MatplotLibAPI 3.2.21__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MatplotLibAPI/__init__.py +4 -86
- MatplotLibAPI/accessor.py +519 -196
- MatplotLibAPI/area.py +177 -0
- MatplotLibAPI/bar.py +185 -0
- MatplotLibAPI/base_plot.py +88 -0
- MatplotLibAPI/box_violin.py +180 -0
- MatplotLibAPI/bubble.py +568 -0
- MatplotLibAPI/{Composite.py → composite.py} +127 -106
- MatplotLibAPI/heatmap.py +223 -0
- MatplotLibAPI/histogram.py +170 -0
- MatplotLibAPI/mcp/__init__.py +17 -0
- MatplotLibAPI/mcp/metadata.py +90 -0
- MatplotLibAPI/mcp/renderers.py +45 -0
- MatplotLibAPI/mcp_server.py +626 -0
- MatplotLibAPI/network/__init__.py +28 -0
- MatplotLibAPI/network/constants.py +22 -0
- MatplotLibAPI/network/core.py +1360 -0
- MatplotLibAPI/network/plot.py +597 -0
- MatplotLibAPI/network/scaling.py +56 -0
- MatplotLibAPI/pie.py +154 -0
- MatplotLibAPI/pivot.py +274 -0
- MatplotLibAPI/sankey.py +99 -0
- MatplotLibAPI/{StyleTemplate.py → style_template.py} +27 -22
- MatplotLibAPI/sunburst.py +139 -0
- MatplotLibAPI/{Table.py → table.py} +112 -87
- MatplotLibAPI/{Timeserie.py → timeserie.py} +98 -42
- MatplotLibAPI/{Treemap.py → treemap.py} +43 -55
- MatplotLibAPI/typing.py +12 -0
- MatplotLibAPI/{_visualization_utils.py → utils.py} +7 -13
- MatplotLibAPI/waffle.py +173 -0
- MatplotLibAPI/word_cloud.py +489 -0
- {matplotlibapi-3.2.21.dist-info → matplotlibapi-4.0.0.dist-info}/METADATA +98 -9
- matplotlibapi-4.0.0.dist-info/RECORD +36 -0
- {matplotlibapi-3.2.21.dist-info → matplotlibapi-4.0.0.dist-info}/WHEEL +1 -1
- matplotlibapi-4.0.0.dist-info/entry_points.txt +2 -0
- MatplotLibAPI/Area.py +0 -80
- MatplotLibAPI/Bar.py +0 -83
- MatplotLibAPI/BoxViolin.py +0 -75
- MatplotLibAPI/Bubble.py +0 -458
- MatplotLibAPI/Heatmap.py +0 -121
- MatplotLibAPI/Histogram.py +0 -73
- MatplotLibAPI/Network.py +0 -989
- MatplotLibAPI/Pie.py +0 -70
- MatplotLibAPI/Pivot.py +0 -134
- MatplotLibAPI/Sankey.py +0 -46
- MatplotLibAPI/Sunburst.py +0 -89
- MatplotLibAPI/Waffle.py +0 -86
- MatplotLibAPI/Wordcloud.py +0 -373
- MatplotLibAPI/_typing.py +0 -17
- matplotlibapi-3.2.21.dist-info/RECORD +0 -26
- {matplotlibapi-3.2.21.dist-info → matplotlibapi-4.0.0.dist-info}/licenses/LICENSE +0 -0
MatplotLibAPI/pie.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Pie and donut chart helpers."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import seaborn as sns
|
|
8
|
+
from matplotlib.axes import Axes
|
|
9
|
+
from matplotlib.figure import Figure
|
|
10
|
+
|
|
11
|
+
from .base_plot import BasePlot
|
|
12
|
+
|
|
13
|
+
from .style_template import PIE_STYLE_TEMPLATE, StyleTemplate, validate_dataframe
|
|
14
|
+
from .utils import _get_axis
|
|
15
|
+
|
|
16
|
+
__all__ = ["PIE_STYLE_TEMPLATE", "aplot_pie", "fplot_pie"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PieChart(BasePlot):
|
|
20
|
+
"""Plot pie and donut charts from categorical aggregates.
|
|
21
|
+
|
|
22
|
+
Methods
|
|
23
|
+
-------
|
|
24
|
+
aplot
|
|
25
|
+
Plot a pie or donut chart on an existing Matplotlib axes.
|
|
26
|
+
fplot
|
|
27
|
+
Plot a pie or donut chart on a new Matplotlib figure.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, pd_df: pd.DataFrame, category: str, value: str):
|
|
31
|
+
validate_dataframe(pd_df, cols=[category, value])
|
|
32
|
+
super().__init__(pd_df=pd_df)
|
|
33
|
+
self.category = category
|
|
34
|
+
self.value = value
|
|
35
|
+
|
|
36
|
+
def aplot(
|
|
37
|
+
self,
|
|
38
|
+
donut: bool = False,
|
|
39
|
+
title: Optional[str] = None,
|
|
40
|
+
style: StyleTemplate = PIE_STYLE_TEMPLATE,
|
|
41
|
+
ax: Optional[Axes] = None,
|
|
42
|
+
**kwargs: Any,
|
|
43
|
+
) -> Axes:
|
|
44
|
+
"""Plot a pie or donut chart on the provided axis.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
donut : bool, optional
|
|
49
|
+
If True, render a donut chart. The default is False.
|
|
50
|
+
title : str, optional
|
|
51
|
+
Title for the plot. The default is None.
|
|
52
|
+
style : StyleTemplate, optional
|
|
53
|
+
Style template for the plot. The default is PIE_STYLE_TEMPLATE.
|
|
54
|
+
ax : Axes, optional
|
|
55
|
+
Matplotlib axes to plot on. If None, use the current axes.
|
|
56
|
+
**kwargs : Any
|
|
57
|
+
Additional keyword arguments reserved for compatibility.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
Axes
|
|
62
|
+
The Matplotlib axes containing the pie or donut chart.
|
|
63
|
+
"""
|
|
64
|
+
labels = self._obj[self.category].astype(str).tolist()
|
|
65
|
+
sizes = self._obj[self.value]
|
|
66
|
+
plot_ax = _get_axis(ax)
|
|
67
|
+
wedgeprops: Optional[Dict[str, Any]] = None
|
|
68
|
+
if donut:
|
|
69
|
+
wedgeprops = {"width": 0.3}
|
|
70
|
+
plot_ax.pie(
|
|
71
|
+
sizes,
|
|
72
|
+
labels=labels,
|
|
73
|
+
autopct="%1.1f%%",
|
|
74
|
+
colors=sns.color_palette(style.palette),
|
|
75
|
+
wedgeprops=wedgeprops,
|
|
76
|
+
textprops={"color": style.font_color, "fontsize": style.font_size},
|
|
77
|
+
)
|
|
78
|
+
plot_ax.axis("equal")
|
|
79
|
+
if title:
|
|
80
|
+
plot_ax.set_title(title)
|
|
81
|
+
return plot_ax
|
|
82
|
+
|
|
83
|
+
def fplot(
|
|
84
|
+
self,
|
|
85
|
+
donut: bool = False,
|
|
86
|
+
title: Optional[str] = None,
|
|
87
|
+
style: StyleTemplate = PIE_STYLE_TEMPLATE,
|
|
88
|
+
figsize: Tuple[float, float] = (8, 8),
|
|
89
|
+
) -> Figure:
|
|
90
|
+
"""Plot a pie or donut chart on a new figure.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
donut : bool, optional
|
|
95
|
+
If True, render a donut chart. The default is False.
|
|
96
|
+
title : str, optional
|
|
97
|
+
Title for the plot. The default is None.
|
|
98
|
+
style : StyleTemplate, optional
|
|
99
|
+
Style template for the plot. The default is PIE_STYLE_TEMPLATE.
|
|
100
|
+
figsize : tuple[float, float], optional
|
|
101
|
+
Figure size. The default is (8, 8).
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
Figure
|
|
106
|
+
The Matplotlib figure containing the pie or donut chart.
|
|
107
|
+
"""
|
|
108
|
+
fig = Figure(
|
|
109
|
+
figsize=figsize,
|
|
110
|
+
facecolor=style.background_color,
|
|
111
|
+
edgecolor=style.background_color,
|
|
112
|
+
)
|
|
113
|
+
ax = Axes(fig=fig, facecolor=style.background_color)
|
|
114
|
+
fig.set_facecolor(style.background_color)
|
|
115
|
+
self.aplot(donut=donut, title=title, style=style, ax=ax)
|
|
116
|
+
return fig
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def aplot_pie(
|
|
120
|
+
pd_df: pd.DataFrame,
|
|
121
|
+
category: str,
|
|
122
|
+
value: str,
|
|
123
|
+
donut: bool = False,
|
|
124
|
+
title: Optional[str] = None,
|
|
125
|
+
style: StyleTemplate = PIE_STYLE_TEMPLATE,
|
|
126
|
+
ax: Optional[Axes] = None,
|
|
127
|
+
**kwargs: Any,
|
|
128
|
+
) -> Axes:
|
|
129
|
+
"""Plot pie or donut charts for categorical share visualization."""
|
|
130
|
+
return PieChart(pd_df=pd_df, category=category, value=value).aplot(
|
|
131
|
+
donut=donut,
|
|
132
|
+
title=title,
|
|
133
|
+
style=style,
|
|
134
|
+
ax=ax,
|
|
135
|
+
**kwargs,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def fplot_pie(
|
|
140
|
+
pd_df: pd.DataFrame,
|
|
141
|
+
category: str,
|
|
142
|
+
value: str,
|
|
143
|
+
donut: bool = False,
|
|
144
|
+
title: Optional[str] = None,
|
|
145
|
+
style: StyleTemplate = PIE_STYLE_TEMPLATE,
|
|
146
|
+
figsize: Tuple[float, float] = (8, 8),
|
|
147
|
+
) -> Figure:
|
|
148
|
+
"""Plot pie or donut charts on a new figure."""
|
|
149
|
+
return PieChart(pd_df=pd_df, category=category, value=value).fplot(
|
|
150
|
+
donut=donut,
|
|
151
|
+
title=title,
|
|
152
|
+
style=style,
|
|
153
|
+
figsize=figsize,
|
|
154
|
+
)
|
MatplotLibAPI/pivot.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Pivot chart helpers for bar and line plots."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, Tuple, cast
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from matplotlib.axes import Axes
|
|
9
|
+
from matplotlib.figure import Figure
|
|
10
|
+
|
|
11
|
+
from .base_plot import BasePlot
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .style_template import (
|
|
15
|
+
FIG_SIZE,
|
|
16
|
+
PIVOTBARS_STYLE_TEMPLATE,
|
|
17
|
+
StyleTemplate,
|
|
18
|
+
string_formatter,
|
|
19
|
+
validate_dataframe,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__all__ = ["PIVOTBARS_STYLE_TEMPLATE", "aplot_pivoted_bars"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _pivot_and_sort_data(
|
|
26
|
+
data: pd.DataFrame,
|
|
27
|
+
index: str,
|
|
28
|
+
columns: str,
|
|
29
|
+
values: str,
|
|
30
|
+
aggfunc: str = "sum",
|
|
31
|
+
sort_by: Optional[str] = None,
|
|
32
|
+
ascending: bool = False,
|
|
33
|
+
) -> pd.DataFrame:
|
|
34
|
+
"""Pivot and sort a DataFrame.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
data : pd.DataFrame
|
|
39
|
+
The input DataFrame.
|
|
40
|
+
index : str
|
|
41
|
+
The column to use as the pivot table index.
|
|
42
|
+
columns : str
|
|
43
|
+
The column to use for pivot table columns.
|
|
44
|
+
values : str
|
|
45
|
+
The column to aggregate.
|
|
46
|
+
aggfunc : str, optional
|
|
47
|
+
The aggregation function, by default "sum".
|
|
48
|
+
sort_by : str, optional
|
|
49
|
+
The column to sort by.
|
|
50
|
+
ascending : bool, optional
|
|
51
|
+
The sort order, by default `False`.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
pd.DataFrame
|
|
56
|
+
A pivoted and sorted DataFrame.
|
|
57
|
+
"""
|
|
58
|
+
pivot_df = pd.pivot_table(
|
|
59
|
+
data, values=values, index=[index], columns=[columns], aggfunc=aggfunc # type: ignore
|
|
60
|
+
) # type: ignore
|
|
61
|
+
if sort_by:
|
|
62
|
+
pivot_df = pivot_df.sort_values(by=sort_by, ascending=ascending)
|
|
63
|
+
return pivot_df.reset_index()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class PivotBarChart(BasePlot):
|
|
67
|
+
"""Class for plotting bar charts from pivoted data."""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
pd_df: pd.DataFrame,
|
|
72
|
+
label: str,
|
|
73
|
+
x: str,
|
|
74
|
+
y: str,
|
|
75
|
+
agg: str = "sum",
|
|
76
|
+
stacked: bool = False,
|
|
77
|
+
):
|
|
78
|
+
cols = [label, x, y]
|
|
79
|
+
validate_dataframe(pd_df, cols=cols)
|
|
80
|
+
super().__init__(pd_df=pd_df)
|
|
81
|
+
self.label = label
|
|
82
|
+
self.x = x
|
|
83
|
+
self.y = y
|
|
84
|
+
self.agg = agg
|
|
85
|
+
self.stacked = stacked
|
|
86
|
+
|
|
87
|
+
def aplot(
|
|
88
|
+
self,
|
|
89
|
+
title: Optional[str] = None,
|
|
90
|
+
style: StyleTemplate = PIVOTBARS_STYLE_TEMPLATE,
|
|
91
|
+
sort_by: Optional[str] = None,
|
|
92
|
+
ascending: bool = False,
|
|
93
|
+
ax: Optional[Axes] = None,
|
|
94
|
+
**kwargs: Any,
|
|
95
|
+
) -> Axes:
|
|
96
|
+
pivot_df = _pivot_and_sort_data(
|
|
97
|
+
self._obj,
|
|
98
|
+
index=self.x,
|
|
99
|
+
columns=self.label,
|
|
100
|
+
values=self.y,
|
|
101
|
+
aggfunc=self.agg,
|
|
102
|
+
sort_by=sort_by,
|
|
103
|
+
ascending=ascending,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if ax is None:
|
|
107
|
+
ax = cast(Axes, plt.gca())
|
|
108
|
+
|
|
109
|
+
if pd.api.types.is_datetime64_any_dtype(pivot_df[self.x]):
|
|
110
|
+
pivot_df[self.x] = pivot_df[self.x].dt.strftime("%Y-%m-%d")
|
|
111
|
+
|
|
112
|
+
pivot_df.plot(kind="bar", x=self.x, stacked=self.stacked, ax=ax, alpha=0.7)
|
|
113
|
+
|
|
114
|
+
ax.set_ylabel(string_formatter(self.y))
|
|
115
|
+
ax.set_xlabel(string_formatter(self.x))
|
|
116
|
+
if title:
|
|
117
|
+
ax.set_title(title)
|
|
118
|
+
|
|
119
|
+
ax.legend(
|
|
120
|
+
fontsize=style.font_size - 2,
|
|
121
|
+
title_fontsize=style.font_size + 2,
|
|
122
|
+
labelcolor="linecolor",
|
|
123
|
+
facecolor=style.background_color,
|
|
124
|
+
)
|
|
125
|
+
ax.tick_params(axis="x", rotation=90)
|
|
126
|
+
return ax
|
|
127
|
+
|
|
128
|
+
def fplot(
|
|
129
|
+
self,
|
|
130
|
+
title: Optional[str] = None,
|
|
131
|
+
style: StyleTemplate = PIVOTBARS_STYLE_TEMPLATE,
|
|
132
|
+
sort_by: Optional[str] = None,
|
|
133
|
+
ascending: bool = False,
|
|
134
|
+
ax: Optional[Axes] = None,
|
|
135
|
+
figsize: Tuple[float, float] = FIG_SIZE,
|
|
136
|
+
**kwargs: Any,
|
|
137
|
+
) -> Figure:
|
|
138
|
+
|
|
139
|
+
fig = Figure(
|
|
140
|
+
figsize=figsize,
|
|
141
|
+
facecolor=style.background_color,
|
|
142
|
+
edgecolor=style.background_color,
|
|
143
|
+
)
|
|
144
|
+
ax = Axes(fig=fig, facecolor=style.background_color)
|
|
145
|
+
self.aplot(
|
|
146
|
+
title=title, style=style, sort_by=sort_by, ascending=ascending, ax=ax
|
|
147
|
+
)
|
|
148
|
+
return fig
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def aplot_pivoted_bars(
|
|
152
|
+
data: pd.DataFrame,
|
|
153
|
+
label: str,
|
|
154
|
+
x: str,
|
|
155
|
+
y: str,
|
|
156
|
+
agg: str = "sum",
|
|
157
|
+
style: StyleTemplate = PIVOTBARS_STYLE_TEMPLATE,
|
|
158
|
+
title: Optional[str] = None,
|
|
159
|
+
sort_by: Optional[str] = None,
|
|
160
|
+
ascending: bool = False,
|
|
161
|
+
ax: Optional[Axes] = None,
|
|
162
|
+
stacked: bool = False,
|
|
163
|
+
**kwargs,
|
|
164
|
+
) -> Axes:
|
|
165
|
+
"""Plot a bar chart from a pivot table.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
data : pd.DataFrame
|
|
170
|
+
The DataFrame containing the data to plot.
|
|
171
|
+
label : str
|
|
172
|
+
The column to pivot into series.
|
|
173
|
+
x : str
|
|
174
|
+
The column for the x-axis.
|
|
175
|
+
y : str
|
|
176
|
+
The column for the y-values.
|
|
177
|
+
agg : str, optional
|
|
178
|
+
The aggregation function for the pivot. The default is "sum".
|
|
179
|
+
style : StyleTemplate, optional
|
|
180
|
+
The style configuration. The default is `PIVOTBARS_STYLE_TEMPLATE`.
|
|
181
|
+
title : str, optional
|
|
182
|
+
The plot title.
|
|
183
|
+
sort_by : str, optional
|
|
184
|
+
The column to sort by.
|
|
185
|
+
ascending : bool, optional
|
|
186
|
+
The sort order. The default is `False`.
|
|
187
|
+
ax : Axes, optional
|
|
188
|
+
The axes to draw on.
|
|
189
|
+
stacked : bool, optional
|
|
190
|
+
Whether to stack the bars. The default is `False`.
|
|
191
|
+
|
|
192
|
+
Returns
|
|
193
|
+
-------
|
|
194
|
+
Axes
|
|
195
|
+
The matplotlib axes with the bar chart.
|
|
196
|
+
"""
|
|
197
|
+
return PivotBarChart(
|
|
198
|
+
pd_df=data,
|
|
199
|
+
label=label,
|
|
200
|
+
x=x,
|
|
201
|
+
y=y,
|
|
202
|
+
agg=agg,
|
|
203
|
+
stacked=stacked,
|
|
204
|
+
).aplot(
|
|
205
|
+
title=title,
|
|
206
|
+
style=style,
|
|
207
|
+
sort_by=sort_by,
|
|
208
|
+
ascending=ascending,
|
|
209
|
+
ax=ax,
|
|
210
|
+
**kwargs,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def fplot_pivoted_bars(
|
|
215
|
+
pd_df: pd.DataFrame,
|
|
216
|
+
label: str,
|
|
217
|
+
x: str,
|
|
218
|
+
y: str,
|
|
219
|
+
agg: str = "sum",
|
|
220
|
+
style: StyleTemplate = PIVOTBARS_STYLE_TEMPLATE,
|
|
221
|
+
title: Optional[str] = None,
|
|
222
|
+
sort_by: Optional[str] = None,
|
|
223
|
+
ascending: bool = False,
|
|
224
|
+
ax: Optional[Axes] = None,
|
|
225
|
+
stacked: bool = False,
|
|
226
|
+
figsize: Tuple[float, float] = FIG_SIZE,
|
|
227
|
+
) -> Figure:
|
|
228
|
+
"""Plot a bar chart from a pivot table.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
data : pd.DataFrame
|
|
233
|
+
The DataFrame containing the data to plot.
|
|
234
|
+
label : str
|
|
235
|
+
The column to pivot into series.
|
|
236
|
+
x : str
|
|
237
|
+
The column for the x-axis.
|
|
238
|
+
y : str
|
|
239
|
+
The column for the y-values.
|
|
240
|
+
agg : str, optional
|
|
241
|
+
The aggregation function for the pivot. The default is "sum".
|
|
242
|
+
style : StyleTemplate, optional
|
|
243
|
+
The style configuration. The default is `PIVOTBARS_STYLE_TEMPLATE`.
|
|
244
|
+
title : str, optional
|
|
245
|
+
The plot title.
|
|
246
|
+
sort_by : str, optional
|
|
247
|
+
The column to sort by.
|
|
248
|
+
ascending : bool, optional
|
|
249
|
+
The sort order. The default is `False`.
|
|
250
|
+
ax : Axes, optional
|
|
251
|
+
The axes to draw on.
|
|
252
|
+
stacked : bool, optional
|
|
253
|
+
Whether to stack the bars. The default is `False`.
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
Figure
|
|
258
|
+
The matplotlib figure with the bar chart.
|
|
259
|
+
"""
|
|
260
|
+
return PivotBarChart(
|
|
261
|
+
pd_df=pd_df,
|
|
262
|
+
label=label,
|
|
263
|
+
x=x,
|
|
264
|
+
y=y,
|
|
265
|
+
agg=agg,
|
|
266
|
+
stacked=stacked,
|
|
267
|
+
).fplot(
|
|
268
|
+
title=title,
|
|
269
|
+
style=style,
|
|
270
|
+
sort_by=sort_by,
|
|
271
|
+
ascending=ascending,
|
|
272
|
+
ax=ax,
|
|
273
|
+
figsize=figsize,
|
|
274
|
+
)
|
MatplotLibAPI/sankey.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Sankey plotting helpers."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, List, Optional, cast
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import plotly.graph_objects as go
|
|
8
|
+
|
|
9
|
+
from .style_template import SANKEY_STYLE_TEMPLATE, StyleTemplate, validate_dataframe
|
|
10
|
+
|
|
11
|
+
__all__ = ["SANKEY_STYLE_TEMPLATE", "fplot_sankey"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class SankeyNode:
|
|
16
|
+
label: List[str]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SankeyLink:
|
|
21
|
+
source: List[int]
|
|
22
|
+
target: List[int]
|
|
23
|
+
value: List[float]
|
|
24
|
+
|
|
25
|
+
def __post_init__(self) -> None:
|
|
26
|
+
"""Validate that all lists have the same length."""
|
|
27
|
+
if not (len(self.source) == len(self.target) == len(self.value)):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"All lists must have the same length. "
|
|
30
|
+
f"Got source={len(self.source)}, target={len(self.target)}, value={len(self.value)}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class SankeyData:
|
|
36
|
+
node: SankeyNode
|
|
37
|
+
link: SankeyLink
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def from_pandas_edgelist(
|
|
41
|
+
edges_df: pd.DataFrame,
|
|
42
|
+
source: str = "source",
|
|
43
|
+
target: str = "target",
|
|
44
|
+
edge_weight_col: str = "weight",
|
|
45
|
+
) -> "SankeyData":
|
|
46
|
+
"""Create SankeyData from a DataFrame."""
|
|
47
|
+
validate_dataframe(edges_df, cols=[source, target, edge_weight_col])
|
|
48
|
+
source_series = cast(pd.Series, edges_df[source])
|
|
49
|
+
target_series = cast(pd.Series, edges_df[target])
|
|
50
|
+
labels: List[str] = list(pd.unique(pd.concat([source_series, target_series])))
|
|
51
|
+
label_to_index: Dict[str, int] = {name: idx for idx, name in enumerate(labels)}
|
|
52
|
+
|
|
53
|
+
return SankeyData(
|
|
54
|
+
node=SankeyNode(label=labels),
|
|
55
|
+
link=SankeyLink(
|
|
56
|
+
source=[label_to_index[val] for val in edges_df[source]],
|
|
57
|
+
target=[label_to_index[val] for val in edges_df[target]],
|
|
58
|
+
value=edges_df[edge_weight_col].tolist(),
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def fplot(
|
|
63
|
+
self, title: Optional[str] = None, style: StyleTemplate = SANKEY_STYLE_TEMPLATE
|
|
64
|
+
) -> go.Figure:
|
|
65
|
+
"""Plot the Sankey diagram using Plotly."""
|
|
66
|
+
sankey = go.Sankey(
|
|
67
|
+
node=dict(
|
|
68
|
+
label=self.node.label,
|
|
69
|
+
pad=15,
|
|
70
|
+
thickness=20,
|
|
71
|
+
color=style.font_color,
|
|
72
|
+
),
|
|
73
|
+
link=dict(
|
|
74
|
+
source=self.link.source,
|
|
75
|
+
target=self.link.target,
|
|
76
|
+
value=self.link.value,
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
fig = go.Figure(sankey)
|
|
81
|
+
if title:
|
|
82
|
+
fig.update_layout(
|
|
83
|
+
title_text=title,
|
|
84
|
+
font=dict(color=style.font_color, size=style.font_size),
|
|
85
|
+
)
|
|
86
|
+
return fig
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def fplot_sankey(
|
|
90
|
+
pd_df: pd.DataFrame,
|
|
91
|
+
source: str,
|
|
92
|
+
target: str,
|
|
93
|
+
value: str,
|
|
94
|
+
title: Optional[str] = None,
|
|
95
|
+
style: StyleTemplate = SANKEY_STYLE_TEMPLATE,
|
|
96
|
+
) -> go.Figure:
|
|
97
|
+
"""Plot a Sankey diagram showing flows between categories."""
|
|
98
|
+
sankey_data = SankeyData.from_pandas_edgelist(pd_df, source, target, value)
|
|
99
|
+
return sankey_data.fplot(title=title, style=style)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Common style utilities and formatters for plotting."""
|
|
2
2
|
|
|
3
|
+
import math
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Callable, Dict, List, Optional, Union, cast
|
|
5
6
|
|
|
@@ -90,10 +91,11 @@ def format_func(
|
|
|
90
91
|
# region Style Constants
|
|
91
92
|
|
|
92
93
|
FIG_SIZE = (19.2, 10.8)
|
|
93
|
-
BACKGROUND_COLOR = "
|
|
94
|
-
TEXT_COLOR = "
|
|
95
|
-
PALETTE = "
|
|
94
|
+
BACKGROUND_COLOR = "white"
|
|
95
|
+
TEXT_COLOR = "black"
|
|
96
|
+
PALETTE = "tab10"
|
|
96
97
|
FONT_SIZE = 14
|
|
98
|
+
TITLE_SCALE_FACTOR = 2
|
|
97
99
|
MAX_RESULTS = 50
|
|
98
100
|
|
|
99
101
|
|
|
@@ -122,19 +124,23 @@ class StyleTemplate:
|
|
|
122
124
|
|
|
123
125
|
@property
|
|
124
126
|
def font_mapping(self) -> Dict[int, int]:
|
|
125
|
-
"""
|
|
127
|
+
"""Compute progressive font sizes based on the base font.
|
|
128
|
+
|
|
129
|
+
The mapping spans five emphasis levels, centered around ``font_size``.
|
|
130
|
+
Each step is scaled to 15% of the base font (minimum step of 1) and
|
|
131
|
+
clamped to a size of at least 1 point to avoid non-readable values for
|
|
132
|
+
very small fonts.
|
|
126
133
|
|
|
127
134
|
Returns
|
|
128
135
|
-------
|
|
129
136
|
dict[int, int]
|
|
130
|
-
Level
|
|
137
|
+
Level-to-font-size mapping where keys increase with size.
|
|
131
138
|
"""
|
|
139
|
+
base_size = max(int(self.font_size), 1)
|
|
140
|
+
step = max(int(math.ceil(base_size * 0.15)), 1)
|
|
132
141
|
return {
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
2: self.font_size,
|
|
136
|
-
3: self.font_size + 1,
|
|
137
|
-
4: self.font_size + 3,
|
|
142
|
+
idx: max(base_size + offset * step, 1)
|
|
143
|
+
for idx, offset in enumerate(range(-2, 3))
|
|
138
144
|
}
|
|
139
145
|
|
|
140
146
|
|
|
@@ -209,7 +215,7 @@ def generate_ticks(
|
|
|
209
215
|
min_val: Union[float, str, pd.Timestamp],
|
|
210
216
|
max_val: Union[float, str, pd.Timestamp],
|
|
211
217
|
num_ticks: int = 5,
|
|
212
|
-
) -> Union[np.ndarray, pd.DatetimeIndex]:
|
|
218
|
+
) -> Union[np.ndarray, pd.DatetimeIndex, pd.Timestamp]:
|
|
213
219
|
"""Generate evenly spaced ticks between min and max.
|
|
214
220
|
|
|
215
221
|
Parameters
|
|
@@ -261,6 +267,10 @@ def generate_ticks(
|
|
|
261
267
|
|
|
262
268
|
# region Style Presets
|
|
263
269
|
|
|
270
|
+
TABLE_STYLE_TEMPLATE = StyleTemplate(
|
|
271
|
+
background_color="black", fig_border="darkgrey", font_color="white", palette="magma"
|
|
272
|
+
)
|
|
273
|
+
|
|
264
274
|
BUBBLE_STYLE_TEMPLATE = StyleTemplate(
|
|
265
275
|
format_funcs=cast(
|
|
266
276
|
Dict[str, Optional[FormatterFunc]],
|
|
@@ -286,19 +296,12 @@ TABLE_STYLE_TEMPLATE = StyleTemplate()
|
|
|
286
296
|
TREEMAP_STYLE_TEMPLATE = StyleTemplate()
|
|
287
297
|
|
|
288
298
|
PIVOTBARS_STYLE_TEMPLATE = StyleTemplate(
|
|
289
|
-
background_color="black",
|
|
290
|
-
fig_border="darkgrey",
|
|
291
|
-
font_color="white",
|
|
292
|
-
palette="magma",
|
|
293
299
|
format_funcs=cast(
|
|
294
300
|
Dict[str, Optional[FormatterFunc]],
|
|
295
301
|
{"y": percent_formatter, "label": string_formatter},
|
|
296
302
|
),
|
|
297
303
|
)
|
|
298
304
|
PIVOTLINES_STYLE_TEMPLATE = StyleTemplate(
|
|
299
|
-
background_color="white",
|
|
300
|
-
fig_border="lightgrey",
|
|
301
|
-
palette="viridis",
|
|
302
305
|
format_funcs=cast(
|
|
303
306
|
Dict[str, Optional[FormatterFunc]],
|
|
304
307
|
{"y": percent_formatter, "label": string_formatter},
|
|
@@ -306,9 +309,11 @@ PIVOTLINES_STYLE_TEMPLATE = StyleTemplate(
|
|
|
306
309
|
)
|
|
307
310
|
|
|
308
311
|
NETWORK_STYLE_TEMPLATE = StyleTemplate()
|
|
309
|
-
DISTRIBUTION_STYLE_TEMPLATE = StyleTemplate(
|
|
310
|
-
HEATMAP_STYLE_TEMPLATE = StyleTemplate(
|
|
311
|
-
AREA_STYLE_TEMPLATE = StyleTemplate(
|
|
312
|
-
PIE_STYLE_TEMPLATE = StyleTemplate(
|
|
312
|
+
DISTRIBUTION_STYLE_TEMPLATE = StyleTemplate()
|
|
313
|
+
HEATMAP_STYLE_TEMPLATE = StyleTemplate()
|
|
314
|
+
AREA_STYLE_TEMPLATE = StyleTemplate()
|
|
315
|
+
PIE_STYLE_TEMPLATE = StyleTemplate()
|
|
313
316
|
SANKEY_STYLE_TEMPLATE = StyleTemplate()
|
|
317
|
+
WORDCLOUD_STYLE_TEMPLATE = StyleTemplate()
|
|
318
|
+
|
|
314
319
|
# endregion
|