smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,23 @@
1
- def run_multiqc(input_dir, output_dir):
2
- """
3
- Runs MultiQC on a given directory and saves the report to the specified output directory.
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ from smftools.logging_utils import get_logger
4
6
 
5
- Parameters:
6
- - input_dir (str): Path to the directory containing QC reports (e.g., FastQC, Samtools, bcftools outputs).
7
- - output_dir (str): Path to the directory where MultiQC reports should be saved.
7
+ logger = get_logger(__name__)
8
8
 
9
- Returns:
10
- - None: The function executes MultiQC and prints the status.
9
+
10
+ def run_multiqc(input_dir: str | Path, output_dir: str | Path) -> None:
11
+ """Run MultiQC on a directory and save the report to the output directory.
12
+
13
+ Args:
14
+ input_dir: Path to the directory containing QC reports (e.g., FastQC, Samtools outputs).
15
+ output_dir: Path to the directory where MultiQC reports should be saved.
11
16
  """
12
- from ..readwrite import make_dirs
13
17
  import subprocess
18
+
19
+ from ..readwrite import make_dirs
20
+
14
21
  # Ensure the output directory exists
15
22
  make_dirs(output_dir)
16
23
 
@@ -20,12 +27,11 @@ def run_multiqc(input_dir, output_dir):
20
27
  # Construct MultiQC command
21
28
  command = ["multiqc", input_dir, "-o", output_dir]
22
29
 
23
- print(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
24
-
30
+ logger.info(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
31
+
25
32
  # Run MultiQC
26
33
  try:
27
34
  subprocess.run(command, check=True)
28
- print(f"MultiQC report generated successfully in: {output_dir}")
35
+ logger.info(f"MultiQC report generated successfully in: {output_dir}")
29
36
  except subprocess.CalledProcessError as e:
30
- print(f"Error running MultiQC: {e}")
31
-
37
+ logger.error(f"Error running MultiQC: {e}")
@@ -0,0 +1,51 @@
1
+ """Logging utilities for smftools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Optional, Union
8
+
9
+ DEFAULT_LOG_FORMAT = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
10
+ DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
11
+
12
+
13
+ def setup_logging(
14
+ level: int = logging.INFO,
15
+ fmt: str = DEFAULT_LOG_FORMAT,
16
+ datefmt: str = DEFAULT_DATE_FORMAT,
17
+ log_file: Optional[Union[str, Path]] = None,
18
+ ) -> None:
19
+ """
20
+ Configure logging for smftools.
21
+
22
+ Should be called once by the CLI entrypoint.
23
+ Safe to call multiple times.
24
+ """
25
+ logger = logging.getLogger("smftools")
26
+
27
+ if logger.handlers:
28
+ return
29
+
30
+ formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
31
+
32
+ # Console handler (stderr)
33
+ stream_handler = logging.StreamHandler()
34
+ stream_handler.setFormatter(formatter)
35
+ logger.addHandler(stream_handler)
36
+
37
+ # Optional file handler
38
+ if log_file is not None:
39
+ log_path = Path(log_file)
40
+ log_path.parent.mkdir(parents=True, exist_ok=True)
41
+
42
+ file_handler = logging.FileHandler(log_path)
43
+ file_handler.setFormatter(formatter)
44
+ logger.addHandler(file_handler)
45
+
46
+ logger.setLevel(level)
47
+ logger.propagate = False
48
+
49
+
50
+ def get_logger(name: str) -> logging.Logger:
51
+ return logging.getLogger(name)
@@ -1,12 +1,7 @@
1
- from . import models
2
- from . import data
3
- from . import utils
4
- from . import evaluation
5
- from . import inference
6
- from . import training
1
+ from . import data, evaluation, inference, models, training, utils
7
2
 
8
3
  __all__ = [
9
4
  "calculate_relative_risk_on_activity",
10
5
  "evaluate_models_by_subgroup",
11
6
  "prepare_melted_model_data",
12
- ]
7
+ ]
@@ -1,24 +1,34 @@
1
- import torch
2
- from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset, Subset
3
- import pytorch_lightning as pl
4
1
  import numpy as np
5
2
  import pandas as pd
6
- from .preprocessing import random_fill_nans
3
+ import pytorch_lightning as pl
4
+ import torch
7
5
  from sklearn.utils.class_weight import compute_class_weight
6
+ from torch.utils.data import DataLoader, Dataset, Subset
7
+
8
+ from .preprocessing import random_fill_nans
9
+
8
10
 
9
-
10
11
  class AnnDataDataset(Dataset):
11
12
  """
12
13
  Generic PyTorch Dataset from AnnData.
13
14
  """
14
- def __init__(self, adata, tensor_source="X", tensor_key=None, label_col=None, window_start=None, window_size=None):
15
+
16
+ def __init__(
17
+ self,
18
+ adata,
19
+ tensor_source="X",
20
+ tensor_key=None,
21
+ label_col=None,
22
+ window_start=None,
23
+ window_size=None,
24
+ ):
15
25
  self.adata = adata
16
26
  self.tensor_source = tensor_source
17
27
  self.tensor_key = tensor_key
18
28
  self.label_col = label_col
19
29
  self.window_start = window_start
20
30
  self.window_size = window_size
21
-
31
+
22
32
  if tensor_source == "X":
23
33
  X = adata.X
24
34
  elif tensor_source == "layers":
@@ -29,17 +39,17 @@ class AnnDataDataset(Dataset):
29
39
  X = adata.obsm[tensor_key]
30
40
  else:
31
41
  raise ValueError(f"Invalid tensor_source: {tensor_source}")
32
-
42
+
33
43
  if self.window_start is not None and self.window_size is not None:
34
44
  X = X[:, self.window_start : self.window_start + self.window_size]
35
-
45
+
36
46
  X = random_fill_nans(X)
37
47
 
38
48
  self.X_tensor = torch.tensor(X, dtype=torch.float32)
39
49
 
40
50
  if label_col is not None:
41
51
  y = adata.obs[label_col]
42
- if y.dtype.name == 'category':
52
+ if y.dtype.name == "category":
43
53
  y = y.cat.codes
44
54
  self.y_tensor = torch.tensor(y.values, dtype=torch.long)
45
55
  else:
@@ -47,7 +57,7 @@ class AnnDataDataset(Dataset):
47
57
 
48
58
  def numpy(self, indices):
49
59
  return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
50
-
60
+
51
61
  def __len__(self):
52
62
  return len(self.X_tensor)
53
63
 
@@ -60,9 +70,17 @@ class AnnDataDataset(Dataset):
60
70
  return (x,)
61
71
 
62
72
 
63
- def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
64
- random_seed=42, split_col="train_val_test_split",
65
- load_existing_split=False, split_save_path=None):
73
+ def split_dataset(
74
+ adata,
75
+ dataset,
76
+ train_frac=0.6,
77
+ val_frac=0.1,
78
+ test_frac=0.3,
79
+ random_seed=42,
80
+ split_col="train_val_test_split",
81
+ load_existing_split=False,
82
+ split_save_path=None,
83
+ ):
66
84
  """
67
85
  Perform split and record assignment into adata.obs[split_col].
68
86
  """
@@ -87,7 +105,7 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
87
105
 
88
106
  split_array = np.full(total_len, "test", dtype=object)
89
107
  split_array[indices[:n_train]] = "train"
90
- split_array[indices[n_train:n_train + n_val]] = "val"
108
+ split_array[indices[n_train : n_train + n_val]] = "val"
91
109
  adata.obs[split_col] = split_array
92
110
 
93
111
  if split_save_path:
@@ -104,14 +122,32 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
104
122
 
105
123
  return train_set, val_set, test_set
106
124
 
125
+
107
126
  class AnnDataModule(pl.LightningDataModule):
108
127
  """
109
128
  Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
110
129
  """
111
- def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
112
- batch_size=64, train_frac=0.6, val_frac=0.1, test_frac=0.3, random_seed=42,
113
- inference_mode=False, split_col="train_val_test_split", split_save_path=None,
114
- load_existing_split=False, window_start=None, window_size=None, num_workers=None, persistent_workers=False):
130
+
131
+ def __init__(
132
+ self,
133
+ adata,
134
+ tensor_source="X",
135
+ tensor_key=None,
136
+ label_col="labels",
137
+ batch_size=64,
138
+ train_frac=0.6,
139
+ val_frac=0.1,
140
+ test_frac=0.3,
141
+ random_seed=42,
142
+ inference_mode=False,
143
+ split_col="train_val_test_split",
144
+ split_save_path=None,
145
+ load_existing_split=False,
146
+ window_start=None,
147
+ window_size=None,
148
+ num_workers=None,
149
+ persistent_workers=False,
150
+ ):
115
151
  super().__init__()
116
152
  self.adata = adata
117
153
  self.tensor_source = tensor_source
@@ -133,52 +169,80 @@ class AnnDataModule(pl.LightningDataModule):
133
169
  self.persistent_workers = persistent_workers
134
170
 
135
171
  def setup(self, stage=None):
136
- dataset = AnnDataDataset(self.adata, self.tensor_source, self.tensor_key,
137
- None if self.inference_mode else self.label_col,
138
- window_start=self.window_start, window_size=self.window_size)
172
+ dataset = AnnDataDataset(
173
+ self.adata,
174
+ self.tensor_source,
175
+ self.tensor_key,
176
+ None if self.inference_mode else self.label_col,
177
+ window_start=self.window_start,
178
+ window_size=self.window_size,
179
+ )
139
180
 
140
181
  if self.inference_mode:
141
182
  self.infer_dataset = dataset
142
183
  return
143
184
 
144
185
  self.train_set, self.val_set, self.test_set = split_dataset(
145
- self.adata, dataset, train_frac=self.train_frac, val_frac=self.val_frac,
146
- test_frac=self.test_frac, random_seed=self.random_seed,
147
- split_col=self.split_col, split_save_path=self.split_save_path,
148
- load_existing_split=self.load_existing_split
186
+ self.adata,
187
+ dataset,
188
+ train_frac=self.train_frac,
189
+ val_frac=self.val_frac,
190
+ test_frac=self.test_frac,
191
+ random_seed=self.random_seed,
192
+ split_col=self.split_col,
193
+ split_save_path=self.split_save_path,
194
+ load_existing_split=self.load_existing_split,
149
195
  )
150
196
 
151
197
  def train_dataloader(self):
152
198
  if self.num_workers:
153
- return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
199
+ return DataLoader(
200
+ self.train_set,
201
+ batch_size=self.batch_size,
202
+ shuffle=True,
203
+ num_workers=self.num_workers,
204
+ persistent_workers=self.persistent_workers,
205
+ )
154
206
  else:
155
207
  return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
156
208
 
157
209
  def val_dataloader(self):
158
210
  if self.num_workers:
159
- return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
211
+ return DataLoader(
212
+ self.val_set,
213
+ batch_size=self.batch_size,
214
+ num_workers=self.num_workers,
215
+ persistent_workers=self.persistent_workers,
216
+ )
160
217
  else:
161
218
  return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
162
-
219
+
163
220
  def test_dataloader(self):
164
221
  if self.num_workers:
165
- return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
222
+ return DataLoader(
223
+ self.test_set,
224
+ batch_size=self.batch_size,
225
+ num_workers=self.num_workers,
226
+ persistent_workers=self.persistent_workers,
227
+ )
166
228
  else:
167
229
  return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
168
-
230
+
169
231
  def predict_dataloader(self):
170
232
  if not self.inference_mode:
171
233
  raise RuntimeError("Only valid in inference mode")
172
234
  return DataLoader(self.infer_dataset, batch_size=self.batch_size)
173
-
235
+
174
236
  def compute_class_weights(self):
175
- train_indices = self.train_set.indices # get the indices of the training set
176
- y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
177
- y_train = y_all[train_indices].cpu().numpy() # get the labels for the training set and move to a numpy array
178
-
179
- class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
237
+ train_indices = self.train_set.indices # get the indices of the training set
238
+ y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
239
+ y_train = (
240
+ y_all[train_indices].cpu().numpy()
241
+ ) # get the labels for the training set and move to a numpy array
242
+
243
+ class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
180
244
  return torch.tensor(class_weights, dtype=torch.float32)
181
-
245
+
182
246
  def inference_numpy(self):
183
247
  """
184
248
  Return inference data as numpy for use in sklearn inference.
@@ -187,7 +251,7 @@ class AnnDataModule(pl.LightningDataModule):
187
251
  raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
188
252
  X_np = self.infer_dataset.X_tensor.numpy()
189
253
  return X_np
190
-
254
+
191
255
  def to_numpy(self):
192
256
  """
193
257
  Move the AnnDataModule tensors into numpy arrays
@@ -202,9 +266,20 @@ class AnnDataModule(pl.LightningDataModule):
202
266
 
203
267
 
204
268
  def build_anndata_loader(
205
- adata, tensor_source="X", tensor_key=None, label_col=None, train_frac=0.6, val_frac=0.1,
206
- test_frac=0.3, random_seed=42, batch_size=64, lightning=True, inference_mode=False,
207
- split_col="train_val_test_split", split_save_path=None, load_existing_split=False
269
+ adata,
270
+ tensor_source="X",
271
+ tensor_key=None,
272
+ label_col=None,
273
+ train_frac=0.6,
274
+ val_frac=0.1,
275
+ test_frac=0.3,
276
+ random_seed=42,
277
+ batch_size=64,
278
+ lightning=True,
279
+ inference_mode=False,
280
+ split_col="train_val_test_split",
281
+ split_save_path=None,
282
+ load_existing_split=False,
208
283
  ):
209
284
  """
210
285
  Unified pipeline for both Lightning and raw PyTorch.
@@ -213,22 +288,40 @@ def build_anndata_loader(
213
288
  """
214
289
  if lightning:
215
290
  return AnnDataModule(
216
- adata, tensor_source=tensor_source, tensor_key=tensor_key, label_col=label_col,
217
- batch_size=batch_size, train_frac=train_frac, val_frac=val_frac, test_frac=test_frac,
218
- random_seed=random_seed, inference_mode=inference_mode,
219
- split_col=split_col, split_save_path=split_save_path, load_existing_split=load_existing_split
291
+ adata,
292
+ tensor_source=tensor_source,
293
+ tensor_key=tensor_key,
294
+ label_col=label_col,
295
+ batch_size=batch_size,
296
+ train_frac=train_frac,
297
+ val_frac=val_frac,
298
+ test_frac=test_frac,
299
+ random_seed=random_seed,
300
+ inference_mode=inference_mode,
301
+ split_col=split_col,
302
+ split_save_path=split_save_path,
303
+ load_existing_split=load_existing_split,
220
304
  )
221
305
  else:
222
306
  var_names = adata.var_names.copy()
223
- dataset = AnnDataDataset(adata, tensor_source, tensor_key, None if inference_mode else label_col)
307
+ dataset = AnnDataDataset(
308
+ adata, tensor_source, tensor_key, None if inference_mode else label_col
309
+ )
224
310
  if inference_mode:
225
311
  return DataLoader(dataset, batch_size=batch_size)
226
312
  else:
227
313
  train_set, val_set, test_set = split_dataset(
228
- adata, dataset, train_frac, val_frac, test_frac, random_seed,
229
- split_col, split_save_path, load_existing_split
314
+ adata,
315
+ dataset,
316
+ train_frac,
317
+ val_frac,
318
+ test_frac,
319
+ random_seed,
320
+ split_col,
321
+ split_save_path,
322
+ load_existing_split,
230
323
  )
231
324
  train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
232
325
  val_loader = DataLoader(val_set, batch_size=batch_size)
233
326
  test_loader = DataLoader(test_set, batch_size=batch_size)
234
- return train_loader, val_loader, test_loader
327
+ return train_loader, val_loader, test_loader
@@ -1,6 +1,7 @@
1
1
  import numpy as np
2
2
 
3
+
3
4
  def random_fill_nans(X):
4
5
  nan_mask = np.isnan(X)
5
6
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
- return X
7
+ return X
@@ -1,2 +1,2 @@
1
+ from .eval_utils import flatten_sliding_window_results
1
2
  from .evaluators import ModelEvaluator, PostInferenceModelEvaluator
2
- from .eval_utils import flatten_sliding_window_results
@@ -1,10 +1,11 @@
1
1
  import pandas as pd
2
2
 
3
+
3
4
  def flatten_sliding_window_results(results_dict):
4
5
  """
5
6
  Flatten nested sliding window results into pandas DataFrame.
6
-
7
- Expects structure:
7
+
8
+ Expects structure:
8
9
  results[model_name][window_size][window_center]['metrics'][metric_name]
9
10
  """
10
11
  records = []
@@ -12,20 +13,16 @@ def flatten_sliding_window_results(results_dict):
12
13
  for model_name, model_results in results_dict.items():
13
14
  for window_size, window_results in model_results.items():
14
15
  for center_var, result in window_results.items():
15
- metrics = result['metrics']
16
- record = {
17
- 'model': model_name,
18
- 'window_size': window_size,
19
- 'center_var': center_var
20
- }
16
+ metrics = result["metrics"]
17
+ record = {"model": model_name, "window_size": window_size, "center_var": center_var}
21
18
  # Add all metrics
22
19
  record.update(metrics)
23
20
  records.append(record)
24
-
21
+
25
22
  df = pd.DataFrame.from_records(records)
26
-
23
+
27
24
  # Convert center_var to numeric if possible (optional but helpful for plotting)
28
- df['center_var'] = pd.to_numeric(df['center_var'], errors='coerce')
29
- df = df.sort_values(['model', 'window_size', 'center_var'])
30
-
31
- return df
25
+ df["center_var"] = pd.to_numeric(df["center_var"], errors="coerce")
26
+ df = df.sort_values(["model", "window_size", "center_var"])
27
+
28
+ return df
@@ -1,15 +1,21 @@
1
+ import matplotlib.pyplot as plt
1
2
  import numpy as np
2
3
  import pandas as pd
3
- import matplotlib.pyplot as plt
4
-
5
4
  from sklearn.metrics import (
6
- roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
5
+ auc,
6
+ confusion_matrix,
7
+ f1_score,
8
+ precision_recall_curve,
9
+ roc_auc_score,
10
+ roc_curve,
7
11
  )
8
12
 
13
+
9
14
  class ModelEvaluator:
10
15
  """
11
16
  A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
12
17
  """
18
+
13
19
  def __init__(self):
14
20
  self.results = []
15
21
  self.pos_freq = None
@@ -21,41 +27,45 @@ class ModelEvaluator:
21
27
  """
22
28
  if is_torch:
23
29
  entry = {
24
- 'name': name,
25
- 'f1': model.test_f1,
26
- 'auc': model.test_roc_auc,
27
- 'pr_auc': model.test_pr_auc,
28
- 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
29
- 'pr_curve': model.test_pr_curve,
30
- 'roc_curve': model.test_roc_curve,
31
- 'num_pos': model.test_num_pos,
32
- 'pos_freq': model.test_pos_freq
30
+ "name": name,
31
+ "f1": model.test_f1,
32
+ "auc": model.test_roc_auc,
33
+ "pr_auc": model.test_pr_auc,
34
+ "pr_auc_norm": model.test_pr_auc / model.test_pos_freq
35
+ if model.test_pos_freq > 0
36
+ else np.nan,
37
+ "pr_curve": model.test_pr_curve,
38
+ "roc_curve": model.test_roc_curve,
39
+ "num_pos": model.test_num_pos,
40
+ "pos_freq": model.test_pos_freq,
33
41
  }
34
42
  else:
35
43
  entry = {
36
- 'name': name,
37
- 'f1': model.test_f1,
38
- 'auc': model.test_roc_auc,
39
- 'pr_auc': model.test_pr_auc,
40
- 'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
41
- 'pr_curve': model.test_pr_curve,
42
- 'roc_curve': model.test_roc_curve,
43
- 'num_pos': model.test_num_pos,
44
- 'pos_freq': model.test_pos_freq
44
+ "name": name,
45
+ "f1": model.test_f1,
46
+ "auc": model.test_roc_auc,
47
+ "pr_auc": model.test_pr_auc,
48
+ "pr_auc_norm": model.test_pr_auc / model.test_pos_freq
49
+ if model.test_pos_freq > 0
50
+ else np.nan,
51
+ "pr_curve": model.test_pr_curve,
52
+ "roc_curve": model.test_roc_curve,
53
+ "num_pos": model.test_num_pos,
54
+ "pos_freq": model.test_pos_freq,
45
55
  }
46
-
56
+
47
57
  self.results.append(entry)
48
58
 
49
59
  if not self.pos_freq:
50
- self.pos_freq = entry['pos_freq']
51
- self.num_pos = entry['num_pos']
60
+ self.pos_freq = entry["pos_freq"]
61
+ self.num_pos = entry["num_pos"]
52
62
 
53
63
  def get_metrics_dataframe(self):
54
64
  """
55
65
  Return all metrics as pandas DataFrame.
56
66
  """
57
67
  df = pd.DataFrame(self.results)
58
- return df[['name', 'f1', 'auc', 'pr_auc', 'pr_auc_norm', 'num_pos', 'pos_freq']]
68
+ return df[["name", "f1", "auc", "pr_auc", "pr_auc_norm", "num_pos", "pos_freq"]]
59
69
 
60
70
  def plot_all_curves(self):
61
71
  """
@@ -66,30 +76,31 @@ class ModelEvaluator:
66
76
  # ROC
67
77
  plt.subplot(1, 2, 1)
68
78
  for res in self.results:
69
- fpr, tpr = res['roc_curve']
79
+ fpr, tpr = res["roc_curve"]
70
80
  plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
71
81
  plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
72
82
  plt.xlabel("False Positive Rate")
73
83
  plt.ylabel("True Positive Rate")
74
- plt.ylim(0,1.05)
84
+ plt.ylim(0, 1.05)
75
85
  plt.title(f"ROC Curves - {self.num_pos} positive instances")
76
86
  plt.legend()
77
87
 
78
88
  # PR
79
89
  plt.subplot(1, 2, 2)
80
90
  for res in self.results:
81
- rc, pr = res['pr_curve']
91
+ rc, pr = res["pr_curve"]
82
92
  plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
83
93
  plt.xlabel("Recall")
84
94
  plt.ylabel("Precision")
85
- plt.ylim(0,1.05)
86
- plt.axhline(self.pos_freq, linestyle='--', color='grey')
95
+ plt.ylim(0, 1.05)
96
+ plt.axhline(self.pos_freq, linestyle="--", color="grey")
87
97
  plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
88
98
  plt.legend()
89
99
 
90
100
  plt.tight_layout()
91
101
  plt.show()
92
102
 
103
+
93
104
  class PostInferenceModelEvaluator:
94
105
  def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
95
106
  """
@@ -179,12 +190,14 @@ class PostInferenceModelEvaluator:
179
190
  "pos_freq": pos_freq,
180
191
  "confusion_matrix": cm,
181
192
  "pr_rc_curve": (pr, rc),
182
- "roc_curve": (tpr, fpr)
193
+ "roc_curve": (tpr, fpr),
183
194
  }
184
195
 
185
196
  return metrics
186
-
187
- def _subsample_for_fixed_positive_frequency(self, binary_labels, target_freq=0.3, max_positive=None):
197
+
198
+ def _subsample_for_fixed_positive_frequency(
199
+ self, binary_labels, target_freq=0.3, max_positive=None
200
+ ):
188
201
  pos_idx = np.where(binary_labels == 1)[0]
189
202
  neg_idx = np.where(binary_labels == 0)[0]
190
203
 
@@ -1,3 +1,3 @@
1
1
  from .lightning_inference import run_lightning_inference
2
+ from .sklearn_inference import run_sklearn_inference
2
3
  from .sliding_window_inference import sliding_window_inference
3
- from .sklearn_inference import run_sklearn_inference