smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +34 -0
- smftools/_settings.py +20 -0
- smftools/_version.py +1 -0
- smftools/cli.py +184 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +33 -0
- smftools/config/deaminase.yaml +56 -0
- smftools/config/default.yaml +253 -0
- smftools/config/direct.yaml +17 -0
- smftools/config/experiment_config.py +1191 -0
- smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
- smftools/datasets/F1_sample_sheet.csv +5 -0
- smftools/datasets/__init__.py +9 -0
- smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
- smftools/datasets/datasets.py +28 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/hmm/apply_hmm_batched.py +242 -0
- smftools/hmm/calculate_distances.py +18 -0
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/hmm/display_hmm.py +18 -0
- smftools/hmm/hmm_readwrite.py +16 -0
- smftools/hmm/nucleosome_hmm_refinement.py +104 -0
- smftools/hmm/train_hmm.py +78 -0
- smftools/informatics/__init__.py +14 -0
- smftools/informatics/archived/bam_conversion.py +59 -0
- smftools/informatics/archived/bam_direct.py +63 -0
- smftools/informatics/archived/basecalls_to_adata.py +71 -0
- smftools/informatics/archived/conversion_smf.py +132 -0
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/direct_smf.py +137 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/fast5_to_pod5.py +24 -0
- smftools/informatics/helpers/__init__.py +73 -0
- smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
- smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
- smftools/informatics/helpers/archived/informatics.py +260 -0
- smftools/informatics/helpers/archived/load_adata.py +516 -0
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/bed_to_bigwig.py +39 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
- smftools/informatics/helpers/canoncall.py +34 -0
- smftools/informatics/helpers/complement_base_list.py +21 -0
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
- smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
- smftools/informatics/helpers/count_aligned_reads.py +43 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/discover_input_files.py +100 -0
- smftools/informatics/helpers/extract_base_identities.py +70 -0
- smftools/informatics/helpers/extract_mods.py +83 -0
- smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
- smftools/informatics/helpers/find_conversion_sites.py +51 -0
- smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
- smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
- smftools/informatics/helpers/get_native_references.py +28 -0
- smftools/informatics/helpers/index_fasta.py +12 -0
- smftools/informatics/helpers/make_dirs.py +21 -0
- smftools/informatics/helpers/make_modbed.py +27 -0
- smftools/informatics/helpers/modQC.py +27 -0
- smftools/informatics/helpers/modcall.py +36 -0
- smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
- smftools/informatics/helpers/ohe_batching.py +76 -0
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +57 -0
- smftools/informatics/helpers/plot_bed_histograms.py +269 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
- smftools/informatics/helpers/split_and_index_BAM.py +32 -0
- smftools/informatics/readwrite.py +106 -0
- smftools/informatics/subsample_fasta_from_bed.py +47 -0
- smftools/informatics/subsample_pod5.py +104 -0
- smftools/load_adata.py +1346 -0
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/data/preprocessing.py +6 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/__init__.py +9 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/machine_learning/models/positional.py +18 -0
- smftools/machine_learning/models/rnn.py +17 -0
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/models/wrappers.py +20 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +10 -0
- smftools/machine_learning/utils/grl.py +14 -0
- smftools/plotting/__init__.py +18 -0
- smftools/plotting/autocorrelation_plotting.py +611 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +682 -0
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +38 -0
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/archives/mark_duplicates.py +146 -0
- smftools/preprocessing/archives/preprocessing.py +614 -0
- smftools/preprocessing/archives/remove_duplicates.py +21 -0
- smftools/preprocessing/binarize_on_Youden.py +45 -0
- smftools/preprocessing/binary_layers_to_ohe.py +40 -0
- smftools/preprocessing/calculate_complexity.py +72 -0
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_consensus.py +47 -0
- smftools/preprocessing/calculate_coverage.py +51 -0
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
- smftools/preprocessing/calculate_position_Youden.py +115 -0
- smftools/preprocessing/calculate_read_length_stats.py +79 -0
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +62 -0
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1351 -0
- smftools/preprocessing/invert_adata.py +37 -0
- smftools/preprocessing/load_sample_sheet.py +53 -0
- smftools/preprocessing/make_dirs.py +21 -0
- smftools/preprocessing/min_non_diagonal.py +25 -0
- smftools/preprocessing/recipes.py +127 -0
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +1004 -0
- smftools/tools/__init__.py +20 -0
- smftools/tools/archived/apply_hmm.py +202 -0
- smftools/tools/archived/classifiers.py +787 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/position_stats.py +601 -0
- smftools/tools/read_stats.py +184 -0
- smftools/tools/spatial_autocorrelation.py +562 -0
- smftools/tools/subset_adata.py +28 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
- smftools-0.2.1.dist-info/RECORD +161 -0
- smftools-0.2.1.dist-info/entry_points.txt +2 -0
- smftools-0.1.6.dist-info/RECORD +0 -4
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
smftools/hmm/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .apply_hmm_batched import apply_hmm_batched
|
|
2
|
+
from .calculate_distances import calculate_distances
|
|
3
|
+
from .call_hmm_peaks import call_hmm_peaks
|
|
4
|
+
from .display_hmm import display_hmm
|
|
5
|
+
from .hmm_readwrite import load_hmm, save_hmm
|
|
6
|
+
from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
|
|
7
|
+
from .train_hmm import train_hmm
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"apply_hmm_batched",
|
|
12
|
+
"calculate_distances",
|
|
13
|
+
"call_hmm_peaks",
|
|
14
|
+
"display_hmm",
|
|
15
|
+
"load_hmm",
|
|
16
|
+
"refine_nucleosome_calls",
|
|
17
|
+
"infer_nucleosomes_in_large_bound",
|
|
18
|
+
"save_hmm",
|
|
19
|
+
"train_hmm"
|
|
20
|
+
]
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import torch
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
def apply_hmm_batched(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A", "C"], device="cpu", threshold=0.7, deaminase_footprinting=False):
|
|
7
|
+
"""
|
|
8
|
+
Applies an HMM model to an AnnData object using tensor-based sequence inputs.
|
|
9
|
+
If multiple methbases are passed, generates a combined feature set.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
model.to(device)
|
|
13
|
+
|
|
14
|
+
# --- Feature Definitions ---
|
|
15
|
+
feature_sets = {}
|
|
16
|
+
if footprints:
|
|
17
|
+
feature_sets["footprint"] = {
|
|
18
|
+
"features": {
|
|
19
|
+
"small_bound_stretch": [0, 20],
|
|
20
|
+
"medium_bound_stretch": [20, 50],
|
|
21
|
+
"putative_nucleosome": [50, 200],
|
|
22
|
+
"large_bound_stretch": [200, np.inf]
|
|
23
|
+
},
|
|
24
|
+
"state": "Non-Methylated"
|
|
25
|
+
}
|
|
26
|
+
if accessible_patches:
|
|
27
|
+
feature_sets["accessible"] = {
|
|
28
|
+
"features": {
|
|
29
|
+
"small_accessible_patch": [0, 20],
|
|
30
|
+
"mid_accessible_patch": [20, 80],
|
|
31
|
+
"large_accessible_patch": [80, np.inf]
|
|
32
|
+
},
|
|
33
|
+
"state": "Methylated"
|
|
34
|
+
}
|
|
35
|
+
if cpg:
|
|
36
|
+
feature_sets["cpg"] = {
|
|
37
|
+
"features": {
|
|
38
|
+
"cpg_patch": [0, np.inf]
|
|
39
|
+
},
|
|
40
|
+
"state": "Methylated"
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# --- Init columns ---
|
|
44
|
+
all_features = []
|
|
45
|
+
combined_prefix = "Combined"
|
|
46
|
+
for key, fs in feature_sets.items():
|
|
47
|
+
if key == 'cpg':
|
|
48
|
+
all_features += [f"CpG_{f}" for f in fs["features"]]
|
|
49
|
+
all_features.append(f"CpG_all_{key}_features")
|
|
50
|
+
else:
|
|
51
|
+
for methbase in methbases:
|
|
52
|
+
all_features += [f"{methbase}_{f}" for f in fs["features"]]
|
|
53
|
+
all_features.append(f"{methbase}_all_{key}_features")
|
|
54
|
+
if len(methbases) > 1:
|
|
55
|
+
all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
|
|
56
|
+
all_features.append(f"{combined_prefix}_all_{key}_features")
|
|
57
|
+
|
|
58
|
+
for feature in all_features:
|
|
59
|
+
adata.obs[feature] = [[] for _ in range(adata.shape[0])]
|
|
60
|
+
adata.obs[f"{feature}_distances"] = [None] * adata.shape[0]
|
|
61
|
+
adata.obs[f"n_{feature}"] = -1
|
|
62
|
+
|
|
63
|
+
# --- Main loop ---
|
|
64
|
+
references = adata.obs[obs_column].cat.categories
|
|
65
|
+
|
|
66
|
+
for ref in tqdm(references, desc="Processing References"):
|
|
67
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
68
|
+
|
|
69
|
+
# Combined methbase mask
|
|
70
|
+
combined_mask = None
|
|
71
|
+
for methbase in methbases:
|
|
72
|
+
mask = {
|
|
73
|
+
"a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
|
|
74
|
+
"c": ref_subset.var[f"{ref}_any_C_site"] == True,
|
|
75
|
+
"gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
|
|
76
|
+
"cpg": ref_subset.var[f"{ref}_CpG_site"] == True
|
|
77
|
+
}[methbase.lower()]
|
|
78
|
+
combined_mask = mask if combined_mask is None else combined_mask | mask
|
|
79
|
+
|
|
80
|
+
methbase_subset = ref_subset[:, mask]
|
|
81
|
+
matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
|
|
82
|
+
|
|
83
|
+
processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
|
|
84
|
+
tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
85
|
+
|
|
86
|
+
coords = methbase_subset.var_names
|
|
87
|
+
for key, fs in feature_sets.items():
|
|
88
|
+
if key == 'cpg':
|
|
89
|
+
continue
|
|
90
|
+
state_target = fs["state"]
|
|
91
|
+
feature_map = fs["features"]
|
|
92
|
+
|
|
93
|
+
pred_states = model.predict(tensor_batch)
|
|
94
|
+
probs = model.predict_proba(tensor_batch)
|
|
95
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
96
|
+
|
|
97
|
+
for i, idx in enumerate(methbase_subset.obs.index):
|
|
98
|
+
for start, length, label, prob in classifications[i]:
|
|
99
|
+
adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
|
|
100
|
+
adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
|
|
101
|
+
|
|
102
|
+
# Combined subset
|
|
103
|
+
if len(methbases) > 1:
|
|
104
|
+
combined_subset = ref_subset[:, combined_mask]
|
|
105
|
+
combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
|
|
106
|
+
processed_combined_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in combined_matrix]
|
|
107
|
+
tensor_combined_batch = torch.tensor(processed_combined_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
108
|
+
|
|
109
|
+
coords = combined_subset.var_names
|
|
110
|
+
for key, fs in feature_sets.items():
|
|
111
|
+
if key == 'cpg':
|
|
112
|
+
continue
|
|
113
|
+
state_target = fs["state"]
|
|
114
|
+
feature_map = fs["features"]
|
|
115
|
+
|
|
116
|
+
pred_states = model.predict(tensor_combined_batch)
|
|
117
|
+
probs = model.predict_proba(tensor_combined_batch)
|
|
118
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
119
|
+
|
|
120
|
+
for i, idx in enumerate(combined_subset.obs.index):
|
|
121
|
+
for start, length, label, prob in classifications[i]:
|
|
122
|
+
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
123
|
+
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
124
|
+
|
|
125
|
+
# --- Special handling for CpG ---
|
|
126
|
+
if cpg:
|
|
127
|
+
for ref in tqdm(references, desc="Processing CpG"):
|
|
128
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
129
|
+
mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
|
|
130
|
+
cpg_subset = ref_subset[:, mask]
|
|
131
|
+
matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
|
|
132
|
+
|
|
133
|
+
processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
|
|
134
|
+
tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
135
|
+
|
|
136
|
+
coords = cpg_subset.var_names
|
|
137
|
+
fs = feature_sets['cpg']
|
|
138
|
+
state_target = fs["state"]
|
|
139
|
+
feature_map = fs["features"]
|
|
140
|
+
|
|
141
|
+
pred_states = model.predict(tensor_batch)
|
|
142
|
+
probs = model.predict_proba(tensor_batch)
|
|
143
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
144
|
+
|
|
145
|
+
for i, idx in enumerate(cpg_subset.obs.index):
|
|
146
|
+
for start, length, label, prob in classifications[i]:
|
|
147
|
+
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
148
|
+
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
149
|
+
|
|
150
|
+
# --- Binarization + Distance ---
|
|
151
|
+
coordinates = adata.var_names.astype(int).values
|
|
152
|
+
|
|
153
|
+
for feature in tqdm(all_features, desc="Finalizing Layers"):
|
|
154
|
+
bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
155
|
+
counts = np.zeros(adata.shape[0], dtype=int)
|
|
156
|
+
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
157
|
+
if not isinstance(intervals, list):
|
|
158
|
+
intervals = []
|
|
159
|
+
for start, length, prob in intervals:
|
|
160
|
+
if prob > threshold:
|
|
161
|
+
start_idx = np.searchsorted(coordinates, start, side="left")
|
|
162
|
+
end_idx = np.searchsorted(coordinates, start + length - 1, side="right")
|
|
163
|
+
bin_matrix[row_idx, start_idx:end_idx] = 1
|
|
164
|
+
counts[row_idx] += 1
|
|
165
|
+
adata.layers[feature] = bin_matrix
|
|
166
|
+
adata.obs[f"n_{feature}"] = counts
|
|
167
|
+
adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
|
|
168
|
+
|
|
169
|
+
def calculate_batch_distances(intervals_list, threshold=0.9):
|
|
170
|
+
"""
|
|
171
|
+
Vectorized calculation of distances across multiple reads.
|
|
172
|
+
|
|
173
|
+
Parameters:
|
|
174
|
+
intervals_list (list of list): Outer list = reads, inner list = intervals [start, length, prob]
|
|
175
|
+
threshold (float): Minimum probability threshold for filtering
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
List of distance lists per read.
|
|
179
|
+
"""
|
|
180
|
+
results = []
|
|
181
|
+
for intervals in intervals_list:
|
|
182
|
+
if not isinstance(intervals, list) or len(intervals) == 0:
|
|
183
|
+
results.append([])
|
|
184
|
+
continue
|
|
185
|
+
valid = [iv for iv in intervals if iv[2] > threshold]
|
|
186
|
+
valid = sorted(valid, key=lambda x: x[0])
|
|
187
|
+
dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
|
|
188
|
+
results.append(dists)
|
|
189
|
+
return results
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def classify_batch(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Methylated"):
|
|
194
|
+
"""
|
|
195
|
+
Classify batch sequences efficiently.
|
|
196
|
+
|
|
197
|
+
Parameters:
|
|
198
|
+
predicted_states_batch: Tensor [batch_size, seq_len]
|
|
199
|
+
probabilities_batch: Tensor [batch_size, seq_len, n_states]
|
|
200
|
+
coordinates: list of genomic coordinates
|
|
201
|
+
classification_mapping: dict of feature bins
|
|
202
|
+
target_state: state name ("Methylated" or "Non-Methylated")
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
List of classifications for each sequence.
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
state_labels = ["Non-Methylated", "Methylated"]
|
|
209
|
+
target_idx = state_labels.index(target_state)
|
|
210
|
+
batch_size = predicted_states_batch.shape[0]
|
|
211
|
+
|
|
212
|
+
all_classifications = []
|
|
213
|
+
|
|
214
|
+
for b in range(batch_size):
|
|
215
|
+
predicted_states = predicted_states_batch[b].cpu().numpy()
|
|
216
|
+
probabilities = probabilities_batch[b].cpu().numpy()
|
|
217
|
+
|
|
218
|
+
regions = []
|
|
219
|
+
current_start, current_length, current_probs = None, 0, []
|
|
220
|
+
|
|
221
|
+
for i, state_index in enumerate(predicted_states):
|
|
222
|
+
state_prob = probabilities[i][state_index]
|
|
223
|
+
if state_index == target_idx:
|
|
224
|
+
if current_start is None:
|
|
225
|
+
current_start = i
|
|
226
|
+
current_length += 1
|
|
227
|
+
current_probs.append(state_prob)
|
|
228
|
+
elif current_start is not None:
|
|
229
|
+
regions.append((current_start, current_length, np.mean(current_probs)))
|
|
230
|
+
current_start, current_length, current_probs = None, 0, []
|
|
231
|
+
|
|
232
|
+
if current_start is not None:
|
|
233
|
+
regions.append((current_start, current_length, np.mean(current_probs)))
|
|
234
|
+
|
|
235
|
+
final = []
|
|
236
|
+
for start, length, prob in regions:
|
|
237
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
238
|
+
label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
|
|
239
|
+
final.append((int(coordinates[start]) + 1, feature_length, label, prob))
|
|
240
|
+
all_classifications.append(final)
|
|
241
|
+
|
|
242
|
+
return all_classifications
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# calculate_distances
|
|
2
|
+
|
|
3
|
+
def calculate_distances(intervals, threshold=0.9):
|
|
4
|
+
"""
|
|
5
|
+
Calculates distance between features in a read.
|
|
6
|
+
Takes in a list of intervals (start of feature, length of feature)
|
|
7
|
+
"""
|
|
8
|
+
# Sort intervals by start position
|
|
9
|
+
intervals = sorted(intervals, key=lambda x: x[0])
|
|
10
|
+
intervals = [interval for interval in intervals if interval[2] > threshold]
|
|
11
|
+
|
|
12
|
+
# Calculate distances
|
|
13
|
+
distances = []
|
|
14
|
+
for i in range(len(intervals) - 1):
|
|
15
|
+
end_current = intervals[i][0] + intervals[i][1]
|
|
16
|
+
start_next = intervals[i + 1][0]
|
|
17
|
+
distances.append(start_next - end_current)
|
|
18
|
+
return distances
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
def call_hmm_peaks(
|
|
2
|
+
adata,
|
|
3
|
+
feature_configs,
|
|
4
|
+
obs_column='Reference_strand',
|
|
5
|
+
site_types=['GpC_site', 'CpG_site'],
|
|
6
|
+
save_plot=False,
|
|
7
|
+
output_dir=None,
|
|
8
|
+
date_tag=None,
|
|
9
|
+
inplace=False
|
|
10
|
+
):
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
from scipy.signal import find_peaks
|
|
15
|
+
|
|
16
|
+
if not inplace:
|
|
17
|
+
adata = adata.copy()
|
|
18
|
+
|
|
19
|
+
# Ensure obs_column is categorical
|
|
20
|
+
if not isinstance(adata.obs[obs_column].dtype, pd.CategoricalDtype):
|
|
21
|
+
adata.obs[obs_column] = pd.Categorical(adata.obs[obs_column])
|
|
22
|
+
|
|
23
|
+
coordinates = adata.var_names.astype(int).values
|
|
24
|
+
peak_columns = []
|
|
25
|
+
|
|
26
|
+
obs_updates = {}
|
|
27
|
+
|
|
28
|
+
for feature_layer, config in feature_configs.items():
|
|
29
|
+
min_distance = config.get('min_distance', 200)
|
|
30
|
+
peak_width = config.get('peak_width', 200)
|
|
31
|
+
peak_prominence = config.get('peak_prominence', 0.2)
|
|
32
|
+
peak_threshold = config.get('peak_threshold', 0.8)
|
|
33
|
+
|
|
34
|
+
matrix = adata.layers[feature_layer]
|
|
35
|
+
means = np.mean(matrix, axis=0)
|
|
36
|
+
peak_indices, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
|
|
37
|
+
peak_centers = coordinates[peak_indices]
|
|
38
|
+
adata.uns[f'{feature_layer} peak_centers'] = peak_centers.tolist()
|
|
39
|
+
|
|
40
|
+
# Plot
|
|
41
|
+
plt.figure(figsize=(6, 3))
|
|
42
|
+
plt.plot(coordinates, means)
|
|
43
|
+
plt.title(f"{feature_layer} with peak calls")
|
|
44
|
+
plt.xlabel("Genomic position")
|
|
45
|
+
plt.ylabel("Mean intensity")
|
|
46
|
+
for i, center in enumerate(peak_centers):
|
|
47
|
+
start, end = center - peak_width // 2, center + peak_width // 2
|
|
48
|
+
plt.axvspan(start, end, color='purple', alpha=0.2)
|
|
49
|
+
plt.axvline(center, color='red', linestyle='--')
|
|
50
|
+
aligned = [end if i % 2 else start, 'left' if i % 2 else 'right']
|
|
51
|
+
plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
|
|
52
|
+
if save_plot and output_dir:
|
|
53
|
+
filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
|
|
54
|
+
plt.savefig(filename, bbox_inches='tight')
|
|
55
|
+
print(f"Saved plot to {filename}")
|
|
56
|
+
else:
|
|
57
|
+
plt.show()
|
|
58
|
+
|
|
59
|
+
feature_peak_columns = []
|
|
60
|
+
for center in peak_centers:
|
|
61
|
+
start, end = center - peak_width // 2, center + peak_width // 2
|
|
62
|
+
colname = f'{feature_layer}_peak_{center}'
|
|
63
|
+
peak_columns.append(colname)
|
|
64
|
+
feature_peak_columns.append(colname)
|
|
65
|
+
|
|
66
|
+
peak_mask = (coordinates >= start) & (coordinates <= end)
|
|
67
|
+
adata.var[colname] = peak_mask
|
|
68
|
+
|
|
69
|
+
region = matrix[:, peak_mask]
|
|
70
|
+
obs_updates[f'mean_{feature_layer}_around_{center}'] = np.mean(region, axis=1)
|
|
71
|
+
obs_updates[f'sum_{feature_layer}_around_{center}'] = np.sum(region, axis=1)
|
|
72
|
+
obs_updates[f'{feature_layer}_present_at_{center}'] = np.mean(region, axis=1) > peak_threshold
|
|
73
|
+
|
|
74
|
+
for site_type in site_types:
|
|
75
|
+
adata.obs[f'{site_type}_sum_around_{center}'] = 0
|
|
76
|
+
adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
|
|
77
|
+
|
|
78
|
+
for ref in adata.obs[obs_column].cat.categories:
|
|
79
|
+
ref_idx = adata.obs[obs_column] == ref
|
|
80
|
+
mask_key = f"{ref}_{site_type}"
|
|
81
|
+
for site_type in site_types:
|
|
82
|
+
if mask_key not in adata.var:
|
|
83
|
+
continue
|
|
84
|
+
site_mask = adata.var[mask_key].values
|
|
85
|
+
site_coords = coordinates[site_mask]
|
|
86
|
+
region_mask = (site_coords >= start) & (site_coords <= end)
|
|
87
|
+
if not region_mask.any():
|
|
88
|
+
continue
|
|
89
|
+
full_mask = site_mask.copy()
|
|
90
|
+
full_mask[site_mask] = region_mask
|
|
91
|
+
site_region = adata[ref_idx, full_mask].X
|
|
92
|
+
if hasattr(site_region, "A"):
|
|
93
|
+
site_region = site_region.A
|
|
94
|
+
if site_region.shape[1] > 0:
|
|
95
|
+
adata.obs.loc[ref_idx, f'{site_type}_sum_around_{center}'] = np.nansum(site_region, axis=1)
|
|
96
|
+
adata.obs.loc[ref_idx, f'{site_type}_mean_around_{center}'] = np.nanmean(site_region, axis=1)
|
|
97
|
+
else:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
|
|
101
|
+
print(f"Annotated {len(peak_centers)} peaks for {feature_layer}")
|
|
102
|
+
|
|
103
|
+
adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)
|
|
104
|
+
adata.obs = pd.concat([adata.obs, pd.DataFrame(obs_updates, index=adata.obs.index)], axis=1)
|
|
105
|
+
|
|
106
|
+
return adata if not inplace else None
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
|
|
2
|
+
import torch
|
|
3
|
+
print("\n**HMM Model Overview**")
|
|
4
|
+
print(hmm)
|
|
5
|
+
|
|
6
|
+
print("\n**Transition Matrix**")
|
|
7
|
+
transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
|
|
8
|
+
for i, row in enumerate(transition_matrix):
|
|
9
|
+
label = state_labels[i] if state_labels else f"State {i}"
|
|
10
|
+
formatted_row = ", ".join(f"{p:.6f}" for p in row)
|
|
11
|
+
print(f"{label}: [{formatted_row}]")
|
|
12
|
+
|
|
13
|
+
print("\n**Emission Probabilities**")
|
|
14
|
+
for i, dist in enumerate(hmm.distributions):
|
|
15
|
+
label = state_labels[i] if state_labels else f"State {i}"
|
|
16
|
+
probs = dist.probs.detach().cpu().numpy()
|
|
17
|
+
formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
|
|
18
|
+
print(f"{label}: {formatted_emissions}")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
def load_hmm(model_path, device='cpu'):
|
|
2
|
+
"""
|
|
3
|
+
Reads in a pretrained HMM.
|
|
4
|
+
|
|
5
|
+
Parameters:
|
|
6
|
+
model_path (str): Path to a pretrained HMM
|
|
7
|
+
"""
|
|
8
|
+
import torch
|
|
9
|
+
# Load model using PyTorch
|
|
10
|
+
hmm = torch.load(model_path)
|
|
11
|
+
hmm.to(device)
|
|
12
|
+
return hmm
|
|
13
|
+
|
|
14
|
+
def save_hmm(model, model_path):
|
|
15
|
+
import torch
|
|
16
|
+
torch.save(model, model_path)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
def refine_nucleosome_calls(adata, layer_name, nan_mask_layer, hexamer_size=120, octamer_size=147, max_wiggle=40, device="cpu"):
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
nucleosome_layer = adata.layers[layer_name]
|
|
5
|
+
nan_mask = adata.layers[nan_mask_layer] # should be binary mask: 1 = nan region, 0 = valid data
|
|
6
|
+
|
|
7
|
+
hexamer_layer = np.zeros_like(nucleosome_layer)
|
|
8
|
+
octamer_layer = np.zeros_like(nucleosome_layer)
|
|
9
|
+
|
|
10
|
+
for read_idx, row in enumerate(nucleosome_layer):
|
|
11
|
+
in_patch = False
|
|
12
|
+
start_idx = None
|
|
13
|
+
|
|
14
|
+
for pos in range(len(row)):
|
|
15
|
+
if row[pos] == 1 and not in_patch:
|
|
16
|
+
in_patch = True
|
|
17
|
+
start_idx = pos
|
|
18
|
+
if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
|
|
19
|
+
in_patch = False
|
|
20
|
+
end_idx = pos if row[pos] == 0 else pos + 1
|
|
21
|
+
|
|
22
|
+
# Expand boundaries into NaNs
|
|
23
|
+
left_expand = 0
|
|
24
|
+
right_expand = 0
|
|
25
|
+
|
|
26
|
+
# Left
|
|
27
|
+
for i in range(1, max_wiggle + 1):
|
|
28
|
+
if start_idx - i >= 0 and nan_mask[read_idx, start_idx - i] == 1:
|
|
29
|
+
left_expand += 1
|
|
30
|
+
else:
|
|
31
|
+
break
|
|
32
|
+
# Right
|
|
33
|
+
for i in range(1, max_wiggle + 1):
|
|
34
|
+
if end_idx + i < nucleosome_layer.shape[1] and nan_mask[read_idx, end_idx + i] == 1:
|
|
35
|
+
right_expand += 1
|
|
36
|
+
else:
|
|
37
|
+
break
|
|
38
|
+
|
|
39
|
+
expanded_start = start_idx - left_expand
|
|
40
|
+
expanded_end = end_idx + right_expand
|
|
41
|
+
|
|
42
|
+
available_size = expanded_end - expanded_start
|
|
43
|
+
|
|
44
|
+
# Octamer placement
|
|
45
|
+
if available_size >= octamer_size:
|
|
46
|
+
center = (expanded_start + expanded_end) // 2
|
|
47
|
+
half_oct = octamer_size // 2
|
|
48
|
+
octamer_layer[read_idx, center - half_oct: center - half_oct + octamer_size] = 1
|
|
49
|
+
|
|
50
|
+
# Hexamer placement
|
|
51
|
+
elif available_size >= hexamer_size:
|
|
52
|
+
center = (expanded_start + expanded_end) // 2
|
|
53
|
+
half_hex = hexamer_size // 2
|
|
54
|
+
hexamer_layer[read_idx, center - half_hex: center - half_hex + hexamer_size] = 1
|
|
55
|
+
|
|
56
|
+
adata.layers[f"{layer_name}_hexamers"] = hexamer_layer
|
|
57
|
+
adata.layers[f"{layer_name}_octamers"] = octamer_layer
|
|
58
|
+
|
|
59
|
+
print(f"Added layers: {layer_name}_hexamers and {layer_name}_octamers")
|
|
60
|
+
return adata
|
|
61
|
+
|
|
62
|
+
def infer_nucleosomes_in_large_bound(adata, large_bound_layer, combined_nuc_layer, nan_mask_layer, nuc_size=147, linker_size=50, exclusion_buffer=30, device="cpu"):
|
|
63
|
+
import numpy as np
|
|
64
|
+
|
|
65
|
+
large_bound = adata.layers[large_bound_layer]
|
|
66
|
+
existing_nucs = adata.layers[combined_nuc_layer]
|
|
67
|
+
nan_mask = adata.layers[nan_mask_layer]
|
|
68
|
+
|
|
69
|
+
inferred_layer = np.zeros_like(large_bound)
|
|
70
|
+
|
|
71
|
+
for read_idx, row in enumerate(large_bound):
|
|
72
|
+
in_patch = False
|
|
73
|
+
start_idx = None
|
|
74
|
+
|
|
75
|
+
for pos in range(len(row)):
|
|
76
|
+
if row[pos] == 1 and not in_patch:
|
|
77
|
+
in_patch = True
|
|
78
|
+
start_idx = pos
|
|
79
|
+
if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
|
|
80
|
+
in_patch = False
|
|
81
|
+
end_idx = pos if row[pos] == 0 else pos + 1
|
|
82
|
+
|
|
83
|
+
# Adjust boundaries into flanking NaN regions without getting too close to existing nucleosomes
|
|
84
|
+
left_expand = start_idx
|
|
85
|
+
while left_expand > 0 and nan_mask[read_idx, left_expand - 1] == 1 and np.sum(existing_nucs[read_idx, max(0, left_expand - exclusion_buffer):left_expand]) == 0:
|
|
86
|
+
left_expand -= 1
|
|
87
|
+
|
|
88
|
+
right_expand = end_idx
|
|
89
|
+
while right_expand < row.shape[0] and nan_mask[read_idx, right_expand] == 1 and np.sum(existing_nucs[read_idx, right_expand:min(row.shape[0], right_expand + exclusion_buffer)]) == 0:
|
|
90
|
+
right_expand += 1
|
|
91
|
+
|
|
92
|
+
# Phase nucleosomes with linker spacing only
|
|
93
|
+
region = (left_expand, right_expand)
|
|
94
|
+
pos_cursor = region[0]
|
|
95
|
+
while pos_cursor + nuc_size <= region[1]:
|
|
96
|
+
if np.all((existing_nucs[read_idx, pos_cursor - exclusion_buffer:pos_cursor + nuc_size + exclusion_buffer] == 0)):
|
|
97
|
+
inferred_layer[read_idx, pos_cursor:pos_cursor + nuc_size] = 1
|
|
98
|
+
pos_cursor += nuc_size + linker_size
|
|
99
|
+
else:
|
|
100
|
+
pos_cursor += 1
|
|
101
|
+
|
|
102
|
+
adata.layers[f"{large_bound_layer}_phased_nucleosomes"] = inferred_layer
|
|
103
|
+
print(f"Added layer: {large_bound_layer}_phased_nucleosomes")
|
|
104
|
+
return adata
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
def train_hmm(
|
|
2
|
+
data,
|
|
3
|
+
emission_probs=[[0.8, 0.2], [0.2, 0.8]],
|
|
4
|
+
transitions=[[0.9, 0.1], [0.1, 0.9]],
|
|
5
|
+
start_probs=[0.5, 0.5],
|
|
6
|
+
end_probs=[0.5, 0.5],
|
|
7
|
+
device=None,
|
|
8
|
+
max_iter=50,
|
|
9
|
+
verbose=True,
|
|
10
|
+
tol=50,
|
|
11
|
+
pad_value=0,
|
|
12
|
+
):
|
|
13
|
+
"""
|
|
14
|
+
Trains a 2-state DenseHMM model on binary methylation/deamination data.
|
|
15
|
+
|
|
16
|
+
Parameters:
|
|
17
|
+
data (list or np.ndarray): List of sequences (lists) with 0, 1, or NaN.
|
|
18
|
+
emission_probs (list): List of emission probabilities for two states.
|
|
19
|
+
transitions (list): Transition matrix between states.
|
|
20
|
+
start_probs (list): Initial state probabilities.
|
|
21
|
+
end_probs (list): End state probabilities.
|
|
22
|
+
device (str or torch.device): "cpu", "mps", "cuda", or None (auto).
|
|
23
|
+
max_iter (int): Maximum EM iterations.
|
|
24
|
+
verbose (bool): Verbose output from pomegranate.
|
|
25
|
+
tol (float): Convergence tolerance.
|
|
26
|
+
pad_value (int): Value used to pad shorter sequences.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
hmm: Trained DenseHMM model
|
|
30
|
+
"""
|
|
31
|
+
import torch
|
|
32
|
+
from pomegranate.hmm import DenseHMM
|
|
33
|
+
from pomegranate.distributions import Categorical
|
|
34
|
+
import numpy as np
|
|
35
|
+
from tqdm import tqdm
|
|
36
|
+
|
|
37
|
+
# Auto device detection
|
|
38
|
+
if device is None:
|
|
39
|
+
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
|
40
|
+
elif isinstance(device, str):
|
|
41
|
+
device = torch.device(device)
|
|
42
|
+
print(f"Using device: {device}")
|
|
43
|
+
|
|
44
|
+
# Ensure emission probs on correct device
|
|
45
|
+
dists = [
|
|
46
|
+
Categorical(torch.tensor([p], device=device))
|
|
47
|
+
for p in emission_probs
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
# Create DenseHMM
|
|
51
|
+
hmm = DenseHMM(
|
|
52
|
+
distributions=dists,
|
|
53
|
+
edges=transitions,
|
|
54
|
+
starts=start_probs,
|
|
55
|
+
ends=end_probs,
|
|
56
|
+
verbose=verbose,
|
|
57
|
+
max_iter=max_iter,
|
|
58
|
+
tol=tol,
|
|
59
|
+
).to(device)
|
|
60
|
+
|
|
61
|
+
# Convert data to list if needed
|
|
62
|
+
if isinstance(data, np.ndarray):
|
|
63
|
+
data = data.tolist()
|
|
64
|
+
|
|
65
|
+
# Preprocess data (replace NaNs + pad)
|
|
66
|
+
max_length = max(len(seq) for seq in data)
|
|
67
|
+
processed_data = []
|
|
68
|
+
for sequence in tqdm(data, desc="Preprocessing Sequences"):
|
|
69
|
+
cleaned_seq = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in sequence]
|
|
70
|
+
cleaned_seq += [pad_value] * (max_length - len(cleaned_seq))
|
|
71
|
+
processed_data.append(cleaned_seq)
|
|
72
|
+
|
|
73
|
+
tensor_data = torch.tensor(processed_data, dtype=torch.long, device=device).unsqueeze(-1)
|
|
74
|
+
|
|
75
|
+
# Fit HMM
|
|
76
|
+
hmm.fit(tensor_data)
|
|
77
|
+
|
|
78
|
+
return hmm
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from . import helpers
|
|
2
|
+
from .basecall_pod5s import basecall_pod5s
|
|
3
|
+
from .subsample_fasta_from_bed import subsample_fasta_from_bed
|
|
4
|
+
from .subsample_pod5 import subsample_pod5
|
|
5
|
+
from .fast5_to_pod5 import fast5_to_pod5
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"basecall_pod5s",
|
|
10
|
+
"subsample_fasta_from_bed",
|
|
11
|
+
"subsample_pod5",
|
|
12
|
+
"fast5_to_pod5",
|
|
13
|
+
"helpers"
|
|
14
|
+
]
|