iplotx 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,882 @@
1
+ """
2
+ Module defining the main matplotlib Artist for network/tree edges, EdgeCollection.
3
+
4
+ Some supporting functions are also defined here.
5
+ """
6
+
7
+ from typing import (
8
+ Sequence,
9
+ Optional,
10
+ Never,
11
+ Any,
12
+ )
13
+ from math import atan2, tan, cos, pi, sin
14
+ from collections import defaultdict
15
+ import numpy as np
16
+ import pandas as pd
17
+ import matplotlib as mpl
18
+
19
+ from ..utils.matplotlib import (
20
+ _compute_mid_coord_and_rot,
21
+ _stale_wrapper,
22
+ _forwarder,
23
+ )
24
+ from ..style import (
25
+ rotate_style,
26
+ )
27
+ from ..label import LabelCollection
28
+ from ..vertex import VertexCollection
29
+ from .arrow import EdgeArrowCollection
30
+ from .ports import _get_port_unit_vector
31
+
32
+
33
+ @_forwarder(
34
+ (
35
+ "set_clip_path",
36
+ "set_clip_box",
37
+ "set_snap",
38
+ "set_sketch_params",
39
+ "set_animated",
40
+ "set_picker",
41
+ )
42
+ )
43
+ class EdgeCollection(mpl.collections.PatchCollection):
44
+ """Artist for a collection of edges within a network/tree.
45
+
46
+ This artist is derived from PatchCollection with a few notable differences:
47
+ - It udpdates ends of each edge based on the vertex borders.
48
+ - It may contain edge labels as a child (a LabelCollection).
49
+ - For directed graphs, it contains arrows as a child (an EdgeArrowCollection).
50
+
51
+ This class is not designed to be instantiated directly but rather by internal
52
+ iplotx functions such as iplotx.network. However, some of its methods can be
53
+ called directly to edit edge style after the initial draw.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ patches: Sequence[mpl.patches.Patch],
59
+ vertex_ids: Sequence[tuple],
60
+ vertex_collection: VertexCollection,
61
+ layout: pd.DataFrame,
62
+ *args,
63
+ layout_coordinate_system: str = "cartesian",
64
+ transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
65
+ arrow_transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
66
+ directed: bool = False,
67
+ style: Optional[dict[str, Any]] = None,
68
+ **kwargs,
69
+ ) -> Never:
70
+ """Initialise an EdgeCollection.
71
+
72
+ Parameters:
73
+ patches: A sequence (usually, list) of matplotlib `Patch`es describing the edges.
74
+ vertex_ids: A sequence of pairs `(v1, v2)`, each defining the ids of vertices at the
75
+ end of an edge.
76
+ vertex_collection: The VertexCollection instance containing the Artist for the
77
+ vertices. This is needed to compute vertex borders and adjust edges accordingly.
78
+ layout: The vertex layout.
79
+ layout_coordinate_system: The coordinate system the previous parameter is in. For
80
+ certain layouts, this might not be "cartesian" (e.g. "polar" layour for radial
81
+ trees).
82
+ transform: The matplotlib transform for the edges, usually transData.
83
+ arrow_transform: The matplotlib transform for the arrow patches. This is not the
84
+ *offset_transform* of arrows, which is set equal to the edge transform (previous
85
+ parameter). Instead, it specifies how arrow size scales, similar to vertex size.
86
+ This is usually the identity transform.
87
+ directed: Whether the graph is directed (in which case arrows are drawn, possibly
88
+ with zero size or opacity to obtain an "arrowless" effect).
89
+ style: The edge style (subdictionary: "edge") to use at creation.
90
+ """
91
+ kwargs["match_original"] = True
92
+ self._vertex_ids = vertex_ids
93
+
94
+ self._vertex_collection = vertex_collection
95
+ # NOTE: the layout is needed for non-cartesian coordinate systems
96
+ # for which information is lost upon cartesianisation (e.g. polar,
97
+ # for which multiple angles are degenerate in cartesian space).
98
+ self._layout = layout
99
+ self._layout_coordinate_system = layout_coordinate_system
100
+ self._style = style if style is not None else {}
101
+ self._labels = kwargs.pop("labels", None)
102
+ self._directed = directed
103
+ self._arrow_transform = arrow_transform
104
+ if "cmap" in self._style:
105
+ kwargs["cmap"] = self._style["cmap"]
106
+ kwargs["norm"] = self._style["norm"]
107
+
108
+ # NOTE: This should also set the transform
109
+ super().__init__(patches, transform=transform, *args, **kwargs)
110
+
111
+ # This is important because it prepares the right flags for scalarmappable
112
+ self.set_facecolor("none")
113
+
114
+ if self.directed:
115
+ self._arrows = EdgeArrowCollection(
116
+ self,
117
+ transform=self._arrow_transform,
118
+ )
119
+ if self._labels is not None:
120
+ style = self._style.get("label", {})
121
+ self._label_collection = LabelCollection(
122
+ self._labels,
123
+ style=style,
124
+ transform=transform,
125
+ )
126
+
127
+ def get_children(self) -> tuple:
128
+ children = []
129
+ if hasattr(self, "_arrows"):
130
+ children.append(self._arrows)
131
+ if hasattr(self, "_label_collection"):
132
+ children.append(self._label_collection)
133
+ return tuple(children)
134
+
135
+ def set_figure(self, fig) -> Never:
136
+ super().set_figure(fig)
137
+ self._update_paths()
138
+ # NOTE: This sets the correct offsets in the arrows,
139
+ # but not the correct sizes (see below)
140
+ self._update_children()
141
+ for child in self.get_children():
142
+ # NOTE: This sets the sizes with correct dpi scaling in the arrows
143
+ child.set_figure(fig)
144
+
145
+ def _update_children(self):
146
+ self._update_arrows()
147
+ self._update_labels()
148
+
149
+ @property
150
+ def directed(self) -> bool:
151
+ """Whether the network is directed."""
152
+ return self._directed
153
+
154
+ @directed.setter
155
+ def directed(self, value) -> Never:
156
+ """Setter for the directed property.
157
+
158
+ Changing this property triggers the addition/removal of arrows from the plot.
159
+ """
160
+ value = bool(value)
161
+ if self._directed != value:
162
+ # Moving to undirected, remove arrows
163
+ if not value:
164
+ self._arrows.remove()
165
+ del self._arrows
166
+ # Moving to directed, create arrows
167
+ else:
168
+ self._arrows = EdgeArrowCollection(
169
+ self,
170
+ transform=self._arrow_transform,
171
+ )
172
+
173
+ self._directed = value
174
+ # NOTE: setting stale to True should trigger a redraw as soon as needed
175
+ # and that will update children. We might need to verify that.
176
+ self.stale = True
177
+
178
+ def set_array(self, A) -> Never:
179
+ """Set the array for cmap/norm coloring."""
180
+ # Preserve the alpha channel
181
+ super().set_array(A)
182
+ # Alpha needs to be kept separately
183
+ if self.get_alpha() is None:
184
+ self.set_alpha(self.get_edgecolor()[:, 3])
185
+ # This is necessary to ensure edgecolors are bool-flagged correctly
186
+ self.set_edgecolor(None)
187
+
188
+ def update_scalarmappable(self) -> Never:
189
+ """Update colors from the scalar mappable array, if any.
190
+
191
+ Assign edge colors from a numerical array, and match arrow colors
192
+ if the graph is directed.
193
+ """
194
+ # NOTE: The superclass also sets stale = True
195
+ super().update_scalarmappable()
196
+ # Now self._edgecolors has the correct colorspace values
197
+ if hasattr(self, "_arrows"):
198
+ self._arrows.set_colors(self.get_edgecolors())
199
+
200
+ def get_labels(self) -> Optional[LabelCollection]:
201
+ """Get LabelCollection artist for labels if present."""
202
+ if hasattr(self, "_label_collection"):
203
+ return self._label_collection
204
+ return None
205
+
206
+ def get_mappable(self):
207
+ """Return mappable for colorbar."""
208
+ return self
209
+
210
+ def _get_adjacent_vertices_info(self):
211
+ index = self._vertex_collection.get_index()
212
+ index = pd.Series(
213
+ np.arange(len(index)),
214
+ index=index,
215
+ )
216
+
217
+ voffsets = []
218
+ vpaths = []
219
+ vsizes = []
220
+ for v1, v2 in self._vertex_ids:
221
+ # NOTE: these are in the original layout coordinate system
222
+ # not cartesianised yet.
223
+ offset1 = self._layout.values[index[v1]]
224
+ offset2 = self._layout.values[index[v2]]
225
+ voffsets.append((offset1, offset2))
226
+
227
+ path1 = self._vertex_collection.get_paths()[index[v1]]
228
+ path2 = self._vertex_collection.get_paths()[index[v2]]
229
+ vpaths.append((path1, path2))
230
+
231
+ # NOTE: This needs to be computed here because the
232
+ # VertexCollection._transforms are reset each draw in order to
233
+ # accomodate for DPI changes on the canvas
234
+ size1 = self._vertex_collection.get_sizes_dpi()[index[v1]]
235
+ size2 = self._vertex_collection.get_sizes_dpi()[index[v2]]
236
+ vsizes.append((size1, size2))
237
+
238
+ return {
239
+ "ids": self._vertex_ids,
240
+ "offsets": voffsets,
241
+ "paths": vpaths,
242
+ "sizes": vsizes,
243
+ }
244
+
245
+ def _update_paths(self, transform=None):
246
+ """Compute paths for the edges.
247
+
248
+ Loops split the largest wedge left open by other
249
+ edges of that vertex. The algo is:
250
+ (i) Find what vertices each loop belongs to
251
+ (ii) While going through the edges, record the angles
252
+ for vertices with loops
253
+ (iii) Plot each loop based on the recorded angles
254
+ """
255
+ vinfo = self._get_adjacent_vertices_info()
256
+ vids = vinfo["ids"]
257
+ vcenters = vinfo["offsets"]
258
+ vpaths = vinfo["paths"]
259
+ vsizes = vinfo["sizes"]
260
+ loopmaxangle = pi / 180.0 * self._style.get("loopmaxangle", pi / 3)
261
+
262
+ if transform is None:
263
+ transform = self.get_transform()
264
+ trans = transform.transform
265
+ trans_inv = transform.inverted().transform
266
+
267
+ # 1. Make a list of vertices with loops, and store them for later
268
+ loop_vertex_dict = defaultdict(lambda: dict(indices=[], edge_angles=[]))
269
+ for i, (v1, v2) in enumerate(vids):
270
+ # Postpone loops (step 3)
271
+ if v1 == v2:
272
+ loop_vertex_dict[v1]["indices"].append(i)
273
+
274
+ # 2. Make paths for non-loop edges
275
+ # NOTE: keep track of parallel edges to offset them
276
+ parallel_edges = defaultdict(list)
277
+ paths = []
278
+ for i, (v1, v2) in enumerate(vids):
279
+ # Postpone loops (step 3)
280
+ if v1 == v2:
281
+ paths.append(None)
282
+ continue
283
+
284
+ # Coordinates of the adjacent vertices, in data coords
285
+ vcoord_data = vcenters[i]
286
+
287
+ # Vertex paths in figure (default) coords
288
+ vpath_fig = vpaths[i]
289
+
290
+ # Vertex size
291
+ vsize_fig = vsizes[i]
292
+
293
+ # Leaf rotation
294
+ edge_stylei = rotate_style(self._style, index=i, key=(v1, v2))
295
+ if edge_stylei.get("curved", False):
296
+ tension = edge_stylei.get("tension", 5)
297
+ ports = edge_stylei.get("ports", (None, None))
298
+ else:
299
+ tension = 0
300
+ ports = None
301
+
302
+ waypoints = edge_stylei.get("waypoints", "none")
303
+
304
+ # Compute actual edge path
305
+ path, angles = self._compute_edge_path(
306
+ vcoord_data,
307
+ vpath_fig,
308
+ vsize_fig,
309
+ trans,
310
+ trans_inv,
311
+ tension=tension,
312
+ waypoints=waypoints,
313
+ ports=ports,
314
+ )
315
+
316
+ # Collect angles for this vertex, to be used for loops plotting below
317
+ if v1 in loop_vertex_dict:
318
+ loop_vertex_dict[v1]["edge_angles"].append(angles[0])
319
+ if v2 in loop_vertex_dict:
320
+ loop_vertex_dict[v2]["edge_angles"].append(angles[1])
321
+
322
+ # Add the path for this non-loop edge
323
+ paths.append(path)
324
+ # FIXME: curved parallel edges depend on the direction of curvature...!
325
+ parallel_edges[(v1, v2)].append(i)
326
+
327
+ # Fix parallel edges
328
+ # If none found, empty the dictionary already
329
+ if (len(parallel_edges) == 0) or (max(parallel_edges.values(), key=len) == 1):
330
+ parallel_edges = {}
331
+ if not self._style.get("curved", False):
332
+ while len(parallel_edges) > 0:
333
+ (v1, v2), indices = parallel_edges.popitem()
334
+ indices_inv = parallel_edges.pop((v2, v1), [])
335
+ ntot = len(indices) + len(indices_inv)
336
+ if ntot > 1:
337
+ self._fix_parallel_edges_straight(
338
+ paths,
339
+ indices,
340
+ indices_inv,
341
+ trans,
342
+ trans_inv,
343
+ offset=self._style.get("offset", 3),
344
+ )
345
+
346
+ # 3. Deal with loops at the end
347
+ for vid, ldict in loop_vertex_dict.items():
348
+ vpath = vpaths[ldict["indices"][0]][0]
349
+ vsize = vsizes[ldict["indices"][0]][0]
350
+ vcoord_fig = trans(vcenters[ldict["indices"][0]][0])
351
+ nloops = len(ldict["indices"])
352
+ edge_angles = ldict["edge_angles"]
353
+
354
+ # The space between the existing angles is where we can fit the loops
355
+ # One loop we can fit in the largest wedge, multiple loops we need
356
+ nloops_per_angle = self._compute_loops_per_angle(nloops, edge_angles)
357
+
358
+ idx = 0
359
+ for theta1, theta2, nloops in nloops_per_angle:
360
+ # Angular size of each loop in this wedge
361
+ delta = (theta2 - theta1) / nloops
362
+
363
+ # Iterate over individual loops
364
+ for j in range(nloops):
365
+ thetaj1 = theta1 + j * delta + max(delta - loopmaxangle, 0) / 2
366
+ thetaj2 = thetaj1 + min(delta, loopmaxangle)
367
+
368
+ # Get the path for this loop
369
+ path = self._compute_loop_path(
370
+ vcoord_fig,
371
+ vpath,
372
+ vsize,
373
+ thetaj1,
374
+ thetaj2,
375
+ trans_inv,
376
+ looptension=self._style.get("looptension", 2.5),
377
+ )
378
+ paths[ldict["indices"][idx]] = path
379
+ idx += 1
380
+
381
+ self._paths = paths
382
+
383
+ def _fix_parallel_edges_straight(
384
+ self,
385
+ paths,
386
+ indices,
387
+ indices_inv,
388
+ trans,
389
+ trans_inv,
390
+ offset=3,
391
+ ):
392
+ """Offset parallel edges along the same path."""
393
+ ntot = len(indices) + len(indices_inv)
394
+
395
+ # This is straight so two vertices anyway
396
+ # NOTE: all paths will be the same, which is why we need to offset them
397
+ vs, ve = trans(paths[indices[0]].vertices)
398
+
399
+ # Move orthogonal to the line
400
+ fracs = (
401
+ (vs - ve) / np.sqrt(((vs - ve) ** 2).sum()) @ np.array([[0, 1], [-1, 0]])
402
+ )
403
+
404
+ # NOTE: for now treat both direction the same
405
+ for i, idx in enumerate(indices + indices_inv):
406
+ # Offset the path
407
+ paths[idx].vertices = trans_inv(
408
+ trans(paths[idx].vertices) + fracs * offset * (i - ntot / 2)
409
+ )
410
+
411
+ def _compute_loop_path(
412
+ self,
413
+ vcoord_fig,
414
+ vpath,
415
+ vsize,
416
+ angle1,
417
+ angle2,
418
+ trans_inv,
419
+ looptension,
420
+ ):
421
+ # Shorten at starting angle
422
+ start = self._get_shorter_edge_coords(vpath, vsize, angle1) + vcoord_fig
423
+ # Shorten at end angle
424
+ end = self._get_shorter_edge_coords(vpath, vsize, angle2) + vcoord_fig
425
+
426
+ aux1 = (start - vcoord_fig) * looptension + vcoord_fig
427
+ aux2 = (end - vcoord_fig) * looptension + vcoord_fig
428
+
429
+ vertices = np.vstack(
430
+ [
431
+ start,
432
+ aux1,
433
+ aux2,
434
+ end,
435
+ ]
436
+ )
437
+ codes = ["MOVETO"] + ["CURVE4"] * 3
438
+
439
+ # Offset to place and transform to data coordinates
440
+ vertices = trans_inv(vertices)
441
+ codes = [getattr(mpl.path.Path, x) for x in codes]
442
+ path = mpl.path.Path(
443
+ vertices,
444
+ codes=codes,
445
+ )
446
+ return path
447
+
448
+ def _compute_edge_path(
449
+ self,
450
+ *args,
451
+ **kwargs,
452
+ ):
453
+ tension = kwargs.pop("tension", 0)
454
+ waypoints = kwargs.pop("waypoints", "none")
455
+ ports = kwargs.pop("ports", (None, None))
456
+
457
+ if (waypoints != "none") and (tension != 0):
458
+ raise ValueError("Waypoints not supported for curved edges.")
459
+
460
+ if waypoints != "none":
461
+ return self._compute_edge_path_waypoints(waypoints, *args, **kwargs)
462
+
463
+ if tension == 0:
464
+ return self._compute_edge_path_straight(*args, **kwargs)
465
+
466
+ return self._compute_edge_path_curved(
467
+ tension,
468
+ *args,
469
+ ports=ports,
470
+ **kwargs,
471
+ )
472
+
473
+ def _compute_edge_path_waypoints(
474
+ self,
475
+ waypoints,
476
+ vcoord_data,
477
+ vpath_fig,
478
+ vsize_fig,
479
+ trans,
480
+ trans_inv,
481
+ points_per_curve=30,
482
+ **kwargs,
483
+ ):
484
+
485
+ if waypoints in ("x0y1", "y0x1"):
486
+ assert self._layout_coordinate_system == "cartesian"
487
+
488
+ # Coordinates in figure (default) coords
489
+ vcoord_fig = trans(vcoord_data)
490
+
491
+ if waypoints == "x0y1":
492
+ waypoint = np.array([vcoord_fig[0][0], vcoord_fig[1][1]])
493
+ else:
494
+ waypoint = np.array([vcoord_fig[1][0], vcoord_fig[0][1]])
495
+
496
+ # Angles of the straight lines
497
+ theta0 = atan2(*((waypoint - vcoord_fig[0])[::-1]))
498
+ theta1 = atan2(*((waypoint - vcoord_fig[1])[::-1]))
499
+
500
+ # Shorten at starting vertex
501
+ vs = (
502
+ self._get_shorter_edge_coords(vpath_fig[0], vsize_fig[0], theta0)
503
+ + vcoord_fig[0]
504
+ )
505
+
506
+ # Shorten at end vertex
507
+ ve = (
508
+ self._get_shorter_edge_coords(vpath_fig[1], vsize_fig[1], theta1)
509
+ + vcoord_fig[1]
510
+ )
511
+
512
+ points = [vs, waypoint, ve]
513
+ codes = ["MOVETO", "LINETO", "LINETO"]
514
+ angles = (theta0, theta1)
515
+ elif waypoints == "r0a1":
516
+ assert self._layout_coordinate_system == "polar"
517
+
518
+ r0, alpha0 = vcoord_data[0]
519
+ r1, alpha1 = vcoord_data[1]
520
+ idx_inner = np.argmin([r0, r1])
521
+ idx_outer = 1 - idx_inner
522
+ alpha_outer = [alpha0, alpha1][idx_outer]
523
+
524
+ # FIXME: this is aware of chirality as stored by the layout function
525
+ betas = np.linspace(alpha0, alpha1, points_per_curve)
526
+ waypoints = [r0, r1][idx_inner] * np.vstack(
527
+ [np.cos(betas), np.sin(betas)]
528
+ ).T
529
+ endpoint = [r0, r1][idx_outer] * np.array(
530
+ [np.cos(alpha_outer), np.sin(alpha_outer)]
531
+ )
532
+ points = np.array(list(waypoints) + [endpoint])
533
+ points = trans(points)
534
+ codes = ["MOVETO"] + ["LINETO"] * len(waypoints)
535
+ # FIXME: same as previus comment
536
+ angles = (alpha0 + pi / 2, alpha1)
537
+
538
+ else:
539
+ raise NotImplementedError(
540
+ f"Edge shortening with waypoints not implemented yet: {waypoints}.",
541
+ )
542
+
543
+ path = mpl.path.Path(
544
+ points,
545
+ codes=[getattr(mpl.path.Path, x) for x in codes],
546
+ )
547
+
548
+ path.vertices = trans_inv(path.vertices)
549
+ return path, angles
550
+
551
+ def _compute_edge_path_straight(
552
+ self,
553
+ vcoord_data,
554
+ vpath_fig,
555
+ vsize_fig,
556
+ trans,
557
+ trans_inv,
558
+ **kwargs,
559
+ ):
560
+
561
+ # Coordinates in figure (default) coords
562
+ vcoord_fig = trans(vcoord_data)
563
+
564
+ points = []
565
+
566
+ # Angle of the straight line
567
+ theta = atan2(*((vcoord_fig[1] - vcoord_fig[0])[::-1]))
568
+
569
+ # Shorten at starting vertex
570
+ vs = (
571
+ self._get_shorter_edge_coords(vpath_fig[0], vsize_fig[0], theta)
572
+ + vcoord_fig[0]
573
+ )
574
+ points.append(vs)
575
+
576
+ # Shorten at end vertex
577
+ ve = (
578
+ self._get_shorter_edge_coords(vpath_fig[1], vsize_fig[1], theta + pi)
579
+ + vcoord_fig[1]
580
+ )
581
+ points.append(ve)
582
+
583
+ codes = ["MOVETO", "LINETO"]
584
+ path = mpl.path.Path(
585
+ points,
586
+ codes=[getattr(mpl.path.Path, x) for x in codes],
587
+ )
588
+ path.vertices = trans_inv(path.vertices)
589
+ return path, (theta, theta + np.pi)
590
+
591
+ def _compute_edge_path_curved(
592
+ self,
593
+ tension,
594
+ vcoord_data,
595
+ vpath_fig,
596
+ vsize_fig,
597
+ trans,
598
+ trans_inv,
599
+ ports=(None, None),
600
+ ):
601
+ """Shorten the edge path along a cubic Bezier between the vertex centres.
602
+
603
+ The most important part is that the derivative of the Bezier at the start
604
+ and end point towards the vertex centres: people notice if they do not.
605
+ """
606
+
607
+ # Coordinates in figure (default) coords
608
+ vcoord_fig = trans(vcoord_data)
609
+
610
+ dv = vcoord_fig[1] - vcoord_fig[0]
611
+ edge_straight_length = np.sqrt((dv**2).sum())
612
+
613
+ auxs = [None, None]
614
+ for i in range(2):
615
+ if ports[i] is not None:
616
+ der = _get_port_unit_vector(ports[i], trans_inv)
617
+ auxs[i] = der * edge_straight_length * tension + vcoord_fig[i]
618
+
619
+ # Both ports defined, just use them and hope for the best
620
+ # Obviously, if the user specifies ports that make no sense,
621
+ # this is going to be a (technically valid) mess.
622
+ if all(aux is not None for aux in auxs):
623
+ pass
624
+
625
+ # If no ports are specified (the most common case), compute
626
+ # the Bezier and shorten it
627
+ elif all(aux is None for aux in auxs):
628
+ # Put auxs along the way
629
+ auxs = np.array(
630
+ [
631
+ vcoord_fig[0] + 0.33 * dv,
632
+ vcoord_fig[1] - 0.33 * dv,
633
+ ]
634
+ )
635
+ # Right rotation from the straight edge
636
+ dv_rot = -0.1 * dv @ np.array([[0, 1], [-1, 0]])
637
+ # Shift the auxs orthogonal to the straight edge
638
+ auxs += dv_rot * tension
639
+
640
+ # First port is defined
641
+ elif (auxs[0] is not None) and (auxs[1] is None):
642
+ auxs[1] = auxs[0]
643
+
644
+ # Second port is defined
645
+ else:
646
+ auxs[0] = auxs[1]
647
+
648
+ vs = [None, None]
649
+ thetas = [None, None]
650
+ for i in range(2):
651
+ thetas[i] = atan2(*((auxs[i] - vcoord_fig[i])[::-1]))
652
+ vs[i] = (
653
+ self._get_shorter_edge_coords(vpath_fig[i], vsize_fig[i], thetas[i])
654
+ + vcoord_fig[i]
655
+ )
656
+
657
+ path = {
658
+ "vertices": [
659
+ vs[0],
660
+ auxs[0],
661
+ auxs[1],
662
+ vs[1],
663
+ ],
664
+ "codes": ["MOVETO"] + ["CURVE4"] * 3,
665
+ }
666
+
667
+ path = mpl.path.Path(
668
+ path["vertices"],
669
+ codes=[getattr(mpl.path.Path, x) for x in path["codes"]],
670
+ )
671
+
672
+ # Return to data transform
673
+ path.vertices = trans_inv(path.vertices)
674
+ return path, tuple(thetas)
675
+
676
+ def _update_labels(self):
677
+ if self._labels is None:
678
+ return
679
+
680
+ style = self._style.get("label", None) if self._style is not None else {}
681
+ transform = self.get_transform()
682
+ trans = transform.transform
683
+
684
+ offsets = []
685
+ if not style.get("rotate", True):
686
+ rotations = []
687
+ for path in self._paths:
688
+ offset, rotation = _compute_mid_coord_and_rot(path, trans)
689
+ offsets.append(offset)
690
+ if not style.get("rotate", True):
691
+ rotations.append(rotation)
692
+
693
+ self._label_collection.set_offsets(offsets)
694
+ if not style.get("rotate", True):
695
+ self._label_collection.set_rotations(rotations)
696
+
697
+ def _update_arrows(
698
+ self,
699
+ which: str = "end",
700
+ ) -> None:
701
+ """Extract the start and/or end angles of the paths to compute arrows.
702
+
703
+ Parameters:
704
+ which: Which end of the edge to put an arrow on. Currently only "end" is accepted.
705
+
706
+ NOTE: This function does *not* update the arrow sizes/_transforms to the correct dpi scaling.
707
+ That's ok since the correct dpi scaling is set whenever there is a different figure (before
708
+ first draw) and whenever a draw is called.
709
+ """
710
+ if not hasattr(self, "_arrows"):
711
+ return
712
+
713
+ transform = self.get_transform()
714
+ trans = transform.transform
715
+
716
+ for i, epath in enumerate(self.get_paths()):
717
+ # Offset the arrow to point to the end of the edge
718
+ self._arrows._offsets[i] = epath.vertices[-1]
719
+
720
+ # Rotate the arrow to point in the direction of the edge
721
+ apath = self._arrows._paths[i]
722
+ # NOTE: because the tip of the arrow is at (0, 0) in patch space,
723
+ # in theory it will rotate around that point already
724
+ v2 = trans(epath.vertices[-1])
725
+ v1 = trans(epath.vertices[-2])
726
+ dv = v2 - v1
727
+ theta = atan2(*(dv[::-1]))
728
+ theta_old = self._arrows._angles[i]
729
+ dtheta = theta - theta_old
730
+ mrot = np.array([[cos(dtheta), sin(dtheta)], [-sin(dtheta), cos(dtheta)]])
731
+ apath.vertices = apath.vertices @ mrot
732
+ self._arrows._angles[i] = theta
733
+
734
+ @_stale_wrapper
735
+ def draw(self, renderer):
736
+ # Visibility affects the children too
737
+ if not self.get_visible():
738
+ return
739
+
740
+ self._update_paths()
741
+ # This sets the arrow offsets
742
+ self._update_children()
743
+
744
+ super().draw(renderer)
745
+ for child in self.get_children():
746
+ # This sets the arrow sizes with dpi scaling
747
+ child.draw(renderer)
748
+
749
+ @property
750
+ def stale(self):
751
+ return super().stale
752
+
753
+ @stale.setter
754
+ def stale(self, val):
755
+ mpl.collections.PatchCollection.stale.fset(self, val)
756
+ if val and hasattr(self, "stale_callback_post"):
757
+ self.stale_callback_post(self)
758
+
759
+ @staticmethod
760
+ def _compute_loops_per_angle(nloops, angles):
761
+ if len(angles) == 0:
762
+ return [(0, 2 * pi, nloops)]
763
+
764
+ angles_sorted_closed = list(sorted(angles))
765
+ angles_sorted_closed.append(angles_sorted_closed[0] + 2 * pi)
766
+ deltas = np.diff(angles_sorted_closed)
767
+
768
+ # Now we have the deltas and the total number of loops
769
+ # 1. Assign all loops to the largest wedge
770
+ idx_dmax = deltas.argmax()
771
+ if nloops == 1:
772
+ return [
773
+ (
774
+ angles_sorted_closed[idx_dmax],
775
+ angles_sorted_closed[idx_dmax + 1],
776
+ nloops,
777
+ )
778
+ ]
779
+
780
+ # 2. Check if any other wedges are larger than this
781
+ # If not, we are done (this is the algo in igraph)
782
+ dsplit = deltas[idx_dmax] / nloops
783
+ if (deltas > dsplit).sum() < 2:
784
+ return [
785
+ (
786
+ angles_sorted_closed[idx_dmax],
787
+ angles_sorted_closed[idx_dmax + 1],
788
+ nloops,
789
+ )
790
+ ]
791
+
792
+ # 3. Check how small the second-largest wedge would become
793
+ idx_dsort = np.argsort(deltas)
794
+ return [
795
+ (
796
+ angles_sorted_closed[idx_dmax],
797
+ angles_sorted_closed[idx_dmax + 1],
798
+ nloops - 1,
799
+ ),
800
+ (
801
+ angles_sorted_closed[idx_dsort[-2]],
802
+ angles_sorted_closed[idx_dsort[-2] + 1],
803
+ 1,
804
+ ),
805
+ ]
806
+
807
+ @staticmethod
808
+ def _get_shorter_edge_coords(vpath, vsize, theta):
809
+ # Bound theta from -pi to pi (why is that not guaranteed?)
810
+ theta = (theta + pi) % (2 * pi) - pi
811
+
812
+ # Size zero vertices need no shortening
813
+ if vsize == 0:
814
+ return np.array([0, 0])
815
+
816
+ for i in range(len(vpath)):
817
+ v1 = vpath.vertices[i]
818
+ v2 = vpath.vertices[(i + 1) % len(vpath)]
819
+ theta1 = atan2(*((v1)[::-1]))
820
+ theta2 = atan2(*((v2)[::-1]))
821
+
822
+ # atan2 ranges ]-3.14, 3.14]
823
+ # so it can be that theta1 is -3 and theta2 is +3
824
+ # therefore we need two separate cases, one that cuts at pi and one at 0
825
+ cond1 = theta1 <= theta <= theta2
826
+ cond2 = (
827
+ (theta1 + 2 * pi) % (2 * pi)
828
+ <= (theta + 2 * pi) % (2 * pi)
829
+ <= (theta2 + 2 * pi) % (2 * pi)
830
+ )
831
+ if cond1 or cond2:
832
+ break
833
+ else:
834
+ raise ValueError("Angle for patch not found")
835
+
836
+ # The edge meets the patch of the vertex on the v1-v2 size,
837
+ # at angle theta from the center
838
+ mtheta = tan(theta)
839
+ if v2[0] == v1[0]:
840
+ xe = v1[0]
841
+ else:
842
+ m12 = (v2[1] - v1[1]) / (v2[0] - v1[0])
843
+ xe = (v1[1] - m12 * v1[0]) / (mtheta - m12)
844
+ ye = mtheta * xe
845
+ ve = np.array([xe, ye])
846
+ return ve * vsize
847
+
848
+
849
+ def make_stub_patch(**kwargs):
850
+ """Make a stub undirected edge patch, without actual path information."""
851
+ kwargs["clip_on"] = kwargs.get("clip_on", True)
852
+ if ("color" in kwargs) and ("edgecolor" not in kwargs):
853
+ kwargs["edgecolor"] = kwargs.pop("color")
854
+
855
+ # Edges are always hollow, because they are not closed paths
856
+ # NOTE: This is supposed to cascade onto what boolean flags are set
857
+ # for color mapping (Colorizer)
858
+ kwargs["facecolor"] = "none"
859
+
860
+ # Forget specific properties that are not supported here
861
+ forbidden_props = [
862
+ "arrow",
863
+ "label",
864
+ "curved",
865
+ "tension",
866
+ "waypoints",
867
+ "ports",
868
+ "looptension",
869
+ "loopmaxangle",
870
+ "offset",
871
+ "cmap",
872
+ ]
873
+ for prop in forbidden_props:
874
+ if prop in kwargs:
875
+ kwargs.pop(prop)
876
+
877
+ # NOTE: the path is overwritten later anyway, so no reason to spend any time here
878
+ art = mpl.patches.PathPatch(
879
+ mpl.path.Path([[0, 0]]),
880
+ **kwargs,
881
+ )
882
+ return art