risk-network 0.0.8b27__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -2
- risk/annotations/__init__.py +2 -2
- risk/annotations/annotations.py +195 -118
- risk/annotations/io.py +47 -31
- risk/log/__init__.py +4 -2
- risk/log/console.py +3 -1
- risk/log/{params.py → parameters.py} +17 -42
- risk/neighborhoods/__init__.py +3 -5
- risk/neighborhoods/api.py +442 -0
- risk/neighborhoods/community.py +324 -101
- risk/neighborhoods/domains.py +125 -52
- risk/neighborhoods/neighborhoods.py +177 -165
- risk/network/__init__.py +1 -3
- risk/network/geometry.py +71 -89
- risk/network/graph/__init__.py +6 -0
- risk/network/graph/api.py +200 -0
- risk/network/{graph.py → graph/graph.py} +90 -40
- risk/network/graph/summary.py +254 -0
- risk/network/io.py +103 -114
- risk/network/plotter/__init__.py +6 -0
- risk/network/plotter/api.py +54 -0
- risk/network/{plot → plotter}/canvas.py +9 -8
- risk/network/{plot → plotter}/contour.py +27 -24
- risk/network/{plot → plotter}/labels.py +73 -78
- risk/network/{plot → plotter}/network.py +45 -39
- risk/network/{plot → plotter}/plotter.py +23 -17
- risk/network/{plot/utils/color.py → plotter/utils/colors.py} +114 -122
- risk/network/{plot → plotter}/utils/layout.py +10 -7
- risk/risk.py +11 -500
- risk/stats/__init__.py +10 -4
- risk/stats/permutation/__init__.py +1 -1
- risk/stats/permutation/permutation.py +44 -38
- risk/stats/permutation/test_functions.py +26 -18
- risk/stats/{stats.py → significance.py} +17 -15
- risk/stats/stat_tests.py +267 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/METADATA +31 -46
- risk_network-0.0.9.dist-info/RECORD +40 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/WHEEL +1 -1
- risk/constants.py +0 -31
- risk/network/plot/__init__.py +0 -6
- risk/stats/hypergeom.py +0 -54
- risk/stats/poisson.py +0 -44
- risk_network-0.0.8b27.dist-info/RECORD +0 -37
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/
|
3
|
-
|
2
|
+
risk/network/plotter/contour
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
from typing import List, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
8
|
import matplotlib.pyplot as plt
|
9
9
|
import numpy as np
|
@@ -12,18 +12,18 @@ from scipy.ndimage import label
|
|
12
12
|
from scipy.stats import gaussian_kde
|
13
13
|
|
14
14
|
from risk.log import params, logger
|
15
|
-
from risk.network.graph import
|
16
|
-
from risk.network.
|
15
|
+
from risk.network.graph.graph import Graph
|
16
|
+
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
17
17
|
|
18
18
|
|
19
19
|
class Contour:
|
20
20
|
"""Class to generate Kernel Density Estimate (KDE) contours for nodes in a network graph."""
|
21
21
|
|
22
|
-
def __init__(self, graph:
|
23
|
-
"""Initialize the Contour with a
|
22
|
+
def __init__(self, graph: Graph, ax: plt.Axes) -> None:
|
23
|
+
"""Initialize the Contour with a Graph and axis for plotting.
|
24
24
|
|
25
25
|
Args:
|
26
|
-
graph (
|
26
|
+
graph (Graph): The Graph object containing the network data.
|
27
27
|
ax (plt.Axes): The axis to plot the contours on.
|
28
28
|
"""
|
29
29
|
self.graph = graph
|
@@ -63,6 +63,8 @@ class Contour:
|
|
63
63
|
contour_color=(
|
64
64
|
"custom" if isinstance(color, np.ndarray) else color
|
65
65
|
), # np.ndarray usually indicates custom colors
|
66
|
+
contour_linestyle=linestyle,
|
67
|
+
contour_linewidth=linewidth,
|
66
68
|
contour_alpha=alpha,
|
67
69
|
contour_fill_alpha=fill_alpha,
|
68
70
|
)
|
@@ -82,7 +84,7 @@ class Contour:
|
|
82
84
|
self.ax,
|
83
85
|
node_coordinates,
|
84
86
|
node_ids,
|
85
|
-
color=
|
87
|
+
color=color_rgba[idx],
|
86
88
|
levels=levels,
|
87
89
|
bandwidth=bandwidth,
|
88
90
|
grid_size=grid_size,
|
@@ -195,7 +197,7 @@ class Contour:
|
|
195
197
|
# Extract the positions of the specified nodes
|
196
198
|
points = np.array([pos[n] for n in nodes])
|
197
199
|
if len(points) <= 1:
|
198
|
-
return
|
200
|
+
return # Not enough points to form a contour
|
199
201
|
|
200
202
|
# Check if the KDE forms a single connected component
|
201
203
|
connected = False
|
@@ -219,12 +221,12 @@ class Contour:
|
|
219
221
|
except Exception as e:
|
220
222
|
# Catch any other exceptions and log them
|
221
223
|
logger.error(f"Unexpected error when drawing KDE contour: {e}")
|
222
|
-
return
|
224
|
+
return
|
223
225
|
|
224
226
|
# If z is still None, the KDE computation failed
|
225
227
|
if z is None:
|
226
228
|
logger.error("Failed to compute KDE. Skipping contour plot for these nodes.")
|
227
|
-
return
|
229
|
+
return
|
228
230
|
|
229
231
|
# Define contour levels based on the density
|
230
232
|
min_density, max_density = z.min(), z.max()
|
@@ -232,15 +234,15 @@ class Contour:
|
|
232
234
|
logger.warning(
|
233
235
|
"Contour levels could not be created due to lack of variation in density."
|
234
236
|
)
|
235
|
-
return
|
237
|
+
return
|
236
238
|
|
237
239
|
# Create contour levels based on the density values
|
238
240
|
contour_levels = np.linspace(min_density, max_density, levels)[1:]
|
239
241
|
if len(contour_levels) < 2 or not np.all(np.diff(contour_levels) > 0):
|
240
242
|
logger.error("Contour levels must be strictly increasing. Skipping contour plot.")
|
241
|
-
return
|
243
|
+
return
|
242
244
|
|
243
|
-
# Set the contour color and linestyle
|
245
|
+
# Set the contour color, fill, and linestyle
|
244
246
|
contour_colors = [color for _ in range(levels - 1)]
|
245
247
|
# Plot the filled contours using fill_alpha for transparency
|
246
248
|
if fill_alpha and fill_alpha > 0:
|
@@ -256,21 +258,19 @@ class Contour:
|
|
256
258
|
alpha=fill_alpha,
|
257
259
|
)
|
258
260
|
|
259
|
-
# Plot the contour
|
260
|
-
|
261
|
+
# Plot the base contour line with the specified RGBA alpha for transparency
|
262
|
+
base_contour_color = [color]
|
263
|
+
base_contour_level = [contour_levels[0]]
|
264
|
+
ax.contour(
|
261
265
|
x,
|
262
266
|
y,
|
263
267
|
z,
|
264
|
-
levels=
|
265
|
-
colors=
|
268
|
+
levels=base_contour_level,
|
269
|
+
colors=base_contour_color,
|
266
270
|
linestyles=linestyle,
|
267
271
|
linewidths=linewidth,
|
268
272
|
)
|
269
273
|
|
270
|
-
# Set linewidth for the contour lines to 0 for levels other than the base level
|
271
|
-
for i in range(1, len(contour_levels)):
|
272
|
-
c.collections[i].set_linewidth(0)
|
273
|
-
|
274
274
|
def get_annotated_contour_colors(
|
275
275
|
self,
|
276
276
|
cmap: str = "gist_rainbow",
|
@@ -280,6 +280,7 @@ class Contour:
|
|
280
280
|
min_scale: float = 0.8,
|
281
281
|
max_scale: float = 1.0,
|
282
282
|
scale_factor: float = 1.0,
|
283
|
+
ids_to_colors: Union[Dict[int, Any], None] = None,
|
283
284
|
random_seed: int = 888,
|
284
285
|
) -> np.ndarray:
|
285
286
|
"""Get colors for the contours based on node annotations or a specified colormap.
|
@@ -294,8 +295,9 @@ class Contour:
|
|
294
295
|
Controls the dimmest colors. Defaults to 0.8.
|
295
296
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
296
297
|
Controls the brightest colors. Defaults to 1.0.
|
297
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on
|
298
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
|
298
299
|
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
300
|
+
ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
|
299
301
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
300
302
|
|
301
303
|
Returns:
|
@@ -310,6 +312,7 @@ class Contour:
|
|
310
312
|
min_scale=min_scale,
|
311
313
|
max_scale=max_scale,
|
312
314
|
scale_factor=scale_factor,
|
315
|
+
ids_to_colors=ids_to_colors,
|
313
316
|
random_seed=random_seed,
|
314
317
|
)
|
315
318
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/
|
3
|
-
|
2
|
+
risk/network/plotter/labels
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
6
|
import copy
|
@@ -11,9 +11,9 @@ import numpy as np
|
|
11
11
|
import pandas as pd
|
12
12
|
|
13
13
|
from risk.log import params
|
14
|
-
from risk.network.graph import
|
15
|
-
from risk.network.
|
16
|
-
from risk.network.
|
14
|
+
from risk.network.graph.graph import Graph
|
15
|
+
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
16
|
+
from risk.network.plotter.utils.layout import calculate_bounding_box
|
17
17
|
|
18
18
|
TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
|
19
19
|
|
@@ -21,11 +21,11 @@ TERM_DELIMITER = "::::" # String used to separate multiple domain terms when co
|
|
21
21
|
class Labels:
|
22
22
|
"""Class to handle the annotation of network graphs with labels for different domains."""
|
23
23
|
|
24
|
-
def __init__(self, graph:
|
24
|
+
def __init__(self, graph: Graph, ax: plt.Axes):
|
25
25
|
"""Initialize the Labeler object with a network graph and matplotlib axes.
|
26
26
|
|
27
27
|
Args:
|
28
|
-
graph (
|
28
|
+
graph (Graph): Graph object containing the network data.
|
29
29
|
ax (plt.Axes): Matplotlib axes object to plot the labels on.
|
30
30
|
"""
|
31
31
|
self.graph = graph
|
@@ -54,7 +54,7 @@ class Labels:
|
|
54
54
|
words_to_omit: Union[List, None] = None,
|
55
55
|
overlay_ids: bool = False,
|
56
56
|
ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
|
57
|
-
|
57
|
+
ids_to_labels: Union[Dict[int, str], None] = None,
|
58
58
|
) -> None:
|
59
59
|
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
60
60
|
|
@@ -62,7 +62,7 @@ class Labels:
|
|
62
62
|
scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
|
63
63
|
offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
|
64
64
|
font (str, optional): Font name for the labels. Defaults to "Arial".
|
65
|
-
fontcase (
|
65
|
+
fontcase (str, Dict[str, str], or None, optional): Defines how to transform the case of words.
|
66
66
|
- If a string (e.g., 'upper', 'lower', 'title'), applies the transformation to all words.
|
67
67
|
- If a dictionary, maps specific cases ('lower', 'upper', 'title') to transformations (e.g., 'lower'='upper').
|
68
68
|
- If None, no transformation is applied.
|
@@ -87,7 +87,7 @@ class Labels:
|
|
87
87
|
overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
|
88
88
|
ids_to_keep (List, Tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
|
89
89
|
you can set `overlay_ids=True`. Defaults to None.
|
90
|
-
|
90
|
+
ids_to_labels (Dict[int, str], optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
|
91
91
|
space-separated words. If provided, the custom labels will replace the default domain terms. To discover domain IDs, you
|
92
92
|
can set `overlay_ids=True`. Defaults to None.
|
93
93
|
|
@@ -119,7 +119,7 @@ class Labels:
|
|
119
119
|
label_words_to_omit=words_to_omit,
|
120
120
|
label_overlay_ids=overlay_ids,
|
121
121
|
label_ids_to_keep=ids_to_keep,
|
122
|
-
|
122
|
+
label_ids_to_labels=ids_to_labels,
|
123
123
|
)
|
124
124
|
|
125
125
|
# Convert ids_to_keep to a tuple if it is not None
|
@@ -152,7 +152,7 @@ class Labels:
|
|
152
152
|
self._process_ids_to_keep(
|
153
153
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
154
154
|
ids_to_keep=ids_to_keep,
|
155
|
-
|
155
|
+
ids_to_labels=ids_to_labels,
|
156
156
|
words_to_omit=words_to_omit,
|
157
157
|
max_labels=max_labels,
|
158
158
|
min_label_lines=min_label_lines,
|
@@ -173,7 +173,7 @@ class Labels:
|
|
173
173
|
self._process_remaining_domains(
|
174
174
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
175
175
|
ids_to_keep=ids_to_keep,
|
176
|
-
|
176
|
+
ids_to_labels=ids_to_labels,
|
177
177
|
words_to_omit=words_to_omit,
|
178
178
|
remaining_labels=remaining_labels,
|
179
179
|
min_chars_per_line=min_chars_per_line,
|
@@ -218,13 +218,14 @@ class Labels:
|
|
218
218
|
fontsize=fontsize,
|
219
219
|
fontname=font,
|
220
220
|
color=fontcolor_rgba[idx],
|
221
|
-
arrowprops=
|
222
|
-
arrowstyle
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
221
|
+
arrowprops={
|
222
|
+
"arrowstyle": arrow_style,
|
223
|
+
"linewidth": arrow_linewidth,
|
224
|
+
"color": arrow_color_rgba[idx],
|
225
|
+
"alpha": arrow_alpha,
|
226
|
+
"shrinkA": arrow_base_shrink,
|
227
|
+
"shrinkB": arrow_tip_shrink,
|
228
|
+
},
|
228
229
|
)
|
229
230
|
|
230
231
|
# Overlay domain ID at the centroid regardless of max_labels if requested
|
@@ -334,13 +335,14 @@ class Labels:
|
|
334
335
|
fontsize=fontsize,
|
335
336
|
fontname=font,
|
336
337
|
color=fontcolor_rgba[idx],
|
337
|
-
arrowprops=
|
338
|
-
arrowstyle
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
338
|
+
arrowprops={
|
339
|
+
"arrowstyle": arrow_style,
|
340
|
+
"linewidth": arrow_linewidth,
|
341
|
+
"color": arrow_color_rgba[idx],
|
342
|
+
"alpha": arrow_alpha,
|
343
|
+
"shrinkA": arrow_base_shrink,
|
344
|
+
"shrinkB": arrow_tip_shrink,
|
345
|
+
},
|
344
346
|
)
|
345
347
|
|
346
348
|
def _calculate_domain_centroid(self, nodes: List) -> tuple:
|
@@ -368,7 +370,7 @@ class Labels:
|
|
368
370
|
self,
|
369
371
|
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
370
372
|
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
371
|
-
|
373
|
+
ids_to_labels: Union[Dict[int, str], None],
|
372
374
|
words_to_omit: Union[List[str], None],
|
373
375
|
max_labels: Union[int, None],
|
374
376
|
min_label_lines: int,
|
@@ -384,7 +386,7 @@ class Labels:
|
|
384
386
|
Args:
|
385
387
|
domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
|
386
388
|
ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
387
|
-
|
389
|
+
ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
388
390
|
words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
|
389
391
|
max_labels (int, optional): Maximum number of labels allowed.
|
390
392
|
min_label_lines (int): Minimum number of lines in a label.
|
@@ -419,7 +421,7 @@ class Labels:
|
|
419
421
|
domain=domain,
|
420
422
|
domain_centroid=domain_centroid,
|
421
423
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
422
|
-
|
424
|
+
ids_to_labels=ids_to_labels,
|
423
425
|
words_to_omit=words_to_omit,
|
424
426
|
min_label_lines=min_label_lines,
|
425
427
|
max_label_lines=max_label_lines,
|
@@ -434,7 +436,7 @@ class Labels:
|
|
434
436
|
self,
|
435
437
|
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
436
438
|
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
437
|
-
|
439
|
+
ids_to_labels: Union[Dict[int, str], None],
|
438
440
|
words_to_omit: Union[List[str], None],
|
439
441
|
remaining_labels: int,
|
440
442
|
min_label_lines: int,
|
@@ -450,7 +452,7 @@ class Labels:
|
|
450
452
|
Args:
|
451
453
|
domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
|
452
454
|
ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
453
|
-
|
455
|
+
ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
454
456
|
words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
|
455
457
|
remaining_labels (int): The remaining number of labels that can be generated.
|
456
458
|
min_label_lines (int): Minimum number of lines in a label.
|
@@ -515,7 +517,7 @@ class Labels:
|
|
515
517
|
domain=domain,
|
516
518
|
domain_centroid=domain_centroid,
|
517
519
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
518
|
-
|
520
|
+
ids_to_labels=ids_to_labels,
|
519
521
|
words_to_omit=words_to_omit,
|
520
522
|
min_label_lines=min_label_lines,
|
521
523
|
max_label_lines=max_label_lines,
|
@@ -536,7 +538,7 @@ class Labels:
|
|
536
538
|
domain: str,
|
537
539
|
domain_centroid: np.ndarray,
|
538
540
|
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
539
|
-
|
541
|
+
ids_to_labels: Union[Dict[int, str], None],
|
540
542
|
words_to_omit: Union[List[str], None],
|
541
543
|
min_label_lines: int,
|
542
544
|
max_label_lines: int,
|
@@ -552,7 +554,7 @@ class Labels:
|
|
552
554
|
domain (str): Domain ID to process.
|
553
555
|
domain_centroid (np.ndarray): Centroid position of the domain.
|
554
556
|
domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
|
555
|
-
|
557
|
+
ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
556
558
|
words_to_omit (List[str], None, optional): List of words to omit from the labels. Defaults to None.
|
557
559
|
min_label_lines (int): Minimum number of lines required in a label.
|
558
560
|
max_label_lines (int): Maximum number of lines allowed in a label.
|
@@ -564,39 +566,39 @@ class Labels:
|
|
564
566
|
|
565
567
|
Returns:
|
566
568
|
bool: True if the domain is valid and added to the filtered dictionaries, False otherwise.
|
567
|
-
|
568
|
-
Note:
|
569
|
-
The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
|
570
569
|
"""
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
570
|
+
if ids_to_labels and domain in ids_to_labels:
|
571
|
+
# Directly use custom labels without filtering
|
572
|
+
domain_terms = ids_to_labels[domain]
|
573
|
+
else:
|
574
|
+
# Process the domain terms automatically
|
575
|
+
domain_terms = self._process_terms(
|
576
|
+
domain=domain,
|
577
|
+
words_to_omit=words_to_omit,
|
578
|
+
max_label_lines=max_label_lines,
|
579
|
+
min_chars_per_line=min_chars_per_line,
|
580
|
+
max_chars_per_line=max_chars_per_line,
|
581
|
+
)
|
582
|
+
# If no valid terms are generated, skip further processing
|
583
|
+
if not domain_terms:
|
584
|
+
return False
|
585
|
+
|
586
|
+
# Split the terms by TERM_DELIMITER and count the number of lines
|
587
|
+
num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
|
588
|
+
# Check if the number of lines meets the minimum requirement
|
589
|
+
if num_domain_lines < min_label_lines:
|
590
|
+
return False
|
591
|
+
|
592
|
+
# Store the valid terms and centroids
|
593
|
+
filtered_domain_centroids[domain] = domain_centroid
|
594
|
+
filtered_domain_terms[domain] = domain_terms
|
595
|
+
valid_indices.append(list(domain_id_to_centroid_map.keys()).index(domain))
|
596
|
+
|
597
|
+
return True
|
595
598
|
|
596
599
|
def _process_terms(
|
597
600
|
self,
|
598
601
|
domain: str,
|
599
|
-
ids_to_replace: Union[Dict[str, str], None],
|
600
602
|
words_to_omit: Union[List[str], None],
|
601
603
|
max_label_lines: int,
|
602
604
|
min_chars_per_line: int,
|
@@ -606,8 +608,7 @@ class Labels:
|
|
606
608
|
|
607
609
|
Args:
|
608
610
|
domain (str): The domain being processed.
|
609
|
-
|
610
|
-
words_to_omit (List, optional): List of words to omit from the labels.
|
611
|
+
words_to_omit (List[str], None): List of words to omit from the labels.
|
611
612
|
max_label_lines (int): Maximum number of lines in a label.
|
612
613
|
min_chars_per_line (int): Minimum number of characters in a line to display.
|
613
614
|
max_chars_per_line (int): Maximum number of characters in a line to display.
|
@@ -615,14 +616,8 @@ class Labels:
|
|
615
616
|
Returns:
|
616
617
|
str: Processed terms separated by TERM_DELIMITER, with words combined if necessary to fit within constraints.
|
617
618
|
"""
|
618
|
-
#
|
619
|
-
|
620
|
-
terms = ids_to_replace[domain].replace(" ", TERM_DELIMITER)
|
621
|
-
return terms
|
622
|
-
|
623
|
-
else:
|
624
|
-
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
625
|
-
|
619
|
+
# Set custom labels from significant terms
|
620
|
+
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
626
621
|
# Apply words_to_omit and word length constraints
|
627
622
|
if words_to_omit:
|
628
623
|
terms = [
|
@@ -645,6 +640,7 @@ class Labels:
|
|
645
640
|
min_scale: float = 0.8,
|
646
641
|
max_scale: float = 1.0,
|
647
642
|
scale_factor: float = 1.0,
|
643
|
+
ids_to_colors: Union[Dict[int, Any], None] = None,
|
648
644
|
random_seed: int = 888,
|
649
645
|
) -> np.ndarray:
|
650
646
|
"""Get colors for the labels based on node annotations or a specified colormap.
|
@@ -659,8 +655,9 @@ class Labels:
|
|
659
655
|
Controls the dimmest colors. Defaults to 0.8.
|
660
656
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
661
657
|
Controls the brightest colors. Defaults to 1.0.
|
662
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on
|
658
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
|
663
659
|
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
660
|
+
ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
|
664
661
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
665
662
|
|
666
663
|
Returns:
|
@@ -675,6 +672,7 @@ class Labels:
|
|
675
672
|
min_scale=min_scale,
|
676
673
|
max_scale=max_scale,
|
677
674
|
scale_factor=scale_factor,
|
675
|
+
ids_to_colors=ids_to_colors,
|
678
676
|
random_seed=random_seed,
|
679
677
|
)
|
680
678
|
|
@@ -763,10 +761,7 @@ def _calculate_best_label_positions(
|
|
763
761
|
center, radius, offset, num_domains
|
764
762
|
)
|
765
763
|
# Create a mapping of domains to their initial label positions
|
766
|
-
label_positions =
|
767
|
-
domain: position
|
768
|
-
for domain, position in zip(filtered_domain_centroids.keys(), equidistant_positions)
|
769
|
-
}
|
764
|
+
label_positions = dict(zip(filtered_domain_centroids.keys(), equidistant_positions))
|
770
765
|
# Optimize the label positions to minimize distance to domain centroids
|
771
766
|
return _optimize_label_positions(label_positions, filtered_domain_centroids)
|
772
767
|
|
@@ -1,26 +1,29 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/
|
3
|
-
|
2
|
+
risk/network/plotter/network
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
from typing import Any, List, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
8
|
import networkx as nx
|
9
9
|
import numpy as np
|
10
10
|
|
11
11
|
from risk.log import params
|
12
|
-
from risk.network.graph import
|
13
|
-
from risk.network.
|
12
|
+
from risk.network.graph.graph import Graph
|
13
|
+
from risk.network.plotter.utils.colors import get_domain_colors, to_rgba
|
14
14
|
|
15
15
|
|
16
16
|
class Network:
|
17
|
-
"""
|
17
|
+
"""A class for plotting network graphs with customizable options.
|
18
18
|
|
19
|
-
|
20
|
-
|
19
|
+
The Network class provides methods to plot network graphs with flexible node and edge properties.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, graph: Graph, ax: Any = None) -> None:
|
23
|
+
"""Initialize the Plotter class.
|
21
24
|
|
22
25
|
Args:
|
23
|
-
graph (
|
26
|
+
graph (Graph): The network data and attributes to be visualized.
|
24
27
|
ax (Any, optional): Axes object to plot the network graph. Defaults to None.
|
25
28
|
"""
|
26
29
|
self.graph = graph
|
@@ -203,11 +206,12 @@ class Network:
|
|
203
206
|
max_scale: float = 1.0,
|
204
207
|
scale_factor: float = 1.0,
|
205
208
|
alpha: Union[float, None] = 1.0,
|
206
|
-
|
207
|
-
|
209
|
+
nonsignificant_color: Union[str, List, Tuple, np.ndarray] = "white",
|
210
|
+
nonsignificant_alpha: Union[float, None] = 1.0,
|
211
|
+
ids_to_colors: Union[Dict[int, Any], None] = None,
|
208
212
|
random_seed: int = 888,
|
209
213
|
) -> np.ndarray:
|
210
|
-
"""Adjust the colors of nodes in the network graph based on
|
214
|
+
"""Adjust the colors of nodes in the network graph based on significance.
|
211
215
|
|
212
216
|
Args:
|
213
217
|
cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
|
@@ -218,16 +222,17 @@ class Network:
|
|
218
222
|
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
219
223
|
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
220
224
|
scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
|
221
|
-
alpha (float, None, optional): Alpha value for
|
225
|
+
alpha (float, None, optional): Alpha value for significant nodes. If provided, it overrides any existing alpha values found in `color`.
|
222
226
|
Defaults to 1.0.
|
223
|
-
|
227
|
+
nonsignificant_color (str, List, Tuple, or np.ndarray, optional): Color for non-significant nodes. Can be a single color or an array of colors.
|
224
228
|
Defaults to "white".
|
225
|
-
|
226
|
-
in `
|
229
|
+
nonsignificant_alpha (float, None, optional): Alpha value for non-significant nodes. If provided, it overrides any existing alpha values found
|
230
|
+
in `nonsignificant_color`. Defaults to 1.0.
|
231
|
+
ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
|
227
232
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
228
233
|
|
229
234
|
Returns:
|
230
|
-
np.ndarray: Array of RGBA colors adjusted for
|
235
|
+
np.ndarray: Array of RGBA colors adjusted for significance status.
|
231
236
|
"""
|
232
237
|
# Get the initial domain colors for each node, which are returned as RGBA
|
233
238
|
network_colors = get_domain_colors(
|
@@ -239,50 +244,51 @@ class Network:
|
|
239
244
|
min_scale=min_scale,
|
240
245
|
max_scale=max_scale,
|
241
246
|
scale_factor=scale_factor,
|
247
|
+
ids_to_colors=ids_to_colors,
|
242
248
|
random_seed=random_seed,
|
243
249
|
)
|
244
|
-
# Apply the alpha value for
|
245
|
-
network_colors[:, 3] = alpha # Apply the alpha value to the
|
246
|
-
# Convert the non-
|
247
|
-
|
248
|
-
color=
|
250
|
+
# Apply the alpha value for significant nodes
|
251
|
+
network_colors[:, 3] = alpha # Apply the alpha value to the significant nodes' A channel
|
252
|
+
# Convert the non-significant color to RGBA using the to_rgba helper function
|
253
|
+
nonsignificant_color_rgba = to_rgba(
|
254
|
+
color=nonsignificant_color, alpha=nonsignificant_alpha, num_repeats=1
|
249
255
|
) # num_repeats=1 for a single color
|
250
|
-
# Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
|
256
|
+
# Adjust node colors: replace any nodes where all three RGB values are equal and less than or equal to 0.1
|
251
257
|
# 0.1 is a predefined threshold for the minimum color intensity
|
252
258
|
adjusted_network_colors = np.where(
|
253
259
|
(
|
254
|
-
np.all(network_colors[:, :3]
|
260
|
+
np.all(network_colors[:, :3] <= 0.1, axis=1)
|
255
261
|
& np.all(network_colors[:, :3] == network_colors[:, 0:1], axis=1)
|
256
262
|
)[:, None],
|
257
263
|
np.tile(
|
258
|
-
np.array(
|
259
|
-
), # Replace with the full RGBA non-
|
264
|
+
np.array(nonsignificant_color_rgba), (network_colors.shape[0], 1)
|
265
|
+
), # Replace with the full RGBA non-significant color
|
260
266
|
network_colors, # Keep the original colors where no match is found
|
261
267
|
)
|
262
268
|
return adjusted_network_colors
|
263
269
|
|
264
270
|
def get_annotated_node_sizes(
|
265
|
-
self,
|
271
|
+
self, significant_size: int = 50, nonsignificant_size: int = 25
|
266
272
|
) -> np.ndarray:
|
267
|
-
"""Adjust the sizes of nodes in the network graph based on whether they are
|
273
|
+
"""Adjust the sizes of nodes in the network graph based on whether they are significant or not.
|
268
274
|
|
269
275
|
Args:
|
270
|
-
|
271
|
-
|
276
|
+
significant_size (int): Size for significant nodes. Defaults to 50.
|
277
|
+
nonsignificant_size (int): Size for non-significant nodes. Defaults to 25.
|
272
278
|
|
273
279
|
Returns:
|
274
|
-
np.ndarray: Array of node sizes, with
|
280
|
+
np.ndarray: Array of node sizes, with significant nodes larger than non-significant ones.
|
275
281
|
"""
|
276
|
-
# Merge all
|
277
|
-
|
282
|
+
# Merge all significant nodes from the domain_id_to_node_ids_map dictionary
|
283
|
+
significant_nodes = set()
|
278
284
|
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
279
|
-
|
285
|
+
significant_nodes.update(node_ids)
|
280
286
|
|
281
|
-
# Initialize all node sizes to the non-
|
282
|
-
node_sizes = np.full(len(self.graph.network.nodes),
|
283
|
-
# Set the size for
|
284
|
-
for node in
|
287
|
+
# Initialize all node sizes to the non-significant size
|
288
|
+
node_sizes = np.full(len(self.graph.network.nodes), nonsignificant_size)
|
289
|
+
# Set the size for significant nodes
|
290
|
+
for node in significant_nodes:
|
285
291
|
if node in self.graph.network.nodes:
|
286
|
-
node_sizes[node] =
|
292
|
+
node_sizes[node] = significant_size
|
287
293
|
|
288
294
|
return node_sizes
|