iplotx 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iplotx/cascades.py ADDED
@@ -0,0 +1,223 @@
1
+ from typing import (
2
+ Any,
3
+ Optional,
4
+ )
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from .typing import (
9
+ TreeType,
10
+ )
11
+ from .ingest.typing import (
12
+ TreeDataProvider,
13
+ )
14
+ import matplotlib as mpl
15
+
16
+ from .style import (
17
+ copy_with_deep_values,
18
+ rotate_style,
19
+ )
20
+
21
+
22
+ class CascadeCollection(mpl.collections.PatchCollection):
23
+ def __init__(
24
+ self,
25
+ tree: TreeType,
26
+ layout: pd.DataFrame,
27
+ layout_name: str,
28
+ orientation: str,
29
+ style: dict[str, Any],
30
+ provider: TreeDataProvider,
31
+ transform: mpl.transforms.Transform,
32
+ maxdepth: Optional[float] = None,
33
+ ):
34
+ self._layout_name = layout_name
35
+ self._orientation = orientation
36
+ style = copy_with_deep_values(style)
37
+ zorder = style.get("zorder", 0)
38
+
39
+ # NOTE: there is a weird bug in pandas when using generic Hashable-s
40
+ # with .loc. Seems like doing .T[...] works for individual index
41
+ # elements only though
42
+ def get_node_coords(node):
43
+ return layout.T[node].values
44
+
45
+ def get_leaves_coords(leaves):
46
+ return np.array(
47
+ [get_node_coords(leaf) for leaf in leaves],
48
+ )
49
+
50
+ if "color" in style:
51
+ style["facecolor"] = style["edgecolor"] = style.pop("color")
52
+ extend = style.get("extend", False)
53
+
54
+ # These patches need at least a facecolor (usually) or an edgecolor
55
+ # so it's safe to make a list from these
56
+ nodes_unordered = set()
57
+ for prop in ("facecolor", "edgecolor"):
58
+ if prop in style:
59
+ nodes_unordered |= set(style[prop].keys())
60
+
61
+ # Draw the patches from the closest to the root (earlier drawing)
62
+ # to the closer to the leaves (later drawing).
63
+ drawing_order = []
64
+ for node in provider(tree).preorder():
65
+ if node in nodes_unordered:
66
+ drawing_order.append(node)
67
+
68
+ if layout_name not in ("horizontal", "vertical", "radial"):
69
+ raise NotImplementedError(
70
+ f"Cascading patches not implemented for layout: {layout_name}.",
71
+ )
72
+
73
+ nleaves = sum(1 for leaf in provider(tree).get_leaves())
74
+ extend_mode = style.get("extend", False)
75
+ if extend_mode and (extend_mode != "leaf_labels"):
76
+ if layout_name == "horizontal":
77
+ if orientation == "right":
78
+ maxdepth = layout.values[:, 0].max()
79
+ else:
80
+ maxdepth = layout.values[:, 0].min()
81
+ elif layout_name == "vertical":
82
+ if orientation == "descending":
83
+ maxdepth = layout.values[:, 1].min()
84
+ else:
85
+ maxdepth = layout.values[:, 1].max()
86
+ elif layout_name == "radial":
87
+ # layout values are: r, theta
88
+ maxdepth = layout.values[:, 0].max()
89
+ self._maxdepth = maxdepth
90
+
91
+ cascading_patches = []
92
+ for node in drawing_order:
93
+ stylei = rotate_style(style, key=node)
94
+ stylei.pop("extend", None)
95
+ # Default alpha is 0.5 for simple colors
96
+ if isinstance(stylei.get("facecolor", None), str) and (
97
+ "alpha" not in stylei
98
+ ):
99
+ stylei["alpha"] = 0.5
100
+
101
+ provider_node = provider(node)
102
+ bl = provider_node.get_branch_length_default_to_one(node)
103
+ node_coords = get_node_coords(node).copy()
104
+ leaves_coords = get_leaves_coords(provider_node.get_leaves())
105
+ if len(leaves_coords) == 0:
106
+ leaves_coords = np.array([node_coords])
107
+
108
+ if layout_name in ("horizontal", "vertical"):
109
+ if layout_name == "horizontal":
110
+ ybot = leaves_coords[:, 1].min() - 0.5
111
+ ytop = leaves_coords[:, 1].max() + 0.5
112
+ if orientation == "right":
113
+ xleft = node_coords[0] - bl
114
+ xright = maxdepth if extend else leaves_coords[:, 0].max()
115
+ else:
116
+ xleft = maxdepth if extend else leaves_coords[:, 0].min()
117
+ xright = node_coords[0] + bl
118
+ elif layout_name == "vertical":
119
+ xleft = leaves_coords[:, 0].min() - 0.5
120
+ xright = leaves_coords[:, 0].max() + 0.5
121
+ if orientation == "descending":
122
+ ytop = node_coords[1] + bl
123
+ ybot = maxdepth if extend else leaves_coords[:, 1].min()
124
+ else:
125
+ ytop = maxdepth if extend else leaves_coords[:, 1].max()
126
+ ybot = node_coords[1] - bl
127
+
128
+ patch = mpl.patches.Rectangle(
129
+ (xleft, ybot),
130
+ xright - xleft,
131
+ ytop - ybot,
132
+ **stylei,
133
+ )
134
+ elif layout_name == "radial":
135
+ dtheta = 2 * np.pi / nleaves
136
+ rmin = node_coords[0] - bl
137
+ rmax = maxdepth if extend else leaves_coords[:, 0].max()
138
+ thetamin = leaves_coords[:, 1].min() - 0.5 * dtheta
139
+ thetamax = leaves_coords[:, 1].max() + 0.5 * dtheta
140
+ thetas = np.linspace(
141
+ thetamin, thetamax, max(30, (thetamax - thetamin) // 3)
142
+ )
143
+ xs = list(rmin * np.cos(thetas)) + list(rmax * np.cos(thetas[::-1]))
144
+ ys = list(rmin * np.sin(thetas)) + list(rmax * np.sin(thetas[::-1]))
145
+ points = list(zip(xs, ys))
146
+ points.append(points[0])
147
+ codes = ["MOVETO"] + ["LINETO"] * (len(points) - 2) + ["CLOSEPOLY"]
148
+
149
+ if "edgecolor" not in stylei:
150
+ stylei["edgecolor"] = "none"
151
+
152
+ path = mpl.path.Path(
153
+ points,
154
+ codes=[getattr(mpl.path.Path, code) for code in codes],
155
+ )
156
+ patch = mpl.patches.PathPatch(
157
+ path,
158
+ **stylei,
159
+ )
160
+
161
+ cascading_patches.append(patch)
162
+
163
+ super().__init__(
164
+ cascading_patches,
165
+ transform=transform,
166
+ match_original=True,
167
+ zorder=zorder,
168
+ )
169
+
170
+ def get_maxdepth(self) -> float:
171
+ """Get the maxdepth of the cascades.
172
+
173
+ Returns: The maximum depth of the cascading patches.
174
+ """
175
+ return self._maxdepth
176
+
177
+ def set_maxdepth(self, maxdepth: float):
178
+ """Set the maximum depth of the cascading patches.
179
+
180
+ Parameters:
181
+ maxdepth: The new maximum depth for the cascades.
182
+
183
+ NOTE: Calling this function updates the cascade patches
184
+ without chechking whether the extent style requires it.
185
+ """
186
+ self._maxdepth = maxdepth
187
+ self._update_maxdepth()
188
+
189
+ def _update_maxdepth(self):
190
+ """Update the cascades with a new max depth.
191
+
192
+ Note: This function changes the paths without checking whether
193
+ the extent is set or not.
194
+ """
195
+ layout_name = self._layout_name
196
+ orientation = self._orientation
197
+
198
+ # This being a PatchCollection, we have to touch the paths
199
+ if layout_name == "radial":
200
+ for path in self.get_paths():
201
+ # Old radii
202
+ r2old = np.linalg.norm(path.vertices[-2])
203
+ path.vertices[(len(path.vertices) - 1) // 2 :] *= (
204
+ self.get_maxdepth() / r2old
205
+ )
206
+ return
207
+
208
+ if (layout_name, orientation) == ("horizontal", "right"):
209
+ for path in self.get_paths():
210
+ path.vertices[[1, 2], 0] = self.get_maxdepth()
211
+ elif (layout_name, orientation) == ("horizontal", "right"):
212
+ for path in self.get_paths():
213
+ path.vertices[[0, 3], 0] = self.get_maxdepth()
214
+ elif (layout_name, orientation) == ("vertical", "descending"):
215
+ for path in self.get_paths():
216
+ path.vertices[[1, 2], 1] = self.get_maxdepth()
217
+ elif (layout_name, orientation) == ("vertical", "ascending"):
218
+ for path in self.get_paths():
219
+ path.vertices[[0, 3], 1] = self.get_maxdepth()
220
+ else:
221
+ raise ValueError(
222
+ f"Layout name and orientation not supported: {layout_name}, {orientation}."
223
+ )