smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,20 @@
1
+ from .apply_hmm_batched import apply_hmm_batched
2
+ from .calculate_distances import calculate_distances
3
+ from .call_hmm_peaks import call_hmm_peaks
4
+ from .display_hmm import display_hmm
5
+ from .hmm_readwrite import load_hmm, save_hmm
6
+ from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
7
+ from .train_hmm import train_hmm
8
+
9
+
10
+ __all__ = [
11
+ "apply_hmm_batched",
12
+ "calculate_distances",
13
+ "call_hmm_peaks",
14
+ "display_hmm",
15
+ "load_hmm",
16
+ "refine_nucleosome_calls",
17
+ "infer_nucleosomes_in_large_bound",
18
+ "save_hmm",
19
+ "train_hmm"
20
+ ]
@@ -0,0 +1,242 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def apply_hmm_batched(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A", "C"], device="cpu", threshold=0.7, deaminase_footprinting=False):
7
+ """
8
+ Applies an HMM model to an AnnData object using tensor-based sequence inputs.
9
+ If multiple methbases are passed, generates a combined feature set.
10
+ """
11
+
12
+ model.to(device)
13
+
14
+ # --- Feature Definitions ---
15
+ feature_sets = {}
16
+ if footprints:
17
+ feature_sets["footprint"] = {
18
+ "features": {
19
+ "small_bound_stretch": [0, 20],
20
+ "medium_bound_stretch": [20, 50],
21
+ "putative_nucleosome": [50, 200],
22
+ "large_bound_stretch": [200, np.inf]
23
+ },
24
+ "state": "Non-Methylated"
25
+ }
26
+ if accessible_patches:
27
+ feature_sets["accessible"] = {
28
+ "features": {
29
+ "small_accessible_patch": [0, 20],
30
+ "mid_accessible_patch": [20, 80],
31
+ "large_accessible_patch": [80, np.inf]
32
+ },
33
+ "state": "Methylated"
34
+ }
35
+ if cpg:
36
+ feature_sets["cpg"] = {
37
+ "features": {
38
+ "cpg_patch": [0, np.inf]
39
+ },
40
+ "state": "Methylated"
41
+ }
42
+
43
+ # --- Init columns ---
44
+ all_features = []
45
+ combined_prefix = "Combined"
46
+ for key, fs in feature_sets.items():
47
+ if key == 'cpg':
48
+ all_features += [f"CpG_{f}" for f in fs["features"]]
49
+ all_features.append(f"CpG_all_{key}_features")
50
+ else:
51
+ for methbase in methbases:
52
+ all_features += [f"{methbase}_{f}" for f in fs["features"]]
53
+ all_features.append(f"{methbase}_all_{key}_features")
54
+ if len(methbases) > 1:
55
+ all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
56
+ all_features.append(f"{combined_prefix}_all_{key}_features")
57
+
58
+ for feature in all_features:
59
+ adata.obs[feature] = [[] for _ in range(adata.shape[0])]
60
+ adata.obs[f"{feature}_distances"] = [None] * adata.shape[0]
61
+ adata.obs[f"n_{feature}"] = -1
62
+
63
+ # --- Main loop ---
64
+ references = adata.obs[obs_column].cat.categories
65
+
66
+ for ref in tqdm(references, desc="Processing References"):
67
+ ref_subset = adata[adata.obs[obs_column] == ref]
68
+
69
+ # Combined methbase mask
70
+ combined_mask = None
71
+ for methbase in methbases:
72
+ mask = {
73
+ "a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
74
+ "c": ref_subset.var[f"{ref}_any_C_site"] == True,
75
+ "gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
76
+ "cpg": ref_subset.var[f"{ref}_CpG_site"] == True
77
+ }[methbase.lower()]
78
+ combined_mask = mask if combined_mask is None else combined_mask | mask
79
+
80
+ methbase_subset = ref_subset[:, mask]
81
+ matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
82
+
83
+ processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
84
+ tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
85
+
86
+ coords = methbase_subset.var_names
87
+ for key, fs in feature_sets.items():
88
+ if key == 'cpg':
89
+ continue
90
+ state_target = fs["state"]
91
+ feature_map = fs["features"]
92
+
93
+ pred_states = model.predict(tensor_batch)
94
+ probs = model.predict_proba(tensor_batch)
95
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
96
+
97
+ for i, idx in enumerate(methbase_subset.obs.index):
98
+ for start, length, label, prob in classifications[i]:
99
+ adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
100
+ adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
101
+
102
+ # Combined subset
103
+ if len(methbases) > 1:
104
+ combined_subset = ref_subset[:, combined_mask]
105
+ combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
106
+ processed_combined_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in combined_matrix]
107
+ tensor_combined_batch = torch.tensor(processed_combined_reads, dtype=torch.long, device=device).unsqueeze(-1)
108
+
109
+ coords = combined_subset.var_names
110
+ for key, fs in feature_sets.items():
111
+ if key == 'cpg':
112
+ continue
113
+ state_target = fs["state"]
114
+ feature_map = fs["features"]
115
+
116
+ pred_states = model.predict(tensor_combined_batch)
117
+ probs = model.predict_proba(tensor_combined_batch)
118
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
119
+
120
+ for i, idx in enumerate(combined_subset.obs.index):
121
+ for start, length, label, prob in classifications[i]:
122
+ adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
123
+ adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
124
+
125
+ # --- Special handling for CpG ---
126
+ if cpg:
127
+ for ref in tqdm(references, desc="Processing CpG"):
128
+ ref_subset = adata[adata.obs[obs_column] == ref]
129
+ mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
130
+ cpg_subset = ref_subset[:, mask]
131
+ matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
132
+
133
+ processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
134
+ tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
135
+
136
+ coords = cpg_subset.var_names
137
+ fs = feature_sets['cpg']
138
+ state_target = fs["state"]
139
+ feature_map = fs["features"]
140
+
141
+ pred_states = model.predict(tensor_batch)
142
+ probs = model.predict_proba(tensor_batch)
143
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
144
+
145
+ for i, idx in enumerate(cpg_subset.obs.index):
146
+ for start, length, label, prob in classifications[i]:
147
+ adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
148
+ adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
149
+
150
+ # --- Binarization + Distance ---
151
+ coordinates = adata.var_names.astype(int).values
152
+
153
+ for feature in tqdm(all_features, desc="Finalizing Layers"):
154
+ bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
155
+ counts = np.zeros(adata.shape[0], dtype=int)
156
+ for row_idx, intervals in enumerate(adata.obs[feature]):
157
+ if not isinstance(intervals, list):
158
+ intervals = []
159
+ for start, length, prob in intervals:
160
+ if prob > threshold:
161
+ start_idx = np.searchsorted(coordinates, start, side="left")
162
+ end_idx = np.searchsorted(coordinates, start + length - 1, side="right")
163
+ bin_matrix[row_idx, start_idx:end_idx] = 1
164
+ counts[row_idx] += 1
165
+ adata.layers[feature] = bin_matrix
166
+ adata.obs[f"n_{feature}"] = counts
167
+ adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
168
+
169
+ def calculate_batch_distances(intervals_list, threshold=0.9):
170
+ """
171
+ Vectorized calculation of distances across multiple reads.
172
+
173
+ Parameters:
174
+ intervals_list (list of list): Outer list = reads, inner list = intervals [start, length, prob]
175
+ threshold (float): Minimum probability threshold for filtering
176
+
177
+ Returns:
178
+ List of distance lists per read.
179
+ """
180
+ results = []
181
+ for intervals in intervals_list:
182
+ if not isinstance(intervals, list) or len(intervals) == 0:
183
+ results.append([])
184
+ continue
185
+ valid = [iv for iv in intervals if iv[2] > threshold]
186
+ valid = sorted(valid, key=lambda x: x[0])
187
+ dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
188
+ results.append(dists)
189
+ return results
190
+
191
+
192
+
193
+ def classify_batch(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Methylated"):
194
+ """
195
+ Classify batch sequences efficiently.
196
+
197
+ Parameters:
198
+ predicted_states_batch: Tensor [batch_size, seq_len]
199
+ probabilities_batch: Tensor [batch_size, seq_len, n_states]
200
+ coordinates: list of genomic coordinates
201
+ classification_mapping: dict of feature bins
202
+ target_state: state name ("Methylated" or "Non-Methylated")
203
+
204
+ Returns:
205
+ List of classifications for each sequence.
206
+ """
207
+
208
+ state_labels = ["Non-Methylated", "Methylated"]
209
+ target_idx = state_labels.index(target_state)
210
+ batch_size = predicted_states_batch.shape[0]
211
+
212
+ all_classifications = []
213
+
214
+ for b in range(batch_size):
215
+ predicted_states = predicted_states_batch[b].cpu().numpy()
216
+ probabilities = probabilities_batch[b].cpu().numpy()
217
+
218
+ regions = []
219
+ current_start, current_length, current_probs = None, 0, []
220
+
221
+ for i, state_index in enumerate(predicted_states):
222
+ state_prob = probabilities[i][state_index]
223
+ if state_index == target_idx:
224
+ if current_start is None:
225
+ current_start = i
226
+ current_length += 1
227
+ current_probs.append(state_prob)
228
+ elif current_start is not None:
229
+ regions.append((current_start, current_length, np.mean(current_probs)))
230
+ current_start, current_length, current_probs = None, 0, []
231
+
232
+ if current_start is not None:
233
+ regions.append((current_start, current_length, np.mean(current_probs)))
234
+
235
+ final = []
236
+ for start, length, prob in regions:
237
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
238
+ label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
239
+ final.append((int(coordinates[start]) + 1, feature_length, label, prob))
240
+ all_classifications.append(final)
241
+
242
+ return all_classifications
@@ -0,0 +1,18 @@
1
+ # calculate_distances
2
+
3
+ def calculate_distances(intervals, threshold=0.9):
4
+ """
5
+ Calculates distance between features in a read.
6
+ Takes in a list of intervals (start of feature, length of feature)
7
+ """
8
+ # Sort intervals by start position
9
+ intervals = sorted(intervals, key=lambda x: x[0])
10
+ intervals = [interval for interval in intervals if interval[2] > threshold]
11
+
12
+ # Calculate distances
13
+ distances = []
14
+ for i in range(len(intervals) - 1):
15
+ end_current = intervals[i][0] + intervals[i][1]
16
+ start_next = intervals[i + 1][0]
17
+ distances.append(start_next - end_current)
18
+ return distances
@@ -0,0 +1,106 @@
1
+ def call_hmm_peaks(
2
+ adata,
3
+ feature_configs,
4
+ obs_column='Reference_strand',
5
+ site_types=['GpC_site', 'CpG_site'],
6
+ save_plot=False,
7
+ output_dir=None,
8
+ date_tag=None,
9
+ inplace=False
10
+ ):
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ from scipy.signal import find_peaks
15
+
16
+ if not inplace:
17
+ adata = adata.copy()
18
+
19
+ # Ensure obs_column is categorical
20
+ if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
21
+ adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
22
+
23
+ coordinates = adata.var_names.astype(int).values
24
+ peak_columns = []
25
+
26
+ obs_updates = {}
27
+
28
+ for feature_layer, config in feature_configs.items():
29
+ min_distance = config.get('min_distance', 200)
30
+ peak_width = config.get('peak_width', 200)
31
+ peak_prominence = config.get('peak_prominence', 0.2)
32
+ peak_threshold = config.get('peak_threshold', 0.8)
33
+
34
+ matrix = adata.layers[feature_layer]
35
+ means = np.mean(matrix, axis=0)
36
+ peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
+ peak_centers = coordinates[peak_indices]
38
+ adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
39
+
40
+ # Plot
41
+ plt.figure(figsize=(6, 3))
42
+ plt.plot(coordinates, means)
43
+ plt.title(f"{feature_layer} with peak calls")
44
+ plt.xlabel("Genomic position")
45
+ plt.ylabel("Mean intensity")
46
+ for i, center in enumerate(peak_centers):
47
+ start, end = center - peak_width // 2, center + peak_width // 2
48
+ plt.axvspan(start, end, color='purple', alpha=0.2)
49
+ plt.axvline(center, color='red', linestyle='--')
50
+ aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
51
+ plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
52
+ if save_plot and output_dir:
53
+ filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
54
+ plt.savefig(filename, bbox_inches='tight')
55
+ print(f"Saved plot to {filename}")
56
+ else:
57
+ plt.show()
58
+
59
+ feature_peak_columns = []
60
+ for center in peak_centers:
61
+ start, end = center - peak_width // 2, center + peak_width // 2
62
+ colname = f'{feature_layer}_peak_{center}'
63
+ peak_columns.append(colname)
64
+ feature_peak_columns.append(colname)
65
+
66
+ peak_mask = (coordinates >= start) & (coordinates <= end)
67
+ adata.var[colname] = peak_mask
68
+
69
+ region = matrix[:, peak_mask]
70
+ obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
71
+ obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
72
+ obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
73
+
74
+ for site_type in site_types:
75
+ adata.obs[f'{site_type}_sum_around_{center}'] = 0
76
+ adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
77
+
78
+ for ref in adata.obs[obs_column].cat.categories:
79
+ ref_idx = adata.obs[obs_column] == ref
80
+ mask_key = f"{ref}_{site_type}"
81
+ for site_type in site_types:
82
+ if mask_key not in adata.var:
83
+ continue
84
+ site_mask = adata.var[mask_key].values
85
+ site_coords = coordinates[site_mask]
86
+ region_mask = (site_coords >= start) & (site_coords <= end)
87
+ if not region_mask.any():
88
+ continue
89
+ full_mask = site_mask.copy()
90
+ full_mask[site_mask] = region_mask
91
+ site_region = adata[ref_idx, full_mask].X
92
+ if hasattr(site_region, "A"):
93
+ site_region = site_region.A
94
+ if site_region.shape[1] > 0:
95
+ adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
96
+ adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
97
+ else:
98
+ pass
99
+
100
+ adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
101
+ print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
102
+
103
+ adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
104
+ adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
105
+
106
+ return adata if not inplace else None
@@ -0,0 +1,18 @@
1
+ def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
2
+ import torch
3
+ print("\n**HMM Model Overview**")
4
+ print(hmm)
5
+
6
+ print("\n**Transition Matrix**")
7
+ transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
8
+ for i, row in enumerate(transition_matrix):
9
+ label = state_labels[i] if state_labels else f"State {i}"
10
+ formatted_row = ", ".join(f"{p:.6f}" for p in row)
11
+ print(f"{label}: [{formatted_row}]")
12
+
13
+ print("\n**Emission Probabilities**")
14
+ for i, dist in enumerate(hmm.distributions):
15
+ label = state_labels[i] if state_labels else f"State {i}"
16
+ probs = dist.probs.detach().cpu().numpy()
17
+ formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
18
+ print(f"{label}: {formatted_emissions}")
@@ -0,0 +1,16 @@
1
+ def load_hmm(model_path, device='cpu'):
2
+ """
3
+ Reads in a pretrained HMM.
4
+
5
+ Parameters:
6
+ model_path (str): Path to a pretrained HMM
7
+ """
8
+ import torch
9
+ # Load model using PyTorch
10
+ hmm = torch.load(model_path)
11
+ hmm.to(device)
12
+ return hmm
13
+
14
+ def save_hmm(model, model_path):
15
+ import torch
16
+ torch.save(model, model_path)
@@ -0,0 +1,104 @@
1
+ def refine_nucleosome_calls(adata, layer_name, nan_mask_layer, hexamer_size=120, octamer_size=147, max_wiggle=40, device="cpu"):
2
+ import numpy as np
3
+
4
+ nucleosome_layer = adata.layers[layer_name]
5
+ nan_mask = adata.layers[nan_mask_layer] # should be binary mask: 1 = nan region, 0 = valid data
6
+
7
+ hexamer_layer = np.zeros_like(nucleosome_layer)
8
+ octamer_layer = np.zeros_like(nucleosome_layer)
9
+
10
+ for read_idx, row in enumerate(nucleosome_layer):
11
+ in_patch = False
12
+ start_idx = None
13
+
14
+ for pos in range(len(row)):
15
+ if row[pos] == 1 and not in_patch:
16
+ in_patch = True
17
+ start_idx = pos
18
+ if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
19
+ in_patch = False
20
+ end_idx = pos if row[pos] == 0 else pos + 1
21
+
22
+ # Expand boundaries into NaNs
23
+ left_expand = 0
24
+ right_expand = 0
25
+
26
+ # Left
27
+ for i in range(1, max_wiggle + 1):
28
+ if start_idx - i >= 0 and nan_mask[read_idx, start_idx - i] == 1:
29
+ left_expand += 1
30
+ else:
31
+ break
32
+ # Right
33
+ for i in range(1, max_wiggle + 1):
34
+ if end_idx + i < nucleosome_layer.shape[1] and nan_mask[read_idx, end_idx + i] == 1:
35
+ right_expand += 1
36
+ else:
37
+ break
38
+
39
+ expanded_start = start_idx - left_expand
40
+ expanded_end = end_idx + right_expand
41
+
42
+ available_size = expanded_end - expanded_start
43
+
44
+ # Octamer placement
45
+ if available_size >= octamer_size:
46
+ center = (expanded_start + expanded_end) // 2
47
+ half_oct = octamer_size // 2
48
+ octamer_layer[read_idx, center - half_oct: center - half_oct + octamer_size] = 1
49
+
50
+ # Hexamer placement
51
+ elif available_size >= hexamer_size:
52
+ center = (expanded_start + expanded_end) // 2
53
+ half_hex = hexamer_size // 2
54
+ hexamer_layer[read_idx, center - half_hex: center - half_hex + hexamer_size] = 1
55
+
56
+ adata.layers[f"{layer_name}_hexamers"] = hexamer_layer
57
+ adata.layers[f"{layer_name}_octamers"] = octamer_layer
58
+
59
+ print(f"Added layers: {layer_name}_hexamers and {layer_name}_octamers")
60
+ return adata
61
+
62
+ def infer_nucleosomes_in_large_bound(adata, large_bound_layer, combined_nuc_layer, nan_mask_layer, nuc_size=147, linker_size=50, exclusion_buffer=30, device="cpu"):
63
+ import numpy as np
64
+
65
+ large_bound = adata.layers[large_bound_layer]
66
+ existing_nucs = adata.layers[combined_nuc_layer]
67
+ nan_mask = adata.layers[nan_mask_layer]
68
+
69
+ inferred_layer = np.zeros_like(large_bound)
70
+
71
+ for read_idx, row in enumerate(large_bound):
72
+ in_patch = False
73
+ start_idx = None
74
+
75
+ for pos in range(len(row)):
76
+ if row[pos] == 1 and not in_patch:
77
+ in_patch = True
78
+ start_idx = pos
79
+ if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
80
+ in_patch = False
81
+ end_idx = pos if row[pos] == 0 else pos + 1
82
+
83
+ # Adjust boundaries into flanking NaN regions without getting too close to existing nucleosomes
84
+ left_expand = start_idx
85
+ while left_expand > 0 and nan_mask[read_idx, left_expand - 1] == 1 and np.sum(existing_nucs[read_idx, max(0, left_expand - exclusion_buffer):left_expand]) == 0:
86
+ left_expand -= 1
87
+
88
+ right_expand = end_idx
89
+ while right_expand < row.shape[0] and nan_mask[read_idx, right_expand] == 1 and np.sum(existing_nucs[read_idx, right_expand:min(row.shape[0], right_expand + exclusion_buffer)]) == 0:
90
+ right_expand += 1
91
+
92
+ # Phase nucleosomes with linker spacing only
93
+ region = (left_expand, right_expand)
94
+ pos_cursor = region[0]
95
+ while pos_cursor + nuc_size <= region[1]:
96
+ if np.all((existing_nucs[read_idx, pos_cursor - exclusion_buffer:pos_cursor + nuc_size + exclusion_buffer] == 0)):
97
+ inferred_layer[read_idx, pos_cursor:pos_cursor + nuc_size] = 1
98
+ pos_cursor += nuc_size + linker_size
99
+ else:
100
+ pos_cursor += 1
101
+
102
+ adata.layers[f"{large_bound_layer}_phased_nucleosomes"] = inferred_layer
103
+ print(f"Added layer: {large_bound_layer}_phased_nucleosomes")
104
+ return adata
@@ -0,0 +1,78 @@
1
+ def train_hmm(
2
+ data,
3
+ emission_probs=[[0.8, 0.2], [0.2, 0.8]],
4
+ transitions=[[0.9, 0.1], [0.1, 0.9]],
5
+ start_probs=[0.5, 0.5],
6
+ end_probs=[0.5, 0.5],
7
+ device=None,
8
+ max_iter=50,
9
+ verbose=True,
10
+ tol=50,
11
+ pad_value=0,
12
+ ):
13
+ """
14
+ Trains a 2-state DenseHMM model on binary methylation/deamination data.
15
+
16
+ Parameters:
17
+ data (list or np.ndarray): List of sequences (lists) with 0, 1, or NaN.
18
+ emission_probs (list): List of emission probabilities for two states.
19
+ transitions (list): Transition matrix between states.
20
+ start_probs (list): Initial state probabilities.
21
+ end_probs (list): End state probabilities.
22
+ device (str or torch.device): "cpu", "mps", "cuda", or None (auto).
23
+ max_iter (int): Maximum EM iterations.
24
+ verbose (bool): Verbose output from pomegranate.
25
+ tol (float): Convergence tolerance.
26
+ pad_value (int): Value used to pad shorter sequences.
27
+
28
+ Returns:
29
+ hmm: Trained DenseHMM model
30
+ """
31
+ import torch
32
+ from pomegranate.hmm import DenseHMM
33
+ from pomegranate.distributions import Categorical
34
+ import numpy as np
35
+ from tqdm import tqdm
36
+
37
+ # Auto device detection
38
+ if device is None:
39
+ device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
40
+ elif isinstance(device, str):
41
+ device = torch.device(device)
42
+ print(f"Using device: {device}")
43
+
44
+ # Ensure emission probs on correct device
45
+ dists = [
46
+ Categorical(torch.tensor([p], device=device))
47
+ for p in emission_probs
48
+ ]
49
+
50
+ # Create DenseHMM
51
+ hmm = DenseHMM(
52
+ distributions=dists,
53
+ edges=transitions,
54
+ starts=start_probs,
55
+ ends=end_probs,
56
+ verbose=verbose,
57
+ max_iter=max_iter,
58
+ tol=tol,
59
+ ).to(device)
60
+
61
+ # Convert data to list if needed
62
+ if isinstance(data, np.ndarray):
63
+ data = data.tolist()
64
+
65
+ # Preprocess data (replace NaNs + pad)
66
+ max_length = max(len(seq) for seq in data)
67
+ processed_data = []
68
+ for sequence in tqdm(data, desc="Preprocessing Sequences"):
69
+ cleaned_seq = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in sequence]
70
+ cleaned_seq += [pad_value] * (max_length - len(cleaned_seq))
71
+ processed_data.append(cleaned_seq)
72
+
73
+ tensor_data = torch.tensor(processed_data, dtype=torch.long, device=device).unsqueeze(-1)
74
+
75
+ # Fit HMM
76
+ hmm.fit(tensor_data)
77
+
78
+ return hmm
@@ -0,0 +1,14 @@
1
+ from . import helpers
2
+ from .basecall_pod5s import basecall_pod5s
3
+ from .subsample_fasta_from_bed import subsample_fasta_from_bed
4
+ from .subsample_pod5 import subsample_pod5
5
+ from .fast5_to_pod5 import fast5_to_pod5
6
+
7
+
8
+ __all__ = [
9
+ "basecall_pod5s",
10
+ "subsample_fasta_from_bed",
11
+ "subsample_pod5",
12
+ "fast5_to_pod5",
13
+ "helpers"
14
+ ]