smftools 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. smftools/__init__.py +9 -4
  2. smftools/_version.py +1 -1
  3. smftools/cli.py +184 -0
  4. smftools/config/__init__.py +1 -0
  5. smftools/config/conversion.yaml +33 -0
  6. smftools/config/deaminase.yaml +56 -0
  7. smftools/config/default.yaml +253 -0
  8. smftools/config/direct.yaml +17 -0
  9. smftools/config/experiment_config.py +1191 -0
  10. smftools/hmm/HMM.py +1576 -0
  11. smftools/hmm/__init__.py +20 -0
  12. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  13. smftools/hmm/call_hmm_peaks.py +106 -0
  14. smftools/{tools → hmm}/display_hmm.py +3 -3
  15. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  16. smftools/{tools → hmm}/train_hmm.py +1 -1
  17. smftools/informatics/__init__.py +0 -2
  18. smftools/informatics/archived/deaminase_smf.py +132 -0
  19. smftools/informatics/fast5_to_pod5.py +4 -1
  20. smftools/informatics/helpers/__init__.py +3 -4
  21. smftools/informatics/helpers/align_and_sort_BAM.py +34 -7
  22. smftools/informatics/helpers/aligned_BAM_to_bed.py +35 -24
  23. smftools/informatics/helpers/binarize_converted_base_identities.py +116 -23
  24. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +365 -42
  25. smftools/informatics/helpers/converted_BAM_to_adata_II.py +165 -29
  26. smftools/informatics/helpers/discover_input_files.py +100 -0
  27. smftools/informatics/helpers/extract_base_identities.py +29 -3
  28. smftools/informatics/helpers/extract_read_features_from_bam.py +4 -2
  29. smftools/informatics/helpers/find_conversion_sites.py +5 -4
  30. smftools/informatics/helpers/modkit_extract_to_adata.py +6 -3
  31. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  32. smftools/informatics/helpers/separate_bam_by_bc.py +2 -2
  33. smftools/informatics/helpers/split_and_index_BAM.py +1 -5
  34. smftools/load_adata.py +1346 -0
  35. smftools/machine_learning/__init__.py +12 -0
  36. smftools/machine_learning/data/__init__.py +2 -0
  37. smftools/machine_learning/data/anndata_data_module.py +234 -0
  38. smftools/machine_learning/evaluation/__init__.py +2 -0
  39. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  40. smftools/machine_learning/evaluation/evaluators.py +223 -0
  41. smftools/machine_learning/inference/__init__.py +3 -0
  42. smftools/machine_learning/inference/inference_utils.py +27 -0
  43. smftools/machine_learning/inference/lightning_inference.py +68 -0
  44. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  45. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  46. smftools/machine_learning/models/base.py +295 -0
  47. smftools/machine_learning/models/cnn.py +138 -0
  48. smftools/machine_learning/models/lightning_base.py +345 -0
  49. smftools/machine_learning/models/mlp.py +26 -0
  50. smftools/{tools → machine_learning}/models/positional.py +3 -2
  51. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  52. smftools/machine_learning/models/sklearn_models.py +273 -0
  53. smftools/machine_learning/models/transformer.py +303 -0
  54. smftools/machine_learning/training/__init__.py +2 -0
  55. smftools/machine_learning/training/train_lightning_model.py +135 -0
  56. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  57. smftools/plotting/__init__.py +4 -1
  58. smftools/plotting/autocorrelation_plotting.py +611 -0
  59. smftools/plotting/general_plotting.py +566 -89
  60. smftools/plotting/hmm_plotting.py +260 -0
  61. smftools/plotting/qc_plotting.py +270 -0
  62. smftools/preprocessing/__init__.py +13 -8
  63. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  64. smftools/preprocessing/append_base_context.py +122 -0
  65. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  66. smftools/preprocessing/calculate_complexity_II.py +248 -0
  67. smftools/preprocessing/calculate_coverage.py +10 -1
  68. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  69. smftools/preprocessing/clean_NaN.py +17 -1
  70. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  71. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  72. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  73. smftools/preprocessing/invert_adata.py +12 -5
  74. smftools/preprocessing/load_sample_sheet.py +19 -4
  75. smftools/readwrite.py +849 -43
  76. smftools/tools/__init__.py +3 -32
  77. smftools/tools/calculate_umap.py +5 -5
  78. smftools/tools/general_tools.py +3 -3
  79. smftools/tools/position_stats.py +468 -106
  80. smftools/tools/read_stats.py +115 -1
  81. smftools/tools/spatial_autocorrelation.py +562 -0
  82. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/METADATA +5 -1
  83. smftools-0.2.1.dist-info/RECORD +161 -0
  84. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  85. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  86. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  87. smftools/informatics/load_adata.py +0 -182
  88. smftools/preprocessing/append_C_context.py +0 -82
  89. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  90. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  91. smftools/preprocessing/filter_reads_on_length.py +0 -51
  92. smftools/tools/call_hmm_peaks.py +0 -105
  93. smftools/tools/data/__init__.py +0 -2
  94. smftools/tools/data/anndata_data_module.py +0 -90
  95. smftools/tools/evaluation/__init__.py +0 -0
  96. smftools/tools/inference/__init__.py +0 -1
  97. smftools/tools/inference/lightning_inference.py +0 -41
  98. smftools/tools/models/base.py +0 -14
  99. smftools/tools/models/cnn.py +0 -34
  100. smftools/tools/models/lightning_base.py +0 -41
  101. smftools/tools/models/mlp.py +0 -17
  102. smftools/tools/models/sklearn_models.py +0 -40
  103. smftools/tools/models/transformer.py +0 -133
  104. smftools/tools/training/__init__.py +0 -1
  105. smftools/tools/training/train_lightning_model.py +0 -47
  106. smftools-0.1.7.dist-info/RECORD +0 -136
  107. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  108. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  109. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  110. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  111. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  112. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  113. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  114. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  115. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  116. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  117. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  118. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  119. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  120. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -17,24 +17,30 @@ def clean_barplot(ax, mean_values, title):
17
17
 
18
18
  ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
19
19
 
20
-
21
20
  def combined_hmm_raw_clustermap(
22
21
  adata,
23
22
  sample_col='Sample_Names',
23
+ reference_col='Reference_strand',
24
24
  hmm_feature_layer="hmm_combined",
25
25
  layer_gpc="nan0_0minus1",
26
26
  layer_cpg="nan0_0minus1",
27
+ layer_any_c="nan0_0minus1",
27
28
  cmap_hmm="tab10",
28
29
  cmap_gpc="coolwarm",
29
30
  cmap_cpg="viridis",
31
+ cmap_any_c='coolwarm',
30
32
  min_quality=20,
31
- min_length=2700,
33
+ min_length=200,
34
+ min_mapped_length_to_reference_length_ratio=0.8,
35
+ min_position_valid_fraction=0.5,
32
36
  sample_mapping=None,
33
37
  save_path=None,
34
38
  normalize_hmm=False,
35
39
  sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
36
- bins=None
37
- ):
40
+ bins=None,
41
+ deaminase=False,
42
+ min_signal=0
43
+ ):
38
44
  import scipy.cluster.hierarchy as sch
39
45
  import pandas as pd
40
46
  import numpy as np
@@ -44,38 +50,51 @@ def combined_hmm_raw_clustermap(
44
50
  import os
45
51
 
46
52
  results = []
53
+ if deaminase:
54
+ signal_type = 'deamination'
55
+ else:
56
+ signal_type = 'methylation'
47
57
 
48
- for ref in adata.obs["Reference_strand"].cat.categories:
58
+ for ref in adata.obs[reference_col].cat.categories:
49
59
  for sample in adata.obs[sample_col].cat.categories:
50
60
  try:
51
61
  subset = adata[
52
- (adata.obs['Reference_strand'] == ref) &
62
+ (adata.obs[reference_col] == ref) &
53
63
  (adata.obs[sample_col] == sample) &
54
- (adata.obs['query_read_quality'] >= min_quality) &
64
+ (adata.obs['read_quality'] >= min_quality) &
55
65
  (adata.obs['read_length'] >= min_length) &
56
- (adata.obs['Raw_methylation_signal'] >= 20)
66
+ (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
57
67
  ]
58
68
 
59
- subset = subset[:, subset.var[f'position_in_{ref}'] == True]
69
+ mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
70
+ subset = subset[:, mask]
60
71
 
61
72
  if subset.shape[0] == 0:
62
- print(f" No reads left after filtering for {sample} - {ref}")
73
+ print(f" No reads left after filtering for {sample} - {ref}")
63
74
  continue
64
75
 
65
76
  if bins:
66
- pass
77
+ print(f"Using defined bins to subset clustermap for {sample} - {ref}")
78
+ bins_temp = bins
67
79
  else:
68
- bins = {"All": (subset.obs['Reference_strand'] != None)}
80
+ print(f"Using all reads for clustermap for {sample} - {ref}")
81
+ bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
69
82
 
70
83
  # Get column positions (not var_names!) of site masks
71
84
  gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
72
85
  cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
86
+ any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
87
+ num_gpc = len(gpc_sites)
88
+ num_cpg = len(cpg_sites)
89
+ num_c = len(any_c_sites)
90
+ print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
73
91
 
74
92
  # Use var_names for x-axis tick labels
75
93
  gpc_labels = subset.var_names[gpc_sites].astype(int)
76
94
  cpg_labels = subset.var_names[cpg_sites].astype(int)
95
+ any_c_labels = subset.var_names[any_c_sites].astype(int)
77
96
 
78
- stacked_hmm_feature, stacked_gpc, stacked_cpg = [], [], []
97
+ stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
79
98
  row_labels, bin_labels = [], []
80
99
  bin_boundaries = []
81
100
 
@@ -83,13 +102,14 @@ def combined_hmm_raw_clustermap(
83
102
  percentages = {}
84
103
  last_idx = 0
85
104
 
86
- for bin_label, bin_filter in bins.items():
105
+ for bin_label, bin_filter in bins_temp.items():
87
106
  subset_bin = subset[bin_filter].copy()
88
107
  num_reads = subset_bin.shape[0]
108
+ print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
89
109
  percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
90
110
  percentages[bin_label] = percent_reads
91
111
 
92
- if num_reads > 0:
112
+ if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
93
113
  # Determine sorting order
94
114
  if sort_by.startswith("obs:"):
95
115
  colname = sort_by.split("obs:")[1]
@@ -105,12 +125,16 @@ def combined_hmm_raw_clustermap(
105
125
  order = sch.leaves_list(linkage)
106
126
  elif sort_by == "none":
107
127
  order = np.arange(num_reads)
128
+ elif sort_by == "any_c":
129
+ linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
130
+ order = sch.leaves_list(linkage)
108
131
  else:
109
132
  raise ValueError(f"Unsupported sort_by option: {sort_by}")
110
133
 
111
134
  stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
112
135
  stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
113
136
  stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
137
+ stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
114
138
 
115
139
  row_labels.extend([bin_label] * num_reads)
116
140
  bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
@@ -121,85 +145,538 @@ def combined_hmm_raw_clustermap(
121
145
  hmm_matrix = np.vstack(stacked_hmm_feature)
122
146
  gpc_matrix = np.vstack(stacked_gpc)
123
147
  cpg_matrix = np.vstack(stacked_cpg)
148
+ any_c_matrix = np.vstack(stacked_any_c)
124
149
 
125
- def normalized_mean(matrix):
126
- mean = np.nanmean(matrix, axis=0)
127
- normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
128
- return normalized
150
+ if hmm_matrix.size > 0:
151
+ def normalized_mean(matrix):
152
+ mean = np.nanmean(matrix, axis=0)
153
+ normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
154
+ return normalized
129
155
 
130
- def methylation_fraction(matrix):
131
- methylated = (matrix == 1).sum(axis=0)
132
- valid = (matrix != 0).sum(axis=0)
133
- return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
156
+ def methylation_fraction(matrix):
157
+ methylated = (matrix == 1).sum(axis=0)
158
+ valid = (matrix != 0).sum(axis=0)
159
+ return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
134
160
 
135
- if normalize_hmm:
136
- mean_hmm = normalized_mean(hmm_matrix)
137
- else:
138
- mean_hmm = np.nanmean(hmm_matrix, axis=0)
139
- mean_gpc = methylation_fraction(gpc_matrix)
140
- mean_cpg = methylation_fraction(cpg_matrix)
141
-
142
- fig = plt.figure(figsize=(18, 12))
143
- gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
144
- fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
145
-
146
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
147
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
148
-
149
- clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
150
- clean_barplot(axes_bar[1], mean_gpc, f"GpC Methylation")
151
- clean_barplot(axes_bar[2], mean_cpg, f"CpG Methylation")
152
-
153
- hmm_labels = subset.var_names.astype(int)
154
- hmm_label_spacing = 150
155
- sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
156
- axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
157
- axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
158
- for boundary in bin_boundaries[:-1]:
159
- axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
160
-
161
- sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
162
- axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
163
- axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
164
- for boundary in bin_boundaries[:-1]:
165
- axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
166
-
167
- sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
168
- axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
169
- for boundary in bin_boundaries[:-1]:
170
- axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
171
-
172
- plt.tight_layout()
173
-
174
- if save_path:
175
- save_name = f"{ref} — {sample}"
176
- os.makedirs(save_path, exist_ok=True)
177
- safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
178
- out_file = os.path.join(save_path, f"{safe_name}.png")
179
- plt.savefig(out_file, dpi=300)
180
- print(f"📁 Saved: {out_file}")
181
-
182
- plt.show()
183
-
184
- print(f"📊 Summary for {sample} - {ref}:")
185
- for bin_label, percent in percentages.items():
186
- print(f" - {bin_label}: {percent:.1f}%")
187
-
188
- results.append({
189
- "sample": sample,
190
- "ref": ref,
191
- "hmm_matrix": hmm_matrix,
192
- "gpc_matrix": gpc_matrix,
193
- "cpg_matrix": cpg_matrix,
194
- "row_labels": row_labels,
195
- "bin_labels": bin_labels,
196
- "bin_boundaries": bin_boundaries,
197
- "percentages": percentages
198
- })
199
-
200
- adata.uns['clustermap_results'] = results
161
+ if normalize_hmm:
162
+ mean_hmm = normalized_mean(hmm_matrix)
163
+ else:
164
+ mean_hmm = np.nanmean(hmm_matrix, axis=0)
165
+ mean_gpc = methylation_fraction(gpc_matrix)
166
+ mean_cpg = methylation_fraction(cpg_matrix)
167
+ mean_any_c = methylation_fraction(any_c_matrix)
168
+
169
+ fig = plt.figure(figsize=(18, 12))
170
+ gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
171
+ fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
172
+
173
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
174
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
175
+
176
+ clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
177
+ clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
178
+ clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
179
+ clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
180
+
181
+ hmm_labels = subset.var_names.astype(int)
182
+ hmm_label_spacing = 150
183
+ sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
184
+ axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
185
+ axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
186
+ for boundary in bin_boundaries[:-1]:
187
+ axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
188
+
189
+ sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
190
+ axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
191
+ axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
192
+ for boundary in bin_boundaries[:-1]:
193
+ axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
194
+
195
+ sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
196
+ axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
197
+ for boundary in bin_boundaries[:-1]:
198
+ axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
199
+
200
+ sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
201
+ axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
202
+ axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
203
+ for boundary in bin_boundaries[:-1]:
204
+ axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
205
+
206
+ plt.tight_layout()
207
+
208
+ if save_path:
209
+ save_name = f"{ref} — {sample}"
210
+ os.makedirs(save_path, exist_ok=True)
211
+ safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
212
+ out_file = os.path.join(save_path, f"{safe_name}.png")
213
+ plt.savefig(out_file, dpi=300)
214
+ print(f"Saved: {out_file}")
215
+ plt.close()
216
+ else:
217
+ plt.show()
218
+
219
+ print(f"Summary for {sample} - {ref}:")
220
+ for bin_label, percent in percentages.items():
221
+ print(f" - {bin_label}: {percent:.1f}%")
222
+
223
+ results.append({
224
+ "sample": sample,
225
+ "ref": ref,
226
+ "hmm_matrix": hmm_matrix,
227
+ "gpc_matrix": gpc_matrix,
228
+ "cpg_matrix": cpg_matrix,
229
+ "row_labels": row_labels,
230
+ "bin_labels": bin_labels,
231
+ "bin_boundaries": bin_boundaries,
232
+ "percentages": percentages
233
+ })
234
+
235
+ adata.uns['clustermap_results'] = results
236
+
237
+ except Exception as e:
238
+ import traceback
239
+ traceback.print_exc()
240
+ continue
241
+
242
+
243
+ def combined_raw_clustermap(
244
+ adata,
245
+ sample_col='Sample_Names',
246
+ reference_col='Reference_strand',
247
+ layer_any_c="nan0_0minus1",
248
+ layer_gpc="nan0_0minus1",
249
+ layer_cpg="nan0_0minus1",
250
+ cmap_any_c="coolwarm",
251
+ cmap_gpc="coolwarm",
252
+ cmap_cpg="viridis",
253
+ min_quality=20,
254
+ min_length=200,
255
+ min_mapped_length_to_reference_length_ratio=0.8,
256
+ min_position_valid_fraction=0.5,
257
+ sample_mapping=None,
258
+ save_path=None,
259
+ sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
260
+ bins=None,
261
+ deaminase=False,
262
+ min_signal=0
263
+ ):
264
+ import scipy.cluster.hierarchy as sch
265
+ import pandas as pd
266
+ import numpy as np
267
+ import seaborn as sns
268
+ import matplotlib.pyplot as plt
269
+ import matplotlib.gridspec as gridspec
270
+ import os
271
+
272
+ results = []
273
+
274
+ for ref in adata.obs[reference_col].cat.categories:
275
+ for sample in adata.obs[sample_col].cat.categories:
276
+ try:
277
+ subset = adata[
278
+ (adata.obs[reference_col] == ref) &
279
+ (adata.obs[sample_col] == sample) &
280
+ (adata.obs['read_quality'] >= min_quality) &
281
+ (adata.obs['mapped_length'] >= min_length) &
282
+ (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
283
+ ]
284
+
285
+ mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
286
+ subset = subset[:, mask]
287
+
288
+ if subset.shape[0] == 0:
289
+ print(f" No reads left after filtering for {sample} - {ref}")
290
+ continue
291
+
292
+ if bins:
293
+ print(f"Using defined bins to subset clustermap for {sample} - {ref}")
294
+ bins_temp = bins
295
+ else:
296
+ print(f"Using all reads for clustermap for {sample} - {ref}")
297
+ bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
298
+
299
+ # Get column positions (not var_names!) of site masks
300
+ any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
301
+ gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
302
+ cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
303
+ num_any_c = len(any_c_sites)
304
+ num_gpc = len(gpc_sites)
305
+ num_cpg = len(cpg_sites)
306
+ print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
307
+
308
+ # Use var_names for x-axis tick labels
309
+ gpc_labels = subset.var_names[gpc_sites].astype(int)
310
+ cpg_labels = subset.var_names[cpg_sites].astype(int)
311
+ any_c_labels = subset.var_names[any_c_sites].astype(int)
312
+
313
+ stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
314
+ row_labels, bin_labels = [], []
315
+ bin_boundaries = []
316
+
317
+ total_reads = subset.shape[0]
318
+ percentages = {}
319
+ last_idx = 0
320
+
321
+ for bin_label, bin_filter in bins_temp.items():
322
+ subset_bin = subset[bin_filter].copy()
323
+ num_reads = subset_bin.shape[0]
324
+ print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
325
+ percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
326
+ percentages[bin_label] = percent_reads
327
+
328
+ if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
329
+ # Determine sorting order
330
+ if sort_by.startswith("obs:"):
331
+ colname = sort_by.split("obs:")[1]
332
+ order = np.argsort(subset_bin.obs[colname].values)
333
+ elif sort_by == "gpc":
334
+ linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
335
+ order = sch.leaves_list(linkage)
336
+ elif sort_by == "cpg":
337
+ linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
338
+ order = sch.leaves_list(linkage)
339
+ elif sort_by == "any_c":
340
+ linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
341
+ order = sch.leaves_list(linkage)
342
+ elif sort_by == "gpc_cpg":
343
+ linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
344
+ order = sch.leaves_list(linkage)
345
+ elif sort_by == "none":
346
+ order = np.arange(num_reads)
347
+ else:
348
+ raise ValueError(f"Unsupported sort_by option: {sort_by}")
349
+
350
+ stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
351
+ stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
352
+ stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
353
+
354
+ row_labels.extend([bin_label] * num_reads)
355
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
356
+ last_idx += num_reads
357
+ bin_boundaries.append(last_idx)
358
+
359
+ if stacked_any_c:
360
+ any_c_matrix = np.vstack(stacked_any_c)
361
+ gpc_matrix = np.vstack(stacked_gpc)
362
+ cpg_matrix = np.vstack(stacked_cpg)
363
+
364
+ if any_c_matrix.size > 0:
365
+ def normalized_mean(matrix):
366
+ mean = np.nanmean(matrix, axis=0)
367
+ normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
368
+ return normalized
369
+
370
+ def methylation_fraction(matrix):
371
+ methylated = (matrix == 1).sum(axis=0)
372
+ valid = (matrix != 0).sum(axis=0)
373
+ return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
374
+
375
+ mean_gpc = methylation_fraction(gpc_matrix)
376
+ mean_cpg = methylation_fraction(cpg_matrix)
377
+ mean_any_c = methylation_fraction(any_c_matrix)
378
+
379
+ fig = plt.figure(figsize=(18, 12))
380
+ gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
381
+ fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
382
+
383
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
384
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
385
+
386
+ clean_barplot(axes_bar[0], mean_any_c, f"any C site Modification Signal")
387
+ clean_barplot(axes_bar[1], mean_gpc, f"GpC Modification Signal")
388
+ clean_barplot(axes_bar[2], mean_cpg, f"CpG Modification Signal")
389
+
390
+
391
+ sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[0], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
392
+ axes_heat[0].set_xticks(range(0, len(any_c_labels), 20))
393
+ axes_heat[0].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
394
+ for boundary in bin_boundaries[:-1]:
395
+ axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
396
+
397
+ sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
398
+ axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
399
+ axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
400
+ for boundary in bin_boundaries[:-1]:
401
+ axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
402
+
403
+ sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
404
+ axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
405
+ for boundary in bin_boundaries[:-1]:
406
+ axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
407
+
408
+ plt.tight_layout()
409
+
410
+ if save_path:
411
+ save_name = f"{ref} — {sample}"
412
+ os.makedirs(save_path, exist_ok=True)
413
+ safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
414
+ out_file = os.path.join(save_path, f"{safe_name}.png")
415
+ plt.savefig(out_file, dpi=300)
416
+ print(f"Saved: {out_file}")
417
+ plt.close()
418
+ else:
419
+ plt.show()
420
+
421
+ print(f"Summary for {sample} - {ref}:")
422
+ for bin_label, percent in percentages.items():
423
+ print(f" - {bin_label}: {percent:.1f}%")
424
+
425
+ results.append({
426
+ "sample": sample,
427
+ "ref": ref,
428
+ "any_c_matrix": any_c_matrix,
429
+ "gpc_matrix": gpc_matrix,
430
+ "cpg_matrix": cpg_matrix,
431
+ "row_labels": row_labels,
432
+ "bin_labels": bin_labels,
433
+ "bin_boundaries": bin_boundaries,
434
+ "percentages": percentages
435
+ })
436
+
437
+ adata.uns['clustermap_results'] = results
201
438
 
202
439
  except Exception as e:
203
440
  import traceback
204
441
  traceback.print_exc()
205
442
  continue
443
+
444
+
445
+ import os
446
+ import math
447
+ from typing import List, Optional, Sequence, Tuple
448
+
449
+ import numpy as np
450
+ import pandas as pd
451
+ import matplotlib.pyplot as plt
452
+
453
+ def plot_hmm_layers_rolling_by_sample_ref(
454
+ adata,
455
+ layers: Optional[Sequence[str]] = None,
456
+ sample_col: str = "Barcode",
457
+ ref_col: str = "Reference_strand",
458
+ samples: Optional[Sequence[str]] = None,
459
+ references: Optional[Sequence[str]] = None,
460
+ window: int = 51,
461
+ min_periods: int = 1,
462
+ center: bool = True,
463
+ rows_per_page: int = 6,
464
+ figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
465
+ dpi: int = 160,
466
+ output_dir: Optional[str] = None,
467
+ save: bool = True,
468
+ show_raw: bool = False,
469
+ cmap: str = "tab10",
470
+ use_var_coords: bool = True,
471
+ ):
472
+ """
473
+ For each sample (row) and reference (col) plot the rolling average of the
474
+ positional mean (mean across reads) for each layer listed.
475
+
476
+ Parameters
477
+ ----------
478
+ adata : AnnData
479
+ Input annotated data (expects obs columns sample_col and ref_col).
480
+ layers : list[str] | None
481
+ Which adata.layers to plot. If None, attempts to autodetect layers whose
482
+ matrices look like "HMM" outputs (else will error). If None and layers
483
+ cannot be found, user must pass a list.
484
+ sample_col, ref_col : str
485
+ obs columns used to group rows.
486
+ samples, references : optional lists
487
+ explicit ordering of samples / references. If None, categories in adata.obs are used.
488
+ window : int
489
+ rolling window size (odd recommended). If window <= 1, no smoothing applied.
490
+ min_periods : int
491
+ min periods param for pd.Series.rolling.
492
+ center : bool
493
+ center the rolling window.
494
+ rows_per_page : int
495
+ paginate rows per page into multiple figures if needed.
496
+ figsize_per_cell : (w,h)
497
+ per-subplot size in inches.
498
+ dpi : int
499
+ figure dpi when saving.
500
+ output_dir : str | None
501
+ directory to save pages; created if necessary. If None and save=True, uses cwd.
502
+ save : bool
503
+ whether to save PNG files.
504
+ show_raw : bool
505
+ draw unsmoothed mean as faint line under smoothed curve.
506
+ cmap : str
507
+ matplotlib colormap for layer lines.
508
+ use_var_coords : bool
509
+ if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
510
+
511
+ Returns
512
+ -------
513
+ saved_files : list[str]
514
+ list of saved filenames (may be empty if save=False).
515
+ """
516
+
517
+ # --- basic checks / defaults ---
518
+ if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
519
+ raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
520
+
521
+ # canonicalize samples / refs
522
+ if samples is None:
523
+ sseries = adata.obs[sample_col]
524
+ if not pd.api.types.is_categorical_dtype(sseries):
525
+ sseries = sseries.astype("category")
526
+ samples_all = list(sseries.cat.categories)
527
+ else:
528
+ samples_all = list(samples)
529
+
530
+ if references is None:
531
+ rseries = adata.obs[ref_col]
532
+ if not pd.api.types.is_categorical_dtype(rseries):
533
+ rseries = rseries.astype("category")
534
+ refs_all = list(rseries.cat.categories)
535
+ else:
536
+ refs_all = list(references)
537
+
538
+ # choose layers: if not provided, try a sensible default: all layers
539
+ if layers is None:
540
+ layers = list(adata.layers.keys())
541
+ if len(layers) == 0:
542
+ raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
543
+ layers = list(layers)
544
+
545
+ # x coordinates (positions)
546
+ try:
547
+ if use_var_coords:
548
+ x_coords = np.array([int(v) for v in adata.var_names])
549
+ else:
550
+ raise Exception("user disabled var coords")
551
+ except Exception:
552
+ # fallback to 0..n_vars-1
553
+ x_coords = np.arange(adata.shape[1], dtype=int)
554
+
555
+ # make output dir
556
+ if save:
557
+ outdir = output_dir or os.getcwd()
558
+ os.makedirs(outdir, exist_ok=True)
559
+ else:
560
+ outdir = None
561
+
562
+ n_samples = len(samples_all)
563
+ n_refs = len(refs_all)
564
+ total_pages = math.ceil(n_samples / rows_per_page)
565
+ saved_files = []
566
+
567
+ # color cycle for layers
568
+ cmap_obj = plt.get_cmap(cmap)
569
+ n_layers = max(1, len(layers))
570
+ colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
571
+
572
+ for page in range(total_pages):
573
+ start = page * rows_per_page
574
+ end = min(start + rows_per_page, n_samples)
575
+ chunk = samples_all[start:end]
576
+ nrows = len(chunk)
577
+ ncols = n_refs
578
+
579
+ fig_w = figsize_per_cell[0] * ncols
580
+ fig_h = figsize_per_cell[1] * nrows
581
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
582
+ figsize=(fig_w, fig_h), dpi=dpi,
583
+ squeeze=False)
584
+
585
+ for r_idx, sample_name in enumerate(chunk):
586
+ for c_idx, ref_name in enumerate(refs_all):
587
+ ax = axes[r_idx][c_idx]
588
+
589
+ # subset adata
590
+ mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
591
+ sub = adata[mask]
592
+ if sub.n_obs == 0:
593
+ ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
594
+ ax.set_xticks([])
595
+ ax.set_yticks([])
596
+ if r_idx == 0:
597
+ ax.set_title(str(ref_name), fontsize=9)
598
+ if c_idx == 0:
599
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
600
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
601
+ continue
602
+
603
+ # for each layer, compute positional mean across reads (ignore NaNs)
604
+ plotted_any = False
605
+ for li, layer in enumerate(layers):
606
+ if layer in sub.layers:
607
+ mat = sub.layers[layer]
608
+ else:
609
+ # fallback: try .X only for the first layer if layer not present
610
+ if layer == layers[0] and getattr(sub, "X", None) is not None:
611
+ mat = sub.X
612
+ else:
613
+ # layer not present for this subset
614
+ continue
615
+
616
+ # convert matrix to numpy 2D
617
+ if hasattr(mat, "toarray"):
618
+ try:
619
+ arr = mat.toarray()
620
+ except Exception:
621
+ arr = np.asarray(mat)
622
+ else:
623
+ arr = np.asarray(mat)
624
+
625
+ if arr.size == 0 or arr.shape[1] == 0:
626
+ continue
627
+
628
+ # compute column-wise mean ignoring NaNs
629
+ # if arr is boolean or int, convert to float to support NaN
630
+ arr = arr.astype(float)
631
+ with np.errstate(all="ignore"):
632
+ col_mean = np.nanmean(arr, axis=0)
633
+
634
+ # If all-NaN, skip
635
+ if np.all(np.isnan(col_mean)):
636
+ continue
637
+
638
+ # smooth via pandas rolling (centered)
639
+ if (window is None) or (window <= 1):
640
+ smoothed = col_mean
641
+ else:
642
+ ser = pd.Series(col_mean)
643
+ smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
644
+
645
+ # x axis: x_coords (trim/pad to match length)
646
+ L = len(col_mean)
647
+ x = x_coords[:L]
648
+
649
+ # optionally plot raw faint line first
650
+ if show_raw:
651
+ ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
652
+
653
+ ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
654
+ plotted_any = True
655
+
656
+ # labels / titles
657
+ if r_idx == 0:
658
+ ax.set_title(str(ref_name), fontsize=9)
659
+ if c_idx == 0:
660
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
661
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
662
+ if r_idx == nrows - 1:
663
+ ax.set_xlabel("position", fontsize=8)
664
+
665
+ # legend (only show in top-left plot to reduce clutter)
666
+ if (r_idx == 0 and c_idx == 0) and plotted_any:
667
+ ax.legend(fontsize=7, loc="upper right")
668
+
669
+ ax.grid(True, alpha=0.2)
670
+
671
+ fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
672
+ fig.tight_layout(rect=[0, 0, 1, 0.97])
673
+
674
+ if save:
675
+ fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
676
+ plt.savefig(fname, bbox_inches="tight", dpi=dpi)
677
+ saved_files.append(fname)
678
+ else:
679
+ plt.show()
680
+ plt.close(fig)
681
+
682
+ return saved_files