MatplotLibAPI 3.2.21__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. MatplotLibAPI/__init__.py +4 -86
  2. MatplotLibAPI/accessor.py +519 -196
  3. MatplotLibAPI/area.py +177 -0
  4. MatplotLibAPI/bar.py +185 -0
  5. MatplotLibAPI/base_plot.py +88 -0
  6. MatplotLibAPI/box_violin.py +180 -0
  7. MatplotLibAPI/bubble.py +568 -0
  8. MatplotLibAPI/{Composite.py → composite.py} +127 -106
  9. MatplotLibAPI/heatmap.py +223 -0
  10. MatplotLibAPI/histogram.py +170 -0
  11. MatplotLibAPI/mcp/__init__.py +17 -0
  12. MatplotLibAPI/mcp/metadata.py +90 -0
  13. MatplotLibAPI/mcp/renderers.py +45 -0
  14. MatplotLibAPI/mcp_server.py +626 -0
  15. MatplotLibAPI/network/__init__.py +28 -0
  16. MatplotLibAPI/network/constants.py +22 -0
  17. MatplotLibAPI/network/core.py +1360 -0
  18. MatplotLibAPI/network/plot.py +597 -0
  19. MatplotLibAPI/network/scaling.py +56 -0
  20. MatplotLibAPI/pie.py +154 -0
  21. MatplotLibAPI/pivot.py +274 -0
  22. MatplotLibAPI/sankey.py +99 -0
  23. MatplotLibAPI/{StyleTemplate.py → style_template.py} +27 -22
  24. MatplotLibAPI/sunburst.py +139 -0
  25. MatplotLibAPI/{Table.py → table.py} +112 -87
  26. MatplotLibAPI/{Timeserie.py → timeserie.py} +98 -42
  27. MatplotLibAPI/{Treemap.py → treemap.py} +43 -55
  28. MatplotLibAPI/typing.py +12 -0
  29. MatplotLibAPI/{_visualization_utils.py → utils.py} +7 -13
  30. MatplotLibAPI/waffle.py +173 -0
  31. MatplotLibAPI/word_cloud.py +489 -0
  32. {matplotlibapi-3.2.21.dist-info → matplotlibapi-4.0.0.dist-info}/METADATA +98 -9
  33. matplotlibapi-4.0.0.dist-info/RECORD +36 -0
  34. {matplotlibapi-3.2.21.dist-info → matplotlibapi-4.0.0.dist-info}/WHEEL +1 -1
  35. matplotlibapi-4.0.0.dist-info/entry_points.txt +2 -0
  36. MatplotLibAPI/Area.py +0 -80
  37. MatplotLibAPI/Bar.py +0 -83
  38. MatplotLibAPI/BoxViolin.py +0 -75
  39. MatplotLibAPI/Bubble.py +0 -458
  40. MatplotLibAPI/Heatmap.py +0 -121
  41. MatplotLibAPI/Histogram.py +0 -73
  42. MatplotLibAPI/Network.py +0 -989
  43. MatplotLibAPI/Pie.py +0 -70
  44. MatplotLibAPI/Pivot.py +0 -134
  45. MatplotLibAPI/Sankey.py +0 -46
  46. MatplotLibAPI/Sunburst.py +0 -89
  47. MatplotLibAPI/Waffle.py +0 -86
  48. MatplotLibAPI/Wordcloud.py +0 -373
  49. MatplotLibAPI/_typing.py +0 -17
  50. matplotlibapi-3.2.21.dist-info/RECORD +0 -26
  51. {matplotlibapi-3.2.21.dist-info → matplotlibapi-4.0.0.dist-info}/licenses/LICENSE +0 -0
MatplotLibAPI/Network.py DELETED
@@ -1,989 +0,0 @@
1
- """Network chart plotting helpers."""
2
-
3
- import logging
4
- from collections import defaultdict
5
- from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
6
-
7
- import matplotlib.pyplot as plt
8
- import networkx as nx
9
- import numpy as np
10
- import pandas as pd
11
- import seaborn as sns
12
- from matplotlib.axes import Axes
13
- from matplotlib.figure import Figure
14
-
15
- from .StyleTemplate import (
16
- NETWORK_STYLE_TEMPLATE,
17
- FIG_SIZE,
18
- StyleTemplate,
19
- format_func,
20
- string_formatter,
21
- validate_dataframe,
22
- )
23
-
24
- DEFAULT = {
25
- "MAX_EDGES": 100,
26
- "MAX_NODES": 30,
27
- "MIN_NODE_SIZE": 100,
28
- "MAX_NODE_SIZE": 2000,
29
- "MAX_EDGE_WIDTH": 10,
30
- "GRAPH_SCALE": 2,
31
- "MAX_FONT_SIZE": 20,
32
- "MIN_FONT_SIZE": 8,
33
- }
34
-
35
-
36
- def softmax(x: Iterable[float]) -> np.ndarray:
37
- """Compute softmax values for array ``x``.
38
-
39
- Parameters
40
- ----------
41
- x : Iterable[float]
42
- Input values.
43
-
44
- Returns
45
- -------
46
- np.ndarray
47
- Softmax-transformed values.
48
- """
49
- x_arr = np.array(x)
50
- return np.exp(x_arr - np.max(x_arr)) / np.exp(x_arr - np.max(x_arr)).sum()
51
-
52
-
53
- def scale_weights(
54
- weights: Iterable[float], scale_min: float = 0, scale_max: float = 1
55
- ) -> List[float]:
56
- """Scale weights into deciles within the given range.
57
-
58
- Parameters
59
- ----------
60
- weights : Iterable[float]
61
- Sequence of weights to scale.
62
- scale_min : float, optional
63
- Minimum of the output range. The default is 0.
64
- scale_max : float, optional
65
- Maximum of the output range. The default is 1.
66
-
67
- Returns
68
- -------
69
- list[float]
70
- Scaled weights.
71
- """
72
- weights_arr = np.array(weights)
73
- deciles = np.percentile(weights_arr, [10, 20, 30, 40, 50, 60, 70, 80, 90])
74
- outs = np.searchsorted(deciles, weights_arr)
75
- return [out * (scale_max - scale_min) / len(deciles) + scale_min for out in outs]
76
-
77
-
78
- class NodeView(nx.classes.reportviews.NodeView):
79
- """Extended node view with convenience helpers."""
80
-
81
- def sort(self, attribute: str = "weight", reverse: bool = True) -> List[Any]:
82
- """Return nodes sorted by the specified attribute.
83
-
84
- Parameters
85
- ----------
86
- attribute : str, optional
87
- Node attribute used for sorting. The default is "weight".
88
- reverse : bool, optional
89
- Sort order. The default is `True`.
90
-
91
- Returns
92
- -------
93
- list[Any]
94
- Sorted nodes.
95
- """
96
- sorted_nodes = sorted(
97
- self, key=lambda node: self[node].get(attribute, 1), reverse=reverse
98
- )
99
- return sorted_nodes
100
-
101
- def filter(self, attribute: str, value: str) -> List[Any]:
102
- """Return nodes where ``attribute`` equals ``value``.
103
-
104
- Parameters
105
- ----------
106
- attribute : str
107
- Node attribute to compare.
108
- value : str
109
- Desired attribute value.
110
-
111
- Returns
112
- -------
113
- list
114
- Nodes matching the condition.
115
- """
116
- filtered_nodes = [node for node in self if self[node].get(attribute) == value]
117
- return filtered_nodes
118
-
119
-
120
- class AdjacencyView(nx.classes.coreviews.AdjacencyView):
121
- """Adjacency view with sorting and filtering helpers."""
122
-
123
- def sort(self, attribute: str = "weight", reverse: bool = True) -> List[Any]:
124
- """Return adjacent nodes sorted by the given attribute.
125
-
126
- Parameters
127
- ----------
128
- attribute : str, optional
129
- Attribute used for sorting. The default is "weight".
130
- reverse : bool, optional
131
- Sort order. The default is `True`.
132
-
133
- Returns
134
- -------
135
- list[Any]
136
- Sorted adjacent nodes.
137
- """
138
- sorted_nodes = sorted(
139
- self, key=lambda node: self[node].get(attribute, 1), reverse=reverse
140
- )
141
- return sorted_nodes
142
-
143
- def filter(self, attribute: str, value: str) -> List[Any]:
144
- """Return adjacent nodes where ``attribute`` equals ``value``.
145
-
146
- Parameters
147
- ----------
148
- attribute : str
149
- Node attribute to compare.
150
- value : str
151
- Desired attribute value.
152
-
153
- Returns
154
- -------
155
- list
156
- Adjacent nodes matching the value.
157
- """
158
- filtered_nodes = [node for node in self if self[node].get(attribute) == value]
159
- return filtered_nodes
160
-
161
-
162
- class EdgeView(nx.classes.reportviews.EdgeView):
163
- """Edge view with sorting and filtering helpers."""
164
-
165
- def sort(
166
- self, attribute: str = "weight", reverse: bool = True
167
- ) -> Dict[Tuple[Any, Any], Dict[str, Any]]:
168
- """Return edges sorted by the given attribute.
169
-
170
- Parameters
171
- ----------
172
- attribute : str, optional
173
- Edge attribute used for sorting. The default is "weight".
174
- reverse : bool, optional
175
- Sort order. The default is `True`.
176
-
177
- Returns
178
- -------
179
- dict[tuple[Any, Any], dict[str, Any]]
180
- Mapping of edge tuples to their attributes.
181
- """
182
- sorted_edges = sorted(
183
- self(data=True), key=lambda t: t[2].get(attribute, 1), reverse=reverse
184
- )
185
- return {(u, v): data for u, v, data in sorted_edges}
186
-
187
- def filter(self, attribute: str, value: str) -> List[Tuple[Any, Any]]:
188
- """Return edges where ``attribute`` equals ``value``.
189
-
190
- Parameters
191
- ----------
192
- attribute : str
193
- Edge attribute to compare.
194
- value : str
195
- Desired attribute value.
196
-
197
- Returns
198
- -------
199
- list[tuple[Any, Any]]
200
- Edges matching the condition.
201
- """
202
- filtered_edges = [edge for edge in self if self[edge].get(attribute) == value]
203
- return [(edge[0], edge[1]) for edge in filtered_edges]
204
-
205
-
206
- class NetworkGraph:
207
- """Custom graph class based on NetworkX's ``Graph``.
208
-
209
- Methods
210
- -------
211
- sort
212
- Return nodes sorted by the specified attribute.
213
- filter
214
- Return nodes where ``attribute`` equals ``value``.
215
- """
216
-
217
- _nx_graph: nx.Graph
218
-
219
- def __init__(self, nx_graph: nx.Graph):
220
- """Initialize with an existing NetworkX graph.
221
-
222
- Parameters
223
- ----------
224
- nx_graph : nx.Graph
225
- Graph to wrap.
226
- """
227
- self._nx_graph = nx_graph
228
- self._scale = 1.0
229
-
230
- @property
231
- def scale(self) -> float:
232
- """Return scaling factor for plotting sizes."""
233
- return self._scale
234
-
235
- @scale.setter
236
- def scale(self, value: float):
237
- """Set scaling factor for plotting sizes.
238
-
239
- Parameters
240
- ----------
241
- value : float
242
- Scaling factor.
243
- """
244
- self._scale = value
245
-
246
- @property
247
- def nodes(self) -> NodeView:
248
- """Return a ``NodeView`` over the graph."""
249
- return NodeView(self._nx_graph)
250
-
251
- @property
252
- def edges(self) -> EdgeView:
253
- """Return an ``EdgeView`` over the graph."""
254
- return EdgeView(self._nx_graph)
255
-
256
- @property
257
- def adjacency(self) -> AdjacencyView:
258
- """Return an ``AdjacencyView`` of the graph."""
259
- return AdjacencyView(self._nx_graph.adj)
260
-
261
- @property
262
- def connected_components(self) -> List[set]:
263
- """Return the connected components of the graph."""
264
- return list(nx.connected_components(self._nx_graph))
265
-
266
- @property
267
- def number_of_nodes(self) -> int:
268
- """Return the number of nodes in the graph."""
269
- return self._nx_graph.number_of_nodes()
270
-
271
- @property
272
- def number_of_edges(self) -> int:
273
- """Return the number of edges in the graph."""
274
- return self._nx_graph.number_of_edges()
275
-
276
- def edge_subgraph(self, edges: Iterable) -> "NetworkGraph":
277
- """Return a subgraph containing only the specified edges.
278
-
279
- Parameters
280
- ----------
281
- edges : Iterable
282
- Edges to include.
283
-
284
- Returns
285
- -------
286
- NetworkGraph
287
- Subgraph with only ``edges``.
288
- """
289
- return NetworkGraph(nx.edge_subgraph(self._nx_graph, edges))
290
-
291
- def layout(
292
- self,
293
- max_node_size: int = DEFAULT["MAX_NODE_SIZE"],
294
- min_node_size: int = DEFAULT["MIN_NODE_SIZE"],
295
- max_edge_width: int = DEFAULT["MAX_EDGE_WIDTH"],
296
- max_font_size: int = DEFAULT["MAX_FONT_SIZE"],
297
- min_font_size: int = DEFAULT["MIN_FONT_SIZE"],
298
- weight: str = "weight",
299
- ) -> Tuple[List[float], List[float], Dict[int, List[str]]]:
300
- """Calculate node, edge and font sizes based on weights.
301
-
302
- Parameters
303
- ----------
304
- max_node_size : int, optional
305
- Upper bound for node size. The default is `DEFAULT["MAX_NODE_SIZE"]`.
306
- min_node_size : int, optional
307
- Lower bound for node size. The default is `DEFAULT["MIN_NODE_SIZE"]`.
308
- max_edge_width : int, optional
309
- Upper bound for edge width. The default is `DEFAULT["MAX_EDGE_WIDTH"]`.
310
- max_font_size : int, optional
311
- Upper bound for font size. The default is `DEFAULT["MAX_FONT_SIZE"]`.
312
- min_font_size : int, optional
313
- Lower bound for font size. The default is `DEFAULT["MIN_FONT_SIZE"]`.
314
- weight : str, optional
315
- Node attribute used for weighting. The default is "weight".
316
-
317
- Returns
318
- -------
319
- tuple[list[float], list[float], dict[int, list[str]]]
320
- Node sizes, edge widths and nodes grouped by font size.
321
- """
322
- # Normalize and scale nodes' weights within the desired range of edge widths
323
- node_weights = [data.get(weight, 1) for node, data in self.nodes(data=True)]
324
- node_size = scale_weights(
325
- weights=node_weights, scale_max=max_node_size, scale_min=min_node_size
326
- )
327
-
328
- # Normalize and scale edges' weights within the desired range of edge widths
329
- edge_weights = [data.get(weight, 1) for _, _, data in self.edges(data=True)]
330
- edges_width = scale_weights(weights=edge_weights, scale_max=max_edge_width)
331
-
332
- # Scale the normalized node weights within the desired range of font sizes
333
- node_size_dict = dict(
334
- zip(
335
- self.nodes,
336
- scale_weights(
337
- weights=node_weights,
338
- scale_max=max_font_size,
339
- scale_min=min_font_size,
340
- ),
341
- )
342
- )
343
- fonts_size = defaultdict(list)
344
- for node, width in node_size_dict.items():
345
- fonts_size[int(width)].append(node)
346
- fonts_size = dict(fonts_size)
347
-
348
- return node_size, edges_width, fonts_size
349
-
350
- def subgraph(
351
- self,
352
- node_list: Optional[List[str]] = None,
353
- max_edges: int = DEFAULT["MAX_EDGES"],
354
- min_degree: int = 2,
355
- top_k_edges_per_node: int = 5,
356
- ) -> "NetworkGraph":
357
- """Return a trimmed subgraph limited by nodes and edges.
358
-
359
- Parameters
360
- ----------
361
- node_list : list[str], optional
362
- Nodes to include.
363
- max_edges : int, optional
364
- Maximum edges to retain. The default is `DEFAULT["MAX_EDGES"]`.
365
- min_degree : int, optional
366
- Minimum degree for nodes in the core subgraph. The default is 2.
367
- top_k_edges_per_node : int, optional
368
- Number of top edges to keep per node. The default is 5.
369
-
370
- Returns
371
- -------
372
- NetworkGraph
373
- Trimmed subgraph.
374
- """
375
- if node_list is None:
376
- node_list = self.nodes.sort("weight")[: DEFAULT["MAX_NODES"]]
377
- core_subgraph_nodes = list(self.get_core_subgraph(k=min_degree).nodes)
378
- node_list = [node for node in node_list if node in core_subgraph_nodes]
379
-
380
- subgraph = NetworkGraph(nx.subgraph(self._nx_graph, nbunch=node_list))
381
- edges = subgraph.top_k_edges(attribute="weight", k=top_k_edges_per_node).keys()
382
- subgraph = subgraph.edge_subgraph(list(edges)[:max_edges])
383
- return subgraph
384
-
385
- def plot_network(
386
- self,
387
- title: Optional[str] = None,
388
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
389
- weight: str = "weight",
390
- ax: Optional[Axes] = None,
391
- ) -> Axes:
392
- """Plot the graph using node and edge weights.
393
-
394
- Parameters
395
- ----------
396
- title : str, optional
397
- Plot title.
398
- style : StyleTemplate, optional
399
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
400
- weight : str, optional
401
- Edge attribute used for weighting. The default is "weight".
402
- ax : Axes, optional
403
- Axes to draw on.
404
-
405
- Returns
406
- -------
407
- Axes
408
- Matplotlib axes with the plotted network.
409
- """
410
- sns.set_palette(style.palette)
411
- if ax is None:
412
- ax = cast(Axes, plt.gca())
413
-
414
- isolated_nodes = list(nx.isolates(self._nx_graph))
415
- graph_nx = self._nx_graph
416
- if isolated_nodes:
417
- # Avoid mutating the user-provided graph when pruning display-only
418
- # isolates so the underlying data remains unchanged after plotting.
419
- graph_nx = graph_nx.copy()
420
- graph_nx.remove_nodes_from(isolated_nodes)
421
-
422
- graph = self if graph_nx is self._nx_graph else NetworkGraph(graph_nx)
423
-
424
- if graph._nx_graph.number_of_nodes() == 0:
425
- ax.set_axis_off()
426
- if title:
427
- ax.set_title(
428
- title, color=style.font_color, fontsize=style.font_size * 2
429
- )
430
- return ax
431
-
432
- node_sizes, edge_widths, font_sizes = graph.layout(
433
- min_node_size=DEFAULT["MIN_NODE_SIZE"] // 5,
434
- max_node_size=DEFAULT["MAX_NODE_SIZE"],
435
- max_edge_width=DEFAULT["MAX_EDGE_WIDTH"],
436
- min_font_size=style.font_mapping.get(0, DEFAULT["MIN_FONT_SIZE"]),
437
- max_font_size=style.font_mapping.get(4, DEFAULT["MAX_FONT_SIZE"]),
438
- weight=weight,
439
- )
440
- pos = nx.spring_layout(graph._nx_graph, k=1)
441
- # nodes
442
- node_sizes_int = [int(size) for size in node_sizes]
443
- nx.draw_networkx_nodes(
444
- graph._nx_graph,
445
- pos,
446
- ax=ax,
447
- node_size=cast(Any, node_sizes_int),
448
- node_color=cast(Any, node_sizes),
449
- cmap=plt.get_cmap(style.palette),
450
- )
451
- # edges
452
- nx.draw_networkx_edges(
453
- graph._nx_graph,
454
- pos,
455
- ax=ax,
456
- edge_color=style.font_color,
457
- edge_cmap=plt.get_cmap(style.palette),
458
- width=cast(Any, edge_widths),
459
- )
460
- # labels
461
- for font_size, nodes in font_sizes.items():
462
- nx.draw_networkx_labels(
463
- graph._nx_graph,
464
- pos,
465
- ax=ax,
466
- font_size=font_size,
467
- font_color=style.font_color,
468
- labels={n: string_formatter(n) for n in nodes},
469
- )
470
- ax.set_facecolor(style.background_color)
471
- if title:
472
- ax.set_title(title, color=style.font_color, fontsize=style.font_size * 2)
473
- ax.set_axis_off()
474
-
475
- return ax
476
-
477
- def plot_network_components(self, *args: Any, **kwargs: Any) -> List:
478
- """Plot network components.
479
-
480
- .. deprecated:: 0.1.0
481
- `plot_network_components` will be removed in a future version.
482
- Use `fplot_network_components` instead.
483
- """
484
- import warnings
485
-
486
- warnings.warn(
487
- "`plot_network_components` is deprecated and will be removed in a future version. "
488
- "Please use `fplot_network_components`.",
489
- DeprecationWarning,
490
- stacklevel=2,
491
- )
492
- return []
493
-
494
- def get_core_subgraph(self, k: int = 2) -> "NetworkGraph":
495
- """Return the k-core of the graph.
496
-
497
- The k-core is a subgraph containing only nodes with degree >= k.
498
-
499
- Parameters
500
- ----------
501
- k : int, optional
502
- The minimum degree for nodes in the core. The default is 2.
503
-
504
- Returns
505
- -------
506
- NetworkGraph
507
- The k-core subgraph.
508
- """
509
- core_graph = nx.k_core(self._nx_graph, k=k)
510
- return NetworkGraph(core_graph)
511
-
512
- def top_k_edges(
513
- self, attribute: str, reverse: bool = True, k: int = 5
514
- ) -> Dict[Any, List[Tuple[Any, Dict]]]:
515
- """Return the top ``k`` edges based on a given attribute.
516
-
517
- Parameters
518
- ----------
519
- attribute : str
520
- Attribute name used for sorting.
521
- reverse : bool, optional
522
- Whether to sort in descending order. The default is `True`.
523
- k : int, optional
524
- Number of top edges to return. The default is 5.
525
-
526
- Returns
527
- -------
528
- dict[Any, list[tuple[Any, dict]]]
529
- Mapping of edge tuples to attribute values.
530
- """
531
- top_list = {}
532
- for node in self.nodes:
533
- edges = self.edges(node, data=True)
534
- edges_sorted = sorted(
535
- edges, key=lambda x: x[2].get(attribute, 0), reverse=reverse
536
- )
537
- top_k_edges = edges_sorted[:k]
538
- for u, v, data in top_k_edges:
539
- edge_key = (u, v)
540
- top_list[edge_key] = data[attribute]
541
- return top_list
542
-
543
- def calculate_node_weights_from_edges(self, weight: str = "weight", k: int = 10):
544
- """Calculate node weights by summing weights of top k edges.
545
-
546
- Parameters
547
- ----------
548
- weight : str, optional
549
- Edge attribute to use for weighting. The default is "weight".
550
- k : int, optional
551
- Number of top edges to consider for each node. The default is 10.
552
- """
553
- edge_aggregates = self.top_k_edges(attribute=weight, k=k)
554
- node_aggregates = {}
555
- for (u, v), weight_value in edge_aggregates.items():
556
- if u not in node_aggregates:
557
- node_aggregates[u] = 0
558
- if v not in node_aggregates:
559
- node_aggregates[v] = 0
560
- node_aggregates[u] += weight_value
561
- node_aggregates[v] += weight_value
562
-
563
- nx.set_node_attributes(self._nx_graph, node_aggregates, name=weight)
564
-
565
- def trim_edges(
566
- self, weight: str = "weight", top_k_per_node: int = 5
567
- ) -> "NetworkGraph":
568
- """Trim the graph to keep only the top k edges per node.
569
-
570
- Parameters
571
- ----------
572
- weight : str, optional
573
- Edge attribute to use for sorting. The default is "weight".
574
- top_k_per_node : int, optional
575
- Number of top edges to keep per node. The default is 5.
576
-
577
- Returns
578
- -------
579
- NetworkGraph
580
- A new graph containing only the top edges.
581
- """
582
- edges_to_keep = self.top_k_edges(attribute=weight, k=top_k_per_node)
583
- return self.edge_subgraph(edges=edges_to_keep)
584
-
585
- @staticmethod
586
- def from_pandas_edgelist(
587
- df: pd.DataFrame,
588
- source: str = "source",
589
- target: str = "target",
590
- weight: str = "weight",
591
- ) -> "NetworkGraph":
592
- """Initialize a NetworkGraph from a simple DataFrame.
593
-
594
- Parameters
595
- ----------
596
- df : pd.DataFrame
597
- DataFrame containing edge data.
598
- source : str, optional
599
- Column name for source nodes. The default is "source".
600
- target : str, optional
601
- Column name for target nodes. The default is "target".
602
- weight : str, optional
603
- Column name for edge weights. The default is "weight".
604
-
605
- Returns
606
- -------
607
- NetworkGraph
608
- Initialized network graph.
609
- """
610
- network_G = nx.from_pandas_edgelist(
611
- df, source=source, target=target, edge_attr=weight
612
- )
613
- return NetworkGraph(network_G)
614
-
615
-
616
- def compute_network_grid(
617
- connected_components: List[set], style: StyleTemplate
618
- ) -> Tuple[Figure, np.ndarray]:
619
- """Compute the grid layout for network component subplots.
620
-
621
- Parameters
622
- ----------
623
- connected_components : list[set]
624
- A list of sets, where each set contains the nodes of a connected component.
625
- style : StyleTemplate
626
- The style template used for plotting.
627
-
628
- Returns
629
- -------
630
- Tuple[Figure, np.ndarray]
631
- A tuple containing the Matplotlib figure and the grid of axes.
632
- """
633
- n_components = len(connected_components)
634
- n_cols = int(np.ceil(np.sqrt(n_components)))
635
- n_rows = int(np.ceil(n_components / n_cols))
636
- fig, axes_grid = plt.subplots(n_rows, n_cols, figsize=FIG_SIZE)
637
- fig = cast(Figure, fig)
638
- fig.patch.set_facecolor(style.background_color)
639
- if not isinstance(axes_grid, np.ndarray):
640
- axes = np.array([axes_grid])
641
- else:
642
- axes = axes_grid.flatten()
643
- return fig, axes
644
-
645
-
646
- def prepare_network_graph(
647
- pd_df: pd.DataFrame,
648
- source: str,
649
- target: str,
650
- weight: str,
651
- sort_by: Optional[str],
652
- node_list: Optional[List],
653
- ) -> NetworkGraph:
654
- """Prepare a NetworkGraph for plotting from a pandas DataFrame.
655
-
656
- This function takes a DataFrame and prepares it for network visualization by:
657
- 1. Filtering the DataFrame to include only the nodes in `node_list` (if provided).
658
- 2. Validating the DataFrame to ensure it has the required columns.
659
- 3. Creating a `NetworkGraph` from the edge list.
660
- 4. Extracting the k-core of the graph (k=2) to focus on the main structure.
661
- 5. Calculating node weights based on the sum of their top k edge weights.
662
- 6. Trimming the graph to keep only the top k edges per node.
663
-
664
- Parameters
665
- ----------
666
- pd_df : pd.DataFrame
667
- DataFrame containing the edge list.
668
- source : str
669
- Column name for source nodes.
670
- target : str
671
- Column name for target nodes.
672
- weight : str
673
- Column name for edge weights.
674
- sort_by : str, optional
675
- Column to sort the DataFrame by before processing.
676
- node_list : list, optional
677
- A list of nodes to include in the graph. If provided, the DataFrame
678
- will be filtered to include only edges connected to these nodes.
679
-
680
- Returns
681
- -------
682
- NetworkGraph
683
- The prepared `NetworkGraph` object.
684
- """
685
- if node_list:
686
- df = pd_df.loc[
687
- (pd_df["source"].isin(node_list)) | (pd_df["target"].isin(node_list))
688
- ]
689
- else:
690
- df = pd_df
691
- validate_dataframe(df, cols=[source, target, weight], sort_by=sort_by)
692
-
693
- graph = NetworkGraph.from_pandas_edgelist(
694
- df, source=source, target=target, weight=weight
695
- )
696
- graph = graph.get_core_subgraph(k=2)
697
- graph.calculate_node_weights_from_edges(weight=weight, k=10)
698
- graph = graph.trim_edges(weight=weight, top_k_per_node=5)
699
- return graph
700
-
701
-
702
- def aplot_network(
703
- pd_df: pd.DataFrame,
704
- source: str = "source",
705
- target: str = "target",
706
- weight: str = "weight",
707
- title: Optional[str] = None,
708
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
709
- sort_by: Optional[str] = None,
710
- ascending: bool = False,
711
- node_list: Optional[List] = None,
712
- ax: Optional[Axes] = None,
713
- ) -> Axes:
714
- """Plot a network graph on the provided axes.
715
-
716
- Parameters
717
- ----------
718
- pd_df : pd.DataFrame
719
- DataFrame containing edge data.
720
- source : str, optional
721
- Column name for source nodes. The default is "source".
722
- target : str, optional
723
- Column name for target nodes. The default is "target".
724
- weight : str, optional
725
- Column name for edge weights. The default is "weight".
726
- title : str, optional
727
- Plot title.
728
- style : StyleTemplate, optional
729
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
730
- sort_by : str, optional
731
- Column used to sort the data.
732
- ascending : bool, optional
733
- Sort order for the data. The default is `False`.
734
- node_list : list, optional
735
- Nodes to include.
736
- ax : Axes, optional
737
- Axes to draw on.
738
-
739
- Returns
740
- -------
741
- Axes
742
- Matplotlib axes with the plotted network.
743
- """
744
- graph = prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
745
- return graph.plot_network(title=title, style=style, weight=weight, ax=ax)
746
-
747
-
748
- def aplot_network_components(
749
- pd_df: pd.DataFrame,
750
- axes: Optional[np.ndarray],
751
- source: str = "source",
752
- target: str = "target",
753
- weight: str = "weight",
754
- title: Optional[str] = None,
755
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
756
- sort_by: Optional[str] = None,
757
- node_list: Optional[List] = None,
758
- ascending: bool = False,
759
- ) -> None:
760
- """Plot network components separately on multiple axes.
761
-
762
- Parameters
763
- ----------
764
- pd_df : pd.DataFrame
765
- DataFrame containing edge data.
766
- source : str, optional
767
- Column name for source nodes. The default is "source".
768
- target : str, optional
769
- Column name for target nodes. The default is "target".
770
- weight : str, optional
771
- Column name for edge weights. The default is "weight".
772
- title : str, optional
773
- Base title for subplots.
774
- style : StyleTemplate, optional
775
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
776
- sort_by : str, optional
777
- Column used to sort the data.
778
- node_list : list, optional
779
- Nodes to include.
780
- ascending : bool, optional
781
- Sort order for the data. The default is `False`.
782
- axes : np.ndarray
783
- Existing axes to draw on.
784
- """
785
- graph = prepare_network_graph(pd_df, source, target, weight, sort_by, node_list)
786
-
787
- connected_components = list(nx.connected_components(graph._nx_graph))
788
-
789
- if not connected_components:
790
- if axes is not None:
791
- for ax in axes.flatten():
792
- ax.set_axis_off()
793
- return
794
-
795
- isolated_nodes = list(nx.isolates(graph._nx_graph))
796
- if isolated_nodes:
797
- # Keep the caller's graph intact while dropping isolates purely for
798
- # visualization.
799
- graph = NetworkGraph(graph._nx_graph.copy())
800
- graph._nx_graph.remove_nodes_from(isolated_nodes)
801
- connected_components = list(nx.connected_components(graph._nx_graph))
802
-
803
- if not connected_components:
804
- if axes is not None:
805
- for ax in axes.flatten():
806
- ax.set_axis_off()
807
- return
808
-
809
- local_axes = axes
810
- if local_axes is None:
811
- fig, local_axes = compute_network_grid(connected_components, style)
812
-
813
- i = -1
814
- for i, component in enumerate(connected_components):
815
- if i < len(local_axes):
816
- if len(component) > 5:
817
- component_graph = graph.subgraph(node_list=list(component))
818
- component_graph.plot_network(
819
- title=f"{title}::{i}" if title else str(i),
820
- style=style,
821
- weight=weight,
822
- ax=local_axes[i],
823
- )
824
- local_axes[i].set_axis_on()
825
- else:
826
- break
827
-
828
- for j in range(i + 1, len(local_axes)):
829
- local_axes[j].set_axis_off()
830
-
831
-
832
- def fplot_network(
833
- pd_df: pd.DataFrame,
834
- source: str = "source",
835
- target: str = "target",
836
- weight: str = "weight",
837
- title: Optional[str] = None,
838
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
839
- sort_by: Optional[str] = None,
840
- ascending: bool = False,
841
- node_list: Optional[List] = None,
842
- figsize: Tuple[float, float] = FIG_SIZE,
843
- save_path: Optional[str] = None,
844
- savefig_kwargs: Optional[Dict[str, Any]] = None,
845
- ) -> Figure:
846
- """Return a figure with a network graph.
847
-
848
- Parameters
849
- ----------
850
- pd_df : pd.DataFrame
851
- DataFrame containing edge data.
852
- source : str, optional
853
- Column name for source nodes. The default is "source".
854
- target : str, optional
855
- Column name for target nodes. The default is "target".
856
- weight : str, optional
857
- Column name for edge weights. The default is "weight".
858
- title : str, optional
859
- Plot title.
860
- style : StyleTemplate, optional
861
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
862
- sort_by : str, optional
863
- Column used to sort the data.
864
- ascending : bool, optional
865
- Sort order for the data. The default is `False`.
866
- node_list : list, optional
867
- Nodes to include.
868
- figsize : tuple[float, float], optional
869
- Size of the created figure. The default is FIG_SIZE.
870
-
871
- Returns
872
- -------
873
- Figure
874
- Matplotlib figure with the network graph.
875
- """
876
- fig = cast(Figure, plt.figure(figsize=figsize))
877
- fig.patch.set_facecolor(style.background_color)
878
- ax = fig.add_subplot()
879
- ax = aplot_network(
880
- pd_df,
881
- source=source,
882
- target=target,
883
- weight=weight,
884
- title=title,
885
- style=style,
886
- sort_by=sort_by,
887
- ascending=ascending,
888
- node_list=node_list,
889
- ax=ax,
890
- )
891
- if save_path:
892
- fig.savefig(save_path, **(savefig_kwargs or {}))
893
- return fig
894
-
895
-
896
- def fplot_network_components(
897
- pd_df: pd.DataFrame,
898
- source: str = "source",
899
- target: str = "target",
900
- weight: str = "weight",
901
- title: Optional[str] = None,
902
- style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
903
- sort_by: Optional[str] = None,
904
- ascending: bool = False,
905
- node_list: Optional[List] = None,
906
- figsize: Tuple[float, float] = FIG_SIZE,
907
- n_cols: Optional[int] = None,
908
- save_path: Optional[str] = None,
909
- savefig_kwargs: Optional[Dict[str, Any]] = None,
910
- ) -> Figure:
911
- """Return a figure showing individual network components.
912
-
913
- Parameters
914
- ----------
915
- pd_df : pd.DataFrame
916
- DataFrame containing edge data.
917
- source : str, optional
918
- Column name for source nodes. The default is "source".
919
- target : str, optional
920
- Column name for target nodes. The default is "target".
921
- weight : str, optional
922
- Column name for edge weights. The default is "weight".
923
- title : str, optional
924
- Plot title.
925
- style : StyleTemplate, optional
926
- Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
927
- sort_by : str, optional
928
- Column used to sort the data.
929
- ascending : bool, optional
930
- Sort order for the data. The default is `False`.
931
- node_list : list, optional
932
- Nodes to include.
933
- figsize : tuple[float, float], optional
934
- Size of the created figure. The default is FIG_SIZE.
935
- n_cols : int, optional
936
- Number of columns for subplots. If None, it's inferred.
937
-
938
- Returns
939
- -------
940
- Figure
941
- Matplotlib figure displaying component plots.
942
- """
943
- # First, get the graph and components to determine the layout
944
- df = pd_df.copy()
945
- if node_list:
946
- df = df.loc[(df["source"].isin(node_list)) | (df["target"].isin(node_list))]
947
-
948
- validate_dataframe(df, cols=[source, target, weight], sort_by=sort_by)
949
- graph = NetworkGraph.from_pandas_edgelist(
950
- df, source=source, target=target, weight=weight
951
- )
952
- graph = graph.get_core_subgraph(k=2)
953
- connected_components = list(nx.connected_components(graph._nx_graph))
954
-
955
- n_components = len(connected_components)
956
- if n_cols is None:
957
- n_cols = int(np.ceil(np.sqrt(n_components)))
958
- n_rows = int(np.ceil(n_components / n_cols))
959
-
960
- fig, axes_grid = plt.subplots(n_rows, n_cols, figsize=figsize)
961
- fig = cast(Figure, fig)
962
- fig.patch.set_facecolor(style.background_color)
963
-
964
- if not isinstance(axes_grid, np.ndarray):
965
- axes = np.array([axes_grid])
966
- else:
967
- axes = axes_grid.flatten()
968
-
969
- aplot_network_components(
970
- pd_df=pd_df,
971
- source=source,
972
- target=target,
973
- weight=weight,
974
- title=title,
975
- style=style,
976
- sort_by=sort_by,
977
- ascending=ascending,
978
- node_list=node_list,
979
- axes=axes,
980
- )
981
-
982
- if title:
983
- fig.suptitle(title, color=style.font_color, fontsize=style.font_size * 2.5)
984
-
985
- plt.tight_layout(rect=(0, 0.03, 1, 0.95))
986
-
987
- if save_path:
988
- fig.savefig(save_path, **(savefig_kwargs or {}))
989
- return fig