risk-network 0.0.6b0__py3-none-any.whl → 0.0.6b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/network/plot.py CHANGED
@@ -5,7 +5,6 @@ risk/network/plot
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
7
7
 
8
- import matplotlib
9
8
  import matplotlib.colors as mcolors
10
9
  import matplotlib.pyplot as plt
11
10
  import networkx as nx
@@ -18,66 +17,42 @@ from risk.network.graph import NetworkGraph
18
17
 
19
18
 
20
19
  class NetworkPlotter:
21
- """A class responsible for visualizing network graphs with various customization options.
20
+ """A class for visualizing network graphs with customizable options.
22
21
 
23
- The NetworkPlotter class takes in a NetworkGraph object, which contains the network's data and attributes,
24
- and provides methods for plotting the network with customizable node and edge properties,
25
- as well as optional features like drawing the network's perimeter and setting background colors.
22
+ The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
23
+ flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
24
+ perimeter, and adjusting background colors.
26
25
  """
27
26
 
28
27
  def __init__(
29
28
  self,
30
29
  graph: NetworkGraph,
31
- figsize: tuple = (10, 10),
32
- background_color: str = "white",
33
- plot_outline: bool = True,
34
- outline_color: str = "black",
35
- outline_scale: float = 1.0,
36
- linestyle: str = "dashed",
30
+ figsize: Tuple = (10, 10),
31
+ background_color: Union[str, List, Tuple, np.ndarray] = "white",
37
32
  ) -> None:
38
33
  """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
39
34
 
40
35
  Args:
41
36
  graph (NetworkGraph): The network data and attributes to be visualized.
42
37
  figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
43
- background_color (str, optional): Background color of the plot. Defaults to "white".
44
- plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
45
- outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
46
- outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
47
- linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
38
+ background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
48
39
  """
49
40
  self.graph = graph
50
41
  # Initialize the plot with the specified parameters
51
- self.ax = self._initialize_plot(
52
- graph,
53
- figsize,
54
- background_color,
55
- plot_outline,
56
- outline_color,
57
- outline_scale,
58
- linestyle,
59
- )
42
+ self.ax = self._initialize_plot(graph, figsize, background_color)
60
43
 
61
44
  def _initialize_plot(
62
45
  self,
63
46
  graph: NetworkGraph,
64
- figsize: tuple,
65
- background_color: str,
66
- plot_outline: bool,
67
- outline_color: str,
68
- outline_scale: float,
69
- linestyle: str,
47
+ figsize: Tuple,
48
+ background_color: Union[str, List, Tuple, np.ndarray],
70
49
  ) -> plt.Axes:
71
- """Set up the plot with figure size, optional circle perimeter, and background color.
50
+ """Set up the plot with figure size and background color.
72
51
 
73
52
  Args:
74
53
  graph (NetworkGraph): The network data and attributes to be visualized.
75
54
  figsize (tuple): Size of the figure in inches (width, height).
76
55
  background_color (str): Background color of the plot.
77
- plot_outline (bool): Whether to plot the network perimeter circle.
78
- outline_color (str): Color of the network perimeter circle.
79
- outline_scale (float): Outline scaling factor for the perimeter diameter.
80
- linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid).
81
56
 
82
57
  Returns:
83
58
  plt.Axes: The axis object for the plot.
@@ -86,31 +61,19 @@ class NetworkPlotter:
86
61
  node_coordinates = graph.node_coordinates
87
62
  # Calculate the center and radius of the bounding box around the network
88
63
  center, radius = _calculate_bounding_box(node_coordinates)
89
- # Scale the radius by the outline_scale factor
90
- scaled_radius = radius * outline_scale
91
64
 
92
65
  # Create a new figure and axis for plotting
93
66
  fig, ax = plt.subplots(figsize=figsize)
94
67
  fig.tight_layout() # Adjust subplot parameters to give specified padding
95
- if plot_outline:
96
- # Draw a circle to represent the network perimeter
97
- circle = plt.Circle(
98
- center,
99
- scaled_radius,
100
- linestyle=linestyle, # Use the linestyle argument here
101
- color=outline_color,
102
- fill=False,
103
- linewidth=1.5,
104
- )
105
- ax.add_artist(circle) # Add the circle to the plot
106
-
107
- # Set axis limits based on the calculated bounding box and scaled radius
108
- ax.set_xlim([center[0] - scaled_radius - 0.3, center[0] + scaled_radius + 0.3])
109
- ax.set_ylim([center[1] - scaled_radius - 0.3, center[1] + scaled_radius + 0.3])
68
+ # Set axis limits based on the calculated bounding box and radius
69
+ ax.set_xlim([center[0] - radius - 0.3, center[0] + radius + 0.3])
70
+ ax.set_ylim([center[1] - radius - 0.3, center[1] + radius + 0.3])
110
71
  ax.set_aspect("equal") # Ensure the aspect ratio is equal
111
- fig.patch.set_facecolor(background_color) # Set the figure background color
112
- ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
113
72
 
73
+ # Set the background color of the plot
74
+ # Convert color to RGBA using the _to_rgba helper function
75
+ fig.patch.set_facecolor(_to_rgba(background_color, 1.0))
76
+ ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
114
77
  # Remove axis spines for a cleaner look
115
78
  for spine in ax.spines.values():
116
79
  spine.set_visible(False)
@@ -122,44 +85,182 @@ class NetworkPlotter:
122
85
 
123
86
  return ax
124
87
 
88
+ def plot_circle_perimeter(
89
+ self,
90
+ scale: float = 1.0,
91
+ linestyle: str = "dashed",
92
+ linewidth: float = 1.5,
93
+ color: Union[str, List, Tuple, np.ndarray] = "black",
94
+ outline_alpha: float = 1.0,
95
+ fill_alpha: float = 0.0,
96
+ ) -> None:
97
+ """Plot a circle around the network graph to represent the network perimeter.
98
+
99
+ Args:
100
+ scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
101
+ linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
102
+ linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
103
+ color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
104
+ outline_alpha (float, optional): Transparency level of the circle outline. Defaults to 1.0.
105
+ fill_alpha (float, optional): Transparency level of the circle fill. Defaults to 0.0.
106
+ """
107
+ # Log the circle perimeter plotting parameters
108
+ params.log_plotter(
109
+ perimeter_type="circle",
110
+ perimeter_scale=scale,
111
+ perimeter_linestyle=linestyle,
112
+ perimeter_linewidth=linewidth,
113
+ perimeter_color=(
114
+ "custom" if isinstance(color, (list, tuple, np.ndarray)) else color
115
+ ), # np.ndarray usually indicates custom colors
116
+ perimeter_outline_alpha=outline_alpha,
117
+ perimeter_fill_alpha=fill_alpha,
118
+ )
119
+
120
+ # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
121
+ color = _to_rgba(color, outline_alpha)
122
+ # Extract node coordinates from the network graph
123
+ node_coordinates = self.graph.node_coordinates
124
+ # Calculate the center and radius of the bounding box around the network
125
+ center, radius = _calculate_bounding_box(node_coordinates)
126
+ # Scale the radius by the scale factor
127
+ scaled_radius = radius * scale
128
+
129
+ # Draw a circle to represent the network perimeter
130
+ circle = plt.Circle(
131
+ center,
132
+ scaled_radius,
133
+ linestyle=linestyle,
134
+ linewidth=linewidth,
135
+ color=color,
136
+ fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
137
+ )
138
+ # Set the transparency of the fill if applicable
139
+ if fill_alpha > 0:
140
+ circle.set_facecolor(
141
+ _to_rgba(color, fill_alpha)
142
+ ) # Use _to_rgba to set the fill color with transparency
143
+
144
+ self.ax.add_artist(circle)
145
+
146
+ def plot_contour_perimeter(
147
+ self,
148
+ scale: float = 1.0,
149
+ levels: int = 3,
150
+ bandwidth: float = 0.8,
151
+ grid_size: int = 250,
152
+ color: Union[str, List, Tuple, np.ndarray] = "black",
153
+ linestyle: str = "solid",
154
+ linewidth: float = 1.5,
155
+ outline_alpha: float = 1.0,
156
+ fill_alpha: float = 0.0,
157
+ ) -> None:
158
+ """
159
+ Plot a KDE-based contour around the network graph to represent the network perimeter.
160
+
161
+ Args:
162
+ scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
163
+ levels (int, optional): Number of contour levels. Defaults to 3.
164
+ bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
165
+ grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
166
+ color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
167
+ linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
168
+ linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
169
+ outline_alpha (float, optional): Transparency level of the contour outline. Defaults to 1.0.
170
+ fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.0.
171
+ """
172
+ # Log the contour perimeter plotting parameters
173
+ params.log_plotter(
174
+ perimeter_type="contour",
175
+ perimeter_scale=scale,
176
+ perimeter_levels=levels,
177
+ perimeter_bandwidth=bandwidth,
178
+ perimeter_grid_size=grid_size,
179
+ perimeter_linestyle=linestyle,
180
+ perimeter_linewidth=linewidth,
181
+ perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
182
+ perimeter_outline_alpha=outline_alpha,
183
+ perimeter_fill_alpha=fill_alpha,
184
+ )
185
+
186
+ # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
187
+ color = _to_rgba(color, outline_alpha)
188
+ # Extract node coordinates from the network graph
189
+ node_coordinates = self.graph.node_coordinates
190
+ # Scale the node coordinates if needed
191
+ scaled_coordinates = node_coordinates * scale
192
+ # Use the existing _draw_kde_contour method
193
+ self._draw_kde_contour(
194
+ ax=self.ax,
195
+ pos=scaled_coordinates,
196
+ nodes=list(range(len(node_coordinates))), # All nodes are included
197
+ levels=levels,
198
+ bandwidth=bandwidth,
199
+ grid_size=grid_size,
200
+ color=color,
201
+ linestyle=linestyle,
202
+ linewidth=linewidth,
203
+ alpha=fill_alpha, # Use fill_alpha for the fill
204
+ )
205
+
125
206
  def plot_network(
126
207
  self,
127
208
  node_size: Union[int, np.ndarray] = 50,
128
- edge_width: float = 1.0,
129
- node_color: Union[str, np.ndarray] = "white",
130
- node_edgecolor: str = "black",
131
- edge_color: str = "black",
132
209
  node_shape: str = "o",
210
+ edge_width: float = 1.0,
211
+ node_color: Union[str, List, Tuple, np.ndarray] = "white",
212
+ node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
213
+ edge_color: Union[str, List, Tuple, np.ndarray] = "black",
214
+ node_alpha: float = 1.0,
215
+ edge_alpha: float = 1.0,
133
216
  ) -> None:
134
217
  """Plot the network graph with customizable node colors, sizes, and edge widths.
135
218
 
136
219
  Args:
137
220
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
138
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
139
- node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
140
- node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
141
- edge_color (str, optional): Color of the edges. Defaults to "black".
142
221
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
222
+ edge_width (float, optional): Width of the edges. Defaults to 1.0.
223
+ node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
224
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
225
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
226
+ node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0. Annotated node_color alphas will override this value.
227
+ edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
143
228
  """
144
229
  # Log the plotting parameters
145
230
  params.log_plotter(
146
- network_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
231
+ network_node_size=(
232
+ "custom" if isinstance(node_size, np.ndarray) else node_size
233
+ ), # np.ndarray usually indicates custom sizes
234
+ network_node_shape=node_shape,
147
235
  network_edge_width=edge_width,
148
- network_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
236
+ network_node_color=(
237
+ "custom" if isinstance(node_color, np.ndarray) else node_color
238
+ ), # np.ndarray usually indicates custom colors
149
239
  network_node_edgecolor=node_edgecolor,
150
240
  network_edge_color=edge_color,
151
- network_node_shape=node_shape,
241
+ network_node_alpha=node_alpha,
242
+ network_edge_alpha=edge_alpha,
152
243
  )
244
+
245
+ # Convert colors to RGBA using the _to_rgba helper function
246
+ # If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
247
+ node_color = _to_rgba(node_color, node_alpha, num_repeats=len(self.graph.network.nodes))
248
+ # Convert other colors to RGBA using the _to_rgba helper function
249
+ node_edgecolor = _to_rgba(
250
+ node_edgecolor, 1.0, num_repeats=len(self.graph.network.nodes)
251
+ ) # Node edges are usually fully opaque
252
+ edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
253
+
153
254
  # Extract node coordinates from the network graph
154
255
  node_coordinates = self.graph.node_coordinates
256
+
155
257
  # Draw the nodes of the graph
156
258
  nx.draw_networkx_nodes(
157
259
  self.graph.network,
158
260
  pos=node_coordinates,
159
261
  node_size=node_size,
160
- node_color=node_color,
161
262
  node_shape=node_shape,
162
- alpha=1.00,
263
+ node_color=node_color,
163
264
  edgecolors=node_edgecolor,
164
265
  ax=self.ax,
165
266
  )
@@ -174,37 +275,33 @@ class NetworkPlotter:
174
275
 
175
276
  def plot_subnetwork(
176
277
  self,
177
- nodes: list,
278
+ nodes: List,
178
279
  node_size: Union[int, np.ndarray] = 50,
179
- edge_width: float = 1.0,
180
- node_color: Union[str, np.ndarray] = "white",
181
- node_edgecolor: str = "black",
182
- edge_color: str = "black",
183
280
  node_shape: str = "o",
281
+ edge_width: float = 1.0,
282
+ node_color: Union[str, List, Tuple, np.ndarray] = "white",
283
+ node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
284
+ edge_color: Union[str, List, Tuple, np.ndarray] = "black",
285
+ node_alpha: float = 1.0,
286
+ edge_alpha: float = 1.0,
184
287
  ) -> None:
185
288
  """Plot a subnetwork of selected nodes with customizable node and edge attributes.
186
289
 
187
290
  Args:
188
291
  nodes (list): List of node labels to include in the subnetwork.
189
292
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
190
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
191
- node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
192
- node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
193
- edge_color (str, optional): Color of the edges. Defaults to "black".
194
293
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
294
+ edge_width (float, optional): Width of the edges. Defaults to 1.0.
295
+ node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
296
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
297
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
298
+ node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0.
299
+ edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
195
300
 
196
301
  Raises:
197
302
  ValueError: If no valid nodes are found in the network graph.
198
303
  """
199
- # Log the plotting parameters for the subnetwork
200
- params.log_plotter(
201
- subnetwork_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
202
- subnetwork_edge_width=edge_width,
203
- subnetwork_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
204
- subnetwork_node_edgecolor=node_edgecolor,
205
- subnetwork_edge_color=edge_color,
206
- subnet_node_shape=node_shape,
207
- )
304
+ # Don't log subnetwork parameters as they are specific to individual annotations
208
305
  # Filter to get node IDs and their coordinates
209
306
  node_ids = [
210
307
  self.graph.node_label_to_id_map.get(node)
@@ -214,17 +311,21 @@ class NetworkPlotter:
214
311
  if not node_ids:
215
312
  raise ValueError("No nodes found in the network graph.")
216
313
 
314
+ # Convert colors to RGBA using the _to_rgba helper function
315
+ node_color = _to_rgba(node_color, node_alpha)
316
+ node_edgecolor = _to_rgba(node_edgecolor, 1.0) # Node edges usually fully opaque
317
+ edge_color = _to_rgba(edge_color, edge_alpha)
217
318
  # Get the coordinates of the filtered nodes
218
319
  node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
320
+
219
321
  # Draw the nodes in the subnetwork
220
322
  nx.draw_networkx_nodes(
221
323
  self.graph.network,
222
324
  pos=node_coordinates,
223
325
  nodelist=node_ids,
224
326
  node_size=node_size,
225
- node_color=node_color,
226
327
  node_shape=node_shape,
227
- alpha=1.00,
328
+ node_color=node_color,
228
329
  edgecolors=node_edgecolor,
229
330
  ax=self.ax,
230
331
  )
@@ -243,8 +344,10 @@ class NetworkPlotter:
243
344
  levels: int = 5,
244
345
  bandwidth: float = 0.8,
245
346
  grid_size: int = 250,
347
+ color: Union[str, List, Tuple, np.ndarray] = "white",
348
+ linestyle: str = "solid",
349
+ linewidth: float = 1.5,
246
350
  alpha: float = 0.2,
247
- color: Union[str, np.ndarray] = "white",
248
351
  ) -> None:
249
352
  """Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
250
353
 
@@ -252,21 +355,24 @@ class NetworkPlotter:
252
355
  levels (int, optional): Number of contour levels to plot. Defaults to 5.
253
356
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
254
357
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
358
+ color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors. Defaults to "white".
359
+ linestyle (str, optional): Line style for the contours. Defaults to "solid".
360
+ linewidth (float, optional): Line width for the contours. Defaults to 1.5.
255
361
  alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
256
- color (str or np.ndarray, optional): Color of the contours. Can be a string (e.g., 'white') or an array of colors. Defaults to "white".
257
362
  """
258
363
  # Log the contour plotting parameters
259
364
  params.log_plotter(
260
365
  contour_levels=levels,
261
366
  contour_bandwidth=bandwidth,
262
367
  contour_grid_size=grid_size,
368
+ contour_color=(
369
+ "custom" if isinstance(color, np.ndarray) else color
370
+ ), # np.ndarray usually indicates custom colors
263
371
  contour_alpha=alpha,
264
- contour_color="custom" if isinstance(color, np.ndarray) else color,
265
372
  )
266
- # Convert color string to RGBA array if necessary
267
- if isinstance(color, str):
268
- color = self.get_annotated_contour_colors(color=color)
269
373
 
374
+ # Ensure color is converted to RGBA with repetition matching the number of domains
375
+ color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_to_nodes))
270
376
  # Extract node coordinates from the network graph
271
377
  node_coordinates = self.graph.node_coordinates
272
378
  # Draw contours for each domain in the network
@@ -280,17 +386,21 @@ class NetworkPlotter:
280
386
  levels=levels,
281
387
  bandwidth=bandwidth,
282
388
  grid_size=grid_size,
389
+ linestyle=linestyle,
390
+ linewidth=linewidth,
283
391
  alpha=alpha,
284
392
  )
285
393
 
286
394
  def plot_subcontour(
287
395
  self,
288
- nodes: list,
396
+ nodes: List,
289
397
  levels: int = 5,
290
398
  bandwidth: float = 0.8,
291
399
  grid_size: int = 250,
400
+ color: Union[str, List, Tuple, np.ndarray] = "white",
401
+ linestyle: str = "solid",
402
+ linewidth: float = 1.5,
292
403
  alpha: float = 0.2,
293
- color: Union[str, np.ndarray] = "white",
294
404
  ) -> None:
295
405
  """Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
296
406
 
@@ -299,20 +409,15 @@ class NetworkPlotter:
299
409
  levels (int, optional): Number of contour levels to plot. Defaults to 5.
300
410
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
301
411
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
412
+ color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
413
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
414
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
302
415
  alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
303
- color (str or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
304
416
 
305
417
  Raises:
306
418
  ValueError: If no valid nodes are found in the network graph.
307
419
  """
308
- # Log the plotting parameters
309
- params.log_plotter(
310
- subcontour_levels=levels,
311
- subcontour_bandwidth=bandwidth,
312
- subcontour_grid_size=grid_size,
313
- subcontour_alpha=alpha,
314
- subcontour_color="custom" if isinstance(color, np.ndarray) else color,
315
- )
420
+ # Don't log subcontour parameters as they are specific to individual annotations
316
421
  # Filter to get node IDs and their coordinates
317
422
  node_ids = [
318
423
  self.graph.node_label_to_id_map.get(node)
@@ -322,6 +427,8 @@ class NetworkPlotter:
322
427
  if not node_ids or len(node_ids) == 1:
323
428
  raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
324
429
 
430
+ # Convert color to RGBA using the _to_rgba helper function
431
+ color = _to_rgba(color, alpha)
325
432
  # Draw the KDE contour for the specified nodes
326
433
  node_coordinates = self.graph.node_coordinates
327
434
  self._draw_kde_contour(
@@ -332,6 +439,8 @@ class NetworkPlotter:
332
439
  levels=levels,
333
440
  bandwidth=bandwidth,
334
441
  grid_size=grid_size,
442
+ linestyle=linestyle,
443
+ linewidth=linewidth,
335
444
  alpha=alpha,
336
445
  )
337
446
 
@@ -339,11 +448,13 @@ class NetworkPlotter:
339
448
  self,
340
449
  ax: plt.Axes,
341
450
  pos: np.ndarray,
342
- nodes: list,
343
- color: Union[str, np.ndarray],
451
+ nodes: List,
344
452
  levels: int = 5,
345
453
  bandwidth: float = 0.8,
346
454
  grid_size: int = 250,
455
+ color: Union[str, np.ndarray] = "white",
456
+ linestyle: str = "solid",
457
+ linewidth: float = 1.5,
347
458
  alpha: float = 0.5,
348
459
  ) -> None:
349
460
  """Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
@@ -352,10 +463,12 @@ class NetworkPlotter:
352
463
  ax (plt.Axes): The axis to draw the contour on.
353
464
  pos (np.ndarray): Array of node positions (x, y).
354
465
  nodes (list): List of node indices to include in the contour.
355
- color (str or np.ndarray): Color for the contour.
356
466
  levels (int, optional): Number of contour levels. Defaults to 5.
357
467
  bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
358
468
  grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
469
+ color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
470
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
471
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
359
472
  alpha (float, optional): Transparency level for the contour fill. Defaults to 0.5.
360
473
  """
361
474
  # Extract the positions of the specified nodes
@@ -382,8 +495,7 @@ class NetworkPlotter:
382
495
  min_density, max_density = z.min(), z.max()
383
496
  contour_levels = np.linspace(min_density, max_density, levels)[1:]
384
497
  contour_colors = [color for _ in range(levels - 1)]
385
-
386
- # Plot the filled contours if alpha > 0
498
+ # Plot the filled contours only if alpha > 0
387
499
  if alpha > 0:
388
500
  ax.contourf(
389
501
  x,
@@ -391,25 +503,34 @@ class NetworkPlotter:
391
503
  z,
392
504
  levels=contour_levels,
393
505
  colors=contour_colors,
394
- alpha=alpha,
395
- extend="neither",
396
506
  antialiased=True,
507
+ alpha=alpha,
397
508
  )
398
509
 
399
- # Plot the contour lines without antialiasing for clarity
400
- c = ax.contour(x, y, z, levels=contour_levels, colors=contour_colors)
510
+ # Plot the contour lines without any change in behavior
511
+ c = ax.contour(
512
+ x,
513
+ y,
514
+ z,
515
+ levels=contour_levels,
516
+ colors=contour_colors,
517
+ linestyles=linestyle,
518
+ linewidths=linewidth,
519
+ )
401
520
  for i in range(1, len(contour_levels)):
402
521
  c.collections[i].set_linewidth(0)
403
522
 
404
523
  def plot_labels(
405
524
  self,
406
- perimeter_scale: float = 1.05,
525
+ scale: float = 1.05,
407
526
  offset: float = 0.10,
408
527
  font: str = "Arial",
409
528
  fontsize: int = 10,
410
- fontcolor: Union[str, np.ndarray] = "black",
529
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
530
+ fontalpha: float = 1.0,
411
531
  arrow_linewidth: float = 1,
412
- arrow_color: Union[str, np.ndarray] = "black",
532
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
533
+ arrow_alpha: float = 1.0,
413
534
  max_labels: Union[int, None] = None,
414
535
  max_words: int = 10,
415
536
  min_words: int = 1,
@@ -420,13 +541,15 @@ class NetworkPlotter:
420
541
  """Annotate the network graph with labels for different domains, positioned around the network for clarity.
421
542
 
422
543
  Args:
423
- perimeter_scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
544
+ scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
424
545
  offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
425
546
  font (str, optional): Font name for the labels. Defaults to "Arial".
426
547
  fontsize (int, optional): Font size for the labels. Defaults to 10.
427
- fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
548
+ fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
549
+ fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
428
550
  arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
429
- arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
551
+ arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
552
+ arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
430
553
  max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
431
554
  max_words (int, optional): Maximum number of words in a label. Defaults to 10.
432
555
  min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
@@ -436,13 +559,17 @@ class NetworkPlotter:
436
559
  """
437
560
  # Log the plotting parameters
438
561
  params.log_plotter(
439
- label_perimeter_scale=perimeter_scale,
562
+ label_perimeter_scale=scale,
440
563
  label_offset=offset,
441
564
  label_font=font,
442
565
  label_fontsize=fontsize,
443
- label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
566
+ label_fontcolor=(
567
+ "custom" if isinstance(fontcolor, np.ndarray) else fontcolor
568
+ ), # np.ndarray usually indicates custom colors
569
+ label_fontalpha=fontalpha,
444
570
  label_arrow_linewidth=arrow_linewidth,
445
571
  label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
572
+ label_arrow_alpha=arrow_alpha,
446
573
  label_max_labels=max_labels,
447
574
  label_max_words=max_words,
448
575
  label_min_words=min_words,
@@ -451,11 +578,12 @@ class NetworkPlotter:
451
578
  label_words_to_omit=words_to_omit,
452
579
  )
453
580
 
454
- # Convert color strings to RGBA arrays if necessary
455
- if isinstance(fontcolor, str):
456
- fontcolor = self.get_annotated_label_colors(color=fontcolor)
457
- if isinstance(arrow_color, str):
458
- arrow_color = self.get_annotated_label_colors(color=arrow_color)
581
+ # Convert colors to RGBA using the _to_rgba helper function, applying alpha separately for font and arrow
582
+ fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes))
583
+ arrow_color = _to_rgba(
584
+ arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_to_nodes)
585
+ )
586
+
459
587
  # Normalize words_to_omit to lowercase
460
588
  if words_to_omit:
461
589
  words_to_omit = set(word.lower() for word in words_to_omit)
@@ -509,9 +637,7 @@ class NetworkPlotter:
509
637
  valid_indices = [valid_indices[i] for i in selected_indices]
510
638
 
511
639
  # Calculate the bounding box around the network
512
- center, radius = _calculate_bounding_box(
513
- self.graph.node_coordinates, radius_margin=perimeter_scale
514
- )
640
+ center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
515
641
  # Calculate the best positions for labels around the perimeter
516
642
  best_label_positions = _calculate_best_label_positions(
517
643
  filtered_domain_centroids, center, radius, offset
@@ -535,16 +661,18 @@ class NetworkPlotter:
535
661
 
536
662
  def plot_sublabel(
537
663
  self,
538
- nodes: list,
664
+ nodes: List,
539
665
  label: str,
540
666
  radial_position: float = 0.0,
541
- perimeter_scale: float = 1.05,
667
+ scale: float = 1.05,
542
668
  offset: float = 0.10,
543
669
  font: str = "Arial",
544
670
  fontsize: int = 10,
545
- fontcolor: str = "black",
671
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
672
+ fontalpha: float = 1.0,
546
673
  arrow_linewidth: float = 1,
547
- arrow_color: str = "black",
674
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
675
+ arrow_alpha: float = 1.0,
548
676
  ) -> None:
549
677
  """Annotate the network graph with a single label for the given nodes, positioned at a specified radial angle.
550
678
 
@@ -552,26 +680,17 @@ class NetworkPlotter:
552
680
  nodes (List[str]): List of node labels to be used for calculating the centroid.
553
681
  label (str): The label to be annotated on the network.
554
682
  radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
555
- perimeter_scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
683
+ scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
556
684
  offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
557
685
  font (str, optional): Font name for the label. Defaults to "Arial".
558
686
  fontsize (int, optional): Font size for the label. Defaults to 10.
559
- fontcolor (str, optional): Color of the label text. Defaults to "black".
687
+ fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
688
+ fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
560
689
  arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
561
- arrow_color (str, optional): Color of the arrow. Defaults to "black".
690
+ arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
691
+ arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
562
692
  """
563
- # Log the plotting parameters
564
- params.log_plotter(
565
- sublabel_perimeter_scale=perimeter_scale,
566
- sublabel_offset=offset,
567
- sublabel_font=font,
568
- sublabel_fontsize=fontsize,
569
- sublabel_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
570
- sublabel_arrow_linewidth=arrow_linewidth,
571
- sublabel_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
572
- sublabel_radial_position=radial_position,
573
- )
574
-
693
+ # Don't log sublabel parameters as they are specific to individual annotations
575
694
  # Map node labels to IDs
576
695
  node_ids = [
577
696
  self.graph.node_label_to_id_map.get(node)
@@ -581,12 +700,13 @@ class NetworkPlotter:
581
700
  if not node_ids or len(node_ids) == 1:
582
701
  raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
583
702
 
703
+ # Convert fontcolor and arrow_color to RGBA using the _to_rgba helper function
704
+ fontcolor = _to_rgba(fontcolor, fontalpha)
705
+ arrow_color = _to_rgba(arrow_color, arrow_alpha)
584
706
  # Calculate the centroid of the provided nodes
585
707
  centroid = self._calculate_domain_centroid(node_ids)
586
708
  # Calculate the bounding box around the network
587
- center, radius = _calculate_bounding_box(
588
- self.graph.node_coordinates, radius_margin=perimeter_scale
589
- )
709
+ center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
590
710
  # Convert radial position to radians, adjusting for a 90-degree rotation
591
711
  radial_radians = np.deg2rad(radial_position - 90)
592
712
  label_position = (
@@ -608,7 +728,7 @@ class NetworkPlotter:
608
728
  arrowprops=dict(arrowstyle="->", color=arrow_color, linewidth=arrow_linewidth),
609
729
  )
610
730
 
611
- def _calculate_domain_centroid(self, nodes: list) -> tuple:
731
+ def _calculate_domain_centroid(self, nodes: List) -> tuple:
612
732
  """Calculate the most centrally located node in .
613
733
 
614
734
  Args:
@@ -631,44 +751,51 @@ class NetworkPlotter:
631
751
 
632
752
  def get_annotated_node_colors(
633
753
  self,
754
+ cmap: str = "gist_rainbow",
755
+ color: Union[str, None] = None,
756
+ min_scale: float = 0.8,
757
+ max_scale: float = 1.0,
758
+ scale_factor: float = 1.0,
759
+ alpha: float = 1.0,
634
760
  nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
761
+ nonenriched_alpha: float = 1.0,
635
762
  random_seed: int = 888,
636
- **kwargs,
637
763
  ) -> np.ndarray:
638
764
  """Adjust the colors of nodes in the network graph based on enrichment.
639
765
 
640
766
  Args:
641
- nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes.
642
- Defaults to "white". If an alpha channel is provided, it will be used to darken the RGB values.
767
+ cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
768
+ color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
769
+ min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
770
+ max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
771
+ scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
772
+ alpha (float, optional): Alpha value for enriched nodes. Defaults to 1.0.
773
+ nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
774
+ nonenriched_alpha (float, optional): Alpha value for non-enriched nodes. Defaults to 1.0.
643
775
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
644
- **kwargs: Additional keyword arguments for `get_domain_colors`.
645
776
 
646
777
  Returns:
647
778
  np.ndarray: Array of RGBA colors adjusted for enrichment status.
648
779
  """
649
- # Get the initial domain colors for each node
650
- network_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
651
- if isinstance(nonenriched_color, str):
652
- # Convert the non-enriched color from string to RGBA
653
- nonenriched_color = mcolors.to_rgba(nonenriched_color)
654
-
655
- # Ensure nonenriched_color is a NumPy array
656
- nonenriched_color = np.array(nonenriched_color)
657
- # If alpha is provided (4th value), darken the RGB values based on alpha
658
- if len(nonenriched_color) == 4 and nonenriched_color[3] < 1.0:
659
- alpha = nonenriched_color[3]
660
- # Adjust RGB based on alpha (darken)
661
- nonenriched_color[:3] = nonenriched_color[:3] * alpha
662
- # Set alpha to 1.0 after darkening the color
663
- nonenriched_color[3] = 1.0
664
-
665
- # Adjust node colors: replace any fully transparent nodes (enriched) with the non-enriched color
780
+ # Get the initial domain colors for each node, which are returned as RGBA
781
+ network_colors = self.graph.get_domain_colors(
782
+ cmap=cmap,
783
+ color=color,
784
+ min_scale=min_scale,
785
+ max_scale=max_scale,
786
+ scale_factor=scale_factor,
787
+ random_seed=random_seed,
788
+ )
789
+ # Apply the alpha value for enriched nodes
790
+ network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
791
+ # Convert the non-enriched color to RGBA using the _to_rgba helper function
792
+ nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
793
+ # Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
666
794
  adjusted_network_colors = np.where(
667
795
  np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
668
- np.array([nonenriched_color]), # Apply the non-enriched color (with adjusted alpha)
669
- network_colors, # Keep the original colors
796
+ np.array([nonenriched_color]), # Apply the non-enriched color with alpha
797
+ network_colors, # Keep the original colors for enriched nodes
670
798
  )
671
-
672
799
  return adjusted_network_colors
673
800
 
674
801
  def get_annotated_node_sizes(
@@ -697,59 +824,119 @@ class NetworkPlotter:
697
824
 
698
825
  return node_sizes
699
826
 
700
- def get_annotated_contour_colors(self, random_seed: int = 888, **kwargs) -> np.ndarray:
701
- """Get colors for the contours based on node annotations.
827
+ def get_annotated_contour_colors(
828
+ self,
829
+ cmap: str = "gist_rainbow",
830
+ color: Union[str, None] = None,
831
+ min_scale: float = 0.8,
832
+ max_scale: float = 1.0,
833
+ scale_factor: float = 1.0,
834
+ random_seed: int = 888,
835
+ ) -> np.ndarray:
836
+ """Get colors for the contours based on node annotations or a specified colormap.
702
837
 
703
838
  Args:
704
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
705
- **kwargs: Additional keyword arguments for `_get_annotated_domain_colors`.
839
+ cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
840
+ color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
841
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
842
+ Controls the dimmest colors. Defaults to 0.8.
843
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
844
+ Controls the brightest colors. Defaults to 1.0.
845
+ scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
846
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
847
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
706
848
 
707
849
  Returns:
708
850
  np.ndarray: Array of RGBA colors for contour annotations.
709
851
  """
710
- return self._get_annotated_domain_colors(**kwargs, random_seed=random_seed)
852
+ return self._get_annotated_domain_colors(
853
+ cmap=cmap,
854
+ color=color,
855
+ min_scale=min_scale,
856
+ max_scale=max_scale,
857
+ scale_factor=scale_factor,
858
+ random_seed=random_seed,
859
+ )
711
860
 
712
- def get_annotated_label_colors(self, random_seed: int = 888, **kwargs) -> np.ndarray:
713
- """Get colors for the labels based on node annotations.
861
+ def get_annotated_label_colors(
862
+ self,
863
+ cmap: str = "gist_rainbow",
864
+ color: Union[str, None] = None,
865
+ min_scale: float = 0.8,
866
+ max_scale: float = 1.0,
867
+ scale_factor: float = 1.0,
868
+ random_seed: int = 888,
869
+ ) -> np.ndarray:
870
+ """Get colors for the labels based on node annotations or a specified colormap.
714
871
 
715
872
  Args:
716
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
717
- **kwargs: Additional keyword arguments for `_get_annotated_domain_colors`.
873
+ cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
874
+ color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
875
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
876
+ Controls the dimmest colors. Defaults to 0.8.
877
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
878
+ Controls the brightest colors. Defaults to 1.0.
879
+ scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
880
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
881
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
718
882
 
719
883
  Returns:
720
884
  np.ndarray: Array of RGBA colors for label annotations.
721
885
  """
722
- return self._get_annotated_domain_colors(**kwargs, random_seed=random_seed)
886
+ return self._get_annotated_domain_colors(
887
+ cmap=cmap,
888
+ color=color,
889
+ min_scale=min_scale,
890
+ max_scale=max_scale,
891
+ scale_factor=scale_factor,
892
+ random_seed=random_seed,
893
+ )
723
894
 
724
895
  def _get_annotated_domain_colors(
725
- self, color: Union[str, list, None] = None, random_seed: int = 888, **kwargs
896
+ self,
897
+ cmap: str = "gist_rainbow",
898
+ color: Union[str, None] = None,
899
+ min_scale: float = 0.8,
900
+ max_scale: float = 1.0,
901
+ scale_factor: float = 1.0,
902
+ random_seed: int = 888,
726
903
  ) -> np.ndarray:
727
- """Get colors for the domains based on node annotations.
904
+ """Get colors for the domains based on node annotations, or use a specified color.
728
905
 
729
906
  Args:
730
- color (str, list, or None, optional): If provided, use this color or list of colors for domains. Defaults to None.
731
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
732
- **kwargs: Additional keyword arguments for `get_domain_colors`.
907
+ cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
908
+ color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
909
+ min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
910
+ Defaults to 0.8.
911
+ max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
912
+ Defaults to 1.0.
913
+ scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
914
+ enrichment. Higher values increase the contrast. Defaults to 1.0.
915
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
733
916
 
734
917
  Returns:
735
918
  np.ndarray: Array of RGBA colors for each domain.
736
919
  """
737
- if isinstance(color, str):
738
- # If a single color string is provided, convert it to RGBA and apply to all domains
739
- rgba_color = np.array(matplotlib.colors.to_rgba(color))
740
- return np.array([rgba_color for _ in self.graph.domain_to_nodes])
741
-
742
- # Generate colors for each domain using the provided arguments and random seed
743
- node_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
920
+ # Generate domain colors based on the enrichment data
921
+ node_colors = self.graph.get_domain_colors(
922
+ cmap=cmap,
923
+ color=color,
924
+ min_scale=min_scale,
925
+ max_scale=max_scale,
926
+ scale_factor=scale_factor,
927
+ random_seed=random_seed,
928
+ )
744
929
  annotated_colors = []
745
930
  for _, nodes in self.graph.domain_to_nodes.items():
746
931
  if len(nodes) > 1:
747
- # For domains with multiple nodes, choose the brightest color (sum of RGB values)
932
+ # For multi-node domains, choose the brightest color based on RGB sum
748
933
  domain_colors = np.array([node_colors[node] for node in nodes])
749
- brightest_color = domain_colors[np.argmax(domain_colors.sum(axis=1))]
934
+ brightest_color = domain_colors[
935
+ np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
936
+ ]
750
937
  annotated_colors.append(brightest_color)
751
938
  else:
752
- # Assign a default color (white) for single-node domains
939
+ # Single-node domains default to white (RGBA)
753
940
  default_color = np.array([1.0, 1.0, 1.0, 1.0])
754
941
  annotated_colors.append(default_color)
755
942
 
@@ -776,6 +963,47 @@ class NetworkPlotter:
776
963
  plt.show(*args, **kwargs)
777
964
 
778
965
 
966
+ def _to_rgba(
967
+ color: Union[str, List, Tuple, np.ndarray],
968
+ alpha: float = 1.0,
969
+ num_repeats: Union[int, None] = None,
970
+ ) -> np.ndarray:
971
+ """Convert a color or array of colors to RGBA format, applying or updating the alpha as needed.
972
+
973
+ Args:
974
+ color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
975
+ alpha (float, optional): Alpha value (transparency) to apply. Defaults to 1.0.
976
+ num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
977
+
978
+ Returns:
979
+ np.ndarray: The RGBA color or array of RGBA colors.
980
+ """
981
+ # Handle single color case
982
+ if isinstance(color, str) or (
983
+ isinstance(color, (list, tuple, np.ndarray)) and len(color) in [3, 4]
984
+ ):
985
+ rgba_color = np.array(mcolors.to_rgba(color, alpha))
986
+ # Repeat the color if repeat argument is provided
987
+ if num_repeats is not None:
988
+ return np.array([rgba_color] * num_repeats)
989
+
990
+ return rgba_color
991
+
992
+ # Handle array of colors case
993
+ elif isinstance(color, (list, tuple, np.ndarray)) and isinstance(
994
+ color[0], (list, tuple, np.ndarray)
995
+ ):
996
+ rgba_colors = [mcolors.to_rgba(c, alpha) if len(c) == 3 else np.array(c) for c in color]
997
+ # Repeat the colors if repeat argument is provided
998
+ if num_repeats is not None and len(rgba_colors) == 1:
999
+ return np.array([rgba_colors[0]] * num_repeats)
1000
+
1001
+ return np.array(rgba_colors)
1002
+
1003
+ else:
1004
+ raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
1005
+
1006
+
779
1007
  def _is_connected(z: np.ndarray) -> bool:
780
1008
  """Determine if a thresholded grid represents a single, connected component.
781
1009