risk-network 0.0.3b4__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -4
- risk/annotations/annotations.py +4 -2
- risk/annotations/io.py +1 -1
- risk/neighborhoods/neighborhoods.py +15 -2
- risk/network/geometry.py +2 -2
- risk/network/graph.py +4 -4
- risk/network/io.py +234 -53
- risk/network/plot.py +179 -58
- risk/risk.py +187 -75
- risk/stats/__init__.py +4 -1
- risk/stats/fisher_exact.py +132 -0
- risk/stats/hypergeom.py +131 -0
- risk/stats/permutation/__init__.py +6 -0
- risk/stats/permutation/permutation.py +212 -0
- risk/stats/{permutation.py → permutation/test_functions.py} +12 -39
- risk/stats/stats.py +1 -212
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/METADATA +6 -6
- risk_network-0.0.4.dist-info/RECORD +30 -0
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/WHEEL +1 -1
- risk_network-0.0.3b4.dist-info/RECORD +0 -26
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/LICENSE +0 -0
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/top_level.txt +0 -0
risk/network/plot.py
CHANGED
@@ -27,7 +27,7 @@ class NetworkPlotter:
|
|
27
27
|
|
28
28
|
def __init__(
|
29
29
|
self,
|
30
|
-
|
30
|
+
graph: NetworkGraph,
|
31
31
|
figsize: tuple = (10, 10),
|
32
32
|
background_color: str = "white",
|
33
33
|
plot_outline: bool = True,
|
@@ -37,22 +37,22 @@ class NetworkPlotter:
|
|
37
37
|
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
38
38
|
|
39
39
|
Args:
|
40
|
-
|
40
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
41
41
|
figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
|
42
42
|
background_color (str, optional): Background color of the plot. Defaults to "white".
|
43
43
|
plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
|
44
44
|
outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
|
45
45
|
outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
|
46
46
|
"""
|
47
|
-
self.
|
47
|
+
self.graph = graph
|
48
48
|
# Initialize the plot with the specified parameters
|
49
49
|
self.ax = self._initialize_plot(
|
50
|
-
|
50
|
+
graph, figsize, background_color, plot_outline, outline_color, outline_scale
|
51
51
|
)
|
52
52
|
|
53
53
|
def _initialize_plot(
|
54
54
|
self,
|
55
|
-
|
55
|
+
graph: NetworkGraph,
|
56
56
|
figsize: tuple,
|
57
57
|
background_color: str,
|
58
58
|
plot_outline: bool,
|
@@ -62,7 +62,7 @@ class NetworkPlotter:
|
|
62
62
|
"""Set up the plot with figure size, optional circle perimeter, and background color.
|
63
63
|
|
64
64
|
Args:
|
65
|
-
|
65
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
66
66
|
figsize (tuple): Size of the figure in inches (width, height).
|
67
67
|
background_color (str): Background color of the plot.
|
68
68
|
plot_outline (bool): Whether to plot the network perimeter circle.
|
@@ -73,7 +73,7 @@ class NetworkPlotter:
|
|
73
73
|
plt.Axes: The axis object for the plot.
|
74
74
|
"""
|
75
75
|
# Extract node coordinates from the network graph
|
76
|
-
node_coordinates =
|
76
|
+
node_coordinates = graph.node_coordinates
|
77
77
|
# Calculate the center and radius of the bounding box around the network
|
78
78
|
center, radius = _calculate_bounding_box(node_coordinates)
|
79
79
|
# Scale the radius by the outline_scale factor
|
@@ -141,10 +141,10 @@ class NetworkPlotter:
|
|
141
141
|
network_node_shape=node_shape,
|
142
142
|
)
|
143
143
|
# Extract node coordinates from the network graph
|
144
|
-
node_coordinates = self.
|
144
|
+
node_coordinates = self.graph.node_coordinates
|
145
145
|
# Draw the nodes of the graph
|
146
146
|
nx.draw_networkx_nodes(
|
147
|
-
self.
|
147
|
+
self.graph.network,
|
148
148
|
pos=node_coordinates,
|
149
149
|
node_size=node_size,
|
150
150
|
node_color=node_color,
|
@@ -155,7 +155,7 @@ class NetworkPlotter:
|
|
155
155
|
)
|
156
156
|
# Draw the edges of the graph
|
157
157
|
nx.draw_networkx_edges(
|
158
|
-
self.
|
158
|
+
self.graph.network,
|
159
159
|
pos=node_coordinates,
|
160
160
|
width=edge_width,
|
161
161
|
edge_color=edge_color,
|
@@ -197,20 +197,18 @@ class NetworkPlotter:
|
|
197
197
|
)
|
198
198
|
# Filter to get node IDs and their coordinates
|
199
199
|
node_ids = [
|
200
|
-
self.
|
200
|
+
self.graph.node_label_to_id_map.get(node)
|
201
201
|
for node in nodes
|
202
|
-
if node in self.
|
202
|
+
if node in self.graph.node_label_to_id_map
|
203
203
|
]
|
204
204
|
if not node_ids:
|
205
205
|
raise ValueError("No nodes found in the network graph.")
|
206
206
|
|
207
207
|
# Get the coordinates of the filtered nodes
|
208
|
-
node_coordinates = {
|
209
|
-
node_id: self.network_graph.node_coordinates[node_id] for node_id in node_ids
|
210
|
-
}
|
208
|
+
node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
|
211
209
|
# Draw the nodes in the subnetwork
|
212
210
|
nx.draw_networkx_nodes(
|
213
|
-
self.
|
211
|
+
self.graph.network,
|
214
212
|
pos=node_coordinates,
|
215
213
|
nodelist=node_ids,
|
216
214
|
node_size=node_size,
|
@@ -221,7 +219,7 @@ class NetworkPlotter:
|
|
221
219
|
ax=self.ax,
|
222
220
|
)
|
223
221
|
# Draw the edges between the specified nodes in the subnetwork
|
224
|
-
subgraph = self.
|
222
|
+
subgraph = self.graph.network.subgraph(node_ids)
|
225
223
|
nx.draw_networkx_edges(
|
226
224
|
subgraph,
|
227
225
|
pos=node_coordinates,
|
@@ -260,9 +258,9 @@ class NetworkPlotter:
|
|
260
258
|
color = self.get_annotated_contour_colors(color=color)
|
261
259
|
|
262
260
|
# Extract node coordinates from the network graph
|
263
|
-
node_coordinates = self.
|
261
|
+
node_coordinates = self.graph.node_coordinates
|
264
262
|
# Draw contours for each domain in the network
|
265
|
-
for idx, (_, nodes) in enumerate(self.
|
263
|
+
for idx, (_, nodes) in enumerate(self.graph.domain_to_nodes.items()):
|
266
264
|
if len(nodes) > 1:
|
267
265
|
self._draw_kde_contour(
|
268
266
|
self.ax,
|
@@ -299,23 +297,23 @@ class NetworkPlotter:
|
|
299
297
|
"""
|
300
298
|
# Log the plotting parameters
|
301
299
|
params.log_plotter(
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
300
|
+
subcontour_levels=levels,
|
301
|
+
subcontour_bandwidth=bandwidth,
|
302
|
+
subcontour_grid_size=grid_size,
|
303
|
+
subcontour_alpha=alpha,
|
304
|
+
subcontour_color="custom" if isinstance(color, np.ndarray) else color,
|
307
305
|
)
|
308
306
|
# Filter to get node IDs and their coordinates
|
309
307
|
node_ids = [
|
310
|
-
self.
|
308
|
+
self.graph.node_label_to_id_map.get(node)
|
311
309
|
for node in nodes
|
312
|
-
if node in self.
|
310
|
+
if node in self.graph.node_label_to_id_map
|
313
311
|
]
|
314
312
|
if not node_ids or len(node_ids) == 1:
|
315
313
|
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
316
314
|
|
317
315
|
# Draw the KDE contour for the specified nodes
|
318
|
-
node_coordinates = self.
|
316
|
+
node_coordinates = self.graph.node_coordinates
|
319
317
|
self._draw_kde_contour(
|
320
318
|
self.ax,
|
321
319
|
node_coordinates,
|
@@ -402,8 +400,10 @@ class NetworkPlotter:
|
|
402
400
|
fontcolor: Union[str, np.ndarray] = "black",
|
403
401
|
arrow_linewidth: float = 1,
|
404
402
|
arrow_color: Union[str, np.ndarray] = "black",
|
405
|
-
|
403
|
+
max_labels: Union[int, None] = None,
|
404
|
+
max_words: int = 10,
|
406
405
|
min_words: int = 1,
|
406
|
+
words_to_omit: Union[List[str], None] = None,
|
407
407
|
) -> None:
|
408
408
|
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
409
409
|
|
@@ -415,8 +415,10 @@ class NetworkPlotter:
|
|
415
415
|
fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
|
416
416
|
arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
|
417
417
|
arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
|
418
|
-
|
418
|
+
max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
|
419
|
+
max_words (int, optional): Maximum number of words in a label. Defaults to 10.
|
419
420
|
min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
|
421
|
+
words_to_omit (List[str], optional): List of words to omit from the labels. Defaults to None.
|
420
422
|
"""
|
421
423
|
# Log the plotting parameters
|
422
424
|
params.log_plotter(
|
@@ -427,39 +429,82 @@ class NetworkPlotter:
|
|
427
429
|
label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
|
428
430
|
label_arrow_linewidth=arrow_linewidth,
|
429
431
|
label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
430
|
-
|
432
|
+
label_max_labels=max_labels,
|
433
|
+
label_max_words=max_words,
|
431
434
|
label_min_words=min_words,
|
435
|
+
label_words_to_omit=words_to_omit, # Log words_to_omit parameter
|
432
436
|
)
|
437
|
+
|
433
438
|
# Convert color strings to RGBA arrays if necessary
|
434
439
|
if isinstance(fontcolor, str):
|
435
|
-
fontcolor = self.
|
440
|
+
fontcolor = self.get_annotated_label_colors(color=fontcolor)
|
436
441
|
if isinstance(arrow_color, str):
|
437
|
-
arrow_color = self.
|
442
|
+
arrow_color = self.get_annotated_label_colors(color=arrow_color)
|
443
|
+
# Normalize words_to_omit to lowercase
|
444
|
+
if words_to_omit:
|
445
|
+
words_to_omit = set(word.lower() for word in words_to_omit)
|
438
446
|
|
439
447
|
# Calculate the center and radius of the network
|
440
448
|
domain_centroids = {}
|
441
|
-
for domain, nodes in self.
|
449
|
+
for domain, nodes in self.graph.domain_to_nodes.items():
|
442
450
|
if nodes: # Skip if the domain has no nodes
|
443
451
|
domain_centroids[domain] = self._calculate_domain_centroid(nodes)
|
444
452
|
|
453
|
+
# Initialize empty lists to collect valid indices
|
454
|
+
valid_indices = []
|
455
|
+
filtered_domain_centroids = {}
|
456
|
+
filtered_domain_terms = {}
|
457
|
+
# Loop through domain_centroids with index
|
458
|
+
for idx, (domain, centroid) in enumerate(domain_centroids.items()):
|
459
|
+
# Process the domain term
|
460
|
+
terms = self.graph.trimmed_domain_to_term[domain].split(" ")
|
461
|
+
# Remove words_to_omit
|
462
|
+
if words_to_omit:
|
463
|
+
terms = [term for term in terms if term.lower() not in words_to_omit]
|
464
|
+
# Trim to max_words
|
465
|
+
terms = terms[:max_words]
|
466
|
+
# Check if the domain passes the word count condition
|
467
|
+
if len(terms) >= min_words:
|
468
|
+
# Add to filtered_domain_centroids
|
469
|
+
filtered_domain_centroids[domain] = centroid
|
470
|
+
# Store the trimmed terms
|
471
|
+
filtered_domain_terms[domain] = " ".join(terms)
|
472
|
+
# Keep track of the valid index
|
473
|
+
valid_indices.append(idx)
|
474
|
+
|
475
|
+
# If max_labels is specified and less than the available labels
|
476
|
+
if max_labels is not None and max_labels < len(filtered_domain_centroids):
|
477
|
+
step = len(filtered_domain_centroids) / max_labels
|
478
|
+
selected_indices = [int(i * step) for i in range(max_labels)]
|
479
|
+
filtered_domain_centroids = {
|
480
|
+
k: v
|
481
|
+
for i, (k, v) in enumerate(filtered_domain_centroids.items())
|
482
|
+
if i in selected_indices
|
483
|
+
}
|
484
|
+
filtered_domain_terms = {
|
485
|
+
k: v
|
486
|
+
for i, (k, v) in enumerate(filtered_domain_terms.items())
|
487
|
+
if i in selected_indices
|
488
|
+
}
|
489
|
+
fontcolor = fontcolor[selected_indices]
|
490
|
+
arrow_color = arrow_color[selected_indices]
|
491
|
+
|
492
|
+
# Update the terms in the graph after omitting words and filtering
|
493
|
+
for domain, terms in filtered_domain_terms.items():
|
494
|
+
self.graph.trimmed_domain_to_term[domain] = terms
|
495
|
+
|
445
496
|
# Calculate the bounding box around the network
|
446
497
|
center, radius = _calculate_bounding_box(
|
447
|
-
self.
|
498
|
+
self.graph.node_coordinates, radius_margin=perimeter_scale
|
448
499
|
)
|
449
|
-
|
450
|
-
# Filter out domains with insufficient words for labeling
|
451
|
-
filtered_domains = {
|
452
|
-
domain: centroid
|
453
|
-
for domain, centroid in domain_centroids.items()
|
454
|
-
if len(self.network_graph.trimmed_domain_to_term[domain].split(" ")[:num_words])
|
455
|
-
>= min_words
|
456
|
-
}
|
457
500
|
# Calculate the best positions for labels around the perimeter
|
458
|
-
best_label_positions = _best_label_positions(
|
501
|
+
best_label_positions = _best_label_positions(
|
502
|
+
filtered_domain_centroids, center, radius, offset
|
503
|
+
)
|
459
504
|
# Annotate the network with labels
|
460
505
|
for idx, (domain, pos) in enumerate(best_label_positions.items()):
|
461
|
-
centroid =
|
462
|
-
annotations = self.
|
506
|
+
centroid = filtered_domain_centroids[domain]
|
507
|
+
annotations = self.graph.trimmed_domain_to_term[domain].split(" ")[:max_words]
|
463
508
|
self.ax.annotate(
|
464
509
|
"\n".join(annotations),
|
465
510
|
xy=centroid,
|
@@ -473,6 +518,81 @@ class NetworkPlotter:
|
|
473
518
|
arrowprops=dict(arrowstyle="->", color=arrow_color[idx], linewidth=arrow_linewidth),
|
474
519
|
)
|
475
520
|
|
521
|
+
def plot_sublabel(
|
522
|
+
self,
|
523
|
+
nodes: list,
|
524
|
+
label: str,
|
525
|
+
radial_position: float = 0.0,
|
526
|
+
perimeter_scale: float = 1.05,
|
527
|
+
offset: float = 0.10,
|
528
|
+
font: str = "Arial",
|
529
|
+
fontsize: int = 10,
|
530
|
+
fontcolor: str = "black",
|
531
|
+
arrow_linewidth: float = 1,
|
532
|
+
arrow_color: str = "black",
|
533
|
+
) -> None:
|
534
|
+
"""Annotate the network graph with a single label for the given nodes, positioned at a specified radial angle.
|
535
|
+
|
536
|
+
Args:
|
537
|
+
nodes (List[str]): List of node labels to be used for calculating the centroid.
|
538
|
+
label (str): The label to be annotated on the network.
|
539
|
+
radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
|
540
|
+
perimeter_scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
|
541
|
+
offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
|
542
|
+
font (str, optional): Font name for the label. Defaults to "Arial".
|
543
|
+
fontsize (int, optional): Font size for the label. Defaults to 10.
|
544
|
+
fontcolor (str, optional): Color of the label text. Defaults to "black".
|
545
|
+
arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
|
546
|
+
arrow_color (str, optional): Color of the arrow. Defaults to "black".
|
547
|
+
"""
|
548
|
+
# Log the plotting parameters
|
549
|
+
params.log_plotter(
|
550
|
+
sublabel_perimeter_scale=perimeter_scale,
|
551
|
+
sublabel_offset=offset,
|
552
|
+
sublabel_font=font,
|
553
|
+
sublabel_fontsize=fontsize,
|
554
|
+
sublabel_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
|
555
|
+
sublabel_arrow_linewidth=arrow_linewidth,
|
556
|
+
sublabel_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
557
|
+
sublabel_radial_position=radial_position,
|
558
|
+
)
|
559
|
+
|
560
|
+
# Map node labels to IDs
|
561
|
+
node_ids = [
|
562
|
+
self.graph.node_label_to_id_map.get(node)
|
563
|
+
for node in nodes
|
564
|
+
if node in self.graph.node_label_to_id_map
|
565
|
+
]
|
566
|
+
if not node_ids or len(node_ids) == 1:
|
567
|
+
raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
|
568
|
+
|
569
|
+
# Calculate the centroid of the provided nodes
|
570
|
+
centroid = self._calculate_domain_centroid(node_ids)
|
571
|
+
# Calculate the bounding box around the network
|
572
|
+
center, radius = _calculate_bounding_box(
|
573
|
+
self.graph.node_coordinates, radius_margin=perimeter_scale
|
574
|
+
)
|
575
|
+
# Convert radial position to radians, adjusting for a 90-degree rotation
|
576
|
+
radial_radians = np.deg2rad(radial_position - 90)
|
577
|
+
label_position = (
|
578
|
+
center[0] + (radius + offset) * np.cos(radial_radians),
|
579
|
+
center[1] + (radius + offset) * np.sin(radial_radians),
|
580
|
+
)
|
581
|
+
|
582
|
+
# Annotate the network with the label
|
583
|
+
self.ax.annotate(
|
584
|
+
label,
|
585
|
+
xy=centroid,
|
586
|
+
xytext=label_position,
|
587
|
+
textcoords="data",
|
588
|
+
ha="center",
|
589
|
+
va="center",
|
590
|
+
fontsize=fontsize,
|
591
|
+
fontname=font,
|
592
|
+
color=fontcolor,
|
593
|
+
arrowprops=dict(arrowstyle="->", color=arrow_color, linewidth=arrow_linewidth),
|
594
|
+
)
|
595
|
+
|
476
596
|
def _calculate_domain_centroid(self, nodes: list) -> tuple:
|
477
597
|
"""Calculate the most centrally located node in .
|
478
598
|
|
@@ -483,7 +603,7 @@ class NetworkPlotter:
|
|
483
603
|
tuple: A tuple containing the domain's central node coordinates.
|
484
604
|
"""
|
485
605
|
# Extract positions of all nodes in the domain
|
486
|
-
node_positions = self.
|
606
|
+
node_positions = self.graph.node_coordinates[nodes, :]
|
487
607
|
# Calculate the pairwise distance matrix between all nodes in the domain
|
488
608
|
distances_matrix = np.linalg.norm(node_positions[:, np.newaxis] - node_positions, axis=2)
|
489
609
|
# Sum the distances for each node to all other nodes in the domain
|
@@ -508,7 +628,7 @@ class NetworkPlotter:
|
|
508
628
|
np.ndarray: Array of RGBA colors adjusted for enrichment status.
|
509
629
|
"""
|
510
630
|
# Get the initial domain colors for each node
|
511
|
-
network_colors = self.
|
631
|
+
network_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
|
512
632
|
if isinstance(nonenriched_color, str):
|
513
633
|
# Convert the non-enriched color from string to RGBA
|
514
634
|
nonenriched_color = mcolors.to_rgba(nonenriched_color)
|
@@ -535,14 +655,14 @@ class NetworkPlotter:
|
|
535
655
|
"""
|
536
656
|
# Merge all enriched nodes from the domain_to_nodes dictionary
|
537
657
|
enriched_nodes = set()
|
538
|
-
for _, nodes in self.
|
658
|
+
for _, nodes in self.graph.domain_to_nodes.items():
|
539
659
|
enriched_nodes.update(nodes)
|
540
660
|
|
541
661
|
# Initialize all node sizes to the non-enriched size
|
542
|
-
node_sizes = np.full(len(self.
|
662
|
+
node_sizes = np.full(len(self.graph.network.nodes), nonenriched_nodesize)
|
543
663
|
# Set the size for enriched nodes
|
544
664
|
for node in enriched_nodes:
|
545
|
-
if node in self.
|
665
|
+
if node in self.graph.network.nodes:
|
546
666
|
node_sizes[node] = enriched_nodesize
|
547
667
|
|
548
668
|
return node_sizes
|
@@ -587,12 +707,12 @@ class NetworkPlotter:
|
|
587
707
|
if isinstance(color, str):
|
588
708
|
# If a single color string is provided, convert it to RGBA and apply to all domains
|
589
709
|
rgba_color = np.array(matplotlib.colors.to_rgba(color))
|
590
|
-
return np.array([rgba_color for _ in self.
|
710
|
+
return np.array([rgba_color for _ in self.graph.domain_to_nodes])
|
591
711
|
|
592
712
|
# Generate colors for each domain using the provided arguments and random seed
|
593
|
-
node_colors = self.
|
713
|
+
node_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
|
594
714
|
annotated_colors = []
|
595
|
-
for _, nodes in self.
|
715
|
+
for _, nodes in self.graph.domain_to_nodes.items():
|
596
716
|
if len(nodes) > 1:
|
597
717
|
# For domains with multiple nodes, choose the brightest color (sum of RGB values)
|
598
718
|
domain_colors = np.array([node_colors[node] for node in nodes])
|
@@ -662,12 +782,12 @@ def _calculate_bounding_box(
|
|
662
782
|
|
663
783
|
|
664
784
|
def _best_label_positions(
|
665
|
-
|
785
|
+
filtered_domain_centroids: Dict[str, Any], center: np.ndarray, radius: float, offset: float
|
666
786
|
) -> Dict[str, Any]:
|
667
787
|
"""Calculate and optimize label positions for clarity.
|
668
788
|
|
669
789
|
Args:
|
670
|
-
|
790
|
+
filtered_domain_centroids (dict): Centroids of the filtered domains.
|
671
791
|
center (np.ndarray): The center coordinates for label positioning.
|
672
792
|
radius (float): The radius for positioning labels around the center.
|
673
793
|
offset (float): The offset distance from the radius for positioning labels.
|
@@ -675,15 +795,16 @@ def _best_label_positions(
|
|
675
795
|
Returns:
|
676
796
|
dict: Optimized positions for labels.
|
677
797
|
"""
|
678
|
-
num_domains = len(
|
798
|
+
num_domains = len(filtered_domain_centroids)
|
679
799
|
# Calculate equidistant positions around the center for initial label placement
|
680
800
|
equidistant_positions = _equidistant_angles_around_center(center, radius, offset, num_domains)
|
681
801
|
# Create a mapping of domains to their initial label positions
|
682
802
|
label_positions = {
|
683
|
-
domain: position
|
803
|
+
domain: position
|
804
|
+
for domain, position in zip(filtered_domain_centroids.keys(), equidistant_positions)
|
684
805
|
}
|
685
806
|
# Optimize the label positions to minimize distance to domain centroids
|
686
|
-
return _optimize_label_positions(label_positions,
|
807
|
+
return _optimize_label_positions(label_positions, filtered_domain_centroids)
|
687
808
|
|
688
809
|
|
689
810
|
def _equidistant_angles_around_center(
|