risk-network 0.0.8b3__tar.gz → 0.0.8b5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/PKG-INFO +1 -1
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/__init__.py +1 -1
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/plot.py +204 -102
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/PKG-INFO +1 -1
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/LICENSE +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/MANIFEST.in +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/README.md +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/pyproject.toml +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/annotations/__init__.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/annotations/annotations.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/annotations/io.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/constants.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/log/__init__.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/log/config.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/log/params.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/__init__.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/community.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/domains.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/neighborhoods/neighborhoods.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/__init__.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/geometry.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/graph.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/network/io.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/risk.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/__init__.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/hypergeom.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/permutation/__init__.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/permutation/permutation.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/permutation/test_functions.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/poisson.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk/stats/stats.py +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/SOURCES.txt +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/dependency_links.txt +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/requires.txt +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/risk_network.egg-info/top_level.txt +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/setup.cfg +0 -0
- {risk_network-0.0.8b3 → risk_network-0.0.8b5}/setup.py +0 -0
@@ -33,6 +33,7 @@ class NetworkPlotter:
|
|
33
33
|
graph: NetworkGraph,
|
34
34
|
figsize: Tuple = (10, 10),
|
35
35
|
background_color: Union[str, List, Tuple, np.ndarray] = "white",
|
36
|
+
background_alpha: float = 1.0,
|
36
37
|
) -> None:
|
37
38
|
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
38
39
|
|
@@ -40,16 +41,23 @@ class NetworkPlotter:
|
|
40
41
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
41
42
|
figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
|
42
43
|
background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
|
44
|
+
background_alpha (float, optional): Transparency level of the background color. Defaults to 1.0.
|
43
45
|
"""
|
44
46
|
self.graph = graph
|
45
47
|
# Initialize the plot with the specified parameters
|
46
|
-
self.ax = self._initialize_plot(
|
48
|
+
self.ax = self._initialize_plot(
|
49
|
+
graph=graph,
|
50
|
+
figsize=figsize,
|
51
|
+
background_color=background_color,
|
52
|
+
background_alpha=background_alpha,
|
53
|
+
)
|
47
54
|
|
48
55
|
def _initialize_plot(
|
49
56
|
self,
|
50
57
|
graph: NetworkGraph,
|
51
58
|
figsize: Tuple,
|
52
59
|
background_color: Union[str, List, Tuple, np.ndarray],
|
60
|
+
background_alpha: float = 1.0,
|
53
61
|
) -> plt.Axes:
|
54
62
|
"""Set up the plot with figure size and background color.
|
55
63
|
|
@@ -57,6 +65,7 @@ class NetworkPlotter:
|
|
57
65
|
graph (NetworkGraph): The network data and attributes to be visualized.
|
58
66
|
figsize (tuple): Size of the figure in inches (width, height).
|
59
67
|
background_color (str): Background color of the plot.
|
68
|
+
background_alpha (float, optional): Transparency level of the background color. Defaults to 1.0.
|
60
69
|
|
61
70
|
Returns:
|
62
71
|
plt.Axes: The axis object for the plot.
|
@@ -76,7 +85,7 @@ class NetworkPlotter:
|
|
76
85
|
|
77
86
|
# Set the background color of the plot
|
78
87
|
# Convert color to RGBA using the _to_rgba helper function
|
79
|
-
fig.patch.set_facecolor(_to_rgba(background_color,
|
88
|
+
fig.patch.set_facecolor(_to_rgba(color=background_color, alpha=background_alpha))
|
80
89
|
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
81
90
|
# Remove axis spines for a cleaner look
|
82
91
|
for spine in ax.spines.values():
|
@@ -199,7 +208,7 @@ class NetworkPlotter:
|
|
199
208
|
)
|
200
209
|
|
201
210
|
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
202
|
-
color = _to_rgba(color, outline_alpha)
|
211
|
+
color = _to_rgba(color=color, alpha=outline_alpha)
|
203
212
|
# Extract node coordinates from the network graph
|
204
213
|
node_coordinates = self.graph.node_coordinates
|
205
214
|
# Calculate the center and radius of the bounding box around the network
|
@@ -218,7 +227,7 @@ class NetworkPlotter:
|
|
218
227
|
)
|
219
228
|
# Set the transparency of the fill if applicable
|
220
229
|
if fill_alpha > 0:
|
221
|
-
circle.set_facecolor(_to_rgba(color, fill_alpha))
|
230
|
+
circle.set_facecolor(_to_rgba(color=color, alpha=fill_alpha))
|
222
231
|
|
223
232
|
self.ax.add_artist(circle)
|
224
233
|
|
@@ -263,7 +272,7 @@ class NetworkPlotter:
|
|
263
272
|
)
|
264
273
|
|
265
274
|
# Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
|
266
|
-
color = _to_rgba(color, outline_alpha)
|
275
|
+
color = _to_rgba(color=color, alpha=outline_alpha)
|
267
276
|
# Extract node coordinates from the network graph
|
268
277
|
node_coordinates = self.graph.node_coordinates
|
269
278
|
# Scale the node coordinates if needed
|
@@ -326,9 +335,15 @@ class NetworkPlotter:
|
|
326
335
|
|
327
336
|
# Convert colors to RGBA using the _to_rgba helper function
|
328
337
|
# If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
|
329
|
-
node_color = _to_rgba(
|
330
|
-
|
331
|
-
|
338
|
+
node_color = _to_rgba(
|
339
|
+
color=node_color, alpha=node_alpha, num_repeats=len(self.graph.network.nodes)
|
340
|
+
)
|
341
|
+
node_edgecolor = _to_rgba(
|
342
|
+
color=node_edgecolor, alpha=1.0, num_repeats=len(self.graph.network.nodes)
|
343
|
+
)
|
344
|
+
edge_color = _to_rgba(
|
345
|
+
color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
|
346
|
+
)
|
332
347
|
|
333
348
|
# Extract node coordinates from the network graph
|
334
349
|
node_coordinates = self.graph.node_coordinates
|
@@ -405,9 +420,11 @@ class NetworkPlotter:
|
|
405
420
|
]
|
406
421
|
|
407
422
|
# Convert colors to RGBA using the _to_rgba helper function
|
408
|
-
node_color = _to_rgba(node_color, node_alpha, num_repeats=len(node_ids))
|
409
|
-
node_edgecolor = _to_rgba(node_edgecolor, 1.0, num_repeats=len(node_ids))
|
410
|
-
edge_color = _to_rgba(
|
423
|
+
node_color = _to_rgba(color=node_color, alpha=node_alpha, num_repeats=len(node_ids))
|
424
|
+
node_edgecolor = _to_rgba(color=node_edgecolor, alpha=1.0, num_repeats=len(node_ids))
|
425
|
+
edge_color = _to_rgba(
|
426
|
+
color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
|
427
|
+
)
|
411
428
|
|
412
429
|
# Get the coordinates of the filtered nodes
|
413
430
|
node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
|
@@ -470,7 +487,9 @@ class NetworkPlotter:
|
|
470
487
|
)
|
471
488
|
|
472
489
|
# Ensure color is converted to RGBA with repetition matching the number of domains
|
473
|
-
color = _to_rgba(
|
490
|
+
color = _to_rgba(
|
491
|
+
color=color, alpha=alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
492
|
+
)
|
474
493
|
# Extract node coordinates from the network graph
|
475
494
|
node_coordinates = self.graph.node_coordinates
|
476
495
|
# Draw contours for each domain in the network
|
@@ -527,7 +546,7 @@ class NetworkPlotter:
|
|
527
546
|
node_groups = [nodes]
|
528
547
|
|
529
548
|
# Convert color to RGBA using the _to_rgba helper function
|
530
|
-
color_rgba = _to_rgba(color, alpha)
|
549
|
+
color_rgba = _to_rgba(color=color, alpha=alpha)
|
531
550
|
|
532
551
|
# Iterate over each group of nodes (either sublists or flat list)
|
533
552
|
for sublist in node_groups:
|
@@ -705,10 +724,10 @@ class NetworkPlotter:
|
|
705
724
|
arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
|
706
725
|
arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
|
707
726
|
max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
|
708
|
-
max_label_lines (int, optional): Maximum number of lines in a label. Defaults to None (no limit).
|
709
727
|
min_label_lines (int, optional): Minimum number of lines in a label. Defaults to 1.
|
710
|
-
|
728
|
+
max_label_lines (int, optional): Maximum number of lines in a label. Defaults to None (no limit).
|
711
729
|
min_chars_per_line (int, optional): Minimum number of characters in a line to display. Defaults to 1.
|
730
|
+
max_chars_per_line (int, optional): Maximum number of characters in a line to display. Defaults to None (no limit).
|
712
731
|
words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
|
713
732
|
overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
|
714
733
|
ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
|
@@ -761,11 +780,11 @@ class NetworkPlotter:
|
|
761
780
|
if words_to_omit:
|
762
781
|
words_to_omit = set(word.lower() for word in words_to_omit)
|
763
782
|
|
764
|
-
# Calculate the center and radius of the network
|
765
|
-
|
783
|
+
# Calculate the center and radius of domains to position labels around the network
|
784
|
+
domain_id_to_centroid_map = {}
|
766
785
|
for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
767
786
|
if node_ids: # Skip if the domain has no nodes
|
768
|
-
|
787
|
+
domain_id_to_centroid_map[domain_id] = self._calculate_domain_centroid(node_ids)
|
769
788
|
|
770
789
|
# Initialize dictionaries and lists for valid indices
|
771
790
|
valid_indices = [] # List of valid indices to plot colors and arrows
|
@@ -775,8 +794,8 @@ class NetworkPlotter:
|
|
775
794
|
if ids_to_keep:
|
776
795
|
# Process the ids_to_keep first INPLACE
|
777
796
|
self._process_ids_to_keep(
|
797
|
+
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
778
798
|
ids_to_keep=ids_to_keep,
|
779
|
-
domain_centroids=domain_centroids,
|
780
799
|
ids_to_replace=ids_to_replace,
|
781
800
|
words_to_omit=words_to_omit,
|
782
801
|
max_labels=max_labels,
|
@@ -796,7 +815,7 @@ class NetworkPlotter:
|
|
796
815
|
# Process remaining domains INPLACE to fill in additional labels, if there are slots left
|
797
816
|
if remaining_labels and remaining_labels > 0:
|
798
817
|
self._process_remaining_domains(
|
799
|
-
|
818
|
+
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
800
819
|
ids_to_keep=ids_to_keep,
|
801
820
|
ids_to_replace=ids_to_replace,
|
802
821
|
words_to_omit=words_to_omit,
|
@@ -816,12 +835,14 @@ class NetworkPlotter:
|
|
816
835
|
best_label_positions = _calculate_best_label_positions(
|
817
836
|
filtered_domain_centroids, center, radius, offset
|
818
837
|
)
|
819
|
-
# Convert colors to RGBA using the _to_rgba helper function
|
838
|
+
# Convert all domain colors to RGBA using the _to_rgba helper function
|
820
839
|
fontcolor = _to_rgba(
|
821
|
-
fontcolor, fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
840
|
+
color=fontcolor, alpha=fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
822
841
|
)
|
823
842
|
arrow_color = _to_rgba(
|
824
|
-
arrow_color,
|
843
|
+
color=arrow_color,
|
844
|
+
alpha=arrow_alpha,
|
845
|
+
num_repeats=len(self.graph.domain_id_to_node_ids_map),
|
825
846
|
)
|
826
847
|
|
827
848
|
# Annotate the network with labels
|
@@ -847,8 +868,11 @@ class NetworkPlotter:
|
|
847
868
|
shrinkB=arrow_tip_shrink,
|
848
869
|
),
|
849
870
|
)
|
850
|
-
|
851
|
-
|
871
|
+
|
872
|
+
# Overlay domain ID at the centroid regardless of max_labels if requested
|
873
|
+
if overlay_ids:
|
874
|
+
for idx, domain in enumerate(self.graph.domain_id_to_node_ids_map):
|
875
|
+
centroid = domain_id_to_centroid_map[domain]
|
852
876
|
self.ax.text(
|
853
877
|
centroid[0],
|
854
878
|
centroid[1],
|
@@ -907,8 +931,8 @@ class NetworkPlotter:
|
|
907
931
|
node_groups = [nodes]
|
908
932
|
|
909
933
|
# Convert fontcolor and arrow_color to RGBA
|
910
|
-
fontcolor_rgba = _to_rgba(fontcolor, fontalpha)
|
911
|
-
arrow_color_rgba = _to_rgba(arrow_color, arrow_alpha)
|
934
|
+
fontcolor_rgba = _to_rgba(color=fontcolor, alpha=fontalpha)
|
935
|
+
arrow_color_rgba = _to_rgba(color=arrow_color, alpha=arrow_alpha)
|
912
936
|
|
913
937
|
# Calculate the bounding box around the network
|
914
938
|
center, radius = _calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
|
@@ -977,8 +1001,8 @@ class NetworkPlotter:
|
|
977
1001
|
|
978
1002
|
def _process_ids_to_keep(
|
979
1003
|
self,
|
1004
|
+
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
980
1005
|
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
981
|
-
domain_centroids: Dict[str, np.ndarray],
|
982
1006
|
ids_to_replace: Union[Dict[str, str], None],
|
983
1007
|
words_to_omit: Union[List[str], None],
|
984
1008
|
max_labels: Union[int, None],
|
@@ -993,8 +1017,8 @@ class NetworkPlotter:
|
|
993
1017
|
"""Process the ids_to_keep, apply filtering, and store valid domain centroids and terms.
|
994
1018
|
|
995
1019
|
Args:
|
1020
|
+
domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
|
996
1021
|
ids_to_keep (list, tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
997
|
-
domain_centroids (dict): Mapping of domains to their centroids.
|
998
1022
|
ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
999
1023
|
words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
|
1000
1024
|
max_labels (int, optional): Maximum number of labels allowed.
|
@@ -1020,25 +1044,30 @@ class NetworkPlotter:
|
|
1020
1044
|
|
1021
1045
|
# Process each domain in ids_to_keep
|
1022
1046
|
for domain in ids_to_keep:
|
1023
|
-
if
|
1024
|
-
|
1047
|
+
if (
|
1048
|
+
domain in self.graph.domain_id_to_domain_terms_map
|
1049
|
+
and domain in domain_id_to_centroid_map
|
1050
|
+
):
|
1051
|
+
domain_centroid = domain_id_to_centroid_map[domain]
|
1052
|
+
# No need to filter the domain terms if it is in ids_to_keep
|
1053
|
+
_ = self._validate_and_update_domain(
|
1025
1054
|
domain=domain,
|
1055
|
+
domain_centroid=domain_centroid,
|
1056
|
+
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
1026
1057
|
ids_to_replace=ids_to_replace,
|
1027
1058
|
words_to_omit=words_to_omit,
|
1059
|
+
min_label_lines=min_label_lines,
|
1028
1060
|
max_label_lines=max_label_lines,
|
1029
1061
|
min_chars_per_line=min_chars_per_line,
|
1030
1062
|
max_chars_per_line=max_chars_per_line,
|
1063
|
+
filtered_domain_centroids=filtered_domain_centroids,
|
1064
|
+
filtered_domain_terms=filtered_domain_terms,
|
1065
|
+
valid_indices=valid_indices,
|
1031
1066
|
)
|
1032
|
-
num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
|
1033
|
-
# Check if the number of lines in the label is greater than or equal to the minimum
|
1034
|
-
if num_domain_lines >= min_label_lines:
|
1035
|
-
filtered_domain_terms[domain] = domain_terms
|
1036
|
-
filtered_domain_centroids[domain] = domain_centroids[domain]
|
1037
|
-
valid_indices.append(list(domain_centroids.keys()).index(domain))
|
1038
1067
|
|
1039
1068
|
def _process_remaining_domains(
|
1040
1069
|
self,
|
1041
|
-
|
1070
|
+
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
1042
1071
|
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
1043
1072
|
ids_to_replace: Union[Dict[str, str], None],
|
1044
1073
|
words_to_omit: Union[List[str], None],
|
@@ -1054,7 +1083,7 @@ class NetworkPlotter:
|
|
1054
1083
|
"""Process remaining domains to fill in additional labels, respecting the remaining_labels limit.
|
1055
1084
|
|
1056
1085
|
Args:
|
1057
|
-
|
1086
|
+
domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
|
1058
1087
|
ids_to_keep (list, tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
1059
1088
|
ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
1060
1089
|
words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
|
@@ -1066,13 +1095,16 @@ class NetworkPlotter:
|
|
1066
1095
|
filtered_domain_centroids (dict): Dictionary to store filtered domain centroids (output).
|
1067
1096
|
filtered_domain_terms (dict): Dictionary to store filtered domain terms (output).
|
1068
1097
|
valid_indices (list): List to store valid indices (output).
|
1098
|
+
|
1099
|
+
Note:
|
1100
|
+
The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
|
1069
1101
|
"""
|
1070
1102
|
# Counter to track how many labels have been created
|
1071
1103
|
label_count = 0
|
1072
1104
|
# Collect domains not in ids_to_keep
|
1073
1105
|
remaining_domains = {
|
1074
1106
|
domain: centroid
|
1075
|
-
for domain, centroid in
|
1107
|
+
for domain, centroid in domain_id_to_centroid_map.items()
|
1076
1108
|
if domain not in ids_to_keep and not pd.isna(domain)
|
1077
1109
|
}
|
1078
1110
|
|
@@ -1112,26 +1144,89 @@ class NetworkPlotter:
|
|
1112
1144
|
|
1113
1145
|
# Process the selected domains and add to filtered lists
|
1114
1146
|
for domain in selected_domains:
|
1115
|
-
|
1116
|
-
|
1147
|
+
domain_centroid = remaining_domains[domain]
|
1148
|
+
is_domain_valid = self._validate_and_update_domain(
|
1117
1149
|
domain=domain,
|
1150
|
+
domain_centroid=domain_centroid,
|
1151
|
+
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
1118
1152
|
ids_to_replace=ids_to_replace,
|
1119
1153
|
words_to_omit=words_to_omit,
|
1154
|
+
min_label_lines=min_label_lines,
|
1120
1155
|
max_label_lines=max_label_lines,
|
1121
1156
|
min_chars_per_line=min_chars_per_line,
|
1122
1157
|
max_chars_per_line=max_chars_per_line,
|
1158
|
+
filtered_domain_centroids=filtered_domain_centroids,
|
1159
|
+
filtered_domain_terms=filtered_domain_terms,
|
1160
|
+
valid_indices=valid_indices,
|
1123
1161
|
)
|
1124
|
-
|
1125
|
-
|
1126
|
-
if num_domain_lines >= min_label_lines:
|
1127
|
-
filtered_domain_centroids[domain] = centroid
|
1128
|
-
filtered_domain_terms[domain] = domain_terms
|
1129
|
-
valid_indices.append(list(domain_centroids.keys()).index(domain))
|
1130
|
-
|
1162
|
+
# Increment the label count if the domain is valid
|
1163
|
+
if is_domain_valid:
|
1131
1164
|
label_count += 1
|
1132
1165
|
if label_count >= remaining_labels:
|
1133
1166
|
break
|
1134
1167
|
|
1168
|
+
def _validate_and_update_domain(
|
1169
|
+
self,
|
1170
|
+
domain: str,
|
1171
|
+
domain_centroid: np.ndarray,
|
1172
|
+
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
1173
|
+
ids_to_replace: Union[Dict[str, str], None],
|
1174
|
+
words_to_omit: Union[List[str], None],
|
1175
|
+
min_label_lines: int,
|
1176
|
+
max_label_lines: int,
|
1177
|
+
min_chars_per_line: int,
|
1178
|
+
max_chars_per_line: int,
|
1179
|
+
filtered_domain_centroids: Dict[str, np.ndarray],
|
1180
|
+
filtered_domain_terms: Dict[str, str],
|
1181
|
+
valid_indices: List[int],
|
1182
|
+
) -> bool:
|
1183
|
+
"""Validate and process the domain terms, updating relevant dictionaries if valid.
|
1184
|
+
|
1185
|
+
Args:
|
1186
|
+
domain (str): Domain ID to process.
|
1187
|
+
domain_centroid (np.ndarray): Centroid position of the domain.
|
1188
|
+
domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
|
1189
|
+
ids_to_replace (Union[Dict[str, str], None]): A dictionary mapping domain IDs to custom labels.
|
1190
|
+
words_to_omit (Union[List[str], None]): List of words to omit from the labels.
|
1191
|
+
min_label_lines (int): Minimum number of lines required in a label.
|
1192
|
+
max_label_lines (int): Maximum number of lines allowed in a label.
|
1193
|
+
min_chars_per_line (int): Minimum number of characters allowed per line.
|
1194
|
+
max_chars_per_line (int): Maximum number of characters allowed per line.
|
1195
|
+
filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store valid domain centroids.
|
1196
|
+
filtered_domain_terms (Dict[str, str]): Dictionary to store valid domain terms.
|
1197
|
+
valid_indices (List[int]): List of valid domain indices.
|
1198
|
+
|
1199
|
+
Returns:
|
1200
|
+
bool: True if the domain is valid and added to the filtered dictionaries, False otherwise.
|
1201
|
+
|
1202
|
+
Note:
|
1203
|
+
The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
|
1204
|
+
"""
|
1205
|
+
# Process the domain terms
|
1206
|
+
domain_terms = self._process_terms(
|
1207
|
+
domain=domain,
|
1208
|
+
ids_to_replace=ids_to_replace,
|
1209
|
+
words_to_omit=words_to_omit,
|
1210
|
+
max_label_lines=max_label_lines,
|
1211
|
+
min_chars_per_line=min_chars_per_line,
|
1212
|
+
max_chars_per_line=max_chars_per_line,
|
1213
|
+
)
|
1214
|
+
# If domain_terms is empty, skip further processing
|
1215
|
+
if not domain_terms:
|
1216
|
+
return False
|
1217
|
+
|
1218
|
+
# Split the terms by TERM_DELIMITER and count the number of lines
|
1219
|
+
num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
|
1220
|
+
# Check if the number of lines is greater than or equal to the minimum
|
1221
|
+
if num_domain_lines >= min_label_lines:
|
1222
|
+
filtered_domain_centroids[domain] = domain_centroid
|
1223
|
+
filtered_domain_terms[domain] = domain_terms
|
1224
|
+
# Add the index of the domain to the valid indices list
|
1225
|
+
valid_indices.append(list(domain_id_to_centroid_map.keys()).index(domain))
|
1226
|
+
return True
|
1227
|
+
|
1228
|
+
return False
|
1229
|
+
|
1135
1230
|
def _process_terms(
|
1136
1231
|
self,
|
1137
1232
|
domain: str,
|
@@ -1152,7 +1247,7 @@ class NetworkPlotter:
|
|
1152
1247
|
max_chars_per_line (int): Maximum number of characters in a line to display.
|
1153
1248
|
|
1154
1249
|
Returns:
|
1155
|
-
|
1250
|
+
str: Processed terms separated by TERM_DELIMITER, with words combined if necessary to fit within constraints.
|
1156
1251
|
"""
|
1157
1252
|
# Handle ids_to_replace logic
|
1158
1253
|
if ids_to_replace and domain in ids_to_replace:
|
@@ -1213,11 +1308,11 @@ class NetworkPlotter:
|
|
1213
1308
|
# Apply the alpha value for enriched nodes
|
1214
1309
|
network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
|
1215
1310
|
# Convert the non-enriched color to RGBA using the _to_rgba helper function
|
1216
|
-
nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
|
1311
|
+
nonenriched_color = _to_rgba(color=nonenriched_color, alpha=nonenriched_alpha)
|
1217
1312
|
# Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
|
1218
1313
|
adjusted_network_colors = np.where(
|
1219
1314
|
np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
|
1220
|
-
np.array(
|
1315
|
+
np.array(nonenriched_color), # Apply the non-enriched color with alpha
|
1221
1316
|
network_colors, # Keep the original colors for enriched nodes
|
1222
1317
|
)
|
1223
1318
|
return adjusted_network_colors
|
@@ -1391,62 +1486,65 @@ class NetworkPlotter:
|
|
1391
1486
|
|
1392
1487
|
def _to_rgba(
|
1393
1488
|
color: Union[str, List, Tuple, np.ndarray],
|
1394
|
-
alpha: float =
|
1489
|
+
alpha: Union[float, None] = None,
|
1395
1490
|
num_repeats: Union[int, None] = None,
|
1396
1491
|
) -> np.ndarray:
|
1397
|
-
"""Convert
|
1492
|
+
"""Convert color(s) to RGBA format, applying alpha and repeating as needed.
|
1398
1493
|
|
1399
1494
|
Args:
|
1400
1495
|
color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
|
1401
|
-
alpha (float, optional): Alpha value (transparency) to apply
|
1402
|
-
num_repeats (int
|
1496
|
+
alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values.
|
1497
|
+
num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
|
1403
1498
|
|
1404
1499
|
Returns:
|
1405
|
-
np.ndarray:
|
1500
|
+
np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
|
1406
1501
|
"""
|
1407
|
-
# Handle single color case (string, RGB, or RGBA)
|
1408
|
-
if isinstance(color, str) or (
|
1409
|
-
isinstance(color, (list, tuple, np.ndarray))
|
1410
|
-
and len(color) in [3, 4]
|
1411
|
-
and not any(isinstance(c, (list, tuple, np.ndarray)) for c in color)
|
1412
|
-
):
|
1413
|
-
rgba_color = np.array(mcolors.to_rgba(color))
|
1414
|
-
# Only set alpha if the input is an RGB color or a string (not RGBA)
|
1415
|
-
if len(rgba_color) == 4 and (
|
1416
|
-
len(color) == 3 or isinstance(color, str)
|
1417
|
-
): # If it's RGB or a string, set the alpha
|
1418
|
-
rgba_color[3] = alpha
|
1419
1502
|
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1503
|
+
def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
|
1504
|
+
"""Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
|
1505
|
+
if isinstance(c, str):
|
1506
|
+
# Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
|
1507
|
+
rgba = np.array(mcolors.to_rgba(c))
|
1508
|
+
elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
|
1509
|
+
# Convert RGB (3) or RGBA (4) values to RGBA format
|
1510
|
+
rgba = np.array(mcolors.to_rgba(c))
|
1511
|
+
else:
|
1512
|
+
raise ValueError(
|
1513
|
+
f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
|
1514
|
+
)
|
1423
1515
|
|
1424
|
-
|
1516
|
+
if alpha is not None: # Override alpha if provided
|
1517
|
+
rgba[3] = alpha
|
1518
|
+
return rgba
|
1425
1519
|
|
1426
|
-
#
|
1427
|
-
|
1428
|
-
|
1429
|
-
for i in range(num_repeats):
|
1430
|
-
# Reiterate over the colors if the number of repeats exceeds the number of colors
|
1431
|
-
c = color[i % len(color)]
|
1432
|
-
# Ensure each element is either a valid string or a list/tuple of length 3 (RGB) or 4 (RGBA)
|
1433
|
-
if isinstance(c, str) or (
|
1434
|
-
isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]
|
1435
|
-
):
|
1436
|
-
rgba_c = np.array(mcolors.to_rgba(c))
|
1437
|
-
# Apply alpha only to RGB colors (not RGBA) and strings
|
1438
|
-
if len(rgba_c) == 4 and (len(c) == 3 or isinstance(c, str)):
|
1439
|
-
rgba_c[3] = alpha
|
1440
|
-
|
1441
|
-
rgba_colors.append(rgba_c)
|
1442
|
-
else:
|
1443
|
-
raise ValueError(f"Invalid color: {c}. Must be a valid RGB/RGBA or string color.")
|
1520
|
+
# If color is a 2D array of RGBA values, convert it to a list of lists
|
1521
|
+
if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
|
1522
|
+
color = [list(c) for c in color]
|
1444
1523
|
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1524
|
+
# Handle a single color (string or RGB/RGBA list/tuple)
|
1525
|
+
if isinstance(color, (str, list, tuple)) and not any(
|
1526
|
+
isinstance(c, (list, tuple, np.ndarray)) for c in color
|
1527
|
+
):
|
1528
|
+
rgba_color = convert_to_rgba(color)
|
1529
|
+
if num_repeats:
|
1530
|
+
return np.tile(
|
1531
|
+
rgba_color, (num_repeats, 1)
|
1532
|
+
) # Repeat the color if num_repeats is provided
|
1533
|
+
return np.array([rgba_color]) # Return a single color wrapped in a numpy array
|
1534
|
+
|
1535
|
+
# Handle a list/array of colors
|
1536
|
+
elif isinstance(color, (list, tuple, np.ndarray)):
|
1537
|
+
rgba_colors = np.array(
|
1538
|
+
[convert_to_rgba(c) for c in color]
|
1539
|
+
) # Convert each color in the list to RGBA
|
1540
|
+
# Handle repetition if num_repeats is provided
|
1541
|
+
if num_repeats:
|
1542
|
+
repeated_colors = np.array(
|
1543
|
+
[rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
|
1544
|
+
)
|
1545
|
+
return repeated_colors
|
1448
1546
|
|
1449
|
-
return
|
1547
|
+
return rgba_colors
|
1450
1548
|
|
1451
1549
|
else:
|
1452
1550
|
raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
|
@@ -1487,13 +1585,13 @@ def _calculate_bounding_box(
|
|
1487
1585
|
return center, radius
|
1488
1586
|
|
1489
1587
|
|
1490
|
-
def _combine_words(words: List[str],
|
1491
|
-
"""Combine words to fit within the
|
1492
|
-
and separate the final output by
|
1588
|
+
def _combine_words(words: List[str], max_chars_per_line: int, max_label_lines: int) -> str:
|
1589
|
+
"""Combine words to fit within the max_chars_per_line and max_label_lines constraints,
|
1590
|
+
and separate the final output by TERM_DELIMITER for plotting.
|
1493
1591
|
|
1494
1592
|
Args:
|
1495
1593
|
words (List[str]): List of words to combine.
|
1496
|
-
|
1594
|
+
max_chars_per_line (int): Maximum number of characters in a line to display.
|
1497
1595
|
max_label_lines (int): Maximum number of lines in a label.
|
1498
1596
|
|
1499
1597
|
Returns:
|
@@ -1510,14 +1608,18 @@ def _combine_words(words: List[str], max_length: int, max_label_lines: int) -> s
|
|
1510
1608
|
# Try to combine more words if possible, and ensure the combination fits within max_length
|
1511
1609
|
for j in range(i + 1, len(words_batch)):
|
1512
1610
|
next_word = words_batch[j]
|
1513
|
-
|
1611
|
+
# Ensure that the combined word fits within the max_chars_per_line limit
|
1612
|
+
if len(combined_word) + len(next_word) + 1 <= max_chars_per_line: # +1 for space
|
1514
1613
|
combined_word = f"{combined_word} {next_word}"
|
1515
1614
|
i += 1 # Move past the combined word
|
1516
1615
|
else:
|
1517
1616
|
break # Stop combining if the length is exceeded
|
1518
1617
|
|
1519
|
-
|
1520
|
-
|
1618
|
+
# Add the combined word only if it fits within the max_chars_per_line limit
|
1619
|
+
if len(combined_word) <= max_chars_per_line:
|
1620
|
+
combined_lines.append(combined_word) # Add the combined word
|
1621
|
+
# Move to the next word
|
1622
|
+
i += 1
|
1521
1623
|
|
1522
1624
|
# Stop if we've reached the max_label_lines limit
|
1523
1625
|
if len(combined_lines) >= max_label_lines:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|