risk-network 0.0.6b6__py3-none-any.whl → 0.0.6b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.6-beta.6"
10
+ __version__ = "0.0.6-beta.7"
risk/network/plot.py CHANGED
@@ -274,7 +274,7 @@ class NetworkPlotter:
274
274
 
275
275
  def plot_subnetwork(
276
276
  self,
277
- nodes: List,
277
+ nodes: Union[List, Tuple, np.ndarray],
278
278
  node_size: Union[int, np.ndarray] = 50,
279
279
  node_shape: str = "o",
280
280
  node_edgewidth: float = 1.0,
@@ -288,20 +288,24 @@ class NetworkPlotter:
288
288
  """Plot a subnetwork of selected nodes with customizable node and edge attributes.
289
289
 
290
290
  Args:
291
- nodes (list): List of node labels to include in the subnetwork.
291
+ nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
292
292
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
293
293
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
294
294
  node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
295
295
  edge_width (float, optional): Width of the edges. Defaults to 1.0.
296
- node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
296
+ node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
297
297
  node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
298
298
  edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
299
- node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0.
300
- edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
299
+ node_alpha (float, optional): Transparency for the nodes. Defaults to 1.0.
300
+ edge_alpha (float, optional): Transparency for the edges. Defaults to 1.0.
301
301
 
302
302
  Raises:
303
303
  ValueError: If no valid nodes are found in the network graph.
304
304
  """
305
+ # Flatten nested lists of nodes, if necessary
306
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
307
+ nodes = [node for sublist in nodes for node in sublist]
308
+
305
309
  # Filter to get node IDs and their coordinates
306
310
  node_ids = [
307
311
  self.graph.node_label_to_id_map.get(node)
@@ -407,7 +411,7 @@ class NetworkPlotter:
407
411
 
408
412
  def plot_subcontour(
409
413
  self,
410
- nodes: List,
414
+ nodes: Union[List, Tuple, np.ndarray],
411
415
  levels: int = 5,
412
416
  bandwidth: float = 0.8,
413
417
  grid_size: int = 250,
@@ -417,10 +421,10 @@ class NetworkPlotter:
417
421
  alpha: float = 1.0,
418
422
  fill_alpha: float = 0.2,
419
423
  ) -> None:
420
- """Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
424
+ """Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
421
425
 
422
426
  Args:
423
- nodes (list): List of node labels to plot the contour for.
427
+ nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
424
428
  levels (int, optional): Number of contour levels to plot. Defaults to 5.
425
429
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
426
430
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
@@ -433,33 +437,45 @@ class NetworkPlotter:
433
437
  Raises:
434
438
  ValueError: If no valid nodes are found in the network graph.
435
439
  """
436
- # Don't log subcontour parameters as they are specific to individual annotations
437
- # Filter to get node IDs and their coordinates
438
- node_ids = [
439
- self.graph.node_label_to_id_map.get(node)
440
- for node in nodes
441
- if node in self.graph.node_label_to_id_map
442
- ]
443
- if not node_ids or len(node_ids) == 1:
444
- raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
440
+ # Check if nodes is a list of lists or a flat list
441
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
442
+ # If it's a list of lists, iterate over sublists
443
+ node_groups = nodes
444
+ else:
445
+ # If it's a flat list of nodes, treat it as a single group
446
+ node_groups = [nodes]
445
447
 
446
448
  # Convert color to RGBA using the _to_rgba helper function
447
- color = _to_rgba(color, alpha)
448
- # Draw the KDE contour for the specified nodes
449
- node_coordinates = self.graph.node_coordinates
450
- self._draw_kde_contour(
451
- self.ax,
452
- node_coordinates,
453
- node_ids,
454
- color=color,
455
- levels=levels,
456
- bandwidth=bandwidth,
457
- grid_size=grid_size,
458
- linestyle=linestyle,
459
- linewidth=linewidth,
460
- alpha=alpha,
461
- fill_alpha=fill_alpha,
462
- )
449
+ color_rgba = _to_rgba(color, alpha)
450
+
451
+ # Iterate over each group of nodes (either sublists or flat list)
452
+ for sublist in node_groups:
453
+ # Filter to get node IDs and their coordinates for each sublist
454
+ node_ids = [
455
+ self.graph.node_label_to_id_map.get(node)
456
+ for node in sublist
457
+ if node in self.graph.node_label_to_id_map
458
+ ]
459
+ if not node_ids or len(node_ids) == 1:
460
+ raise ValueError(
461
+ "No nodes found in the network graph or insufficient nodes to plot."
462
+ )
463
+
464
+ # Draw the KDE contour for the specified nodes
465
+ node_coordinates = self.graph.node_coordinates
466
+ self._draw_kde_contour(
467
+ self.ax,
468
+ node_coordinates,
469
+ node_ids,
470
+ color=color_rgba,
471
+ levels=levels,
472
+ bandwidth=bandwidth,
473
+ grid_size=grid_size,
474
+ linestyle=linestyle,
475
+ linewidth=linewidth,
476
+ alpha=alpha,
477
+ fill_alpha=fill_alpha,
478
+ )
463
479
 
464
480
  def _draw_kde_contour(
465
481
  self,
@@ -617,6 +633,10 @@ class NetworkPlotter:
617
633
  label_ids_to_replace=ids_to_replace,
618
634
  )
619
635
 
636
+ # Set max_labels to the total number of domains if not provided (None)
637
+ if max_labels is None:
638
+ max_labels = len(self.graph.domain_to_nodes_map)
639
+
620
640
  # Convert colors to RGBA using the _to_rgba helper function
621
641
  fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes_map))
622
642
  arrow_color = _to_rgba(
@@ -639,6 +659,8 @@ class NetworkPlotter:
639
659
  filtered_domain_terms = {}
640
660
  # Handle the ids_to_keep logic
641
661
  if ids_to_keep:
662
+ # Convert ids_to_keep to remove accidental duplicates
663
+ ids_to_keep = set(ids_to_keep)
642
664
  # Check if the number of provided ids_to_keep exceeds max_labels
643
665
  if max_labels is not None and len(ids_to_keep) > max_labels:
644
666
  raise ValueError(
@@ -708,6 +730,7 @@ class NetworkPlotter:
708
730
  best_label_positions = _calculate_best_label_positions(
709
731
  filtered_domain_centroids, center, radius, offset
710
732
  )
733
+
711
734
  # Annotate the network with labels
712
735
  for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
713
736
  centroid = filtered_domain_centroids[domain]
@@ -742,7 +765,7 @@ class NetworkPlotter:
742
765
 
743
766
  def plot_sublabel(
744
767
  self,
745
- nodes: List,
768
+ nodes: Union[List, Tuple, np.ndarray],
746
769
  label: str,
747
770
  radial_position: float = 0.0,
748
771
  scale: float = 1.05,
@@ -752,13 +775,14 @@ class NetworkPlotter:
752
775
  fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
753
776
  fontalpha: float = 1.0,
754
777
  arrow_linewidth: float = 1,
778
+ arrow_style: str = "->",
755
779
  arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
756
780
  arrow_alpha: float = 1.0,
757
781
  ) -> None:
758
- """Annotate the network graph with a single label for the given nodes, positioned at a specified radial angle.
782
+ """Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
759
783
 
760
784
  Args:
761
- nodes (List): List of node labels to be used for calculating the centroid.
785
+ nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels.
762
786
  label (str): The label to be annotated on the network.
763
787
  radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
764
788
  scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
@@ -768,24 +792,22 @@ class NetworkPlotter:
768
792
  fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
769
793
  fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
770
794
  arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
795
+ arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
771
796
  arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
772
797
  arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
773
798
  """
774
- # Don't log sublabel parameters as they are specific to individual annotations
775
- # Map node labels to IDs
776
- node_ids = [
777
- self.graph.node_label_to_id_map.get(node)
778
- for node in nodes
779
- if node in self.graph.node_label_to_id_map
780
- ]
781
- if not node_ids or len(node_ids) == 1:
782
- raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
783
-
784
- # Convert fontcolor and arrow_color to RGBA using the _to_rgba helper function
785
- fontcolor = _to_rgba(fontcolor, fontalpha)
786
- arrow_color = _to_rgba(arrow_color, arrow_alpha)
787
- # Calculate the centroid of the provided nodes
788
- centroid = self._calculate_domain_centroid(node_ids)
799
+ # Check if nodes is a list of lists or a flat list
800
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
801
+ # If it's a list of lists, iterate over sublists
802
+ node_groups = nodes
803
+ else:
804
+ # If it's a flat list of nodes, treat it as a single group
805
+ node_groups = [nodes]
806
+
807
+ # Convert fontcolor and arrow_color to RGBA
808
+ fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
809
+ arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
810
+
789
811
  # Calculate the bounding box around the network
790
812
  center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
791
813
  # Convert radial position to radians, adjusting for a 90-degree rotation
@@ -795,19 +817,36 @@ class NetworkPlotter:
795
817
  center[1] + (radius + offset) * np.sin(radial_radians),
796
818
  )
797
819
 
798
- # Annotate the network with the label
799
- self.ax.annotate(
800
- label,
801
- xy=centroid,
802
- xytext=label_position,
803
- textcoords="data",
804
- ha="center",
805
- va="center",
806
- fontsize=fontsize,
807
- fontname=font,
808
- color=fontcolor,
809
- arrowprops=dict(arrowstyle="->", color=arrow_color, linewidth=arrow_linewidth),
810
- )
820
+ # Iterate over each group of nodes (either sublists or flat list)
821
+ for sublist in node_groups:
822
+ # Map node labels to IDs
823
+ node_ids = [
824
+ self.graph.node_label_to_id_map.get(node)
825
+ for node in sublist
826
+ if node in self.graph.node_label_to_id_map
827
+ ]
828
+ if not node_ids or len(node_ids) == 1:
829
+ raise ValueError(
830
+ "No nodes found in the network graph or insufficient nodes to plot."
831
+ )
832
+
833
+ # Calculate the centroid of the provided nodes in this sublist
834
+ centroid = self._calculate_domain_centroid(node_ids)
835
+ # Annotate the network with the label and an arrow pointing to each centroid
836
+ self.ax.annotate(
837
+ label,
838
+ xy=centroid,
839
+ xytext=label_position,
840
+ textcoords="data",
841
+ ha="center",
842
+ va="center",
843
+ fontsize=fontsize,
844
+ fontname=font,
845
+ color=fontcolor_rgba,
846
+ arrowprops=dict(
847
+ arrowstyle=arrow_style, color=arrow_color_rgba, linewidth=arrow_linewidth
848
+ ),
849
+ )
811
850
 
812
851
  def _calculate_domain_centroid(self, nodes: List) -> tuple:
813
852
  """Calculate the most centrally located node in .
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.6b6
3
+ Version: 0.0.6b7
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -1,4 +1,4 @@
1
- risk/__init__.py,sha256=W_0EIUjGsvlMd6HMEtSq_HCSBIQ8FH9MMb1zWxoALzU,112
1
+ risk/__init__.py,sha256=DLDQBVlM5oYOp-S9i1GDq-rRwaNv1dutx241xcHbK3w,112
2
2
  risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
3
  risk/risk.py,sha256=PONl5tzN5DSVUf4MgczfOvzGV-5JoAOLTQ6YWl10mZ8,20697
4
4
  risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
@@ -15,7 +15,7 @@ risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
15
15
  risk/network/geometry.py,sha256=H1yGVVqgbfpzBzJwEheDLfvGLSA284jGQQTn612L4Vc,6759
16
16
  risk/network/graph.py,sha256=scPFQIJjioup1FjQLyxNrAB17RmskY9MmvoFHrMlqNI,13135
17
17
  risk/network/io.py,sha256=gG50kOknO-D3HkW1HsbHMkTMvjUtn3l4W4Jwd-rXNr8,21202
18
- risk/network/plot.py,sha256=jZm8wfrcswI_SDgDV1PEE2pO0CyuHk2fOi44o5fwEVI,58845
18
+ risk/network/plot.py,sha256=4zcFJWZgGrj7AG1crRuwGpzzjjHXjQ8Eh5VjlkOQblE,60747
19
19
  risk/stats/__init__.py,sha256=e-BE_Dr_jgiK6hKM-T-tlG4yvHnId8e5qjnM0pdwNVc,230
20
20
  risk/stats/fisher_exact.py,sha256=-bPwzu76-ob0HzrTV20mXUTot7v-MLuqFaAoab-QxPg,4966
21
21
  risk/stats/hypergeom.py,sha256=lrIFdhCWRjvM4apYw1MlOKqT_IY5OjtCwrjdtJdt6Tg,4954
@@ -23,8 +23,8 @@ risk/stats/stats.py,sha256=kvShov-94W6ffgDUTb522vB9hDJQSyTsYif_UIaFfSM,7059
23
23
  risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
24
24
  risk/stats/permutation/permutation.py,sha256=qLWdwxEY6nmkYPxpM8HLDcd2mbqYv9Qr7CKtJvhLqIM,9220
25
25
  risk/stats/permutation/test_functions.py,sha256=HuDIM-V1jkkfE1rlaIqrWWBSKZt3dQ1f-YEDjWpnLSE,2343
26
- risk_network-0.0.6b6.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
27
- risk_network-0.0.6b6.dist-info/METADATA,sha256=3DAPqiBMVnOmSDTgQ-F49litJZ0zl4EeZ3sENuw0UHM,43142
28
- risk_network-0.0.6b6.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
29
- risk_network-0.0.6b6.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
30
- risk_network-0.0.6b6.dist-info/RECORD,,
26
+ risk_network-0.0.6b7.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
27
+ risk_network-0.0.6b7.dist-info/METADATA,sha256=yiCOKKWr1cByfyrssz1g0LLSYLfvUr7XziphQPe2IkA,43142
28
+ risk_network-0.0.6b7.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
29
+ risk_network-0.0.6b7.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
30
+ risk_network-0.0.6b7.dist-info/RECORD,,