risk-network 0.0.8b7__py3-none-any.whl → 0.0.8b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/network/plot/__init__.py +6 -0
- risk/network/plot/canvas.py +229 -0
- risk/network/plot/contour.py +319 -0
- risk/network/plot/labels.py +848 -0
- risk/network/plot/network.py +269 -0
- risk/network/plot/plotter.py +134 -0
- risk/network/plot/utils.py +153 -0
- risk/risk.py +3 -5
- {risk_network-0.0.8b7.dist-info → risk_network-0.0.8b9.dist-info}/METADATA +1 -1
- {risk_network-0.0.8b7.dist-info → risk_network-0.0.8b9.dist-info}/RECORD +14 -8
- risk/network/plot/base.py +0 -1809
- {risk_network-0.0.8b7.dist-info → risk_network-0.0.8b9.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b7.dist-info → risk_network-0.0.8b9.dist-info}/WHEEL +0 -0
- {risk_network-0.0.8b7.dist-info → risk_network-0.0.8b9.dist-info}/top_level.txt +0 -0
risk/network/plot/base.py
DELETED
@@ -1,1809 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
risk/network/plot
|
3
|
-
~~~~~~~~~~~~~~~~~
|
4
|
-
"""
|
5
|
-
|
6
|
-
from typing import Any, Dict, List, Tuple, Union
|
7
|
-
|
8
|
-
import matplotlib.colors as mcolors
|
9
|
-
import matplotlib.pyplot as plt
|
10
|
-
import networkx as nx
|
11
|
-
import numpy as np
|
12
|
-
import pandas as pd
|
13
|
-
from scipy import linalg
|
14
|
-
from scipy.ndimage import label
|
15
|
-
from scipy.stats import gaussian_kde
|
16
|
-
|
17
|
-
from risk.log import params, logger
|
18
|
-
from risk.network.graph import NetworkGraph
|
19
|
-
|
20
|
-
TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
|
21
|
-
|
22
|
-
|
23
|
-
class NetworkPlotter:
|
24
|
-
"""A class for visualizing network graphs with customizable options.
|
25
|
-
|
26
|
-
The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
|
27
|
-
flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
|
28
|
-
perimeter, and adjusting background colors.
|
29
|
-
"""
|
30
|
-
|
31
|
-
def __init__(
|
32
|
-
self,
|
33
|
-
graph: NetworkGraph,
|
34
|
-
figsize: Tuple = (10, 10),
|
35
|
-
background_color: Union[str, List, Tuple, np.ndarray] = "white",
|
36
|
-
background_alpha: Union[float, None] = 1.0,
|
37
|
-
) -> None:
|
38
|
-
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
39
|
-
|
40
|
-
Args:
|
41
|
-
graph (NetworkGraph): The network data and attributes to be visualized.
|
42
|
-
figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
|
43
|
-
background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
|
44
|
-
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
|
45
|
-
any existing alpha values found in background_color. Defaults to 1.0.
|
46
|
-
"""
|
47
|
-
self.graph = graph
|
48
|
-
# Initialize the plot with the specified parameters
|
49
|
-
self.ax = self._initialize_plot(
|
50
|
-
graph=graph,
|
51
|
-
figsize=figsize,
|
52
|
-
background_color=background_color,
|
53
|
-
background_alpha=background_alpha,
|
54
|
-
)
|
55
|
-
|
56
|
-
def _initialize_plot(
|
57
|
-
self,
|
58
|
-
graph: NetworkGraph,
|
59
|
-
figsize: Tuple,
|
60
|
-
background_color: Union[str, List, Tuple, np.ndarray],
|
61
|
-
background_alpha: Union[float, None],
|
62
|
-
) -> plt.Axes:
|
63
|
-
"""Set up the plot with figure size and background color.
|
64
|
-
|
65
|
-
Args:
|
66
|
-
graph (NetworkGraph): The network data and attributes to be visualized.
|
67
|
-
figsize (tuple): Size of the figure in inches (width, height).
|
68
|
-
background_color (str): Background color of the plot.
|
69
|
-
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any
|
70
|
-
existing alpha values found in background_color.
|
71
|
-
|
72
|
-
Returns:
|
73
|
-
plt.Axes: The axis object for the plot.
|
74
|
-
"""
|
75
|
-
# Extract node coordinates from the network graph
|
76
|
-
node_coordinates = graph.node_coordinates
|
77
|
-
# Calculate the center and radius of the bounding box around the network
|
78
|
-
center, radius = _calculate_bounding_box(node_coordinates)
|
79
|
-
|
80
|
-
# Create a new figure and axis for plotting
|
81
|
-
fig, ax = plt.subplots(figsize=figsize)
|
82
|
-
fig.tight_layout() # Adjust subplot parameters to give specified padding
|
83
|
-
# Set axis limits based on the calculated bounding box and radius
|
84
|
-
ax.set_xlim([center[0] - radius - 0.3, center[0] + radius + 0.3])
|
85
|
-
ax.set_ylim([center[1] - radius - 0.3, center[1] + radius + 0.3])
|
86
|
-
ax.set_aspect("equal") # Ensure the aspect ratio is equal
|
87
|
-
|
88
|
-
# Set the background color of the plot
|
89
|
-
# Convert color to RGBA using the _to_rgba helper function
|
90
|
-
fig.patch.set_facecolor(_to_rgba(color=background_color, alpha=background_alpha))
|
91
|
-
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
92
|
-
# Remove axis spines for a cleaner look
|
93
|
-
for spine in ax.spines.values():
|
94
|
-
spine.set_visible(False)
|
95
|
-
|
96
|
-
# Hide axis ticks and labels
|
97
|
-
ax.set_xticks([])
|
98
|
-
ax.set_yticks([])
|
99
|
-
ax.patch.set_visible(False) # Hide the axis background
|
100
|
-
|
101
|
-
return ax
|
102
|
-
|
103
|
-
def plot_title(
|
104
|
-
self,
|
105
|
-
title: Union[str, None] = None,
|
106
|
-
subtitle: Union[str, None] = None,
|
107
|
-
title_fontsize: int = 20,
|
108
|
-
subtitle_fontsize: int = 14,
|
109
|
-
font: str = "Arial",
|
110
|
-
title_color: str = "black",
|
111
|
-
subtitle_color: str = "gray",
|
112
|
-
title_y: float = 0.975,
|
113
|
-
title_space_offset: float = 0.075,
|
114
|
-
subtitle_offset: float = 0.025,
|
115
|
-
) -> None:
|
116
|
-
"""Plot title and subtitle on the network graph with customizable parameters.
|
117
|
-
|
118
|
-
Args:
|
119
|
-
title (str, optional): Title of the plot. Defaults to None.
|
120
|
-
subtitle (str, optional): Subtitle of the plot. Defaults to None.
|
121
|
-
title_fontsize (int, optional): Font size for the title. Defaults to 16.
|
122
|
-
subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 12.
|
123
|
-
font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
|
124
|
-
title_color (str, optional): Color of the title text. Defaults to "black".
|
125
|
-
subtitle_color (str, optional): Color of the subtitle text. Defaults to "gray".
|
126
|
-
title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
|
127
|
-
title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
|
128
|
-
subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
|
129
|
-
"""
|
130
|
-
# Log the title and subtitle parameters
|
131
|
-
params.log_plotter(
|
132
|
-
title=title,
|
133
|
-
subtitle=subtitle,
|
134
|
-
title_fontsize=title_fontsize,
|
135
|
-
subtitle_fontsize=subtitle_fontsize,
|
136
|
-
title_subtitle_font=font,
|
137
|
-
title_color=title_color,
|
138
|
-
subtitle_color=subtitle_color,
|
139
|
-
subtitle_offset=subtitle_offset,
|
140
|
-
title_y=title_y,
|
141
|
-
title_space_offset=title_space_offset,
|
142
|
-
)
|
143
|
-
|
144
|
-
# Get the current figure and axis dimensions
|
145
|
-
fig = self.ax.figure
|
146
|
-
# Use a tight layout to ensure that title and subtitle do not overlap with the original plot
|
147
|
-
fig.tight_layout(
|
148
|
-
rect=[0, 0, 1, 1 - title_space_offset]
|
149
|
-
) # Leave space above the plot for title
|
150
|
-
|
151
|
-
# Plot title if provided
|
152
|
-
if title:
|
153
|
-
# Set the title using figure's suptitle to ensure centering
|
154
|
-
self.ax.figure.suptitle(
|
155
|
-
title,
|
156
|
-
fontsize=title_fontsize,
|
157
|
-
color=title_color,
|
158
|
-
fontname=font,
|
159
|
-
x=0.5, # Center the title horizontally
|
160
|
-
ha="center",
|
161
|
-
va="top",
|
162
|
-
y=title_y,
|
163
|
-
)
|
164
|
-
|
165
|
-
# Plot subtitle if provided
|
166
|
-
if subtitle:
|
167
|
-
# Calculate the subtitle's y position based on title's position and subtitle_offset
|
168
|
-
subtitle_y_position = title_y - subtitle_offset
|
169
|
-
self.ax.figure.text(
|
170
|
-
0.5, # Ensure horizontal centering for subtitle
|
171
|
-
subtitle_y_position,
|
172
|
-
subtitle,
|
173
|
-
ha="center",
|
174
|
-
va="top",
|
175
|
-
fontname=font,
|
176
|
-
fontsize=subtitle_fontsize,
|
177
|
-
color=subtitle_color,
|
178
|
-
)
|
179
|
-
|
180
|
-
def plot_circle_perimeter(
|
181
|
-
self,
|
182
|
-
scale: float = 1.0,
|
183
|
-
linestyle: str = "dashed",
|
184
|
-
linewidth: float = 1.5,
|
185
|
-
color: Union[str, List, Tuple, np.ndarray] = "black",
|
186
|
-
outline_alpha: Union[float, None] = 1.0,
|
187
|
-
fill_alpha: Union[float, None] = 0.0,
|
188
|
-
) -> None:
|
189
|
-
"""Plot a circle around the network graph to represent the network perimeter.
|
190
|
-
|
191
|
-
Args:
|
192
|
-
scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
|
193
|
-
linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
|
194
|
-
linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
|
195
|
-
color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
|
196
|
-
outline_alpha (float, None, optional): Transparency level of the circle outline. If provided, it overrides any existing alpha
|
197
|
-
values found in color. Defaults to 1.0.
|
198
|
-
fill_alpha (float, None, optional): Transparency level of the circle fill. If provided, it overrides any existing alpha values
|
199
|
-
found in color. Defaults to 0.0.
|
200
|
-
"""
|
201
|
-
# Log the circle perimeter plotting parameters
|
202
|
-
params.log_plotter(
|
203
|
-
perimeter_type="circle",
|
204
|
-
perimeter_scale=scale,
|
205
|
-
perimeter_linestyle=linestyle,
|
206
|
-
perimeter_linewidth=linewidth,
|
207
|
-
perimeter_color=(
|
208
|
-
"custom" if isinstance(color, (list, tuple, np.ndarray)) else color
|
209
|
-
), # np.ndarray usually indicates custom colors
|
210
|
-
perimeter_outline_alpha=outline_alpha,
|
211
|
-
perimeter_fill_alpha=fill_alpha,
|
212
|
-
)
|
213
|
-
|
214
|
-
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
215
|
-
color = _to_rgba(color=color, alpha=outline_alpha)
|
216
|
-
# Set the fill_alpha to 0 if not provided
|
217
|
-
fill_alpha = fill_alpha if fill_alpha is not None else 0.0
|
218
|
-
# Extract node coordinates from the network graph
|
219
|
-
node_coordinates = self.graph.node_coordinates
|
220
|
-
# Calculate the center and radius of the bounding box around the network
|
221
|
-
center, radius = _calculate_bounding_box(node_coordinates)
|
222
|
-
# Scale the radius by the scale factor
|
223
|
-
scaled_radius = radius * scale
|
224
|
-
|
225
|
-
# Draw a circle to represent the network perimeter
|
226
|
-
circle = plt.Circle(
|
227
|
-
center,
|
228
|
-
scaled_radius,
|
229
|
-
linestyle=linestyle,
|
230
|
-
linewidth=linewidth,
|
231
|
-
color=color,
|
232
|
-
fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
|
233
|
-
)
|
234
|
-
# Set the transparency of the fill if applicable
|
235
|
-
if fill_alpha > 0:
|
236
|
-
circle.set_facecolor(_to_rgba(color=color, alpha=fill_alpha))
|
237
|
-
|
238
|
-
self.ax.add_artist(circle)
|
239
|
-
|
240
|
-
def plot_contour_perimeter(
|
241
|
-
self,
|
242
|
-
scale: float = 1.0,
|
243
|
-
levels: int = 3,
|
244
|
-
bandwidth: float = 0.8,
|
245
|
-
grid_size: int = 250,
|
246
|
-
color: Union[str, List, Tuple, np.ndarray] = "black",
|
247
|
-
linestyle: str = "solid",
|
248
|
-
linewidth: float = 1.5,
|
249
|
-
outline_alpha: Union[float, None] = 1.0,
|
250
|
-
fill_alpha: Union[float, None] = 0.0,
|
251
|
-
) -> None:
|
252
|
-
"""
|
253
|
-
Plot a KDE-based contour around the network graph to represent the network perimeter.
|
254
|
-
|
255
|
-
Args:
|
256
|
-
scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
|
257
|
-
levels (int, optional): Number of contour levels. Defaults to 3.
|
258
|
-
bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
|
259
|
-
grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
|
260
|
-
color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
|
261
|
-
linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
|
262
|
-
linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
|
263
|
-
outline_alpha (float, None, optional): Transparency level of the contour outline. If provided, it overrides any existing
|
264
|
-
alpha values found in color. Defaults to 1.0.
|
265
|
-
fill_alpha (float, None, optional): Transparency level of the contour fill. If provided, it overrides any existing alpha
|
266
|
-
values found in color. Defaults to 0.0.
|
267
|
-
"""
|
268
|
-
# Log the contour perimeter plotting parameters
|
269
|
-
params.log_plotter(
|
270
|
-
perimeter_type="contour",
|
271
|
-
perimeter_scale=scale,
|
272
|
-
perimeter_levels=levels,
|
273
|
-
perimeter_bandwidth=bandwidth,
|
274
|
-
perimeter_grid_size=grid_size,
|
275
|
-
perimeter_linestyle=linestyle,
|
276
|
-
perimeter_linewidth=linewidth,
|
277
|
-
perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
|
278
|
-
perimeter_outline_alpha=outline_alpha,
|
279
|
-
perimeter_fill_alpha=fill_alpha,
|
280
|
-
)
|
281
|
-
|
282
|
-
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
283
|
-
color = _to_rgba(color=color, alpha=outline_alpha)
|
284
|
-
# Extract node coordinates from the network graph
|
285
|
-
node_coordinates = self.graph.node_coordinates
|
286
|
-
# Scale the node coordinates if needed
|
287
|
-
scaled_coordinates = node_coordinates * scale
|
288
|
-
# Use the existing _draw_kde_contour method
|
289
|
-
self._draw_kde_contour(
|
290
|
-
ax=self.ax,
|
291
|
-
pos=scaled_coordinates,
|
292
|
-
nodes=list(range(len(node_coordinates))), # All nodes are included
|
293
|
-
levels=levels,
|
294
|
-
bandwidth=bandwidth,
|
295
|
-
grid_size=grid_size,
|
296
|
-
color=color,
|
297
|
-
linestyle=linestyle,
|
298
|
-
linewidth=linewidth,
|
299
|
-
alpha=fill_alpha,
|
300
|
-
)
|
301
|
-
|
302
|
-
def plot_network(
|
303
|
-
self,
|
304
|
-
node_size: Union[int, np.ndarray] = 50,
|
305
|
-
node_shape: str = "o",
|
306
|
-
node_edgewidth: float = 1.0,
|
307
|
-
edge_width: float = 1.0,
|
308
|
-
node_color: Union[str, List, Tuple, np.ndarray] = "white",
|
309
|
-
node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
|
310
|
-
edge_color: Union[str, List, Tuple, np.ndarray] = "black",
|
311
|
-
node_alpha: Union[float, None] = 1.0,
|
312
|
-
edge_alpha: Union[float, None] = 1.0,
|
313
|
-
) -> None:
|
314
|
-
"""Plot the network graph with customizable node colors, sizes, edge widths, and node edge widths.
|
315
|
-
|
316
|
-
Args:
|
317
|
-
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
318
|
-
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
319
|
-
node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
|
320
|
-
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
321
|
-
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors.
|
322
|
-
Defaults to "white".
|
323
|
-
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
324
|
-
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
325
|
-
node_alpha (float, None, optional): Alpha value (transparency) for the nodes. If provided, it overrides any existing alpha
|
326
|
-
values found in node_color. Defaults to 1.0. Annotated node_color alphas will override this value.
|
327
|
-
edge_alpha (float, None, optional): Alpha value (transparency) for the edges. If provided, it overrides any existing alpha
|
328
|
-
values found in edge_color. Defaults to 1.0.
|
329
|
-
"""
|
330
|
-
# Log the plotting parameters
|
331
|
-
params.log_plotter(
|
332
|
-
network_node_size=(
|
333
|
-
"custom" if isinstance(node_size, np.ndarray) else node_size
|
334
|
-
), # np.ndarray usually indicates custom sizes
|
335
|
-
network_node_shape=node_shape,
|
336
|
-
network_node_edgewidth=node_edgewidth,
|
337
|
-
network_edge_width=edge_width,
|
338
|
-
network_node_color=(
|
339
|
-
"custom" if isinstance(node_color, np.ndarray) else node_color
|
340
|
-
), # np.ndarray usually indicates custom colors
|
341
|
-
network_node_edgecolor=node_edgecolor,
|
342
|
-
network_edge_color=edge_color,
|
343
|
-
network_node_alpha=node_alpha,
|
344
|
-
network_edge_alpha=edge_alpha,
|
345
|
-
)
|
346
|
-
|
347
|
-
# Convert colors to RGBA using the _to_rgba helper function
|
348
|
-
# If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
|
349
|
-
node_color = _to_rgba(
|
350
|
-
color=node_color, alpha=node_alpha, num_repeats=len(self.graph.network.nodes)
|
351
|
-
)
|
352
|
-
node_edgecolor = _to_rgba(
|
353
|
-
color=node_edgecolor, alpha=1.0, num_repeats=len(self.graph.network.nodes)
|
354
|
-
)
|
355
|
-
edge_color = _to_rgba(
|
356
|
-
color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
|
357
|
-
)
|
358
|
-
|
359
|
-
# Extract node coordinates from the network graph
|
360
|
-
node_coordinates = self.graph.node_coordinates
|
361
|
-
|
362
|
-
# Draw the nodes of the graph
|
363
|
-
nx.draw_networkx_nodes(
|
364
|
-
self.graph.network,
|
365
|
-
pos=node_coordinates,
|
366
|
-
node_size=node_size,
|
367
|
-
node_shape=node_shape,
|
368
|
-
node_color=node_color,
|
369
|
-
edgecolors=node_edgecolor,
|
370
|
-
linewidths=node_edgewidth,
|
371
|
-
ax=self.ax,
|
372
|
-
)
|
373
|
-
# Draw the edges of the graph
|
374
|
-
nx.draw_networkx_edges(
|
375
|
-
self.graph.network,
|
376
|
-
pos=node_coordinates,
|
377
|
-
width=edge_width,
|
378
|
-
edge_color=edge_color,
|
379
|
-
ax=self.ax,
|
380
|
-
)
|
381
|
-
|
382
|
-
def plot_subnetwork(
|
383
|
-
self,
|
384
|
-
nodes: Union[List, Tuple, np.ndarray],
|
385
|
-
node_size: Union[int, np.ndarray] = 50,
|
386
|
-
node_shape: str = "o",
|
387
|
-
node_edgewidth: float = 1.0,
|
388
|
-
edge_width: float = 1.0,
|
389
|
-
node_color: Union[str, List, Tuple, np.ndarray] = "white",
|
390
|
-
node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
|
391
|
-
edge_color: Union[str, List, Tuple, np.ndarray] = "black",
|
392
|
-
node_alpha: Union[float, None] = None,
|
393
|
-
edge_alpha: Union[float, None] = None,
|
394
|
-
) -> None:
|
395
|
-
"""Plot a subnetwork of selected nodes with customizable node and edge attributes.
|
396
|
-
|
397
|
-
Args:
|
398
|
-
nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
|
399
|
-
node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
|
400
|
-
node_shape (str, optional): Shape of the nodes. Defaults to "o".
|
401
|
-
node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
|
402
|
-
edge_width (float, optional): Width of the edges. Defaults to 1.0.
|
403
|
-
node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
|
404
|
-
node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
|
405
|
-
edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
|
406
|
-
node_alpha (float, None, optional): Transparency for the nodes. If provided, it overrides any existing alpha values
|
407
|
-
found in node_color. Defaults to 1.0.
|
408
|
-
edge_alpha (float, None, optional): Transparency for the edges. If provided, it overrides any existing alpha values
|
409
|
-
found in node_color. Defaults to 1.0.
|
410
|
-
|
411
|
-
Raises:
|
412
|
-
ValueError: If no valid nodes are found in the network graph.
|
413
|
-
"""
|
414
|
-
# Flatten nested lists of nodes, if necessary
|
415
|
-
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
416
|
-
nodes = [node for sublist in nodes for node in sublist]
|
417
|
-
|
418
|
-
# Filter to get node IDs and their coordinates
|
419
|
-
node_ids = [
|
420
|
-
self.graph.node_label_to_node_id_map.get(node)
|
421
|
-
for node in nodes
|
422
|
-
if node in self.graph.node_label_to_node_id_map
|
423
|
-
]
|
424
|
-
if not node_ids:
|
425
|
-
raise ValueError("No nodes found in the network graph.")
|
426
|
-
|
427
|
-
# Check if node_color is a single color or a list of colors
|
428
|
-
if not isinstance(node_color, (str, tuple, np.ndarray)):
|
429
|
-
node_color = [
|
430
|
-
node_color[nodes.index(node)]
|
431
|
-
for node in nodes
|
432
|
-
if node in self.graph.node_label_to_node_id_map
|
433
|
-
]
|
434
|
-
|
435
|
-
# Convert colors to RGBA using the _to_rgba helper function
|
436
|
-
node_color = _to_rgba(color=node_color, alpha=node_alpha, num_repeats=len(node_ids))
|
437
|
-
node_edgecolor = _to_rgba(color=node_edgecolor, alpha=1.0, num_repeats=len(node_ids))
|
438
|
-
edge_color = _to_rgba(
|
439
|
-
color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
|
440
|
-
)
|
441
|
-
|
442
|
-
# Get the coordinates of the filtered nodes
|
443
|
-
node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
|
444
|
-
|
445
|
-
# Draw the nodes in the subnetwork
|
446
|
-
nx.draw_networkx_nodes(
|
447
|
-
self.graph.network,
|
448
|
-
pos=node_coordinates,
|
449
|
-
nodelist=node_ids,
|
450
|
-
node_size=node_size,
|
451
|
-
node_shape=node_shape,
|
452
|
-
node_color=node_color,
|
453
|
-
edgecolors=node_edgecolor,
|
454
|
-
linewidths=node_edgewidth,
|
455
|
-
ax=self.ax,
|
456
|
-
)
|
457
|
-
# Draw the edges between the specified nodes in the subnetwork
|
458
|
-
subgraph = self.graph.network.subgraph(node_ids)
|
459
|
-
nx.draw_networkx_edges(
|
460
|
-
subgraph,
|
461
|
-
pos=node_coordinates,
|
462
|
-
width=edge_width,
|
463
|
-
edge_color=edge_color,
|
464
|
-
ax=self.ax,
|
465
|
-
)
|
466
|
-
|
467
|
-
def plot_contours(
|
468
|
-
self,
|
469
|
-
levels: int = 5,
|
470
|
-
bandwidth: float = 0.8,
|
471
|
-
grid_size: int = 250,
|
472
|
-
color: Union[str, List, Tuple, np.ndarray] = "white",
|
473
|
-
linestyle: str = "solid",
|
474
|
-
linewidth: float = 1.5,
|
475
|
-
alpha: Union[float, None] = 1.0,
|
476
|
-
fill_alpha: Union[float, None] = None,
|
477
|
-
) -> None:
|
478
|
-
"""Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
|
479
|
-
|
480
|
-
Args:
|
481
|
-
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
482
|
-
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
483
|
-
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
484
|
-
color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors.
|
485
|
-
Defaults to "white".
|
486
|
-
linestyle (str, optional): Line style for the contours. Defaults to "solid".
|
487
|
-
linewidth (float, optional): Line width for the contours. Defaults to 1.5.
|
488
|
-
alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
|
489
|
-
found in color. Defaults to 1.0.
|
490
|
-
fill_alpha (float, None, optional): Transparency level of the contour fill. If provided, it overrides any existing alpha
|
491
|
-
values found in color. Defaults to None.
|
492
|
-
"""
|
493
|
-
# Log the contour plotting parameters
|
494
|
-
params.log_plotter(
|
495
|
-
contour_levels=levels,
|
496
|
-
contour_bandwidth=bandwidth,
|
497
|
-
contour_grid_size=grid_size,
|
498
|
-
contour_color=(
|
499
|
-
"custom" if isinstance(color, np.ndarray) else color
|
500
|
-
), # np.ndarray usually indicates custom colors
|
501
|
-
contour_alpha=alpha,
|
502
|
-
contour_fill_alpha=fill_alpha,
|
503
|
-
)
|
504
|
-
|
505
|
-
# Ensure color is converted to RGBA with repetition matching the number of domains
|
506
|
-
color = _to_rgba(
|
507
|
-
color=color, alpha=alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
508
|
-
)
|
509
|
-
# Extract node coordinates from the network graph
|
510
|
-
node_coordinates = self.graph.node_coordinates
|
511
|
-
# Draw contours for each domain in the network
|
512
|
-
for idx, (_, node_ids) in enumerate(self.graph.domain_id_to_node_ids_map.items()):
|
513
|
-
if len(node_ids) > 1:
|
514
|
-
self._draw_kde_contour(
|
515
|
-
self.ax,
|
516
|
-
node_coordinates,
|
517
|
-
node_ids,
|
518
|
-
color=color[idx],
|
519
|
-
levels=levels,
|
520
|
-
bandwidth=bandwidth,
|
521
|
-
grid_size=grid_size,
|
522
|
-
linestyle=linestyle,
|
523
|
-
linewidth=linewidth,
|
524
|
-
alpha=alpha,
|
525
|
-
fill_alpha=fill_alpha,
|
526
|
-
)
|
527
|
-
|
528
|
-
def plot_subcontour(
|
529
|
-
self,
|
530
|
-
nodes: Union[List, Tuple, np.ndarray],
|
531
|
-
levels: int = 5,
|
532
|
-
bandwidth: float = 0.8,
|
533
|
-
grid_size: int = 250,
|
534
|
-
color: Union[str, List, Tuple, np.ndarray] = "white",
|
535
|
-
linestyle: str = "solid",
|
536
|
-
linewidth: float = 1.5,
|
537
|
-
alpha: Union[float, None] = 1.0,
|
538
|
-
fill_alpha: Union[float, None] = None,
|
539
|
-
) -> None:
|
540
|
-
"""Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
|
541
|
-
|
542
|
-
Args:
|
543
|
-
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
|
544
|
-
levels (int, optional): Number of contour levels to plot. Defaults to 5.
|
545
|
-
bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
|
546
|
-
grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
|
547
|
-
color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array.
|
548
|
-
Defaults to "white".
|
549
|
-
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
550
|
-
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
551
|
-
alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
|
552
|
-
found in color. Defaults to 1.0.
|
553
|
-
fill_alpha (float, None, optional): Transparency level of the contour fill. If provided, it overrides any existing alpha
|
554
|
-
values found in color. Defaults to None.
|
555
|
-
|
556
|
-
Raises:
|
557
|
-
ValueError: If no valid nodes are found in the network graph.
|
558
|
-
"""
|
559
|
-
# Check if nodes is a list of lists or a flat list
|
560
|
-
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
561
|
-
# If it's a list of lists, iterate over sublists
|
562
|
-
node_groups = nodes
|
563
|
-
else:
|
564
|
-
# If it's a flat list of nodes, treat it as a single group
|
565
|
-
node_groups = [nodes]
|
566
|
-
|
567
|
-
# Convert color to RGBA using the _to_rgba helper function
|
568
|
-
color_rgba = _to_rgba(color=color, alpha=alpha)
|
569
|
-
|
570
|
-
# Iterate over each group of nodes (either sublists or flat list)
|
571
|
-
for sublist in node_groups:
|
572
|
-
# Filter to get node IDs and their coordinates for each sublist
|
573
|
-
node_ids = [
|
574
|
-
self.graph.node_label_to_node_id_map.get(node)
|
575
|
-
for node in sublist
|
576
|
-
if node in self.graph.node_label_to_node_id_map
|
577
|
-
]
|
578
|
-
if not node_ids or len(node_ids) == 1:
|
579
|
-
raise ValueError(
|
580
|
-
"No nodes found in the network graph or insufficient nodes to plot."
|
581
|
-
)
|
582
|
-
|
583
|
-
# Draw the KDE contour for the specified nodes
|
584
|
-
node_coordinates = self.graph.node_coordinates
|
585
|
-
self._draw_kde_contour(
|
586
|
-
self.ax,
|
587
|
-
node_coordinates,
|
588
|
-
node_ids,
|
589
|
-
color=color_rgba,
|
590
|
-
levels=levels,
|
591
|
-
bandwidth=bandwidth,
|
592
|
-
grid_size=grid_size,
|
593
|
-
linestyle=linestyle,
|
594
|
-
linewidth=linewidth,
|
595
|
-
alpha=alpha,
|
596
|
-
fill_alpha=fill_alpha,
|
597
|
-
)
|
598
|
-
|
599
|
-
def _draw_kde_contour(
|
600
|
-
self,
|
601
|
-
ax: plt.Axes,
|
602
|
-
pos: np.ndarray,
|
603
|
-
nodes: List,
|
604
|
-
levels: int = 5,
|
605
|
-
bandwidth: float = 0.8,
|
606
|
-
grid_size: int = 250,
|
607
|
-
color: Union[str, np.ndarray] = "white",
|
608
|
-
linestyle: str = "solid",
|
609
|
-
linewidth: float = 1.5,
|
610
|
-
alpha: Union[float, None] = 1.0,
|
611
|
-
fill_alpha: Union[float, None] = 0.2,
|
612
|
-
) -> None:
|
613
|
-
"""Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
|
614
|
-
|
615
|
-
Args:
|
616
|
-
ax (plt.Axes): The axis to draw the contour on.
|
617
|
-
pos (np.ndarray): Array of node positions (x, y).
|
618
|
-
nodes (list): List of node indices to include in the contour.
|
619
|
-
levels (int, optional): Number of contour levels. Defaults to 5.
|
620
|
-
bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
|
621
|
-
grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
|
622
|
-
color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
|
623
|
-
linestyle (str, optional): Line style for the contour. Defaults to "solid".
|
624
|
-
linewidth (float, optional): Line width for the contour. Defaults to 1.5.
|
625
|
-
alpha (float, None, optional): Transparency level for the contour lines. If provided, it overrides any existing alpha
|
626
|
-
values found in color. Defaults to 1.0.
|
627
|
-
fill_alpha (float, None, optional): Transparency level for the contour fill. If provided, it overrides any existing
|
628
|
-
alpha values found in color. Defaults to 0.2.
|
629
|
-
"""
|
630
|
-
# Extract the positions of the specified nodes
|
631
|
-
points = np.array([pos[n] for n in nodes])
|
632
|
-
if len(points) <= 1:
|
633
|
-
return None # Not enough points to form a contour
|
634
|
-
|
635
|
-
# Check if the KDE forms a single connected component
|
636
|
-
connected = False
|
637
|
-
z = None # Initialize z to None to avoid UnboundLocalError
|
638
|
-
while not connected and bandwidth <= 100.0:
|
639
|
-
try:
|
640
|
-
# Perform KDE on the points with the given bandwidth
|
641
|
-
kde = gaussian_kde(points.T, bw_method=bandwidth)
|
642
|
-
xmin, ymin = points.min(axis=0) - bandwidth
|
643
|
-
xmax, ymax = points.max(axis=0) + bandwidth
|
644
|
-
x, y = np.mgrid[
|
645
|
-
xmin : xmax : complex(0, grid_size), ymin : ymax : complex(0, grid_size)
|
646
|
-
]
|
647
|
-
z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
|
648
|
-
# Check if the KDE forms a single connected component
|
649
|
-
connected = _is_connected(z)
|
650
|
-
if not connected:
|
651
|
-
bandwidth += 0.05 # Increase bandwidth slightly and retry
|
652
|
-
except linalg.LinAlgError:
|
653
|
-
bandwidth += 0.05 # Increase bandwidth and retry
|
654
|
-
except Exception as e:
|
655
|
-
# Catch any other exceptions and log them
|
656
|
-
logger.error(f"Unexpected error when drawing KDE contour: {e}")
|
657
|
-
return None
|
658
|
-
|
659
|
-
# If z is still None, the KDE computation failed
|
660
|
-
if z is None:
|
661
|
-
logger.error("Failed to compute KDE. Skipping contour plot for these nodes.")
|
662
|
-
return None
|
663
|
-
|
664
|
-
# Define contour levels based on the density
|
665
|
-
min_density, max_density = z.min(), z.max()
|
666
|
-
if min_density == max_density:
|
667
|
-
logger.warning(
|
668
|
-
"Contour levels could not be created due to lack of variation in density."
|
669
|
-
)
|
670
|
-
return None
|
671
|
-
|
672
|
-
# Create contour levels based on the density values
|
673
|
-
contour_levels = np.linspace(min_density, max_density, levels)[1:]
|
674
|
-
if len(contour_levels) < 2 or not np.all(np.diff(contour_levels) > 0):
|
675
|
-
logger.error("Contour levels must be strictly increasing. Skipping contour plot.")
|
676
|
-
return None
|
677
|
-
|
678
|
-
# Set the contour color and linestyle
|
679
|
-
contour_colors = [color for _ in range(levels - 1)]
|
680
|
-
# Plot the filled contours using fill_alpha for transparency
|
681
|
-
if fill_alpha and fill_alpha > 0:
|
682
|
-
ax.contourf(
|
683
|
-
x,
|
684
|
-
y,
|
685
|
-
z,
|
686
|
-
levels=contour_levels,
|
687
|
-
colors=contour_colors,
|
688
|
-
antialiased=True,
|
689
|
-
alpha=fill_alpha,
|
690
|
-
)
|
691
|
-
|
692
|
-
# Plot the contour lines with the specified alpha for transparency
|
693
|
-
c = ax.contour(
|
694
|
-
x,
|
695
|
-
y,
|
696
|
-
z,
|
697
|
-
levels=contour_levels,
|
698
|
-
colors=contour_colors,
|
699
|
-
linestyles=linestyle,
|
700
|
-
linewidths=linewidth,
|
701
|
-
alpha=alpha,
|
702
|
-
)
|
703
|
-
|
704
|
-
# Set linewidth for the contour lines to 0 for levels other than the base level
|
705
|
-
for i in range(1, len(contour_levels)):
|
706
|
-
c.collections[i].set_linewidth(0)
|
707
|
-
|
708
|
-
def plot_labels(
|
709
|
-
self,
|
710
|
-
scale: float = 1.05,
|
711
|
-
offset: float = 0.10,
|
712
|
-
font: str = "Arial",
|
713
|
-
fontsize: int = 10,
|
714
|
-
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
715
|
-
fontalpha: Union[float, None] = 1.0,
|
716
|
-
arrow_linewidth: float = 1,
|
717
|
-
arrow_style: str = "->",
|
718
|
-
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
719
|
-
arrow_alpha: Union[float, None] = 1.0,
|
720
|
-
arrow_base_shrink: float = 0.0,
|
721
|
-
arrow_tip_shrink: float = 0.0,
|
722
|
-
max_labels: Union[int, None] = None,
|
723
|
-
max_label_lines: Union[int, None] = None,
|
724
|
-
min_label_lines: int = 1,
|
725
|
-
max_chars_per_line: Union[int, None] = None,
|
726
|
-
min_chars_per_line: int = 1,
|
727
|
-
words_to_omit: Union[List, None] = None,
|
728
|
-
overlay_ids: bool = False,
|
729
|
-
ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
|
730
|
-
ids_to_replace: Union[Dict, None] = None,
|
731
|
-
) -> None:
|
732
|
-
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
733
|
-
|
734
|
-
Args:
|
735
|
-
scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
|
736
|
-
offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
|
737
|
-
font (str, optional): Font name for the labels. Defaults to "Arial".
|
738
|
-
fontsize (int, optional): Font size for the labels. Defaults to 10.
|
739
|
-
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array.
|
740
|
-
Defaults to "black".
|
741
|
-
fontalpha (float, None, optional): Transparency level for the font color. If provided, it overrides any existing alpha
|
742
|
-
values found in fontcolor. Defaults to 1.0.
|
743
|
-
arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
|
744
|
-
arrow_style (str, optional): Style of the arrows pointing to centroids. Defaults to "->".
|
745
|
-
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
|
746
|
-
arrow_alpha (float, None, optional): Transparency level for the arrow color. If provided, it overrides any existing alpha
|
747
|
-
values found in arrow_color. Defaults to 1.0.
|
748
|
-
arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
|
749
|
-
arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
|
750
|
-
max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
|
751
|
-
min_label_lines (int, optional): Minimum number of lines in a label. Defaults to 1.
|
752
|
-
max_label_lines (int, optional): Maximum number of lines in a label. Defaults to None (no limit).
|
753
|
-
min_chars_per_line (int, optional): Minimum number of characters in a line to display. Defaults to 1.
|
754
|
-
max_chars_per_line (int, optional): Maximum number of characters in a line to display. Defaults to None (no limit).
|
755
|
-
words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
|
756
|
-
overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
|
757
|
-
ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
|
758
|
-
you can set `overlay_ids=True`. Defaults to None.
|
759
|
-
ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
|
760
|
-
space-separated words. If provided, the custom labels will replace the default domain terms. To discover domain IDs, you
|
761
|
-
can set `overlay_ids=True`. Defaults to None.
|
762
|
-
|
763
|
-
Raises:
|
764
|
-
ValueError: If the number of provided `ids_to_keep` exceeds `max_labels`.
|
765
|
-
"""
|
766
|
-
# Log the plotting parameters
|
767
|
-
params.log_plotter(
|
768
|
-
label_perimeter_scale=scale,
|
769
|
-
label_offset=offset,
|
770
|
-
label_font=font,
|
771
|
-
label_fontsize=fontsize,
|
772
|
-
label_fontcolor=(
|
773
|
-
"custom" if isinstance(fontcolor, np.ndarray) else fontcolor
|
774
|
-
), # np.ndarray usually indicates custom colors
|
775
|
-
label_fontalpha=fontalpha,
|
776
|
-
label_arrow_linewidth=arrow_linewidth,
|
777
|
-
label_arrow_style=arrow_style,
|
778
|
-
label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
779
|
-
label_arrow_alpha=arrow_alpha,
|
780
|
-
label_arrow_base_shrink=arrow_base_shrink,
|
781
|
-
label_arrow_tip_shrink=arrow_tip_shrink,
|
782
|
-
label_max_labels=max_labels,
|
783
|
-
label_min_label_lines=min_label_lines,
|
784
|
-
label_max_label_lines=max_label_lines,
|
785
|
-
label_max_chars_per_line=max_chars_per_line,
|
786
|
-
label_min_chars_per_line=min_chars_per_line,
|
787
|
-
label_words_to_omit=words_to_omit,
|
788
|
-
label_overlay_ids=overlay_ids,
|
789
|
-
label_ids_to_keep=ids_to_keep,
|
790
|
-
label_ids_to_replace=ids_to_replace,
|
791
|
-
)
|
792
|
-
|
793
|
-
# Convert ids_to_keep to a tuple if it is not None
|
794
|
-
ids_to_keep = tuple(ids_to_keep) if ids_to_keep else tuple()
|
795
|
-
# Set max_labels to the total number of domains if not provided (None)
|
796
|
-
if max_labels is None:
|
797
|
-
max_labels = len(self.graph.domain_id_to_node_ids_map)
|
798
|
-
# Set max_label_lines and max_chars_per_line to large numbers if not provided (None)
|
799
|
-
if max_label_lines is None:
|
800
|
-
max_label_lines = int(1e6)
|
801
|
-
if max_chars_per_line is None:
|
802
|
-
max_chars_per_line = int(1e6)
|
803
|
-
# Normalize words_to_omit to lowercase
|
804
|
-
if words_to_omit:
|
805
|
-
words_to_omit = set(word.lower() for word in words_to_omit)
|
806
|
-
|
807
|
-
# Calculate the center and radius of domains to position labels around the network
|
808
|
-
domain_id_to_centroid_map = {}
|
809
|
-
for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
810
|
-
if node_ids: # Skip if the domain has no nodes
|
811
|
-
domain_id_to_centroid_map[domain_id] = self._calculate_domain_centroid(node_ids)
|
812
|
-
|
813
|
-
# Initialize dictionaries and lists for valid indices
|
814
|
-
valid_indices = [] # List of valid indices to plot colors and arrows
|
815
|
-
filtered_domain_centroids = {} # Filtered domain centroids to plot
|
816
|
-
filtered_domain_terms = {} # Filtered domain terms to plot
|
817
|
-
# Handle the ids_to_keep logic
|
818
|
-
if ids_to_keep:
|
819
|
-
# Process the ids_to_keep first INPLACE
|
820
|
-
self._process_ids_to_keep(
|
821
|
-
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
822
|
-
ids_to_keep=ids_to_keep,
|
823
|
-
ids_to_replace=ids_to_replace,
|
824
|
-
words_to_omit=words_to_omit,
|
825
|
-
max_labels=max_labels,
|
826
|
-
min_label_lines=min_label_lines,
|
827
|
-
max_label_lines=max_label_lines,
|
828
|
-
min_chars_per_line=min_chars_per_line,
|
829
|
-
max_chars_per_line=max_chars_per_line,
|
830
|
-
filtered_domain_centroids=filtered_domain_centroids,
|
831
|
-
filtered_domain_terms=filtered_domain_terms,
|
832
|
-
valid_indices=valid_indices,
|
833
|
-
)
|
834
|
-
|
835
|
-
# Calculate remaining labels to plot after processing ids_to_keep
|
836
|
-
remaining_labels = (
|
837
|
-
max_labels - len(valid_indices) if valid_indices and max_labels else max_labels
|
838
|
-
)
|
839
|
-
# Process remaining domains INPLACE to fill in additional labels, if there are slots left
|
840
|
-
if remaining_labels and remaining_labels > 0:
|
841
|
-
self._process_remaining_domains(
|
842
|
-
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
843
|
-
ids_to_keep=ids_to_keep,
|
844
|
-
ids_to_replace=ids_to_replace,
|
845
|
-
words_to_omit=words_to_omit,
|
846
|
-
remaining_labels=remaining_labels,
|
847
|
-
min_chars_per_line=min_chars_per_line,
|
848
|
-
max_chars_per_line=max_chars_per_line,
|
849
|
-
max_label_lines=max_label_lines,
|
850
|
-
min_label_lines=min_label_lines,
|
851
|
-
filtered_domain_centroids=filtered_domain_centroids,
|
852
|
-
filtered_domain_terms=filtered_domain_terms,
|
853
|
-
valid_indices=valid_indices,
|
854
|
-
)
|
855
|
-
|
856
|
-
# Calculate the bounding box around the network
|
857
|
-
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
858
|
-
# Calculate the best positions for labels
|
859
|
-
best_label_positions = _calculate_best_label_positions(
|
860
|
-
filtered_domain_centroids, center, radius, offset
|
861
|
-
)
|
862
|
-
# Convert all domain colors to RGBA using the _to_rgba helper function
|
863
|
-
fontcolor = _to_rgba(
|
864
|
-
color=fontcolor, alpha=fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
865
|
-
)
|
866
|
-
arrow_color = _to_rgba(
|
867
|
-
color=arrow_color,
|
868
|
-
alpha=arrow_alpha,
|
869
|
-
num_repeats=len(self.graph.domain_id_to_node_ids_map),
|
870
|
-
)
|
871
|
-
|
872
|
-
# Annotate the network with labels
|
873
|
-
for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
|
874
|
-
centroid = filtered_domain_centroids[domain]
|
875
|
-
# Split by special key TERM_DELIMITER to split annotation into multiple lines
|
876
|
-
annotations = filtered_domain_terms[domain].split(TERM_DELIMITER)
|
877
|
-
self.ax.annotate(
|
878
|
-
"\n".join(annotations),
|
879
|
-
xy=centroid,
|
880
|
-
xytext=pos,
|
881
|
-
textcoords="data",
|
882
|
-
ha="center",
|
883
|
-
va="center",
|
884
|
-
fontsize=fontsize,
|
885
|
-
fontname=font,
|
886
|
-
color=fontcolor[idx],
|
887
|
-
arrowprops=dict(
|
888
|
-
arrowstyle=arrow_style,
|
889
|
-
color=arrow_color[idx],
|
890
|
-
linewidth=arrow_linewidth,
|
891
|
-
shrinkA=arrow_base_shrink,
|
892
|
-
shrinkB=arrow_tip_shrink,
|
893
|
-
),
|
894
|
-
)
|
895
|
-
|
896
|
-
# Overlay domain ID at the centroid regardless of max_labels if requested
|
897
|
-
if overlay_ids:
|
898
|
-
for idx, domain in enumerate(self.graph.domain_id_to_node_ids_map):
|
899
|
-
centroid = domain_id_to_centroid_map[domain]
|
900
|
-
self.ax.text(
|
901
|
-
centroid[0],
|
902
|
-
centroid[1],
|
903
|
-
domain,
|
904
|
-
ha="center",
|
905
|
-
va="center",
|
906
|
-
fontsize=fontsize,
|
907
|
-
fontname=font,
|
908
|
-
color=fontcolor[idx],
|
909
|
-
alpha=fontalpha,
|
910
|
-
)
|
911
|
-
|
912
|
-
def plot_sublabel(
|
913
|
-
self,
|
914
|
-
nodes: Union[List, Tuple, np.ndarray],
|
915
|
-
label: str,
|
916
|
-
radial_position: float = 0.0,
|
917
|
-
scale: float = 1.05,
|
918
|
-
offset: float = 0.10,
|
919
|
-
font: str = "Arial",
|
920
|
-
fontsize: int = 10,
|
921
|
-
fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
|
922
|
-
fontalpha: Union[float, None] = 1.0,
|
923
|
-
arrow_linewidth: float = 1,
|
924
|
-
arrow_style: str = "->",
|
925
|
-
arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
|
926
|
-
arrow_alpha: Union[float, None] = 1.0,
|
927
|
-
arrow_base_shrink: float = 0.0,
|
928
|
-
arrow_tip_shrink: float = 0.0,
|
929
|
-
) -> None:
|
930
|
-
"""Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
|
931
|
-
|
932
|
-
Args:
|
933
|
-
nodes (list, tuple, or np.ndarray): List of node labels or list of lists of node labels.
|
934
|
-
label (str): The label to be annotated on the network.
|
935
|
-
radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
|
936
|
-
scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
|
937
|
-
offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
|
938
|
-
font (str, optional): Font name for the label. Defaults to "Arial".
|
939
|
-
fontsize (int, optional): Font size for the label. Defaults to 10.
|
940
|
-
fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
|
941
|
-
fontalpha (float, None, optional): Transparency level for the font color. If provided, it overrides any existing alpha values found
|
942
|
-
in fontalpha. Defaults to 1.0.
|
943
|
-
arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
|
944
|
-
arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
|
945
|
-
arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
|
946
|
-
arrow_alpha (float, None, optional): Transparency level for the arrow color. If provided, it overrides any existing alpha values
|
947
|
-
found in arrow_alpha. Defaults to 1.0.
|
948
|
-
arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
|
949
|
-
arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
|
950
|
-
"""
|
951
|
-
# Check if nodes is a list of lists or a flat list
|
952
|
-
if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
|
953
|
-
# If it's a list of lists, iterate over sublists
|
954
|
-
node_groups = nodes
|
955
|
-
else:
|
956
|
-
# If it's a flat list of nodes, treat it as a single group
|
957
|
-
node_groups = [nodes]
|
958
|
-
|
959
|
-
# Convert fontcolor and arrow_color to RGBA
|
960
|
-
fontcolor_rgba = _to_rgba(color=fontcolor, alpha=fontalpha)
|
961
|
-
arrow_color_rgba = _to_rgba(color=arrow_color, alpha=arrow_alpha)
|
962
|
-
|
963
|
-
# Calculate the bounding box around the network
|
964
|
-
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
965
|
-
# Convert radial position to radians, adjusting for a 90-degree rotation
|
966
|
-
radial_radians = np.deg2rad(radial_position - 90)
|
967
|
-
label_position = (
|
968
|
-
center[0] + (radius + offset) * np.cos(radial_radians),
|
969
|
-
center[1] + (radius + offset) * np.sin(radial_radians),
|
970
|
-
)
|
971
|
-
|
972
|
-
# Iterate over each group of nodes (either sublists or flat list)
|
973
|
-
for sublist in node_groups:
|
974
|
-
# Map node labels to IDs
|
975
|
-
node_ids = [
|
976
|
-
self.graph.node_label_to_node_id_map.get(node)
|
977
|
-
for node in sublist
|
978
|
-
if node in self.graph.node_label_to_node_id_map
|
979
|
-
]
|
980
|
-
if not node_ids or len(node_ids) == 1:
|
981
|
-
raise ValueError(
|
982
|
-
"No nodes found in the network graph or insufficient nodes to plot."
|
983
|
-
)
|
984
|
-
|
985
|
-
# Calculate the centroid of the provided nodes in this sublist
|
986
|
-
centroid = self._calculate_domain_centroid(node_ids)
|
987
|
-
# Annotate the network with the label and an arrow pointing to each centroid
|
988
|
-
self.ax.annotate(
|
989
|
-
label,
|
990
|
-
xy=centroid,
|
991
|
-
xytext=label_position,
|
992
|
-
textcoords="data",
|
993
|
-
ha="center",
|
994
|
-
va="center",
|
995
|
-
fontsize=fontsize,
|
996
|
-
fontname=font,
|
997
|
-
color=fontcolor_rgba,
|
998
|
-
arrowprops=dict(
|
999
|
-
arrowstyle=arrow_style,
|
1000
|
-
color=arrow_color_rgba,
|
1001
|
-
linewidth=arrow_linewidth,
|
1002
|
-
shrinkA=arrow_base_shrink,
|
1003
|
-
shrinkB=arrow_tip_shrink,
|
1004
|
-
),
|
1005
|
-
)
|
1006
|
-
|
1007
|
-
def _calculate_domain_centroid(self, nodes: List) -> tuple:
|
1008
|
-
"""Calculate the most centrally located node in .
|
1009
|
-
|
1010
|
-
Args:
|
1011
|
-
nodes (list): List of node labels to include in the subnetwork.
|
1012
|
-
|
1013
|
-
Returns:
|
1014
|
-
tuple: A tuple containing the domain's central node coordinates.
|
1015
|
-
"""
|
1016
|
-
# Extract positions of all nodes in the domain
|
1017
|
-
node_positions = self.graph.node_coordinates[nodes, :]
|
1018
|
-
# Calculate the pairwise distance matrix between all nodes in the domain
|
1019
|
-
distances_matrix = np.linalg.norm(node_positions[:, np.newaxis] - node_positions, axis=2)
|
1020
|
-
# Sum the distances for each node to all other nodes in the domain
|
1021
|
-
sum_distances = np.sum(distances_matrix, axis=1)
|
1022
|
-
# Identify the node with the smallest total distance to others (the centroid)
|
1023
|
-
central_node_idx = np.argmin(sum_distances)
|
1024
|
-
# Map the domain to the coordinates of its central node
|
1025
|
-
domain_central_node = node_positions[central_node_idx]
|
1026
|
-
return domain_central_node
|
1027
|
-
|
1028
|
-
def _process_ids_to_keep(
|
1029
|
-
self,
|
1030
|
-
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
1031
|
-
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
1032
|
-
ids_to_replace: Union[Dict[str, str], None],
|
1033
|
-
words_to_omit: Union[List[str], None],
|
1034
|
-
max_labels: Union[int, None],
|
1035
|
-
min_label_lines: int,
|
1036
|
-
max_label_lines: int,
|
1037
|
-
min_chars_per_line: int,
|
1038
|
-
max_chars_per_line: int,
|
1039
|
-
filtered_domain_centroids: Dict[str, np.ndarray],
|
1040
|
-
filtered_domain_terms: Dict[str, str],
|
1041
|
-
valid_indices: List[int],
|
1042
|
-
) -> None:
|
1043
|
-
"""Process the ids_to_keep, apply filtering, and store valid domain centroids and terms.
|
1044
|
-
|
1045
|
-
Args:
|
1046
|
-
domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
|
1047
|
-
ids_to_keep (list, tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
1048
|
-
ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
1049
|
-
words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
|
1050
|
-
max_labels (int, optional): Maximum number of labels allowed.
|
1051
|
-
min_label_lines (int): Minimum number of lines in a label.
|
1052
|
-
max_label_lines (int): Maximum number of lines in a label.
|
1053
|
-
min_chars_per_line (int): Minimum number of characters in a line to display.
|
1054
|
-
max_chars_per_line (int): Maximum number of characters in a line to display.
|
1055
|
-
filtered_domain_centroids (dict): Dictionary to store filtered domain centroids (output).
|
1056
|
-
filtered_domain_terms (dict): Dictionary to store filtered domain terms (output).
|
1057
|
-
valid_indices (list): List to store valid indices (output).
|
1058
|
-
|
1059
|
-
Note:
|
1060
|
-
The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
|
1061
|
-
|
1062
|
-
Raises:
|
1063
|
-
ValueError: If the number of provided `ids_to_keep` exceeds `max_labels`.
|
1064
|
-
"""
|
1065
|
-
# Check if the number of provided ids_to_keep exceeds max_labels
|
1066
|
-
if max_labels is not None and len(ids_to_keep) > max_labels:
|
1067
|
-
raise ValueError(
|
1068
|
-
f"Number of provided IDs ({len(ids_to_keep)}) exceeds max_labels ({max_labels})."
|
1069
|
-
)
|
1070
|
-
|
1071
|
-
# Process each domain in ids_to_keep
|
1072
|
-
for domain in ids_to_keep:
|
1073
|
-
if (
|
1074
|
-
domain in self.graph.domain_id_to_domain_terms_map
|
1075
|
-
and domain in domain_id_to_centroid_map
|
1076
|
-
):
|
1077
|
-
domain_centroid = domain_id_to_centroid_map[domain]
|
1078
|
-
# No need to filter the domain terms if it is in ids_to_keep
|
1079
|
-
_ = self._validate_and_update_domain(
|
1080
|
-
domain=domain,
|
1081
|
-
domain_centroid=domain_centroid,
|
1082
|
-
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
1083
|
-
ids_to_replace=ids_to_replace,
|
1084
|
-
words_to_omit=words_to_omit,
|
1085
|
-
min_label_lines=min_label_lines,
|
1086
|
-
max_label_lines=max_label_lines,
|
1087
|
-
min_chars_per_line=min_chars_per_line,
|
1088
|
-
max_chars_per_line=max_chars_per_line,
|
1089
|
-
filtered_domain_centroids=filtered_domain_centroids,
|
1090
|
-
filtered_domain_terms=filtered_domain_terms,
|
1091
|
-
valid_indices=valid_indices,
|
1092
|
-
)
|
1093
|
-
|
1094
|
-
def _process_remaining_domains(
|
1095
|
-
self,
|
1096
|
-
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
1097
|
-
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
1098
|
-
ids_to_replace: Union[Dict[str, str], None],
|
1099
|
-
words_to_omit: Union[List[str], None],
|
1100
|
-
remaining_labels: int,
|
1101
|
-
min_label_lines: int,
|
1102
|
-
max_label_lines: int,
|
1103
|
-
min_chars_per_line: int,
|
1104
|
-
max_chars_per_line: int,
|
1105
|
-
filtered_domain_centroids: Dict[str, np.ndarray],
|
1106
|
-
filtered_domain_terms: Dict[str, str],
|
1107
|
-
valid_indices: List[int],
|
1108
|
-
) -> None:
|
1109
|
-
"""Process remaining domains to fill in additional labels, respecting the remaining_labels limit.
|
1110
|
-
|
1111
|
-
Args:
|
1112
|
-
domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
|
1113
|
-
ids_to_keep (list, tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
1114
|
-
ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
1115
|
-
words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
|
1116
|
-
remaining_labels (int): The remaining number of labels that can be generated.
|
1117
|
-
min_label_lines (int): Minimum number of lines in a label.
|
1118
|
-
max_label_lines (int): Maximum number of lines in a label.
|
1119
|
-
min_chars_per_line (int): Minimum number of characters in a line to display.
|
1120
|
-
max_chars_per_line (int): Maximum number of characters in a line to display.
|
1121
|
-
filtered_domain_centroids (dict): Dictionary to store filtered domain centroids (output).
|
1122
|
-
filtered_domain_terms (dict): Dictionary to store filtered domain terms (output).
|
1123
|
-
valid_indices (list): List to store valid indices (output).
|
1124
|
-
|
1125
|
-
Note:
|
1126
|
-
The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
|
1127
|
-
"""
|
1128
|
-
# Counter to track how many labels have been created
|
1129
|
-
label_count = 0
|
1130
|
-
# Collect domains not in ids_to_keep
|
1131
|
-
remaining_domains = {
|
1132
|
-
domain: centroid
|
1133
|
-
for domain, centroid in domain_id_to_centroid_map.items()
|
1134
|
-
if domain not in ids_to_keep and not pd.isna(domain)
|
1135
|
-
}
|
1136
|
-
|
1137
|
-
# Function to calculate distance between two centroids
|
1138
|
-
def calculate_distance(centroid1, centroid2):
|
1139
|
-
return np.linalg.norm(centroid1 - centroid2)
|
1140
|
-
|
1141
|
-
# Find the farthest apart domains using centroids
|
1142
|
-
if remaining_domains and remaining_labels:
|
1143
|
-
selected_domains = []
|
1144
|
-
first_domain = next(iter(remaining_domains)) # Pick the first domain to start
|
1145
|
-
selected_domains.append(first_domain)
|
1146
|
-
|
1147
|
-
while len(selected_domains) < remaining_labels:
|
1148
|
-
farthest_domain = None
|
1149
|
-
max_distance = -1
|
1150
|
-
# Find the domain farthest from any already selected domain
|
1151
|
-
for candidate_domain, candidate_centroid in remaining_domains.items():
|
1152
|
-
if candidate_domain in selected_domains:
|
1153
|
-
continue
|
1154
|
-
|
1155
|
-
# Calculate the minimum distance to any selected domain
|
1156
|
-
min_distance = min(
|
1157
|
-
calculate_distance(candidate_centroid, remaining_domains[dom])
|
1158
|
-
for dom in selected_domains
|
1159
|
-
)
|
1160
|
-
# Update the farthest domain if the minimum distance is greater
|
1161
|
-
if min_distance > max_distance:
|
1162
|
-
max_distance = min_distance
|
1163
|
-
farthest_domain = candidate_domain
|
1164
|
-
|
1165
|
-
# Add the farthest domain to the selected domains
|
1166
|
-
if farthest_domain:
|
1167
|
-
selected_domains.append(farthest_domain)
|
1168
|
-
else:
|
1169
|
-
break # No more domains to select
|
1170
|
-
|
1171
|
-
# Process the selected domains and add to filtered lists
|
1172
|
-
for domain in selected_domains:
|
1173
|
-
domain_centroid = remaining_domains[domain]
|
1174
|
-
is_domain_valid = self._validate_and_update_domain(
|
1175
|
-
domain=domain,
|
1176
|
-
domain_centroid=domain_centroid,
|
1177
|
-
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
1178
|
-
ids_to_replace=ids_to_replace,
|
1179
|
-
words_to_omit=words_to_omit,
|
1180
|
-
min_label_lines=min_label_lines,
|
1181
|
-
max_label_lines=max_label_lines,
|
1182
|
-
min_chars_per_line=min_chars_per_line,
|
1183
|
-
max_chars_per_line=max_chars_per_line,
|
1184
|
-
filtered_domain_centroids=filtered_domain_centroids,
|
1185
|
-
filtered_domain_terms=filtered_domain_terms,
|
1186
|
-
valid_indices=valid_indices,
|
1187
|
-
)
|
1188
|
-
# Increment the label count if the domain is valid
|
1189
|
-
if is_domain_valid:
|
1190
|
-
label_count += 1
|
1191
|
-
if label_count >= remaining_labels:
|
1192
|
-
break
|
1193
|
-
|
1194
|
-
def _validate_and_update_domain(
|
1195
|
-
self,
|
1196
|
-
domain: str,
|
1197
|
-
domain_centroid: np.ndarray,
|
1198
|
-
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
1199
|
-
ids_to_replace: Union[Dict[str, str], None],
|
1200
|
-
words_to_omit: Union[List[str], None],
|
1201
|
-
min_label_lines: int,
|
1202
|
-
max_label_lines: int,
|
1203
|
-
min_chars_per_line: int,
|
1204
|
-
max_chars_per_line: int,
|
1205
|
-
filtered_domain_centroids: Dict[str, np.ndarray],
|
1206
|
-
filtered_domain_terms: Dict[str, str],
|
1207
|
-
valid_indices: List[int],
|
1208
|
-
) -> bool:
|
1209
|
-
"""Validate and process the domain terms, updating relevant dictionaries if valid.
|
1210
|
-
|
1211
|
-
Args:
|
1212
|
-
domain (str): Domain ID to process.
|
1213
|
-
domain_centroid (np.ndarray): Centroid position of the domain.
|
1214
|
-
domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
|
1215
|
-
ids_to_replace (Union[Dict[str, str], None]): A dictionary mapping domain IDs to custom labels.
|
1216
|
-
words_to_omit (Union[List[str], None]): List of words to omit from the labels.
|
1217
|
-
min_label_lines (int): Minimum number of lines required in a label.
|
1218
|
-
max_label_lines (int): Maximum number of lines allowed in a label.
|
1219
|
-
min_chars_per_line (int): Minimum number of characters allowed per line.
|
1220
|
-
max_chars_per_line (int): Maximum number of characters allowed per line.
|
1221
|
-
filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store valid domain centroids.
|
1222
|
-
filtered_domain_terms (Dict[str, str]): Dictionary to store valid domain terms.
|
1223
|
-
valid_indices (List[int]): List of valid domain indices.
|
1224
|
-
|
1225
|
-
Returns:
|
1226
|
-
bool: True if the domain is valid and added to the filtered dictionaries, False otherwise.
|
1227
|
-
|
1228
|
-
Note:
|
1229
|
-
The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
|
1230
|
-
"""
|
1231
|
-
# Process the domain terms
|
1232
|
-
domain_terms = self._process_terms(
|
1233
|
-
domain=domain,
|
1234
|
-
ids_to_replace=ids_to_replace,
|
1235
|
-
words_to_omit=words_to_omit,
|
1236
|
-
max_label_lines=max_label_lines,
|
1237
|
-
min_chars_per_line=min_chars_per_line,
|
1238
|
-
max_chars_per_line=max_chars_per_line,
|
1239
|
-
)
|
1240
|
-
# If domain_terms is empty, skip further processing
|
1241
|
-
if not domain_terms:
|
1242
|
-
return False
|
1243
|
-
|
1244
|
-
# Split the terms by TERM_DELIMITER and count the number of lines
|
1245
|
-
num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
|
1246
|
-
# Check if the number of lines is greater than or equal to the minimum
|
1247
|
-
if num_domain_lines >= min_label_lines:
|
1248
|
-
filtered_domain_centroids[domain] = domain_centroid
|
1249
|
-
filtered_domain_terms[domain] = domain_terms
|
1250
|
-
# Add the index of the domain to the valid indices list
|
1251
|
-
valid_indices.append(list(domain_id_to_centroid_map.keys()).index(domain))
|
1252
|
-
return True
|
1253
|
-
|
1254
|
-
return False
|
1255
|
-
|
1256
|
-
def _process_terms(
|
1257
|
-
self,
|
1258
|
-
domain: str,
|
1259
|
-
ids_to_replace: Union[Dict[str, str], None],
|
1260
|
-
words_to_omit: Union[List[str], None],
|
1261
|
-
max_label_lines: int,
|
1262
|
-
min_chars_per_line: int,
|
1263
|
-
max_chars_per_line: int,
|
1264
|
-
) -> List[str]:
|
1265
|
-
"""Process terms for a domain, applying word length constraints and combining words where appropriate.
|
1266
|
-
|
1267
|
-
Args:
|
1268
|
-
domain (str): The domain being processed.
|
1269
|
-
ids_to_replace (dict, optional): Dictionary mapping domain IDs to custom labels.
|
1270
|
-
words_to_omit (list, optional): List of words to omit from the labels.
|
1271
|
-
max_label_lines (int): Maximum number of lines in a label.
|
1272
|
-
min_chars_per_line (int): Minimum number of characters in a line to display.
|
1273
|
-
max_chars_per_line (int): Maximum number of characters in a line to display.
|
1274
|
-
|
1275
|
-
Returns:
|
1276
|
-
str: Processed terms separated by TERM_DELIMITER, with words combined if necessary to fit within constraints.
|
1277
|
-
"""
|
1278
|
-
# Handle ids_to_replace logic
|
1279
|
-
if ids_to_replace and domain in ids_to_replace:
|
1280
|
-
terms = ids_to_replace[domain].split(" ")
|
1281
|
-
else:
|
1282
|
-
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
1283
|
-
|
1284
|
-
# Apply words_to_omit and word length constraints
|
1285
|
-
if words_to_omit:
|
1286
|
-
terms = [
|
1287
|
-
term
|
1288
|
-
for term in terms
|
1289
|
-
if term.lower() not in words_to_omit and len(term) >= min_chars_per_line
|
1290
|
-
]
|
1291
|
-
|
1292
|
-
# Use the combine_words function directly to handle word combinations and length constraints
|
1293
|
-
compressed_terms = _combine_words(tuple(terms), max_chars_per_line, max_label_lines)
|
1294
|
-
|
1295
|
-
return compressed_terms
|
1296
|
-
|
1297
|
-
def get_annotated_node_colors(
|
1298
|
-
self,
|
1299
|
-
cmap: str = "gist_rainbow",
|
1300
|
-
color: Union[str, None] = None,
|
1301
|
-
min_scale: float = 0.8,
|
1302
|
-
max_scale: float = 1.0,
|
1303
|
-
scale_factor: float = 1.0,
|
1304
|
-
alpha: Union[float, None] = 1.0,
|
1305
|
-
nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
|
1306
|
-
nonenriched_alpha: Union[float, None] = 1.0,
|
1307
|
-
random_seed: int = 888,
|
1308
|
-
) -> np.ndarray:
|
1309
|
-
"""Adjust the colors of nodes in the network graph based on enrichment.
|
1310
|
-
|
1311
|
-
Args:
|
1312
|
-
cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
|
1313
|
-
color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
|
1314
|
-
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
1315
|
-
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
1316
|
-
scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
|
1317
|
-
alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values
|
1318
|
-
found in color. Defaults to 1.0.
|
1319
|
-
nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
|
1320
|
-
nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing
|
1321
|
-
alpha values found in nonenriched_color. Defaults to 1.0.
|
1322
|
-
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
1323
|
-
|
1324
|
-
Returns:
|
1325
|
-
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
1326
|
-
"""
|
1327
|
-
# Get the initial domain colors for each node, which are returned as RGBA
|
1328
|
-
network_colors = self.graph.get_domain_colors(
|
1329
|
-
cmap=cmap,
|
1330
|
-
color=color,
|
1331
|
-
min_scale=min_scale,
|
1332
|
-
max_scale=max_scale,
|
1333
|
-
scale_factor=scale_factor,
|
1334
|
-
random_seed=random_seed,
|
1335
|
-
)
|
1336
|
-
# Apply the alpha value for enriched nodes
|
1337
|
-
network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
|
1338
|
-
# Convert the non-enriched color to RGBA using the _to_rgba helper function
|
1339
|
-
nonenriched_color = _to_rgba(color=nonenriched_color, alpha=nonenriched_alpha)
|
1340
|
-
# Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
|
1341
|
-
adjusted_network_colors = np.where(
|
1342
|
-
np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
|
1343
|
-
np.array(nonenriched_color), # Apply the non-enriched color with alpha
|
1344
|
-
network_colors, # Keep the original colors for enriched nodes
|
1345
|
-
)
|
1346
|
-
return adjusted_network_colors
|
1347
|
-
|
1348
|
-
def get_annotated_node_sizes(
|
1349
|
-
self, enriched_size: int = 50, nonenriched_size: int = 25
|
1350
|
-
) -> np.ndarray:
|
1351
|
-
"""Adjust the sizes of nodes in the network graph based on whether they are enriched or not.
|
1352
|
-
|
1353
|
-
Args:
|
1354
|
-
enriched_size (int): Size for enriched nodes. Defaults to 50.
|
1355
|
-
nonenriched_size (int): Size for non-enriched nodes. Defaults to 25.
|
1356
|
-
|
1357
|
-
Returns:
|
1358
|
-
np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
|
1359
|
-
"""
|
1360
|
-
# Merge all enriched nodes from the domain_id_to_node_ids_map dictionary
|
1361
|
-
enriched_nodes = set()
|
1362
|
-
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
1363
|
-
enriched_nodes.update(node_ids)
|
1364
|
-
|
1365
|
-
# Initialize all node sizes to the non-enriched size
|
1366
|
-
node_sizes = np.full(len(self.graph.network.nodes), nonenriched_size)
|
1367
|
-
# Set the size for enriched nodes
|
1368
|
-
for node in enriched_nodes:
|
1369
|
-
if node in self.graph.network.nodes:
|
1370
|
-
node_sizes[node] = enriched_size
|
1371
|
-
|
1372
|
-
return node_sizes
|
1373
|
-
|
1374
|
-
def get_annotated_contour_colors(
|
1375
|
-
self,
|
1376
|
-
cmap: str = "gist_rainbow",
|
1377
|
-
color: Union[str, None] = None,
|
1378
|
-
min_scale: float = 0.8,
|
1379
|
-
max_scale: float = 1.0,
|
1380
|
-
scale_factor: float = 1.0,
|
1381
|
-
random_seed: int = 888,
|
1382
|
-
) -> np.ndarray:
|
1383
|
-
"""Get colors for the contours based on node annotations or a specified colormap.
|
1384
|
-
|
1385
|
-
Args:
|
1386
|
-
cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
|
1387
|
-
color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
|
1388
|
-
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
1389
|
-
Controls the dimmest colors. Defaults to 0.8.
|
1390
|
-
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
1391
|
-
Controls the brightest colors. Defaults to 1.0.
|
1392
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
|
1393
|
-
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
1394
|
-
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
1395
|
-
|
1396
|
-
Returns:
|
1397
|
-
np.ndarray: Array of RGBA colors for contour annotations.
|
1398
|
-
"""
|
1399
|
-
return self._get_annotated_domain_colors(
|
1400
|
-
cmap=cmap,
|
1401
|
-
color=color,
|
1402
|
-
min_scale=min_scale,
|
1403
|
-
max_scale=max_scale,
|
1404
|
-
scale_factor=scale_factor,
|
1405
|
-
random_seed=random_seed,
|
1406
|
-
)
|
1407
|
-
|
1408
|
-
def get_annotated_label_colors(
|
1409
|
-
self,
|
1410
|
-
cmap: str = "gist_rainbow",
|
1411
|
-
color: Union[str, None] = None,
|
1412
|
-
min_scale: float = 0.8,
|
1413
|
-
max_scale: float = 1.0,
|
1414
|
-
scale_factor: float = 1.0,
|
1415
|
-
random_seed: int = 888,
|
1416
|
-
) -> np.ndarray:
|
1417
|
-
"""Get colors for the labels based on node annotations or a specified colormap.
|
1418
|
-
|
1419
|
-
Args:
|
1420
|
-
cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
|
1421
|
-
color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
|
1422
|
-
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
1423
|
-
Controls the dimmest colors. Defaults to 0.8.
|
1424
|
-
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
1425
|
-
Controls the brightest colors. Defaults to 1.0.
|
1426
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on enrichment scores.
|
1427
|
-
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
1428
|
-
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
1429
|
-
|
1430
|
-
Returns:
|
1431
|
-
np.ndarray: Array of RGBA colors for label annotations.
|
1432
|
-
"""
|
1433
|
-
return self._get_annotated_domain_colors(
|
1434
|
-
cmap=cmap,
|
1435
|
-
color=color,
|
1436
|
-
min_scale=min_scale,
|
1437
|
-
max_scale=max_scale,
|
1438
|
-
scale_factor=scale_factor,
|
1439
|
-
random_seed=random_seed,
|
1440
|
-
)
|
1441
|
-
|
1442
|
-
def _get_annotated_domain_colors(
|
1443
|
-
self,
|
1444
|
-
cmap: str = "gist_rainbow",
|
1445
|
-
color: Union[str, None] = None,
|
1446
|
-
min_scale: float = 0.8,
|
1447
|
-
max_scale: float = 1.0,
|
1448
|
-
scale_factor: float = 1.0,
|
1449
|
-
random_seed: int = 888,
|
1450
|
-
) -> np.ndarray:
|
1451
|
-
"""Get colors for the domains based on node annotations, or use a specified color.
|
1452
|
-
|
1453
|
-
Args:
|
1454
|
-
cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
1455
|
-
color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
|
1456
|
-
min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
|
1457
|
-
Defaults to 0.8.
|
1458
|
-
max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
|
1459
|
-
Defaults to 1.0.
|
1460
|
-
scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
|
1461
|
-
enrichment. Higher values increase the contrast. Defaults to 1.0.
|
1462
|
-
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
1463
|
-
|
1464
|
-
Returns:
|
1465
|
-
np.ndarray: Array of RGBA colors for each domain.
|
1466
|
-
"""
|
1467
|
-
# Generate domain colors based on the enrichment data
|
1468
|
-
node_colors = self.graph.get_domain_colors(
|
1469
|
-
cmap=cmap,
|
1470
|
-
color=color,
|
1471
|
-
min_scale=min_scale,
|
1472
|
-
max_scale=max_scale,
|
1473
|
-
scale_factor=scale_factor,
|
1474
|
-
random_seed=random_seed,
|
1475
|
-
)
|
1476
|
-
annotated_colors = []
|
1477
|
-
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
1478
|
-
if len(node_ids) > 1:
|
1479
|
-
# For multi-node domains, choose the brightest color based on RGB sum
|
1480
|
-
domain_colors = np.array([node_colors[node] for node in node_ids])
|
1481
|
-
brightest_color = domain_colors[
|
1482
|
-
np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
|
1483
|
-
]
|
1484
|
-
annotated_colors.append(brightest_color)
|
1485
|
-
else:
|
1486
|
-
# Single-node domains default to white (RGBA)
|
1487
|
-
default_color = np.array([1.0, 1.0, 1.0, 1.0])
|
1488
|
-
annotated_colors.append(default_color)
|
1489
|
-
|
1490
|
-
return np.array(annotated_colors)
|
1491
|
-
|
1492
|
-
@staticmethod
|
1493
|
-
def savefig(*args, pad_inches: float = 0.5, dpi: int = 100, **kwargs) -> None:
|
1494
|
-
"""Save the current plot to a file with additional export options.
|
1495
|
-
|
1496
|
-
Args:
|
1497
|
-
*args: Positional arguments passed to `plt.savefig`.
|
1498
|
-
pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
|
1499
|
-
dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
|
1500
|
-
**kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
|
1501
|
-
"""
|
1502
|
-
plt.savefig(*args, bbox_inches="tight", pad_inches=pad_inches, dpi=dpi, **kwargs)
|
1503
|
-
|
1504
|
-
@staticmethod
|
1505
|
-
def show(*args, **kwargs) -> None:
|
1506
|
-
"""Display the current plot.
|
1507
|
-
|
1508
|
-
Args:
|
1509
|
-
*args: Positional arguments passed to `plt.show`.
|
1510
|
-
**kwargs: Keyword arguments passed to `plt.show`.
|
1511
|
-
"""
|
1512
|
-
plt.show(*args, **kwargs)
|
1513
|
-
|
1514
|
-
|
1515
|
-
def _to_rgba(
|
1516
|
-
color: Union[str, List, Tuple, np.ndarray],
|
1517
|
-
alpha: Union[float, None] = None,
|
1518
|
-
num_repeats: Union[int, None] = None,
|
1519
|
-
) -> np.ndarray:
|
1520
|
-
"""Convert color(s) to RGBA format, applying alpha and repeating as needed.
|
1521
|
-
|
1522
|
-
Args:
|
1523
|
-
color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
|
1524
|
-
alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values
|
1525
|
-
found in color.
|
1526
|
-
num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
|
1527
|
-
|
1528
|
-
Returns:
|
1529
|
-
np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
|
1530
|
-
"""
|
1531
|
-
|
1532
|
-
def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
|
1533
|
-
"""Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
|
1534
|
-
# Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
|
1535
|
-
if isinstance(c, str):
|
1536
|
-
# Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
|
1537
|
-
rgba = np.array(mcolors.to_rgba(c))
|
1538
|
-
elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
|
1539
|
-
# Convert RGB (3) or RGBA (4) values to RGBA format
|
1540
|
-
rgba = np.array(mcolors.to_rgba(c))
|
1541
|
-
else:
|
1542
|
-
raise ValueError(
|
1543
|
-
f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
|
1544
|
-
)
|
1545
|
-
|
1546
|
-
if alpha is not None: # Override alpha if provided
|
1547
|
-
rgba[3] = alpha
|
1548
|
-
return rgba
|
1549
|
-
|
1550
|
-
# If color is a 2D array of RGBA values, convert it to a list of lists
|
1551
|
-
if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
|
1552
|
-
color = [list(c) for c in color]
|
1553
|
-
|
1554
|
-
# Handle a single color (string or RGB/RGBA list/tuple)
|
1555
|
-
if isinstance(color, (str, list, tuple)) and not any(
|
1556
|
-
isinstance(c, (list, tuple, np.ndarray)) for c in color
|
1557
|
-
):
|
1558
|
-
rgba_color = convert_to_rgba(color)
|
1559
|
-
if num_repeats:
|
1560
|
-
return np.tile(
|
1561
|
-
rgba_color, (num_repeats, 1)
|
1562
|
-
) # Repeat the color if num_repeats is provided
|
1563
|
-
return np.array([rgba_color]) # Return a single color wrapped in a numpy array
|
1564
|
-
|
1565
|
-
# Handle a list/array of colors
|
1566
|
-
elif isinstance(color, (list, tuple, np.ndarray)):
|
1567
|
-
rgba_colors = np.array(
|
1568
|
-
[convert_to_rgba(c) for c in color]
|
1569
|
-
) # Convert each color in the list to RGBA
|
1570
|
-
# Handle repetition if num_repeats is provided
|
1571
|
-
if num_repeats:
|
1572
|
-
repeated_colors = np.array(
|
1573
|
-
[rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
|
1574
|
-
)
|
1575
|
-
return repeated_colors
|
1576
|
-
|
1577
|
-
return rgba_colors
|
1578
|
-
|
1579
|
-
else:
|
1580
|
-
raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
|
1581
|
-
|
1582
|
-
|
1583
|
-
def _is_connected(z: np.ndarray) -> bool:
|
1584
|
-
"""Determine if a thresholded grid represents a single, connected component.
|
1585
|
-
|
1586
|
-
Args:
|
1587
|
-
z (np.ndarray): A binary grid where the component connectivity is evaluated.
|
1588
|
-
|
1589
|
-
Returns:
|
1590
|
-
bool: True if the grid represents a single connected component, False otherwise.
|
1591
|
-
"""
|
1592
|
-
_, num_features = label(z)
|
1593
|
-
return num_features == 1 # Return True if only one connected component is found
|
1594
|
-
|
1595
|
-
|
1596
|
-
def _calculate_bounding_box(
|
1597
|
-
node_coordinates: np.ndarray, radius_margin: float = 1.05
|
1598
|
-
) -> Tuple[np.ndarray, float]:
|
1599
|
-
"""Calculate the bounding box of the network based on node coordinates.
|
1600
|
-
|
1601
|
-
Args:
|
1602
|
-
node_coordinates (np.ndarray): Array of node coordinates (x, y).
|
1603
|
-
radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
|
1604
|
-
|
1605
|
-
Returns:
|
1606
|
-
tuple: Center of the bounding box and the radius (adjusted by the radius margin).
|
1607
|
-
"""
|
1608
|
-
# Find minimum and maximum x, y coordinates
|
1609
|
-
x_min, y_min = np.min(node_coordinates, axis=0)
|
1610
|
-
x_max, y_max = np.max(node_coordinates, axis=0)
|
1611
|
-
# Calculate the center of the bounding box
|
1612
|
-
center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
|
1613
|
-
# Calculate the radius of the bounding box, adjusted by the margin
|
1614
|
-
radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
|
1615
|
-
return center, radius
|
1616
|
-
|
1617
|
-
|
1618
|
-
def _combine_words(words: List[str], max_chars_per_line: int, max_label_lines: int) -> str:
|
1619
|
-
"""Combine words to fit within the max_chars_per_line and max_label_lines constraints,
|
1620
|
-
and separate the final output by TERM_DELIMITER for plotting.
|
1621
|
-
|
1622
|
-
Args:
|
1623
|
-
words (List[str]): List of words to combine.
|
1624
|
-
max_chars_per_line (int): Maximum number of characters in a line to display.
|
1625
|
-
max_label_lines (int): Maximum number of lines in a label.
|
1626
|
-
|
1627
|
-
Returns:
|
1628
|
-
str: String of combined words separated by ':' for line breaks.
|
1629
|
-
"""
|
1630
|
-
|
1631
|
-
def try_combinations(words_batch: List[str]) -> List[str]:
|
1632
|
-
"""Try to combine words within a batch and return them with combined words separated by ':'."""
|
1633
|
-
combined_lines = []
|
1634
|
-
i = 0
|
1635
|
-
while i < len(words_batch):
|
1636
|
-
current_word = words_batch[i]
|
1637
|
-
combined_word = current_word # Start with the current word
|
1638
|
-
# Try to combine more words if possible, and ensure the combination fits within max_length
|
1639
|
-
for j in range(i + 1, len(words_batch)):
|
1640
|
-
next_word = words_batch[j]
|
1641
|
-
# Ensure that the combined word fits within the max_chars_per_line limit
|
1642
|
-
if len(combined_word) + len(next_word) + 1 <= max_chars_per_line: # +1 for space
|
1643
|
-
combined_word = f"{combined_word} {next_word}"
|
1644
|
-
i += 1 # Move past the combined word
|
1645
|
-
else:
|
1646
|
-
break # Stop combining if the length is exceeded
|
1647
|
-
|
1648
|
-
# Add the combined word only if it fits within the max_chars_per_line limit
|
1649
|
-
if len(combined_word) <= max_chars_per_line:
|
1650
|
-
combined_lines.append(combined_word) # Add the combined word
|
1651
|
-
# Move to the next word
|
1652
|
-
i += 1
|
1653
|
-
|
1654
|
-
# Stop if we've reached the max_label_lines limit
|
1655
|
-
if len(combined_lines) >= max_label_lines:
|
1656
|
-
break
|
1657
|
-
|
1658
|
-
return combined_lines
|
1659
|
-
|
1660
|
-
# Main logic: start with max_label_lines number of words
|
1661
|
-
combined_lines = try_combinations(words[:max_label_lines])
|
1662
|
-
remaining_words = words[max_label_lines:] # Remaining words after the initial batch
|
1663
|
-
|
1664
|
-
# Continue pulling more words until we fill the lines
|
1665
|
-
while remaining_words and len(combined_lines) < max_label_lines:
|
1666
|
-
available_slots = max_label_lines - len(combined_lines)
|
1667
|
-
words_to_add = remaining_words[:available_slots]
|
1668
|
-
remaining_words = remaining_words[available_slots:]
|
1669
|
-
combined_lines += try_combinations(words_to_add)
|
1670
|
-
|
1671
|
-
# Join the final combined lines with TERM_DELIMITER, a special separator for line breaks
|
1672
|
-
return TERM_DELIMITER.join(combined_lines[:max_label_lines])
|
1673
|
-
|
1674
|
-
|
1675
|
-
def _calculate_best_label_positions(
|
1676
|
-
filtered_domain_centroids: Dict[str, Any], center: np.ndarray, radius: float, offset: float
|
1677
|
-
) -> Dict[str, Any]:
|
1678
|
-
"""Calculate and optimize label positions for clarity.
|
1679
|
-
|
1680
|
-
Args:
|
1681
|
-
filtered_domain_centroids (dict): Centroids of the filtered domains.
|
1682
|
-
center (np.ndarray): The center coordinates for label positioning.
|
1683
|
-
radius (float): The radius for positioning labels around the center.
|
1684
|
-
offset (float): The offset distance from the radius for positioning labels.
|
1685
|
-
|
1686
|
-
Returns:
|
1687
|
-
dict: Optimized positions for labels.
|
1688
|
-
"""
|
1689
|
-
num_domains = len(filtered_domain_centroids)
|
1690
|
-
# Calculate equidistant positions around the center for initial label placement
|
1691
|
-
equidistant_positions = _calculate_equidistant_positions_around_center(
|
1692
|
-
center, radius, offset, num_domains
|
1693
|
-
)
|
1694
|
-
# Create a mapping of domains to their initial label positions
|
1695
|
-
label_positions = {
|
1696
|
-
domain: position
|
1697
|
-
for domain, position in zip(filtered_domain_centroids.keys(), equidistant_positions)
|
1698
|
-
}
|
1699
|
-
# Optimize the label positions to minimize distance to domain centroids
|
1700
|
-
return _optimize_label_positions(label_positions, filtered_domain_centroids)
|
1701
|
-
|
1702
|
-
|
1703
|
-
def _calculate_equidistant_positions_around_center(
|
1704
|
-
center: np.ndarray, radius: float, label_offset: float, num_domains: int
|
1705
|
-
) -> List[np.ndarray]:
|
1706
|
-
"""Calculate positions around a center at equidistant angles.
|
1707
|
-
|
1708
|
-
Args:
|
1709
|
-
center (np.ndarray): The central point around which positions are calculated.
|
1710
|
-
radius (float): The radius at which positions are calculated.
|
1711
|
-
label_offset (float): The offset added to the radius for label positioning.
|
1712
|
-
num_domains (int): The number of positions (or domains) to calculate.
|
1713
|
-
|
1714
|
-
Returns:
|
1715
|
-
list[np.ndarray]: List of positions (as 2D numpy arrays) around the center.
|
1716
|
-
"""
|
1717
|
-
# Calculate equidistant angles in radians around the center
|
1718
|
-
angles = np.linspace(0, 2 * np.pi, num_domains, endpoint=False)
|
1719
|
-
# Compute the positions around the center using the angles
|
1720
|
-
return [
|
1721
|
-
center + (radius + label_offset) * np.array([np.cos(angle), np.sin(angle)])
|
1722
|
-
for angle in angles
|
1723
|
-
]
|
1724
|
-
|
1725
|
-
|
1726
|
-
def _optimize_label_positions(
|
1727
|
-
best_label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
|
1728
|
-
) -> Dict[str, Any]:
|
1729
|
-
"""Optimize label positions around the perimeter to minimize total distance to centroids.
|
1730
|
-
|
1731
|
-
Args:
|
1732
|
-
best_label_positions (dict): Initial positions of labels around the perimeter.
|
1733
|
-
domain_centroids (dict): Centroid positions of the domains.
|
1734
|
-
|
1735
|
-
Returns:
|
1736
|
-
dict: Optimized label positions.
|
1737
|
-
"""
|
1738
|
-
while True:
|
1739
|
-
improvement = False # Start each iteration assuming no improvement
|
1740
|
-
# Iterate through each pair of labels to check for potential improvements
|
1741
|
-
for i in range(len(domain_centroids)):
|
1742
|
-
for j in range(i + 1, len(domain_centroids)):
|
1743
|
-
# Calculate the current total distance
|
1744
|
-
current_distance = _calculate_total_distance(best_label_positions, domain_centroids)
|
1745
|
-
# Evaluate the total distance after swapping two labels
|
1746
|
-
swapped_distance = _swap_and_evaluate(best_label_positions, i, j, domain_centroids)
|
1747
|
-
# If the swap improves the total distance, perform the swap
|
1748
|
-
if swapped_distance < current_distance:
|
1749
|
-
labels = list(best_label_positions.keys())
|
1750
|
-
best_label_positions[labels[i]], best_label_positions[labels[j]] = (
|
1751
|
-
best_label_positions[labels[j]],
|
1752
|
-
best_label_positions[labels[i]],
|
1753
|
-
)
|
1754
|
-
improvement = True # Found an improvement, so continue optimizing
|
1755
|
-
|
1756
|
-
if not improvement:
|
1757
|
-
break # Exit the loop if no improvement was found in this iteration
|
1758
|
-
|
1759
|
-
return best_label_positions
|
1760
|
-
|
1761
|
-
|
1762
|
-
def _calculate_total_distance(
|
1763
|
-
label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
|
1764
|
-
) -> float:
|
1765
|
-
"""Calculate the total distance from label positions to their domain centroids.
|
1766
|
-
|
1767
|
-
Args:
|
1768
|
-
label_positions (dict): Positions of labels around the perimeter.
|
1769
|
-
domain_centroids (dict): Centroid positions of the domains.
|
1770
|
-
|
1771
|
-
Returns:
|
1772
|
-
float: The total distance from labels to centroids.
|
1773
|
-
"""
|
1774
|
-
total_distance = 0
|
1775
|
-
# Iterate through each domain and calculate the distance to its centroid
|
1776
|
-
for domain, pos in label_positions.items():
|
1777
|
-
centroid = domain_centroids[domain]
|
1778
|
-
total_distance += np.linalg.norm(centroid - pos)
|
1779
|
-
|
1780
|
-
return total_distance
|
1781
|
-
|
1782
|
-
|
1783
|
-
def _swap_and_evaluate(
|
1784
|
-
label_positions: Dict[str, Any],
|
1785
|
-
i: int,
|
1786
|
-
j: int,
|
1787
|
-
domain_centroids: Dict[str, Any],
|
1788
|
-
) -> float:
|
1789
|
-
"""Swap two labels and evaluate the total distance after the swap.
|
1790
|
-
|
1791
|
-
Args:
|
1792
|
-
label_positions (dict): Positions of labels around the perimeter.
|
1793
|
-
i (int): Index of the first label to swap.
|
1794
|
-
j (int): Index of the second label to swap.
|
1795
|
-
domain_centroids (dict): Centroid positions of the domains.
|
1796
|
-
|
1797
|
-
Returns:
|
1798
|
-
float: The total distance after swapping the two labels.
|
1799
|
-
"""
|
1800
|
-
# Get the list of labels from the dictionary keys
|
1801
|
-
labels = list(label_positions.keys())
|
1802
|
-
swapped_positions = label_positions.copy()
|
1803
|
-
# Swap the positions of the two specified labels
|
1804
|
-
swapped_positions[labels[i]], swapped_positions[labels[j]] = (
|
1805
|
-
swapped_positions[labels[j]],
|
1806
|
-
swapped_positions[labels[i]],
|
1807
|
-
)
|
1808
|
-
# Calculate and return the total distance after the swap
|
1809
|
-
return _calculate_total_distance(swapped_positions, domain_centroids)
|