scales-python 1.4.0.9000__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,593 @@
1
+ """
2
+ Colour mapping functions for the scales package.
3
+
4
+ Python port of R/colour-mapping.R from the R scales package
5
+ (https://github.com/r-lib/scales). Provides factory functions that return
6
+ callables mapping data values to hex colour strings:
7
+
8
+ - :func:`col_numeric` -- continuous linear interpolation
9
+ - :func:`col_bin` -- binned (stepped) colour mapping
10
+ - :func:`col_quantile` -- quantile-based binning
11
+ - :func:`col_factor` -- categorical / factor mapping
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import warnings
17
+ from typing import (
18
+ Any,
19
+ Callable,
20
+ Dict,
21
+ List,
22
+ Optional,
23
+ Sequence,
24
+ Tuple,
25
+ Union,
26
+ )
27
+
28
+ import numpy as np
29
+ from numpy.typing import ArrayLike
30
+
31
+ from .colour_ramp import colour_ramp
32
+ from .palettes import pal_brewer, pal_viridis
33
+
34
+ __all__ = [
35
+ "col_numeric",
36
+ "col_bin",
37
+ "col_quantile",
38
+ "col_factor",
39
+ ]
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Brewer palette name lookup (from embedded data)
43
+ # ---------------------------------------------------------------------------
44
+
45
+ from ._palettes_data import BREWER_MAXCOLORS as _BREWER_PALETTES
46
+
47
+ _VIRIDIS_NAMES = {"viridis", "magma", "inferno", "plasma"}
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Internal helpers (matching R's safePaletteFunc / toPaletteFunc dispatch)
52
+ # ---------------------------------------------------------------------------
53
+
54
+ def _to_palette_func(
55
+ pal: Union[str, Sequence[str], Callable],
56
+ alpha: bool = True,
57
+ nlevels: Optional[int] = None,
58
+ ) -> Callable[[ArrayLike], List[str]]:
59
+ """
60
+ Convert a palette specification to a callable ramp over [0, 1].
61
+
62
+ Matches R's ``toPaletteFunc`` S3 dispatch:
63
+ - A single string that is a Brewer palette name → sample colours, pass
64
+ through ``colour_ramp`` (LAB interpolation).
65
+ - A single string that is a viridis option → sample 256 colours, pass
66
+ through ``colour_ramp``.
67
+ - A character vector of colours → pass through ``colour_ramp``.
68
+ - A callable → use as-is.
69
+
70
+ Parameters
71
+ ----------
72
+ pal : str, sequence of str, or callable
73
+ Palette specification.
74
+ alpha : bool
75
+ Whether to include alpha channel in interpolation.
76
+ nlevels : int or None
77
+ Number of levels (used for Brewer palette sampling).
78
+ """
79
+ if callable(pal) and not isinstance(pal, (str, list, tuple)):
80
+ return pal
81
+
82
+ if isinstance(pal, str):
83
+ if pal in _BREWER_PALETTES:
84
+ # R: sample all maxcolors or abs(nlevels) colours
85
+ max_n = _BREWER_PALETTES[pal]
86
+ if nlevels is not None:
87
+ n_sample = min(abs(nlevels), max_n)
88
+ else:
89
+ n_sample = max_n
90
+ colors = pal_brewer(palette=pal)(n_sample)
91
+ colors = [c for c in colors if c is not None]
92
+ return colour_ramp(colors, alpha=alpha)
93
+ elif pal in _VIRIDIS_NAMES:
94
+ colors = pal_viridis(option=pal)(256)
95
+ return colour_ramp(colors, alpha=alpha)
96
+ else:
97
+ # Try as a single colour string
98
+ pal = [pal]
99
+
100
+ # List of colours
101
+ return colour_ramp(list(pal), alpha=alpha)
102
+
103
+
104
+ def _safe_palette_func(
105
+ pal: Union[str, Sequence[str], Callable],
106
+ na_color: str,
107
+ alpha: bool = True,
108
+ nlevels: Optional[int] = None,
109
+ ) -> Callable[[ArrayLike], List[str]]:
110
+ """
111
+ Wrap a palette function with NA handling and range filtering.
112
+
113
+ Matches R's ``safePaletteFunc``: composes filterRange → filterNA →
114
+ filterZeroLength → filterRGB → toPaletteFunc.
115
+ """
116
+ ramp = _to_palette_func(pal, alpha=alpha, nlevels=nlevels)
117
+
118
+ def _safe(x: ArrayLike) -> List[str]:
119
+ x = np.asarray(x, dtype=float)
120
+ if x.size == 0:
121
+ return []
122
+ # filterRange: out-of-[0,1] → NaN
123
+ x = np.where((x < 0) | (x > 1), np.nan, x)
124
+ result = ramp(x)
125
+ return [na_color if c is None else c for c in result]
126
+
127
+ return _safe
128
+
129
+
130
+ def _rescale(
131
+ x: np.ndarray,
132
+ domain_min: float,
133
+ domain_max: float,
134
+ ) -> np.ndarray:
135
+ """Linearly rescale *x* from *[domain_min, domain_max]* to [0, 1]."""
136
+ rng = domain_max - domain_min
137
+ if rng == 0:
138
+ return np.where(np.isnan(x), np.nan, 0.5)
139
+ return (x - domain_min) / rng
140
+
141
+
142
+ def _pretty_breaks(domain_min: float, domain_max: float, n: int) -> np.ndarray:
143
+ """
144
+ Compute *n* + 1 "pretty" breakpoints spanning the domain.
145
+
146
+ Uses numpy's ``linspace`` rounded to a clean step size that resembles
147
+ R's ``pretty()`` heuristic.
148
+ """
149
+ raw_step = (domain_max - domain_min) / n
150
+ if raw_step == 0:
151
+ return np.array([domain_min, domain_max])
152
+
153
+ magnitude = 10 ** np.floor(np.log10(raw_step))
154
+ residual = raw_step / magnitude
155
+ if residual <= 1.5:
156
+ nice_step = 1.0 * magnitude
157
+ elif residual <= 3.0:
158
+ nice_step = 2.0 * magnitude
159
+ elif residual <= 7.0:
160
+ nice_step = 5.0 * magnitude
161
+ else:
162
+ nice_step = 10.0 * magnitude
163
+
164
+ lo = np.floor(domain_min / nice_step) * nice_step
165
+ hi = np.ceil(domain_max / nice_step) * nice_step
166
+ return np.arange(lo, hi + nice_step * 0.5, nice_step)
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # Public API
171
+ # ---------------------------------------------------------------------------
172
+
173
+ def col_numeric(
174
+ palette: Union[str, Sequence[str]],
175
+ domain: Optional[Tuple[float, float]] = None,
176
+ na_color: str = "#808080",
177
+ alpha: bool = False,
178
+ reverse: bool = False,
179
+ ) -> Callable[[ArrayLike], List[str]]:
180
+ """
181
+ Map continuous numeric values to colours via linear interpolation.
182
+
183
+ Parameters
184
+ ----------
185
+ palette : str or sequence of str
186
+ Colourmap name (e.g. ``"Blues"``, ``"Greens"``, ``"viridis"``) or
187
+ a list of colour strings defining the ramp.
188
+ domain : tuple of (float, float), optional
189
+ ``(min, max)`` of the data domain. If *None*, the domain is
190
+ inferred from the first call.
191
+ na_color : str, default "#808080"
192
+ Colour for missing / ``NaN`` values.
193
+ alpha : bool, default False
194
+ If *True*, alpha channels in the palette colours are included
195
+ in interpolation and output.
196
+ reverse : bool, default False
197
+ Reverse the palette direction.
198
+
199
+ Returns
200
+ -------
201
+ callable
202
+ ``f(x)`` mapping numeric array to a list of hex colour strings.
203
+
204
+ Examples
205
+ --------
206
+ >>> f = col_numeric(["white", "red"], domain=(0, 100))
207
+ >>> f([0, 50, 100]) # doctest: +SKIP
208
+ ['#ffffffff', '#ff8080ff', '#ff0000ff']
209
+ """
210
+ ramp = _safe_palette_func(palette, na_color, alpha=alpha)
211
+
212
+ # Mutable state for auto-domain
213
+ state: Dict[str, Any] = {
214
+ "domain": domain,
215
+ }
216
+
217
+ def _map(x: ArrayLike) -> List[str]:
218
+ x = np.asarray(x, dtype=float)
219
+
220
+ if state["domain"] is None:
221
+ finite = x[np.isfinite(x)]
222
+ if len(finite) == 0:
223
+ return [na_color] * x.size
224
+ state["domain"] = (float(finite.min()), float(finite.max()))
225
+
226
+ lo, hi = state["domain"]
227
+ scaled = _rescale(x, lo, hi)
228
+ if reverse:
229
+ scaled = 1.0 - scaled
230
+ # R: warn when values are outside the color scale
231
+ if np.any((scaled < 0) | (scaled > 1), where=~np.isnan(scaled)):
232
+ warnings.warn(
233
+ "Some values were outside the color scale and will be "
234
+ "treated as NA",
235
+ stacklevel=2,
236
+ )
237
+ return ramp(scaled)
238
+
239
+ return _map
240
+
241
+
242
+ def col_bin(
243
+ palette: Union[str, Sequence[str]],
244
+ domain: Optional[Tuple[float, float]] = None,
245
+ bins: Union[int, Sequence[float]] = 7,
246
+ pretty: bool = True,
247
+ na_color: str = "#808080",
248
+ alpha: bool = False,
249
+ reverse: bool = False,
250
+ right: bool = False,
251
+ ) -> Callable[[ArrayLike], List[str]]:
252
+ """
253
+ Map continuous data to colours through binning.
254
+
255
+ Parameters
256
+ ----------
257
+ palette : str or sequence of str
258
+ Palette specification.
259
+ domain : tuple of (float, float), optional
260
+ Data domain. Required if *bins* is an integer and *pretty* is False.
261
+ bins : int or sequence of float, default 7
262
+ Number of bins or explicit breakpoints.
263
+ pretty : bool, default True
264
+ Use "pretty" breakpoints when *bins* is an integer.
265
+ na_color : str, default "#808080"
266
+ Colour for missing values.
267
+ alpha : bool, default False
268
+ If *True*, preserve alpha channels in interpolation.
269
+ reverse : bool, default False
270
+ Reverse palette direction.
271
+ right : bool, default False
272
+ If *True*, bins are right-closed ``(a, b]``; otherwise left-closed
273
+ ``[a, b)``.
274
+
275
+ Returns
276
+ -------
277
+ callable
278
+ ``f(x)`` mapping numeric array to a list of hex colour strings.
279
+ """
280
+ # R: autobin = is.null(domain) && length(bins) == 1
281
+ autobin = domain is None and isinstance(bins, int)
282
+
283
+ state: Dict[str, Any] = {
284
+ "breaks": None,
285
+ }
286
+
287
+ if not isinstance(bins, int) and domain is None:
288
+ # R: explicit breaks don't need a domain
289
+ state["breaks"] = np.sort(np.asarray(bins, dtype=float))
290
+ elif domain is not None:
291
+ state["breaks"] = _get_bins(domain, None, bins, pretty)
292
+
293
+ def _map(x: ArrayLike) -> List[str]:
294
+ x = np.atleast_1d(np.asarray(x, dtype=float))
295
+
296
+ if x.size == 0 or np.all(np.isnan(x)):
297
+ return [na_color] * x.size
298
+
299
+ breaks = state["breaks"]
300
+ if breaks is None:
301
+ breaks = _get_bins(None, x, bins, pretty)
302
+ state["breaks"] = breaks
303
+
304
+ n_bins = len(breaks) - 1
305
+ if n_bins < 1:
306
+ return [na_color] * x.size
307
+
308
+ # R: col_bin delegates to col_factor(palette, domain=1:numColors, ...)
309
+ # This creates a discrete color mapping with exactly numColors colors.
310
+ color_func = col_factor(
311
+ palette,
312
+ domain=[str(i) for i in range(1, n_bins + 1)],
313
+ na_color=na_color,
314
+ alpha=alpha,
315
+ reverse=reverse,
316
+ )
317
+
318
+ # R: cut(x, breaks, labels=FALSE, include.lowest=TRUE, right=right)
319
+ ints = _cut(x, breaks, include_lowest=True, right=right)
320
+
321
+ # Map bin labels through col_factor
322
+ labels = [str(int(v)) if not np.isnan(v) else "NA" for v in ints]
323
+ result: List[str] = []
324
+ for lab, val in zip(labels, x.flat):
325
+ if np.isnan(val) or lab == "NA":
326
+ result.append(na_color)
327
+ else:
328
+ mapped = color_func([lab])
329
+ result.append(mapped[0])
330
+
331
+ return result
332
+
333
+ return _map
334
+
335
+
336
+ def _get_bins(
337
+ domain: Optional[Tuple[float, float]],
338
+ x: Optional[np.ndarray],
339
+ bins: Union[int, Sequence[float]],
340
+ pretty: bool,
341
+ ) -> np.ndarray:
342
+ """Compute bin breakpoints (R's ``getBins``)."""
343
+ if not isinstance(bins, int):
344
+ return np.sort(np.asarray(bins, dtype=float))
345
+ if bins < 2:
346
+ raise ValueError(f"Invalid bins value ({bins}); bin count must be at least 2")
347
+
348
+ if domain is not None:
349
+ ref = np.asarray(domain, dtype=float)
350
+ elif x is not None:
351
+ ref = x[np.isfinite(x)]
352
+ else:
353
+ raise ValueError("domain and x can't both be None")
354
+
355
+ if len(ref) == 0:
356
+ return np.array([0.0, 1.0])
357
+
358
+ if pretty:
359
+ return _pretty_breaks(float(ref.min()), float(ref.max()), bins)
360
+ else:
361
+ return np.linspace(float(ref.min()), float(ref.max()), bins + 1)
362
+
363
+
364
+ def _cut(
365
+ x: np.ndarray,
366
+ breaks: np.ndarray,
367
+ include_lowest: bool = True,
368
+ right: bool = False,
369
+ ) -> np.ndarray:
370
+ """
371
+ Bin values into integer labels (R's ``cut(..., labels=FALSE)``).
372
+
373
+ Returns 1-based bin indices; NaN for values outside the breaks range.
374
+ """
375
+ n_bins = len(breaks) - 1
376
+ result = np.full(x.shape, np.nan, dtype=float)
377
+
378
+ for i in range(x.size):
379
+ val = x.flat[i]
380
+ if np.isnan(val):
381
+ continue
382
+
383
+ assigned = False
384
+ for b in range(n_bins):
385
+ lo, hi = breaks[b], breaks[b + 1]
386
+
387
+ if right:
388
+ # (lo, hi] — right-closed
389
+ in_bin = (val > lo) and (val <= hi)
390
+ # include.lowest: first bin becomes [lo, hi]
391
+ if include_lowest and b == 0:
392
+ in_bin = (val >= lo) and (val <= hi)
393
+ else:
394
+ # [lo, hi) — left-closed
395
+ in_bin = (val >= lo) and (val < hi)
396
+ # include.lowest: last bin becomes [lo, hi]
397
+ if include_lowest and b == n_bins - 1:
398
+ in_bin = (val >= lo) and (val <= hi)
399
+
400
+ if in_bin:
401
+ result.flat[i] = b + 1 # 1-based
402
+ assigned = True
403
+ break
404
+
405
+ # not assigned → outside range → stays NaN
406
+
407
+ return result
408
+
409
+
410
+ def _safe_quantile(
411
+ x: np.ndarray,
412
+ probs: np.ndarray,
413
+ n_requested: int,
414
+ ) -> np.ndarray:
415
+ """R's ``safe_quantile``: deduplicate and warn on skewed data."""
416
+ bins = np.unique(np.quantile(x, probs))
417
+ if len(bins) < len(probs):
418
+ warnings.warn(
419
+ f"Skewed data means we can only allocate {len(bins)} unique "
420
+ f"colours not the {n_requested} requested",
421
+ stacklevel=3,
422
+ )
423
+ return bins
424
+
425
+
426
+ def col_quantile(
427
+ palette: Union[str, Sequence[str]],
428
+ domain: Optional[ArrayLike] = None,
429
+ n: int = 4,
430
+ probs: Optional[Sequence[float]] = None,
431
+ na_color: str = "#808080",
432
+ alpha: bool = False,
433
+ reverse: bool = False,
434
+ right: bool = False,
435
+ ) -> Callable[[ArrayLike], List[str]]:
436
+ """
437
+ Map quantile-based bins to colours.
438
+
439
+ Parameters
440
+ ----------
441
+ palette : str or sequence of str
442
+ Palette specification.
443
+ domain : array-like, optional
444
+ Reference data from which quantiles are computed. If *None*,
445
+ quantiles are computed on the first call.
446
+ n : int, default 4
447
+ Number of quantile bins (ignored if *probs* is given).
448
+ probs : sequence of float, optional
449
+ Explicit quantile probabilities (e.g. ``[0, 0.25, 0.5, 0.75, 1]``).
450
+ na_color : str, default "#808080"
451
+ Colour for missing values.
452
+ reverse : bool, default False
453
+ Reverse palette direction.
454
+ right : bool, default False
455
+ Bin closure direction.
456
+
457
+ Returns
458
+ -------
459
+ callable
460
+ ``f(x)`` mapping numeric array to a list of hex colour strings.
461
+ """
462
+ if probs is None:
463
+ probs_arr = np.linspace(0, 1, n + 1)
464
+ else:
465
+ probs_arr = np.asarray(probs, dtype=float)
466
+
467
+ state: Dict[str, Any] = {"breaks": None}
468
+
469
+ n_requested = len(probs_arr) - 1
470
+
471
+ if domain is not None:
472
+ domain_arr = np.asarray(domain, dtype=float)
473
+ finite = domain_arr[np.isfinite(domain_arr)]
474
+ if len(finite) > 0:
475
+ state["breaks"] = _safe_quantile(finite, probs_arr, n_requested)
476
+
477
+ def _map(x: ArrayLike) -> List[str]:
478
+ x = np.asarray(x, dtype=float)
479
+
480
+ if state["breaks"] is None:
481
+ finite = x[np.isfinite(x)]
482
+ if len(finite) == 0:
483
+ return [na_color] * x.size
484
+ state["breaks"] = _safe_quantile(finite, probs_arr, n_requested)
485
+
486
+ # Delegate to col_bin with the computed breaks
487
+ mapper = col_bin(
488
+ palette,
489
+ bins=state["breaks"],
490
+ na_color=na_color,
491
+ alpha=alpha,
492
+ reverse=reverse,
493
+ right=right,
494
+ )
495
+ return mapper(x)
496
+
497
+ return _map
498
+
499
+
500
+ def col_factor(
501
+ palette: Union[str, Sequence[str]],
502
+ domain: Optional[Sequence[str]] = None,
503
+ levels: Optional[Sequence[str]] = None,
504
+ ordered: bool = False,
505
+ na_color: str = "#808080",
506
+ alpha: bool = False,
507
+ reverse: bool = False,
508
+ ) -> Callable[[Union[ArrayLike, Sequence[str]]], List[str]]:
509
+ """
510
+ Map categorical (factor) values to colours.
511
+
512
+ Parameters
513
+ ----------
514
+ palette : str or sequence of str
515
+ Palette specification. When a list of colours, the number of
516
+ colours should ideally match the number of levels.
517
+ domain : sequence of str, optional
518
+ Valid category labels. If *None*, inferred on first call.
519
+ levels : sequence of str, optional
520
+ Synonym for *domain* (mirrors R's ``levels`` argument).
521
+ ordered : bool, default False
522
+ If *True*, treat categories as ordered and interpolate; otherwise
523
+ assign evenly spaced colours.
524
+ na_color : str, default "#808080"
525
+ Colour for missing / unknown levels.
526
+ reverse : bool, default False
527
+ Reverse palette direction.
528
+
529
+ Returns
530
+ -------
531
+ callable
532
+ ``f(x)`` mapping an array of category labels to a list of hex
533
+ colour strings.
534
+ """
535
+ lvls = list(levels) if levels is not None else (
536
+ list(domain) if domain is not None else None
537
+ )
538
+
539
+ state: Dict[str, Any] = {"levels": lvls, "colors": None}
540
+
541
+ def _ensure_colors(x_levels: List[str]) -> Dict[str, str]:
542
+ if state["colors"] is not None:
543
+ return state["colors"]
544
+
545
+ all_levels = state["levels"] if state["levels"] is not None else x_levels
546
+ if reverse:
547
+ all_levels = list(reversed(all_levels))
548
+
549
+ n = len(all_levels)
550
+ if n == 0:
551
+ state["colors"] = {}
552
+ return state["colors"]
553
+
554
+ # R: safePaletteFunc(palette, na.color, alpha,
555
+ # nlevels = length(lvls) * ifelse(reverse, -1, 1))
556
+ nlevels = n * (-1 if reverse else 1)
557
+ ramp = _safe_palette_func(
558
+ palette, na_color, alpha=alpha, nlevels=nlevels
559
+ )
560
+
561
+ if n == 1:
562
+ positions = np.array([0.5])
563
+ else:
564
+ # R: rescale(as.integer(x), from = c(1, length(lvls)))
565
+ positions = np.linspace(0, 1, n)
566
+
567
+ hex_colors = ramp(positions)
568
+ state["colors"] = dict(zip(all_levels, hex_colors))
569
+ state["levels"] = all_levels
570
+ return state["colors"]
571
+
572
+ def _map(x: Union[ArrayLike, Sequence[str]]) -> List[str]:
573
+ if isinstance(x, np.ndarray):
574
+ labels = x.astype(str).tolist()
575
+ else:
576
+ labels = [str(v) for v in x]
577
+
578
+ if state["levels"] is None:
579
+ # R: calcLevels — ordered preserves insertion order,
580
+ # unordered sorts alphabetically.
581
+ seen: Dict[str, None] = {}
582
+ for lab in labels:
583
+ if lab not in seen:
584
+ seen[lab] = None
585
+ discovered = list(seen.keys())
586
+ if not ordered:
587
+ discovered = sorted(discovered)
588
+ state["levels"] = discovered
589
+
590
+ color_map = _ensure_colors(labels)
591
+ return [color_map.get(lab, na_color) for lab in labels]
592
+
593
+ return _map