risk-network 0.0.6b6__tar.gz → 0.0.6b8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/PKG-INFO +1 -1
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/__init__.py +1 -1
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/network/plot.py +122 -65
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk_network.egg-info/PKG-INFO +1 -1
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/LICENSE +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/MANIFEST.in +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/README.md +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/pyproject.toml +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/annotations/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/annotations/annotations.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/annotations/io.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/constants.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/log/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/log/console.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/log/params.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/neighborhoods/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/neighborhoods/community.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/neighborhoods/domains.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/neighborhoods/neighborhoods.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/network/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/network/geometry.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/network/graph.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/network/io.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/risk.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/stats/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/stats/fisher_exact.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/stats/hypergeom.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/stats/permutation/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/stats/permutation/permutation.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/stats/permutation/test_functions.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk/stats/stats.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk_network.egg-info/SOURCES.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk_network.egg-info/dependency_links.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk_network.egg-info/requires.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/risk_network.egg-info/top_level.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/setup.cfg +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b8}/setup.py +0 -0
@@ -274,7 +274,7 @@ class NetworkPlotter:
|
|
274
274
|
|
275
275
|
def plot_subnetwork(
|
276
276
|
self,
|
277
|
-
nodes: List,
|
277
|
+
nodes: Union[List, Tuple, np.ndarray],
|
278
278
|
node_size: Union[int, np.ndarray] = 50,
|
279
279
|
node_shape: str = "o",
|
280
280
|
node_edgewidth: float = 1.0,
|
@@ -288,20 +288,24 @@ class NetworkPlotter:
|
|
288
288
|
"""Plot a subnetwork of selected nodes with customizable node and edge attributes.
|
289
289
|
|
290
290
|
Args:
|
291
|
-
nodes (list): List of node labels to include in the subnetwork.
|
291
|
+
nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
|
292
292
|
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
293
293
|
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
294
294
|
node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
|
295
295
|
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
296
|
-
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes.
|
296
|
+
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
|
297
297
|
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
298
298
|
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
299
|
-
node_alpha (float, optional):
|
300
|
-
edge_alpha (float, optional):
|
299
|
+
node_alpha (float, optional): Transparency for the nodes. Defaults to 1.0.
|
300
|
+
edge_alpha (float, optional): Transparency for the edges. Defaults to 1.0.
|
301
301
|
|
302
302
|
Raises:
|
303
303
|
ValueError: If no valid nodes are found in the network graph.
|
304
304
|
"""
|
305
|
+
# Flatten nested lists of nodes, if necessary
|
306
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
307
|
+
nodes = [node for sublist in nodes for node in sublist]
|
308
|
+
|
305
309
|
# Filter to get node IDs and their coordinates
|
306
310
|
node_ids = [
|
307
311
|
self.graph.node_label_to_id_map.get(node)
|
@@ -407,7 +411,7 @@ class NetworkPlotter:
|
|
407
411
|
|
408
412
|
def plot_subcontour(
|
409
413
|
self,
|
410
|
-
nodes: List,
|
414
|
+
nodes: Union[List, Tuple, np.ndarray],
|
411
415
|
levels: int = 5,
|
412
416
|
bandwidth: float = 0.8,
|
413
417
|
grid_size: int = 250,
|
@@ -417,10 +421,10 @@ class NetworkPlotter:
|
|
417
421
|
alpha: float = 1.0,
|
418
422
|
fill_alpha: float = 0.2,
|
419
423
|
) -> None:
|
420
|
-
"""Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
|
424
|
+
"""Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
|
421
425
|
|
422
426
|
Args:
|
423
|
-
nodes (list): List of node labels to plot the contour for.
|
427
|
+
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
|
424
428
|
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
425
429
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
426
430
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
@@ -433,33 +437,45 @@ class NetworkPlotter:
|
|
433
437
|
Raises:
|
434
438
|
ValueError: If no valid nodes are found in the network graph.
|
435
439
|
"""
|
436
|
-
#
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
if not node_ids or len(node_ids) == 1:
|
444
|
-
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
440
|
+
# Check if nodes is a list of lists or a flat list
|
441
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
442
|
+
# If it's a list of lists, iterate over sublists
|
443
|
+
node_groups = nodes
|
444
|
+
else:
|
445
|
+
# If it's a flat list of nodes, treat it as a single group
|
446
|
+
node_groups = [nodes]
|
445
447
|
|
446
448
|
# Convert color to RGBA using the _to_rgba helper function
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
449
|
+
color_rgba = _to_rgba(color, alpha)
|
450
|
+
|
451
|
+
# Iterate over each group of nodes (either sublists or flat list)
|
452
|
+
for sublist in node_groups:
|
453
|
+
# Filter to get node IDs and their coordinates for each sublist
|
454
|
+
node_ids = [
|
455
|
+
self.graph.node_label_to_id_map.get(node)
|
456
|
+
for node in sublist
|
457
|
+
if node in self.graph.node_label_to_id_map
|
458
|
+
]
|
459
|
+
if not node_ids or len(node_ids) == 1:
|
460
|
+
raise ValueError(
|
461
|
+
"No nodes found in the network graph or insufficient nodes to plot."
|
462
|
+
)
|
463
|
+
|
464
|
+
# Draw the KDE contour for the specified nodes
|
465
|
+
node_coordinates = self.graph.node_coordinates
|
466
|
+
self._draw_kde_contour(
|
467
|
+
self.ax,
|
468
|
+
node_coordinates,
|
469
|
+
node_ids,
|
470
|
+
color=color_rgba,
|
471
|
+
levels=levels,
|
472
|
+
bandwidth=bandwidth,
|
473
|
+
grid_size=grid_size,
|
474
|
+
linestyle=linestyle,
|
475
|
+
linewidth=linewidth,
|
476
|
+
alpha=alpha,
|
477
|
+
fill_alpha=fill_alpha,
|
478
|
+
)
|
463
479
|
|
464
480
|
def _draw_kde_contour(
|
465
481
|
self,
|
@@ -553,6 +569,8 @@ class NetworkPlotter:
|
|
553
569
|
arrow_style: str = "->",
|
554
570
|
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
555
571
|
arrow_alpha: float = 1.0,
|
572
|
+
arrow_base_shrink: float = 0.0,
|
573
|
+
arrow_tip_shrink: float = 0.0,
|
556
574
|
max_labels: Union[int, None] = None,
|
557
575
|
max_words: int = 10,
|
558
576
|
min_words: int = 1,
|
@@ -576,6 +594,8 @@ class NetworkPlotter:
|
|
576
594
|
arrow_style (str, optional): Style of the arrows pointing to centroids. Defaults to "->".
|
577
595
|
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
|
578
596
|
arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
|
597
|
+
arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
|
598
|
+
arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
|
579
599
|
max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
|
580
600
|
max_words (int, optional): Maximum number of words in a label. Defaults to 10.
|
581
601
|
min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
|
@@ -606,6 +626,8 @@ class NetworkPlotter:
|
|
606
626
|
label_arrow_style=arrow_style,
|
607
627
|
label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
608
628
|
label_arrow_alpha=arrow_alpha,
|
629
|
+
label_arrow_base_shrink=arrow_base_shrink,
|
630
|
+
label_arrow_tip_shrink=arrow_tip_shrink,
|
609
631
|
label_max_labels=max_labels,
|
610
632
|
label_max_words=max_words,
|
611
633
|
label_min_words=min_words,
|
@@ -617,6 +639,10 @@ class NetworkPlotter:
|
|
617
639
|
label_ids_to_replace=ids_to_replace,
|
618
640
|
)
|
619
641
|
|
642
|
+
# Set max_labels to the total number of domains if not provided (None)
|
643
|
+
if max_labels is None:
|
644
|
+
max_labels = len(self.graph.domain_to_nodes_map)
|
645
|
+
|
620
646
|
# Convert colors to RGBA using the _to_rgba helper function
|
621
647
|
fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes_map))
|
622
648
|
arrow_color = _to_rgba(
|
@@ -639,6 +665,8 @@ class NetworkPlotter:
|
|
639
665
|
filtered_domain_terms = {}
|
640
666
|
# Handle the ids_to_keep logic
|
641
667
|
if ids_to_keep:
|
668
|
+
# Convert ids_to_keep to remove accidental duplicates
|
669
|
+
ids_to_keep = set(ids_to_keep)
|
642
670
|
# Check if the number of provided ids_to_keep exceeds max_labels
|
643
671
|
if max_labels is not None and len(ids_to_keep) > max_labels:
|
644
672
|
raise ValueError(
|
@@ -708,6 +736,7 @@ class NetworkPlotter:
|
|
708
736
|
best_label_positions = _calculate_best_label_positions(
|
709
737
|
filtered_domain_centroids, center, radius, offset
|
710
738
|
)
|
739
|
+
|
711
740
|
# Annotate the network with labels
|
712
741
|
for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
|
713
742
|
centroid = filtered_domain_centroids[domain]
|
@@ -723,7 +752,11 @@ class NetworkPlotter:
|
|
723
752
|
fontname=font,
|
724
753
|
color=fontcolor[idx],
|
725
754
|
arrowprops=dict(
|
726
|
-
arrowstyle=arrow_style,
|
755
|
+
arrowstyle=arrow_style,
|
756
|
+
color=arrow_color[idx],
|
757
|
+
linewidth=arrow_linewidth,
|
758
|
+
shrinkA=arrow_base_shrink,
|
759
|
+
shrinkB=arrow_tip_shrink,
|
727
760
|
),
|
728
761
|
)
|
729
762
|
# Overlay domain ID at the centroid if requested
|
@@ -742,7 +775,7 @@ class NetworkPlotter:
|
|
742
775
|
|
743
776
|
def plot_sublabel(
|
744
777
|
self,
|
745
|
-
nodes: List,
|
778
|
+
nodes: Union[List, Tuple, np.ndarray],
|
746
779
|
label: str,
|
747
780
|
radial_position: float = 0.0,
|
748
781
|
scale: float = 1.05,
|
@@ -752,13 +785,16 @@ class NetworkPlotter:
|
|
752
785
|
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
753
786
|
fontalpha: float = 1.0,
|
754
787
|
arrow_linewidth: float = 1,
|
788
|
+
arrow_style: str = "->",
|
755
789
|
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
756
790
|
arrow_alpha: float = 1.0,
|
791
|
+
arrow_base_shrink: float = 0.0,
|
792
|
+
arrow_tip_shrink: float = 0.0,
|
757
793
|
) -> None:
|
758
|
-
"""Annotate the network graph with a
|
794
|
+
"""Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
|
759
795
|
|
760
796
|
Args:
|
761
|
-
nodes (
|
797
|
+
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels.
|
762
798
|
label (str): The label to be annotated on the network.
|
763
799
|
radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
|
764
800
|
scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
|
@@ -768,24 +804,24 @@ class NetworkPlotter:
|
|
768
804
|
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
|
769
805
|
fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
|
770
806
|
arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
|
807
|
+
arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
|
771
808
|
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
|
772
809
|
arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
|
810
|
+
arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
|
811
|
+
arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
|
773
812
|
"""
|
774
|
-
#
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
arrow_color = _to_rgba(arrow_color, arrow_alpha)
|
787
|
-
# Calculate the centroid of the provided nodes
|
788
|
-
centroid = self._calculate_domain_centroid(node_ids)
|
813
|
+
# Check if nodes is a list of lists or a flat list
|
814
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
815
|
+
# If it's a list of lists, iterate over sublists
|
816
|
+
node_groups = nodes
|
817
|
+
else:
|
818
|
+
# If it's a flat list of nodes, treat it as a single group
|
819
|
+
node_groups = [nodes]
|
820
|
+
|
821
|
+
# Convert fontcolor and arrow_color to RGBA
|
822
|
+
fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
|
823
|
+
arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
|
824
|
+
|
789
825
|
# Calculate the bounding box around the network
|
790
826
|
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
791
827
|
# Convert radial position to radians, adjusting for a 90-degree rotation
|
@@ -795,19 +831,40 @@ class NetworkPlotter:
|
|
795
831
|
center[1] + (radius + offset) * np.sin(radial_radians),
|
796
832
|
)
|
797
833
|
|
798
|
-
#
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
834
|
+
# Iterate over each group of nodes (either sublists or flat list)
|
835
|
+
for sublist in node_groups:
|
836
|
+
# Map node labels to IDs
|
837
|
+
node_ids = [
|
838
|
+
self.graph.node_label_to_id_map.get(node)
|
839
|
+
for node in sublist
|
840
|
+
if node in self.graph.node_label_to_id_map
|
841
|
+
]
|
842
|
+
if not node_ids or len(node_ids) == 1:
|
843
|
+
raise ValueError(
|
844
|
+
"No nodes found in the network graph or insufficient nodes to plot."
|
845
|
+
)
|
846
|
+
|
847
|
+
# Calculate the centroid of the provided nodes in this sublist
|
848
|
+
centroid = self._calculate_domain_centroid(node_ids)
|
849
|
+
# Annotate the network with the label and an arrow pointing to each centroid
|
850
|
+
self.ax.annotate(
|
851
|
+
label,
|
852
|
+
xy=centroid,
|
853
|
+
xytext=label_position,
|
854
|
+
textcoords="data",
|
855
|
+
ha="center",
|
856
|
+
va="center",
|
857
|
+
fontsize=fontsize,
|
858
|
+
fontname=font,
|
859
|
+
color=fontcolor_rgba,
|
860
|
+
arrowprops=dict(
|
861
|
+
arrowstyle=arrow_style,
|
862
|
+
color=arrow_color_rgba,
|
863
|
+
linewidth=arrow_linewidth,
|
864
|
+
shrinkA=arrow_base_shrink,
|
865
|
+
shrinkB=arrow_tip_shrink,
|
866
|
+
),
|
867
|
+
)
|
811
868
|
|
812
869
|
def _calculate_domain_centroid(self, nodes: List) -> tuple:
|
813
870
|
"""Calculate the most centrally located node in .
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|