marsilea 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
marsilea/base.py CHANGED
@@ -45,6 +45,7 @@ def get_breakpoints(arr):
45
45
 
46
46
 
47
47
  class LegendMaker:
48
+ """The factory class to handle legends"""
48
49
  layout: CrossLayout | CompositeCrossLayout
49
50
  _legend_box: List[Artist] = None
50
51
  _legend_name: str = None
@@ -65,13 +66,13 @@ class LegendMaker:
65
66
  """
66
67
  raise NotImplementedError("Should be implemented in derived class")
67
68
 
68
- def custom_legends(self, legends, name=None):
69
- """Add custom legends
69
+ def custom_legend(self, legend, name=None):
70
+ """Add a custom legend
70
71
 
71
72
  Parameters
72
73
  ----------
73
74
 
74
- legends : `Artist <matplotlib.artist.Artists>`
75
+ legend : `Artist <matplotlib.artist.Artists>`
75
76
  A legend object
76
77
  name : str, optional
77
78
  The name of the legend
@@ -79,16 +80,21 @@ class LegendMaker:
79
80
  """
80
81
  if name is None:
81
82
  name = str(uuid4())
82
- self._user_legends[name] = legends
83
-
84
- def add_legends(self, side="right", pad=0, order=None,
85
- stack_by=None, stack_size=3,
86
- align_legends=None,
87
- align_stacks=None,
88
- legend_spacing=10,
89
- stack_spacing=10,
90
- box_padding=2,
91
- ):
83
+ self._user_legends[name] = legend
84
+
85
+ def add_legends(
86
+ self,
87
+ side="right",
88
+ pad=0,
89
+ order=None,
90
+ stack_by=None,
91
+ stack_size=3,
92
+ align_legends=None,
93
+ align_stacks=None,
94
+ legend_spacing=10,
95
+ stack_spacing=10,
96
+ box_padding=2,
97
+ ):
92
98
  """Draw legend based on the order of annotation
93
99
 
94
100
  .. note::
@@ -99,21 +105,26 @@ class LegendMaker:
99
105
 
100
106
  Parameters
101
107
  ----------
102
- side : str, default: 'left'
108
+ side : {'right', 'left', 'top', 'bottom'}, default: 'right'
103
109
  Which side to draw legend
104
110
  pad : number, default: 0
111
+ The padding of the legend in inches
105
112
  order : array of plot name
106
- stack_by :
107
- stack_size :
108
- align_legends : str
113
+ The order of the legend, if None, the order will be the same as the order when adding plotters.
114
+ You need to set name for each plotter when adding them, and specify the order here.
115
+ stack_by : {'row', 'col'}
116
+ The direction to stack legends
117
+ stack_size : int, default: 3
118
+ The number of legends in a stack
119
+ align_legends : {'left', 'right', 'top', 'bottom'}
109
120
  The side to align legends in a stack
110
- align_stacks : str
121
+ align_stacks : {'left', 'right', 'top', 'bottom'}
111
122
  The side to align stacks
112
- legend_spacing : float
123
+ legend_spacing : float, default: 10
113
124
  The space between legends
114
- stack_spacing : float
125
+ stack_spacing : float, default: 10
115
126
  The space between stacks
116
- box_padding : float
127
+ box_padding : float, default: 2
117
128
  Add pad around the whole legend box
118
129
 
119
130
  """
@@ -128,10 +139,15 @@ class LegendMaker:
128
139
 
129
140
  self._legend_grid_kws = dict(side=side, size=0.01, pad=pad)
130
141
  self._legend_draw_kws = dict(
131
- order=order, stack_by=stack_by, stack_size=stack_size,
132
- align_legends=align_legends, align_stacks=align_stacks,
133
- legend_spacing=legend_spacing, stack_spacing=stack_spacing,
134
- box_padding=box_padding)
142
+ order=order,
143
+ stack_by=stack_by,
144
+ stack_size=stack_size,
145
+ align_legends=align_legends,
146
+ align_stacks=align_stacks,
147
+ legend_spacing=legend_spacing,
148
+ stack_spacing=stack_spacing,
149
+ box_padding=box_padding,
150
+ )
135
151
 
136
152
  def remove_legends(self):
137
153
  self._draw_legend = False
@@ -150,14 +166,14 @@ class LegendMaker:
150
166
  except Exception:
151
167
  pass
152
168
 
153
- legend_order = self._legend_draw_kws['order']
154
- stack_by = self._legend_draw_kws['stack_by']
155
- stack_size = self._legend_draw_kws['stack_size']
156
- align_legends = self._legend_draw_kws['align_legends']
157
- align_stacks = self._legend_draw_kws['align_stacks']
158
- legend_spacing = self._legend_draw_kws['legend_spacing']
159
- stack_spacing = self._legend_draw_kws['stack_spacing']
160
- box_padding = self._legend_draw_kws['box_padding']
169
+ legend_order = self._legend_draw_kws["order"]
170
+ stack_by = self._legend_draw_kws["stack_by"]
171
+ stack_size = self._legend_draw_kws["stack_size"]
172
+ align_legends = self._legend_draw_kws["align_legends"]
173
+ align_stacks = self._legend_draw_kws["align_stacks"]
174
+ legend_spacing = self._legend_draw_kws["legend_spacing"]
175
+ stack_spacing = self._legend_draw_kws["stack_spacing"]
176
+ box_padding = self._legend_draw_kws["box_padding"]
161
177
 
162
178
  inner, outer = vstack, hstack
163
179
  if stack_by == "row":
@@ -175,8 +191,13 @@ class LegendMaker:
175
191
  for legs in batched(all_legs, stack_size):
176
192
  box = inner(legs, align=align_legends, spacing=legend_spacing)
177
193
  bboxes.append(box)
178
- legend_box = outer(bboxes, align=align_stacks, loc="center left",
179
- spacing=stack_spacing, padding=box_padding)
194
+ legend_box = outer(
195
+ bboxes,
196
+ align=align_stacks,
197
+ loc="center left",
198
+ spacing=stack_spacing,
199
+ padding=box_padding,
200
+ )
180
201
  ax.add_artist(legend_box)
181
202
  # uncomment this to visualize legend ax
182
203
  # from matplotlib.patches import Rectangle
@@ -195,7 +216,7 @@ class LegendMaker:
195
216
  legend_ax = figure.add_axes([0, 0, 1, 1])
196
217
  legends_box = self._legends_drawer(legend_ax)
197
218
  bbox = legends_box.get_window_extent(renderer)
198
- if self._legend_grid_kws['side'] in ["left", "right"]:
219
+ if self._legend_grid_kws["side"] in ["left", "right"]:
199
220
  size = bbox.xmax - bbox.xmin
200
221
  else:
201
222
  size = bbox.ymax - bbox.ymin
@@ -213,22 +234,67 @@ class LegendMaker:
213
234
  class WhiteBoard(LegendMaker):
214
235
  """The base class that handle all rendering process
215
236
 
237
+ Parameters
238
+ ----------
239
+ width : int, optional
240
+ The width of the main canvas in inches
241
+ height : int, optional
242
+ The height of the main canvas in inches
243
+ name : str, optional
244
+ The name of the main canvas
245
+ margin : float, 4-tuple, optional
246
+ The margin of the main canvas in inches
247
+ init_main : bool, optional
248
+ If True, the main canvas will be initialized
249
+
250
+
251
+ See Also
252
+ --------
253
+ :class:`~marsilea.base.ClusterBoard`
254
+
255
+
256
+ Attributes
257
+ ----------
258
+ layout : CrossLayout
259
+ The layout manager
260
+ figure : Figure
261
+ The matplotlib figure object
262
+
263
+ Examples
264
+ --------
265
+ Create a violin plot in white board
266
+
267
+ .. plot::
268
+ :context: close-figs
269
+
270
+ >>> import numpy as np
271
+ >>> import marsilea as ma
272
+ >>> data = np.random.rand(10, 10)
273
+ >>> h = ma.WhiteBoard(height=2)
274
+ >>> h.add_layer(ma.plotter.Violin(data))
275
+ >>> h.render()
276
+
277
+
216
278
  """
279
+
217
280
  layout: CrossLayout
218
281
  figure: Figure = None
219
282
  _row_plan: List[RenderPlan]
220
283
  _col_plan: List[RenderPlan]
221
284
  _layer_plan: List[RenderPlan]
222
285
 
223
- def __init__(self, width=None, height=None, name=None, margin=.2):
286
+ def __init__(self, width=None, height=None, name=None, margin=0.2, init_main=True):
224
287
  self.main_name = get_plot_name(name, "main", "board")
225
288
  self._main_size_updatable = (width is None) & (height is None)
226
289
  width = 4 if width is None else width
227
290
  height = 4 if height is None else height
228
- self.layout = CrossLayout(name=self.main_name,
229
- width=width,
230
- height=height,
231
- margin=margin)
291
+ self.layout = CrossLayout(
292
+ name=self.main_name,
293
+ width=width,
294
+ height=height,
295
+ margin=margin,
296
+ init_main=init_main,
297
+ )
232
298
 
233
299
  # self._side_count = {"right": 0, "left": 0, "top": 0, "bottom": 0}
234
300
  self._col_plan = []
@@ -238,8 +304,27 @@ class WhiteBoard(LegendMaker):
238
304
  self._legend_switch = {}
239
305
  super().__init__()
240
306
 
241
- def add_plot(self, side, plot: RenderPlan, name=None,
242
- size=None, pad=0., legend=True):
307
+ def add_plot(
308
+ self, side, plot: RenderPlan, name=None, size=None, pad=0.0, legend=True
309
+ ):
310
+ """Add a plotter to the board
311
+
312
+ Parameters
313
+ ----------
314
+ side : {"left", "right", "top", "bottom"}
315
+ Which side to add the plotter
316
+ plot : RenderPlan
317
+ The plotter to add
318
+ name : str, optional
319
+ The name of the plot
320
+ size : float, optional
321
+ The size of the plot in inches
322
+ pad : float, optional
323
+ The padding of the plot in inches
324
+ legend : bool, optional
325
+ If True, the legend will be included when calling :meth:`~marsilea.base.LegendMaker.add_legends`
326
+
327
+ """
243
328
  plot_name = get_plot_name(name, side, plot.__class__.__name__)
244
329
  self._legend_switch[plot_name] = legend
245
330
 
@@ -249,7 +334,7 @@ class WhiteBoard(LegendMaker):
249
334
  if plot.size is not None:
250
335
  ax_size = plot.size
251
336
  else:
252
- ax_size = 1.
337
+ ax_size = 1.0
253
338
 
254
339
  self.layout.add_ax(side, name=plot_name, size=ax_size, pad=pad)
255
340
 
@@ -262,20 +347,80 @@ class WhiteBoard(LegendMaker):
262
347
 
263
348
  plan.append(plot)
264
349
 
265
- def add_left(self, plot: RenderPlan, name=None,
266
- size=None, pad=0., legend=True):
350
+ def add_left(self, plot: RenderPlan, name=None, size=None, pad=0.0, legend=True):
351
+ """Add a plotter to the left-side of main canvas
352
+
353
+ Parameters
354
+ ----------
355
+ plot : RenderPlan
356
+ The plotter to add
357
+ name : str, optional
358
+ The name of the plot
359
+ size : float, optional
360
+ The size of the plot in inches
361
+ pad : float, optional
362
+ The padding of the plot in inches
363
+ legend : bool, optional
364
+ If True, the legend will be included when calling :meth:`~marsilea.base.LegendMaker.add_legends`
365
+
366
+ """
267
367
  self.add_plot("left", plot, name, size, pad, legend)
268
368
 
269
- def add_right(self, plot: RenderPlan, name=None,
270
- size=None, pad=0., legend=True):
369
+ def add_right(self, plot: RenderPlan, name=None, size=None, pad=0.0, legend=True):
370
+ """Add a plotter to the right-side of main canvas
371
+
372
+ Parameters
373
+ ----------
374
+ plot : RenderPlan
375
+ The plotter to add
376
+ name : str, optional
377
+ The name of the plot
378
+ size : float, optional
379
+ The size of the plot in inches
380
+ pad : float, optional
381
+ The padding of the plot in inches
382
+ legend : bool, optional
383
+ If True, the legend will be included when calling :meth:`~marsilea.base.LegendMaker.add_legends`
384
+
385
+ """
271
386
  self.add_plot("right", plot, name, size, pad, legend)
272
387
 
273
- def add_top(self, plot: RenderPlan, name=None,
274
- size=None, pad=0., legend=True):
388
+ def add_top(self, plot: RenderPlan, name=None, size=None, pad=0.0, legend=True):
389
+ """Add a plotter to the top-side of main canvas
390
+
391
+ Parameters
392
+ ----------
393
+ plot : RenderPlan
394
+ The plotter to add
395
+ name : str, optional
396
+ The name of the plot
397
+ size : float, optional
398
+ The size of the plot in inches
399
+ pad : float, optional
400
+ The padding of the plot in inches
401
+ legend : bool, optional
402
+ If True, the legend will be included when calling :meth:`~marsilea.base.LegendMaker.add_legends`
403
+
404
+ """
275
405
  self.add_plot("top", plot, name, size, pad, legend)
276
406
 
277
- def add_bottom(self, plot: RenderPlan, name=None,
278
- size=None, pad=0., legend=True):
407
+ def add_bottom(self, plot: RenderPlan, name=None, size=None, pad=0.0, legend=True):
408
+ """Add a plotter to the bottom-side of main canvas
409
+
410
+ Parameters
411
+ ----------
412
+ plot : RenderPlan
413
+ The plotter to add
414
+ name : str, optional
415
+ The name of the plot
416
+ size : float, optional
417
+ The size of the plot in inches
418
+ pad : float, optional
419
+ The padding of the plot in inches
420
+ legend : bool, optional
421
+ If True, the legend will be included when calling :meth:`~marsilea.base.LegendMaker.add_legends`
422
+
423
+ """
279
424
  self.add_plot("bottom", plot, name, size, pad, legend)
280
425
 
281
426
  def _render_plan(self):
@@ -293,14 +438,31 @@ class WhiteBoard(LegendMaker):
293
438
  plan.render(main_ax)
294
439
 
295
440
  def add_layer(self, plot: RenderPlan, zorder=None, name=None, legend=True):
441
+ """Add a plotter to the main canvas
442
+
443
+ .. note::
444
+
445
+ Not every plotter can be added as a layer.
446
+
447
+ Parameters
448
+ ----------
449
+ plot : RenderPlan
450
+ The plotter to add
451
+ zorder : int, optional
452
+ The zorder of the plot
453
+ name : str, optional
454
+ The name of the plot
455
+ legend : bool, optional
456
+ If True, the legend will be included when calling :meth:`~marsilea.base.LegendMaker.add_legends`
457
+
458
+ """
296
459
  if name is None:
297
460
  name = plot.name
298
461
  plot_type = plot.__class__.__name__
299
462
  name = get_plot_name(name, side="main", chart=plot_type)
300
463
  self._legend_switch[name] = legend
301
464
  if not plot.render_main:
302
- msg = f"{plot_type} " \
303
- f"cannot be rendered as another layer."
465
+ msg = f"{plot_type} " f"cannot be rendered as another layer."
304
466
  raise TypeError(msg)
305
467
  if zorder is not None:
306
468
  plot.zorder = zorder
@@ -323,13 +485,57 @@ class WhiteBoard(LegendMaker):
323
485
  return sorted(self._layer_plan, key=lambda p: p.zorder)
324
486
 
325
487
  def add_pad(self, side, size):
488
+ """Add padding to the main canvas
489
+
490
+ Parameters
491
+ ----------
492
+ side : {"left", "right", "top", "bottom"}
493
+ Which side to add padding
494
+ size : float
495
+ The size of padding in inches
496
+
497
+ """
326
498
  self.layout.add_pad(side, size)
327
499
 
328
- def add_canvas(self, side, name, size, pad=0.):
500
+ def add_canvas(self, side, name, size, pad=0.0):
501
+ """Add an axes to the main canvas
502
+
503
+ Parameters
504
+ ----------
505
+ side : {"left", "right", "top", "bottom"}
506
+ Which side to add the axes
507
+ name : str
508
+ The name of the axes
509
+ size : float
510
+ The size of the axes in inches
511
+ pad : float, optional
512
+ The padding of the axes in inches
513
+
514
+ """
329
515
  self.layout.add_ax(side, name, size, pad=pad)
330
516
 
331
- def add_title(self, top=None, bottom=None, left=None, right=None,
332
- pad=0, **props):
517
+ def add_title(self, top=None, bottom=None, left=None, right=None, pad=0, **props):
518
+ """A shortcut to add title to the main canvas
519
+
520
+ Parameters
521
+ ----------
522
+ top : str, optional
523
+ The title of the top side
524
+ bottom : str, optional
525
+ The title of the bottom side
526
+ left : str, optional
527
+ The title of the left side
528
+ right : str, optional
529
+ The title of the right side
530
+ pad : float, optional
531
+ The padding of the title in inches
532
+ props : dict
533
+ The properties of the title
534
+
535
+ Returns
536
+ -------
537
+
538
+ """
333
539
  if left is not None:
334
540
  self.add_plot("left", Title(left, **props), pad=pad)
335
541
  if right is not None:
@@ -363,6 +569,7 @@ class WhiteBoard(LegendMaker):
363
569
  return {}
364
570
 
365
571
  def get_legends(self):
572
+ """Get all legends from the main canvas"""
366
573
  legends = {}
367
574
  legends.update(self._extra_legends())
368
575
  for plan in self._layer_plan + self._col_plan + self._row_plan:
@@ -384,6 +591,7 @@ class WhiteBoard(LegendMaker):
384
591
  return self.append("bottom", other)
385
592
 
386
593
  def append(self, side, other):
594
+ """Append two :class:`~marsilea.base.CrossLayout` together"""
387
595
  compose_board = CompositeBoard(self)
388
596
  compose_board.append(side, other)
389
597
  return compose_board
@@ -396,12 +604,15 @@ class WhiteBoard(LegendMaker):
396
604
  self.layout.set_render_size(plan.name, render_size)
397
605
 
398
606
  def render(self, figure=None, scale=1):
399
- """
607
+ """Finalize the layout and render all plots
400
608
 
401
609
  Parameters
402
610
  ----------
403
- figure
404
- scale
611
+ figure : :class:`~matplotlib.figure.FigureBase`, optional
612
+ The matplotlib figure object
613
+ scale : float, optional
614
+ The scale value of the figure size. You can use this to
615
+ adjust the overall size of the figure
405
616
 
406
617
  Returns
407
618
  -------
@@ -413,10 +624,6 @@ class WhiteBoard(LegendMaker):
413
624
  self._freeze_legend(figure)
414
625
  self._freeze_flex_plots(figure)
415
626
 
416
- # if refreeze:
417
- # self.figure = self.layout.freeze(figure=figure, scale=scale)
418
- # else:
419
- # self.figure = self.layout.figure
420
627
  self.layout.freeze(figure=figure, scale=scale)
421
628
 
422
629
  # render other plots
@@ -424,18 +631,71 @@ class WhiteBoard(LegendMaker):
424
631
  self._render_legend()
425
632
 
426
633
  def save(self, fname, **kwargs):
634
+ """Save the figure to a file
635
+
636
+ This will force a re-render of the figure
637
+
638
+ Parameters
639
+ ----------
640
+ fname : str, path-like
641
+ The file name to save
642
+ kwargs : dict
643
+ Additional options for saving the figure, will be passed to :meth:`~matplotlib.pyplot.savefig`
644
+
645
+ """
427
646
  self.render()
428
647
  save_options = dict(bbox_inches="tight")
429
648
  save_options.update(kwargs)
430
649
  self.figure.savefig(fname, **save_options)
431
650
 
432
- def set_margin(self, margin):
651
+ def set_margin(self, margin: float | tuple[float, float, float, float]):
652
+ """Set margin of the main canvas
653
+
654
+ Parameters
655
+ ----------
656
+ margin : float, 4-tuple
657
+ The margin of the main canvas in inches
658
+
659
+ """
433
660
  self.layout.set_margin(margin)
434
661
 
435
662
 
663
+ class ZeroWidth(WhiteBoard):
664
+ """A utility class to initialize a canvas \
665
+ with zero width
666
+
667
+ This is useful when you try to stack many plots
668
+
669
+ Parameters
670
+ ----------
671
+ height : float
672
+ The
673
+ name : str
674
+ margin : float
675
+
676
+ """
677
+
678
+ def __init__(self, height, name=None, margin=0.2):
679
+ super().__init__(width=0, height=height, name=name,
680
+ margin=margin, init_main=False)
681
+
682
+
683
+ class ZeroHeight(WhiteBoard):
684
+ """A utility class to initialize a canvas \
685
+ with zero height
686
+
687
+ This is useful when you try to stack many plots
688
+
689
+ """
690
+
691
+ def __init__(self, width, name=None, margin=0.2):
692
+ super().__init__(width=width, height=0, name=name,
693
+ margin=margin, init_main=False)
694
+
695
+
436
696
  class CompositeBoard(LegendMaker):
437
- layout: CompositeCrossLayout
438
- figure: Figure
697
+ layout: CompositeCrossLayout = None
698
+ figure: Figure = None
439
699
 
440
700
  def __init__(self, main_board: WhiteBoard):
441
701
  self.main_board = self.new_board(main_board)
@@ -487,8 +747,9 @@ class CompositeBoard(LegendMaker):
487
747
  save_options.update(kwargs)
488
748
  self.figure.savefig(fname, **save_options)
489
749
  else:
490
- warnings.warn("Figure does not exist, "
491
- "please render it before saving as file.")
750
+ warnings.warn(
751
+ "Figure does not exist, " "please render it before saving as file."
752
+ )
492
753
 
493
754
  def get_legends(self):
494
755
  legends = {}
@@ -504,6 +765,31 @@ class CompositeBoard(LegendMaker):
504
765
 
505
766
 
506
767
  class ClusterBoard(WhiteBoard):
768
+ """A main canvas class that can handle cluster data
769
+
770
+ Parameters
771
+ ----------
772
+ cluster_data : ndarray
773
+ The cluster data
774
+ width : int, optional
775
+ The width of the main canvas in inches
776
+ height : int, optional
777
+ The height of the main canvas in inches
778
+ name : str, optional
779
+ The name of the main canvas
780
+ margin : float, 4-tuple, optional
781
+ The margin of the main canvas in inches
782
+ init_main : bool, optional
783
+ If True, the main canvas will be initialized
784
+
785
+
786
+ See Also
787
+ --------
788
+ :class:`~marsilea.base.WhiteBoard`
789
+
790
+
791
+ """
792
+
507
793
  _row_reindex: List[int] = None
508
794
  _col_reindex: List[int] = None
509
795
  # If cluster data need to be defined by user
@@ -511,21 +797,44 @@ class ClusterBoard(WhiteBoard):
511
797
  _split_col: bool = False
512
798
  _split_row: bool = False
513
799
  _mesh = None
514
- square = False
515
800
 
516
- def __init__(self, cluster_data, width=None, height=None,
517
- name=None, margin=.2):
518
- super().__init__(width=width, height=height, name=name, margin=margin)
801
+ def __init__(
802
+ self,
803
+ cluster_data,
804
+ width=None,
805
+ height=None,
806
+ name=None,
807
+ margin=0.2,
808
+ init_main=True,
809
+ ):
810
+ super().__init__(
811
+ width=width, height=height, name=name, margin=margin, init_main=init_main
812
+ )
519
813
  self._row_den = []
520
814
  self._col_den = []
521
815
  self._cluster_data = cluster_data
522
816
  self._deform = Deformation(cluster_data)
523
817
 
524
- def add_dendrogram(self, side, method=None, metric=None, linkage=None,
525
- add_meta=True, add_base=True, add_divider=True,
526
- meta_color=None, linewidth=None, colors=None,
527
- divider_style="--", meta_ratio=.2,
528
- show=True, name=None, size=0.5, pad=0., get_meta_center=None):
818
+ def add_dendrogram(
819
+ self,
820
+ side,
821
+ method=None,
822
+ metric=None,
823
+ linkage=None,
824
+ add_meta=True,
825
+ add_base=True,
826
+ add_divider=True,
827
+ meta_color=None,
828
+ linewidth=None,
829
+ colors=None,
830
+ divider_style="--",
831
+ meta_ratio=0.2,
832
+ show=True,
833
+ name=None,
834
+ size=0.5,
835
+ pad=0.0,
836
+ get_meta_center=None,
837
+ ):
529
838
  """Run cluster and add dendrogram
530
839
 
531
840
  .. note::
@@ -619,15 +928,16 @@ class ClusterBoard(WhiteBoard):
619
928
 
620
929
  """
621
930
  if not self._allow_cluster:
622
- msg = f"Please specify cluster data when initialize " \
623
- f"'{self.__class__.__name__}' class."
931
+ msg = (
932
+ f"Please specify cluster data when initialize "
933
+ f"'{self.__class__.__name__}' class."
934
+ )
624
935
  raise ValueError(msg)
625
936
  plot_name = get_plot_name(name, side, "Dendrogram")
626
937
 
627
938
  # if only colors is passed
628
939
  # the color should be applied to all
629
- if (colors is not None) & (is_color_like(colors)) & (
630
- meta_color is None):
940
+ if (colors is not None) & (is_color_like(colors)) & (meta_color is None):
631
941
  meta_color = colors
632
942
 
633
943
  # if nothing is added
@@ -638,27 +948,90 @@ class ClusterBoard(WhiteBoard):
638
948
  if show:
639
949
  self.layout.add_ax(side, name=plot_name, size=size, pad=pad)
640
950
 
641
- den_options = dict(name=plot_name, show=show, side=side,
642
- add_meta=add_meta, add_base=add_base,
643
- add_divider=add_divider, meta_color=meta_color,
644
- linewidth=linewidth, colors=colors,
645
- divider_style=divider_style, meta_ratio=meta_ratio)
951
+ den_options = dict(
952
+ name=plot_name,
953
+ show=show,
954
+ side=side,
955
+ add_meta=add_meta,
956
+ add_base=add_base,
957
+ add_divider=add_divider,
958
+ meta_color=meta_color,
959
+ linewidth=linewidth,
960
+ colors=colors,
961
+ divider_style=divider_style,
962
+ meta_ratio=meta_ratio,
963
+ )
646
964
 
647
965
  deform = self.get_deform()
648
966
  if side in ["right", "left"]:
649
- den_options['pos'] = "row"
967
+ den_options["pos"] = "row"
650
968
  self._row_den.append(den_options)
651
- deform.set_cluster(row=True, method=method, metric=metric,
652
- linkage=linkage, use_meta=add_meta,
653
- get_meta_center=get_meta_center)
969
+ deform.set_cluster(
970
+ row=True,
971
+ method=method,
972
+ metric=metric,
973
+ linkage=linkage,
974
+ use_meta=add_meta,
975
+ get_meta_center=get_meta_center,
976
+ )
654
977
  else:
655
- den_options['pos'] = "col"
978
+ den_options["pos"] = "col"
656
979
  self._col_den.append(den_options)
657
- deform.set_cluster(col=True, method=method, metric=metric,
658
- linkage=linkage, use_meta=add_meta,
659
- get_meta_center=get_meta_center)
660
-
980
+ deform.set_cluster(
981
+ col=True,
982
+ method=method,
983
+ metric=metric,
984
+ linkage=linkage,
985
+ use_meta=add_meta,
986
+ get_meta_center=get_meta_center,
987
+ )
988
+
661
989
  def hsplit(self, cut=None, labels=None, order=None, spacing=0.01):
990
+ """Split the main canvas horizontally
991
+
992
+ .. deprecated:: 0.5.0
993
+ Use :meth:`~marsilea.base.ClusterBoard.cut_rows` \
994
+ or :meth:`~marsilea.base.ClusterBoard.group_rows` instead
995
+
996
+ Parameters
997
+ ----------
998
+ cut : array-like, optional
999
+ The index of your data to specify where to split the canvas
1000
+ labels : array-like, optional
1001
+ The labels of your data, must be the same length as the data
1002
+ order : array-like, optional
1003
+ The order of the unique labels
1004
+ spacing : float, optional
1005
+ The spacing between each split chunks, default is 0.01
1006
+
1007
+ Examples
1008
+ --------
1009
+ Split the canvas by the unique labels
1010
+
1011
+ .. plot::
1012
+ :context: close-figs
1013
+
1014
+ >>> data = np.random.rand(10, 11)
1015
+ >>> import marsilea as ma
1016
+ >>> h = ma.Heatmap(data)
1017
+ >>> labels = ["A", "B", "C", "A", "B", "C", "A", "B", "C", "A"]
1018
+ >>> h.hsplit(labels=labels, order=["A", "B", "C"])
1019
+ >>> h.add_left(ma.plotter.Labels(labels), pad=.1)
1020
+ >>> h.render()
1021
+
1022
+
1023
+ Split the canvas by the index
1024
+
1025
+ .. plot::
1026
+ :context: close-figs
1027
+
1028
+ >>> h = ma.Heatmap(data)
1029
+ >>> h.hsplit(cut=[4, 8])
1030
+ >>> h.render()
1031
+
1032
+
1033
+ """
1034
+ warnings.warn(DeprecationWarning("`hsplit` will be deprecated in v0.5.0, use `cut_rows` or `group_rows` instead"))
662
1035
  if self._split_row:
663
1036
  raise SplitTwice(axis="horizontally")
664
1037
  self._split_row = True
@@ -677,6 +1050,51 @@ class ClusterBoard(WhiteBoard):
677
1050
  deform.set_split_row(breakpoints=breakpoints, order=order)
678
1051
 
679
1052
  def vsplit(self, cut=None, labels=None, order=None, spacing=0.01):
1053
+ """Split the main canvas vertically
1054
+
1055
+ .. deprecated:: 0.5.0
1056
+ Use :meth:`~marsilea.base.ClusterBoard.cut_cols` \
1057
+ or :meth:`~marsilea.base.ClusterBoard.group_cols` instead
1058
+
1059
+ Parameters
1060
+ ----------
1061
+ cut : array-like, optional
1062
+ The index of your data to specify where to split the canvas
1063
+ labels : array-like, optional
1064
+ The labels of your data, must be the same length as the data
1065
+ order : array-like, optional
1066
+ The order of the unique labels
1067
+ spacing : float, optional
1068
+ The spacing between each split chunks, default is 0.01
1069
+
1070
+ Examples
1071
+ --------
1072
+ Split the canvas by the unique labels
1073
+
1074
+ .. plot::
1075
+ :context: close-figs
1076
+
1077
+ >>> data = np.random.rand(10, 11)
1078
+ >>> import marsilea as ma
1079
+ >>> h = ma.Heatmap(data)
1080
+ >>> labels = ["A", "B", "C", "A", "B", "C", "A", "B", "C", "A", "B"]
1081
+ >>> h.vsplit(labels=labels, order=["A", "B", "C"])
1082
+ >>> h.add_top(ma.plotter.Labels(labels), pad=.1)
1083
+ >>> h.render()
1084
+
1085
+
1086
+ Split the canvas by the index
1087
+
1088
+ .. plot::
1089
+ :context: close-figs
1090
+
1091
+ >>> h = ma.Heatmap(data)
1092
+ >>> h.vsplit(cut=[4, 8])
1093
+ >>> h.render()
1094
+
1095
+
1096
+ """
1097
+ warnings.warn(DeprecationWarning("`vsplit` will be deprecated in v0.5.0, use `cut_cols` or `group_cols` instead"))
680
1098
  if self._split_col:
681
1099
  raise SplitTwice(axis="vertically")
682
1100
  self._split_col = True
@@ -694,6 +1112,154 @@ class ClusterBoard(WhiteBoard):
694
1112
  breakpoints = get_breakpoints(labels[reindex])
695
1113
  deform.set_split_col(breakpoints=breakpoints, order=order)
696
1114
 
1115
+ def group_rows(self, group, order=None, spacing=0.01):
1116
+ """Group rows into chunks
1117
+
1118
+ Parameters
1119
+ ----------
1120
+ group : array-like
1121
+ The group of each row
1122
+ order : array-like, optional
1123
+ The order of the unique groups
1124
+ spacing : float, optional
1125
+ The spacing between each split chunks, default is 0.01
1126
+
1127
+ Examples
1128
+ --------
1129
+ Group rows by the unique labels
1130
+
1131
+ .. plot::
1132
+ :context: close-figs
1133
+
1134
+ >>> data = np.random.rand(10, 11)
1135
+ >>> import marsilea as ma
1136
+ >>> h = ma.Heatmap(data)
1137
+ >>> labels = ["A", "B", "C", "A", "B", "C", "A", "B", "C", "A"]
1138
+ >>> h.group_rows(labels, order=["A", "B", "C"])
1139
+ >>> h.add_left(ma.plotter.Labels(labels), pad=.1)
1140
+ >>> h.render()
1141
+
1142
+ """
1143
+ if self._split_row:
1144
+ raise SplitTwice(axis="rows")
1145
+ self._split_row = True
1146
+
1147
+ deform = self.get_deform()
1148
+ deform.hspace = spacing
1149
+
1150
+ labels = np.asarray(group)
1151
+ reindex, order = reorder_index(labels, order=order)
1152
+ deform.set_data_row_reindex(reindex)
1153
+
1154
+ breakpoints = get_breakpoints(labels[reindex])
1155
+ deform.set_split_row(breakpoints=breakpoints, order=order)
1156
+
1157
+ def group_cols(self, group, order=None, spacing=0.01):
1158
+ """Group columns into chunks
1159
+
1160
+ Parameters
1161
+ ----------
1162
+ group : array-like
1163
+ The group of each column
1164
+ order : array-like, optional
1165
+ The order of the unique groups
1166
+ spacing : float, optional
1167
+ The spacing between each split chunks, default is 0.01
1168
+
1169
+ Examples
1170
+ --------
1171
+ Group columns by the unique labels
1172
+
1173
+ .. plot::
1174
+ :context: close-figs
1175
+
1176
+ >>> data = np.random.rand(11, 10)
1177
+ >>> import marsilea as ma
1178
+ >>> h = ma.Heatmap(data)
1179
+ >>> labels = ["A", "B", "C", "A", "B", "C", "A", "B", "C", "A"]
1180
+ >>> h.group_cols(labels, order=["A", "B", "C"])
1181
+ >>> h.add_top(ma.plotter.Labels(labels), pad=.1)
1182
+ >>> h.render()
1183
+
1184
+ """
1185
+ if self._split_col:
1186
+ raise SplitTwice(axis="columns")
1187
+ self._split_col = True
1188
+
1189
+ deform = self.get_deform()
1190
+ deform.wspace = spacing
1191
+
1192
+ labels = np.asarray(group)
1193
+ reindex, order = reorder_index(labels, order=order)
1194
+ deform.set_data_col_reindex(reindex)
1195
+
1196
+ breakpoints = get_breakpoints(labels[reindex])
1197
+ deform.set_split_col(breakpoints=breakpoints, order=order)
1198
+
1199
+ def cut_rows(self, cut, spacing=0.01):
1200
+ """Cut the main canvas by rows
1201
+
1202
+ Parameters
1203
+ ----------
1204
+ cut : array-like
1205
+ The index of your data to specify where to cut the canvas
1206
+ spacing : float, optional
1207
+ The spacing between each cut, default is 0.01
1208
+
1209
+ Examples
1210
+ --------
1211
+ Cut the canvas by the index
1212
+
1213
+ .. plot::
1214
+ :context: close-figs
1215
+
1216
+ >>> data = np.random.rand(10, 11)
1217
+ >>> import marsilea as ma
1218
+ >>> h = ma.Heatmap(data)
1219
+ >>> h.cut_rows([4, 8])
1220
+ >>> h.render()
1221
+
1222
+ """
1223
+ if self._split_row:
1224
+ raise SplitTwice(axis="horizontally")
1225
+ self._split_row = True
1226
+
1227
+ deform = self.get_deform()
1228
+ deform.hspace = spacing
1229
+ deform.set_split_row(breakpoints=cut)
1230
+
1231
+ def cut_cols(self, cut, spacing=0.01):
1232
+ """Cut the main canvas by columns
1233
+
1234
+ Parameters
1235
+ ----------
1236
+ cut : array-like
1237
+ The index of your data to specify where to cut the canvas
1238
+ spacing : float, optional
1239
+ The spacing between each cut, default is 0.01
1240
+
1241
+ Examples
1242
+ --------
1243
+ Cut the canvas by the index
1244
+
1245
+ .. plot::
1246
+ :context: close-figs
1247
+
1248
+ >>> data = np.random.rand(10, 11)
1249
+ >>> import marsilea as ma
1250
+ >>> h = ma.Heatmap(data)
1251
+ >>> h.cut_cols([4, 8])
1252
+ >>> h.render()
1253
+
1254
+ """
1255
+ if self._split_col:
1256
+ raise SplitTwice(axis="vertically")
1257
+ self._split_col = True
1258
+
1259
+ deform = self.get_deform()
1260
+ deform.wspace = spacing
1261
+ deform.set_split_col(breakpoints=cut)
1262
+
697
1263
  def _setup_axes(self):
698
1264
  deform = self.get_deform()
699
1265
  w_ratios = deform.col_ratios
@@ -716,8 +1282,7 @@ class ClusterBoard(WhiteBoard):
716
1282
  group_ratios = None
717
1283
  else:
718
1284
  group_ratios = plan.get_split_regroup()
719
- self.layout.vsplit(plan.name, w_ratios, wspace,
720
- group_ratios)
1285
+ self.layout.vsplit(plan.name, w_ratios, wspace, group_ratios)
721
1286
 
722
1287
  # split row axes
723
1288
  if deform.is_row_split:
@@ -727,38 +1292,40 @@ class ClusterBoard(WhiteBoard):
727
1292
  group_ratios = None
728
1293
  else:
729
1294
  group_ratios = plan.get_split_regroup()
730
- self.layout.hsplit(plan.name, h_ratios, hspace,
731
- group_ratios)
1295
+ self.layout.hsplit(plan.name, h_ratios, hspace, group_ratios)
732
1296
 
733
1297
  def _render_dendrogram(self):
734
1298
  deform = self.get_deform()
735
- for den in (self._row_den + self._col_den):
736
- if den['show']:
737
- ax = self.layout.get_ax(den['name'])
1299
+ for den in self._row_den + self._col_den:
1300
+ if den["show"]:
1301
+ ax = self.layout.get_ax(den["name"])
738
1302
  ax.set_axis_off()
739
1303
  spacing = deform.hspace
740
1304
  den_obj = deform.get_row_dendrogram()
741
- if den['pos'] == "col":
1305
+ if den["pos"] == "col":
742
1306
  spacing = deform.wspace
743
1307
  den_obj = deform.get_col_dendrogram()
744
1308
  if isinstance(den_obj, Dendrogram):
745
- color = den['colors']
1309
+ color = den["colors"]
746
1310
  if (color is not None) & (not is_color_like(color)):
747
1311
  color = color[0]
748
- den_obj.draw(ax, orient=den['side'], color=color,
749
- linewidth=den['linewidth'])
1312
+ den_obj.draw(
1313
+ ax, orient=den["side"], color=color, linewidth=den["linewidth"]
1314
+ )
750
1315
  else:
751
- den_obj.draw(ax, orient=den['side'],
752
- spacing=spacing,
753
- add_meta=den['add_meta'],
754
- add_base=den['add_base'],
755
- base_colors=den['colors'],
756
- meta_color=den['meta_color'],
757
- linewidth=den['linewidth'],
758
- divide=den['add_divider'],
759
- divide_style=den['divider_style'],
760
- meta_ratio=den['meta_ratio']
761
- )
1316
+ den_obj.draw(
1317
+ ax,
1318
+ orient=den["side"],
1319
+ spacing=spacing,
1320
+ add_meta=den["add_meta"],
1321
+ add_base=den["add_base"],
1322
+ base_colors=den["colors"],
1323
+ meta_color=den["meta_color"],
1324
+ linewidth=den["linewidth"],
1325
+ divide=den["add_divider"],
1326
+ divide_style=den["divider_style"],
1327
+ meta_ratio=den["meta_ratio"],
1328
+ )
762
1329
 
763
1330
  def _render_plan(self):
764
1331
  deform = self.get_deform()
@@ -781,20 +1348,37 @@ class ClusterBoard(WhiteBoard):
781
1348
  plan.render(main_ax)
782
1349
 
783
1350
  def get_deform(self):
1351
+ """Return the deformation object of the cluster data"""
784
1352
  return self._deform
785
1353
 
786
1354
  def get_row_linkage(self):
1355
+ """Return the linkage matrix of row dendrogram
1356
+
1357
+ If the canvas is not split, the linkage matrix will be returned;
1358
+ otherwise, a dictionary of linkage matrix will be returned, the key is either
1359
+ index or the name of each chunk.
1360
+
1361
+ """
787
1362
  return self._deform.get_row_linkage()
788
1363
 
789
1364
  def get_col_linkage(self):
1365
+ """Return the linkage matrix of column dendrogram
1366
+
1367
+ If the canvas is not split, the linkage matrix will be returned;
1368
+ otherwise, a dictionary of linkage matrix will be returned, the key is either
1369
+ index or the name of each chunk.
1370
+
1371
+ """
790
1372
  return self._deform.get_col_linkage()
791
1373
 
792
1374
  @property
793
1375
  def row_cluster(self):
1376
+ """If row dendrogram is added"""
794
1377
  return len(self._row_den) > 0
795
1378
 
796
1379
  @property
797
1380
  def col_cluster(self):
1381
+ """If column dendrogram is added"""
798
1382
  return len(self._col_den) > 0
799
1383
 
800
1384
  def render(self, figure=None, scale=1):