marsilea 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marsilea/upset.py CHANGED
@@ -14,14 +14,23 @@ from matplotlib.patches import Rectangle, Patch
14
14
  from typing import List, Set, Mapping
15
15
 
16
16
  from .base import WhiteBoard
17
- from .plotter import Numbers, Labels, StackBar, Bar, Box, \
18
- Boxen, Violin, Point, Strip, Swarm
17
+ from .plotter import (
18
+ Numbers,
19
+ Labels,
20
+ StackBar,
21
+ Bar,
22
+ Box,
23
+ Boxen,
24
+ Violin,
25
+ Point,
26
+ Strip,
27
+ Swarm,
28
+ )
19
29
  from .utils import get_canvas_size_by_data
20
30
 
21
31
 
22
32
  def _get_sets_table(binary_table):
23
- cardinality = binary_table.groupby(list(binary_table.columns),
24
- observed=True).size()
33
+ cardinality = binary_table.groupby(list(binary_table.columns), observed=True).size()
25
34
  sets_table = pd.DataFrame(cardinality, columns=["cardinality"])
26
35
  sets_table["degree"] = sets_table.index.to_frame().sum(axis=1)
27
36
  return sets_table
@@ -71,8 +80,9 @@ class UpsetData:
71
80
  nitems, nsets = self._binary_table.shape
72
81
  return f"UpsetData: {nsets} sets, {nitems} items"
73
82
 
74
- def __init__(self, data, sets_names=None, items=None,
75
- sets_attrs=None, items_attrs=None):
83
+ def __init__(
84
+ self, data, sets_names=None, items=None, sets_attrs=None, items_attrs=None
85
+ ):
76
86
  if isinstance(data, pd.DataFrame):
77
87
  if sets_names is None:
78
88
  sets_names = data.columns.tolist()
@@ -84,8 +94,7 @@ class UpsetData:
84
94
  if items is None:
85
95
  raise ValueError("The name of items must be provided")
86
96
 
87
- assert len(sets_names) == len(set(sets_names)), \
88
- "Duplicates in set names"
97
+ assert len(sets_names) == len(set(sets_names)), "Duplicates in set names"
89
98
  assert len(items) == len(set(items)), "Duplicates in items"
90
99
 
91
100
  if sets_attrs is not None:
@@ -96,16 +105,16 @@ class UpsetData:
96
105
  items_attrs = items_attrs.loc[list(items)]
97
106
  self._items_attrs = items_attrs
98
107
 
99
- self._binary_table = pd.DataFrame(columns=sets_names, index=items,
100
- data=data)
108
+ self._binary_table = pd.DataFrame(columns=sets_names, index=items, data=data)
101
109
  self._sets_table = _get_sets_table(self._binary_table)
102
110
 
103
- def filter(self,
104
- min_degree=None,
105
- max_degree=None,
106
- min_cardinality=None,
107
- max_cardinality=None,
108
- ):
111
+ def filter(
112
+ self,
113
+ min_degree=None,
114
+ max_degree=None,
115
+ min_cardinality=None,
116
+ max_cardinality=None,
117
+ ):
109
118
  """Filter by degree or cardinality
110
119
 
111
120
  Parameters
@@ -126,11 +135,9 @@ class UpsetData:
126
135
  if max_degree is not None:
127
136
  sets_table = sets_table[sets_table["degree"] <= max_degree]
128
137
  if min_cardinality is not None:
129
- sets_table = sets_table[
130
- sets_table["cardinality"] >= min_cardinality]
138
+ sets_table = sets_table[sets_table["cardinality"] >= min_cardinality]
131
139
  if max_cardinality is not None:
132
- sets_table = sets_table[
133
- sets_table["cardinality"] <= max_cardinality]
140
+ sets_table = sets_table[sets_table["cardinality"] <= max_cardinality]
134
141
  self._sets_table = sets_table
135
142
  return self
136
143
 
@@ -149,17 +156,18 @@ class UpsetData:
149
156
  if by not in ["degree", "cardinality"]:
150
157
  raise ValueError("Sort by either `degree` or `cardinality`")
151
158
  if by == "cardinality":
152
- self._sets_table.sort_values(by=by, ascending=not ascending,
153
- inplace=True)
159
+ self._sets_table.sort_values(
160
+ by=by, ascending=not ascending, inplace=True, kind="stable"
161
+ )
154
162
  else:
155
163
  matrix = self._sets_table.index.to_frame().reset_index(drop=True)
156
164
  _, num = matrix.shape
157
165
 
158
- matrix['SUM'] = matrix.sum(axis=1)
166
+ matrix["SUM"] = matrix.sum(axis=1)
159
167
 
160
168
  reorder_ix = []
161
- for n, df in matrix.groupby('SUM'):
162
- del df['SUM']
169
+ for n, df in matrix.groupby("SUM"):
170
+ del df["SUM"]
163
171
  rows = df.to_numpy()
164
172
  c_rows = []
165
173
  for row in rows:
@@ -192,7 +200,9 @@ class UpsetData:
192
200
  if order is not None:
193
201
  sets_names = order
194
202
  else:
195
- sets_sizes = self.sets_size().sort_values(ascending=ascending)
203
+ sets_sizes = self.sets_size().sort_values(
204
+ ascending=ascending, kind="stable"
205
+ )
196
206
  sets_names = sets_sizes.index.to_list()
197
207
  self._binary_table = self._binary_table.loc[:, sets_names]
198
208
  self._sets_table = self._sets_table.reorder_levels(order=sets_names)
@@ -200,14 +210,15 @@ class UpsetData:
200
210
  self._sets_attrs = self._sets_attrs.loc[sets_names]
201
211
  return self
202
212
 
203
- def mark(self,
204
- present=None,
205
- absent=None,
206
- min_cardinality=None,
207
- max_cardinality=None,
208
- min_degree=None,
209
- max_degree=None
210
- ):
213
+ def mark(
214
+ self,
215
+ present=None,
216
+ absent=None,
217
+ min_cardinality=None,
218
+ max_cardinality=None,
219
+ min_degree=None,
220
+ max_degree=None,
221
+ ):
211
222
  sets_table = self._sets_table
212
223
  marks = np.ones(len(sets_table), dtype=int)
213
224
 
@@ -237,9 +248,13 @@ class UpsetData:
237
248
  return self
238
249
 
239
250
  @classmethod
240
- def from_sets(cls, sets: List[Set], sets_names=None,
241
- sets_attrs: pd.DataFrame = None,
242
- items_attrs: pd.DataFrame = None) -> UpsetData:
251
+ def from_sets(
252
+ cls,
253
+ sets: List[Set],
254
+ sets_names=None,
255
+ sets_attrs: pd.DataFrame = None,
256
+ items_attrs: pd.DataFrame = None,
257
+ ) -> UpsetData:
243
258
  """Create UpsetData from a series of sets
244
259
 
245
260
  Parameters
@@ -256,7 +271,7 @@ class UpsetData:
256
271
  The attributes of items, the input index should be the
257
272
  same as items
258
273
 
259
-
274
+
260
275
  """
261
276
  items = set()
262
277
  new_sets = []
@@ -283,15 +298,23 @@ class UpsetData:
283
298
  d = [i in s for i in items]
284
299
  data.append(d)
285
300
  data = np.array(data, dtype=int).T
286
- container = cls(data, sets_names=new_names, items=items,
287
- sets_attrs=sets_attrs,
288
- items_attrs=items_attrs)
301
+ container = cls(
302
+ data,
303
+ sets_names=new_names,
304
+ items=items,
305
+ sets_attrs=sets_attrs,
306
+ items_attrs=items_attrs,
307
+ )
289
308
  return container
290
309
 
291
310
  @classmethod
292
- def from_memberships(cls, items, items_names=None,
293
- sets_attrs: pd.DataFrame = None,
294
- items_attrs: pd.DataFrame = None):
311
+ def from_memberships(
312
+ cls,
313
+ items,
314
+ items_names=None,
315
+ sets_attrs: pd.DataFrame = None,
316
+ items_attrs: pd.DataFrame = None,
317
+ ):
295
318
  """Describe the sets an item are in
296
319
 
297
320
  Parameters
@@ -325,10 +348,14 @@ class UpsetData:
325
348
  if items_names is not None:
326
349
  new_items_names = items_names
327
350
 
328
- df = pd.DataFrame(sets).fillna(False).astype(int)
329
- container = cls(df.to_numpy(), sets_names=df.columns,
330
- items=new_items_names, sets_attrs=sets_attrs,
331
- items_attrs=items_attrs)
351
+ df = pd.DataFrame(sets).astype(float).fillna(False).astype(int)
352
+ container = cls(
353
+ df.to_numpy(),
354
+ sets_names=df.columns,
355
+ items=new_items_names,
356
+ sets_attrs=sets_attrs,
357
+ items_attrs=items_attrs,
358
+ )
332
359
  return container
333
360
 
334
361
  def has_item(self, item):
@@ -353,11 +380,11 @@ class UpsetData:
353
380
 
354
381
  def cardinality(self):
355
382
  """The number of items in intersections"""
356
- return self._sets_table['cardinality']
383
+ return self._sets_table["cardinality"]
357
384
 
358
385
  def degree(self):
359
386
  """Intersection between how many sets"""
360
- return self._sets_table['degree']
387
+ return self._sets_table["degree"]
361
388
 
362
389
  def sets_size(self):
363
390
  return self._binary_table.sum()[self.sets_names]
@@ -460,38 +487,42 @@ class Upset(WhiteBoard):
460
487
  >>> [3, 4, 5, 6],
461
488
  >>> [1, 6, 10, 11]])
462
489
  >>> Upset(data).render()
463
-
490
+
464
491
 
465
492
  """
466
493
 
467
- def __init__(self, data: UpsetData,
468
- orient="h",
469
- sort_sets=None, # ascending, descending
470
- sets_order=None,
471
- sets_color=None,
472
- sort_subsets="cardinality", # cardinality, degree
473
- min_degree=None,
474
- max_degree=None,
475
- min_cardinality=None,
476
- max_cardinality=None,
477
- color=".1",
478
- shading=.3,
479
- radius=50,
480
- linewidth=1.5,
481
- grid_background=0.1,
482
- fontsize=None,
483
- add_intersections=True,
484
- add_sets_size=True,
485
- add_labels=True,
486
- width=None,
487
- height=None,
488
- ):
494
+ def __init__(
495
+ self,
496
+ data: UpsetData,
497
+ orient="h",
498
+ sort_sets=None, # ascending, descending
499
+ sets_order=None,
500
+ sets_color=None,
501
+ sort_subsets="cardinality", # cardinality, degree
502
+ min_degree=None,
503
+ max_degree=None,
504
+ min_cardinality=None,
505
+ max_cardinality=None,
506
+ color=".1",
507
+ shading=0.3,
508
+ radius=50,
509
+ linewidth=1.5,
510
+ grid_background=0.1,
511
+ fontsize=None,
512
+ add_intersections=True,
513
+ add_sets_size=True,
514
+ add_labels=True,
515
+ width=None,
516
+ height=None,
517
+ ):
489
518
  # The modification happens inplace
490
519
  upset_data = data
491
- upset_data.filter(min_degree=min_degree,
492
- max_degree=max_degree,
493
- min_cardinality=min_cardinality,
494
- max_cardinality=max_cardinality)
520
+ upset_data.filter(
521
+ min_degree=min_degree,
522
+ max_degree=max_degree,
523
+ min_cardinality=min_cardinality,
524
+ max_cardinality=max_cardinality,
525
+ )
495
526
 
496
527
  ascending = sort_subsets.startswith("-")
497
528
  if ascending:
@@ -536,7 +567,8 @@ class Upset(WhiteBoard):
536
567
  self.orient = orient
537
568
 
538
569
  width, height = get_canvas_size_by_data(
539
- main_shape, scale=.3, width=width, height=height, aspect=1)
570
+ main_shape, scale=0.3, width=width, height=height, aspect=1
571
+ )
540
572
 
541
573
  super().__init__(width=width, height=height)
542
574
  if add_intersections:
@@ -558,13 +590,21 @@ class Upset(WhiteBoard):
558
590
  side = "left" if orient == "h" else "top"
559
591
  self.add_sets_size(side, color=self.sets_color)
560
592
 
561
- def highlight_subsets(self, present=None, absent=None,
562
- min_cardinality=None, max_cardinality=None,
563
- min_degree=None, max_degree=None,
564
- facecolor=None, edgecolor=None,
565
- edgewidth=None, hatch=None, edgestyle=None,
566
- label=None,
567
- ):
593
+ def highlight_subsets(
594
+ self,
595
+ present=None,
596
+ absent=None,
597
+ min_cardinality=None,
598
+ max_cardinality=None,
599
+ min_degree=None,
600
+ max_degree=None,
601
+ facecolor=None,
602
+ edgecolor=None,
603
+ edgewidth=None,
604
+ hatch=None,
605
+ edgestyle=None,
606
+ label=None,
607
+ ):
568
608
  """Highlight a subset of the data.
569
609
 
570
610
  Notice that the color of hatch is determined by the edgecolor.
@@ -597,11 +637,14 @@ class Upset(WhiteBoard):
597
637
  The label for the highlighting
598
638
 
599
639
  """
600
- marks = self.data.mark(present=present, absent=absent,
601
- min_cardinality=min_cardinality,
602
- max_cardinality=max_cardinality,
603
- min_degree=min_degree,
604
- max_degree=max_degree)
640
+ marks = self.data.mark(
641
+ present=present,
642
+ absent=absent,
643
+ min_cardinality=min_cardinality,
644
+ max_cardinality=max_cardinality,
645
+ min_degree=min_degree,
646
+ max_degree=max_degree,
647
+ )
605
648
 
606
649
  def _label(name, value):
607
650
  if value is not None:
@@ -617,13 +660,29 @@ class Upset(WhiteBoard):
617
660
  min_degree = _label("min_degree", min_degree)
618
661
  max_degree = _label("max_degree", max_degree)
619
662
 
620
- label = ", ".join([s for s in [present, absent, min_cardinality,
621
- max_cardinality, min_degree,
622
- max_degree] if s])
623
-
624
- styles = dict(facecolor=facecolor, edgecolor=edgecolor,
625
- linestyle=edgestyle, linewidth=edgewidth,
626
- hatch=hatch, label=label)
663
+ label = ", ".join(
664
+ [
665
+ s
666
+ for s in [
667
+ present,
668
+ absent,
669
+ min_cardinality,
670
+ max_cardinality,
671
+ min_degree,
672
+ max_degree,
673
+ ]
674
+ if s
675
+ ]
676
+ )
677
+
678
+ styles = dict(
679
+ facecolor=facecolor,
680
+ edgecolor=edgecolor,
681
+ linestyle=edgestyle,
682
+ linewidth=edgewidth,
683
+ hatch=hatch,
684
+ label=label,
685
+ )
627
686
  styles = {k: v for k, v in styles.items() if v is not None}
628
687
 
629
688
  line_styles = dict()
@@ -642,63 +701,68 @@ class Upset(WhiteBoard):
642
701
  self._subset_styles[i] = current_styles
643
702
  self._subset_line_styles[i] = {**line_styles}
644
703
 
645
- if 'facecolor' not in styles.keys():
646
- styles['facecolor'] = 'none'
647
- if 'edgecolor' not in styles.keys():
648
- styles['edgecolor'] = 'none'
704
+ if "facecolor" not in styles.keys():
705
+ styles["facecolor"] = "none"
706
+ if "edgecolor" not in styles.keys():
707
+ styles["edgecolor"] = "none"
649
708
  self._legend_entries.append(styles)
650
709
 
651
710
  def _check_side(self, side, chart_name, allow):
652
711
  options = allow[self.orient]
653
712
  if side not in options:
654
- msg = f"{chart_name} cannot be placed at '{side}', " \
655
- f"try {' ,'.join(options)}"
713
+ msg = (
714
+ f"{chart_name} cannot be placed at '{side}', "
715
+ f"try {' ,'.join(options)}"
716
+ )
656
717
  raise ValueError(msg)
657
718
 
658
- def add_intersections(self, side, pad=.1, size=1.):
659
- self._check_side(side, 'Intersections',
660
- dict(h=["top", "bottom"], v=["left", "right"]))
719
+ def add_intersections(self, side, pad=0.1, size=1.0):
720
+ self._check_side(
721
+ side, "Intersections", dict(h=["top", "bottom"], v=["left", "right"])
722
+ )
661
723
  data = self.data.cardinality()
662
724
  self._intersection_bar = Numbers(data, color=self.color)
663
725
  self.add_plot(side, self._intersection_bar, size=size, pad=pad)
664
726
 
665
- def add_sets_size(self, side, pad=.1, size=1., **props):
666
- self._check_side(side, 'Sets size',
667
- dict(h=["left", "right"], v=["top", "bottom"]))
727
+ def add_sets_size(self, side, pad=0.1, size=1.0, **props):
728
+ self._check_side(
729
+ side, "Sets size", dict(h=["left", "right"], v=["top", "bottom"])
730
+ )
668
731
  data = self.sets_size
669
732
  options = dict(color=self.color)
670
733
  options.update(props)
671
734
  self._sets_size_bar = Numbers(data, **options)
672
735
  self.add_plot(side, self._sets_size_bar, size=size, pad=pad)
673
736
 
674
- def add_sets_label(self, side, pad=.1, size=None, **props):
675
- self._check_side(side, 'Sets label',
676
- dict(h=["left", "right"], v=["top", "bottom"]))
737
+ def add_sets_label(self, side, pad=0.1, size=None, **props):
738
+ self._check_side(
739
+ side, "Sets label", dict(h=["left", "right"], v=["top", "bottom"])
740
+ )
677
741
  data = self.data.sets_names
678
742
  self.add_plot(side, Labels(data, **props), pad=pad, size=size)
679
743
 
680
744
  def get_intersection_ax(self):
681
- return self.get_ax('Intersections')
745
+ return self.get_ax("Intersections")
682
746
 
683
747
  def get_sets_size_ax(self):
684
- return self.get_ax('Sets size')
748
+ return self.get_ax("Sets size")
685
749
 
686
750
  def get_sets_label_ax(self):
687
- return self.get_ax('Sets label')
751
+ return self.get_ax("Sets label")
688
752
 
689
753
  def get_data(self):
690
754
  return self.data
691
755
 
692
756
  _attr_plotter = {
693
- 'bar': Bar,
694
- 'box': Box,
695
- 'boxen': Boxen,
696
- 'violin': Violin,
697
- 'point': Point,
698
- 'strip': Strip,
699
- 'swarm': Swarm,
700
- 'stack_bar': StackBar,
701
- 'number': Numbers,
757
+ "bar": Bar,
758
+ "box": Box,
759
+ "boxen": Boxen,
760
+ "violin": Violin,
761
+ "point": Point,
762
+ "strip": Strip,
763
+ "swarm": Swarm,
764
+ "stack_bar": StackBar,
765
+ "number": Numbers,
702
766
  }
703
767
 
704
768
  @classmethod
@@ -706,8 +770,9 @@ class Upset(WhiteBoard):
706
770
  """Update the global upset plot for attr plotter"""
707
771
  cls._attr_plotter.update(attr_plotter)
708
772
 
709
- def add_sets_attr(self, side, attr_name, plot,
710
- name=None, pad=.1, size=None, plot_kws=None):
773
+ def add_sets_attr(
774
+ self, side, attr_name, plot, name=None, pad=0.1, size=None, plot_kws=None
775
+ ):
711
776
  """Add a plot for the sets attribute
712
777
 
713
778
  Parameters
@@ -732,14 +797,14 @@ class Upset(WhiteBoard):
732
797
  data = self.data.sets_attrs
733
798
  attr = data[attr_name]
734
799
  plot = self._attr_plotter[plot]
735
- kws = {'label': attr_name}
800
+ kws = {"label": attr_name}
736
801
  if plot_kws is not None:
737
802
  kws.update(plot_kws)
738
- self.add_plot(side, plot(attr, **plot_kws),
739
- name=name, pad=pad, size=size)
803
+ self.add_plot(side, plot(attr, **plot_kws), name=name, pad=pad, size=size)
740
804
 
741
- def add_items_attr(self, side, attr_name, plot,
742
- name=None, pad=.1, size=None, plot_kws=None):
805
+ def add_items_attr(
806
+ self, side, attr_name, plot, name=None, pad=0.1, size=None, plot_kws=None
807
+ ):
743
808
  """Add a plot for the items attribute
744
809
 
745
810
  Parameters
@@ -769,16 +834,15 @@ class Upset(WhiteBoard):
769
834
  if plot == StackBar:
770
835
  collect = [Counter(col) for col in data_collector]
771
836
  construct = pd.DataFrame(collect).T
772
- construct = ((construct.loc[~pd.isnull(construct.index)])
773
- .fillna(0)
774
- .astype(int))
837
+ construct = (
838
+ (construct.loc[~pd.isnull(construct.index)]).fillna(0).astype(int)
839
+ )
775
840
 
776
841
  plot = self._attr_plotter[plot]
777
- kws = {'label': attr_name}
842
+ kws = {"label": attr_name}
778
843
  if plot_kws is not None:
779
844
  kws.update(plot_kws)
780
- self.add_plot(side, plot(construct, **kws),
781
- name=name, pad=pad, size=size)
845
+ self.add_plot(side, plot(construct, **kws), name=name, pad=pad, size=size)
782
846
 
783
847
  def _render_matrix(self, ax):
784
848
  ax.set_axis_off()
@@ -795,8 +859,14 @@ class Upset(WhiteBoard):
795
859
  if self.shading > 0:
796
860
  if self.orient == "v":
797
861
  xv, yv = yv, xv
798
- ax.scatter(xv, yv, s=self.radius, facecolor=self.color,
799
- alpha=self.shading, edgecolor="none")
862
+ ax.scatter(
863
+ xv,
864
+ yv,
865
+ s=self.radius,
866
+ facecolor=self.color,
867
+ alpha=self.shading,
868
+ edgecolor="none",
869
+ )
800
870
 
801
871
  for ix1, chunk in enumerate(matrix):
802
872
  custom_style = self._subset_styles.get(ix1)
@@ -810,19 +880,23 @@ class Upset(WhiteBoard):
810
880
  if len(cy) > 0:
811
881
  line_low, line_up = np.min(cy), np.max(cy)
812
882
  if (self.linewidth > 0) & (line_up - line_low > 0):
813
- line_style = {'color': self.color, 'lw': self.linewidth,
814
- **custom_line_style}
883
+ line_style = {
884
+ "color": self.color,
885
+ "lw": self.linewidth,
886
+ **custom_line_style,
887
+ }
815
888
  xs, ys = ix1, (line_low, line_up)
816
- liner = ax.vlines
817
- if self.orient == "v":
818
- xs, ys = ys, xs
819
- liner = ax.hlines
889
+ liner = ax.vlines if self.orient == "h" else ax.hlines
820
890
  liner(xs, *ys, **line_style)
821
891
  scatter_colors = self.sets_color[cy]
822
892
  if self.orient == "v":
823
893
  cx, cy = cy, cx
824
- current_style = {'facecolor': scatter_colors, 'zorder': 100,
825
- 'alpha': 1, **custom_style}
894
+ current_style = {
895
+ "facecolor": scatter_colors,
896
+ "zorder": 100,
897
+ "alpha": 1,
898
+ **custom_style,
899
+ }
826
900
  ax.scatter(cx, cy, s=self.radius, **current_style)
827
901
 
828
902
  xlow, xup = 0 - 0.5, np.max(xv) + 0.5
@@ -840,9 +914,13 @@ class Upset(WhiteBoard):
840
914
  height = ylow - yup
841
915
  for i, coord in enumerate(bg_coords):
842
916
  if i % 2 == 0:
843
- rect = Rectangle(xy=coord, height=height, width=width,
844
- facecolor=self.color,
845
- alpha=self.grid_background)
917
+ rect = Rectangle(
918
+ xy=coord,
919
+ height=height,
920
+ width=width,
921
+ facecolor=self.color,
922
+ alpha=self.grid_background,
923
+ )
846
924
  bg_circles.append(rect)
847
925
  # add bg_circles
848
926
  bg_circles = PatchCollection(bg_circles, match_original=True)
@@ -860,7 +938,7 @@ class Upset(WhiteBoard):
860
938
  handles = [Patch(**entry) for entry in self._legend_entries]
861
939
  highlight_legend = ListLegend(handles=handles, handlelength=2)
862
940
  highlight_legend.figure = None
863
- return {'highlight_subsets': [highlight_legend]}
941
+ return {"highlight_subsets": [highlight_legend]}
864
942
 
865
943
  def render(self, figure=None, scale=1):
866
944
  super().render(figure=figure, scale=scale)
marsilea/utils.py CHANGED
@@ -6,10 +6,22 @@ from matplotlib.colors import Colormap
6
6
  from uuid import uuid4
7
7
 
8
8
  ECHARTS16 = [
9
- "#5470c6", "#91cc75", "#fac858", "#ee6666",
10
- "#9a60b4", "#73c0de", "#3ba272", "#fc8452",
11
- "#27727b", "#ea7ccc", "#d7504b", "#e87c25",
12
- "#b5c334", "#fe8463", "#26c0c0", "#f4e001"
9
+ "#5470c6",
10
+ "#91cc75",
11
+ "#fac858",
12
+ "#ee6666",
13
+ "#9a60b4",
14
+ "#73c0de",
15
+ "#3ba272",
16
+ "#fc8452",
17
+ "#27727b",
18
+ "#ea7ccc",
19
+ "#d7504b",
20
+ "#e87c25",
21
+ "#b5c334",
22
+ "#fe8463",
23
+ "#26c0c0",
24
+ "#f4e001",
13
25
  ]
14
26
 
15
27
 
@@ -30,9 +42,9 @@ def batched(iterable, n):
30
42
  """Batch data into lists of length n. The last batch may be shorter."""
31
43
  # batched('ABCDEFG', 3) --> ABC DEF G
32
44
  if n < 1:
33
- raise ValueError('n must be at least one')
45
+ raise ValueError("n must be at least one")
34
46
  it = iter(iterable)
35
- while (batch := list(islice(it, n))):
47
+ while batch := list(islice(it, n)):
36
48
  yield batch
37
49
 
38
50
 
@@ -48,8 +60,8 @@ def relative_luminance(color):
48
60
  luminance : float(s) between 0 and 1
49
61
  """
50
62
  rgb = mcolors.colorConverter.to_rgba_array(color)[:, :3]
51
- rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
52
- lum = rgb.dot([.2126, .7152, .0722])
63
+ rgb = np.where(rgb <= 0.03928, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4)
64
+ lum = rgb.dot([0.2126, 0.7152, 0.0722])
53
65
  try:
54
66
  return lum.item()
55
67
  except ValueError:
@@ -65,8 +77,9 @@ def get_colormap(cmap):
65
77
  return mpl.cm.get_cmap(cmap)
66
78
 
67
79
 
68
- def get_canvas_size_by_data(shape, width=None, height=None,
69
- scale=.3, aspect=1, max_side=15):
80
+ def get_canvas_size_by_data(
81
+ shape, width=None, height=None, scale=0.3, aspect=1, max_side=15
82
+ ):
70
83
  h, w = shape
71
84
  no_w = width is None
72
85
  no_h = height is None