risk-network 0.0.8b26__py3-none-any.whl → 0.0.9b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +74 -47
  4. risk/annotations/io.py +47 -31
  5. risk/log/__init__.py +4 -2
  6. risk/log/{config.py → console.py} +5 -3
  7. risk/log/{params.py → parameters.py} +17 -42
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +446 -0
  10. risk/neighborhoods/community.py +255 -77
  11. risk/neighborhoods/domains.py +62 -31
  12. risk/neighborhoods/neighborhoods.py +156 -160
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +65 -57
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +194 -0
  17. risk/network/{graph.py → graph/network.py} +87 -37
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +56 -47
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +7 -4
  23. risk/network/{plot → plotter}/contour.py +22 -19
  24. risk/network/{plot → plotter}/labels.py +69 -74
  25. risk/network/{plot → plotter}/network.py +170 -34
  26. risk/network/{plot/utils/color.py → plotter/utils/colors.py} +104 -112
  27. risk/network/{plot → plotter}/utils/layout.py +8 -5
  28. risk/risk.py +11 -500
  29. risk/stats/__init__.py +8 -4
  30. risk/stats/binom.py +51 -0
  31. risk/stats/chi2.py +69 -0
  32. risk/stats/hypergeom.py +27 -17
  33. risk/stats/permutation/__init__.py +1 -1
  34. risk/stats/permutation/permutation.py +44 -38
  35. risk/stats/permutation/test_functions.py +25 -17
  36. risk/stats/poisson.py +15 -9
  37. risk/stats/stats.py +15 -13
  38. risk/stats/zscore.py +68 -0
  39. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
  40. risk_network-0.0.9b26.dist-info/RECORD +44 -0
  41. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
  42. risk/network/plot/__init__.py +0 -6
  43. risk/network/plot/plotter.py +0 -137
  44. risk_network-0.0.8b26.dist-info/RECORD +0 -37
  45. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
  46. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
@@ -9,8 +9,7 @@ import matplotlib
9
9
  import matplotlib.colors as mcolors
10
10
  import numpy as np
11
11
 
12
- from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils.layout import calculate_centroids
12
+ from risk.network.graph.network import NetworkGraph
14
13
 
15
14
 
16
15
  def get_annotated_domain_colors(
@@ -22,6 +21,7 @@ def get_annotated_domain_colors(
22
21
  min_scale: float = 0.8,
23
22
  max_scale: float = 1.0,
24
23
  scale_factor: float = 1.0,
24
+ ids_to_colors: Union[Dict[int, Any], None] = None,
25
25
  random_seed: int = 888,
26
26
  ) -> np.ndarray:
27
27
  """Get colors for the domains based on node annotations, or use a specified color.
@@ -35,14 +35,15 @@ def get_annotated_domain_colors(
35
35
  blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
36
36
  min_scale (float, optional): Minimum scale for color intensity when generating domain colors. Defaults to 0.8.
37
37
  max_scale (float, optional): Maximum scale for color intensity when generating domain colors. Defaults to 1.0.
38
- scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on enrichment. Higher values
38
+ scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on significance. Higher values
39
39
  increase the contrast. Defaults to 1.0.
40
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
40
41
  random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
41
42
 
42
43
  Returns:
43
44
  np.ndarray: Array of RGBA colors for each domain.
44
45
  """
45
- # Generate domain colors based on the enrichment data
46
+ # Generate domain colors based on the significance data
46
47
  node_colors = get_domain_colors(
47
48
  graph=graph,
48
49
  cmap=cmap,
@@ -52,6 +53,7 @@ def get_annotated_domain_colors(
52
53
  min_scale=min_scale,
53
54
  max_scale=max_scale,
54
55
  scale_factor=scale_factor,
56
+ ids_to_colors=ids_to_colors,
55
57
  random_seed=random_seed,
56
58
  )
57
59
  annotated_colors = []
@@ -59,16 +61,14 @@ def get_annotated_domain_colors(
59
61
  if len(node_ids) > 1:
60
62
  # For multi-node domains, choose the brightest color based on RGB sum
61
63
  domain_colors = np.array([node_colors[node] for node in node_ids])
62
- brightest_color = domain_colors[
63
- np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
64
- ]
65
- annotated_colors.append(brightest_color)
64
+ color = domain_colors[np.argmax(domain_colors[:, :3].sum(axis=1))] # Sum the RGB values
66
65
  else:
67
66
  # Single-node domains default to white (RGBA)
68
- default_color = np.array([1.0, 1.0, 1.0, 1.0])
69
- annotated_colors.append(default_color)
67
+ color = np.array([1.0, 1.0, 1.0, 1.0])
68
+
69
+ annotated_colors.append(color)
70
70
 
71
- return np.array(annotated_colors)
71
+ return annotated_colors
72
72
 
73
73
 
74
74
  def get_domain_colors(
@@ -80,9 +80,10 @@ def get_domain_colors(
80
80
  min_scale: float = 0.8,
81
81
  max_scale: float = 1.0,
82
82
  scale_factor: float = 1.0,
83
+ ids_to_colors: Union[Dict[int, Any], None] = None,
83
84
  random_seed: int = 888,
84
85
  ) -> np.ndarray:
85
- """Generate composite colors for domains based on enrichment or specified colors.
86
+ """Generate composite colors for domains based on significance or specified colors.
86
87
 
87
88
  Args:
88
89
  graph (NetworkGraph): The network data and attributes to be visualized.
@@ -95,23 +96,29 @@ def get_domain_colors(
95
96
  Defaults to 0.8.
96
97
  max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap. Controls the brightest colors.
97
98
  Defaults to 1.0.
98
- scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores. Higher values increase
99
+ scale_factor (float, optional): Exponent for adjusting the color scaling based on significance scores. Higher values increase
99
100
  contrast by dimming lower scores more. Defaults to 1.0.
101
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
100
102
  random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments. Defaults to 888.
101
103
 
102
104
  Returns:
103
- np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
105
+ np.ndarray: Array of RGBA colors generated for each domain, based on significance or the specified color.
104
106
  """
105
107
  # Get colors for each domain
106
- domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
108
+ domain_ids_to_colors = _get_domain_ids_to_colors(
109
+ graph=graph, cmap=cmap, color=color, ids_to_colors=ids_to_colors, random_seed=random_seed
110
+ )
107
111
  # Generate composite colors for nodes
108
112
  node_colors = _get_composite_node_colors(
109
- graph=graph, domain_colors=domain_colors, blend_colors=blend_colors, blend_gamma=blend_gamma
113
+ graph=graph,
114
+ domain_ids_to_colors=domain_ids_to_colors,
115
+ blend_colors=blend_colors,
116
+ blend_gamma=blend_gamma,
110
117
  )
111
118
  # Transform colors to ensure proper alpha values and intensity
112
119
  transformed_colors = _transform_colors(
113
120
  node_colors,
114
- graph.node_enrichment_sums,
121
+ graph.node_significance_sums,
115
122
  min_scale=min_scale,
116
123
  max_scale=max_scale,
117
124
  scale_factor=scale_factor,
@@ -119,10 +126,11 @@ def get_domain_colors(
119
126
  return transformed_colors
120
127
 
121
128
 
122
- def _get_domain_colors(
129
+ def _get_domain_ids_to_colors(
123
130
  graph: NetworkGraph,
124
131
  cmap: str = "gist_rainbow",
125
132
  color: Union[str, List, Tuple, np.ndarray, None] = None,
133
+ ids_to_colors: Union[Dict[int, Any], None] = None,
126
134
  random_seed: int = 888,
127
135
  ) -> Dict[int, Any]:
128
136
  """Get colors for each domain.
@@ -132,6 +140,7 @@ def _get_domain_colors(
132
140
  cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
133
141
  color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
134
142
  If None, the colormap will be used. Defaults to None.
143
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
135
144
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
136
145
 
137
146
  Returns:
@@ -139,23 +148,35 @@ def _get_domain_colors(
139
148
  """
140
149
  # Get colors for each domain based on node positions
141
150
  domain_colors = _get_colors(
142
- graph.network,
143
151
  graph.domain_id_to_node_ids_map,
144
152
  cmap=cmap,
145
153
  color=color,
146
154
  random_seed=random_seed,
147
155
  )
148
- return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))
156
+ # Assign colors to domains either based on the generated colormap or the user-specified colors
157
+ domain_ids_to_colors = {}
158
+ for domain_id, domain_color in zip(graph.domain_id_to_node_ids_map.keys(), domain_colors):
159
+ if ids_to_colors and domain_id in ids_to_colors:
160
+ # Convert user-specified colors to RGBA format
161
+ user_rgba = to_rgba(ids_to_colors[domain_id])
162
+ domain_ids_to_colors[domain_id] = user_rgba
163
+ else:
164
+ domain_ids_to_colors[domain_id] = domain_color
165
+
166
+ return domain_ids_to_colors
149
167
 
150
168
 
151
169
  def _get_composite_node_colors(
152
- graph, domain_colors: np.ndarray, blend_colors: bool = False, blend_gamma: float = 2.2
170
+ graph: NetworkGraph,
171
+ domain_ids_to_colors: Dict[int, Any],
172
+ blend_colors: bool = False,
173
+ blend_gamma: float = 2.2,
153
174
  ) -> np.ndarray:
154
- """Generate composite colors for nodes based on domain colors and enrichment values, with optional color blending.
175
+ """Generate composite colors for nodes based on domain colors and significance values, with optional color blending.
155
176
 
156
177
  Args:
157
178
  graph (NetworkGraph): The network data and attributes to be visualized.
158
- domain_colors (np.ndarray): Array or list of RGBA colors corresponding to each domain.
179
+ domain_ids_to_colors (Dict[int, Any]): Mapping of domain IDs to RGBA colors.
159
180
  blend_colors (bool): Whether to blend colors for nodes with multiple domains. Defaults to False.
160
181
  blend_gamma (float, optional): Gamma correction factor to be used for perceptual color blending.
161
182
  This parameter is only relevant if blend_colors is True. Defaults to 2.2.
@@ -167,36 +188,35 @@ def _get_composite_node_colors(
167
188
  num_nodes = len(graph.node_coordinates)
168
189
  # Initialize composite colors array with shape (number of nodes, 4) for RGBA
169
190
  composite_colors = np.zeros((num_nodes, 4))
170
-
171
191
  # If blending is not required, directly assign domain colors to nodes
172
192
  if not blend_colors:
173
193
  for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
174
- color = domain_colors[domain_id]
194
+ color = domain_ids_to_colors[domain_id]
175
195
  for node in nodes:
176
196
  composite_colors[node] = color
177
197
 
178
198
  # If blending is required
179
199
  else:
180
- for node, node_info in graph.node_id_to_domain_ids_and_enrichments_map.items():
200
+ for node, node_info in graph.node_id_to_domain_ids_and_significance_map.items():
181
201
  domains = node_info["domains"] # List of domain IDs
182
- enrichments = node_info["enrichments"] # List of enrichment values
183
- # Filter domains and enrichments to keep only those with corresponding colors in domain_colors
184
- filtered_domains_enrichments = [
185
- (domain_id, enrichment)
186
- for domain_id, enrichment in zip(domains, enrichments)
187
- if domain_id in domain_colors
202
+ significances = node_info["significances"] # List of significance values
203
+ # Filter domains and significances to keep only those with corresponding colors in domain_ids_to_colors
204
+ filtered_domains_significances = [
205
+ (domain_id, significance)
206
+ for domain_id, significance in zip(domains, significances)
207
+ if domain_id in domain_ids_to_colors
188
208
  ]
189
209
  # If no valid domains exist, skip this node
190
- if not filtered_domains_enrichments:
210
+ if not filtered_domains_significances:
191
211
  continue
192
212
 
193
- # Unpack filtered domains and enrichments
194
- filtered_domains, filtered_enrichments = zip(*filtered_domains_enrichments)
213
+ # Unpack filtered domains and significances
214
+ filtered_domains, filtered_significances = zip(*filtered_domains_significances)
195
215
  # Get the colors corresponding to the valid filtered domains
196
- colors = [domain_colors[domain_id] for domain_id in filtered_domains]
216
+ colors = [domain_ids_to_colors[domain_id] for domain_id in filtered_domains]
197
217
  # Blend the colors using the given gamma (default is 2.2 if None)
198
218
  gamma = blend_gamma if blend_gamma is not None else 2.2
199
- composite_color = _blend_colors_perceptually(colors, filtered_enrichments, gamma)
219
+ composite_color = _blend_colors_perceptually(colors, filtered_significances, gamma)
200
220
  # Assign the composite color to the node
201
221
  composite_colors[node] = composite_color
202
222
 
@@ -204,99 +224,57 @@ def _get_composite_node_colors(
204
224
 
205
225
 
206
226
  def _get_colors(
207
- network,
208
- domain_id_to_node_ids_map,
227
+ domain_id_to_node_ids_map: Dict[int, Any],
209
228
  cmap: str = "gist_rainbow",
210
229
  color: Union[str, List, Tuple, np.ndarray, None] = None,
211
230
  random_seed: int = 888,
212
231
  ) -> List[Tuple]:
213
- """Generate a list of RGBA colors based on domain centroids, ensuring that domains
214
- close in space get maximally separated colors, while keeping some randomness.
232
+ """Generate a list of RGBA colors for domains, ensuring maximally separated colors for nearby domains.
215
233
 
216
234
  Args:
217
- network (NetworkX graph): The graph representing the network.
218
235
  domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
219
236
  cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
220
- color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
237
+ color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use.
221
238
  If None, the colormap will be used. Defaults to None.
222
239
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
223
240
 
224
241
  Returns:
225
- List[Tuple]: List of RGBA colors.
242
+ List[Tuple]: List of RGBA colors for each domain.
226
243
  """
227
- # Set random seed for reproducibility
228
244
  np.random.seed(random_seed)
229
- # Determine the number of colors to generate based on the number of domains
230
- num_colors_to_generate = len(domain_id_to_node_ids_map)
245
+ num_domains = len(domain_id_to_node_ids_map)
231
246
  if color:
232
- # Generate all colors as the same specified color
233
- rgba = to_rgba(color, num_repeats=num_colors_to_generate)
247
+ # If a single color is specified, apply it to all domains
248
+ rgba = to_rgba(color, num_repeats=num_domains)
234
249
  return rgba
235
250
 
236
- # Load colormap
251
+ # Load colormap and generate a large, maximally separated set of colors
237
252
  colormap = matplotlib.colormaps.get_cmap(cmap)
238
- # Step 1: Calculate centroids for each domain
239
- centroids = calculate_centroids(network, domain_id_to_node_ids_map)
240
- # Step 2: Calculate pairwise distances between centroids
241
- centroid_array = np.array(centroids)
242
- dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
243
- # Step 3: Assign distant colors to close centroids
244
- color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
245
- # Step 4: Randomly shift the entire color palette while maintaining relative distances
246
- global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
247
- color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
248
- # Step 5: Ensure that all positions remain between 0 and 1
249
- color_positions = np.clip(color_positions, 0, 1)
250
-
251
- # Step 6: Generate RGBA colors based on positions
252
- return [colormap(pos) for pos in color_positions]
253
-
254
-
255
- def _assign_distant_colors(dist_matrix, num_colors_to_generate):
256
- """Assign colors to centroids that are close in space, ensuring stark color differences.
253
+ color_positions = np.linspace(0, 1, num_domains, endpoint=False)
254
+ # Shuffle color positions to avoid spatial clustering of similar colors
255
+ np.random.shuffle(color_positions)
256
+ # Assign colors based on positions in the colormap
257
+ colors = [colormap(pos) for pos in color_positions]
257
258
 
258
- Args:
259
- dist_matrix (ndarray): Matrix of pairwise centroid distances.
260
- num_colors_to_generate (int): Number of colors to generate.
261
-
262
- Returns:
263
- np.array: Array of color positions in the range [0, 1].
264
- """
265
- color_positions = np.zeros(num_colors_to_generate)
266
- # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
267
- proximity_order = sorted(
268
- range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
269
- )
270
- # Step 2: Assign colors starting with the most distant points in proximity order
271
- for i, idx in enumerate(proximity_order):
272
- color_positions[idx] = i / num_colors_to_generate
273
-
274
- # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
275
- half_spectrum = int(num_colors_to_generate / 2)
276
- for i in range(half_spectrum):
277
- # Split the spectrum so that close centroids are assigned distant colors
278
- color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
279
- color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
280
-
281
- return color_positions
259
+ return colors
282
260
 
283
261
 
284
262
  def _blend_colors_perceptually(
285
- colors: Union[List, Tuple, np.ndarray], enrichments: List[float], gamma: float = 2.2
263
+ colors: Union[List, Tuple, np.ndarray], significances: List[float], gamma: float = 2.2
286
264
  ) -> Tuple[float, float, float, float]:
287
265
  """Blends a list of RGBA colors using gamma correction for perceptually uniform color mixing.
288
266
 
289
267
  Args:
290
268
  colors (List, Tuple, np.ndarray): List of RGBA colors. Can be a list, tuple, or NumPy array of RGBA values.
291
- enrichments (List[float]): Corresponding list of enrichment values.
269
+ significances (List[float]): Corresponding list of significance values.
292
270
  gamma (float, optional): Gamma correction factor, default is 2.2 (typical for perceptual blending).
293
271
 
294
272
  Returns:
295
273
  Tuple[float, float, float, float]: The blended RGBA color.
296
274
  """
297
- # Normalize enrichments so they sum up to 1 (proportions)
298
- total_enrichment = sum(enrichments)
299
- proportions = [enrichment / total_enrichment for enrichment in enrichments]
275
+ # Normalize significances so they sum up to 1 (proportions)
276
+ total_significance = sum(significances)
277
+ proportions = [significance / total_significance for significance in significances]
300
278
  # Convert colors to gamma-corrected space (apply gamma correction to RGB channels)
301
279
  gamma_corrected_colors = [[channel**gamma for channel in color[:3]] for color in colors]
302
280
  # Blend the colors in gamma-corrected space
@@ -310,17 +288,17 @@ def _blend_colors_perceptually(
310
288
 
311
289
  def _transform_colors(
312
290
  colors: np.ndarray,
313
- enrichment_sums: np.ndarray,
291
+ significance_sums: np.ndarray,
314
292
  min_scale: float = 0.8,
315
293
  max_scale: float = 1.0,
316
294
  scale_factor: float = 1.0,
317
295
  ) -> np.ndarray:
318
- """Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
296
+ """Transform colors using power scaling to emphasize high significance sums more. Black colors are replaced with
319
297
  very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
320
298
 
321
299
  Args:
322
300
  colors (np.ndarray): An array of RGBA colors.
323
- enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
301
+ significance_sums (np.ndarray): An array of significance sums corresponding to the colors.
324
302
  min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
325
303
  max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
326
304
  scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
@@ -333,19 +311,28 @@ def _transform_colors(
333
311
  if min_scale == max_scale:
334
312
  min_scale = max_scale - 10e-6 # Avoid division by zero
335
313
 
314
+ # Replace invalid values in colors early
315
+ colors = np.nan_to_num(colors, nan=0.0) # Replace NaN with black
336
316
  # Replace black colors (#000000) with very dark grey (#1A1A1A)
337
317
  black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
338
318
  dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
339
- # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
340
319
  is_black = np.all(colors[:, :3] == black_color, axis=1)
341
320
  colors[is_black, :3] = dark_grey
342
321
 
343
- # Normalize the enrichment sums to the range [0, 1]
344
- normalized_sums = enrichment_sums / np.max(enrichment_sums)
345
- # Apply power scaling to dim lower values and emphasize higher values
322
+ # Handle invalid or zero significance sums
323
+ max_significance = np.max(significance_sums)
324
+ if max_significance == 0:
325
+ max_significance = 1 # Avoid division by zero
326
+ normalized_sums = significance_sums / max_significance
327
+ # Replace NaN values in normalized sums
328
+ normalized_sums = np.nan_to_num(normalized_sums, nan=0.0)
329
+
330
+ # Apply power scaling to emphasize higher significance values
346
331
  scaled_sums = normalized_sums**scale_factor
347
332
  # Linearly scale the normalized sums to the range [min_scale, max_scale]
348
333
  scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
334
+ # Replace NaN or invalid scaled sums
335
+ scaled_sums = np.nan_to_num(scaled_sums, nan=min_scale)
349
336
  # Adjust RGB values based on scaled sums
350
337
  for i in range(3): # Only adjust RGB values
351
338
  colors[:, i] = scaled_sums * colors[:, i]
@@ -354,15 +341,15 @@ def _transform_colors(
354
341
 
355
342
 
356
343
  def to_rgba(
357
- color: Union[str, List, Tuple, np.ndarray],
344
+ color: Union[str, List, Tuple, np.ndarray, None],
358
345
  alpha: Union[float, None] = None,
359
346
  num_repeats: Union[int, None] = None,
360
347
  ) -> np.ndarray:
361
348
  """Convert color(s) to RGBA format, applying alpha and repeating as needed.
362
349
 
363
350
  Args:
364
- color (str, List, Tuple, np.ndarray): The color(s) to convert. Can be a string (e.g., 'red'), a list or tuple of RGB/RGBA values,
365
- or an `np.ndarray` of colors.
351
+ color (str, List, Tuple, np.ndarray, None): The color(s) to convert. Can be a string (e.g., 'red'), a list or tuple of RGB/RGBA values,
352
+ or an `np.ndarray` of colors. If None, the function will return an array of white (RGBA) colors.
366
353
  alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values found
367
354
  in color.
368
355
  num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
@@ -387,8 +374,13 @@ def to_rgba(
387
374
 
388
375
  if alpha is not None: # Override alpha if provided
389
376
  rgba[3] = alpha
377
+
390
378
  return rgba
391
379
 
380
+ # Default to white if no color is provided
381
+ if color is None:
382
+ color = "white"
383
+
392
384
  # If color is a 2D array of RGBA values, convert it to a list of lists
393
385
  if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
394
386
  color = [list(c) for c in color]
@@ -404,10 +396,11 @@ def to_rgba(
404
396
  return np.tile(
405
397
  rgba_color, (num_repeats, 1)
406
398
  ) # Repeat the color if num_repeats is provided
407
- return np.array([rgba_color]) # Return a single color wrapped in a numpy array
399
+
400
+ return rgba_color
408
401
 
409
402
  # Handle a list/array of colors
410
- elif isinstance(color, (list, tuple, np.ndarray)):
403
+ if isinstance(color, (list, tuple, np.ndarray)):
411
404
  rgba_colors = np.array(
412
405
  [convert_to_rgba(c) for c in color]
413
406
  ) # Convert each color in the list to RGBA
@@ -420,5 +413,4 @@ def to_rgba(
420
413
 
421
414
  return rgba_colors
422
415
 
423
- else:
424
- raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
416
+ raise ValueError("Color must be a valid string, RGB/RGBA, or array of RGB/RGBA colors.")
@@ -3,8 +3,9 @@ risk/network/plot/utils/layout
3
3
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Tuple
6
+ from typing import Any, Dict, List, Tuple
7
7
 
8
+ import networkx as nx
8
9
  import numpy as np
9
10
 
10
11
 
@@ -48,7 +49,7 @@ def refine_center_iteratively(
48
49
  tuple: Refined center and the final radius.
49
50
  """
50
51
  # Initial center and radius based on the bounding box
51
- center, radius = calculate_bounding_box(node_coordinates, radius_margin)
52
+ center, _ = calculate_bounding_box(node_coordinates, radius_margin)
52
53
  for _ in range(max_iterations):
53
54
  # Shift the coordinates based on the current center
54
55
  shifted_coordinates = node_coordinates - center
@@ -68,18 +69,20 @@ def refine_center_iteratively(
68
69
  return center, new_radius
69
70
 
70
71
 
71
- def calculate_centroids(network, domain_id_to_node_ids_map):
72
+ def calculate_centroids(
73
+ network: nx.Graph, domain_id_to_node_ids_map: Dict[int, Any]
74
+ ) -> List[Tuple[float, float]]:
72
75
  """Calculate the centroid for each domain based on node x and y coordinates in the network.
73
76
 
74
77
  Args:
75
- network (NetworkX graph): The graph representing the network.
78
+ network (nx.Graph): The graph representing the network.
76
79
  domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
77
80
 
78
81
  Returns:
79
82
  List[Tuple[float, float]]: List of centroids (x, y) for each domain.
80
83
  """
81
84
  centroids = []
82
- for domain_id, node_ids in domain_id_to_node_ids_map.items():
85
+ for _, node_ids in domain_id_to_node_ids_map.items():
83
86
  # Extract x and y coordinates from the network nodes
84
87
  node_positions = np.array(
85
88
  [[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]