smftools 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +29 -0
- smftools/_settings.py +20 -0
- smftools/_version.py +1 -0
- smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
- smftools/datasets/F1_sample_sheet.csv +5 -0
- smftools/datasets/__init__.py +9 -0
- smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
- smftools/datasets/datasets.py +28 -0
- smftools/informatics/__init__.py +16 -0
- smftools/informatics/archived/bam_conversion.py +59 -0
- smftools/informatics/archived/bam_direct.py +63 -0
- smftools/informatics/archived/basecalls_to_adata.py +71 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +132 -0
- smftools/informatics/direct_smf.py +137 -0
- smftools/informatics/fast5_to_pod5.py +21 -0
- smftools/informatics/helpers/LoadExperimentConfig.py +75 -0
- smftools/informatics/helpers/__init__.py +74 -0
- smftools/informatics/helpers/align_and_sort_BAM.py +59 -0
- smftools/informatics/helpers/aligned_BAM_to_bed.py +74 -0
- smftools/informatics/helpers/archived/informatics.py +260 -0
- smftools/informatics/helpers/archived/load_adata.py +516 -0
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/bed_to_bigwig.py +39 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +79 -0
- smftools/informatics/helpers/canoncall.py +34 -0
- smftools/informatics/helpers/complement_base_list.py +21 -0
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +55 -0
- smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/count_aligned_reads.py +43 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +44 -0
- smftools/informatics/helpers/extract_mods.py +83 -0
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
- smftools/informatics/helpers/find_conversion_sites.py +50 -0
- smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
- smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
- smftools/informatics/helpers/get_native_references.py +28 -0
- smftools/informatics/helpers/index_fasta.py +12 -0
- smftools/informatics/helpers/make_dirs.py +21 -0
- smftools/informatics/helpers/make_modbed.py +27 -0
- smftools/informatics/helpers/modQC.py +27 -0
- smftools/informatics/helpers/modcall.py +36 -0
- smftools/informatics/helpers/modkit_extract_to_adata.py +884 -0
- smftools/informatics/helpers/ohe_batching.py +76 -0
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +57 -0
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +53 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
- smftools/informatics/helpers/split_and_index_BAM.py +36 -0
- smftools/informatics/load_adata.py +182 -0
- smftools/informatics/readwrite.py +106 -0
- smftools/informatics/subsample_fasta_from_bed.py +47 -0
- smftools/informatics/subsample_pod5.py +104 -0
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +33 -0
- smftools/preprocessing/append_C_context.py +82 -0
- smftools/preprocessing/archives/mark_duplicates.py +146 -0
- smftools/preprocessing/archives/preprocessing.py +614 -0
- smftools/preprocessing/archives/remove_duplicates.py +21 -0
- smftools/preprocessing/binarize_on_Youden.py +45 -0
- smftools/preprocessing/binary_layers_to_ohe.py +40 -0
- smftools/preprocessing/calculate_complexity.py +72 -0
- smftools/preprocessing/calculate_consensus.py +47 -0
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +94 -0
- smftools/preprocessing/calculate_coverage.py +42 -0
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
- smftools/preprocessing/calculate_position_Youden.py +115 -0
- smftools/preprocessing/calculate_read_length_stats.py +79 -0
- smftools/preprocessing/clean_NaN.py +46 -0
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +44 -0
- smftools/preprocessing/filter_reads_on_length.py +51 -0
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +30 -0
- smftools/preprocessing/load_sample_sheet.py +38 -0
- smftools/preprocessing/make_dirs.py +21 -0
- smftools/preprocessing/min_non_diagonal.py +25 -0
- smftools/preprocessing/recipes.py +127 -0
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +198 -0
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/evaluation/__init__.py +0 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +28 -0
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/METADATA +5 -2
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools-0.1.6.dist-info/RECORD +0 -4
- {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
smftools/readwrite.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
## readwrite ##
|
|
2
|
+
|
|
3
|
+
######################################################################################################
|
|
4
|
+
## Datetime functionality
|
|
5
|
+
def date_string():
|
|
6
|
+
"""
|
|
7
|
+
Each time this is called, it returns the current date string
|
|
8
|
+
"""
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
current_date = datetime.now()
|
|
11
|
+
date_string = current_date.strftime("%Y%m%d")
|
|
12
|
+
date_string = date_string[2:]
|
|
13
|
+
return date_string
|
|
14
|
+
|
|
15
|
+
def time_string():
|
|
16
|
+
"""
|
|
17
|
+
Each time this is called, it returns the current time string
|
|
18
|
+
"""
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
current_time = datetime.now()
|
|
21
|
+
return current_time.strftime("%H:%M:%S")
|
|
22
|
+
######################################################################################################
|
|
23
|
+
|
|
24
|
+
######################################################################################################
|
|
25
|
+
## Numpy, Pandas, Anndata functionality
|
|
26
|
+
|
|
27
|
+
def adata_to_df(adata, layer=None):
|
|
28
|
+
"""
|
|
29
|
+
Convert an AnnData object into a Pandas DataFrame.
|
|
30
|
+
|
|
31
|
+
Parameters:
|
|
32
|
+
adata (AnnData): The input AnnData object.
|
|
33
|
+
layer (str, optional): The layer to extract. If None, uses adata.X.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
pd.DataFrame: A DataFrame where rows are observations and columns are positions.
|
|
37
|
+
"""
|
|
38
|
+
import pandas as pd
|
|
39
|
+
import anndata as ad
|
|
40
|
+
import numpy as np
|
|
41
|
+
|
|
42
|
+
# Validate that the requested layer exists
|
|
43
|
+
if layer and layer not in adata.layers:
|
|
44
|
+
raise ValueError(f"Layer '{layer}' not found in adata.layers.")
|
|
45
|
+
|
|
46
|
+
# Extract the data matrix
|
|
47
|
+
data_matrix = adata.layers.get(layer, adata.X)
|
|
48
|
+
|
|
49
|
+
# Ensure matrix is dense (handle sparse formats)
|
|
50
|
+
if hasattr(data_matrix, "toarray"):
|
|
51
|
+
data_matrix = data_matrix.toarray()
|
|
52
|
+
|
|
53
|
+
# Ensure obs and var have unique indices
|
|
54
|
+
if adata.obs.index.duplicated().any():
|
|
55
|
+
raise ValueError("Duplicate values found in `adata.obs.index`. Ensure unique observation indices.")
|
|
56
|
+
|
|
57
|
+
if adata.var.index.duplicated().any():
|
|
58
|
+
raise ValueError("Duplicate values found in `adata.var.index`. Ensure unique variable indices.")
|
|
59
|
+
|
|
60
|
+
# Convert to DataFrame
|
|
61
|
+
df = pd.DataFrame(data_matrix, index=adata.obs.index, columns=adata.var.index)
|
|
62
|
+
|
|
63
|
+
return df
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def save_matrix(matrix, save_name):
|
|
67
|
+
"""
|
|
68
|
+
Input: A numpy matrix and a save_name
|
|
69
|
+
Output: A txt file representation of the data matrix
|
|
70
|
+
"""
|
|
71
|
+
import numpy as np
|
|
72
|
+
np.savetxt(f'{save_name}.txt', matrix)
|
|
73
|
+
|
|
74
|
+
def concatenate_h5ads(output_file, file_suffix='h5ad.gz', delete_inputs=True):
|
|
75
|
+
"""
|
|
76
|
+
Concatenate all h5ad files in a directory and delete them after the final adata is written out.
|
|
77
|
+
Input: an output file path relative to the directory in which the function is called
|
|
78
|
+
"""
|
|
79
|
+
import os
|
|
80
|
+
import anndata as ad
|
|
81
|
+
# Runtime warnings
|
|
82
|
+
import warnings
|
|
83
|
+
warnings.filterwarnings('ignore', category=UserWarning, module='anndata')
|
|
84
|
+
warnings.filterwarnings('ignore', category=FutureWarning, module='anndata')
|
|
85
|
+
|
|
86
|
+
# List all files in the directory
|
|
87
|
+
files = os.listdir(os.getcwd())
|
|
88
|
+
# get current working directory
|
|
89
|
+
cwd = os.getcwd()
|
|
90
|
+
suffix = file_suffix
|
|
91
|
+
# Filter file names that contain the search string in their filename and keep them in a list
|
|
92
|
+
hdfs = [hdf for hdf in files if suffix in hdf]
|
|
93
|
+
# Sort file list by names and print the list of file names
|
|
94
|
+
hdfs.sort()
|
|
95
|
+
print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
|
|
96
|
+
# Iterate over all of the hdf5 files and concatenate them.
|
|
97
|
+
final_adata = None
|
|
98
|
+
for hdf in hdfs:
|
|
99
|
+
print('{0}: Reading in {1} hdf5 file'.format(time_string(), hdf))
|
|
100
|
+
temp_adata = ad.read_h5ad(hdf)
|
|
101
|
+
if final_adata:
|
|
102
|
+
print('{0}: Concatenating final adata object with {1} hdf5 file'.format(time_string(), hdf))
|
|
103
|
+
final_adata = ad.concat([final_adata, temp_adata], join='outer', index_unique=None)
|
|
104
|
+
else:
|
|
105
|
+
print('{0}: Initializing final adata object with {1} hdf5 file'.format(time_string(), hdf))
|
|
106
|
+
final_adata = temp_adata
|
|
107
|
+
print('{0}: Writing final concatenated hdf5 file'.format(time_string()))
|
|
108
|
+
final_adata.write_h5ad(output_file, compression='gzip')
|
|
109
|
+
|
|
110
|
+
# Delete the individual h5ad files and only keep the final concatenated file
|
|
111
|
+
if delete_inputs:
|
|
112
|
+
files = os.listdir(os.getcwd())
|
|
113
|
+
hdfs = [hdf for hdf in files if suffix in hdf]
|
|
114
|
+
if output_file in hdfs:
|
|
115
|
+
hdfs.remove(output_file)
|
|
116
|
+
# Iterate over the files and delete them
|
|
117
|
+
for hdf in hdfs:
|
|
118
|
+
try:
|
|
119
|
+
os.remove(hdf)
|
|
120
|
+
print(f"Deleted file: {hdf}")
|
|
121
|
+
except OSError as e:
|
|
122
|
+
print(f"Error deleting file {hdf}: {e}")
|
|
123
|
+
else:
|
|
124
|
+
print('Keeping input files')
|
|
125
|
+
|
|
126
|
+
def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir="./"):
|
|
127
|
+
"""
|
|
128
|
+
Saves an AnnData object safely by omitting problematic columns from .obs and .var.
|
|
129
|
+
|
|
130
|
+
Parameters:
|
|
131
|
+
adata (AnnData): The AnnData object to save.
|
|
132
|
+
path (str): Output .h5ad file path.
|
|
133
|
+
compression (str): Compression method for h5ad file.
|
|
134
|
+
backup (bool): If True, saves problematic columns to CSV files.
|
|
135
|
+
backup_dir (str): Directory to store backups if backup=True.
|
|
136
|
+
"""
|
|
137
|
+
import anndata as ad
|
|
138
|
+
import pandas as pd
|
|
139
|
+
import os
|
|
140
|
+
|
|
141
|
+
os.makedirs(backup_dir, exist_ok=True)
|
|
142
|
+
|
|
143
|
+
def filter_df(df, df_name):
|
|
144
|
+
bad_cols = []
|
|
145
|
+
for col in df.columns:
|
|
146
|
+
if df[col].dtype == 'object':
|
|
147
|
+
if not df[col].apply(lambda x: isinstance(x, (str, type(None)))).all():
|
|
148
|
+
bad_cols.append(col)
|
|
149
|
+
if bad_cols:
|
|
150
|
+
print(f"⚠️ Skipping columns from {df_name}: {bad_cols}")
|
|
151
|
+
if backup:
|
|
152
|
+
df[bad_cols].to_csv(os.path.join(backup_dir, f"{df_name}_skipped_columns.csv"))
|
|
153
|
+
print(f"📝 Backed up skipped columns to {backup_dir}/{df_name}_skipped_columns.csv")
|
|
154
|
+
return df.drop(columns=bad_cols)
|
|
155
|
+
|
|
156
|
+
# Clean obs and var
|
|
157
|
+
obs_clean = filter_df(adata.obs, "obs")
|
|
158
|
+
var_clean = filter_df(adata.var, "var")
|
|
159
|
+
|
|
160
|
+
# Save clean version
|
|
161
|
+
adata_copy = ad.AnnData(
|
|
162
|
+
X=adata.X,
|
|
163
|
+
obs=obs_clean,
|
|
164
|
+
var=var_clean,
|
|
165
|
+
layers=adata.layers,
|
|
166
|
+
uns=adata.uns,
|
|
167
|
+
obsm=adata.obsm,
|
|
168
|
+
varm=adata.varm
|
|
169
|
+
)
|
|
170
|
+
adata_copy.write_h5ad(path, compression=compression)
|
|
171
|
+
print(f"✅ Saved safely to {path}")
|
|
172
|
+
|
|
173
|
+
def merge_barcoded_anndatas(adata_single, adata_double):
|
|
174
|
+
import numpy as np
|
|
175
|
+
import anndata as ad
|
|
176
|
+
|
|
177
|
+
# Step 1: Identify overlap
|
|
178
|
+
overlap = np.intersect1d(adata_single.obs_names, adata_double.obs_names)
|
|
179
|
+
|
|
180
|
+
# Step 2: Filter out overlaps from adata_single
|
|
181
|
+
adata_single_filtered = adata_single[~adata_single.obs_names.isin(overlap)].copy()
|
|
182
|
+
|
|
183
|
+
# Step 3: Add source tag
|
|
184
|
+
adata_single_filtered.obs['source'] = 'single_barcode'
|
|
185
|
+
adata_double.obs['source'] = 'double_barcode'
|
|
186
|
+
|
|
187
|
+
# Step 4: Concatenate all components
|
|
188
|
+
adata_merged = ad.concat([
|
|
189
|
+
adata_single_filtered,
|
|
190
|
+
adata_double
|
|
191
|
+
], join='outer', merge='same') # merge='same' preserves matching layers, obsm, etc.
|
|
192
|
+
|
|
193
|
+
# Step 5: Merge `.uns`
|
|
194
|
+
adata_merged.uns = {**adata_single.uns, **adata_double.uns}
|
|
195
|
+
|
|
196
|
+
return adata_merged
|
|
197
|
+
|
|
198
|
+
######################################################################################################
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from .apply_hmm import apply_hmm
|
|
2
|
+
from .apply_hmm_batched import apply_hmm_batched
|
|
3
|
+
from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistic
|
|
4
|
+
from .calculate_distances import calculate_distances
|
|
5
|
+
from .calculate_umap import calculate_umap
|
|
6
|
+
from .call_hmm_peaks import call_hmm_peaks
|
|
7
|
+
from .classifiers import run_training_loop, run_inference, evaluate_models_by_subgroup, prepare_melted_model_data, sliding_window_train_test
|
|
8
|
+
from .cluster_adata_on_methylation import cluster_adata_on_methylation
|
|
9
|
+
from .display_hmm import display_hmm
|
|
10
|
+
from .general_tools import create_nan_mask_from_X, combine_layers, create_nan_or_non_gpc_mask
|
|
11
|
+
from .hmm_readwrite import load_hmm, save_hmm
|
|
12
|
+
from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
|
|
13
|
+
from .read_stats import calculate_row_entropy
|
|
14
|
+
from .subset_adata import subset_adata
|
|
15
|
+
from .train_hmm import train_hmm
|
|
16
|
+
|
|
17
|
+
from . import models
|
|
18
|
+
from . import data
|
|
19
|
+
from . import utils
|
|
20
|
+
from . import evaluation
|
|
21
|
+
from . import inference
|
|
22
|
+
from . import training
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"apply_hmm",
|
|
26
|
+
"apply_hmm_batched",
|
|
27
|
+
"calculate_distances",
|
|
28
|
+
"compute_positionwise_statistic",
|
|
29
|
+
"calculate_row_entropy",
|
|
30
|
+
"calculate_umap",
|
|
31
|
+
"calculate_relative_risk_on_activity",
|
|
32
|
+
"call_hmm_peaks",
|
|
33
|
+
"cluster_adata_on_methylation",
|
|
34
|
+
"create_nan_mask_from_X",
|
|
35
|
+
"create_nan_or_non_gpc_mask",
|
|
36
|
+
"combine_layers",
|
|
37
|
+
"display_hmm",
|
|
38
|
+
"evaluate_models_by_subgroup",
|
|
39
|
+
"load_hmm",
|
|
40
|
+
"prepare_melted_model_data",
|
|
41
|
+
"refine_nucleosome_calls",
|
|
42
|
+
"infer_nucleosomes_in_large_bound",
|
|
43
|
+
"run_training_loop",
|
|
44
|
+
"run_inference",
|
|
45
|
+
"save_hmm",
|
|
46
|
+
"sliding_window_train_test"
|
|
47
|
+
"subset_adata",
|
|
48
|
+
"train_hmm"
|
|
49
|
+
]
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import torch
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
def apply_hmm(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
|
|
7
|
+
"""
|
|
8
|
+
Applies an HMM model to an AnnData object using tensor-based sequence inputs.
|
|
9
|
+
If multiple methbases are passed, generates a combined feature set.
|
|
10
|
+
"""
|
|
11
|
+
model.to(device)
|
|
12
|
+
|
|
13
|
+
# --- Feature Definitions ---
|
|
14
|
+
feature_sets = {}
|
|
15
|
+
if footprints:
|
|
16
|
+
feature_sets["footprint"] = {
|
|
17
|
+
"features": {
|
|
18
|
+
"small_bound_stretch": [0, 30],
|
|
19
|
+
"medium_bound_stretch": [30, 80],
|
|
20
|
+
"putative_nucleosome": [80, 200],
|
|
21
|
+
"large_bound_stretch": [200, np.inf]
|
|
22
|
+
},
|
|
23
|
+
"state": "Non-Methylated"
|
|
24
|
+
}
|
|
25
|
+
if accessible_patches:
|
|
26
|
+
feature_sets["accessible"] = {
|
|
27
|
+
"features": {
|
|
28
|
+
"small_accessible_patch": [0, 30],
|
|
29
|
+
"mid_accessible_patch": [30, 80],
|
|
30
|
+
"large_accessible_patch": [80, np.inf]
|
|
31
|
+
},
|
|
32
|
+
"state": "Methylated"
|
|
33
|
+
}
|
|
34
|
+
if cpg:
|
|
35
|
+
feature_sets["cpg"] = {
|
|
36
|
+
"features": {
|
|
37
|
+
"cpg_patch": [0, np.inf]
|
|
38
|
+
},
|
|
39
|
+
"state": "Methylated"
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# --- Init columns ---
|
|
43
|
+
all_features = []
|
|
44
|
+
combined_prefix = "Combined"
|
|
45
|
+
for key, fs in feature_sets.items():
|
|
46
|
+
if key == 'cpg':
|
|
47
|
+
all_features += [f"CpG_{f}" for f in fs["features"]]
|
|
48
|
+
all_features.append(f"CpG_all_{key}_features")
|
|
49
|
+
else:
|
|
50
|
+
for methbase in methbases:
|
|
51
|
+
all_features += [f"{methbase}_{f}" for f in fs["features"]]
|
|
52
|
+
all_features.append(f"{methbase}_all_{key}_features")
|
|
53
|
+
all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
|
|
54
|
+
all_features.append(f"{combined_prefix}_all_{key}_features")
|
|
55
|
+
|
|
56
|
+
for feature in all_features:
|
|
57
|
+
adata.obs[feature] = pd.Series([[] for _ in range(adata.shape[0])], dtype=object, index=adata.obs.index)
|
|
58
|
+
adata.obs[f"{feature}_distances"] = pd.Series([None] * adata.shape[0])
|
|
59
|
+
adata.obs[f"n_{feature}"] = -1
|
|
60
|
+
|
|
61
|
+
# --- Main loop ---
|
|
62
|
+
references = adata.obs[obs_column].cat.categories
|
|
63
|
+
|
|
64
|
+
for ref in tqdm(references, desc="Processing References"):
|
|
65
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
66
|
+
|
|
67
|
+
# Create combined mask for methbases
|
|
68
|
+
combined_mask = None
|
|
69
|
+
for methbase in methbases:
|
|
70
|
+
mask = {
|
|
71
|
+
"a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
|
|
72
|
+
"gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
|
|
73
|
+
"cpg": ref_subset.var[f"{ref}_CpG_site"] == True
|
|
74
|
+
}[methbase.lower()]
|
|
75
|
+
combined_mask = mask if combined_mask is None else combined_mask | mask
|
|
76
|
+
|
|
77
|
+
methbase_subset = ref_subset[:, mask]
|
|
78
|
+
matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
|
|
79
|
+
|
|
80
|
+
for i, raw_read in enumerate(matrix):
|
|
81
|
+
read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
|
|
82
|
+
tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
|
|
83
|
+
coords = methbase_subset.var_names
|
|
84
|
+
|
|
85
|
+
for key, fs in feature_sets.items():
|
|
86
|
+
if key == 'cpg':
|
|
87
|
+
continue
|
|
88
|
+
state_target = fs["state"]
|
|
89
|
+
feature_map = fs["features"]
|
|
90
|
+
|
|
91
|
+
classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
|
|
92
|
+
idx = methbase_subset.obs.index[i]
|
|
93
|
+
|
|
94
|
+
for start, length, label, prob in classifications:
|
|
95
|
+
adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
|
|
96
|
+
adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
|
|
97
|
+
|
|
98
|
+
# Combined methbase subset
|
|
99
|
+
combined_subset = ref_subset[:, combined_mask]
|
|
100
|
+
combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
|
|
101
|
+
|
|
102
|
+
for i, raw_read in enumerate(combined_matrix):
|
|
103
|
+
read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
|
|
104
|
+
tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
|
|
105
|
+
coords = combined_subset.var_names
|
|
106
|
+
|
|
107
|
+
for key, fs in feature_sets.items():
|
|
108
|
+
if key == 'cpg':
|
|
109
|
+
continue
|
|
110
|
+
state_target = fs["state"]
|
|
111
|
+
feature_map = fs["features"]
|
|
112
|
+
|
|
113
|
+
classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
|
|
114
|
+
idx = combined_subset.obs.index[i]
|
|
115
|
+
|
|
116
|
+
for start, length, label, prob in classifications:
|
|
117
|
+
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
118
|
+
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
119
|
+
|
|
120
|
+
# --- Special handling for CpG ---
|
|
121
|
+
if cpg:
|
|
122
|
+
for ref in tqdm(references, desc="Processing CpG"):
|
|
123
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
124
|
+
mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
|
|
125
|
+
cpg_subset = ref_subset[:, mask]
|
|
126
|
+
matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
|
|
127
|
+
|
|
128
|
+
for i, raw_read in enumerate(matrix):
|
|
129
|
+
read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
|
|
130
|
+
tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
|
|
131
|
+
coords = cpg_subset.var_names
|
|
132
|
+
fs = feature_sets['cpg']
|
|
133
|
+
state_target = fs["state"]
|
|
134
|
+
feature_map = fs["features"]
|
|
135
|
+
classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
|
|
136
|
+
idx = cpg_subset.obs.index[i]
|
|
137
|
+
for start, length, label, prob in classifications:
|
|
138
|
+
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
139
|
+
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
140
|
+
|
|
141
|
+
# --- Binarization + Distance ---
|
|
142
|
+
for feature in tqdm(all_features, desc="Finalizing Layers"):
|
|
143
|
+
bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
144
|
+
counts = np.zeros(adata.shape[0], dtype=int)
|
|
145
|
+
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
146
|
+
if not isinstance(intervals, list):
|
|
147
|
+
intervals = []
|
|
148
|
+
for start, length, prob in intervals:
|
|
149
|
+
if prob > threshold:
|
|
150
|
+
bin_matrix[row_idx, start:start+length] = 1
|
|
151
|
+
counts[row_idx] += 1
|
|
152
|
+
adata.layers[f"{feature}"] = bin_matrix
|
|
153
|
+
adata.obs[f"n_{feature}"] = counts
|
|
154
|
+
adata.obs[f"{feature}_distances"] = adata.obs[feature].apply(lambda x: calculate_distances(x, threshold))
|
|
155
|
+
|
|
156
|
+
def calculate_distances(intervals, threshold=0.9):
|
|
157
|
+
"""Calculates distances between consecutive features in a read."""
|
|
158
|
+
intervals = sorted([iv for iv in intervals if iv[2] > threshold], key=lambda x: x[0])
|
|
159
|
+
distances = [(intervals[i + 1][0] - (intervals[i][0] + intervals[i][1]))
|
|
160
|
+
for i in range(len(intervals) - 1)]
|
|
161
|
+
return distances
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def classify_features(sequence, model, coordinates, classification_mapping={}, target_state="Methylated"):
|
|
165
|
+
"""
|
|
166
|
+
Classifies regions based on HMM state.
|
|
167
|
+
|
|
168
|
+
Parameters:
|
|
169
|
+
sequence (torch.Tensor): Tensor of binarized data [batch_size, seq_len, 1]
|
|
170
|
+
model: Trained pomegranate HMM
|
|
171
|
+
coordinates (list): Genomic coordinates for sequence
|
|
172
|
+
classification_mapping (dict): Mapping for feature labeling
|
|
173
|
+
target_state (str): The state to classify ("Methylated" or "Non-Methylated")
|
|
174
|
+
"""
|
|
175
|
+
predicted_states = model.predict(sequence).squeeze(-1).squeeze(0).cpu().numpy()
|
|
176
|
+
probabilities = model.predict_proba(sequence).squeeze(0).cpu().numpy()
|
|
177
|
+
state_labels = ["Non-Methylated", "Methylated"]
|
|
178
|
+
|
|
179
|
+
classifications, current_start, current_length, current_probs = [], None, 0, []
|
|
180
|
+
|
|
181
|
+
for i, state_index in enumerate(predicted_states):
|
|
182
|
+
state_name = state_labels[state_index]
|
|
183
|
+
state_prob = probabilities[i][state_index]
|
|
184
|
+
|
|
185
|
+
if state_name == target_state:
|
|
186
|
+
if current_start is None:
|
|
187
|
+
current_start = i
|
|
188
|
+
current_length += 1
|
|
189
|
+
current_probs.append(state_prob)
|
|
190
|
+
elif current_start is not None:
|
|
191
|
+
classifications.append((current_start, current_length, avg := np.mean(current_probs)))
|
|
192
|
+
current_start, current_length, current_probs = None, 0, []
|
|
193
|
+
|
|
194
|
+
if current_start is not None:
|
|
195
|
+
classifications.append((current_start, current_length, avg := np.mean(current_probs)))
|
|
196
|
+
|
|
197
|
+
final = []
|
|
198
|
+
for start, length, prob in classifications:
|
|
199
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
200
|
+
label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
|
|
201
|
+
final.append((int(coordinates[start]) + 1, feature_length, label, prob))
|
|
202
|
+
return final
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import torch
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
def apply_hmm_batched(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
|
|
7
|
+
"""
|
|
8
|
+
Applies an HMM model to an AnnData object using tensor-based sequence inputs.
|
|
9
|
+
If multiple methbases are passed, generates a combined feature set.
|
|
10
|
+
"""
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
model.to(device)
|
|
16
|
+
|
|
17
|
+
# --- Feature Definitions ---
|
|
18
|
+
feature_sets = {}
|
|
19
|
+
if footprints:
|
|
20
|
+
feature_sets["footprint"] = {
|
|
21
|
+
"features": {
|
|
22
|
+
"small_bound_stretch": [0, 20],
|
|
23
|
+
"medium_bound_stretch": [20, 50],
|
|
24
|
+
"putative_nucleosome": [50, 200],
|
|
25
|
+
"large_bound_stretch": [200, np.inf]
|
|
26
|
+
},
|
|
27
|
+
"state": "Non-Methylated"
|
|
28
|
+
}
|
|
29
|
+
if accessible_patches:
|
|
30
|
+
feature_sets["accessible"] = {
|
|
31
|
+
"features": {
|
|
32
|
+
"small_accessible_patch": [0, 20],
|
|
33
|
+
"mid_accessible_patch": [20, 80],
|
|
34
|
+
"large_accessible_patch": [80, np.inf]
|
|
35
|
+
},
|
|
36
|
+
"state": "Methylated"
|
|
37
|
+
}
|
|
38
|
+
if cpg:
|
|
39
|
+
feature_sets["cpg"] = {
|
|
40
|
+
"features": {
|
|
41
|
+
"cpg_patch": [0, np.inf]
|
|
42
|
+
},
|
|
43
|
+
"state": "Methylated"
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
# --- Init columns ---
|
|
47
|
+
all_features = []
|
|
48
|
+
combined_prefix = "Combined"
|
|
49
|
+
for key, fs in feature_sets.items():
|
|
50
|
+
if key == 'cpg':
|
|
51
|
+
all_features += [f"CpG_{f}" for f in fs["features"]]
|
|
52
|
+
all_features.append(f"CpG_all_{key}_features")
|
|
53
|
+
else:
|
|
54
|
+
for methbase in methbases:
|
|
55
|
+
all_features += [f"{methbase}_{f}" for f in fs["features"]]
|
|
56
|
+
all_features.append(f"{methbase}_all_{key}_features")
|
|
57
|
+
if len(methbases) > 1:
|
|
58
|
+
all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
|
|
59
|
+
all_features.append(f"{combined_prefix}_all_{key}_features")
|
|
60
|
+
|
|
61
|
+
for feature in all_features:
|
|
62
|
+
adata.obs[feature] = [[] for _ in range(adata.shape[0])]
|
|
63
|
+
adata.obs[f"{feature}_distances"] = [None] * adata.shape[0]
|
|
64
|
+
adata.obs[f"n_{feature}"] = -1
|
|
65
|
+
|
|
66
|
+
# --- Main loop ---
|
|
67
|
+
references = adata.obs[obs_column].cat.categories
|
|
68
|
+
|
|
69
|
+
for ref in tqdm(references, desc="Processing References"):
|
|
70
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
71
|
+
|
|
72
|
+
# Combined methbase mask
|
|
73
|
+
combined_mask = None
|
|
74
|
+
for methbase in methbases:
|
|
75
|
+
mask = {
|
|
76
|
+
"a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
|
|
77
|
+
"gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
|
|
78
|
+
"cpg": ref_subset.var[f"{ref}_CpG_site"] == True
|
|
79
|
+
}[methbase.lower()]
|
|
80
|
+
combined_mask = mask if combined_mask is None else combined_mask | mask
|
|
81
|
+
|
|
82
|
+
methbase_subset = ref_subset[:, mask]
|
|
83
|
+
matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
|
|
84
|
+
|
|
85
|
+
processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
|
|
86
|
+
tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
87
|
+
|
|
88
|
+
coords = methbase_subset.var_names
|
|
89
|
+
for key, fs in feature_sets.items():
|
|
90
|
+
if key == 'cpg':
|
|
91
|
+
continue
|
|
92
|
+
state_target = fs["state"]
|
|
93
|
+
feature_map = fs["features"]
|
|
94
|
+
|
|
95
|
+
pred_states = model.predict(tensor_batch)
|
|
96
|
+
probs = model.predict_proba(tensor_batch)
|
|
97
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
98
|
+
|
|
99
|
+
for i, idx in enumerate(methbase_subset.obs.index):
|
|
100
|
+
for start, length, label, prob in classifications[i]:
|
|
101
|
+
adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
|
|
102
|
+
adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
|
|
103
|
+
|
|
104
|
+
# Combined subset
|
|
105
|
+
if len(methbases) > 1:
|
|
106
|
+
combined_subset = ref_subset[:, combined_mask]
|
|
107
|
+
combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
|
|
108
|
+
processed_combined_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in combined_matrix]
|
|
109
|
+
tensor_combined_batch = torch.tensor(processed_combined_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
110
|
+
|
|
111
|
+
coords = combined_subset.var_names
|
|
112
|
+
for key, fs in feature_sets.items():
|
|
113
|
+
if key == 'cpg':
|
|
114
|
+
continue
|
|
115
|
+
state_target = fs["state"]
|
|
116
|
+
feature_map = fs["features"]
|
|
117
|
+
|
|
118
|
+
pred_states = model.predict(tensor_combined_batch)
|
|
119
|
+
probs = model.predict_proba(tensor_combined_batch)
|
|
120
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
121
|
+
|
|
122
|
+
for i, idx in enumerate(combined_subset.obs.index):
|
|
123
|
+
for start, length, label, prob in classifications[i]:
|
|
124
|
+
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
125
|
+
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
126
|
+
|
|
127
|
+
# --- Special handling for CpG ---
|
|
128
|
+
if cpg:
|
|
129
|
+
for ref in tqdm(references, desc="Processing CpG"):
|
|
130
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
131
|
+
mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
|
|
132
|
+
cpg_subset = ref_subset[:, mask]
|
|
133
|
+
matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
|
|
134
|
+
|
|
135
|
+
processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
|
|
136
|
+
tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
|
|
137
|
+
|
|
138
|
+
coords = cpg_subset.var_names
|
|
139
|
+
fs = feature_sets['cpg']
|
|
140
|
+
state_target = fs["state"]
|
|
141
|
+
feature_map = fs["features"]
|
|
142
|
+
|
|
143
|
+
pred_states = model.predict(tensor_batch)
|
|
144
|
+
probs = model.predict_proba(tensor_batch)
|
|
145
|
+
classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
|
|
146
|
+
|
|
147
|
+
for i, idx in enumerate(cpg_subset.obs.index):
|
|
148
|
+
for start, length, label, prob in classifications[i]:
|
|
149
|
+
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
150
|
+
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
151
|
+
|
|
152
|
+
# --- Binarization + Distance ---
|
|
153
|
+
for feature in tqdm(all_features, desc="Finalizing Layers"):
|
|
154
|
+
bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
155
|
+
counts = np.zeros(adata.shape[0], dtype=int)
|
|
156
|
+
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
157
|
+
if not isinstance(intervals, list):
|
|
158
|
+
intervals = []
|
|
159
|
+
for start, length, prob in intervals:
|
|
160
|
+
if prob > threshold:
|
|
161
|
+
bin_matrix[row_idx, start:start+length] = 1
|
|
162
|
+
counts[row_idx] += 1
|
|
163
|
+
adata.layers[f"{feature}"] = bin_matrix
|
|
164
|
+
adata.obs[f"n_{feature}"] = counts
|
|
165
|
+
adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
|
|
166
|
+
|
|
167
|
+
def calculate_batch_distances(intervals_list, threshold=0.9):
|
|
168
|
+
"""
|
|
169
|
+
Vectorized calculation of distances across multiple reads.
|
|
170
|
+
|
|
171
|
+
Parameters:
|
|
172
|
+
intervals_list (list of list): Outer list = reads, inner list = intervals [start, length, prob]
|
|
173
|
+
threshold (float): Minimum probability threshold for filtering
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
List of distance lists per read.
|
|
177
|
+
"""
|
|
178
|
+
results = []
|
|
179
|
+
for intervals in intervals_list:
|
|
180
|
+
if not isinstance(intervals, list) or len(intervals) == 0:
|
|
181
|
+
results.append([])
|
|
182
|
+
continue
|
|
183
|
+
valid = [iv for iv in intervals if iv[2] > threshold]
|
|
184
|
+
valid = sorted(valid, key=lambda x: x[0])
|
|
185
|
+
dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
|
|
186
|
+
results.append(dists)
|
|
187
|
+
return results
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def classify_batch(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Methylated"):
|
|
192
|
+
"""
|
|
193
|
+
Classify batch sequences efficiently.
|
|
194
|
+
|
|
195
|
+
Parameters:
|
|
196
|
+
predicted_states_batch: Tensor [batch_size, seq_len]
|
|
197
|
+
probabilities_batch: Tensor [batch_size, seq_len, n_states]
|
|
198
|
+
coordinates: list of genomic coordinates
|
|
199
|
+
classification_mapping: dict of feature bins
|
|
200
|
+
target_state: state name ("Methylated" or "Non-Methylated")
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
List of classifications for each sequence.
|
|
204
|
+
"""
|
|
205
|
+
import numpy as np
|
|
206
|
+
|
|
207
|
+
state_labels = ["Non-Methylated", "Methylated"]
|
|
208
|
+
target_idx = state_labels.index(target_state)
|
|
209
|
+
batch_size = predicted_states_batch.shape[0]
|
|
210
|
+
|
|
211
|
+
all_classifications = []
|
|
212
|
+
|
|
213
|
+
for b in range(batch_size):
|
|
214
|
+
predicted_states = predicted_states_batch[b].cpu().numpy()
|
|
215
|
+
probabilities = probabilities_batch[b].cpu().numpy()
|
|
216
|
+
|
|
217
|
+
regions = []
|
|
218
|
+
current_start, current_length, current_probs = None, 0, []
|
|
219
|
+
|
|
220
|
+
for i, state_index in enumerate(predicted_states):
|
|
221
|
+
state_prob = probabilities[i][state_index]
|
|
222
|
+
if state_index == target_idx:
|
|
223
|
+
if current_start is None:
|
|
224
|
+
current_start = i
|
|
225
|
+
current_length += 1
|
|
226
|
+
current_probs.append(state_prob)
|
|
227
|
+
elif current_start is not None:
|
|
228
|
+
regions.append((current_start, current_length, np.mean(current_probs)))
|
|
229
|
+
current_start, current_length, current_probs = None, 0, []
|
|
230
|
+
|
|
231
|
+
if current_start is not None:
|
|
232
|
+
regions.append((current_start, current_length, np.mean(current_probs)))
|
|
233
|
+
|
|
234
|
+
final = []
|
|
235
|
+
for start, length, prob in regions:
|
|
236
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
237
|
+
label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
|
|
238
|
+
final.append((int(coordinates[start]) + 1, feature_length, label, prob))
|
|
239
|
+
all_classifications.append(final)
|
|
240
|
+
|
|
241
|
+
return all_classifications
|