smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,462 @@
1
+ def plot_volcano_relative_risk(
2
+ results_dict,
3
+ save_path=None,
4
+ highlight_regions=None, # List of (start, end) tuples
5
+ highlight_color="lightgray",
6
+ highlight_alpha=0.3,
7
+ xlim=None,
8
+ ylim=None,
9
+ ):
10
+ """
11
+ Plot volcano-style log2(Relative Risk) vs Genomic Position for each group within each reference.
12
+
13
+ Parameters:
14
+ results_dict (dict): Output from calculate_relative_risk_by_group.
15
+ Format: dict[ref][group_label] = (results_df, sig_df)
16
+ save_path (str): Directory to save plots.
17
+ highlight_regions (list): List of (start, end) tuples for shaded regions.
18
+ highlight_color (str): Color for highlighted regions.
19
+ highlight_alpha (float): Alpha for highlighted region.
20
+ xlim (tuple): Optional x-axis limit.
21
+ ylim (tuple): Optional y-axis limit.
22
+ """
23
+ import matplotlib.pyplot as plt
24
+ import numpy as np
25
+ import os
26
+
27
+ for ref, group_results in results_dict.items():
28
+ for group_label, (results_df, _) in group_results.items():
29
+ if results_df.empty:
30
+ print(f"Skipping empty results for {ref} / {group_label}")
31
+ continue
32
+
33
+ # Split by site type
34
+ gpc_df = results_df[results_df['GpC_Site']]
35
+ cpg_df = results_df[results_df['CpG_Site']]
36
+
37
+ fig, ax = plt.subplots(figsize=(12, 6))
38
+
39
+ # Highlight regions
40
+ if highlight_regions:
41
+ for start, end in highlight_regions:
42
+ ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
43
+
44
+ # GpC as circles
45
+ sc1 = ax.scatter(
46
+ gpc_df['Genomic_Position'],
47
+ gpc_df['log2_Relative_Risk'],
48
+ c=gpc_df['-log10_Adj_P'],
49
+ cmap='coolwarm',
50
+ edgecolor='k',
51
+ s=40,
52
+ marker='o',
53
+ label='GpC'
54
+ )
55
+
56
+ # CpG as stars
57
+ sc2 = ax.scatter(
58
+ cpg_df['Genomic_Position'],
59
+ cpg_df['log2_Relative_Risk'],
60
+ c=cpg_df['-log10_Adj_P'],
61
+ cmap='coolwarm',
62
+ edgecolor='k',
63
+ s=60,
64
+ marker='*',
65
+ label='CpG'
66
+ )
67
+
68
+ ax.axhline(y=0, color='gray', linestyle='--')
69
+ ax.set_xlabel("Genomic Position")
70
+ ax.set_ylabel("log2(Relative Risk)")
71
+ ax.set_title(f"{ref} / {group_label} — Relative Risk vs Genomic Position")
72
+
73
+ if xlim:
74
+ ax.set_xlim(xlim)
75
+ if ylim:
76
+ ax.set_ylim(ylim)
77
+
78
+ ax.spines['top'].set_visible(False)
79
+ ax.spines['right'].set_visible(False)
80
+
81
+ cbar = plt.colorbar(sc1, ax=ax)
82
+ cbar.set_label("-log10(Adjusted P-Value)")
83
+
84
+ ax.legend()
85
+ plt.tight_layout()
86
+
87
+ # Save if requested
88
+ if save_path:
89
+ os.makedirs(save_path, exist_ok=True)
90
+ safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
91
+ out_file = os.path.join(save_path, f"{safe_name}.png")
92
+ plt.savefig(out_file, dpi=300)
93
+ print(f"📁 Saved: {out_file}")
94
+
95
+ plt.show()
96
+
97
+ def plot_bar_relative_risk(
98
+ results_dict,
99
+ sort_by_position=True,
100
+ xlim=None,
101
+ ylim=None,
102
+ save_path=None,
103
+ highlight_regions=None, # List of (start, end) tuples
104
+ highlight_color="lightgray",
105
+ highlight_alpha=0.3
106
+ ):
107
+ """
108
+ Plot log2(Relative Risk) as a bar plot across genomic positions for each group within each reference.
109
+
110
+ Parameters:
111
+ results_dict (dict): Output from calculate_relative_risk_by_group.
112
+ sort_by_position (bool): Whether to sort bars left-to-right by genomic coordinate.
113
+ xlim, ylim (tuple): Axis limits.
114
+ save_path (str or None): Directory to save plots.
115
+ highlight_regions (list of tuple): List of (start, end) genomic regions to shade.
116
+ highlight_color (str): Color of shaded region.
117
+ highlight_alpha (float): Transparency of shaded region.
118
+ """
119
+ import matplotlib.pyplot as plt
120
+ import numpy as np
121
+ import os
122
+
123
+ for ref, group_data in results_dict.items():
124
+ for group_label, (df, _) in group_data.items():
125
+ if df.empty:
126
+ print(f"Skipping empty result for {ref} / {group_label}")
127
+ continue
128
+
129
+ df = df.copy()
130
+ df['Genomic_Position'] = df['Genomic_Position'].astype(int)
131
+
132
+ if sort_by_position:
133
+ df = df.sort_values('Genomic_Position')
134
+
135
+ gpc_mask = df['GpC_Site'] & ~df['CpG_Site']
136
+ cpg_mask = df['CpG_Site'] & ~df['GpC_Site']
137
+ both_mask = df['GpC_Site'] & df['CpG_Site']
138
+
139
+ fig, ax = plt.subplots(figsize=(14, 6))
140
+
141
+ # Optional shaded regions
142
+ if highlight_regions:
143
+ for start, end in highlight_regions:
144
+ ax.axvspan(start, end, color=highlight_color, alpha=highlight_alpha)
145
+
146
+ # Bar plots
147
+ ax.bar(
148
+ df['Genomic_Position'][gpc_mask],
149
+ df['log2_Relative_Risk'][gpc_mask],
150
+ width=10,
151
+ color='steelblue',
152
+ label='GpC Site',
153
+ edgecolor='black'
154
+ )
155
+
156
+ ax.bar(
157
+ df['Genomic_Position'][cpg_mask],
158
+ df['log2_Relative_Risk'][cpg_mask],
159
+ width=10,
160
+ color='darkorange',
161
+ label='CpG Site',
162
+ edgecolor='black'
163
+ )
164
+
165
+ if both_mask.any():
166
+ ax.bar(
167
+ df['Genomic_Position'][both_mask],
168
+ df['log2_Relative_Risk'][both_mask],
169
+ width=10,
170
+ color='purple',
171
+ label='GpC + CpG',
172
+ edgecolor='black'
173
+ )
174
+
175
+ ax.axhline(y=0, color='gray', linestyle='--')
176
+ ax.set_xlabel('Genomic Position')
177
+ ax.set_ylabel('log2(Relative Risk)')
178
+ ax.set_title(f"{ref} — {group_label}")
179
+ ax.legend()
180
+
181
+ if xlim:
182
+ ax.set_xlim(xlim)
183
+ if ylim:
184
+ ax.set_ylim(ylim)
185
+
186
+ ax.spines['top'].set_visible(False)
187
+ ax.spines['right'].set_visible(False)
188
+
189
+ plt.tight_layout()
190
+
191
+ if save_path:
192
+ os.makedirs(save_path, exist_ok=True)
193
+ safe_name = f"{ref}_{group_label}".replace("=", "").replace("__", "_").replace(",", "_")
194
+ out_file = os.path.join(save_path, f"{safe_name}.png")
195
+ plt.savefig(out_file, dpi=300)
196
+ print(f"📁 Saved: {out_file}")
197
+
198
+ plt.show()
199
+
200
+ def plot_positionwise_matrix(
201
+ adata,
202
+ key="positionwise_result",
203
+ log_transform=False,
204
+ log_base="log1p", # or 'log2', or None
205
+ triangle="full",
206
+ cmap="vlag",
207
+ figsize=(12, 10), # Taller to accommodate line plot below
208
+ vmin=None,
209
+ vmax=None,
210
+ xtick_step=10,
211
+ ytick_step=10,
212
+ save_path=None,
213
+ highlight_position=None, # Can be a single int/float or list of them
214
+ highlight_axis="row", # "row" or "column"
215
+ annotate_points=False # ✅ New option
216
+ ):
217
+ """
218
+ Plots positionwise matrices stored in adata.uns[key], with an optional line plot
219
+ for specified row(s) or column(s), and highlights them on the heatmap.
220
+ """
221
+ import matplotlib.pyplot as plt
222
+ import seaborn as sns
223
+ import numpy as np
224
+ import pandas as pd
225
+ import os
226
+
227
+ def find_closest_index(index, target):
228
+ index_vals = pd.to_numeric(index, errors="coerce")
229
+ target_val = pd.to_numeric([target], errors="coerce")[0]
230
+ diffs = pd.Series(np.abs(index_vals - target_val), index=index)
231
+ return diffs.idxmin()
232
+
233
+ # Ensure highlight_position is a list
234
+ if highlight_position is not None and not isinstance(highlight_position, (list, tuple, np.ndarray)):
235
+ highlight_position = [highlight_position]
236
+
237
+ for group, mat_df in adata.uns[key].items():
238
+ mat = mat_df.copy()
239
+
240
+ if log_transform:
241
+ with np.errstate(divide='ignore', invalid='ignore'):
242
+ if log_base == "log1p":
243
+ mat = np.log1p(mat)
244
+ elif log_base == "log2":
245
+ mat = np.log2(mat.replace(0, np.nanmin(mat[mat > 0]) * 0.1))
246
+ mat.replace([np.inf, -np.inf], np.nan, inplace=True)
247
+
248
+ # Set color limits for log2 to be centered around 0
249
+ if log_base == "log2" and log_transform and (vmin is None or vmax is None):
250
+ abs_max = np.nanmax(np.abs(mat.values))
251
+ vmin = -abs_max if vmin is None else vmin
252
+ vmax = abs_max if vmax is None else vmax
253
+
254
+ # Create mask for triangle
255
+ mask = None
256
+ if triangle == "lower":
257
+ mask = np.triu(np.ones_like(mat, dtype=bool), k=1)
258
+ elif triangle == "upper":
259
+ mask = np.tril(np.ones_like(mat, dtype=bool), k=-1)
260
+
261
+ xticks = mat.columns.astype(int)
262
+ yticks = mat.index.astype(int)
263
+
264
+ # 👉 Make taller figure: heatmap on top, line plot below
265
+ fig, axs = plt.subplots(2, 1, figsize=figsize, height_ratios=[3, 1.5])
266
+ heat_ax, line_ax = axs
267
+
268
+ # Heatmap
269
+ sns.heatmap(
270
+ mat,
271
+ mask=mask,
272
+ cmap=cmap,
273
+ xticklabels=xticks,
274
+ yticklabels=yticks,
275
+ square=True,
276
+ vmin=vmin,
277
+ vmax=vmax,
278
+ cbar_kws={"label": f"{key} ({log_base})" if log_transform else key},
279
+ ax=heat_ax
280
+ )
281
+
282
+ heat_ax.set_title(f"{key} — {group}", pad=20)
283
+ heat_ax.set_xticks(np.arange(0, len(xticks), xtick_step))
284
+ heat_ax.set_xticklabels(xticks[::xtick_step], rotation=90)
285
+ heat_ax.set_yticks(np.arange(0, len(yticks), ytick_step))
286
+ heat_ax.set_yticklabels(yticks[::ytick_step])
287
+
288
+ # Line plot
289
+ if highlight_position is not None:
290
+ colors = plt.cm.tab10.colors
291
+ for i, pos in enumerate(highlight_position):
292
+ try:
293
+ if highlight_axis == "row":
294
+ closest = find_closest_index(mat.index, pos)
295
+ series = mat.loc[closest]
296
+ x_vals = pd.to_numeric(series.index, errors="coerce")
297
+ idx = mat.index.get_loc(closest)
298
+ heat_ax.axhline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
299
+ label = f"Row {pos} → {closest}"
300
+ else:
301
+ closest = find_closest_index(mat.columns, pos)
302
+ series = mat[closest]
303
+ x_vals = pd.to_numeric(series.index, errors="coerce")
304
+ idx = mat.columns.get_loc(closest)
305
+ heat_ax.axvline(idx, color=colors[i % len(colors)], linestyle="--", linewidth=1)
306
+ label = f"Col {pos} → {closest}"
307
+
308
+ line = line_ax.plot(x_vals, series.values, marker='o', label=label, color=colors[i % len(colors)])
309
+
310
+ # Annotate each point
311
+ if annotate_points:
312
+ for x, y in zip(x_vals, series.values):
313
+ if not np.isnan(y):
314
+ line_ax.annotate(
315
+ f"{y:.2f}",
316
+ xy=(x, y),
317
+ textcoords="offset points",
318
+ xytext=(0, 5),
319
+ ha='center',
320
+ fontsize=8
321
+ )
322
+ except Exception as e:
323
+ line_ax.text(0.5, 0.5, f"⚠️ Error plotting {highlight_axis} @ {pos}",
324
+ ha='center', va='center', fontsize=10)
325
+ print(f"Error plotting line for {highlight_axis}={pos}: {e}")
326
+
327
+ line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
328
+ line_ax.set_xlabel(f"{'Column' if highlight_axis == 'row' else 'Row'} position")
329
+ line_ax.set_ylabel("Value")
330
+ line_ax.grid(True)
331
+ line_ax.legend(fontsize=8)
332
+
333
+ plt.tight_layout()
334
+
335
+ # Save if requested
336
+ if save_path:
337
+ os.makedirs(save_path, exist_ok=True)
338
+ safe_name = group.replace("=", "").replace("__", "_").replace(",", "_")
339
+ out_file = os.path.join(save_path, f"{key}_{safe_name}.png")
340
+ plt.savefig(out_file, dpi=300)
341
+ print(f"📁 Saved: {out_file}")
342
+
343
+ plt.show()
344
+
345
+ def plot_positionwise_matrix_grid(
346
+ adata,
347
+ key,
348
+ outer_keys=["Reference_strand", "activity_status"],
349
+ inner_keys=["Promoter_Open", "Enhancer_Open"],
350
+ log_transform=None,
351
+ vmin=None,
352
+ vmax=None,
353
+ cmap="vlag",
354
+ save_path=None,
355
+ figsize=(10, 10),
356
+ xtick_step=10,
357
+ ytick_step=10,
358
+ parallel=False,
359
+ max_threads=None
360
+ ):
361
+ import matplotlib.pyplot as plt
362
+ import seaborn as sns
363
+ import numpy as np
364
+ import pandas as pd
365
+ import os
366
+ from matplotlib.gridspec import GridSpec
367
+ from joblib import Parallel, delayed
368
+
369
+ matrices = adata.uns[key]
370
+ group_labels = list(matrices.keys())
371
+
372
+ parsed_inner = pd.DataFrame([dict(zip(inner_keys, g.split("_")[-len(inner_keys):])) for g in group_labels])
373
+ parsed_outer = pd.Series(["_".join(g.split("_")[:-len(inner_keys)]) for g in group_labels], name="outer")
374
+ parsed = pd.concat([parsed_outer, parsed_inner], axis=1)
375
+
376
+ def plot_one_grid(outer_label):
377
+ selected = parsed[parsed['outer'] == outer_label].copy()
378
+ selected["group_str"] = [f"{outer_label}_{row[inner_keys[0]]}_{row[inner_keys[1]]}" for _, row in selected.iterrows()]
379
+
380
+ row_vals = sorted(selected[inner_keys[0]].unique())
381
+ col_vals = sorted(selected[inner_keys[1]].unique())
382
+
383
+ fig = plt.figure(figsize=figsize)
384
+ gs = GridSpec(len(row_vals), len(col_vals) + 1, width_ratios=[1]*len(col_vals) + [0.05], wspace=0.3)
385
+ axes = np.empty((len(row_vals), len(col_vals)), dtype=object)
386
+
387
+ local_vmin, local_vmax = vmin, vmax
388
+ if log_transform == "log2" and (vmin is None or vmax is None):
389
+ all_data = []
390
+ for group_str in selected["group_str"]:
391
+ mat = matrices.get(group_str)
392
+ if mat is not None:
393
+ all_data.append(np.log2(mat.replace(0, 1e-9).values))
394
+ if all_data:
395
+ combined = np.concatenate([arr.flatten() for arr in all_data])
396
+ vmax_auto = np.nanmax(np.abs(combined))
397
+ local_vmin = -vmax_auto if vmin is None else vmin
398
+ local_vmax = vmax_auto if vmax is None else vmax
399
+
400
+ cbar_label = {
401
+ "log2": "log2(Value)",
402
+ "log1p": "log1p(Value)"
403
+ }.get(log_transform, "Value")
404
+
405
+ cbar_ax = fig.add_subplot(gs[:, -1])
406
+
407
+ for i, row_val in enumerate(row_vals):
408
+ for j, col_val in enumerate(col_vals):
409
+ group_label = f"{outer_label}_{row_val}_{col_val}"
410
+ ax = fig.add_subplot(gs[i, j])
411
+ axes[i, j] = ax
412
+ mat = matrices.get(group_label)
413
+ if mat is None:
414
+ ax.axis("off")
415
+ continue
416
+
417
+ data = mat.copy()
418
+ if log_transform == "log2":
419
+ data = np.log2(data.replace(0, 1e-9))
420
+ elif log_transform == "log1p":
421
+ data = np.log1p(data)
422
+
423
+ sns.heatmap(
424
+ data,
425
+ ax=ax,
426
+ cmap=cmap,
427
+ xticklabels=True,
428
+ yticklabels=True,
429
+ square=True,
430
+ vmin=local_vmin,
431
+ vmax=local_vmax,
432
+ cbar=(i == 0 and j == 0),
433
+ cbar_ax=cbar_ax if (i == 0 and j == 0) else None,
434
+ cbar_kws={"label": cbar_label if (i == 0 and j == 0) else ""}
435
+ )
436
+ ax.set_title(f"{inner_keys[0]}={row_val}, {inner_keys[1]}={col_val}", fontsize=9, pad=8)
437
+
438
+ xticks = data.columns.astype(int)
439
+ yticks = data.index.astype(int)
440
+ ax.set_xticks(np.arange(0, len(xticks), xtick_step))
441
+ ax.set_xticklabels(xticks[::xtick_step], rotation=90)
442
+ ax.set_yticks(np.arange(0, len(yticks), ytick_step))
443
+ ax.set_yticklabels(yticks[::ytick_step])
444
+
445
+ fig.suptitle(f"{key} • {outer_label}", fontsize=14, y=1.02)
446
+ fig.tight_layout(rect=[0, 0, 0.97, 0.95])
447
+
448
+ if save_path:
449
+ os.makedirs(save_path, exist_ok=True)
450
+ fname = outer_label.replace("_", "").replace("=", "") + ".png"
451
+ plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches='tight')
452
+ print(f"✅ Saved {fname}")
453
+
454
+ plt.close(fig)
455
+
456
+ if parallel:
457
+ Parallel(n_jobs=max_threads)(delayed(plot_one_grid)(outer_label) for outer_label in parsed['outer'].unique())
458
+ else:
459
+ for outer_label in parsed['outer'].unique():
460
+ plot_one_grid(outer_label)
461
+
462
+ print("✅ Finished plotting all grids.")