iplotx 0.3.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iplotx/edge/leaf.py ADDED
@@ -0,0 +1,117 @@
1
+ """
2
+ Module containing leaf edges, i.e. special edges of tree visualisations
3
+ that connect leaf vertices to the deepest leaf (typically for labeling).
4
+ """
5
+
6
+ from typing import (
7
+ Sequence,
8
+ Optional,
9
+ Any,
10
+ )
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib as mpl
14
+
15
+ from ..utils.matplotlib import (
16
+ _forwarder,
17
+ )
18
+ from ..vertex import VertexCollection
19
+ from iplotx.edge import EdgeCollection
20
+
21
+
22
+ @_forwarder(
23
+ (
24
+ "set_clip_path",
25
+ "set_clip_box",
26
+ "set_snap",
27
+ "set_sketch_params",
28
+ "set_animated",
29
+ "set_picker",
30
+ )
31
+ )
32
+ class LeafEdgeCollection(EdgeCollection):
33
+ """Artist for leaf edges in tree visualisations."""
34
+
35
+ def __init__(
36
+ self,
37
+ patches: Sequence[mpl.patches.Patch],
38
+ vertex_leaf_ids: Sequence[tuple],
39
+ vertex_collection: VertexCollection,
40
+ leaf_collection: VertexCollection,
41
+ *args,
42
+ transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
43
+ arrow_transform: mpl.transforms.Transform = mpl.transforms.IdentityTransform(),
44
+ directed: bool = False,
45
+ style: Optional[dict[str, Any]] = None,
46
+ **kwargs,
47
+ ) -> None:
48
+ """Initialise a LeafEdgeCollection.
49
+
50
+ Parameters:
51
+ patches: A sequence (usually, list) of matplotlib `Patch`es describing the edges.
52
+ vertex_ids: A sequence of pairs `(v1, v2)`, each defining the ids of vertices at the
53
+ end of an edge.
54
+ vertex_collection: The VertexCollection instance containing the Artist for the
55
+ vertices. This is needed to compute vertex borders and adjust edges accordingly.
56
+ transform: The matplotlib transform for the edges, usually transData.
57
+ arrow_transform: The matplotlib transform for the arrow patches. This is not the
58
+ *offset_transform* of arrows, which is set equal to the edge transform (previous
59
+ parameter). Instead, it specifies how arrow size scales, similar to vertex size.
60
+ This is usually the identity transform.
61
+ directed: Whether the graph is directed (in which case arrows are drawn, possibly
62
+ with zero size or opacity to obtain an "arrowless" effect).
63
+ style: The edge style (subdictionary: "edge") to use at creation.
64
+ """
65
+ self._leaf_collection = leaf_collection
66
+ super().__init__(
67
+ patches=patches,
68
+ vertex_ids=vertex_leaf_ids,
69
+ vertex_collection=vertex_collection,
70
+ *args,
71
+ transform=transform,
72
+ arrow_transform=arrow_transform,
73
+ directed=directed,
74
+ style=style,
75
+ **kwargs,
76
+ )
77
+
78
+ def _get_adjacent_vertices_info(self):
79
+ lindex = self._leaf_collection.get_index()
80
+ lindex = pd.Series(
81
+ np.arange(len(lindex)),
82
+ index=lindex,
83
+ )
84
+ vindex = self._vertex_collection.get_index()
85
+ vindex = pd.Series(
86
+ np.arange(len(vindex)),
87
+ index=vindex,
88
+ ).loc[lindex.index]
89
+
90
+ voffsets = []
91
+ vpaths = []
92
+ vsizes = []
93
+ for vid in self._vertex_ids:
94
+ # NOTE: these are in the original layout coordinate system
95
+ # not cartesianised yet.
96
+ offset1 = self._vertex_collection.get_layout().values[vindex[vid]]
97
+ offset2 = self._leaf_collection.get_layout().values[lindex[vid]]
98
+ voffsets.append((offset1, offset2))
99
+
100
+ path1 = self._vertex_collection.get_paths()[vindex[vid]]
101
+ path2 = self._leaf_collection.get_paths()[lindex[vid]]
102
+ vpaths.append((path1, path2))
103
+
104
+ # NOTE: This needs to be computed here because the
105
+ # VertexCollection._transforms are reset each draw in order to
106
+ # accomodate for DPI changes on the canvas
107
+ size1 = self._vertex_collection.get_sizes_dpi()[vindex[vid]]
108
+ size2 = self._leaf_collection.get_sizes_dpi()[lindex[vid]]
109
+ vsizes.append((size1, size2))
110
+
111
+ return {
112
+ "ids": [(vid, vid) for vid in self._vertex_ids],
113
+ "offsets": voffsets,
114
+ "paths": vpaths,
115
+ "sizes": vsizes,
116
+ "loops": False,
117
+ }
iplotx/edge/ports.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Module for handling edge ports in iplotx.
3
3
  """
4
4
 
5
+ from collections.abc import Callable
5
6
  import numpy as np
6
7
 
7
8
  sq2 = np.sqrt(2) / 2
@@ -19,8 +20,8 @@ port_dict = {
19
20
 
20
21
 
21
22
  def _get_port_unit_vector(
22
- portstring,
23
- trans_inv,
23
+ portstring: str,
24
+ trans_inv: Callable,
24
25
  ):
25
26
  """Get the tangent unit vector from a port string."""
26
27
  # The only tricky bit is if the port says e.g. north but the y axis is inverted, in which
iplotx/groups.py CHANGED
@@ -130,14 +130,12 @@ class GroupingArtist(PatchCollection):
130
130
  return patches, grouping, coords_hulls
131
131
 
132
132
  def _compute_paths(self, dpi: float = 72.0) -> None:
133
- ppc = self._points_per_curve
134
133
  for i, hull in enumerate(self._coords_hulls):
135
- self._paths[i].vertices = _compute_group_path_with_vertex_padding(
134
+ _compute_group_path_with_vertex_padding(
136
135
  hull,
137
136
  self._paths[i].vertices,
138
137
  self.get_transform(),
139
138
  vertexpadding=self.get_vertexpadding_dpi(dpi),
140
- points_per_curve=ppc,
141
139
  )
142
140
 
143
141
  def _process(self) -> None:
iplotx/ingest/__init__.py CHANGED
@@ -32,15 +32,11 @@ provider_protocols = {
32
32
  }
33
33
 
34
34
  # Internally supported data providers
35
- data_providers: dict[str, dict[str, Protocol]] = {
36
- kind: {} for kind in provider_protocols
37
- }
35
+ data_providers: dict[str, dict[str, Protocol]] = {kind: {} for kind in provider_protocols}
38
36
  for kind in data_providers:
39
37
  providers_path = pathlib.Path(__file__).parent.joinpath("providers").joinpath(kind)
40
38
  for importer, module_name, _ in pkgutil.iter_modules([providers_path]):
41
- module = importlib.import_module(
42
- f"iplotx.ingest.providers.{kind}.{module_name}"
43
- )
39
+ module = importlib.import_module(f"iplotx.ingest.providers.{kind}.{module_name}")
44
40
  for key, val in module.__dict__.items():
45
41
  if key == provider_protocols[kind].__name__:
46
42
  continue
@@ -95,8 +91,7 @@ def ingest_network_data(
95
91
  f"Currently installed supported libraries: {sup}."
96
92
  )
97
93
 
98
- result = provider()(
99
- network=network,
94
+ result = provider(network)(
100
95
  layout=layout,
101
96
  vertex_labels=vertex_labels,
102
97
  edge_labels=edge_labels,
@@ -108,7 +103,6 @@ def ingest_network_data(
108
103
  def ingest_tree_data(
109
104
  tree: TreeType,
110
105
  layout: Optional[str] = "horizontal",
111
- orientation: Optional[str] = None,
112
106
  directed: bool | str = False,
113
107
  layout_style: Optional[dict[str, str | int | float]] = None,
114
108
  vertex_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
@@ -125,15 +119,13 @@ def ingest_tree_data(
125
119
  else:
126
120
  sup = ", ".join(data_providers["tree"].keys())
127
121
  raise ValueError(
128
- f"Tree library '{tl}' is not installed. "
129
- f"Currently installed supported libraries: {sup}."
122
+ f"Tree library '{tl}' is not installed. Currently installed supported libraries: {sup}."
130
123
  )
131
124
 
132
125
  result = provider(
133
126
  tree=tree,
134
127
  )(
135
128
  layout=layout,
136
- orientation=orientation,
137
129
  directed=directed,
138
130
  layout_style=layout_style,
139
131
  vertex_labels=vertex_labels,
@@ -142,22 +134,16 @@ def ingest_tree_data(
142
134
  )
143
135
  result["tree_library"] = tl
144
136
 
145
- # TODO: cascading thing here
146
-
147
137
  return result
148
138
 
149
139
 
150
140
  # INTERNAL FUNCTIONS
151
141
  def _update_data_providers(kind):
152
142
  """Update data provieders dynamically from external packages."""
153
- discovered_providers = importlib.metadata.entry_points(
154
- group=f"iplotx.{kind}_data_providers"
155
- )
143
+ discovered_providers = importlib.metadata.entry_points(group=f"iplotx.{kind}_data_providers")
156
144
  for entry_point in discovered_providers:
157
145
  if entry_point.name not in data_providers["network"]:
158
146
  try:
159
147
  data_providers[kind][entry_point.name] = entry_point.load()
160
148
  except Exception as e:
161
- warnings.warn(
162
- f"Failed to load {kind} data provider '{entry_point.name}': {e}"
163
- )
149
+ warnings.warn(f"Failed to load {kind} data provider '{entry_point.name}': {e}")
@@ -3,7 +3,6 @@ Heuristics module to funnel certain variable inputs (e.g. layouts) into a standa
3
3
  """
4
4
 
5
5
  from typing import (
6
- Optional,
7
6
  Any,
8
7
  )
9
8
  from collections.abc import Hashable
@@ -13,42 +12,13 @@ import pandas as pd
13
12
 
14
13
  from ..layout import compute_tree_layout
15
14
  from ..typing import (
16
- GraphType,
17
15
  GroupingType,
18
16
  LayoutType,
19
17
  )
20
18
 
21
19
 
22
- def number_of_vertices(network: GraphType) -> int:
23
- """Get the number of vertices in the network."""
24
- from . import network_library
25
-
26
- if network_library(network) == "igraph":
27
- return network.vcount()
28
- if network_library(network) == "networkx":
29
- return network.number_of_nodes()
30
- raise TypeError("Unsupported graph type. Supported types are igraph and networkx.")
31
-
32
-
33
- def detect_directedness(
34
- network: GraphType,
35
- ) -> bool:
36
- """Detect if the network is directed or not."""
37
- from . import network_library
38
-
39
- nl = network_library(network)
40
-
41
- if nl == "igraph":
42
- return network.is_directed()
43
- if nl == "networkx":
44
- import networkx as nx
45
-
46
- if isinstance(network, (nx.DiGraph, nx.MultiDiGraph)):
47
- return True
48
- return False
49
-
50
-
51
- def normalise_layout(layout, network=None):
20
+ # TODO: some of this logic should be moved into individual providers
21
+ def normalise_layout(layout, network=None, nvertices=None):
52
22
  """Normalise the layout to a pandas.DataFrame."""
53
23
  from . import network_library
54
24
 
@@ -58,7 +28,7 @@ def normalise_layout(layout, network=None):
58
28
  ig = None
59
29
 
60
30
  if layout is None:
61
- if (network is not None) and (number_of_vertices(network) == 0):
31
+ if (network is not None) and (nvertices == 0):
62
32
  return pd.DataFrame(np.zeros((0, 2)))
63
33
  return None
64
34
  if (network is not None) and isinstance(layout, str):
@@ -108,9 +78,7 @@ def normalise_tree_layout(
108
78
  if isinstance(layout, str):
109
79
  layout = compute_tree_layout(layout, **kwargs)
110
80
  else:
111
- raise NotImplementedError(
112
- "Only internally computed tree layout currently accepted."
113
- )
81
+ raise NotImplementedError("Only internally computed tree layout currently accepted.")
114
82
 
115
83
  if isinstance(layout, dict):
116
84
  # Adjust vertex layout
@@ -3,16 +3,15 @@ from typing import (
3
3
  Sequence,
4
4
  )
5
5
  from collections.abc import Hashable
6
+ import importlib
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
 
9
10
  from ....typing import (
10
- GraphType,
11
11
  LayoutType,
12
12
  )
13
13
  from ...heuristics import (
14
14
  normalise_layout,
15
- detect_directedness,
16
15
  )
17
16
  from ...typing import (
18
17
  NetworkDataProvider,
@@ -26,21 +25,24 @@ from ....utils.internal import (
26
25
  class IGraphDataProvider(NetworkDataProvider):
27
26
  def __call__(
28
27
  self,
29
- network: GraphType,
30
28
  layout: Optional[LayoutType] = None,
31
29
  vertex_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
32
30
  edge_labels: Optional[Sequence[str] | dict[str]] = None,
33
31
  ) -> NetworkData:
34
- """Create network data object for iplotx from any provider."""
35
-
36
- directed = detect_directedness(network)
32
+ """Create network data object for iplotx from an igraph object."""
33
+ network = self.network
34
+ directed = self.is_directed()
37
35
 
38
36
  # Recast vertex_labels=False as vertex_labels=None
39
37
  if np.isscalar(vertex_labels) and (not vertex_labels):
40
38
  vertex_labels = None
41
39
 
42
40
  # Vertices are ordered integers, no gaps
43
- vertex_df = normalise_layout(layout, network=network)
41
+ vertex_df = normalise_layout(
42
+ layout,
43
+ network=network,
44
+ nvertices=self.number_of_vertices(),
45
+ )
44
46
  ndim = vertex_df.shape[1]
45
47
  vertex_df.columns = _make_layout_columns(ndim)
46
48
 
@@ -49,9 +51,7 @@ class IGraphDataProvider(NetworkDataProvider):
49
51
  if np.isscalar(vertex_labels):
50
52
  vertex_df["label"] = vertex_df.index.astype(str)
51
53
  elif len(vertex_labels) != len(vertex_df):
52
- raise ValueError(
53
- "Vertex labels must be the same length as the number of vertices."
54
- )
54
+ raise ValueError("Vertex labels must be the same length as the number of vertices.")
55
55
  else:
56
56
  vertex_df["label"] = vertex_labels
57
57
 
@@ -70,9 +70,7 @@ class IGraphDataProvider(NetworkDataProvider):
70
70
  # Edge labels
71
71
  if edge_labels is not None:
72
72
  if len(edge_labels) != len(edge_df):
73
- raise ValueError(
74
- "Edge labels must be the same length as the number of edges."
75
- )
73
+ raise ValueError("Edge labels must be the same length as the number of edges.")
76
74
  edge_df["label"] = edge_labels
77
75
 
78
76
  network_data = {
@@ -85,14 +83,18 @@ class IGraphDataProvider(NetworkDataProvider):
85
83
 
86
84
  @staticmethod
87
85
  def check_dependencies() -> bool:
88
- try:
89
- import igraph
90
- except ImportError:
91
- return False
92
- return True
86
+ return importlib.util.find_spec("igraph") is not None
93
87
 
94
88
  @staticmethod
95
89
  def graph_type():
96
90
  import igraph as ig
97
91
 
98
92
  return ig.Graph
93
+
94
+ def is_directed(self):
95
+ """Whether the network is directed."""
96
+ return self.network.is_directed()
97
+
98
+ def number_of_vertices(self):
99
+ """The number of vertices/nodes in the network."""
100
+ return self.network.vcount()
@@ -3,16 +3,15 @@ from typing import (
3
3
  Sequence,
4
4
  )
5
5
  from collections.abc import Hashable
6
+ import importlib
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
 
9
10
  from ....typing import (
10
- GraphType,
11
11
  LayoutType,
12
12
  )
13
13
  from ...heuristics import (
14
14
  normalise_layout,
15
- detect_directedness,
16
15
  )
17
16
  from ...typing import (
18
17
  NetworkDataProvider,
@@ -26,16 +25,17 @@ from ....utils.internal import (
26
25
  class NetworkXDataProvider(NetworkDataProvider):
27
26
  def __call__(
28
27
  self,
29
- network: GraphType,
30
28
  layout: Optional[LayoutType] = None,
31
29
  vertex_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
32
30
  edge_labels: Optional[Sequence[str] | dict[str]] = None,
33
31
  ) -> NetworkData:
34
- """Create network data object for iplotx from any provider."""
32
+ """Create network data object for iplotx from a networkx object."""
35
33
 
36
34
  import networkx as nx
37
35
 
38
- directed = detect_directedness(network)
36
+ network = self.network
37
+
38
+ directed = self.is_directed()
39
39
 
40
40
  # Recast vertex_labels=False as vertex_labels=None
41
41
  if np.isscalar(vertex_labels) and (not vertex_labels):
@@ -45,6 +45,7 @@ class NetworkXDataProvider(NetworkDataProvider):
45
45
  vertex_df = normalise_layout(
46
46
  layout,
47
47
  network=network,
48
+ nvertices=self.number_of_vertices(),
48
49
  ).loc[pd.Index(network.nodes)]
49
50
  ndim = vertex_df.shape[1]
50
51
  vertex_df.columns = _make_layout_columns(ndim)
@@ -63,21 +64,13 @@ class NetworkXDataProvider(NetworkDataProvider):
63
64
  if "label" in vertex_df:
64
65
  del vertex_df["label"]
65
66
  else:
66
- if (
67
- np.isscalar(vertex_labels)
68
- and (not vertex_labels)
69
- and ("label" in vertex_df)
70
- ):
67
+ if np.isscalar(vertex_labels) and (not vertex_labels) and ("label" in vertex_df):
71
68
  del vertex_df["label"]
72
69
  elif vertex_labels is True:
73
70
  if "label" not in vertex_df:
74
71
  vertex_df["label"] = vertex_df.index
75
- elif (not np.isscalar(vertex_labels)) and (
76
- len(vertex_labels) != len(vertex_df)
77
- ):
78
- raise ValueError(
79
- "Vertex labels must be the same length as the number of vertices."
80
- )
72
+ elif (not np.isscalar(vertex_labels)) and (len(vertex_labels) != len(vertex_df)):
73
+ raise ValueError("Vertex labels must be the same length as the number of vertices.")
81
74
  elif isinstance(vertex_labels, nx.classes.reportviews.NodeDataView):
82
75
  vertex_df["label"] = pd.Series(dict(vertex_labels))
83
76
  else:
@@ -107,9 +100,7 @@ class NetworkXDataProvider(NetworkDataProvider):
107
100
  edge_df["label"] = [str(i) for i in edge_df.index]
108
101
  else:
109
102
  if len(edge_labels) != len(edge_df):
110
- raise ValueError(
111
- "Edge labels must be the same length as the number of edges."
112
- )
103
+ raise ValueError("Edge labels must be the same length as the number of edges.")
113
104
  edge_df["label"] = edge_labels
114
105
 
115
106
  network_data = {
@@ -122,14 +113,19 @@ class NetworkXDataProvider(NetworkDataProvider):
122
113
 
123
114
  @staticmethod
124
115
  def check_dependencies() -> bool:
125
- try:
126
- import networkx
127
- except ImportError:
128
- return False
129
- return True
116
+ return importlib.util.find_spec("networkx") is not None
130
117
 
131
118
  @staticmethod
132
119
  def graph_type():
133
120
  from networkx import Graph
134
121
 
135
122
  return Graph
123
+
124
+ def is_directed(self):
125
+ import networkx as nx
126
+
127
+ return isinstance(self.network, (nx.DiGraph, nx.MultiDiGraph))
128
+
129
+ def number_of_vertices(self):
130
+ """The number of vertices/nodes in the network."""
131
+ return self.network.number_of_nodes()
@@ -0,0 +1,114 @@
1
+ from typing import (
2
+ Optional,
3
+ Sequence,
4
+ )
5
+ from collections.abc import Hashable
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from ....typing import (
10
+ LayoutType,
11
+ )
12
+ from ...typing import (
13
+ NetworkDataProvider,
14
+ NetworkData,
15
+ )
16
+ from ....utils.internal import (
17
+ _make_layout_columns,
18
+ )
19
+
20
+
21
+ class SimpleNetworkDataProvider(NetworkDataProvider):
22
+ def __call__(
23
+ self,
24
+ layout: Optional[LayoutType] = None,
25
+ vertex_labels: Optional[Sequence[str] | dict[Hashable, str] | pd.Series] = None,
26
+ edge_labels: Optional[Sequence[str] | dict[str]] = None,
27
+ ) -> NetworkData:
28
+ """Create network data object for iplotx from a simple Python object."""
29
+ network = self.network
30
+ directed = self.is_directed()
31
+
32
+ # Recast vertex_labels=False as vertex_labels=None
33
+ if np.isscalar(vertex_labels) and (not vertex_labels):
34
+ vertex_labels = None
35
+
36
+ # Vertices are ordered integers, no gaps
37
+ for key in ["nodes", "vertices"]:
38
+ if key in network:
39
+ vertices = network[key]
40
+ break
41
+ else:
42
+ # Infer from edge adjacent vertices, singletons will be missed
43
+ vertices = set()
44
+ for edge in self.network.get("edges", []):
45
+ vertices.add(edge[0])
46
+ vertices.add(edge[1])
47
+ vertices = list(vertices)
48
+
49
+ # NOTE: This is underpowered, but it's ok for a simple educational provider
50
+ if isinstance(layout, pd.DataFrame):
51
+ vertex_df = layout.loc[vertices].copy()
52
+ elif isinstance(layout, dict):
53
+ vertex_df = pd.DataFrame(layout).T.loc[vertices]
54
+ else:
55
+ vertex_df = pd.DataFrame(
56
+ index=vertices,
57
+ data=layout,
58
+ )
59
+ ndim = vertex_df.shape[1]
60
+ vertex_df.columns = _make_layout_columns(ndim)
61
+
62
+ # Vertex labels
63
+ if vertex_labels is not None:
64
+ if np.isscalar(vertex_labels):
65
+ vertex_df["label"] = vertex_df.index.astype(str)
66
+ elif len(vertex_labels) != len(vertex_df):
67
+ raise ValueError("Vertex labels must be the same length as the number of vertices.")
68
+ else:
69
+ vertex_df["label"] = vertex_labels
70
+
71
+ # Edges are a list of tuples, because of multiedges
72
+ tmp = []
73
+ for edge in network.get("edges", []):
74
+ row = {"_ipx_source": edge[0], "_ipx_target": edge[1]}
75
+ tmp.append(row)
76
+ if len(tmp):
77
+ edge_df = pd.DataFrame(tmp)
78
+ else:
79
+ edge_df = pd.DataFrame(columns=["_ipx_source", "_ipx_target"])
80
+ del tmp
81
+
82
+ network_data = {
83
+ "vertex_df": vertex_df,
84
+ "edge_df": edge_df,
85
+ "directed": directed,
86
+ "ndim": ndim,
87
+ }
88
+ return network_data
89
+
90
+ @staticmethod
91
+ def check_dependencies() -> bool:
92
+ """Check dependencies. Returns True since this provider has no dependencies."""
93
+ return True
94
+
95
+ @staticmethod
96
+ def graph_type():
97
+ return dict
98
+
99
+ def is_directed(self):
100
+ """Whether the network is directed."""
101
+ return self.network.get("directed", False)
102
+
103
+ def number_of_vertices(self):
104
+ """The number of vertices/nodes in the network."""
105
+ for key in ("nodes", "vertices"):
106
+ if key in self.network:
107
+ return len(self.network[key])
108
+
109
+ # Default to unique edge adjacent nodes (this will ignore singletons)
110
+ nodes = set()
111
+ for edge in self.network.get("edges", []):
112
+ nodes.add(edge[0])
113
+ nodes.add(edge[1])
114
+ return len(nodes)
@@ -3,6 +3,7 @@ from typing import (
3
3
  Optional,
4
4
  Sequence,
5
5
  )
6
+ import importlib
6
7
  from functools import partialmethod
7
8
 
8
9
  from ...typing import (
@@ -34,14 +35,23 @@ class BiopythonDataProvider(TreeDataProvider):
34
35
 
35
36
  @staticmethod
36
37
  def check_dependencies() -> bool:
37
- try:
38
- from Bio import Phylo
39
- except ImportError:
40
- return False
41
- return True
38
+ return importlib.util.find_spec("Bio") is not None
42
39
 
43
40
  @staticmethod
44
41
  def tree_type():
45
42
  from Bio import Phylo
46
43
 
47
44
  return Phylo.BaseTree.Tree
45
+
46
+ def get_support(self):
47
+ """Get support/confidence values for all nodes."""
48
+ support_dict = {}
49
+ for node in self.preorder():
50
+ if hasattr(node, "confidences"):
51
+ support = node.confidences
52
+ elif hasattr(node, "confidence"):
53
+ support = node.confidence
54
+ else:
55
+ support = None
56
+ support_dict[node] = support
57
+ return support_dict
@@ -3,6 +3,7 @@ from typing import (
3
3
  Optional,
4
4
  Sequence,
5
5
  )
6
+ import importlib
6
7
  from ...typing import (
7
8
  TreeDataProvider,
8
9
  )
@@ -28,14 +29,17 @@ class Cogent3DataProvider(TreeDataProvider):
28
29
 
29
30
  @staticmethod
30
31
  def check_dependencies() -> bool:
31
- try:
32
- import cogent3
33
- except ImportError:
34
- return False
35
- return True
32
+ return importlib.util.find_spec("cogent3") is not None
36
33
 
37
34
  @staticmethod
38
35
  def tree_type():
39
36
  from cogent3.core.tree import PhyloNode
40
37
 
41
38
  return PhyloNode
39
+
40
+ def get_support(self):
41
+ """Get support values for all nodes."""
42
+ support_dict = {}
43
+ for node in self.preorder():
44
+ support_dict[node] = node.params.get("support", None)
45
+ return support_dict