iplotx 0.2.1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,112 +1,41 @@
1
1
  from typing import (
2
+ Any,
2
3
  Optional,
3
4
  Sequence,
4
5
  )
5
- from collections.abc import Hashable
6
- from operator import attrgetter
7
- import numpy as np
8
- import pandas as pd
9
-
10
- from ....typing import (
11
- TreeType,
12
- LayoutType,
13
- )
14
6
  from ...typing import (
15
7
  TreeDataProvider,
16
- TreeData,
17
- )
18
- from ...heuristics import (
19
- normalise_tree_layout,
20
8
  )
21
9
 
22
10
 
23
11
  class Cogent3DataProvider(TreeDataProvider):
24
- def __call__(
25
- self,
26
- tree: TreeType,
27
- layout: str | LayoutType,
28
- orientation: str = "horizontal",
29
- directed: bool | str = False,
30
- vertex_labels: Optional[
31
- Sequence[str] | dict[Hashable, str] | pd.Series | bool
32
- ] = None,
33
- edge_labels: Optional[Sequence[str] | dict] = None,
34
- ) -> TreeData:
35
- """Create tree data object for iplotx from cogent3.core.tree.PhyloNode classes."""
36
-
37
- root_fun = lambda tree: tree.root()
38
- preorder_fun = lambda tree: tree.preorder()
39
- postorder_fun = lambda tree: tree.postorder()
40
- children_fun = attrgetter("children")
41
- branch_length_fun = attrgetter("length")
42
- leaves_fun = lambda tree: tree.tips()
43
-
44
- tree_data = {
45
- "root": root_fun(tree),
46
- "leaves": leaves_fun(tree),
47
- "rooted": True,
48
- "directed": directed,
49
- "ndim": 2,
50
- "layout_name": layout,
51
- }
12
+ def preorder(self) -> Sequence[Any]:
13
+ return self.tree.preorder()
52
14
 
53
- # Add vertex_df including layout
54
- tree_data["vertex_df"] = normalise_tree_layout(
55
- layout,
56
- tree=tree,
57
- orientation=orientation,
58
- root_fun=root_fun,
59
- preorder_fun=preorder_fun,
60
- postorder_fun=postorder_fun,
61
- children_fun=children_fun,
62
- branch_length_fun=branch_length_fun,
63
- )
64
- if layout in ("radial",):
65
- tree_data["layout_coordinate_system"] = "polar"
66
- else:
67
- tree_data["layout_coordinate_system"] = "cartesian"
15
+ def postorder(self) -> Sequence[Any]:
16
+ return self.tree.postorder()
68
17
 
69
- # Add edge_df
70
- edge_data = {"_ipx_source": [], "_ipx_target": []}
71
- for node in preorder_fun(tree):
72
- for child in node.children:
73
- if directed == "parent":
74
- edge_data["_ipx_source"].append(child)
75
- edge_data["_ipx_target"].append(node)
76
- else:
77
- edge_data["_ipx_source"].append(node)
78
- edge_data["_ipx_target"].append(child)
79
- edge_df = pd.DataFrame(edge_data)
80
- tree_data["edge_df"] = edge_df
18
+ def get_leaves(self) -> Sequence[Any]:
19
+ return self.tree.tips()
81
20
 
82
- # Add vertex labels
83
- if vertex_labels is None:
84
- vertex_labels = False
85
- if np.isscalar(vertex_labels) and vertex_labels:
86
- tree_data["vertex_df"]["label"] = [
87
- x.name for x in tree_data["vertices"].index
88
- ]
89
- elif not np.isscalar(vertex_labels):
90
- # If a dict-like object is passed, it can be incomplete (e.g. only the leaves):
91
- # we fill the rest with empty strings which are not going to show up in the plot.
92
- if isinstance(vertex_labels, pd.Series):
93
- vertex_labels = dict(vertex_labels)
94
- if isinstance(vertex_labels, dict):
95
- for vertex in tree_data["vertex_df"].index:
96
- if vertex not in vertex_labels:
97
- vertex_labels[vertex] = ""
98
- tree_data["vertex_df"]["label"] = pd.Series(vertex_labels)
21
+ @staticmethod
22
+ def get_children(node: Any) -> Sequence[Any]:
23
+ return node.children
99
24
 
100
- return tree_data
25
+ @staticmethod
26
+ def get_branch_length(node: Any) -> Optional[float]:
27
+ return node.length
101
28
 
102
- def check_dependencies(self) -> bool:
29
+ @staticmethod
30
+ def check_dependencies() -> bool:
103
31
  try:
104
32
  import cogent3
105
33
  except ImportError:
106
34
  return False
107
35
  return True
108
36
 
109
- def tree_type(self):
37
+ @staticmethod
38
+ def tree_type():
110
39
  from cogent3.core.tree import PhyloNode
111
40
 
112
41
  return PhyloNode
@@ -1,112 +1,44 @@
1
1
  from typing import (
2
+ Any,
2
3
  Optional,
3
4
  Sequence,
4
5
  )
5
- from collections.abc import Hashable
6
- from operator import attrgetter
7
- import numpy as np
8
- import pandas as pd
6
+ from functools import partialmethod
9
7
 
10
- from ....typing import (
11
- TreeType,
12
- LayoutType,
13
- )
14
8
  from ...typing import (
15
9
  TreeDataProvider,
16
- TreeData,
17
- )
18
- from ...heuristics import (
19
- normalise_tree_layout,
20
10
  )
21
11
 
22
12
 
23
13
  class Ete4DataProvider(TreeDataProvider):
24
- def __call__(
25
- self,
26
- tree: TreeType,
27
- layout: str | LayoutType,
28
- orientation: str = "horizontal",
29
- directed: bool | str = False,
30
- vertex_labels: Optional[
31
- Sequence[str] | dict[Hashable, str] | pd.Series | bool
32
- ] = None,
33
- edge_labels: Optional[Sequence[str] | dict] = None,
34
- ) -> TreeData:
35
- """Create tree data object for iplotx from ete4.core.tre.Tree classes."""
36
-
37
- root_fun = attrgetter("root")
38
- preorder_fun = lambda tree: tree.traverse("preorder")
39
- postorder_fun = lambda tree: tree.traverse("postorder")
40
- children_fun = attrgetter("children")
41
- branch_length_fun = lambda node: node.dist if node.dist is not None else 1.0
42
- leaves_fun = lambda tree: tree.leaves()
43
-
44
- tree_data = {
45
- "root": tree.root,
46
- "leaves": leaves_fun(tree),
47
- "rooted": True,
48
- "directed": directed,
49
- "ndim": 2,
50
- "layout_name": layout,
51
- }
14
+ def _traverse(self, order: str) -> Any:
15
+ """Traverse the tree."""
16
+ return self.tree.traverse(order)
52
17
 
53
- # Add vertex_df including layout
54
- tree_data["vertex_df"] = normalise_tree_layout(
55
- layout,
56
- tree=tree,
57
- orientation=orientation,
58
- root_fun=root_fun,
59
- preorder_fun=preorder_fun,
60
- postorder_fun=postorder_fun,
61
- children_fun=children_fun,
62
- branch_length_fun=branch_length_fun,
63
- )
64
- if layout in ("radial",):
65
- tree_data["layout_coordinate_system"] = "polar"
66
- else:
67
- tree_data["layout_coordinate_system"] = "cartesian"
18
+ preorder = partialmethod(_traverse, order="preorder")
19
+ postorder = partialmethod(_traverse, order="postorder")
68
20
 
69
- # Add edge_df
70
- edge_data = {"_ipx_source": [], "_ipx_target": []}
71
- for node in preorder_fun(tree):
72
- for child in children_fun(node):
73
- if directed == "parent":
74
- edge_data["_ipx_source"].append(child)
75
- edge_data["_ipx_target"].append(node)
76
- else:
77
- edge_data["_ipx_source"].append(node)
78
- edge_data["_ipx_target"].append(child)
79
- edge_df = pd.DataFrame(edge_data)
80
- tree_data["edge_df"] = edge_df
21
+ def get_leaves(self) -> Sequence[Any]:
22
+ return self.tree.leaves()
81
23
 
82
- # Add vertex labels
83
- if vertex_labels is None:
84
- vertex_labels = False
85
- if np.isscalar(vertex_labels) and vertex_labels:
86
- tree_data["vertex_df"]["label"] = [
87
- x.name for x in tree_data["vertices"].index
88
- ]
89
- elif not np.isscalar(vertex_labels):
90
- # If a dict-like object is passed, it can be incomplete (e.g. only the leaves):
91
- # we fill the rest with empty strings which are not going to show up in the plot.
92
- if isinstance(vertex_labels, pd.Series):
93
- vertex_labels = dict(vertex_labels)
94
- if isinstance(vertex_labels, dict):
95
- for vertex in tree_data["vertex_df"].index:
96
- if vertex not in vertex_labels:
97
- vertex_labels[vertex] = ""
98
- tree_data["vertex_df"]["label"] = pd.Series(vertex_labels)
24
+ @staticmethod
25
+ def get_children(node: Any) -> Sequence[Any]:
26
+ return node.children
99
27
 
100
- return tree_data
28
+ @staticmethod
29
+ def get_branch_length(node: Any) -> Optional[float]:
30
+ return node.dist
101
31
 
102
- def check_dependencies(self) -> bool:
32
+ @staticmethod
33
+ def check_dependencies() -> bool:
103
34
  try:
104
35
  from ete4 import Tree
105
36
  except ImportError:
106
37
  return False
107
38
  return True
108
39
 
109
- def tree_type(self):
40
+ @staticmethod
41
+ def tree_type():
110
42
  from ete4 import Tree
111
43
 
112
44
  return Tree
@@ -1,112 +1,41 @@
1
1
  from typing import (
2
+ Any,
2
3
  Optional,
3
4
  Sequence,
4
5
  )
5
- from collections.abc import Hashable
6
- from operator import attrgetter
7
- import numpy as np
8
- import pandas as pd
9
-
10
- from ....typing import (
11
- TreeType,
12
- LayoutType,
13
- )
14
6
  from ...typing import (
15
7
  TreeDataProvider,
16
- TreeData,
17
- )
18
- from ...heuristics import (
19
- normalise_tree_layout,
20
8
  )
21
9
 
22
10
 
23
11
  class SkbioDataProvider(TreeDataProvider):
24
- def __call__(
25
- self,
26
- tree: TreeType,
27
- layout: str | LayoutType,
28
- orientation: str = "horizontal",
29
- directed: bool | str = False,
30
- vertex_labels: Optional[
31
- Sequence[str] | dict[Hashable, str] | pd.Series | bool
32
- ] = None,
33
- edge_labels: Optional[Sequence[str] | dict] = None,
34
- ) -> TreeData:
35
- """Create tree data object for iplotx from skbio.tree.TreeNode classes."""
36
-
37
- root_fun = lambda tree: tree.root()
38
- preorder_fun = lambda tree: tree.preorder()
39
- postorder_fun = lambda tree: tree.postorder()
40
- children_fun = attrgetter("children")
41
- branch_length_fun = attrgetter("length")
42
- leaves_fun = lambda tree: tree.tips()
43
-
44
- tree_data = {
45
- "root": root_fun(tree),
46
- "leaves": leaves_fun(tree),
47
- "rooted": True,
48
- "directed": directed,
49
- "ndim": 2,
50
- "layout_name": layout,
51
- }
12
+ def preorder(self) -> Sequence[Any]:
13
+ return self.tree.preorder()
52
14
 
53
- # Add vertex_df including layout
54
- tree_data["vertex_df"] = normalise_tree_layout(
55
- layout,
56
- tree=tree,
57
- orientation=orientation,
58
- root_fun=root_fun,
59
- preorder_fun=preorder_fun,
60
- postorder_fun=postorder_fun,
61
- children_fun=children_fun,
62
- branch_length_fun=branch_length_fun,
63
- )
64
- if layout in ("radial",):
65
- tree_data["layout_coordinate_system"] = "polar"
66
- else:
67
- tree_data["layout_coordinate_system"] = "cartesian"
15
+ def postorder(self) -> Sequence[Any]:
16
+ return self.tree.postorder()
68
17
 
69
- # Add edge_df
70
- edge_data = {"_ipx_source": [], "_ipx_target": []}
71
- for node in preorder_fun(tree):
72
- for child in children_fun(node):
73
- if directed == "parent":
74
- edge_data["_ipx_source"].append(child)
75
- edge_data["_ipx_target"].append(node)
76
- else:
77
- edge_data["_ipx_source"].append(node)
78
- edge_data["_ipx_target"].append(child)
79
- edge_df = pd.DataFrame(edge_data)
80
- tree_data["edge_df"] = edge_df
18
+ def get_leaves(self) -> Sequence[Any]:
19
+ return self.tree.tips()
81
20
 
82
- # Add vertex labels
83
- if vertex_labels is None:
84
- vertex_labels = False
85
- if np.isscalar(vertex_labels) and vertex_labels:
86
- tree_data["vertex_df"]["label"] = [
87
- x.name for x in tree_data["vertices"].index
88
- ]
89
- elif not np.isscalar(vertex_labels):
90
- # If a dict-like object is passed, it can be incomplete (e.g. only the leaves):
91
- # we fill the rest with empty strings which are not going to show up in the plot.
92
- if isinstance(vertex_labels, pd.Series):
93
- vertex_labels = dict(vertex_labels)
94
- if isinstance(vertex_labels, dict):
95
- for vertex in tree_data["vertex_df"].index:
96
- if vertex not in vertex_labels:
97
- vertex_labels[vertex] = ""
98
- tree_data["vertex_df"]["label"] = pd.Series(vertex_labels)
21
+ @staticmethod
22
+ def get_children(node: Any) -> Sequence[Any]:
23
+ return node.children
99
24
 
100
- return tree_data
25
+ @staticmethod
26
+ def get_branch_length(node: Any) -> Optional[float]:
27
+ return node.length
101
28
 
102
- def check_dependencies(self) -> bool:
29
+ @staticmethod
30
+ def check_dependencies() -> bool:
103
31
  try:
104
32
  from skbio import TreeNode
105
33
  except ImportError:
106
34
  return False
107
35
  return True
108
36
 
109
- def tree_type(self):
37
+ @staticmethod
38
+ def tree_type():
110
39
  from skbio import TreeNode
111
40
 
112
41
  return TreeNode
iplotx/ingest/typing.py CHANGED
@@ -11,14 +11,20 @@ from typing import (
11
11
  Protocol,
12
12
  Optional,
13
13
  Sequence,
14
+ Any,
15
+ Iterable,
14
16
  )
15
17
  from collections.abc import Hashable
18
+ import numpy as np
16
19
  import pandas as pd
17
20
  from ..typing import (
18
21
  GraphType,
19
22
  LayoutType,
20
23
  TreeType,
21
24
  )
25
+ from .heuristics import (
26
+ normalise_tree_layout,
27
+ )
22
28
 
23
29
 
24
30
  class NetworkData(TypedDict):
@@ -44,15 +50,13 @@ class NetworkDataProvider(Protocol):
44
50
  """Create network data object for iplotx from any provider."""
45
51
  raise NotImplementedError("Network data providers must implement this method.")
46
52
 
47
- def check_dependencies(
48
- self,
49
- ):
53
+ @staticmethod
54
+ def check_dependencies():
50
55
  """Check whether the dependencies for this provider are installed."""
51
56
  raise NotImplementedError("Network data providers must implement this method.")
52
57
 
53
- def graph_type(
54
- self,
55
- ):
58
+ @staticmethod
59
+ def graph_type():
56
60
  """Return the graph type from this provider to check for instances."""
57
61
  raise NotImplementedError("Network data providers must implement this method.")
58
62
 
@@ -63,11 +67,12 @@ class TreeData(TypedDict):
63
67
  rooted: bool
64
68
  directed: bool | str
65
69
  root: Optional[Hashable]
66
- leaves: list[Hashable]
70
+ leaf_df: pd.DataFrame
67
71
  vertex_df: dict[Hashable, tuple[float, float]]
68
72
  edge_df: dict[Hashable, Sequence[tuple[float, float]]]
69
73
  layout_coordinate_system: str
70
74
  layout_name: str
75
+ orientation: str
71
76
  ndim: int
72
77
  tree_library: NotRequired[str]
73
78
 
@@ -75,26 +80,224 @@ class TreeData(TypedDict):
75
80
  class TreeDataProvider(Protocol):
76
81
  """Protocol for tree data ingestion provider for iplotx."""
77
82
 
78
- def __call__(
83
+ def __init__(
79
84
  self,
80
85
  tree: TreeType,
81
- layout: str | LayoutType,
82
- orientation: Optional[str] = None,
83
- directed: bool | str = False,
84
- vertex_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
85
- edge_labels: Optional[Sequence[str] | dict] = None,
86
- ) -> TreeData:
87
- """Create tree data object for iplotx from any provider."""
88
- raise NotImplementedError("Tree data providers must implement this method.")
86
+ ) -> None:
87
+ """Initialize the provider with the tree type.
89
88
 
90
- def check_dependencies(
91
- self,
92
- ):
89
+ Parameters:
90
+ tree: The tree type that this provider will handle.
91
+ """
92
+ self.tree = tree
93
+
94
+ @staticmethod
95
+ def check_dependencies():
93
96
  """Check whether the dependencies for this provider are installed."""
94
97
  raise NotImplementedError("Tree data providers must implement this method.")
95
98
 
96
- def tree_type(
97
- self,
98
- ):
99
+ @staticmethod
100
+ def tree_type():
99
101
  """Return the tree type from this provider to check for instances."""
100
102
  raise NotImplementedError("Tree data providers must implement this method.")
103
+
104
+ def is_rooted(self) -> bool:
105
+ """Get whether the tree is rooted.
106
+
107
+ Returns:
108
+ A boolean indicating whether the tree is rooted.
109
+
110
+ Note: This is a default implemntation that can be overridden by the provider
111
+ if they support unrooted trees (e.g. Biopython).
112
+ """
113
+ return True
114
+
115
+ def get_root(self) -> Any:
116
+ """Get the tree root in a provider-specific data structure.
117
+
118
+ Returns:
119
+ The root of the tree.
120
+
121
+ Note: This is a default implemntation that can be overridden by the provider.
122
+ """
123
+ root_attr = self.tree.root
124
+ if callable(root_attr):
125
+ return root_attr()
126
+ else:
127
+ return root_attr
128
+
129
+ def get_leaves(self) -> Sequence[Any]:
130
+ """Get the tree leaves/tips in a provider-specific data structure.
131
+
132
+ Returns:
133
+ The leaves or tips of the tree.
134
+ """
135
+ raise NotImplementedError("Tree data providers must implement this method.")
136
+
137
+ def preorder(self) -> Iterable[Any]:
138
+ """Preorder (DFS - parent first) iteration over the tree.
139
+
140
+ Returns:
141
+ An iterable of nodes in preorder traversal.
142
+ """
143
+ raise NotImplementedError("Tree data providers must implement this method.")
144
+
145
+ def postorder(self) -> Iterable[Any]:
146
+ """Postorder (DFS - child first) iteration over the tree.
147
+
148
+ Returns:
149
+ An iterable of nodes in preorder traversal.
150
+ """
151
+ raise NotImplementedError("Tree data providers must implement this method.")
152
+
153
+ @staticmethod
154
+ def get_children(
155
+ node: Any,
156
+ ) -> Sequence[Any]:
157
+ """Get the children of a node.
158
+
159
+ Parameters:
160
+ node: The node to get the children from.
161
+ Returns:
162
+ A sequence of children nodes.
163
+ """
164
+ raise NotImplementedError("Tree data providers must implement this method.")
165
+
166
+ @staticmethod
167
+ def get_branch_length(
168
+ node: Any,
169
+ ) -> Optional[float]:
170
+ """Get the length of the branch to this node.
171
+
172
+ Parameters:
173
+ node: The node to get the branch length from.
174
+ Returns:
175
+ The branch length to the node.
176
+ """
177
+ raise NotImplementedError("Tree data providers must implement this method.")
178
+
179
+ def get_branch_length_default_to_one(
180
+ self,
181
+ node: Any,
182
+ ) -> float:
183
+ """Get the length of the branch to this node, defaulting to 1.0 if not available.
184
+
185
+ Parameters:
186
+ node: The node to get the branch length from.
187
+ Returns:
188
+ The branch length to the node, defaulting to 1.0 if not available.
189
+ """
190
+ branch_length = self.get_branch_length(node)
191
+ return branch_length if branch_length is not None else 1.0
192
+
193
+ def __call__(
194
+ self,
195
+ layout: str | LayoutType,
196
+ orientation: Optional[str],
197
+ layout_style: Optional[dict[str, int | float | str]] = None,
198
+ directed: bool | str = False,
199
+ vertex_labels: Optional[
200
+ Sequence[str] | dict[Hashable, str] | pd.Series | bool
201
+ ] = None,
202
+ edge_labels: Optional[Sequence[str] | dict] = None,
203
+ leaf_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
204
+ ) -> TreeData:
205
+ """Create tree data object for iplotx from ete4.core.tre.Tree classes."""
206
+
207
+ if layout_style is None:
208
+ layout_style = {}
209
+
210
+ if orientation is None:
211
+ if layout == "horizontal":
212
+ orientation = "right"
213
+ elif layout == "vertical":
214
+ orientation = "descending"
215
+ elif layout == "radial":
216
+ orientation = "clockwise"
217
+
218
+ tree_data = {
219
+ "root": self.get_root(),
220
+ "rooted": self.is_rooted(),
221
+ "directed": directed,
222
+ "ndim": 2,
223
+ "layout_name": layout,
224
+ "orientation": orientation,
225
+ }
226
+
227
+ # Add vertex_df including layout
228
+ tree_data["vertex_df"] = normalise_tree_layout(
229
+ layout,
230
+ orientation=orientation,
231
+ root=tree_data["root"],
232
+ preorder_fun=self.preorder,
233
+ postorder_fun=self.postorder,
234
+ children_fun=self.get_children,
235
+ branch_length_fun=self.get_branch_length_default_to_one,
236
+ **layout_style,
237
+ )
238
+ if layout in ("radial",):
239
+ tree_data["layout_coordinate_system"] = "polar"
240
+ else:
241
+ tree_data["layout_coordinate_system"] = "cartesian"
242
+
243
+ # Add edge_df
244
+ edge_data = {"_ipx_source": [], "_ipx_target": []}
245
+ for node in self.preorder():
246
+ for child in self.get_children(node):
247
+ if directed == "parent":
248
+ edge_data["_ipx_source"].append(child)
249
+ edge_data["_ipx_target"].append(node)
250
+ else:
251
+ edge_data["_ipx_source"].append(node)
252
+ edge_data["_ipx_target"].append(child)
253
+ edge_df = pd.DataFrame(edge_data)
254
+ tree_data["edge_df"] = edge_df
255
+
256
+ # Add leaf_df
257
+ tree_data["leaf_df"] = pd.DataFrame(index=self.get_leaves())
258
+
259
+ # Add vertex labels
260
+ if vertex_labels is None:
261
+ vertex_labels = False
262
+ if np.isscalar(vertex_labels) and vertex_labels:
263
+ tree_data["vertex_df"]["label"] = [
264
+ x.name for x in tree_data["vertex_df"].index
265
+ ]
266
+ elif not np.isscalar(vertex_labels):
267
+ # If a dict-like object is passed, it can be incomplete (e.g. only the leaves):
268
+ # we fill the rest with empty strings which are not going to show up in the plot.
269
+ if isinstance(vertex_labels, pd.Series):
270
+ vertex_labels = dict(vertex_labels)
271
+ if isinstance(vertex_labels, dict):
272
+ for vertex in tree_data["vertex_df"].index:
273
+ if vertex not in vertex_labels:
274
+ vertex_labels[vertex] = ""
275
+ tree_data["vertex_df"]["label"] = pd.Series(vertex_labels)
276
+
277
+ # Add leaf labels
278
+ if leaf_labels is None:
279
+ leaf_labels = False
280
+ if np.isscalar(leaf_labels) and leaf_labels:
281
+ tree_data["leaf_labels"]["label"] = [
282
+ # FIXME: this is likely broken
283
+ x.name
284
+ for x in tree_data["leaf_df"].index
285
+ ]
286
+ elif not np.isscalar(leaf_labels):
287
+ # Leaves are already in the dataframe in a certain order, so sequences are allowed
288
+ if isinstance(leaf_labels, (list, tuple, np.ndarray)):
289
+ leaf_labels = {
290
+ leaf: label
291
+ for leaf, label in zip(tree_data["leaf_df"].index, leaf_labels)
292
+ }
293
+ # If a dict-like object is passed, it can be incomplete (e.g. only the leaves):
294
+ # we fill the rest with empty strings which are not going to show up in the plot.
295
+ if isinstance(leaf_labels, pd.Series):
296
+ leaf_labels = dict(leaf_labels)
297
+ if isinstance(leaf_labels, dict):
298
+ for leaf in tree_data["leaf_df"].index:
299
+ if leaf not in leaf_labels:
300
+ leaf_labels[leaf] = ""
301
+ tree_data["leaf_df"]["label"] = pd.Series(leaf_labels)
302
+
303
+ return tree_data