smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,23 @@
1
- def run_multiqc(input_dir, output_dir):
2
- """
3
- Runs MultiQC on a given directory and saves the report to the specified output directory.
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ from smftools.logging_utils import get_logger
4
6
 
5
- Parameters:
6
- - input_dir (str): Path to the directory containing QC reports (e.g., FastQC, Samtools, bcftools outputs).
7
- - output_dir (str): Path to the directory where MultiQC reports should be saved.
7
+ logger = get_logger(__name__)
8
8
 
9
- Returns:
10
- - None: The function executes MultiQC and prints the status.
9
+
10
+ def run_multiqc(input_dir: str | Path, output_dir: str | Path) -> None:
11
+ """Run MultiQC on a directory and save the report to the output directory.
12
+
13
+ Args:
14
+ input_dir: Path to the directory containing QC reports (e.g., FastQC, Samtools outputs).
15
+ output_dir: Path to the directory where MultiQC reports should be saved.
11
16
  """
12
- from ..readwrite import make_dirs
13
17
  import subprocess
18
+
19
+ from ..readwrite import make_dirs
20
+
14
21
  # Ensure the output directory exists
15
22
  make_dirs(output_dir)
16
23
 
@@ -20,12 +27,11 @@ def run_multiqc(input_dir, output_dir):
20
27
  # Construct MultiQC command
21
28
  command = ["multiqc", input_dir, "-o", output_dir]
22
29
 
23
- print(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
24
-
30
+ logger.info(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
31
+
25
32
  # Run MultiQC
26
33
  try:
27
34
  subprocess.run(command, check=True)
28
- print(f"MultiQC report generated successfully in: {output_dir}")
35
+ logger.info(f"MultiQC report generated successfully in: {output_dir}")
29
36
  except subprocess.CalledProcessError as e:
30
- print(f"Error running MultiQC: {e}")
31
-
37
+ logger.error(f"Error running MultiQC: {e}")
@@ -0,0 +1,51 @@
1
+ """Logging utilities for smftools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Optional, Union
8
+
9
+ DEFAULT_LOG_FORMAT = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
10
+ DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
11
+
12
+
13
+ def setup_logging(
14
+ level: int = logging.INFO,
15
+ fmt: str = DEFAULT_LOG_FORMAT,
16
+ datefmt: str = DEFAULT_DATE_FORMAT,
17
+ log_file: Optional[Union[str, Path]] = None,
18
+ ) -> None:
19
+ """
20
+ Configure logging for smftools.
21
+
22
+ Should be called once by the CLI entrypoint.
23
+ Safe to call multiple times.
24
+ """
25
+ logger = logging.getLogger("smftools")
26
+
27
+ if logger.handlers:
28
+ return
29
+
30
+ formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
31
+
32
+ # Console handler (stderr)
33
+ stream_handler = logging.StreamHandler()
34
+ stream_handler.setFormatter(formatter)
35
+ logger.addHandler(stream_handler)
36
+
37
+ # Optional file handler
38
+ if log_file is not None:
39
+ log_path = Path(log_file)
40
+ log_path.parent.mkdir(parents=True, exist_ok=True)
41
+
42
+ file_handler = logging.FileHandler(log_path)
43
+ file_handler.setFormatter(formatter)
44
+ logger.addHandler(file_handler)
45
+
46
+ logger.setLevel(level)
47
+ logger.propagate = False
48
+
49
+
50
+ def get_logger(name: str) -> logging.Logger:
51
+ return logging.getLogger(name)
@@ -1,12 +1,23 @@
1
- from . import models
2
- from . import data
3
- from . import utils
4
- from . import evaluation
5
- from . import inference
6
- from . import training
7
-
8
- __all__ = [
9
- "calculate_relative_risk_on_activity",
10
- "evaluate_models_by_subgroup",
11
- "prepare_melted_model_data",
12
- ]
1
+ from __future__ import annotations
2
+
3
+ from importlib import import_module
4
+
5
+ _LAZY_MODULES = {
6
+ "data": "smftools.machine_learning.data",
7
+ "evaluation": "smftools.machine_learning.evaluation",
8
+ "inference": "smftools.machine_learning.inference",
9
+ "models": "smftools.machine_learning.models",
10
+ "training": "smftools.machine_learning.training",
11
+ "utils": "smftools.machine_learning.utils",
12
+ }
13
+
14
+
15
+ def __getattr__(name: str):
16
+ if name in _LAZY_MODULES:
17
+ module = import_module(_LAZY_MODULES[name])
18
+ globals()[name] = module
19
+ return module
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
21
+
22
+
23
+ __all__ = list(_LAZY_MODULES.keys())
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .anndata_data_module import AnnDataModule, build_anndata_loader
2
4
  from .preprocessing import random_fill_nans
@@ -1,24 +1,48 @@
1
- import torch
2
- from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset, Subset
3
- import pytorch_lightning as pl
1
+ from __future__ import annotations
2
+
4
3
  import numpy as np
5
4
  import pandas as pd
5
+
6
+ from smftools.optional_imports import require
7
+
6
8
  from .preprocessing import random_fill_nans
7
- from sklearn.utils.class_weight import compute_class_weight
8
9
 
9
-
10
+ pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning data modules")
11
+ torch = require("torch", extra="ml-base", purpose="ML data loading")
12
+ sklearn_class_weight = require(
13
+ "sklearn.utils.class_weight",
14
+ extra="ml-base",
15
+ purpose="class weighting",
16
+ )
17
+ torch_utils_data = require("torch.utils.data", extra="ml-base", purpose="ML data loading")
18
+
19
+ compute_class_weight = sklearn_class_weight.compute_class_weight
20
+ DataLoader = torch_utils_data.DataLoader
21
+ Dataset = torch_utils_data.Dataset
22
+ Subset = torch_utils_data.Subset
23
+
24
+
10
25
  class AnnDataDataset(Dataset):
11
26
  """
12
27
  Generic PyTorch Dataset from AnnData.
13
28
  """
14
- def __init__(self, adata, tensor_source="X", tensor_key=None, label_col=None, window_start=None, window_size=None):
29
+
30
+ def __init__(
31
+ self,
32
+ adata,
33
+ tensor_source="X",
34
+ tensor_key=None,
35
+ label_col=None,
36
+ window_start=None,
37
+ window_size=None,
38
+ ):
15
39
  self.adata = adata
16
40
  self.tensor_source = tensor_source
17
41
  self.tensor_key = tensor_key
18
42
  self.label_col = label_col
19
43
  self.window_start = window_start
20
44
  self.window_size = window_size
21
-
45
+
22
46
  if tensor_source == "X":
23
47
  X = adata.X
24
48
  elif tensor_source == "layers":
@@ -29,17 +53,17 @@ class AnnDataDataset(Dataset):
29
53
  X = adata.obsm[tensor_key]
30
54
  else:
31
55
  raise ValueError(f"Invalid tensor_source: {tensor_source}")
32
-
56
+
33
57
  if self.window_start is not None and self.window_size is not None:
34
58
  X = X[:, self.window_start : self.window_start + self.window_size]
35
-
59
+
36
60
  X = random_fill_nans(X)
37
61
 
38
62
  self.X_tensor = torch.tensor(X, dtype=torch.float32)
39
63
 
40
64
  if label_col is not None:
41
65
  y = adata.obs[label_col]
42
- if y.dtype.name == 'category':
66
+ if y.dtype.name == "category":
43
67
  y = y.cat.codes
44
68
  self.y_tensor = torch.tensor(y.values, dtype=torch.long)
45
69
  else:
@@ -47,7 +71,7 @@ class AnnDataDataset(Dataset):
47
71
 
48
72
  def numpy(self, indices):
49
73
  return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
50
-
74
+
51
75
  def __len__(self):
52
76
  return len(self.X_tensor)
53
77
 
@@ -60,9 +84,17 @@ class AnnDataDataset(Dataset):
60
84
  return (x,)
61
85
 
62
86
 
63
- def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
64
- random_seed=42, split_col="train_val_test_split",
65
- load_existing_split=False, split_save_path=None):
87
+ def split_dataset(
88
+ adata,
89
+ dataset,
90
+ train_frac=0.6,
91
+ val_frac=0.1,
92
+ test_frac=0.3,
93
+ random_seed=42,
94
+ split_col="train_val_test_split",
95
+ load_existing_split=False,
96
+ split_save_path=None,
97
+ ):
66
98
  """
67
99
  Perform split and record assignment into adata.obs[split_col].
68
100
  """
@@ -87,7 +119,7 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
87
119
 
88
120
  split_array = np.full(total_len, "test", dtype=object)
89
121
  split_array[indices[:n_train]] = "train"
90
- split_array[indices[n_train:n_train + n_val]] = "val"
122
+ split_array[indices[n_train : n_train + n_val]] = "val"
91
123
  adata.obs[split_col] = split_array
92
124
 
93
125
  if split_save_path:
@@ -104,14 +136,32 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
104
136
 
105
137
  return train_set, val_set, test_set
106
138
 
139
+
107
140
  class AnnDataModule(pl.LightningDataModule):
108
141
  """
109
142
  Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
110
143
  """
111
- def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
112
- batch_size=64, train_frac=0.6, val_frac=0.1, test_frac=0.3, random_seed=42,
113
- inference_mode=False, split_col="train_val_test_split", split_save_path=None,
114
- load_existing_split=False, window_start=None, window_size=None, num_workers=None, persistent_workers=False):
144
+
145
+ def __init__(
146
+ self,
147
+ adata,
148
+ tensor_source="X",
149
+ tensor_key=None,
150
+ label_col="labels",
151
+ batch_size=64,
152
+ train_frac=0.6,
153
+ val_frac=0.1,
154
+ test_frac=0.3,
155
+ random_seed=42,
156
+ inference_mode=False,
157
+ split_col="train_val_test_split",
158
+ split_save_path=None,
159
+ load_existing_split=False,
160
+ window_start=None,
161
+ window_size=None,
162
+ num_workers=None,
163
+ persistent_workers=False,
164
+ ):
115
165
  super().__init__()
116
166
  self.adata = adata
117
167
  self.tensor_source = tensor_source
@@ -133,52 +183,80 @@ class AnnDataModule(pl.LightningDataModule):
133
183
  self.persistent_workers = persistent_workers
134
184
 
135
185
  def setup(self, stage=None):
136
- dataset = AnnDataDataset(self.adata, self.tensor_source, self.tensor_key,
137
- None if self.inference_mode else self.label_col,
138
- window_start=self.window_start, window_size=self.window_size)
186
+ dataset = AnnDataDataset(
187
+ self.adata,
188
+ self.tensor_source,
189
+ self.tensor_key,
190
+ None if self.inference_mode else self.label_col,
191
+ window_start=self.window_start,
192
+ window_size=self.window_size,
193
+ )
139
194
 
140
195
  if self.inference_mode:
141
196
  self.infer_dataset = dataset
142
197
  return
143
198
 
144
199
  self.train_set, self.val_set, self.test_set = split_dataset(
145
- self.adata, dataset, train_frac=self.train_frac, val_frac=self.val_frac,
146
- test_frac=self.test_frac, random_seed=self.random_seed,
147
- split_col=self.split_col, split_save_path=self.split_save_path,
148
- load_existing_split=self.load_existing_split
200
+ self.adata,
201
+ dataset,
202
+ train_frac=self.train_frac,
203
+ val_frac=self.val_frac,
204
+ test_frac=self.test_frac,
205
+ random_seed=self.random_seed,
206
+ split_col=self.split_col,
207
+ split_save_path=self.split_save_path,
208
+ load_existing_split=self.load_existing_split,
149
209
  )
150
210
 
151
211
  def train_dataloader(self):
152
212
  if self.num_workers:
153
- return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
213
+ return DataLoader(
214
+ self.train_set,
215
+ batch_size=self.batch_size,
216
+ shuffle=True,
217
+ num_workers=self.num_workers,
218
+ persistent_workers=self.persistent_workers,
219
+ )
154
220
  else:
155
221
  return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
156
222
 
157
223
  def val_dataloader(self):
158
224
  if self.num_workers:
159
- return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
225
+ return DataLoader(
226
+ self.val_set,
227
+ batch_size=self.batch_size,
228
+ num_workers=self.num_workers,
229
+ persistent_workers=self.persistent_workers,
230
+ )
160
231
  else:
161
232
  return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
162
-
233
+
163
234
  def test_dataloader(self):
164
235
  if self.num_workers:
165
- return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=self.persistent_workers)
236
+ return DataLoader(
237
+ self.test_set,
238
+ batch_size=self.batch_size,
239
+ num_workers=self.num_workers,
240
+ persistent_workers=self.persistent_workers,
241
+ )
166
242
  else:
167
243
  return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
168
-
244
+
169
245
  def predict_dataloader(self):
170
246
  if not self.inference_mode:
171
247
  raise RuntimeError("Only valid in inference mode")
172
248
  return DataLoader(self.infer_dataset, batch_size=self.batch_size)
173
-
249
+
174
250
  def compute_class_weights(self):
175
- train_indices = self.train_set.indices # get the indices of the training set
176
- y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
177
- y_train = y_all[train_indices].cpu().numpy() # get the labels for the training set and move to a numpy array
178
-
179
- class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
251
+ train_indices = self.train_set.indices # get the indices of the training set
252
+ y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
253
+ y_train = (
254
+ y_all[train_indices].cpu().numpy()
255
+ ) # get the labels for the training set and move to a numpy array
256
+
257
+ class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
180
258
  return torch.tensor(class_weights, dtype=torch.float32)
181
-
259
+
182
260
  def inference_numpy(self):
183
261
  """
184
262
  Return inference data as numpy for use in sklearn inference.
@@ -187,7 +265,7 @@ class AnnDataModule(pl.LightningDataModule):
187
265
  raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
188
266
  X_np = self.infer_dataset.X_tensor.numpy()
189
267
  return X_np
190
-
268
+
191
269
  def to_numpy(self):
192
270
  """
193
271
  Move the AnnDataModule tensors into numpy arrays
@@ -202,9 +280,20 @@ class AnnDataModule(pl.LightningDataModule):
202
280
 
203
281
 
204
282
  def build_anndata_loader(
205
- adata, tensor_source="X", tensor_key=None, label_col=None, train_frac=0.6, val_frac=0.1,
206
- test_frac=0.3, random_seed=42, batch_size=64, lightning=True, inference_mode=False,
207
- split_col="train_val_test_split", split_save_path=None, load_existing_split=False
283
+ adata,
284
+ tensor_source="X",
285
+ tensor_key=None,
286
+ label_col=None,
287
+ train_frac=0.6,
288
+ val_frac=0.1,
289
+ test_frac=0.3,
290
+ random_seed=42,
291
+ batch_size=64,
292
+ lightning=True,
293
+ inference_mode=False,
294
+ split_col="train_val_test_split",
295
+ split_save_path=None,
296
+ load_existing_split=False,
208
297
  ):
209
298
  """
210
299
  Unified pipeline for both Lightning and raw PyTorch.
@@ -213,22 +302,40 @@ def build_anndata_loader(
213
302
  """
214
303
  if lightning:
215
304
  return AnnDataModule(
216
- adata, tensor_source=tensor_source, tensor_key=tensor_key, label_col=label_col,
217
- batch_size=batch_size, train_frac=train_frac, val_frac=val_frac, test_frac=test_frac,
218
- random_seed=random_seed, inference_mode=inference_mode,
219
- split_col=split_col, split_save_path=split_save_path, load_existing_split=load_existing_split
305
+ adata,
306
+ tensor_source=tensor_source,
307
+ tensor_key=tensor_key,
308
+ label_col=label_col,
309
+ batch_size=batch_size,
310
+ train_frac=train_frac,
311
+ val_frac=val_frac,
312
+ test_frac=test_frac,
313
+ random_seed=random_seed,
314
+ inference_mode=inference_mode,
315
+ split_col=split_col,
316
+ split_save_path=split_save_path,
317
+ load_existing_split=load_existing_split,
220
318
  )
221
319
  else:
222
320
  var_names = adata.var_names.copy()
223
- dataset = AnnDataDataset(adata, tensor_source, tensor_key, None if inference_mode else label_col)
321
+ dataset = AnnDataDataset(
322
+ adata, tensor_source, tensor_key, None if inference_mode else label_col
323
+ )
224
324
  if inference_mode:
225
325
  return DataLoader(dataset, batch_size=batch_size)
226
326
  else:
227
327
  train_set, val_set, test_set = split_dataset(
228
- adata, dataset, train_frac, val_frac, test_frac, random_seed,
229
- split_col, split_save_path, load_existing_split
328
+ adata,
329
+ dataset,
330
+ train_frac,
331
+ val_frac,
332
+ test_frac,
333
+ random_seed,
334
+ split_col,
335
+ split_save_path,
336
+ load_existing_split,
230
337
  )
231
338
  train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
232
339
  val_loader = DataLoader(val_set, batch_size=batch_size)
233
340
  test_loader = DataLoader(test_set, batch_size=batch_size)
234
- return train_loader, val_loader, test_loader
341
+ return train_loader, val_loader, test_loader
@@ -1,6 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
 
5
+
3
6
  def random_fill_nans(X):
4
7
  nan_mask = np.isnan(X)
5
8
  X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
- return X
9
+ return X
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
3
+ from .eval_utils import flatten_sliding_window_results
1
4
  from .evaluators import ModelEvaluator, PostInferenceModelEvaluator
2
- from .eval_utils import flatten_sliding_window_results
@@ -1,10 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  import pandas as pd
2
4
 
5
+
3
6
  def flatten_sliding_window_results(results_dict):
4
7
  """
5
8
  Flatten nested sliding window results into pandas DataFrame.
6
-
7
- Expects structure:
9
+
10
+ Expects structure:
8
11
  results[model_name][window_size][window_center]['metrics'][metric_name]
9
12
  """
10
13
  records = []
@@ -12,20 +15,16 @@ def flatten_sliding_window_results(results_dict):
12
15
  for model_name, model_results in results_dict.items():
13
16
  for window_size, window_results in model_results.items():
14
17
  for center_var, result in window_results.items():
15
- metrics = result['metrics']
16
- record = {
17
- 'model': model_name,
18
- 'window_size': window_size,
19
- 'center_var': center_var
20
- }
18
+ metrics = result["metrics"]
19
+ record = {"model": model_name, "window_size": window_size, "center_var": center_var}
21
20
  # Add all metrics
22
21
  record.update(metrics)
23
22
  records.append(record)
24
-
23
+
25
24
  df = pd.DataFrame.from_records(records)
26
-
25
+
27
26
  # Convert center_var to numeric if possible (optional but helpful for plotting)
28
- df['center_var'] = pd.to_numeric(df['center_var'], errors='coerce')
29
- df = df.sort_values(['model', 'window_size', 'center_var'])
30
-
31
- return df
27
+ df["center_var"] = pd.to_numeric(df["center_var"], errors="coerce")
28
+ df = df.sort_values(["model", "window_size", "center_var"])
29
+
30
+ return df