smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,223 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+
5
+ from sklearn.metrics import (
6
+ roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
7
+ )
8
+
9
+ class ModelEvaluator:
10
+ """
11
+ A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
12
+ """
13
+ def __init__(self):
14
+ self.results = []
15
+ self.pos_freq = None
16
+ self.num_pos = None
17
+
18
+ def add_model(self, name, model, is_torch=True):
19
+ """
20
+ Add a trained model with its evaluation metrics.
21
+ """
22
+ if is_torch:
23
+ entry = {
24
+ 'name': name,
25
+ 'f1': model.test_f1,
26
+ 'auc': model.test_roc_auc,
27
+ 'pr_auc': model.test_pr_auc,
28
+ 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
29
+ 'pr_curve': model.test_pr_curve,
30
+ 'roc_curve': model.test_roc_curve,
31
+ 'num_pos': model.test_num_pos,
32
+ 'pos_freq': model.test_pos_freq
33
+ }
34
+ else:
35
+ entry = {
36
+ 'name': name,
37
+ 'f1': model.test_f1,
38
+ 'auc': model.test_roc_auc,
39
+ 'pr_auc': model.test_pr_auc,
40
+ 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
41
+ 'pr_curve': model.test_pr_curve,
42
+ 'roc_curve': model.test_roc_curve,
43
+ 'num_pos': model.test_num_pos,
44
+ 'pos_freq': model.test_pos_freq
45
+ }
46
+
47
+ self.results.append(entry)
48
+
49
+ if not self.pos_freq:
50
+ self.pos_freq = entry['pos_freq']
51
+ self.num_pos = entry['num_pos']
52
+
53
+ def get_metrics_dataframe(self):
54
+ """
55
+ Return all metrics as pandas DataFrame.
56
+ """
57
+ df = pd.DataFrame(self.results)
58
+ return df[['name', 'f1', 'auc', 'pr_auc', 'pr_auc_norm', 'num_pos', 'pos_freq']]
59
+
60
+ def plot_all_curves(self):
61
+ """
62
+ Plot unified ROC and PR curves across all models.
63
+ """
64
+ plt.figure(figsize=(12, 5))
65
+
66
+ # ROC
67
+ plt.subplot(1, 2, 1)
68
+ for res in self.results:
69
+ fpr, tpr = res['roc_curve']
70
+ plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
71
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
72
+ plt.xlabel("False Positive Rate")
73
+ plt.ylabel("True Positive Rate")
74
+ plt.ylim(0,1.05)
75
+ plt.title(f"ROC Curves - {self.num_pos} positive instances")
76
+ plt.legend()
77
+
78
+ # PR
79
+ plt.subplot(1, 2, 2)
80
+ for res in self.results:
81
+ rc, pr = res['pr_curve']
82
+ plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
83
+ plt.xlabel("Recall")
84
+ plt.ylabel("Precision")
85
+ plt.ylim(0,1.05)
86
+ plt.axhline(self.pos_freq, linestyle='--', color='grey')
87
+ plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
88
+ plt.legend()
89
+
90
+ plt.tight_layout()
91
+ plt.show()
92
+
93
+ class PostInferenceModelEvaluator:
94
+ def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
95
+ """
96
+ Initialize evaluator.
97
+
98
+ Parameters:
99
+ -----------
100
+ adata : AnnData
101
+ The annotated dataset where predictions are stored in obs/obsm.
102
+ models : dict
103
+ Dictionary of models: {model_name: model_instance}.
104
+ Supports TorchClassifierWrapper and SklearnModelWrapper.
105
+ """
106
+ self.adata = adata
107
+ self.models = models
108
+ self.target_eval_freq = target_eval_freq
109
+ self.max_eval_positive = max_eval_positive
110
+ self.results = {}
111
+
112
+ def evaluate_all(self):
113
+ """
114
+ Evaluate all models and store results.
115
+ """
116
+ for name, model in self.models.items():
117
+ print(f"Evaluating {name}...")
118
+ label_col = model.label_col
119
+ full_prefix = f"{name}_{label_col}"
120
+ self.results[full_prefix] = self._evaluate_model(name, model)
121
+
122
+ def _evaluate_model(self, model_name, model):
123
+ """
124
+ Evaluate one model and return metrics.
125
+ """
126
+ label_col = model.label_col
127
+ num_classes = model.num_classes
128
+ class_names = model.class_names
129
+ focus_class = model.focus_class
130
+
131
+ full_prefix = f"{model_name}_{label_col}"
132
+
133
+ # Extract ground truth + predictions
134
+ y_true = self.adata.obs[label_col].cat.codes.to_numpy()
135
+ y_pred = self.adata.obs[f"{full_prefix}_pred"].to_numpy()
136
+ probs_all = self.adata.obsm[f"{full_prefix}_pred_prob_all"]
137
+
138
+ binary_focus = (y_true == focus_class).astype(int)
139
+
140
+ # OPTIONAL SUBSAMPLING
141
+ if self.target_eval_freq is not None:
142
+ indices = self._subsample_for_fixed_positive_frequency(
143
+ binary_focus, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
144
+ )
145
+ y_true = y_true[indices]
146
+ y_pred = y_pred[indices]
147
+ probs_all = probs_all[indices]
148
+ binary_focus = (y_true == focus_class).astype(int)
149
+
150
+ acc = np.mean(y_true == y_pred)
151
+
152
+ if num_classes == 2:
153
+ focus_probs = probs_all[:, focus_class]
154
+ f1 = f1_score(binary_focus, (y_pred == focus_class).astype(int))
155
+ roc_auc = roc_auc_score(binary_focus, focus_probs)
156
+ pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
157
+ fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
158
+ pr_auc = auc(rc, pr)
159
+ pos_freq = binary_focus.mean()
160
+ pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
161
+ else:
162
+ f1 = f1_score(y_true, y_pred, average="macro")
163
+ roc_auc = roc_auc_score(y_true, probs_all, multi_class="ovr", average="macro")
164
+ focus_probs = probs_all[:, focus_class]
165
+ pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
166
+ fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
167
+ pr_auc = auc(rc, pr)
168
+ pos_freq = binary_focus.mean()
169
+ pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
170
+
171
+ cm = confusion_matrix(y_true, y_pred)
172
+
173
+ metrics = {
174
+ "accuracy": acc,
175
+ "f1": f1,
176
+ "roc_auc": roc_auc,
177
+ "pr_auc": pr_auc,
178
+ "pr_auc_norm": pr_auc_norm,
179
+ "pos_freq": pos_freq,
180
+ "confusion_matrix": cm,
181
+ "pr_rc_curve": (pr, rc),
182
+ "roc_curve": (tpr, fpr)
183
+ }
184
+
185
+ return metrics
186
+
187
+ def _subsample_for_fixed_positive_frequency(self, binary_labels, target_freq=0.3, max_positive=None):
188
+ pos_idx = np.where(binary_labels == 1)[0]
189
+ neg_idx = np.where(binary_labels == 0)[0]
190
+
191
+ max_pos = len(pos_idx)
192
+ max_neg = len(neg_idx)
193
+
194
+ max_possible_freq = max_pos / (max_pos + max_neg)
195
+ if target_freq > max_possible_freq:
196
+ target_freq = max_possible_freq
197
+
198
+ num_pos_target = int(target_freq * max_neg / (1 - target_freq))
199
+ num_pos_target = min(num_pos_target, max_pos)
200
+ if max_positive is not None:
201
+ num_pos_target = min(num_pos_target, max_positive)
202
+
203
+ num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
204
+ num_neg_target = min(num_neg_target, max_neg)
205
+
206
+ pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
207
+ neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
208
+ sampled_idx = np.concatenate([pos_sampled, neg_sampled])
209
+ np.random.shuffle(sampled_idx)
210
+ return sampled_idx
211
+
212
+ def to_dataframe(self):
213
+ """
214
+ Convert results to pandas DataFrame (excluding confusion matrices).
215
+ """
216
+ records = []
217
+ for model_name, metrics in self.results.items():
218
+ row = {"model": model_name}
219
+ for k, v in metrics.items():
220
+ if k not in ["confusion_matrix", "pr_rc_curve", "roc_curve"]:
221
+ row[k] = v
222
+ records.append(row)
223
+ return pd.DataFrame(records)
@@ -0,0 +1,3 @@
1
+ from .lightning_inference import run_lightning_inference
2
+ from .sliding_window_inference import sliding_window_inference
3
+ from .sklearn_inference import run_sklearn_inference
@@ -0,0 +1,27 @@
1
+ import pandas as pd
2
+
3
+ def annotate_split_column(adata, model, split_col="split"):
4
+ """
5
+ Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
6
+ """
7
+ # Get sets for fast lookup
8
+ train_set = set(model.train_obs_names)
9
+ val_set = set(model.val_obs_names)
10
+ test_set = set(model.test_obs_names)
11
+
12
+ # Create array for split labels
13
+ split_labels = []
14
+ for obs in adata.obs_names:
15
+ if obs in train_set:
16
+ split_labels.append("training")
17
+ elif obs in val_set:
18
+ split_labels.append("validation")
19
+ elif obs in test_set:
20
+ split_labels.append("testing")
21
+ else:
22
+ split_labels.append("new")
23
+
24
+ # Store in AnnData.obs
25
+ adata.obs[split_col] = pd.Categorical(split_labels, categories=["training", "validation", "testing", "new"])
26
+
27
+ print(f"Annotated {split_col} column with training/validation/testing/new status.")
@@ -0,0 +1,68 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ from pytorch_lightning import Trainer
5
+ from .inference_utils import annotate_split_column
6
+
7
+ def run_lightning_inference(
8
+ adata,
9
+ model,
10
+ datamodule,
11
+ trainer,
12
+ prefix="model",
13
+ devices=1
14
+ ):
15
+ """
16
+ Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
17
+ """
18
+
19
+ # Device logic
20
+ if torch.cuda.is_available():
21
+ accelerator = "gpu"
22
+ elif torch.backends.mps.is_available():
23
+ accelerator = "mps"
24
+ devices = 1
25
+ else:
26
+ accelerator = "cpu"
27
+ devices = 1
28
+
29
+ label_col = model.label_col
30
+ num_classes = model.num_classes
31
+ class_labels = model.class_names
32
+ focus_class = model.focus_class
33
+ focus_class_name = model.focus_class_name
34
+
35
+ annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
36
+
37
+ # Run predictions
38
+ outputs = trainer.predict(model, datamodule=datamodule)
39
+
40
+ preds_list, probs_list = zip(*outputs)
41
+ preds = torch.cat(preds_list, dim=0).cpu().numpy()
42
+ probs = torch.cat(probs_list, dim=0).cpu().numpy()
43
+
44
+ # Handle binary vs multiclass formats
45
+ if num_classes == 2:
46
+ # probs shape: (N,) from sigmoid
47
+ pred_class_idx = (probs >= 0.5).astype(int)
48
+ probs_all = np.vstack([1 - probs, probs]).T # shape (N, 2)
49
+ pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
50
+ else:
51
+ pred_class_idx = probs.argmax(axis=1)
52
+ probs_all = probs
53
+ pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
54
+
55
+ pred_class_labels = [class_labels[i] for i in pred_class_idx]
56
+
57
+ full_prefix = f"{prefix}_{label_col}"
58
+
59
+ adata.obs[f"{full_prefix}_pred"] = pred_class_idx
60
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
61
+ adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
62
+
63
+ for i, class_name in enumerate(class_labels):
64
+ adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
65
+
66
+ adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
67
+
68
+ print(f"Inference complete: stored under prefix '{full_prefix}'")
@@ -0,0 +1,55 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from .inference_utils import annotate_split_column
4
+
5
+
6
+ def run_sklearn_inference(
7
+ adata,
8
+ model,
9
+ datamodule,
10
+ prefix="model"
11
+ ):
12
+ """
13
+ Run inference on AnnData using SklearnModelWrapper.
14
+ """
15
+
16
+ label_col = model.label_col
17
+ num_classes = model.num_classes
18
+ class_labels = model.class_names
19
+ focus_class_name = model.focus_class_name
20
+
21
+ annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
22
+
23
+ datamodule.setup()
24
+
25
+ X_infer = datamodule.to_numpy()
26
+
27
+ # Run predictions
28
+ preds = model.predict(X_infer)
29
+ probs = model.predict_proba(X_infer)
30
+
31
+ # Handle binary vs multiclass formats
32
+ if num_classes == 2:
33
+ # probs shape: (N, 2) from predict_proba
34
+ pred_class_idx = preds
35
+ probs_all = probs
36
+ pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
37
+ else:
38
+ pred_class_idx = preds
39
+ probs_all = probs
40
+ pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
41
+
42
+ pred_class_labels = [class_labels[i] for i in pred_class_idx]
43
+
44
+ full_prefix = f"{prefix}_{label_col}"
45
+
46
+ adata.obs[f"{full_prefix}_pred"] = pred_class_idx
47
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
48
+ adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
49
+
50
+ for i, class_name in enumerate(class_labels):
51
+ adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
52
+
53
+ adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
54
+
55
+ print(f"Inference complete: stored under prefix '{full_prefix}'")
@@ -0,0 +1,114 @@
1
+ from ..data import AnnDataModule
2
+ from ..evaluation import PostInferenceModelEvaluator
3
+ from .lightning_inference import run_lightning_inference
4
+ from .sklearn_inference import run_sklearn_inference
5
+
6
+ def sliding_window_inference(
7
+ adata,
8
+ trained_results,
9
+ tensor_source='X',
10
+ tensor_key=None,
11
+ label_col='activity_status',
12
+ batch_size=64,
13
+ cleanup=False,
14
+ target_eval_freq=None,
15
+ max_eval_positive=None
16
+ ):
17
+ """
18
+ Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
19
+ Evaluate model performance and return a df.
20
+ Optionally remove the appended inference columns from AnnData to clean up obs namespace.
21
+ """
22
+ ## Inference using trained models
23
+ for model_name, model_dict in trained_results.items():
24
+ for window_size, window_data in model_dict.items():
25
+ for center_varname, run in window_data.items():
26
+ print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
27
+
28
+ # Extract window start from varname
29
+ center_idx = adata.var_names.get_loc(center_varname)
30
+ window_start = center_idx - window_size // 2
31
+
32
+ # Build datamodule for window
33
+ datamodule = AnnDataModule(
34
+ adata,
35
+ tensor_source=tensor_source,
36
+ tensor_key=tensor_key,
37
+ label_col=label_col,
38
+ batch_size=batch_size,
39
+ window_start=window_start,
40
+ window_size=window_size,
41
+ inference_mode=True
42
+ )
43
+ datamodule.setup()
44
+
45
+ # Extract model + detect type
46
+ model = run['model']
47
+
48
+ # Lightning models
49
+ if hasattr(run, 'trainer') or 'trainer' in run:
50
+ trainer = run['trainer']
51
+ run_lightning_inference(
52
+ adata,
53
+ model=model,
54
+ datamodule=datamodule,
55
+ trainer=trainer,
56
+ prefix=f"{model_name}_w{window_size}_c{center_varname}"
57
+ )
58
+
59
+ # Sklearn models
60
+ else:
61
+ run_sklearn_inference(
62
+ adata,
63
+ model=model,
64
+ datamodule=datamodule,
65
+ prefix=f"{model_name}_w{window_size}_c{center_varname}"
66
+ )
67
+
68
+ print("Inference complete across all models.")
69
+
70
+ ## Post-inference model evaluation
71
+ model_wrappers = {}
72
+
73
+ for model_name, model_dict in trained_results.items():
74
+ for window_size, window_data in model_dict.items():
75
+ for center_varname, run in window_data.items():
76
+ # Reconstruct the prefix string you used in inference
77
+ prefix = f"{model_name}_w{window_size}_c{center_varname}"
78
+ # Use full key for uniqueness
79
+ key = prefix
80
+ model_wrappers[key] = run['model']
81
+
82
+ # Run evaluator
83
+ evaluator = PostInferenceModelEvaluator(adata, model_wrappers, target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive)
84
+ evaluator.evaluate_all()
85
+
86
+ # Get results
87
+ df = evaluator.to_dataframe()
88
+
89
+ df[['model_name', 'window_size', 'center']] = df['model'].str.extract(r'(\w+)_w(\d+)_c(\d+)_activity_status')
90
+
91
+ # Cast window_size and center to integers for plotting
92
+ df['window_size'] = df['window_size'].astype(int)
93
+ df['center'] = df['center'].astype(int)
94
+
95
+ ## Optional cleanup:
96
+ if cleanup:
97
+ prefixes = [f"{model_name}_w{window_size}_c{center_varname}"
98
+ for model_name, model_dict in trained_results.items()
99
+ for window_size, window_data in model_dict.items()
100
+ for center_varname in window_data.keys()]
101
+
102
+ # Remove matching obs columns
103
+ for prefix in prefixes:
104
+ to_remove = [col for col in adata.obs.columns if col.startswith(prefix)]
105
+ adata.obs.drop(columns=to_remove, inplace=True)
106
+
107
+ # Remove obsm entries if any
108
+ obsm_key = f"{prefix}_pred_prob_all"
109
+ if obsm_key in adata.obsm:
110
+ del adata.obsm[obsm_key]
111
+
112
+ print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
113
+
114
+ return df