scib-metrics 0.5.4__tar.gz → 0.5.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.pre-commit-config.yaml +1 -1
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/CHANGELOG.md +12 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/PKG-INFO +1 -1
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/api.md +1 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/references.bib +10 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/pyproject.toml +1 -1
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/__init__.py +2 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/metrics/__init__.py +2 -1
- scib_metrics-0.5.5/src/scib_metrics/metrics/_silhouette.py +162 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_dist.py +23 -3
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_silhouette.py +62 -13
- scib_metrics-0.5.5/tests/test_BRAS_metric.py +347 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/test_metrics.py +6 -0
- scib_metrics-0.5.4/src/scib_metrics/metrics/_silhouette.py +0 -86
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.codecov.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.cruft.json +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.editorconfig +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/ISSUE_TEMPLATE/bug_report.yml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/ISSUE_TEMPLATE/feature_request.yml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/build.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/release.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/test_linux.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/test_linux_cuda.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/test_linux_pre.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/test_macos.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/test_macos_m1.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.github/workflows/test_windows.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.gitignore +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/.readthedocs.yaml +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/LICENSE +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/README.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/Makefile +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/_static/.gitkeep +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/_static/css/custom.css +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/_templates/.gitkeep +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/_templates/autosummary/class.rst +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/_templates/class_no_inherited.rst +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/changelog.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/conf.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/contributing.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/extensions/.gitkeep +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/extensions/typed_returns.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/index.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/notebooks/large_scale.ipynb +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/notebooks/lung_example.ipynb +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/references.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/template_usage.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/docs/tutorials.md +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/setup.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/_settings.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/_types.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/benchmark/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/benchmark/_core.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/metrics/_graph_connectivity.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/metrics/_isolated_labels.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/metrics/_kbet.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/metrics/_lisi.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/metrics/_nmi_ari.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/metrics/_pcr_comparison.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/nearest_neighbors/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/nearest_neighbors/_dataclass.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/nearest_neighbors/_jax.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/nearest_neighbors/_pynndescent.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_diffusion_nn.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_kmeans.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_lisi.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_pca.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_pcr.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/utils/_utils.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/test_benchmarker.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/test_neighbors.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/test_pcr_comparison.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/utils/__init__.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/utils/data.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/utils/sampling.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/utils/test_pca.py +0 -0
- {scib_metrics-0.5.4 → scib_metrics-0.5.5}/tests/utils/test_pcr.py +0 -0
|
@@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning][].
|
|
|
10
10
|
|
|
11
11
|
## 0.6.0 (unreleased)
|
|
12
12
|
|
|
13
|
+
## 0.5.5 (2025-06-03)
|
|
14
|
+
|
|
15
|
+
### Added
|
|
16
|
+
|
|
17
|
+
- Add batch removal adapted silhouette (BRAS) metric ({func}`scib_metrics.metrics.bras`) {pr}`197`, which addresses limitations of silhouette for scoring batch effect removal.
|
|
18
|
+
- Add cosine distance implementation required for BRAS.
|
|
19
|
+
|
|
20
|
+
### Changed
|
|
21
|
+
|
|
22
|
+
- Changed {func}`scib_metrics.utils.cdist` to support cosine distance.
|
|
23
|
+
- Changed silhouette-related functions to be compatible with adaptions required for BRAS.
|
|
24
|
+
|
|
13
25
|
## 0.5.4 (2025-04-23)
|
|
14
26
|
|
|
15
27
|
### Fixed
|
|
@@ -36,3 +36,13 @@
|
|
|
36
36
|
pages = {43--49},
|
|
37
37
|
publisher = {Springer Science and Business Media {LLC}}
|
|
38
38
|
}
|
|
39
|
+
|
|
40
|
+
@article{rautenstrauch2025,
|
|
41
|
+
title = {Metrics Matter: Why We Need to Stop Using Silhouette in Single-Cell Benchmarking},
|
|
42
|
+
author = {Pia Rautenstrauch and Uwe Ohler},
|
|
43
|
+
doi = {10.1101/2025.01.21.634098},
|
|
44
|
+
year = {2025},
|
|
45
|
+
month = jan,
|
|
46
|
+
journal = {bioRxiv},
|
|
47
|
+
publisher = {Cold Spring Harbor Laboratory}
|
|
48
|
+
}
|
|
@@ -15,6 +15,7 @@ from .metrics import (
|
|
|
15
15
|
pcr_comparison,
|
|
16
16
|
silhouette_batch,
|
|
17
17
|
silhouette_label,
|
|
18
|
+
bras,
|
|
18
19
|
)
|
|
19
20
|
from ._settings import settings
|
|
20
21
|
|
|
@@ -25,6 +26,7 @@ __all__ = [
|
|
|
25
26
|
"pcr_comparison",
|
|
26
27
|
"silhouette_label",
|
|
27
28
|
"silhouette_batch",
|
|
29
|
+
"bras",
|
|
28
30
|
"ilisi_knn",
|
|
29
31
|
"clisi_knn",
|
|
30
32
|
"lisi_knn",
|
|
@@ -4,13 +4,14 @@ from ._kbet import kbet, kbet_per_label
|
|
|
4
4
|
from ._lisi import clisi_knn, ilisi_knn, lisi_knn
|
|
5
5
|
from ._nmi_ari import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden
|
|
6
6
|
from ._pcr_comparison import pcr_comparison
|
|
7
|
-
from ._silhouette import silhouette_batch, silhouette_label
|
|
7
|
+
from ._silhouette import bras, silhouette_batch, silhouette_label
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
"isolated_labels",
|
|
11
11
|
"pcr_comparison",
|
|
12
12
|
"silhouette_label",
|
|
13
13
|
"silhouette_batch",
|
|
14
|
+
"bras",
|
|
14
15
|
"ilisi_knn",
|
|
15
16
|
"clisi_knn",
|
|
16
17
|
"lisi_knn",
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from scib_metrics.utils import silhouette_samples
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def silhouette_label(
|
|
10
|
+
X: np.ndarray,
|
|
11
|
+
labels: np.ndarray,
|
|
12
|
+
rescale: bool = True,
|
|
13
|
+
chunk_size: int = 256,
|
|
14
|
+
metric: Literal["euclidean", "cosine"] = "euclidean",
|
|
15
|
+
) -> float:
|
|
16
|
+
"""Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
|
|
17
|
+
|
|
18
|
+
Default parameters ('euclidean') match scIB implementation.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
X
|
|
23
|
+
Array of shape (n_cells, n_features).
|
|
24
|
+
labels
|
|
25
|
+
Array of shape (n_cells,) representing label values
|
|
26
|
+
rescale
|
|
27
|
+
Scale asw into the range [0, 1].
|
|
28
|
+
chunk_size
|
|
29
|
+
Size of chunks to process at a time for distance computation
|
|
30
|
+
metric
|
|
31
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
silhouette score
|
|
36
|
+
"""
|
|
37
|
+
asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size, metric=metric))
|
|
38
|
+
if rescale:
|
|
39
|
+
asw = (asw + 1) / 2
|
|
40
|
+
return np.mean(asw)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def silhouette_batch(
|
|
44
|
+
X: np.ndarray,
|
|
45
|
+
labels: np.ndarray,
|
|
46
|
+
batch: np.ndarray,
|
|
47
|
+
rescale: bool = True,
|
|
48
|
+
chunk_size: int = 256,
|
|
49
|
+
metric: Literal["euclidean", "cosine"] = "euclidean",
|
|
50
|
+
between_cluster_distances: Literal["nearest", "mean_other", "furthest"] = "nearest",
|
|
51
|
+
) -> float:
|
|
52
|
+
"""Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
|
|
53
|
+
|
|
54
|
+
Default parameters ('euclidean', 'nearest') match scIB implementation.
|
|
55
|
+
|
|
56
|
+
Additional options enable BRAS compatible usage (see :func:`~scib_metrics.metrics.bras` documentation).
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
X
|
|
61
|
+
Array of shape (n_cells, n_features).
|
|
62
|
+
labels
|
|
63
|
+
Array of shape (n_cells,) representing label values
|
|
64
|
+
batch
|
|
65
|
+
Array of shape (n_cells,) representing batch values
|
|
66
|
+
rescale
|
|
67
|
+
Scale asw into the range [0, 1]. If True, higher values are better.
|
|
68
|
+
chunk_size
|
|
69
|
+
Size of chunks to process at a time for distance computation.
|
|
70
|
+
metric
|
|
71
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
72
|
+
between_cluster_distances
|
|
73
|
+
Method for computing inter-cluster distances.
|
|
74
|
+
- 'nearest': Standard silhouette (distance to nearest cluster)
|
|
75
|
+
- 'mean_other': BRAS-specific (mean distance to all other clusters)
|
|
76
|
+
- 'furthest': BRAS-specific (distance to furthest cluster)
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
silhouette score
|
|
81
|
+
"""
|
|
82
|
+
sil_dfs = []
|
|
83
|
+
unique_labels = np.unique(labels)
|
|
84
|
+
for group in unique_labels:
|
|
85
|
+
labels_mask = labels == group
|
|
86
|
+
X_subset = X[labels_mask]
|
|
87
|
+
batch_subset = batch[labels_mask]
|
|
88
|
+
n_batches = len(np.unique(batch_subset))
|
|
89
|
+
|
|
90
|
+
if (n_batches == 1) or (n_batches == X_subset.shape[0]):
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
sil_per_group = silhouette_samples(
|
|
94
|
+
X_subset,
|
|
95
|
+
batch_subset,
|
|
96
|
+
chunk_size=chunk_size,
|
|
97
|
+
metric=metric,
|
|
98
|
+
between_cluster_distances=between_cluster_distances,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# take only absolute value
|
|
102
|
+
sil_per_group = np.abs(sil_per_group)
|
|
103
|
+
|
|
104
|
+
if rescale:
|
|
105
|
+
# scale s.t. highest number is optimal
|
|
106
|
+
sil_per_group = 1 - sil_per_group
|
|
107
|
+
|
|
108
|
+
sil_dfs.append(
|
|
109
|
+
pd.DataFrame(
|
|
110
|
+
{
|
|
111
|
+
"group": [group] * len(sil_per_group),
|
|
112
|
+
"silhouette_score": sil_per_group,
|
|
113
|
+
}
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
sil_df = pd.concat(sil_dfs).reset_index(drop=True)
|
|
118
|
+
sil_means = sil_df.groupby("group").mean()
|
|
119
|
+
asw = sil_means["silhouette_score"].mean()
|
|
120
|
+
|
|
121
|
+
return asw
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def bras(
|
|
125
|
+
X: np.ndarray,
|
|
126
|
+
labels: np.ndarray,
|
|
127
|
+
batch: np.ndarray,
|
|
128
|
+
chunk_size: int = 256,
|
|
129
|
+
metric: Literal["euclidean", "cosine"] = "cosine",
|
|
130
|
+
between_cluster_distances: Literal["mean_other", "furthest"] = "mean_other",
|
|
131
|
+
) -> float:
|
|
132
|
+
"""Batch removal adapted silhouette (BRAS) for single-cell data integration assessment :cite:p:`rautenstrauch2025`.
|
|
133
|
+
|
|
134
|
+
BRAS evaluates batch effect removal with respect to batch ids within each label (cell type cluster),
|
|
135
|
+
using a modified silhouette score that accounts for nested batch effects. Unlike standard silhouette,
|
|
136
|
+
BRAS computes between-cluster distances using the `between_cluster_distances` method rather than
|
|
137
|
+
nearest-cluster approach. A higher scores indicates better batch mixing.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
X
|
|
142
|
+
Array of shape (n_cells, n_features).
|
|
143
|
+
labels
|
|
144
|
+
Array of shape (n_cells,) representing label values
|
|
145
|
+
batch
|
|
146
|
+
Array of shape (n_cells,) representing batch values
|
|
147
|
+
rescale
|
|
148
|
+
Scale asw into the range [0, 1]. If True, higher values are better.
|
|
149
|
+
chunk_size
|
|
150
|
+
Size of chunks to process at a time for distance computation.
|
|
151
|
+
metric
|
|
152
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
153
|
+
between_cluster_distances
|
|
154
|
+
Method for computing inter-cluster distances.
|
|
155
|
+
- 'mean_other': Mean distance to all cells in other clusters (default)
|
|
156
|
+
- 'furthest': Distance to furthest cluster (conservative estimate)
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
BRAS score
|
|
161
|
+
"""
|
|
162
|
+
return silhouette_batch(X, labels, batch, True, chunk_size, metric, between_cluster_distances)
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
1
4
|
import jax
|
|
2
5
|
import jax.numpy as jnp
|
|
3
6
|
import numpy as np
|
|
@@ -10,10 +13,20 @@ def _euclidean_distance(x: np.array, y: np.array) -> float:
|
|
|
10
13
|
|
|
11
14
|
|
|
12
15
|
@jax.jit
|
|
13
|
-
def
|
|
16
|
+
def _cosine_distance(x: np.array, y: np.array) -> float:
|
|
17
|
+
xy = jnp.dot(x, y)
|
|
18
|
+
xx = jnp.dot(x, x)
|
|
19
|
+
yy = jnp.dot(y, y)
|
|
20
|
+
dist = 1.0 - xy / jnp.sqrt(xx * yy)
|
|
21
|
+
# Clip the result to avoid rounding error
|
|
22
|
+
return jnp.clip(dist, 0.0, 2.0)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@partial(jax.jit, static_argnames=["metric"])
|
|
26
|
+
def cdist(x: np.ndarray, y: np.ndarray, metric: Literal["euclidean", "cosine"] = "euclidean") -> jnp.ndarray:
|
|
14
27
|
"""Jax implementation of :func:`scipy.spatial.distance.cdist`.
|
|
15
28
|
|
|
16
|
-
Uses euclidean distance.
|
|
29
|
+
Uses euclidean distance by default, cosine distance is also available.
|
|
17
30
|
|
|
18
31
|
Parameters
|
|
19
32
|
----------
|
|
@@ -21,13 +34,20 @@ def cdist(x: np.ndarray, y: np.ndarray) -> jnp.ndarray:
|
|
|
21
34
|
Array of shape (n_cells_a, n_features)
|
|
22
35
|
y
|
|
23
36
|
Array of shape (n_cells_b, n_features)
|
|
37
|
+
metric
|
|
38
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
24
39
|
|
|
25
40
|
Returns
|
|
26
41
|
-------
|
|
27
42
|
dist
|
|
28
43
|
Array of shape (n_cells_a, n_cells_b)
|
|
29
44
|
"""
|
|
30
|
-
|
|
45
|
+
if metric not in ["euclidean", "cosine"]:
|
|
46
|
+
raise ValueError("Invalid metric choice, must be one of ['euclidean' or 'cosine'].")
|
|
47
|
+
if metric == "cosine":
|
|
48
|
+
return jax.vmap(lambda x1: jax.vmap(lambda y1: _cosine_distance(x1, y1))(y))(x)
|
|
49
|
+
else:
|
|
50
|
+
return jax.vmap(lambda x1: jax.vmap(lambda y1: _euclidean_distance(x1, y1))(y))(x)
|
|
31
51
|
|
|
32
52
|
|
|
33
53
|
@jax.jit
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from functools import partial
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import jax
|
|
4
5
|
import jax.numpy as jnp
|
|
@@ -9,13 +10,21 @@ from ._dist import cdist
|
|
|
9
10
|
from ._utils import get_ndarray
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
@jax.jit
|
|
13
|
+
@partial(jax.jit, static_argnames=["between_cluster_distances"])
|
|
13
14
|
def _silhouette_reduce(
|
|
14
|
-
D_chunk: jnp.ndarray,
|
|
15
|
+
D_chunk: jnp.ndarray,
|
|
16
|
+
start: int,
|
|
17
|
+
labels: jnp.ndarray,
|
|
18
|
+
label_freqs: jnp.ndarray,
|
|
19
|
+
between_cluster_distances: Literal["nearest", "mean_other", "furthest"] = "nearest",
|
|
15
20
|
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
16
21
|
"""Accumulate silhouette statistics for vertical chunk of X.
|
|
17
22
|
|
|
18
|
-
Follows scikit-learn implementation.
|
|
23
|
+
Follows scikit-learn implementation with default parameter usage ('nearest').
|
|
24
|
+
|
|
25
|
+
Additional options enable BRAS compatible usage, addressing specific limitations of using silhouette in the context
|
|
26
|
+
of evaluating data integration (see :func:`~scib_metrics.metrics.bras` documentation).
|
|
27
|
+
|
|
19
28
|
|
|
20
29
|
Parameters
|
|
21
30
|
----------
|
|
@@ -29,6 +38,12 @@ def _silhouette_reduce(
|
|
|
29
38
|
Corresponding cluster labels, encoded as {0, ..., n_clusters-1}.
|
|
30
39
|
label_freqs
|
|
31
40
|
Distribution of cluster labels in ``labels``.
|
|
41
|
+
between_cluster_distances
|
|
42
|
+
Method for computing inter-cluster distances.
|
|
43
|
+
- 'nearest': Standard silhouette (distance to nearest cluster)
|
|
44
|
+
- 'mean_other': BRAS-specific (mean distance to all other clusters)
|
|
45
|
+
- 'furthest': BRAS-specific (distance to furthest cluster)
|
|
46
|
+
|
|
32
47
|
"""
|
|
33
48
|
# accumulate distances from each sample to each cluster
|
|
34
49
|
D_chunk_len = D_chunk.shape[0]
|
|
@@ -43,21 +58,36 @@ def _silhouette_reduce(
|
|
|
43
58
|
# clust_dists = jax.lax.fori_loop(
|
|
44
59
|
# 0, D_chunk_len, lambda i, _data: _bincount(i, _data), (clust_dists, D_chunk, labels, label_freqs)
|
|
45
60
|
# )[0]
|
|
46
|
-
|
|
47
61
|
clust_dists = jax.vmap(partial(jnp.bincount, length=label_freqs.shape[0]), in_axes=(None, 0))(labels, D_chunk)
|
|
48
62
|
|
|
49
63
|
# intra_index selects intra-cluster distances within clust_dists
|
|
50
64
|
intra_index = (jnp.arange(D_chunk_len), jax.lax.dynamic_slice(labels, (start,), (D_chunk_len,)))
|
|
51
65
|
# intra_clust_dists are averaged over cluster size outside this function
|
|
52
66
|
intra_clust_dists = clust_dists[intra_index]
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
67
|
+
|
|
68
|
+
if between_cluster_distances == "furthest":
|
|
69
|
+
# of the remaining distances we normalise and extract the maximum
|
|
70
|
+
clust_dists = clust_dists.at[intra_index].set(-jnp.inf)
|
|
71
|
+
clust_dists /= label_freqs
|
|
72
|
+
inter_clust_dists = clust_dists.max(axis=1)
|
|
73
|
+
elif between_cluster_distances == "mean_other":
|
|
74
|
+
clust_dists = clust_dists.at[intra_index].set(jnp.nan)
|
|
75
|
+
total_other_dists = jnp.nansum(clust_dists, axis=1)
|
|
76
|
+
total_other_count = jnp.sum(label_freqs) - label_freqs[jax.lax.dynamic_slice(labels, (start,), (D_chunk_len,))]
|
|
77
|
+
inter_clust_dists = total_other_dists / total_other_count
|
|
78
|
+
elif between_cluster_distances == "nearest":
|
|
79
|
+
# of the remaining distances we normalise and extract the minimum
|
|
80
|
+
clust_dists = clust_dists.at[intra_index].set(jnp.inf)
|
|
81
|
+
clust_dists /= label_freqs
|
|
82
|
+
inter_clust_dists = clust_dists.min(axis=1)
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError("Parameter 'between_cluster_distances' must be one of ['nearest', 'mean_other', 'furthest'].")
|
|
57
85
|
return intra_clust_dists, inter_clust_dists
|
|
58
86
|
|
|
59
87
|
|
|
60
|
-
def _pairwise_distances_chunked(
|
|
88
|
+
def _pairwise_distances_chunked(
|
|
89
|
+
X: jnp.ndarray, chunk_size: int, reduce_fn: callable, metric: Literal["euclidean", "cosine"] = "euclidean"
|
|
90
|
+
) -> jnp.ndarray:
|
|
61
91
|
"""Compute pairwise distances in chunks to reduce memory usage."""
|
|
62
92
|
n_samples = X.shape[0]
|
|
63
93
|
n_chunks = jnp.ceil(n_samples / chunk_size).astype(int)
|
|
@@ -66,17 +96,27 @@ def _pairwise_distances_chunked(X: jnp.ndarray, chunk_size: int, reduce_fn: call
|
|
|
66
96
|
for i in range(n_chunks):
|
|
67
97
|
start = i * chunk_size
|
|
68
98
|
end = min((i + 1) * chunk_size, n_samples)
|
|
69
|
-
intra_cluster_dists, inter_cluster_dists = reduce_fn(cdist(X[start:end], X), start=start)
|
|
99
|
+
intra_cluster_dists, inter_cluster_dists = reduce_fn(cdist(X[start:end], X, metric=metric), start=start)
|
|
70
100
|
intra_dists_all.append(intra_cluster_dists)
|
|
71
101
|
inter_dists_all.append(inter_cluster_dists)
|
|
72
102
|
return jnp.concatenate(intra_dists_all), jnp.concatenate(inter_dists_all)
|
|
73
103
|
|
|
74
104
|
|
|
75
|
-
def silhouette_samples(
|
|
105
|
+
def silhouette_samples(
|
|
106
|
+
X: np.ndarray,
|
|
107
|
+
labels: np.ndarray,
|
|
108
|
+
chunk_size: int = 256,
|
|
109
|
+
metric: Literal["euclidean", "cosine"] = "euclidean",
|
|
110
|
+
between_cluster_distances: Literal["nearest", "mean_other", "furthest"] = "nearest",
|
|
111
|
+
) -> np.ndarray:
|
|
76
112
|
"""Compute the Silhouette Coefficient for each observation.
|
|
77
113
|
|
|
78
114
|
Implements :func:`sklearn.metrics.silhouette_samples`.
|
|
79
115
|
|
|
116
|
+
Default parameters ('euclidean', 'nearest') match scIB implementation.
|
|
117
|
+
|
|
118
|
+
Additional options enable BRAS compatible usage (see `bras()` documentation).
|
|
119
|
+
|
|
80
120
|
Parameters
|
|
81
121
|
----------
|
|
82
122
|
X
|
|
@@ -87,6 +127,13 @@ def silhouette_samples(X: np.ndarray, labels: np.ndarray, chunk_size: int = 256)
|
|
|
87
127
|
for each observation.
|
|
88
128
|
chunk_size
|
|
89
129
|
Number of samples to process at a time for distance computation.
|
|
130
|
+
metric
|
|
131
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
132
|
+
between_cluster_distances
|
|
133
|
+
Method for computing inter-cluster distances.
|
|
134
|
+
- 'nearest': Standard silhouette (distance to nearest cluster)
|
|
135
|
+
- 'mean_other': BRAS-specific (mean distance to all other clusters)
|
|
136
|
+
- 'furthest': BRAS-specific (distance to furthest cluster)
|
|
90
137
|
|
|
91
138
|
Returns
|
|
92
139
|
-------
|
|
@@ -97,8 +144,10 @@ def silhouette_samples(X: np.ndarray, labels: np.ndarray, chunk_size: int = 256)
|
|
|
97
144
|
labels = pd.Categorical(labels).codes
|
|
98
145
|
labels = jnp.asarray(labels)
|
|
99
146
|
label_freqs = jnp.bincount(labels)
|
|
100
|
-
reduce_fn = partial(
|
|
101
|
-
|
|
147
|
+
reduce_fn = partial(
|
|
148
|
+
_silhouette_reduce, labels=labels, label_freqs=label_freqs, between_cluster_distances=between_cluster_distances
|
|
149
|
+
)
|
|
150
|
+
results = _pairwise_distances_chunked(X, chunk_size=chunk_size, reduce_fn=reduce_fn, metric=metric)
|
|
102
151
|
intra_clust_dists, inter_clust_dists = results
|
|
103
152
|
|
|
104
153
|
denom = jnp.take(label_freqs - 1, labels, mode="clip")
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from sklearn.metrics import silhouette_samples as sk_silhouette_samples
|
|
6
|
+
from sklearn.metrics.pairwise import pairwise_distances
|
|
7
|
+
|
|
8
|
+
import scib_metrics
|
|
9
|
+
from tests.utils.data import dummy_benchmarker_adata, dummy_x_labels
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# v 0.5.3 with modifications for BRAS usage
|
|
13
|
+
def silhouette_batch_custom(
|
|
14
|
+
X: np.ndarray,
|
|
15
|
+
labels: np.ndarray,
|
|
16
|
+
batch: np.ndarray,
|
|
17
|
+
rescale: bool = True,
|
|
18
|
+
chunk_size: int = 256,
|
|
19
|
+
metric: Literal["euclidean", "cosine"] = "euclidean",
|
|
20
|
+
between_cluster_distances: Literal["nearest", "mean_other", "furthest"] = "nearest",
|
|
21
|
+
) -> float:
|
|
22
|
+
"""Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
|
|
23
|
+
|
|
24
|
+
Default parameters ('euclidean', 'nearest') match scIB implementation.
|
|
25
|
+
|
|
26
|
+
Additional options enable BRAS compatible usage (see `bras()` documentation).
|
|
27
|
+
|
|
28
|
+
This version uses a naive implementation for the silhouette score calculation, serving as a reference for the fast
|
|
29
|
+
implementation provided in this package.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
X
|
|
34
|
+
Array of shape (n_cells, n_features).
|
|
35
|
+
labels
|
|
36
|
+
Array of shape (n_cells,) representing label values
|
|
37
|
+
batch
|
|
38
|
+
Array of shape (n_cells,) representing batch values
|
|
39
|
+
rescale
|
|
40
|
+
Scale asw into the range [0, 1]. If True, higher values are better.
|
|
41
|
+
chunk_size
|
|
42
|
+
Size of chunks to process at a time for distance computation.
|
|
43
|
+
metric
|
|
44
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
45
|
+
between_cluster_distances
|
|
46
|
+
Method for computing inter-cluster distances.
|
|
47
|
+
- 'nearest': Standard silhouette (distance to nearest cluster)
|
|
48
|
+
- 'mean_other': BRAS-specific (mean distance to all other clusters)
|
|
49
|
+
- 'furthest': BRAS-specific (distance to furthest cluster)
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
silhouette score
|
|
54
|
+
"""
|
|
55
|
+
sil_dfs = []
|
|
56
|
+
unique_labels = np.unique(labels)
|
|
57
|
+
for group in unique_labels:
|
|
58
|
+
labels_mask = labels == group
|
|
59
|
+
X_subset = X[labels_mask]
|
|
60
|
+
batch_subset = batch[labels_mask]
|
|
61
|
+
n_batches = len(np.unique(batch_subset))
|
|
62
|
+
|
|
63
|
+
if (n_batches == 1) or (n_batches == X_subset.shape[0]):
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
sil_per_group = silhouette_samples_custom(
|
|
67
|
+
X_subset,
|
|
68
|
+
batch_subset,
|
|
69
|
+
metric=metric,
|
|
70
|
+
between_cluster_distances=between_cluster_distances,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# take only absolute value
|
|
74
|
+
sil_per_group = np.abs(sil_per_group)
|
|
75
|
+
|
|
76
|
+
if rescale:
|
|
77
|
+
# scale s.t. highest number is optimal
|
|
78
|
+
sil_per_group = 1 - sil_per_group
|
|
79
|
+
|
|
80
|
+
sil_dfs.append(
|
|
81
|
+
pd.DataFrame(
|
|
82
|
+
{
|
|
83
|
+
"group": [group] * len(sil_per_group),
|
|
84
|
+
"silhouette_score": sil_per_group,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
sil_df = pd.concat(sil_dfs).reset_index(drop=True)
|
|
90
|
+
sil_means = sil_df.groupby("group").mean()
|
|
91
|
+
asw = sil_means["silhouette_score"].mean()
|
|
92
|
+
|
|
93
|
+
return asw
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def silhouette_samples_custom(X, cluster_labels, metric="euclidean", between_cluster_distances="nearest"):
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
Naive implementation of silhouette score modifications changing inter-cluster distance calculations for testing fast
|
|
100
|
+
implementations.
|
|
101
|
+
|
|
102
|
+
Experimental variants include:
|
|
103
|
+
- Standard silhouette ('nearest' cluster distance)
|
|
104
|
+
- BRAS-specific modifications ('mean_other', 'furthest')
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
X
|
|
109
|
+
Array of shape (n_cells, n_features).
|
|
110
|
+
cluster_labels
|
|
111
|
+
Array of shape (n_cells,) representing cluster label values
|
|
112
|
+
metric
|
|
113
|
+
The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
|
|
114
|
+
between_cluster_distances
|
|
115
|
+
Method for computing inter-cluster distances.
|
|
116
|
+
- 'nearest': Standard silhouette (distance to nearest cluster)
|
|
117
|
+
- 'mean_other': BRAS-specific (mean distance to all other clusters)
|
|
118
|
+
- 'furthest': BRAS-specific (distance to furthest cluster)
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
(Modified) silhouette scores with selected inter-cluster distance calcuation.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
# Number of clusters
|
|
126
|
+
unique_cluster_labels = np.unique(cluster_labels)
|
|
127
|
+
n_clusters = len(unique_cluster_labels)
|
|
128
|
+
|
|
129
|
+
# If there's only one cluster or no clusters, return 0 as silhouette score cannot be computed
|
|
130
|
+
if n_clusters == 1 or n_clusters == 0:
|
|
131
|
+
return 0
|
|
132
|
+
|
|
133
|
+
# Initialize silhouette scores
|
|
134
|
+
silhouette_scores = np.zeros(len(X))
|
|
135
|
+
|
|
136
|
+
# Calculate pairwise distance matrix
|
|
137
|
+
distance_matrix = pairwise_distances(X, metric=metric)
|
|
138
|
+
|
|
139
|
+
for i in range(len(X)):
|
|
140
|
+
# Points in the same cluster
|
|
141
|
+
same_cluster = cluster_labels == cluster_labels[i]
|
|
142
|
+
other_clusters = cluster_labels != cluster_labels[i]
|
|
143
|
+
# Exclude the current point for intra-cluster distance
|
|
144
|
+
same_cluster[i] = False
|
|
145
|
+
|
|
146
|
+
# a: Mean distance from i to all other points in the same cluster
|
|
147
|
+
if np.sum(same_cluster) == 0:
|
|
148
|
+
silhouette_scores[i] = 0
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
a = np.mean(distance_matrix[i, same_cluster])
|
|
152
|
+
|
|
153
|
+
# b: Mean distance from i to all points in the furthest different cluster
|
|
154
|
+
if between_cluster_distances == "furthest":
|
|
155
|
+
b = np.max(
|
|
156
|
+
[
|
|
157
|
+
np.mean(distance_matrix[i, cluster_labels == label])
|
|
158
|
+
for label in unique_cluster_labels
|
|
159
|
+
if label != cluster_labels[i]
|
|
160
|
+
]
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# b: Mean distance from i to all points in any other cluster
|
|
164
|
+
elif between_cluster_distances == "mean_other":
|
|
165
|
+
b = np.mean(distance_matrix[i, other_clusters])
|
|
166
|
+
|
|
167
|
+
# b: Mean distance from i to all points in the nearest different cluster
|
|
168
|
+
else:
|
|
169
|
+
b = np.min(
|
|
170
|
+
[
|
|
171
|
+
np.mean(distance_matrix[i, cluster_labels == label])
|
|
172
|
+
for label in unique_cluster_labels
|
|
173
|
+
if label != cluster_labels[i]
|
|
174
|
+
]
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Silhouette score for point i
|
|
178
|
+
silhouette_scores[i] = (b - a) / max(a, b)
|
|
179
|
+
|
|
180
|
+
return silhouette_scores
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_silhouette_samples_cosine():
|
|
184
|
+
X, labels = dummy_x_labels()
|
|
185
|
+
assert np.allclose(
|
|
186
|
+
scib_metrics.utils.silhouette_samples(X, labels, metric="cosine"),
|
|
187
|
+
silhouette_samples_custom(X, labels, metric="cosine"),
|
|
188
|
+
atol=1e-5,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def test_silhouette_samples_nearest():
|
|
193
|
+
X, labels = dummy_x_labels()
|
|
194
|
+
assert np.allclose(
|
|
195
|
+
scib_metrics.utils.silhouette_samples(X, labels, between_cluster_distances="nearest"),
|
|
196
|
+
silhouette_samples_custom(X, labels, between_cluster_distances="nearest"),
|
|
197
|
+
atol=1e-5,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def test_silhouette_samples_mean_other():
|
|
202
|
+
X, labels = dummy_x_labels()
|
|
203
|
+
assert np.allclose(
|
|
204
|
+
scib_metrics.utils.silhouette_samples(X, labels, between_cluster_distances="mean_other"),
|
|
205
|
+
silhouette_samples_custom(X, labels, between_cluster_distances="mean_other"),
|
|
206
|
+
atol=1e-5,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def test_silhouette_samples_furthest():
|
|
211
|
+
X, labels = dummy_x_labels()
|
|
212
|
+
assert np.allclose(
|
|
213
|
+
scib_metrics.utils.silhouette_samples(X, labels, between_cluster_distances="furthest"),
|
|
214
|
+
silhouette_samples_custom(X, labels, between_cluster_distances="furthest"),
|
|
215
|
+
atol=1e-5,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def test_silhouette_label():
|
|
220
|
+
X, labels = dummy_x_labels()
|
|
221
|
+
score = scib_metrics.silhouette_label(X, labels)
|
|
222
|
+
score_sk = (np.mean(sk_silhouette_samples(X, labels)) + 1) / 2
|
|
223
|
+
assert np.allclose(score, score_sk)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def test_silhouette_label_cosine():
|
|
227
|
+
X, labels = dummy_x_labels()
|
|
228
|
+
score = scib_metrics.silhouette_label(X, labels, metric="cosine")
|
|
229
|
+
score_sk = (np.mean(sk_silhouette_samples(X, labels, metric="cosine")) + 1) / 2
|
|
230
|
+
assert np.allclose(score, score_sk)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def test_bras():
|
|
234
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
235
|
+
score = scib_metrics.bras(ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key])
|
|
236
|
+
score_no = silhouette_batch_custom(
|
|
237
|
+
ad.obsm[emb_keys[0]],
|
|
238
|
+
ad.obs[labels_key],
|
|
239
|
+
ad.obs[batch_key],
|
|
240
|
+
metric="cosine",
|
|
241
|
+
between_cluster_distances="mean_other",
|
|
242
|
+
)
|
|
243
|
+
assert np.allclose(score, score_no)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def test_silhouette_batch_default():
|
|
247
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
248
|
+
score = scib_metrics.silhouette_batch(ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key])
|
|
249
|
+
score_no = silhouette_batch_custom(ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key])
|
|
250
|
+
assert np.allclose(score, score_no)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def test_silhouette_batch_cosine():
|
|
254
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
255
|
+
score = scib_metrics.silhouette_batch(ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], metric="cosine")
|
|
256
|
+
score_no = silhouette_batch_custom(ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], metric="cosine")
|
|
257
|
+
assert np.allclose(score, score_no)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def test_silhouette_batch_nearest():
|
|
261
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
262
|
+
score = scib_metrics.silhouette_batch(
|
|
263
|
+
ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], between_cluster_distances="nearest"
|
|
264
|
+
)
|
|
265
|
+
score_no = silhouette_batch_custom(
|
|
266
|
+
ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], between_cluster_distances="nearest"
|
|
267
|
+
)
|
|
268
|
+
assert np.allclose(score, score_no)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def test_silhouette_batch_cosine_nearest():
|
|
272
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
273
|
+
score = scib_metrics.silhouette_batch(
|
|
274
|
+
ad.obsm[emb_keys[0]],
|
|
275
|
+
ad.obs[labels_key],
|
|
276
|
+
ad.obs[batch_key],
|
|
277
|
+
metric="cosine",
|
|
278
|
+
between_cluster_distances="nearest",
|
|
279
|
+
)
|
|
280
|
+
score_no = silhouette_batch_custom(
|
|
281
|
+
ad.obsm[emb_keys[0]],
|
|
282
|
+
ad.obs[labels_key],
|
|
283
|
+
ad.obs[batch_key],
|
|
284
|
+
metric="cosine",
|
|
285
|
+
between_cluster_distances="nearest",
|
|
286
|
+
)
|
|
287
|
+
assert np.allclose(score, score_no)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def test_silhouette_batch_furthest():
|
|
291
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
292
|
+
score = scib_metrics.silhouette_batch(
|
|
293
|
+
ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], between_cluster_distances="furthest"
|
|
294
|
+
)
|
|
295
|
+
score_no = silhouette_batch_custom(
|
|
296
|
+
ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], between_cluster_distances="furthest"
|
|
297
|
+
)
|
|
298
|
+
assert np.allclose(score, score_no)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def test_silhouette_batch_cosine_furthest():
|
|
302
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
303
|
+
score = scib_metrics.silhouette_batch(
|
|
304
|
+
ad.obsm[emb_keys[0]],
|
|
305
|
+
ad.obs[labels_key],
|
|
306
|
+
ad.obs[batch_key],
|
|
307
|
+
metric="cosine",
|
|
308
|
+
between_cluster_distances="furthest",
|
|
309
|
+
)
|
|
310
|
+
score_no = silhouette_batch_custom(
|
|
311
|
+
ad.obsm[emb_keys[0]],
|
|
312
|
+
ad.obs[labels_key],
|
|
313
|
+
ad.obs[batch_key],
|
|
314
|
+
metric="cosine",
|
|
315
|
+
between_cluster_distances="furthest",
|
|
316
|
+
)
|
|
317
|
+
assert np.allclose(score, score_no)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def test_silhouette_batch_mean_other():
|
|
321
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
322
|
+
score = scib_metrics.silhouette_batch(
|
|
323
|
+
ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], between_cluster_distances="mean_other"
|
|
324
|
+
)
|
|
325
|
+
score_no = silhouette_batch_custom(
|
|
326
|
+
ad.obsm[emb_keys[0]], ad.obs[labels_key], ad.obs[batch_key], between_cluster_distances="mean_other"
|
|
327
|
+
)
|
|
328
|
+
assert np.allclose(score, score_no)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def test_silhouette_batch_cosine_mean_other():
|
|
332
|
+
ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata()
|
|
333
|
+
score = scib_metrics.silhouette_batch(
|
|
334
|
+
ad.obsm[emb_keys[0]],
|
|
335
|
+
ad.obs[labels_key],
|
|
336
|
+
ad.obs[batch_key],
|
|
337
|
+
metric="cosine",
|
|
338
|
+
between_cluster_distances="mean_other",
|
|
339
|
+
)
|
|
340
|
+
score_no = silhouette_batch_custom(
|
|
341
|
+
ad.obsm[emb_keys[0]],
|
|
342
|
+
ad.obs[labels_key],
|
|
343
|
+
ad.obs[batch_key],
|
|
344
|
+
metric="cosine",
|
|
345
|
+
between_cluster_distances="mean_other",
|
|
346
|
+
)
|
|
347
|
+
assert np.allclose(score, score_no)
|
|
@@ -32,6 +32,12 @@ def test_cdist():
|
|
|
32
32
|
assert np.allclose(scib_metrics.utils.cdist(x, y), sp_cdist(x, y))
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
def test_cdist_cosine():
|
|
36
|
+
x = jnp.array([[1, 2], [3, 4]])
|
|
37
|
+
y = jnp.array([[5, 6], [7, 8]])
|
|
38
|
+
assert np.allclose(scib_metrics.utils.cdist(x, y, metric="cosine"), sp_cdist(x, y, metric="cosine"), atol=1e-5)
|
|
39
|
+
|
|
40
|
+
|
|
35
41
|
def test_pdist():
|
|
36
42
|
x = jnp.array([[1, 2], [3, 4]])
|
|
37
43
|
assert np.allclose(scib_metrics.utils.pdist_squareform(x), squareform(pdist(x)))
|
|
@@ -1,86 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
|
|
4
|
-
from scib_metrics.utils import silhouette_samples
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True, chunk_size: int = 256) -> float:
|
|
8
|
-
"""Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
|
|
9
|
-
|
|
10
|
-
Parameters
|
|
11
|
-
----------
|
|
12
|
-
X
|
|
13
|
-
Array of shape (n_cells, n_features).
|
|
14
|
-
labels
|
|
15
|
-
Array of shape (n_cells,) representing label values
|
|
16
|
-
rescale
|
|
17
|
-
Scale asw into the range [0, 1].
|
|
18
|
-
chunk_size
|
|
19
|
-
Size of chunks to process at a time for distance computation.
|
|
20
|
-
|
|
21
|
-
Returns
|
|
22
|
-
-------
|
|
23
|
-
silhouette score
|
|
24
|
-
"""
|
|
25
|
-
asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size))
|
|
26
|
-
if rescale:
|
|
27
|
-
asw = (asw + 1) / 2
|
|
28
|
-
return np.mean(asw)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def silhouette_batch(
|
|
32
|
-
X: np.ndarray, labels: np.ndarray, batch: np.ndarray, rescale: bool = True, chunk_size: int = 256
|
|
33
|
-
) -> float:
|
|
34
|
-
"""Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
|
|
35
|
-
|
|
36
|
-
Parameters
|
|
37
|
-
----------
|
|
38
|
-
X
|
|
39
|
-
Array of shape (n_cells, n_features).
|
|
40
|
-
labels
|
|
41
|
-
Array of shape (n_cells,) representing label values
|
|
42
|
-
batch
|
|
43
|
-
Array of shape (n_cells,) representing batch values
|
|
44
|
-
rescale
|
|
45
|
-
Scale asw into the range [0, 1]. If True, higher values are better.
|
|
46
|
-
chunk_size
|
|
47
|
-
Size of chunks to process at a time for distance computation.
|
|
48
|
-
|
|
49
|
-
Returns
|
|
50
|
-
-------
|
|
51
|
-
silhouette score
|
|
52
|
-
"""
|
|
53
|
-
sil_dfs = []
|
|
54
|
-
unique_labels = np.unique(labels)
|
|
55
|
-
for group in unique_labels:
|
|
56
|
-
labels_mask = labels == group
|
|
57
|
-
X_subset = X[labels_mask]
|
|
58
|
-
batch_subset = batch[labels_mask]
|
|
59
|
-
n_batches = len(np.unique(batch_subset))
|
|
60
|
-
|
|
61
|
-
if (n_batches == 1) or (n_batches == X_subset.shape[0]):
|
|
62
|
-
continue
|
|
63
|
-
|
|
64
|
-
sil_per_group = silhouette_samples(X_subset, batch_subset, chunk_size=chunk_size)
|
|
65
|
-
|
|
66
|
-
# take only absolute value
|
|
67
|
-
sil_per_group = np.abs(sil_per_group)
|
|
68
|
-
|
|
69
|
-
if rescale:
|
|
70
|
-
# scale s.t. highest number is optimal
|
|
71
|
-
sil_per_group = 1 - sil_per_group
|
|
72
|
-
|
|
73
|
-
sil_dfs.append(
|
|
74
|
-
pd.DataFrame(
|
|
75
|
-
{
|
|
76
|
-
"group": [group] * len(sil_per_group),
|
|
77
|
-
"silhouette_score": sil_per_group,
|
|
78
|
-
}
|
|
79
|
-
)
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
sil_df = pd.concat(sil_dfs).reset_index(drop=True)
|
|
83
|
-
sil_means = sil_df.groupby("group").mean()
|
|
84
|
-
asw = sil_means["silhouette_score"].mean()
|
|
85
|
-
|
|
86
|
-
return asw
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{scib_metrics-0.5.4 → scib_metrics-0.5.5}/src/scib_metrics/nearest_neighbors/_pynndescent.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|