iplotx 0.3.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iplotx/tree.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from typing import (
2
2
  Optional,
3
3
  Sequence,
4
+ Any,
4
5
  )
5
6
  from collections.abc import Hashable
6
7
  from collections import defaultdict
@@ -13,6 +14,7 @@ from .style import (
13
14
  context,
14
15
  get_style,
15
16
  rotate_style,
17
+ merge_styles,
16
18
  )
17
19
  from .utils.matplotlib import (
18
20
  _stale_wrapper,
@@ -30,6 +32,9 @@ from .edge import (
30
32
  EdgeCollection,
31
33
  make_stub_patch as make_undirected_edge_patch,
32
34
  )
35
+ from .edge.leaf import (
36
+ LeafEdgeCollection,
37
+ )
33
38
  from .label import (
34
39
  LabelCollection,
35
40
  )
@@ -58,24 +63,19 @@ class TreeArtist(mpl.artist.Artist):
58
63
  self,
59
64
  tree,
60
65
  layout: Optional[str] = "horizontal",
61
- orientation: Optional[str] = None,
62
66
  directed: bool | str = False,
63
- vertex_labels: Optional[
64
- bool | list[str] | dict[Hashable, str] | pd.Series
65
- ] = None,
67
+ vertex_labels: Optional[bool | list[str] | dict[Hashable, str] | pd.Series] = None,
66
68
  edge_labels: Optional[Sequence | dict[Hashable, str] | pd.Series] = None,
67
69
  leaf_labels: Optional[Sequence | dict[Hashable, str]] | pd.Series = None,
68
70
  transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
69
71
  offset_transform: Optional[mpl.transforms.Transform] = None,
72
+ show_support: bool = False,
70
73
  ):
71
74
  """Initialize the TreeArtist.
72
75
 
73
76
  Parameters:
74
77
  tree: The tree to plot.
75
78
  layout: The layout to use for the tree. Can be "horizontal", "vertical", or "radial".
76
- orientation: The orientation of the tree layout. Can be "right" or "left" (for
77
- horizontal and radial layouts) and "descending" or "ascending" (for vertical
78
- layouts).
79
79
  directed: Whether the tree is directed. Can be a boolean or a string with the
80
80
  following choices: "parent" or "child".
81
81
  vertex_labels: Labels for the vertices. Can be a list, dictionary, or pandas Series.
@@ -88,13 +88,14 @@ class TreeArtist(mpl.artist.Artist):
88
88
  transform: The transform to apply to the tree artist. This is usually the identity.
89
89
  offset_transform: The offset transform to apply to the tree artist. This is
90
90
  usually `ax.transData`.
91
+ show_support: Whether to show support values on the nodes. If both show_support and
92
+ vertex_labels are set, this parameters takes precedence.
91
93
  """
92
94
 
93
95
  self.tree = tree
94
96
  self._ipx_internal_data = ingest_tree_data(
95
97
  tree,
96
98
  layout,
97
- orientation=orientation,
98
99
  directed=directed,
99
100
  layout_style=get_style(".layout", {}),
100
101
  vertex_labels=vertex_labels,
@@ -102,6 +103,10 @@ class TreeArtist(mpl.artist.Artist):
102
103
  leaf_labels=leaf_labels,
103
104
  )
104
105
 
106
+ if show_support:
107
+ support = self._ipx_internal_data["vertex_df"]["support"]
108
+ self._ipx_internal_data["vertex_df"]["label"] = support
109
+
105
110
  super().__init__()
106
111
 
107
112
  # This is usually the identity (which scales poorly with dpi)
@@ -116,10 +121,11 @@ class TreeArtist(mpl.artist.Artist):
116
121
  self._add_vertices()
117
122
  self._add_edges()
118
123
  self._add_leaf_vertices()
124
+ self._add_leaf_edges()
119
125
 
120
126
  # NOTE: cascades need to be created after leaf vertices in case
121
127
  # they are requested to wrap around them.
122
- if "cascade" in self.get_vertices().get_style():
128
+ if get_style(".cascade") != {}:
123
129
  self._add_cascades()
124
130
 
125
131
  def get_children(self) -> tuple[mpl.artist.Artist]:
@@ -131,6 +137,8 @@ class TreeArtist(mpl.artist.Artist):
131
137
  children = [self._vertices, self._edges]
132
138
  if hasattr(self, "_leaf_vertices"):
133
139
  children.append(self._leaf_vertices)
140
+ if hasattr(self, "_leaf_edges"):
141
+ children.append(self._leaf_edges)
134
142
  if hasattr(self, "_cascades"):
135
143
  children.append(self._cascades)
136
144
  return tuple(children)
@@ -141,12 +149,24 @@ class TreeArtist(mpl.artist.Artist):
141
149
  Parameters:
142
150
  fig: the figure to set for this artist and its children.
143
151
  """
144
- super().set_figure(fig)
145
- for child in self.get_children():
146
- child.set_figure(fig)
147
-
148
152
  # At the end, if there are cadcades with extent depending on
149
153
  # leaf edges, we should update them
154
+ super().set_figure(fig)
155
+
156
+ # The next two are vanilla NetworkArtist
157
+ self._vertices.set_figure(fig)
158
+ self._edges.set_figure(fig)
159
+
160
+ # For trees, there are a few more elements to coordinate,
161
+ # including possibly text at the fringes (leaf labels)
162
+ # which might require a redraw (without rendering) to compute
163
+ # its actual scren real estate.
164
+ if hasattr(self, "_leaf_vertices"):
165
+ self._leaf_vertices.set_figure(fig)
166
+ if hasattr(self, "_leaf_edges"):
167
+ self._leaf_edges.set_figure(fig)
168
+ if hasattr(self, "_cascades"):
169
+ self._cascades.set_figure(fig)
150
170
  self._update_cascades_extent()
151
171
 
152
172
  def _update_cascades_extent(self) -> None:
@@ -154,7 +174,7 @@ class TreeArtist(mpl.artist.Artist):
154
174
  if not hasattr(self, "_cascades"):
155
175
  return
156
176
 
157
- style_cascade = self.get_vertices().get_style()["cascade"]
177
+ style_cascade = get_style(".cascade")
158
178
  extend_to_labels = style_cascade.get("extend", False) == "leaf_labels"
159
179
  if not extend_to_labels:
160
180
  return
@@ -172,9 +192,7 @@ class TreeArtist(mpl.artist.Artist):
172
192
 
173
193
  def get_layout(self, kind="vertex"):
174
194
  """Get vertex or edge layout."""
175
- layout_columns = [
176
- f"_ipx_layout_{i}" for i in range(self._ipx_internal_data["ndim"])
177
- ]
195
+ layout_columns = [f"_ipx_layout_{i}" for i in range(self._ipx_internal_data["ndim"])]
178
196
 
179
197
  if kind == "vertex":
180
198
  layout = self._ipx_internal_data["vertex_df"][layout_columns]
@@ -213,13 +231,16 @@ class TreeArtist(mpl.artist.Artist):
213
231
  edge_bbox = self._edges.get_datalim(transData)
214
232
  bbox = mpl.transforms.Bbox.union([bbox, edge_bbox])
215
233
 
234
+ if hasattr(self, "_leaf_vertices"):
235
+ leaf_labels_bbox = self._leaf_vertices.get_datalim(transData)
236
+ bbox = mpl.transforms.Bbox.union([bbox, leaf_labels_bbox])
237
+
216
238
  if hasattr(self, "_cascades"):
217
239
  cascades_bbox = self._cascades.get_datalim(transData)
218
240
  bbox = mpl.transforms.Bbox.union([bbox, cascades_bbox])
219
241
 
220
- if hasattr(self, "_leaf_vertices"):
221
- leaf_labels_bbox = self._leaf_vertices.get_datalim(transData)
222
- bbox = mpl.transforms.Bbox.union([bbox, leaf_labels_bbox])
242
+ # NOTE: We do not need to check leaf edges for bbox, because they are
243
+ # guaranteed within the convex hull of leaf vertices.
223
244
 
224
245
  bbox = bbox.expanded(sw=(1.0 + pad), sh=(1.0 + pad))
225
246
  return bbox
@@ -240,8 +261,12 @@ class TreeArtist(mpl.artist.Artist):
240
261
 
241
262
  def get_leaf_vertices(self) -> Optional[VertexCollection]:
242
263
  """Get leaf VertexCollection artist."""
243
- if hasattr(self, "_leaf_vertices"):
244
- return self._leaf_vertices
264
+ return self._leaf_vertices
265
+
266
+ def get_leaf_edges(self) -> Optional[LeafEdgeCollection]:
267
+ """Get LeafEdgeCollection artist if present."""
268
+ if hasattr(self, "_leaf_edges"):
269
+ return self._leaf_edges
245
270
  return None
246
271
 
247
272
  def get_vertex_labels(self) -> LabelCollection:
@@ -253,8 +278,14 @@ class TreeArtist(mpl.artist.Artist):
253
278
  return self._edges.get_labels()
254
279
 
255
280
  def get_leaf_labels(self) -> Optional[LabelCollection]:
256
- if hasattr(self, "_leaf_vertices"):
257
- return self._leaf_vertices.get_labels()
281
+ """Get the leaf label artist if present."""
282
+ return self._leaf_vertices.get_labels()
283
+
284
+ def get_leaf_edge_labels(self) -> Optional[LabelCollection]:
285
+ """Get the leaf edge label artist if present."""
286
+ # TODO: leaf edge labels are basically unsupported as of now
287
+ if hasattr(self, "_leaf_edges"):
288
+ return self._leaf_edges.get_labels()
258
289
  return None
259
290
 
260
291
  def _add_vertices(self) -> None:
@@ -271,36 +302,149 @@ class TreeArtist(mpl.artist.Artist):
271
302
  offset_transform=self.get_offset_transform(),
272
303
  )
273
304
 
305
+ def _add_leaf_edges(self) -> None:
306
+ """Add edges from the leaf to the max leaf depth."""
307
+ # If there are no leaves, no leaf labels, or leaves are not deep,
308
+ # skip leaf edges
309
+ leaf_style = get_style(".leaf", {})
310
+ if ("deep" not in leaf_style) and self.get_leaf_labels() is None:
311
+ return
312
+ if not leaf_style.get("deep", True):
313
+ return
314
+
315
+ edge_style = get_style(
316
+ ".leafedge",
317
+ )
318
+ default_style = {
319
+ "linestyle": "--",
320
+ "linewidth": 1,
321
+ "edgecolor": "#111",
322
+ }
323
+ for key, value in default_style.items():
324
+ if key not in edge_style:
325
+ edge_style[key] = value
326
+
327
+ labels = None
328
+ # TODO: implement leaf edge labels
329
+ # self._get_label_series("leafedge")
330
+
331
+ if "cmap" in edge_style:
332
+ cmap_fun = _build_cmap_fun(
333
+ edge_style["color"],
334
+ edge_style["cmap"],
335
+ )
336
+ else:
337
+ cmap_fun = None
338
+
339
+ leaf_shallow_layout = self.get_layout("leaf")
340
+
341
+ if "cmap" in edge_style:
342
+ colorarray = []
343
+ edgepatches = []
344
+ adjacent_vertex_ids = []
345
+ for i, vid in enumerate(leaf_shallow_layout.index):
346
+ edge_stylei = rotate_style(edge_style, index=i, key=vid)
347
+
348
+ if cmap_fun is not None:
349
+ colorarray.append(edge_stylei["color"])
350
+ edge_stylei["color"] = cmap_fun(edge_stylei["color"])
351
+
352
+ # These are not the actual edges drawn, only stubs to establish
353
+ # the styles which are then fed into the dynamic, optimised
354
+ # factory (the collection) below
355
+ patch = make_undirected_edge_patch(
356
+ **edge_stylei,
357
+ )
358
+ edgepatches.append(patch)
359
+ adjacent_vertex_ids.append(vid)
360
+
361
+ if "cmap" in edge_style:
362
+ vmin = np.min(colorarray)
363
+ vmax = np.max(colorarray)
364
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
365
+ edge_style["norm"] = norm
366
+
367
+ self._leaf_edges = LeafEdgeCollection(
368
+ edgepatches,
369
+ vertex_leaf_ids=adjacent_vertex_ids,
370
+ vertex_collection=self._vertices,
371
+ leaf_collection=self._leaf_vertices,
372
+ labels=labels,
373
+ transform=self.get_offset_transform(),
374
+ style=edge_style,
375
+ directed=False,
376
+ )
377
+ if "cmap" in edge_style:
378
+ self._leaf_edges.set_array(colorarray)
379
+
274
380
  def _add_leaf_vertices(self) -> None:
275
381
  """Add invisible deep vertices as leaf label anchors."""
382
+ layout_name = self._ipx_internal_data["layout_name"]
383
+ orientation = self._ipx_internal_data["orientation"]
384
+ user_leaf_style = get_style(".leaf", {})
385
+
276
386
  leaf_layout = self.get_layout("leaf").copy()
387
+
277
388
  # Set all to max depth
278
- depth_idx = int(self._ipx_internal_data["layout_name"] == "vertical")
279
- leaf_layout.iloc[:, depth_idx] = leaf_layout.iloc[:, depth_idx].max()
389
+ if user_leaf_style.get("deep", True):
390
+ if layout_name == "radial":
391
+ leaf_layout.iloc[:, 0] = leaf_layout.iloc[:, 0].max()
392
+ elif (layout_name, orientation) == ("horizontal", "right"):
393
+ leaf_layout.iloc[:, 0] = leaf_layout.iloc[:, 0].max()
394
+ elif (layout_name, orientation) == ("horizontal", "left"):
395
+ leaf_layout.iloc[:, 0] = leaf_layout.iloc[:, 0].min()
396
+ elif (layout_name, orientation) == ("vertical", "descending"):
397
+ leaf_layout.iloc[:, 1] = leaf_layout.iloc[:, 1].min()
398
+ elif (layout_name, orientation) == ("vertical", "ascending"):
399
+ leaf_layout.iloc[:, 1] = leaf_layout.iloc[:, 1].max()
400
+ else:
401
+ raise ValueError(
402
+ f"Layout and orientation not supported: {layout_name}, {orientation}."
403
+ )
280
404
 
281
405
  # Set invisible vertices with visible labels
282
- layout_name = self._ipx_internal_data["layout_name"]
283
- orientation = self._ipx_internal_data["orientation"]
284
406
  if layout_name == "radial":
285
407
  ha = "auto"
286
- elif orientation in ("left", "ascending"):
408
+ rotation = 0
409
+ elif orientation == "right":
410
+ ha = "left"
411
+ rotation = 0
412
+ elif orientation == "left":
287
413
  ha = "right"
414
+ rotation = 0
415
+ elif orientation == "ascending":
416
+ ha = "left"
417
+ rotation = 90
288
418
  else:
289
419
  ha = "left"
420
+ rotation = -90
290
421
 
291
- leaf_vertex_style = {
422
+ default_leaf_style = {
292
423
  "size": 0,
293
424
  "label": {
294
- "verticalalignment": "center",
425
+ "verticalalignment": "center_baseline",
295
426
  "horizontalalignment": ha,
427
+ "rotation": rotation,
296
428
  "hmargin": 5,
429
+ "vmargin": 0,
297
430
  "bbox": {
298
431
  "facecolor": (1, 1, 1, 0),
299
432
  },
300
433
  },
301
434
  }
302
- with context({"vertex": leaf_vertex_style}):
435
+ with context([{"vertex": default_leaf_style}, {"vertex": user_leaf_style}]):
303
436
  leaf_vertex_style = get_style(".vertex")
437
+ # Left horizontal layout has no rotation of the labels but we need to
438
+ # reverse hmargin
439
+ if (
440
+ layout_name == "horizontal"
441
+ and orientation == "left"
442
+ and "label" in leaf_vertex_style
443
+ and "hmargin" in leaf_vertex_style["label"]
444
+ ):
445
+ # Reverse the horizontal margin
446
+ leaf_vertex_style["label"]["hmargin"] *= -1
447
+
304
448
  self._leaf_vertices = VertexCollection(
305
449
  layout=leaf_layout,
306
450
  layout_coordinate_system=self._ipx_internal_data.get(
@@ -317,8 +461,25 @@ class TreeArtist(mpl.artist.Artist):
317
461
  """Add cascade patches."""
318
462
  # NOTE: If leaf labels are present and the cascades are requested to wrap around them,
319
463
  # we have to compute the max extend of the cascades from the leaf labels.
320
- maxdepth = None
321
- style_cascade = self.get_vertices().get_style()["cascade"]
464
+ layout = self.get_layout()
465
+ layout_name = self._ipx_internal_data["layout_name"]
466
+ orientation = self._ipx_internal_data["orientation"]
467
+ maxdepth = 1e-10
468
+ if layout_name == "horizontal":
469
+ if orientation == "right":
470
+ maxdepth = layout.values[:, 0].max()
471
+ else:
472
+ maxdepth = layout.values[:, 0].min()
473
+ elif layout_name == "vertical":
474
+ if orientation == "descending":
475
+ maxdepth = layout.values[:, 1].min()
476
+ else:
477
+ maxdepth = layout.values[:, 1].max()
478
+ elif layout_name == "radial":
479
+ # layout values are: r, theta
480
+ maxdepth = layout.values[:, 0].max()
481
+
482
+ style_cascade = get_style(".cascade")
322
483
  extend_to_labels = style_cascade.get("extend", False) == "leaf_labels"
323
484
  has_leaf_labels = self.get_leaf_labels() is not None
324
485
  if extend_to_labels and not has_leaf_labels:
@@ -329,9 +490,9 @@ class TreeArtist(mpl.artist.Artist):
329
490
 
330
491
  self._cascades = CascadeCollection(
331
492
  tree=self.tree,
332
- layout=self.get_layout(),
333
- layout_name=self._ipx_internal_data["layout_name"],
334
- orientation=self._ipx_internal_data["orientation"],
493
+ layout=layout,
494
+ layout_name=layout_name,
495
+ orientation=orientation,
335
496
  style=style_cascade,
336
497
  provider=data_providers["tree"][self._ipx_internal_data["tree_library"]],
337
498
  transform=self.get_offset_transform(),
@@ -342,10 +503,7 @@ class TreeArtist(mpl.artist.Artist):
342
503
  layout_name = self.get_layout_name()
343
504
  if layout_name == "radial":
344
505
  maxdepth = 0
345
- # These are the text boxes, they must all be included
346
- bboxes = self.get_leaf_labels().get_datalims_children(
347
- self.get_offset_transform()
348
- )
506
+ bboxes = self.get_leaf_labels().get_datalims_children(self.get_offset_transform())
349
507
  for bbox in bboxes:
350
508
  r1 = np.linalg.norm([bbox.xmax, bbox.ymax])
351
509
  r2 = np.linalg.norm([bbox.xmax, bbox.ymin])
@@ -388,9 +546,7 @@ class TreeArtist(mpl.artist.Artist):
388
546
  else:
389
547
  cmap_fun = None
390
548
 
391
- edge_df = self._ipx_internal_data["edge_df"].set_index(
392
- ["_ipx_source", "_ipx_target"]
393
- )
549
+ edge_df = self._ipx_internal_data["edge_df"].set_index(["_ipx_source", "_ipx_target"])
394
550
 
395
551
  if "cmap" in edge_style:
396
552
  colorarray = []
@@ -398,7 +554,7 @@ class TreeArtist(mpl.artist.Artist):
398
554
  adjacent_vertex_ids = []
399
555
  waypoints = []
400
556
  for i, (vid1, vid2) in enumerate(edge_df.index):
401
- edge_stylei = rotate_style(edge_style, index=i, key=(vid1, vid2))
557
+ edge_stylei = rotate_style(edge_style, index=i, key=vid2)
402
558
 
403
559
  # FIXME:: Improve this logic. We have three layers of priority:
404
560
  # 1. Explicitely set in the style of "plot"
@@ -420,6 +576,8 @@ class TreeArtist(mpl.artist.Artist):
420
576
 
421
577
  # Tree layout determines waypoints
422
578
  waypointsi = edge_stylei.pop("waypoints", None)
579
+ if isinstance(waypointsi, (bool, np.bool)):
580
+ waypointsi = ["none", None][int(waypointsi)]
423
581
  if waypointsi is None:
424
582
  layout_name = self._ipx_internal_data["layout_name"]
425
583
  if layout_name == "horizontal":
@@ -429,7 +587,9 @@ class TreeArtist(mpl.artist.Artist):
429
587
  elif layout_name == "radial":
430
588
  waypointsi = "r0a1"
431
589
  else:
432
- waypointsi = "none"
590
+ raise ValueError(
591
+ f"Layout not supported: {layout_name}. ",
592
+ )
433
593
  waypoints.append(waypointsi)
434
594
 
435
595
  # These are not the actual edges drawn, only stubs to establish
@@ -447,7 +607,10 @@ class TreeArtist(mpl.artist.Artist):
447
607
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
448
608
  edge_style["norm"] = norm
449
609
 
450
- edge_style["waypoints"] = waypoints
610
+ if get_style(".layout", {}).get("angular", False):
611
+ edge_style.pop("waypoints", None)
612
+ else:
613
+ edge_style["waypoints"] = waypoints
451
614
 
452
615
  # NOTE: Trees are directed is their "directed" property is True, "child", or "parent"
453
616
  self._edges = EdgeCollection(
@@ -470,24 +633,89 @@ class TreeArtist(mpl.artist.Artist):
470
633
  """Get the orientation of the tree layout."""
471
634
  return self._ipx_internal_data.get("orientation", None)
472
635
 
636
+ def style_subtree(
637
+ self,
638
+ nodes: Sequence[Hashable],
639
+ style: Optional[dict[str, Any] | Sequence[str | dict[str, Any]]] = None,
640
+ **kwargs,
641
+ ) -> None:
642
+ """Style a subtree of the tree.
643
+
644
+ Parameters:
645
+ nodes: Sequence of nodes that span the subtree. All elements below including
646
+ the most recent common ancestor of these leaves will be styled.
647
+ style: Style or sequence of styles to apply to the subtree. Each style can
648
+ be either a string, referring to an internal `iplotx` style, or a dictionary
649
+ with custom styling elements.
650
+ kwargs: Additional flat style elements. If both style and kwargs are provided,
651
+ kwargs is applied last.
652
+ """
653
+ styles = []
654
+ if isinstance(style, (str, dict)):
655
+ styles = [style]
656
+ elif style is not None:
657
+ styles = list(style)
658
+ style = merge_styles(styles + [kwargs])
659
+
660
+ provider = data_providers["tree"][self._ipx_internal_data["tree_library"]]
661
+
662
+ # Get last (deepest) common ancestor of the requested nodes
663
+ root = provider(self.tree).get_lca(nodes)
664
+
665
+ # Populate a DataFrame with the array of properties to update
666
+ vertex_idx = {node: i for i, node in enumerate(self._ipx_internal_data["vertex_df"].index)}
667
+ edge_idx = {
668
+ node: i
669
+ for i, node in enumerate(self._ipx_internal_data["edge_df"]["_ipx_target"].values)
670
+ }
671
+ vertex_props = {}
672
+ edge_props = {}
673
+ vertex_style = style.get("vertex", {})
674
+ edge_style = style.get("edge", {})
675
+ for inode, node in enumerate(provider(root).preorder()):
676
+ for attr, value in vertex_style.items():
677
+ if attr not in vertex_props:
678
+ vertex_props[attr] = list(getattr(self._vertices, f"get_{attr}")())
679
+ vertex_props[attr][vertex_idx[node]] = value
680
+
681
+ # Ignore branch coming into the root node
682
+ if inode == 0:
683
+ continue
684
+
685
+ for attr, value in edge_style.items():
686
+ # Edge color is actually edgecolor
687
+ if attr == "color":
688
+ attr = "edgecolor"
689
+ if attr not in edge_props:
690
+ edge_props[attr] = list(getattr(self._edges, f"get_{attr}")())
691
+ edge_props[attr][edge_idx[node]] = value
692
+
693
+ # Update the properties from the DataFrames
694
+ for attr in vertex_props:
695
+ getattr(self._vertices, f"set_{attr}")(vertex_props[attr])
696
+ for attr in edge_props:
697
+ getattr(self._edges, f"set_{attr}")(edge_props[attr])
698
+
473
699
  @_stale_wrapper
474
700
  def draw(self, renderer) -> None:
475
701
  """Draw each of the children, with some buffering mechanism."""
476
702
  if not self.get_visible():
477
703
  return
478
704
 
479
- # At the end, if there are cadcades with extent depending on
480
- # leaf edges, we should update them
481
- self._update_cascades_extent()
482
-
483
705
  # NOTE: looks like we have to manage the zorder ourselves
484
706
  # this is kind of funny actually. Btw we need to ensure
485
707
  # that cascades are drawn behind (earlier than) vertices
486
708
  # and edges at equal zorder because it looks better that way.
487
709
  z_suborder = defaultdict(int)
710
+ if hasattr(self, "_leaf_vertices"):
711
+ z_suborder[self._leaf_vertices] = -2
712
+ if hasattr(self, "_leaf_edges"):
713
+ z_suborder[self._leaf_edges] = -2
488
714
  if hasattr(self, "_cascades"):
489
715
  z_suborder[self._cascades] = -1
490
716
  children = list(self.get_children())
491
717
  children.sort(key=lambda x: (x.zorder, z_suborder[x]))
492
718
  for art in children:
719
+ if isinstance(art, CascadeCollection):
720
+ self._update_cascades_extent()
493
721
  art.draw(renderer)
iplotx/typing.py CHANGED
@@ -32,12 +32,14 @@ LayoutType = Union[
32
32
  Sequence[Sequence[float]],
33
33
  np.ndarray,
34
34
  pd.DataFrame,
35
+ dict[Hashable, Sequence[float] | tuple[float, float]],
35
36
  # igraph.Layout,
36
37
  ]
37
38
  GroupingType = Union[
38
39
  Sequence[set],
39
40
  Sequence[int],
40
41
  Sequence[str],
42
+ dict[str, set],
41
43
  # igraph.clustering.Clustering,
42
44
  # igraph.clustering.VertexClustering,
43
45
  # igraph.clustering.Cover,
iplotx/utils/geometry.py CHANGED
@@ -1,5 +1,9 @@
1
+ from typing import (
2
+ Sequence,
3
+ )
1
4
  from math import atan2
2
5
  import numpy as np
6
+ import matplotlib as mpl
3
7
 
4
8
 
5
9
  # See also this link for the general answer (using scipy to compute coefficients):
@@ -13,28 +17,7 @@ def _evaluate_squared_bezier(points, t):
13
17
  def _evaluate_cubic_bezier(points, t):
14
18
  """Evaluate a cubic Bezier curve at t."""
15
19
  p0, p1, p2, p3 = points
16
- return (
17
- (1 - t) ** 3 * p0
18
- + 3 * (1 - t) ** 2 * t * p1
19
- + 3 * (1 - t) * t**2 * p2
20
- + t**3 * p3
21
- )
22
-
23
-
24
- def _evaluate_cubic_bezier_derivative(points, t):
25
- """Evaluate the derivative of a cubic Bezier curve at t."""
26
- p0, p1, p2, p3 = points
27
- # (dx / dt, dy / dt) is the parametric gradient
28
- # to get the angle from this, one can just atanh(dy/dt, dx/dt)
29
- # This is equivalent to computing the actual bezier curve
30
- # at low t, of course, which is the geometric interpretation
31
- # (obviously, division by t is irrelenant)
32
- return (
33
- 3 * p0 * (1 - t) ** 2
34
- + 3 * p1 * (1 - t) * (-3 * t + 1)
35
- + 3 * p2 * t * (2 - 3 * t)
36
- + 3 * p3 * t**2
37
- )
20
+ return (1 - t) ** 3 * p0 + 3 * (1 - t) ** 2 * t * p1 + 3 * (1 - t) * t**2 * p2 + t**3 * p3
38
21
 
39
22
 
40
23
  def convex_hull(points):
@@ -90,9 +73,7 @@ def _convex_hull_Graham_scan(points):
90
73
  pivot_idx = miny_idx[points[miny_idx, 0].argmin()]
91
74
 
92
75
  # Compute angles against that pivot, ensuring the pivot itself last
93
- angles = np.arctan2(
94
- points[:, 1] - points[pivot_idx, 1], points[:, 0] - points[pivot_idx, 0]
95
- )
76
+ angles = np.arctan2(points[:, 1] - points[pivot_idx, 1], points[:, 0] - points[pivot_idx, 0])
96
77
  angles[pivot_idx] = np.inf
97
78
 
98
79
  # Sort points by angle
@@ -169,22 +150,36 @@ def _convex_hull_Graham_scan(points):
169
150
 
170
151
 
171
152
  def _compute_group_path_with_vertex_padding(
172
- hull,
173
- points,
174
- transform,
175
- vertexpadding=10,
176
- points_per_curve=30,
153
+ hull: np.ndarray | Sequence[int],
154
+ points: np.ndarray,
155
+ transform: mpl.transforms.Transform,
156
+ vertexpadding: int = 10,
177
157
  # TODO: check how dpi affects this
178
- dpi=72.0,
179
- ):
158
+ dpi: float = 72.0,
159
+ ) -> np.ndarray:
180
160
  """Offset path for a group based on vertex padding.
181
161
 
182
- At the input, the structure is [v1, v1, v1, ..., vn, vn, vn, v1]
183
-
184
- # NOTE: this would look better as a cubic Bezier, but ok for now.
162
+ Parameters:
163
+ hull: The coordinates (not indices!) of the convex hull.
164
+ points: This is the np.ndarray where the coordinates will be written to (output).
165
+ The length is some integer ppc * len(hull) + 1 because for each vertex, this
166
+ function wraps around it using a certain fixed ppc number of points, plus the
167
+ final point for CLOSEPOLY.
168
+ transform: The transform of the hull points.
169
+ vertexpadding: The padding to apply to the vertices, in figure coordinates.
170
+ dpi (WIP): The dpi of the figure renderer.
171
+
172
+ Returns:
173
+ None. The output is written to the `points` array in place. This ensures that the
174
+ length of this array is unchanged, which is important to ensure that the vertices
175
+ and SVG codes are in sync.
185
176
  """
186
- # Short form
187
- ppc = points_per_curve
177
+ if len(hull) == 0:
178
+ return
179
+
180
+ # Short form for point per curve
181
+ ppc = (len(points) - 1) // len(hull)
182
+ assert len(points) % ppc == 1
188
183
 
189
184
  # No padding, set degenerate path
190
185
  if vertexpadding == 0:
@@ -196,11 +191,9 @@ def _compute_group_path_with_vertex_padding(
196
191
  # Transform into figure coordinates
197
192
  trans = transform.transform
198
193
  trans_inv = transform.inverted().transform
199
- points = trans(points)
200
194
 
201
195
  # Singleton: draw a circle around it
202
196
  if len(hull) == 1:
203
-
204
197
  # NOTE: linspace is double inclusive, which covers CLOSEPOLY
205
198
  thetas = np.linspace(
206
199
  -np.pi,
@@ -213,7 +206,6 @@ def _compute_group_path_with_vertex_padding(
213
206
 
214
207
  # Doublet: draw two semicircles
215
208
  if len(hull) == 2:
216
-
217
209
  # Unit vector connecting the two points
218
210
  dv = trans(hull[0]) - trans(hull[1])
219
211
  dv = dv / np.sqrt((dv**2).sum())