smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +34 -0
- smftools/_settings.py +20 -0
- smftools/_version.py +1 -0
- smftools/cli.py +184 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +33 -0
- smftools/config/deaminase.yaml +56 -0
- smftools/config/default.yaml +253 -0
- smftools/config/direct.yaml +17 -0
- smftools/config/experiment_config.py +1191 -0
- smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
- smftools/datasets/F1_sample_sheet.csv +5 -0
- smftools/datasets/__init__.py +9 -0
- smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
- smftools/datasets/datasets.py +28 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/hmm/apply_hmm_batched.py +242 -0
- smftools/hmm/calculate_distances.py +18 -0
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/hmm/display_hmm.py +18 -0
- smftools/hmm/hmm_readwrite.py +16 -0
- smftools/hmm/nucleosome_hmm_refinement.py +104 -0
- smftools/hmm/train_hmm.py +78 -0
- smftools/informatics/__init__.py +14 -0
- smftools/informatics/archived/bam_conversion.py +59 -0
- smftools/informatics/archived/bam_direct.py +63 -0
- smftools/informatics/archived/basecalls_to_adata.py +71 -0
- smftools/informatics/archived/conversion_smf.py +132 -0
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/direct_smf.py +137 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/fast5_to_pod5.py +24 -0
- smftools/informatics/helpers/__init__.py +73 -0
- smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
- smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
- smftools/informatics/helpers/archived/informatics.py +260 -0
- smftools/informatics/helpers/archived/load_adata.py +516 -0
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/bed_to_bigwig.py +39 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
- smftools/informatics/helpers/canoncall.py +34 -0
- smftools/informatics/helpers/complement_base_list.py +21 -0
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
- smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
- smftools/informatics/helpers/count_aligned_reads.py +43 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/discover_input_files.py +100 -0
- smftools/informatics/helpers/extract_base_identities.py +70 -0
- smftools/informatics/helpers/extract_mods.py +83 -0
- smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
- smftools/informatics/helpers/find_conversion_sites.py +51 -0
- smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
- smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
- smftools/informatics/helpers/get_native_references.py +28 -0
- smftools/informatics/helpers/index_fasta.py +12 -0
- smftools/informatics/helpers/make_dirs.py +21 -0
- smftools/informatics/helpers/make_modbed.py +27 -0
- smftools/informatics/helpers/modQC.py +27 -0
- smftools/informatics/helpers/modcall.py +36 -0
- smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
- smftools/informatics/helpers/ohe_batching.py +76 -0
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +57 -0
- smftools/informatics/helpers/plot_bed_histograms.py +269 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
- smftools/informatics/helpers/split_and_index_BAM.py +32 -0
- smftools/informatics/readwrite.py +106 -0
- smftools/informatics/subsample_fasta_from_bed.py +47 -0
- smftools/informatics/subsample_pod5.py +104 -0
- smftools/load_adata.py +1346 -0
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/data/preprocessing.py +6 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/__init__.py +9 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/machine_learning/models/positional.py +18 -0
- smftools/machine_learning/models/rnn.py +17 -0
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/models/wrappers.py +20 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +10 -0
- smftools/machine_learning/utils/grl.py +14 -0
- smftools/plotting/__init__.py +18 -0
- smftools/plotting/autocorrelation_plotting.py +611 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +682 -0
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +38 -0
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/archives/mark_duplicates.py +146 -0
- smftools/preprocessing/archives/preprocessing.py +614 -0
- smftools/preprocessing/archives/remove_duplicates.py +21 -0
- smftools/preprocessing/binarize_on_Youden.py +45 -0
- smftools/preprocessing/binary_layers_to_ohe.py +40 -0
- smftools/preprocessing/calculate_complexity.py +72 -0
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_consensus.py +47 -0
- smftools/preprocessing/calculate_coverage.py +51 -0
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
- smftools/preprocessing/calculate_position_Youden.py +115 -0
- smftools/preprocessing/calculate_read_length_stats.py +79 -0
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +62 -0
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1351 -0
- smftools/preprocessing/invert_adata.py +37 -0
- smftools/preprocessing/load_sample_sheet.py +53 -0
- smftools/preprocessing/make_dirs.py +21 -0
- smftools/preprocessing/min_non_diagonal.py +25 -0
- smftools/preprocessing/recipes.py +127 -0
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +1004 -0
- smftools/tools/__init__.py +20 -0
- smftools/tools/archived/apply_hmm.py +202 -0
- smftools/tools/archived/classifiers.py +787 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/position_stats.py +601 -0
- smftools/tools/read_stats.py +184 -0
- smftools/tools/spatial_autocorrelation.py +562 -0
- smftools/tools/subset_adata.py +28 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
- smftools-0.2.1.dist-info/RECORD +161 -0
- smftools-0.2.1.dist-info/entry_points.txt +2 -0
- smftools-0.1.6.dist-info/RECORD +0 -4
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from . import models
|
|
2
|
+
from . import data
|
|
3
|
+
from . import utils
|
|
4
|
+
from . import evaluation
|
|
5
|
+
from . import inference
|
|
6
|
+
from . import training
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"calculate_relative_risk_on_activity",
|
|
10
|
+
"evaluate_models_by_subgroup",
|
|
11
|
+
"prepare_melted_model_data",
|
|
12
|
+
]
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset, Subset
|
|
3
|
+
import pytorch_lightning as pl
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from .preprocessing import random_fill_nans
|
|
7
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnnDataDataset(Dataset):
|
|
11
|
+
"""
|
|
12
|
+
Generic PyTorch Dataset from AnnData.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, adata, tensor_source="X", tensor_key=None, label_col=None, window_start=None, window_size=None):
|
|
15
|
+
self.adata = adata
|
|
16
|
+
self.tensor_source = tensor_source
|
|
17
|
+
self.tensor_key = tensor_key
|
|
18
|
+
self.label_col = label_col
|
|
19
|
+
self.window_start = window_start
|
|
20
|
+
self.window_size = window_size
|
|
21
|
+
|
|
22
|
+
if tensor_source == "X":
|
|
23
|
+
X = adata.X
|
|
24
|
+
elif tensor_source == "layers":
|
|
25
|
+
assert tensor_key in adata.layers
|
|
26
|
+
X = adata.layers[tensor_key]
|
|
27
|
+
elif tensor_source == "obsm":
|
|
28
|
+
assert tensor_key in adata.obsm
|
|
29
|
+
X = adata.obsm[tensor_key]
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError(f"Invalid tensor_source: {tensor_source}")
|
|
32
|
+
|
|
33
|
+
if self.window_start is not None and self.window_size is not None:
|
|
34
|
+
X = X[:, self.window_start : self.window_start + self.window_size]
|
|
35
|
+
|
|
36
|
+
X = random_fill_nans(X)
|
|
37
|
+
|
|
38
|
+
self.X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
39
|
+
|
|
40
|
+
if label_col is not None:
|
|
41
|
+
y = adata.obs[label_col]
|
|
42
|
+
if y.dtype.name == 'category':
|
|
43
|
+
y = y.cat.codes
|
|
44
|
+
self.y_tensor = torch.tensor(y.values, dtype=torch.long)
|
|
45
|
+
else:
|
|
46
|
+
self.y_tensor = None
|
|
47
|
+
|
|
48
|
+
def numpy(self, indices):
|
|
49
|
+
return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
|
|
50
|
+
|
|
51
|
+
def __len__(self):
|
|
52
|
+
return len(self.X_tensor)
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, idx):
|
|
55
|
+
x = self.X_tensor[idx]
|
|
56
|
+
if self.y_tensor is not None:
|
|
57
|
+
y = self.y_tensor[idx]
|
|
58
|
+
return x, y
|
|
59
|
+
else:
|
|
60
|
+
return (x,)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
|
|
64
|
+
random_seed=42, split_col="train_val_test_split",
|
|
65
|
+
load_existing_split=False, split_save_path=None):
|
|
66
|
+
"""
|
|
67
|
+
Perform split and record assignment into adata.obs[split_col].
|
|
68
|
+
"""
|
|
69
|
+
total_len = len(dataset)
|
|
70
|
+
|
|
71
|
+
if load_existing_split:
|
|
72
|
+
if split_col in adata.obs:
|
|
73
|
+
pass # use existing
|
|
74
|
+
elif split_save_path:
|
|
75
|
+
split_df = pd.read_csv(split_save_path, index_col=0)
|
|
76
|
+
adata.obs[split_col] = split_df.loc[adata.obs_names][split_col].values
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("No existing split column found and no file provided.")
|
|
79
|
+
else:
|
|
80
|
+
indices = np.arange(total_len)
|
|
81
|
+
np.random.seed(random_seed)
|
|
82
|
+
np.random.shuffle(indices)
|
|
83
|
+
|
|
84
|
+
n_train = int(train_frac * total_len)
|
|
85
|
+
n_val = int(val_frac * total_len)
|
|
86
|
+
n_test = total_len - n_train - n_val
|
|
87
|
+
|
|
88
|
+
split_array = np.full(total_len, "test", dtype=object)
|
|
89
|
+
split_array[indices[:n_train]] = "train"
|
|
90
|
+
split_array[indices[n_train:n_train + n_val]] = "val"
|
|
91
|
+
adata.obs[split_col] = split_array
|
|
92
|
+
|
|
93
|
+
if split_save_path:
|
|
94
|
+
adata.obs[[split_col]].to_csv(split_save_path)
|
|
95
|
+
|
|
96
|
+
split_labels = adata.obs[split_col].values
|
|
97
|
+
train_indices = np.where(split_labels == "train")[0]
|
|
98
|
+
val_indices = np.where(split_labels == "val")[0]
|
|
99
|
+
test_indices = np.where(split_labels == "test")[0]
|
|
100
|
+
|
|
101
|
+
train_set = Subset(dataset, train_indices)
|
|
102
|
+
val_set = Subset(dataset, val_indices)
|
|
103
|
+
test_set = Subset(dataset, test_indices)
|
|
104
|
+
|
|
105
|
+
return train_set, val_set, test_set
|
|
106
|
+
|
|
107
|
+
class AnnDataModule(pl.LightningDataModule):
|
|
108
|
+
"""
|
|
109
|
+
Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
|
|
110
|
+
"""
|
|
111
|
+
def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
|
|
112
|
+
batch_size=64, train_frac=0.6, val_frac=0.1, test_frac=0.3, random_seed=42,
|
|
113
|
+
inference_mode=False, split_col="train_val_test_split", split_save_path=None,
|
|
114
|
+
load_existing_split=False, window_start=None, window_size=None, num_workers=None, persistent_workers=False):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.adata = adata
|
|
117
|
+
self.tensor_source = tensor_source
|
|
118
|
+
self.tensor_key = tensor_key
|
|
119
|
+
self.label_col = label_col
|
|
120
|
+
self.batch_size = batch_size
|
|
121
|
+
self.train_frac = train_frac
|
|
122
|
+
self.val_frac = val_frac
|
|
123
|
+
self.test_frac = test_frac
|
|
124
|
+
self.random_seed = random_seed
|
|
125
|
+
self.inference_mode = inference_mode
|
|
126
|
+
self.split_col = split_col
|
|
127
|
+
self.split_save_path = split_save_path
|
|
128
|
+
self.load_existing_split = load_existing_split
|
|
129
|
+
self.var_names = adata.var_names.copy()
|
|
130
|
+
self.window_start = window_start
|
|
131
|
+
self.window_size = window_size
|
|
132
|
+
self.num_workers = num_workers
|
|
133
|
+
self.persistent_workers = persistent_workers
|
|
134
|
+
|
|
135
|
+
def setup(self, stage=None):
|
|
136
|
+
dataset = AnnDataDataset(self.adata, self.tensor_source, self.tensor_key,
|
|
137
|
+
None if self.inference_mode else self.label_col,
|
|
138
|
+
window_start=self.window_start, window_size=self.window_size)
|
|
139
|
+
|
|
140
|
+
if self.inference_mode:
|
|
141
|
+
self.infer_dataset = dataset
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
self.train_set, self.val_set, self.test_set = split_dataset(
|
|
145
|
+
self.adata, dataset, train_frac=self.train_frac, val_frac=self.val_frac,
|
|
146
|
+
test_frac=self.test_frac, random_seed=self.random_seed,
|
|
147
|
+
split_col=self.split_col, split_save_path=self.split_save_path,
|
|
148
|
+
load_existing_split=self.load_existing_split
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def train_dataloader(self):
|
|
152
|
+
if self.num_workers:
|
|
153
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
|
|
154
|
+
else:
|
|
155
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
|
|
156
|
+
|
|
157
|
+
def val_dataloader(self):
|
|
158
|
+
if self.num_workers:
|
|
159
|
+
return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
|
|
160
|
+
else:
|
|
161
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
162
|
+
|
|
163
|
+
def test_dataloader(self):
|
|
164
|
+
if self.num_workers:
|
|
165
|
+
return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
|
|
166
|
+
else:
|
|
167
|
+
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
168
|
+
|
|
169
|
+
def predict_dataloader(self):
|
|
170
|
+
if not self.inference_mode:
|
|
171
|
+
raise RuntimeError("Only valid in inference mode")
|
|
172
|
+
return DataLoader(self.infer_dataset, batch_size=self.batch_size)
|
|
173
|
+
|
|
174
|
+
def compute_class_weights(self):
|
|
175
|
+
train_indices = self.train_set.indices # get the indices of the training set
|
|
176
|
+
y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
|
|
177
|
+
y_train = y_all[train_indices].cpu().numpy() # get the labels for the training set and move to a numpy array
|
|
178
|
+
|
|
179
|
+
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
|
|
180
|
+
return torch.tensor(class_weights, dtype=torch.float32)
|
|
181
|
+
|
|
182
|
+
def inference_numpy(self):
|
|
183
|
+
"""
|
|
184
|
+
Return inference data as numpy for use in sklearn inference.
|
|
185
|
+
"""
|
|
186
|
+
if not self.inference_mode:
|
|
187
|
+
raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
|
|
188
|
+
X_np = self.infer_dataset.X_tensor.numpy()
|
|
189
|
+
return X_np
|
|
190
|
+
|
|
191
|
+
def to_numpy(self):
|
|
192
|
+
"""
|
|
193
|
+
Move the AnnDataModule tensors into numpy arrays
|
|
194
|
+
"""
|
|
195
|
+
if not self.inference_mode:
|
|
196
|
+
train_X, train_y = self.train_set.dataset.numpy(self.train_set.indices)
|
|
197
|
+
val_X, val_y = self.val_set.dataset.numpy(self.val_set.indices)
|
|
198
|
+
test_X, test_Y = self.test_set.dataset.numpy(self.test_set.indices)
|
|
199
|
+
return train_X, train_y, val_X, val_y, test_X, test_Y
|
|
200
|
+
else:
|
|
201
|
+
return self.inference_numpy()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def build_anndata_loader(
|
|
205
|
+
adata, tensor_source="X", tensor_key=None, label_col=None, train_frac=0.6, val_frac=0.1,
|
|
206
|
+
test_frac=0.3, random_seed=42, batch_size=64, lightning=True, inference_mode=False,
|
|
207
|
+
split_col="train_val_test_split", split_save_path=None, load_existing_split=False
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
Unified pipeline for both Lightning and raw PyTorch.
|
|
211
|
+
The lightning loader works for both Lightning and the Sklearn wrapper.
|
|
212
|
+
Set lightning to False if you want to make data loaders for base PyTorch or base sklearn models
|
|
213
|
+
"""
|
|
214
|
+
if lightning:
|
|
215
|
+
return AnnDataModule(
|
|
216
|
+
adata, tensor_source=tensor_source, tensor_key=tensor_key, label_col=label_col,
|
|
217
|
+
batch_size=batch_size, train_frac=train_frac, val_frac=val_frac, test_frac=test_frac,
|
|
218
|
+
random_seed=random_seed, inference_mode=inference_mode,
|
|
219
|
+
split_col=split_col, split_save_path=split_save_path, load_existing_split=load_existing_split
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
var_names = adata.var_names.copy()
|
|
223
|
+
dataset = AnnDataDataset(adata, tensor_source, tensor_key, None if inference_mode else label_col)
|
|
224
|
+
if inference_mode:
|
|
225
|
+
return DataLoader(dataset, batch_size=batch_size)
|
|
226
|
+
else:
|
|
227
|
+
train_set, val_set, test_set = split_dataset(
|
|
228
|
+
adata, dataset, train_frac, val_frac, test_frac, random_seed,
|
|
229
|
+
split_col, split_save_path, load_existing_split
|
|
230
|
+
)
|
|
231
|
+
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
|
|
232
|
+
val_loader = DataLoader(val_set, batch_size=batch_size)
|
|
233
|
+
test_loader = DataLoader(test_set, batch_size=batch_size)
|
|
234
|
+
return train_loader, val_loader, test_loader
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
def flatten_sliding_window_results(results_dict):
|
|
4
|
+
"""
|
|
5
|
+
Flatten nested sliding window results into pandas DataFrame.
|
|
6
|
+
|
|
7
|
+
Expects structure:
|
|
8
|
+
results[model_name][window_size][window_center]['metrics'][metric_name]
|
|
9
|
+
"""
|
|
10
|
+
records = []
|
|
11
|
+
|
|
12
|
+
for model_name, model_results in results_dict.items():
|
|
13
|
+
for window_size, window_results in model_results.items():
|
|
14
|
+
for center_var, result in window_results.items():
|
|
15
|
+
metrics = result['metrics']
|
|
16
|
+
record = {
|
|
17
|
+
'model': model_name,
|
|
18
|
+
'window_size': window_size,
|
|
19
|
+
'center_var': center_var
|
|
20
|
+
}
|
|
21
|
+
# Add all metrics
|
|
22
|
+
record.update(metrics)
|
|
23
|
+
records.append(record)
|
|
24
|
+
|
|
25
|
+
df = pd.DataFrame.from_records(records)
|
|
26
|
+
|
|
27
|
+
# Convert center_var to numeric if possible (optional but helpful for plotting)
|
|
28
|
+
df['center_var'] = pd.to_numeric(df['center_var'], errors='coerce')
|
|
29
|
+
df = df.sort_values(['model', 'window_size', 'center_var'])
|
|
30
|
+
|
|
31
|
+
return df
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
|
|
5
|
+
from sklearn.metrics import (
|
|
6
|
+
roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
class ModelEvaluator:
|
|
10
|
+
"""
|
|
11
|
+
A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.results = []
|
|
15
|
+
self.pos_freq = None
|
|
16
|
+
self.num_pos = None
|
|
17
|
+
|
|
18
|
+
def add_model(self, name, model, is_torch=True):
|
|
19
|
+
"""
|
|
20
|
+
Add a trained model with its evaluation metrics.
|
|
21
|
+
"""
|
|
22
|
+
if is_torch:
|
|
23
|
+
entry = {
|
|
24
|
+
'name': name,
|
|
25
|
+
'f1': model.test_f1,
|
|
26
|
+
'auc': model.test_roc_auc,
|
|
27
|
+
'pr_auc': model.test_pr_auc,
|
|
28
|
+
'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
|
|
29
|
+
'pr_curve': model.test_pr_curve,
|
|
30
|
+
'roc_curve': model.test_roc_curve,
|
|
31
|
+
'num_pos': model.test_num_pos,
|
|
32
|
+
'pos_freq': model.test_pos_freq
|
|
33
|
+
}
|
|
34
|
+
else:
|
|
35
|
+
entry = {
|
|
36
|
+
'name': name,
|
|
37
|
+
'f1': model.test_f1,
|
|
38
|
+
'auc': model.test_roc_auc,
|
|
39
|
+
'pr_auc': model.test_pr_auc,
|
|
40
|
+
'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
|
|
41
|
+
'pr_curve': model.test_pr_curve,
|
|
42
|
+
'roc_curve': model.test_roc_curve,
|
|
43
|
+
'num_pos': model.test_num_pos,
|
|
44
|
+
'pos_freq': model.test_pos_freq
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
self.results.append(entry)
|
|
48
|
+
|
|
49
|
+
if not self.pos_freq:
|
|
50
|
+
self.pos_freq = entry['pos_freq']
|
|
51
|
+
self.num_pos = entry['num_pos']
|
|
52
|
+
|
|
53
|
+
def get_metrics_dataframe(self):
|
|
54
|
+
"""
|
|
55
|
+
Return all metrics as pandas DataFrame.
|
|
56
|
+
"""
|
|
57
|
+
df = pd.DataFrame(self.results)
|
|
58
|
+
return df[['name', 'f1', 'auc', 'pr_auc', 'pr_auc_norm', 'num_pos', 'pos_freq']]
|
|
59
|
+
|
|
60
|
+
def plot_all_curves(self):
|
|
61
|
+
"""
|
|
62
|
+
Plot unified ROC and PR curves across all models.
|
|
63
|
+
"""
|
|
64
|
+
plt.figure(figsize=(12, 5))
|
|
65
|
+
|
|
66
|
+
# ROC
|
|
67
|
+
plt.subplot(1, 2, 1)
|
|
68
|
+
for res in self.results:
|
|
69
|
+
fpr, tpr = res['roc_curve']
|
|
70
|
+
plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
|
|
71
|
+
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
72
|
+
plt.xlabel("False Positive Rate")
|
|
73
|
+
plt.ylabel("True Positive Rate")
|
|
74
|
+
plt.ylim(0,1.05)
|
|
75
|
+
plt.title(f"ROC Curves - {self.num_pos} positive instances")
|
|
76
|
+
plt.legend()
|
|
77
|
+
|
|
78
|
+
# PR
|
|
79
|
+
plt.subplot(1, 2, 2)
|
|
80
|
+
for res in self.results:
|
|
81
|
+
rc, pr = res['pr_curve']
|
|
82
|
+
plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
|
|
83
|
+
plt.xlabel("Recall")
|
|
84
|
+
plt.ylabel("Precision")
|
|
85
|
+
plt.ylim(0,1.05)
|
|
86
|
+
plt.axhline(self.pos_freq, linestyle='--', color='grey')
|
|
87
|
+
plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
|
|
88
|
+
plt.legend()
|
|
89
|
+
|
|
90
|
+
plt.tight_layout()
|
|
91
|
+
plt.show()
|
|
92
|
+
|
|
93
|
+
class PostInferenceModelEvaluator:
|
|
94
|
+
def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
|
|
95
|
+
"""
|
|
96
|
+
Initialize evaluator.
|
|
97
|
+
|
|
98
|
+
Parameters:
|
|
99
|
+
-----------
|
|
100
|
+
adata : AnnData
|
|
101
|
+
The annotated dataset where predictions are stored in obs/obsm.
|
|
102
|
+
models : dict
|
|
103
|
+
Dictionary of models: {model_name: model_instance}.
|
|
104
|
+
Supports TorchClassifierWrapper and SklearnModelWrapper.
|
|
105
|
+
"""
|
|
106
|
+
self.adata = adata
|
|
107
|
+
self.models = models
|
|
108
|
+
self.target_eval_freq = target_eval_freq
|
|
109
|
+
self.max_eval_positive = max_eval_positive
|
|
110
|
+
self.results = {}
|
|
111
|
+
|
|
112
|
+
def evaluate_all(self):
|
|
113
|
+
"""
|
|
114
|
+
Evaluate all models and store results.
|
|
115
|
+
"""
|
|
116
|
+
for name, model in self.models.items():
|
|
117
|
+
print(f"Evaluating {name}...")
|
|
118
|
+
label_col = model.label_col
|
|
119
|
+
full_prefix = f"{name}_{label_col}"
|
|
120
|
+
self.results[full_prefix] = self._evaluate_model(name, model)
|
|
121
|
+
|
|
122
|
+
def _evaluate_model(self, model_name, model):
|
|
123
|
+
"""
|
|
124
|
+
Evaluate one model and return metrics.
|
|
125
|
+
"""
|
|
126
|
+
label_col = model.label_col
|
|
127
|
+
num_classes = model.num_classes
|
|
128
|
+
class_names = model.class_names
|
|
129
|
+
focus_class = model.focus_class
|
|
130
|
+
|
|
131
|
+
full_prefix = f"{model_name}_{label_col}"
|
|
132
|
+
|
|
133
|
+
# Extract ground truth + predictions
|
|
134
|
+
y_true = self.adata.obs[label_col].cat.codes.to_numpy()
|
|
135
|
+
y_pred = self.adata.obs[f"{full_prefix}_pred"].to_numpy()
|
|
136
|
+
probs_all = self.adata.obsm[f"{full_prefix}_pred_prob_all"]
|
|
137
|
+
|
|
138
|
+
binary_focus = (y_true == focus_class).astype(int)
|
|
139
|
+
|
|
140
|
+
# OPTIONAL SUBSAMPLING
|
|
141
|
+
if self.target_eval_freq is not None:
|
|
142
|
+
indices = self._subsample_for_fixed_positive_frequency(
|
|
143
|
+
binary_focus, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
|
|
144
|
+
)
|
|
145
|
+
y_true = y_true[indices]
|
|
146
|
+
y_pred = y_pred[indices]
|
|
147
|
+
probs_all = probs_all[indices]
|
|
148
|
+
binary_focus = (y_true == focus_class).astype(int)
|
|
149
|
+
|
|
150
|
+
acc = np.mean(y_true == y_pred)
|
|
151
|
+
|
|
152
|
+
if num_classes == 2:
|
|
153
|
+
focus_probs = probs_all[:, focus_class]
|
|
154
|
+
f1 = f1_score(binary_focus, (y_pred == focus_class).astype(int))
|
|
155
|
+
roc_auc = roc_auc_score(binary_focus, focus_probs)
|
|
156
|
+
pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
|
|
157
|
+
fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
|
|
158
|
+
pr_auc = auc(rc, pr)
|
|
159
|
+
pos_freq = binary_focus.mean()
|
|
160
|
+
pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
|
|
161
|
+
else:
|
|
162
|
+
f1 = f1_score(y_true, y_pred, average="macro")
|
|
163
|
+
roc_auc = roc_auc_score(y_true, probs_all, multi_class="ovr", average="macro")
|
|
164
|
+
focus_probs = probs_all[:, focus_class]
|
|
165
|
+
pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
|
|
166
|
+
fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
|
|
167
|
+
pr_auc = auc(rc, pr)
|
|
168
|
+
pos_freq = binary_focus.mean()
|
|
169
|
+
pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
|
|
170
|
+
|
|
171
|
+
cm = confusion_matrix(y_true, y_pred)
|
|
172
|
+
|
|
173
|
+
metrics = {
|
|
174
|
+
"accuracy": acc,
|
|
175
|
+
"f1": f1,
|
|
176
|
+
"roc_auc": roc_auc,
|
|
177
|
+
"pr_auc": pr_auc,
|
|
178
|
+
"pr_auc_norm": pr_auc_norm,
|
|
179
|
+
"pos_freq": pos_freq,
|
|
180
|
+
"confusion_matrix": cm,
|
|
181
|
+
"pr_rc_curve": (pr, rc),
|
|
182
|
+
"roc_curve": (tpr, fpr)
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
return metrics
|
|
186
|
+
|
|
187
|
+
def _subsample_for_fixed_positive_frequency(self, binary_labels, target_freq=0.3, max_positive=None):
|
|
188
|
+
pos_idx = np.where(binary_labels == 1)[0]
|
|
189
|
+
neg_idx = np.where(binary_labels == 0)[0]
|
|
190
|
+
|
|
191
|
+
max_pos = len(pos_idx)
|
|
192
|
+
max_neg = len(neg_idx)
|
|
193
|
+
|
|
194
|
+
max_possible_freq = max_pos / (max_pos + max_neg)
|
|
195
|
+
if target_freq > max_possible_freq:
|
|
196
|
+
target_freq = max_possible_freq
|
|
197
|
+
|
|
198
|
+
num_pos_target = int(target_freq * max_neg / (1 - target_freq))
|
|
199
|
+
num_pos_target = min(num_pos_target, max_pos)
|
|
200
|
+
if max_positive is not None:
|
|
201
|
+
num_pos_target = min(num_pos_target, max_positive)
|
|
202
|
+
|
|
203
|
+
num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
|
|
204
|
+
num_neg_target = min(num_neg_target, max_neg)
|
|
205
|
+
|
|
206
|
+
pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
|
|
207
|
+
neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
|
|
208
|
+
sampled_idx = np.concatenate([pos_sampled, neg_sampled])
|
|
209
|
+
np.random.shuffle(sampled_idx)
|
|
210
|
+
return sampled_idx
|
|
211
|
+
|
|
212
|
+
def to_dataframe(self):
|
|
213
|
+
"""
|
|
214
|
+
Convert results to pandas DataFrame (excluding confusion matrices).
|
|
215
|
+
"""
|
|
216
|
+
records = []
|
|
217
|
+
for model_name, metrics in self.results.items():
|
|
218
|
+
row = {"model": model_name}
|
|
219
|
+
for k, v in metrics.items():
|
|
220
|
+
if k not in ["confusion_matrix", "pr_rc_curve", "roc_curve"]:
|
|
221
|
+
row[k] = v
|
|
222
|
+
records.append(row)
|
|
223
|
+
return pd.DataFrame(records)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
def annotate_split_column(adata, model, split_col="split"):
|
|
4
|
+
"""
|
|
5
|
+
Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
|
|
6
|
+
"""
|
|
7
|
+
# Get sets for fast lookup
|
|
8
|
+
train_set = set(model.train_obs_names)
|
|
9
|
+
val_set = set(model.val_obs_names)
|
|
10
|
+
test_set = set(model.test_obs_names)
|
|
11
|
+
|
|
12
|
+
# Create array for split labels
|
|
13
|
+
split_labels = []
|
|
14
|
+
for obs in adata.obs_names:
|
|
15
|
+
if obs in train_set:
|
|
16
|
+
split_labels.append("training")
|
|
17
|
+
elif obs in val_set:
|
|
18
|
+
split_labels.append("validation")
|
|
19
|
+
elif obs in test_set:
|
|
20
|
+
split_labels.append("testing")
|
|
21
|
+
else:
|
|
22
|
+
split_labels.append("new")
|
|
23
|
+
|
|
24
|
+
# Store in AnnData.obs
|
|
25
|
+
adata.obs[split_col] = pd.Categorical(split_labels, categories=["training", "validation", "testing", "new"])
|
|
26
|
+
|
|
27
|
+
print(f"Annotated {split_col} column with training/validation/testing/new status.")
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pytorch_lightning import Trainer
|
|
5
|
+
from .inference_utils import annotate_split_column
|
|
6
|
+
|
|
7
|
+
def run_lightning_inference(
|
|
8
|
+
adata,
|
|
9
|
+
model,
|
|
10
|
+
datamodule,
|
|
11
|
+
trainer,
|
|
12
|
+
prefix="model",
|
|
13
|
+
devices=1
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Device logic
|
|
20
|
+
if torch.cuda.is_available():
|
|
21
|
+
accelerator = "gpu"
|
|
22
|
+
elif torch.backends.mps.is_available():
|
|
23
|
+
accelerator = "mps"
|
|
24
|
+
devices = 1
|
|
25
|
+
else:
|
|
26
|
+
accelerator = "cpu"
|
|
27
|
+
devices = 1
|
|
28
|
+
|
|
29
|
+
label_col = model.label_col
|
|
30
|
+
num_classes = model.num_classes
|
|
31
|
+
class_labels = model.class_names
|
|
32
|
+
focus_class = model.focus_class
|
|
33
|
+
focus_class_name = model.focus_class_name
|
|
34
|
+
|
|
35
|
+
annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
|
|
36
|
+
|
|
37
|
+
# Run predictions
|
|
38
|
+
outputs = trainer.predict(model, datamodule=datamodule)
|
|
39
|
+
|
|
40
|
+
preds_list, probs_list = zip(*outputs)
|
|
41
|
+
preds = torch.cat(preds_list, dim=0).cpu().numpy()
|
|
42
|
+
probs = torch.cat(probs_list, dim=0).cpu().numpy()
|
|
43
|
+
|
|
44
|
+
# Handle binary vs multiclass formats
|
|
45
|
+
if num_classes == 2:
|
|
46
|
+
# probs shape: (N,) from sigmoid
|
|
47
|
+
pred_class_idx = (probs >= 0.5).astype(int)
|
|
48
|
+
probs_all = np.vstack([1 - probs, probs]).T # shape (N, 2)
|
|
49
|
+
pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
|
|
50
|
+
else:
|
|
51
|
+
pred_class_idx = probs.argmax(axis=1)
|
|
52
|
+
probs_all = probs
|
|
53
|
+
pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
|
|
54
|
+
|
|
55
|
+
pred_class_labels = [class_labels[i] for i in pred_class_idx]
|
|
56
|
+
|
|
57
|
+
full_prefix = f"{prefix}_{label_col}"
|
|
58
|
+
|
|
59
|
+
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
60
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
|
|
61
|
+
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
62
|
+
|
|
63
|
+
for i, class_name in enumerate(class_labels):
|
|
64
|
+
adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
|
|
65
|
+
|
|
66
|
+
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
|
|
67
|
+
|
|
68
|
+
print(f"Inference complete: stored under prefix '{full_prefix}'")
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .inference_utils import annotate_split_column
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def run_sklearn_inference(
|
|
7
|
+
adata,
|
|
8
|
+
model,
|
|
9
|
+
datamodule,
|
|
10
|
+
prefix="model"
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Run inference on AnnData using SklearnModelWrapper.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
label_col = model.label_col
|
|
17
|
+
num_classes = model.num_classes
|
|
18
|
+
class_labels = model.class_names
|
|
19
|
+
focus_class_name = model.focus_class_name
|
|
20
|
+
|
|
21
|
+
annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
|
|
22
|
+
|
|
23
|
+
datamodule.setup()
|
|
24
|
+
|
|
25
|
+
X_infer = datamodule.to_numpy()
|
|
26
|
+
|
|
27
|
+
# Run predictions
|
|
28
|
+
preds = model.predict(X_infer)
|
|
29
|
+
probs = model.predict_proba(X_infer)
|
|
30
|
+
|
|
31
|
+
# Handle binary vs multiclass formats
|
|
32
|
+
if num_classes == 2:
|
|
33
|
+
# probs shape: (N, 2) from predict_proba
|
|
34
|
+
pred_class_idx = preds
|
|
35
|
+
probs_all = probs
|
|
36
|
+
pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
|
|
37
|
+
else:
|
|
38
|
+
pred_class_idx = preds
|
|
39
|
+
probs_all = probs
|
|
40
|
+
pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
|
|
41
|
+
|
|
42
|
+
pred_class_labels = [class_labels[i] for i in pred_class_idx]
|
|
43
|
+
|
|
44
|
+
full_prefix = f"{prefix}_{label_col}"
|
|
45
|
+
|
|
46
|
+
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
47
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
|
|
48
|
+
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
49
|
+
|
|
50
|
+
for i, class_name in enumerate(class_labels):
|
|
51
|
+
adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
|
|
52
|
+
|
|
53
|
+
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
|
|
54
|
+
|
|
55
|
+
print(f"Inference complete: stored under prefix '{full_prefix}'")
|