plotnine 0.14.5__py3-none-any.whl → 0.15.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. plotnine/__init__.py +31 -37
  2. plotnine/_mpl/gridspec.py +265 -0
  3. plotnine/_mpl/layout_manager/__init__.py +6 -0
  4. plotnine/_mpl/layout_manager/_engine.py +87 -0
  5. plotnine/_mpl/layout_manager/_layout_items.py +957 -0
  6. plotnine/_mpl/layout_manager/_layout_tree.py +905 -0
  7. plotnine/_mpl/layout_manager/_spaces.py +1154 -0
  8. plotnine/_mpl/patches.py +70 -34
  9. plotnine/_mpl/text.py +159 -37
  10. plotnine/_mpl/utils.py +78 -10
  11. plotnine/_utils/__init__.py +35 -9
  12. plotnine/_utils/dev.py +45 -27
  13. plotnine/_utils/yippie.py +115 -0
  14. plotnine/animation.py +1 -1
  15. plotnine/coords/coord.py +3 -3
  16. plotnine/coords/coord_trans.py +1 -1
  17. plotnine/data/__init__.py +43 -8
  18. plotnine/data/anscombe-quartet.csv +45 -0
  19. plotnine/doctools.py +2 -2
  20. plotnine/facets/facet.py +34 -43
  21. plotnine/facets/facet_grid.py +14 -6
  22. plotnine/facets/facet_wrap.py +3 -5
  23. plotnine/facets/strips.py +20 -33
  24. plotnine/geoms/annotate.py +3 -3
  25. plotnine/geoms/annotation_logticks.py +2 -0
  26. plotnine/geoms/annotation_stripes.py +2 -0
  27. plotnine/geoms/geom.py +3 -3
  28. plotnine/geoms/geom_bar.py +10 -2
  29. plotnine/geoms/geom_col.py +6 -0
  30. plotnine/geoms/geom_crossbar.py +2 -3
  31. plotnine/geoms/geom_path.py +2 -2
  32. plotnine/geoms/geom_violin.py +24 -7
  33. plotnine/ggplot.py +95 -66
  34. plotnine/guides/guide.py +19 -20
  35. plotnine/guides/guide_colorbar.py +6 -6
  36. plotnine/guides/guide_legend.py +15 -16
  37. plotnine/guides/guides.py +8 -8
  38. plotnine/helpers.py +49 -0
  39. plotnine/iapi.py +33 -7
  40. plotnine/labels.py +8 -3
  41. plotnine/layer.py +4 -4
  42. plotnine/mapping/_env.py +2 -2
  43. plotnine/mapping/_eval_environment.py +85 -0
  44. plotnine/mapping/aes.py +14 -30
  45. plotnine/mapping/evaluation.py +7 -65
  46. plotnine/options.py +14 -7
  47. plotnine/plot_composition/__init__.py +10 -0
  48. plotnine/plot_composition/_compose.py +462 -0
  49. plotnine/plot_composition/_plotspec.py +50 -0
  50. plotnine/plot_composition/_spacer.py +32 -0
  51. plotnine/positions/position_dodge.py +1 -1
  52. plotnine/positions/position_dodge2.py +1 -1
  53. plotnine/positions/position_stack.py +1 -2
  54. plotnine/qplot.py +1 -2
  55. plotnine/scales/__init__.py +0 -6
  56. plotnine/scales/limits.py +7 -7
  57. plotnine/scales/scale.py +4 -4
  58. plotnine/scales/scale_continuous.py +2 -1
  59. plotnine/scales/scale_identity.py +10 -2
  60. plotnine/scales/scale_manual.py +6 -2
  61. plotnine/stats/binning.py +5 -2
  62. plotnine/stats/smoothers.py +3 -5
  63. plotnine/stats/stat.py +3 -3
  64. plotnine/stats/stat_bindot.py +1 -3
  65. plotnine/stats/stat_density.py +2 -2
  66. plotnine/stats/stat_qq_line.py +1 -1
  67. plotnine/stats/stat_sina.py +34 -1
  68. plotnine/themes/elements/__init__.py +3 -0
  69. plotnine/themes/elements/element_text.py +35 -24
  70. plotnine/themes/elements/margin.py +137 -61
  71. plotnine/themes/targets.py +3 -1
  72. plotnine/themes/theme.py +21 -7
  73. plotnine/themes/theme_538.py +0 -1
  74. plotnine/themes/theme_bw.py +0 -1
  75. plotnine/themes/theme_dark.py +0 -1
  76. plotnine/themes/theme_gray.py +32 -34
  77. plotnine/themes/theme_light.py +1 -1
  78. plotnine/themes/theme_matplotlib.py +28 -31
  79. plotnine/themes/theme_seaborn.py +36 -36
  80. plotnine/themes/theme_void.py +25 -27
  81. plotnine/themes/theme_xkcd.py +0 -1
  82. plotnine/themes/themeable.py +369 -169
  83. plotnine/typing.py +3 -3
  84. plotnine/watermark.py +3 -3
  85. {plotnine-0.14.5.dist-info → plotnine-0.15.0a2.dist-info}/METADATA +8 -5
  86. {plotnine-0.14.5.dist-info → plotnine-0.15.0a2.dist-info}/RECORD +89 -78
  87. {plotnine-0.14.5.dist-info → plotnine-0.15.0a2.dist-info}/WHEEL +1 -1
  88. plotnine/_mpl/_plot_side_space.py +0 -888
  89. plotnine/_mpl/_plotnine_tight_layout.py +0 -293
  90. plotnine/_mpl/layout_engine.py +0 -110
  91. {plotnine-0.14.5.dist-info → plotnine-0.15.0a2.dist-info/licenses}/LICENSE +0 -0
  92. {plotnine-0.14.5.dist-info → plotnine-0.15.0a2.dist-info}/top_level.txt +0 -0
plotnine/geoms/geom.py CHANGED
@@ -426,7 +426,7 @@ class geom(ABC, metaclass=Register):
426
426
  msg = "The geom should implement this method."
427
427
  raise NotImplementedError(msg)
428
428
 
429
- def __radd__(self, plot: ggplot) -> ggplot:
429
+ def __radd__(self, other: ggplot) -> ggplot:
430
430
  """
431
431
  Add layer representing geom object on the right
432
432
 
@@ -440,8 +440,8 @@ class geom(ABC, metaclass=Register):
440
440
  :
441
441
  ggplot object with added layer.
442
442
  """
443
- plot += self.to_layer() # Add layer
444
- return plot
443
+ other += self.to_layer() # Add layer
444
+ return other
445
445
 
446
446
  def to_layer(self) -> layer:
447
447
  """
@@ -20,6 +20,11 @@ class geom_bar(geom_rect):
20
20
  Parameters
21
21
  ----------
22
22
  {common_parameters}
23
+ just : float, default=0.5
24
+ How to align the column with respect to the axis breaks. The default
25
+ `0.5` aligns the center of the column with the break. `0` aligns the
26
+ left of the of the column with the break and `1` aligns the right of
27
+ the column with the break.
23
28
  width : float, default=None
24
29
  Bar width. If `None`{.py}, the width is set to
25
30
  `90%` of the resolution of the data.
@@ -35,6 +40,7 @@ class geom_bar(geom_rect):
35
40
  "stat": "count",
36
41
  "position": "stack",
37
42
  "na_rm": False,
43
+ "just": 0.5,
38
44
  "width": None,
39
45
  }
40
46
 
@@ -45,6 +51,8 @@ class geom_bar(geom_rect):
45
51
  else:
46
52
  data["width"] = resolution(data["x"], False) * 0.9
47
53
 
54
+ just = self.params.get("just", 0.5)
55
+
48
56
  bool_idx = data["y"] < 0
49
57
 
50
58
  data["ymin"] = 0.0
@@ -53,7 +61,7 @@ class geom_bar(geom_rect):
53
61
  data["ymax"] = data["y"]
54
62
  data.loc[bool_idx, "ymax"] = 0.0
55
63
 
56
- data["xmin"] = data["x"] - data["width"] / 2
57
- data["xmax"] = data["x"] + data["width"] / 2
64
+ data["xmin"] = data["x"] - data["width"] * just
65
+ data["xmax"] = data["x"] + data["width"] * (1 - just)
58
66
  del data["width"]
59
67
  return data
@@ -17,6 +17,11 @@ class geom_col(geom_bar):
17
17
  Parameters
18
18
  ----------
19
19
  {common_parameters}
20
+ just : float, default=0.5
21
+ How to align the column with respect to the axis breaks. The default
22
+ `0.5` aligns the center of the column with the break. `0` aligns the
23
+ left of the of the column with the break and `1` aligns the right of
24
+ the column with the break.
20
25
  width : float, default=None
21
26
  Bar width. If `None`{.py}, the width is set to
22
27
  `90%` of the resolution of the data.
@@ -32,5 +37,6 @@ class geom_col(geom_bar):
32
37
  "stat": "identity",
33
38
  "position": "stack",
34
39
  "na_rm": False,
40
+ "just": 0.5,
35
41
  "width": None,
36
42
  }
@@ -88,7 +88,7 @@ class geom_crossbar(geom):
88
88
  group = data["group"]
89
89
 
90
90
  # From violin
91
- notchwidth = typing.cast(float, params.get("notchwidth"))
91
+ notchwidth = typing.cast("float", params.get("notchwidth"))
92
92
  # ynotchupper = data.get('ynotchupper')
93
93
  # ynotchlower = data.get('ynotchlower')
94
94
 
@@ -110,8 +110,7 @@ class geom_crossbar(geom):
110
110
 
111
111
  if any(ynotchlower < ymin) or any(ynotchupper > ymax):
112
112
  warn(
113
- "Notch went outside the hinges. "
114
- "Try setting notch=False.",
113
+ "Notch went outside the hinges. Try setting notch=False.",
115
114
  PlotnineWarning,
116
115
  )
117
116
 
@@ -91,7 +91,7 @@ class geom_path(geom):
91
91
 
92
92
  # return data
93
93
  n1 = len(data)
94
- data = data.loc[bool_idx] # pyright: ignore[reportCallIssue,reportArgumentType]
94
+ data = data.loc[bool_idx]
95
95
  data.reset_index(drop=True, inplace=True)
96
96
  n2 = len(data)
97
97
 
@@ -482,7 +482,7 @@ def _draw_segments(data: pd.DataFrame, ax: Axes, **params: Any):
482
482
  linestyle = data.loc[indices, "linetype"]
483
483
 
484
484
  coll = LineCollection(
485
- segments, # pyright: ignore[reportArgumentType]
485
+ segments,
486
486
  edgecolor=edgecolor,
487
487
  linewidth=linewidth,
488
488
  linestyle=linestyle,
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- import typing
3
+ from typing import TYPE_CHECKING, cast
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
@@ -11,7 +11,7 @@ from .geom import geom
11
11
  from .geom_path import geom_path
12
12
  from .geom_polygon import geom_polygon
13
13
 
14
- if typing.TYPE_CHECKING:
14
+ if TYPE_CHECKING:
15
15
  from typing import Any
16
16
 
17
17
  from matplotlib.axes import Axes
@@ -115,10 +115,17 @@ class geom_violin(geom):
115
115
  ax: Axes,
116
116
  **params: Any,
117
117
  ):
118
- quantiles = params["draw_quantiles"]
119
- style = params["style"]
118
+ quantiles = params.pop("draw_quantiles")
119
+ style = params.pop("style")
120
+ zorder = params.pop("zorder")
121
+
122
+ for i, (group, df) in enumerate(data.groupby("group")):
123
+ # Place the violins with the smalleer group number on top
124
+ # of those with larger numbers. The group_zorder values should be
125
+ # in the range [zorder, zorder + 1) to stay within the layer.
126
+ group = cast("int", group)
127
+ group_zorder = zorder + 0.9 / group
120
128
 
121
- for i, (_, df) in enumerate(data.groupby("group")):
122
129
  # Find the points for the line to go all the way around
123
130
  df["xminv"] = df["x"] - df["violinwidth"] * (df["x"] - df["xmin"])
124
131
  df["xmaxv"] = df["x"] + df["violinwidth"] * (df["xmax"] - df["x"])
@@ -156,7 +163,12 @@ class geom_violin(geom):
156
163
 
157
164
  # plot violin polygon
158
165
  geom_polygon.draw_group(
159
- polygon_df, panel_params, coord, ax, **params
166
+ polygon_df,
167
+ panel_params,
168
+ coord,
169
+ ax,
170
+ zorder=group_zorder,
171
+ **params,
160
172
  )
161
173
 
162
174
  if quantiles is not None:
@@ -174,7 +186,12 @@ class geom_violin(geom):
174
186
 
175
187
  # plot quantile segments
176
188
  geom_path.draw_group(
177
- segment_df, panel_params, coord, ax, **params
189
+ segment_df,
190
+ panel_params,
191
+ coord,
192
+ ax,
193
+ zorder=group_zorder,
194
+ **params,
178
195
  )
179
196
 
180
197
 
plotnine/ggplot.py CHANGED
@@ -6,7 +6,15 @@ from io import BytesIO
6
6
  from itertools import chain
7
7
  from pathlib import Path
8
8
  from types import SimpleNamespace as NS
9
- from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, cast
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Dict,
13
+ Iterable,
14
+ Optional,
15
+ cast,
16
+ overload,
17
+ )
10
18
  from warnings import warn
11
19
 
12
20
  from ._utils import (
@@ -44,9 +52,11 @@ if TYPE_CHECKING:
44
52
  from typing_extensions import Self
45
53
 
46
54
  from plotnine import watermark
55
+ from plotnine._mpl.gridspec import p9GridSpec
47
56
  from plotnine.coords.coord import coord
48
57
  from plotnine.facets.facet import facet
49
58
  from plotnine.layer import layer
59
+ from plotnine.plot_composition import Compose
50
60
  from plotnine.typing import DataLike
51
61
 
52
62
  class PlotAddable(Protocol):
@@ -54,7 +64,7 @@ if TYPE_CHECKING:
54
64
  Object that can be added to a ggplot object
55
65
  """
56
66
 
57
- def __radd__(self, plot: ggplot) -> ggplot:
67
+ def __radd__(self, other: ggplot) -> ggplot:
58
68
  """
59
69
  Add to ggplot object
60
70
 
@@ -95,9 +105,7 @@ class ggplot:
95
105
 
96
106
  figure: Figure
97
107
  axs: list[Axes]
98
- theme: theme
99
- facet: facet
100
- coordinates: coord
108
+ _gridspec: p9GridSpec
101
109
 
102
110
  def __init__(
103
111
  self,
@@ -110,7 +118,7 @@ class ggplot:
110
118
  data, mapping = order_as_data_mapping(data, mapping)
111
119
  self.data = data
112
120
  self.mapping = mapping if mapping is not None else aes()
113
- self.facet = facet_null()
121
+ self.facet: facet = facet_null()
114
122
  self.labels = make_labels(self.mapping)
115
123
  self.layers = Layers()
116
124
  self.guides = guides()
@@ -147,6 +155,11 @@ class ggplot:
147
155
  Users should prefer this method instead of printing or repring
148
156
  the object.
149
157
  """
158
+ # Prevent against any modifications to the users
159
+ # ggplot object. Do the copy here as we may/may not
160
+ # assign a default theme
161
+ self = deepcopy(self)
162
+
150
163
  if is_inline_backend() or is_quarto_environment():
151
164
  # Take charge of the display because we have to make
152
165
  # adjustments for retina output.
@@ -167,18 +180,15 @@ class ggplot:
167
180
  format = get_option("figure_format") or ip.config.InlineBackend.get(
168
181
  "figure_format", "retina"
169
182
  )
170
- save_format = format
171
-
172
183
  # While jpegs can be displayed as retina, we restrict the output
173
184
  # of "retina" to png
174
185
  if format == "retina":
175
186
  self = copy(self)
176
187
  self.theme = self.theme.to_retina()
177
- save_format = "png"
178
188
 
179
- figure_size_px = self.theme._figure_size_px
180
189
  buf = BytesIO()
181
- self.save(buf, format=save_format, verbose=False)
190
+ self.save(buf, "png" if format == "retina" else format, verbose=False)
191
+ figure_size_px = self.theme._figure_size_px
182
192
  display_func = get_display_function(format, figure_size_px)
183
193
  display_func(buf.getvalue())
184
194
 
@@ -193,7 +203,7 @@ class ggplot:
193
203
  new = result.__dict__
194
204
 
195
205
  # don't make a deepcopy of data
196
- shallow = {"data", "figure", "_build_objs"}
206
+ shallow = {"data", "figure", "gs", "_build_objs"}
197
207
  for key, item in old.items():
198
208
  if key in shallow:
199
209
  new[key] = item
@@ -220,9 +230,20 @@ class ggplot:
220
230
  other.__radd__(self)
221
231
  return self
222
232
 
223
- def __add__(self, other: PlotAddable | list[PlotAddable] | None) -> ggplot:
233
+ @overload
234
+ def __add__(
235
+ self, rhs: PlotAddable | list[PlotAddable] | None
236
+ ) -> ggplot: ...
237
+
238
+ @overload
239
+ def __add__(self, rhs: ggplot | Compose) -> Compose: ...
240
+
241
+ def __add__(
242
+ self,
243
+ rhs: PlotAddable | list[PlotAddable] | None | ggplot | Compose,
244
+ ) -> ggplot | Compose:
224
245
  """
225
- Add to ggplot from a list
246
+ Add to ggplot
226
247
 
227
248
  Parameters
228
249
  ----------
@@ -230,8 +251,37 @@ class ggplot:
230
251
  Either an object that knows how to "radd"
231
252
  itself to a ggplot, or a list of such objects.
232
253
  """
254
+ from .plot_composition import ADD, Compose
255
+
256
+ if isinstance(rhs, (ggplot, Compose)):
257
+ return ADD([self, rhs])
258
+
233
259
  self = deepcopy(self)
234
- return self.__iadd__(other)
260
+ return self.__iadd__(rhs)
261
+
262
+ def __or__(self, rhs: ggplot | Compose) -> Compose:
263
+ """
264
+ Compose 2 plots columnwise
265
+ """
266
+ from .plot_composition import OR
267
+
268
+ return OR([self, rhs])
269
+
270
+ def __truediv__(self, rhs: ggplot | Compose) -> Compose:
271
+ """
272
+ Compose 2 plots rowwise
273
+ """
274
+ from .plot_composition import DIV
275
+
276
+ return DIV([self, rhs])
277
+
278
+ def __sub__(self, rhs: ggplot | Compose) -> Compose:
279
+ """
280
+ Compose 2 plots columnwise
281
+ """
282
+ from .plot_composition import OR
283
+
284
+ return OR([self, rhs])
235
285
 
236
286
  def __rrshift__(self, other: DataLike) -> ggplot:
237
287
  """
@@ -248,7 +298,7 @@ class ggplot:
248
298
  raise TypeError(msg.format(type(other)))
249
299
  return self
250
300
 
251
- def draw(self, show: bool = False) -> Figure:
301
+ def draw(self, *, show: bool = False) -> Figure:
252
302
  """
253
303
  Render the complete plot
254
304
 
@@ -262,23 +312,17 @@ class ggplot:
262
312
  :
263
313
  Matplotlib figure
264
314
  """
265
- from ._mpl.layout_engine import PlotnineLayoutEngine
266
-
267
- # Do not draw if drawn already.
268
- # This prevents a needless error when reusing
269
- # figure & axes in the jupyter notebook.
270
- if hasattr(self, "figure"):
271
- return self.figure
315
+ from ._mpl.layout_manager import PlotnineLayoutEngine
272
316
 
273
- # Prevent against any modifications to the users
274
- # ggplot object. Do the copy here as we may/may not
275
- # assign a default theme
276
- self = deepcopy(self)
277
317
  with plot_context(self, show=show):
318
+ if not hasattr(self, "figure"):
319
+ self._create_figure()
320
+ figure = self.figure
321
+
278
322
  self._build()
279
323
 
280
324
  # setup
281
- self.figure, self.axs = self.facet.setup(self)
325
+ self.axs = self.facet.setup(self)
282
326
  self.guides._setup(self)
283
327
  self.theme.setup(self)
284
328
 
@@ -289,51 +333,24 @@ class ggplot:
289
333
  self.guides.draw()
290
334
  self._draw_figure_texts()
291
335
  self._draw_watermarks()
336
+ self._draw_figure_background()
292
337
 
293
338
  # Artist object theming
294
339
  self.theme.apply()
295
- self.figure.set_layout_engine(PlotnineLayoutEngine(self))
340
+ figure.set_layout_engine(PlotnineLayoutEngine(self))
296
341
 
297
- return self.figure
342
+ return figure
298
343
 
299
- def _draw_using_figure(self, figure: Figure, axs: list[Axes]) -> ggplot:
344
+ def _create_figure(self):
300
345
  """
301
- Draw onto already created figure and axes
302
-
303
- This is can be used to draw animation frames,
304
- or inset plots. It is intended to be used
305
- after the key plot has been drawn.
306
-
307
- Parameters
308
- ----------
309
- figure :
310
- Matplotlib figure
311
- axs :
312
- Array of Axes onto which to draw the plots
346
+ Create gridspec for the panels
313
347
  """
314
- from ._mpl.layout_engine import PlotnineLayoutEngine
348
+ import matplotlib.pyplot as plt
315
349
 
316
- self = deepcopy(self)
317
- self.figure = figure
318
- self.axs = axs
319
- with plot_context(self):
320
- self._build()
350
+ from ._mpl.gridspec import p9GridSpec
321
351
 
322
- # setup
323
- self.figure, self.axs = self.facet.setup(self)
324
- self.guides._setup(self)
325
- self.theme.setup(self)
326
-
327
- # drawing
328
- self._draw_layers()
329
- self._draw_breaks_and_labels()
330
- self.guides.draw()
331
-
332
- # artist theming
333
- self.theme.apply()
334
- self.figure.set_layout_engine(PlotnineLayoutEngine(self))
335
-
336
- return self
352
+ self.figure = plt.figure()
353
+ self._gridspec = p9GridSpec(1, 1, self.figure)
337
354
 
338
355
  def _build(self):
339
356
  """
@@ -491,6 +508,7 @@ class ggplot:
491
508
  title = self.labels.get("title", "")
492
509
  subtitle = self.labels.get("subtitle", "")
493
510
  caption = self.labels.get("caption", "")
511
+ tag = self.labels.get("tag", "")
494
512
 
495
513
  # Get the axis labels (default or specified by user)
496
514
  # and let the coordinate modify them e.g. flip
@@ -508,6 +526,9 @@ class ggplot:
508
526
  if caption:
509
527
  targets.plot_caption = figure.text(0, 0, caption)
510
528
 
529
+ if tag:
530
+ targets.plot_tag = figure.text(0, 0, tag)
531
+
511
532
  if labels.x:
512
533
  targets.axis_title_x = figure.text(0, 0, labels.x)
513
534
 
@@ -521,6 +542,14 @@ class ggplot:
521
542
  for wm in self.watermarks:
522
543
  wm.draw(self.figure)
523
544
 
545
+ def _draw_figure_background(self):
546
+ from matplotlib.patches import Rectangle
547
+
548
+ rect = Rectangle((0, 0), 0, 0, facecolor="none", zorder=-1000)
549
+ self.figure.add_artist(rect)
550
+ self._gridspec.patch = rect
551
+ self.theme.targets.plot_background = rect
552
+
524
553
  def _save_filename(self, ext: str) -> Path:
525
554
  """
526
555
  Make a filename for use by the save method
@@ -572,7 +601,7 @@ class ggplot:
572
601
  fig_kwargs: Dict[str, Any] = {"format": format, **kwargs}
573
602
 
574
603
  if limitsize is None:
575
- limitsize = cast(bool, get_option("limitsize"))
604
+ limitsize = cast("bool", get_option("limitsize"))
576
605
 
577
606
  # filename, depends on the object
578
607
  if filename is None:
@@ -598,7 +627,7 @@ class ggplot:
598
627
  raise PlotnineError("You must specify both width and height")
599
628
  else:
600
629
  width, height = cast(
601
- tuple[float, float], self.theme.getp("figure_size")
630
+ "tuple[float, float]", self.theme.getp("figure_size")
602
631
  )
603
632
 
604
633
  if limitsize and (width > 25 or height > 25):
plotnine/guides/guide.py CHANGED
@@ -18,18 +18,18 @@ if TYPE_CHECKING:
18
18
  from typing_extensions import Self
19
19
 
20
20
  from plotnine import aes, guides
21
- from plotnine.layer import Layers
21
+ from plotnine.layer import Layers, layer
22
22
  from plotnine.scales.scale import scale
23
23
  from plotnine.typing import (
24
24
  LegendPosition,
25
25
  Orientation,
26
- SidePosition,
26
+ Side,
27
27
  )
28
28
 
29
29
  from .guides import GuidesElements
30
30
 
31
31
  AlignDict: TypeAlias = dict[
32
- Literal["ha", "va"], dict[tuple[Orientation, SidePosition], str]
32
+ Literal["ha", "va"], dict[tuple[Orientation, Side], str]
33
33
  ]
34
34
 
35
35
 
@@ -76,10 +76,10 @@ class guide(ABC, metaclass=Register):
76
76
  self.plot_layers: Layers
77
77
  self.plot_mapping: aes
78
78
  self._elements_cls = GuideElements
79
- self.elements = cast(GuideElements, None)
79
+ self.elements = cast("GuideElements", None)
80
80
  self.guides_elements: GuidesElements
81
81
 
82
- def legend_aesthetics(self, layer):
82
+ def legend_aesthetics(self, layer: layer):
83
83
  """
84
84
  Return the aesthetics that contribute to the legend
85
85
 
@@ -122,24 +122,21 @@ class guide(ABC, metaclass=Register):
122
122
  @property
123
123
  def _resolved_position_justification(
124
124
  self,
125
- ) -> (
126
- tuple[SidePosition, float]
127
- | tuple[tuple[float, float], tuple[float, float]]
128
- ):
125
+ ) -> tuple[Side, float] | tuple[tuple[float, float], tuple[float, float]]:
129
126
  """
130
127
  Return the final position & justification to draw the guide
131
128
  """
132
129
  pos = self.elements.position
133
130
  just_view = asdict(self.guides_elements.justification)
134
131
  if isinstance(pos, str):
135
- just = cast(float, just_view[pos])
132
+ just = cast("float", just_view[pos])
136
133
  return (pos, just)
137
134
  else:
138
135
  # If no justification is given for an inside legend,
139
136
  # we use the position of the legend
140
137
  if (just := just_view["inside"]) is None:
141
138
  just = pos
142
- just = cast(tuple[float, float], just)
139
+ just = cast("tuple[float, float]", just)
143
140
  return (pos, just)
144
141
 
145
142
  def train(
@@ -191,9 +188,9 @@ class GuideElements:
191
188
  def title(self):
192
189
  ha = self.theme.getp(("legend_title", "ha"))
193
190
  va = self.theme.getp(("legend_title", "va"), "center")
194
- _margin = self.theme.getp(("legend_title", "margin"))
191
+ _margin = self.theme.getp(("legend_title", "margin")).pt
195
192
  _loc = get_opposite_side(self.title_position)[0]
196
- margin = _margin.get_as(_loc, "pt") if _margin else 0
193
+ margin = getattr(_margin, _loc)
197
194
  top_or_bottom = self.title_position in ("top", "bottom")
198
195
  is_blank = self.theme.T.is_blank("legend_title")
199
196
 
@@ -213,17 +210,19 @@ class GuideElements:
213
210
  )
214
211
 
215
212
  @cached_property
216
- def text_position(self) -> SidePosition:
213
+ def text_position(self) -> Side:
217
214
  raise NotImplementedError
218
215
 
219
216
  @cached_property
220
217
  def _text_margin(self) -> float:
221
- _margin = self.theme.getp((f"legend_text_{self.guide_kind}", "margin"))
222
- _loc = get_opposite_side(self.text_position)
223
- return _margin.get_as(_loc[0], "pt") if _margin else 0
218
+ _margin = self.theme.getp(
219
+ (f"legend_text_{self.guide_kind}", "margin")
220
+ ).pt
221
+ _loc = get_opposite_side(self.text_position)[0]
222
+ return getattr(_margin, _loc)
224
223
 
225
224
  @cached_property
226
- def title_position(self) -> SidePosition:
225
+ def title_position(self) -> Side:
227
226
  if not (pos := self.theme.getp("legend_title_position")):
228
227
  pos = "top" if self.is_vertical else "left"
229
228
  return pos
@@ -242,7 +241,7 @@ class GuideElements:
242
241
  return direction
243
242
 
244
243
  @cached_property
245
- def position(self) -> SidePosition | tuple[float, float]:
244
+ def position(self) -> Side | tuple[float, float]:
246
245
  if (guide_pos := self.guide.position) == "inside":
247
246
  guide_pos = self._position_inside
248
247
 
@@ -254,7 +253,7 @@ class GuideElements:
254
253
  return pos
255
254
 
256
255
  @cached_property
257
- def _position_inside(self) -> SidePosition | tuple[float, float]:
256
+ def _position_inside(self) -> Side | tuple[float, float]:
258
257
  pos = self.theme.getp("legend_position_inside")
259
258
  if isinstance(pos, tuple):
260
259
  return pos
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
27
27
 
28
28
  from plotnine import theme
29
29
  from plotnine.scales.scale import scale
30
- from plotnine.typing import SidePosition
30
+ from plotnine.typing import Side
31
31
 
32
32
 
33
33
  @dataclass
@@ -75,8 +75,8 @@ class guide_colorbar(guide):
75
75
  self.nbin = 300 # if self.display == "gradient" else 300
76
76
 
77
77
  def train(self, scale: scale, aesthetic=None):
78
- self.nbin = cast(int, self.nbin)
79
- self.title = cast(str, self.title)
78
+ self.nbin = cast("int", self.nbin)
79
+ self.title = cast("str", self.title)
80
80
 
81
81
  if not isinstance(scale, scale_continuous):
82
82
  warn("colorbar guide needs continuous scales", PlotnineWarning)
@@ -213,7 +213,7 @@ class guide_colorbar(guide):
213
213
  auxbox = DPICorAuxTransformBox(IdentityTransform())
214
214
 
215
215
  # title
216
- title = cast(str, self.title)
216
+ title = cast("str", self.title)
217
217
  props = {"ha": elements.title.ha, "va": elements.title.va}
218
218
  title_box = TextArea(title, textprops=props)
219
219
  targets.legend_title = title_box._text # type: ignore
@@ -242,7 +242,7 @@ class guide_colorbar(guide):
242
242
  targets.legend_frame = frame
243
243
 
244
244
  # title + colorbar(with labels)
245
- lookup: dict[SidePosition, tuple[type[PackerBase], slice]] = {
245
+ lookup: dict[Side, tuple[type[PackerBase], slice]] = {
246
246
  "right": (HPacker, reverse),
247
247
  "left": (HPacker, obverse),
248
248
  "bottom": (VPacker, reverse),
@@ -495,7 +495,7 @@ class GuideElementsColorbar(GuideElements):
495
495
  )
496
496
 
497
497
  @cached_property
498
- def text_position(self) -> SidePosition:
498
+ def text_position(self) -> Side:
499
499
  if not (position := self.theme.getp("legend_text_position")):
500
500
  position = "right" if self.is_vertical else "bottom"
501
501