smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,12 @@
1
+ from . import models
2
+ from . import data
3
+ from . import utils
4
+ from . import evaluation
5
+ from . import inference
6
+ from . import training
7
+
8
+ __all__ = [
9
+ "calculate_relative_risk_on_activity",
10
+ "evaluate_models_by_subgroup",
11
+ "prepare_melted_model_data",
12
+ ]
@@ -0,0 +1,2 @@
1
+ from .anndata_data_module import AnnDataModule, build_anndata_loader
2
+ from .preprocessing import random_fill_nans
@@ -0,0 +1,234 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset, Subset
3
+ import pytorch_lightning as pl
4
+ import numpy as np
5
+ import pandas as pd
6
+ from .preprocessing import random_fill_nans
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+
9
+
10
+ class AnnDataDataset(Dataset):
11
+ """
12
+ Generic PyTorch Dataset from AnnData.
13
+ """
14
+ def __init__(self, adata, tensor_source="X", tensor_key=None, label_col=None, window_start=None, window_size=None):
15
+ self.adata = adata
16
+ self.tensor_source = tensor_source
17
+ self.tensor_key = tensor_key
18
+ self.label_col = label_col
19
+ self.window_start = window_start
20
+ self.window_size = window_size
21
+
22
+ if tensor_source == "X":
23
+ X = adata.X
24
+ elif tensor_source == "layers":
25
+ assert tensor_key in adata.layers
26
+ X = adata.layers[tensor_key]
27
+ elif tensor_source == "obsm":
28
+ assert tensor_key in adata.obsm
29
+ X = adata.obsm[tensor_key]
30
+ else:
31
+ raise ValueError(f"Invalid tensor_source: {tensor_source}")
32
+
33
+ if self.window_start is not None and self.window_size is not None:
34
+ X = X[:, self.window_start : self.window_start + self.window_size]
35
+
36
+ X = random_fill_nans(X)
37
+
38
+ self.X_tensor = torch.tensor(X, dtype=torch.float32)
39
+
40
+ if label_col is not None:
41
+ y = adata.obs[label_col]
42
+ if y.dtype.name == 'category':
43
+ y = y.cat.codes
44
+ self.y_tensor = torch.tensor(y.values, dtype=torch.long)
45
+ else:
46
+ self.y_tensor = None
47
+
48
+ def numpy(self, indices):
49
+ return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
50
+
51
+ def __len__(self):
52
+ return len(self.X_tensor)
53
+
54
+ def __getitem__(self, idx):
55
+ x = self.X_tensor[idx]
56
+ if self.y_tensor is not None:
57
+ y = self.y_tensor[idx]
58
+ return x, y
59
+ else:
60
+ return (x,)
61
+
62
+
63
+ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
64
+ random_seed=42, split_col="train_val_test_split",
65
+ load_existing_split=False, split_save_path=None):
66
+ """
67
+ Perform split and record assignment into adata.obs[split_col].
68
+ """
69
+ total_len = len(dataset)
70
+
71
+ if load_existing_split:
72
+ if split_col in adata.obs:
73
+ pass # use existing
74
+ elif split_save_path:
75
+ split_df = pd.read_csv(split_save_path, index_col=0)
76
+ adata.obs[split_col] = split_df.loc[adata.obs_names][split_col].values
77
+ else:
78
+ raise ValueError("No existing split column found and no file provided.")
79
+ else:
80
+ indices = np.arange(total_len)
81
+ np.random.seed(random_seed)
82
+ np.random.shuffle(indices)
83
+
84
+ n_train = int(train_frac * total_len)
85
+ n_val = int(val_frac * total_len)
86
+ n_test = total_len - n_train - n_val
87
+
88
+ split_array = np.full(total_len, "test", dtype=object)
89
+ split_array[indices[:n_train]] = "train"
90
+ split_array[indices[n_train:n_train + n_val]] = "val"
91
+ adata.obs[split_col] = split_array
92
+
93
+ if split_save_path:
94
+ adata.obs[[split_col]].to_csv(split_save_path)
95
+
96
+ split_labels = adata.obs[split_col].values
97
+ train_indices = np.where(split_labels == "train")[0]
98
+ val_indices = np.where(split_labels == "val")[0]
99
+ test_indices = np.where(split_labels == "test")[0]
100
+
101
+ train_set = Subset(dataset, train_indices)
102
+ val_set = Subset(dataset, val_indices)
103
+ test_set = Subset(dataset, test_indices)
104
+
105
+ return train_set, val_set, test_set
106
+
107
+ class AnnDataModule(pl.LightningDataModule):
108
+ """
109
+ Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
110
+ """
111
+ def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
112
+ batch_size=64, train_frac=0.6, val_frac=0.1, test_frac=0.3, random_seed=42,
113
+ inference_mode=False, split_col="train_val_test_split", split_save_path=None,
114
+ load_existing_split=False, window_start=None, window_size=None, num_workers=None, persistent_workers=False):
115
+ super().__init__()
116
+ self.adata = adata
117
+ self.tensor_source = tensor_source
118
+ self.tensor_key = tensor_key
119
+ self.label_col = label_col
120
+ self.batch_size = batch_size
121
+ self.train_frac = train_frac
122
+ self.val_frac = val_frac
123
+ self.test_frac = test_frac
124
+ self.random_seed = random_seed
125
+ self.inference_mode = inference_mode
126
+ self.split_col = split_col
127
+ self.split_save_path = split_save_path
128
+ self.load_existing_split = load_existing_split
129
+ self.var_names = adata.var_names.copy()
130
+ self.window_start = window_start
131
+ self.window_size = window_size
132
+ self.num_workers = num_workers
133
+ self.persistent_workers = persistent_workers
134
+
135
+ def setup(self, stage=None):
136
+ dataset = AnnDataDataset(self.adata, self.tensor_source, self.tensor_key,
137
+ None if self.inference_mode else self.label_col,
138
+ window_start=self.window_start, window_size=self.window_size)
139
+
140
+ if self.inference_mode:
141
+ self.infer_dataset = dataset
142
+ return
143
+
144
+ self.train_set, self.val_set, self.test_set = split_dataset(
145
+ self.adata, dataset, train_frac=self.train_frac, val_frac=self.val_frac,
146
+ test_frac=self.test_frac, random_seed=self.random_seed,
147
+ split_col=self.split_col, split_save_path=self.split_save_path,
148
+ load_existing_split=self.load_existing_split
149
+ )
150
+
151
+ def train_dataloader(self):
152
+ if self.num_workers:
153
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
154
+ else:
155
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
156
+
157
+ def val_dataloader(self):
158
+ if self.num_workers:
159
+ return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
160
+ else:
161
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
162
+
163
+ def test_dataloader(self):
164
+ if self.num_workers:
165
+ return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
166
+ else:
167
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
168
+
169
+ def predict_dataloader(self):
170
+ if not self.inference_mode:
171
+ raise RuntimeError("Only valid in inference mode")
172
+ return DataLoader(self.infer_dataset, batch_size=self.batch_size)
173
+
174
+ def compute_class_weights(self):
175
+ train_indices = self.train_set.indices # get the indices of the training set
176
+ y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
177
+ y_train = y_all[train_indices].cpu().numpy() # get the labels for the training set and move to a numpy array
178
+
179
+ class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
180
+ return torch.tensor(class_weights, dtype=torch.float32)
181
+
182
+ def inference_numpy(self):
183
+ """
184
+ Return inference data as numpy for use in sklearn inference.
185
+ """
186
+ if not self.inference_mode:
187
+ raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
188
+ X_np = self.infer_dataset.X_tensor.numpy()
189
+ return X_np
190
+
191
+ def to_numpy(self):
192
+ """
193
+ Move the AnnDataModule tensors into numpy arrays
194
+ """
195
+ if not self.inference_mode:
196
+ train_X, train_y = self.train_set.dataset.numpy(self.train_set.indices)
197
+ val_X, val_y = self.val_set.dataset.numpy(self.val_set.indices)
198
+ test_X, test_Y = self.test_set.dataset.numpy(self.test_set.indices)
199
+ return train_X, train_y, val_X, val_y, test_X, test_Y
200
+ else:
201
+ return self.inference_numpy()
202
+
203
+
204
+ def build_anndata_loader(
205
+ adata, tensor_source="X", tensor_key=None, label_col=None, train_frac=0.6, val_frac=0.1,
206
+ test_frac=0.3, random_seed=42, batch_size=64, lightning=True, inference_mode=False,
207
+ split_col="train_val_test_split", split_save_path=None, load_existing_split=False
208
+ ):
209
+ """
210
+ Unified pipeline for both Lightning and raw PyTorch.
211
+ The lightning loader works for both Lightning and the Sklearn wrapper.
212
+ Set lightning to False if you want to make data loaders for base PyTorch or base sklearn models
213
+ """
214
+ if lightning:
215
+ return AnnDataModule(
216
+ adata, tensor_source=tensor_source, tensor_key=tensor_key, label_col=label_col,
217
+ batch_size=batch_size, train_frac=train_frac, val_frac=val_frac, test_frac=test_frac,
218
+ random_seed=random_seed, inference_mode=inference_mode,
219
+ split_col=split_col, split_save_path=split_save_path, load_existing_split=load_existing_split
220
+ )
221
+ else:
222
+ var_names = adata.var_names.copy()
223
+ dataset = AnnDataDataset(adata, tensor_source, tensor_key, None if inference_mode else label_col)
224
+ if inference_mode:
225
+ return DataLoader(dataset, batch_size=batch_size)
226
+ else:
227
+ train_set, val_set, test_set = split_dataset(
228
+ adata, dataset, train_frac, val_frac, test_frac, random_seed,
229
+ split_col, split_save_path, load_existing_split
230
+ )
231
+ train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
232
+ val_loader = DataLoader(val_set, batch_size=batch_size)
233
+ test_loader = DataLoader(test_set, batch_size=batch_size)
234
+ return train_loader, val_loader, test_loader
@@ -0,0 +1,6 @@
1
+ import numpy as np
2
+
3
+ def random_fill_nans(X):
4
+ nan_mask = np.isnan(X)
5
+ X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
+ return X
@@ -0,0 +1,2 @@
1
+ from .evaluators import ModelEvaluator, PostInferenceModelEvaluator
2
+ from .eval_utils import flatten_sliding_window_results
@@ -0,0 +1,31 @@
1
+ import pandas as pd
2
+
3
+ def flatten_sliding_window_results(results_dict):
4
+ """
5
+ Flatten nested sliding window results into pandas DataFrame.
6
+
7
+ Expects structure:
8
+ results[model_name][window_size][window_center]['metrics'][metric_name]
9
+ """
10
+ records = []
11
+
12
+ for model_name, model_results in results_dict.items():
13
+ for window_size, window_results in model_results.items():
14
+ for center_var, result in window_results.items():
15
+ metrics = result['metrics']
16
+ record = {
17
+ 'model': model_name,
18
+ 'window_size': window_size,
19
+ 'center_var': center_var
20
+ }
21
+ # Add all metrics
22
+ record.update(metrics)
23
+ records.append(record)
24
+
25
+ df = pd.DataFrame.from_records(records)
26
+
27
+ # Convert center_var to numeric if possible (optional but helpful for plotting)
28
+ df['center_var'] = pd.to_numeric(df['center_var'], errors='coerce')
29
+ df = df.sort_values(['model', 'window_size', 'center_var'])
30
+
31
+ return df
@@ -0,0 +1,223 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+
5
+ from sklearn.metrics import (
6
+ roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
7
+ )
8
+
9
+ class ModelEvaluator:
10
+ """
11
+ A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
12
+ """
13
+ def __init__(self):
14
+ self.results = []
15
+ self.pos_freq = None
16
+ self.num_pos = None
17
+
18
+ def add_model(self, name, model, is_torch=True):
19
+ """
20
+ Add a trained model with its evaluation metrics.
21
+ """
22
+ if is_torch:
23
+ entry = {
24
+ 'name': name,
25
+ 'f1': model.test_f1,
26
+ 'auc': model.test_roc_auc,
27
+ 'pr_auc': model.test_pr_auc,
28
+ 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
29
+ 'pr_curve': model.test_pr_curve,
30
+ 'roc_curve': model.test_roc_curve,
31
+ 'num_pos': model.test_num_pos,
32
+ 'pos_freq': model.test_pos_freq
33
+ }
34
+ else:
35
+ entry = {
36
+ 'name': name,
37
+ 'f1': model.test_f1,
38
+ 'auc': model.test_roc_auc,
39
+ 'pr_auc': model.test_pr_auc,
40
+ 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
41
+ 'pr_curve': model.test_pr_curve,
42
+ 'roc_curve': model.test_roc_curve,
43
+ 'num_pos': model.test_num_pos,
44
+ 'pos_freq': model.test_pos_freq
45
+ }
46
+
47
+ self.results.append(entry)
48
+
49
+ if not self.pos_freq:
50
+ self.pos_freq = entry['pos_freq']
51
+ self.num_pos = entry['num_pos']
52
+
53
+ def get_metrics_dataframe(self):
54
+ """
55
+ Return all metrics as pandas DataFrame.
56
+ """
57
+ df = pd.DataFrame(self.results)
58
+ return df[['name', 'f1', 'auc', 'pr_auc', 'pr_auc_norm', 'num_pos', 'pos_freq']]
59
+
60
+ def plot_all_curves(self):
61
+ """
62
+ Plot unified ROC and PR curves across all models.
63
+ """
64
+ plt.figure(figsize=(12, 5))
65
+
66
+ # ROC
67
+ plt.subplot(1, 2, 1)
68
+ for res in self.results:
69
+ fpr, tpr = res['roc_curve']
70
+ plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
71
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
72
+ plt.xlabel("False Positive Rate")
73
+ plt.ylabel("True Positive Rate")
74
+ plt.ylim(0,1.05)
75
+ plt.title(f"ROC Curves - {self.num_pos} positive instances")
76
+ plt.legend()
77
+
78
+ # PR
79
+ plt.subplot(1, 2, 2)
80
+ for res in self.results:
81
+ rc, pr = res['pr_curve']
82
+ plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
83
+ plt.xlabel("Recall")
84
+ plt.ylabel("Precision")
85
+ plt.ylim(0,1.05)
86
+ plt.axhline(self.pos_freq, linestyle='--', color='grey')
87
+ plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
88
+ plt.legend()
89
+
90
+ plt.tight_layout()
91
+ plt.show()
92
+
93
+ class PostInferenceModelEvaluator:
94
+ def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
95
+ """
96
+ Initialize evaluator.
97
+
98
+ Parameters:
99
+ -----------
100
+ adata : AnnData
101
+ The annotated dataset where predictions are stored in obs/obsm.
102
+ models : dict
103
+ Dictionary of models: {model_name: model_instance}.
104
+ Supports TorchClassifierWrapper and SklearnModelWrapper.
105
+ """
106
+ self.adata = adata
107
+ self.models = models
108
+ self.target_eval_freq = target_eval_freq
109
+ self.max_eval_positive = max_eval_positive
110
+ self.results = {}
111
+
112
+ def evaluate_all(self):
113
+ """
114
+ Evaluate all models and store results.
115
+ """
116
+ for name, model in self.models.items():
117
+ print(f"Evaluating {name}...")
118
+ label_col = model.label_col
119
+ full_prefix = f"{name}_{label_col}"
120
+ self.results[full_prefix] = self._evaluate_model(name, model)
121
+
122
+ def _evaluate_model(self, model_name, model):
123
+ """
124
+ Evaluate one model and return metrics.
125
+ """
126
+ label_col = model.label_col
127
+ num_classes = model.num_classes
128
+ class_names = model.class_names
129
+ focus_class = model.focus_class
130
+
131
+ full_prefix = f"{model_name}_{label_col}"
132
+
133
+ # Extract ground truth + predictions
134
+ y_true = self.adata.obs[label_col].cat.codes.to_numpy()
135
+ y_pred = self.adata.obs[f"{full_prefix}_pred"].to_numpy()
136
+ probs_all = self.adata.obsm[f"{full_prefix}_pred_prob_all"]
137
+
138
+ binary_focus = (y_true == focus_class).astype(int)
139
+
140
+ # OPTIONAL SUBSAMPLING
141
+ if self.target_eval_freq is not None:
142
+ indices = self._subsample_for_fixed_positive_frequency(
143
+ binary_focus, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
144
+ )
145
+ y_true = y_true[indices]
146
+ y_pred = y_pred[indices]
147
+ probs_all = probs_all[indices]
148
+ binary_focus = (y_true == focus_class).astype(int)
149
+
150
+ acc = np.mean(y_true == y_pred)
151
+
152
+ if num_classes == 2:
153
+ focus_probs = probs_all[:, focus_class]
154
+ f1 = f1_score(binary_focus, (y_pred == focus_class).astype(int))
155
+ roc_auc = roc_auc_score(binary_focus, focus_probs)
156
+ pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
157
+ fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
158
+ pr_auc = auc(rc, pr)
159
+ pos_freq = binary_focus.mean()
160
+ pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
161
+ else:
162
+ f1 = f1_score(y_true, y_pred, average="macro")
163
+ roc_auc = roc_auc_score(y_true, probs_all, multi_class="ovr", average="macro")
164
+ focus_probs = probs_all[:, focus_class]
165
+ pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
166
+ fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
167
+ pr_auc = auc(rc, pr)
168
+ pos_freq = binary_focus.mean()
169
+ pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
170
+
171
+ cm = confusion_matrix(y_true, y_pred)
172
+
173
+ metrics = {
174
+ "accuracy": acc,
175
+ "f1": f1,
176
+ "roc_auc": roc_auc,
177
+ "pr_auc": pr_auc,
178
+ "pr_auc_norm": pr_auc_norm,
179
+ "pos_freq": pos_freq,
180
+ "confusion_matrix": cm,
181
+ "pr_rc_curve": (pr, rc),
182
+ "roc_curve": (tpr, fpr)
183
+ }
184
+
185
+ return metrics
186
+
187
+ def _subsample_for_fixed_positive_frequency(self, binary_labels, target_freq=0.3, max_positive=None):
188
+ pos_idx = np.where(binary_labels == 1)[0]
189
+ neg_idx = np.where(binary_labels == 0)[0]
190
+
191
+ max_pos = len(pos_idx)
192
+ max_neg = len(neg_idx)
193
+
194
+ max_possible_freq = max_pos / (max_pos + max_neg)
195
+ if target_freq > max_possible_freq:
196
+ target_freq = max_possible_freq
197
+
198
+ num_pos_target = int(target_freq * max_neg / (1 - target_freq))
199
+ num_pos_target = min(num_pos_target, max_pos)
200
+ if max_positive is not None:
201
+ num_pos_target = min(num_pos_target, max_positive)
202
+
203
+ num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
204
+ num_neg_target = min(num_neg_target, max_neg)
205
+
206
+ pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
207
+ neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
208
+ sampled_idx = np.concatenate([pos_sampled, neg_sampled])
209
+ np.random.shuffle(sampled_idx)
210
+ return sampled_idx
211
+
212
+ def to_dataframe(self):
213
+ """
214
+ Convert results to pandas DataFrame (excluding confusion matrices).
215
+ """
216
+ records = []
217
+ for model_name, metrics in self.results.items():
218
+ row = {"model": model_name}
219
+ for k, v in metrics.items():
220
+ if k not in ["confusion_matrix", "pr_rc_curve", "roc_curve"]:
221
+ row[k] = v
222
+ records.append(row)
223
+ return pd.DataFrame(records)
@@ -0,0 +1,3 @@
1
+ from .lightning_inference import run_lightning_inference
2
+ from .sliding_window_inference import sliding_window_inference
3
+ from .sklearn_inference import run_sklearn_inference
@@ -0,0 +1,27 @@
1
+ import pandas as pd
2
+
3
+ def annotate_split_column(adata, model, split_col="split"):
4
+ """
5
+ Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
6
+ """
7
+ # Get sets for fast lookup
8
+ train_set = set(model.train_obs_names)
9
+ val_set = set(model.val_obs_names)
10
+ test_set = set(model.test_obs_names)
11
+
12
+ # Create array for split labels
13
+ split_labels = []
14
+ for obs in adata.obs_names:
15
+ if obs in train_set:
16
+ split_labels.append("training")
17
+ elif obs in val_set:
18
+ split_labels.append("validation")
19
+ elif obs in test_set:
20
+ split_labels.append("testing")
21
+ else:
22
+ split_labels.append("new")
23
+
24
+ # Store in AnnData.obs
25
+ adata.obs[split_col] = pd.Categorical(split_labels, categories=["training", "validation", "testing", "new"])
26
+
27
+ print(f"Annotated {split_col} column with training/validation/testing/new status.")
@@ -0,0 +1,68 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ from pytorch_lightning import Trainer
5
+ from .inference_utils import annotate_split_column
6
+
7
+ def run_lightning_inference(
8
+ adata,
9
+ model,
10
+ datamodule,
11
+ trainer,
12
+ prefix="model",
13
+ devices=1
14
+ ):
15
+ """
16
+ Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
17
+ """
18
+
19
+ # Device logic
20
+ if torch.cuda.is_available():
21
+ accelerator = "gpu"
22
+ elif torch.backends.mps.is_available():
23
+ accelerator = "mps"
24
+ devices = 1
25
+ else:
26
+ accelerator = "cpu"
27
+ devices = 1
28
+
29
+ label_col = model.label_col
30
+ num_classes = model.num_classes
31
+ class_labels = model.class_names
32
+ focus_class = model.focus_class
33
+ focus_class_name = model.focus_class_name
34
+
35
+ annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
36
+
37
+ # Run predictions
38
+ outputs = trainer.predict(model, datamodule=datamodule)
39
+
40
+ preds_list, probs_list = zip(*outputs)
41
+ preds = torch.cat(preds_list, dim=0).cpu().numpy()
42
+ probs = torch.cat(probs_list, dim=0).cpu().numpy()
43
+
44
+ # Handle binary vs multiclass formats
45
+ if num_classes == 2:
46
+ # probs shape: (N,) from sigmoid
47
+ pred_class_idx = (probs >= 0.5).astype(int)
48
+ probs_all = np.vstack([1 - probs, probs]).T # shape (N, 2)
49
+ pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
50
+ else:
51
+ pred_class_idx = probs.argmax(axis=1)
52
+ probs_all = probs
53
+ pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
54
+
55
+ pred_class_labels = [class_labels[i] for i in pred_class_idx]
56
+
57
+ full_prefix = f"{prefix}_{label_col}"
58
+
59
+ adata.obs[f"{full_prefix}_pred"] = pred_class_idx
60
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
61
+ adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
62
+
63
+ for i, class_name in enumerate(class_labels):
64
+ adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
65
+
66
+ adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
67
+
68
+ print(f"Inference complete: stored under prefix '{full_prefix}'")
@@ -0,0 +1,55 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from .inference_utils import annotate_split_column
4
+
5
+
6
+ def run_sklearn_inference(
7
+ adata,
8
+ model,
9
+ datamodule,
10
+ prefix="model"
11
+ ):
12
+ """
13
+ Run inference on AnnData using SklearnModelWrapper.
14
+ """
15
+
16
+ label_col = model.label_col
17
+ num_classes = model.num_classes
18
+ class_labels = model.class_names
19
+ focus_class_name = model.focus_class_name
20
+
21
+ annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
22
+
23
+ datamodule.setup()
24
+
25
+ X_infer = datamodule.to_numpy()
26
+
27
+ # Run predictions
28
+ preds = model.predict(X_infer)
29
+ probs = model.predict_proba(X_infer)
30
+
31
+ # Handle binary vs multiclass formats
32
+ if num_classes == 2:
33
+ # probs shape: (N, 2) from predict_proba
34
+ pred_class_idx = preds
35
+ probs_all = probs
36
+ pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
37
+ else:
38
+ pred_class_idx = preds
39
+ probs_all = probs
40
+ pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
41
+
42
+ pred_class_labels = [class_labels[i] for i in pred_class_idx]
43
+
44
+ full_prefix = f"{prefix}_{label_col}"
45
+
46
+ adata.obs[f"{full_prefix}_pred"] = pred_class_idx
47
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
48
+ adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
49
+
50
+ for i, class_name in enumerate(class_labels):
51
+ adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
52
+
53
+ adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
54
+
55
+ print(f"Inference complete: stored under prefix '{full_prefix}'")