risk-network 0.0.8b26__py3-none-any.whl → 0.0.9b26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -2
- risk/annotations/__init__.py +2 -2
- risk/annotations/annotations.py +74 -47
- risk/annotations/io.py +47 -31
- risk/log/__init__.py +4 -2
- risk/log/{config.py → console.py} +5 -3
- risk/log/{params.py → parameters.py} +17 -42
- risk/neighborhoods/__init__.py +3 -5
- risk/neighborhoods/api.py +446 -0
- risk/neighborhoods/community.py +255 -77
- risk/neighborhoods/domains.py +62 -31
- risk/neighborhoods/neighborhoods.py +156 -160
- risk/network/__init__.py +1 -3
- risk/network/geometry.py +65 -57
- risk/network/graph/__init__.py +6 -0
- risk/network/graph/api.py +194 -0
- risk/network/{graph.py → graph/network.py} +87 -37
- risk/network/graph/summary.py +254 -0
- risk/network/io.py +56 -47
- risk/network/plotter/__init__.py +6 -0
- risk/network/plotter/api.py +54 -0
- risk/network/{plot → plotter}/canvas.py +7 -4
- risk/network/{plot → plotter}/contour.py +22 -19
- risk/network/{plot → plotter}/labels.py +69 -74
- risk/network/{plot → plotter}/network.py +170 -34
- risk/network/{plot/utils/color.py → plotter/utils/colors.py} +104 -112
- risk/network/{plot → plotter}/utils/layout.py +8 -5
- risk/risk.py +11 -500
- risk/stats/__init__.py +8 -4
- risk/stats/binom.py +51 -0
- risk/stats/chi2.py +69 -0
- risk/stats/hypergeom.py +27 -17
- risk/stats/permutation/__init__.py +1 -1
- risk/stats/permutation/permutation.py +44 -38
- risk/stats/permutation/test_functions.py +25 -17
- risk/stats/poisson.py +15 -9
- risk/stats/stats.py +15 -13
- risk/stats/zscore.py +68 -0
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
- risk_network-0.0.9b26.dist-info/RECORD +44 -0
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
- risk/network/plot/__init__.py +0 -6
- risk/network/plot/plotter.py +0 -137
- risk_network-0.0.8b26.dist-info/RECORD +0 -37
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ risk/network/plot/contour
|
|
3
3
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
from typing import List, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
8
|
import matplotlib.pyplot as plt
|
9
9
|
import numpy as np
|
@@ -12,8 +12,8 @@ from scipy.ndimage import label
|
|
12
12
|
from scipy.stats import gaussian_kde
|
13
13
|
|
14
14
|
from risk.log import params, logger
|
15
|
-
from risk.network.graph import NetworkGraph
|
16
|
-
from risk.network.
|
15
|
+
from risk.network.graph.network import NetworkGraph
|
16
|
+
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
17
17
|
|
18
18
|
|
19
19
|
class Contour:
|
@@ -63,6 +63,8 @@ class Contour:
|
|
63
63
|
contour_color=(
|
64
64
|
"custom" if isinstance(color, np.ndarray) else color
|
65
65
|
), # np.ndarray usually indicates custom colors
|
66
|
+
contour_linestyle=linestyle,
|
67
|
+
contour_linewidth=linewidth,
|
66
68
|
contour_alpha=alpha,
|
67
69
|
contour_fill_alpha=fill_alpha,
|
68
70
|
)
|
@@ -82,7 +84,7 @@ class Contour:
|
|
82
84
|
self.ax,
|
83
85
|
node_coordinates,
|
84
86
|
node_ids,
|
85
|
-
color=
|
87
|
+
color=color_rgba[idx],
|
86
88
|
levels=levels,
|
87
89
|
bandwidth=bandwidth,
|
88
90
|
grid_size=grid_size,
|
@@ -195,7 +197,7 @@ class Contour:
|
|
195
197
|
# Extract the positions of the specified nodes
|
196
198
|
points = np.array([pos[n] for n in nodes])
|
197
199
|
if len(points) <= 1:
|
198
|
-
return
|
200
|
+
return # Not enough points to form a contour
|
199
201
|
|
200
202
|
# Check if the KDE forms a single connected component
|
201
203
|
connected = False
|
@@ -219,12 +221,12 @@ class Contour:
|
|
219
221
|
except Exception as e:
|
220
222
|
# Catch any other exceptions and log them
|
221
223
|
logger.error(f"Unexpected error when drawing KDE contour: {e}")
|
222
|
-
return
|
224
|
+
return
|
223
225
|
|
224
226
|
# If z is still None, the KDE computation failed
|
225
227
|
if z is None:
|
226
228
|
logger.error("Failed to compute KDE. Skipping contour plot for these nodes.")
|
227
|
-
return
|
229
|
+
return
|
228
230
|
|
229
231
|
# Define contour levels based on the density
|
230
232
|
min_density, max_density = z.min(), z.max()
|
@@ -232,15 +234,15 @@ class Contour:
|
|
232
234
|
logger.warning(
|
233
235
|
"Contour levels could not be created due to lack of variation in density."
|
234
236
|
)
|
235
|
-
return
|
237
|
+
return
|
236
238
|
|
237
239
|
# Create contour levels based on the density values
|
238
240
|
contour_levels = np.linspace(min_density, max_density, levels)[1:]
|
239
241
|
if len(contour_levels) < 2 or not np.all(np.diff(contour_levels) > 0):
|
240
242
|
logger.error("Contour levels must be strictly increasing. Skipping contour plot.")
|
241
|
-
return
|
243
|
+
return
|
242
244
|
|
243
|
-
# Set the contour color and linestyle
|
245
|
+
# Set the contour color, fill, and linestyle
|
244
246
|
contour_colors = [color for _ in range(levels - 1)]
|
245
247
|
# Plot the filled contours using fill_alpha for transparency
|
246
248
|
if fill_alpha and fill_alpha > 0:
|
@@ -256,21 +258,19 @@ class Contour:
|
|
256
258
|
alpha=fill_alpha,
|
257
259
|
)
|
258
260
|
|
259
|
-
# Plot the contour
|
260
|
-
|
261
|
+
# Plot the base contour line with the specified RGBA alpha for transparency
|
262
|
+
base_contour_color = [color]
|
263
|
+
base_contour_level = [contour_levels[0]]
|
264
|
+
ax.contour(
|
261
265
|
x,
|
262
266
|
y,
|
263
267
|
z,
|
264
|
-
levels=
|
265
|
-
colors=
|
268
|
+
levels=base_contour_level,
|
269
|
+
colors=base_contour_color,
|
266
270
|
linestyles=linestyle,
|
267
271
|
linewidths=linewidth,
|
268
272
|
)
|
269
273
|
|
270
|
-
# Set linewidth for the contour lines to 0 for levels other than the base level
|
271
|
-
for i in range(1, len(contour_levels)):
|
272
|
-
c.collections[i].set_linewidth(0)
|
273
|
-
|
274
274
|
def get_annotated_contour_colors(
|
275
275
|
self,
|
276
276
|
cmap: str = "gist_rainbow",
|
@@ -280,6 +280,7 @@ class Contour:
|
|
280
280
|
min_scale: float = 0.8,
|
281
281
|
max_scale: float = 1.0,
|
282
282
|
scale_factor: float = 1.0,
|
283
|
+
ids_to_colors: Union[Dict[int, Any], None] = None,
|
283
284
|
random_seed: int = 888,
|
284
285
|
) -> np.ndarray:
|
285
286
|
"""Get colors for the contours based on node annotations or a specified colormap.
|
@@ -294,8 +295,9 @@ class Contour:
|
|
294
295
|
Controls the dimmest colors. Defaults to 0.8.
|
295
296
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
296
297
|
Controls the brightest colors. Defaults to 1.0.
|
297
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on
|
298
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
|
298
299
|
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
300
|
+
ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
|
299
301
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
300
302
|
|
301
303
|
Returns:
|
@@ -310,6 +312,7 @@ class Contour:
|
|
310
312
|
min_scale=min_scale,
|
311
313
|
max_scale=max_scale,
|
312
314
|
scale_factor=scale_factor,
|
315
|
+
ids_to_colors=ids_to_colors,
|
313
316
|
random_seed=random_seed,
|
314
317
|
)
|
315
318
|
|
@@ -11,9 +11,9 @@ import numpy as np
|
|
11
11
|
import pandas as pd
|
12
12
|
|
13
13
|
from risk.log import params
|
14
|
-
from risk.network.graph import NetworkGraph
|
15
|
-
from risk.network.
|
16
|
-
from risk.network.
|
14
|
+
from risk.network.graph.network import NetworkGraph
|
15
|
+
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
16
|
+
from risk.network.plotter.utils.layout import calculate_bounding_box
|
17
17
|
|
18
18
|
TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
|
19
19
|
|
@@ -54,7 +54,7 @@ class Labels:
|
|
54
54
|
words_to_omit: Union[List, None] = None,
|
55
55
|
overlay_ids: bool = False,
|
56
56
|
ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
|
57
|
-
|
57
|
+
ids_to_labels: Union[Dict[int, str], None] = None,
|
58
58
|
) -> None:
|
59
59
|
"""Annotate the network graph with labels for different domains, positioned around the network for clarity.
|
60
60
|
|
@@ -62,7 +62,7 @@ class Labels:
|
|
62
62
|
scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
|
63
63
|
offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
|
64
64
|
font (str, optional): Font name for the labels. Defaults to "Arial".
|
65
|
-
fontcase (
|
65
|
+
fontcase (str, Dict[str, str], or None, optional): Defines how to transform the case of words.
|
66
66
|
- If a string (e.g., 'upper', 'lower', 'title'), applies the transformation to all words.
|
67
67
|
- If a dictionary, maps specific cases ('lower', 'upper', 'title') to transformations (e.g., 'lower'='upper').
|
68
68
|
- If None, no transformation is applied.
|
@@ -87,7 +87,7 @@ class Labels:
|
|
87
87
|
overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
|
88
88
|
ids_to_keep (List, Tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
|
89
89
|
you can set `overlay_ids=True`. Defaults to None.
|
90
|
-
|
90
|
+
ids_to_labels (Dict[int, str], optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
|
91
91
|
space-separated words. If provided, the custom labels will replace the default domain terms. To discover domain IDs, you
|
92
92
|
can set `overlay_ids=True`. Defaults to None.
|
93
93
|
|
@@ -119,7 +119,7 @@ class Labels:
|
|
119
119
|
label_words_to_omit=words_to_omit,
|
120
120
|
label_overlay_ids=overlay_ids,
|
121
121
|
label_ids_to_keep=ids_to_keep,
|
122
|
-
|
122
|
+
label_ids_to_labels=ids_to_labels,
|
123
123
|
)
|
124
124
|
|
125
125
|
# Convert ids_to_keep to a tuple if it is not None
|
@@ -152,7 +152,7 @@ class Labels:
|
|
152
152
|
self._process_ids_to_keep(
|
153
153
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
154
154
|
ids_to_keep=ids_to_keep,
|
155
|
-
|
155
|
+
ids_to_labels=ids_to_labels,
|
156
156
|
words_to_omit=words_to_omit,
|
157
157
|
max_labels=max_labels,
|
158
158
|
min_label_lines=min_label_lines,
|
@@ -173,7 +173,7 @@ class Labels:
|
|
173
173
|
self._process_remaining_domains(
|
174
174
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
175
175
|
ids_to_keep=ids_to_keep,
|
176
|
-
|
176
|
+
ids_to_labels=ids_to_labels,
|
177
177
|
words_to_omit=words_to_omit,
|
178
178
|
remaining_labels=remaining_labels,
|
179
179
|
min_chars_per_line=min_chars_per_line,
|
@@ -218,13 +218,14 @@ class Labels:
|
|
218
218
|
fontsize=fontsize,
|
219
219
|
fontname=font,
|
220
220
|
color=fontcolor_rgba[idx],
|
221
|
-
arrowprops=
|
222
|
-
arrowstyle
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
221
|
+
arrowprops={
|
222
|
+
"arrowstyle": arrow_style,
|
223
|
+
"linewidth": arrow_linewidth,
|
224
|
+
"color": arrow_color_rgba[idx],
|
225
|
+
"alpha": arrow_alpha,
|
226
|
+
"shrinkA": arrow_base_shrink,
|
227
|
+
"shrinkB": arrow_tip_shrink,
|
228
|
+
},
|
228
229
|
)
|
229
230
|
|
230
231
|
# Overlay domain ID at the centroid regardless of max_labels if requested
|
@@ -334,13 +335,14 @@ class Labels:
|
|
334
335
|
fontsize=fontsize,
|
335
336
|
fontname=font,
|
336
337
|
color=fontcolor_rgba[idx],
|
337
|
-
arrowprops=
|
338
|
-
arrowstyle
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
338
|
+
arrowprops={
|
339
|
+
"arrowstyle": arrow_style,
|
340
|
+
"linewidth": arrow_linewidth,
|
341
|
+
"color": arrow_color_rgba[idx],
|
342
|
+
"alpha": arrow_alpha,
|
343
|
+
"shrinkA": arrow_base_shrink,
|
344
|
+
"shrinkB": arrow_tip_shrink,
|
345
|
+
},
|
344
346
|
)
|
345
347
|
|
346
348
|
def _calculate_domain_centroid(self, nodes: List) -> tuple:
|
@@ -368,7 +370,7 @@ class Labels:
|
|
368
370
|
self,
|
369
371
|
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
370
372
|
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
371
|
-
|
373
|
+
ids_to_labels: Union[Dict[int, str], None],
|
372
374
|
words_to_omit: Union[List[str], None],
|
373
375
|
max_labels: Union[int, None],
|
374
376
|
min_label_lines: int,
|
@@ -384,7 +386,7 @@ class Labels:
|
|
384
386
|
Args:
|
385
387
|
domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
|
386
388
|
ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
387
|
-
|
389
|
+
ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
388
390
|
words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
|
389
391
|
max_labels (int, optional): Maximum number of labels allowed.
|
390
392
|
min_label_lines (int): Minimum number of lines in a label.
|
@@ -419,7 +421,7 @@ class Labels:
|
|
419
421
|
domain=domain,
|
420
422
|
domain_centroid=domain_centroid,
|
421
423
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
422
|
-
|
424
|
+
ids_to_labels=ids_to_labels,
|
423
425
|
words_to_omit=words_to_omit,
|
424
426
|
min_label_lines=min_label_lines,
|
425
427
|
max_label_lines=max_label_lines,
|
@@ -434,7 +436,7 @@ class Labels:
|
|
434
436
|
self,
|
435
437
|
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
436
438
|
ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
|
437
|
-
|
439
|
+
ids_to_labels: Union[Dict[int, str], None],
|
438
440
|
words_to_omit: Union[List[str], None],
|
439
441
|
remaining_labels: int,
|
440
442
|
min_label_lines: int,
|
@@ -450,7 +452,7 @@ class Labels:
|
|
450
452
|
Args:
|
451
453
|
domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
|
452
454
|
ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
|
453
|
-
|
455
|
+
ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
454
456
|
words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
|
455
457
|
remaining_labels (int): The remaining number of labels that can be generated.
|
456
458
|
min_label_lines (int): Minimum number of lines in a label.
|
@@ -515,7 +517,7 @@ class Labels:
|
|
515
517
|
domain=domain,
|
516
518
|
domain_centroid=domain_centroid,
|
517
519
|
domain_id_to_centroid_map=domain_id_to_centroid_map,
|
518
|
-
|
520
|
+
ids_to_labels=ids_to_labels,
|
519
521
|
words_to_omit=words_to_omit,
|
520
522
|
min_label_lines=min_label_lines,
|
521
523
|
max_label_lines=max_label_lines,
|
@@ -536,7 +538,7 @@ class Labels:
|
|
536
538
|
domain: str,
|
537
539
|
domain_centroid: np.ndarray,
|
538
540
|
domain_id_to_centroid_map: Dict[str, np.ndarray],
|
539
|
-
|
541
|
+
ids_to_labels: Union[Dict[int, str], None],
|
540
542
|
words_to_omit: Union[List[str], None],
|
541
543
|
min_label_lines: int,
|
542
544
|
max_label_lines: int,
|
@@ -552,7 +554,7 @@ class Labels:
|
|
552
554
|
domain (str): Domain ID to process.
|
553
555
|
domain_centroid (np.ndarray): Centroid position of the domain.
|
554
556
|
domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
|
555
|
-
|
557
|
+
ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
|
556
558
|
words_to_omit (List[str], None, optional): List of words to omit from the labels. Defaults to None.
|
557
559
|
min_label_lines (int): Minimum number of lines required in a label.
|
558
560
|
max_label_lines (int): Maximum number of lines allowed in a label.
|
@@ -564,39 +566,39 @@ class Labels:
|
|
564
566
|
|
565
567
|
Returns:
|
566
568
|
bool: True if the domain is valid and added to the filtered dictionaries, False otherwise.
|
567
|
-
|
568
|
-
Note:
|
569
|
-
The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
|
570
569
|
"""
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
570
|
+
if ids_to_labels and domain in ids_to_labels:
|
571
|
+
# Directly use custom labels without filtering
|
572
|
+
domain_terms = ids_to_labels[domain]
|
573
|
+
else:
|
574
|
+
# Process the domain terms automatically
|
575
|
+
domain_terms = self._process_terms(
|
576
|
+
domain=domain,
|
577
|
+
words_to_omit=words_to_omit,
|
578
|
+
max_label_lines=max_label_lines,
|
579
|
+
min_chars_per_line=min_chars_per_line,
|
580
|
+
max_chars_per_line=max_chars_per_line,
|
581
|
+
)
|
582
|
+
# If no valid terms are generated, skip further processing
|
583
|
+
if not domain_terms:
|
584
|
+
return False
|
585
|
+
|
586
|
+
# Split the terms by TERM_DELIMITER and count the number of lines
|
587
|
+
num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
|
588
|
+
# Check if the number of lines meets the minimum requirement
|
589
|
+
if num_domain_lines < min_label_lines:
|
590
|
+
return False
|
591
|
+
|
592
|
+
# Store the valid terms and centroids
|
593
|
+
filtered_domain_centroids[domain] = domain_centroid
|
594
|
+
filtered_domain_terms[domain] = domain_terms
|
595
|
+
valid_indices.append(list(domain_id_to_centroid_map.keys()).index(domain))
|
596
|
+
|
597
|
+
return True
|
595
598
|
|
596
599
|
def _process_terms(
|
597
600
|
self,
|
598
601
|
domain: str,
|
599
|
-
ids_to_replace: Union[Dict[str, str], None],
|
600
602
|
words_to_omit: Union[List[str], None],
|
601
603
|
max_label_lines: int,
|
602
604
|
min_chars_per_line: int,
|
@@ -606,8 +608,7 @@ class Labels:
|
|
606
608
|
|
607
609
|
Args:
|
608
610
|
domain (str): The domain being processed.
|
609
|
-
|
610
|
-
words_to_omit (List, optional): List of words to omit from the labels.
|
611
|
+
words_to_omit (List[str], None): List of words to omit from the labels.
|
611
612
|
max_label_lines (int): Maximum number of lines in a label.
|
612
613
|
min_chars_per_line (int): Minimum number of characters in a line to display.
|
613
614
|
max_chars_per_line (int): Maximum number of characters in a line to display.
|
@@ -615,14 +616,8 @@ class Labels:
|
|
615
616
|
Returns:
|
616
617
|
str: Processed terms separated by TERM_DELIMITER, with words combined if necessary to fit within constraints.
|
617
618
|
"""
|
618
|
-
#
|
619
|
-
|
620
|
-
terms = ids_to_replace[domain].replace(" ", TERM_DELIMITER)
|
621
|
-
return terms
|
622
|
-
|
623
|
-
else:
|
624
|
-
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
625
|
-
|
619
|
+
# Set custom labels from significant terms
|
620
|
+
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
626
621
|
# Apply words_to_omit and word length constraints
|
627
622
|
if words_to_omit:
|
628
623
|
terms = [
|
@@ -645,6 +640,7 @@ class Labels:
|
|
645
640
|
min_scale: float = 0.8,
|
646
641
|
max_scale: float = 1.0,
|
647
642
|
scale_factor: float = 1.0,
|
643
|
+
ids_to_colors: Union[Dict[int, Any], None] = None,
|
648
644
|
random_seed: int = 888,
|
649
645
|
) -> np.ndarray:
|
650
646
|
"""Get colors for the labels based on node annotations or a specified colormap.
|
@@ -659,8 +655,9 @@ class Labels:
|
|
659
655
|
Controls the dimmest colors. Defaults to 0.8.
|
660
656
|
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
661
657
|
Controls the brightest colors. Defaults to 1.0.
|
662
|
-
scale_factor (float, optional): Exponent for adjusting color scaling based on
|
658
|
+
scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
|
663
659
|
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
660
|
+
ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
|
664
661
|
random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
|
665
662
|
|
666
663
|
Returns:
|
@@ -675,6 +672,7 @@ class Labels:
|
|
675
672
|
min_scale=min_scale,
|
676
673
|
max_scale=max_scale,
|
677
674
|
scale_factor=scale_factor,
|
675
|
+
ids_to_colors=ids_to_colors,
|
678
676
|
random_seed=random_seed,
|
679
677
|
)
|
680
678
|
|
@@ -763,10 +761,7 @@ def _calculate_best_label_positions(
|
|
763
761
|
center, radius, offset, num_domains
|
764
762
|
)
|
765
763
|
# Create a mapping of domains to their initial label positions
|
766
|
-
label_positions =
|
767
|
-
domain: position
|
768
|
-
for domain, position in zip(filtered_domain_centroids.keys(), equidistant_positions)
|
769
|
-
}
|
764
|
+
label_positions = dict(zip(filtered_domain_centroids.keys(), equidistant_positions))
|
770
765
|
# Optimize the label positions to minimize distance to domain centroids
|
771
766
|
return _optimize_label_positions(label_positions, filtered_domain_centroids)
|
772
767
|
|