risk-network 0.0.5b6__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/io.py +44 -3
- risk/log/params.py +2 -0
- risk/neighborhoods/community.py +7 -3
- risk/neighborhoods/domains.py +24 -18
- risk/neighborhoods/neighborhoods.py +2 -2
- risk/network/graph.py +68 -40
- risk/network/io.py +30 -10
- risk/network/plot.py +713 -309
- risk/risk.py +10 -22
- {risk_network-0.0.5b6.dist-info → risk_network-0.0.6.dist-info}/METADATA +3 -4
- {risk_network-0.0.5b6.dist-info → risk_network-0.0.6.dist-info}/RECORD +15 -15
- {risk_network-0.0.5b6.dist-info → risk_network-0.0.6.dist-info}/WHEEL +1 -1
- {risk_network-0.0.5b6.dist-info → risk_network-0.0.6.dist-info}/LICENSE +0 -0
- {risk_network-0.0.5b6.dist-info → risk_network-0.0.6.dist-info}/top_level.txt +0 -0
risk/network/plot.py
CHANGED
@@ -5,7 +5,6 @@ risk/network/plot
|
|
5
5
|
|
6
6
|
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
|
-
import matplotlib
|
9
8
|
import matplotlib.colors as mcolors
|
10
9
|
import matplotlib.pyplot as plt
|
11
10
|
import networkx as nx
|
@@ -18,66 +17,42 @@ from risk.network.graph import NetworkGraph
|
|
18
17
|
|
19
18
|
|
20
19
|
class NetworkPlotter:
|
21
|
-
"""A class
|
20
|
+
"""A class for visualizing network graphs with customizable options.
|
22
21
|
|
23
|
-
The NetworkPlotter class
|
24
|
-
and
|
25
|
-
|
22
|
+
The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
|
23
|
+
flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
|
24
|
+
perimeter, and adjusting background colors.
|
26
25
|
"""
|
27
26
|
|
28
27
|
def __init__(
|
29
28
|
self,
|
30
29
|
graph: NetworkGraph,
|
31
|
-
figsize:
|
32
|
-
background_color: str = "white",
|
33
|
-
plot_outline: bool = True,
|
34
|
-
outline_color: str = "black",
|
35
|
-
outline_scale: float = 1.0,
|
36
|
-
linestyle: str = "dashed",
|
30
|
+
figsize: Tuple = (10, 10),
|
31
|
+
background_color: Union[str, List, Tuple, np.ndarray] = "white",
|
37
32
|
) -> None:
|
38
33
|
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
39
34
|
|
40
35
|
Args:
|
41
36
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
42
37
|
figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
|
43
|
-
background_color (str, optional): Background color of the plot. Defaults to "white".
|
44
|
-
plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
|
45
|
-
outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
|
46
|
-
outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
|
47
|
-
linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
|
38
|
+
background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
|
48
39
|
"""
|
49
40
|
self.graph = graph
|
50
41
|
# Initialize the plot with the specified parameters
|
51
|
-
self.ax = self._initialize_plot(
|
52
|
-
graph,
|
53
|
-
figsize,
|
54
|
-
background_color,
|
55
|
-
plot_outline,
|
56
|
-
outline_color,
|
57
|
-
outline_scale,
|
58
|
-
linestyle,
|
59
|
-
)
|
42
|
+
self.ax = self._initialize_plot(graph, figsize, background_color)
|
60
43
|
|
61
44
|
def _initialize_plot(
|
62
45
|
self,
|
63
46
|
graph: NetworkGraph,
|
64
|
-
figsize:
|
65
|
-
background_color: str,
|
66
|
-
plot_outline: bool,
|
67
|
-
outline_color: str,
|
68
|
-
outline_scale: float,
|
69
|
-
linestyle: str,
|
47
|
+
figsize: Tuple,
|
48
|
+
background_color: Union[str, List, Tuple, np.ndarray],
|
70
49
|
) -> plt.Axes:
|
71
|
-
"""Set up the plot with figure size
|
50
|
+
"""Set up the plot with figure size and background color.
|
72
51
|
|
73
52
|
Args:
|
74
53
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
75
54
|
figsize (tuple): Size of the figure in inches (width, height).
|
76
55
|
background_color (str): Background color of the plot.
|
77
|
-
plot_outline (bool): Whether to plot the network perimeter circle.
|
78
|
-
outline_color (str): Color of the network perimeter circle.
|
79
|
-
outline_scale (float): Outline scaling factor for the perimeter diameter.
|
80
|
-
linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid).
|
81
56
|
|
82
57
|
Returns:
|
83
58
|
plt.Axes: The axis object for the plot.
|
@@ -86,31 +61,19 @@ class NetworkPlotter:
|
|
86
61
|
node_coordinates = graph.node_coordinates
|
87
62
|
# Calculate the center and radius of the bounding box around the network
|
88
63
|
center, radius = _calculate_bounding_box(node_coordinates)
|
89
|
-
# Scale the radius by the outline_scale factor
|
90
|
-
scaled_radius = radius * outline_scale
|
91
64
|
|
92
65
|
# Create a new figure and axis for plotting
|
93
66
|
fig, ax = plt.subplots(figsize=figsize)
|
94
67
|
fig.tight_layout() # Adjust subplot parameters to give specified padding
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
center,
|
99
|
-
scaled_radius,
|
100
|
-
linestyle=linestyle, # Use the linestyle argument here
|
101
|
-
color=outline_color,
|
102
|
-
fill=False,
|
103
|
-
linewidth=1.5,
|
104
|
-
)
|
105
|
-
ax.add_artist(circle) # Add the circle to the plot
|
106
|
-
|
107
|
-
# Set axis limits based on the calculated bounding box and scaled radius
|
108
|
-
ax.set_xlim([center[0] - scaled_radius - 0.3, center[0] + scaled_radius + 0.3])
|
109
|
-
ax.set_ylim([center[1] - scaled_radius - 0.3, center[1] + scaled_radius + 0.3])
|
68
|
+
# Set axis limits based on the calculated bounding box and radius
|
69
|
+
ax.set_xlim([center[0] - radius - 0.3, center[0] + radius + 0.3])
|
70
|
+
ax.set_ylim([center[1] - radius - 0.3, center[1] + radius + 0.3])
|
110
71
|
ax.set_aspect("equal") # Ensure the aspect ratio is equal
|
111
|
-
fig.patch.set_facecolor(background_color) # Set the figure background color
|
112
|
-
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
113
72
|
|
73
|
+
# Set the background color of the plot
|
74
|
+
# Convert color to RGBA using the _to_rgba helper function
|
75
|
+
fig.patch.set_facecolor(_to_rgba(background_color, 1.0))
|
76
|
+
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
114
77
|
# Remove axis spines for a cleaner look
|
115
78
|
for spine in ax.spines.values():
|
116
79
|
spine.set_visible(False)
|
@@ -122,45 +85,182 @@ class NetworkPlotter:
|
|
122
85
|
|
123
86
|
return ax
|
124
87
|
|
88
|
+
def plot_circle_perimeter(
|
89
|
+
self,
|
90
|
+
scale: float = 1.0,
|
91
|
+
linestyle: str = "dashed",
|
92
|
+
linewidth: float = 1.5,
|
93
|
+
color: Union[str, List, Tuple, np.ndarray] = "black",
|
94
|
+
outline_alpha: float = 1.0,
|
95
|
+
fill_alpha: float = 0.0,
|
96
|
+
) -> None:
|
97
|
+
"""Plot a circle around the network graph to represent the network perimeter.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
|
101
|
+
linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
|
102
|
+
linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
|
103
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
|
104
|
+
outline_alpha (float, optional): Transparency level of the circle outline. Defaults to 1.0.
|
105
|
+
fill_alpha (float, optional): Transparency level of the circle fill. Defaults to 0.0.
|
106
|
+
"""
|
107
|
+
# Log the circle perimeter plotting parameters
|
108
|
+
params.log_plotter(
|
109
|
+
perimeter_type="circle",
|
110
|
+
perimeter_scale=scale,
|
111
|
+
perimeter_linestyle=linestyle,
|
112
|
+
perimeter_linewidth=linewidth,
|
113
|
+
perimeter_color=(
|
114
|
+
"custom" if isinstance(color, (list, tuple, np.ndarray)) else color
|
115
|
+
), # np.ndarray usually indicates custom colors
|
116
|
+
perimeter_outline_alpha=outline_alpha,
|
117
|
+
perimeter_fill_alpha=fill_alpha,
|
118
|
+
)
|
119
|
+
|
120
|
+
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
121
|
+
color = _to_rgba(color, outline_alpha)
|
122
|
+
# Extract node coordinates from the network graph
|
123
|
+
node_coordinates = self.graph.node_coordinates
|
124
|
+
# Calculate the center and radius of the bounding box around the network
|
125
|
+
center, radius = _calculate_bounding_box(node_coordinates)
|
126
|
+
# Scale the radius by the scale factor
|
127
|
+
scaled_radius = radius * scale
|
128
|
+
|
129
|
+
# Draw a circle to represent the network perimeter
|
130
|
+
circle = plt.Circle(
|
131
|
+
center,
|
132
|
+
scaled_radius,
|
133
|
+
linestyle=linestyle,
|
134
|
+
linewidth=linewidth,
|
135
|
+
color=color,
|
136
|
+
fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
|
137
|
+
)
|
138
|
+
# Set the transparency of the fill if applicable
|
139
|
+
if fill_alpha > 0:
|
140
|
+
circle.set_facecolor(_to_rgba(color, fill_alpha))
|
141
|
+
|
142
|
+
self.ax.add_artist(circle)
|
143
|
+
|
144
|
+
def plot_contour_perimeter(
|
145
|
+
self,
|
146
|
+
scale: float = 1.0,
|
147
|
+
levels: int = 3,
|
148
|
+
bandwidth: float = 0.8,
|
149
|
+
grid_size: int = 250,
|
150
|
+
color: Union[str, List, Tuple, np.ndarray] = "black",
|
151
|
+
linestyle: str = "solid",
|
152
|
+
linewidth: float = 1.5,
|
153
|
+
outline_alpha: float = 1.0,
|
154
|
+
fill_alpha: float = 0.0,
|
155
|
+
) -> None:
|
156
|
+
"""
|
157
|
+
Plot a KDE-based contour around the network graph to represent the network perimeter.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
|
161
|
+
levels (int, optional): Number of contour levels. Defaults to 3.
|
162
|
+
bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
|
163
|
+
grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
|
164
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
|
165
|
+
linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
|
166
|
+
linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
|
167
|
+
outline_alpha (float, optional): Transparency level of the contour outline. Defaults to 1.0.
|
168
|
+
fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.0.
|
169
|
+
"""
|
170
|
+
# Log the contour perimeter plotting parameters
|
171
|
+
params.log_plotter(
|
172
|
+
perimeter_type="contour",
|
173
|
+
perimeter_scale=scale,
|
174
|
+
perimeter_levels=levels,
|
175
|
+
perimeter_bandwidth=bandwidth,
|
176
|
+
perimeter_grid_size=grid_size,
|
177
|
+
perimeter_linestyle=linestyle,
|
178
|
+
perimeter_linewidth=linewidth,
|
179
|
+
perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
|
180
|
+
perimeter_outline_alpha=outline_alpha,
|
181
|
+
perimeter_fill_alpha=fill_alpha,
|
182
|
+
)
|
183
|
+
|
184
|
+
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
185
|
+
color = _to_rgba(color, outline_alpha)
|
186
|
+
# Extract node coordinates from the network graph
|
187
|
+
node_coordinates = self.graph.node_coordinates
|
188
|
+
# Scale the node coordinates if needed
|
189
|
+
scaled_coordinates = node_coordinates * scale
|
190
|
+
# Use the existing _draw_kde_contour method
|
191
|
+
self._draw_kde_contour(
|
192
|
+
ax=self.ax,
|
193
|
+
pos=scaled_coordinates,
|
194
|
+
nodes=list(range(len(node_coordinates))), # All nodes are included
|
195
|
+
levels=levels,
|
196
|
+
bandwidth=bandwidth,
|
197
|
+
grid_size=grid_size,
|
198
|
+
color=color,
|
199
|
+
linestyle=linestyle,
|
200
|
+
linewidth=linewidth,
|
201
|
+
alpha=fill_alpha,
|
202
|
+
)
|
203
|
+
|
125
204
|
def plot_network(
|
126
205
|
self,
|
127
206
|
node_size: Union[int, np.ndarray] = 50,
|
128
|
-
edge_width: float = 1.0,
|
129
|
-
node_color: Union[str, np.ndarray] = "white",
|
130
|
-
node_edgecolor: str = "black",
|
131
|
-
edge_color: str = "black",
|
132
207
|
node_shape: str = "o",
|
208
|
+
node_edgewidth: float = 1.0,
|
209
|
+
edge_width: float = 1.0,
|
210
|
+
node_color: Union[str, List, Tuple, np.ndarray] = "white",
|
211
|
+
node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
|
212
|
+
edge_color: Union[str, List, Tuple, np.ndarray] = "black",
|
213
|
+
node_alpha: float = 1.0,
|
214
|
+
edge_alpha: float = 1.0,
|
133
215
|
) -> None:
|
134
|
-
"""Plot the network graph with customizable node colors, sizes, and edge widths.
|
216
|
+
"""Plot the network graph with customizable node colors, sizes, edge widths, and node edge widths.
|
135
217
|
|
136
218
|
Args:
|
137
219
|
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
138
|
-
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
139
|
-
node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
140
|
-
node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
|
141
|
-
edge_color (str, optional): Color of the edges. Defaults to "black".
|
142
220
|
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
221
|
+
node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
|
222
|
+
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
223
|
+
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
224
|
+
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
225
|
+
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
226
|
+
node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0. Annotated node_color alphas will override this value.
|
227
|
+
edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
|
143
228
|
"""
|
144
229
|
# Log the plotting parameters
|
145
230
|
params.log_plotter(
|
146
|
-
network_node_size=
|
231
|
+
network_node_size=(
|
232
|
+
"custom" if isinstance(node_size, np.ndarray) else node_size
|
233
|
+
), # np.ndarray usually indicates custom sizes
|
234
|
+
network_node_shape=node_shape,
|
235
|
+
network_node_edgewidth=node_edgewidth,
|
147
236
|
network_edge_width=edge_width,
|
148
|
-
network_node_color=
|
237
|
+
network_node_color=(
|
238
|
+
"custom" if isinstance(node_color, np.ndarray) else node_color
|
239
|
+
), # np.ndarray usually indicates custom colors
|
149
240
|
network_node_edgecolor=node_edgecolor,
|
150
241
|
network_edge_color=edge_color,
|
151
|
-
|
242
|
+
network_node_alpha=node_alpha,
|
243
|
+
network_edge_alpha=edge_alpha,
|
152
244
|
)
|
245
|
+
|
246
|
+
# Convert colors to RGBA using the _to_rgba helper function
|
247
|
+
# If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
|
248
|
+
node_color = _to_rgba(node_color, node_alpha, num_repeats=len(self.graph.network.nodes))
|
249
|
+
node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(self.graph.network.nodes))
|
250
|
+
edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
|
251
|
+
|
153
252
|
# Extract node coordinates from the network graph
|
154
253
|
node_coordinates = self.graph.node_coordinates
|
254
|
+
|
155
255
|
# Draw the nodes of the graph
|
156
256
|
nx.draw_networkx_nodes(
|
157
257
|
self.graph.network,
|
158
258
|
pos=node_coordinates,
|
159
259
|
node_size=node_size,
|
160
|
-
node_color=node_color,
|
161
260
|
node_shape=node_shape,
|
162
|
-
|
261
|
+
node_color=node_color,
|
163
262
|
edgecolors=node_edgecolor,
|
263
|
+
linewidths=node_edgewidth,
|
164
264
|
ax=self.ax,
|
165
265
|
)
|
166
266
|
# Draw the edges of the graph
|
@@ -174,58 +274,73 @@ class NetworkPlotter:
|
|
174
274
|
|
175
275
|
def plot_subnetwork(
|
176
276
|
self,
|
177
|
-
nodes:
|
277
|
+
nodes: Union[List, Tuple, np.ndarray],
|
178
278
|
node_size: Union[int, np.ndarray] = 50,
|
179
|
-
edge_width: float = 1.0,
|
180
|
-
node_color: Union[str, np.ndarray] = "white",
|
181
|
-
node_edgecolor: str = "black",
|
182
|
-
edge_color: str = "black",
|
183
279
|
node_shape: str = "o",
|
280
|
+
node_edgewidth: float = 1.0,
|
281
|
+
edge_width: float = 1.0,
|
282
|
+
node_color: Union[str, List, Tuple, np.ndarray] = "white",
|
283
|
+
node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
|
284
|
+
edge_color: Union[str, List, Tuple, np.ndarray] = "black",
|
285
|
+
node_alpha: float = 1.0,
|
286
|
+
edge_alpha: float = 1.0,
|
184
287
|
) -> None:
|
185
288
|
"""Plot a subnetwork of selected nodes with customizable node and edge attributes.
|
186
289
|
|
187
290
|
Args:
|
188
|
-
nodes (list): List of node labels to include in the subnetwork.
|
291
|
+
nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
|
189
292
|
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
190
|
-
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
191
|
-
node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
192
|
-
node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
|
193
|
-
edge_color (str, optional): Color of the edges. Defaults to "black".
|
194
293
|
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
294
|
+
node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
|
295
|
+
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
296
|
+
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
|
297
|
+
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
298
|
+
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
299
|
+
node_alpha (float, optional): Transparency for the nodes. Defaults to 1.0.
|
300
|
+
edge_alpha (float, optional): Transparency for the edges. Defaults to 1.0.
|
195
301
|
|
196
302
|
Raises:
|
197
303
|
ValueError: If no valid nodes are found in the network graph.
|
198
304
|
"""
|
199
|
-
#
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
subnetwork_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
|
204
|
-
subnetwork_node_edgecolor=node_edgecolor,
|
205
|
-
subnetwork_edge_color=edge_color,
|
206
|
-
subnet_node_shape=node_shape,
|
207
|
-
)
|
305
|
+
# Flatten nested lists of nodes, if necessary
|
306
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
307
|
+
nodes = [node for sublist in nodes for node in sublist]
|
308
|
+
|
208
309
|
# Filter to get node IDs and their coordinates
|
209
310
|
node_ids = [
|
210
|
-
self.graph.
|
311
|
+
self.graph.node_label_to_node_id_map.get(node)
|
211
312
|
for node in nodes
|
212
|
-
if node in self.graph.
|
313
|
+
if node in self.graph.node_label_to_node_id_map
|
213
314
|
]
|
214
315
|
if not node_ids:
|
215
316
|
raise ValueError("No nodes found in the network graph.")
|
216
317
|
|
318
|
+
# Check if node_color is a single color or a list of colors
|
319
|
+
if not isinstance(node_color, (str, tuple, np.ndarray)):
|
320
|
+
node_color = [
|
321
|
+
node_color[nodes.index(node)]
|
322
|
+
for node in nodes
|
323
|
+
if node in self.graph.node_label_to_node_id_map
|
324
|
+
]
|
325
|
+
|
326
|
+
# Convert colors to RGBA using the _to_rgba helper function
|
327
|
+
node_color = _to_rgba(node_color, node_alpha, num_repeats=len(node_ids))
|
328
|
+
node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(node_ids))
|
329
|
+
edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
|
330
|
+
|
217
331
|
# Get the coordinates of the filtered nodes
|
218
332
|
node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
|
333
|
+
|
219
334
|
# Draw the nodes in the subnetwork
|
220
335
|
nx.draw_networkx_nodes(
|
221
336
|
self.graph.network,
|
222
337
|
pos=node_coordinates,
|
223
338
|
nodelist=node_ids,
|
224
339
|
node_size=node_size,
|
225
|
-
node_color=node_color,
|
226
340
|
node_shape=node_shape,
|
227
|
-
|
341
|
+
node_color=node_color,
|
228
342
|
edgecolors=node_edgecolor,
|
343
|
+
linewidths=node_edgewidth,
|
229
344
|
ax=self.ax,
|
230
345
|
)
|
231
346
|
# Draw the edges between the specified nodes in the subnetwork
|
@@ -243,8 +358,11 @@ class NetworkPlotter:
|
|
243
358
|
levels: int = 5,
|
244
359
|
bandwidth: float = 0.8,
|
245
360
|
grid_size: int = 250,
|
246
|
-
|
247
|
-
|
361
|
+
color: Union[str, List, Tuple, np.ndarray] = "white",
|
362
|
+
linestyle: str = "solid",
|
363
|
+
linewidth: float = 1.5,
|
364
|
+
alpha: float = 1.0,
|
365
|
+
fill_alpha: float = 0.2,
|
248
366
|
) -> None:
|
249
367
|
"""Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
|
250
368
|
|
@@ -252,99 +370,126 @@ class NetworkPlotter:
|
|
252
370
|
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
253
371
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
254
372
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
255
|
-
|
256
|
-
|
373
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors. Defaults to "white".
|
374
|
+
linestyle (str, optional): Line style for the contours. Defaults to "solid".
|
375
|
+
linewidth (float, optional): Line width for the contours. Defaults to 1.5.
|
376
|
+
alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
|
377
|
+
fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
|
257
378
|
"""
|
258
379
|
# Log the contour plotting parameters
|
259
380
|
params.log_plotter(
|
260
381
|
contour_levels=levels,
|
261
382
|
contour_bandwidth=bandwidth,
|
262
383
|
contour_grid_size=grid_size,
|
384
|
+
contour_color=(
|
385
|
+
"custom" if isinstance(color, np.ndarray) else color
|
386
|
+
), # np.ndarray usually indicates custom colors
|
263
387
|
contour_alpha=alpha,
|
264
|
-
|
388
|
+
contour_fill_alpha=fill_alpha,
|
265
389
|
)
|
266
|
-
# Convert color string to RGBA array if necessary
|
267
|
-
if isinstance(color, str):
|
268
|
-
color = self.get_annotated_contour_colors(color=color)
|
269
390
|
|
391
|
+
# Ensure color is converted to RGBA with repetition matching the number of domains
|
392
|
+
color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map))
|
270
393
|
# Extract node coordinates from the network graph
|
271
394
|
node_coordinates = self.graph.node_coordinates
|
272
395
|
# Draw contours for each domain in the network
|
273
|
-
for idx, (_,
|
274
|
-
if len(
|
396
|
+
for idx, (_, node_ids) in enumerate(self.graph.domain_id_to_node_ids_map.items()):
|
397
|
+
if len(node_ids) > 1:
|
275
398
|
self._draw_kde_contour(
|
276
399
|
self.ax,
|
277
400
|
node_coordinates,
|
278
|
-
|
401
|
+
node_ids,
|
279
402
|
color=color[idx],
|
280
403
|
levels=levels,
|
281
404
|
bandwidth=bandwidth,
|
282
405
|
grid_size=grid_size,
|
406
|
+
linestyle=linestyle,
|
407
|
+
linewidth=linewidth,
|
283
408
|
alpha=alpha,
|
409
|
+
fill_alpha=fill_alpha,
|
284
410
|
)
|
285
411
|
|
286
412
|
def plot_subcontour(
|
287
413
|
self,
|
288
|
-
nodes:
|
414
|
+
nodes: Union[List, Tuple, np.ndarray],
|
289
415
|
levels: int = 5,
|
290
416
|
bandwidth: float = 0.8,
|
291
417
|
grid_size: int = 250,
|
292
|
-
|
293
|
-
|
418
|
+
color: Union[str, List, Tuple, np.ndarray] = "white",
|
419
|
+
linestyle: str = "solid",
|
420
|
+
linewidth: float = 1.5,
|
421
|
+
alpha: float = 1.0,
|
422
|
+
fill_alpha: float = 0.2,
|
294
423
|
) -> None:
|
295
|
-
"""Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
|
424
|
+
"""Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
|
296
425
|
|
297
426
|
Args:
|
298
|
-
nodes (list): List of node labels to plot the contour for.
|
427
|
+
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
|
299
428
|
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
300
429
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
301
430
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
302
|
-
|
303
|
-
|
431
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
|
432
|
+
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
433
|
+
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
434
|
+
alpha (float, optional): Transparency level of the contour lines. Defaults to 1.0.
|
435
|
+
fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
|
304
436
|
|
305
437
|
Raises:
|
306
438
|
ValueError: If no valid nodes are found in the network graph.
|
307
439
|
"""
|
308
|
-
#
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
#
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
440
|
+
# Check if nodes is a list of lists or a flat list
|
441
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
442
|
+
# If it's a list of lists, iterate over sublists
|
443
|
+
node_groups = nodes
|
444
|
+
else:
|
445
|
+
# If it's a flat list of nodes, treat it as a single group
|
446
|
+
node_groups = [nodes]
|
447
|
+
|
448
|
+
# Convert color to RGBA using the _to_rgba helper function
|
449
|
+
color_rgba = _to_rgba(color, alpha)
|
450
|
+
|
451
|
+
# Iterate over each group of nodes (either sublists or flat list)
|
452
|
+
for sublist in node_groups:
|
453
|
+
# Filter to get node IDs and their coordinates for each sublist
|
454
|
+
node_ids = [
|
455
|
+
self.graph.node_label_to_node_id_map.get(node)
|
456
|
+
for node in sublist
|
457
|
+
if node in self.graph.node_label_to_node_id_map
|
458
|
+
]
|
459
|
+
if not node_ids or len(node_ids) == 1:
|
460
|
+
raise ValueError(
|
461
|
+
"No nodes found in the network graph or insufficient nodes to plot."
|
462
|
+
)
|
324
463
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
464
|
+
# Draw the KDE contour for the specified nodes
|
465
|
+
node_coordinates = self.graph.node_coordinates
|
466
|
+
self._draw_kde_contour(
|
467
|
+
self.ax,
|
468
|
+
node_coordinates,
|
469
|
+
node_ids,
|
470
|
+
color=color_rgba,
|
471
|
+
levels=levels,
|
472
|
+
bandwidth=bandwidth,
|
473
|
+
grid_size=grid_size,
|
474
|
+
linestyle=linestyle,
|
475
|
+
linewidth=linewidth,
|
476
|
+
alpha=alpha,
|
477
|
+
fill_alpha=fill_alpha,
|
478
|
+
)
|
337
479
|
|
338
480
|
def _draw_kde_contour(
|
339
481
|
self,
|
340
482
|
ax: plt.Axes,
|
341
483
|
pos: np.ndarray,
|
342
|
-
nodes:
|
343
|
-
color: Union[str, np.ndarray],
|
484
|
+
nodes: List,
|
344
485
|
levels: int = 5,
|
345
486
|
bandwidth: float = 0.8,
|
346
487
|
grid_size: int = 250,
|
347
|
-
|
488
|
+
color: Union[str, np.ndarray] = "white",
|
489
|
+
linestyle: str = "solid",
|
490
|
+
linewidth: float = 1.5,
|
491
|
+
alpha: float = 1.0,
|
492
|
+
fill_alpha: float = 0.2,
|
348
493
|
) -> None:
|
349
494
|
"""Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
|
350
495
|
|
@@ -352,11 +497,14 @@ class NetworkPlotter:
|
|
352
497
|
ax (plt.Axes): The axis to draw the contour on.
|
353
498
|
pos (np.ndarray): Array of node positions (x, y).
|
354
499
|
nodes (list): List of node indices to include in the contour.
|
355
|
-
color (str or np.ndarray): Color for the contour.
|
356
500
|
levels (int, optional): Number of contour levels. Defaults to 5.
|
357
501
|
bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
|
358
502
|
grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
|
359
|
-
|
503
|
+
color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
|
504
|
+
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
505
|
+
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
506
|
+
alpha (float, optional): Transparency level for the contour lines. Defaults to 1.0.
|
507
|
+
fill_alpha (float, optional): Transparency level for the contour fill. Defaults to 0.2.
|
360
508
|
"""
|
361
509
|
# Extract the positions of the specified nodes
|
362
510
|
points = np.array([pos[n] for n in nodes])
|
@@ -382,141 +530,219 @@ class NetworkPlotter:
|
|
382
530
|
min_density, max_density = z.min(), z.max()
|
383
531
|
contour_levels = np.linspace(min_density, max_density, levels)[1:]
|
384
532
|
contour_colors = [color for _ in range(levels - 1)]
|
385
|
-
|
386
|
-
|
387
|
-
if alpha > 0:
|
533
|
+
# Plot the filled contours using fill_alpha for transparency
|
534
|
+
if fill_alpha > 0:
|
388
535
|
ax.contourf(
|
389
536
|
x,
|
390
537
|
y,
|
391
538
|
z,
|
392
539
|
levels=contour_levels,
|
393
540
|
colors=contour_colors,
|
394
|
-
alpha=alpha,
|
395
|
-
extend="neither",
|
396
541
|
antialiased=True,
|
542
|
+
alpha=fill_alpha,
|
397
543
|
)
|
398
544
|
|
399
|
-
# Plot the contour lines
|
400
|
-
c = ax.contour(
|
545
|
+
# Plot the contour lines with the specified alpha for transparency
|
546
|
+
c = ax.contour(
|
547
|
+
x,
|
548
|
+
y,
|
549
|
+
z,
|
550
|
+
levels=contour_levels,
|
551
|
+
colors=contour_colors,
|
552
|
+
linestyles=linestyle,
|
553
|
+
linewidths=linewidth,
|
554
|
+
alpha=alpha,
|
555
|
+
)
|
556
|
+
# Set linewidth for the contour lines to 0 for levels other than the base level
|
401
557
|
for i in range(1, len(contour_levels)):
|
402
558
|
c.collections[i].set_linewidth(0)
|
403
559
|
|
404
560
|
def plot_labels(
|
405
561
|
self,
|
406
|
-
|
562
|
+
scale: float = 1.05,
|
407
563
|
offset: float = 0.10,
|
408
564
|
font: str = "Arial",
|
409
565
|
fontsize: int = 10,
|
410
|
-
fontcolor: Union[str, np.ndarray] = "black",
|
566
|
+
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
567
|
+
fontalpha: float = 1.0,
|
411
568
|
arrow_linewidth: float = 1,
|
412
|
-
|
569
|
+
arrow_style: str = "->",
|
570
|
+
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
571
|
+
arrow_alpha: float = 1.0,
|
572
|
+
arrow_base_shrink: float = 0.0,
|
573
|
+
arrow_tip_shrink: float = 0.0,
|
413
574
|
max_labels: Union[int, None] = None,
|
414
575
|
max_words: int = 10,
|
415
576
|
min_words: int = 1,
|
416
577
|
max_word_length: int = 20,
|
417
578
|
min_word_length: int = 1,
|
418
|
-
words_to_omit: Union[List
|
579
|
+
words_to_omit: Union[List, None] = None,
|
580
|
+
overlay_ids: bool = False,
|
581
|
+
ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
|
582
|
+
ids_to_replace: Union[Dict, None] = None,
|
419
583
|
) -> None:
|
420
584
|
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
421
585
|
|
422
586
|
Args:
|
423
|
-
|
587
|
+
scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
|
424
588
|
offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
|
425
589
|
font (str, optional): Font name for the labels. Defaults to "Arial".
|
426
590
|
fontsize (int, optional): Font size for the labels. Defaults to 10.
|
427
|
-
fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
|
591
|
+
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
|
592
|
+
fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
|
428
593
|
arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
|
429
|
-
|
594
|
+
arrow_style (str, optional): Style of the arrows pointing to centroids. Defaults to "->".
|
595
|
+
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
|
596
|
+
arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
|
597
|
+
arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
|
598
|
+
arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
|
430
599
|
max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
|
431
600
|
max_words (int, optional): Maximum number of words in a label. Defaults to 10.
|
432
601
|
min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
|
433
602
|
max_word_length (int, optional): Maximum number of characters in a word to display. Defaults to 20.
|
434
603
|
min_word_length (int, optional): Minimum number of characters in a word to display. Defaults to 1.
|
435
|
-
words_to_omit (List
|
604
|
+
words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
|
605
|
+
overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
|
606
|
+
ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
|
607
|
+
you can set `overlay_ids=True`. Defaults to None.
|
608
|
+
ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be space-separated words.
|
609
|
+
If provided, the custom labels will replace the default domain terms. To discover domain IDs, you can set `overlay_ids=True`.
|
610
|
+
Defaults to None.
|
611
|
+
|
612
|
+
Raises:
|
613
|
+
ValueError: If the number of provided `ids_to_keep` exceeds `max_labels`.
|
436
614
|
"""
|
437
615
|
# Log the plotting parameters
|
438
616
|
params.log_plotter(
|
439
|
-
label_perimeter_scale=
|
617
|
+
label_perimeter_scale=scale,
|
440
618
|
label_offset=offset,
|
441
619
|
label_font=font,
|
442
620
|
label_fontsize=fontsize,
|
443
|
-
label_fontcolor=
|
621
|
+
label_fontcolor=(
|
622
|
+
"custom" if isinstance(fontcolor, np.ndarray) else fontcolor
|
623
|
+
), # np.ndarray usually indicates custom colors
|
624
|
+
label_fontalpha=fontalpha,
|
444
625
|
label_arrow_linewidth=arrow_linewidth,
|
626
|
+
label_arrow_style=arrow_style,
|
445
627
|
label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
628
|
+
label_arrow_alpha=arrow_alpha,
|
629
|
+
label_arrow_base_shrink=arrow_base_shrink,
|
630
|
+
label_arrow_tip_shrink=arrow_tip_shrink,
|
446
631
|
label_max_labels=max_labels,
|
447
632
|
label_max_words=max_words,
|
448
633
|
label_min_words=min_words,
|
449
634
|
label_max_word_length=max_word_length,
|
450
635
|
label_min_word_length=min_word_length,
|
451
636
|
label_words_to_omit=words_to_omit,
|
637
|
+
label_overlay_ids=overlay_ids,
|
638
|
+
label_ids_to_keep=ids_to_keep,
|
639
|
+
label_ids_to_replace=ids_to_replace,
|
640
|
+
)
|
641
|
+
|
642
|
+
# Set max_labels to the total number of domains if not provided (None)
|
643
|
+
if max_labels is None:
|
644
|
+
max_labels = len(self.graph.domain_id_to_node_ids_map)
|
645
|
+
|
646
|
+
# Convert colors to RGBA using the _to_rgba helper function
|
647
|
+
fontcolor = _to_rgba(
|
648
|
+
fontcolor, fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
649
|
+
)
|
650
|
+
arrow_color = _to_rgba(
|
651
|
+
arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
452
652
|
)
|
453
653
|
|
454
|
-
# Convert color strings to RGBA arrays if necessary
|
455
|
-
if isinstance(fontcolor, str):
|
456
|
-
fontcolor = self.get_annotated_label_colors(color=fontcolor)
|
457
|
-
if isinstance(arrow_color, str):
|
458
|
-
arrow_color = self.get_annotated_label_colors(color=arrow_color)
|
459
654
|
# Normalize words_to_omit to lowercase
|
460
655
|
if words_to_omit:
|
461
656
|
words_to_omit = set(word.lower() for word in words_to_omit)
|
462
657
|
|
463
658
|
# Calculate the center and radius of the network
|
464
659
|
domain_centroids = {}
|
465
|
-
for
|
466
|
-
if
|
467
|
-
domain_centroids[
|
660
|
+
for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
661
|
+
if node_ids: # Skip if the domain has no nodes
|
662
|
+
domain_centroids[domain_id] = self._calculate_domain_centroid(node_ids)
|
468
663
|
|
469
|
-
# Initialize
|
664
|
+
# Initialize dictionaries and lists for valid indices
|
470
665
|
valid_indices = []
|
471
666
|
filtered_domain_centroids = {}
|
472
667
|
filtered_domain_terms = {}
|
473
|
-
#
|
474
|
-
|
475
|
-
#
|
476
|
-
|
477
|
-
#
|
478
|
-
if
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
# Trim to max_words
|
483
|
-
terms = terms[:max_words]
|
484
|
-
# Check if the domain passes the word count condition
|
485
|
-
if len(terms) >= min_words:
|
486
|
-
# Add to filtered_domain_centroids
|
487
|
-
filtered_domain_centroids[domain] = centroid
|
488
|
-
# Store the filtered and trimmed terms
|
489
|
-
filtered_domain_terms[domain] = " ".join(terms)
|
490
|
-
# Keep track of the valid index - used for fontcolor and arrow_color
|
491
|
-
valid_indices.append(idx)
|
492
|
-
|
493
|
-
# If max_labels is specified and less than the available labels
|
494
|
-
if max_labels is not None and max_labels < len(filtered_domain_centroids):
|
495
|
-
step = len(filtered_domain_centroids) / max_labels
|
496
|
-
selected_indices = [int(i * step) for i in range(max_labels)]
|
497
|
-
# Filter the centroids, terms, and valid_indices to only use the selected indices
|
498
|
-
filtered_domain_centroids = {
|
499
|
-
k: v
|
500
|
-
for i, (k, v) in enumerate(filtered_domain_centroids.items())
|
501
|
-
if i in selected_indices
|
502
|
-
}
|
503
|
-
filtered_domain_terms = {
|
504
|
-
k: v
|
505
|
-
for i, (k, v) in enumerate(filtered_domain_terms.items())
|
506
|
-
if i in selected_indices
|
507
|
-
}
|
508
|
-
# Update valid_indices to match selected indices
|
509
|
-
valid_indices = [valid_indices[i] for i in selected_indices]
|
668
|
+
# Handle the ids_to_keep logic
|
669
|
+
if ids_to_keep:
|
670
|
+
# Convert ids_to_keep to remove accidental duplicates
|
671
|
+
ids_to_keep = set(ids_to_keep)
|
672
|
+
# Check if the number of provided ids_to_keep exceeds max_labels
|
673
|
+
if max_labels is not None and len(ids_to_keep) > max_labels:
|
674
|
+
raise ValueError(
|
675
|
+
f"Number of provided IDs ({len(ids_to_keep)}) exceeds max_labels ({max_labels})."
|
676
|
+
)
|
510
677
|
|
511
|
-
|
512
|
-
|
513
|
-
|
678
|
+
# Process the specified IDs first
|
679
|
+
for domain in ids_to_keep:
|
680
|
+
if (
|
681
|
+
domain in self.graph.domain_id_to_domain_terms_map
|
682
|
+
and domain in domain_centroids
|
683
|
+
):
|
684
|
+
# Handle ids_to_replace logic here for ids_to_keep
|
685
|
+
if ids_to_replace and domain in ids_to_replace:
|
686
|
+
terms = ids_to_replace[domain].split(" ")
|
687
|
+
else:
|
688
|
+
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
689
|
+
|
690
|
+
# Apply words_to_omit, word length constraints, and max_words
|
691
|
+
if words_to_omit:
|
692
|
+
terms = [term for term in terms if term.lower() not in words_to_omit]
|
693
|
+
terms = [
|
694
|
+
term for term in terms if min_word_length <= len(term) <= max_word_length
|
695
|
+
]
|
696
|
+
terms = terms[:max_words]
|
697
|
+
|
698
|
+
# Check if the domain passes the word count condition
|
699
|
+
if len(terms) >= min_words:
|
700
|
+
filtered_domain_centroids[domain] = domain_centroids[domain]
|
701
|
+
filtered_domain_terms[domain] = " ".join(terms)
|
702
|
+
valid_indices.append(
|
703
|
+
list(domain_centroids.keys()).index(domain)
|
704
|
+
) # Track the valid index
|
705
|
+
|
706
|
+
# Calculate remaining labels to plot after processing ids_to_keep
|
707
|
+
remaining_labels = (
|
708
|
+
max_labels - len(ids_to_keep) if ids_to_keep and max_labels else max_labels
|
514
709
|
)
|
515
|
-
#
|
710
|
+
# Process remaining domains to fill in additional labels, if there are slots left
|
711
|
+
if remaining_labels and remaining_labels > 0:
|
712
|
+
for idx, (domain, centroid) in enumerate(domain_centroids.items()):
|
713
|
+
if ids_to_keep and domain in ids_to_keep:
|
714
|
+
continue # Skip domains already handled by ids_to_keep
|
715
|
+
|
716
|
+
# Handle ids_to_replace logic first
|
717
|
+
if ids_to_replace and domain in ids_to_replace:
|
718
|
+
terms = ids_to_replace[domain].split(" ")
|
719
|
+
else:
|
720
|
+
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
721
|
+
|
722
|
+
# Apply words_to_omit, word length constraints, and max_words
|
723
|
+
if words_to_omit:
|
724
|
+
terms = [term for term in terms if term.lower() not in words_to_omit]
|
725
|
+
|
726
|
+
terms = [term for term in terms if min_word_length <= len(term) <= max_word_length]
|
727
|
+
terms = terms[:max_words]
|
728
|
+
# Check if the domain passes the word count condition
|
729
|
+
if len(terms) >= min_words:
|
730
|
+
filtered_domain_centroids[domain] = centroid
|
731
|
+
filtered_domain_terms[domain] = " ".join(terms)
|
732
|
+
valid_indices.append(idx) # Track the valid index
|
733
|
+
|
734
|
+
# Stop once we've reached the max_labels limit
|
735
|
+
if len(filtered_domain_centroids) >= max_labels:
|
736
|
+
break
|
737
|
+
|
738
|
+
# Calculate the bounding box around the network
|
739
|
+
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
740
|
+
# Calculate the best positions for labels
|
516
741
|
best_label_positions = _calculate_best_label_positions(
|
517
742
|
filtered_domain_centroids, center, radius, offset
|
518
743
|
)
|
519
|
-
|
744
|
+
|
745
|
+
# Annotate the network with labels
|
520
746
|
for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
|
521
747
|
centroid = filtered_domain_centroids[domain]
|
522
748
|
annotations = filtered_domain_terms[domain].split(" ")[:max_words]
|
@@ -530,63 +756,79 @@ class NetworkPlotter:
|
|
530
756
|
fontsize=fontsize,
|
531
757
|
fontname=font,
|
532
758
|
color=fontcolor[idx],
|
533
|
-
arrowprops=dict(
|
759
|
+
arrowprops=dict(
|
760
|
+
arrowstyle=arrow_style,
|
761
|
+
color=arrow_color[idx],
|
762
|
+
linewidth=arrow_linewidth,
|
763
|
+
shrinkA=arrow_base_shrink,
|
764
|
+
shrinkB=arrow_tip_shrink,
|
765
|
+
),
|
534
766
|
)
|
767
|
+
# Overlay domain ID at the centroid if requested
|
768
|
+
if overlay_ids:
|
769
|
+
self.ax.text(
|
770
|
+
centroid[0],
|
771
|
+
centroid[1],
|
772
|
+
domain,
|
773
|
+
ha="center",
|
774
|
+
va="center",
|
775
|
+
fontsize=fontsize,
|
776
|
+
fontname=font,
|
777
|
+
color=fontcolor[idx],
|
778
|
+
alpha=fontalpha,
|
779
|
+
)
|
535
780
|
|
536
781
|
def plot_sublabel(
|
537
782
|
self,
|
538
|
-
nodes:
|
783
|
+
nodes: Union[List, Tuple, np.ndarray],
|
539
784
|
label: str,
|
540
785
|
radial_position: float = 0.0,
|
541
|
-
|
786
|
+
scale: float = 1.05,
|
542
787
|
offset: float = 0.10,
|
543
788
|
font: str = "Arial",
|
544
789
|
fontsize: int = 10,
|
545
|
-
fontcolor: str = "black",
|
790
|
+
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
791
|
+
fontalpha: float = 1.0,
|
546
792
|
arrow_linewidth: float = 1,
|
547
|
-
|
793
|
+
arrow_style: str = "->",
|
794
|
+
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
795
|
+
arrow_alpha: float = 1.0,
|
796
|
+
arrow_base_shrink: float = 0.0,
|
797
|
+
arrow_tip_shrink: float = 0.0,
|
548
798
|
) -> None:
|
549
|
-
"""Annotate the network graph with a
|
799
|
+
"""Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
|
550
800
|
|
551
801
|
Args:
|
552
|
-
nodes (
|
802
|
+
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels.
|
553
803
|
label (str): The label to be annotated on the network.
|
554
804
|
radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
|
555
|
-
|
805
|
+
scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
|
556
806
|
offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
|
557
807
|
font (str, optional): Font name for the label. Defaults to "Arial".
|
558
808
|
fontsize (int, optional): Font size for the label. Defaults to 10.
|
559
|
-
fontcolor (str, optional): Color of the label text. Defaults to "black".
|
809
|
+
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
|
810
|
+
fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
|
560
811
|
arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
|
561
|
-
|
812
|
+
arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
|
813
|
+
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
|
814
|
+
arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
|
815
|
+
arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
|
816
|
+
arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
|
562
817
|
"""
|
563
|
-
#
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
)
|
574
|
-
|
575
|
-
# Map node labels to IDs
|
576
|
-
node_ids = [
|
577
|
-
self.graph.node_label_to_id_map.get(node)
|
578
|
-
for node in nodes
|
579
|
-
if node in self.graph.node_label_to_id_map
|
580
|
-
]
|
581
|
-
if not node_ids or len(node_ids) == 1:
|
582
|
-
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
818
|
+
# Check if nodes is a list of lists or a flat list
|
819
|
+
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
820
|
+
# If it's a list of lists, iterate over sublists
|
821
|
+
node_groups = nodes
|
822
|
+
else:
|
823
|
+
# If it's a flat list of nodes, treat it as a single group
|
824
|
+
node_groups = [nodes]
|
825
|
+
|
826
|
+
# Convert fontcolor and arrow_color to RGBA
|
827
|
+
fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
|
828
|
+
arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
|
583
829
|
|
584
|
-
# Calculate the centroid of the provided nodes
|
585
|
-
centroid = self._calculate_domain_centroid(node_ids)
|
586
830
|
# Calculate the bounding box around the network
|
587
|
-
center, radius = _calculate_bounding_box(
|
588
|
-
self.graph.node_coordinates, radius_margin=perimeter_scale
|
589
|
-
)
|
831
|
+
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
590
832
|
# Convert radial position to radians, adjusting for a 90-degree rotation
|
591
833
|
radial_radians = np.deg2rad(radial_position - 90)
|
592
834
|
label_position = (
|
@@ -594,21 +836,42 @@ class NetworkPlotter:
|
|
594
836
|
center[1] + (radius + offset) * np.sin(radial_radians),
|
595
837
|
)
|
596
838
|
|
597
|
-
#
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
)
|
839
|
+
# Iterate over each group of nodes (either sublists or flat list)
|
840
|
+
for sublist in node_groups:
|
841
|
+
# Map node labels to IDs
|
842
|
+
node_ids = [
|
843
|
+
self.graph.node_label_to_node_id_map.get(node)
|
844
|
+
for node in sublist
|
845
|
+
if node in self.graph.node_label_to_node_id_map
|
846
|
+
]
|
847
|
+
if not node_ids or len(node_ids) == 1:
|
848
|
+
raise ValueError(
|
849
|
+
"No nodes found in the network graph or insufficient nodes to plot."
|
850
|
+
)
|
610
851
|
|
611
|
-
|
852
|
+
# Calculate the centroid of the provided nodes in this sublist
|
853
|
+
centroid = self._calculate_domain_centroid(node_ids)
|
854
|
+
# Annotate the network with the label and an arrow pointing to each centroid
|
855
|
+
self.ax.annotate(
|
856
|
+
label,
|
857
|
+
xy=centroid,
|
858
|
+
xytext=label_position,
|
859
|
+
textcoords="data",
|
860
|
+
ha="center",
|
861
|
+
va="center",
|
862
|
+
fontsize=fontsize,
|
863
|
+
fontname=font,
|
864
|
+
color=fontcolor_rgba,
|
865
|
+
arrowprops=dict(
|
866
|
+
arrowstyle=arrow_style,
|
867
|
+
color=arrow_color_rgba,
|
868
|
+
linewidth=arrow_linewidth,
|
869
|
+
shrinkA=arrow_base_shrink,
|
870
|
+
shrinkB=arrow_tip_shrink,
|
871
|
+
),
|
872
|
+
)
|
873
|
+
|
874
|
+
def _calculate_domain_centroid(self, nodes: List) -> tuple:
|
612
875
|
"""Calculate the most centrally located node in .
|
613
876
|
|
614
877
|
Args:
|
@@ -630,111 +893,193 @@ class NetworkPlotter:
|
|
630
893
|
return domain_central_node
|
631
894
|
|
632
895
|
def get_annotated_node_colors(
|
633
|
-
self,
|
896
|
+
self,
|
897
|
+
cmap: str = "gist_rainbow",
|
898
|
+
color: Union[str, None] = None,
|
899
|
+
min_scale: float = 0.8,
|
900
|
+
max_scale: float = 1.0,
|
901
|
+
scale_factor: float = 1.0,
|
902
|
+
alpha: float = 1.0,
|
903
|
+
nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
|
904
|
+
nonenriched_alpha: float = 1.0,
|
905
|
+
random_seed: int = 888,
|
634
906
|
) -> np.ndarray:
|
635
907
|
"""Adjust the colors of nodes in the network graph based on enrichment.
|
636
908
|
|
637
909
|
Args:
|
638
|
-
|
910
|
+
cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
|
911
|
+
color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
|
912
|
+
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
913
|
+
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
914
|
+
scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
|
915
|
+
alpha (float, optional): Alpha value for enriched nodes. Defaults to 1.0.
|
916
|
+
nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
|
917
|
+
nonenriched_alpha (float, optional): Alpha value for non-enriched nodes. Defaults to 1.0.
|
639
918
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
640
|
-
**kwargs: Additional keyword arguments for `get_domain_colors`.
|
641
919
|
|
642
920
|
Returns:
|
643
921
|
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
644
922
|
"""
|
645
|
-
# Get the initial domain colors for each node
|
646
|
-
network_colors = self.graph.get_domain_colors(
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
923
|
+
# Get the initial domain colors for each node, which are returned as RGBA
|
924
|
+
network_colors = self.graph.get_domain_colors(
|
925
|
+
cmap=cmap,
|
926
|
+
color=color,
|
927
|
+
min_scale=min_scale,
|
928
|
+
max_scale=max_scale,
|
929
|
+
scale_factor=scale_factor,
|
930
|
+
random_seed=random_seed,
|
931
|
+
)
|
932
|
+
# Apply the alpha value for enriched nodes
|
933
|
+
network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
|
934
|
+
# Convert the non-enriched color to RGBA using the _to_rgba helper function
|
935
|
+
nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
|
936
|
+
# Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
|
652
937
|
adjusted_network_colors = np.where(
|
653
|
-
np.all(network_colors == 0, axis=1, keepdims=True),
|
654
|
-
np.array([nonenriched_color]),
|
655
|
-
network_colors,
|
938
|
+
np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
|
939
|
+
np.array([nonenriched_color]), # Apply the non-enriched color with alpha
|
940
|
+
network_colors, # Keep the original colors for enriched nodes
|
656
941
|
)
|
657
942
|
return adjusted_network_colors
|
658
943
|
|
659
944
|
def get_annotated_node_sizes(
|
660
|
-
self,
|
945
|
+
self, enriched_size: int = 50, nonenriched_size: int = 25
|
661
946
|
) -> np.ndarray:
|
662
947
|
"""Adjust the sizes of nodes in the network graph based on whether they are enriched or not.
|
663
948
|
|
664
949
|
Args:
|
665
|
-
|
666
|
-
|
950
|
+
enriched_size (int): Size for enriched nodes. Defaults to 50.
|
951
|
+
nonenriched_size (int): Size for non-enriched nodes. Defaults to 25.
|
667
952
|
|
668
953
|
Returns:
|
669
954
|
np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
|
670
955
|
"""
|
671
|
-
# Merge all enriched nodes from the
|
956
|
+
# Merge all enriched nodes from the domain_id_to_node_ids_map dictionary
|
672
957
|
enriched_nodes = set()
|
673
|
-
for _,
|
674
|
-
enriched_nodes.update(
|
958
|
+
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
959
|
+
enriched_nodes.update(node_ids)
|
675
960
|
|
676
961
|
# Initialize all node sizes to the non-enriched size
|
677
|
-
node_sizes = np.full(len(self.graph.network.nodes),
|
962
|
+
node_sizes = np.full(len(self.graph.network.nodes), nonenriched_size)
|
678
963
|
# Set the size for enriched nodes
|
679
964
|
for node in enriched_nodes:
|
680
965
|
if node in self.graph.network.nodes:
|
681
|
-
node_sizes[node] =
|
966
|
+
node_sizes[node] = enriched_size
|
682
967
|
|
683
968
|
return node_sizes
|
684
969
|
|
685
|
-
def get_annotated_contour_colors(
|
686
|
-
|
970
|
+
def get_annotated_contour_colors(
|
971
|
+
self,
|
972
|
+
cmap: str = "gist_rainbow",
|
973
|
+
color: Union[str, None] = None,
|
974
|
+
min_scale: float = 0.8,
|
975
|
+
max_scale: float = 1.0,
|
976
|
+
scale_factor: float = 1.0,
|
977
|
+
random_seed: int = 888,
|
978
|
+
) -> np.ndarray:
|
979
|
+
"""Get colors for the contours based on node annotations or a specified colormap.
|
687
980
|
|
688
981
|
Args:
|
689
|
-
|
690
|
-
|
982
|
+
cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
|
983
|
+
color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
|
984
|
+
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
985
|
+
Controls the dimmest colors. Defaults to 0.8.
|
986
|
+
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
987
|
+
Controls the brightest colors. Defaults to 1.0.
|
988
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
|
989
|
+
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
990
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
691
991
|
|
692
992
|
Returns:
|
693
993
|
np.ndarray: Array of RGBA colors for contour annotations.
|
694
994
|
"""
|
695
|
-
return self._get_annotated_domain_colors(
|
995
|
+
return self._get_annotated_domain_colors(
|
996
|
+
cmap=cmap,
|
997
|
+
color=color,
|
998
|
+
min_scale=min_scale,
|
999
|
+
max_scale=max_scale,
|
1000
|
+
scale_factor=scale_factor,
|
1001
|
+
random_seed=random_seed,
|
1002
|
+
)
|
696
1003
|
|
697
|
-
def get_annotated_label_colors(
|
698
|
-
|
1004
|
+
def get_annotated_label_colors(
|
1005
|
+
self,
|
1006
|
+
cmap: str = "gist_rainbow",
|
1007
|
+
color: Union[str, None] = None,
|
1008
|
+
min_scale: float = 0.8,
|
1009
|
+
max_scale: float = 1.0,
|
1010
|
+
scale_factor: float = 1.0,
|
1011
|
+
random_seed: int = 888,
|
1012
|
+
) -> np.ndarray:
|
1013
|
+
"""Get colors for the labels based on node annotations or a specified colormap.
|
699
1014
|
|
700
1015
|
Args:
|
701
|
-
|
702
|
-
|
1016
|
+
cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
|
1017
|
+
color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
|
1018
|
+
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
1019
|
+
Controls the dimmest colors. Defaults to 0.8.
|
1020
|
+
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
1021
|
+
Controls the brightest colors. Defaults to 1.0.
|
1022
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
|
1023
|
+
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
1024
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
703
1025
|
|
704
1026
|
Returns:
|
705
1027
|
np.ndarray: Array of RGBA colors for label annotations.
|
706
1028
|
"""
|
707
|
-
return self._get_annotated_domain_colors(
|
1029
|
+
return self._get_annotated_domain_colors(
|
1030
|
+
cmap=cmap,
|
1031
|
+
color=color,
|
1032
|
+
min_scale=min_scale,
|
1033
|
+
max_scale=max_scale,
|
1034
|
+
scale_factor=scale_factor,
|
1035
|
+
random_seed=random_seed,
|
1036
|
+
)
|
708
1037
|
|
709
1038
|
def _get_annotated_domain_colors(
|
710
|
-
self,
|
1039
|
+
self,
|
1040
|
+
cmap: str = "gist_rainbow",
|
1041
|
+
color: Union[str, None] = None,
|
1042
|
+
min_scale: float = 0.8,
|
1043
|
+
max_scale: float = 1.0,
|
1044
|
+
scale_factor: float = 1.0,
|
1045
|
+
random_seed: int = 888,
|
711
1046
|
) -> np.ndarray:
|
712
|
-
"""Get colors for the domains based on node annotations.
|
1047
|
+
"""Get colors for the domains based on node annotations, or use a specified color.
|
713
1048
|
|
714
1049
|
Args:
|
715
|
-
|
716
|
-
|
717
|
-
|
1050
|
+
cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
1051
|
+
color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
|
1052
|
+
min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
|
1053
|
+
Defaults to 0.8.
|
1054
|
+
max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
|
1055
|
+
Defaults to 1.0.
|
1056
|
+
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
|
1057
|
+
enrichment. Higher values increase the contrast. Defaults to 1.0.
|
1058
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
718
1059
|
|
719
1060
|
Returns:
|
720
1061
|
np.ndarray: Array of RGBA colors for each domain.
|
721
1062
|
"""
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
1063
|
+
# Generate domain colors based on the enrichment data
|
1064
|
+
node_colors = self.graph.get_domain_colors(
|
1065
|
+
cmap=cmap,
|
1066
|
+
color=color,
|
1067
|
+
min_scale=min_scale,
|
1068
|
+
max_scale=max_scale,
|
1069
|
+
scale_factor=scale_factor,
|
1070
|
+
random_seed=random_seed,
|
1071
|
+
)
|
729
1072
|
annotated_colors = []
|
730
|
-
for _,
|
731
|
-
if len(
|
732
|
-
# For domains
|
733
|
-
domain_colors = np.array([node_colors[node] for node in
|
734
|
-
brightest_color = domain_colors[
|
1073
|
+
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
1074
|
+
if len(node_ids) > 1:
|
1075
|
+
# For multi-node domains, choose the brightest color based on RGB sum
|
1076
|
+
domain_colors = np.array([node_colors[node] for node in node_ids])
|
1077
|
+
brightest_color = domain_colors[
|
1078
|
+
np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
|
1079
|
+
]
|
735
1080
|
annotated_colors.append(brightest_color)
|
736
1081
|
else:
|
737
|
-
#
|
1082
|
+
# Single-node domains default to white (RGBA)
|
738
1083
|
default_color = np.array([1.0, 1.0, 1.0, 1.0])
|
739
1084
|
annotated_colors.append(default_color)
|
740
1085
|
|
@@ -761,6 +1106,65 @@ class NetworkPlotter:
|
|
761
1106
|
plt.show(*args, **kwargs)
|
762
1107
|
|
763
1108
|
|
1109
|
+
def _to_rgba(
|
1110
|
+
color: Union[str, List, Tuple, np.ndarray],
|
1111
|
+
alpha: float = 1.0,
|
1112
|
+
num_repeats: Union[int, None] = None,
|
1113
|
+
) -> np.ndarray:
|
1114
|
+
"""Convert a color or array of colors to RGBA format, applying alpha only if the color is RGB.
|
1115
|
+
|
1116
|
+
Args:
|
1117
|
+
color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
|
1118
|
+
alpha (float, optional): Alpha value (transparency) to apply if the color is in RGB format. Defaults to 1.0.
|
1119
|
+
num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
|
1120
|
+
|
1121
|
+
Returns:
|
1122
|
+
np.ndarray: The RGBA color or array of RGBA colors.
|
1123
|
+
"""
|
1124
|
+
# Handle single color case (string, RGB, or RGBA)
|
1125
|
+
if isinstance(color, str) or (
|
1126
|
+
isinstance(color, (list, tuple, np.ndarray)) and len(color) in [3, 4]
|
1127
|
+
):
|
1128
|
+
rgba_color = np.array(mcolors.to_rgba(color))
|
1129
|
+
# Only set alpha if the input is an RGB color or a string (not RGBA)
|
1130
|
+
if len(rgba_color) == 4 and (
|
1131
|
+
len(color) == 3 or isinstance(color, str)
|
1132
|
+
): # If it's RGB or a string, set the alpha
|
1133
|
+
rgba_color[3] = alpha
|
1134
|
+
|
1135
|
+
# Repeat the color if num_repeats argument is provided
|
1136
|
+
if num_repeats is not None:
|
1137
|
+
return np.array([rgba_color] * num_repeats)
|
1138
|
+
|
1139
|
+
return rgba_color
|
1140
|
+
|
1141
|
+
# Handle array of colors case (including strings, RGB, and RGBA)
|
1142
|
+
elif isinstance(color, (list, tuple, np.ndarray)):
|
1143
|
+
rgba_colors = []
|
1144
|
+
for c in color:
|
1145
|
+
# Ensure each element is either a valid string or a list/tuple of length 3 (RGB) or 4 (RGBA)
|
1146
|
+
if isinstance(c, str) or (
|
1147
|
+
isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]
|
1148
|
+
):
|
1149
|
+
rgba_c = np.array(mcolors.to_rgba(c))
|
1150
|
+
# Apply alpha only to RGB colors (not RGBA) and strings
|
1151
|
+
if len(rgba_c) == 4 and (len(c) == 3 or isinstance(c, str)):
|
1152
|
+
rgba_c[3] = alpha
|
1153
|
+
|
1154
|
+
rgba_colors.append(rgba_c)
|
1155
|
+
else:
|
1156
|
+
raise ValueError(f"Invalid color: {c}. Must be a valid RGB/RGBA or string color.")
|
1157
|
+
|
1158
|
+
# Repeat the colors if num_repeats argument is provided
|
1159
|
+
if num_repeats is not None and len(rgba_colors) == 1:
|
1160
|
+
return np.array([rgba_colors[0]] * num_repeats)
|
1161
|
+
|
1162
|
+
return np.array(rgba_colors)
|
1163
|
+
|
1164
|
+
else:
|
1165
|
+
raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
|
1166
|
+
|
1167
|
+
|
764
1168
|
def _is_connected(z: np.ndarray) -> bool:
|
765
1169
|
"""Determine if a thresholded grid represents a single, connected component.
|
766
1170
|
|