smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
def plot_volcano_relative_risk(
|
|
2
|
+
results_dict,
|
|
3
|
+
save_path=None,
|
|
4
|
+
highlight_regions=None, # List of (start, end) tuples
|
|
5
|
+
highlight_color="lightgray",
|
|
6
|
+
highlight_alpha=0.3,
|
|
7
|
+
xlim=None,
|
|
8
|
+
ylim=None,
|
|
9
|
+
):
|
|
10
|
+
"""
|
|
11
|
+
Plot volcano-style log2(Relative Risk) vs Genomic Position for each group within each reference.
|
|
12
|
+
|
|
13
|
+
Parameters:
|
|
14
|
+
results_dict (dict): Output from calculate_relative_risk_by_group.
|
|
15
|
+
Format: dict[ref][group_label] = (results_df, sig_df)
|
|
16
|
+
save_path (str): Directory to save plots.
|
|
17
|
+
highlight_regions (list): List of (start, end) tuples for shaded regions.
|
|
18
|
+
highlight_color (str): Color for highlighted regions.
|
|
19
|
+
highlight_alpha (float): Alpha for highlighted region.
|
|
20
|
+
xlim (tuple): Optional x-axis limit.
|
|
21
|
+
ylim (tuple): Optional y-axis limit.
|
|
22
|
+
"""
|
|
23
|
+
import matplotlib.pyplot as plt
|
|
24
|
+
import numpy as np
|
|
25
|
+
import os
|
|
26
|
+
|
|
27
|
+
for ref, group_results in results_dict.items():
|
|
28
|
+
for group_label, (results_df, _) in group_results.items():
|
|
29
|
+
if results_df.empty:
|
|
30
|
+
print(f"Skipping empty results for {ref} / {group_label}")
|
|
31
|
+
continue
|
|
32
|
+
|
|
33
|
+
# Split by site type
|
|
34
|
+
gpc_df = results_df[results_df['GpC_Site']]
|
|
35
|
+
cpg_df = results_df[results_df['CpG_Site']]
|
|
36
|
+
|
|
37
|
+
fig, ax = plt.subplots(figsize=(12, 6))
|
|
38
|
+
|
|
39
|
+
# Highlight regions
|
|
40
|
+
if highlight_regions:
|
|
41
|
+
for start, end in highlight_regions:
|
|
42
|
+
ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
|
|
43
|
+
|
|
44
|
+
# GpC as circles
|
|
45
|
+
sc1 = ax.scatter(
|
|
46
|
+
gpc_df['Genomic_Position'],
|
|
47
|
+
gpc_df['log2_Relative_Risk'],
|
|
48
|
+
c=gpc_df['-log10_Adj_P'],
|
|
49
|
+
cmap='coolwarm',
|
|
50
|
+
edgecolor='k',
|
|
51
|
+
s=40,
|
|
52
|
+
marker='o',
|
|
53
|
+
label='GpC'
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# CpG as stars
|
|
57
|
+
sc2 = ax.scatter(
|
|
58
|
+
cpg_df['Genomic_Position'],
|
|
59
|
+
cpg_df['log2_Relative_Risk'],
|
|
60
|
+
c=cpg_df['-log10_Adj_P'],
|
|
61
|
+
cmap='coolwarm',
|
|
62
|
+
edgecolor='k',
|
|
63
|
+
s=60,
|
|
64
|
+
marker='*',
|
|
65
|
+
label='CpG'
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
ax.axhline(y=0, color='gray', linestyle='--')
|
|
69
|
+
ax.set_xlabel("Genomic Position")
|
|
70
|
+
ax.set_ylabel("log2(Relative Risk)")
|
|
71
|
+
ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
|
|
72
|
+
|
|
73
|
+
if xlim:
|
|
74
|
+
ax.set_xlim(xlim)
|
|
75
|
+
if ylim:
|
|
76
|
+
ax.set_ylim(ylim)
|
|
77
|
+
|
|
78
|
+
ax.spines['top'].set_visible(False)
|
|
79
|
+
ax.spines['right'].set_visible(False)
|
|
80
|
+
|
|
81
|
+
cbar = plt.colorbar(sc1, ax=ax)
|
|
82
|
+
cbar.set_label("-log10(Adjusted P-Value)")
|
|
83
|
+
|
|
84
|
+
ax.legend()
|
|
85
|
+
plt.tight_layout()
|
|
86
|
+
|
|
87
|
+
# Save if requested
|
|
88
|
+
if save_path:
|
|
89
|
+
os.makedirs(save_path, exist_ok=True)
|
|
90
|
+
safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
|
|
91
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
92
|
+
plt.savefig(out_file, dpi=300)
|
|
93
|
+
print(f"📁 Saved: {out_file}")
|
|
94
|
+
|
|
95
|
+
plt.show()
|
|
96
|
+
|
|
97
|
+
def plot_bar_relative_risk(
|
|
98
|
+
results_dict,
|
|
99
|
+
sort_by_position=True,
|
|
100
|
+
xlim=None,
|
|
101
|
+
ylim=None,
|
|
102
|
+
save_path=None,
|
|
103
|
+
highlight_regions=None, # List of (start, end) tuples
|
|
104
|
+
highlight_color="lightgray",
|
|
105
|
+
highlight_alpha=0.3
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
|
|
109
|
+
|
|
110
|
+
Parameters:
|
|
111
|
+
results_dict (dict): Output from calculate_relative_risk_by_group.
|
|
112
|
+
sort_by_position (bool): Whether to sort bars left-to-right by genomic coordinate.
|
|
113
|
+
xlim, ylim (tuple): Axis limits.
|
|
114
|
+
save_path (str or None): Directory to save plots.
|
|
115
|
+
highlight_regions (list of tuple): List of (start, end) genomic regions to shade.
|
|
116
|
+
highlight_color (str): Color of shaded region.
|
|
117
|
+
highlight_alpha (float): Transparency of shaded region.
|
|
118
|
+
"""
|
|
119
|
+
import matplotlib.pyplot as plt
|
|
120
|
+
import numpy as np
|
|
121
|
+
import os
|
|
122
|
+
|
|
123
|
+
for ref, group_data in results_dict.items():
|
|
124
|
+
for group_label, (df, _) in group_data.items():
|
|
125
|
+
if df.empty:
|
|
126
|
+
print(f"Skipping empty result for {ref} / {group_label}")
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
df = df.copy()
|
|
130
|
+
df['Genomic_Position'] = df['Genomic_Position'].astype(int)
|
|
131
|
+
|
|
132
|
+
if sort_by_position:
|
|
133
|
+
df = df.sort_values('Genomic_Position')
|
|
134
|
+
|
|
135
|
+
gpc_mask = df['GpC_Site'] & ~df['CpG_Site']
|
|
136
|
+
cpg_mask = df['CpG_Site'] & ~df['GpC_Site']
|
|
137
|
+
both_mask = df['GpC_Site'] & df['CpG_Site']
|
|
138
|
+
|
|
139
|
+
fig, ax = plt.subplots(figsize=(14, 6))
|
|
140
|
+
|
|
141
|
+
# Optional shaded regions
|
|
142
|
+
if highlight_regions:
|
|
143
|
+
for start, end in highlight_regions:
|
|
144
|
+
ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
|
|
145
|
+
|
|
146
|
+
# Bar plots
|
|
147
|
+
ax.bar(
|
|
148
|
+
df['Genomic_Position'][gpc_mask],
|
|
149
|
+
df['log2_Relative_Risk'][gpc_mask],
|
|
150
|
+
width=10,
|
|
151
|
+
color='steelblue',
|
|
152
|
+
label='GpC Site',
|
|
153
|
+
edgecolor='black'
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
ax.bar(
|
|
157
|
+
df['Genomic_Position'][cpg_mask],
|
|
158
|
+
df['log2_Relative_Risk'][cpg_mask],
|
|
159
|
+
width=10,
|
|
160
|
+
color='darkorange',
|
|
161
|
+
label='CpG Site',
|
|
162
|
+
edgecolor='black'
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if both_mask.any():
|
|
166
|
+
ax.bar(
|
|
167
|
+
df['Genomic_Position'][both_mask],
|
|
168
|
+
df['log2_Relative_Risk'][both_mask],
|
|
169
|
+
width=10,
|
|
170
|
+
color='purple',
|
|
171
|
+
label='GpC + CpG',
|
|
172
|
+
edgecolor='black'
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
ax.axhline(y=0, color='gray', linestyle='--')
|
|
176
|
+
ax.set_xlabel('Genomic Position')
|
|
177
|
+
ax.set_ylabel('log2(Relative Risk)')
|
|
178
|
+
ax.set_title(f"{ref} — {group_label}")
|
|
179
|
+
ax.legend()
|
|
180
|
+
|
|
181
|
+
if xlim:
|
|
182
|
+
ax.set_xlim(xlim)
|
|
183
|
+
if ylim:
|
|
184
|
+
ax.set_ylim(ylim)
|
|
185
|
+
|
|
186
|
+
ax.spines['top'].set_visible(False)
|
|
187
|
+
ax.spines['right'].set_visible(False)
|
|
188
|
+
|
|
189
|
+
plt.tight_layout()
|
|
190
|
+
|
|
191
|
+
if save_path:
|
|
192
|
+
os.makedirs(save_path, exist_ok=True)
|
|
193
|
+
safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
|
|
194
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
195
|
+
plt.savefig(out_file, dpi=300)
|
|
196
|
+
print(f"📁 Saved: {out_file}")
|
|
197
|
+
|
|
198
|
+
plt.show()
|
|
199
|
+
|
|
200
|
+
def plot_positionwise_matrix(
|
|
201
|
+
adata,
|
|
202
|
+
key="positionwise_result",
|
|
203
|
+
log_transform=False,
|
|
204
|
+
log_base="log1p", # or 'log2', or None
|
|
205
|
+
triangle="full",
|
|
206
|
+
cmap="vlag",
|
|
207
|
+
figsize=(12, 10), # Taller to accommodate line plot below
|
|
208
|
+
vmin=None,
|
|
209
|
+
vmax=None,
|
|
210
|
+
xtick_step=10,
|
|
211
|
+
ytick_step=10,
|
|
212
|
+
save_path=None,
|
|
213
|
+
highlight_position=None, # Can be a single int/float or list of them
|
|
214
|
+
highlight_axis="row", # "row" or "column"
|
|
215
|
+
annotate_points=False # ✅ New option
|
|
216
|
+
):
|
|
217
|
+
"""
|
|
218
|
+
Plots positionwise matrices stored in adata.uns[key], with an optional line plot
|
|
219
|
+
for specified row(s) or column(s), and highlights them on the heatmap.
|
|
220
|
+
"""
|
|
221
|
+
import matplotlib.pyplot as plt
|
|
222
|
+
import seaborn as sns
|
|
223
|
+
import numpy as np
|
|
224
|
+
import pandas as pd
|
|
225
|
+
import os
|
|
226
|
+
|
|
227
|
+
def find_closest_index(index, target):
|
|
228
|
+
index_vals = pd.to_numeric(index, errors="coerce")
|
|
229
|
+
target_val = pd.to_numeric([target], errors="coerce")[0]
|
|
230
|
+
diffs = pd.Series(np.abs(index_vals - target_val), index=index)
|
|
231
|
+
return diffs.idxmin()
|
|
232
|
+
|
|
233
|
+
# Ensure highlight_position is a list
|
|
234
|
+
if highlight_position is not None and not isinstance(highlight_position, (list, tuple, np.ndarray)):
|
|
235
|
+
highlight_position = [highlight_position]
|
|
236
|
+
|
|
237
|
+
for group, mat_df in adata.uns[key].items():
|
|
238
|
+
mat = mat_df.copy()
|
|
239
|
+
|
|
240
|
+
if log_transform:
|
|
241
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
242
|
+
if log_base == "log1p":
|
|
243
|
+
mat = np.log1p(mat)
|
|
244
|
+
elif log_base == "log2":
|
|
245
|
+
mat = np.log2(mat.replace(0, np.nanmin(mat[mat > 0]) * 0.1))
|
|
246
|
+
mat.replace([np.inf, -np.inf], np.nan, inplace=True)
|
|
247
|
+
|
|
248
|
+
# Set color limits for log2 to be centered around 0
|
|
249
|
+
if log_base == "log2" and log_transform and (vmin is None or vmax is None):
|
|
250
|
+
abs_max = np.nanmax(np.abs(mat.values))
|
|
251
|
+
vmin = -abs_max if vmin is None else vmin
|
|
252
|
+
vmax = abs_max if vmax is None else vmax
|
|
253
|
+
|
|
254
|
+
# Create mask for triangle
|
|
255
|
+
mask = None
|
|
256
|
+
if triangle == "lower":
|
|
257
|
+
mask = np.triu(np.ones_like(mat, dtype=bool), k=1)
|
|
258
|
+
elif triangle == "upper":
|
|
259
|
+
mask = np.tril(np.ones_like(mat, dtype=bool), k=-1)
|
|
260
|
+
|
|
261
|
+
xticks = mat.columns.astype(int)
|
|
262
|
+
yticks = mat.index.astype(int)
|
|
263
|
+
|
|
264
|
+
# 👉 Make taller figure: heatmap on top, line plot below
|
|
265
|
+
fig, axs = plt.subplots(2, 1, figsize=figsize, height_ratios=[3, 1.5])
|
|
266
|
+
heat_ax, line_ax = axs
|
|
267
|
+
|
|
268
|
+
# Heatmap
|
|
269
|
+
sns.heatmap(
|
|
270
|
+
mat,
|
|
271
|
+
mask=mask,
|
|
272
|
+
cmap=cmap,
|
|
273
|
+
xticklabels=xticks,
|
|
274
|
+
yticklabels=yticks,
|
|
275
|
+
square=True,
|
|
276
|
+
vmin=vmin,
|
|
277
|
+
vmax=vmax,
|
|
278
|
+
cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
|
|
279
|
+
ax=heat_ax
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
heat_ax.set_title(f"{key} — {group}", pad=20)
|
|
283
|
+
heat_ax.set_xticks(np.arange(0, len(xticks), xtick_step))
|
|
284
|
+
heat_ax.set_xticklabels(xticks[::xtick_step], rotation=90)
|
|
285
|
+
heat_ax.set_yticks(np.arange(0, len(yticks), ytick_step))
|
|
286
|
+
heat_ax.set_yticklabels(yticks[::ytick_step])
|
|
287
|
+
|
|
288
|
+
# Line plot
|
|
289
|
+
if highlight_position is not None:
|
|
290
|
+
colors = plt.cm.tab10.colors
|
|
291
|
+
for i, pos in enumerate(highlight_position):
|
|
292
|
+
try:
|
|
293
|
+
if highlight_axis == "row":
|
|
294
|
+
closest = find_closest_index(mat.index, pos)
|
|
295
|
+
series = mat.loc[closest]
|
|
296
|
+
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
297
|
+
idx = mat.index.get_loc(closest)
|
|
298
|
+
heat_ax.axhline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
|
|
299
|
+
label = f"Row {pos} → {closest}"
|
|
300
|
+
else:
|
|
301
|
+
closest = find_closest_index(mat.columns, pos)
|
|
302
|
+
series = mat[closest]
|
|
303
|
+
x_vals = pd.to_numeric(series.index, errors="coerce")
|
|
304
|
+
idx = mat.columns.get_loc(closest)
|
|
305
|
+
heat_ax.axvline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
|
|
306
|
+
label = f"Col {pos} → {closest}"
|
|
307
|
+
|
|
308
|
+
line = line_ax.plot(x_vals, series.values, marker='o', label=label, color=colors[i % len(colors)])
|
|
309
|
+
|
|
310
|
+
# Annotate each point
|
|
311
|
+
if annotate_points:
|
|
312
|
+
for x, y in zip(x_vals, series.values):
|
|
313
|
+
if not np.isnan(y):
|
|
314
|
+
line_ax.annotate(
|
|
315
|
+
f"{y:.2f}",
|
|
316
|
+
xy=(x, y),
|
|
317
|
+
textcoords="offset points",
|
|
318
|
+
xytext=(0, 5),
|
|
319
|
+
ha='center',
|
|
320
|
+
fontsize=8
|
|
321
|
+
)
|
|
322
|
+
except Exception as e:
|
|
323
|
+
line_ax.text(0.5, 0.5, f"⚠️ Error plotting {highlight_axis} @ {pos}",
|
|
324
|
+
ha='center', va='center', fontsize=10)
|
|
325
|
+
print(f"Error plotting line for {highlight_axis}={pos}: {e}")
|
|
326
|
+
|
|
327
|
+
line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
|
|
328
|
+
line_ax.set_xlabel(f"{'Column' if highlight_axis == 'row' else 'Row'} position")
|
|
329
|
+
line_ax.set_ylabel("Value")
|
|
330
|
+
line_ax.grid(True)
|
|
331
|
+
line_ax.legend(fontsize=8)
|
|
332
|
+
|
|
333
|
+
plt.tight_layout()
|
|
334
|
+
|
|
335
|
+
# Save if requested
|
|
336
|
+
if save_path:
|
|
337
|
+
os.makedirs(save_path, exist_ok=True)
|
|
338
|
+
safe_name = group.replace("=", "").replace("__", "_").replace(",", "_")
|
|
339
|
+
out_file = os.path.join(save_path, f"{key}_{safe_name}.png")
|
|
340
|
+
plt.savefig(out_file, dpi=300)
|
|
341
|
+
print(f"📁 Saved: {out_file}")
|
|
342
|
+
|
|
343
|
+
plt.show()
|
|
344
|
+
|
|
345
|
+
def plot_positionwise_matrix_grid(
|
|
346
|
+
adata,
|
|
347
|
+
key,
|
|
348
|
+
outer_keys=["Reference_strand", "activity_status"],
|
|
349
|
+
inner_keys=["Promoter_Open", "Enhancer_Open"],
|
|
350
|
+
log_transform=None,
|
|
351
|
+
vmin=None,
|
|
352
|
+
vmax=None,
|
|
353
|
+
cmap="vlag",
|
|
354
|
+
save_path=None,
|
|
355
|
+
figsize=(10, 10),
|
|
356
|
+
xtick_step=10,
|
|
357
|
+
ytick_step=10,
|
|
358
|
+
parallel=False,
|
|
359
|
+
max_threads=None
|
|
360
|
+
):
|
|
361
|
+
import matplotlib.pyplot as plt
|
|
362
|
+
import seaborn as sns
|
|
363
|
+
import numpy as np
|
|
364
|
+
import pandas as pd
|
|
365
|
+
import os
|
|
366
|
+
from matplotlib.gridspec import GridSpec
|
|
367
|
+
from joblib import Parallel, delayed
|
|
368
|
+
|
|
369
|
+
matrices = adata.uns[key]
|
|
370
|
+
group_labels = list(matrices.keys())
|
|
371
|
+
|
|
372
|
+
parsed_inner = pd.DataFrame([dict(zip(inner_keys, g.split("_")[-len(inner_keys):])) for g in group_labels])
|
|
373
|
+
parsed_outer = pd.Series(["_".join(g.split("_")[:-len(inner_keys)]) for g in group_labels], name="outer")
|
|
374
|
+
parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
|
|
375
|
+
|
|
376
|
+
def plot_one_grid(outer_label):
|
|
377
|
+
selected = parsed[parsed['outer'] == outer_label].copy()
|
|
378
|
+
selected["group_str"] = [f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}" for _, row in selected.iterrows()]
|
|
379
|
+
|
|
380
|
+
row_vals = sorted(selected[inner_keys[0]].unique())
|
|
381
|
+
col_vals = sorted(selected[inner_keys[1]].unique())
|
|
382
|
+
|
|
383
|
+
fig = plt.figure(figsize=figsize)
|
|
384
|
+
gs = GridSpec(len(row_vals), len(col_vals) + 1, width_ratios=[1]*len(col_vals) + [0.05], wspace=0.3)
|
|
385
|
+
axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
|
|
386
|
+
|
|
387
|
+
local_vmin, local_vmax = vmin, vmax
|
|
388
|
+
if log_transform == "log2" and (vmin is None or vmax is None):
|
|
389
|
+
all_data = []
|
|
390
|
+
for group_str in selected["group_str"]:
|
|
391
|
+
mat = matrices.get(group_str)
|
|
392
|
+
if mat is not None:
|
|
393
|
+
all_data.append(np.log2(mat.replace(0, 1e-9).values))
|
|
394
|
+
if all_data:
|
|
395
|
+
combined = np.concatenate([arr.flatten() for arr in all_data])
|
|
396
|
+
vmax_auto = np.nanmax(np.abs(combined))
|
|
397
|
+
local_vmin = -vmax_auto if vmin is None else vmin
|
|
398
|
+
local_vmax = vmax_auto if vmax is None else vmax
|
|
399
|
+
|
|
400
|
+
cbar_label = {
|
|
401
|
+
"log2": "log2(Value)",
|
|
402
|
+
"log1p": "log1p(Value)"
|
|
403
|
+
}.get(log_transform, "Value")
|
|
404
|
+
|
|
405
|
+
cbar_ax = fig.add_subplot(gs[:, -1])
|
|
406
|
+
|
|
407
|
+
for i, row_val in enumerate(row_vals):
|
|
408
|
+
for j, col_val in enumerate(col_vals):
|
|
409
|
+
group_label = f"{outer_label}_{row_val}_{col_val}"
|
|
410
|
+
ax = fig.add_subplot(gs[i, j])
|
|
411
|
+
axes[i, j] = ax
|
|
412
|
+
mat = matrices.get(group_label)
|
|
413
|
+
if mat is None:
|
|
414
|
+
ax.axis("off")
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
data = mat.copy()
|
|
418
|
+
if log_transform == "log2":
|
|
419
|
+
data = np.log2(data.replace(0, 1e-9))
|
|
420
|
+
elif log_transform == "log1p":
|
|
421
|
+
data = np.log1p(data)
|
|
422
|
+
|
|
423
|
+
sns.heatmap(
|
|
424
|
+
data,
|
|
425
|
+
ax=ax,
|
|
426
|
+
cmap=cmap,
|
|
427
|
+
xticklabels=True,
|
|
428
|
+
yticklabels=True,
|
|
429
|
+
square=True,
|
|
430
|
+
vmin=local_vmin,
|
|
431
|
+
vmax=local_vmax,
|
|
432
|
+
cbar=(i == 0 and j == 0),
|
|
433
|
+
cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
|
|
434
|
+
cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
|
|
435
|
+
)
|
|
436
|
+
ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
|
|
437
|
+
|
|
438
|
+
xticks = data.columns.astype(int)
|
|
439
|
+
yticks = data.index.astype(int)
|
|
440
|
+
ax.set_xticks(np.arange(0, len(xticks), xtick_step))
|
|
441
|
+
ax.set_xticklabels(xticks[::xtick_step], rotation=90)
|
|
442
|
+
ax.set_yticks(np.arange(0, len(yticks), ytick_step))
|
|
443
|
+
ax.set_yticklabels(yticks[::ytick_step])
|
|
444
|
+
|
|
445
|
+
fig.suptitle(f"{key} • {outer_label}", fontsize=14, y=1.02)
|
|
446
|
+
fig.tight_layout(rect=[0, 0, 0.97, 0.95])
|
|
447
|
+
|
|
448
|
+
if save_path:
|
|
449
|
+
os.makedirs(save_path, exist_ok=True)
|
|
450
|
+
fname = outer_label.replace("_", "").replace("=", "") + ".png"
|
|
451
|
+
plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches='tight')
|
|
452
|
+
print(f"✅ Saved {fname}")
|
|
453
|
+
|
|
454
|
+
plt.close(fig)
|
|
455
|
+
|
|
456
|
+
if parallel:
|
|
457
|
+
Parallel(n_jobs=max_threads)(delayed(plot_one_grid)(outer_label) for outer_label in parsed['outer'].unique())
|
|
458
|
+
else:
|
|
459
|
+
for outer_label in parsed['outer'].unique():
|
|
460
|
+
plot_one_grid(outer_label)
|
|
461
|
+
|
|
462
|
+
print("✅ Finished plotting all grids.")
|
|
@@ -6,13 +6,13 @@ from .calculate_coverage import calculate_coverage
|
|
|
6
6
|
from .calculate_position_Youden import calculate_position_Youden
|
|
7
7
|
from .calculate_read_length_stats import calculate_read_length_stats
|
|
8
8
|
from .clean_NaN import clean_NaN
|
|
9
|
+
from .filter_adata_by_nan_proportion import filter_adata_by_nan_proportion
|
|
9
10
|
from .filter_converted_reads_on_methylation import filter_converted_reads_on_methylation
|
|
10
11
|
from .filter_reads_on_length import filter_reads_on_length
|
|
11
12
|
from .invert_adata import invert_adata
|
|
12
13
|
from .load_sample_sheet import load_sample_sheet
|
|
13
|
-
from .
|
|
14
|
-
from .
|
|
15
|
-
from .recipes import recipe_1_Kissiov_and_McKenna_2025, recipe_2_Kissiov_and_McKenna_2025
|
|
14
|
+
from .flag_duplicate_reads import flag_duplicate_reads
|
|
15
|
+
from .subsample_adata import subsample_adata
|
|
16
16
|
|
|
17
17
|
__all__ = [
|
|
18
18
|
"append_C_context",
|
|
@@ -23,12 +23,11 @@ __all__ = [
|
|
|
23
23
|
"calculate_position_Youden",
|
|
24
24
|
"calculate_read_length_stats",
|
|
25
25
|
"clean_NaN",
|
|
26
|
+
"filter_adata_by_nan_proportion",
|
|
26
27
|
"filter_converted_reads_on_methylation",
|
|
27
28
|
"filter_reads_on_length",
|
|
28
29
|
"invert_adata",
|
|
29
30
|
"load_sample_sheet",
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
-
"recipe_1_Kissiov_and_McKenna_2025",
|
|
33
|
-
"recipe_2_Kissiov_and_McKenna_2025"
|
|
31
|
+
"flag_duplicate_reads",
|
|
32
|
+
"subsample_adata"
|
|
34
33
|
]
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
## Conversion SMF Specific
|
|
4
4
|
# Read methylation QC
|
|
5
|
-
def append_C_context(adata, obs_column='Reference', use_consensus=False):
|
|
5
|
+
def append_C_context(adata, obs_column='Reference', use_consensus=False, native=False):
|
|
6
6
|
"""
|
|
7
7
|
Adds Cytosine context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
|
|
8
8
|
|
|
@@ -10,14 +10,17 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
|
|
|
10
10
|
adata (AnnData): The input adata object.
|
|
11
11
|
obs_column (str): The observation column in which to stratify on. Default is 'Reference', which should not be changed for most purposes.
|
|
12
12
|
use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
|
|
13
|
-
|
|
13
|
+
native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
|
|
14
14
|
|
|
15
15
|
Returns:
|
|
16
16
|
None
|
|
17
17
|
"""
|
|
18
18
|
import numpy as np
|
|
19
19
|
import anndata as ad
|
|
20
|
-
|
|
20
|
+
|
|
21
|
+
print('Adding Cytosine context based on reference FASTA sequence for sample')
|
|
22
|
+
|
|
23
|
+
site_types = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C', 'any_C_site']
|
|
21
24
|
categories = adata.obs[obs_column].cat.categories
|
|
22
25
|
for cat in categories:
|
|
23
26
|
# Assess if the strand is the top or bottom strand converted
|
|
@@ -26,11 +29,20 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
|
|
|
26
29
|
elif 'bottom' in cat:
|
|
27
30
|
strand = 'bottom'
|
|
28
31
|
|
|
29
|
-
if
|
|
30
|
-
|
|
32
|
+
if native:
|
|
33
|
+
basename = cat.split(f"_{strand}")[0]
|
|
34
|
+
if use_consensus:
|
|
35
|
+
sequence = adata.uns[f'{basename}_consensus_sequence']
|
|
36
|
+
else:
|
|
37
|
+
# This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
|
|
38
|
+
sequence = adata.uns[f'{basename}_FASTA_sequence']
|
|
31
39
|
else:
|
|
32
|
-
|
|
33
|
-
|
|
40
|
+
basename = cat.split(f"_{strand}")[0]
|
|
41
|
+
if use_consensus:
|
|
42
|
+
sequence = adata.uns[f'{basename}_consensus_sequence']
|
|
43
|
+
else:
|
|
44
|
+
# This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
|
|
45
|
+
sequence = adata.uns[f'{basename}_FASTA_sequence']
|
|
34
46
|
# Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
|
|
35
47
|
boolean_dict = {}
|
|
36
48
|
for site_type in site_types:
|
|
@@ -40,6 +52,7 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
|
|
|
40
52
|
# Iterate through the sequence and apply the criteria
|
|
41
53
|
for i in range(1, len(sequence) - 1):
|
|
42
54
|
if sequence[i] == 'C':
|
|
55
|
+
boolean_dict[f'{cat}_any_C_site'][i] = True
|
|
43
56
|
if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
|
|
44
57
|
boolean_dict[f'{cat}_GpC_site'][i] = True
|
|
45
58
|
elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
|
|
@@ -52,6 +65,7 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
|
|
|
52
65
|
# Iterate through the sequence and apply the criteria
|
|
53
66
|
for i in range(1, len(sequence) - 1):
|
|
54
67
|
if sequence[i] == 'G':
|
|
68
|
+
boolean_dict[f'{cat}_any_C_site'][i] = True
|
|
55
69
|
if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
|
|
56
70
|
boolean_dict[f'{cat}_GpC_site'][i] = True
|
|
57
71
|
elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
|
|
@@ -65,5 +79,4 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
|
|
|
65
79
|
|
|
66
80
|
for site_type in site_types:
|
|
67
81
|
adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
|
|
68
|
-
adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].
|
|
69
|
-
|
|
82
|
+
adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X
|