plot-misc 2.2.0__py3-none-any.whl → 2.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plot_misc/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.2.0'
1
+ __version__ = '2.2.1'
plot_misc/barchart.py CHANGED
@@ -30,11 +30,12 @@ from plot_misc.errors import (
30
30
  is_df,
31
31
  Error_MSG,
32
32
  )
33
- from typing import Any, Optional
33
+ from typing import Any
34
34
  from plot_misc.constants import Real
35
35
 
36
36
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37
37
  def bar(data:pd.DataFrame, label:str, column:str,
38
+ positions:np.ndarray | list[Real] | None = None,
38
39
  error_max:str | None = None, error_min:str | None = None,
39
40
  colours:list[str]=['tab:blue', 'tab:pink'], transparency:float=0.7,
40
41
  wd:Real=1.0, edgecolour:str='black',
@@ -54,9 +55,13 @@ def bar(data:pd.DataFrame, label:str, column:str,
54
55
  The column name for the axes labels.
55
56
  column : `str`
56
57
  The column name for the bar height values.
57
- error_max : `str`, default `NoneType`
58
+ positions : `np.ndarray` or `list` or `None`, default `None`
59
+ Numeric positions for the bars along the category axis. If None,
60
+ bars are placed at integer positions 0, 1, 2, ... with tick
61
+ labels taken from the `label` column.
62
+ error_max : `str`, default `None`
58
63
  column name for the upper value of the error line segment.
59
- error_min : `str`, default `NoneType`
64
+ error_min : `str`, default `None`
60
65
  column name for the lower value of the error line segment.
61
66
  colours : `list` [`str`], default ['tab:blue', 'tab:pink']
62
67
  Colours for the bars; recycled if shorter than the number of bars.
@@ -68,7 +73,7 @@ def bar(data:pd.DataFrame, label:str, column:str,
68
73
  The bar edgecolour.
69
74
  horizontal : `bool`, default `False`
70
75
  Whether plot a horizontal bar chart.
71
- ax : `plt.ax`, default `NoneType`
76
+ ax : `plt.ax`, default `None`
72
77
  The pyplot.axes object.
73
78
  figsize : `tuple` [`float`, `float`], default (2, 2),
74
79
  The figure size in inches, when ax is set to None.
@@ -96,6 +101,7 @@ def bar(data:pd.DataFrame, label:str, column:str,
96
101
  is_df(data)
97
102
  is_type(label, str)
98
103
  is_type(column, str)
104
+ is_type(positions, (type(None), np.ndarray, list))
99
105
  is_type(colours, list)
100
106
  is_type(transparency, float)
101
107
  is_type(wd, (float, int))
@@ -117,10 +123,19 @@ def bar(data:pd.DataFrame, label:str, column:str,
117
123
  # ### check input
118
124
  if any(data.isna().any()):
119
125
  raise ValueError(Error_MSG.MISSING_DF.format('data'))
120
- # ### get labels
126
+ # ### get labels and positions
121
127
  labels = data[label]
128
+ if positions is None:
129
+ pos = np.arange(len(labels))
130
+ else:
131
+ pos = np.asarray(positions)
132
+ if len(pos) != len(labels):
133
+ raise ValueError(
134
+ f'Length of positions ({len(pos)}) does not match '
135
+ f'number of rows in data ({len(labels)}).'
136
+ )
122
137
  # ### plotting
123
- if horizontal == False:
138
+ if not horizontal:
124
139
  # plotting vertical bar chart
125
140
  new_kwargs = _update_kwargs(update_dict=kwargs_bar,
126
141
  edgecolor=edgecolour,
@@ -128,8 +143,10 @@ def bar(data:pd.DataFrame, label:str, column:str,
128
143
  alpha=transparency,
129
144
  zorder=2,
130
145
  )
131
- bars = ax.bar(labels, height=data[column], **new_kwargs,
146
+ bars = ax.bar(pos, height=data[column], **new_kwargs,
132
147
  )
148
+ ax.set_xticks(pos)
149
+ ax.set_xticklabels(labels)
133
150
  else:
134
151
  # plotting horizontal bar chart
135
152
  new_kwargs = _update_kwargs(update_dict=kwargs_bar,
@@ -138,13 +155,15 @@ def bar(data:pd.DataFrame, label:str, column:str,
138
155
  alpha=transparency,
139
156
  zorder=2,
140
157
  )
141
- bars = ax.barh(labels, width=data[column], **new_kwargs,
158
+ bars = ax.barh(pos, width=data[column], **new_kwargs,
142
159
  )
160
+ ax.set_yticks(pos)
161
+ ax.set_yticklabels(labels)
143
162
  # do we need to plot error bars
144
163
  if error_min is not None or error_max is not None:
145
164
  # finding the mid points of the bars and
146
165
  # initialising the bounds, allowing for one-sided limits.
147
- if horizontal == False:
166
+ if not horizontal:
148
167
  min_l = [b.get_y() + b.get_height() for b in bars]
149
168
  max_l = min_l.copy()
150
169
  else:
@@ -164,7 +183,7 @@ def bar(data:pd.DataFrame, label:str, column:str,
164
183
  color='black',
165
184
  zorder=1,
166
185
  )
167
- if horizontal == False:
186
+ if not horizontal:
168
187
  mids = [b.get_x() + b.get_width() / 2 for b in bars]
169
188
  ax.vlines(mids, min_l, max_l, **new_kwargs_error,)
170
189
  else:
@@ -178,10 +197,11 @@ def bar(data:pd.DataFrame, label:str, column:str,
178
197
 
179
198
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
180
199
  def stack_bar(data:pd.DataFrame, label:str, columns:list[str],
200
+ positions:np.ndarray | list[Real] | None = None,
181
201
  colours:list[str]=['tab:blue', 'tab:pink'],
182
202
  transparency:float=0.7, wd:Real=1.0, edgecolour:str='black',
183
203
  horizontal:bool = False, figsize:tuple[Real,Real] = (2,2),
184
- ax:plt.Axes | None = None, **kwargs:Optional[Any],
204
+ ax:plt.Axes | None = None, **kwargs:Any,
185
205
  ) -> tuple[plt.Figure, plt.Axes]:
186
206
  """
187
207
  Plot a stacked bar chart with each bar divided into segments.
@@ -194,6 +214,9 @@ def stack_bar(data:pd.DataFrame, label:str, columns:list[str],
194
214
  Column name used for bar labels.
195
215
  columns : `list` [`str`]
196
216
  Column names representing bar segments to stack.
217
+ positions : `np.ndarray` or `list` or `None`, default `None`
218
+ Numeric positions for the bars along the category axis. If None,
219
+ bars are placed at integer positions 0, 1, 2, ...
197
220
  colours : `list` [`str`]
198
221
  List of colours for each stack segment.
199
222
  transparency : `float`, default 0.7
@@ -204,7 +227,7 @@ def stack_bar(data:pd.DataFrame, label:str, columns:list[str],
204
227
  Colour for bar borders.
205
228
  horizontal : `bool`, default `False`
206
229
  Whether plot a horizontal barchart.
207
- ax : `plt.ax`, default `NoneType`
230
+ ax : `plt.ax`, default `None`
208
231
  The pyplot.axes object.
209
232
  figsize : `tuple` [`float`, `float`], default (2, 2),
210
233
  The figure size in inches, when ax is set to None.
@@ -222,6 +245,7 @@ def stack_bar(data:pd.DataFrame, label:str, columns:list[str],
222
245
  is_df(data)
223
246
  is_type(label, str)
224
247
  is_type(columns, list)
248
+ is_type(positions, (type(None), np.ndarray, list))
225
249
  is_type(colours, list)
226
250
  is_type(transparency, float)
227
251
  is_type(wd, (float, int))
@@ -254,17 +278,17 @@ def stack_bar(data:pd.DataFrame, label:str, columns:list[str],
254
278
  color=colours[idx],
255
279
  alpha=transparency,
256
280
  )
257
- if horizontal == False:
281
+ if not horizontal:
258
282
  new_kwargs = _update_kwargs(new_kwargs, bottom=left,
259
283
  )
260
284
  else:
261
285
  new_kwargs = _update_kwargs(new_kwargs, left=left,
262
286
  )
263
287
  # The actual plotting
264
- # NOTE adding wd here because it bar assigns it to either width or
288
+ # NOTE adding wd here because bar assigns it to either width or
265
289
  # height depending on horizontal.
266
- _, ax = bar(data=data, label=label, column=name, horizontal=horizontal,
267
- wd=wd, ax=ax, kwargs_bar=new_kwargs,
290
+ _, ax = bar(data=data, label=label, column=name, positions=positions,
291
+ horizontal=horizontal, wd=wd, ax=ax, kwargs_bar=new_kwargs,
268
292
  )
269
293
  # updating the coordinate where the last bar stops
270
294
  left = left + data[name]
@@ -277,6 +301,7 @@ def stack_bar(data:pd.DataFrame, label:str, columns:list[str],
277
301
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
278
302
  def subtotal_bar(data:pd.DataFrame, label:str, total_col:str,
279
303
  subtotal_col: str | None = None,
304
+ positions:np.ndarray | list[Real] | None = None,
280
305
  colours:tuple[str,str]=('grey','tab:blue'),
281
306
  transparency:tuple[float,float]=(0.7,0.9),
282
307
  wd:tuple[float,float]=(1,0.6),
@@ -299,8 +324,11 @@ def subtotal_bar(data:pd.DataFrame, label:str, total_col:str,
299
324
  Column name for axis labels.
300
325
  total_col : `str`
301
326
  Column containing values for the base (total) bars.
302
- subtotal_col : `str` or `None`, default `NoneType`
327
+ subtotal_col : `str` or `None`, default `None`
303
328
  Column containing values for (smaller) overlaid subtotal bars.
329
+ positions : `np.ndarray` or `list` or `None`, default `None`
330
+ Numeric positions for the bars along the category axis. If None,
331
+ bars are placed at integer positions 0, 1, 2, ...
304
332
  colours : `tuple` [`str`,`str`], default ("grey", "tab:blue")
305
333
  Colours for the total and subtotal bars.
306
334
  transparency : `tuple` [`float`,`float`], default (0.7, 0.9)
@@ -339,6 +367,7 @@ def subtotal_bar(data:pd.DataFrame, label:str, total_col:str,
339
367
  is_type(ax, (type(None), plt.Axes))
340
368
  is_type(total_col, str)
341
369
  is_type(subtotal_col, (str, type(None)))
370
+ is_type(positions, (type(None), np.ndarray, list))
342
371
  is_type(zorder, tuple)
343
372
  is_type(colours, tuple)
344
373
  is_type(transparency, tuple)
@@ -370,6 +399,7 @@ def subtotal_bar(data:pd.DataFrame, label:str, total_col:str,
370
399
  ax=ax,
371
400
  label=label,
372
401
  column=total_col,
402
+ positions=positions,
373
403
  colours=[colours[0]],
374
404
  transparency=transparency[0],
375
405
  wd=wd[0],
@@ -378,7 +408,7 @@ def subtotal_bar(data:pd.DataFrame, label:str, total_col:str,
378
408
  kwargs_bar=new_total_kwargs_bar,
379
409
  )
380
410
  # plot subtotal
381
- if not subtotal_col is None:
411
+ if subtotal_col is not None:
382
412
  subtotal = data[subtotal_col]
383
413
  # updating kwargs
384
414
  new_subtotal_kwargs_bar = _update_kwargs(
@@ -390,6 +420,7 @@ def subtotal_bar(data:pd.DataFrame, label:str, total_col:str,
390
420
  ax=ax,
391
421
  label=label,
392
422
  column=subtotal_col,
423
+ positions=positions,
393
424
  colours=[colours[1]],
394
425
  transparency=transparency[1],
395
426
  wd=wd[1],
@@ -405,6 +436,7 @@ def subtotal_bar(data:pd.DataFrame, label:str, total_col:str,
405
436
 
406
437
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
438
  def group_bar(data:pd.DataFrame, label:str, columns:list[str],
439
+ group_positions:np.ndarray | list[Real] | None = None,
408
440
  errors_max:list[str] | None = None,
409
441
  errors_min:list[str] | None = None,
410
442
  colours:list[str]=['tab:blue', 'tab:pink'],
@@ -433,10 +465,14 @@ def group_bar(data:pd.DataFrame, label:str, columns:list[str],
433
465
  Column name for group labels.
434
466
  column : `list` [`str`]
435
467
  Value columns to plot as grouped bars.
436
- errors_max : `list` [`str`] or `None`, default `NoneType`
468
+ group_positions : `np.ndarray` or `list` or `None`, default `None`
469
+ Numeric positions for the group centres along the category axis.
470
+ If None, groups are placed at positions determined by
471
+ `group_spacing` (0, group_spacing, 2*group_spacing, ...).
472
+ errors_max : `list` [`str`] or `None`, default `None`
437
473
  Column names in `data` containing the upper values of the error bars.
438
474
  Should be structured similarly to `columns` if used.
439
- errors_min : `list` [`str`] or `None` default `NoneType`
475
+ errors_min : `list` [`str`] or `None` default `None`
440
476
  Column names in `data` containing the lower values of the error bars.
441
477
  colours : `list` [`str`], default ['tab:blue', 'tab:pink']
442
478
  Colours for the bars. Recycled if fewer colours than `columns`.
@@ -448,7 +484,7 @@ def group_bar(data:pd.DataFrame, label:str, columns:list[str],
448
484
  The bar edge colours.
449
485
  horizontal : `bool`, default `False`
450
486
  Whether plot a horizontal barchart.
451
- ax : `plt.ax`, default `NoneType`
487
+ ax : `plt.ax`, default `None`
452
488
  The pyplot.axes object.
453
489
  figsize : `tuple` [`float`, `float`], default (2, 2),
454
490
  The figure size in inches, when ax is set to None.
@@ -464,11 +500,10 @@ def group_bar(data:pd.DataFrame, label:str, columns:list[str],
464
500
  ax : plt.Axes
465
501
  The matplotlib Axes object with the plot.
466
502
  """
467
- # constants
468
- OFFSET_COL = "__offset__"
469
503
  # check input - most will be done by bar, just keeping the minimum
470
504
  is_df(data)
471
505
  is_type(columns, list)
506
+ is_type(group_positions, (type(None), np.ndarray, list))
472
507
  is_type(errors_max, (type(None),list))
473
508
  is_type(errors_min, (type(None),list))
474
509
  is_type(horizontal, bool)
@@ -483,8 +518,16 @@ def group_bar(data:pd.DataFrame, label:str, columns:list[str],
483
518
  # ### prepare the loop
484
519
  # the number of bars for each group
485
520
  n_bars = len(columns)
486
- # the number of groups
487
- base = np.arange(data.shape[0]) * group_spacing
521
+ # the base position of each group
522
+ if group_positions is None:
523
+ base = np.arange(data.shape[0]) * group_spacing
524
+ else:
525
+ base = np.asarray(group_positions)
526
+ if len(base) != data.shape[0]:
527
+ raise ValueError(
528
+ f'Length of group_positions ({len(base)}) does not '
529
+ f'match number of rows in data ({data.shape[0]}).'
530
+ )
488
531
  # the total width of all the bars in a single group
489
532
  spacing_per_bar = bar_spacing * wd
490
533
  total_spacing = spacing_per_bar * (n_bars - 1)
@@ -494,20 +537,19 @@ def group_bar(data:pd.DataFrame, label:str, columns:list[str],
494
537
  group_width = wd * n_bars + total_spacing
495
538
  tick_pos = base + (group_width - wd) / 2
496
539
  # looping
497
- df_offset = data.copy()
498
540
  for i, column in enumerate(columns):
499
541
  # the location of the bar
500
542
  offset = base + i * (wd + spacing_per_bar)
501
- df_offset[OFFSET_COL] = offset
502
543
  # cycling the colours
503
544
  col = colours[i % len(colours)]
504
545
  # the limits
505
546
  err_max = errors_max[i] if errors_max else None
506
547
  err_min = errors_min[i] if errors_min else None
507
548
  _ = bar(
508
- data=df_offset,
509
- label=OFFSET_COL,
549
+ data=data,
550
+ label=label,
510
551
  column=column,
552
+ positions=offset,
511
553
  error_max=err_max,
512
554
  error_min=err_min,
513
555
  colours=[col],
plot_misc/constants.py CHANGED
@@ -42,9 +42,9 @@ class ForestNames(object):
42
42
  PVALUE = 'p-value'
43
43
  CI = 'confidence_interval'
44
44
  data_table = 'data_table'
45
- EmpericalSupport_Coverage = 'coverage'
46
- EmpericalSupport_Compatability = 'compatibility'
47
- EmpericalSupportResults = 'results_'
45
+ EmpiricalSupport_Coverage = 'coverage'
46
+ EmpiricalSupport_Compatibility = 'compatibility'
47
+ EmpiricalSupportResults = 'results_'
48
48
 
49
49
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
50
50
  # Utils Names
@@ -103,7 +103,7 @@ class NamesIncidenceMatrix(object):
103
103
  GRID_POS_O = 'outline'
104
104
 
105
105
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
106
- class NamesMachineLearnig(object):
106
+ class NamesMachineLearning(object):
107
107
  '''
108
108
  Names used by machinelearning.py
109
109
  '''
plot_misc/errors.py CHANGED
@@ -95,7 +95,7 @@ def is_type(param: Any, types: tuple[Type] | Type,
95
95
  Expected type(s) of the object.
96
96
  param_name : `str` or `None`
97
97
  Name of the parameter. Will attempt to infer the parameter name if set
98
- to `NoneType`.
98
+ to `None`.
99
99
 
100
100
  Returns
101
101
  -------
@@ -185,7 +185,7 @@ def are_columns_in_df(
185
185
  missing_columns = expected_columns_set - set(df.columns)
186
186
  # return
187
187
  if missing_columns:
188
- if warning == False:
188
+ if not warning:
189
189
  raise InputValidationError(
190
190
  message.format(missing_columns)
191
191
  )
@@ -273,7 +273,7 @@ def same_len(object1: Any, object2: Any, object_names: list[str] | None = None,
273
273
  if object_names is None:
274
274
  object_names = ['object1', 'object2']
275
275
  elif len(object_names) !=2:
276
- raise ValueError('`object_names` should be `NoneType` or contain '
276
+ raise ValueError('`object_names` should be `None` or contain '
277
277
  'two strings')
278
278
  # the actual test
279
279
  if n1 != n2:
@@ -74,6 +74,7 @@ import os
74
74
  import re
75
75
  import pandas as pd
76
76
  import numpy as np
77
+ import matplotlib.lines as mlines
77
78
  from plot_misc.constants import (
78
79
  UtilsNames,
79
80
  ForestNames,
@@ -724,3 +725,234 @@ def load_survival_table(**kwargs):
724
725
  # return
725
726
  return table
726
727
 
728
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
729
+ @dataset
730
+ def load_forest_preprocessed(**kwargs):
731
+ """
732
+ Loads the forest example data with subgroup colour and model shape
733
+ columns attached, ready for direct use by the forest plot.
734
+
735
+ Returns
736
+ -------
737
+ pd.DataFrame
738
+ """
739
+ # mapping literals (mirrors resources/examples/forestplot.ipynb cell 2)
740
+ col_dict = {
741
+ 'wo T2DM/CVD': 'orangered',
742
+ 'w T2DM': 'blueviolet',
743
+ 'w T2DM & CVD': 'limegreen',
744
+ }
745
+ shape_dict = {
746
+ 'PGS only': 'o',
747
+ 'PGS plus': 's',
748
+ 'PGS extended': 'H',
749
+ }
750
+ # base data (already carries the hardcoded ForestNames.y_col)
751
+ df = load_forest_data(**kwargs)
752
+ # attach colour and shape columns
753
+ df['col'] = df['subgroup_name'].map(col_dict)
754
+ df['shape'] = df['model'].map(shape_dict)
755
+ # return
756
+ return df
757
+
758
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
759
+ @dataset
760
+ def load_barchart_preprocessed(**kwargs):
761
+ """
762
+ Loads data counting the number of associations between cardiac chambers
763
+ (`LV`, `RV`, `LA`) and cardiac outcomes.
764
+
765
+ Returns
766
+ -------
767
+ pd.DataFrame
768
+ """
769
+ # files
770
+ data = load_barchart_data(**kwargs).T
771
+ data['labels'] = data.index
772
+ data = data.loc[["Heart failure", "HCM", "DCM", "AF"]]
773
+ # return
774
+ return data
775
+
776
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
777
+ @dataset
778
+ def load_groupbar_preprocessed(**kwargs):
779
+ """
780
+ Loads data representing mean and SD percentage of sarcomere disruption
781
+ per knockdown gene and control in iPS-CM
782
+
783
+ Returns
784
+ -------
785
+ pd.DataFrame
786
+ """
787
+ # files
788
+ data = load_groupbar_data(**kwargs)
789
+ # compute max-error columns from mean + std
790
+ for gene in ['Control', 'AP4S1', 'LRRC39', 'ZFAND4']:
791
+ data[f'{gene}_max'] = data[f'{gene}_mean'] + data[f'{gene}_std']
792
+ # return
793
+ return data
794
+
795
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
796
+ @dataset
797
+ def load_subtotal_preprocessed(**kwargs):
798
+ """
799
+ Reformatting the load_barchart data to create preprocssed subtotal charts.
800
+
801
+ Returns
802
+ -------
803
+ pd.DataFrame
804
+ """
805
+ # files
806
+ data_w = load_barchart_data(**kwargs)
807
+ data = data_w.T
808
+ label = "labels"
809
+ total_col = "total"
810
+ data[label] = data_w.T.index
811
+ data[total_col] = data.drop(columns=[label]).sum(axis=1)
812
+ data['sub'] = data["LV"]
813
+ data.drop(labels=["LA", "LV", "RV"], inplace=True, axis=1)
814
+ data.sort_values(by=total_col, ascending=True, inplace=True)
815
+ return data
816
+
817
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
818
+ @dataset
819
+ def load_survival_preprocessed(**kwargs) -> tuple[dict, pd.DataFrame]:
820
+ """
821
+ Returns survival rate data along with a survival table.
822
+ """
823
+ surv_table = create_survival_data(nrows=24)
824
+ surv_table2 = create_survival_data(nrows=15, survival_rate=0.04,
825
+ ci_width=0.45)
826
+ # extract at-risk counts at three time points
827
+ # import plot_misc.survival as pltm_surv
828
+ # bottom_table1 = pltm_surv.extract_follow_up(
829
+ # surv_table, at_risk_col='at_risk', points=[0, 50, 100])
830
+ # bottom_table2 = pltm_surv.extract_follow_up(
831
+ # surv_table2, at_risk_col='at_risk', points=[0, 50, 100])
832
+ # sel_col = ['time', 'group_1_at_risk_format']
833
+ # col_names = ['time', 'group 1', 'group 2']
834
+ # bottom_table = pd.merge(
835
+ # bottom_table1[sel_col], bottom_table2[sel_col], on=['time'],
836
+ # **kwargs)
837
+ # bottom_table.columns = col_names
838
+ bottom_table = pd.DataFrame(
839
+ {
840
+ "time": [0, 50, 100],
841
+ "group 1": ["1,000", "379", "121"],
842
+ "group 2": ["1,000", "143", "1"],
843
+ }
844
+ )
845
+ # make data dict
846
+ data_dict = dict(
847
+ curve1 = [surv_table, 'steelblue'],
848
+ curve2 = [surv_table2,'crimson']
849
+ )
850
+ return data_dict, bottom_table
851
+
852
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
853
+ @dataset
854
+ def load_heatmap_preprocessed(**kwargs) -> pd.DataFrame:
855
+ """
856
+ Returns an example correlation matrix.
857
+ """
858
+ labels = [
859
+ "Var1",
860
+ "Var2",
861
+ "Var3",
862
+ "Var4",
863
+ "Var5",
864
+ "Var6",
865
+ ]
866
+ corr = np.array([
867
+ [1.00, 0.92, 0.81, 0.55, 0.30, 0.10],
868
+ [0.92, 1.00, 0.76, 0.50, 0.28, 0.08],
869
+ [0.81, 0.76, 1.00, 0.45, 0.22, 0.05],
870
+ [0.55, 0.50, 0.45, 1.00, 0.18, 0.02],
871
+ [0.30, 0.28, 0.22, 0.18, 1.00, -0.75],
872
+ [0.10, 0.08, 0.05, 0.02, -0.75, 1.00],
873
+ ])
874
+ corr = pd.DataFrame(
875
+ corr,
876
+ index=labels,
877
+ columns=labels,
878
+ **kwargs,
879
+ )
880
+ return corr
881
+
882
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
883
+ @dataset
884
+ def load_bubble_preprocessed(**kwargs) -> tuple[pd.DataFrame, list, list]:
885
+ """
886
+ Returns the data needed for a bubble chart
887
+ """
888
+ # create an 8x5 dataframe filled with zeros
889
+ table = pd.DataFrame(
890
+ np.zeros((8, 5), dtype=int),
891
+ columns=['GENE1', 'GENE2', 'GENE3', 'GENE4', 'GENE5'],
892
+ index=['CRP', 'SBP', 'BMI', 'HbA1c', 'eGFR', 'CHD',
893
+ 'Stroke', 'T2DM']
894
+ )
895
+ # set two entries to 4
896
+ table.loc['SBP', 'GENE3'] = 4
897
+ table.loc['CHD', 'GENE4'] = 4
898
+ # set one entry to 3
899
+ table.loc['HbA1c', 'GENE2'] = 3
900
+ # set one entry to 4 (third entry with value 4)
901
+ table.loc['BMI', 'GENE5'] = 4
902
+ # set four entries to 1
903
+ table.loc['CRP', 'GENE1'] = 2
904
+ table.loc['eGFR', 'GENE2'] = 1
905
+ table.loc['Stroke', 'GENE4'] = 1
906
+ table.loc['T2DM', 'GENE5'] = 1
907
+ # Define cut-offs and mappings
908
+ DOT_COLOUR = [
909
+ ('#AAAAAA', 0.9), # grey for (−inf, 0.2]
910
+ ('#d65db1', 1.9),
911
+ ('#ff6f91', 2.9),
912
+ ('#008f7a', 3.9),
913
+ ('#ffc75f', 4.9),
914
+ ]
915
+ # Size thresholds: 2 categories
916
+ DOT_SIZE = [
917
+ (0, .9), # grey for (−inf, 0.2]
918
+ (20, 1.9),
919
+ (40, 2.9),
920
+ (60, 3.9),
921
+ (80, 4.9),
922
+ ]
923
+ # legend handle
924
+ _SCATTER_KW = dict(
925
+ marker='o', linestyle='none',
926
+ markeredgecolor='black', markeredgewidth=0.4,
927
+ )
928
+ _COLOURS = [c for c, _ in DOT_COLOUR[1:]]
929
+ _LABELS = ['1', '2', '3', '4']
930
+ _DOT_SIZE_VALS = [s for s, _ in DOT_SIZE[1:]]
931
+ handles = [
932
+ mlines.Line2D(
933
+ [], [], markerfacecolor=col,
934
+ markersize=2 + 6 * s / 80 if s > 0 else 2,
935
+ label=lbl, **_SCATTER_KW,
936
+ )
937
+ for col, s, lbl in zip(_COLOURS, _DOT_SIZE_VALS, _LABELS)
938
+ ]
939
+ return table, [DOT_COLOUR, DOT_SIZE], handles
940
+
941
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
942
+ @dataset
943
+ def load_forest_preprocessed(**kwargs) -> pd.DataFrame:
944
+ """
945
+ Returns the data needed for a forest plot
946
+ """
947
+ # subsetting the data on the exposure
948
+ data = load_mace_associations(**kwargs)
949
+ exposure = 'LDL-C (mmol/L)'
950
+ dat = data[data['Exposure'] == exposure].reset_index()
951
+ table = data[data['Exposure'] == 'Apo-B (g/L)'].reset_index()
952
+ table['Variable'] = [f'Variable {str(s)}' for s in range(1,table.shape[0]+1)]
953
+ # add y-coordinates
954
+ dat['Independent'] = ['Exposure ' + str(i+1) for i in range(dat.shape[0])]
955
+ dat['y_axis'] = [1.0, 5.0, 9.0, 13.0, 17.0, 21.0, 25.0]
956
+ return dat
957
+
958
+