iplotx 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iplotx/__init__.py CHANGED
@@ -1,2 +1,23 @@
1
+ """
2
+ iplotx is a library for interactive plotting of networks and trees in
3
+ matplotlib.
4
+
5
+ It guarantees the visualisation will look exactly the same no matter what
6
+ library was used to construct the network.
7
+ """
8
+
1
9
  from .version import __version__
2
- from .plotting import plot
10
+ from .plotting import (
11
+ network,
12
+ tree,
13
+ )
14
+
15
+ # Shortcut to iplotx.plotting.network
16
+ plot = network
17
+
18
+ __all__ = [
19
+ "network",
20
+ "tree",
21
+ "plot",
22
+ "__version__",
23
+ ]
@@ -0,0 +1,623 @@
1
+ """
2
+ Module defining the main matplotlib Artist for network/tree edges, EdgeCollection.
3
+
4
+ Some supporting functions are also defined here.
5
+ """
6
+
7
+ from typing import (
8
+ Sequence,
9
+ Optional,
10
+ Any,
11
+ )
12
+ from math import atan2, cos, pi, sin
13
+ from collections import defaultdict
14
+ import numpy as np
15
+ import pandas as pd
16
+ import matplotlib as mpl
17
+
18
+ from ..typing import (
19
+ Pair,
20
+ LeafProperty,
21
+ )
22
+ from ..utils.matplotlib import (
23
+ _compute_mid_coord_and_rot,
24
+ _stale_wrapper,
25
+ _forwarder,
26
+ )
27
+ from ..style import (
28
+ rotate_style,
29
+ )
30
+ from ..label import LabelCollection
31
+ from ..vertex import VertexCollection
32
+ from .arrow import EdgeArrowCollection
33
+ from .geometry import (
34
+ _compute_loops_per_angle,
35
+ _fix_parallel_edges_straight,
36
+ _compute_loop_path,
37
+ _compute_edge_path,
38
+ )
39
+
40
+
41
+ @_forwarder(
42
+ (
43
+ "set_clip_path",
44
+ "set_clip_box",
45
+ "set_snap",
46
+ "set_sketch_params",
47
+ "set_animated",
48
+ "set_picker",
49
+ )
50
+ )
51
+ class EdgeCollection(mpl.collections.PatchCollection):
52
+ """Artist for a collection of edges within a network/tree.
53
+
54
+ This artist is derived from PatchCollection with a few notable differences:
55
+ - It udpdates ends of each edge based on the vertex borders.
56
+ - It may contain edge labels as a child (a LabelCollection).
57
+ - For directed graphs, it contains arrows as a child (an EdgeArrowCollection).
58
+
59
+ This class is not designed to be instantiated directly but rather by internal
60
+ iplotx functions such as iplotx.network. However, some of its methods can be
61
+ called directly to edit edge style after the initial draw.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ patches: Sequence[mpl.patches.Patch],
67
+ vertex_ids: Sequence[tuple],
68
+ vertex_collection: VertexCollection,
69
+ *args,
70
+ transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
71
+ arrow_transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
72
+ directed: bool = False,
73
+ style: Optional[dict[str, Any]] = None,
74
+ **kwargs,
75
+ ) -> None:
76
+ """Initialise an EdgeCollection.
77
+
78
+ Parameters:
79
+ patches: A sequence (usually, list) of matplotlib `Patch`es describing the edges.
80
+ vertex_ids: A sequence of pairs `(v1, v2)`, each defining the ids of vertices at the
81
+ end of an edge.
82
+ vertex_collection: The VertexCollection instance containing the Artist for the
83
+ vertices. This is needed to compute vertex borders and adjust edges accordingly.
84
+ transform: The matplotlib transform for the edges, usually transData.
85
+ arrow_transform: The matplotlib transform for the arrow patches. This is not the
86
+ *offset_transform* of arrows, which is set equal to the edge transform (previous
87
+ parameter). Instead, it specifies how arrow size scales, similar to vertex size.
88
+ This is usually the identity transform.
89
+ directed: Whether the graph is directed (in which case arrows are drawn, possibly
90
+ with zero size or opacity to obtain an "arrowless" effect).
91
+ style: The edge style (subdictionary: "edge") to use at creation.
92
+ """
93
+ kwargs["match_original"] = True
94
+ self._vertex_ids = vertex_ids
95
+
96
+ self._vertex_collection = vertex_collection
97
+ self._style = style if style is not None else {}
98
+ self._labels = kwargs.pop("labels", None)
99
+ self._directed = directed
100
+ self._arrow_transform = arrow_transform
101
+ if "cmap" in self._style:
102
+ kwargs["cmap"] = self._style["cmap"]
103
+ kwargs["norm"] = self._style["norm"]
104
+
105
+ # NOTE: This should also set the transform
106
+ super().__init__(patches, transform=transform, *args, **kwargs)
107
+
108
+ # This is important because it prepares the right flags for scalarmappable
109
+ self.set_facecolor("none")
110
+
111
+ if self.directed:
112
+ self._arrows = EdgeArrowCollection(
113
+ self,
114
+ transform=self._arrow_transform,
115
+ )
116
+ if self._labels is not None:
117
+ style = self._style.get("label", {})
118
+ self._label_collection = LabelCollection(
119
+ self._labels,
120
+ style=style,
121
+ transform=transform,
122
+ )
123
+
124
+ def get_children(self) -> tuple:
125
+ children = []
126
+ if hasattr(self, "_arrows"):
127
+ children.append(self._arrows)
128
+ if hasattr(self, "_label_collection"):
129
+ children.append(self._label_collection)
130
+ return tuple(children)
131
+
132
+ def set_figure(self, fig) -> None:
133
+ super().set_figure(fig)
134
+ self._update_paths()
135
+ # NOTE: This sets the correct offsets in the arrows,
136
+ # but not the correct sizes (see below)
137
+ self._update_children()
138
+ for child in self.get_children():
139
+ # NOTE: This sets the sizes with correct dpi scaling in the arrows
140
+ child.set_figure(fig)
141
+
142
+ def _update_children(self):
143
+ self._update_arrows()
144
+ self._update_labels()
145
+
146
+ @property
147
+ def directed(self) -> bool:
148
+ """Whether the network is directed."""
149
+ return self._directed
150
+
151
+ @directed.setter
152
+ def directed(self, value) -> None:
153
+ """Setter for the directed property.
154
+
155
+ Changing this property triggers the addition/removal of arrows from the plot.
156
+ """
157
+ value = bool(value)
158
+ if self._directed != value:
159
+ # Moving to undirected, remove arrows
160
+ if not value:
161
+ self._arrows.remove()
162
+ del self._arrows
163
+ # Moving to directed, create arrows
164
+ else:
165
+ self._arrows = EdgeArrowCollection(
166
+ self,
167
+ transform=self._arrow_transform,
168
+ )
169
+
170
+ self._directed = value
171
+ # NOTE: setting stale to True should trigger a redraw as soon as needed
172
+ # and that will update children. We might need to verify that.
173
+ self.stale = True
174
+
175
+ def set_array(self, A) -> None:
176
+ """Set the array for cmap/norm coloring."""
177
+ # Preserve the alpha channel
178
+ super().set_array(A)
179
+ # Alpha needs to be kept separately
180
+ if self.get_alpha() is None:
181
+ self.set_alpha(self.get_edgecolor()[:, 3])
182
+ # This is necessary to ensure edgecolors are bool-flagged correctly
183
+ self.set_edgecolor(None)
184
+
185
+ def update_scalarmappable(self) -> None:
186
+ """Update colors from the scalar mappable array, if any.
187
+
188
+ Assign edge colors from a numerical array, and match arrow colors
189
+ if the graph is directed.
190
+ """
191
+ # NOTE: The superclass also sets stale = True
192
+ super().update_scalarmappable()
193
+ # Now self._edgecolors has the correct colorspace values
194
+ if hasattr(self, "_arrows"):
195
+ self._arrows.set_colors(self.get_edgecolor())
196
+
197
+ def get_labels(self) -> Optional[LabelCollection]:
198
+ """Get LabelCollection artist for labels if present."""
199
+ if hasattr(self, "_label_collection"):
200
+ return self._label_collection
201
+ return None
202
+
203
+ def get_mappable(self):
204
+ """Return mappable for colorbar."""
205
+ return self
206
+
207
+ def _get_adjacent_vertices_info(self):
208
+ index = self._vertex_collection.get_index()
209
+ index = pd.Series(
210
+ np.arange(len(index)),
211
+ index=index,
212
+ )
213
+
214
+ voffsets = []
215
+ vpaths = []
216
+ vsizes = []
217
+ for v1, v2 in self._vertex_ids:
218
+ # NOTE: these are in the original layout coordinate system
219
+ # not cartesianised yet.
220
+ offset1 = self._vertex_collection.get_layout().values[index[v1]]
221
+ offset2 = self._vertex_collection.get_layout().values[index[v2]]
222
+ voffsets.append((offset1, offset2))
223
+
224
+ path1 = self._vertex_collection.get_paths()[index[v1]]
225
+ path2 = self._vertex_collection.get_paths()[index[v2]]
226
+ vpaths.append((path1, path2))
227
+
228
+ # NOTE: This needs to be computed here because the
229
+ # VertexCollection._transforms are reset each draw in order to
230
+ # accomodate for DPI changes on the canvas
231
+ size1 = self._vertex_collection.get_sizes_dpi()[index[v1]]
232
+ size2 = self._vertex_collection.get_sizes_dpi()[index[v2]]
233
+ vsizes.append((size1, size2))
234
+
235
+ return {
236
+ "ids": self._vertex_ids,
237
+ "offsets": voffsets,
238
+ "paths": vpaths,
239
+ "sizes": vsizes,
240
+ }
241
+
242
+ def _update_paths(self, transform=None):
243
+ """Compute paths for the edges.
244
+
245
+ Loops split the largest wedge left open by other
246
+ edges of that vertex. The algo is:
247
+ (i) Find what vertices each loop belongs to
248
+ (ii) While going through the edges, record the angles
249
+ for vertices with loops
250
+ (iii) Plot each loop based on the recorded angles
251
+ """
252
+ vinfo = self._get_adjacent_vertices_info()
253
+ vids = vinfo["ids"]
254
+ vcenters = vinfo["offsets"]
255
+ vpaths = vinfo["paths"]
256
+ vsizes = vinfo["sizes"]
257
+ loopmaxangle = pi / 180.0 * self._style.get("loopmaxangle", 60.0)
258
+
259
+ if transform is None:
260
+ transform = self.get_transform()
261
+ trans = transform.transform
262
+ trans_inv = transform.inverted().transform
263
+
264
+ # 1. Make a list of vertices with loops, and store them for later
265
+ loop_vertex_dict = defaultdict(lambda: dict(indices=[], edge_angles=[]))
266
+ for i, (v1, v2) in enumerate(vids):
267
+ # Postpone loops (step 3)
268
+ if v1 == v2:
269
+ loop_vertex_dict[v1]["indices"].append(i)
270
+
271
+ # 2. Make paths for non-loop edges
272
+ # NOTE: keep track of parallel edges to offset them
273
+ parallel_edges = defaultdict(list)
274
+ paths = []
275
+ for i, (v1, v2) in enumerate(vids):
276
+ # Postpone loops (step 3)
277
+ if v1 == v2:
278
+ paths.append(None)
279
+ continue
280
+
281
+ # Coordinates of the adjacent vertices, in data coords
282
+ vcoord_data = vcenters[i]
283
+
284
+ # Vertex paths in figure (default) coords
285
+ vpath_fig = vpaths[i]
286
+
287
+ # Vertex size
288
+ vsize_fig = vsizes[i]
289
+
290
+ # Leaf rotation
291
+ edge_stylei = rotate_style(self._style, index=i, key=(v1, v2))
292
+ if edge_stylei.get("curved", False):
293
+ tension = edge_stylei.get("tension", 5)
294
+ ports = edge_stylei.get("ports", (None, None))
295
+ else:
296
+ tension = 0
297
+ ports = None
298
+
299
+ waypoints = edge_stylei.get("waypoints", "none")
300
+
301
+ # Compute actual edge path
302
+ path, angles = _compute_edge_path(
303
+ vcoord_data,
304
+ vpath_fig,
305
+ vsize_fig,
306
+ trans,
307
+ trans_inv,
308
+ tension=tension,
309
+ waypoints=waypoints,
310
+ ports=ports,
311
+ layout_coordinate_system=self._vertex_collection.get_layout_coordinate_system(),
312
+ )
313
+
314
+ # Collect angles for this vertex, to be used for loops plotting below
315
+ if v1 in loop_vertex_dict:
316
+ loop_vertex_dict[v1]["edge_angles"].append(angles[0])
317
+ if v2 in loop_vertex_dict:
318
+ loop_vertex_dict[v2]["edge_angles"].append(angles[1])
319
+
320
+ # Add the path for this non-loop edge
321
+ paths.append(path)
322
+ # FIXME: curved parallel edges depend on the direction of curvature...!
323
+ parallel_edges[(v1, v2)].append(i)
324
+
325
+ # Fix parallel edges
326
+ # If none found, empty the dictionary already
327
+ if (len(parallel_edges) == 0) or (max(parallel_edges.values(), key=len) == 1):
328
+ parallel_edges = {}
329
+ if not self._style.get("curved", False):
330
+ while len(parallel_edges) > 0:
331
+ (v1, v2), indices = parallel_edges.popitem()
332
+ indices_inv = parallel_edges.pop((v2, v1), [])
333
+ ntot = len(indices) + len(indices_inv)
334
+ if ntot > 1:
335
+ _fix_parallel_edges_straight(
336
+ paths,
337
+ indices,
338
+ indices_inv,
339
+ trans,
340
+ trans_inv,
341
+ offset=self._style.get("offset", 3),
342
+ )
343
+
344
+ # 3. Deal with loops at the end
345
+ for vid, ldict in loop_vertex_dict.items():
346
+ vpath = vpaths[ldict["indices"][0]][0]
347
+ vsize = vsizes[ldict["indices"][0]][0]
348
+ vcoord_fig = trans(vcenters[ldict["indices"][0]][0])
349
+ nloops = len(ldict["indices"])
350
+ edge_angles = ldict["edge_angles"]
351
+
352
+ # The space between the existing angles is where we can fit the loops
353
+ # One loop we can fit in the largest wedge, multiple loops we need
354
+ nloops_per_angle = _compute_loops_per_angle(nloops, edge_angles)
355
+
356
+ idx = 0
357
+ for theta1, theta2, nloops in nloops_per_angle:
358
+ # Angular size of each loop in this wedge
359
+ delta = (theta2 - theta1) / nloops
360
+
361
+ # Iterate over individual loops
362
+ for j in range(nloops):
363
+ thetaj1 = theta1 + j * delta + max(delta - loopmaxangle, 0) / 2
364
+ thetaj2 = thetaj1 + min(delta, loopmaxangle)
365
+
366
+ # Get the path for this loop
367
+ path = _compute_loop_path(
368
+ vcoord_fig,
369
+ vpath,
370
+ vsize,
371
+ thetaj1,
372
+ thetaj2,
373
+ trans_inv,
374
+ looptension=self._style.get("looptension", 2.5),
375
+ )
376
+ paths[ldict["indices"][idx]] = path
377
+ idx += 1
378
+
379
+ self._paths = paths
380
+
381
+ def _update_labels(self):
382
+ if self._labels is None:
383
+ return
384
+
385
+ style = self._style.get("label", None) if self._style is not None else {}
386
+ transform = self.get_transform()
387
+ trans = transform.transform
388
+
389
+ offsets = []
390
+ if not style.get("rotate", True):
391
+ rotations = []
392
+ for path in self._paths:
393
+ offset, rotation = _compute_mid_coord_and_rot(path, trans)
394
+ offsets.append(offset)
395
+ if not style.get("rotate", True):
396
+ rotations.append(rotation)
397
+
398
+ self._label_collection.set_offsets(offsets)
399
+ if not style.get("rotate", True):
400
+ self._label_collection.set_rotations(rotations)
401
+
402
+ def _update_arrows(
403
+ self,
404
+ ) -> None:
405
+ """Extract the start and/or end angles of the paths to compute arrows.
406
+
407
+ Parameters:
408
+ which: Which end of the edge to put an arrow on. Currently only "end" is accepted.
409
+
410
+ NOTE: This function does *not* update the arrow sizes/_transforms to the correct dpi
411
+ scaling. That's ok since the correct dpi scaling is set whenever there is a different
412
+ figure (before first draw) and whenever a draw is called.
413
+ """
414
+ if not hasattr(self, "_arrows"):
415
+ return
416
+
417
+ transform = self.get_transform()
418
+ trans = transform.transform
419
+
420
+ for i, epath in enumerate(self.get_paths()):
421
+ # Offset the arrow to point to the end of the edge
422
+ self._arrows._offsets[i] = epath.vertices[-1]
423
+
424
+ # Rotate the arrow to point in the direction of the edge
425
+ apath = self._arrows._paths[i]
426
+ # NOTE: because the tip of the arrow is at (0, 0) in patch space,
427
+ # in theory it will rotate around that point already
428
+ v2 = trans(epath.vertices[-1])
429
+ v1 = trans(epath.vertices[-2])
430
+ dv = v2 - v1
431
+ theta = atan2(*(dv[::-1]))
432
+ theta_old = self._arrows._angles[i]
433
+ dtheta = theta - theta_old
434
+ mrot = np.array([[cos(dtheta), sin(dtheta)], [-sin(dtheta), cos(dtheta)]])
435
+ apath.vertices = apath.vertices @ mrot
436
+ self._arrows._angles[i] = theta
437
+
438
+ @_stale_wrapper
439
+ def draw(self, renderer):
440
+ # Visibility affects the children too
441
+ if not self.get_visible():
442
+ return
443
+
444
+ self._update_paths()
445
+ # This sets the arrow offsets
446
+ self._update_children()
447
+
448
+ super().draw(renderer)
449
+ for child in self.get_children():
450
+ # This sets the arrow sizes with dpi scaling
451
+ child.draw(renderer)
452
+
453
+ def get_ports(self) -> Optional[LeafProperty[Pair[Optional[str]]]]:
454
+ """Get the ports for all edges.
455
+
456
+ Returns:
457
+ The ports for the edges, as a pair of strings or None for each edge. If None, it
458
+ means all edges are free.
459
+ """
460
+ return self._style.get("ports", None)
461
+
462
+ def set_ports(self, ports: Optional[LeafProperty[Pair[Optional[str]]]]) -> None:
463
+ """Set new ports for the edges.
464
+
465
+ Parameters:
466
+ ports: A pair of ports strings for each edge. Each port can be None to mean free
467
+ edge end.
468
+ """
469
+ if ports is None:
470
+ del self._style["ports"]
471
+ else:
472
+ self._style["ports"] = ports
473
+ self.stale = True
474
+
475
+ def get_tension(self) -> Optional[LeafProperty[float]]:
476
+ """Get the tension for the edges.
477
+
478
+ Returns:
479
+ The tension for the edges. If None, the edges are straight.
480
+ """
481
+ return self._style.get("tension", None)
482
+
483
+ def set_tension(self, tension: Optional[LeafProperty[float]]) -> None:
484
+ """Set new tension for the edges.
485
+
486
+ Parameters:
487
+ tension: The tension to use for curved edges. If None, the edges become straight.
488
+
489
+ Note: This function does not set self.set_curved(True) automatically. If you are
490
+ unsure whether that property is set already, you should call both functions.
491
+
492
+ Example:
493
+ # Set curved edges with different tensions
494
+ >>> network.get_edges().set_curved(True)
495
+ >>> network.get_edges().set_tension([1, 0.5])
496
+
497
+ # Set straight edges
498
+ # (the latter call is optional but helps readability)
499
+ >>> network.get_edges().set_curved(False)
500
+ >>> network.get_edges().set_tension(None)
501
+
502
+ """
503
+ if tension is None:
504
+ del self._style["tension"]
505
+ else:
506
+ self._style["tension"] = tension
507
+ self.stale = True
508
+
509
+ get_tensions = get_tension
510
+ set_tensions = set_tension
511
+
512
+ def get_curved(self) -> bool:
513
+ """Get whether the edges are curved or not.
514
+
515
+ Returns:
516
+ A bool that is True if the edges are curved, False if they are straight.
517
+ """
518
+ return self._style.get("curved", False)
519
+
520
+ def set_curved(self, curved: bool) -> None:
521
+ """Set whether the edges are curved or not.
522
+
523
+ Parameters:
524
+ curved: Whether the edges should be curved (True) or straight (False).
525
+
526
+ Note: If you want only some edges to be curved, set curved to True and set tensions to
527
+ 0 for the straight edges.
528
+ """
529
+ self._style["curved"] = bool(curved)
530
+ self.stale = True
531
+
532
+ def get_loopmaxangle(self) -> Optional[float]:
533
+ """Get the maximum angle for loops.
534
+
535
+ Returns:
536
+ The maximum angle in degrees that a loop can take. If None, the default is 60.
537
+ """
538
+ return self._style.get("loopmaxangle", 60)
539
+
540
+ def set_loopmaxangle(self, loopmaxangle: float) -> None:
541
+ """Set the maximum angle for loops.
542
+
543
+ Parameters:
544
+ loopmaxangle: The maximum angle in degrees that a loop can take.
545
+ """
546
+ self._style["loopmaxangle"] = loopmaxangle
547
+ self.stale = True
548
+
549
+ def get_looptension(self) -> Optional[float]:
550
+ """Get the tension for loops.
551
+
552
+ Returns:
553
+ The tension for loops. If None, the default is 2.5.
554
+ """
555
+ return self._style.get("looptension", 2.5)
556
+
557
+ def set_looptension(self, looptension: Optional[float]) -> None:
558
+ """Set new tension for loops.
559
+
560
+ Parameters:
561
+ looptension: The tension to use for loops. If None, the default is 2.5.
562
+ """
563
+ if looptension is None:
564
+ del self._style["looptension"]
565
+ else:
566
+ self._style["looptension"] = looptension
567
+ self.stale = True
568
+
569
+ def get_offset(self) -> Optional[float]:
570
+ """Get the offset for parallel straight edges.
571
+
572
+ Returns:
573
+ The offset in points for parallel straight edges. If None, the default is 3.
574
+ """
575
+ return self._style.get("offset", 3)
576
+
577
+ def set_offset(self, offset: Optional[float]) -> None:
578
+ """Set the offset for parallel straight edges.
579
+
580
+ Parameters:
581
+ offset: The offset in points for parallel straight edges. If None, the default is 3.
582
+ """
583
+ if offset is None:
584
+ del self._style["offset"]
585
+ else:
586
+ self._style["offset"] = offset
587
+ self.stale = True
588
+
589
+
590
+ def make_stub_patch(**kwargs):
591
+ """Make a stub undirected edge patch, without actual path information."""
592
+ kwargs["clip_on"] = kwargs.get("clip_on", True)
593
+ if ("color" in kwargs) and ("edgecolor" not in kwargs):
594
+ kwargs["edgecolor"] = kwargs.pop("color")
595
+
596
+ # Edges are always hollow, because they are not closed paths
597
+ # NOTE: This is supposed to cascade onto what boolean flags are set
598
+ # for color mapping (Colorizer)
599
+ kwargs["facecolor"] = "none"
600
+
601
+ # Forget specific properties that are not supported here
602
+ forbidden_props = [
603
+ "arrow",
604
+ "label",
605
+ "curved",
606
+ "tension",
607
+ "waypoints",
608
+ "ports",
609
+ "looptension",
610
+ "loopmaxangle",
611
+ "offset",
612
+ "cmap",
613
+ ]
614
+ for prop in forbidden_props:
615
+ if prop in kwargs:
616
+ kwargs.pop(prop)
617
+
618
+ # NOTE: the path is overwritten later anyway, so no reason to spend any time here
619
+ art = mpl.patches.PathPatch(
620
+ mpl.path.Path([[0, 0]]),
621
+ **kwargs,
622
+ )
623
+ return art