risk-network 0.0.7b11__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/network/plot.py DELETED
@@ -1,1343 +0,0 @@
1
- """
2
- risk/network/plot
3
- ~~~~~~~~~~~~~~~~~
4
- """
5
-
6
- from typing import Any, Dict, List, Tuple, Union
7
-
8
- import matplotlib.colors as mcolors
9
- import matplotlib.pyplot as plt
10
- import networkx as nx
11
- import numpy as np
12
- import pandas as pd
13
- from scipy.ndimage import label
14
- from scipy.stats import gaussian_kde
15
-
16
- from risk.log import params
17
- from risk.network.graph import NetworkGraph
18
-
19
-
20
- class NetworkPlotter:
21
- """A class for visualizing network graphs with customizable options.
22
-
23
- The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
24
- flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
25
- perimeter, and adjusting background colors.
26
- """
27
-
28
- def __init__(
29
- self,
30
- graph: NetworkGraph,
31
- figsize: Tuple = (10, 10),
32
- background_color: Union[str, List, Tuple, np.ndarray] = "white",
33
- ) -> None:
34
- """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
35
-
36
- Args:
37
- graph (NetworkGraph): The network data and attributes to be visualized.
38
- figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
39
- background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
40
- """
41
- self.graph = graph
42
- # Initialize the plot with the specified parameters
43
- self.ax = self._initialize_plot(graph, figsize, background_color)
44
-
45
- def _initialize_plot(
46
- self,
47
- graph: NetworkGraph,
48
- figsize: Tuple,
49
- background_color: Union[str, List, Tuple, np.ndarray],
50
- ) -> plt.Axes:
51
- """Set up the plot with figure size and background color.
52
-
53
- Args:
54
- graph (NetworkGraph): The network data and attributes to be visualized.
55
- figsize (tuple): Size of the figure in inches (width, height).
56
- background_color (str): Background color of the plot.
57
-
58
- Returns:
59
- plt.Axes: The axis object for the plot.
60
- """
61
- # Extract node coordinates from the network graph
62
- node_coordinates = graph.node_coordinates
63
- # Calculate the center and radius of the bounding box around the network
64
- center, radius = _calculate_bounding_box(node_coordinates)
65
-
66
- # Create a new figure and axis for plotting
67
- fig, ax = plt.subplots(figsize=figsize)
68
- fig.tight_layout() # Adjust subplot parameters to give specified padding
69
- # Set axis limits based on the calculated bounding box and radius
70
- ax.set_xlim([center[0] - radius - 0.3, center[0] + radius + 0.3])
71
- ax.set_ylim([center[1] - radius - 0.3, center[1] + radius + 0.3])
72
- ax.set_aspect("equal") # Ensure the aspect ratio is equal
73
-
74
- # Set the background color of the plot
75
- # Convert color to RGBA using the _to_rgba helper function
76
- fig.patch.set_facecolor(_to_rgba(background_color, 1.0))
77
- ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
78
- # Remove axis spines for a cleaner look
79
- for spine in ax.spines.values():
80
- spine.set_visible(False)
81
-
82
- # Hide axis ticks and labels
83
- ax.set_xticks([])
84
- ax.set_yticks([])
85
- ax.patch.set_visible(False) # Hide the axis background
86
-
87
- return ax
88
-
89
- def plot_circle_perimeter(
90
- self,
91
- scale: float = 1.0,
92
- linestyle: str = "dashed",
93
- linewidth: float = 1.5,
94
- color: Union[str, List, Tuple, np.ndarray] = "black",
95
- outline_alpha: float = 1.0,
96
- fill_alpha: float = 0.0,
97
- ) -> None:
98
- """Plot a circle around the network graph to represent the network perimeter.
99
-
100
- Args:
101
- scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
102
- linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
103
- linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
104
- color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
105
- outline_alpha (float, optional): Transparency level of the circle outline. Defaults to 1.0.
106
- fill_alpha (float, optional): Transparency level of the circle fill. Defaults to 0.0.
107
- """
108
- # Log the circle perimeter plotting parameters
109
- params.log_plotter(
110
- perimeter_type="circle",
111
- perimeter_scale=scale,
112
- perimeter_linestyle=linestyle,
113
- perimeter_linewidth=linewidth,
114
- perimeter_color=(
115
- "custom" if isinstance(color, (list, tuple, np.ndarray)) else color
116
- ), # np.ndarray usually indicates custom colors
117
- perimeter_outline_alpha=outline_alpha,
118
- perimeter_fill_alpha=fill_alpha,
119
- )
120
-
121
- # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
122
- color = _to_rgba(color, outline_alpha)
123
- # Extract node coordinates from the network graph
124
- node_coordinates = self.graph.node_coordinates
125
- # Calculate the center and radius of the bounding box around the network
126
- center, radius = _calculate_bounding_box(node_coordinates)
127
- # Scale the radius by the scale factor
128
- scaled_radius = radius * scale
129
-
130
- # Draw a circle to represent the network perimeter
131
- circle = plt.Circle(
132
- center,
133
- scaled_radius,
134
- linestyle=linestyle,
135
- linewidth=linewidth,
136
- color=color,
137
- fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
138
- )
139
- # Set the transparency of the fill if applicable
140
- if fill_alpha > 0:
141
- circle.set_facecolor(_to_rgba(color, fill_alpha))
142
-
143
- self.ax.add_artist(circle)
144
-
145
- def plot_contour_perimeter(
146
- self,
147
- scale: float = 1.0,
148
- levels: int = 3,
149
- bandwidth: float = 0.8,
150
- grid_size: int = 250,
151
- color: Union[str, List, Tuple, np.ndarray] = "black",
152
- linestyle: str = "solid",
153
- linewidth: float = 1.5,
154
- outline_alpha: float = 1.0,
155
- fill_alpha: float = 0.0,
156
- ) -> None:
157
- """
158
- Plot a KDE-based contour around the network graph to represent the network perimeter.
159
-
160
- Args:
161
- scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
162
- levels (int, optional): Number of contour levels. Defaults to 3.
163
- bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
164
- grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
165
- color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
166
- linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
167
- linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
168
- outline_alpha (float, optional): Transparency level of the contour outline. Defaults to 1.0.
169
- fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.0.
170
- """
171
- # Log the contour perimeter plotting parameters
172
- params.log_plotter(
173
- perimeter_type="contour",
174
- perimeter_scale=scale,
175
- perimeter_levels=levels,
176
- perimeter_bandwidth=bandwidth,
177
- perimeter_grid_size=grid_size,
178
- perimeter_linestyle=linestyle,
179
- perimeter_linewidth=linewidth,
180
- perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
181
- perimeter_outline_alpha=outline_alpha,
182
- perimeter_fill_alpha=fill_alpha,
183
- )
184
-
185
- # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
186
- color = _to_rgba(color, outline_alpha)
187
- # Extract node coordinates from the network graph
188
- node_coordinates = self.graph.node_coordinates
189
- # Scale the node coordinates if needed
190
- scaled_coordinates = node_coordinates * scale
191
- # Use the existing _draw_kde_contour method
192
- self._draw_kde_contour(
193
- ax=self.ax,
194
- pos=scaled_coordinates,
195
- nodes=list(range(len(node_coordinates))), # All nodes are included
196
- levels=levels,
197
- bandwidth=bandwidth,
198
- grid_size=grid_size,
199
- color=color,
200
- linestyle=linestyle,
201
- linewidth=linewidth,
202
- alpha=fill_alpha,
203
- )
204
-
205
- def plot_network(
206
- self,
207
- node_size: Union[int, np.ndarray] = 50,
208
- node_shape: str = "o",
209
- node_edgewidth: float = 1.0,
210
- edge_width: float = 1.0,
211
- node_color: Union[str, List, Tuple, np.ndarray] = "white",
212
- node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
213
- edge_color: Union[str, List, Tuple, np.ndarray] = "black",
214
- node_alpha: float = 1.0,
215
- edge_alpha: float = 1.0,
216
- ) -> None:
217
- """Plot the network graph with customizable node colors, sizes, edge widths, and node edge widths.
218
-
219
- Args:
220
- node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
221
- node_shape (str, optional): Shape of the nodes. Defaults to "o".
222
- node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
223
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
224
- node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
225
- node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
226
- edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
227
- node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0. Annotated node_color alphas will override this value.
228
- edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
229
- """
230
- # Log the plotting parameters
231
- params.log_plotter(
232
- network_node_size=(
233
- "custom" if isinstance(node_size, np.ndarray) else node_size
234
- ), # np.ndarray usually indicates custom sizes
235
- network_node_shape=node_shape,
236
- network_node_edgewidth=node_edgewidth,
237
- network_edge_width=edge_width,
238
- network_node_color=(
239
- "custom" if isinstance(node_color, np.ndarray) else node_color
240
- ), # np.ndarray usually indicates custom colors
241
- network_node_edgecolor=node_edgecolor,
242
- network_edge_color=edge_color,
243
- network_node_alpha=node_alpha,
244
- network_edge_alpha=edge_alpha,
245
- )
246
-
247
- # Convert colors to RGBA using the _to_rgba helper function
248
- # If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
249
- node_color = _to_rgba(node_color, node_alpha, num_repeats=len(self.graph.network.nodes))
250
- node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(self.graph.network.nodes))
251
- edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
252
-
253
- # Extract node coordinates from the network graph
254
- node_coordinates = self.graph.node_coordinates
255
-
256
- # Draw the nodes of the graph
257
- nx.draw_networkx_nodes(
258
- self.graph.network,
259
- pos=node_coordinates,
260
- node_size=node_size,
261
- node_shape=node_shape,
262
- node_color=node_color,
263
- edgecolors=node_edgecolor,
264
- linewidths=node_edgewidth,
265
- ax=self.ax,
266
- )
267
- # Draw the edges of the graph
268
- nx.draw_networkx_edges(
269
- self.graph.network,
270
- pos=node_coordinates,
271
- width=edge_width,
272
- edge_color=edge_color,
273
- ax=self.ax,
274
- )
275
-
276
- def plot_subnetwork(
277
- self,
278
- nodes: Union[List, Tuple, np.ndarray],
279
- node_size: Union[int, np.ndarray] = 50,
280
- node_shape: str = "o",
281
- node_edgewidth: float = 1.0,
282
- edge_width: float = 1.0,
283
- node_color: Union[str, List, Tuple, np.ndarray] = "white",
284
- node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
285
- edge_color: Union[str, List, Tuple, np.ndarray] = "black",
286
- node_alpha: float = 1.0,
287
- edge_alpha: float = 1.0,
288
- ) -> None:
289
- """Plot a subnetwork of selected nodes with customizable node and edge attributes.
290
-
291
- Args:
292
- nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
293
- node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
294
- node_shape (str, optional): Shape of the nodes. Defaults to "o".
295
- node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
296
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
297
- node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
298
- node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
299
- edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
300
- node_alpha (float, optional): Transparency for the nodes. Defaults to 1.0.
301
- edge_alpha (float, optional): Transparency for the edges. Defaults to 1.0.
302
-
303
- Raises:
304
- ValueError: If no valid nodes are found in the network graph.
305
- """
306
- # Flatten nested lists of nodes, if necessary
307
- if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
308
- nodes = [node for sublist in nodes for node in sublist]
309
-
310
- # Filter to get node IDs and their coordinates
311
- node_ids = [
312
- self.graph.node_label_to_node_id_map.get(node)
313
- for node in nodes
314
- if node in self.graph.node_label_to_node_id_map
315
- ]
316
- if not node_ids:
317
- raise ValueError("No nodes found in the network graph.")
318
-
319
- # Check if node_color is a single color or a list of colors
320
- if not isinstance(node_color, (str, tuple, np.ndarray)):
321
- node_color = [
322
- node_color[nodes.index(node)]
323
- for node in nodes
324
- if node in self.graph.node_label_to_node_id_map
325
- ]
326
-
327
- # Convert colors to RGBA using the _to_rgba helper function
328
- node_color = _to_rgba(node_color, node_alpha, num_repeats=len(node_ids))
329
- node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(node_ids))
330
- edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
331
-
332
- # Get the coordinates of the filtered nodes
333
- node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
334
-
335
- # Draw the nodes in the subnetwork
336
- nx.draw_networkx_nodes(
337
- self.graph.network,
338
- pos=node_coordinates,
339
- nodelist=node_ids,
340
- node_size=node_size,
341
- node_shape=node_shape,
342
- node_color=node_color,
343
- edgecolors=node_edgecolor,
344
- linewidths=node_edgewidth,
345
- ax=self.ax,
346
- )
347
- # Draw the edges between the specified nodes in the subnetwork
348
- subgraph = self.graph.network.subgraph(node_ids)
349
- nx.draw_networkx_edges(
350
- subgraph,
351
- pos=node_coordinates,
352
- width=edge_width,
353
- edge_color=edge_color,
354
- ax=self.ax,
355
- )
356
-
357
- def plot_contours(
358
- self,
359
- levels: int = 5,
360
- bandwidth: float = 0.8,
361
- grid_size: int = 250,
362
- color: Union[str, List, Tuple, np.ndarray] = "white",
363
- linestyle: str = "solid",
364
- linewidth: float = 1.5,
365
- alpha: float = 1.0,
366
- fill_alpha: float = 0.2,
367
- ) -> None:
368
- """Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
369
-
370
- Args:
371
- levels (int, optional): Number of contour levels to plot. Defaults to 5.
372
- bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
373
- grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
374
- color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors. Defaults to "white".
375
- linestyle (str, optional): Line style for the contours. Defaults to "solid".
376
- linewidth (float, optional): Line width for the contours. Defaults to 1.5.
377
- alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
378
- fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
379
- """
380
- # Log the contour plotting parameters
381
- params.log_plotter(
382
- contour_levels=levels,
383
- contour_bandwidth=bandwidth,
384
- contour_grid_size=grid_size,
385
- contour_color=(
386
- "custom" if isinstance(color, np.ndarray) else color
387
- ), # np.ndarray usually indicates custom colors
388
- contour_alpha=alpha,
389
- contour_fill_alpha=fill_alpha,
390
- )
391
-
392
- # Ensure color is converted to RGBA with repetition matching the number of domains
393
- color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map))
394
- # Extract node coordinates from the network graph
395
- node_coordinates = self.graph.node_coordinates
396
- # Draw contours for each domain in the network
397
- for idx, (_, node_ids) in enumerate(self.graph.domain_id_to_node_ids_map.items()):
398
- if len(node_ids) > 1:
399
- self._draw_kde_contour(
400
- self.ax,
401
- node_coordinates,
402
- node_ids,
403
- color=color[idx],
404
- levels=levels,
405
- bandwidth=bandwidth,
406
- grid_size=grid_size,
407
- linestyle=linestyle,
408
- linewidth=linewidth,
409
- alpha=alpha,
410
- fill_alpha=fill_alpha,
411
- )
412
-
413
- def plot_subcontour(
414
- self,
415
- nodes: Union[List, Tuple, np.ndarray],
416
- levels: int = 5,
417
- bandwidth: float = 0.8,
418
- grid_size: int = 250,
419
- color: Union[str, List, Tuple, np.ndarray] = "white",
420
- linestyle: str = "solid",
421
- linewidth: float = 1.5,
422
- alpha: float = 1.0,
423
- fill_alpha: float = 0.2,
424
- ) -> None:
425
- """Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
426
-
427
- Args:
428
- nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
429
- levels (int, optional): Number of contour levels to plot. Defaults to 5.
430
- bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
431
- grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
432
- color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
433
- linestyle (str, optional): Line style for the contour. Defaults to "solid".
434
- linewidth (float, optional): Line width for the contour. Defaults to 1.5.
435
- alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
436
- fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
437
-
438
- Raises:
439
- ValueError: If no valid nodes are found in the network graph.
440
- """
441
- # Check if nodes is a list of lists or a flat list
442
- if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
443
- # If it's a list of lists, iterate over sublists
444
- node_groups = nodes
445
- else:
446
- # If it's a flat list of nodes, treat it as a single group
447
- node_groups = [nodes]
448
-
449
- # Convert color to RGBA using the _to_rgba helper function
450
- color_rgba = _to_rgba(color, alpha)
451
-
452
- # Iterate over each group of nodes (either sublists or flat list)
453
- for sublist in node_groups:
454
- # Filter to get node IDs and their coordinates for each sublist
455
- node_ids = [
456
- self.graph.node_label_to_node_id_map.get(node)
457
- for node in sublist
458
- if node in self.graph.node_label_to_node_id_map
459
- ]
460
- if not node_ids or len(node_ids) == 1:
461
- raise ValueError(
462
- "No nodes found in the network graph or insufficient nodes to plot."
463
- )
464
-
465
- # Draw the KDE contour for the specified nodes
466
- node_coordinates = self.graph.node_coordinates
467
- self._draw_kde_contour(
468
- self.ax,
469
- node_coordinates,
470
- node_ids,
471
- color=color_rgba,
472
- levels=levels,
473
- bandwidth=bandwidth,
474
- grid_size=grid_size,
475
- linestyle=linestyle,
476
- linewidth=linewidth,
477
- alpha=alpha,
478
- fill_alpha=fill_alpha,
479
- )
480
-
481
- def _draw_kde_contour(
482
- self,
483
- ax: plt.Axes,
484
- pos: np.ndarray,
485
- nodes: List,
486
- levels: int = 5,
487
- bandwidth: float = 0.8,
488
- grid_size: int = 250,
489
- color: Union[str, np.ndarray] = "white",
490
- linestyle: str = "solid",
491
- linewidth: float = 1.5,
492
- alpha: float = 1.0,
493
- fill_alpha: float = 0.2,
494
- ) -> None:
495
- """Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
496
-
497
- Args:
498
- ax (plt.Axes): The axis to draw the contour on.
499
- pos (np.ndarray): Array of node positions (x, y).
500
- nodes (list): List of node indices to include in the contour.
501
- levels (int, optional): Number of contour levels. Defaults to 5.
502
- bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
503
- grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
504
- color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
505
- linestyle (str, optional): Line style for the contour. Defaults to "solid".
506
- linewidth (float, optional): Line width for the contour. Defaults to 1.5.
507
- alpha (float, optional): Transparency level for the contour lines. Defaults to 1.0.
508
- fill_alpha (float, optional): Transparency level for the contour fill. Defaults to 0.2.
509
- """
510
- # Extract the positions of the specified nodes
511
- points = np.array([pos[n] for n in nodes])
512
- if len(points) <= 1:
513
- return # Not enough points to form a contour
514
-
515
- connected = False
516
- while not connected and bandwidth <= 100.0:
517
- # Perform KDE on the points with the given bandwidth
518
- kde = gaussian_kde(points.T, bw_method=bandwidth)
519
- xmin, ymin = points.min(axis=0) - bandwidth
520
- xmax, ymax = points.max(axis=0) + bandwidth
521
- x, y = np.mgrid[
522
- xmin : xmax : complex(0, grid_size), ymin : ymax : complex(0, grid_size)
523
- ]
524
- z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
525
- # Check if the KDE forms a single connected component
526
- connected = _is_connected(z)
527
- if not connected:
528
- bandwidth += 0.05 # Increase bandwidth slightly and retry
529
-
530
- # Define contour levels based on the density
531
- min_density, max_density = z.min(), z.max()
532
- contour_levels = np.linspace(min_density, max_density, levels)[1:]
533
- contour_colors = [color for _ in range(levels - 1)]
534
- # Plot the filled contours using fill_alpha for transparency
535
- if fill_alpha > 0:
536
- ax.contourf(
537
- x,
538
- y,
539
- z,
540
- levels=contour_levels,
541
- colors=contour_colors,
542
- antialiased=True,
543
- alpha=fill_alpha,
544
- )
545
-
546
- # Plot the contour lines with the specified alpha for transparency
547
- c = ax.contour(
548
- x,
549
- y,
550
- z,
551
- levels=contour_levels,
552
- colors=contour_colors,
553
- linestyles=linestyle,
554
- linewidths=linewidth,
555
- alpha=alpha,
556
- )
557
- # Set linewidth for the contour lines to 0 for levels other than the base level
558
- for i in range(1, len(contour_levels)):
559
- c.collections[i].set_linewidth(0)
560
-
561
- def plot_labels(
562
- self,
563
- scale: float = 1.05,
564
- offset: float = 0.10,
565
- font: str = "Arial",
566
- fontsize: int = 10,
567
- fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
568
- fontalpha: float = 1.0,
569
- arrow_linewidth: float = 1,
570
- arrow_style: str = "->",
571
- arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
572
- arrow_alpha: float = 1.0,
573
- arrow_base_shrink: float = 0.0,
574
- arrow_tip_shrink: float = 0.0,
575
- max_labels: Union[int, None] = None,
576
- max_words: int = 10,
577
- min_words: int = 1,
578
- max_word_length: int = 20,
579
- min_word_length: int = 1,
580
- words_to_omit: Union[List, None] = None,
581
- overlay_ids: bool = False,
582
- ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
583
- ids_to_replace: Union[Dict, None] = None,
584
- ) -> None:
585
- """Annotate the network graph with labels for different domains, positioned around the network for clarity.
586
-
587
- Args:
588
- scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
589
- offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
590
- font (str, optional): Font name for the labels. Defaults to "Arial".
591
- fontsize (int, optional): Font size for the labels. Defaults to 10.
592
- fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
593
- fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
594
- arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
595
- arrow_style (str, optional): Style of the arrows pointing to centroids. Defaults to "->".
596
- arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
597
- arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
598
- arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
599
- arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
600
- max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
601
- max_words (int, optional): Maximum number of words in a label. Defaults to 10.
602
- min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
603
- max_word_length (int, optional): Maximum number of characters in a word to display. Defaults to 20.
604
- min_word_length (int, optional): Minimum number of characters in a word to display. Defaults to 1.
605
- words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
606
- overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
607
- ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
608
- you can set `overlay_ids=True`. Defaults to None.
609
- ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be space-separated words.
610
- If provided, the custom labels will replace the default domain terms. To discover domain IDs, you can set `overlay_ids=True`.
611
- Defaults to None.
612
-
613
- Raises:
614
- ValueError: If the number of provided `ids_to_keep` exceeds `max_labels`.
615
- """
616
- # Log the plotting parameters
617
- params.log_plotter(
618
- label_perimeter_scale=scale,
619
- label_offset=offset,
620
- label_font=font,
621
- label_fontsize=fontsize,
622
- label_fontcolor=(
623
- "custom" if isinstance(fontcolor, np.ndarray) else fontcolor
624
- ), # np.ndarray usually indicates custom colors
625
- label_fontalpha=fontalpha,
626
- label_arrow_linewidth=arrow_linewidth,
627
- label_arrow_style=arrow_style,
628
- label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
629
- label_arrow_alpha=arrow_alpha,
630
- label_arrow_base_shrink=arrow_base_shrink,
631
- label_arrow_tip_shrink=arrow_tip_shrink,
632
- label_max_labels=max_labels,
633
- label_max_words=max_words,
634
- label_min_words=min_words,
635
- label_max_word_length=max_word_length,
636
- label_min_word_length=min_word_length,
637
- label_words_to_omit=words_to_omit,
638
- label_overlay_ids=overlay_ids,
639
- label_ids_to_keep=ids_to_keep,
640
- label_ids_to_replace=ids_to_replace,
641
- )
642
-
643
- # Set max_labels to the total number of domains if not provided (None)
644
- if max_labels is None:
645
- max_labels = len(self.graph.domain_id_to_node_ids_map)
646
-
647
- # Convert colors to RGBA using the _to_rgba helper function
648
- fontcolor = _to_rgba(
649
- fontcolor, fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
650
- )
651
- arrow_color = _to_rgba(
652
- arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
653
- )
654
-
655
- # Normalize words_to_omit to lowercase
656
- if words_to_omit:
657
- words_to_omit = set(word.lower() for word in words_to_omit)
658
-
659
- # Calculate the center and radius of the network
660
- domain_centroids = {}
661
- for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
662
- if node_ids: # Skip if the domain has no nodes
663
- domain_centroids[domain_id] = self._calculate_domain_centroid(node_ids)
664
-
665
- # Initialize dictionaries and lists for valid indices
666
- valid_indices = []
667
- filtered_domain_centroids = {}
668
- filtered_domain_terms = {}
669
- # Handle the ids_to_keep logic
670
- if ids_to_keep:
671
- # Convert ids_to_keep to remove accidental duplicates
672
- ids_to_keep = set(ids_to_keep)
673
- # Check if the number of provided ids_to_keep exceeds max_labels
674
- if max_labels is not None and len(ids_to_keep) > max_labels:
675
- raise ValueError(
676
- f"Number of provided IDs ({len(ids_to_keep)}) exceeds max_labels ({max_labels})."
677
- )
678
-
679
- # Process the specified IDs first
680
- for domain in ids_to_keep:
681
- if (
682
- domain in self.graph.domain_id_to_domain_terms_map
683
- and domain in domain_centroids
684
- ):
685
- # Handle ids_to_replace logic here for ids_to_keep
686
- if ids_to_replace and domain in ids_to_replace:
687
- terms = ids_to_replace[domain].split(" ")
688
- else:
689
- terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
690
-
691
- # Apply words_to_omit, word length constraints, and max_words
692
- if words_to_omit:
693
- terms = [term for term in terms if term.lower() not in words_to_omit]
694
- terms = [
695
- term for term in terms if min_word_length <= len(term) <= max_word_length
696
- ]
697
- terms = terms[:max_words]
698
-
699
- # Check if the domain passes the word count condition
700
- if len(terms) >= min_words:
701
- filtered_domain_centroids[domain] = domain_centroids[domain]
702
- filtered_domain_terms[domain] = " ".join(terms)
703
- valid_indices.append(
704
- list(domain_centroids.keys()).index(domain)
705
- ) # Track the valid index
706
-
707
- # Calculate remaining labels to plot after processing ids_to_keep
708
- remaining_labels = (
709
- max_labels - len(ids_to_keep) if ids_to_keep and max_labels else max_labels
710
- )
711
- # Process remaining domains to fill in additional labels, if there are slots left
712
- if remaining_labels and remaining_labels > 0:
713
- for idx, (domain, centroid) in enumerate(domain_centroids.items()):
714
- # Check if the domain is NaN and continue if true
715
- if pd.isna(domain) or (isinstance(domain, float) and np.isnan(domain)):
716
- continue # Skip NaN domains
717
- if ids_to_keep and domain in ids_to_keep:
718
- continue # Skip domains already handled by ids_to_keep
719
-
720
- # Handle ids_to_replace logic first
721
- if ids_to_replace and domain in ids_to_replace:
722
- terms = ids_to_replace[domain].split(" ")
723
- else:
724
- terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
725
-
726
- # Apply words_to_omit, word length constraints, and max_words
727
- if words_to_omit:
728
- terms = [term for term in terms if term.lower() not in words_to_omit]
729
-
730
- terms = [term for term in terms if min_word_length <= len(term) <= max_word_length]
731
- terms = terms[:max_words]
732
- # Check if the domain passes the word count condition
733
- if len(terms) >= min_words:
734
- filtered_domain_centroids[domain] = centroid
735
- filtered_domain_terms[domain] = " ".join(terms)
736
- valid_indices.append(idx) # Track the valid index
737
-
738
- # Stop once we've reached the max_labels limit
739
- if len(filtered_domain_centroids) >= max_labels:
740
- break
741
-
742
- # Calculate the bounding box around the network
743
- center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
744
- # Calculate the best positions for labels
745
- best_label_positions = _calculate_best_label_positions(
746
- filtered_domain_centroids, center, radius, offset
747
- )
748
-
749
- # Annotate the network with labels
750
- for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
751
- centroid = filtered_domain_centroids[domain]
752
- annotations = filtered_domain_terms[domain].split(" ")[:max_words]
753
- self.ax.annotate(
754
- "\n".join(annotations),
755
- xy=centroid,
756
- xytext=pos,
757
- textcoords="data",
758
- ha="center",
759
- va="center",
760
- fontsize=fontsize,
761
- fontname=font,
762
- color=fontcolor[idx],
763
- arrowprops=dict(
764
- arrowstyle=arrow_style,
765
- color=arrow_color[idx],
766
- linewidth=arrow_linewidth,
767
- shrinkA=arrow_base_shrink,
768
- shrinkB=arrow_tip_shrink,
769
- ),
770
- )
771
- # Overlay domain ID at the centroid if requested
772
- if overlay_ids:
773
- self.ax.text(
774
- centroid[0],
775
- centroid[1],
776
- domain,
777
- ha="center",
778
- va="center",
779
- fontsize=fontsize,
780
- fontname=font,
781
- color=fontcolor[idx],
782
- alpha=fontalpha,
783
- )
784
-
785
- def plot_sublabel(
786
- self,
787
- nodes: Union[List, Tuple, np.ndarray],
788
- label: str,
789
- radial_position: float = 0.0,
790
- scale: float = 1.05,
791
- offset: float = 0.10,
792
- font: str = "Arial",
793
- fontsize: int = 10,
794
- fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
795
- fontalpha: float = 1.0,
796
- arrow_linewidth: float = 1,
797
- arrow_style: str = "->",
798
- arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
799
- arrow_alpha: float = 1.0,
800
- arrow_base_shrink: float = 0.0,
801
- arrow_tip_shrink: float = 0.0,
802
- ) -> None:
803
- """Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
804
-
805
- Args:
806
- nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels.
807
- label (str): The label to be annotated on the network.
808
- radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
809
- scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
810
- offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
811
- font (str, optional): Font name for the label. Defaults to "Arial".
812
- fontsize (int, optional): Font size for the label. Defaults to 10.
813
- fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
814
- fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
815
- arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
816
- arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
817
- arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
818
- arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
819
- arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
820
- arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
821
- """
822
- # Check if nodes is a list of lists or a flat list
823
- if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
824
- # If it's a list of lists, iterate over sublists
825
- node_groups = nodes
826
- else:
827
- # If it's a flat list of nodes, treat it as a single group
828
- node_groups = [nodes]
829
-
830
- # Convert fontcolor and arrow_color to RGBA
831
- fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
832
- arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
833
-
834
- # Calculate the bounding box around the network
835
- center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
836
- # Convert radial position to radians, adjusting for a 90-degree rotation
837
- radial_radians = np.deg2rad(radial_position - 90)
838
- label_position = (
839
- center[0] + (radius + offset) * np.cos(radial_radians),
840
- center[1] + (radius + offset) * np.sin(radial_radians),
841
- )
842
-
843
- # Iterate over each group of nodes (either sublists or flat list)
844
- for sublist in node_groups:
845
- # Map node labels to IDs
846
- node_ids = [
847
- self.graph.node_label_to_node_id_map.get(node)
848
- for node in sublist
849
- if node in self.graph.node_label_to_node_id_map
850
- ]
851
- if not node_ids or len(node_ids) == 1:
852
- raise ValueError(
853
- "No nodes found in the network graph or insufficient nodes to plot."
854
- )
855
-
856
- # Calculate the centroid of the provided nodes in this sublist
857
- centroid = self._calculate_domain_centroid(node_ids)
858
- # Annotate the network with the label and an arrow pointing to each centroid
859
- self.ax.annotate(
860
- label,
861
- xy=centroid,
862
- xytext=label_position,
863
- textcoords="data",
864
- ha="center",
865
- va="center",
866
- fontsize=fontsize,
867
- fontname=font,
868
- color=fontcolor_rgba,
869
- arrowprops=dict(
870
- arrowstyle=arrow_style,
871
- color=arrow_color_rgba,
872
- linewidth=arrow_linewidth,
873
- shrinkA=arrow_base_shrink,
874
- shrinkB=arrow_tip_shrink,
875
- ),
876
- )
877
-
878
- def _calculate_domain_centroid(self, nodes: List) -> tuple:
879
- """Calculate the most centrally located node in .
880
-
881
- Args:
882
- nodes (list): List of node labels to include in the subnetwork.
883
-
884
- Returns:
885
- tuple: A tuple containing the domain's central node coordinates.
886
- """
887
- # Extract positions of all nodes in the domain
888
- node_positions = self.graph.node_coordinates[nodes, :]
889
- # Calculate the pairwise distance matrix between all nodes in the domain
890
- distances_matrix = np.linalg.norm(node_positions[:, np.newaxis] - node_positions, axis=2)
891
- # Sum the distances for each node to all other nodes in the domain
892
- sum_distances = np.sum(distances_matrix, axis=1)
893
- # Identify the node with the smallest total distance to others (the centroid)
894
- central_node_idx = np.argmin(sum_distances)
895
- # Map the domain to the coordinates of its central node
896
- domain_central_node = node_positions[central_node_idx]
897
- return domain_central_node
898
-
899
- def get_annotated_node_colors(
900
- self,
901
- cmap: str = "gist_rainbow",
902
- color: Union[str, None] = None,
903
- min_scale: float = 0.8,
904
- max_scale: float = 1.0,
905
- scale_factor: float = 1.0,
906
- alpha: float = 1.0,
907
- nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
908
- nonenriched_alpha: float = 1.0,
909
- random_seed: int = 888,
910
- ) -> np.ndarray:
911
- """Adjust the colors of nodes in the network graph based on enrichment.
912
-
913
- Args:
914
- cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
915
- color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
916
- min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
917
- max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
918
- scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
919
- alpha (float, optional): Alpha value for enriched nodes. Defaults to 1.0.
920
- nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
921
- nonenriched_alpha (float, optional): Alpha value for non-enriched nodes. Defaults to 1.0.
922
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
923
-
924
- Returns:
925
- np.ndarray: Array of RGBA colors adjusted for enrichment status.
926
- """
927
- # Get the initial domain colors for each node, which are returned as RGBA
928
- network_colors = self.graph.get_domain_colors(
929
- cmap=cmap,
930
- color=color,
931
- min_scale=min_scale,
932
- max_scale=max_scale,
933
- scale_factor=scale_factor,
934
- random_seed=random_seed,
935
- )
936
- # Apply the alpha value for enriched nodes
937
- network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
938
- # Convert the non-enriched color to RGBA using the _to_rgba helper function
939
- nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
940
- # Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
941
- adjusted_network_colors = np.where(
942
- np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
943
- np.array([nonenriched_color]), # Apply the non-enriched color with alpha
944
- network_colors, # Keep the original colors for enriched nodes
945
- )
946
- return adjusted_network_colors
947
-
948
- def get_annotated_node_sizes(
949
- self, enriched_size: int = 50, nonenriched_size: int = 25
950
- ) -> np.ndarray:
951
- """Adjust the sizes of nodes in the network graph based on whether they are enriched or not.
952
-
953
- Args:
954
- enriched_size (int): Size for enriched nodes. Defaults to 50.
955
- nonenriched_size (int): Size for non-enriched nodes. Defaults to 25.
956
-
957
- Returns:
958
- np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
959
- """
960
- # Merge all enriched nodes from the domain_id_to_node_ids_map dictionary
961
- enriched_nodes = set()
962
- for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
963
- enriched_nodes.update(node_ids)
964
-
965
- # Initialize all node sizes to the non-enriched size
966
- node_sizes = np.full(len(self.graph.network.nodes), nonenriched_size)
967
- # Set the size for enriched nodes
968
- for node in enriched_nodes:
969
- if node in self.graph.network.nodes:
970
- node_sizes[node] = enriched_size
971
-
972
- return node_sizes
973
-
974
- def get_annotated_contour_colors(
975
- self,
976
- cmap: str = "gist_rainbow",
977
- color: Union[str, None] = None,
978
- min_scale: float = 0.8,
979
- max_scale: float = 1.0,
980
- scale_factor: float = 1.0,
981
- random_seed: int = 888,
982
- ) -> np.ndarray:
983
- """Get colors for the contours based on node annotations or a specified colormap.
984
-
985
- Args:
986
- cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
987
- color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
988
- min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
989
- Controls the dimmest colors. Defaults to 0.8.
990
- max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
991
- Controls the brightest colors. Defaults to 1.0.
992
- scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
993
- A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
994
- random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
995
-
996
- Returns:
997
- np.ndarray: Array of RGBA colors for contour annotations.
998
- """
999
- return self._get_annotated_domain_colors(
1000
- cmap=cmap,
1001
- color=color,
1002
- min_scale=min_scale,
1003
- max_scale=max_scale,
1004
- scale_factor=scale_factor,
1005
- random_seed=random_seed,
1006
- )
1007
-
1008
- def get_annotated_label_colors(
1009
- self,
1010
- cmap: str = "gist_rainbow",
1011
- color: Union[str, None] = None,
1012
- min_scale: float = 0.8,
1013
- max_scale: float = 1.0,
1014
- scale_factor: float = 1.0,
1015
- random_seed: int = 888,
1016
- ) -> np.ndarray:
1017
- """Get colors for the labels based on node annotations or a specified colormap.
1018
-
1019
- Args:
1020
- cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
1021
- color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
1022
- min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
1023
- Controls the dimmest colors. Defaults to 0.8.
1024
- max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
1025
- Controls the brightest colors. Defaults to 1.0.
1026
- scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
1027
- A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
1028
- random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
1029
-
1030
- Returns:
1031
- np.ndarray: Array of RGBA colors for label annotations.
1032
- """
1033
- return self._get_annotated_domain_colors(
1034
- cmap=cmap,
1035
- color=color,
1036
- min_scale=min_scale,
1037
- max_scale=max_scale,
1038
- scale_factor=scale_factor,
1039
- random_seed=random_seed,
1040
- )
1041
-
1042
- def _get_annotated_domain_colors(
1043
- self,
1044
- cmap: str = "gist_rainbow",
1045
- color: Union[str, None] = None,
1046
- min_scale: float = 0.8,
1047
- max_scale: float = 1.0,
1048
- scale_factor: float = 1.0,
1049
- random_seed: int = 888,
1050
- ) -> np.ndarray:
1051
- """Get colors for the domains based on node annotations, or use a specified color.
1052
-
1053
- Args:
1054
- cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
1055
- color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
1056
- min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
1057
- Defaults to 0.8.
1058
- max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
1059
- Defaults to 1.0.
1060
- scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
1061
- enrichment. Higher values increase the contrast. Defaults to 1.0.
1062
- random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
1063
-
1064
- Returns:
1065
- np.ndarray: Array of RGBA colors for each domain.
1066
- """
1067
- # Generate domain colors based on the enrichment data
1068
- node_colors = self.graph.get_domain_colors(
1069
- cmap=cmap,
1070
- color=color,
1071
- min_scale=min_scale,
1072
- max_scale=max_scale,
1073
- scale_factor=scale_factor,
1074
- random_seed=random_seed,
1075
- )
1076
- annotated_colors = []
1077
- for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
1078
- if len(node_ids) > 1:
1079
- # For multi-node domains, choose the brightest color based on RGB sum
1080
- domain_colors = np.array([node_colors[node] for node in node_ids])
1081
- brightest_color = domain_colors[
1082
- np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
1083
- ]
1084
- annotated_colors.append(brightest_color)
1085
- else:
1086
- # Single-node domains default to white (RGBA)
1087
- default_color = np.array([1.0, 1.0, 1.0, 1.0])
1088
- annotated_colors.append(default_color)
1089
-
1090
- return np.array(annotated_colors)
1091
-
1092
- @staticmethod
1093
- def savefig(*args, **kwargs) -> None:
1094
- """Save the current plot to a file.
1095
-
1096
- Args:
1097
- *args: Positional arguments passed to `plt.savefig`.
1098
- **kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
1099
- """
1100
- plt.savefig(*args, bbox_inches="tight", **kwargs)
1101
-
1102
- @staticmethod
1103
- def show(*args, **kwargs) -> None:
1104
- """Display the current plot.
1105
-
1106
- Args:
1107
- *args: Positional arguments passed to `plt.show`.
1108
- **kwargs: Keyword arguments passed to `plt.show`.
1109
- """
1110
- plt.show(*args, **kwargs)
1111
-
1112
-
1113
- def _to_rgba(
1114
- color: Union[str, List, Tuple, np.ndarray],
1115
- alpha: float = 1.0,
1116
- num_repeats: Union[int, None] = None,
1117
- ) -> np.ndarray:
1118
- """Convert a color or array of colors to RGBA format, applying alpha only if the color is RGB.
1119
-
1120
- Args:
1121
- color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
1122
- alpha (float, optional): Alpha value (transparency) to apply if the color is in RGB format. Defaults to 1.0.
1123
- num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
1124
-
1125
- Returns:
1126
- np.ndarray: The RGBA color or array of RGBA colors.
1127
- """
1128
- # Handle single color case (string, RGB, or RGBA)
1129
- if isinstance(color, str) or (
1130
- isinstance(color, (list, tuple, np.ndarray))
1131
- and len(color) in [3, 4]
1132
- and not any(isinstance(c, (list, tuple, np.ndarray)) for c in color)
1133
- ):
1134
- rgba_color = np.array(mcolors.to_rgba(color))
1135
- # Only set alpha if the input is an RGB color or a string (not RGBA)
1136
- if len(rgba_color) == 4 and (
1137
- len(color) == 3 or isinstance(color, str)
1138
- ): # If it's RGB or a string, set the alpha
1139
- rgba_color[3] = alpha
1140
-
1141
- # Repeat the color if num_repeats argument is provided
1142
- if num_repeats is not None:
1143
- return np.array([rgba_color] * num_repeats)
1144
-
1145
- return rgba_color
1146
-
1147
- # Handle array of colors case (including strings, RGB, and RGBA)
1148
- elif isinstance(color, (list, tuple, np.ndarray)):
1149
- rgba_colors = []
1150
- for c in color:
1151
- # Ensure each element is either a valid string or a list/tuple of length 3 (RGB) or 4 (RGBA)
1152
- if isinstance(c, str) or (
1153
- isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]
1154
- ):
1155
- rgba_c = np.array(mcolors.to_rgba(c))
1156
- # Apply alpha only to RGB colors (not RGBA) and strings
1157
- if len(rgba_c) == 4 and (len(c) == 3 or isinstance(c, str)):
1158
- rgba_c[3] = alpha
1159
-
1160
- rgba_colors.append(rgba_c)
1161
- else:
1162
- raise ValueError(f"Invalid color: {c}. Must be a valid RGB/RGBA or string color.")
1163
-
1164
- # Repeat the colors if num_repeats argument is provided
1165
- if num_repeats is not None and len(rgba_colors) == 1:
1166
- return np.array([rgba_colors[0]] * num_repeats)
1167
-
1168
- return np.array(rgba_colors)
1169
-
1170
- else:
1171
- raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
1172
-
1173
-
1174
- def _is_connected(z: np.ndarray) -> bool:
1175
- """Determine if a thresholded grid represents a single, connected component.
1176
-
1177
- Args:
1178
- z (np.ndarray): A binary grid where the component connectivity is evaluated.
1179
-
1180
- Returns:
1181
- bool: True if the grid represents a single connected component, False otherwise.
1182
- """
1183
- _, num_features = label(z)
1184
- return num_features == 1 # Return True if only one connected component is found
1185
-
1186
-
1187
- def _calculate_bounding_box(
1188
- node_coordinates: np.ndarray, radius_margin: float = 1.05
1189
- ) -> Tuple[np.ndarray, float]:
1190
- """Calculate the bounding box of the network based on node coordinates.
1191
-
1192
- Args:
1193
- node_coordinates (np.ndarray): Array of node coordinates (x, y).
1194
- radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
1195
-
1196
- Returns:
1197
- tuple: Center of the bounding box and the radius (adjusted by the radius margin).
1198
- """
1199
- # Find minimum and maximum x, y coordinates
1200
- x_min, y_min = np.min(node_coordinates, axis=0)
1201
- x_max, y_max = np.max(node_coordinates, axis=0)
1202
- # Calculate the center of the bounding box
1203
- center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
1204
- # Calculate the radius of the bounding box, adjusted by the margin
1205
- radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
1206
- return center, radius
1207
-
1208
-
1209
- def _calculate_best_label_positions(
1210
- filtered_domain_centroids: Dict[str, Any], center: np.ndarray, radius: float, offset: float
1211
- ) -> Dict[str, Any]:
1212
- """Calculate and optimize label positions for clarity.
1213
-
1214
- Args:
1215
- filtered_domain_centroids (dict): Centroids of the filtered domains.
1216
- center (np.ndarray): The center coordinates for label positioning.
1217
- radius (float): The radius for positioning labels around the center.
1218
- offset (float): The offset distance from the radius for positioning labels.
1219
-
1220
- Returns:
1221
- dict: Optimized positions for labels.
1222
- """
1223
- num_domains = len(filtered_domain_centroids)
1224
- # Calculate equidistant positions around the center for initial label placement
1225
- equidistant_positions = _calculate_equidistant_positions_around_center(
1226
- center, radius, offset, num_domains
1227
- )
1228
- # Create a mapping of domains to their initial label positions
1229
- label_positions = {
1230
- domain: position
1231
- for domain, position in zip(filtered_domain_centroids.keys(), equidistant_positions)
1232
- }
1233
- # Optimize the label positions to minimize distance to domain centroids
1234
- return _optimize_label_positions(label_positions, filtered_domain_centroids)
1235
-
1236
-
1237
- def _calculate_equidistant_positions_around_center(
1238
- center: np.ndarray, radius: float, label_offset: float, num_domains: int
1239
- ) -> List[np.ndarray]:
1240
- """Calculate positions around a center at equidistant angles.
1241
-
1242
- Args:
1243
- center (np.ndarray): The central point around which positions are calculated.
1244
- radius (float): The radius at which positions are calculated.
1245
- label_offset (float): The offset added to the radius for label positioning.
1246
- num_domains (int): The number of positions (or domains) to calculate.
1247
-
1248
- Returns:
1249
- list[np.ndarray]: List of positions (as 2D numpy arrays) around the center.
1250
- """
1251
- # Calculate equidistant angles in radians around the center
1252
- angles = np.linspace(0, 2 * np.pi, num_domains, endpoint=False)
1253
- # Compute the positions around the center using the angles
1254
- return [
1255
- center + (radius + label_offset) * np.array([np.cos(angle), np.sin(angle)])
1256
- for angle in angles
1257
- ]
1258
-
1259
-
1260
- def _optimize_label_positions(
1261
- best_label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
1262
- ) -> Dict[str, Any]:
1263
- """Optimize label positions around the perimeter to minimize total distance to centroids.
1264
-
1265
- Args:
1266
- best_label_positions (dict): Initial positions of labels around the perimeter.
1267
- domain_centroids (dict): Centroid positions of the domains.
1268
-
1269
- Returns:
1270
- dict: Optimized label positions.
1271
- """
1272
- while True:
1273
- improvement = False # Start each iteration assuming no improvement
1274
- # Iterate through each pair of labels to check for potential improvements
1275
- for i in range(len(domain_centroids)):
1276
- for j in range(i + 1, len(domain_centroids)):
1277
- # Calculate the current total distance
1278
- current_distance = _calculate_total_distance(best_label_positions, domain_centroids)
1279
- # Evaluate the total distance after swapping two labels
1280
- swapped_distance = _swap_and_evaluate(best_label_positions, i, j, domain_centroids)
1281
- # If the swap improves the total distance, perform the swap
1282
- if swapped_distance < current_distance:
1283
- labels = list(best_label_positions.keys())
1284
- best_label_positions[labels[i]], best_label_positions[labels[j]] = (
1285
- best_label_positions[labels[j]],
1286
- best_label_positions[labels[i]],
1287
- )
1288
- improvement = True # Found an improvement, so continue optimizing
1289
-
1290
- if not improvement:
1291
- break # Exit the loop if no improvement was found in this iteration
1292
-
1293
- return best_label_positions
1294
-
1295
-
1296
- def _calculate_total_distance(
1297
- label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
1298
- ) -> float:
1299
- """Calculate the total distance from label positions to their domain centroids.
1300
-
1301
- Args:
1302
- label_positions (dict): Positions of labels around the perimeter.
1303
- domain_centroids (dict): Centroid positions of the domains.
1304
-
1305
- Returns:
1306
- float: The total distance from labels to centroids.
1307
- """
1308
- total_distance = 0
1309
- # Iterate through each domain and calculate the distance to its centroid
1310
- for domain, pos in label_positions.items():
1311
- centroid = domain_centroids[domain]
1312
- total_distance += np.linalg.norm(centroid - pos)
1313
-
1314
- return total_distance
1315
-
1316
-
1317
- def _swap_and_evaluate(
1318
- label_positions: Dict[str, Any],
1319
- i: int,
1320
- j: int,
1321
- domain_centroids: Dict[str, Any],
1322
- ) -> float:
1323
- """Swap two labels and evaluate the total distance after the swap.
1324
-
1325
- Args:
1326
- label_positions (dict): Positions of labels around the perimeter.
1327
- i (int): Index of the first label to swap.
1328
- j (int): Index of the second label to swap.
1329
- domain_centroids (dict): Centroid positions of the domains.
1330
-
1331
- Returns:
1332
- float: The total distance after swapping the two labels.
1333
- """
1334
- # Get the list of labels from the dictionary keys
1335
- labels = list(label_positions.keys())
1336
- swapped_positions = label_positions.copy()
1337
- # Swap the positions of the two specified labels
1338
- swapped_positions[labels[i]], swapped_positions[labels[j]] = (
1339
- swapped_positions[labels[j]],
1340
- swapped_positions[labels[i]],
1341
- )
1342
- # Calculate and return the total distance after the swap
1343
- return _calculate_total_distance(swapped_positions, domain_centroids)