MatplotLibAPI 3.3.0__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. MatplotLibAPI/__init__.py +4 -86
  2. MatplotLibAPI/accessor.py +279 -182
  3. MatplotLibAPI/area.py +177 -0
  4. MatplotLibAPI/bar.py +185 -0
  5. MatplotLibAPI/base_plot.py +88 -0
  6. MatplotLibAPI/box_violin.py +180 -0
  7. MatplotLibAPI/bubble.py +568 -0
  8. MatplotLibAPI/{Composite.py → composite.py} +70 -83
  9. MatplotLibAPI/heatmap.py +223 -0
  10. MatplotLibAPI/histogram.py +170 -0
  11. MatplotLibAPI/mcp/__init__.py +17 -0
  12. MatplotLibAPI/mcp/metadata.py +90 -0
  13. MatplotLibAPI/mcp/renderers.py +45 -0
  14. MatplotLibAPI/mcp_server.py +626 -0
  15. MatplotLibAPI/network/__init__.py +28 -0
  16. MatplotLibAPI/network/constants.py +22 -0
  17. MatplotLibAPI/{Network.py → network/core.py} +346 -809
  18. MatplotLibAPI/network/plot.py +597 -0
  19. MatplotLibAPI/network/scaling.py +56 -0
  20. MatplotLibAPI/pie.py +154 -0
  21. MatplotLibAPI/pivot.py +274 -0
  22. MatplotLibAPI/sankey.py +99 -0
  23. MatplotLibAPI/{StyleTemplate.py → style_template.py} +8 -4
  24. MatplotLibAPI/sunburst.py +139 -0
  25. MatplotLibAPI/{Table.py → table.py} +108 -93
  26. MatplotLibAPI/{Timeserie.py → timeserie.py} +98 -42
  27. MatplotLibAPI/{Treemap.py → treemap.py} +43 -55
  28. MatplotLibAPI/typing.py +12 -0
  29. MatplotLibAPI/{_visualization_utils.py → utils.py} +7 -13
  30. MatplotLibAPI/waffle.py +173 -0
  31. MatplotLibAPI/{Wordcloud.py → word_cloud.py} +187 -88
  32. {matplotlibapi-3.3.0.dist-info → matplotlibapi-4.0.0.dist-info}/METADATA +98 -9
  33. matplotlibapi-4.0.0.dist-info/RECORD +36 -0
  34. {matplotlibapi-3.3.0.dist-info → matplotlibapi-4.0.0.dist-info}/WHEEL +1 -1
  35. matplotlibapi-4.0.0.dist-info/entry_points.txt +2 -0
  36. MatplotLibAPI/Area.py +0 -80
  37. MatplotLibAPI/Bar.py +0 -83
  38. MatplotLibAPI/BoxViolin.py +0 -75
  39. MatplotLibAPI/Bubble.py +0 -460
  40. MatplotLibAPI/Heatmap.py +0 -121
  41. MatplotLibAPI/Histogram.py +0 -73
  42. MatplotLibAPI/Pie.py +0 -70
  43. MatplotLibAPI/Pivot.py +0 -134
  44. MatplotLibAPI/Sankey.py +0 -46
  45. MatplotLibAPI/Sunburst.py +0 -89
  46. MatplotLibAPI/Waffle.py +0 -86
  47. MatplotLibAPI/_typing.py +0 -17
  48. matplotlibapi-3.3.0.dist-info/RECORD +0 -26
  49. {matplotlibapi-3.3.0.dist-info → matplotlibapi-4.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,597 @@
1
+ """Network chart plotting helpers."""
2
+
3
+ from typing import Any, Optional, Tuple, cast
4
+ import matplotlib.pyplot as plt
5
+ import networkx as nx
6
+ import numpy as np
7
+ import pandas as pd
8
+ from matplotlib.axes import Axes
9
+ from matplotlib.figure import Figure
10
+
11
+ from .core import NetworkGraph
12
+ from .constants import _DEFAULT
13
+ from ..style_template import (
14
+ NETWORK_STYLE_TEMPLATE,
15
+ FIG_SIZE,
16
+ TITLE_SCALE_FACTOR,
17
+ StyleTemplate,
18
+ string_formatter,
19
+ validate_dataframe,
20
+ )
21
+
22
+ __all__ = [
23
+ "aplot_network",
24
+ "aplot_network_node",
25
+ "aplot_network_components",
26
+ "fplot_network",
27
+ "fplot_network_node",
28
+ "fplot_network_components",
29
+ ]
30
+
31
+
32
+ def _sanitize_node_dataframe(
33
+ node_df: Optional[pd.DataFrame],
34
+ edge_df: pd.DataFrame,
35
+ node_col: str = "node",
36
+ node_weight_col: str = "weight",
37
+ edge_source_col: str = "source",
38
+ edge_target_col: str = "target",
39
+ edge_weight_col: str = "weight",
40
+ ) -> Optional[pd.DataFrame]:
41
+ """Private helper returning ``node_df`` rows present in the edge list.
42
+
43
+ Intended for internal use when preparing plotting data.
44
+
45
+ Parameters
46
+ ----------
47
+ node_df : pd.DataFrame, optional
48
+ DataFrame containing ``node`` and ``weight`` columns.
49
+ edge_df : pd.DataFrame
50
+ Edge DataFrame containing source and target columns.
51
+ node_col : str, optional
52
+ Column name for node identifiers. The default is "node".
53
+ node_weight_col : str, optional
54
+ Column name for node weights. The default is "weight".
55
+ edge_source_col : str, optional
56
+ Column name for source edges. The default is "source".
57
+ edge_target_col : str, optional
58
+ Column name for target edges. The default is "target".
59
+ edge_weight_col : str, optional
60
+ Column name for edge weights. The default is "weight". Included to
61
+ keep signature parity with other sanitization helpers.
62
+
63
+ Returns
64
+ -------
65
+ pd.DataFrame
66
+ Filtered ``node_df`` with only nodes that appear as sources or targets.
67
+ """
68
+ if node_df is None:
69
+ return None
70
+
71
+ validate_dataframe(node_df, cols=[node_col, node_weight_col])
72
+ validate_dataframe(
73
+ edge_df, cols=[edge_source_col, edge_target_col, edge_weight_col]
74
+ )
75
+ filtered_node_df = node_df.copy()
76
+ nodes_in_edges = list(set(edge_df[edge_source_col]).union(edge_df[edge_target_col]))
77
+ return filtered_node_df.loc[filtered_node_df[node_col].isin(nodes_in_edges)]
78
+
79
+
80
+ def _prepare_network_graph(
81
+ pd_df: pd.DataFrame,
82
+ node_col: str = "node",
83
+ node_weight_col: str = "weight",
84
+ edge_source_col: str = "source",
85
+ edge_target_col: str = "target",
86
+ edge_weight_col: str = "weight",
87
+ sort_by: Optional[str] = None,
88
+ node_df: Optional[pd.DataFrame] = None,
89
+ ) -> NetworkGraph:
90
+ """Prepare a NetworkGraph for plotting from a pandas DataFrame.
91
+
92
+ This function takes a DataFrame and prepares it for network visualization by:
93
+ 1. Filtering the DataFrame to include only the nodes in ``node_df`` (if provided).
94
+ 2. Validating the DataFrame to ensure it has the required columns.
95
+ 3. Creating a `NetworkGraph` from the edge list.
96
+ 4. Extracting the k-core of the graph (k=2) to focus on the main structure.
97
+ 5. Applying node weights provided in ``node_df`` or calculating them from
98
+ the top-k edge weights.
99
+ 6. Trimming the graph to keep only the top k edges per node.
100
+
101
+ Parameters
102
+ ----------
103
+ pd_df : pd.DataFrame
104
+ DataFrame containing the edge list.
105
+ node_col : str, optional
106
+ Column name for node identifiers. The default is "node".
107
+ node_weight_col : str, optional
108
+ Column name for node weights. The default is "weight".
109
+ edge_source_col : str, optional
110
+ Column name for source nodes. The default is "source".
111
+ edge_target_col : str, optional
112
+ Column name for target nodes. The default is "target".
113
+ edge_weight_col : str, optional
114
+ Column name for edge weights. The default is "weight".
115
+ sort_by : str, optional
116
+ Column to sort the DataFrame by before processing.
117
+ node_df : pd.DataFrame, optional
118
+ DataFrame containing ``node`` and ``weight`` columns. If provided, the
119
+ DataFrame will be filtered to include only edges connected to these
120
+ nodes, and their provided weights will be used instead of calculated
121
+ values.
122
+
123
+ Returns
124
+ -------
125
+ NetworkGraph
126
+ The prepared `NetworkGraph` object.
127
+
128
+ Raises
129
+ ------
130
+ ValueError
131
+ If ``node_df`` is provided but none of its nodes appear as sources or
132
+ targets in ``pd_df``.
133
+ """
134
+ filtered_node_df = _sanitize_node_dataframe(
135
+ node_df,
136
+ pd_df,
137
+ node_col=node_col,
138
+ node_weight_col=node_weight_col,
139
+ edge_source_col=edge_source_col,
140
+ edge_target_col=edge_target_col,
141
+ edge_weight_col=edge_weight_col,
142
+ )
143
+ if node_df is not None:
144
+ if filtered_node_df is None or filtered_node_df.empty:
145
+ raise ValueError(
146
+ "node_df must include at least one node present as a source or target."
147
+ )
148
+ allowed_nodes = filtered_node_df[node_col].tolist()
149
+ df = pd_df.loc[
150
+ pd_df[edge_source_col].isin(allowed_nodes)
151
+ & pd_df[edge_target_col].isin(allowed_nodes)
152
+ ]
153
+ else:
154
+ df = pd_df
155
+ validate_dataframe(
156
+ df, cols=[edge_source_col, edge_target_col, edge_weight_col], sort_by=sort_by
157
+ )
158
+
159
+ graph = NetworkGraph.from_pandas_edgelist(
160
+ df,
161
+ source=edge_source_col,
162
+ target=edge_target_col,
163
+ edge_weight_col=edge_weight_col,
164
+ )
165
+ if filtered_node_df is not None and not filtered_node_df.empty:
166
+ node_weights = {
167
+ node: weight_value
168
+ for node, weight_value in filtered_node_df.set_index(node_col)[
169
+ node_weight_col
170
+ ].items()
171
+ if node in graph._nx_graph.nodes
172
+ }
173
+ nx.set_node_attributes(graph._nx_graph, node_weights, name=edge_weight_col)
174
+ else:
175
+ graph.calculate_nodes(edge_weight_col=edge_weight_col, k=10)
176
+ return graph
177
+
178
+
179
+ def aplot_network(
180
+ pd_df: pd.DataFrame,
181
+ edge_source_col: str = "source",
182
+ edge_target_col: str = "target",
183
+ edge_weight_col: str = "weight",
184
+ title: Optional[str] = None,
185
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
186
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
187
+ ax: Optional[Axes] = None,
188
+ ) -> Axes:
189
+ """Plot a network graph on the provided axes.
190
+
191
+ Parameters
192
+ ----------
193
+ pd_df : pd.DataFrame
194
+ DataFrame containing edge data.
195
+ edge_source_col : str, optional
196
+ Column name for source nodes. The default is "source".
197
+ edge_target_col : str, optional
198
+ Column name for target nodes. The default is "target".
199
+ edge_weight_col : str, optional
200
+ Column name for edge weights. The default is "weight".
201
+ title : str, optional
202
+ Plot title.
203
+ style : StyleTemplate, optional
204
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
205
+ layout_seed : int, optional
206
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
207
+ ax : Axes, optional
208
+ Axes to draw on.
209
+
210
+ Returns
211
+ -------
212
+ Axes
213
+ Matplotlib axes with the plotted network.
214
+ """
215
+ return NetworkGraph(
216
+ pd_df=pd_df,
217
+ source=edge_source_col,
218
+ target=edge_target_col,
219
+ weight=edge_weight_col,
220
+ ).aplot(
221
+ title=title,
222
+ style=style,
223
+ edge_weight_col=edge_weight_col,
224
+ layout_seed=layout_seed,
225
+ ax=ax,
226
+ )
227
+
228
+
229
+ def aplot_network_node(
230
+ pd_df: pd.DataFrame,
231
+ node: Any,
232
+ edge_source_col: str = "source",
233
+ edge_target_col: str = "target",
234
+ edge_weight_col: str = "weight",
235
+ title: Optional[str] = None,
236
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
237
+ ax: Optional[Axes] = None,
238
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
239
+ ) -> Axes:
240
+ """Plot the connected component containing ``node`` on the provided axes.
241
+
242
+ Parameters
243
+ ----------
244
+ pd_df : pd.DataFrame
245
+ DataFrame containing edge data.
246
+ node : Any
247
+ Node identifier whose component should be visualized.
248
+ node_col : str, optional
249
+ Column name for node identifiers. The default is "node".
250
+ node_weight_col : str, optional
251
+ Column name for node weights. The default is "weight".
252
+ edge_source_col : str, optional
253
+ Column name for source nodes. The default is "source".
254
+ edge_target_col : str, optional
255
+ Column name for target nodes. The default is "target".
256
+ edge_weight_col : str, optional
257
+ Column name for edge weights. The default is "weight".
258
+ sort_by : str, optional
259
+ Column used to sort the data.
260
+ ascending : bool, optional
261
+ Sort order for the data. The default is `False`.
262
+ node_df : pd.DataFrame, optional
263
+ DataFrame containing ``node`` and ``weight`` columns to include.
264
+ title : str, optional
265
+ Plot title. If ``None``, defaults to the node identifier.
266
+ style : StyleTemplate, optional
267
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
268
+ ax : Axes, optional
269
+ Axes to draw on.
270
+ layout_seed : int, optional
271
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
272
+
273
+ Returns
274
+ -------
275
+ Axes
276
+ Matplotlib axes with the plotted component.
277
+
278
+ Raises
279
+ ------
280
+ ValueError
281
+ If ``node`` is not present in the prepared graph.
282
+ """
283
+ graph = NetworkGraph(
284
+ pd_df=pd_df,
285
+ source=edge_source_col,
286
+ target=edge_target_col,
287
+ weight=edge_weight_col,
288
+ )
289
+ component_graph = graph.subgraph_component(node)
290
+ resolved_title = title if title is not None else string_formatter(node)
291
+ return component_graph.aplot(
292
+ title=resolved_title,
293
+ style=style,
294
+ edge_weight_col=edge_weight_col,
295
+ ax=ax,
296
+ layout_seed=layout_seed,
297
+ )
298
+
299
+
300
+ def aplot_network_components(
301
+ pd_df: pd.DataFrame,
302
+ node_col: str = "node",
303
+ node_weight_col: str = "weight",
304
+ edge_source_col: str = "source",
305
+ edge_target_col: str = "target",
306
+ edge_weight_col: str = "weight",
307
+ sort_by: Optional[str] = None,
308
+ ascending: bool = False,
309
+ node_df: Optional[pd.DataFrame] = None,
310
+ title: Optional[str] = None,
311
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
312
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
313
+ axes: Optional[np.ndarray] = None,
314
+ ) -> None:
315
+ """Plot network components separately on multiple axes.
316
+
317
+ Parameters
318
+ ----------
319
+ pd_df : pd.DataFrame
320
+ DataFrame containing edge data.
321
+ node_col : str, optional
322
+ Column name for node identifiers. The default is "node".
323
+ node_weight_col : str, optional
324
+ Column name for node weights. The default is "weight".
325
+ edge_source_col : str, optional
326
+ Column name for source nodes. The default is "source".
327
+ edge_target_col : str, optional
328
+ Column name for target nodes. The default is "target".
329
+ edge_weight_col : str, optional
330
+ Column name for edge weights. The default is "weight".
331
+ sort_by : str, optional
332
+ Column used to sort the data.
333
+ ascending : bool, optional
334
+ Sort order for the data. The default is `False`.
335
+ node_df : pd.DataFrame, optional
336
+ DataFrame containing ``node`` and ``weight`` columns to include.
337
+ title : str, optional
338
+ Base title for subplots.
339
+ style : StyleTemplate, optional
340
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
341
+ layout_seed : int, optional
342
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
343
+ axes : np.ndarray
344
+ Existing axes to draw on.
345
+ """
346
+ graph = _prepare_network_graph(
347
+ pd_df,
348
+ node_col=node_col,
349
+ node_weight_col=node_weight_col,
350
+ edge_source_col=edge_source_col,
351
+ edge_target_col=edge_target_col,
352
+ edge_weight_col=edge_weight_col,
353
+ sort_by=sort_by,
354
+ node_df=node_df,
355
+ )
356
+ graph.aplot_connected_components(
357
+ title=title,
358
+ style=style,
359
+ edge_weight_col=edge_weight_col,
360
+ layout_seed=layout_seed,
361
+ axes=axes,
362
+ )
363
+
364
+
365
+ def fplot_network(
366
+ pd_df: pd.DataFrame,
367
+ edge_source_col: str = "source",
368
+ edge_target_col: str = "target",
369
+ edge_weight_col: str = "weight",
370
+ title: Optional[str] = None,
371
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
372
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
373
+ figsize: Tuple[float, float] = FIG_SIZE,
374
+ ) -> Figure:
375
+ """Return a figure with a network graph.
376
+
377
+ Parameters
378
+ ----------
379
+ pd_df : pd.DataFrame
380
+ DataFrame containing edge data.
381
+ node_col : str, optional
382
+ Column name for node identifiers. The default is "node".
383
+ node_weight_col : str, optional
384
+ Column name for node weights. The default is "weight".
385
+ edge_source_col : str, optional
386
+ Column name for source nodes. The default is "source".
387
+ edge_target_col : str, optional
388
+ Column name for target nodes. The default is "target".
389
+ edge_weight_col : str, optional
390
+ Column name for edge weights. The default is "weight".
391
+ sort_by : str, optional
392
+ Column used to sort the data.
393
+ ascending : bool, optional
394
+ Sort order for the data. The default is `False`.
395
+ node_df : pd.DataFrame, optional
396
+ DataFrame containing ``node`` and ``weight`` columns to include.
397
+ title : str, optional
398
+ Plot title.
399
+ style : StyleTemplate, optional
400
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
401
+ layout_seed : int, optional
402
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
403
+ figsize : tuple[float, float], optional
404
+ Size of the created figure. The default is FIG_SIZE.
405
+
406
+ Returns
407
+ -------
408
+ Figure
409
+ Matplotlib figure with the network graph.
410
+ """
411
+ return NetworkGraph(
412
+ pd_df=pd_df,
413
+ source=edge_source_col,
414
+ target=edge_target_col,
415
+ weight=edge_weight_col,
416
+ ).fplot(
417
+ title=title,
418
+ style=style,
419
+ figsize=figsize,
420
+ )
421
+
422
+
423
+ def fplot_network_node(
424
+ pd_df: pd.DataFrame,
425
+ node: Any,
426
+ edge_source_col: str = "source",
427
+ edge_target_col: str = "target",
428
+ edge_weight_col: str = "weight",
429
+ title: Optional[str] = None,
430
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
431
+ figsize: Tuple[float, float] = FIG_SIZE,
432
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
433
+ ) -> Figure:
434
+ """Return a figure with the component containing ``node``.
435
+
436
+ Parameters
437
+ ----------
438
+ pd_df : pd.DataFrame
439
+ DataFrame containing edge data.
440
+ node : Any
441
+ Node identifier whose component should be visualized.
442
+ node_col : str, optional
443
+ Column name for node identifiers. The default is "node".
444
+ node_weight_col : str, optional
445
+ Column name for node weights. The default is "weight".
446
+ edge_source_col : str, optional
447
+ Column name for source nodes. The default is "source".
448
+ edge_target_col : str, optional
449
+ Column name for target nodes. The default is "target".
450
+ edge_weight_col : str, optional
451
+ Column name for edge weights. The default is "weight".
452
+ sort_by : str, optional
453
+ Column used to sort the data.
454
+ ascending : bool, optional
455
+ Sort order for the data. The default is `False`.
456
+ node_df : pd.DataFrame, optional
457
+ DataFrame containing ``node`` and ``weight`` columns to include.
458
+ figsize : tuple[float, float], optional
459
+ Size of the created figure. The default is FIG_SIZE.
460
+ title : str, optional
461
+ Plot title. If ``None``, defaults to the node identifier.
462
+ style : StyleTemplate, optional
463
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
464
+ save_path : str, optional
465
+ File path to save the figure. The default is ``None``.
466
+ savefig_kwargs : dict[str, Any], optional
467
+ Extra keyword arguments forwarded to ``Figure.savefig``. The default is ``None``.
468
+ layout_seed : int, optional
469
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
470
+
471
+ Returns
472
+ -------
473
+ Figure
474
+ Matplotlib figure with the component plot.
475
+
476
+ Raises
477
+ ------
478
+ ValueError
479
+ If ``node`` is not present in the prepared graph.
480
+ """
481
+ fig = cast(Figure, plt.figure(figsize=figsize))
482
+ fig.set_facecolor(style.background_color)
483
+ ax = fig.add_subplot()
484
+ ax = aplot_network_node(
485
+ pd_df,
486
+ node=node,
487
+ edge_source_col=edge_source_col,
488
+ edge_target_col=edge_target_col,
489
+ edge_weight_col=edge_weight_col,
490
+ title=title,
491
+ style=style,
492
+ ax=ax,
493
+ layout_seed=layout_seed,
494
+ )
495
+ return fig
496
+
497
+
498
+ def fplot_network_components(
499
+ pd_df: pd.DataFrame,
500
+ edge_source_col: str = "source",
501
+ edge_target_col: str = "target",
502
+ edge_weight_col: str = "weight",
503
+ title: Optional[str] = None,
504
+ style: StyleTemplate = NETWORK_STYLE_TEMPLATE,
505
+ layout_seed: Optional[int] = _DEFAULT["SPRING_LAYOUT_SEED"],
506
+ figsize: Tuple[float, float] = FIG_SIZE,
507
+ n_cols: Optional[int] = None,
508
+ ) -> Figure:
509
+ """Return a figure showing individual network components.
510
+
511
+ Parameters
512
+ ----------
513
+ pd_df : pd.DataFrame
514
+ DataFrame containing edge data.
515
+ node_col : str, optional
516
+ Column name for node identifiers. The default is "node".
517
+ node_weight_col : str, optional
518
+ Column name for node weights. The default is "weight".
519
+ edge_source_col : str, optional
520
+ Column name for source nodes. The default is "source".
521
+ edge_target_col : str, optional
522
+ Column name for target nodes. The default is "target".
523
+ edge_weight_col : str, optional
524
+ Column name for edge weights. The default is "weight".
525
+ sort_by : str, optional
526
+ Column used to sort the data.
527
+ ascending : bool, optional
528
+ Sort order for the data. The default is `False`.
529
+ node_df : pd.DataFrame, optional
530
+ DataFrame containing ``node`` and ``weight`` columns to include.
531
+ title : str, optional
532
+ Plot title.
533
+ style : StyleTemplate, optional
534
+ Style configuration. The default is `NETWORK_STYLE_TEMPLATE`.
535
+ layout_seed : int, optional
536
+ Seed for the spring layout used to place nodes. The default is ``_DEFAULT["SPRING_LAYOUT_SEED"]``.
537
+ figsize : tuple[float, float], optional
538
+ Size of the created figure. The default is FIG_SIZE.
539
+ n_cols : int, optional
540
+ Number of columns for subplots. If None, it's inferred.
541
+ save_path : str, optional
542
+ File path to save the figure. The default is ``None``.
543
+ savefig_kwargs : dict[str, Any], optional
544
+ Extra keyword arguments forwarded to ``Figure.savefig``. The default is ``None``.
545
+
546
+ Returns
547
+ -------
548
+ Figure
549
+ Matplotlib figure displaying component plots.
550
+
551
+ Raises
552
+ ------
553
+ ValueError
554
+ If ``node_df`` is provided but none of its nodes appear as sources or
555
+ targets in ``pd_df``.
556
+ """
557
+ graph = NetworkGraph(
558
+ pd_df=pd_df,
559
+ source=edge_source_col,
560
+ target=edge_target_col,
561
+ weight=edge_weight_col,
562
+ )
563
+ isolated_nodes = list(nx.isolates(graph._nx_graph))
564
+ if isolated_nodes:
565
+ graph._nx_graph.remove_nodes_from(isolated_nodes)
566
+ connected_components = graph.connected_components
567
+
568
+ n_components = max(1, len(connected_components))
569
+ n_cols_local = int(np.ceil(np.sqrt(n_components))) if n_cols is None else n_cols
570
+ n_rows = int(np.ceil(n_components / n_cols_local))
571
+
572
+ fig, axes_grid = plt.subplots(n_rows, n_cols_local, figsize=figsize)
573
+ fig = cast(Figure, fig)
574
+ fig.set_facecolor(style.background_color)
575
+ if not isinstance(axes_grid, np.ndarray):
576
+ axes = np.array([axes_grid])
577
+ else:
578
+ axes = axes_grid.flatten()
579
+
580
+ graph.aplot_connected_components(
581
+ title=title,
582
+ style=style,
583
+ edge_weight_col=edge_weight_col,
584
+ layout_seed=layout_seed,
585
+ axes=axes,
586
+ )
587
+
588
+ if title:
589
+ fig.suptitle(
590
+ title,
591
+ color=style.font_color,
592
+ fontsize=style.font_size * TITLE_SCALE_FACTOR * 1.25,
593
+ )
594
+
595
+ plt.tight_layout(rect=(0, 0.03, 1, 0.95))
596
+
597
+ return fig
@@ -0,0 +1,56 @@
1
+ """Weight scaling helpers for network plots."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Iterable, List, Optional
6
+
7
+ import numpy as np
8
+
9
+ from .constants import _WEIGHT_PERCENTILES
10
+
11
+
12
+ def _softmax(x: Iterable[float]) -> np.ndarray:
13
+ """Compute softmax values for array-like input."""
14
+ x_arr = np.array(x)
15
+ shifted = x_arr - np.max(x_arr)
16
+ exp_shifted = np.exp(shifted)
17
+ return exp_shifted / exp_shifted.sum()
18
+
19
+
20
+ def _scale_weights(
21
+ weights: Iterable[float],
22
+ scale_min: float = 0,
23
+ scale_max: float = 1,
24
+ deciles: Optional[np.ndarray] = None,
25
+ ) -> List[float]:
26
+ """Scale weights into deciles within the given range."""
27
+ weights_arr = np.array(list(weights), dtype=float)
28
+ if weights_arr.size == 0:
29
+ return []
30
+
31
+ soft = _softmax(weights_arr)
32
+ deciles_arr = (
33
+ np.percentile(weights_arr, _WEIGHT_PERCENTILES) if deciles is None else deciles
34
+ )
35
+
36
+ scaled = np.zeros_like(weights_arr, dtype=float)
37
+ edges = np.concatenate(([-np.inf], deciles_arr, [np.inf]))
38
+ bins = np.digitize(weights_arr, edges) - 1
39
+
40
+ for idx in range(10):
41
+ mask = bins == idx
42
+ if not np.any(mask):
43
+ continue
44
+ bin_soft = soft[mask]
45
+ if bin_soft.max() - bin_soft.min() < 1e-12:
46
+ scaled[mask] = scale_min + (scale_max - scale_min) * (idx / 9)
47
+ else:
48
+ normalized = (bin_soft - bin_soft.min()) / (bin_soft.max() - bin_soft.min())
49
+ scaled[mask] = scale_min + (scale_max - scale_min) * (
50
+ idx / 9 + normalized / 9
51
+ )
52
+
53
+ return scaled.tolist()
54
+
55
+
56
+ __all__ = ["_softmax", "_scale_weights"]