scdlkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. scdlkit/__init__.py +20 -0
  2. scdlkit/data/__init__.py +6 -0
  3. scdlkit/data/datasets.py +30 -0
  4. scdlkit/data/prepare.py +262 -0
  5. scdlkit/data/schemas.py +35 -0
  6. scdlkit/data/splits.py +90 -0
  7. scdlkit/evaluation/__init__.py +6 -0
  8. scdlkit/evaluation/compare.py +69 -0
  9. scdlkit/evaluation/evaluator.py +34 -0
  10. scdlkit/evaluation/metrics.py +83 -0
  11. scdlkit/evaluation/report.py +40 -0
  12. scdlkit/models/__init__.py +20 -0
  13. scdlkit/models/autoencoder.py +43 -0
  14. scdlkit/models/base.py +22 -0
  15. scdlkit/models/blocks.py +32 -0
  16. scdlkit/models/classifier.py +30 -0
  17. scdlkit/models/denoising.py +37 -0
  18. scdlkit/models/registry.py +33 -0
  19. scdlkit/models/transformer.py +73 -0
  20. scdlkit/models/vae.py +61 -0
  21. scdlkit/runner.py +278 -0
  22. scdlkit/tasks/__init__.py +14 -0
  23. scdlkit/tasks/base.py +40 -0
  24. scdlkit/tasks/classification.py +28 -0
  25. scdlkit/tasks/reconstruction.py +41 -0
  26. scdlkit/tasks/representation.py +14 -0
  27. scdlkit/training/__init__.py +5 -0
  28. scdlkit/training/callbacks.py +12 -0
  29. scdlkit/training/trainer.py +176 -0
  30. scdlkit/utils/__init__.py +7 -0
  31. scdlkit/utils/device.py +13 -0
  32. scdlkit/utils/io.py +13 -0
  33. scdlkit/utils/seed.py +18 -0
  34. scdlkit/visualization/__init__.py +15 -0
  35. scdlkit/visualization/classification.py +26 -0
  36. scdlkit/visualization/compare.py +32 -0
  37. scdlkit/visualization/latent.py +36 -0
  38. scdlkit/visualization/reconstruction.py +24 -0
  39. scdlkit/visualization/training.py +21 -0
  40. scdlkit-0.1.0.dist-info/METADATA +265 -0
  41. scdlkit-0.1.0.dist-info/RECORD +44 -0
  42. scdlkit-0.1.0.dist-info/WHEEL +5 -0
  43. scdlkit-0.1.0.dist-info/licenses/LICENSE +21 -0
  44. scdlkit-0.1.0.dist-info/top_level.txt +1 -0
scdlkit/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """Public package surface for scDLKit."""
2
+
3
+ from scdlkit.data import PreparedData, prepare_data
4
+ from scdlkit.evaluation.compare import BenchmarkResult, compare_models
5
+ from scdlkit.models import BaseModel, create_model
6
+ from scdlkit.runner import TaskRunner
7
+ from scdlkit.training import Trainer
8
+
9
+ __all__ = [
10
+ "BaseModel",
11
+ "BenchmarkResult",
12
+ "PreparedData",
13
+ "TaskRunner",
14
+ "Trainer",
15
+ "compare_models",
16
+ "create_model",
17
+ "prepare_data",
18
+ ]
19
+
20
+ __version__ = "0.1.0"
@@ -0,0 +1,6 @@
1
+ """Data preparation utilities."""
2
+
3
+ from scdlkit.data.prepare import prepare_data, transform_adata
4
+ from scdlkit.data.schemas import PreparedData, SplitData
5
+
6
+ __all__ = ["PreparedData", "SplitData", "prepare_data", "transform_adata"]
@@ -0,0 +1,30 @@
1
+ """PyTorch datasets backed by dense or sparse matrices."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import torch
7
+ from scipy import sparse
8
+ from torch.utils.data import Dataset
9
+
10
+ from scdlkit.data.schemas import SplitData
11
+
12
+
13
+ class AnnDataset(Dataset[dict[str, torch.Tensor]]):
14
+ """Dataset that converts rows to dense float32 on access."""
15
+
16
+ def __init__(self, split: SplitData):
17
+ self.split = split
18
+
19
+ def __len__(self) -> int:
20
+ return len(self.split)
21
+
22
+ def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
23
+ row = self.split.X[index]
24
+ x = row.toarray().ravel() if sparse.issparse(row) else np.asarray(row).ravel()
25
+ sample: dict[str, torch.Tensor] = {"x": torch.as_tensor(x, dtype=torch.float32)}
26
+ if self.split.labels is not None:
27
+ sample["y"] = torch.as_tensor(int(self.split.labels[index]), dtype=torch.long)
28
+ if self.split.batches is not None:
29
+ sample["batch"] = torch.as_tensor(int(self.split.batches[index]), dtype=torch.long)
30
+ return sample
@@ -0,0 +1,262 @@
1
+ """AnnData preparation and transformation utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from anndata import AnnData
9
+ from scipy import sparse
10
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
11
+
12
+ from scdlkit.data.schemas import PreparedData, SplitData
13
+ from scdlkit.data.splits import build_splits
14
+
15
+
16
+ def _require_scanpy() -> Any:
17
+ try:
18
+ import scanpy as sc
19
+ except ImportError as exc:
20
+ msg = "scanpy-backed preprocessing requires `pip install scdlkit[scanpy]`."
21
+ raise ImportError(msg) from exc
22
+ return sc
23
+
24
+
25
+ def _encode_obs(values: np.ndarray | None) -> tuple[np.ndarray | None, dict[str, int] | None]:
26
+ if values is None:
27
+ return None, None
28
+ encoder = LabelEncoder()
29
+ encoded = encoder.fit_transform(values.astype(str))
30
+ mapping = {label: int(index) for index, label in enumerate(encoder.classes_)}
31
+ return encoded, mapping
32
+
33
+
34
+ def _transform_obs(values: np.ndarray | None, mapping: dict[str, int] | None) -> np.ndarray | None:
35
+ if values is None or mapping is None:
36
+ return None
37
+ encoded = np.empty(values.shape[0], dtype=int)
38
+ for index, value in enumerate(values.astype(str)):
39
+ if value not in mapping:
40
+ msg = f"Encountered unseen label '{value}' during transform."
41
+ raise ValueError(msg)
42
+ encoded[index] = mapping[value]
43
+ return encoded
44
+
45
+
46
+ def _extract_matrix(adata: AnnData, layer: str) -> Any:
47
+ if layer == "X":
48
+ return adata.X
49
+ if layer not in adata.layers:
50
+ msg = f"Layer '{layer}' not found in AnnData.layers."
51
+ raise ValueError(msg)
52
+ return adata.layers[layer]
53
+
54
+
55
+ def _to_split_data(
56
+ x_matrix: Any,
57
+ indices: np.ndarray,
58
+ *,
59
+ labels: np.ndarray | None,
60
+ batches: np.ndarray | None,
61
+ obs_names: list[str],
62
+ ) -> SplitData:
63
+ return SplitData(
64
+ X=x_matrix[indices],
65
+ labels=labels[indices] if labels is not None else None,
66
+ batches=batches[indices] if batches is not None else None,
67
+ obs_names=[obs_names[index] for index in indices],
68
+ )
69
+
70
+
71
+ def _prepare_matrix(
72
+ adata: AnnData,
73
+ *,
74
+ layer: str,
75
+ use_hvg: bool,
76
+ n_top_genes: int,
77
+ normalize: bool,
78
+ log1p: bool,
79
+ scale: bool,
80
+ ) -> tuple[AnnData, Any, list[str], StandardScaler | None]:
81
+ working = adata
82
+ if normalize or log1p or use_hvg:
83
+ sc = _require_scanpy()
84
+ if normalize:
85
+ sc.pp.normalize_total(working)
86
+ if log1p:
87
+ sc.pp.log1p(working)
88
+ if use_hvg:
89
+ sc.pp.highly_variable_genes(working, n_top_genes=n_top_genes, subset=True)
90
+ x_matrix = _extract_matrix(working, layer)
91
+ scaler: StandardScaler | None = None
92
+ if scale:
93
+ scaler = StandardScaler(with_mean=not sparse.issparse(x_matrix))
94
+ x_matrix = scaler.fit_transform(x_matrix)
95
+ if sparse.issparse(x_matrix):
96
+ x_matrix = x_matrix.tocsr()
97
+ feature_names = working.var_names.astype(str).tolist()
98
+ return working, x_matrix, feature_names, scaler
99
+
100
+
101
+ def prepare_data(
102
+ adata: AnnData,
103
+ *,
104
+ layer: str = "X",
105
+ use_hvg: bool = False,
106
+ n_top_genes: int = 2000,
107
+ normalize: bool = False,
108
+ log1p: bool = False,
109
+ scale: bool = False,
110
+ label_key: str | None = None,
111
+ batch_key: str | None = None,
112
+ val_size: float = 0.15,
113
+ test_size: float = 0.15,
114
+ batch_aware_split: bool = False,
115
+ random_state: int = 42,
116
+ copy: bool = True,
117
+ ) -> PreparedData:
118
+ """Prepare AnnData splits and preprocessing metadata."""
119
+
120
+ working = adata.copy() if copy else adata
121
+ working, x_matrix, feature_names, scaler = _prepare_matrix(
122
+ working,
123
+ layer=layer,
124
+ use_hvg=use_hvg,
125
+ n_top_genes=n_top_genes,
126
+ normalize=normalize,
127
+ log1p=log1p,
128
+ scale=scale,
129
+ )
130
+ labels_raw = (
131
+ working.obs[label_key].astype(str).to_numpy()
132
+ if label_key is not None and label_key in working.obs
133
+ else None
134
+ )
135
+ if label_key is not None and labels_raw is None:
136
+ msg = f"label_key '{label_key}' not found in adata.obs."
137
+ raise ValueError(msg)
138
+ batches_raw = (
139
+ working.obs[batch_key].astype(str).to_numpy()
140
+ if batch_key is not None and batch_key in working.obs
141
+ else None
142
+ )
143
+ if batch_key is not None and batches_raw is None:
144
+ msg = f"batch_key '{batch_key}' not found in adata.obs."
145
+ raise ValueError(msg)
146
+ labels, label_encoder = _encode_obs(labels_raw)
147
+ batches, batch_encoder = _encode_obs(batches_raw)
148
+
149
+ split_indices = build_splits(
150
+ working.n_obs,
151
+ val_size=val_size,
152
+ test_size=test_size,
153
+ random_state=random_state,
154
+ stratify=labels if label_key is not None else None,
155
+ groups=batches if batch_aware_split and batch_key is not None else None,
156
+ )
157
+ obs_names = working.obs_names.astype(str).tolist()
158
+ train = _to_split_data(
159
+ x_matrix,
160
+ split_indices.train,
161
+ labels=labels,
162
+ batches=batches,
163
+ obs_names=obs_names,
164
+ )
165
+ val = (
166
+ _to_split_data(
167
+ x_matrix,
168
+ split_indices.val,
169
+ labels=labels,
170
+ batches=batches,
171
+ obs_names=obs_names,
172
+ )
173
+ if split_indices.val.size
174
+ else None
175
+ )
176
+ test = (
177
+ _to_split_data(
178
+ x_matrix,
179
+ split_indices.test,
180
+ labels=labels,
181
+ batches=batches,
182
+ obs_names=obs_names,
183
+ )
184
+ if split_indices.test.size
185
+ else None
186
+ )
187
+ preprocessing = {
188
+ "layer": layer,
189
+ "use_hvg": use_hvg,
190
+ "n_top_genes": n_top_genes,
191
+ "normalize": normalize,
192
+ "log1p": log1p,
193
+ "scale": scale,
194
+ "scaler": scaler,
195
+ "feature_names": feature_names,
196
+ "label_key": label_key,
197
+ "batch_key": batch_key,
198
+ "batch_aware_split": batch_aware_split,
199
+ }
200
+ return PreparedData(
201
+ train=train,
202
+ val=val,
203
+ test=test,
204
+ input_dim=int(x_matrix.shape[1]),
205
+ feature_names=feature_names,
206
+ label_encoder=label_encoder,
207
+ batch_encoder=batch_encoder,
208
+ preprocessing=preprocessing,
209
+ )
210
+
211
+
212
+ def transform_adata(
213
+ adata: AnnData,
214
+ preprocessing: dict[str, Any],
215
+ *,
216
+ label_encoder: dict[str, int] | None = None,
217
+ batch_encoder: dict[str, int] | None = None,
218
+ copy: bool = True,
219
+ ) -> SplitData:
220
+ """Transform new AnnData using stored preprocessing metadata."""
221
+
222
+ working = adata.copy() if copy else adata
223
+ working, x_matrix, _, _ = _prepare_matrix(
224
+ working,
225
+ layer=preprocessing["layer"],
226
+ use_hvg=preprocessing["use_hvg"],
227
+ n_top_genes=preprocessing["n_top_genes"],
228
+ normalize=preprocessing["normalize"],
229
+ log1p=preprocessing["log1p"],
230
+ scale=False,
231
+ )
232
+ feature_names = preprocessing["feature_names"]
233
+ if list(working.var_names.astype(str)) != feature_names:
234
+ missing = sorted(set(feature_names) - set(working.var_names.astype(str)))
235
+ if missing:
236
+ msg = f"AnnData is missing required features: {missing[:5]}"
237
+ raise ValueError(msg)
238
+ working = working[:, feature_names].copy()
239
+ x_matrix = _extract_matrix(working, preprocessing["layer"])
240
+ scaler = preprocessing.get("scaler")
241
+ if scaler is not None:
242
+ x_matrix = scaler.transform(x_matrix)
243
+ if sparse.issparse(x_matrix):
244
+ x_matrix = x_matrix.tocsr()
245
+ labels = None
246
+ if preprocessing["label_key"] is not None and preprocessing["label_key"] in working.obs:
247
+ labels = _transform_obs(
248
+ working.obs[preprocessing["label_key"]].astype(str).to_numpy(),
249
+ label_encoder,
250
+ )
251
+ batches = None
252
+ if preprocessing["batch_key"] is not None and preprocessing["batch_key"] in working.obs:
253
+ batches = _transform_obs(
254
+ working.obs[preprocessing["batch_key"]].astype(str).to_numpy(),
255
+ batch_encoder,
256
+ )
257
+ return SplitData(
258
+ X=x_matrix,
259
+ labels=labels,
260
+ batches=batches,
261
+ obs_names=working.obs_names.astype(str).tolist(),
262
+ )
@@ -0,0 +1,35 @@
1
+ """Dataclasses for prepared datasets and preprocessing metadata."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+
11
+ @dataclass(slots=True)
12
+ class SplitData:
13
+ """One dataset split with optional encoded labels and batches."""
14
+
15
+ X: Any
16
+ labels: np.ndarray | None = None
17
+ batches: np.ndarray | None = None
18
+ obs_names: list[str] = field(default_factory=list)
19
+
20
+ def __len__(self) -> int:
21
+ return int(self.X.shape[0])
22
+
23
+
24
+ @dataclass(slots=True)
25
+ class PreparedData:
26
+ """Prepared train/validation/test splits and metadata."""
27
+
28
+ train: SplitData
29
+ val: SplitData | None
30
+ test: SplitData | None
31
+ input_dim: int
32
+ feature_names: list[str]
33
+ label_encoder: dict[str, int] | None
34
+ batch_encoder: dict[str, int] | None
35
+ preprocessing: dict[str, Any]
scdlkit/data/splits.py ADDED
@@ -0,0 +1,90 @@
1
+ """Split helpers for prepared AnnData workflows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ from sklearn.model_selection import GroupShuffleSplit, train_test_split
9
+
10
+
11
+ @dataclass(slots=True)
12
+ class SplitIndices:
13
+ train: np.ndarray
14
+ val: np.ndarray
15
+ test: np.ndarray
16
+
17
+
18
+ def build_splits(
19
+ n_samples: int,
20
+ *,
21
+ val_size: float,
22
+ test_size: float,
23
+ random_state: int,
24
+ stratify: np.ndarray | None = None,
25
+ groups: np.ndarray | None = None,
26
+ ) -> SplitIndices:
27
+ """Create train/validation/test indices."""
28
+
29
+ all_indices = np.arange(n_samples)
30
+ if val_size < 0 or test_size < 0 or val_size + test_size >= 1:
31
+ msg = "val_size and test_size must be >= 0 and sum to less than 1"
32
+ raise ValueError(msg)
33
+
34
+ holdout_fraction = val_size + test_size
35
+ if holdout_fraction == 0:
36
+ return SplitIndices(
37
+ train=all_indices,
38
+ val=np.array([], dtype=int),
39
+ test=np.array([], dtype=int),
40
+ )
41
+
42
+ if groups is not None:
43
+ splitter = GroupShuffleSplit(
44
+ n_splits=1,
45
+ test_size=holdout_fraction,
46
+ random_state=random_state,
47
+ )
48
+ train_idx, holdout_idx = next(splitter.split(all_indices, groups=groups))
49
+ else:
50
+ train_idx, holdout_idx = train_test_split(
51
+ all_indices,
52
+ test_size=holdout_fraction,
53
+ random_state=random_state,
54
+ stratify=stratify,
55
+ )
56
+
57
+ if test_size == 0:
58
+ return SplitIndices(
59
+ train=np.sort(train_idx),
60
+ val=np.sort(holdout_idx),
61
+ test=np.array([], dtype=int),
62
+ )
63
+
64
+ holdout_stratify = stratify[holdout_idx] if stratify is not None else None
65
+ if val_size == 0:
66
+ return SplitIndices(
67
+ train=np.sort(train_idx),
68
+ val=np.array([], dtype=int),
69
+ test=np.sort(holdout_idx),
70
+ )
71
+
72
+ test_fraction = test_size / holdout_fraction
73
+ if groups is not None:
74
+ holdout_groups = groups[holdout_idx]
75
+ splitter = GroupShuffleSplit(
76
+ n_splits=1,
77
+ test_size=test_fraction,
78
+ random_state=random_state,
79
+ )
80
+ val_rel, test_rel = next(splitter.split(holdout_idx, groups=holdout_groups))
81
+ else:
82
+ val_rel, test_rel = train_test_split(
83
+ np.arange(holdout_idx.size),
84
+ test_size=test_fraction,
85
+ random_state=random_state,
86
+ stratify=holdout_stratify,
87
+ )
88
+ val_idx = holdout_idx[val_rel]
89
+ test_idx = holdout_idx[test_rel]
90
+ return SplitIndices(train=np.sort(train_idx), val=np.sort(val_idx), test=np.sort(test_idx))
@@ -0,0 +1,6 @@
1
+ """Evaluation utilities."""
2
+
3
+ from scdlkit.evaluation.compare import BenchmarkResult, compare_models
4
+ from scdlkit.evaluation.evaluator import evaluate_predictions
5
+
6
+ __all__ = ["BenchmarkResult", "compare_models", "evaluate_predictions"]
@@ -0,0 +1,69 @@
1
+ """Compare multiple models on the same AnnData workflow."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ import pandas as pd
9
+
10
+ from scdlkit.evaluation.report import save_markdown_report
11
+ from scdlkit.utils import ensure_directory
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class BenchmarkResult:
16
+ """Collected results from comparing multiple models."""
17
+
18
+ metrics_frame: pd.DataFrame
19
+ runners: dict[str, Any]
20
+ output_paths: dict[str, str] = field(default_factory=dict)
21
+
22
+
23
+ def compare_models(
24
+ adata: Any,
25
+ *,
26
+ models: list[str],
27
+ task: str,
28
+ shared_kwargs: dict[str, Any] | None = None,
29
+ output_dir: str | None = None,
30
+ ) -> BenchmarkResult:
31
+ """Train and evaluate several models with shared configuration."""
32
+
33
+ from scdlkit.runner import TaskRunner
34
+ from scdlkit.visualization.compare import plot_model_comparison
35
+
36
+ shared = dict(shared_kwargs or {})
37
+ records: list[dict[str, Any]] = []
38
+ runners: dict[str, TaskRunner] = {}
39
+ output_paths: dict[str, str] = {}
40
+ for model_name in models:
41
+ runner = TaskRunner(model=model_name, task=task, **shared)
42
+ runner.fit(adata)
43
+ metrics = runner.evaluate()
44
+ scalar_metrics = {k: v for k, v in metrics.items() if isinstance(v, (int, float))}
45
+ records.append({"model": model_name, **scalar_metrics})
46
+ runners[model_name] = runner
47
+
48
+ metrics_frame = pd.DataFrame.from_records(records).sort_values("model").reset_index(drop=True)
49
+ if output_dir is not None:
50
+ directory = ensure_directory(output_dir)
51
+ csv_path = directory / "benchmark_metrics.csv"
52
+ md_path = directory / "benchmark_report.md"
53
+ png_path = directory / "benchmark_comparison.png"
54
+ metrics_frame.to_csv(csv_path, index=False)
55
+ fig, _ = plot_model_comparison(metrics_frame)
56
+ fig.savefig(png_path, dpi=150, bbox_inches="tight")
57
+ report_lines = ["## Compared models", "", *[f"- `{name}`" for name in models]]
58
+ save_markdown_report(
59
+ {"num_models": len(models), "task": task},
60
+ path=md_path,
61
+ title="Benchmark Report",
62
+ extra_sections=report_lines,
63
+ )
64
+ output_paths = {
65
+ "metrics_csv": str(csv_path),
66
+ "report_md": str(md_path),
67
+ "comparison_png": str(png_path),
68
+ }
69
+ return BenchmarkResult(metrics_frame=metrics_frame, runners=runners, output_paths=output_paths)
@@ -0,0 +1,34 @@
1
+ """Task-aware evaluation entrypoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+ from scdlkit.evaluation.metrics import (
10
+ classification_metrics,
11
+ reconstruction_metrics,
12
+ representation_metrics,
13
+ )
14
+
15
+
16
+ def evaluate_predictions(task: str, predictions: dict[str, np.ndarray]) -> dict[str, Any]:
17
+ """Evaluate model predictions for a task."""
18
+
19
+ if task == "classification":
20
+ if "y" not in predictions:
21
+ msg = "Classification evaluation requires encoded labels."
22
+ raise ValueError(msg)
23
+ return classification_metrics(predictions["y"], predictions["logits"])
24
+
25
+ metrics = reconstruction_metrics(predictions["x"], predictions["reconstruction"])
26
+ if task == "representation":
27
+ metrics.update(
28
+ representation_metrics(
29
+ predictions.get("latent", np.empty((predictions["x"].shape[0], 0))),
30
+ predictions.get("y"),
31
+ predictions.get("batch"),
32
+ )
33
+ )
34
+ return metrics
@@ -0,0 +1,83 @@
1
+ """Metric helpers for reconstruction, representation, and classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from collections.abc import Callable
7
+
8
+ import numpy as np
9
+ from scipy.stats import pearsonr, spearmanr
10
+ from sklearn.cluster import KMeans
11
+ from sklearn.metrics import (
12
+ accuracy_score,
13
+ adjusted_rand_score,
14
+ confusion_matrix,
15
+ f1_score,
16
+ normalized_mutual_info_score,
17
+ silhouette_score,
18
+ )
19
+ from sklearn.neighbors import NearestNeighbors
20
+
21
+
22
+ def _safe_correlation(
23
+ func: Callable[[np.ndarray, np.ndarray], tuple[float, float]],
24
+ y_true: np.ndarray,
25
+ y_pred: np.ndarray,
26
+ ) -> float:
27
+ flat_true = np.ravel(y_true)
28
+ flat_pred = np.ravel(y_pred)
29
+ if np.std(flat_true) == 0 or np.std(flat_pred) == 0:
30
+ return 0.0
31
+ corr, _ = func(flat_true, flat_pred)
32
+ return 0.0 if math.isnan(corr) else float(corr)
33
+
34
+
35
+ def reconstruction_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
36
+ error = y_true - y_pred
37
+ return {
38
+ "mse": float(np.mean(error**2)),
39
+ "mae": float(np.mean(np.abs(error))),
40
+ "pearson": _safe_correlation(pearsonr, y_true, y_pred),
41
+ "spearman": _safe_correlation(spearmanr, y_true, y_pred),
42
+ }
43
+
44
+
45
+ def knn_label_consistency(latent: np.ndarray, labels: np.ndarray, n_neighbors: int = 10) -> float:
46
+ if len(np.unique(labels)) < 2 or latent.shape[0] <= 1:
47
+ return 0.0
48
+ neighbors = min(n_neighbors + 1, latent.shape[0])
49
+ knn = NearestNeighbors(n_neighbors=neighbors)
50
+ knn.fit(latent)
51
+ indices = knn.kneighbors(latent, return_distance=False)[:, 1:]
52
+ votes = labels[indices]
53
+ majority = np.array([np.bincount(row).argmax() for row in votes])
54
+ return float(np.mean(majority == labels))
55
+
56
+
57
+ def representation_metrics(
58
+ latent: np.ndarray,
59
+ labels: np.ndarray | None,
60
+ batches: np.ndarray | None,
61
+ ) -> dict[str, float]:
62
+ metrics: dict[str, float] = {}
63
+ unique_labels = np.unique(labels) if labels is not None else np.array([])
64
+ unique_batches = np.unique(batches) if batches is not None else np.array([])
65
+ if labels is not None and latent.shape[0] > len(unique_labels) and len(unique_labels) > 1:
66
+ metrics["silhouette"] = float(silhouette_score(latent, labels))
67
+ metrics["knn_label_consistency"] = knn_label_consistency(latent, labels)
68
+ kmeans = KMeans(n_clusters=len(unique_labels), random_state=42, n_init="auto")
69
+ clusters = kmeans.fit_predict(latent)
70
+ metrics["ari"] = float(adjusted_rand_score(labels, clusters))
71
+ metrics["nmi"] = float(normalized_mutual_info_score(labels, clusters))
72
+ if batches is not None and latent.shape[0] > len(unique_batches) and len(unique_batches) > 1:
73
+ metrics["batch_silhouette"] = float(silhouette_score(latent, batches))
74
+ return metrics
75
+
76
+
77
+ def classification_metrics(y_true: np.ndarray, logits: np.ndarray) -> dict[str, object]:
78
+ predicted = logits.argmax(axis=1)
79
+ return {
80
+ "accuracy": float(accuracy_score(y_true, predicted)),
81
+ "macro_f1": float(f1_score(y_true, predicted, average="macro")),
82
+ "confusion_matrix": confusion_matrix(y_true, predicted).tolist(),
83
+ }
@@ -0,0 +1,40 @@
1
+ """Report export helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import pandas as pd
9
+
10
+
11
+ def _scalar_metrics(metrics: dict[str, Any]) -> dict[str, float]:
12
+ return {key: float(value) for key, value in metrics.items() if isinstance(value, (int, float))}
13
+
14
+
15
+ def save_metrics_table(metrics: dict[str, Any], path: str | Path) -> Path:
16
+ """Write scalar metrics to CSV."""
17
+
18
+ output = Path(path)
19
+ frame = pd.DataFrame([_scalar_metrics(metrics)])
20
+ frame.to_csv(output, index=False)
21
+ return output
22
+
23
+
24
+ def save_markdown_report(
25
+ metrics: dict[str, Any],
26
+ *,
27
+ path: str | Path,
28
+ title: str,
29
+ extra_sections: list[str] | None = None,
30
+ ) -> Path:
31
+ """Write a markdown report with scalar and structured metrics."""
32
+
33
+ output = Path(path)
34
+ lines = [f"# {title}", "", "## Metrics", ""]
35
+ for key, value in metrics.items():
36
+ lines.append(f"- **{key}**: {value}")
37
+ if extra_sections:
38
+ lines.extend(["", *extra_sections])
39
+ output.write_text("\n".join(lines), encoding="utf-8")
40
+ return output