scib-metrics 0.5.4__tar.gz → 0.5.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.pre-commit-config.yaml +1 -1
  2. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/CHANGELOG.md +27 -0
  3. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/PKG-INFO +1 -1
  4. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/api.md +1 -0
  5. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/references.bib +10 -0
  6. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/pyproject.toml +1 -1
  7. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/__init__.py +2 -0
  8. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/benchmark/_core.py +25 -17
  9. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/__init__.py +2 -1
  10. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_kbet.py +46 -47
  11. scib_metrics-0.5.6/src/scib_metrics/metrics/_silhouette.py +162 -0
  12. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_dist.py +23 -3
  13. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_silhouette.py +62 -13
  14. scib_metrics-0.5.6/tests/test_BRAS_metric.py +347 -0
  15. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_metrics.py +6 -0
  16. scib_metrics-0.5.4/src/scib_metrics/metrics/_silhouette.py +0 -86
  17. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.codecov.yaml +0 -0
  18. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.cruft.json +0 -0
  19. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.editorconfig +0 -0
  20. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/ISSUE_TEMPLATE/bug_report.yml +0 -0
  21. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  22. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/ISSUE_TEMPLATE/feature_request.yml +0 -0
  23. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/build.yaml +0 -0
  24. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/release.yaml +0 -0
  25. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_linux.yaml +0 -0
  26. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_linux_cuda.yaml +0 -0
  27. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_linux_pre.yaml +0 -0
  28. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_macos.yaml +0 -0
  29. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_macos_m1.yaml +0 -0
  30. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.github/workflows/test_windows.yaml +0 -0
  31. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.gitignore +0 -0
  32. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/.readthedocs.yaml +0 -0
  33. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/LICENSE +0 -0
  34. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/README.md +0 -0
  35. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/Makefile +0 -0
  36. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_static/.gitkeep +0 -0
  37. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_static/css/custom.css +0 -0
  38. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_templates/.gitkeep +0 -0
  39. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_templates/autosummary/class.rst +0 -0
  40. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/_templates/class_no_inherited.rst +0 -0
  41. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/changelog.md +0 -0
  42. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/conf.py +0 -0
  43. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/contributing.md +0 -0
  44. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/extensions/.gitkeep +0 -0
  45. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/extensions/typed_returns.py +0 -0
  46. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/index.md +0 -0
  47. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/notebooks/large_scale.ipynb +0 -0
  48. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/notebooks/lung_example.ipynb +0 -0
  49. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/references.md +0 -0
  50. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/template_usage.md +0 -0
  51. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/docs/tutorials.md +0 -0
  52. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/setup.py +0 -0
  53. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/_settings.py +0 -0
  54. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/_types.py +0 -0
  55. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/benchmark/__init__.py +0 -0
  56. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_graph_connectivity.py +0 -0
  57. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_isolated_labels.py +0 -0
  58. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_lisi.py +0 -0
  59. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_nmi_ari.py +0 -0
  60. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/metrics/_pcr_comparison.py +0 -0
  61. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/__init__.py +0 -0
  62. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/_dataclass.py +0 -0
  63. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/_jax.py +0 -0
  64. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/nearest_neighbors/_pynndescent.py +0 -0
  65. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/__init__.py +0 -0
  66. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_diffusion_nn.py +0 -0
  67. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_kmeans.py +0 -0
  68. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_lisi.py +0 -0
  69. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_pca.py +0 -0
  70. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_pcr.py +0 -0
  71. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/src/scib_metrics/utils/_utils.py +0 -0
  72. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/__init__.py +0 -0
  73. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_benchmarker.py +0 -0
  74. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_neighbors.py +0 -0
  75. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/test_pcr_comparison.py +0 -0
  76. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/__init__.py +0 -0
  77. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/data.py +0 -0
  78. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/sampling.py +0 -0
  79. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/test_pca.py +0 -0
  80. {scib_metrics-0.5.4 → scib_metrics-0.5.6}/tests/utils/test_pcr.py +0 -0
@@ -11,7 +11,7 @@ repos:
11
11
  hooks:
12
12
  - id: prettier
13
13
  - repo: https://github.com/astral-sh/ruff-pre-commit
14
- rev: v0.11.6
14
+ rev: v0.12.2
15
15
  hooks:
16
16
  - id: ruff
17
17
  types_or: [python, pyi, jupyter]
@@ -10,6 +10,33 @@ and this project adheres to [Semantic Versioning][].
10
10
 
11
11
  ## 0.6.0 (unreleased)
12
12
 
13
+ ## 0.5.6 (2025-07-08)
14
+
15
+ ### Added
16
+
17
+ - Add BRAS to Benchmarker as default, instead of regular silhouette batch {pr}`217`
18
+ - Added the option to manually set the KNN graphs before running a benchmarker.
19
+
20
+ ### Changed
21
+
22
+ - Changed default of min_max_scale in {func}`scib_metrics.benchmark.get_results` to False {pr}`215`.
23
+
24
+ ### Fixed
25
+
26
+ - Reverted Skip labels before loop {pr}`180`, which caused wrong selection of clusters {pr}`213`.
27
+
28
+ ## 0.5.5 (2025-06-03)
29
+
30
+ ### Added
31
+
32
+ - Add batch removal adapted silhouette (BRAS) metric ({func}`scib_metrics.metrics.bras`) {pr}`197`, which addresses limitations of silhouette for scoring batch effect removal.
33
+ - Add cosine distance implementation required for BRAS.
34
+
35
+ ### Changed
36
+
37
+ - Changed {func}`scib_metrics.utils.cdist` to support cosine distance.
38
+ - Changed silhouette-related functions to be compatible with adaptions required for BRAS.
39
+
13
40
  ## 0.5.4 (2025-04-23)
14
41
 
15
42
  ### Fixed
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scib-metrics
3
- Version: 0.5.4
3
+ Version: 0.5.6
4
4
  Summary: Accelerated and Python-only scIB metrics
5
5
  Project-URL: Documentation, https://scib-metrics.readthedocs.io/
6
6
  Project-URL: Source, https://github.com/yoseflab/scib-metrics
@@ -42,6 +42,7 @@ scib_metrics.ilisi_knn(...)
42
42
  pcr_comparison
43
43
  silhouette_label
44
44
  silhouette_batch
45
+ bras
45
46
  ilisi_knn
46
47
  clisi_knn
47
48
  kbet
@@ -36,3 +36,13 @@
36
36
  pages = {43--49},
37
37
  publisher = {Springer Science and Business Media {LLC}}
38
38
  }
39
+
40
+ @article{rautenstrauch2025,
41
+ title = {Metrics Matter: Why We Need to Stop Using Silhouette in Single-Cell Benchmarking},
42
+ author = {Pia Rautenstrauch and Uwe Ohler},
43
+ doi = {10.1101/2025.01.21.634098},
44
+ year = {2025},
45
+ month = jan,
46
+ journal = {bioRxiv},
47
+ publisher = {Cold Spring Harbor Laboratory}
48
+ }
@@ -5,7 +5,7 @@ requires = ["hatchling"]
5
5
 
6
6
  [project]
7
7
  name = "scib-metrics"
8
- version = "0.5.4"
8
+ version = "0.5.6"
9
9
  description = "Accelerated and Python-only scIB metrics"
10
10
  readme = "README.md"
11
11
  requires-python = ">=3.10"
@@ -15,6 +15,7 @@ from .metrics import (
15
15
  pcr_comparison,
16
16
  silhouette_batch,
17
17
  silhouette_label,
18
+ bras,
18
19
  )
19
20
  from ._settings import settings
20
21
 
@@ -25,6 +26,7 @@ __all__ = [
25
26
  "pcr_comparison",
26
27
  "silhouette_label",
27
28
  "silhouette_batch",
29
+ "bras",
28
30
  "ilisi_knn",
29
31
  "clisi_knn",
30
32
  "lisi_knn",
@@ -42,6 +42,7 @@ metric_name_cleaner = {
42
42
  "clisi_knn": "cLISI",
43
43
  "ilisi_knn": "iLISI",
44
44
  "kbet_per_label": "KBET",
45
+ "bras": "BRAS",
45
46
  "graph_connectivity": "Graph connectivity",
46
47
  "pcr_comparison": "PCR comparison",
47
48
  }
@@ -72,7 +73,7 @@ class BatchCorrection:
72
73
  parameters, such as `X` or `labels`.
73
74
  """
74
75
 
75
- silhouette_batch: MetricType = True
76
+ bras: MetricType = True
76
77
  ilisi_knn: MetricType = True
77
78
  kbet_per_label: MetricType = True
78
79
  graph_connectivity: MetricType = True
@@ -88,7 +89,7 @@ class MetricAnnDataAPI(Enum):
88
89
  silhouette_label = lambda ad, fn: fn(ad.X, ad.obs[_LABELS])
89
90
  clisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_LABELS])
90
91
  graph_connectivity = lambda ad, fn: fn(ad.uns["15_neighbor_res"], ad.obs[_LABELS])
91
- silhouette_batch = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH])
92
+ bras = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH])
92
93
  pcr_comparison = lambda ad, fn: fn(ad.obsm[_X_PRE], ad.X, ad.obs[_BATCH], categorical=True)
93
94
  ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH])
94
95
  kbet_per_label = lambda ad, fn: fn(ad.uns["50_neighbor_res"], ad.obs[_BATCH], ad.obs[_LABELS])
@@ -156,6 +157,7 @@ class Benchmarker:
156
157
  self._label_key = label_key
157
158
  self._n_jobs = n_jobs
158
159
  self._progress_bar = progress_bar
160
+ self._compute_neighbors = True
159
161
 
160
162
  if self._bio_conservation_metrics is None and self._batch_correction_metrics is None:
161
163
  raise ValueError("Either batch or bio metrics must be defined.")
@@ -191,19 +193,25 @@ class Benchmarker:
191
193
  self._emb_adatas[emb_key].obsm[_X_PRE] = self._adata.obsm[self._pre_integrated_embedding_obsm_key]
192
194
 
193
195
  # Compute neighbors
194
- progress = self._emb_adatas.values()
195
- if self._progress_bar:
196
- progress = tqdm(progress, desc="Computing neighbors")
197
-
198
- for ad in progress:
199
- if neighbor_computer is not None:
200
- neigh_result = neighbor_computer(ad.X, max(self._neighbor_values))
201
- else:
202
- neigh_result = pynndescent(
203
- ad.X, n_neighbors=max(self._neighbor_values), random_state=0, n_jobs=self._n_jobs
204
- )
205
- for n in self._neighbor_values:
206
- ad.uns[f"{n}_neighbor_res"] = neigh_result.subset_neighbors(n=n)
196
+ if self._compute_neighbors:
197
+ progress = self._emb_adatas.values()
198
+ if self._progress_bar:
199
+ progress = tqdm(progress, desc="Computing neighbors")
200
+
201
+ for ad in progress:
202
+ if neighbor_computer is not None:
203
+ neigh_result = neighbor_computer(ad.X, max(self._neighbor_values))
204
+ else:
205
+ neigh_result = pynndescent(
206
+ ad.X, n_neighbors=max(self._neighbor_values), random_state=0, n_jobs=self._n_jobs
207
+ )
208
+ for n in self._neighbor_values:
209
+ ad.uns[f"{n}_neighbor_res"] = neigh_result.subset_neighbors(n=n)
210
+ else:
211
+ warnings.warn(
212
+ "Computing Neighbors Skipped",
213
+ UserWarning,
214
+ )
207
215
 
208
216
  self._prepared = True
209
217
 
@@ -251,7 +259,7 @@ class Benchmarker:
251
259
 
252
260
  self._benchmarked = True
253
261
 
254
- def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> pd.DataFrame:
262
+ def get_results(self, min_max_scale: bool = False, clean_names: bool = True) -> pd.DataFrame:
255
263
  """Return the benchmarking results.
256
264
 
257
265
  Parameters
@@ -291,7 +299,7 @@ class Benchmarker:
291
299
  df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE
292
300
  return df
293
301
 
294
- def plot_results_table(self, min_max_scale: bool = True, show: bool = True, save_dir: str | None = None) -> Table:
302
+ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, save_dir: str | None = None) -> Table:
295
303
  """Plot the benchmarking results.
296
304
 
297
305
  Parameters
@@ -4,13 +4,14 @@ from ._kbet import kbet, kbet_per_label
4
4
  from ._lisi import clisi_knn, ilisi_knn, lisi_knn
5
5
  from ._nmi_ari import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden
6
6
  from ._pcr_comparison import pcr_comparison
7
- from ._silhouette import silhouette_batch, silhouette_label
7
+ from ._silhouette import bras, silhouette_batch, silhouette_label
8
8
 
9
9
  __all__ = [
10
10
  "isolated_labels",
11
11
  "pcr_comparison",
12
12
  "silhouette_label",
13
13
  "silhouette_batch",
14
+ "bras",
14
15
  "ilisi_knn",
15
16
  "clisi_knn",
16
17
  "lisi_knn",
@@ -138,14 +138,8 @@ def kbet_per_label(
138
138
  conn_graph = X.knn_graph_connectivities
139
139
 
140
140
  # prepare call of kBET per cluster
141
- clusters = []
142
- clusters, counts = np.unique(labels, return_counts=True)
143
- skipped = clusters[counts > 10]
144
- clusters = clusters[counts <= 10]
145
- kbet_scores = {"cluster": list(skipped), "kBET": [np.nan] * len(skipped)}
146
- logger.info(f"{len(skipped)} clusters consist of a single batch or are too small. Skip.")
147
-
148
- for clus in clusters:
141
+ kbet_scores = {"cluster": [], "kBET": []}
142
+ for clus in np.unique(labels):
149
143
  # subset by label
150
144
  mask = labels == clus
151
145
  conn_graph_sub = conn_graph[mask, :][:, mask]
@@ -153,55 +147,60 @@ def kbet_per_label(
153
147
  n_obs = conn_graph_sub.shape[0]
154
148
  batches_sub = batches[mask]
155
149
 
156
- quarter_mean = np.floor(np.mean(pd.Series(batches_sub).value_counts()) / 4).astype("int")
157
- k0 = np.min([70, np.max([10, quarter_mean])])
158
- # check k0 for reasonability
159
- if k0 * n_obs >= size_max:
160
- k0 = np.floor(size_max / n_obs).astype("int")
161
-
162
- n_comp, labs = scipy.sparse.csgraph.connected_components(conn_graph_sub, connection="strong")
163
-
164
- if n_comp == 1: # a single component to compute kBET on
165
- try:
166
- diffusion_n_comps = np.min([diffusion_n_comps, n_obs - 1])
167
- nn_graph_sub = diffusion_nn(conn_graph_sub, k=k0, n_comps=diffusion_n_comps)
168
- # call kBET
169
- score, _, _ = kbet(
170
- nn_graph_sub,
171
- batches=batches_sub,
172
- alpha=alpha,
173
- )
174
- except ValueError:
175
- logger.info("Diffusion distance failed. Skip.")
176
- score = 0 # i.e. 100% rejection
177
-
150
+ # check if neighborhood size too small or only one batch in subset
151
+ if np.logical_or(n_obs < 10, len(np.unique(batches_sub)) == 1):
152
+ logger.info(f"{clus} consists of a single batch or is too small. Skip.")
153
+ score = np.nan
178
154
  else:
179
- # check the number of components where kBET can be computed upon
180
- comp_size = pd.Series(labs).value_counts()
181
- # check which components are small
182
- comp_size_thresh = 3 * k0
183
- idx_nonan = np.flatnonzero(np.in1d(labs, comp_size[comp_size >= comp_size_thresh].index))
184
-
185
- # check if 75% of all cells can be used for kBET run
186
- if len(idx_nonan) / len(labs) >= 0.75:
187
- # create another subset of components, assume they are not visited in a diffusion process
188
- conn_graph_sub_sub = conn_graph_sub[idx_nonan, :][:, idx_nonan]
189
- conn_graph_sub_sub.sort_indices()
155
+ quarter_mean = np.floor(np.mean(pd.Series(batches_sub).value_counts()) / 4).astype("int")
156
+ k0 = np.min([70, np.max([10, quarter_mean])])
157
+ # check k0 for reasonability
158
+ if k0 * n_obs >= size_max:
159
+ k0 = np.floor(size_max / n_obs).astype("int")
160
+
161
+ n_comp, labs = scipy.sparse.csgraph.connected_components(conn_graph_sub, connection="strong")
190
162
 
163
+ if n_comp == 1: # a single component to compute kBET on
191
164
  try:
192
- diffusion_n_comps = np.min([diffusion_n_comps, conn_graph_sub_sub.shape[0] - 1])
193
- nn_results_sub_sub = diffusion_nn(conn_graph_sub_sub, k=k0, n_comps=diffusion_n_comps)
165
+ diffusion_n_comps = np.min([diffusion_n_comps, n_obs - 1])
166
+ nn_graph_sub = diffusion_nn(conn_graph_sub, k=k0, n_comps=diffusion_n_comps)
194
167
  # call kBET
195
168
  score, _, _ = kbet(
196
- nn_results_sub_sub,
197
- batches=batches_sub[idx_nonan],
169
+ nn_graph_sub,
170
+ batches=batches_sub,
198
171
  alpha=alpha,
199
172
  )
200
173
  except ValueError:
201
174
  logger.info("Diffusion distance failed. Skip.")
202
175
  score = 0 # i.e. 100% rejection
203
- else: # if there are too many too small connected components, set kBET score to 0
204
- score = 0 # i.e. 100% rejection
176
+
177
+ else:
178
+ # check the number of components where kBET can be computed upon
179
+ comp_size = pd.Series(labs).value_counts()
180
+ # check which components are small
181
+ comp_size_thresh = 3 * k0
182
+ idx_nonan = np.flatnonzero(np.in1d(labs, comp_size[comp_size >= comp_size_thresh].index))
183
+
184
+ # check if 75% of all cells can be used for kBET run
185
+ if len(idx_nonan) / len(labs) >= 0.75:
186
+ # create another subset of components, assume they are not visited in a diffusion process
187
+ conn_graph_sub_sub = conn_graph_sub[idx_nonan, :][:, idx_nonan]
188
+ conn_graph_sub_sub.sort_indices()
189
+
190
+ try:
191
+ diffusion_n_comps = np.min([diffusion_n_comps, conn_graph_sub_sub.shape[0] - 1])
192
+ nn_results_sub_sub = diffusion_nn(conn_graph_sub_sub, k=k0, n_comps=diffusion_n_comps)
193
+ # call kBET
194
+ score, _, _ = kbet(
195
+ nn_results_sub_sub,
196
+ batches=batches_sub[idx_nonan],
197
+ alpha=alpha,
198
+ )
199
+ except ValueError:
200
+ logger.info("Diffusion distance failed. Skip.")
201
+ score = 0 # i.e. 100% rejection
202
+ else: # if there are too many too small connected components, set kBET score to 0
203
+ score = 0 # i.e. 100% rejection
205
204
 
206
205
  kbet_scores["cluster"].append(clus)
207
206
  kbet_scores["kBET"].append(score)
@@ -0,0 +1,162 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from scib_metrics.utils import silhouette_samples
7
+
8
+
9
+ def silhouette_label(
10
+ X: np.ndarray,
11
+ labels: np.ndarray,
12
+ rescale: bool = True,
13
+ chunk_size: int = 256,
14
+ metric: Literal["euclidean", "cosine"] = "euclidean",
15
+ ) -> float:
16
+ """Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
17
+
18
+ Default parameters ('euclidean') match scIB implementation.
19
+
20
+ Parameters
21
+ ----------
22
+ X
23
+ Array of shape (n_cells, n_features).
24
+ labels
25
+ Array of shape (n_cells,) representing label values
26
+ rescale
27
+ Scale asw into the range [0, 1].
28
+ chunk_size
29
+ Size of chunks to process at a time for distance computation
30
+ metric
31
+ The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
32
+
33
+ Returns
34
+ -------
35
+ silhouette score
36
+ """
37
+ asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size, metric=metric))
38
+ if rescale:
39
+ asw = (asw + 1) / 2
40
+ return np.mean(asw)
41
+
42
+
43
+ def silhouette_batch(
44
+ X: np.ndarray,
45
+ labels: np.ndarray,
46
+ batch: np.ndarray,
47
+ rescale: bool = True,
48
+ chunk_size: int = 256,
49
+ metric: Literal["euclidean", "cosine"] = "euclidean",
50
+ between_cluster_distances: Literal["nearest", "mean_other", "furthest"] = "nearest",
51
+ ) -> float:
52
+ """Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
53
+
54
+ Default parameters ('euclidean', 'nearest') match scIB implementation.
55
+
56
+ Additional options enable BRAS compatible usage (see :func:`~scib_metrics.metrics.bras` documentation).
57
+
58
+ Parameters
59
+ ----------
60
+ X
61
+ Array of shape (n_cells, n_features).
62
+ labels
63
+ Array of shape (n_cells,) representing label values
64
+ batch
65
+ Array of shape (n_cells,) representing batch values
66
+ rescale
67
+ Scale asw into the range [0, 1]. If True, higher values are better.
68
+ chunk_size
69
+ Size of chunks to process at a time for distance computation.
70
+ metric
71
+ The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
72
+ between_cluster_distances
73
+ Method for computing inter-cluster distances.
74
+ - 'nearest': Standard silhouette (distance to nearest cluster)
75
+ - 'mean_other': BRAS-specific (mean distance to all other clusters)
76
+ - 'furthest': BRAS-specific (distance to furthest cluster)
77
+
78
+ Returns
79
+ -------
80
+ silhouette score
81
+ """
82
+ sil_dfs = []
83
+ unique_labels = np.unique(labels)
84
+ for group in unique_labels:
85
+ labels_mask = labels == group
86
+ X_subset = X[labels_mask]
87
+ batch_subset = batch[labels_mask]
88
+ n_batches = len(np.unique(batch_subset))
89
+
90
+ if (n_batches == 1) or (n_batches == X_subset.shape[0]):
91
+ continue
92
+
93
+ sil_per_group = silhouette_samples(
94
+ X_subset,
95
+ batch_subset,
96
+ chunk_size=chunk_size,
97
+ metric=metric,
98
+ between_cluster_distances=between_cluster_distances,
99
+ )
100
+
101
+ # take only absolute value
102
+ sil_per_group = np.abs(sil_per_group)
103
+
104
+ if rescale:
105
+ # scale s.t. highest number is optimal
106
+ sil_per_group = 1 - sil_per_group
107
+
108
+ sil_dfs.append(
109
+ pd.DataFrame(
110
+ {
111
+ "group": [group] * len(sil_per_group),
112
+ "silhouette_score": sil_per_group,
113
+ }
114
+ )
115
+ )
116
+
117
+ sil_df = pd.concat(sil_dfs).reset_index(drop=True)
118
+ sil_means = sil_df.groupby("group").mean()
119
+ asw = sil_means["silhouette_score"].mean()
120
+
121
+ return asw
122
+
123
+
124
+ def bras(
125
+ X: np.ndarray,
126
+ labels: np.ndarray,
127
+ batch: np.ndarray,
128
+ chunk_size: int = 256,
129
+ metric: Literal["euclidean", "cosine"] = "cosine",
130
+ between_cluster_distances: Literal["mean_other", "furthest"] = "mean_other",
131
+ ) -> float:
132
+ """Batch removal adapted silhouette (BRAS) for single-cell data integration assessment :cite:p:`rautenstrauch2025`.
133
+
134
+ BRAS evaluates batch effect removal with respect to batch ids within each label (cell type cluster),
135
+ using a modified silhouette score that accounts for nested batch effects. Unlike standard silhouette,
136
+ BRAS computes between-cluster distances using the `between_cluster_distances` method rather than
137
+ nearest-cluster approach. A higher scores indicates better batch mixing.
138
+
139
+ Parameters
140
+ ----------
141
+ X
142
+ Array of shape (n_cells, n_features).
143
+ labels
144
+ Array of shape (n_cells,) representing label values
145
+ batch
146
+ Array of shape (n_cells,) representing batch values
147
+ rescale
148
+ Scale asw into the range [0, 1]. If True, higher values are better.
149
+ chunk_size
150
+ Size of chunks to process at a time for distance computation.
151
+ metric
152
+ The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
153
+ between_cluster_distances
154
+ Method for computing inter-cluster distances.
155
+ - 'mean_other': Mean distance to all cells in other clusters (default)
156
+ - 'furthest': Distance to furthest cluster (conservative estimate)
157
+
158
+ Returns
159
+ -------
160
+ BRAS score
161
+ """
162
+ return silhouette_batch(X, labels, batch, True, chunk_size, metric, between_cluster_distances)
@@ -1,3 +1,6 @@
1
+ from functools import partial
2
+ from typing import Literal
3
+
1
4
  import jax
2
5
  import jax.numpy as jnp
3
6
  import numpy as np
@@ -10,10 +13,20 @@ def _euclidean_distance(x: np.array, y: np.array) -> float:
10
13
 
11
14
 
12
15
  @jax.jit
13
- def cdist(x: np.ndarray, y: np.ndarray) -> jnp.ndarray:
16
+ def _cosine_distance(x: np.array, y: np.array) -> float:
17
+ xy = jnp.dot(x, y)
18
+ xx = jnp.dot(x, x)
19
+ yy = jnp.dot(y, y)
20
+ dist = 1.0 - xy / jnp.sqrt(xx * yy)
21
+ # Clip the result to avoid rounding error
22
+ return jnp.clip(dist, 0.0, 2.0)
23
+
24
+
25
+ @partial(jax.jit, static_argnames=["metric"])
26
+ def cdist(x: np.ndarray, y: np.ndarray, metric: Literal["euclidean", "cosine"] = "euclidean") -> jnp.ndarray:
14
27
  """Jax implementation of :func:`scipy.spatial.distance.cdist`.
15
28
 
16
- Uses euclidean distance.
29
+ Uses euclidean distance by default, cosine distance is also available.
17
30
 
18
31
  Parameters
19
32
  ----------
@@ -21,13 +34,20 @@ def cdist(x: np.ndarray, y: np.ndarray) -> jnp.ndarray:
21
34
  Array of shape (n_cells_a, n_features)
22
35
  y
23
36
  Array of shape (n_cells_b, n_features)
37
+ metric
38
+ The distance metric to use. The distance function can be 'euclidean' (default) or 'cosine'.
24
39
 
25
40
  Returns
26
41
  -------
27
42
  dist
28
43
  Array of shape (n_cells_a, n_cells_b)
29
44
  """
30
- return jax.vmap(lambda x1: jax.vmap(lambda y1: _euclidean_distance(x1, y1))(y))(x)
45
+ if metric not in ["euclidean", "cosine"]:
46
+ raise ValueError("Invalid metric choice, must be one of ['euclidean' or 'cosine'].")
47
+ if metric == "cosine":
48
+ return jax.vmap(lambda x1: jax.vmap(lambda y1: _cosine_distance(x1, y1))(y))(x)
49
+ else:
50
+ return jax.vmap(lambda x1: jax.vmap(lambda y1: _euclidean_distance(x1, y1))(y))(x)
31
51
 
32
52
 
33
53
  @jax.jit