MatplotLibAPI 4.0.2__py3-none-any.whl → 4.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
MatplotLibAPI/bubble.py CHANGED
@@ -4,11 +4,10 @@ Provides a Bubble class to create and render bubble charts using seaborn and mat
4
4
  with customizable styling via `StyleTemplate`.
5
5
  """
6
6
 
7
- from typing import Any, Dict, Optional, Tuple, cast
7
+ from typing import Any, Dict, Optional, cast
8
8
 
9
9
  import matplotlib.pyplot as plt
10
10
  import pandas as pd
11
- from pandas.api.extensions import register_dataframe_accessor
12
11
  import seaborn as sns
13
12
  from matplotlib.axes import Axes
14
13
  from matplotlib.figure import Figure
@@ -18,7 +17,6 @@ from .base_plot import BasePlot
18
17
 
19
18
  from .style_template import (
20
19
  BUBBLE_STYLE_TEMPLATE,
21
- FIG_SIZE,
22
20
  MAX_RESULTS,
23
21
  TITLE_SCALE_FACTOR,
24
22
  StyleTemplate,
@@ -29,10 +27,9 @@ from .style_template import (
29
27
  FormatterFunc,
30
28
  )
31
29
 
32
- __all__ = ["BUBBLE_STYLE_TEMPLATE", "Bubble"]
30
+ __all__ = ["BUBBLE_STYLE_TEMPLATE", "Bubble", "aplot_bubble", "fplot_bubble"]
33
31
 
34
32
 
35
- @register_dataframe_accessor("bubble")
36
33
  class Bubble(BasePlot):
37
34
  """Bubble chart plot implementing the BasePlot interface.
38
35
 
@@ -399,26 +396,96 @@ class Bubble(BasePlot):
399
396
  """
400
397
  if not style:
401
398
  style = BUBBLE_STYLE_TEMPLATE
402
- if ax is None:
403
- ax = cast(Axes, plt.gca())
399
+ plot_ax = BasePlot.get_axis(ax)
404
400
 
405
401
  format_funcs = format_func(
406
402
  style.format_funcs, label=self.label, x=self.x, y=self.y, z=self.z
407
403
  )
408
404
 
409
- Bubble._setup_axes(ax, style, self._obj, self.x, self.y, format_funcs)
405
+ Bubble._setup_axes(plot_ax, style, self._obj, self.x, self.y, format_funcs)
410
406
 
411
- Bubble._draw_bubbles(ax, self._obj, self.x, self.y, self.z, style)
412
- Bubble._draw_lines(ax, self._obj, self.x, self.y, hline, vline, style)
407
+ Bubble._draw_bubbles(plot_ax, self._obj, self.x, self.y, self.z, style)
408
+ Bubble._draw_lines(plot_ax, self._obj, self.x, self.y, hline, vline, style)
413
409
  Bubble._draw_labels(
414
- ax, self._obj, self.label, self.x, self.y, style, format_funcs
410
+ plot_ax, self._obj, self.label, self.x, self.y, style, format_funcs
415
411
  )
416
412
 
417
413
  if title:
418
- ax.set_title(
414
+ plot_ax.set_title(
419
415
  title,
420
416
  color=style.font_color,
421
417
  fontsize=style.font_size * TITLE_SCALE_FACTOR,
422
418
  )
423
419
 
424
- return ax
420
+ return plot_ax
421
+
422
+
423
+ def aplot_bubble(
424
+ pd_df: pd.DataFrame,
425
+ label: str,
426
+ x: str,
427
+ y: str,
428
+ z: str,
429
+ sort_by: Optional[str] = None,
430
+ ascending: bool = False,
431
+ max_values: int = MAX_RESULTS,
432
+ center_to_mean: bool = False,
433
+ title: Optional[str] = None,
434
+ style: Optional[StyleTemplate] = None,
435
+ hline: bool = False,
436
+ vline: bool = False,
437
+ ax: Optional[Axes] = None,
438
+ **kwargs: Any,
439
+ ) -> Axes:
440
+ """Plot a matrix heatmap for multivariate pattern detection."""
441
+ return Bubble(
442
+ pd_df=pd_df,
443
+ label=label,
444
+ x=x,
445
+ y=y,
446
+ z=z,
447
+ max_values=max_values,
448
+ center_to_mean=center_to_mean,
449
+ sort_by=sort_by,
450
+ ascending=ascending,
451
+ ).aplot(
452
+ title=title,
453
+ style=style or BUBBLE_STYLE_TEMPLATE,
454
+ hline=hline,
455
+ vline=vline,
456
+ ax=ax,
457
+ )
458
+
459
+
460
+ def fplot_bubble(
461
+ pd_df: pd.DataFrame,
462
+ label: str,
463
+ x: str,
464
+ y: str,
465
+ z: str,
466
+ sort_by: Optional[str] = None,
467
+ ascending: bool = False,
468
+ max_values: int = MAX_RESULTS,
469
+ center_to_mean: bool = False,
470
+ title: Optional[str] = None,
471
+ style: Optional[StyleTemplate] = None,
472
+ hline: bool = False,
473
+ vline: bool = False,
474
+ ) -> Figure:
475
+ """Plot a matrix heatmap for multivariate pattern detection."""
476
+ return Bubble(
477
+ pd_df=pd_df,
478
+ label=label,
479
+ x=x,
480
+ y=y,
481
+ z=z,
482
+ max_values=max_values,
483
+ center_to_mean=center_to_mean,
484
+ sort_by=sort_by,
485
+ ascending=ascending,
486
+ ).fplot(
487
+ title=title,
488
+ style=style or BUBBLE_STYLE_TEMPLATE,
489
+ hline=hline,
490
+ vline=vline,
491
+ )
@@ -7,13 +7,12 @@ import pandas as pd
7
7
  import plotly.graph_objects as go
8
8
  from matplotlib.axes import Axes
9
9
  from matplotlib.figure import Figure
10
- from matplotlib.gridspec import GridSpec
11
10
  from plotly.subplots import make_subplots
12
11
 
13
- from .base_plot import BasePlot
12
+ from .base_plot import FIG_SIZE
14
13
 
15
- from .bubble import BUBBLE_STYLE_TEMPLATE, FIG_SIZE, Bubble
16
- from .network import aplot_network, NetworkGraph
14
+ from .bubble import aplot_bubble, Bubble
15
+ from .network import NetworkGraph
17
16
  from .style_template import (
18
17
  MAX_RESULTS,
19
18
  TITLE_SCALE_FACTOR,
@@ -21,8 +20,8 @@ from .style_template import (
21
20
  validate_dataframe,
22
21
  )
23
22
  from .table import aplot_table
24
- from .treemap import TREEMAP_STYLE_TEMPLATE, aplot_treemap
25
- from .word_cloud import WORDCLOUD_STYLE_TEMPLATE, aplot_wordcloud, WordCloud
23
+ from .treemap import aplot_treemap
24
+ from .word_cloud import WORDCLOUD_STYLE_TEMPLATE, aplot_wordcloud
26
25
 
27
26
 
28
27
  def plot_composite_bubble(
@@ -80,13 +79,20 @@ def plot_composite_bubble(
80
79
  Matplotlib figure containing the composite bubble chart and tables.
81
80
  """
82
81
  validate_dataframe(pd_df, cols=[label, x, y, z], sort_by=sort_by)
83
- if not style:
84
- style = BUBBLE_STYLE_TEMPLATE
85
- fig = cast(Figure, plt.figure(figsize=figsize))
86
- fig.set_facecolor(style.background_color)
87
- grid = GridSpec(2, 2, height_ratios=[2, 1], width_ratios=[1, 1])
88
- ax = fig.add_subplot(grid[0, 0:])
89
- ax = Bubble(
82
+ if style is None:
83
+ style = WORDCLOUD_STYLE_TEMPLATE
84
+ fig = plt.figure(figsize=figsize)
85
+ gs = fig.add_gridspec(
86
+ 2,
87
+ 2,
88
+ height_ratios=[2, 1],
89
+ width_ratios=[1, 1],
90
+ )
91
+ ax = fig.add_subplot(gs[0, :])
92
+ ax2 = fig.add_subplot(gs[1, 0])
93
+ ax3 = fig.add_subplot(gs[1, 1])
94
+
95
+ aplot_bubble(
90
96
  pd_df=pd_df,
91
97
  label=label,
92
98
  x=x,
@@ -96,13 +102,11 @@ def plot_composite_bubble(
96
102
  center_to_mean=center_to_mean,
97
103
  sort_by=sort_by,
98
104
  ascending=ascending,
99
- ).aplot(
100
105
  title=title,
101
106
  style=style,
102
107
  ax=ax,
103
108
  )
104
109
 
105
- ax2 = fig.add_subplot(grid[1, 0])
106
110
  ax2 = aplot_table(
107
111
  pd_df=pd_df,
108
112
  cols=[label, z, y, x],
@@ -113,7 +117,7 @@ def plot_composite_bubble(
113
117
  max_values=table_rows,
114
118
  style=style,
115
119
  )
116
- ax3 = fig.add_subplot(grid[1, 1])
120
+
117
121
  ax3 = aplot_table(
118
122
  pd_df=pd_df,
119
123
  cols=[label, z, y, x],
@@ -197,6 +201,219 @@ def plot_composite_treemap(
197
201
  return fig
198
202
 
199
203
 
204
+ def fplot_bubble(
205
+ pd_df: pd.DataFrame,
206
+ label: str,
207
+ x: str,
208
+ y: str,
209
+ z: str,
210
+ title: Optional[str] = None,
211
+ style: Optional[StyleTemplate] = None,
212
+ max_values: int = 50,
213
+ center_to_mean: bool = False,
214
+ filter_by: Optional[str] = None,
215
+ sort_by: Optional[str] = None,
216
+ ascending: bool = False,
217
+ table_rows: int = 10,
218
+ figsize: Tuple[float, float] = FIG_SIZE,
219
+ ) -> Figure:
220
+ """Plot a composite bubble chart with summary tables.
221
+
222
+ Parameters
223
+ ----------
224
+ pd_df : pd.DataFrame
225
+ Data to be plotted.
226
+ label : str
227
+ Column name for bubble labels.
228
+ x : str
229
+ Column name for the x-axis values.
230
+ y : str
231
+ Column name for the y-axis values.
232
+ z : str
233
+ Column name for bubble sizes.
234
+ title : str, optional
235
+ Title of the plot. The default is ``None``.
236
+ style : StyleTemplate, optional
237
+ Style configuration. The default is `BUBBLE_STYLE_TEMPLATE`.
238
+ max_values : int, optional
239
+ Maximum number of rows to display in the chart. The default is 50.
240
+ center_to_mean : bool, optional
241
+ Whether to center the bubbles on the mean. The default is `False`.
242
+ filter_by : str, optional
243
+ Column used to filter the data.
244
+ sort_by : str, optional
245
+ Column used to sort the data.
246
+ ascending : bool, optional
247
+ Sort order for the data. The default is `False`.
248
+ table_rows : int, optional
249
+ Number of rows to display in the tables. The default is 10.
250
+ figsize : tuple[float, float], optional
251
+ Size of the created figure. The default is FIG_SIZE.
252
+
253
+ Returns
254
+ -------
255
+ Figure
256
+ Matplotlib figure containing the composite bubble chart and tables.
257
+ """
258
+ validate_dataframe(pd_df, cols=[label, x, y, z], sort_by=sort_by)
259
+ if style is None:
260
+ style = WORDCLOUD_STYLE_TEMPLATE
261
+ fig = plt.figure(figsize=figsize)
262
+ gs = fig.add_gridspec(
263
+ 2,
264
+ 2,
265
+ height_ratios=[2, 1],
266
+ width_ratios=[1, 1],
267
+ )
268
+ ax = fig.add_subplot(gs[0, :])
269
+ ax2 = fig.add_subplot(gs[1, 0])
270
+ ax3 = fig.add_subplot(gs[1, 1])
271
+
272
+ aplot_bubble(
273
+ pd_df=pd_df,
274
+ label=label,
275
+ x=x,
276
+ y=y,
277
+ z=z,
278
+ max_values=max_values,
279
+ center_to_mean=center_to_mean,
280
+ sort_by=sort_by,
281
+ ascending=ascending,
282
+ title=title,
283
+ style=style,
284
+ ax=ax,
285
+ )
286
+
287
+ ax2 = aplot_table(
288
+ pd_df=pd_df,
289
+ cols=[label, z, y, x],
290
+ title=f"Top {table_rows}",
291
+ ax=ax2,
292
+ sort_by=sort_by,
293
+ ascending=False,
294
+ max_values=table_rows,
295
+ style=style,
296
+ )
297
+
298
+ ax3 = aplot_table(
299
+ pd_df=pd_df,
300
+ cols=[label, z, y, x],
301
+ title=f"Last {table_rows}",
302
+ ax=ax3,
303
+ sort_by=sort_by,
304
+ ascending=True,
305
+ max_values=table_rows,
306
+ style=style,
307
+ )
308
+ if title:
309
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
310
+ else:
311
+ fig.tight_layout()
312
+ return fig
313
+
314
+
315
+ def fplot_wordcloud_network2(
316
+ edges_df: pd.DataFrame,
317
+ edge_source_col: str = "source",
318
+ edge_target_col: str = "target",
319
+ edge_weight_col: str = "weight",
320
+ max_words: int = MAX_RESULTS,
321
+ stopwords: Optional[Iterable[str]] = None,
322
+ title: Optional[str] = None,
323
+ style: Optional[StyleTemplate] = None,
324
+ wordcloud_style: Optional[StyleTemplate] = None,
325
+ network_style: Optional[StyleTemplate] = None,
326
+ figsize: Tuple[float, float] = FIG_SIZE,
327
+ ) -> Figure:
328
+ """Plot a word cloud above a network graph.
329
+
330
+ Parameters
331
+ ----------
332
+ edges_df : pd.DataFrame
333
+ DataFrame containing edge connections for the network plot.
334
+ edge_source_col : str, optional
335
+ Column in ``edges_df`` containing source nodes. The default is ``"source"``.
336
+ edge_target_col : str, optional
337
+ Column in ``edges_df`` containing target nodes. The default is ``"target"``.
338
+ edge_weight_col : str, optional
339
+ Column in ``edges_df`` containing edge weights. The default is ``"weight"``.
340
+ max_words : int, optional
341
+ Maximum number of words to include in the word cloud. The default is ``50``.
342
+ stopwords : Iterable[str], optional
343
+ Stopwords to exclude from the word cloud. The default is ``None``.
344
+ title : str, optional
345
+ Title for the composite figure. The default is ``None``.
346
+ style : StyleTemplate, optional
347
+ Shared style configuration applied to the composite figure and used for
348
+ subplots when specialized styles are not provided. The default is
349
+ ``WORDCLOUD_STYLE_TEMPLATE``.
350
+ wordcloud_style : StyleTemplate, optional
351
+ Optional style configuration for the word cloud subplot. When ``None``
352
+ the shared ``style`` is used. The default is ``None``.
353
+ network_style : StyleTemplate, optional
354
+ Optional style configuration for the network subplot. When ``None`` the
355
+ shared ``style`` is used. The default is ``None``.
356
+ figsize : tuple[float, float], optional
357
+ Size of the composite figure. The default is ``FIG_SIZE``.
358
+
359
+ Returns
360
+ -------
361
+ Figure
362
+ Matplotlib figure containing the word cloud on top and network below.
363
+ """
364
+ if not style:
365
+ style = WORDCLOUD_STYLE_TEMPLATE
366
+ fig_raw, axes_raw = plt.subplots(
367
+ 2,
368
+ 1,
369
+ figsize=figsize,
370
+ gridspec_kw={"height_ratios": [1, 2]},
371
+ )
372
+ fig = cast(Figure, fig_raw)
373
+ wordcloud_ax, network_ax = cast(Tuple[Axes, Axes], axes_raw)
374
+
375
+ wordcloud_style = wordcloud_style or style
376
+ network_style = network_style or style
377
+
378
+ fig.set_facecolor(style.background_color)
379
+ if title:
380
+ fig.suptitle(
381
+ title,
382
+ color=style.font_color,
383
+ fontsize=style.font_size * TITLE_SCALE_FACTOR,
384
+ fontname=style.font_name,
385
+ )
386
+
387
+ network = NetworkGraph.from_pandas_edgelist(
388
+ edges_df=edges_df,
389
+ source=edge_source_col,
390
+ target=edge_target_col,
391
+ edge_weight_col=edge_weight_col,
392
+ )
393
+ network.aplot(
394
+ title=None,
395
+ style=network_style,
396
+ ax=network_ax,
397
+ )
398
+
399
+ aplot_wordcloud(
400
+ pd_df=network.node_view.to_dataframe(),
401
+ text_column="node",
402
+ weight_column=edge_weight_col,
403
+ title=None,
404
+ style=wordcloud_style,
405
+ max_words=max_words,
406
+ stopwords=stopwords,
407
+ ax=wordcloud_ax,
408
+ )
409
+
410
+ if title:
411
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
412
+ else:
413
+ fig.tight_layout()
414
+ return fig
415
+
416
+
200
417
  def fplot_wordcloud_network(
201
418
  node_df: pd.DataFrame,
202
419
  edges_df: pd.DataFrame,
MatplotLibAPI/heatmap.py CHANGED
@@ -1,22 +1,20 @@
1
1
  """Heatmap and correlation matrix helpers."""
2
2
 
3
- from typing import Any, Optional, Sequence, Tuple, cast, Literal
3
+ from typing import Any, Optional, Sequence, Tuple, cast
4
4
 
5
5
  import pandas as pd
6
6
  import seaborn as sns
7
7
  from matplotlib.axes import Axes
8
8
  from matplotlib.figure import Figure
9
- from pandas.api.extensions import register_dataframe_accessor
10
9
 
11
10
  from .base_plot import BasePlot
11
+ from .types import CorrelationMethod
12
12
  from .style_template import (
13
13
  HEATMAP_STYLE_TEMPLATE,
14
14
  StyleTemplate,
15
15
  string_formatter,
16
16
  validate_dataframe,
17
17
  )
18
- from .typing import CorrelationMethod
19
- from .utils import _get_axis, _merge_kwargs, create_fig
20
18
 
21
19
  __all__ = [
22
20
  "HEATMAP_STYLE_TEMPLATE",
@@ -27,7 +25,6 @@ __all__ = [
27
25
  ]
28
26
 
29
27
 
30
- @register_dataframe_accessor("heatmap")
31
28
  class Heatmap(BasePlot):
32
29
  """Class for plotting heatmaps and correlation matrices."""
33
30
 
@@ -60,13 +57,13 @@ class Heatmap(BasePlot):
60
57
  """Plot a heatmap on an existing Matplotlib axes."""
61
58
  if not style:
62
59
  style = HEATMAP_STYLE_TEMPLATE
63
- plot_ax = _get_axis(ax)
60
+ plot_ax = BasePlot.get_axis(ax)
64
61
  heatmap_kwargs: dict[str, Any] = {
65
62
  "data": self._obj,
66
63
  "cmap": style.palette,
67
64
  "ax": plot_ax,
68
65
  }
69
- sns.heatmap(**_merge_kwargs(heatmap_kwargs, kwargs))
66
+ sns.heatmap(**BasePlot.merge_kwargs(heatmap_kwargs, kwargs))
70
67
 
71
68
  plot_ax.set_xlabel(string_formatter(self.x))
72
69
  plot_ax.set_ylabel(string_formatter(self.y))
@@ -85,7 +82,7 @@ class Heatmap(BasePlot):
85
82
  """Plot a correlation matrix heatmap on existing axes."""
86
83
  if not style:
87
84
  style = HEATMAP_STYLE_TEMPLATE
88
- plot_ax = _get_axis(ax)
85
+ plot_ax = BasePlot.get_axis(ax)
89
86
  heatmap_kwargs: dict[str, Any] = {
90
87
  "data": self.correlation_matrix(correlation_method),
91
88
  "cmap": style.palette,
@@ -93,7 +90,7 @@ class Heatmap(BasePlot):
93
90
  "fmt": ".2f",
94
91
  "ax": plot_ax,
95
92
  }
96
- sns.heatmap(**_merge_kwargs(heatmap_kwargs, kwargs))
93
+ sns.heatmap(**BasePlot.merge_kwargs(heatmap_kwargs, kwargs))
97
94
  if title:
98
95
  plot_ax.set_title(title)
99
96
  return plot_ax
@@ -108,7 +105,7 @@ class Heatmap(BasePlot):
108
105
  """Plot a correlation matrix heatmap on a new figure."""
109
106
  if not style:
110
107
  style = HEATMAP_STYLE_TEMPLATE
111
- fig, ax = create_fig(figsize=figsize, style=style)
108
+ fig, ax = BasePlot.create_fig(figsize=figsize, style=style)
112
109
  self.aplot_correlation_matrix(
113
110
  title=title,
114
111
  style=style,
@@ -3,9 +3,7 @@
3
3
  from typing import Any, Optional, Tuple
4
4
 
5
5
  import pandas as pd
6
- from pandas.api.extensions import register_dataframe_accessor
7
6
  import seaborn as sns
8
- import matplotlib.pyplot as plt
9
7
  from matplotlib.axes import Axes
10
8
  from matplotlib.figure import Figure
11
9
 
@@ -19,12 +17,10 @@ from .style_template import (
19
17
  string_formatter,
20
18
  validate_dataframe,
21
19
  )
22
- from .utils import _get_axis, _merge_kwargs
23
20
 
24
21
  __all__ = ["DISTRIBUTION_STYLE_TEMPLATE", "aplot_histogram", "fplot_histogram"]
25
22
 
26
23
 
27
- @register_dataframe_accessor("histogram")
28
24
  class Histogram(BasePlot):
29
25
  """Class for plotting histograms with optional KDE."""
30
26
 
@@ -49,7 +45,7 @@ class Histogram(BasePlot):
49
45
  ) -> Axes:
50
46
 
51
47
  validate_dataframe(self._obj, cols=[self.column])
52
- plot_ax = _get_axis(ax)
48
+ plot_ax = BasePlot.get_axis(ax)
53
49
  histplot_kwargs: dict[str, Any] = {
54
50
  "data": self._obj,
55
51
  "x": self.column,
@@ -59,7 +55,7 @@ class Histogram(BasePlot):
59
55
  "edgecolor": style.background_color,
60
56
  "ax": plot_ax,
61
57
  }
62
- sns.histplot(**_merge_kwargs(histplot_kwargs, kwargs))
58
+ sns.histplot(**BasePlot.merge_kwargs(histplot_kwargs, kwargs))
63
59
  plot_ax.set_facecolor(style.background_color)
64
60
  plot_ax.set_xlabel(string_formatter(self.column))
65
61
  plot_ax.set_ylabel("Frequency")
@@ -19,4 +19,5 @@ _DEFAULT = {
19
19
 
20
20
  _WEIGHT_PERCENTILES = np.arange(10, 100, 10)
21
21
 
22
+
22
23
  __all__ = ["_DEFAULT", "_WEIGHT_PERCENTILES"]
@@ -2,7 +2,6 @@
2
2
 
3
3
  from collections import defaultdict
4
4
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
5
- from pandas.api.extensions import register_dataframe_accessor
6
5
  import matplotlib.pyplot as plt
7
6
  import networkx as nx
8
7
  import numpy as np
@@ -243,7 +242,6 @@ class EdgeView(nx.classes.reportviews.EdgeView):
243
242
  return pd.DataFrame(data)
244
243
 
245
244
 
246
- @register_dataframe_accessor("network")
247
245
  class NetworkGraph(BasePlot):
248
246
  """Custom graph class based on NetworkX's ``Graph``.
249
247
 
MatplotLibAPI/pie.py CHANGED
@@ -11,7 +11,6 @@ from matplotlib.figure import Figure
11
11
  from .base_plot import BasePlot
12
12
 
13
13
  from .style_template import PIE_STYLE_TEMPLATE, StyleTemplate, validate_dataframe
14
- from .utils import _get_axis, _merge_kwargs
15
14
 
16
15
  __all__ = ["PIE_STYLE_TEMPLATE", "aplot_pie", "fplot_pie"]
17
16
 
@@ -63,7 +62,7 @@ class PieChart(BasePlot):
63
62
  """
64
63
  labels = self._obj[self.category].astype(str).tolist()
65
64
  sizes = self._obj[self.value]
66
- plot_ax = _get_axis(ax)
65
+ plot_ax = BasePlot.get_axis(ax)
67
66
  wedgeprops: Optional[Dict[str, Any]] = None
68
67
  if donut:
69
68
  wedgeprops = {"width": 0.3}
@@ -74,7 +73,7 @@ class PieChart(BasePlot):
74
73
  "wedgeprops": wedgeprops,
75
74
  "textprops": {"color": style.font_color, "fontsize": style.font_size},
76
75
  }
77
- plot_ax.pie(sizes, **_merge_kwargs(pie_kwargs, kwargs))
76
+ plot_ax.pie(sizes, **BasePlot.merge_kwargs(pie_kwargs, kwargs))
78
77
  plot_ax.axis("equal")
79
78
  if title:
80
79
  plot_ax.set_title(title)
MatplotLibAPI/pivot.py CHANGED
@@ -9,7 +9,6 @@ from matplotlib.axes import Axes
9
9
  from matplotlib.figure import Figure
10
10
 
11
11
  from .base_plot import BasePlot
12
- from .utils import _merge_kwargs
13
12
 
14
13
  from .style_template import (
15
14
  FIG_SIZE,
@@ -103,8 +102,7 @@ class PivotBarChart(BasePlot):
103
102
  ascending=ascending,
104
103
  )
105
104
 
106
- if ax is None:
107
- ax = cast(Axes, plt.gca())
105
+ plot_ax = BasePlot.get_axis(ax)
108
106
 
109
107
  if pd.api.types.is_datetime64_any_dtype(pivot_df[self.x]):
110
108
  pivot_df[self.x] = pivot_df[self.x].dt.strftime("%Y-%m-%d")
@@ -113,24 +111,24 @@ class PivotBarChart(BasePlot):
113
111
  "kind": "bar",
114
112
  "x": self.x,
115
113
  "stacked": self.stacked,
116
- "ax": ax,
114
+ "ax": plot_ax,
117
115
  "alpha": 0.7,
118
116
  }
119
- pivot_df.plot(**_merge_kwargs(plot_kwargs, kwargs))
117
+ pivot_df.plot(**BasePlot.merge_kwargs(plot_kwargs, kwargs))
120
118
 
121
- ax.set_ylabel(string_formatter(self.y))
122
- ax.set_xlabel(string_formatter(self.x))
119
+ plot_ax.set_ylabel(string_formatter(self.y))
120
+ plot_ax.set_xlabel(string_formatter(self.x))
123
121
  if title:
124
- ax.set_title(title)
122
+ plot_ax.set_title(title)
125
123
 
126
- ax.legend(
124
+ plot_ax.legend(
127
125
  fontsize=style.font_size - 2,
128
126
  title_fontsize=style.font_size + 2,
129
127
  labelcolor="linecolor",
130
128
  facecolor=style.background_color,
131
129
  )
132
- ax.tick_params(axis="x", rotation=90)
133
- return ax
130
+ plot_ax.tick_params(axis="x", rotation=90)
131
+ return plot_ax
134
132
 
135
133
 
136
134
  def aplot_pivoted_bars(
MatplotLibAPI/table.py CHANGED
@@ -2,7 +2,6 @@
2
2
 
3
3
  from typing import Any, List, Optional, Tuple
4
4
  import pandas as pd
5
- import matplotlib.pyplot as plt
6
5
  from matplotlib.axes import Axes
7
6
  from matplotlib.figure import Figure
8
7
  from matplotlib.transforms import Bbox
@@ -10,7 +9,6 @@ from matplotlib.table import Table
10
9
 
11
10
  from .base_plot import BasePlot
12
11
 
13
- from .utils import _get_axis
14
12
 
15
13
  from .style_template import (
16
14
  FIG_SIZE,
@@ -62,7 +60,7 @@ class TablePlot(BasePlot):
62
60
  ax: Optional[Axes] = None,
63
61
  **kwargs: Any,
64
62
  ) -> Axes:
65
- plot_ax = _get_axis(ax)
63
+ plot_ax = BasePlot.get_axis(ax)
66
64
 
67
65
  if sort_by is None:
68
66
  sort_by = self.cols[0]
@@ -9,7 +9,7 @@ from matplotlib.axes import Axes
9
9
  from matplotlib.figure import Figure
10
10
 
11
11
  from .base_plot import BasePlot
12
- from .utils import _get_axis
12
+
13
13
  from .style_template import (
14
14
  TIMESERIE_STYLE_TEMPLATE,
15
15
  FIG_SIZE,
@@ -210,7 +210,7 @@ class TimeSeriePlot(BasePlot):
210
210
  format_funcs = format_func(
211
211
  style.format_funcs, label=self.label, x=self.x, y=self.y
212
212
  )
213
- plot_ax = _get_axis(ax)
213
+ plot_ax = BasePlot.get_axis(ax)
214
214
  _plot_timeserie_lines(
215
215
  plot_ax, df, self.label, self.x, self.y, std, style, format_funcs
216
216
  )
MatplotLibAPI/types.py ADDED
@@ -0,0 +1,6 @@
1
+ """Shared type aliases for MatplotLibAPI."""
2
+
3
+ from typing import Literal
4
+ from typing_extensions import TypeAlias
5
+
6
+ CorrelationMethod: TypeAlias = Literal["pearson", "kendall", "spearman"]