scdlkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdlkit/__init__.py +20 -0
- scdlkit/data/__init__.py +6 -0
- scdlkit/data/datasets.py +30 -0
- scdlkit/data/prepare.py +262 -0
- scdlkit/data/schemas.py +35 -0
- scdlkit/data/splits.py +90 -0
- scdlkit/evaluation/__init__.py +6 -0
- scdlkit/evaluation/compare.py +69 -0
- scdlkit/evaluation/evaluator.py +34 -0
- scdlkit/evaluation/metrics.py +83 -0
- scdlkit/evaluation/report.py +40 -0
- scdlkit/models/__init__.py +20 -0
- scdlkit/models/autoencoder.py +43 -0
- scdlkit/models/base.py +22 -0
- scdlkit/models/blocks.py +32 -0
- scdlkit/models/classifier.py +30 -0
- scdlkit/models/denoising.py +37 -0
- scdlkit/models/registry.py +33 -0
- scdlkit/models/transformer.py +73 -0
- scdlkit/models/vae.py +61 -0
- scdlkit/runner.py +278 -0
- scdlkit/tasks/__init__.py +14 -0
- scdlkit/tasks/base.py +40 -0
- scdlkit/tasks/classification.py +28 -0
- scdlkit/tasks/reconstruction.py +41 -0
- scdlkit/tasks/representation.py +14 -0
- scdlkit/training/__init__.py +5 -0
- scdlkit/training/callbacks.py +12 -0
- scdlkit/training/trainer.py +176 -0
- scdlkit/utils/__init__.py +7 -0
- scdlkit/utils/device.py +13 -0
- scdlkit/utils/io.py +13 -0
- scdlkit/utils/seed.py +18 -0
- scdlkit/visualization/__init__.py +15 -0
- scdlkit/visualization/classification.py +26 -0
- scdlkit/visualization/compare.py +32 -0
- scdlkit/visualization/latent.py +36 -0
- scdlkit/visualization/reconstruction.py +24 -0
- scdlkit/visualization/training.py +21 -0
- scdlkit-0.1.0.dist-info/METADATA +265 -0
- scdlkit-0.1.0.dist-info/RECORD +44 -0
- scdlkit-0.1.0.dist-info/WHEEL +5 -0
- scdlkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- scdlkit-0.1.0.dist-info/top_level.txt +1 -0
scdlkit/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Public package surface for scDLKit."""
|
|
2
|
+
|
|
3
|
+
from scdlkit.data import PreparedData, prepare_data
|
|
4
|
+
from scdlkit.evaluation.compare import BenchmarkResult, compare_models
|
|
5
|
+
from scdlkit.models import BaseModel, create_model
|
|
6
|
+
from scdlkit.runner import TaskRunner
|
|
7
|
+
from scdlkit.training import Trainer
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BaseModel",
|
|
11
|
+
"BenchmarkResult",
|
|
12
|
+
"PreparedData",
|
|
13
|
+
"TaskRunner",
|
|
14
|
+
"Trainer",
|
|
15
|
+
"compare_models",
|
|
16
|
+
"create_model",
|
|
17
|
+
"prepare_data",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
__version__ = "0.1.0"
|
scdlkit/data/__init__.py
ADDED
scdlkit/data/datasets.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""PyTorch datasets backed by dense or sparse matrices."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from scipy import sparse
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
|
|
10
|
+
from scdlkit.data.schemas import SplitData
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AnnDataset(Dataset[dict[str, torch.Tensor]]):
|
|
14
|
+
"""Dataset that converts rows to dense float32 on access."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, split: SplitData):
|
|
17
|
+
self.split = split
|
|
18
|
+
|
|
19
|
+
def __len__(self) -> int:
|
|
20
|
+
return len(self.split)
|
|
21
|
+
|
|
22
|
+
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
|
23
|
+
row = self.split.X[index]
|
|
24
|
+
x = row.toarray().ravel() if sparse.issparse(row) else np.asarray(row).ravel()
|
|
25
|
+
sample: dict[str, torch.Tensor] = {"x": torch.as_tensor(x, dtype=torch.float32)}
|
|
26
|
+
if self.split.labels is not None:
|
|
27
|
+
sample["y"] = torch.as_tensor(int(self.split.labels[index]), dtype=torch.long)
|
|
28
|
+
if self.split.batches is not None:
|
|
29
|
+
sample["batch"] = torch.as_tensor(int(self.split.batches[index]), dtype=torch.long)
|
|
30
|
+
return sample
|
scdlkit/data/prepare.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""AnnData preparation and transformation utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from anndata import AnnData
|
|
9
|
+
from scipy import sparse
|
|
10
|
+
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
|
11
|
+
|
|
12
|
+
from scdlkit.data.schemas import PreparedData, SplitData
|
|
13
|
+
from scdlkit.data.splits import build_splits
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _require_scanpy() -> Any:
|
|
17
|
+
try:
|
|
18
|
+
import scanpy as sc
|
|
19
|
+
except ImportError as exc:
|
|
20
|
+
msg = "scanpy-backed preprocessing requires `pip install scdlkit[scanpy]`."
|
|
21
|
+
raise ImportError(msg) from exc
|
|
22
|
+
return sc
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _encode_obs(values: np.ndarray | None) -> tuple[np.ndarray | None, dict[str, int] | None]:
|
|
26
|
+
if values is None:
|
|
27
|
+
return None, None
|
|
28
|
+
encoder = LabelEncoder()
|
|
29
|
+
encoded = encoder.fit_transform(values.astype(str))
|
|
30
|
+
mapping = {label: int(index) for index, label in enumerate(encoder.classes_)}
|
|
31
|
+
return encoded, mapping
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _transform_obs(values: np.ndarray | None, mapping: dict[str, int] | None) -> np.ndarray | None:
|
|
35
|
+
if values is None or mapping is None:
|
|
36
|
+
return None
|
|
37
|
+
encoded = np.empty(values.shape[0], dtype=int)
|
|
38
|
+
for index, value in enumerate(values.astype(str)):
|
|
39
|
+
if value not in mapping:
|
|
40
|
+
msg = f"Encountered unseen label '{value}' during transform."
|
|
41
|
+
raise ValueError(msg)
|
|
42
|
+
encoded[index] = mapping[value]
|
|
43
|
+
return encoded
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _extract_matrix(adata: AnnData, layer: str) -> Any:
|
|
47
|
+
if layer == "X":
|
|
48
|
+
return adata.X
|
|
49
|
+
if layer not in adata.layers:
|
|
50
|
+
msg = f"Layer '{layer}' not found in AnnData.layers."
|
|
51
|
+
raise ValueError(msg)
|
|
52
|
+
return adata.layers[layer]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _to_split_data(
|
|
56
|
+
x_matrix: Any,
|
|
57
|
+
indices: np.ndarray,
|
|
58
|
+
*,
|
|
59
|
+
labels: np.ndarray | None,
|
|
60
|
+
batches: np.ndarray | None,
|
|
61
|
+
obs_names: list[str],
|
|
62
|
+
) -> SplitData:
|
|
63
|
+
return SplitData(
|
|
64
|
+
X=x_matrix[indices],
|
|
65
|
+
labels=labels[indices] if labels is not None else None,
|
|
66
|
+
batches=batches[indices] if batches is not None else None,
|
|
67
|
+
obs_names=[obs_names[index] for index in indices],
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _prepare_matrix(
|
|
72
|
+
adata: AnnData,
|
|
73
|
+
*,
|
|
74
|
+
layer: str,
|
|
75
|
+
use_hvg: bool,
|
|
76
|
+
n_top_genes: int,
|
|
77
|
+
normalize: bool,
|
|
78
|
+
log1p: bool,
|
|
79
|
+
scale: bool,
|
|
80
|
+
) -> tuple[AnnData, Any, list[str], StandardScaler | None]:
|
|
81
|
+
working = adata
|
|
82
|
+
if normalize or log1p or use_hvg:
|
|
83
|
+
sc = _require_scanpy()
|
|
84
|
+
if normalize:
|
|
85
|
+
sc.pp.normalize_total(working)
|
|
86
|
+
if log1p:
|
|
87
|
+
sc.pp.log1p(working)
|
|
88
|
+
if use_hvg:
|
|
89
|
+
sc.pp.highly_variable_genes(working, n_top_genes=n_top_genes, subset=True)
|
|
90
|
+
x_matrix = _extract_matrix(working, layer)
|
|
91
|
+
scaler: StandardScaler | None = None
|
|
92
|
+
if scale:
|
|
93
|
+
scaler = StandardScaler(with_mean=not sparse.issparse(x_matrix))
|
|
94
|
+
x_matrix = scaler.fit_transform(x_matrix)
|
|
95
|
+
if sparse.issparse(x_matrix):
|
|
96
|
+
x_matrix = x_matrix.tocsr()
|
|
97
|
+
feature_names = working.var_names.astype(str).tolist()
|
|
98
|
+
return working, x_matrix, feature_names, scaler
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def prepare_data(
|
|
102
|
+
adata: AnnData,
|
|
103
|
+
*,
|
|
104
|
+
layer: str = "X",
|
|
105
|
+
use_hvg: bool = False,
|
|
106
|
+
n_top_genes: int = 2000,
|
|
107
|
+
normalize: bool = False,
|
|
108
|
+
log1p: bool = False,
|
|
109
|
+
scale: bool = False,
|
|
110
|
+
label_key: str | None = None,
|
|
111
|
+
batch_key: str | None = None,
|
|
112
|
+
val_size: float = 0.15,
|
|
113
|
+
test_size: float = 0.15,
|
|
114
|
+
batch_aware_split: bool = False,
|
|
115
|
+
random_state: int = 42,
|
|
116
|
+
copy: bool = True,
|
|
117
|
+
) -> PreparedData:
|
|
118
|
+
"""Prepare AnnData splits and preprocessing metadata."""
|
|
119
|
+
|
|
120
|
+
working = adata.copy() if copy else adata
|
|
121
|
+
working, x_matrix, feature_names, scaler = _prepare_matrix(
|
|
122
|
+
working,
|
|
123
|
+
layer=layer,
|
|
124
|
+
use_hvg=use_hvg,
|
|
125
|
+
n_top_genes=n_top_genes,
|
|
126
|
+
normalize=normalize,
|
|
127
|
+
log1p=log1p,
|
|
128
|
+
scale=scale,
|
|
129
|
+
)
|
|
130
|
+
labels_raw = (
|
|
131
|
+
working.obs[label_key].astype(str).to_numpy()
|
|
132
|
+
if label_key is not None and label_key in working.obs
|
|
133
|
+
else None
|
|
134
|
+
)
|
|
135
|
+
if label_key is not None and labels_raw is None:
|
|
136
|
+
msg = f"label_key '{label_key}' not found in adata.obs."
|
|
137
|
+
raise ValueError(msg)
|
|
138
|
+
batches_raw = (
|
|
139
|
+
working.obs[batch_key].astype(str).to_numpy()
|
|
140
|
+
if batch_key is not None and batch_key in working.obs
|
|
141
|
+
else None
|
|
142
|
+
)
|
|
143
|
+
if batch_key is not None and batches_raw is None:
|
|
144
|
+
msg = f"batch_key '{batch_key}' not found in adata.obs."
|
|
145
|
+
raise ValueError(msg)
|
|
146
|
+
labels, label_encoder = _encode_obs(labels_raw)
|
|
147
|
+
batches, batch_encoder = _encode_obs(batches_raw)
|
|
148
|
+
|
|
149
|
+
split_indices = build_splits(
|
|
150
|
+
working.n_obs,
|
|
151
|
+
val_size=val_size,
|
|
152
|
+
test_size=test_size,
|
|
153
|
+
random_state=random_state,
|
|
154
|
+
stratify=labels if label_key is not None else None,
|
|
155
|
+
groups=batches if batch_aware_split and batch_key is not None else None,
|
|
156
|
+
)
|
|
157
|
+
obs_names = working.obs_names.astype(str).tolist()
|
|
158
|
+
train = _to_split_data(
|
|
159
|
+
x_matrix,
|
|
160
|
+
split_indices.train,
|
|
161
|
+
labels=labels,
|
|
162
|
+
batches=batches,
|
|
163
|
+
obs_names=obs_names,
|
|
164
|
+
)
|
|
165
|
+
val = (
|
|
166
|
+
_to_split_data(
|
|
167
|
+
x_matrix,
|
|
168
|
+
split_indices.val,
|
|
169
|
+
labels=labels,
|
|
170
|
+
batches=batches,
|
|
171
|
+
obs_names=obs_names,
|
|
172
|
+
)
|
|
173
|
+
if split_indices.val.size
|
|
174
|
+
else None
|
|
175
|
+
)
|
|
176
|
+
test = (
|
|
177
|
+
_to_split_data(
|
|
178
|
+
x_matrix,
|
|
179
|
+
split_indices.test,
|
|
180
|
+
labels=labels,
|
|
181
|
+
batches=batches,
|
|
182
|
+
obs_names=obs_names,
|
|
183
|
+
)
|
|
184
|
+
if split_indices.test.size
|
|
185
|
+
else None
|
|
186
|
+
)
|
|
187
|
+
preprocessing = {
|
|
188
|
+
"layer": layer,
|
|
189
|
+
"use_hvg": use_hvg,
|
|
190
|
+
"n_top_genes": n_top_genes,
|
|
191
|
+
"normalize": normalize,
|
|
192
|
+
"log1p": log1p,
|
|
193
|
+
"scale": scale,
|
|
194
|
+
"scaler": scaler,
|
|
195
|
+
"feature_names": feature_names,
|
|
196
|
+
"label_key": label_key,
|
|
197
|
+
"batch_key": batch_key,
|
|
198
|
+
"batch_aware_split": batch_aware_split,
|
|
199
|
+
}
|
|
200
|
+
return PreparedData(
|
|
201
|
+
train=train,
|
|
202
|
+
val=val,
|
|
203
|
+
test=test,
|
|
204
|
+
input_dim=int(x_matrix.shape[1]),
|
|
205
|
+
feature_names=feature_names,
|
|
206
|
+
label_encoder=label_encoder,
|
|
207
|
+
batch_encoder=batch_encoder,
|
|
208
|
+
preprocessing=preprocessing,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def transform_adata(
|
|
213
|
+
adata: AnnData,
|
|
214
|
+
preprocessing: dict[str, Any],
|
|
215
|
+
*,
|
|
216
|
+
label_encoder: dict[str, int] | None = None,
|
|
217
|
+
batch_encoder: dict[str, int] | None = None,
|
|
218
|
+
copy: bool = True,
|
|
219
|
+
) -> SplitData:
|
|
220
|
+
"""Transform new AnnData using stored preprocessing metadata."""
|
|
221
|
+
|
|
222
|
+
working = adata.copy() if copy else adata
|
|
223
|
+
working, x_matrix, _, _ = _prepare_matrix(
|
|
224
|
+
working,
|
|
225
|
+
layer=preprocessing["layer"],
|
|
226
|
+
use_hvg=preprocessing["use_hvg"],
|
|
227
|
+
n_top_genes=preprocessing["n_top_genes"],
|
|
228
|
+
normalize=preprocessing["normalize"],
|
|
229
|
+
log1p=preprocessing["log1p"],
|
|
230
|
+
scale=False,
|
|
231
|
+
)
|
|
232
|
+
feature_names = preprocessing["feature_names"]
|
|
233
|
+
if list(working.var_names.astype(str)) != feature_names:
|
|
234
|
+
missing = sorted(set(feature_names) - set(working.var_names.astype(str)))
|
|
235
|
+
if missing:
|
|
236
|
+
msg = f"AnnData is missing required features: {missing[:5]}"
|
|
237
|
+
raise ValueError(msg)
|
|
238
|
+
working = working[:, feature_names].copy()
|
|
239
|
+
x_matrix = _extract_matrix(working, preprocessing["layer"])
|
|
240
|
+
scaler = preprocessing.get("scaler")
|
|
241
|
+
if scaler is not None:
|
|
242
|
+
x_matrix = scaler.transform(x_matrix)
|
|
243
|
+
if sparse.issparse(x_matrix):
|
|
244
|
+
x_matrix = x_matrix.tocsr()
|
|
245
|
+
labels = None
|
|
246
|
+
if preprocessing["label_key"] is not None and preprocessing["label_key"] in working.obs:
|
|
247
|
+
labels = _transform_obs(
|
|
248
|
+
working.obs[preprocessing["label_key"]].astype(str).to_numpy(),
|
|
249
|
+
label_encoder,
|
|
250
|
+
)
|
|
251
|
+
batches = None
|
|
252
|
+
if preprocessing["batch_key"] is not None and preprocessing["batch_key"] in working.obs:
|
|
253
|
+
batches = _transform_obs(
|
|
254
|
+
working.obs[preprocessing["batch_key"]].astype(str).to_numpy(),
|
|
255
|
+
batch_encoder,
|
|
256
|
+
)
|
|
257
|
+
return SplitData(
|
|
258
|
+
X=x_matrix,
|
|
259
|
+
labels=labels,
|
|
260
|
+
batches=batches,
|
|
261
|
+
obs_names=working.obs_names.astype(str).tolist(),
|
|
262
|
+
)
|
scdlkit/data/schemas.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Dataclasses for prepared datasets and preprocessing metadata."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(slots=True)
|
|
12
|
+
class SplitData:
|
|
13
|
+
"""One dataset split with optional encoded labels and batches."""
|
|
14
|
+
|
|
15
|
+
X: Any
|
|
16
|
+
labels: np.ndarray | None = None
|
|
17
|
+
batches: np.ndarray | None = None
|
|
18
|
+
obs_names: list[str] = field(default_factory=list)
|
|
19
|
+
|
|
20
|
+
def __len__(self) -> int:
|
|
21
|
+
return int(self.X.shape[0])
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(slots=True)
|
|
25
|
+
class PreparedData:
|
|
26
|
+
"""Prepared train/validation/test splits and metadata."""
|
|
27
|
+
|
|
28
|
+
train: SplitData
|
|
29
|
+
val: SplitData | None
|
|
30
|
+
test: SplitData | None
|
|
31
|
+
input_dim: int
|
|
32
|
+
feature_names: list[str]
|
|
33
|
+
label_encoder: dict[str, int] | None
|
|
34
|
+
batch_encoder: dict[str, int] | None
|
|
35
|
+
preprocessing: dict[str, Any]
|
scdlkit/data/splits.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Split helpers for prepared AnnData workflows."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sklearn.model_selection import GroupShuffleSplit, train_test_split
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(slots=True)
|
|
12
|
+
class SplitIndices:
|
|
13
|
+
train: np.ndarray
|
|
14
|
+
val: np.ndarray
|
|
15
|
+
test: np.ndarray
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def build_splits(
|
|
19
|
+
n_samples: int,
|
|
20
|
+
*,
|
|
21
|
+
val_size: float,
|
|
22
|
+
test_size: float,
|
|
23
|
+
random_state: int,
|
|
24
|
+
stratify: np.ndarray | None = None,
|
|
25
|
+
groups: np.ndarray | None = None,
|
|
26
|
+
) -> SplitIndices:
|
|
27
|
+
"""Create train/validation/test indices."""
|
|
28
|
+
|
|
29
|
+
all_indices = np.arange(n_samples)
|
|
30
|
+
if val_size < 0 or test_size < 0 or val_size + test_size >= 1:
|
|
31
|
+
msg = "val_size and test_size must be >= 0 and sum to less than 1"
|
|
32
|
+
raise ValueError(msg)
|
|
33
|
+
|
|
34
|
+
holdout_fraction = val_size + test_size
|
|
35
|
+
if holdout_fraction == 0:
|
|
36
|
+
return SplitIndices(
|
|
37
|
+
train=all_indices,
|
|
38
|
+
val=np.array([], dtype=int),
|
|
39
|
+
test=np.array([], dtype=int),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
if groups is not None:
|
|
43
|
+
splitter = GroupShuffleSplit(
|
|
44
|
+
n_splits=1,
|
|
45
|
+
test_size=holdout_fraction,
|
|
46
|
+
random_state=random_state,
|
|
47
|
+
)
|
|
48
|
+
train_idx, holdout_idx = next(splitter.split(all_indices, groups=groups))
|
|
49
|
+
else:
|
|
50
|
+
train_idx, holdout_idx = train_test_split(
|
|
51
|
+
all_indices,
|
|
52
|
+
test_size=holdout_fraction,
|
|
53
|
+
random_state=random_state,
|
|
54
|
+
stratify=stratify,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if test_size == 0:
|
|
58
|
+
return SplitIndices(
|
|
59
|
+
train=np.sort(train_idx),
|
|
60
|
+
val=np.sort(holdout_idx),
|
|
61
|
+
test=np.array([], dtype=int),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
holdout_stratify = stratify[holdout_idx] if stratify is not None else None
|
|
65
|
+
if val_size == 0:
|
|
66
|
+
return SplitIndices(
|
|
67
|
+
train=np.sort(train_idx),
|
|
68
|
+
val=np.array([], dtype=int),
|
|
69
|
+
test=np.sort(holdout_idx),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
test_fraction = test_size / holdout_fraction
|
|
73
|
+
if groups is not None:
|
|
74
|
+
holdout_groups = groups[holdout_idx]
|
|
75
|
+
splitter = GroupShuffleSplit(
|
|
76
|
+
n_splits=1,
|
|
77
|
+
test_size=test_fraction,
|
|
78
|
+
random_state=random_state,
|
|
79
|
+
)
|
|
80
|
+
val_rel, test_rel = next(splitter.split(holdout_idx, groups=holdout_groups))
|
|
81
|
+
else:
|
|
82
|
+
val_rel, test_rel = train_test_split(
|
|
83
|
+
np.arange(holdout_idx.size),
|
|
84
|
+
test_size=test_fraction,
|
|
85
|
+
random_state=random_state,
|
|
86
|
+
stratify=holdout_stratify,
|
|
87
|
+
)
|
|
88
|
+
val_idx = holdout_idx[val_rel]
|
|
89
|
+
test_idx = holdout_idx[test_rel]
|
|
90
|
+
return SplitIndices(train=np.sort(train_idx), val=np.sort(val_idx), test=np.sort(test_idx))
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Compare multiple models on the same AnnData workflow."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from scdlkit.evaluation.report import save_markdown_report
|
|
11
|
+
from scdlkit.utils import ensure_directory
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(slots=True)
|
|
15
|
+
class BenchmarkResult:
|
|
16
|
+
"""Collected results from comparing multiple models."""
|
|
17
|
+
|
|
18
|
+
metrics_frame: pd.DataFrame
|
|
19
|
+
runners: dict[str, Any]
|
|
20
|
+
output_paths: dict[str, str] = field(default_factory=dict)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def compare_models(
|
|
24
|
+
adata: Any,
|
|
25
|
+
*,
|
|
26
|
+
models: list[str],
|
|
27
|
+
task: str,
|
|
28
|
+
shared_kwargs: dict[str, Any] | None = None,
|
|
29
|
+
output_dir: str | None = None,
|
|
30
|
+
) -> BenchmarkResult:
|
|
31
|
+
"""Train and evaluate several models with shared configuration."""
|
|
32
|
+
|
|
33
|
+
from scdlkit.runner import TaskRunner
|
|
34
|
+
from scdlkit.visualization.compare import plot_model_comparison
|
|
35
|
+
|
|
36
|
+
shared = dict(shared_kwargs or {})
|
|
37
|
+
records: list[dict[str, Any]] = []
|
|
38
|
+
runners: dict[str, TaskRunner] = {}
|
|
39
|
+
output_paths: dict[str, str] = {}
|
|
40
|
+
for model_name in models:
|
|
41
|
+
runner = TaskRunner(model=model_name, task=task, **shared)
|
|
42
|
+
runner.fit(adata)
|
|
43
|
+
metrics = runner.evaluate()
|
|
44
|
+
scalar_metrics = {k: v for k, v in metrics.items() if isinstance(v, (int, float))}
|
|
45
|
+
records.append({"model": model_name, **scalar_metrics})
|
|
46
|
+
runners[model_name] = runner
|
|
47
|
+
|
|
48
|
+
metrics_frame = pd.DataFrame.from_records(records).sort_values("model").reset_index(drop=True)
|
|
49
|
+
if output_dir is not None:
|
|
50
|
+
directory = ensure_directory(output_dir)
|
|
51
|
+
csv_path = directory / "benchmark_metrics.csv"
|
|
52
|
+
md_path = directory / "benchmark_report.md"
|
|
53
|
+
png_path = directory / "benchmark_comparison.png"
|
|
54
|
+
metrics_frame.to_csv(csv_path, index=False)
|
|
55
|
+
fig, _ = plot_model_comparison(metrics_frame)
|
|
56
|
+
fig.savefig(png_path, dpi=150, bbox_inches="tight")
|
|
57
|
+
report_lines = ["## Compared models", "", *[f"- `{name}`" for name in models]]
|
|
58
|
+
save_markdown_report(
|
|
59
|
+
{"num_models": len(models), "task": task},
|
|
60
|
+
path=md_path,
|
|
61
|
+
title="Benchmark Report",
|
|
62
|
+
extra_sections=report_lines,
|
|
63
|
+
)
|
|
64
|
+
output_paths = {
|
|
65
|
+
"metrics_csv": str(csv_path),
|
|
66
|
+
"report_md": str(md_path),
|
|
67
|
+
"comparison_png": str(png_path),
|
|
68
|
+
}
|
|
69
|
+
return BenchmarkResult(metrics_frame=metrics_frame, runners=runners, output_paths=output_paths)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Task-aware evaluation entrypoints."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from scdlkit.evaluation.metrics import (
|
|
10
|
+
classification_metrics,
|
|
11
|
+
reconstruction_metrics,
|
|
12
|
+
representation_metrics,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def evaluate_predictions(task: str, predictions: dict[str, np.ndarray]) -> dict[str, Any]:
|
|
17
|
+
"""Evaluate model predictions for a task."""
|
|
18
|
+
|
|
19
|
+
if task == "classification":
|
|
20
|
+
if "y" not in predictions:
|
|
21
|
+
msg = "Classification evaluation requires encoded labels."
|
|
22
|
+
raise ValueError(msg)
|
|
23
|
+
return classification_metrics(predictions["y"], predictions["logits"])
|
|
24
|
+
|
|
25
|
+
metrics = reconstruction_metrics(predictions["x"], predictions["reconstruction"])
|
|
26
|
+
if task == "representation":
|
|
27
|
+
metrics.update(
|
|
28
|
+
representation_metrics(
|
|
29
|
+
predictions.get("latent", np.empty((predictions["x"].shape[0], 0))),
|
|
30
|
+
predictions.get("y"),
|
|
31
|
+
predictions.get("batch"),
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
return metrics
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Metric helpers for reconstruction, representation, and classification."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.stats import pearsonr, spearmanr
|
|
10
|
+
from sklearn.cluster import KMeans
|
|
11
|
+
from sklearn.metrics import (
|
|
12
|
+
accuracy_score,
|
|
13
|
+
adjusted_rand_score,
|
|
14
|
+
confusion_matrix,
|
|
15
|
+
f1_score,
|
|
16
|
+
normalized_mutual_info_score,
|
|
17
|
+
silhouette_score,
|
|
18
|
+
)
|
|
19
|
+
from sklearn.neighbors import NearestNeighbors
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _safe_correlation(
|
|
23
|
+
func: Callable[[np.ndarray, np.ndarray], tuple[float, float]],
|
|
24
|
+
y_true: np.ndarray,
|
|
25
|
+
y_pred: np.ndarray,
|
|
26
|
+
) -> float:
|
|
27
|
+
flat_true = np.ravel(y_true)
|
|
28
|
+
flat_pred = np.ravel(y_pred)
|
|
29
|
+
if np.std(flat_true) == 0 or np.std(flat_pred) == 0:
|
|
30
|
+
return 0.0
|
|
31
|
+
corr, _ = func(flat_true, flat_pred)
|
|
32
|
+
return 0.0 if math.isnan(corr) else float(corr)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def reconstruction_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
36
|
+
error = y_true - y_pred
|
|
37
|
+
return {
|
|
38
|
+
"mse": float(np.mean(error**2)),
|
|
39
|
+
"mae": float(np.mean(np.abs(error))),
|
|
40
|
+
"pearson": _safe_correlation(pearsonr, y_true, y_pred),
|
|
41
|
+
"spearman": _safe_correlation(spearmanr, y_true, y_pred),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def knn_label_consistency(latent: np.ndarray, labels: np.ndarray, n_neighbors: int = 10) -> float:
|
|
46
|
+
if len(np.unique(labels)) < 2 or latent.shape[0] <= 1:
|
|
47
|
+
return 0.0
|
|
48
|
+
neighbors = min(n_neighbors + 1, latent.shape[0])
|
|
49
|
+
knn = NearestNeighbors(n_neighbors=neighbors)
|
|
50
|
+
knn.fit(latent)
|
|
51
|
+
indices = knn.kneighbors(latent, return_distance=False)[:, 1:]
|
|
52
|
+
votes = labels[indices]
|
|
53
|
+
majority = np.array([np.bincount(row).argmax() for row in votes])
|
|
54
|
+
return float(np.mean(majority == labels))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def representation_metrics(
|
|
58
|
+
latent: np.ndarray,
|
|
59
|
+
labels: np.ndarray | None,
|
|
60
|
+
batches: np.ndarray | None,
|
|
61
|
+
) -> dict[str, float]:
|
|
62
|
+
metrics: dict[str, float] = {}
|
|
63
|
+
unique_labels = np.unique(labels) if labels is not None else np.array([])
|
|
64
|
+
unique_batches = np.unique(batches) if batches is not None else np.array([])
|
|
65
|
+
if labels is not None and latent.shape[0] > len(unique_labels) and len(unique_labels) > 1:
|
|
66
|
+
metrics["silhouette"] = float(silhouette_score(latent, labels))
|
|
67
|
+
metrics["knn_label_consistency"] = knn_label_consistency(latent, labels)
|
|
68
|
+
kmeans = KMeans(n_clusters=len(unique_labels), random_state=42, n_init="auto")
|
|
69
|
+
clusters = kmeans.fit_predict(latent)
|
|
70
|
+
metrics["ari"] = float(adjusted_rand_score(labels, clusters))
|
|
71
|
+
metrics["nmi"] = float(normalized_mutual_info_score(labels, clusters))
|
|
72
|
+
if batches is not None and latent.shape[0] > len(unique_batches) and len(unique_batches) > 1:
|
|
73
|
+
metrics["batch_silhouette"] = float(silhouette_score(latent, batches))
|
|
74
|
+
return metrics
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def classification_metrics(y_true: np.ndarray, logits: np.ndarray) -> dict[str, object]:
|
|
78
|
+
predicted = logits.argmax(axis=1)
|
|
79
|
+
return {
|
|
80
|
+
"accuracy": float(accuracy_score(y_true, predicted)),
|
|
81
|
+
"macro_f1": float(f1_score(y_true, predicted, average="macro")),
|
|
82
|
+
"confusion_matrix": confusion_matrix(y_true, predicted).tolist(),
|
|
83
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Report export helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _scalar_metrics(metrics: dict[str, Any]) -> dict[str, float]:
|
|
12
|
+
return {key: float(value) for key, value in metrics.items() if isinstance(value, (int, float))}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def save_metrics_table(metrics: dict[str, Any], path: str | Path) -> Path:
|
|
16
|
+
"""Write scalar metrics to CSV."""
|
|
17
|
+
|
|
18
|
+
output = Path(path)
|
|
19
|
+
frame = pd.DataFrame([_scalar_metrics(metrics)])
|
|
20
|
+
frame.to_csv(output, index=False)
|
|
21
|
+
return output
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def save_markdown_report(
|
|
25
|
+
metrics: dict[str, Any],
|
|
26
|
+
*,
|
|
27
|
+
path: str | Path,
|
|
28
|
+
title: str,
|
|
29
|
+
extra_sections: list[str] | None = None,
|
|
30
|
+
) -> Path:
|
|
31
|
+
"""Write a markdown report with scalar and structured metrics."""
|
|
32
|
+
|
|
33
|
+
output = Path(path)
|
|
34
|
+
lines = [f"# {title}", "", "## Metrics", ""]
|
|
35
|
+
for key, value in metrics.items():
|
|
36
|
+
lines.append(f"- **{key}**: {value}")
|
|
37
|
+
if extra_sections:
|
|
38
|
+
lines.extend(["", *extra_sections])
|
|
39
|
+
output.write_text("\n".join(lines), encoding="utf-8")
|
|
40
|
+
return output
|