risk-network 0.0.8b15__tar.gz → 0.0.8b16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/PKG-INFO +1 -1
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/__init__.py +1 -1
- risk_network-0.0.8b16/risk/network/graph.py +159 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/plot/canvas.py +2 -1
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/plot/contour.py +1 -1
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/plot/labels.py +2 -1
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/plot/network.py +3 -2
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/plot/plotter.py +2 -1
- risk_network-0.0.8b16/risk/network/plot/utils/color.py +351 -0
- risk_network-0.0.8b16/risk/network/plot/utils/layout.py +53 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk_network.egg-info/PKG-INFO +1 -1
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk_network.egg-info/SOURCES.txt +2 -1
- risk_network-0.0.8b15/risk/network/graph.py +0 -393
- risk_network-0.0.8b15/risk/network/plot/utils.py +0 -153
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/LICENSE +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/MANIFEST.in +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/README.md +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/pyproject.toml +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/annotations/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/annotations/annotations.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/annotations/io.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/constants.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/log/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/log/config.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/log/params.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/neighborhoods/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/neighborhoods/community.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/neighborhoods/domains.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/neighborhoods/neighborhoods.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/geometry.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/io.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/network/plot/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/risk.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/stats/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/stats/hypergeom.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/stats/permutation/__init__.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/stats/permutation/permutation.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/stats/permutation/test_functions.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/stats/poisson.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk/stats/stats.py +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk_network.egg-info/dependency_links.txt +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk_network.egg-info/requires.txt +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/risk_network.egg-info/top_level.txt +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/setup.cfg +0 -0
- {risk_network-0.0.8b15 → risk_network-0.0.8b16}/setup.py +0 -0
@@ -0,0 +1,159 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph
|
3
|
+
~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from collections import defaultdict
|
7
|
+
from typing import Any, Dict, List
|
8
|
+
|
9
|
+
import networkx as nx
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
|
13
|
+
|
14
|
+
class NetworkGraph:
|
15
|
+
"""A class to represent a network graph and process its nodes and edges.
|
16
|
+
|
17
|
+
The NetworkGraph class provides functionality to handle and manipulate a network graph,
|
18
|
+
including managing domains, annotations, and node enrichment data. It also includes methods
|
19
|
+
for transforming and mapping graph coordinates, as well as generating colors based on node
|
20
|
+
enrichment.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
network: nx.Graph,
|
26
|
+
top_annotations: pd.DataFrame,
|
27
|
+
domains: pd.DataFrame,
|
28
|
+
trimmed_domains: pd.DataFrame,
|
29
|
+
node_label_to_node_id_map: Dict[str, Any],
|
30
|
+
node_enrichment_sums: np.ndarray,
|
31
|
+
):
|
32
|
+
"""Initialize the NetworkGraph object.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
network (nx.Graph): The network graph.
|
36
|
+
top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
|
37
|
+
domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
|
38
|
+
trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
|
39
|
+
node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
|
40
|
+
node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
|
41
|
+
"""
|
42
|
+
self.top_annotations = top_annotations
|
43
|
+
self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
|
44
|
+
self.domains = domains
|
45
|
+
self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
|
46
|
+
trimmed_domains
|
47
|
+
)
|
48
|
+
self.node_enrichment_sums = node_enrichment_sums
|
49
|
+
self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
|
50
|
+
self.node_label_to_enrichment_map = dict(
|
51
|
+
zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
|
52
|
+
)
|
53
|
+
self.node_label_to_node_id_map = node_label_to_node_id_map
|
54
|
+
# NOTE: Below this point, instance attributes (i.e., self) will be used!
|
55
|
+
self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
|
56
|
+
# Unfold the network's 3D coordinates to 2D and extract node coordinates
|
57
|
+
self.network = _unfold_sphere_to_plane(network)
|
58
|
+
self.node_coordinates = _extract_node_coordinates(self.network)
|
59
|
+
|
60
|
+
def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
|
61
|
+
"""Create a mapping from domains to the list of node IDs belonging to each domain.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
|
68
|
+
"""
|
69
|
+
cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
|
70
|
+
node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
|
71
|
+
domain_id_to_node_ids_map = defaultdict(list)
|
72
|
+
for k, v in node_to_domains_map.items():
|
73
|
+
domain_id_to_node_ids_map[v].append(k)
|
74
|
+
|
75
|
+
return domain_id_to_node_ids_map
|
76
|
+
|
77
|
+
def _create_domain_id_to_domain_terms_map(
|
78
|
+
self, trimmed_domains: pd.DataFrame
|
79
|
+
) -> Dict[str, Any]:
|
80
|
+
"""Create a mapping from domain IDs to their corresponding terms.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
dict: A dictionary mapping domain IDs to their corresponding terms.
|
87
|
+
"""
|
88
|
+
return dict(
|
89
|
+
zip(
|
90
|
+
trimmed_domains.index,
|
91
|
+
trimmed_domains["label"],
|
92
|
+
)
|
93
|
+
)
|
94
|
+
|
95
|
+
def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
|
96
|
+
"""Create a map from domain IDs to node labels.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
dict: A dictionary mapping domain IDs to the corresponding node labels.
|
100
|
+
"""
|
101
|
+
domain_id_to_label_map = {}
|
102
|
+
for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
|
103
|
+
domain_id_to_label_map[domain_id] = [
|
104
|
+
self.node_id_to_node_label_map[node_id] for node_id in node_ids
|
105
|
+
]
|
106
|
+
|
107
|
+
return domain_id_to_label_map
|
108
|
+
|
109
|
+
|
110
|
+
def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
|
111
|
+
"""Convert 3D coordinates to 2D by unfolding a sphere to a plane.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
|
118
|
+
"""
|
119
|
+
for node in G.nodes():
|
120
|
+
if "z" in G.nodes[node]:
|
121
|
+
# Extract 3D coordinates
|
122
|
+
x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
|
123
|
+
# Calculate spherical coordinates theta and phi from Cartesian coordinates
|
124
|
+
r = np.sqrt(x**2 + y**2 + z**2)
|
125
|
+
theta = np.arctan2(y, x)
|
126
|
+
phi = np.arccos(z / r)
|
127
|
+
|
128
|
+
# Convert spherical coordinates to 2D plane coordinates
|
129
|
+
unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
|
130
|
+
unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
|
131
|
+
unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
|
132
|
+
# Update network node attributes
|
133
|
+
G.nodes[node]["x"] = unfolded_x
|
134
|
+
G.nodes[node]["y"] = -unfolded_y
|
135
|
+
# Remove the 'z' coordinate as it's no longer needed
|
136
|
+
del G.nodes[node]["z"]
|
137
|
+
|
138
|
+
return G
|
139
|
+
|
140
|
+
|
141
|
+
def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
|
142
|
+
"""Extract 2D coordinates of nodes from the graph.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
G (nx.Graph): The network graph with node coordinates.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
np.ndarray: Array of node coordinates with shape (num_nodes, 2).
|
149
|
+
"""
|
150
|
+
# Extract x and y coordinates from graph nodes
|
151
|
+
x_coords = dict(G.nodes.data("x"))
|
152
|
+
y_coords = dict(G.nodes.data("y"))
|
153
|
+
coordinates_dicts = [x_coords, y_coords]
|
154
|
+
# Combine x and y coordinates into a single array
|
155
|
+
node_positions = {
|
156
|
+
node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
|
157
|
+
}
|
158
|
+
node_coordinates = np.vstack(list(node_positions.values()))
|
159
|
+
return node_coordinates
|
@@ -10,7 +10,8 @@ import numpy as np
|
|
10
10
|
|
11
11
|
from risk.log import params
|
12
12
|
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.plot.utils import
|
13
|
+
from risk.network.plot.utils.color import to_rgba
|
14
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
14
15
|
|
15
16
|
|
16
17
|
class Canvas:
|
@@ -13,7 +13,7 @@ from scipy.stats import gaussian_kde
|
|
13
13
|
|
14
14
|
from risk.log import params, logger
|
15
15
|
from risk.network.graph import NetworkGraph
|
16
|
-
from risk.network.plot.utils import get_annotated_domain_colors, to_rgba
|
16
|
+
from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
|
17
17
|
|
18
18
|
|
19
19
|
class Contour:
|
@@ -11,7 +11,8 @@ import pandas as pd
|
|
11
11
|
|
12
12
|
from risk.log import params
|
13
13
|
from risk.network.graph import NetworkGraph
|
14
|
-
from risk.network.plot.utils import
|
14
|
+
from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
|
15
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
15
16
|
|
16
17
|
TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
|
17
18
|
|
@@ -10,7 +10,7 @@ import numpy as np
|
|
10
10
|
|
11
11
|
from risk.log import params
|
12
12
|
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.plot.utils import to_rgba
|
13
|
+
from risk.network.plot.utils.color import get_domain_colors, to_rgba
|
14
14
|
|
15
15
|
|
16
16
|
class Network:
|
@@ -222,7 +222,8 @@ class Network:
|
|
222
222
|
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
223
223
|
"""
|
224
224
|
# Get the initial domain colors for each node, which are returned as RGBA
|
225
|
-
network_colors =
|
225
|
+
network_colors = get_domain_colors(
|
226
|
+
graph=self.graph,
|
226
227
|
cmap=cmap,
|
227
228
|
color=color,
|
228
229
|
min_scale=min_scale,
|
@@ -14,7 +14,8 @@ from risk.network.plot.canvas import Canvas
|
|
14
14
|
from risk.network.plot.contour import Contour
|
15
15
|
from risk.network.plot.labels import Labels
|
16
16
|
from risk.network.plot.network import Network
|
17
|
-
from risk.network.plot.utils import
|
17
|
+
from risk.network.plot.utils.color import to_rgba
|
18
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
18
19
|
|
19
20
|
|
20
21
|
class NetworkPlotter(Canvas, Network, Contour, Labels):
|
@@ -0,0 +1,351 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/plot/utils/plot
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
|
+
|
8
|
+
import matplotlib
|
9
|
+
import matplotlib.colors as mcolors
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from risk.network.graph import NetworkGraph
|
13
|
+
from risk.network.plot.utils.layout import calculate_centroids
|
14
|
+
|
15
|
+
|
16
|
+
def get_annotated_domain_colors(
|
17
|
+
graph: NetworkGraph,
|
18
|
+
cmap: str = "gist_rainbow",
|
19
|
+
color: Union[str, None] = None,
|
20
|
+
min_scale: float = 0.8,
|
21
|
+
max_scale: float = 1.0,
|
22
|
+
scale_factor: float = 1.0,
|
23
|
+
random_seed: int = 888,
|
24
|
+
) -> np.ndarray:
|
25
|
+
"""Get colors for the domains based on node annotations, or use a specified color.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
29
|
+
cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
30
|
+
color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
|
31
|
+
min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
|
32
|
+
Defaults to 0.8.
|
33
|
+
max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
|
34
|
+
Defaults to 1.0.
|
35
|
+
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
|
36
|
+
enrichment. Higher values increase the contrast. Defaults to 1.0.
|
37
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
np.ndarray: Array of RGBA colors for each domain.
|
41
|
+
"""
|
42
|
+
# Generate domain colors based on the enrichment data
|
43
|
+
node_colors = get_domain_colors(
|
44
|
+
graph=graph,
|
45
|
+
cmap=cmap,
|
46
|
+
color=color,
|
47
|
+
min_scale=min_scale,
|
48
|
+
max_scale=max_scale,
|
49
|
+
scale_factor=scale_factor,
|
50
|
+
random_seed=random_seed,
|
51
|
+
)
|
52
|
+
annotated_colors = []
|
53
|
+
for _, node_ids in graph.domain_id_to_node_ids_map.items():
|
54
|
+
if len(node_ids) > 1:
|
55
|
+
# For multi-node domains, choose the brightest color based on RGB sum
|
56
|
+
domain_colors = np.array([node_colors[node] for node in node_ids])
|
57
|
+
brightest_color = domain_colors[
|
58
|
+
np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
|
59
|
+
]
|
60
|
+
annotated_colors.append(brightest_color)
|
61
|
+
else:
|
62
|
+
# Single-node domains default to white (RGBA)
|
63
|
+
default_color = np.array([1.0, 1.0, 1.0, 1.0])
|
64
|
+
annotated_colors.append(default_color)
|
65
|
+
|
66
|
+
return np.array(annotated_colors)
|
67
|
+
|
68
|
+
|
69
|
+
def get_domain_colors(
|
70
|
+
graph: NetworkGraph,
|
71
|
+
cmap: str = "gist_rainbow",
|
72
|
+
color: Union[str, None] = None,
|
73
|
+
min_scale: float = 0.8,
|
74
|
+
max_scale: float = 1.0,
|
75
|
+
scale_factor: float = 1.0,
|
76
|
+
random_seed: int = 888,
|
77
|
+
) -> np.ndarray:
|
78
|
+
"""Generate composite colors for domains based on enrichment or specified colors.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
82
|
+
cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
83
|
+
color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
|
84
|
+
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
85
|
+
Controls the dimmest colors. Defaults to 0.8.
|
86
|
+
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
87
|
+
Controls the brightest colors. Defaults to 1.0.
|
88
|
+
scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores.
|
89
|
+
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
90
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments.
|
91
|
+
Defaults to 888.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
|
95
|
+
"""
|
96
|
+
# Get colors for each domain
|
97
|
+
domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
|
98
|
+
# Generate composite colors for nodes
|
99
|
+
node_colors = _get_composite_node_colors(graph=graph, domain_colors=domain_colors)
|
100
|
+
# Transform colors to ensure proper alpha values and intensity
|
101
|
+
transformed_colors = _transform_colors(
|
102
|
+
node_colors,
|
103
|
+
graph.node_enrichment_sums,
|
104
|
+
min_scale=min_scale,
|
105
|
+
max_scale=max_scale,
|
106
|
+
scale_factor=scale_factor,
|
107
|
+
)
|
108
|
+
return transformed_colors
|
109
|
+
|
110
|
+
|
111
|
+
def _get_domain_colors(
|
112
|
+
graph: NetworkGraph,
|
113
|
+
cmap: str = "gist_rainbow",
|
114
|
+
color: Union[str, None] = None,
|
115
|
+
random_seed: int = 888,
|
116
|
+
) -> Dict[str, Any]:
|
117
|
+
"""Get colors for each domain.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
121
|
+
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
122
|
+
color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
|
123
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
dict: A dictionary mapping domain keys to their corresponding RGBA colors.
|
127
|
+
"""
|
128
|
+
# Get colors for each domain based on node positions
|
129
|
+
domain_colors = _get_colors(
|
130
|
+
graph.network,
|
131
|
+
graph.domain_id_to_node_ids_map,
|
132
|
+
cmap=cmap,
|
133
|
+
color=color,
|
134
|
+
random_seed=random_seed,
|
135
|
+
)
|
136
|
+
return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))
|
137
|
+
|
138
|
+
|
139
|
+
def _get_composite_node_colors(graph: NetworkGraph, domain_colors: np.ndarray) -> np.ndarray:
|
140
|
+
"""Generate composite colors for nodes based on domain colors and counts.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
144
|
+
domain_colors (np.ndarray): Array of colors corresponding to each domain.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
np.ndarray: Array of composite colors for each node.
|
148
|
+
"""
|
149
|
+
# Determine the number of nodes
|
150
|
+
num_nodes = len(graph.node_coordinates)
|
151
|
+
# Initialize composite colors array with shape (number of nodes, 4) for RGBA
|
152
|
+
composite_colors = np.zeros((num_nodes, 4))
|
153
|
+
# Assign colors to nodes based on domain_colors
|
154
|
+
for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
|
155
|
+
color = domain_colors[domain_id]
|
156
|
+
for node in nodes:
|
157
|
+
composite_colors[node] = color
|
158
|
+
|
159
|
+
return composite_colors
|
160
|
+
|
161
|
+
|
162
|
+
def _get_colors(
|
163
|
+
network,
|
164
|
+
domain_id_to_node_ids_map,
|
165
|
+
cmap: str = "gist_rainbow",
|
166
|
+
color: Union[str, None] = None,
|
167
|
+
random_seed: int = 888,
|
168
|
+
) -> List[Tuple]:
|
169
|
+
"""Generate a list of RGBA colors based on domain centroids, ensuring that domains
|
170
|
+
close in space get maximally separated colors, while keeping some randomness.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
network (NetworkX graph): The graph representing the network.
|
174
|
+
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
175
|
+
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
176
|
+
color (str or None, optional): A specific color to use for all generated colors.
|
177
|
+
random_seed (int): Seed for random number generation. Defaults to 888.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
List[Tuple]: List of RGBA colors.
|
181
|
+
"""
|
182
|
+
# Set random seed for reproducibility
|
183
|
+
np.random.seed(random_seed)
|
184
|
+
# Determine the number of colors to generate based on the number of domains
|
185
|
+
num_colors_to_generate = len(domain_id_to_node_ids_map)
|
186
|
+
if color:
|
187
|
+
# Generate all colors as the same specified color
|
188
|
+
rgba = to_rgba(color, num_repeats=num_colors_to_generate)
|
189
|
+
return rgba
|
190
|
+
|
191
|
+
# Load colormap
|
192
|
+
colormap = matplotlib.colormaps.get_cmap(cmap)
|
193
|
+
# Step 1: Calculate centroids for each domain
|
194
|
+
centroids = calculate_centroids(network, domain_id_to_node_ids_map)
|
195
|
+
# Step 2: Calculate pairwise distances between centroids
|
196
|
+
centroid_array = np.array(centroids)
|
197
|
+
dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
|
198
|
+
# Step 3: Assign distant colors to close centroids
|
199
|
+
color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
|
200
|
+
# Step 4: Randomly shift the entire color palette while maintaining relative distances
|
201
|
+
global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
|
202
|
+
color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
|
203
|
+
# Step 5: Ensure that all positions remain between 0 and 1
|
204
|
+
color_positions = np.clip(color_positions, 0, 1)
|
205
|
+
|
206
|
+
# Step 6: Generate RGBA colors based on positions
|
207
|
+
return [colormap(pos) for pos in color_positions]
|
208
|
+
|
209
|
+
|
210
|
+
def _assign_distant_colors(dist_matrix, num_colors_to_generate):
|
211
|
+
"""Assign colors to centroids that are close in space, ensuring stark color differences.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
dist_matrix (ndarray): Matrix of pairwise centroid distances.
|
215
|
+
num_colors_to_generate (int): Number of colors to generate.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
np.array: Array of color positions in the range [0, 1].
|
219
|
+
"""
|
220
|
+
color_positions = np.zeros(num_colors_to_generate)
|
221
|
+
# Step 1: Sort indices by centroid proximity (based on sum of distances to others)
|
222
|
+
proximity_order = sorted(
|
223
|
+
range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
|
224
|
+
)
|
225
|
+
# Step 2: Assign colors starting with the most distant points in proximity order
|
226
|
+
for i, idx in enumerate(proximity_order):
|
227
|
+
color_positions[idx] = i / num_colors_to_generate
|
228
|
+
|
229
|
+
# Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
|
230
|
+
half_spectrum = int(num_colors_to_generate / 2)
|
231
|
+
for i in range(half_spectrum):
|
232
|
+
# Split the spectrum so that close centroids are assigned distant colors
|
233
|
+
color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
|
234
|
+
color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
|
235
|
+
|
236
|
+
return color_positions
|
237
|
+
|
238
|
+
|
239
|
+
def _transform_colors(
|
240
|
+
colors: np.ndarray,
|
241
|
+
enrichment_sums: np.ndarray,
|
242
|
+
min_scale: float = 0.8,
|
243
|
+
max_scale: float = 1.0,
|
244
|
+
scale_factor: float = 1.0,
|
245
|
+
) -> np.ndarray:
|
246
|
+
"""Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
|
247
|
+
very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
|
248
|
+
|
249
|
+
Args:
|
250
|
+
colors (np.ndarray): An array of RGBA colors.
|
251
|
+
enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
|
252
|
+
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
253
|
+
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
254
|
+
scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
|
255
|
+
values more. Defaults to 1.0.
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
np.ndarray: The transformed array of RGBA colors with adjusted intensities.
|
259
|
+
"""
|
260
|
+
# Ensure that min_scale is less than max_scale
|
261
|
+
if min_scale == max_scale:
|
262
|
+
min_scale = max_scale - 10e-6 # Avoid division by zero
|
263
|
+
|
264
|
+
# Replace black colors (#000000) with very dark grey (#1A1A1A)
|
265
|
+
black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
|
266
|
+
dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
|
267
|
+
# Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
|
268
|
+
is_black = np.all(colors[:, :3] == black_color, axis=1)
|
269
|
+
colors[is_black, :3] = dark_grey
|
270
|
+
|
271
|
+
# Normalize the enrichment sums to the range [0, 1]
|
272
|
+
normalized_sums = enrichment_sums / np.max(enrichment_sums)
|
273
|
+
# Apply power scaling to dim lower values and emphasize higher values
|
274
|
+
scaled_sums = normalized_sums**scale_factor
|
275
|
+
# Linearly scale the normalized sums to the range [min_scale, max_scale]
|
276
|
+
scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
|
277
|
+
# Adjust RGB values based on scaled sums
|
278
|
+
for i in range(3): # Only adjust RGB values
|
279
|
+
colors[:, i] = scaled_sums * colors[:, i]
|
280
|
+
|
281
|
+
return colors
|
282
|
+
|
283
|
+
|
284
|
+
def to_rgba(
|
285
|
+
color: Union[str, List, Tuple, np.ndarray],
|
286
|
+
alpha: Union[float, None] = None,
|
287
|
+
num_repeats: Union[int, None] = None,
|
288
|
+
) -> np.ndarray:
|
289
|
+
"""Convert color(s) to RGBA format, applying alpha and repeating as needed.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
|
293
|
+
alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values
|
294
|
+
found in color.
|
295
|
+
num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
|
299
|
+
"""
|
300
|
+
|
301
|
+
def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
|
302
|
+
"""Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
|
303
|
+
# Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
|
304
|
+
if isinstance(c, str):
|
305
|
+
# Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
|
306
|
+
rgba = np.array(mcolors.to_rgba(c))
|
307
|
+
elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
|
308
|
+
# Convert RGB (3) or RGBA (4) values to RGBA format
|
309
|
+
rgba = np.array(mcolors.to_rgba(c))
|
310
|
+
else:
|
311
|
+
raise ValueError(
|
312
|
+
f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
|
313
|
+
)
|
314
|
+
|
315
|
+
if alpha is not None: # Override alpha if provided
|
316
|
+
rgba[3] = alpha
|
317
|
+
return rgba
|
318
|
+
|
319
|
+
# If color is a 2D array of RGBA values, convert it to a list of lists
|
320
|
+
if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
|
321
|
+
color = [list(c) for c in color]
|
322
|
+
|
323
|
+
# Handle a single color (string or RGB/RGBA list/tuple)
|
324
|
+
if (
|
325
|
+
isinstance(color, str)
|
326
|
+
or isinstance(color, (list, tuple, np.ndarray))
|
327
|
+
and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
|
328
|
+
):
|
329
|
+
rgba_color = convert_to_rgba(color)
|
330
|
+
if num_repeats:
|
331
|
+
return np.tile(
|
332
|
+
rgba_color, (num_repeats, 1)
|
333
|
+
) # Repeat the color if num_repeats is provided
|
334
|
+
return np.array([rgba_color]) # Return a single color wrapped in a numpy array
|
335
|
+
|
336
|
+
# Handle a list/array of colors
|
337
|
+
elif isinstance(color, (list, tuple, np.ndarray)):
|
338
|
+
rgba_colors = np.array(
|
339
|
+
[convert_to_rgba(c) for c in color]
|
340
|
+
) # Convert each color in the list to RGBA
|
341
|
+
# Handle repetition if num_repeats is provided
|
342
|
+
if num_repeats:
|
343
|
+
repeated_colors = np.array(
|
344
|
+
[rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
|
345
|
+
)
|
346
|
+
return repeated_colors
|
347
|
+
|
348
|
+
return rgba_colors
|
349
|
+
|
350
|
+
else:
|
351
|
+
raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
|
@@ -0,0 +1,53 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/plot/utils/layout
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Tuple
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
|
11
|
+
def calculate_bounding_box(
|
12
|
+
node_coordinates: np.ndarray, radius_margin: float = 1.05
|
13
|
+
) -> Tuple[np.ndarray, float]:
|
14
|
+
"""Calculate the bounding box of the network based on node coordinates.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
node_coordinates (np.ndarray): Array of node coordinates (x, y).
|
18
|
+
radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
tuple: Center of the bounding box and the radius (adjusted by the radius margin).
|
22
|
+
"""
|
23
|
+
# Find minimum and maximum x, y coordinates
|
24
|
+
x_min, y_min = np.min(node_coordinates, axis=0)
|
25
|
+
x_max, y_max = np.max(node_coordinates, axis=0)
|
26
|
+
# Calculate the center of the bounding box
|
27
|
+
center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
|
28
|
+
# Calculate the radius of the bounding box, adjusted by the margin
|
29
|
+
radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
|
30
|
+
return center, radius
|
31
|
+
|
32
|
+
|
33
|
+
def calculate_centroids(network, domain_id_to_node_ids_map):
|
34
|
+
"""Calculate the centroid for each domain based on node x and y coordinates in the network.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
network (NetworkX graph): The graph representing the network.
|
38
|
+
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
List[Tuple[float, float]]: List of centroids (x, y) for each domain.
|
42
|
+
"""
|
43
|
+
centroids = []
|
44
|
+
for domain_id, node_ids in domain_id_to_node_ids_map.items():
|
45
|
+
# Extract x and y coordinates from the network nodes
|
46
|
+
node_positions = np.array(
|
47
|
+
[[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
|
48
|
+
)
|
49
|
+
# Compute the centroid as the mean of the x and y coordinates
|
50
|
+
centroid = np.mean(node_positions, axis=0)
|
51
|
+
centroids.append(tuple(centroid))
|
52
|
+
|
53
|
+
return centroids
|
@@ -26,7 +26,8 @@ risk/network/plot/contour.py
|
|
26
26
|
risk/network/plot/labels.py
|
27
27
|
risk/network/plot/network.py
|
28
28
|
risk/network/plot/plotter.py
|
29
|
-
risk/network/plot/utils.py
|
29
|
+
risk/network/plot/utils/color.py
|
30
|
+
risk/network/plot/utils/layout.py
|
30
31
|
risk/stats/__init__.py
|
31
32
|
risk/stats/hypergeom.py
|
32
33
|
risk/stats/poisson.py
|