multiscoresplot 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- multiscoresplot/__init__.py +28 -0
- multiscoresplot/_colorspace.py +321 -0
- multiscoresplot/_interactive.py +328 -0
- multiscoresplot/_legend.py +284 -0
- multiscoresplot/_plotting.py +266 -0
- multiscoresplot/_scoring.py +99 -0
- multiscoresplot/py.typed +0 -0
- multiscoresplot-1.0.0.dist-info/METADATA +85 -0
- multiscoresplot-1.0.0.dist-info/RECORD +11 -0
- multiscoresplot-1.0.0.dist-info/WHEEL +4 -0
- multiscoresplot-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""multiscoresplot -- multi-dimensional gene set scoring visualization."""
|
|
2
|
+
|
|
3
|
+
from multiscoresplot._colorspace import (
|
|
4
|
+
blend_to_rgb,
|
|
5
|
+
get_component_labels,
|
|
6
|
+
project_direct,
|
|
7
|
+
project_pca,
|
|
8
|
+
reduce_to_rgb,
|
|
9
|
+
register_reducer,
|
|
10
|
+
)
|
|
11
|
+
from multiscoresplot._interactive import plot_embedding_interactive
|
|
12
|
+
from multiscoresplot._legend import render_legend
|
|
13
|
+
from multiscoresplot._plotting import plot_embedding
|
|
14
|
+
from multiscoresplot._scoring import score_gene_sets
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"blend_to_rgb",
|
|
18
|
+
"get_component_labels",
|
|
19
|
+
"plot_embedding",
|
|
20
|
+
"plot_embedding_interactive",
|
|
21
|
+
"project_direct",
|
|
22
|
+
"project_pca",
|
|
23
|
+
"reduce_to_rgb",
|
|
24
|
+
"register_reducer",
|
|
25
|
+
"render_legend",
|
|
26
|
+
"score_gene_sets",
|
|
27
|
+
]
|
|
28
|
+
__version__ = "1.0.0"
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""Color space construction and cell projection (pipeline steps 2-3).
|
|
2
|
+
|
|
3
|
+
Provides two projection strategies:
|
|
4
|
+
|
|
5
|
+
* **Direct** (`blend_to_rgb`): multiplicative blending from white using
|
|
6
|
+
explicit base colors. Supports 2-3 gene sets.
|
|
7
|
+
* **Reduction** (`reduce_to_rgb`): dimensionality reduction (PCA, NMF, ICA,
|
|
8
|
+
or custom) to 3 color channels. Works for any number of gene sets (≥ 2).
|
|
9
|
+
|
|
10
|
+
Legacy names ``project_direct`` and ``project_pca`` are kept as thin
|
|
11
|
+
deprecation wrappers.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import warnings
|
|
17
|
+
from typing import TYPE_CHECKING
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from multiscoresplot._scoring import SCORE_PREFIX
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from collections.abc import Callable
|
|
25
|
+
|
|
26
|
+
from numpy.typing import NDArray
|
|
27
|
+
from pandas import DataFrame
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"blend_to_rgb",
|
|
31
|
+
"get_component_labels",
|
|
32
|
+
"project_direct",
|
|
33
|
+
"project_pca",
|
|
34
|
+
"reduce_to_rgb",
|
|
35
|
+
"register_reducer",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
# ---- default color palettes ------------------------------------------------
|
|
39
|
+
|
|
40
|
+
DEFAULT_COLORS_2: list[tuple[float, float, float]] = [
|
|
41
|
+
(0.0, 0.0, 1.0), # blue
|
|
42
|
+
(1.0, 0.0, 0.0), # red
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
DEFAULT_COLORS_3: list[tuple[float, float, float]] = [
|
|
46
|
+
(1.0, 0.0, 0.0), # red
|
|
47
|
+
(0.0, 1.0, 0.0), # green
|
|
48
|
+
(0.0, 0.0, 1.0), # blue
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ---- private helpers --------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _validate_score_columns(scores: DataFrame, prefix: str = SCORE_PREFIX) -> list[str]:
|
|
56
|
+
"""Return the ``score-*`` column names, raising if none are found."""
|
|
57
|
+
cols = [c for c in scores.columns if c.startswith(prefix)]
|
|
58
|
+
if not cols:
|
|
59
|
+
msg = (
|
|
60
|
+
"No score columns found. Expected columns starting with "
|
|
61
|
+
f"'{prefix}'. Run score_gene_sets() first."
|
|
62
|
+
)
|
|
63
|
+
raise ValueError(msg)
|
|
64
|
+
return cols
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _multiplicative_blend(
|
|
68
|
+
score_matrix: NDArray,
|
|
69
|
+
colors: list[tuple[float, float, float]],
|
|
70
|
+
) -> NDArray:
|
|
71
|
+
"""Blend gene set scores into RGB via multiplicative gradients from white.
|
|
72
|
+
|
|
73
|
+
For each gene set *i* with base color ``c_i`` and score ``s_i``:
|
|
74
|
+
``gradient_i = 1 - s_i * (1 - c_i)``
|
|
75
|
+
|
|
76
|
+
The final colour is the element-wise product of all gradients.
|
|
77
|
+
"""
|
|
78
|
+
n_cells = score_matrix.shape[0]
|
|
79
|
+
rgb = np.ones((n_cells, 3), dtype=np.float64)
|
|
80
|
+
|
|
81
|
+
for i, color in enumerate(colors):
|
|
82
|
+
c = np.asarray(color, dtype=np.float64) # (3,)
|
|
83
|
+
s = score_matrix[:, i : i + 1] # (n_cells, 1)
|
|
84
|
+
gradient = 1.0 - s * (1.0 - c) # (n_cells, 3)
|
|
85
|
+
rgb *= gradient
|
|
86
|
+
|
|
87
|
+
return np.clip(rgb, 0.0, 1.0)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _minmax_normalize(X: NDArray, n_target: int = 3) -> NDArray:
|
|
91
|
+
"""Min-max normalize each column to [0, 1], zero-pad to *n_target* columns."""
|
|
92
|
+
k = X.shape[1]
|
|
93
|
+
for j in range(k):
|
|
94
|
+
col = X[:, j]
|
|
95
|
+
lo, hi = col.min(), col.max()
|
|
96
|
+
if hi - lo > 0:
|
|
97
|
+
X[:, j] = (col - lo) / (hi - lo)
|
|
98
|
+
else:
|
|
99
|
+
X[:, j] = 0.0
|
|
100
|
+
|
|
101
|
+
if k < n_target:
|
|
102
|
+
pad = np.zeros((X.shape[0], n_target - k), dtype=np.float64)
|
|
103
|
+
X = np.hstack([X, pad])
|
|
104
|
+
|
|
105
|
+
return X
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ---- reducer registry ------------------------------------------------------
|
|
109
|
+
|
|
110
|
+
ReducerFn = type(lambda: None) # placeholder for type alias
|
|
111
|
+
|
|
112
|
+
_REDUCERS: dict[str, Callable[..., NDArray]] = {}
|
|
113
|
+
_COMPONENT_PREFIXES: dict[str, str] = {}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def register_reducer(
|
|
117
|
+
name: str,
|
|
118
|
+
fn: Callable[..., NDArray],
|
|
119
|
+
*,
|
|
120
|
+
component_prefix: str | None = None,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""Register a dimensionality reduction method for use with ``reduce_to_rgb``.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
name
|
|
127
|
+
Short identifier (e.g. ``"pca"``, ``"nmf"``).
|
|
128
|
+
fn
|
|
129
|
+
Callable with signature ``(X, n_components, **kwargs) -> NDArray``
|
|
130
|
+
returning an ``(n_cells, 3)`` array with values in [0, 1].
|
|
131
|
+
component_prefix
|
|
132
|
+
Label prefix for legend axes (e.g. ``"PC"`` → PC1, PC2, PC3).
|
|
133
|
+
"""
|
|
134
|
+
_REDUCERS[name] = fn
|
|
135
|
+
if component_prefix is not None:
|
|
136
|
+
_COMPONENT_PREFIXES[name] = component_prefix
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def get_component_labels(method: str) -> list[str]:
|
|
140
|
+
"""Return ``["<prefix>1", "<prefix>2", "<prefix>3"]`` for a registered method."""
|
|
141
|
+
prefix = _COMPONENT_PREFIXES.get(method, "C")
|
|
142
|
+
return [f"{prefix}{i + 1}" for i in range(3)]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# ---- built-in reducer implementations --------------------------------------
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _reduce_pca(X: NDArray, n_components: int, **kwargs: object) -> NDArray:
|
|
149
|
+
"""PCA via numpy SVD."""
|
|
150
|
+
mean = X.mean(axis=0)
|
|
151
|
+
X_centered = X - mean
|
|
152
|
+
|
|
153
|
+
if np.allclose(X_centered, 0.0):
|
|
154
|
+
return np.zeros((X.shape[0], 3), dtype=np.float64)
|
|
155
|
+
|
|
156
|
+
U, S, _ = np.linalg.svd(X_centered, full_matrices=False)
|
|
157
|
+
k = min(n_components, U.shape[1])
|
|
158
|
+
pc_scores = U[:, :k] * S[:k]
|
|
159
|
+
return _minmax_normalize(pc_scores, n_target=3)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _reduce_nmf(X: NDArray, n_components: int, **kwargs: object) -> NDArray:
|
|
163
|
+
"""NMF via scikit-learn."""
|
|
164
|
+
from sklearn.decomposition import NMF
|
|
165
|
+
|
|
166
|
+
if np.allclose(X, X.mean(axis=0)):
|
|
167
|
+
return np.zeros((X.shape[0], 3), dtype=np.float64)
|
|
168
|
+
|
|
169
|
+
k = min(n_components, X.shape[1])
|
|
170
|
+
defaults: dict[str, object] = {"init": "nndsvda", "max_iter": 300}
|
|
171
|
+
defaults.update(kwargs)
|
|
172
|
+
model = NMF(n_components=k, **defaults) # type: ignore[arg-type]
|
|
173
|
+
W = model.fit_transform(X)
|
|
174
|
+
return _minmax_normalize(W, n_target=3)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _reduce_ica(X: NDArray, n_components: int, **kwargs: object) -> NDArray:
|
|
178
|
+
"""ICA via scikit-learn FastICA."""
|
|
179
|
+
from sklearn.decomposition import FastICA
|
|
180
|
+
|
|
181
|
+
if np.allclose(X, X.mean(axis=0)):
|
|
182
|
+
return np.zeros((X.shape[0], 3), dtype=np.float64)
|
|
183
|
+
|
|
184
|
+
k = min(n_components, X.shape[1])
|
|
185
|
+
defaults: dict[str, object] = {"max_iter": 300, "tol": 1e-4}
|
|
186
|
+
defaults.update(kwargs)
|
|
187
|
+
model = FastICA(n_components=k, **defaults) # type: ignore[arg-type]
|
|
188
|
+
S = model.fit_transform(X)
|
|
189
|
+
return _minmax_normalize(S, n_target=3)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# Register built-in reducers
|
|
193
|
+
register_reducer("pca", _reduce_pca, component_prefix="PC")
|
|
194
|
+
register_reducer("nmf", _reduce_nmf, component_prefix="NMF")
|
|
195
|
+
register_reducer("ica", _reduce_ica, component_prefix="IC")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# ---- public API -------------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def blend_to_rgb(
|
|
202
|
+
scores: DataFrame,
|
|
203
|
+
*,
|
|
204
|
+
colors: list[tuple[float, float, float]] | None = None,
|
|
205
|
+
) -> NDArray:
|
|
206
|
+
"""Map gene set scores to RGB via multiplicative blending from white.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
scores
|
|
211
|
+
DataFrame returned by :func:`score_gene_sets`. Only columns whose
|
|
212
|
+
names start with ``score-`` are used.
|
|
213
|
+
colors
|
|
214
|
+
One ``(R, G, B)`` tuple per gene set. If *None*, defaults are chosen
|
|
215
|
+
based on the number of gene sets (2 → blue/red, 3 → RGB).
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
numpy.ndarray
|
|
220
|
+
``(n_cells, 3)`` RGB array with values in [0, 1].
|
|
221
|
+
|
|
222
|
+
Raises
|
|
223
|
+
------
|
|
224
|
+
ValueError
|
|
225
|
+
If fewer than 2 or more than 3 gene sets are present, or if the
|
|
226
|
+
number of supplied colours does not match expectations.
|
|
227
|
+
"""
|
|
228
|
+
score_cols = _validate_score_columns(scores)
|
|
229
|
+
n_sets = len(score_cols)
|
|
230
|
+
|
|
231
|
+
if n_sets < 2:
|
|
232
|
+
raise ValueError("At least 2 gene sets are required.")
|
|
233
|
+
if n_sets > 3:
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"Direct projection supports at most 3 gene sets (got {n_sets}). "
|
|
236
|
+
"Use reduce_to_rgb() for higher dimensions."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
mat = scores[score_cols].to_numpy(dtype=np.float64)
|
|
240
|
+
|
|
241
|
+
default = DEFAULT_COLORS_2 if n_sets == 2 else DEFAULT_COLORS_3
|
|
242
|
+
if colors is None:
|
|
243
|
+
colors = default
|
|
244
|
+
if len(colors) != n_sets:
|
|
245
|
+
raise ValueError(f"Expected {n_sets} colors for {n_sets} gene sets, got {len(colors)}.")
|
|
246
|
+
return _multiplicative_blend(mat, colors)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def reduce_to_rgb(
|
|
250
|
+
scores: DataFrame,
|
|
251
|
+
*,
|
|
252
|
+
method: str = "pca",
|
|
253
|
+
n_components: int = 3,
|
|
254
|
+
**kwargs: object,
|
|
255
|
+
) -> NDArray:
|
|
256
|
+
"""Map gene set scores to RGB via dimensionality reduction.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
scores
|
|
261
|
+
DataFrame returned by :func:`score_gene_sets`.
|
|
262
|
+
method
|
|
263
|
+
Reduction method: ``"pca"`` (default), ``"nmf"``, ``"ica"``, or any
|
|
264
|
+
method registered via :func:`register_reducer`.
|
|
265
|
+
n_components
|
|
266
|
+
Number of components to retain (max 3 for RGB).
|
|
267
|
+
**kwargs
|
|
268
|
+
Extra keyword arguments forwarded to the reducer function.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
numpy.ndarray
|
|
273
|
+
``(n_cells, 3)`` RGB array with values in [0, 1].
|
|
274
|
+
|
|
275
|
+
Raises
|
|
276
|
+
------
|
|
277
|
+
ValueError
|
|
278
|
+
If fewer than 2 gene sets are present or *method* is unknown.
|
|
279
|
+
"""
|
|
280
|
+
if method not in _REDUCERS:
|
|
281
|
+
available = ", ".join(sorted(_REDUCERS))
|
|
282
|
+
raise ValueError(f"Unknown reduction method '{method}'. Available: {available}.")
|
|
283
|
+
|
|
284
|
+
score_cols = _validate_score_columns(scores)
|
|
285
|
+
if len(score_cols) < 2:
|
|
286
|
+
raise ValueError("At least 2 gene sets are required.")
|
|
287
|
+
|
|
288
|
+
mat = scores[score_cols].to_numpy(dtype=np.float64)
|
|
289
|
+
k = min(n_components, 3)
|
|
290
|
+
return _REDUCERS[method](mat, k, **kwargs)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# ---- deprecated wrappers ---------------------------------------------------
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def project_direct(
|
|
297
|
+
scores: DataFrame,
|
|
298
|
+
*,
|
|
299
|
+
colors: list[tuple[float, float, float]] | None = None,
|
|
300
|
+
) -> NDArray:
|
|
301
|
+
"""Deprecated: use :func:`blend_to_rgb` instead."""
|
|
302
|
+
warnings.warn(
|
|
303
|
+
"project_direct() is deprecated, use blend_to_rgb() instead.",
|
|
304
|
+
DeprecationWarning,
|
|
305
|
+
stacklevel=2,
|
|
306
|
+
)
|
|
307
|
+
return blend_to_rgb(scores, colors=colors)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def project_pca(
|
|
311
|
+
scores: DataFrame,
|
|
312
|
+
*,
|
|
313
|
+
n_components: int = 3,
|
|
314
|
+
) -> NDArray:
|
|
315
|
+
"""Deprecated: use :func:`reduce_to_rgb` instead."""
|
|
316
|
+
warnings.warn(
|
|
317
|
+
"project_pca() is deprecated, use reduce_to_rgb() instead.",
|
|
318
|
+
DeprecationWarning,
|
|
319
|
+
stacklevel=2,
|
|
320
|
+
)
|
|
321
|
+
return reduce_to_rgb(scores, method="pca", n_components=n_components)
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
"""Interactive Plotly-based embedding plots (optional dependency).
|
|
2
|
+
|
|
3
|
+
Provides ``plot_embedding_interactive`` which renders a WebGL-accelerated
|
|
4
|
+
scatter plot of cells in embedding space, coloured by projected RGB values,
|
|
5
|
+
with rich hover information.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import base64
|
|
11
|
+
import io
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from multiscoresplot._colorspace import get_component_labels
|
|
17
|
+
from multiscoresplot._legend import render_legend
|
|
18
|
+
from multiscoresplot._plotting import _extract_coords, _validate_rgb
|
|
19
|
+
from multiscoresplot._scoring import SCORE_PREFIX
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from numpy.typing import NDArray
|
|
23
|
+
from pandas import DataFrame
|
|
24
|
+
|
|
25
|
+
__all__ = ["plot_embedding_interactive"]
|
|
26
|
+
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
# Legend position lookup (mirrors _INSET_BOUNDS in _plotting.py)
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
|
|
31
|
+
_PLOTLY_LEGEND_POS: dict[str, dict[str, str | float]] = {
|
|
32
|
+
"lower right": {"x": 0.98, "y": 0.02, "xanchor": "right", "yanchor": "bottom"},
|
|
33
|
+
"lower left": {"x": 0.02, "y": 0.02, "xanchor": "left", "yanchor": "bottom"},
|
|
34
|
+
"upper right": {"x": 0.98, "y": 0.98, "xanchor": "right", "yanchor": "top"},
|
|
35
|
+
"upper left": {"x": 0.02, "y": 0.98, "xanchor": "left", "yanchor": "top"},
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
# Legend helpers
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _render_legend_to_base64(
|
|
45
|
+
method: str,
|
|
46
|
+
*,
|
|
47
|
+
gene_set_names: list[str] | None = None,
|
|
48
|
+
colors: list[tuple[float, float, float]] | None = None,
|
|
49
|
+
component_labels: list[str] | None = None,
|
|
50
|
+
resolution: int = 128,
|
|
51
|
+
) -> str:
|
|
52
|
+
"""Render the legend to a base64-encoded PNG data URI."""
|
|
53
|
+
import matplotlib
|
|
54
|
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
55
|
+
from matplotlib.figure import Figure
|
|
56
|
+
|
|
57
|
+
fig = Figure(figsize=(2, 2), dpi=150)
|
|
58
|
+
FigureCanvasAgg(fig)
|
|
59
|
+
ax = fig.add_subplot(111)
|
|
60
|
+
|
|
61
|
+
render_legend(
|
|
62
|
+
ax,
|
|
63
|
+
method,
|
|
64
|
+
gene_set_names=gene_set_names,
|
|
65
|
+
colors=colors,
|
|
66
|
+
component_labels=component_labels,
|
|
67
|
+
resolution=resolution,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
buf = io.BytesIO()
|
|
71
|
+
fig.savefig(buf, format="png", transparent=True, bbox_inches="tight", dpi=150)
|
|
72
|
+
matplotlib.pyplot.close(fig)
|
|
73
|
+
buf.seek(0)
|
|
74
|
+
b64 = base64.b64encode(buf.read()).decode("ascii")
|
|
75
|
+
return f"data:image/png;base64,{b64}"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _add_plotly_legend(
|
|
79
|
+
fig: object,
|
|
80
|
+
*,
|
|
81
|
+
method: str,
|
|
82
|
+
gene_set_names: list[str] | None = None,
|
|
83
|
+
colors: list[tuple[float, float, float]] | None = None,
|
|
84
|
+
legend_loc: str = "lower right",
|
|
85
|
+
legend_size: float = 0.30,
|
|
86
|
+
legend_resolution: int = 128,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Render a matplotlib legend and embed it into a Plotly figure."""
|
|
89
|
+
# For non-direct methods, derive component labels
|
|
90
|
+
component_labels = None
|
|
91
|
+
if method != "direct":
|
|
92
|
+
component_labels = get_component_labels(method)
|
|
93
|
+
|
|
94
|
+
uri = _render_legend_to_base64(
|
|
95
|
+
method,
|
|
96
|
+
gene_set_names=gene_set_names,
|
|
97
|
+
colors=colors,
|
|
98
|
+
component_labels=component_labels,
|
|
99
|
+
resolution=legend_resolution,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
pos = _PLOTLY_LEGEND_POS.get(legend_loc, _PLOTLY_LEGEND_POS["lower right"])
|
|
103
|
+
|
|
104
|
+
fig.add_layout_image( # type: ignore[attr-defined]
|
|
105
|
+
source=uri,
|
|
106
|
+
xref="paper",
|
|
107
|
+
yref="paper",
|
|
108
|
+
x=pos["x"],
|
|
109
|
+
y=pos["y"],
|
|
110
|
+
xanchor=pos["xanchor"],
|
|
111
|
+
yanchor=pos["yanchor"],
|
|
112
|
+
sizex=legend_size,
|
|
113
|
+
sizey=legend_size,
|
|
114
|
+
sizing="contain",
|
|
115
|
+
layer="above",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _ensure_plotly():
|
|
120
|
+
"""Lazy-import plotly, raising a helpful error if not installed."""
|
|
121
|
+
try:
|
|
122
|
+
import plotly.graph_objects as go
|
|
123
|
+
except ImportError:
|
|
124
|
+
raise ImportError(
|
|
125
|
+
"plotly is required for interactive plotting. "
|
|
126
|
+
"Install it with: pip install 'multiscoresplot[interactive]'"
|
|
127
|
+
) from None
|
|
128
|
+
return go
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def plot_embedding_interactive(
|
|
132
|
+
adata_or_coords: object,
|
|
133
|
+
rgb: NDArray,
|
|
134
|
+
*,
|
|
135
|
+
basis: str | None = None,
|
|
136
|
+
components: tuple[int, int] = (0, 1),
|
|
137
|
+
scores: DataFrame | None = None,
|
|
138
|
+
method: str | None = None,
|
|
139
|
+
gene_set_names: list[str] | None = None,
|
|
140
|
+
# legend
|
|
141
|
+
legend: bool = True,
|
|
142
|
+
legend_loc: str = "lower right",
|
|
143
|
+
legend_size: float = 0.30,
|
|
144
|
+
legend_resolution: int = 128,
|
|
145
|
+
colors: list[tuple[float, float, float]] | None = None,
|
|
146
|
+
# hover / scatter
|
|
147
|
+
hover_columns: list[str] | None = None,
|
|
148
|
+
point_size: float = 2,
|
|
149
|
+
alpha: float = 1.0,
|
|
150
|
+
width: int = 500,
|
|
151
|
+
height: int = 450,
|
|
152
|
+
title: str = "",
|
|
153
|
+
show: bool = True,
|
|
154
|
+
) -> object | None:
|
|
155
|
+
"""Interactive Plotly scatter plot of embedding coordinates coloured by RGB.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
adata_or_coords
|
|
160
|
+
An ``AnnData`` object (with *basis* in ``.obsm``) or a raw
|
|
161
|
+
``(n_cells, 2)`` coordinate array.
|
|
162
|
+
rgb
|
|
163
|
+
``(n_cells, 3)`` RGB array from ``blend_to_rgb`` or ``reduce_to_rgb``.
|
|
164
|
+
basis
|
|
165
|
+
Embedding key (e.g. ``"umap"``, ``"pca"``). Required when
|
|
166
|
+
*adata_or_coords* is AnnData.
|
|
167
|
+
components
|
|
168
|
+
Which two components to plot (0-indexed).
|
|
169
|
+
scores
|
|
170
|
+
DataFrame with ``score-*`` columns. If *None* and *adata_or_coords*
|
|
171
|
+
is AnnData, scores are auto-extracted from ``adata.obs``.
|
|
172
|
+
method
|
|
173
|
+
Reduction method (``"pca"``, ``"nmf"``, etc.) used to derive RGB.
|
|
174
|
+
Controls the channel labels in hover info. If *None* or ``"direct"``,
|
|
175
|
+
channels are labeled R/G/B.
|
|
176
|
+
gene_set_names
|
|
177
|
+
Human-readable labels for gene set scores in hover info.
|
|
178
|
+
legend
|
|
179
|
+
Whether to add a colour-space legend overlay.
|
|
180
|
+
legend_loc
|
|
181
|
+
Position for the legend (``"lower right"``, ``"lower left"``,
|
|
182
|
+
``"upper right"``, ``"upper left"``).
|
|
183
|
+
legend_size
|
|
184
|
+
Size of the legend as a fraction of the plot (0-1).
|
|
185
|
+
legend_resolution
|
|
186
|
+
Pixel resolution of the legend image.
|
|
187
|
+
colors
|
|
188
|
+
Base colours for direct-mode legends.
|
|
189
|
+
hover_columns
|
|
190
|
+
Extra columns from ``adata.obs`` to include in hover info.
|
|
191
|
+
point_size
|
|
192
|
+
Scatter marker size.
|
|
193
|
+
alpha
|
|
194
|
+
Marker opacity.
|
|
195
|
+
width
|
|
196
|
+
Figure width in pixels.
|
|
197
|
+
height
|
|
198
|
+
Figure height in pixels.
|
|
199
|
+
title
|
|
200
|
+
Plot title.
|
|
201
|
+
show
|
|
202
|
+
If *True*, call ``fig.show()`` and return *None*. If *False*,
|
|
203
|
+
return the ``plotly.graph_objects.Figure``.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
Figure or None
|
|
208
|
+
The figure when ``show=False``; *None* when ``show=True``.
|
|
209
|
+
"""
|
|
210
|
+
go = _ensure_plotly()
|
|
211
|
+
|
|
212
|
+
coords, basis_label = _extract_coords(adata_or_coords, basis, components)
|
|
213
|
+
n_cells = coords.shape[0]
|
|
214
|
+
rgb = _validate_rgb(rgb, n_cells)
|
|
215
|
+
|
|
216
|
+
# Determine if we have an AnnData object
|
|
217
|
+
has_obs = hasattr(adata_or_coords, "obs")
|
|
218
|
+
|
|
219
|
+
# --- Build hover text ---
|
|
220
|
+
hover_parts: list[list[str]] = [[] for _ in range(n_cells)]
|
|
221
|
+
|
|
222
|
+
# 1. Gene set scores
|
|
223
|
+
score_df: DataFrame | None = scores
|
|
224
|
+
if score_df is None and has_obs:
|
|
225
|
+
obs = adata_or_coords.obs # type: ignore[attr-defined]
|
|
226
|
+
score_cols = [c for c in obs.columns if c.startswith(SCORE_PREFIX)]
|
|
227
|
+
if score_cols:
|
|
228
|
+
score_df = obs[score_cols]
|
|
229
|
+
|
|
230
|
+
if score_df is not None:
|
|
231
|
+
score_cols = [c for c in score_df.columns if c.startswith(SCORE_PREFIX)]
|
|
232
|
+
labels = (
|
|
233
|
+
gene_set_names
|
|
234
|
+
if gene_set_names is not None and len(gene_set_names) == len(score_cols)
|
|
235
|
+
else [c[len(SCORE_PREFIX) :] for c in score_cols]
|
|
236
|
+
)
|
|
237
|
+
score_vals = score_df[score_cols].to_numpy(dtype=np.float64)
|
|
238
|
+
for i in range(n_cells):
|
|
239
|
+
for j, label in enumerate(labels):
|
|
240
|
+
hover_parts[i].append(f"{label}: {score_vals[i, j]:.3f}")
|
|
241
|
+
|
|
242
|
+
# 2. RGB channel values
|
|
243
|
+
if method is not None and method != "direct":
|
|
244
|
+
channel_labels = get_component_labels(method)
|
|
245
|
+
else:
|
|
246
|
+
channel_labels = ["R", "G", "B"]
|
|
247
|
+
|
|
248
|
+
for i in range(n_cells):
|
|
249
|
+
for j, ch_label in enumerate(channel_labels):
|
|
250
|
+
hover_parts[i].append(f"{ch_label}: {rgb[i, j]:.2f}")
|
|
251
|
+
|
|
252
|
+
# 3. Extra .obs columns
|
|
253
|
+
if hover_columns is not None:
|
|
254
|
+
if not has_obs:
|
|
255
|
+
raise ValueError("hover_columns requires an AnnData object, not raw coordinates.")
|
|
256
|
+
obs = adata_or_coords.obs # type: ignore[attr-defined]
|
|
257
|
+
missing = [c for c in hover_columns if c not in obs.columns]
|
|
258
|
+
if missing:
|
|
259
|
+
raise KeyError(f"Columns not found in adata.obs: {missing}")
|
|
260
|
+
|
|
261
|
+
import pandas as _pd
|
|
262
|
+
|
|
263
|
+
for col_name in hover_columns:
|
|
264
|
+
col = obs[col_name]
|
|
265
|
+
is_numeric = _pd.api.types.is_numeric_dtype(col)
|
|
266
|
+
for i in range(n_cells):
|
|
267
|
+
val = col.iloc[i]
|
|
268
|
+
if is_numeric:
|
|
269
|
+
hover_parts[i].append(f"{col_name}: {val:.3f}")
|
|
270
|
+
else:
|
|
271
|
+
hover_parts[i].append(f"{col_name}: {val}")
|
|
272
|
+
|
|
273
|
+
hover_text = ["<br>".join(parts) for parts in hover_parts]
|
|
274
|
+
|
|
275
|
+
# --- Build color strings ---
|
|
276
|
+
marker_colors = [
|
|
277
|
+
f"rgba({int(r * 255)},{int(g * 255)},{int(b * 255)},{alpha})" for r, g, b in rgb
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
# --- Axis labels ---
|
|
281
|
+
if basis_label is not None:
|
|
282
|
+
xaxis_title = f"{basis_label}{components[0] + 1}"
|
|
283
|
+
yaxis_title = f"{basis_label}{components[1] + 1}"
|
|
284
|
+
else:
|
|
285
|
+
xaxis_title = ""
|
|
286
|
+
yaxis_title = ""
|
|
287
|
+
|
|
288
|
+
# --- Create figure ---
|
|
289
|
+
fig = go.Figure(
|
|
290
|
+
data=go.Scattergl(
|
|
291
|
+
x=coords[:, 0],
|
|
292
|
+
y=coords[:, 1],
|
|
293
|
+
mode="markers",
|
|
294
|
+
marker=dict(
|
|
295
|
+
size=point_size,
|
|
296
|
+
color=marker_colors,
|
|
297
|
+
),
|
|
298
|
+
hovertext=hover_text,
|
|
299
|
+
hoverinfo="text",
|
|
300
|
+
),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
fig.update_layout(
|
|
304
|
+
width=width,
|
|
305
|
+
height=height,
|
|
306
|
+
title=title,
|
|
307
|
+
xaxis=dict(title=xaxis_title, scaleanchor="y"),
|
|
308
|
+
yaxis=dict(title=yaxis_title),
|
|
309
|
+
plot_bgcolor="white",
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Legend (skip silently when method is None, or direct mode without gene_set_names)
|
|
313
|
+
if legend and method is not None and (method != "direct" or gene_set_names is not None):
|
|
314
|
+
_add_plotly_legend(
|
|
315
|
+
fig,
|
|
316
|
+
method=method,
|
|
317
|
+
gene_set_names=gene_set_names,
|
|
318
|
+
colors=colors,
|
|
319
|
+
legend_loc=legend_loc,
|
|
320
|
+
legend_size=legend_size,
|
|
321
|
+
legend_resolution=legend_resolution,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
if show:
|
|
325
|
+
fig.show()
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
return fig # type: ignore[no-any-return]
|