iplotx 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
iplotx/cascades.py CHANGED
@@ -2,6 +2,7 @@ from typing import (
2
2
  Any,
3
3
  Optional,
4
4
  )
5
+ import warnings
5
6
  import numpy as np
6
7
  import pandas as pd
7
8
 
@@ -38,7 +39,7 @@ class CascadeCollection(mpl.collections.PatchCollection):
38
39
 
39
40
  # NOTE: there is a weird bug in pandas when using generic Hashable-s
40
41
  # with .loc. Seems like doing .T[...] works for individual index
41
- # elements only though
42
+ # elements only though (i.e. using __getitem__ a la dict)
42
43
  def get_node_coords(node):
43
44
  return layout.T[node].values
44
45
 
@@ -54,9 +55,19 @@ class CascadeCollection(mpl.collections.PatchCollection):
54
55
  # These patches need at least a facecolor (usually) or an edgecolor
55
56
  # so it's safe to make a list from these
56
57
  nodes_unordered = set()
57
- for prop in ("facecolor", "edgecolor"):
58
+ for prop in ("facecolor", "edgecolor", "linewidth", "linestyle"):
58
59
  if prop in style:
59
- nodes_unordered |= set(style[prop].keys())
60
+ value = style[prop]
61
+ if isinstance(value, dict):
62
+ nodes_unordered |= set(value.keys())
63
+
64
+ if len(nodes_unordered) == 0:
65
+ warnings.warn(
66
+ "No nodes found in the style for the cascading patches. "
67
+ "Please provide a style with at least one dict-like "
68
+ "specification among the following properties: 'facecolor', "
69
+ "'edgecolor', 'color', 'linewidth', or 'linestyle'.",
70
+ )
60
71
 
61
72
  # Draw the patches from the closest to the root (earlier drawing)
62
73
  # to the closer to the leaves (later drawing).
@@ -70,32 +81,15 @@ class CascadeCollection(mpl.collections.PatchCollection):
70
81
  f"Cascading patches not implemented for layout: {layout_name}.",
71
82
  )
72
83
 
73
- nleaves = sum(1 for leaf in provider(tree).get_leaves())
74
- extend_mode = style.get("extend", False)
75
- if extend_mode and (extend_mode != "leaf_labels"):
76
- if layout_name == "horizontal":
77
- if orientation == "right":
78
- maxdepth = layout.values[:, 0].max()
79
- else:
80
- maxdepth = layout.values[:, 0].min()
81
- elif layout_name == "vertical":
82
- if orientation == "descending":
83
- maxdepth = layout.values[:, 1].min()
84
- else:
85
- maxdepth = layout.values[:, 1].max()
86
- elif layout_name == "radial":
87
- # layout values are: r, theta
88
- maxdepth = layout.values[:, 0].max()
89
84
  self._maxdepth = maxdepth
90
85
 
91
86
  cascading_patches = []
87
+ nleaves = sum(1 for leaf in provider(tree).get_leaves())
92
88
  for node in drawing_order:
93
89
  stylei = rotate_style(style, key=node)
94
90
  stylei.pop("extend", None)
95
91
  # Default alpha is 0.5 for simple colors
96
- if isinstance(stylei.get("facecolor", None), str) and (
97
- "alpha" not in stylei
98
- ):
92
+ if isinstance(stylei.get("facecolor", None), str) and ("alpha" not in stylei):
99
93
  stylei["alpha"] = 0.5
100
94
 
101
95
  provider_node = provider(node)
@@ -137,9 +131,7 @@ class CascadeCollection(mpl.collections.PatchCollection):
137
131
  rmax = maxdepth if extend else leaves_coords[:, 0].max()
138
132
  thetamin = leaves_coords[:, 1].min() - 0.5 * dtheta
139
133
  thetamax = leaves_coords[:, 1].max() + 0.5 * dtheta
140
- thetas = np.linspace(
141
- thetamin, thetamax, max(30, (thetamax - thetamin) // 3)
142
- )
134
+ thetas = np.linspace(thetamin, thetamax, max(30, (thetamax - thetamin) // 3))
143
135
  xs = list(rmin * np.cos(thetas)) + list(rmax * np.cos(thetas[::-1]))
144
136
  ys = list(rmin * np.sin(thetas)) + list(rmax * np.sin(thetas[::-1]))
145
137
  points = list(zip(xs, ys))
@@ -200,9 +192,8 @@ class CascadeCollection(mpl.collections.PatchCollection):
200
192
  for path in self.get_paths():
201
193
  # Old radii
202
194
  r2old = np.linalg.norm(path.vertices[-2])
203
- path.vertices[(len(path.vertices) - 1) // 2 :] *= (
204
- self.get_maxdepth() / r2old
205
- )
195
+ # Update the outer part of the wedge patch
196
+ path.vertices[(len(path.vertices) - 1) // 2 :] *= self.get_maxdepth() / r2old
206
197
  return
207
198
 
208
199
  if (layout_name, orientation) == ("horizontal", "right"):
iplotx/edge/__init__.py CHANGED
@@ -121,8 +121,47 @@ class EdgeCollection(mpl.collections.PatchCollection):
121
121
  transform=transform,
122
122
  )
123
123
 
124
+ if "split" in self._style:
125
+ self._add_subedges(
126
+ len(patches),
127
+ self._style["split"],
128
+ )
129
+
130
+ def _add_subedges(
131
+ self,
132
+ nedges,
133
+ style,
134
+ ):
135
+ """Add subedges to shadow the current edges."""
136
+ segments = [np.zeros((2, 2)) for i in range(nedges)]
137
+ kwargs = {
138
+ "linewidths": [],
139
+ "edgecolors": [],
140
+ "linestyles": [],
141
+ }
142
+ for i in range(nedges):
143
+ vids = self._vertex_ids[i]
144
+ stylei = rotate_style(style, index=i, key=vids, key2=vids[-1])
145
+ for key, values in kwargs.items():
146
+ # iplotx uses singular style properties
147
+ key = key.rstrip("s")
148
+ # "color" has higher priority than "edgecolor"
149
+ if (key == "edgecolor") and ("color" in stylei):
150
+ val = stylei["color"]
151
+ else:
152
+ val = stylei.get(key.rstrip("s"), getattr(self, f"get_{key}")()[i])
153
+ values.append(val)
154
+
155
+ self._subedges = mpl.collections.LineCollection(
156
+ segments,
157
+ transform=self.get_transform(),
158
+ **kwargs,
159
+ )
160
+
124
161
  def get_children(self) -> tuple:
125
162
  children = []
163
+ if hasattr(self, "_subedges"):
164
+ children.append(self._subedges)
126
165
  if hasattr(self, "_arrows"):
127
166
  children.append(self._arrows)
128
167
  if hasattr(self, "_label_collection"):
@@ -209,7 +248,7 @@ class EdgeCollection(mpl.collections.PatchCollection):
209
248
  index = pd.Series(
210
249
  np.arange(len(index)),
211
250
  index=index,
212
- )
251
+ ).to_dict()
213
252
 
214
253
  voffsets = []
215
254
  vpaths = []
@@ -298,10 +337,20 @@ class EdgeCollection(mpl.collections.PatchCollection):
298
337
  tension = 0
299
338
  ports = None
300
339
 
340
+ # False is a synonym for "none"
301
341
  waypoints = edge_stylei.get("waypoints", "none")
342
+ if waypoints is False or waypoints is np.False_:
343
+ waypoints = "none"
344
+ elif waypoints is True or waypoints is np.True_:
345
+ raise ValueError(
346
+ "Could not determine automatically type of edge waypoints.",
347
+ )
302
348
  if waypoints != "none":
303
349
  ports = edge_stylei.get("ports", (None, None))
304
350
 
351
+ if not isinstance(waypoints, str):
352
+ __import__("ipdb").set_trace()
353
+
305
354
  # Compute actual edge path
306
355
  path, angles = _compute_edge_path(
307
356
  vcoord_data,
@@ -331,6 +380,17 @@ class EdgeCollection(mpl.collections.PatchCollection):
331
380
  if (offset != 0).any():
332
381
  path.vertices[:] = trans_inv(trans(path.vertices) + offset)
333
382
 
383
+ # If splitting is active, split the path here, shedding off the last straight
384
+ # segment but only if waypoints were used
385
+ if hasattr(self, "_subedges") and waypoints != "none":
386
+ # NOTE: we are already in the middle of a redraw, so we can happily avoid
387
+ # causing stale of the subedges. They are already scheduled to be redrawn
388
+ # at the end of this function.
389
+ self._subedges._paths[i].vertices[:] = path.vertices[-2:].copy()
390
+ # NOTE: instead of shortening the path, we just make the last bit invisible
391
+ # that makes it easier on memory management etc.
392
+ path.vertices[-1] = path.vertices[-2]
393
+
334
394
  # Collect angles for this vertex, to be used for loops plotting below
335
395
  if vinfo.get("loops", True):
336
396
  if v1 in loop_vertex_dict:
@@ -399,6 +459,9 @@ class EdgeCollection(mpl.collections.PatchCollection):
399
459
  idx += 1
400
460
 
401
461
  self._paths = paths
462
+ # FIXME:??
463
+ # if hasattr(self, "_subedges"):
464
+ # self._subedges.stale = True
402
465
 
403
466
  def _update_labels(self):
404
467
  if self._labels is None:
@@ -463,13 +526,13 @@ class EdgeCollection(mpl.collections.PatchCollection):
463
526
  if not self.get_visible():
464
527
  return
465
528
 
529
+ # This includes the subedges if present
466
530
  self._update_paths()
467
531
  # This sets the arrow offsets
468
532
  self._update_children()
469
533
 
470
534
  super().draw(renderer)
471
535
  for child in self.get_children():
472
- # This sets the arrow sizes with dpi scaling
473
536
  child.draw(renderer)
474
537
 
475
538
  def get_ports(self) -> Optional[LeafProperty[Pair[Optional[str]]]]:
@@ -489,7 +552,8 @@ class EdgeCollection(mpl.collections.PatchCollection):
489
552
  edge end.
490
553
  """
491
554
  if ports is None:
492
- del self._style["ports"]
555
+ if "ports" in self._style:
556
+ del self._style["ports"]
493
557
  else:
494
558
  self._style["ports"] = ports
495
559
  self.stale = True
@@ -523,7 +587,8 @@ class EdgeCollection(mpl.collections.PatchCollection):
523
587
 
524
588
  """
525
589
  if tension is None:
526
- del self._style["tension"]
590
+ if "tension" in self._style:
591
+ del self._style["tension"]
527
592
  else:
528
593
  self._style["tension"] = tension
529
594
  self.stale = True
@@ -583,7 +648,8 @@ class EdgeCollection(mpl.collections.PatchCollection):
583
648
  looptension: The tension to use for loops. If None, the default is 2.5.
584
649
  """
585
650
  if looptension is None:
586
- del self._style["looptension"]
651
+ if "looptension" in self._style:
652
+ del self._style["looptension"]
587
653
  else:
588
654
  self._style["looptension"] = looptension
589
655
  self.stale = True
@@ -603,7 +669,8 @@ class EdgeCollection(mpl.collections.PatchCollection):
603
669
  offset: The offset in points for parallel straight edges. If None, the default is 3.
604
670
  """
605
671
  if offset is None:
606
- del self._style["offset"]
672
+ if "offset" in self._style:
673
+ del self._style["offset"]
607
674
  else:
608
675
  self._style["offset"] = offset
609
676
  self.stale = True
@@ -634,6 +701,7 @@ def make_stub_patch(**kwargs):
634
701
  "paralleloffset",
635
702
  "cmap",
636
703
  "norm",
704
+ "split",
637
705
  ]
638
706
  for prop in forbidden_props:
639
707
  if prop in kwargs:
iplotx/edge/arrow.py CHANGED
@@ -2,6 +2,8 @@
2
2
  Module for edge arrows in iplotx.
3
3
  """
4
4
 
5
+ from typing import Never
6
+
5
7
  import numpy as np
6
8
  import matplotlib as mpl
7
9
  from matplotlib.patches import PathPatch
@@ -124,12 +126,17 @@ class EdgeArrowCollection(mpl.collections.PatchCollection):
124
126
 
125
127
  return patches, sizes
126
128
 
127
- def set_array(self, A):
129
+ def set_array(self, A: np.ndarray) -> Never:
128
130
  """Set the array for cmap/norm coloring, but keep the facecolors as set (usually 'none')."""
129
131
  raise ValueError("Setting an array for arrows directly is not supported.")
130
132
 
131
- def set_colors(self, colors):
132
- """Set arrow colors (edge and/or face) based on a colormap."""
133
+ def set_colors(self, colors: np.ndarray) -> None:
134
+ """Set arrow colors (edge and/or face) based on a colormap.
135
+
136
+ Parameters:
137
+ colors: Color array to apply. This must be an Nx3 or Nx4 vector of RGB or RGBA colors.
138
+ This function will NOT attempt to convert other color descriptions to RGB/RGBA.
139
+ """
133
140
  # NOTE: facecolors is always an array because we come from patches
134
141
  # It can have zero alpha (i.e. if we choose "none", or a hollow marker)
135
142
  self.set_edgecolor(colors)
iplotx/edge/geometry.py CHANGED
@@ -96,6 +96,7 @@ def _get_shorter_edge_coords(vpath, vsize, theta):
96
96
  xe = v1[0]
97
97
  else:
98
98
  m12 = (v2[1] - v1[1]) / (v2[0] - v1[0])
99
+ print(m12, mtheta)
99
100
  xe = (v1[1] - m12 * v1[0]) / (mtheta - m12)
100
101
  ye = mtheta * xe
101
102
  ve = np.array([xe, ye])
@@ -198,15 +199,14 @@ def _compute_edge_path_straight(
198
199
 
199
200
  # Angle of the straight line
200
201
  theta = atan2(*((vcoord_fig[1] - vcoord_fig[0])[::-1]))
202
+ print(vcoord_data_cart, vcoord_fig, theta)
201
203
 
202
204
  # Shorten at starting vertex
203
205
  vs = _get_shorter_edge_coords(vpath_fig[0], vsize_fig[0], theta) + vcoord_fig[0]
204
206
  points.append(vs)
205
207
 
206
208
  # Shorten at end vertex
207
- ve = (
208
- _get_shorter_edge_coords(vpath_fig[1], vsize_fig[1], theta + pi) + vcoord_fig[1]
209
- )
209
+ ve = _get_shorter_edge_coords(vpath_fig[1], vsize_fig[1], theta + pi) + vcoord_fig[1]
210
210
  points.append(ve)
211
211
 
212
212
  codes = ["MOVETO", "LINETO"]
@@ -230,7 +230,6 @@ def _compute_edge_path_waypoints(
230
230
  ports: Pair[Optional[str]] = (None, None),
231
231
  **kwargs,
232
232
  ):
233
-
234
233
  if waypoints in ("x0y1", "y0x1"):
235
234
  assert layout_coordinate_system == "cartesian"
236
235
 
@@ -253,8 +252,7 @@ def _compute_edge_path_waypoints(
253
252
 
254
253
  # Shorten at vertex border
255
254
  vshorts[i] = (
256
- _get_shorter_edge_coords(vpath_fig[i], vsize_fig[i], thetas[i])
257
- + vcoord_fig[i]
255
+ _get_shorter_edge_coords(vpath_fig[i], vsize_fig[i], thetas[i]) + vcoord_fig[i]
258
256
  )
259
257
 
260
258
  # Shorten waypoints to keep the angles right
@@ -302,10 +300,7 @@ def _compute_edge_path_waypoints(
302
300
  theta = atan2(*(_get_port_unit_vector(ports[i], trans_inv)[::-1]))
303
301
 
304
302
  # Shorten at vertex border
305
- vshort = (
306
- _get_shorter_edge_coords(vpath_fig[i], vsize_fig[i], theta)
307
- + vcoord_fig[i]
308
- )
303
+ vshort = _get_shorter_edge_coords(vpath_fig[i], vsize_fig[i], theta) + vcoord_fig[i]
309
304
  thetas.append(theta)
310
305
  vshorts.append(vshort)
311
306
 
@@ -324,9 +319,7 @@ def _compute_edge_path_waypoints(
324
319
 
325
320
  betas = np.linspace(alpha0, alpha1, points_per_curve)
326
321
  waypoints = [r0, r1][idx_inner] * np.vstack([np.cos(betas), np.sin(betas)]).T
327
- endpoint = [r0, r1][idx_outer] * np.array(
328
- [np.cos(alpha_outer), np.sin(alpha_outer)]
329
- )
322
+ endpoint = [r0, r1][idx_outer] * np.array([np.cos(alpha_outer), np.sin(alpha_outer)])
330
323
  points = np.array(list(waypoints) + [endpoint])
331
324
  points = trans(points)
332
325
  codes = ["MOVETO"] + ["LINETO"] * len(waypoints)
@@ -406,10 +399,7 @@ def _compute_edge_path_curved(
406
399
  thetas = [None, None]
407
400
  for i in range(2):
408
401
  thetas[i] = atan2(*((auxs[i] - vcoord_fig[i])[::-1]))
409
- vs[i] = (
410
- _get_shorter_edge_coords(vpath_fig[i], vsize_fig[i], thetas[i])
411
- + vcoord_fig[i]
412
- )
402
+ vs[i] = _get_shorter_edge_coords(vpath_fig[i], vsize_fig[i], thetas[i]) + vcoord_fig[i]
413
403
 
414
404
  path = {
415
405
  "vertices": [
iplotx/edge/ports.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Module for handling edge ports in iplotx.
3
3
  """
4
4
 
5
+ from collections.abc import Callable
5
6
  import numpy as np
6
7
 
7
8
  sq2 = np.sqrt(2) / 2
@@ -19,8 +20,8 @@ port_dict = {
19
20
 
20
21
 
21
22
  def _get_port_unit_vector(
22
- portstring,
23
- trans_inv,
23
+ portstring: str,
24
+ trans_inv: Callable,
24
25
  ):
25
26
  """Get the tangent unit vector from a port string."""
26
27
  # The only tricky bit is if the port says e.g. north but the y axis is inverted, in which
iplotx/groups.py CHANGED
@@ -130,14 +130,12 @@ class GroupingArtist(PatchCollection):
130
130
  return patches, grouping, coords_hulls
131
131
 
132
132
  def _compute_paths(self, dpi: float = 72.0) -> None:
133
- ppc = self._points_per_curve
134
133
  for i, hull in enumerate(self._coords_hulls):
135
- self._paths[i].vertices = _compute_group_path_with_vertex_padding(
134
+ _compute_group_path_with_vertex_padding(
136
135
  hull,
137
136
  self._paths[i].vertices,
138
137
  self.get_transform(),
139
138
  vertexpadding=self.get_vertexpadding_dpi(dpi),
140
- points_per_curve=ppc,
141
139
  )
142
140
 
143
141
  def _process(self) -> None:
iplotx/ingest/__init__.py CHANGED
@@ -32,15 +32,11 @@ provider_protocols = {
32
32
  }
33
33
 
34
34
  # Internally supported data providers
35
- data_providers: dict[str, dict[str, Protocol]] = {
36
- kind: {} for kind in provider_protocols
37
- }
35
+ data_providers: dict[str, dict[str, Protocol]] = {kind: {} for kind in provider_protocols}
38
36
  for kind in data_providers:
39
37
  providers_path = pathlib.Path(__file__).parent.joinpath("providers").joinpath(kind)
40
38
  for importer, module_name, _ in pkgutil.iter_modules([providers_path]):
41
- module = importlib.import_module(
42
- f"iplotx.ingest.providers.{kind}.{module_name}"
43
- )
39
+ module = importlib.import_module(f"iplotx.ingest.providers.{kind}.{module_name}")
44
40
  for key, val in module.__dict__.items():
45
41
  if key == provider_protocols[kind].__name__:
46
42
  continue
@@ -123,8 +119,7 @@ def ingest_tree_data(
123
119
  else:
124
120
  sup = ", ".join(data_providers["tree"].keys())
125
121
  raise ValueError(
126
- f"Tree library '{tl}' is not installed. "
127
- f"Currently installed supported libraries: {sup}."
122
+ f"Tree library '{tl}' is not installed. Currently installed supported libraries: {sup}."
128
123
  )
129
124
 
130
125
  result = provider(
@@ -145,14 +140,10 @@ def ingest_tree_data(
145
140
  # INTERNAL FUNCTIONS
146
141
  def _update_data_providers(kind):
147
142
  """Update data provieders dynamically from external packages."""
148
- discovered_providers = importlib.metadata.entry_points(
149
- group=f"iplotx.{kind}_data_providers"
150
- )
143
+ discovered_providers = importlib.metadata.entry_points(group=f"iplotx.{kind}_data_providers")
151
144
  for entry_point in discovered_providers:
152
145
  if entry_point.name not in data_providers["network"]:
153
146
  try:
154
147
  data_providers[kind][entry_point.name] = entry_point.load()
155
148
  except Exception as e:
156
- warnings.warn(
157
- f"Failed to load {kind} data provider '{entry_point.name}': {e}"
158
- )
149
+ warnings.warn(f"Failed to load {kind} data provider '{entry_point.name}': {e}")
@@ -3,7 +3,6 @@ Heuristics module to funnel certain variable inputs (e.g. layouts) into a standa
3
3
  """
4
4
 
5
5
  from typing import (
6
- Optional,
7
6
  Any,
8
7
  )
9
8
  from collections.abc import Hashable
@@ -79,9 +78,7 @@ def normalise_tree_layout(
79
78
  if isinstance(layout, str):
80
79
  layout = compute_tree_layout(layout, **kwargs)
81
80
  else:
82
- raise NotImplementedError(
83
- "Only internally computed tree layout currently accepted."
84
- )
81
+ raise NotImplementedError("Only internally computed tree layout currently accepted.")
85
82
 
86
83
  if isinstance(layout, dict):
87
84
  # Adjust vertex layout
@@ -3,11 +3,11 @@ from typing import (
3
3
  Sequence,
4
4
  )
5
5
  from collections.abc import Hashable
6
+ import importlib
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
 
9
10
  from ....typing import (
10
- GraphType,
11
11
  LayoutType,
12
12
  )
13
13
  from ...heuristics import (
@@ -51,9 +51,7 @@ class IGraphDataProvider(NetworkDataProvider):
51
51
  if np.isscalar(vertex_labels):
52
52
  vertex_df["label"] = vertex_df.index.astype(str)
53
53
  elif len(vertex_labels) != len(vertex_df):
54
- raise ValueError(
55
- "Vertex labels must be the same length as the number of vertices."
56
- )
54
+ raise ValueError("Vertex labels must be the same length as the number of vertices.")
57
55
  else:
58
56
  vertex_df["label"] = vertex_labels
59
57
 
@@ -72,9 +70,7 @@ class IGraphDataProvider(NetworkDataProvider):
72
70
  # Edge labels
73
71
  if edge_labels is not None:
74
72
  if len(edge_labels) != len(edge_df):
75
- raise ValueError(
76
- "Edge labels must be the same length as the number of edges."
77
- )
73
+ raise ValueError("Edge labels must be the same length as the number of edges.")
78
74
  edge_df["label"] = edge_labels
79
75
 
80
76
  network_data = {
@@ -87,11 +83,7 @@ class IGraphDataProvider(NetworkDataProvider):
87
83
 
88
84
  @staticmethod
89
85
  def check_dependencies() -> bool:
90
- try:
91
- import igraph
92
- except ImportError:
93
- return False
94
- return True
86
+ return importlib.util.find_spec("igraph") is not None
95
87
 
96
88
  @staticmethod
97
89
  def graph_type():
@@ -3,11 +3,11 @@ from typing import (
3
3
  Sequence,
4
4
  )
5
5
  from collections.abc import Hashable
6
+ import importlib
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
 
9
10
  from ....typing import (
10
- GraphType,
11
11
  LayoutType,
12
12
  )
13
13
  from ...heuristics import (
@@ -64,21 +64,13 @@ class NetworkXDataProvider(NetworkDataProvider):
64
64
  if "label" in vertex_df:
65
65
  del vertex_df["label"]
66
66
  else:
67
- if (
68
- np.isscalar(vertex_labels)
69
- and (not vertex_labels)
70
- and ("label" in vertex_df)
71
- ):
67
+ if np.isscalar(vertex_labels) and (not vertex_labels) and ("label" in vertex_df):
72
68
  del vertex_df["label"]
73
69
  elif vertex_labels is True:
74
70
  if "label" not in vertex_df:
75
71
  vertex_df["label"] = vertex_df.index
76
- elif (not np.isscalar(vertex_labels)) and (
77
- len(vertex_labels) != len(vertex_df)
78
- ):
79
- raise ValueError(
80
- "Vertex labels must be the same length as the number of vertices."
81
- )
72
+ elif (not np.isscalar(vertex_labels)) and (len(vertex_labels) != len(vertex_df)):
73
+ raise ValueError("Vertex labels must be the same length as the number of vertices.")
82
74
  elif isinstance(vertex_labels, nx.classes.reportviews.NodeDataView):
83
75
  vertex_df["label"] = pd.Series(dict(vertex_labels))
84
76
  else:
@@ -108,9 +100,7 @@ class NetworkXDataProvider(NetworkDataProvider):
108
100
  edge_df["label"] = [str(i) for i in edge_df.index]
109
101
  else:
110
102
  if len(edge_labels) != len(edge_df):
111
- raise ValueError(
112
- "Edge labels must be the same length as the number of edges."
113
- )
103
+ raise ValueError("Edge labels must be the same length as the number of edges.")
114
104
  edge_df["label"] = edge_labels
115
105
 
116
106
  network_data = {
@@ -123,11 +113,7 @@ class NetworkXDataProvider(NetworkDataProvider):
123
113
 
124
114
  @staticmethod
125
115
  def check_dependencies() -> bool:
126
- try:
127
- import networkx
128
- except ImportError:
129
- return False
130
- return True
116
+ return importlib.util.find_spec("networkx") is not None
131
117
 
132
118
  @staticmethod
133
119
  def graph_type():
@@ -1,19 +1,14 @@
1
1
  from typing import (
2
2
  Optional,
3
3
  Sequence,
4
- Any,
5
4
  )
6
5
  from collections.abc import Hashable
7
6
  import numpy as np
8
7
  import pandas as pd
9
8
 
10
9
  from ....typing import (
11
- GraphType,
12
10
  LayoutType,
13
11
  )
14
- from ...heuristics import (
15
- normalise_layout,
16
- )
17
12
  from ...typing import (
18
13
  NetworkDataProvider,
19
14
  NetworkData,
@@ -23,7 +18,7 @@ from ....utils.internal import (
23
18
  )
24
19
 
25
20
 
26
- class SimpleDataProvider(NetworkDataProvider):
21
+ class SimpleNetworkDataProvider(NetworkDataProvider):
27
22
  def __call__(
28
23
  self,
29
24
  layout: Optional[LayoutType] = None,
@@ -69,9 +64,7 @@ class SimpleDataProvider(NetworkDataProvider):
69
64
  if np.isscalar(vertex_labels):
70
65
  vertex_df["label"] = vertex_df.index.astype(str)
71
66
  elif len(vertex_labels) != len(vertex_df):
72
- raise ValueError(
73
- "Vertex labels must be the same length as the number of vertices."
74
- )
67
+ raise ValueError("Vertex labels must be the same length as the number of vertices.")
75
68
  else:
76
69
  vertex_df["label"] = vertex_labels
77
70
 
@@ -3,6 +3,7 @@ from typing import (
3
3
  Optional,
4
4
  Sequence,
5
5
  )
6
+ import importlib
6
7
  from functools import partialmethod
7
8
 
8
9
  from ...typing import (
@@ -34,11 +35,7 @@ class BiopythonDataProvider(TreeDataProvider):
34
35
 
35
36
  @staticmethod
36
37
  def check_dependencies() -> bool:
37
- try:
38
- from Bio import Phylo
39
- except ImportError:
40
- return False
41
- return True
38
+ return importlib.util.find_spec("Bio") is not None
42
39
 
43
40
  @staticmethod
44
41
  def tree_type():
@@ -3,6 +3,7 @@ from typing import (
3
3
  Optional,
4
4
  Sequence,
5
5
  )
6
+ import importlib
6
7
  from ...typing import (
7
8
  TreeDataProvider,
8
9
  )
@@ -28,11 +29,7 @@ class Cogent3DataProvider(TreeDataProvider):
28
29
 
29
30
  @staticmethod
30
31
  def check_dependencies() -> bool:
31
- try:
32
- import cogent3
33
- except ImportError:
34
- return False
35
- return True
32
+ return importlib.util.find_spec("cogent3") is not None
36
33
 
37
34
  @staticmethod
38
35
  def tree_type():
@@ -3,6 +3,7 @@ from typing import (
3
3
  Optional,
4
4
  Sequence,
5
5
  )
6
+ import importlib
6
7
  from functools import partialmethod
7
8
 
8
9
  from ...typing import (
@@ -31,11 +32,7 @@ class Ete4DataProvider(TreeDataProvider):
31
32
 
32
33
  @staticmethod
33
34
  def check_dependencies() -> bool:
34
- try:
35
- from ete4 import Tree
36
- except ImportError:
37
- return False
38
- return True
35
+ return importlib.util.find_spec("ete4") is not None
39
36
 
40
37
  @staticmethod
41
38
  def tree_type():