risk-network 0.0.8b15__py3-none-any.whl → 0.0.8b17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/network/graph.py +1 -235
- risk/network/plot/canvas.py +15 -8
- risk/network/plot/contour.py +11 -9
- risk/network/plot/labels.py +18 -10
- risk/network/plot/network.py +20 -13
- risk/network/plot/plotter.py +9 -6
- risk/network/plot/utils/color.py +353 -0
- risk/network/plot/utils/layout.py +53 -0
- {risk_network-0.0.8b15.dist-info → risk_network-0.0.8b17.dist-info}/METADATA +1 -1
- {risk_network-0.0.8b15.dist-info → risk_network-0.0.8b17.dist-info}/RECORD +14 -13
- risk/network/plot/utils.py +0 -153
- {risk_network-0.0.8b15.dist-info → risk_network-0.0.8b17.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b15.dist-info → risk_network-0.0.8b17.dist-info}/WHEEL +0 -0
- {risk_network-0.0.8b15.dist-info → risk_network-0.0.8b17.dist-info}/top_level.txt +0 -0
risk/__init__.py
CHANGED
risk/network/graph.py
CHANGED
@@ -4,12 +4,11 @@ risk/network/graph
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from collections import defaultdict
|
7
|
-
from typing import Any, Dict, List
|
7
|
+
from typing import Any, Dict, List
|
8
8
|
|
9
9
|
import networkx as nx
|
10
10
|
import numpy as np
|
11
11
|
import pandas as pd
|
12
|
-
import matplotlib
|
13
12
|
|
14
13
|
|
15
14
|
class NetworkGraph:
|
@@ -107,139 +106,6 @@ class NetworkGraph:
|
|
107
106
|
|
108
107
|
return domain_id_to_label_map
|
109
108
|
|
110
|
-
def get_domain_colors(
|
111
|
-
self,
|
112
|
-
cmap: str = "gist_rainbow",
|
113
|
-
color: Union[str, None] = None,
|
114
|
-
min_scale: float = 0.8,
|
115
|
-
max_scale: float = 1.0,
|
116
|
-
scale_factor: float = 1.0,
|
117
|
-
random_seed: int = 888,
|
118
|
-
) -> np.ndarray:
|
119
|
-
"""Generate composite colors for domains based on enrichment or specified colors.
|
120
|
-
|
121
|
-
Args:
|
122
|
-
cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
123
|
-
color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
|
124
|
-
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
125
|
-
Controls the dimmest colors. Defaults to 0.8.
|
126
|
-
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
127
|
-
Controls the brightest colors. Defaults to 1.0.
|
128
|
-
scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores.
|
129
|
-
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
130
|
-
random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments.
|
131
|
-
Defaults to 888.
|
132
|
-
|
133
|
-
Returns:
|
134
|
-
np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
|
135
|
-
"""
|
136
|
-
# Get colors for each domain
|
137
|
-
domain_colors = self._get_domain_colors(cmap=cmap, color=color, random_seed=random_seed)
|
138
|
-
# Generate composite colors for nodes
|
139
|
-
node_colors = self._get_composite_node_colors(domain_colors)
|
140
|
-
# Transform colors to ensure proper alpha values and intensity
|
141
|
-
transformed_colors = _transform_colors(
|
142
|
-
node_colors,
|
143
|
-
self.node_enrichment_sums,
|
144
|
-
min_scale=min_scale,
|
145
|
-
max_scale=max_scale,
|
146
|
-
scale_factor=scale_factor,
|
147
|
-
)
|
148
|
-
|
149
|
-
return transformed_colors
|
150
|
-
|
151
|
-
def _get_domain_colors(
|
152
|
-
self,
|
153
|
-
cmap: str = "gist_rainbow",
|
154
|
-
color: Union[str, None] = None,
|
155
|
-
random_seed: int = 888,
|
156
|
-
) -> Dict[str, Any]:
|
157
|
-
"""Get colors for each domain.
|
158
|
-
|
159
|
-
Args:
|
160
|
-
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
161
|
-
color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
|
162
|
-
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
163
|
-
|
164
|
-
Returns:
|
165
|
-
dict: A dictionary mapping domain keys to their corresponding RGBA colors.
|
166
|
-
"""
|
167
|
-
# Get colors for each domain based on node positions
|
168
|
-
domain_colors = _get_colors(
|
169
|
-
self.network,
|
170
|
-
self.domain_id_to_node_ids_map,
|
171
|
-
cmap=cmap,
|
172
|
-
color=color,
|
173
|
-
random_seed=random_seed,
|
174
|
-
)
|
175
|
-
return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
|
176
|
-
|
177
|
-
def _get_composite_node_colors(self, domain_colors: np.ndarray) -> np.ndarray:
|
178
|
-
"""Generate composite colors for nodes based on domain colors and counts.
|
179
|
-
|
180
|
-
Args:
|
181
|
-
domain_colors (np.ndarray): Array of colors corresponding to each domain.
|
182
|
-
|
183
|
-
Returns:
|
184
|
-
np.ndarray: Array of composite colors for each node.
|
185
|
-
"""
|
186
|
-
# Determine the number of nodes
|
187
|
-
num_nodes = len(self.node_coordinates)
|
188
|
-
# Initialize composite colors array with shape (number of nodes, 4) for RGBA
|
189
|
-
composite_colors = np.zeros((num_nodes, 4))
|
190
|
-
# Assign colors to nodes based on domain_colors
|
191
|
-
for domain_id, nodes in self.domain_id_to_node_ids_map.items():
|
192
|
-
color = domain_colors[domain_id]
|
193
|
-
for node in nodes:
|
194
|
-
composite_colors[node] = color
|
195
|
-
|
196
|
-
return composite_colors
|
197
|
-
|
198
|
-
|
199
|
-
def _transform_colors(
|
200
|
-
colors: np.ndarray,
|
201
|
-
enrichment_sums: np.ndarray,
|
202
|
-
min_scale: float = 0.8,
|
203
|
-
max_scale: float = 1.0,
|
204
|
-
scale_factor: float = 1.0,
|
205
|
-
) -> np.ndarray:
|
206
|
-
"""Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
|
207
|
-
very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
|
208
|
-
|
209
|
-
Args:
|
210
|
-
colors (np.ndarray): An array of RGBA colors.
|
211
|
-
enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
|
212
|
-
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
213
|
-
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
214
|
-
scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
|
215
|
-
values more. Defaults to 1.0.
|
216
|
-
|
217
|
-
Returns:
|
218
|
-
np.ndarray: The transformed array of RGBA colors with adjusted intensities.
|
219
|
-
"""
|
220
|
-
# Ensure that min_scale is less than max_scale
|
221
|
-
if min_scale == max_scale:
|
222
|
-
min_scale = max_scale - 10e-6 # Avoid division by zero
|
223
|
-
|
224
|
-
# Replace black colors (#000000) with very dark grey (#1A1A1A)
|
225
|
-
black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
|
226
|
-
dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
|
227
|
-
# Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
|
228
|
-
is_black = np.all(colors[:, :3] == black_color, axis=1)
|
229
|
-
colors[is_black, :3] = dark_grey
|
230
|
-
|
231
|
-
# Normalize the enrichment sums to the range [0, 1]
|
232
|
-
normalized_sums = enrichment_sums / np.max(enrichment_sums)
|
233
|
-
# Apply power scaling to dim lower values and emphasize higher values
|
234
|
-
scaled_sums = normalized_sums**scale_factor
|
235
|
-
# Linearly scale the normalized sums to the range [min_scale, max_scale]
|
236
|
-
scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
|
237
|
-
# Adjust RGB values based on scaled sums
|
238
|
-
for i in range(3): # Only adjust RGB values
|
239
|
-
colors[:, i] = scaled_sums * colors[:, i]
|
240
|
-
|
241
|
-
return colors
|
242
|
-
|
243
109
|
|
244
110
|
def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
|
245
111
|
"""Convert 3D coordinates to 2D by unfolding a sphere to a plane.
|
@@ -291,103 +157,3 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
|
|
291
157
|
}
|
292
158
|
node_coordinates = np.vstack(list(node_positions.values()))
|
293
159
|
return node_coordinates
|
294
|
-
|
295
|
-
|
296
|
-
def _get_colors(
|
297
|
-
network,
|
298
|
-
domain_id_to_node_ids_map,
|
299
|
-
cmap: str = "gist_rainbow",
|
300
|
-
color: Union[str, None] = None,
|
301
|
-
random_seed: int = 888,
|
302
|
-
) -> List[Tuple]:
|
303
|
-
"""Generate a list of RGBA colors based on domain centroids, ensuring that domains
|
304
|
-
close in space get maximally separated colors, while keeping some randomness.
|
305
|
-
|
306
|
-
Args:
|
307
|
-
network (NetworkX graph): The graph representing the network.
|
308
|
-
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
309
|
-
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
310
|
-
color (str or None, optional): A specific color to use for all generated colors.
|
311
|
-
random_seed (int): Seed for random number generation. Defaults to 888.
|
312
|
-
|
313
|
-
Returns:
|
314
|
-
List[Tuple]: List of RGBA colors.
|
315
|
-
"""
|
316
|
-
# Set random seed for reproducibility
|
317
|
-
np.random.seed(random_seed)
|
318
|
-
# Determine the number of colors to generate based on the number of domains
|
319
|
-
num_colors_to_generate = len(domain_id_to_node_ids_map)
|
320
|
-
if color:
|
321
|
-
# Generate all colors as the same specified color
|
322
|
-
rgba = matplotlib.colors.to_rgba(color)
|
323
|
-
return [rgba] * num_colors_to_generate
|
324
|
-
|
325
|
-
# Load colormap
|
326
|
-
colormap = matplotlib.colormaps.get_cmap(cmap)
|
327
|
-
# Step 1: Calculate centroids for each domain
|
328
|
-
centroids = _calculate_centroids(network, domain_id_to_node_ids_map)
|
329
|
-
# Step 2: Calculate pairwise distances between centroids
|
330
|
-
centroid_array = np.array(centroids)
|
331
|
-
dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
|
332
|
-
# Step 3: Assign distant colors to close centroids
|
333
|
-
color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
|
334
|
-
# Step 4: Randomly shift the entire color palette while maintaining relative distances
|
335
|
-
global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
|
336
|
-
color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
|
337
|
-
# Step 5: Ensure that all positions remain between 0 and 1
|
338
|
-
color_positions = np.clip(color_positions, 0, 1)
|
339
|
-
|
340
|
-
# Step 6: Generate RGBA colors based on positions
|
341
|
-
return [colormap(pos) for pos in color_positions]
|
342
|
-
|
343
|
-
|
344
|
-
def _calculate_centroids(network, domain_id_to_node_ids_map):
|
345
|
-
"""Calculate the centroid for each domain based on node x and y coordinates in the network.
|
346
|
-
|
347
|
-
Args:
|
348
|
-
network (NetworkX graph): The graph representing the network.
|
349
|
-
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
350
|
-
|
351
|
-
Returns:
|
352
|
-
List[Tuple[float, float]]: List of centroids (x, y) for each domain.
|
353
|
-
"""
|
354
|
-
centroids = []
|
355
|
-
for domain_id, node_ids in domain_id_to_node_ids_map.items():
|
356
|
-
# Extract x and y coordinates from the network nodes
|
357
|
-
node_positions = np.array(
|
358
|
-
[[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
|
359
|
-
)
|
360
|
-
# Compute the centroid as the mean of the x and y coordinates
|
361
|
-
centroid = np.mean(node_positions, axis=0)
|
362
|
-
centroids.append(tuple(centroid))
|
363
|
-
|
364
|
-
return centroids
|
365
|
-
|
366
|
-
|
367
|
-
def _assign_distant_colors(dist_matrix, num_colors_to_generate):
|
368
|
-
"""Assign colors to centroids that are close in space, ensuring stark color differences.
|
369
|
-
|
370
|
-
Args:
|
371
|
-
dist_matrix (ndarray): Matrix of pairwise centroid distances.
|
372
|
-
num_colors_to_generate (int): Number of colors to generate.
|
373
|
-
|
374
|
-
Returns:
|
375
|
-
np.array: Array of color positions in the range [0, 1].
|
376
|
-
"""
|
377
|
-
color_positions = np.zeros(num_colors_to_generate)
|
378
|
-
# Step 1: Sort indices by centroid proximity (based on sum of distances to others)
|
379
|
-
proximity_order = sorted(
|
380
|
-
range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
|
381
|
-
)
|
382
|
-
# Step 2: Assign colors starting with the most distant points in proximity order
|
383
|
-
for i, idx in enumerate(proximity_order):
|
384
|
-
color_positions[idx] = i / num_colors_to_generate
|
385
|
-
|
386
|
-
# Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
|
387
|
-
half_spectrum = int(num_colors_to_generate / 2)
|
388
|
-
for i in range(half_spectrum):
|
389
|
-
# Split the spectrum so that close centroids are assigned distant colors
|
390
|
-
color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
|
391
|
-
color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
|
392
|
-
|
393
|
-
return color_positions
|
risk/network/plot/canvas.py
CHANGED
@@ -10,7 +10,8 @@ import numpy as np
|
|
10
10
|
|
11
11
|
from risk.log import params
|
12
12
|
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.plot.utils import
|
13
|
+
from risk.network.plot.utils.color import to_rgba
|
14
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
14
15
|
|
15
16
|
|
16
17
|
class Canvas:
|
@@ -33,8 +34,8 @@ class Canvas:
|
|
33
34
|
title_fontsize: int = 20,
|
34
35
|
subtitle_fontsize: int = 14,
|
35
36
|
font: str = "Arial",
|
36
|
-
title_color: str = "black",
|
37
|
-
subtitle_color: str = "gray",
|
37
|
+
title_color: Union[str, list, tuple, np.ndarray] = "black",
|
38
|
+
subtitle_color: Union[str, list, tuple, np.ndarray] = "gray",
|
38
39
|
title_y: float = 0.975,
|
39
40
|
title_space_offset: float = 0.075,
|
40
41
|
subtitle_offset: float = 0.025,
|
@@ -47,8 +48,10 @@ class Canvas:
|
|
47
48
|
title_fontsize (int, optional): Font size for the title. Defaults to 20.
|
48
49
|
subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 14.
|
49
50
|
font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
|
50
|
-
title_color (str, optional): Color of the title text.
|
51
|
-
|
51
|
+
title_color (str, list, tuple, or np.ndarray, optional): Color of the title text. Can be a string or an array of colors.
|
52
|
+
Defaults to "black".
|
53
|
+
subtitle_color (str, list, tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
|
54
|
+
Defaults to "gray".
|
52
55
|
title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
|
53
56
|
title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
|
54
57
|
subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
|
@@ -141,7 +144,9 @@ class Canvas:
|
|
141
144
|
)
|
142
145
|
|
143
146
|
# Convert color to RGBA using the to_rgba helper function - use outline_alpha for the perimeter
|
144
|
-
color = to_rgba(
|
147
|
+
color = to_rgba(
|
148
|
+
color=color, alpha=outline_alpha, num_repeats=1
|
149
|
+
) # num_repeats=1 for a single color
|
145
150
|
# Set the fill_alpha to 0 if not provided
|
146
151
|
fill_alpha = fill_alpha if fill_alpha is not None else 0.0
|
147
152
|
# Extract node coordinates from the network graph
|
@@ -162,7 +167,9 @@ class Canvas:
|
|
162
167
|
)
|
163
168
|
# Set the transparency of the fill if applicable
|
164
169
|
if fill_alpha > 0:
|
165
|
-
circle.set_facecolor(
|
170
|
+
circle.set_facecolor(
|
171
|
+
to_rgba(color=color, alpha=fill_alpha, num_repeats=1)
|
172
|
+
) # num_repeats=1 for a single color
|
166
173
|
|
167
174
|
self.ax.add_artist(circle)
|
168
175
|
|
@@ -209,7 +216,7 @@ class Canvas:
|
|
209
216
|
)
|
210
217
|
|
211
218
|
# Convert color to RGBA using outline_alpha for the line (outline)
|
212
|
-
outline_color = to_rgba(color=color)
|
219
|
+
outline_color = to_rgba(color=color, num_repeats=1) # num_repeats=1 for a single color
|
213
220
|
# Extract node coordinates from the network graph
|
214
221
|
node_coordinates = self.graph.node_coordinates
|
215
222
|
# Scale the node coordinates if needed
|
risk/network/plot/contour.py
CHANGED
@@ -13,7 +13,7 @@ from scipy.stats import gaussian_kde
|
|
13
13
|
|
14
14
|
from risk.log import params, logger
|
15
15
|
from risk.network.graph import NetworkGraph
|
16
|
-
from risk.network.plot.utils import get_annotated_domain_colors, to_rgba
|
16
|
+
from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
|
17
17
|
|
18
18
|
|
19
19
|
class Contour:
|
@@ -110,7 +110,7 @@ class Contour:
|
|
110
110
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
111
111
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
112
112
|
color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array.
|
113
|
-
Defaults to "white".
|
113
|
+
Can be a single color or an array of colors. Defaults to "white".
|
114
114
|
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
115
115
|
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
116
116
|
alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
|
@@ -125,15 +125,16 @@ class Contour:
|
|
125
125
|
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
126
126
|
# If it's a list of lists, iterate over sublists
|
127
127
|
node_groups = nodes
|
128
|
+
# Convert color to RGBA arrays to match the number of groups
|
129
|
+
color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=len(node_groups))
|
128
130
|
else:
|
129
131
|
# If it's a flat list of nodes, treat it as a single group
|
130
132
|
node_groups = [nodes]
|
131
|
-
|
132
|
-
|
133
|
-
color_rgba = to_rgba(color=color, alpha=alpha)
|
133
|
+
# Wrap the RGBA color in an array to index the first element
|
134
|
+
color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=1)
|
134
135
|
|
135
136
|
# Iterate over each group of nodes (either sublists or flat list)
|
136
|
-
for sublist in node_groups:
|
137
|
+
for idx, sublist in enumerate(node_groups):
|
137
138
|
# Filter to get node IDs and their coordinates for each sublist
|
138
139
|
node_ids = [
|
139
140
|
self.graph.node_label_to_node_id_map.get(node)
|
@@ -151,7 +152,7 @@ class Contour:
|
|
151
152
|
self.ax,
|
152
153
|
node_coordinates,
|
153
154
|
node_ids,
|
154
|
-
color=color_rgba,
|
155
|
+
color=color_rgba[idx],
|
155
156
|
levels=levels,
|
156
157
|
bandwidth=bandwidth,
|
157
158
|
grid_size=grid_size,
|
@@ -273,7 +274,7 @@ class Contour:
|
|
273
274
|
def get_annotated_contour_colors(
|
274
275
|
self,
|
275
276
|
cmap: str = "gist_rainbow",
|
276
|
-
color: Union[str, None] = None,
|
277
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
277
278
|
min_scale: float = 0.8,
|
278
279
|
max_scale: float = 1.0,
|
279
280
|
scale_factor: float = 1.0,
|
@@ -283,7 +284,8 @@ class Contour:
|
|
283
284
|
|
284
285
|
Args:
|
285
286
|
cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
|
286
|
-
color (str or None, optional): Color to use for the contours.
|
287
|
+
color (str, list, tuple, np.ndarray, or None, optional): Color to use for the contours. Can be a single color or an array of colors.
|
288
|
+
If None, the colormap will be used. Defaults to None.
|
287
289
|
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
288
290
|
Controls the dimmest colors. Defaults to 0.8.
|
289
291
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
risk/network/plot/labels.py
CHANGED
@@ -11,7 +11,8 @@ import pandas as pd
|
|
11
11
|
|
12
12
|
from risk.log import params
|
13
13
|
from risk.network.graph import NetworkGraph
|
14
|
-
from risk.network.plot.utils import
|
14
|
+
from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
|
15
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
15
16
|
|
16
17
|
TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
|
17
18
|
|
@@ -284,13 +285,19 @@ class Labels:
|
|
284
285
|
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
285
286
|
# If it's a list of lists, iterate over sublists
|
286
287
|
node_groups = nodes
|
288
|
+
# Convert fontcolor and arrow_color to RGBA arrays to match the number of groups
|
289
|
+
fontcolor_rgba = to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=len(node_groups))
|
290
|
+
arrow_color_rgba = to_rgba(
|
291
|
+
color=arrow_color, alpha=arrow_alpha, num_repeats=len(node_groups)
|
292
|
+
)
|
287
293
|
else:
|
288
294
|
# If it's a flat list of nodes, treat it as a single group
|
289
295
|
node_groups = [nodes]
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
296
|
+
# Wrap the RGBA fontcolor and arrow_color in an array to index the first element
|
297
|
+
fontcolor_rgba = np.array(to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=1))
|
298
|
+
arrow_color_rgba = np.array(
|
299
|
+
to_rgba(color=arrow_color, alpha=arrow_alpha, num_repeats=1)
|
300
|
+
)
|
294
301
|
|
295
302
|
# Calculate the bounding box around the network
|
296
303
|
center, radius = calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
@@ -302,7 +309,7 @@ class Labels:
|
|
302
309
|
)
|
303
310
|
|
304
311
|
# Iterate over each group of nodes (either sublists or flat list)
|
305
|
-
for sublist in node_groups:
|
312
|
+
for idx, sublist in enumerate(node_groups):
|
306
313
|
# Map node labels to IDs
|
307
314
|
node_ids = [
|
308
315
|
self.graph.node_label_to_node_id_map.get(node)
|
@@ -326,10 +333,10 @@ class Labels:
|
|
326
333
|
va="center",
|
327
334
|
fontsize=fontsize,
|
328
335
|
fontname=font,
|
329
|
-
color=fontcolor_rgba,
|
336
|
+
color=fontcolor_rgba[idx],
|
330
337
|
arrowprops=dict(
|
331
338
|
arrowstyle=arrow_style,
|
332
|
-
color=arrow_color_rgba,
|
339
|
+
color=arrow_color_rgba[idx],
|
333
340
|
linewidth=arrow_linewidth,
|
334
341
|
shrinkA=arrow_base_shrink,
|
335
342
|
shrinkB=arrow_tip_shrink,
|
@@ -630,7 +637,7 @@ class Labels:
|
|
630
637
|
def get_annotated_label_colors(
|
631
638
|
self,
|
632
639
|
cmap: str = "gist_rainbow",
|
633
|
-
color: Union[str, None] = None,
|
640
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
634
641
|
min_scale: float = 0.8,
|
635
642
|
max_scale: float = 1.0,
|
636
643
|
scale_factor: float = 1.0,
|
@@ -640,7 +647,8 @@ class Labels:
|
|
640
647
|
|
641
648
|
Args:
|
642
649
|
cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
|
643
|
-
color (str or None, optional): Color to use for the labels.
|
650
|
+
color (str, list, tuple, np.ndarray, or None, optional): Color to use for the labels. Can be a single color or an array
|
651
|
+
of colors. If None, the colormap will be used. Defaults to None.
|
644
652
|
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
645
653
|
Controls the dimmest colors. Defaults to 0.8.
|
646
654
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
risk/network/plot/network.py
CHANGED
@@ -10,7 +10,7 @@ import numpy as np
|
|
10
10
|
|
11
11
|
from risk.log import params
|
12
12
|
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.plot.utils import to_rgba
|
13
|
+
from risk.network.plot.utils.color import get_domain_colors, to_rgba
|
14
14
|
|
15
15
|
|
16
16
|
class Network:
|
@@ -47,8 +47,10 @@ class Network:
|
|
47
47
|
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
48
48
|
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors.
|
49
49
|
Defaults to "white".
|
50
|
-
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges.
|
51
|
-
|
50
|
+
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Can be a single color or an array of colors.
|
51
|
+
Defaults to "black".
|
52
|
+
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Can be a single color or an array of colors.
|
53
|
+
Defaults to "black".
|
52
54
|
node_alpha (float, None, optional): Alpha value (transparency) for the nodes. If provided, it overrides any existing alpha
|
53
55
|
values found in node_color. Defaults to 1.0. Annotated node_color alphas will override this value.
|
54
56
|
edge_alpha (float, None, optional): Alpha value (transparency) for the edges. If provided, it overrides any existing alpha
|
@@ -194,12 +196,12 @@ class Network:
|
|
194
196
|
def get_annotated_node_colors(
|
195
197
|
self,
|
196
198
|
cmap: str = "gist_rainbow",
|
197
|
-
color: Union[str, None] = None,
|
199
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
198
200
|
min_scale: float = 0.8,
|
199
201
|
max_scale: float = 1.0,
|
200
202
|
scale_factor: float = 1.0,
|
201
203
|
alpha: Union[float, None] = 1.0,
|
202
|
-
nonenriched_color: Union[str,
|
204
|
+
nonenriched_color: Union[str, list, tuple, np.ndarray] = "white",
|
203
205
|
nonenriched_alpha: Union[float, None] = 1.0,
|
204
206
|
random_seed: int = 888,
|
205
207
|
) -> np.ndarray:
|
@@ -207,22 +209,25 @@ class Network:
|
|
207
209
|
|
208
210
|
Args:
|
209
211
|
cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
|
210
|
-
color (str or None, optional): Color to use for the nodes.
|
212
|
+
color (str, list, tuple, np.ndarray, or None, optional): Color to use for the nodes. Can be a single color or an array of colors.
|
213
|
+
If None, the colormap will be used. Defaults to None.
|
211
214
|
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
212
215
|
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
213
216
|
scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
|
214
|
-
alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values
|
215
|
-
|
216
|
-
nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes.
|
217
|
-
|
218
|
-
|
217
|
+
alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values found in `color`.
|
218
|
+
Defaults to 1.0.
|
219
|
+
nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Can be a single color or an array of colors.
|
220
|
+
Defaults to "white".
|
221
|
+
nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing alpha values found
|
222
|
+
in `nonenriched_color`. Defaults to 1.0.
|
219
223
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
220
224
|
|
221
225
|
Returns:
|
222
226
|
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
223
227
|
"""
|
224
228
|
# Get the initial domain colors for each node, which are returned as RGBA
|
225
|
-
network_colors =
|
229
|
+
network_colors = get_domain_colors(
|
230
|
+
graph=self.graph,
|
226
231
|
cmap=cmap,
|
227
232
|
color=color,
|
228
233
|
min_scale=min_scale,
|
@@ -233,7 +238,9 @@ class Network:
|
|
233
238
|
# Apply the alpha value for enriched nodes
|
234
239
|
network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
|
235
240
|
# Convert the non-enriched color to RGBA using the to_rgba helper function
|
236
|
-
nonenriched_color = to_rgba(
|
241
|
+
nonenriched_color = to_rgba(
|
242
|
+
color=nonenriched_color, alpha=nonenriched_alpha, num_repeats=1
|
243
|
+
) # num_repeats=1 for a single color
|
237
244
|
# Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
|
238
245
|
# 0.1 is a predefined threshold for the minimum color intensity
|
239
246
|
adjusted_network_colors = np.where(
|
risk/network/plot/plotter.py
CHANGED
@@ -14,7 +14,8 @@ from risk.network.plot.canvas import Canvas
|
|
14
14
|
from risk.network.plot.contour import Contour
|
15
15
|
from risk.network.plot.labels import Labels
|
16
16
|
from risk.network.plot.network import Network
|
17
|
-
from risk.network.plot.utils import
|
17
|
+
from risk.network.plot.utils.color import to_rgba
|
18
|
+
from risk.network.plot.utils.layout import calculate_bounding_box
|
18
19
|
|
19
20
|
|
20
21
|
class NetworkPlotter(Canvas, Network, Contour, Labels):
|
@@ -58,7 +59,7 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
|
|
58
59
|
self,
|
59
60
|
graph: NetworkGraph,
|
60
61
|
figsize: Tuple,
|
61
|
-
background_color: Union[str,
|
62
|
+
background_color: Union[str, list, tuple, np.ndarray],
|
62
63
|
background_alpha: Union[float, None],
|
63
64
|
pad: float,
|
64
65
|
) -> plt.Axes:
|
@@ -67,9 +68,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
|
|
67
68
|
Args:
|
68
69
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
69
70
|
figsize (tuple): Size of the figure in inches (width, height).
|
70
|
-
background_color (str): Background color of the plot.
|
71
|
-
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any
|
72
|
-
|
71
|
+
background_color (str, list, tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
|
72
|
+
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
|
73
|
+
alpha values found in `background_color`.
|
73
74
|
pad (float, optional): Padding value to adjust the axis limits.
|
74
75
|
|
75
76
|
Returns:
|
@@ -98,7 +99,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
|
|
98
99
|
|
99
100
|
# Set the background color of the plot
|
100
101
|
# Convert color to RGBA using the to_rgba helper function
|
101
|
-
fig.patch.set_facecolor(
|
102
|
+
fig.patch.set_facecolor(
|
103
|
+
to_rgba(color=background_color, alpha=background_alpha, num_repeats=1)
|
104
|
+
) # num_repeats=1 for single color
|
102
105
|
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
103
106
|
# Remove axis spines for a cleaner look
|
104
107
|
for spine in ax.spines.values():
|
@@ -0,0 +1,353 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/plot/utils/plot
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
|
+
|
8
|
+
import matplotlib
|
9
|
+
import matplotlib.colors as mcolors
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from risk.network.graph import NetworkGraph
|
13
|
+
from risk.network.plot.utils.layout import calculate_centroids
|
14
|
+
|
15
|
+
|
16
|
+
def get_annotated_domain_colors(
|
17
|
+
graph: NetworkGraph,
|
18
|
+
cmap: str = "gist_rainbow",
|
19
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
20
|
+
min_scale: float = 0.8,
|
21
|
+
max_scale: float = 1.0,
|
22
|
+
scale_factor: float = 1.0,
|
23
|
+
random_seed: int = 888,
|
24
|
+
) -> np.ndarray:
|
25
|
+
"""Get colors for the domains based on node annotations, or use a specified color.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
29
|
+
cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
30
|
+
color (str, list, tuple, np.ndarray, or None, optional): Color to use for the domains. Can be a single color or an array of colors.
|
31
|
+
If None, the colormap will be used. Defaults to None.
|
32
|
+
min_scale (float, optional): Minimum scale for color intensity when generating domain colors. Defaults to 0.8.
|
33
|
+
max_scale (float, optional): Maximum scale for color intensity when generating domain colors. Defaults to 1.0.
|
34
|
+
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on enrichment. Higher values
|
35
|
+
increase the contrast. Defaults to 1.0.
|
36
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
np.ndarray: Array of RGBA colors for each domain.
|
40
|
+
"""
|
41
|
+
# Generate domain colors based on the enrichment data
|
42
|
+
node_colors = get_domain_colors(
|
43
|
+
graph=graph,
|
44
|
+
cmap=cmap,
|
45
|
+
color=color,
|
46
|
+
min_scale=min_scale,
|
47
|
+
max_scale=max_scale,
|
48
|
+
scale_factor=scale_factor,
|
49
|
+
random_seed=random_seed,
|
50
|
+
)
|
51
|
+
annotated_colors = []
|
52
|
+
for _, node_ids in graph.domain_id_to_node_ids_map.items():
|
53
|
+
if len(node_ids) > 1:
|
54
|
+
# For multi-node domains, choose the brightest color based on RGB sum
|
55
|
+
domain_colors = np.array([node_colors[node] for node in node_ids])
|
56
|
+
brightest_color = domain_colors[
|
57
|
+
np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
|
58
|
+
]
|
59
|
+
annotated_colors.append(brightest_color)
|
60
|
+
else:
|
61
|
+
# Single-node domains default to white (RGBA)
|
62
|
+
default_color = np.array([1.0, 1.0, 1.0, 1.0])
|
63
|
+
annotated_colors.append(default_color)
|
64
|
+
|
65
|
+
return np.array(annotated_colors)
|
66
|
+
|
67
|
+
|
68
|
+
def get_domain_colors(
|
69
|
+
graph: NetworkGraph,
|
70
|
+
cmap: str = "gist_rainbow",
|
71
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
72
|
+
min_scale: float = 0.8,
|
73
|
+
max_scale: float = 1.0,
|
74
|
+
scale_factor: float = 1.0,
|
75
|
+
random_seed: int = 888,
|
76
|
+
) -> np.ndarray:
|
77
|
+
"""Generate composite colors for domains based on enrichment or specified colors.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
81
|
+
cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
82
|
+
color (str, list, tuple, np.ndarray, or None, optional): A specific color or array of colors to use for all domains.
|
83
|
+
If None, the colormap will be used. Defaults to None.
|
84
|
+
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap. Controls the dimmest colors.
|
85
|
+
Defaults to 0.8.
|
86
|
+
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap. Controls the brightest colors.
|
87
|
+
Defaults to 1.0.
|
88
|
+
scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores. Higher values increase
|
89
|
+
contrast by dimming lower scores more. Defaults to 1.0.
|
90
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments. Defaults to 888.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
|
94
|
+
"""
|
95
|
+
# Get colors for each domain
|
96
|
+
domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
|
97
|
+
# Generate composite colors for nodes
|
98
|
+
node_colors = _get_composite_node_colors(graph=graph, domain_colors=domain_colors)
|
99
|
+
# Transform colors to ensure proper alpha values and intensity
|
100
|
+
transformed_colors = _transform_colors(
|
101
|
+
node_colors,
|
102
|
+
graph.node_enrichment_sums,
|
103
|
+
min_scale=min_scale,
|
104
|
+
max_scale=max_scale,
|
105
|
+
scale_factor=scale_factor,
|
106
|
+
)
|
107
|
+
return transformed_colors
|
108
|
+
|
109
|
+
|
110
|
+
def _get_domain_colors(
|
111
|
+
graph: NetworkGraph,
|
112
|
+
cmap: str = "gist_rainbow",
|
113
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
114
|
+
random_seed: int = 888,
|
115
|
+
) -> Dict[str, Any]:
|
116
|
+
"""Get colors for each domain.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
120
|
+
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
121
|
+
color (str, list, tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
|
122
|
+
If None, the colormap will be used. Defaults to None.
|
123
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
dict: A dictionary mapping domain keys to their corresponding RGBA colors.
|
127
|
+
"""
|
128
|
+
# Get colors for each domain based on node positions
|
129
|
+
domain_colors = _get_colors(
|
130
|
+
graph.network,
|
131
|
+
graph.domain_id_to_node_ids_map,
|
132
|
+
cmap=cmap,
|
133
|
+
color=color,
|
134
|
+
random_seed=random_seed,
|
135
|
+
)
|
136
|
+
return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))
|
137
|
+
|
138
|
+
|
139
|
+
def _get_composite_node_colors(graph: NetworkGraph, domain_colors: np.ndarray) -> np.ndarray:
|
140
|
+
"""Generate composite colors for nodes based on domain colors and counts.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
144
|
+
domain_colors (np.ndarray): Array of colors corresponding to each domain.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
np.ndarray: Array of composite colors for each node.
|
148
|
+
"""
|
149
|
+
# Determine the number of nodes
|
150
|
+
num_nodes = len(graph.node_coordinates)
|
151
|
+
# Initialize composite colors array with shape (number of nodes, 4) for RGBA
|
152
|
+
composite_colors = np.zeros((num_nodes, 4))
|
153
|
+
# Assign colors to nodes based on domain_colors
|
154
|
+
for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
|
155
|
+
color = domain_colors[domain_id]
|
156
|
+
for node in nodes:
|
157
|
+
composite_colors[node] = color
|
158
|
+
|
159
|
+
return composite_colors
|
160
|
+
|
161
|
+
|
162
|
+
def _get_colors(
|
163
|
+
network,
|
164
|
+
domain_id_to_node_ids_map,
|
165
|
+
cmap: str = "gist_rainbow",
|
166
|
+
color: Union[str, list, tuple, np.ndarray, None] = None,
|
167
|
+
random_seed: int = 888,
|
168
|
+
) -> List[Tuple]:
|
169
|
+
"""Generate a list of RGBA colors based on domain centroids, ensuring that domains
|
170
|
+
close in space get maximally separated colors, while keeping some randomness.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
network (NetworkX graph): The graph representing the network.
|
174
|
+
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
175
|
+
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
176
|
+
color (str, list, tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
|
177
|
+
If None, the colormap will be used. Defaults to None.
|
178
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
List[Tuple]: List of RGBA colors.
|
182
|
+
"""
|
183
|
+
# Set random seed for reproducibility
|
184
|
+
np.random.seed(random_seed)
|
185
|
+
# Determine the number of colors to generate based on the number of domains
|
186
|
+
num_colors_to_generate = len(domain_id_to_node_ids_map)
|
187
|
+
if color:
|
188
|
+
# Generate all colors as the same specified color
|
189
|
+
rgba = to_rgba(color, num_repeats=num_colors_to_generate)
|
190
|
+
return rgba
|
191
|
+
|
192
|
+
# Load colormap
|
193
|
+
colormap = matplotlib.colormaps.get_cmap(cmap)
|
194
|
+
# Step 1: Calculate centroids for each domain
|
195
|
+
centroids = calculate_centroids(network, domain_id_to_node_ids_map)
|
196
|
+
# Step 2: Calculate pairwise distances between centroids
|
197
|
+
centroid_array = np.array(centroids)
|
198
|
+
dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
|
199
|
+
# Step 3: Assign distant colors to close centroids
|
200
|
+
color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
|
201
|
+
# Step 4: Randomly shift the entire color palette while maintaining relative distances
|
202
|
+
global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
|
203
|
+
color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
|
204
|
+
# Step 5: Ensure that all positions remain between 0 and 1
|
205
|
+
color_positions = np.clip(color_positions, 0, 1)
|
206
|
+
|
207
|
+
# Step 6: Generate RGBA colors based on positions
|
208
|
+
return [colormap(pos) for pos in color_positions]
|
209
|
+
|
210
|
+
|
211
|
+
def _assign_distant_colors(dist_matrix, num_colors_to_generate):
|
212
|
+
"""Assign colors to centroids that are close in space, ensuring stark color differences.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
dist_matrix (ndarray): Matrix of pairwise centroid distances.
|
216
|
+
num_colors_to_generate (int): Number of colors to generate.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
np.array: Array of color positions in the range [0, 1].
|
220
|
+
"""
|
221
|
+
color_positions = np.zeros(num_colors_to_generate)
|
222
|
+
# Step 1: Sort indices by centroid proximity (based on sum of distances to others)
|
223
|
+
proximity_order = sorted(
|
224
|
+
range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
|
225
|
+
)
|
226
|
+
# Step 2: Assign colors starting with the most distant points in proximity order
|
227
|
+
for i, idx in enumerate(proximity_order):
|
228
|
+
color_positions[idx] = i / num_colors_to_generate
|
229
|
+
|
230
|
+
# Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
|
231
|
+
half_spectrum = int(num_colors_to_generate / 2)
|
232
|
+
for i in range(half_spectrum):
|
233
|
+
# Split the spectrum so that close centroids are assigned distant colors
|
234
|
+
color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
|
235
|
+
color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
|
236
|
+
|
237
|
+
return color_positions
|
238
|
+
|
239
|
+
|
240
|
+
def _transform_colors(
|
241
|
+
colors: np.ndarray,
|
242
|
+
enrichment_sums: np.ndarray,
|
243
|
+
min_scale: float = 0.8,
|
244
|
+
max_scale: float = 1.0,
|
245
|
+
scale_factor: float = 1.0,
|
246
|
+
) -> np.ndarray:
|
247
|
+
"""Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
|
248
|
+
very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
|
249
|
+
|
250
|
+
Args:
|
251
|
+
colors (np.ndarray): An array of RGBA colors.
|
252
|
+
enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
|
253
|
+
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
254
|
+
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
255
|
+
scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
|
256
|
+
values more. Defaults to 1.0.
|
257
|
+
|
258
|
+
Returns:
|
259
|
+
np.ndarray: The transformed array of RGBA colors with adjusted intensities.
|
260
|
+
"""
|
261
|
+
# Ensure that min_scale is less than max_scale
|
262
|
+
if min_scale == max_scale:
|
263
|
+
min_scale = max_scale - 10e-6 # Avoid division by zero
|
264
|
+
|
265
|
+
# Replace black colors (#000000) with very dark grey (#1A1A1A)
|
266
|
+
black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
|
267
|
+
dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
|
268
|
+
# Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
|
269
|
+
is_black = np.all(colors[:, :3] == black_color, axis=1)
|
270
|
+
colors[is_black, :3] = dark_grey
|
271
|
+
|
272
|
+
# Normalize the enrichment sums to the range [0, 1]
|
273
|
+
normalized_sums = enrichment_sums / np.max(enrichment_sums)
|
274
|
+
# Apply power scaling to dim lower values and emphasize higher values
|
275
|
+
scaled_sums = normalized_sums**scale_factor
|
276
|
+
# Linearly scale the normalized sums to the range [min_scale, max_scale]
|
277
|
+
scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
|
278
|
+
# Adjust RGB values based on scaled sums
|
279
|
+
for i in range(3): # Only adjust RGB values
|
280
|
+
colors[:, i] = scaled_sums * colors[:, i]
|
281
|
+
|
282
|
+
return colors
|
283
|
+
|
284
|
+
|
285
|
+
def to_rgba(
|
286
|
+
color: Union[str, list, tuple, np.ndarray],
|
287
|
+
alpha: Union[float, None] = None,
|
288
|
+
num_repeats: Union[int, None] = None,
|
289
|
+
) -> np.ndarray:
|
290
|
+
"""Convert color(s) to RGBA format, applying alpha and repeating as needed.
|
291
|
+
|
292
|
+
Args:
|
293
|
+
color (str, list, tuple, np.ndarray): The color(s) to convert. Can be a string (e.g., 'red'), a list or tuple of RGB/RGBA values,
|
294
|
+
or an `np.ndarray` of colors.
|
295
|
+
alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values found
|
296
|
+
in color.
|
297
|
+
num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
|
301
|
+
"""
|
302
|
+
|
303
|
+
def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
|
304
|
+
"""Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
|
305
|
+
# Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
|
306
|
+
if isinstance(c, str):
|
307
|
+
# Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
|
308
|
+
rgba = np.array(mcolors.to_rgba(c))
|
309
|
+
elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
|
310
|
+
# Convert RGB (3) or RGBA (4) values to RGBA format
|
311
|
+
rgba = np.array(mcolors.to_rgba(c))
|
312
|
+
else:
|
313
|
+
raise ValueError(
|
314
|
+
f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
|
315
|
+
)
|
316
|
+
|
317
|
+
if alpha is not None: # Override alpha if provided
|
318
|
+
rgba[3] = alpha
|
319
|
+
return rgba
|
320
|
+
|
321
|
+
# If color is a 2D array of RGBA values, convert it to a list of lists
|
322
|
+
if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
|
323
|
+
color = [list(c) for c in color]
|
324
|
+
|
325
|
+
# Handle a single color (string or RGB/RGBA list/tuple)
|
326
|
+
if (
|
327
|
+
isinstance(color, str)
|
328
|
+
or isinstance(color, (list, tuple, np.ndarray))
|
329
|
+
and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
|
330
|
+
):
|
331
|
+
rgba_color = convert_to_rgba(color)
|
332
|
+
if num_repeats:
|
333
|
+
return np.tile(
|
334
|
+
rgba_color, (num_repeats, 1)
|
335
|
+
) # Repeat the color if num_repeats is provided
|
336
|
+
return np.array([rgba_color]) # Return a single color wrapped in a numpy array
|
337
|
+
|
338
|
+
# Handle a list/array of colors
|
339
|
+
elif isinstance(color, (list, tuple, np.ndarray)):
|
340
|
+
rgba_colors = np.array(
|
341
|
+
[convert_to_rgba(c) for c in color]
|
342
|
+
) # Convert each color in the list to RGBA
|
343
|
+
# Handle repetition if num_repeats is provided
|
344
|
+
if num_repeats:
|
345
|
+
repeated_colors = np.array(
|
346
|
+
[rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
|
347
|
+
)
|
348
|
+
return repeated_colors
|
349
|
+
|
350
|
+
return rgba_colors
|
351
|
+
|
352
|
+
else:
|
353
|
+
raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
|
@@ -0,0 +1,53 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/plot/utils/layout
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Tuple
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
|
11
|
+
def calculate_bounding_box(
|
12
|
+
node_coordinates: np.ndarray, radius_margin: float = 1.05
|
13
|
+
) -> Tuple[np.ndarray, float]:
|
14
|
+
"""Calculate the bounding box of the network based on node coordinates.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
node_coordinates (np.ndarray): Array of node coordinates (x, y).
|
18
|
+
radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
tuple: Center of the bounding box and the radius (adjusted by the radius margin).
|
22
|
+
"""
|
23
|
+
# Find minimum and maximum x, y coordinates
|
24
|
+
x_min, y_min = np.min(node_coordinates, axis=0)
|
25
|
+
x_max, y_max = np.max(node_coordinates, axis=0)
|
26
|
+
# Calculate the center of the bounding box
|
27
|
+
center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
|
28
|
+
# Calculate the radius of the bounding box, adjusted by the margin
|
29
|
+
radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
|
30
|
+
return center, radius
|
31
|
+
|
32
|
+
|
33
|
+
def calculate_centroids(network, domain_id_to_node_ids_map):
|
34
|
+
"""Calculate the centroid for each domain based on node x and y coordinates in the network.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
network (NetworkX graph): The graph representing the network.
|
38
|
+
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
List[Tuple[float, float]]: List of centroids (x, y) for each domain.
|
42
|
+
"""
|
43
|
+
centroids = []
|
44
|
+
for domain_id, node_ids in domain_id_to_node_ids_map.items():
|
45
|
+
# Extract x and y coordinates from the network nodes
|
46
|
+
node_positions = np.array(
|
47
|
+
[[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
|
48
|
+
)
|
49
|
+
# Compute the centroid as the mean of the x and y coordinates
|
50
|
+
centroid = np.mean(node_positions, axis=0)
|
51
|
+
centroids.append(tuple(centroid))
|
52
|
+
|
53
|
+
return centroids
|
@@ -1,4 +1,4 @@
|
|
1
|
-
risk/__init__.py,sha256=
|
1
|
+
risk/__init__.py,sha256=pF-sA0CxulV6zIC6axfgeXWkkQGyVZGuoGF_u4RgoQM,113
|
2
2
|
risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
|
3
3
|
risk/risk.py,sha256=slJXca_a726_D7oXwe765HaKTv3ZrOvhttyrWdCGPkA,21231
|
4
4
|
risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
|
@@ -13,15 +13,16 @@ risk/neighborhoods/domains.py,sha256=D5MUIghbwyKKCAE8PN_HXvsO9NxLTGejQmyEqetD1Bk
|
|
13
13
|
risk/neighborhoods/neighborhoods.py,sha256=M-wL4xB_BUTlSZg90swygO5NdrZ6hFUFqs6jsiZaqHk,18260
|
14
14
|
risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
|
15
15
|
risk/network/geometry.py,sha256=H1yGVVqgbfpzBzJwEheDLfvGLSA284jGQQTn612L4Vc,6759
|
16
|
-
risk/network/graph.py,sha256=
|
16
|
+
risk/network/graph.py,sha256=x5cur1meitkR0YuE5vGxX0s_IFa5wkx8z44f_C1vK7U,6509
|
17
17
|
risk/network/io.py,sha256=u0PPcKjp6Xze--7eDOlvalYkjQ9S2sjiC-ac2476PUI,22942
|
18
18
|
risk/network/plot/__init__.py,sha256=MfmaXJgAZJgXZ2wrhK8pXwzETlcMaLChhWXKAozniAo,98
|
19
|
-
risk/network/plot/canvas.py,sha256=
|
20
|
-
risk/network/plot/contour.py,sha256=
|
21
|
-
risk/network/plot/labels.py,sha256=
|
22
|
-
risk/network/plot/network.py,sha256=
|
23
|
-
risk/network/plot/plotter.py,sha256=
|
24
|
-
risk/network/plot/utils.py,sha256=
|
19
|
+
risk/network/plot/canvas.py,sha256=JnjPQaryRb_J6LP36BT2-rlsbJO3T4tTBornL8Oqqbs,10778
|
20
|
+
risk/network/plot/contour.py,sha256=8uwJ7K-Z6VMyr_uQ5VUyoQSqDHA7zDvR_nYAmLn60-I,14647
|
21
|
+
risk/network/plot/labels.py,sha256=Gt2HIbYsC5QZNZa_Mk29hijDVv0_V0pXEl1IILaNKkY,44016
|
22
|
+
risk/network/plot/network.py,sha256=9blVFeCp5x5XoGhPwOOdADegXC4gC72c2vrM2u4QPe0,13235
|
23
|
+
risk/network/plot/plotter.py,sha256=lN-_GDXRk9V3IFu8q7QmPjJGBZiP0QYwSvU6dVVDV2E,5770
|
24
|
+
risk/network/plot/utils/color.py,sha256=4W4EoQ_Fs4tmbngdczXnFkkAjvyYP5EV_P2Vu-TCCwY,15573
|
25
|
+
risk/network/plot/utils/layout.py,sha256=znssSqe2VZzzSz47hLZtTuXwMTpHR9b8lkQPL0BX7OA,1950
|
25
26
|
risk/stats/__init__.py,sha256=WcgoETQ-hS0LQqKRsAMIPtP15xZ-4eul6VUBuUx4Wzc,220
|
26
27
|
risk/stats/hypergeom.py,sha256=o6Qnj31gCAKxr2uQirXrbv7XvdDJGEq69MFW-ubx_hA,2272
|
27
28
|
risk/stats/poisson.py,sha256=8x9hB4DCukq4gNIlIKO-c_jYG1-BTwTX53oLauFyfj8,1793
|
@@ -29,8 +30,8 @@ risk/stats/stats.py,sha256=kvShov-94W6ffgDUTb522vB9hDJQSyTsYif_UIaFfSM,7059
|
|
29
30
|
risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
|
30
31
|
risk/stats/permutation/permutation.py,sha256=D84Rcpt6iTQniK0PfQGcw9bLcHbMt9p-ARcurUnIXZQ,10095
|
31
32
|
risk/stats/permutation/test_functions.py,sha256=lftOude6hee0pyR80HlBD32522JkDoN5hrKQ9VEbuoY,2345
|
32
|
-
risk_network-0.0.
|
33
|
-
risk_network-0.0.
|
34
|
-
risk_network-0.0.
|
35
|
-
risk_network-0.0.
|
36
|
-
risk_network-0.0.
|
33
|
+
risk_network-0.0.8b17.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
|
34
|
+
risk_network-0.0.8b17.dist-info/METADATA,sha256=HmkNnAzSs7IdEGzYirOsrzy0euWVd0Ezn4OztCtXhnE,47498
|
35
|
+
risk_network-0.0.8b17.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
36
|
+
risk_network-0.0.8b17.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
|
37
|
+
risk_network-0.0.8b17.dist-info/RECORD,,
|
risk/network/plot/utils.py
DELETED
@@ -1,153 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
risk/network/plot/utils
|
3
|
-
~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
-
"""
|
5
|
-
|
6
|
-
from typing import List, Tuple, Union
|
7
|
-
|
8
|
-
import matplotlib.colors as mcolors
|
9
|
-
import numpy as np
|
10
|
-
|
11
|
-
from risk.network.graph import NetworkGraph
|
12
|
-
|
13
|
-
|
14
|
-
def get_annotated_domain_colors(
|
15
|
-
graph: NetworkGraph,
|
16
|
-
cmap: str = "gist_rainbow",
|
17
|
-
color: Union[str, None] = None,
|
18
|
-
min_scale: float = 0.8,
|
19
|
-
max_scale: float = 1.0,
|
20
|
-
scale_factor: float = 1.0,
|
21
|
-
random_seed: int = 888,
|
22
|
-
) -> np.ndarray:
|
23
|
-
"""Get colors for the domains based on node annotations, or use a specified color.
|
24
|
-
|
25
|
-
Args:
|
26
|
-
graph (NetworkGraph): The network data and attributes to be visualized.
|
27
|
-
cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
28
|
-
color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
|
29
|
-
min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
|
30
|
-
Defaults to 0.8.
|
31
|
-
max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
|
32
|
-
Defaults to 1.0.
|
33
|
-
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
|
34
|
-
enrichment. Higher values increase the contrast. Defaults to 1.0.
|
35
|
-
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
36
|
-
|
37
|
-
Returns:
|
38
|
-
np.ndarray: Array of RGBA colors for each domain.
|
39
|
-
"""
|
40
|
-
# Generate domain colors based on the enrichment data
|
41
|
-
node_colors = graph.get_domain_colors(
|
42
|
-
cmap=cmap,
|
43
|
-
color=color,
|
44
|
-
min_scale=min_scale,
|
45
|
-
max_scale=max_scale,
|
46
|
-
scale_factor=scale_factor,
|
47
|
-
random_seed=random_seed,
|
48
|
-
)
|
49
|
-
annotated_colors = []
|
50
|
-
for _, node_ids in graph.domain_id_to_node_ids_map.items():
|
51
|
-
if len(node_ids) > 1:
|
52
|
-
# For multi-node domains, choose the brightest color based on RGB sum
|
53
|
-
domain_colors = np.array([node_colors[node] for node in node_ids])
|
54
|
-
brightest_color = domain_colors[
|
55
|
-
np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
|
56
|
-
]
|
57
|
-
annotated_colors.append(brightest_color)
|
58
|
-
else:
|
59
|
-
# Single-node domains default to white (RGBA)
|
60
|
-
default_color = np.array([1.0, 1.0, 1.0, 1.0])
|
61
|
-
annotated_colors.append(default_color)
|
62
|
-
|
63
|
-
return np.array(annotated_colors)
|
64
|
-
|
65
|
-
|
66
|
-
def calculate_bounding_box(
|
67
|
-
node_coordinates: np.ndarray, radius_margin: float = 1.05
|
68
|
-
) -> Tuple[np.ndarray, float]:
|
69
|
-
"""Calculate the bounding box of the network based on node coordinates.
|
70
|
-
|
71
|
-
Args:
|
72
|
-
node_coordinates (np.ndarray): Array of node coordinates (x, y).
|
73
|
-
radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
|
74
|
-
|
75
|
-
Returns:
|
76
|
-
tuple: Center of the bounding box and the radius (adjusted by the radius margin).
|
77
|
-
"""
|
78
|
-
# Find minimum and maximum x, y coordinates
|
79
|
-
x_min, y_min = np.min(node_coordinates, axis=0)
|
80
|
-
x_max, y_max = np.max(node_coordinates, axis=0)
|
81
|
-
# Calculate the center of the bounding box
|
82
|
-
center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
|
83
|
-
# Calculate the radius of the bounding box, adjusted by the margin
|
84
|
-
radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
|
85
|
-
return center, radius
|
86
|
-
|
87
|
-
|
88
|
-
def to_rgba(
|
89
|
-
color: Union[str, List, Tuple, np.ndarray],
|
90
|
-
alpha: Union[float, None] = None,
|
91
|
-
num_repeats: Union[int, None] = None,
|
92
|
-
) -> np.ndarray:
|
93
|
-
"""Convert color(s) to RGBA format, applying alpha and repeating as needed.
|
94
|
-
|
95
|
-
Args:
|
96
|
-
color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
|
97
|
-
alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values
|
98
|
-
found in color.
|
99
|
-
num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
|
100
|
-
|
101
|
-
Returns:
|
102
|
-
np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
|
103
|
-
"""
|
104
|
-
|
105
|
-
def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
|
106
|
-
"""Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
|
107
|
-
# Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
|
108
|
-
if isinstance(c, str):
|
109
|
-
# Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
|
110
|
-
rgba = np.array(mcolors.to_rgba(c))
|
111
|
-
elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
|
112
|
-
# Convert RGB (3) or RGBA (4) values to RGBA format
|
113
|
-
rgba = np.array(mcolors.to_rgba(c))
|
114
|
-
else:
|
115
|
-
raise ValueError(
|
116
|
-
f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
|
117
|
-
)
|
118
|
-
|
119
|
-
if alpha is not None: # Override alpha if provided
|
120
|
-
rgba[3] = alpha
|
121
|
-
return rgba
|
122
|
-
|
123
|
-
# If color is a 2D array of RGBA values, convert it to a list of lists
|
124
|
-
if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
|
125
|
-
color = [list(c) for c in color]
|
126
|
-
|
127
|
-
# Handle a single color (string or RGB/RGBA list/tuple)
|
128
|
-
if isinstance(color, (str, list, tuple)) and not any(
|
129
|
-
isinstance(c, (list, tuple, np.ndarray)) for c in color
|
130
|
-
):
|
131
|
-
rgba_color = convert_to_rgba(color)
|
132
|
-
if num_repeats:
|
133
|
-
return np.tile(
|
134
|
-
rgba_color, (num_repeats, 1)
|
135
|
-
) # Repeat the color if num_repeats is provided
|
136
|
-
return np.array([rgba_color]) # Return a single color wrapped in a numpy array
|
137
|
-
|
138
|
-
# Handle a list/array of colors
|
139
|
-
elif isinstance(color, (list, tuple, np.ndarray)):
|
140
|
-
rgba_colors = np.array(
|
141
|
-
[convert_to_rgba(c) for c in color]
|
142
|
-
) # Convert each color in the list to RGBA
|
143
|
-
# Handle repetition if num_repeats is provided
|
144
|
-
if num_repeats:
|
145
|
-
repeated_colors = np.array(
|
146
|
-
[rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
|
147
|
-
)
|
148
|
-
return repeated_colors
|
149
|
-
|
150
|
-
return rgba_colors
|
151
|
-
|
152
|
-
else:
|
153
|
-
raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
|
File without changes
|
File without changes
|
File without changes
|