pandas-plots 0.12.7__tar.gz → 0.12.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pandas-plots
3
- Version: 0.12.7
3
+ Version: 0.12.8
4
4
  Summary: A collection of helper for table handling and visualization
5
5
  Home-page: https://github.com/smeisegeier/pandas-plots
6
6
  Author: smeisegeier
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = pandas-plots
3
- version = 0.12.7
3
+ version = 0.12.8
4
4
  author = smeisegeier
5
5
  author_email = dexterDSDo@googlemail.com
6
6
  description = A collection of helper for table handling and visualization
@@ -1,7 +1,4 @@
1
1
  from pathlib import Path
2
- import warnings
3
-
4
- warnings.filterwarnings("ignore")
5
2
 
6
3
  import os
7
4
  from typing import Optional, Literal
@@ -12,51 +9,118 @@ from matplotlib import pyplot as plt
12
9
  from plotly import express as px
13
10
  import plotly.graph_objects as go
14
11
  from plotly.subplots import make_subplots
15
- import plotly # needed for return types
12
+ import plotly # needed for return types
16
13
 
17
14
  from .hlp import *
18
15
  from .tbl import print_summary
19
16
 
20
17
  ### helper functions
21
18
 
19
+
22
20
  def _set_caption(caption: str) -> str:
23
21
  return f"#️⃣{'-'.join(caption.split())}, " if caption else ""
24
22
 
25
23
 
26
- def aggregate_data(df: pd.DataFrame, top_n_index: int, top_n_columns: int, top_n_facet: int, null_label: str) -> pd.DataFrame:
24
+ def aggregate_data(
25
+ df: pd.DataFrame,
26
+ top_n_index: int,
27
+ top_n_color: int,
28
+ top_n_facet: int,
29
+ null_label: str,
30
+ show_other: bool = False,
31
+ sort_values_index: bool = False,
32
+ sort_values_color: bool = False,
33
+ sort_values_facet: bool = False,
34
+ ) -> pd.DataFrame:
27
35
  """
28
36
  Aggregates the data, ensuring each combination of 'index', 'col', and 'facet' is unique with summed 'value'.
29
-
37
+
30
38
  Args:
31
39
  df (pd.DataFrame): Input DataFrame.
32
40
  top_n_index (int): top N values of the first column to keep. 0 means take all.
33
- top_n_columns (int): top N values of the second column to keep. 0 means take all.
41
+ top_n_color (int): top N values of the second column to keep. 0 means take all.
34
42
  top_n_facet (int): top N values of the third column to keep. 0 means take all.
35
43
  null_label (str): Label for null values.
44
+ show_other (bool): Whether to include "<other>" for columns not in top_n_color. Defaults to False.
45
+ sort_values (bool): Whether to sort values in descending order based on group sum. Defaults to False.
36
46
 
37
47
  Returns:
38
48
  pd.DataFrame: Aggregated and filtered dataset.
39
49
  """
40
- for col in ['index', 'col', 'facet']: # Skip 'value' column (numeric)
50
+
51
+ for col in ["index", "col", "facet"]: # Skip 'value' column (numeric)
41
52
  df[col] = df[col].fillna(null_label)
42
53
 
43
54
  # Aggregate data to ensure unique combinations
44
- aggregated_df = df.groupby(['index', 'col', 'facet'], as_index=False)['value'].sum()
55
+ aggregated_df = df.groupby(["index", "col", "facet"], as_index=False)["value"].sum()
56
+
57
+ # * Reduce data based on top_n parameters
58
+ if sort_values_index:
59
+ top_indexes = (
60
+ aggregated_df.groupby("index")["value"]
61
+ .sum()
62
+ .sort_values(ascending=False)[:top_n_index or None]
63
+ .index
64
+ )
65
+ else:
66
+ top_indexes = aggregated_df["index"].sort_values().unique()[:top_n_index or None]
67
+
68
+ aggregated_df = aggregated_df[aggregated_df["index"].isin(top_indexes)]
69
+
70
+ if sort_values_color:
71
+ top_colors = (
72
+ aggregated_df.groupby("col")["value"]
73
+ .sum()
74
+ .sort_values(ascending=False)[:top_n_color or None]
75
+ .index
76
+ )
77
+ else:
78
+ top_colors = aggregated_df["col"].sort_values().unique()[:top_n_color or None]
79
+
80
+ others_df = df[~df["col"].isin(top_colors)]
81
+ aggregated_df = aggregated_df[aggregated_df["col"].isin(top_colors)]
82
+ if show_other and top_n_color > 0 and not others_df.empty:
83
+ other_agg = others_df.groupby(["index", "facet"], as_index=False)[
84
+ "value"
85
+ ].sum()
86
+ other_agg["col"] = "<other>"
87
+ other_agg = other_agg[["index", "col", "facet", "value"]]
88
+ aggregated_df = pd.concat([aggregated_df, other_agg], ignore_index=True)
89
+ top_colors = [*top_colors, "<other>"]
90
+
91
+ if sort_values_facet:
92
+ top_facets = (
93
+ aggregated_df.groupby("facet")["value"]
94
+ .sum()
95
+ .sort_values(ascending=False)[:top_n_facet or None]
96
+ .index
97
+ )
98
+ else:
99
+ top_facets = aggregated_df["facet"].sort_values().unique()[:top_n_facet or None]
100
+
101
+ aggregated_df = aggregated_df[aggregated_df["facet"].isin(top_facets)]
102
+
103
+ # * Ensure facets are sorted alphabetically
104
+ aggregated_df["facet"] = pd.Categorical(
105
+ values=aggregated_df["facet"],
106
+ categories=top_facets,
107
+ ordered=True,
108
+ )
109
+
110
+ aggregated_df["index"] = pd.Categorical(
111
+ values=aggregated_df["index"],
112
+ categories=top_indexes,
113
+ ordered=True,
114
+ )
115
+
116
+ aggregated_df["col"] = pd.Categorical(
117
+ values=aggregated_df["col"],
118
+ categories=top_colors,
119
+ ordered=True,
120
+ )
45
121
 
46
- # Reduce data based on top_n parameters
47
- if top_n_index > 0:
48
- top_indexes = aggregated_df.groupby('index')['value'].sum().nlargest(top_n_index).index
49
- aggregated_df = aggregated_df[aggregated_df['index'].isin(top_indexes)]
50
- if top_n_columns > 0:
51
- top_columns = aggregated_df.groupby('col')['value'].sum().nlargest(top_n_columns).index
52
- aggregated_df = aggregated_df[aggregated_df['col'].isin(top_columns)]
53
- if top_n_facet > 0:
54
- top_facets = aggregated_df.groupby('facet')['value'].sum().nlargest(top_n_facet).index
55
- aggregated_df = aggregated_df[aggregated_df['facet'].isin(top_facets)]
56
122
 
57
- # Ensure facets are sorted alphabetically
58
- aggregated_df['facet'] = pd.Categorical(aggregated_df['facet'], sorted(aggregated_df['facet'].unique()))
59
- aggregated_df = aggregated_df.sort_values(by='facet')
123
+ # aggregated_df = aggregated_df.sort_values(by="facet")
60
124
 
61
125
  return aggregated_df
62
126
 
@@ -77,13 +141,15 @@ def assign_column_colors(columns, color_palette, null_label):
77
141
  palette = getattr(px.colors.qualitative, color_palette)
78
142
  else:
79
143
  raise ValueError(f"Invalid color palette: {color_palette}")
80
-
144
+
81
145
  colors = {col: palette[i % len(palette)] for i, col in enumerate(sorted(columns))}
82
146
  colors[null_label] = "lightgray"
83
147
  return colors
84
148
 
149
+
85
150
  ### main functions
86
151
 
152
+
87
153
  def plot_quadrants(
88
154
  df: pd.DataFrame,
89
155
  title: str = None,
@@ -163,7 +229,7 @@ def plot_quadrants(
163
229
 
164
230
  # * save to png if path is provided
165
231
  if png_path is not None:
166
- plt.savefig(Path(png_path).as_posix(), format='png')
232
+ plt.savefig(Path(png_path).as_posix(), format="png")
167
233
 
168
234
  return q1, q2, q3, q4, n
169
235
  # * plotly express is not used for the heatmap, although it does not need the derived wide format.
@@ -185,11 +251,14 @@ def plot_stacked_bars(
185
251
  renderer: Literal["png", "svg", None] = "png",
186
252
  caption: str = None,
187
253
  sort_values: bool = False,
254
+ sort_values_index: bool = False,
255
+ sort_values_color: bool = False,
188
256
  show_total: bool = False,
189
257
  precision: int = 0,
190
258
  png_path: Path | str = None,
191
259
  color_palette: str = "Plotly",
192
260
  null_label: str = "<NA>",
261
+ show_other: bool = False,
193
262
  ) -> plotly.graph_objects:
194
263
  """
195
264
  Generates a stacked bar plot using the provided DataFrame.
@@ -208,7 +277,7 @@ def plot_stacked_bars(
208
277
  - title (str): Custom title for the plot.
209
278
  - renderer (Literal["png", "svg", None]): Defines the output format.
210
279
  - caption (str): Optional caption for additional context.
211
- - sort_values (bool):
280
+ - sort_values (bool):
212
281
  - If True, sorts bars by the sum of their values (descending).
213
282
  - If False, sorts bars alphabetically.
214
283
  - show_total (bool): If True, adds a row with the total sum of all categories.
@@ -216,20 +285,33 @@ def plot_stacked_bars(
216
285
  - png_path (Path | str): If specified, saves the plot as a PNG file.
217
286
  - color_palette (str): Name of the color palette to use.
218
287
  - null_label (str): Label for null values.
219
-
288
+ - show_other (bool): If True, shows the "Other" category in the legend.
289
+ - sort_values_index (bool): If True, sorts the index categories by group sum
290
+ - sort_values_color (bool): If True, sorts the columns categories by group sum
291
+
220
292
  Returns:
221
293
  - A Plotly figure object representing the stacked bar chart.
222
294
  """
223
295
  BAR_LENGTH_MULTIPLIER = 1.05
224
-
296
+
225
297
  # * 2 axis means at least 2 columns
226
298
  if len(df.columns) < 2 or len(df.columns) > 3:
227
299
  print("❌ df must have exactly 2 or 3 columns")
228
300
  return
229
301
 
230
- # * check if first 2 columns are str
231
- if list(set((df.iloc[:, [0, 1]].dtypes)))[0].kind not in ["O", "b"]:
232
- print("❌ first 2 columns must be str")
302
+ # ! do not enforce str columns anymore
303
+ # # * check if first 2 columns are str
304
+ # dtypes = set(df.iloc[:, [0, 1]].dtypes)
305
+ # dtypes_kind = [i.kind for i in dtypes]
306
+
307
+ # if set(dtypes_kind) - set(["O", "b"]):
308
+ # print("❌ first 2 columns must be str")
309
+ # # * overkill ^^
310
+ # df.iloc[:, [0, 1]] = df.iloc[:, [0, 1]].astype(str)
311
+
312
+ # * but last col must be numeric
313
+ if df.iloc[:, -1].dtype.kind not in ("f", "i"):
314
+ print("❌ last column must be numeric")
233
315
  return
234
316
 
235
317
  df = df.copy() # Copy the input DataFrame to avoid modifying the original
@@ -253,87 +335,102 @@ def plot_stacked_bars(
253
335
  # * apply precision
254
336
  df.iloc[:, 2] = df.iloc[:, 2].round(precision)
255
337
 
256
- # * set index + color col
338
+ # # * set index + color col
257
339
  col_index = df.columns[0] if not swap else df.columns[1]
258
340
  col_color = df.columns[1] if not swap else df.columns[0]
259
341
 
260
342
  # * ensure df is grouped to prevent false aggregations
261
- df = (
262
- df.groupby([df.columns[0], df.columns[1]])
263
- [df.columns[2]]
264
- .sum()
265
- .reset_index()
266
- )
343
+ df = df.groupby([df.columns[0], df.columns[1]])[df.columns[2]].sum().reset_index()
267
344
 
268
345
  # * add total as aggregation of df
269
346
  if show_total:
270
- df_total = df.groupby(df.columns[1], observed=True, as_index=False)[df.columns[2]].sum()
347
+ df_total = df.groupby(df.columns[1], observed=True, as_index=False)[
348
+ df.columns[2]
349
+ ].sum()
271
350
  df_total[df.columns[0]] = " Total"
272
351
  df = pd.concat([df, df_total], ignore_index=True)
273
352
 
274
-
275
- # * apply top_n, reduce df
276
- n_col = top_n_color if top_n_color > 0 else None
277
- n_idx = top_n_index if top_n_index > 0 else None
278
-
279
- unique_colors = sorted(
280
- df.groupby(col_color)[df.columns[2]]
281
- .sum()
282
- .sort_values(ascending=False)
283
- .index.tolist()[:n_col]
284
- )
285
-
286
- unique_idx = df[col_index].sort_values().unique()[:n_idx]
287
-
288
- df = df[df[col_color].isin(unique_colors)]#.sort_values(by=[col_index, col_color])
289
- df = df[df[col_index].isin(unique_idx)]#.sort_values(by=[col_index, col_color])
290
-
291
-
292
- # # * Sorting logic based on sort_values
293
- if sort_values:
294
- sort_order = (
295
- df.groupby(col_index)[df.columns[2]].sum().sort_values(ascending=False).index
296
- )
297
- else:
298
- sort_order = sorted(df[col_index].unique()) # Alphabetical order
299
-
300
- # # * Convert to categorical with explicit ordering
301
- df[col_index] = pd.Categorical(df[col_index], categories=sort_order, ordered=True)
302
-
303
- column_colors = assign_column_colors(
304
- columns=unique_colors,
305
- color_palette=color_palette,
306
- null_label=null_label
307
- )
308
-
309
353
  # * calculate n
310
354
  divider = 2 if show_total else 1
311
- n = int(df[df.columns[2]].sum() / divider)
355
+ n = int(df.iloc[:, 2].sum() / divider)
312
356
 
313
357
  # * title str
314
358
  _title_str_top_index = f"TOP{top_n_index} " if top_n_index > 0 else ""
315
359
  _title_str_top_color = f"TOP{top_n_color} " if top_n_color > 0 else ""
316
360
  _title_str_null = f", NULL excluded" if dropna else ""
317
361
  _title_str_n = f", n={n:_}"
362
+
363
+ _df = df.copy().assign(facet=None)
364
+ _df.columns = (
365
+ ["index", "col", "value", "facet"]
366
+ if not swap
367
+ else ["col", "index", "value", "facet"]
368
+ )
369
+
370
+ aggregated_df = aggregate_data(
371
+ df=_df,
372
+ top_n_index=top_n_index,
373
+ top_n_color=top_n_color,
374
+ top_n_facet=0,
375
+ null_label=null_label,
376
+ show_other=show_other,
377
+ sort_values_index=sort_values_index,
378
+ sort_values_color=sort_values_color,
379
+ sort_values_facet=False, # just a placeholder
380
+ )
381
+
382
+ df = aggregated_df.copy()
383
+
384
+ columns = sorted(
385
+ df.groupby("col", observed=True)["value"]
386
+ .sum()
387
+ .sort_values(ascending=False)
388
+ .index.tolist()
389
+ )
390
+ column_colors = assign_column_colors(columns, color_palette, null_label)
391
+
318
392
  caption = _set_caption(caption)
319
393
 
320
- # * after grouping add cols for pct and formatting
321
- df["pct"] = df[df.columns[2]].apply(lambda x: f"{(x / n) * 100:.{precision}f}%")
394
+ # * after grouping add cols for pct and formatting
395
+ df["cnt_pct_only"] = df["value"].apply(lambda x: f"{(x / n) * 100:.{precision}f}%")
322
396
 
323
397
  # * format output
324
- df["cnt_str"] = df[df.columns[2]].apply(lambda x: f"{x:_.{precision}f}")
398
+ df["cnt_str"] = df["value"].apply(lambda x: f"{x:_.{precision}f}")
325
399
 
326
400
  divider2 = "<br>" if orientation == "v" else " "
327
401
  df["cnt_pct_str"] = df.apply(
328
- lambda row: f"{row['cnt_str']}{divider2}({row['pct']})", axis=1
402
+ lambda row: f"{row['cnt_str']}{divider2}({row['cnt_pct_only']})", axis=1
329
403
  )
330
404
 
405
+ # # # * Sorting logic based on sort_values
406
+ # if sort_values_index:
407
+ # sort_order = (
408
+ # df.groupby("index")["value"].sum().sort_values(ascending=False).index
409
+ # )
410
+ # else:
411
+ # sort_order = sorted(df["index"].unique(), reverse=False) # Alphabetical order
412
+
413
+ # display(sort_order)
414
+
415
+ # df["index"] = pd.Categorical(
416
+ # values=df["index"],
417
+ # # categories=sort_order,
418
+ # ordered=True,
419
+ # )
420
+ df = (
421
+ df.sort_values(by="index", ascending=False)
422
+ if orientation == "h"
423
+ else df.sort_values(by="index", ascending=True)
424
+ )
425
+
426
+ # display(df)
427
+
331
428
  # * plot
332
429
  fig = px.bar(
333
430
  df,
334
- x=col_index if orientation == "v" else df.columns[2],
335
- y=df.columns[2] if orientation == "v" else col_index,
336
- color=col_color,
431
+ x="index" if orientation == "v" else "value",
432
+ y="value" if orientation == "v" else "index",
433
+ color="col",
337
434
  text="cnt_pct_str" if normalize else "cnt_str",
338
435
  orientation=orientation,
339
436
  title=title
@@ -342,13 +439,15 @@ def plot_stacked_bars(
342
439
  width=width,
343
440
  height=height,
344
441
  color_discrete_map=column_colors, # Use assigned colors
345
- category_orders={col_index: list(df[col_index].cat.categories)}, # <- Add this line
346
-
442
+ category_orders={
443
+ col_index: list(df["index"].cat.categories)
444
+ }, # <- Add this line
347
445
  )
348
-
349
- # * get longest bar
446
+
447
+
448
+ # * get longest bar
350
449
  bar_max = (
351
- df.groupby(col_index)[df.columns[2]].sum().sort_values(ascending=False).iloc[0]
450
+ df.groupby("index")["value"].sum().sort_values(ascending=False).iloc[0]
352
451
  * BAR_LENGTH_MULTIPLIER
353
452
  )
354
453
  # * ignore if bar mode is on
@@ -372,7 +471,7 @@ def plot_stacked_bars(
372
471
  },
373
472
  },
374
473
  )
375
-
474
+
376
475
  # * set dtick
377
476
  if orientation == "h":
378
477
  if relative:
@@ -692,7 +791,7 @@ def plot_histogram(
692
791
  caption (str): The caption for the plot. Default is None.
693
792
  title (str): The title of the plot. Default is None.
694
793
  png_path (Path | str, optional): The path to save the image as a png file. Defaults to None.
695
-
794
+
696
795
 
697
796
  Returns:
698
797
  plot object
@@ -744,7 +843,7 @@ def plot_histogram(
744
843
  )
745
844
 
746
845
  fig.show(renderer)
747
-
846
+
748
847
  # * save to png if path is provided
749
848
  if png_path is not None:
750
849
  fig.write_image(Path(png_path).as_posix())
@@ -1156,12 +1255,11 @@ def plot_boxes(
1156
1255
  return fig
1157
1256
 
1158
1257
 
1159
-
1160
1258
  def plot_facet_stacked_bars(
1161
1259
  df: pd.DataFrame,
1162
1260
  subplots_per_row: int = 4,
1163
1261
  top_n_index: int = 0,
1164
- top_n_columns: int = 0,
1262
+ top_n_color: int = 0,
1165
1263
  top_n_facet: int = 0,
1166
1264
  null_label: str = "<NA>",
1167
1265
  subplot_size: int = 300,
@@ -1171,6 +1269,12 @@ def plot_facet_stacked_bars(
1171
1269
  annotations: bool = False,
1172
1270
  precision: int = 0,
1173
1271
  png_path: Optional[Path] = None,
1272
+ show_other: bool = False,
1273
+ sort_values: bool = True,
1274
+ sort_values_index: bool = False,
1275
+ sort_values_color: bool = False,
1276
+ sort_values_facet: bool = False,
1277
+
1174
1278
  ) -> object:
1175
1279
  """
1176
1280
  Create a grid of stacked bar charts.
@@ -1179,7 +1283,7 @@ def plot_facet_stacked_bars(
1179
1283
  df (pd.DataFrame): DataFrame with 3 or 4 columns.
1180
1284
  subplots_per_row (int): Number of subplots per row.
1181
1285
  top_n_index (int): top N index values to keep.
1182
- top_n_columns (int): top N column values to keep.
1286
+ top_n_color (int): top N column values to keep.
1183
1287
  top_n_facet (int): top N facet values to keep.
1184
1288
  null_label (str): Label for null values.
1185
1289
  subplot_size (int): Size of each subplot.
@@ -1189,47 +1293,57 @@ def plot_facet_stacked_bars(
1189
1293
  annotations (bool): Whether to show annotations in the subplots.
1190
1294
  precision (int): Decimal precision for annotations.
1191
1295
  png_path (Optional[Path]): Path to save the image.
1296
+ show_other (bool): If True, adds an "<other>" bar for columns not in top_n_color.
1297
+ sort_values_index (bool): If True, sorts index by group sum.
1298
+ sort_values_color (bool): If True, sorts columns by group sum.
1299
+ sort_values_facet (bool): If True, sorts facet by group sum.
1300
+ sort_values (bool): DEPRECATED
1301
+
1192
1302
 
1193
1303
  Returns:
1194
1304
  plot object
1195
-
1305
+
1196
1306
  Remarks:
1197
1307
  If you need to include facets that have no data, fill up like this beforehand:
1198
1308
  df.loc[len(df)]=[None, None, 12]
1199
1309
  """
1200
-
1310
+
1201
1311
  df = df.copy() # Copy the input DataFrame to avoid modifying the original
1202
1312
 
1203
1313
  if not (df.shape[1] == 3 or df.shape[1] == 4):
1204
1314
  raise ValueError("Input DataFrame must have 3 or 4 columns.")
1205
-
1315
+
1206
1316
  original_column_names = df.columns.tolist()
1317
+ original_rows = len(df)
1207
1318
 
1208
1319
  if df.shape[1] == 3:
1209
- df.columns = ['index', 'col', 'facet']
1210
- df['value'] = 1
1320
+ df.columns = ["index", "col", "facet"]
1321
+ df["value"] = 1
1211
1322
  elif df.shape[1] == 4:
1212
- df.columns = ['index', 'col', 'facet', 'value']
1213
-
1214
- aggregated_df = aggregate_data(df, top_n_index, top_n_columns, top_n_facet, null_label)
1215
-
1216
- # facets = aggregated_df['facet'].unique()
1217
- facets = sorted(aggregated_df['facet'].unique()) # Ensure facets are sorted consistently
1323
+ df.columns = ["index", "col", "facet", "value"]
1218
1324
 
1219
- if top_n_columns > 0:
1220
- top_columns = aggregated_df.groupby('col', observed=True)['value'].sum().nlargest(top_n_columns).index.tolist()
1221
- # aggregated_df['col'] = aggregated_df['col'].apply(lambda x: x if x in top_columns else "<other>")
1222
- # aggregated_df['col'] = pd.Categorical(aggregated_df['col'], categories=top_columns + ["<other>"], ordered=True)
1223
- # aggregated_df['col'] = pd.Categorical(
1224
- # aggregated_df['col'].map(lambda x: x if x in top_columns else "<other>"),
1225
- # categories=top_columns + ["<other>"],
1226
- # ordered=True
1227
- # )
1228
- aggregated_df['col'] = aggregated_df['col'].apply(lambda x: x if x in top_columns else "<other>")
1325
+ aggregated_df = aggregate_data(
1326
+ df,
1327
+ top_n_index,
1328
+ top_n_color,
1329
+ top_n_facet,
1330
+ null_label,
1331
+ show_other=show_other,
1332
+ sort_values_index=sort_values_index,
1333
+ sort_values_color=sort_values_color,
1334
+ sort_values_facet=sort_values_facet,
1335
+ )
1229
1336
 
1337
+ facets = sorted(
1338
+ aggregated_df["facet"].unique()
1339
+ ) # Ensure facets are sorted consistently
1230
1340
 
1231
- # columns = sorted(aggregated_df['col'].unique())
1232
- columns = aggregated_df.groupby('col', observed=True)['value'].sum().sort_values(ascending=False).index.tolist()
1341
+ columns = sorted(
1342
+ aggregated_df.groupby("col", observed=True)["value"]
1343
+ .sum()
1344
+ .sort_values(ascending=False)
1345
+ .index.tolist()
1346
+ )
1233
1347
  column_colors = assign_column_colors(columns, color_palette, null_label)
1234
1348
 
1235
1349
  fig = make_subplots(
@@ -1238,25 +1352,39 @@ def plot_facet_stacked_bars(
1238
1352
  subplot_titles=facets,
1239
1353
  )
1240
1354
 
1355
+ # * Ensure all categories appear in the legend by adding an invisible trace
1356
+ for column in columns:
1357
+ fig.add_trace(
1358
+ go.Bar(
1359
+ x=[None], # Invisible bar
1360
+ y=[None],
1361
+ name=column,
1362
+ marker=dict(color=column_colors[column]),
1363
+ showlegend=True, # Ensure it appears in the legend
1364
+ )
1365
+ )
1366
+
1241
1367
  added_to_legend = set()
1242
1368
  for i, facet in enumerate(facets):
1243
- facet_data = aggregated_df[aggregated_df['facet'] == facet]
1369
+ facet_data = aggregated_df[aggregated_df["facet"] == facet]
1244
1370
  row = (i // subplots_per_row) + 1
1245
1371
  col = (i % subplots_per_row) + 1
1246
1372
 
1247
1373
  for column in columns:
1248
- column_data = facet_data[facet_data['col'] == column]
1374
+ column_data = facet_data[facet_data["col"] == column]
1375
+
1249
1376
  show_legend = column not in added_to_legend
1250
1377
  if show_legend:
1251
1378
  added_to_legend.add(column)
1252
1379
 
1253
1380
  fig.add_trace(
1254
1381
  go.Bar(
1255
- x=column_data['index'],
1256
- y=column_data['value'],
1382
+ x=column_data["index"],
1383
+ y=column_data["value"],
1257
1384
  name=column,
1258
1385
  marker=dict(color=column_colors[column]),
1259
- showlegend=show_legend,
1386
+ legendgroup=column, # Ensures multiple traces use the same legend entry
1387
+ showlegend=False, # suppress further legend items
1260
1388
  ),
1261
1389
  row=row,
1262
1390
  col=col,
@@ -1265,8 +1393,8 @@ def plot_facet_stacked_bars(
1265
1393
  if annotations:
1266
1394
  for _, row_data in column_data.iterrows():
1267
1395
  fig.add_annotation(
1268
- x=row_data['index'],
1269
- y=row_data['value'],
1396
+ x=row_data["index"],
1397
+ y=row_data["value"],
1270
1398
  text=f"{row_data['value']:.{precision}f}",
1271
1399
  showarrow=False,
1272
1400
  row=row,
@@ -1280,8 +1408,8 @@ def plot_facet_stacked_bars(
1280
1408
  else:
1281
1409
  axis_details.append(f"[{original_column_names[0]}]")
1282
1410
 
1283
- if top_n_columns > 0:
1284
- axis_details.append(f"TOP {top_n_columns} [{original_column_names[1]}]")
1411
+ if top_n_color > 0:
1412
+ axis_details.append(f"TOP {top_n_color} [{original_column_names[1]}]")
1285
1413
  else:
1286
1414
  axis_details.append(f"[{original_column_names[1]}]")
1287
1415
 
@@ -1290,7 +1418,7 @@ def plot_facet_stacked_bars(
1290
1418
  else:
1291
1419
  axis_details.append(f"[{original_column_names[2]}]")
1292
1420
 
1293
- title = f"{caption} {', '.join(axis_details)}, n = {unique_rows:_}"
1421
+ title = f"{caption} {', '.join(axis_details)}, n = {original_rows:_}"
1294
1422
  template = "plotly_dark" if os.getenv("THEME") == "dark" else "plotly"
1295
1423
  fig.update_layout(
1296
1424
  title=title,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pandas-plots
3
- Version: 0.12.7
3
+ Version: 0.12.8
4
4
  Summary: A collection of helper for table handling and visualization
5
5
  Home-page: https://github.com/smeisegeier/pandas-plots
6
6
  Author: smeisegeier
File without changes
File without changes