smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
import seaborn as sns
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
import scipy.cluster.hierarchy as sch
|
|
7
|
-
import matplotlib.gridspec as gridspec
|
|
8
|
-
import os
|
|
9
3
|
import math
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
10
9
|
import pandas as pd
|
|
10
|
+
import scipy.cluster.hierarchy as sch
|
|
11
|
+
|
|
12
|
+
from smftools.optional_imports import require
|
|
13
|
+
|
|
14
|
+
gridspec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
|
|
15
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
|
|
16
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
11
17
|
|
|
12
|
-
from typing import Optional, Mapping, Sequence, Any, Dict, List, Tuple
|
|
13
|
-
from pathlib import Path
|
|
14
18
|
|
|
15
19
|
def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
16
20
|
"""
|
|
@@ -25,6 +29,7 @@ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
|
25
29
|
pos = np.linspace(0, n_positions - 1, n_ticks)
|
|
26
30
|
return np.unique(np.round(pos).astype(int))
|
|
27
31
|
|
|
32
|
+
|
|
28
33
|
def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
|
|
29
34
|
"""
|
|
30
35
|
Select tick labels for the heatmap axis.
|
|
@@ -65,11 +70,21 @@ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix:
|
|
|
65
70
|
labels = subset.var[colname].astype(str).values
|
|
66
71
|
return labels[sites]
|
|
67
72
|
|
|
73
|
+
|
|
68
74
|
def normalized_mean(matrix: np.ndarray) -> np.ndarray:
|
|
75
|
+
"""Compute normalized column means for a matrix.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
matrix: Input matrix.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
1D array of normalized means.
|
|
82
|
+
"""
|
|
69
83
|
mean = np.nanmean(matrix, axis=0)
|
|
70
84
|
denom = (mean.max() - mean.min()) + 1e-9
|
|
71
85
|
return (mean - mean.min()) / denom
|
|
72
86
|
|
|
87
|
+
|
|
73
88
|
def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
74
89
|
"""
|
|
75
90
|
Fraction methylated per column.
|
|
@@ -84,14 +99,20 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
|
84
99
|
valid = valid_mask.sum(axis=0)
|
|
85
100
|
|
|
86
101
|
return np.divide(
|
|
87
|
-
methylated, valid,
|
|
88
|
-
out=np.zeros_like(methylated, dtype=float),
|
|
89
|
-
where=valid != 0
|
|
102
|
+
methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
|
|
90
103
|
)
|
|
91
104
|
|
|
105
|
+
|
|
92
106
|
def clean_barplot(ax, mean_values, title):
|
|
107
|
+
"""Format a barplot with consistent axes and labels.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
ax: Matplotlib axes.
|
|
111
|
+
mean_values: Values to plot.
|
|
112
|
+
title: Plot title.
|
|
113
|
+
"""
|
|
93
114
|
x = np.arange(len(mean_values))
|
|
94
|
-
ax.bar(x, mean_values, color="gray", width=1.0, align=
|
|
115
|
+
ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
|
|
95
116
|
ax.set_xlim(0, len(mean_values))
|
|
96
117
|
ax.set_ylim(0, 1)
|
|
97
118
|
ax.set_yticks([0.0, 0.5, 1.0])
|
|
@@ -100,9 +121,10 @@ def clean_barplot(ax, mean_values, title):
|
|
|
100
121
|
|
|
101
122
|
# Hide all spines except left
|
|
102
123
|
for spine_name, spine in ax.spines.items():
|
|
103
|
-
spine.set_visible(spine_name ==
|
|
124
|
+
spine.set_visible(spine_name == "left")
|
|
125
|
+
|
|
126
|
+
ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
|
|
104
127
|
|
|
105
|
-
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
|
106
128
|
|
|
107
129
|
# def combined_hmm_raw_clustermap(
|
|
108
130
|
# adata,
|
|
@@ -145,7 +167,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
145
167
|
# (adata.obs['read_length'] >= min_length) &
|
|
146
168
|
# (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
|
|
147
169
|
# ]
|
|
148
|
-
|
|
170
|
+
|
|
149
171
|
# mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
150
172
|
# subset = subset[:, mask]
|
|
151
173
|
|
|
@@ -257,7 +279,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
257
279
|
# clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
|
|
258
280
|
# clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
|
|
259
281
|
# clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
|
|
260
|
-
|
|
282
|
+
|
|
261
283
|
# hmm_labels = subset.var_names.astype(int)
|
|
262
284
|
# hmm_label_spacing = 150
|
|
263
285
|
# sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
@@ -311,7 +333,7 @@ def clean_barplot(ax, mean_values, title):
|
|
|
311
333
|
# "bin_boundaries": bin_boundaries,
|
|
312
334
|
# "percentages": percentages
|
|
313
335
|
# })
|
|
314
|
-
|
|
336
|
+
|
|
315
337
|
# #adata.uns['clustermap_results'] = results
|
|
316
338
|
|
|
317
339
|
# except Exception as e:
|
|
@@ -319,45 +341,39 @@ def clean_barplot(ax, mean_values, title):
|
|
|
319
341
|
# traceback.print_exc()
|
|
320
342
|
# continue
|
|
321
343
|
|
|
344
|
+
|
|
322
345
|
def combined_hmm_raw_clustermap(
|
|
323
346
|
adata,
|
|
324
347
|
sample_col: str = "Sample_Names",
|
|
325
348
|
reference_col: str = "Reference_strand",
|
|
326
|
-
|
|
327
349
|
hmm_feature_layer: str = "hmm_combined",
|
|
328
|
-
|
|
329
350
|
layer_gpc: str = "nan0_0minus1",
|
|
330
351
|
layer_cpg: str = "nan0_0minus1",
|
|
331
352
|
layer_c: str = "nan0_0minus1",
|
|
332
353
|
layer_a: str = "nan0_0minus1",
|
|
333
|
-
|
|
334
354
|
cmap_hmm: str = "tab10",
|
|
335
355
|
cmap_gpc: str = "coolwarm",
|
|
336
356
|
cmap_cpg: str = "viridis",
|
|
337
357
|
cmap_c: str = "coolwarm",
|
|
338
358
|
cmap_a: str = "coolwarm",
|
|
339
|
-
|
|
340
359
|
min_quality: int = 20,
|
|
341
360
|
min_length: int = 200,
|
|
342
361
|
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
343
362
|
min_position_valid_fraction: float = 0.5,
|
|
344
|
-
|
|
363
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
364
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
345
365
|
save_path: str | Path | None = None,
|
|
346
366
|
normalize_hmm: bool = False,
|
|
347
|
-
|
|
348
367
|
sort_by: str = "gpc",
|
|
349
368
|
bins: Optional[Dict[str, Any]] = None,
|
|
350
|
-
|
|
351
369
|
deaminase: bool = False,
|
|
352
370
|
min_signal: float = 0.0,
|
|
353
|
-
|
|
354
371
|
# ---- fixed tick label controls (counts, not spacing)
|
|
355
372
|
n_xticks_hmm: int = 10,
|
|
356
373
|
n_xticks_any_c: int = 8,
|
|
357
374
|
n_xticks_gpc: int = 8,
|
|
358
375
|
n_xticks_cpg: int = 8,
|
|
359
376
|
n_xticks_a: int = 8,
|
|
360
|
-
|
|
361
377
|
index_col_suffix: str | None = None,
|
|
362
378
|
):
|
|
363
379
|
"""
|
|
@@ -369,39 +385,92 @@ def combined_hmm_raw_clustermap(
|
|
|
369
385
|
sort_by options:
|
|
370
386
|
'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
|
|
371
387
|
"""
|
|
388
|
+
|
|
372
389
|
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
390
|
+
"""Pick tick indices/labels from an array."""
|
|
373
391
|
if labels.size == 0:
|
|
374
392
|
return [], []
|
|
375
393
|
idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
|
|
376
394
|
idx = np.unique(idx)
|
|
377
395
|
return idx.tolist(), labels[idx].tolist()
|
|
378
|
-
|
|
396
|
+
|
|
397
|
+
# Helper: build a True mask if filter is inactive or column missing
|
|
398
|
+
def _mask_or_true(series_name: str, predicate):
|
|
399
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
400
|
+
if series_name not in adata.obs:
|
|
401
|
+
return pd.Series(True, index=adata.obs.index)
|
|
402
|
+
s = adata.obs[series_name]
|
|
403
|
+
try:
|
|
404
|
+
return predicate(s)
|
|
405
|
+
except Exception:
|
|
406
|
+
# Fallback: all True if bad dtype / predicate failure
|
|
407
|
+
return pd.Series(True, index=adata.obs.index)
|
|
408
|
+
|
|
379
409
|
results = []
|
|
380
410
|
signal_type = "deamination" if deaminase else "methylation"
|
|
381
411
|
|
|
382
412
|
for ref in adata.obs[reference_col].cat.categories:
|
|
383
413
|
for sample in adata.obs[sample_col].cat.categories:
|
|
414
|
+
# Optionally remap sample label for display
|
|
415
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
416
|
+
# Row-level masks (obs)
|
|
417
|
+
qmask = _mask_or_true(
|
|
418
|
+
"read_quality",
|
|
419
|
+
(lambda s: s >= float(min_quality))
|
|
420
|
+
if (min_quality is not None)
|
|
421
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
422
|
+
)
|
|
423
|
+
lm_mask = _mask_or_true(
|
|
424
|
+
"mapped_length",
|
|
425
|
+
(lambda s: s >= float(min_length))
|
|
426
|
+
if (min_length is not None)
|
|
427
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
428
|
+
)
|
|
429
|
+
lrr_mask = _mask_or_true(
|
|
430
|
+
"mapped_length_to_reference_length_ratio",
|
|
431
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
432
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
433
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
demux_mask = _mask_or_true(
|
|
437
|
+
"demux_type",
|
|
438
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
439
|
+
if (demux_types is not None)
|
|
440
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
444
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
445
|
+
|
|
446
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
447
|
+
|
|
448
|
+
if not bool(row_mask.any()):
|
|
449
|
+
print(
|
|
450
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
451
|
+
)
|
|
452
|
+
continue
|
|
384
453
|
|
|
385
454
|
try:
|
|
386
455
|
# ---- subset reads ----
|
|
387
|
-
subset = adata[
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
>
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
subset = subset[:, mask]
|
|
456
|
+
subset = adata[row_mask, :].copy()
|
|
457
|
+
|
|
458
|
+
# Column-level mask (var)
|
|
459
|
+
if min_position_valid_fraction is not None:
|
|
460
|
+
valid_key = f"{ref}_valid_fraction"
|
|
461
|
+
if valid_key in subset.var:
|
|
462
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
463
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
464
|
+
if col_mask.any():
|
|
465
|
+
subset = subset[:, col_mask].copy()
|
|
466
|
+
else:
|
|
467
|
+
print(
|
|
468
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
469
|
+
)
|
|
470
|
+
continue
|
|
403
471
|
|
|
404
472
|
if subset.shape[0] == 0:
|
|
473
|
+
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
405
474
|
continue
|
|
406
475
|
|
|
407
476
|
# ---- bins ----
|
|
@@ -412,22 +481,23 @@ def combined_hmm_raw_clustermap(
|
|
|
412
481
|
|
|
413
482
|
# ---- site masks (robust) ----
|
|
414
483
|
def _sites(*keys):
|
|
484
|
+
"""Return indices for the first matching site key."""
|
|
415
485
|
for k in keys:
|
|
416
486
|
if k in subset.var:
|
|
417
487
|
return np.where(subset.var[k].values)[0]
|
|
418
488
|
return np.array([], dtype=int)
|
|
419
489
|
|
|
420
|
-
gpc_sites
|
|
421
|
-
cpg_sites
|
|
490
|
+
gpc_sites = _sites(f"{ref}_GpC_site")
|
|
491
|
+
cpg_sites = _sites(f"{ref}_CpG_site")
|
|
422
492
|
any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
|
|
423
493
|
any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
|
|
424
494
|
|
|
425
495
|
# ---- labels via _select_labels ----
|
|
426
496
|
# HMM uses *all* columns
|
|
427
|
-
hmm_sites
|
|
428
|
-
hmm_labels
|
|
429
|
-
gpc_labels
|
|
430
|
-
cpg_labels
|
|
497
|
+
hmm_sites = np.arange(subset.n_vars, dtype=int)
|
|
498
|
+
hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
|
|
499
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
500
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
431
501
|
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
432
502
|
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
433
503
|
|
|
@@ -477,9 +547,11 @@ def combined_hmm_raw_clustermap(
|
|
|
477
547
|
elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
|
|
478
548
|
linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
|
|
479
549
|
order = sch.leaves_list(linkage)
|
|
480
|
-
|
|
550
|
+
|
|
481
551
|
elif sort_by == "hmm" and hmm_sites.size:
|
|
482
|
-
linkage = sch.linkage(
|
|
552
|
+
linkage = sch.linkage(
|
|
553
|
+
sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
|
|
554
|
+
)
|
|
483
555
|
order = sch.leaves_list(linkage)
|
|
484
556
|
|
|
485
557
|
else:
|
|
@@ -505,46 +577,62 @@ def combined_hmm_raw_clustermap(
|
|
|
505
577
|
|
|
506
578
|
# ---------------- stack ----------------
|
|
507
579
|
hmm_matrix = np.vstack(stacked_hmm)
|
|
508
|
-
mean_hmm =
|
|
580
|
+
mean_hmm = (
|
|
581
|
+
normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
|
|
582
|
+
)
|
|
509
583
|
|
|
510
584
|
panels = [
|
|
511
|
-
(
|
|
585
|
+
(
|
|
586
|
+
f"HMM - {hmm_feature_layer}",
|
|
587
|
+
hmm_matrix,
|
|
588
|
+
hmm_labels,
|
|
589
|
+
cmap_hmm,
|
|
590
|
+
mean_hmm,
|
|
591
|
+
n_xticks_hmm,
|
|
592
|
+
),
|
|
512
593
|
]
|
|
513
594
|
|
|
514
595
|
if stacked_any_c:
|
|
515
596
|
m = np.vstack(stacked_any_c)
|
|
516
|
-
panels.append(
|
|
597
|
+
panels.append(
|
|
598
|
+
("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
|
|
599
|
+
)
|
|
517
600
|
|
|
518
601
|
if stacked_gpc:
|
|
519
602
|
m = np.vstack(stacked_gpc)
|
|
520
|
-
panels.append(
|
|
603
|
+
panels.append(
|
|
604
|
+
("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
|
|
605
|
+
)
|
|
521
606
|
|
|
522
607
|
if stacked_cpg:
|
|
523
608
|
m = np.vstack(stacked_cpg)
|
|
524
|
-
panels.append(
|
|
609
|
+
panels.append(
|
|
610
|
+
("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
|
|
611
|
+
)
|
|
525
612
|
|
|
526
613
|
if stacked_any_a:
|
|
527
614
|
m = np.vstack(stacked_any_a)
|
|
528
|
-
panels.append(
|
|
615
|
+
panels.append(
|
|
616
|
+
("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
|
|
617
|
+
)
|
|
529
618
|
|
|
530
619
|
# ---------------- plotting ----------------
|
|
531
620
|
n_panels = len(panels)
|
|
532
621
|
fig = plt.figure(figsize=(4.5 * n_panels, 10))
|
|
533
622
|
gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
|
|
534
|
-
fig.suptitle(
|
|
535
|
-
|
|
623
|
+
fig.suptitle(
|
|
624
|
+
f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
|
|
625
|
+
)
|
|
536
626
|
|
|
537
627
|
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
|
|
538
628
|
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
|
|
539
629
|
|
|
540
630
|
for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
|
|
541
|
-
|
|
542
631
|
# ---- your clean barplot ----
|
|
543
632
|
clean_barplot(axes_bar[i], mean_vec, name)
|
|
544
633
|
|
|
545
634
|
# ---- heatmap ----
|
|
546
|
-
sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
|
|
547
|
-
yticklabels=False, cbar=False)
|
|
635
|
+
sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
|
|
548
636
|
|
|
549
637
|
# ---- xticks ----
|
|
550
638
|
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
@@ -568,6 +656,7 @@ def combined_hmm_raw_clustermap(
|
|
|
568
656
|
|
|
569
657
|
except Exception:
|
|
570
658
|
import traceback
|
|
659
|
+
|
|
571
660
|
traceback.print_exc()
|
|
572
661
|
continue
|
|
573
662
|
|
|
@@ -687,7 +776,7 @@ def combined_hmm_raw_clustermap(
|
|
|
687
776
|
# order = np.arange(num_reads)
|
|
688
777
|
# elif sort_by == "any_a":
|
|
689
778
|
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
690
|
-
# order = sch.leaves_list(linkage)
|
|
779
|
+
# order = sch.leaves_list(linkage)
|
|
691
780
|
# else:
|
|
692
781
|
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
693
782
|
|
|
@@ -716,13 +805,13 @@ def combined_hmm_raw_clustermap(
|
|
|
716
805
|
# order = np.arange(num_reads)
|
|
717
806
|
# elif sort_by == "any_a":
|
|
718
807
|
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
719
|
-
# order = sch.leaves_list(linkage)
|
|
808
|
+
# order = sch.leaves_list(linkage)
|
|
720
809
|
# else:
|
|
721
810
|
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
722
|
-
|
|
811
|
+
|
|
723
812
|
# stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
|
|
724
|
-
|
|
725
|
-
|
|
813
|
+
|
|
814
|
+
|
|
726
815
|
# row_labels.extend([bin_label] * num_reads)
|
|
727
816
|
# bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
728
817
|
# last_idx += num_reads
|
|
@@ -745,7 +834,7 @@ def combined_hmm_raw_clustermap(
|
|
|
745
834
|
# if any_a_matrix.size > 0:
|
|
746
835
|
# mean_any_a = methylation_fraction(any_a_matrix)
|
|
747
836
|
# gs_dim += 1
|
|
748
|
-
|
|
837
|
+
|
|
749
838
|
|
|
750
839
|
# fig = plt.figure(figsize=(18, 12))
|
|
751
840
|
# gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
|
|
@@ -777,8 +866,8 @@ def combined_hmm_raw_clustermap(
|
|
|
777
866
|
# sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
778
867
|
# axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
779
868
|
# for boundary in bin_boundaries[:-1]:
|
|
780
|
-
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
781
|
-
# current_ax +=1
|
|
869
|
+
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
870
|
+
# current_ax +=1
|
|
782
871
|
|
|
783
872
|
# results.append({
|
|
784
873
|
# "sample": sample,
|
|
@@ -790,7 +879,7 @@ def combined_hmm_raw_clustermap(
|
|
|
790
879
|
# "bin_labels": bin_labels,
|
|
791
880
|
# "bin_boundaries": bin_boundaries,
|
|
792
881
|
# "percentages": percentages
|
|
793
|
-
# })
|
|
882
|
+
# })
|
|
794
883
|
|
|
795
884
|
# if stacked_any_a:
|
|
796
885
|
# if any_a_matrix.size > 0:
|
|
@@ -810,7 +899,7 @@ def combined_hmm_raw_clustermap(
|
|
|
810
899
|
# "bin_labels": bin_labels,
|
|
811
900
|
# "bin_boundaries": bin_boundaries,
|
|
812
901
|
# "percentages": percentages
|
|
813
|
-
# })
|
|
902
|
+
# })
|
|
814
903
|
|
|
815
904
|
# plt.tight_layout()
|
|
816
905
|
|
|
@@ -828,7 +917,7 @@ def combined_hmm_raw_clustermap(
|
|
|
828
917
|
# print(f"Summary for {sample} - {ref}:")
|
|
829
918
|
# for bin_label, percent in percentages.items():
|
|
830
919
|
# print(f" - {bin_label}: {percent:.1f}%")
|
|
831
|
-
|
|
920
|
+
|
|
832
921
|
# adata.uns['clustermap_results'] = results
|
|
833
922
|
|
|
834
923
|
# except Exception as e:
|
|
@@ -836,6 +925,7 @@ def combined_hmm_raw_clustermap(
|
|
|
836
925
|
# traceback.print_exc()
|
|
837
926
|
# continue
|
|
838
927
|
|
|
928
|
+
|
|
839
929
|
def combined_raw_clustermap(
|
|
840
930
|
adata,
|
|
841
931
|
sample_col: str = "Sample_Names",
|
|
@@ -849,10 +939,11 @@ def combined_raw_clustermap(
|
|
|
849
939
|
cmap_gpc: str = "coolwarm",
|
|
850
940
|
cmap_cpg: str = "viridis",
|
|
851
941
|
cmap_a: str = "coolwarm",
|
|
852
|
-
min_quality: float = 20,
|
|
853
|
-
min_length: int = 200,
|
|
854
|
-
min_mapped_length_to_reference_length_ratio: float = 0
|
|
855
|
-
min_position_valid_fraction: float = 0
|
|
942
|
+
min_quality: float | None = 20,
|
|
943
|
+
min_length: int | None = 200,
|
|
944
|
+
min_mapped_length_to_reference_length_ratio: float | None = 0,
|
|
945
|
+
min_position_valid_fraction: float | None = 0,
|
|
946
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
856
947
|
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
857
948
|
save_path: str | Path | None = None,
|
|
858
949
|
sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
|
|
@@ -884,6 +975,18 @@ def combined_raw_clustermap(
|
|
|
884
975
|
One entry per (sample, ref) plot with matrices + bin metadata.
|
|
885
976
|
"""
|
|
886
977
|
|
|
978
|
+
# Helper: build a True mask if filter is inactive or column missing
|
|
979
|
+
def _mask_or_true(series_name: str, predicate):
|
|
980
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
981
|
+
if series_name not in adata.obs:
|
|
982
|
+
return pd.Series(True, index=adata.obs.index)
|
|
983
|
+
s = adata.obs[series_name]
|
|
984
|
+
try:
|
|
985
|
+
return predicate(s)
|
|
986
|
+
except Exception:
|
|
987
|
+
# Fallback: all True if bad dtype / predicate failure
|
|
988
|
+
return pd.Series(True, index=adata.obs.index)
|
|
989
|
+
|
|
887
990
|
results: List[Dict[str, Any]] = []
|
|
888
991
|
save_path = Path(save_path) if save_path is not None else None
|
|
889
992
|
if save_path is not None:
|
|
@@ -902,24 +1005,63 @@ def combined_raw_clustermap(
|
|
|
902
1005
|
|
|
903
1006
|
for ref in adata.obs[reference_col].cat.categories:
|
|
904
1007
|
for sample in adata.obs[sample_col].cat.categories:
|
|
905
|
-
|
|
906
1008
|
# Optionally remap sample label for display
|
|
907
1009
|
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
908
1010
|
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
1011
|
+
# Row-level masks (obs)
|
|
1012
|
+
qmask = _mask_or_true(
|
|
1013
|
+
"read_quality",
|
|
1014
|
+
(lambda s: s >= float(min_quality))
|
|
1015
|
+
if (min_quality is not None)
|
|
1016
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1017
|
+
)
|
|
1018
|
+
lm_mask = _mask_or_true(
|
|
1019
|
+
"mapped_length",
|
|
1020
|
+
(lambda s: s >= float(min_length))
|
|
1021
|
+
if (min_length is not None)
|
|
1022
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1023
|
+
)
|
|
1024
|
+
lrr_mask = _mask_or_true(
|
|
1025
|
+
"mapped_length_to_reference_length_ratio",
|
|
1026
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
1027
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
1028
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
demux_mask = _mask_or_true(
|
|
1032
|
+
"demux_type",
|
|
1033
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
1034
|
+
if (demux_types is not None)
|
|
1035
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
1039
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
1040
|
+
|
|
1041
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
1042
|
+
|
|
1043
|
+
if not bool(row_mask.any()):
|
|
1044
|
+
print(
|
|
1045
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
1046
|
+
)
|
|
1047
|
+
continue
|
|
917
1048
|
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
1049
|
+
try:
|
|
1050
|
+
subset = adata[row_mask, :].copy()
|
|
1051
|
+
|
|
1052
|
+
# Column-level mask (var)
|
|
1053
|
+
if min_position_valid_fraction is not None:
|
|
1054
|
+
valid_key = f"{ref}_valid_fraction"
|
|
1055
|
+
if valid_key in subset.var:
|
|
1056
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
1057
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
1058
|
+
if col_mask.any():
|
|
1059
|
+
subset = subset[:, col_mask].copy()
|
|
1060
|
+
else:
|
|
1061
|
+
print(
|
|
1062
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
1063
|
+
)
|
|
1064
|
+
continue
|
|
923
1065
|
|
|
924
1066
|
if subset.shape[0] == 0:
|
|
925
1067
|
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
@@ -939,14 +1081,14 @@ def combined_raw_clustermap(
|
|
|
939
1081
|
|
|
940
1082
|
if include_any_c:
|
|
941
1083
|
any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
|
|
942
|
-
gpc_sites
|
|
943
|
-
cpg_sites
|
|
1084
|
+
gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
|
|
1085
|
+
cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
|
|
944
1086
|
|
|
945
1087
|
num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
|
|
946
1088
|
|
|
947
1089
|
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
948
|
-
gpc_labels
|
|
949
|
-
cpg_labels
|
|
1090
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
1091
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
950
1092
|
|
|
951
1093
|
if include_any_a:
|
|
952
1094
|
any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
|
|
@@ -978,15 +1120,21 @@ def combined_raw_clustermap(
|
|
|
978
1120
|
order = np.argsort(subset_bin.obs[colname].values)
|
|
979
1121
|
|
|
980
1122
|
elif sort_by == "gpc" and num_gpc > 0:
|
|
981
|
-
linkage = sch.linkage(
|
|
1123
|
+
linkage = sch.linkage(
|
|
1124
|
+
subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
|
|
1125
|
+
)
|
|
982
1126
|
order = sch.leaves_list(linkage)
|
|
983
1127
|
|
|
984
1128
|
elif sort_by == "cpg" and num_cpg > 0:
|
|
985
|
-
linkage = sch.linkage(
|
|
1129
|
+
linkage = sch.linkage(
|
|
1130
|
+
subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
|
|
1131
|
+
)
|
|
986
1132
|
order = sch.leaves_list(linkage)
|
|
987
1133
|
|
|
988
1134
|
elif sort_by == "c" and num_any_c > 0:
|
|
989
|
-
linkage = sch.linkage(
|
|
1135
|
+
linkage = sch.linkage(
|
|
1136
|
+
subset_bin[:, any_c_sites].layers[layer_c], method="ward"
|
|
1137
|
+
)
|
|
990
1138
|
order = sch.leaves_list(linkage)
|
|
991
1139
|
|
|
992
1140
|
elif sort_by == "gpc_cpg":
|
|
@@ -994,7 +1142,9 @@ def combined_raw_clustermap(
|
|
|
994
1142
|
order = sch.leaves_list(linkage)
|
|
995
1143
|
|
|
996
1144
|
elif sort_by == "a" and num_any_a > 0:
|
|
997
|
-
linkage = sch.linkage(
|
|
1145
|
+
linkage = sch.linkage(
|
|
1146
|
+
subset_bin[:, any_a_sites].layers[layer_a], method="ward"
|
|
1147
|
+
)
|
|
998
1148
|
order = sch.leaves_list(linkage)
|
|
999
1149
|
|
|
1000
1150
|
elif sort_by == "none":
|
|
@@ -1027,57 +1177,65 @@ def combined_raw_clustermap(
|
|
|
1027
1177
|
|
|
1028
1178
|
if include_any_c and stacked_any_c:
|
|
1029
1179
|
any_c_matrix = np.vstack(stacked_any_c)
|
|
1030
|
-
gpc_matrix
|
|
1031
|
-
cpg_matrix
|
|
1180
|
+
gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
|
|
1181
|
+
cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
|
|
1032
1182
|
|
|
1033
1183
|
mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
|
|
1034
|
-
mean_gpc
|
|
1035
|
-
mean_cpg
|
|
1184
|
+
mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
|
|
1185
|
+
mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
|
|
1036
1186
|
|
|
1037
1187
|
if any_c_matrix.size:
|
|
1038
|
-
blocks.append(
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1188
|
+
blocks.append(
|
|
1189
|
+
dict(
|
|
1190
|
+
name="c",
|
|
1191
|
+
matrix=any_c_matrix,
|
|
1192
|
+
mean=mean_any_c,
|
|
1193
|
+
labels=any_c_labels,
|
|
1194
|
+
cmap=cmap_c,
|
|
1195
|
+
n_xticks=n_xticks_any_c,
|
|
1196
|
+
title="any C site Modification Signal",
|
|
1197
|
+
)
|
|
1198
|
+
)
|
|
1047
1199
|
if gpc_matrix.size:
|
|
1048
|
-
blocks.append(
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1200
|
+
blocks.append(
|
|
1201
|
+
dict(
|
|
1202
|
+
name="gpc",
|
|
1203
|
+
matrix=gpc_matrix,
|
|
1204
|
+
mean=mean_gpc,
|
|
1205
|
+
labels=gpc_labels,
|
|
1206
|
+
cmap=cmap_gpc,
|
|
1207
|
+
n_xticks=n_xticks_gpc,
|
|
1208
|
+
title="GpC Modification Signal",
|
|
1209
|
+
)
|
|
1210
|
+
)
|
|
1057
1211
|
if cpg_matrix.size:
|
|
1058
|
-
blocks.append(
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1212
|
+
blocks.append(
|
|
1213
|
+
dict(
|
|
1214
|
+
name="cpg",
|
|
1215
|
+
matrix=cpg_matrix,
|
|
1216
|
+
mean=mean_cpg,
|
|
1217
|
+
labels=cpg_labels,
|
|
1218
|
+
cmap=cmap_cpg,
|
|
1219
|
+
n_xticks=n_xticks_cpg,
|
|
1220
|
+
title="CpG Modification Signal",
|
|
1221
|
+
)
|
|
1222
|
+
)
|
|
1067
1223
|
|
|
1068
1224
|
if include_any_a and stacked_any_a:
|
|
1069
1225
|
any_a_matrix = np.vstack(stacked_any_a)
|
|
1070
1226
|
mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
|
|
1071
1227
|
if any_a_matrix.size:
|
|
1072
|
-
blocks.append(
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1228
|
+
blocks.append(
|
|
1229
|
+
dict(
|
|
1230
|
+
name="a",
|
|
1231
|
+
matrix=any_a_matrix,
|
|
1232
|
+
mean=mean_any_a,
|
|
1233
|
+
labels=any_a_labels,
|
|
1234
|
+
cmap=cmap_a,
|
|
1235
|
+
n_xticks=n_xticks_any_a,
|
|
1236
|
+
title="any A site Modification Signal",
|
|
1237
|
+
)
|
|
1238
|
+
)
|
|
1081
1239
|
|
|
1082
1240
|
if not blocks:
|
|
1083
1241
|
print(f"No matrices to plot for {display_sample} - {ref}")
|
|
@@ -1089,7 +1247,7 @@ def combined_raw_clustermap(
|
|
|
1089
1247
|
fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
|
|
1090
1248
|
|
|
1091
1249
|
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
|
|
1092
|
-
axes_bar
|
|
1250
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
|
|
1093
1251
|
|
|
1094
1252
|
# ----------------------------
|
|
1095
1253
|
# plot blocks
|
|
@@ -1105,20 +1263,14 @@ def combined_raw_clustermap(
|
|
|
1105
1263
|
|
|
1106
1264
|
# heatmap
|
|
1107
1265
|
sns.heatmap(
|
|
1108
|
-
mat,
|
|
1109
|
-
cmap=blk["cmap"],
|
|
1110
|
-
ax=axes_heat[i],
|
|
1111
|
-
yticklabels=False,
|
|
1112
|
-
cbar=False
|
|
1266
|
+
mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
|
|
1113
1267
|
)
|
|
1114
1268
|
|
|
1115
1269
|
# fixed tick labels
|
|
1116
1270
|
tick_pos = _fixed_tick_positions(len(labels), n_xticks)
|
|
1117
1271
|
axes_heat[i].set_xticks(tick_pos)
|
|
1118
1272
|
axes_heat[i].set_xticklabels(
|
|
1119
|
-
labels[tick_pos],
|
|
1120
|
-
rotation=xtick_rotation,
|
|
1121
|
-
fontsize=xtick_fontsize
|
|
1273
|
+
labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
|
|
1122
1274
|
)
|
|
1123
1275
|
|
|
1124
1276
|
# bin separators
|
|
@@ -1131,7 +1283,12 @@ def combined_raw_clustermap(
|
|
|
1131
1283
|
|
|
1132
1284
|
# save or show
|
|
1133
1285
|
if save_path is not None:
|
|
1134
|
-
safe_name =
|
|
1286
|
+
safe_name = (
|
|
1287
|
+
f"{ref}__{display_sample}".replace("=", "")
|
|
1288
|
+
.replace("__", "_")
|
|
1289
|
+
.replace(",", "_")
|
|
1290
|
+
.replace(" ", "_")
|
|
1291
|
+
)
|
|
1135
1292
|
out_file = save_path / f"{safe_name}.png"
|
|
1136
1293
|
fig.savefig(out_file, dpi=300)
|
|
1137
1294
|
plt.close(fig)
|
|
@@ -1157,20 +1314,15 @@ def combined_raw_clustermap(
|
|
|
1157
1314
|
for bin_label, percent in percentages.items():
|
|
1158
1315
|
print(f" - {bin_label}: {percent:.1f}%")
|
|
1159
1316
|
|
|
1160
|
-
except Exception
|
|
1317
|
+
except Exception:
|
|
1161
1318
|
import traceback
|
|
1319
|
+
|
|
1162
1320
|
traceback.print_exc()
|
|
1163
1321
|
continue
|
|
1164
1322
|
|
|
1165
|
-
# store once at the end (HDF5 safe)
|
|
1166
|
-
# matrices won't be HDF5-safe; store only metadata + maybe hit counts
|
|
1167
|
-
# adata.uns["clustermap_results"] = [
|
|
1168
|
-
# {k: v for k, v in r.items() if not k.endswith("_matrix")}
|
|
1169
|
-
# for r in results
|
|
1170
|
-
# ]
|
|
1171
|
-
|
|
1172
1323
|
return results
|
|
1173
1324
|
|
|
1325
|
+
|
|
1174
1326
|
def plot_hmm_layers_rolling_by_sample_ref(
|
|
1175
1327
|
adata,
|
|
1176
1328
|
layers: Optional[Sequence[str]] = None,
|
|
@@ -1237,7 +1389,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1237
1389
|
|
|
1238
1390
|
# --- basic checks / defaults ---
|
|
1239
1391
|
if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
|
|
1240
|
-
raise ValueError(
|
|
1392
|
+
raise ValueError(
|
|
1393
|
+
f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
|
|
1394
|
+
)
|
|
1241
1395
|
|
|
1242
1396
|
# canonicalize samples / refs
|
|
1243
1397
|
if samples is None:
|
|
@@ -1260,7 +1414,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1260
1414
|
if layers is None:
|
|
1261
1415
|
layers = list(adata.layers.keys())
|
|
1262
1416
|
if len(layers) == 0:
|
|
1263
|
-
raise ValueError(
|
|
1417
|
+
raise ValueError(
|
|
1418
|
+
"No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
|
|
1419
|
+
)
|
|
1264
1420
|
layers = list(layers)
|
|
1265
1421
|
|
|
1266
1422
|
# x coordinates (positions)
|
|
@@ -1299,19 +1455,29 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1299
1455
|
|
|
1300
1456
|
fig_w = figsize_per_cell[0] * ncols
|
|
1301
1457
|
fig_h = figsize_per_cell[1] * nrows
|
|
1302
|
-
fig, axes = plt.subplots(
|
|
1303
|
-
|
|
1304
|
-
|
|
1458
|
+
fig, axes = plt.subplots(
|
|
1459
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1460
|
+
)
|
|
1305
1461
|
|
|
1306
1462
|
for r_idx, sample_name in enumerate(chunk):
|
|
1307
1463
|
for c_idx, ref_name in enumerate(refs_all):
|
|
1308
1464
|
ax = axes[r_idx][c_idx]
|
|
1309
1465
|
|
|
1310
1466
|
# subset adata
|
|
1311
|
-
mask = (adata.obs[sample_col].values == sample_name) & (
|
|
1467
|
+
mask = (adata.obs[sample_col].values == sample_name) & (
|
|
1468
|
+
adata.obs[ref_col].values == ref_name
|
|
1469
|
+
)
|
|
1312
1470
|
sub = adata[mask]
|
|
1313
1471
|
if sub.n_obs == 0:
|
|
1314
|
-
ax.text(
|
|
1472
|
+
ax.text(
|
|
1473
|
+
0.5,
|
|
1474
|
+
0.5,
|
|
1475
|
+
"No reads",
|
|
1476
|
+
ha="center",
|
|
1477
|
+
va="center",
|
|
1478
|
+
transform=ax.transAxes,
|
|
1479
|
+
color="gray",
|
|
1480
|
+
)
|
|
1315
1481
|
ax.set_xticks([])
|
|
1316
1482
|
ax.set_yticks([])
|
|
1317
1483
|
if r_idx == 0:
|
|
@@ -1361,7 +1527,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1361
1527
|
smoothed = col_mean
|
|
1362
1528
|
else:
|
|
1363
1529
|
ser = pd.Series(col_mean)
|
|
1364
|
-
smoothed =
|
|
1530
|
+
smoothed = (
|
|
1531
|
+
ser.rolling(window=window, min_periods=min_periods, center=center)
|
|
1532
|
+
.mean()
|
|
1533
|
+
.to_numpy()
|
|
1534
|
+
)
|
|
1365
1535
|
|
|
1366
1536
|
# x axis: x_coords (trim/pad to match length)
|
|
1367
1537
|
L = len(col_mean)
|
|
@@ -1371,7 +1541,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1371
1541
|
if show_raw:
|
|
1372
1542
|
ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
|
|
1373
1543
|
|
|
1374
|
-
ax.plot(
|
|
1544
|
+
ax.plot(
|
|
1545
|
+
x,
|
|
1546
|
+
smoothed[:L],
|
|
1547
|
+
label=layer,
|
|
1548
|
+
color=colors[li],
|
|
1549
|
+
linewidth=1.2,
|
|
1550
|
+
alpha=0.95,
|
|
1551
|
+
zorder=2,
|
|
1552
|
+
)
|
|
1375
1553
|
plotted_any = True
|
|
1376
1554
|
|
|
1377
1555
|
# labels / titles
|
|
@@ -1389,11 +1567,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1389
1567
|
|
|
1390
1568
|
ax.grid(True, alpha=0.2)
|
|
1391
1569
|
|
|
1392
|
-
fig.suptitle(
|
|
1570
|
+
fig.suptitle(
|
|
1571
|
+
f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
|
|
1572
|
+
fontsize=11,
|
|
1573
|
+
y=0.995,
|
|
1574
|
+
)
|
|
1393
1575
|
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
|
1394
1576
|
|
|
1395
1577
|
if save:
|
|
1396
|
-
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
|
|
1578
|
+
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
|
|
1397
1579
|
plt.savefig(fname, bbox_inches="tight", dpi=dpi)
|
|
1398
1580
|
saved_files.append(fname)
|
|
1399
1581
|
else:
|