risk-network 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +13 -0
- risk/annotations/__init__.py +7 -0
- risk/annotations/annotations.py +259 -0
- risk/annotations/io.py +183 -0
- risk/constants.py +31 -0
- risk/log/__init__.py +9 -0
- risk/log/console.py +16 -0
- risk/log/params.py +195 -0
- risk/neighborhoods/__init__.py +10 -0
- risk/neighborhoods/community.py +189 -0
- risk/neighborhoods/domains.py +257 -0
- risk/neighborhoods/neighborhoods.py +319 -0
- risk/network/__init__.py +8 -0
- risk/network/geometry.py +165 -0
- risk/network/graph.py +280 -0
- risk/network/io.py +319 -0
- risk/network/plot.py +795 -0
- risk/risk.py +379 -0
- risk/stats/__init__.py +6 -0
- risk/stats/permutation.py +88 -0
- risk/stats/stats.py +373 -0
- risk_network-0.0.3.dist-info/LICENSE +674 -0
- risk_network-0.0.3.dist-info/METADATA +751 -0
- risk_network-0.0.3.dist-info/RECORD +26 -0
- risk_network-0.0.3.dist-info/WHEEL +5 -0
- risk_network-0.0.3.dist-info/top_level.txt +1 -0
risk/network/plot.py
ADDED
@@ -0,0 +1,795 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/plot
|
3
|
+
~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
|
+
|
8
|
+
import matplotlib
|
9
|
+
import matplotlib.colors as mcolors
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
import networkx as nx
|
12
|
+
import numpy as np
|
13
|
+
from scipy.ndimage import label
|
14
|
+
from scipy.stats import gaussian_kde
|
15
|
+
|
16
|
+
from risk.log import params
|
17
|
+
from risk.network.graph import NetworkGraph
|
18
|
+
|
19
|
+
|
20
|
+
class NetworkPlotter:
|
21
|
+
"""A class responsible for visualizing network graphs with various customization options.
|
22
|
+
|
23
|
+
The NetworkPlotter class takes in a NetworkGraph object, which contains the network's data and attributes,
|
24
|
+
and provides methods for plotting the network with customizable node and edge properties,
|
25
|
+
as well as optional features like drawing the network's perimeter and setting background colors.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
network_graph: NetworkGraph,
|
31
|
+
figsize: tuple = (10, 10),
|
32
|
+
background_color: str = "white",
|
33
|
+
plot_outline: bool = True,
|
34
|
+
outline_color: str = "black",
|
35
|
+
outline_scale: float = 1.0,
|
36
|
+
) -> None:
|
37
|
+
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
network_graph (NetworkGraph): The network data and attributes to be visualized.
|
41
|
+
figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
|
42
|
+
background_color (str, optional): Background color of the plot. Defaults to "white".
|
43
|
+
plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
|
44
|
+
outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
|
45
|
+
outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
|
46
|
+
"""
|
47
|
+
self.network_graph = network_graph
|
48
|
+
# Initialize the plot with the specified parameters
|
49
|
+
self.ax = self._initialize_plot(
|
50
|
+
network_graph, figsize, background_color, plot_outline, outline_color, outline_scale
|
51
|
+
)
|
52
|
+
|
53
|
+
def _initialize_plot(
|
54
|
+
self,
|
55
|
+
network_graph: NetworkGraph,
|
56
|
+
figsize: tuple,
|
57
|
+
background_color: str,
|
58
|
+
plot_outline: bool,
|
59
|
+
outline_color: str,
|
60
|
+
outline_scale: float,
|
61
|
+
) -> plt.Axes:
|
62
|
+
"""Set up the plot with figure size, optional circle perimeter, and background color.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
network_graph (NetworkGraph): The network data and attributes to be visualized.
|
66
|
+
figsize (tuple): Size of the figure in inches (width, height).
|
67
|
+
background_color (str): Background color of the plot.
|
68
|
+
plot_outline (bool): Whether to plot the network perimeter circle.
|
69
|
+
outline_color (str): Color of the network perimeter circle.
|
70
|
+
outline_scale (float): Outline scaling factor for the perimeter diameter.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
plt.Axes: The axis object for the plot.
|
74
|
+
"""
|
75
|
+
# Extract node coordinates from the network graph
|
76
|
+
node_coordinates = network_graph.node_coordinates
|
77
|
+
# Calculate the center and radius of the bounding box around the network
|
78
|
+
center, radius = _calculate_bounding_box(node_coordinates)
|
79
|
+
# Scale the radius by the outline_scale factor
|
80
|
+
scaled_radius = radius * outline_scale
|
81
|
+
|
82
|
+
# Create a new figure and axis for plotting
|
83
|
+
fig, ax = plt.subplots(figsize=figsize)
|
84
|
+
fig.tight_layout() # Adjust subplot parameters to give specified padding
|
85
|
+
if plot_outline:
|
86
|
+
# Draw a circle to represent the network perimeter
|
87
|
+
circle = plt.Circle(
|
88
|
+
center,
|
89
|
+
scaled_radius,
|
90
|
+
linestyle="--",
|
91
|
+
color=outline_color,
|
92
|
+
fill=False,
|
93
|
+
linewidth=1.5,
|
94
|
+
)
|
95
|
+
ax.add_artist(circle) # Add the circle to the plot
|
96
|
+
|
97
|
+
# Set axis limits based on the calculated bounding box and scaled radius
|
98
|
+
ax.set_xlim([center[0] - scaled_radius - 0.3, center[0] + scaled_radius + 0.3])
|
99
|
+
ax.set_ylim([center[1] - scaled_radius - 0.3, center[1] + scaled_radius + 0.3])
|
100
|
+
ax.set_aspect("equal") # Ensure the aspect ratio is equal
|
101
|
+
fig.patch.set_facecolor(background_color) # Set the figure background color
|
102
|
+
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
103
|
+
|
104
|
+
# Remove axis spines for a cleaner look
|
105
|
+
for spine in ax.spines.values():
|
106
|
+
spine.set_visible(False)
|
107
|
+
|
108
|
+
# Hide axis ticks and labels
|
109
|
+
ax.set_xticks([])
|
110
|
+
ax.set_yticks([])
|
111
|
+
ax.patch.set_visible(False) # Hide the axis background
|
112
|
+
|
113
|
+
return ax
|
114
|
+
|
115
|
+
def plot_network(
|
116
|
+
self,
|
117
|
+
node_size: Union[int, np.ndarray] = 50,
|
118
|
+
edge_width: float = 1.0,
|
119
|
+
node_color: Union[str, np.ndarray] = "white",
|
120
|
+
node_edgecolor: str = "black",
|
121
|
+
edge_color: str = "black",
|
122
|
+
node_shape: str = "o",
|
123
|
+
) -> None:
|
124
|
+
"""Plot the network graph with customizable node colors, sizes, and edge widths.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
128
|
+
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
129
|
+
node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
130
|
+
node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
|
131
|
+
edge_color (str, optional): Color of the edges. Defaults to "black".
|
132
|
+
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
133
|
+
"""
|
134
|
+
# Log the plotting parameters
|
135
|
+
params.log_plotter(
|
136
|
+
network_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
|
137
|
+
network_edge_width=edge_width,
|
138
|
+
network_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
|
139
|
+
network_node_edgecolor=node_edgecolor,
|
140
|
+
network_edge_color=edge_color,
|
141
|
+
network_node_shape=node_shape,
|
142
|
+
)
|
143
|
+
# Extract node coordinates from the network graph
|
144
|
+
node_coordinates = self.network_graph.node_coordinates
|
145
|
+
# Draw the nodes of the graph
|
146
|
+
nx.draw_networkx_nodes(
|
147
|
+
self.network_graph.G,
|
148
|
+
pos=node_coordinates,
|
149
|
+
node_size=node_size,
|
150
|
+
node_color=node_color,
|
151
|
+
node_shape=node_shape,
|
152
|
+
alpha=1.00,
|
153
|
+
edgecolors=node_edgecolor,
|
154
|
+
ax=self.ax,
|
155
|
+
)
|
156
|
+
# Draw the edges of the graph
|
157
|
+
nx.draw_networkx_edges(
|
158
|
+
self.network_graph.G,
|
159
|
+
pos=node_coordinates,
|
160
|
+
width=edge_width,
|
161
|
+
edge_color=edge_color,
|
162
|
+
ax=self.ax,
|
163
|
+
)
|
164
|
+
|
165
|
+
def plot_subnetwork(
|
166
|
+
self,
|
167
|
+
nodes: list,
|
168
|
+
node_size: Union[int, np.ndarray] = 50,
|
169
|
+
edge_width: float = 1.0,
|
170
|
+
node_color: Union[str, np.ndarray] = "white",
|
171
|
+
node_edgecolor: str = "black",
|
172
|
+
edge_color: str = "black",
|
173
|
+
node_shape: str = "o",
|
174
|
+
) -> None:
|
175
|
+
"""Plot a subnetwork of selected nodes with customizable node and edge attributes.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
nodes (list): List of node labels to include in the subnetwork.
|
179
|
+
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
180
|
+
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
181
|
+
node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
|
182
|
+
node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
|
183
|
+
edge_color (str, optional): Color of the edges. Defaults to "black".
|
184
|
+
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
185
|
+
|
186
|
+
Raises:
|
187
|
+
ValueError: If no valid nodes are found in the network graph.
|
188
|
+
"""
|
189
|
+
# Log the plotting parameters for the subnetwork
|
190
|
+
params.log_plotter(
|
191
|
+
subnetwork_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
|
192
|
+
subnetwork_edge_width=edge_width,
|
193
|
+
subnetwork_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
|
194
|
+
subnetwork_node_edgecolor=node_edgecolor,
|
195
|
+
subnetwork_edge_color=edge_color,
|
196
|
+
subnet_node_shape=node_shape,
|
197
|
+
)
|
198
|
+
# Filter to get node IDs and their coordinates
|
199
|
+
node_ids = [
|
200
|
+
self.network_graph.node_label_to_id_map.get(node)
|
201
|
+
for node in nodes
|
202
|
+
if node in self.network_graph.node_label_to_id_map
|
203
|
+
]
|
204
|
+
if not node_ids:
|
205
|
+
raise ValueError("No nodes found in the network graph.")
|
206
|
+
|
207
|
+
# Get the coordinates of the filtered nodes
|
208
|
+
node_coordinates = {
|
209
|
+
node_id: self.network_graph.node_coordinates[node_id] for node_id in node_ids
|
210
|
+
}
|
211
|
+
# Draw the nodes in the subnetwork
|
212
|
+
nx.draw_networkx_nodes(
|
213
|
+
self.network_graph.G,
|
214
|
+
pos=node_coordinates,
|
215
|
+
nodelist=node_ids,
|
216
|
+
node_size=node_size,
|
217
|
+
node_color=node_color,
|
218
|
+
node_shape=node_shape,
|
219
|
+
alpha=1.00,
|
220
|
+
edgecolors=node_edgecolor,
|
221
|
+
ax=self.ax,
|
222
|
+
)
|
223
|
+
# Draw the edges between the specified nodes in the subnetwork
|
224
|
+
subgraph = self.network_graph.G.subgraph(node_ids)
|
225
|
+
nx.draw_networkx_edges(
|
226
|
+
subgraph,
|
227
|
+
pos=node_coordinates,
|
228
|
+
width=edge_width,
|
229
|
+
edge_color=edge_color,
|
230
|
+
ax=self.ax,
|
231
|
+
)
|
232
|
+
|
233
|
+
def plot_contours(
|
234
|
+
self,
|
235
|
+
levels: int = 5,
|
236
|
+
bandwidth: float = 0.8,
|
237
|
+
grid_size: int = 250,
|
238
|
+
alpha: float = 0.2,
|
239
|
+
color: Union[str, np.ndarray] = "white",
|
240
|
+
) -> None:
|
241
|
+
"""Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
245
|
+
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
246
|
+
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
247
|
+
alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
|
248
|
+
color (str or np.ndarray, optional): Color of the contours. Can be a string (e.g., 'white') or an array of colors. Defaults to "white".
|
249
|
+
"""
|
250
|
+
# Log the contour plotting parameters
|
251
|
+
params.log_plotter(
|
252
|
+
contour_levels=levels,
|
253
|
+
contour_bandwidth=bandwidth,
|
254
|
+
contour_grid_size=grid_size,
|
255
|
+
contour_alpha=alpha,
|
256
|
+
contour_color="custom" if isinstance(color, np.ndarray) else color,
|
257
|
+
)
|
258
|
+
# Convert color string to RGBA array if necessary
|
259
|
+
if isinstance(color, str):
|
260
|
+
color = self.get_annotated_contour_colors(color=color)
|
261
|
+
|
262
|
+
# Extract node coordinates from the network graph
|
263
|
+
node_coordinates = self.network_graph.node_coordinates
|
264
|
+
# Draw contours for each domain in the network
|
265
|
+
for idx, (_, nodes) in enumerate(self.network_graph.domain_to_nodes.items()):
|
266
|
+
if len(nodes) > 1:
|
267
|
+
self._draw_kde_contour(
|
268
|
+
self.ax,
|
269
|
+
node_coordinates,
|
270
|
+
nodes,
|
271
|
+
color=color[idx],
|
272
|
+
levels=levels,
|
273
|
+
bandwidth=bandwidth,
|
274
|
+
grid_size=grid_size,
|
275
|
+
alpha=alpha,
|
276
|
+
)
|
277
|
+
|
278
|
+
def plot_subcontour(
|
279
|
+
self,
|
280
|
+
nodes: list,
|
281
|
+
levels: int = 5,
|
282
|
+
bandwidth: float = 0.8,
|
283
|
+
grid_size: int = 250,
|
284
|
+
alpha: float = 0.2,
|
285
|
+
color: Union[str, np.ndarray] = "white",
|
286
|
+
) -> None:
|
287
|
+
"""Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
|
288
|
+
|
289
|
+
Args:
|
290
|
+
nodes (list): List of node labels to plot the contour for.
|
291
|
+
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
292
|
+
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
293
|
+
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
294
|
+
alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
|
295
|
+
color (str or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
|
296
|
+
|
297
|
+
Raises:
|
298
|
+
ValueError: If no valid nodes are found in the network graph.
|
299
|
+
"""
|
300
|
+
# Log the plotting parameters
|
301
|
+
params.log_plotter(
|
302
|
+
contour_levels=levels,
|
303
|
+
contour_bandwidth=bandwidth,
|
304
|
+
contour_grid_size=grid_size,
|
305
|
+
contour_alpha=alpha,
|
306
|
+
contour_color="custom" if isinstance(color, np.ndarray) else color,
|
307
|
+
)
|
308
|
+
# Filter to get node IDs and their coordinates
|
309
|
+
node_ids = [
|
310
|
+
self.network_graph.node_label_to_id_map.get(node)
|
311
|
+
for node in nodes
|
312
|
+
if node in self.network_graph.node_label_to_id_map
|
313
|
+
]
|
314
|
+
if not node_ids or len(node_ids) == 1:
|
315
|
+
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
316
|
+
|
317
|
+
# Draw the KDE contour for the specified nodes
|
318
|
+
node_coordinates = self.network_graph.node_coordinates
|
319
|
+
self._draw_kde_contour(
|
320
|
+
self.ax,
|
321
|
+
node_coordinates,
|
322
|
+
node_ids,
|
323
|
+
color=color,
|
324
|
+
levels=levels,
|
325
|
+
bandwidth=bandwidth,
|
326
|
+
grid_size=grid_size,
|
327
|
+
alpha=alpha,
|
328
|
+
)
|
329
|
+
|
330
|
+
def _draw_kde_contour(
|
331
|
+
self,
|
332
|
+
ax: plt.Axes,
|
333
|
+
pos: np.ndarray,
|
334
|
+
nodes: list,
|
335
|
+
color: Union[str, np.ndarray],
|
336
|
+
levels: int = 5,
|
337
|
+
bandwidth: float = 0.8,
|
338
|
+
grid_size: int = 250,
|
339
|
+
alpha: float = 0.5,
|
340
|
+
) -> None:
|
341
|
+
"""Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
ax (plt.Axes): The axis to draw the contour on.
|
345
|
+
pos (np.ndarray): Array of node positions (x, y).
|
346
|
+
nodes (list): List of node indices to include in the contour.
|
347
|
+
color (str or np.ndarray): Color for the contour.
|
348
|
+
levels (int, optional): Number of contour levels. Defaults to 5.
|
349
|
+
bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
|
350
|
+
grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
|
351
|
+
alpha (float, optional): Transparency level for the contour fill. Defaults to 0.5.
|
352
|
+
"""
|
353
|
+
# Extract the positions of the specified nodes
|
354
|
+
points = np.array([pos[n] for n in nodes])
|
355
|
+
if len(points) <= 1:
|
356
|
+
return # Not enough points to form a contour
|
357
|
+
|
358
|
+
connected = False
|
359
|
+
while not connected and bandwidth <= 100.0:
|
360
|
+
# Perform KDE on the points with the given bandwidth
|
361
|
+
kde = gaussian_kde(points.T, bw_method=bandwidth)
|
362
|
+
xmin, ymin = points.min(axis=0) - bandwidth
|
363
|
+
xmax, ymax = points.max(axis=0) + bandwidth
|
364
|
+
x, y = np.mgrid[
|
365
|
+
xmin : xmax : complex(0, grid_size), ymin : ymax : complex(0, grid_size)
|
366
|
+
]
|
367
|
+
z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
|
368
|
+
# Check if the KDE forms a single connected component
|
369
|
+
connected = _is_connected(z)
|
370
|
+
if not connected:
|
371
|
+
bandwidth += 0.05 # Increase bandwidth slightly and retry
|
372
|
+
|
373
|
+
# Define contour levels based on the density
|
374
|
+
min_density, max_density = z.min(), z.max()
|
375
|
+
contour_levels = np.linspace(min_density, max_density, levels)[1:]
|
376
|
+
contour_colors = [color for _ in range(levels - 1)]
|
377
|
+
|
378
|
+
# Plot the filled contours if alpha > 0
|
379
|
+
if alpha > 0:
|
380
|
+
ax.contourf(
|
381
|
+
x,
|
382
|
+
y,
|
383
|
+
z,
|
384
|
+
levels=contour_levels,
|
385
|
+
colors=contour_colors,
|
386
|
+
alpha=alpha,
|
387
|
+
extend="neither",
|
388
|
+
antialiased=True,
|
389
|
+
)
|
390
|
+
|
391
|
+
# Plot the contour lines without antialiasing for clarity
|
392
|
+
c = ax.contour(x, y, z, levels=contour_levels, colors=contour_colors)
|
393
|
+
for i in range(1, len(contour_levels)):
|
394
|
+
c.collections[i].set_linewidth(0)
|
395
|
+
|
396
|
+
def plot_labels(
|
397
|
+
self,
|
398
|
+
perimeter_scale: float = 1.05,
|
399
|
+
offset: float = 0.10,
|
400
|
+
font: str = "Arial",
|
401
|
+
fontsize: int = 10,
|
402
|
+
fontcolor: Union[str, np.ndarray] = "black",
|
403
|
+
arrow_linewidth: float = 1,
|
404
|
+
arrow_color: Union[str, np.ndarray] = "black",
|
405
|
+
max_words: int = 10,
|
406
|
+
min_words: int = 1,
|
407
|
+
) -> None:
|
408
|
+
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
perimeter_scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
|
412
|
+
offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
|
413
|
+
font (str, optional): Font name for the labels. Defaults to "Arial".
|
414
|
+
fontsize (int, optional): Font size for the labels. Defaults to 10.
|
415
|
+
fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
|
416
|
+
arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
|
417
|
+
arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
|
418
|
+
max_words (int, optional): Maximum number of words in a label. Defaults to 10.
|
419
|
+
min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
|
420
|
+
"""
|
421
|
+
# Log the plotting parameters
|
422
|
+
params.log_plotter(
|
423
|
+
label_perimeter_scale=perimeter_scale,
|
424
|
+
label_offset=offset,
|
425
|
+
label_font=font,
|
426
|
+
label_fontsize=fontsize,
|
427
|
+
label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
|
428
|
+
label_arrow_linewidth=arrow_linewidth,
|
429
|
+
label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
430
|
+
label_max_words=max_words,
|
431
|
+
label_min_words=min_words,
|
432
|
+
)
|
433
|
+
# Convert color strings to RGBA arrays if necessary
|
434
|
+
if isinstance(fontcolor, str):
|
435
|
+
fontcolor = self.get_annotated_contour_colors(color=fontcolor)
|
436
|
+
if isinstance(arrow_color, str):
|
437
|
+
arrow_color = self.get_annotated_contour_colors(color=arrow_color)
|
438
|
+
|
439
|
+
# Calculate the center and radius of the network
|
440
|
+
domain_centroids = {}
|
441
|
+
for domain, nodes in self.network_graph.domain_to_nodes.items():
|
442
|
+
if nodes: # Skip if the domain has no nodes
|
443
|
+
domain_centroids[domain] = self._calculate_domain_centroid(nodes)
|
444
|
+
|
445
|
+
# Calculate the bounding box around the network
|
446
|
+
center, radius = _calculate_bounding_box(
|
447
|
+
self.network_graph.node_coordinates, radius_margin=perimeter_scale
|
448
|
+
)
|
449
|
+
|
450
|
+
# Filter out domains with insufficient words for labeling
|
451
|
+
filtered_domains = {
|
452
|
+
domain: centroid
|
453
|
+
for domain, centroid in domain_centroids.items()
|
454
|
+
if len(self.network_graph.trimmed_domain_to_term[domain].split(" ")[:max_words])
|
455
|
+
>= min_words
|
456
|
+
}
|
457
|
+
# Calculate the best positions for labels around the perimeter
|
458
|
+
best_label_positions = _best_label_positions(filtered_domains, center, radius, offset)
|
459
|
+
# Annotate the network with labels
|
460
|
+
for idx, (domain, pos) in enumerate(best_label_positions.items()):
|
461
|
+
centroid = filtered_domains[domain]
|
462
|
+
annotations = self.network_graph.trimmed_domain_to_term[domain].split(" ")[:max_words]
|
463
|
+
self.ax.annotate(
|
464
|
+
"\n".join(annotations),
|
465
|
+
xy=centroid,
|
466
|
+
xytext=pos,
|
467
|
+
textcoords="data",
|
468
|
+
ha="center",
|
469
|
+
va="center",
|
470
|
+
fontsize=fontsize,
|
471
|
+
fontname=font,
|
472
|
+
color=fontcolor[idx],
|
473
|
+
arrowprops=dict(arrowstyle="->", color=arrow_color[idx], linewidth=arrow_linewidth),
|
474
|
+
)
|
475
|
+
|
476
|
+
def _calculate_domain_centroid(self, nodes: list) -> tuple:
|
477
|
+
"""Calculate the most centrally located node in .
|
478
|
+
|
479
|
+
Args:
|
480
|
+
nodes (list): List of node labels to include in the subnetwork.
|
481
|
+
|
482
|
+
Returns:
|
483
|
+
tuple: A tuple containing the domain's central node coordinates.
|
484
|
+
"""
|
485
|
+
# Extract positions of all nodes in the domain
|
486
|
+
node_positions = self.network_graph.node_coordinates[nodes, :]
|
487
|
+
# Calculate the pairwise distance matrix between all nodes in the domain
|
488
|
+
distances_matrix = np.linalg.norm(node_positions[:, np.newaxis] - node_positions, axis=2)
|
489
|
+
# Sum the distances for each node to all other nodes in the domain
|
490
|
+
sum_distances = np.sum(distances_matrix, axis=1)
|
491
|
+
# Identify the node with the smallest total distance to others (the centroid)
|
492
|
+
central_node_idx = np.argmin(sum_distances)
|
493
|
+
# Map the domain to the coordinates of its central node
|
494
|
+
domain_central_node = node_positions[central_node_idx]
|
495
|
+
return domain_central_node
|
496
|
+
|
497
|
+
def get_annotated_node_colors(
|
498
|
+
self, nonenriched_color: str = "white", random_seed: int = 888, **kwargs
|
499
|
+
) -> np.ndarray:
|
500
|
+
"""Adjust the colors of nodes in the network graph based on enrichment.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
nonenriched_color (str, optional): Color for non-enriched nodes. Defaults to "white".
|
504
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
505
|
+
**kwargs: Additional keyword arguments for `get_domain_colors`.
|
506
|
+
|
507
|
+
Returns:
|
508
|
+
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
509
|
+
"""
|
510
|
+
# Get the initial domain colors for each node
|
511
|
+
network_colors = self.network_graph.get_domain_colors(**kwargs, random_seed=random_seed)
|
512
|
+
if isinstance(nonenriched_color, str):
|
513
|
+
# Convert the non-enriched color from string to RGBA
|
514
|
+
nonenriched_color = mcolors.to_rgba(nonenriched_color)
|
515
|
+
|
516
|
+
# Adjust node colors: replace any fully transparent nodes (enriched) with the non-enriched color
|
517
|
+
adjusted_network_colors = np.where(
|
518
|
+
np.all(network_colors == 0, axis=1, keepdims=True),
|
519
|
+
np.array([nonenriched_color]),
|
520
|
+
network_colors,
|
521
|
+
)
|
522
|
+
return adjusted_network_colors
|
523
|
+
|
524
|
+
def get_annotated_node_sizes(
|
525
|
+
self, enriched_nodesize: int = 50, nonenriched_nodesize: int = 25
|
526
|
+
) -> np.ndarray:
|
527
|
+
"""Adjust the sizes of nodes in the network graph based on whether they are enriched or not.
|
528
|
+
|
529
|
+
Args:
|
530
|
+
enriched_nodesize (int): Size for enriched nodes. Defaults to 50.
|
531
|
+
nonenriched_nodesize (int): Size for non-enriched nodes. Defaults to 25.
|
532
|
+
|
533
|
+
Returns:
|
534
|
+
np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
|
535
|
+
"""
|
536
|
+
# Merge all enriched nodes from the domain_to_nodes dictionary
|
537
|
+
enriched_nodes = set()
|
538
|
+
for _, nodes in self.network_graph.domain_to_nodes.items():
|
539
|
+
enriched_nodes.update(nodes)
|
540
|
+
|
541
|
+
# Initialize all node sizes to the non-enriched size
|
542
|
+
node_sizes = np.full(len(self.network_graph.G.nodes), nonenriched_nodesize)
|
543
|
+
# Set the size for enriched nodes
|
544
|
+
for node in enriched_nodes:
|
545
|
+
if node in self.network_graph.G.nodes:
|
546
|
+
node_sizes[node] = enriched_nodesize
|
547
|
+
|
548
|
+
return node_sizes
|
549
|
+
|
550
|
+
def get_annotated_contour_colors(self, random_seed: int = 888, **kwargs) -> np.ndarray:
|
551
|
+
"""Get colors for the contours based on node annotations.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
555
|
+
**kwargs: Additional keyword arguments for `_get_annotated_domain_colors`.
|
556
|
+
|
557
|
+
Returns:
|
558
|
+
np.ndarray: Array of RGBA colors for contour annotations.
|
559
|
+
"""
|
560
|
+
return self._get_annotated_domain_colors(**kwargs, random_seed=random_seed)
|
561
|
+
|
562
|
+
def get_annotated_label_colors(self, random_seed: int = 888, **kwargs) -> np.ndarray:
|
563
|
+
"""Get colors for the labels based on node annotations.
|
564
|
+
|
565
|
+
Args:
|
566
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
567
|
+
**kwargs: Additional keyword arguments for `_get_annotated_domain_colors`.
|
568
|
+
|
569
|
+
Returns:
|
570
|
+
np.ndarray: Array of RGBA colors for label annotations.
|
571
|
+
"""
|
572
|
+
return self._get_annotated_domain_colors(**kwargs, random_seed=random_seed)
|
573
|
+
|
574
|
+
def _get_annotated_domain_colors(
|
575
|
+
self, color: Union[str, list, None] = None, random_seed: int = 888, **kwargs
|
576
|
+
) -> np.ndarray:
|
577
|
+
"""Get colors for the domains based on node annotations.
|
578
|
+
|
579
|
+
Args:
|
580
|
+
color (str, list, or None, optional): If provided, use this color or list of colors for domains. Defaults to None.
|
581
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
582
|
+
**kwargs: Additional keyword arguments for `get_domain_colors`.
|
583
|
+
|
584
|
+
Returns:
|
585
|
+
np.ndarray: Array of RGBA colors for each domain.
|
586
|
+
"""
|
587
|
+
if isinstance(color, str):
|
588
|
+
# If a single color string is provided, convert it to RGBA and apply to all domains
|
589
|
+
rgba_color = np.array(matplotlib.colors.to_rgba(color))
|
590
|
+
return np.array([rgba_color for _ in self.network_graph.domain_to_nodes])
|
591
|
+
|
592
|
+
# Generate colors for each domain using the provided arguments and random seed
|
593
|
+
node_colors = self.network_graph.get_domain_colors(**kwargs, random_seed=random_seed)
|
594
|
+
annotated_colors = []
|
595
|
+
for _, nodes in self.network_graph.domain_to_nodes.items():
|
596
|
+
if len(nodes) > 1:
|
597
|
+
# For domains with multiple nodes, choose the brightest color (sum of RGB values)
|
598
|
+
domain_colors = np.array([node_colors[node] for node in nodes])
|
599
|
+
brightest_color = domain_colors[np.argmax(domain_colors.sum(axis=1))]
|
600
|
+
annotated_colors.append(brightest_color)
|
601
|
+
else:
|
602
|
+
# Assign a default color (white) for single-node domains
|
603
|
+
default_color = np.array([1.0, 1.0, 1.0, 1.0])
|
604
|
+
annotated_colors.append(default_color)
|
605
|
+
|
606
|
+
return np.array(annotated_colors)
|
607
|
+
|
608
|
+
@staticmethod
|
609
|
+
def savefig(*args, **kwargs) -> None:
|
610
|
+
"""Save the current plot to a file.
|
611
|
+
|
612
|
+
Args:
|
613
|
+
*args: Positional arguments passed to `plt.savefig`.
|
614
|
+
**kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
|
615
|
+
"""
|
616
|
+
plt.savefig(*args, bbox_inches="tight", **kwargs)
|
617
|
+
|
618
|
+
@staticmethod
|
619
|
+
def show(*args, **kwargs) -> None:
|
620
|
+
"""Display the current plot.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
*args: Positional arguments passed to `plt.show`.
|
624
|
+
**kwargs: Keyword arguments passed to `plt.show`.
|
625
|
+
"""
|
626
|
+
plt.show(*args, **kwargs)
|
627
|
+
|
628
|
+
|
629
|
+
def _is_connected(z: np.ndarray) -> bool:
|
630
|
+
"""Determine if a thresholded grid represents a single, connected component.
|
631
|
+
|
632
|
+
Args:
|
633
|
+
z (np.ndarray): A binary grid where the component connectivity is evaluated.
|
634
|
+
|
635
|
+
Returns:
|
636
|
+
bool: True if the grid represents a single connected component, False otherwise.
|
637
|
+
"""
|
638
|
+
_, num_features = label(z)
|
639
|
+
return num_features == 1 # Return True if only one connected component is found
|
640
|
+
|
641
|
+
|
642
|
+
def _calculate_bounding_box(
|
643
|
+
node_coordinates: np.ndarray, radius_margin: float = 1.05
|
644
|
+
) -> Tuple[np.ndarray, float]:
|
645
|
+
"""Calculate the bounding box of the network based on node coordinates.
|
646
|
+
|
647
|
+
Args:
|
648
|
+
node_coordinates (np.ndarray): Array of node coordinates (x, y).
|
649
|
+
radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
|
650
|
+
|
651
|
+
Returns:
|
652
|
+
tuple: Center of the bounding box and the radius (adjusted by the radius margin).
|
653
|
+
"""
|
654
|
+
# Find minimum and maximum x, y coordinates
|
655
|
+
x_min, y_min = np.min(node_coordinates, axis=0)
|
656
|
+
x_max, y_max = np.max(node_coordinates, axis=0)
|
657
|
+
# Calculate the center of the bounding box
|
658
|
+
center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
|
659
|
+
# Calculate the radius of the bounding box, adjusted by the margin
|
660
|
+
radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
|
661
|
+
return center, radius
|
662
|
+
|
663
|
+
|
664
|
+
def _best_label_positions(
|
665
|
+
filtered_domains: Dict[str, Any], center: np.ndarray, radius: float, offset: float
|
666
|
+
) -> Dict[str, Any]:
|
667
|
+
"""Calculate and optimize label positions for clarity.
|
668
|
+
|
669
|
+
Args:
|
670
|
+
filtered_domains (dict): Centroids of the filtered domains.
|
671
|
+
center (np.ndarray): The center coordinates for label positioning.
|
672
|
+
radius (float): The radius for positioning labels around the center.
|
673
|
+
offset (float): The offset distance from the radius for positioning labels.
|
674
|
+
|
675
|
+
Returns:
|
676
|
+
dict: Optimized positions for labels.
|
677
|
+
"""
|
678
|
+
num_domains = len(filtered_domains)
|
679
|
+
# Calculate equidistant positions around the center for initial label placement
|
680
|
+
equidistant_positions = _equidistant_angles_around_center(center, radius, offset, num_domains)
|
681
|
+
# Create a mapping of domains to their initial label positions
|
682
|
+
label_positions = {
|
683
|
+
domain: position for domain, position in zip(filtered_domains.keys(), equidistant_positions)
|
684
|
+
}
|
685
|
+
# Optimize the label positions to minimize distance to domain centroids
|
686
|
+
return _optimize_label_positions(label_positions, filtered_domains)
|
687
|
+
|
688
|
+
|
689
|
+
def _equidistant_angles_around_center(
|
690
|
+
center: np.ndarray, radius: float, label_offset: float, num_domains: int
|
691
|
+
) -> List[np.ndarray]:
|
692
|
+
"""Calculate positions around a center at equidistant angles.
|
693
|
+
|
694
|
+
Args:
|
695
|
+
center (np.ndarray): The central point around which positions are calculated.
|
696
|
+
radius (float): The radius at which positions are calculated.
|
697
|
+
label_offset (float): The offset added to the radius for label positioning.
|
698
|
+
num_domains (int): The number of positions (or domains) to calculate.
|
699
|
+
|
700
|
+
Returns:
|
701
|
+
list[np.ndarray]: List of positions (as 2D numpy arrays) around the center.
|
702
|
+
"""
|
703
|
+
# Calculate equidistant angles in radians around the center
|
704
|
+
angles = np.linspace(0, 2 * np.pi, num_domains, endpoint=False)
|
705
|
+
# Compute the positions around the center using the angles
|
706
|
+
return [
|
707
|
+
center + (radius + label_offset) * np.array([np.cos(angle), np.sin(angle)])
|
708
|
+
for angle in angles
|
709
|
+
]
|
710
|
+
|
711
|
+
|
712
|
+
def _optimize_label_positions(
|
713
|
+
best_label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
|
714
|
+
) -> Dict[str, Any]:
|
715
|
+
"""Optimize label positions around the perimeter to minimize total distance to centroids.
|
716
|
+
|
717
|
+
Args:
|
718
|
+
best_label_positions (dict): Initial positions of labels around the perimeter.
|
719
|
+
domain_centroids (dict): Centroid positions of the domains.
|
720
|
+
|
721
|
+
Returns:
|
722
|
+
dict: Optimized label positions.
|
723
|
+
"""
|
724
|
+
while True:
|
725
|
+
improvement = False # Start each iteration assuming no improvement
|
726
|
+
# Iterate through each pair of labels to check for potential improvements
|
727
|
+
for i in range(len(domain_centroids)):
|
728
|
+
for j in range(i + 1, len(domain_centroids)):
|
729
|
+
# Calculate the current total distance
|
730
|
+
current_distance = _calculate_total_distance(best_label_positions, domain_centroids)
|
731
|
+
# Evaluate the total distance after swapping two labels
|
732
|
+
swapped_distance = _swap_and_evaluate(best_label_positions, i, j, domain_centroids)
|
733
|
+
# If the swap improves the total distance, perform the swap
|
734
|
+
if swapped_distance < current_distance:
|
735
|
+
labels = list(best_label_positions.keys())
|
736
|
+
best_label_positions[labels[i]], best_label_positions[labels[j]] = (
|
737
|
+
best_label_positions[labels[j]],
|
738
|
+
best_label_positions[labels[i]],
|
739
|
+
)
|
740
|
+
improvement = True # Found an improvement, so continue optimizing
|
741
|
+
|
742
|
+
if not improvement:
|
743
|
+
break # Exit the loop if no improvement was found in this iteration
|
744
|
+
|
745
|
+
return best_label_positions
|
746
|
+
|
747
|
+
|
748
|
+
def _calculate_total_distance(
|
749
|
+
label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
|
750
|
+
) -> float:
|
751
|
+
"""Calculate the total distance from label positions to their domain centroids.
|
752
|
+
|
753
|
+
Args:
|
754
|
+
label_positions (dict): Positions of labels around the perimeter.
|
755
|
+
domain_centroids (dict): Centroid positions of the domains.
|
756
|
+
|
757
|
+
Returns:
|
758
|
+
float: The total distance from labels to centroids.
|
759
|
+
"""
|
760
|
+
total_distance = 0
|
761
|
+
# Iterate through each domain and calculate the distance to its centroid
|
762
|
+
for domain, pos in label_positions.items():
|
763
|
+
centroid = domain_centroids[domain]
|
764
|
+
total_distance += np.linalg.norm(centroid - pos)
|
765
|
+
|
766
|
+
return total_distance
|
767
|
+
|
768
|
+
|
769
|
+
def _swap_and_evaluate(
|
770
|
+
label_positions: Dict[str, Any],
|
771
|
+
i: int,
|
772
|
+
j: int,
|
773
|
+
domain_centroids: Dict[str, Any],
|
774
|
+
) -> float:
|
775
|
+
"""Swap two labels and evaluate the total distance after the swap.
|
776
|
+
|
777
|
+
Args:
|
778
|
+
label_positions (dict): Positions of labels around the perimeter.
|
779
|
+
i (int): Index of the first label to swap.
|
780
|
+
j (int): Index of the second label to swap.
|
781
|
+
domain_centroids (dict): Centroid positions of the domains.
|
782
|
+
|
783
|
+
Returns:
|
784
|
+
float: The total distance after swapping the two labels.
|
785
|
+
"""
|
786
|
+
# Get the list of labels from the dictionary keys
|
787
|
+
labels = list(label_positions.keys())
|
788
|
+
swapped_positions = label_positions.copy()
|
789
|
+
# Swap the positions of the two specified labels
|
790
|
+
swapped_positions[labels[i]], swapped_positions[labels[j]] = (
|
791
|
+
swapped_positions[labels[j]],
|
792
|
+
swapped_positions[labels[i]],
|
793
|
+
)
|
794
|
+
# Calculate and return the total distance after the swap
|
795
|
+
return _calculate_total_distance(swapped_positions, domain_centroids)
|