multiscoresplot 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ """multiscoresplot -- multi-dimensional gene set scoring visualization."""
2
+
3
+ from multiscoresplot._colorspace import (
4
+ blend_to_rgb,
5
+ get_component_labels,
6
+ project_direct,
7
+ project_pca,
8
+ reduce_to_rgb,
9
+ register_reducer,
10
+ )
11
+ from multiscoresplot._interactive import plot_embedding_interactive
12
+ from multiscoresplot._legend import render_legend
13
+ from multiscoresplot._plotting import plot_embedding
14
+ from multiscoresplot._scoring import score_gene_sets
15
+
16
+ __all__ = [
17
+ "blend_to_rgb",
18
+ "get_component_labels",
19
+ "plot_embedding",
20
+ "plot_embedding_interactive",
21
+ "project_direct",
22
+ "project_pca",
23
+ "reduce_to_rgb",
24
+ "register_reducer",
25
+ "render_legend",
26
+ "score_gene_sets",
27
+ ]
28
+ __version__ = "1.0.0"
@@ -0,0 +1,321 @@
1
+ """Color space construction and cell projection (pipeline steps 2-3).
2
+
3
+ Provides two projection strategies:
4
+
5
+ * **Direct** (`blend_to_rgb`): multiplicative blending from white using
6
+ explicit base colors. Supports 2-3 gene sets.
7
+ * **Reduction** (`reduce_to_rgb`): dimensionality reduction (PCA, NMF, ICA,
8
+ or custom) to 3 color channels. Works for any number of gene sets (≥ 2).
9
+
10
+ Legacy names ``project_direct`` and ``project_pca`` are kept as thin
11
+ deprecation wrappers.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import warnings
17
+ from typing import TYPE_CHECKING
18
+
19
+ import numpy as np
20
+
21
+ from multiscoresplot._scoring import SCORE_PREFIX
22
+
23
+ if TYPE_CHECKING:
24
+ from collections.abc import Callable
25
+
26
+ from numpy.typing import NDArray
27
+ from pandas import DataFrame
28
+
29
+ __all__ = [
30
+ "blend_to_rgb",
31
+ "get_component_labels",
32
+ "project_direct",
33
+ "project_pca",
34
+ "reduce_to_rgb",
35
+ "register_reducer",
36
+ ]
37
+
38
+ # ---- default color palettes ------------------------------------------------
39
+
40
+ DEFAULT_COLORS_2: list[tuple[float, float, float]] = [
41
+ (0.0, 0.0, 1.0), # blue
42
+ (1.0, 0.0, 0.0), # red
43
+ ]
44
+
45
+ DEFAULT_COLORS_3: list[tuple[float, float, float]] = [
46
+ (1.0, 0.0, 0.0), # red
47
+ (0.0, 1.0, 0.0), # green
48
+ (0.0, 0.0, 1.0), # blue
49
+ ]
50
+
51
+
52
+ # ---- private helpers --------------------------------------------------------
53
+
54
+
55
+ def _validate_score_columns(scores: DataFrame, prefix: str = SCORE_PREFIX) -> list[str]:
56
+ """Return the ``score-*`` column names, raising if none are found."""
57
+ cols = [c for c in scores.columns if c.startswith(prefix)]
58
+ if not cols:
59
+ msg = (
60
+ "No score columns found. Expected columns starting with "
61
+ f"'{prefix}'. Run score_gene_sets() first."
62
+ )
63
+ raise ValueError(msg)
64
+ return cols
65
+
66
+
67
+ def _multiplicative_blend(
68
+ score_matrix: NDArray,
69
+ colors: list[tuple[float, float, float]],
70
+ ) -> NDArray:
71
+ """Blend gene set scores into RGB via multiplicative gradients from white.
72
+
73
+ For each gene set *i* with base color ``c_i`` and score ``s_i``:
74
+ ``gradient_i = 1 - s_i * (1 - c_i)``
75
+
76
+ The final colour is the element-wise product of all gradients.
77
+ """
78
+ n_cells = score_matrix.shape[0]
79
+ rgb = np.ones((n_cells, 3), dtype=np.float64)
80
+
81
+ for i, color in enumerate(colors):
82
+ c = np.asarray(color, dtype=np.float64) # (3,)
83
+ s = score_matrix[:, i : i + 1] # (n_cells, 1)
84
+ gradient = 1.0 - s * (1.0 - c) # (n_cells, 3)
85
+ rgb *= gradient
86
+
87
+ return np.clip(rgb, 0.0, 1.0)
88
+
89
+
90
+ def _minmax_normalize(X: NDArray, n_target: int = 3) -> NDArray:
91
+ """Min-max normalize each column to [0, 1], zero-pad to *n_target* columns."""
92
+ k = X.shape[1]
93
+ for j in range(k):
94
+ col = X[:, j]
95
+ lo, hi = col.min(), col.max()
96
+ if hi - lo > 0:
97
+ X[:, j] = (col - lo) / (hi - lo)
98
+ else:
99
+ X[:, j] = 0.0
100
+
101
+ if k < n_target:
102
+ pad = np.zeros((X.shape[0], n_target - k), dtype=np.float64)
103
+ X = np.hstack([X, pad])
104
+
105
+ return X
106
+
107
+
108
+ # ---- reducer registry ------------------------------------------------------
109
+
110
+ ReducerFn = type(lambda: None) # placeholder for type alias
111
+
112
+ _REDUCERS: dict[str, Callable[..., NDArray]] = {}
113
+ _COMPONENT_PREFIXES: dict[str, str] = {}
114
+
115
+
116
+ def register_reducer(
117
+ name: str,
118
+ fn: Callable[..., NDArray],
119
+ *,
120
+ component_prefix: str | None = None,
121
+ ) -> None:
122
+ """Register a dimensionality reduction method for use with ``reduce_to_rgb``.
123
+
124
+ Parameters
125
+ ----------
126
+ name
127
+ Short identifier (e.g. ``"pca"``, ``"nmf"``).
128
+ fn
129
+ Callable with signature ``(X, n_components, **kwargs) -> NDArray``
130
+ returning an ``(n_cells, 3)`` array with values in [0, 1].
131
+ component_prefix
132
+ Label prefix for legend axes (e.g. ``"PC"`` → PC1, PC2, PC3).
133
+ """
134
+ _REDUCERS[name] = fn
135
+ if component_prefix is not None:
136
+ _COMPONENT_PREFIXES[name] = component_prefix
137
+
138
+
139
+ def get_component_labels(method: str) -> list[str]:
140
+ """Return ``["<prefix>1", "<prefix>2", "<prefix>3"]`` for a registered method."""
141
+ prefix = _COMPONENT_PREFIXES.get(method, "C")
142
+ return [f"{prefix}{i + 1}" for i in range(3)]
143
+
144
+
145
+ # ---- built-in reducer implementations --------------------------------------
146
+
147
+
148
+ def _reduce_pca(X: NDArray, n_components: int, **kwargs: object) -> NDArray:
149
+ """PCA via numpy SVD."""
150
+ mean = X.mean(axis=0)
151
+ X_centered = X - mean
152
+
153
+ if np.allclose(X_centered, 0.0):
154
+ return np.zeros((X.shape[0], 3), dtype=np.float64)
155
+
156
+ U, S, _ = np.linalg.svd(X_centered, full_matrices=False)
157
+ k = min(n_components, U.shape[1])
158
+ pc_scores = U[:, :k] * S[:k]
159
+ return _minmax_normalize(pc_scores, n_target=3)
160
+
161
+
162
+ def _reduce_nmf(X: NDArray, n_components: int, **kwargs: object) -> NDArray:
163
+ """NMF via scikit-learn."""
164
+ from sklearn.decomposition import NMF
165
+
166
+ if np.allclose(X, X.mean(axis=0)):
167
+ return np.zeros((X.shape[0], 3), dtype=np.float64)
168
+
169
+ k = min(n_components, X.shape[1])
170
+ defaults: dict[str, object] = {"init": "nndsvda", "max_iter": 300}
171
+ defaults.update(kwargs)
172
+ model = NMF(n_components=k, **defaults) # type: ignore[arg-type]
173
+ W = model.fit_transform(X)
174
+ return _minmax_normalize(W, n_target=3)
175
+
176
+
177
+ def _reduce_ica(X: NDArray, n_components: int, **kwargs: object) -> NDArray:
178
+ """ICA via scikit-learn FastICA."""
179
+ from sklearn.decomposition import FastICA
180
+
181
+ if np.allclose(X, X.mean(axis=0)):
182
+ return np.zeros((X.shape[0], 3), dtype=np.float64)
183
+
184
+ k = min(n_components, X.shape[1])
185
+ defaults: dict[str, object] = {"max_iter": 300, "tol": 1e-4}
186
+ defaults.update(kwargs)
187
+ model = FastICA(n_components=k, **defaults) # type: ignore[arg-type]
188
+ S = model.fit_transform(X)
189
+ return _minmax_normalize(S, n_target=3)
190
+
191
+
192
+ # Register built-in reducers
193
+ register_reducer("pca", _reduce_pca, component_prefix="PC")
194
+ register_reducer("nmf", _reduce_nmf, component_prefix="NMF")
195
+ register_reducer("ica", _reduce_ica, component_prefix="IC")
196
+
197
+
198
+ # ---- public API -------------------------------------------------------------
199
+
200
+
201
+ def blend_to_rgb(
202
+ scores: DataFrame,
203
+ *,
204
+ colors: list[tuple[float, float, float]] | None = None,
205
+ ) -> NDArray:
206
+ """Map gene set scores to RGB via multiplicative blending from white.
207
+
208
+ Parameters
209
+ ----------
210
+ scores
211
+ DataFrame returned by :func:`score_gene_sets`. Only columns whose
212
+ names start with ``score-`` are used.
213
+ colors
214
+ One ``(R, G, B)`` tuple per gene set. If *None*, defaults are chosen
215
+ based on the number of gene sets (2 → blue/red, 3 → RGB).
216
+
217
+ Returns
218
+ -------
219
+ numpy.ndarray
220
+ ``(n_cells, 3)`` RGB array with values in [0, 1].
221
+
222
+ Raises
223
+ ------
224
+ ValueError
225
+ If fewer than 2 or more than 3 gene sets are present, or if the
226
+ number of supplied colours does not match expectations.
227
+ """
228
+ score_cols = _validate_score_columns(scores)
229
+ n_sets = len(score_cols)
230
+
231
+ if n_sets < 2:
232
+ raise ValueError("At least 2 gene sets are required.")
233
+ if n_sets > 3:
234
+ raise ValueError(
235
+ f"Direct projection supports at most 3 gene sets (got {n_sets}). "
236
+ "Use reduce_to_rgb() for higher dimensions."
237
+ )
238
+
239
+ mat = scores[score_cols].to_numpy(dtype=np.float64)
240
+
241
+ default = DEFAULT_COLORS_2 if n_sets == 2 else DEFAULT_COLORS_3
242
+ if colors is None:
243
+ colors = default
244
+ if len(colors) != n_sets:
245
+ raise ValueError(f"Expected {n_sets} colors for {n_sets} gene sets, got {len(colors)}.")
246
+ return _multiplicative_blend(mat, colors)
247
+
248
+
249
+ def reduce_to_rgb(
250
+ scores: DataFrame,
251
+ *,
252
+ method: str = "pca",
253
+ n_components: int = 3,
254
+ **kwargs: object,
255
+ ) -> NDArray:
256
+ """Map gene set scores to RGB via dimensionality reduction.
257
+
258
+ Parameters
259
+ ----------
260
+ scores
261
+ DataFrame returned by :func:`score_gene_sets`.
262
+ method
263
+ Reduction method: ``"pca"`` (default), ``"nmf"``, ``"ica"``, or any
264
+ method registered via :func:`register_reducer`.
265
+ n_components
266
+ Number of components to retain (max 3 for RGB).
267
+ **kwargs
268
+ Extra keyword arguments forwarded to the reducer function.
269
+
270
+ Returns
271
+ -------
272
+ numpy.ndarray
273
+ ``(n_cells, 3)`` RGB array with values in [0, 1].
274
+
275
+ Raises
276
+ ------
277
+ ValueError
278
+ If fewer than 2 gene sets are present or *method* is unknown.
279
+ """
280
+ if method not in _REDUCERS:
281
+ available = ", ".join(sorted(_REDUCERS))
282
+ raise ValueError(f"Unknown reduction method '{method}'. Available: {available}.")
283
+
284
+ score_cols = _validate_score_columns(scores)
285
+ if len(score_cols) < 2:
286
+ raise ValueError("At least 2 gene sets are required.")
287
+
288
+ mat = scores[score_cols].to_numpy(dtype=np.float64)
289
+ k = min(n_components, 3)
290
+ return _REDUCERS[method](mat, k, **kwargs)
291
+
292
+
293
+ # ---- deprecated wrappers ---------------------------------------------------
294
+
295
+
296
+ def project_direct(
297
+ scores: DataFrame,
298
+ *,
299
+ colors: list[tuple[float, float, float]] | None = None,
300
+ ) -> NDArray:
301
+ """Deprecated: use :func:`blend_to_rgb` instead."""
302
+ warnings.warn(
303
+ "project_direct() is deprecated, use blend_to_rgb() instead.",
304
+ DeprecationWarning,
305
+ stacklevel=2,
306
+ )
307
+ return blend_to_rgb(scores, colors=colors)
308
+
309
+
310
+ def project_pca(
311
+ scores: DataFrame,
312
+ *,
313
+ n_components: int = 3,
314
+ ) -> NDArray:
315
+ """Deprecated: use :func:`reduce_to_rgb` instead."""
316
+ warnings.warn(
317
+ "project_pca() is deprecated, use reduce_to_rgb() instead.",
318
+ DeprecationWarning,
319
+ stacklevel=2,
320
+ )
321
+ return reduce_to_rgb(scores, method="pca", n_components=n_components)
@@ -0,0 +1,328 @@
1
+ """Interactive Plotly-based embedding plots (optional dependency).
2
+
3
+ Provides ``plot_embedding_interactive`` which renders a WebGL-accelerated
4
+ scatter plot of cells in embedding space, coloured by projected RGB values,
5
+ with rich hover information.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import base64
11
+ import io
12
+ from typing import TYPE_CHECKING
13
+
14
+ import numpy as np
15
+
16
+ from multiscoresplot._colorspace import get_component_labels
17
+ from multiscoresplot._legend import render_legend
18
+ from multiscoresplot._plotting import _extract_coords, _validate_rgb
19
+ from multiscoresplot._scoring import SCORE_PREFIX
20
+
21
+ if TYPE_CHECKING:
22
+ from numpy.typing import NDArray
23
+ from pandas import DataFrame
24
+
25
+ __all__ = ["plot_embedding_interactive"]
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Legend position lookup (mirrors _INSET_BOUNDS in _plotting.py)
29
+ # ---------------------------------------------------------------------------
30
+
31
+ _PLOTLY_LEGEND_POS: dict[str, dict[str, str | float]] = {
32
+ "lower right": {"x": 0.98, "y": 0.02, "xanchor": "right", "yanchor": "bottom"},
33
+ "lower left": {"x": 0.02, "y": 0.02, "xanchor": "left", "yanchor": "bottom"},
34
+ "upper right": {"x": 0.98, "y": 0.98, "xanchor": "right", "yanchor": "top"},
35
+ "upper left": {"x": 0.02, "y": 0.98, "xanchor": "left", "yanchor": "top"},
36
+ }
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Legend helpers
41
+ # ---------------------------------------------------------------------------
42
+
43
+
44
+ def _render_legend_to_base64(
45
+ method: str,
46
+ *,
47
+ gene_set_names: list[str] | None = None,
48
+ colors: list[tuple[float, float, float]] | None = None,
49
+ component_labels: list[str] | None = None,
50
+ resolution: int = 128,
51
+ ) -> str:
52
+ """Render the legend to a base64-encoded PNG data URI."""
53
+ import matplotlib
54
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
55
+ from matplotlib.figure import Figure
56
+
57
+ fig = Figure(figsize=(2, 2), dpi=150)
58
+ FigureCanvasAgg(fig)
59
+ ax = fig.add_subplot(111)
60
+
61
+ render_legend(
62
+ ax,
63
+ method,
64
+ gene_set_names=gene_set_names,
65
+ colors=colors,
66
+ component_labels=component_labels,
67
+ resolution=resolution,
68
+ )
69
+
70
+ buf = io.BytesIO()
71
+ fig.savefig(buf, format="png", transparent=True, bbox_inches="tight", dpi=150)
72
+ matplotlib.pyplot.close(fig)
73
+ buf.seek(0)
74
+ b64 = base64.b64encode(buf.read()).decode("ascii")
75
+ return f"data:image/png;base64,{b64}"
76
+
77
+
78
+ def _add_plotly_legend(
79
+ fig: object,
80
+ *,
81
+ method: str,
82
+ gene_set_names: list[str] | None = None,
83
+ colors: list[tuple[float, float, float]] | None = None,
84
+ legend_loc: str = "lower right",
85
+ legend_size: float = 0.30,
86
+ legend_resolution: int = 128,
87
+ ) -> None:
88
+ """Render a matplotlib legend and embed it into a Plotly figure."""
89
+ # For non-direct methods, derive component labels
90
+ component_labels = None
91
+ if method != "direct":
92
+ component_labels = get_component_labels(method)
93
+
94
+ uri = _render_legend_to_base64(
95
+ method,
96
+ gene_set_names=gene_set_names,
97
+ colors=colors,
98
+ component_labels=component_labels,
99
+ resolution=legend_resolution,
100
+ )
101
+
102
+ pos = _PLOTLY_LEGEND_POS.get(legend_loc, _PLOTLY_LEGEND_POS["lower right"])
103
+
104
+ fig.add_layout_image( # type: ignore[attr-defined]
105
+ source=uri,
106
+ xref="paper",
107
+ yref="paper",
108
+ x=pos["x"],
109
+ y=pos["y"],
110
+ xanchor=pos["xanchor"],
111
+ yanchor=pos["yanchor"],
112
+ sizex=legend_size,
113
+ sizey=legend_size,
114
+ sizing="contain",
115
+ layer="above",
116
+ )
117
+
118
+
119
+ def _ensure_plotly():
120
+ """Lazy-import plotly, raising a helpful error if not installed."""
121
+ try:
122
+ import plotly.graph_objects as go
123
+ except ImportError:
124
+ raise ImportError(
125
+ "plotly is required for interactive plotting. "
126
+ "Install it with: pip install 'multiscoresplot[interactive]'"
127
+ ) from None
128
+ return go
129
+
130
+
131
+ def plot_embedding_interactive(
132
+ adata_or_coords: object,
133
+ rgb: NDArray,
134
+ *,
135
+ basis: str | None = None,
136
+ components: tuple[int, int] = (0, 1),
137
+ scores: DataFrame | None = None,
138
+ method: str | None = None,
139
+ gene_set_names: list[str] | None = None,
140
+ # legend
141
+ legend: bool = True,
142
+ legend_loc: str = "lower right",
143
+ legend_size: float = 0.30,
144
+ legend_resolution: int = 128,
145
+ colors: list[tuple[float, float, float]] | None = None,
146
+ # hover / scatter
147
+ hover_columns: list[str] | None = None,
148
+ point_size: float = 2,
149
+ alpha: float = 1.0,
150
+ width: int = 500,
151
+ height: int = 450,
152
+ title: str = "",
153
+ show: bool = True,
154
+ ) -> object | None:
155
+ """Interactive Plotly scatter plot of embedding coordinates coloured by RGB.
156
+
157
+ Parameters
158
+ ----------
159
+ adata_or_coords
160
+ An ``AnnData`` object (with *basis* in ``.obsm``) or a raw
161
+ ``(n_cells, 2)`` coordinate array.
162
+ rgb
163
+ ``(n_cells, 3)`` RGB array from ``blend_to_rgb`` or ``reduce_to_rgb``.
164
+ basis
165
+ Embedding key (e.g. ``"umap"``, ``"pca"``). Required when
166
+ *adata_or_coords* is AnnData.
167
+ components
168
+ Which two components to plot (0-indexed).
169
+ scores
170
+ DataFrame with ``score-*`` columns. If *None* and *adata_or_coords*
171
+ is AnnData, scores are auto-extracted from ``adata.obs``.
172
+ method
173
+ Reduction method (``"pca"``, ``"nmf"``, etc.) used to derive RGB.
174
+ Controls the channel labels in hover info. If *None* or ``"direct"``,
175
+ channels are labeled R/G/B.
176
+ gene_set_names
177
+ Human-readable labels for gene set scores in hover info.
178
+ legend
179
+ Whether to add a colour-space legend overlay.
180
+ legend_loc
181
+ Position for the legend (``"lower right"``, ``"lower left"``,
182
+ ``"upper right"``, ``"upper left"``).
183
+ legend_size
184
+ Size of the legend as a fraction of the plot (0-1).
185
+ legend_resolution
186
+ Pixel resolution of the legend image.
187
+ colors
188
+ Base colours for direct-mode legends.
189
+ hover_columns
190
+ Extra columns from ``adata.obs`` to include in hover info.
191
+ point_size
192
+ Scatter marker size.
193
+ alpha
194
+ Marker opacity.
195
+ width
196
+ Figure width in pixels.
197
+ height
198
+ Figure height in pixels.
199
+ title
200
+ Plot title.
201
+ show
202
+ If *True*, call ``fig.show()`` and return *None*. If *False*,
203
+ return the ``plotly.graph_objects.Figure``.
204
+
205
+ Returns
206
+ -------
207
+ Figure or None
208
+ The figure when ``show=False``; *None* when ``show=True``.
209
+ """
210
+ go = _ensure_plotly()
211
+
212
+ coords, basis_label = _extract_coords(adata_or_coords, basis, components)
213
+ n_cells = coords.shape[0]
214
+ rgb = _validate_rgb(rgb, n_cells)
215
+
216
+ # Determine if we have an AnnData object
217
+ has_obs = hasattr(adata_or_coords, "obs")
218
+
219
+ # --- Build hover text ---
220
+ hover_parts: list[list[str]] = [[] for _ in range(n_cells)]
221
+
222
+ # 1. Gene set scores
223
+ score_df: DataFrame | None = scores
224
+ if score_df is None and has_obs:
225
+ obs = adata_or_coords.obs # type: ignore[attr-defined]
226
+ score_cols = [c for c in obs.columns if c.startswith(SCORE_PREFIX)]
227
+ if score_cols:
228
+ score_df = obs[score_cols]
229
+
230
+ if score_df is not None:
231
+ score_cols = [c for c in score_df.columns if c.startswith(SCORE_PREFIX)]
232
+ labels = (
233
+ gene_set_names
234
+ if gene_set_names is not None and len(gene_set_names) == len(score_cols)
235
+ else [c[len(SCORE_PREFIX) :] for c in score_cols]
236
+ )
237
+ score_vals = score_df[score_cols].to_numpy(dtype=np.float64)
238
+ for i in range(n_cells):
239
+ for j, label in enumerate(labels):
240
+ hover_parts[i].append(f"{label}: {score_vals[i, j]:.3f}")
241
+
242
+ # 2. RGB channel values
243
+ if method is not None and method != "direct":
244
+ channel_labels = get_component_labels(method)
245
+ else:
246
+ channel_labels = ["R", "G", "B"]
247
+
248
+ for i in range(n_cells):
249
+ for j, ch_label in enumerate(channel_labels):
250
+ hover_parts[i].append(f"{ch_label}: {rgb[i, j]:.2f}")
251
+
252
+ # 3. Extra .obs columns
253
+ if hover_columns is not None:
254
+ if not has_obs:
255
+ raise ValueError("hover_columns requires an AnnData object, not raw coordinates.")
256
+ obs = adata_or_coords.obs # type: ignore[attr-defined]
257
+ missing = [c for c in hover_columns if c not in obs.columns]
258
+ if missing:
259
+ raise KeyError(f"Columns not found in adata.obs: {missing}")
260
+
261
+ import pandas as _pd
262
+
263
+ for col_name in hover_columns:
264
+ col = obs[col_name]
265
+ is_numeric = _pd.api.types.is_numeric_dtype(col)
266
+ for i in range(n_cells):
267
+ val = col.iloc[i]
268
+ if is_numeric:
269
+ hover_parts[i].append(f"{col_name}: {val:.3f}")
270
+ else:
271
+ hover_parts[i].append(f"{col_name}: {val}")
272
+
273
+ hover_text = ["<br>".join(parts) for parts in hover_parts]
274
+
275
+ # --- Build color strings ---
276
+ marker_colors = [
277
+ f"rgba({int(r * 255)},{int(g * 255)},{int(b * 255)},{alpha})" for r, g, b in rgb
278
+ ]
279
+
280
+ # --- Axis labels ---
281
+ if basis_label is not None:
282
+ xaxis_title = f"{basis_label}{components[0] + 1}"
283
+ yaxis_title = f"{basis_label}{components[1] + 1}"
284
+ else:
285
+ xaxis_title = ""
286
+ yaxis_title = ""
287
+
288
+ # --- Create figure ---
289
+ fig = go.Figure(
290
+ data=go.Scattergl(
291
+ x=coords[:, 0],
292
+ y=coords[:, 1],
293
+ mode="markers",
294
+ marker=dict(
295
+ size=point_size,
296
+ color=marker_colors,
297
+ ),
298
+ hovertext=hover_text,
299
+ hoverinfo="text",
300
+ ),
301
+ )
302
+
303
+ fig.update_layout(
304
+ width=width,
305
+ height=height,
306
+ title=title,
307
+ xaxis=dict(title=xaxis_title, scaleanchor="y"),
308
+ yaxis=dict(title=yaxis_title),
309
+ plot_bgcolor="white",
310
+ )
311
+
312
+ # Legend (skip silently when method is None, or direct mode without gene_set_names)
313
+ if legend and method is not None and (method != "direct" or gene_set_names is not None):
314
+ _add_plotly_legend(
315
+ fig,
316
+ method=method,
317
+ gene_set_names=gene_set_names,
318
+ colors=colors,
319
+ legend_loc=legend_loc,
320
+ legend_size=legend_size,
321
+ legend_resolution=legend_resolution,
322
+ )
323
+
324
+ if show:
325
+ fig.show()
326
+ return None
327
+
328
+ return fig # type: ignore[no-any-return]