ggplot2-python 4.0.2.9000__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ggplot2_py/__init__.py +852 -0
  2. ggplot2_py/_compat.py +475 -0
  3. ggplot2_py/_plugins.py +129 -0
  4. ggplot2_py/_utils.py +544 -0
  5. ggplot2_py/aes.py +586 -0
  6. ggplot2_py/annotation.py +540 -0
  7. ggplot2_py/coord.py +2108 -0
  8. ggplot2_py/coords/__init__.py +49 -0
  9. ggplot2_py/datasets.py +265 -0
  10. ggplot2_py/draw_key.py +454 -0
  11. ggplot2_py/facet.py +1456 -0
  12. ggplot2_py/fortify.py +95 -0
  13. ggplot2_py/geom.py +4516 -0
  14. ggplot2_py/geoms/__init__.py +12 -0
  15. ggplot2_py/ggproto.py +279 -0
  16. ggplot2_py/guide.py +2925 -0
  17. ggplot2_py/guide_axis.py +615 -0
  18. ggplot2_py/guide_colourbar.py +657 -0
  19. ggplot2_py/guide_legend.py +1061 -0
  20. ggplot2_py/guides/__init__.py +8 -0
  21. ggplot2_py/labeller.py +296 -0
  22. ggplot2_py/labels.py +309 -0
  23. ggplot2_py/layer.py +954 -0
  24. ggplot2_py/layout.py +754 -0
  25. ggplot2_py/limits.py +314 -0
  26. ggplot2_py/plot.py +1401 -0
  27. ggplot2_py/plot_render.py +866 -0
  28. ggplot2_py/position.py +1269 -0
  29. ggplot2_py/protocols.py +171 -0
  30. ggplot2_py/py.typed +0 -0
  31. ggplot2_py/qplot.py +233 -0
  32. ggplot2_py/resources/diamonds.csv +53941 -0
  33. ggplot2_py/resources/economics.csv +575 -0
  34. ggplot2_py/resources/economics_long.csv +2871 -0
  35. ggplot2_py/resources/faithfuld.csv +5626 -0
  36. ggplot2_py/resources/luv_colours.csv +658 -0
  37. ggplot2_py/resources/midwest.csv +438 -0
  38. ggplot2_py/resources/mpg.csv +235 -0
  39. ggplot2_py/resources/msleep.csv +84 -0
  40. ggplot2_py/resources/presidential.csv +13 -0
  41. ggplot2_py/resources/seals.csv +1156 -0
  42. ggplot2_py/resources/txhousing.csv +8603 -0
  43. ggplot2_py/save.py +316 -0
  44. ggplot2_py/scale.py +2727 -0
  45. ggplot2_py/scales/__init__.py +4252 -0
  46. ggplot2_py/stat.py +6071 -0
  47. ggplot2_py/stats/__init__.py +9 -0
  48. ggplot2_py/theme.py +490 -0
  49. ggplot2_py/theme_defaults.py +1350 -0
  50. ggplot2_py/theme_elements.py +2052 -0
  51. ggplot2_python-4.0.2.9000.dist-info/METADATA +179 -0
  52. ggplot2_python-4.0.2.9000.dist-info/RECORD +54 -0
  53. ggplot2_python-4.0.2.9000.dist-info/WHEEL +4 -0
  54. ggplot2_python-4.0.2.9000.dist-info/licenses/LICENSE +3 -0
ggplot2_py/scale.py ADDED
@@ -0,0 +1,2727 @@
1
+ """
2
+ Base Scale classes and constructor functions for the ggplot2 scale system.
3
+
4
+ This module implements the core Scale hierarchy:
5
+
6
+ - :class:`Scale` -- abstract base
7
+ - :class:`ScaleContinuous` -- continuous data
8
+ - :class:`ScaleDiscrete` -- discrete / categorical data
9
+ - :class:`ScaleBinned` -- binned continuous data
10
+ - Position sub-classes (:class:`ScaleContinuousPosition`, etc.)
11
+ - Identity sub-classes (:class:`ScaleContinuousIdentity`, etc.)
12
+ - Date/datetime sub-classes
13
+
14
+ Also provides constructor helpers:
15
+
16
+ - :func:`continuous_scale`
17
+ - :func:`discrete_scale`
18
+ - :func:`binned_scale`
19
+
20
+ Container class :class:`ScalesList`, secondary axis support
21
+ (:class:`AxisSecondary`, :func:`sec_axis`, :func:`dup_axis`),
22
+ and expansion helpers (:func:`expansion`, :func:`expand_scale`).
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import copy
28
+ import math
29
+ import warnings
30
+ from typing import (
31
+ Any,
32
+ Callable,
33
+ Dict,
34
+ List,
35
+ Optional,
36
+ Sequence,
37
+ Tuple,
38
+ Type,
39
+ Union,
40
+ )
41
+
42
+ import numpy as np
43
+ import pandas as pd
44
+
45
+ from scales import (
46
+ ContinuousRange,
47
+ DiscreteRange,
48
+ as_transform,
49
+ breaks_extended,
50
+ censor,
51
+ discard,
52
+ expand_range,
53
+ is_transform,
54
+ oob_censor,
55
+ oob_squish,
56
+ oob_squish_infinite,
57
+ rescale,
58
+ rescale_mid,
59
+ rescale_max,
60
+ squish,
61
+ train_continuous,
62
+ train_discrete,
63
+ transform_identity,
64
+ zero_range,
65
+ )
66
+
67
+ from ggplot2_py._compat import (
68
+ Waiver,
69
+ cli_abort,
70
+ cli_inform,
71
+ cli_warn,
72
+ deprecate_warn,
73
+ is_waiver,
74
+ waiver,
75
+ )
76
+ from ggplot2_py.aes import standardise_aes_names
77
+ from ggplot2_py.ggproto import GGProto, fetch_ggproto, ggproto, ggproto_parent
78
+
79
+ __all__ = [
80
+ # Base classes
81
+ "Scale",
82
+ "ScaleContinuous",
83
+ "ScaleDiscrete",
84
+ "ScaleBinned",
85
+ # Position sub-classes
86
+ "ScaleContinuousPosition",
87
+ "ScaleDiscretePosition",
88
+ "ScaleBinnedPosition",
89
+ # Identity sub-classes
90
+ "ScaleContinuousIdentity",
91
+ "ScaleDiscreteIdentity",
92
+ # Date/datetime sub-classes
93
+ "ScaleContinuousDate",
94
+ "ScaleContinuousDatetime",
95
+ # Constructors
96
+ "continuous_scale",
97
+ "discrete_scale",
98
+ "binned_scale",
99
+ # Container
100
+ "ScalesList",
101
+ "scales_list",
102
+ # Secondary axis
103
+ "AxisSecondary",
104
+ "sec_axis",
105
+ "dup_axis",
106
+ "derive",
107
+ "is_derived",
108
+ "is_sec_axis",
109
+ # Expansion helpers
110
+ "expansion",
111
+ "expand_scale",
112
+ "expand_range4",
113
+ "default_expansion",
114
+ # Scale detection
115
+ "find_scale",
116
+ "is_scale",
117
+ # Mapped discrete sentinel
118
+ "mapped_discrete",
119
+ "is_mapped_discrete",
120
+ ]
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Utility helpers
125
+ # ---------------------------------------------------------------------------
126
+
127
+ _POSITION_AESTHETICS = frozenset(
128
+ [
129
+ "x",
130
+ "xmin",
131
+ "xmax",
132
+ "xend",
133
+ "xintercept",
134
+ "xmin_final",
135
+ "xmax_final",
136
+ "xlower",
137
+ "xmiddle",
138
+ "xupper",
139
+ "x0",
140
+ "y",
141
+ "ymin",
142
+ "ymax",
143
+ "yend",
144
+ "yintercept",
145
+ "ymin_final",
146
+ "ymax_final",
147
+ "ylower",
148
+ "ymiddle",
149
+ "yupper",
150
+ "y0",
151
+ ]
152
+ )
153
+
154
+ _X_AESTHETICS = [
155
+ "x", "xmin", "xmax", "xend", "xintercept",
156
+ "xmin_final", "xmax_final", "xlower", "xmiddle", "xupper", "x0",
157
+ ]
158
+
159
+ _Y_AESTHETICS = [
160
+ "y", "ymin", "ymax", "yend", "yintercept",
161
+ "ymin_final", "ymax_final", "ylower", "ymiddle", "yupper", "y0",
162
+ ]
163
+
164
+
165
+ def _is_position_aes(aesthetics: Union[str, Sequence[str]]) -> bool:
166
+ """Return True if any of *aesthetics* is a position aesthetic."""
167
+ if isinstance(aesthetics, str):
168
+ return aesthetics in _POSITION_AESTHETICS
169
+ return any(a in _POSITION_AESTHETICS for a in aesthetics)
170
+
171
+
172
+ def _is_discrete(x: Any) -> bool:
173
+ """Check whether *x* should be treated as discrete data."""
174
+ if isinstance(x, pd.Categorical) or isinstance(x, pd.CategoricalDtype):
175
+ return True
176
+ if isinstance(x, pd.Series):
177
+ if isinstance(x.dtype, pd.CategoricalDtype):
178
+ return True
179
+ if x.dtype == object:
180
+ return True
181
+ if pd.api.types.is_bool_dtype(x.dtype):
182
+ return True
183
+ return False
184
+ if isinstance(x, np.ndarray):
185
+ if x.dtype.kind in ("U", "S", "O", "b"):
186
+ return True
187
+ return False
188
+ if isinstance(x, (list, tuple)):
189
+ if len(x) == 0:
190
+ return False
191
+ first = x[0]
192
+ return isinstance(first, (str, bool))
193
+ if isinstance(x, (str, bool)):
194
+ return True
195
+ return False
196
+
197
+
198
+ def _empty(df: Any) -> bool:
199
+ """Check whether *df* is empty (None or zero-length)."""
200
+ if df is None:
201
+ return True
202
+ if isinstance(df, pd.DataFrame):
203
+ return len(df) == 0
204
+ if isinstance(df, dict):
205
+ return len(df) == 0
206
+ return False
207
+
208
+
209
+ def _unique0(x: Any) -> np.ndarray:
210
+ """Unique values preserving order."""
211
+ if x is None:
212
+ return np.array([])
213
+ arr = np.asarray(x)
214
+ _, idx = np.unique(arr, return_index=True)
215
+ return arr[np.sort(idx)]
216
+
217
+
218
+ def _check_breaks_labels(
219
+ breaks: Any,
220
+ labels: Any,
221
+ ) -> None:
222
+ """Validate that breaks and labels are compatible."""
223
+ if breaks is None or labels is None:
224
+ return
225
+ if isinstance(breaks, np.ndarray) and np.isscalar(breaks) and np.isnan(breaks):
226
+ cli_abort("Invalid breaks specification. Use None, not NaN.")
227
+ if (
228
+ not callable(breaks)
229
+ and not callable(labels)
230
+ and not is_waiver(breaks)
231
+ and not is_waiver(labels)
232
+ ):
233
+ breaks_arr = np.asarray(breaks) if not isinstance(breaks, (list, tuple)) else breaks
234
+ labels_arr = labels if not isinstance(labels, (list, tuple)) else labels
235
+ if hasattr(breaks_arr, "__len__") and hasattr(labels_arr, "__len__"):
236
+ if len(breaks_arr) != len(labels_arr):
237
+ cli_abort("breaks and labels must have the same length.")
238
+
239
+
240
+ def _is_finite(x: Any) -> np.ndarray:
241
+ """Element-wise finite check."""
242
+ arr = np.asarray(x, dtype=float)
243
+ return np.isfinite(arr)
244
+
245
+
246
+ # ---------------------------------------------------------------------------
247
+ # Mapped discrete sentinel
248
+ # ---------------------------------------------------------------------------
249
+
250
+ class _MappedDiscrete(np.ndarray):
251
+ """Sentinel wrapper for discrete values that have been mapped to numeric."""
252
+
253
+ def __new__(cls, x: Any) -> "_MappedDiscrete":
254
+ arr = np.asarray(x, dtype=float).view(cls)
255
+ return arr
256
+
257
+
258
+ def mapped_discrete(x: Any) -> Optional[_MappedDiscrete]:
259
+ """Wrap *x* as a mapped-discrete array."""
260
+ if x is None:
261
+ return None
262
+ return _MappedDiscrete(x)
263
+
264
+
265
+ def is_mapped_discrete(x: Any) -> bool:
266
+ """Check whether *x* is a mapped discrete array."""
267
+ return isinstance(x, _MappedDiscrete)
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Expansion helpers
272
+ # ---------------------------------------------------------------------------
273
+
274
+ def expansion(
275
+ mult: Union[float, Sequence[float]] = 0,
276
+ add: Union[float, Sequence[float]] = 0,
277
+ ) -> np.ndarray:
278
+ """Generate an expansion vector for scale padding.
279
+
280
+ Parameters
281
+ ----------
282
+ mult : float or sequence of float
283
+ Multiplicative range expansion factors. If length 1, both limits
284
+ use the same value; if length 2, ``(lower, upper)``.
285
+ add : float or sequence of float
286
+ Additive range expansion constants.
287
+
288
+ Returns
289
+ -------
290
+ numpy.ndarray
291
+ Length-4 array ``[mult_lower, add_lower, mult_upper, add_upper]``.
292
+ """
293
+ mult = np.atleast_1d(np.asarray(mult, dtype=float))
294
+ add = np.atleast_1d(np.asarray(add, dtype=float))
295
+ if len(mult) == 1:
296
+ mult = np.repeat(mult, 2)
297
+ if len(add) == 1:
298
+ add = np.repeat(add, 2)
299
+ if len(mult) != 2 or len(add) != 2:
300
+ cli_abort("mult and add must be numeric vectors with 1 or 2 elements.")
301
+ return np.array([mult[0], add[0], mult[1], add[1]])
302
+
303
+
304
+ def expand_scale(
305
+ mult: Union[float, Sequence[float]] = 0,
306
+ add: Union[float, Sequence[float]] = 0,
307
+ ) -> np.ndarray:
308
+ """Deprecated. Use :func:`expansion` instead.
309
+
310
+ Parameters
311
+ ----------
312
+ mult : float or sequence of float
313
+ Multiplicative range expansion factors.
314
+ add : float or sequence of float
315
+ Additive range expansion constants.
316
+
317
+ Returns
318
+ -------
319
+ numpy.ndarray
320
+ Length-4 expansion vector.
321
+ """
322
+ deprecate_warn("3.3.0", "expand_scale()", with_="expansion()")
323
+ return expansion(mult, add)
324
+
325
+
326
+ def expand_range4(
327
+ limits: Sequence[float],
328
+ expand: np.ndarray,
329
+ ) -> np.ndarray:
330
+ """Expand a numeric range with a 2- or 4-element expansion vector.
331
+
332
+ Parameters
333
+ ----------
334
+ limits : array-like
335
+ Length-2 numeric range.
336
+ expand : array-like
337
+ 2- or 4-element expansion vector ``[mult_lo, add_lo, mult_hi, add_hi]``
338
+ or ``[mult, add]`` (duplicated for both sides).
339
+
340
+ Returns
341
+ -------
342
+ numpy.ndarray
343
+ Expanded limits (length 2).
344
+ """
345
+ expand = np.asarray(expand, dtype=float)
346
+ limits = np.asarray(limits, dtype=float)
347
+ if len(expand) not in (2, 4):
348
+ cli_abort("expand must be a numeric vector with 2 or 4 elements.")
349
+ if not np.any(np.isfinite(limits)):
350
+ return np.array([-np.inf, np.inf])
351
+ if len(expand) == 2:
352
+ expand = np.tile(expand, 2)
353
+ # expand = [mult_lower, add_lower, mult_upper, add_upper]
354
+ # Compute expansion inline to handle asymmetric mult/add correctly.
355
+ # scales.expand_range only accepts scalar mul/add, so we compute manually.
356
+ extent = limits[1] - limits[0]
357
+ if extent == 0:
358
+ extent = 1.0
359
+ lower = limits[0] - extent * expand[0] - expand[1]
360
+ upper = limits[1] + extent * expand[2] + expand[3]
361
+ return np.array([lower, upper])
362
+
363
+
364
+ def default_expansion(
365
+ scale: Any,
366
+ discrete: Optional[np.ndarray] = None,
367
+ continuous: Optional[np.ndarray] = None,
368
+ expand: bool = True,
369
+ ) -> np.ndarray:
370
+ """Compute the default expansion for a scale.
371
+
372
+ Parameters
373
+ ----------
374
+ scale : Scale
375
+ A position scale.
376
+ discrete : array-like, optional
377
+ Default expansion for discrete scales.
378
+ continuous : array-like, optional
379
+ Default expansion for continuous scales.
380
+ expand : bool
381
+ Whether to apply expansion at all.
382
+
383
+ Returns
384
+ -------
385
+ numpy.ndarray
386
+ Length-4 expansion vector.
387
+ """
388
+ if discrete is None:
389
+ discrete = expansion(add=0.6)
390
+ if continuous is None:
391
+ continuous = expansion(mult=0.05)
392
+ out = expansion()
393
+ if not expand:
394
+ return out
395
+ scale_expand = scale.expand
396
+ if is_waiver(scale_expand):
397
+ scale_expand = discrete if scale.is_discrete() else continuous
398
+ scale_expand = np.asarray(scale_expand, dtype=float)
399
+ if len(scale_expand) < 4:
400
+ scale_expand = np.tile(scale_expand, 2)[:4]
401
+ out[0:2] = scale_expand[0:2]
402
+ out[2:4] = scale_expand[2:4]
403
+ return out
404
+
405
+
406
+ # ---------------------------------------------------------------------------
407
+ # Base Scale class
408
+ # ---------------------------------------------------------------------------
409
+
410
+ class Scale(GGProto):
411
+ """Abstract base class for all ggplot2 scales.
412
+
413
+ Scales translate data values to aesthetic values and populate
414
+ breaks and labels.
415
+ """
416
+
417
+ # --- Auto-registration registry (Python-exclusive) -------------------
418
+ _registry: Dict[str, Any] = {}
419
+
420
+ def __init_subclass__(cls, **kwargs: Any) -> None:
421
+ super().__init_subclass__(**kwargs)
422
+ name = cls.__name__
423
+ if name.startswith("Scale") and len(name) > 5:
424
+ key = name[5:] # strip "Scale" prefix
425
+ Scale._registry[key] = cls
426
+ Scale._registry[key.lower()] = cls
427
+
428
+ call: Optional[str] = None
429
+ aesthetics: List[str] = []
430
+ palette: Optional[Callable] = None
431
+ fallback_palette: Optional[Callable] = None
432
+ limits: Any = None
433
+ na_value: Any = np.nan
434
+ expand: Any = waiver()
435
+ name: Any = waiver()
436
+ breaks: Any = waiver()
437
+ labels: Any = waiver()
438
+ guide: Any = "legend"
439
+ position: str = "left"
440
+
441
+ # -- Transformation -------------------------------------------------------
442
+
443
+ def transform_df(self, df: pd.DataFrame) -> Dict[str, Any]:
444
+ """Apply transform to matching columns in *df*.
445
+
446
+ Parameters
447
+ ----------
448
+ df : pandas.DataFrame
449
+ Layer data.
450
+
451
+ Returns
452
+ -------
453
+ dict
454
+ Transformed columns keyed by aesthetic name.
455
+ """
456
+ if _empty(df):
457
+ return {}
458
+ aesthetics = [a for a in self.aesthetics if a in df.columns]
459
+ if not aesthetics:
460
+ return {}
461
+ return {a: self.transform(df[a]) for a in aesthetics}
462
+
463
+ def transform(self, x: Any) -> Any:
464
+ """Transform raw data values. Must be overridden."""
465
+ cli_abort("Not implemented.")
466
+
467
+ # -- Training -------------------------------------------------------------
468
+
469
+ def train_df(self, df: pd.DataFrame) -> None:
470
+ """Train scale on matching columns of *df*.
471
+
472
+ Parameters
473
+ ----------
474
+ df : pandas.DataFrame
475
+ Layer data.
476
+ """
477
+ if _empty(df):
478
+ return
479
+ aesthetics = [a for a in self.aesthetics if a in df.columns]
480
+ for a in aesthetics:
481
+ self.train(df[a])
482
+
483
+ def train(self, x: Any) -> None:
484
+ """Train on a vector. Must be overridden."""
485
+ cli_abort("Not implemented.")
486
+
487
+ # -- Mapping --------------------------------------------------------------
488
+
489
+ def map_df(self, df: pd.DataFrame, i: Optional[np.ndarray] = None) -> Dict[str, Any]:
490
+ """Map matching columns in *df* to aesthetic values.
491
+
492
+ Parameters
493
+ ----------
494
+ df : pandas.DataFrame
495
+ Layer data.
496
+ i : array-like, optional
497
+ Row index subset.
498
+
499
+ Returns
500
+ -------
501
+ dict
502
+ Mapped columns keyed by aesthetic name.
503
+ """
504
+ if _empty(df):
505
+ return {}
506
+ if self.palette is None:
507
+ pal = getattr(self, "fallback_palette", None)
508
+ if pal is not None:
509
+ self.palette = pal
510
+ aesthetics = [a for a in self.aesthetics if a in df.columns]
511
+ if not aesthetics:
512
+ return {}
513
+ result = {}
514
+ for a in aesthetics:
515
+ col = df[a].values if i is None else df[a].values[i]
516
+ result[a] = self.map(col)
517
+ return result
518
+
519
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
520
+ """Map data values to aesthetic values. Must be overridden."""
521
+ cli_abort("Not implemented.")
522
+
523
+ def rescale(
524
+ self,
525
+ x: Any,
526
+ limits: Optional[Any] = None,
527
+ range: Optional[Any] = None,
528
+ ) -> Any:
529
+ """Rescale to 0-1 range. Must be overridden."""
530
+ cli_abort("Not implemented.")
531
+
532
+ # -- Getters --------------------------------------------------------------
533
+
534
+ def get_limits(self) -> Any:
535
+ """Return the current scale limits (without expansion).
536
+
537
+ Returns
538
+ -------
539
+ array-like
540
+ Scale limits.
541
+ """
542
+ if self.is_empty():
543
+ return np.array([0.0, 1.0])
544
+ if self.limits is None:
545
+ return self.range.range
546
+ if callable(self.limits):
547
+ return self.limits(self.range.range)
548
+ return self.limits
549
+
550
+ def dimension(
551
+ self,
552
+ expand: Optional[np.ndarray] = None,
553
+ limits: Optional[Any] = None,
554
+ ) -> Any:
555
+ """Return continuous dimension of the scale. Must be overridden."""
556
+ cli_abort("Not implemented.")
557
+
558
+ def get_breaks(self, limits: Optional[Any] = None) -> Any:
559
+ """Resolve and return scale breaks. Must be overridden."""
560
+ cli_abort("Not implemented.")
561
+
562
+ def get_breaks_minor(
563
+ self,
564
+ n: int = 2,
565
+ b: Optional[Any] = None,
566
+ limits: Optional[Any] = None,
567
+ ) -> Any:
568
+ """Resolve and return minor breaks. Must be overridden."""
569
+ cli_abort("Not implemented.")
570
+
571
+ def get_labels(self, breaks: Optional[Any] = None) -> Any:
572
+ """Resolve and return labels for the given breaks. Must be overridden."""
573
+ cli_abort("Not implemented.")
574
+
575
+ def get_transformation(self) -> Any:
576
+ """Return the scale's transformation object.
577
+
578
+ Returns
579
+ -------
580
+ Transform
581
+ A scales-package transform object.
582
+ """
583
+ return getattr(self, "trans", transform_identity())
584
+
585
+ def break_positions(self, range: Optional[Any] = None) -> Any:
586
+ """Return mapped break positions.
587
+
588
+ Parameters
589
+ ----------
590
+ range : array-like, optional
591
+ Scale limits; defaults to ``get_limits()``.
592
+
593
+ Returns
594
+ -------
595
+ array-like
596
+ Mapped break positions.
597
+ """
598
+ if range is None:
599
+ range = self.get_limits()
600
+ return self.map(self.get_breaks(range))
601
+
602
+ def break_info(self, range: Optional[Any] = None) -> Dict[str, Any]:
603
+ """Return all break-related information. Must be overridden."""
604
+ cli_abort("Not implemented.")
605
+
606
+ # -- Titles ---------------------------------------------------------------
607
+
608
+ def make_title(
609
+ self,
610
+ guide_title: Any = None,
611
+ scale_title: Any = None,
612
+ label_title: Any = None,
613
+ ) -> Any:
614
+ """Resolve scale title from guide, scale, and label titles.
615
+
616
+ Parameters
617
+ ----------
618
+ guide_title : str or Waiver, optional
619
+ Title from the guide.
620
+ scale_title : str or Waiver, optional
621
+ Title from the scale ``name`` field.
622
+ label_title : str or Waiver, optional
623
+ Title from ``labs()``.
624
+
625
+ Returns
626
+ -------
627
+ str or None
628
+ Resolved title.
629
+ """
630
+ if guide_title is None:
631
+ guide_title = waiver()
632
+ if scale_title is None:
633
+ scale_title = waiver()
634
+ if label_title is None:
635
+ label_title = waiver()
636
+ title = label_title
637
+ if callable(scale_title) and not is_waiver(scale_title):
638
+ title = scale_title(title)
639
+ elif not is_waiver(scale_title):
640
+ title = scale_title
641
+ if callable(guide_title) and not is_waiver(guide_title):
642
+ title = guide_title(title)
643
+ elif not is_waiver(guide_title):
644
+ title = guide_title
645
+ return title
646
+
647
+ def make_sec_title(self, *args: Any, **kwargs: Any) -> Any:
648
+ """Resolve secondary axis title (delegates to ``make_title``)."""
649
+ return self.make_title(*args, **kwargs)
650
+
651
+ # -- Axis order -----------------------------------------------------------
652
+
653
+ def axis_order(self) -> List[str]:
654
+ """Return axis order as ``['primary', 'secondary']`` or reversed."""
655
+ order = ["primary", "secondary"]
656
+ if self.position in ("right", "bottom"):
657
+ order = list(reversed(order))
658
+ return order
659
+
660
+ # -- Utilities ------------------------------------------------------------
661
+
662
+ def clone(self) -> "Scale":
663
+ """Create an untrained copy of this scale.
664
+
665
+ Returns
666
+ -------
667
+ Scale
668
+ A new Scale with a fresh ``range``.
669
+ """
670
+ cli_abort("Not implemented.")
671
+
672
+ def reset(self) -> None:
673
+ """Reset the scale's range, un-training it."""
674
+ self.range.reset()
675
+
676
+ def is_empty(self) -> bool:
677
+ """Whether the scale contains no information for limits.
678
+
679
+ Returns
680
+ -------
681
+ bool
682
+ """
683
+ return self.range.range is None and self.limits is None
684
+
685
+ def is_discrete(self) -> bool:
686
+ """Whether the scale is discrete. Must be overridden.
687
+
688
+ Returns
689
+ -------
690
+ bool
691
+ """
692
+ cli_abort("Not implemented.")
693
+
694
+
695
+ # ---------------------------------------------------------------------------
696
+ # ScaleContinuous
697
+ # ---------------------------------------------------------------------------
698
+
699
+ def _default_transform(self: Any, x: Any) -> Any:
700
+ """Apply the scale's transformation to data values."""
701
+ transformation = self.get_transformation()
702
+ x_arr = np.asarray(x, dtype=float)
703
+ new_x = transformation.transform(x_arr)
704
+ new_x = np.asarray(new_x, dtype=float)
705
+ # Check for introduced infinities
706
+ finite_orig = np.isfinite(x_arr)
707
+ finite_new = np.isfinite(new_x)
708
+ if np.any(finite_orig & ~finite_new):
709
+ cli_warn(
710
+ f"{transformation.name} transformation introduced infinite values."
711
+ )
712
+ return new_x
713
+
714
+
715
+ class ScaleContinuous(Scale):
716
+ """Scale for continuous data.
717
+
718
+ Attributes
719
+ ----------
720
+ trans : Transform
721
+ Transformation object from the ``scales`` package.
722
+ rescaler : callable
723
+ Function to rescale values (default ``rescale``).
724
+ oob : callable
725
+ Out-of-bounds handler (default ``censor``).
726
+ minor_breaks : any
727
+ Minor break specification.
728
+ n_breaks : int or None
729
+ Desired number of major breaks.
730
+ """
731
+
732
+ na_value: Any = np.nan
733
+ rescaler: Callable = staticmethod(rescale)
734
+ oob: Callable = staticmethod(censor)
735
+ minor_breaks: Any = waiver()
736
+ n_breaks: Optional[int] = None
737
+ trans: Any = None
738
+
739
+ def __init_subclass__(cls, **kwargs: Any) -> None:
740
+ super().__init_subclass__(**kwargs)
741
+ if cls.trans is None:
742
+ cls.trans = transform_identity()
743
+
744
+ def is_discrete(self) -> bool:
745
+ return False
746
+
747
+ def train(self, x: Any) -> None:
748
+ """Train the continuous range on *x*.
749
+
750
+ Parameters
751
+ ----------
752
+ x : array-like
753
+ Numeric data values.
754
+ """
755
+ x_arr = np.asarray(x, dtype=float)
756
+ if len(x_arr) == 0:
757
+ return
758
+ self.range.train(x_arr)
759
+
760
+ def is_empty(self) -> bool:
761
+ has_data = self.range.range is not None
762
+ has_limits = callable(self.limits) or (
763
+ self.limits is not None
764
+ and np.all(np.isfinite(np.asarray(self.limits, dtype=float)))
765
+ )
766
+ return not has_data and not has_limits
767
+
768
+ def transform(self, x: Any) -> Any:
769
+ """Transform data values using the scale's transformation.
770
+
771
+ Parameters
772
+ ----------
773
+ x : array-like
774
+ Raw data values.
775
+
776
+ Returns
777
+ -------
778
+ numpy.ndarray
779
+ Transformed values.
780
+ """
781
+ return _default_transform(self, x)
782
+
783
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
784
+ """Map data values to aesthetic values via palette.
785
+
786
+ Parameters
787
+ ----------
788
+ x : array-like
789
+ Values in transformed space.
790
+ limits : array-like, optional
791
+ Scale limits; defaults to ``get_limits()``.
792
+
793
+ Returns
794
+ -------
795
+ numpy.ndarray
796
+ Mapped aesthetic values.
797
+ """
798
+ if limits is None:
799
+ limits = self.get_limits()
800
+ x_arr = np.asarray(x, dtype=float)
801
+ x_oob = self.oob(x_arr, range=limits)
802
+ x_rescaled = self.rescale(x_oob, limits)
803
+
804
+ uniq = _unique0(x_rescaled)
805
+ if len(uniq) == 0:
806
+ return np.full_like(x_arr, self.na_value)
807
+ pal = np.asarray(self.palette(uniq))
808
+
809
+ # Determine output dtype: colour strings need object dtype
810
+ out_dtype = pal.dtype if pal.dtype.kind in ("U", "S", "O") else float
811
+ scaled = np.full(len(x_arr), self.na_value, dtype=out_dtype)
812
+
813
+ for i, u in enumerate(uniq):
814
+ mask = x_rescaled == u
815
+ if np.any(mask):
816
+ scaled[mask] = pal[i] if i < len(pal) else self.na_value
817
+
818
+ # Fill NaN from x_rescaled
819
+ nan_mask = np.isnan(x_rescaled)
820
+ scaled[nan_mask] = self.na_value
821
+ return scaled
822
+
823
+ def rescale(
824
+ self,
825
+ x: Any,
826
+ limits: Optional[Any] = None,
827
+ range: Optional[Any] = None,
828
+ ) -> np.ndarray:
829
+ """Rescale *x* to [0, 1].
830
+
831
+ Parameters
832
+ ----------
833
+ x : array-like
834
+ Values to rescale.
835
+ limits : array-like, optional
836
+ Scale limits.
837
+ range : array-like, optional
838
+ Range to rescale from. Defaults to *limits*.
839
+
840
+ Returns
841
+ -------
842
+ numpy.ndarray
843
+ Rescaled values.
844
+ """
845
+ if limits is None:
846
+ limits = self.get_limits()
847
+ if range is None:
848
+ range = limits
849
+ return self.rescaler(x, from_range=range)
850
+
851
+ def get_limits(self) -> np.ndarray:
852
+ if self.is_empty():
853
+ return np.array([0.0, 1.0])
854
+ if self.limits is None:
855
+ return np.asarray(self.range.range)
856
+ if callable(self.limits):
857
+ transformation = self.get_transformation()
858
+ inv = transformation.inverse(np.asarray(self.range.range))
859
+ user_limits = self.limits(inv)
860
+ return np.asarray(transformation.transform(user_limits))
861
+ limits = np.asarray(self.limits, dtype=float)
862
+ r = np.asarray(self.range.range, dtype=float) if self.range.range is not None else limits
863
+ return np.where(np.isnan(limits), r, limits)
864
+
865
+ def dimension(
866
+ self,
867
+ expand: Optional[np.ndarray] = None,
868
+ limits: Optional[Any] = None,
869
+ ) -> np.ndarray:
870
+ """Return the (optionally expanded) continuous range.
871
+
872
+ Mirrors R ``Scale$dimension`` (scale-.R:713):
873
+
874
+ dimension = function(self, expand = expansion(0, 0), limits = ...)
875
+
876
+ R's default is **no expansion** (``expansion(0, 0)``). The
877
+ caller applies expansion explicitly when needed (e.g. at
878
+ panel-param setup). Python used to default to
879
+ ``expansion(0.05, 0)`` which inflated values that downstream
880
+ consumers — notably ``hex_binwidth = diff(x$dimension()) /
881
+ bins`` — expected to be the raw data extent.
882
+
883
+ Parameters
884
+ ----------
885
+ expand : array-like, optional
886
+ Expansion vector (defaults to none).
887
+ limits : array-like, optional
888
+ Scale limits.
889
+
890
+ Returns
891
+ -------
892
+ numpy.ndarray
893
+ Length-2 range.
894
+ """
895
+ if expand is None:
896
+ expand = expansion(0, 0)
897
+ if limits is None:
898
+ limits = self.get_limits()
899
+ return expand_range4(limits, expand)
900
+
901
+ def get_breaks(self, limits: Optional[Any] = None) -> Optional[np.ndarray]:
902
+ """Resolve and return break positions.
903
+
904
+ Parameters
905
+ ----------
906
+ limits : array-like, optional
907
+ Scale limits.
908
+
909
+ Returns
910
+ -------
911
+ numpy.ndarray or None
912
+ """
913
+ if self.is_empty():
914
+ return np.array([])
915
+ if limits is None:
916
+ limits = self.get_limits()
917
+ limits = np.asarray(limits, dtype=float)
918
+ transformation = self.get_transformation()
919
+
920
+ breaks = self.breaks
921
+ if is_waiver(breaks):
922
+ breaks = transformation.breaks_func
923
+
924
+ if breaks is None:
925
+ return None
926
+
927
+ if zero_range(limits.astype(float)):
928
+ return np.array([limits[0]])
929
+
930
+ if callable(breaks):
931
+ inv_limits = transformation.inverse(limits)
932
+ n_brk = getattr(self, "n_breaks", None)
933
+ if n_brk is not None:
934
+ try:
935
+ result = breaks(inv_limits, n=n_brk)
936
+ except TypeError:
937
+ result = breaks(inv_limits)
938
+ else:
939
+ result = breaks(inv_limits)
940
+ breaks_val = np.asarray(result, dtype=float)
941
+ else:
942
+ breaks_val = np.asarray(breaks, dtype=float)
943
+
944
+ return np.asarray(transformation.transform(breaks_val), dtype=float)
945
+
946
+ def get_breaks_minor(
947
+ self,
948
+ n: int = 2,
949
+ b: Optional[Any] = None,
950
+ limits: Optional[Any] = None,
951
+ ) -> Optional[np.ndarray]:
952
+ """Resolve minor breaks.
953
+
954
+ Parameters
955
+ ----------
956
+ n : int
957
+ Number of minor breaks between major breaks.
958
+ b : array-like, optional
959
+ Major break positions.
960
+ limits : array-like, optional
961
+ Scale limits.
962
+
963
+ Returns
964
+ -------
965
+ numpy.ndarray or None
966
+ """
967
+ if limits is None:
968
+ limits = self.get_limits()
969
+ limits = np.asarray(limits, dtype=float)
970
+ if zero_range(limits):
971
+ return None
972
+ if b is None:
973
+ b = self.break_positions()
974
+
975
+ minor = self.minor_breaks
976
+ if minor is None:
977
+ return None
978
+
979
+ if is_waiver(minor):
980
+ if b is None:
981
+ return None
982
+ transformation = self.get_transformation()
983
+ if not callable(getattr(transformation, "minor_breaks_func", None)):
984
+ return None
985
+ b_finite = np.asarray(b, dtype=float)
986
+ b_finite = b_finite[np.isfinite(b_finite)]
987
+ return np.asarray(transformation.minor_breaks_func(b_finite, limits, n))
988
+ elif callable(minor):
989
+ transformation = self.get_transformation()
990
+ inv_limits = transformation.inverse(limits)
991
+ result = minor(inv_limits)
992
+ return np.asarray(transformation.transform(result), dtype=float)
993
+ else:
994
+ transformation = self.get_transformation()
995
+ return np.asarray(transformation.transform(minor), dtype=float)
996
+
997
+ def get_labels(self, breaks: Optional[Any] = None) -> Optional[Any]:
998
+ """Resolve labels for the given breaks.
999
+
1000
+ Parameters
1001
+ ----------
1002
+ breaks : array-like, optional
1003
+ Break positions.
1004
+
1005
+ Returns
1006
+ -------
1007
+ list or None
1008
+ """
1009
+ if breaks is None:
1010
+ breaks = self.get_breaks()
1011
+ if breaks is None:
1012
+ return None
1013
+
1014
+ transformation = self.get_transformation()
1015
+ breaks_data = transformation.inverse(np.asarray(breaks, dtype=float))
1016
+
1017
+ labels = self.labels
1018
+ if labels is None:
1019
+ return None
1020
+ if is_waiver(labels):
1021
+ return list(transformation.format_func(breaks_data))
1022
+ if callable(labels):
1023
+ return list(labels(breaks_data))
1024
+ return list(labels)
1025
+
1026
+ def clone(self) -> "ScaleContinuous":
1027
+ new = copy.copy(self)
1028
+ new.range = ContinuousRange()
1029
+ return new
1030
+
1031
+ def break_info(self, range: Optional[Any] = None) -> Dict[str, Any]:
1032
+ """Compute all break info for position scales.
1033
+
1034
+ Parameters
1035
+ ----------
1036
+ range : array-like, optional
1037
+ The continuous range to compute breaks for.
1038
+
1039
+ Returns
1040
+ -------
1041
+ dict
1042
+ """
1043
+ if range is None:
1044
+ range = self.dimension()
1045
+ range = np.asarray(range, dtype=float)
1046
+
1047
+ major = self.get_breaks(range)
1048
+ labels = self.get_labels(major)
1049
+ minor = self.get_breaks_minor(b=major, limits=range)
1050
+ if minor is not None:
1051
+ minor = minor[~np.isnan(minor)]
1052
+
1053
+ # Censor out-of-range
1054
+ if major is not None:
1055
+ major_arr = np.asarray(major, dtype=float)
1056
+ oob_mask = (major_arr < range[0]) | (major_arr > range[1])
1057
+ if labels is not None:
1058
+ labels = [l for l, m in zip(labels, ~oob_mask) if m]
1059
+ major_arr = major_arr[~oob_mask]
1060
+ else:
1061
+ major_arr = None
1062
+
1063
+ major_n = rescale(major_arr, from_range=range) if major_arr is not None else None
1064
+ minor_n = rescale(minor, from_range=range) if minor is not None else None
1065
+
1066
+ return {
1067
+ "range": range,
1068
+ "labels": labels,
1069
+ "major": major_n,
1070
+ "minor": minor_n,
1071
+ "major_source": major_arr,
1072
+ "minor_source": minor,
1073
+ }
1074
+
1075
+
1076
+ # ---------------------------------------------------------------------------
1077
+ # ScaleDiscrete
1078
+ # ---------------------------------------------------------------------------
1079
+
1080
+ class ScaleDiscrete(Scale):
1081
+ """Scale for discrete / categorical data.
1082
+
1083
+ Attributes
1084
+ ----------
1085
+ drop : bool
1086
+ Whether to drop unused factor levels.
1087
+ na_translate : bool
1088
+ Whether to include NA in the scale.
1089
+ """
1090
+
1091
+ drop: bool = True
1092
+ na_value: Any = np.nan
1093
+ na_translate: bool = True
1094
+ n_breaks_cache: Optional[int] = None
1095
+ palette_cache: Optional[Any] = None
1096
+
1097
+ def is_discrete(self) -> bool:
1098
+ return True
1099
+
1100
+ def train(self, x: Any) -> None:
1101
+ """Train the discrete range on *x*.
1102
+
1103
+ Parameters
1104
+ ----------
1105
+ x : array-like
1106
+ Discrete data values.
1107
+ """
1108
+ if isinstance(x, pd.Series):
1109
+ x_arr = x.values
1110
+ else:
1111
+ x_arr = np.asarray(x)
1112
+ if len(x_arr) == 0:
1113
+ return
1114
+ self.range.train(x_arr, drop=self.drop)
1115
+
1116
+ def transform(self, x: Any) -> Any:
1117
+ """Identity transform for discrete scales."""
1118
+ return x
1119
+
1120
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
1121
+ """Map discrete values to palette values.
1122
+
1123
+ Parameters
1124
+ ----------
1125
+ x : array-like
1126
+ Discrete data values.
1127
+ limits : array-like, optional
1128
+ Scale limits.
1129
+
1130
+ Returns
1131
+ -------
1132
+ numpy.ndarray
1133
+ """
1134
+ if limits is None:
1135
+ limits = self.get_limits()
1136
+ if limits is None or len(limits) == 0:
1137
+ return np.full(len(np.asarray(x)), self.na_value)
1138
+
1139
+ limits = [l for l in limits if l is not None and not (isinstance(l, float) and np.isnan(l))]
1140
+ n = len(limits)
1141
+ if n < 1:
1142
+ return np.full(len(np.asarray(x)), self.na_value)
1143
+
1144
+ if self.n_breaks_cache is not None and self.n_breaks_cache == n:
1145
+ pal = self.palette_cache
1146
+ else:
1147
+ pal = self.palette(n)
1148
+ self.palette_cache = pal
1149
+ self.n_breaks_cache = n
1150
+
1151
+ x_str = [str(v) for v in np.asarray(x)]
1152
+ limits_str = [str(l) for l in limits]
1153
+
1154
+ if isinstance(pal, dict):
1155
+ pal_list = list(pal.values())
1156
+ elif isinstance(pal, np.ndarray):
1157
+ pal_list = list(pal)
1158
+ else:
1159
+ pal_list = list(pal) if hasattr(pal, "__iter__") else [pal]
1160
+
1161
+ na_val = self.na_value if self.na_translate else np.nan
1162
+
1163
+ result = []
1164
+ for v in x_str:
1165
+ if v in limits_str:
1166
+ idx = limits_str.index(v)
1167
+ if idx < len(pal_list):
1168
+ result.append(pal_list[idx])
1169
+ else:
1170
+ result.append(na_val)
1171
+ else:
1172
+ result.append(na_val)
1173
+ return np.array(result)
1174
+
1175
+ def rescale(
1176
+ self,
1177
+ x: Any,
1178
+ limits: Optional[Any] = None,
1179
+ range: Optional[Any] = None,
1180
+ ) -> np.ndarray:
1181
+ """Rescale discrete values."""
1182
+ if limits is None:
1183
+ limits = self.get_limits()
1184
+ if range is None:
1185
+ range = (1, len(limits))
1186
+ x_arr = np.asarray(x)
1187
+ limits_str = [str(l) for l in limits]
1188
+ matched = np.array([limits_str.index(str(v)) + 1 if str(v) in limits_str else np.nan for v in x_arr])
1189
+ return rescale(matched, from_range=range)
1190
+
1191
+ def dimension(
1192
+ self,
1193
+ expand: Optional[np.ndarray] = None,
1194
+ limits: Optional[Any] = None,
1195
+ ) -> np.ndarray:
1196
+ if expand is None:
1197
+ # R default for discrete position scales: expansion(add = 0.6)
1198
+ expand = expansion(0, 0.6)
1199
+ if limits is None:
1200
+ limits = self.get_limits()
1201
+ n = len(limits) if limits is not None else 0
1202
+ if n == 0:
1203
+ return np.array([0.0, 1.0])
1204
+ return expand_range4(np.array([1.0, float(n)]), expand)
1205
+
1206
+ def get_breaks(self, limits: Optional[Any] = None) -> Optional[Any]:
1207
+ if self.is_empty():
1208
+ return np.array([])
1209
+ if limits is None:
1210
+ limits = self.get_limits()
1211
+ breaks = self.breaks
1212
+ if breaks is None:
1213
+ return None
1214
+ if is_waiver(breaks):
1215
+ return limits
1216
+ if callable(breaks):
1217
+ return breaks(limits)
1218
+ # Filter breaks to those in limits
1219
+ if limits is not None:
1220
+ limits_str = set(str(l) for l in limits)
1221
+ return [b for b in breaks if str(b) in limits_str]
1222
+ return breaks
1223
+
1224
+ def get_breaks_minor(
1225
+ self,
1226
+ n: int = 2,
1227
+ b: Optional[Any] = None,
1228
+ limits: Optional[Any] = None,
1229
+ ) -> Optional[Any]:
1230
+ minor = self.minor_breaks if hasattr(self, "minor_breaks") else waiver()
1231
+ if is_waiver(minor) or minor is None:
1232
+ return None
1233
+ if callable(minor):
1234
+ if limits is None:
1235
+ limits = self.get_limits()
1236
+ return minor(limits)
1237
+ return minor
1238
+
1239
+ def get_labels(self, breaks: Optional[Any] = None) -> Optional[Any]:
1240
+ if self.is_empty():
1241
+ return []
1242
+ if breaks is None:
1243
+ breaks = self.get_breaks()
1244
+ if breaks is None:
1245
+ return None
1246
+ labels = self.labels
1247
+ if labels is None:
1248
+ return None
1249
+ if is_waiver(labels):
1250
+ return [str(b) for b in breaks]
1251
+ if callable(labels):
1252
+ return list(labels(breaks))
1253
+ return list(labels)
1254
+
1255
+ def clone(self) -> "ScaleDiscrete":
1256
+ new = copy.copy(self)
1257
+ new.range = DiscreteRange()
1258
+ return new
1259
+
1260
+ def break_info(self, range: Optional[Any] = None) -> Dict[str, Any]:
1261
+ limits = self.get_limits()
1262
+ major = self.get_breaks(limits)
1263
+ if major is None:
1264
+ return {
1265
+ "range": range,
1266
+ "labels": None,
1267
+ "major": None,
1268
+ "minor": None,
1269
+ "major_source": None,
1270
+ "minor_source": None,
1271
+ }
1272
+ labels = self.get_labels(major)
1273
+ major_mapped = self.map(major)
1274
+ major_mapped = major_mapped[~np.isnan(major_mapped.astype(float))]
1275
+ major_n = rescale(major_mapped, from_range=range) if range is not None else None
1276
+ return {
1277
+ "range": range,
1278
+ "labels": labels,
1279
+ "major": major_n,
1280
+ "minor": None,
1281
+ "major_source": major_mapped,
1282
+ "minor_source": None,
1283
+ }
1284
+
1285
+
1286
+ # ---------------------------------------------------------------------------
1287
+ # ScaleBinned
1288
+ # ---------------------------------------------------------------------------
1289
+
1290
+ class ScaleBinned(Scale):
1291
+ """Scale for binned continuous data.
1292
+
1293
+ Attributes
1294
+ ----------
1295
+ n_breaks : int or None
1296
+ Desired number of bins.
1297
+ nice_breaks : bool
1298
+ Whether to use nicely-spaced breaks.
1299
+ right : bool
1300
+ Whether bins are closed on the right.
1301
+ show_limits : bool
1302
+ Whether to show scale limits as ticks.
1303
+ """
1304
+
1305
+ na_value: Any = np.nan
1306
+ rescaler: Callable = staticmethod(rescale)
1307
+ oob: Callable = staticmethod(squish)
1308
+ n_breaks: Optional[int] = None
1309
+ nice_breaks: bool = True
1310
+ right: bool = True
1311
+ after_stat: bool = False
1312
+ show_limits: bool = False
1313
+ trans: Any = None
1314
+ palette_cache: Optional[Any] = None
1315
+
1316
+ def __init_subclass__(cls, **kwargs: Any) -> None:
1317
+ super().__init_subclass__(**kwargs)
1318
+ if cls.trans is None:
1319
+ cls.trans = transform_identity()
1320
+
1321
+ def is_discrete(self) -> bool:
1322
+ return False
1323
+
1324
+ def train(self, x: Any) -> None:
1325
+ x_arr = np.asarray(x, dtype=float)
1326
+ if len(x_arr) == 0:
1327
+ return
1328
+ if not np.issubdtype(x_arr.dtype, np.number):
1329
+ cli_abort("Binned scales only support continuous data.")
1330
+ self.range.train(x_arr)
1331
+
1332
+ def transform(self, x: Any) -> Any:
1333
+ return _default_transform(self, x)
1334
+
1335
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
1336
+ if limits is None:
1337
+ limits = self.get_limits()
1338
+ limits = np.asarray(limits, dtype=float)
1339
+
1340
+ if self.after_stat:
1341
+ return x
1342
+
1343
+ breaks = self.get_breaks(limits)
1344
+ if breaks is None:
1345
+ breaks = np.array([])
1346
+ all_breaks = np.unique(np.sort(np.concatenate([limits[:1], np.asarray(breaks), limits[1:]])))
1347
+
1348
+ x_arr = np.asarray(self.oob(np.asarray(x, dtype=float), range=limits), dtype=float)
1349
+ x_arr = np.where(~np.isnan(x_arr), x_arr, self.na_value)
1350
+
1351
+ # Rescale breaks
1352
+ breaks_resc = self.rescale(all_breaks, limits)
1353
+ if len(breaks_resc) > 1:
1354
+ bins = np.digitize(self.rescale(x_arr, limits), breaks_resc, right=not self.right)
1355
+ bins = np.clip(bins, 1, len(breaks_resc) - 1)
1356
+ midpoints = breaks_resc[:-1] + np.diff(breaks_resc) / 2.0
1357
+
1358
+ if self.palette_cache is not None:
1359
+ pal = self.palette_cache
1360
+ else:
1361
+ pal = self.palette(midpoints)
1362
+ self.palette_cache = pal
1363
+
1364
+ if isinstance(pal, np.ndarray):
1365
+ scaled = pal[np.clip(bins - 1, 0, len(pal) - 1)]
1366
+ else:
1367
+ pal_arr = np.asarray(pal)
1368
+ scaled = pal_arr[np.clip(bins - 1, 0, len(pal_arr) - 1)]
1369
+ # np.isnan doesn't work on object arrays (e.g. colour strings)
1370
+ if scaled.dtype.kind in ("U", "S", "O"):
1371
+ na_mask = np.array([v is None or (isinstance(v, float) and np.isnan(v))
1372
+ for v in scaled])
1373
+ scaled[na_mask] = self.na_value
1374
+ return scaled
1375
+ return np.where(~np.isnan(scaled), scaled, self.na_value)
1376
+ else:
1377
+ return np.full_like(x_arr, self.na_value)
1378
+
1379
+ def rescale(
1380
+ self,
1381
+ x: Any,
1382
+ limits: Optional[Any] = None,
1383
+ range: Optional[Any] = None,
1384
+ ) -> np.ndarray:
1385
+ if limits is None:
1386
+ limits = self.get_limits()
1387
+ if range is None:
1388
+ range = limits
1389
+ return self.rescaler(x, from_range=range)
1390
+
1391
+ def dimension(
1392
+ self,
1393
+ expand: Optional[np.ndarray] = None,
1394
+ limits: Optional[Any] = None,
1395
+ ) -> np.ndarray:
1396
+ if expand is None:
1397
+ expand = np.array([0.0, 0.0, 0.0, 0.0])
1398
+ if limits is None:
1399
+ limits = self.get_limits()
1400
+ return expand_range4(np.asarray(limits), expand)
1401
+
1402
+ def get_limits(self) -> np.ndarray:
1403
+ # Delegate to continuous logic
1404
+ if self.is_empty():
1405
+ return np.array([0.0, 1.0])
1406
+ if self.limits is None:
1407
+ return np.asarray(self.range.range)
1408
+ if callable(self.limits):
1409
+ transformation = self.get_transformation()
1410
+ inv = transformation.inverse(np.asarray(self.range.range))
1411
+ user_limits = self.limits(inv)
1412
+ return np.asarray(transformation.transform(user_limits))
1413
+ limits = np.asarray(self.limits, dtype=float)
1414
+ r = np.asarray(self.range.range, dtype=float) if self.range.range is not None else limits
1415
+ return np.where(np.isnan(limits), r, limits)
1416
+
1417
+ def get_breaks(self, limits: Optional[Any] = None) -> Optional[np.ndarray]:
1418
+ if self.is_empty():
1419
+ return np.array([])
1420
+ if limits is None:
1421
+ limits = self.get_limits()
1422
+
1423
+ transformation = self.get_transformation()
1424
+ inv_limits = transformation.inverse(np.asarray(limits, dtype=float))
1425
+ inv_limits_sorted = np.sort(inv_limits)
1426
+
1427
+ breaks = self.breaks
1428
+ if breaks is None:
1429
+ return None
1430
+ if is_waiver(breaks):
1431
+ if self.nice_breaks:
1432
+ n = self.n_breaks or 5
1433
+ try:
1434
+ result = transformation.breaks_func(inv_limits_sorted, n=n)
1435
+ except TypeError:
1436
+ result = transformation.breaks_func(inv_limits_sorted)
1437
+ else:
1438
+ n = self.n_breaks or 5
1439
+ result = np.linspace(inv_limits_sorted[0], inv_limits_sorted[1], n + 2)[1:-1]
1440
+ breaks_val = np.asarray(result, dtype=float)
1441
+ # Discard out of range
1442
+ breaks_val = breaks_val[(breaks_val >= inv_limits_sorted[0]) & (breaks_val <= inv_limits_sorted[1])]
1443
+ elif callable(breaks):
1444
+ n = self.n_breaks or 5
1445
+ try:
1446
+ breaks_val = np.asarray(breaks(inv_limits_sorted, n=n), dtype=float)
1447
+ except TypeError:
1448
+ breaks_val = np.asarray(breaks(inv_limits_sorted), dtype=float)
1449
+ else:
1450
+ breaks_val = np.asarray(breaks, dtype=float)
1451
+
1452
+ return np.asarray(transformation.transform(breaks_val), dtype=float)
1453
+
1454
+ def get_breaks_minor(self, **kwargs: Any) -> None:
1455
+ return None
1456
+
1457
+ def get_labels(self, breaks: Optional[Any] = None) -> Optional[Any]:
1458
+ if breaks is None:
1459
+ breaks = self.get_breaks()
1460
+ if breaks is None:
1461
+ return None
1462
+ transformation = self.get_transformation()
1463
+ breaks_data = transformation.inverse(np.asarray(breaks, dtype=float))
1464
+ labels = self.labels
1465
+ if labels is None:
1466
+ return None
1467
+ if is_waiver(labels):
1468
+ return list(transformation.format_func(breaks_data))
1469
+ if callable(labels):
1470
+ return list(labels(breaks_data))
1471
+ return list(labels)
1472
+
1473
+ def clone(self) -> "ScaleBinned":
1474
+ new = copy.copy(self)
1475
+ new.range = ContinuousRange()
1476
+ return new
1477
+
1478
+ def break_info(self, range: Optional[Any] = None) -> Dict[str, Any]:
1479
+ if range is None:
1480
+ range = self.dimension()
1481
+ range = np.asarray(range, dtype=float)
1482
+ major = self.get_breaks(range)
1483
+ labels = self.get_labels(major)
1484
+ return {
1485
+ "range": range,
1486
+ "labels": labels,
1487
+ "major": major,
1488
+ "minor": None,
1489
+ "major_source": major,
1490
+ "minor_source": None,
1491
+ }
1492
+
1493
+
1494
+ # ---------------------------------------------------------------------------
1495
+ # Position sub-classes
1496
+ # ---------------------------------------------------------------------------
1497
+
1498
+ class ScaleContinuousPosition(ScaleContinuous):
1499
+ """Continuous scale for position aesthetics (x/y)."""
1500
+
1501
+ secondary_axis: Any = None # waiver or AxisSecondary
1502
+
1503
+ def __init_subclass__(cls, **kwargs: Any) -> None:
1504
+ super().__init_subclass__(**kwargs)
1505
+ if cls.secondary_axis is None:
1506
+ cls.secondary_axis = waiver()
1507
+
1508
+ def map(self, x: Any, limits: Optional[Any] = None) -> np.ndarray:
1509
+ """Map position values (oob only, no palette).
1510
+
1511
+ Parameters
1512
+ ----------
1513
+ x : array-like
1514
+ Values in transformed space.
1515
+ limits : array-like, optional
1516
+ Scale limits.
1517
+
1518
+ Returns
1519
+ -------
1520
+ numpy.ndarray
1521
+ """
1522
+ if limits is None:
1523
+ limits = self.get_limits()
1524
+ x_arr = np.asarray(x, dtype=float)
1525
+ scaled = np.asarray(self.oob(x_arr, range=limits), dtype=float)
1526
+ nan_mask = np.isnan(scaled)
1527
+ if np.any(nan_mask):
1528
+ scaled[nan_mask] = self.na_value
1529
+ return scaled
1530
+
1531
+ def break_info(self, range: Optional[Any] = None) -> Dict[str, Any]:
1532
+ info = super().break_info(range)
1533
+ sec = getattr(self, "secondary_axis", None)
1534
+ if sec is not None and not is_waiver(sec) and not sec.empty():
1535
+ sec.init(self)
1536
+ sec_info = sec.break_info(info["range"], self)
1537
+ info.update(sec_info)
1538
+ return info
1539
+
1540
+ def sec_name(self) -> Any:
1541
+ sec = getattr(self, "secondary_axis", None)
1542
+ if sec is None or is_waiver(sec):
1543
+ return waiver()
1544
+ return sec.name
1545
+
1546
+ def make_sec_title(self, *args: Any, **kwargs: Any) -> Any:
1547
+ sec = getattr(self, "secondary_axis", None)
1548
+ if sec is not None and not is_waiver(sec):
1549
+ return sec.make_title(*args, **kwargs)
1550
+ return super().make_sec_title(*args, **kwargs)
1551
+
1552
+
1553
+ class ScaleDiscretePosition(ScaleDiscrete):
1554
+ """Discrete scale for position aesthetics (x/y)."""
1555
+
1556
+ secondary_axis: Any = None
1557
+ continuous_limits: Any = None
1558
+
1559
+ def __init_subclass__(cls, **kwargs: Any) -> None:
1560
+ super().__init_subclass__(**kwargs)
1561
+ if cls.secondary_axis is None:
1562
+ cls.secondary_axis = waiver()
1563
+
1564
+ def __init__(self, **kwargs: Any) -> None:
1565
+ super().__init__(**kwargs)
1566
+ if not hasattr(self, "range_c") or self.range_c is None:
1567
+ self.range_c = ContinuousRange()
1568
+
1569
+ def train(self, x: Any) -> None:
1570
+ if _is_discrete(x):
1571
+ super().train(x)
1572
+ else:
1573
+ self.range_c.train(np.asarray(x, dtype=float))
1574
+
1575
+ def get_limits(self) -> Any:
1576
+ if self.is_empty():
1577
+ return np.array([0.0, 1.0])
1578
+ if callable(self.limits):
1579
+ return self.limits(self.range.range)
1580
+ return self.limits if self.limits is not None else (
1581
+ self.range.range if self.range.range is not None else []
1582
+ )
1583
+
1584
+ def is_empty(self) -> bool:
1585
+ r = self.range.range
1586
+ return (
1587
+ r is None
1588
+ and (self.limits is None or callable(self.limits))
1589
+ and (not hasattr(self, "range_c") or self.range_c is None or self.range_c.range is None)
1590
+ )
1591
+
1592
+ def reset(self) -> None:
1593
+ if hasattr(self, "range_c") and self.range_c is not None:
1594
+ self.range_c.reset()
1595
+
1596
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
1597
+ if limits is None:
1598
+ limits = self.get_limits()
1599
+ if _is_discrete(x):
1600
+ if limits is None or len(limits) == 0:
1601
+ return np.array([])
1602
+ values = self.palette(len(limits))
1603
+ if not isinstance(values, (np.ndarray, list)):
1604
+ cli_abort("The palette function must return a numeric vector.")
1605
+ values = np.asarray(values)
1606
+ x_str = [str(v) for v in np.asarray(x)]
1607
+ limits_str = [str(l) for l in limits]
1608
+ mapped = np.array([
1609
+ values[limits_str.index(v)] if v in limits_str else np.nan
1610
+ for v in x_str
1611
+ ], dtype=float)
1612
+ return mapped_discrete(mapped)
1613
+ return mapped_discrete(np.asarray(x, dtype=float))
1614
+
1615
+ def dimension(
1616
+ self,
1617
+ expand: Optional[np.ndarray] = None,
1618
+ limits: Optional[Any] = None,
1619
+ ) -> np.ndarray:
1620
+ if expand is None:
1621
+ # R default for discrete position scales: expansion(add = 0.6)
1622
+ expand = expansion(0, 0.6)
1623
+ if limits is None:
1624
+ limits = self.get_limits()
1625
+ mapped = self.map(limits)
1626
+ if mapped is None or len(mapped) == 0:
1627
+ lo, hi = 0.0, 1.0
1628
+ else:
1629
+ lo, hi = float(np.nanmin(mapped)), float(np.nanmax(mapped))
1630
+ return expand_range4(np.array([lo, hi]), expand)
1631
+
1632
+ def clone(self) -> "ScaleDiscretePosition":
1633
+ new = copy.copy(self)
1634
+ new.range = DiscreteRange()
1635
+ new.range_c = ContinuousRange()
1636
+ return new
1637
+
1638
+ def sec_name(self) -> Any:
1639
+ sec = getattr(self, "secondary_axis", None)
1640
+ if sec is None or is_waiver(sec):
1641
+ return waiver()
1642
+ return sec.name
1643
+
1644
+
1645
+ class ScaleBinnedPosition(ScaleBinned):
1646
+ """Binned scale for position aesthetics (x/y)."""
1647
+
1648
+ after_stat: bool = False
1649
+
1650
+ def train(self, x: Any) -> None:
1651
+ x_arr = np.asarray(x, dtype=float)
1652
+ if not np.issubdtype(x_arr.dtype, np.number):
1653
+ cli_abort("Binned scales only support continuous data.")
1654
+ if len(x_arr) == 0 or self.after_stat:
1655
+ return
1656
+ self.range.train(x_arr)
1657
+
1658
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
1659
+ if limits is None:
1660
+ limits = self.get_limits()
1661
+ limits = np.asarray(limits, dtype=float)
1662
+ x_arr = np.asarray(x, dtype=float)
1663
+
1664
+ breaks = self.get_breaks(limits)
1665
+ if breaks is None:
1666
+ breaks = np.array([])
1667
+ all_breaks = np.unique(np.sort(np.concatenate([limits[:1], np.asarray(breaks), limits[1:]])))
1668
+
1669
+ x_oob = np.asarray(self.oob(x_arr, range=limits), dtype=float)
1670
+ x_oob = np.where(~np.isnan(x_oob), x_oob, self.na_value)
1671
+ bins = np.digitize(x_oob, all_breaks, right=not self.right)
1672
+ bins = np.clip(bins, 1, len(all_breaks) - 1)
1673
+ return bins.astype(float)
1674
+
1675
+ def reset(self) -> None:
1676
+ self.after_stat = True
1677
+ limits = self.get_limits()
1678
+ breaks = self.get_breaks(limits)
1679
+ self.range.reset()
1680
+ combined = np.concatenate([np.asarray(limits), np.asarray(breaks) if breaks is not None else np.array([])])
1681
+ self.range.train(combined)
1682
+
1683
+ def get_breaks(self, limits: Optional[Any] = None) -> Optional[np.ndarray]:
1684
+ breaks = super().get_breaks(limits)
1685
+ if self.show_limits and breaks is not None:
1686
+ lims = self.get_limits()
1687
+ breaks = np.sort(np.unique(np.concatenate([lims, np.asarray(breaks)])))
1688
+ return breaks
1689
+
1690
+
1691
+ # ---------------------------------------------------------------------------
1692
+ # Identity sub-classes
1693
+ # ---------------------------------------------------------------------------
1694
+
1695
+ class ScaleContinuousIdentity(ScaleContinuous):
1696
+ """Continuous identity scale -- data values are used as-is."""
1697
+
1698
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
1699
+ x_arr = np.asarray(x)
1700
+ if isinstance(x, pd.Categorical):
1701
+ return np.asarray(x.astype(str))
1702
+ return x_arr
1703
+
1704
+ def train(self, x: Any) -> None:
1705
+ if self.guide == "none":
1706
+ return
1707
+ super().train(x)
1708
+
1709
+
1710
+ class ScaleDiscreteIdentity(ScaleDiscrete):
1711
+ """Discrete identity scale -- data values are used as-is."""
1712
+
1713
+ def map(self, x: Any, limits: Optional[Any] = None) -> Any:
1714
+ x_arr = np.asarray(x)
1715
+ if isinstance(x, pd.Categorical):
1716
+ return np.asarray(x.astype(str))
1717
+ return x_arr
1718
+
1719
+ def train(self, x: Any) -> None:
1720
+ if self.guide == "none":
1721
+ return
1722
+ super().train(x)
1723
+
1724
+
1725
+ # ---------------------------------------------------------------------------
1726
+ # Date/Datetime sub-classes
1727
+ # ---------------------------------------------------------------------------
1728
+
1729
+ class ScaleContinuousDate(ScaleContinuous):
1730
+ """Continuous scale for date-valued data."""
1731
+ pass
1732
+
1733
+
1734
+ class ScaleContinuousDatetime(ScaleContinuous):
1735
+ """Continuous scale for datetime-valued data."""
1736
+ pass
1737
+
1738
+
1739
+ # ---------------------------------------------------------------------------
1740
+ # Constructor functions
1741
+ # ---------------------------------------------------------------------------
1742
+
1743
+ def continuous_scale(
1744
+ aesthetics: Union[str, List[str]],
1745
+ palette: Optional[Callable] = None,
1746
+ *,
1747
+ name: Any = None,
1748
+ breaks: Any = None,
1749
+ minor_breaks: Any = None,
1750
+ n_breaks: Optional[int] = None,
1751
+ labels: Any = None,
1752
+ limits: Optional[Any] = None,
1753
+ rescaler: Optional[Callable] = None,
1754
+ oob: Optional[Callable] = None,
1755
+ expand: Any = None,
1756
+ na_value: Any = np.nan,
1757
+ transform: Union[str, Any] = "identity",
1758
+ trans: Optional[Any] = None,
1759
+ guide: Any = "legend",
1760
+ position: str = "left",
1761
+ fallback_palette: Optional[Callable] = None,
1762
+ super_class: Optional[Type[ScaleContinuous]] = None,
1763
+ ) -> ScaleContinuous:
1764
+ """Construct a continuous scale.
1765
+
1766
+ Parameters
1767
+ ----------
1768
+ aesthetics : str or list of str
1769
+ Aesthetic names this scale applies to.
1770
+ palette : callable, optional
1771
+ Palette function mapping [0,1] to aesthetic values.
1772
+ name : str or Waiver, optional
1773
+ Scale title.
1774
+ breaks : array-like, callable, or None
1775
+ Break specification.
1776
+ minor_breaks : array-like, callable, or None
1777
+ Minor break specification.
1778
+ n_breaks : int, optional
1779
+ Desired number of breaks.
1780
+ labels : array-like, callable, or None
1781
+ Label specification.
1782
+ limits : array-like, callable, or None
1783
+ Scale limits.
1784
+ rescaler : callable, optional
1785
+ Rescaling function (default ``rescale``).
1786
+ oob : callable, optional
1787
+ Out-of-bounds handler (default ``censor``).
1788
+ expand : array-like or Waiver, optional
1789
+ Expansion.
1790
+ na_value : any
1791
+ Value to use for missing data.
1792
+ transform : str or Transform
1793
+ Transformation name or object.
1794
+ trans : str or Transform, optional
1795
+ Deprecated alias for *transform*.
1796
+ guide : str
1797
+ Guide type.
1798
+ position : str
1799
+ Axis position.
1800
+ fallback_palette : callable, optional
1801
+ Palette to use when *palette* is None and theme provides none.
1802
+ super_class : type, optional
1803
+ Scale class to instantiate (default ``ScaleContinuous``).
1804
+
1805
+ Returns
1806
+ -------
1807
+ ScaleContinuous
1808
+ """
1809
+ if name is None:
1810
+ name = waiver()
1811
+ if breaks is None:
1812
+ breaks = waiver()
1813
+ if minor_breaks is None:
1814
+ minor_breaks = waiver()
1815
+ if labels is None:
1816
+ labels = waiver()
1817
+ if expand is None:
1818
+ expand = waiver()
1819
+
1820
+ if trans is not None:
1821
+ deprecate_warn("3.5.0", "continuous_scale(trans=)", with_="continuous_scale(transform=)")
1822
+ transform = trans
1823
+
1824
+ if isinstance(aesthetics, str):
1825
+ aesthetics = [aesthetics]
1826
+ aesthetics = standardise_aes_names(aesthetics)
1827
+
1828
+ _check_breaks_labels(breaks, labels)
1829
+
1830
+ if position not in ("left", "right", "top", "bottom"):
1831
+ cli_abort(f"position must be one of 'left', 'right', 'top', 'bottom', got '{position}'.")
1832
+
1833
+ # If non-positional scale with breaks=None, remove guide
1834
+ if breaks is None and not _is_position_aes(aesthetics):
1835
+ guide = "none"
1836
+
1837
+ if isinstance(transform, str):
1838
+ transform = as_transform(transform)
1839
+
1840
+ # Transform limits if provided
1841
+ if limits is not None and not callable(limits):
1842
+ limits_arr = np.asarray(limits, dtype=float)
1843
+ limits_arr = transform.transform(limits_arr)
1844
+ if not np.any(np.isnan(limits_arr)):
1845
+ limits_arr = np.sort(limits_arr)
1846
+ limits = limits_arr
1847
+
1848
+ if super_class is None:
1849
+ super_class = ScaleContinuous
1850
+
1851
+ sc = super_class()
1852
+ sc.aesthetics = list(aesthetics)
1853
+ sc.palette = palette
1854
+ sc.fallback_palette = fallback_palette
1855
+ sc.range = ContinuousRange()
1856
+ sc.limits = limits
1857
+ sc.trans = transform
1858
+ sc.na_value = na_value
1859
+ sc.expand = expand
1860
+ sc.rescaler = rescaler if rescaler is not None else rescale
1861
+ sc.oob = oob if oob is not None else censor
1862
+ sc.name = name
1863
+ sc.breaks = breaks
1864
+ sc.minor_breaks = minor_breaks
1865
+ sc.n_breaks = n_breaks
1866
+ sc.labels = labels
1867
+ sc.guide = guide
1868
+ sc.position = position
1869
+ return sc
1870
+
1871
+
1872
+ def discrete_scale(
1873
+ aesthetics: Union[str, List[str]],
1874
+ palette: Optional[Callable] = None,
1875
+ *,
1876
+ name: Any = None,
1877
+ breaks: Any = None,
1878
+ minor_breaks: Any = None,
1879
+ labels: Any = None,
1880
+ limits: Optional[Any] = None,
1881
+ expand: Any = None,
1882
+ na_translate: bool = True,
1883
+ na_value: Any = np.nan,
1884
+ drop: bool = True,
1885
+ guide: Any = "legend",
1886
+ position: str = "left",
1887
+ fallback_palette: Optional[Callable] = None,
1888
+ super_class: Optional[Type[ScaleDiscrete]] = None,
1889
+ ) -> ScaleDiscrete:
1890
+ """Construct a discrete scale.
1891
+
1892
+ Parameters
1893
+ ----------
1894
+ aesthetics : str or list of str
1895
+ Aesthetic names this scale applies to.
1896
+ palette : callable, optional
1897
+ Palette function taking an integer and returning *n* values.
1898
+ name : str or Waiver, optional
1899
+ Scale title.
1900
+ breaks : array-like, callable, or None
1901
+ Break specification.
1902
+ minor_breaks : array-like, callable, or None
1903
+ Minor break specification.
1904
+ labels : array-like, callable, or None
1905
+ Label specification.
1906
+ limits : array-like, callable, or None
1907
+ Scale limits.
1908
+ expand : array-like or Waiver, optional
1909
+ Expansion.
1910
+ na_translate : bool
1911
+ Whether to translate NAs.
1912
+ na_value : any
1913
+ Value for missing data.
1914
+ drop : bool
1915
+ Whether to drop unused levels.
1916
+ guide : str
1917
+ Guide type.
1918
+ position : str
1919
+ Axis position.
1920
+ fallback_palette : callable, optional
1921
+ Fallback palette.
1922
+ super_class : type, optional
1923
+ Scale class to instantiate (default ``ScaleDiscrete``).
1924
+
1925
+ Returns
1926
+ -------
1927
+ ScaleDiscrete
1928
+ """
1929
+ if name is None:
1930
+ name = waiver()
1931
+ if breaks is None:
1932
+ breaks = waiver()
1933
+ if minor_breaks is None:
1934
+ minor_breaks = waiver()
1935
+ if labels is None:
1936
+ labels = waiver()
1937
+ if expand is None:
1938
+ expand = waiver()
1939
+
1940
+ if isinstance(aesthetics, str):
1941
+ aesthetics = [aesthetics]
1942
+ aesthetics = standardise_aes_names(aesthetics)
1943
+
1944
+ _check_breaks_labels(breaks, labels)
1945
+
1946
+ if position not in ("left", "right", "top", "bottom"):
1947
+ cli_abort(f"position must be one of 'left', 'right', 'top', 'bottom', got '{position}'.")
1948
+
1949
+ # If non-positional scale with breaks=None, remove guide
1950
+ if breaks is None and not _is_position_aes(aesthetics):
1951
+ guide = "none"
1952
+
1953
+ if super_class is None:
1954
+ super_class = ScaleDiscrete
1955
+
1956
+ sc = super_class()
1957
+ sc.aesthetics = list(aesthetics)
1958
+ sc.palette = palette
1959
+ sc.fallback_palette = fallback_palette
1960
+ sc.range = DiscreteRange()
1961
+ sc.limits = limits
1962
+ sc.na_value = na_value
1963
+ sc.na_translate = na_translate
1964
+ sc.expand = expand
1965
+ sc.name = name
1966
+ sc.breaks = breaks
1967
+ sc.minor_breaks = minor_breaks
1968
+ sc.labels = labels
1969
+ sc.drop = drop
1970
+ sc.guide = guide
1971
+ sc.position = position
1972
+ return sc
1973
+
1974
+
1975
+ def binned_scale(
1976
+ aesthetics: Union[str, List[str]],
1977
+ palette: Optional[Callable] = None,
1978
+ *,
1979
+ name: Any = None,
1980
+ breaks: Any = None,
1981
+ labels: Any = None,
1982
+ limits: Optional[Any] = None,
1983
+ rescaler: Optional[Callable] = None,
1984
+ oob: Optional[Callable] = None,
1985
+ expand: Any = None,
1986
+ na_value: Any = np.nan,
1987
+ n_breaks: Optional[int] = None,
1988
+ nice_breaks: bool = True,
1989
+ right: bool = True,
1990
+ transform: Union[str, Any] = "identity",
1991
+ trans: Optional[Any] = None,
1992
+ show_limits: bool = False,
1993
+ guide: Any = "bins",
1994
+ position: str = "left",
1995
+ fallback_palette: Optional[Callable] = None,
1996
+ super_class: Optional[Type[ScaleBinned]] = None,
1997
+ ) -> ScaleBinned:
1998
+ """Construct a binned scale.
1999
+
2000
+ Parameters
2001
+ ----------
2002
+ aesthetics : str or list of str
2003
+ Aesthetic names this scale applies to.
2004
+ palette : callable, optional
2005
+ Palette function.
2006
+ name : str or Waiver, optional
2007
+ Scale title.
2008
+ breaks : array-like, callable, or None
2009
+ Break specification.
2010
+ labels : array-like, callable, or None
2011
+ Label specification.
2012
+ limits : array-like, callable, or None
2013
+ Scale limits.
2014
+ rescaler : callable, optional
2015
+ Rescaling function.
2016
+ oob : callable, optional
2017
+ Out-of-bounds handler (default ``squish``).
2018
+ expand : array-like or Waiver, optional
2019
+ Expansion.
2020
+ na_value : any
2021
+ Value for missing data.
2022
+ n_breaks : int, optional
2023
+ Desired number of breaks.
2024
+ nice_breaks : bool
2025
+ Use nicely-spaced breaks.
2026
+ right : bool
2027
+ Bins closed on the right.
2028
+ transform : str or Transform
2029
+ Transformation.
2030
+ trans : str or Transform, optional
2031
+ Deprecated alias for *transform*.
2032
+ show_limits : bool
2033
+ Show limits as ticks.
2034
+ guide : str
2035
+ Guide type.
2036
+ position : str
2037
+ Axis position.
2038
+ fallback_palette : callable, optional
2039
+ Fallback palette.
2040
+ super_class : type, optional
2041
+ Scale class to instantiate (default ``ScaleBinned``).
2042
+
2043
+ Returns
2044
+ -------
2045
+ ScaleBinned
2046
+ """
2047
+ if name is None:
2048
+ name = waiver()
2049
+ if breaks is None:
2050
+ breaks = waiver()
2051
+ if labels is None:
2052
+ labels = waiver()
2053
+ if expand is None:
2054
+ expand = waiver()
2055
+
2056
+ if trans is not None:
2057
+ deprecate_warn("3.5.0", "binned_scale(trans=)", with_="binned_scale(transform=)")
2058
+ transform = trans
2059
+
2060
+ if isinstance(aesthetics, str):
2061
+ aesthetics = [aesthetics]
2062
+ aesthetics = standardise_aes_names(aesthetics)
2063
+
2064
+ _check_breaks_labels(breaks, labels)
2065
+
2066
+ if position not in ("left", "right", "top", "bottom"):
2067
+ cli_abort(f"position must be one of 'left', 'right', 'top', 'bottom', got '{position}'.")
2068
+
2069
+ if breaks is None and not _is_position_aes(aesthetics) and guide != "none":
2070
+ guide = "none"
2071
+
2072
+ if isinstance(transform, str):
2073
+ transform = as_transform(transform)
2074
+
2075
+ if limits is not None and not callable(limits):
2076
+ limits_arr = np.asarray(limits, dtype=float)
2077
+ limits_arr = transform.transform(limits_arr)
2078
+ if not np.any(np.isnan(limits_arr)):
2079
+ limits_arr = np.sort(limits_arr)
2080
+ limits = limits_arr
2081
+
2082
+ if super_class is None:
2083
+ super_class = ScaleBinned
2084
+
2085
+ sc = super_class()
2086
+ sc.aesthetics = list(aesthetics)
2087
+ sc.palette = palette
2088
+ sc.fallback_palette = fallback_palette
2089
+ sc.range = ContinuousRange()
2090
+ sc.limits = limits
2091
+ sc.trans = transform
2092
+ sc.na_value = na_value
2093
+ sc.expand = expand
2094
+ sc.rescaler = rescaler if rescaler is not None else rescale
2095
+ sc.oob = oob if oob is not None else squish
2096
+ sc.n_breaks = n_breaks
2097
+ sc.nice_breaks = nice_breaks
2098
+ sc.right = right
2099
+ sc.show_limits = show_limits
2100
+ sc.name = name
2101
+ sc.breaks = breaks
2102
+ sc.labels = labels
2103
+ sc.guide = guide
2104
+ sc.position = position
2105
+ return sc
2106
+
2107
+
2108
+ # ---------------------------------------------------------------------------
2109
+ # ScalesList container
2110
+ # ---------------------------------------------------------------------------
2111
+
2112
+ class ScalesList:
2113
+ """Container for a plot's collection of scales.
2114
+
2115
+ Attributes
2116
+ ----------
2117
+ scales : list of Scale
2118
+ The individual scales.
2119
+ """
2120
+
2121
+ def __init__(self) -> None:
2122
+ self.scales: List[Scale] = []
2123
+
2124
+ def find(self, aesthetic: str) -> List[bool]:
2125
+ """Return a boolean mask of scales matching *aesthetic*.
2126
+
2127
+ Parameters
2128
+ ----------
2129
+ aesthetic : str
2130
+ Aesthetic name.
2131
+
2132
+ Returns
2133
+ -------
2134
+ list of bool
2135
+ """
2136
+ return [any(aesthetic in s.aesthetics for aesthetic in [aesthetic]) for s in self.scales]
2137
+
2138
+ def has_scale(self, aesthetic: str) -> bool:
2139
+ """Check whether a scale exists for *aesthetic*.
2140
+
2141
+ Parameters
2142
+ ----------
2143
+ aesthetic : str
2144
+ Aesthetic name.
2145
+
2146
+ Returns
2147
+ -------
2148
+ bool
2149
+ """
2150
+ return any(aesthetic in s.aesthetics for s in self.scales)
2151
+
2152
+ def add(self, scale: Optional[Scale]) -> None:
2153
+ """Add a scale, replacing any existing scale for the same aesthetics.
2154
+
2155
+ Parameters
2156
+ ----------
2157
+ scale : Scale or None
2158
+ Scale to add. ``None`` is silently ignored.
2159
+ """
2160
+ if scale is None:
2161
+ return
2162
+ prev_aes = [any(a in s.aesthetics for a in scale.aesthetics) for s in self.scales]
2163
+ if any(prev_aes):
2164
+ first_name = next(
2165
+ s.aesthetics[0] for s, p in zip(self.scales, prev_aes) if p
2166
+ )
2167
+ cli_inform(
2168
+ f"Scale for {first_name} is already present. "
2169
+ f"Adding another scale for {first_name}, which will replace the existing scale."
2170
+ )
2171
+ self.scales = [s for s, p in zip(self.scales, prev_aes) if not p]
2172
+ self.scales.append(scale)
2173
+
2174
+ def n(self) -> int:
2175
+ """Return the number of scales."""
2176
+ return len(self.scales)
2177
+
2178
+ def input(self) -> List[str]:
2179
+ """Return all aesthetic names across all scales."""
2180
+ result: List[str] = []
2181
+ for s in self.scales:
2182
+ result.extend(s.aesthetics)
2183
+ return result
2184
+
2185
+ def clone(self) -> "ScalesList":
2186
+ """Clone the scales list and all its scales."""
2187
+ new = ScalesList()
2188
+ new.scales = [s.clone() for s in self.scales]
2189
+ return new
2190
+
2191
+ def non_position_scales(self) -> "ScalesList":
2192
+ """Return a new ScalesList with only non-position scales."""
2193
+ new = ScalesList()
2194
+ new.scales = [
2195
+ s for s in self.scales
2196
+ if not any(a in _POSITION_AESTHETICS for a in s.aesthetics)
2197
+ ]
2198
+ return new
2199
+
2200
+ def get_scales(self, output: str) -> Optional[Scale]:
2201
+ """Get the scale for a given aesthetic.
2202
+
2203
+ Parameters
2204
+ ----------
2205
+ output : str
2206
+ Aesthetic name.
2207
+
2208
+ Returns
2209
+ -------
2210
+ Scale or None
2211
+ """
2212
+ for s in self.scales:
2213
+ if output in s.aesthetics:
2214
+ return s
2215
+ return None
2216
+
2217
+ def train_df(self, df: pd.DataFrame) -> None:
2218
+ """Train all scales on *df*.
2219
+
2220
+ Parameters
2221
+ ----------
2222
+ df : pandas.DataFrame
2223
+ Layer data.
2224
+ """
2225
+ if _empty(df) or len(self.scales) == 0:
2226
+ return
2227
+ for s in self.scales:
2228
+ s.train_df(df)
2229
+
2230
+ def map_df(self, df: pd.DataFrame) -> pd.DataFrame:
2231
+ """Map all scales on *df*.
2232
+
2233
+ Parameters
2234
+ ----------
2235
+ df : pandas.DataFrame
2236
+ Layer data.
2237
+
2238
+ Returns
2239
+ -------
2240
+ pandas.DataFrame
2241
+ Data with mapped columns.
2242
+ """
2243
+ if _empty(df) or len(self.scales) == 0:
2244
+ return df
2245
+ for s in self.scales:
2246
+ mapped = s.map_df(df)
2247
+ for k, v in mapped.items():
2248
+ df[k] = v
2249
+ return df
2250
+
2251
+ def transform_df(self, df: pd.DataFrame) -> pd.DataFrame:
2252
+ """Transform all scale columns in *df*.
2253
+
2254
+ Parameters
2255
+ ----------
2256
+ df : pandas.DataFrame
2257
+ Layer data.
2258
+
2259
+ Returns
2260
+ -------
2261
+ pandas.DataFrame
2262
+ Data with transformed columns.
2263
+ """
2264
+ if _empty(df):
2265
+ return df
2266
+ for s in self.scales:
2267
+ transformed = s.transform_df(df)
2268
+ for k, v in transformed.items():
2269
+ df[k] = v
2270
+ return df
2271
+
2272
+ def add_defaults(self, data: pd.DataFrame, env: Optional[Any] = None) -> None:
2273
+ """Add default scales for aesthetics in *data* not yet covered.
2274
+
2275
+ Parameters
2276
+ ----------
2277
+ data : pandas.DataFrame
2278
+ Layer data.
2279
+ env : any, optional
2280
+ Lookup environment (unused in Python port).
2281
+ """
2282
+ existing = set(self.input())
2283
+ for aes_name in data.columns:
2284
+ if aes_name not in existing:
2285
+ sc = find_scale(aes_name, data[aes_name])
2286
+ if sc is not None:
2287
+ self.add(sc)
2288
+
2289
+ def add_missing(self, aesthetics: List[str], env: Optional[Any] = None) -> None:
2290
+ """Add missing but required scales.
2291
+
2292
+ Parameters
2293
+ ----------
2294
+ aesthetics : list of str
2295
+ Required aesthetic names (typically ``['x', 'y']``).
2296
+ env : any, optional
2297
+ Lookup environment (unused).
2298
+ """
2299
+ existing = set(self.input())
2300
+ for aes_name in aesthetics:
2301
+ if aes_name not in existing:
2302
+ sc = _default_continuous_scale(aes_name)
2303
+ if sc is not None:
2304
+ self.add(sc)
2305
+
2306
+
2307
+ def scales_list() -> ScalesList:
2308
+ """Create a new empty :class:`ScalesList`.
2309
+
2310
+ Returns
2311
+ -------
2312
+ ScalesList
2313
+ """
2314
+ return ScalesList()
2315
+
2316
+
2317
+ def _default_continuous_scale(aes: str) -> Optional[Scale]:
2318
+ """Create a default continuous scale for the given aesthetic."""
2319
+ if aes in ("x", "xmin", "xmax", "xend", "xintercept"):
2320
+ return continuous_scale(
2321
+ _X_AESTHETICS,
2322
+ palette=lambda x: x,
2323
+ position="bottom",
2324
+ super_class=ScaleContinuousPosition,
2325
+ )
2326
+ if aes in ("y", "ymin", "ymax", "yend", "yintercept"):
2327
+ return continuous_scale(
2328
+ _Y_AESTHETICS,
2329
+ palette=lambda x: x,
2330
+ position="left",
2331
+ super_class=ScaleContinuousPosition,
2332
+ )
2333
+ return None
2334
+
2335
+
2336
+ # ---------------------------------------------------------------------------
2337
+ # Secondary axis support
2338
+ # ---------------------------------------------------------------------------
2339
+
2340
+ class _Derived:
2341
+ """Sentinel for inheriting settings from the primary axis."""
2342
+
2343
+ def __repr__(self) -> str:
2344
+ return "derive()"
2345
+
2346
+
2347
+ def derive() -> _Derived:
2348
+ """Return a ``derive()`` sentinel for secondary axis inheritance.
2349
+
2350
+ Returns
2351
+ -------
2352
+ _Derived
2353
+ """
2354
+ return _Derived()
2355
+
2356
+
2357
+ def is_derived(x: Any) -> bool:
2358
+ """Check whether *x* is a ``derive()`` sentinel.
2359
+
2360
+ Parameters
2361
+ ----------
2362
+ x : Any
2363
+
2364
+ Returns
2365
+ -------
2366
+ bool
2367
+ """
2368
+ return isinstance(x, _Derived)
2369
+
2370
+
2371
+ class AxisSecondary:
2372
+ """Specification for a secondary axis.
2373
+
2374
+ Parameters
2375
+ ----------
2376
+ trans : callable
2377
+ Monotonic transformation from primary to secondary scale.
2378
+ name : any
2379
+ Axis title.
2380
+ breaks : any
2381
+ Break specification.
2382
+ labels : any
2383
+ Label specification.
2384
+ guide : any
2385
+ Guide specification.
2386
+ """
2387
+
2388
+ def __init__(
2389
+ self,
2390
+ trans: Optional[Callable] = None,
2391
+ name: Any = None,
2392
+ breaks: Any = None,
2393
+ labels: Any = None,
2394
+ guide: Any = None,
2395
+ ) -> None:
2396
+ self.trans = trans
2397
+ self.name = name if name is not None else waiver()
2398
+ self.breaks = breaks if breaks is not None else waiver()
2399
+ self.labels = labels if labels is not None else waiver()
2400
+ self.guide = guide if guide is not None else waiver()
2401
+ self.detail = 1000
2402
+
2403
+ def empty(self) -> bool:
2404
+ """Whether this secondary axis is empty (no transform)."""
2405
+ return self.trans is None
2406
+
2407
+ def init(self, scale: Scale) -> None:
2408
+ """Inherit settings from the primary scale.
2409
+
2410
+ Parameters
2411
+ ----------
2412
+ scale : Scale
2413
+ The primary scale.
2414
+ """
2415
+ if self.empty():
2416
+ return
2417
+ if is_derived(self.name) and not is_waiver(scale.name):
2418
+ self.name = scale.name
2419
+ if is_derived(self.breaks):
2420
+ self.breaks = scale.breaks
2421
+ if is_waiver(self.breaks):
2422
+ transformation = scale.get_transformation()
2423
+ self.breaks = transformation.breaks_func
2424
+ if is_derived(self.labels):
2425
+ self.labels = scale.labels
2426
+ if is_derived(self.guide):
2427
+ self.guide = scale.guide
2428
+
2429
+ def transform_range(self, range: np.ndarray) -> np.ndarray:
2430
+ """Apply the secondary transform to a range.
2431
+
2432
+ Parameters
2433
+ ----------
2434
+ range : numpy.ndarray
2435
+ Primary range values.
2436
+
2437
+ Returns
2438
+ -------
2439
+ numpy.ndarray
2440
+ Transformed values.
2441
+ """
2442
+ return np.asarray(self.trans(range))
2443
+
2444
+ def break_info(self, range: np.ndarray, scale: Scale) -> Dict[str, Any]:
2445
+ """Compute secondary-axis break information.
2446
+
2447
+ Parameters
2448
+ ----------
2449
+ range : numpy.ndarray
2450
+ Primary axis range.
2451
+ scale : Scale
2452
+ Primary scale.
2453
+
2454
+ Returns
2455
+ -------
2456
+ dict
2457
+ Keys are prefixed with ``sec.``.
2458
+ """
2459
+ if self.empty():
2460
+ return {}
2461
+
2462
+ transformation = scale.get_transformation()
2463
+ along = np.linspace(range[0], range[1], self.detail)
2464
+ old_range = transformation.inverse(along)
2465
+ full_range = self.transform_range(old_range)
2466
+
2467
+ new_range = np.array([np.nanmin(full_range), np.nanmax(full_range)])
2468
+
2469
+ # Create a temporary scale for break info
2470
+ temp_sc = ScaleContinuousPosition()
2471
+ temp_sc.name = self.name
2472
+ temp_sc.breaks = self.breaks
2473
+ temp_sc.labels = self.labels
2474
+ temp_sc.limits = new_range
2475
+ temp_sc.expand = np.array([0.0, 0.0])
2476
+ temp_sc.minor_breaks = None # no minor breaks for secondary axis
2477
+ temp_sc.trans = transformation
2478
+ temp_sc.range = ContinuousRange()
2479
+ temp_sc.train(new_range)
2480
+ range_info = temp_sc.break_info()
2481
+
2482
+ result = {}
2483
+ for k, v in range_info.items():
2484
+ result[f"sec.{k}"] = v
2485
+ return result
2486
+
2487
+ def make_title(self, *args: Any, **kwargs: Any) -> Any:
2488
+ """Resolve the secondary axis title."""
2489
+ return ScaleContinuous.make_title(None, *args, **kwargs)
2490
+
2491
+
2492
+ def sec_axis(
2493
+ transform: Optional[Callable] = None,
2494
+ name: Any = None,
2495
+ breaks: Any = None,
2496
+ labels: Any = None,
2497
+ guide: Any = None,
2498
+ trans: Optional[Callable] = None,
2499
+ ) -> AxisSecondary:
2500
+ """Create a secondary axis specification.
2501
+
2502
+ Parameters
2503
+ ----------
2504
+ transform : callable, optional
2505
+ Monotonic transformation function.
2506
+ name : any, optional
2507
+ Axis title.
2508
+ breaks : any, optional
2509
+ Break specification.
2510
+ labels : any, optional
2511
+ Label specification.
2512
+ guide : any, optional
2513
+ Guide specification.
2514
+ trans : callable, optional
2515
+ Deprecated alias for *transform*.
2516
+
2517
+ Returns
2518
+ -------
2519
+ AxisSecondary
2520
+ """
2521
+ if trans is not None:
2522
+ deprecate_warn("3.5.0", "sec_axis(trans=)", with_="sec_axis(transform=)")
2523
+ transform = trans
2524
+
2525
+ if name is None:
2526
+ name = waiver()
2527
+ if breaks is None:
2528
+ breaks = waiver()
2529
+ if labels is None:
2530
+ labels = waiver()
2531
+ if guide is None:
2532
+ guide = waiver()
2533
+
2534
+ return AxisSecondary(
2535
+ trans=transform,
2536
+ name=name,
2537
+ breaks=breaks,
2538
+ labels=labels,
2539
+ guide=guide,
2540
+ )
2541
+
2542
+
2543
+ def dup_axis(
2544
+ transform: Optional[Callable] = None,
2545
+ name: Any = None,
2546
+ breaks: Any = None,
2547
+ labels: Any = None,
2548
+ guide: Any = None,
2549
+ trans: Optional[Callable] = None,
2550
+ ) -> AxisSecondary:
2551
+ """Create a secondary axis that duplicates the primary.
2552
+
2553
+ Parameters
2554
+ ----------
2555
+ transform : callable, optional
2556
+ Transformation (default: identity).
2557
+ name : any, optional
2558
+ Axis title (default: derive from primary).
2559
+ breaks : any, optional
2560
+ Breaks (default: derive from primary).
2561
+ labels : any, optional
2562
+ Labels (default: derive from primary).
2563
+ guide : any, optional
2564
+ Guide (default: derive from primary).
2565
+ trans : callable, optional
2566
+ Deprecated alias for *transform*.
2567
+
2568
+ Returns
2569
+ -------
2570
+ AxisSecondary
2571
+ """
2572
+ if transform is None:
2573
+ transform = lambda x: x
2574
+ if name is None:
2575
+ name = derive()
2576
+ if breaks is None:
2577
+ breaks = derive()
2578
+ if labels is None:
2579
+ labels = derive()
2580
+ if guide is None:
2581
+ guide = derive()
2582
+ return sec_axis(transform=transform, name=name, breaks=breaks, labels=labels, guide=guide, trans=trans)
2583
+
2584
+
2585
+ def is_sec_axis(x: Any) -> bool:
2586
+ """Check whether *x* is an :class:`AxisSecondary`.
2587
+
2588
+ Parameters
2589
+ ----------
2590
+ x : Any
2591
+
2592
+ Returns
2593
+ -------
2594
+ bool
2595
+ """
2596
+ return isinstance(x, AxisSecondary)
2597
+
2598
+
2599
+ def _set_sec_axis(sec_axis_obj: Any, scale: Scale) -> Scale:
2600
+ """Attach a secondary axis to a scale (internal helper).
2601
+
2602
+ Parameters
2603
+ ----------
2604
+ sec_axis_obj : AxisSecondary or Waiver
2605
+ Secondary axis specification.
2606
+ scale : Scale
2607
+ Target scale.
2608
+
2609
+ Returns
2610
+ -------
2611
+ Scale
2612
+ The scale (potentially modified).
2613
+ """
2614
+ if not is_waiver(sec_axis_obj):
2615
+ if scale.is_discrete():
2616
+ if sec_axis_obj.trans is not None and sec_axis_obj.trans is not (lambda x: x):
2617
+ pass # discrete axes must have identity transform
2618
+ if not is_sec_axis(sec_axis_obj):
2619
+ cli_abort("Secondary axes must be specified using sec_axis().")
2620
+ scale.secondary_axis = sec_axis_obj
2621
+ return scale
2622
+
2623
+
2624
+ # ---------------------------------------------------------------------------
2625
+ # Scale detection
2626
+ # ---------------------------------------------------------------------------
2627
+
2628
+ def is_scale(x: Any) -> bool:
2629
+ """Check whether *x* is a Scale instance.
2630
+
2631
+ Parameters
2632
+ ----------
2633
+ x : Any
2634
+
2635
+ Returns
2636
+ -------
2637
+ bool
2638
+ """
2639
+ return isinstance(x, Scale)
2640
+
2641
+
2642
+ def scale_type(x: Any) -> List[str]:
2643
+ """Determine the appropriate scale type for data *x*.
2644
+
2645
+ Parameters
2646
+ ----------
2647
+ x : array-like
2648
+ Data values.
2649
+
2650
+ Returns
2651
+ -------
2652
+ list of str
2653
+ Scale type names (e.g. ``['continuous']`` or ``['discrete']``).
2654
+ """
2655
+ if isinstance(x, pd.Series):
2656
+ if isinstance(x.dtype, pd.CategoricalDtype):
2657
+ if x.cat.ordered:
2658
+ return ["ordinal", "discrete"]
2659
+ return ["discrete"]
2660
+ if pd.api.types.is_bool_dtype(x.dtype):
2661
+ return ["discrete"]
2662
+ if pd.api.types.is_datetime64_any_dtype(x.dtype):
2663
+ return ["datetime", "continuous"]
2664
+ if pd.api.types.is_numeric_dtype(x.dtype):
2665
+ return ["continuous"]
2666
+ if x.dtype == object:
2667
+ return ["discrete"]
2668
+ if isinstance(x, np.ndarray):
2669
+ if x.dtype.kind in ("U", "S", "O", "b"):
2670
+ return ["discrete"]
2671
+ if np.issubdtype(x.dtype, np.datetime64):
2672
+ return ["datetime", "continuous"]
2673
+ if np.issubdtype(x.dtype, np.number):
2674
+ return ["continuous"]
2675
+ return ["continuous"]
2676
+
2677
+
2678
+ #: Aesthetics that R does **not** auto-create a default scale for.
2679
+ #:
2680
+ #: R ships no ``scale_stroke_*`` family — ``stroke`` data flows through
2681
+ #: as raw aesthetic values and is not trained by a scale nor given a
2682
+ #: legend. Python has ``scale_stroke_*()`` available for explicit user
2683
+ #: control, but ``find_scale`` must not auto-instantiate one or a
2684
+ #: spurious "stroke" guide would appear that R never produces.
2685
+ _NO_DEFAULT_SCALE_AES: frozenset = frozenset({"stroke"})
2686
+
2687
+
2688
+ def find_scale(aes: str, x: Any, env: Optional[Any] = None) -> Optional[Scale]:
2689
+ """Auto-detect an appropriate scale for aesthetic *aes* and data *x*.
2690
+
2691
+ Mirrors R's ``find_scale()`` (ggplot2/R/scales-.R) which only looks
2692
+ up scales for aesthetics that R has registered scale constructors
2693
+ for. ``stroke`` is intentionally excluded — see
2694
+ :data:`_NO_DEFAULT_SCALE_AES`.
2695
+
2696
+ Parameters
2697
+ ----------
2698
+ aes : str
2699
+ Aesthetic name.
2700
+ x : array-like
2701
+ Data values.
2702
+ env : any, optional
2703
+ Lookup environment (unused).
2704
+
2705
+ Returns
2706
+ -------
2707
+ Scale or None
2708
+ """
2709
+ if x is None:
2710
+ return None
2711
+ if aes in _NO_DEFAULT_SCALE_AES:
2712
+ return None
2713
+
2714
+ types = scale_type(x)
2715
+
2716
+ for stype in types:
2717
+ # Try to import from the scales module
2718
+ func_name = f"scale_{aes}_{stype}"
2719
+ try:
2720
+ from ggplot2_py import scales as scales_mod
2721
+ func = getattr(scales_mod, func_name, None)
2722
+ if func is not None:
2723
+ return func()
2724
+ except (ImportError, AttributeError):
2725
+ pass
2726
+
2727
+ return None