risk-network 0.0.5b6__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/network/plot.py CHANGED
@@ -5,7 +5,6 @@ risk/network/plot
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
7
7
 
8
- import matplotlib
9
8
  import matplotlib.colors as mcolors
10
9
  import matplotlib.pyplot as plt
11
10
  import networkx as nx
@@ -18,66 +17,42 @@ from risk.network.graph import NetworkGraph
18
17
 
19
18
 
20
19
  class NetworkPlotter:
21
- """A class responsible for visualizing network graphs with various customization options.
20
+ """A class for visualizing network graphs with customizable options.
22
21
 
23
- The NetworkPlotter class takes in a NetworkGraph object, which contains the network's data and attributes,
24
- and provides methods for plotting the network with customizable node and edge properties,
25
- as well as optional features like drawing the network's perimeter and setting background colors.
22
+ The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
23
+ flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
24
+ perimeter, and adjusting background colors.
26
25
  """
27
26
 
28
27
  def __init__(
29
28
  self,
30
29
  graph: NetworkGraph,
31
- figsize: tuple = (10, 10),
32
- background_color: str = "white",
33
- plot_outline: bool = True,
34
- outline_color: str = "black",
35
- outline_scale: float = 1.0,
36
- linestyle: str = "dashed",
30
+ figsize: Tuple = (10, 10),
31
+ background_color: Union[str, List, Tuple, np.ndarray] = "white",
37
32
  ) -> None:
38
33
  """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
39
34
 
40
35
  Args:
41
36
  graph (NetworkGraph): The network data and attributes to be visualized.
42
37
  figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
43
- background_color (str, optional): Background color of the plot. Defaults to "white".
44
- plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
45
- outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
46
- outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
47
- linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
38
+ background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
48
39
  """
49
40
  self.graph = graph
50
41
  # Initialize the plot with the specified parameters
51
- self.ax = self._initialize_plot(
52
- graph,
53
- figsize,
54
- background_color,
55
- plot_outline,
56
- outline_color,
57
- outline_scale,
58
- linestyle,
59
- )
42
+ self.ax = self._initialize_plot(graph, figsize, background_color)
60
43
 
61
44
  def _initialize_plot(
62
45
  self,
63
46
  graph: NetworkGraph,
64
- figsize: tuple,
65
- background_color: str,
66
- plot_outline: bool,
67
- outline_color: str,
68
- outline_scale: float,
69
- linestyle: str,
47
+ figsize: Tuple,
48
+ background_color: Union[str, List, Tuple, np.ndarray],
70
49
  ) -> plt.Axes:
71
- """Set up the plot with figure size, optional circle perimeter, and background color.
50
+ """Set up the plot with figure size and background color.
72
51
 
73
52
  Args:
74
53
  graph (NetworkGraph): The network data and attributes to be visualized.
75
54
  figsize (tuple): Size of the figure in inches (width, height).
76
55
  background_color (str): Background color of the plot.
77
- plot_outline (bool): Whether to plot the network perimeter circle.
78
- outline_color (str): Color of the network perimeter circle.
79
- outline_scale (float): Outline scaling factor for the perimeter diameter.
80
- linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid).
81
56
 
82
57
  Returns:
83
58
  plt.Axes: The axis object for the plot.
@@ -86,31 +61,19 @@ class NetworkPlotter:
86
61
  node_coordinates = graph.node_coordinates
87
62
  # Calculate the center and radius of the bounding box around the network
88
63
  center, radius = _calculate_bounding_box(node_coordinates)
89
- # Scale the radius by the outline_scale factor
90
- scaled_radius = radius * outline_scale
91
64
 
92
65
  # Create a new figure and axis for plotting
93
66
  fig, ax = plt.subplots(figsize=figsize)
94
67
  fig.tight_layout() # Adjust subplot parameters to give specified padding
95
- if plot_outline:
96
- # Draw a circle to represent the network perimeter
97
- circle = plt.Circle(
98
- center,
99
- scaled_radius,
100
- linestyle=linestyle, # Use the linestyle argument here
101
- color=outline_color,
102
- fill=False,
103
- linewidth=1.5,
104
- )
105
- ax.add_artist(circle) # Add the circle to the plot
106
-
107
- # Set axis limits based on the calculated bounding box and scaled radius
108
- ax.set_xlim([center[0] - scaled_radius - 0.3, center[0] + scaled_radius + 0.3])
109
- ax.set_ylim([center[1] - scaled_radius - 0.3, center[1] + scaled_radius + 0.3])
68
+ # Set axis limits based on the calculated bounding box and radius
69
+ ax.set_xlim([center[0] - radius - 0.3, center[0] + radius + 0.3])
70
+ ax.set_ylim([center[1] - radius - 0.3, center[1] + radius + 0.3])
110
71
  ax.set_aspect("equal") # Ensure the aspect ratio is equal
111
- fig.patch.set_facecolor(background_color) # Set the figure background color
112
- ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
113
72
 
73
+ # Set the background color of the plot
74
+ # Convert color to RGBA using the _to_rgba helper function
75
+ fig.patch.set_facecolor(_to_rgba(background_color, 1.0))
76
+ ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
114
77
  # Remove axis spines for a cleaner look
115
78
  for spine in ax.spines.values():
116
79
  spine.set_visible(False)
@@ -122,45 +85,182 @@ class NetworkPlotter:
122
85
 
123
86
  return ax
124
87
 
88
+ def plot_circle_perimeter(
89
+ self,
90
+ scale: float = 1.0,
91
+ linestyle: str = "dashed",
92
+ linewidth: float = 1.5,
93
+ color: Union[str, List, Tuple, np.ndarray] = "black",
94
+ outline_alpha: float = 1.0,
95
+ fill_alpha: float = 0.0,
96
+ ) -> None:
97
+ """Plot a circle around the network graph to represent the network perimeter.
98
+
99
+ Args:
100
+ scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
101
+ linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
102
+ linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
103
+ color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
104
+ outline_alpha (float, optional): Transparency level of the circle outline. Defaults to 1.0.
105
+ fill_alpha (float, optional): Transparency level of the circle fill. Defaults to 0.0.
106
+ """
107
+ # Log the circle perimeter plotting parameters
108
+ params.log_plotter(
109
+ perimeter_type="circle",
110
+ perimeter_scale=scale,
111
+ perimeter_linestyle=linestyle,
112
+ perimeter_linewidth=linewidth,
113
+ perimeter_color=(
114
+ "custom" if isinstance(color, (list, tuple, np.ndarray)) else color
115
+ ), # np.ndarray usually indicates custom colors
116
+ perimeter_outline_alpha=outline_alpha,
117
+ perimeter_fill_alpha=fill_alpha,
118
+ )
119
+
120
+ # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
121
+ color = _to_rgba(color, outline_alpha)
122
+ # Extract node coordinates from the network graph
123
+ node_coordinates = self.graph.node_coordinates
124
+ # Calculate the center and radius of the bounding box around the network
125
+ center, radius = _calculate_bounding_box(node_coordinates)
126
+ # Scale the radius by the scale factor
127
+ scaled_radius = radius * scale
128
+
129
+ # Draw a circle to represent the network perimeter
130
+ circle = plt.Circle(
131
+ center,
132
+ scaled_radius,
133
+ linestyle=linestyle,
134
+ linewidth=linewidth,
135
+ color=color,
136
+ fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
137
+ )
138
+ # Set the transparency of the fill if applicable
139
+ if fill_alpha > 0:
140
+ circle.set_facecolor(_to_rgba(color, fill_alpha))
141
+
142
+ self.ax.add_artist(circle)
143
+
144
+ def plot_contour_perimeter(
145
+ self,
146
+ scale: float = 1.0,
147
+ levels: int = 3,
148
+ bandwidth: float = 0.8,
149
+ grid_size: int = 250,
150
+ color: Union[str, List, Tuple, np.ndarray] = "black",
151
+ linestyle: str = "solid",
152
+ linewidth: float = 1.5,
153
+ outline_alpha: float = 1.0,
154
+ fill_alpha: float = 0.0,
155
+ ) -> None:
156
+ """
157
+ Plot a KDE-based contour around the network graph to represent the network perimeter.
158
+
159
+ Args:
160
+ scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
161
+ levels (int, optional): Number of contour levels. Defaults to 3.
162
+ bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
163
+ grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
164
+ color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
165
+ linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
166
+ linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
167
+ outline_alpha (float, optional): Transparency level of the contour outline. Defaults to 1.0.
168
+ fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.0.
169
+ """
170
+ # Log the contour perimeter plotting parameters
171
+ params.log_plotter(
172
+ perimeter_type="contour",
173
+ perimeter_scale=scale,
174
+ perimeter_levels=levels,
175
+ perimeter_bandwidth=bandwidth,
176
+ perimeter_grid_size=grid_size,
177
+ perimeter_linestyle=linestyle,
178
+ perimeter_linewidth=linewidth,
179
+ perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
180
+ perimeter_outline_alpha=outline_alpha,
181
+ perimeter_fill_alpha=fill_alpha,
182
+ )
183
+
184
+ # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
185
+ color = _to_rgba(color, outline_alpha)
186
+ # Extract node coordinates from the network graph
187
+ node_coordinates = self.graph.node_coordinates
188
+ # Scale the node coordinates if needed
189
+ scaled_coordinates = node_coordinates * scale
190
+ # Use the existing _draw_kde_contour method
191
+ self._draw_kde_contour(
192
+ ax=self.ax,
193
+ pos=scaled_coordinates,
194
+ nodes=list(range(len(node_coordinates))), # All nodes are included
195
+ levels=levels,
196
+ bandwidth=bandwidth,
197
+ grid_size=grid_size,
198
+ color=color,
199
+ linestyle=linestyle,
200
+ linewidth=linewidth,
201
+ alpha=fill_alpha,
202
+ )
203
+
125
204
  def plot_network(
126
205
  self,
127
206
  node_size: Union[int, np.ndarray] = 50,
128
- edge_width: float = 1.0,
129
- node_color: Union[str, np.ndarray] = "white",
130
- node_edgecolor: str = "black",
131
- edge_color: str = "black",
132
207
  node_shape: str = "o",
208
+ node_edgewidth: float = 1.0,
209
+ edge_width: float = 1.0,
210
+ node_color: Union[str, List, Tuple, np.ndarray] = "white",
211
+ node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
212
+ edge_color: Union[str, List, Tuple, np.ndarray] = "black",
213
+ node_alpha: float = 1.0,
214
+ edge_alpha: float = 1.0,
133
215
  ) -> None:
134
- """Plot the network graph with customizable node colors, sizes, and edge widths.
216
+ """Plot the network graph with customizable node colors, sizes, edge widths, and node edge widths.
135
217
 
136
218
  Args:
137
219
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
138
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
139
- node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
140
- node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
141
- edge_color (str, optional): Color of the edges. Defaults to "black".
142
220
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
221
+ node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
222
+ edge_width (float, optional): Width of the edges. Defaults to 1.0.
223
+ node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
224
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
225
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
226
+ node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0. Annotated node_color alphas will override this value.
227
+ edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
143
228
  """
144
229
  # Log the plotting parameters
145
230
  params.log_plotter(
146
- network_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
231
+ network_node_size=(
232
+ "custom" if isinstance(node_size, np.ndarray) else node_size
233
+ ), # np.ndarray usually indicates custom sizes
234
+ network_node_shape=node_shape,
235
+ network_node_edgewidth=node_edgewidth,
147
236
  network_edge_width=edge_width,
148
- network_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
237
+ network_node_color=(
238
+ "custom" if isinstance(node_color, np.ndarray) else node_color
239
+ ), # np.ndarray usually indicates custom colors
149
240
  network_node_edgecolor=node_edgecolor,
150
241
  network_edge_color=edge_color,
151
- network_node_shape=node_shape,
242
+ network_node_alpha=node_alpha,
243
+ network_edge_alpha=edge_alpha,
152
244
  )
245
+
246
+ # Convert colors to RGBA using the _to_rgba helper function
247
+ # If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
248
+ node_color = _to_rgba(node_color, node_alpha, num_repeats=len(self.graph.network.nodes))
249
+ node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(self.graph.network.nodes))
250
+ edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
251
+
153
252
  # Extract node coordinates from the network graph
154
253
  node_coordinates = self.graph.node_coordinates
254
+
155
255
  # Draw the nodes of the graph
156
256
  nx.draw_networkx_nodes(
157
257
  self.graph.network,
158
258
  pos=node_coordinates,
159
259
  node_size=node_size,
160
- node_color=node_color,
161
260
  node_shape=node_shape,
162
- alpha=1.00,
261
+ node_color=node_color,
163
262
  edgecolors=node_edgecolor,
263
+ linewidths=node_edgewidth,
164
264
  ax=self.ax,
165
265
  )
166
266
  # Draw the edges of the graph
@@ -174,58 +274,73 @@ class NetworkPlotter:
174
274
 
175
275
  def plot_subnetwork(
176
276
  self,
177
- nodes: list,
277
+ nodes: Union[List, Tuple, np.ndarray],
178
278
  node_size: Union[int, np.ndarray] = 50,
179
- edge_width: float = 1.0,
180
- node_color: Union[str, np.ndarray] = "white",
181
- node_edgecolor: str = "black",
182
- edge_color: str = "black",
183
279
  node_shape: str = "o",
280
+ node_edgewidth: float = 1.0,
281
+ edge_width: float = 1.0,
282
+ node_color: Union[str, List, Tuple, np.ndarray] = "white",
283
+ node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
284
+ edge_color: Union[str, List, Tuple, np.ndarray] = "black",
285
+ node_alpha: float = 1.0,
286
+ edge_alpha: float = 1.0,
184
287
  ) -> None:
185
288
  """Plot a subnetwork of selected nodes with customizable node and edge attributes.
186
289
 
187
290
  Args:
188
- nodes (list): List of node labels to include in the subnetwork.
291
+ nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
189
292
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
190
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
191
- node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
192
- node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
193
- edge_color (str, optional): Color of the edges. Defaults to "black".
194
293
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
294
+ node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
295
+ edge_width (float, optional): Width of the edges. Defaults to 1.0.
296
+ node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
297
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
298
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
299
+ node_alpha (float, optional): Transparency for the nodes. Defaults to 1.0.
300
+ edge_alpha (float, optional): Transparency for the edges. Defaults to 1.0.
195
301
 
196
302
  Raises:
197
303
  ValueError: If no valid nodes are found in the network graph.
198
304
  """
199
- # Log the plotting parameters for the subnetwork
200
- params.log_plotter(
201
- subnetwork_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
202
- subnetwork_edge_width=edge_width,
203
- subnetwork_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
204
- subnetwork_node_edgecolor=node_edgecolor,
205
- subnetwork_edge_color=edge_color,
206
- subnet_node_shape=node_shape,
207
- )
305
+ # Flatten nested lists of nodes, if necessary
306
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
307
+ nodes = [node for sublist in nodes for node in sublist]
308
+
208
309
  # Filter to get node IDs and their coordinates
209
310
  node_ids = [
210
- self.graph.node_label_to_id_map.get(node)
311
+ self.graph.node_label_to_node_id_map.get(node)
211
312
  for node in nodes
212
- if node in self.graph.node_label_to_id_map
313
+ if node in self.graph.node_label_to_node_id_map
213
314
  ]
214
315
  if not node_ids:
215
316
  raise ValueError("No nodes found in the network graph.")
216
317
 
318
+ # Check if node_color is a single color or a list of colors
319
+ if not isinstance(node_color, (str, tuple, np.ndarray)):
320
+ node_color = [
321
+ node_color[nodes.index(node)]
322
+ for node in nodes
323
+ if node in self.graph.node_label_to_node_id_map
324
+ ]
325
+
326
+ # Convert colors to RGBA using the _to_rgba helper function
327
+ node_color = _to_rgba(node_color, node_alpha, num_repeats=len(node_ids))
328
+ node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(node_ids))
329
+ edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
330
+
217
331
  # Get the coordinates of the filtered nodes
218
332
  node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
333
+
219
334
  # Draw the nodes in the subnetwork
220
335
  nx.draw_networkx_nodes(
221
336
  self.graph.network,
222
337
  pos=node_coordinates,
223
338
  nodelist=node_ids,
224
339
  node_size=node_size,
225
- node_color=node_color,
226
340
  node_shape=node_shape,
227
- alpha=1.00,
341
+ node_color=node_color,
228
342
  edgecolors=node_edgecolor,
343
+ linewidths=node_edgewidth,
229
344
  ax=self.ax,
230
345
  )
231
346
  # Draw the edges between the specified nodes in the subnetwork
@@ -243,8 +358,11 @@ class NetworkPlotter:
243
358
  levels: int = 5,
244
359
  bandwidth: float = 0.8,
245
360
  grid_size: int = 250,
246
- alpha: float = 0.2,
247
- color: Union[str, np.ndarray] = "white",
361
+ color: Union[str, List, Tuple, np.ndarray] = "white",
362
+ linestyle: str = "solid",
363
+ linewidth: float = 1.5,
364
+ alpha: float = 1.0,
365
+ fill_alpha: float = 0.2,
248
366
  ) -> None:
249
367
  """Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
250
368
 
@@ -252,99 +370,126 @@ class NetworkPlotter:
252
370
  levels (int, optional): Number of contour levels to plot. Defaults to 5.
253
371
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
254
372
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
255
- alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
256
- color (str or np.ndarray, optional): Color of the contours. Can be a string (e.g., 'white') or an array of colors. Defaults to "white".
373
+ color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors. Defaults to "white".
374
+ linestyle (str, optional): Line style for the contours. Defaults to "solid".
375
+ linewidth (float, optional): Line width for the contours. Defaults to 1.5.
376
+ alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
377
+ fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
257
378
  """
258
379
  # Log the contour plotting parameters
259
380
  params.log_plotter(
260
381
  contour_levels=levels,
261
382
  contour_bandwidth=bandwidth,
262
383
  contour_grid_size=grid_size,
384
+ contour_color=(
385
+ "custom" if isinstance(color, np.ndarray) else color
386
+ ), # np.ndarray usually indicates custom colors
263
387
  contour_alpha=alpha,
264
- contour_color="custom" if isinstance(color, np.ndarray) else color,
388
+ contour_fill_alpha=fill_alpha,
265
389
  )
266
- # Convert color string to RGBA array if necessary
267
- if isinstance(color, str):
268
- color = self.get_annotated_contour_colors(color=color)
269
390
 
391
+ # Ensure color is converted to RGBA with repetition matching the number of domains
392
+ color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map))
270
393
  # Extract node coordinates from the network graph
271
394
  node_coordinates = self.graph.node_coordinates
272
395
  # Draw contours for each domain in the network
273
- for idx, (_, nodes) in enumerate(self.graph.domain_to_nodes.items()):
274
- if len(nodes) > 1:
396
+ for idx, (_, node_ids) in enumerate(self.graph.domain_id_to_node_ids_map.items()):
397
+ if len(node_ids) > 1:
275
398
  self._draw_kde_contour(
276
399
  self.ax,
277
400
  node_coordinates,
278
- nodes,
401
+ node_ids,
279
402
  color=color[idx],
280
403
  levels=levels,
281
404
  bandwidth=bandwidth,
282
405
  grid_size=grid_size,
406
+ linestyle=linestyle,
407
+ linewidth=linewidth,
283
408
  alpha=alpha,
409
+ fill_alpha=fill_alpha,
284
410
  )
285
411
 
286
412
  def plot_subcontour(
287
413
  self,
288
- nodes: list,
414
+ nodes: Union[List, Tuple, np.ndarray],
289
415
  levels: int = 5,
290
416
  bandwidth: float = 0.8,
291
417
  grid_size: int = 250,
292
- alpha: float = 0.2,
293
- color: Union[str, np.ndarray] = "white",
418
+ color: Union[str, List, Tuple, np.ndarray] = "white",
419
+ linestyle: str = "solid",
420
+ linewidth: float = 1.5,
421
+ alpha: float = 1.0,
422
+ fill_alpha: float = 0.2,
294
423
  ) -> None:
295
- """Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
424
+ """Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
296
425
 
297
426
  Args:
298
- nodes (list): List of node labels to plot the contour for.
427
+ nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
299
428
  levels (int, optional): Number of contour levels to plot. Defaults to 5.
300
429
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
301
430
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
302
- alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
303
- color (str or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
431
+ color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
432
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
433
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
434
+ alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
435
+ fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
304
436
 
305
437
  Raises:
306
438
  ValueError: If no valid nodes are found in the network graph.
307
439
  """
308
- # Log the plotting parameters
309
- params.log_plotter(
310
- subcontour_levels=levels,
311
- subcontour_bandwidth=bandwidth,
312
- subcontour_grid_size=grid_size,
313
- subcontour_alpha=alpha,
314
- subcontour_color="custom" if isinstance(color, np.ndarray) else color,
315
- )
316
- # Filter to get node IDs and their coordinates
317
- node_ids = [
318
- self.graph.node_label_to_id_map.get(node)
319
- for node in nodes
320
- if node in self.graph.node_label_to_id_map
321
- ]
322
- if not node_ids or len(node_ids) == 1:
323
- raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
440
+ # Check if nodes is a list of lists or a flat list
441
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
442
+ # If it's a list of lists, iterate over sublists
443
+ node_groups = nodes
444
+ else:
445
+ # If it's a flat list of nodes, treat it as a single group
446
+ node_groups = [nodes]
447
+
448
+ # Convert color to RGBA using the _to_rgba helper function
449
+ color_rgba = _to_rgba(color, alpha)
450
+
451
+ # Iterate over each group of nodes (either sublists or flat list)
452
+ for sublist in node_groups:
453
+ # Filter to get node IDs and their coordinates for each sublist
454
+ node_ids = [
455
+ self.graph.node_label_to_node_id_map.get(node)
456
+ for node in sublist
457
+ if node in self.graph.node_label_to_node_id_map
458
+ ]
459
+ if not node_ids or len(node_ids) == 1:
460
+ raise ValueError(
461
+ "No nodes found in the network graph or insufficient nodes to plot."
462
+ )
324
463
 
325
- # Draw the KDE contour for the specified nodes
326
- node_coordinates = self.graph.node_coordinates
327
- self._draw_kde_contour(
328
- self.ax,
329
- node_coordinates,
330
- node_ids,
331
- color=color,
332
- levels=levels,
333
- bandwidth=bandwidth,
334
- grid_size=grid_size,
335
- alpha=alpha,
336
- )
464
+ # Draw the KDE contour for the specified nodes
465
+ node_coordinates = self.graph.node_coordinates
466
+ self._draw_kde_contour(
467
+ self.ax,
468
+ node_coordinates,
469
+ node_ids,
470
+ color=color_rgba,
471
+ levels=levels,
472
+ bandwidth=bandwidth,
473
+ grid_size=grid_size,
474
+ linestyle=linestyle,
475
+ linewidth=linewidth,
476
+ alpha=alpha,
477
+ fill_alpha=fill_alpha,
478
+ )
337
479
 
338
480
  def _draw_kde_contour(
339
481
  self,
340
482
  ax: plt.Axes,
341
483
  pos: np.ndarray,
342
- nodes: list,
343
- color: Union[str, np.ndarray],
484
+ nodes: List,
344
485
  levels: int = 5,
345
486
  bandwidth: float = 0.8,
346
487
  grid_size: int = 250,
347
- alpha: float = 0.5,
488
+ color: Union[str, np.ndarray] = "white",
489
+ linestyle: str = "solid",
490
+ linewidth: float = 1.5,
491
+ alpha: float = 1.0,
492
+ fill_alpha: float = 0.2,
348
493
  ) -> None:
349
494
  """Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
350
495
 
@@ -352,11 +497,14 @@ class NetworkPlotter:
352
497
  ax (plt.Axes): The axis to draw the contour on.
353
498
  pos (np.ndarray): Array of node positions (x, y).
354
499
  nodes (list): List of node indices to include in the contour.
355
- color (str or np.ndarray): Color for the contour.
356
500
  levels (int, optional): Number of contour levels. Defaults to 5.
357
501
  bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
358
502
  grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
359
- alpha (float, optional): Transparency level for the contour fill. Defaults to 0.5.
503
+ color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
504
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
505
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
506
+ alpha (float, optional): Transparency level for the contour lines. Defaults to 1.0.
507
+ fill_alpha (float, optional): Transparency level for the contour fill. Defaults to 0.2.
360
508
  """
361
509
  # Extract the positions of the specified nodes
362
510
  points = np.array([pos[n] for n in nodes])
@@ -382,141 +530,219 @@ class NetworkPlotter:
382
530
  min_density, max_density = z.min(), z.max()
383
531
  contour_levels = np.linspace(min_density, max_density, levels)[1:]
384
532
  contour_colors = [color for _ in range(levels - 1)]
385
-
386
- # Plot the filled contours if alpha > 0
387
- if alpha > 0:
533
+ # Plot the filled contours using fill_alpha for transparency
534
+ if fill_alpha > 0:
388
535
  ax.contourf(
389
536
  x,
390
537
  y,
391
538
  z,
392
539
  levels=contour_levels,
393
540
  colors=contour_colors,
394
- alpha=alpha,
395
- extend="neither",
396
541
  antialiased=True,
542
+ alpha=fill_alpha,
397
543
  )
398
544
 
399
- # Plot the contour lines without antialiasing for clarity
400
- c = ax.contour(x, y, z, levels=contour_levels, colors=contour_colors)
545
+ # Plot the contour lines with the specified alpha for transparency
546
+ c = ax.contour(
547
+ x,
548
+ y,
549
+ z,
550
+ levels=contour_levels,
551
+ colors=contour_colors,
552
+ linestyles=linestyle,
553
+ linewidths=linewidth,
554
+ alpha=alpha,
555
+ )
556
+ # Set linewidth for the contour lines to 0 for levels other than the base level
401
557
  for i in range(1, len(contour_levels)):
402
558
  c.collections[i].set_linewidth(0)
403
559
 
404
560
  def plot_labels(
405
561
  self,
406
- perimeter_scale: float = 1.05,
562
+ scale: float = 1.05,
407
563
  offset: float = 0.10,
408
564
  font: str = "Arial",
409
565
  fontsize: int = 10,
410
- fontcolor: Union[str, np.ndarray] = "black",
566
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
567
+ fontalpha: float = 1.0,
411
568
  arrow_linewidth: float = 1,
412
- arrow_color: Union[str, np.ndarray] = "black",
569
+ arrow_style: str = "->",
570
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
571
+ arrow_alpha: float = 1.0,
572
+ arrow_base_shrink: float = 0.0,
573
+ arrow_tip_shrink: float = 0.0,
413
574
  max_labels: Union[int, None] = None,
414
575
  max_words: int = 10,
415
576
  min_words: int = 1,
416
577
  max_word_length: int = 20,
417
578
  min_word_length: int = 1,
418
- words_to_omit: Union[List[str], None] = None,
579
+ words_to_omit: Union[List, None] = None,
580
+ overlay_ids: bool = False,
581
+ ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
582
+ ids_to_replace: Union[Dict, None] = None,
419
583
  ) -> None:
420
584
  """Annotate the network graph with labels for different domains, positioned around the network for clarity.
421
585
 
422
586
  Args:
423
- perimeter_scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
587
+ scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
424
588
  offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
425
589
  font (str, optional): Font name for the labels. Defaults to "Arial".
426
590
  fontsize (int, optional): Font size for the labels. Defaults to 10.
427
- fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
591
+ fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
592
+ fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
428
593
  arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
429
- arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
594
+ arrow_style (str, optional): Style of the arrows pointing to centroids. Defaults to "->".
595
+ arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
596
+ arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
597
+ arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
598
+ arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
430
599
  max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
431
600
  max_words (int, optional): Maximum number of words in a label. Defaults to 10.
432
601
  min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
433
602
  max_word_length (int, optional): Maximum number of characters in a word to display. Defaults to 20.
434
603
  min_word_length (int, optional): Minimum number of characters in a word to display. Defaults to 1.
435
- words_to_omit (List[str], optional): List of words to omit from the labels. Defaults to None.
604
+ words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
605
+ overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
606
+ ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
607
+ you can set `overlay_ids=True`. Defaults to None.
608
+ ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be space-separated words.
609
+ If provided, the custom labels will replace the default domain terms. To discover domain IDs, you can set `overlay_ids=True`.
610
+ Defaults to None.
611
+
612
+ Raises:
613
+ ValueError: If the number of provided `ids_to_keep` exceeds `max_labels`.
436
614
  """
437
615
  # Log the plotting parameters
438
616
  params.log_plotter(
439
- label_perimeter_scale=perimeter_scale,
617
+ label_perimeter_scale=scale,
440
618
  label_offset=offset,
441
619
  label_font=font,
442
620
  label_fontsize=fontsize,
443
- label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
621
+ label_fontcolor=(
622
+ "custom" if isinstance(fontcolor, np.ndarray) else fontcolor
623
+ ), # np.ndarray usually indicates custom colors
624
+ label_fontalpha=fontalpha,
444
625
  label_arrow_linewidth=arrow_linewidth,
626
+ label_arrow_style=arrow_style,
445
627
  label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
628
+ label_arrow_alpha=arrow_alpha,
629
+ label_arrow_base_shrink=arrow_base_shrink,
630
+ label_arrow_tip_shrink=arrow_tip_shrink,
446
631
  label_max_labels=max_labels,
447
632
  label_max_words=max_words,
448
633
  label_min_words=min_words,
449
634
  label_max_word_length=max_word_length,
450
635
  label_min_word_length=min_word_length,
451
636
  label_words_to_omit=words_to_omit,
637
+ label_overlay_ids=overlay_ids,
638
+ label_ids_to_keep=ids_to_keep,
639
+ label_ids_to_replace=ids_to_replace,
640
+ )
641
+
642
+ # Set max_labels to the total number of domains if not provided (None)
643
+ if max_labels is None:
644
+ max_labels = len(self.graph.domain_id_to_node_ids_map)
645
+
646
+ # Convert colors to RGBA using the _to_rgba helper function
647
+ fontcolor = _to_rgba(
648
+ fontcolor, fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
649
+ )
650
+ arrow_color = _to_rgba(
651
+ arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
452
652
  )
453
653
 
454
- # Convert color strings to RGBA arrays if necessary
455
- if isinstance(fontcolor, str):
456
- fontcolor = self.get_annotated_label_colors(color=fontcolor)
457
- if isinstance(arrow_color, str):
458
- arrow_color = self.get_annotated_label_colors(color=arrow_color)
459
654
  # Normalize words_to_omit to lowercase
460
655
  if words_to_omit:
461
656
  words_to_omit = set(word.lower() for word in words_to_omit)
462
657
 
463
658
  # Calculate the center and radius of the network
464
659
  domain_centroids = {}
465
- for domain, nodes in self.graph.domain_to_nodes.items():
466
- if nodes: # Skip if the domain has no nodes
467
- domain_centroids[domain] = self._calculate_domain_centroid(nodes)
660
+ for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
661
+ if node_ids: # Skip if the domain has no nodes
662
+ domain_centroids[domain_id] = self._calculate_domain_centroid(node_ids)
468
663
 
469
- # Initialize empty lists to collect valid indices
664
+ # Initialize dictionaries and lists for valid indices
470
665
  valid_indices = []
471
666
  filtered_domain_centroids = {}
472
667
  filtered_domain_terms = {}
473
- # Loop through domain_centroids with index
474
- for idx, (domain, centroid) in enumerate(domain_centroids.items()):
475
- # Process the domain term
476
- terms = self.graph.trimmed_domain_to_term[domain].split(" ")
477
- # Remove words_to_omit
478
- if words_to_omit:
479
- terms = [term for term in terms if term.lower() not in words_to_omit]
480
- # Filter words based on length
481
- terms = [term for term in terms if min_word_length <= len(term) <= max_word_length]
482
- # Trim to max_words
483
- terms = terms[:max_words]
484
- # Check if the domain passes the word count condition
485
- if len(terms) >= min_words:
486
- # Add to filtered_domain_centroids
487
- filtered_domain_centroids[domain] = centroid
488
- # Store the filtered and trimmed terms
489
- filtered_domain_terms[domain] = " ".join(terms)
490
- # Keep track of the valid index - used for fontcolor and arrow_color
491
- valid_indices.append(idx)
492
-
493
- # If max_labels is specified and less than the available labels
494
- if max_labels is not None and max_labels < len(filtered_domain_centroids):
495
- step = len(filtered_domain_centroids) / max_labels
496
- selected_indices = [int(i * step) for i in range(max_labels)]
497
- # Filter the centroids, terms, and valid_indices to only use the selected indices
498
- filtered_domain_centroids = {
499
- k: v
500
- for i, (k, v) in enumerate(filtered_domain_centroids.items())
501
- if i in selected_indices
502
- }
503
- filtered_domain_terms = {
504
- k: v
505
- for i, (k, v) in enumerate(filtered_domain_terms.items())
506
- if i in selected_indices
507
- }
508
- # Update valid_indices to match selected indices
509
- valid_indices = [valid_indices[i] for i in selected_indices]
668
+ # Handle the ids_to_keep logic
669
+ if ids_to_keep:
670
+ # Convert ids_to_keep to remove accidental duplicates
671
+ ids_to_keep = set(ids_to_keep)
672
+ # Check if the number of provided ids_to_keep exceeds max_labels
673
+ if max_labels is not None and len(ids_to_keep) > max_labels:
674
+ raise ValueError(
675
+ f"Number of provided IDs ({len(ids_to_keep)}) exceeds max_labels ({max_labels})."
676
+ )
510
677
 
511
- # Calculate the bounding box around the network
512
- center, radius = _calculate_bounding_box(
513
- self.graph.node_coordinates, radius_margin=perimeter_scale
678
+ # Process the specified IDs first
679
+ for domain in ids_to_keep:
680
+ if (
681
+ domain in self.graph.domain_id_to_domain_terms_map
682
+ and domain in domain_centroids
683
+ ):
684
+ # Handle ids_to_replace logic here for ids_to_keep
685
+ if ids_to_replace and domain in ids_to_replace:
686
+ terms = ids_to_replace[domain].split(" ")
687
+ else:
688
+ terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
689
+
690
+ # Apply words_to_omit, word length constraints, and max_words
691
+ if words_to_omit:
692
+ terms = [term for term in terms if term.lower() not in words_to_omit]
693
+ terms = [
694
+ term for term in terms if min_word_length <= len(term) <= max_word_length
695
+ ]
696
+ terms = terms[:max_words]
697
+
698
+ # Check if the domain passes the word count condition
699
+ if len(terms) >= min_words:
700
+ filtered_domain_centroids[domain] = domain_centroids[domain]
701
+ filtered_domain_terms[domain] = " ".join(terms)
702
+ valid_indices.append(
703
+ list(domain_centroids.keys()).index(domain)
704
+ ) # Track the valid index
705
+
706
+ # Calculate remaining labels to plot after processing ids_to_keep
707
+ remaining_labels = (
708
+ max_labels - len(ids_to_keep) if ids_to_keep and max_labels else max_labels
514
709
  )
515
- # Calculate the best positions for labels around the perimeter
710
+ # Process remaining domains to fill in additional labels, if there are slots left
711
+ if remaining_labels and remaining_labels > 0:
712
+ for idx, (domain, centroid) in enumerate(domain_centroids.items()):
713
+ if ids_to_keep and domain in ids_to_keep:
714
+ continue # Skip domains already handled by ids_to_keep
715
+
716
+ # Handle ids_to_replace logic first
717
+ if ids_to_replace and domain in ids_to_replace:
718
+ terms = ids_to_replace[domain].split(" ")
719
+ else:
720
+ terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
721
+
722
+ # Apply words_to_omit, word length constraints, and max_words
723
+ if words_to_omit:
724
+ terms = [term for term in terms if term.lower() not in words_to_omit]
725
+
726
+ terms = [term for term in terms if min_word_length <= len(term) <= max_word_length]
727
+ terms = terms[:max_words]
728
+ # Check if the domain passes the word count condition
729
+ if len(terms) >= min_words:
730
+ filtered_domain_centroids[domain] = centroid
731
+ filtered_domain_terms[domain] = " ".join(terms)
732
+ valid_indices.append(idx) # Track the valid index
733
+
734
+ # Stop once we've reached the max_labels limit
735
+ if len(filtered_domain_centroids) >= max_labels:
736
+ break
737
+
738
+ # Calculate the bounding box around the network
739
+ center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
740
+ # Calculate the best positions for labels
516
741
  best_label_positions = _calculate_best_label_positions(
517
742
  filtered_domain_centroids, center, radius, offset
518
743
  )
519
- # Annotate the network with labels - valid_indices is used for fontcolor and arrow_color
744
+
745
+ # Annotate the network with labels
520
746
  for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
521
747
  centroid = filtered_domain_centroids[domain]
522
748
  annotations = filtered_domain_terms[domain].split(" ")[:max_words]
@@ -530,63 +756,79 @@ class NetworkPlotter:
530
756
  fontsize=fontsize,
531
757
  fontname=font,
532
758
  color=fontcolor[idx],
533
- arrowprops=dict(arrowstyle="->", color=arrow_color[idx], linewidth=arrow_linewidth),
759
+ arrowprops=dict(
760
+ arrowstyle=arrow_style,
761
+ color=arrow_color[idx],
762
+ linewidth=arrow_linewidth,
763
+ shrinkA=arrow_base_shrink,
764
+ shrinkB=arrow_tip_shrink,
765
+ ),
534
766
  )
767
+ # Overlay domain ID at the centroid if requested
768
+ if overlay_ids:
769
+ self.ax.text(
770
+ centroid[0],
771
+ centroid[1],
772
+ domain,
773
+ ha="center",
774
+ va="center",
775
+ fontsize=fontsize,
776
+ fontname=font,
777
+ color=fontcolor[idx],
778
+ alpha=fontalpha,
779
+ )
535
780
 
536
781
  def plot_sublabel(
537
782
  self,
538
- nodes: list,
783
+ nodes: Union[List, Tuple, np.ndarray],
539
784
  label: str,
540
785
  radial_position: float = 0.0,
541
- perimeter_scale: float = 1.05,
786
+ scale: float = 1.05,
542
787
  offset: float = 0.10,
543
788
  font: str = "Arial",
544
789
  fontsize: int = 10,
545
- fontcolor: str = "black",
790
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
791
+ fontalpha: float = 1.0,
546
792
  arrow_linewidth: float = 1,
547
- arrow_color: str = "black",
793
+ arrow_style: str = "->",
794
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
795
+ arrow_alpha: float = 1.0,
796
+ arrow_base_shrink: float = 0.0,
797
+ arrow_tip_shrink: float = 0.0,
548
798
  ) -> None:
549
- """Annotate the network graph with a single label for the given nodes, positioned at a specified radial angle.
799
+ """Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
550
800
 
551
801
  Args:
552
- nodes (List[str]): List of node labels to be used for calculating the centroid.
802
+ nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels.
553
803
  label (str): The label to be annotated on the network.
554
804
  radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
555
- perimeter_scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
805
+ scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
556
806
  offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
557
807
  font (str, optional): Font name for the label. Defaults to "Arial".
558
808
  fontsize (int, optional): Font size for the label. Defaults to 10.
559
- fontcolor (str, optional): Color of the label text. Defaults to "black".
809
+ fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
810
+ fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
560
811
  arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
561
- arrow_color (str, optional): Color of the arrow. Defaults to "black".
812
+ arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
813
+ arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
814
+ arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
815
+ arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
816
+ arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
562
817
  """
563
- # Log the plotting parameters
564
- params.log_plotter(
565
- sublabel_perimeter_scale=perimeter_scale,
566
- sublabel_offset=offset,
567
- sublabel_font=font,
568
- sublabel_fontsize=fontsize,
569
- sublabel_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
570
- sublabel_arrow_linewidth=arrow_linewidth,
571
- sublabel_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
572
- sublabel_radial_position=radial_position,
573
- )
574
-
575
- # Map node labels to IDs
576
- node_ids = [
577
- self.graph.node_label_to_id_map.get(node)
578
- for node in nodes
579
- if node in self.graph.node_label_to_id_map
580
- ]
581
- if not node_ids or len(node_ids) == 1:
582
- raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
818
+ # Check if nodes is a list of lists or a flat list
819
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
820
+ # If it's a list of lists, iterate over sublists
821
+ node_groups = nodes
822
+ else:
823
+ # If it's a flat list of nodes, treat it as a single group
824
+ node_groups = [nodes]
825
+
826
+ # Convert fontcolor and arrow_color to RGBA
827
+ fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
828
+ arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
583
829
 
584
- # Calculate the centroid of the provided nodes
585
- centroid = self._calculate_domain_centroid(node_ids)
586
830
  # Calculate the bounding box around the network
587
- center, radius = _calculate_bounding_box(
588
- self.graph.node_coordinates, radius_margin=perimeter_scale
589
- )
831
+ center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
590
832
  # Convert radial position to radians, adjusting for a 90-degree rotation
591
833
  radial_radians = np.deg2rad(radial_position - 90)
592
834
  label_position = (
@@ -594,21 +836,42 @@ class NetworkPlotter:
594
836
  center[1] + (radius + offset) * np.sin(radial_radians),
595
837
  )
596
838
 
597
- # Annotate the network with the label
598
- self.ax.annotate(
599
- label,
600
- xy=centroid,
601
- xytext=label_position,
602
- textcoords="data",
603
- ha="center",
604
- va="center",
605
- fontsize=fontsize,
606
- fontname=font,
607
- color=fontcolor,
608
- arrowprops=dict(arrowstyle="->", color=arrow_color, linewidth=arrow_linewidth),
609
- )
839
+ # Iterate over each group of nodes (either sublists or flat list)
840
+ for sublist in node_groups:
841
+ # Map node labels to IDs
842
+ node_ids = [
843
+ self.graph.node_label_to_node_id_map.get(node)
844
+ for node in sublist
845
+ if node in self.graph.node_label_to_node_id_map
846
+ ]
847
+ if not node_ids or len(node_ids) == 1:
848
+ raise ValueError(
849
+ "No nodes found in the network graph or insufficient nodes to plot."
850
+ )
610
851
 
611
- def _calculate_domain_centroid(self, nodes: list) -> tuple:
852
+ # Calculate the centroid of the provided nodes in this sublist
853
+ centroid = self._calculate_domain_centroid(node_ids)
854
+ # Annotate the network with the label and an arrow pointing to each centroid
855
+ self.ax.annotate(
856
+ label,
857
+ xy=centroid,
858
+ xytext=label_position,
859
+ textcoords="data",
860
+ ha="center",
861
+ va="center",
862
+ fontsize=fontsize,
863
+ fontname=font,
864
+ color=fontcolor_rgba,
865
+ arrowprops=dict(
866
+ arrowstyle=arrow_style,
867
+ color=arrow_color_rgba,
868
+ linewidth=arrow_linewidth,
869
+ shrinkA=arrow_base_shrink,
870
+ shrinkB=arrow_tip_shrink,
871
+ ),
872
+ )
873
+
874
+ def _calculate_domain_centroid(self, nodes: List) -> tuple:
612
875
  """Calculate the most centrally located node in .
613
876
 
614
877
  Args:
@@ -630,111 +893,193 @@ class NetworkPlotter:
630
893
  return domain_central_node
631
894
 
632
895
  def get_annotated_node_colors(
633
- self, nonenriched_color: str = "white", random_seed: int = 888, **kwargs
896
+ self,
897
+ cmap: str = "gist_rainbow",
898
+ color: Union[str, None] = None,
899
+ min_scale: float = 0.8,
900
+ max_scale: float = 1.0,
901
+ scale_factor: float = 1.0,
902
+ alpha: float = 1.0,
903
+ nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
904
+ nonenriched_alpha: float = 1.0,
905
+ random_seed: int = 888,
634
906
  ) -> np.ndarray:
635
907
  """Adjust the colors of nodes in the network graph based on enrichment.
636
908
 
637
909
  Args:
638
- nonenriched_color (str, optional): Color for non-enriched nodes. Defaults to "white".
910
+ cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
911
+ color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
912
+ min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
913
+ max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
914
+ scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
915
+ alpha (float, optional): Alpha value for enriched nodes. Defaults to 1.0.
916
+ nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
917
+ nonenriched_alpha (float, optional): Alpha value for non-enriched nodes. Defaults to 1.0.
639
918
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
640
- **kwargs: Additional keyword arguments for `get_domain_colors`.
641
919
 
642
920
  Returns:
643
921
  np.ndarray: Array of RGBA colors adjusted for enrichment status.
644
922
  """
645
- # Get the initial domain colors for each node
646
- network_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
647
- if isinstance(nonenriched_color, str):
648
- # Convert the non-enriched color from string to RGBA
649
- nonenriched_color = mcolors.to_rgba(nonenriched_color)
650
-
651
- # Adjust node colors: replace any fully transparent nodes (enriched) with the non-enriched color
923
+ # Get the initial domain colors for each node, which are returned as RGBA
924
+ network_colors = self.graph.get_domain_colors(
925
+ cmap=cmap,
926
+ color=color,
927
+ min_scale=min_scale,
928
+ max_scale=max_scale,
929
+ scale_factor=scale_factor,
930
+ random_seed=random_seed,
931
+ )
932
+ # Apply the alpha value for enriched nodes
933
+ network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
934
+ # Convert the non-enriched color to RGBA using the _to_rgba helper function
935
+ nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
936
+ # Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
652
937
  adjusted_network_colors = np.where(
653
- np.all(network_colors == 0, axis=1, keepdims=True),
654
- np.array([nonenriched_color]),
655
- network_colors,
938
+ np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
939
+ np.array([nonenriched_color]), # Apply the non-enriched color with alpha
940
+ network_colors, # Keep the original colors for enriched nodes
656
941
  )
657
942
  return adjusted_network_colors
658
943
 
659
944
  def get_annotated_node_sizes(
660
- self, enriched_nodesize: int = 50, nonenriched_nodesize: int = 25
945
+ self, enriched_size: int = 50, nonenriched_size: int = 25
661
946
  ) -> np.ndarray:
662
947
  """Adjust the sizes of nodes in the network graph based on whether they are enriched or not.
663
948
 
664
949
  Args:
665
- enriched_nodesize (int): Size for enriched nodes. Defaults to 50.
666
- nonenriched_nodesize (int): Size for non-enriched nodes. Defaults to 25.
950
+ enriched_size (int): Size for enriched nodes. Defaults to 50.
951
+ nonenriched_size (int): Size for non-enriched nodes. Defaults to 25.
667
952
 
668
953
  Returns:
669
954
  np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
670
955
  """
671
- # Merge all enriched nodes from the domain_to_nodes dictionary
956
+ # Merge all enriched nodes from the domain_id_to_node_ids_map dictionary
672
957
  enriched_nodes = set()
673
- for _, nodes in self.graph.domain_to_nodes.items():
674
- enriched_nodes.update(nodes)
958
+ for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
959
+ enriched_nodes.update(node_ids)
675
960
 
676
961
  # Initialize all node sizes to the non-enriched size
677
- node_sizes = np.full(len(self.graph.network.nodes), nonenriched_nodesize)
962
+ node_sizes = np.full(len(self.graph.network.nodes), nonenriched_size)
678
963
  # Set the size for enriched nodes
679
964
  for node in enriched_nodes:
680
965
  if node in self.graph.network.nodes:
681
- node_sizes[node] = enriched_nodesize
966
+ node_sizes[node] = enriched_size
682
967
 
683
968
  return node_sizes
684
969
 
685
- def get_annotated_contour_colors(self, random_seed: int = 888, **kwargs) -> np.ndarray:
686
- """Get colors for the contours based on node annotations.
970
+ def get_annotated_contour_colors(
971
+ self,
972
+ cmap: str = "gist_rainbow",
973
+ color: Union[str, None] = None,
974
+ min_scale: float = 0.8,
975
+ max_scale: float = 1.0,
976
+ scale_factor: float = 1.0,
977
+ random_seed: int = 888,
978
+ ) -> np.ndarray:
979
+ """Get colors for the contours based on node annotations or a specified colormap.
687
980
 
688
981
  Args:
689
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
690
- **kwargs: Additional keyword arguments for `_get_annotated_domain_colors`.
982
+ cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
983
+ color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
984
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
985
+ Controls the dimmest colors. Defaults to 0.8.
986
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
987
+ Controls the brightest colors. Defaults to 1.0.
988
+ scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
989
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
990
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
691
991
 
692
992
  Returns:
693
993
  np.ndarray: Array of RGBA colors for contour annotations.
694
994
  """
695
- return self._get_annotated_domain_colors(**kwargs, random_seed=random_seed)
995
+ return self._get_annotated_domain_colors(
996
+ cmap=cmap,
997
+ color=color,
998
+ min_scale=min_scale,
999
+ max_scale=max_scale,
1000
+ scale_factor=scale_factor,
1001
+ random_seed=random_seed,
1002
+ )
696
1003
 
697
- def get_annotated_label_colors(self, random_seed: int = 888, **kwargs) -> np.ndarray:
698
- """Get colors for the labels based on node annotations.
1004
+ def get_annotated_label_colors(
1005
+ self,
1006
+ cmap: str = "gist_rainbow",
1007
+ color: Union[str, None] = None,
1008
+ min_scale: float = 0.8,
1009
+ max_scale: float = 1.0,
1010
+ scale_factor: float = 1.0,
1011
+ random_seed: int = 888,
1012
+ ) -> np.ndarray:
1013
+ """Get colors for the labels based on node annotations or a specified colormap.
699
1014
 
700
1015
  Args:
701
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
702
- **kwargs: Additional keyword arguments for `_get_annotated_domain_colors`.
1016
+ cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
1017
+ color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
1018
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
1019
+ Controls the dimmest colors. Defaults to 0.8.
1020
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
1021
+ Controls the brightest colors. Defaults to 1.0.
1022
+ scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
1023
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
1024
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
703
1025
 
704
1026
  Returns:
705
1027
  np.ndarray: Array of RGBA colors for label annotations.
706
1028
  """
707
- return self._get_annotated_domain_colors(**kwargs, random_seed=random_seed)
1029
+ return self._get_annotated_domain_colors(
1030
+ cmap=cmap,
1031
+ color=color,
1032
+ min_scale=min_scale,
1033
+ max_scale=max_scale,
1034
+ scale_factor=scale_factor,
1035
+ random_seed=random_seed,
1036
+ )
708
1037
 
709
1038
  def _get_annotated_domain_colors(
710
- self, color: Union[str, list, None] = None, random_seed: int = 888, **kwargs
1039
+ self,
1040
+ cmap: str = "gist_rainbow",
1041
+ color: Union[str, None] = None,
1042
+ min_scale: float = 0.8,
1043
+ max_scale: float = 1.0,
1044
+ scale_factor: float = 1.0,
1045
+ random_seed: int = 888,
711
1046
  ) -> np.ndarray:
712
- """Get colors for the domains based on node annotations.
1047
+ """Get colors for the domains based on node annotations, or use a specified color.
713
1048
 
714
1049
  Args:
715
- color (str, list, or None, optional): If provided, use this color or list of colors for domains. Defaults to None.
716
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
717
- **kwargs: Additional keyword arguments for `get_domain_colors`.
1050
+ cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
1051
+ color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
1052
+ min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
1053
+ Defaults to 0.8.
1054
+ max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
1055
+ Defaults to 1.0.
1056
+ scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
1057
+ enrichment. Higher values increase the contrast. Defaults to 1.0.
1058
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
718
1059
 
719
1060
  Returns:
720
1061
  np.ndarray: Array of RGBA colors for each domain.
721
1062
  """
722
- if isinstance(color, str):
723
- # If a single color string is provided, convert it to RGBA and apply to all domains
724
- rgba_color = np.array(matplotlib.colors.to_rgba(color))
725
- return np.array([rgba_color for _ in self.graph.domain_to_nodes])
726
-
727
- # Generate colors for each domain using the provided arguments and random seed
728
- node_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
1063
+ # Generate domain colors based on the enrichment data
1064
+ node_colors = self.graph.get_domain_colors(
1065
+ cmap=cmap,
1066
+ color=color,
1067
+ min_scale=min_scale,
1068
+ max_scale=max_scale,
1069
+ scale_factor=scale_factor,
1070
+ random_seed=random_seed,
1071
+ )
729
1072
  annotated_colors = []
730
- for _, nodes in self.graph.domain_to_nodes.items():
731
- if len(nodes) > 1:
732
- # For domains with multiple nodes, choose the brightest color (sum of RGB values)
733
- domain_colors = np.array([node_colors[node] for node in nodes])
734
- brightest_color = domain_colors[np.argmax(domain_colors.sum(axis=1))]
1073
+ for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
1074
+ if len(node_ids) > 1:
1075
+ # For multi-node domains, choose the brightest color based on RGB sum
1076
+ domain_colors = np.array([node_colors[node] for node in node_ids])
1077
+ brightest_color = domain_colors[
1078
+ np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
1079
+ ]
735
1080
  annotated_colors.append(brightest_color)
736
1081
  else:
737
- # Assign a default color (white) for single-node domains
1082
+ # Single-node domains default to white (RGBA)
738
1083
  default_color = np.array([1.0, 1.0, 1.0, 1.0])
739
1084
  annotated_colors.append(default_color)
740
1085
 
@@ -761,6 +1106,65 @@ class NetworkPlotter:
761
1106
  plt.show(*args, **kwargs)
762
1107
 
763
1108
 
1109
+ def _to_rgba(
1110
+ color: Union[str, List, Tuple, np.ndarray],
1111
+ alpha: float = 1.0,
1112
+ num_repeats: Union[int, None] = None,
1113
+ ) -> np.ndarray:
1114
+ """Convert a color or array of colors to RGBA format, applying alpha only if the color is RGB.
1115
+
1116
+ Args:
1117
+ color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
1118
+ alpha (float, optional): Alpha value (transparency) to apply if the color is in RGB format. Defaults to 1.0.
1119
+ num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
1120
+
1121
+ Returns:
1122
+ np.ndarray: The RGBA color or array of RGBA colors.
1123
+ """
1124
+ # Handle single color case (string, RGB, or RGBA)
1125
+ if isinstance(color, str) or (
1126
+ isinstance(color, (list, tuple, np.ndarray)) and len(color) in [3, 4]
1127
+ ):
1128
+ rgba_color = np.array(mcolors.to_rgba(color))
1129
+ # Only set alpha if the input is an RGB color or a string (not RGBA)
1130
+ if len(rgba_color) == 4 and (
1131
+ len(color) == 3 or isinstance(color, str)
1132
+ ): # If it's RGB or a string, set the alpha
1133
+ rgba_color[3] = alpha
1134
+
1135
+ # Repeat the color if num_repeats argument is provided
1136
+ if num_repeats is not None:
1137
+ return np.array([rgba_color] * num_repeats)
1138
+
1139
+ return rgba_color
1140
+
1141
+ # Handle array of colors case (including strings, RGB, and RGBA)
1142
+ elif isinstance(color, (list, tuple, np.ndarray)):
1143
+ rgba_colors = []
1144
+ for c in color:
1145
+ # Ensure each element is either a valid string or a list/tuple of length 3 (RGB) or 4 (RGBA)
1146
+ if isinstance(c, str) or (
1147
+ isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]
1148
+ ):
1149
+ rgba_c = np.array(mcolors.to_rgba(c))
1150
+ # Apply alpha only to RGB colors (not RGBA) and strings
1151
+ if len(rgba_c) == 4 and (len(c) == 3 or isinstance(c, str)):
1152
+ rgba_c[3] = alpha
1153
+
1154
+ rgba_colors.append(rgba_c)
1155
+ else:
1156
+ raise ValueError(f"Invalid color: {c}. Must be a valid RGB/RGBA or string color.")
1157
+
1158
+ # Repeat the colors if num_repeats argument is provided
1159
+ if num_repeats is not None and len(rgba_colors) == 1:
1160
+ return np.array([rgba_colors[0]] * num_repeats)
1161
+
1162
+ return np.array(rgba_colors)
1163
+
1164
+ else:
1165
+ raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
1166
+
1167
+
764
1168
  def _is_connected(z: np.ndarray) -> bool:
765
1169
  """Determine if a thresholded grid represents a single, connected component.
766
1170