smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
- import matplotlib.pyplot as plt
4
5
 
5
- from sklearn.metrics import (
6
- roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
7
- )
6
+ from smftools.optional_imports import require
7
+
8
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="evaluation plots")
9
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
10
+
11
+ auc = sklearn_metrics.auc
12
+ confusion_matrix = sklearn_metrics.confusion_matrix
13
+ f1_score = sklearn_metrics.f1_score
14
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
15
+ roc_auc_score = sklearn_metrics.roc_auc_score
16
+ roc_curve = sklearn_metrics.roc_curve
17
+
8
18
 
9
19
  class ModelEvaluator:
10
20
  """
11
21
  A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
12
22
  """
23
+
13
24
  def __init__(self):
14
25
  self.results = []
15
26
  self.pos_freq = None
@@ -21,41 +32,45 @@ class ModelEvaluator:
21
32
  """
22
33
  if is_torch:
23
34
  entry = {
24
- 'name': name,
25
- 'f1': model.test_f1,
26
- 'auc': model.test_roc_auc,
27
- 'pr_auc': model.test_pr_auc,
28
- 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
29
- 'pr_curve': model.test_pr_curve,
30
- 'roc_curve': model.test_roc_curve,
31
- 'num_pos': model.test_num_pos,
32
- 'pos_freq': model.test_pos_freq
35
+ "name": name,
36
+ "f1": model.test_f1,
37
+ "auc": model.test_roc_auc,
38
+ "pr_auc": model.test_pr_auc,
39
+ "pr_auc_norm": model.test_pr_auc / model.test_pos_freq
40
+ if model.test_pos_freq > 0
41
+ else np.nan,
42
+ "pr_curve": model.test_pr_curve,
43
+ "roc_curve": model.test_roc_curve,
44
+ "num_pos": model.test_num_pos,
45
+ "pos_freq": model.test_pos_freq,
33
46
  }
34
47
  else:
35
48
  entry = {
36
- 'name': name,
37
- 'f1': model.test_f1,
38
- 'auc': model.test_roc_auc,
39
- 'pr_auc': model.test_pr_auc,
40
- 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
41
- 'pr_curve': model.test_pr_curve,
42
- 'roc_curve': model.test_roc_curve,
43
- 'num_pos': model.test_num_pos,
44
- 'pos_freq': model.test_pos_freq
49
+ "name": name,
50
+ "f1": model.test_f1,
51
+ "auc": model.test_roc_auc,
52
+ "pr_auc": model.test_pr_auc,
53
+ "pr_auc_norm": model.test_pr_auc / model.test_pos_freq
54
+ if model.test_pos_freq > 0
55
+ else np.nan,
56
+ "pr_curve": model.test_pr_curve,
57
+ "roc_curve": model.test_roc_curve,
58
+ "num_pos": model.test_num_pos,
59
+ "pos_freq": model.test_pos_freq,
45
60
  }
46
-
61
+
47
62
  self.results.append(entry)
48
63
 
49
64
  if not self.pos_freq:
50
- self.pos_freq = entry['pos_freq']
51
- self.num_pos = entry['num_pos']
65
+ self.pos_freq = entry["pos_freq"]
66
+ self.num_pos = entry["num_pos"]
52
67
 
53
68
  def get_metrics_dataframe(self):
54
69
  """
55
70
  Return all metrics as pandas DataFrame.
56
71
  """
57
72
  df = pd.DataFrame(self.results)
58
- return df[['name', 'f1', 'auc', 'pr_auc', 'pr_auc_norm', 'num_pos', 'pos_freq']]
73
+ return df[["name", "f1", "auc", "pr_auc", "pr_auc_norm", "num_pos", "pos_freq"]]
59
74
 
60
75
  def plot_all_curves(self):
61
76
  """
@@ -66,30 +81,31 @@ class ModelEvaluator:
66
81
  # ROC
67
82
  plt.subplot(1, 2, 1)
68
83
  for res in self.results:
69
- fpr, tpr = res['roc_curve']
84
+ fpr, tpr = res["roc_curve"]
70
85
  plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
71
86
  plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
72
87
  plt.xlabel("False Positive Rate")
73
88
  plt.ylabel("True Positive Rate")
74
- plt.ylim(0,1.05)
89
+ plt.ylim(0, 1.05)
75
90
  plt.title(f"ROC Curves - {self.num_pos} positive instances")
76
91
  plt.legend()
77
92
 
78
93
  # PR
79
94
  plt.subplot(1, 2, 2)
80
95
  for res in self.results:
81
- rc, pr = res['pr_curve']
96
+ rc, pr = res["pr_curve"]
82
97
  plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
83
98
  plt.xlabel("Recall")
84
99
  plt.ylabel("Precision")
85
- plt.ylim(0,1.05)
86
- plt.axhline(self.pos_freq, linestyle='--', color='grey')
100
+ plt.ylim(0, 1.05)
101
+ plt.axhline(self.pos_freq, linestyle="--", color="grey")
87
102
  plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
88
103
  plt.legend()
89
104
 
90
105
  plt.tight_layout()
91
106
  plt.show()
92
107
 
108
+
93
109
  class PostInferenceModelEvaluator:
94
110
  def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
95
111
  """
@@ -179,12 +195,14 @@ class PostInferenceModelEvaluator:
179
195
  "pos_freq": pos_freq,
180
196
  "confusion_matrix": cm,
181
197
  "pr_rc_curve": (pr, rc),
182
- "roc_curve": (tpr, fpr)
198
+ "roc_curve": (tpr, fpr),
183
199
  }
184
200
 
185
201
  return metrics
186
-
187
- def _subsample_for_fixed_positive_frequency(self, binary_labels, target_freq=0.3, max_positive=None):
202
+
203
+ def _subsample_for_fixed_positive_frequency(
204
+ self, binary_labels, target_freq=0.3, max_positive=None
205
+ ):
188
206
  pos_idx = np.where(binary_labels == 1)[0]
189
207
  neg_idx = np.where(binary_labels == 0)[0]
190
208
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .lightning_inference import run_lightning_inference
4
+ from .sklearn_inference import run_sklearn_inference
2
5
  from .sliding_window_inference import sliding_window_inference
3
- from .sklearn_inference import run_sklearn_inference
@@ -1,5 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import pandas as pd
2
4
 
5
+
3
6
  def annotate_split_column(adata, model, split_col="split"):
4
7
  """
5
8
  Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
@@ -8,7 +11,7 @@ def annotate_split_column(adata, model, split_col="split"):
8
11
  train_set = set(model.train_obs_names)
9
12
  val_set = set(model.val_obs_names)
10
13
  test_set = set(model.test_obs_names)
11
-
14
+
12
15
  # Create array for split labels
13
16
  split_labels = []
14
17
  for obs in adata.obs_names:
@@ -20,8 +23,10 @@ def annotate_split_column(adata, model, split_col="split"):
20
23
  split_labels.append("testing")
21
24
  else:
22
25
  split_labels.append("new")
23
-
26
+
24
27
  # Store in AnnData.obs
25
- adata.obs[split_col] = pd.Categorical(split_labels, categories=["training", "validation", "testing", "new"])
26
-
28
+ adata.obs[split_col] = pd.Categorical(
29
+ split_labels, categories=["training", "validation", "testing", "new"]
30
+ )
31
+
27
32
  print(f"Annotated {split_col} column with training/validation/testing/new status.")
@@ -1,17 +1,16 @@
1
- import torch
2
- import pandas as pd
1
+ from __future__ import annotations
2
+
3
3
  import numpy as np
4
- from pytorch_lightning import Trainer
4
+ import pandas as pd
5
+
6
+ from smftools.optional_imports import require
7
+
5
8
  from .inference_utils import annotate_split_column
6
9
 
7
- def run_lightning_inference(
8
- adata,
9
- model,
10
- datamodule,
11
- trainer,
12
- prefix="model",
13
- devices=1
14
- ):
10
+ torch = require("torch", extra="ml-base", purpose="Lightning inference")
11
+
12
+
13
+ def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
15
14
  """
16
15
  Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
17
16
  """
@@ -57,7 +56,9 @@ def run_lightning_inference(
57
56
  full_prefix = f"{prefix}_{label_col}"
58
57
 
59
58
  adata.obs[f"{full_prefix}_pred"] = pred_class_idx
60
- adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
59
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
60
+ pred_class_labels, categories=class_labels
61
+ )
61
62
  adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
62
63
 
63
64
  for i, class_name in enumerate(class_labels):
@@ -65,4 +66,4 @@ def run_lightning_inference(
65
66
 
66
67
  adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
67
68
 
68
- print(f"Inference complete: stored under prefix '{full_prefix}'")
69
+ print(f"Inference complete: stored under prefix '{full_prefix}'")
@@ -1,14 +1,12 @@
1
- import pandas as pd
1
+ from __future__ import annotations
2
+
2
3
  import numpy as np
4
+ import pandas as pd
5
+
3
6
  from .inference_utils import annotate_split_column
4
7
 
5
8
 
6
- def run_sklearn_inference(
7
- adata,
8
- model,
9
- datamodule,
10
- prefix="model"
11
- ):
9
+ def run_sklearn_inference(adata, model, datamodule, prefix="model"):
12
10
  """
13
11
  Run inference on AnnData using SklearnModelWrapper.
14
12
  """
@@ -44,7 +42,9 @@ def run_sklearn_inference(
44
42
  full_prefix = f"{prefix}_{label_col}"
45
43
 
46
44
  adata.obs[f"{full_prefix}_pred"] = pred_class_idx
47
- adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
45
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
46
+ pred_class_labels, categories=class_labels
47
+ )
48
48
  adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
49
49
 
50
50
  for i, class_name in enumerate(class_labels):
@@ -1,18 +1,21 @@
1
+ from __future__ import annotations
2
+
1
3
  from ..data import AnnDataModule
2
4
  from ..evaluation import PostInferenceModelEvaluator
3
5
  from .lightning_inference import run_lightning_inference
4
6
  from .sklearn_inference import run_sklearn_inference
5
7
 
8
+
6
9
  def sliding_window_inference(
7
- adata,
8
- trained_results,
9
- tensor_source='X',
10
+ adata,
11
+ trained_results,
12
+ tensor_source="X",
10
13
  tensor_key=None,
11
- label_col='activity_status',
14
+ label_col="activity_status",
12
15
  batch_size=64,
13
16
  cleanup=False,
14
- target_eval_freq=None,
15
- max_eval_positive=None
17
+ target_eval_freq=None,
18
+ max_eval_positive=None,
16
19
  ):
17
20
  """
18
21
  Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
@@ -24,11 +27,11 @@ def sliding_window_inference(
24
27
  for window_size, window_data in model_dict.items():
25
28
  for center_varname, run in window_data.items():
26
29
  print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
27
-
30
+
28
31
  # Extract window start from varname
29
32
  center_idx = adata.var_names.get_loc(center_varname)
30
33
  window_start = center_idx - window_size // 2
31
-
34
+
32
35
  # Build datamodule for window
33
36
  datamodule = AnnDataModule(
34
37
  adata,
@@ -38,31 +41,31 @@ def sliding_window_inference(
38
41
  batch_size=batch_size,
39
42
  window_start=window_start,
40
43
  window_size=window_size,
41
- inference_mode=True
44
+ inference_mode=True,
42
45
  )
43
46
  datamodule.setup()
44
47
 
45
48
  # Extract model + detect type
46
- model = run['model']
49
+ model = run["model"]
47
50
 
48
51
  # Lightning models
49
- if hasattr(run, 'trainer') or 'trainer' in run:
50
- trainer = run['trainer']
52
+ if hasattr(run, "trainer") or "trainer" in run:
53
+ trainer = run["trainer"]
51
54
  run_lightning_inference(
52
55
  adata,
53
56
  model=model,
54
57
  datamodule=datamodule,
55
58
  trainer=trainer,
56
- prefix=f"{model_name}_w{window_size}_c{center_varname}"
59
+ prefix=f"{model_name}_w{window_size}_c{center_varname}",
57
60
  )
58
-
61
+
59
62
  # Sklearn models
60
63
  else:
61
64
  run_sklearn_inference(
62
65
  adata,
63
66
  model=model,
64
67
  datamodule=datamodule,
65
- prefix=f"{model_name}_w{window_size}_c{center_varname}"
68
+ prefix=f"{model_name}_w{window_size}_c{center_varname}",
66
69
  )
67
70
 
68
71
  print("Inference complete across all models.")
@@ -77,27 +80,36 @@ def sliding_window_inference(
77
80
  prefix = f"{model_name}_w{window_size}_c{center_varname}"
78
81
  # Use full key for uniqueness
79
82
  key = prefix
80
- model_wrappers[key] = run['model']
83
+ model_wrappers[key] = run["model"]
81
84
 
82
85
  # Run evaluator
83
- evaluator = PostInferenceModelEvaluator(adata, model_wrappers, target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive)
86
+ evaluator = PostInferenceModelEvaluator(
87
+ adata,
88
+ model_wrappers,
89
+ target_eval_freq=target_eval_freq,
90
+ max_eval_positive=max_eval_positive,
91
+ )
84
92
  evaluator.evaluate_all()
85
93
 
86
94
  # Get results
87
95
  df = evaluator.to_dataframe()
88
96
 
89
- df[['model_name', 'window_size', 'center']] = df['model'].str.extract(r'(\w+)_w(\d+)_c(\d+)_activity_status')
97
+ df[["model_name", "window_size", "center"]] = df["model"].str.extract(
98
+ r"(\w+)_w(\d+)_c(\d+)_activity_status"
99
+ )
90
100
 
91
101
  # Cast window_size and center to integers for plotting
92
- df['window_size'] = df['window_size'].astype(int)
93
- df['center'] = df['center'].astype(int)
102
+ df["window_size"] = df["window_size"].astype(int)
103
+ df["center"] = df["center"].astype(int)
94
104
 
95
105
  ## Optional cleanup:
96
106
  if cleanup:
97
- prefixes = [f"{model_name}_w{window_size}_c{center_varname}"
98
- for model_name, model_dict in trained_results.items()
99
- for window_size, window_data in model_dict.items()
100
- for center_varname in window_data.keys()]
107
+ prefixes = [
108
+ f"{model_name}_w{window_size}_c{center_varname}"
109
+ for model_name, model_dict in trained_results.items()
110
+ for window_size, window_data in model_dict.items()
111
+ for center_varname in window_data.keys()
112
+ ]
101
113
 
102
114
  # Remove matching obs columns
103
115
  for prefix in prefixes:
@@ -111,4 +123,4 @@ def sliding_window_inference(
111
123
 
112
124
  print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
113
125
 
114
- return df
126
+ return df
@@ -1,9 +1,16 @@
1
+ from __future__ import annotations
2
+
1
3
  from .base import BaseTorchModel
2
- from .mlp import MLPClassifier
3
4
  from .cnn import CNNClassifier
4
- from .rnn import RNNClassifier
5
- from .transformer import BaseTransformer, TransformerClassifier, DANNTransformerClassifier, MaskedTransformerPretrainer
5
+ from .lightning_base import TorchClassifierWrapper
6
+ from .mlp import MLPClassifier
6
7
  from .positional import PositionalEncoding
8
+ from .rnn import RNNClassifier
9
+ from .sklearn_models import SklearnModelWrapper
10
+ from .transformer import (
11
+ BaseTransformer,
12
+ DANNTransformerClassifier,
13
+ MaskedTransformerPretrainer,
14
+ TransformerClassifier,
15
+ )
7
16
  from .wrappers import ScaledModel
8
- from .lightning_base import TorchClassifierWrapper
9
- from .sklearn_models import SklearnModelWrapper
@@ -1,17 +1,25 @@
1
- import torch
2
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
3
  import numpy as np
4
+
5
+ from smftools.optional_imports import require
6
+
4
7
  from ..utils.device import detect_device
5
8
 
9
+ torch = require("torch", extra="ml-base", purpose="ML base models")
10
+ nn = torch.nn
11
+
12
+
6
13
  class BaseTorchModel(nn.Module):
7
14
  """
8
15
  Minimal base class for torch models that:
9
16
  - Stores device and dropout regularization
10
17
  """
18
+
11
19
  def __init__(self, dropout_rate=0.0):
12
20
  super().__init__()
13
- self.device = detect_device() # detects available devices
14
- self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
21
+ self.device = detect_device() # detects available devices
22
+ self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
15
23
 
16
24
  def compute_saliency(
17
25
  self,
@@ -21,11 +29,11 @@ class BaseTorchModel(nn.Module):
21
29
  smoothgrad=False,
22
30
  smooth_samples=25,
23
31
  smooth_noise=0.1,
24
- signed=True
32
+ signed=True,
25
33
  ):
26
34
  """
27
35
  Compute vanilla saliency or SmoothGrad saliency.
28
-
36
+
29
37
  Arguments:
30
38
  ----------
31
39
  x : torch.Tensor
@@ -43,7 +51,7 @@ class BaseTorchModel(nn.Module):
43
51
  """
44
52
  self.eval()
45
53
  x = x.clone().detach().requires_grad_(True)
46
-
54
+
47
55
  if smoothgrad:
48
56
  saliency_accum = torch.zeros_like(x)
49
57
  for i in range(smooth_samples):
@@ -56,7 +64,7 @@ class BaseTorchModel(nn.Module):
56
64
  if logits.shape[1] == 1:
57
65
  scores = logits.squeeze(1)
58
66
  else:
59
- scores = logits[torch.arange(x.shape[0]), target_class]
67
+ scores = logits[torch.arange(x.shape[0]), target_class]
60
68
  scores.sum().backward()
61
69
  saliency_accum += x_noisy.grad.detach()
62
70
  saliency = saliency_accum / smooth_samples
@@ -69,17 +77,17 @@ class BaseTorchModel(nn.Module):
69
77
  scores = logits[torch.arange(x.shape[0]), target_class]
70
78
  scores.sum().backward()
71
79
  saliency = x.grad.detach()
72
-
80
+
73
81
  if not signed:
74
82
  saliency = saliency.abs()
75
-
83
+
76
84
  if reduction == "sum" and x.ndim == 3:
77
85
  return saliency.sum(dim=-1)
78
86
  elif reduction == "mean" and x.ndim == 3:
79
87
  return saliency.mean(dim=-1)
80
88
  else:
81
89
  return saliency
82
-
90
+
83
91
  def compute_gradient_x_input(self, x, target_class=None):
84
92
  """
85
93
  Computes gradient × input attribution.
@@ -118,22 +126,11 @@ class BaseTorchModel(nn.Module):
118
126
  baseline = torch.zeros_like(x)
119
127
 
120
128
  attributions, delta = ig.attribute(
121
- x,
122
- baselines=baseline,
123
- target=target_class,
124
- n_steps=steps,
125
- return_convergence_delta=True
129
+ x, baselines=baseline, target=target_class, n_steps=steps, return_convergence_delta=True
126
130
  )
127
131
  return attributions, delta
128
132
 
129
- def compute_deeplift(
130
- self,
131
- x,
132
- baseline=None,
133
- target_class=None,
134
- reduction="sum",
135
- signed=True
136
- ):
133
+ def compute_deeplift(self, x, baseline=None, target_class=None, reduction="sum", signed=True):
137
134
  """
138
135
  Compute DeepLIFT scores using captum.
139
136
 
@@ -158,21 +155,15 @@ class BaseTorchModel(nn.Module):
158
155
 
159
156
  if not signed:
160
157
  attr = attr.abs()
161
-
158
+
162
159
  if reduction == "sum" and x.ndim == 3:
163
160
  return attr.sum(dim=-1)
164
161
  elif reduction == "mean" and x.ndim == 3:
165
162
  return attr.mean(dim=-1)
166
163
  else:
167
164
  return attr
168
-
169
- def compute_occlusion(
170
- self,
171
- x,
172
- target_class=None,
173
- window_size=5,
174
- baseline=None
175
- ):
165
+
166
+ def compute_occlusion(self, x, target_class=None, window_size=5, baseline=None):
176
167
  """
177
168
  Computes per-sample occlusion attribution.
178
169
  Supports 2D [B, S] or 3D [B, S, D] inputs.
@@ -208,9 +199,7 @@ class BaseTorchModel(nn.Module):
208
199
  x_occluded[left:right, :] = baseline[left:right, :]
209
200
 
210
201
  x_tensor = torch.tensor(
211
- x_occluded,
212
- device=self.device,
213
- dtype=torch.float32
202
+ x_occluded, device=self.device, dtype=torch.float32
214
203
  ).unsqueeze(0)
215
204
 
216
205
  logits = self.forward(x_tensor)
@@ -235,7 +224,7 @@ class BaseTorchModel(nn.Module):
235
224
  device="cpu",
236
225
  target_class=None,
237
226
  normalize=True,
238
- signed=True
227
+ signed=True,
239
228
  ):
240
229
  """
241
230
  Apply a chosen attribution method to a dataloader and store results in adata.
@@ -252,7 +241,9 @@ class BaseTorchModel(nn.Module):
252
241
  attr = model.compute_saliency(x, target_class=target_class, signed=signed)
253
242
 
254
243
  elif method == "smoothgrad":
255
- attr = model.compute_saliency(x, smoothgrad=True, target_class=target_class, signed=signed)
244
+ attr = model.compute_saliency(
245
+ x, smoothgrad=True, target_class=target_class, signed=signed
246
+ )
256
247
 
257
248
  elif method == "IG":
258
249
  attributions, delta = model.compute_integrated_gradients(
@@ -261,15 +252,15 @@ class BaseTorchModel(nn.Module):
261
252
  attr = attributions
262
253
 
263
254
  elif method == "deeplift":
264
- attr = model.compute_deeplift(x, baseline=baseline, target_class=target_class, signed=signed)
255
+ attr = model.compute_deeplift(
256
+ x, baseline=baseline, target_class=target_class, signed=signed
257
+ )
265
258
 
266
259
  elif method == "gradxinput":
267
260
  attr = model.compute_gradient_x_input(x, target_class=target_class)
268
261
 
269
262
  elif method == "occlusion":
270
- attr = model.compute_occlusion(
271
- x, target_class=target_class, baseline=baseline
272
- )
263
+ attr = model.compute_occlusion(x, target_class=target_class, baseline=baseline)
273
264
 
274
265
  else:
275
266
  raise ValueError(f"Unknown method {method}")
@@ -292,4 +283,4 @@ class BaseTorchModel(nn.Module):
292
283
  return target_class
293
284
  if logits.shape[1] == 1:
294
285
  return (logits > 0).long().squeeze(1)
295
- return logits.argmax(dim=1)
286
+ return logits.argmax(dim=1)