marsilea 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,98 +2,292 @@
2
2
  # issues across different rendering backend at different DPI. Currently
3
3
  # not a public API.
4
4
  from functools import partial
5
+ from hashlib import sha256
6
+ from numbers import Number
7
+ from pathlib import Path
5
8
 
6
9
  import numpy as np
10
+ from PIL import Image as PILImage
7
11
  from matplotlib.image import imread, BboxImage
8
12
  from matplotlib.transforms import Bbox
9
- from pathlib import Path
10
13
  from platformdirs import user_cache_dir
11
- from urllib.request import urlretrieve
12
14
 
13
15
  from .base import RenderPlan
14
16
 
15
- TWEMOJI_CDN = "https://cdn.jsdelivr.net/gh/twitter/twemoji/assets/72x72/"
16
-
17
17
 
18
18
  def _cache_remote(url, cache=True):
19
+ try:
20
+ import requests
21
+ except ImportError:
22
+ raise ImportError("Required requests, try `pip install requests`.")
19
23
  data_dir = Path(user_cache_dir(appname="Marsilea"))
20
24
  data_dir.mkdir(exist_ok=True, parents=True)
21
25
 
22
- fname = url.split("/")[-1]
26
+ hasher = sha256()
27
+ hasher.update(url.encode("utf-8"))
28
+ fname = hasher.hexdigest()
29
+
23
30
  dest = data_dir / fname
24
31
  if not (cache and dest.exists()):
25
- urlretrieve(url, dest)
32
+ r = requests.get(url, headers={"User-Agent": "Mozilla/5.0"})
33
+ r.raise_for_status()
34
+ with open(dest, "wb") as f:
35
+ f.write(r.content)
26
36
 
27
37
  return dest
28
38
 
29
39
 
30
- class Emoji(RenderPlan):
31
- def __init__(self, images, lang="en", scale=1, mode="filled"):
32
- try:
33
- import emoji
34
- except ImportError:
35
- raise ImportError("Required emoji, try `pip install emoji`.")
40
+ class Image(RenderPlan):
41
+ """Plot static images
36
42
 
37
- codes = []
38
- for i in images:
39
- i = emoji.emojize(i, language=lang)
40
- if not emoji.is_emoji(i):
41
- raise ValueError(f"{i} is not a valid emoji")
42
- codes.append(f"{ord(i):X}".lower())
43
+ Parameters
44
+ ----------
45
+ images : array of image
46
+ You can input either path to the image, URL, or numpy array.
47
+ align : {"center", "top", "bottom", "left", "right"}, default: "center"
48
+ The alignment of the images.
49
+ scale : float, default: 1
50
+ The scale of the images.
51
+ spacing : float, default: 0.1
52
+ The spacing between images, a value between 0 and 1,
53
+ relative to the image container size.
54
+ resize : int or tuple, default: None
55
+ The size to resize the images.
56
+
57
+ Examples
58
+ --------
59
+
60
+ .. plot::
61
+ :context: close-figs
43
62
 
44
- self.set_data(np.asarray(codes))
45
- self.emoji_caches = {}
46
- for c in codes:
47
- cache_image = _cache_remote(f"{TWEMOJI_CDN}{c}.png")
48
- self.emoji_caches[c] = imread(cache_image)
63
+ >>> import numpy as np
64
+ >>> import marsilea as ma
65
+ >>> c = ma.ZeroWidth(height=2)
66
+ >>> c.add_right(
67
+ ... ma.plotter.Image(
68
+ ... [
69
+ ... "https://www.iconfinder.com/icons/4375050/download/png/512",
70
+ ... "https://www.iconfinder.com/icons/8666426/download/png/512",
71
+ ... "https://www.iconfinder.com/icons/652581/download/png/512",
72
+ ... ],
73
+ ... align="right",
74
+ ... ),
75
+ ... pad=0.1,
76
+ ... )
77
+ >>> c.add_right(
78
+ ... ma.plotter.Labels(["Python", "Rust", "JavaScript"], fontsize=20), pad=0.1
79
+ ... )
80
+ >>> c.render()
49
81
 
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ images,
87
+ align="center",
88
+ scale=1,
89
+ spacing=0.1,
90
+ resize=None,
91
+ ):
92
+ self.images_mapper = {}
93
+
94
+ for i, img in enumerate(images):
95
+ if isinstance(img, str):
96
+ # Read from URL
97
+ if img.startswith("http") or img.startswith("https"):
98
+ img = imread(_cache_remote(img))
99
+ else:
100
+ # Read from string path
101
+ img = imread(img)
102
+ # Read from Path
103
+ elif isinstance(img, Path):
104
+ img = imread(img)
105
+ else:
106
+ # Read from array interface
107
+ img = np.asarray(img)
108
+
109
+ self.images_mapper[i] = img
110
+
111
+ self.images_codes = np.asarray(list(self.images_mapper.keys()))
112
+
113
+ self.images = images
114
+ self.align = align
50
115
  self.scale = scale
51
- self.mode = mode
116
+ if not 0 <= spacing <= 1:
117
+ raise ValueError("spacing should be between 0 and 1")
118
+ self.spacing = spacing
119
+ if resize is not None:
120
+ if isinstance(resize, Number):
121
+ resize = (int(resize), int(resize))
122
+ for i, img in self.images_mapper.items():
123
+ self.images_mapper[i] = img.resize(resize, PILImage.Resampling.LANCZOS)
124
+ for i, img in self.images_mapper.items():
125
+ self.images_mapper[i] = np.asarray(img)
126
+ self.set_data(self.images_codes)
127
+
128
+ def _get_bbox_imges(
129
+ self, ax, imgs, scale=1, align="center", ax_height=None, ax_width=None
130
+ ):
131
+ locs = np.linspace(0, 1, len(imgs) + 1)
132
+ slot_size = locs[1] - locs[0]
133
+ locs = locs[:-1] + slot_size * self.spacing / 2
134
+
135
+ xmin, ymin = ax.transAxes.transform((0, 0))
136
+ xmax, ymax = ax.transAxes.transform((1, 1))
137
+
138
+ if ax_width is None or ax_width == 0:
139
+ ax_width = xmax - xmin
140
+ if ax_height is None or ax_height == 0:
141
+ ax_height = ymax - ymin
142
+
143
+ base_dpi = ax.get_figure().get_dpi()
144
+
145
+ bbox_images = []
146
+ imgaes_sizes = []
147
+
148
+ if self.is_body:
149
+ fit_width = ax_width / len(imgs) * (1 - self.spacing)
150
+ for loc, img in zip(locs, imgs):
151
+ height, width = img.shape[:2]
152
+ fit_height = height / width * fit_width
52
153
 
53
- def _get_images_bbox(self, figure, imgs):
54
- for img in imgs:
55
- width, height = img.shape[:2]
154
+ fit_scale_width = fit_width * scale
155
+ fit_scale_height = fit_height * scale
56
156
 
57
- return Bbox.from_bounds(0, 0, width, height)
157
+ offset = (fit_width - fit_scale_width) / 2 / ax_width
158
+ loc += offset
159
+
160
+ if align == "top":
161
+ loc_y = 1 - fit_scale_height / ax_height
162
+ elif align == "bottom":
163
+ loc_y = 0
164
+ else:
165
+ loc_y = 0.5 - fit_scale_height / 2 / ax_height
166
+
167
+ def img_bbox(renderer, loc, loc_y, width, height, base_dpi):
168
+ factor = renderer.dpi / base_dpi
169
+ x0, y0 = ax.transData.transform((loc, loc_y))
170
+ return Bbox.from_bounds(x0, y0, width * factor, height * factor)
171
+
172
+ memorized_img_bbox = partial(
173
+ img_bbox,
174
+ loc=loc,
175
+ loc_y=loc_y,
176
+ width=fit_scale_width,
177
+ height=fit_scale_height,
178
+ base_dpi=base_dpi,
179
+ )
180
+
181
+ i1 = BboxImage(memorized_img_bbox, data=img)
182
+ bbox_images.append(i1)
183
+ imgaes_sizes.append(fit_scale_height)
184
+ else:
185
+ fit_height = ax_height / len(imgs) * (1 - self.spacing)
186
+ for loc, img in zip(locs, imgs[::-1]):
187
+ height, width = img.shape[:2]
188
+ fit_width = width / height * fit_height
189
+
190
+ fit_scale_width = fit_width * scale
191
+ fit_scale_height = fit_height * scale
192
+
193
+ offset = (fit_height - fit_scale_height) / 2 / ax_height
194
+ loc += offset
195
+
196
+ if align == "right":
197
+ loc_x = 1 - fit_scale_width / ax_width
198
+ elif align == "left":
199
+ loc_x = 0
200
+ else:
201
+ loc_x = 0.5 - fit_scale_width / 2 / ax_width
202
+
203
+ def img_bbox(renderer, loc_x, loc, width, height, base_dpi):
204
+ factor = renderer.dpi / base_dpi
205
+ x0, y0 = ax.transData.transform((loc_x, loc))
206
+ return Bbox.from_bounds(x0, y0, width * factor, height * factor)
207
+
208
+ memorized_img_bbox = partial(
209
+ img_bbox,
210
+ loc_x=loc_x,
211
+ loc=loc,
212
+ width=fit_scale_width,
213
+ height=fit_scale_height,
214
+ base_dpi=base_dpi,
215
+ )
216
+
217
+ i1 = BboxImage(memorized_img_bbox, data=img)
218
+ bbox_images.append(i1)
219
+ imgaes_sizes.append(fit_scale_width)
220
+
221
+ return bbox_images, max(imgaes_sizes)
58
222
 
59
223
  def render_ax(self, spec):
60
224
  ax = spec.ax
61
- data = spec.data
225
+ imgs = [self.images_mapper[i] for i in spec.data]
226
+ bbox_images, _ = self._get_bbox_imges(ax, imgs)
227
+ for i in bbox_images:
228
+ ax.add_artist(i)
229
+ ax.set_axis_off()
62
230
 
63
- # TODO: Does not work for orient = "v"
64
- locs = np.linspace(0, 1, len(data) + 1)
65
- for loc, d in zip(locs, data):
66
- img = self.emoji_caches[d]
67
- width, height = img.shape[:2]
231
+ def get_canvas_size(
232
+ self, figure, main_height=None, main_width=None, **kwargs
233
+ ) -> float:
234
+ ax = figure.add_subplot(111)
235
+ imgs = [self.images_mapper[i] for i in self.images_codes]
236
+ _, size = self._get_bbox_imges(
237
+ ax, imgs, ax_width=main_width, ax_height=main_height
238
+ )
239
+ ax.remove()
240
+ return size
68
241
 
69
- xmin, ymin = ax.transAxes.transform((0, 0))
70
- xmax, ymax = ax.transAxes.transform((1, 1))
71
242
 
72
- ax_width = xmax - xmin
73
- ax_height = ymax - ymin
243
+ # ======== EMOJI Plotter ========
74
244
 
75
- fit_width = ax_width / len(data)
76
- fit_height = height / width * fit_width
245
+ TWEMOJI_CDN = "https://cdn.jsdelivr.net/gh/twitter/twemoji/assets/72x72/"
77
246
 
78
- fit_scale_width = fit_width * self.scale
79
- fit_scale_height = fit_height * self.scale
80
247
 
81
- offset = (fit_width - fit_scale_width) / 2 / ax_width
82
- loc += offset
248
+ class Emoji(Image):
249
+ """Have fun with emoji images
83
250
 
84
- loc_y = 0.5 - fit_scale_height / 2 / ax_height
251
+ The emoji images are from `twemoji <https://twemoji.twitter.com/>`_.
85
252
 
86
- def get_emoji_bbox(renderer, loc, loc_y, width, height):
87
- x0, y0 = ax.transData.transform((loc, loc_y))
88
- return Bbox.from_bounds(x0, y0, width, height)
253
+ You can will all twemoji from `here <https://twemoji-cheatsheet.vercel.app/>`_
89
254
 
90
- partial_get_emoji_bbox = partial(
91
- get_emoji_bbox,
92
- loc=loc,
93
- loc_y=loc_y,
94
- width=fit_scale_width,
95
- height=fit_scale_height,
96
- )
255
+ Parameters
256
+ ----------
257
+ codes : array of str
258
+ The emoji codes. You can input either unicode or short code.
259
+ lang : str, default: "en"
260
+ The language of the emoji.
261
+ scale : float, default: 1
262
+ The scale of the emoji.
263
+ spacing : float, default: 0.1
264
+ The spacing between emoji, a value between 0 and 1,
265
+ relative to the emoji container size.
266
+
267
+ Examples
268
+ --------
269
+
270
+ .. plot::
271
+
272
+ >>> import marsilea as ma
273
+ >>> c = ma.ZeroHeight(width=2)
274
+ >>> c.add_top(ma.plotter.Emoji("😆😆🤣😂😉😇🐍🦀🦄"))
275
+ >>> c.render()
276
+
277
+ """
278
+
279
+ def __init__(self, codes, lang="en", scale=1, spacing=0.1, **kwargs):
280
+ try:
281
+ import emoji
282
+ except ImportError:
283
+ raise ImportError("Required emoji, try `pip install emoji`.")
284
+
285
+ urls = []
286
+ for i in codes:
287
+ i = emoji.emojize(i, language=lang)
288
+ if not emoji.is_emoji(i):
289
+ raise ValueError(f"{i} is not a valid emoji")
290
+ c = f"{ord(i):X}".lower()
291
+ urls.append(f"{TWEMOJI_CDN}{c}.png")
97
292
 
98
- i1 = BboxImage(partial_get_emoji_bbox, data=img)
99
- ax.add_artist(i1)
293
+ super().__init__(urls, scale=scale, spacing=spacing, **kwargs)
@@ -195,7 +195,7 @@ def _seaborn_doc(obj: _SeabornBase):
195
195
  >>> sdata = {sdata}
196
196
  >>> plot = {cls_name}(sdata, {kws})
197
197
  >>> h = ma.Heatmap(data)
198
- >>> h.hsplit(cut=[3, 7])
198
+ >>> h.cut_rows(cut=[3, 7])
199
199
  >>> h.add_right(plot)
200
200
  >>> h.render()
201
201
  """
@@ -208,7 +208,7 @@ def _seaborn_doc(obj: _SeabornBase):
208
208
 
209
209
  >>> plot = {cls_name}({hue_data}, {kws})
210
210
  >>> h = ma.Heatmap(data)
211
- >>> h.hsplit(cut=[3, 7])
211
+ >>> h.cut_rows(cut=[3, 7])
212
212
  >>> h.add_right(plot)
213
213
  >>> h.render()
214
214
 
@@ -222,7 +222,7 @@ def _seaborn_doc(obj: _SeabornBase):
222
222
  >>> anno = ma.plotter.Chunk(['C1', 'C2', 'C3'], colors, padding=10)
223
223
  >>> cb = ma.ClusterBoard(data, height=2, margin=.5)
224
224
  >>> cb.add_layer(plot)
225
- >>> cb.vsplit(cut=[3, 7])
225
+ >>> cb.cut_cols([3, 7])
226
226
  >>> cb.add_bottom(anno)
227
227
  >>> cb.render()
228
228
 
@@ -233,9 +233,10 @@ def _seaborn_doc(obj: _SeabornBase):
233
233
 
234
234
  >>> plot = {cls_name}(sdata, orient='h',
235
235
  ... {h_kws})
236
+ >>> anno = ma.plotter.Chunk(['C1', 'C2', 'C3'], colors, padding=10)
236
237
  >>> cb = ma.ClusterBoard(data.T, width=2)
237
238
  >>> cb.add_layer(plot)
238
- >>> cb.hsplit(cut=[3, 7])
239
+ >>> cb.cut_rows([3, 7])
239
240
  >>> cb.add_left(anno)
240
241
  >>> cb.render()
241
242
 
marsilea/plotter/arc.py CHANGED
@@ -141,8 +141,7 @@ class Arc(StatsBase):
141
141
  >>> colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C6"]
142
142
  >>> labels = ["A", "B", "C", "D", "E", "F", "G"]
143
143
  >>> h = ma.Heatmap(np.random.rand(10, 10))
144
- >>> h.add_top(Arc(anchors, links, weights=weights,
145
- ... colors=colors, labels=labels))
144
+ >>> h.add_top(Arc(anchors, links, weights=weights, colors=colors, labels=labels))
146
145
  >>> h.render()
147
146
 
148
147
 
marsilea/plotter/bar.py CHANGED
@@ -276,13 +276,13 @@ class CenterBar(_BarBase):
276
276
  class StackBar(_BarBase):
277
277
  """Stacked Bar
278
278
 
279
- Parameters
280
- ----------
281
- data : np.ndarray, pd.DataFrame
279
+ Parameters
280
+ ----------
281
+ data : np.ndarray, pd.DataFrame
282
282
  2D data, index of dataframe is used as the name of items.
283
- items : list of str
283
+ items : list of str
284
284
  The name of items.
285
- colors : list of colors, mapping of (item, color)
285
+ colors : list of colors, mapping of (item, color)
286
286
  The colors of the bar for each item.
287
287
  orient : {"v", "h"}
288
288
  The orientation of the plot
@@ -309,8 +309,9 @@ class StackBar(_BarBase):
309
309
  :context: close-figs
310
310
 
311
311
  >>> from marsilea.plotter import StackBar
312
- >>> stack_data = pd.DataFrame(data=np.random.randint(1, 10, (5, 10)),
313
- ... index=list("abcde"))
312
+ >>> stack_data = pd.DataFrame(
313
+ ... data=np.random.randint(1, 10, (5, 10)), index=list("abcde")
314
+ ... )
314
315
  >>> _, ax = plt.subplots()
315
316
  >>> StackBar(stack_data).render(ax)
316
317
 
marsilea/plotter/base.py CHANGED
@@ -14,6 +14,13 @@ from .._deform import Deformation
14
14
  from ..exceptions import DataError, SplitConflict
15
15
 
16
16
 
17
+ # class DataValidator:
18
+ #
19
+ # @singledispatch(np.ndarray)
20
+ # def parse(self, data):
21
+ # pass
22
+
23
+
17
24
  class DataLoader:
18
25
  """Handle user data"""
19
26
 
@@ -397,7 +404,8 @@ class RenderPlan:
397
404
  return self._get_intact_render_spec(axes)
398
405
  except Exception as _:
399
406
  raise DataError(
400
- f"Please check your data input with {self.__class__.__name__}"
407
+ f"Please check your data input "
408
+ f"with {self.__class__.__name__} at '{self.side}'"
401
409
  )
402
410
 
403
411
  # def get_render_data(self):
@@ -508,7 +516,9 @@ class RenderPlan:
508
516
  if self._plan_label is not None:
509
517
  self._plan_label.add(axes, self.side)
510
518
 
511
- def get_canvas_size(self, figure) -> float:
519
+ def get_canvas_size(
520
+ self, figure, main_height=None, main_width=None, **kwargs
521
+ ) -> float:
512
522
  """
513
523
  If the size is unknown before rendering, this function must be
514
524
  implemented to return the canvas size in inches.
marsilea/plotter/bio.py CHANGED
@@ -102,8 +102,9 @@ class SeqLogo(StatsBase):
102
102
 
103
103
  >>> import pandas as pd
104
104
  >>> from marsilea.plotter import SeqLogo
105
- >>> matrix = pd.DataFrame(data=np.random.randint(1, 10, (4, 10)),
106
- ... index=list("ACGT"))
105
+ >>> matrix = pd.DataFrame(
106
+ ... data=np.random.randint(1, 10, (4, 10)), index=list("ACGT")
107
+ ... )
107
108
  >>> _, ax = plt.subplots()
108
109
  >>> colors = {"A": "r", "C": "b", "G": "g", "T": "black"}
109
110
  >>> SeqLogo(matrix, color_encode=colors).render(ax)
marsilea/plotter/mesh.py CHANGED
@@ -103,7 +103,7 @@ class ColorMesh(MeshBase):
103
103
 
104
104
  >>> import marsilea as ma
105
105
  >>> from marsilea.plotter import ColorMesh
106
- >>> _, ax = plt.subplots(figsize=(5, .5))
106
+ >>> _, ax = plt.subplots(figsize=(5, 0.5))
107
107
  >>> ColorMesh(np.arange(10), cmap="Blues").render(ax)
108
108
 
109
109
  .. plot::
@@ -111,13 +111,13 @@ class ColorMesh(MeshBase):
111
111
 
112
112
  >>> data = np.random.randn(10, 8)
113
113
  >>> h = ma.Heatmap(data)
114
- >>> h.hsplit(cut=[5])
114
+ >>> h.cut_rows(cut=[5])
115
115
  >>> h.add_dendrogram("left")
116
116
  >>> cmap1, cmap2 = "Purples", "Greens"
117
- >>> colors1 = ColorMesh(np.arange(10)+1, cmap=cmap1, label=cmap1, annot=True)
118
- >>> colors2 = ColorMesh(np.arange(10)+1, cmap=cmap2, label=cmap2)
119
- >>> h.add_right(colors1, size=.2, pad=.05)
120
- >>> h.add_right(colors2, size=.2, pad=.05)
117
+ >>> colors1 = ColorMesh(np.arange(10) + 1, cmap=cmap1, label=cmap1, annot=True)
118
+ >>> colors2 = ColorMesh(np.arange(10) + 1, cmap=cmap2, label=cmap2)
119
+ >>> h.add_right(colors1, size=0.2, pad=0.05)
120
+ >>> h.add_right(colors2, size=0.2, pad=0.05)
121
121
  >>> h.render()
122
122
 
123
123
 
@@ -280,7 +280,7 @@ class Colors(MeshBase):
280
280
 
281
281
  >>> import marsilea as ma
282
282
  >>> from marsilea.plotter import Colors
283
- >>> _, ax = plt.subplots(figsize=(5, .5))
283
+ >>> _, ax = plt.subplots(figsize=(5, 0.5))
284
284
  >>> data = np.random.choice(["A", "B", "C"], 10)
285
285
  >>> Colors(data).render(ax)
286
286
 
@@ -288,10 +288,10 @@ class Colors(MeshBase):
288
288
  :context: close-figs
289
289
 
290
290
  >>> h = ma.Heatmap(np.random.randn(10, 8))
291
- >>> h.hsplit(cut=[5])
291
+ >>> h.cut_rows(cut=[5])
292
292
  >>> h.add_dendrogram("left")
293
293
  >>> color = Colors(data, label="Colors")
294
- >>> h.add_right(color, size=.2, pad=.05)
294
+ >>> h.add_right(color, size=0.2, pad=0.05)
295
295
  >>> h.render()
296
296
 
297
297
 
@@ -467,7 +467,7 @@ class SizedMesh(MeshBase):
467
467
 
468
468
  >>> import marsilea as ma
469
469
  >>> from marsilea.plotter import SizedMesh
470
- >>> _, ax = plt.subplots(figsize=(5, .5))
470
+ >>> _, ax = plt.subplots(figsize=(5, 0.5))
471
471
  >>> size, color = np.random.rand(1, 10), np.random.rand(1, 10)
472
472
  >>> SizedMesh(size, color).render(ax)
473
473
 
@@ -475,10 +475,10 @@ class SizedMesh(MeshBase):
475
475
  :context: close-figs
476
476
 
477
477
  >>> h = ma.Heatmap(np.random.randn(10, 8))
478
- >>> h.hsplit(cut=[5])
478
+ >>> h.cut_rows(cut=[5])
479
479
  >>> h.add_dendrogram("left")
480
480
  >>> mesh = SizedMesh(size, color, marker="*", label="SizedMesh")
481
- >>> h.add_right(mesh, size=.2, pad=.05)
481
+ >>> h.add_right(mesh, size=0.2, pad=0.05)
482
482
  >>> h.render()
483
483
 
484
484
 
@@ -668,6 +668,8 @@ class MarkerMesh(MeshBase):
668
668
  See :mod:`matplotlib.markers`
669
669
  size : int
670
670
  The of marker in fontsize unit
671
+ frameon : bool
672
+ Whether to draw the border of the plot
671
673
  label : str
672
674
  The label of the plot, only show when added to the side plot
673
675
  label_loc : {'top', 'bottom', 'left', 'right'}
@@ -698,6 +700,7 @@ class MarkerMesh(MeshBase):
698
700
  color="black",
699
701
  marker="*",
700
702
  size=35,
703
+ frameon=False,
701
704
  label=None,
702
705
  label_loc=None,
703
706
  label_props=None,
@@ -706,9 +709,10 @@ class MarkerMesh(MeshBase):
706
709
  self.set_data(np.asarray(data))
707
710
  self.color = color
708
711
  self.marker = marker
712
+ self.marker_size = size
713
+ self.frameon = frameon
709
714
  self.set_label(label, label_loc, label_props)
710
715
  self.kwargs = kwargs
711
- self.marker_size = size
712
716
 
713
717
  def get_legends(self):
714
718
  return CatLegend(
@@ -737,6 +741,8 @@ class MarkerMesh(MeshBase):
737
741
  ax.set_xlim(0, xticks[-1] + 0.5)
738
742
  ax.set_ylim(0, yticks[-1] + 0.5)
739
743
  ax.invert_yaxis()
744
+ if not self.frameon:
745
+ ax.set_axis_off()
740
746
 
741
747
 
742
748
  class TextMesh(MeshBase):
@@ -748,6 +754,8 @@ class TextMesh(MeshBase):
748
754
  The text to draw
749
755
  color : color
750
756
  The color of the text
757
+ frameon : bool
758
+ Whether to draw the border of the plot
751
759
  label : str
752
760
  The label of the plot, only show when added to the side plot
753
761
  label_loc : {'top', 'bottom', 'left', 'right'}
@@ -765,6 +773,7 @@ class TextMesh(MeshBase):
765
773
  self,
766
774
  texts,
767
775
  color="black",
776
+ frameon=False,
768
777
  label=None,
769
778
  label_loc=None,
770
779
  label_props=None,
@@ -772,6 +781,7 @@ class TextMesh(MeshBase):
772
781
  ):
773
782
  self.set_data(self.data_validator(texts))
774
783
  self.color = color
784
+ self.frameon = frameon
775
785
  self.set_label(label, label_loc, label_props)
776
786
  self.kwargs = kwargs
777
787
 
@@ -797,3 +807,5 @@ class TextMesh(MeshBase):
797
807
  ax.set_xlim(0, xticks[-1] + 0.5)
798
808
  ax.set_ylim(0, yticks[-1] + 0.5)
799
809
  ax.invert_yaxis()
810
+ if not self.frameon:
811
+ ax.set_axis_off()