smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,8 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import math
|
|
2
|
-
from typing import
|
|
4
|
+
from typing import Optional, Tuple, Union
|
|
5
|
+
|
|
3
6
|
import numpy as np
|
|
4
|
-
|
|
5
|
-
from
|
|
7
|
+
|
|
8
|
+
from smftools.optional_imports import require
|
|
9
|
+
|
|
10
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM plots")
|
|
11
|
+
pdf_backend = require(
|
|
12
|
+
"matplotlib.backends.backend_pdf",
|
|
13
|
+
extra="plotting",
|
|
14
|
+
purpose="PDF output",
|
|
15
|
+
)
|
|
16
|
+
PdfPages = pdf_backend.PdfPages
|
|
17
|
+
|
|
6
18
|
|
|
7
19
|
def plot_hmm_size_contours(
|
|
8
20
|
adata,
|
|
@@ -36,32 +48,41 @@ def plot_hmm_size_contours(
|
|
|
36
48
|
|
|
37
49
|
Other args are the same as prior function.
|
|
38
50
|
"""
|
|
51
|
+
|
|
39
52
|
# --- helper: gaussian smoothing (scipy fallback -> numpy separable conv) ---
|
|
40
53
|
def _gaussian_1d_kernel(sigma: float, eps: float = 1e-12):
|
|
54
|
+
"""Build a normalized 1D Gaussian kernel."""
|
|
41
55
|
if sigma <= 0 or sigma is None:
|
|
42
56
|
return np.array([1.0], dtype=float)
|
|
43
57
|
# choose kernel size = odd ~ 6*sigma (covers +/-3 sigma)
|
|
44
58
|
radius = max(1, int(math.ceil(3.0 * float(sigma))))
|
|
45
59
|
xs = np.arange(-radius, radius + 1, dtype=float)
|
|
46
|
-
k = np.exp(-(xs
|
|
60
|
+
k = np.exp(-(xs**2) / (2.0 * sigma**2))
|
|
47
61
|
k_sum = k.sum()
|
|
48
62
|
if k_sum <= eps:
|
|
49
63
|
k = np.array([1.0], dtype=float)
|
|
50
64
|
k_sum = 1.0
|
|
51
65
|
return k / k_sum
|
|
52
66
|
|
|
53
|
-
def _smooth_with_numpy_separable(
|
|
67
|
+
def _smooth_with_numpy_separable(
|
|
68
|
+
Z: np.ndarray, sigma_len: float, sigma_pos: float
|
|
69
|
+
) -> np.ndarray:
|
|
70
|
+
"""Apply separable Gaussian smoothing with NumPy."""
|
|
54
71
|
# Z shape: (n_lengths, n_positions)
|
|
55
72
|
out = Z.copy()
|
|
56
73
|
# smooth along length axis (axis=0)
|
|
57
74
|
if sigma_len and sigma_len > 0:
|
|
58
75
|
k_len = _gaussian_1d_kernel(sigma_len)
|
|
59
76
|
# convolve each column
|
|
60
|
-
out = np.apply_along_axis(
|
|
77
|
+
out = np.apply_along_axis(
|
|
78
|
+
lambda col: np.convolve(col, k_len, mode="same"), axis=0, arr=out
|
|
79
|
+
)
|
|
61
80
|
# smooth along position axis (axis=1)
|
|
62
81
|
if sigma_pos and sigma_pos > 0:
|
|
63
82
|
k_pos = _gaussian_1d_kernel(sigma_pos)
|
|
64
|
-
out = np.apply_along_axis(
|
|
83
|
+
out = np.apply_along_axis(
|
|
84
|
+
lambda row: np.convolve(row, k_pos, mode="same"), axis=1, arr=out
|
|
85
|
+
)
|
|
65
86
|
return out
|
|
66
87
|
|
|
67
88
|
# prefer scipy.ndimage if available (faster and better boundary handling)
|
|
@@ -69,11 +90,13 @@ def plot_hmm_size_contours(
|
|
|
69
90
|
if use_scipy_if_available:
|
|
70
91
|
try:
|
|
71
92
|
from scipy.ndimage import gaussian_filter as _scipy_gaussian_filter
|
|
93
|
+
|
|
72
94
|
_have_scipy = True
|
|
73
95
|
except Exception:
|
|
74
96
|
_have_scipy = False
|
|
75
97
|
|
|
76
98
|
def _smooth_Z(Z: np.ndarray, sigma_len: float, sigma_pos: float) -> np.ndarray:
|
|
99
|
+
"""Smooth a matrix using scipy if available or NumPy fallback."""
|
|
77
100
|
if (sigma_len is None or sigma_len == 0) and (sigma_pos is None or sigma_pos == 0):
|
|
78
101
|
return Z
|
|
79
102
|
if _have_scipy:
|
|
@@ -84,8 +107,16 @@ def plot_hmm_size_contours(
|
|
|
84
107
|
return _smooth_with_numpy_separable(Z, float(sigma_len or 0.0), float(sigma_pos or 0.0))
|
|
85
108
|
|
|
86
109
|
# --- gather unique ordered labels ---
|
|
87
|
-
samples =
|
|
88
|
-
|
|
110
|
+
samples = (
|
|
111
|
+
list(adata.obs[sample_col].cat.categories)
|
|
112
|
+
if getattr(adata.obs[sample_col], "dtype", None) == "category"
|
|
113
|
+
else list(pd.Categorical(adata.obs[sample_col]).categories)
|
|
114
|
+
)
|
|
115
|
+
refs = (
|
|
116
|
+
list(adata.obs[ref_obs_col].cat.categories)
|
|
117
|
+
if getattr(adata.obs[ref_obs_col], "dtype", None) == "category"
|
|
118
|
+
else list(pd.Categorical(adata.obs[ref_obs_col]).categories)
|
|
119
|
+
)
|
|
89
120
|
|
|
90
121
|
n_samples = len(samples)
|
|
91
122
|
n_refs = len(refs)
|
|
@@ -102,6 +133,7 @@ def plot_hmm_size_contours(
|
|
|
102
133
|
|
|
103
134
|
# helper to get dense layer array for subset
|
|
104
135
|
def _get_layer_array(layer):
|
|
136
|
+
"""Convert a layer to a dense NumPy array."""
|
|
105
137
|
arr = layer
|
|
106
138
|
# sparse -> toarray
|
|
107
139
|
if hasattr(arr, "toarray"):
|
|
@@ -146,7 +178,7 @@ def plot_hmm_size_contours(
|
|
|
146
178
|
fig_w = n_refs * figsize_per_cell[0]
|
|
147
179
|
fig_h = rows_on_page * figsize_per_cell[1]
|
|
148
180
|
fig, axes = plt.subplots(rows_on_page, n_refs, figsize=(fig_w, fig_h), squeeze=False)
|
|
149
|
-
fig.suptitle(f"HMM size contours (page {p+1}/{pages})", fontsize=12)
|
|
181
|
+
fig.suptitle(f"HMM size contours (page {p + 1}/{pages})", fontsize=12)
|
|
150
182
|
|
|
151
183
|
# for each panel compute p(length | position)
|
|
152
184
|
for i_row, sample in enumerate(page_samples):
|
|
@@ -160,7 +192,9 @@ def plot_hmm_size_contours(
|
|
|
160
192
|
ax.set_title(f"{sample} / {ref}")
|
|
161
193
|
continue
|
|
162
194
|
|
|
163
|
-
row_idx = np.nonzero(
|
|
195
|
+
row_idx = np.nonzero(
|
|
196
|
+
panel_mask.values if hasattr(panel_mask, "values") else np.asarray(panel_mask)
|
|
197
|
+
)[0]
|
|
164
198
|
if row_idx.size == 0:
|
|
165
199
|
ax.text(0.5, 0.5, "no reads", ha="center", va="center")
|
|
166
200
|
ax.set_title(f"{sample} / {ref}")
|
|
@@ -178,7 +212,9 @@ def plot_hmm_size_contours(
|
|
|
178
212
|
max_len_here = min(max_len, max_len_local)
|
|
179
213
|
|
|
180
214
|
lengths_range = np.arange(1, max_len_here + 1, dtype=int)
|
|
181
|
-
Z = np.zeros(
|
|
215
|
+
Z = np.zeros(
|
|
216
|
+
(len(lengths_range), n_positions), dtype=float
|
|
217
|
+
) # rows=length, cols=pos
|
|
182
218
|
|
|
183
219
|
# fill Z by efficient bincount across columns
|
|
184
220
|
for j in range(n_positions):
|
|
@@ -222,7 +258,9 @@ def plot_hmm_size_contours(
|
|
|
222
258
|
dy = 1.0
|
|
223
259
|
y_edges = np.concatenate([y - 0.5, [y[-1] + 0.5]])
|
|
224
260
|
|
|
225
|
-
pcm = ax.pcolormesh(
|
|
261
|
+
pcm = ax.pcolormesh(
|
|
262
|
+
x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
|
|
263
|
+
)
|
|
226
264
|
ax.set_title(f"{sample} / {ref}")
|
|
227
265
|
ax.set_ylabel("length")
|
|
228
266
|
if i_row == rows_on_page - 1:
|
|
@@ -243,9 +281,10 @@ def plot_hmm_size_contours(
|
|
|
243
281
|
# saving per page if requested
|
|
244
282
|
if save_path is not None:
|
|
245
283
|
import os
|
|
284
|
+
|
|
246
285
|
os.makedirs(save_path, exist_ok=True)
|
|
247
286
|
if save_each_page:
|
|
248
|
-
fname = f"hmm_size_page_{p+1:03d}.png"
|
|
287
|
+
fname = f"hmm_size_page_{p + 1:03d}.png"
|
|
249
288
|
out = os.path.join(save_path, fname)
|
|
250
289
|
fig.savefig(out, dpi=dpi, bbox_inches="tight")
|
|
251
290
|
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
5
|
+
|
|
1
6
|
def plot_volcano_relative_risk(
|
|
2
7
|
results_dict,
|
|
3
8
|
save_path=None,
|
|
@@ -20,10 +25,10 @@ def plot_volcano_relative_risk(
|
|
|
20
25
|
xlim (tuple): Optional x-axis limit.
|
|
21
26
|
ylim (tuple): Optional y-axis limit.
|
|
22
27
|
"""
|
|
23
|
-
import matplotlib.pyplot as plt
|
|
24
|
-
import numpy as np
|
|
25
28
|
import os
|
|
26
29
|
|
|
30
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
|
|
31
|
+
|
|
27
32
|
for ref, group_results in results_dict.items():
|
|
28
33
|
for group_label, (results_df, _) in group_results.items():
|
|
29
34
|
if results_df.empty:
|
|
@@ -31,8 +36,8 @@ def plot_volcano_relative_risk(
|
|
|
31
36
|
continue
|
|
32
37
|
|
|
33
38
|
# Split by site type
|
|
34
|
-
gpc_df = results_df[results_df[
|
|
35
|
-
cpg_df = results_df[results_df[
|
|
39
|
+
gpc_df = results_df[results_df["GpC_Site"]]
|
|
40
|
+
cpg_df = results_df[results_df["CpG_Site"]]
|
|
36
41
|
|
|
37
42
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
38
43
|
|
|
@@ -43,29 +48,29 @@ def plot_volcano_relative_risk(
|
|
|
43
48
|
|
|
44
49
|
# GpC as circles
|
|
45
50
|
sc1 = ax.scatter(
|
|
46
|
-
gpc_df[
|
|
47
|
-
gpc_df[
|
|
48
|
-
c=gpc_df[
|
|
49
|
-
cmap=
|
|
50
|
-
edgecolor=
|
|
51
|
+
gpc_df["Genomic_Position"],
|
|
52
|
+
gpc_df["log2_Relative_Risk"],
|
|
53
|
+
c=gpc_df["-log10_Adj_P"],
|
|
54
|
+
cmap="coolwarm",
|
|
55
|
+
edgecolor="k",
|
|
51
56
|
s=40,
|
|
52
|
-
marker=
|
|
53
|
-
label=
|
|
57
|
+
marker="o",
|
|
58
|
+
label="GpC",
|
|
54
59
|
)
|
|
55
60
|
|
|
56
61
|
# CpG as stars
|
|
57
62
|
sc2 = ax.scatter(
|
|
58
|
-
cpg_df[
|
|
59
|
-
cpg_df[
|
|
60
|
-
c=cpg_df[
|
|
61
|
-
cmap=
|
|
62
|
-
edgecolor=
|
|
63
|
+
cpg_df["Genomic_Position"],
|
|
64
|
+
cpg_df["log2_Relative_Risk"],
|
|
65
|
+
c=cpg_df["-log10_Adj_P"],
|
|
66
|
+
cmap="coolwarm",
|
|
67
|
+
edgecolor="k",
|
|
63
68
|
s=60,
|
|
64
|
-
marker=
|
|
65
|
-
label=
|
|
69
|
+
marker="*",
|
|
70
|
+
label="CpG",
|
|
66
71
|
)
|
|
67
72
|
|
|
68
|
-
ax.axhline(y=0, color=
|
|
73
|
+
ax.axhline(y=0, color="gray", linestyle="--")
|
|
69
74
|
ax.set_xlabel("Genomic Position")
|
|
70
75
|
ax.set_ylabel("log2(Relative Risk)")
|
|
71
76
|
ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
|
|
@@ -75,8 +80,8 @@ def plot_volcano_relative_risk(
|
|
|
75
80
|
if ylim:
|
|
76
81
|
ax.set_ylim(ylim)
|
|
77
82
|
|
|
78
|
-
ax.spines[
|
|
79
|
-
ax.spines[
|
|
83
|
+
ax.spines["top"].set_visible(False)
|
|
84
|
+
ax.spines["right"].set_visible(False)
|
|
80
85
|
|
|
81
86
|
cbar = plt.colorbar(sc1, ax=ax)
|
|
82
87
|
cbar.set_label("-log10(Adjusted P-Value)")
|
|
@@ -87,13 +92,19 @@ def plot_volcano_relative_risk(
|
|
|
87
92
|
# Save if requested
|
|
88
93
|
if save_path:
|
|
89
94
|
os.makedirs(save_path, exist_ok=True)
|
|
90
|
-
safe_name =
|
|
95
|
+
safe_name = (
|
|
96
|
+
f"{ref}_{group_label}".replace("=", "")
|
|
97
|
+
.replace("__", "_")
|
|
98
|
+
.replace(",", "_")
|
|
99
|
+
.replace(" ", "_")
|
|
100
|
+
)
|
|
91
101
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
92
102
|
plt.savefig(out_file, dpi=300)
|
|
93
103
|
print(f"Saved: {out_file}")
|
|
94
104
|
|
|
95
105
|
plt.show()
|
|
96
106
|
|
|
107
|
+
|
|
97
108
|
def plot_bar_relative_risk(
|
|
98
109
|
results_dict,
|
|
99
110
|
sort_by_position=True,
|
|
@@ -102,7 +113,7 @@ def plot_bar_relative_risk(
|
|
|
102
113
|
save_path=None,
|
|
103
114
|
highlight_regions=None, # List of (start, end) tuples
|
|
104
115
|
highlight_color="lightgray",
|
|
105
|
-
highlight_alpha=0.3
|
|
116
|
+
highlight_alpha=0.3,
|
|
106
117
|
):
|
|
107
118
|
"""
|
|
108
119
|
Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
|
|
@@ -116,10 +127,10 @@ def plot_bar_relative_risk(
|
|
|
116
127
|
highlight_color (str): Color of shaded region.
|
|
117
128
|
highlight_alpha (float): Transparency of shaded region.
|
|
118
129
|
"""
|
|
119
|
-
import matplotlib.pyplot as plt
|
|
120
|
-
import numpy as np
|
|
121
130
|
import os
|
|
122
131
|
|
|
132
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
|
|
133
|
+
|
|
123
134
|
for ref, group_data in results_dict.items():
|
|
124
135
|
for group_label, (df, _) in group_data.items():
|
|
125
136
|
if df.empty:
|
|
@@ -127,14 +138,14 @@ def plot_bar_relative_risk(
|
|
|
127
138
|
continue
|
|
128
139
|
|
|
129
140
|
df = df.copy()
|
|
130
|
-
df[
|
|
141
|
+
df["Genomic_Position"] = df["Genomic_Position"].astype(int)
|
|
131
142
|
|
|
132
143
|
if sort_by_position:
|
|
133
|
-
df = df.sort_values(
|
|
144
|
+
df = df.sort_values("Genomic_Position")
|
|
134
145
|
|
|
135
|
-
gpc_mask = df[
|
|
136
|
-
cpg_mask = df[
|
|
137
|
-
both_mask = df[
|
|
146
|
+
gpc_mask = df["GpC_Site"] & ~df["CpG_Site"]
|
|
147
|
+
cpg_mask = df["CpG_Site"] & ~df["GpC_Site"]
|
|
148
|
+
both_mask = df["GpC_Site"] & df["CpG_Site"]
|
|
138
149
|
|
|
139
150
|
fig, ax = plt.subplots(figsize=(14, 6))
|
|
140
151
|
|
|
@@ -145,36 +156,36 @@ def plot_bar_relative_risk(
|
|
|
145
156
|
|
|
146
157
|
# Bar plots
|
|
147
158
|
ax.bar(
|
|
148
|
-
df[
|
|
149
|
-
df[
|
|
159
|
+
df["Genomic_Position"][gpc_mask],
|
|
160
|
+
df["log2_Relative_Risk"][gpc_mask],
|
|
150
161
|
width=10,
|
|
151
|
-
color=
|
|
152
|
-
label=
|
|
153
|
-
edgecolor=
|
|
162
|
+
color="steelblue",
|
|
163
|
+
label="GpC Site",
|
|
164
|
+
edgecolor="black",
|
|
154
165
|
)
|
|
155
166
|
|
|
156
167
|
ax.bar(
|
|
157
|
-
df[
|
|
158
|
-
df[
|
|
168
|
+
df["Genomic_Position"][cpg_mask],
|
|
169
|
+
df["log2_Relative_Risk"][cpg_mask],
|
|
159
170
|
width=10,
|
|
160
|
-
color=
|
|
161
|
-
label=
|
|
162
|
-
edgecolor=
|
|
171
|
+
color="darkorange",
|
|
172
|
+
label="CpG Site",
|
|
173
|
+
edgecolor="black",
|
|
163
174
|
)
|
|
164
175
|
|
|
165
176
|
if both_mask.any():
|
|
166
177
|
ax.bar(
|
|
167
|
-
df[
|
|
168
|
-
df[
|
|
178
|
+
df["Genomic_Position"][both_mask],
|
|
179
|
+
df["log2_Relative_Risk"][both_mask],
|
|
169
180
|
width=10,
|
|
170
|
-
color=
|
|
171
|
-
label=
|
|
172
|
-
edgecolor=
|
|
181
|
+
color="purple",
|
|
182
|
+
label="GpC + CpG",
|
|
183
|
+
edgecolor="black",
|
|
173
184
|
)
|
|
174
185
|
|
|
175
|
-
ax.axhline(y=0, color=
|
|
176
|
-
ax.set_xlabel(
|
|
177
|
-
ax.set_ylabel(
|
|
186
|
+
ax.axhline(y=0, color="gray", linestyle="--")
|
|
187
|
+
ax.set_xlabel("Genomic Position")
|
|
188
|
+
ax.set_ylabel("log2(Relative Risk)")
|
|
178
189
|
ax.set_title(f"{ref} — {group_label}")
|
|
179
190
|
ax.legend()
|
|
180
191
|
|
|
@@ -183,20 +194,23 @@ def plot_bar_relative_risk(
|
|
|
183
194
|
if ylim:
|
|
184
195
|
ax.set_ylim(ylim)
|
|
185
196
|
|
|
186
|
-
ax.spines[
|
|
187
|
-
ax.spines[
|
|
197
|
+
ax.spines["top"].set_visible(False)
|
|
198
|
+
ax.spines["right"].set_visible(False)
|
|
188
199
|
|
|
189
200
|
plt.tight_layout()
|
|
190
201
|
|
|
191
202
|
if save_path:
|
|
192
203
|
os.makedirs(save_path, exist_ok=True)
|
|
193
|
-
safe_name =
|
|
204
|
+
safe_name = (
|
|
205
|
+
f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
|
|
206
|
+
)
|
|
194
207
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
195
208
|
plt.savefig(out_file, dpi=300)
|
|
196
209
|
print(f"📁 Saved: {out_file}")
|
|
197
210
|
|
|
198
211
|
plt.show()
|
|
199
212
|
|
|
213
|
+
|
|
200
214
|
def plot_positionwise_matrix(
|
|
201
215
|
adata,
|
|
202
216
|
key="positionwise_result",
|
|
@@ -210,35 +224,40 @@ def plot_positionwise_matrix(
|
|
|
210
224
|
xtick_step=10,
|
|
211
225
|
ytick_step=10,
|
|
212
226
|
save_path=None,
|
|
213
|
-
highlight_position=None,
|
|
214
|
-
highlight_axis="row",
|
|
215
|
-
annotate_points=False
|
|
227
|
+
highlight_position=None, # Can be a single int/float or list of them
|
|
228
|
+
highlight_axis="row", # "row" or "column"
|
|
229
|
+
annotate_points=False, # ✅ New option
|
|
216
230
|
):
|
|
217
231
|
"""
|
|
218
232
|
Plots positionwise matrices stored in adata.uns[key], with an optional line plot
|
|
219
233
|
for specified row(s) or column(s), and highlights them on the heatmap.
|
|
220
234
|
"""
|
|
221
|
-
import
|
|
222
|
-
|
|
235
|
+
import os
|
|
236
|
+
|
|
223
237
|
import numpy as np
|
|
224
238
|
import pandas as pd
|
|
225
|
-
|
|
239
|
+
|
|
240
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
|
|
241
|
+
sns = require("seaborn", extra="plotting", purpose="position stats plots")
|
|
226
242
|
|
|
227
243
|
def find_closest_index(index, target):
|
|
244
|
+
"""Find the index value closest to a target value."""
|
|
228
245
|
index_vals = pd.to_numeric(index, errors="coerce")
|
|
229
246
|
target_val = pd.to_numeric([target], errors="coerce")[0]
|
|
230
247
|
diffs = pd.Series(np.abs(index_vals - target_val), index=index)
|
|
231
248
|
return diffs.idxmin()
|
|
232
249
|
|
|
233
250
|
# Ensure highlight_position is a list
|
|
234
|
-
if highlight_position is not None and not isinstance(
|
|
251
|
+
if highlight_position is not None and not isinstance(
|
|
252
|
+
highlight_position, (list, tuple, np.ndarray)
|
|
253
|
+
):
|
|
235
254
|
highlight_position = [highlight_position]
|
|
236
255
|
|
|
237
256
|
for group, mat_df in adata.uns[key].items():
|
|
238
257
|
mat = mat_df.copy()
|
|
239
258
|
|
|
240
259
|
if log_transform:
|
|
241
|
-
with np.errstate(divide=
|
|
260
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
242
261
|
if log_base == "log1p":
|
|
243
262
|
mat = np.log1p(mat)
|
|
244
263
|
elif log_base == "log2":
|
|
@@ -276,7 +295,7 @@ def plot_positionwise_matrix(
|
|
|
276
295
|
vmin=vmin,
|
|
277
296
|
vmax=vmax,
|
|
278
297
|
cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
|
|
279
|
-
ax=heat_ax
|
|
298
|
+
ax=heat_ax,
|
|
280
299
|
)
|
|
281
300
|
|
|
282
301
|
heat_ax.set_title(f"{key} — {group}", pad=20)
|
|
@@ -295,17 +314,27 @@ def plot_positionwise_matrix(
|
|
|
295
314
|
series = mat.loc[closest]
|
|
296
315
|
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
297
316
|
idx = mat.index.get_loc(closest)
|
|
298
|
-
heat_ax.axhline(
|
|
317
|
+
heat_ax.axhline(
|
|
318
|
+
idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
|
|
319
|
+
)
|
|
299
320
|
label = f"Row {pos} → {closest}"
|
|
300
321
|
else:
|
|
301
322
|
closest = find_closest_index(mat.columns, pos)
|
|
302
323
|
series = mat[closest]
|
|
303
324
|
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
304
325
|
idx = mat.columns.get_loc(closest)
|
|
305
|
-
heat_ax.axvline(
|
|
326
|
+
heat_ax.axvline(
|
|
327
|
+
idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
|
|
328
|
+
)
|
|
306
329
|
label = f"Col {pos} → {closest}"
|
|
307
330
|
|
|
308
|
-
line = line_ax.plot(
|
|
331
|
+
line = line_ax.plot(
|
|
332
|
+
x_vals,
|
|
333
|
+
series.values,
|
|
334
|
+
marker="o",
|
|
335
|
+
label=label,
|
|
336
|
+
color=colors[i % len(colors)],
|
|
337
|
+
)
|
|
309
338
|
|
|
310
339
|
# Annotate each point
|
|
311
340
|
if annotate_points:
|
|
@@ -316,12 +345,18 @@ def plot_positionwise_matrix(
|
|
|
316
345
|
xy=(x, y),
|
|
317
346
|
textcoords="offset points",
|
|
318
347
|
xytext=(0, 5),
|
|
319
|
-
ha=
|
|
320
|
-
fontsize=8
|
|
348
|
+
ha="center",
|
|
349
|
+
fontsize=8,
|
|
321
350
|
)
|
|
322
351
|
except Exception as e:
|
|
323
|
-
line_ax.text(
|
|
324
|
-
|
|
352
|
+
line_ax.text(
|
|
353
|
+
0.5,
|
|
354
|
+
0.5,
|
|
355
|
+
f"⚠️ Error plotting {highlight_axis} @ {pos}",
|
|
356
|
+
ha="center",
|
|
357
|
+
va="center",
|
|
358
|
+
fontsize=10,
|
|
359
|
+
)
|
|
325
360
|
print(f"Error plotting line for {highlight_axis}={pos}: {e}")
|
|
326
361
|
|
|
327
362
|
line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
|
|
@@ -342,6 +377,7 @@ def plot_positionwise_matrix(
|
|
|
342
377
|
|
|
343
378
|
plt.show()
|
|
344
379
|
|
|
380
|
+
|
|
345
381
|
def plot_positionwise_matrix_grid(
|
|
346
382
|
adata,
|
|
347
383
|
key,
|
|
@@ -356,32 +392,63 @@ def plot_positionwise_matrix_grid(
|
|
|
356
392
|
xtick_step=10,
|
|
357
393
|
ytick_step=10,
|
|
358
394
|
parallel=False,
|
|
359
|
-
max_threads=None
|
|
395
|
+
max_threads=None,
|
|
360
396
|
):
|
|
361
|
-
|
|
362
|
-
|
|
397
|
+
"""Plot a grid of positionwise matrices grouped by metadata.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
adata: AnnData containing matrices in ``adata.uns``.
|
|
401
|
+
key: Key for positionwise matrices.
|
|
402
|
+
outer_keys: Keys for outer grouping.
|
|
403
|
+
inner_keys: Keys for inner grouping.
|
|
404
|
+
log_transform: Optional log transform (``log2`` or ``log1p``).
|
|
405
|
+
vmin: Minimum color scale value.
|
|
406
|
+
vmax: Maximum color scale value.
|
|
407
|
+
cmap: Matplotlib colormap.
|
|
408
|
+
save_path: Optional path to save plots.
|
|
409
|
+
figsize: Figure size.
|
|
410
|
+
xtick_step: X-axis tick step.
|
|
411
|
+
ytick_step: Y-axis tick step.
|
|
412
|
+
parallel: Whether to plot in parallel.
|
|
413
|
+
max_threads: Max thread count for parallel plotting.
|
|
414
|
+
"""
|
|
415
|
+
import os
|
|
416
|
+
|
|
363
417
|
import numpy as np
|
|
364
418
|
import pandas as pd
|
|
365
|
-
import os
|
|
366
|
-
from matplotlib.gridspec import GridSpec
|
|
367
419
|
from joblib import Parallel, delayed
|
|
368
420
|
|
|
421
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
|
|
422
|
+
sns = require("seaborn", extra="plotting", purpose="position stats plots")
|
|
423
|
+
grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="position stats plots")
|
|
424
|
+
GridSpec = grid_spec.GridSpec
|
|
425
|
+
|
|
369
426
|
matrices = adata.uns[key]
|
|
370
427
|
group_labels = list(matrices.keys())
|
|
371
428
|
|
|
372
|
-
parsed_inner = pd.DataFrame(
|
|
373
|
-
|
|
429
|
+
parsed_inner = pd.DataFrame(
|
|
430
|
+
[dict(zip(inner_keys, g.split("_")[-len(inner_keys) :])) for g in group_labels]
|
|
431
|
+
)
|
|
432
|
+
parsed_outer = pd.Series(
|
|
433
|
+
["_".join(g.split("_")[: -len(inner_keys)]) for g in group_labels], name="outer"
|
|
434
|
+
)
|
|
374
435
|
parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
|
|
375
436
|
|
|
376
437
|
def plot_one_grid(outer_label):
|
|
377
|
-
|
|
378
|
-
selected["
|
|
438
|
+
"""Plot one grid for a specific outer label."""
|
|
439
|
+
selected = parsed[parsed["outer"] == outer_label].copy()
|
|
440
|
+
selected["group_str"] = [
|
|
441
|
+
f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}"
|
|
442
|
+
for _, row in selected.iterrows()
|
|
443
|
+
]
|
|
379
444
|
|
|
380
445
|
row_vals = sorted(selected[inner_keys[0]].unique())
|
|
381
446
|
col_vals = sorted(selected[inner_keys[1]].unique())
|
|
382
447
|
|
|
383
448
|
fig = plt.figure(figsize=figsize)
|
|
384
|
-
gs = GridSpec(
|
|
449
|
+
gs = GridSpec(
|
|
450
|
+
len(row_vals), len(col_vals) + 1, width_ratios=[1] * len(col_vals) + [0.05], wspace=0.3
|
|
451
|
+
)
|
|
385
452
|
axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
|
|
386
453
|
|
|
387
454
|
local_vmin, local_vmax = vmin, vmax
|
|
@@ -397,10 +464,7 @@ def plot_positionwise_matrix_grid(
|
|
|
397
464
|
local_vmin = -vmax_auto if vmin is None else vmin
|
|
398
465
|
local_vmax = vmax_auto if vmax is None else vmax
|
|
399
466
|
|
|
400
|
-
cbar_label = {
|
|
401
|
-
"log2": "log2(Value)",
|
|
402
|
-
"log1p": "log1p(Value)"
|
|
403
|
-
}.get(log_transform, "Value")
|
|
467
|
+
cbar_label = {"log2": "log2(Value)", "log1p": "log1p(Value)"}.get(log_transform, "Value")
|
|
404
468
|
|
|
405
469
|
cbar_ax = fig.add_subplot(gs[:, -1])
|
|
406
470
|
|
|
@@ -431,9 +495,11 @@ def plot_positionwise_matrix_grid(
|
|
|
431
495
|
vmax=local_vmax,
|
|
432
496
|
cbar=(i == 0 and j == 0),
|
|
433
497
|
cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
|
|
434
|
-
cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
|
|
498
|
+
cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""},
|
|
499
|
+
)
|
|
500
|
+
ax.set_title(
|
|
501
|
+
f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8
|
|
435
502
|
)
|
|
436
|
-
ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
|
|
437
503
|
|
|
438
504
|
xticks = data.columns.astype(int)
|
|
439
505
|
yticks = data.index.astype(int)
|
|
@@ -448,15 +514,17 @@ def plot_positionwise_matrix_grid(
|
|
|
448
514
|
if save_path:
|
|
449
515
|
os.makedirs(save_path, exist_ok=True)
|
|
450
516
|
fname = outer_label.replace("_", "").replace("=", "") + ".png"
|
|
451
|
-
plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches=
|
|
517
|
+
plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches="tight")
|
|
452
518
|
print(f"Saved {fname}")
|
|
453
519
|
|
|
454
520
|
plt.close(fig)
|
|
455
521
|
|
|
456
522
|
if parallel:
|
|
457
|
-
Parallel(n_jobs=max_threads)(
|
|
523
|
+
Parallel(n_jobs=max_threads)(
|
|
524
|
+
delayed(plot_one_grid)(outer_label) for outer_label in parsed["outer"].unique()
|
|
525
|
+
)
|
|
458
526
|
else:
|
|
459
|
-
for outer_label in parsed[
|
|
527
|
+
for outer_label in parsed["outer"].unique():
|
|
460
528
|
plot_one_grid(outer_label)
|
|
461
529
|
|
|
462
|
-
print("Finished plotting all grids.")
|
|
530
|
+
print("Finished plotting all grids.")
|