risk-network 0.0.8b14__py3-none-any.whl → 0.0.8b16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.8-beta.14"
10
+ __version__ = "0.0.8-beta.16"
risk/network/graph.py CHANGED
@@ -4,12 +4,11 @@ risk/network/graph
4
4
  """
5
5
 
6
6
  from collections import defaultdict
7
- from typing import Any, Dict, List, Tuple, Union
7
+ from typing import Any, Dict, List
8
8
 
9
9
  import networkx as nx
10
10
  import numpy as np
11
11
  import pandas as pd
12
- import matplotlib
13
12
 
14
13
 
15
14
  class NetworkGraph:
@@ -107,139 +106,6 @@ class NetworkGraph:
107
106
 
108
107
  return domain_id_to_label_map
109
108
 
110
- def get_domain_colors(
111
- self,
112
- cmap: str = "gist_rainbow",
113
- color: Union[str, None] = None,
114
- min_scale: float = 0.8,
115
- max_scale: float = 1.0,
116
- scale_factor: float = 1.0,
117
- random_seed: int = 888,
118
- ) -> np.ndarray:
119
- """Generate composite colors for domains based on enrichment or specified colors.
120
-
121
- Args:
122
- cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
123
- color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
124
- min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
125
- Controls the dimmest colors. Defaults to 0.8.
126
- max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
127
- Controls the brightest colors. Defaults to 1.0.
128
- scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores.
129
- A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
130
- random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments.
131
- Defaults to 888.
132
-
133
- Returns:
134
- np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
135
- """
136
- # Get colors for each domain
137
- domain_colors = self._get_domain_colors(cmap=cmap, color=color, random_seed=random_seed)
138
- # Generate composite colors for nodes
139
- node_colors = self._get_composite_node_colors(domain_colors)
140
- # Transform colors to ensure proper alpha values and intensity
141
- transformed_colors = _transform_colors(
142
- node_colors,
143
- self.node_enrichment_sums,
144
- min_scale=min_scale,
145
- max_scale=max_scale,
146
- scale_factor=scale_factor,
147
- )
148
-
149
- return transformed_colors
150
-
151
- def _get_domain_colors(
152
- self,
153
- cmap: str = "gist_rainbow",
154
- color: Union[str, None] = None,
155
- random_seed: int = 888,
156
- ) -> Dict[str, Any]:
157
- """Get colors for each domain.
158
-
159
- Args:
160
- cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
161
- color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
162
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
163
-
164
- Returns:
165
- dict: A dictionary mapping domain keys to their corresponding RGBA colors.
166
- """
167
- # Get colors for each domain based on node positions
168
- domain_colors = _get_colors(
169
- self.network,
170
- self.domain_id_to_node_ids_map,
171
- cmap=cmap,
172
- color=color,
173
- random_seed=random_seed,
174
- )
175
- return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
176
-
177
- def _get_composite_node_colors(self, domain_colors: np.ndarray) -> np.ndarray:
178
- """Generate composite colors for nodes based on domain colors and counts.
179
-
180
- Args:
181
- domain_colors (np.ndarray): Array of colors corresponding to each domain.
182
-
183
- Returns:
184
- np.ndarray: Array of composite colors for each node.
185
- """
186
- # Determine the number of nodes
187
- num_nodes = len(self.node_coordinates)
188
- # Initialize composite colors array with shape (number of nodes, 4) for RGBA
189
- composite_colors = np.zeros((num_nodes, 4))
190
- # Assign colors to nodes based on domain_colors
191
- for domain_id, nodes in self.domain_id_to_node_ids_map.items():
192
- color = domain_colors[domain_id]
193
- for node in nodes:
194
- composite_colors[node] = color
195
-
196
- return composite_colors
197
-
198
-
199
- def _transform_colors(
200
- colors: np.ndarray,
201
- enrichment_sums: np.ndarray,
202
- min_scale: float = 0.8,
203
- max_scale: float = 1.0,
204
- scale_factor: float = 1.0,
205
- ) -> np.ndarray:
206
- """Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
207
- very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
208
-
209
- Args:
210
- colors (np.ndarray): An array of RGBA colors.
211
- enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
212
- min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
213
- max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
214
- scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
215
- values more. Defaults to 1.0.
216
-
217
- Returns:
218
- np.ndarray: The transformed array of RGBA colors with adjusted intensities.
219
- """
220
- # Ensure that min_scale is less than max_scale
221
- if min_scale == max_scale:
222
- min_scale = max_scale - 10e-6 # Avoid division by zero
223
-
224
- # Replace black colors (#000000) with very dark grey (#1A1A1A)
225
- black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
226
- dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
227
- # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
228
- is_black = np.all(colors[:, :3] == black_color, axis=1)
229
- colors[is_black, :3] = dark_grey
230
-
231
- # Normalize the enrichment sums to the range [0, 1]
232
- normalized_sums = enrichment_sums / np.max(enrichment_sums)
233
- # Apply power scaling to dim lower values and emphasize higher values
234
- scaled_sums = normalized_sums**scale_factor
235
- # Linearly scale the normalized sums to the range [min_scale, max_scale]
236
- scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
237
- # Adjust RGB values based on scaled sums
238
- for i in range(3): # Only adjust RGB values
239
- colors[:, i] = scaled_sums * colors[:, i]
240
-
241
- return colors
242
-
243
109
 
244
110
  def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
245
111
  """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
@@ -291,103 +157,3 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
291
157
  }
292
158
  node_coordinates = np.vstack(list(node_positions.values()))
293
159
  return node_coordinates
294
-
295
-
296
- def _get_colors(
297
- network,
298
- domain_id_to_node_ids_map,
299
- cmap: str = "gist_rainbow",
300
- color: Union[str, None] = None,
301
- random_seed: int = 888,
302
- ) -> List[Tuple]:
303
- """Generate a list of RGBA colors based on domain centroids, ensuring that domains
304
- close in space get maximally separated colors, while keeping some randomness.
305
-
306
- Args:
307
- network (NetworkX graph): The graph representing the network.
308
- domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
309
- cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
310
- color (str or None, optional): A specific color to use for all generated colors.
311
- random_seed (int): Seed for random number generation. Defaults to 888.
312
-
313
- Returns:
314
- List[Tuple]: List of RGBA colors.
315
- """
316
- # Set random seed for reproducibility
317
- np.random.seed(random_seed)
318
- # Determine the number of colors to generate based on the number of domains
319
- num_colors_to_generate = len(domain_id_to_node_ids_map)
320
- if color:
321
- # Generate all colors as the same specified color
322
- rgba = matplotlib.colors.to_rgba(color)
323
- return [rgba] * num_colors_to_generate
324
-
325
- # Load colormap
326
- colormap = matplotlib.colormaps.get_cmap(cmap)
327
- # Step 1: Calculate centroids for each domain
328
- centroids = _calculate_centroids(network, domain_id_to_node_ids_map)
329
- # Step 2: Calculate pairwise distances between centroids
330
- centroid_array = np.array(centroids)
331
- dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
332
- # Step 3: Assign distant colors to close centroids
333
- color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
334
- # Step 4: Randomly shift the entire color palette while maintaining relative distances
335
- global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
336
- color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
337
- # Step 5: Ensure that all positions remain between 0 and 1
338
- color_positions = np.clip(color_positions, 0, 1)
339
-
340
- # Step 6: Generate RGBA colors based on positions
341
- return [colormap(pos) for pos in color_positions]
342
-
343
-
344
- def _calculate_centroids(network, domain_id_to_node_ids_map):
345
- """Calculate the centroid for each domain based on node x and y coordinates in the network.
346
-
347
- Args:
348
- network (NetworkX graph): The graph representing the network.
349
- domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
350
-
351
- Returns:
352
- List[Tuple[float, float]]: List of centroids (x, y) for each domain.
353
- """
354
- centroids = []
355
- for domain_id, node_ids in domain_id_to_node_ids_map.items():
356
- # Extract x and y coordinates from the network nodes
357
- node_positions = np.array(
358
- [[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
359
- )
360
- # Compute the centroid as the mean of the x and y coordinates
361
- centroid = np.mean(node_positions, axis=0)
362
- centroids.append(tuple(centroid))
363
-
364
- return centroids
365
-
366
-
367
- def _assign_distant_colors(dist_matrix, num_colors_to_generate):
368
- """Assign colors to centroids that are close in space, ensuring stark color differences.
369
-
370
- Args:
371
- dist_matrix (ndarray): Matrix of pairwise centroid distances.
372
- num_colors_to_generate (int): Number of colors to generate.
373
-
374
- Returns:
375
- np.array: Array of color positions in the range [0, 1].
376
- """
377
- color_positions = np.zeros(num_colors_to_generate)
378
- # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
379
- proximity_order = sorted(
380
- range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
381
- )
382
- # Step 2: Assign colors starting with the most distant points in proximity order
383
- for i, idx in enumerate(proximity_order):
384
- color_positions[idx] = i / num_colors_to_generate
385
-
386
- # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
387
- half_spectrum = int(num_colors_to_generate / 2)
388
- for i in range(half_spectrum):
389
- # Split the spectrum so that close centroids are assigned distant colors
390
- color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
391
- color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
392
-
393
- return color_positions
@@ -10,7 +10,8 @@ import numpy as np
10
10
 
11
11
  from risk.log import params
12
12
  from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils import calculate_bounding_box, to_rgba
13
+ from risk.network.plot.utils.color import to_rgba
14
+ from risk.network.plot.utils.layout import calculate_bounding_box
14
15
 
15
16
 
16
17
  class Canvas:
@@ -13,7 +13,7 @@ from scipy.stats import gaussian_kde
13
13
 
14
14
  from risk.log import params, logger
15
15
  from risk.network.graph import NetworkGraph
16
- from risk.network.plot.utils import get_annotated_domain_colors, to_rgba
16
+ from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
17
17
 
18
18
 
19
19
  class Contour:
@@ -11,7 +11,8 @@ import pandas as pd
11
11
 
12
12
  from risk.log import params
13
13
  from risk.network.graph import NetworkGraph
14
- from risk.network.plot.utils import calculate_bounding_box, get_annotated_domain_colors, to_rgba
14
+ from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
15
+ from risk.network.plot.utils.layout import calculate_bounding_box
15
16
 
16
17
  TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
17
18
 
@@ -34,7 +35,7 @@ class Labels:
34
35
  scale: float = 1.05,
35
36
  offset: float = 0.10,
36
37
  font: str = "Arial",
37
- fontcase: Union[str, None] = None,
38
+ fontcase: Union[str, Dict[str, str], None] = None,
38
39
  fontsize: int = 10,
39
40
  fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
40
41
  fontalpha: Union[float, None] = 1.0,
@@ -60,8 +61,10 @@ class Labels:
60
61
  scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
61
62
  offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
62
63
  font (str, optional): Font name for the labels. Defaults to "Arial".
63
- fontcase (str, None, optional): Case transformation for the labels. Can be "capitalize", "lower", "title",
64
- "upper", or None. Defaults to None.
64
+ fontcase (Union[str, Dict[str, str], None]): Defines how to transform the case of words.
65
+ - If a string (e.g., 'upper', 'lower', 'title'), applies the transformation to all words.
66
+ - If a dictionary, maps specific cases ('lower', 'upper', 'title') to transformations (e.g., 'lower'='upper').
67
+ - If None, no transformation is applied.
65
68
  fontsize (int, optional): Font size for the labels. Defaults to 10.
66
69
  fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array.
67
70
  Defaults to "black".
@@ -855,23 +858,43 @@ def _swap_and_evaluate(
855
858
  return _calculate_total_distance(swapped_positions, domain_centroids)
856
859
 
857
860
 
858
- def _apply_str_transformation(words: List[str], transformation: str) -> List[str]:
861
+ def _apply_str_transformation(
862
+ words: List[str], transformation: Union[str, Dict[str, str]]
863
+ ) -> List[str]:
859
864
  """Apply a user-specified case transformation to each word in the list without appending duplicates.
860
865
 
861
866
  Args:
862
867
  words (List[str]): A list of words to transform.
863
- transformation (str): The case transformation to apply (e.g., 'lower', 'upper', 'title', 'capitalize').
868
+ transformation (Union[str, Dict[str, str]]): A single transformation (e.g., 'lower', 'upper', 'title', 'capitalize')
869
+ or a dictionary mapping cases ('lower', 'upper', 'title') to transformations (e.g., 'lower'='upper').
864
870
 
865
871
  Returns:
866
872
  List[str]: A list of transformed words with no duplicates.
867
873
  """
868
874
  transformed_words = []
869
875
  for word in words:
870
- if hasattr(word, transformation):
871
- transformed_word = getattr(word, transformation)() # Apply the string method
872
-
873
- # Only append if the transformed word is not already in the list
874
- if transformed_word not in transformed_words:
875
- transformed_words.append(transformed_word)
876
+ # Convert the word to a string if it is not already
877
+ word = str(word)
878
+ transformed_word = word # Start with the original word
879
+ # If transformation is a string, apply it to all words
880
+ if isinstance(transformation, str):
881
+ if hasattr(word, transformation):
882
+ transformed_word = getattr(
883
+ word, transformation
884
+ )() # Apply the single transformation
885
+
886
+ # If transformation is a dictionary, apply case-specific transformations
887
+ elif isinstance(transformation, dict):
888
+ for case_type, transform in transformation.items():
889
+ if case_type == "lower" and word.islower() and transform:
890
+ transformed_word = getattr(word, transform)()
891
+ elif case_type == "upper" and word.isupper() and transform:
892
+ transformed_word = getattr(word, transform)()
893
+ elif case_type == "title" and word.istitle() and transform:
894
+ transformed_word = getattr(word, transform)()
895
+
896
+ # Only append if the transformed word is not already in the list
897
+ if transformed_word not in transformed_words:
898
+ transformed_words.append(transformed_word)
876
899
 
877
900
  return transformed_words
@@ -10,7 +10,7 @@ import numpy as np
10
10
 
11
11
  from risk.log import params
12
12
  from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils import to_rgba
13
+ from risk.network.plot.utils.color import get_domain_colors, to_rgba
14
14
 
15
15
 
16
16
  class Network:
@@ -222,7 +222,8 @@ class Network:
222
222
  np.ndarray: Array of RGBA colors adjusted for enrichment status.
223
223
  """
224
224
  # Get the initial domain colors for each node, which are returned as RGBA
225
- network_colors = self.graph.get_domain_colors(
225
+ network_colors = get_domain_colors(
226
+ graph=self.graph,
226
227
  cmap=cmap,
227
228
  color=color,
228
229
  min_scale=min_scale,
@@ -14,7 +14,8 @@ from risk.network.plot.canvas import Canvas
14
14
  from risk.network.plot.contour import Contour
15
15
  from risk.network.plot.labels import Labels
16
16
  from risk.network.plot.network import Network
17
- from risk.network.plot.utils import calculate_bounding_box, to_rgba
17
+ from risk.network.plot.utils.color import to_rgba
18
+ from risk.network.plot.utils.layout import calculate_bounding_box
18
19
 
19
20
 
20
21
  class NetworkPlotter(Canvas, Network, Contour, Labels):
@@ -0,0 +1,351 @@
1
+ """
2
+ risk/network/plot/utils/plot
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict, List, Tuple, Union
7
+
8
+ import matplotlib
9
+ import matplotlib.colors as mcolors
10
+ import numpy as np
11
+
12
+ from risk.network.graph import NetworkGraph
13
+ from risk.network.plot.utils.layout import calculate_centroids
14
+
15
+
16
+ def get_annotated_domain_colors(
17
+ graph: NetworkGraph,
18
+ cmap: str = "gist_rainbow",
19
+ color: Union[str, None] = None,
20
+ min_scale: float = 0.8,
21
+ max_scale: float = 1.0,
22
+ scale_factor: float = 1.0,
23
+ random_seed: int = 888,
24
+ ) -> np.ndarray:
25
+ """Get colors for the domains based on node annotations, or use a specified color.
26
+
27
+ Args:
28
+ graph (NetworkGraph): The network data and attributes to be visualized.
29
+ cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
30
+ color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
31
+ min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
32
+ Defaults to 0.8.
33
+ max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
34
+ Defaults to 1.0.
35
+ scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
36
+ enrichment. Higher values increase the contrast. Defaults to 1.0.
37
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
38
+
39
+ Returns:
40
+ np.ndarray: Array of RGBA colors for each domain.
41
+ """
42
+ # Generate domain colors based on the enrichment data
43
+ node_colors = get_domain_colors(
44
+ graph=graph,
45
+ cmap=cmap,
46
+ color=color,
47
+ min_scale=min_scale,
48
+ max_scale=max_scale,
49
+ scale_factor=scale_factor,
50
+ random_seed=random_seed,
51
+ )
52
+ annotated_colors = []
53
+ for _, node_ids in graph.domain_id_to_node_ids_map.items():
54
+ if len(node_ids) > 1:
55
+ # For multi-node domains, choose the brightest color based on RGB sum
56
+ domain_colors = np.array([node_colors[node] for node in node_ids])
57
+ brightest_color = domain_colors[
58
+ np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
59
+ ]
60
+ annotated_colors.append(brightest_color)
61
+ else:
62
+ # Single-node domains default to white (RGBA)
63
+ default_color = np.array([1.0, 1.0, 1.0, 1.0])
64
+ annotated_colors.append(default_color)
65
+
66
+ return np.array(annotated_colors)
67
+
68
+
69
+ def get_domain_colors(
70
+ graph: NetworkGraph,
71
+ cmap: str = "gist_rainbow",
72
+ color: Union[str, None] = None,
73
+ min_scale: float = 0.8,
74
+ max_scale: float = 1.0,
75
+ scale_factor: float = 1.0,
76
+ random_seed: int = 888,
77
+ ) -> np.ndarray:
78
+ """Generate composite colors for domains based on enrichment or specified colors.
79
+
80
+ Args:
81
+ graph (NetworkGraph): The network data and attributes to be visualized.
82
+ cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
83
+ color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
84
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
85
+ Controls the dimmest colors. Defaults to 0.8.
86
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
87
+ Controls the brightest colors. Defaults to 1.0.
88
+ scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores.
89
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
90
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments.
91
+ Defaults to 888.
92
+
93
+ Returns:
94
+ np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
95
+ """
96
+ # Get colors for each domain
97
+ domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
98
+ # Generate composite colors for nodes
99
+ node_colors = _get_composite_node_colors(graph=graph, domain_colors=domain_colors)
100
+ # Transform colors to ensure proper alpha values and intensity
101
+ transformed_colors = _transform_colors(
102
+ node_colors,
103
+ graph.node_enrichment_sums,
104
+ min_scale=min_scale,
105
+ max_scale=max_scale,
106
+ scale_factor=scale_factor,
107
+ )
108
+ return transformed_colors
109
+
110
+
111
+ def _get_domain_colors(
112
+ graph: NetworkGraph,
113
+ cmap: str = "gist_rainbow",
114
+ color: Union[str, None] = None,
115
+ random_seed: int = 888,
116
+ ) -> Dict[str, Any]:
117
+ """Get colors for each domain.
118
+
119
+ Args:
120
+ graph (NetworkGraph): The network data and attributes to be visualized.
121
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
122
+ color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
123
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
124
+
125
+ Returns:
126
+ dict: A dictionary mapping domain keys to their corresponding RGBA colors.
127
+ """
128
+ # Get colors for each domain based on node positions
129
+ domain_colors = _get_colors(
130
+ graph.network,
131
+ graph.domain_id_to_node_ids_map,
132
+ cmap=cmap,
133
+ color=color,
134
+ random_seed=random_seed,
135
+ )
136
+ return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))
137
+
138
+
139
+ def _get_composite_node_colors(graph: NetworkGraph, domain_colors: np.ndarray) -> np.ndarray:
140
+ """Generate composite colors for nodes based on domain colors and counts.
141
+
142
+ Args:
143
+ graph (NetworkGraph): The network data and attributes to be visualized.
144
+ domain_colors (np.ndarray): Array of colors corresponding to each domain.
145
+
146
+ Returns:
147
+ np.ndarray: Array of composite colors for each node.
148
+ """
149
+ # Determine the number of nodes
150
+ num_nodes = len(graph.node_coordinates)
151
+ # Initialize composite colors array with shape (number of nodes, 4) for RGBA
152
+ composite_colors = np.zeros((num_nodes, 4))
153
+ # Assign colors to nodes based on domain_colors
154
+ for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
155
+ color = domain_colors[domain_id]
156
+ for node in nodes:
157
+ composite_colors[node] = color
158
+
159
+ return composite_colors
160
+
161
+
162
+ def _get_colors(
163
+ network,
164
+ domain_id_to_node_ids_map,
165
+ cmap: str = "gist_rainbow",
166
+ color: Union[str, None] = None,
167
+ random_seed: int = 888,
168
+ ) -> List[Tuple]:
169
+ """Generate a list of RGBA colors based on domain centroids, ensuring that domains
170
+ close in space get maximally separated colors, while keeping some randomness.
171
+
172
+ Args:
173
+ network (NetworkX graph): The graph representing the network.
174
+ domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
175
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
176
+ color (str or None, optional): A specific color to use for all generated colors.
177
+ random_seed (int): Seed for random number generation. Defaults to 888.
178
+
179
+ Returns:
180
+ List[Tuple]: List of RGBA colors.
181
+ """
182
+ # Set random seed for reproducibility
183
+ np.random.seed(random_seed)
184
+ # Determine the number of colors to generate based on the number of domains
185
+ num_colors_to_generate = len(domain_id_to_node_ids_map)
186
+ if color:
187
+ # Generate all colors as the same specified color
188
+ rgba = to_rgba(color, num_repeats=num_colors_to_generate)
189
+ return rgba
190
+
191
+ # Load colormap
192
+ colormap = matplotlib.colormaps.get_cmap(cmap)
193
+ # Step 1: Calculate centroids for each domain
194
+ centroids = calculate_centroids(network, domain_id_to_node_ids_map)
195
+ # Step 2: Calculate pairwise distances between centroids
196
+ centroid_array = np.array(centroids)
197
+ dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
198
+ # Step 3: Assign distant colors to close centroids
199
+ color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
200
+ # Step 4: Randomly shift the entire color palette while maintaining relative distances
201
+ global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
202
+ color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
203
+ # Step 5: Ensure that all positions remain between 0 and 1
204
+ color_positions = np.clip(color_positions, 0, 1)
205
+
206
+ # Step 6: Generate RGBA colors based on positions
207
+ return [colormap(pos) for pos in color_positions]
208
+
209
+
210
+ def _assign_distant_colors(dist_matrix, num_colors_to_generate):
211
+ """Assign colors to centroids that are close in space, ensuring stark color differences.
212
+
213
+ Args:
214
+ dist_matrix (ndarray): Matrix of pairwise centroid distances.
215
+ num_colors_to_generate (int): Number of colors to generate.
216
+
217
+ Returns:
218
+ np.array: Array of color positions in the range [0, 1].
219
+ """
220
+ color_positions = np.zeros(num_colors_to_generate)
221
+ # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
222
+ proximity_order = sorted(
223
+ range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
224
+ )
225
+ # Step 2: Assign colors starting with the most distant points in proximity order
226
+ for i, idx in enumerate(proximity_order):
227
+ color_positions[idx] = i / num_colors_to_generate
228
+
229
+ # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
230
+ half_spectrum = int(num_colors_to_generate / 2)
231
+ for i in range(half_spectrum):
232
+ # Split the spectrum so that close centroids are assigned distant colors
233
+ color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
234
+ color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
235
+
236
+ return color_positions
237
+
238
+
239
+ def _transform_colors(
240
+ colors: np.ndarray,
241
+ enrichment_sums: np.ndarray,
242
+ min_scale: float = 0.8,
243
+ max_scale: float = 1.0,
244
+ scale_factor: float = 1.0,
245
+ ) -> np.ndarray:
246
+ """Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
247
+ very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
248
+
249
+ Args:
250
+ colors (np.ndarray): An array of RGBA colors.
251
+ enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
252
+ min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
253
+ max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
254
+ scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
255
+ values more. Defaults to 1.0.
256
+
257
+ Returns:
258
+ np.ndarray: The transformed array of RGBA colors with adjusted intensities.
259
+ """
260
+ # Ensure that min_scale is less than max_scale
261
+ if min_scale == max_scale:
262
+ min_scale = max_scale - 10e-6 # Avoid division by zero
263
+
264
+ # Replace black colors (#000000) with very dark grey (#1A1A1A)
265
+ black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
266
+ dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
267
+ # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
268
+ is_black = np.all(colors[:, :3] == black_color, axis=1)
269
+ colors[is_black, :3] = dark_grey
270
+
271
+ # Normalize the enrichment sums to the range [0, 1]
272
+ normalized_sums = enrichment_sums / np.max(enrichment_sums)
273
+ # Apply power scaling to dim lower values and emphasize higher values
274
+ scaled_sums = normalized_sums**scale_factor
275
+ # Linearly scale the normalized sums to the range [min_scale, max_scale]
276
+ scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
277
+ # Adjust RGB values based on scaled sums
278
+ for i in range(3): # Only adjust RGB values
279
+ colors[:, i] = scaled_sums * colors[:, i]
280
+
281
+ return colors
282
+
283
+
284
+ def to_rgba(
285
+ color: Union[str, List, Tuple, np.ndarray],
286
+ alpha: Union[float, None] = None,
287
+ num_repeats: Union[int, None] = None,
288
+ ) -> np.ndarray:
289
+ """Convert color(s) to RGBA format, applying alpha and repeating as needed.
290
+
291
+ Args:
292
+ color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
293
+ alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values
294
+ found in color.
295
+ num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
296
+
297
+ Returns:
298
+ np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
299
+ """
300
+
301
+ def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
302
+ """Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
303
+ # Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
304
+ if isinstance(c, str):
305
+ # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
306
+ rgba = np.array(mcolors.to_rgba(c))
307
+ elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
308
+ # Convert RGB (3) or RGBA (4) values to RGBA format
309
+ rgba = np.array(mcolors.to_rgba(c))
310
+ else:
311
+ raise ValueError(
312
+ f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
313
+ )
314
+
315
+ if alpha is not None: # Override alpha if provided
316
+ rgba[3] = alpha
317
+ return rgba
318
+
319
+ # If color is a 2D array of RGBA values, convert it to a list of lists
320
+ if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
321
+ color = [list(c) for c in color]
322
+
323
+ # Handle a single color (string or RGB/RGBA list/tuple)
324
+ if (
325
+ isinstance(color, str)
326
+ or isinstance(color, (list, tuple, np.ndarray))
327
+ and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
328
+ ):
329
+ rgba_color = convert_to_rgba(color)
330
+ if num_repeats:
331
+ return np.tile(
332
+ rgba_color, (num_repeats, 1)
333
+ ) # Repeat the color if num_repeats is provided
334
+ return np.array([rgba_color]) # Return a single color wrapped in a numpy array
335
+
336
+ # Handle a list/array of colors
337
+ elif isinstance(color, (list, tuple, np.ndarray)):
338
+ rgba_colors = np.array(
339
+ [convert_to_rgba(c) for c in color]
340
+ ) # Convert each color in the list to RGBA
341
+ # Handle repetition if num_repeats is provided
342
+ if num_repeats:
343
+ repeated_colors = np.array(
344
+ [rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
345
+ )
346
+ return repeated_colors
347
+
348
+ return rgba_colors
349
+
350
+ else:
351
+ raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
@@ -0,0 +1,53 @@
1
+ """
2
+ risk/network/plot/utils/layout
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import numpy as np
9
+
10
+
11
+ def calculate_bounding_box(
12
+ node_coordinates: np.ndarray, radius_margin: float = 1.05
13
+ ) -> Tuple[np.ndarray, float]:
14
+ """Calculate the bounding box of the network based on node coordinates.
15
+
16
+ Args:
17
+ node_coordinates (np.ndarray): Array of node coordinates (x, y).
18
+ radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
19
+
20
+ Returns:
21
+ tuple: Center of the bounding box and the radius (adjusted by the radius margin).
22
+ """
23
+ # Find minimum and maximum x, y coordinates
24
+ x_min, y_min = np.min(node_coordinates, axis=0)
25
+ x_max, y_max = np.max(node_coordinates, axis=0)
26
+ # Calculate the center of the bounding box
27
+ center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
28
+ # Calculate the radius of the bounding box, adjusted by the margin
29
+ radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
30
+ return center, radius
31
+
32
+
33
+ def calculate_centroids(network, domain_id_to_node_ids_map):
34
+ """Calculate the centroid for each domain based on node x and y coordinates in the network.
35
+
36
+ Args:
37
+ network (NetworkX graph): The graph representing the network.
38
+ domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
39
+
40
+ Returns:
41
+ List[Tuple[float, float]]: List of centroids (x, y) for each domain.
42
+ """
43
+ centroids = []
44
+ for domain_id, node_ids in domain_id_to_node_ids_map.items():
45
+ # Extract x and y coordinates from the network nodes
46
+ node_positions = np.array(
47
+ [[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
48
+ )
49
+ # Compute the centroid as the mean of the x and y coordinates
50
+ centroid = np.mean(node_positions, axis=0)
51
+ centroids.append(tuple(centroid))
52
+
53
+ return centroids
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.8b14
3
+ Version: 0.0.8b16
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -1,4 +1,4 @@
1
- risk/__init__.py,sha256=m9OKov65SuifezPxufH3hRsggm2xXDSzlRDR6p5FURo,113
1
+ risk/__init__.py,sha256=tmGuKXSNSCXcVzCf-nWdonudI0kVDYxEZHIa6MfXIOE,113
2
2
  risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
3
  risk/risk.py,sha256=slJXca_a726_D7oXwe765HaKTv3ZrOvhttyrWdCGPkA,21231
4
4
  risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
@@ -13,15 +13,16 @@ risk/neighborhoods/domains.py,sha256=D5MUIghbwyKKCAE8PN_HXvsO9NxLTGejQmyEqetD1Bk
13
13
  risk/neighborhoods/neighborhoods.py,sha256=M-wL4xB_BUTlSZg90swygO5NdrZ6hFUFqs6jsiZaqHk,18260
14
14
  risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
15
15
  risk/network/geometry.py,sha256=H1yGVVqgbfpzBzJwEheDLfvGLSA284jGQQTn612L4Vc,6759
16
- risk/network/graph.py,sha256=WxWh1BZvH9E_kgDb5HrwcQ-30hM8aNHPkBm2l3vzkAw,16973
16
+ risk/network/graph.py,sha256=x5cur1meitkR0YuE5vGxX0s_IFa5wkx8z44f_C1vK7U,6509
17
17
  risk/network/io.py,sha256=u0PPcKjp6Xze--7eDOlvalYkjQ9S2sjiC-ac2476PUI,22942
18
18
  risk/network/plot/__init__.py,sha256=MfmaXJgAZJgXZ2wrhK8pXwzETlcMaLChhWXKAozniAo,98
19
- risk/network/plot/canvas.py,sha256=LXHndwanWIBShChoPag8zgGHF2P9MFWYdEnLKc2eeb0,10295
20
- risk/network/plot/contour.py,sha256=YPG8Uz0VlJ4skLdGaTH_FmQN6A_ArK8XSTNo1LzkSws,14276
21
- risk/network/plot/labels.py,sha256=PV21hig6gQJZRgfUAP9-zpn4wEmQFEjS2_X63SgzWMs,42064
22
- risk/network/plot/network.py,sha256=t5eMh7mBJOh_Wa19aK8g_1zpL7maiXZAYw-TEFHxAVM,12819
23
- risk/network/plot/plotter.py,sha256=rQV4Db6Ud86FJm11uaBvgSuzpmGsrZxnsRnUKjg6w84,5572
24
- risk/network/plot/utils.py,sha256=jZgI8EysSjviQmdYAceZk2MwJXcdeFAkYp-odZNqV0k,6316
19
+ risk/network/plot/canvas.py,sha256=s5nB2c1xY5ymxaDYHwXWc74LjfgUaBFU3xXmsolZrcI,10343
20
+ risk/network/plot/contour.py,sha256=aS-UGF0M7MQ7zggx7A6d0dYZJFlOFQHjsU9Q8ls15nU,14282
21
+ risk/network/plot/labels.py,sha256=JkLTXglKK4L1nI1G7GcHg5vUAUJxympUerbbPWmcL1E,43419
22
+ risk/network/plot/network.py,sha256=h4KqQR5AcJIM23kKvz1toTsDtHo11HcMxGbdV-6tux4,12863
23
+ risk/network/plot/plotter.py,sha256=eZ3X6XCNiAbJjWTXw88btbwJ0TI7Fujk6A1EVkOqJkE,5620
24
+ risk/network/plot/utils/color.py,sha256=Wrq7j5vYYUi2TgZvzsowyVdx7tPP3EsvOBz_vQP6Oic,15091
25
+ risk/network/plot/utils/layout.py,sha256=znssSqe2VZzzSz47hLZtTuXwMTpHR9b8lkQPL0BX7OA,1950
25
26
  risk/stats/__init__.py,sha256=WcgoETQ-hS0LQqKRsAMIPtP15xZ-4eul6VUBuUx4Wzc,220
26
27
  risk/stats/hypergeom.py,sha256=o6Qnj31gCAKxr2uQirXrbv7XvdDJGEq69MFW-ubx_hA,2272
27
28
  risk/stats/poisson.py,sha256=8x9hB4DCukq4gNIlIKO-c_jYG1-BTwTX53oLauFyfj8,1793
@@ -29,8 +30,8 @@ risk/stats/stats.py,sha256=kvShov-94W6ffgDUTb522vB9hDJQSyTsYif_UIaFfSM,7059
29
30
  risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
30
31
  risk/stats/permutation/permutation.py,sha256=D84Rcpt6iTQniK0PfQGcw9bLcHbMt9p-ARcurUnIXZQ,10095
31
32
  risk/stats/permutation/test_functions.py,sha256=lftOude6hee0pyR80HlBD32522JkDoN5hrKQ9VEbuoY,2345
32
- risk_network-0.0.8b14.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
33
- risk_network-0.0.8b14.dist-info/METADATA,sha256=QUStGHs-tu2cPCAtumH931wZx1ll8PUaewDgPROd9hM,47498
34
- risk_network-0.0.8b14.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
35
- risk_network-0.0.8b14.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
36
- risk_network-0.0.8b14.dist-info/RECORD,,
33
+ risk_network-0.0.8b16.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
34
+ risk_network-0.0.8b16.dist-info/METADATA,sha256=6aKAr6uhZlBaoEflmp2A3vEx9naheGYAsb2RtF8cYZc,47498
35
+ risk_network-0.0.8b16.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
36
+ risk_network-0.0.8b16.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
37
+ risk_network-0.0.8b16.dist-info/RECORD,,
@@ -1,153 +0,0 @@
1
- """
2
- risk/network/plot/utils
3
- ~~~~~~~~~~~~~~~~~~~~~~~
4
- """
5
-
6
- from typing import List, Tuple, Union
7
-
8
- import matplotlib.colors as mcolors
9
- import numpy as np
10
-
11
- from risk.network.graph import NetworkGraph
12
-
13
-
14
- def get_annotated_domain_colors(
15
- graph: NetworkGraph,
16
- cmap: str = "gist_rainbow",
17
- color: Union[str, None] = None,
18
- min_scale: float = 0.8,
19
- max_scale: float = 1.0,
20
- scale_factor: float = 1.0,
21
- random_seed: int = 888,
22
- ) -> np.ndarray:
23
- """Get colors for the domains based on node annotations, or use a specified color.
24
-
25
- Args:
26
- graph (NetworkGraph): The network data and attributes to be visualized.
27
- cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
28
- color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
29
- min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
30
- Defaults to 0.8.
31
- max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
32
- Defaults to 1.0.
33
- scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
34
- enrichment. Higher values increase the contrast. Defaults to 1.0.
35
- random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
36
-
37
- Returns:
38
- np.ndarray: Array of RGBA colors for each domain.
39
- """
40
- # Generate domain colors based on the enrichment data
41
- node_colors = graph.get_domain_colors(
42
- cmap=cmap,
43
- color=color,
44
- min_scale=min_scale,
45
- max_scale=max_scale,
46
- scale_factor=scale_factor,
47
- random_seed=random_seed,
48
- )
49
- annotated_colors = []
50
- for _, node_ids in graph.domain_id_to_node_ids_map.items():
51
- if len(node_ids) > 1:
52
- # For multi-node domains, choose the brightest color based on RGB sum
53
- domain_colors = np.array([node_colors[node] for node in node_ids])
54
- brightest_color = domain_colors[
55
- np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
56
- ]
57
- annotated_colors.append(brightest_color)
58
- else:
59
- # Single-node domains default to white (RGBA)
60
- default_color = np.array([1.0, 1.0, 1.0, 1.0])
61
- annotated_colors.append(default_color)
62
-
63
- return np.array(annotated_colors)
64
-
65
-
66
- def calculate_bounding_box(
67
- node_coordinates: np.ndarray, radius_margin: float = 1.05
68
- ) -> Tuple[np.ndarray, float]:
69
- """Calculate the bounding box of the network based on node coordinates.
70
-
71
- Args:
72
- node_coordinates (np.ndarray): Array of node coordinates (x, y).
73
- radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
74
-
75
- Returns:
76
- tuple: Center of the bounding box and the radius (adjusted by the radius margin).
77
- """
78
- # Find minimum and maximum x, y coordinates
79
- x_min, y_min = np.min(node_coordinates, axis=0)
80
- x_max, y_max = np.max(node_coordinates, axis=0)
81
- # Calculate the center of the bounding box
82
- center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
83
- # Calculate the radius of the bounding box, adjusted by the margin
84
- radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
85
- return center, radius
86
-
87
-
88
- def to_rgba(
89
- color: Union[str, List, Tuple, np.ndarray],
90
- alpha: Union[float, None] = None,
91
- num_repeats: Union[int, None] = None,
92
- ) -> np.ndarray:
93
- """Convert color(s) to RGBA format, applying alpha and repeating as needed.
94
-
95
- Args:
96
- color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
97
- alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values
98
- found in color.
99
- num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
100
-
101
- Returns:
102
- np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
103
- """
104
-
105
- def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
106
- """Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
107
- # Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
108
- if isinstance(c, str):
109
- # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
110
- rgba = np.array(mcolors.to_rgba(c))
111
- elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
112
- # Convert RGB (3) or RGBA (4) values to RGBA format
113
- rgba = np.array(mcolors.to_rgba(c))
114
- else:
115
- raise ValueError(
116
- f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
117
- )
118
-
119
- if alpha is not None: # Override alpha if provided
120
- rgba[3] = alpha
121
- return rgba
122
-
123
- # If color is a 2D array of RGBA values, convert it to a list of lists
124
- if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
125
- color = [list(c) for c in color]
126
-
127
- # Handle a single color (string or RGB/RGBA list/tuple)
128
- if isinstance(color, (str, list, tuple)) and not any(
129
- isinstance(c, (list, tuple, np.ndarray)) for c in color
130
- ):
131
- rgba_color = convert_to_rgba(color)
132
- if num_repeats:
133
- return np.tile(
134
- rgba_color, (num_repeats, 1)
135
- ) # Repeat the color if num_repeats is provided
136
- return np.array([rgba_color]) # Return a single color wrapped in a numpy array
137
-
138
- # Handle a list/array of colors
139
- elif isinstance(color, (list, tuple, np.ndarray)):
140
- rgba_colors = np.array(
141
- [convert_to_rgba(c) for c in color]
142
- ) # Convert each color in the list to RGBA
143
- # Handle repetition if num_repeats is provided
144
- if num_repeats:
145
- repeated_colors = np.array(
146
- [rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
147
- )
148
- return repeated_colors
149
-
150
- return rgba_colors
151
-
152
- else:
153
- raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")