pandas-plots 0.14.1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pandas_plots/pls.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from pathlib import Path
2
2
  import warnings
3
+
3
4
  warnings.filterwarnings("ignore")
4
5
 
5
6
  import os
@@ -62,12 +63,14 @@ def aggregate_data(
62
63
  top_indexes = (
63
64
  aggregated_df.groupby("index")["value"]
64
65
  .sum()
65
- .sort_values(ascending=False)[:top_n_index or None]
66
+ .sort_values(ascending=False)[: top_n_index or None]
66
67
  .index
67
68
  )
68
-
69
+
69
70
  else:
70
- top_indexes = aggregated_df["index"].sort_values().unique()[:top_n_index or None]
71
+ top_indexes = (
72
+ aggregated_df["index"].sort_values().unique()[: top_n_index or None]
73
+ )
71
74
 
72
75
  aggregated_df = aggregated_df[aggregated_df["index"].isin(top_indexes)]
73
76
 
@@ -75,18 +78,16 @@ def aggregate_data(
75
78
  top_colors = (
76
79
  aggregated_df.groupby("col")["value"]
77
80
  .sum()
78
- .sort_values(ascending=False)[:top_n_color or None]
81
+ .sort_values(ascending=False)[: top_n_color or None]
79
82
  .index
80
83
  )
81
84
  else:
82
- top_colors = aggregated_df["col"].sort_values().unique()[:top_n_color or None]
85
+ top_colors = aggregated_df["col"].sort_values().unique()[: top_n_color or None]
83
86
 
84
87
  others_df = df[~df["col"].isin(top_colors)]
85
88
  aggregated_df = aggregated_df[aggregated_df["col"].isin(top_colors)]
86
89
  if show_other and top_n_color > 0 and not others_df.empty:
87
- other_agg = others_df.groupby(["index", "facet"], as_index=False)[
88
- "value"
89
- ].sum()
90
+ other_agg = others_df.groupby(["index", "facet"], as_index=False)["value"].sum()
90
91
  other_agg["col"] = "<other>"
91
92
  other_agg = other_agg[["index", "col", "facet", "value"]]
92
93
  aggregated_df = pd.concat([aggregated_df, other_agg], ignore_index=True)
@@ -96,11 +97,13 @@ def aggregate_data(
96
97
  top_facets = (
97
98
  aggregated_df.groupby("facet")["value"]
98
99
  .sum()
99
- .sort_values(ascending=False)[:top_n_facet or None]
100
+ .sort_values(ascending=False)[: top_n_facet or None]
100
101
  .index
101
102
  )
102
103
  else:
103
- top_facets = aggregated_df["facet"].sort_values().unique()[:top_n_facet or None]
104
+ top_facets = (
105
+ aggregated_df["facet"].sort_values().unique()[: top_n_facet or None]
106
+ )
104
107
 
105
108
  aggregated_df = aggregated_df[aggregated_df["facet"].isin(top_facets)]
106
109
 
@@ -241,7 +244,9 @@ def plot_stacked_bars(
241
244
  color_palette: str = "Plotly",
242
245
  null_label: str = "<NA>",
243
246
  show_other: bool = False,
244
- ) -> plotly.graph_objects:
247
+ show_pct_all: bool = False,
248
+ show_pct_bar: bool = False,
249
+ ) -> None:
245
250
  """
246
251
  Generates a stacked bar plot using the provided DataFrame.
247
252
 
@@ -270,9 +275,10 @@ def plot_stacked_bars(
270
275
  - show_other (bool): If True, shows the "Other" category in the legend.
271
276
  - sort_values_index (bool): If True, sorts the index categories by group sum
272
277
  - sort_values_color (bool): If True, sorts the columns categories by group sum
278
+ - show_pct_all (bool): If True, formats the bar text with percentages from the total n.
279
+ - show_pct_bar (bool): If True, formats the bar text with percentages from the bar's total.
273
280
 
274
- Returns:
275
- - A Plotly figure object representing the stacked bar chart.
281
+ Returns: None
276
282
  """
277
283
  BAR_LENGTH_MULTIPLIER = 1.05
278
284
 
@@ -358,27 +364,41 @@ def plot_stacked_bars(
358
364
  show_other=show_other,
359
365
  sort_values_index=sort_values_index,
360
366
  sort_values_color=sort_values_color,
361
- sort_values_facet=False, # just a placeholder
367
+ sort_values_facet=False, # just a placeholder
362
368
  )
363
369
 
364
370
  df = aggregated_df.copy()
365
371
 
372
+ # * calculate bar totals
373
+ bar_totals = df.groupby("index")["value"].transform("sum")
374
+
366
375
  caption = _set_caption(caption)
367
376
 
368
377
  # * after grouping add cols for pct and formatting
369
- df["cnt_pct_only"] = df["value"].apply(lambda x: f"{(x / n) * 100:.{precision}f}%")
378
+ df["cnt_pct_all_only"] = df["value"].apply(lambda x: f"{(x / n) * 100:.{precision}f}%")
379
+ df["cnt_pct_bar_only"] = (df["value"] / bar_totals * 100).apply(lambda x: f"{x:.{precision}f}%")
370
380
 
371
381
  # * format output
372
382
  df["cnt_str"] = df["value"].apply(lambda x: f"{x:_.{precision}f}")
373
383
 
374
384
  divider2 = "<br>" if orientation == "v" else " "
375
- df["cnt_pct_str"] = df.apply(
376
- lambda row: f"{row['cnt_str']}{divider2}({row['cnt_pct_only']})", axis=1
385
+
386
+ df["cnt_pct_all_str"] = df.apply(
387
+ lambda row: f"{row['cnt_str']}{divider2}({row['cnt_pct_all_only']})", axis=1
388
+ )
389
+ df["cnt_pct_bar_str"] = df.apply(
390
+ lambda row: f"{row['cnt_str']}{divider2}({row['cnt_pct_bar_only']})", axis=1
377
391
  )
378
392
 
393
+ text_to_show = "cnt_str"
394
+ if show_pct_all:
395
+ text_to_show = "cnt_pct_all_str"
396
+ elif show_pct_bar:
397
+ text_to_show = "cnt_pct_bar_str"
398
+
379
399
  if sort_values_color:
380
- colors_unique = (df
381
- .groupby("col", observed=True)["value"]
400
+ colors_unique = (
401
+ df.groupby("col", observed=True)["value"]
382
402
  .sum()
383
403
  .sort_values(ascending=False)
384
404
  .index.tolist()
@@ -387,8 +407,8 @@ def plot_stacked_bars(
387
407
  colors_unique = sorted(df["col"].unique().tolist())
388
408
 
389
409
  if sort_values_index:
390
- index_unique = (df
391
- .groupby("index", observed=True)["value"]
410
+ index_unique = (
411
+ df.groupby("index", observed=True)["value"]
392
412
  .sum()
393
413
  .sort_values(ascending=False)
394
414
  .index.tolist()
@@ -397,7 +417,6 @@ def plot_stacked_bars(
397
417
  index_unique = sorted(df["index"].unique().tolist())
398
418
 
399
419
  color_map = assign_column_colors(colors_unique, color_palette, null_label)
400
-
401
420
 
402
421
  cat_orders = {
403
422
  "index": index_unique,
@@ -405,8 +424,9 @@ def plot_stacked_bars(
405
424
  }
406
425
 
407
426
  # Ensure bl is categorical with the correct order
408
- df["index"] = pd.Categorical(df["index"], categories=cat_orders["index"], ordered=True)
409
-
427
+ df["index"] = pd.Categorical(
428
+ df["index"], categories=cat_orders["index"], ordered=True
429
+ )
410
430
 
411
431
  # * plot
412
432
  fig = px.bar(
@@ -415,18 +435,15 @@ def plot_stacked_bars(
415
435
  y="value" if orientation == "v" else "index",
416
436
  # color=columns,
417
437
  color="col",
418
- text="cnt_pct_str" if normalize else "cnt_str",
438
+ text=text_to_show,
419
439
  orientation=orientation,
420
440
  title=title
421
441
  or f"{caption}{_title_str_top_index}[{col_index}] by {_title_str_top_color}[{col_color}]{_title_str_null}{_title_str_n}",
422
442
  template="plotly_dark" if os.getenv("THEME") == "dark" else "plotly",
423
- width=width,
424
- height=height,
425
443
  color_discrete_map=color_map, # Use assigned colors
426
- category_orders= cat_orders,
444
+ category_orders=cat_orders,
427
445
  )
428
446
 
429
-
430
447
  # print(cat_orders)
431
448
  # print(color_map)
432
449
  # display(df)
@@ -457,10 +474,9 @@ def plot_stacked_bars(
457
474
  },
458
475
  },
459
476
  )
460
- fig.update_layout(legend_traceorder="normal")
477
+ fig.update_layout(legend_traceorder="normal")
461
478
  fig.update_layout(legend_title_text=col_color)
462
479
 
463
-
464
480
  # * set dtick
465
481
  if orientation == "h":
466
482
  if relative:
@@ -482,10 +498,13 @@ def plot_stacked_bars(
482
498
  if png_path is not None:
483
499
  fig.write_image(Path(png_path).as_posix())
484
500
 
485
- fig.show(renderer=renderer)
486
-
487
- return fig
501
+ fig.show(
502
+ renderer=renderer,
503
+ width=width,
504
+ height=height,
505
+ )
488
506
 
507
+ return
489
508
 
490
509
  def plot_bars(
491
510
  df_in: pd.Series | pd.DataFrame,
@@ -504,7 +523,7 @@ def plot_bars(
504
523
  precision: int = 0,
505
524
  renderer: Literal["png", "svg", None] = "png",
506
525
  png_path: Path | str = None,
507
- ) -> object:
526
+ ) -> None:
508
527
  """
509
528
  A function to plot a bar chart based on a *categorical* column (must be string or bool) and a numerical value.
510
529
  Accepts:
@@ -532,8 +551,7 @@ def plot_bars(
532
551
  - renderer: A string indicating the renderer to use for displaying the chart. It can be "png", "svg", or None. Default is "png".
533
552
  - png_path (Path | str, optional): The path to save the image as a png file. Defaults to None.
534
553
 
535
- Returns:
536
- - plot object
554
+ Returns: None
537
555
  """
538
556
  # * if series, apply value_counts, deselect use_ci
539
557
  if isinstance(df_in, pd.Series):
@@ -563,8 +581,9 @@ def plot_bars(
563
581
 
564
582
  # * ensure df is grouped to prevent false aggregations, reset index to return df
565
583
  if use_ci:
566
- # * grouping is smoother on df than on series
567
- df = (df_in
584
+ # * grouping is smoother on df than on series
585
+ df = (
586
+ df_in
568
587
  # ? dont dropna() here, this biases the input data
569
588
  .groupby(
570
589
  col_index,
@@ -573,7 +592,12 @@ def plot_bars(
573
592
  .agg(
574
593
  mean=(col_name, ci_agg),
575
594
  # * retrieve margin from custom func
576
- margin=(col_name, lambda x: mean_confidence_interval(x, use_median = (ci_agg == "median"))[1]),
595
+ margin=(
596
+ col_name,
597
+ lambda x: mean_confidence_interval(
598
+ x, use_median=(ci_agg == "median")
599
+ )[1],
600
+ ),
577
601
  )
578
602
  .reset_index()
579
603
  )
@@ -593,7 +617,6 @@ def plot_bars(
593
617
  else:
594
618
  df = df.fillna("<NA>")
595
619
 
596
-
597
620
  # * get n, col1 now is always numeric
598
621
  n = df[df.columns[1]].sum()
599
622
  n_len = len(df_in)
@@ -657,7 +680,9 @@ def plot_bars(
657
680
 
658
681
  # * title str n
659
682
  _title_str_n = (
660
- f", n={n_len:_} ({n:_})" if not use_ci else f", n={n_len:_})<br><sub>ci(95) on {ci_agg}s<sub>"
683
+ f", n={n_len:_} ({n:_})"
684
+ if not use_ci
685
+ else f", n={n_len:_})<br><sub>ci(95) on {ci_agg}s<sub>"
661
686
  )
662
687
 
663
688
  # * title str na
@@ -680,8 +705,6 @@ def plot_bars(
680
705
  or f"{caption}{_title_str_minval}{_title_str_top}[{col_name}] by [{col_index}]{_title_str_null}{_title_str_n}",
681
706
  # * retrieve theme from env (intro.set_theme) or default
682
707
  template="plotly_dark" if os.getenv("THEME") == "dark" else "plotly",
683
- width=width,
684
- height=height,
685
708
  error_y=None if not use_ci else df["margin"],
686
709
  color_discrete_sequence=px.colors.qualitative.D3,
687
710
  color=col_index,
@@ -734,14 +757,12 @@ def plot_bars(
734
757
  _fig.update_layout(yaxis={"categoryorder": "category descending"})
735
758
 
736
759
  # * looks better on single bars
737
- _fig.update_traces(
738
- error_y=dict(thickness=5)
739
- )
760
+ _fig.update_traces(error_y=dict(thickness=5))
740
761
  if use_ci:
741
762
  _fig.update_traces(
742
763
  textposition="inside", # Put labels inside bars
743
764
  insidetextanchor="start", # Align labels at the bottom
744
- textfont=dict(size=14, color="white") # Adjust text color for visibility
765
+ textfont=dict(size=14, color="white"), # Adjust text color for visibility
745
766
  )
746
767
  else:
747
768
  _fig.update_traces(
@@ -750,14 +771,17 @@ def plot_bars(
750
771
  )
751
772
 
752
773
  # * set axis title
753
-
754
- _fig.show(renderer)
774
+ _fig.show(
775
+ renderer,
776
+ width=width,
777
+ height=height,
778
+ )
755
779
 
756
780
  # * save to png if path is provided
757
781
  if png_path is not None:
758
782
  _fig.write_image(Path(png_path).as_posix())
759
783
 
760
- return _fig
784
+ return
761
785
 
762
786
 
763
787
  def plot_histogram(
@@ -776,7 +800,7 @@ def plot_histogram(
776
800
  caption: str = None,
777
801
  title: str = None,
778
802
  png_path: Path | str = None,
779
- ) -> object:
803
+ ) -> None:
780
804
  """
781
805
  A function to plot a histogram based on *numeric* columns in a DataFrame.
782
806
  Accepts:
@@ -799,8 +823,7 @@ def plot_histogram(
799
823
  png_path (Path | str, optional): The path to save the image as a png file. Defaults to None.
800
824
 
801
825
 
802
- Returns:
803
- plot object
826
+ Returns: None
804
827
  """
805
828
 
806
829
  # * convert to df if series
@@ -828,8 +851,6 @@ def plot_histogram(
828
851
  marginal="box",
829
852
  barmode=barmode,
830
853
  text_auto=text_auto,
831
- height=height,
832
- width=width,
833
854
  orientation=orientation,
834
855
  title=title or f"{_caption}[{', '.join(df.columns)}], n={df.shape[0]:_}",
835
856
  template="plotly_dark" if os.getenv("THEME") == "dark" else "plotly",
@@ -848,13 +869,17 @@ def plot_histogram(
848
869
  showlegend=False if df.shape[1] == 1 else True,
849
870
  )
850
871
 
851
- fig.show(renderer)
872
+ fig.show(
873
+ renderer,
874
+ width=width,
875
+ height=height,
876
+ )
852
877
 
853
878
  # * save to png if path is provided
854
879
  if png_path is not None:
855
880
  fig.write_image(Path(png_path).as_posix())
856
881
 
857
- return fig
882
+ return
858
883
 
859
884
 
860
885
  def plot_joint(
@@ -866,7 +891,7 @@ def plot_joint(
866
891
  caption: str = "",
867
892
  title: str = "",
868
893
  png_path: Path | str = None,
869
- ) -> object:
894
+ ) -> None:
870
895
  """
871
896
  Generate a seaborn joint plot for *two numeric* columns of a given DataFrame.
872
897
 
@@ -880,8 +905,7 @@ def plot_joint(
880
905
  - title: The title of the plot.
881
906
  - png_path (Path | str, optional): The path to save the image as a png file. Defaults to None.
882
907
 
883
- Returns:
884
- plot object
908
+ Returns: None
885
909
  """
886
910
 
887
911
  if df.shape[1] != 2:
@@ -953,7 +977,7 @@ def plot_joint(
953
977
  if png_path is not None:
954
978
  fig.savefig(Path(png_path).as_posix())
955
979
 
956
- return fig
980
+ return
957
981
 
958
982
 
959
983
  def plot_box(
@@ -971,7 +995,8 @@ def plot_box(
971
995
  x_max: float = None,
972
996
  use_log: bool = False,
973
997
  png_path: Path | str = None,
974
- ) -> object:
998
+ renderer: Literal["png", "svg", None] = "png",
999
+ ) -> None:
975
1000
  """
976
1001
  Plots a horizontal box plot for the given pandas Series.
977
1002
 
@@ -990,9 +1015,9 @@ def plot_box(
990
1015
  x_max: The maximum value for the x-axis scale (max and min must be set).
991
1016
  use_log: Use logarithmic scale for the axis.
992
1017
  png_path (Path | str, optional): The path to save the image as a png file. Defaults to None.
1018
+ renderer (Literal["png", "svg", None], optional): The renderer to use for saving the image. Defaults to "png".
993
1019
 
994
- Returns:
995
- plot object
1020
+ Returns: None
996
1021
  """
997
1022
  ser = to_series(ser)
998
1023
  if ser is None:
@@ -1024,11 +1049,9 @@ def plot_box(
1024
1049
  "data_frame": ser,
1025
1050
  "orientation": "h",
1026
1051
  "template": "plotly_dark" if os.getenv("THEME") == "dark" else "plotly",
1027
- "height": height,
1028
- "width": width,
1029
1052
  "points": points,
1030
1053
  # 'box':True,
1031
- "log_x": use_log, # * logarithmic scale, axis is always x
1054
+ "log_x": use_log, # * logarithmic scale, axis is always x
1032
1055
  # "notched": True,
1033
1056
  "title": f"{caption}[{ser.name}]{log_str}, n = {n_:_}" if not title else title,
1034
1057
  }
@@ -1106,7 +1129,11 @@ def plot_box(
1106
1129
  y=-0,
1107
1130
  )
1108
1131
 
1109
- fig.show("png")
1132
+ fig.show(
1133
+ renderer=renderer,
1134
+ width=width,
1135
+ height=height,
1136
+ )
1110
1137
 
1111
1138
  if summary:
1112
1139
  # * if only series is provided, col name is None
@@ -1116,9 +1143,7 @@ def plot_box(
1116
1143
  if png_path is not None:
1117
1144
  fig.write_image(Path(png_path).as_posix())
1118
1145
 
1119
- return fig
1120
-
1121
-
1146
+ return
1122
1147
 
1123
1148
 
1124
1149
  def plot_boxes(
@@ -1134,7 +1159,8 @@ def plot_boxes(
1134
1159
  use_log: bool = False,
1135
1160
  box_width: float = 0.5,
1136
1161
  png_path: Path | str = None,
1137
- ) -> object:
1162
+ renderer: Literal["png", "svg", None] = "png",
1163
+ ) -> None:
1138
1164
  """
1139
1165
  [Experimental] Plot vertical boxes for each unique item in the DataFrame and add annotations for statistics.
1140
1166
 
@@ -1149,9 +1175,9 @@ def plot_boxes(
1149
1175
  summary (bool): Whether to add a summary to the plot.
1150
1176
  use_log (bool): Whether to use logarithmic scale for the plot (cannot show negative values).
1151
1177
  png_path (Path | str, optional): The path to save the image as a png file. Defaults to None.
1178
+ renderer (Literal["png", "svg", None], optional): The renderer to use for saving the image. Defaults to "png".
1152
1179
 
1153
- Returns:
1154
- plot object
1180
+ Returns: None
1155
1181
  """
1156
1182
 
1157
1183
  if (
@@ -1184,8 +1210,6 @@ def plot_boxes(
1184
1210
  color=df.iloc[:, 0],
1185
1211
  template="plotly_dark" if os.getenv("THEME") == "dark" else "plotly",
1186
1212
  orientation="v",
1187
- height=height,
1188
- width=width,
1189
1213
  points=points,
1190
1214
  log_y=use_log,
1191
1215
  # color_discrete_sequence=px.colors.qualitative.Plotly,
@@ -1264,9 +1288,11 @@ def plot_boxes(
1264
1288
  fig.update_yaxes(title_text=df.columns[1])
1265
1289
  fig.update_layout(boxmode="group") # Ensures boxes are not too compressed
1266
1290
  fig.update_layout(showlegend=False)
1267
- fig.update_traces(marker=dict(size=5), width=box_width) # Adjust width (default ~0.5)
1291
+ fig.update_traces(
1292
+ marker=dict(size=5), width=box_width
1293
+ ) # Adjust width (default ~0.5)
1268
1294
 
1269
- fig.show("png")
1295
+ fig.show(renderer=renderer, width=width, height=height)
1270
1296
  if summary:
1271
1297
  # * sort df by first column
1272
1298
  print_summary(df=df.sort_values(df.columns[0]), precision=precision)
@@ -1275,7 +1301,7 @@ def plot_boxes(
1275
1301
  if png_path is not None:
1276
1302
  fig.write_image(Path(png_path).as_posix())
1277
1303
 
1278
- return fig
1304
+ return
1279
1305
 
1280
1306
 
1281
1307
  def plot_facet_stacked_bars(
@@ -1299,20 +1325,51 @@ def plot_facet_stacked_bars(
1299
1325
  sort_values_facet: bool = False,
1300
1326
  relative: bool = False,
1301
1327
  show_pct: bool = False,
1302
- ) -> go.Figure:
1328
+ ) -> None:
1329
+
1330
+ """
1331
+ A function to plot multiple (subplots_per_row) stacked bar charts, facetted by the third column, with the first column as the index and the second column as the colors.
1303
1332
 
1304
- # --- ENFORCE show_pct RULES ---
1333
+ Parameters:
1334
+ - df (pd.DataFrame): Input DataFrame with 3 or 4 columns.
1335
+ - subplots_per_row (int): The number of subplots to display per row.
1336
+ - top_n_index (int): The number of top indexes to include in the chart. Default is 0, which includes all indexes.
1337
+ - top_n_color (int): The number of top colors to include in the chart. Default is 0, which includes all colors.
1338
+ - top_n_facet (int): The number of top facets to include in the chart. Default is 0, which includes all facets.
1339
+ - null_label (str): The label to use for null values. Default is "<NA>".
1340
+ - subplot_size (int): The size of each subplot in pixels. Default is 300.
1341
+ - color_palette (str): The name of the color palette to use. Default is "Plotly".
1342
+ - caption (str): An optional string indicating the caption for the chart.
1343
+ - renderer (str): The output format. Default is "png".
1344
+ - annotations (bool): Whether to include annotations on the chart. Default is False.
1345
+ - precision (int): The number of decimal places to round the values to. Default is 0.
1346
+ - png_path (str): The path to save the chart to, if provided.
1347
+ - show_other (bool): Whether to include "<other>" for columns not in top_n_color. Default is False.
1348
+ - sort_values (bool): Whether to sort the values in the chart. Default is True.
1349
+ - sort_values_index (bool): Whether to sort the index column. Default is False.
1350
+ - sort_values_color (bool): Whether to sort the color column. Default is False.
1351
+ - sort_values_facet (bool): Whether to sort the facet column. Default is False.
1352
+ - relative (bool): Whether to show the bars as relative values (0-1 range). Default is False.
1353
+ - show_pct (bool): Whether to show the annotations as percentages. Default is False.
1354
+
1355
+ Returns: None
1356
+ """
1357
+ # ENFORCE show_pct RULES ---
1305
1358
  if not relative:
1306
1359
  # If bars are absolute, annotations MUST be absolute
1307
1360
  if show_pct:
1308
- print("Warning: 'show_pct' cannot be True when 'relative' is False. Setting 'show_pct' to False.")
1361
+ print(
1362
+ "Warning: 'show_pct' cannot be True when 'relative' is False. Setting 'show_pct' to False."
1363
+ )
1309
1364
  show_pct = False
1310
- # ------------------------------
1365
+ #
1311
1366
 
1312
1367
  try:
1313
1368
  precision = int(precision)
1314
1369
  except (ValueError, TypeError):
1315
- print(f"Warning: 'precision' received as {precision} (type: {type(precision)}). Defaulting to 0.")
1370
+ print(
1371
+ f"Warning: 'precision' received as {precision} (type: {type(precision)}). Defaulting to 0."
1372
+ )
1316
1373
  precision = 0
1317
1374
 
1318
1375
  df_copy = df.copy()
@@ -1331,7 +1388,7 @@ def plot_facet_stacked_bars(
1331
1388
  n = df_copy["value"].sum()
1332
1389
  original_rows = len(df_copy)
1333
1390
 
1334
- aggregated_df = aggregate_data( # Assumes aggregate_data is accessible
1391
+ aggregated_df = aggregate_data( # Assumes aggregate_data is accessible
1335
1392
  df_copy,
1336
1393
  top_n_index,
1337
1394
  top_n_color,
@@ -1343,46 +1400,60 @@ def plot_facet_stacked_bars(
1343
1400
  sort_values_facet=sort_values_facet,
1344
1401
  )
1345
1402
 
1346
- aggregated_df['index'] = aggregated_df['index'].astype(str)
1347
- aggregated_df['col'] = aggregated_df['col'].astype(str)
1348
- aggregated_df['facet'] = aggregated_df['facet'].astype(str)
1403
+ aggregated_df["index"] = aggregated_df["index"].astype(str)
1404
+ aggregated_df["col"] = aggregated_df["col"].astype(str)
1405
+ aggregated_df["facet"] = aggregated_df["facet"].astype(str)
1349
1406
 
1350
1407
  # --- Store original 'value' for annotations before potential scaling ---
1351
- aggregated_df['annotation_value'] = aggregated_df['value'].copy()
1408
+ aggregated_df["annotation_value"] = aggregated_df["value"].copy()
1352
1409
  # ----------------------------------------------------------------------
1353
1410
 
1354
1411
  if relative:
1355
1412
  # This transforms the bar heights (value column) to percentages (0-1 range)
1356
- aggregated_df["value"] = aggregated_df.groupby(["facet", "index"])["value"].transform(lambda x: x / x.sum())
1413
+ aggregated_df["value"] = aggregated_df.groupby(["facet", "index"])[
1414
+ "value"
1415
+ ].transform(lambda x: x / x.sum())
1357
1416
 
1358
1417
  category_orders = {}
1359
1418
 
1360
1419
  if sort_values_index:
1361
- sum_by_index = aggregated_df.groupby('index')['value'].sum().sort_values(ascending=False)
1420
+ sum_by_index = (
1421
+ aggregated_df.groupby("index")["value"].sum().sort_values(ascending=False)
1422
+ )
1362
1423
  category_orders["index"] = sum_by_index.index.tolist()
1363
1424
 
1364
1425
  if sort_values_color:
1365
- sum_by_col = aggregated_df.groupby('col')['value'].sum().sort_values(ascending=False)
1426
+ sum_by_col = (
1427
+ aggregated_df.groupby("col")["value"].sum().sort_values(ascending=False)
1428
+ )
1366
1429
  category_orders["col"] = sum_by_col.index.tolist()
1367
1430
 
1368
1431
  if sort_values_facet:
1369
- sum_by_facet = aggregated_df.groupby('facet')['value'].sum().sort_values(ascending=False)
1432
+ sum_by_facet = (
1433
+ aggregated_df.groupby("facet")["value"].sum().sort_values(ascending=False)
1434
+ )
1370
1435
  category_orders["facet"] = sum_by_facet.index.tolist()
1371
1436
 
1372
1437
  columns_for_color = sorted(aggregated_df["col"].unique().tolist())
1373
- column_colors_map = assign_column_colors(columns_for_color, color_palette, null_label) # Assumes assign_column_colors is accessible
1438
+ column_colors_map = assign_column_colors(
1439
+ columns_for_color, color_palette, null_label
1440
+ ) # Assumes assign_column_colors is accessible
1374
1441
 
1375
- # --- Prepare the text series for annotations with 'show_pct' control ---
1442
+ # Prepare the text series for annotations with 'show_pct' control
1376
1443
  if annotations:
1377
1444
  if show_pct:
1378
1445
  # When show_pct is True, use the scaled 'value' column (0-1) and format as percentage
1379
- formatted_text_series = aggregated_df["value"].apply(lambda x: f"{x:.{precision}%}".replace('.', ','))
1446
+ formatted_text_series = aggregated_df["value"].apply(
1447
+ lambda x: f"{x:.{precision}%}".replace(".", ",")
1448
+ )
1380
1449
  else:
1381
1450
  # When show_pct is False, use the 'annotation_value' (original absolute) and format as absolute
1382
- formatted_text_series = aggregated_df["annotation_value"].apply(lambda x: f"{x:_.{precision}f}".replace('.', ','))
1451
+ formatted_text_series = aggregated_df["annotation_value"].apply(
1452
+ lambda x: f"{x:_.{precision}f}".replace(".", ",")
1453
+ )
1383
1454
  else:
1384
1455
  formatted_text_series = None
1385
- # -----------------------------------------------------------------------
1456
+ # - - - -
1386
1457
 
1387
1458
  fig = px.bar(
1388
1459
  aggregated_df,
@@ -1396,7 +1467,7 @@ def plot_facet_stacked_bars(
1396
1467
  category_orders=category_orders,
1397
1468
  text=formatted_text_series,
1398
1469
  text_auto=False,
1399
- height=subplot_size * (-(-len(aggregated_df["facet"].unique()) // subplots_per_row)),
1470
+ # height=subplot_size * (-(-len(aggregated_df["facet"].unique()) // subplots_per_row)),
1400
1471
  title=f"{caption} {original_column_names[0]}, {original_column_names[1]}, {original_column_names[2]}",
1401
1472
  )
1402
1473
 
@@ -1410,19 +1481,19 @@ def plot_facet_stacked_bars(
1410
1481
  template = "plotly_dark" if os.getenv("THEME") == "dark" else "plotly"
1411
1482
 
1412
1483
  layout_updates = {
1413
- "title_text": f"{caption} "
1414
- f"{'TOP ' + str(top_n_index) + ' ' if top_n_index > 0 else ''}[{original_column_names[0]}] "
1415
- f"{'TOP ' + str(top_n_color) + ' ' if top_n_color > 0 else ''}[{original_column_names[1]}] "
1416
- f"{'TOP ' + str(top_n_facet) + ' ' if top_n_facet > 0 else ''}[{original_column_names[2]}] "
1417
- f", n = {original_rows:_} ({n:_})",
1484
+ "title_text": f"{caption} "
1485
+ f"{'TOP ' + str(top_n_index) + ' ' if top_n_index > 0 else ''}[{original_column_names[0]}] "
1486
+ f"{'TOP ' + str(top_n_color) + ' ' if top_n_color > 0 else ''}[{original_column_names[1]}] "
1487
+ f"{'TOP ' + str(top_n_facet) + ' ' if top_n_facet > 0 else ''}[{original_column_names[2]}] "
1488
+ f", n = {original_rows:_} ({n:_})",
1418
1489
  "showlegend": True,
1419
1490
  "template": template,
1420
- "width": subplot_size * subplots_per_row,
1491
+ # "width": subplot_size * subplots_per_row,
1421
1492
  }
1422
1493
 
1423
1494
  if relative:
1424
- layout_updates['yaxis_range'] = [0, 1.1]
1425
- layout_updates['yaxis_tickformat'] = ".0%"
1495
+ layout_updates["yaxis_range"] = [0, 1.1]
1496
+ layout_updates["yaxis_tickformat"] = ".0%"
1426
1497
 
1427
1498
  fig.update_layout(**layout_updates)
1428
1499
 
@@ -1433,12 +1504,27 @@ def plot_facet_stacked_bars(
1433
1504
  png_path = Path(png_path)
1434
1505
  fig.write_image(str(png_path))
1435
1506
 
1436
- fig.show(renderer=renderer)
1507
+ fig.show(
1508
+ renderer=renderer,
1509
+ width=subplot_size * subplots_per_row,
1510
+ height=subplot_size
1511
+ * (-(-len(aggregated_df["facet"].unique()) // subplots_per_row)),
1512
+ )
1437
1513
 
1438
- return fig
1514
+ return
1439
1515
 
1440
1516
 
1441
- def plot_sankey(df=None, max_events_per_id=None, height=None, width=None, exclude_overlap_id=False, exclude_overlap_event=False, renderer=None, show_start_node=True):
1517
+ def plot_sankey(
1518
+ df=None,
1519
+ max_events_per_id=None,
1520
+ height=None,
1521
+ width=None,
1522
+ exclude_overlap_id=False,
1523
+ exclude_overlap_event=False,
1524
+ renderer=None,
1525
+ show_start_node=True,
1526
+ font_size=10,
1527
+ ):
1442
1528
  """
1443
1529
  Generates a Sankey diagram from a Pandas DataFrame, assuming the column order is:
1444
1530
  1. ID (string or integer)
@@ -1450,71 +1536,117 @@ def plot_sankey(df=None, max_events_per_id=None, height=None, width=None, exclud
1450
1536
 
1451
1537
  Args:
1452
1538
  df (pd.DataFrame, optional): A Pandas DataFrame containing the event data.
1453
- Expected column order: ID, Date, Event.
1539
+ Expected column order: ID, Date, Event.
1454
1540
  max_events_per_id (int, optional): The maximum number of events to display for each ID.
1455
- If None, all events for each ID will be used.
1541
+ If None, all events for each ID will be used.
1456
1542
  height (int, optional): The height of the plot in pixels.
1457
1543
  width (int, optional): The width of the plot in pixels.
1458
1544
  exclude_overlap_id (bool): If True, excludes any IDs that have multiple events on the same date.
1459
- This takes precedence over `exclude_overlap_event`.
1545
+ This takes precedence over `exclude_overlap_event`.
1460
1546
  exclude_overlap_event (bool): If True, only excludes the specific events that fall on the same date,
1461
- retaining other non-overlapping events for that ID.
1547
+ retaining other non-overlapping events for that ID.
1462
1548
  renderer (str, optional): The renderer to use for displaying the plot. Options include
1463
- 'browser', 'notebook', 'json', 'png', 'svg', 'jpeg', 'webp', or 'pdf'.
1464
- If None, plotly's default renderer is used.
1549
+ 'browser', 'notebook', 'json', 'png', 'svg', 'jpeg', 'webp', or 'pdf'.
1550
+ If None, plotly's default renderer is used.
1465
1551
  show_start_node (bool): If True, adds a visual 'start' node and links all
1466
1552
  first events to it. This is useful for visualizing
1467
1553
  IDs with only one event.
1554
+ font_size (int): The font size of the labels in the plot.
1468
1555
  """
1469
1556
  # --- Example Usage with Enlarged Pandas DataFrame if no DataFrame is provided ---
1470
1557
  if df is None:
1471
- data_demo = { # Renamed to data_demo for clarity
1472
- 'tumor-id': [
1473
- '1', '1', '1', '1', '1',
1474
- '2', '2', '2', '2',
1475
- '3', '3', '3', '3',
1476
- '4', '4', '4',
1477
- '5', '5',
1478
- '6', '6',
1479
- '7', '7',
1480
- '8',
1481
- '9',
1482
- '10',
1483
- '11',
1484
- '12'
1558
+ data_demo = { # Renamed to data_demo for clarity
1559
+ "tumor-id": [
1560
+ "1",
1561
+ "1",
1562
+ "1",
1563
+ "1",
1564
+ "1",
1565
+ "2",
1566
+ "2",
1567
+ "2",
1568
+ "2",
1569
+ "3",
1570
+ "3",
1571
+ "3",
1572
+ "3",
1573
+ "4",
1574
+ "4",
1575
+ "4",
1576
+ "5",
1577
+ "5",
1578
+ "6",
1579
+ "6",
1580
+ "7",
1581
+ "7",
1582
+ "8",
1583
+ "9",
1584
+ "10",
1585
+ "11",
1586
+ "12",
1587
+ ],
1588
+ "diagnosis date": [
1589
+ "2020-01-01",
1590
+ "2021-02-01",
1591
+ "2022-03-01",
1592
+ "2023-04-01",
1593
+ "2024-05-01", # Tumor 1
1594
+ "2010-01-01",
1595
+ "2011-02-01",
1596
+ "2012-03-01",
1597
+ "2013-04-01", # Tumor 2
1598
+ "2015-01-01",
1599
+ "2016-02-01",
1600
+ "2017-03-01",
1601
+ "2018-04-01", # Tumor 3
1602
+ "2005-01-01",
1603
+ "2006-02-01",
1604
+ "2007-03-01", # Tumor 4
1605
+ "2019-01-01",
1606
+ "2020-02-01", # Tumor 5
1607
+ "2021-01-01",
1608
+ "2022-02-01", # Tumor 6
1609
+ "2014-01-01",
1610
+ "2015-02-01", # Tumor 7
1611
+ "2025-01-01", # Tumor 8 (single event)
1612
+ "2025-02-01", # Tumor 9 (single event)
1613
+ "2025-03-01", # Tumor 10 (single event)
1614
+ "2025-04-01", # Tumor 11 (single event)
1615
+ "2025-05-01", # Tumor 12 (single event)
1485
1616
  ],
1486
- 'diagnosis date': [
1487
- '2020-01-01', '2021-02-01', '2022-03-01', '2023-04-01', '2024-05-01', # Tumor 1
1488
- '2010-01-01', '2011-02-01', '2012-03-01', '2013-04-01', # Tumor 2
1489
- '2015-01-01', '2016-02-01', '2017-03-01', '2018-04-01', # Tumor 3
1490
- '2005-01-01', '2006-02-01', '2007-03-01', # Tumor 4
1491
- '2019-01-01', '2020-02-01', # Tumor 5
1492
- '2021-01-01', '2022-02-01', # Tumor 6
1493
- '2014-01-01', '2015-02-01', # Tumor 7
1494
- '2025-01-01', # Tumor 8 (single event)
1495
- '2025-02-01', # Tumor 9 (single event)
1496
- '2025-03-01', # Tumor 10 (single event)
1497
- '2025-04-01', # Tumor 11 (single event)
1498
- '2025-05-01' # Tumor 12 (single event)
1617
+ "treatment": [
1618
+ "op",
1619
+ "syst",
1620
+ "op",
1621
+ "rad",
1622
+ "op", # Tumor 1
1623
+ "syst",
1624
+ "st",
1625
+ "op",
1626
+ "rad", # Tumor 2
1627
+ "op",
1628
+ "rad",
1629
+ "syst",
1630
+ "op", # Tumor 3
1631
+ "st",
1632
+ "syst",
1633
+ "op", # Tumor 4
1634
+ "op",
1635
+ "rad", # Tumor 5
1636
+ "syst",
1637
+ "op", # Tumor 6
1638
+ "st",
1639
+ "rad", # Tumor 7
1640
+ "op", # Tumor 8
1641
+ "op", # Tumor 9
1642
+ "syst", # Tumor 10
1643
+ "rad", # Tumor 11
1644
+ "op", # Tumor 12
1499
1645
  ],
1500
- 'treatment': [
1501
- 'op', 'syst', 'op', 'rad', 'op', # Tumor 1
1502
- 'syst', 'st', 'op', 'rad', # Tumor 2
1503
- 'op', 'rad', 'syst', 'op', # Tumor 3
1504
- 'st', 'syst', 'op', # Tumor 4
1505
- 'op', 'rad', # Tumor 5
1506
- 'syst', 'op', # Tumor 6
1507
- 'st', 'rad', # Tumor 7
1508
- 'op', # Tumor 8
1509
- 'op', # Tumor 9
1510
- 'syst', # Tumor 10
1511
- 'rad', # Tumor 11
1512
- 'op' # Tumor 12
1513
- ]
1514
1646
  }
1515
1647
  df = pd.DataFrame(data_demo)
1516
1648
  print("--- Using demo data (data_demo) ---")
1517
- print(df.head().to_string()) # Print first 5 rows of the DataFrame prettily
1649
+ print(df.head().to_string()) # Print first 5 rows of the DataFrame prettily
1518
1650
  print("-----------------------------------")
1519
1651
 
1520
1652
  # --- Simplified Column Recognition based on index ---
@@ -1525,139 +1657,193 @@ def plot_sankey(df=None, max_events_per_id=None, height=None, width=None, exclud
1525
1657
  df_processed = df.copy()
1526
1658
 
1527
1659
  # --- Aggregate the data to remove duplicate rows before processing ---
1528
- df_processed = df_processed.drop_duplicates(subset=[id_col_name, date_col_name, event_col_name])
1660
+ df_processed = df_processed.drop_duplicates(
1661
+ subset=[id_col_name, date_col_name, event_col_name]
1662
+ )
1529
1663
 
1530
1664
  try:
1531
1665
  df_processed[date_col_name] = pd.to_datetime(df_processed[date_col_name])
1532
1666
  except (ValueError, TypeError):
1533
- print(f"Error: Could not convert column '{date_col_name}' to a valid date format.")
1667
+ print(
1668
+ f"Error: Could not convert column '{date_col_name}' to a valid date format."
1669
+ )
1534
1670
  return None
1535
1671
 
1536
1672
  # --- Handle overlap exclusion based on user selection ---
1537
1673
  overlap_title_part = ""
1538
1674
  if exclude_overlap_id:
1539
- overlapping_ids = df_processed.groupby([id_col_name, date_col_name]).size().loc[lambda x: x > 1].index.get_level_values(id_col_name).unique()
1540
- df_processed = df_processed[~df_processed[id_col_name].isin(overlapping_ids)].copy()
1675
+ overlapping_ids = (
1676
+ df_processed.groupby([id_col_name, date_col_name])
1677
+ .size()
1678
+ .loc[lambda x: x > 1]
1679
+ .index.get_level_values(id_col_name)
1680
+ .unique()
1681
+ )
1682
+ df_processed = df_processed[
1683
+ ~df_processed[id_col_name].isin(overlapping_ids)
1684
+ ].copy()
1541
1685
  overlap_title_part = ", overlap ids excluded"
1542
1686
  elif exclude_overlap_event:
1543
- overlapping_event_set = set(df_processed.groupby([id_col_name, date_col_name]).size().loc[lambda x: x > 1].index)
1544
- df_processed = df_processed[~df_processed.set_index([id_col_name, date_col_name]).index.isin(overlapping_event_set)].copy()
1687
+ overlapping_event_set = set(
1688
+ df_processed.groupby([id_col_name, date_col_name])
1689
+ .size()
1690
+ .loc[lambda x: x > 1]
1691
+ .index
1692
+ )
1693
+ df_processed = df_processed[
1694
+ ~df_processed.set_index([id_col_name, date_col_name]).index.isin(
1695
+ overlapping_event_set
1696
+ )
1697
+ ].copy()
1545
1698
  overlap_title_part = ", overlap events excluded"
1546
1699
 
1547
1700
  df_sorted = df_processed.sort_values(by=[id_col_name, date_col_name])
1548
-
1701
+
1549
1702
  # --- Performance Optimization: Use vectorized operations instead of loops ---
1550
- df_sorted['event_order'] = df_sorted.groupby(id_col_name).cumcount() + 1
1551
-
1703
+ df_sorted["event_order"] = df_sorted.groupby(id_col_name).cumcount() + 1
1704
+
1552
1705
  if max_events_per_id is not None:
1553
- df_sorted = df_sorted[df_sorted['event_order'] <= max_events_per_id]
1554
-
1555
- df_sorted['ordered_event_label'] = '[' + df_sorted['event_order'].astype(str) + '] ' + df_sorted[event_col_name]
1556
-
1706
+ df_sorted = df_sorted[df_sorted["event_order"] <= max_events_per_id]
1707
+
1708
+ df_sorted["ordered_event_label"] = (
1709
+ "[" + df_sorted["event_order"].astype(str) + "] " + df_sorted[event_col_name]
1710
+ )
1711
+
1557
1712
  if df_sorted.empty:
1558
1713
  print("No valid data to plot after filtering.")
1559
1714
  return None
1560
1715
 
1561
1716
  # Use a vectorized shift operation to create source and target columns
1562
- df_sorted['source_label'] = df_sorted.groupby(id_col_name)['ordered_event_label'].shift(1)
1563
- df_with_links = df_sorted.dropna(subset=['source_label']).copy()
1717
+ df_sorted["source_label"] = df_sorted.groupby(id_col_name)[
1718
+ "ordered_event_label"
1719
+ ].shift(1)
1720
+ df_with_links = df_sorted.dropna(subset=["source_label"]).copy()
1564
1721
 
1565
1722
  # Create the start node and links if enabled
1566
1723
  if show_start_node:
1567
1724
  first_events = df_sorted.groupby(id_col_name).first().reset_index()
1568
- first_events['source_label'] = "[0] start"
1569
- df_with_links = pd.concat([first_events[['source_label', 'ordered_event_label']], df_with_links[['source_label', 'ordered_event_label']]], ignore_index=True)
1570
-
1571
- link_counts = df_with_links.groupby(['source_label', 'ordered_event_label']).size().reset_index(name='value')
1725
+ first_events["source_label"] = "[0] start"
1726
+ df_with_links = pd.concat(
1727
+ [
1728
+ first_events[["source_label", "ordered_event_label"]],
1729
+ df_with_links[["source_label", "ordered_event_label"]],
1730
+ ],
1731
+ ignore_index=True,
1732
+ )
1733
+
1734
+ link_counts = (
1735
+ df_with_links.groupby(["source_label", "ordered_event_label"])
1736
+ .size()
1737
+ .reset_index(name="value")
1738
+ )
1572
1739
 
1573
1740
  # Get all unique nodes for the labels and sorting
1574
- all_labels = pd.concat([link_counts['source_label'], link_counts['ordered_event_label']]).unique()
1575
- unique_labels_df = pd.DataFrame(all_labels, columns=['label'])
1576
- unique_labels_df['event_order_num'] = unique_labels_df['label'].str.extract(r'\[(\d+)\]').astype(float).fillna(0)
1577
- unique_labels_df['event_name'] = unique_labels_df['label'].str.extract(r'\] (.*)').fillna('start')
1578
- unique_labels_df_sorted = unique_labels_df.sort_values(by=['event_order_num', 'event_name'])
1579
- unique_unformatted_labels_sorted = unique_labels_df_sorted['label'].tolist()
1741
+ all_labels = pd.concat(
1742
+ [link_counts["source_label"], link_counts["ordered_event_label"]]
1743
+ ).unique()
1744
+ unique_labels_df = pd.DataFrame(all_labels, columns=["label"])
1745
+ unique_labels_df["event_order_num"] = (
1746
+ unique_labels_df["label"].str.extract(r"\[(\d+)\]").astype(float).fillna(0)
1747
+ )
1748
+ unique_labels_df["event_name"] = (
1749
+ unique_labels_df["label"].str.extract(r"\] (.*)").fillna("start")
1750
+ )
1751
+ unique_labels_df_sorted = unique_labels_df.sort_values(
1752
+ by=["event_order_num", "event_name"]
1753
+ )
1754
+ unique_unformatted_labels_sorted = unique_labels_df_sorted["label"].tolist()
1580
1755
 
1581
- label_to_index = {label: i for i, label in enumerate(unique_unformatted_labels_sorted)}
1756
+ label_to_index = {
1757
+ label: i for i, label in enumerate(unique_unformatted_labels_sorted)
1758
+ }
1582
1759
 
1583
1760
  # Calculate total unique IDs for percentage calculation
1584
1761
  total_unique_ids = df_processed[id_col_name].nunique()
1585
1762
 
1586
1763
  display_labels = []
1587
- node_counts = df_sorted['ordered_event_label'].value_counts()
1764
+ node_counts = df_sorted["ordered_event_label"].value_counts()
1588
1765
  for label in unique_unformatted_labels_sorted:
1589
1766
  if label == "[0] start":
1590
1767
  count = total_unique_ids
1591
1768
  else:
1592
1769
  count = node_counts.get(label, 0)
1593
-
1770
+
1594
1771
  percentage = (count / total_unique_ids) * 100
1595
- formatted_count = f"{count:,}".replace(',', '_')
1772
+ formatted_count = f"{count:,}".replace(",", "_")
1596
1773
  formatted_percentage = f"({int(round(percentage, 0))}%)"
1597
1774
 
1598
1775
  display_labels.append(f"{label} {formatted_count} {formatted_percentage}")
1599
1776
 
1600
1777
  # Map sources and targets to indices
1601
- sources = link_counts['source_label'].map(label_to_index).tolist()
1602
- targets = link_counts['ordered_event_label'].map(label_to_index).tolist()
1603
- values = link_counts['value'].tolist()
1778
+ sources = link_counts["source_label"].map(label_to_index).tolist()
1779
+ targets = link_counts["ordered_event_label"].map(label_to_index).tolist()
1780
+ values = link_counts["value"].tolist()
1604
1781
 
1605
1782
  # Define a color palette for links
1606
1783
  color_palette = [
1607
- "rgba(255, 99, 71, 0.6)", "rgba(60, 179, 113, 0.6)", "rgba(65, 105, 225, 0.6)",
1608
- "rgba(255, 215, 0, 0.6)", "rgba(147, 112, 219, 0.6)", "rgba(0, 206, 209, 0.6)",
1609
- "rgba(255, 160, 122, 0.6)", "rgba(124, 252, 0, 0.6)", "rgba(30, 144, 255, 0.6)",
1610
- "rgba(218, 165, 32, 0.6)"
1784
+ "rgba(255, 99, 71, 0.6)",
1785
+ "rgba(60, 179, 113, 0.6)",
1786
+ "rgba(65, 105, 225, 0.6)",
1787
+ "rgba(255, 215, 0, 0.6)",
1788
+ "rgba(147, 112, 219, 0.6)",
1789
+ "rgba(0, 206, 209, 0.6)",
1790
+ "rgba(255, 160, 122, 0.6)",
1791
+ "rgba(124, 252, 0, 0.6)",
1792
+ "rgba(30, 144, 255, 0.6)",
1793
+ "rgba(218, 165, 32, 0.6)",
1611
1794
  ]
1612
1795
  start_link_color = "rgba(128, 128, 128, 0.6)"
1613
-
1796
+
1614
1797
  link_colors = []
1615
1798
  link_type_to_color = {}
1616
1799
  color_index = 0
1617
1800
  for i, row in link_counts.iterrows():
1618
- source_l = row['source_label']
1619
- target_l = row['ordered_event_label']
1801
+ source_l = row["source_label"]
1802
+ target_l = row["ordered_event_label"]
1620
1803
  if source_l == "[0] start":
1621
1804
  link_colors.append(start_link_color)
1622
1805
  else:
1623
- source_event_name = re.search(r'\] (.*)', source_l).group(1)
1624
- target_event_name = re.search(r'\] (.*)', target_l).group(1)
1806
+ source_event_name = re.search(r"\] (.*)", source_l).group(1)
1807
+ target_event_name = re.search(r"\] (.*)", target_l).group(1)
1625
1808
  link_type = (source_event_name, target_event_name)
1626
1809
 
1627
1810
  if link_type not in link_type_to_color:
1628
- link_type_to_color[link_type] = color_palette[color_index % len(color_palette)]
1811
+ link_type_to_color[link_type] = color_palette[
1812
+ color_index % len(color_palette)
1813
+ ]
1629
1814
  color_index += 1
1630
1815
  link_colors.append(link_type_to_color[link_type])
1631
1816
 
1632
- formatted_total_ids = f"{total_unique_ids:,}".replace(',', '_')
1817
+ formatted_total_ids = f"{total_unique_ids:,}".replace(",", "_")
1633
1818
  total_rows = len(df_processed)
1634
- formatted_total_rows = f"{total_rows:,}".replace(',', '_')
1635
-
1819
+ formatted_total_rows = f"{total_rows:,}".replace(",", "_")
1820
+
1636
1821
  chart_title = f"[{id_col_name}] over [{event_col_name}]"
1637
1822
  if max_events_per_id is not None:
1638
1823
  chart_title += f", top {max_events_per_id} events"
1639
1824
  chart_title += overlap_title_part
1640
1825
  chart_title += f", n = {formatted_total_ids} ({formatted_total_rows})"
1641
1826
 
1642
- fig = go.Figure(data=[go.Sankey(
1643
- node=dict(
1644
- pad=15,
1645
- thickness=20,
1646
- line=dict(color="black", width=0.5),
1647
- label=display_labels,
1648
- color="blue",
1649
- align="left"
1650
- ),
1651
- link=dict(
1652
- source=sources,
1653
- target=targets,
1654
- value=values,
1655
- color=link_colors
1656
- )
1657
- )])
1827
+ fig = go.Figure(
1828
+ data=[
1829
+ go.Sankey(
1830
+ node=dict(
1831
+ pad=15,
1832
+ thickness=20,
1833
+ line=dict(color="black", width=0.5),
1834
+ label=display_labels,
1835
+ color="blue",
1836
+ align="left",
1837
+ ),
1838
+ link=dict(
1839
+ source=sources, target=targets, value=values, color=link_colors
1840
+ ),
1841
+ )
1842
+ ]
1843
+ )
1658
1844
 
1659
- fig.update_layout(title_text=chart_title, font_size=10, height=height, width=width)
1660
- fig.show(renderer=renderer)
1845
+ fig.update_layout(title_text=chart_title, font_size=font_size)
1846
+ fig.show(renderer=renderer, width=width, height=height)
1661
1847
 
1662
1848
 
1663
1849
  # * extend objects to enable chaining
@@ -1669,4 +1855,4 @@ pd.DataFrame.plot_stacked_boxes = plot_boxes
1669
1855
  pd.DataFrame.plot_quadrants = plot_quadrants
1670
1856
  pd.DataFrame.plot_histogram = plot_histogram
1671
1857
  pd.DataFrame.plot_joint = plot_joint
1672
- pd.DataFrame.plot_sankey = plot_sankey
1858
+ pd.DataFrame.plot_sankey = plot_sankey
pandas_plots/tbl.py CHANGED
@@ -70,8 +70,9 @@ def describe_df(
70
70
  fig_cols: int = 3,
71
71
  fig_offset: int = None,
72
72
  fig_rowheight: int = 300,
73
+ fig_width: int = 400,
73
74
  sort_mode: Literal["value", "index"] = "value",
74
- top_n_uniques: int = 30,
75
+ top_n_uniques: int = 5,
75
76
  top_n_chars_in_index: int = 0,
76
77
  top_n_chars_in_columns: int = 0,
77
78
  ):
@@ -88,6 +89,7 @@ def describe_df(
88
89
  fig_cols (int): number of columns in plot
89
90
  fig_offset (int): offset for plots as iloc Argument. None = no offset, -1 = omit last plot
90
91
  fig_rowheight (int): row height for plot (default 300)
92
+ fig_width (int): width for plot (default 400)
91
93
  sort_mode (Literal["value", "index"]): sort by value or index
92
94
  top_n_uniques (int): number of uniques to display
93
95
  top_n_chars_in_index (int): number of characters to display on plot axis
@@ -203,8 +205,8 @@ def describe_df(
203
205
  subplot_titles=cols,
204
206
  )
205
207
  # * layout settings
206
- fig.layout.height = fig_rowheight * fig_rows
207
- fig.layout.width = 400 * fig_cols
208
+ # fig.layout.height = fig_rowheight * fig_rows
209
+ # fig.layout.width = 400 * fig_cols
208
210
 
209
211
  # * construct subplots
210
212
  for i, col in enumerate(cols):
@@ -246,7 +248,7 @@ def describe_df(
246
248
  fig.update_layout(
247
249
  template="plotly_dark" if os.getenv("THEME") == "dark" else "plotly"
248
250
  )
249
- fig.show(renderer)
251
+ fig.show(renderer, width=fig_width * fig_cols, height=fig_rowheight * fig_rows)
250
252
 
251
253
  if use_missing:
252
254
  import missingno as msno
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pandas-plots
3
- Version: 0.14.1
3
+ Version: 0.15.1
4
4
  Summary: A collection of helper for table handling and visualization
5
5
  Project-URL: Homepage, https://github.com/smeisegeier/pandas-plots
6
6
  Project-URL: Repository, https://github.com/smeisegeier/pandas-plots
7
7
  Project-URL: Bug Tracker, https://github.com/smeisegeier/pandas-plots/issues
8
- Author-email: smeisegeier <meisegeiers@rki.de>
8
+ Author-email: smeisegeier <dexterDSD@googlemail.com>
9
9
  License-File: LICENSE
10
10
  Keywords: pivot,plot,plotly,tables,venn,vizualization
11
11
  Classifier: Development Status :: 4 - Beta
@@ -16,6 +16,7 @@ Classifier: Programming Language :: Python :: 3
16
16
  Classifier: Programming Language :: Python :: 3.10
17
17
  Classifier: Topic :: Scientific/Engineering
18
18
  Requires-Python: >=3.10
19
+ Requires-Dist: connection-helper>=0.11.2
19
20
  Requires-Dist: dataframe-image>=0.2.6
20
21
  Requires-Dist: duckdb>=1.3.0
21
22
  Requires-Dist: jinja2>=3.1.4
@@ -0,0 +1,9 @@
1
+ pandas_plots/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ pandas_plots/hlp.py,sha256=z8rrVNbH9qMohdXPT-FksP-VkTOjI0bGFj47Sw5p3aY,21141
3
+ pandas_plots/pls.py,sha256=k3btK4TWHUJCyHEzu3yLh40G9SuFlW84dYP2RLS5lWY,64118
4
+ pandas_plots/tbl.py,sha256=mzrUif2TUZ8JJmkgzNpVYApBZS8L0MS1Yjpx9KZN7Vs,32920
5
+ pandas_plots/ven.py,sha256=2x3ACo2vSfO3q6fv-UdDQ0h1SJyt8WChBGgE5SDCdCk,11673
6
+ pandas_plots-0.15.1.dist-info/METADATA,sha256=xQ1FomsfZp38k4o_7J-Bp8dIkW3PvHM_wq4qK8QnWFU,7467
7
+ pandas_plots-0.15.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ pandas_plots-0.15.1.dist-info/licenses/LICENSE,sha256=ltLbQWUCs-GBQlTPXbt5nHNBE9U5LzjjoS1Y8hHETM4,1051
9
+ pandas_plots-0.15.1.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- pandas_plots/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- pandas_plots/hlp.py,sha256=z8rrVNbH9qMohdXPT-FksP-VkTOjI0bGFj47Sw5p3aY,21141
3
- pandas_plots/pls.py,sha256=80uXr3bT66LGjDcuT4a0ewCBwATcOUZ3QQ228Hn9glY,60052
4
- pandas_plots/tbl.py,sha256=R2E6FLhxNpUtS88Zf88Eh9i8dSKgmJtmFimFvOt0foQ,32780
5
- pandas_plots/ven.py,sha256=2x3ACo2vSfO3q6fv-UdDQ0h1SJyt8WChBGgE5SDCdCk,11673
6
- pandas_plots-0.14.1.dist-info/METADATA,sha256=7wU-RjYxYQGfw8rshzpbuQ0ci7xJfe-xldiAAshAMjw,7420
7
- pandas_plots-0.14.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- pandas_plots-0.14.1.dist-info/licenses/LICENSE,sha256=ltLbQWUCs-GBQlTPXbt5nHNBE9U5LzjjoS1Y8hHETM4,1051
9
- pandas_plots-0.14.1.dist-info/RECORD,,