smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import torch
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
def apply_hmm_batched(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
|
|
7
|
+
"""
|
|
8
|
+
Applies an HMM model to an AnnData object using tensor-based sequence inputs.
|
|
9
|
+
If multiple methbases are passed, generates a combined feature set.
|
|
10
|
+
"""
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
model.to(device)
|
|
16
|
+
|
|
17
|
+
# --- Feature Definitions ---
|
|
18
|
+
feature_sets = {}
|
|
19
|
+
if footprints:
|
|
20
|
+
feature_sets["footprint"] = {
|
|
21
|
+
"features": {
|
|
22
|
+
"small_bound_stretch": [0, 20],
|
|
23
|
+
"medium_bound_stretch": [20, 50],
|
|
24
|
+
"putative_nucleosome": [50, 200],
|
|
25
|
+
"large_bound_stretch": [200, np.inf]
|
|
26
|
+
},
|
|
27
|
+
"state": "Non-Methylated"
|
|
28
|
+
}
|
|
29
|
+
if accessible_patches:
|
|
30
|
+
feature_sets["accessible"] = {
|
|
31
|
+
"features": {
|
|
32
|
+
"small_accessible_patch": [0, 20],
|
|
33
|
+
"mid_accessible_patch": [20, 80],
|
|
34
|
+
"large_accessible_patch": [80, np.inf]
|
|
35
|
+
},
|
|
36
|
+
"state": "Methylated"
|
|
37
|
+
}
|
|
38
|
+
if cpg:
|
|
39
|
+
feature_sets["cpg"] = {
|
|
40
|
+
"features": {
|
|
41
|
+
"cpg_patch": [0, np.inf]
|
|
42
|
+
},
|
|
43
|
+
"state": "Methylated"
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
# --- Init columns ---
|
|
47
|
+
all_features = []
|
|
48
|
+
combined_prefix = "Combined"
|
|
49
|
+
for key, fs in feature_sets.items():
|
|
50
|
+
if key == 'cpg':
|
|
51
|
+
all_features += [f"CpG_{f}" for f in fs["features"]]
|
|
52
|
+
all_features.append(f"CpG_all_{key}_features")
|
|
53
|
+
else:
|
|
54
|
+
for methbase in methbases:
|
|
55
|
+
all_features += [f"{methbase}_{f}" for f in fs["features"]]
|
|
56
|
+
all_features.append(f"{methbase}_all_{key}_features")
|
|
57
|
+
if len(methbases) > 1:
|
|
58
|
+
all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
|
|
59
|
+
all_features.append(f"{combined_prefix}_all_{key}_features")
|
|
60
|
+
|
|
61
|
+
for feature in all_features:
|
|
62
|
+
adata.obs[feature] = [[] for _ in range(adata.shape[0])]
|
|
63
|
+
adata.obs[f"{feature}_distances"] = [None] * adata.shape[0]
|
|
64
|
+
adata.obs[f"n_{feature}"] = -1
|
|
65
|
+
|
|
66
|
+
# --- Main loop ---
|
|
67
|
+
references = adata.obs[obs_column].cat.categories
|
|
68
|
+
|
|
69
|
+
for ref in tqdm(references, desc="Processing References"):
|
|
70
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
71
|
+
|
|
72
|
+
# Combined methbase mask
|
|
73
|
+
combined_mask = None
|
|
74
|
+
for methbase in methbases:
|
|
75
|
+
mask = {
|
|
76
|
+
"a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
|
|
77
|
+
"gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
|
|
78
|
+
"cpg": ref_subset.var[f"{ref}_CpG_site"] == True
|
|
79
|
+
}[methbase.lower()]
|
|
80
|
+
combined_mask = mask if combined_mask is None else combined_mask | mask
|
|
81
|
+
|
|
82
|
+
methbase_subset = ref_subset[:, mask]
|
|
83
|
+
matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
|
|
84
|
+
|
|
85
|
+
processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
|
|
86
|
+
tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
87
|
+
|
|
88
|
+
coords = methbase_subset.var_names
|
|
89
|
+
for key, fs in feature_sets.items():
|
|
90
|
+
if key == 'cpg':
|
|
91
|
+
continue
|
|
92
|
+
state_target = fs["state"]
|
|
93
|
+
feature_map = fs["features"]
|
|
94
|
+
|
|
95
|
+
pred_states = model.predict(tensor_batch)
|
|
96
|
+
probs = model.predict_proba(tensor_batch)
|
|
97
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
98
|
+
|
|
99
|
+
for i, idx in enumerate(methbase_subset.obs.index):
|
|
100
|
+
for start, length, label, prob in classifications[i]:
|
|
101
|
+
adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
|
|
102
|
+
adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
|
|
103
|
+
|
|
104
|
+
# Combined subset
|
|
105
|
+
if len(methbases) > 1:
|
|
106
|
+
combined_subset = ref_subset[:, combined_mask]
|
|
107
|
+
combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
|
|
108
|
+
processed_combined_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in combined_matrix]
|
|
109
|
+
tensor_combined_batch = torch.tensor(processed_combined_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
110
|
+
|
|
111
|
+
coords = combined_subset.var_names
|
|
112
|
+
for key, fs in feature_sets.items():
|
|
113
|
+
if key == 'cpg':
|
|
114
|
+
continue
|
|
115
|
+
state_target = fs["state"]
|
|
116
|
+
feature_map = fs["features"]
|
|
117
|
+
|
|
118
|
+
pred_states = model.predict(tensor_combined_batch)
|
|
119
|
+
probs = model.predict_proba(tensor_combined_batch)
|
|
120
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
121
|
+
|
|
122
|
+
for i, idx in enumerate(combined_subset.obs.index):
|
|
123
|
+
for start, length, label, prob in classifications[i]:
|
|
124
|
+
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
125
|
+
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
126
|
+
|
|
127
|
+
# --- Special handling for CpG ---
|
|
128
|
+
if cpg:
|
|
129
|
+
for ref in tqdm(references, desc="Processing CpG"):
|
|
130
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
131
|
+
mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
|
|
132
|
+
cpg_subset = ref_subset[:, mask]
|
|
133
|
+
matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
|
|
134
|
+
|
|
135
|
+
processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
|
|
136
|
+
tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
137
|
+
|
|
138
|
+
coords = cpg_subset.var_names
|
|
139
|
+
fs = feature_sets['cpg']
|
|
140
|
+
state_target = fs["state"]
|
|
141
|
+
feature_map = fs["features"]
|
|
142
|
+
|
|
143
|
+
pred_states = model.predict(tensor_batch)
|
|
144
|
+
probs = model.predict_proba(tensor_batch)
|
|
145
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
146
|
+
|
|
147
|
+
for i, idx in enumerate(cpg_subset.obs.index):
|
|
148
|
+
for start, length, label, prob in classifications[i]:
|
|
149
|
+
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
150
|
+
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
151
|
+
|
|
152
|
+
# --- Binarization + Distance ---
|
|
153
|
+
for feature in tqdm(all_features, desc="Finalizing Layers"):
|
|
154
|
+
bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
155
|
+
counts = np.zeros(adata.shape[0], dtype=int)
|
|
156
|
+
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
157
|
+
if not isinstance(intervals, list):
|
|
158
|
+
intervals = []
|
|
159
|
+
for start, length, prob in intervals:
|
|
160
|
+
if prob > threshold:
|
|
161
|
+
bin_matrix[row_idx, start:start+length] = 1
|
|
162
|
+
counts[row_idx] += 1
|
|
163
|
+
adata.layers[f"{feature}"] = bin_matrix
|
|
164
|
+
adata.obs[f"n_{feature}"] = counts
|
|
165
|
+
adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
|
|
166
|
+
|
|
167
|
+
def calculate_batch_distances(intervals_list, threshold=0.9):
|
|
168
|
+
"""
|
|
169
|
+
Vectorized calculation of distances across multiple reads.
|
|
170
|
+
|
|
171
|
+
Parameters:
|
|
172
|
+
intervals_list (list of list): Outer list = reads, inner list = intervals [start, length, prob]
|
|
173
|
+
threshold (float): Minimum probability threshold for filtering
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
List of distance lists per read.
|
|
177
|
+
"""
|
|
178
|
+
results = []
|
|
179
|
+
for intervals in intervals_list:
|
|
180
|
+
if not isinstance(intervals, list) or len(intervals) == 0:
|
|
181
|
+
results.append([])
|
|
182
|
+
continue
|
|
183
|
+
valid = [iv for iv in intervals if iv[2] > threshold]
|
|
184
|
+
valid = sorted(valid, key=lambda x: x[0])
|
|
185
|
+
dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
|
|
186
|
+
results.append(dists)
|
|
187
|
+
return results
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def classify_batch(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Methylated"):
|
|
192
|
+
"""
|
|
193
|
+
Classify batch sequences efficiently.
|
|
194
|
+
|
|
195
|
+
Parameters:
|
|
196
|
+
predicted_states_batch: Tensor [batch_size, seq_len]
|
|
197
|
+
probabilities_batch: Tensor [batch_size, seq_len, n_states]
|
|
198
|
+
coordinates: list of genomic coordinates
|
|
199
|
+
classification_mapping: dict of feature bins
|
|
200
|
+
target_state: state name ("Methylated" or "Non-Methylated")
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
List of classifications for each sequence.
|
|
204
|
+
"""
|
|
205
|
+
import numpy as np
|
|
206
|
+
|
|
207
|
+
state_labels = ["Non-Methylated", "Methylated"]
|
|
208
|
+
target_idx = state_labels.index(target_state)
|
|
209
|
+
batch_size = predicted_states_batch.shape[0]
|
|
210
|
+
|
|
211
|
+
all_classifications = []
|
|
212
|
+
|
|
213
|
+
for b in range(batch_size):
|
|
214
|
+
predicted_states = predicted_states_batch[b].cpu().numpy()
|
|
215
|
+
probabilities = probabilities_batch[b].cpu().numpy()
|
|
216
|
+
|
|
217
|
+
regions = []
|
|
218
|
+
current_start, current_length, current_probs = None, 0, []
|
|
219
|
+
|
|
220
|
+
for i, state_index in enumerate(predicted_states):
|
|
221
|
+
state_prob = probabilities[i][state_index]
|
|
222
|
+
if state_index == target_idx:
|
|
223
|
+
if current_start is None:
|
|
224
|
+
current_start = i
|
|
225
|
+
current_length += 1
|
|
226
|
+
current_probs.append(state_prob)
|
|
227
|
+
elif current_start is not None:
|
|
228
|
+
regions.append((current_start, current_length, np.mean(current_probs)))
|
|
229
|
+
current_start, current_length, current_probs = None, 0, []
|
|
230
|
+
|
|
231
|
+
if current_start is not None:
|
|
232
|
+
regions.append((current_start, current_length, np.mean(current_probs)))
|
|
233
|
+
|
|
234
|
+
final = []
|
|
235
|
+
for start, length, prob in regions:
|
|
236
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
237
|
+
label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
|
|
238
|
+
final.append((int(coordinates[start]) + 1, feature_length, label, prob))
|
|
239
|
+
all_classifications.append(final)
|
|
240
|
+
|
|
241
|
+
return all_classifications
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# classify_methylated_features
|
|
2
|
+
|
|
3
|
+
def classify_methylated_features(read, model, coordinates, classification_mapping={}):
|
|
4
|
+
"""
|
|
5
|
+
Classifies methylated features (accessible features or CpG methylated features)
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
read (np.ndarray) : An array of binarized SMF data representing a read
|
|
9
|
+
model (): a trained pomegranate HMM
|
|
10
|
+
coordinates (list): a list of postional coordinates corresponding to the positions in the read
|
|
11
|
+
classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
|
|
12
|
+
Returns:
|
|
13
|
+
final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
|
|
14
|
+
"""
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
sequence = list(read)
|
|
18
|
+
# Get the predicted states using the MAP algorithm
|
|
19
|
+
predicted_states = model.predict(sequence, algorithm='map')
|
|
20
|
+
|
|
21
|
+
# Get the probabilities for each state using the forward-backward algorithm
|
|
22
|
+
probabilities = model.predict_proba(sequence)
|
|
23
|
+
|
|
24
|
+
# Initialize lists to store the classifications and their probabilities
|
|
25
|
+
classifications = []
|
|
26
|
+
current_start = None
|
|
27
|
+
current_length = 0
|
|
28
|
+
current_probs = []
|
|
29
|
+
|
|
30
|
+
for i, state_index in enumerate(predicted_states):
|
|
31
|
+
state_name = model.states[state_index].name
|
|
32
|
+
state_prob = probabilities[i][state_index]
|
|
33
|
+
|
|
34
|
+
if state_name == "Methylated":
|
|
35
|
+
if current_start is None:
|
|
36
|
+
current_start = i
|
|
37
|
+
current_length += 1
|
|
38
|
+
current_probs.append(state_prob)
|
|
39
|
+
else:
|
|
40
|
+
if current_start is not None:
|
|
41
|
+
avg_prob = np.mean(current_probs)
|
|
42
|
+
classifications.append((current_start, current_length, "Methylated", avg_prob))
|
|
43
|
+
current_start = None
|
|
44
|
+
current_length = 0
|
|
45
|
+
current_probs = []
|
|
46
|
+
|
|
47
|
+
if current_start is not None:
|
|
48
|
+
avg_prob = np.mean(current_probs)
|
|
49
|
+
classifications.append((current_start, current_length, "Methylated", avg_prob))
|
|
50
|
+
|
|
51
|
+
final_classifications = []
|
|
52
|
+
for start, length, classification, prob in classifications:
|
|
53
|
+
final_classification = None
|
|
54
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
55
|
+
for feature_type, size_range in classification_mapping.items():
|
|
56
|
+
if size_range[0] <= feature_length < size_range[1]:
|
|
57
|
+
final_classification = feature_type
|
|
58
|
+
break
|
|
59
|
+
else:
|
|
60
|
+
pass
|
|
61
|
+
if not final_classification:
|
|
62
|
+
final_classification = classification
|
|
63
|
+
|
|
64
|
+
final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
|
|
65
|
+
|
|
66
|
+
return final_classifications
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# classify_non_methylated_features
|
|
2
|
+
|
|
3
|
+
def classify_non_methylated_features(read, model, coordinates, classification_mapping={}):
|
|
4
|
+
"""
|
|
5
|
+
Classifies non-methylated features (inaccessible features)
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
read (np.ndarray) : An array of binarized SMF data representing a read
|
|
9
|
+
model (): a trained pomegranate HMM
|
|
10
|
+
coordinates (list): a list of postional coordinates corresponding to the positions in the read
|
|
11
|
+
classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
|
|
12
|
+
Returns:
|
|
13
|
+
final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
|
|
14
|
+
"""
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
sequence = list(read)
|
|
18
|
+
# Get the predicted states using the MAP algorithm
|
|
19
|
+
predicted_states = model.predict(sequence, algorithm='map')
|
|
20
|
+
|
|
21
|
+
# Get the probabilities for each state using the forward-backward algorithm
|
|
22
|
+
probabilities = model.predict_proba(sequence)
|
|
23
|
+
|
|
24
|
+
# Initialize lists to store the classifications and their probabilities
|
|
25
|
+
classifications = []
|
|
26
|
+
current_start = None
|
|
27
|
+
current_length = 0
|
|
28
|
+
current_probs = []
|
|
29
|
+
|
|
30
|
+
for i, state_index in enumerate(predicted_states):
|
|
31
|
+
state_name = model.states[state_index].name
|
|
32
|
+
state_prob = probabilities[i][state_index]
|
|
33
|
+
|
|
34
|
+
if state_name == "Non-Methylated":
|
|
35
|
+
if current_start is None:
|
|
36
|
+
current_start = i
|
|
37
|
+
current_length += 1
|
|
38
|
+
current_probs.append(state_prob)
|
|
39
|
+
else:
|
|
40
|
+
if current_start is not None:
|
|
41
|
+
avg_prob = np.mean(current_probs)
|
|
42
|
+
classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
|
|
43
|
+
current_start = None
|
|
44
|
+
current_length = 0
|
|
45
|
+
current_probs = []
|
|
46
|
+
|
|
47
|
+
if current_start is not None:
|
|
48
|
+
avg_prob = np.mean(current_probs)
|
|
49
|
+
classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
|
|
50
|
+
|
|
51
|
+
final_classifications = []
|
|
52
|
+
for start, length, classification, prob in classifications:
|
|
53
|
+
final_classification = None
|
|
54
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
55
|
+
for feature_type, size_range in classification_mapping.items():
|
|
56
|
+
if size_range[0] <= feature_length < size_range[1]:
|
|
57
|
+
final_classification = feature_type
|
|
58
|
+
break
|
|
59
|
+
else:
|
|
60
|
+
pass
|
|
61
|
+
if not final_classification:
|
|
62
|
+
final_classification = classification
|
|
63
|
+
|
|
64
|
+
final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
|
|
65
|
+
|
|
66
|
+
return final_classifications
|
|
67
|
+
|
|
68
|
+
# if feature_length < 80:
|
|
69
|
+
# final_classification = 'small_bound_stretch'
|
|
70
|
+
# elif 80 <= feature_length <= 200:
|
|
71
|
+
# final_classification = 'Putative_Nucleosome'
|
|
72
|
+
# elif 200 < feature_length:
|
|
73
|
+
# final_classification = 'large_bound_stretch'
|
|
74
|
+
# else:
|
|
75
|
+
# pass
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# subset_adata
|
|
2
|
+
|
|
3
|
+
def subset_adata(adata, obs_columns):
|
|
4
|
+
"""
|
|
5
|
+
Subsets an AnnData object based on categorical values in specified `.obs` columns.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata (AnnData): The AnnData object to subset.
|
|
9
|
+
obs_columns (list of str): List of `.obs` column names to subset by. The order matters.
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def subset_recursive(adata_subset, columns):
|
|
16
|
+
if not columns:
|
|
17
|
+
return {(): adata_subset}
|
|
18
|
+
|
|
19
|
+
current_column = columns[0]
|
|
20
|
+
categories = adata_subset.obs[current_column].cat.categories
|
|
21
|
+
|
|
22
|
+
subsets = {}
|
|
23
|
+
for cat in categories:
|
|
24
|
+
subset = adata_subset[adata_subset.obs[current_column] == cat]
|
|
25
|
+
subsets.update(subset_recursive(subset, columns[1:]))
|
|
26
|
+
|
|
27
|
+
return subsets
|
|
28
|
+
|
|
29
|
+
# Start the recursive subset process
|
|
30
|
+
subsets_dict = subset_recursive(adata, obs_columns)
|
|
31
|
+
|
|
32
|
+
return subsets_dict
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# subset_adata
|
|
2
|
+
|
|
3
|
+
def subset_adata(adata, columns, cat_type='obs'):
|
|
4
|
+
"""
|
|
5
|
+
Subsets an AnnData object based on categorical values in specified .obs or .var columns.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata (AnnData): The AnnData object to subset.
|
|
9
|
+
columns (list of str): List of .obs or .var column names to subset by. The order matters.
|
|
10
|
+
cat_type (str): obs or var. Default is obs
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
|
|
17
|
+
# Returns when the bottom of the stack is reached
|
|
18
|
+
if not columns:
|
|
19
|
+
# If there's only one column, return the key as a single value, not a tuple
|
|
20
|
+
if len(key_prefix) == 1:
|
|
21
|
+
return {key_prefix[0]: adata_subset}
|
|
22
|
+
return {key_prefix: adata_subset}
|
|
23
|
+
|
|
24
|
+
current_column = columns[0]
|
|
25
|
+
subsets = {}
|
|
26
|
+
|
|
27
|
+
if 'obs' in cat_type:
|
|
28
|
+
categories = adata_subset.obs[current_column].cat.categories
|
|
29
|
+
for cat in categories:
|
|
30
|
+
subset = adata_subset[adata_subset.obs[current_column] == cat].copy()
|
|
31
|
+
new_key = key_prefix + (cat,)
|
|
32
|
+
subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
|
|
33
|
+
|
|
34
|
+
elif 'var' in cat_type:
|
|
35
|
+
categories = adata_subset.var[current_column].cat.categories
|
|
36
|
+
for cat in categories:
|
|
37
|
+
subset = adata_subset[:, adata_subset.var[current_column] == cat].copy()
|
|
38
|
+
new_key = key_prefix + (cat,)
|
|
39
|
+
subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
|
|
40
|
+
|
|
41
|
+
return subsets
|
|
42
|
+
|
|
43
|
+
# Start the recursive subset process
|
|
44
|
+
subsets_dict = subset_recursive(adata, columns, cat_type)
|
|
45
|
+
|
|
46
|
+
return subsets_dict
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# calculate_distances
|
|
2
|
+
|
|
3
|
+
def calculate_distances(intervals, threshold=0.9):
|
|
4
|
+
"""
|
|
5
|
+
Calculates distance between features in a read.
|
|
6
|
+
Takes in a list of intervals (start of feature, length of feature)
|
|
7
|
+
"""
|
|
8
|
+
# Sort intervals by start position
|
|
9
|
+
intervals = sorted(intervals, key=lambda x: x[0])
|
|
10
|
+
intervals = [interval for interval in intervals if interval[2] > threshold]
|
|
11
|
+
|
|
12
|
+
# Calculate distances
|
|
13
|
+
distances = []
|
|
14
|
+
for i in range(len(intervals) - 1):
|
|
15
|
+
end_current = intervals[i][0] + intervals[i][1]
|
|
16
|
+
start_next = intervals[i + 1][0]
|
|
17
|
+
distances.append(start_next - end_current)
|
|
18
|
+
return distances
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neighbors=100, overwrite=True, threads=8):
|
|
2
|
+
import scanpy as sc
|
|
3
|
+
import numpy as np
|
|
4
|
+
import os
|
|
5
|
+
from scipy.sparse import issparse
|
|
6
|
+
|
|
7
|
+
os.environ["OMP_NUM_THREADS"] = str(threads)
|
|
8
|
+
|
|
9
|
+
# Step 1: Apply var filter
|
|
10
|
+
if var_filters:
|
|
11
|
+
subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
|
|
12
|
+
adata_subset = adata[:, subset_mask].copy()
|
|
13
|
+
print(f"🔹 Subsetting adata: Retained {adata_subset.shape[1]} features based on filters {var_filters}")
|
|
14
|
+
else:
|
|
15
|
+
adata_subset = adata.copy()
|
|
16
|
+
print("🔹 No var filters provided. Using all features.")
|
|
17
|
+
|
|
18
|
+
# Step 2: NaN handling inside layer
|
|
19
|
+
if layer:
|
|
20
|
+
data = adata_subset.layers[layer]
|
|
21
|
+
if not issparse(data):
|
|
22
|
+
if np.isnan(data).any():
|
|
23
|
+
print("⚠ NaNs detected, filling with 0.5 before PCA + neighbors.")
|
|
24
|
+
data = np.nan_to_num(data, nan=0.5)
|
|
25
|
+
adata_subset.layers[layer] = data
|
|
26
|
+
else:
|
|
27
|
+
print("✅ No NaNs detected.")
|
|
28
|
+
else:
|
|
29
|
+
print("✅ Sparse matrix detected; skipping NaN check (sparse formats typically do not store NaNs).")
|
|
30
|
+
|
|
31
|
+
# Step 3: PCA + neighbors + UMAP on subset
|
|
32
|
+
if "X_umap" not in adata_subset.obsm or overwrite:
|
|
33
|
+
n_pcs = min(adata_subset.shape[1], n_pcs)
|
|
34
|
+
print(f"Running PCA with n_pcs={n_pcs}")
|
|
35
|
+
sc.pp.pca(adata_subset, layer=layer)
|
|
36
|
+
print('Running neighborhood graph')
|
|
37
|
+
sc.pp.neighbors(adata_subset, use_rep="X_pca", n_pcs=n_pcs, n_neighbors=knn_neighbors)
|
|
38
|
+
print('Running UMAP')
|
|
39
|
+
sc.tl.umap(adata_subset)
|
|
40
|
+
|
|
41
|
+
# Step 4: Store results in original adata
|
|
42
|
+
adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
|
|
43
|
+
adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
|
|
44
|
+
adata.obsp["distances"] = adata_subset.obsp["distances"]
|
|
45
|
+
adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
|
|
46
|
+
adata.uns["neighbors"] = adata_subset.uns["neighbors"]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Fix varm["PCs"] shape mismatch
|
|
50
|
+
pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
|
|
51
|
+
if var_filters:
|
|
52
|
+
subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
|
|
53
|
+
pc_matrix[subset_mask, :] = adata_subset.varm["PCs"]
|
|
54
|
+
else:
|
|
55
|
+
pc_matrix = adata_subset.varm["PCs"] # No subsetting case
|
|
56
|
+
|
|
57
|
+
adata.varm["PCs"] = pc_matrix
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
print(f"✅ Stored: adata.obsm['X_pca'] and adata.obsm['X_umap']")
|
|
61
|
+
|
|
62
|
+
return adata
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
def call_hmm_peaks(adata, feature_configs, obs_column='Reference_strand', site_types=['GpC_site', 'CpG_site'], save_plot=False, output_dir=None, date_tag=None):
|
|
2
|
+
"""
|
|
3
|
+
Calls peaks from HMM feature layers and annotates them into the AnnData object.
|
|
4
|
+
|
|
5
|
+
Parameters:
|
|
6
|
+
adata : AnnData object with HMM layers (from apply_hmm)
|
|
7
|
+
feature_configs : dict
|
|
8
|
+
min_distance : minimum distance between peaks
|
|
9
|
+
peak_width : window size around peak centers
|
|
10
|
+
peak_prominence : required peak prominence
|
|
11
|
+
peak_threshold : threshold for labeling a read as "present" at a peak
|
|
12
|
+
site_types : list of var site types to aggregate
|
|
13
|
+
save_plot : whether to save the plot
|
|
14
|
+
output_dir : path to save the figure if save_plot=True
|
|
15
|
+
date_tag : optional tag for filename
|
|
16
|
+
"""
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
from scipy.signal import find_peaks
|
|
19
|
+
import os
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
peak_columns = []
|
|
23
|
+
|
|
24
|
+
for feature_layer, config in feature_configs.items():
|
|
25
|
+
min_distance = config.get('min_distance', 200)
|
|
26
|
+
peak_width = config.get('peak_width', 200)
|
|
27
|
+
peak_prominence = config.get('peak_prominence', 0.2)
|
|
28
|
+
peak_threshold = config.get('peak_threshold', 0.8)
|
|
29
|
+
|
|
30
|
+
# 1️⃣ Calculate mean intensity profile
|
|
31
|
+
matrix = adata.layers[feature_layer]
|
|
32
|
+
means = np.mean(matrix, axis=0)
|
|
33
|
+
feature_peak_columns = []
|
|
34
|
+
|
|
35
|
+
# 2️⃣ Peak calling
|
|
36
|
+
peak_centers, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
|
|
37
|
+
adata.uns[f'{feature_layer} peak_centers'] = peak_centers
|
|
38
|
+
|
|
39
|
+
# 3️⃣ Plot
|
|
40
|
+
plt.figure(figsize=(6, 3))
|
|
41
|
+
plt.plot(range(len(means)), means)
|
|
42
|
+
plt.title(f"{feature_layer} density with peak calls")
|
|
43
|
+
plt.xlabel("Genomic position")
|
|
44
|
+
plt.ylabel("Mean feature density")
|
|
45
|
+
y = max(means) / 2
|
|
46
|
+
for i, center in enumerate(peak_centers):
|
|
47
|
+
plus_minus_width = peak_width // 2
|
|
48
|
+
start = center - plus_minus_width
|
|
49
|
+
end = center + plus_minus_width
|
|
50
|
+
plt.axvspan(start, end, color='purple', alpha=0.2)
|
|
51
|
+
plt.axvline(center, color='red', linestyle='--')
|
|
52
|
+
if i%2:
|
|
53
|
+
aligned = [end, 'left']
|
|
54
|
+
else:
|
|
55
|
+
aligned = [start, 'right']
|
|
56
|
+
plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
|
|
57
|
+
|
|
58
|
+
if save_plot and output_dir:
|
|
59
|
+
filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
|
|
60
|
+
plt.savefig(filename, bbox_inches='tight')
|
|
61
|
+
print(f"Saved plot to {filename}")
|
|
62
|
+
else:
|
|
63
|
+
plt.show()
|
|
64
|
+
|
|
65
|
+
# 4️⃣ Annotate peaks back into adata.obs
|
|
66
|
+
for center in peak_centers:
|
|
67
|
+
half_width = peak_width // 2
|
|
68
|
+
start, end = center - half_width, center + half_width
|
|
69
|
+
colname = f'{feature_layer}_peak_{center}'
|
|
70
|
+
peak_columns.append(colname)
|
|
71
|
+
feature_peak_columns.append(colname)
|
|
72
|
+
|
|
73
|
+
adata.var[colname] = (
|
|
74
|
+
(adata.var_names.astype(int) >= start) &
|
|
75
|
+
(adata.var_names.astype(int) <= end)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Feature layer intensity around peak
|
|
79
|
+
mean_values = np.mean(matrix[:, start:end+1], axis=1)
|
|
80
|
+
sum_values = np.sum(matrix[:, start:end+1], axis=1)
|
|
81
|
+
adata.obs[f'mean_{feature_layer}_around_{center}'] = mean_values
|
|
82
|
+
adata.obs[f'sum_{feature_layer}_around_{center}'] = sum_values
|
|
83
|
+
adata.obs[f'{feature_layer}_present_at_{center}'] = mean_values > peak_threshold
|
|
84
|
+
|
|
85
|
+
# Site-type based aggregation
|
|
86
|
+
for site_type in site_types:
|
|
87
|
+
adata.obs[f'{site_type}_sum_around_{center}'] = 0
|
|
88
|
+
adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
|
|
89
|
+
|
|
90
|
+
references = adata.obs[obs_column].cat.categories
|
|
91
|
+
for ref in adata.obs[obs_column].cat.categories:
|
|
92
|
+
subset = adata[adata.obs[obs_column] == ref]
|
|
93
|
+
for site_type in site_types:
|
|
94
|
+
mask = subset.var.get(f'{ref}_{site_type}', None)
|
|
95
|
+
if mask is not None:
|
|
96
|
+
region_mask = (subset.var_names[mask].astype(int) >= start) & (subset.var_names[mask].astype(int) <= end)
|
|
97
|
+
region = subset[:, mask].X[:, region_mask]
|
|
98
|
+
adata.obs.loc[subset.obs.index, f'{site_type}_sum_around_{center}'] = np.nansum(region, axis=1)
|
|
99
|
+
adata.obs.loc[subset.obs.index, f'{site_type}_mean_around_{center}'] = np.nanmean(region, axis=1)
|
|
100
|
+
|
|
101
|
+
adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
|
|
102
|
+
print(f"✅ Peak annotation completed for {feature_layer} with {len(peak_centers)} peaks.")
|
|
103
|
+
|
|
104
|
+
# Combine all peaks into a single "is_in_any_peak" column
|
|
105
|
+
adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
|