smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import
|
|
3
|
-
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
5
6
|
from matplotlib.backends.backend_pdf import PdfPages
|
|
6
7
|
|
|
8
|
+
|
|
7
9
|
def plot_hmm_size_contours(
|
|
8
10
|
adata,
|
|
9
11
|
length_layer: str,
|
|
@@ -36,32 +38,41 @@ def plot_hmm_size_contours(
|
|
|
36
38
|
|
|
37
39
|
Other args are the same as prior function.
|
|
38
40
|
"""
|
|
41
|
+
|
|
39
42
|
# --- helper: gaussian smoothing (scipy fallback -> numpy separable conv) ---
|
|
40
43
|
def _gaussian_1d_kernel(sigma: float, eps: float = 1e-12):
|
|
44
|
+
"""Build a normalized 1D Gaussian kernel."""
|
|
41
45
|
if sigma <= 0 or sigma is None:
|
|
42
46
|
return np.array([1.0], dtype=float)
|
|
43
47
|
# choose kernel size = odd ~ 6*sigma (covers +/-3 sigma)
|
|
44
48
|
radius = max(1, int(math.ceil(3.0 * float(sigma))))
|
|
45
49
|
xs = np.arange(-radius, radius + 1, dtype=float)
|
|
46
|
-
k = np.exp(-(xs
|
|
50
|
+
k = np.exp(-(xs**2) / (2.0 * sigma**2))
|
|
47
51
|
k_sum = k.sum()
|
|
48
52
|
if k_sum <= eps:
|
|
49
53
|
k = np.array([1.0], dtype=float)
|
|
50
54
|
k_sum = 1.0
|
|
51
55
|
return k / k_sum
|
|
52
56
|
|
|
53
|
-
def _smooth_with_numpy_separable(
|
|
57
|
+
def _smooth_with_numpy_separable(
|
|
58
|
+
Z: np.ndarray, sigma_len: float, sigma_pos: float
|
|
59
|
+
) -> np.ndarray:
|
|
60
|
+
"""Apply separable Gaussian smoothing with NumPy."""
|
|
54
61
|
# Z shape: (n_lengths, n_positions)
|
|
55
62
|
out = Z.copy()
|
|
56
63
|
# smooth along length axis (axis=0)
|
|
57
64
|
if sigma_len and sigma_len > 0:
|
|
58
65
|
k_len = _gaussian_1d_kernel(sigma_len)
|
|
59
66
|
# convolve each column
|
|
60
|
-
out = np.apply_along_axis(
|
|
67
|
+
out = np.apply_along_axis(
|
|
68
|
+
lambda col: np.convolve(col, k_len, mode="same"), axis=0, arr=out
|
|
69
|
+
)
|
|
61
70
|
# smooth along position axis (axis=1)
|
|
62
71
|
if sigma_pos and sigma_pos > 0:
|
|
63
72
|
k_pos = _gaussian_1d_kernel(sigma_pos)
|
|
64
|
-
out = np.apply_along_axis(
|
|
73
|
+
out = np.apply_along_axis(
|
|
74
|
+
lambda row: np.convolve(row, k_pos, mode="same"), axis=1, arr=out
|
|
75
|
+
)
|
|
65
76
|
return out
|
|
66
77
|
|
|
67
78
|
# prefer scipy.ndimage if available (faster and better boundary handling)
|
|
@@ -69,11 +80,13 @@ def plot_hmm_size_contours(
|
|
|
69
80
|
if use_scipy_if_available:
|
|
70
81
|
try:
|
|
71
82
|
from scipy.ndimage import gaussian_filter as _scipy_gaussian_filter
|
|
83
|
+
|
|
72
84
|
_have_scipy = True
|
|
73
85
|
except Exception:
|
|
74
86
|
_have_scipy = False
|
|
75
87
|
|
|
76
88
|
def _smooth_Z(Z: np.ndarray, sigma_len: float, sigma_pos: float) -> np.ndarray:
|
|
89
|
+
"""Smooth a matrix using scipy if available or NumPy fallback."""
|
|
77
90
|
if (sigma_len is None or sigma_len == 0) and (sigma_pos is None or sigma_pos == 0):
|
|
78
91
|
return Z
|
|
79
92
|
if _have_scipy:
|
|
@@ -84,8 +97,16 @@ def plot_hmm_size_contours(
|
|
|
84
97
|
return _smooth_with_numpy_separable(Z, float(sigma_len or 0.0), float(sigma_pos or 0.0))
|
|
85
98
|
|
|
86
99
|
# --- gather unique ordered labels ---
|
|
87
|
-
samples =
|
|
88
|
-
|
|
100
|
+
samples = (
|
|
101
|
+
list(adata.obs[sample_col].cat.categories)
|
|
102
|
+
if getattr(adata.obs[sample_col], "dtype", None) == "category"
|
|
103
|
+
else list(pd.Categorical(adata.obs[sample_col]).categories)
|
|
104
|
+
)
|
|
105
|
+
refs = (
|
|
106
|
+
list(adata.obs[ref_obs_col].cat.categories)
|
|
107
|
+
if getattr(adata.obs[ref_obs_col], "dtype", None) == "category"
|
|
108
|
+
else list(pd.Categorical(adata.obs[ref_obs_col]).categories)
|
|
109
|
+
)
|
|
89
110
|
|
|
90
111
|
n_samples = len(samples)
|
|
91
112
|
n_refs = len(refs)
|
|
@@ -102,6 +123,7 @@ def plot_hmm_size_contours(
|
|
|
102
123
|
|
|
103
124
|
# helper to get dense layer array for subset
|
|
104
125
|
def _get_layer_array(layer):
|
|
126
|
+
"""Convert a layer to a dense NumPy array."""
|
|
105
127
|
arr = layer
|
|
106
128
|
# sparse -> toarray
|
|
107
129
|
if hasattr(arr, "toarray"):
|
|
@@ -146,7 +168,7 @@ def plot_hmm_size_contours(
|
|
|
146
168
|
fig_w = n_refs * figsize_per_cell[0]
|
|
147
169
|
fig_h = rows_on_page * figsize_per_cell[1]
|
|
148
170
|
fig, axes = plt.subplots(rows_on_page, n_refs, figsize=(fig_w, fig_h), squeeze=False)
|
|
149
|
-
fig.suptitle(f"HMM size contours (page {p+1}/{pages})", fontsize=12)
|
|
171
|
+
fig.suptitle(f"HMM size contours (page {p + 1}/{pages})", fontsize=12)
|
|
150
172
|
|
|
151
173
|
# for each panel compute p(length | position)
|
|
152
174
|
for i_row, sample in enumerate(page_samples):
|
|
@@ -160,7 +182,9 @@ def plot_hmm_size_contours(
|
|
|
160
182
|
ax.set_title(f"{sample} / {ref}")
|
|
161
183
|
continue
|
|
162
184
|
|
|
163
|
-
row_idx = np.nonzero(
|
|
185
|
+
row_idx = np.nonzero(
|
|
186
|
+
panel_mask.values if hasattr(panel_mask, "values") else np.asarray(panel_mask)
|
|
187
|
+
)[0]
|
|
164
188
|
if row_idx.size == 0:
|
|
165
189
|
ax.text(0.5, 0.5, "no reads", ha="center", va="center")
|
|
166
190
|
ax.set_title(f"{sample} / {ref}")
|
|
@@ -178,7 +202,9 @@ def plot_hmm_size_contours(
|
|
|
178
202
|
max_len_here = min(max_len, max_len_local)
|
|
179
203
|
|
|
180
204
|
lengths_range = np.arange(1, max_len_here + 1, dtype=int)
|
|
181
|
-
Z = np.zeros(
|
|
205
|
+
Z = np.zeros(
|
|
206
|
+
(len(lengths_range), n_positions), dtype=float
|
|
207
|
+
) # rows=length, cols=pos
|
|
182
208
|
|
|
183
209
|
# fill Z by efficient bincount across columns
|
|
184
210
|
for j in range(n_positions):
|
|
@@ -222,7 +248,9 @@ def plot_hmm_size_contours(
|
|
|
222
248
|
dy = 1.0
|
|
223
249
|
y_edges = np.concatenate([y - 0.5, [y[-1] + 0.5]])
|
|
224
250
|
|
|
225
|
-
pcm = ax.pcolormesh(
|
|
251
|
+
pcm = ax.pcolormesh(
|
|
252
|
+
x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
|
|
253
|
+
)
|
|
226
254
|
ax.set_title(f"{sample} / {ref}")
|
|
227
255
|
ax.set_ylabel("length")
|
|
228
256
|
if i_row == rows_on_page - 1:
|
|
@@ -243,9 +271,10 @@ def plot_hmm_size_contours(
|
|
|
243
271
|
# saving per page if requested
|
|
244
272
|
if save_path is not None:
|
|
245
273
|
import os
|
|
274
|
+
|
|
246
275
|
os.makedirs(save_path, exist_ok=True)
|
|
247
276
|
if save_each_page:
|
|
248
|
-
fname = f"hmm_size_page_{p+1:03d}.png"
|
|
277
|
+
fname = f"hmm_size_page_{p + 1:03d}.png"
|
|
249
278
|
out = os.path.join(save_path, fname)
|
|
250
279
|
fig.savefig(out, dpi=dpi, bbox_inches="tight")
|
|
251
280
|
|
|
@@ -20,10 +20,10 @@ def plot_volcano_relative_risk(
|
|
|
20
20
|
xlim (tuple): Optional x-axis limit.
|
|
21
21
|
ylim (tuple): Optional y-axis limit.
|
|
22
22
|
"""
|
|
23
|
-
import matplotlib.pyplot as plt
|
|
24
|
-
import numpy as np
|
|
25
23
|
import os
|
|
26
24
|
|
|
25
|
+
import matplotlib.pyplot as plt
|
|
26
|
+
|
|
27
27
|
for ref, group_results in results_dict.items():
|
|
28
28
|
for group_label, (results_df, _) in group_results.items():
|
|
29
29
|
if results_df.empty:
|
|
@@ -31,8 +31,8 @@ def plot_volcano_relative_risk(
|
|
|
31
31
|
continue
|
|
32
32
|
|
|
33
33
|
# Split by site type
|
|
34
|
-
gpc_df = results_df[results_df[
|
|
35
|
-
cpg_df = results_df[results_df[
|
|
34
|
+
gpc_df = results_df[results_df["GpC_Site"]]
|
|
35
|
+
cpg_df = results_df[results_df["CpG_Site"]]
|
|
36
36
|
|
|
37
37
|
fig, ax = plt.subplots(figsize=(12, 6))
|
|
38
38
|
|
|
@@ -43,29 +43,29 @@ def plot_volcano_relative_risk(
|
|
|
43
43
|
|
|
44
44
|
# GpC as circles
|
|
45
45
|
sc1 = ax.scatter(
|
|
46
|
-
gpc_df[
|
|
47
|
-
gpc_df[
|
|
48
|
-
c=gpc_df[
|
|
49
|
-
cmap=
|
|
50
|
-
edgecolor=
|
|
46
|
+
gpc_df["Genomic_Position"],
|
|
47
|
+
gpc_df["log2_Relative_Risk"],
|
|
48
|
+
c=gpc_df["-log10_Adj_P"],
|
|
49
|
+
cmap="coolwarm",
|
|
50
|
+
edgecolor="k",
|
|
51
51
|
s=40,
|
|
52
|
-
marker=
|
|
53
|
-
label=
|
|
52
|
+
marker="o",
|
|
53
|
+
label="GpC",
|
|
54
54
|
)
|
|
55
55
|
|
|
56
56
|
# CpG as stars
|
|
57
57
|
sc2 = ax.scatter(
|
|
58
|
-
cpg_df[
|
|
59
|
-
cpg_df[
|
|
60
|
-
c=cpg_df[
|
|
61
|
-
cmap=
|
|
62
|
-
edgecolor=
|
|
58
|
+
cpg_df["Genomic_Position"],
|
|
59
|
+
cpg_df["log2_Relative_Risk"],
|
|
60
|
+
c=cpg_df["-log10_Adj_P"],
|
|
61
|
+
cmap="coolwarm",
|
|
62
|
+
edgecolor="k",
|
|
63
63
|
s=60,
|
|
64
|
-
marker=
|
|
65
|
-
label=
|
|
64
|
+
marker="*",
|
|
65
|
+
label="CpG",
|
|
66
66
|
)
|
|
67
67
|
|
|
68
|
-
ax.axhline(y=0, color=
|
|
68
|
+
ax.axhline(y=0, color="gray", linestyle="--")
|
|
69
69
|
ax.set_xlabel("Genomic Position")
|
|
70
70
|
ax.set_ylabel("log2(Relative Risk)")
|
|
71
71
|
ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
|
|
@@ -75,8 +75,8 @@ def plot_volcano_relative_risk(
|
|
|
75
75
|
if ylim:
|
|
76
76
|
ax.set_ylim(ylim)
|
|
77
77
|
|
|
78
|
-
ax.spines[
|
|
79
|
-
ax.spines[
|
|
78
|
+
ax.spines["top"].set_visible(False)
|
|
79
|
+
ax.spines["right"].set_visible(False)
|
|
80
80
|
|
|
81
81
|
cbar = plt.colorbar(sc1, ax=ax)
|
|
82
82
|
cbar.set_label("-log10(Adjusted P-Value)")
|
|
@@ -87,13 +87,19 @@ def plot_volcano_relative_risk(
|
|
|
87
87
|
# Save if requested
|
|
88
88
|
if save_path:
|
|
89
89
|
os.makedirs(save_path, exist_ok=True)
|
|
90
|
-
safe_name =
|
|
90
|
+
safe_name = (
|
|
91
|
+
f"{ref}_{group_label}".replace("=", "")
|
|
92
|
+
.replace("__", "_")
|
|
93
|
+
.replace(",", "_")
|
|
94
|
+
.replace(" ", "_")
|
|
95
|
+
)
|
|
91
96
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
92
97
|
plt.savefig(out_file, dpi=300)
|
|
93
|
-
print(f"
|
|
98
|
+
print(f"Saved: {out_file}")
|
|
94
99
|
|
|
95
100
|
plt.show()
|
|
96
101
|
|
|
102
|
+
|
|
97
103
|
def plot_bar_relative_risk(
|
|
98
104
|
results_dict,
|
|
99
105
|
sort_by_position=True,
|
|
@@ -102,7 +108,7 @@ def plot_bar_relative_risk(
|
|
|
102
108
|
save_path=None,
|
|
103
109
|
highlight_regions=None, # List of (start, end) tuples
|
|
104
110
|
highlight_color="lightgray",
|
|
105
|
-
highlight_alpha=0.3
|
|
111
|
+
highlight_alpha=0.3,
|
|
106
112
|
):
|
|
107
113
|
"""
|
|
108
114
|
Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
|
|
@@ -116,10 +122,10 @@ def plot_bar_relative_risk(
|
|
|
116
122
|
highlight_color (str): Color of shaded region.
|
|
117
123
|
highlight_alpha (float): Transparency of shaded region.
|
|
118
124
|
"""
|
|
119
|
-
import matplotlib.pyplot as plt
|
|
120
|
-
import numpy as np
|
|
121
125
|
import os
|
|
122
126
|
|
|
127
|
+
import matplotlib.pyplot as plt
|
|
128
|
+
|
|
123
129
|
for ref, group_data in results_dict.items():
|
|
124
130
|
for group_label, (df, _) in group_data.items():
|
|
125
131
|
if df.empty:
|
|
@@ -127,14 +133,14 @@ def plot_bar_relative_risk(
|
|
|
127
133
|
continue
|
|
128
134
|
|
|
129
135
|
df = df.copy()
|
|
130
|
-
df[
|
|
136
|
+
df["Genomic_Position"] = df["Genomic_Position"].astype(int)
|
|
131
137
|
|
|
132
138
|
if sort_by_position:
|
|
133
|
-
df = df.sort_values(
|
|
139
|
+
df = df.sort_values("Genomic_Position")
|
|
134
140
|
|
|
135
|
-
gpc_mask = df[
|
|
136
|
-
cpg_mask = df[
|
|
137
|
-
both_mask = df[
|
|
141
|
+
gpc_mask = df["GpC_Site"] & ~df["CpG_Site"]
|
|
142
|
+
cpg_mask = df["CpG_Site"] & ~df["GpC_Site"]
|
|
143
|
+
both_mask = df["GpC_Site"] & df["CpG_Site"]
|
|
138
144
|
|
|
139
145
|
fig, ax = plt.subplots(figsize=(14, 6))
|
|
140
146
|
|
|
@@ -145,36 +151,36 @@ def plot_bar_relative_risk(
|
|
|
145
151
|
|
|
146
152
|
# Bar plots
|
|
147
153
|
ax.bar(
|
|
148
|
-
df[
|
|
149
|
-
df[
|
|
154
|
+
df["Genomic_Position"][gpc_mask],
|
|
155
|
+
df["log2_Relative_Risk"][gpc_mask],
|
|
150
156
|
width=10,
|
|
151
|
-
color=
|
|
152
|
-
label=
|
|
153
|
-
edgecolor=
|
|
157
|
+
color="steelblue",
|
|
158
|
+
label="GpC Site",
|
|
159
|
+
edgecolor="black",
|
|
154
160
|
)
|
|
155
161
|
|
|
156
162
|
ax.bar(
|
|
157
|
-
df[
|
|
158
|
-
df[
|
|
163
|
+
df["Genomic_Position"][cpg_mask],
|
|
164
|
+
df["log2_Relative_Risk"][cpg_mask],
|
|
159
165
|
width=10,
|
|
160
|
-
color=
|
|
161
|
-
label=
|
|
162
|
-
edgecolor=
|
|
166
|
+
color="darkorange",
|
|
167
|
+
label="CpG Site",
|
|
168
|
+
edgecolor="black",
|
|
163
169
|
)
|
|
164
170
|
|
|
165
171
|
if both_mask.any():
|
|
166
172
|
ax.bar(
|
|
167
|
-
df[
|
|
168
|
-
df[
|
|
173
|
+
df["Genomic_Position"][both_mask],
|
|
174
|
+
df["log2_Relative_Risk"][both_mask],
|
|
169
175
|
width=10,
|
|
170
|
-
color=
|
|
171
|
-
label=
|
|
172
|
-
edgecolor=
|
|
176
|
+
color="purple",
|
|
177
|
+
label="GpC + CpG",
|
|
178
|
+
edgecolor="black",
|
|
173
179
|
)
|
|
174
180
|
|
|
175
|
-
ax.axhline(y=0, color=
|
|
176
|
-
ax.set_xlabel(
|
|
177
|
-
ax.set_ylabel(
|
|
181
|
+
ax.axhline(y=0, color="gray", linestyle="--")
|
|
182
|
+
ax.set_xlabel("Genomic Position")
|
|
183
|
+
ax.set_ylabel("log2(Relative Risk)")
|
|
178
184
|
ax.set_title(f"{ref} — {group_label}")
|
|
179
185
|
ax.legend()
|
|
180
186
|
|
|
@@ -183,20 +189,23 @@ def plot_bar_relative_risk(
|
|
|
183
189
|
if ylim:
|
|
184
190
|
ax.set_ylim(ylim)
|
|
185
191
|
|
|
186
|
-
ax.spines[
|
|
187
|
-
ax.spines[
|
|
192
|
+
ax.spines["top"].set_visible(False)
|
|
193
|
+
ax.spines["right"].set_visible(False)
|
|
188
194
|
|
|
189
195
|
plt.tight_layout()
|
|
190
196
|
|
|
191
197
|
if save_path:
|
|
192
198
|
os.makedirs(save_path, exist_ok=True)
|
|
193
|
-
safe_name =
|
|
199
|
+
safe_name = (
|
|
200
|
+
f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
|
|
201
|
+
)
|
|
194
202
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
195
203
|
plt.savefig(out_file, dpi=300)
|
|
196
204
|
print(f"📁 Saved: {out_file}")
|
|
197
205
|
|
|
198
206
|
plt.show()
|
|
199
207
|
|
|
208
|
+
|
|
200
209
|
def plot_positionwise_matrix(
|
|
201
210
|
adata,
|
|
202
211
|
key="positionwise_result",
|
|
@@ -210,35 +219,39 @@ def plot_positionwise_matrix(
|
|
|
210
219
|
xtick_step=10,
|
|
211
220
|
ytick_step=10,
|
|
212
221
|
save_path=None,
|
|
213
|
-
highlight_position=None,
|
|
214
|
-
highlight_axis="row",
|
|
215
|
-
annotate_points=False
|
|
222
|
+
highlight_position=None, # Can be a single int/float or list of them
|
|
223
|
+
highlight_axis="row", # "row" or "column"
|
|
224
|
+
annotate_points=False, # ✅ New option
|
|
216
225
|
):
|
|
217
226
|
"""
|
|
218
227
|
Plots positionwise matrices stored in adata.uns[key], with an optional line plot
|
|
219
228
|
for specified row(s) or column(s), and highlights them on the heatmap.
|
|
220
229
|
"""
|
|
230
|
+
import os
|
|
231
|
+
|
|
221
232
|
import matplotlib.pyplot as plt
|
|
222
|
-
import seaborn as sns
|
|
223
233
|
import numpy as np
|
|
224
234
|
import pandas as pd
|
|
225
|
-
import
|
|
235
|
+
import seaborn as sns
|
|
226
236
|
|
|
227
237
|
def find_closest_index(index, target):
|
|
238
|
+
"""Find the index value closest to a target value."""
|
|
228
239
|
index_vals = pd.to_numeric(index, errors="coerce")
|
|
229
240
|
target_val = pd.to_numeric([target], errors="coerce")[0]
|
|
230
241
|
diffs = pd.Series(np.abs(index_vals - target_val), index=index)
|
|
231
242
|
return diffs.idxmin()
|
|
232
243
|
|
|
233
244
|
# Ensure highlight_position is a list
|
|
234
|
-
if highlight_position is not None and not isinstance(
|
|
245
|
+
if highlight_position is not None and not isinstance(
|
|
246
|
+
highlight_position, (list, tuple, np.ndarray)
|
|
247
|
+
):
|
|
235
248
|
highlight_position = [highlight_position]
|
|
236
249
|
|
|
237
250
|
for group, mat_df in adata.uns[key].items():
|
|
238
251
|
mat = mat_df.copy()
|
|
239
252
|
|
|
240
253
|
if log_transform:
|
|
241
|
-
with np.errstate(divide=
|
|
254
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
242
255
|
if log_base == "log1p":
|
|
243
256
|
mat = np.log1p(mat)
|
|
244
257
|
elif log_base == "log2":
|
|
@@ -276,7 +289,7 @@ def plot_positionwise_matrix(
|
|
|
276
289
|
vmin=vmin,
|
|
277
290
|
vmax=vmax,
|
|
278
291
|
cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
|
|
279
|
-
ax=heat_ax
|
|
292
|
+
ax=heat_ax,
|
|
280
293
|
)
|
|
281
294
|
|
|
282
295
|
heat_ax.set_title(f"{key} — {group}", pad=20)
|
|
@@ -295,17 +308,27 @@ def plot_positionwise_matrix(
|
|
|
295
308
|
series = mat.loc[closest]
|
|
296
309
|
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
297
310
|
idx = mat.index.get_loc(closest)
|
|
298
|
-
heat_ax.axhline(
|
|
311
|
+
heat_ax.axhline(
|
|
312
|
+
idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
|
|
313
|
+
)
|
|
299
314
|
label = f"Row {pos} → {closest}"
|
|
300
315
|
else:
|
|
301
316
|
closest = find_closest_index(mat.columns, pos)
|
|
302
317
|
series = mat[closest]
|
|
303
318
|
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
304
319
|
idx = mat.columns.get_loc(closest)
|
|
305
|
-
heat_ax.axvline(
|
|
320
|
+
heat_ax.axvline(
|
|
321
|
+
idx, color=colors[i % len(colors)], linestyle="--", linewidth=1
|
|
322
|
+
)
|
|
306
323
|
label = f"Col {pos} → {closest}"
|
|
307
324
|
|
|
308
|
-
line = line_ax.plot(
|
|
325
|
+
line = line_ax.plot(
|
|
326
|
+
x_vals,
|
|
327
|
+
series.values,
|
|
328
|
+
marker="o",
|
|
329
|
+
label=label,
|
|
330
|
+
color=colors[i % len(colors)],
|
|
331
|
+
)
|
|
309
332
|
|
|
310
333
|
# Annotate each point
|
|
311
334
|
if annotate_points:
|
|
@@ -316,12 +339,18 @@ def plot_positionwise_matrix(
|
|
|
316
339
|
xy=(x, y),
|
|
317
340
|
textcoords="offset points",
|
|
318
341
|
xytext=(0, 5),
|
|
319
|
-
ha=
|
|
320
|
-
fontsize=8
|
|
342
|
+
ha="center",
|
|
343
|
+
fontsize=8,
|
|
321
344
|
)
|
|
322
345
|
except Exception as e:
|
|
323
|
-
line_ax.text(
|
|
324
|
-
|
|
346
|
+
line_ax.text(
|
|
347
|
+
0.5,
|
|
348
|
+
0.5,
|
|
349
|
+
f"⚠️ Error plotting {highlight_axis} @ {pos}",
|
|
350
|
+
ha="center",
|
|
351
|
+
va="center",
|
|
352
|
+
fontsize=10,
|
|
353
|
+
)
|
|
325
354
|
print(f"Error plotting line for {highlight_axis}={pos}: {e}")
|
|
326
355
|
|
|
327
356
|
line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
|
|
@@ -342,6 +371,7 @@ def plot_positionwise_matrix(
|
|
|
342
371
|
|
|
343
372
|
plt.show()
|
|
344
373
|
|
|
374
|
+
|
|
345
375
|
def plot_positionwise_matrix_grid(
|
|
346
376
|
adata,
|
|
347
377
|
key,
|
|
@@ -356,32 +386,61 @@ def plot_positionwise_matrix_grid(
|
|
|
356
386
|
xtick_step=10,
|
|
357
387
|
ytick_step=10,
|
|
358
388
|
parallel=False,
|
|
359
|
-
max_threads=None
|
|
389
|
+
max_threads=None,
|
|
360
390
|
):
|
|
391
|
+
"""Plot a grid of positionwise matrices grouped by metadata.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
adata: AnnData containing matrices in ``adata.uns``.
|
|
395
|
+
key: Key for positionwise matrices.
|
|
396
|
+
outer_keys: Keys for outer grouping.
|
|
397
|
+
inner_keys: Keys for inner grouping.
|
|
398
|
+
log_transform: Optional log transform (``log2`` or ``log1p``).
|
|
399
|
+
vmin: Minimum color scale value.
|
|
400
|
+
vmax: Maximum color scale value.
|
|
401
|
+
cmap: Matplotlib colormap.
|
|
402
|
+
save_path: Optional path to save plots.
|
|
403
|
+
figsize: Figure size.
|
|
404
|
+
xtick_step: X-axis tick step.
|
|
405
|
+
ytick_step: Y-axis tick step.
|
|
406
|
+
parallel: Whether to plot in parallel.
|
|
407
|
+
max_threads: Max thread count for parallel plotting.
|
|
408
|
+
"""
|
|
409
|
+
import os
|
|
410
|
+
|
|
361
411
|
import matplotlib.pyplot as plt
|
|
362
|
-
import seaborn as sns
|
|
363
412
|
import numpy as np
|
|
364
413
|
import pandas as pd
|
|
365
|
-
import
|
|
366
|
-
from matplotlib.gridspec import GridSpec
|
|
414
|
+
import seaborn as sns
|
|
367
415
|
from joblib import Parallel, delayed
|
|
416
|
+
from matplotlib.gridspec import GridSpec
|
|
368
417
|
|
|
369
418
|
matrices = adata.uns[key]
|
|
370
419
|
group_labels = list(matrices.keys())
|
|
371
420
|
|
|
372
|
-
parsed_inner = pd.DataFrame(
|
|
373
|
-
|
|
421
|
+
parsed_inner = pd.DataFrame(
|
|
422
|
+
[dict(zip(inner_keys, g.split("_")[-len(inner_keys) :])) for g in group_labels]
|
|
423
|
+
)
|
|
424
|
+
parsed_outer = pd.Series(
|
|
425
|
+
["_".join(g.split("_")[: -len(inner_keys)]) for g in group_labels], name="outer"
|
|
426
|
+
)
|
|
374
427
|
parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
|
|
375
428
|
|
|
376
429
|
def plot_one_grid(outer_label):
|
|
377
|
-
|
|
378
|
-
selected["
|
|
430
|
+
"""Plot one grid for a specific outer label."""
|
|
431
|
+
selected = parsed[parsed["outer"] == outer_label].copy()
|
|
432
|
+
selected["group_str"] = [
|
|
433
|
+
f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}"
|
|
434
|
+
for _, row in selected.iterrows()
|
|
435
|
+
]
|
|
379
436
|
|
|
380
437
|
row_vals = sorted(selected[inner_keys[0]].unique())
|
|
381
438
|
col_vals = sorted(selected[inner_keys[1]].unique())
|
|
382
439
|
|
|
383
440
|
fig = plt.figure(figsize=figsize)
|
|
384
|
-
gs = GridSpec(
|
|
441
|
+
gs = GridSpec(
|
|
442
|
+
len(row_vals), len(col_vals) + 1, width_ratios=[1] * len(col_vals) + [0.05], wspace=0.3
|
|
443
|
+
)
|
|
385
444
|
axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
|
|
386
445
|
|
|
387
446
|
local_vmin, local_vmax = vmin, vmax
|
|
@@ -397,10 +456,7 @@ def plot_positionwise_matrix_grid(
|
|
|
397
456
|
local_vmin = -vmax_auto if vmin is None else vmin
|
|
398
457
|
local_vmax = vmax_auto if vmax is None else vmax
|
|
399
458
|
|
|
400
|
-
cbar_label = {
|
|
401
|
-
"log2": "log2(Value)",
|
|
402
|
-
"log1p": "log1p(Value)"
|
|
403
|
-
}.get(log_transform, "Value")
|
|
459
|
+
cbar_label = {"log2": "log2(Value)", "log1p": "log1p(Value)"}.get(log_transform, "Value")
|
|
404
460
|
|
|
405
461
|
cbar_ax = fig.add_subplot(gs[:, -1])
|
|
406
462
|
|
|
@@ -431,9 +487,11 @@ def plot_positionwise_matrix_grid(
|
|
|
431
487
|
vmax=local_vmax,
|
|
432
488
|
cbar=(i == 0 and j == 0),
|
|
433
489
|
cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
|
|
434
|
-
cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
|
|
490
|
+
cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""},
|
|
491
|
+
)
|
|
492
|
+
ax.set_title(
|
|
493
|
+
f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8
|
|
435
494
|
)
|
|
436
|
-
ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
|
|
437
495
|
|
|
438
496
|
xticks = data.columns.astype(int)
|
|
439
497
|
yticks = data.index.astype(int)
|
|
@@ -448,15 +506,17 @@ def plot_positionwise_matrix_grid(
|
|
|
448
506
|
if save_path:
|
|
449
507
|
os.makedirs(save_path, exist_ok=True)
|
|
450
508
|
fname = outer_label.replace("_", "").replace("=", "") + ".png"
|
|
451
|
-
plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches=
|
|
452
|
-
print(f"
|
|
509
|
+
plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches="tight")
|
|
510
|
+
print(f"Saved {fname}")
|
|
453
511
|
|
|
454
512
|
plt.close(fig)
|
|
455
513
|
|
|
456
514
|
if parallel:
|
|
457
|
-
Parallel(n_jobs=max_threads)(
|
|
515
|
+
Parallel(n_jobs=max_threads)(
|
|
516
|
+
delayed(plot_one_grid)(outer_label) for outer_label in parsed["outer"].unique()
|
|
517
|
+
)
|
|
458
518
|
else:
|
|
459
|
-
for outer_label in parsed[
|
|
519
|
+
for outer_label in parsed["outer"].unique():
|
|
460
520
|
plot_one_grid(outer_label)
|
|
461
521
|
|
|
462
|
-
print("
|
|
522
|
+
print("Finished plotting all grids.")
|