smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,682 @@
1
+ import numpy as np
2
+ import seaborn as sns
3
+ import matplotlib.pyplot as plt
4
+
5
+ def clean_barplot(ax, mean_values, title):
6
+ x = np.arange(len(mean_values))
7
+ ax.bar(x, mean_values, color="gray", width=1.0, align='edge')
8
+ ax.set_xlim(0, len(mean_values))
9
+ ax.set_ylim(0, 1)
10
+ ax.set_yticks([0.0, 0.5, 1.0])
11
+ ax.set_ylabel("Mean")
12
+ ax.set_title(title, fontsize=12, pad=2)
13
+
14
+ # Hide all spines except left
15
+ for spine_name, spine in ax.spines.items():
16
+ spine.set_visible(spine_name == 'left')
17
+
18
+ ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
19
+
20
+ def combined_hmm_raw_clustermap(
21
+ adata,
22
+ sample_col='Sample_Names',
23
+ reference_col='Reference_strand',
24
+ hmm_feature_layer="hmm_combined",
25
+ layer_gpc="nan0_0minus1",
26
+ layer_cpg="nan0_0minus1",
27
+ layer_any_c="nan0_0minus1",
28
+ cmap_hmm="tab10",
29
+ cmap_gpc="coolwarm",
30
+ cmap_cpg="viridis",
31
+ cmap_any_c='coolwarm',
32
+ min_quality=20,
33
+ min_length=200,
34
+ min_mapped_length_to_reference_length_ratio=0.8,
35
+ min_position_valid_fraction=0.5,
36
+ sample_mapping=None,
37
+ save_path=None,
38
+ normalize_hmm=False,
39
+ sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
40
+ bins=None,
41
+ deaminase=False,
42
+ min_signal=0
43
+ ):
44
+ import scipy.cluster.hierarchy as sch
45
+ import pandas as pd
46
+ import numpy as np
47
+ import seaborn as sns
48
+ import matplotlib.pyplot as plt
49
+ import matplotlib.gridspec as gridspec
50
+ import os
51
+
52
+ results = []
53
+ if deaminase:
54
+ signal_type = 'deamination'
55
+ else:
56
+ signal_type = 'methylation'
57
+
58
+ for ref in adata.obs[reference_col].cat.categories:
59
+ for sample in adata.obs[sample_col].cat.categories:
60
+ try:
61
+ subset = adata[
62
+ (adata.obs[reference_col] == ref) &
63
+ (adata.obs[sample_col] == sample) &
64
+ (adata.obs['read_quality'] >= min_quality) &
65
+ (adata.obs['read_length'] >= min_length) &
66
+ (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
67
+ ]
68
+
69
+ mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
70
+ subset = subset[:, mask]
71
+
72
+ if subset.shape[0] == 0:
73
+ print(f" No reads left after filtering for {sample} - {ref}")
74
+ continue
75
+
76
+ if bins:
77
+ print(f"Using defined bins to subset clustermap for {sample} - {ref}")
78
+ bins_temp = bins
79
+ else:
80
+ print(f"Using all reads for clustermap for {sample} - {ref}")
81
+ bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
82
+
83
+ # Get column positions (not var_names!) of site masks
84
+ gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
85
+ cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
86
+ any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
87
+ num_gpc = len(gpc_sites)
88
+ num_cpg = len(cpg_sites)
89
+ num_c = len(any_c_sites)
90
+ print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
91
+
92
+ # Use var_names for x-axis tick labels
93
+ gpc_labels = subset.var_names[gpc_sites].astype(int)
94
+ cpg_labels = subset.var_names[cpg_sites].astype(int)
95
+ any_c_labels = subset.var_names[any_c_sites].astype(int)
96
+
97
+ stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
98
+ row_labels, bin_labels = [], []
99
+ bin_boundaries = []
100
+
101
+ total_reads = subset.shape[0]
102
+ percentages = {}
103
+ last_idx = 0
104
+
105
+ for bin_label, bin_filter in bins_temp.items():
106
+ subset_bin = subset[bin_filter].copy()
107
+ num_reads = subset_bin.shape[0]
108
+ print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
109
+ percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
110
+ percentages[bin_label] = percent_reads
111
+
112
+ if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
113
+ # Determine sorting order
114
+ if sort_by.startswith("obs:"):
115
+ colname = sort_by.split("obs:")[1]
116
+ order = np.argsort(subset_bin.obs[colname].values)
117
+ elif sort_by == "gpc":
118
+ linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
119
+ order = sch.leaves_list(linkage)
120
+ elif sort_by == "cpg":
121
+ linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
122
+ order = sch.leaves_list(linkage)
123
+ elif sort_by == "gpc_cpg":
124
+ linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
125
+ order = sch.leaves_list(linkage)
126
+ elif sort_by == "none":
127
+ order = np.arange(num_reads)
128
+ elif sort_by == "any_c":
129
+ linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
130
+ order = sch.leaves_list(linkage)
131
+ else:
132
+ raise ValueError(f"Unsupported sort_by option: {sort_by}")
133
+
134
+ stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
135
+ stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
136
+ stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
137
+ stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
138
+
139
+ row_labels.extend([bin_label] * num_reads)
140
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
141
+ last_idx += num_reads
142
+ bin_boundaries.append(last_idx)
143
+
144
+ if stacked_hmm_feature:
145
+ hmm_matrix = np.vstack(stacked_hmm_feature)
146
+ gpc_matrix = np.vstack(stacked_gpc)
147
+ cpg_matrix = np.vstack(stacked_cpg)
148
+ any_c_matrix = np.vstack(stacked_any_c)
149
+
150
+ if hmm_matrix.size > 0:
151
+ def normalized_mean(matrix):
152
+ mean = np.nanmean(matrix, axis=0)
153
+ normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
154
+ return normalized
155
+
156
+ def methylation_fraction(matrix):
157
+ methylated = (matrix == 1).sum(axis=0)
158
+ valid = (matrix != 0).sum(axis=0)
159
+ return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
160
+
161
+ if normalize_hmm:
162
+ mean_hmm = normalized_mean(hmm_matrix)
163
+ else:
164
+ mean_hmm = np.nanmean(hmm_matrix, axis=0)
165
+ mean_gpc = methylation_fraction(gpc_matrix)
166
+ mean_cpg = methylation_fraction(cpg_matrix)
167
+ mean_any_c = methylation_fraction(any_c_matrix)
168
+
169
+ fig = plt.figure(figsize=(18, 12))
170
+ gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
171
+ fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
172
+
173
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
174
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
175
+
176
+ clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
177
+ clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
178
+ clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
179
+ clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
180
+
181
+ hmm_labels = subset.var_names.astype(int)
182
+ hmm_label_spacing = 150
183
+ sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
184
+ axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
185
+ axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
186
+ for boundary in bin_boundaries[:-1]:
187
+ axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
188
+
189
+ sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
190
+ axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
191
+ axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
192
+ for boundary in bin_boundaries[:-1]:
193
+ axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
194
+
195
+ sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
196
+ axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
197
+ for boundary in bin_boundaries[:-1]:
198
+ axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
199
+
200
+ sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
201
+ axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
202
+ axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
203
+ for boundary in bin_boundaries[:-1]:
204
+ axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
205
+
206
+ plt.tight_layout()
207
+
208
+ if save_path:
209
+ save_name = f"{ref} — {sample}"
210
+ os.makedirs(save_path, exist_ok=True)
211
+ safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
212
+ out_file = os.path.join(save_path, f"{safe_name}.png")
213
+ plt.savefig(out_file, dpi=300)
214
+ print(f"Saved: {out_file}")
215
+ plt.close()
216
+ else:
217
+ plt.show()
218
+
219
+ print(f"Summary for {sample} - {ref}:")
220
+ for bin_label, percent in percentages.items():
221
+ print(f" - {bin_label}: {percent:.1f}%")
222
+
223
+ results.append({
224
+ "sample": sample,
225
+ "ref": ref,
226
+ "hmm_matrix": hmm_matrix,
227
+ "gpc_matrix": gpc_matrix,
228
+ "cpg_matrix": cpg_matrix,
229
+ "row_labels": row_labels,
230
+ "bin_labels": bin_labels,
231
+ "bin_boundaries": bin_boundaries,
232
+ "percentages": percentages
233
+ })
234
+
235
+ adata.uns['clustermap_results'] = results
236
+
237
+ except Exception as e:
238
+ import traceback
239
+ traceback.print_exc()
240
+ continue
241
+
242
+
243
+ def combined_raw_clustermap(
244
+ adata,
245
+ sample_col='Sample_Names',
246
+ reference_col='Reference_strand',
247
+ layer_any_c="nan0_0minus1",
248
+ layer_gpc="nan0_0minus1",
249
+ layer_cpg="nan0_0minus1",
250
+ cmap_any_c="coolwarm",
251
+ cmap_gpc="coolwarm",
252
+ cmap_cpg="viridis",
253
+ min_quality=20,
254
+ min_length=200,
255
+ min_mapped_length_to_reference_length_ratio=0.8,
256
+ min_position_valid_fraction=0.5,
257
+ sample_mapping=None,
258
+ save_path=None,
259
+ sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
260
+ bins=None,
261
+ deaminase=False,
262
+ min_signal=0
263
+ ):
264
+ import scipy.cluster.hierarchy as sch
265
+ import pandas as pd
266
+ import numpy as np
267
+ import seaborn as sns
268
+ import matplotlib.pyplot as plt
269
+ import matplotlib.gridspec as gridspec
270
+ import os
271
+
272
+ results = []
273
+
274
+ for ref in adata.obs[reference_col].cat.categories:
275
+ for sample in adata.obs[sample_col].cat.categories:
276
+ try:
277
+ subset = adata[
278
+ (adata.obs[reference_col] == ref) &
279
+ (adata.obs[sample_col] == sample) &
280
+ (adata.obs['read_quality'] >= min_quality) &
281
+ (adata.obs['mapped_length'] >= min_length) &
282
+ (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
283
+ ]
284
+
285
+ mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
286
+ subset = subset[:, mask]
287
+
288
+ if subset.shape[0] == 0:
289
+ print(f" No reads left after filtering for {sample} - {ref}")
290
+ continue
291
+
292
+ if bins:
293
+ print(f"Using defined bins to subset clustermap for {sample} - {ref}")
294
+ bins_temp = bins
295
+ else:
296
+ print(f"Using all reads for clustermap for {sample} - {ref}")
297
+ bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
298
+
299
+ # Get column positions (not var_names!) of site masks
300
+ any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
301
+ gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
302
+ cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
303
+ num_any_c = len(any_c_sites)
304
+ num_gpc = len(gpc_sites)
305
+ num_cpg = len(cpg_sites)
306
+ print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
307
+
308
+ # Use var_names for x-axis tick labels
309
+ gpc_labels = subset.var_names[gpc_sites].astype(int)
310
+ cpg_labels = subset.var_names[cpg_sites].astype(int)
311
+ any_c_labels = subset.var_names[any_c_sites].astype(int)
312
+
313
+ stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
314
+ row_labels, bin_labels = [], []
315
+ bin_boundaries = []
316
+
317
+ total_reads = subset.shape[0]
318
+ percentages = {}
319
+ last_idx = 0
320
+
321
+ for bin_label, bin_filter in bins_temp.items():
322
+ subset_bin = subset[bin_filter].copy()
323
+ num_reads = subset_bin.shape[0]
324
+ print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
325
+ percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
326
+ percentages[bin_label] = percent_reads
327
+
328
+ if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
329
+ # Determine sorting order
330
+ if sort_by.startswith("obs:"):
331
+ colname = sort_by.split("obs:")[1]
332
+ order = np.argsort(subset_bin.obs[colname].values)
333
+ elif sort_by == "gpc":
334
+ linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
335
+ order = sch.leaves_list(linkage)
336
+ elif sort_by == "cpg":
337
+ linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
338
+ order = sch.leaves_list(linkage)
339
+ elif sort_by == "any_c":
340
+ linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
341
+ order = sch.leaves_list(linkage)
342
+ elif sort_by == "gpc_cpg":
343
+ linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
344
+ order = sch.leaves_list(linkage)
345
+ elif sort_by == "none":
346
+ order = np.arange(num_reads)
347
+ else:
348
+ raise ValueError(f"Unsupported sort_by option: {sort_by}")
349
+
350
+ stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
351
+ stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
352
+ stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
353
+
354
+ row_labels.extend([bin_label] * num_reads)
355
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
356
+ last_idx += num_reads
357
+ bin_boundaries.append(last_idx)
358
+
359
+ if stacked_any_c:
360
+ any_c_matrix = np.vstack(stacked_any_c)
361
+ gpc_matrix = np.vstack(stacked_gpc)
362
+ cpg_matrix = np.vstack(stacked_cpg)
363
+
364
+ if any_c_matrix.size > 0:
365
+ def normalized_mean(matrix):
366
+ mean = np.nanmean(matrix, axis=0)
367
+ normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
368
+ return normalized
369
+
370
+ def methylation_fraction(matrix):
371
+ methylated = (matrix == 1).sum(axis=0)
372
+ valid = (matrix != 0).sum(axis=0)
373
+ return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
374
+
375
+ mean_gpc = methylation_fraction(gpc_matrix)
376
+ mean_cpg = methylation_fraction(cpg_matrix)
377
+ mean_any_c = methylation_fraction(any_c_matrix)
378
+
379
+ fig = plt.figure(figsize=(18, 12))
380
+ gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
381
+ fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
382
+
383
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
384
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
385
+
386
+ clean_barplot(axes_bar[0], mean_any_c, f"any C site Modification Signal")
387
+ clean_barplot(axes_bar[1], mean_gpc, f"GpC Modification Signal")
388
+ clean_barplot(axes_bar[2], mean_cpg, f"CpG Modification Signal")
389
+
390
+
391
+ sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[0], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
392
+ axes_heat[0].set_xticks(range(0, len(any_c_labels), 20))
393
+ axes_heat[0].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
394
+ for boundary in bin_boundaries[:-1]:
395
+ axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
396
+
397
+ sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
398
+ axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
399
+ axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
400
+ for boundary in bin_boundaries[:-1]:
401
+ axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
402
+
403
+ sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
404
+ axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
405
+ for boundary in bin_boundaries[:-1]:
406
+ axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
407
+
408
+ plt.tight_layout()
409
+
410
+ if save_path:
411
+ save_name = f"{ref} — {sample}"
412
+ os.makedirs(save_path, exist_ok=True)
413
+ safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
414
+ out_file = os.path.join(save_path, f"{safe_name}.png")
415
+ plt.savefig(out_file, dpi=300)
416
+ print(f"Saved: {out_file}")
417
+ plt.close()
418
+ else:
419
+ plt.show()
420
+
421
+ print(f"Summary for {sample} - {ref}:")
422
+ for bin_label, percent in percentages.items():
423
+ print(f" - {bin_label}: {percent:.1f}%")
424
+
425
+ results.append({
426
+ "sample": sample,
427
+ "ref": ref,
428
+ "any_c_matrix": any_c_matrix,
429
+ "gpc_matrix": gpc_matrix,
430
+ "cpg_matrix": cpg_matrix,
431
+ "row_labels": row_labels,
432
+ "bin_labels": bin_labels,
433
+ "bin_boundaries": bin_boundaries,
434
+ "percentages": percentages
435
+ })
436
+
437
+ adata.uns['clustermap_results'] = results
438
+
439
+ except Exception as e:
440
+ import traceback
441
+ traceback.print_exc()
442
+ continue
443
+
444
+
445
+ import os
446
+ import math
447
+ from typing import List, Optional, Sequence, Tuple
448
+
449
+ import numpy as np
450
+ import pandas as pd
451
+ import matplotlib.pyplot as plt
452
+
453
+ def plot_hmm_layers_rolling_by_sample_ref(
454
+ adata,
455
+ layers: Optional[Sequence[str]] = None,
456
+ sample_col: str = "Barcode",
457
+ ref_col: str = "Reference_strand",
458
+ samples: Optional[Sequence[str]] = None,
459
+ references: Optional[Sequence[str]] = None,
460
+ window: int = 51,
461
+ min_periods: int = 1,
462
+ center: bool = True,
463
+ rows_per_page: int = 6,
464
+ figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
465
+ dpi: int = 160,
466
+ output_dir: Optional[str] = None,
467
+ save: bool = True,
468
+ show_raw: bool = False,
469
+ cmap: str = "tab10",
470
+ use_var_coords: bool = True,
471
+ ):
472
+ """
473
+ For each sample (row) and reference (col) plot the rolling average of the
474
+ positional mean (mean across reads) for each layer listed.
475
+
476
+ Parameters
477
+ ----------
478
+ adata : AnnData
479
+ Input annotated data (expects obs columns sample_col and ref_col).
480
+ layers : list[str] | None
481
+ Which adata.layers to plot. If None, attempts to autodetect layers whose
482
+ matrices look like "HMM" outputs (else will error). If None and layers
483
+ cannot be found, user must pass a list.
484
+ sample_col, ref_col : str
485
+ obs columns used to group rows.
486
+ samples, references : optional lists
487
+ explicit ordering of samples / references. If None, categories in adata.obs are used.
488
+ window : int
489
+ rolling window size (odd recommended). If window <= 1, no smoothing applied.
490
+ min_periods : int
491
+ min periods param for pd.Series.rolling.
492
+ center : bool
493
+ center the rolling window.
494
+ rows_per_page : int
495
+ paginate rows per page into multiple figures if needed.
496
+ figsize_per_cell : (w,h)
497
+ per-subplot size in inches.
498
+ dpi : int
499
+ figure dpi when saving.
500
+ output_dir : str | None
501
+ directory to save pages; created if necessary. If None and save=True, uses cwd.
502
+ save : bool
503
+ whether to save PNG files.
504
+ show_raw : bool
505
+ draw unsmoothed mean as faint line under smoothed curve.
506
+ cmap : str
507
+ matplotlib colormap for layer lines.
508
+ use_var_coords : bool
509
+ if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
510
+
511
+ Returns
512
+ -------
513
+ saved_files : list[str]
514
+ list of saved filenames (may be empty if save=False).
515
+ """
516
+
517
+ # --- basic checks / defaults ---
518
+ if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
519
+ raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
520
+
521
+ # canonicalize samples / refs
522
+ if samples is None:
523
+ sseries = adata.obs[sample_col]
524
+ if not pd.api.types.is_categorical_dtype(sseries):
525
+ sseries = sseries.astype("category")
526
+ samples_all = list(sseries.cat.categories)
527
+ else:
528
+ samples_all = list(samples)
529
+
530
+ if references is None:
531
+ rseries = adata.obs[ref_col]
532
+ if not pd.api.types.is_categorical_dtype(rseries):
533
+ rseries = rseries.astype("category")
534
+ refs_all = list(rseries.cat.categories)
535
+ else:
536
+ refs_all = list(references)
537
+
538
+ # choose layers: if not provided, try a sensible default: all layers
539
+ if layers is None:
540
+ layers = list(adata.layers.keys())
541
+ if len(layers) == 0:
542
+ raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
543
+ layers = list(layers)
544
+
545
+ # x coordinates (positions)
546
+ try:
547
+ if use_var_coords:
548
+ x_coords = np.array([int(v) for v in adata.var_names])
549
+ else:
550
+ raise Exception("user disabled var coords")
551
+ except Exception:
552
+ # fallback to 0..n_vars-1
553
+ x_coords = np.arange(adata.shape[1], dtype=int)
554
+
555
+ # make output dir
556
+ if save:
557
+ outdir = output_dir or os.getcwd()
558
+ os.makedirs(outdir, exist_ok=True)
559
+ else:
560
+ outdir = None
561
+
562
+ n_samples = len(samples_all)
563
+ n_refs = len(refs_all)
564
+ total_pages = math.ceil(n_samples / rows_per_page)
565
+ saved_files = []
566
+
567
+ # color cycle for layers
568
+ cmap_obj = plt.get_cmap(cmap)
569
+ n_layers = max(1, len(layers))
570
+ colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
571
+
572
+ for page in range(total_pages):
573
+ start = page * rows_per_page
574
+ end = min(start + rows_per_page, n_samples)
575
+ chunk = samples_all[start:end]
576
+ nrows = len(chunk)
577
+ ncols = n_refs
578
+
579
+ fig_w = figsize_per_cell[0] * ncols
580
+ fig_h = figsize_per_cell[1] * nrows
581
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
582
+ figsize=(fig_w, fig_h), dpi=dpi,
583
+ squeeze=False)
584
+
585
+ for r_idx, sample_name in enumerate(chunk):
586
+ for c_idx, ref_name in enumerate(refs_all):
587
+ ax = axes[r_idx][c_idx]
588
+
589
+ # subset adata
590
+ mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
591
+ sub = adata[mask]
592
+ if sub.n_obs == 0:
593
+ ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
594
+ ax.set_xticks([])
595
+ ax.set_yticks([])
596
+ if r_idx == 0:
597
+ ax.set_title(str(ref_name), fontsize=9)
598
+ if c_idx == 0:
599
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
600
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
601
+ continue
602
+
603
+ # for each layer, compute positional mean across reads (ignore NaNs)
604
+ plotted_any = False
605
+ for li, layer in enumerate(layers):
606
+ if layer in sub.layers:
607
+ mat = sub.layers[layer]
608
+ else:
609
+ # fallback: try .X only for the first layer if layer not present
610
+ if layer == layers[0] and getattr(sub, "X", None) is not None:
611
+ mat = sub.X
612
+ else:
613
+ # layer not present for this subset
614
+ continue
615
+
616
+ # convert matrix to numpy 2D
617
+ if hasattr(mat, "toarray"):
618
+ try:
619
+ arr = mat.toarray()
620
+ except Exception:
621
+ arr = np.asarray(mat)
622
+ else:
623
+ arr = np.asarray(mat)
624
+
625
+ if arr.size == 0 or arr.shape[1] == 0:
626
+ continue
627
+
628
+ # compute column-wise mean ignoring NaNs
629
+ # if arr is boolean or int, convert to float to support NaN
630
+ arr = arr.astype(float)
631
+ with np.errstate(all="ignore"):
632
+ col_mean = np.nanmean(arr, axis=0)
633
+
634
+ # If all-NaN, skip
635
+ if np.all(np.isnan(col_mean)):
636
+ continue
637
+
638
+ # smooth via pandas rolling (centered)
639
+ if (window is None) or (window <= 1):
640
+ smoothed = col_mean
641
+ else:
642
+ ser = pd.Series(col_mean)
643
+ smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
644
+
645
+ # x axis: x_coords (trim/pad to match length)
646
+ L = len(col_mean)
647
+ x = x_coords[:L]
648
+
649
+ # optionally plot raw faint line first
650
+ if show_raw:
651
+ ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
652
+
653
+ ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
654
+ plotted_any = True
655
+
656
+ # labels / titles
657
+ if r_idx == 0:
658
+ ax.set_title(str(ref_name), fontsize=9)
659
+ if c_idx == 0:
660
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
661
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
662
+ if r_idx == nrows - 1:
663
+ ax.set_xlabel("position", fontsize=8)
664
+
665
+ # legend (only show in top-left plot to reduce clutter)
666
+ if (r_idx == 0 and c_idx == 0) and plotted_any:
667
+ ax.legend(fontsize=7, loc="upper right")
668
+
669
+ ax.grid(True, alpha=0.2)
670
+
671
+ fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
672
+ fig.tight_layout(rect=[0, 0, 1, 0.97])
673
+
674
+ if save:
675
+ fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
676
+ plt.savefig(fname, bbox_inches="tight", dpi=dpi)
677
+ saved_files.append(fname)
678
+ else:
679
+ plt.show()
680
+ plt.close(fig)
681
+
682
+ return saved_files