smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,462 @@
1
+ def plot_volcano_relative_risk(
2
+ results_dict,
3
+ save_path=None,
4
+ highlight_regions=None, # List of (start, end) tuples
5
+ highlight_color="lightgray",
6
+ highlight_alpha=0.3,
7
+ xlim=None,
8
+ ylim=None,
9
+ ):
10
+ """
11
+ Plot volcano-style log2(Relative Risk) vs Genomic Position for each group within each reference.
12
+
13
+ Parameters:
14
+ results_dict (dict): Output from calculate_relative_risk_by_group.
15
+ Format: dict[ref][group_label] = (results_df, sig_df)
16
+ save_path (str): Directory to save plots.
17
+ highlight_regions (list): List of (start, end) tuples for shaded regions.
18
+ highlight_color (str): Color for highlighted regions.
19
+ highlight_alpha (float): Alpha for highlighted region.
20
+ xlim (tuple): Optional x-axis limit.
21
+ ylim (tuple): Optional y-axis limit.
22
+ """
23
+ import matplotlib.pyplot as plt
24
+ import numpy as np
25
+ import os
26
+
27
+ for ref, group_results in results_dict.items():
28
+ for group_label, (results_df, _) in group_results.items():
29
+ if results_df.empty:
30
+ print(f"Skipping empty results for {ref} / {group_label}")
31
+ continue
32
+
33
+ # Split by site type
34
+ gpc_df = results_df[results_df['GpC_Site']]
35
+ cpg_df = results_df[results_df['CpG_Site']]
36
+
37
+ fig, ax = plt.subplots(figsize=(12, 6))
38
+
39
+ # Highlight regions
40
+ if highlight_regions:
41
+ for start, end in highlight_regions:
42
+ ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
43
+
44
+ # GpC as circles
45
+ sc1 = ax.scatter(
46
+ gpc_df['Genomic_Position'],
47
+ gpc_df['log2_Relative_Risk'],
48
+ c=gpc_df['-log10_Adj_P'],
49
+ cmap='coolwarm',
50
+ edgecolor='k',
51
+ s=40,
52
+ marker='o',
53
+ label='GpC'
54
+ )
55
+
56
+ # CpG as stars
57
+ sc2 = ax.scatter(
58
+ cpg_df['Genomic_Position'],
59
+ cpg_df['log2_Relative_Risk'],
60
+ c=cpg_df['-log10_Adj_P'],
61
+ cmap='coolwarm',
62
+ edgecolor='k',
63
+ s=60,
64
+ marker='*',
65
+ label='CpG'
66
+ )
67
+
68
+ ax.axhline(y=0, color='gray', linestyle='--')
69
+ ax.set_xlabel("Genomic Position")
70
+ ax.set_ylabel("log2(Relative Risk)")
71
+ ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
72
+
73
+ if xlim:
74
+ ax.set_xlim(xlim)
75
+ if ylim:
76
+ ax.set_ylim(ylim)
77
+
78
+ ax.spines['top'].set_visible(False)
79
+ ax.spines['right'].set_visible(False)
80
+
81
+ cbar = plt.colorbar(sc1, ax=ax)
82
+ cbar.set_label("-log10(Adjusted P-Value)")
83
+
84
+ ax.legend()
85
+ plt.tight_layout()
86
+
87
+ # Save if requested
88
+ if save_path:
89
+ os.makedirs(save_path, exist_ok=True)
90
+ safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
91
+ out_file = os.path.join(save_path, f"{safe_name}.png")
92
+ plt.savefig(out_file, dpi=300)
93
+ print(f"📁 Saved: {out_file}")
94
+
95
+ plt.show()
96
+
97
+ def plot_bar_relative_risk(
98
+ results_dict,
99
+ sort_by_position=True,
100
+ xlim=None,
101
+ ylim=None,
102
+ save_path=None,
103
+ highlight_regions=None, # List of (start, end) tuples
104
+ highlight_color="lightgray",
105
+ highlight_alpha=0.3
106
+ ):
107
+ """
108
+ Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
109
+
110
+ Parameters:
111
+ results_dict (dict): Output from calculate_relative_risk_by_group.
112
+ sort_by_position (bool): Whether to sort bars left-to-right by genomic coordinate.
113
+ xlim, ylim (tuple): Axis limits.
114
+ save_path (str or None): Directory to save plots.
115
+ highlight_regions (list of tuple): List of (start, end) genomic regions to shade.
116
+ highlight_color (str): Color of shaded region.
117
+ highlight_alpha (float): Transparency of shaded region.
118
+ """
119
+ import matplotlib.pyplot as plt
120
+ import numpy as np
121
+ import os
122
+
123
+ for ref, group_data in results_dict.items():
124
+ for group_label, (df, _) in group_data.items():
125
+ if df.empty:
126
+ print(f"Skipping empty result for {ref} / {group_label}")
127
+ continue
128
+
129
+ df = df.copy()
130
+ df['Genomic_Position'] = df['Genomic_Position'].astype(int)
131
+
132
+ if sort_by_position:
133
+ df = df.sort_values('Genomic_Position')
134
+
135
+ gpc_mask = df['GpC_Site'] & ~df['CpG_Site']
136
+ cpg_mask = df['CpG_Site'] & ~df['GpC_Site']
137
+ both_mask = df['GpC_Site'] & df['CpG_Site']
138
+
139
+ fig, ax = plt.subplots(figsize=(14, 6))
140
+
141
+ # Optional shaded regions
142
+ if highlight_regions:
143
+ for start, end in highlight_regions:
144
+ ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
145
+
146
+ # Bar plots
147
+ ax.bar(
148
+ df['Genomic_Position'][gpc_mask],
149
+ df['log2_Relative_Risk'][gpc_mask],
150
+ width=10,
151
+ color='steelblue',
152
+ label='GpC Site',
153
+ edgecolor='black'
154
+ )
155
+
156
+ ax.bar(
157
+ df['Genomic_Position'][cpg_mask],
158
+ df['log2_Relative_Risk'][cpg_mask],
159
+ width=10,
160
+ color='darkorange',
161
+ label='CpG Site',
162
+ edgecolor='black'
163
+ )
164
+
165
+ if both_mask.any():
166
+ ax.bar(
167
+ df['Genomic_Position'][both_mask],
168
+ df['log2_Relative_Risk'][both_mask],
169
+ width=10,
170
+ color='purple',
171
+ label='GpC + CpG',
172
+ edgecolor='black'
173
+ )
174
+
175
+ ax.axhline(y=0, color='gray', linestyle='--')
176
+ ax.set_xlabel('Genomic Position')
177
+ ax.set_ylabel('log2(Relative Risk)')
178
+ ax.set_title(f"{ref} — {group_label}")
179
+ ax.legend()
180
+
181
+ if xlim:
182
+ ax.set_xlim(xlim)
183
+ if ylim:
184
+ ax.set_ylim(ylim)
185
+
186
+ ax.spines['top'].set_visible(False)
187
+ ax.spines['right'].set_visible(False)
188
+
189
+ plt.tight_layout()
190
+
191
+ if save_path:
192
+ os.makedirs(save_path, exist_ok=True)
193
+ safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
194
+ out_file = os.path.join(save_path, f"{safe_name}.png")
195
+ plt.savefig(out_file, dpi=300)
196
+ print(f"📁 Saved: {out_file}")
197
+
198
+ plt.show()
199
+
200
+ def plot_positionwise_matrix(
201
+ adata,
202
+ key="positionwise_result",
203
+ log_transform=False,
204
+ log_base="log1p", # or 'log2', or None
205
+ triangle="full",
206
+ cmap="vlag",
207
+ figsize=(12, 10), # Taller to accommodate line plot below
208
+ vmin=None,
209
+ vmax=None,
210
+ xtick_step=10,
211
+ ytick_step=10,
212
+ save_path=None,
213
+ highlight_position=None, # Can be a single int/float or list of them
214
+ highlight_axis="row", # "row" or "column"
215
+ annotate_points=False # ✅ New option
216
+ ):
217
+ """
218
+ Plots positionwise matrices stored in adata.uns[key], with an optional line plot
219
+ for specified row(s) or column(s), and highlights them on the heatmap.
220
+ """
221
+ import matplotlib.pyplot as plt
222
+ import seaborn as sns
223
+ import numpy as np
224
+ import pandas as pd
225
+ import os
226
+
227
+ def find_closest_index(index, target):
228
+ index_vals = pd.to_numeric(index, errors="coerce")
229
+ target_val = pd.to_numeric([target], errors="coerce")[0]
230
+ diffs = pd.Series(np.abs(index_vals - target_val), index=index)
231
+ return diffs.idxmin()
232
+
233
+ # Ensure highlight_position is a list
234
+ if highlight_position is not None and not isinstance(highlight_position, (list, tuple, np.ndarray)):
235
+ highlight_position = [highlight_position]
236
+
237
+ for group, mat_df in adata.uns[key].items():
238
+ mat = mat_df.copy()
239
+
240
+ if log_transform:
241
+ with np.errstate(divide='ignore', invalid='ignore'):
242
+ if log_base == "log1p":
243
+ mat = np.log1p(mat)
244
+ elif log_base == "log2":
245
+ mat = np.log2(mat.replace(0, np.nanmin(mat[mat > 0]) * 0.1))
246
+ mat.replace([np.inf, -np.inf], np.nan, inplace=True)
247
+
248
+ # Set color limits for log2 to be centered around 0
249
+ if log_base == "log2" and log_transform and (vmin is None or vmax is None):
250
+ abs_max = np.nanmax(np.abs(mat.values))
251
+ vmin = -abs_max if vmin is None else vmin
252
+ vmax = abs_max if vmax is None else vmax
253
+
254
+ # Create mask for triangle
255
+ mask = None
256
+ if triangle == "lower":
257
+ mask = np.triu(np.ones_like(mat, dtype=bool), k=1)
258
+ elif triangle == "upper":
259
+ mask = np.tril(np.ones_like(mat, dtype=bool), k=-1)
260
+
261
+ xticks = mat.columns.astype(int)
262
+ yticks = mat.index.astype(int)
263
+
264
+ # 👉 Make taller figure: heatmap on top, line plot below
265
+ fig, axs = plt.subplots(2, 1, figsize=figsize, height_ratios=[3, 1.5])
266
+ heat_ax, line_ax = axs
267
+
268
+ # Heatmap
269
+ sns.heatmap(
270
+ mat,
271
+ mask=mask,
272
+ cmap=cmap,
273
+ xticklabels=xticks,
274
+ yticklabels=yticks,
275
+ square=True,
276
+ vmin=vmin,
277
+ vmax=vmax,
278
+ cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
279
+ ax=heat_ax
280
+ )
281
+
282
+ heat_ax.set_title(f"{key} — {group}", pad=20)
283
+ heat_ax.set_xticks(np.arange(0, len(xticks), xtick_step))
284
+ heat_ax.set_xticklabels(xticks[::xtick_step], rotation=90)
285
+ heat_ax.set_yticks(np.arange(0, len(yticks), ytick_step))
286
+ heat_ax.set_yticklabels(yticks[::ytick_step])
287
+
288
+ # Line plot
289
+ if highlight_position is not None:
290
+ colors = plt.cm.tab10.colors
291
+ for i, pos in enumerate(highlight_position):
292
+ try:
293
+ if highlight_axis == "row":
294
+ closest = find_closest_index(mat.index, pos)
295
+ series = mat.loc[closest]
296
+ x_vals = pd.to_numeric(series.index, errors="coerce")
297
+ idx = mat.index.get_loc(closest)
298
+ heat_ax.axhline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
299
+ label = f"Row {pos} → {closest}"
300
+ else:
301
+ closest = find_closest_index(mat.columns, pos)
302
+ series = mat[closest]
303
+ x_vals = pd.to_numeric(series.index, errors="coerce")
304
+ idx = mat.columns.get_loc(closest)
305
+ heat_ax.axvline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
306
+ label = f"Col {pos} → {closest}"
307
+
308
+ line = line_ax.plot(x_vals, series.values, marker='o', label=label, color=colors[i % len(colors)])
309
+
310
+ # Annotate each point
311
+ if annotate_points:
312
+ for x, y in zip(x_vals, series.values):
313
+ if not np.isnan(y):
314
+ line_ax.annotate(
315
+ f"{y:.2f}",
316
+ xy=(x, y),
317
+ textcoords="offset points",
318
+ xytext=(0, 5),
319
+ ha='center',
320
+ fontsize=8
321
+ )
322
+ except Exception as e:
323
+ line_ax.text(0.5, 0.5, f"⚠️ Error plotting {highlight_axis} @ {pos}",
324
+ ha='center', va='center', fontsize=10)
325
+ print(f"Error plotting line for {highlight_axis}={pos}: {e}")
326
+
327
+ line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
328
+ line_ax.set_xlabel(f"{'Column' if highlight_axis == 'row' else 'Row'} position")
329
+ line_ax.set_ylabel("Value")
330
+ line_ax.grid(True)
331
+ line_ax.legend(fontsize=8)
332
+
333
+ plt.tight_layout()
334
+
335
+ # Save if requested
336
+ if save_path:
337
+ os.makedirs(save_path, exist_ok=True)
338
+ safe_name = group.replace("=", "").replace("__", "_").replace(",", "_")
339
+ out_file = os.path.join(save_path, f"{key}_{safe_name}.png")
340
+ plt.savefig(out_file, dpi=300)
341
+ print(f"📁 Saved: {out_file}")
342
+
343
+ plt.show()
344
+
345
+ def plot_positionwise_matrix_grid(
346
+ adata,
347
+ key,
348
+ outer_keys=["Reference_strand", "activity_status"],
349
+ inner_keys=["Promoter_Open", "Enhancer_Open"],
350
+ log_transform=None,
351
+ vmin=None,
352
+ vmax=None,
353
+ cmap="vlag",
354
+ save_path=None,
355
+ figsize=(10, 10),
356
+ xtick_step=10,
357
+ ytick_step=10,
358
+ parallel=False,
359
+ max_threads=None
360
+ ):
361
+ import matplotlib.pyplot as plt
362
+ import seaborn as sns
363
+ import numpy as np
364
+ import pandas as pd
365
+ import os
366
+ from matplotlib.gridspec import GridSpec
367
+ from joblib import Parallel, delayed
368
+
369
+ matrices = adata.uns[key]
370
+ group_labels = list(matrices.keys())
371
+
372
+ parsed_inner = pd.DataFrame([dict(zip(inner_keys, g.split("_")[-len(inner_keys):])) for g in group_labels])
373
+ parsed_outer = pd.Series(["_".join(g.split("_")[:-len(inner_keys)]) for g in group_labels], name="outer")
374
+ parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
375
+
376
+ def plot_one_grid(outer_label):
377
+ selected = parsed[parsed['outer'] == outer_label].copy()
378
+ selected["group_str"] = [f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}" for _, row in selected.iterrows()]
379
+
380
+ row_vals = sorted(selected[inner_keys[0]].unique())
381
+ col_vals = sorted(selected[inner_keys[1]].unique())
382
+
383
+ fig = plt.figure(figsize=figsize)
384
+ gs = GridSpec(len(row_vals), len(col_vals) + 1, width_ratios=[1]*len(col_vals) + [0.05], wspace=0.3)
385
+ axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
386
+
387
+ local_vmin, local_vmax = vmin, vmax
388
+ if log_transform == "log2" and (vmin is None or vmax is None):
389
+ all_data = []
390
+ for group_str in selected["group_str"]:
391
+ mat = matrices.get(group_str)
392
+ if mat is not None:
393
+ all_data.append(np.log2(mat.replace(0, 1e-9).values))
394
+ if all_data:
395
+ combined = np.concatenate([arr.flatten() for arr in all_data])
396
+ vmax_auto = np.nanmax(np.abs(combined))
397
+ local_vmin = -vmax_auto if vmin is None else vmin
398
+ local_vmax = vmax_auto if vmax is None else vmax
399
+
400
+ cbar_label = {
401
+ "log2": "log2(Value)",
402
+ "log1p": "log1p(Value)"
403
+ }.get(log_transform, "Value")
404
+
405
+ cbar_ax = fig.add_subplot(gs[:, -1])
406
+
407
+ for i, row_val in enumerate(row_vals):
408
+ for j, col_val in enumerate(col_vals):
409
+ group_label = f"{outer_label}_{row_val}_{col_val}"
410
+ ax = fig.add_subplot(gs[i, j])
411
+ axes[i, j] = ax
412
+ mat = matrices.get(group_label)
413
+ if mat is None:
414
+ ax.axis("off")
415
+ continue
416
+
417
+ data = mat.copy()
418
+ if log_transform == "log2":
419
+ data = np.log2(data.replace(0, 1e-9))
420
+ elif log_transform == "log1p":
421
+ data = np.log1p(data)
422
+
423
+ sns.heatmap(
424
+ data,
425
+ ax=ax,
426
+ cmap=cmap,
427
+ xticklabels=True,
428
+ yticklabels=True,
429
+ square=True,
430
+ vmin=local_vmin,
431
+ vmax=local_vmax,
432
+ cbar=(i == 0 and j == 0),
433
+ cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
434
+ cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
435
+ )
436
+ ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
437
+
438
+ xticks = data.columns.astype(int)
439
+ yticks = data.index.astype(int)
440
+ ax.set_xticks(np.arange(0, len(xticks), xtick_step))
441
+ ax.set_xticklabels(xticks[::xtick_step], rotation=90)
442
+ ax.set_yticks(np.arange(0, len(yticks), ytick_step))
443
+ ax.set_yticklabels(yticks[::ytick_step])
444
+
445
+ fig.suptitle(f"{key} • {outer_label}", fontsize=14, y=1.02)
446
+ fig.tight_layout(rect=[0, 0, 0.97, 0.95])
447
+
448
+ if save_path:
449
+ os.makedirs(save_path, exist_ok=True)
450
+ fname = outer_label.replace("_", "").replace("=", "") + ".png"
451
+ plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches='tight')
452
+ print(f"✅ Saved {fname}")
453
+
454
+ plt.close(fig)
455
+
456
+ if parallel:
457
+ Parallel(n_jobs=max_threads)(delayed(plot_one_grid)(outer_label) for outer_label in parsed['outer'].unique())
458
+ else:
459
+ for outer_label in parsed['outer'].unique():
460
+ plot_one_grid(outer_label)
461
+
462
+ print("✅ Finished plotting all grids.")
@@ -6,13 +6,13 @@ from .calculate_coverage import calculate_coverage
6
6
  from .calculate_position_Youden import calculate_position_Youden
7
7
  from .calculate_read_length_stats import calculate_read_length_stats
8
8
  from .clean_NaN import clean_NaN
9
+ from .filter_adata_by_nan_proportion import filter_adata_by_nan_proportion
9
10
  from .filter_converted_reads_on_methylation import filter_converted_reads_on_methylation
10
11
  from .filter_reads_on_length import filter_reads_on_length
11
12
  from .invert_adata import invert_adata
12
13
  from .load_sample_sheet import load_sample_sheet
13
- from .mark_duplicates import mark_duplicates
14
- from .remove_duplicates import remove_duplicates
15
- from .recipes import recipe_1_Kissiov_and_McKenna_2025, recipe_2_Kissiov_and_McKenna_2025
14
+ from .flag_duplicate_reads import flag_duplicate_reads
15
+ from .subsample_adata import subsample_adata
16
16
 
17
17
  __all__ = [
18
18
  "append_C_context",
@@ -23,12 +23,11 @@ __all__ = [
23
23
  "calculate_position_Youden",
24
24
  "calculate_read_length_stats",
25
25
  "clean_NaN",
26
+ "filter_adata_by_nan_proportion",
26
27
  "filter_converted_reads_on_methylation",
27
28
  "filter_reads_on_length",
28
29
  "invert_adata",
29
30
  "load_sample_sheet",
30
- "mark_duplicates",
31
- "remove_duplicates",
32
- "recipe_1_Kissiov_and_McKenna_2025",
33
- "recipe_2_Kissiov_and_McKenna_2025"
31
+ "flag_duplicate_reads",
32
+ "subsample_adata"
34
33
  ]
@@ -2,7 +2,7 @@
2
2
 
3
3
  ## Conversion SMF Specific
4
4
  # Read methylation QC
5
- def append_C_context(adata, obs_column='Reference', use_consensus=False):
5
+ def append_C_context(adata, obs_column='Reference', use_consensus=False, native=False):
6
6
  """
7
7
  Adds Cytosine context to the position within the given category. When use_consensus is True, it uses the consensus sequence, otherwise it defaults to the FASTA sequence.
8
8
 
@@ -10,14 +10,17 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
10
10
  adata (AnnData): The input adata object.
11
11
  obs_column (str): The observation column in which to stratify on. Default is 'Reference', which should not be changed for most purposes.
12
12
  use_consensus (bool): A truth statement indicating whether to use the consensus sequence from the reads mapped to the reference. If False, the reference FASTA is used instead.
13
- Input: An adata object, the obs_column of interst, and whether to use the consensus sequence from the category.
13
+ native (bool): If False, perform conversion SMF assumptions. If True, perform native SMF assumptions
14
14
 
15
15
  Returns:
16
16
  None
17
17
  """
18
18
  import numpy as np
19
19
  import anndata as ad
20
- site_types = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C']
20
+
21
+ print('Adding Cytosine context based on reference FASTA sequence for sample')
22
+
23
+ site_types = ['GpC_site', 'CpG_site', 'ambiguous_GpC_CpG_site', 'other_C', 'any_C_site']
21
24
  categories = adata.obs[obs_column].cat.categories
22
25
  for cat in categories:
23
26
  # Assess if the strand is the top or bottom strand converted
@@ -26,11 +29,20 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
26
29
  elif 'bottom' in cat:
27
30
  strand = 'bottom'
28
31
 
29
- if use_consensus:
30
- sequence = adata.uns[f'{cat}_consensus_sequence']
32
+ if native:
33
+ basename = cat.split(f"_{strand}")[0]
34
+ if use_consensus:
35
+ sequence = adata.uns[f'{basename}_consensus_sequence']
36
+ else:
37
+ # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
38
+ sequence = adata.uns[f'{basename}_FASTA_sequence']
31
39
  else:
32
- # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
33
- sequence = adata.uns[f'{cat}_FASTA_sequence']
40
+ basename = cat.split(f"_{strand}")[0]
41
+ if use_consensus:
42
+ sequence = adata.uns[f'{basename}_consensus_sequence']
43
+ else:
44
+ # This sequence is the unconverted FASTA sequence of the original input FASTA for the locus
45
+ sequence = adata.uns[f'{basename}_FASTA_sequence']
34
46
  # Init a dict keyed by reference site type that points to a bool of whether the position is that site type.
35
47
  boolean_dict = {}
36
48
  for site_type in site_types:
@@ -40,6 +52,7 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
40
52
  # Iterate through the sequence and apply the criteria
41
53
  for i in range(1, len(sequence) - 1):
42
54
  if sequence[i] == 'C':
55
+ boolean_dict[f'{cat}_any_C_site'][i] = True
43
56
  if sequence[i - 1] == 'G' and sequence[i + 1] != 'G':
44
57
  boolean_dict[f'{cat}_GpC_site'][i] = True
45
58
  elif sequence[i - 1] == 'G' and sequence[i + 1] == 'G':
@@ -52,6 +65,7 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
52
65
  # Iterate through the sequence and apply the criteria
53
66
  for i in range(1, len(sequence) - 1):
54
67
  if sequence[i] == 'G':
68
+ boolean_dict[f'{cat}_any_C_site'][i] = True
55
69
  if sequence[i + 1] == 'C' and sequence[i - 1] != 'C':
56
70
  boolean_dict[f'{cat}_GpC_site'][i] = True
57
71
  elif sequence[i - 1] == 'C' and sequence[i + 1] == 'C':
@@ -65,5 +79,4 @@ def append_C_context(adata, obs_column='Reference', use_consensus=False):
65
79
 
66
80
  for site_type in site_types:
67
81
  adata.var[f'{cat}_{site_type}'] = boolean_dict[f'{cat}_{site_type}'].astype(bool)
68
- adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].copy().X
69
-
82
+ adata.obsm[f'{cat}_{site_type}'] = adata[:, adata.var[f'{cat}_{site_type}'] == True].X