smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,16 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from smftools.logging_utils import get_logger
|
|
4
6
|
|
|
5
|
-
|
|
6
|
-
- input_dir (str): Path to the directory containing QC reports (e.g., FastQC, Samtools, bcftools outputs).
|
|
7
|
-
- output_dir (str): Path to the directory where MultiQC reports should be saved.
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
|
|
10
|
+
def run_multiqc(input_dir: str | Path, output_dir: str | Path) -> None:
|
|
11
|
+
"""Run MultiQC on a directory and save the report to the output directory.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
input_dir: Path to the directory containing QC reports (e.g., FastQC, Samtools outputs).
|
|
15
|
+
output_dir: Path to the directory where MultiQC reports should be saved.
|
|
11
16
|
"""
|
|
12
|
-
from ..readwrite import make_dirs
|
|
13
17
|
import subprocess
|
|
18
|
+
|
|
19
|
+
from ..readwrite import make_dirs
|
|
20
|
+
|
|
14
21
|
# Ensure the output directory exists
|
|
15
22
|
make_dirs(output_dir)
|
|
16
23
|
|
|
@@ -20,12 +27,11 @@ def run_multiqc(input_dir, output_dir):
|
|
|
20
27
|
# Construct MultiQC command
|
|
21
28
|
command = ["multiqc", input_dir, "-o", output_dir]
|
|
22
29
|
|
|
23
|
-
|
|
24
|
-
|
|
30
|
+
logger.info(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
|
|
31
|
+
|
|
25
32
|
# Run MultiQC
|
|
26
33
|
try:
|
|
27
34
|
subprocess.run(command, check=True)
|
|
28
|
-
|
|
35
|
+
logger.info(f"MultiQC report generated successfully in: {output_dir}")
|
|
29
36
|
except subprocess.CalledProcessError as e:
|
|
30
|
-
|
|
31
|
-
|
|
37
|
+
logger.error(f"Error running MultiQC: {e}")
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Logging utilities for smftools."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
DEFAULT_LOG_FORMAT = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
|
|
10
|
+
DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def setup_logging(
|
|
14
|
+
level: int = logging.INFO,
|
|
15
|
+
fmt: str = DEFAULT_LOG_FORMAT,
|
|
16
|
+
datefmt: str = DEFAULT_DATE_FORMAT,
|
|
17
|
+
log_file: Optional[Union[str, Path]] = None,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Configure logging for smftools.
|
|
21
|
+
|
|
22
|
+
Should be called once by the CLI entrypoint.
|
|
23
|
+
Safe to call multiple times.
|
|
24
|
+
"""
|
|
25
|
+
logger = logging.getLogger("smftools")
|
|
26
|
+
|
|
27
|
+
if logger.handlers:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
|
|
31
|
+
|
|
32
|
+
# Console handler (stderr)
|
|
33
|
+
stream_handler = logging.StreamHandler()
|
|
34
|
+
stream_handler.setFormatter(formatter)
|
|
35
|
+
logger.addHandler(stream_handler)
|
|
36
|
+
|
|
37
|
+
# Optional file handler
|
|
38
|
+
if log_file is not None:
|
|
39
|
+
log_path = Path(log_file)
|
|
40
|
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
|
|
42
|
+
file_handler = logging.FileHandler(log_path)
|
|
43
|
+
file_handler.setFormatter(formatter)
|
|
44
|
+
logger.addHandler(file_handler)
|
|
45
|
+
|
|
46
|
+
logger.setLevel(level)
|
|
47
|
+
logger.propagate = False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_logger(name: str) -> logging.Logger:
|
|
51
|
+
return logging.getLogger(name)
|
|
@@ -1,12 +1,7 @@
|
|
|
1
|
-
from . import models
|
|
2
|
-
from . import data
|
|
3
|
-
from . import utils
|
|
4
|
-
from . import evaluation
|
|
5
|
-
from . import inference
|
|
6
|
-
from . import training
|
|
1
|
+
from . import data, evaluation, inference, models, training, utils
|
|
7
2
|
|
|
8
3
|
__all__ = [
|
|
9
4
|
"calculate_relative_risk_on_activity",
|
|
10
5
|
"evaluate_models_by_subgroup",
|
|
11
6
|
"prepare_melted_model_data",
|
|
12
|
-
]
|
|
7
|
+
]
|
|
@@ -1,24 +1,34 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset, Subset
|
|
3
|
-
import pytorch_lightning as pl
|
|
4
1
|
import numpy as np
|
|
5
2
|
import pandas as pd
|
|
6
|
-
|
|
3
|
+
import pytorch_lightning as pl
|
|
4
|
+
import torch
|
|
7
5
|
from sklearn.utils.class_weight import compute_class_weight
|
|
6
|
+
from torch.utils.data import DataLoader, Dataset, Subset
|
|
7
|
+
|
|
8
|
+
from .preprocessing import random_fill_nans
|
|
9
|
+
|
|
8
10
|
|
|
9
|
-
|
|
10
11
|
class AnnDataDataset(Dataset):
|
|
11
12
|
"""
|
|
12
13
|
Generic PyTorch Dataset from AnnData.
|
|
13
14
|
"""
|
|
14
|
-
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
adata,
|
|
19
|
+
tensor_source="X",
|
|
20
|
+
tensor_key=None,
|
|
21
|
+
label_col=None,
|
|
22
|
+
window_start=None,
|
|
23
|
+
window_size=None,
|
|
24
|
+
):
|
|
15
25
|
self.adata = adata
|
|
16
26
|
self.tensor_source = tensor_source
|
|
17
27
|
self.tensor_key = tensor_key
|
|
18
28
|
self.label_col = label_col
|
|
19
29
|
self.window_start = window_start
|
|
20
30
|
self.window_size = window_size
|
|
21
|
-
|
|
31
|
+
|
|
22
32
|
if tensor_source == "X":
|
|
23
33
|
X = adata.X
|
|
24
34
|
elif tensor_source == "layers":
|
|
@@ -29,17 +39,17 @@ class AnnDataDataset(Dataset):
|
|
|
29
39
|
X = adata.obsm[tensor_key]
|
|
30
40
|
else:
|
|
31
41
|
raise ValueError(f"Invalid tensor_source: {tensor_source}")
|
|
32
|
-
|
|
42
|
+
|
|
33
43
|
if self.window_start is not None and self.window_size is not None:
|
|
34
44
|
X = X[:, self.window_start : self.window_start + self.window_size]
|
|
35
|
-
|
|
45
|
+
|
|
36
46
|
X = random_fill_nans(X)
|
|
37
47
|
|
|
38
48
|
self.X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
39
49
|
|
|
40
50
|
if label_col is not None:
|
|
41
51
|
y = adata.obs[label_col]
|
|
42
|
-
if y.dtype.name ==
|
|
52
|
+
if y.dtype.name == "category":
|
|
43
53
|
y = y.cat.codes
|
|
44
54
|
self.y_tensor = torch.tensor(y.values, dtype=torch.long)
|
|
45
55
|
else:
|
|
@@ -47,7 +57,7 @@ class AnnDataDataset(Dataset):
|
|
|
47
57
|
|
|
48
58
|
def numpy(self, indices):
|
|
49
59
|
return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
|
|
50
|
-
|
|
60
|
+
|
|
51
61
|
def __len__(self):
|
|
52
62
|
return len(self.X_tensor)
|
|
53
63
|
|
|
@@ -60,9 +70,17 @@ class AnnDataDataset(Dataset):
|
|
|
60
70
|
return (x,)
|
|
61
71
|
|
|
62
72
|
|
|
63
|
-
def split_dataset(
|
|
64
|
-
|
|
65
|
-
|
|
73
|
+
def split_dataset(
|
|
74
|
+
adata,
|
|
75
|
+
dataset,
|
|
76
|
+
train_frac=0.6,
|
|
77
|
+
val_frac=0.1,
|
|
78
|
+
test_frac=0.3,
|
|
79
|
+
random_seed=42,
|
|
80
|
+
split_col="train_val_test_split",
|
|
81
|
+
load_existing_split=False,
|
|
82
|
+
split_save_path=None,
|
|
83
|
+
):
|
|
66
84
|
"""
|
|
67
85
|
Perform split and record assignment into adata.obs[split_col].
|
|
68
86
|
"""
|
|
@@ -87,7 +105,7 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
|
|
|
87
105
|
|
|
88
106
|
split_array = np.full(total_len, "test", dtype=object)
|
|
89
107
|
split_array[indices[:n_train]] = "train"
|
|
90
|
-
split_array[indices[n_train:n_train + n_val]] = "val"
|
|
108
|
+
split_array[indices[n_train : n_train + n_val]] = "val"
|
|
91
109
|
adata.obs[split_col] = split_array
|
|
92
110
|
|
|
93
111
|
if split_save_path:
|
|
@@ -104,14 +122,32 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
|
|
|
104
122
|
|
|
105
123
|
return train_set, val_set, test_set
|
|
106
124
|
|
|
125
|
+
|
|
107
126
|
class AnnDataModule(pl.LightningDataModule):
|
|
108
127
|
"""
|
|
109
128
|
Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
|
|
110
129
|
"""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
adata,
|
|
134
|
+
tensor_source="X",
|
|
135
|
+
tensor_key=None,
|
|
136
|
+
label_col="labels",
|
|
137
|
+
batch_size=64,
|
|
138
|
+
train_frac=0.6,
|
|
139
|
+
val_frac=0.1,
|
|
140
|
+
test_frac=0.3,
|
|
141
|
+
random_seed=42,
|
|
142
|
+
inference_mode=False,
|
|
143
|
+
split_col="train_val_test_split",
|
|
144
|
+
split_save_path=None,
|
|
145
|
+
load_existing_split=False,
|
|
146
|
+
window_start=None,
|
|
147
|
+
window_size=None,
|
|
148
|
+
num_workers=None,
|
|
149
|
+
persistent_workers=False,
|
|
150
|
+
):
|
|
115
151
|
super().__init__()
|
|
116
152
|
self.adata = adata
|
|
117
153
|
self.tensor_source = tensor_source
|
|
@@ -133,52 +169,80 @@ class AnnDataModule(pl.LightningDataModule):
|
|
|
133
169
|
self.persistent_workers = persistent_workers
|
|
134
170
|
|
|
135
171
|
def setup(self, stage=None):
|
|
136
|
-
dataset = AnnDataDataset(
|
|
137
|
-
|
|
138
|
-
|
|
172
|
+
dataset = AnnDataDataset(
|
|
173
|
+
self.adata,
|
|
174
|
+
self.tensor_source,
|
|
175
|
+
self.tensor_key,
|
|
176
|
+
None if self.inference_mode else self.label_col,
|
|
177
|
+
window_start=self.window_start,
|
|
178
|
+
window_size=self.window_size,
|
|
179
|
+
)
|
|
139
180
|
|
|
140
181
|
if self.inference_mode:
|
|
141
182
|
self.infer_dataset = dataset
|
|
142
183
|
return
|
|
143
184
|
|
|
144
185
|
self.train_set, self.val_set, self.test_set = split_dataset(
|
|
145
|
-
self.adata,
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
186
|
+
self.adata,
|
|
187
|
+
dataset,
|
|
188
|
+
train_frac=self.train_frac,
|
|
189
|
+
val_frac=self.val_frac,
|
|
190
|
+
test_frac=self.test_frac,
|
|
191
|
+
random_seed=self.random_seed,
|
|
192
|
+
split_col=self.split_col,
|
|
193
|
+
split_save_path=self.split_save_path,
|
|
194
|
+
load_existing_split=self.load_existing_split,
|
|
149
195
|
)
|
|
150
196
|
|
|
151
197
|
def train_dataloader(self):
|
|
152
198
|
if self.num_workers:
|
|
153
|
-
return DataLoader(
|
|
199
|
+
return DataLoader(
|
|
200
|
+
self.train_set,
|
|
201
|
+
batch_size=self.batch_size,
|
|
202
|
+
shuffle=True,
|
|
203
|
+
num_workers=self.num_workers,
|
|
204
|
+
persistent_workers=self.persistent_workers,
|
|
205
|
+
)
|
|
154
206
|
else:
|
|
155
207
|
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
|
|
156
208
|
|
|
157
209
|
def val_dataloader(self):
|
|
158
210
|
if self.num_workers:
|
|
159
|
-
return DataLoader(
|
|
211
|
+
return DataLoader(
|
|
212
|
+
self.val_set,
|
|
213
|
+
batch_size=self.batch_size,
|
|
214
|
+
num_workers=self.num_workers,
|
|
215
|
+
persistent_workers=self.persistent_workers,
|
|
216
|
+
)
|
|
160
217
|
else:
|
|
161
218
|
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
162
|
-
|
|
219
|
+
|
|
163
220
|
def test_dataloader(self):
|
|
164
221
|
if self.num_workers:
|
|
165
|
-
return DataLoader(
|
|
222
|
+
return DataLoader(
|
|
223
|
+
self.test_set,
|
|
224
|
+
batch_size=self.batch_size,
|
|
225
|
+
num_workers=self.num_workers,
|
|
226
|
+
persistent_workers=self.persistent_workers,
|
|
227
|
+
)
|
|
166
228
|
else:
|
|
167
229
|
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
168
|
-
|
|
230
|
+
|
|
169
231
|
def predict_dataloader(self):
|
|
170
232
|
if not self.inference_mode:
|
|
171
233
|
raise RuntimeError("Only valid in inference mode")
|
|
172
234
|
return DataLoader(self.infer_dataset, batch_size=self.batch_size)
|
|
173
|
-
|
|
235
|
+
|
|
174
236
|
def compute_class_weights(self):
|
|
175
|
-
train_indices = self.train_set.indices
|
|
176
|
-
y_all = self.train_set.dataset.y_tensor
|
|
177
|
-
y_train =
|
|
178
|
-
|
|
179
|
-
|
|
237
|
+
train_indices = self.train_set.indices # get the indices of the training set
|
|
238
|
+
y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
|
|
239
|
+
y_train = (
|
|
240
|
+
y_all[train_indices].cpu().numpy()
|
|
241
|
+
) # get the labels for the training set and move to a numpy array
|
|
242
|
+
|
|
243
|
+
class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
|
|
180
244
|
return torch.tensor(class_weights, dtype=torch.float32)
|
|
181
|
-
|
|
245
|
+
|
|
182
246
|
def inference_numpy(self):
|
|
183
247
|
"""
|
|
184
248
|
Return inference data as numpy for use in sklearn inference.
|
|
@@ -187,7 +251,7 @@ class AnnDataModule(pl.LightningDataModule):
|
|
|
187
251
|
raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
|
|
188
252
|
X_np = self.infer_dataset.X_tensor.numpy()
|
|
189
253
|
return X_np
|
|
190
|
-
|
|
254
|
+
|
|
191
255
|
def to_numpy(self):
|
|
192
256
|
"""
|
|
193
257
|
Move the AnnDataModule tensors into numpy arrays
|
|
@@ -202,9 +266,20 @@ class AnnDataModule(pl.LightningDataModule):
|
|
|
202
266
|
|
|
203
267
|
|
|
204
268
|
def build_anndata_loader(
|
|
205
|
-
adata,
|
|
206
|
-
|
|
207
|
-
|
|
269
|
+
adata,
|
|
270
|
+
tensor_source="X",
|
|
271
|
+
tensor_key=None,
|
|
272
|
+
label_col=None,
|
|
273
|
+
train_frac=0.6,
|
|
274
|
+
val_frac=0.1,
|
|
275
|
+
test_frac=0.3,
|
|
276
|
+
random_seed=42,
|
|
277
|
+
batch_size=64,
|
|
278
|
+
lightning=True,
|
|
279
|
+
inference_mode=False,
|
|
280
|
+
split_col="train_val_test_split",
|
|
281
|
+
split_save_path=None,
|
|
282
|
+
load_existing_split=False,
|
|
208
283
|
):
|
|
209
284
|
"""
|
|
210
285
|
Unified pipeline for both Lightning and raw PyTorch.
|
|
@@ -213,22 +288,40 @@ def build_anndata_loader(
|
|
|
213
288
|
"""
|
|
214
289
|
if lightning:
|
|
215
290
|
return AnnDataModule(
|
|
216
|
-
adata,
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
291
|
+
adata,
|
|
292
|
+
tensor_source=tensor_source,
|
|
293
|
+
tensor_key=tensor_key,
|
|
294
|
+
label_col=label_col,
|
|
295
|
+
batch_size=batch_size,
|
|
296
|
+
train_frac=train_frac,
|
|
297
|
+
val_frac=val_frac,
|
|
298
|
+
test_frac=test_frac,
|
|
299
|
+
random_seed=random_seed,
|
|
300
|
+
inference_mode=inference_mode,
|
|
301
|
+
split_col=split_col,
|
|
302
|
+
split_save_path=split_save_path,
|
|
303
|
+
load_existing_split=load_existing_split,
|
|
220
304
|
)
|
|
221
305
|
else:
|
|
222
306
|
var_names = adata.var_names.copy()
|
|
223
|
-
dataset = AnnDataDataset(
|
|
307
|
+
dataset = AnnDataDataset(
|
|
308
|
+
adata, tensor_source, tensor_key, None if inference_mode else label_col
|
|
309
|
+
)
|
|
224
310
|
if inference_mode:
|
|
225
311
|
return DataLoader(dataset, batch_size=batch_size)
|
|
226
312
|
else:
|
|
227
313
|
train_set, val_set, test_set = split_dataset(
|
|
228
|
-
adata,
|
|
229
|
-
|
|
314
|
+
adata,
|
|
315
|
+
dataset,
|
|
316
|
+
train_frac,
|
|
317
|
+
val_frac,
|
|
318
|
+
test_frac,
|
|
319
|
+
random_seed,
|
|
320
|
+
split_col,
|
|
321
|
+
split_save_path,
|
|
322
|
+
load_existing_split,
|
|
230
323
|
)
|
|
231
324
|
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
|
|
232
325
|
val_loader = DataLoader(val_set, batch_size=batch_size)
|
|
233
326
|
test_loader = DataLoader(test_set, batch_size=batch_size)
|
|
234
|
-
return train_loader, val_loader, test_loader
|
|
327
|
+
return train_loader, val_loader, test_loader
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
def flatten_sliding_window_results(results_dict):
|
|
4
5
|
"""
|
|
5
6
|
Flatten nested sliding window results into pandas DataFrame.
|
|
6
|
-
|
|
7
|
-
Expects structure:
|
|
7
|
+
|
|
8
|
+
Expects structure:
|
|
8
9
|
results[model_name][window_size][window_center]['metrics'][metric_name]
|
|
9
10
|
"""
|
|
10
11
|
records = []
|
|
@@ -12,20 +13,16 @@ def flatten_sliding_window_results(results_dict):
|
|
|
12
13
|
for model_name, model_results in results_dict.items():
|
|
13
14
|
for window_size, window_results in model_results.items():
|
|
14
15
|
for center_var, result in window_results.items():
|
|
15
|
-
metrics = result[
|
|
16
|
-
record = {
|
|
17
|
-
'model': model_name,
|
|
18
|
-
'window_size': window_size,
|
|
19
|
-
'center_var': center_var
|
|
20
|
-
}
|
|
16
|
+
metrics = result["metrics"]
|
|
17
|
+
record = {"model": model_name, "window_size": window_size, "center_var": center_var}
|
|
21
18
|
# Add all metrics
|
|
22
19
|
record.update(metrics)
|
|
23
20
|
records.append(record)
|
|
24
|
-
|
|
21
|
+
|
|
25
22
|
df = pd.DataFrame.from_records(records)
|
|
26
|
-
|
|
23
|
+
|
|
27
24
|
# Convert center_var to numeric if possible (optional but helpful for plotting)
|
|
28
|
-
df[
|
|
29
|
-
df = df.sort_values([
|
|
30
|
-
|
|
31
|
-
return df
|
|
25
|
+
df["center_var"] = pd.to_numeric(df["center_var"], errors="coerce")
|
|
26
|
+
df = df.sort_values(["model", "window_size", "center_var"])
|
|
27
|
+
|
|
28
|
+
return df
|
|
@@ -1,15 +1,21 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
1
2
|
import numpy as np
|
|
2
3
|
import pandas as pd
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
|
|
5
4
|
from sklearn.metrics import (
|
|
6
|
-
|
|
5
|
+
auc,
|
|
6
|
+
confusion_matrix,
|
|
7
|
+
f1_score,
|
|
8
|
+
precision_recall_curve,
|
|
9
|
+
roc_auc_score,
|
|
10
|
+
roc_curve,
|
|
7
11
|
)
|
|
8
12
|
|
|
13
|
+
|
|
9
14
|
class ModelEvaluator:
|
|
10
15
|
"""
|
|
11
16
|
A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
|
|
12
17
|
"""
|
|
18
|
+
|
|
13
19
|
def __init__(self):
|
|
14
20
|
self.results = []
|
|
15
21
|
self.pos_freq = None
|
|
@@ -21,41 +27,45 @@ class ModelEvaluator:
|
|
|
21
27
|
"""
|
|
22
28
|
if is_torch:
|
|
23
29
|
entry = {
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
30
|
+
"name": name,
|
|
31
|
+
"f1": model.test_f1,
|
|
32
|
+
"auc": model.test_roc_auc,
|
|
33
|
+
"pr_auc": model.test_pr_auc,
|
|
34
|
+
"pr_auc_norm": model.test_pr_auc / model.test_pos_freq
|
|
35
|
+
if model.test_pos_freq > 0
|
|
36
|
+
else np.nan,
|
|
37
|
+
"pr_curve": model.test_pr_curve,
|
|
38
|
+
"roc_curve": model.test_roc_curve,
|
|
39
|
+
"num_pos": model.test_num_pos,
|
|
40
|
+
"pos_freq": model.test_pos_freq,
|
|
33
41
|
}
|
|
34
42
|
else:
|
|
35
43
|
entry = {
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
"name": name,
|
|
45
|
+
"f1": model.test_f1,
|
|
46
|
+
"auc": model.test_roc_auc,
|
|
47
|
+
"pr_auc": model.test_pr_auc,
|
|
48
|
+
"pr_auc_norm": model.test_pr_auc / model.test_pos_freq
|
|
49
|
+
if model.test_pos_freq > 0
|
|
50
|
+
else np.nan,
|
|
51
|
+
"pr_curve": model.test_pr_curve,
|
|
52
|
+
"roc_curve": model.test_roc_curve,
|
|
53
|
+
"num_pos": model.test_num_pos,
|
|
54
|
+
"pos_freq": model.test_pos_freq,
|
|
45
55
|
}
|
|
46
|
-
|
|
56
|
+
|
|
47
57
|
self.results.append(entry)
|
|
48
58
|
|
|
49
59
|
if not self.pos_freq:
|
|
50
|
-
self.pos_freq = entry[
|
|
51
|
-
self.num_pos = entry[
|
|
60
|
+
self.pos_freq = entry["pos_freq"]
|
|
61
|
+
self.num_pos = entry["num_pos"]
|
|
52
62
|
|
|
53
63
|
def get_metrics_dataframe(self):
|
|
54
64
|
"""
|
|
55
65
|
Return all metrics as pandas DataFrame.
|
|
56
66
|
"""
|
|
57
67
|
df = pd.DataFrame(self.results)
|
|
58
|
-
return df[[
|
|
68
|
+
return df[["name", "f1", "auc", "pr_auc", "pr_auc_norm", "num_pos", "pos_freq"]]
|
|
59
69
|
|
|
60
70
|
def plot_all_curves(self):
|
|
61
71
|
"""
|
|
@@ -66,30 +76,31 @@ class ModelEvaluator:
|
|
|
66
76
|
# ROC
|
|
67
77
|
plt.subplot(1, 2, 1)
|
|
68
78
|
for res in self.results:
|
|
69
|
-
fpr, tpr = res[
|
|
79
|
+
fpr, tpr = res["roc_curve"]
|
|
70
80
|
plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
|
|
71
81
|
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
72
82
|
plt.xlabel("False Positive Rate")
|
|
73
83
|
plt.ylabel("True Positive Rate")
|
|
74
|
-
plt.ylim(0,1.05)
|
|
84
|
+
plt.ylim(0, 1.05)
|
|
75
85
|
plt.title(f"ROC Curves - {self.num_pos} positive instances")
|
|
76
86
|
plt.legend()
|
|
77
87
|
|
|
78
88
|
# PR
|
|
79
89
|
plt.subplot(1, 2, 2)
|
|
80
90
|
for res in self.results:
|
|
81
|
-
rc, pr = res[
|
|
91
|
+
rc, pr = res["pr_curve"]
|
|
82
92
|
plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
|
|
83
93
|
plt.xlabel("Recall")
|
|
84
94
|
plt.ylabel("Precision")
|
|
85
|
-
plt.ylim(0,1.05)
|
|
86
|
-
plt.axhline(self.pos_freq, linestyle=
|
|
95
|
+
plt.ylim(0, 1.05)
|
|
96
|
+
plt.axhline(self.pos_freq, linestyle="--", color="grey")
|
|
87
97
|
plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
|
|
88
98
|
plt.legend()
|
|
89
99
|
|
|
90
100
|
plt.tight_layout()
|
|
91
101
|
plt.show()
|
|
92
102
|
|
|
103
|
+
|
|
93
104
|
class PostInferenceModelEvaluator:
|
|
94
105
|
def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
|
|
95
106
|
"""
|
|
@@ -179,12 +190,14 @@ class PostInferenceModelEvaluator:
|
|
|
179
190
|
"pos_freq": pos_freq,
|
|
180
191
|
"confusion_matrix": cm,
|
|
181
192
|
"pr_rc_curve": (pr, rc),
|
|
182
|
-
"roc_curve": (tpr, fpr)
|
|
193
|
+
"roc_curve": (tpr, fpr),
|
|
183
194
|
}
|
|
184
195
|
|
|
185
196
|
return metrics
|
|
186
|
-
|
|
187
|
-
def _subsample_for_fixed_positive_frequency(
|
|
197
|
+
|
|
198
|
+
def _subsample_for_fixed_positive_frequency(
|
|
199
|
+
self, binary_labels, target_freq=0.3, max_positive=None
|
|
200
|
+
):
|
|
188
201
|
pos_idx = np.where(binary_labels == 1)[0]
|
|
189
202
|
neg_idx = np.where(binary_labels == 0)[0]
|
|
190
203
|
|