plotnine 0.15.0a1__py3-none-any.whl → 0.15.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. plotnine/_mpl/layout_manager/_layout_items.py +85 -23
  2. plotnine/_mpl/layout_manager/_layout_tree.py +16 -6
  3. plotnine/_mpl/layout_manager/_spaces.py +5 -5
  4. plotnine/_mpl/patches.py +70 -34
  5. plotnine/_mpl/text.py +150 -63
  6. plotnine/_mpl/utils.py +1 -1
  7. plotnine/_utils/__init__.py +30 -2
  8. plotnine/doctools.py +1 -1
  9. plotnine/facets/strips.py +17 -28
  10. plotnine/geoms/annotation_logticks.py +7 -8
  11. plotnine/geoms/annotation_stripes.py +6 -6
  12. plotnine/geoms/geom.py +20 -8
  13. plotnine/geoms/geom_abline.py +3 -2
  14. plotnine/geoms/geom_blank.py +0 -3
  15. plotnine/geoms/geom_boxplot.py +4 -4
  16. plotnine/geoms/geom_crossbar.py +3 -3
  17. plotnine/geoms/geom_dotplot.py +1 -1
  18. plotnine/geoms/geom_errorbar.py +2 -2
  19. plotnine/geoms/geom_errorbarh.py +2 -2
  20. plotnine/geoms/geom_hline.py +3 -2
  21. plotnine/geoms/geom_linerange.py +2 -2
  22. plotnine/geoms/geom_map.py +3 -3
  23. plotnine/geoms/geom_path.py +10 -11
  24. plotnine/geoms/geom_point.py +4 -5
  25. plotnine/geoms/geom_pointrange.py +3 -5
  26. plotnine/geoms/geom_polygon.py +2 -3
  27. plotnine/geoms/geom_raster.py +4 -5
  28. plotnine/geoms/geom_rect.py +3 -4
  29. plotnine/geoms/geom_ribbon.py +7 -7
  30. plotnine/geoms/geom_rug.py +1 -1
  31. plotnine/geoms/geom_segment.py +2 -2
  32. plotnine/geoms/geom_smooth.py +3 -3
  33. plotnine/geoms/geom_step.py +2 -2
  34. plotnine/geoms/geom_text.py +2 -3
  35. plotnine/geoms/geom_violin.py +4 -5
  36. plotnine/geoms/geom_vline.py +3 -2
  37. plotnine/guides/guides.py +1 -1
  38. plotnine/helpers.py +49 -0
  39. plotnine/iapi.py +28 -5
  40. plotnine/layer.py +18 -12
  41. plotnine/mapping/_eval_environment.py +1 -1
  42. plotnine/scales/scale_color.py +46 -14
  43. plotnine/scales/scale_continuous.py +5 -4
  44. plotnine/scales/scale_datetime.py +28 -14
  45. plotnine/scales/scale_discrete.py +2 -2
  46. plotnine/scales/scale_identity.py +10 -2
  47. plotnine/scales/scale_xy.py +2 -2
  48. plotnine/stats/binning.py +4 -1
  49. plotnine/stats/smoothers.py +19 -19
  50. plotnine/stats/stat.py +15 -25
  51. plotnine/stats/stat_bin.py +2 -5
  52. plotnine/stats/stat_bin_2d.py +7 -9
  53. plotnine/stats/stat_bindot.py +6 -11
  54. plotnine/stats/stat_boxplot.py +5 -5
  55. plotnine/stats/stat_count.py +5 -9
  56. plotnine/stats/stat_density.py +6 -9
  57. plotnine/stats/stat_density_2d.py +12 -9
  58. plotnine/stats/stat_ecdf.py +6 -5
  59. plotnine/stats/stat_ellipse.py +5 -6
  60. plotnine/stats/stat_function.py +6 -8
  61. plotnine/stats/stat_hull.py +2 -3
  62. plotnine/stats/stat_identity.py +1 -2
  63. plotnine/stats/stat_pointdensity.py +4 -7
  64. plotnine/stats/stat_qq.py +45 -20
  65. plotnine/stats/stat_qq_line.py +15 -11
  66. plotnine/stats/stat_quantile.py +6 -7
  67. plotnine/stats/stat_sina.py +12 -14
  68. plotnine/stats/stat_smooth.py +7 -10
  69. plotnine/stats/stat_sum.py +1 -2
  70. plotnine/stats/stat_summary.py +6 -9
  71. plotnine/stats/stat_summary_bin.py +10 -13
  72. plotnine/stats/stat_unique.py +1 -2
  73. plotnine/stats/stat_ydensity.py +7 -10
  74. plotnine/themes/elements/__init__.py +2 -1
  75. plotnine/themes/elements/margin.py +64 -1
  76. plotnine/themes/theme_gray.py +5 -3
  77. plotnine/themes/theme_matplotlib.py +5 -4
  78. plotnine/themes/theme_seaborn.py +7 -4
  79. plotnine/themes/theme_void.py +11 -4
  80. plotnine/themes/themeable.py +2 -2
  81. plotnine/typing.py +2 -2
  82. {plotnine-0.15.0a1.dist-info → plotnine-0.15.0a3.dist-info}/METADATA +7 -4
  83. {plotnine-0.15.0a1.dist-info → plotnine-0.15.0a3.dist-info}/RECORD +86 -85
  84. {plotnine-0.15.0a1.dist-info → plotnine-0.15.0a3.dist-info}/WHEEL +1 -1
  85. {plotnine-0.15.0a1.dist-info → plotnine-0.15.0a3.dist-info}/licenses/LICENSE +0 -0
  86. {plotnine-0.15.0a1.dist-info → plotnine-0.15.0a3.dist-info}/top_level.txt +0 -0
plotnine/layer.py CHANGED
@@ -163,6 +163,7 @@ class layer:
163
163
  self._make_layer_data(plot.data)
164
164
  self._make_layer_mapping(plot.mapping)
165
165
  self._make_layer_environments(plot.environment)
166
+ self._share_layer_params()
166
167
 
167
168
  def _make_layer_data(self, plot_data: DataLike | None):
168
169
  """
@@ -250,6 +251,13 @@ class layer:
250
251
  self.geom.environment = plot_environment
251
252
  self.stat.environment = plot_environment
252
253
 
254
+ def _share_layer_params(self):
255
+ """
256
+ Pass necessary layer parameters to the geom
257
+ """
258
+ self.geom.params["zorder"] = self.zorder
259
+ self.geom.params["raster"] = self.raster
260
+
253
261
  def compute_aesthetics(self, plot: ggplot):
254
262
  """
255
263
  Return a dataframe where the columns match the aesthetic mappings
@@ -278,10 +286,10 @@ class layer:
278
286
  if not len(data):
279
287
  return
280
288
 
281
- params = self.stat.setup_params(data)
289
+ self.stat.setup_params(data)
282
290
  data = self.stat.use_defaults(data)
283
291
  data = self.stat.setup_data(data)
284
- data = self.stat.compute_layer(data, params, layout)
292
+ data = self.stat.compute_layer(data, layout)
285
293
  self.data = data
286
294
 
287
295
  def map_statistic(self, plot: ggplot):
@@ -320,6 +328,8 @@ class layer:
320
328
  if len(data) == 0:
321
329
  return
322
330
 
331
+ self.geom.params.update(self.stat.params)
332
+ self.geom.setup_params(data)
323
333
  data = self.geom.setup_data(data)
324
334
 
325
335
  check_required_aesthetics(
@@ -357,14 +367,10 @@ class layer:
357
367
  coord : coord
358
368
  Type of coordinate axes
359
369
  """
360
- params = copy(self.geom.params)
361
- params.update(self.stat.params)
362
- params["zorder"] = self.zorder
363
- params["raster"] = self.raster
364
370
  self.data = self.geom.handle_na(self.data)
365
371
  # At this point each layer must have the data
366
372
  # that is created by the plot build process
367
- self.geom.draw_layer(self.data, layout, coord, **params)
373
+ self.geom.draw_layer(self.data, layout, coord)
368
374
 
369
375
  def use_defaults(
370
376
  self,
@@ -399,7 +405,7 @@ class layer:
399
405
  """
400
406
  Prepare/modify data for plotting
401
407
  """
402
- self.stat.finish_layer(self.data, self.stat.params)
408
+ self.stat.finish_layer(self.data)
403
409
 
404
410
 
405
411
  class Layers(List[layer]):
@@ -450,7 +456,9 @@ class Layers(List[layer]):
450
456
  return [l.data for l in self]
451
457
 
452
458
  def setup(self, plot: ggplot):
453
- for l in self:
459
+ # If zorder is 0, it is left to MPL
460
+ for i, l in enumerate(self, start=1):
461
+ l.zorder = i
454
462
  l.setup(plot)
455
463
 
456
464
  def setup_data(self):
@@ -458,9 +466,7 @@ class Layers(List[layer]):
458
466
  l.setup_data()
459
467
 
460
468
  def draw(self, layout: Layout, coord: coord):
461
- # If zorder is 0, it is left to MPL
462
- for i, l in enumerate(self, start=1):
463
- l.zorder = i
469
+ for l in self:
464
470
  l.draw(layout, coord)
465
471
 
466
472
  def compute_aesthetics(self, plot: ggplot):
@@ -48,7 +48,7 @@ def factor(
48
48
  `categories` attribute (which in turn is the `categories` argument, if
49
49
  provided).
50
50
  """
51
- return pd.Categorical(values, categories=categories, ordered=None)
51
+ return pd.Categorical(values, categories=categories, ordered=None) # pyright: ignore[reportArgumentType]
52
52
 
53
53
 
54
54
  def reorder(x, y, fun=np.median, ascending=True):
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import KW_ONLY, InitVar, dataclass
3
+ from dataclasses import KW_ONLY, InitVar, dataclass, field
4
4
  from typing import Literal, Sequence
5
5
  from warnings import warn
6
6
 
@@ -50,34 +50,66 @@ class scale_color_hue(_scale_color_discrete):
50
50
  Qualitative color scale with evenly spaced hues
51
51
  """
52
52
 
53
- h: InitVar[float] = 0.01
53
+ h: InitVar[float | tuple[float, float]] = 15
54
54
  """
55
- Hue. Must be in the range [0, 1]
55
+ Hue. If a float, it is the first hue value, in the range `[0, 360]`.
56
+ The range of the palette will be `[first, first + 360)`.
57
+
58
+ If a tuple, it is the range `[first, last)` of the hues.
56
59
  """
57
60
 
58
- l: InitVar[float] = 0.6
61
+ c: InitVar[float] = 100
59
62
  """
60
- Lightness. Must be in the range [0, 1]
63
+ Chroma. Must be in the range `[0, 100]`
61
64
  """
62
65
 
63
- s: InitVar[float] = 0.65
66
+ l: InitVar[float] = 65
64
67
  """
65
- Saturation. Must be in the range [0, 1]
68
+ Lightness. Must be in the range [0, 100]
66
69
  """
67
70
 
68
- color_space: InitVar[Literal["hls", "hsluv"]] = "hls"
71
+ direction: InitVar[Literal[1, -1]] = 1
69
72
  """
70
- Color space to use. Should be one of
71
- [hls](https://en.wikipedia.org/wiki/HSL_and_HSV)
72
- or [hsluv](https://www.hsluv.org/).
73
- https://www.hsluv.org/
73
+ The order of colours in the scale. If -1 the order
74
+ of colours is reversed. The default is 1.
74
75
  """
75
76
 
76
- def __post_init__(self, h, l, s, color_space):
77
+ _: KW_ONLY
78
+
79
+ s: None = field(default=None, repr=False)
80
+ """
81
+ Not being use and will be removed in a future version
82
+ """
83
+ color_space: None = field(default=None, repr=False)
84
+ """
85
+ Not being use and will be removed in a future version
86
+ """
87
+
88
+ def __post_init__(self, h, c, l, direction):
77
89
  from mizani.palettes import hue_pal
78
90
 
91
+ if (s := self.s) is not None:
92
+ warn(
93
+ f"You used {s=} for the saturation which has been ignored. "
94
+ f"{self.__class__.__name__} now works in HCL colorspace. "
95
+ f"Using `s` in future versions will throw an exception.",
96
+ FutureWarning,
97
+ )
98
+ del self.s
99
+
100
+ if (color_space := self.color_space) is not None:
101
+ warn(
102
+ f"You used {color_space=} to select a color_space and it "
103
+ f"has been ignored. {self.__class__.__name__} now only works "
104
+ f"in HCL colorspace. Using `color_space` in future versions "
105
+ "will throw an exception.",
106
+ FutureWarning,
107
+ )
108
+ del self.color_space
109
+
79
110
  super().__post_init__()
80
- self.palette = hue_pal(h, l, s, color_space=color_space)
111
+ self.palette = hue_pal(h, c, l, direction)
112
+ self.palette.h
81
113
 
82
114
 
83
115
  @dataclass
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from contextlib import suppress
4
4
  from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, Sequence
5
+ from typing import TYPE_CHECKING, Sequence, cast
6
6
  from warnings import warn
7
7
 
8
8
  import numpy as np
@@ -387,15 +387,16 @@ class scale_continuous(
387
387
  limits = self.final_limits
388
388
 
389
389
  x = self.oob(self.rescaler(x, _from=limits))
390
+ na_value = cast("float", self.na_value)
390
391
 
391
392
  uniq = np.unique(x)
392
393
  pal = np.asarray(self.palette(uniq))
393
394
  scaled = pal[match(x, uniq)]
394
395
  if scaled.dtype.kind == "U":
395
- scaled = [self.na_value if x == "nan" else x for x in scaled]
396
+ scaled = [na_value if x == "nan" else x for x in scaled]
396
397
  else:
397
- scaled[pd.isna(scaled)] = self.na_value
398
- return scaled
398
+ scaled[pd.isna(scaled)] = na_value
399
+ return scaled # pyright: ignore[reportReturnType]
399
400
 
400
401
  def get_breaks(
401
402
  self, limits: Optional[tuple[float, float]] = None
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import KW_ONLY, InitVar, dataclass
4
4
  from typing import TYPE_CHECKING
5
+ from warnings import warn
5
6
 
6
7
  from ._runtime_typing import TransUser # noqa: TCH001
7
8
  from .scale_continuous import scale_continuous
@@ -20,24 +21,21 @@ class scale_datetime(scale_continuous):
20
21
  """
21
22
  A string giving the distance between major breaks.
22
23
  For example `'2 weeks'`, `'5 years'`. If specified,
23
- `date_breaks` takes precedence over
24
- `breaks`.
24
+ `date_breaks` takes precedence over `breaks`.
25
25
  """
26
26
 
27
27
  date_labels: InitVar[str | None] = None
28
28
  """
29
29
  Format string for the labels.
30
30
  See [strftime](:ref:`strftime-strptime-behavior`).
31
- If specified, `date_labels` takes precedence over
32
- `labels`.
31
+ If specified, `date_labels` takes precedence over `labels`.
33
32
  """
34
33
 
35
34
  date_minor_breaks: InitVar[str | None] = None
36
35
  """
37
36
  A string giving the distance between minor breaks.
38
37
  For example `'2 weeks'`, `'5 years'`. If specified,
39
- `date_minor_breaks` takes precedence over
40
- `minor_breaks`.
38
+ `date_minor_breaks` takes precedence over `minor_breaks`.
41
39
  """
42
40
 
43
41
  _: KW_ONLY
@@ -80,22 +78,38 @@ class scale_datetime(scale_continuous):
80
78
  date_labels: str | None,
81
79
  date_minor_breaks: str | None,
82
80
  ):
83
- from mizani.breaks import breaks_date as breaks_func
84
- from mizani.labels import label_date as labels_func
81
+ from mizani.breaks import breaks_date_width
82
+ from mizani.labels import label_date
85
83
 
86
84
  if date_breaks is not None:
87
- self.breaks = breaks_func(date_breaks) # pyright: ignore
85
+ self.breaks = breaks_date_width(date_breaks) # pyright: ignore[reportAttributeAccessIssue]
88
86
  elif isinstance(self.breaks, str):
89
- self.breaks = breaks_func(width=self.breaks) # pyright: ignore
87
+ warn(
88
+ "Passing a string to `breaks` will not work in "
89
+ f"future versions. Use `date_breaks={self.breaks!r}`",
90
+ FutureWarning,
91
+ )
92
+ self.breaks = breaks_date_width(width=self.breaks) # pyright: ignore[reportAttributeAccessIssue]
90
93
 
91
94
  if date_labels is not None:
92
- self.labels = labels_func(date_labels) # pyright: ignore
95
+ self.labels = label_date(fmt=date_labels) # pyright: ignore[reportAttributeAccessIssue]
93
96
  elif isinstance(self.labels, str):
94
- self.labels = labels_func(width=self.labels) # pyright: ignore
97
+ warn(
98
+ "Passing a string to `labels` will not work in "
99
+ f"future versions. Use `date_labels={self.labels!r}`",
100
+ FutureWarning,
101
+ )
102
+ self.labels = label_date(fmt=self.labels) # pyright: ignore[reportAttributeAccessIssue]
95
103
 
96
104
  if date_minor_breaks is not None:
97
- self.minor_breaks = breaks_func(date_minor_breaks) # pyright: ignore
105
+ self.minor_breaks = breaks_date_width(date_minor_breaks) # pyright: ignore[reportAttributeAccessIssue]
98
106
  elif isinstance(self.minor_breaks, str):
99
- self.minor_breaks = breaks_func(width=self.minor_breaks) # pyright: ignore
107
+ warn(
108
+ "Passing a string to `minor_breaks` will not work in "
109
+ "future versions. "
110
+ f"Use `date_minor_breaks={self.minor_breaks!r}`",
111
+ FutureWarning,
112
+ )
113
+ self.minor_breaks = breaks_date_width(width=self.minor_breaks) # pyright: ignore[reportAttributeAccessIssue]
100
114
 
101
115
  scale_continuous.__post_init__(self)
@@ -156,7 +156,7 @@ class scale_discrete(
156
156
  range = self.dimension(limits=limits)
157
157
 
158
158
  breaks_d = self.get_breaks(limits)
159
- breaks = self.map(pd.Categorical(breaks_d))
159
+ breaks = self.map(pd.Categorical(breaks_d)) # pyright: ignore[reportArgumentType]
160
160
  minor_breaks = []
161
161
  labels = self.get_labels(breaks_d)
162
162
 
@@ -206,7 +206,7 @@ class scale_discrete(
206
206
  pal = np.asarray(pal, dtype=object)
207
207
  idx = np.asarray(match(x, limits))
208
208
  try:
209
- pal_match = [pal[i] if i >= 0 else None for i in idx]
209
+ pal_match = [pal[i] if i >= 0 else None for i in idx] # pyright: ignore[reportCallIssue,reportArgumentType]
210
210
  except IndexError:
211
211
  # Deal with missing data
212
212
  # - Insert NaN where there is no match
@@ -43,6 +43,8 @@ class scale_color_identity(MapTrainMixin, scale_discrete):
43
43
  """
44
44
 
45
45
  _aesthetics = ["color"]
46
+ _: KW_ONLY
47
+ guide: Literal["legend"] | None = None
46
48
 
47
49
 
48
50
  @dataclass
@@ -52,6 +54,8 @@ class scale_fill_identity(scale_color_identity):
52
54
  """
53
55
 
54
56
  _aesthetics = ["fill"]
57
+ _: KW_ONLY
58
+ guide: Literal["legend"] | None = None
55
59
 
56
60
 
57
61
  @dataclass
@@ -61,6 +65,8 @@ class scale_shape_identity(MapTrainMixin, scale_discrete):
61
65
  """
62
66
 
63
67
  _aesthetics = ["shape"]
68
+ _: KW_ONLY
69
+ guide: Literal["legend"] | None = None
64
70
 
65
71
 
66
72
  @dataclass
@@ -70,6 +76,8 @@ class scale_linetype_identity(MapTrainMixin, scale_discrete):
70
76
  """
71
77
 
72
78
  _aesthetics = ["linetype"]
79
+ _: KW_ONLY
80
+ guide: Literal["legend"] | None = None
73
81
 
74
82
 
75
83
  @dataclass
@@ -82,7 +90,7 @@ class scale_alpha_identity(
82
90
 
83
91
  _aesthetics = ["alpha"]
84
92
  _: KW_ONLY
85
- guide: Literal["legend"] | None = "legend"
93
+ guide: Literal["legend"] | None = None
86
94
 
87
95
 
88
96
  @dataclass
@@ -95,7 +103,7 @@ class scale_size_identity(
95
103
 
96
104
  _aesthetics = ["size"]
97
105
  _: KW_ONLY
98
- guide: Literal["legend"] | None = "legend"
106
+ guide: Literal["legend"] | None = None
99
107
 
100
108
 
101
109
  # American to British spelling
@@ -213,7 +213,7 @@ class scale_x_discrete(scale_position_discrete):
213
213
  Discrete x position
214
214
  """
215
215
 
216
- _aesthetics = ["x", "xmin", "xmax", "xend"]
216
+ _aesthetics = ["x", "xmin", "xmax", "xend", "xintercept"]
217
217
 
218
218
 
219
219
  @dataclass(kw_only=True)
@@ -222,7 +222,7 @@ class scale_y_discrete(scale_position_discrete):
222
222
  Discrete y position
223
223
  """
224
224
 
225
- _aesthetics = ["y", "ymin", "ymax", "yend"]
225
+ _aesthetics = ["y", "ymin", "ymax", "yend", "yintercept"]
226
226
 
227
227
 
228
228
  # Not part of the user API
plotnine/stats/binning.py CHANGED
@@ -165,7 +165,10 @@ def assign_bins(
165
165
  if weight is None:
166
166
  weight = np.ones(len(x))
167
167
  else:
168
- weight = np.asarray(weight)
168
+ # If weight is a dtype that isn't writeable
169
+ # and does not own it's memory. Using a list
170
+ # as an intermediate easily solves this.
171
+ weight = np.array(list(weight))
169
172
  weight[np.isnan(weight)] = 0
170
173
 
171
174
  bin_idx = pd.cut(
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
  from plotnine.mapping import Environment
18
18
 
19
19
 
20
- def predictdf(data, xseq, **params) -> pd.DataFrame:
20
+ def predictdf(data, xseq, params) -> pd.DataFrame:
21
21
  """
22
22
  Make prediction on the data
23
23
 
@@ -49,21 +49,21 @@ def predictdf(data, xseq, **params) -> pd.DataFrame:
49
49
  if not callable(method):
50
50
  msg = (
51
51
  "'method' should either be a string or a function"
52
- "with the signature `func(data, xseq, **params)`"
52
+ "with the signature `func(data, xseq, params)`"
53
53
  )
54
54
  raise PlotnineError(msg)
55
55
 
56
- return method(data, xseq, **params)
56
+ return method(data, xseq, params)
57
57
 
58
58
 
59
- def lm(data, xseq, **params) -> pd.DataFrame:
59
+ def lm(data, xseq, params) -> pd.DataFrame:
60
60
  """
61
61
  Fit OLS / WLS if data has weight
62
62
  """
63
63
  import statsmodels.api as sm
64
64
 
65
65
  if params["formula"]:
66
- return lm_formula(data, xseq, **params)
66
+ return lm_formula(data, xseq, params)
67
67
 
68
68
  X = sm.add_constant(data["x"])
69
69
  Xseq = sm.add_constant(xseq)
@@ -96,7 +96,7 @@ def lm(data, xseq, **params) -> pd.DataFrame:
96
96
  return data
97
97
 
98
98
 
99
- def lm_formula(data, xseq, **params) -> pd.DataFrame:
99
+ def lm_formula(data, xseq, params) -> pd.DataFrame:
100
100
  """
101
101
  Fit OLS / WLS using a formula
102
102
  """
@@ -140,14 +140,14 @@ def lm_formula(data, xseq, **params) -> pd.DataFrame:
140
140
  return data
141
141
 
142
142
 
143
- def rlm(data, xseq, **params) -> pd.DataFrame:
143
+ def rlm(data, xseq, params) -> pd.DataFrame:
144
144
  """
145
145
  Fit RLM
146
146
  """
147
147
  import statsmodels.api as sm
148
148
 
149
149
  if params["formula"]:
150
- return rlm_formula(data, xseq, **params)
150
+ return rlm_formula(data, xseq, params)
151
151
 
152
152
  X = sm.add_constant(data["x"])
153
153
  Xseq = sm.add_constant(xseq)
@@ -170,7 +170,7 @@ def rlm(data, xseq, **params) -> pd.DataFrame:
170
170
  return data
171
171
 
172
172
 
173
- def rlm_formula(data, xseq, **params) -> pd.DataFrame:
173
+ def rlm_formula(data, xseq, params) -> pd.DataFrame:
174
174
  """
175
175
  Fit RLM using a formula
176
176
  """
@@ -196,14 +196,14 @@ def rlm_formula(data, xseq, **params) -> pd.DataFrame:
196
196
  return data
197
197
 
198
198
 
199
- def gls(data, xseq, **params) -> pd.DataFrame:
199
+ def gls(data, xseq, params) -> pd.DataFrame:
200
200
  """
201
201
  Fit GLS
202
202
  """
203
203
  import statsmodels.api as sm
204
204
 
205
205
  if params["formula"]:
206
- return gls_formula(data, xseq, **params)
206
+ return gls_formula(data, xseq, params)
207
207
 
208
208
  X = sm.add_constant(data["x"])
209
209
  Xseq = sm.add_constant(xseq)
@@ -227,7 +227,7 @@ def gls(data, xseq, **params) -> pd.DataFrame:
227
227
  return data
228
228
 
229
229
 
230
- def gls_formula(data, xseq, **params):
230
+ def gls_formula(data, xseq, params):
231
231
  """
232
232
  Fit GLL using a formula
233
233
  """
@@ -258,14 +258,14 @@ def gls_formula(data, xseq, **params):
258
258
  return data
259
259
 
260
260
 
261
- def glm(data, xseq, **params) -> pd.DataFrame:
261
+ def glm(data, xseq, params) -> pd.DataFrame:
262
262
  """
263
263
  Fit GLM
264
264
  """
265
265
  import statsmodels.api as sm
266
266
 
267
267
  if params["formula"]:
268
- return glm_formula(data, xseq, **params)
268
+ return glm_formula(data, xseq, params)
269
269
 
270
270
  X = sm.add_constant(data["x"])
271
271
  Xseq = sm.add_constant(xseq)
@@ -292,7 +292,7 @@ def glm(data, xseq, **params) -> pd.DataFrame:
292
292
  return data
293
293
 
294
294
 
295
- def glm_formula(data, xseq, **params):
295
+ def glm_formula(data, xseq, params):
296
296
  """
297
297
  Fit with GLM formula
298
298
  """
@@ -321,7 +321,7 @@ def glm_formula(data, xseq, **params):
321
321
  return data
322
322
 
323
323
 
324
- def lowess(data, xseq, **params) -> pd.DataFrame:
324
+ def lowess(data, xseq, params) -> pd.DataFrame:
325
325
  """
326
326
  Lowess fitting
327
327
  """
@@ -351,7 +351,7 @@ def lowess(data, xseq, **params) -> pd.DataFrame:
351
351
  return data
352
352
 
353
353
 
354
- def loess(data, xseq, **params) -> pd.DataFrame:
354
+ def loess(data, xseq, params) -> pd.DataFrame:
355
355
  """
356
356
  Loess smoothing
357
357
  """
@@ -402,7 +402,7 @@ def loess(data, xseq, **params) -> pd.DataFrame:
402
402
  return data
403
403
 
404
404
 
405
- def mavg(data, xseq, **params) -> pd.DataFrame:
405
+ def mavg(data, xseq, params) -> pd.DataFrame:
406
406
  """
407
407
  Fit moving average
408
408
  """
@@ -426,7 +426,7 @@ def mavg(data, xseq, **params) -> pd.DataFrame:
426
426
  return data
427
427
 
428
428
 
429
- def gpr(data, xseq, **params):
429
+ def gpr(data, xseq, params):
430
430
  """
431
431
  Fit gaussian process
432
432
  """
plotnine/stats/stat.py CHANGED
@@ -195,9 +195,9 @@ class stat(ABC, metaclass=Register):
195
195
 
196
196
  return data
197
197
 
198
- def setup_params(self, data: pd.DataFrame) -> dict[str, Any]:
198
+ def setup_params(self, data: pd.DataFrame):
199
199
  """
200
- Override this to verify or adjust parameters
200
+ Override this to verify and/or adjust parameters
201
201
 
202
202
  Parameters
203
203
  ----------
@@ -209,7 +209,6 @@ class stat(ABC, metaclass=Register):
209
209
  out :
210
210
  Parameters used by the stats.
211
211
  """
212
- return self.params
213
212
 
214
213
  def setup_data(self, data: pd.DataFrame) -> pd.DataFrame:
215
214
  """
@@ -227,9 +226,7 @@ class stat(ABC, metaclass=Register):
227
226
  """
228
227
  return data
229
228
 
230
- def finish_layer(
231
- self, data: pd.DataFrame, params: dict[str, Any]
232
- ) -> pd.DataFrame:
229
+ def finish_layer(self, data: pd.DataFrame) -> pd.DataFrame:
233
230
  """
234
231
  Modify data after the aesthetics have been mapped
235
232
 
@@ -257,9 +254,8 @@ class stat(ABC, metaclass=Register):
257
254
  """
258
255
  return data
259
256
 
260
- @classmethod
261
257
  def compute_layer(
262
- cls, data: pd.DataFrame, params: dict[str, Any], layout: Layout
258
+ self, data: pd.DataFrame, layout: Layout
263
259
  ) -> pd.DataFrame:
264
260
  """
265
261
  Calculate statistics for this layers
@@ -275,22 +271,20 @@ class stat(ABC, metaclass=Register):
275
271
  ----------
276
272
  data :
277
273
  Data points for all objects in a layer.
278
- params :
279
- Stat parameters
280
274
  layout :
281
275
  Panel layout information
282
276
  """
283
277
  check_required_aesthetics(
284
- cls.REQUIRED_AES,
285
- list(data.columns) + list(params.keys()),
286
- cls.__name__,
278
+ self.REQUIRED_AES,
279
+ list(data.columns) + list(self.params.keys()),
280
+ self.__class__.__name__,
287
281
  )
288
282
 
289
283
  data = remove_missing(
290
284
  data,
291
- na_rm=params.get("na_rm", False),
292
- vars=list(cls.REQUIRED_AES | cls.NON_MISSING_AES),
293
- name=cls.__name__,
285
+ na_rm=self.params.get("na_rm", False),
286
+ vars=list(self.REQUIRED_AES | self.NON_MISSING_AES),
287
+ name=self.__class__.__name__,
294
288
  finite=True,
295
289
  )
296
290
 
@@ -304,14 +298,11 @@ class stat(ABC, metaclass=Register):
304
298
  if len(pdata) == 0:
305
299
  return pdata
306
300
  pscales = layout.get_scales(pdata["PANEL"].iloc[0])
307
- return cls.compute_panel(pdata, pscales, **params)
301
+ return self.compute_panel(pdata, pscales)
308
302
 
309
303
  return groupby_apply(data, "PANEL", fn)
310
304
 
311
- @classmethod
312
- def compute_panel(
313
- cls, data: pd.DataFrame, scales: pos_scales, **params: Any
314
- ):
305
+ def compute_panel(self, data: pd.DataFrame, scales: pos_scales):
315
306
  """
316
307
  Calculate the statistics for all the groups
317
308
 
@@ -341,7 +332,7 @@ class stat(ABC, metaclass=Register):
341
332
 
342
333
  stats = []
343
334
  for _, old in data.groupby("group"):
344
- new = cls.compute_group(old, scales, **params)
335
+ new = self.compute_group(old, scales)
345
336
  new.reset_index(drop=True, inplace=True)
346
337
  unique = uniquecols(old)
347
338
  missing = unique.columns.difference(new.columns)
@@ -365,9 +356,8 @@ class stat(ABC, metaclass=Register):
365
356
  # it completely.
366
357
  return stats
367
358
 
368
- @classmethod
369
359
  def compute_group(
370
- cls, data: pd.DataFrame, scales: pos_scales, **params: Any
360
+ self, data: pd.DataFrame, scales: pos_scales
371
361
  ) -> pd.DataFrame:
372
362
  """
373
363
  Calculate statistics for the group
@@ -390,7 +380,7 @@ class stat(ABC, metaclass=Register):
390
380
  Parameters
391
381
  """
392
382
  msg = "{} should implement this method."
393
- raise NotImplementedError(msg.format(cls.__name__))
383
+ raise NotImplementedError(msg.format(self.__class__.__name__))
394
384
 
395
385
  def __radd__(self, other: ggplot) -> ggplot:
396
386
  """