risk-network 0.0.8b15__tar.gz → 0.0.8b17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/PKG-INFO +1 -1
  2. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/__init__.py +1 -1
  3. risk_network-0.0.8b17/risk/network/graph.py +159 -0
  4. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/canvas.py +15 -8
  5. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/contour.py +11 -9
  6. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/labels.py +18 -10
  7. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/network.py +20 -13
  8. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/plotter.py +9 -6
  9. risk_network-0.0.8b17/risk/network/plot/utils/color.py +353 -0
  10. risk_network-0.0.8b17/risk/network/plot/utils/layout.py +53 -0
  11. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/PKG-INFO +1 -1
  12. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/SOURCES.txt +2 -1
  13. risk_network-0.0.8b15/risk/network/graph.py +0 -393
  14. risk_network-0.0.8b15/risk/network/plot/utils.py +0 -153
  15. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/LICENSE +0 -0
  16. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/MANIFEST.in +0 -0
  17. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/README.md +0 -0
  18. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/pyproject.toml +0 -0
  19. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/annotations/__init__.py +0 -0
  20. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/annotations/annotations.py +0 -0
  21. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/annotations/io.py +0 -0
  22. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/constants.py +0 -0
  23. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/log/__init__.py +0 -0
  24. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/log/config.py +0 -0
  25. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/log/params.py +0 -0
  26. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/__init__.py +0 -0
  27. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/community.py +0 -0
  28. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/domains.py +0 -0
  29. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/neighborhoods.py +0 -0
  30. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/__init__.py +0 -0
  31. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/geometry.py +0 -0
  32. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/io.py +0 -0
  33. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/__init__.py +0 -0
  34. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/risk.py +0 -0
  35. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/__init__.py +0 -0
  36. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/hypergeom.py +0 -0
  37. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/permutation/__init__.py +0 -0
  38. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/permutation/permutation.py +0 -0
  39. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/permutation/test_functions.py +0 -0
  40. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/poisson.py +0 -0
  41. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/stats.py +0 -0
  42. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/dependency_links.txt +0 -0
  43. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/requires.txt +0 -0
  44. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/top_level.txt +0 -0
  45. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/setup.cfg +0 -0
  46. {risk_network-0.0.8b15 → risk_network-0.0.8b17}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.8b15
3
+ Version: 0.0.8b17
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.8-beta.15"
10
+ __version__ = "0.0.8-beta.17"
@@ -0,0 +1,159 @@
1
+ """
2
+ risk/network/graph
3
+ ~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from collections import defaultdict
7
+ from typing import Any, Dict, List
8
+
9
+ import networkx as nx
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+
14
+ class NetworkGraph:
15
+ """A class to represent a network graph and process its nodes and edges.
16
+
17
+ The NetworkGraph class provides functionality to handle and manipulate a network graph,
18
+ including managing domains, annotations, and node enrichment data. It also includes methods
19
+ for transforming and mapping graph coordinates, as well as generating colors based on node
20
+ enrichment.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ network: nx.Graph,
26
+ top_annotations: pd.DataFrame,
27
+ domains: pd.DataFrame,
28
+ trimmed_domains: pd.DataFrame,
29
+ node_label_to_node_id_map: Dict[str, Any],
30
+ node_enrichment_sums: np.ndarray,
31
+ ):
32
+ """Initialize the NetworkGraph object.
33
+
34
+ Args:
35
+ network (nx.Graph): The network graph.
36
+ top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
37
+ domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
38
+ trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
39
+ node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
40
+ node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
41
+ """
42
+ self.top_annotations = top_annotations
43
+ self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
44
+ self.domains = domains
45
+ self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
46
+ trimmed_domains
47
+ )
48
+ self.node_enrichment_sums = node_enrichment_sums
49
+ self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
50
+ self.node_label_to_enrichment_map = dict(
51
+ zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
52
+ )
53
+ self.node_label_to_node_id_map = node_label_to_node_id_map
54
+ # NOTE: Below this point, instance attributes (i.e., self) will be used!
55
+ self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
56
+ # Unfold the network's 3D coordinates to 2D and extract node coordinates
57
+ self.network = _unfold_sphere_to_plane(network)
58
+ self.node_coordinates = _extract_node_coordinates(self.network)
59
+
60
+ def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
61
+ """Create a mapping from domains to the list of node IDs belonging to each domain.
62
+
63
+ Args:
64
+ domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
65
+
66
+ Returns:
67
+ dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
68
+ """
69
+ cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
70
+ node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
71
+ domain_id_to_node_ids_map = defaultdict(list)
72
+ for k, v in node_to_domains_map.items():
73
+ domain_id_to_node_ids_map[v].append(k)
74
+
75
+ return domain_id_to_node_ids_map
76
+
77
+ def _create_domain_id_to_domain_terms_map(
78
+ self, trimmed_domains: pd.DataFrame
79
+ ) -> Dict[str, Any]:
80
+ """Create a mapping from domain IDs to their corresponding terms.
81
+
82
+ Args:
83
+ trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
84
+
85
+ Returns:
86
+ dict: A dictionary mapping domain IDs to their corresponding terms.
87
+ """
88
+ return dict(
89
+ zip(
90
+ trimmed_domains.index,
91
+ trimmed_domains["label"],
92
+ )
93
+ )
94
+
95
+ def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
96
+ """Create a map from domain IDs to node labels.
97
+
98
+ Returns:
99
+ dict: A dictionary mapping domain IDs to the corresponding node labels.
100
+ """
101
+ domain_id_to_label_map = {}
102
+ for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
103
+ domain_id_to_label_map[domain_id] = [
104
+ self.node_id_to_node_label_map[node_id] for node_id in node_ids
105
+ ]
106
+
107
+ return domain_id_to_label_map
108
+
109
+
110
+ def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
111
+ """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
112
+
113
+ Args:
114
+ G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
115
+
116
+ Returns:
117
+ nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
118
+ """
119
+ for node in G.nodes():
120
+ if "z" in G.nodes[node]:
121
+ # Extract 3D coordinates
122
+ x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
123
+ # Calculate spherical coordinates theta and phi from Cartesian coordinates
124
+ r = np.sqrt(x**2 + y**2 + z**2)
125
+ theta = np.arctan2(y, x)
126
+ phi = np.arccos(z / r)
127
+
128
+ # Convert spherical coordinates to 2D plane coordinates
129
+ unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
130
+ unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
131
+ unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
132
+ # Update network node attributes
133
+ G.nodes[node]["x"] = unfolded_x
134
+ G.nodes[node]["y"] = -unfolded_y
135
+ # Remove the 'z' coordinate as it's no longer needed
136
+ del G.nodes[node]["z"]
137
+
138
+ return G
139
+
140
+
141
+ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
142
+ """Extract 2D coordinates of nodes from the graph.
143
+
144
+ Args:
145
+ G (nx.Graph): The network graph with node coordinates.
146
+
147
+ Returns:
148
+ np.ndarray: Array of node coordinates with shape (num_nodes, 2).
149
+ """
150
+ # Extract x and y coordinates from graph nodes
151
+ x_coords = dict(G.nodes.data("x"))
152
+ y_coords = dict(G.nodes.data("y"))
153
+ coordinates_dicts = [x_coords, y_coords]
154
+ # Combine x and y coordinates into a single array
155
+ node_positions = {
156
+ node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
157
+ }
158
+ node_coordinates = np.vstack(list(node_positions.values()))
159
+ return node_coordinates
@@ -10,7 +10,8 @@ import numpy as np
10
10
 
11
11
  from risk.log import params
12
12
  from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils import calculate_bounding_box, to_rgba
13
+ from risk.network.plot.utils.color import to_rgba
14
+ from risk.network.plot.utils.layout import calculate_bounding_box
14
15
 
15
16
 
16
17
  class Canvas:
@@ -33,8 +34,8 @@ class Canvas:
33
34
  title_fontsize: int = 20,
34
35
  subtitle_fontsize: int = 14,
35
36
  font: str = "Arial",
36
- title_color: str = "black",
37
- subtitle_color: str = "gray",
37
+ title_color: Union[str, list, tuple, np.ndarray] = "black",
38
+ subtitle_color: Union[str, list, tuple, np.ndarray] = "gray",
38
39
  title_y: float = 0.975,
39
40
  title_space_offset: float = 0.075,
40
41
  subtitle_offset: float = 0.025,
@@ -47,8 +48,10 @@ class Canvas:
47
48
  title_fontsize (int, optional): Font size for the title. Defaults to 20.
48
49
  subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 14.
49
50
  font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
50
- title_color (str, optional): Color of the title text. Defaults to "black".
51
- subtitle_color (str, optional): Color of the subtitle text. Defaults to "gray".
51
+ title_color (str, list, tuple, or np.ndarray, optional): Color of the title text. Can be a string or an array of colors.
52
+ Defaults to "black".
53
+ subtitle_color (str, list, tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
54
+ Defaults to "gray".
52
55
  title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
53
56
  title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
54
57
  subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
@@ -141,7 +144,9 @@ class Canvas:
141
144
  )
142
145
 
143
146
  # Convert color to RGBA using the to_rgba helper function - use outline_alpha for the perimeter
144
- color = to_rgba(color=color, alpha=outline_alpha)
147
+ color = to_rgba(
148
+ color=color, alpha=outline_alpha, num_repeats=1
149
+ ) # num_repeats=1 for a single color
145
150
  # Set the fill_alpha to 0 if not provided
146
151
  fill_alpha = fill_alpha if fill_alpha is not None else 0.0
147
152
  # Extract node coordinates from the network graph
@@ -162,7 +167,9 @@ class Canvas:
162
167
  )
163
168
  # Set the transparency of the fill if applicable
164
169
  if fill_alpha > 0:
165
- circle.set_facecolor(to_rgba(color=color, alpha=fill_alpha))
170
+ circle.set_facecolor(
171
+ to_rgba(color=color, alpha=fill_alpha, num_repeats=1)
172
+ ) # num_repeats=1 for a single color
166
173
 
167
174
  self.ax.add_artist(circle)
168
175
 
@@ -209,7 +216,7 @@ class Canvas:
209
216
  )
210
217
 
211
218
  # Convert color to RGBA using outline_alpha for the line (outline)
212
- outline_color = to_rgba(color=color)
219
+ outline_color = to_rgba(color=color, num_repeats=1) # num_repeats=1 for a single color
213
220
  # Extract node coordinates from the network graph
214
221
  node_coordinates = self.graph.node_coordinates
215
222
  # Scale the node coordinates if needed
@@ -13,7 +13,7 @@ from scipy.stats import gaussian_kde
13
13
 
14
14
  from risk.log import params, logger
15
15
  from risk.network.graph import NetworkGraph
16
- from risk.network.plot.utils import get_annotated_domain_colors, to_rgba
16
+ from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
17
17
 
18
18
 
19
19
  class Contour:
@@ -110,7 +110,7 @@ class Contour:
110
110
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
111
111
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
112
112
  color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array.
113
- Defaults to "white".
113
+ Can be a single color or an array of colors. Defaults to "white".
114
114
  linestyle (str, optional): Line style for the contour. Defaults to "solid".
115
115
  linewidth (float, optional): Line width for the contour. Defaults to 1.5.
116
116
  alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
@@ -125,15 +125,16 @@ class Contour:
125
125
  if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
126
126
  # If it's a list of lists, iterate over sublists
127
127
  node_groups = nodes
128
+ # Convert color to RGBA arrays to match the number of groups
129
+ color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=len(node_groups))
128
130
  else:
129
131
  # If it's a flat list of nodes, treat it as a single group
130
132
  node_groups = [nodes]
131
-
132
- # Convert color to RGBA using the to_rgba helper function
133
- color_rgba = to_rgba(color=color, alpha=alpha)
133
+ # Wrap the RGBA color in an array to index the first element
134
+ color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=1)
134
135
 
135
136
  # Iterate over each group of nodes (either sublists or flat list)
136
- for sublist in node_groups:
137
+ for idx, sublist in enumerate(node_groups):
137
138
  # Filter to get node IDs and their coordinates for each sublist
138
139
  node_ids = [
139
140
  self.graph.node_label_to_node_id_map.get(node)
@@ -151,7 +152,7 @@ class Contour:
151
152
  self.ax,
152
153
  node_coordinates,
153
154
  node_ids,
154
- color=color_rgba,
155
+ color=color_rgba[idx],
155
156
  levels=levels,
156
157
  bandwidth=bandwidth,
157
158
  grid_size=grid_size,
@@ -273,7 +274,7 @@ class Contour:
273
274
  def get_annotated_contour_colors(
274
275
  self,
275
276
  cmap: str = "gist_rainbow",
276
- color: Union[str, None] = None,
277
+ color: Union[str, list, tuple, np.ndarray, None] = None,
277
278
  min_scale: float = 0.8,
278
279
  max_scale: float = 1.0,
279
280
  scale_factor: float = 1.0,
@@ -283,7 +284,8 @@ class Contour:
283
284
 
284
285
  Args:
285
286
  cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
286
- color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
287
+ color (str, list, tuple, np.ndarray, or None, optional): Color to use for the contours. Can be a single color or an array of colors.
288
+ If None, the colormap will be used. Defaults to None.
287
289
  min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
288
290
  Controls the dimmest colors. Defaults to 0.8.
289
291
  max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
@@ -11,7 +11,8 @@ import pandas as pd
11
11
 
12
12
  from risk.log import params
13
13
  from risk.network.graph import NetworkGraph
14
- from risk.network.plot.utils import calculate_bounding_box, get_annotated_domain_colors, to_rgba
14
+ from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
15
+ from risk.network.plot.utils.layout import calculate_bounding_box
15
16
 
16
17
  TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
17
18
 
@@ -284,13 +285,19 @@ class Labels:
284
285
  if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
285
286
  # If it's a list of lists, iterate over sublists
286
287
  node_groups = nodes
288
+ # Convert fontcolor and arrow_color to RGBA arrays to match the number of groups
289
+ fontcolor_rgba = to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=len(node_groups))
290
+ arrow_color_rgba = to_rgba(
291
+ color=arrow_color, alpha=arrow_alpha, num_repeats=len(node_groups)
292
+ )
287
293
  else:
288
294
  # If it's a flat list of nodes, treat it as a single group
289
295
  node_groups = [nodes]
290
-
291
- # Convert fontcolor and arrow_color to RGBA
292
- fontcolor_rgba = to_rgba(color=fontcolor, alpha=fontalpha)
293
- arrow_color_rgba = to_rgba(color=arrow_color, alpha=arrow_alpha)
296
+ # Wrap the RGBA fontcolor and arrow_color in an array to index the first element
297
+ fontcolor_rgba = np.array(to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=1))
298
+ arrow_color_rgba = np.array(
299
+ to_rgba(color=arrow_color, alpha=arrow_alpha, num_repeats=1)
300
+ )
294
301
 
295
302
  # Calculate the bounding box around the network
296
303
  center, radius = calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
@@ -302,7 +309,7 @@ class Labels:
302
309
  )
303
310
 
304
311
  # Iterate over each group of nodes (either sublists or flat list)
305
- for sublist in node_groups:
312
+ for idx, sublist in enumerate(node_groups):
306
313
  # Map node labels to IDs
307
314
  node_ids = [
308
315
  self.graph.node_label_to_node_id_map.get(node)
@@ -326,10 +333,10 @@ class Labels:
326
333
  va="center",
327
334
  fontsize=fontsize,
328
335
  fontname=font,
329
- color=fontcolor_rgba,
336
+ color=fontcolor_rgba[idx],
330
337
  arrowprops=dict(
331
338
  arrowstyle=arrow_style,
332
- color=arrow_color_rgba,
339
+ color=arrow_color_rgba[idx],
333
340
  linewidth=arrow_linewidth,
334
341
  shrinkA=arrow_base_shrink,
335
342
  shrinkB=arrow_tip_shrink,
@@ -630,7 +637,7 @@ class Labels:
630
637
  def get_annotated_label_colors(
631
638
  self,
632
639
  cmap: str = "gist_rainbow",
633
- color: Union[str, None] = None,
640
+ color: Union[str, list, tuple, np.ndarray, None] = None,
634
641
  min_scale: float = 0.8,
635
642
  max_scale: float = 1.0,
636
643
  scale_factor: float = 1.0,
@@ -640,7 +647,8 @@ class Labels:
640
647
 
641
648
  Args:
642
649
  cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
643
- color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
650
+ color (str, list, tuple, np.ndarray, or None, optional): Color to use for the labels. Can be a single color or an array
651
+ of colors. If None, the colormap will be used. Defaults to None.
644
652
  min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
645
653
  Controls the dimmest colors. Defaults to 0.8.
646
654
  max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
@@ -10,7 +10,7 @@ import numpy as np
10
10
 
11
11
  from risk.log import params
12
12
  from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils import to_rgba
13
+ from risk.network.plot.utils.color import get_domain_colors, to_rgba
14
14
 
15
15
 
16
16
  class Network:
@@ -47,8 +47,10 @@ class Network:
47
47
  edge_width (float, optional): Width of the edges. Defaults to 1.0.
48
48
  node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors.
49
49
  Defaults to "white".
50
- node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
51
- edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
50
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Can be a single color or an array of colors.
51
+ Defaults to "black".
52
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Can be a single color or an array of colors.
53
+ Defaults to "black".
52
54
  node_alpha (float, None, optional): Alpha value (transparency) for the nodes. If provided, it overrides any existing alpha
53
55
  values found in node_color. Defaults to 1.0. Annotated node_color alphas will override this value.
54
56
  edge_alpha (float, None, optional): Alpha value (transparency) for the edges. If provided, it overrides any existing alpha
@@ -194,12 +196,12 @@ class Network:
194
196
  def get_annotated_node_colors(
195
197
  self,
196
198
  cmap: str = "gist_rainbow",
197
- color: Union[str, None] = None,
199
+ color: Union[str, list, tuple, np.ndarray, None] = None,
198
200
  min_scale: float = 0.8,
199
201
  max_scale: float = 1.0,
200
202
  scale_factor: float = 1.0,
201
203
  alpha: Union[float, None] = 1.0,
202
- nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
204
+ nonenriched_color: Union[str, list, tuple, np.ndarray] = "white",
203
205
  nonenriched_alpha: Union[float, None] = 1.0,
204
206
  random_seed: int = 888,
205
207
  ) -> np.ndarray:
@@ -207,22 +209,25 @@ class Network:
207
209
 
208
210
  Args:
209
211
  cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
210
- color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
212
+ color (str, list, tuple, np.ndarray, or None, optional): Color to use for the nodes. Can be a single color or an array of colors.
213
+ If None, the colormap will be used. Defaults to None.
211
214
  min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
212
215
  max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
213
216
  scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
214
- alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values
215
- found in color. Defaults to 1.0.
216
- nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
217
- nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing
218
- alpha values found in nonenriched_color. Defaults to 1.0.
217
+ alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values found in `color`.
218
+ Defaults to 1.0.
219
+ nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Can be a single color or an array of colors.
220
+ Defaults to "white".
221
+ nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing alpha values found
222
+ in `nonenriched_color`. Defaults to 1.0.
219
223
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
220
224
 
221
225
  Returns:
222
226
  np.ndarray: Array of RGBA colors adjusted for enrichment status.
223
227
  """
224
228
  # Get the initial domain colors for each node, which are returned as RGBA
225
- network_colors = self.graph.get_domain_colors(
229
+ network_colors = get_domain_colors(
230
+ graph=self.graph,
226
231
  cmap=cmap,
227
232
  color=color,
228
233
  min_scale=min_scale,
@@ -233,7 +238,9 @@ class Network:
233
238
  # Apply the alpha value for enriched nodes
234
239
  network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
235
240
  # Convert the non-enriched color to RGBA using the to_rgba helper function
236
- nonenriched_color = to_rgba(color=nonenriched_color, alpha=nonenriched_alpha)
241
+ nonenriched_color = to_rgba(
242
+ color=nonenriched_color, alpha=nonenriched_alpha, num_repeats=1
243
+ ) # num_repeats=1 for a single color
237
244
  # Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
238
245
  # 0.1 is a predefined threshold for the minimum color intensity
239
246
  adjusted_network_colors = np.where(
@@ -14,7 +14,8 @@ from risk.network.plot.canvas import Canvas
14
14
  from risk.network.plot.contour import Contour
15
15
  from risk.network.plot.labels import Labels
16
16
  from risk.network.plot.network import Network
17
- from risk.network.plot.utils import calculate_bounding_box, to_rgba
17
+ from risk.network.plot.utils.color import to_rgba
18
+ from risk.network.plot.utils.layout import calculate_bounding_box
18
19
 
19
20
 
20
21
  class NetworkPlotter(Canvas, Network, Contour, Labels):
@@ -58,7 +59,7 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
58
59
  self,
59
60
  graph: NetworkGraph,
60
61
  figsize: Tuple,
61
- background_color: Union[str, List, Tuple, np.ndarray],
62
+ background_color: Union[str, list, tuple, np.ndarray],
62
63
  background_alpha: Union[float, None],
63
64
  pad: float,
64
65
  ) -> plt.Axes:
@@ -67,9 +68,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
67
68
  Args:
68
69
  graph (NetworkGraph): The network data and attributes to be visualized.
69
70
  figsize (tuple): Size of the figure in inches (width, height).
70
- background_color (str): Background color of the plot.
71
- background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any
72
- existing alpha values found in background_color.
71
+ background_color (str, list, tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
72
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
73
+ alpha values found in `background_color`.
73
74
  pad (float, optional): Padding value to adjust the axis limits.
74
75
 
75
76
  Returns:
@@ -98,7 +99,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
98
99
 
99
100
  # Set the background color of the plot
100
101
  # Convert color to RGBA using the to_rgba helper function
101
- fig.patch.set_facecolor(to_rgba(color=background_color, alpha=background_alpha))
102
+ fig.patch.set_facecolor(
103
+ to_rgba(color=background_color, alpha=background_alpha, num_repeats=1)
104
+ ) # num_repeats=1 for single color
102
105
  ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
103
106
  # Remove axis spines for a cleaner look
104
107
  for spine in ax.spines.values():