scib-metrics 0.5.4__tar.gz → 0.5.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.pre-commit-config.yaml +1 -1
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/CHANGELOG.md +27 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/PKG-INFO +1 -1
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/api.md +1 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/references.bib +10 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/pyproject.toml +1 -1
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/__init__.py +2 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/benchmark/_core.py +25 -17
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/__init__.py +2 -1
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_kbet.py +46 -47
- scib_metrics-0.5.6/src/scib_metrics/metrics/_silhouette.py +162 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_dist.py +23 -3
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_silhouette.py +62 -13
- scib_metrics-0.5.6/tests/test_BRAS_metric.py +347 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_metrics.py +6 -0
- scib_metrics-0.5.4/src/scib_metrics/metrics/_silhouette.py +0 -86
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.codecov.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.cruft.json +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.editorconfig +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/ISSUE_TEMPLATE/bug_report.yml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/ISSUE_TEMPLATE/feature_request.yml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/build.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/release.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_linux.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_linux_cuda.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_linux_pre.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_macos.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_macos_m1.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_windows.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.gitignore +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.readthedocs.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/LICENSE +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/README.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/Makefile +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_static/.gitkeep +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_static/css/custom.css +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_templates/.gitkeep +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_templates/autosummary/class.rst +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_templates/class_no_inherited.rst +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/changelog.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/conf.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/contributing.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/extensions/.gitkeep +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/extensions/typed_returns.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/index.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/notebooks/large_scale.ipynb +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/notebooks/lung_example.ipynb +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/references.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/template_usage.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/tutorials.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/setup.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/_settings.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/_types.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/benchmark/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_graph_connectivity.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_isolated_labels.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_lisi.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_nmi_ari.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_pcr_comparison.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/_dataclass.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/_jax.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/_pynndescent.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_diffusion_nn.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_kmeans.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_lisi.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_pca.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_pcr.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_utils.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_benchmarker.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_neighbors.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_pcr_comparison.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/data.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/sampling.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/test_pca.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/test_pcr.py +0 -0
|
@@ -10,6 +10,33 @@ and this project adheres to [Semantic Versioning][].
|
|
|
10
10
|
|
|
11
11
|
## 0.6.0 (unreleased)
|
|
12
12
|
|
|
13
|
+
## 0.5.6 (2025-07-08)
|
|
14
|
+
|
|
15
|
+
### Added
|
|
16
|
+
|
|
17
|
+
- Add BRAS to Benchmarker as default, instead of regular silhouette batch {pr}`217`
|
|
18
|
+
- Added the option to manually set the KNN graphs before running a benchmarker.
|
|
19
|
+
|
|
20
|
+
### Changed
|
|
21
|
+
|
|
22
|
+
- Changed default of min_max_scale in {func}`scib_metrics.benchmark.get_results` to False {pr}`215`.
|
|
23
|
+
|
|
24
|
+
### Fixed
|
|
25
|
+
|
|
26
|
+
- Reverted Skip labels before loop {pr}`180`, which caused wrong selection of clusters {pr}`213`.
|
|
27
|
+
|
|
28
|
+
## 0.5.5 (2025-06-03)
|
|
29
|
+
|
|
30
|
+
### Added
|
|
31
|
+
|
|
32
|
+
- Add batch removal adapted silhouette (BRAS) metric ({func}`scib_metrics.metrics.bras`) {pr}`197`, which addresses limitations of silhouette for scoring batch effect removal.
|
|
33
|
+
- Add cosine distance implementation required for BRAS.
|
|
34
|
+
|
|
35
|
+
### Changed
|
|
36
|
+
|
|
37
|
+
- Changed {func}`scib_metrics.utils.cdist` to support cosine distance.
|
|
38
|
+
- Changed silhouette-related functions to be compatible with adaptions required for BRAS.
|
|
39
|
+
|
|
13
40
|
## 0.5.4 (2025-04-23)
|
|
14
41
|
|
|
15
42
|
### Fixed
|
|
@@ -36,3 +36,13 @@
|
|
|
36
36
|
pages = {43--49},
|
|
37
37
|
publisher = {Springer Science and Business Media {LLC}}
|
|
38
38
|
}
|
|
39
|
+
|
|
40
|
+
@article{rautenstrauch2025,
|
|
41
|
+
title = {Metrics Matter: Why We Need to Stop Using Silhouette in Single-Cell Benchmarking},
|
|
42
|
+
author = {Pia Rautenstrauch and Uwe Ohler},
|
|
43
|
+
doi = {10.1101/2025.01.21.634098},
|
|
44
|
+
year = {2025},
|
|
45
|
+
month = jan,
|
|
46
|
+
journal = {bioRxiv},
|
|
47
|
+
publisher = {Cold Spring Harbor Laboratory}
|
|
48
|
+
}
|
|
@@ -15,6 +15,7 @@ from .metrics import (
|
|
|
15
15
|
pcr_comparison,
|
|
16
16
|
silhouette_batch,
|
|
17
17
|
silhouette_label,
|
|
18
|
+
bras,
|
|
18
19
|
)
|
|
19
20
|
from ._settings import settings
|
|
20
21
|
|
|
@@ -25,6 +26,7 @@ __all__ = [
|
|
|
25
26
|
"pcr_comparison",
|
|
26
27
|
"silhouette_label",
|
|
27
28
|
"silhouette_batch",
|
|
29
|
+
"bras",
|
|
28
30
|
"ilisi_knn",
|
|
29
31
|
"clisi_knn",
|
|
30
32
|
"lisi_knn",
|
|
@@ -42,6 +42,7 @@ metric_name_cleaner = {
|
|
|
42
42
|
"clisi_knn": "cLISI",
|
|
43
43
|
"ilisi_knn": "iLISI",
|
|
44
44
|
"kbet_per_label": "KBET",
|
|
45
|
+
"bras": "BRAS",
|
|
45
46
|
"graph_connectivity": "Graph connectivity",
|
|
46
47
|
"pcr_comparison": "PCR comparison",
|
|
47
48
|
}
|
|
@@ -72,7 +73,7 @@ class BatchCorrection:
|
|
|
72
73
|
parameters, such as `X` or `labels`.
|
|
73
74
|
"""
|
|
74
75
|
|
|
75
|
-
|
|
76
|
+
bras: MetricType = True
|
|
76
77
|
ilisi_knn: MetricType = True
|
|
77
78
|
kbet_per_label: MetricType = True
|
|
78
79
|
graph_connectivity: MetricType = True
|
|
@@ -88,7 +89,7 @@ class MetricAnnDataAPI(Enum):
|
|
|
88
89
|
silhouette_label = lambda ad, fn: fn(ad.X, ad.obs[_LABELS])
|
|
89
90
|
clisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_LABELS])
|
|
90
91
|
graph_connectivity = lambda ad, fn: fn(ad.uns["15_neighbor_res"], ad.obs[_LABELS])
|
|
91
|
-
|
|
92
|
+
bras = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH])
|
|
92
93
|
pcr_comparison = lambda ad, fn: fn(ad.obsm[_X_PRE], ad.X, ad.obs[_BATCH], categorical=True)
|
|
93
94
|
ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH])
|
|
94
95
|
kbet_per_label = lambda ad, fn: fn(ad.uns["50_neighbor_res"], ad.obs[_BATCH], ad.obs[_LABELS])
|
|
@@ -156,6 +157,7 @@ class Benchmarker:
|
|
|
156
157
|
self._label_key = label_key
|
|
157
158
|
self._n_jobs = n_jobs
|
|
158
159
|
self._progress_bar = progress_bar
|
|
160
|
+
self._compute_neighbors = True
|
|
159
161
|
|
|
160
162
|
if self._bio_conservation_metrics is None and self._batch_correction_metrics is None:
|
|
161
163
|
raise ValueError("Either batch or bio metrics must be defined.")
|
|
@@ -191,19 +193,25 @@ class Benchmarker:
|
|
|
191
193
|
self._emb_adatas[emb_key].obsm[_X_PRE] = self._adata.obsm[self._pre_integrated_embedding_obsm_key]
|
|
192
194
|
|
|
193
195
|
# Compute neighbors
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
196
|
+
if self._compute_neighbors:
|
|
197
|
+
progress = self._emb_adatas.values()
|
|
198
|
+
if self._progress_bar:
|
|
199
|
+
progress = tqdm(progress, desc="Computing neighbors")
|
|
200
|
+
|
|
201
|
+
for ad in progress:
|
|
202
|
+
if neighbor_computer is not None:
|
|
203
|
+
neigh_result = neighbor_computer(ad.X, max(self._neighbor_values))
|
|
204
|
+
else:
|
|
205
|
+
neigh_result = pynndescent(
|
|
206
|
+
ad.X, n_neighbors=max(self._neighbor_values), random_state=0, n_jobs=self._n_jobs
|
|
207
|
+
)
|
|
208
|
+
for n in self._neighbor_values:
|
|
209
|
+
ad.uns[f"{n}_neighbor_res"] = neigh_result.subset_neighbors(n=n)
|
|
210
|
+
else:
|
|
211
|
+
warnings.warn(
|
|
212
|
+
"Computing Neighbors Skipped",
|
|
213
|
+
UserWarning,
|
|
214
|
+
)
|
|
207
215
|
|
|
208
216
|
self._prepared = True
|
|
209
217
|
|
|
@@ -251,7 +259,7 @@ class Benchmarker:
|
|
|
251
259
|
|
|
252
260
|
self._benchmarked = True
|
|
253
261
|
|
|
254
|
-
def get_results(self, min_max_scale: bool =
|
|
262
|
+
def get_results(self, min_max_scale: bool = False, clean_names: bool = True) -> pd.DataFrame:
|
|
255
263
|
"""Return the benchmarking results.
|
|
256
264
|
|
|
257
265
|
Parameters
|
|
@@ -291,7 +299,7 @@ class Benchmarker:
|
|
|
291
299
|
df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE
|
|
292
300
|
return df
|
|
293
301
|
|
|
294
|
-
def plot_results_table(self, min_max_scale: bool =
|
|
302
|
+
def plot_results_table(self, min_max_scale: bool = False, show: bool = True, save_dir: str | None = None) -> Table:
|
|
295
303
|
"""Plot the benchmarking results.
|
|
296
304
|
|
|
297
305
|
Parameters
|
|
@@ -4,13 +4,14 @@ from ._kbet import kbet, kbet_per_label
|
|
|
4
4
|
from ._lisi import clisi_knn, ilisi_knn, lisi_knn
|
|
5
5
|
from ._nmi_ari import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden
|
|
6
6
|
from ._pcr_comparison import pcr_comparison
|
|
7
|
-
from ._silhouette import silhouette_batch, silhouette_label
|
|
7
|
+
from ._silhouette import bras, silhouette_batch, silhouette_label
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
"isolated_labels",
|
|
11
11
|
"pcr_comparison",
|
|
12
12
|
"silhouette_label",
|
|
13
13
|
"silhouette_batch",
|
|
14
|
+
"bras",
|
|
14
15
|
"ilisi_knn",
|
|
15
16
|
"clisi_knn",
|
|
16
17
|
"lisi_knn",
|
|
@@ -138,14 +138,8 @@ def kbet_per_label(
|
|
|
138
138
|
conn_graph = X.knn_graph_connectivities
|
|
139
139
|
|
|
140
140
|
# prepare call of kBET per cluster
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
skipped = clusters[counts > 10]
|
|
144
|
-
clusters = clusters[counts <= 10]
|
|
145
|
-
kbet_scores = {"cluster": list(skipped), "kBET": [np.nan] * len(skipped)}
|
|
146
|
-
logger.info(f"{len(skipped)} clusters consist of a single batch or are too small. Skip.")
|
|
147
|
-
|
|
148
|
-
for clus in clusters:
|
|
141
|
+
kbet_scores = {"cluster": [], "kBET": []}
|
|
142
|
+
for clus in np.unique(labels):
|
|
149
143
|
# subset by label
|
|
150
144
|
mask = labels == clus
|
|
151
145
|
conn_graph_sub = conn_graph[mask, :][:, mask]
|
|
@@ -153,55 +147,60 @@ def kbet_per_label(
|
|
|
153
147
|
n_obs = conn_graph_sub.shape[0]
|
|
154
148
|
batches_sub = batches[mask]
|
|
155
149
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
k0 = np.floor(size_max / n_obs).astype("int")
|
|
161
|
-
|
|
162
|
-
n_comp, labs = scipy.sparse.csgraph.connected_components(conn_graph_sub, connection="strong")
|
|
163
|
-
|
|
164
|
-
if n_comp == 1: # a single component to compute kBET on
|
|
165
|
-
try:
|
|
166
|
-
diffusion_n_comps = np.min([diffusion_n_comps, n_obs - 1])
|
|
167
|
-
nn_graph_sub = diffusion_nn(conn_graph_sub, k=k0, n_comps=diffusion_n_comps)
|
|
168
|
-
# call kBET
|
|
169
|
-
score, _, _ = kbet(
|
|
170
|
-
nn_graph_sub,
|
|
171
|
-
batches=batches_sub,
|
|
172
|
-
alpha=alpha,
|
|
173
|
-
)
|
|
174
|
-
except ValueError:
|
|
175
|
-
logger.info("Diffusion distance failed. Skip.")
|
|
176
|
-
score = 0 # i.e. 100% rejection
|
|
177
|
-
|
|
150
|
+
# check if neighborhood size too small or only one batch in subset
|
|
151
|
+
if np.logical_or(n_obs < 10, len(np.unique(batches_sub)) == 1):
|
|
152
|
+
logger.info(f"{clus} consists of a single batch or is too small. Skip.")
|
|
153
|
+
score = np.nan
|
|
178
154
|
else:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
# check
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
if len(idx_nonan) / len(labs) >= 0.75:
|
|
187
|
-
# create another subset of components, assume they are not visited in a diffusion process
|
|
188
|
-
conn_graph_sub_sub = conn_graph_sub[idx_nonan, :][:, idx_nonan]
|
|
189
|
-
conn_graph_sub_sub.sort_indices()
|
|
155
|
+
quarter_mean = np.floor(np.mean(pd.Series(batches_sub).value_counts()) / 4).astype("int")
|
|
156
|
+
k0 = np.min([70, np.max([10, quarter_mean])])
|
|
157
|
+
# check k0 for reasonability
|
|
158
|
+
if k0 * n_obs >= size_max:
|
|
159
|
+
k0 = np.floor(size_max / n_obs).astype("int")
|
|
160
|
+
|
|
161
|
+
n_comp, labs = scipy.sparse.csgraph.connected_components(conn_graph_sub, connection="strong")
|
|
190
162
|
|
|
163
|
+
if n_comp == 1: # a single component to compute kBET on
|
|
191
164
|
try:
|
|
192
|
-
diffusion_n_comps = np.min([diffusion_n_comps,
|
|
193
|
-
|
|
165
|
+
diffusion_n_comps = np.min([diffusion_n_comps, n_obs - 1])
|
|
166
|
+
nn_graph_sub = diffusion_nn(conn_graph_sub, k=k0, n_comps=diffusion_n_comps)
|
|
194
167
|
# call kBET
|
|
195
168
|
score, _, _ = kbet(
|
|
196
|
-
|
|
197
|
-
batches=batches_sub
|
|
169
|
+
nn_graph_sub,
|
|
170
|
+
batches=batches_sub,
|
|
198
171
|
alpha=alpha,
|
|
199
172
|
)
|
|
200
173
|
except ValueError:
|
|
201
174
|
logger.info("Diffusion distance failed. Skip.")
|
|
202
175
|
score = 0 # i.e. 100% rejection
|
|
203
|
-
|
|
204
|
-
|
|
176
|
+
|
|
177
|
+
else:
|
|
178
|
+
# check the number of components where kBET can be computed upon
|
|
179
|
+
comp_size = pd.Series(labs).value_counts()
|
|
180
|
+
# check which components are small
|
|
181
|
+
comp_size_thresh = 3 * k0
|
|
182
|
+
idx_nonan = np.flatnonzero(np.in1d(labs, comp_size[comp_size >= comp_size_thresh].index))
|
|
183
|
+
|
|
184
|
+
# check if 75% of all cells can be used for kBET run
|
|
185
|
+
if len(idx_nonan) / len(labs) >= 0.75:
|
|
186
|
+
# create another subset of components, assume they are not visited in a diffusion process
|
|
187
|
+
conn_graph_sub_sub = conn_graph_sub[idx_nonan, :][:, idx_nonan]
|
|
188
|
+
conn_graph_sub_sub.sort_indices()
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
diffusion_n_comps = np.min([diffusion_n_comps, conn_graph_sub_sub.shape[0] - 1])
|
|
192
|
+
nn_results_sub_sub = diffusion_nn(conn_graph_sub_sub, k=k0, n_comps=diffusion_n_comps)
|
|
193
|
+
# call kBET
|
|
194
|
+
score, _, _ = kbet(
|
|
195
|
+
nn_results_sub_sub,
|
|
196
|
+
batches=batches_sub[idx_nonan],
|
|
197
|
+
alpha=alpha,
|
|
198
|
+
)
|
|
199
|
+
except ValueError:
|
|
200
|
+
logger.info("Diffusion distance failed. Skip.")
|
|
201
|
+
score = 0 # i.e. 100% rejection
|
|
202
|
+
else: # if there are too many too small connected components, set kBET score to 0
|
|
203
|
+
score = 0 # i.e. 100% rejection
|
|
205
204
|
|
|
206
205
|
kbet_scores["cluster"].append(clus)
|
|
207
206
|
kbet_scores["kBET"].append(score)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from scib_metrics.utils import silhouette_samples
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def silhouette_label(
|
|
10
|
+
X: np.ndarray,
|
|
11
|
+
labels: np.ndarray,
|
|
12
|
+
rescale: bool = True,
|
|
13
|
+
chunk_size: int = 256,
|
|
14
|
+
metric: Literal["euclidean", "cosine"] = "euclidean",
|
|
15
|
+
) -> float:
|
|
16
|
+
"""Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
|
|
17
|
+
|
|
18
|
+
Default parameters ('euclidean') match scIB implementation.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
X
|
|
23
|
+
Array of shape (n_cells, n_features).
|
|
24
|
+
labels
|
|
25
|
+
Array of shape (n_cells,) representing label values
|
|
26
|
+
rescale
|
|
27
|
+
Scale asw into the range [0, 1].
|
|
28
|
+
chunk_size
|
|
29
|
+
Size of chunks to process at a time for distance computation
|
|
30
|
+
metric
|
|
31
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
silhouette score
|
|
36
|
+
"""
|
|
37
|
+
asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size, metric=metric))
|
|
38
|
+
if rescale:
|
|
39
|
+
asw = (asw + 1) / 2
|
|
40
|
+
return np.mean(asw)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def silhouette_batch(
|
|
44
|
+
X: np.ndarray,
|
|
45
|
+
labels: np.ndarray,
|
|
46
|
+
batch: np.ndarray,
|
|
47
|
+
rescale: bool = True,
|
|
48
|
+
chunk_size: int = 256,
|
|
49
|
+
metric: Literal["euclidean", "cosine"] = "euclidean",
|
|
50
|
+
between_cluster_distances: Literal["nearest", "mean_other", "furthest"] = "nearest",
|
|
51
|
+
) -> float:
|
|
52
|
+
"""Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
|
|
53
|
+
|
|
54
|
+
Default parameters ('euclidean', 'nearest') match scIB implementation.
|
|
55
|
+
|
|
56
|
+
Additional options enable BRAS compatible usage (see :func:`~scib_metrics.metrics.bras` documentation).
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
X
|
|
61
|
+
Array of shape (n_cells, n_features).
|
|
62
|
+
labels
|
|
63
|
+
Array of shape (n_cells,) representing label values
|
|
64
|
+
batch
|
|
65
|
+
Array of shape (n_cells,) representing batch values
|
|
66
|
+
rescale
|
|
67
|
+
Scale asw into the range [0, 1]. If True, higher values are better.
|
|
68
|
+
chunk_size
|
|
69
|
+
Size of chunks to process at a time for distance computation.
|
|
70
|
+
metric
|
|
71
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
72
|
+
between_cluster_distances
|
|
73
|
+
Method for computing inter-cluster distances.
|
|
74
|
+
- 'nearest': Standard silhouette (distance to nearest cluster)
|
|
75
|
+
- 'mean_other': BRAS-specific (mean distance to all other clusters)
|
|
76
|
+
- 'furthest': BRAS-specific (distance to furthest cluster)
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
silhouette score
|
|
81
|
+
"""
|
|
82
|
+
sil_dfs = []
|
|
83
|
+
unique_labels = np.unique(labels)
|
|
84
|
+
for group in unique_labels:
|
|
85
|
+
labels_mask = labels == group
|
|
86
|
+
X_subset = X[labels_mask]
|
|
87
|
+
batch_subset = batch[labels_mask]
|
|
88
|
+
n_batches = len(np.unique(batch_subset))
|
|
89
|
+
|
|
90
|
+
if (n_batches == 1) or (n_batches == X_subset.shape[0]):
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
sil_per_group = silhouette_samples(
|
|
94
|
+
X_subset,
|
|
95
|
+
batch_subset,
|
|
96
|
+
chunk_size=chunk_size,
|
|
97
|
+
metric=metric,
|
|
98
|
+
between_cluster_distances=between_cluster_distances,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# take only absolute value
|
|
102
|
+
sil_per_group = np.abs(sil_per_group)
|
|
103
|
+
|
|
104
|
+
if rescale:
|
|
105
|
+
# scale s.t. highest number is optimal
|
|
106
|
+
sil_per_group = 1 - sil_per_group
|
|
107
|
+
|
|
108
|
+
sil_dfs.append(
|
|
109
|
+
pd.DataFrame(
|
|
110
|
+
{
|
|
111
|
+
"group": [group] * len(sil_per_group),
|
|
112
|
+
"silhouette_score": sil_per_group,
|
|
113
|
+
}
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
sil_df = pd.concat(sil_dfs).reset_index(drop=True)
|
|
118
|
+
sil_means = sil_df.groupby("group").mean()
|
|
119
|
+
asw = sil_means["silhouette_score"].mean()
|
|
120
|
+
|
|
121
|
+
return asw
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def bras(
|
|
125
|
+
X: np.ndarray,
|
|
126
|
+
labels: np.ndarray,
|
|
127
|
+
batch: np.ndarray,
|
|
128
|
+
chunk_size: int = 256,
|
|
129
|
+
metric: Literal["euclidean", "cosine"] = "cosine",
|
|
130
|
+
between_cluster_distances: Literal["mean_other", "furthest"] = "mean_other",
|
|
131
|
+
) -> float:
|
|
132
|
+
"""Batch removal adapted silhouette (BRAS) for single-cell data integration assessment :cite:p:`rautenstrauch2025`.
|
|
133
|
+
|
|
134
|
+
BRAS evaluates batch effect removal with respect to batch ids within each label (cell type cluster),
|
|
135
|
+
using a modified silhouette score that accounts for nested batch effects. Unlike standard silhouette,
|
|
136
|
+
BRAS computes between-cluster distances using the `between_cluster_distances` method rather than
|
|
137
|
+
nearest-cluster approach. A higher scores indicates better batch mixing.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
X
|
|
142
|
+
Array of shape (n_cells, n_features).
|
|
143
|
+
labels
|
|
144
|
+
Array of shape (n_cells,) representing label values
|
|
145
|
+
batch
|
|
146
|
+
Array of shape (n_cells,) representing batch values
|
|
147
|
+
rescale
|
|
148
|
+
Scale asw into the range [0, 1]. If True, higher values are better.
|
|
149
|
+
chunk_size
|
|
150
|
+
Size of chunks to process at a time for distance computation.
|
|
151
|
+
metric
|
|
152
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
153
|
+
between_cluster_distances
|
|
154
|
+
Method for computing inter-cluster distances.
|
|
155
|
+
- 'mean_other': Mean distance to all cells in other clusters (default)
|
|
156
|
+
- 'furthest': Distance to furthest cluster (conservative estimate)
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
BRAS score
|
|
161
|
+
"""
|
|
162
|
+
return silhouette_batch(X, labels, batch, True, chunk_size, metric, between_cluster_distances)
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
1
4
|
import jax
|
|
2
5
|
import jax.numpy as jnp
|
|
3
6
|
import numpy as np
|
|
@@ -10,10 +13,20 @@ def _euclidean_distance(x: np.array, y: np.array) -> float:
|
|
|
10
13
|
|
|
11
14
|
|
|
12
15
|
@jax.jit
|
|
13
|
-
def
|
|
16
|
+
def _cosine_distance(x: np.array, y: np.array) -> float:
|
|
17
|
+
xy = jnp.dot(x, y)
|
|
18
|
+
xx = jnp.dot(x, x)
|
|
19
|
+
yy = jnp.dot(y, y)
|
|
20
|
+
dist = 1.0 - xy / jnp.sqrt(xx * yy)
|
|
21
|
+
# Clip the result to avoid rounding error
|
|
22
|
+
return jnp.clip(dist, 0.0, 2.0)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@partial(jax.jit, static_argnames=["metric"])
|
|
26
|
+
def cdist(x: np.ndarray, y: np.ndarray, metric: Literal["euclidean", "cosine"] = "euclidean") -> jnp.ndarray:
|
|
14
27
|
"""Jax implementation of :func:`scipy.spatial.distance.cdist`.
|
|
15
28
|
|
|
16
|
-
Uses euclidean distance.
|
|
29
|
+
Uses euclidean distance by default, cosine distance is also available.
|
|
17
30
|
|
|
18
31
|
Parameters
|
|
19
32
|
----------
|
|
@@ -21,13 +34,20 @@ def cdist(x: np.ndarray, y: np.ndarray) -> jnp.ndarray:
|
|
|
21
34
|
Array of shape (n_cells_a, n_features)
|
|
22
35
|
y
|
|
23
36
|
Array of shape (n_cells_b, n_features)
|
|
37
|
+
metric
|
|
38
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
24
39
|
|
|
25
40
|
Returns
|
|
26
41
|
-------
|
|
27
42
|
dist
|
|
28
43
|
Array of shape (n_cells_a, n_cells_b)
|
|
29
44
|
"""
|
|
30
|
-
|
|
45
|
+
if metric not in ["euclidean", "cosine"]:
|
|
46
|
+
raise ValueError("Invalid metric choice, must be one of ['euclidean' or 'cosine'].")
|
|
47
|
+
if metric == "cosine":
|
|
48
|
+
return jax.vmap(lambda x1: jax.vmap(lambda y1: _cosine_distance(x1, y1))(y))(x)
|
|
49
|
+
else:
|
|
50
|
+
return jax.vmap(lambda x1: jax.vmap(lambda y1: _euclidean_distance(x1, y1))(y))(x)
|
|
31
51
|
|
|
32
52
|
|
|
33
53
|
@jax.jit
|