smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,16 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from smftools.logging_utils import get_logger
|
|
4
6
|
|
|
5
|
-
|
|
6
|
-
- input_dir (str): Path to the directory containing QC reports (e.g., FastQC, Samtools, bcftools outputs).
|
|
7
|
-
- output_dir (str): Path to the directory where MultiQC reports should be saved.
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
|
|
10
|
+
def run_multiqc(input_dir: str | Path, output_dir: str | Path) -> None:
|
|
11
|
+
"""Run MultiQC on a directory and save the report to the output directory.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
input_dir: Path to the directory containing QC reports (e.g., FastQC, Samtools outputs).
|
|
15
|
+
output_dir: Path to the directory where MultiQC reports should be saved.
|
|
11
16
|
"""
|
|
12
|
-
from ..readwrite import make_dirs
|
|
13
17
|
import subprocess
|
|
18
|
+
|
|
19
|
+
from ..readwrite import make_dirs
|
|
20
|
+
|
|
14
21
|
# Ensure the output directory exists
|
|
15
22
|
make_dirs(output_dir)
|
|
16
23
|
|
|
@@ -20,12 +27,11 @@ def run_multiqc(input_dir, output_dir):
|
|
|
20
27
|
# Construct MultiQC command
|
|
21
28
|
command = ["multiqc", input_dir, "-o", output_dir]
|
|
22
29
|
|
|
23
|
-
|
|
24
|
-
|
|
30
|
+
logger.info(f"Running MultiQC on '{input_dir}' and saving results to '{output_dir}'...")
|
|
31
|
+
|
|
25
32
|
# Run MultiQC
|
|
26
33
|
try:
|
|
27
34
|
subprocess.run(command, check=True)
|
|
28
|
-
|
|
35
|
+
logger.info(f"MultiQC report generated successfully in: {output_dir}")
|
|
29
36
|
except subprocess.CalledProcessError as e:
|
|
30
|
-
|
|
31
|
-
|
|
37
|
+
logger.error(f"Error running MultiQC: {e}")
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Logging utilities for smftools."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
DEFAULT_LOG_FORMAT = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
|
|
10
|
+
DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def setup_logging(
|
|
14
|
+
level: int = logging.INFO,
|
|
15
|
+
fmt: str = DEFAULT_LOG_FORMAT,
|
|
16
|
+
datefmt: str = DEFAULT_DATE_FORMAT,
|
|
17
|
+
log_file: Optional[Union[str, Path]] = None,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Configure logging for smftools.
|
|
21
|
+
|
|
22
|
+
Should be called once by the CLI entrypoint.
|
|
23
|
+
Safe to call multiple times.
|
|
24
|
+
"""
|
|
25
|
+
logger = logging.getLogger("smftools")
|
|
26
|
+
|
|
27
|
+
if logger.handlers:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
|
|
31
|
+
|
|
32
|
+
# Console handler (stderr)
|
|
33
|
+
stream_handler = logging.StreamHandler()
|
|
34
|
+
stream_handler.setFormatter(formatter)
|
|
35
|
+
logger.addHandler(stream_handler)
|
|
36
|
+
|
|
37
|
+
# Optional file handler
|
|
38
|
+
if log_file is not None:
|
|
39
|
+
log_path = Path(log_file)
|
|
40
|
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
|
|
42
|
+
file_handler = logging.FileHandler(log_path)
|
|
43
|
+
file_handler.setFormatter(formatter)
|
|
44
|
+
logger.addHandler(file_handler)
|
|
45
|
+
|
|
46
|
+
logger.setLevel(level)
|
|
47
|
+
logger.propagate = False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_logger(name: str) -> logging.Logger:
|
|
51
|
+
return logging.getLogger(name)
|
|
@@ -1,12 +1,23 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
"
|
|
10
|
-
"
|
|
11
|
-
"
|
|
12
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
_LAZY_MODULES = {
|
|
6
|
+
"data": "smftools.machine_learning.data",
|
|
7
|
+
"evaluation": "smftools.machine_learning.evaluation",
|
|
8
|
+
"inference": "smftools.machine_learning.inference",
|
|
9
|
+
"models": "smftools.machine_learning.models",
|
|
10
|
+
"training": "smftools.machine_learning.training",
|
|
11
|
+
"utils": "smftools.machine_learning.utils",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def __getattr__(name: str):
|
|
16
|
+
if name in _LAZY_MODULES:
|
|
17
|
+
module = import_module(_LAZY_MODULES[name])
|
|
18
|
+
globals()[name] = module
|
|
19
|
+
return module
|
|
20
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = list(_LAZY_MODULES.keys())
|
|
@@ -1,24 +1,48 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
import pytorch_lightning as pl
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
4
3
|
import numpy as np
|
|
5
4
|
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
6
8
|
from .preprocessing import random_fill_nans
|
|
7
|
-
from sklearn.utils.class_weight import compute_class_weight
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning data modules")
|
|
11
|
+
torch = require("torch", extra="ml-base", purpose="ML data loading")
|
|
12
|
+
sklearn_class_weight = require(
|
|
13
|
+
"sklearn.utils.class_weight",
|
|
14
|
+
extra="ml-base",
|
|
15
|
+
purpose="class weighting",
|
|
16
|
+
)
|
|
17
|
+
torch_utils_data = require("torch.utils.data", extra="ml-base", purpose="ML data loading")
|
|
18
|
+
|
|
19
|
+
compute_class_weight = sklearn_class_weight.compute_class_weight
|
|
20
|
+
DataLoader = torch_utils_data.DataLoader
|
|
21
|
+
Dataset = torch_utils_data.Dataset
|
|
22
|
+
Subset = torch_utils_data.Subset
|
|
23
|
+
|
|
24
|
+
|
|
10
25
|
class AnnDataDataset(Dataset):
|
|
11
26
|
"""
|
|
12
27
|
Generic PyTorch Dataset from AnnData.
|
|
13
28
|
"""
|
|
14
|
-
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
adata,
|
|
33
|
+
tensor_source="X",
|
|
34
|
+
tensor_key=None,
|
|
35
|
+
label_col=None,
|
|
36
|
+
window_start=None,
|
|
37
|
+
window_size=None,
|
|
38
|
+
):
|
|
15
39
|
self.adata = adata
|
|
16
40
|
self.tensor_source = tensor_source
|
|
17
41
|
self.tensor_key = tensor_key
|
|
18
42
|
self.label_col = label_col
|
|
19
43
|
self.window_start = window_start
|
|
20
44
|
self.window_size = window_size
|
|
21
|
-
|
|
45
|
+
|
|
22
46
|
if tensor_source == "X":
|
|
23
47
|
X = adata.X
|
|
24
48
|
elif tensor_source == "layers":
|
|
@@ -29,17 +53,17 @@ class AnnDataDataset(Dataset):
|
|
|
29
53
|
X = adata.obsm[tensor_key]
|
|
30
54
|
else:
|
|
31
55
|
raise ValueError(f"Invalid tensor_source: {tensor_source}")
|
|
32
|
-
|
|
56
|
+
|
|
33
57
|
if self.window_start is not None and self.window_size is not None:
|
|
34
58
|
X = X[:, self.window_start : self.window_start + self.window_size]
|
|
35
|
-
|
|
59
|
+
|
|
36
60
|
X = random_fill_nans(X)
|
|
37
61
|
|
|
38
62
|
self.X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
39
63
|
|
|
40
64
|
if label_col is not None:
|
|
41
65
|
y = adata.obs[label_col]
|
|
42
|
-
if y.dtype.name ==
|
|
66
|
+
if y.dtype.name == "category":
|
|
43
67
|
y = y.cat.codes
|
|
44
68
|
self.y_tensor = torch.tensor(y.values, dtype=torch.long)
|
|
45
69
|
else:
|
|
@@ -47,7 +71,7 @@ class AnnDataDataset(Dataset):
|
|
|
47
71
|
|
|
48
72
|
def numpy(self, indices):
|
|
49
73
|
return self.X_tensor[indices].numpy(), self.y_tensor[indices].numpy()
|
|
50
|
-
|
|
74
|
+
|
|
51
75
|
def __len__(self):
|
|
52
76
|
return len(self.X_tensor)
|
|
53
77
|
|
|
@@ -60,9 +84,17 @@ class AnnDataDataset(Dataset):
|
|
|
60
84
|
return (x,)
|
|
61
85
|
|
|
62
86
|
|
|
63
|
-
def split_dataset(
|
|
64
|
-
|
|
65
|
-
|
|
87
|
+
def split_dataset(
|
|
88
|
+
adata,
|
|
89
|
+
dataset,
|
|
90
|
+
train_frac=0.6,
|
|
91
|
+
val_frac=0.1,
|
|
92
|
+
test_frac=0.3,
|
|
93
|
+
random_seed=42,
|
|
94
|
+
split_col="train_val_test_split",
|
|
95
|
+
load_existing_split=False,
|
|
96
|
+
split_save_path=None,
|
|
97
|
+
):
|
|
66
98
|
"""
|
|
67
99
|
Perform split and record assignment into adata.obs[split_col].
|
|
68
100
|
"""
|
|
@@ -87,7 +119,7 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
|
|
|
87
119
|
|
|
88
120
|
split_array = np.full(total_len, "test", dtype=object)
|
|
89
121
|
split_array[indices[:n_train]] = "train"
|
|
90
|
-
split_array[indices[n_train:n_train + n_val]] = "val"
|
|
122
|
+
split_array[indices[n_train : n_train + n_val]] = "val"
|
|
91
123
|
adata.obs[split_col] = split_array
|
|
92
124
|
|
|
93
125
|
if split_save_path:
|
|
@@ -104,14 +136,32 @@ def split_dataset(adata, dataset, train_frac=0.6, val_frac=0.1, test_frac=0.3,
|
|
|
104
136
|
|
|
105
137
|
return train_set, val_set, test_set
|
|
106
138
|
|
|
139
|
+
|
|
107
140
|
class AnnDataModule(pl.LightningDataModule):
|
|
108
141
|
"""
|
|
109
142
|
Unified LightningDataModule version of AnnDataDataset + splitting with adata.obs recording.
|
|
110
143
|
"""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
adata,
|
|
148
|
+
tensor_source="X",
|
|
149
|
+
tensor_key=None,
|
|
150
|
+
label_col="labels",
|
|
151
|
+
batch_size=64,
|
|
152
|
+
train_frac=0.6,
|
|
153
|
+
val_frac=0.1,
|
|
154
|
+
test_frac=0.3,
|
|
155
|
+
random_seed=42,
|
|
156
|
+
inference_mode=False,
|
|
157
|
+
split_col="train_val_test_split",
|
|
158
|
+
split_save_path=None,
|
|
159
|
+
load_existing_split=False,
|
|
160
|
+
window_start=None,
|
|
161
|
+
window_size=None,
|
|
162
|
+
num_workers=None,
|
|
163
|
+
persistent_workers=False,
|
|
164
|
+
):
|
|
115
165
|
super().__init__()
|
|
116
166
|
self.adata = adata
|
|
117
167
|
self.tensor_source = tensor_source
|
|
@@ -133,52 +183,80 @@ class AnnDataModule(pl.LightningDataModule):
|
|
|
133
183
|
self.persistent_workers = persistent_workers
|
|
134
184
|
|
|
135
185
|
def setup(self, stage=None):
|
|
136
|
-
dataset = AnnDataDataset(
|
|
137
|
-
|
|
138
|
-
|
|
186
|
+
dataset = AnnDataDataset(
|
|
187
|
+
self.adata,
|
|
188
|
+
self.tensor_source,
|
|
189
|
+
self.tensor_key,
|
|
190
|
+
None if self.inference_mode else self.label_col,
|
|
191
|
+
window_start=self.window_start,
|
|
192
|
+
window_size=self.window_size,
|
|
193
|
+
)
|
|
139
194
|
|
|
140
195
|
if self.inference_mode:
|
|
141
196
|
self.infer_dataset = dataset
|
|
142
197
|
return
|
|
143
198
|
|
|
144
199
|
self.train_set, self.val_set, self.test_set = split_dataset(
|
|
145
|
-
self.adata,
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
200
|
+
self.adata,
|
|
201
|
+
dataset,
|
|
202
|
+
train_frac=self.train_frac,
|
|
203
|
+
val_frac=self.val_frac,
|
|
204
|
+
test_frac=self.test_frac,
|
|
205
|
+
random_seed=self.random_seed,
|
|
206
|
+
split_col=self.split_col,
|
|
207
|
+
split_save_path=self.split_save_path,
|
|
208
|
+
load_existing_split=self.load_existing_split,
|
|
149
209
|
)
|
|
150
210
|
|
|
151
211
|
def train_dataloader(self):
|
|
152
212
|
if self.num_workers:
|
|
153
|
-
return DataLoader(
|
|
213
|
+
return DataLoader(
|
|
214
|
+
self.train_set,
|
|
215
|
+
batch_size=self.batch_size,
|
|
216
|
+
shuffle=True,
|
|
217
|
+
num_workers=self.num_workers,
|
|
218
|
+
persistent_workers=self.persistent_workers,
|
|
219
|
+
)
|
|
154
220
|
else:
|
|
155
221
|
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
|
|
156
222
|
|
|
157
223
|
def val_dataloader(self):
|
|
158
224
|
if self.num_workers:
|
|
159
|
-
return DataLoader(
|
|
225
|
+
return DataLoader(
|
|
226
|
+
self.val_set,
|
|
227
|
+
batch_size=self.batch_size,
|
|
228
|
+
num_workers=self.num_workers,
|
|
229
|
+
persistent_workers=self.persistent_workers,
|
|
230
|
+
)
|
|
160
231
|
else:
|
|
161
232
|
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
162
|
-
|
|
233
|
+
|
|
163
234
|
def test_dataloader(self):
|
|
164
235
|
if self.num_workers:
|
|
165
|
-
return DataLoader(
|
|
236
|
+
return DataLoader(
|
|
237
|
+
self.test_set,
|
|
238
|
+
batch_size=self.batch_size,
|
|
239
|
+
num_workers=self.num_workers,
|
|
240
|
+
persistent_workers=self.persistent_workers,
|
|
241
|
+
)
|
|
166
242
|
else:
|
|
167
243
|
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=False)
|
|
168
|
-
|
|
244
|
+
|
|
169
245
|
def predict_dataloader(self):
|
|
170
246
|
if not self.inference_mode:
|
|
171
247
|
raise RuntimeError("Only valid in inference mode")
|
|
172
248
|
return DataLoader(self.infer_dataset, batch_size=self.batch_size)
|
|
173
|
-
|
|
249
|
+
|
|
174
250
|
def compute_class_weights(self):
|
|
175
|
-
train_indices = self.train_set.indices
|
|
176
|
-
y_all = self.train_set.dataset.y_tensor
|
|
177
|
-
y_train =
|
|
178
|
-
|
|
179
|
-
|
|
251
|
+
train_indices = self.train_set.indices # get the indices of the training set
|
|
252
|
+
y_all = self.train_set.dataset.y_tensor # get labels for the entire dataset (We are pulling from a Subset object, so this syntax can be confusing)
|
|
253
|
+
y_train = (
|
|
254
|
+
y_all[train_indices].cpu().numpy()
|
|
255
|
+
) # get the labels for the training set and move to a numpy array
|
|
256
|
+
|
|
257
|
+
class_weights = compute_class_weight("balanced", classes=np.unique(y_train), y=y_train)
|
|
180
258
|
return torch.tensor(class_weights, dtype=torch.float32)
|
|
181
|
-
|
|
259
|
+
|
|
182
260
|
def inference_numpy(self):
|
|
183
261
|
"""
|
|
184
262
|
Return inference data as numpy for use in sklearn inference.
|
|
@@ -187,7 +265,7 @@ class AnnDataModule(pl.LightningDataModule):
|
|
|
187
265
|
raise RuntimeError("Must be in inference_mode=True to use inference_numpy()")
|
|
188
266
|
X_np = self.infer_dataset.X_tensor.numpy()
|
|
189
267
|
return X_np
|
|
190
|
-
|
|
268
|
+
|
|
191
269
|
def to_numpy(self):
|
|
192
270
|
"""
|
|
193
271
|
Move the AnnDataModule tensors into numpy arrays
|
|
@@ -202,9 +280,20 @@ class AnnDataModule(pl.LightningDataModule):
|
|
|
202
280
|
|
|
203
281
|
|
|
204
282
|
def build_anndata_loader(
|
|
205
|
-
adata,
|
|
206
|
-
|
|
207
|
-
|
|
283
|
+
adata,
|
|
284
|
+
tensor_source="X",
|
|
285
|
+
tensor_key=None,
|
|
286
|
+
label_col=None,
|
|
287
|
+
train_frac=0.6,
|
|
288
|
+
val_frac=0.1,
|
|
289
|
+
test_frac=0.3,
|
|
290
|
+
random_seed=42,
|
|
291
|
+
batch_size=64,
|
|
292
|
+
lightning=True,
|
|
293
|
+
inference_mode=False,
|
|
294
|
+
split_col="train_val_test_split",
|
|
295
|
+
split_save_path=None,
|
|
296
|
+
load_existing_split=False,
|
|
208
297
|
):
|
|
209
298
|
"""
|
|
210
299
|
Unified pipeline for both Lightning and raw PyTorch.
|
|
@@ -213,22 +302,40 @@ def build_anndata_loader(
|
|
|
213
302
|
"""
|
|
214
303
|
if lightning:
|
|
215
304
|
return AnnDataModule(
|
|
216
|
-
adata,
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
305
|
+
adata,
|
|
306
|
+
tensor_source=tensor_source,
|
|
307
|
+
tensor_key=tensor_key,
|
|
308
|
+
label_col=label_col,
|
|
309
|
+
batch_size=batch_size,
|
|
310
|
+
train_frac=train_frac,
|
|
311
|
+
val_frac=val_frac,
|
|
312
|
+
test_frac=test_frac,
|
|
313
|
+
random_seed=random_seed,
|
|
314
|
+
inference_mode=inference_mode,
|
|
315
|
+
split_col=split_col,
|
|
316
|
+
split_save_path=split_save_path,
|
|
317
|
+
load_existing_split=load_existing_split,
|
|
220
318
|
)
|
|
221
319
|
else:
|
|
222
320
|
var_names = adata.var_names.copy()
|
|
223
|
-
dataset = AnnDataDataset(
|
|
321
|
+
dataset = AnnDataDataset(
|
|
322
|
+
adata, tensor_source, tensor_key, None if inference_mode else label_col
|
|
323
|
+
)
|
|
224
324
|
if inference_mode:
|
|
225
325
|
return DataLoader(dataset, batch_size=batch_size)
|
|
226
326
|
else:
|
|
227
327
|
train_set, val_set, test_set = split_dataset(
|
|
228
|
-
adata,
|
|
229
|
-
|
|
328
|
+
adata,
|
|
329
|
+
dataset,
|
|
330
|
+
train_frac,
|
|
331
|
+
val_frac,
|
|
332
|
+
test_frac,
|
|
333
|
+
random_seed,
|
|
334
|
+
split_col,
|
|
335
|
+
split_save_path,
|
|
336
|
+
load_existing_split,
|
|
230
337
|
)
|
|
231
338
|
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
|
|
232
339
|
val_loader = DataLoader(val_set, batch_size=batch_size)
|
|
233
340
|
test_loader = DataLoader(test_set, batch_size=batch_size)
|
|
234
|
-
return train_loader, val_loader, test_loader
|
|
341
|
+
return train_loader, val_loader, test_loader
|
|
@@ -1,10 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import pandas as pd
|
|
2
4
|
|
|
5
|
+
|
|
3
6
|
def flatten_sliding_window_results(results_dict):
|
|
4
7
|
"""
|
|
5
8
|
Flatten nested sliding window results into pandas DataFrame.
|
|
6
|
-
|
|
7
|
-
Expects structure:
|
|
9
|
+
|
|
10
|
+
Expects structure:
|
|
8
11
|
results[model_name][window_size][window_center]['metrics'][metric_name]
|
|
9
12
|
"""
|
|
10
13
|
records = []
|
|
@@ -12,20 +15,16 @@ def flatten_sliding_window_results(results_dict):
|
|
|
12
15
|
for model_name, model_results in results_dict.items():
|
|
13
16
|
for window_size, window_results in model_results.items():
|
|
14
17
|
for center_var, result in window_results.items():
|
|
15
|
-
metrics = result[
|
|
16
|
-
record = {
|
|
17
|
-
'model': model_name,
|
|
18
|
-
'window_size': window_size,
|
|
19
|
-
'center_var': center_var
|
|
20
|
-
}
|
|
18
|
+
metrics = result["metrics"]
|
|
19
|
+
record = {"model": model_name, "window_size": window_size, "center_var": center_var}
|
|
21
20
|
# Add all metrics
|
|
22
21
|
record.update(metrics)
|
|
23
22
|
records.append(record)
|
|
24
|
-
|
|
23
|
+
|
|
25
24
|
df = pd.DataFrame.from_records(records)
|
|
26
|
-
|
|
25
|
+
|
|
27
26
|
# Convert center_var to numeric if possible (optional but helpful for plotting)
|
|
28
|
-
df[
|
|
29
|
-
df = df.sort_values([
|
|
30
|
-
|
|
31
|
-
return df
|
|
27
|
+
df["center_var"] = pd.to_numeric(df["center_var"], errors="coerce")
|
|
28
|
+
df = df.sort_values(["model", "window_size", "center_var"])
|
|
29
|
+
|
|
30
|
+
return df
|