pandas-plots 0.12.7__tar.gz → 0.12.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pandas_plots-0.12.7/src/pandas_plots.egg-info → pandas_plots-0.12.8}/PKG-INFO +1 -1
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/setup.cfg +1 -1
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots/pls.py +253 -125
- {pandas_plots-0.12.7 → pandas_plots-0.12.8/src/pandas_plots.egg-info}/PKG-INFO +1 -1
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/LICENSE +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/README.md +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/pyproject.toml +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots/hlp.py +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots/pii.py +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots/tbl.py +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots/ven.py +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots.egg-info/SOURCES.txt +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots.egg-info/dependency_links.txt +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots.egg-info/pii.py +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots.egg-info/requires.txt +0 -0
- {pandas_plots-0.12.7 → pandas_plots-0.12.8}/src/pandas_plots.egg-info/top_level.txt +0 -0
@@ -1,7 +1,4 @@
|
|
1
1
|
from pathlib import Path
|
2
|
-
import warnings
|
3
|
-
|
4
|
-
warnings.filterwarnings("ignore")
|
5
2
|
|
6
3
|
import os
|
7
4
|
from typing import Optional, Literal
|
@@ -12,51 +9,118 @@ from matplotlib import pyplot as plt
|
|
12
9
|
from plotly import express as px
|
13
10
|
import plotly.graph_objects as go
|
14
11
|
from plotly.subplots import make_subplots
|
15
|
-
import plotly
|
12
|
+
import plotly # needed for return types
|
16
13
|
|
17
14
|
from .hlp import *
|
18
15
|
from .tbl import print_summary
|
19
16
|
|
20
17
|
### helper functions
|
21
18
|
|
19
|
+
|
22
20
|
def _set_caption(caption: str) -> str:
|
23
21
|
return f"#️⃣{'-'.join(caption.split())}, " if caption else ""
|
24
22
|
|
25
23
|
|
26
|
-
def aggregate_data(
|
24
|
+
def aggregate_data(
|
25
|
+
df: pd.DataFrame,
|
26
|
+
top_n_index: int,
|
27
|
+
top_n_color: int,
|
28
|
+
top_n_facet: int,
|
29
|
+
null_label: str,
|
30
|
+
show_other: bool = False,
|
31
|
+
sort_values_index: bool = False,
|
32
|
+
sort_values_color: bool = False,
|
33
|
+
sort_values_facet: bool = False,
|
34
|
+
) -> pd.DataFrame:
|
27
35
|
"""
|
28
36
|
Aggregates the data, ensuring each combination of 'index', 'col', and 'facet' is unique with summed 'value'.
|
29
|
-
|
37
|
+
|
30
38
|
Args:
|
31
39
|
df (pd.DataFrame): Input DataFrame.
|
32
40
|
top_n_index (int): top N values of the first column to keep. 0 means take all.
|
33
|
-
|
41
|
+
top_n_color (int): top N values of the second column to keep. 0 means take all.
|
34
42
|
top_n_facet (int): top N values of the third column to keep. 0 means take all.
|
35
43
|
null_label (str): Label for null values.
|
44
|
+
show_other (bool): Whether to include "<other>" for columns not in top_n_color. Defaults to False.
|
45
|
+
sort_values (bool): Whether to sort values in descending order based on group sum. Defaults to False.
|
36
46
|
|
37
47
|
Returns:
|
38
48
|
pd.DataFrame: Aggregated and filtered dataset.
|
39
49
|
"""
|
40
|
-
|
50
|
+
|
51
|
+
for col in ["index", "col", "facet"]: # Skip 'value' column (numeric)
|
41
52
|
df[col] = df[col].fillna(null_label)
|
42
53
|
|
43
54
|
# Aggregate data to ensure unique combinations
|
44
|
-
aggregated_df = df.groupby([
|
55
|
+
aggregated_df = df.groupby(["index", "col", "facet"], as_index=False)["value"].sum()
|
56
|
+
|
57
|
+
# * Reduce data based on top_n parameters
|
58
|
+
if sort_values_index:
|
59
|
+
top_indexes = (
|
60
|
+
aggregated_df.groupby("index")["value"]
|
61
|
+
.sum()
|
62
|
+
.sort_values(ascending=False)[:top_n_index or None]
|
63
|
+
.index
|
64
|
+
)
|
65
|
+
else:
|
66
|
+
top_indexes = aggregated_df["index"].sort_values().unique()[:top_n_index or None]
|
67
|
+
|
68
|
+
aggregated_df = aggregated_df[aggregated_df["index"].isin(top_indexes)]
|
69
|
+
|
70
|
+
if sort_values_color:
|
71
|
+
top_colors = (
|
72
|
+
aggregated_df.groupby("col")["value"]
|
73
|
+
.sum()
|
74
|
+
.sort_values(ascending=False)[:top_n_color or None]
|
75
|
+
.index
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
top_colors = aggregated_df["col"].sort_values().unique()[:top_n_color or None]
|
79
|
+
|
80
|
+
others_df = df[~df["col"].isin(top_colors)]
|
81
|
+
aggregated_df = aggregated_df[aggregated_df["col"].isin(top_colors)]
|
82
|
+
if show_other and top_n_color > 0 and not others_df.empty:
|
83
|
+
other_agg = others_df.groupby(["index", "facet"], as_index=False)[
|
84
|
+
"value"
|
85
|
+
].sum()
|
86
|
+
other_agg["col"] = "<other>"
|
87
|
+
other_agg = other_agg[["index", "col", "facet", "value"]]
|
88
|
+
aggregated_df = pd.concat([aggregated_df, other_agg], ignore_index=True)
|
89
|
+
top_colors = [*top_colors, "<other>"]
|
90
|
+
|
91
|
+
if sort_values_facet:
|
92
|
+
top_facets = (
|
93
|
+
aggregated_df.groupby("facet")["value"]
|
94
|
+
.sum()
|
95
|
+
.sort_values(ascending=False)[:top_n_facet or None]
|
96
|
+
.index
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
top_facets = aggregated_df["facet"].sort_values().unique()[:top_n_facet or None]
|
100
|
+
|
101
|
+
aggregated_df = aggregated_df[aggregated_df["facet"].isin(top_facets)]
|
102
|
+
|
103
|
+
# * Ensure facets are sorted alphabetically
|
104
|
+
aggregated_df["facet"] = pd.Categorical(
|
105
|
+
values=aggregated_df["facet"],
|
106
|
+
categories=top_facets,
|
107
|
+
ordered=True,
|
108
|
+
)
|
109
|
+
|
110
|
+
aggregated_df["index"] = pd.Categorical(
|
111
|
+
values=aggregated_df["index"],
|
112
|
+
categories=top_indexes,
|
113
|
+
ordered=True,
|
114
|
+
)
|
115
|
+
|
116
|
+
aggregated_df["col"] = pd.Categorical(
|
117
|
+
values=aggregated_df["col"],
|
118
|
+
categories=top_colors,
|
119
|
+
ordered=True,
|
120
|
+
)
|
45
121
|
|
46
|
-
# Reduce data based on top_n parameters
|
47
|
-
if top_n_index > 0:
|
48
|
-
top_indexes = aggregated_df.groupby('index')['value'].sum().nlargest(top_n_index).index
|
49
|
-
aggregated_df = aggregated_df[aggregated_df['index'].isin(top_indexes)]
|
50
|
-
if top_n_columns > 0:
|
51
|
-
top_columns = aggregated_df.groupby('col')['value'].sum().nlargest(top_n_columns).index
|
52
|
-
aggregated_df = aggregated_df[aggregated_df['col'].isin(top_columns)]
|
53
|
-
if top_n_facet > 0:
|
54
|
-
top_facets = aggregated_df.groupby('facet')['value'].sum().nlargest(top_n_facet).index
|
55
|
-
aggregated_df = aggregated_df[aggregated_df['facet'].isin(top_facets)]
|
56
122
|
|
57
|
-
#
|
58
|
-
aggregated_df['facet'] = pd.Categorical(aggregated_df['facet'], sorted(aggregated_df['facet'].unique()))
|
59
|
-
aggregated_df = aggregated_df.sort_values(by='facet')
|
123
|
+
# aggregated_df = aggregated_df.sort_values(by="facet")
|
60
124
|
|
61
125
|
return aggregated_df
|
62
126
|
|
@@ -77,13 +141,15 @@ def assign_column_colors(columns, color_palette, null_label):
|
|
77
141
|
palette = getattr(px.colors.qualitative, color_palette)
|
78
142
|
else:
|
79
143
|
raise ValueError(f"Invalid color palette: {color_palette}")
|
80
|
-
|
144
|
+
|
81
145
|
colors = {col: palette[i % len(palette)] for i, col in enumerate(sorted(columns))}
|
82
146
|
colors[null_label] = "lightgray"
|
83
147
|
return colors
|
84
148
|
|
149
|
+
|
85
150
|
### main functions
|
86
151
|
|
152
|
+
|
87
153
|
def plot_quadrants(
|
88
154
|
df: pd.DataFrame,
|
89
155
|
title: str = None,
|
@@ -163,7 +229,7 @@ def plot_quadrants(
|
|
163
229
|
|
164
230
|
# * save to png if path is provided
|
165
231
|
if png_path is not None:
|
166
|
-
plt.savefig(Path(png_path).as_posix(), format=
|
232
|
+
plt.savefig(Path(png_path).as_posix(), format="png")
|
167
233
|
|
168
234
|
return q1, q2, q3, q4, n
|
169
235
|
# * plotly express is not used for the heatmap, although it does not need the derived wide format.
|
@@ -185,11 +251,14 @@ def plot_stacked_bars(
|
|
185
251
|
renderer: Literal["png", "svg", None] = "png",
|
186
252
|
caption: str = None,
|
187
253
|
sort_values: bool = False,
|
254
|
+
sort_values_index: bool = False,
|
255
|
+
sort_values_color: bool = False,
|
188
256
|
show_total: bool = False,
|
189
257
|
precision: int = 0,
|
190
258
|
png_path: Path | str = None,
|
191
259
|
color_palette: str = "Plotly",
|
192
260
|
null_label: str = "<NA>",
|
261
|
+
show_other: bool = False,
|
193
262
|
) -> plotly.graph_objects:
|
194
263
|
"""
|
195
264
|
Generates a stacked bar plot using the provided DataFrame.
|
@@ -208,7 +277,7 @@ def plot_stacked_bars(
|
|
208
277
|
- title (str): Custom title for the plot.
|
209
278
|
- renderer (Literal["png", "svg", None]): Defines the output format.
|
210
279
|
- caption (str): Optional caption for additional context.
|
211
|
-
- sort_values (bool):
|
280
|
+
- sort_values (bool):
|
212
281
|
- If True, sorts bars by the sum of their values (descending).
|
213
282
|
- If False, sorts bars alphabetically.
|
214
283
|
- show_total (bool): If True, adds a row with the total sum of all categories.
|
@@ -216,20 +285,33 @@ def plot_stacked_bars(
|
|
216
285
|
- png_path (Path | str): If specified, saves the plot as a PNG file.
|
217
286
|
- color_palette (str): Name of the color palette to use.
|
218
287
|
- null_label (str): Label for null values.
|
219
|
-
|
288
|
+
- show_other (bool): If True, shows the "Other" category in the legend.
|
289
|
+
- sort_values_index (bool): If True, sorts the index categories by group sum
|
290
|
+
- sort_values_color (bool): If True, sorts the columns categories by group sum
|
291
|
+
|
220
292
|
Returns:
|
221
293
|
- A Plotly figure object representing the stacked bar chart.
|
222
294
|
"""
|
223
295
|
BAR_LENGTH_MULTIPLIER = 1.05
|
224
|
-
|
296
|
+
|
225
297
|
# * 2 axis means at least 2 columns
|
226
298
|
if len(df.columns) < 2 or len(df.columns) > 3:
|
227
299
|
print("❌ df must have exactly 2 or 3 columns")
|
228
300
|
return
|
229
301
|
|
230
|
-
#
|
231
|
-
|
232
|
-
|
302
|
+
# ! do not enforce str columns anymore
|
303
|
+
# # * check if first 2 columns are str
|
304
|
+
# dtypes = set(df.iloc[:, [0, 1]].dtypes)
|
305
|
+
# dtypes_kind = [i.kind for i in dtypes]
|
306
|
+
|
307
|
+
# if set(dtypes_kind) - set(["O", "b"]):
|
308
|
+
# print("❌ first 2 columns must be str")
|
309
|
+
# # * overkill ^^
|
310
|
+
# df.iloc[:, [0, 1]] = df.iloc[:, [0, 1]].astype(str)
|
311
|
+
|
312
|
+
# * but last col must be numeric
|
313
|
+
if df.iloc[:, -1].dtype.kind not in ("f", "i"):
|
314
|
+
print("❌ last column must be numeric")
|
233
315
|
return
|
234
316
|
|
235
317
|
df = df.copy() # Copy the input DataFrame to avoid modifying the original
|
@@ -253,87 +335,102 @@ def plot_stacked_bars(
|
|
253
335
|
# * apply precision
|
254
336
|
df.iloc[:, 2] = df.iloc[:, 2].round(precision)
|
255
337
|
|
256
|
-
# * set index + color col
|
338
|
+
# # * set index + color col
|
257
339
|
col_index = df.columns[0] if not swap else df.columns[1]
|
258
340
|
col_color = df.columns[1] if not swap else df.columns[0]
|
259
341
|
|
260
342
|
# * ensure df is grouped to prevent false aggregations
|
261
|
-
df = (
|
262
|
-
df.groupby([df.columns[0], df.columns[1]])
|
263
|
-
[df.columns[2]]
|
264
|
-
.sum()
|
265
|
-
.reset_index()
|
266
|
-
)
|
343
|
+
df = df.groupby([df.columns[0], df.columns[1]])[df.columns[2]].sum().reset_index()
|
267
344
|
|
268
345
|
# * add total as aggregation of df
|
269
346
|
if show_total:
|
270
|
-
df_total = df.groupby(df.columns[1], observed=True, as_index=False)[
|
347
|
+
df_total = df.groupby(df.columns[1], observed=True, as_index=False)[
|
348
|
+
df.columns[2]
|
349
|
+
].sum()
|
271
350
|
df_total[df.columns[0]] = " Total"
|
272
351
|
df = pd.concat([df, df_total], ignore_index=True)
|
273
352
|
|
274
|
-
|
275
|
-
# * apply top_n, reduce df
|
276
|
-
n_col = top_n_color if top_n_color > 0 else None
|
277
|
-
n_idx = top_n_index if top_n_index > 0 else None
|
278
|
-
|
279
|
-
unique_colors = sorted(
|
280
|
-
df.groupby(col_color)[df.columns[2]]
|
281
|
-
.sum()
|
282
|
-
.sort_values(ascending=False)
|
283
|
-
.index.tolist()[:n_col]
|
284
|
-
)
|
285
|
-
|
286
|
-
unique_idx = df[col_index].sort_values().unique()[:n_idx]
|
287
|
-
|
288
|
-
df = df[df[col_color].isin(unique_colors)]#.sort_values(by=[col_index, col_color])
|
289
|
-
df = df[df[col_index].isin(unique_idx)]#.sort_values(by=[col_index, col_color])
|
290
|
-
|
291
|
-
|
292
|
-
# # * Sorting logic based on sort_values
|
293
|
-
if sort_values:
|
294
|
-
sort_order = (
|
295
|
-
df.groupby(col_index)[df.columns[2]].sum().sort_values(ascending=False).index
|
296
|
-
)
|
297
|
-
else:
|
298
|
-
sort_order = sorted(df[col_index].unique()) # Alphabetical order
|
299
|
-
|
300
|
-
# # * Convert to categorical with explicit ordering
|
301
|
-
df[col_index] = pd.Categorical(df[col_index], categories=sort_order, ordered=True)
|
302
|
-
|
303
|
-
column_colors = assign_column_colors(
|
304
|
-
columns=unique_colors,
|
305
|
-
color_palette=color_palette,
|
306
|
-
null_label=null_label
|
307
|
-
)
|
308
|
-
|
309
353
|
# * calculate n
|
310
354
|
divider = 2 if show_total else 1
|
311
|
-
n = int(df
|
355
|
+
n = int(df.iloc[:, 2].sum() / divider)
|
312
356
|
|
313
357
|
# * title str
|
314
358
|
_title_str_top_index = f"TOP{top_n_index} " if top_n_index > 0 else ""
|
315
359
|
_title_str_top_color = f"TOP{top_n_color} " if top_n_color > 0 else ""
|
316
360
|
_title_str_null = f", NULL excluded" if dropna else ""
|
317
361
|
_title_str_n = f", n={n:_}"
|
362
|
+
|
363
|
+
_df = df.copy().assign(facet=None)
|
364
|
+
_df.columns = (
|
365
|
+
["index", "col", "value", "facet"]
|
366
|
+
if not swap
|
367
|
+
else ["col", "index", "value", "facet"]
|
368
|
+
)
|
369
|
+
|
370
|
+
aggregated_df = aggregate_data(
|
371
|
+
df=_df,
|
372
|
+
top_n_index=top_n_index,
|
373
|
+
top_n_color=top_n_color,
|
374
|
+
top_n_facet=0,
|
375
|
+
null_label=null_label,
|
376
|
+
show_other=show_other,
|
377
|
+
sort_values_index=sort_values_index,
|
378
|
+
sort_values_color=sort_values_color,
|
379
|
+
sort_values_facet=False, # just a placeholder
|
380
|
+
)
|
381
|
+
|
382
|
+
df = aggregated_df.copy()
|
383
|
+
|
384
|
+
columns = sorted(
|
385
|
+
df.groupby("col", observed=True)["value"]
|
386
|
+
.sum()
|
387
|
+
.sort_values(ascending=False)
|
388
|
+
.index.tolist()
|
389
|
+
)
|
390
|
+
column_colors = assign_column_colors(columns, color_palette, null_label)
|
391
|
+
|
318
392
|
caption = _set_caption(caption)
|
319
393
|
|
320
|
-
|
321
|
-
df["
|
394
|
+
# * after grouping add cols for pct and formatting
|
395
|
+
df["cnt_pct_only"] = df["value"].apply(lambda x: f"{(x / n) * 100:.{precision}f}%")
|
322
396
|
|
323
397
|
# * format output
|
324
|
-
df["cnt_str"] = df[
|
398
|
+
df["cnt_str"] = df["value"].apply(lambda x: f"{x:_.{precision}f}")
|
325
399
|
|
326
400
|
divider2 = "<br>" if orientation == "v" else " "
|
327
401
|
df["cnt_pct_str"] = df.apply(
|
328
|
-
lambda row: f"{row['cnt_str']}{divider2}({row['
|
402
|
+
lambda row: f"{row['cnt_str']}{divider2}({row['cnt_pct_only']})", axis=1
|
329
403
|
)
|
330
404
|
|
405
|
+
# # # * Sorting logic based on sort_values
|
406
|
+
# if sort_values_index:
|
407
|
+
# sort_order = (
|
408
|
+
# df.groupby("index")["value"].sum().sort_values(ascending=False).index
|
409
|
+
# )
|
410
|
+
# else:
|
411
|
+
# sort_order = sorted(df["index"].unique(), reverse=False) # Alphabetical order
|
412
|
+
|
413
|
+
# display(sort_order)
|
414
|
+
|
415
|
+
# df["index"] = pd.Categorical(
|
416
|
+
# values=df["index"],
|
417
|
+
# # categories=sort_order,
|
418
|
+
# ordered=True,
|
419
|
+
# )
|
420
|
+
df = (
|
421
|
+
df.sort_values(by="index", ascending=False)
|
422
|
+
if orientation == "h"
|
423
|
+
else df.sort_values(by="index", ascending=True)
|
424
|
+
)
|
425
|
+
|
426
|
+
# display(df)
|
427
|
+
|
331
428
|
# * plot
|
332
429
|
fig = px.bar(
|
333
430
|
df,
|
334
|
-
x=
|
335
|
-
y=
|
336
|
-
color=
|
431
|
+
x="index" if orientation == "v" else "value",
|
432
|
+
y="value" if orientation == "v" else "index",
|
433
|
+
color="col",
|
337
434
|
text="cnt_pct_str" if normalize else "cnt_str",
|
338
435
|
orientation=orientation,
|
339
436
|
title=title
|
@@ -342,13 +439,15 @@ def plot_stacked_bars(
|
|
342
439
|
width=width,
|
343
440
|
height=height,
|
344
441
|
color_discrete_map=column_colors, # Use assigned colors
|
345
|
-
category_orders={
|
346
|
-
|
442
|
+
category_orders={
|
443
|
+
col_index: list(df["index"].cat.categories)
|
444
|
+
}, # <- Add this line
|
347
445
|
)
|
348
|
-
|
349
|
-
|
446
|
+
|
447
|
+
|
448
|
+
# * get longest bar
|
350
449
|
bar_max = (
|
351
|
-
df.groupby(
|
450
|
+
df.groupby("index")["value"].sum().sort_values(ascending=False).iloc[0]
|
352
451
|
* BAR_LENGTH_MULTIPLIER
|
353
452
|
)
|
354
453
|
# * ignore if bar mode is on
|
@@ -372,7 +471,7 @@ def plot_stacked_bars(
|
|
372
471
|
},
|
373
472
|
},
|
374
473
|
)
|
375
|
-
|
474
|
+
|
376
475
|
# * set dtick
|
377
476
|
if orientation == "h":
|
378
477
|
if relative:
|
@@ -692,7 +791,7 @@ def plot_histogram(
|
|
692
791
|
caption (str): The caption for the plot. Default is None.
|
693
792
|
title (str): The title of the plot. Default is None.
|
694
793
|
png_path (Path | str, optional): The path to save the image as a png file. Defaults to None.
|
695
|
-
|
794
|
+
|
696
795
|
|
697
796
|
Returns:
|
698
797
|
plot object
|
@@ -744,7 +843,7 @@ def plot_histogram(
|
|
744
843
|
)
|
745
844
|
|
746
845
|
fig.show(renderer)
|
747
|
-
|
846
|
+
|
748
847
|
# * save to png if path is provided
|
749
848
|
if png_path is not None:
|
750
849
|
fig.write_image(Path(png_path).as_posix())
|
@@ -1156,12 +1255,11 @@ def plot_boxes(
|
|
1156
1255
|
return fig
|
1157
1256
|
|
1158
1257
|
|
1159
|
-
|
1160
1258
|
def plot_facet_stacked_bars(
|
1161
1259
|
df: pd.DataFrame,
|
1162
1260
|
subplots_per_row: int = 4,
|
1163
1261
|
top_n_index: int = 0,
|
1164
|
-
|
1262
|
+
top_n_color: int = 0,
|
1165
1263
|
top_n_facet: int = 0,
|
1166
1264
|
null_label: str = "<NA>",
|
1167
1265
|
subplot_size: int = 300,
|
@@ -1171,6 +1269,12 @@ def plot_facet_stacked_bars(
|
|
1171
1269
|
annotations: bool = False,
|
1172
1270
|
precision: int = 0,
|
1173
1271
|
png_path: Optional[Path] = None,
|
1272
|
+
show_other: bool = False,
|
1273
|
+
sort_values: bool = True,
|
1274
|
+
sort_values_index: bool = False,
|
1275
|
+
sort_values_color: bool = False,
|
1276
|
+
sort_values_facet: bool = False,
|
1277
|
+
|
1174
1278
|
) -> object:
|
1175
1279
|
"""
|
1176
1280
|
Create a grid of stacked bar charts.
|
@@ -1179,7 +1283,7 @@ def plot_facet_stacked_bars(
|
|
1179
1283
|
df (pd.DataFrame): DataFrame with 3 or 4 columns.
|
1180
1284
|
subplots_per_row (int): Number of subplots per row.
|
1181
1285
|
top_n_index (int): top N index values to keep.
|
1182
|
-
|
1286
|
+
top_n_color (int): top N column values to keep.
|
1183
1287
|
top_n_facet (int): top N facet values to keep.
|
1184
1288
|
null_label (str): Label for null values.
|
1185
1289
|
subplot_size (int): Size of each subplot.
|
@@ -1189,47 +1293,57 @@ def plot_facet_stacked_bars(
|
|
1189
1293
|
annotations (bool): Whether to show annotations in the subplots.
|
1190
1294
|
precision (int): Decimal precision for annotations.
|
1191
1295
|
png_path (Optional[Path]): Path to save the image.
|
1296
|
+
show_other (bool): If True, adds an "<other>" bar for columns not in top_n_color.
|
1297
|
+
sort_values_index (bool): If True, sorts index by group sum.
|
1298
|
+
sort_values_color (bool): If True, sorts columns by group sum.
|
1299
|
+
sort_values_facet (bool): If True, sorts facet by group sum.
|
1300
|
+
sort_values (bool): DEPRECATED
|
1301
|
+
|
1192
1302
|
|
1193
1303
|
Returns:
|
1194
1304
|
plot object
|
1195
|
-
|
1305
|
+
|
1196
1306
|
Remarks:
|
1197
1307
|
If you need to include facets that have no data, fill up like this beforehand:
|
1198
1308
|
df.loc[len(df)]=[None, None, 12]
|
1199
1309
|
"""
|
1200
|
-
|
1310
|
+
|
1201
1311
|
df = df.copy() # Copy the input DataFrame to avoid modifying the original
|
1202
1312
|
|
1203
1313
|
if not (df.shape[1] == 3 or df.shape[1] == 4):
|
1204
1314
|
raise ValueError("Input DataFrame must have 3 or 4 columns.")
|
1205
|
-
|
1315
|
+
|
1206
1316
|
original_column_names = df.columns.tolist()
|
1317
|
+
original_rows = len(df)
|
1207
1318
|
|
1208
1319
|
if df.shape[1] == 3:
|
1209
|
-
df.columns = [
|
1210
|
-
df[
|
1320
|
+
df.columns = ["index", "col", "facet"]
|
1321
|
+
df["value"] = 1
|
1211
1322
|
elif df.shape[1] == 4:
|
1212
|
-
df.columns = [
|
1213
|
-
|
1214
|
-
aggregated_df = aggregate_data(df, top_n_index, top_n_columns, top_n_facet, null_label)
|
1215
|
-
|
1216
|
-
# facets = aggregated_df['facet'].unique()
|
1217
|
-
facets = sorted(aggregated_df['facet'].unique()) # Ensure facets are sorted consistently
|
1323
|
+
df.columns = ["index", "col", "facet", "value"]
|
1218
1324
|
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1325
|
+
aggregated_df = aggregate_data(
|
1326
|
+
df,
|
1327
|
+
top_n_index,
|
1328
|
+
top_n_color,
|
1329
|
+
top_n_facet,
|
1330
|
+
null_label,
|
1331
|
+
show_other=show_other,
|
1332
|
+
sort_values_index=sort_values_index,
|
1333
|
+
sort_values_color=sort_values_color,
|
1334
|
+
sort_values_facet=sort_values_facet,
|
1335
|
+
)
|
1229
1336
|
|
1337
|
+
facets = sorted(
|
1338
|
+
aggregated_df["facet"].unique()
|
1339
|
+
) # Ensure facets are sorted consistently
|
1230
1340
|
|
1231
|
-
|
1232
|
-
|
1341
|
+
columns = sorted(
|
1342
|
+
aggregated_df.groupby("col", observed=True)["value"]
|
1343
|
+
.sum()
|
1344
|
+
.sort_values(ascending=False)
|
1345
|
+
.index.tolist()
|
1346
|
+
)
|
1233
1347
|
column_colors = assign_column_colors(columns, color_palette, null_label)
|
1234
1348
|
|
1235
1349
|
fig = make_subplots(
|
@@ -1238,25 +1352,39 @@ def plot_facet_stacked_bars(
|
|
1238
1352
|
subplot_titles=facets,
|
1239
1353
|
)
|
1240
1354
|
|
1355
|
+
# * Ensure all categories appear in the legend by adding an invisible trace
|
1356
|
+
for column in columns:
|
1357
|
+
fig.add_trace(
|
1358
|
+
go.Bar(
|
1359
|
+
x=[None], # Invisible bar
|
1360
|
+
y=[None],
|
1361
|
+
name=column,
|
1362
|
+
marker=dict(color=column_colors[column]),
|
1363
|
+
showlegend=True, # Ensure it appears in the legend
|
1364
|
+
)
|
1365
|
+
)
|
1366
|
+
|
1241
1367
|
added_to_legend = set()
|
1242
1368
|
for i, facet in enumerate(facets):
|
1243
|
-
facet_data = aggregated_df[aggregated_df[
|
1369
|
+
facet_data = aggregated_df[aggregated_df["facet"] == facet]
|
1244
1370
|
row = (i // subplots_per_row) + 1
|
1245
1371
|
col = (i % subplots_per_row) + 1
|
1246
1372
|
|
1247
1373
|
for column in columns:
|
1248
|
-
column_data = facet_data[facet_data[
|
1374
|
+
column_data = facet_data[facet_data["col"] == column]
|
1375
|
+
|
1249
1376
|
show_legend = column not in added_to_legend
|
1250
1377
|
if show_legend:
|
1251
1378
|
added_to_legend.add(column)
|
1252
1379
|
|
1253
1380
|
fig.add_trace(
|
1254
1381
|
go.Bar(
|
1255
|
-
x=column_data[
|
1256
|
-
y=column_data[
|
1382
|
+
x=column_data["index"],
|
1383
|
+
y=column_data["value"],
|
1257
1384
|
name=column,
|
1258
1385
|
marker=dict(color=column_colors[column]),
|
1259
|
-
|
1386
|
+
legendgroup=column, # Ensures multiple traces use the same legend entry
|
1387
|
+
showlegend=False, # suppress further legend items
|
1260
1388
|
),
|
1261
1389
|
row=row,
|
1262
1390
|
col=col,
|
@@ -1265,8 +1393,8 @@ def plot_facet_stacked_bars(
|
|
1265
1393
|
if annotations:
|
1266
1394
|
for _, row_data in column_data.iterrows():
|
1267
1395
|
fig.add_annotation(
|
1268
|
-
x=row_data[
|
1269
|
-
y=row_data[
|
1396
|
+
x=row_data["index"],
|
1397
|
+
y=row_data["value"],
|
1270
1398
|
text=f"{row_data['value']:.{precision}f}",
|
1271
1399
|
showarrow=False,
|
1272
1400
|
row=row,
|
@@ -1280,8 +1408,8 @@ def plot_facet_stacked_bars(
|
|
1280
1408
|
else:
|
1281
1409
|
axis_details.append(f"[{original_column_names[0]}]")
|
1282
1410
|
|
1283
|
-
if
|
1284
|
-
axis_details.append(f"TOP {
|
1411
|
+
if top_n_color > 0:
|
1412
|
+
axis_details.append(f"TOP {top_n_color} [{original_column_names[1]}]")
|
1285
1413
|
else:
|
1286
1414
|
axis_details.append(f"[{original_column_names[1]}]")
|
1287
1415
|
|
@@ -1290,7 +1418,7 @@ def plot_facet_stacked_bars(
|
|
1290
1418
|
else:
|
1291
1419
|
axis_details.append(f"[{original_column_names[2]}]")
|
1292
1420
|
|
1293
|
-
title = f"{caption} {', '.join(axis_details)}, n = {
|
1421
|
+
title = f"{caption} {', '.join(axis_details)}, n = {original_rows:_}"
|
1294
1422
|
template = "plotly_dark" if os.getenv("THEME") == "dark" else "plotly"
|
1295
1423
|
fig.update_layout(
|
1296
1424
|
title=title,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|