iplotx 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iplotx/ingest/typing.py CHANGED
@@ -11,14 +11,20 @@ from typing import (
11
11
  Protocol,
12
12
  Optional,
13
13
  Sequence,
14
+ Any,
15
+ Iterable,
14
16
  )
15
17
  from collections.abc import Hashable
18
+ import numpy as np
16
19
  import pandas as pd
17
20
  from ..typing import (
18
21
  GraphType,
19
22
  LayoutType,
20
23
  TreeType,
21
24
  )
25
+ from .heuristics import (
26
+ normalise_tree_layout,
27
+ )
22
28
 
23
29
 
24
30
  class NetworkData(TypedDict):
@@ -44,15 +50,13 @@ class NetworkDataProvider(Protocol):
44
50
  """Create network data object for iplotx from any provider."""
45
51
  raise NotImplementedError("Network data providers must implement this method.")
46
52
 
47
- def check_dependencies(
48
- self,
49
- ):
53
+ @staticmethod
54
+ def check_dependencies():
50
55
  """Check whether the dependencies for this provider are installed."""
51
56
  raise NotImplementedError("Network data providers must implement this method.")
52
57
 
53
- def graph_type(
54
- self,
55
- ):
58
+ @staticmethod
59
+ def graph_type():
56
60
  """Return the graph type from this provider to check for instances."""
57
61
  raise NotImplementedError("Network data providers must implement this method.")
58
62
 
@@ -63,11 +67,12 @@ class TreeData(TypedDict):
63
67
  rooted: bool
64
68
  directed: bool | str
65
69
  root: Optional[Hashable]
66
- leaves: list[Hashable]
70
+ leaf_df: pd.DataFrame
67
71
  vertex_df: dict[Hashable, tuple[float, float]]
68
72
  edge_df: dict[Hashable, Sequence[tuple[float, float]]]
69
73
  layout_coordinate_system: str
70
74
  layout_name: str
75
+ orientation: str
71
76
  ndim: int
72
77
  tree_library: NotRequired[str]
73
78
 
@@ -75,26 +80,224 @@ class TreeData(TypedDict):
75
80
  class TreeDataProvider(Protocol):
76
81
  """Protocol for tree data ingestion provider for iplotx."""
77
82
 
78
- def __call__(
83
+ def __init__(
79
84
  self,
80
85
  tree: TreeType,
81
- layout: str | LayoutType,
82
- orientation: Optional[str] = None,
83
- directed: bool | str = False,
84
- vertex_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
85
- edge_labels: Optional[Sequence[str] | dict] = None,
86
- ) -> TreeData:
87
- """Create tree data object for iplotx from any provider."""
88
- raise NotImplementedError("Tree data providers must implement this method.")
86
+ ) -> None:
87
+ """Initialize the provider with the tree type.
89
88
 
90
- def check_dependencies(
91
- self,
92
- ):
89
+ Parameters:
90
+ tree: The tree type that this provider will handle.
91
+ """
92
+ self.tree = tree
93
+
94
+ @staticmethod
95
+ def check_dependencies():
93
96
  """Check whether the dependencies for this provider are installed."""
94
97
  raise NotImplementedError("Tree data providers must implement this method.")
95
98
 
96
- def tree_type(
97
- self,
98
- ):
99
+ @staticmethod
100
+ def tree_type():
99
101
  """Return the tree type from this provider to check for instances."""
100
102
  raise NotImplementedError("Tree data providers must implement this method.")
103
+
104
+ def is_rooted(self) -> bool:
105
+ """Get whether the tree is rooted.
106
+
107
+ Returns:
108
+ A boolean indicating whether the tree is rooted.
109
+
110
+ Note: This is a default implemntation that can be overridden by the provider
111
+ if they support unrooted trees (e.g. Biopython).
112
+ """
113
+ return True
114
+
115
+ def get_root(self) -> Any:
116
+ """Get the tree root in a provider-specific data structure.
117
+
118
+ Returns:
119
+ The root of the tree.
120
+
121
+ Note: This is a default implemntation that can be overridden by the provider.
122
+ """
123
+ root_attr = self.tree.root
124
+ if callable(root_attr):
125
+ return root_attr()
126
+ else:
127
+ return root_attr
128
+
129
+ def get_leaves(self) -> Sequence[Any]:
130
+ """Get the tree leaves/tips in a provider-specific data structure.
131
+
132
+ Returns:
133
+ The leaves or tips of the tree.
134
+ """
135
+ raise NotImplementedError("Tree data providers must implement this method.")
136
+
137
+ def preorder(self) -> Iterable[Any]:
138
+ """Preorder (DFS - parent first) iteration over the tree.
139
+
140
+ Returns:
141
+ An iterable of nodes in preorder traversal.
142
+ """
143
+ raise NotImplementedError("Tree data providers must implement this method.")
144
+
145
+ def postorder(self) -> Iterable[Any]:
146
+ """Postorder (DFS - child first) iteration over the tree.
147
+
148
+ Returns:
149
+ An iterable of nodes in preorder traversal.
150
+ """
151
+ raise NotImplementedError("Tree data providers must implement this method.")
152
+
153
+ @staticmethod
154
+ def get_children(
155
+ node: Any,
156
+ ) -> Sequence[Any]:
157
+ """Get the children of a node.
158
+
159
+ Parameters:
160
+ node: The node to get the children from.
161
+ Returns:
162
+ A sequence of children nodes.
163
+ """
164
+ raise NotImplementedError("Tree data providers must implement this method.")
165
+
166
+ @staticmethod
167
+ def get_branch_length(
168
+ node: Any,
169
+ ) -> Optional[float]:
170
+ """Get the length of the branch to this node.
171
+
172
+ Parameters:
173
+ node: The node to get the branch length from.
174
+ Returns:
175
+ The branch length to the node.
176
+ """
177
+ raise NotImplementedError("Tree data providers must implement this method.")
178
+
179
+ def get_branch_length_default_to_one(
180
+ self,
181
+ node: Any,
182
+ ) -> float:
183
+ """Get the length of the branch to this node, defaulting to 1.0 if not available.
184
+
185
+ Parameters:
186
+ node: The node to get the branch length from.
187
+ Returns:
188
+ The branch length to the node, defaulting to 1.0 if not available.
189
+ """
190
+ branch_length = self.get_branch_length(node)
191
+ return branch_length if branch_length is not None else 1.0
192
+
193
+ def __call__(
194
+ self,
195
+ layout: str | LayoutType,
196
+ orientation: Optional[str],
197
+ layout_style: Optional[dict[str, int | float | str]] = None,
198
+ directed: bool | str = False,
199
+ vertex_labels: Optional[
200
+ Sequence[str] | dict[Hashable, str] | pd.Series | bool
201
+ ] = None,
202
+ edge_labels: Optional[Sequence[str] | dict] = None,
203
+ leaf_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
204
+ ) -> TreeData:
205
+ """Create tree data object for iplotx from ete4.core.tre.Tree classes."""
206
+
207
+ if layout_style is None:
208
+ layout_style = {}
209
+
210
+ if orientation is None:
211
+ if layout == "horizontal":
212
+ orientation = "right"
213
+ elif layout == "vertical":
214
+ orientation = "descending"
215
+ elif layout == "radial":
216
+ orientation = "clockwise"
217
+
218
+ tree_data = {
219
+ "root": self.get_root(),
220
+ "rooted": self.is_rooted(),
221
+ "directed": directed,
222
+ "ndim": 2,
223
+ "layout_name": layout,
224
+ "orientation": orientation,
225
+ }
226
+
227
+ # Add vertex_df including layout
228
+ tree_data["vertex_df"] = normalise_tree_layout(
229
+ layout,
230
+ orientation=orientation,
231
+ root=tree_data["root"],
232
+ preorder_fun=self.preorder,
233
+ postorder_fun=self.postorder,
234
+ children_fun=self.get_children,
235
+ branch_length_fun=self.get_branch_length_default_to_one,
236
+ **layout_style,
237
+ )
238
+ if layout in ("radial",):
239
+ tree_data["layout_coordinate_system"] = "polar"
240
+ else:
241
+ tree_data["layout_coordinate_system"] = "cartesian"
242
+
243
+ # Add edge_df
244
+ edge_data = {"_ipx_source": [], "_ipx_target": []}
245
+ for node in self.preorder():
246
+ for child in self.get_children(node):
247
+ if directed == "parent":
248
+ edge_data["_ipx_source"].append(child)
249
+ edge_data["_ipx_target"].append(node)
250
+ else:
251
+ edge_data["_ipx_source"].append(node)
252
+ edge_data["_ipx_target"].append(child)
253
+ edge_df = pd.DataFrame(edge_data)
254
+ tree_data["edge_df"] = edge_df
255
+
256
+ # Add leaf_df
257
+ tree_data["leaf_df"] = pd.DataFrame(index=self.get_leaves())
258
+
259
+ # Add vertex labels
260
+ if vertex_labels is None:
261
+ vertex_labels = False
262
+ if np.isscalar(vertex_labels) and vertex_labels:
263
+ tree_data["vertex_df"]["label"] = [
264
+ x.name for x in tree_data["vertex_df"].index
265
+ ]
266
+ elif not np.isscalar(vertex_labels):
267
+ # If a dict-like object is passed, it can be incomplete (e.g. only the leaves):
268
+ # we fill the rest with empty strings which are not going to show up in the plot.
269
+ if isinstance(vertex_labels, pd.Series):
270
+ vertex_labels = dict(vertex_labels)
271
+ if isinstance(vertex_labels, dict):
272
+ for vertex in tree_data["vertex_df"].index:
273
+ if vertex not in vertex_labels:
274
+ vertex_labels[vertex] = ""
275
+ tree_data["vertex_df"]["label"] = pd.Series(vertex_labels)
276
+
277
+ # Add leaf labels
278
+ if leaf_labels is None:
279
+ leaf_labels = False
280
+ if np.isscalar(leaf_labels) and leaf_labels:
281
+ tree_data["leaf_labels"]["label"] = [
282
+ # FIXME: this is likely broken
283
+ x.name
284
+ for x in tree_data["leaf_df"].index
285
+ ]
286
+ elif not np.isscalar(leaf_labels):
287
+ # Leaves are already in the dataframe in a certain order, so sequences are allowed
288
+ if isinstance(leaf_labels, (list, tuple, np.ndarray)):
289
+ leaf_labels = {
290
+ leaf: label
291
+ for leaf, label in zip(tree_data["leaf_df"].index, leaf_labels)
292
+ }
293
+ # If a dict-like object is passed, it can be incomplete (e.g. only the leaves):
294
+ # we fill the rest with empty strings which are not going to show up in the plot.
295
+ if isinstance(leaf_labels, pd.Series):
296
+ leaf_labels = dict(leaf_labels)
297
+ if isinstance(leaf_labels, dict):
298
+ for leaf in tree_data["leaf_df"].index:
299
+ if leaf not in leaf_labels:
300
+ leaf_labels[leaf] = ""
301
+ tree_data["leaf_df"]["label"] = pd.Series(leaf_labels)
302
+
303
+ return tree_data
iplotx/label.py CHANGED
@@ -1,8 +1,13 @@
1
+ """
2
+ Module for label collection in iplotx.
3
+ """
4
+
1
5
  from typing import (
2
6
  Optional,
3
7
  Sequence,
4
8
  )
5
9
  import numpy as np
10
+ import pandas as pd
6
11
  import matplotlib as mpl
7
12
 
8
13
  from .style import (
@@ -26,13 +31,28 @@ from .utils.matplotlib import (
26
31
  )
27
32
  )
28
33
  class LabelCollection(mpl.artist.Artist):
34
+ """Collection of labels for iplotx with styles.
35
+
36
+ NOTE: This class is not a subclass of `mpl.collections.Collection`, although in some ways items
37
+ behaves like one. It is named LabelCollection quite literally to indicate it contains a list of
38
+ labels for vertices, edges, etc.
39
+ """
40
+
29
41
  def __init__(
30
42
  self,
31
- labels: Sequence[str],
43
+ labels: pd.Series,
32
44
  style: Optional[dict[str, dict]] = None,
33
45
  offsets: Optional[np.ndarray] = None,
34
46
  transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
35
- ):
47
+ ) -> None:
48
+ """Initialize a collection of labels.
49
+
50
+ Parameters:
51
+ labels: A sequence of labels to be displayed.
52
+ style: A dictionary of styles to apply to the labels. The keys are style properties.
53
+ offsets: A sequence of offsets for each label, specifying the position of the label.
54
+ transform: A transform to apply to the labels. This is usually ax.transData.
55
+ """
36
56
  self._labels = labels
37
57
  self._offsets = offsets if offsets is not None else np.zeros((len(labels), 2))
38
58
  self._style = style
@@ -41,19 +61,25 @@ class LabelCollection(mpl.artist.Artist):
41
61
  self.set_transform(transform)
42
62
  self._create_artists()
43
63
 
44
- def get_children(self):
64
+ def get_children(self) -> tuple[mpl.artist.Artist]:
65
+ """Get the children of this artist, which are the label artists."""
45
66
  return tuple(self._labelartists)
46
67
 
47
- def set_figure(self, figure):
48
- super().set_figure(figure)
68
+ def set_figure(self, fig) -> None:
69
+ """Set the figure of this artist.
70
+
71
+ Parameters:
72
+ fig: The figure to set.
73
+ """
74
+ super().set_figure(fig)
49
75
  for child in self.get_children():
50
- child.set_figure(figure)
51
- self._update_offsets(dpi=figure.dpi)
76
+ child.set_figure(fig)
77
+ self._update_offsets(dpi=fig.dpi)
52
78
 
53
- def _get_margins_with_dpi(self, dpi=72.0):
79
+ def _get_margins_with_dpi(self, dpi: float = 72.0) -> np.ndarray:
54
80
  return self._margins * dpi / 72.0
55
81
 
56
- def _create_artists(self):
82
+ def _create_artists(self) -> None:
57
83
  style = copy_with_deep_values(self._style) if self._style is not None else {}
58
84
  transform = self.get_transform()
59
85
 
@@ -72,6 +98,11 @@ class LabelCollection(mpl.artist.Artist):
72
98
  vmargin = stylei.pop("vmargin", 0.0)
73
99
  margins.append((hmargin, vmargin))
74
100
 
101
+ # Initially, ignore autoalignment since we do not know the
102
+ # rotations
103
+ if stylei.get("horizontalalignment") == "auto":
104
+ stylei["horizontalalignment"] = "center"
105
+
75
106
  art = mpl.text.Text(
76
107
  self._offsets[i][0],
77
108
  self._offsets[i][1],
@@ -82,14 +113,20 @@ class LabelCollection(mpl.artist.Artist):
82
113
  arts.append(art)
83
114
  self._labelartists = arts
84
115
  self._margins = np.array(margins)
116
+ self._rotations = np.zeros(len(self._labels))
85
117
 
86
- def _update_offsets(self, dpi=72.0):
118
+ def _update_offsets(self, dpi: float = 72.0) -> None:
87
119
  """Update offsets including margins."""
88
- offsets = self._adjust_offsets_for_margins(self._offsets, dpi=dpi)
89
- self.set_offsets(offsets)
90
-
91
- def get_offsets(self):
92
- return self._offsets
120
+ self.set_offsets(self._offsets, dpi=dpi)
121
+
122
+ def get_offsets(self, with_margins: bool = False) -> np.ndarray:
123
+ """Get the positions (offsets) of the labels."""
124
+ if not with_margins:
125
+ return self._offsets
126
+ else:
127
+ return np.array(
128
+ [art.get_position() for art in self._labelartists],
129
+ )
93
130
 
94
131
  def _adjust_offsets_for_margins(self, offsets, dpi=72.0):
95
132
  margins = self._get_margins_with_dpi(dpi=dpi)
@@ -97,24 +134,69 @@ class LabelCollection(mpl.artist.Artist):
97
134
  transform = self.get_transform()
98
135
  trans = transform.transform
99
136
  trans_inv = transform.inverted().transform
100
- offsets = trans_inv(trans(offsets) + margins)
137
+ rotations = self.get_rotations()
138
+ vrot = [np.cos(rotations), np.sin(rotations)]
139
+
140
+ margins_rot = np.empty_like(margins)
141
+ margins_rot[:, 0] = margins[:, 0] * vrot[0] - margins[:, 1] * vrot[1]
142
+ margins_rot[:, 1] = margins[:, 0] * vrot[1] + margins[:, 1] * vrot[0]
143
+ offsets = trans_inv(trans(offsets) + margins_rot)
101
144
  return offsets
102
145
 
103
- def set_offsets(self, offsets):
104
- """Set positions (offsets) of the labels."""
146
+ def set_offsets(self, offsets, dpi: float = 72.0) -> None:
147
+ """Set positions (offsets) of the labels.
148
+
149
+ Parameters:
150
+ offsets: A sequence of offsets for each label, specifying the position of the label.
151
+ """
105
152
  self._offsets = np.asarray(offsets)
106
- for art, offset in zip(self._labelartists, self._offsets):
153
+ offsets_with_margins = self._adjust_offsets_for_margins(offsets, dpi=dpi)
154
+ for art, offset in zip(self._labelartists, offsets_with_margins):
107
155
  art.set_position((offset[0], offset[1]))
108
156
 
109
- def set_rotations(self, rotations):
157
+ def get_rotations(self) -> np.ndarray:
158
+ """Get the rotations of the labels in radians."""
159
+ return self._rotations
160
+
161
+ def set_rotations(self, rotations: Sequence[float]) -> None:
162
+ """Set the rotations of the labels.
163
+
164
+ Parameters:
165
+ rotations: A sequence of rotations in radians for each label.
166
+ """
167
+ self._rotations = np.asarray(rotations)
168
+ ha = self._style.get("horizontalalignment", "center")
110
169
  for art, rotation in zip(self._labelartists, rotations):
111
170
  rot_deg = 180.0 / np.pi * rotation
112
171
  # Force the font size to be upwards
172
+ if ha == "auto":
173
+ if -90 <= rot_deg < 90:
174
+ art.set_horizontalalignment("left")
175
+ else:
176
+ art.set_horizontalalignment("right")
113
177
  rot_deg = ((rot_deg + 90) % 180) - 90
114
178
  art.set_rotation(rot_deg)
115
179
 
180
+ def get_datalim(self, transData=None) -> mpl.transforms.Bbox:
181
+ """Get the data limits of the labels."""
182
+ bboxes = self.get_datalims_children(transData=transData)
183
+ bbox = mpl.transforms.Bbox.union(bboxes)
184
+ return bbox
185
+
186
+ def get_datalims_children(self, transData=None) -> Sequence[mpl.transforms.Bbox]:
187
+ """Get the data limits of the children of this artist."""
188
+ if transData is None:
189
+ transData = self.get_transform()
190
+ trans_inv = transData.inverted().transform_bbox
191
+ bboxes = []
192
+ for art in self._labelartists:
193
+ bbox_fig = art.get_bbox_patch().get_extents()
194
+ bbox_data = trans_inv(bbox_fig)
195
+ bboxes.append(bbox_data)
196
+ return bboxes
197
+
116
198
  @_stale_wrapper
117
- def draw(self, renderer):
199
+ def draw(self, renderer) -> None:
118
200
  """Draw each of the children, with some buffering mechanism."""
119
201
  if not self.get_visible():
120
202
  return
iplotx/layout.py CHANGED
@@ -2,34 +2,48 @@
2
2
  Layout functions, currently limited to trees.
3
3
  """
4
4
 
5
- from collections.abc import Hashable
5
+ from typing import Any
6
+ from collections.abc import (
7
+ Hashable,
8
+ Callable,
9
+ )
6
10
 
7
11
  import numpy as np
8
12
 
9
13
 
10
14
  def compute_tree_layout(
11
- tree,
12
15
  layout: str,
13
16
  orientation: str,
17
+ root: Any,
18
+ preorder_fun: Callable,
19
+ postorder_fun: Callable,
20
+ children_fun: Callable,
21
+ branch_length_fun: Callable,
14
22
  **kwargs,
15
23
  ) -> dict[Hashable, list[float]]:
16
24
  """Compute the layout for a tree.
17
25
 
18
26
  Parameters:
19
- tree: The tree to compute the layout for.
20
- layout: The name of the layout, e.g. "horizontal" or "radial".
21
- orientation: The orientation of the layout, e.g. "right", "left", "descending", or "ascending".
27
+ layout: The name of the layout, e.g. "horizontal", "vertial", or "radial".
28
+ orientation: The orientation of the layout, e.g. "right", "left", "descending",
29
+ "ascending", "clockwise", "anticlockwise".
22
30
 
23
31
  Returns:
24
32
  A layout dictionary with node positions.
25
33
  """
34
+ kwargs["root"] = root
35
+ kwargs["preorder_fun"] = preorder_fun
36
+ kwargs["postorder_fun"] = postorder_fun
37
+ kwargs["children_fun"] = children_fun
38
+ kwargs["branch_length_fun"] = branch_length_fun
39
+ kwargs["orientation"] = orientation
26
40
 
27
41
  if layout == "radial":
28
- layout_dict = _circular_tree_layout(tree, orientation=orientation, **kwargs)
42
+ layout_dict = _radial_tree_layout(**kwargs)
29
43
  elif layout == "horizontal":
30
- layout_dict = _horizontal_tree_layout(tree, orientation=orientation, **kwargs)
44
+ layout_dict = _horizontal_tree_layout(**kwargs)
31
45
  elif layout == "vertical":
32
- layout_dict = _vertical_tree_layout(tree, orientation=orientation, **kwargs)
46
+ layout_dict = _vertical_tree_layout(**kwargs)
33
47
  else:
34
48
  raise ValueError(f"Tree layout not available: {layout}")
35
49
 
@@ -37,12 +51,11 @@ def compute_tree_layout(
37
51
 
38
52
 
39
53
  def _horizontal_tree_layout_right(
40
- tree,
41
- root_fun: callable,
42
- preorder_fun: callable,
43
- postorder_fun: callable,
44
- children_fun: callable,
45
- branch_length_fun: callable,
54
+ root: Any,
55
+ preorder_fun: Callable,
56
+ postorder_fun: Callable,
57
+ children_fun: Callable,
58
+ branch_length_fun: Callable,
46
59
  ) -> dict[Hashable, list[float]]:
47
60
  """Build a tree layout horizontally, left to right.
48
61
 
@@ -57,7 +70,7 @@ def _horizontal_tree_layout_right(
57
70
 
58
71
  # Set the y values for vertices
59
72
  i = 0
60
- for node in postorder_fun(tree):
73
+ for node in postorder_fun():
61
74
  children = children_fun(node)
62
75
  if len(children) == 0:
63
76
  layout[node] = [None, i]
@@ -69,9 +82,8 @@ def _horizontal_tree_layout_right(
69
82
  ]
70
83
 
71
84
  # Set the x values for vertices
72
- layout[root_fun(tree)][0] = 0
73
- for node in preorder_fun(tree):
74
- x0, y0 = layout[node]
85
+ layout[root][0] = 0
86
+ for node in preorder_fun():
75
87
  for child in children_fun(node):
76
88
  bl = branch_length_fun(child)
77
89
  if bl is None:
@@ -82,7 +94,6 @@ def _horizontal_tree_layout_right(
82
94
 
83
95
 
84
96
  def _horizontal_tree_layout(
85
- tree,
86
97
  orientation="right",
87
98
  **kwargs,
88
99
  ) -> dict[Hashable, list[float]]:
@@ -90,22 +101,21 @@ def _horizontal_tree_layout(
90
101
  if orientation not in ("right", "left"):
91
102
  raise ValueError("Orientation must be 'right' or 'left'.")
92
103
 
93
- layout = _horizontal_tree_layout_right(tree, **kwargs)
104
+ layout = _horizontal_tree_layout_right(**kwargs)
94
105
 
95
106
  if orientation == "left":
96
- for key, value in layout.items():
107
+ for key in layout:
97
108
  layout[key][0] *= -1
98
109
  return layout
99
110
 
100
111
 
101
112
  def _vertical_tree_layout(
102
- tree,
103
113
  orientation="descending",
104
114
  **kwargs,
105
115
  ) -> dict[Hashable, list[float]]:
106
116
  """Vertical tree layout."""
107
- sign = 1 if orientation == "descending" else -1
108
- layout = _horizontal_tree_layout(tree, **kwargs)
117
+ sign = -1 if orientation == "descending" else 1
118
+ layout = _horizontal_tree_layout(**kwargs)
109
119
  for key, value in layout.items():
110
120
  # Invert x and y
111
121
  layout[key] = value[::-1]
@@ -114,24 +124,35 @@ def _vertical_tree_layout(
114
124
  return layout
115
125
 
116
126
 
117
- def _circular_tree_layout(
118
- tree,
119
- orientation="right",
120
- starting_angle=0,
121
- angular_span=360,
127
+ def _radial_tree_layout(
128
+ orientation: str = "right",
129
+ start: float = 180,
130
+ span: float = 360,
122
131
  **kwargs,
123
- ) -> dict[Hashable, list[float]]:
124
- """Circular tree layout."""
132
+ ) -> dict[Hashable, tuple[float, float]]:
133
+ """Radial tree layout.
134
+
135
+ Parameters:
136
+ orientation: Whether the layout fans out towards the right (clockwise) or left
137
+ (anticlockwise).
138
+ start: The starting angle in degrees, default is -180 (left).
139
+ span: The angular span in degrees, default is 360 (full circle). When this is
140
+ 360, it leaves a small gap at the end to ensure the first and last leaf
141
+ are not overlapping.
142
+ Returns:
143
+ A dictionary with the radial layout.
144
+ """
125
145
  # Short form
126
- th = starting_angle * np.pi / 180
127
- th_span = angular_span * np.pi / 180
128
- sign = 1 if orientation == "right" else -1
146
+ th = start * np.pi / 180
147
+ th_span = span * np.pi / 180
148
+ pad = int(span == 360)
149
+ sign = -1 if orientation in ("right", "clockwise") else 1
129
150
 
130
- layout = _horizontal_tree_layout_right(tree, **kwargs)
151
+ layout = _horizontal_tree_layout_right(**kwargs)
131
152
  ymax = max(point[1] for point in layout.values())
132
153
  for key, (x, y) in layout.items():
133
154
  r = x
134
- theta = sign * th_span * y / (ymax + 1) + th
155
+ theta = sign * th_span * y / (ymax + pad) + th
135
156
  # We export r and theta to ensure theta does not
136
157
  # modulo 2pi if we take the tan and then arctan later.
137
158
  layout[key] = (r, theta)