smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# cluster_adata_on_methylation
|
|
2
|
+
|
|
3
|
+
def cluster_adata_on_methylation(adata, obs_columns, method='hierarchical', n_clusters=3, layer=None, site_types = ['GpC_site', 'CpG_site']):
|
|
4
|
+
"""
|
|
5
|
+
Adds cluster groups to the adata object as an observation column
|
|
6
|
+
|
|
7
|
+
Parameters:
|
|
8
|
+
adata
|
|
9
|
+
obs_columns
|
|
10
|
+
method
|
|
11
|
+
n_clusters
|
|
12
|
+
layer
|
|
13
|
+
site_types
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
None
|
|
17
|
+
"""
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import numpy as np
|
|
20
|
+
from . import subset_adata
|
|
21
|
+
from ..readwrite import adata_to_df
|
|
22
|
+
|
|
23
|
+
# Ensure obs_columns are categorical
|
|
24
|
+
for col in obs_columns:
|
|
25
|
+
adata.obs[col] = adata.obs[col].astype('category')
|
|
26
|
+
|
|
27
|
+
references = adata.obs['Reference'].cat.categories
|
|
28
|
+
|
|
29
|
+
# Add subset metadata to the adata
|
|
30
|
+
subset_adata(adata, obs_columns)
|
|
31
|
+
|
|
32
|
+
subgroup_name = '_'.join(obs_columns)
|
|
33
|
+
subgroups = adata.obs[subgroup_name].cat.categories
|
|
34
|
+
|
|
35
|
+
subgroup_to_reference_map = {}
|
|
36
|
+
for subgroup in subgroups:
|
|
37
|
+
for reference in references:
|
|
38
|
+
if reference in subgroup:
|
|
39
|
+
subgroup_to_reference_map[subgroup] = reference
|
|
40
|
+
else:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
if method == 'hierarchical':
|
|
44
|
+
for site_type in site_types:
|
|
45
|
+
adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
|
|
46
|
+
elif method == 'kmeans':
|
|
47
|
+
for site_type in site_types:
|
|
48
|
+
adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
|
|
49
|
+
|
|
50
|
+
for subgroup in subgroups:
|
|
51
|
+
subgroup_subset = adata[adata.obs[subgroup_name] == subgroup].copy()
|
|
52
|
+
reference = subgroup_to_reference_map[subgroup]
|
|
53
|
+
for site_type in site_types:
|
|
54
|
+
site_subset = subgroup_subset[:, np.array(subgroup_subset.var[f'{reference}_{site_type}'])].copy()
|
|
55
|
+
df = adata_to_df(site_subset, layer=layer)
|
|
56
|
+
df2 = df.reset_index(drop=True)
|
|
57
|
+
if method == 'hierarchical':
|
|
58
|
+
try:
|
|
59
|
+
from scipy.cluster.hierarchy import linkage, dendrogram
|
|
60
|
+
# Perform hierarchical clustering on rows using the average linkage method and Euclidean metric
|
|
61
|
+
row_linkage = linkage(df2.values, method='average', metric='euclidean')
|
|
62
|
+
|
|
63
|
+
# Generate the dendrogram to get the ordered indices
|
|
64
|
+
dendro = dendrogram(row_linkage, no_plot=True)
|
|
65
|
+
reordered_row_indices = np.array(dendro['leaves']).astype(int)
|
|
66
|
+
|
|
67
|
+
# Get the reordered observation names
|
|
68
|
+
reordered_obs_names = [df.index[i] for i in reordered_row_indices]
|
|
69
|
+
|
|
70
|
+
temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}': np.arange(0, len(reordered_obs_names), 1)}, index=reordered_obs_names, dtype=int)
|
|
71
|
+
adata.obs.update(temp_obs_data)
|
|
72
|
+
except:
|
|
73
|
+
print(f'Error found in {subgroup} of {site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}')
|
|
74
|
+
elif method == 'kmeans':
|
|
75
|
+
try:
|
|
76
|
+
from sklearn.cluster import KMeans
|
|
77
|
+
kmeans = KMeans(n_clusters=n_clusters)
|
|
78
|
+
kmeans.fit(site_subset.layers[layer])
|
|
79
|
+
# Get the cluster labels for each data point
|
|
80
|
+
cluster_labels = kmeans.labels_
|
|
81
|
+
# Add the kmeans cluster data as an observation to the anndata object
|
|
82
|
+
site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = cluster_labels.astype(str)
|
|
83
|
+
# Calculate the mean of each observation categoty of each cluster
|
|
84
|
+
cluster_means = site_subset.obs.groupby(f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}').mean()
|
|
85
|
+
# Sort the cluster indices by mean methylation value
|
|
86
|
+
sorted_clusters = cluster_means.sort_values(by=f'{site_type}_row_methylation_means', ascending=False).index
|
|
87
|
+
# Create a mapping of the old cluster values to the new cluster values
|
|
88
|
+
sorted_cluster_mapping = {old: new for new, old in enumerate(sorted_clusters)}
|
|
89
|
+
# Apply the mapping to create a new observation value: kmeans_labels_reordered
|
|
90
|
+
site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].map(sorted_cluster_mapping)
|
|
91
|
+
temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}': site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}']}, index=site_subset.obs_names, dtype=int)
|
|
92
|
+
adata.obs.update(temp_obs_data)
|
|
93
|
+
except:
|
|
94
|
+
print(f'Error found in {subgroup} of {site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}')
|
|
95
|
+
|
|
96
|
+
if method == 'hierarchical':
|
|
97
|
+
# Ensure that the observation values are type int
|
|
98
|
+
for site_type in site_types:
|
|
99
|
+
adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'].astype(int)
|
|
100
|
+
elif method == 'kmeans':
|
|
101
|
+
# Ensure that the observation values are type int
|
|
102
|
+
for site_type in site_types:
|
|
103
|
+
adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].astype(int)
|
|
104
|
+
|
|
105
|
+
return None
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import DataLoader, TensorDataset, random_split
|
|
3
|
+
import pytorch_lightning as pl
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
class AnnDataModule(pl.LightningDataModule):
|
|
8
|
+
def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
|
|
9
|
+
batch_size=64, train_frac=0.7, random_seed=42, split_col='train_val_split', split_save_path=None, load_existing_split=False,
|
|
10
|
+
inference_mode=False):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.adata = adata # The adata object
|
|
13
|
+
self.tensor_source = tensor_source # X, layers, obsm
|
|
14
|
+
self.tensor_key = tensor_key # name of the layer or obsm key
|
|
15
|
+
self.label_col = label_col # name of the label column in obs
|
|
16
|
+
self.batch_size = batch_size
|
|
17
|
+
self.train_frac = train_frac
|
|
18
|
+
self.random_seed = random_seed
|
|
19
|
+
self.split_col = split_col # Name of obs column to store "train"/"val"
|
|
20
|
+
self.split_save_path = split_save_path # Where to save the obs_names and train/test split logging
|
|
21
|
+
self.load_existing_split = load_existing_split # Whether to load from an existing split
|
|
22
|
+
self.inference_mode = inference_mode # Whether to load the AnnDataModule in inference mode.
|
|
23
|
+
|
|
24
|
+
def setup(self, stage=None):
|
|
25
|
+
# Load feature matrix
|
|
26
|
+
if self.tensor_source == "X":
|
|
27
|
+
X = self.adata.X
|
|
28
|
+
elif self.tensor_source == "layers":
|
|
29
|
+
assert self.tensor_key in self.adata.layers, f"Layer '{self.tensor_key}' not found."
|
|
30
|
+
X = self.adata.layers[self.tensor_key]
|
|
31
|
+
elif self.tensor_source == "obsm":
|
|
32
|
+
assert self.tensor_key in self.adata.obsm, f"obsm key '{self.tensor_key}' not found."
|
|
33
|
+
X = self.adata.obsm[self.tensor_key]
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Invalid tensor_source: {self.tensor_source}")
|
|
36
|
+
|
|
37
|
+
# Convert to tensor
|
|
38
|
+
X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
39
|
+
|
|
40
|
+
if self.inference_mode:
|
|
41
|
+
self.infer_dataset = TensorDataset(X_tensor)
|
|
42
|
+
|
|
43
|
+
else:
|
|
44
|
+
# Load and encode labels
|
|
45
|
+
y = self.adata.obs[self.label_col]
|
|
46
|
+
if y.dtype.name == 'category':
|
|
47
|
+
y = y.cat.codes
|
|
48
|
+
y_tensor = torch.tensor(y.values, dtype=torch.long)
|
|
49
|
+
|
|
50
|
+
# Use existing split
|
|
51
|
+
if self.load_existing_split:
|
|
52
|
+
split_df = pd.read_csv(self.split_save_path, index_col=0)
|
|
53
|
+
assert self.split_col in split_df.columns, f"'{self.split_col}' column missing in split file."
|
|
54
|
+
self.adata.obs[self.split_col] = split_df.loc[self.adata.obs_names][self.split_col].values
|
|
55
|
+
|
|
56
|
+
# If no split exists, create one
|
|
57
|
+
if self.split_col not in self.adata.obs:
|
|
58
|
+
full_dataset = TensorDataset(X_tensor, y_tensor)
|
|
59
|
+
n_train = int(self.train_frac * len(full_dataset))
|
|
60
|
+
n_val = len(full_dataset) - n_train
|
|
61
|
+
self.train_set, self.val_set = random_split(
|
|
62
|
+
full_dataset, [n_train, n_val],
|
|
63
|
+
generator=torch.Generator().manual_seed(self.random_seed)
|
|
64
|
+
)
|
|
65
|
+
# Assign split labels
|
|
66
|
+
split_array = np.full(len(self.adata), "val", dtype=object)
|
|
67
|
+
train_idx = self.train_set.indices if hasattr(self.train_set, "indices") else self.train_set._indices
|
|
68
|
+
split_array[train_idx] = "train"
|
|
69
|
+
self.adata.obs[self.split_col] = split_array
|
|
70
|
+
|
|
71
|
+
# Save to disk
|
|
72
|
+
if self.split_save_path:
|
|
73
|
+
self.adata.obs[[self.split_col]].to_csv(self.split_save_path)
|
|
74
|
+
else:
|
|
75
|
+
split_labels = self.adata.obs[self.split_col].values
|
|
76
|
+
train_mask = split_labels == "train"
|
|
77
|
+
val_mask = split_labels == "val"
|
|
78
|
+
self.train_set = TensorDataset(X_tensor[train_mask], y_tensor[train_mask])
|
|
79
|
+
self.val_set = TensorDataset(X_tensor[val_mask], y_tensor[val_mask])
|
|
80
|
+
|
|
81
|
+
def train_dataloader(self):
|
|
82
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
|
|
83
|
+
|
|
84
|
+
def val_dataloader(self):
|
|
85
|
+
return DataLoader(self.val_set, batch_size=self.batch_size)
|
|
86
|
+
|
|
87
|
+
def predict_dataloader(self):
|
|
88
|
+
if not self.inference_mode:
|
|
89
|
+
raise RuntimeError("predict_dataloader only available in inference mode.")
|
|
90
|
+
return DataLoader(self.infer_dataset, batch_size=self.batch_size)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
|
|
2
|
+
import torch
|
|
3
|
+
print("\n🔹 **HMM Model Overview**")
|
|
4
|
+
print(hmm)
|
|
5
|
+
|
|
6
|
+
print("\n🔹 **Transition Matrix**")
|
|
7
|
+
transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
|
|
8
|
+
for i, row in enumerate(transition_matrix):
|
|
9
|
+
label = state_labels[i] if state_labels else f"State {i}"
|
|
10
|
+
formatted_row = ", ".join(f"{p:.6f}" for p in row)
|
|
11
|
+
print(f"{label}: [{formatted_row}]")
|
|
12
|
+
|
|
13
|
+
print("\n🔹 **Emission Probabilities**")
|
|
14
|
+
for i, dist in enumerate(hmm.distributions):
|
|
15
|
+
label = state_labels[i] if state_labels else f"State {i}"
|
|
16
|
+
probs = dist.probs.detach().cpu().numpy()
|
|
17
|
+
formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
|
|
18
|
+
print(f"{label}: {formatted_emissions}")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
def create_nan_mask_from_X(adata, new_layer_name="nan_mask"):
|
|
2
|
+
"""
|
|
3
|
+
Generates a nan mask where 1 = NaN in adata.X and 0 = valid value.
|
|
4
|
+
"""
|
|
5
|
+
import numpy as np
|
|
6
|
+
nan_mask = np.isnan(adata.X).astype(int)
|
|
7
|
+
adata.layers[new_layer_name] = nan_mask
|
|
8
|
+
print(f"✅ Created '{new_layer_name}' layer based on NaNs in adata.X")
|
|
9
|
+
return adata
|
|
10
|
+
|
|
11
|
+
def create_nan_or_non_gpc_mask(adata, obs_column, new_layer_name="nan_or_non_gpc_mask"):
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
nan_mask = np.isnan(adata.X).astype(int)
|
|
15
|
+
combined_mask = np.zeros_like(nan_mask)
|
|
16
|
+
|
|
17
|
+
for idx, row in enumerate(adata.obs.itertuples()):
|
|
18
|
+
ref = getattr(row, obs_column)
|
|
19
|
+
gpc_mask = adata.var[f"{ref}_GpC_site"].astype(int).values
|
|
20
|
+
combined_mask[idx, :] = 1 - gpc_mask # non-GpC is 1
|
|
21
|
+
|
|
22
|
+
mask = np.maximum(nan_mask, combined_mask)
|
|
23
|
+
adata.layers[new_layer_name] = mask
|
|
24
|
+
|
|
25
|
+
print(f"✅ Created '{new_layer_name}' layer based on NaNs in adata.X and non-GpC regions using {obs_column}")
|
|
26
|
+
return adata
|
|
27
|
+
|
|
28
|
+
def combine_layers(adata, input_layers, output_layer, negative_mask=None, values=None, binary_mode=False):
|
|
29
|
+
"""
|
|
30
|
+
Combines layers into a single layer with specific coding:
|
|
31
|
+
- Background stays 0
|
|
32
|
+
- If binary_mode=True: any overlap = 1
|
|
33
|
+
- If binary_mode=False:
|
|
34
|
+
- Defaults to [1, 2, 3, ...] if values=None
|
|
35
|
+
- Later layers take precedence in overlaps
|
|
36
|
+
|
|
37
|
+
Parameters:
|
|
38
|
+
adata: AnnData object
|
|
39
|
+
input_layers: list of str
|
|
40
|
+
output_layer: str, name of the output layer
|
|
41
|
+
negative_mask: str (optional), binary mask to enforce 0s
|
|
42
|
+
values: list of ints (optional), values to assign to each input layer
|
|
43
|
+
binary_mode: bool, if True, creates a simple 0/1 mask regardless of values
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Updated AnnData with new layer.
|
|
47
|
+
"""
|
|
48
|
+
import numpy as np
|
|
49
|
+
combined = np.zeros_like(adata.layers[input_layers[0]])
|
|
50
|
+
|
|
51
|
+
if binary_mode:
|
|
52
|
+
for layer in input_layers:
|
|
53
|
+
combined = np.logical_or(combined, adata.layers[layer] > 0)
|
|
54
|
+
combined = combined.astype(int)
|
|
55
|
+
else:
|
|
56
|
+
if values is None:
|
|
57
|
+
values = list(range(1, len(input_layers) + 1))
|
|
58
|
+
for i, layer in enumerate(input_layers):
|
|
59
|
+
arr = adata.layers[layer]
|
|
60
|
+
combined[arr > 0] = values[i]
|
|
61
|
+
|
|
62
|
+
if negative_mask:
|
|
63
|
+
mask = adata.layers[negative_mask]
|
|
64
|
+
combined[mask == 0] = 0
|
|
65
|
+
|
|
66
|
+
adata.layers[output_layer] = combined
|
|
67
|
+
print(f"✅ Combined layers into {output_layer} {'(binary)' if binary_mode else f'with values {values}'}")
|
|
68
|
+
|
|
69
|
+
return adata
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
def load_hmm(model_path, device='cpu'):
|
|
2
|
+
"""
|
|
3
|
+
Reads in a pretrained HMM.
|
|
4
|
+
|
|
5
|
+
Parameters:
|
|
6
|
+
model_path (str): Path to a pretrained HMM
|
|
7
|
+
"""
|
|
8
|
+
import torch
|
|
9
|
+
# Load model using PyTorch
|
|
10
|
+
hmm = torch.load(model_path)
|
|
11
|
+
hmm.to(device)
|
|
12
|
+
return hmm
|
|
13
|
+
|
|
14
|
+
def save_hmm(model, model_path):
|
|
15
|
+
import torch
|
|
16
|
+
torch.save(model, model_path)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .lightning_inference import run_lightning_inference
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pytorch_lightning import Trainer
|
|
5
|
+
|
|
6
|
+
def run_lightning_inference(
|
|
7
|
+
adata,
|
|
8
|
+
model,
|
|
9
|
+
datamodule,
|
|
10
|
+
label_col="labels",
|
|
11
|
+
prefix="model"
|
|
12
|
+
):
|
|
13
|
+
|
|
14
|
+
# Get class labels
|
|
15
|
+
if label_col in adata.obs and pd.api.types.is_categorical_dtype(adata.obs[label_col]):
|
|
16
|
+
class_labels = adata.obs[label_col].cat.categories.tolist()
|
|
17
|
+
else:
|
|
18
|
+
raise ValueError("label_col must be a categorical column in adata.obs")
|
|
19
|
+
|
|
20
|
+
# Run predictions
|
|
21
|
+
trainer = Trainer(accelerator="auto", devices=1, logger=False, enable_checkpointing=False)
|
|
22
|
+
preds = trainer.predict(model, datamodule=datamodule)
|
|
23
|
+
probs = torch.cat(preds, dim=0).cpu().numpy() # (N, C)
|
|
24
|
+
pred_class_idx = probs.argmax(axis=1)
|
|
25
|
+
pred_class_labels = [class_labels[i] for i in pred_class_idx]
|
|
26
|
+
pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
|
|
27
|
+
|
|
28
|
+
# Construct full prefix with label_col
|
|
29
|
+
full_prefix = f"{prefix}_{label_col}"
|
|
30
|
+
|
|
31
|
+
# Store predictions in obs
|
|
32
|
+
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
33
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
|
|
34
|
+
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
35
|
+
|
|
36
|
+
# Per-class probabilities
|
|
37
|
+
for i, class_name in enumerate(class_labels):
|
|
38
|
+
adata.obs[f"{full_prefix}_prob_{class_name}"] = probs[:, i]
|
|
39
|
+
|
|
40
|
+
# Full probability matrix in obsm
|
|
41
|
+
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .base import BaseTorchModel
|
|
2
|
+
from .mlp import MLPClassifier
|
|
3
|
+
from .cnn import CNNClassifier
|
|
4
|
+
from .rnn import RNNClassifier
|
|
5
|
+
from .transformer import BaseTransformer, TransformerClassifier, DANNTransformerClassifier, MaskedTransformerPretrainer
|
|
6
|
+
from .positional import PositionalEncoding
|
|
7
|
+
from .wrappers import ScaledModel
|
|
8
|
+
from .lightning_base import TorchClassifierWrapper
|
|
9
|
+
from .sklearn_models import SklearnModelWrapper
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from ..utils.device import detect_device
|
|
3
|
+
|
|
4
|
+
class BaseTorchModel(nn.Module):
|
|
5
|
+
"""
|
|
6
|
+
Minimal base class for torch models that:
|
|
7
|
+
- Stores device
|
|
8
|
+
- Moves model to detected device on init
|
|
9
|
+
"""
|
|
10
|
+
def __init__(self, dropout_rate=0.2):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.device = detect_device() # detects available devices
|
|
13
|
+
self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
|
|
14
|
+
self.to(self.device) # move model to device
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
|
|
5
|
+
class CNNClassifier(BaseTorchModel):
|
|
6
|
+
def __init__(self, input_size, num_classes, **kwargs):
|
|
7
|
+
super().__init__(**kwargs)
|
|
8
|
+
# Define convolutional layers
|
|
9
|
+
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
|
|
10
|
+
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
|
|
11
|
+
# Define activation function
|
|
12
|
+
self.relu = nn.ReLU()
|
|
13
|
+
|
|
14
|
+
# Determine the flattened size dynamically
|
|
15
|
+
dummy_input = torch.zeros(1, 1, input_size).to(self.device)
|
|
16
|
+
with torch.no_grad():
|
|
17
|
+
dummy_output = self._forward_conv(dummy_input)
|
|
18
|
+
flattened_size = dummy_output.view(1, -1).shape[1]
|
|
19
|
+
|
|
20
|
+
# Define fully connected layers
|
|
21
|
+
self.fc1 = nn.Linear(flattened_size, 64)
|
|
22
|
+
self.fc2 = nn.Linear(64, num_classes)
|
|
23
|
+
|
|
24
|
+
def _forward_conv(self, x):
|
|
25
|
+
x = self.relu(self.conv1(x))
|
|
26
|
+
x = self.relu(self.conv2(x))
|
|
27
|
+
return x
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
x = x.unsqueeze(1) # [B, 1, L]
|
|
31
|
+
x = self._forward_conv(x)
|
|
32
|
+
x = x.view(x.size(0), -1) # flatten
|
|
33
|
+
x = self.relu(self.fc1(x))
|
|
34
|
+
return self.fc2(x)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pytorch_lightning as pl
|
|
3
|
+
|
|
4
|
+
class TorchClassifierWrapper(pl.LightningModule):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
model: torch.nn.Module,
|
|
8
|
+
optimizer_cls=torch.optim.AdamW,
|
|
9
|
+
optimizer_kwargs=None,
|
|
10
|
+
criterion_cls=torch.nn.CrossEntropyLoss,
|
|
11
|
+
criterion_kwargs=None,
|
|
12
|
+
lr: float = 1e-3,
|
|
13
|
+
):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.model = model
|
|
16
|
+
self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
|
|
17
|
+
self.optimizer_cls = optimizer_cls
|
|
18
|
+
self.optimizer_kwargs = optimizer_kwargs or {}
|
|
19
|
+
self.criterion = criterion_cls(**(criterion_kwargs or {}))
|
|
20
|
+
self.lr = lr
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
return self.model(x)
|
|
24
|
+
|
|
25
|
+
def training_step(self, batch, batch_idx):
|
|
26
|
+
x, y = batch
|
|
27
|
+
logits = self(x)
|
|
28
|
+
loss = self.criterion(logits, y)
|
|
29
|
+
self.log("train_loss", loss, prog_bar=True)
|
|
30
|
+
return loss
|
|
31
|
+
|
|
32
|
+
def validation_step(self, batch, batch_idx):
|
|
33
|
+
x, y = batch
|
|
34
|
+
logits = self(x)
|
|
35
|
+
loss = self.criterion(logits, y)
|
|
36
|
+
acc = (logits.argmax(dim=1) == y).float().mean()
|
|
37
|
+
self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=True)
|
|
38
|
+
return loss
|
|
39
|
+
|
|
40
|
+
def configure_optimizers(self):
|
|
41
|
+
return self.optimizer_cls(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
|
|
5
|
+
class MLPClassifier(BaseTorchModel):
|
|
6
|
+
def __init__(self, input_dim, num_classes, hidden_sizes=(128, 64), **kwargs):
|
|
7
|
+
super().__init__(**kwargs)
|
|
8
|
+
layers = []
|
|
9
|
+
prev = input_dim
|
|
10
|
+
for h in hidden_sizes:
|
|
11
|
+
layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(self.dropout_rate)])
|
|
12
|
+
prev = h
|
|
13
|
+
layers.append(nn.Linear(prev, num_classes))
|
|
14
|
+
self.model = nn.Sequential(*layers)
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
return self.model(x)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
class PositionalEncoding(nn.Module):
|
|
6
|
+
def __init__(self, d_model, max_len=5000):
|
|
7
|
+
super().__init__()
|
|
8
|
+
pe = torch.zeros(max_len, d_model)
|
|
9
|
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
10
|
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
|
11
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
12
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
13
|
+
self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
|
14
|
+
|
|
15
|
+
def forward(self, x):
|
|
16
|
+
x = x + self.pe[:, :x.size(1)].to(x.device)
|
|
17
|
+
return x
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
|
|
5
|
+
class RNNClassifier(BaseTorchModel):
|
|
6
|
+
def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
|
|
7
|
+
super().__init__(**kwargs)
|
|
8
|
+
# Define LSTM layer
|
|
9
|
+
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
10
|
+
# Define fully connected output layer
|
|
11
|
+
self.fc = nn.Linear(hidden_dim, num_classes)
|
|
12
|
+
|
|
13
|
+
def forward(self, x):
|
|
14
|
+
x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
|
|
15
|
+
_, (h_n, _) = self.lstm(x) # h_n: [1, B, H]
|
|
16
|
+
return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
2
|
+
from sklearn.naive_bayes import GaussianNB
|
|
3
|
+
from sklearn.metrics import (
|
|
4
|
+
roc_curve, precision_recall_curve, auc, f1_score, confusion_matrix
|
|
5
|
+
)
|
|
6
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
class SklearnModelWrapper:
|
|
11
|
+
def __init__(self, model):
|
|
12
|
+
self.model = model
|
|
13
|
+
|
|
14
|
+
def fit(self, X_train, y_train):
|
|
15
|
+
self.model.fit(X_train, y_train)
|
|
16
|
+
|
|
17
|
+
def predict(self, X):
|
|
18
|
+
return self.model.predict(X)
|
|
19
|
+
|
|
20
|
+
def predict_proba(self, X):
|
|
21
|
+
return self.model.predict_proba(X)
|
|
22
|
+
|
|
23
|
+
def evaluate(self, X_test, y_test):
|
|
24
|
+
probs = self.predict_proba(X_test)[:, 1]
|
|
25
|
+
preds = self.predict(X_test)
|
|
26
|
+
|
|
27
|
+
fpr, tpr, _ = roc_curve(y_test, probs)
|
|
28
|
+
precision, recall, _ = precision_recall_curve(y_test, probs)
|
|
29
|
+
f1 = f1_score(y_test, preds)
|
|
30
|
+
auc_score = auc(fpr, tpr)
|
|
31
|
+
pr_auc = auc(recall, precision)
|
|
32
|
+
cm = confusion_matrix(y_test, preds)
|
|
33
|
+
pos_freq = np.mean(y_test == 1)
|
|
34
|
+
pr_auc_norm = pr_auc / pos_freq
|
|
35
|
+
|
|
36
|
+
return {
|
|
37
|
+
"fpr": fpr, "tpr": tpr, "precision": precision, "recall": recall,
|
|
38
|
+
"f1": f1, "auc": auc_score, "pr_auc": pr_auc,
|
|
39
|
+
"pr_auc_norm": pr_auc_norm, "confusion_matrix": cm
|
|
40
|
+
}
|