risk-network 0.0.8b27__py3-none-any.whl → 0.0.9b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/annotations.py +39 -38
- risk/annotations/io.py +8 -6
- risk/log/__init__.py +3 -1
- risk/log/{params.py → parameters.py} +9 -34
- risk/neighborhoods/domains.py +18 -18
- risk/neighborhoods/neighborhoods.py +104 -92
- risk/network/graph/__init__.py +6 -0
- risk/network/{graph.py → graph/network.py} +38 -27
- risk/network/graph/summary.py +239 -0
- risk/network/io.py +3 -3
- risk/network/plot/contour.py +1 -1
- risk/network/plot/labels.py +1 -1
- risk/network/plot/network.py +28 -28
- risk/network/plot/utils/color.py +27 -27
- risk/risk.py +25 -30
- risk/stats/stats.py +13 -13
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9b2.dist-info}/METADATA +1 -1
- risk_network-0.0.9b2.dist-info/RECORD +39 -0
- risk_network-0.0.8b27.dist-info/RECORD +0 -37
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9b2.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9b2.dist-info}/WHEEL +0 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9b2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,239 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph/summary
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
import warnings
|
7
|
+
from functools import lru_cache
|
8
|
+
from typing import Any, Dict, Tuple, Union
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
from statsmodels.stats.multitest import fdrcorrection
|
13
|
+
|
14
|
+
from risk.log.console import logger, log_header
|
15
|
+
|
16
|
+
|
17
|
+
# Suppress all warnings - this is to resolve warnings from multiprocessing
|
18
|
+
warnings.filterwarnings("ignore")
|
19
|
+
|
20
|
+
|
21
|
+
class Summary:
|
22
|
+
"""Handles the processing, storage, and export of network analysis results.
|
23
|
+
|
24
|
+
The Results class provides methods to process significance and depletion data, compute
|
25
|
+
FDR-corrected q-values, and structure information on domains and annotations into a
|
26
|
+
DataFrame. It also offers functionality to export the processed data in CSV, JSON,
|
27
|
+
and text formats for analysis and reporting.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
annotations: Dict[str, Any],
|
33
|
+
neighborhoods: Dict[str, Any],
|
34
|
+
graph, # Avoid type hinting NetworkGraph to avoid circular import
|
35
|
+
):
|
36
|
+
"""Initialize the Results object with analysis components.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
|
40
|
+
neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
|
41
|
+
graph (NetworkGraph): Graph object representing domain-to-node and node-to-label mappings.
|
42
|
+
"""
|
43
|
+
self.annotations = annotations
|
44
|
+
self.neighborhoods = neighborhoods
|
45
|
+
self.graph = graph
|
46
|
+
|
47
|
+
def to_csv(self, filepath: str) -> None:
|
48
|
+
"""Export significance results to a CSV file.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
filepath (str): The path where the CSV file will be saved.
|
52
|
+
"""
|
53
|
+
# Load results and export directly to CSV
|
54
|
+
results = self.load()
|
55
|
+
results.to_csv(filepath, index=False)
|
56
|
+
logger.info(f"Results summary exported to CSV file: {filepath}")
|
57
|
+
|
58
|
+
def to_json(self, filepath: str) -> None:
|
59
|
+
"""Export significance results to a JSON file.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
filepath (str): The path where the JSON file will be saved.
|
63
|
+
"""
|
64
|
+
# Load results and export directly to JSON
|
65
|
+
results = self.load()
|
66
|
+
results.to_json(filepath, orient="records", indent=4)
|
67
|
+
logger.info(f"Results summary exported to JSON file: {filepath}")
|
68
|
+
|
69
|
+
def to_txt(self, filepath: str) -> None:
|
70
|
+
"""Export significance results to a text file.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
filepath (str): The path where the text file will be saved.
|
74
|
+
"""
|
75
|
+
# Load results and export directly to text file
|
76
|
+
results = self.load()
|
77
|
+
with open(filepath, "w") as txt_file:
|
78
|
+
txt_file.write(results.to_string(index=False))
|
79
|
+
|
80
|
+
logger.info(f"Results summary exported to text file: {filepath}")
|
81
|
+
|
82
|
+
@lru_cache(maxsize=None)
|
83
|
+
def load(self) -> pd.DataFrame:
|
84
|
+
"""Load and process domain and annotation data into a DataFrame with significance metrics.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
graph (Any): Graph object containing domain-to-node and node-to-label mappings.
|
88
|
+
annotations (Dict[str, Any]): Annotation details, including ordered annotations and matrix.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
pd.DataFrame: Processed DataFrame containing significance scores, p-values, q-values,
|
92
|
+
and annotation member information.
|
93
|
+
"""
|
94
|
+
log_header("Loading parameters")
|
95
|
+
# Calculate significance and depletion q-values from p-value matrices in `annotations`
|
96
|
+
enrichment_pvals = self.neighborhoods["enrichment_pvals"]
|
97
|
+
depletion_pvals = self.neighborhoods["depletion_pvals"]
|
98
|
+
enrichment_qvals = self._calculate_qvalues(enrichment_pvals)
|
99
|
+
depletion_qvals = self._calculate_qvalues(depletion_pvals)
|
100
|
+
|
101
|
+
# Initialize DataFrame with domain and annotation details
|
102
|
+
results = pd.DataFrame(
|
103
|
+
[
|
104
|
+
{"Domain ID": domain_id, "Annotation": desc, "Summed Significance Score": score}
|
105
|
+
for domain_id, info in self.graph.domain_id_to_domain_info_map.items()
|
106
|
+
for desc, score in zip(info["full_descriptions"], info["significance_scores"])
|
107
|
+
]
|
108
|
+
)
|
109
|
+
# Sort by Domain ID and Summed Significance Score
|
110
|
+
results = results.sort_values(
|
111
|
+
by=["Domain ID", "Summed Significance Score"], ascending=[True, False]
|
112
|
+
).reset_index(drop=True)
|
113
|
+
|
114
|
+
# Add minimum p-values and q-values to DataFrame
|
115
|
+
results[
|
116
|
+
[
|
117
|
+
"Enrichment P-value",
|
118
|
+
"Enrichment Q-value",
|
119
|
+
"Depletion P-value",
|
120
|
+
"Depletion Q-value",
|
121
|
+
]
|
122
|
+
] = results.apply(
|
123
|
+
lambda row: self._get_significance_values(
|
124
|
+
row["Domain ID"],
|
125
|
+
row["Annotation"],
|
126
|
+
enrichment_pvals,
|
127
|
+
depletion_pvals,
|
128
|
+
enrichment_qvals,
|
129
|
+
depletion_qvals,
|
130
|
+
),
|
131
|
+
axis=1,
|
132
|
+
result_type="expand",
|
133
|
+
)
|
134
|
+
# Add annotation members and their counts
|
135
|
+
results["Annotation Members"] = results["Annotation"].apply(
|
136
|
+
lambda desc: self._get_annotation_members(desc)
|
137
|
+
)
|
138
|
+
results["Annotation Member Count"] = results["Annotation Members"].apply(
|
139
|
+
lambda x: len(x.split(";")) if x else 0
|
140
|
+
)
|
141
|
+
|
142
|
+
# Reorder columns and drop rows with NaN values
|
143
|
+
results = (
|
144
|
+
results[
|
145
|
+
[
|
146
|
+
"Domain ID",
|
147
|
+
"Annotation",
|
148
|
+
"Annotation Members",
|
149
|
+
"Annotation Member Count",
|
150
|
+
"Summed Significance Score",
|
151
|
+
"Enrichment P-value",
|
152
|
+
"Enrichment Q-value",
|
153
|
+
"Depletion P-value",
|
154
|
+
"Depletion Q-value",
|
155
|
+
]
|
156
|
+
]
|
157
|
+
.dropna()
|
158
|
+
.reset_index(drop=True)
|
159
|
+
)
|
160
|
+
|
161
|
+
return results
|
162
|
+
|
163
|
+
@staticmethod
|
164
|
+
def _calculate_qvalues(pvals: np.ndarray) -> np.ndarray:
|
165
|
+
"""Calculate q-values (FDR) for each row of a p-value matrix.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
pvals (np.ndarray): 2D array of p-values.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
np.ndarray: 2D array of q-values, with FDR correction applied row-wise.
|
172
|
+
"""
|
173
|
+
return np.apply_along_axis(lambda row: fdrcorrection(row)[1], 1, pvals)
|
174
|
+
|
175
|
+
def _get_significance_values(
|
176
|
+
self,
|
177
|
+
domain_id: int,
|
178
|
+
description: str,
|
179
|
+
enrichment_pvals: np.ndarray,
|
180
|
+
depletion_pvals: np.ndarray,
|
181
|
+
enrichment_qvals: np.ndarray,
|
182
|
+
depletion_qvals: np.ndarray,
|
183
|
+
) -> Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
|
184
|
+
"""Retrieve the most significant p-values and q-values (FDR) for a given annotation.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
domain_id (int): The domain ID associated with the annotation.
|
188
|
+
description (str): The annotation description.
|
189
|
+
enrichment_pvals (np.ndarray): Matrix of significance p-values.
|
190
|
+
depletion_pvals (np.ndarray): Matrix of depletion p-values.
|
191
|
+
enrichment_qvals (np.ndarray): Matrix of significance q-values.
|
192
|
+
depletion_qvals (np.ndarray): Matrix of depletion q-values.
|
193
|
+
|
194
|
+
Returns:
|
195
|
+
Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
|
196
|
+
Minimum significance p-value, significance q-value, depletion p-value, depletion q-value.
|
197
|
+
"""
|
198
|
+
try:
|
199
|
+
annotation_idx = self.annotations["ordered_annotations"].index(description)
|
200
|
+
except ValueError:
|
201
|
+
return None, None, None, None # Description not found
|
202
|
+
|
203
|
+
node_indices = self.graph.domain_id_to_node_ids_map.get(domain_id, [])
|
204
|
+
if not node_indices:
|
205
|
+
return None, None, None, None # No associated nodes
|
206
|
+
|
207
|
+
sig_p = enrichment_pvals[node_indices, annotation_idx]
|
208
|
+
dep_p = depletion_pvals[node_indices, annotation_idx]
|
209
|
+
sig_q = enrichment_qvals[node_indices, annotation_idx]
|
210
|
+
dep_q = depletion_qvals[node_indices, annotation_idx]
|
211
|
+
|
212
|
+
return (
|
213
|
+
np.min(sig_p) if sig_p.size > 0 else None,
|
214
|
+
np.min(sig_q) if sig_q.size > 0 else None,
|
215
|
+
np.min(dep_p) if dep_p.size > 0 else None,
|
216
|
+
np.min(dep_q) if dep_q.size > 0 else None,
|
217
|
+
)
|
218
|
+
|
219
|
+
def _get_annotation_members(self, description: str) -> str:
|
220
|
+
"""Retrieve node labels associated with a given annotation description.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
description (str): The annotation description.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
str: ';'-separated string of node labels that are associated with the annotation.
|
227
|
+
"""
|
228
|
+
try:
|
229
|
+
annotation_idx = self.annotations["ordered_annotations"].index(description)
|
230
|
+
except ValueError:
|
231
|
+
return "" # Description not found
|
232
|
+
|
233
|
+
nodes_present = np.where(self.annotations["matrix"][:, annotation_idx] == 1)[0]
|
234
|
+
node_labels = sorted(
|
235
|
+
self.graph.node_id_to_node_label_map[node_id]
|
236
|
+
for node_id in nodes_present
|
237
|
+
if node_id in self.graph.node_id_to_node_label_map
|
238
|
+
)
|
239
|
+
return ";".join(node_labels)
|
risk/network/io.py
CHANGED
@@ -165,12 +165,12 @@ class NetworkIO:
|
|
165
165
|
filepath: str,
|
166
166
|
source_label: str = "source",
|
167
167
|
target_label: str = "target",
|
168
|
+
view_name: str = "",
|
168
169
|
compute_sphere: bool = True,
|
169
170
|
surface_depth: float = 0.0,
|
170
171
|
min_edges_per_node: int = 0,
|
171
172
|
include_edge_weight: bool = True,
|
172
173
|
weight_label: str = "weight",
|
173
|
-
view_name: str = "",
|
174
174
|
) -> nx.Graph:
|
175
175
|
"""Load a network from a Cytoscape file.
|
176
176
|
|
@@ -178,7 +178,7 @@ class NetworkIO:
|
|
178
178
|
filepath (str): Path to the Cytoscape file.
|
179
179
|
source_label (str, optional): Source node label. Defaults to "source".
|
180
180
|
target_label (str, optional): Target node label. Defaults to "target".
|
181
|
-
view_name (str, optional): Specific view name to load. Defaults to
|
181
|
+
view_name (str, optional): Specific view name to load. Defaults to "".
|
182
182
|
compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
|
183
183
|
surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
|
184
184
|
min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
|
@@ -215,7 +215,7 @@ class NetworkIO:
|
|
215
215
|
filepath (str): Path to the Cytoscape file.
|
216
216
|
source_label (str, optional): Source node label. Defaults to "source".
|
217
217
|
target_label (str, optional): Target node label. Defaults to "target".
|
218
|
-
view_name (str, optional): Specific view name to load. Defaults to
|
218
|
+
view_name (str, optional): Specific view name to load. Defaults to "".
|
219
219
|
|
220
220
|
Returns:
|
221
221
|
nx.Graph: Loaded and processed network.
|
risk/network/plot/contour.py
CHANGED
@@ -294,7 +294,7 @@ class Contour:
|
|
294
294
|
Controls the dimmest colors. Defaults to 0.8.
|
295
295
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
296
296
|
Controls the brightest colors. Defaults to 1.0.
|
297
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on
|
297
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
|
298
298
|
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
299
299
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
300
300
|
|
risk/network/plot/labels.py
CHANGED
@@ -659,7 +659,7 @@ class Labels:
|
|
659
659
|
Controls the dimmest colors. Defaults to 0.8.
|
660
660
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
661
661
|
Controls the brightest colors. Defaults to 1.0.
|
662
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on
|
662
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
|
663
663
|
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
664
664
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
665
665
|
|
risk/network/plot/network.py
CHANGED
@@ -203,11 +203,11 @@ class Network:
|
|
203
203
|
max_scale: float = 1.0,
|
204
204
|
scale_factor: float = 1.0,
|
205
205
|
alpha: Union[float, None] = 1.0,
|
206
|
-
|
207
|
-
|
206
|
+
nonsignificant_color: Union[str, List, Tuple, np.ndarray] = "white",
|
207
|
+
nonsignificant_alpha: Union[float, None] = 1.0,
|
208
208
|
random_seed: int = 888,
|
209
209
|
) -> np.ndarray:
|
210
|
-
"""Adjust the colors of nodes in the network graph based on
|
210
|
+
"""Adjust the colors of nodes in the network graph based on significance.
|
211
211
|
|
212
212
|
Args:
|
213
213
|
cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
|
@@ -218,16 +218,16 @@ class Network:
|
|
218
218
|
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
219
219
|
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
220
220
|
scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
|
221
|
-
alpha (float, None, optional): Alpha value for
|
221
|
+
alpha (float, None, optional): Alpha value for significant nodes. If provided, it overrides any existing alpha values found in `color`.
|
222
222
|
Defaults to 1.0.
|
223
|
-
|
223
|
+
nonsignificant_color (str, List, Tuple, or np.ndarray, optional): Color for non-significant nodes. Can be a single color or an array of colors.
|
224
224
|
Defaults to "white".
|
225
|
-
|
226
|
-
in `
|
225
|
+
nonsignificant_alpha (float, None, optional): Alpha value for non-significant nodes. If provided, it overrides any existing alpha values found
|
226
|
+
in `nonsignificant_color`. Defaults to 1.0.
|
227
227
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
228
228
|
|
229
229
|
Returns:
|
230
|
-
np.ndarray: Array of RGBA colors adjusted for
|
230
|
+
np.ndarray: Array of RGBA colors adjusted for significance status.
|
231
231
|
"""
|
232
232
|
# Get the initial domain colors for each node, which are returned as RGBA
|
233
233
|
network_colors = get_domain_colors(
|
@@ -241,11 +241,11 @@ class Network:
|
|
241
241
|
scale_factor=scale_factor,
|
242
242
|
random_seed=random_seed,
|
243
243
|
)
|
244
|
-
# Apply the alpha value for
|
245
|
-
network_colors[:, 3] = alpha # Apply the alpha value to the
|
246
|
-
# Convert the non-
|
247
|
-
|
248
|
-
color=
|
244
|
+
# Apply the alpha value for significant nodes
|
245
|
+
network_colors[:, 3] = alpha # Apply the alpha value to the significant nodes' A channel
|
246
|
+
# Convert the non-significant color to RGBA using the to_rgba helper function
|
247
|
+
nonsignificant_color_rgba = to_rgba(
|
248
|
+
color=nonsignificant_color, alpha=nonsignificant_alpha, num_repeats=1
|
249
249
|
) # num_repeats=1 for a single color
|
250
250
|
# Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
|
251
251
|
# 0.1 is a predefined threshold for the minimum color intensity
|
@@ -255,34 +255,34 @@ class Network:
|
|
255
255
|
& np.all(network_colors[:, :3] == network_colors[:, 0:1], axis=1)
|
256
256
|
)[:, None],
|
257
257
|
np.tile(
|
258
|
-
np.array(
|
259
|
-
), # Replace with the full RGBA non-
|
258
|
+
np.array(nonsignificant_color_rgba), (network_colors.shape[0], 1)
|
259
|
+
), # Replace with the full RGBA non-significant color
|
260
260
|
network_colors, # Keep the original colors where no match is found
|
261
261
|
)
|
262
262
|
return adjusted_network_colors
|
263
263
|
|
264
264
|
def get_annotated_node_sizes(
|
265
|
-
self,
|
265
|
+
self, significant_size: int = 50, nonsignificant_size: int = 25
|
266
266
|
) -> np.ndarray:
|
267
|
-
"""Adjust the sizes of nodes in the network graph based on whether they are
|
267
|
+
"""Adjust the sizes of nodes in the network graph based on whether they are significant or not.
|
268
268
|
|
269
269
|
Args:
|
270
|
-
|
271
|
-
|
270
|
+
significant_size (int): Size for significant nodes. Defaults to 50.
|
271
|
+
nonsignificant_size (int): Size for non-significant nodes. Defaults to 25.
|
272
272
|
|
273
273
|
Returns:
|
274
|
-
np.ndarray: Array of node sizes, with
|
274
|
+
np.ndarray: Array of node sizes, with significant nodes larger than non-significant ones.
|
275
275
|
"""
|
276
|
-
# Merge all
|
277
|
-
|
276
|
+
# Merge all significant nodes from the domain_id_to_node_ids_map dictionary
|
277
|
+
significant_nodes = set()
|
278
278
|
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
279
|
-
|
279
|
+
significant_nodes.update(node_ids)
|
280
280
|
|
281
|
-
# Initialize all node sizes to the non-
|
282
|
-
node_sizes = np.full(len(self.graph.network.nodes),
|
283
|
-
# Set the size for
|
284
|
-
for node in
|
281
|
+
# Initialize all node sizes to the non-significant size
|
282
|
+
node_sizes = np.full(len(self.graph.network.nodes), nonsignificant_size)
|
283
|
+
# Set the size for significant nodes
|
284
|
+
for node in significant_nodes:
|
285
285
|
if node in self.graph.network.nodes:
|
286
|
-
node_sizes[node] =
|
286
|
+
node_sizes[node] = significant_size
|
287
287
|
|
288
288
|
return node_sizes
|
risk/network/plot/utils/color.py
CHANGED
@@ -35,14 +35,14 @@ def get_annotated_domain_colors(
|
|
35
35
|
blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
|
36
36
|
min_scale (float, optional): Minimum scale for color intensity when generating domain colors. Defaults to 0.8.
|
37
37
|
max_scale (float, optional): Maximum scale for color intensity when generating domain colors. Defaults to 1.0.
|
38
|
-
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
|
38
|
+
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on significance. Higher values
|
39
39
|
increase the contrast. Defaults to 1.0.
|
40
40
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
41
41
|
|
42
42
|
Returns:
|
43
43
|
np.ndarray: Array of RGBA colors for each domain.
|
44
44
|
"""
|
45
|
-
# Generate domain colors based on the
|
45
|
+
# Generate domain colors based on the significance data
|
46
46
|
node_colors = get_domain_colors(
|
47
47
|
graph=graph,
|
48
48
|
cmap=cmap,
|
@@ -82,7 +82,7 @@ def get_domain_colors(
|
|
82
82
|
scale_factor: float = 1.0,
|
83
83
|
random_seed: int = 888,
|
84
84
|
) -> np.ndarray:
|
85
|
-
"""Generate composite colors for domains based on
|
85
|
+
"""Generate composite colors for domains based on significance or specified colors.
|
86
86
|
|
87
87
|
Args:
|
88
88
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
@@ -95,12 +95,12 @@ def get_domain_colors(
|
|
95
95
|
Defaults to 0.8.
|
96
96
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap. Controls the brightest colors.
|
97
97
|
Defaults to 1.0.
|
98
|
-
scale_factor (float, optional): Exponent for adjusting the color scaling based on
|
98
|
+
scale_factor (float, optional): Exponent for adjusting the color scaling based on significance scores. Higher values increase
|
99
99
|
contrast by dimming lower scores more. Defaults to 1.0.
|
100
100
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments. Defaults to 888.
|
101
101
|
|
102
102
|
Returns:
|
103
|
-
np.ndarray: Array of RGBA colors generated for each domain, based on
|
103
|
+
np.ndarray: Array of RGBA colors generated for each domain, based on significance or the specified color.
|
104
104
|
"""
|
105
105
|
# Get colors for each domain
|
106
106
|
domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
|
@@ -111,7 +111,7 @@ def get_domain_colors(
|
|
111
111
|
# Transform colors to ensure proper alpha values and intensity
|
112
112
|
transformed_colors = _transform_colors(
|
113
113
|
node_colors,
|
114
|
-
graph.
|
114
|
+
graph.node_significance_sums,
|
115
115
|
min_scale=min_scale,
|
116
116
|
max_scale=max_scale,
|
117
117
|
scale_factor=scale_factor,
|
@@ -151,7 +151,7 @@ def _get_domain_colors(
|
|
151
151
|
def _get_composite_node_colors(
|
152
152
|
graph, domain_colors: np.ndarray, blend_colors: bool = False, blend_gamma: float = 2.2
|
153
153
|
) -> np.ndarray:
|
154
|
-
"""Generate composite colors for nodes based on domain colors and
|
154
|
+
"""Generate composite colors for nodes based on domain colors and significance values, with optional color blending.
|
155
155
|
|
156
156
|
Args:
|
157
157
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
@@ -177,26 +177,26 @@ def _get_composite_node_colors(
|
|
177
177
|
|
178
178
|
# If blending is required
|
179
179
|
else:
|
180
|
-
for node, node_info in graph.
|
180
|
+
for node, node_info in graph.node_id_to_domain_ids_and_significance_map.items():
|
181
181
|
domains = node_info["domains"] # List of domain IDs
|
182
|
-
|
183
|
-
# Filter domains and
|
184
|
-
|
185
|
-
(domain_id,
|
186
|
-
for domain_id,
|
182
|
+
significances = node_info["significances"] # List of significance values
|
183
|
+
# Filter domains and significances to keep only those with corresponding colors in domain_colors
|
184
|
+
filtered_domains_significances = [
|
185
|
+
(domain_id, significance)
|
186
|
+
for domain_id, significance in zip(domains, significances)
|
187
187
|
if domain_id in domain_colors
|
188
188
|
]
|
189
189
|
# If no valid domains exist, skip this node
|
190
|
-
if not
|
190
|
+
if not filtered_domains_significances:
|
191
191
|
continue
|
192
192
|
|
193
|
-
# Unpack filtered domains and
|
194
|
-
filtered_domains,
|
193
|
+
# Unpack filtered domains and significances
|
194
|
+
filtered_domains, filtered_significances = zip(*filtered_domains_significances)
|
195
195
|
# Get the colors corresponding to the valid filtered domains
|
196
196
|
colors = [domain_colors[domain_id] for domain_id in filtered_domains]
|
197
197
|
# Blend the colors using the given gamma (default is 2.2 if None)
|
198
198
|
gamma = blend_gamma if blend_gamma is not None else 2.2
|
199
|
-
composite_color = _blend_colors_perceptually(colors,
|
199
|
+
composite_color = _blend_colors_perceptually(colors, filtered_significances, gamma)
|
200
200
|
# Assign the composite color to the node
|
201
201
|
composite_colors[node] = composite_color
|
202
202
|
|
@@ -282,21 +282,21 @@ def _assign_distant_colors(dist_matrix, num_colors_to_generate):
|
|
282
282
|
|
283
283
|
|
284
284
|
def _blend_colors_perceptually(
|
285
|
-
colors: Union[List, Tuple, np.ndarray],
|
285
|
+
colors: Union[List, Tuple, np.ndarray], significances: List[float], gamma: float = 2.2
|
286
286
|
) -> Tuple[float, float, float, float]:
|
287
287
|
"""Blends a list of RGBA colors using gamma correction for perceptually uniform color mixing.
|
288
288
|
|
289
289
|
Args:
|
290
290
|
colors (List, Tuple, np.ndarray): List of RGBA colors. Can be a list, tuple, or NumPy array of RGBA values.
|
291
|
-
|
291
|
+
significances (List[float]): Corresponding list of significance values.
|
292
292
|
gamma (float, optional): Gamma correction factor, default is 2.2 (typical for perceptual blending).
|
293
293
|
|
294
294
|
Returns:
|
295
295
|
Tuple[float, float, float, float]: The blended RGBA color.
|
296
296
|
"""
|
297
|
-
# Normalize
|
298
|
-
|
299
|
-
proportions = [
|
297
|
+
# Normalize significances so they sum up to 1 (proportions)
|
298
|
+
total_significance = sum(significances)
|
299
|
+
proportions = [significance / total_significance for significance in significances]
|
300
300
|
# Convert colors to gamma-corrected space (apply gamma correction to RGB channels)
|
301
301
|
gamma_corrected_colors = [[channel**gamma for channel in color[:3]] for color in colors]
|
302
302
|
# Blend the colors in gamma-corrected space
|
@@ -310,17 +310,17 @@ def _blend_colors_perceptually(
|
|
310
310
|
|
311
311
|
def _transform_colors(
|
312
312
|
colors: np.ndarray,
|
313
|
-
|
313
|
+
significance_sums: np.ndarray,
|
314
314
|
min_scale: float = 0.8,
|
315
315
|
max_scale: float = 1.0,
|
316
316
|
scale_factor: float = 1.0,
|
317
317
|
) -> np.ndarray:
|
318
|
-
"""Transform colors using power scaling to emphasize high
|
318
|
+
"""Transform colors using power scaling to emphasize high significance sums more. Black colors are replaced with
|
319
319
|
very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
|
320
320
|
|
321
321
|
Args:
|
322
322
|
colors (np.ndarray): An array of RGBA colors.
|
323
|
-
|
323
|
+
significance_sums (np.ndarray): An array of significance sums corresponding to the colors.
|
324
324
|
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
325
325
|
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
326
326
|
scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
|
@@ -340,8 +340,8 @@ def _transform_colors(
|
|
340
340
|
is_black = np.all(colors[:, :3] == black_color, axis=1)
|
341
341
|
colors[is_black, :3] = dark_grey
|
342
342
|
|
343
|
-
# Normalize the
|
344
|
-
normalized_sums =
|
343
|
+
# Normalize the significance sums to the range [0, 1]
|
344
|
+
normalized_sums = significance_sums / np.max(significance_sums)
|
345
345
|
# Apply power scaling to dim lower values and emphasize higher values
|
346
346
|
scaled_sums = normalized_sums**scale_factor
|
347
347
|
# Linearly scale the normalized sums to the range [min_scale, max_scale]
|