risk-network 0.0.6b6__tar.gz → 0.0.6b7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/PKG-INFO +1 -1
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/__init__.py +1 -1
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/network/plot.py +103 -64
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk_network.egg-info/PKG-INFO +1 -1
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/LICENSE +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/MANIFEST.in +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/README.md +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/pyproject.toml +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/annotations/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/annotations/annotations.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/annotations/io.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/constants.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/log/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/log/console.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/log/params.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/neighborhoods/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/neighborhoods/community.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/neighborhoods/domains.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/neighborhoods/neighborhoods.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/network/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/network/geometry.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/network/graph.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/network/io.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/risk.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/stats/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/stats/fisher_exact.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/stats/hypergeom.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/stats/permutation/__init__.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/stats/permutation/permutation.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/stats/permutation/test_functions.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk/stats/stats.py +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk_network.egg-info/SOURCES.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk_network.egg-info/dependency_links.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk_network.egg-info/requires.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/risk_network.egg-info/top_level.txt +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/setup.cfg +0 -0
- {risk_network-0.0.6b6 → risk_network-0.0.6b7}/setup.py +0 -0
@@ -274,7 +274,7 @@ class NetworkPlotter:
|
|
274
274
|
|
275
275
|
def plot_subnetwork(
|
276
276
|
self,
|
277
|
-
nodes: List,
|
277
|
+
nodes: Union[List, Tuple, np.ndarray],
|
278
278
|
node_size: Union[int, np.ndarray] = 50,
|
279
279
|
node_shape: str = "o",
|
280
280
|
node_edgewidth: float = 1.0,
|
@@ -288,20 +288,24 @@ class NetworkPlotter:
|
|
288
288
|
"""Plot a subnetwork of selected nodes with customizable node and edge attributes.
|
289
289
|
|
290
290
|
Args:
|
291
|
-
nodes (list): List of node labels to include in the subnetwork.
|
291
|
+
nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
|
292
292
|
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
293
293
|
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
294
294
|
node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
|
295
295
|
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
296
|
-
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes.
|
296
|
+
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
|
297
297
|
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
298
298
|
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
299
|
-
node_alpha (float, optional):
|
300
|
-
edge_alpha (float, optional):
|
299
|
+
node_alpha (float, optional): Transparency for the nodes. Defaults to 1.0.
|
300
|
+
edge_alpha (float, optional): Transparency for the edges. Defaults to 1.0.
|
301
301
|
|
302
302
|
Raises:
|
303
303
|
ValueError: If no valid nodes are found in the network graph.
|
304
304
|
"""
|
305
|
+
# Flatten nested lists of nodes, if necessary
|
306
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
307
|
+
nodes = [node for sublist in nodes for node in sublist]
|
308
|
+
|
305
309
|
# Filter to get node IDs and their coordinates
|
306
310
|
node_ids = [
|
307
311
|
self.graph.node_label_to_id_map.get(node)
|
@@ -407,7 +411,7 @@ class NetworkPlotter:
|
|
407
411
|
|
408
412
|
def plot_subcontour(
|
409
413
|
self,
|
410
|
-
nodes: List,
|
414
|
+
nodes: Union[List, Tuple, np.ndarray],
|
411
415
|
levels: int = 5,
|
412
416
|
bandwidth: float = 0.8,
|
413
417
|
grid_size: int = 250,
|
@@ -417,10 +421,10 @@ class NetworkPlotter:
|
|
417
421
|
alpha: float = 1.0,
|
418
422
|
fill_alpha: float = 0.2,
|
419
423
|
) -> None:
|
420
|
-
"""Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
|
424
|
+
"""Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
|
421
425
|
|
422
426
|
Args:
|
423
|
-
nodes (list): List of node labels to plot the contour for.
|
427
|
+
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
|
424
428
|
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
425
429
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
426
430
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
@@ -433,33 +437,45 @@ class NetworkPlotter:
|
|
433
437
|
Raises:
|
434
438
|
ValueError: If no valid nodes are found in the network graph.
|
435
439
|
"""
|
436
|
-
#
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
if not node_ids or len(node_ids) == 1:
|
444
|
-
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
440
|
+
# Check if nodes is a list of lists or a flat list
|
441
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
442
|
+
# If it's a list of lists, iterate over sublists
|
443
|
+
node_groups = nodes
|
444
|
+
else:
|
445
|
+
# If it's a flat list of nodes, treat it as a single group
|
446
|
+
node_groups = [nodes]
|
445
447
|
|
446
448
|
# Convert color to RGBA using the _to_rgba helper function
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
449
|
+
color_rgba = _to_rgba(color, alpha)
|
450
|
+
|
451
|
+
# Iterate over each group of nodes (either sublists or flat list)
|
452
|
+
for sublist in node_groups:
|
453
|
+
# Filter to get node IDs and their coordinates for each sublist
|
454
|
+
node_ids = [
|
455
|
+
self.graph.node_label_to_id_map.get(node)
|
456
|
+
for node in sublist
|
457
|
+
if node in self.graph.node_label_to_id_map
|
458
|
+
]
|
459
|
+
if not node_ids or len(node_ids) == 1:
|
460
|
+
raise ValueError(
|
461
|
+
"No nodes found in the network graph or insufficient nodes to plot."
|
462
|
+
)
|
463
|
+
|
464
|
+
# Draw the KDE contour for the specified nodes
|
465
|
+
node_coordinates = self.graph.node_coordinates
|
466
|
+
self._draw_kde_contour(
|
467
|
+
self.ax,
|
468
|
+
node_coordinates,
|
469
|
+
node_ids,
|
470
|
+
color=color_rgba,
|
471
|
+
levels=levels,
|
472
|
+
bandwidth=bandwidth,
|
473
|
+
grid_size=grid_size,
|
474
|
+
linestyle=linestyle,
|
475
|
+
linewidth=linewidth,
|
476
|
+
alpha=alpha,
|
477
|
+
fill_alpha=fill_alpha,
|
478
|
+
)
|
463
479
|
|
464
480
|
def _draw_kde_contour(
|
465
481
|
self,
|
@@ -617,6 +633,10 @@ class NetworkPlotter:
|
|
617
633
|
label_ids_to_replace=ids_to_replace,
|
618
634
|
)
|
619
635
|
|
636
|
+
# Set max_labels to the total number of domains if not provided (None)
|
637
|
+
if max_labels is None:
|
638
|
+
max_labels = len(self.graph.domain_to_nodes_map)
|
639
|
+
|
620
640
|
# Convert colors to RGBA using the _to_rgba helper function
|
621
641
|
fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes_map))
|
622
642
|
arrow_color = _to_rgba(
|
@@ -639,6 +659,8 @@ class NetworkPlotter:
|
|
639
659
|
filtered_domain_terms = {}
|
640
660
|
# Handle the ids_to_keep logic
|
641
661
|
if ids_to_keep:
|
662
|
+
# Convert ids_to_keep to remove accidental duplicates
|
663
|
+
ids_to_keep = set(ids_to_keep)
|
642
664
|
# Check if the number of provided ids_to_keep exceeds max_labels
|
643
665
|
if max_labels is not None and len(ids_to_keep) > max_labels:
|
644
666
|
raise ValueError(
|
@@ -708,6 +730,7 @@ class NetworkPlotter:
|
|
708
730
|
best_label_positions = _calculate_best_label_positions(
|
709
731
|
filtered_domain_centroids, center, radius, offset
|
710
732
|
)
|
733
|
+
|
711
734
|
# Annotate the network with labels
|
712
735
|
for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
|
713
736
|
centroid = filtered_domain_centroids[domain]
|
@@ -742,7 +765,7 @@ class NetworkPlotter:
|
|
742
765
|
|
743
766
|
def plot_sublabel(
|
744
767
|
self,
|
745
|
-
nodes: List,
|
768
|
+
nodes: Union[List, Tuple, np.ndarray],
|
746
769
|
label: str,
|
747
770
|
radial_position: float = 0.0,
|
748
771
|
scale: float = 1.05,
|
@@ -752,13 +775,14 @@ class NetworkPlotter:
|
|
752
775
|
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
753
776
|
fontalpha: float = 1.0,
|
754
777
|
arrow_linewidth: float = 1,
|
778
|
+
arrow_style: str = "->",
|
755
779
|
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
756
780
|
arrow_alpha: float = 1.0,
|
757
781
|
) -> None:
|
758
|
-
"""Annotate the network graph with a
|
782
|
+
"""Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
|
759
783
|
|
760
784
|
Args:
|
761
|
-
nodes (
|
785
|
+
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels.
|
762
786
|
label (str): The label to be annotated on the network.
|
763
787
|
radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
|
764
788
|
scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
|
@@ -768,24 +792,22 @@ class NetworkPlotter:
|
|
768
792
|
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
|
769
793
|
fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
|
770
794
|
arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
|
795
|
+
arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
|
771
796
|
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
|
772
797
|
arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
|
773
798
|
"""
|
774
|
-
#
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
arrow_color = _to_rgba(arrow_color, arrow_alpha)
|
787
|
-
# Calculate the centroid of the provided nodes
|
788
|
-
centroid = self._calculate_domain_centroid(node_ids)
|
799
|
+
# Check if nodes is a list of lists or a flat list
|
800
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
801
|
+
# If it's a list of lists, iterate over sublists
|
802
|
+
node_groups = nodes
|
803
|
+
else:
|
804
|
+
# If it's a flat list of nodes, treat it as a single group
|
805
|
+
node_groups = [nodes]
|
806
|
+
|
807
|
+
# Convert fontcolor and arrow_color to RGBA
|
808
|
+
fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
|
809
|
+
arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
|
810
|
+
|
789
811
|
# Calculate the bounding box around the network
|
790
812
|
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
791
813
|
# Convert radial position to radians, adjusting for a 90-degree rotation
|
@@ -795,19 +817,36 @@ class NetworkPlotter:
|
|
795
817
|
center[1] + (radius + offset) * np.sin(radial_radians),
|
796
818
|
)
|
797
819
|
|
798
|
-
#
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
820
|
+
# Iterate over each group of nodes (either sublists or flat list)
|
821
|
+
for sublist in node_groups:
|
822
|
+
# Map node labels to IDs
|
823
|
+
node_ids = [
|
824
|
+
self.graph.node_label_to_id_map.get(node)
|
825
|
+
for node in sublist
|
826
|
+
if node in self.graph.node_label_to_id_map
|
827
|
+
]
|
828
|
+
if not node_ids or len(node_ids) == 1:
|
829
|
+
raise ValueError(
|
830
|
+
"No nodes found in the network graph or insufficient nodes to plot."
|
831
|
+
)
|
832
|
+
|
833
|
+
# Calculate the centroid of the provided nodes in this sublist
|
834
|
+
centroid = self._calculate_domain_centroid(node_ids)
|
835
|
+
# Annotate the network with the label and an arrow pointing to each centroid
|
836
|
+
self.ax.annotate(
|
837
|
+
label,
|
838
|
+
xy=centroid,
|
839
|
+
xytext=label_position,
|
840
|
+
textcoords="data",
|
841
|
+
ha="center",
|
842
|
+
va="center",
|
843
|
+
fontsize=fontsize,
|
844
|
+
fontname=font,
|
845
|
+
color=fontcolor_rgba,
|
846
|
+
arrowprops=dict(
|
847
|
+
arrowstyle=arrow_style, color=arrow_color_rgba, linewidth=arrow_linewidth
|
848
|
+
),
|
849
|
+
)
|
811
850
|
|
812
851
|
def _calculate_domain_centroid(self, nodes: List) -> tuple:
|
813
852
|
"""Calculate the most centrally located node in .
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|