smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,22 +1,87 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
import seaborn as sns
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
import scipy.cluster.hierarchy as sch
|
|
7
|
-
import matplotlib.gridspec as gridspec
|
|
8
|
-
import os
|
|
9
3
|
import math
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
import matplotlib.gridspec as gridspec
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import numpy as np
|
|
10
11
|
import pandas as pd
|
|
12
|
+
import scipy.cluster.hierarchy as sch
|
|
13
|
+
import seaborn as sns
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
17
|
+
"""
|
|
18
|
+
Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
|
|
19
|
+
Always includes 0 and n_positions-1 when possible.
|
|
20
|
+
"""
|
|
21
|
+
n_ticks = int(max(2, n_ticks))
|
|
22
|
+
if n_positions <= n_ticks:
|
|
23
|
+
return np.arange(n_positions)
|
|
24
|
+
|
|
25
|
+
# linspace gives fixed count
|
|
26
|
+
pos = np.linspace(0, n_positions - 1, n_ticks)
|
|
27
|
+
return np.unique(np.round(pos).astype(int))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
|
|
31
|
+
"""
|
|
32
|
+
Select tick labels for the heatmap axis.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
subset : AnnData view
|
|
37
|
+
The per-bin subset of the AnnData.
|
|
38
|
+
sites : np.ndarray[int]
|
|
39
|
+
Indices of the subset.var positions to annotate.
|
|
40
|
+
reference : str
|
|
41
|
+
Reference name (e.g., '6B6_top').
|
|
42
|
+
index_col_suffix : None or str
|
|
43
|
+
If None → use subset.var_names
|
|
44
|
+
Else → use subset.var[f"{reference}_{index_col_suffix}"]
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
np.ndarray[str]
|
|
49
|
+
The labels to use for tick positions.
|
|
50
|
+
"""
|
|
51
|
+
if sites.size == 0:
|
|
52
|
+
return np.array([])
|
|
53
|
+
|
|
54
|
+
# Default behavior: use var_names
|
|
55
|
+
if index_col_suffix is None:
|
|
56
|
+
return subset.var_names[sites].astype(str)
|
|
57
|
+
|
|
58
|
+
# Otherwise: use a computed column adata.var[f"{reference}_{suffix}"]
|
|
59
|
+
colname = f"{reference}_{index_col_suffix}"
|
|
60
|
+
|
|
61
|
+
if colname not in subset.var:
|
|
62
|
+
raise KeyError(
|
|
63
|
+
f"index_col_suffix='{index_col_suffix}' requires var column '{colname}', "
|
|
64
|
+
f"but it is not present in adata.var."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
labels = subset.var[colname].astype(str).values
|
|
68
|
+
return labels[sites]
|
|
11
69
|
|
|
12
|
-
from typing import Optional, Mapping, Sequence, Any, Dict, List
|
|
13
|
-
from pathlib import Path
|
|
14
70
|
|
|
15
71
|
def normalized_mean(matrix: np.ndarray) -> np.ndarray:
|
|
72
|
+
"""Compute normalized column means for a matrix.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
matrix: Input matrix.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
1D array of normalized means.
|
|
79
|
+
"""
|
|
16
80
|
mean = np.nanmean(matrix, axis=0)
|
|
17
81
|
denom = (mean.max() - mean.min()) + 1e-9
|
|
18
82
|
return (mean - mean.min()) / denom
|
|
19
83
|
|
|
84
|
+
|
|
20
85
|
def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
21
86
|
"""
|
|
22
87
|
Fraction methylated per column.
|
|
@@ -31,14 +96,20 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
|
31
96
|
valid = valid_mask.sum(axis=0)
|
|
32
97
|
|
|
33
98
|
return np.divide(
|
|
34
|
-
methylated, valid,
|
|
35
|
-
out=np.zeros_like(methylated, dtype=float),
|
|
36
|
-
where=valid != 0
|
|
99
|
+
methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
|
|
37
100
|
)
|
|
38
101
|
|
|
102
|
+
|
|
39
103
|
def clean_barplot(ax, mean_values, title):
|
|
104
|
+
"""Format a barplot with consistent axes and labels.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
ax: Matplotlib axes.
|
|
108
|
+
mean_values: Values to plot.
|
|
109
|
+
title: Plot title.
|
|
110
|
+
"""
|
|
40
111
|
x = np.arange(len(mean_values))
|
|
41
|
-
ax.bar(x, mean_values, color="gray", width=1.0, align=
|
|
112
|
+
ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
|
|
42
113
|
ax.set_xlim(0, len(mean_values))
|
|
43
114
|
ax.set_ylim(0, 1)
|
|
44
115
|
ax.set_yticks([0.0, 0.5, 1.0])
|
|
@@ -47,9 +118,10 @@ def clean_barplot(ax, mean_values, title):
|
|
|
47
118
|
|
|
48
119
|
# Hide all spines except left
|
|
49
120
|
for spine_name, spine in ax.spines.items():
|
|
50
|
-
spine.set_visible(spine_name ==
|
|
121
|
+
spine.set_visible(spine_name == "left")
|
|
122
|
+
|
|
123
|
+
ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
|
|
51
124
|
|
|
52
|
-
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
|
53
125
|
|
|
54
126
|
# def combined_hmm_raw_clustermap(
|
|
55
127
|
# adata,
|
|
@@ -92,7 +164,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
92
164
|
# (adata.obs['read_length'] >= min_length) &
|
|
93
165
|
# (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
|
|
94
166
|
# ]
|
|
95
|
-
|
|
167
|
+
|
|
96
168
|
# mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
97
169
|
# subset = subset[:, mask]
|
|
98
170
|
|
|
@@ -204,7 +276,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
204
276
|
# clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
|
|
205
277
|
# clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
|
|
206
278
|
# clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
|
|
207
|
-
|
|
279
|
+
|
|
208
280
|
# hmm_labels = subset.var_names.astype(int)
|
|
209
281
|
# hmm_label_spacing = 150
|
|
210
282
|
# sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
@@ -258,7 +330,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
258
330
|
# "bin_boundaries": bin_boundaries,
|
|
259
331
|
# "percentages": percentages
|
|
260
332
|
# })
|
|
261
|
-
|
|
333
|
+
|
|
262
334
|
# #adata.uns['clustermap_results'] = results
|
|
263
335
|
|
|
264
336
|
# except Exception as e:
|
|
@@ -271,83 +343,131 @@ def combined_hmm_raw_clustermap(
|
|
|
271
343
|
adata,
|
|
272
344
|
sample_col: str = "Sample_Names",
|
|
273
345
|
reference_col: str = "Reference_strand",
|
|
274
|
-
|
|
275
346
|
hmm_feature_layer: str = "hmm_combined",
|
|
276
|
-
|
|
277
347
|
layer_gpc: str = "nan0_0minus1",
|
|
278
348
|
layer_cpg: str = "nan0_0minus1",
|
|
279
|
-
|
|
349
|
+
layer_c: str = "nan0_0minus1",
|
|
280
350
|
layer_a: str = "nan0_0minus1",
|
|
281
|
-
|
|
282
351
|
cmap_hmm: str = "tab10",
|
|
283
352
|
cmap_gpc: str = "coolwarm",
|
|
284
353
|
cmap_cpg: str = "viridis",
|
|
285
|
-
|
|
354
|
+
cmap_c: str = "coolwarm",
|
|
286
355
|
cmap_a: str = "coolwarm",
|
|
287
|
-
|
|
288
356
|
min_quality: int = 20,
|
|
289
357
|
min_length: int = 200,
|
|
290
358
|
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
291
359
|
min_position_valid_fraction: float = 0.5,
|
|
292
|
-
|
|
360
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
361
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
293
362
|
save_path: str | Path | None = None,
|
|
294
363
|
normalize_hmm: bool = False,
|
|
295
|
-
|
|
296
364
|
sort_by: str = "gpc",
|
|
297
365
|
bins: Optional[Dict[str, Any]] = None,
|
|
298
|
-
|
|
299
366
|
deaminase: bool = False,
|
|
300
367
|
min_signal: float = 0.0,
|
|
301
|
-
|
|
302
368
|
# ---- fixed tick label controls (counts, not spacing)
|
|
303
369
|
n_xticks_hmm: int = 10,
|
|
304
370
|
n_xticks_any_c: int = 8,
|
|
305
371
|
n_xticks_gpc: int = 8,
|
|
306
372
|
n_xticks_cpg: int = 8,
|
|
307
373
|
n_xticks_a: int = 8,
|
|
374
|
+
index_col_suffix: str | None = None,
|
|
308
375
|
):
|
|
309
376
|
"""
|
|
310
377
|
Makes a multi-panel clustermap per (sample, reference):
|
|
311
|
-
HMM panel (always) + optional raw panels for
|
|
378
|
+
HMM panel (always) + optional raw panels for C, GpC, CpG, and A sites.
|
|
312
379
|
|
|
313
380
|
Panels are added only if the corresponding site mask exists AND has >0 sites.
|
|
314
381
|
|
|
315
382
|
sort_by options:
|
|
316
|
-
'gpc', 'cpg', '
|
|
383
|
+
'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
|
|
317
384
|
"""
|
|
385
|
+
|
|
318
386
|
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
387
|
+
"""Pick tick indices/labels from an array."""
|
|
319
388
|
if labels.size == 0:
|
|
320
389
|
return [], []
|
|
321
390
|
idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
|
|
322
391
|
idx = np.unique(idx)
|
|
323
392
|
return idx.tolist(), labels[idx].tolist()
|
|
324
|
-
|
|
393
|
+
|
|
394
|
+
# Helper: build a True mask if filter is inactive or column missing
|
|
395
|
+
def _mask_or_true(series_name: str, predicate):
|
|
396
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
397
|
+
if series_name not in adata.obs:
|
|
398
|
+
return pd.Series(True, index=adata.obs.index)
|
|
399
|
+
s = adata.obs[series_name]
|
|
400
|
+
try:
|
|
401
|
+
return predicate(s)
|
|
402
|
+
except Exception:
|
|
403
|
+
# Fallback: all True if bad dtype / predicate failure
|
|
404
|
+
return pd.Series(True, index=adata.obs.index)
|
|
405
|
+
|
|
325
406
|
results = []
|
|
326
407
|
signal_type = "deamination" if deaminase else "methylation"
|
|
327
408
|
|
|
328
409
|
for ref in adata.obs[reference_col].cat.categories:
|
|
329
410
|
for sample in adata.obs[sample_col].cat.categories:
|
|
411
|
+
# Optionally remap sample label for display
|
|
412
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
413
|
+
# Row-level masks (obs)
|
|
414
|
+
qmask = _mask_or_true(
|
|
415
|
+
"read_quality",
|
|
416
|
+
(lambda s: s >= float(min_quality))
|
|
417
|
+
if (min_quality is not None)
|
|
418
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
419
|
+
)
|
|
420
|
+
lm_mask = _mask_or_true(
|
|
421
|
+
"mapped_length",
|
|
422
|
+
(lambda s: s >= float(min_length))
|
|
423
|
+
if (min_length is not None)
|
|
424
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
425
|
+
)
|
|
426
|
+
lrr_mask = _mask_or_true(
|
|
427
|
+
"mapped_length_to_reference_length_ratio",
|
|
428
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
429
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
430
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
demux_mask = _mask_or_true(
|
|
434
|
+
"demux_type",
|
|
435
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
436
|
+
if (demux_types is not None)
|
|
437
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
441
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
442
|
+
|
|
443
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
444
|
+
|
|
445
|
+
if not bool(row_mask.any()):
|
|
446
|
+
print(
|
|
447
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
448
|
+
)
|
|
449
|
+
continue
|
|
330
450
|
|
|
331
451
|
try:
|
|
332
452
|
# ---- subset reads ----
|
|
333
|
-
subset = adata[
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
>
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
subset = subset[:, mask]
|
|
453
|
+
subset = adata[row_mask, :].copy()
|
|
454
|
+
|
|
455
|
+
# Column-level mask (var)
|
|
456
|
+
if min_position_valid_fraction is not None:
|
|
457
|
+
valid_key = f"{ref}_valid_fraction"
|
|
458
|
+
if valid_key in subset.var:
|
|
459
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
460
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
461
|
+
if col_mask.any():
|
|
462
|
+
subset = subset[:, col_mask].copy()
|
|
463
|
+
else:
|
|
464
|
+
print(
|
|
465
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
466
|
+
)
|
|
467
|
+
continue
|
|
349
468
|
|
|
350
469
|
if subset.shape[0] == 0:
|
|
470
|
+
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
351
471
|
continue
|
|
352
472
|
|
|
353
473
|
# ---- bins ----
|
|
@@ -358,6 +478,7 @@ def combined_hmm_raw_clustermap(
|
|
|
358
478
|
|
|
359
479
|
# ---- site masks (robust) ----
|
|
360
480
|
def _sites(*keys):
|
|
481
|
+
"""Return indices for the first matching site key."""
|
|
361
482
|
for k in keys:
|
|
362
483
|
if k in subset.var:
|
|
363
484
|
return np.where(subset.var[k].values)[0]
|
|
@@ -368,13 +489,14 @@ def combined_hmm_raw_clustermap(
|
|
|
368
489
|
any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
|
|
369
490
|
any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
|
|
370
491
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
492
|
+
# ---- labels via _select_labels ----
|
|
493
|
+
# HMM uses *all* columns
|
|
494
|
+
hmm_sites = np.arange(subset.n_vars, dtype=int)
|
|
495
|
+
hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
|
|
496
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
497
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
498
|
+
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
499
|
+
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
378
500
|
|
|
379
501
|
# storage
|
|
380
502
|
stacked_hmm = []
|
|
@@ -411,11 +533,11 @@ def combined_hmm_raw_clustermap(
|
|
|
411
533
|
linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
412
534
|
order = sch.leaves_list(linkage)
|
|
413
535
|
|
|
414
|
-
elif sort_by == "
|
|
415
|
-
linkage = sch.linkage(sb[:, any_c_sites].layers[
|
|
536
|
+
elif sort_by == "c" and any_c_sites.size:
|
|
537
|
+
linkage = sch.linkage(sb[:, any_c_sites].layers[layer_c], method="ward")
|
|
416
538
|
order = sch.leaves_list(linkage)
|
|
417
539
|
|
|
418
|
-
elif sort_by == "
|
|
540
|
+
elif sort_by == "a" and any_a_sites.size:
|
|
419
541
|
linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
|
|
420
542
|
order = sch.leaves_list(linkage)
|
|
421
543
|
|
|
@@ -423,6 +545,12 @@ def combined_hmm_raw_clustermap(
|
|
|
423
545
|
linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
|
|
424
546
|
order = sch.leaves_list(linkage)
|
|
425
547
|
|
|
548
|
+
elif sort_by == "hmm" and hmm_sites.size:
|
|
549
|
+
linkage = sch.linkage(
|
|
550
|
+
sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
|
|
551
|
+
)
|
|
552
|
+
order = sch.leaves_list(linkage)
|
|
553
|
+
|
|
426
554
|
else:
|
|
427
555
|
order = np.arange(n)
|
|
428
556
|
|
|
@@ -431,7 +559,7 @@ def combined_hmm_raw_clustermap(
|
|
|
431
559
|
# ---- collect matrices ----
|
|
432
560
|
stacked_hmm.append(sb.layers[hmm_feature_layer])
|
|
433
561
|
if any_c_sites.size:
|
|
434
|
-
stacked_any_c.append(sb[:, any_c_sites].layers[
|
|
562
|
+
stacked_any_c.append(sb[:, any_c_sites].layers[layer_c])
|
|
435
563
|
if gpc_sites.size:
|
|
436
564
|
stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
|
|
437
565
|
if cpg_sites.size:
|
|
@@ -446,46 +574,62 @@ def combined_hmm_raw_clustermap(
|
|
|
446
574
|
|
|
447
575
|
# ---------------- stack ----------------
|
|
448
576
|
hmm_matrix = np.vstack(stacked_hmm)
|
|
449
|
-
mean_hmm =
|
|
577
|
+
mean_hmm = (
|
|
578
|
+
normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
|
|
579
|
+
)
|
|
450
580
|
|
|
451
581
|
panels = [
|
|
452
|
-
(
|
|
582
|
+
(
|
|
583
|
+
f"HMM - {hmm_feature_layer}",
|
|
584
|
+
hmm_matrix,
|
|
585
|
+
hmm_labels,
|
|
586
|
+
cmap_hmm,
|
|
587
|
+
mean_hmm,
|
|
588
|
+
n_xticks_hmm,
|
|
589
|
+
),
|
|
453
590
|
]
|
|
454
591
|
|
|
455
592
|
if stacked_any_c:
|
|
456
593
|
m = np.vstack(stacked_any_c)
|
|
457
|
-
panels.append(
|
|
594
|
+
panels.append(
|
|
595
|
+
("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
|
|
596
|
+
)
|
|
458
597
|
|
|
459
598
|
if stacked_gpc:
|
|
460
599
|
m = np.vstack(stacked_gpc)
|
|
461
|
-
panels.append(
|
|
600
|
+
panels.append(
|
|
601
|
+
("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
|
|
602
|
+
)
|
|
462
603
|
|
|
463
604
|
if stacked_cpg:
|
|
464
605
|
m = np.vstack(stacked_cpg)
|
|
465
|
-
panels.append(
|
|
606
|
+
panels.append(
|
|
607
|
+
("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
|
|
608
|
+
)
|
|
466
609
|
|
|
467
610
|
if stacked_any_a:
|
|
468
611
|
m = np.vstack(stacked_any_a)
|
|
469
|
-
panels.append(
|
|
612
|
+
panels.append(
|
|
613
|
+
("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
|
|
614
|
+
)
|
|
470
615
|
|
|
471
616
|
# ---------------- plotting ----------------
|
|
472
617
|
n_panels = len(panels)
|
|
473
618
|
fig = plt.figure(figsize=(4.5 * n_panels, 10))
|
|
474
619
|
gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
|
|
475
|
-
fig.suptitle(
|
|
476
|
-
|
|
620
|
+
fig.suptitle(
|
|
621
|
+
f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
|
|
622
|
+
)
|
|
477
623
|
|
|
478
624
|
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
|
|
479
625
|
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
|
|
480
626
|
|
|
481
627
|
for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
|
|
482
|
-
|
|
483
628
|
# ---- your clean barplot ----
|
|
484
629
|
clean_barplot(axes_bar[i], mean_vec, name)
|
|
485
630
|
|
|
486
631
|
# ---- heatmap ----
|
|
487
|
-
sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
|
|
488
|
-
yticklabels=False, cbar=False)
|
|
632
|
+
sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
|
|
489
633
|
|
|
490
634
|
# ---- xticks ----
|
|
491
635
|
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
@@ -509,6 +653,7 @@ def combined_hmm_raw_clustermap(
|
|
|
509
653
|
|
|
510
654
|
except Exception:
|
|
511
655
|
import traceback
|
|
656
|
+
|
|
512
657
|
traceback.print_exc()
|
|
513
658
|
continue
|
|
514
659
|
|
|
@@ -628,7 +773,7 @@ def combined_hmm_raw_clustermap(
|
|
|
628
773
|
# order = np.arange(num_reads)
|
|
629
774
|
# elif sort_by == "any_a":
|
|
630
775
|
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
631
|
-
# order = sch.leaves_list(linkage)
|
|
776
|
+
# order = sch.leaves_list(linkage)
|
|
632
777
|
# else:
|
|
633
778
|
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
634
779
|
|
|
@@ -657,13 +802,13 @@ def combined_hmm_raw_clustermap(
|
|
|
657
802
|
# order = np.arange(num_reads)
|
|
658
803
|
# elif sort_by == "any_a":
|
|
659
804
|
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
660
|
-
# order = sch.leaves_list(linkage)
|
|
805
|
+
# order = sch.leaves_list(linkage)
|
|
661
806
|
# else:
|
|
662
807
|
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
663
|
-
|
|
808
|
+
|
|
664
809
|
# stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
|
|
665
|
-
|
|
666
|
-
|
|
810
|
+
|
|
811
|
+
|
|
667
812
|
# row_labels.extend([bin_label] * num_reads)
|
|
668
813
|
# bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
669
814
|
# last_idx += num_reads
|
|
@@ -686,7 +831,7 @@ def combined_hmm_raw_clustermap(
|
|
|
686
831
|
# if any_a_matrix.size > 0:
|
|
687
832
|
# mean_any_a = methylation_fraction(any_a_matrix)
|
|
688
833
|
# gs_dim += 1
|
|
689
|
-
|
|
834
|
+
|
|
690
835
|
|
|
691
836
|
# fig = plt.figure(figsize=(18, 12))
|
|
692
837
|
# gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
|
|
@@ -718,8 +863,8 @@ def combined_hmm_raw_clustermap(
|
|
|
718
863
|
# sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
719
864
|
# axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
720
865
|
# for boundary in bin_boundaries[:-1]:
|
|
721
|
-
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
722
|
-
# current_ax +=1
|
|
866
|
+
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
867
|
+
# current_ax +=1
|
|
723
868
|
|
|
724
869
|
# results.append({
|
|
725
870
|
# "sample": sample,
|
|
@@ -731,7 +876,7 @@ def combined_hmm_raw_clustermap(
|
|
|
731
876
|
# "bin_labels": bin_labels,
|
|
732
877
|
# "bin_boundaries": bin_boundaries,
|
|
733
878
|
# "percentages": percentages
|
|
734
|
-
# })
|
|
879
|
+
# })
|
|
735
880
|
|
|
736
881
|
# if stacked_any_a:
|
|
737
882
|
# if any_a_matrix.size > 0:
|
|
@@ -751,7 +896,7 @@ def combined_hmm_raw_clustermap(
|
|
|
751
896
|
# "bin_labels": bin_labels,
|
|
752
897
|
# "bin_boundaries": bin_boundaries,
|
|
753
898
|
# "percentages": percentages
|
|
754
|
-
# })
|
|
899
|
+
# })
|
|
755
900
|
|
|
756
901
|
# plt.tight_layout()
|
|
757
902
|
|
|
@@ -769,7 +914,7 @@ def combined_hmm_raw_clustermap(
|
|
|
769
914
|
# print(f"Summary for {sample} - {ref}:")
|
|
770
915
|
# for bin_label, percent in percentages.items():
|
|
771
916
|
# print(f" - {bin_label}: {percent:.1f}%")
|
|
772
|
-
|
|
917
|
+
|
|
773
918
|
# adata.uns['clustermap_results'] = results
|
|
774
919
|
|
|
775
920
|
# except Exception as e:
|
|
@@ -777,52 +922,41 @@ def combined_hmm_raw_clustermap(
|
|
|
777
922
|
# traceback.print_exc()
|
|
778
923
|
# continue
|
|
779
924
|
|
|
780
|
-
def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
781
|
-
"""
|
|
782
|
-
Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
|
|
783
|
-
Always includes 0 and n_positions-1 when possible.
|
|
784
|
-
"""
|
|
785
|
-
n_ticks = int(max(2, n_ticks))
|
|
786
|
-
if n_positions <= n_ticks:
|
|
787
|
-
return np.arange(n_positions)
|
|
788
|
-
|
|
789
|
-
# linspace gives fixed count
|
|
790
|
-
pos = np.linspace(0, n_positions - 1, n_ticks)
|
|
791
|
-
return np.unique(np.round(pos).astype(int))
|
|
792
925
|
|
|
793
926
|
def combined_raw_clustermap(
|
|
794
927
|
adata,
|
|
795
928
|
sample_col: str = "Sample_Names",
|
|
796
929
|
reference_col: str = "Reference_strand",
|
|
797
930
|
mod_target_bases: Sequence[str] = ("GpC", "CpG"),
|
|
798
|
-
|
|
931
|
+
layer_c: str = "nan0_0minus1",
|
|
799
932
|
layer_gpc: str = "nan0_0minus1",
|
|
800
933
|
layer_cpg: str = "nan0_0minus1",
|
|
801
934
|
layer_a: str = "nan0_0minus1",
|
|
802
|
-
|
|
935
|
+
cmap_c: str = "coolwarm",
|
|
803
936
|
cmap_gpc: str = "coolwarm",
|
|
804
937
|
cmap_cpg: str = "viridis",
|
|
805
938
|
cmap_a: str = "coolwarm",
|
|
806
|
-
min_quality: float = 20,
|
|
807
|
-
min_length: int = 200,
|
|
808
|
-
min_mapped_length_to_reference_length_ratio: float = 0
|
|
809
|
-
min_position_valid_fraction: float = 0
|
|
939
|
+
min_quality: float | None = 20,
|
|
940
|
+
min_length: int | None = 200,
|
|
941
|
+
min_mapped_length_to_reference_length_ratio: float | None = 0,
|
|
942
|
+
min_position_valid_fraction: float | None = 0,
|
|
943
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
810
944
|
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
811
945
|
save_path: str | Path | None = None,
|
|
812
|
-
sort_by: str = "gpc", # 'gpc','cpg','
|
|
946
|
+
sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
|
|
813
947
|
bins: Optional[Dict[str, Any]] = None,
|
|
814
948
|
deaminase: bool = False,
|
|
815
949
|
min_signal: float = 0,
|
|
816
|
-
# NEW tick controls
|
|
817
950
|
n_xticks_any_c: int = 10,
|
|
818
951
|
n_xticks_gpc: int = 10,
|
|
819
952
|
n_xticks_cpg: int = 10,
|
|
820
953
|
n_xticks_any_a: int = 10,
|
|
821
954
|
xtick_rotation: int = 90,
|
|
822
955
|
xtick_fontsize: int = 9,
|
|
956
|
+
index_col_suffix: str | None = None,
|
|
823
957
|
):
|
|
824
958
|
"""
|
|
825
|
-
Plot stacked heatmaps + per-position mean barplots for
|
|
959
|
+
Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
|
|
826
960
|
|
|
827
961
|
Key fixes vs old version:
|
|
828
962
|
- order computed ONCE per bin, applied to all matrices
|
|
@@ -838,6 +972,18 @@ def combined_raw_clustermap(
|
|
|
838
972
|
One entry per (sample, ref) plot with matrices + bin metadata.
|
|
839
973
|
"""
|
|
840
974
|
|
|
975
|
+
# Helper: build a True mask if filter is inactive or column missing
|
|
976
|
+
def _mask_or_true(series_name: str, predicate):
|
|
977
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
978
|
+
if series_name not in adata.obs:
|
|
979
|
+
return pd.Series(True, index=adata.obs.index)
|
|
980
|
+
s = adata.obs[series_name]
|
|
981
|
+
try:
|
|
982
|
+
return predicate(s)
|
|
983
|
+
except Exception:
|
|
984
|
+
# Fallback: all True if bad dtype / predicate failure
|
|
985
|
+
return pd.Series(True, index=adata.obs.index)
|
|
986
|
+
|
|
841
987
|
results: List[Dict[str, Any]] = []
|
|
842
988
|
save_path = Path(save_path) if save_path is not None else None
|
|
843
989
|
if save_path is not None:
|
|
@@ -856,24 +1002,63 @@ def combined_raw_clustermap(
|
|
|
856
1002
|
|
|
857
1003
|
for ref in adata.obs[reference_col].cat.categories:
|
|
858
1004
|
for sample in adata.obs[sample_col].cat.categories:
|
|
859
|
-
|
|
860
1005
|
# Optionally remap sample label for display
|
|
861
1006
|
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
862
1007
|
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
1008
|
+
# Row-level masks (obs)
|
|
1009
|
+
qmask = _mask_or_true(
|
|
1010
|
+
"read_quality",
|
|
1011
|
+
(lambda s: s >= float(min_quality))
|
|
1012
|
+
if (min_quality is not None)
|
|
1013
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1014
|
+
)
|
|
1015
|
+
lm_mask = _mask_or_true(
|
|
1016
|
+
"mapped_length",
|
|
1017
|
+
(lambda s: s >= float(min_length))
|
|
1018
|
+
if (min_length is not None)
|
|
1019
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1020
|
+
)
|
|
1021
|
+
lrr_mask = _mask_or_true(
|
|
1022
|
+
"mapped_length_to_reference_length_ratio",
|
|
1023
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
1024
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
1025
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
demux_mask = _mask_or_true(
|
|
1029
|
+
"demux_type",
|
|
1030
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
1031
|
+
if (demux_types is not None)
|
|
1032
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
1036
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
1037
|
+
|
|
1038
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
1039
|
+
|
|
1040
|
+
if not bool(row_mask.any()):
|
|
1041
|
+
print(
|
|
1042
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
1043
|
+
)
|
|
1044
|
+
continue
|
|
871
1045
|
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
1046
|
+
try:
|
|
1047
|
+
subset = adata[row_mask, :].copy()
|
|
1048
|
+
|
|
1049
|
+
# Column-level mask (var)
|
|
1050
|
+
if min_position_valid_fraction is not None:
|
|
1051
|
+
valid_key = f"{ref}_valid_fraction"
|
|
1052
|
+
if valid_key in subset.var:
|
|
1053
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
1054
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
1055
|
+
if col_mask.any():
|
|
1056
|
+
subset = subset[:, col_mask].copy()
|
|
1057
|
+
else:
|
|
1058
|
+
print(
|
|
1059
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
1060
|
+
)
|
|
1061
|
+
continue
|
|
877
1062
|
|
|
878
1063
|
if subset.shape[0] == 0:
|
|
879
1064
|
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
@@ -893,19 +1078,19 @@ def combined_raw_clustermap(
|
|
|
893
1078
|
|
|
894
1079
|
if include_any_c:
|
|
895
1080
|
any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
|
|
896
|
-
gpc_sites
|
|
897
|
-
cpg_sites
|
|
1081
|
+
gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
|
|
1082
|
+
cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
|
|
898
1083
|
|
|
899
1084
|
num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
|
|
900
1085
|
|
|
901
|
-
any_c_labels = subset
|
|
902
|
-
gpc_labels
|
|
903
|
-
cpg_labels
|
|
1086
|
+
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
1087
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
1088
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
904
1089
|
|
|
905
1090
|
if include_any_a:
|
|
906
1091
|
any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
|
|
907
1092
|
num_any_a = len(any_a_sites)
|
|
908
|
-
any_a_labels = subset
|
|
1093
|
+
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
909
1094
|
|
|
910
1095
|
stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
|
|
911
1096
|
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
@@ -932,23 +1117,31 @@ def combined_raw_clustermap(
|
|
|
932
1117
|
order = np.argsort(subset_bin.obs[colname].values)
|
|
933
1118
|
|
|
934
1119
|
elif sort_by == "gpc" and num_gpc > 0:
|
|
935
|
-
linkage = sch.linkage(
|
|
1120
|
+
linkage = sch.linkage(
|
|
1121
|
+
subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
|
|
1122
|
+
)
|
|
936
1123
|
order = sch.leaves_list(linkage)
|
|
937
1124
|
|
|
938
1125
|
elif sort_by == "cpg" and num_cpg > 0:
|
|
939
|
-
linkage = sch.linkage(
|
|
1126
|
+
linkage = sch.linkage(
|
|
1127
|
+
subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
|
|
1128
|
+
)
|
|
940
1129
|
order = sch.leaves_list(linkage)
|
|
941
1130
|
|
|
942
|
-
elif sort_by == "
|
|
943
|
-
linkage = sch.linkage(
|
|
1131
|
+
elif sort_by == "c" and num_any_c > 0:
|
|
1132
|
+
linkage = sch.linkage(
|
|
1133
|
+
subset_bin[:, any_c_sites].layers[layer_c], method="ward"
|
|
1134
|
+
)
|
|
944
1135
|
order = sch.leaves_list(linkage)
|
|
945
1136
|
|
|
946
1137
|
elif sort_by == "gpc_cpg":
|
|
947
1138
|
linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
948
1139
|
order = sch.leaves_list(linkage)
|
|
949
1140
|
|
|
950
|
-
elif sort_by == "
|
|
951
|
-
linkage = sch.linkage(
|
|
1141
|
+
elif sort_by == "a" and num_any_a > 0:
|
|
1142
|
+
linkage = sch.linkage(
|
|
1143
|
+
subset_bin[:, any_a_sites].layers[layer_a], method="ward"
|
|
1144
|
+
)
|
|
952
1145
|
order = sch.leaves_list(linkage)
|
|
953
1146
|
|
|
954
1147
|
elif sort_by == "none":
|
|
@@ -961,7 +1154,7 @@ def combined_raw_clustermap(
|
|
|
961
1154
|
|
|
962
1155
|
# stack consistently
|
|
963
1156
|
if include_any_c and num_any_c > 0:
|
|
964
|
-
stacked_any_c.append(subset_bin[:, any_c_sites].layers[
|
|
1157
|
+
stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_c])
|
|
965
1158
|
if include_any_c and num_gpc > 0:
|
|
966
1159
|
stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
|
|
967
1160
|
if include_any_c and num_cpg > 0:
|
|
@@ -981,57 +1174,65 @@ def combined_raw_clustermap(
|
|
|
981
1174
|
|
|
982
1175
|
if include_any_c and stacked_any_c:
|
|
983
1176
|
any_c_matrix = np.vstack(stacked_any_c)
|
|
984
|
-
gpc_matrix
|
|
985
|
-
cpg_matrix
|
|
1177
|
+
gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
|
|
1178
|
+
cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
|
|
986
1179
|
|
|
987
1180
|
mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
|
|
988
|
-
mean_gpc
|
|
989
|
-
mean_cpg
|
|
1181
|
+
mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
|
|
1182
|
+
mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
|
|
990
1183
|
|
|
991
1184
|
if any_c_matrix.size:
|
|
992
|
-
blocks.append(
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1185
|
+
blocks.append(
|
|
1186
|
+
dict(
|
|
1187
|
+
name="c",
|
|
1188
|
+
matrix=any_c_matrix,
|
|
1189
|
+
mean=mean_any_c,
|
|
1190
|
+
labels=any_c_labels,
|
|
1191
|
+
cmap=cmap_c,
|
|
1192
|
+
n_xticks=n_xticks_any_c,
|
|
1193
|
+
title="any C site Modification Signal",
|
|
1194
|
+
)
|
|
1195
|
+
)
|
|
1001
1196
|
if gpc_matrix.size:
|
|
1002
|
-
blocks.append(
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1197
|
+
blocks.append(
|
|
1198
|
+
dict(
|
|
1199
|
+
name="gpc",
|
|
1200
|
+
matrix=gpc_matrix,
|
|
1201
|
+
mean=mean_gpc,
|
|
1202
|
+
labels=gpc_labels,
|
|
1203
|
+
cmap=cmap_gpc,
|
|
1204
|
+
n_xticks=n_xticks_gpc,
|
|
1205
|
+
title="GpC Modification Signal",
|
|
1206
|
+
)
|
|
1207
|
+
)
|
|
1011
1208
|
if cpg_matrix.size:
|
|
1012
|
-
blocks.append(
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1209
|
+
blocks.append(
|
|
1210
|
+
dict(
|
|
1211
|
+
name="cpg",
|
|
1212
|
+
matrix=cpg_matrix,
|
|
1213
|
+
mean=mean_cpg,
|
|
1214
|
+
labels=cpg_labels,
|
|
1215
|
+
cmap=cmap_cpg,
|
|
1216
|
+
n_xticks=n_xticks_cpg,
|
|
1217
|
+
title="CpG Modification Signal",
|
|
1218
|
+
)
|
|
1219
|
+
)
|
|
1021
1220
|
|
|
1022
1221
|
if include_any_a and stacked_any_a:
|
|
1023
1222
|
any_a_matrix = np.vstack(stacked_any_a)
|
|
1024
1223
|
mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
|
|
1025
1224
|
if any_a_matrix.size:
|
|
1026
|
-
blocks.append(
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1225
|
+
blocks.append(
|
|
1226
|
+
dict(
|
|
1227
|
+
name="a",
|
|
1228
|
+
matrix=any_a_matrix,
|
|
1229
|
+
mean=mean_any_a,
|
|
1230
|
+
labels=any_a_labels,
|
|
1231
|
+
cmap=cmap_a,
|
|
1232
|
+
n_xticks=n_xticks_any_a,
|
|
1233
|
+
title="any A site Modification Signal",
|
|
1234
|
+
)
|
|
1235
|
+
)
|
|
1035
1236
|
|
|
1036
1237
|
if not blocks:
|
|
1037
1238
|
print(f"No matrices to plot for {display_sample} - {ref}")
|
|
@@ -1043,7 +1244,7 @@ def combined_raw_clustermap(
|
|
|
1043
1244
|
fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
|
|
1044
1245
|
|
|
1045
1246
|
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
|
|
1046
|
-
axes_bar
|
|
1247
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
|
|
1047
1248
|
|
|
1048
1249
|
# ----------------------------
|
|
1049
1250
|
# plot blocks
|
|
@@ -1059,20 +1260,14 @@ def combined_raw_clustermap(
|
|
|
1059
1260
|
|
|
1060
1261
|
# heatmap
|
|
1061
1262
|
sns.heatmap(
|
|
1062
|
-
mat,
|
|
1063
|
-
cmap=blk["cmap"],
|
|
1064
|
-
ax=axes_heat[i],
|
|
1065
|
-
yticklabels=False,
|
|
1066
|
-
cbar=False
|
|
1263
|
+
mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
|
|
1067
1264
|
)
|
|
1068
1265
|
|
|
1069
1266
|
# fixed tick labels
|
|
1070
1267
|
tick_pos = _fixed_tick_positions(len(labels), n_xticks)
|
|
1071
1268
|
axes_heat[i].set_xticks(tick_pos)
|
|
1072
1269
|
axes_heat[i].set_xticklabels(
|
|
1073
|
-
labels[tick_pos],
|
|
1074
|
-
rotation=xtick_rotation,
|
|
1075
|
-
fontsize=xtick_fontsize
|
|
1270
|
+
labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
|
|
1076
1271
|
)
|
|
1077
1272
|
|
|
1078
1273
|
# bin separators
|
|
@@ -1085,7 +1280,12 @@ def combined_raw_clustermap(
|
|
|
1085
1280
|
|
|
1086
1281
|
# save or show
|
|
1087
1282
|
if save_path is not None:
|
|
1088
|
-
safe_name =
|
|
1283
|
+
safe_name = (
|
|
1284
|
+
f"{ref}__{display_sample}".replace("=", "")
|
|
1285
|
+
.replace("__", "_")
|
|
1286
|
+
.replace(",", "_")
|
|
1287
|
+
.replace(" ", "_")
|
|
1288
|
+
)
|
|
1089
1289
|
out_file = save_path / f"{safe_name}.png"
|
|
1090
1290
|
fig.savefig(out_file, dpi=300)
|
|
1091
1291
|
plt.close(fig)
|
|
@@ -1111,20 +1311,15 @@ def combined_raw_clustermap(
|
|
|
1111
1311
|
for bin_label, percent in percentages.items():
|
|
1112
1312
|
print(f" - {bin_label}: {percent:.1f}%")
|
|
1113
1313
|
|
|
1114
|
-
except Exception
|
|
1314
|
+
except Exception:
|
|
1115
1315
|
import traceback
|
|
1316
|
+
|
|
1116
1317
|
traceback.print_exc()
|
|
1117
1318
|
continue
|
|
1118
1319
|
|
|
1119
|
-
# store once at the end (HDF5 safe)
|
|
1120
|
-
# matrices won't be HDF5-safe; store only metadata + maybe hit counts
|
|
1121
|
-
# adata.uns["clustermap_results"] = [
|
|
1122
|
-
# {k: v for k, v in r.items() if not k.endswith("_matrix")}
|
|
1123
|
-
# for r in results
|
|
1124
|
-
# ]
|
|
1125
|
-
|
|
1126
1320
|
return results
|
|
1127
1321
|
|
|
1322
|
+
|
|
1128
1323
|
def plot_hmm_layers_rolling_by_sample_ref(
|
|
1129
1324
|
adata,
|
|
1130
1325
|
layers: Optional[Sequence[str]] = None,
|
|
@@ -1141,7 +1336,7 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1141
1336
|
output_dir: Optional[str] = None,
|
|
1142
1337
|
save: bool = True,
|
|
1143
1338
|
show_raw: bool = False,
|
|
1144
|
-
cmap: str = "
|
|
1339
|
+
cmap: str = "tab20",
|
|
1145
1340
|
use_var_coords: bool = True,
|
|
1146
1341
|
):
|
|
1147
1342
|
"""
|
|
@@ -1191,7 +1386,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1191
1386
|
|
|
1192
1387
|
# --- basic checks / defaults ---
|
|
1193
1388
|
if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
|
|
1194
|
-
raise ValueError(
|
|
1389
|
+
raise ValueError(
|
|
1390
|
+
f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
|
|
1391
|
+
)
|
|
1195
1392
|
|
|
1196
1393
|
# canonicalize samples / refs
|
|
1197
1394
|
if samples is None:
|
|
@@ -1214,7 +1411,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1214
1411
|
if layers is None:
|
|
1215
1412
|
layers = list(adata.layers.keys())
|
|
1216
1413
|
if len(layers) == 0:
|
|
1217
|
-
raise ValueError(
|
|
1414
|
+
raise ValueError(
|
|
1415
|
+
"No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
|
|
1416
|
+
)
|
|
1218
1417
|
layers = list(layers)
|
|
1219
1418
|
|
|
1220
1419
|
# x coordinates (positions)
|
|
@@ -1253,19 +1452,29 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1253
1452
|
|
|
1254
1453
|
fig_w = figsize_per_cell[0] * ncols
|
|
1255
1454
|
fig_h = figsize_per_cell[1] * nrows
|
|
1256
|
-
fig, axes = plt.subplots(
|
|
1257
|
-
|
|
1258
|
-
|
|
1455
|
+
fig, axes = plt.subplots(
|
|
1456
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1457
|
+
)
|
|
1259
1458
|
|
|
1260
1459
|
for r_idx, sample_name in enumerate(chunk):
|
|
1261
1460
|
for c_idx, ref_name in enumerate(refs_all):
|
|
1262
1461
|
ax = axes[r_idx][c_idx]
|
|
1263
1462
|
|
|
1264
1463
|
# subset adata
|
|
1265
|
-
mask = (adata.obs[sample_col].values == sample_name) & (
|
|
1464
|
+
mask = (adata.obs[sample_col].values == sample_name) & (
|
|
1465
|
+
adata.obs[ref_col].values == ref_name
|
|
1466
|
+
)
|
|
1266
1467
|
sub = adata[mask]
|
|
1267
1468
|
if sub.n_obs == 0:
|
|
1268
|
-
ax.text(
|
|
1469
|
+
ax.text(
|
|
1470
|
+
0.5,
|
|
1471
|
+
0.5,
|
|
1472
|
+
"No reads",
|
|
1473
|
+
ha="center",
|
|
1474
|
+
va="center",
|
|
1475
|
+
transform=ax.transAxes,
|
|
1476
|
+
color="gray",
|
|
1477
|
+
)
|
|
1269
1478
|
ax.set_xticks([])
|
|
1270
1479
|
ax.set_yticks([])
|
|
1271
1480
|
if r_idx == 0:
|
|
@@ -1315,7 +1524,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1315
1524
|
smoothed = col_mean
|
|
1316
1525
|
else:
|
|
1317
1526
|
ser = pd.Series(col_mean)
|
|
1318
|
-
smoothed =
|
|
1527
|
+
smoothed = (
|
|
1528
|
+
ser.rolling(window=window, min_periods=min_periods, center=center)
|
|
1529
|
+
.mean()
|
|
1530
|
+
.to_numpy()
|
|
1531
|
+
)
|
|
1319
1532
|
|
|
1320
1533
|
# x axis: x_coords (trim/pad to match length)
|
|
1321
1534
|
L = len(col_mean)
|
|
@@ -1325,7 +1538,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1325
1538
|
if show_raw:
|
|
1326
1539
|
ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
|
|
1327
1540
|
|
|
1328
|
-
ax.plot(
|
|
1541
|
+
ax.plot(
|
|
1542
|
+
x,
|
|
1543
|
+
smoothed[:L],
|
|
1544
|
+
label=layer,
|
|
1545
|
+
color=colors[li],
|
|
1546
|
+
linewidth=1.2,
|
|
1547
|
+
alpha=0.95,
|
|
1548
|
+
zorder=2,
|
|
1549
|
+
)
|
|
1329
1550
|
plotted_any = True
|
|
1330
1551
|
|
|
1331
1552
|
# labels / titles
|
|
@@ -1343,11 +1564,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1343
1564
|
|
|
1344
1565
|
ax.grid(True, alpha=0.2)
|
|
1345
1566
|
|
|
1346
|
-
fig.suptitle(
|
|
1567
|
+
fig.suptitle(
|
|
1568
|
+
f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
|
|
1569
|
+
fontsize=11,
|
|
1570
|
+
y=0.995,
|
|
1571
|
+
)
|
|
1347
1572
|
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
|
1348
1573
|
|
|
1349
1574
|
if save:
|
|
1350
|
-
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
|
|
1575
|
+
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
|
|
1351
1576
|
plt.savefig(fname, bbox_inches="tight", dpi=dpi)
|
|
1352
1577
|
saved_files.append(fname)
|
|
1353
1578
|
else:
|