smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +49 -7
- smftools/cli/hmm_adata.py +250 -32
- smftools/cli/latent_adata.py +773 -0
- smftools/cli/load_adata.py +78 -74
- smftools/cli/preprocess_adata.py +122 -58
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +74 -112
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +52 -4
- smftools/config/conversion.yaml +1 -1
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +85 -12
- smftools/config/experiment_config.py +146 -1
- smftools/constants.py +69 -0
- smftools/hmm/HMM.py +88 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/__init__.py +6 -0
- smftools/informatics/bam_functions.py +358 -8
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +636 -175
- smftools/informatics/h5ad_functions.py +198 -2
- smftools/informatics/modkit_extract_to_adata.py +1007 -425
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/metadata.py +1 -1
- smftools/plotting/__init__.py +26 -3
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +62 -1583
- smftools/plotting/hmm_plotting.py +1670 -8
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +4 -0
- smftools/preprocessing/append_base_context.py +18 -18
- smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +159 -99
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +10 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +130 -0
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +79 -80
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +872 -0
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +217 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1231 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import scipy.cluster.hierarchy as sch
|
|
9
|
+
|
|
10
|
+
from smftools.logging_utils import get_logger
|
|
11
|
+
from smftools.optional_imports import require
|
|
12
|
+
|
|
13
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
|
|
14
|
+
patches = require("matplotlib.patches", extra="plotting", purpose="plot rendering")
|
|
15
|
+
colors = require("matplotlib.colors", extra="plotting", purpose="plot rendering")
|
|
16
|
+
grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
|
|
17
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
DNA_5COLOR_PALETTE = {
|
|
22
|
+
"A": "#00A000", # green
|
|
23
|
+
"C": "#0000FF", # blue
|
|
24
|
+
"G": "#FF7F00", # orange
|
|
25
|
+
"T": "#FF0000", # red
|
|
26
|
+
"OTHER": "#808080", # gray (N, PAD, unknown)
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def plot_mismatch_base_frequency_by_position(
|
|
31
|
+
adata,
|
|
32
|
+
sample_col: str = "Sample_Names",
|
|
33
|
+
reference_col: str = "Reference_strand",
|
|
34
|
+
mismatch_layer: str = "mismatch_integer_encoding",
|
|
35
|
+
read_span_layer: str = "read_span_mask",
|
|
36
|
+
quality_layer: str = "base_quality_scores",
|
|
37
|
+
plot_zscores: bool = False,
|
|
38
|
+
exclude_mod_sites: bool = False,
|
|
39
|
+
mod_site_bases: Sequence[str] | None = None,
|
|
40
|
+
min_quality: float | None = None,
|
|
41
|
+
min_length: int | None = None,
|
|
42
|
+
min_mapped_length_to_reference_length_ratio: float | None = None,
|
|
43
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
44
|
+
save_path: str | Path | None = None,
|
|
45
|
+
) -> List[Dict[str, Any]]:
|
|
46
|
+
"""Plot mismatch base frequencies by position per sample/reference.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
adata: AnnData with mismatch integer encoding layer.
|
|
50
|
+
sample_col: Column in ``adata.obs`` that identifies samples.
|
|
51
|
+
reference_col: Column in ``adata.obs`` that identifies references.
|
|
52
|
+
mismatch_layer: Layer name containing mismatch integer encodings.
|
|
53
|
+
read_span_layer: Layer name containing read-span masks.
|
|
54
|
+
quality_layer: Layer name containing base-quality scores used for z-scores.
|
|
55
|
+
plot_zscores: Whether to plot quality-normalized z-scores in a separate panel.
|
|
56
|
+
exclude_mod_sites: Whether to exclude annotated modification sites.
|
|
57
|
+
mod_site_bases: Base-context labels used to build mod-site masks (e.g., ``["GpC", "CpG"]``).
|
|
58
|
+
min_quality: Optional minimum read quality filter.
|
|
59
|
+
min_length: Optional minimum mapped length filter.
|
|
60
|
+
min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
|
|
61
|
+
demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
|
|
62
|
+
save_path: Optional output directory for saving plots.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List of dictionaries with per-plot metadata and output paths. Includes
|
|
66
|
+
a pooled-samples entry per reference.
|
|
67
|
+
"""
|
|
68
|
+
logger.info("Plotting mismatch base frequency by position.")
|
|
69
|
+
|
|
70
|
+
def _mask_or_true(series_name: str, predicate):
|
|
71
|
+
if series_name not in adata.obs:
|
|
72
|
+
return pd.Series(True, index=adata.obs.index)
|
|
73
|
+
s = adata.obs[series_name]
|
|
74
|
+
try:
|
|
75
|
+
return predicate(s)
|
|
76
|
+
except Exception:
|
|
77
|
+
return pd.Series(True, index=s.index)
|
|
78
|
+
|
|
79
|
+
def _build_mod_site_mask(var_frame, ref_name: str) -> np.ndarray | None:
|
|
80
|
+
if not exclude_mod_sites or not mod_site_bases:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
mod_site_cols = [f"{ref_name}_{base}_site" for base in mod_site_bases]
|
|
84
|
+
missing_required = [col for col in mod_site_cols if col not in var_frame.columns]
|
|
85
|
+
if missing_required:
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
extra_cols = []
|
|
89
|
+
if any(base in {"GpC", "CpG"} for base in mod_site_bases):
|
|
90
|
+
ambiguous_col = f"{ref_name}_ambiguous_GpC_CpG_site"
|
|
91
|
+
if ambiguous_col in var_frame.columns:
|
|
92
|
+
extra_cols.append(ambiguous_col)
|
|
93
|
+
|
|
94
|
+
mod_site_cols.extend(extra_cols)
|
|
95
|
+
mod_site_cols = list(dict.fromkeys(mod_site_cols))
|
|
96
|
+
|
|
97
|
+
mod_masks = [np.asarray(var_frame[col].values, dtype=bool) for col in mod_site_cols]
|
|
98
|
+
mod_mask = mod_masks[0] if len(mod_masks) == 1 else np.logical_or.reduce(mod_masks)
|
|
99
|
+
|
|
100
|
+
position_col = f"position_in_{ref_name}"
|
|
101
|
+
if position_col in var_frame.columns:
|
|
102
|
+
position_mask = np.asarray(var_frame[position_col].values, dtype=bool)
|
|
103
|
+
mod_mask = np.logical_and(mod_mask, position_mask)
|
|
104
|
+
|
|
105
|
+
return mod_mask
|
|
106
|
+
|
|
107
|
+
def _get_reference_base_series(subset, ref_name: str) -> pd.Series | None:
|
|
108
|
+
if f"{ref_name}_strand_FASTA_base" in subset.var.columns:
|
|
109
|
+
return subset.var[f"{ref_name}_strand_FASTA_base"].astype("string")
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
if mismatch_layer not in adata.layers:
|
|
113
|
+
raise KeyError(f"Layer '{mismatch_layer}' not found in adata.layers")
|
|
114
|
+
if plot_zscores and quality_layer not in adata.layers:
|
|
115
|
+
raise KeyError(f"Layer '{quality_layer}' not found in adata.layers")
|
|
116
|
+
|
|
117
|
+
mismatch_map = adata.uns.get("mismatch_integer_encoding_map", {}) or {}
|
|
118
|
+
if not mismatch_map:
|
|
119
|
+
raise KeyError("Mismatch encoding map not found in adata.uns")
|
|
120
|
+
|
|
121
|
+
base_int_to_label = {
|
|
122
|
+
int(value): str(base)
|
|
123
|
+
for base, value in mismatch_map.items()
|
|
124
|
+
if base not in {"N", "PAD"} and isinstance(value, (int, np.integer))
|
|
125
|
+
}
|
|
126
|
+
if not base_int_to_label:
|
|
127
|
+
raise ValueError("Mismatch encoding map missing base labels.")
|
|
128
|
+
|
|
129
|
+
base_label_to_int = {label: int_val for int_val, label in base_int_to_label.items()}
|
|
130
|
+
|
|
131
|
+
def _pooled_label(sample_categories: Sequence[str]) -> str:
|
|
132
|
+
base_label = "pooled_samples"
|
|
133
|
+
if base_label not in sample_categories:
|
|
134
|
+
return base_label
|
|
135
|
+
suffix = 1
|
|
136
|
+
while f"{base_label}_{suffix}" in sample_categories:
|
|
137
|
+
suffix += 1
|
|
138
|
+
return f"{base_label}_{suffix}"
|
|
139
|
+
|
|
140
|
+
results: List[Dict[str, Any]] = []
|
|
141
|
+
save_path = Path(save_path) if save_path is not None else None
|
|
142
|
+
if save_path is not None:
|
|
143
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
|
|
145
|
+
for col in (sample_col, reference_col):
|
|
146
|
+
if col not in adata.obs:
|
|
147
|
+
raise KeyError(f"{col} not in adata.obs")
|
|
148
|
+
if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
|
|
149
|
+
adata.obs[col] = adata.obs[col].astype("category")
|
|
150
|
+
|
|
151
|
+
sample_categories = [str(sample) for sample in adata.obs[sample_col].cat.categories]
|
|
152
|
+
pooled_sample = _pooled_label(sample_categories)
|
|
153
|
+
|
|
154
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
155
|
+
ref_name = str(ref)
|
|
156
|
+
base_mask = np.ones(adata.n_vars, dtype=bool)
|
|
157
|
+
position_col = f"position_in_{ref_name}"
|
|
158
|
+
if position_col in adata.var.columns:
|
|
159
|
+
base_mask = np.asarray(adata.var[position_col].values, dtype=bool)
|
|
160
|
+
|
|
161
|
+
base_mod_mask = _build_mod_site_mask(adata.var, ref_name)
|
|
162
|
+
if base_mod_mask is not None:
|
|
163
|
+
base_mask = base_mask & ~base_mod_mask
|
|
164
|
+
|
|
165
|
+
summary_data = {
|
|
166
|
+
"var_names_position": np.asarray(adata.var_names)[base_mask],
|
|
167
|
+
}
|
|
168
|
+
if "Original_var_names" in adata.var.columns:
|
|
169
|
+
summary_data["original_var_names_position"] = np.asarray(
|
|
170
|
+
adata.var["Original_var_names"]
|
|
171
|
+
)[base_mask]
|
|
172
|
+
reindexed_col = f"{ref_name}_reindexed"
|
|
173
|
+
if reindexed_col in adata.var.columns:
|
|
174
|
+
summary_data["reindexed_var_names_position"] = np.asarray(adata.var[reindexed_col])[
|
|
175
|
+
base_mask
|
|
176
|
+
]
|
|
177
|
+
ref_sequence_col = f"{ref_name}_strand_FASTA_base"
|
|
178
|
+
if ref_sequence_col in adata.var.columns:
|
|
179
|
+
summary_data["reference_sequence_base"] = np.asarray(adata.var[ref_sequence_col])[
|
|
180
|
+
base_mask
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
summary_df = pd.DataFrame(summary_data)
|
|
184
|
+
mean_positional_error = adata.var[f"{ref}_mean_error_rate"].values
|
|
185
|
+
std_positional_error = adata.var[f"{ref}_std_error_rate"].values
|
|
186
|
+
for sample in [*sample_categories, pooled_sample]:
|
|
187
|
+
qmask = _mask_or_true(
|
|
188
|
+
"read_quality",
|
|
189
|
+
(lambda s: s >= float(min_quality))
|
|
190
|
+
if (min_quality is not None)
|
|
191
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
192
|
+
)
|
|
193
|
+
lm_mask = _mask_or_true(
|
|
194
|
+
"mapped_length",
|
|
195
|
+
(lambda s: s >= float(min_length))
|
|
196
|
+
if (min_length is not None)
|
|
197
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
198
|
+
)
|
|
199
|
+
lrr_mask = _mask_or_true(
|
|
200
|
+
"mapped_length_to_reference_length_ratio",
|
|
201
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
202
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
203
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
204
|
+
)
|
|
205
|
+
demux_mask = _mask_or_true(
|
|
206
|
+
"demux_type",
|
|
207
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
208
|
+
if (demux_types is not None)
|
|
209
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
row_mask = (adata.obs[reference_col] == ref) & qmask & lm_mask & lrr_mask & demux_mask
|
|
213
|
+
if sample != pooled_sample:
|
|
214
|
+
row_mask = row_mask & (adata.obs[sample_col] == sample)
|
|
215
|
+
if not bool(row_mask.any()):
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
subset = adata[row_mask, :].copy()
|
|
219
|
+
mismatch_matrix = np.asarray(subset.layers[mismatch_layer])
|
|
220
|
+
|
|
221
|
+
if read_span_layer in subset.layers:
|
|
222
|
+
span_matrix = np.asarray(subset.layers[read_span_layer])
|
|
223
|
+
coverage_mask = span_matrix > 0
|
|
224
|
+
else:
|
|
225
|
+
coverage_mask = np.ones_like(mismatch_matrix, dtype=bool)
|
|
226
|
+
|
|
227
|
+
ref_bases = _get_reference_base_series(subset, str(ref))
|
|
228
|
+
logger.debug(f"Unconverted reference sequence for {ref}: {ref_bases.values}")
|
|
229
|
+
ref_lower = str(ref).lower()
|
|
230
|
+
if ref_bases is not None:
|
|
231
|
+
target_base = None
|
|
232
|
+
if "top" in ref_lower:
|
|
233
|
+
target_base = "C"
|
|
234
|
+
elif "bottom" in ref_lower:
|
|
235
|
+
target_base = "G"
|
|
236
|
+
else:
|
|
237
|
+
logger.debug(f"Could not find strand in {ref_lower}")
|
|
238
|
+
if target_base in base_label_to_int:
|
|
239
|
+
target_int = base_label_to_int[target_base]
|
|
240
|
+
logger.debug(
|
|
241
|
+
f"Ignoring {target_base} site mismatches in {ref} that yield a mismatch ambiguous to conversion"
|
|
242
|
+
)
|
|
243
|
+
ref_base_mask = np.asarray(ref_bases.values == target_base, dtype=bool)
|
|
244
|
+
ignore_mask = (mismatch_matrix == target_int) & ref_base_mask[None, :]
|
|
245
|
+
coverage_mask = coverage_mask & ~ignore_mask
|
|
246
|
+
|
|
247
|
+
coverage_counts = coverage_mask.sum(axis=0).astype(float)
|
|
248
|
+
|
|
249
|
+
ref_position_mask = subset.var.get(f"position_in_{ref}")
|
|
250
|
+
if ref_position_mask is None:
|
|
251
|
+
position_mask = np.ones(mismatch_matrix.shape[1], dtype=bool)
|
|
252
|
+
else:
|
|
253
|
+
position_mask = np.asarray(ref_position_mask.values, dtype=bool)
|
|
254
|
+
|
|
255
|
+
mod_site_mask = _build_mod_site_mask(subset.var, str(ref))
|
|
256
|
+
if mod_site_mask is not None:
|
|
257
|
+
position_mask = position_mask & ~mod_site_mask
|
|
258
|
+
|
|
259
|
+
position_mask = position_mask & (coverage_counts > 0)
|
|
260
|
+
if not np.any(position_mask):
|
|
261
|
+
continue
|
|
262
|
+
|
|
263
|
+
positions = np.arange(mismatch_matrix.shape[1])[position_mask]
|
|
264
|
+
mean_errors = mean_positional_error[position_mask]
|
|
265
|
+
normalized_mean_errors = (
|
|
266
|
+
mean_errors / 3
|
|
267
|
+
) # This is a conservative normalization against variant specific error rate
|
|
268
|
+
std_errors = std_positional_error[position_mask]
|
|
269
|
+
base_freqs: Dict[str, np.ndarray] = {}
|
|
270
|
+
for base_int, base_label in base_int_to_label.items():
|
|
271
|
+
base_counts = ((mismatch_matrix == base_int) & coverage_mask).sum(axis=0)
|
|
272
|
+
freq = np.divide(
|
|
273
|
+
base_counts,
|
|
274
|
+
coverage_counts,
|
|
275
|
+
out=np.full(mismatch_matrix.shape[1], np.nan, dtype=float),
|
|
276
|
+
where=coverage_counts > 0,
|
|
277
|
+
)
|
|
278
|
+
freq = np.where(freq > 0, freq, np.nan)
|
|
279
|
+
freq = freq[position_mask]
|
|
280
|
+
if np.all(np.isnan(freq)):
|
|
281
|
+
continue
|
|
282
|
+
base_freqs[base_label] = freq
|
|
283
|
+
|
|
284
|
+
if not base_freqs:
|
|
285
|
+
continue
|
|
286
|
+
|
|
287
|
+
zscore_freqs: Dict[str, np.ndarray] = {}
|
|
288
|
+
if plot_zscores:
|
|
289
|
+
quality_matrix = np.asarray(subset.layers[quality_layer]).astype(float)
|
|
290
|
+
quality_matrix[quality_matrix < 0] = np.nan
|
|
291
|
+
valid_quality = coverage_mask & ~np.isnan(quality_matrix)
|
|
292
|
+
error_probs = np.power(10.0, -quality_matrix / 10.0)
|
|
293
|
+
error_probs = np.where(valid_quality, error_probs, 0.0)
|
|
294
|
+
variant_probs = error_probs / 3.0
|
|
295
|
+
variance = (variant_probs * (1.0 - variant_probs)).sum(axis=0)
|
|
296
|
+
variance = variance[position_mask]
|
|
297
|
+
variance = np.where(variance > 0, variance, np.nan)
|
|
298
|
+
|
|
299
|
+
for base_int, base_label in base_int_to_label.items():
|
|
300
|
+
base_counts = (
|
|
301
|
+
((mismatch_matrix == base_int) & coverage_mask).sum(axis=0).astype(float)
|
|
302
|
+
)
|
|
303
|
+
expected_counts = variant_probs.sum(axis=0)
|
|
304
|
+
expected_counts = expected_counts[position_mask]
|
|
305
|
+
observed_counts = base_counts[position_mask]
|
|
306
|
+
zscores = np.divide(
|
|
307
|
+
observed_counts - expected_counts,
|
|
308
|
+
np.sqrt(variance),
|
|
309
|
+
out=np.full_like(expected_counts, np.nan, dtype=float),
|
|
310
|
+
where=~np.isnan(variance),
|
|
311
|
+
)
|
|
312
|
+
if np.all(np.isnan(zscores)):
|
|
313
|
+
continue
|
|
314
|
+
zscore_freqs[base_label] = zscores
|
|
315
|
+
if plot_zscores and save_path is not None:
|
|
316
|
+
full_max = np.full(adata.n_vars, np.nan, dtype=float)
|
|
317
|
+
full_base = np.full(adata.n_vars, None, dtype=object)
|
|
318
|
+
if zscore_freqs:
|
|
319
|
+
base_labels = sorted(zscore_freqs.keys())
|
|
320
|
+
zscore_stack = []
|
|
321
|
+
for base_label in base_labels:
|
|
322
|
+
full_z = np.full(adata.n_vars, np.nan, dtype=float)
|
|
323
|
+
full_z[position_mask] = zscore_freqs[base_label]
|
|
324
|
+
zscore_stack.append(full_z)
|
|
325
|
+
zscore_stack = np.vstack(zscore_stack)
|
|
326
|
+
all_nan = np.all(np.isnan(zscore_stack), axis=0)
|
|
327
|
+
safe_stack = np.where(np.isnan(zscore_stack), -np.inf, zscore_stack)
|
|
328
|
+
with np.errstate(invalid="ignore"):
|
|
329
|
+
full_max = np.nanmax(zscore_stack, axis=0)
|
|
330
|
+
max_idx = np.argmax(safe_stack, axis=0)
|
|
331
|
+
full_base = np.array([base_labels[idx] for idx in max_idx], dtype=object)
|
|
332
|
+
full_max[all_nan] = np.nan
|
|
333
|
+
full_base[all_nan] = None
|
|
334
|
+
summary_df[f"{sample}_max_zscore"] = full_max[base_mask]
|
|
335
|
+
summary_df[f"{sample}_max_zscore_base"] = full_base[base_mask]
|
|
336
|
+
|
|
337
|
+
if plot_zscores:
|
|
338
|
+
fig, axes = plt.subplots(nrows=2, figsize=(12, 7), sharex=True)
|
|
339
|
+
ax = axes[0]
|
|
340
|
+
zscore_ax = axes[1]
|
|
341
|
+
else:
|
|
342
|
+
fig, ax = plt.subplots(figsize=(12, 4))
|
|
343
|
+
zscore_ax = None
|
|
344
|
+
|
|
345
|
+
for base_label in sorted(base_freqs.keys()):
|
|
346
|
+
normalized_base = base_label if base_label in {"A", "C", "G", "T"} else "OTHER"
|
|
347
|
+
color = DNA_5COLOR_PALETTE.get(normalized_base, DNA_5COLOR_PALETTE["OTHER"])
|
|
348
|
+
ax.scatter(
|
|
349
|
+
positions, base_freqs[base_label], label=base_label, color=color, linewidth=1
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
ax.plot(
|
|
353
|
+
positions,
|
|
354
|
+
normalized_mean_errors,
|
|
355
|
+
label="Mean error rate",
|
|
356
|
+
color="black",
|
|
357
|
+
linestyle="--",
|
|
358
|
+
)
|
|
359
|
+
# ax.fill_between(
|
|
360
|
+
# positions,
|
|
361
|
+
# np.full_like(positions, lower, dtype=float),
|
|
362
|
+
# np.full_like(positions, upper, dtype=float),
|
|
363
|
+
# color="black",
|
|
364
|
+
# alpha=0.12,
|
|
365
|
+
# label="±1 std error",
|
|
366
|
+
# )
|
|
367
|
+
|
|
368
|
+
ax.set_yscale("log")
|
|
369
|
+
ax.set_xlabel("Position")
|
|
370
|
+
ax.set_ylabel("Mismatch frequency")
|
|
371
|
+
ax.set_title(f"{sample} - {ref} mismatch base frequencies")
|
|
372
|
+
ax.legend(title="Mismatch base", ncol=4, fontsize=9)
|
|
373
|
+
|
|
374
|
+
if plot_zscores and zscore_ax is not None and zscore_freqs:
|
|
375
|
+
for base_label in sorted(zscore_freqs.keys()):
|
|
376
|
+
normalized_base = base_label if base_label in {"A", "C", "G", "T"} else "OTHER"
|
|
377
|
+
color = DNA_5COLOR_PALETTE.get(normalized_base, DNA_5COLOR_PALETTE["OTHER"])
|
|
378
|
+
zscore_ax.scatter(
|
|
379
|
+
positions, zscore_freqs[base_label], label=base_label, color=color
|
|
380
|
+
)
|
|
381
|
+
zscore_ax.axhline(0.0, color="black", linestyle="--", linewidth=1)
|
|
382
|
+
zscore_ax.set_xlabel("Position")
|
|
383
|
+
zscore_ax.set_ylabel("Z-score")
|
|
384
|
+
zscore_ax.set_title(f"{sample} - {ref} quality-normalized mismatch z-scores")
|
|
385
|
+
zscore_ax.legend(title="Mismatch base", ncol=4, fontsize=9)
|
|
386
|
+
fig.tight_layout()
|
|
387
|
+
|
|
388
|
+
out_file = None
|
|
389
|
+
if save_path is not None:
|
|
390
|
+
safe_name = f"{ref}__{sample}__mismatch_base_frequency".replace("=", "").replace(
|
|
391
|
+
",", "_"
|
|
392
|
+
)
|
|
393
|
+
out_file = save_path / f"{safe_name}.png"
|
|
394
|
+
fig.savefig(out_file, dpi=300, bbox_inches="tight")
|
|
395
|
+
plt.close(fig)
|
|
396
|
+
logger.info("Saved mismatch base frequency plot to %s.", out_file)
|
|
397
|
+
else:
|
|
398
|
+
plt.show()
|
|
399
|
+
|
|
400
|
+
results.append(
|
|
401
|
+
{
|
|
402
|
+
"reference": str(ref),
|
|
403
|
+
"sample": str(sample),
|
|
404
|
+
"n_positions": int(positions.size),
|
|
405
|
+
"quality_layer": quality_layer if plot_zscores else None,
|
|
406
|
+
"output_path": str(out_file) if out_file is not None else None,
|
|
407
|
+
}
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
if save_path is not None and not summary_df.empty:
|
|
411
|
+
safe_ref = f"{ref_name}__mismatch_base_frequency_summary".replace("=", "").replace(
|
|
412
|
+
",", "_"
|
|
413
|
+
)
|
|
414
|
+
summary_file = save_path / f"{safe_ref}.csv"
|
|
415
|
+
summary_df.to_csv(summary_file, index=False)
|
|
416
|
+
logger.info("Saved mismatch base frequency summary to %s.", summary_file)
|
|
417
|
+
|
|
418
|
+
return results
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def plot_sequence_integer_encoding_clustermaps(
|
|
422
|
+
adata,
|
|
423
|
+
sample_col: str = "Sample_Names",
|
|
424
|
+
reference_col: str = "Reference_strand",
|
|
425
|
+
layer: str = "sequence_integer_encoding",
|
|
426
|
+
mismatch_layer: str = "mismatch_integer_encoding",
|
|
427
|
+
exclude_mod_sites: bool = False,
|
|
428
|
+
mod_site_bases: Sequence[str] | None = None,
|
|
429
|
+
min_quality: float | None = 20,
|
|
430
|
+
min_length: int | None = 200,
|
|
431
|
+
min_mapped_length_to_reference_length_ratio: float | None = 0,
|
|
432
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
433
|
+
sort_by: str = "none", # "none", "hierarchical", "obs:<col>"
|
|
434
|
+
cmap: str = "viridis",
|
|
435
|
+
max_unknown_fraction: float | None = None,
|
|
436
|
+
unknown_values: Sequence[int] = (4, 5),
|
|
437
|
+
xtick_step: int | None = None,
|
|
438
|
+
xtick_rotation: int = 90,
|
|
439
|
+
xtick_fontsize: int = 9,
|
|
440
|
+
max_reads: int | None = None,
|
|
441
|
+
save_path: str | Path | None = None,
|
|
442
|
+
use_dna_5color_palette: bool = True,
|
|
443
|
+
show_numeric_colorbar: bool = False,
|
|
444
|
+
show_position_axis: bool = False,
|
|
445
|
+
position_axis_tick_target: int = 25,
|
|
446
|
+
):
|
|
447
|
+
"""Plot integer-encoded sequence clustermaps per sample/reference.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
adata: AnnData with a ``sequence_integer_encoding`` layer.
|
|
451
|
+
sample_col: Column in ``adata.obs`` that identifies samples.
|
|
452
|
+
reference_col: Column in ``adata.obs`` that identifies references.
|
|
453
|
+
layer: Layer name containing integer-encoded sequences.
|
|
454
|
+
mismatch_layer: Optional layer name containing mismatch integer encodings.
|
|
455
|
+
exclude_mod_sites: Whether to exclude annotated modification sites.
|
|
456
|
+
mod_site_bases: Base-context labels used to build mod-site masks (e.g., ``["GpC", "CpG"]``).
|
|
457
|
+
min_quality: Optional minimum read quality filter.
|
|
458
|
+
min_length: Optional minimum mapped length filter.
|
|
459
|
+
min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
|
|
460
|
+
demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
|
|
461
|
+
sort_by: Row sorting strategy: ``none``, ``hierarchical``, or ``obs:<col>``.
|
|
462
|
+
cmap: Matplotlib colormap for the heatmap when ``use_dna_5color_palette`` is False.
|
|
463
|
+
max_unknown_fraction: Optional maximum fraction of ``unknown_values`` allowed per
|
|
464
|
+
position; positions above this threshold are excluded.
|
|
465
|
+
unknown_values: Integer values to treat as unknown/padding.
|
|
466
|
+
xtick_step: Spacing between x-axis tick labels (None = no labels).
|
|
467
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
468
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
469
|
+
max_reads: Optional maximum number of reads to plot per sample/reference.
|
|
470
|
+
save_path: Optional output directory for saving plots.
|
|
471
|
+
use_dna_5color_palette: Whether to use a fixed A/C/G/T/Other palette.
|
|
472
|
+
show_numeric_colorbar: If False, use a legend instead of a numeric colorbar.
|
|
473
|
+
show_position_axis: Whether to draw a position axis with tick labels.
|
|
474
|
+
position_axis_tick_target: Approximate number of ticks to show when auto-sizing.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
List of dictionaries with per-plot metadata and output paths.
|
|
478
|
+
"""
|
|
479
|
+
logger.info("Plotting sequence integer encoding clustermaps.")
|
|
480
|
+
|
|
481
|
+
def _mask_or_true(series_name: str, predicate):
|
|
482
|
+
if series_name not in adata.obs:
|
|
483
|
+
return pd.Series(True, index=adata.obs.index)
|
|
484
|
+
s = adata.obs[series_name]
|
|
485
|
+
try:
|
|
486
|
+
return predicate(s)
|
|
487
|
+
except Exception:
|
|
488
|
+
return pd.Series(True, index=adata.obs.index)
|
|
489
|
+
|
|
490
|
+
if layer not in adata.layers:
|
|
491
|
+
raise KeyError(f"Layer '{layer}' not found in adata.layers")
|
|
492
|
+
|
|
493
|
+
if max_unknown_fraction is not None and not (0 <= max_unknown_fraction <= 1):
|
|
494
|
+
raise ValueError("max_unknown_fraction must be between 0 and 1.")
|
|
495
|
+
|
|
496
|
+
if position_axis_tick_target < 1:
|
|
497
|
+
raise ValueError("position_axis_tick_target must be at least 1.")
|
|
498
|
+
|
|
499
|
+
results: List[Dict[str, Any]] = []
|
|
500
|
+
save_path = Path(save_path) if save_path is not None else None
|
|
501
|
+
if save_path is not None:
|
|
502
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
503
|
+
|
|
504
|
+
for col in (sample_col, reference_col):
|
|
505
|
+
if col not in adata.obs:
|
|
506
|
+
raise KeyError(f"{col} not in adata.obs")
|
|
507
|
+
if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
|
|
508
|
+
adata.obs[col] = adata.obs[col].astype("category")
|
|
509
|
+
|
|
510
|
+
int_to_base = adata.uns.get("sequence_integer_decoding_map", {}) or {}
|
|
511
|
+
if not int_to_base:
|
|
512
|
+
encoding_map = adata.uns.get("sequence_integer_encoding_map", {}) or {}
|
|
513
|
+
int_to_base = {int(v): str(k) for k, v in encoding_map.items()} if encoding_map else {}
|
|
514
|
+
|
|
515
|
+
coerced_int_to_base = {}
|
|
516
|
+
for key, value in int_to_base.items():
|
|
517
|
+
try:
|
|
518
|
+
coerced_key = int(key)
|
|
519
|
+
except Exception:
|
|
520
|
+
continue
|
|
521
|
+
coerced_int_to_base[coerced_key] = str(value)
|
|
522
|
+
int_to_base = coerced_int_to_base
|
|
523
|
+
|
|
524
|
+
def normalize_base(base: str) -> str:
|
|
525
|
+
return base if base in {"A", "C", "G", "T"} else "OTHER"
|
|
526
|
+
|
|
527
|
+
mismatch_int_to_base = {}
|
|
528
|
+
if mismatch_layer in adata.layers:
|
|
529
|
+
mismatch_encoding_map = adata.uns.get("mismatch_integer_encoding_map", {}) or {}
|
|
530
|
+
mismatch_int_to_base = {
|
|
531
|
+
int(v): str(k)
|
|
532
|
+
for k, v in mismatch_encoding_map.items()
|
|
533
|
+
if isinstance(v, (int, np.integer))
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
def _resolve_xtick_step(n_positions: int) -> int | None:
|
|
537
|
+
if xtick_step is not None:
|
|
538
|
+
return xtick_step
|
|
539
|
+
if not show_position_axis:
|
|
540
|
+
return None
|
|
541
|
+
return max(1, int(np.ceil(n_positions / position_axis_tick_target)))
|
|
542
|
+
|
|
543
|
+
def _build_mod_site_mask(var_frame, ref_name: str) -> np.ndarray | None:
|
|
544
|
+
if not exclude_mod_sites or not mod_site_bases:
|
|
545
|
+
return None
|
|
546
|
+
|
|
547
|
+
if hasattr(var_frame, "var"):
|
|
548
|
+
var_frame = var_frame.var
|
|
549
|
+
|
|
550
|
+
mod_site_cols = [f"{ref_name}_{base}_site" for base in mod_site_bases]
|
|
551
|
+
missing_required = [col for col in mod_site_cols if col not in var_frame.columns]
|
|
552
|
+
if missing_required:
|
|
553
|
+
return None
|
|
554
|
+
|
|
555
|
+
extra_cols = []
|
|
556
|
+
if any(base in {"GpC", "CpG"} for base in mod_site_bases):
|
|
557
|
+
ambiguous_col = f"{ref_name}_ambiguous_GpC_CpG_site"
|
|
558
|
+
if ambiguous_col in var_frame.columns:
|
|
559
|
+
extra_cols.append(ambiguous_col)
|
|
560
|
+
|
|
561
|
+
mod_site_cols.extend(extra_cols)
|
|
562
|
+
mod_site_cols = list(dict.fromkeys(mod_site_cols))
|
|
563
|
+
|
|
564
|
+
mod_masks = [np.asarray(var_frame[col].values, dtype=bool) for col in mod_site_cols]
|
|
565
|
+
mod_mask = mod_masks[0] if len(mod_masks) == 1 else np.logical_or.reduce(mod_masks)
|
|
566
|
+
|
|
567
|
+
position_col = f"position_in_{ref_name}"
|
|
568
|
+
if position_col in var_frame.columns:
|
|
569
|
+
position_mask = np.asarray(var_frame[position_col].values, dtype=bool)
|
|
570
|
+
mod_mask = np.logical_and(mod_mask, position_mask)
|
|
571
|
+
|
|
572
|
+
return mod_mask
|
|
573
|
+
|
|
574
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
575
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
576
|
+
qmask = _mask_or_true(
|
|
577
|
+
"read_quality",
|
|
578
|
+
(lambda s: s >= float(min_quality))
|
|
579
|
+
if (min_quality is not None)
|
|
580
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
581
|
+
)
|
|
582
|
+
lm_mask = _mask_or_true(
|
|
583
|
+
"mapped_length",
|
|
584
|
+
(lambda s: s >= float(min_length))
|
|
585
|
+
if (min_length is not None)
|
|
586
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
587
|
+
)
|
|
588
|
+
lrr_mask = _mask_or_true(
|
|
589
|
+
"mapped_length_to_reference_length_ratio",
|
|
590
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
591
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
592
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
593
|
+
)
|
|
594
|
+
demux_mask = _mask_or_true(
|
|
595
|
+
"demux_type",
|
|
596
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
597
|
+
if (demux_types is not None)
|
|
598
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
row_mask = (
|
|
602
|
+
(adata.obs[reference_col] == ref)
|
|
603
|
+
& (adata.obs[sample_col] == sample)
|
|
604
|
+
& qmask
|
|
605
|
+
& lm_mask
|
|
606
|
+
& lrr_mask
|
|
607
|
+
& demux_mask
|
|
608
|
+
)
|
|
609
|
+
if not bool(row_mask.any()):
|
|
610
|
+
continue
|
|
611
|
+
|
|
612
|
+
subset = adata[row_mask, :].copy()
|
|
613
|
+
matrix = np.asarray(subset.layers[layer])
|
|
614
|
+
mismatch_matrix = None
|
|
615
|
+
if mismatch_layer in subset.layers:
|
|
616
|
+
mismatch_matrix = np.asarray(subset.layers[mismatch_layer])
|
|
617
|
+
|
|
618
|
+
mod_site_mask = _build_mod_site_mask(subset, str(ref))
|
|
619
|
+
if mod_site_mask is not None:
|
|
620
|
+
keep_columns = ~mod_site_mask
|
|
621
|
+
if not np.any(keep_columns):
|
|
622
|
+
continue
|
|
623
|
+
matrix = matrix[:, keep_columns]
|
|
624
|
+
subset = subset[:, keep_columns].copy()
|
|
625
|
+
if mismatch_matrix is not None:
|
|
626
|
+
mismatch_matrix = mismatch_matrix[:, keep_columns]
|
|
627
|
+
|
|
628
|
+
if max_unknown_fraction is not None:
|
|
629
|
+
unknown_mask = np.isin(matrix, np.asarray(unknown_values))
|
|
630
|
+
unknown_fraction = unknown_mask.mean(axis=0)
|
|
631
|
+
keep_columns = unknown_fraction <= max_unknown_fraction
|
|
632
|
+
if not np.any(keep_columns):
|
|
633
|
+
continue
|
|
634
|
+
matrix = matrix[:, keep_columns]
|
|
635
|
+
subset = subset[:, keep_columns].copy()
|
|
636
|
+
if mismatch_matrix is not None:
|
|
637
|
+
mismatch_matrix = mismatch_matrix[:, keep_columns]
|
|
638
|
+
|
|
639
|
+
if max_reads is not None and matrix.shape[0] > max_reads:
|
|
640
|
+
matrix = matrix[:max_reads]
|
|
641
|
+
subset = subset[:max_reads, :].copy()
|
|
642
|
+
if mismatch_matrix is not None:
|
|
643
|
+
mismatch_matrix = mismatch_matrix[:max_reads]
|
|
644
|
+
|
|
645
|
+
if matrix.size == 0:
|
|
646
|
+
continue
|
|
647
|
+
|
|
648
|
+
if use_dna_5color_palette and not int_to_base:
|
|
649
|
+
uniq_vals = np.unique(matrix[~pd.isna(matrix)])
|
|
650
|
+
guess = {}
|
|
651
|
+
for val in uniq_vals:
|
|
652
|
+
try:
|
|
653
|
+
int_val = int(val)
|
|
654
|
+
except Exception:
|
|
655
|
+
continue
|
|
656
|
+
guess[int_val] = {0: "A", 1: "C", 2: "G", 3: "T"}.get(int_val, "OTHER")
|
|
657
|
+
int_to_base_local = guess
|
|
658
|
+
else:
|
|
659
|
+
int_to_base_local = int_to_base
|
|
660
|
+
|
|
661
|
+
order = None
|
|
662
|
+
if sort_by.startswith("obs:"):
|
|
663
|
+
colname = sort_by.split("obs:")[1]
|
|
664
|
+
order = np.argsort(subset.obs[colname].values)
|
|
665
|
+
elif sort_by == "hierarchical":
|
|
666
|
+
linkage = sch.linkage(np.nan_to_num(matrix), method="ward")
|
|
667
|
+
order = sch.leaves_list(linkage)
|
|
668
|
+
elif sort_by != "none":
|
|
669
|
+
raise ValueError("sort_by must be 'none', 'hierarchical', or 'obs:<col>'")
|
|
670
|
+
|
|
671
|
+
if order is not None:
|
|
672
|
+
matrix = matrix[order]
|
|
673
|
+
if mismatch_matrix is not None:
|
|
674
|
+
mismatch_matrix = mismatch_matrix[order]
|
|
675
|
+
|
|
676
|
+
has_mismatch = mismatch_matrix is not None
|
|
677
|
+
fig, axes = plt.subplots(
|
|
678
|
+
ncols=2 if has_mismatch else 1,
|
|
679
|
+
figsize=(18, 6) if has_mismatch else (12, 6),
|
|
680
|
+
sharey=has_mismatch,
|
|
681
|
+
)
|
|
682
|
+
if not isinstance(axes, np.ndarray):
|
|
683
|
+
axes = np.asarray([axes])
|
|
684
|
+
ax = axes[0]
|
|
685
|
+
|
|
686
|
+
if use_dna_5color_palette and int_to_base_local:
|
|
687
|
+
int_to_color = {
|
|
688
|
+
int(int_val): DNA_5COLOR_PALETTE[normalize_base(str(base))]
|
|
689
|
+
for int_val, base in int_to_base_local.items()
|
|
690
|
+
}
|
|
691
|
+
uniq_matrix = np.unique(matrix[~pd.isna(matrix)])
|
|
692
|
+
for val in uniq_matrix:
|
|
693
|
+
try:
|
|
694
|
+
int_val = int(val)
|
|
695
|
+
except Exception:
|
|
696
|
+
continue
|
|
697
|
+
if int_val not in int_to_color:
|
|
698
|
+
int_to_color[int_val] = DNA_5COLOR_PALETTE["OTHER"]
|
|
699
|
+
|
|
700
|
+
ordered = sorted(int_to_color.items(), key=lambda x: x[0])
|
|
701
|
+
colors_list = [color for _, color in ordered]
|
|
702
|
+
bounds = [int_val - 0.5 for int_val, _ in ordered]
|
|
703
|
+
bounds.append(ordered[-1][0] + 0.5)
|
|
704
|
+
|
|
705
|
+
cmap_obj = colors.ListedColormap(colors_list)
|
|
706
|
+
norm = colors.BoundaryNorm(bounds, cmap_obj.N)
|
|
707
|
+
|
|
708
|
+
sns.heatmap(
|
|
709
|
+
matrix,
|
|
710
|
+
cmap=cmap_obj,
|
|
711
|
+
norm=norm,
|
|
712
|
+
ax=ax,
|
|
713
|
+
yticklabels=False,
|
|
714
|
+
cbar=show_numeric_colorbar,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
legend_handles = [
|
|
718
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["A"], label="A"),
|
|
719
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["C"], label="C"),
|
|
720
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["G"], label="G"),
|
|
721
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["T"], label="T"),
|
|
722
|
+
patches.Patch(
|
|
723
|
+
facecolor=DNA_5COLOR_PALETTE["OTHER"],
|
|
724
|
+
label="Other (N / PAD / unknown)",
|
|
725
|
+
),
|
|
726
|
+
]
|
|
727
|
+
ax.legend(
|
|
728
|
+
handles=legend_handles,
|
|
729
|
+
title="Base",
|
|
730
|
+
loc="upper left",
|
|
731
|
+
bbox_to_anchor=(1.02, 1.0),
|
|
732
|
+
frameon=False,
|
|
733
|
+
)
|
|
734
|
+
else:
|
|
735
|
+
sns.heatmap(matrix, cmap=cmap, ax=ax, yticklabels=False, cbar=True)
|
|
736
|
+
|
|
737
|
+
ax.set_title(layer)
|
|
738
|
+
|
|
739
|
+
resolved_step = _resolve_xtick_step(matrix.shape[1])
|
|
740
|
+
if resolved_step is not None and resolved_step > 0:
|
|
741
|
+
sites = np.arange(0, matrix.shape[1], resolved_step)
|
|
742
|
+
ax.set_xticks(sites)
|
|
743
|
+
ax.set_xticklabels(
|
|
744
|
+
subset.var_names[sites].astype(str),
|
|
745
|
+
rotation=xtick_rotation,
|
|
746
|
+
fontsize=xtick_fontsize,
|
|
747
|
+
)
|
|
748
|
+
else:
|
|
749
|
+
ax.set_xticks([])
|
|
750
|
+
if show_position_axis or xtick_step is not None:
|
|
751
|
+
ax.set_xlabel("Position")
|
|
752
|
+
|
|
753
|
+
if has_mismatch:
|
|
754
|
+
mismatch_ax = axes[1]
|
|
755
|
+
mismatch_int_to_base_local = mismatch_int_to_base or int_to_base_local
|
|
756
|
+
if use_dna_5color_palette and mismatch_int_to_base_local:
|
|
757
|
+
mismatch_int_to_color = {}
|
|
758
|
+
for int_val, base in mismatch_int_to_base_local.items():
|
|
759
|
+
base_upper = str(base).upper()
|
|
760
|
+
if base_upper == "PAD":
|
|
761
|
+
mismatch_int_to_color[int(int_val)] = "#D3D3D3"
|
|
762
|
+
elif base_upper == "N":
|
|
763
|
+
mismatch_int_to_color[int(int_val)] = "#808080"
|
|
764
|
+
else:
|
|
765
|
+
mismatch_int_to_color[int(int_val)] = DNA_5COLOR_PALETTE[
|
|
766
|
+
normalize_base(base_upper)
|
|
767
|
+
]
|
|
768
|
+
|
|
769
|
+
uniq_mismatch = np.unique(mismatch_matrix[~pd.isna(mismatch_matrix)])
|
|
770
|
+
for val in uniq_mismatch:
|
|
771
|
+
try:
|
|
772
|
+
int_val = int(val)
|
|
773
|
+
except Exception:
|
|
774
|
+
continue
|
|
775
|
+
if int_val not in mismatch_int_to_color:
|
|
776
|
+
mismatch_int_to_color[int_val] = DNA_5COLOR_PALETTE["OTHER"]
|
|
777
|
+
|
|
778
|
+
ordered_mismatch = sorted(mismatch_int_to_color.items(), key=lambda x: x[0])
|
|
779
|
+
mismatch_colors = [color for _, color in ordered_mismatch]
|
|
780
|
+
mismatch_bounds = [int_val - 0.5 for int_val, _ in ordered_mismatch]
|
|
781
|
+
mismatch_bounds.append(ordered_mismatch[-1][0] + 0.5)
|
|
782
|
+
|
|
783
|
+
mismatch_cmap = colors.ListedColormap(mismatch_colors)
|
|
784
|
+
mismatch_norm = colors.BoundaryNorm(mismatch_bounds, mismatch_cmap.N)
|
|
785
|
+
|
|
786
|
+
sns.heatmap(
|
|
787
|
+
mismatch_matrix,
|
|
788
|
+
cmap=mismatch_cmap,
|
|
789
|
+
norm=mismatch_norm,
|
|
790
|
+
ax=mismatch_ax,
|
|
791
|
+
yticklabels=False,
|
|
792
|
+
cbar=show_numeric_colorbar,
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
mismatch_legend_handles = [
|
|
796
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["A"], label="A"),
|
|
797
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["C"], label="C"),
|
|
798
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["G"], label="G"),
|
|
799
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["T"], label="T"),
|
|
800
|
+
patches.Patch(facecolor="#808080", label="Match/N"),
|
|
801
|
+
patches.Patch(facecolor="#D3D3D3", label="PAD"),
|
|
802
|
+
]
|
|
803
|
+
mismatch_ax.legend(
|
|
804
|
+
handles=mismatch_legend_handles,
|
|
805
|
+
title="Mismatch base",
|
|
806
|
+
loc="upper left",
|
|
807
|
+
bbox_to_anchor=(1.02, 1.0),
|
|
808
|
+
frameon=False,
|
|
809
|
+
)
|
|
810
|
+
else:
|
|
811
|
+
sns.heatmap(
|
|
812
|
+
mismatch_matrix,
|
|
813
|
+
cmap=cmap,
|
|
814
|
+
ax=mismatch_ax,
|
|
815
|
+
yticklabels=False,
|
|
816
|
+
cbar=True,
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
mismatch_ax.set_title(mismatch_layer)
|
|
820
|
+
if resolved_step is not None and resolved_step > 0:
|
|
821
|
+
sites = np.arange(0, mismatch_matrix.shape[1], resolved_step)
|
|
822
|
+
mismatch_ax.set_xticks(sites)
|
|
823
|
+
mismatch_ax.set_xticklabels(
|
|
824
|
+
subset.var_names[sites].astype(str),
|
|
825
|
+
rotation=xtick_rotation,
|
|
826
|
+
fontsize=xtick_fontsize,
|
|
827
|
+
)
|
|
828
|
+
else:
|
|
829
|
+
mismatch_ax.set_xticks([])
|
|
830
|
+
if show_position_axis or xtick_step is not None:
|
|
831
|
+
mismatch_ax.set_xlabel("Position")
|
|
832
|
+
|
|
833
|
+
n_reads = matrix.shape[0]
|
|
834
|
+
|
|
835
|
+
fig.suptitle(f"{sample} - {ref} - {n_reads} reads")
|
|
836
|
+
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
|
837
|
+
|
|
838
|
+
out_file = None
|
|
839
|
+
if save_path is not None:
|
|
840
|
+
safe_name = f"{ref}__{sample}__{layer}".replace("=", "").replace(",", "_")
|
|
841
|
+
out_file = save_path / f"{safe_name}.png"
|
|
842
|
+
fig.savefig(out_file, dpi=300, bbox_inches="tight")
|
|
843
|
+
plt.close(fig)
|
|
844
|
+
logger.info("Saved sequence encoding clustermap to %s.", out_file)
|
|
845
|
+
else:
|
|
846
|
+
plt.show()
|
|
847
|
+
|
|
848
|
+
results.append(
|
|
849
|
+
{
|
|
850
|
+
"reference": str(ref),
|
|
851
|
+
"sample": str(sample),
|
|
852
|
+
"layer": layer,
|
|
853
|
+
"n_positions": int(matrix.shape[1]),
|
|
854
|
+
"mismatch_layer": mismatch_layer if has_mismatch else None,
|
|
855
|
+
"mismatch_layer_present": bool(has_mismatch),
|
|
856
|
+
"output_path": str(out_file) if out_file is not None else None,
|
|
857
|
+
}
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
return results
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def plot_variant_segment_clustermaps(
|
|
864
|
+
adata,
|
|
865
|
+
seq1_column: str,
|
|
866
|
+
seq2_column: str,
|
|
867
|
+
sample_col: str = "Sample_Names",
|
|
868
|
+
reference_col: str = "Reference_strand",
|
|
869
|
+
variant_segment_layer: str | None = None,
|
|
870
|
+
read_span_layer: str = "read_span_mask",
|
|
871
|
+
sort_by: str = "hierarchical",
|
|
872
|
+
max_reads: int | None = None,
|
|
873
|
+
save_path: str | Path | None = None,
|
|
874
|
+
seq1_color: str = "#8B5CF6",
|
|
875
|
+
seq2_color: str = "#4682b4",
|
|
876
|
+
transition_color: str = "#F5F5DC",
|
|
877
|
+
no_coverage_color: str = "#f0f0f0",
|
|
878
|
+
ref1_marker_color: str = "white",
|
|
879
|
+
ref2_marker_color: str = "black",
|
|
880
|
+
breakpoint_marker_color: str = "red",
|
|
881
|
+
marker_size: float = 4.0,
|
|
882
|
+
show_position_axis: bool = False,
|
|
883
|
+
position_axis_tick_target: int = 25,
|
|
884
|
+
xtick_rotation: int = 90,
|
|
885
|
+
xtick_fontsize: int = 9,
|
|
886
|
+
mismatch_type_obs_col: str | None = None,
|
|
887
|
+
mismatch_type_colors: Dict[str, str] | None = None,
|
|
888
|
+
) -> List[Dict[str, Any]]:
|
|
889
|
+
"""Plot variant segment heatmaps with variant call and breakpoint overlays.
|
|
890
|
+
|
|
891
|
+
Renders per-read segment blocks (seq1/seq2) as colored fills with beige
|
|
892
|
+
transition zones between different-class blocks. Overlays white circles
|
|
893
|
+
at seq1 variant call sites, black circles at seq2 sites, and red circles
|
|
894
|
+
at putative breakpoint positions (midpoint of each transition zone).
|
|
895
|
+
|
|
896
|
+
Args:
|
|
897
|
+
adata: AnnData object.
|
|
898
|
+
seq1_column: Name of reference 1 sequence column in ``adata.var``.
|
|
899
|
+
seq2_column: Name of reference 2 sequence column in ``adata.var``.
|
|
900
|
+
sample_col: Obs column for sample grouping.
|
|
901
|
+
reference_col: Obs column for reference grouping.
|
|
902
|
+
variant_segment_layer: Layer with variant segments (0=no coverage,
|
|
903
|
+
1=seq1, 2=seq2, 3=transition zone). Auto-derived if None.
|
|
904
|
+
read_span_layer: Layer containing read span masks.
|
|
905
|
+
sort_by: Row sorting strategy — ``"none"`` or ``"hierarchical"``.
|
|
906
|
+
max_reads: Maximum reads to display per panel.
|
|
907
|
+
save_path: Directory to save plots. If None, displays interactively.
|
|
908
|
+
seq1_color: Fill color for seq1 segments.
|
|
909
|
+
seq2_color: Fill color for seq2 segments.
|
|
910
|
+
transition_color: Fill color for transition zones between blocks.
|
|
911
|
+
no_coverage_color: Fill color for positions with no coverage.
|
|
912
|
+
ref1_marker_color: Circle color for seq1 variant call positions.
|
|
913
|
+
ref2_marker_color: Circle color for seq2 variant call positions.
|
|
914
|
+
breakpoint_marker_color: Circle color for putative breakpoint positions.
|
|
915
|
+
marker_size: Size of overlay circles.
|
|
916
|
+
show_position_axis: Whether to show genomic position labels on x-axis.
|
|
917
|
+
position_axis_tick_target: Target number of x-axis ticks.
|
|
918
|
+
xtick_rotation: Rotation angle for x-axis tick labels.
|
|
919
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
920
|
+
mismatch_type_obs_col: Optional obs column to annotate per read as an
|
|
921
|
+
adjacent categorical strip.
|
|
922
|
+
mismatch_type_colors: Optional mapping from mismatch class label to color.
|
|
923
|
+
Missing labels fall back to gray.
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
List of dicts with metadata about each generated plot.
|
|
927
|
+
"""
|
|
928
|
+
output_prefix = f"{seq1_column}__{seq2_column}"
|
|
929
|
+
if variant_segment_layer is None:
|
|
930
|
+
variant_segment_layer = f"{output_prefix}_variant_segments"
|
|
931
|
+
|
|
932
|
+
if variant_segment_layer not in adata.layers:
|
|
933
|
+
logger.warning("Variant segment layer '%s' not found; skipping.", variant_segment_layer)
|
|
934
|
+
return []
|
|
935
|
+
|
|
936
|
+
if save_path is not None:
|
|
937
|
+
save_path = Path(save_path)
|
|
938
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
939
|
+
|
|
940
|
+
logger.info("Plotting variant segment clustermaps.")
|
|
941
|
+
|
|
942
|
+
suffix = "_strand_FASTA_base"
|
|
943
|
+
seq1_label = seq1_column[: -len(suffix)] if seq1_column.endswith(suffix) else seq1_column
|
|
944
|
+
seq2_label = seq2_column[: -len(suffix)] if seq2_column.endswith(suffix) else seq2_column
|
|
945
|
+
|
|
946
|
+
# Colormap: 0=no coverage, 1=seq1, 2=seq2, 3=transition zone (beige)
|
|
947
|
+
seg_cmap = colors.ListedColormap([no_coverage_color, seq1_color, seq2_color, transition_color])
|
|
948
|
+
seg_norm = colors.BoundaryNorm([0, 0.5, 1.5, 2.5, 3.5], seg_cmap.N)
|
|
949
|
+
|
|
950
|
+
if mismatch_type_colors is None:
|
|
951
|
+
mismatch_type_colors = {
|
|
952
|
+
"no_segment_mismatch": "#bdbdbd",
|
|
953
|
+
"left_segment_mismatch": "#d73027",
|
|
954
|
+
"right_segment_mismatch": "#4575b4",
|
|
955
|
+
"middle_segment_mismatch": "#1a9850",
|
|
956
|
+
"multi_segment_mismatch": "#f46d43",
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
variant_call_layer = f"{output_prefix}_variant_call"
|
|
960
|
+
has_variant_calls = variant_call_layer in adata.layers
|
|
961
|
+
|
|
962
|
+
results: List[Dict[str, Any]] = []
|
|
963
|
+
|
|
964
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
965
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
966
|
+
row_mask = (adata.obs[reference_col] == ref) & (adata.obs[sample_col] == sample)
|
|
967
|
+
if not bool(row_mask.any()):
|
|
968
|
+
continue
|
|
969
|
+
|
|
970
|
+
subset = adata[row_mask, :].copy()
|
|
971
|
+
seg_matrix = np.asarray(subset.layers[variant_segment_layer])
|
|
972
|
+
n_reads, n_pos = seg_matrix.shape
|
|
973
|
+
|
|
974
|
+
if max_reads is not None and n_reads > max_reads:
|
|
975
|
+
subset = subset[:max_reads, :].copy()
|
|
976
|
+
seg_matrix = seg_matrix[:max_reads]
|
|
977
|
+
n_reads = max_reads
|
|
978
|
+
|
|
979
|
+
# Filter out positions with no coverage in any read
|
|
980
|
+
if read_span_layer in subset.layers:
|
|
981
|
+
span_matrix = np.asarray(subset.layers[read_span_layer])
|
|
982
|
+
if max_reads is not None:
|
|
983
|
+
span_matrix = span_matrix[:max_reads]
|
|
984
|
+
col_has_coverage = np.any(span_matrix > 0, axis=0)
|
|
985
|
+
else:
|
|
986
|
+
col_has_coverage = np.any(seg_matrix > 0, axis=0)
|
|
987
|
+
if not np.all(col_has_coverage):
|
|
988
|
+
seg_matrix = seg_matrix[:, col_has_coverage]
|
|
989
|
+
subset = subset[:, col_has_coverage].copy()
|
|
990
|
+
|
|
991
|
+
# Load variant call matrix for overlays and clustering
|
|
992
|
+
call_matrix = None
|
|
993
|
+
if has_variant_calls:
|
|
994
|
+
call_matrix = np.asarray(subset.layers[variant_call_layer])
|
|
995
|
+
if max_reads is not None:
|
|
996
|
+
call_matrix = call_matrix[:max_reads]
|
|
997
|
+
|
|
998
|
+
# Row ordering — cluster on variant call status
|
|
999
|
+
order = np.arange(n_reads)
|
|
1000
|
+
if sort_by == "hierarchical" and n_reads > 1 and call_matrix is not None:
|
|
1001
|
+
try:
|
|
1002
|
+
informative_cols = np.any(call_matrix > 0, axis=0)
|
|
1003
|
+
if informative_cols.any():
|
|
1004
|
+
cluster_data = call_matrix[:, informative_cols].astype(np.float64)
|
|
1005
|
+
cluster_data[cluster_data == 1] = 0.0
|
|
1006
|
+
cluster_data[cluster_data == 2] = 1.0
|
|
1007
|
+
cluster_data[(cluster_data != 0.0) & (cluster_data != 1.0)] = 0.5
|
|
1008
|
+
linkage = sch.linkage(cluster_data, method="ward")
|
|
1009
|
+
order = sch.leaves_list(linkage)
|
|
1010
|
+
except Exception:
|
|
1011
|
+
pass
|
|
1012
|
+
seg_matrix = seg_matrix[order]
|
|
1013
|
+
if call_matrix is not None:
|
|
1014
|
+
call_matrix = call_matrix[order]
|
|
1015
|
+
|
|
1016
|
+
row_mismatch_labels = None
|
|
1017
|
+
row_mismatch_legend = []
|
|
1018
|
+
if mismatch_type_obs_col is not None and mismatch_type_obs_col in subset.obs:
|
|
1019
|
+
mm_series = subset.obs[mismatch_type_obs_col]
|
|
1020
|
+
mm_values = mm_series.astype("string").to_numpy()[order]
|
|
1021
|
+
row_mismatch_labels = []
|
|
1022
|
+
for val in mm_values:
|
|
1023
|
+
if pd.isna(val):
|
|
1024
|
+
row_mismatch_labels.append("unknown")
|
|
1025
|
+
else:
|
|
1026
|
+
row_mismatch_labels.append(str(val))
|
|
1027
|
+
row_mismatch_legend = list(dict.fromkeys(row_mismatch_labels))
|
|
1028
|
+
|
|
1029
|
+
# Plot segment heatmap
|
|
1030
|
+
if row_mismatch_labels is None:
|
|
1031
|
+
fig, ax = plt.subplots(figsize=(16, 8))
|
|
1032
|
+
ax_mismatch = None
|
|
1033
|
+
else:
|
|
1034
|
+
fig = plt.figure(figsize=(16, 8))
|
|
1035
|
+
gs = fig.add_gridspec(1, 2, width_ratios=[0.6, 18], wspace=0.02)
|
|
1036
|
+
ax_mismatch = fig.add_subplot(gs[0, 0])
|
|
1037
|
+
ax = fig.add_subplot(gs[0, 1])
|
|
1038
|
+
sns.heatmap(
|
|
1039
|
+
seg_matrix.astype(np.float32),
|
|
1040
|
+
cmap=seg_cmap,
|
|
1041
|
+
norm=seg_norm,
|
|
1042
|
+
ax=ax,
|
|
1043
|
+
yticklabels=False,
|
|
1044
|
+
cbar=False,
|
|
1045
|
+
)
|
|
1046
|
+
|
|
1047
|
+
if row_mismatch_labels is not None and ax_mismatch is not None:
|
|
1048
|
+
mismatch_categories = list(dict.fromkeys(row_mismatch_legend))
|
|
1049
|
+
mismatch_color_list = [
|
|
1050
|
+
mismatch_type_colors.get(label, "#636363") for label in mismatch_categories
|
|
1051
|
+
]
|
|
1052
|
+
mismatch_to_code = {label: i for i, label in enumerate(mismatch_categories)}
|
|
1053
|
+
mismatch_codes = np.array(
|
|
1054
|
+
[mismatch_to_code[label] for label in row_mismatch_labels], dtype=np.int32
|
|
1055
|
+
).reshape(-1, 1)
|
|
1056
|
+
mismatch_cmap = colors.ListedColormap(mismatch_color_list)
|
|
1057
|
+
mismatch_norm = colors.BoundaryNorm(
|
|
1058
|
+
np.arange(-0.5, len(mismatch_categories) + 0.5, 1),
|
|
1059
|
+
mismatch_cmap.N,
|
|
1060
|
+
)
|
|
1061
|
+
ax_mismatch.imshow(
|
|
1062
|
+
mismatch_codes,
|
|
1063
|
+
cmap=mismatch_cmap,
|
|
1064
|
+
norm=mismatch_norm,
|
|
1065
|
+
aspect="auto",
|
|
1066
|
+
interpolation="nearest",
|
|
1067
|
+
origin="upper",
|
|
1068
|
+
)
|
|
1069
|
+
ax_mismatch.set_xticks([])
|
|
1070
|
+
ax_mismatch.set_yticks([])
|
|
1071
|
+
ax_mismatch.set_title("Type", fontsize=8, pad=8)
|
|
1072
|
+
|
|
1073
|
+
# Overlay variant call circles
|
|
1074
|
+
if call_matrix is not None:
|
|
1075
|
+
ref1_rows, ref1_cols = np.where(call_matrix == 1)
|
|
1076
|
+
ref2_rows, ref2_cols = np.where(call_matrix == 2)
|
|
1077
|
+
|
|
1078
|
+
if len(ref1_rows) > 0:
|
|
1079
|
+
ax.scatter(
|
|
1080
|
+
ref1_cols + 0.5,
|
|
1081
|
+
ref1_rows + 0.5,
|
|
1082
|
+
c=ref1_marker_color,
|
|
1083
|
+
s=marker_size,
|
|
1084
|
+
marker="o",
|
|
1085
|
+
edgecolors="gray",
|
|
1086
|
+
linewidths=0.3,
|
|
1087
|
+
zorder=3,
|
|
1088
|
+
label=f"{seq1_label} call",
|
|
1089
|
+
)
|
|
1090
|
+
if len(ref2_rows) > 0:
|
|
1091
|
+
ax.scatter(
|
|
1092
|
+
ref2_cols + 0.5,
|
|
1093
|
+
ref2_rows + 0.5,
|
|
1094
|
+
c=ref2_marker_color,
|
|
1095
|
+
s=marker_size,
|
|
1096
|
+
marker="o",
|
|
1097
|
+
edgecolors="gray",
|
|
1098
|
+
linewidths=0.3,
|
|
1099
|
+
zorder=3,
|
|
1100
|
+
label=f"{seq2_label} call",
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
# Overlay breakpoint circles at the midpoint of each transition zone
|
|
1104
|
+
bp_rows_list = []
|
|
1105
|
+
bp_cols_list = []
|
|
1106
|
+
for r in range(n_reads):
|
|
1107
|
+
row = seg_matrix[r]
|
|
1108
|
+
# Find contiguous runs of value 3 (transition zones)
|
|
1109
|
+
is_trans = row == 3
|
|
1110
|
+
if not np.any(is_trans):
|
|
1111
|
+
continue
|
|
1112
|
+
trans_positions = np.where(is_trans)[0]
|
|
1113
|
+
# Group into contiguous runs
|
|
1114
|
+
breaks = np.where(np.diff(trans_positions) > 1)[0] + 1
|
|
1115
|
+
runs = np.split(trans_positions, breaks)
|
|
1116
|
+
for run in runs:
|
|
1117
|
+
midpoint = run[len(run) // 2]
|
|
1118
|
+
bp_rows_list.append(r)
|
|
1119
|
+
bp_cols_list.append(midpoint)
|
|
1120
|
+
|
|
1121
|
+
if bp_rows_list:
|
|
1122
|
+
ax.scatter(
|
|
1123
|
+
np.array(bp_cols_list) + 0.5,
|
|
1124
|
+
np.array(bp_rows_list) + 0.5,
|
|
1125
|
+
c=breakpoint_marker_color,
|
|
1126
|
+
s=marker_size,
|
|
1127
|
+
marker="o",
|
|
1128
|
+
edgecolors="gray",
|
|
1129
|
+
linewidths=0.3,
|
|
1130
|
+
zorder=4,
|
|
1131
|
+
label="Breakpoint",
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
ax.set_title(f"{ref} — {sample}", fontsize=10)
|
|
1135
|
+
ax.set_ylabel("Reads")
|
|
1136
|
+
|
|
1137
|
+
if show_position_axis:
|
|
1138
|
+
n_cols = seg_matrix.shape[1]
|
|
1139
|
+
step = max(1, n_cols // position_axis_tick_target)
|
|
1140
|
+
sites = np.arange(0, n_cols, step)
|
|
1141
|
+
ax.set_xticks(sites)
|
|
1142
|
+
ax.set_xticklabels(
|
|
1143
|
+
subset.var_names[sites].astype(str),
|
|
1144
|
+
rotation=xtick_rotation,
|
|
1145
|
+
fontsize=xtick_fontsize,
|
|
1146
|
+
)
|
|
1147
|
+
else:
|
|
1148
|
+
ax.set_xticks([])
|
|
1149
|
+
|
|
1150
|
+
legend_elements = [
|
|
1151
|
+
patches.Patch(facecolor=seq1_color, label=f"{seq1_label} segment"),
|
|
1152
|
+
patches.Patch(facecolor=seq2_color, label=f"{seq2_label} segment"),
|
|
1153
|
+
patches.Patch(facecolor=transition_color, label="Transition zone"),
|
|
1154
|
+
patches.Patch(
|
|
1155
|
+
facecolor=no_coverage_color,
|
|
1156
|
+
edgecolor="gray",
|
|
1157
|
+
linewidth=0.5,
|
|
1158
|
+
label="No coverage",
|
|
1159
|
+
),
|
|
1160
|
+
plt.Line2D(
|
|
1161
|
+
[0],
|
|
1162
|
+
[0],
|
|
1163
|
+
marker="o",
|
|
1164
|
+
color="w",
|
|
1165
|
+
markerfacecolor=ref1_marker_color,
|
|
1166
|
+
markeredgecolor="gray",
|
|
1167
|
+
markersize=5,
|
|
1168
|
+
label=f"{seq1_label} call",
|
|
1169
|
+
),
|
|
1170
|
+
plt.Line2D(
|
|
1171
|
+
[0],
|
|
1172
|
+
[0],
|
|
1173
|
+
marker="o",
|
|
1174
|
+
color="w",
|
|
1175
|
+
markerfacecolor=ref2_marker_color,
|
|
1176
|
+
markeredgecolor="gray",
|
|
1177
|
+
markersize=5,
|
|
1178
|
+
label=f"{seq2_label} call",
|
|
1179
|
+
),
|
|
1180
|
+
plt.Line2D(
|
|
1181
|
+
[0],
|
|
1182
|
+
[0],
|
|
1183
|
+
marker="o",
|
|
1184
|
+
color="w",
|
|
1185
|
+
markerfacecolor=breakpoint_marker_color,
|
|
1186
|
+
markeredgecolor="gray",
|
|
1187
|
+
markersize=5,
|
|
1188
|
+
label="Breakpoint",
|
|
1189
|
+
),
|
|
1190
|
+
]
|
|
1191
|
+
if row_mismatch_labels is not None:
|
|
1192
|
+
for label in row_mismatch_legend:
|
|
1193
|
+
legend_elements.append(
|
|
1194
|
+
patches.Patch(
|
|
1195
|
+
facecolor=mismatch_type_colors.get(label, "#636363"),
|
|
1196
|
+
label=f"Mismatch type: {label}",
|
|
1197
|
+
)
|
|
1198
|
+
)
|
|
1199
|
+
ax.legend(
|
|
1200
|
+
handles=legend_elements,
|
|
1201
|
+
loc="upper left",
|
|
1202
|
+
bbox_to_anchor=(1.02, 1.0),
|
|
1203
|
+
fontsize=7,
|
|
1204
|
+
framealpha=0.8,
|
|
1205
|
+
frameon=False,
|
|
1206
|
+
)
|
|
1207
|
+
|
|
1208
|
+
fig.tight_layout(rect=(0, 0, 0.88, 1))
|
|
1209
|
+
|
|
1210
|
+
out_file = None
|
|
1211
|
+
if save_path is not None:
|
|
1212
|
+
safe_name = f"{ref}__{sample}__variant_segments".replace("=", "").replace(",", "_")
|
|
1213
|
+
out_file = save_path / f"{safe_name}.png"
|
|
1214
|
+
fig.savefig(out_file, dpi=300, bbox_inches="tight")
|
|
1215
|
+
plt.close(fig)
|
|
1216
|
+
logger.info("Saved variant segment clustermap to %s.", out_file)
|
|
1217
|
+
else:
|
|
1218
|
+
plt.show()
|
|
1219
|
+
|
|
1220
|
+
n_with_bp = int(np.sum(np.any(seg_matrix == 3, axis=1)))
|
|
1221
|
+
results.append(
|
|
1222
|
+
{
|
|
1223
|
+
"reference": str(ref),
|
|
1224
|
+
"sample": str(sample),
|
|
1225
|
+
"n_reads": n_reads,
|
|
1226
|
+
"n_reads_with_breakpoints": n_with_bp,
|
|
1227
|
+
"output_path": str(out_file) if out_file is not None else None,
|
|
1228
|
+
}
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
return results
|