smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +34 -0
- smftools/_settings.py +20 -0
- smftools/_version.py +1 -0
- smftools/cli.py +184 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +33 -0
- smftools/config/deaminase.yaml +56 -0
- smftools/config/default.yaml +253 -0
- smftools/config/direct.yaml +17 -0
- smftools/config/experiment_config.py +1191 -0
- smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
- smftools/datasets/F1_sample_sheet.csv +5 -0
- smftools/datasets/__init__.py +9 -0
- smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
- smftools/datasets/datasets.py +28 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/hmm/apply_hmm_batched.py +242 -0
- smftools/hmm/calculate_distances.py +18 -0
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/hmm/display_hmm.py +18 -0
- smftools/hmm/hmm_readwrite.py +16 -0
- smftools/hmm/nucleosome_hmm_refinement.py +104 -0
- smftools/hmm/train_hmm.py +78 -0
- smftools/informatics/__init__.py +14 -0
- smftools/informatics/archived/bam_conversion.py +59 -0
- smftools/informatics/archived/bam_direct.py +63 -0
- smftools/informatics/archived/basecalls_to_adata.py +71 -0
- smftools/informatics/archived/conversion_smf.py +132 -0
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/direct_smf.py +137 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/fast5_to_pod5.py +24 -0
- smftools/informatics/helpers/__init__.py +73 -0
- smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
- smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
- smftools/informatics/helpers/archived/informatics.py +260 -0
- smftools/informatics/helpers/archived/load_adata.py +516 -0
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/bed_to_bigwig.py +39 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
- smftools/informatics/helpers/canoncall.py +34 -0
- smftools/informatics/helpers/complement_base_list.py +21 -0
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
- smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
- smftools/informatics/helpers/count_aligned_reads.py +43 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/discover_input_files.py +100 -0
- smftools/informatics/helpers/extract_base_identities.py +70 -0
- smftools/informatics/helpers/extract_mods.py +83 -0
- smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
- smftools/informatics/helpers/find_conversion_sites.py +51 -0
- smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
- smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
- smftools/informatics/helpers/get_native_references.py +28 -0
- smftools/informatics/helpers/index_fasta.py +12 -0
- smftools/informatics/helpers/make_dirs.py +21 -0
- smftools/informatics/helpers/make_modbed.py +27 -0
- smftools/informatics/helpers/modQC.py +27 -0
- smftools/informatics/helpers/modcall.py +36 -0
- smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
- smftools/informatics/helpers/ohe_batching.py +76 -0
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +57 -0
- smftools/informatics/helpers/plot_bed_histograms.py +269 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
- smftools/informatics/helpers/split_and_index_BAM.py +32 -0
- smftools/informatics/readwrite.py +106 -0
- smftools/informatics/subsample_fasta_from_bed.py +47 -0
- smftools/informatics/subsample_pod5.py +104 -0
- smftools/load_adata.py +1346 -0
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/data/preprocessing.py +6 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/__init__.py +9 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/machine_learning/models/positional.py +18 -0
- smftools/machine_learning/models/rnn.py +17 -0
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/models/wrappers.py +20 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +10 -0
- smftools/machine_learning/utils/grl.py +14 -0
- smftools/plotting/__init__.py +18 -0
- smftools/plotting/autocorrelation_plotting.py +611 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +682 -0
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +38 -0
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/archives/mark_duplicates.py +146 -0
- smftools/preprocessing/archives/preprocessing.py +614 -0
- smftools/preprocessing/archives/remove_duplicates.py +21 -0
- smftools/preprocessing/binarize_on_Youden.py +45 -0
- smftools/preprocessing/binary_layers_to_ohe.py +40 -0
- smftools/preprocessing/calculate_complexity.py +72 -0
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_consensus.py +47 -0
- smftools/preprocessing/calculate_coverage.py +51 -0
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
- smftools/preprocessing/calculate_position_Youden.py +115 -0
- smftools/preprocessing/calculate_read_length_stats.py +79 -0
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +62 -0
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1351 -0
- smftools/preprocessing/invert_adata.py +37 -0
- smftools/preprocessing/load_sample_sheet.py +53 -0
- smftools/preprocessing/make_dirs.py +21 -0
- smftools/preprocessing/min_non_diagonal.py +25 -0
- smftools/preprocessing/recipes.py +127 -0
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +1004 -0
- smftools/tools/__init__.py +20 -0
- smftools/tools/archived/apply_hmm.py +202 -0
- smftools/tools/archived/classifiers.py +787 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/position_stats.py +601 -0
- smftools/tools/read_stats.py +184 -0
- smftools/tools/spatial_autocorrelation.py +562 -0
- smftools/tools/subset_adata.py +28 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
- smftools-0.2.1.dist-info/RECORD +161 -0
- smftools-0.2.1.dist-info/entry_points.txt +2 -0
- smftools-0.1.6.dist-info/RECORD +0 -4
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
def plot_volcano_relative_risk(
|
|
2
|
+
results_dict,
|
|
3
|
+
save_path=None,
|
|
4
|
+
highlight_regions=None, # List of (start, end) tuples
|
|
5
|
+
highlight_color="lightgray",
|
|
6
|
+
highlight_alpha=0.3,
|
|
7
|
+
xlim=None,
|
|
8
|
+
ylim=None,
|
|
9
|
+
):
|
|
10
|
+
"""
|
|
11
|
+
Plot volcano-style log2(Relative Risk) vs Genomic Position for each group within each reference.
|
|
12
|
+
|
|
13
|
+
Parameters:
|
|
14
|
+
results_dict (dict): Output from calculate_relative_risk_by_group.
|
|
15
|
+
Format: dict[ref][group_label] = (results_df, sig_df)
|
|
16
|
+
save_path (str): Directory to save plots.
|
|
17
|
+
highlight_regions (list): List of (start, end) tuples for shaded regions.
|
|
18
|
+
highlight_color (str): Color for highlighted regions.
|
|
19
|
+
highlight_alpha (float): Alpha for highlighted region.
|
|
20
|
+
xlim (tuple): Optional x-axis limit.
|
|
21
|
+
ylim (tuple): Optional y-axis limit.
|
|
22
|
+
"""
|
|
23
|
+
import matplotlib.pyplot as plt
|
|
24
|
+
import numpy as np
|
|
25
|
+
import os
|
|
26
|
+
|
|
27
|
+
for ref, group_results in results_dict.items():
|
|
28
|
+
for group_label, (results_df, _) in group_results.items():
|
|
29
|
+
if results_df.empty:
|
|
30
|
+
print(f"Skipping empty results for {ref} / {group_label}")
|
|
31
|
+
continue
|
|
32
|
+
|
|
33
|
+
# Split by site type
|
|
34
|
+
gpc_df = results_df[results_df['GpC_Site']]
|
|
35
|
+
cpg_df = results_df[results_df['CpG_Site']]
|
|
36
|
+
|
|
37
|
+
fig, ax = plt.subplots(figsize=(12, 6))
|
|
38
|
+
|
|
39
|
+
# Highlight regions
|
|
40
|
+
if highlight_regions:
|
|
41
|
+
for start, end in highlight_regions:
|
|
42
|
+
ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
|
|
43
|
+
|
|
44
|
+
# GpC as circles
|
|
45
|
+
sc1 = ax.scatter(
|
|
46
|
+
gpc_df['Genomic_Position'],
|
|
47
|
+
gpc_df['log2_Relative_Risk'],
|
|
48
|
+
c=gpc_df['-log10_Adj_P'],
|
|
49
|
+
cmap='coolwarm',
|
|
50
|
+
edgecolor='k',
|
|
51
|
+
s=40,
|
|
52
|
+
marker='o',
|
|
53
|
+
label='GpC'
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# CpG as stars
|
|
57
|
+
sc2 = ax.scatter(
|
|
58
|
+
cpg_df['Genomic_Position'],
|
|
59
|
+
cpg_df['log2_Relative_Risk'],
|
|
60
|
+
c=cpg_df['-log10_Adj_P'],
|
|
61
|
+
cmap='coolwarm',
|
|
62
|
+
edgecolor='k',
|
|
63
|
+
s=60,
|
|
64
|
+
marker='*',
|
|
65
|
+
label='CpG'
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
ax.axhline(y=0, color='gray', linestyle='--')
|
|
69
|
+
ax.set_xlabel("Genomic Position")
|
|
70
|
+
ax.set_ylabel("log2(Relative Risk)")
|
|
71
|
+
ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
|
|
72
|
+
|
|
73
|
+
if xlim:
|
|
74
|
+
ax.set_xlim(xlim)
|
|
75
|
+
if ylim:
|
|
76
|
+
ax.set_ylim(ylim)
|
|
77
|
+
|
|
78
|
+
ax.spines['top'].set_visible(False)
|
|
79
|
+
ax.spines['right'].set_visible(False)
|
|
80
|
+
|
|
81
|
+
cbar = plt.colorbar(sc1, ax=ax)
|
|
82
|
+
cbar.set_label("-log10(Adjusted P-Value)")
|
|
83
|
+
|
|
84
|
+
ax.legend()
|
|
85
|
+
plt.tight_layout()
|
|
86
|
+
|
|
87
|
+
# Save if requested
|
|
88
|
+
if save_path:
|
|
89
|
+
os.makedirs(save_path, exist_ok=True)
|
|
90
|
+
safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
|
|
91
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
92
|
+
plt.savefig(out_file, dpi=300)
|
|
93
|
+
print(f"📁 Saved: {out_file}")
|
|
94
|
+
|
|
95
|
+
plt.show()
|
|
96
|
+
|
|
97
|
+
def plot_bar_relative_risk(
|
|
98
|
+
results_dict,
|
|
99
|
+
sort_by_position=True,
|
|
100
|
+
xlim=None,
|
|
101
|
+
ylim=None,
|
|
102
|
+
save_path=None,
|
|
103
|
+
highlight_regions=None, # List of (start, end) tuples
|
|
104
|
+
highlight_color="lightgray",
|
|
105
|
+
highlight_alpha=0.3
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
|
|
109
|
+
|
|
110
|
+
Parameters:
|
|
111
|
+
results_dict (dict): Output from calculate_relative_risk_by_group.
|
|
112
|
+
sort_by_position (bool): Whether to sort bars left-to-right by genomic coordinate.
|
|
113
|
+
xlim, ylim (tuple): Axis limits.
|
|
114
|
+
save_path (str or None): Directory to save plots.
|
|
115
|
+
highlight_regions (list of tuple): List of (start, end) genomic regions to shade.
|
|
116
|
+
highlight_color (str): Color of shaded region.
|
|
117
|
+
highlight_alpha (float): Transparency of shaded region.
|
|
118
|
+
"""
|
|
119
|
+
import matplotlib.pyplot as plt
|
|
120
|
+
import numpy as np
|
|
121
|
+
import os
|
|
122
|
+
|
|
123
|
+
for ref, group_data in results_dict.items():
|
|
124
|
+
for group_label, (df, _) in group_data.items():
|
|
125
|
+
if df.empty:
|
|
126
|
+
print(f"Skipping empty result for {ref} / {group_label}")
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
df = df.copy()
|
|
130
|
+
df['Genomic_Position'] = df['Genomic_Position'].astype(int)
|
|
131
|
+
|
|
132
|
+
if sort_by_position:
|
|
133
|
+
df = df.sort_values('Genomic_Position')
|
|
134
|
+
|
|
135
|
+
gpc_mask = df['GpC_Site'] & ~df['CpG_Site']
|
|
136
|
+
cpg_mask = df['CpG_Site'] & ~df['GpC_Site']
|
|
137
|
+
both_mask = df['GpC_Site'] & df['CpG_Site']
|
|
138
|
+
|
|
139
|
+
fig, ax = plt.subplots(figsize=(14, 6))
|
|
140
|
+
|
|
141
|
+
# Optional shaded regions
|
|
142
|
+
if highlight_regions:
|
|
143
|
+
for start, end in highlight_regions:
|
|
144
|
+
ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
|
|
145
|
+
|
|
146
|
+
# Bar plots
|
|
147
|
+
ax.bar(
|
|
148
|
+
df['Genomic_Position'][gpc_mask],
|
|
149
|
+
df['log2_Relative_Risk'][gpc_mask],
|
|
150
|
+
width=10,
|
|
151
|
+
color='steelblue',
|
|
152
|
+
label='GpC Site',
|
|
153
|
+
edgecolor='black'
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
ax.bar(
|
|
157
|
+
df['Genomic_Position'][cpg_mask],
|
|
158
|
+
df['log2_Relative_Risk'][cpg_mask],
|
|
159
|
+
width=10,
|
|
160
|
+
color='darkorange',
|
|
161
|
+
label='CpG Site',
|
|
162
|
+
edgecolor='black'
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if both_mask.any():
|
|
166
|
+
ax.bar(
|
|
167
|
+
df['Genomic_Position'][both_mask],
|
|
168
|
+
df['log2_Relative_Risk'][both_mask],
|
|
169
|
+
width=10,
|
|
170
|
+
color='purple',
|
|
171
|
+
label='GpC + CpG',
|
|
172
|
+
edgecolor='black'
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
ax.axhline(y=0, color='gray', linestyle='--')
|
|
176
|
+
ax.set_xlabel('Genomic Position')
|
|
177
|
+
ax.set_ylabel('log2(Relative Risk)')
|
|
178
|
+
ax.set_title(f"{ref} — {group_label}")
|
|
179
|
+
ax.legend()
|
|
180
|
+
|
|
181
|
+
if xlim:
|
|
182
|
+
ax.set_xlim(xlim)
|
|
183
|
+
if ylim:
|
|
184
|
+
ax.set_ylim(ylim)
|
|
185
|
+
|
|
186
|
+
ax.spines['top'].set_visible(False)
|
|
187
|
+
ax.spines['right'].set_visible(False)
|
|
188
|
+
|
|
189
|
+
plt.tight_layout()
|
|
190
|
+
|
|
191
|
+
if save_path:
|
|
192
|
+
os.makedirs(save_path, exist_ok=True)
|
|
193
|
+
safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
|
|
194
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
195
|
+
plt.savefig(out_file, dpi=300)
|
|
196
|
+
print(f"📁 Saved: {out_file}")
|
|
197
|
+
|
|
198
|
+
plt.show()
|
|
199
|
+
|
|
200
|
+
def plot_positionwise_matrix(
|
|
201
|
+
adata,
|
|
202
|
+
key="positionwise_result",
|
|
203
|
+
log_transform=False,
|
|
204
|
+
log_base="log1p", # or 'log2', or None
|
|
205
|
+
triangle="full",
|
|
206
|
+
cmap="vlag",
|
|
207
|
+
figsize=(12, 10), # Taller to accommodate line plot below
|
|
208
|
+
vmin=None,
|
|
209
|
+
vmax=None,
|
|
210
|
+
xtick_step=10,
|
|
211
|
+
ytick_step=10,
|
|
212
|
+
save_path=None,
|
|
213
|
+
highlight_position=None, # Can be a single int/float or list of them
|
|
214
|
+
highlight_axis="row", # "row" or "column"
|
|
215
|
+
annotate_points=False # ✅ New option
|
|
216
|
+
):
|
|
217
|
+
"""
|
|
218
|
+
Plots positionwise matrices stored in adata.uns[key], with an optional line plot
|
|
219
|
+
for specified row(s) or column(s), and highlights them on the heatmap.
|
|
220
|
+
"""
|
|
221
|
+
import matplotlib.pyplot as plt
|
|
222
|
+
import seaborn as sns
|
|
223
|
+
import numpy as np
|
|
224
|
+
import pandas as pd
|
|
225
|
+
import os
|
|
226
|
+
|
|
227
|
+
def find_closest_index(index, target):
|
|
228
|
+
index_vals = pd.to_numeric(index, errors="coerce")
|
|
229
|
+
target_val = pd.to_numeric([target], errors="coerce")[0]
|
|
230
|
+
diffs = pd.Series(np.abs(index_vals - target_val), index=index)
|
|
231
|
+
return diffs.idxmin()
|
|
232
|
+
|
|
233
|
+
# Ensure highlight_position is a list
|
|
234
|
+
if highlight_position is not None and not isinstance(highlight_position, (list, tuple, np.ndarray)):
|
|
235
|
+
highlight_position = [highlight_position]
|
|
236
|
+
|
|
237
|
+
for group, mat_df in adata.uns[key].items():
|
|
238
|
+
mat = mat_df.copy()
|
|
239
|
+
|
|
240
|
+
if log_transform:
|
|
241
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
242
|
+
if log_base == "log1p":
|
|
243
|
+
mat = np.log1p(mat)
|
|
244
|
+
elif log_base == "log2":
|
|
245
|
+
mat = np.log2(mat.replace(0, np.nanmin(mat[mat > 0]) * 0.1))
|
|
246
|
+
mat.replace([np.inf, -np.inf], np.nan, inplace=True)
|
|
247
|
+
|
|
248
|
+
# Set color limits for log2 to be centered around 0
|
|
249
|
+
if log_base == "log2" and log_transform and (vmin is None or vmax is None):
|
|
250
|
+
abs_max = np.nanmax(np.abs(mat.values))
|
|
251
|
+
vmin = -abs_max if vmin is None else vmin
|
|
252
|
+
vmax = abs_max if vmax is None else vmax
|
|
253
|
+
|
|
254
|
+
# Create mask for triangle
|
|
255
|
+
mask = None
|
|
256
|
+
if triangle == "lower":
|
|
257
|
+
mask = np.triu(np.ones_like(mat, dtype=bool), k=1)
|
|
258
|
+
elif triangle == "upper":
|
|
259
|
+
mask = np.tril(np.ones_like(mat, dtype=bool), k=-1)
|
|
260
|
+
|
|
261
|
+
xticks = mat.columns.astype(int)
|
|
262
|
+
yticks = mat.index.astype(int)
|
|
263
|
+
|
|
264
|
+
# 👉 Make taller figure: heatmap on top, line plot below
|
|
265
|
+
fig, axs = plt.subplots(2, 1, figsize=figsize, height_ratios=[3, 1.5])
|
|
266
|
+
heat_ax, line_ax = axs
|
|
267
|
+
|
|
268
|
+
# Heatmap
|
|
269
|
+
sns.heatmap(
|
|
270
|
+
mat,
|
|
271
|
+
mask=mask,
|
|
272
|
+
cmap=cmap,
|
|
273
|
+
xticklabels=xticks,
|
|
274
|
+
yticklabels=yticks,
|
|
275
|
+
square=True,
|
|
276
|
+
vmin=vmin,
|
|
277
|
+
vmax=vmax,
|
|
278
|
+
cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
|
|
279
|
+
ax=heat_ax
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
heat_ax.set_title(f"{key} — {group}", pad=20)
|
|
283
|
+
heat_ax.set_xticks(np.arange(0, len(xticks), xtick_step))
|
|
284
|
+
heat_ax.set_xticklabels(xticks[::xtick_step], rotation=90)
|
|
285
|
+
heat_ax.set_yticks(np.arange(0, len(yticks), ytick_step))
|
|
286
|
+
heat_ax.set_yticklabels(yticks[::ytick_step])
|
|
287
|
+
|
|
288
|
+
# Line plot
|
|
289
|
+
if highlight_position is not None:
|
|
290
|
+
colors = plt.cm.tab10.colors
|
|
291
|
+
for i, pos in enumerate(highlight_position):
|
|
292
|
+
try:
|
|
293
|
+
if highlight_axis == "row":
|
|
294
|
+
closest = find_closest_index(mat.index, pos)
|
|
295
|
+
series = mat.loc[closest]
|
|
296
|
+
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
297
|
+
idx = mat.index.get_loc(closest)
|
|
298
|
+
heat_ax.axhline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
|
|
299
|
+
label = f"Row {pos} → {closest}"
|
|
300
|
+
else:
|
|
301
|
+
closest = find_closest_index(mat.columns, pos)
|
|
302
|
+
series = mat[closest]
|
|
303
|
+
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
304
|
+
idx = mat.columns.get_loc(closest)
|
|
305
|
+
heat_ax.axvline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
|
|
306
|
+
label = f"Col {pos} → {closest}"
|
|
307
|
+
|
|
308
|
+
line = line_ax.plot(x_vals, series.values, marker='o', label=label, color=colors[i % len(colors)])
|
|
309
|
+
|
|
310
|
+
# Annotate each point
|
|
311
|
+
if annotate_points:
|
|
312
|
+
for x, y in zip(x_vals, series.values):
|
|
313
|
+
if not np.isnan(y):
|
|
314
|
+
line_ax.annotate(
|
|
315
|
+
f"{y:.2f}",
|
|
316
|
+
xy=(x, y),
|
|
317
|
+
textcoords="offset points",
|
|
318
|
+
xytext=(0, 5),
|
|
319
|
+
ha='center',
|
|
320
|
+
fontsize=8
|
|
321
|
+
)
|
|
322
|
+
except Exception as e:
|
|
323
|
+
line_ax.text(0.5, 0.5, f"⚠️ Error plotting {highlight_axis} @ {pos}",
|
|
324
|
+
ha='center', va='center', fontsize=10)
|
|
325
|
+
print(f"Error plotting line for {highlight_axis}={pos}: {e}")
|
|
326
|
+
|
|
327
|
+
line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
|
|
328
|
+
line_ax.set_xlabel(f"{'Column' if highlight_axis == 'row' else 'Row'} position")
|
|
329
|
+
line_ax.set_ylabel("Value")
|
|
330
|
+
line_ax.grid(True)
|
|
331
|
+
line_ax.legend(fontsize=8)
|
|
332
|
+
|
|
333
|
+
plt.tight_layout()
|
|
334
|
+
|
|
335
|
+
# Save if requested
|
|
336
|
+
if save_path:
|
|
337
|
+
os.makedirs(save_path, exist_ok=True)
|
|
338
|
+
safe_name = group.replace("=", "").replace("__", "_").replace(",", "_")
|
|
339
|
+
out_file = os.path.join(save_path, f"{key}_{safe_name}.png")
|
|
340
|
+
plt.savefig(out_file, dpi=300)
|
|
341
|
+
print(f"📁 Saved: {out_file}")
|
|
342
|
+
|
|
343
|
+
plt.show()
|
|
344
|
+
|
|
345
|
+
def plot_positionwise_matrix_grid(
|
|
346
|
+
adata,
|
|
347
|
+
key,
|
|
348
|
+
outer_keys=["Reference_strand", "activity_status"],
|
|
349
|
+
inner_keys=["Promoter_Open", "Enhancer_Open"],
|
|
350
|
+
log_transform=None,
|
|
351
|
+
vmin=None,
|
|
352
|
+
vmax=None,
|
|
353
|
+
cmap="vlag",
|
|
354
|
+
save_path=None,
|
|
355
|
+
figsize=(10, 10),
|
|
356
|
+
xtick_step=10,
|
|
357
|
+
ytick_step=10,
|
|
358
|
+
parallel=False,
|
|
359
|
+
max_threads=None
|
|
360
|
+
):
|
|
361
|
+
import matplotlib.pyplot as plt
|
|
362
|
+
import seaborn as sns
|
|
363
|
+
import numpy as np
|
|
364
|
+
import pandas as pd
|
|
365
|
+
import os
|
|
366
|
+
from matplotlib.gridspec import GridSpec
|
|
367
|
+
from joblib import Parallel, delayed
|
|
368
|
+
|
|
369
|
+
matrices = adata.uns[key]
|
|
370
|
+
group_labels = list(matrices.keys())
|
|
371
|
+
|
|
372
|
+
parsed_inner = pd.DataFrame([dict(zip(inner_keys, g.split("_")[-len(inner_keys):])) for g in group_labels])
|
|
373
|
+
parsed_outer = pd.Series(["_".join(g.split("_")[:-len(inner_keys)]) for g in group_labels], name="outer")
|
|
374
|
+
parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
|
|
375
|
+
|
|
376
|
+
def plot_one_grid(outer_label):
|
|
377
|
+
selected = parsed[parsed['outer'] == outer_label].copy()
|
|
378
|
+
selected["group_str"] = [f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}" for _, row in selected.iterrows()]
|
|
379
|
+
|
|
380
|
+
row_vals = sorted(selected[inner_keys[0]].unique())
|
|
381
|
+
col_vals = sorted(selected[inner_keys[1]].unique())
|
|
382
|
+
|
|
383
|
+
fig = plt.figure(figsize=figsize)
|
|
384
|
+
gs = GridSpec(len(row_vals), len(col_vals) + 1, width_ratios=[1]*len(col_vals) + [0.05], wspace=0.3)
|
|
385
|
+
axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
|
|
386
|
+
|
|
387
|
+
local_vmin, local_vmax = vmin, vmax
|
|
388
|
+
if log_transform == "log2" and (vmin is None or vmax is None):
|
|
389
|
+
all_data = []
|
|
390
|
+
for group_str in selected["group_str"]:
|
|
391
|
+
mat = matrices.get(group_str)
|
|
392
|
+
if mat is not None:
|
|
393
|
+
all_data.append(np.log2(mat.replace(0, 1e-9).values))
|
|
394
|
+
if all_data:
|
|
395
|
+
combined = np.concatenate([arr.flatten() for arr in all_data])
|
|
396
|
+
vmax_auto = np.nanmax(np.abs(combined))
|
|
397
|
+
local_vmin = -vmax_auto if vmin is None else vmin
|
|
398
|
+
local_vmax = vmax_auto if vmax is None else vmax
|
|
399
|
+
|
|
400
|
+
cbar_label = {
|
|
401
|
+
"log2": "log2(Value)",
|
|
402
|
+
"log1p": "log1p(Value)"
|
|
403
|
+
}.get(log_transform, "Value")
|
|
404
|
+
|
|
405
|
+
cbar_ax = fig.add_subplot(gs[:, -1])
|
|
406
|
+
|
|
407
|
+
for i, row_val in enumerate(row_vals):
|
|
408
|
+
for j, col_val in enumerate(col_vals):
|
|
409
|
+
group_label = f"{outer_label}_{row_val}_{col_val}"
|
|
410
|
+
ax = fig.add_subplot(gs[i, j])
|
|
411
|
+
axes[i, j] = ax
|
|
412
|
+
mat = matrices.get(group_label)
|
|
413
|
+
if mat is None:
|
|
414
|
+
ax.axis("off")
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
data = mat.copy()
|
|
418
|
+
if log_transform == "log2":
|
|
419
|
+
data = np.log2(data.replace(0, 1e-9))
|
|
420
|
+
elif log_transform == "log1p":
|
|
421
|
+
data = np.log1p(data)
|
|
422
|
+
|
|
423
|
+
sns.heatmap(
|
|
424
|
+
data,
|
|
425
|
+
ax=ax,
|
|
426
|
+
cmap=cmap,
|
|
427
|
+
xticklabels=True,
|
|
428
|
+
yticklabels=True,
|
|
429
|
+
square=True,
|
|
430
|
+
vmin=local_vmin,
|
|
431
|
+
vmax=local_vmax,
|
|
432
|
+
cbar=(i == 0 and j == 0),
|
|
433
|
+
cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
|
|
434
|
+
cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
|
|
435
|
+
)
|
|
436
|
+
ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
|
|
437
|
+
|
|
438
|
+
xticks = data.columns.astype(int)
|
|
439
|
+
yticks = data.index.astype(int)
|
|
440
|
+
ax.set_xticks(np.arange(0, len(xticks), xtick_step))
|
|
441
|
+
ax.set_xticklabels(xticks[::xtick_step], rotation=90)
|
|
442
|
+
ax.set_yticks(np.arange(0, len(yticks), ytick_step))
|
|
443
|
+
ax.set_yticklabels(yticks[::ytick_step])
|
|
444
|
+
|
|
445
|
+
fig.suptitle(f"{key} • {outer_label}", fontsize=14, y=1.02)
|
|
446
|
+
fig.tight_layout(rect=[0, 0, 0.97, 0.95])
|
|
447
|
+
|
|
448
|
+
if save_path:
|
|
449
|
+
os.makedirs(save_path, exist_ok=True)
|
|
450
|
+
fname = outer_label.replace("_", "").replace("=", "") + ".png"
|
|
451
|
+
plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches='tight')
|
|
452
|
+
print(f"✅ Saved {fname}")
|
|
453
|
+
|
|
454
|
+
plt.close(fig)
|
|
455
|
+
|
|
456
|
+
if parallel:
|
|
457
|
+
Parallel(n_jobs=max_threads)(delayed(plot_one_grid)(outer_label) for outer_label in parsed['outer'].unique())
|
|
458
|
+
else:
|
|
459
|
+
for outer_label in parsed['outer'].unique():
|
|
460
|
+
plot_one_grid(outer_label)
|
|
461
|
+
|
|
462
|
+
print("✅ Finished plotting all grids.")
|