plotnine 0.14.5__py3-none-any.whl → 0.15.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. plotnine/__init__.py +31 -37
  2. plotnine/_mpl/gridspec.py +265 -0
  3. plotnine/_mpl/layout_manager/__init__.py +6 -0
  4. plotnine/_mpl/layout_manager/_engine.py +87 -0
  5. plotnine/_mpl/layout_manager/_layout_items.py +916 -0
  6. plotnine/_mpl/layout_manager/_layout_tree.py +625 -0
  7. plotnine/_mpl/layout_manager/_spaces.py +1007 -0
  8. plotnine/_mpl/patches.py +1 -1
  9. plotnine/_mpl/text.py +59 -24
  10. plotnine/_mpl/utils.py +78 -10
  11. plotnine/_utils/__init__.py +5 -5
  12. plotnine/_utils/dev.py +45 -27
  13. plotnine/animation.py +1 -1
  14. plotnine/coords/coord_trans.py +1 -1
  15. plotnine/data/__init__.py +12 -8
  16. plotnine/doctools.py +1 -1
  17. plotnine/facets/facet.py +30 -39
  18. plotnine/facets/facet_grid.py +14 -6
  19. plotnine/facets/facet_wrap.py +3 -5
  20. plotnine/facets/strips.py +7 -9
  21. plotnine/geoms/geom_crossbar.py +2 -3
  22. plotnine/geoms/geom_path.py +1 -1
  23. plotnine/ggplot.py +94 -65
  24. plotnine/guides/guide.py +12 -10
  25. plotnine/guides/guide_colorbar.py +3 -3
  26. plotnine/guides/guide_legend.py +12 -13
  27. plotnine/guides/guides.py +3 -3
  28. plotnine/iapi.py +5 -2
  29. plotnine/labels.py +5 -0
  30. plotnine/mapping/aes.py +4 -3
  31. plotnine/options.py +14 -7
  32. plotnine/plot_composition/__init__.py +10 -0
  33. plotnine/plot_composition/_compose.py +436 -0
  34. plotnine/plot_composition/_plotspec.py +50 -0
  35. plotnine/plot_composition/_spacer.py +32 -0
  36. plotnine/positions/position_dodge.py +1 -1
  37. plotnine/positions/position_dodge2.py +1 -1
  38. plotnine/positions/position_stack.py +1 -2
  39. plotnine/qplot.py +1 -2
  40. plotnine/scales/__init__.py +0 -6
  41. plotnine/scales/scale.py +1 -1
  42. plotnine/stats/binning.py +1 -1
  43. plotnine/stats/smoothers.py +3 -5
  44. plotnine/stats/stat_density.py +1 -1
  45. plotnine/stats/stat_qq_line.py +1 -1
  46. plotnine/stats/stat_sina.py +1 -1
  47. plotnine/themes/elements/__init__.py +2 -0
  48. plotnine/themes/elements/element_text.py +35 -24
  49. plotnine/themes/elements/margin.py +73 -60
  50. plotnine/themes/targets.py +3 -1
  51. plotnine/themes/theme.py +13 -7
  52. plotnine/themes/theme_gray.py +28 -31
  53. plotnine/themes/theme_matplotlib.py +25 -28
  54. plotnine/themes/theme_seaborn.py +31 -34
  55. plotnine/themes/theme_void.py +17 -26
  56. plotnine/themes/themeable.py +290 -157
  57. {plotnine-0.14.5.dist-info → plotnine-0.15.0.dev2.dist-info}/METADATA +4 -3
  58. {plotnine-0.14.5.dist-info → plotnine-0.15.0.dev2.dist-info}/RECORD +61 -54
  59. {plotnine-0.14.5.dist-info → plotnine-0.15.0.dev2.dist-info}/WHEEL +1 -1
  60. plotnine/_mpl/_plot_side_space.py +0 -888
  61. plotnine/_mpl/_plotnine_tight_layout.py +0 -293
  62. plotnine/_mpl/layout_engine.py +0 -110
  63. {plotnine-0.14.5.dist-info → plotnine-0.15.0.dev2.dist-info/licenses}/LICENSE +0 -0
  64. {plotnine-0.14.5.dist-info → plotnine-0.15.0.dev2.dist-info}/top_level.txt +0 -0
plotnine/_mpl/patches.py CHANGED
@@ -49,7 +49,7 @@ class StripTextPatch(FancyBboxPatch):
49
49
  return
50
50
 
51
51
  text = self.text
52
- posx, posy = text.get_transform().transform((text._x, text._y))
52
+ posx, posy = text.get_transform().transform(text.get_position())
53
53
  x, y, w, h = _get_textbox(text, renderer)
54
54
 
55
55
  self.set_bounds(0.0, 0.0, w, h)
plotnine/_mpl/text.py CHANGED
@@ -1,16 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import typing
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  from matplotlib.text import Text
6
6
 
7
7
  from .patches import StripTextPatch
8
- from .utils import bbox_in_axes_space
8
+ from .utils import bbox_in_axes_space, rel_position
9
9
 
10
- if typing.TYPE_CHECKING:
10
+ if TYPE_CHECKING:
11
11
  from matplotlib.backend_bases import RendererBase
12
12
 
13
13
  from plotnine.iapi import strip_draw_info
14
+ from plotnine.typing import HorizontalJustification, VerticalJustification
14
15
 
15
16
 
16
17
  class StripText(Text):
@@ -23,8 +24,6 @@ class StripText(Text):
23
24
 
24
25
  def __init__(self, info: strip_draw_info):
25
26
  kwargs = {
26
- "ha": info.ha,
27
- "va": info.va,
28
27
  "rotation": info.rotation,
29
28
  "transform": info.ax.transAxes,
30
29
  "clip_on": False,
@@ -40,38 +39,74 @@ class StripText(Text):
40
39
  self.draw_info = info
41
40
  self.patch = StripTextPatch(self)
42
41
 
42
+ # TODO: Move these _justify methods to the layout manager
43
+ # We need to first make sure that the patch has the final size during
44
+ # layout computation. Right now, the final size is calculated during
45
+ # draw (in these justify methods)
46
+ def _justify_horizontally(self, renderer):
47
+ """
48
+ Justify the text along the strip_background
49
+ """
50
+ info = self.draw_info
51
+ lookup: dict[HorizontalJustification, float] = {
52
+ "left": 0.0,
53
+ "center": 0.5,
54
+ "right": 1.0,
55
+ }
56
+ rel = lookup.get(info.ha, 0.5) if isinstance(info.ha, str) else info.ha
57
+ patch_bbox = bbox_in_axes_space(self.patch, info.ax, renderer)
58
+ text_bbox = bbox_in_axes_space(self, info.ax, renderer)
59
+ l, b, w, h = info.x, info.y, info.box_width, patch_bbox.height
60
+ b = b + patch_bbox.height * info.strip_align
61
+ x = rel_position(rel, text_bbox.width, patch_bbox.x0, patch_bbox.x1)
62
+ y = b + h / 2
63
+ self.set_horizontalalignment("left")
64
+ self.patch.set_bounds(l, b, w, h)
65
+ self.set_position((x, y))
66
+
67
+ def _justify_vertically(self, renderer):
68
+ """
69
+ Justify the text along the strip_background
70
+ """
71
+ # Note that the strip text & background and horizontal but
72
+ # rotated to appear vertical. So we really are still justifying
73
+ # horizontally.
74
+ info = self.draw_info
75
+ lookup: dict[VerticalJustification, float] = {
76
+ "bottom": 0.0,
77
+ "center": 0.5,
78
+ "top": 1.0,
79
+ }
80
+ rel = lookup.get(info.va, 0.5) if isinstance(info.va, str) else info.va
81
+ patch_bbox = bbox_in_axes_space(self.patch, info.ax, renderer)
82
+ text_bbox = bbox_in_axes_space(self, info.ax, renderer)
83
+ l, b, w, h = info.x, info.y, patch_bbox.width, info.box_height
84
+ l = l + patch_bbox.width * info.strip_align
85
+ x = l + w / 2
86
+ y = rel_position(rel, text_bbox.height, patch_bbox.y0, patch_bbox.y1)
87
+ self.set_horizontalalignment("right") # 90CW right means bottom
88
+ self.patch.set_bounds(l, b, w, h)
89
+ self.set_position((x, y))
90
+
43
91
  def draw(self, renderer: RendererBase):
44
92
  if not self.get_visible():
45
93
  return
46
94
 
47
- info = self.draw_info
48
- # "fill up" spatch to contain the text
95
+ # expand strip_text patch to contain the text
49
96
  self.patch.update_position_size(renderer)
50
97
 
51
- # Get bbox of spatch in transAxes space
52
- patch_bbox = bbox_in_axes_space(self.patch, info.ax, renderer)
53
-
54
98
  # Align patch across the edge of the panel
55
- if info.position == "top":
56
- l, b, w, h = info.x, info.y, info.box_width, patch_bbox.height
57
- b = b + patch_bbox.height * info.strip_align
99
+ if self.draw_info.position == "top":
100
+ self._justify_horizontally(renderer)
58
101
  else: # "right"
59
- l, b, w, h = info.x, info.y, patch_bbox.width, info.box_height
60
- l = l + patch_bbox.width * info.strip_align
102
+ self._justify_vertically(renderer)
61
103
 
62
- self.patch.set_bounds(l, b, w, h)
63
- self.patch.set_transform(info.ax.transAxes)
104
+ self.patch.set_transform(self.draw_info.ax.transAxes)
64
105
  self.patch.set_mutation_scale(0)
65
106
 
66
107
  # Put text in center of patch
67
- self._x = l + w / 2
68
- self._y = b + h / 2
69
-
70
- # "anchor" aligns before rotation so the right-strip get properly
71
- # centered text
72
108
  self.set_rotation_mode("anchor")
73
- self.set_horizontalalignment("center") # right-strip
74
- self.set_verticalalignment("center_baseline") # top-strip
109
+ self.set_verticalalignment("center_baseline")
75
110
 
76
111
  # Draw spatch
77
112
  self.patch.draw(renderer)
plotnine/_mpl/utils.py CHANGED
@@ -1,18 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- import typing
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  from matplotlib.transforms import Affine2D, Bbox
6
6
 
7
7
  from .transforms import ZEROS_BBOX
8
8
 
9
- if typing.TYPE_CHECKING:
9
+ if TYPE_CHECKING:
10
10
  from matplotlib.artist import Artist
11
11
  from matplotlib.axes import Axes
12
12
  from matplotlib.backend_bases import RendererBase
13
13
  from matplotlib.figure import Figure
14
+ from matplotlib.gridspec import SubplotSpec
14
15
  from matplotlib.transforms import Transform
15
16
 
17
+ from .gridspec import p9GridSpec
18
+
16
19
 
17
20
  def bbox_in_figure_space(
18
21
  artist: Artist, fig: Figure, renderer: RendererBase
@@ -51,28 +54,93 @@ def pts_in_figure_space(fig: Figure, pts: float) -> float:
51
54
  return fig.transFigure.inverted().transform([0, pts])[1]
52
55
 
53
56
 
54
- def get_transPanels(fig: Figure) -> Transform:
57
+ def get_transPanels(fig: Figure, gs: p9GridSpec) -> Transform:
55
58
  """
56
59
  Coordinate system of the Panels (facets) area
57
60
 
58
61
  (0, 0) is the bottom-left of the bottom-left panel and
59
62
  (1, 1) is the top right of the top-right panel.
60
63
 
61
- The subplot parameters must be set before calling this function.
62
- i.e. fig.subplots_adjust should have been called.
64
+ The gridspec parameters must be set before calling this function.
65
+ i.e. gs.update have been called.
63
66
  """
64
- # Contains the layout information from which the panel area
65
- # is derived
66
- params = fig.subplotpars
67
+ # The position of the panels area in figure coordinates
68
+ params = gs.get_subplot_params(fig)
67
69
 
68
70
  # Figure width & height in display coordinates
69
71
  W, H = fig.bbox.width, fig.bbox.height
70
72
 
71
73
  # 1. The panels occupy space that is smaller than the figure
72
74
  # 2. That space is contained within the figure
73
- # We create a transform that represent these separable aspects
74
- # (but order matters), and use to transform transFigure
75
+ # We create a transform that represents these separable aspects
76
+ # (but order matters), and use it to transform transFigure
75
77
  sx, sy = params.right - params.left, params.top - params.bottom
76
78
  dx, dy = params.left * W, params.bottom * H
77
79
  transFiguretoPanels = Affine2D().scale(sx, sy).translate(dx, dy)
78
80
  return fig.transFigure + transFiguretoPanels
81
+
82
+
83
+ def rel_position(rel: float, length: float, low: float, high: float) -> float:
84
+ """
85
+ Relatively position an object of a given length between two position
86
+
87
+ Parameters
88
+ ----------
89
+ rel:
90
+ Relative position of the object between the limits.
91
+ length:
92
+ Length of the object
93
+ low:
94
+ Lower limit position
95
+ high:
96
+ Upper limit position
97
+ """
98
+ return low * (1 - rel) + (high - length) * rel
99
+
100
+
101
+ def get_subplotspecs(axs: list[Axes]) -> list[SubplotSpec]:
102
+ """
103
+ Return the SubplotSpecs of the given axes
104
+
105
+ Parameters
106
+ ----------
107
+ axs:
108
+ List of axes
109
+
110
+ Notes
111
+ -----
112
+ This functions returns the innermost subplotspec and it expects
113
+ every axes object to have one.
114
+ """
115
+ subplotspecs: list[SubplotSpec] = []
116
+ for ax in axs:
117
+ if not (subplotspec := ax.get_subplotspec()):
118
+ raise ValueError("Axes has no suplotspec")
119
+ subplotspecs.append(subplotspec)
120
+ return subplotspecs
121
+
122
+
123
+ def draw_gridspec(gs: p9GridSpec, color="black", **kwargs):
124
+ """
125
+ A debug function to draw a rectangle around the gridspec
126
+ """
127
+ draw_bbox(gs.bbox_relative, gs.figure, color, **kwargs)
128
+
129
+
130
+ def draw_bbox(bbox, figure, color="black", **kwargs):
131
+ """
132
+ A debug function to draw a rectangle around a bounding bbox
133
+ """
134
+ from matplotlib.patches import Rectangle
135
+
136
+ figure.add_artist(
137
+ Rectangle(
138
+ xy=bbox.p0,
139
+ width=bbox.width,
140
+ height=bbox.height,
141
+ edgecolor=color,
142
+ fill=False,
143
+ clip_on=False,
144
+ **kwargs,
145
+ )
146
+ )
@@ -299,7 +299,7 @@ def ninteraction(df: pd.DataFrame, drop: bool = False) -> list[int]:
299
299
  return _id_var(df[df.columns[0]], drop)
300
300
 
301
301
  # Calculate individual ids
302
- ids = df.apply(_id_var, axis=0)
302
+ ids = df.apply(_id_var, axis=0, drop=drop)
303
303
  ids = ids.reindex(columns=list(reversed(ids.columns)))
304
304
 
305
305
  # Calculate dimensions
@@ -310,8 +310,8 @@ def ninteraction(df: pd.DataFrame, drop: bool = False) -> list[int]:
310
310
 
311
311
  combs = np.array(np.hstack([1, np.cumprod(ndistinct[:-1])]))
312
312
  mat = np.array(ids)
313
- res = (mat - 1) @ combs.T + 1
314
- res = np.array(res).flatten().tolist()
313
+ _res = (mat - 1) @ combs.T + 1
314
+ res: list[int] = np.array(_res).flatten().tolist()
315
315
 
316
316
  if drop:
317
317
  return _id_var(res, drop)
@@ -511,7 +511,7 @@ def remove_missing(
511
511
  if finite:
512
512
  lst = [np.inf, -np.inf]
513
513
  to_replace = {v: lst for v in vars}
514
- data.replace(to_replace, np.nan, inplace=True)
514
+ data.replace(to_replace, np.nan, inplace=True) # pyright: ignore[reportArgumentType,reportCallIssue]
515
515
  txt = "non-finite"
516
516
  else:
517
517
  txt = "missing"
@@ -604,7 +604,7 @@ def to_rgba(
604
604
  return c
605
605
 
606
606
  if is_iterable(colors):
607
- colors = cast(Sequence["ColorType"], colors)
607
+ colors = cast("Sequence[ColorType]", colors)
608
608
 
609
609
  if all(no_color(c) for c in colors):
610
610
  return "none"
plotnine/_utils/dev.py CHANGED
@@ -1,13 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
4
3
 
5
-
6
- def get_plotnine_all(use_clipboard=True) -> Optional[str]:
4
+ def get_plotnine_all() -> str:
7
5
  """
8
6
  Generate package level * (star) imports for plotnine
9
-
10
- The contents of __all__ in plotnine/__init__.py
11
7
  """
12
8
  from importlib import import_module
13
9
 
@@ -28,32 +24,54 @@ def get_plotnine_all(use_clipboard=True) -> Optional[str]:
28
24
  "watermark",
29
25
  )
30
26
 
31
- def get_all_from_module(name, quote=False):
27
+ def get_all_from_module(name):
32
28
  """
33
29
  Module level imports
34
30
  """
35
31
  qname = f"plotnine.{name}"
36
32
  m = import_module(qname)
37
- fmt = ('"{}",' if quote else "{},").format
38
- return "\n ".join(fmt(x) for x in sorted(m.__all__))
39
33
 
40
- _imports = "\n".join(
41
- f"from .{name} import (\n {get_all_from_module(name)}\n)"
42
- for name in modules
43
- )
44
- _all = "\n".join(
45
- [
46
- "__all__ = (",
47
- "\n".join(
48
- f" {get_all_from_module(name, True)}" for name in modules
49
- ),
50
- ")",
51
- ]
52
- )
53
- content = f"{_imports}\n\n{_all}"
54
- if use_clipboard:
55
- from pandas.io import clipboard
34
+ return sorted(m.__all__)
35
+
36
+ imports = []
37
+ all_funcs = []
38
+
39
+ for name in modules:
40
+ funcs = get_all_from_module(name)
41
+ import_funcs = "\n ".join(f"{x}," for x in funcs)
42
+ imports.append(f"from .{name} import (\n {import_funcs}\n)")
43
+ all_funcs.extend(funcs)
44
+
45
+ all_funcs = [f' "{x}",' for x in sorted(all_funcs)]
46
+
47
+ _imports = "\n".join(imports)
48
+ _all = "__all__ = (\n" + "\n".join(all_funcs) + "\n)"
49
+
50
+ return f"{_imports}\n\n{_all}"
51
+
52
+
53
+ def get_init_py() -> str:
54
+ """
55
+ Generate plotnine/__init__.py
56
+ """
57
+
58
+ preamble: str = """# Do not edit this file by hand.
59
+ #
60
+ # Generate it using:
61
+ #
62
+ # $ python -c 'from plotnine._utils import dev; print(dev.get_init_py())'
63
+
64
+ from importlib.metadata import PackageNotFoundError, version
65
+
66
+ try:
67
+ __version__ = version("plotnine")
68
+ except PackageNotFoundError:
69
+ # package is not installed
70
+ pass
71
+ finally:
72
+ del version
73
+ del PackageNotFoundError
74
+
75
+ """
56
76
 
57
- clipboard.copy(content) # pyright: ignore
58
- else:
59
- return content
77
+ return preamble + get_plotnine_all()
plotnine/animation.py CHANGED
@@ -233,6 +233,6 @@ class PlotnineAnimation(ArtistAnimation):
233
233
  plot.axs = axs
234
234
  with plot_context(plot):
235
235
  plot._build()
236
- plot.figure, plot.axs = plot.facet.setup(plot)
236
+ plot.axs = plot.facet.setup(plot)
237
237
  plot._draw_layers()
238
238
  return plot
@@ -117,7 +117,7 @@ class coord_trans(coord):
117
117
  range=ranges.range,
118
118
  )
119
119
  sv.range = tuple(sorted(ranges.range_coord)) # type: ignore
120
- breaks = cast(tuple[float, float], sv.breaks)
120
+ breaks = cast("tuple[float, float]", sv.breaks)
121
121
  sv.breaks = transform_value(trans, breaks)
122
122
  sv.minor_breaks = transform_value(trans, sv.minor_breaks)
123
123
  return sv
plotnine/data/__init__.py CHANGED
@@ -81,20 +81,24 @@ def _process_categories():
81
81
  """
82
82
  Set columns in some of the dataframes to categoricals
83
83
  """
84
- global diamonds, midwest, mpg, msleep, penguins
84
+ global diamonds, penguins
85
85
  diamonds = _ordered_categories(
86
86
  diamonds,
87
87
  {
88
- "cut": "Fair, Good, Very Good, Premium, Ideal".split(", "),
89
- "clarity": "I1 SI2 SI1 VS2 VS1 VVS2 VVS1 IF".split(),
88
+ "cut": ["Fair", "Good", "Very Good", "Premium", "Ideal"],
89
+ "clarity": [
90
+ "I1",
91
+ "SI2",
92
+ "SI1",
93
+ "VS2",
94
+ "VS1",
95
+ "VVS2",
96
+ "VVS1",
97
+ "IF",
98
+ ],
90
99
  "color": list("DEFGHIJ"),
91
100
  },
92
101
  )
93
- mpg = _unordered_categories(
94
- mpg, "manufacturer model trans fl drv class".split()
95
- )
96
- midwest = _unordered_categories(midwest, ["category"])
97
- msleep = _unordered_categories(msleep, ["vore", "conservation"])
98
102
  penguins = _unordered_categories(penguins, ["species", "island", "sex"])
99
103
 
100
104
 
plotnine/doctools.py CHANGED
@@ -188,7 +188,7 @@ def dict_to_table(header: tuple[str, str], contents: dict[str, str]) -> str:
188
188
  fill `None`
189
189
  """
190
190
  rows = [
191
- (name, value if value == "" else f"`{value!r}`" "{.py}")
191
+ (name, value if value == "" else f"`{value!r}`{{.py}}")
192
192
  for name, value in contents.items()
193
193
  ]
194
194
  return table_function(rows, headers=header, tablefmt="grid")
plotnine/facets/facet.py CHANGED
@@ -20,9 +20,9 @@ if typing.TYPE_CHECKING:
20
20
  import numpy.typing as npt
21
21
  from matplotlib.axes import Axes
22
22
  from matplotlib.figure import Figure
23
- from matplotlib.gridspec import GridSpec
24
23
 
25
24
  from plotnine import ggplot, theme
25
+ from plotnine._mpl.gridspec import p9GridSpec
26
26
  from plotnine.coords.coord import coord
27
27
  from plotnine.facets.labelling import CanBeStripLabellingFunc
28
28
  from plotnine.facets.layout import Layout
@@ -93,6 +93,7 @@ class facet:
93
93
 
94
94
  # Axes
95
95
  axs: list[Axes]
96
+ _panels_gridspec: p9GridSpec
96
97
 
97
98
  # ggplot object that the facet belongs to
98
99
  plot: ggplot
@@ -100,8 +101,6 @@ class facet:
100
101
  # Facet strips
101
102
  strips: Strips
102
103
 
103
- grid_spec: GridSpec
104
-
105
104
  # The plot environment
106
105
  environment: Environment
107
106
 
@@ -121,6 +120,11 @@ class facet:
121
120
  self.as_table = as_table
122
121
  self.drop = drop
123
122
  self.dir = dir
123
+ allowed_scales = ["fixed", "free", "free_x", "free_y"]
124
+ if scales not in allowed_scales:
125
+ raise ValueError(
126
+ "Argument `scales` must be one of {allowed_scales}."
127
+ )
124
128
  self.free = {
125
129
  "x": scales in ("free_x", "free"),
126
130
  "y": scales in ("free_y", "free"),
@@ -137,17 +141,18 @@ class facet:
137
141
  def setup(self, plot: ggplot):
138
142
  self.plot = plot
139
143
  self.layout = plot.layout
144
+ self.figure = plot.figure
140
145
 
141
- if hasattr(plot, "figure"):
142
- self.figure, self.axs = plot.figure, plot.axs
146
+ if hasattr(plot, "axs"):
147
+ self.axs = plot.axs
143
148
  else:
144
- self.figure, self.axs = self.make_figure()
149
+ self.axs = self._make_axes()
145
150
 
146
151
  self.coordinates = plot.coordinates
147
152
  self.theme = plot.theme
148
153
  self.layout.axs = self.axs
149
154
  self.strips = Strips.from_facet(self)
150
- return self.figure, self.axs
155
+ return self.axs
151
156
 
152
157
  def setup_data(self, data: list[pd.DataFrame]) -> list[pd.DataFrame]:
153
158
  """
@@ -346,17 +351,8 @@ class facet:
346
351
  ax.xaxis.set_major_formatter(MyFixedFormatter(panel_params.x.labels))
347
352
  ax.yaxis.set_major_formatter(MyFixedFormatter(panel_params.y.labels))
348
353
 
349
- pad_x = (
350
- margin.get_as("t", "pt")
351
- if (margin := theme.getp(("axis_text_x", "margin")))
352
- else 0
353
- )
354
-
355
- pad_y = (
356
- margin.get_as("r", "pt")
357
- if (margin := theme.getp(("axis_text_y", "margin")))
358
- else 0
359
- )
354
+ pad_x = theme.get_margin("axis_text_x").pt.t
355
+ pad_y = theme.get_margin("axis_text_y").pt.r
360
356
 
361
357
  ax.tick_params(axis="x", which="major", pad=pad_x)
362
358
  ax.tick_params(axis="y", which="major", pad=pad_y)
@@ -372,7 +368,7 @@ class facet:
372
368
  new = result.__dict__
373
369
 
374
370
  # don't make a deepcopy of the figure & the axes
375
- shallow = {"figure", "axs", "first_ax", "last_ax"}
371
+ shallow = {"axs", "first_ax", "last_ax"}
376
372
  for key, item in old.items():
377
373
  if key in shallow:
378
374
  new[key] = item
@@ -382,35 +378,31 @@ class facet:
382
378
 
383
379
  return result
384
380
 
385
- def _make_figure(self) -> tuple[Figure, GridSpec]:
381
+ def _get_panels_gridspec(self) -> p9GridSpec:
386
382
  """
387
- Create figure & gridspec
383
+ Create gridspec for the panels
388
384
  """
389
- import matplotlib.pyplot as plt
390
- from matplotlib.gridspec import GridSpec
385
+ from plotnine._mpl.gridspec import p9GridSpec
391
386
 
392
- return plt.figure(), GridSpec(self.nrow, self.ncol)
387
+ return p9GridSpec(
388
+ self.nrow, self.ncol, self.figure, nest_into=self.plot._gridspec[0]
389
+ )
393
390
 
394
- def make_figure(self) -> tuple[Figure, list[Axes]]:
391
+ def _make_axes(self) -> list[Axes]:
395
392
  """
396
- Create and return Matplotlib figure and subplot axes
393
+ Create and return subplot axes
397
394
  """
398
395
  num_panels = len(self.layout.layout)
399
396
  axsarr = np.empty((self.nrow, self.ncol), dtype=object)
400
397
 
401
- # Create figure & gridspec
402
- figure, gs = self._make_figure()
403
- self.grid_spec = gs
398
+ self._panels_gridspec = self._get_panels_gridspec()
404
399
 
405
400
  # Create axes
406
401
  it = itertools.product(range(self.nrow), range(self.ncol))
407
402
  for i, (row, col) in enumerate(it):
408
- axsarr[row, col] = figure.add_subplot(gs[i])
409
-
410
- # axsarr = np.array([
411
- # figure.add_subplot(gs[i])
412
- # for i in range(self.nrow * self.ncol)
413
- # ]).reshape((self.nrow, self.ncol))
403
+ axsarr[row, col] = self.figure.add_subplot(
404
+ self._panels_gridspec[i]
405
+ )
414
406
 
415
407
  # Rearrange axes
416
408
  # They are ordered to match the positions in the layout table
@@ -429,9 +421,9 @@ class facet:
429
421
 
430
422
  # Delete unused axes
431
423
  for ax in axs[num_panels:]:
432
- figure.delaxes(ax)
424
+ self.figure.delaxes(ax)
433
425
  axs = axs[:num_panels]
434
- return figure, list(axs)
426
+ return list(axs)
435
427
 
436
428
  def _aspect_ratio(self) -> Optional[float]:
437
429
  """
@@ -477,8 +469,7 @@ def combine_vars(
477
469
  has_all = [x.shape[1] == len(vars) for x in values]
478
470
  if not any(has_all):
479
471
  raise PlotnineError(
480
- "At least one layer must contain all variables "
481
- "used for facetting"
472
+ "At least one layer must contain all variables used for facetting"
482
473
  )
483
474
  base = pd.concat([x for i, x in enumerate(values) if has_all[i]], axis=0)
484
475
  base = base.drop_duplicates()
@@ -107,9 +107,11 @@ class facet_grid(facet):
107
107
  self.space = space
108
108
  self.margins = margins
109
109
 
110
- def _make_figure(self):
111
- import matplotlib.pyplot as plt
112
- from matplotlib.gridspec import GridSpec
110
+ def _get_panels_gridspec(self):
111
+ """
112
+ Create gridspec for the panels
113
+ """
114
+ from plotnine._mpl.gridspec import p9GridSpec
113
115
 
114
116
  layout = self.layout
115
117
  space = self.space
@@ -155,7 +157,13 @@ class facet_grid(facet):
155
157
  ratios["width_ratios"] = self.space.get("x")
156
158
  ratios["height_ratios"] = self.space.get("y")
157
159
 
158
- return plt.figure(), GridSpec(self.nrow, self.ncol, **ratios)
160
+ return p9GridSpec(
161
+ self.nrow,
162
+ self.ncol,
163
+ self.figure,
164
+ nest_into=self.plot._gridspec[0],
165
+ **ratios,
166
+ )
159
167
 
160
168
  def compute_layout(self, data: list[pd.DataFrame]) -> pd.DataFrame:
161
169
  if not self.rows and not self.cols:
@@ -302,7 +310,7 @@ def parse_grid_facets_old(
302
310
  "((var1, var2), (var3, var4))",
303
311
  ]
304
312
  error_msg_s = (
305
- "Valid sequences for specifying 'facets' look like" f" {valid_seqs}"
313
+ f"Valid sequences for specifying 'facets' look like {valid_seqs}"
306
314
  )
307
315
 
308
316
  valid_forms = [
@@ -314,7 +322,7 @@ def parse_grid_facets_old(
314
322
  ". ~ func(var1) + func(var2)",
315
323
  ". ~ func(var1+var3) + func(var2)",
316
324
  ] + valid_seqs
317
- error_msg_f = "Valid formula for 'facet_grid' look like" f" {valid_forms}"
325
+ error_msg_f = f"Valid formula for 'facet_grid' look like {valid_forms}"
318
326
 
319
327
  if not isinstance(facets, str):
320
328
  if len(facets) == 1: