risk-network 0.0.8b15__tar.gz → 0.0.8b17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/PKG-INFO +1 -1
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/__init__.py +1 -1
- risk_network-0.0.8b17/risk/network/graph.py +159 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/canvas.py +15 -8
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/contour.py +11 -9
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/labels.py +18 -10
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/network.py +20 -13
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/plotter.py +9 -6
- risk_network-0.0.8b17/risk/network/plot/utils/color.py +353 -0
- risk_network-0.0.8b17/risk/network/plot/utils/layout.py +53 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/PKG-INFO +1 -1
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/SOURCES.txt +2 -1
- risk_network-0.0.8b15/risk/network/graph.py +0 -393
- risk_network-0.0.8b15/risk/network/plot/utils.py +0 -153
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/LICENSE +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/MANIFEST.in +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/README.md +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/pyproject.toml +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/annotations/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/annotations/annotations.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/annotations/io.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/constants.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/log/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/log/config.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/log/params.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/community.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/domains.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/neighborhoods/neighborhoods.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/geometry.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/io.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/network/plot/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/risk.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/hypergeom.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/permutation/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/permutation/permutation.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/permutation/test_functions.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/poisson.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk/stats/stats.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/dependency_links.txt +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/requires.txt +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/risk_network.egg-info/top_level.txt +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/setup.cfg +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b17}/setup.py +0 -0
@@ -0,0 +1,159 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph
|
3
|
+
~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from collections import defaultdict
|
7
|
+
from typing import Any, Dict, List
|
8
|
+
|
9
|
+
import networkx as nx
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
|
13
|
+
|
14
|
+
class NetworkGraph:
|
15
|
+
"""A class to represent a network graph and process its nodes and edges.
|
16
|
+
|
17
|
+
The NetworkGraph class provides functionality to handle and manipulate a network graph,
|
18
|
+
including managing domains, annotations, and node enrichment data. It also includes methods
|
19
|
+
for transforming and mapping graph coordinates, as well as generating colors based on node
|
20
|
+
enrichment.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
network: nx.Graph,
|
26
|
+
top_annotations: pd.DataFrame,
|
27
|
+
domains: pd.DataFrame,
|
28
|
+
trimmed_domains: pd.DataFrame,
|
29
|
+
node_label_to_node_id_map: Dict[str, Any],
|
30
|
+
node_enrichment_sums: np.ndarray,
|
31
|
+
):
|
32
|
+
"""Initialize the NetworkGraph object.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
network (nx.Graph): The network graph.
|
36
|
+
top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
|
37
|
+
domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
|
38
|
+
trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
|
39
|
+
node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
|
40
|
+
node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
|
41
|
+
"""
|
42
|
+
self.top_annotations = top_annotations
|
43
|
+
self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
|
44
|
+
self.domains = domains
|
45
|
+
self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
|
46
|
+
trimmed_domains
|
47
|
+
)
|
48
|
+
self.node_enrichment_sums = node_enrichment_sums
|
49
|
+
self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
|
50
|
+
self.node_label_to_enrichment_map = dict(
|
51
|
+
zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
|
52
|
+
)
|
53
|
+
self.node_label_to_node_id_map = node_label_to_node_id_map
|
54
|
+
# NOTE: Below this point, instance attributes (i.e., self) will be used!
|
55
|
+
self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
|
56
|
+
# Unfold the network's 3D coordinates to 2D and extract node coordinates
|
57
|
+
self.network = _unfold_sphere_to_plane(network)
|
58
|
+
self.node_coordinates = _extract_node_coordinates(self.network)
|
59
|
+
|
60
|
+
def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
|
61
|
+
"""Create a mapping from domains to the list of node IDs belonging to each domain.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
|
68
|
+
"""
|
69
|
+
cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
|
70
|
+
node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
|
71
|
+
domain_id_to_node_ids_map = defaultdict(list)
|
72
|
+
for k, v in node_to_domains_map.items():
|
73
|
+
domain_id_to_node_ids_map[v].append(k)
|
74
|
+
|
75
|
+
return domain_id_to_node_ids_map
|
76
|
+
|
77
|
+
def _create_domain_id_to_domain_terms_map(
|
78
|
+
self, trimmed_domains: pd.DataFrame
|
79
|
+
) -> Dict[str, Any]:
|
80
|
+
"""Create a mapping from domain IDs to their corresponding terms.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
dict: A dictionary mapping domain IDs to their corresponding terms.
|
87
|
+
"""
|
88
|
+
return dict(
|
89
|
+
zip(
|
90
|
+
trimmed_domains.index,
|
91
|
+
trimmed_domains["label"],
|
92
|
+
)
|
93
|
+
)
|
94
|
+
|
95
|
+
def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
|
96
|
+
"""Create a map from domain IDs to node labels.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
dict: A dictionary mapping domain IDs to the corresponding node labels.
|
100
|
+
"""
|
101
|
+
domain_id_to_label_map = {}
|
102
|
+
for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
|
103
|
+
domain_id_to_label_map[domain_id] = [
|
104
|
+
self.node_id_to_node_label_map[node_id] for node_id in node_ids
|
105
|
+
]
|
106
|
+
|
107
|
+
return domain_id_to_label_map
|
108
|
+
|
109
|
+
|
110
|
+
def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
|
111
|
+
"""Convert 3D coordinates to 2D by unfolding a sphere to a plane.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
|
118
|
+
"""
|
119
|
+
for node in G.nodes():
|
120
|
+
if "z" in G.nodes[node]:
|
121
|
+
# Extract 3D coordinates
|
122
|
+
x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
|
123
|
+
# Calculate spherical coordinates theta and phi from Cartesian coordinates
|
124
|
+
r = np.sqrt(x**2 + y**2 + z**2)
|
125
|
+
theta = np.arctan2(y, x)
|
126
|
+
phi = np.arccos(z / r)
|
127
|
+
|
128
|
+
# Convert spherical coordinates to 2D plane coordinates
|
129
|
+
unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
|
130
|
+
unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
|
131
|
+
unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
|
132
|
+
# Update network node attributes
|
133
|
+
G.nodes[node]["x"] = unfolded_x
|
134
|
+
G.nodes[node]["y"] = -unfolded_y
|
135
|
+
# Remove the 'z' coordinate as it's no longer needed
|
136
|
+
del G.nodes[node]["z"]
|
137
|
+
|
138
|
+
return G
|
139
|
+
|
140
|
+
|
141
|
+
def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
|
142
|
+
"""Extract 2D coordinates of nodes from the graph.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
G (nx.Graph): The network graph with node coordinates.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
np.ndarray: Array of node coordinates with shape (num_nodes, 2).
|
149
|
+
"""
|
150
|
+
# Extract x and y coordinates from graph nodes
|
151
|
+
x_coords = dict(G.nodes.data("x"))
|
152
|
+
y_coords = dict(G.nodes.data("y"))
|
153
|
+
coordinates_dicts = [x_coords, y_coords]
|
154
|
+
# Combine x and y coordinates into a single array
|
155
|
+
node_positions = {
|
156
|
+
node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
|
157
|
+
}
|
158
|
+
node_coordinates = np.vstack(list(node_positions.values()))
|
159
|
+
return node_coordinates
|
@@ -10,7 +10,8 @@ import numpy as np
|
|
10
10
|
|
11
11
|
from risk.log import params
|
12
12
|
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.plot.utils import
|
13
|
+
from risk.network.plot.utils.color import to_rgba
|
14
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
14
15
|
|
15
16
|
|
16
17
|
class Canvas:
|
@@ -33,8 +34,8 @@ class Canvas:
|
|
33
34
|
title_fontsize: int = 20,
|
34
35
|
subtitle_fontsize: int = 14,
|
35
36
|
font: str = "Arial",
|
36
|
-
title_color: str = "black",
|
37
|
-
subtitle_color: str = "gray",
|
37
|
+
title_color: Union[str, list, tuple, np.ndarray] = "black",
|
38
|
+
subtitle_color: Union[str, list, tuple, np.ndarray] = "gray",
|
38
39
|
title_y: float = 0.975,
|
39
40
|
title_space_offset: float = 0.075,
|
40
41
|
subtitle_offset: float = 0.025,
|
@@ -47,8 +48,10 @@ class Canvas:
|
|
47
48
|
title_fontsize (int, optional): Font size for the title. Defaults to 20.
|
48
49
|
subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 14.
|
49
50
|
font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
|
50
|
-
title_color (str, optional): Color of the title text.
|
51
|
-
|
51
|
+
title_color (str, list, tuple, or np.ndarray, optional): Color of the title text. Can be a string or an array of colors.
|
52
|
+
Defaults to "black".
|
53
|
+
subtitle_color (str, list, tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
|
54
|
+
Defaults to "gray".
|
52
55
|
title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
|
53
56
|
title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
|
54
57
|
subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
|
@@ -141,7 +144,9 @@ class Canvas:
|
|
141
144
|
)
|
142
145
|
|
143
146
|
# Convert color to RGBA using the to_rgba helper function - use outline_alpha for the perimeter
|
144
|
-
color = to_rgba(
|
147
|
+
color = to_rgba(
|
148
|
+
color=color, alpha=outline_alpha, num_repeats=1
|
149
|
+
) # num_repeats=1 for a single color
|
145
150
|
# Set the fill_alpha to 0 if not provided
|
146
151
|
fill_alpha = fill_alpha if fill_alpha is not None else 0.0
|
147
152
|
# Extract node coordinates from the network graph
|
@@ -162,7 +167,9 @@ class Canvas:
|
|
162
167
|
)
|
163
168
|
# Set the transparency of the fill if applicable
|
164
169
|
if fill_alpha > 0:
|
165
|
-
circle.set_facecolor(
|
170
|
+
circle.set_facecolor(
|
171
|
+
to_rgba(color=color, alpha=fill_alpha, num_repeats=1)
|
172
|
+
) # num_repeats=1 for a single color
|
166
173
|
|
167
174
|
self.ax.add_artist(circle)
|
168
175
|
|
@@ -209,7 +216,7 @@ class Canvas:
|
|
209
216
|
)
|
210
217
|
|
211
218
|
# Convert color to RGBA using outline_alpha for the line (outline)
|
212
|
-
outline_color = to_rgba(color=color)
|
219
|
+
outline_color = to_rgba(color=color, num_repeats=1) # num_repeats=1 for a single color
|
213
220
|
# Extract node coordinates from the network graph
|
214
221
|
node_coordinates = self.graph.node_coordinates
|
215
222
|
# Scale the node coordinates if needed
|
@@ -13,7 +13,7 @@ from scipy.stats import gaussian_kde
|
|
13
13
|
|
14
14
|
from risk.log import params, logger
|
15
15
|
from risk.network.graph import NetworkGraph
|
16
|
-
from risk.network.plot.utils import get_annotated_domain_colors, to_rgba
|
16
|
+
from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
|
17
17
|
|
18
18
|
|
19
19
|
class Contour:
|
@@ -110,7 +110,7 @@ class Contour:
|
|
110
110
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
111
111
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
112
112
|
color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array.
|
113
|
-
Defaults to "white".
|
113
|
+
Can be a single color or an array of colors. Defaults to "white".
|
114
114
|
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
115
115
|
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
116
116
|
alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
|
@@ -125,15 +125,16 @@ class Contour:
|
|
125
125
|
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
126
126
|
# If it's a list of lists, iterate over sublists
|
127
127
|
node_groups = nodes
|
128
|
+
# Convert color to RGBA arrays to match the number of groups
|
129
|
+
color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=len(node_groups))
|
128
130
|
else:
|
129
131
|
# If it's a flat list of nodes, treat it as a single group
|
130
132
|
node_groups = [nodes]
|
131
|
-
|
132
|
-
|
133
|
-
color_rgba = to_rgba(color=color, alpha=alpha)
|
133
|
+
# Wrap the RGBA color in an array to index the first element
|
134
|
+
color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=1)
|
134
135
|
|
135
136
|
# Iterate over each group of nodes (either sublists or flat list)
|
136
|
-
for sublist in node_groups:
|
137
|
+
for idx, sublist in enumerate(node_groups):
|
137
138
|
# Filter to get node IDs and their coordinates for each sublist
|
138
139
|
node_ids = [
|
139
140
|
self.graph.node_label_to_node_id_map.get(node)
|
@@ -151,7 +152,7 @@ class Contour:
|
|
151
152
|
self.ax,
|
152
153
|
node_coordinates,
|
153
154
|
node_ids,
|
154
|
-
color=color_rgba,
|
155
|
+
color=color_rgba[idx],
|
155
156
|
levels=levels,
|
156
157
|
bandwidth=bandwidth,
|
157
158
|
grid_size=grid_size,
|
@@ -273,7 +274,7 @@ class Contour:
|
|
273
274
|
def get_annotated_contour_colors(
|
274
275
|
self,
|
275
276
|
cmap: str = "gist_rainbow",
|
276
|
-
color: Union[str, None] = None,
|
277
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
277
278
|
min_scale: float = 0.8,
|
278
279
|
max_scale: float = 1.0,
|
279
280
|
scale_factor: float = 1.0,
|
@@ -283,7 +284,8 @@ class Contour:
|
|
283
284
|
|
284
285
|
Args:
|
285
286
|
cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
|
286
|
-
color (str or None, optional): Color to use for the contours.
|
287
|
+
color (str, list, tuple, np.ndarray, or None, optional): Color to use for the contours. Can be a single color or an array of colors.
|
288
|
+
If None, the colormap will be used. Defaults to None.
|
287
289
|
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
288
290
|
Controls the dimmest colors. Defaults to 0.8.
|
289
291
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
@@ -11,7 +11,8 @@ import pandas as pd
|
|
11
11
|
|
12
12
|
from risk.log import params
|
13
13
|
from risk.network.graph import NetworkGraph
|
14
|
-
from risk.network.plot.utils import
|
14
|
+
from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
|
15
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
15
16
|
|
16
17
|
TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
|
17
18
|
|
@@ -284,13 +285,19 @@ class Labels:
|
|
284
285
|
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
285
286
|
# If it's a list of lists, iterate over sublists
|
286
287
|
node_groups = nodes
|
288
|
+
# Convert fontcolor and arrow_color to RGBA arrays to match the number of groups
|
289
|
+
fontcolor_rgba = to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=len(node_groups))
|
290
|
+
arrow_color_rgba = to_rgba(
|
291
|
+
color=arrow_color, alpha=arrow_alpha, num_repeats=len(node_groups)
|
292
|
+
)
|
287
293
|
else:
|
288
294
|
# If it's a flat list of nodes, treat it as a single group
|
289
295
|
node_groups = [nodes]
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
296
|
+
# Wrap the RGBA fontcolor and arrow_color in an array to index the first element
|
297
|
+
fontcolor_rgba = np.array(to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=1))
|
298
|
+
arrow_color_rgba = np.array(
|
299
|
+
to_rgba(color=arrow_color, alpha=arrow_alpha, num_repeats=1)
|
300
|
+
)
|
294
301
|
|
295
302
|
# Calculate the bounding box around the network
|
296
303
|
center, radius = calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
@@ -302,7 +309,7 @@ class Labels:
|
|
302
309
|
)
|
303
310
|
|
304
311
|
# Iterate over each group of nodes (either sublists or flat list)
|
305
|
-
for sublist in node_groups:
|
312
|
+
for idx, sublist in enumerate(node_groups):
|
306
313
|
# Map node labels to IDs
|
307
314
|
node_ids = [
|
308
315
|
self.graph.node_label_to_node_id_map.get(node)
|
@@ -326,10 +333,10 @@ class Labels:
|
|
326
333
|
va="center",
|
327
334
|
fontsize=fontsize,
|
328
335
|
fontname=font,
|
329
|
-
color=fontcolor_rgba,
|
336
|
+
color=fontcolor_rgba[idx],
|
330
337
|
arrowprops=dict(
|
331
338
|
arrowstyle=arrow_style,
|
332
|
-
color=arrow_color_rgba,
|
339
|
+
color=arrow_color_rgba[idx],
|
333
340
|
linewidth=arrow_linewidth,
|
334
341
|
shrinkA=arrow_base_shrink,
|
335
342
|
shrinkB=arrow_tip_shrink,
|
@@ -630,7 +637,7 @@ class Labels:
|
|
630
637
|
def get_annotated_label_colors(
|
631
638
|
self,
|
632
639
|
cmap: str = "gist_rainbow",
|
633
|
-
color: Union[str, None] = None,
|
640
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
634
641
|
min_scale: float = 0.8,
|
635
642
|
max_scale: float = 1.0,
|
636
643
|
scale_factor: float = 1.0,
|
@@ -640,7 +647,8 @@ class Labels:
|
|
640
647
|
|
641
648
|
Args:
|
642
649
|
cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
|
643
|
-
color (str or None, optional): Color to use for the labels.
|
650
|
+
color (str, list, tuple, np.ndarray, or None, optional): Color to use for the labels. Can be a single color or an array
|
651
|
+
of colors. If None, the colormap will be used. Defaults to None.
|
644
652
|
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
645
653
|
Controls the dimmest colors. Defaults to 0.8.
|
646
654
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
@@ -10,7 +10,7 @@ import numpy as np
|
|
10
10
|
|
11
11
|
from risk.log import params
|
12
12
|
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.plot.utils import to_rgba
|
13
|
+
from risk.network.plot.utils.color import get_domain_colors, to_rgba
|
14
14
|
|
15
15
|
|
16
16
|
class Network:
|
@@ -47,8 +47,10 @@ class Network:
|
|
47
47
|
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
48
48
|
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors.
|
49
49
|
Defaults to "white".
|
50
|
-
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges.
|
51
|
-
|
50
|
+
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Can be a single color or an array of colors.
|
51
|
+
Defaults to "black".
|
52
|
+
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Can be a single color or an array of colors.
|
53
|
+
Defaults to "black".
|
52
54
|
node_alpha (float, None, optional): Alpha value (transparency) for the nodes. If provided, it overrides any existing alpha
|
53
55
|
values found in node_color. Defaults to 1.0. Annotated node_color alphas will override this value.
|
54
56
|
edge_alpha (float, None, optional): Alpha value (transparency) for the edges. If provided, it overrides any existing alpha
|
@@ -194,12 +196,12 @@ class Network:
|
|
194
196
|
def get_annotated_node_colors(
|
195
197
|
self,
|
196
198
|
cmap: str = "gist_rainbow",
|
197
|
-
color: Union[str, None] = None,
|
199
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
198
200
|
min_scale: float = 0.8,
|
199
201
|
max_scale: float = 1.0,
|
200
202
|
scale_factor: float = 1.0,
|
201
203
|
alpha: Union[float, None] = 1.0,
|
202
|
-
nonenriched_color: Union[str,
|
204
|
+
nonenriched_color: Union[str, list, tuple, np.ndarray] = "white",
|
203
205
|
nonenriched_alpha: Union[float, None] = 1.0,
|
204
206
|
random_seed: int = 888,
|
205
207
|
) -> np.ndarray:
|
@@ -207,22 +209,25 @@ class Network:
|
|
207
209
|
|
208
210
|
Args:
|
209
211
|
cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
|
210
|
-
color (str or None, optional): Color to use for the nodes.
|
212
|
+
color (str, list, tuple, np.ndarray, or None, optional): Color to use for the nodes. Can be a single color or an array of colors.
|
213
|
+
If None, the colormap will be used. Defaults to None.
|
211
214
|
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
212
215
|
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
213
216
|
scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
|
214
|
-
alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values
|
215
|
-
|
216
|
-
nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes.
|
217
|
-
|
218
|
-
|
217
|
+
alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values found in `color`.
|
218
|
+
Defaults to 1.0.
|
219
|
+
nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Can be a single color or an array of colors.
|
220
|
+
Defaults to "white".
|
221
|
+
nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing alpha values found
|
222
|
+
in `nonenriched_color`. Defaults to 1.0.
|
219
223
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
220
224
|
|
221
225
|
Returns:
|
222
226
|
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
223
227
|
"""
|
224
228
|
# Get the initial domain colors for each node, which are returned as RGBA
|
225
|
-
network_colors =
|
229
|
+
network_colors = get_domain_colors(
|
230
|
+
graph=self.graph,
|
226
231
|
cmap=cmap,
|
227
232
|
color=color,
|
228
233
|
min_scale=min_scale,
|
@@ -233,7 +238,9 @@ class Network:
|
|
233
238
|
# Apply the alpha value for enriched nodes
|
234
239
|
network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
|
235
240
|
# Convert the non-enriched color to RGBA using the to_rgba helper function
|
236
|
-
nonenriched_color = to_rgba(
|
241
|
+
nonenriched_color = to_rgba(
|
242
|
+
color=nonenriched_color, alpha=nonenriched_alpha, num_repeats=1
|
243
|
+
) # num_repeats=1 for a single color
|
237
244
|
# Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
|
238
245
|
# 0.1 is a predefined threshold for the minimum color intensity
|
239
246
|
adjusted_network_colors = np.where(
|
@@ -14,7 +14,8 @@ from risk.network.plot.canvas import Canvas
|
|
14
14
|
from risk.network.plot.contour import Contour
|
15
15
|
from risk.network.plot.labels import Labels
|
16
16
|
from risk.network.plot.network import Network
|
17
|
-
from risk.network.plot.utils import
|
17
|
+
from risk.network.plot.utils.color import to_rgba
|
18
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
18
19
|
|
19
20
|
|
20
21
|
class NetworkPlotter(Canvas, Network, Contour, Labels):
|
@@ -58,7 +59,7 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
|
|
58
59
|
self,
|
59
60
|
graph: NetworkGraph,
|
60
61
|
figsize: Tuple,
|
61
|
-
background_color: Union[str,
|
62
|
+
background_color: Union[str, list, tuple, np.ndarray],
|
62
63
|
background_alpha: Union[float, None],
|
63
64
|
pad: float,
|
64
65
|
) -> plt.Axes:
|
@@ -67,9 +68,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
|
|
67
68
|
Args:
|
68
69
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
69
70
|
figsize (tuple): Size of the figure in inches (width, height).
|
70
|
-
background_color (str): Background color of the plot.
|
71
|
-
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any
|
72
|
-
|
71
|
+
background_color (str, list, tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
|
72
|
+
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
|
73
|
+
alpha values found in `background_color`.
|
73
74
|
pad (float, optional): Padding value to adjust the axis limits.
|
74
75
|
|
75
76
|
Returns:
|
@@ -98,7 +99,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
|
|
98
99
|
|
99
100
|
# Set the background color of the plot
|
100
101
|
# Convert color to RGBA using the to_rgba helper function
|
101
|
-
fig.patch.set_facecolor(
|
102
|
+
fig.patch.set_facecolor(
|
103
|
+
to_rgba(color=background_color, alpha=background_alpha, num_repeats=1)
|
104
|
+
) # num_repeats=1 for single color
|
102
105
|
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
103
106
|
# Remove axis spines for a cleaner look
|
104
107
|
for spine in ax.spines.values():
|