smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +7 -1
- smftools/cli/hmm_adata.py +902 -244
- smftools/cli/load_adata.py +318 -198
- smftools/cli/preprocess_adata.py +285 -171
- smftools/cli/spatial_adata.py +137 -53
- smftools/cli_entry.py +94 -178
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +22 -17
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +505 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2125 -1426
- smftools/hmm/__init__.py +2 -3
- smftools/hmm/archived/call_hmm_peaks.py +16 -1
- smftools/hmm/call_hmm_peaks.py +173 -193
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +379 -156
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +195 -29
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +347 -168
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +145 -85
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +8 -8
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +103 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +688 -271
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.4.dist-info/RECORD +0 -176
- /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
import seaborn as sns
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
import scipy.cluster.hierarchy as sch
|
|
7
|
-
import matplotlib.gridspec as gridspec
|
|
8
|
-
import os
|
|
9
3
|
import math
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
import matplotlib.gridspec as gridspec
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import numpy as np
|
|
10
11
|
import pandas as pd
|
|
12
|
+
import scipy.cluster.hierarchy as sch
|
|
13
|
+
import seaborn as sns
|
|
11
14
|
|
|
12
|
-
from typing import Optional, Mapping, Sequence, Any, Dict, List, Tuple
|
|
13
|
-
from pathlib import Path
|
|
14
15
|
|
|
15
16
|
def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
16
17
|
"""
|
|
@@ -25,6 +26,7 @@ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
|
25
26
|
pos = np.linspace(0, n_positions - 1, n_ticks)
|
|
26
27
|
return np.unique(np.round(pos).astype(int))
|
|
27
28
|
|
|
29
|
+
|
|
28
30
|
def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
|
|
29
31
|
"""
|
|
30
32
|
Select tick labels for the heatmap axis.
|
|
@@ -65,11 +67,21 @@ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix:
|
|
|
65
67
|
labels = subset.var[colname].astype(str).values
|
|
66
68
|
return labels[sites]
|
|
67
69
|
|
|
70
|
+
|
|
68
71
|
def normalized_mean(matrix: np.ndarray) -> np.ndarray:
|
|
72
|
+
"""Compute normalized column means for a matrix.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
matrix: Input matrix.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
1D array of normalized means.
|
|
79
|
+
"""
|
|
69
80
|
mean = np.nanmean(matrix, axis=0)
|
|
70
81
|
denom = (mean.max() - mean.min()) + 1e-9
|
|
71
82
|
return (mean - mean.min()) / denom
|
|
72
83
|
|
|
84
|
+
|
|
73
85
|
def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
74
86
|
"""
|
|
75
87
|
Fraction methylated per column.
|
|
@@ -84,14 +96,20 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
|
84
96
|
valid = valid_mask.sum(axis=0)
|
|
85
97
|
|
|
86
98
|
return np.divide(
|
|
87
|
-
methylated, valid,
|
|
88
|
-
out=np.zeros_like(methylated, dtype=float),
|
|
89
|
-
where=valid != 0
|
|
99
|
+
methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
|
|
90
100
|
)
|
|
91
101
|
|
|
102
|
+
|
|
92
103
|
def clean_barplot(ax, mean_values, title):
|
|
104
|
+
"""Format a barplot with consistent axes and labels.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
ax: Matplotlib axes.
|
|
108
|
+
mean_values: Values to plot.
|
|
109
|
+
title: Plot title.
|
|
110
|
+
"""
|
|
93
111
|
x = np.arange(len(mean_values))
|
|
94
|
-
ax.bar(x, mean_values, color="gray", width=1.0, align=
|
|
112
|
+
ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
|
|
95
113
|
ax.set_xlim(0, len(mean_values))
|
|
96
114
|
ax.set_ylim(0, 1)
|
|
97
115
|
ax.set_yticks([0.0, 0.5, 1.0])
|
|
@@ -100,9 +118,10 @@ def clean_barplot(ax, mean_values, title):
|
|
|
100
118
|
|
|
101
119
|
# Hide all spines except left
|
|
102
120
|
for spine_name, spine in ax.spines.items():
|
|
103
|
-
spine.set_visible(spine_name ==
|
|
121
|
+
spine.set_visible(spine_name == "left")
|
|
122
|
+
|
|
123
|
+
ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
|
|
104
124
|
|
|
105
|
-
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
|
106
125
|
|
|
107
126
|
# def combined_hmm_raw_clustermap(
|
|
108
127
|
# adata,
|
|
@@ -145,7 +164,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
145
164
|
# (adata.obs['read_length'] >= min_length) &
|
|
146
165
|
# (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
|
|
147
166
|
# ]
|
|
148
|
-
|
|
167
|
+
|
|
149
168
|
# mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
150
169
|
# subset = subset[:, mask]
|
|
151
170
|
|
|
@@ -257,7 +276,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
257
276
|
# clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
|
|
258
277
|
# clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
|
|
259
278
|
# clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
|
|
260
|
-
|
|
279
|
+
|
|
261
280
|
# hmm_labels = subset.var_names.astype(int)
|
|
262
281
|
# hmm_label_spacing = 150
|
|
263
282
|
# sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
@@ -311,7 +330,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
311
330
|
# "bin_boundaries": bin_boundaries,
|
|
312
331
|
# "percentages": percentages
|
|
313
332
|
# })
|
|
314
|
-
|
|
333
|
+
|
|
315
334
|
# #adata.uns['clustermap_results'] = results
|
|
316
335
|
|
|
317
336
|
# except Exception as e:
|
|
@@ -319,45 +338,39 @@ def clean_barplot(ax, mean_values, title):
|
|
|
319
338
|
# traceback.print_exc()
|
|
320
339
|
# continue
|
|
321
340
|
|
|
341
|
+
|
|
322
342
|
def combined_hmm_raw_clustermap(
|
|
323
343
|
adata,
|
|
324
344
|
sample_col: str = "Sample_Names",
|
|
325
345
|
reference_col: str = "Reference_strand",
|
|
326
|
-
|
|
327
346
|
hmm_feature_layer: str = "hmm_combined",
|
|
328
|
-
|
|
329
347
|
layer_gpc: str = "nan0_0minus1",
|
|
330
348
|
layer_cpg: str = "nan0_0minus1",
|
|
331
349
|
layer_c: str = "nan0_0minus1",
|
|
332
350
|
layer_a: str = "nan0_0minus1",
|
|
333
|
-
|
|
334
351
|
cmap_hmm: str = "tab10",
|
|
335
352
|
cmap_gpc: str = "coolwarm",
|
|
336
353
|
cmap_cpg: str = "viridis",
|
|
337
354
|
cmap_c: str = "coolwarm",
|
|
338
355
|
cmap_a: str = "coolwarm",
|
|
339
|
-
|
|
340
356
|
min_quality: int = 20,
|
|
341
357
|
min_length: int = 200,
|
|
342
358
|
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
343
359
|
min_position_valid_fraction: float = 0.5,
|
|
344
|
-
|
|
360
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
361
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
345
362
|
save_path: str | Path | None = None,
|
|
346
363
|
normalize_hmm: bool = False,
|
|
347
|
-
|
|
348
364
|
sort_by: str = "gpc",
|
|
349
365
|
bins: Optional[Dict[str, Any]] = None,
|
|
350
|
-
|
|
351
366
|
deaminase: bool = False,
|
|
352
367
|
min_signal: float = 0.0,
|
|
353
|
-
|
|
354
368
|
# ---- fixed tick label controls (counts, not spacing)
|
|
355
369
|
n_xticks_hmm: int = 10,
|
|
356
370
|
n_xticks_any_c: int = 8,
|
|
357
371
|
n_xticks_gpc: int = 8,
|
|
358
372
|
n_xticks_cpg: int = 8,
|
|
359
373
|
n_xticks_a: int = 8,
|
|
360
|
-
|
|
361
374
|
index_col_suffix: str | None = None,
|
|
362
375
|
):
|
|
363
376
|
"""
|
|
@@ -369,39 +382,92 @@ def combined_hmm_raw_clustermap(
|
|
|
369
382
|
sort_by options:
|
|
370
383
|
'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
|
|
371
384
|
"""
|
|
385
|
+
|
|
372
386
|
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
387
|
+
"""Pick tick indices/labels from an array."""
|
|
373
388
|
if labels.size == 0:
|
|
374
389
|
return [], []
|
|
375
390
|
idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
|
|
376
391
|
idx = np.unique(idx)
|
|
377
392
|
return idx.tolist(), labels[idx].tolist()
|
|
378
|
-
|
|
393
|
+
|
|
394
|
+
# Helper: build a True mask if filter is inactive or column missing
|
|
395
|
+
def _mask_or_true(series_name: str, predicate):
|
|
396
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
397
|
+
if series_name not in adata.obs:
|
|
398
|
+
return pd.Series(True, index=adata.obs.index)
|
|
399
|
+
s = adata.obs[series_name]
|
|
400
|
+
try:
|
|
401
|
+
return predicate(s)
|
|
402
|
+
except Exception:
|
|
403
|
+
# Fallback: all True if bad dtype / predicate failure
|
|
404
|
+
return pd.Series(True, index=adata.obs.index)
|
|
405
|
+
|
|
379
406
|
results = []
|
|
380
407
|
signal_type = "deamination" if deaminase else "methylation"
|
|
381
408
|
|
|
382
409
|
for ref in adata.obs[reference_col].cat.categories:
|
|
383
410
|
for sample in adata.obs[sample_col].cat.categories:
|
|
411
|
+
# Optionally remap sample label for display
|
|
412
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
413
|
+
# Row-level masks (obs)
|
|
414
|
+
qmask = _mask_or_true(
|
|
415
|
+
"read_quality",
|
|
416
|
+
(lambda s: s >= float(min_quality))
|
|
417
|
+
if (min_quality is not None)
|
|
418
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
419
|
+
)
|
|
420
|
+
lm_mask = _mask_or_true(
|
|
421
|
+
"mapped_length",
|
|
422
|
+
(lambda s: s >= float(min_length))
|
|
423
|
+
if (min_length is not None)
|
|
424
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
425
|
+
)
|
|
426
|
+
lrr_mask = _mask_or_true(
|
|
427
|
+
"mapped_length_to_reference_length_ratio",
|
|
428
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
429
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
430
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
demux_mask = _mask_or_true(
|
|
434
|
+
"demux_type",
|
|
435
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
436
|
+
if (demux_types is not None)
|
|
437
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
441
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
442
|
+
|
|
443
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
444
|
+
|
|
445
|
+
if not bool(row_mask.any()):
|
|
446
|
+
print(
|
|
447
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
448
|
+
)
|
|
449
|
+
continue
|
|
384
450
|
|
|
385
451
|
try:
|
|
386
452
|
# ---- subset reads ----
|
|
387
|
-
subset = adata[
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
>
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
subset = subset[:, mask]
|
|
453
|
+
subset = adata[row_mask, :].copy()
|
|
454
|
+
|
|
455
|
+
# Column-level mask (var)
|
|
456
|
+
if min_position_valid_fraction is not None:
|
|
457
|
+
valid_key = f"{ref}_valid_fraction"
|
|
458
|
+
if valid_key in subset.var:
|
|
459
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
460
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
461
|
+
if col_mask.any():
|
|
462
|
+
subset = subset[:, col_mask].copy()
|
|
463
|
+
else:
|
|
464
|
+
print(
|
|
465
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
466
|
+
)
|
|
467
|
+
continue
|
|
403
468
|
|
|
404
469
|
if subset.shape[0] == 0:
|
|
470
|
+
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
405
471
|
continue
|
|
406
472
|
|
|
407
473
|
# ---- bins ----
|
|
@@ -412,22 +478,23 @@ def combined_hmm_raw_clustermap(
|
|
|
412
478
|
|
|
413
479
|
# ---- site masks (robust) ----
|
|
414
480
|
def _sites(*keys):
|
|
481
|
+
"""Return indices for the first matching site key."""
|
|
415
482
|
for k in keys:
|
|
416
483
|
if k in subset.var:
|
|
417
484
|
return np.where(subset.var[k].values)[0]
|
|
418
485
|
return np.array([], dtype=int)
|
|
419
486
|
|
|
420
|
-
gpc_sites
|
|
421
|
-
cpg_sites
|
|
487
|
+
gpc_sites = _sites(f"{ref}_GpC_site")
|
|
488
|
+
cpg_sites = _sites(f"{ref}_CpG_site")
|
|
422
489
|
any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
|
|
423
490
|
any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
|
|
424
491
|
|
|
425
492
|
# ---- labels via _select_labels ----
|
|
426
493
|
# HMM uses *all* columns
|
|
427
|
-
hmm_sites
|
|
428
|
-
hmm_labels
|
|
429
|
-
gpc_labels
|
|
430
|
-
cpg_labels
|
|
494
|
+
hmm_sites = np.arange(subset.n_vars, dtype=int)
|
|
495
|
+
hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
|
|
496
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
497
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
431
498
|
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
432
499
|
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
433
500
|
|
|
@@ -477,9 +544,11 @@ def combined_hmm_raw_clustermap(
|
|
|
477
544
|
elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
|
|
478
545
|
linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
|
|
479
546
|
order = sch.leaves_list(linkage)
|
|
480
|
-
|
|
547
|
+
|
|
481
548
|
elif sort_by == "hmm" and hmm_sites.size:
|
|
482
|
-
linkage = sch.linkage(
|
|
549
|
+
linkage = sch.linkage(
|
|
550
|
+
sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
|
|
551
|
+
)
|
|
483
552
|
order = sch.leaves_list(linkage)
|
|
484
553
|
|
|
485
554
|
else:
|
|
@@ -505,46 +574,62 @@ def combined_hmm_raw_clustermap(
|
|
|
505
574
|
|
|
506
575
|
# ---------------- stack ----------------
|
|
507
576
|
hmm_matrix = np.vstack(stacked_hmm)
|
|
508
|
-
mean_hmm =
|
|
577
|
+
mean_hmm = (
|
|
578
|
+
normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
|
|
579
|
+
)
|
|
509
580
|
|
|
510
581
|
panels = [
|
|
511
|
-
(
|
|
582
|
+
(
|
|
583
|
+
f"HMM - {hmm_feature_layer}",
|
|
584
|
+
hmm_matrix,
|
|
585
|
+
hmm_labels,
|
|
586
|
+
cmap_hmm,
|
|
587
|
+
mean_hmm,
|
|
588
|
+
n_xticks_hmm,
|
|
589
|
+
),
|
|
512
590
|
]
|
|
513
591
|
|
|
514
592
|
if stacked_any_c:
|
|
515
593
|
m = np.vstack(stacked_any_c)
|
|
516
|
-
panels.append(
|
|
594
|
+
panels.append(
|
|
595
|
+
("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
|
|
596
|
+
)
|
|
517
597
|
|
|
518
598
|
if stacked_gpc:
|
|
519
599
|
m = np.vstack(stacked_gpc)
|
|
520
|
-
panels.append(
|
|
600
|
+
panels.append(
|
|
601
|
+
("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
|
|
602
|
+
)
|
|
521
603
|
|
|
522
604
|
if stacked_cpg:
|
|
523
605
|
m = np.vstack(stacked_cpg)
|
|
524
|
-
panels.append(
|
|
606
|
+
panels.append(
|
|
607
|
+
("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
|
|
608
|
+
)
|
|
525
609
|
|
|
526
610
|
if stacked_any_a:
|
|
527
611
|
m = np.vstack(stacked_any_a)
|
|
528
|
-
panels.append(
|
|
612
|
+
panels.append(
|
|
613
|
+
("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
|
|
614
|
+
)
|
|
529
615
|
|
|
530
616
|
# ---------------- plotting ----------------
|
|
531
617
|
n_panels = len(panels)
|
|
532
618
|
fig = plt.figure(figsize=(4.5 * n_panels, 10))
|
|
533
619
|
gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
|
|
534
|
-
fig.suptitle(
|
|
535
|
-
|
|
620
|
+
fig.suptitle(
|
|
621
|
+
f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
|
|
622
|
+
)
|
|
536
623
|
|
|
537
624
|
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
|
|
538
625
|
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
|
|
539
626
|
|
|
540
627
|
for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
|
|
541
|
-
|
|
542
628
|
# ---- your clean barplot ----
|
|
543
629
|
clean_barplot(axes_bar[i], mean_vec, name)
|
|
544
630
|
|
|
545
631
|
# ---- heatmap ----
|
|
546
|
-
sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
|
|
547
|
-
yticklabels=False, cbar=False)
|
|
632
|
+
sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
|
|
548
633
|
|
|
549
634
|
# ---- xticks ----
|
|
550
635
|
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
@@ -568,6 +653,7 @@ def combined_hmm_raw_clustermap(
|
|
|
568
653
|
|
|
569
654
|
except Exception:
|
|
570
655
|
import traceback
|
|
656
|
+
|
|
571
657
|
traceback.print_exc()
|
|
572
658
|
continue
|
|
573
659
|
|
|
@@ -687,7 +773,7 @@ def combined_hmm_raw_clustermap(
|
|
|
687
773
|
# order = np.arange(num_reads)
|
|
688
774
|
# elif sort_by == "any_a":
|
|
689
775
|
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
690
|
-
# order = sch.leaves_list(linkage)
|
|
776
|
+
# order = sch.leaves_list(linkage)
|
|
691
777
|
# else:
|
|
692
778
|
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
693
779
|
|
|
@@ -716,13 +802,13 @@ def combined_hmm_raw_clustermap(
|
|
|
716
802
|
# order = np.arange(num_reads)
|
|
717
803
|
# elif sort_by == "any_a":
|
|
718
804
|
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
719
|
-
# order = sch.leaves_list(linkage)
|
|
805
|
+
# order = sch.leaves_list(linkage)
|
|
720
806
|
# else:
|
|
721
807
|
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
722
|
-
|
|
808
|
+
|
|
723
809
|
# stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
|
|
724
|
-
|
|
725
|
-
|
|
810
|
+
|
|
811
|
+
|
|
726
812
|
# row_labels.extend([bin_label] * num_reads)
|
|
727
813
|
# bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
728
814
|
# last_idx += num_reads
|
|
@@ -745,7 +831,7 @@ def combined_hmm_raw_clustermap(
|
|
|
745
831
|
# if any_a_matrix.size > 0:
|
|
746
832
|
# mean_any_a = methylation_fraction(any_a_matrix)
|
|
747
833
|
# gs_dim += 1
|
|
748
|
-
|
|
834
|
+
|
|
749
835
|
|
|
750
836
|
# fig = plt.figure(figsize=(18, 12))
|
|
751
837
|
# gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
|
|
@@ -777,8 +863,8 @@ def combined_hmm_raw_clustermap(
|
|
|
777
863
|
# sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
778
864
|
# axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
779
865
|
# for boundary in bin_boundaries[:-1]:
|
|
780
|
-
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
781
|
-
# current_ax +=1
|
|
866
|
+
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
867
|
+
# current_ax +=1
|
|
782
868
|
|
|
783
869
|
# results.append({
|
|
784
870
|
# "sample": sample,
|
|
@@ -790,7 +876,7 @@ def combined_hmm_raw_clustermap(
|
|
|
790
876
|
# "bin_labels": bin_labels,
|
|
791
877
|
# "bin_boundaries": bin_boundaries,
|
|
792
878
|
# "percentages": percentages
|
|
793
|
-
# })
|
|
879
|
+
# })
|
|
794
880
|
|
|
795
881
|
# if stacked_any_a:
|
|
796
882
|
# if any_a_matrix.size > 0:
|
|
@@ -810,7 +896,7 @@ def combined_hmm_raw_clustermap(
|
|
|
810
896
|
# "bin_labels": bin_labels,
|
|
811
897
|
# "bin_boundaries": bin_boundaries,
|
|
812
898
|
# "percentages": percentages
|
|
813
|
-
# })
|
|
899
|
+
# })
|
|
814
900
|
|
|
815
901
|
# plt.tight_layout()
|
|
816
902
|
|
|
@@ -828,7 +914,7 @@ def combined_hmm_raw_clustermap(
|
|
|
828
914
|
# print(f"Summary for {sample} - {ref}:")
|
|
829
915
|
# for bin_label, percent in percentages.items():
|
|
830
916
|
# print(f" - {bin_label}: {percent:.1f}%")
|
|
831
|
-
|
|
917
|
+
|
|
832
918
|
# adata.uns['clustermap_results'] = results
|
|
833
919
|
|
|
834
920
|
# except Exception as e:
|
|
@@ -836,6 +922,7 @@ def combined_hmm_raw_clustermap(
|
|
|
836
922
|
# traceback.print_exc()
|
|
837
923
|
# continue
|
|
838
924
|
|
|
925
|
+
|
|
839
926
|
def combined_raw_clustermap(
|
|
840
927
|
adata,
|
|
841
928
|
sample_col: str = "Sample_Names",
|
|
@@ -849,10 +936,11 @@ def combined_raw_clustermap(
|
|
|
849
936
|
cmap_gpc: str = "coolwarm",
|
|
850
937
|
cmap_cpg: str = "viridis",
|
|
851
938
|
cmap_a: str = "coolwarm",
|
|
852
|
-
min_quality: float = 20,
|
|
853
|
-
min_length: int = 200,
|
|
854
|
-
min_mapped_length_to_reference_length_ratio: float = 0
|
|
855
|
-
min_position_valid_fraction: float = 0
|
|
939
|
+
min_quality: float | None = 20,
|
|
940
|
+
min_length: int | None = 200,
|
|
941
|
+
min_mapped_length_to_reference_length_ratio: float | None = 0,
|
|
942
|
+
min_position_valid_fraction: float | None = 0,
|
|
943
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
856
944
|
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
857
945
|
save_path: str | Path | None = None,
|
|
858
946
|
sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
|
|
@@ -884,6 +972,18 @@ def combined_raw_clustermap(
|
|
|
884
972
|
One entry per (sample, ref) plot with matrices + bin metadata.
|
|
885
973
|
"""
|
|
886
974
|
|
|
975
|
+
# Helper: build a True mask if filter is inactive or column missing
|
|
976
|
+
def _mask_or_true(series_name: str, predicate):
|
|
977
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
978
|
+
if series_name not in adata.obs:
|
|
979
|
+
return pd.Series(True, index=adata.obs.index)
|
|
980
|
+
s = adata.obs[series_name]
|
|
981
|
+
try:
|
|
982
|
+
return predicate(s)
|
|
983
|
+
except Exception:
|
|
984
|
+
# Fallback: all True if bad dtype / predicate failure
|
|
985
|
+
return pd.Series(True, index=adata.obs.index)
|
|
986
|
+
|
|
887
987
|
results: List[Dict[str, Any]] = []
|
|
888
988
|
save_path = Path(save_path) if save_path is not None else None
|
|
889
989
|
if save_path is not None:
|
|
@@ -902,24 +1002,63 @@ def combined_raw_clustermap(
|
|
|
902
1002
|
|
|
903
1003
|
for ref in adata.obs[reference_col].cat.categories:
|
|
904
1004
|
for sample in adata.obs[sample_col].cat.categories:
|
|
905
|
-
|
|
906
1005
|
# Optionally remap sample label for display
|
|
907
1006
|
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
908
1007
|
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
1008
|
+
# Row-level masks (obs)
|
|
1009
|
+
qmask = _mask_or_true(
|
|
1010
|
+
"read_quality",
|
|
1011
|
+
(lambda s: s >= float(min_quality))
|
|
1012
|
+
if (min_quality is not None)
|
|
1013
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1014
|
+
)
|
|
1015
|
+
lm_mask = _mask_or_true(
|
|
1016
|
+
"mapped_length",
|
|
1017
|
+
(lambda s: s >= float(min_length))
|
|
1018
|
+
if (min_length is not None)
|
|
1019
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1020
|
+
)
|
|
1021
|
+
lrr_mask = _mask_or_true(
|
|
1022
|
+
"mapped_length_to_reference_length_ratio",
|
|
1023
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
1024
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
1025
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
demux_mask = _mask_or_true(
|
|
1029
|
+
"demux_type",
|
|
1030
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
1031
|
+
if (demux_types is not None)
|
|
1032
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
1036
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
1037
|
+
|
|
1038
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
1039
|
+
|
|
1040
|
+
if not bool(row_mask.any()):
|
|
1041
|
+
print(
|
|
1042
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
1043
|
+
)
|
|
1044
|
+
continue
|
|
917
1045
|
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
1046
|
+
try:
|
|
1047
|
+
subset = adata[row_mask, :].copy()
|
|
1048
|
+
|
|
1049
|
+
# Column-level mask (var)
|
|
1050
|
+
if min_position_valid_fraction is not None:
|
|
1051
|
+
valid_key = f"{ref}_valid_fraction"
|
|
1052
|
+
if valid_key in subset.var:
|
|
1053
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
1054
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
1055
|
+
if col_mask.any():
|
|
1056
|
+
subset = subset[:, col_mask].copy()
|
|
1057
|
+
else:
|
|
1058
|
+
print(
|
|
1059
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
1060
|
+
)
|
|
1061
|
+
continue
|
|
923
1062
|
|
|
924
1063
|
if subset.shape[0] == 0:
|
|
925
1064
|
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
@@ -939,14 +1078,14 @@ def combined_raw_clustermap(
|
|
|
939
1078
|
|
|
940
1079
|
if include_any_c:
|
|
941
1080
|
any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
|
|
942
|
-
gpc_sites
|
|
943
|
-
cpg_sites
|
|
1081
|
+
gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
|
|
1082
|
+
cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
|
|
944
1083
|
|
|
945
1084
|
num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
|
|
946
1085
|
|
|
947
1086
|
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
948
|
-
gpc_labels
|
|
949
|
-
cpg_labels
|
|
1087
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
1088
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
950
1089
|
|
|
951
1090
|
if include_any_a:
|
|
952
1091
|
any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
|
|
@@ -978,15 +1117,21 @@ def combined_raw_clustermap(
|
|
|
978
1117
|
order = np.argsort(subset_bin.obs[colname].values)
|
|
979
1118
|
|
|
980
1119
|
elif sort_by == "gpc" and num_gpc > 0:
|
|
981
|
-
linkage = sch.linkage(
|
|
1120
|
+
linkage = sch.linkage(
|
|
1121
|
+
subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
|
|
1122
|
+
)
|
|
982
1123
|
order = sch.leaves_list(linkage)
|
|
983
1124
|
|
|
984
1125
|
elif sort_by == "cpg" and num_cpg > 0:
|
|
985
|
-
linkage = sch.linkage(
|
|
1126
|
+
linkage = sch.linkage(
|
|
1127
|
+
subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
|
|
1128
|
+
)
|
|
986
1129
|
order = sch.leaves_list(linkage)
|
|
987
1130
|
|
|
988
1131
|
elif sort_by == "c" and num_any_c > 0:
|
|
989
|
-
linkage = sch.linkage(
|
|
1132
|
+
linkage = sch.linkage(
|
|
1133
|
+
subset_bin[:, any_c_sites].layers[layer_c], method="ward"
|
|
1134
|
+
)
|
|
990
1135
|
order = sch.leaves_list(linkage)
|
|
991
1136
|
|
|
992
1137
|
elif sort_by == "gpc_cpg":
|
|
@@ -994,7 +1139,9 @@ def combined_raw_clustermap(
|
|
|
994
1139
|
order = sch.leaves_list(linkage)
|
|
995
1140
|
|
|
996
1141
|
elif sort_by == "a" and num_any_a > 0:
|
|
997
|
-
linkage = sch.linkage(
|
|
1142
|
+
linkage = sch.linkage(
|
|
1143
|
+
subset_bin[:, any_a_sites].layers[layer_a], method="ward"
|
|
1144
|
+
)
|
|
998
1145
|
order = sch.leaves_list(linkage)
|
|
999
1146
|
|
|
1000
1147
|
elif sort_by == "none":
|
|
@@ -1027,57 +1174,65 @@ def combined_raw_clustermap(
|
|
|
1027
1174
|
|
|
1028
1175
|
if include_any_c and stacked_any_c:
|
|
1029
1176
|
any_c_matrix = np.vstack(stacked_any_c)
|
|
1030
|
-
gpc_matrix
|
|
1031
|
-
cpg_matrix
|
|
1177
|
+
gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
|
|
1178
|
+
cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
|
|
1032
1179
|
|
|
1033
1180
|
mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
|
|
1034
|
-
mean_gpc
|
|
1035
|
-
mean_cpg
|
|
1181
|
+
mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
|
|
1182
|
+
mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
|
|
1036
1183
|
|
|
1037
1184
|
if any_c_matrix.size:
|
|
1038
|
-
blocks.append(
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1185
|
+
blocks.append(
|
|
1186
|
+
dict(
|
|
1187
|
+
name="c",
|
|
1188
|
+
matrix=any_c_matrix,
|
|
1189
|
+
mean=mean_any_c,
|
|
1190
|
+
labels=any_c_labels,
|
|
1191
|
+
cmap=cmap_c,
|
|
1192
|
+
n_xticks=n_xticks_any_c,
|
|
1193
|
+
title="any C site Modification Signal",
|
|
1194
|
+
)
|
|
1195
|
+
)
|
|
1047
1196
|
if gpc_matrix.size:
|
|
1048
|
-
blocks.append(
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1197
|
+
blocks.append(
|
|
1198
|
+
dict(
|
|
1199
|
+
name="gpc",
|
|
1200
|
+
matrix=gpc_matrix,
|
|
1201
|
+
mean=mean_gpc,
|
|
1202
|
+
labels=gpc_labels,
|
|
1203
|
+
cmap=cmap_gpc,
|
|
1204
|
+
n_xticks=n_xticks_gpc,
|
|
1205
|
+
title="GpC Modification Signal",
|
|
1206
|
+
)
|
|
1207
|
+
)
|
|
1057
1208
|
if cpg_matrix.size:
|
|
1058
|
-
blocks.append(
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1209
|
+
blocks.append(
|
|
1210
|
+
dict(
|
|
1211
|
+
name="cpg",
|
|
1212
|
+
matrix=cpg_matrix,
|
|
1213
|
+
mean=mean_cpg,
|
|
1214
|
+
labels=cpg_labels,
|
|
1215
|
+
cmap=cmap_cpg,
|
|
1216
|
+
n_xticks=n_xticks_cpg,
|
|
1217
|
+
title="CpG Modification Signal",
|
|
1218
|
+
)
|
|
1219
|
+
)
|
|
1067
1220
|
|
|
1068
1221
|
if include_any_a and stacked_any_a:
|
|
1069
1222
|
any_a_matrix = np.vstack(stacked_any_a)
|
|
1070
1223
|
mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
|
|
1071
1224
|
if any_a_matrix.size:
|
|
1072
|
-
blocks.append(
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1225
|
+
blocks.append(
|
|
1226
|
+
dict(
|
|
1227
|
+
name="a",
|
|
1228
|
+
matrix=any_a_matrix,
|
|
1229
|
+
mean=mean_any_a,
|
|
1230
|
+
labels=any_a_labels,
|
|
1231
|
+
cmap=cmap_a,
|
|
1232
|
+
n_xticks=n_xticks_any_a,
|
|
1233
|
+
title="any A site Modification Signal",
|
|
1234
|
+
)
|
|
1235
|
+
)
|
|
1081
1236
|
|
|
1082
1237
|
if not blocks:
|
|
1083
1238
|
print(f"No matrices to plot for {display_sample} - {ref}")
|
|
@@ -1089,7 +1244,7 @@ def combined_raw_clustermap(
|
|
|
1089
1244
|
fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
|
|
1090
1245
|
|
|
1091
1246
|
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
|
|
1092
|
-
axes_bar
|
|
1247
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
|
|
1093
1248
|
|
|
1094
1249
|
# ----------------------------
|
|
1095
1250
|
# plot blocks
|
|
@@ -1105,20 +1260,14 @@ def combined_raw_clustermap(
|
|
|
1105
1260
|
|
|
1106
1261
|
# heatmap
|
|
1107
1262
|
sns.heatmap(
|
|
1108
|
-
mat,
|
|
1109
|
-
cmap=blk["cmap"],
|
|
1110
|
-
ax=axes_heat[i],
|
|
1111
|
-
yticklabels=False,
|
|
1112
|
-
cbar=False
|
|
1263
|
+
mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
|
|
1113
1264
|
)
|
|
1114
1265
|
|
|
1115
1266
|
# fixed tick labels
|
|
1116
1267
|
tick_pos = _fixed_tick_positions(len(labels), n_xticks)
|
|
1117
1268
|
axes_heat[i].set_xticks(tick_pos)
|
|
1118
1269
|
axes_heat[i].set_xticklabels(
|
|
1119
|
-
labels[tick_pos],
|
|
1120
|
-
rotation=xtick_rotation,
|
|
1121
|
-
fontsize=xtick_fontsize
|
|
1270
|
+
labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
|
|
1122
1271
|
)
|
|
1123
1272
|
|
|
1124
1273
|
# bin separators
|
|
@@ -1131,7 +1280,12 @@ def combined_raw_clustermap(
|
|
|
1131
1280
|
|
|
1132
1281
|
# save or show
|
|
1133
1282
|
if save_path is not None:
|
|
1134
|
-
safe_name =
|
|
1283
|
+
safe_name = (
|
|
1284
|
+
f"{ref}__{display_sample}".replace("=", "")
|
|
1285
|
+
.replace("__", "_")
|
|
1286
|
+
.replace(",", "_")
|
|
1287
|
+
.replace(" ", "_")
|
|
1288
|
+
)
|
|
1135
1289
|
out_file = save_path / f"{safe_name}.png"
|
|
1136
1290
|
fig.savefig(out_file, dpi=300)
|
|
1137
1291
|
plt.close(fig)
|
|
@@ -1157,20 +1311,15 @@ def combined_raw_clustermap(
|
|
|
1157
1311
|
for bin_label, percent in percentages.items():
|
|
1158
1312
|
print(f" - {bin_label}: {percent:.1f}%")
|
|
1159
1313
|
|
|
1160
|
-
except Exception
|
|
1314
|
+
except Exception:
|
|
1161
1315
|
import traceback
|
|
1316
|
+
|
|
1162
1317
|
traceback.print_exc()
|
|
1163
1318
|
continue
|
|
1164
1319
|
|
|
1165
|
-
# store once at the end (HDF5 safe)
|
|
1166
|
-
# matrices won't be HDF5-safe; store only metadata + maybe hit counts
|
|
1167
|
-
# adata.uns["clustermap_results"] = [
|
|
1168
|
-
# {k: v for k, v in r.items() if not k.endswith("_matrix")}
|
|
1169
|
-
# for r in results
|
|
1170
|
-
# ]
|
|
1171
|
-
|
|
1172
1320
|
return results
|
|
1173
1321
|
|
|
1322
|
+
|
|
1174
1323
|
def plot_hmm_layers_rolling_by_sample_ref(
|
|
1175
1324
|
adata,
|
|
1176
1325
|
layers: Optional[Sequence[str]] = None,
|
|
@@ -1237,7 +1386,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1237
1386
|
|
|
1238
1387
|
# --- basic checks / defaults ---
|
|
1239
1388
|
if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
|
|
1240
|
-
raise ValueError(
|
|
1389
|
+
raise ValueError(
|
|
1390
|
+
f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
|
|
1391
|
+
)
|
|
1241
1392
|
|
|
1242
1393
|
# canonicalize samples / refs
|
|
1243
1394
|
if samples is None:
|
|
@@ -1260,7 +1411,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1260
1411
|
if layers is None:
|
|
1261
1412
|
layers = list(adata.layers.keys())
|
|
1262
1413
|
if len(layers) == 0:
|
|
1263
|
-
raise ValueError(
|
|
1414
|
+
raise ValueError(
|
|
1415
|
+
"No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
|
|
1416
|
+
)
|
|
1264
1417
|
layers = list(layers)
|
|
1265
1418
|
|
|
1266
1419
|
# x coordinates (positions)
|
|
@@ -1299,19 +1452,29 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1299
1452
|
|
|
1300
1453
|
fig_w = figsize_per_cell[0] * ncols
|
|
1301
1454
|
fig_h = figsize_per_cell[1] * nrows
|
|
1302
|
-
fig, axes = plt.subplots(
|
|
1303
|
-
|
|
1304
|
-
|
|
1455
|
+
fig, axes = plt.subplots(
|
|
1456
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1457
|
+
)
|
|
1305
1458
|
|
|
1306
1459
|
for r_idx, sample_name in enumerate(chunk):
|
|
1307
1460
|
for c_idx, ref_name in enumerate(refs_all):
|
|
1308
1461
|
ax = axes[r_idx][c_idx]
|
|
1309
1462
|
|
|
1310
1463
|
# subset adata
|
|
1311
|
-
mask = (adata.obs[sample_col].values == sample_name) & (
|
|
1464
|
+
mask = (adata.obs[sample_col].values == sample_name) & (
|
|
1465
|
+
adata.obs[ref_col].values == ref_name
|
|
1466
|
+
)
|
|
1312
1467
|
sub = adata[mask]
|
|
1313
1468
|
if sub.n_obs == 0:
|
|
1314
|
-
ax.text(
|
|
1469
|
+
ax.text(
|
|
1470
|
+
0.5,
|
|
1471
|
+
0.5,
|
|
1472
|
+
"No reads",
|
|
1473
|
+
ha="center",
|
|
1474
|
+
va="center",
|
|
1475
|
+
transform=ax.transAxes,
|
|
1476
|
+
color="gray",
|
|
1477
|
+
)
|
|
1315
1478
|
ax.set_xticks([])
|
|
1316
1479
|
ax.set_yticks([])
|
|
1317
1480
|
if r_idx == 0:
|
|
@@ -1361,7 +1524,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1361
1524
|
smoothed = col_mean
|
|
1362
1525
|
else:
|
|
1363
1526
|
ser = pd.Series(col_mean)
|
|
1364
|
-
smoothed =
|
|
1527
|
+
smoothed = (
|
|
1528
|
+
ser.rolling(window=window, min_periods=min_periods, center=center)
|
|
1529
|
+
.mean()
|
|
1530
|
+
.to_numpy()
|
|
1531
|
+
)
|
|
1365
1532
|
|
|
1366
1533
|
# x axis: x_coords (trim/pad to match length)
|
|
1367
1534
|
L = len(col_mean)
|
|
@@ -1371,7 +1538,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1371
1538
|
if show_raw:
|
|
1372
1539
|
ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
|
|
1373
1540
|
|
|
1374
|
-
ax.plot(
|
|
1541
|
+
ax.plot(
|
|
1542
|
+
x,
|
|
1543
|
+
smoothed[:L],
|
|
1544
|
+
label=layer,
|
|
1545
|
+
color=colors[li],
|
|
1546
|
+
linewidth=1.2,
|
|
1547
|
+
alpha=0.95,
|
|
1548
|
+
zorder=2,
|
|
1549
|
+
)
|
|
1375
1550
|
plotted_any = True
|
|
1376
1551
|
|
|
1377
1552
|
# labels / titles
|
|
@@ -1389,11 +1564,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1389
1564
|
|
|
1390
1565
|
ax.grid(True, alpha=0.2)
|
|
1391
1566
|
|
|
1392
|
-
fig.suptitle(
|
|
1567
|
+
fig.suptitle(
|
|
1568
|
+
f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
|
|
1569
|
+
fontsize=11,
|
|
1570
|
+
y=0.995,
|
|
1571
|
+
)
|
|
1393
1572
|
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
|
1394
1573
|
|
|
1395
1574
|
if save:
|
|
1396
|
-
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
|
|
1575
|
+
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
|
|
1397
1576
|
plt.savefig(fname, bbox_inches="tight", dpi=dpi)
|
|
1398
1577
|
saved_files.append(fname)
|
|
1399
1578
|
else:
|