risk-network 0.0.4b2__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -5
- risk/annotations/annotations.py +3 -3
- risk/constants.py +2 -2
- risk/neighborhoods/neighborhoods.py +5 -1
- risk/network/geometry.py +2 -2
- risk/network/graph.py +45 -19
- risk/network/io.py +45 -30
- risk/network/plot.py +70 -18
- risk/risk.py +175 -19
- risk/stats/__init__.py +4 -1
- risk/stats/fisher_exact.py +132 -0
- risk/stats/hypergeom.py +131 -0
- risk/stats/permutation/__init__.py +6 -0
- risk/stats/permutation/permutation.py +212 -0
- risk/stats/{permutation.py → permutation/test_functions.py} +12 -39
- risk/stats/stats.py +1 -212
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/METADATA +4 -5
- risk_network-0.0.5.dist-info/RECORD +30 -0
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/WHEEL +1 -1
- risk_network-0.0.4b2.dist-info/RECORD +0 -26
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/LICENSE +0 -0
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/top_level.txt +0 -0
risk/__init__.py
CHANGED
risk/annotations/annotations.py
CHANGED
@@ -197,7 +197,7 @@ def _simplify_word_list(words: List[str], threshold: float = 0.80) -> List[str]:
|
|
197
197
|
word_counts = Counter(words)
|
198
198
|
filtered_words = []
|
199
199
|
used_words = set()
|
200
|
-
|
200
|
+
# Iterate through the words to find similar words
|
201
201
|
for word in word_counts:
|
202
202
|
if word in used_words:
|
203
203
|
continue
|
@@ -207,7 +207,7 @@ def _simplify_word_list(words: List[str], threshold: float = 0.80) -> List[str]:
|
|
207
207
|
similar_words = [
|
208
208
|
other_word
|
209
209
|
for other_word in word_counts
|
210
|
-
if
|
210
|
+
if _calculate_jaccard_index(word_set, set(other_word)) >= threshold
|
211
211
|
]
|
212
212
|
# Sort by frequency and choose the most frequent word
|
213
213
|
similar_words.sort(key=lambda w: word_counts[w], reverse=True)
|
@@ -220,7 +220,7 @@ def _simplify_word_list(words: List[str], threshold: float = 0.80) -> List[str]:
|
|
220
220
|
return final_words
|
221
221
|
|
222
222
|
|
223
|
-
def
|
223
|
+
def _calculate_jaccard_index(set1: Set[Any], set2: Set[Any]) -> float:
|
224
224
|
"""Calculate the Jaccard Index of two sets.
|
225
225
|
|
226
226
|
Args:
|
risk/constants.py
CHANGED
@@ -3,6 +3,8 @@ risk/constants
|
|
3
3
|
~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
+
GROUP_LINKAGE_METHODS = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
|
7
|
+
|
6
8
|
GROUP_DISTANCE_METRICS = [
|
7
9
|
"braycurtis",
|
8
10
|
"canberra",
|
@@ -27,5 +29,3 @@ GROUP_DISTANCE_METRICS = [
|
|
27
29
|
"sqeuclidean",
|
28
30
|
"yule",
|
29
31
|
]
|
30
|
-
|
31
|
-
GROUP_LINKAGE_METHODS = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
|
@@ -321,7 +321,11 @@ def _calculate_threshold(average_distances: list, distance_threshold: float) ->
|
|
321
321
|
rank_percentiles = np.linspace(0, 1, len(sorted_distances))
|
322
322
|
# Interpolating the ranks to 1000 evenly spaced percentiles
|
323
323
|
interpolated_percentiles = np.linspace(0, 1, 1000)
|
324
|
-
|
324
|
+
try:
|
325
|
+
smoothed_distances = np.interp(interpolated_percentiles, rank_percentiles, sorted_distances)
|
326
|
+
except ValueError as e:
|
327
|
+
raise ValueError("No significant annotations found.") from e
|
328
|
+
|
325
329
|
# Determine the index corresponding to the distance threshold
|
326
330
|
threshold_index = int(np.ceil(distance_threshold * len(smoothed_distances))) - 1
|
327
331
|
# Return the smoothed distance at the calculated index
|
risk/network/geometry.py
CHANGED
@@ -7,13 +7,13 @@ import networkx as nx
|
|
7
7
|
import numpy as np
|
8
8
|
|
9
9
|
|
10
|
-
def
|
10
|
+
def assign_edge_lengths(
|
11
11
|
G: nx.Graph,
|
12
12
|
compute_sphere: bool = True,
|
13
13
|
surface_depth: float = 0.0,
|
14
14
|
include_edge_weight: bool = False,
|
15
15
|
) -> nx.Graph:
|
16
|
-
"""
|
16
|
+
"""Assign edge lengths in the graph, optionally mapping nodes to a sphere and including edge weights.
|
17
17
|
|
18
18
|
Args:
|
19
19
|
G (nx.Graph): The input graph.
|
risk/network/graph.py
CHANGED
@@ -5,13 +5,12 @@ risk/network/graph
|
|
5
5
|
|
6
6
|
import random
|
7
7
|
from collections import defaultdict
|
8
|
-
from typing import Any, Dict, List, Tuple
|
8
|
+
from typing import Any, Dict, List, Tuple, Union
|
9
9
|
|
10
10
|
import networkx as nx
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
import matplotlib
|
14
|
-
import matplotlib.cm as cm
|
15
14
|
|
16
15
|
|
17
16
|
class NetworkGraph:
|
@@ -101,7 +100,12 @@ class NetworkGraph:
|
|
101
100
|
self.node_coordinates = _extract_node_coordinates(G_2d)
|
102
101
|
|
103
102
|
def get_domain_colors(
|
104
|
-
self,
|
103
|
+
self,
|
104
|
+
min_scale: float = 0.8,
|
105
|
+
max_scale: float = 1.0,
|
106
|
+
scale_factor: float = 1.0,
|
107
|
+
random_seed: int = 888,
|
108
|
+
**kwargs,
|
105
109
|
) -> np.ndarray:
|
106
110
|
"""Generate composite colors for domains.
|
107
111
|
|
@@ -112,6 +116,8 @@ class NetworkGraph:
|
|
112
116
|
Args:
|
113
117
|
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
114
118
|
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
119
|
+
scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
|
120
|
+
values more. Defaults to 1.0.
|
115
121
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
116
122
|
**kwargs: Additional keyword arguments for color generation.
|
117
123
|
|
@@ -119,7 +125,7 @@ class NetworkGraph:
|
|
119
125
|
np.ndarray: Array of transformed colors.
|
120
126
|
"""
|
121
127
|
# Get colors for each domain
|
122
|
-
domain_colors = self._get_domain_colors(
|
128
|
+
domain_colors = self._get_domain_colors(random_seed=random_seed)
|
123
129
|
# Generate composite colors for nodes
|
124
130
|
node_colors = self._get_composite_node_colors(domain_colors)
|
125
131
|
# Transform colors to ensure proper alpha values and intensity
|
@@ -128,6 +134,7 @@ class NetworkGraph:
|
|
128
134
|
self.node_enrichment_sums,
|
129
135
|
min_scale=min_scale,
|
130
136
|
max_scale=max_scale,
|
137
|
+
scale_factor=scale_factor,
|
131
138
|
)
|
132
139
|
|
133
140
|
return transformed_colors
|
@@ -153,9 +160,15 @@ class NetworkGraph:
|
|
153
160
|
|
154
161
|
return composite_colors
|
155
162
|
|
156
|
-
def _get_domain_colors(
|
163
|
+
def _get_domain_colors(
|
164
|
+
self, color: Union[str, None] = None, random_seed: int = 888
|
165
|
+
) -> Dict[str, Any]:
|
157
166
|
"""Get colors for each domain.
|
158
167
|
|
168
|
+
Args:
|
169
|
+
color (Union[str, None], optional): Specific color to use for all domains. If specified, it will overwrite the colormap.
|
170
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
171
|
+
|
159
172
|
Returns:
|
160
173
|
dict: A dictionary mapping domain keys to their corresponding RGBA colors.
|
161
174
|
"""
|
@@ -164,20 +177,28 @@ class NetworkGraph:
|
|
164
177
|
col for col in self.domains.columns if isinstance(col, (int, np.integer))
|
165
178
|
]
|
166
179
|
domains = np.sort(numeric_domains)
|
167
|
-
domain_colors = _get_colors(
|
180
|
+
domain_colors = _get_colors(
|
181
|
+
num_colors_to_generate=len(domains), color=color, random_seed=random_seed
|
182
|
+
)
|
168
183
|
return dict(zip(self.domain_to_nodes.keys(), domain_colors))
|
169
184
|
|
170
185
|
|
171
186
|
def _transform_colors(
|
172
|
-
colors: np.ndarray,
|
187
|
+
colors: np.ndarray,
|
188
|
+
enrichment_sums: np.ndarray,
|
189
|
+
min_scale: float = 0.8,
|
190
|
+
max_scale: float = 1.0,
|
191
|
+
scale_factor: float = 1.0,
|
173
192
|
) -> np.ndarray:
|
174
|
-
"""Transform colors
|
193
|
+
"""Transform colors using power scaling to emphasize high enrichment sums more.
|
175
194
|
|
176
195
|
Args:
|
177
196
|
colors (np.ndarray): An array of RGBA colors.
|
178
197
|
enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
|
179
198
|
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
180
199
|
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
200
|
+
scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
|
201
|
+
values more. Defaults to 1.0.
|
181
202
|
|
182
203
|
Returns:
|
183
204
|
np.ndarray: The transformed array of RGBA colors with adjusted intensities.
|
@@ -185,11 +206,12 @@ def _transform_colors(
|
|
185
206
|
if min_scale == max_scale:
|
186
207
|
min_scale = max_scale - 10e-6 # Avoid division by zero
|
187
208
|
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
209
|
+
# Normalize the enrichment sums to the range [0, 1]
|
210
|
+
normalized_sums = enrichment_sums / np.max(enrichment_sums)
|
211
|
+
# Apply power scaling to dim lower values and emphasize higher values
|
212
|
+
scaled_sums = normalized_sums**scale_factor
|
213
|
+
# Linearly scale the normalized sums to the range [min_scale, max_scale]
|
214
|
+
scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
|
193
215
|
# Adjust RGB values based on scaled sums
|
194
216
|
for i in range(3): # Only adjust RGB values
|
195
217
|
colors[:, i] = scaled_sums * colors[:, i]
|
@@ -250,7 +272,10 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
|
|
250
272
|
|
251
273
|
|
252
274
|
def _get_colors(
|
253
|
-
num_colors_to_generate: int = 10,
|
275
|
+
num_colors_to_generate: int = 10,
|
276
|
+
cmap: str = "hsv",
|
277
|
+
random_seed: int = 888,
|
278
|
+
color: Union[str, None] = None,
|
254
279
|
) -> List[Tuple]:
|
255
280
|
"""Generate a list of RGBA colors from a specified colormap or use a direct color string.
|
256
281
|
|
@@ -258,19 +283,20 @@ def _get_colors(
|
|
258
283
|
num_colors_to_generate (int): The number of colors to generate. Defaults to 10.
|
259
284
|
cmap (str): The name of the colormap to use. Defaults to "hsv".
|
260
285
|
random_seed (int): Seed for random number generation. Defaults to 888.
|
261
|
-
|
286
|
+
color (str, optional): Specific color to use for all nodes. If specified, it will overwrite the colormap.
|
287
|
+
Defaults to None.
|
262
288
|
|
263
289
|
Returns:
|
264
290
|
list of tuple: List of RGBA colors.
|
265
291
|
"""
|
266
292
|
# Set random seed for reproducibility
|
267
293
|
random.seed(random_seed)
|
268
|
-
if
|
269
|
-
# If a direct color
|
270
|
-
rgba = matplotlib.colors.to_rgba(
|
294
|
+
if color:
|
295
|
+
# If a direct color is provided, generate a list with that color
|
296
|
+
rgba = matplotlib.colors.to_rgba(color)
|
271
297
|
rgbas = [rgba] * num_colors_to_generate
|
272
298
|
else:
|
273
|
-
colormap =
|
299
|
+
colormap = matplotlib.colormaps.get_cmap(cmap)
|
274
300
|
# Generate evenly distributed color positions
|
275
301
|
color_positions = np.linspace(0, 1, num_colors_to_generate)
|
276
302
|
random.shuffle(color_positions) # Shuffle the positions to randomize colors
|
risk/network/io.py
CHANGED
@@ -6,6 +6,7 @@ This file contains the code for the RISK class and command-line access.
|
|
6
6
|
"""
|
7
7
|
|
8
8
|
import json
|
9
|
+
import os
|
9
10
|
import pickle
|
10
11
|
import shutil
|
11
12
|
import zipfile
|
@@ -14,7 +15,7 @@ from xml.dom import minidom
|
|
14
15
|
import networkx as nx
|
15
16
|
import pandas as pd
|
16
17
|
|
17
|
-
from risk.network.geometry import
|
18
|
+
from risk.network.geometry import assign_edge_lengths
|
18
19
|
from risk.log import params, print_header
|
19
20
|
|
20
21
|
|
@@ -215,14 +216,20 @@ class NetworkIO:
|
|
215
216
|
params.log_network(filetype=filetype, filepath=str(filepath))
|
216
217
|
self._log_loading(filetype, filepath=filepath)
|
217
218
|
cys_files = []
|
219
|
+
tmp_dir = ".tmp_cytoscape"
|
218
220
|
# Try / finally to remove unzipped files
|
219
221
|
try:
|
220
|
-
#
|
222
|
+
# Create the temporary directory if it doesn't exist
|
223
|
+
if not os.path.exists(tmp_dir):
|
224
|
+
os.makedirs(tmp_dir)
|
225
|
+
|
226
|
+
# Unzip CYS file into the temporary directory
|
221
227
|
with zipfile.ZipFile(filepath, "r") as zip_ref:
|
222
228
|
cys_files = zip_ref.namelist()
|
223
|
-
zip_ref.extractall(
|
229
|
+
zip_ref.extractall(tmp_dir)
|
230
|
+
|
224
231
|
# Get first view and network instances
|
225
|
-
cys_view_files = [cf for cf in cys_files if "/views/" in cf]
|
232
|
+
cys_view_files = [os.path.join(tmp_dir, cf) for cf in cys_files if "/views/" in cf]
|
226
233
|
cys_view_file = (
|
227
234
|
cys_view_files[0]
|
228
235
|
if not view_name
|
@@ -244,7 +251,7 @@ class NetworkIO:
|
|
244
251
|
# Read the node attributes (from /tables/)
|
245
252
|
attribute_metadata_keywords = ["/tables/", "SHARED_ATTRS", "edge.cytable"]
|
246
253
|
attribute_metadata = [
|
247
|
-
cf
|
254
|
+
os.path.join(tmp_dir, cf)
|
248
255
|
for cf in cys_files
|
249
256
|
if all(keyword in cf for keyword in attribute_metadata_keywords)
|
250
257
|
][0]
|
@@ -291,10 +298,9 @@ class NetworkIO:
|
|
291
298
|
return self._initialize_graph(G)
|
292
299
|
|
293
300
|
finally:
|
294
|
-
# Remove
|
295
|
-
|
296
|
-
|
297
|
-
shutil.rmtree(dirname)
|
301
|
+
# Remove the temporary directory and its contents
|
302
|
+
if os.path.exists(tmp_dir):
|
303
|
+
shutil.rmtree(tmp_dir)
|
298
304
|
|
299
305
|
@classmethod
|
300
306
|
def load_cytoscape_json_network(
|
@@ -402,12 +408,13 @@ class NetworkIO:
|
|
402
408
|
Returns:
|
403
409
|
nx.Graph: The processed and validated graph.
|
404
410
|
"""
|
411
|
+
self._validate_nodes(G)
|
412
|
+
self._assign_edge_weights(G)
|
413
|
+
self._assign_edge_lengths(G)
|
414
|
+
self._remove_invalid_graph_properties(G)
|
405
415
|
# IMPORTANT: This is where the graph node labels are converted to integers
|
416
|
+
# Make sure to perform this step after all other processing
|
406
417
|
G = nx.relabel_nodes(G, {node: idx for idx, node in enumerate(G.nodes)})
|
407
|
-
self._remove_invalid_graph_properties(G)
|
408
|
-
self._validate_edges(G)
|
409
|
-
self._validate_nodes(G)
|
410
|
-
self._process_graph(G)
|
411
418
|
return G
|
412
419
|
|
413
420
|
def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
|
@@ -416,18 +423,26 @@ class NetworkIO:
|
|
416
423
|
Args:
|
417
424
|
G (nx.Graph): A NetworkX graph object.
|
418
425
|
"""
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
426
|
+
# First, Remove self-loop edges to ensure correct edge count
|
427
|
+
G.remove_edges_from(list(nx.selfloop_edges(G)))
|
428
|
+
# Then, iteratively remove nodes with fewer edges than the specified threshold
|
429
|
+
while True:
|
430
|
+
nodes_to_remove = [
|
431
|
+
node for node in G.nodes() if G.degree(node) < self.min_edges_per_node
|
432
|
+
]
|
433
|
+
if not nodes_to_remove:
|
434
|
+
break # Exit loop if no more nodes to remove
|
435
|
+
|
436
|
+
# Remove the nodes and their associated edges
|
437
|
+
G.remove_nodes_from(nodes_to_remove)
|
438
|
+
|
439
|
+
# Optionally: Remove any isolated nodes if needed
|
440
|
+
isolated_nodes = list(nx.isolates(G))
|
441
|
+
if isolated_nodes:
|
442
|
+
G.remove_nodes_from(isolated_nodes)
|
443
|
+
|
444
|
+
def _assign_edge_weights(self, G: nx.Graph) -> None:
|
445
|
+
"""Assign weights to the edges in the graph.
|
431
446
|
|
432
447
|
Args:
|
433
448
|
G (nx.Graph): A NetworkX graph object.
|
@@ -456,13 +471,13 @@ class NetworkIO:
|
|
456
471
|
), f"Node {node} is missing 'x' or 'y' position attributes."
|
457
472
|
assert "label" in attrs, f"Node {node} is missing a 'label' attribute."
|
458
473
|
|
459
|
-
def
|
474
|
+
def _assign_edge_lengths(self, G: nx.Graph) -> None:
|
460
475
|
"""Prepare the network by adjusting surface depth and calculating edge lengths.
|
461
476
|
|
462
477
|
Args:
|
463
478
|
G (nx.Graph): The input network graph.
|
464
479
|
"""
|
465
|
-
|
480
|
+
assign_edge_lengths(
|
466
481
|
G,
|
467
482
|
compute_sphere=self.compute_sphere,
|
468
483
|
surface_depth=self.surface_depth,
|
@@ -484,9 +499,9 @@ class NetworkIO:
|
|
484
499
|
print(f"Filetype: {filetype}")
|
485
500
|
if filepath:
|
486
501
|
print(f"Filepath: {filepath}")
|
487
|
-
print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
|
488
|
-
if self.compute_sphere:
|
489
|
-
print(f"Surface depth: {self.surface_depth}")
|
490
502
|
print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
|
491
503
|
if self.include_edge_weight:
|
492
504
|
print(f"Weight label: {self.weight_label}")
|
505
|
+
print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
|
506
|
+
if self.compute_sphere:
|
507
|
+
print(f"Surface depth: {self.surface_depth}")
|
risk/network/plot.py
CHANGED
@@ -33,6 +33,7 @@ class NetworkPlotter:
|
|
33
33
|
plot_outline: bool = True,
|
34
34
|
outline_color: str = "black",
|
35
35
|
outline_scale: float = 1.0,
|
36
|
+
linestyle: str = "dashed",
|
36
37
|
) -> None:
|
37
38
|
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
38
39
|
|
@@ -43,11 +44,18 @@ class NetworkPlotter:
|
|
43
44
|
plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
|
44
45
|
outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
|
45
46
|
outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
|
47
|
+
linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
|
46
48
|
"""
|
47
49
|
self.graph = graph
|
48
50
|
# Initialize the plot with the specified parameters
|
49
51
|
self.ax = self._initialize_plot(
|
50
|
-
graph,
|
52
|
+
graph,
|
53
|
+
figsize,
|
54
|
+
background_color,
|
55
|
+
plot_outline,
|
56
|
+
outline_color,
|
57
|
+
outline_scale,
|
58
|
+
linestyle,
|
51
59
|
)
|
52
60
|
|
53
61
|
def _initialize_plot(
|
@@ -58,6 +66,7 @@ class NetworkPlotter:
|
|
58
66
|
plot_outline: bool,
|
59
67
|
outline_color: str,
|
60
68
|
outline_scale: float,
|
69
|
+
linestyle: str,
|
61
70
|
) -> plt.Axes:
|
62
71
|
"""Set up the plot with figure size, optional circle perimeter, and background color.
|
63
72
|
|
@@ -68,6 +77,7 @@ class NetworkPlotter:
|
|
68
77
|
plot_outline (bool): Whether to plot the network perimeter circle.
|
69
78
|
outline_color (str): Color of the network perimeter circle.
|
70
79
|
outline_scale (float): Outline scaling factor for the perimeter diameter.
|
80
|
+
linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid).
|
71
81
|
|
72
82
|
Returns:
|
73
83
|
plt.Axes: The axis object for the plot.
|
@@ -87,7 +97,7 @@ class NetworkPlotter:
|
|
87
97
|
circle = plt.Circle(
|
88
98
|
center,
|
89
99
|
scaled_radius,
|
90
|
-
linestyle=
|
100
|
+
linestyle=linestyle, # Use the linestyle argument here
|
91
101
|
color=outline_color,
|
92
102
|
fill=False,
|
93
103
|
linewidth=1.5,
|
@@ -400,8 +410,12 @@ class NetworkPlotter:
|
|
400
410
|
fontcolor: Union[str, np.ndarray] = "black",
|
401
411
|
arrow_linewidth: float = 1,
|
402
412
|
arrow_color: Union[str, np.ndarray] = "black",
|
413
|
+
max_labels: Union[int, None] = None,
|
403
414
|
max_words: int = 10,
|
404
415
|
min_words: int = 1,
|
416
|
+
max_word_length: int = 20,
|
417
|
+
min_word_length: int = 1,
|
418
|
+
words_to_omit: Union[List[str], None] = None,
|
405
419
|
) -> None:
|
406
420
|
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
407
421
|
|
@@ -413,8 +427,12 @@ class NetworkPlotter:
|
|
413
427
|
fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
|
414
428
|
arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
|
415
429
|
arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
|
430
|
+
max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
|
416
431
|
max_words (int, optional): Maximum number of words in a label. Defaults to 10.
|
417
432
|
min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
|
433
|
+
max_word_length (int, optional): Maximum number of characters in a word to display. Defaults to 20.
|
434
|
+
min_word_length (int, optional): Minimum number of characters in a word to display. Defaults to 1.
|
435
|
+
words_to_omit (List[str], optional): List of words to omit from the labels. Defaults to None.
|
418
436
|
"""
|
419
437
|
# Log the plotting parameters
|
420
438
|
params.log_plotter(
|
@@ -425,14 +443,22 @@ class NetworkPlotter:
|
|
425
443
|
label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
|
426
444
|
label_arrow_linewidth=arrow_linewidth,
|
427
445
|
label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
|
446
|
+
label_max_labels=max_labels,
|
428
447
|
label_max_words=max_words,
|
429
448
|
label_min_words=min_words,
|
449
|
+
label_max_word_length=max_word_length,
|
450
|
+
label_min_word_length=min_word_length,
|
451
|
+
label_words_to_omit=words_to_omit,
|
430
452
|
)
|
453
|
+
|
431
454
|
# Convert color strings to RGBA arrays if necessary
|
432
455
|
if isinstance(fontcolor, str):
|
433
|
-
fontcolor = self.
|
456
|
+
fontcolor = self.get_annotated_label_colors(color=fontcolor)
|
434
457
|
if isinstance(arrow_color, str):
|
435
|
-
arrow_color = self.
|
458
|
+
arrow_color = self.get_annotated_label_colors(color=arrow_color)
|
459
|
+
# Normalize words_to_omit to lowercase
|
460
|
+
if words_to_omit:
|
461
|
+
words_to_omit = set(word.lower() for word in words_to_omit)
|
436
462
|
|
437
463
|
# Calculate the center and radius of the network
|
438
464
|
domain_centroids = {}
|
@@ -443,31 +469,57 @@ class NetworkPlotter:
|
|
443
469
|
# Initialize empty lists to collect valid indices
|
444
470
|
valid_indices = []
|
445
471
|
filtered_domain_centroids = {}
|
472
|
+
filtered_domain_terms = {}
|
446
473
|
# Loop through domain_centroids with index
|
447
474
|
for idx, (domain, centroid) in enumerate(domain_centroids.items()):
|
475
|
+
# Process the domain term
|
476
|
+
terms = self.graph.trimmed_domain_to_term[domain].split(" ")
|
477
|
+
# Remove words_to_omit
|
478
|
+
if words_to_omit:
|
479
|
+
terms = [term for term in terms if term.lower() not in words_to_omit]
|
480
|
+
# Filter words based on length
|
481
|
+
terms = [term for term in terms if min_word_length <= len(term) <= max_word_length]
|
482
|
+
# Trim to max_words
|
483
|
+
terms = terms[:max_words]
|
448
484
|
# Check if the domain passes the word count condition
|
449
|
-
if len(
|
485
|
+
if len(terms) >= min_words:
|
450
486
|
# Add to filtered_domain_centroids
|
451
487
|
filtered_domain_centroids[domain] = centroid
|
452
|
-
#
|
488
|
+
# Store the filtered and trimmed terms
|
489
|
+
filtered_domain_terms[domain] = " ".join(terms)
|
490
|
+
# Keep track of the valid index - used for fontcolor and arrow_color
|
453
491
|
valid_indices.append(idx)
|
454
492
|
|
455
|
-
#
|
456
|
-
|
457
|
-
|
493
|
+
# If max_labels is specified and less than the available labels
|
494
|
+
if max_labels is not None and max_labels < len(filtered_domain_centroids):
|
495
|
+
step = len(filtered_domain_centroids) / max_labels
|
496
|
+
selected_indices = [int(i * step) for i in range(max_labels)]
|
497
|
+
# Filter the centroids, terms, and valid_indices to only use the selected indices
|
498
|
+
filtered_domain_centroids = {
|
499
|
+
k: v
|
500
|
+
for i, (k, v) in enumerate(filtered_domain_centroids.items())
|
501
|
+
if i in selected_indices
|
502
|
+
}
|
503
|
+
filtered_domain_terms = {
|
504
|
+
k: v
|
505
|
+
for i, (k, v) in enumerate(filtered_domain_terms.items())
|
506
|
+
if i in selected_indices
|
507
|
+
}
|
508
|
+
# Update valid_indices to match selected indices
|
509
|
+
valid_indices = [valid_indices[i] for i in selected_indices]
|
458
510
|
|
459
511
|
# Calculate the bounding box around the network
|
460
512
|
center, radius = _calculate_bounding_box(
|
461
513
|
self.graph.node_coordinates, radius_margin=perimeter_scale
|
462
514
|
)
|
463
515
|
# Calculate the best positions for labels around the perimeter
|
464
|
-
best_label_positions =
|
516
|
+
best_label_positions = _calculate_best_label_positions(
|
465
517
|
filtered_domain_centroids, center, radius, offset
|
466
518
|
)
|
467
|
-
# Annotate the network with labels
|
468
|
-
for idx, (domain, pos) in
|
519
|
+
# Annotate the network with labels - valid_indices is used for fontcolor and arrow_color
|
520
|
+
for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
|
469
521
|
centroid = filtered_domain_centroids[domain]
|
470
|
-
annotations =
|
522
|
+
annotations = filtered_domain_terms[domain].split(" ")[:max_words]
|
471
523
|
self.ax.annotate(
|
472
524
|
"\n".join(annotations),
|
473
525
|
xy=centroid,
|
@@ -531,12 +583,10 @@ class NetworkPlotter:
|
|
531
583
|
|
532
584
|
# Calculate the centroid of the provided nodes
|
533
585
|
centroid = self._calculate_domain_centroid(node_ids)
|
534
|
-
|
535
586
|
# Calculate the bounding box around the network
|
536
587
|
center, radius = _calculate_bounding_box(
|
537
588
|
self.graph.node_coordinates, radius_margin=perimeter_scale
|
538
589
|
)
|
539
|
-
|
540
590
|
# Convert radial position to radians, adjusting for a 90-degree rotation
|
541
591
|
radial_radians = np.deg2rad(radial_position - 90)
|
542
592
|
label_position = (
|
@@ -746,7 +796,7 @@ def _calculate_bounding_box(
|
|
746
796
|
return center, radius
|
747
797
|
|
748
798
|
|
749
|
-
def
|
799
|
+
def _calculate_best_label_positions(
|
750
800
|
filtered_domain_centroids: Dict[str, Any], center: np.ndarray, radius: float, offset: float
|
751
801
|
) -> Dict[str, Any]:
|
752
802
|
"""Calculate and optimize label positions for clarity.
|
@@ -762,7 +812,9 @@ def _best_label_positions(
|
|
762
812
|
"""
|
763
813
|
num_domains = len(filtered_domain_centroids)
|
764
814
|
# Calculate equidistant positions around the center for initial label placement
|
765
|
-
equidistant_positions =
|
815
|
+
equidistant_positions = _calculate_equidistant_positions_around_center(
|
816
|
+
center, radius, offset, num_domains
|
817
|
+
)
|
766
818
|
# Create a mapping of domains to their initial label positions
|
767
819
|
label_positions = {
|
768
820
|
domain: position
|
@@ -772,7 +824,7 @@ def _best_label_positions(
|
|
772
824
|
return _optimize_label_positions(label_positions, filtered_domain_centroids)
|
773
825
|
|
774
826
|
|
775
|
-
def
|
827
|
+
def _calculate_equidistant_positions_around_center(
|
776
828
|
center: np.ndarray, radius: float, label_offset: float, num_domains: int
|
777
829
|
) -> List[np.ndarray]:
|
778
830
|
"""Calculate positions around a center at equidistant angles.
|