matplotlib-sankey 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matplotlib_sankey/__init__.py +5 -0
- matplotlib_sankey/_patches.py +143 -0
- matplotlib_sankey/_plotting.py +269 -0
- matplotlib_sankey/_types.py +8 -0
- matplotlib_sankey/_utils.py +97 -0
- matplotlib_sankey/_version.py +1 -0
- matplotlib_sankey-0.2.1.dist-info/METADATA +100 -0
- matplotlib_sankey-0.2.1.dist-info/RECORD +10 -0
- matplotlib_sankey-0.2.1.dist-info/WHEEL +4 -0
- matplotlib_sankey-0.2.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from matplotlib.patches import PathPatch
|
|
2
|
+
from matplotlib.path import Path
|
|
3
|
+
|
|
4
|
+
from matplotlib_sankey._types import AcceptedColors
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def patch_line(
|
|
8
|
+
x_start: float,
|
|
9
|
+
x_end: float,
|
|
10
|
+
y1_start: float,
|
|
11
|
+
y1_end: float,
|
|
12
|
+
y2_start: float,
|
|
13
|
+
y2_end: float,
|
|
14
|
+
row_index: int,
|
|
15
|
+
spacing: float = 0.0,
|
|
16
|
+
alpha: float = 0.5,
|
|
17
|
+
color: AcceptedColors | None = None,
|
|
18
|
+
) -> PathPatch:
|
|
19
|
+
"""Generate line patch."""
|
|
20
|
+
path_patch_kwargs = {
|
|
21
|
+
"color": color,
|
|
22
|
+
"zorder": 0,
|
|
23
|
+
"alpha": alpha,
|
|
24
|
+
"lw": 0,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
return PathPatch(
|
|
28
|
+
Path(
|
|
29
|
+
vertices=[
|
|
30
|
+
(x_start, y1_start + (spacing * row_index)),
|
|
31
|
+
(x_end, y1_end + (spacing * row_index)),
|
|
32
|
+
(x_end, y2_start + (spacing * row_index)),
|
|
33
|
+
(x_start, y2_end + (spacing * row_index)),
|
|
34
|
+
],
|
|
35
|
+
codes=[Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO],
|
|
36
|
+
closed=True,
|
|
37
|
+
),
|
|
38
|
+
**path_patch_kwargs,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def patch_curve3(
|
|
43
|
+
x_start: float,
|
|
44
|
+
x_end: float,
|
|
45
|
+
y1_start: float,
|
|
46
|
+
y1_end: float,
|
|
47
|
+
y2_start: float,
|
|
48
|
+
y2_end: float,
|
|
49
|
+
row_index: int,
|
|
50
|
+
spacing: float = 0.0,
|
|
51
|
+
alpha: float = 0.5,
|
|
52
|
+
color: AcceptedColors | None = None,
|
|
53
|
+
) -> PathPatch:
|
|
54
|
+
"""Generate curve3 patch."""
|
|
55
|
+
path_patch_kwargs = {
|
|
56
|
+
"color": color,
|
|
57
|
+
"zorder": 0,
|
|
58
|
+
"alpha": alpha,
|
|
59
|
+
"lw": 0,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
x_middle = ((x_end - x_start) / 2) + x_start
|
|
63
|
+
|
|
64
|
+
return PathPatch(
|
|
65
|
+
Path(
|
|
66
|
+
vertices=[
|
|
67
|
+
(x_start, y1_start + (spacing * row_index)),
|
|
68
|
+
(x_middle, y1_start + (spacing * row_index)),
|
|
69
|
+
(x_middle, (y1_start - y1_end) / 2 + y1_end + (spacing * row_index)),
|
|
70
|
+
(x_middle, y1_end + (spacing * row_index)),
|
|
71
|
+
(x_end, y1_end + (spacing * row_index)),
|
|
72
|
+
(x_end, y2_start + (spacing * row_index)),
|
|
73
|
+
(x_middle, y2_start + (spacing * row_index)),
|
|
74
|
+
(x_middle, (y2_start - y2_end) / 2 + y2_end + (spacing * row_index)),
|
|
75
|
+
(x_middle, y2_end + (spacing * row_index)),
|
|
76
|
+
(x_start, y2_end + (spacing * row_index)),
|
|
77
|
+
],
|
|
78
|
+
codes=[
|
|
79
|
+
Path.MOVETO,
|
|
80
|
+
Path.CURVE3,
|
|
81
|
+
Path.CURVE3,
|
|
82
|
+
Path.CURVE3,
|
|
83
|
+
Path.CURVE3,
|
|
84
|
+
Path.LINETO,
|
|
85
|
+
Path.CURVE3,
|
|
86
|
+
Path.CURVE3,
|
|
87
|
+
Path.CURVE3,
|
|
88
|
+
Path.CURVE3,
|
|
89
|
+
],
|
|
90
|
+
closed=True,
|
|
91
|
+
),
|
|
92
|
+
**path_patch_kwargs,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def patch_curve4(
|
|
97
|
+
x_start: float,
|
|
98
|
+
x_end: float,
|
|
99
|
+
y1_start: float,
|
|
100
|
+
y1_end: float,
|
|
101
|
+
y2_start: float,
|
|
102
|
+
y2_end: float,
|
|
103
|
+
row_index: int,
|
|
104
|
+
spacing: float = 0.0,
|
|
105
|
+
alpha: float = 0.5,
|
|
106
|
+
color: AcceptedColors | None = None,
|
|
107
|
+
) -> PathPatch:
|
|
108
|
+
"""Generate curve3 patch."""
|
|
109
|
+
path_patch_kwargs = {
|
|
110
|
+
"color": color,
|
|
111
|
+
"zorder": 0,
|
|
112
|
+
"alpha": alpha,
|
|
113
|
+
"lw": 0,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
x_middle = ((x_end - x_start) / 2) + x_start
|
|
117
|
+
|
|
118
|
+
return PathPatch(
|
|
119
|
+
Path(
|
|
120
|
+
vertices=[
|
|
121
|
+
(x_start, y1_start + (spacing * row_index)),
|
|
122
|
+
(x_middle, y1_start + (spacing * row_index)),
|
|
123
|
+
(x_middle, y1_end + (spacing * row_index)),
|
|
124
|
+
(x_end, y1_end + (spacing * row_index)),
|
|
125
|
+
(x_end, y2_start + (spacing * row_index)),
|
|
126
|
+
(x_middle, y2_start + (spacing * row_index)),
|
|
127
|
+
(x_middle, y2_end + (spacing * row_index)),
|
|
128
|
+
(x_start, y2_end + (spacing * row_index)),
|
|
129
|
+
],
|
|
130
|
+
codes=[
|
|
131
|
+
Path.MOVETO,
|
|
132
|
+
Path.CURVE4,
|
|
133
|
+
Path.CURVE4,
|
|
134
|
+
Path.CURVE4,
|
|
135
|
+
Path.LINETO,
|
|
136
|
+
Path.CURVE4,
|
|
137
|
+
Path.CURVE4,
|
|
138
|
+
Path.CURVE4,
|
|
139
|
+
],
|
|
140
|
+
closed=True,
|
|
141
|
+
),
|
|
142
|
+
**path_patch_kwargs,
|
|
143
|
+
)
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
from matplotlib.axes import Axes
|
|
3
|
+
from matplotlib.patches import Rectangle, PathPatch, Patch
|
|
4
|
+
from matplotlib.ticker import FixedLocator
|
|
5
|
+
|
|
6
|
+
from ._types import AcceptedColors, CurveType
|
|
7
|
+
from ._utils import _clean_axis, _generate_cmap
|
|
8
|
+
from ._patches import patch_curve3, patch_curve4, patch_line
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def sankey(
|
|
12
|
+
data: list[list[tuple[int | str, int | str, float | int]]],
|
|
13
|
+
figsize: tuple[int, int] | None = None,
|
|
14
|
+
frameon: bool = False,
|
|
15
|
+
ax: Axes | None = None,
|
|
16
|
+
spacing: float = 0.00,
|
|
17
|
+
annotate_columns: bool = False,
|
|
18
|
+
rel_column_width: float = 0.15,
|
|
19
|
+
cmap: AcceptedColors = "tab10",
|
|
20
|
+
curve_type: CurveType = "curve4",
|
|
21
|
+
ribbon_alpha: float = 0.2,
|
|
22
|
+
ribbon_color: str = "black",
|
|
23
|
+
title: str | None = None,
|
|
24
|
+
show: bool = True,
|
|
25
|
+
show_legend: bool = False,
|
|
26
|
+
legend_labels: list[str] | None = None,
|
|
27
|
+
column_labels: list[str] | None = None,
|
|
28
|
+
) -> Axes:
|
|
29
|
+
"""Sankey plot.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
data (list[list[tuple[int | str, int | str, float | int]]]): Input data.
|
|
33
|
+
figsize (tuple[int, int] | None): Size of figure. Defaults to `None`.
|
|
34
|
+
frameon (bool, optional): Draw frame. Defaults to `True`.
|
|
35
|
+
ax (matplotlib.Axes | None, optional): Provide matplotlib Axes instance for plotting. Defaults to `None`.
|
|
36
|
+
spacing (float, optional): Spacing between column and ribbon patches. Defaults to `0.0`.
|
|
37
|
+
annotate_columns (bool, optional): Annotate columns of plot. Defaults to `False`.
|
|
38
|
+
rel_column_width (float, optional): Relative width of columns compared to ribbons. Defaults to `0.15`. Value must be between 0 and 1.
|
|
39
|
+
cmap (Sequence[str] | Colormap | str | Sequence[tuple[float, float, float]], optional): Colors or colormap for columns.
|
|
40
|
+
curve_type (Literal["curve3", "curve4", "line"], optional): Curve type ofo ribbon. Defaults to `"curve4"`.
|
|
41
|
+
ribbon_alpha (float, optional): Alpha of ribbons. Defaults to `0.2`.
|
|
42
|
+
ribbon_color (str, optional): Color of ribbons. Defaults to `"black"`.
|
|
43
|
+
title (str | None, optional): Title of figure. Defaults to `None`.
|
|
44
|
+
show (bool, optional): Show figure. Defaults to `True`.
|
|
45
|
+
show_legend (bool, optional): Show legend. Defaults to `False`. If legend should be displayed, also provide `legend_labels`.
|
|
46
|
+
legend_labels (list[str] | None, optional): Labels to display in legend. Defaults to `None`.
|
|
47
|
+
column_labels (list[str] | None, optional): Labels for columns. Defaults to `None`.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Matplotlib axes instance.
|
|
51
|
+
|
|
52
|
+
ReturnType:
|
|
53
|
+
matplitlib.Axes
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
assert rel_column_width > 0 and rel_column_width < 1, (
|
|
57
|
+
"Value for parameter 'rel_column_width' must be between 0 and 1."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if ax is None:
|
|
61
|
+
_, ax = plt.subplots(figsize=figsize, frameon=frameon)
|
|
62
|
+
|
|
63
|
+
ncols = len(data) + 1
|
|
64
|
+
|
|
65
|
+
ax = _clean_axis(ax, frameon=frameon)
|
|
66
|
+
|
|
67
|
+
ax.set_ylim(0.0, 1.0)
|
|
68
|
+
ax.set_xlim(-1 * (rel_column_width / 2), (ncols - 1) + (rel_column_width / 2))
|
|
69
|
+
|
|
70
|
+
if column_labels is not None:
|
|
71
|
+
ax.xaxis.set_major_locator(FixedLocator(list(range(ncols))))
|
|
72
|
+
|
|
73
|
+
# Prepare data
|
|
74
|
+
column_weights: list[dict[int | str, int | float]] = [{} for _ in range(ncols)]
|
|
75
|
+
|
|
76
|
+
cmap = _generate_cmap(cmap, 20)
|
|
77
|
+
|
|
78
|
+
for frame_index, frame in enumerate(data):
|
|
79
|
+
for source_index, target_index, weight in frame:
|
|
80
|
+
if column_weights[frame_index] is None:
|
|
81
|
+
column_weights[frame_index] = {
|
|
82
|
+
source_index: weight,
|
|
83
|
+
}
|
|
84
|
+
else:
|
|
85
|
+
column_weights[frame_index][source_index] = column_weights[frame_index].get(source_index, 0) + weight
|
|
86
|
+
|
|
87
|
+
if frame_index == len(data) - 1:
|
|
88
|
+
# Add weights for last column
|
|
89
|
+
if column_weights[frame_index + 1] is None:
|
|
90
|
+
column_weights[frame_index + 1] = {
|
|
91
|
+
target_index: weight,
|
|
92
|
+
}
|
|
93
|
+
else:
|
|
94
|
+
column_weights[frame_index + 1][target_index] = (
|
|
95
|
+
column_weights[frame_index + 1].get(target_index, 0) + weight
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Plot rectangles
|
|
99
|
+
column_rects: list[dict[int | str, tuple[float, float, float, float]]] = [{} for _ in range(ncols)]
|
|
100
|
+
rect_num = 0
|
|
101
|
+
|
|
102
|
+
legend_handles: list[tuple[str, tuple[float, float, float, float]]] = []
|
|
103
|
+
|
|
104
|
+
for frame_index in range(ncols):
|
|
105
|
+
column_total_weight = sum(column_weights[frame_index].values())
|
|
106
|
+
column_prev_weight = 0.0
|
|
107
|
+
|
|
108
|
+
column_n_spacing = len(column_weights[frame_index].values()) - 1
|
|
109
|
+
|
|
110
|
+
spacing_scale_factor = 1 - (spacing * column_n_spacing)
|
|
111
|
+
|
|
112
|
+
for column_index, (column_key, weights) in enumerate(column_weights[frame_index].items()):
|
|
113
|
+
rect_x = frame_index - (rel_column_width / 2)
|
|
114
|
+
rect_y = column_prev_weight / column_total_weight + (column_index * spacing)
|
|
115
|
+
rect_height = (weights * spacing_scale_factor) / column_total_weight
|
|
116
|
+
|
|
117
|
+
column_prev_weight += weights * spacing_scale_factor
|
|
118
|
+
|
|
119
|
+
rect = Rectangle(
|
|
120
|
+
xy=(
|
|
121
|
+
rect_x,
|
|
122
|
+
rect_y,
|
|
123
|
+
),
|
|
124
|
+
width=rel_column_width,
|
|
125
|
+
height=rect_height,
|
|
126
|
+
color=cmap(rect_num),
|
|
127
|
+
zorder=1,
|
|
128
|
+
lw=0,
|
|
129
|
+
)
|
|
130
|
+
ax.add_patch(rect)
|
|
131
|
+
|
|
132
|
+
# Save in lookup dict
|
|
133
|
+
if column_rects[frame_index] is None:
|
|
134
|
+
column_rects[frame_index] = {column_key: (rect_x, rect_y, rel_column_width, rect_height)}
|
|
135
|
+
else:
|
|
136
|
+
column_rects[frame_index][column_key] = (rect_x, rect_y, rel_column_width, rect_height)
|
|
137
|
+
|
|
138
|
+
if annotate_columns is True:
|
|
139
|
+
ax.text(
|
|
140
|
+
x=rect_x + (rel_column_width / 2),
|
|
141
|
+
y=rect_y + (rect_height / 2),
|
|
142
|
+
s=str(column_key),
|
|
143
|
+
ha="center",
|
|
144
|
+
va="center",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
legend_handles.append((str(column_key), cmap(rect_num)))
|
|
148
|
+
|
|
149
|
+
rect_num += 1
|
|
150
|
+
|
|
151
|
+
# Plot ribbons
|
|
152
|
+
|
|
153
|
+
for frame_index in range(ncols - 1):
|
|
154
|
+
target_ribbon_offset: dict[int | str, int | float] = {}
|
|
155
|
+
|
|
156
|
+
for column_key in column_weights[frame_index].keys():
|
|
157
|
+
# print(column_key)
|
|
158
|
+
# Source rect dimensions
|
|
159
|
+
rect_x, rect_y, _, rect_height = column_rects[frame_index][column_key]
|
|
160
|
+
|
|
161
|
+
# Get all connection targets
|
|
162
|
+
column_targets: dict[int | str, float | int] = {}
|
|
163
|
+
for source, target, connection_weights in data[frame_index]:
|
|
164
|
+
if source == column_key:
|
|
165
|
+
column_targets[target] = connection_weights
|
|
166
|
+
|
|
167
|
+
ribbon_offset: float = 0.0
|
|
168
|
+
|
|
169
|
+
for target_index, ribbon_weight in column_targets.items():
|
|
170
|
+
# Start coords
|
|
171
|
+
y1_start = rect_y + +(rect_height * (ribbon_offset / sum(column_targets.values())))
|
|
172
|
+
y2_end = rect_y + (rect_height * ((ribbon_offset + ribbon_weight) / sum(column_targets.values())))
|
|
173
|
+
|
|
174
|
+
ribbon_offset += ribbon_weight
|
|
175
|
+
|
|
176
|
+
_, target_rect_y, _, target_rect_height = column_rects[frame_index + 1][target_index]
|
|
177
|
+
|
|
178
|
+
# End coords
|
|
179
|
+
y1_end = target_rect_y + (
|
|
180
|
+
target_rect_height
|
|
181
|
+
* (target_ribbon_offset.get(target_index, 0) / column_weights[frame_index + 1][target_index])
|
|
182
|
+
)
|
|
183
|
+
y2_start = target_rect_y + (
|
|
184
|
+
target_rect_height
|
|
185
|
+
* (
|
|
186
|
+
(ribbon_weight + target_ribbon_offset.get(target_index, 0))
|
|
187
|
+
/ column_weights[frame_index + 1][target_index]
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
target_ribbon_offset[target_index] = target_ribbon_offset.get(target_index, 0) + ribbon_weight
|
|
192
|
+
|
|
193
|
+
poly: PathPatch
|
|
194
|
+
|
|
195
|
+
if curve_type == "curve4":
|
|
196
|
+
poly = patch_curve4(
|
|
197
|
+
x_start=frame_index + (rel_column_width / 2),
|
|
198
|
+
x_end=frame_index + 1 - (rel_column_width / 2),
|
|
199
|
+
y1_start=y1_start,
|
|
200
|
+
y1_end=y1_end,
|
|
201
|
+
y2_start=y2_start,
|
|
202
|
+
y2_end=y2_end,
|
|
203
|
+
row_index=0,
|
|
204
|
+
alpha=ribbon_alpha,
|
|
205
|
+
color=ribbon_color,
|
|
206
|
+
spacing=0,
|
|
207
|
+
)
|
|
208
|
+
elif curve_type == "curve3":
|
|
209
|
+
poly = patch_curve3(
|
|
210
|
+
x_start=frame_index + (rel_column_width / 2),
|
|
211
|
+
x_end=frame_index + 1 - (rel_column_width / 2),
|
|
212
|
+
y1_start=y1_start,
|
|
213
|
+
y1_end=y1_end,
|
|
214
|
+
y2_start=y2_start,
|
|
215
|
+
y2_end=y2_end,
|
|
216
|
+
row_index=0,
|
|
217
|
+
alpha=ribbon_alpha,
|
|
218
|
+
color=ribbon_color,
|
|
219
|
+
spacing=0,
|
|
220
|
+
)
|
|
221
|
+
elif curve_type == "line":
|
|
222
|
+
poly = patch_line(
|
|
223
|
+
x_start=frame_index + (rel_column_width / 2),
|
|
224
|
+
x_end=frame_index + 1 - (rel_column_width / 2),
|
|
225
|
+
y1_start=y1_start,
|
|
226
|
+
y1_end=y1_end,
|
|
227
|
+
y2_start=y2_start,
|
|
228
|
+
y2_end=y2_end,
|
|
229
|
+
row_index=0,
|
|
230
|
+
alpha=ribbon_alpha,
|
|
231
|
+
color=ribbon_color,
|
|
232
|
+
spacing=0,
|
|
233
|
+
)
|
|
234
|
+
else:
|
|
235
|
+
raise ValueError(f"curve_type '{curve_type}' not supported.")
|
|
236
|
+
|
|
237
|
+
ax.add_patch(poly)
|
|
238
|
+
|
|
239
|
+
if show_legend is True:
|
|
240
|
+
legend_patches = []
|
|
241
|
+
for handle_index, (label, color) in enumerate(legend_handles):
|
|
242
|
+
if legend_labels is None:
|
|
243
|
+
legend_patches.append(Patch(facecolor=color, label=label))
|
|
244
|
+
else:
|
|
245
|
+
legend_patches.append(Patch(facecolor=color, label=legend_labels[handle_index]))
|
|
246
|
+
|
|
247
|
+
ax.legend(
|
|
248
|
+
handles=legend_patches,
|
|
249
|
+
frameon=False,
|
|
250
|
+
bbox_to_anchor=(1, 1),
|
|
251
|
+
loc="upper left",
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if title is not None:
|
|
255
|
+
ax.set_title(title)
|
|
256
|
+
|
|
257
|
+
if show is False:
|
|
258
|
+
plt.close()
|
|
259
|
+
# else:
|
|
260
|
+
# plt.show()
|
|
261
|
+
|
|
262
|
+
if column_labels is not None:
|
|
263
|
+
assert len(column_labels) == ncols
|
|
264
|
+
ax.set_xticklabels(column_labels)
|
|
265
|
+
# else:
|
|
266
|
+
# ax.set_xticklabels(None)
|
|
267
|
+
# ax.set_xticks([])
|
|
268
|
+
|
|
269
|
+
return ax
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Literal, TypeAlias
|
|
3
|
+
|
|
4
|
+
from matplotlib.colors import Colormap
|
|
5
|
+
|
|
6
|
+
CurveType: TypeAlias = Literal["curve3", "curve4", "line"]
|
|
7
|
+
|
|
8
|
+
AcceptedColors: TypeAlias = Sequence[str] | Colormap | str | Sequence[tuple[float, float, float]]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from matplotlib import colormaps
|
|
5
|
+
from matplotlib.axes import Axes
|
|
6
|
+
from matplotlib.colors import Colormap, ListedColormap
|
|
7
|
+
|
|
8
|
+
from ._types import AcceptedColors
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _clean_axis(
|
|
12
|
+
_ax: Axes,
|
|
13
|
+
frameon: bool = True,
|
|
14
|
+
reset_x_ticks: bool = True,
|
|
15
|
+
reset_y_ticks: bool = True,
|
|
16
|
+
) -> Axes:
|
|
17
|
+
"""Helper function to clean axes."""
|
|
18
|
+
if reset_x_ticks is True:
|
|
19
|
+
_ax.set_xticklabels([])
|
|
20
|
+
_ax.set_xticks([])
|
|
21
|
+
|
|
22
|
+
if reset_y_ticks is True:
|
|
23
|
+
_ax.set_yticklabels([])
|
|
24
|
+
_ax.set_yticks([])
|
|
25
|
+
# _ax.set_ylim(0, 1)
|
|
26
|
+
|
|
27
|
+
if frameon is False:
|
|
28
|
+
# Despine
|
|
29
|
+
_ax.spines["top"].set_visible(False)
|
|
30
|
+
_ax.spines["left"].set_visible(False)
|
|
31
|
+
_ax.spines["right"].set_visible(False)
|
|
32
|
+
_ax.spines["bottom"].set_visible(False)
|
|
33
|
+
return _ax
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _generate_cmap(value: AcceptedColors, nrows: int) -> Colormap:
|
|
37
|
+
"""Util function to generate colormap from list of string or name."""
|
|
38
|
+
|
|
39
|
+
def _convert_sequential_cmap_to_listed(c: Colormap, threshold: int = 256) -> Colormap:
|
|
40
|
+
if c.N >= threshold:
|
|
41
|
+
return ListedColormap([c(i) for i in np.linspace(start=0, stop=c.N, num=nrows).astype(int)])
|
|
42
|
+
|
|
43
|
+
return c
|
|
44
|
+
|
|
45
|
+
if isinstance(value, str):
|
|
46
|
+
# String argument must be the name of an colormap
|
|
47
|
+
assert value in list(colormaps.keys()), f"Value '{value}' is not the name of a valid colormap."
|
|
48
|
+
|
|
49
|
+
return _convert_sequential_cmap_to_listed(colormaps.get_cmap(value))
|
|
50
|
+
|
|
51
|
+
elif isinstance(value, Sequence):
|
|
52
|
+
return ListedColormap(value)
|
|
53
|
+
|
|
54
|
+
elif isinstance(value, Colormap):
|
|
55
|
+
return _convert_sequential_cmap_to_listed(value)
|
|
56
|
+
|
|
57
|
+
raise TypeError(f"Type '{type(value).__name__}' not allowed.")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def from_matrix(
|
|
61
|
+
mtx: Sequence[Sequence[int | float]],
|
|
62
|
+
source_indicies: list[int | str] | None = None,
|
|
63
|
+
target_indicies: list[int | str] | None = None,
|
|
64
|
+
) -> Sequence[tuple[int | str, int | str, float | int]]:
|
|
65
|
+
"""Convert weight matrix to tuple of source, target and weight.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
mtx (Sequence[Sequence[int | float]], optional): Weight matrix (source x target).
|
|
69
|
+
source_indicies (list[int | str] | None, optional): List of source indices. Defaults to `None`.
|
|
70
|
+
target_indicies (list[int | str] | None, optional): List of target indices. Defaults to `None`.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
List of tuples containing source, target and weight.
|
|
74
|
+
|
|
75
|
+
ReturnType:
|
|
76
|
+
list[tuple[int | str, int | str, float | int]]
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
# Check correct dimensions of index list
|
|
80
|
+
if source_indicies is not None:
|
|
81
|
+
assert len(source_indicies) == len(mtx)
|
|
82
|
+
if target_indicies is not None:
|
|
83
|
+
assert len(target_indicies) == len(mtx[0])
|
|
84
|
+
|
|
85
|
+
connections = []
|
|
86
|
+
for row in range(len(mtx)):
|
|
87
|
+
for col in range(len(mtx[row])):
|
|
88
|
+
if mtx[row][col] > 0:
|
|
89
|
+
source_index: int | str = row
|
|
90
|
+
target_index: int | str = col
|
|
91
|
+
|
|
92
|
+
if source_indicies is not None:
|
|
93
|
+
source_index = source_indicies[row]
|
|
94
|
+
if target_indicies is not None:
|
|
95
|
+
target_index = target_indicies[col]
|
|
96
|
+
connections.append((source_index, target_index, mtx[row][col]))
|
|
97
|
+
return connections
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.1"
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: matplotlib-sankey
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: Sankey plot for matplotlib
|
|
5
|
+
Author: harryhaller001
|
|
6
|
+
Maintainer-email: harryhaller001 <harryhaller001@gmail.com>
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Framework :: Matplotlib
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Natural Language :: English
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Typing :: Typed
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: matplotlib
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Requires-Dist: ipykernel ; extra == "docs"
|
|
24
|
+
Requires-Dist: ipython ; extra == "docs"
|
|
25
|
+
Requires-Dist: ipywidgets ; extra == "docs"
|
|
26
|
+
Requires-Dist: myst-parser ; extra == "docs"
|
|
27
|
+
Requires-Dist: nbsphinx ; extra == "docs"
|
|
28
|
+
Requires-Dist: networkx[default] ; extra == "docs"
|
|
29
|
+
Requires-Dist: sphinx>=4 ; extra == "docs"
|
|
30
|
+
Requires-Dist: sphinx-autoapi ; extra == "docs"
|
|
31
|
+
Requires-Dist: sphinx-autodoc-typehints ; extra == "docs"
|
|
32
|
+
Requires-Dist: sphinx-book-theme>=1 ; extra == "docs"
|
|
33
|
+
Requires-Dist: coverage ; extra == "test"
|
|
34
|
+
Requires-Dist: flit ; extra == "test"
|
|
35
|
+
Requires-Dist: mypy ; extra == "test"
|
|
36
|
+
Requires-Dist: pre-commit ; extra == "test"
|
|
37
|
+
Requires-Dist: pytest ; extra == "test"
|
|
38
|
+
Requires-Dist: ruff ; extra == "test"
|
|
39
|
+
Requires-Dist: setuptools ; extra == "test"
|
|
40
|
+
Requires-Dist: twine>=4.0.2 ; extra == "test"
|
|
41
|
+
Project-URL: Documentation, https://github.com/harryhaller001/matplotlib-sankey
|
|
42
|
+
Project-URL: Homepage, https://github.com/harryhaller001/matplotlib-sankey
|
|
43
|
+
Project-URL: Source, https://github.com/harryhaller001/matplotlib-sankey
|
|
44
|
+
Provides-Extra: docs
|
|
45
|
+
Provides-Extra: test
|
|
46
|
+
|
|
47
|
+
# matplotlib-sankey
|
|
48
|
+
|
|
49
|
+
Sankey plot for matplotlib
|
|
50
|
+
|
|
51
|
+
### Installation
|
|
52
|
+
|
|
53
|
+
Install with pip:
|
|
54
|
+
|
|
55
|
+
`pip install matplotlib-sankey`
|
|
56
|
+
|
|
57
|
+
Install from source:
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
git clone https://github.com/harryhaller001/matplotlib-sankey
|
|
61
|
+
cd matplotlib-sankey
|
|
62
|
+
pip install .
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
### Example
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
data = [
|
|
70
|
+
# (source index, target index, weight)
|
|
71
|
+
[(0, 2, 20), (0, 1, 10), (3, 4, 15), (3, 2, 10), (5, 1, 5), (5, 2, 50)],
|
|
72
|
+
[(2, 6, 40), (1, 6, 15), (2, 7, 40), (4, 6, 15)],
|
|
73
|
+
[(7, 8, 5), (7, 9, 5), (7, 10, 20), (7, 11, 10), (6, 11, 55), (6, 8, 15)],
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
77
|
+
fig.tight_layout()
|
|
78
|
+
sankey(
|
|
79
|
+
data=data,
|
|
80
|
+
cmap="tab20",
|
|
81
|
+
annotate_columns=True,
|
|
82
|
+
ax=ax,
|
|
83
|
+
spacing=0.03,
|
|
84
|
+
)
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+

|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
### Development
|
|
92
|
+
|
|
93
|
+
```bash
|
|
94
|
+
python3.10 -m virtualenv venv
|
|
95
|
+
source venv/bin/activate
|
|
96
|
+
|
|
97
|
+
# Install dev dependencies
|
|
98
|
+
make install
|
|
99
|
+
```
|
|
100
|
+
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
matplotlib_sankey/__init__.py,sha256=1_BkFKWc5IsotEarp0lzZpFXQBvew07qnplc0_QNUkA,148
|
|
2
|
+
matplotlib_sankey/_patches.py,sha256=3xYWHhAgbBD6yYF_qfqHezchZVAbXFUL9zSGhi-l3_A,3941
|
|
3
|
+
matplotlib_sankey/_plotting.py,sha256=w3EEKSj2QT0XjpUt4juDjVvnPXyTt8VcYpe9BJ35Jow,10580
|
|
4
|
+
matplotlib_sankey/_types.py,sha256=4BhTH8vndjpF8QZjW7l43xLvOdMYlFTPLVcu3KX3a_Q,274
|
|
5
|
+
matplotlib_sankey/_utils.py,sha256=3CwG1cjuqh24tSKTzpvYxM7pew_o9ix0zalVYX0mr-I,3246
|
|
6
|
+
matplotlib_sankey/_version.py,sha256=HfjVOrpTnmZ-xVFCYSVmX50EXaBQeJteUHG-PD6iQs8,22
|
|
7
|
+
matplotlib_sankey-0.2.1.dist-info/licenses/LICENSE,sha256=pKQY7EaklCnxWGJNlxr2OrDtJrb6MuqlcDj-x4jHKmU,1071
|
|
8
|
+
matplotlib_sankey-0.2.1.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
9
|
+
matplotlib_sankey-0.2.1.dist-info/METADATA,sha256=UTPYBlLKk2hJzRVE4up3ZkYZ1FU_PJL4z-pypEEwbjU,2797
|
|
10
|
+
matplotlib_sankey-0.2.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 harryhaller001
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|