MatplotLibAPI 3.2.21__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
MatplotLibAPI/Network.py CHANGED
@@ -1,8 +1,7 @@
1
1
  """Network chart plotting helpers."""
2
2
 
3
- import logging
4
3
  from collections import defaultdict
5
- from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
4
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
6
5
 
7
6
  import matplotlib.pyplot as plt
8
7
  import networkx as nx
@@ -15,13 +14,13 @@ from matplotlib.figure import Figure
15
14
  from .StyleTemplate import (
16
15
  NETWORK_STYLE_TEMPLATE,
17
16
  FIG_SIZE,
17
+ TITLE_SCALE_FACTOR,
18
18
  StyleTemplate,
19
- format_func,
20
19
  string_formatter,
21
20
  validate_dataframe,
22
21
  )
23
22
 
24
- DEFAULT = {
23
+ _DEFAULT = {
25
24
  "MAX_EDGES": 100,
26
25
  "MAX_NODES": 30,
27
26
  "MIN_NODE_SIZE": 100,
@@ -30,11 +29,28 @@ DEFAULT = {
30
29
  "GRAPH_SCALE": 2,
31
30
  "MAX_FONT_SIZE": 20,
32
31
  "MIN_FONT_SIZE": 8,
32
+ "SPRING_LAYOUT_K": 1.0,
33
+ "SPRING_LAYOUT_SEED": 42,
33
34
  }
34
35
 
36
+ _WEIGHT_PERCENTILES = np.arange(10, 100, 10)
35
37
 
36
- def softmax(x: Iterable[float]) -> np.ndarray:
37
- """Compute softmax values for array ``x``.
38
+ __all__ = [
39
+ "NETWORK_STYLE_TEMPLATE",
40
+ "NetworkGraph",
41
+ "aplot_network",
42
+ "aplot_network_node",
43
+ "aplot_network_components",
44
+ "fplot_network",
45
+ "fplot_network_node",
46
+ "fplot_network_components",
47
+ ]
48
+
49
+
50
+ def _softmax(x: Iterable[float]) -> np.ndarray:
51
+ """Private helper to compute softmax values for array ``x``.
52
+
53
+ Not part of the public API; used internally for edge weight scaling.
38
54
 
39
55
  Parameters
40
56
  ----------
@@ -47,13 +63,21 @@ def softmax(x: Iterable[float]) -> np.ndarray:
47
63
  Softmax-transformed values.
48
64
  """
49
65
  x_arr = np.array(x)
50
- return np.exp(x_arr - np.max(x_arr)) / np.exp(x_arr - np.max(x_arr)).sum()
66
+ shifted = x_arr - np.max(x_arr)
67
+ exp_shifted = np.exp(shifted)
68
+ return exp_shifted / exp_shifted.sum()
51
69
 
52
70
 
53
- def scale_weights(
54
- weights: Iterable[float], scale_min: float = 0, scale_max: float = 1
71
+ def _scale_weights(
72
+ weights: Iterable[float],
73
+ scale_min: float = 0,
74
+ scale_max: float = 1,
75
+ deciles: Optional[np.ndarray] = None,
55
76
  ) -> List[float]:
56
- """Scale weights into deciles within the given range.
77
+ """Private helper to scale weights into deciles within the given range.
78
+
79
+ This helper is internal to plotting utilities and is not part of the
80
+ supported public interface.
57
81
 
58
82
  Parameters
59
83
  ----------
@@ -63,16 +87,25 @@ def scale_weights(
63
87
  Minimum of the output range. The default is 0.
64
88
  scale_max : float, optional
65
89
  Maximum of the output range. The default is 1.
90
+ deciles : np.ndarray, optional
91
+ Precomputed percentile breakpoints to reuse. The default is ``None``.
66
92
 
67
93
  Returns
68
94
  -------
69
95
  list[float]
70
- Scaled weights.
96
+ Scaled weights or an empty list when ``weights`` is empty.
71
97
  """
72
98
  weights_arr = np.array(weights)
73
- deciles = np.percentile(weights_arr, [10, 20, 30, 40, 50, 60, 70, 80, 90])
74
- outs = np.searchsorted(deciles, weights_arr)
75
- return [out * (scale_max - scale_min) / len(deciles) + scale_min for out in outs]
99
+ if weights_arr.size == 0:
100
+ return []
101
+
102
+ percentiles = (
103
+ np.percentile(weights_arr, _WEIGHT_PERCENTILES) if deciles is None else deciles
104
+ )
105
+ outs = np.searchsorted(percentiles, weights_arr)
106
+ return [
107
+ out * (scale_max - scale_min) / len(percentiles) + scale_min for out in outs
108
+ ]
76
109
 
77
110
 
78
111
  class NodeView(nx.classes.reportviews.NodeView):
@@ -203,15 +236,86 @@ class EdgeView(nx.classes.reportviews.EdgeView):
203
236
  return [(edge[0], edge[1]) for edge in filtered_edges]
204
237
 
205
238
 
239
+ def _sanitize_node_dataframe(
240
+ node_df: Optional[pd.DataFrame],
241
+ edge_df: pd.DataFrame,
242
+ node_col: str = "node",
243
+ node_weight_col: str = "weight",
244
+ edge_source_col: str = "source",
245
+ edge_target_col: str = "target",
246
+ edge_weight_col: str = "weight",
247
+ ) -> Optional[pd.DataFrame]:
248
+ """Private helper returning ``node_df`` rows present in the edge list.
249
+
250
+ Intended for internal use when preparing plotting data.
251
+
252
+ Parameters
253
+ ----------
254
+ node_df : pd.DataFrame, optional
255
+ DataFrame containing ``node`` and ``weight`` columns.
256
+ edge_df : pd.DataFrame
257
+ Edge DataFrame containing source and target columns.
258
+ node_col : str, optional
259
+ Column name for node identifiers. The default is "node".
260
+ node_weight_col : str, optional
261
+ Column name for node weights. The default is "weight".
262
+ edge_source_col : str, optional
263
+ Column name for source edges. The default is "source".
264
+ edge_target_col : str, optional
265
+ Column name for target edges. The default is "target".
266
+ edge_weight_col : str, optional
267
+ Column name for edge weights. The default is "weight". Included to
268
+ keep signature parity with other sanitization helpers.
269
+
270
+ Returns
271
+ -------
272
+ pd.DataFrame
273
+ Filtered ``node_df`` with only nodes that appear as sources or targets.
274
+ """
275
+ if node_df is None:
276
+ return None
277
+
278
+ validate_dataframe(node_df, cols=[node_col, node_weight_col])
279
+ filtered_node_df = node_df.copy()
280
+ nodes_in_edges = list(set(edge_df[edge_source_col]).union(edge_df[edge_target_col]))
281
+ return filtered_node_df.loc[filtered_node_df[node_col].isin(nodes_in_edges)]
282
+
283
+
206
284
  class NetworkGraph:
207
285
  """Custom graph class based on NetworkX's ``Graph``.
208
286
 
209
287
  Methods
210
288
  -------
211
- sort
212
- Return nodes sorted by the specified attribute.
213
- filter
214
- Return nodes where ``attribute`` equals ``value``.
289
+ compute_positions
290
+ Return node positions computed with a spring layout.
291
+ layout
292
+ Return scaled node sizes, edge widths, and grouped font sizes.
293
+ aplot_network
294
+ Plot the graph on a provided axis.
295
+ fplot_network
296
+ Plot the graph and return a new figure.
297
+ aplot_connected_components
298
+ Plot each connected component on a shared axis.
299
+ fplot_connected_components
300
+ Plot each connected component on a new figure.
301
+ get_component_subgraph
302
+ Return the subgraph containing the specified node.
303
+ k_core
304
+ Return the k-core of the graph.
305
+ get_core_subgraph
306
+ Return the 2-core of the graph.
307
+ top_k_edges
308
+ Return the top edges for each node based on an attribute.
309
+ calculate_node_weights_from_edges
310
+ Populate node weights by summing top edge weights.
311
+ trim_edges
312
+ Create a subgraph that retains the top edges per node.
313
+ set_node_attributes
314
+ Set multiple node attributes from a mapping.
315
+ from_pandas_edgelist
316
+ Build a graph from a pandas edge list.
317
+ build_from_dataframes
318
+ Construct a graph from node and edge DataFrames with validation.
215
319
  """
216
320
 
217
321
  _nx_graph: nx.Graph
@@ -273,6 +377,61 @@ class NetworkGraph:
273
377
  """Return the number of edges in the graph."""
274
378
  return self._nx_graph.number_of_edges()
275
379
 
380
+ @property
381
+ def density(self) -> float:
382
+ """Return the density of the graph."""
383
+ return nx.density(self._nx_graph)
384
+
385
+ @property
386
+ def is_connected(self) -> bool:
387
+ """Return whether the graph is connected."""
388
+ return nx.is_connected(self._nx_graph)
389
+
390
+ @property
391
+ def average_clustering(self) -> float:
392
+ """Return the average clustering coefficient of the graph."""
393
+ return nx.average_clustering(self._nx_graph)
394
+
395
+ @property
396
+ def diameter(self) -> int:
397
+ """Return the diameter of the graph."""
398
+ return nx.diameter(self._nx_graph)
399
+
400
+ @property
401
+ def radius(self) -> int:
402
+ """Return the radius of the graph."""
403
+ return nx.radius(self._nx_graph)
404
+
405
+ @property
406
+ def center(self) -> List[Any]:
407
+ """Return the center nodes of the graph."""
408
+ return nx.center(self._nx_graph)
409
+
410
+ @property
411
+ def periphery(self) -> List[Any]:
412
+ """Return the periphery nodes of the graph."""
413
+ return nx.periphery(self._nx_graph)
414
+
415
+ @property
416
+ def average_shortest_path_length(self) -> float:
417
+ """Return the average shortest path length of the graph."""
418
+ return nx.average_shortest_path_length(self._nx_graph)
419
+
420
+ @property
421
+ def transitivity(self) -> float:
422
+ """Return the transitivity of the graph."""
423
+ return nx.transitivity(self._nx_graph)
424
+
425
+ @property
426
+ def clustering_coefficients(self) -> Dict[Any, float]:
427
+ """Return the clustering coefficients of the graph."""
428
+ return nx.clustering(self._nx_graph) # pyright: ignore[reportReturnType]
429
+
430
+ @property
431
+ def degree_assortativity_coefficient(self) -> float:
432
+ """Return the degree assortativity coefficient of the graph."""
433
+ return nx.degree_assortativity_coefficient(self._nx_graph)
434
+
276
435
  def edge_subgraph(self, edges: Iterable) -> "NetworkGraph":
277
436
  """Return a subgraph containing only the specified edges.
278
437
 
@@ -290,28 +449,28 @@ class NetworkGraph:
290
449
 
291
450
  def layout(
292
451
  self,
293
- max_node_size: int = DEFAULT["MAX_NODE_SIZE"],
294
- min_node_size: int = DEFAULT["MIN_NODE_SIZE"],
295
- max_edge_width: int = DEFAULT["MAX_EDGE_WIDTH"],
296
- max_font_size: int = DEFAULT["MAX_FONT_SIZE"],
297
- min_font_size: int = DEFAULT["MIN_FONT_SIZE"],
298
- weight: str = "weight",
452
+ max_node_size: int = _DEFAULT["MAX_NODE_SIZE"],
453
+ min_node_size: int = _DEFAULT["MIN_NODE_SIZE"],
454
+ max_edge_width: int = _DEFAULT["MAX_EDGE_WIDTH"],
455
+ max_font_size: int = _DEFAULT["MAX_FONT_SIZE"],
456
+ min_font_size: int = _DEFAULT["MIN_FONT_SIZE"],
457
+ edge_weight_col: str = "weight",
299
458
  ) -> Tuple[List[float], List[float], Dict[int, List[str]]]:
300
459
  """Calculate node, edge and font sizes based on weights.
301
460
 
302
461
  Parameters
303
462
  ----------
304
463
  max_node_size : int, optional
305
- Upper bound for node size. The default is `DEFAULT["MAX_NODE_SIZE"]`.
464
+ Upper bound for node size. The default is `_DEFAULT["MAX_NODE_SIZE"]`.
306
465
  min_node_size : int, optional
307
- Lower bound for node size. The default is `DEFAULT["MIN_NODE_SIZE"]`.
466
+ Lower bound for node size. The default is `_DEFAULT["MIN_NODE_SIZE"]`.
308
467
  max_edge_width : int, optional
309
- Upper bound for edge width. The default is `DEFAULT["MAX_EDGE_WIDTH"]`.
468
+ Upper bound for edge width. The default is `_DEFAULT["MAX_EDGE_WIDTH"]`.
310
469
  max_font_size : int, optional
311
- Upper bound for font size. The default is `DEFAULT["MAX_FONT_SIZE"]`.
470
+ Upper bound for font size. The default is `_DEFAULT["MAX_FONT_SIZE"]`.
312
471
  min_font_size : int, optional
313
- Lower bound for font size. The default is `DEFAULT["MIN_FONT_SIZE"]`.
314
- weight : str, optional
472
+ Lower bound for font size. The default is `_DEFAULT["MIN_FONT_SIZE"]`.
473
+ edge_weight_col : str, optional
315
474
  Node attribute used for weighting. The default is "weight".
316
475
 
317
476
  Returns
@@ -320,23 +479,36 @@ class NetworkGraph:
320
479
  Node sizes, edge widths and nodes grouped by font size.
321
480
  """
322
481
  # Normalize and scale nodes' weights within the desired range of edge widths
323
- node_weights = [data.get(weight, 1) for node, data in self.nodes(data=True)]
324
- node_size = scale_weights(
325
- weights=node_weights, scale_max=max_node_size, scale_min=min_node_size
482
+ node_weights = [
483
+ data.get(edge_weight_col, 1) for node, data in self.nodes(data=True)
484
+ ]
485
+ node_deciles = (
486
+ np.percentile(np.array(node_weights), _WEIGHT_PERCENTILES)
487
+ if node_weights
488
+ else None
489
+ )
490
+ node_size = _scale_weights(
491
+ weights=node_weights,
492
+ scale_max=max_node_size,
493
+ scale_min=min_node_size,
494
+ deciles=node_deciles,
326
495
  )
327
496
 
328
497
  # Normalize and scale edges' weights within the desired range of edge widths
329
- edge_weights = [data.get(weight, 1) for _, _, data in self.edges(data=True)]
330
- edges_width = scale_weights(weights=edge_weights, scale_max=max_edge_width)
498
+ edge_weights = [
499
+ data.get(edge_weight_col, 1) for _, _, data in self.edges(data=True)
500
+ ]
501
+ edges_width = _scale_weights(weights=edge_weights, scale_max=max_edge_width)
331
502
 
332
503
  # Scale the normalized node weights within the desired range of font sizes
333
504
  node_size_dict = dict(
334
505
  zip(
335
506
  self.nodes,
336
- scale_weights(
507
+ _scale_weights(
337
508
  weights=node_weights,
338
509
  scale_max=max_font_size,
339
510
  scale_min=min_font_size,
511
+ deciles=node_deciles,
340
512
  ),
341
513
  )
342
514
  )
@@ -350,7 +522,7 @@ class NetworkGraph:
350
522
  def subgraph(
351
523
  self,
352
524
  node_list: Optional[List[str]] = None,
353
- max_edges: int = DEFAULT["MAX_EDGES"],
525
+ max_edges: int = _DEFAULT["MAX_EDGES"],
354
526
  min_degree: int = 2,
355
527
  top_k_edges_per_node: int = 5,
356
528
  ) -> "NetworkGraph":
@@ -361,7 +533,7 @@ class NetworkGraph:
361
533
  node_list : list[str], optional
362
534
  Nodes to include.
363
535
  max_edges : int, optional
364
- Maximum edges to retain. The default is `DEFAULT["MAX_EDGES"]`.
536
+ Maximum edges to retain. The default is `_DEFAULT["MAX_EDGES"]`.
365
537
  min_degree : int, optional
366
538
  Minimum degree for nodes in the core subgraph. The default is 2.
367
539
  top_k_edges_per_node : int, optional
@@ -373,20 +545,82 @@ class NetworkGraph:
373
545
  Trimmed subgraph.
374
546
  """
375
547
  if node_list is None:
376
- node_list = self.nodes.sort("weight")[: DEFAULT["MAX_NODES"]]
377
- core_subgraph_nodes = list(self.get_core_subgraph(k=min_degree).nodes)
548
+ node_list = self.nodes.sort("weight")[: _DEFAULT["MAX_NODES"]]
549
+ core_subgraph_nodes = list(self.k_core(k=min_degree).nodes)
378
550
  node_list = [node for node in node_list if node in core_subgraph_nodes]
379
-
380
- subgraph = NetworkGraph(nx.subgraph(self._nx_graph, nbunch=node_list))
381
- edges = subgraph.top_k_edges(attribute="weight", k=top_k_edges_per_node).keys()
382
- subgraph = subgraph.edge_subgraph(list(edges)[:max_edges])
551
+ _subgraph = nx.subgraph(self._nx_graph, node_list)
552
+ subgraph = NetworkGraph(_subgraph)
553
+ if top_k_edges_per_node > 0:
554
+ edges = subgraph.top_k_edges(
555
+ attribute="weight", k=top_k_edges_per_node
556
+ ).keys()
557
+ subgraph = subgraph.edge_subgraph(list(edges)[:max_edges])
383
558
  return subgraph
384
559
 
385
- def plot_network(
560
+ def compute_positions(
561
+ self,
562
+ k: Optional[float] = None,
563
+ seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
564
+ ) -> Dict[Any, np.ndarray]:
565
+ """Return spring layout positions for the graph.
566
+
567
+ Parameters
568
+ ----------
569
+ k : float, optional
570
+ Optimal distance between nodes. The default is ``_DEFAULT["SPRING_LAYOUT_K"]``.
571
+ seed : int, optional
572
+ Seed for reproducible layouts. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
573
+
574
+ Returns
575
+ -------
576
+ dict[Any, np.ndarray]
577
+ Mapping of nodes to their layout coordinates.
578
+ """
579
+ layout_k = _DEFAULT["SPRING_LAYOUT_K"] if k is None else k
580
+ return nx.spring_layout(self._nx_graph, k=layout_k, seed=seed)
581
+
582
+ def get_component_subgraph(self, node: Any) -> "NetworkGraph":
583
+ """Return the connected component containing ``node``.
584
+
585
+ Parameters
586
+ ----------
587
+ node : Any
588
+ Node identifier to anchor the component selection.
589
+
590
+ Returns
591
+ -------
592
+ NetworkGraph
593
+ Subgraph made of the nodes in the same connected component as
594
+ ``node``.
595
+
596
+ Raises
597
+ ------
598
+ ValueError
599
+ If ``node`` is not present in the graph.
600
+ """
601
+ if node not in self._nx_graph:
602
+ raise ValueError(f"Node {node!r} is not present in the graph.")
603
+
604
+ component_nodes = next(
605
+ (
606
+ component
607
+ for component in nx.connected_components(self._nx_graph)
608
+ if node in component
609
+ ),
610
+ None,
611
+ )
612
+
613
+ if component_nodes is None:
614
+ return NetworkGraph(nx.Graph())
615
+
616
+ return NetworkGraph(nx.subgraph(self._nx_graph, component_nodes).copy())
617
+
618
+ def aplot_network(
386
619
  self,
387
620
  title: Optional[str] = None,
388
621
  style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
389
- weight: str = "weight",
622
+ edge_weight_col: str = "weight",
623
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
390
624
  ax: Optional[Axes] = None,
391
625
  ) -> Axes:
392
626
  """Plot the graph using node and edge weights.
@@ -397,8 +631,10 @@ class NetworkGraph:
397
631
  Plot title.
398
632
  style : StyleTemplate, optional
399
633
  Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
400
- weight : str, optional
634
+ edge_weight_col : str, optional
401
635
  Edge attribute used for weighting. The default is "weight".
636
+ layout_seed : int, optional
637
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
402
638
  ax : Axes, optional
403
639
  Axes to draw on.
404
640
 
@@ -414,8 +650,6 @@ class NetworkGraph:
414
650
  isolated_nodes = list(nx.isolates(self._nx_graph))
415
651
  graph_nx = self._nx_graph
416
652
  if isolated_nodes:
417
- # Avoid mutating the user-provided graph when pruning display-only
418
- # isolates so the underlying data remains unchanged after plotting.
419
653
  graph_nx = graph_nx.copy()
420
654
  graph_nx.remove_nodes_from(isolated_nodes)
421
655
 
@@ -425,19 +659,32 @@ class NetworkGraph:
425
659
  ax.set_axis_off()
426
660
  if title:
427
661
  ax.set_title(
428
- title, color=style.font_color, fontsize=style.font_size * 2
662
+ title,
663
+ color=style.font_color,
664
+ fontsize=style.font_size * TITLE_SCALE_FACTOR,
429
665
  )
430
666
  return ax
431
667
 
668
+ mapped_min_font_size = style.font_mapping.get(0)
669
+ mapped_max_font_size = style.font_mapping.get(4)
670
+
432
671
  node_sizes, edge_widths, font_sizes = graph.layout(
433
- min_node_size=DEFAULT["MIN_NODE_SIZE"] // 5,
434
- max_node_size=DEFAULT["MAX_NODE_SIZE"],
435
- max_edge_width=DEFAULT["MAX_EDGE_WIDTH"],
436
- min_font_size=style.font_mapping.get(0, DEFAULT["MIN_FONT_SIZE"]),
437
- max_font_size=style.font_mapping.get(4, DEFAULT["MAX_FONT_SIZE"]),
438
- weight=weight,
672
+ min_node_size=_DEFAULT["MIN_NODE_SIZE"] // 5,
673
+ max_node_size=_DEFAULT["MAX_NODE_SIZE"],
674
+ max_edge_width=_DEFAULT["MAX_EDGE_WIDTH"],
675
+ min_font_size=(
676
+ mapped_min_font_size
677
+ if mapped_min_font_size is not None
678
+ else _DEFAULT["MIN_FONT_SIZE"]
679
+ ),
680
+ max_font_size=(
681
+ mapped_max_font_size
682
+ if mapped_max_font_size is not None
683
+ else _DEFAULT["MAX_FONT_SIZE"]
684
+ ),
685
+ edge_weight_col=edge_weight_col,
439
686
  )
440
- pos = nx.spring_layout(graph._nx_graph, k=1)
687
+ pos = graph.compute_positions(seed=layout_seed)
441
688
  # nodes
442
689
  node_sizes_int = [int(size) for size in node_sizes]
443
690
  nx.draw_networkx_nodes(
@@ -469,29 +716,167 @@ class NetworkGraph:
469
716
  )
470
717
  ax.set_facecolor(style.background_color)
471
718
  if title:
472
- ax.set_title(title, color=style.font_color, fontsize=style.font_size * 2)
719
+ ax.set_title(
720
+ title,
721
+ color=style.font_color,
722
+ fontsize=style.font_size * TITLE_SCALE_FACTOR,
723
+ )
473
724
  ax.set_axis_off()
474
725
 
475
726
  return ax
476
727
 
477
- def plot_network_components(self, *args: Any, **kwargs: Any) -> List:
478
- """Plot network components.
728
+ def fplot_network(
729
+ self,
730
+ title: Optional[str] = None,
731
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
732
+ edge_weight_col: str = "weight",
733
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
734
+ ) -> Figure:
735
+ """Plot the graph using node and edge weights.
479
736
 
480
- .. deprecated:: 0.1.0
481
- `plot_network_components` will be removed in a future version.
482
- Use `fplot_network_components` instead.
737
+ Parameters
738
+ ----------
739
+ title : str, optional
740
+ Plot title.
741
+ style : StyleTemplate, optional
742
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
743
+ edge_weight_col : str, optional
744
+ Edge attribute used for weighting. The default is "weight".
745
+ layout_seed : int, optional
746
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
747
+ Returns
748
+ -------
749
+ Figure
750
+ Matplotlib figure with the plotted network.
483
751
  """
484
- import warnings
752
+ fig, ax = plt.subplots(figsize=FIG_SIZE)
753
+ fig = cast(Figure, fig)
754
+ ax = cast(Axes, ax)
755
+ fig.patch.set_facecolor(style.background_color)
756
+ self.aplot_network(
757
+ title=title,
758
+ style=style,
759
+ edge_weight_col=edge_weight_col,
760
+ layout_seed=layout_seed,
761
+ ax=ax,
762
+ )
763
+ return fig
485
764
 
486
- warnings.warn(
487
- "`plot_network_components` is deprecated and will be removed in a future version. "
488
- "Please use `fplot_network_components`.",
489
- DeprecationWarning,
490
- stacklevel=2,
765
+ def aplot_connected_components(
766
+ self,
767
+ title: Optional[str] = None,
768
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
769
+ edge_weight_col: str = "weight",
770
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
771
+ axes: Optional[np.ndarray] = None,
772
+ ) -> Union[Axes, np.ndarray]:
773
+ """Plot all connected components of the graph.
774
+
775
+ Parameters
776
+ ----------
777
+ title : str, optional
778
+ Base title for component subplots. When provided, each axis title is
779
+ suffixed with the component index.
780
+ style : StyleTemplate, optional
781
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
782
+ edge_weight_col : str, optional
783
+ Edge attribute used for weighting. The default is "weight".
784
+ layout_seed : int, optional
785
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
786
+ axes : np.ndarray, optional
787
+ Existing axes to draw each component on. If None, a grid is created
788
+ based on the number of components.
789
+
790
+ Returns
791
+ -------
792
+ Union[Axes, np.ndarray]
793
+ Matplotlib axes with the plotted network. When ``axes`` is provided
794
+ or created, the flattened array of axes is returned; otherwise, a
795
+ single Axes is returned.
796
+ """
797
+ sns.set_palette(style.palette)
798
+
799
+ graph = self
800
+ isolated_nodes = list(nx.isolates(self._nx_graph))
801
+ if isolated_nodes:
802
+ graph = NetworkGraph(self._nx_graph.copy())
803
+ graph._nx_graph.remove_nodes_from(isolated_nodes)
804
+
805
+ connected_components = list(nx.connected_components(graph._nx_graph))
806
+
807
+ local_axes = axes
808
+ created_axes = False
809
+
810
+ if not connected_components:
811
+ if local_axes is None:
812
+ local_axes = np.array([cast(Axes, plt.gca())])
813
+ for axis in local_axes.flatten():
814
+ axis.set_facecolor(style.background_color)
815
+ axis.set_axis_off()
816
+ return local_axes
817
+
818
+ if local_axes is None:
819
+ _, local_axes = _compute_network_grid(connected_components, style)
820
+ created_axes = True
821
+
822
+ for i, component in enumerate(connected_components):
823
+ if i >= len(local_axes):
824
+ break
825
+ component_graph = NetworkGraph(
826
+ nx.subgraph(graph._nx_graph, component).copy()
827
+ )
828
+ component_graph.aplot_network(
829
+ title=f"{title}::{i}" if title else str(i),
830
+ style=style,
831
+ edge_weight_col=edge_weight_col,
832
+ layout_seed=layout_seed,
833
+ ax=local_axes[i],
834
+ )
835
+ local_axes[i].set_axis_on()
836
+
837
+ for axis in local_axes[len(connected_components) :]:
838
+ axis.set_axis_off()
839
+
840
+ return local_axes if created_axes or len(local_axes) > 1 else local_axes[0]
841
+
842
+ def fplot_connected_components(
843
+ self,
844
+ title: Optional[str] = None,
845
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
846
+ edge_weight_col: str = "weight",
847
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
848
+ ) -> Figure:
849
+ """Plot all connected components of the graph.
850
+
851
+ Parameters
852
+ ----------
853
+ title : str, optional
854
+ Plot title to apply to the first component axis.
855
+ style : StyleTemplate, optional
856
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
857
+ edge_weight_col : str, optional
858
+ Edge attribute used for weighting. The default is "weight".
859
+ layout_seed : int, optional
860
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
861
+
862
+ Returns
863
+ -------
864
+ Figure
865
+ Matplotlib figure with the plotted network.
866
+ """
867
+ fig, ax = plt.subplots(figsize=FIG_SIZE)
868
+ fig = cast(Figure, fig)
869
+ fig.patch.set_facecolor(style.background_color)
870
+ self.aplot_connected_components(
871
+ title=title,
872
+ style=style,
873
+ edge_weight_col=edge_weight_col,
874
+ layout_seed=layout_seed,
875
+ axes=np.array([ax]),
491
876
  )
492
- return []
877
+ return fig
493
878
 
494
- def get_core_subgraph(self, k: int = 2) -> "NetworkGraph":
879
+ def k_core(self, k: int = 2) -> "NetworkGraph":
495
880
  """Return the k-core of the graph.
496
881
 
497
882
  The k-core is a subgraph containing only nodes with degree >= k.
@@ -509,6 +894,16 @@ class NetworkGraph:
509
894
  core_graph = nx.k_core(self._nx_graph, k=k)
510
895
  return NetworkGraph(core_graph)
511
896
 
897
+ def get_core_subgraph(self) -> "NetworkGraph":
898
+ """Return the 2-core of the graph.
899
+
900
+ Returns
901
+ -------
902
+ NetworkGraph
903
+ The k-core subgraph with minimum degree 2.
904
+ """
905
+ return self.k_core(k=2)
906
+
512
907
  def top_k_edges(
513
908
  self, attribute: str, reverse: bool = True, k: int = 5
514
909
  ) -> Dict[Any, List[Tuple[Any, Dict]]]:
@@ -540,17 +935,19 @@ class NetworkGraph:
540
935
  top_list[edge_key] = data[attribute]
541
936
  return top_list
542
937
 
543
- def calculate_node_weights_from_edges(self, weight: str = "weight", k: int = 10):
938
+ def calculate_node_weights_from_edges(
939
+ self, edge_weight_col: str = "weight", k: int = 10
940
+ ):
544
941
  """Calculate node weights by summing weights of top k edges.
545
942
 
546
943
  Parameters
547
944
  ----------
548
- weight : str, optional
945
+ edge_weight_col : str, optional
549
946
  Edge attribute to use for weighting. The default is "weight".
550
947
  k : int, optional
551
948
  Number of top edges to consider for each node. The default is 10.
552
949
  """
553
- edge_aggregates = self.top_k_edges(attribute=weight, k=k)
950
+ edge_aggregates = self.top_k_edges(attribute=edge_weight_col, k=k)
554
951
  node_aggregates = {}
555
952
  for (u, v), weight_value in edge_aggregates.items():
556
953
  if u not in node_aggregates:
@@ -560,16 +957,16 @@ class NetworkGraph:
560
957
  node_aggregates[u] += weight_value
561
958
  node_aggregates[v] += weight_value
562
959
 
563
- nx.set_node_attributes(self._nx_graph, node_aggregates, name=weight)
960
+ nx.set_node_attributes(self._nx_graph, node_aggregates, name=edge_weight_col)
564
961
 
565
962
  def trim_edges(
566
- self, weight: str = "weight", top_k_per_node: int = 5
963
+ self, edge_weight_col: str = "weight", top_k_per_node: int = 5
567
964
  ) -> "NetworkGraph":
568
965
  """Trim the graph to keep only the top k edges per node.
569
966
 
570
967
  Parameters
571
968
  ----------
572
- weight : str, optional
969
+ edge_weight_col : str, optional
573
970
  Edge attribute to use for sorting. The default is "weight".
574
971
  top_k_per_node : int, optional
575
972
  Number of top edges to keep per node. The default is 5.
@@ -579,27 +976,38 @@ class NetworkGraph:
579
976
  NetworkGraph
580
977
  A new graph containing only the top edges.
581
978
  """
582
- edges_to_keep = self.top_k_edges(attribute=weight, k=top_k_per_node)
979
+ edges_to_keep = self.top_k_edges(attribute=edge_weight_col, k=top_k_per_node)
583
980
  return self.edge_subgraph(edges=edges_to_keep)
584
981
 
982
+ def set_node_attributes(self, attributes: Dict[Any, Dict[str, Any]]):
983
+ """Set multiple node attributes from a dictionary.
984
+
985
+ Parameters
986
+ ----------
987
+ attributes : dict[Any, dict[str, Any]]
988
+ Mapping of node identifiers to their attribute dictionaries.
989
+ """
990
+ for node, attrs in attributes.items():
991
+ nx.set_node_attributes(self._nx_graph, {node: attrs})
992
+
585
993
  @staticmethod
586
994
  def from_pandas_edgelist(
587
- df: pd.DataFrame,
995
+ edges_df: pd.DataFrame,
588
996
  source: str = "source",
589
997
  target: str = "target",
590
- weight: str = "weight",
998
+ edge_weight_col: str = "weight",
591
999
  ) -> "NetworkGraph":
592
1000
  """Initialize a NetworkGraph from a simple DataFrame.
593
1001
 
594
1002
  Parameters
595
1003
  ----------
596
- df : pd.DataFrame
1004
+ edges_df : pd.DataFrame
597
1005
  DataFrame containing edge data.
598
1006
  source : str, optional
599
1007
  Column name for source nodes. The default is "source".
600
1008
  target : str, optional
601
1009
  Column name for target nodes. The default is "target".
602
- weight : str, optional
1010
+ edge_weight_col : str, optional
603
1011
  Column name for edge weights. The default is "weight".
604
1012
 
605
1013
  Returns
@@ -608,12 +1016,180 @@ class NetworkGraph:
608
1016
  Initialized network graph.
609
1017
  """
610
1018
  network_G = nx.from_pandas_edgelist(
611
- df, source=source, target=target, edge_attr=weight
1019
+ edges_df, source=source, target=target, edge_attr=edge_weight_col
612
1020
  )
613
1021
  return NetworkGraph(network_G)
614
1022
 
1023
+ @staticmethod
1024
+ def _sanitize_node_dataframe(
1025
+ node_df: pd.DataFrame,
1026
+ edge_df: pd.DataFrame,
1027
+ node_col: str = "node",
1028
+ node_weight_col: str = "weight",
1029
+ edge_source_col: str = "source",
1030
+ edge_target_col: str = "target",
1031
+ edge_weight_col: str = "weight",
1032
+ ) -> pd.DataFrame:
1033
+ """Private helper returning ``node_df`` rows present in the edge list.
1034
+
1035
+ This method supports internal builders and is not part of the public API.
1036
+
1037
+ Parameters
1038
+ ----------
1039
+ node_df : pd.DataFrame
1040
+ DataFrame containing ``node`` and ``weight`` columns.
1041
+ edge_df : pd.DataFrame
1042
+ Edge DataFrame containing source and target columns.
1043
+ node_col : str, optional
1044
+ Column name for node identifiers. The default is "node".
1045
+ node_weight_col : str, optional
1046
+ Column name for node weights. The default is "weight".
1047
+ edge_source_col : str, optional
1048
+ Column name for source edges. The default is "source".
1049
+ edge_target_col : str, optional
1050
+ Column name for target edges. The default is "target".
1051
+ edge_weight_col : str, optional
1052
+ Column name for edge weights. The default is "weight". Included for
1053
+ signature parity with other sanitization helpers.
1054
+
1055
+ Returns
1056
+ -------
1057
+ pd.DataFrame
1058
+ Filtered ``node_df`` with only nodes that appear as sources or targets.
1059
+ """
1060
+ validate_dataframe(node_df, cols=[node_col, node_weight_col])
1061
+ filtered_node_df = node_df.copy()
1062
+ nodes_in_edges = list(
1063
+ set(edge_df[edge_source_col]).union(edge_df[edge_target_col])
1064
+ )
1065
+ return filtered_node_df.loc[filtered_node_df[node_col].isin(nodes_in_edges)]
1066
+
1067
+ @staticmethod
1068
+ def _sanitize_edge_dataframe(
1069
+ node_df: pd.DataFrame,
1070
+ edge_df: pd.DataFrame,
1071
+ node_col: str = "node",
1072
+ node_weight_col: str = "weight",
1073
+ edge_source_col: str = "source",
1074
+ edge_target_col: str = "target",
1075
+ edge_weight_col: str = "weight",
1076
+ ) -> pd.DataFrame:
1077
+ """Private helper returning a sanitized copy of the edge DataFrame.
1078
+
1079
+ Intended for internal validation when building graphs from dataframes.
1080
+
1081
+ Parameters
1082
+ ----------
1083
+ node_df : pd.DataFrame
1084
+ DataFrame containing node identifiers and weights.
1085
+ edge_df : pd.DataFrame
1086
+ Edge DataFrame containing source and target columns.
1087
+ node_col : str, optional
1088
+ Column name for node identifiers. The default is "node".
1089
+ node_weight_col : str, optional
1090
+ Column name for node weights. The default is "weight".
1091
+ edge_source_col : str, optional
1092
+ Column name for source nodes. The default is "source".
1093
+ edge_target_col : str, optional
1094
+ Column name for target nodes. The default is "target".
1095
+ edge_weight_col : str, optional
1096
+ Column name for edge weights. The default is "weight".
1097
+
1098
+ Returns
1099
+ -------
1100
+ pd.DataFrame
1101
+ Sanitized edge DataFrame containing only edges whose nodes appear
1102
+ in ``node_df``.
1103
+ """
1104
+ validate_dataframe(
1105
+ edge_df, cols=[edge_source_col, edge_target_col, edge_weight_col]
1106
+ )
1107
+ validate_dataframe(node_df, cols=[node_col, node_weight_col])
1108
+ allowed_nodes = node_df[node_col].tolist()
1109
+ edge_df = edge_df.loc[
1110
+ edge_df[edge_source_col].isin(allowed_nodes)
1111
+ & edge_df[edge_target_col].isin(allowed_nodes)
1112
+ ]
1113
+ return edge_df
1114
+
1115
+ @staticmethod
1116
+ def build_from_dataframes(
1117
+ node_df: pd.DataFrame,
1118
+ edge_df: pd.DataFrame,
1119
+ node_col: str = "node",
1120
+ node_weight_col: str = "weight",
1121
+ edge_source_col: str = "source",
1122
+ edge_target_col: str = "target",
1123
+ edge_weight_col: str = "weight",
1124
+ ) -> "NetworkGraph":
1125
+ """Build a NetworkGraph from node and edge DataFrames.
1126
+
1127
+ Parameters
1128
+ ----------
1129
+ node_df : pd.DataFrame
1130
+ DataFrame containing node identifiers and weights.
1131
+ edge_df : pd.DataFrame
1132
+ DataFrame containing the edge list.
1133
+ node_col : str, optional
1134
+ Column name for node identifiers. The default is "node".
1135
+ node_weight_col : str, optional
1136
+ Column name for node weights. The default is "weight".
1137
+ edge_source_col : str, optional
1138
+ Column name for source nodes. The default is "source".
1139
+ edge_target_col : str, optional
1140
+ Column name for target nodes. The default is "target".
1141
+ edge_weight_col : str, optional
1142
+ Column name for edge weights. The default is "weight".
1143
+
1144
+ Returns
1145
+ -------
1146
+ NetworkGraph
1147
+ Prepared ``NetworkGraph`` instance with node weights set and edges
1148
+ filtered to nodes present in ``node_df``.
1149
+ """
1150
+ if node_df is not None:
1151
+ node_df = NetworkGraph._sanitize_node_dataframe(
1152
+ node_df,
1153
+ edge_df=edge_df,
1154
+ node_col=node_col,
1155
+ node_weight_col=node_weight_col,
1156
+ edge_source_col=edge_source_col,
1157
+ edge_target_col=edge_target_col,
1158
+ edge_weight_col=edge_weight_col,
1159
+ )
1160
+ edge_df = NetworkGraph._sanitize_edge_dataframe(
1161
+ node_df,
1162
+ edge_df=edge_df,
1163
+ node_col=node_col,
1164
+ node_weight_col=node_weight_col,
1165
+ edge_source_col=edge_source_col,
1166
+ edge_target_col=edge_target_col,
1167
+ edge_weight_col=edge_weight_col,
1168
+ )
1169
+ graph = NetworkGraph.from_pandas_edgelist(
1170
+ edge_df,
1171
+ source=edge_source_col,
1172
+ target=edge_target_col,
1173
+ edge_weight_col=edge_weight_col,
1174
+ )
1175
+ if node_df is None or node_df.empty:
1176
+ graph.calculate_node_weights_from_edges(
1177
+ edge_weight_col=edge_weight_col, k=10
1178
+ )
1179
+ else:
1180
+ node_weights = {
1181
+ node: {"weight": weight_value}
1182
+ for node, weight_value in node_df.set_index(node_col)[
1183
+ node_weight_col
1184
+ ].items()
1185
+ if node in graph._nx_graph.nodes
1186
+ }
1187
+ graph.set_node_attributes(node_weights)
1188
+
1189
+ return graph
615
1190
 
616
- def compute_network_grid(
1191
+
1192
+ def _compute_network_grid(
617
1193
  connected_components: List[set], style: StyleTemplate
618
1194
  ) -> Tuple[Figure, np.ndarray]:
619
1195
  """Compute the grid layout for network component subplots.
@@ -643,72 +1219,120 @@ def compute_network_grid(
643
1219
  return fig, axes
644
1220
 
645
1221
 
646
- def prepare_network_graph(
1222
+ def _prepare_network_graph(
647
1223
  pd_df: pd.DataFrame,
648
- source: str,
649
- target: str,
650
- weight: str,
651
- sort_by: Optional[str],
652
- node_list: Optional[List],
1224
+ node_col: str = "node",
1225
+ node_weight_col: str = "weight",
1226
+ edge_source_col: str = "source",
1227
+ edge_target_col: str = "target",
1228
+ edge_weight_col: str = "weight",
1229
+ sort_by: Optional[str] = None,
1230
+ node_df: Optional[pd.DataFrame] = None,
653
1231
  ) -> NetworkGraph:
654
1232
  """Prepare a NetworkGraph for plotting from a pandas DataFrame.
655
1233
 
656
1234
  This function takes a DataFrame and prepares it for network visualization by:
657
- 1. Filtering the DataFrame to include only the nodes in `node_list` (if provided).
1235
+ 1. Filtering the DataFrame to include only the nodes in ``node_df`` (if provided).
658
1236
  2. Validating the DataFrame to ensure it has the required columns.
659
1237
  3. Creating a `NetworkGraph` from the edge list.
660
1238
  4. Extracting the k-core of the graph (k=2) to focus on the main structure.
661
- 5. Calculating node weights based on the sum of their top k edge weights.
1239
+ 5. Applying node weights provided in ``node_df`` or calculating them from
1240
+ the top-k edge weights.
662
1241
  6. Trimming the graph to keep only the top k edges per node.
663
1242
 
664
1243
  Parameters
665
1244
  ----------
666
1245
  pd_df : pd.DataFrame
667
1246
  DataFrame containing the edge list.
668
- source : str
669
- Column name for source nodes.
670
- target : str
671
- Column name for target nodes.
672
- weight : str
673
- Column name for edge weights.
1247
+ node_col : str, optional
1248
+ Column name for node identifiers. The default is "node".
1249
+ node_weight_col : str, optional
1250
+ Column name for node weights. The default is "weight".
1251
+ edge_source_col : str, optional
1252
+ Column name for source nodes. The default is "source".
1253
+ edge_target_col : str, optional
1254
+ Column name for target nodes. The default is "target".
1255
+ edge_weight_col : str, optional
1256
+ Column name for edge weights. The default is "weight".
674
1257
  sort_by : str, optional
675
1258
  Column to sort the DataFrame by before processing.
676
- node_list : list, optional
677
- A list of nodes to include in the graph. If provided, the DataFrame
678
- will be filtered to include only edges connected to these nodes.
1259
+ node_df : pd.DataFrame, optional
1260
+ DataFrame containing ``node`` and ``weight`` columns. If provided, the
1261
+ DataFrame will be filtered to include only edges connected to these
1262
+ nodes, and their provided weights will be used instead of calculated
1263
+ values.
679
1264
 
680
1265
  Returns
681
1266
  -------
682
1267
  NetworkGraph
683
1268
  The prepared `NetworkGraph` object.
1269
+
1270
+ Raises
1271
+ ------
1272
+ ValueError
1273
+ If ``node_df`` is provided but none of its nodes appear as sources or
1274
+ targets in ``pd_df``.
684
1275
  """
685
- if node_list:
1276
+ filtered_node_df = _sanitize_node_dataframe(
1277
+ node_df,
1278
+ pd_df,
1279
+ node_col=node_col,
1280
+ node_weight_col=node_weight_col,
1281
+ edge_source_col=edge_source_col,
1282
+ edge_target_col=edge_target_col,
1283
+ edge_weight_col=edge_weight_col,
1284
+ )
1285
+ if node_df is not None:
1286
+ if filtered_node_df is None or filtered_node_df.empty:
1287
+ raise ValueError(
1288
+ "node_df must include at least one node present as a source or target."
1289
+ )
1290
+ allowed_nodes = filtered_node_df[node_col].tolist()
686
1291
  df = pd_df.loc[
687
- (pd_df["source"].isin(node_list)) | (pd_df["target"].isin(node_list))
1292
+ pd_df[edge_source_col].isin(allowed_nodes)
1293
+ & pd_df[edge_target_col].isin(allowed_nodes)
688
1294
  ]
689
1295
  else:
690
1296
  df = pd_df
691
- validate_dataframe(df, cols=[source, target, weight], sort_by=sort_by)
1297
+ validate_dataframe(
1298
+ df, cols=[edge_source_col, edge_target_col, edge_weight_col], sort_by=sort_by
1299
+ )
692
1300
 
693
1301
  graph = NetworkGraph.from_pandas_edgelist(
694
- df, source=source, target=target, weight=weight
1302
+ df,
1303
+ source=edge_source_col,
1304
+ target=edge_target_col,
1305
+ edge_weight_col=edge_weight_col,
695
1306
  )
696
- graph = graph.get_core_subgraph(k=2)
697
- graph.calculate_node_weights_from_edges(weight=weight, k=10)
698
- graph = graph.trim_edges(weight=weight, top_k_per_node=5)
1307
+ graph = graph.get_core_subgraph()
1308
+ if filtered_node_df is not None and not filtered_node_df.empty:
1309
+ node_weights = {
1310
+ node: weight_value
1311
+ for node, weight_value in filtered_node_df.set_index(node_col)[
1312
+ node_weight_col
1313
+ ].items()
1314
+ if node in graph._nx_graph.nodes
1315
+ }
1316
+ nx.set_node_attributes(graph._nx_graph, node_weights, name=edge_weight_col)
1317
+ else:
1318
+ graph.calculate_node_weights_from_edges(edge_weight_col=edge_weight_col, k=10)
1319
+ graph = graph.trim_edges(edge_weight_col=edge_weight_col, top_k_per_node=5)
699
1320
  return graph
700
1321
 
701
1322
 
702
1323
  def aplot_network(
703
1324
  pd_df: pd.DataFrame,
704
- source: str = "source",
705
- target: str = "target",
706
- weight: str = "weight",
707
- title: Optional[str] = None,
708
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1325
+ node_col: str = "node",
1326
+ node_weight_col: str = "weight",
1327
+ edge_source_col: str = "source",
1328
+ edge_target_col: str = "target",
1329
+ edge_weight_col: str = "weight",
709
1330
  sort_by: Optional[str] = None,
710
1331
  ascending: bool = False,
711
- node_list: Optional[List] = None,
1332
+ node_df: Optional[pd.DataFrame] = None,
1333
+ title: Optional[str] = None,
1334
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1335
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
712
1336
  ax: Optional[Axes] = None,
713
1337
  ) -> Axes:
714
1338
  """Plot a network graph on the provided axes.
@@ -717,45 +1341,149 @@ def aplot_network(
717
1341
  ----------
718
1342
  pd_df : pd.DataFrame
719
1343
  DataFrame containing edge data.
720
- source : str, optional
1344
+ node_col : str, optional
1345
+ Column name for node identifiers. The default is "node".
1346
+ node_weight_col : str, optional
1347
+ Column name for node weights. The default is "weight".
1348
+ edge_source_col : str, optional
721
1349
  Column name for source nodes. The default is "source".
722
- target : str, optional
1350
+ edge_target_col : str, optional
723
1351
  Column name for target nodes. The default is "target".
724
- weight : str, optional
1352
+ edge_weight_col : str, optional
725
1353
  Column name for edge weights. The default is "weight".
1354
+ sort_by : str, optional
1355
+ Column used to sort the data.
1356
+ ascending : bool, optional
1357
+ Sort order for the data. The default is `False`.
1358
+ node_df : pd.DataFrame, optional
1359
+ DataFrame containing ``node`` and ``weight`` columns to include.
726
1360
  title : str, optional
727
1361
  Plot title.
728
1362
  style : StyleTemplate, optional
729
1363
  Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
1364
+ layout_seed : int, optional
1365
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
1366
+ ax : Axes, optional
1367
+ Axes to draw on.
1368
+
1369
+ Returns
1370
+ -------
1371
+ Axes
1372
+ Matplotlib axes with the plotted network.
1373
+ """
1374
+ graph = _prepare_network_graph(
1375
+ pd_df,
1376
+ node_col=node_col,
1377
+ node_weight_col=node_weight_col,
1378
+ edge_source_col=edge_source_col,
1379
+ edge_target_col=edge_target_col,
1380
+ edge_weight_col=edge_weight_col,
1381
+ sort_by=sort_by,
1382
+ node_df=node_df,
1383
+ )
1384
+ return graph.aplot_network(
1385
+ title=title,
1386
+ style=style,
1387
+ edge_weight_col=edge_weight_col,
1388
+ layout_seed=layout_seed,
1389
+ ax=ax,
1390
+ )
1391
+
1392
+
1393
+ def aplot_network_node(
1394
+ pd_df: pd.DataFrame,
1395
+ node: Any,
1396
+ node_col: str = "node",
1397
+ node_weight_col: str = "weight",
1398
+ edge_source_col: str = "source",
1399
+ edge_target_col: str = "target",
1400
+ edge_weight_col: str = "weight",
1401
+ sort_by: Optional[str] = None,
1402
+ ascending: bool = False,
1403
+ node_df: Optional[pd.DataFrame] = None,
1404
+ title: Optional[str] = None,
1405
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1406
+ ax: Optional[Axes] = None,
1407
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
1408
+ ) -> Axes:
1409
+ """Plot the connected component containing ``node`` on the provided axes.
1410
+
1411
+ Parameters
1412
+ ----------
1413
+ pd_df : pd.DataFrame
1414
+ DataFrame containing edge data.
1415
+ node : Any
1416
+ Node identifier whose component should be visualized.
1417
+ node_col : str, optional
1418
+ Column name for node identifiers. The default is "node".
1419
+ node_weight_col : str, optional
1420
+ Column name for node weights. The default is "weight".
1421
+ edge_source_col : str, optional
1422
+ Column name for source nodes. The default is "source".
1423
+ edge_target_col : str, optional
1424
+ Column name for target nodes. The default is "target".
1425
+ edge_weight_col : str, optional
1426
+ Column name for edge weights. The default is "weight".
730
1427
  sort_by : str, optional
731
1428
  Column used to sort the data.
732
1429
  ascending : bool, optional
733
1430
  Sort order for the data. The default is `False`.
734
- node_list : list, optional
735
- Nodes to include.
1431
+ node_df : pd.DataFrame, optional
1432
+ DataFrame containing ``node`` and ``weight`` columns to include.
1433
+ title : str, optional
1434
+ Plot title. If ``None``, defaults to the node identifier.
1435
+ style : StyleTemplate, optional
1436
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
736
1437
  ax : Axes, optional
737
1438
  Axes to draw on.
1439
+ layout_seed : int, optional
1440
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
738
1441
 
739
1442
  Returns
740
1443
  -------
741
1444
  Axes
742
- Matplotlib axes with the plotted network.
1445
+ Matplotlib axes with the plotted component.
1446
+
1447
+ Raises
1448
+ ------
1449
+ ValueError
1450
+ If ``node`` is not present in the prepared graph.
743
1451
  """
744
- graph = prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
745
- return graph.plot_network(title=title, style=style, weight=weight, ax=ax)
1452
+ graph = _prepare_network_graph(
1453
+ pd_df,
1454
+ node_col=node_col,
1455
+ node_weight_col=node_weight_col,
1456
+ edge_source_col=edge_source_col,
1457
+ edge_target_col=edge_target_col,
1458
+ edge_weight_col=edge_weight_col,
1459
+ sort_by=sort_by,
1460
+ node_df=node_df,
1461
+ )
1462
+ component_graph = graph.get_component_subgraph(node)
1463
+ resolved_title = title if title is not None else string_formatter(node)
1464
+ return component_graph.aplot_network(
1465
+ title=resolved_title,
1466
+ style=style,
1467
+ edge_weight_col=edge_weight_col,
1468
+ ax=ax,
1469
+ layout_seed=layout_seed,
1470
+ )
746
1471
 
747
1472
 
748
1473
  def aplot_network_components(
749
1474
  pd_df: pd.DataFrame,
750
- axes: Optional[np.ndarray],
751
- source: str = "source",
752
- target: str = "target",
753
- weight: str = "weight",
754
- title: Optional[str] = None,
755
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1475
+ node_col: str = "node",
1476
+ node_weight_col: str = "weight",
1477
+ edge_source_col: str = "source",
1478
+ edge_target_col: str = "target",
1479
+ edge_weight_col: str = "weight",
756
1480
  sort_by: Optional[str] = None,
757
- node_list: Optional[List] = None,
758
1481
  ascending: bool = False,
1482
+ node_df: Optional[pd.DataFrame] = None,
1483
+ title: Optional[str] = None,
1484
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1485
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
1486
+ axes: Optional[np.ndarray] = None,
759
1487
  ) -> None:
760
1488
  """Plot network components separately on multiple axes.
761
1489
 
@@ -763,82 +1491,63 @@ def aplot_network_components(
763
1491
  ----------
764
1492
  pd_df : pd.DataFrame
765
1493
  DataFrame containing edge data.
766
- source : str, optional
1494
+ node_col : str, optional
1495
+ Column name for node identifiers. The default is "node".
1496
+ node_weight_col : str, optional
1497
+ Column name for node weights. The default is "weight".
1498
+ edge_source_col : str, optional
767
1499
  Column name for source nodes. The default is "source".
768
- target : str, optional
1500
+ edge_target_col : str, optional
769
1501
  Column name for target nodes. The default is "target".
770
- weight : str, optional
1502
+ edge_weight_col : str, optional
771
1503
  Column name for edge weights. The default is "weight".
772
- title : str, optional
773
- Base title for subplots.
774
- style : StyleTemplate, optional
775
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
776
1504
  sort_by : str, optional
777
1505
  Column used to sort the data.
778
- node_list : list, optional
779
- Nodes to include.
780
1506
  ascending : bool, optional
781
1507
  Sort order for the data. The default is `False`.
1508
+ node_df : pd.DataFrame, optional
1509
+ DataFrame containing ``node`` and ``weight`` columns to include.
1510
+ title : str, optional
1511
+ Base title for subplots.
1512
+ style : StyleTemplate, optional
1513
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
1514
+ layout_seed : int, optional
1515
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
782
1516
  axes : np.ndarray
783
1517
  Existing axes to draw on.
784
1518
  """
785
- graph = prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
786
-
787
- connected_components = list(nx.connected_components(graph._nx_graph))
788
-
789
- if not connected_components:
790
- if axes is not None:
791
- for ax in axes.flatten():
792
- ax.set_axis_off()
793
- return
794
-
795
- isolated_nodes = list(nx.isolates(graph._nx_graph))
796
- if isolated_nodes:
797
- # Keep the caller's graph intact while dropping isolates purely for
798
- # visualization.
799
- graph = NetworkGraph(graph._nx_graph.copy())
800
- graph._nx_graph.remove_nodes_from(isolated_nodes)
801
- connected_components = list(nx.connected_components(graph._nx_graph))
802
-
803
- if not connected_components:
804
- if axes is not None:
805
- for ax in axes.flatten():
806
- ax.set_axis_off()
807
- return
808
-
809
- local_axes = axes
810
- if local_axes is None:
811
- fig, local_axes = compute_network_grid(connected_components, style)
812
-
813
- i = -1
814
- for i, component in enumerate(connected_components):
815
- if i < len(local_axes):
816
- if len(component) > 5:
817
- component_graph = graph.subgraph(node_list=list(component))
818
- component_graph.plot_network(
819
- title=f"{title}::{i}" if title else str(i),
820
- style=style,
821
- weight=weight,
822
- ax=local_axes[i],
823
- )
824
- local_axes[i].set_axis_on()
825
- else:
826
- break
827
-
828
- for j in range(i + 1, len(local_axes)):
829
- local_axes[j].set_axis_off()
1519
+ graph = _prepare_network_graph(
1520
+ pd_df,
1521
+ node_col=node_col,
1522
+ node_weight_col=node_weight_col,
1523
+ edge_source_col=edge_source_col,
1524
+ edge_target_col=edge_target_col,
1525
+ edge_weight_col=edge_weight_col,
1526
+ sort_by=sort_by,
1527
+ node_df=node_df,
1528
+ )
1529
+ graph.aplot_connected_components(
1530
+ title=title,
1531
+ style=style,
1532
+ edge_weight_col=edge_weight_col,
1533
+ layout_seed=layout_seed,
1534
+ axes=axes,
1535
+ )
830
1536
 
831
1537
 
832
1538
  def fplot_network(
833
1539
  pd_df: pd.DataFrame,
834
- source: str = "source",
835
- target: str = "target",
836
- weight: str = "weight",
837
- title: Optional[str] = None,
838
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1540
+ node_col: str = "node",
1541
+ node_weight_col: str = "weight",
1542
+ edge_source_col: str = "source",
1543
+ edge_target_col: str = "target",
1544
+ edge_weight_col: str = "weight",
839
1545
  sort_by: Optional[str] = None,
840
1546
  ascending: bool = False,
841
- node_list: Optional[List] = None,
1547
+ node_df: Optional[pd.DataFrame] = None,
1548
+ title: Optional[str] = None,
1549
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1550
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
842
1551
  figsize: Tuple[float, float] = FIG_SIZE,
843
1552
  save_path: Optional[str] = None,
844
1553
  savefig_kwargs: Optional[Dict[str, Any]] = None,
@@ -849,24 +1558,34 @@ def fplot_network(
849
1558
  ----------
850
1559
  pd_df : pd.DataFrame
851
1560
  DataFrame containing edge data.
852
- source : str, optional
1561
+ node_col : str, optional
1562
+ Column name for node identifiers. The default is "node".
1563
+ node_weight_col : str, optional
1564
+ Column name for node weights. The default is "weight".
1565
+ edge_source_col : str, optional
853
1566
  Column name for source nodes. The default is "source".
854
- target : str, optional
1567
+ edge_target_col : str, optional
855
1568
  Column name for target nodes. The default is "target".
856
- weight : str, optional
1569
+ edge_weight_col : str, optional
857
1570
  Column name for edge weights. The default is "weight".
858
- title : str, optional
859
- Plot title.
860
- style : StyleTemplate, optional
861
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
862
1571
  sort_by : str, optional
863
1572
  Column used to sort the data.
864
1573
  ascending : bool, optional
865
1574
  Sort order for the data. The default is `False`.
866
- node_list : list, optional
867
- Nodes to include.
1575
+ node_df : pd.DataFrame, optional
1576
+ DataFrame containing ``node`` and ``weight`` columns to include.
1577
+ title : str, optional
1578
+ Plot title.
1579
+ style : StyleTemplate, optional
1580
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
1581
+ layout_seed : int, optional
1582
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
868
1583
  figsize : tuple[float, float], optional
869
1584
  Size of the created figure. The default is FIG_SIZE.
1585
+ save_path : str, optional
1586
+ File path to save the figure. The default is ``None``.
1587
+ savefig_kwargs : dict[str, Any], optional
1588
+ Extra keyword arguments forwarded to ``Figure.savefig``. The default is ``None``.
870
1589
 
871
1590
  Returns
872
1591
  -------
@@ -878,15 +1597,107 @@ def fplot_network(
878
1597
  ax = fig.add_subplot()
879
1598
  ax = aplot_network(
880
1599
  pd_df,
881
- source=source,
882
- target=target,
883
- weight=weight,
1600
+ node_col=node_col,
1601
+ node_weight_col=node_weight_col,
1602
+ edge_source_col=edge_source_col,
1603
+ edge_target_col=edge_target_col,
1604
+ edge_weight_col=edge_weight_col,
1605
+ sort_by=sort_by,
1606
+ ascending=ascending,
1607
+ node_df=node_df,
884
1608
  title=title,
885
1609
  style=style,
1610
+ layout_seed=layout_seed,
1611
+ ax=ax,
1612
+ )
1613
+ if save_path:
1614
+ fig.savefig(save_path, **(savefig_kwargs or {}))
1615
+ return fig
1616
+
1617
+
1618
+ def fplot_network_node(
1619
+ pd_df: pd.DataFrame,
1620
+ node: Any,
1621
+ node_col: str = "node",
1622
+ node_weight_col: str = "weight",
1623
+ edge_source_col: str = "source",
1624
+ edge_target_col: str = "target",
1625
+ edge_weight_col: str = "weight",
1626
+ sort_by: Optional[str] = None,
1627
+ ascending: bool = False,
1628
+ node_df: Optional[pd.DataFrame] = None,
1629
+ title: Optional[str] = None,
1630
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1631
+ figsize: Tuple[float, float] = FIG_SIZE,
1632
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
1633
+ save_path: Optional[str] = None,
1634
+ savefig_kwargs: Optional[Dict[str, Any]] = None,
1635
+ ) -> Figure:
1636
+ """Return a figure with the component containing ``node``.
1637
+
1638
+ Parameters
1639
+ ----------
1640
+ pd_df : pd.DataFrame
1641
+ DataFrame containing edge data.
1642
+ node : Any
1643
+ Node identifier whose component should be visualized.
1644
+ node_col : str, optional
1645
+ Column name for node identifiers. The default is "node".
1646
+ node_weight_col : str, optional
1647
+ Column name for node weights. The default is "weight".
1648
+ edge_source_col : str, optional
1649
+ Column name for source nodes. The default is "source".
1650
+ edge_target_col : str, optional
1651
+ Column name for target nodes. The default is "target".
1652
+ edge_weight_col : str, optional
1653
+ Column name for edge weights. The default is "weight".
1654
+ sort_by : str, optional
1655
+ Column used to sort the data.
1656
+ ascending : bool, optional
1657
+ Sort order for the data. The default is `False`.
1658
+ node_df : pd.DataFrame, optional
1659
+ DataFrame containing ``node`` and ``weight`` columns to include.
1660
+ figsize : tuple[float, float], optional
1661
+ Size of the created figure. The default is FIG_SIZE.
1662
+ title : str, optional
1663
+ Plot title. If ``None``, defaults to the node identifier.
1664
+ style : StyleTemplate, optional
1665
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
1666
+ save_path : str, optional
1667
+ File path to save the figure. The default is ``None``.
1668
+ savefig_kwargs : dict[str, Any], optional
1669
+ Extra keyword arguments forwarded to ``Figure.savefig``. The default is ``None``.
1670
+ layout_seed : int, optional
1671
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
1672
+
1673
+ Returns
1674
+ -------
1675
+ Figure
1676
+ Matplotlib figure with the component plot.
1677
+
1678
+ Raises
1679
+ ------
1680
+ ValueError
1681
+ If ``node`` is not present in the prepared graph.
1682
+ """
1683
+ fig = cast(Figure, plt.figure(figsize=figsize))
1684
+ fig.patch.set_facecolor(style.background_color)
1685
+ ax = fig.add_subplot()
1686
+ ax = aplot_network_node(
1687
+ pd_df,
1688
+ node=node,
1689
+ node_col=node_col,
1690
+ node_weight_col=node_weight_col,
1691
+ edge_source_col=edge_source_col,
1692
+ edge_target_col=edge_target_col,
1693
+ edge_weight_col=edge_weight_col,
886
1694
  sort_by=sort_by,
887
1695
  ascending=ascending,
888
- node_list=node_list,
1696
+ node_df=node_df,
1697
+ title=title,
1698
+ style=style,
889
1699
  ax=ax,
1700
+ layout_seed=layout_seed,
890
1701
  )
891
1702
  if save_path:
892
1703
  fig.savefig(save_path, **(savefig_kwargs or {}))
@@ -895,14 +1706,17 @@ def fplot_network(
895
1706
 
896
1707
  def fplot_network_components(
897
1708
  pd_df: pd.DataFrame,
898
- source: str = "source",
899
- target: str = "target",
900
- weight: str = "weight",
901
- title: Optional[str] = None,
902
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1709
+ node_col: str = "node",
1710
+ node_weight_col: str = "weight",
1711
+ edge_source_col: str = "source",
1712
+ edge_target_col: str = "target",
1713
+ edge_weight_col: str = "weight",
903
1714
  sort_by: Optional[str] = None,
904
1715
  ascending: bool = False,
905
- node_list: Optional[List] = None,
1716
+ node_df: Optional[pd.DataFrame] = None,
1717
+ title: Optional[str] = None,
1718
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
1719
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
906
1720
  figsize: Tuple[float, float] = FIG_SIZE,
907
1721
  n_cols: Optional[int] = None,
908
1722
  save_path: Optional[str] = None,
@@ -914,73 +1728,93 @@ def fplot_network_components(
914
1728
  ----------
915
1729
  pd_df : pd.DataFrame
916
1730
  DataFrame containing edge data.
917
- source : str, optional
1731
+ node_col : str, optional
1732
+ Column name for node identifiers. The default is "node".
1733
+ node_weight_col : str, optional
1734
+ Column name for node weights. The default is "weight".
1735
+ edge_source_col : str, optional
918
1736
  Column name for source nodes. The default is "source".
919
- target : str, optional
1737
+ edge_target_col : str, optional
920
1738
  Column name for target nodes. The default is "target".
921
- weight : str, optional
1739
+ edge_weight_col : str, optional
922
1740
  Column name for edge weights. The default is "weight".
923
- title : str, optional
924
- Plot title.
925
- style : StyleTemplate, optional
926
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
927
1741
  sort_by : str, optional
928
1742
  Column used to sort the data.
929
1743
  ascending : bool, optional
930
1744
  Sort order for the data. The default is `False`.
931
- node_list : list, optional
932
- Nodes to include.
1745
+ node_df : pd.DataFrame, optional
1746
+ DataFrame containing ``node`` and ``weight`` columns to include.
1747
+ title : str, optional
1748
+ Plot title.
1749
+ style : StyleTemplate, optional
1750
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
1751
+ layout_seed : int, optional
1752
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
933
1753
  figsize : tuple[float, float], optional
934
1754
  Size of the created figure. The default is FIG_SIZE.
935
1755
  n_cols : int, optional
936
1756
  Number of columns for subplots. If None, it's inferred.
1757
+ save_path : str, optional
1758
+ File path to save the figure. The default is ``None``.
1759
+ savefig_kwargs : dict[str, Any], optional
1760
+ Extra keyword arguments forwarded to ``Figure.savefig``. The default is ``None``.
937
1761
 
938
1762
  Returns
939
1763
  -------
940
1764
  Figure
941
1765
  Matplotlib figure displaying component plots.
942
- """
943
- # First, get the graph and components to determine the layout
944
- df = pd_df.copy()
945
- if node_list:
946
- df = df.loc[(df["source"].isin(node_list)) | (df["target"].isin(node_list))]
947
1766
 
948
- validate_dataframe(df, cols=[source, target, weight], sort_by=sort_by)
949
- graph = NetworkGraph.from_pandas_edgelist(
950
- df, source=source, target=target, weight=weight
1767
+ Raises
1768
+ ------
1769
+ ValueError
1770
+ If ``node_df`` is provided but none of its nodes appear as sources or
1771
+ targets in ``pd_df``.
1772
+ """
1773
+ graph = _prepare_network_graph(
1774
+ pd_df,
1775
+ node_col=node_col,
1776
+ node_weight_col=node_weight_col,
1777
+ edge_source_col=edge_source_col,
1778
+ edge_target_col=edge_target_col,
1779
+ edge_weight_col=edge_weight_col,
1780
+ sort_by=sort_by,
1781
+ node_df=node_df,
951
1782
  )
952
- graph = graph.get_core_subgraph(k=2)
953
- connected_components = list(nx.connected_components(graph._nx_graph))
954
1783
 
955
- n_components = len(connected_components)
956
- if n_cols is None:
957
- n_cols = int(np.ceil(np.sqrt(n_components)))
958
- n_rows = int(np.ceil(n_components / n_cols))
1784
+ working_graph = graph
1785
+ isolated_nodes = list(nx.isolates(graph._nx_graph))
1786
+ if isolated_nodes:
1787
+ working_graph = NetworkGraph(graph._nx_graph.copy())
1788
+ working_graph._nx_graph.remove_nodes_from(isolated_nodes)
1789
+
1790
+ connected_components = list(nx.connected_components(working_graph._nx_graph))
959
1791
 
960
- fig, axes_grid = plt.subplots(n_rows, n_cols, figsize=figsize)
1792
+ n_components = max(1, len(connected_components))
1793
+ n_cols_local = int(np.ceil(np.sqrt(n_components))) if n_cols is None else n_cols
1794
+ n_rows = int(np.ceil(n_components / n_cols_local))
1795
+
1796
+ fig, axes_grid = plt.subplots(n_rows, n_cols_local, figsize=figsize)
961
1797
  fig = cast(Figure, fig)
962
1798
  fig.patch.set_facecolor(style.background_color)
963
-
964
1799
  if not isinstance(axes_grid, np.ndarray):
965
1800
  axes = np.array([axes_grid])
966
1801
  else:
967
1802
  axes = axes_grid.flatten()
968
1803
 
969
- aplot_network_components(
970
- pd_df=pd_df,
971
- source=source,
972
- target=target,
973
- weight=weight,
1804
+ graph.aplot_connected_components(
974
1805
  title=title,
975
1806
  style=style,
976
- sort_by=sort_by,
977
- ascending=ascending,
978
- node_list=node_list,
1807
+ edge_weight_col=edge_weight_col,
1808
+ layout_seed=layout_seed,
979
1809
  axes=axes,
980
1810
  )
981
1811
 
982
1812
  if title:
983
- fig.suptitle(title, color=style.font_color, fontsize=style.font_size * 2.5)
1813
+ fig.suptitle(
1814
+ title,
1815
+ color=style.font_color,
1816
+ fontsize=style.font_size * TITLE_SCALE_FACTOR * 1.25,
1817
+ )
984
1818
 
985
1819
  plt.tight_layout(rect=(0, 0.03, 1, 0.95))
986
1820