MatplotLibAPI 3.2.14__py3-none-any.whl → 3.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
MatplotLibAPI/Area.py ADDED
@@ -0,0 +1,76 @@
1
+ """Area chart helpers."""
2
+
3
+ from typing import Any, Optional, Tuple
4
+
5
+ import pandas as pd
6
+ from matplotlib.axes import Axes
7
+ from matplotlib.figure import Figure
8
+
9
+ from .StyleTemplate import (
10
+ AREA_STYLE_TEMPLATE,
11
+ StyleTemplate,
12
+ string_formatter,
13
+ validate_dataframe,
14
+ )
15
+ from ._visualization_utils import _get_axis, _wrap_aplot
16
+
17
+
18
+ def aplot_area(
19
+ pd_df: pd.DataFrame,
20
+ x: str,
21
+ y: str,
22
+ label: Optional[str] = None,
23
+ stacked: bool = True,
24
+ title: Optional[str] = None,
25
+ style: StyleTemplate = AREA_STYLE_TEMPLATE,
26
+ ax: Optional[Axes] = None,
27
+ **kwargs: Any,
28
+ ) -> Axes:
29
+ """Plot area charts, optionally stacked for part-to-whole trends."""
30
+ cols = [x, y]
31
+ if label:
32
+ cols.append(label)
33
+ validate_dataframe(pd_df, cols=cols)
34
+ plot_ax = _get_axis(ax)
35
+
36
+ if label:
37
+ pivot_df = pd_df.pivot_table(
38
+ index=x, columns=label, values=y, aggfunc="sum"
39
+ ).sort_index()
40
+ pivot_df.plot(kind="area", stacked=stacked, alpha=0.7, ax=plot_ax)
41
+ else:
42
+ sorted_df = pd_df.sort_values(by=x)
43
+ plot_ax.fill_between(
44
+ sorted_df[x], sorted_df[y], color=style.font_color, alpha=0.4
45
+ )
46
+ plot_ax.plot(sorted_df[x], sorted_df[y], color=style.font_color)
47
+
48
+ plot_ax.set_xlabel(string_formatter(x))
49
+ plot_ax.set_ylabel(string_formatter(y))
50
+ if title:
51
+ plot_ax.set_title(title)
52
+ return plot_ax
53
+
54
+
55
+ def fplot_area(
56
+ pd_df: pd.DataFrame,
57
+ x: str,
58
+ y: str,
59
+ label: Optional[str] = None,
60
+ stacked: bool = True,
61
+ title: Optional[str] = None,
62
+ style: StyleTemplate = AREA_STYLE_TEMPLATE,
63
+ figsize: Tuple[float, float] = (10, 6),
64
+ ) -> Figure:
65
+ """Plot area charts on a new figure."""
66
+ return _wrap_aplot(
67
+ aplot_area,
68
+ pd_df=pd_df,
69
+ figsize=figsize,
70
+ x=x,
71
+ y=y,
72
+ label=label,
73
+ stacked=stacked,
74
+ title=title,
75
+ style=style,
76
+ )
MatplotLibAPI/Bar.py ADDED
@@ -0,0 +1,79 @@
1
+ """Bar and stacked bar chart helpers."""
2
+
3
+ from typing import Any, Optional, Tuple
4
+
5
+ import pandas as pd
6
+ import seaborn as sns
7
+ from matplotlib.axes import Axes
8
+ from matplotlib.figure import Figure
9
+
10
+ from .StyleTemplate import (
11
+ DISTRIBUTION_STYLE_TEMPLATE,
12
+ StyleTemplate,
13
+ string_formatter,
14
+ validate_dataframe,
15
+ )
16
+ from ._visualization_utils import _get_axis, _wrap_aplot
17
+
18
+
19
+ def aplot_bar(
20
+ pd_df: pd.DataFrame,
21
+ category: str,
22
+ value: str,
23
+ group: Optional[str] = None,
24
+ stacked: bool = False,
25
+ title: Optional[str] = None,
26
+ style: StyleTemplate = DISTRIBUTION_STYLE_TEMPLATE,
27
+ ax: Optional[Axes] = None,
28
+ **kwargs: Any,
29
+ ) -> Axes:
30
+ """Plot bar or stacked bar charts for categorical comparisons."""
31
+ cols = [category, value]
32
+ if group:
33
+ cols.append(group)
34
+ validate_dataframe(pd_df, cols=cols)
35
+
36
+ plot_ax = _get_axis(ax)
37
+ plot_df = pd_df.copy()
38
+
39
+ if group:
40
+ pivot_df = plot_df.pivot_table(
41
+ index=category, columns=group, values=value, aggfunc="sum"
42
+ )
43
+ pivot_df.plot(kind="bar", stacked=stacked, ax=plot_ax, alpha=0.85)
44
+ else:
45
+ sns.barplot(
46
+ data=plot_df, x=category, y=value, palette=style.palette, ax=plot_ax
47
+ )
48
+
49
+ plot_ax.set_facecolor(style.background_color)
50
+ plot_ax.set_xlabel(string_formatter(category))
51
+ plot_ax.set_ylabel(string_formatter(value))
52
+ if title:
53
+ plot_ax.set_title(title)
54
+ plot_ax.tick_params(axis="x", labelrotation=45)
55
+ return plot_ax
56
+
57
+
58
+ def fplot_bar(
59
+ pd_df: pd.DataFrame,
60
+ category: str,
61
+ value: str,
62
+ group: Optional[str] = None,
63
+ stacked: bool = False,
64
+ title: Optional[str] = None,
65
+ style: StyleTemplate = DISTRIBUTION_STYLE_TEMPLATE,
66
+ figsize: Tuple[float, float] = (10, 6),
67
+ ) -> Figure:
68
+ """Plot bar or stacked bar charts on a new figure."""
69
+ return _wrap_aplot(
70
+ aplot_bar,
71
+ pd_df=pd_df,
72
+ figsize=figsize,
73
+ category=category,
74
+ value=value,
75
+ group=group,
76
+ stacked=stacked,
77
+ title=title,
78
+ style=style,
79
+ )
@@ -0,0 +1,69 @@
1
+ """Box and violin plot helpers."""
2
+
3
+ from typing import Any, Optional, Tuple
4
+
5
+ import pandas as pd
6
+ import seaborn as sns
7
+ from matplotlib.axes import Axes
8
+ from matplotlib.figure import Figure
9
+
10
+ from .StyleTemplate import (
11
+ DISTRIBUTION_STYLE_TEMPLATE,
12
+ StyleTemplate,
13
+ string_formatter,
14
+ validate_dataframe,
15
+ )
16
+ from ._visualization_utils import _get_axis, _wrap_aplot
17
+
18
+
19
+ def aplot_box_violin(
20
+ pd_df: pd.DataFrame,
21
+ column: str,
22
+ by: Optional[str] = None,
23
+ violin: bool = False,
24
+ title: Optional[str] = None,
25
+ style: StyleTemplate = DISTRIBUTION_STYLE_TEMPLATE,
26
+ ax: Optional[Axes] = None,
27
+ **kwargs: Any,
28
+ ) -> Axes:
29
+ """Plot box or violin charts to summarize distributions."""
30
+ cols = [column]
31
+ if by:
32
+ cols.append(by)
33
+ validate_dataframe(pd_df, cols=cols)
34
+ plot_ax = _get_axis(ax)
35
+
36
+ if violin:
37
+ sns.violinplot(data=pd_df, x=by, y=column, palette=style.palette, ax=plot_ax)
38
+ else:
39
+ sns.boxplot(data=pd_df, x=by, y=column, palette=style.palette, ax=plot_ax)
40
+
41
+ plot_ax.set_facecolor(style.background_color)
42
+ plot_ax.set_ylabel(string_formatter(column))
43
+ if by:
44
+ plot_ax.set_xlabel(string_formatter(by))
45
+ if title:
46
+ plot_ax.set_title(title)
47
+ return plot_ax
48
+
49
+
50
+ def fplot_box_violin(
51
+ pd_df: pd.DataFrame,
52
+ column: str,
53
+ by: Optional[str] = None,
54
+ violin: bool = False,
55
+ title: Optional[str] = None,
56
+ style: StyleTemplate = DISTRIBUTION_STYLE_TEMPLATE,
57
+ figsize: Tuple[float, float] = (10, 6),
58
+ ) -> Figure:
59
+ """Plot box or violin charts on a new figure."""
60
+ return _wrap_aplot(
61
+ aplot_box_violin,
62
+ pd_df=pd_df,
63
+ figsize=figsize,
64
+ column=column,
65
+ by=by,
66
+ violin=violin,
67
+ title=title,
68
+ style=style,
69
+ )
@@ -0,0 +1,113 @@
1
+ """Heatmap and correlation matrix helpers."""
2
+
3
+ from typing import Any, Optional, Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import seaborn as sns
8
+ from matplotlib.axes import Axes
9
+ from matplotlib.figure import Figure
10
+ from pandas._typing import CorrelationMethod
11
+
12
+ from .StyleTemplate import (
13
+ HEATMAP_STYLE_TEMPLATE,
14
+ StyleTemplate,
15
+ string_formatter,
16
+ validate_dataframe,
17
+ )
18
+ from ._visualization_utils import _get_axis, _wrap_aplot
19
+
20
+
21
+ def aplot_heatmap(
22
+ pd_df: pd.DataFrame,
23
+ x: str,
24
+ y: str,
25
+ value: str,
26
+ title: Optional[str] = None,
27
+ style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
28
+ ax: Optional[Axes] = None,
29
+ **kwargs: Any,
30
+ ) -> Axes:
31
+ """Plot a matrix heatmap for multivariate pattern detection."""
32
+ validate_dataframe(pd_df, cols=[x, y, value])
33
+ plot_ax = _get_axis(ax)
34
+
35
+ pivot_df = pd_df.pivot_table(index=y, columns=x, values=value, aggfunc="mean")
36
+ sns.heatmap(pivot_df, cmap=style.palette, ax=plot_ax)
37
+
38
+ plot_ax.set_xlabel(string_formatter(x))
39
+ plot_ax.set_ylabel(string_formatter(y))
40
+ if title:
41
+ plot_ax.set_title(title)
42
+ return plot_ax
43
+
44
+
45
+ def aplot_correlation_matrix(
46
+ pd_df: pd.DataFrame,
47
+ columns: Optional[Sequence[str]] = None,
48
+ method: CorrelationMethod = "pearson",
49
+ title: Optional[str] = None,
50
+ style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
51
+ ax: Optional[Axes] = None,
52
+ **kwargs: Any,
53
+ ) -> Axes:
54
+ """Plot a correlation matrix heatmap for numeric columns."""
55
+ subset = (
56
+ columns
57
+ if columns is not None
58
+ else pd_df.select_dtypes(include=[np.number]).columns
59
+ )
60
+ if len(subset) == 0:
61
+ raise AttributeError("No numeric columns available for correlation matrix")
62
+
63
+ validate_dataframe(pd_df, cols=list(subset))
64
+ plot_ax = _get_axis(ax)
65
+
66
+ selected: pd.DataFrame = pd_df.loc[:, list(subset)]
67
+ corr = selected.corr(method=method)
68
+ sns.heatmap(corr, cmap=style.palette, annot=True, fmt=".2f", ax=plot_ax)
69
+ if title:
70
+ plot_ax.set_title(title)
71
+ return plot_ax
72
+
73
+
74
+ def fplot_heatmap(
75
+ pd_df: pd.DataFrame,
76
+ x: str,
77
+ y: str,
78
+ value: str,
79
+ title: Optional[str] = None,
80
+ style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
81
+ figsize: Tuple[float, float] = (10, 6),
82
+ ) -> Figure:
83
+ """Plot a matrix heatmap on a new figure."""
84
+ return _wrap_aplot(
85
+ aplot_heatmap,
86
+ pd_df=pd_df,
87
+ figsize=figsize,
88
+ x=x,
89
+ y=y,
90
+ value=value,
91
+ title=title,
92
+ style=style,
93
+ )
94
+
95
+
96
+ def fplot_correlation_matrix(
97
+ pd_df: pd.DataFrame,
98
+ columns: Optional[Sequence[str]] = None,
99
+ method: CorrelationMethod = "pearson",
100
+ title: Optional[str] = None,
101
+ style: StyleTemplate = HEATMAP_STYLE_TEMPLATE,
102
+ figsize: Tuple[float, float] = (10, 6),
103
+ ) -> Figure:
104
+ """Plot a correlation matrix heatmap on a new figure."""
105
+ return _wrap_aplot(
106
+ aplot_correlation_matrix,
107
+ pd_df=pd_df,
108
+ figsize=figsize,
109
+ columns=columns,
110
+ method=method,
111
+ title=title,
112
+ style=style,
113
+ )
@@ -0,0 +1,69 @@
1
+ """Histogram and KDE plotting helpers."""
2
+
3
+ from typing import Any, Optional, Tuple
4
+
5
+ import pandas as pd
6
+ import seaborn as sns
7
+ from matplotlib.axes import Axes
8
+ from matplotlib.figure import Figure
9
+
10
+ from .StyleTemplate import (
11
+ DISTRIBUTION_STYLE_TEMPLATE,
12
+ StyleTemplate,
13
+ string_formatter,
14
+ validate_dataframe,
15
+ )
16
+ from ._visualization_utils import _get_axis, _wrap_aplot
17
+
18
+
19
+ def aplot_histogram_kde(
20
+ pd_df: pd.DataFrame,
21
+ column: str,
22
+ bins: int = 20,
23
+ kde: bool = True,
24
+ title: Optional[str] = None,
25
+ style: StyleTemplate = DISTRIBUTION_STYLE_TEMPLATE,
26
+ ax: Optional[Axes] = None,
27
+ **kwargs: Any,
28
+ ) -> Axes:
29
+ """Plot a histogram with an optional kernel density estimate."""
30
+ validate_dataframe(pd_df, cols=[column])
31
+ plot_ax = _get_axis(ax)
32
+
33
+ sns.histplot(
34
+ data=pd_df,
35
+ x=column,
36
+ bins=bins,
37
+ kde=kde,
38
+ color=style.font_color,
39
+ edgecolor=style.background_color,
40
+ ax=plot_ax,
41
+ )
42
+ plot_ax.set_facecolor(style.background_color)
43
+ plot_ax.set_xlabel(string_formatter(column))
44
+ plot_ax.set_ylabel("Frequency")
45
+ if title:
46
+ plot_ax.set_title(title)
47
+ return plot_ax
48
+
49
+
50
+ def fplot_histogram_kde(
51
+ pd_df: pd.DataFrame,
52
+ column: str,
53
+ bins: int = 20,
54
+ kde: bool = True,
55
+ title: Optional[str] = None,
56
+ style: StyleTemplate = DISTRIBUTION_STYLE_TEMPLATE,
57
+ figsize: Tuple[float, float] = (10, 6),
58
+ ) -> Figure:
59
+ """Plot a histogram with optional KDE on a new figure."""
60
+ return _wrap_aplot(
61
+ aplot_histogram_kde,
62
+ pd_df=pd_df,
63
+ figsize=figsize,
64
+ column=column,
65
+ bins=bins,
66
+ kde=kde,
67
+ title=title,
68
+ style=style,
69
+ )
MatplotLibAPI/Network.py CHANGED
@@ -257,6 +257,21 @@ class NetworkGraph:
257
257
  """Return an ``AdjacencyView`` of the graph."""
258
258
  return AdjacencyView(self._nx_graph.adj)
259
259
 
260
+ @property
261
+ def connected_components(self) -> List[set]:
262
+ """Return the connected components of the graph."""
263
+ return list(nx.connected_components(self._nx_graph))
264
+
265
+ @property
266
+ def number_of_nodes(self) -> int:
267
+ """Return the number of nodes in the graph."""
268
+ return self._nx_graph.number_of_nodes()
269
+
270
+ @property
271
+ def number_of_edges(self) -> int:
272
+ """Return the number of edges in the graph."""
273
+ return self._nx_graph.number_of_edges()
274
+
260
275
  def edge_subgraph(self, edges: Iterable) -> "NetworkGraph":
261
276
  """Return a subgraph containing only the specified edges.
262
277
 
@@ -441,12 +456,11 @@ class NetworkGraph:
441
456
  return ax
442
457
 
443
458
  def plot_network_components(self, *args: Any, **kwargs: Any) -> List:
444
- """Plot network components (DEPRECATED).
459
+ """Plot network components.
445
460
 
446
461
  .. deprecated:: 0.1.0
447
- This method will be removed in a future version.
448
- Please use `fplot_network_components` which provides a figure-level interface
449
- for plotting components.
462
+ `plot_network_components` will be removed in a future version.
463
+ Use `fplot_network_components` instead.
450
464
  """
451
465
  import warnings
452
466
 
@@ -580,7 +594,37 @@ class NetworkGraph:
580
594
  return NetworkGraph(network_G)
581
595
 
582
596
 
583
- def _prepare_network_graph(
597
+ def compute_network_grid(
598
+ connected_components: List[set], style: StyleTemplate
599
+ ) -> Tuple[Figure, np.ndarray]:
600
+ """Compute the grid layout for network component subplots.
601
+
602
+ Parameters
603
+ ----------
604
+ connected_components : list[set]
605
+ A list of sets, where each set contains the nodes of a connected component.
606
+ style : StyleTemplate
607
+ The style template used for plotting.
608
+
609
+ Returns
610
+ -------
611
+ Tuple[Figure, np.ndarray]
612
+ A tuple containing the Matplotlib figure and the grid of axes.
613
+ """
614
+ n_components = len(connected_components)
615
+ n_cols = int(np.ceil(np.sqrt(n_components)))
616
+ n_rows = int(np.ceil(n_components / n_cols))
617
+ fig, axes_grid = plt.subplots(n_rows, n_cols, figsize=(19.2, 10.8))
618
+ fig = cast(Figure, fig)
619
+ fig.patch.set_facecolor(style.background_color)
620
+ if not isinstance(axes_grid, np.ndarray):
621
+ axes = np.array([axes_grid])
622
+ else:
623
+ axes = axes_grid.flatten()
624
+ return fig, axes
625
+
626
+
627
+ def prepare_network_graph(
584
628
  pd_df: pd.DataFrame,
585
629
  source: str,
586
630
  target: str,
@@ -588,7 +632,37 @@ def _prepare_network_graph(
588
632
  sort_by: Optional[str],
589
633
  node_list: Optional[List],
590
634
  ) -> NetworkGraph:
591
- """Prepare NetworkGraph for plotting."""
635
+ """Prepare a NetworkGraph for plotting from a pandas DataFrame.
636
+
637
+ This function takes a DataFrame and prepares it for network visualization by:
638
+ 1. Filtering the DataFrame to include only the nodes in `node_list` (if provided).
639
+ 2. Validating the DataFrame to ensure it has the required columns.
640
+ 3. Creating a `NetworkGraph` from the edge list.
641
+ 4. Extracting the k-core of the graph (k=2) to focus on the main structure.
642
+ 5. Calculating node weights based on the sum of their top k edge weights.
643
+ 6. Trimming the graph to keep only the top k edges per node.
644
+
645
+ Parameters
646
+ ----------
647
+ pd_df : pd.DataFrame
648
+ DataFrame containing the edge list.
649
+ source : str
650
+ Column name for source nodes.
651
+ target : str
652
+ Column name for target nodes.
653
+ weight : str
654
+ Column name for edge weights.
655
+ sort_by : str, optional
656
+ Column to sort the DataFrame by before processing.
657
+ node_list : list, optional
658
+ A list of nodes to include in the graph. If provided, the DataFrame
659
+ will be filtered to include only edges connected to these nodes.
660
+
661
+ Returns
662
+ -------
663
+ NetworkGraph
664
+ The prepared `NetworkGraph` object.
665
+ """
592
666
  if node_list:
593
667
  df = pd_df.loc[
594
668
  (pd_df["source"].isin(node_list)) | (pd_df["target"].isin(node_list))
@@ -648,12 +722,13 @@ def aplot_network(
648
722
  Axes
649
723
  Matplotlib axes with the plotted network.
650
724
  """
651
- graph = _prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
725
+ graph = prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
652
726
  return graph.plot_network(title=title, style=style, weight=weight, ax=ax)
653
727
 
654
728
 
655
729
  def aplot_network_components(
656
730
  pd_df: pd.DataFrame,
731
+ axes: Optional[np.ndarray],
657
732
  source: str = "source",
658
733
  target: str = "target",
659
734
  weight: str = "weight",
@@ -662,7 +737,6 @@ def aplot_network_components(
662
737
  sort_by: Optional[str] = None,
663
738
  node_list: Optional[List] = None,
664
739
  ascending: bool = False,
665
- axes: Optional[np.ndarray] = None,
666
740
  ) -> None:
667
741
  """Plot network components separately on multiple axes.
668
742
 
@@ -686,10 +760,10 @@ def aplot_network_components(
686
760
  Nodes to include.
687
761
  ascending : bool, optional
688
762
  Sort order for the data. The default is `False`.
689
- axes : np.ndarray, optional
690
- Existing axes to draw on. If None, new axes are created.
763
+ axes : np.ndarray
764
+ Existing axes to draw on.
691
765
  """
692
- graph = _prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
766
+ graph = prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
693
767
 
694
768
  connected_components = list(nx.connected_components(graph._nx_graph))
695
769
 
@@ -699,37 +773,27 @@ def aplot_network_components(
699
773
  ax.set_axis_off()
700
774
  return
701
775
 
702
- if axes is None:
703
- n_components = len(connected_components)
704
- n_cols = int(np.ceil(np.sqrt(n_components)))
705
- n_rows = int(np.ceil(n_components / n_cols))
706
- fig, axes_grid = plt.subplots(n_rows, n_cols, figsize=(19.2, 10.8))
707
- fig = cast(Figure, fig)
708
- fig.patch.set_facecolor(style.background_color)
709
- if not isinstance(axes_grid, np.ndarray):
710
- axes = np.array([axes_grid])
711
- else:
712
- axes = axes_grid.flatten()
776
+ local_axes = axes
777
+ if local_axes is None:
778
+ fig, local_axes = compute_network_grid(connected_components, style)
713
779
 
714
780
  i = -1
715
781
  for i, component in enumerate(connected_components):
716
- if i < len(axes):
782
+ if i < len(local_axes):
717
783
  if len(component) > 5:
718
784
  component_graph = graph.subgraph(node_list=list(component))
719
785
  component_graph.plot_network(
720
786
  title=f"{title}::{i}" if title else str(i),
721
787
  style=style,
722
788
  weight=weight,
723
- ax=axes[i],
789
+ ax=local_axes[i],
724
790
  )
725
- axes[i].set_axis_on()
791
+ local_axes[i].set_axis_on()
726
792
  else:
727
- break # Stop if there are more components than axes
793
+ break
728
794
 
729
- # Turn off any unused axes
730
- if axes is not None:
731
- for j in range(i + 1, len(axes)):
732
- axes[j].set_axis_off()
795
+ for j in range(i + 1, len(local_axes)):
796
+ local_axes[j].set_axis_off()
733
797
 
734
798
 
735
799
  def fplot_network(
MatplotLibAPI/Pie.py ADDED
@@ -0,0 +1,66 @@
1
+ """Pie and donut chart helpers."""
2
+
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ import matplotlib.pyplot as plt
6
+ import pandas as pd
7
+ import seaborn as sns
8
+ from matplotlib.axes import Axes
9
+ from matplotlib.figure import Figure
10
+
11
+ from .StyleTemplate import PIE_STYLE_TEMPLATE, StyleTemplate, validate_dataframe
12
+ from ._visualization_utils import _get_axis, _wrap_aplot
13
+
14
+
15
+ def aplot_pie_donut(
16
+ pd_df: pd.DataFrame,
17
+ category: str,
18
+ value: str,
19
+ donut: bool = False,
20
+ title: Optional[str] = None,
21
+ style: StyleTemplate = PIE_STYLE_TEMPLATE,
22
+ ax: Optional[Axes] = None,
23
+ **kwargs: Any,
24
+ ) -> Axes:
25
+ """Plot pie or donut charts for categorical share visualization."""
26
+ validate_dataframe(pd_df, cols=[category, value])
27
+ plot_ax = _get_axis(ax)
28
+ labels = pd_df[category].astype(str).tolist()
29
+ sizes = pd_df[value]
30
+
31
+ wedgeprops: Optional[Dict[str, Any]] = None
32
+ if donut:
33
+ wedgeprops = {"width": 0.3}
34
+ wedges, *_ = plot_ax.pie(
35
+ sizes,
36
+ labels=labels,
37
+ autopct="%1.1f%%",
38
+ colors=sns.color_palette(style.palette),
39
+ wedgeprops=wedgeprops,
40
+ )
41
+ plot_ax.axis("equal")
42
+ if title:
43
+ plot_ax.set_title(title)
44
+ return plot_ax
45
+
46
+
47
+ def fplot_pie_donut(
48
+ pd_df: pd.DataFrame,
49
+ category: str,
50
+ value: str,
51
+ donut: bool = False,
52
+ title: Optional[str] = None,
53
+ style: StyleTemplate = PIE_STYLE_TEMPLATE,
54
+ figsize: Tuple[float, float] = (8, 8),
55
+ ) -> Figure:
56
+ """Plot pie or donut charts on a new figure."""
57
+ return _wrap_aplot(
58
+ aplot_pie_donut,
59
+ pd_df=pd_df,
60
+ figsize=figsize,
61
+ category=category,
62
+ value=value,
63
+ donut=donut,
64
+ title=title,
65
+ style=style,
66
+ )