smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# recipes
|
|
2
2
|
|
|
3
|
-
def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory, mapping_key_column='Sample', reference_column = 'Reference', sample_names_col='Sample_names', invert=
|
|
3
|
+
def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory, mapping_key_column='Sample', reference_column = 'Reference', sample_names_col='Sample_names', invert=True):
|
|
4
4
|
"""
|
|
5
5
|
The first part of the preprocessing workflow applied to the smf.inform.pod_to_adata() output derived from Kissiov_and_McKenna_2025.
|
|
6
6
|
|
|
@@ -9,9 +9,10 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
|
|
|
9
9
|
2) Appends a boolean indicating whether each position in var_names is within a given reference.
|
|
10
10
|
3) Appends the cytosine context to each position from each reference.
|
|
11
11
|
4) Calculate read level methylation statistics.
|
|
12
|
-
5)
|
|
13
|
-
6)
|
|
14
|
-
7)
|
|
12
|
+
5) Calculates read length statistics (start position, end position, read length).
|
|
13
|
+
6) Optionally inverts the adata to flip the position coordinate orientation.
|
|
14
|
+
7) Adds new layers containing NaN replaced variants of adata.X (fill_closest, nan0_0minus1, nan1_12).
|
|
15
|
+
8) Returns a dictionary to pass the variable namespace to the parent scope.
|
|
15
16
|
|
|
16
17
|
Parameters:
|
|
17
18
|
adata (AnnData): The AnnData object to use as input.
|
|
@@ -34,6 +35,7 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
|
|
|
34
35
|
from .calculate_converted_read_methylation_stats import calculate_converted_read_methylation_stats
|
|
35
36
|
from .invert_adata import invert_adata
|
|
36
37
|
from .calculate_read_length_stats import calculate_read_length_stats
|
|
38
|
+
from .clean_NaN import clean_NaN
|
|
37
39
|
|
|
38
40
|
# Clean up some of the Reference metadata and save variable names that point to sets of values in the column.
|
|
39
41
|
adata.obs[reference_column] = adata.obs[reference_column].astype('category')
|
|
@@ -42,7 +44,7 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
|
|
|
42
44
|
reference_mapping = {k: v for k, v in split_references}
|
|
43
45
|
adata.obs[f'{reference_column}_short'] = adata.obs[reference_column].map(reference_mapping)
|
|
44
46
|
short_references = set(adata.obs[f'{reference_column}_short'])
|
|
45
|
-
binary_layers = adata.layers.keys()
|
|
47
|
+
binary_layers = list(adata.layers.keys())
|
|
46
48
|
|
|
47
49
|
# load sample sheet metadata
|
|
48
50
|
load_sample_sheet(adata, sample_sheet_path, mapping_key_column)
|
|
@@ -59,16 +61,19 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
|
|
|
59
61
|
append_C_context(adata, obs_column=reference_column, use_consensus=False)
|
|
60
62
|
|
|
61
63
|
# Calculate read level methylation statistics. Assess if GpC methylation level is above other_C methylation level as a QC.
|
|
62
|
-
calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col
|
|
64
|
+
calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col)
|
|
65
|
+
|
|
66
|
+
# Calculate read length statistics
|
|
67
|
+
upper_bound, lower_bound = calculate_read_length_stats(adata, reference_column, sample_names_col)
|
|
63
68
|
|
|
64
69
|
# Invert the adata object (ie flip the strand orientation for visualization)
|
|
65
70
|
if invert:
|
|
66
|
-
invert_adata(adata)
|
|
71
|
+
adata = invert_adata(adata)
|
|
67
72
|
else:
|
|
68
73
|
pass
|
|
69
74
|
|
|
70
|
-
#
|
|
71
|
-
|
|
75
|
+
# NaN replacement strategies stored in additional layers. Having layer=None uses adata.X
|
|
76
|
+
clean_NaN(adata, layer=None)
|
|
72
77
|
|
|
73
78
|
variables = {
|
|
74
79
|
"short_references": short_references,
|
|
@@ -80,22 +85,21 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
|
|
|
80
85
|
}
|
|
81
86
|
return variables
|
|
82
87
|
|
|
83
|
-
def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers,
|
|
88
|
+
def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, distance_thresholds={}, reference_column = 'Reference', sample_names_col='Sample_names'):
|
|
84
89
|
"""
|
|
85
90
|
The second part of the preprocessing workflow applied to the adata that has already been preprocessed by recipe_1_Kissiov_and_McKenna_2025.
|
|
86
91
|
|
|
87
92
|
Performs the following tasks:
|
|
88
|
-
1)
|
|
89
|
-
2)
|
|
90
|
-
3)
|
|
91
|
-
4)
|
|
92
|
-
5) Returns two adata object: one for the filtered adata and one for the duplicate adata.
|
|
93
|
+
1) Marks putative PCR duplicates using pairwise hamming distance metrics.
|
|
94
|
+
2) Performs a complexity analysis of the library based on the PCR duplicate detection rate.
|
|
95
|
+
3) Removes PCR duplicates from the adata.
|
|
96
|
+
4) Returns two adata object: one for the filtered adata and one for the duplicate adata.
|
|
93
97
|
|
|
94
98
|
Parameters:
|
|
95
99
|
adata (AnnData): The AnnData object to use as input.
|
|
96
100
|
output_directory (str): String representing the path to the output directory for plots.
|
|
97
101
|
binary_layers (list): A list of layers to used for the binary encoding of read sequences. Used for duplicate detection.
|
|
98
|
-
|
|
102
|
+
distance_thresholds (dict): A dictionary keyed by obs_column categories that points to a float corresponding to the distance threshold to apply. Default is an empty dict.
|
|
99
103
|
reference_column (str): The name of the reference column to use.
|
|
100
104
|
sample_names_col (str): The name of the sample name column to use.
|
|
101
105
|
|
|
@@ -106,16 +110,14 @@ def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, ha
|
|
|
106
110
|
import anndata as ad
|
|
107
111
|
import pandas as pd
|
|
108
112
|
import numpy as np
|
|
109
|
-
from .clean_NaN import clean_NaN
|
|
110
113
|
from .mark_duplicates import mark_duplicates
|
|
111
114
|
from .calculate_complexity import calculate_complexity
|
|
112
115
|
from .remove_duplicates import remove_duplicates
|
|
113
116
|
|
|
114
|
-
#
|
|
115
|
-
clean_NaN(adata, layer=None)
|
|
117
|
+
# Add here a way to remove reads below a given read quality (based on nan content). Need to also add a way to pull from BAM files the read quality from each read
|
|
116
118
|
|
|
117
119
|
# Duplicate detection using pairwise hamming distance across reads
|
|
118
|
-
mark_duplicates(adata, binary_layers, obs_column=reference_column, sample_col=sample_names_col,
|
|
120
|
+
mark_duplicates(adata, binary_layers, obs_column=reference_column, sample_col=sample_names_col, distance_thresholds=distance_thresholds, method='N_masked_distances')
|
|
119
121
|
|
|
120
122
|
# Complexity analysis using the marked duplicates and the lander-watermann algorithm
|
|
121
123
|
calculate_complexity(adata, output_directory, obs_column=reference_column, sample_col=sample_names_col, plot=True, save_plot=False)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
|
|
2
|
+
"""
|
|
3
|
+
Subsamples an AnnData object so that each unique combination of categories
|
|
4
|
+
in the given `obs_columns` has at most `max_samples` observations.
|
|
5
|
+
If `obs_columns` is None or empty, the function randomly subsamples the entire dataset.
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata (AnnData): The AnnData object to subsample.
|
|
9
|
+
obs_columns (list of str, optional): List of observation column names to group by.
|
|
10
|
+
max_samples (int): The maximum number of observations per category combination.
|
|
11
|
+
random_seed (int): Random seed for reproducibility.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
AnnData: A new AnnData object with subsampled observations.
|
|
15
|
+
"""
|
|
16
|
+
import anndata as ad
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
np.random.seed(random_seed) # Ensure reproducibility
|
|
20
|
+
|
|
21
|
+
if not obs_columns: # If no obs columns are given, sample globally
|
|
22
|
+
if adata.shape[0] > max_samples:
|
|
23
|
+
sampled_indices = np.random.choice(adata.obs.index, max_samples, replace=False)
|
|
24
|
+
else:
|
|
25
|
+
sampled_indices = adata.obs.index # Keep all if fewer than max_samples
|
|
26
|
+
|
|
27
|
+
return adata[sampled_indices].copy()
|
|
28
|
+
|
|
29
|
+
sampled_indices = []
|
|
30
|
+
|
|
31
|
+
# Get unique category combinations from all specified obs columns
|
|
32
|
+
unique_combinations = adata.obs[obs_columns].drop_duplicates()
|
|
33
|
+
|
|
34
|
+
for _, row in unique_combinations.iterrows():
|
|
35
|
+
# Build filter condition dynamically for multiple columns
|
|
36
|
+
condition = (adata.obs[obs_columns] == row.values).all(axis=1)
|
|
37
|
+
|
|
38
|
+
# Get indices for the current category combination
|
|
39
|
+
subset_indices = adata.obs[condition].index.to_numpy()
|
|
40
|
+
|
|
41
|
+
# Subsample or take all
|
|
42
|
+
if len(subset_indices) > max_samples:
|
|
43
|
+
sampled = np.random.choice(subset_indices, max_samples, replace=False)
|
|
44
|
+
else:
|
|
45
|
+
sampled = subset_indices # Keep all if fewer than max_samples
|
|
46
|
+
|
|
47
|
+
sampled_indices.extend(sampled)
|
|
48
|
+
|
|
49
|
+
# ⚠ Handle backed mode detection
|
|
50
|
+
if adata.isbacked:
|
|
51
|
+
print("⚠ Detected backed mode. Subset will be loaded fully into memory.")
|
|
52
|
+
subset = adata[sampled_indices]
|
|
53
|
+
subset = subset.to_memory()
|
|
54
|
+
else:
|
|
55
|
+
subset = adata[sampled_indices]
|
|
56
|
+
|
|
57
|
+
# Create a new AnnData object with only the selected indices
|
|
58
|
+
return subset[sampled_indices].copy()
|
smftools/readwrite.py
CHANGED
|
@@ -23,27 +23,46 @@ def time_string():
|
|
|
23
23
|
|
|
24
24
|
######################################################################################################
|
|
25
25
|
## Numpy, Pandas, Anndata functionality
|
|
26
|
+
|
|
26
27
|
def adata_to_df(adata, layer=None):
|
|
27
28
|
"""
|
|
28
|
-
|
|
29
|
-
|
|
29
|
+
Convert an AnnData object into a Pandas DataFrame.
|
|
30
|
+
|
|
31
|
+
Parameters:
|
|
32
|
+
adata (AnnData): The input AnnData object.
|
|
33
|
+
layer (str, optional): The layer to extract. If None, uses adata.X.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
pd.DataFrame: A DataFrame where rows are observations and columns are positions.
|
|
30
37
|
"""
|
|
31
38
|
import pandas as pd
|
|
32
39
|
import anndata as ad
|
|
40
|
+
import numpy as np
|
|
41
|
+
|
|
42
|
+
# Validate that the requested layer exists
|
|
43
|
+
if layer and layer not in adata.layers:
|
|
44
|
+
raise ValueError(f"Layer '{layer}' not found in adata.layers.")
|
|
45
|
+
|
|
46
|
+
# Extract the data matrix
|
|
47
|
+
data_matrix = adata.layers.get(layer, adata.X)
|
|
48
|
+
|
|
49
|
+
# Ensure matrix is dense (handle sparse formats)
|
|
50
|
+
if hasattr(data_matrix, "toarray"):
|
|
51
|
+
data_matrix = data_matrix.toarray()
|
|
52
|
+
|
|
53
|
+
# Ensure obs and var have unique indices
|
|
54
|
+
if adata.obs.index.duplicated().any():
|
|
55
|
+
raise ValueError("Duplicate values found in `adata.obs.index`. Ensure unique observation indices.")
|
|
56
|
+
|
|
57
|
+
if adata.var.index.duplicated().any():
|
|
58
|
+
raise ValueError("Duplicate values found in `adata.var.index`. Ensure unique variable indices.")
|
|
59
|
+
|
|
60
|
+
# Convert to DataFrame
|
|
61
|
+
df = pd.DataFrame(data_matrix, index=adata.obs.index, columns=adata.var.index)
|
|
33
62
|
|
|
34
|
-
# Extract the data matrix from the given layer
|
|
35
|
-
if layer:
|
|
36
|
-
data_matrix = adata.layers[layer]
|
|
37
|
-
else:
|
|
38
|
-
data_matrix = adata.X
|
|
39
|
-
# Extract observation (read) annotations
|
|
40
|
-
obs_df = adata.obs
|
|
41
|
-
# Extract variable (position) annotations
|
|
42
|
-
var_df = adata.var
|
|
43
|
-
# Convert data matrix and annotations to pandas DataFrames
|
|
44
|
-
df = pd.DataFrame(data_matrix, index=obs_df.index, columns=var_df.index)
|
|
45
63
|
return df
|
|
46
64
|
|
|
65
|
+
|
|
47
66
|
def save_matrix(matrix, save_name):
|
|
48
67
|
"""
|
|
49
68
|
Input: A numpy matrix and a save_name
|
|
@@ -103,4 +122,77 @@ def concatenate_h5ads(output_file, file_suffix='h5ad.gz', delete_inputs=True):
|
|
|
103
122
|
print(f"Error deleting file {hdf}: {e}")
|
|
104
123
|
else:
|
|
105
124
|
print('Keeping input files')
|
|
125
|
+
|
|
126
|
+
def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir="./"):
|
|
127
|
+
"""
|
|
128
|
+
Saves an AnnData object safely by omitting problematic columns from .obs and .var.
|
|
129
|
+
|
|
130
|
+
Parameters:
|
|
131
|
+
adata (AnnData): The AnnData object to save.
|
|
132
|
+
path (str): Output .h5ad file path.
|
|
133
|
+
compression (str): Compression method for h5ad file.
|
|
134
|
+
backup (bool): If True, saves problematic columns to CSV files.
|
|
135
|
+
backup_dir (str): Directory to store backups if backup=True.
|
|
136
|
+
"""
|
|
137
|
+
import anndata as ad
|
|
138
|
+
import pandas as pd
|
|
139
|
+
import os
|
|
140
|
+
|
|
141
|
+
os.makedirs(backup_dir, exist_ok=True)
|
|
142
|
+
|
|
143
|
+
def filter_df(df, df_name):
|
|
144
|
+
bad_cols = []
|
|
145
|
+
for col in df.columns:
|
|
146
|
+
if df[col].dtype == 'object':
|
|
147
|
+
if not df[col].apply(lambda x: isinstance(x, (str, type(None)))).all():
|
|
148
|
+
bad_cols.append(col)
|
|
149
|
+
if bad_cols:
|
|
150
|
+
print(f"⚠️ Skipping columns from {df_name}: {bad_cols}")
|
|
151
|
+
if backup:
|
|
152
|
+
df[bad_cols].to_csv(os.path.join(backup_dir, f"{df_name}_skipped_columns.csv"))
|
|
153
|
+
print(f"📝 Backed up skipped columns to {backup_dir}/{df_name}_skipped_columns.csv")
|
|
154
|
+
return df.drop(columns=bad_cols)
|
|
155
|
+
|
|
156
|
+
# Clean obs and var
|
|
157
|
+
obs_clean = filter_df(adata.obs, "obs")
|
|
158
|
+
var_clean = filter_df(adata.var, "var")
|
|
159
|
+
|
|
160
|
+
# Save clean version
|
|
161
|
+
adata_copy = ad.AnnData(
|
|
162
|
+
X=adata.X,
|
|
163
|
+
obs=obs_clean,
|
|
164
|
+
var=var_clean,
|
|
165
|
+
layers=adata.layers,
|
|
166
|
+
uns=adata.uns,
|
|
167
|
+
obsm=adata.obsm,
|
|
168
|
+
varm=adata.varm
|
|
169
|
+
)
|
|
170
|
+
adata_copy.write_h5ad(path, compression=compression)
|
|
171
|
+
print(f"✅ Saved safely to {path}")
|
|
172
|
+
|
|
173
|
+
def merge_barcoded_anndatas(adata_single, adata_double):
|
|
174
|
+
import numpy as np
|
|
175
|
+
import anndata as ad
|
|
176
|
+
|
|
177
|
+
# Step 1: Identify overlap
|
|
178
|
+
overlap = np.intersect1d(adata_single.obs_names, adata_double.obs_names)
|
|
179
|
+
|
|
180
|
+
# Step 2: Filter out overlaps from adata_single
|
|
181
|
+
adata_single_filtered = adata_single[~adata_single.obs_names.isin(overlap)].copy()
|
|
182
|
+
|
|
183
|
+
# Step 3: Add source tag
|
|
184
|
+
adata_single_filtered.obs['source'] = 'single_barcode'
|
|
185
|
+
adata_double.obs['source'] = 'double_barcode'
|
|
186
|
+
|
|
187
|
+
# Step 4: Concatenate all components
|
|
188
|
+
adata_merged = ad.concat([
|
|
189
|
+
adata_single_filtered,
|
|
190
|
+
adata_double
|
|
191
|
+
], join='outer', merge='same') # merge='same' preserves matching layers, obsm, etc.
|
|
192
|
+
|
|
193
|
+
# Step 5: Merge `.uns`
|
|
194
|
+
adata_merged.uns = {**adata_single.uns, **adata_double.uns}
|
|
195
|
+
|
|
196
|
+
return adata_merged
|
|
197
|
+
|
|
106
198
|
######################################################################################################
|
smftools/tools/__init__.py
CHANGED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from .apply_hmm import apply_hmm
|
|
2
|
+
from .apply_hmm_batched import apply_hmm_batched
|
|
3
|
+
from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistic
|
|
4
|
+
from .calculate_distances import calculate_distances
|
|
5
|
+
from .calculate_umap import calculate_umap
|
|
6
|
+
from .call_hmm_peaks import call_hmm_peaks
|
|
7
|
+
from .classifiers import run_training_loop, run_inference, evaluate_models_by_subgroup, prepare_melted_model_data, sliding_window_train_test
|
|
8
|
+
from .cluster_adata_on_methylation import cluster_adata_on_methylation
|
|
9
|
+
from .display_hmm import display_hmm
|
|
10
|
+
from .general_tools import create_nan_mask_from_X, combine_layers, create_nan_or_non_gpc_mask
|
|
11
|
+
from .hmm_readwrite import load_hmm, save_hmm
|
|
12
|
+
from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
|
|
13
|
+
from .read_stats import calculate_row_entropy
|
|
14
|
+
from .subset_adata import subset_adata
|
|
15
|
+
from .train_hmm import train_hmm
|
|
16
|
+
|
|
17
|
+
from . import models
|
|
18
|
+
from . import data
|
|
19
|
+
from . import utils
|
|
20
|
+
from . import evaluation
|
|
21
|
+
from . import inference
|
|
22
|
+
from . import training
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"apply_hmm",
|
|
26
|
+
"apply_hmm_batched",
|
|
27
|
+
"calculate_distances",
|
|
28
|
+
"compute_positionwise_statistic",
|
|
29
|
+
"calculate_row_entropy",
|
|
30
|
+
"calculate_umap",
|
|
31
|
+
"calculate_relative_risk_on_activity",
|
|
32
|
+
"call_hmm_peaks",
|
|
33
|
+
"cluster_adata_on_methylation",
|
|
34
|
+
"create_nan_mask_from_X",
|
|
35
|
+
"create_nan_or_non_gpc_mask",
|
|
36
|
+
"combine_layers",
|
|
37
|
+
"display_hmm",
|
|
38
|
+
"evaluate_models_by_subgroup",
|
|
39
|
+
"load_hmm",
|
|
40
|
+
"prepare_melted_model_data",
|
|
41
|
+
"refine_nucleosome_calls",
|
|
42
|
+
"infer_nucleosomes_in_large_bound",
|
|
43
|
+
"run_training_loop",
|
|
44
|
+
"run_inference",
|
|
45
|
+
"save_hmm",
|
|
46
|
+
"sliding_window_train_test"
|
|
47
|
+
"subset_adata",
|
|
48
|
+
"train_hmm"
|
|
49
|
+
]
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import torch
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
def apply_hmm(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
|
|
7
|
+
"""
|
|
8
|
+
Applies an HMM model to an AnnData object using tensor-based sequence inputs.
|
|
9
|
+
If multiple methbases are passed, generates a combined feature set.
|
|
10
|
+
"""
|
|
11
|
+
model.to(device)
|
|
12
|
+
|
|
13
|
+
# --- Feature Definitions ---
|
|
14
|
+
feature_sets = {}
|
|
15
|
+
if footprints:
|
|
16
|
+
feature_sets["footprint"] = {
|
|
17
|
+
"features": {
|
|
18
|
+
"small_bound_stretch": [0, 30],
|
|
19
|
+
"medium_bound_stretch": [30, 80],
|
|
20
|
+
"putative_nucleosome": [80, 200],
|
|
21
|
+
"large_bound_stretch": [200, np.inf]
|
|
22
|
+
},
|
|
23
|
+
"state": "Non-Methylated"
|
|
24
|
+
}
|
|
25
|
+
if accessible_patches:
|
|
26
|
+
feature_sets["accessible"] = {
|
|
27
|
+
"features": {
|
|
28
|
+
"small_accessible_patch": [0, 30],
|
|
29
|
+
"mid_accessible_patch": [30, 80],
|
|
30
|
+
"large_accessible_patch": [80, np.inf]
|
|
31
|
+
},
|
|
32
|
+
"state": "Methylated"
|
|
33
|
+
}
|
|
34
|
+
if cpg:
|
|
35
|
+
feature_sets["cpg"] = {
|
|
36
|
+
"features": {
|
|
37
|
+
"cpg_patch": [0, np.inf]
|
|
38
|
+
},
|
|
39
|
+
"state": "Methylated"
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# --- Init columns ---
|
|
43
|
+
all_features = []
|
|
44
|
+
combined_prefix = "Combined"
|
|
45
|
+
for key, fs in feature_sets.items():
|
|
46
|
+
if key == 'cpg':
|
|
47
|
+
all_features += [f"CpG_{f}" for f in fs["features"]]
|
|
48
|
+
all_features.append(f"CpG_all_{key}_features")
|
|
49
|
+
else:
|
|
50
|
+
for methbase in methbases:
|
|
51
|
+
all_features += [f"{methbase}_{f}" for f in fs["features"]]
|
|
52
|
+
all_features.append(f"{methbase}_all_{key}_features")
|
|
53
|
+
all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
|
|
54
|
+
all_features.append(f"{combined_prefix}_all_{key}_features")
|
|
55
|
+
|
|
56
|
+
for feature in all_features:
|
|
57
|
+
adata.obs[feature] = pd.Series([[] for _ in range(adata.shape[0])], dtype=object, index=adata.obs.index)
|
|
58
|
+
adata.obs[f"{feature}_distances"] = pd.Series([None] * adata.shape[0])
|
|
59
|
+
adata.obs[f"n_{feature}"] = -1
|
|
60
|
+
|
|
61
|
+
# --- Main loop ---
|
|
62
|
+
references = adata.obs[obs_column].cat.categories
|
|
63
|
+
|
|
64
|
+
for ref in tqdm(references, desc="Processing References"):
|
|
65
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
66
|
+
|
|
67
|
+
# Create combined mask for methbases
|
|
68
|
+
combined_mask = None
|
|
69
|
+
for methbase in methbases:
|
|
70
|
+
mask = {
|
|
71
|
+
"a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
|
|
72
|
+
"gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
|
|
73
|
+
"cpg": ref_subset.var[f"{ref}_CpG_site"] == True
|
|
74
|
+
}[methbase.lower()]
|
|
75
|
+
combined_mask = mask if combined_mask is None else combined_mask | mask
|
|
76
|
+
|
|
77
|
+
methbase_subset = ref_subset[:, mask]
|
|
78
|
+
matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
|
|
79
|
+
|
|
80
|
+
for i, raw_read in enumerate(matrix):
|
|
81
|
+
read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
|
|
82
|
+
tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
|
|
83
|
+
coords = methbase_subset.var_names
|
|
84
|
+
|
|
85
|
+
for key, fs in feature_sets.items():
|
|
86
|
+
if key == 'cpg':
|
|
87
|
+
continue
|
|
88
|
+
state_target = fs["state"]
|
|
89
|
+
feature_map = fs["features"]
|
|
90
|
+
|
|
91
|
+
classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
|
|
92
|
+
idx = methbase_subset.obs.index[i]
|
|
93
|
+
|
|
94
|
+
for start, length, label, prob in classifications:
|
|
95
|
+
adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
|
|
96
|
+
adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
|
|
97
|
+
|
|
98
|
+
# Combined methbase subset
|
|
99
|
+
combined_subset = ref_subset[:, combined_mask]
|
|
100
|
+
combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
|
|
101
|
+
|
|
102
|
+
for i, raw_read in enumerate(combined_matrix):
|
|
103
|
+
read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
|
|
104
|
+
tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
|
|
105
|
+
coords = combined_subset.var_names
|
|
106
|
+
|
|
107
|
+
for key, fs in feature_sets.items():
|
|
108
|
+
if key == 'cpg':
|
|
109
|
+
continue
|
|
110
|
+
state_target = fs["state"]
|
|
111
|
+
feature_map = fs["features"]
|
|
112
|
+
|
|
113
|
+
classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
|
|
114
|
+
idx = combined_subset.obs.index[i]
|
|
115
|
+
|
|
116
|
+
for start, length, label, prob in classifications:
|
|
117
|
+
adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
|
|
118
|
+
adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
|
|
119
|
+
|
|
120
|
+
# --- Special handling for CpG ---
|
|
121
|
+
if cpg:
|
|
122
|
+
for ref in tqdm(references, desc="Processing CpG"):
|
|
123
|
+
ref_subset = adata[adata.obs[obs_column] == ref]
|
|
124
|
+
mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
|
|
125
|
+
cpg_subset = ref_subset[:, mask]
|
|
126
|
+
matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
|
|
127
|
+
|
|
128
|
+
for i, raw_read in enumerate(matrix):
|
|
129
|
+
read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
|
|
130
|
+
tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
|
|
131
|
+
coords = cpg_subset.var_names
|
|
132
|
+
fs = feature_sets['cpg']
|
|
133
|
+
state_target = fs["state"]
|
|
134
|
+
feature_map = fs["features"]
|
|
135
|
+
classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
|
|
136
|
+
idx = cpg_subset.obs.index[i]
|
|
137
|
+
for start, length, label, prob in classifications:
|
|
138
|
+
adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
|
|
139
|
+
adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
|
|
140
|
+
|
|
141
|
+
# --- Binarization + Distance ---
|
|
142
|
+
for feature in tqdm(all_features, desc="Finalizing Layers"):
|
|
143
|
+
bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
|
|
144
|
+
counts = np.zeros(adata.shape[0], dtype=int)
|
|
145
|
+
for row_idx, intervals in enumerate(adata.obs[feature]):
|
|
146
|
+
if not isinstance(intervals, list):
|
|
147
|
+
intervals = []
|
|
148
|
+
for start, length, prob in intervals:
|
|
149
|
+
if prob > threshold:
|
|
150
|
+
bin_matrix[row_idx, start:start+length] = 1
|
|
151
|
+
counts[row_idx] += 1
|
|
152
|
+
adata.layers[f"{feature}"] = bin_matrix
|
|
153
|
+
adata.obs[f"n_{feature}"] = counts
|
|
154
|
+
adata.obs[f"{feature}_distances"] = adata.obs[feature].apply(lambda x: calculate_distances(x, threshold))
|
|
155
|
+
|
|
156
|
+
def calculate_distances(intervals, threshold=0.9):
|
|
157
|
+
"""Calculates distances between consecutive features in a read."""
|
|
158
|
+
intervals = sorted([iv for iv in intervals if iv[2] > threshold], key=lambda x: x[0])
|
|
159
|
+
distances = [(intervals[i + 1][0] - (intervals[i][0] + intervals[i][1]))
|
|
160
|
+
for i in range(len(intervals) - 1)]
|
|
161
|
+
return distances
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def classify_features(sequence, model, coordinates, classification_mapping={}, target_state="Methylated"):
|
|
165
|
+
"""
|
|
166
|
+
Classifies regions based on HMM state.
|
|
167
|
+
|
|
168
|
+
Parameters:
|
|
169
|
+
sequence (torch.Tensor): Tensor of binarized data [batch_size, seq_len, 1]
|
|
170
|
+
model: Trained pomegranate HMM
|
|
171
|
+
coordinates (list): Genomic coordinates for sequence
|
|
172
|
+
classification_mapping (dict): Mapping for feature labeling
|
|
173
|
+
target_state (str): The state to classify ("Methylated" or "Non-Methylated")
|
|
174
|
+
"""
|
|
175
|
+
predicted_states = model.predict(sequence).squeeze(-1).squeeze(0).cpu().numpy()
|
|
176
|
+
probabilities = model.predict_proba(sequence).squeeze(0).cpu().numpy()
|
|
177
|
+
state_labels = ["Non-Methylated", "Methylated"]
|
|
178
|
+
|
|
179
|
+
classifications, current_start, current_length, current_probs = [], None, 0, []
|
|
180
|
+
|
|
181
|
+
for i, state_index in enumerate(predicted_states):
|
|
182
|
+
state_name = state_labels[state_index]
|
|
183
|
+
state_prob = probabilities[i][state_index]
|
|
184
|
+
|
|
185
|
+
if state_name == target_state:
|
|
186
|
+
if current_start is None:
|
|
187
|
+
current_start = i
|
|
188
|
+
current_length += 1
|
|
189
|
+
current_probs.append(state_prob)
|
|
190
|
+
elif current_start is not None:
|
|
191
|
+
classifications.append((current_start, current_length, avg := np.mean(current_probs)))
|
|
192
|
+
current_start, current_length, current_probs = None, 0, []
|
|
193
|
+
|
|
194
|
+
if current_start is not None:
|
|
195
|
+
classifications.append((current_start, current_length, avg := np.mean(current_probs)))
|
|
196
|
+
|
|
197
|
+
final = []
|
|
198
|
+
for start, length, prob in classifications:
|
|
199
|
+
feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
|
|
200
|
+
label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
|
|
201
|
+
final.append((int(coordinates[start]) + 1, feature_length, label, prob))
|
|
202
|
+
return final
|