risk-network 0.0.7b11__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,424 @@
1
+ """
2
+ risk/network/plot/utils/color
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict, List, Tuple, Union
7
+
8
+ import matplotlib
9
+ import matplotlib.colors as mcolors
10
+ import numpy as np
11
+
12
+ from risk.network.graph import NetworkGraph
13
+ from risk.network.plot.utils.layout import calculate_centroids
14
+
15
+
16
+ def get_annotated_domain_colors(
17
+ graph: NetworkGraph,
18
+ cmap: str = "gist_rainbow",
19
+ color: Union[str, List, Tuple, np.ndarray, None] = None,
20
+ blend_colors: bool = False,
21
+ blend_gamma: float = 2.2,
22
+ min_scale: float = 0.8,
23
+ max_scale: float = 1.0,
24
+ scale_factor: float = 1.0,
25
+ random_seed: int = 888,
26
+ ) -> np.ndarray:
27
+ """Get colors for the domains based on node annotations, or use a specified color.
28
+
29
+ Args:
30
+ graph (NetworkGraph): The network data and attributes to be visualized.
31
+ cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
32
+ color (str, List, Tuple, np.ndarray, or None, optional): Color to use for the domains. Can be a single color or an array of colors.
33
+ If None, the colormap will be used. Defaults to None.
34
+ blend_colors (bool, optional): Whether to blend colors for nodes with multiple domains. Defaults to False.
35
+ blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
36
+ min_scale (float, optional): Minimum scale for color intensity when generating domain colors. Defaults to 0.8.
37
+ max_scale (float, optional): Maximum scale for color intensity when generating domain colors. Defaults to 1.0.
38
+ scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on enrichment. Higher values
39
+ increase the contrast. Defaults to 1.0.
40
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
41
+
42
+ Returns:
43
+ np.ndarray: Array of RGBA colors for each domain.
44
+ """
45
+ # Generate domain colors based on the enrichment data
46
+ node_colors = get_domain_colors(
47
+ graph=graph,
48
+ cmap=cmap,
49
+ color=color,
50
+ blend_colors=blend_colors,
51
+ blend_gamma=blend_gamma,
52
+ min_scale=min_scale,
53
+ max_scale=max_scale,
54
+ scale_factor=scale_factor,
55
+ random_seed=random_seed,
56
+ )
57
+ annotated_colors = []
58
+ for _, node_ids in graph.domain_id_to_node_ids_map.items():
59
+ if len(node_ids) > 1:
60
+ # For multi-node domains, choose the brightest color based on RGB sum
61
+ domain_colors = np.array([node_colors[node] for node in node_ids])
62
+ brightest_color = domain_colors[
63
+ np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
64
+ ]
65
+ annotated_colors.append(brightest_color)
66
+ else:
67
+ # Single-node domains default to white (RGBA)
68
+ default_color = np.array([1.0, 1.0, 1.0, 1.0])
69
+ annotated_colors.append(default_color)
70
+
71
+ return np.array(annotated_colors)
72
+
73
+
74
+ def get_domain_colors(
75
+ graph: NetworkGraph,
76
+ cmap: str = "gist_rainbow",
77
+ color: Union[str, List, Tuple, np.ndarray, None] = None,
78
+ blend_colors: bool = False,
79
+ blend_gamma: float = 2.2,
80
+ min_scale: float = 0.8,
81
+ max_scale: float = 1.0,
82
+ scale_factor: float = 1.0,
83
+ random_seed: int = 888,
84
+ ) -> np.ndarray:
85
+ """Generate composite colors for domains based on enrichment or specified colors.
86
+
87
+ Args:
88
+ graph (NetworkGraph): The network data and attributes to be visualized.
89
+ cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
90
+ color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for all domains.
91
+ If None, the colormap will be used. Defaults to None.
92
+ blend_colors (bool, optional): Whether to blend colors for nodes with multiple domains. Defaults to False.
93
+ blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
94
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap. Controls the dimmest colors.
95
+ Defaults to 0.8.
96
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap. Controls the brightest colors.
97
+ Defaults to 1.0.
98
+ scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores. Higher values increase
99
+ contrast by dimming lower scores more. Defaults to 1.0.
100
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments. Defaults to 888.
101
+
102
+ Returns:
103
+ np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
104
+ """
105
+ # Get colors for each domain
106
+ domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
107
+ # Generate composite colors for nodes
108
+ node_colors = _get_composite_node_colors(
109
+ graph=graph, domain_colors=domain_colors, blend_colors=blend_colors, blend_gamma=blend_gamma
110
+ )
111
+ # Transform colors to ensure proper alpha values and intensity
112
+ transformed_colors = _transform_colors(
113
+ node_colors,
114
+ graph.node_enrichment_sums,
115
+ min_scale=min_scale,
116
+ max_scale=max_scale,
117
+ scale_factor=scale_factor,
118
+ )
119
+ return transformed_colors
120
+
121
+
122
+ def _get_domain_colors(
123
+ graph: NetworkGraph,
124
+ cmap: str = "gist_rainbow",
125
+ color: Union[str, List, Tuple, np.ndarray, None] = None,
126
+ random_seed: int = 888,
127
+ ) -> Dict[int, Any]:
128
+ """Get colors for each domain.
129
+
130
+ Args:
131
+ graph (NetworkGraph): The network data and attributes to be visualized.
132
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
133
+ color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
134
+ If None, the colormap will be used. Defaults to None.
135
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
136
+
137
+ Returns:
138
+ Dict[int, Any]: A dictionary mapping domain keys to their corresponding RGBA colors.
139
+ """
140
+ # Get colors for each domain based on node positions
141
+ domain_colors = _get_colors(
142
+ graph.network,
143
+ graph.domain_id_to_node_ids_map,
144
+ cmap=cmap,
145
+ color=color,
146
+ random_seed=random_seed,
147
+ )
148
+ return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))
149
+
150
+
151
+ def _get_composite_node_colors(
152
+ graph, domain_colors: np.ndarray, blend_colors: bool = False, blend_gamma: float = 2.2
153
+ ) -> np.ndarray:
154
+ """Generate composite colors for nodes based on domain colors and enrichment values, with optional color blending.
155
+
156
+ Args:
157
+ graph (NetworkGraph): The network data and attributes to be visualized.
158
+ domain_colors (np.ndarray): Array or list of RGBA colors corresponding to each domain.
159
+ blend_colors (bool): Whether to blend colors for nodes with multiple domains. Defaults to False.
160
+ blend_gamma (float, optional): Gamma correction factor to be used for perceptual color blending.
161
+ This parameter is only relevant if blend_colors is True. Defaults to 2.2.
162
+
163
+ Returns:
164
+ np.ndarray: Array of composite colors for each node.
165
+ """
166
+ # Determine the number of nodes
167
+ num_nodes = len(graph.node_coordinates)
168
+ # Initialize composite colors array with shape (number of nodes, 4) for RGBA
169
+ composite_colors = np.zeros((num_nodes, 4))
170
+
171
+ # If blending is not required, directly assign domain colors to nodes
172
+ if not blend_colors:
173
+ for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
174
+ color = domain_colors[domain_id]
175
+ for node in nodes:
176
+ composite_colors[node] = color
177
+
178
+ # If blending is required
179
+ else:
180
+ for node, node_info in graph.node_id_to_domain_ids_and_enrichments_map.items():
181
+ domains = node_info["domains"] # List of domain IDs
182
+ enrichments = node_info["enrichments"] # List of enrichment values
183
+ # Filter domains and enrichments to keep only those with corresponding colors in domain_colors
184
+ filtered_domains_enrichments = [
185
+ (domain_id, enrichment)
186
+ for domain_id, enrichment in zip(domains, enrichments)
187
+ if domain_id in domain_colors
188
+ ]
189
+ # If no valid domains exist, skip this node
190
+ if not filtered_domains_enrichments:
191
+ continue
192
+
193
+ # Unpack filtered domains and enrichments
194
+ filtered_domains, filtered_enrichments = zip(*filtered_domains_enrichments)
195
+ # Get the colors corresponding to the valid filtered domains
196
+ colors = [domain_colors[domain_id] for domain_id in filtered_domains]
197
+ # Blend the colors using the given gamma (default is 2.2 if None)
198
+ gamma = blend_gamma if blend_gamma is not None else 2.2
199
+ composite_color = _blend_colors_perceptually(colors, filtered_enrichments, gamma)
200
+ # Assign the composite color to the node
201
+ composite_colors[node] = composite_color
202
+
203
+ return composite_colors
204
+
205
+
206
+ def _get_colors(
207
+ network,
208
+ domain_id_to_node_ids_map,
209
+ cmap: str = "gist_rainbow",
210
+ color: Union[str, List, Tuple, np.ndarray, None] = None,
211
+ random_seed: int = 888,
212
+ ) -> List[Tuple]:
213
+ """Generate a list of RGBA colors based on domain centroids, ensuring that domains
214
+ close in space get maximally separated colors, while keeping some randomness.
215
+
216
+ Args:
217
+ network (NetworkX graph): The graph representing the network.
218
+ domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
219
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
220
+ color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
221
+ If None, the colormap will be used. Defaults to None.
222
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
223
+
224
+ Returns:
225
+ List[Tuple]: List of RGBA colors.
226
+ """
227
+ # Set random seed for reproducibility
228
+ np.random.seed(random_seed)
229
+ # Determine the number of colors to generate based on the number of domains
230
+ num_colors_to_generate = len(domain_id_to_node_ids_map)
231
+ if color:
232
+ # Generate all colors as the same specified color
233
+ rgba = to_rgba(color, num_repeats=num_colors_to_generate)
234
+ return rgba
235
+
236
+ # Load colormap
237
+ colormap = matplotlib.colormaps.get_cmap(cmap)
238
+ # Step 1: Calculate centroids for each domain
239
+ centroids = calculate_centroids(network, domain_id_to_node_ids_map)
240
+ # Step 2: Calculate pairwise distances between centroids
241
+ centroid_array = np.array(centroids)
242
+ dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
243
+ # Step 3: Assign distant colors to close centroids
244
+ color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
245
+ # Step 4: Randomly shift the entire color palette while maintaining relative distances
246
+ global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
247
+ color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
248
+ # Step 5: Ensure that all positions remain between 0 and 1
249
+ color_positions = np.clip(color_positions, 0, 1)
250
+
251
+ # Step 6: Generate RGBA colors based on positions
252
+ return [colormap(pos) for pos in color_positions]
253
+
254
+
255
+ def _assign_distant_colors(dist_matrix, num_colors_to_generate):
256
+ """Assign colors to centroids that are close in space, ensuring stark color differences.
257
+
258
+ Args:
259
+ dist_matrix (ndarray): Matrix of pairwise centroid distances.
260
+ num_colors_to_generate (int): Number of colors to generate.
261
+
262
+ Returns:
263
+ np.array: Array of color positions in the range [0, 1].
264
+ """
265
+ color_positions = np.zeros(num_colors_to_generate)
266
+ # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
267
+ proximity_order = sorted(
268
+ range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
269
+ )
270
+ # Step 2: Assign colors starting with the most distant points in proximity order
271
+ for i, idx in enumerate(proximity_order):
272
+ color_positions[idx] = i / num_colors_to_generate
273
+
274
+ # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
275
+ half_spectrum = int(num_colors_to_generate / 2)
276
+ for i in range(half_spectrum):
277
+ # Split the spectrum so that close centroids are assigned distant colors
278
+ color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
279
+ color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
280
+
281
+ return color_positions
282
+
283
+
284
+ def _blend_colors_perceptually(
285
+ colors: Union[List, Tuple, np.ndarray], enrichments: List[float], gamma: float = 2.2
286
+ ) -> Tuple[float, float, float, float]:
287
+ """Blends a list of RGBA colors using gamma correction for perceptually uniform color mixing.
288
+
289
+ Args:
290
+ colors (List, Tuple, np.ndarray): List of RGBA colors. Can be a list, tuple, or NumPy array of RGBA values.
291
+ enrichments (List[float]): Corresponding list of enrichment values.
292
+ gamma (float, optional): Gamma correction factor, default is 2.2 (typical for perceptual blending).
293
+
294
+ Returns:
295
+ Tuple[float, float, float, float]: The blended RGBA color.
296
+ """
297
+ # Normalize enrichments so they sum up to 1 (proportions)
298
+ total_enrichment = sum(enrichments)
299
+ proportions = [enrichment / total_enrichment for enrichment in enrichments]
300
+ # Convert colors to gamma-corrected space (apply gamma correction to RGB channels)
301
+ gamma_corrected_colors = [[channel**gamma for channel in color[:3]] for color in colors]
302
+ # Blend the colors in gamma-corrected space
303
+ blended_color = np.dot(proportions, gamma_corrected_colors)
304
+ # Convert back from gamma-corrected space to linear space (by applying inverse gamma correction)
305
+ blended_color = [channel ** (1 / gamma) for channel in blended_color]
306
+ # Average the alpha channel separately (no gamma correction on alpha)
307
+ alpha = np.dot(proportions, [color[3] for color in colors])
308
+ return tuple(blended_color + [alpha])
309
+
310
+
311
+ def _transform_colors(
312
+ colors: np.ndarray,
313
+ enrichment_sums: np.ndarray,
314
+ min_scale: float = 0.8,
315
+ max_scale: float = 1.0,
316
+ scale_factor: float = 1.0,
317
+ ) -> np.ndarray:
318
+ """Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
319
+ very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
320
+
321
+ Args:
322
+ colors (np.ndarray): An array of RGBA colors.
323
+ enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
324
+ min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
325
+ max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
326
+ scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
327
+ values more. Defaults to 1.0.
328
+
329
+ Returns:
330
+ np.ndarray: The transformed array of RGBA colors with adjusted intensities.
331
+ """
332
+ # Ensure that min_scale is less than max_scale
333
+ if min_scale == max_scale:
334
+ min_scale = max_scale - 10e-6 # Avoid division by zero
335
+
336
+ # Replace black colors (#000000) with very dark grey (#1A1A1A)
337
+ black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
338
+ dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
339
+ # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
340
+ is_black = np.all(colors[:, :3] == black_color, axis=1)
341
+ colors[is_black, :3] = dark_grey
342
+
343
+ # Normalize the enrichment sums to the range [0, 1]
344
+ normalized_sums = enrichment_sums / np.max(enrichment_sums)
345
+ # Apply power scaling to dim lower values and emphasize higher values
346
+ scaled_sums = normalized_sums**scale_factor
347
+ # Linearly scale the normalized sums to the range [min_scale, max_scale]
348
+ scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
349
+ # Adjust RGB values based on scaled sums
350
+ for i in range(3): # Only adjust RGB values
351
+ colors[:, i] = scaled_sums * colors[:, i]
352
+
353
+ return colors
354
+
355
+
356
+ def to_rgba(
357
+ color: Union[str, List, Tuple, np.ndarray],
358
+ alpha: Union[float, None] = None,
359
+ num_repeats: Union[int, None] = None,
360
+ ) -> np.ndarray:
361
+ """Convert color(s) to RGBA format, applying alpha and repeating as needed.
362
+
363
+ Args:
364
+ color (str, List, Tuple, np.ndarray): The color(s) to convert. Can be a string (e.g., 'red'), a list or tuple of RGB/RGBA values,
365
+ or an `np.ndarray` of colors.
366
+ alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values found
367
+ in color.
368
+ num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
369
+
370
+ Returns:
371
+ np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
372
+ """
373
+
374
+ def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
375
+ """Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
376
+ # Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
377
+ if isinstance(c, str):
378
+ # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
379
+ rgba = np.array(mcolors.to_rgba(c))
380
+ elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
381
+ # Convert RGB (3) or RGBA (4) values to RGBA format
382
+ rgba = np.array(mcolors.to_rgba(c))
383
+ else:
384
+ raise ValueError(
385
+ f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
386
+ )
387
+
388
+ if alpha is not None: # Override alpha if provided
389
+ rgba[3] = alpha
390
+ return rgba
391
+
392
+ # If color is a 2D array of RGBA values, convert it to a list of lists
393
+ if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
394
+ color = [list(c) for c in color]
395
+
396
+ # Handle a single color (string or RGB/RGBA list/tuple)
397
+ if (
398
+ isinstance(color, str)
399
+ or isinstance(color, (list, tuple, np.ndarray))
400
+ and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
401
+ ):
402
+ rgba_color = convert_to_rgba(color)
403
+ if num_repeats:
404
+ return np.tile(
405
+ rgba_color, (num_repeats, 1)
406
+ ) # Repeat the color if num_repeats is provided
407
+ return np.array([rgba_color]) # Return a single color wrapped in a numpy array
408
+
409
+ # Handle a list/array of colors
410
+ elif isinstance(color, (list, tuple, np.ndarray)):
411
+ rgba_colors = np.array(
412
+ [convert_to_rgba(c) for c in color]
413
+ ) # Convert each color in the list to RGBA
414
+ # Handle repetition if num_repeats is provided
415
+ if num_repeats:
416
+ repeated_colors = np.array(
417
+ [rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
418
+ )
419
+ return repeated_colors
420
+
421
+ return rgba_colors
422
+
423
+ else:
424
+ raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
@@ -0,0 +1,91 @@
1
+ """
2
+ risk/network/plot/utils/layout
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import numpy as np
9
+
10
+
11
+ def calculate_bounding_box(
12
+ node_coordinates: np.ndarray, radius_margin: float = 1.05
13
+ ) -> Tuple[np.ndarray, float]:
14
+ """Calculate the bounding box of the network based on node coordinates.
15
+
16
+ Args:
17
+ node_coordinates (np.ndarray): Array of node coordinates (x, y).
18
+ radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
19
+
20
+ Returns:
21
+ tuple: Center of the bounding box and the radius (adjusted by the radius margin).
22
+ """
23
+ # Find minimum and maximum x, y coordinates
24
+ x_min, y_min = np.min(node_coordinates, axis=0)
25
+ x_max, y_max = np.max(node_coordinates, axis=0)
26
+ # Calculate the center of the bounding box
27
+ center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
28
+ # Calculate the radius of the bounding box, adjusted by the margin
29
+ radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
30
+ return center, radius
31
+
32
+
33
+ def refine_center_iteratively(
34
+ node_coordinates: np.ndarray,
35
+ radius_margin: float = 1.05,
36
+ max_iterations: int = 10,
37
+ tolerance: float = 1e-2,
38
+ ) -> Tuple[np.ndarray, float]:
39
+ """Refine the center of the graph iteratively to minimize skew in node distribution.
40
+
41
+ Args:
42
+ node_coordinates (np.ndarray): Array of node coordinates (x, y).
43
+ radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
44
+ max_iterations (int, optional): Maximum number of iterations for refining the center. Defaults to 10.
45
+ tolerance (float, optional): Stopping tolerance for center adjustment. Defaults to 1e-2.
46
+
47
+ Returns:
48
+ tuple: Refined center and the final radius.
49
+ """
50
+ # Initial center and radius based on the bounding box
51
+ center, radius = calculate_bounding_box(node_coordinates, radius_margin)
52
+ for _ in range(max_iterations):
53
+ # Shift the coordinates based on the current center
54
+ shifted_coordinates = node_coordinates - center
55
+ # Calculate skew (difference in distance from the center)
56
+ skew = np.mean(shifted_coordinates, axis=0)
57
+ # If skew is below tolerance, stop
58
+ if np.linalg.norm(skew) < tolerance:
59
+ break
60
+
61
+ # Adjust the center by moving it in the direction opposite to the skew
62
+ center += skew
63
+
64
+ # After refinement, recalculate the bounding radius
65
+ shifted_coordinates = node_coordinates - center
66
+ new_radius = np.max(np.linalg.norm(shifted_coordinates, axis=1)) * radius_margin
67
+
68
+ return center, new_radius
69
+
70
+
71
+ def calculate_centroids(network, domain_id_to_node_ids_map):
72
+ """Calculate the centroid for each domain based on node x and y coordinates in the network.
73
+
74
+ Args:
75
+ network (NetworkX graph): The graph representing the network.
76
+ domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
77
+
78
+ Returns:
79
+ List[Tuple[float, float]]: List of centroids (x, y) for each domain.
80
+ """
81
+ centroids = []
82
+ for domain_id, node_ids in domain_id_to_node_ids_map.items():
83
+ # Extract x and y coordinates from the network nodes
84
+ node_positions = np.array(
85
+ [[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
86
+ )
87
+ # Compute the centroid as the mean of the x and y coordinates
88
+ centroid = np.mean(node_positions, axis=0)
89
+ centroids.append(tuple(centroid))
90
+
91
+ return centroids