smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,241 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def apply_hmm_batched(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
7
+ """
8
+ Applies an HMM model to an AnnData object using tensor-based sequence inputs.
9
+ If multiple methbases are passed, generates a combined feature set.
10
+ """
11
+ import numpy as np
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ model.to(device)
16
+
17
+ # --- Feature Definitions ---
18
+ feature_sets = {}
19
+ if footprints:
20
+ feature_sets["footprint"] = {
21
+ "features": {
22
+ "small_bound_stretch": [0, 20],
23
+ "medium_bound_stretch": [20, 50],
24
+ "putative_nucleosome": [50, 200],
25
+ "large_bound_stretch": [200, np.inf]
26
+ },
27
+ "state": "Non-Methylated"
28
+ }
29
+ if accessible_patches:
30
+ feature_sets["accessible"] = {
31
+ "features": {
32
+ "small_accessible_patch": [0, 20],
33
+ "mid_accessible_patch": [20, 80],
34
+ "large_accessible_patch": [80, np.inf]
35
+ },
36
+ "state": "Methylated"
37
+ }
38
+ if cpg:
39
+ feature_sets["cpg"] = {
40
+ "features": {
41
+ "cpg_patch": [0, np.inf]
42
+ },
43
+ "state": "Methylated"
44
+ }
45
+
46
+ # --- Init columns ---
47
+ all_features = []
48
+ combined_prefix = "Combined"
49
+ for key, fs in feature_sets.items():
50
+ if key == 'cpg':
51
+ all_features += [f"CpG_{f}" for f in fs["features"]]
52
+ all_features.append(f"CpG_all_{key}_features")
53
+ else:
54
+ for methbase in methbases:
55
+ all_features += [f"{methbase}_{f}" for f in fs["features"]]
56
+ all_features.append(f"{methbase}_all_{key}_features")
57
+ if len(methbases) > 1:
58
+ all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
59
+ all_features.append(f"{combined_prefix}_all_{key}_features")
60
+
61
+ for feature in all_features:
62
+ adata.obs[feature] = [[] for _ in range(adata.shape[0])]
63
+ adata.obs[f"{feature}_distances"] = [None] * adata.shape[0]
64
+ adata.obs[f"n_{feature}"] = -1
65
+
66
+ # --- Main loop ---
67
+ references = adata.obs[obs_column].cat.categories
68
+
69
+ for ref in tqdm(references, desc="Processing References"):
70
+ ref_subset = adata[adata.obs[obs_column] == ref]
71
+
72
+ # Combined methbase mask
73
+ combined_mask = None
74
+ for methbase in methbases:
75
+ mask = {
76
+ "a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
77
+ "gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
78
+ "cpg": ref_subset.var[f"{ref}_CpG_site"] == True
79
+ }[methbase.lower()]
80
+ combined_mask = mask if combined_mask is None else combined_mask | mask
81
+
82
+ methbase_subset = ref_subset[:, mask]
83
+ matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
84
+
85
+ processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
86
+ tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
87
+
88
+ coords = methbase_subset.var_names
89
+ for key, fs in feature_sets.items():
90
+ if key == 'cpg':
91
+ continue
92
+ state_target = fs["state"]
93
+ feature_map = fs["features"]
94
+
95
+ pred_states = model.predict(tensor_batch)
96
+ probs = model.predict_proba(tensor_batch)
97
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
98
+
99
+ for i, idx in enumerate(methbase_subset.obs.index):
100
+ for start, length, label, prob in classifications[i]:
101
+ adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
102
+ adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
103
+
104
+ # Combined subset
105
+ if len(methbases) > 1:
106
+ combined_subset = ref_subset[:, combined_mask]
107
+ combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
108
+ processed_combined_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in combined_matrix]
109
+ tensor_combined_batch = torch.tensor(processed_combined_reads, dtype=torch.long, device=device).unsqueeze(-1)
110
+
111
+ coords = combined_subset.var_names
112
+ for key, fs in feature_sets.items():
113
+ if key == 'cpg':
114
+ continue
115
+ state_target = fs["state"]
116
+ feature_map = fs["features"]
117
+
118
+ pred_states = model.predict(tensor_combined_batch)
119
+ probs = model.predict_proba(tensor_combined_batch)
120
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
121
+
122
+ for i, idx in enumerate(combined_subset.obs.index):
123
+ for start, length, label, prob in classifications[i]:
124
+ adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
125
+ adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
126
+
127
+ # --- Special handling for CpG ---
128
+ if cpg:
129
+ for ref in tqdm(references, desc="Processing CpG"):
130
+ ref_subset = adata[adata.obs[obs_column] == ref]
131
+ mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
132
+ cpg_subset = ref_subset[:, mask]
133
+ matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
134
+
135
+ processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
136
+ tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
137
+
138
+ coords = cpg_subset.var_names
139
+ fs = feature_sets['cpg']
140
+ state_target = fs["state"]
141
+ feature_map = fs["features"]
142
+
143
+ pred_states = model.predict(tensor_batch)
144
+ probs = model.predict_proba(tensor_batch)
145
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
146
+
147
+ for i, idx in enumerate(cpg_subset.obs.index):
148
+ for start, length, label, prob in classifications[i]:
149
+ adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
150
+ adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
151
+
152
+ # --- Binarization + Distance ---
153
+ for feature in tqdm(all_features, desc="Finalizing Layers"):
154
+ bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
155
+ counts = np.zeros(adata.shape[0], dtype=int)
156
+ for row_idx, intervals in enumerate(adata.obs[feature]):
157
+ if not isinstance(intervals, list):
158
+ intervals = []
159
+ for start, length, prob in intervals:
160
+ if prob > threshold:
161
+ bin_matrix[row_idx, start:start+length] = 1
162
+ counts[row_idx] += 1
163
+ adata.layers[f"{feature}"] = bin_matrix
164
+ adata.obs[f"n_{feature}"] = counts
165
+ adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
166
+
167
+ def calculate_batch_distances(intervals_list, threshold=0.9):
168
+ """
169
+ Vectorized calculation of distances across multiple reads.
170
+
171
+ Parameters:
172
+ intervals_list (list of list): Outer list = reads, inner list = intervals [start, length, prob]
173
+ threshold (float): Minimum probability threshold for filtering
174
+
175
+ Returns:
176
+ List of distance lists per read.
177
+ """
178
+ results = []
179
+ for intervals in intervals_list:
180
+ if not isinstance(intervals, list) or len(intervals) == 0:
181
+ results.append([])
182
+ continue
183
+ valid = [iv for iv in intervals if iv[2] > threshold]
184
+ valid = sorted(valid, key=lambda x: x[0])
185
+ dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
186
+ results.append(dists)
187
+ return results
188
+
189
+
190
+
191
+ def classify_batch(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Methylated"):
192
+ """
193
+ Classify batch sequences efficiently.
194
+
195
+ Parameters:
196
+ predicted_states_batch: Tensor [batch_size, seq_len]
197
+ probabilities_batch: Tensor [batch_size, seq_len, n_states]
198
+ coordinates: list of genomic coordinates
199
+ classification_mapping: dict of feature bins
200
+ target_state: state name ("Methylated" or "Non-Methylated")
201
+
202
+ Returns:
203
+ List of classifications for each sequence.
204
+ """
205
+ import numpy as np
206
+
207
+ state_labels = ["Non-Methylated", "Methylated"]
208
+ target_idx = state_labels.index(target_state)
209
+ batch_size = predicted_states_batch.shape[0]
210
+
211
+ all_classifications = []
212
+
213
+ for b in range(batch_size):
214
+ predicted_states = predicted_states_batch[b].cpu().numpy()
215
+ probabilities = probabilities_batch[b].cpu().numpy()
216
+
217
+ regions = []
218
+ current_start, current_length, current_probs = None, 0, []
219
+
220
+ for i, state_index in enumerate(predicted_states):
221
+ state_prob = probabilities[i][state_index]
222
+ if state_index == target_idx:
223
+ if current_start is None:
224
+ current_start = i
225
+ current_length += 1
226
+ current_probs.append(state_prob)
227
+ elif current_start is not None:
228
+ regions.append((current_start, current_length, np.mean(current_probs)))
229
+ current_start, current_length, current_probs = None, 0, []
230
+
231
+ if current_start is not None:
232
+ regions.append((current_start, current_length, np.mean(current_probs)))
233
+
234
+ final = []
235
+ for start, length, prob in regions:
236
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
237
+ label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
238
+ final.append((int(coordinates[start]) + 1, feature_length, label, prob))
239
+ all_classifications.append(final)
240
+
241
+ return all_classifications
@@ -0,0 +1,66 @@
1
+ # classify_methylated_features
2
+
3
+ def classify_methylated_features(read, model, coordinates, classification_mapping={}):
4
+ """
5
+ Classifies methylated features (accessible features or CpG methylated features)
6
+
7
+ Parameters:
8
+ read (np.ndarray) : An array of binarized SMF data representing a read
9
+ model (): a trained pomegranate HMM
10
+ coordinates (list): a list of postional coordinates corresponding to the positions in the read
11
+ classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
12
+ Returns:
13
+ final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
14
+ """
15
+ import numpy as np
16
+
17
+ sequence = list(read)
18
+ # Get the predicted states using the MAP algorithm
19
+ predicted_states = model.predict(sequence, algorithm='map')
20
+
21
+ # Get the probabilities for each state using the forward-backward algorithm
22
+ probabilities = model.predict_proba(sequence)
23
+
24
+ # Initialize lists to store the classifications and their probabilities
25
+ classifications = []
26
+ current_start = None
27
+ current_length = 0
28
+ current_probs = []
29
+
30
+ for i, state_index in enumerate(predicted_states):
31
+ state_name = model.states[state_index].name
32
+ state_prob = probabilities[i][state_index]
33
+
34
+ if state_name == "Methylated":
35
+ if current_start is None:
36
+ current_start = i
37
+ current_length += 1
38
+ current_probs.append(state_prob)
39
+ else:
40
+ if current_start is not None:
41
+ avg_prob = np.mean(current_probs)
42
+ classifications.append((current_start, current_length, "Methylated", avg_prob))
43
+ current_start = None
44
+ current_length = 0
45
+ current_probs = []
46
+
47
+ if current_start is not None:
48
+ avg_prob = np.mean(current_probs)
49
+ classifications.append((current_start, current_length, "Methylated", avg_prob))
50
+
51
+ final_classifications = []
52
+ for start, length, classification, prob in classifications:
53
+ final_classification = None
54
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
55
+ for feature_type, size_range in classification_mapping.items():
56
+ if size_range[0] <= feature_length < size_range[1]:
57
+ final_classification = feature_type
58
+ break
59
+ else:
60
+ pass
61
+ if not final_classification:
62
+ final_classification = classification
63
+
64
+ final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
65
+
66
+ return final_classifications
@@ -0,0 +1,75 @@
1
+ # classify_non_methylated_features
2
+
3
+ def classify_non_methylated_features(read, model, coordinates, classification_mapping={}):
4
+ """
5
+ Classifies non-methylated features (inaccessible features)
6
+
7
+ Parameters:
8
+ read (np.ndarray) : An array of binarized SMF data representing a read
9
+ model (): a trained pomegranate HMM
10
+ coordinates (list): a list of postional coordinates corresponding to the positions in the read
11
+ classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
12
+ Returns:
13
+ final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
14
+ """
15
+ import numpy as np
16
+
17
+ sequence = list(read)
18
+ # Get the predicted states using the MAP algorithm
19
+ predicted_states = model.predict(sequence, algorithm='map')
20
+
21
+ # Get the probabilities for each state using the forward-backward algorithm
22
+ probabilities = model.predict_proba(sequence)
23
+
24
+ # Initialize lists to store the classifications and their probabilities
25
+ classifications = []
26
+ current_start = None
27
+ current_length = 0
28
+ current_probs = []
29
+
30
+ for i, state_index in enumerate(predicted_states):
31
+ state_name = model.states[state_index].name
32
+ state_prob = probabilities[i][state_index]
33
+
34
+ if state_name == "Non-Methylated":
35
+ if current_start is None:
36
+ current_start = i
37
+ current_length += 1
38
+ current_probs.append(state_prob)
39
+ else:
40
+ if current_start is not None:
41
+ avg_prob = np.mean(current_probs)
42
+ classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
43
+ current_start = None
44
+ current_length = 0
45
+ current_probs = []
46
+
47
+ if current_start is not None:
48
+ avg_prob = np.mean(current_probs)
49
+ classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
50
+
51
+ final_classifications = []
52
+ for start, length, classification, prob in classifications:
53
+ final_classification = None
54
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
55
+ for feature_type, size_range in classification_mapping.items():
56
+ if size_range[0] <= feature_length < size_range[1]:
57
+ final_classification = feature_type
58
+ break
59
+ else:
60
+ pass
61
+ if not final_classification:
62
+ final_classification = classification
63
+
64
+ final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
65
+
66
+ return final_classifications
67
+
68
+ # if feature_length < 80:
69
+ # final_classification = 'small_bound_stretch'
70
+ # elif 80 <= feature_length <= 200:
71
+ # final_classification = 'Putative_Nucleosome'
72
+ # elif 200 < feature_length:
73
+ # final_classification = 'large_bound_stretch'
74
+ # else:
75
+ # pass
@@ -0,0 +1,32 @@
1
+ # subset_adata
2
+
3
+ def subset_adata(adata, obs_columns):
4
+ """
5
+ Subsets an AnnData object based on categorical values in specified `.obs` columns.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subset.
9
+ obs_columns (list of str): List of `.obs` column names to subset by. The order matters.
10
+
11
+ Returns:
12
+ dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
13
+ """
14
+
15
+ def subset_recursive(adata_subset, columns):
16
+ if not columns:
17
+ return {(): adata_subset}
18
+
19
+ current_column = columns[0]
20
+ categories = adata_subset.obs[current_column].cat.categories
21
+
22
+ subsets = {}
23
+ for cat in categories:
24
+ subset = adata_subset[adata_subset.obs[current_column] == cat]
25
+ subsets.update(subset_recursive(subset, columns[1:]))
26
+
27
+ return subsets
28
+
29
+ # Start the recursive subset process
30
+ subsets_dict = subset_recursive(adata, obs_columns)
31
+
32
+ return subsets_dict
@@ -0,0 +1,46 @@
1
+ # subset_adata
2
+
3
+ def subset_adata(adata, columns, cat_type='obs'):
4
+ """
5
+ Subsets an AnnData object based on categorical values in specified .obs or .var columns.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subset.
9
+ columns (list of str): List of .obs or .var column names to subset by. The order matters.
10
+ cat_type (str): obs or var. Default is obs
11
+
12
+ Returns:
13
+ dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
14
+ """
15
+
16
+ def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
17
+ # Returns when the bottom of the stack is reached
18
+ if not columns:
19
+ # If there's only one column, return the key as a single value, not a tuple
20
+ if len(key_prefix) == 1:
21
+ return {key_prefix[0]: adata_subset}
22
+ return {key_prefix: adata_subset}
23
+
24
+ current_column = columns[0]
25
+ subsets = {}
26
+
27
+ if 'obs' in cat_type:
28
+ categories = adata_subset.obs[current_column].cat.categories
29
+ for cat in categories:
30
+ subset = adata_subset[adata_subset.obs[current_column] == cat].copy()
31
+ new_key = key_prefix + (cat,)
32
+ subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
33
+
34
+ elif 'var' in cat_type:
35
+ categories = adata_subset.var[current_column].cat.categories
36
+ for cat in categories:
37
+ subset = adata_subset[:, adata_subset.var[current_column] == cat].copy()
38
+ new_key = key_prefix + (cat,)
39
+ subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
40
+
41
+ return subsets
42
+
43
+ # Start the recursive subset process
44
+ subsets_dict = subset_recursive(adata, columns, cat_type)
45
+
46
+ return subsets_dict
@@ -0,0 +1,18 @@
1
+ # calculate_distances
2
+
3
+ def calculate_distances(intervals, threshold=0.9):
4
+ """
5
+ Calculates distance between features in a read.
6
+ Takes in a list of intervals (start of feature, length of feature)
7
+ """
8
+ # Sort intervals by start position
9
+ intervals = sorted(intervals, key=lambda x: x[0])
10
+ intervals = [interval for interval in intervals if interval[2] > threshold]
11
+
12
+ # Calculate distances
13
+ distances = []
14
+ for i in range(len(intervals) - 1):
15
+ end_current = intervals[i][0] + intervals[i][1]
16
+ start_next = intervals[i + 1][0]
17
+ distances.append(start_next - end_current)
18
+ return distances
@@ -0,0 +1,62 @@
1
+ def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neighbors=100, overwrite=True, threads=8):
2
+ import scanpy as sc
3
+ import numpy as np
4
+ import os
5
+ from scipy.sparse import issparse
6
+
7
+ os.environ["OMP_NUM_THREADS"] = str(threads)
8
+
9
+ # Step 1: Apply var filter
10
+ if var_filters:
11
+ subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
12
+ adata_subset = adata[:, subset_mask].copy()
13
+ print(f"🔹 Subsetting adata: Retained {adata_subset.shape[1]} features based on filters {var_filters}")
14
+ else:
15
+ adata_subset = adata.copy()
16
+ print("🔹 No var filters provided. Using all features.")
17
+
18
+ # Step 2: NaN handling inside layer
19
+ if layer:
20
+ data = adata_subset.layers[layer]
21
+ if not issparse(data):
22
+ if np.isnan(data).any():
23
+ print("⚠ NaNs detected, filling with 0.5 before PCA + neighbors.")
24
+ data = np.nan_to_num(data, nan=0.5)
25
+ adata_subset.layers[layer] = data
26
+ else:
27
+ print("✅ No NaNs detected.")
28
+ else:
29
+ print("✅ Sparse matrix detected; skipping NaN check (sparse formats typically do not store NaNs).")
30
+
31
+ # Step 3: PCA + neighbors + UMAP on subset
32
+ if "X_umap" not in adata_subset.obsm or overwrite:
33
+ n_pcs = min(adata_subset.shape[1], n_pcs)
34
+ print(f"Running PCA with n_pcs={n_pcs}")
35
+ sc.pp.pca(adata_subset, layer=layer)
36
+ print('Running neighborhood graph')
37
+ sc.pp.neighbors(adata_subset, use_rep="X_pca", n_pcs=n_pcs, n_neighbors=knn_neighbors)
38
+ print('Running UMAP')
39
+ sc.tl.umap(adata_subset)
40
+
41
+ # Step 4: Store results in original adata
42
+ adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
43
+ adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
44
+ adata.obsp["distances"] = adata_subset.obsp["distances"]
45
+ adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
46
+ adata.uns["neighbors"] = adata_subset.uns["neighbors"]
47
+
48
+
49
+ # Fix varm["PCs"] shape mismatch
50
+ pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
51
+ if var_filters:
52
+ subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
53
+ pc_matrix[subset_mask, :] = adata_subset.varm["PCs"]
54
+ else:
55
+ pc_matrix = adata_subset.varm["PCs"] # No subsetting case
56
+
57
+ adata.varm["PCs"] = pc_matrix
58
+
59
+
60
+ print(f"✅ Stored: adata.obsm['X_pca'] and adata.obsm['X_umap']")
61
+
62
+ return adata
@@ -0,0 +1,105 @@
1
+ def call_hmm_peaks(adata, feature_configs, obs_column='Reference_strand', site_types=['GpC_site', 'CpG_site'], save_plot=False, output_dir=None, date_tag=None):
2
+ """
3
+ Calls peaks from HMM feature layers and annotates them into the AnnData object.
4
+
5
+ Parameters:
6
+ adata : AnnData object with HMM layers (from apply_hmm)
7
+ feature_configs : dict
8
+ min_distance : minimum distance between peaks
9
+ peak_width : window size around peak centers
10
+ peak_prominence : required peak prominence
11
+ peak_threshold : threshold for labeling a read as "present" at a peak
12
+ site_types : list of var site types to aggregate
13
+ save_plot : whether to save the plot
14
+ output_dir : path to save the figure if save_plot=True
15
+ date_tag : optional tag for filename
16
+ """
17
+ import matplotlib.pyplot as plt
18
+ from scipy.signal import find_peaks
19
+ import os
20
+ import numpy as np
21
+
22
+ peak_columns = []
23
+
24
+ for feature_layer, config in feature_configs.items():
25
+ min_distance = config.get('min_distance', 200)
26
+ peak_width = config.get('peak_width', 200)
27
+ peak_prominence = config.get('peak_prominence', 0.2)
28
+ peak_threshold = config.get('peak_threshold', 0.8)
29
+
30
+ # 1️⃣ Calculate mean intensity profile
31
+ matrix = adata.layers[feature_layer]
32
+ means = np.mean(matrix, axis=0)
33
+ feature_peak_columns = []
34
+
35
+ # 2️⃣ Peak calling
36
+ peak_centers, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
+ adata.uns[f'{feature_layer} peak_centers'] = peak_centers
38
+
39
+ # 3️⃣ Plot
40
+ plt.figure(figsize=(6, 3))
41
+ plt.plot(range(len(means)), means)
42
+ plt.title(f"{feature_layer} density with peak calls")
43
+ plt.xlabel("Genomic position")
44
+ plt.ylabel("Mean feature density")
45
+ y = max(means) / 2
46
+ for i, center in enumerate(peak_centers):
47
+ plus_minus_width = peak_width // 2
48
+ start = center - plus_minus_width
49
+ end = center + plus_minus_width
50
+ plt.axvspan(start, end, color='purple', alpha=0.2)
51
+ plt.axvline(center, color='red', linestyle='--')
52
+ if i%2:
53
+ aligned = [end, 'left']
54
+ else:
55
+ aligned = [start, 'right']
56
+ plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
57
+
58
+ if save_plot and output_dir:
59
+ filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
60
+ plt.savefig(filename, bbox_inches='tight')
61
+ print(f"Saved plot to {filename}")
62
+ else:
63
+ plt.show()
64
+
65
+ # 4️⃣ Annotate peaks back into adata.obs
66
+ for center in peak_centers:
67
+ half_width = peak_width // 2
68
+ start, end = center - half_width, center + half_width
69
+ colname = f'{feature_layer}_peak_{center}'
70
+ peak_columns.append(colname)
71
+ feature_peak_columns.append(colname)
72
+
73
+ adata.var[colname] = (
74
+ (adata.var_names.astype(int) >= start) &
75
+ (adata.var_names.astype(int) <= end)
76
+ )
77
+
78
+ # Feature layer intensity around peak
79
+ mean_values = np.mean(matrix[:, start:end+1], axis=1)
80
+ sum_values = np.sum(matrix[:, start:end+1], axis=1)
81
+ adata.obs[f'mean_{feature_layer}_around_{center}'] = mean_values
82
+ adata.obs[f'sum_{feature_layer}_around_{center}'] = sum_values
83
+ adata.obs[f'{feature_layer}_present_at_{center}'] = mean_values > peak_threshold
84
+
85
+ # Site-type based aggregation
86
+ for site_type in site_types:
87
+ adata.obs[f'{site_type}_sum_around_{center}'] = 0
88
+ adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
89
+
90
+ references = adata.obs[obs_column].cat.categories
91
+ for ref in adata.obs[obs_column].cat.categories:
92
+ subset = adata[adata.obs[obs_column] == ref]
93
+ for site_type in site_types:
94
+ mask = subset.var.get(f'{ref}_{site_type}', None)
95
+ if mask is not None:
96
+ region_mask = (subset.var_names[mask].astype(int) >= start) & (subset.var_names[mask].astype(int) <= end)
97
+ region = subset[:, mask].X[:, region_mask]
98
+ adata.obs.loc[subset.obs.index, f'{site_type}_sum_around_{center}'] = np.nansum(region, axis=1)
99
+ adata.obs.loc[subset.obs.index, f'{site_type}_mean_around_{center}'] = np.nanmean(region, axis=1)
100
+
101
+ adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
102
+ print(f"✅ Peak annotation completed for {feature_layer} with {len(peak_centers)} peaks.")
103
+
104
+ # Combine all peaks into a single "is_in_any_peak" column
105
+ adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)