risk-network 0.0.6b0__py3-none-any.whl → 0.0.6b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/io.py +15 -3
- risk/neighborhoods/community.py +7 -3
- risk/neighborhoods/neighborhoods.py +2 -2
- risk/network/graph.py +26 -21
- risk/network/io.py +11 -0
- risk/network/plot.py +430 -202
- risk/risk.py +6 -18
- {risk_network-0.0.6b0.dist-info → risk_network-0.0.6b2.dist-info}/METADATA +1 -1
- {risk_network-0.0.6b0.dist-info → risk_network-0.0.6b2.dist-info}/RECORD +13 -13
- {risk_network-0.0.6b0.dist-info → risk_network-0.0.6b2.dist-info}/LICENSE +0 -0
- {risk_network-0.0.6b0.dist-info → risk_network-0.0.6b2.dist-info}/WHEEL +0 -0
- {risk_network-0.0.6b0.dist-info → risk_network-0.0.6b2.dist-info}/top_level.txt +0 -0
risk/network/plot.py
CHANGED
@@ -5,7 +5,6 @@ risk/network/plot
|
|
5
5
|
|
6
6
|
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
|
-
import matplotlib
|
9
8
|
import matplotlib.colors as mcolors
|
10
9
|
import matplotlib.pyplot as plt
|
11
10
|
import networkx as nx
|
@@ -18,66 +17,42 @@ from risk.network.graph import NetworkGraph
|
|
18
17
|
|
19
18
|
|
20
19
|
class NetworkPlotter:
|
21
|
-
"""A class
|
20
|
+
"""A class for visualizing network graphs with customizable options.
|
22
21
|
|
23
|
-
The NetworkPlotter class
|
24
|
-
and
|
25
|
-
|
22
|
+
The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
|
23
|
+
flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
|
24
|
+
perimeter, and adjusting background colors.
|
26
25
|
"""
|
27
26
|
|
28
27
|
def __init__(
|
29
28
|
self,
|
30
29
|
graph: NetworkGraph,
|
31
|
-
figsize:
|
32
|
-
background_color: str = "white",
|
33
|
-
plot_outline: bool = True,
|
34
|
-
outline_color: str = "black",
|
35
|
-
outline_scale: float = 1.0,
|
36
|
-
linestyle: str = "dashed",
|
30
|
+
figsize: Tuple = (10, 10),
|
31
|
+
background_color: Union[str, List, Tuple, np.ndarray] = "white",
|
37
32
|
) -> None:
|
38
33
|
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
39
34
|
|
40
35
|
Args:
|
41
36
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
42
37
|
figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
|
43
|
-
background_color (str, optional): Background color of the plot. Defaults to "white".
|
44
|
-
plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
|
45
|
-
outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
|
46
|
-
outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
|
47
|
-
linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
|
38
|
+
background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
|
48
39
|
"""
|
49
40
|
self.graph = graph
|
50
41
|
# Initialize the plot with the specified parameters
|
51
|
-
self.ax = self._initialize_plot(
|
52
|
-
graph,
|
53
|
-
figsize,
|
54
|
-
background_color,
|
55
|
-
plot_outline,
|
56
|
-
outline_color,
|
57
|
-
outline_scale,
|
58
|
-
linestyle,
|
59
|
-
)
|
42
|
+
self.ax = self._initialize_plot(graph, figsize, background_color)
|
60
43
|
|
61
44
|
def _initialize_plot(
|
62
45
|
self,
|
63
46
|
graph: NetworkGraph,
|
64
|
-
figsize:
|
65
|
-
background_color: str,
|
66
|
-
plot_outline: bool,
|
67
|
-
outline_color: str,
|
68
|
-
outline_scale: float,
|
69
|
-
linestyle: str,
|
47
|
+
figsize: Tuple,
|
48
|
+
background_color: Union[str, List, Tuple, np.ndarray],
|
70
49
|
) -> plt.Axes:
|
71
|
-
"""Set up the plot with figure size
|
50
|
+
"""Set up the plot with figure size and background color.
|
72
51
|
|
73
52
|
Args:
|
74
53
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
75
54
|
figsize (tuple): Size of the figure in inches (width, height).
|
76
55
|
background_color (str): Background color of the plot.
|
77
|
-
plot_outline (bool): Whether to plot the network perimeter circle.
|
78
|
-
outline_color (str): Color of the network perimeter circle.
|
79
|
-
outline_scale (float): Outline scaling factor for the perimeter diameter.
|
80
|
-
linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid).
|
81
56
|
|
82
57
|
Returns:
|
83
58
|
plt.Axes: The axis object for the plot.
|
@@ -86,31 +61,19 @@ class NetworkPlotter:
|
|
86
61
|
node_coordinates = graph.node_coordinates
|
87
62
|
# Calculate the center and radius of the bounding box around the network
|
88
63
|
center, radius = _calculate_bounding_box(node_coordinates)
|
89
|
-
# Scale the radius by the outline_scale factor
|
90
|
-
scaled_radius = radius * outline_scale
|
91
64
|
|
92
65
|
# Create a new figure and axis for plotting
|
93
66
|
fig, ax = plt.subplots(figsize=figsize)
|
94
67
|
fig.tight_layout() # Adjust subplot parameters to give specified padding
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
center,
|
99
|
-
scaled_radius,
|
100
|
-
linestyle=linestyle, # Use the linestyle argument here
|
101
|
-
color=outline_color,
|
102
|
-
fill=False,
|
103
|
-
linewidth=1.5,
|
104
|
-
)
|
105
|
-
ax.add_artist(circle) # Add the circle to the plot
|
106
|
-
|
107
|
-
# Set axis limits based on the calculated bounding box and scaled radius
|
108
|
-
ax.set_xlim([center[0] - scaled_radius - 0.3, center[0] + scaled_radius + 0.3])
|
109
|
-
ax.set_ylim([center[1] - scaled_radius - 0.3, center[1] + scaled_radius + 0.3])
|
68
|
+
# Set axis limits based on the calculated bounding box and radius
|
69
|
+
ax.set_xlim([center[0] - radius - 0.3, center[0] + radius + 0.3])
|
70
|
+
ax.set_ylim([center[1] - radius - 0.3, center[1] + radius + 0.3])
|
110
71
|
ax.set_aspect("equal") # Ensure the aspect ratio is equal
|
111
|
-
fig.patch.set_facecolor(background_color) # Set the figure background color
|
112
|
-
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
113
72
|
|
73
|
+
# Set the background color of the plot
|
74
|
+
# Convert color to RGBA using the _to_rgba helper function
|
75
|
+
fig.patch.set_facecolor(_to_rgba(background_color, 1.0))
|
76
|
+
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
114
77
|
# Remove axis spines for a cleaner look
|
115
78
|
for spine in ax.spines.values():
|
116
79
|
spine.set_visible(False)
|
@@ -122,44 +85,182 @@ class NetworkPlotter:
|
|
122
85
|
|
123
86
|
return ax
|
124
87
|
|
88
|
+
def plot_circle_perimeter(
|
89
|
+
self,
|
90
|
+
scale: float = 1.0,
|
91
|
+
linestyle: str = "dashed",
|
92
|
+
linewidth: float = 1.5,
|
93
|
+
color: Union[str, List, Tuple, np.ndarray] = "black",
|
94
|
+
outline_alpha: float = 1.0,
|
95
|
+
fill_alpha: float = 0.0,
|
96
|
+
) -> None:
|
97
|
+
"""Plot a circle around the network graph to represent the network perimeter.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
|
101
|
+
linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
|
102
|
+
linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
|
103
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
|
104
|
+
outline_alpha (float, optional): Transparency level of the circle outline. Defaults to 1.0.
|
105
|
+
fill_alpha (float, optional): Transparency level of the circle fill. Defaults to 0.0.
|
106
|
+
"""
|
107
|
+
# Log the circle perimeter plotting parameters
|
108
|
+
params.log_plotter(
|
109
|
+
perimeter_type="circle",
|
110
|
+
perimeter_scale=scale,
|
111
|
+
perimeter_linestyle=linestyle,
|
112
|
+
perimeter_linewidth=linewidth,
|
113
|
+
perimeter_color=(
|
114
|
+
"custom" if isinstance(color, (list, tuple, np.ndarray)) else color
|
115
|
+
), # np.ndarray usually indicates custom colors
|
116
|
+
perimeter_outline_alpha=outline_alpha,
|
117
|
+
perimeter_fill_alpha=fill_alpha,
|
118
|
+
)
|
119
|
+
|
120
|
+
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
121
|
+
color = _to_rgba(color, outline_alpha)
|
122
|
+
# Extract node coordinates from the network graph
|
123
|
+
node_coordinates = self.graph.node_coordinates
|
124
|
+
# Calculate the center and radius of the bounding box around the network
|
125
|
+
center, radius = _calculate_bounding_box(node_coordinates)
|
126
|
+
# Scale the radius by the scale factor
|
127
|
+
scaled_radius = radius * scale
|
128
|
+
|
129
|
+
# Draw a circle to represent the network perimeter
|
130
|
+
circle = plt.Circle(
|
131
|
+
center,
|
132
|
+
scaled_radius,
|
133
|
+
linestyle=linestyle,
|
134
|
+
linewidth=linewidth,
|
135
|
+
color=color,
|
136
|
+
fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
|
137
|
+
)
|
138
|
+
# Set the transparency of the fill if applicable
|
139
|
+
if fill_alpha > 0:
|
140
|
+
circle.set_facecolor(
|
141
|
+
_to_rgba(color, fill_alpha)
|
142
|
+
) # Use _to_rgba to set the fill color with transparency
|
143
|
+
|
144
|
+
self.ax.add_artist(circle)
|
145
|
+
|
146
|
+
def plot_contour_perimeter(
|
147
|
+
self,
|
148
|
+
scale: float = 1.0,
|
149
|
+
levels: int = 3,
|
150
|
+
bandwidth: float = 0.8,
|
151
|
+
grid_size: int = 250,
|
152
|
+
color: Union[str, List, Tuple, np.ndarray] = "black",
|
153
|
+
linestyle: str = "solid",
|
154
|
+
linewidth: float = 1.5,
|
155
|
+
outline_alpha: float = 1.0,
|
156
|
+
fill_alpha: float = 0.0,
|
157
|
+
) -> None:
|
158
|
+
"""
|
159
|
+
Plot a KDE-based contour around the network graph to represent the network perimeter.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
|
163
|
+
levels (int, optional): Number of contour levels. Defaults to 3.
|
164
|
+
bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
|
165
|
+
grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
|
166
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
|
167
|
+
linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
|
168
|
+
linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
|
169
|
+
outline_alpha (float, optional): Transparency level of the contour outline. Defaults to 1.0.
|
170
|
+
fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.0.
|
171
|
+
"""
|
172
|
+
# Log the contour perimeter plotting parameters
|
173
|
+
params.log_plotter(
|
174
|
+
perimeter_type="contour",
|
175
|
+
perimeter_scale=scale,
|
176
|
+
perimeter_levels=levels,
|
177
|
+
perimeter_bandwidth=bandwidth,
|
178
|
+
perimeter_grid_size=grid_size,
|
179
|
+
perimeter_linestyle=linestyle,
|
180
|
+
perimeter_linewidth=linewidth,
|
181
|
+
perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
|
182
|
+
perimeter_outline_alpha=outline_alpha,
|
183
|
+
perimeter_fill_alpha=fill_alpha,
|
184
|
+
)
|
185
|
+
|
186
|
+
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
187
|
+
color = _to_rgba(color, outline_alpha)
|
188
|
+
# Extract node coordinates from the network graph
|
189
|
+
node_coordinates = self.graph.node_coordinates
|
190
|
+
# Scale the node coordinates if needed
|
191
|
+
scaled_coordinates = node_coordinates * scale
|
192
|
+
# Use the existing _draw_kde_contour method
|
193
|
+
self._draw_kde_contour(
|
194
|
+
ax=self.ax,
|
195
|
+
pos=scaled_coordinates,
|
196
|
+
nodes=list(range(len(node_coordinates))), # All nodes are included
|
197
|
+
levels=levels,
|
198
|
+
bandwidth=bandwidth,
|
199
|
+
grid_size=grid_size,
|
200
|
+
color=color,
|
201
|
+
linestyle=linestyle,
|
202
|
+
linewidth=linewidth,
|
203
|
+
alpha=fill_alpha, # Use fill_alpha for the fill
|
204
|
+
)
|
205
|
+
|
125
206
|
def plot_network(
|
126
207
|
self,
|
127
208
|
node_size: Union[int, np.ndarray] = 50,
|
128
|
-
edge_width: float = 1.0,
|
129
|
-
node_color: Union[str, np.ndarray] = "white",
|
130
|
-
node_edgecolor: str = "black",
|
131
|
-
edge_color: str = "black",
|
132
209
|
node_shape: str = "o",
|
210
|
+
edge_width: float = 1.0,
|
211
|
+
node_color: Union[str, List, Tuple, np.ndarray] = "white",
|
212
|
+
node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
|
213
|
+
edge_color: Union[str, List, Tuple, np.ndarray] = "black",
|
214
|
+
node_alpha: float = 1.0,
|
215
|
+
edge_alpha: float = 1.0,
|
133
216
|
) -> None:
|
134
217
|
"""Plot the network graph with customizable node colors, sizes, and edge widths.
|
135
218
|
|
136
219
|
Args:
|
137
220
|
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
138
|
-
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
139
|
-
node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
140
|
-
node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
|
141
|
-
edge_color (str, optional): Color of the edges. Defaults to "black".
|
142
221
|
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
222
|
+
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
223
|
+
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
224
|
+
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
225
|
+
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
226
|
+
node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0. Annotated node_color alphas will override this value.
|
227
|
+
edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
|
143
228
|
"""
|
144
229
|
# Log the plotting parameters
|
145
230
|
params.log_plotter(
|
146
|
-
network_node_size=
|
231
|
+
network_node_size=(
|
232
|
+
"custom" if isinstance(node_size, np.ndarray) else node_size
|
233
|
+
), # np.ndarray usually indicates custom sizes
|
234
|
+
network_node_shape=node_shape,
|
147
235
|
network_edge_width=edge_width,
|
148
|
-
network_node_color=
|
236
|
+
network_node_color=(
|
237
|
+
"custom" if isinstance(node_color, np.ndarray) else node_color
|
238
|
+
), # np.ndarray usually indicates custom colors
|
149
239
|
network_node_edgecolor=node_edgecolor,
|
150
240
|
network_edge_color=edge_color,
|
151
|
-
|
241
|
+
network_node_alpha=node_alpha,
|
242
|
+
network_edge_alpha=edge_alpha,
|
152
243
|
)
|
244
|
+
|
245
|
+
# Convert colors to RGBA using the _to_rgba helper function
|
246
|
+
# If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
|
247
|
+
node_color = _to_rgba(node_color, node_alpha, num_repeats=len(self.graph.network.nodes))
|
248
|
+
# Convert other colors to RGBA using the _to_rgba helper function
|
249
|
+
node_edgecolor = _to_rgba(
|
250
|
+
node_edgecolor, 1.0, num_repeats=len(self.graph.network.nodes)
|
251
|
+
) # Node edges are usually fully opaque
|
252
|
+
edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
|
253
|
+
|
153
254
|
# Extract node coordinates from the network graph
|
154
255
|
node_coordinates = self.graph.node_coordinates
|
256
|
+
|
155
257
|
# Draw the nodes of the graph
|
156
258
|
nx.draw_networkx_nodes(
|
157
259
|
self.graph.network,
|
158
260
|
pos=node_coordinates,
|
159
261
|
node_size=node_size,
|
160
|
-
node_color=node_color,
|
161
262
|
node_shape=node_shape,
|
162
|
-
|
263
|
+
node_color=node_color,
|
163
264
|
edgecolors=node_edgecolor,
|
164
265
|
ax=self.ax,
|
165
266
|
)
|
@@ -174,37 +275,33 @@ class NetworkPlotter:
|
|
174
275
|
|
175
276
|
def plot_subnetwork(
|
176
277
|
self,
|
177
|
-
nodes:
|
278
|
+
nodes: List,
|
178
279
|
node_size: Union[int, np.ndarray] = 50,
|
179
|
-
edge_width: float = 1.0,
|
180
|
-
node_color: Union[str, np.ndarray] = "white",
|
181
|
-
node_edgecolor: str = "black",
|
182
|
-
edge_color: str = "black",
|
183
280
|
node_shape: str = "o",
|
281
|
+
edge_width: float = 1.0,
|
282
|
+
node_color: Union[str, List, Tuple, np.ndarray] = "white",
|
283
|
+
node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
|
284
|
+
edge_color: Union[str, List, Tuple, np.ndarray] = "black",
|
285
|
+
node_alpha: float = 1.0,
|
286
|
+
edge_alpha: float = 1.0,
|
184
287
|
) -> None:
|
185
288
|
"""Plot a subnetwork of selected nodes with customizable node and edge attributes.
|
186
289
|
|
187
290
|
Args:
|
188
291
|
nodes (list): List of node labels to include in the subnetwork.
|
189
292
|
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
190
|
-
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
191
|
-
node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
192
|
-
node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
|
193
|
-
edge_color (str, optional): Color of the edges. Defaults to "black".
|
194
293
|
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
294
|
+
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
295
|
+
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
296
|
+
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
297
|
+
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
298
|
+
node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0.
|
299
|
+
edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
|
195
300
|
|
196
301
|
Raises:
|
197
302
|
ValueError: If no valid nodes are found in the network graph.
|
198
303
|
"""
|
199
|
-
#
|
200
|
-
params.log_plotter(
|
201
|
-
subnetwork_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
|
202
|
-
subnetwork_edge_width=edge_width,
|
203
|
-
subnetwork_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
|
204
|
-
subnetwork_node_edgecolor=node_edgecolor,
|
205
|
-
subnetwork_edge_color=edge_color,
|
206
|
-
subnet_node_shape=node_shape,
|
207
|
-
)
|
304
|
+
# Don't log subnetwork parameters as they are specific to individual annotations
|
208
305
|
# Filter to get node IDs and their coordinates
|
209
306
|
node_ids = [
|
210
307
|
self.graph.node_label_to_id_map.get(node)
|
@@ -214,17 +311,21 @@ class NetworkPlotter:
|
|
214
311
|
if not node_ids:
|
215
312
|
raise ValueError("No nodes found in the network graph.")
|
216
313
|
|
314
|
+
# Convert colors to RGBA using the _to_rgba helper function
|
315
|
+
node_color = _to_rgba(node_color, node_alpha)
|
316
|
+
node_edgecolor = _to_rgba(node_edgecolor, 1.0) # Node edges usually fully opaque
|
317
|
+
edge_color = _to_rgba(edge_color, edge_alpha)
|
217
318
|
# Get the coordinates of the filtered nodes
|
218
319
|
node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
|
320
|
+
|
219
321
|
# Draw the nodes in the subnetwork
|
220
322
|
nx.draw_networkx_nodes(
|
221
323
|
self.graph.network,
|
222
324
|
pos=node_coordinates,
|
223
325
|
nodelist=node_ids,
|
224
326
|
node_size=node_size,
|
225
|
-
node_color=node_color,
|
226
327
|
node_shape=node_shape,
|
227
|
-
|
328
|
+
node_color=node_color,
|
228
329
|
edgecolors=node_edgecolor,
|
229
330
|
ax=self.ax,
|
230
331
|
)
|
@@ -243,8 +344,10 @@ class NetworkPlotter:
|
|
243
344
|
levels: int = 5,
|
244
345
|
bandwidth: float = 0.8,
|
245
346
|
grid_size: int = 250,
|
347
|
+
color: Union[str, List, Tuple, np.ndarray] = "white",
|
348
|
+
linestyle: str = "solid",
|
349
|
+
linewidth: float = 1.5,
|
246
350
|
alpha: float = 0.2,
|
247
|
-
color: Union[str, np.ndarray] = "white",
|
248
351
|
) -> None:
|
249
352
|
"""Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
|
250
353
|
|
@@ -252,21 +355,24 @@ class NetworkPlotter:
|
|
252
355
|
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
253
356
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
254
357
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
358
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors. Defaults to "white".
|
359
|
+
linestyle (str, optional): Line style for the contours. Defaults to "solid".
|
360
|
+
linewidth (float, optional): Line width for the contours. Defaults to 1.5.
|
255
361
|
alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
|
256
|
-
color (str or np.ndarray, optional): Color of the contours. Can be a string (e.g., 'white') or an array of colors. Defaults to "white".
|
257
362
|
"""
|
258
363
|
# Log the contour plotting parameters
|
259
364
|
params.log_plotter(
|
260
365
|
contour_levels=levels,
|
261
366
|
contour_bandwidth=bandwidth,
|
262
367
|
contour_grid_size=grid_size,
|
368
|
+
contour_color=(
|
369
|
+
"custom" if isinstance(color, np.ndarray) else color
|
370
|
+
), # np.ndarray usually indicates custom colors
|
263
371
|
contour_alpha=alpha,
|
264
|
-
contour_color="custom" if isinstance(color, np.ndarray) else color,
|
265
372
|
)
|
266
|
-
# Convert color string to RGBA array if necessary
|
267
|
-
if isinstance(color, str):
|
268
|
-
color = self.get_annotated_contour_colors(color=color)
|
269
373
|
|
374
|
+
# Ensure color is converted to RGBA with repetition matching the number of domains
|
375
|
+
color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_to_nodes))
|
270
376
|
# Extract node coordinates from the network graph
|
271
377
|
node_coordinates = self.graph.node_coordinates
|
272
378
|
# Draw contours for each domain in the network
|
@@ -280,17 +386,21 @@ class NetworkPlotter:
|
|
280
386
|
levels=levels,
|
281
387
|
bandwidth=bandwidth,
|
282
388
|
grid_size=grid_size,
|
389
|
+
linestyle=linestyle,
|
390
|
+
linewidth=linewidth,
|
283
391
|
alpha=alpha,
|
284
392
|
)
|
285
393
|
|
286
394
|
def plot_subcontour(
|
287
395
|
self,
|
288
|
-
nodes:
|
396
|
+
nodes: List,
|
289
397
|
levels: int = 5,
|
290
398
|
bandwidth: float = 0.8,
|
291
399
|
grid_size: int = 250,
|
400
|
+
color: Union[str, List, Tuple, np.ndarray] = "white",
|
401
|
+
linestyle: str = "solid",
|
402
|
+
linewidth: float = 1.5,
|
292
403
|
alpha: float = 0.2,
|
293
|
-
color: Union[str, np.ndarray] = "white",
|
294
404
|
) -> None:
|
295
405
|
"""Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
|
296
406
|
|
@@ -299,20 +409,15 @@ class NetworkPlotter:
|
|
299
409
|
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
300
410
|
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
301
411
|
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
412
|
+
color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
|
413
|
+
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
414
|
+
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
302
415
|
alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
|
303
|
-
color (str or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
|
304
416
|
|
305
417
|
Raises:
|
306
418
|
ValueError: If no valid nodes are found in the network graph.
|
307
419
|
"""
|
308
|
-
#
|
309
|
-
params.log_plotter(
|
310
|
-
subcontour_levels=levels,
|
311
|
-
subcontour_bandwidth=bandwidth,
|
312
|
-
subcontour_grid_size=grid_size,
|
313
|
-
subcontour_alpha=alpha,
|
314
|
-
subcontour_color="custom" if isinstance(color, np.ndarray) else color,
|
315
|
-
)
|
420
|
+
# Don't log subcontour parameters as they are specific to individual annotations
|
316
421
|
# Filter to get node IDs and their coordinates
|
317
422
|
node_ids = [
|
318
423
|
self.graph.node_label_to_id_map.get(node)
|
@@ -322,6 +427,8 @@ class NetworkPlotter:
|
|
322
427
|
if not node_ids or len(node_ids) == 1:
|
323
428
|
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
324
429
|
|
430
|
+
# Convert color to RGBA using the _to_rgba helper function
|
431
|
+
color = _to_rgba(color, alpha)
|
325
432
|
# Draw the KDE contour for the specified nodes
|
326
433
|
node_coordinates = self.graph.node_coordinates
|
327
434
|
self._draw_kde_contour(
|
@@ -332,6 +439,8 @@ class NetworkPlotter:
|
|
332
439
|
levels=levels,
|
333
440
|
bandwidth=bandwidth,
|
334
441
|
grid_size=grid_size,
|
442
|
+
linestyle=linestyle,
|
443
|
+
linewidth=linewidth,
|
335
444
|
alpha=alpha,
|
336
445
|
)
|
337
446
|
|
@@ -339,11 +448,13 @@ class NetworkPlotter:
|
|
339
448
|
self,
|
340
449
|
ax: plt.Axes,
|
341
450
|
pos: np.ndarray,
|
342
|
-
nodes:
|
343
|
-
color: Union[str, np.ndarray],
|
451
|
+
nodes: List,
|
344
452
|
levels: int = 5,
|
345
453
|
bandwidth: float = 0.8,
|
346
454
|
grid_size: int = 250,
|
455
|
+
color: Union[str, np.ndarray] = "white",
|
456
|
+
linestyle: str = "solid",
|
457
|
+
linewidth: float = 1.5,
|
347
458
|
alpha: float = 0.5,
|
348
459
|
) -> None:
|
349
460
|
"""Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
|
@@ -352,10 +463,12 @@ class NetworkPlotter:
|
|
352
463
|
ax (plt.Axes): The axis to draw the contour on.
|
353
464
|
pos (np.ndarray): Array of node positions (x, y).
|
354
465
|
nodes (list): List of node indices to include in the contour.
|
355
|
-
color (str or np.ndarray): Color for the contour.
|
356
466
|
levels (int, optional): Number of contour levels. Defaults to 5.
|
357
467
|
bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
|
358
468
|
grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
|
469
|
+
color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
|
470
|
+
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
471
|
+
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
359
472
|
alpha (float, optional): Transparency level for the contour fill. Defaults to 0.5.
|
360
473
|
"""
|
361
474
|
# Extract the positions of the specified nodes
|
@@ -382,8 +495,7 @@ class NetworkPlotter:
|
|
382
495
|
min_density, max_density = z.min(), z.max()
|
383
496
|
contour_levels = np.linspace(min_density, max_density, levels)[1:]
|
384
497
|
contour_colors = [color for _ in range(levels - 1)]
|
385
|
-
|
386
|
-
# Plot the filled contours if alpha > 0
|
498
|
+
# Plot the filled contours only if alpha > 0
|
387
499
|
if alpha > 0:
|
388
500
|
ax.contourf(
|
389
501
|
x,
|
@@ -391,25 +503,34 @@ class NetworkPlotter:
|
|
391
503
|
z,
|
392
504
|
levels=contour_levels,
|
393
505
|
colors=contour_colors,
|
394
|
-
alpha=alpha,
|
395
|
-
extend="neither",
|
396
506
|
antialiased=True,
|
507
|
+
alpha=alpha,
|
397
508
|
)
|
398
509
|
|
399
|
-
# Plot the contour lines without
|
400
|
-
c = ax.contour(
|
510
|
+
# Plot the contour lines without any change in behavior
|
511
|
+
c = ax.contour(
|
512
|
+
x,
|
513
|
+
y,
|
514
|
+
z,
|
515
|
+
levels=contour_levels,
|
516
|
+
colors=contour_colors,
|
517
|
+
linestyles=linestyle,
|
518
|
+
linewidths=linewidth,
|
519
|
+
)
|
401
520
|
for i in range(1, len(contour_levels)):
|
402
521
|
c.collections[i].set_linewidth(0)
|
403
522
|
|
404
523
|
def plot_labels(
|
405
524
|
self,
|
406
|
-
|
525
|
+
scale: float = 1.05,
|
407
526
|
offset: float = 0.10,
|
408
527
|
font: str = "Arial",
|
409
528
|
fontsize: int = 10,
|
410
|
-
fontcolor: Union[str, np.ndarray] = "black",
|
529
|
+
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
530
|
+
fontalpha: float = 1.0,
|
411
531
|
arrow_linewidth: float = 1,
|
412
|
-
arrow_color: Union[str, np.ndarray] = "black",
|
532
|
+
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
533
|
+
arrow_alpha: float = 1.0,
|
413
534
|
max_labels: Union[int, None] = None,
|
414
535
|
max_words: int = 10,
|
415
536
|
min_words: int = 1,
|
@@ -420,13 +541,15 @@ class NetworkPlotter:
|
|
420
541
|
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
421
542
|
|
422
543
|
Args:
|
423
|
-
|
544
|
+
scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
|
424
545
|
offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
|
425
546
|
font (str, optional): Font name for the labels. Defaults to "Arial".
|
426
547
|
fontsize (int, optional): Font size for the labels. Defaults to 10.
|
427
|
-
fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
|
548
|
+
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
|
549
|
+
fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
|
428
550
|
arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
|
429
|
-
arrow_color (str or np.ndarray, optional): Color of the arrows.
|
551
|
+
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
|
552
|
+
arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
|
430
553
|
max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
|
431
554
|
max_words (int, optional): Maximum number of words in a label. Defaults to 10.
|
432
555
|
min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
|
@@ -436,13 +559,17 @@ class NetworkPlotter:
|
|
436
559
|
"""
|
437
560
|
# Log the plotting parameters
|
438
561
|
params.log_plotter(
|
439
|
-
label_perimeter_scale=
|
562
|
+
label_perimeter_scale=scale,
|
440
563
|
label_offset=offset,
|
441
564
|
label_font=font,
|
442
565
|
label_fontsize=fontsize,
|
443
|
-
label_fontcolor=
|
566
|
+
label_fontcolor=(
|
567
|
+
"custom" if isinstance(fontcolor, np.ndarray) else fontcolor
|
568
|
+
), # np.ndarray usually indicates custom colors
|
569
|
+
label_fontalpha=fontalpha,
|
444
570
|
label_arrow_linewidth=arrow_linewidth,
|
445
571
|
label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
572
|
+
label_arrow_alpha=arrow_alpha,
|
446
573
|
label_max_labels=max_labels,
|
447
574
|
label_max_words=max_words,
|
448
575
|
label_min_words=min_words,
|
@@ -451,11 +578,12 @@ class NetworkPlotter:
|
|
451
578
|
label_words_to_omit=words_to_omit,
|
452
579
|
)
|
453
580
|
|
454
|
-
# Convert
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
581
|
+
# Convert colors to RGBA using the _to_rgba helper function, applying alpha separately for font and arrow
|
582
|
+
fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes))
|
583
|
+
arrow_color = _to_rgba(
|
584
|
+
arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_to_nodes)
|
585
|
+
)
|
586
|
+
|
459
587
|
# Normalize words_to_omit to lowercase
|
460
588
|
if words_to_omit:
|
461
589
|
words_to_omit = set(word.lower() for word in words_to_omit)
|
@@ -509,9 +637,7 @@ class NetworkPlotter:
|
|
509
637
|
valid_indices = [valid_indices[i] for i in selected_indices]
|
510
638
|
|
511
639
|
# Calculate the bounding box around the network
|
512
|
-
center, radius = _calculate_bounding_box(
|
513
|
-
self.graph.node_coordinates, radius_margin=perimeter_scale
|
514
|
-
)
|
640
|
+
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
515
641
|
# Calculate the best positions for labels around the perimeter
|
516
642
|
best_label_positions = _calculate_best_label_positions(
|
517
643
|
filtered_domain_centroids, center, radius, offset
|
@@ -535,16 +661,18 @@ class NetworkPlotter:
|
|
535
661
|
|
536
662
|
def plot_sublabel(
|
537
663
|
self,
|
538
|
-
nodes:
|
664
|
+
nodes: List,
|
539
665
|
label: str,
|
540
666
|
radial_position: float = 0.0,
|
541
|
-
|
667
|
+
scale: float = 1.05,
|
542
668
|
offset: float = 0.10,
|
543
669
|
font: str = "Arial",
|
544
670
|
fontsize: int = 10,
|
545
|
-
fontcolor: str = "black",
|
671
|
+
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
672
|
+
fontalpha: float = 1.0,
|
546
673
|
arrow_linewidth: float = 1,
|
547
|
-
arrow_color: str = "black",
|
674
|
+
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
675
|
+
arrow_alpha: float = 1.0,
|
548
676
|
) -> None:
|
549
677
|
"""Annotate the network graph with a single label for the given nodes, positioned at a specified radial angle.
|
550
678
|
|
@@ -552,26 +680,17 @@ class NetworkPlotter:
|
|
552
680
|
nodes (List[str]): List of node labels to be used for calculating the centroid.
|
553
681
|
label (str): The label to be annotated on the network.
|
554
682
|
radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
|
555
|
-
|
683
|
+
scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
|
556
684
|
offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
|
557
685
|
font (str, optional): Font name for the label. Defaults to "Arial".
|
558
686
|
fontsize (int, optional): Font size for the label. Defaults to 10.
|
559
|
-
fontcolor (str, optional): Color of the label text. Defaults to "black".
|
687
|
+
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
|
688
|
+
fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
|
560
689
|
arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
|
561
|
-
arrow_color (str, optional): Color of the arrow. Defaults to "black".
|
690
|
+
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
|
691
|
+
arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
|
562
692
|
"""
|
563
|
-
#
|
564
|
-
params.log_plotter(
|
565
|
-
sublabel_perimeter_scale=perimeter_scale,
|
566
|
-
sublabel_offset=offset,
|
567
|
-
sublabel_font=font,
|
568
|
-
sublabel_fontsize=fontsize,
|
569
|
-
sublabel_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
|
570
|
-
sublabel_arrow_linewidth=arrow_linewidth,
|
571
|
-
sublabel_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
572
|
-
sublabel_radial_position=radial_position,
|
573
|
-
)
|
574
|
-
|
693
|
+
# Don't log sublabel parameters as they are specific to individual annotations
|
575
694
|
# Map node labels to IDs
|
576
695
|
node_ids = [
|
577
696
|
self.graph.node_label_to_id_map.get(node)
|
@@ -581,12 +700,13 @@ class NetworkPlotter:
|
|
581
700
|
if not node_ids or len(node_ids) == 1:
|
582
701
|
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
583
702
|
|
703
|
+
# Convert fontcolor and arrow_color to RGBA using the _to_rgba helper function
|
704
|
+
fontcolor = _to_rgba(fontcolor, fontalpha)
|
705
|
+
arrow_color = _to_rgba(arrow_color, arrow_alpha)
|
584
706
|
# Calculate the centroid of the provided nodes
|
585
707
|
centroid = self._calculate_domain_centroid(node_ids)
|
586
708
|
# Calculate the bounding box around the network
|
587
|
-
center, radius = _calculate_bounding_box(
|
588
|
-
self.graph.node_coordinates, radius_margin=perimeter_scale
|
589
|
-
)
|
709
|
+
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
590
710
|
# Convert radial position to radians, adjusting for a 90-degree rotation
|
591
711
|
radial_radians = np.deg2rad(radial_position - 90)
|
592
712
|
label_position = (
|
@@ -608,7 +728,7 @@ class NetworkPlotter:
|
|
608
728
|
arrowprops=dict(arrowstyle="->", color=arrow_color, linewidth=arrow_linewidth),
|
609
729
|
)
|
610
730
|
|
611
|
-
def _calculate_domain_centroid(self, nodes:
|
731
|
+
def _calculate_domain_centroid(self, nodes: List) -> tuple:
|
612
732
|
"""Calculate the most centrally located node in .
|
613
733
|
|
614
734
|
Args:
|
@@ -631,44 +751,51 @@ class NetworkPlotter:
|
|
631
751
|
|
632
752
|
def get_annotated_node_colors(
|
633
753
|
self,
|
754
|
+
cmap: str = "gist_rainbow",
|
755
|
+
color: Union[str, None] = None,
|
756
|
+
min_scale: float = 0.8,
|
757
|
+
max_scale: float = 1.0,
|
758
|
+
scale_factor: float = 1.0,
|
759
|
+
alpha: float = 1.0,
|
634
760
|
nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
|
761
|
+
nonenriched_alpha: float = 1.0,
|
635
762
|
random_seed: int = 888,
|
636
|
-
**kwargs,
|
637
763
|
) -> np.ndarray:
|
638
764
|
"""Adjust the colors of nodes in the network graph based on enrichment.
|
639
765
|
|
640
766
|
Args:
|
641
|
-
|
642
|
-
|
767
|
+
cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
|
768
|
+
color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
|
769
|
+
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
770
|
+
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
771
|
+
scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
|
772
|
+
alpha (float, optional): Alpha value for enriched nodes. Defaults to 1.0.
|
773
|
+
nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
|
774
|
+
nonenriched_alpha (float, optional): Alpha value for non-enriched nodes. Defaults to 1.0.
|
643
775
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
644
|
-
**kwargs: Additional keyword arguments for `get_domain_colors`.
|
645
776
|
|
646
777
|
Returns:
|
647
778
|
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
648
779
|
"""
|
649
|
-
# Get the initial domain colors for each node
|
650
|
-
network_colors = self.graph.get_domain_colors(
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
nonenriched_color[3] = 1.0
|
664
|
-
|
665
|
-
# Adjust node colors: replace any fully transparent nodes (enriched) with the non-enriched color
|
780
|
+
# Get the initial domain colors for each node, which are returned as RGBA
|
781
|
+
network_colors = self.graph.get_domain_colors(
|
782
|
+
cmap=cmap,
|
783
|
+
color=color,
|
784
|
+
min_scale=min_scale,
|
785
|
+
max_scale=max_scale,
|
786
|
+
scale_factor=scale_factor,
|
787
|
+
random_seed=random_seed,
|
788
|
+
)
|
789
|
+
# Apply the alpha value for enriched nodes
|
790
|
+
network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
|
791
|
+
# Convert the non-enriched color to RGBA using the _to_rgba helper function
|
792
|
+
nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
|
793
|
+
# Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
|
666
794
|
adjusted_network_colors = np.where(
|
667
795
|
np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
|
668
|
-
np.array([nonenriched_color]), # Apply the non-enriched color
|
669
|
-
network_colors, # Keep the original colors
|
796
|
+
np.array([nonenriched_color]), # Apply the non-enriched color with alpha
|
797
|
+
network_colors, # Keep the original colors for enriched nodes
|
670
798
|
)
|
671
|
-
|
672
799
|
return adjusted_network_colors
|
673
800
|
|
674
801
|
def get_annotated_node_sizes(
|
@@ -697,59 +824,119 @@ class NetworkPlotter:
|
|
697
824
|
|
698
825
|
return node_sizes
|
699
826
|
|
700
|
-
def get_annotated_contour_colors(
|
701
|
-
|
827
|
+
def get_annotated_contour_colors(
|
828
|
+
self,
|
829
|
+
cmap: str = "gist_rainbow",
|
830
|
+
color: Union[str, None] = None,
|
831
|
+
min_scale: float = 0.8,
|
832
|
+
max_scale: float = 1.0,
|
833
|
+
scale_factor: float = 1.0,
|
834
|
+
random_seed: int = 888,
|
835
|
+
) -> np.ndarray:
|
836
|
+
"""Get colors for the contours based on node annotations or a specified colormap.
|
702
837
|
|
703
838
|
Args:
|
704
|
-
|
705
|
-
|
839
|
+
cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
|
840
|
+
color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
|
841
|
+
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
842
|
+
Controls the dimmest colors. Defaults to 0.8.
|
843
|
+
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
844
|
+
Controls the brightest colors. Defaults to 1.0.
|
845
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
|
846
|
+
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
847
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
706
848
|
|
707
849
|
Returns:
|
708
850
|
np.ndarray: Array of RGBA colors for contour annotations.
|
709
851
|
"""
|
710
|
-
return self._get_annotated_domain_colors(
|
852
|
+
return self._get_annotated_domain_colors(
|
853
|
+
cmap=cmap,
|
854
|
+
color=color,
|
855
|
+
min_scale=min_scale,
|
856
|
+
max_scale=max_scale,
|
857
|
+
scale_factor=scale_factor,
|
858
|
+
random_seed=random_seed,
|
859
|
+
)
|
711
860
|
|
712
|
-
def get_annotated_label_colors(
|
713
|
-
|
861
|
+
def get_annotated_label_colors(
|
862
|
+
self,
|
863
|
+
cmap: str = "gist_rainbow",
|
864
|
+
color: Union[str, None] = None,
|
865
|
+
min_scale: float = 0.8,
|
866
|
+
max_scale: float = 1.0,
|
867
|
+
scale_factor: float = 1.0,
|
868
|
+
random_seed: int = 888,
|
869
|
+
) -> np.ndarray:
|
870
|
+
"""Get colors for the labels based on node annotations or a specified colormap.
|
714
871
|
|
715
872
|
Args:
|
716
|
-
|
717
|
-
|
873
|
+
cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
|
874
|
+
color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
|
875
|
+
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
876
|
+
Controls the dimmest colors. Defaults to 0.8.
|
877
|
+
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
878
|
+
Controls the brightest colors. Defaults to 1.0.
|
879
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
|
880
|
+
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
881
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
718
882
|
|
719
883
|
Returns:
|
720
884
|
np.ndarray: Array of RGBA colors for label annotations.
|
721
885
|
"""
|
722
|
-
return self._get_annotated_domain_colors(
|
886
|
+
return self._get_annotated_domain_colors(
|
887
|
+
cmap=cmap,
|
888
|
+
color=color,
|
889
|
+
min_scale=min_scale,
|
890
|
+
max_scale=max_scale,
|
891
|
+
scale_factor=scale_factor,
|
892
|
+
random_seed=random_seed,
|
893
|
+
)
|
723
894
|
|
724
895
|
def _get_annotated_domain_colors(
|
725
|
-
self,
|
896
|
+
self,
|
897
|
+
cmap: str = "gist_rainbow",
|
898
|
+
color: Union[str, None] = None,
|
899
|
+
min_scale: float = 0.8,
|
900
|
+
max_scale: float = 1.0,
|
901
|
+
scale_factor: float = 1.0,
|
902
|
+
random_seed: int = 888,
|
726
903
|
) -> np.ndarray:
|
727
|
-
"""Get colors for the domains based on node annotations.
|
904
|
+
"""Get colors for the domains based on node annotations, or use a specified color.
|
728
905
|
|
729
906
|
Args:
|
730
|
-
|
731
|
-
|
732
|
-
|
907
|
+
cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
908
|
+
color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
|
909
|
+
min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
|
910
|
+
Defaults to 0.8.
|
911
|
+
max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
|
912
|
+
Defaults to 1.0.
|
913
|
+
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
|
914
|
+
enrichment. Higher values increase the contrast. Defaults to 1.0.
|
915
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
733
916
|
|
734
917
|
Returns:
|
735
918
|
np.ndarray: Array of RGBA colors for each domain.
|
736
919
|
"""
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
920
|
+
# Generate domain colors based on the enrichment data
|
921
|
+
node_colors = self.graph.get_domain_colors(
|
922
|
+
cmap=cmap,
|
923
|
+
color=color,
|
924
|
+
min_scale=min_scale,
|
925
|
+
max_scale=max_scale,
|
926
|
+
scale_factor=scale_factor,
|
927
|
+
random_seed=random_seed,
|
928
|
+
)
|
744
929
|
annotated_colors = []
|
745
930
|
for _, nodes in self.graph.domain_to_nodes.items():
|
746
931
|
if len(nodes) > 1:
|
747
|
-
# For domains
|
932
|
+
# For multi-node domains, choose the brightest color based on RGB sum
|
748
933
|
domain_colors = np.array([node_colors[node] for node in nodes])
|
749
|
-
brightest_color = domain_colors[
|
934
|
+
brightest_color = domain_colors[
|
935
|
+
np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
|
936
|
+
]
|
750
937
|
annotated_colors.append(brightest_color)
|
751
938
|
else:
|
752
|
-
#
|
939
|
+
# Single-node domains default to white (RGBA)
|
753
940
|
default_color = np.array([1.0, 1.0, 1.0, 1.0])
|
754
941
|
annotated_colors.append(default_color)
|
755
942
|
|
@@ -776,6 +963,47 @@ class NetworkPlotter:
|
|
776
963
|
plt.show(*args, **kwargs)
|
777
964
|
|
778
965
|
|
966
|
+
def _to_rgba(
|
967
|
+
color: Union[str, List, Tuple, np.ndarray],
|
968
|
+
alpha: float = 1.0,
|
969
|
+
num_repeats: Union[int, None] = None,
|
970
|
+
) -> np.ndarray:
|
971
|
+
"""Convert a color or array of colors to RGBA format, applying or updating the alpha as needed.
|
972
|
+
|
973
|
+
Args:
|
974
|
+
color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
|
975
|
+
alpha (float, optional): Alpha value (transparency) to apply. Defaults to 1.0.
|
976
|
+
num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
|
977
|
+
|
978
|
+
Returns:
|
979
|
+
np.ndarray: The RGBA color or array of RGBA colors.
|
980
|
+
"""
|
981
|
+
# Handle single color case
|
982
|
+
if isinstance(color, str) or (
|
983
|
+
isinstance(color, (list, tuple, np.ndarray)) and len(color) in [3, 4]
|
984
|
+
):
|
985
|
+
rgba_color = np.array(mcolors.to_rgba(color, alpha))
|
986
|
+
# Repeat the color if repeat argument is provided
|
987
|
+
if num_repeats is not None:
|
988
|
+
return np.array([rgba_color] * num_repeats)
|
989
|
+
|
990
|
+
return rgba_color
|
991
|
+
|
992
|
+
# Handle array of colors case
|
993
|
+
elif isinstance(color, (list, tuple, np.ndarray)) and isinstance(
|
994
|
+
color[0], (list, tuple, np.ndarray)
|
995
|
+
):
|
996
|
+
rgba_colors = [mcolors.to_rgba(c, alpha) if len(c) == 3 else np.array(c) for c in color]
|
997
|
+
# Repeat the colors if repeat argument is provided
|
998
|
+
if num_repeats is not None and len(rgba_colors) == 1:
|
999
|
+
return np.array([rgba_colors[0]] * num_repeats)
|
1000
|
+
|
1001
|
+
return np.array(rgba_colors)
|
1002
|
+
|
1003
|
+
else:
|
1004
|
+
raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
|
1005
|
+
|
1006
|
+
|
779
1007
|
def _is_connected(z: np.ndarray) -> bool:
|
780
1008
|
"""Determine if a thresholded grid represents a single, connected component.
|
781
1009
|
|