risk-network 0.0.12b0__py3-none-any.whl → 0.0.12b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. risk/__init__.py +1 -1
  2. risk/annotations/__init__.py +10 -0
  3. risk/annotations/annotations.py +354 -0
  4. risk/annotations/io.py +241 -0
  5. risk/annotations/nltk_setup.py +86 -0
  6. risk/log/__init__.py +11 -0
  7. risk/log/console.py +141 -0
  8. risk/log/parameters.py +171 -0
  9. risk/neighborhoods/__init__.py +7 -0
  10. risk/neighborhoods/api.py +442 -0
  11. risk/neighborhoods/community.py +441 -0
  12. risk/neighborhoods/domains.py +360 -0
  13. risk/neighborhoods/neighborhoods.py +514 -0
  14. risk/neighborhoods/stats/__init__.py +13 -0
  15. risk/neighborhoods/stats/permutation/__init__.py +6 -0
  16. risk/neighborhoods/stats/permutation/permutation.py +240 -0
  17. risk/neighborhoods/stats/permutation/test_functions.py +70 -0
  18. risk/neighborhoods/stats/tests.py +275 -0
  19. risk/network/__init__.py +4 -0
  20. risk/network/graph/__init__.py +4 -0
  21. risk/network/graph/api.py +200 -0
  22. risk/network/graph/graph.py +274 -0
  23. risk/network/graph/stats.py +166 -0
  24. risk/network/graph/summary.py +253 -0
  25. risk/network/io.py +693 -0
  26. risk/network/plotter/__init__.py +4 -0
  27. risk/network/plotter/api.py +54 -0
  28. risk/network/plotter/canvas.py +291 -0
  29. risk/network/plotter/contour.py +329 -0
  30. risk/network/plotter/labels.py +935 -0
  31. risk/network/plotter/network.py +294 -0
  32. risk/network/plotter/plotter.py +141 -0
  33. risk/network/plotter/utils/colors.py +419 -0
  34. risk/network/plotter/utils/layout.py +94 -0
  35. risk_network-0.0.12b2.dist-info/METADATA +122 -0
  36. risk_network-0.0.12b2.dist-info/RECORD +40 -0
  37. {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b2.dist-info}/WHEEL +1 -1
  38. risk_network-0.0.12b0.dist-info/METADATA +0 -796
  39. risk_network-0.0.12b0.dist-info/RECORD +0 -7
  40. {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b2.dist-info}/licenses/LICENSE +0 -0
  41. {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,935 @@
1
+ """
2
+ risk/network/plotter/labels
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import copy
7
+ from typing import Any, Dict, List, Tuple, Union
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from risk.log import params
14
+ from risk.network.graph.graph import Graph
15
+ from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
16
+ from risk.network.plotter.utils.layout import calculate_bounding_box
17
+
18
+ TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
19
+
20
+
21
+ class Labels:
22
+ """Class to handle the annotation of network graphs with labels for different domains."""
23
+
24
+ def __init__(self, graph: Graph, ax: plt.Axes):
25
+ """Initialize the Labeler object with a network graph and matplotlib axes.
26
+
27
+ Args:
28
+ graph (Graph): Graph object containing the network data.
29
+ ax (plt.Axes): Matplotlib axes object to plot the labels on.
30
+ """
31
+ self.graph = graph
32
+ self.ax = ax
33
+
34
+ def plot_labels(
35
+ self,
36
+ scale: float = 1.05,
37
+ offset: float = 0.10,
38
+ font: str = "Arial",
39
+ fontcase: Union[str, Dict[str, str], None] = None,
40
+ fontsize: int = 10,
41
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
42
+ fontalpha: Union[float, None] = 1.0,
43
+ arrow_linewidth: float = 1,
44
+ arrow_style: str = "->",
45
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
46
+ arrow_alpha: Union[float, None] = 1.0,
47
+ arrow_base_shrink: float = 0.0,
48
+ arrow_tip_shrink: float = 0.0,
49
+ max_labels: Union[int, None] = None,
50
+ max_label_lines: Union[int, None] = None,
51
+ min_label_lines: int = 1,
52
+ max_chars_per_line: Union[int, None] = None,
53
+ min_chars_per_line: int = 1,
54
+ words_to_omit: Union[List, None] = None,
55
+ overlay_ids: bool = False,
56
+ ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
57
+ ids_to_labels: Union[Dict[int, str], None] = None,
58
+ ) -> None:
59
+ """Annotate the network graph with labels for different domains, positioned around the network for clarity.
60
+
61
+ Args:
62
+ scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
63
+ offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
64
+ font (str, optional): Font name for the labels. Defaults to "Arial".
65
+ fontcase (str, Dict[str, str], or None, optional): Defines how to transform the case of words.
66
+ - If a string (e.g., 'upper', 'lower', 'title'), applies the transformation to all words.
67
+ - If a dictionary, maps specific cases ('lower', 'upper', 'title') to transformations (e.g., 'lower'='upper').
68
+ - If None, no transformation is applied.
69
+ fontsize (int, optional): Font size for the labels. Defaults to 10.
70
+ fontcolor (str, List, Tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array.
71
+ Defaults to "black".
72
+ fontalpha (float, None, optional): Transparency level for the font color. If provided, it overrides any existing alpha
73
+ values found in fontcolor. Defaults to 1.0.
74
+ arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
75
+ arrow_style (str, optional): Style of the arrows pointing to centroids. Defaults to "->".
76
+ arrow_color (str, List, Tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
77
+ arrow_alpha (float, None, optional): Transparency level for the arrow color. If provided, it overrides any existing alpha
78
+ values found in arrow_color. Defaults to 1.0.
79
+ arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
80
+ arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
81
+ max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
82
+ min_label_lines (int, optional): Minimum number of lines in a label. Defaults to 1.
83
+ max_label_lines (int, optional): Maximum number of lines in a label. Defaults to None (no limit).
84
+ min_chars_per_line (int, optional): Minimum number of characters in a line to display. Defaults to 1.
85
+ max_chars_per_line (int, optional): Maximum number of characters in a line to display. Defaults to None (no limit).
86
+ words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
87
+ overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
88
+ ids_to_keep (List, Tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
89
+ you can set `overlay_ids=True`. Defaults to None.
90
+ ids_to_labels (Dict[int, str], optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
91
+ space-separated words. If provided, the custom labels will replace the default domain terms. To discover domain IDs, you
92
+ can set `overlay_ids=True`. Defaults to None.
93
+
94
+ Raises:
95
+ ValueError: If the number of provided `ids_to_keep` exceeds `max_labels`.
96
+ """
97
+ # Log the plotting parameters
98
+ params.log_plotter(
99
+ label_perimeter_scale=scale,
100
+ label_offset=offset,
101
+ label_font=font,
102
+ label_fontcase=fontcase,
103
+ label_fontsize=fontsize,
104
+ label_fontcolor=(
105
+ "custom" if isinstance(fontcolor, np.ndarray) else fontcolor
106
+ ), # np.ndarray usually indicates custom colors
107
+ label_fontalpha=fontalpha,
108
+ label_arrow_linewidth=arrow_linewidth,
109
+ label_arrow_style=arrow_style,
110
+ label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
111
+ label_arrow_alpha=arrow_alpha,
112
+ label_arrow_base_shrink=arrow_base_shrink,
113
+ label_arrow_tip_shrink=arrow_tip_shrink,
114
+ label_max_labels=max_labels,
115
+ label_min_label_lines=min_label_lines,
116
+ label_max_label_lines=max_label_lines,
117
+ label_max_chars_per_line=max_chars_per_line,
118
+ label_min_chars_per_line=min_chars_per_line,
119
+ label_words_to_omit=words_to_omit,
120
+ label_overlay_ids=overlay_ids,
121
+ label_ids_to_keep=ids_to_keep,
122
+ label_ids_to_labels=ids_to_labels,
123
+ )
124
+
125
+ # Convert ids_to_keep to a tuple if it is not None
126
+ ids_to_keep = tuple(ids_to_keep) if ids_to_keep else tuple()
127
+ # Set max_labels to the total number of domains if not provided (None)
128
+ if max_labels is None:
129
+ max_labels = len(self.graph.domain_id_to_node_ids_map)
130
+ # Set max_label_lines and max_chars_per_line to large numbers if not provided (None)
131
+ if max_label_lines is None:
132
+ max_label_lines = int(1e6)
133
+ if max_chars_per_line is None:
134
+ max_chars_per_line = int(1e6)
135
+ # Normalize words_to_omit to lowercase
136
+ if words_to_omit:
137
+ words_to_omit = set(word.lower() for word in words_to_omit)
138
+
139
+ # Calculate the center and radius of domains to position labels around the network
140
+ domain_id_to_centroid_map = {}
141
+ for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
142
+ if node_ids: # Skip if the domain has no nodes
143
+ domain_id_to_centroid_map[domain_id] = self._calculate_domain_centroid(node_ids)
144
+
145
+ # Initialize dictionaries and lists for valid indices
146
+ valid_indices = [] # List of valid indices to plot colors and arrows
147
+ filtered_domain_centroids = {} # Filtered domain centroids to plot
148
+ filtered_domain_terms = {} # Filtered domain terms to plot
149
+ # Handle the ids_to_keep logic
150
+ if ids_to_keep:
151
+ # Process the ids_to_keep first INPLACE
152
+ self._process_ids_to_keep(
153
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
154
+ ids_to_keep=ids_to_keep,
155
+ ids_to_labels=ids_to_labels,
156
+ words_to_omit=words_to_omit,
157
+ max_labels=max_labels,
158
+ min_label_lines=min_label_lines,
159
+ max_label_lines=max_label_lines,
160
+ min_chars_per_line=min_chars_per_line,
161
+ max_chars_per_line=max_chars_per_line,
162
+ filtered_domain_centroids=filtered_domain_centroids,
163
+ filtered_domain_terms=filtered_domain_terms,
164
+ valid_indices=valid_indices,
165
+ )
166
+
167
+ # Calculate remaining labels to plot after processing ids_to_keep
168
+ remaining_labels = (
169
+ max_labels - len(valid_indices) if valid_indices and max_labels else max_labels
170
+ )
171
+ # Process remaining domains INPLACE to fill in additional labels, if there are slots left
172
+ if remaining_labels and remaining_labels > 0:
173
+ self._process_remaining_domains(
174
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
175
+ ids_to_keep=ids_to_keep,
176
+ ids_to_labels=ids_to_labels,
177
+ words_to_omit=words_to_omit,
178
+ remaining_labels=remaining_labels,
179
+ min_chars_per_line=min_chars_per_line,
180
+ max_chars_per_line=max_chars_per_line,
181
+ max_label_lines=max_label_lines,
182
+ min_label_lines=min_label_lines,
183
+ filtered_domain_centroids=filtered_domain_centroids,
184
+ filtered_domain_terms=filtered_domain_terms,
185
+ valid_indices=valid_indices,
186
+ )
187
+
188
+ # Calculate the bounding box around the network
189
+ center, radius = calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
190
+ # Calculate the best positions for labels
191
+ best_label_positions = self._calculate_best_label_positions(
192
+ filtered_domain_centroids, center, radius, offset
193
+ )
194
+ # Convert all domain colors to RGBA using the to_rgba helper function
195
+ fontcolor_rgba = to_rgba(
196
+ color=fontcolor, alpha=fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
197
+ )
198
+ arrow_color_rgba = to_rgba(
199
+ color=arrow_color,
200
+ alpha=arrow_alpha,
201
+ num_repeats=len(self.graph.domain_id_to_node_ids_map),
202
+ )
203
+
204
+ # Annotate the network with labels
205
+ for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
206
+ centroid = filtered_domain_centroids[domain]
207
+ # Split by special key TERM_DELIMITER to split annotation into multiple lines
208
+ annotations = filtered_domain_terms[domain].split(TERM_DELIMITER)
209
+ if fontcase is not None:
210
+ annotations = self._apply_str_transformation(
211
+ words=annotations, transformation=fontcase
212
+ )
213
+ self.ax.annotate(
214
+ "\n".join(annotations),
215
+ xy=centroid,
216
+ xytext=pos,
217
+ textcoords="data",
218
+ ha="center",
219
+ va="center",
220
+ fontsize=fontsize,
221
+ fontname=font,
222
+ color=fontcolor_rgba[idx],
223
+ arrowprops={
224
+ "arrowstyle": arrow_style,
225
+ "linewidth": arrow_linewidth,
226
+ "color": arrow_color_rgba[idx],
227
+ "alpha": arrow_alpha,
228
+ "shrinkA": arrow_base_shrink,
229
+ "shrinkB": arrow_tip_shrink,
230
+ },
231
+ )
232
+
233
+ # Overlay domain ID at the centroid regardless of max_labels if requested
234
+ if overlay_ids:
235
+ for idx, domain in enumerate(self.graph.domain_id_to_node_ids_map):
236
+ centroid = domain_id_to_centroid_map[domain]
237
+ self.ax.text(
238
+ centroid[0],
239
+ centroid[1],
240
+ domain,
241
+ ha="center",
242
+ va="center",
243
+ fontsize=fontsize,
244
+ fontname=font,
245
+ color=fontcolor_rgba[idx],
246
+ )
247
+
248
+ def plot_sublabel(
249
+ self,
250
+ nodes: Union[List, Tuple, np.ndarray],
251
+ label: str,
252
+ radial_position: float = 0.0,
253
+ scale: float = 1.05,
254
+ offset: float = 0.10,
255
+ font: str = "Arial",
256
+ fontsize: int = 10,
257
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
258
+ fontalpha: Union[float, None] = 1.0,
259
+ arrow_linewidth: float = 1,
260
+ arrow_style: str = "->",
261
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
262
+ arrow_alpha: Union[float, None] = 1.0,
263
+ arrow_base_shrink: float = 0.0,
264
+ arrow_tip_shrink: float = 0.0,
265
+ ) -> None:
266
+ """Annotate the network graph with a label for the given nodes, with one arrow pointing to each centroid of sublists of nodes.
267
+
268
+ Args:
269
+ nodes (List, Tuple, or np.ndarray): List of node labels or list of lists of node labels.
270
+ label (str): The label to be annotated on the network.
271
+ radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
272
+ scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
273
+ offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
274
+ font (str, optional): Font name for the label. Defaults to "Arial".
275
+ fontsize (int, optional): Font size for the label. Defaults to 10.
276
+ fontcolor (str, List, Tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
277
+ fontalpha (float, None, optional): Transparency level for the font color. If provided, it overrides any existing alpha values found
278
+ in fontalpha. Defaults to 1.0.
279
+ arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
280
+ arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
281
+ arrow_color (str, List, Tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
282
+ arrow_alpha (float, None, optional): Transparency level for the arrow color. If provided, it overrides any existing alpha values
283
+ found in arrow_alpha. Defaults to 1.0.
284
+ arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
285
+ arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
286
+
287
+ Raises:
288
+ ValueError: If no nodes are found in the network graph or if there are insufficient nodes to plot.
289
+ """
290
+ # Check if nodes is a list of lists or a flat list
291
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
292
+ # If it's a list of lists, iterate over sublists
293
+ node_groups = nodes
294
+ # Convert fontcolor and arrow_color to RGBA arrays to match the number of groups
295
+ fontcolor_rgba = to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=len(node_groups))
296
+ arrow_color_rgba = to_rgba(
297
+ color=arrow_color, alpha=arrow_alpha, num_repeats=len(node_groups)
298
+ )
299
+ else:
300
+ # If it's a flat list of nodes, treat it as a single group
301
+ node_groups = [nodes]
302
+ # Wrap the RGBA fontcolor and arrow_color in an array to index the first element
303
+ fontcolor_rgba = np.array(to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=1))
304
+ arrow_color_rgba = np.array(
305
+ to_rgba(color=arrow_color, alpha=arrow_alpha, num_repeats=1)
306
+ )
307
+
308
+ # Calculate the bounding box around the network
309
+ center, radius = calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
310
+ # Convert radial position to radians, adjusting for a 90-degree rotation
311
+ radial_radians = np.deg2rad(radial_position - 90)
312
+ label_position = (
313
+ center[0] + (radius + offset) * np.cos(radial_radians),
314
+ center[1] + (radius + offset) * np.sin(radial_radians),
315
+ )
316
+
317
+ # Iterate over each group of nodes (either sublists or flat list)
318
+ for idx, sublist in enumerate(node_groups):
319
+ # Map node labels to IDs
320
+ node_ids = [
321
+ self.graph.node_label_to_node_id_map.get(node)
322
+ for node in sublist
323
+ if node in self.graph.node_label_to_node_id_map
324
+ ]
325
+ if not node_ids or len(node_ids) == 1:
326
+ raise ValueError(
327
+ "No nodes found in the network graph or insufficient nodes to plot."
328
+ )
329
+
330
+ # Calculate the centroid of the provided nodes in this sublist
331
+ centroid = self._calculate_domain_centroid(node_ids)
332
+ # Annotate the network with the label and an arrow pointing to each centroid
333
+ self.ax.annotate(
334
+ label,
335
+ xy=centroid,
336
+ xytext=label_position,
337
+ textcoords="data",
338
+ ha="center",
339
+ va="center",
340
+ fontsize=fontsize,
341
+ fontname=font,
342
+ color=fontcolor_rgba[idx],
343
+ arrowprops={
344
+ "arrowstyle": arrow_style,
345
+ "linewidth": arrow_linewidth,
346
+ "color": arrow_color_rgba[idx],
347
+ "alpha": arrow_alpha,
348
+ "shrinkA": arrow_base_shrink,
349
+ "shrinkB": arrow_tip_shrink,
350
+ },
351
+ )
352
+
353
+ def _calculate_domain_centroid(self, nodes: List) -> tuple:
354
+ """Calculate the most centrally located node in .
355
+
356
+ Args:
357
+ nodes (List): List of node labels to include in the subnetwork.
358
+
359
+ Returns:
360
+ tuple: A tuple containing the domain's central node coordinates.
361
+ """
362
+ # Extract positions of all nodes in the domain
363
+ node_positions = self.graph.node_coordinates[nodes, :]
364
+ # Calculate the pairwise distance matrix between all nodes in the domain
365
+ distances_matrix = np.linalg.norm(node_positions[:, np.newaxis] - node_positions, axis=2)
366
+ # Sum the distances for each node to all other nodes in the domain
367
+ sum_distances = np.sum(distances_matrix, axis=1)
368
+ # Identify the node with the smallest total distance to others (the centroid)
369
+ central_node_idx = np.argmin(sum_distances)
370
+ # Map the domain to the coordinates of its central node
371
+ domain_central_node = node_positions[central_node_idx]
372
+ return domain_central_node
373
+
374
+ def _process_ids_to_keep(
375
+ self,
376
+ domain_id_to_centroid_map: Dict[str, np.ndarray],
377
+ ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
378
+ ids_to_labels: Union[Dict[int, str], None],
379
+ words_to_omit: Union[List[str], None],
380
+ max_labels: Union[int, None],
381
+ min_label_lines: int,
382
+ max_label_lines: int,
383
+ min_chars_per_line: int,
384
+ max_chars_per_line: int,
385
+ filtered_domain_centroids: Dict[str, np.ndarray],
386
+ filtered_domain_terms: Dict[str, str],
387
+ valid_indices: List[int],
388
+ ) -> None:
389
+ """Process the ids_to_keep, apply filtering, and store valid domain centroids and terms.
390
+
391
+ Args:
392
+ domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
393
+ ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
394
+ ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
395
+ words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
396
+ max_labels (int, optional): Maximum number of labels allowed.
397
+ min_label_lines (int): Minimum number of lines in a label.
398
+ max_label_lines (int): Maximum number of lines in a label.
399
+ min_chars_per_line (int): Minimum number of characters in a line to display.
400
+ max_chars_per_line (int): Maximum number of characters in a line to display.
401
+ filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store filtered domain centroids (output).
402
+ filtered_domain_terms (Dict[str, str]): Dictionary to store filtered domain terms (output).
403
+ valid_indices (List): List to store valid indices (output).
404
+
405
+ Note:
406
+ The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
407
+
408
+ Raises:
409
+ ValueError: If the number of provided `ids_to_keep` exceeds `max_labels`.
410
+ """
411
+ # Check if the number of provided ids_to_keep exceeds max_labels
412
+ if max_labels is not None and len(ids_to_keep) > max_labels:
413
+ raise ValueError(
414
+ f"Number of provided IDs ({len(ids_to_keep)}) exceeds max_labels ({max_labels})."
415
+ )
416
+
417
+ # Process each domain in ids_to_keep
418
+ for domain in ids_to_keep:
419
+ if (
420
+ domain in self.graph.domain_id_to_domain_terms_map
421
+ and domain in domain_id_to_centroid_map
422
+ ):
423
+ domain_centroid = domain_id_to_centroid_map[domain]
424
+ # No need to filter the domain terms if it is in ids_to_keep
425
+ _ = self._validate_and_update_domain(
426
+ domain=domain,
427
+ domain_centroid=domain_centroid,
428
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
429
+ ids_to_labels=ids_to_labels,
430
+ words_to_omit=words_to_omit,
431
+ min_label_lines=min_label_lines,
432
+ max_label_lines=max_label_lines,
433
+ min_chars_per_line=min_chars_per_line,
434
+ max_chars_per_line=max_chars_per_line,
435
+ filtered_domain_centroids=filtered_domain_centroids,
436
+ filtered_domain_terms=filtered_domain_terms,
437
+ valid_indices=valid_indices,
438
+ )
439
+
440
+ def _process_remaining_domains(
441
+ self,
442
+ domain_id_to_centroid_map: Dict[str, np.ndarray],
443
+ ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
444
+ ids_to_labels: Union[Dict[int, str], None],
445
+ words_to_omit: Union[List[str], None],
446
+ remaining_labels: int,
447
+ min_label_lines: int,
448
+ max_label_lines: int,
449
+ min_chars_per_line: int,
450
+ max_chars_per_line: int,
451
+ filtered_domain_centroids: Dict[str, np.ndarray],
452
+ filtered_domain_terms: Dict[str, str],
453
+ valid_indices: List[int],
454
+ ) -> None:
455
+ """Process remaining domains to fill in additional labels, respecting the remaining_labels limit.
456
+
457
+ Args:
458
+ domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
459
+ ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
460
+ ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
461
+ words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
462
+ remaining_labels (int): The remaining number of labels that can be generated.
463
+ min_label_lines (int): Minimum number of lines in a label.
464
+ max_label_lines (int): Maximum number of lines in a label.
465
+ min_chars_per_line (int): Minimum number of characters in a line to display.
466
+ max_chars_per_line (int): Maximum number of characters in a line to display.
467
+ filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store filtered domain centroids (output).
468
+ filtered_domain_terms (Dict[str, str]): Dictionary to store filtered domain terms (output).
469
+ valid_indices (List): List to store valid indices (output).
470
+
471
+ Note:
472
+ The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
473
+ """
474
+ # Counter to track how many labels have been created
475
+ label_count = 0
476
+ # Collect domains not in ids_to_keep
477
+ remaining_domains = {
478
+ domain: centroid
479
+ for domain, centroid in domain_id_to_centroid_map.items()
480
+ if domain not in ids_to_keep and not pd.isna(domain)
481
+ }
482
+
483
+ # Function to calculate distance between two centroids
484
+ def calculate_distance(centroid1, centroid2):
485
+ return np.linalg.norm(centroid1 - centroid2)
486
+
487
+ # Domains to plot on network
488
+ selected_domains = []
489
+ # Find the farthest apart domains using centroids
490
+ if remaining_domains and remaining_labels:
491
+ first_domain = next(iter(remaining_domains)) # Pick the first domain to start
492
+ selected_domains.append(first_domain)
493
+
494
+ while len(selected_domains) < remaining_labels:
495
+ farthest_domain = None
496
+ max_distance = -1
497
+ # Find the domain farthest from any already selected domain
498
+ for candidate_domain, candidate_centroid in remaining_domains.items():
499
+ if candidate_domain in selected_domains:
500
+ continue
501
+
502
+ # Calculate the minimum distance to any selected domain
503
+ min_distance = min(
504
+ calculate_distance(candidate_centroid, remaining_domains[dom])
505
+ for dom in selected_domains
506
+ )
507
+ # Update the farthest domain if the minimum distance is greater
508
+ if min_distance > max_distance:
509
+ max_distance = min_distance
510
+ farthest_domain = candidate_domain
511
+
512
+ # Add the farthest domain to the selected domains
513
+ if farthest_domain:
514
+ selected_domains.append(farthest_domain)
515
+ else:
516
+ break # No more domains to select
517
+
518
+ # Process the selected domains and add to filtered lists
519
+ for domain in selected_domains:
520
+ domain_centroid = remaining_domains[domain]
521
+ is_domain_valid = self._validate_and_update_domain(
522
+ domain=domain,
523
+ domain_centroid=domain_centroid,
524
+ domain_id_to_centroid_map=domain_id_to_centroid_map,
525
+ ids_to_labels=ids_to_labels,
526
+ words_to_omit=words_to_omit,
527
+ min_label_lines=min_label_lines,
528
+ max_label_lines=max_label_lines,
529
+ min_chars_per_line=min_chars_per_line,
530
+ max_chars_per_line=max_chars_per_line,
531
+ filtered_domain_centroids=filtered_domain_centroids,
532
+ filtered_domain_terms=filtered_domain_terms,
533
+ valid_indices=valid_indices,
534
+ )
535
+ # Increment the label count if the domain is valid
536
+ if is_domain_valid:
537
+ label_count += 1
538
+ if label_count >= remaining_labels:
539
+ break
540
+
541
+ def _validate_and_update_domain(
542
+ self,
543
+ domain: str,
544
+ domain_centroid: np.ndarray,
545
+ domain_id_to_centroid_map: Dict[str, np.ndarray],
546
+ ids_to_labels: Union[Dict[int, str], None],
547
+ words_to_omit: Union[List[str], None],
548
+ min_label_lines: int,
549
+ max_label_lines: int,
550
+ min_chars_per_line: int,
551
+ max_chars_per_line: int,
552
+ filtered_domain_centroids: Dict[str, np.ndarray],
553
+ filtered_domain_terms: Dict[str, str],
554
+ valid_indices: List[int],
555
+ ) -> bool:
556
+ """Validate and process the domain terms, updating relevant dictionaries if valid.
557
+
558
+ Args:
559
+ domain (str): Domain ID to process.
560
+ domain_centroid (np.ndarray): Centroid position of the domain.
561
+ domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
562
+ ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
563
+ words_to_omit (List[str], None, optional): List of words to omit from the labels. Defaults to None.
564
+ min_label_lines (int): Minimum number of lines required in a label.
565
+ max_label_lines (int): Maximum number of lines allowed in a label.
566
+ min_chars_per_line (int): Minimum number of characters allowed per line.
567
+ max_chars_per_line (int): Maximum number of characters allowed per line.
568
+ filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store valid domain centroids.
569
+ filtered_domain_terms (Dict[str, str]): Dictionary to store valid domain terms.
570
+ valid_indices (List[int]): List of valid domain indices.
571
+
572
+ Returns:
573
+ bool: True if the domain is valid and added to the filtered dictionaries, False otherwise.
574
+ """
575
+ if ids_to_labels and domain in ids_to_labels:
576
+ # Directly use custom labels without filtering
577
+ domain_terms = ids_to_labels[domain]
578
+ else:
579
+ # Process the domain terms automatically
580
+ domain_terms = self._process_terms(
581
+ domain=domain,
582
+ words_to_omit=words_to_omit,
583
+ max_label_lines=max_label_lines,
584
+ min_chars_per_line=min_chars_per_line,
585
+ max_chars_per_line=max_chars_per_line,
586
+ )
587
+ # If no valid terms are generated, skip further processing
588
+ if not domain_terms:
589
+ return False
590
+
591
+ # Split the terms by TERM_DELIMITER and count the number of lines
592
+ num_domain_lines = len(domain_terms.split(TERM_DELIMITER))
593
+ # Check if the number of lines meets the minimum requirement
594
+ if num_domain_lines < min_label_lines:
595
+ return False
596
+
597
+ # Store the valid terms and centroids
598
+ filtered_domain_centroids[domain] = domain_centroid
599
+ filtered_domain_terms[domain] = domain_terms
600
+ valid_indices.append(list(domain_id_to_centroid_map.keys()).index(domain))
601
+
602
+ return True
603
+
604
+ def _process_terms(
605
+ self,
606
+ domain: str,
607
+ words_to_omit: Union[List[str], None],
608
+ max_label_lines: int,
609
+ min_chars_per_line: int,
610
+ max_chars_per_line: int,
611
+ ) -> List[str]:
612
+ """Process terms for a domain, applying word length constraints and combining words where appropriate.
613
+
614
+ Args:
615
+ domain (str): The domain being processed.
616
+ words_to_omit (List[str], None): List of words to omit from the labels.
617
+ max_label_lines (int): Maximum number of lines in a label.
618
+ min_chars_per_line (int): Minimum number of characters in a line to display.
619
+ max_chars_per_line (int): Maximum number of characters in a line to display.
620
+
621
+ Returns:
622
+ str: Processed terms separated by TERM_DELIMITER, with words combined if necessary to fit within constraints.
623
+ """
624
+ # Set custom labels from significant terms
625
+ terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
626
+ # Apply words_to_omit and word length constraints
627
+ if words_to_omit:
628
+ terms = [
629
+ term
630
+ for term in terms
631
+ if term.lower() not in words_to_omit and len(term) >= min_chars_per_line
632
+ ]
633
+
634
+ # Use the combine_words function directly to handle word combinations and length constraints
635
+ compressed_terms = self._combine_words(tuple(terms), max_chars_per_line, max_label_lines)
636
+
637
+ return compressed_terms
638
+
639
+ def get_annotated_label_colors(
640
+ self,
641
+ cmap: str = "gist_rainbow",
642
+ color: Union[str, List, Tuple, np.ndarray, None] = None,
643
+ blend_colors: bool = False,
644
+ blend_gamma: float = 2.2,
645
+ min_scale: float = 0.8,
646
+ max_scale: float = 1.0,
647
+ scale_factor: float = 1.0,
648
+ ids_to_colors: Union[Dict[int, Any], None] = None,
649
+ random_seed: int = 888,
650
+ ) -> np.ndarray:
651
+ """Get colors for the labels based on node annotations or a specified colormap.
652
+
653
+ Args:
654
+ cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
655
+ color (str, List, Tuple, np.ndarray, or None, optional): Color to use for the labels. Can be a single color or an array
656
+ of colors. If None, the colormap will be used. Defaults to None.
657
+ blend_colors (bool, optional): Whether to blend colors for nodes with multiple domains. Defaults to False.
658
+ blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
659
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
660
+ Controls the dimmest colors. Defaults to 0.8.
661
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
662
+ Controls the brightest colors. Defaults to 1.0.
663
+ scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
664
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
665
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
666
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
667
+
668
+ Returns:
669
+ np.ndarray: Array of RGBA colors for label annotations.
670
+ """
671
+ return get_annotated_domain_colors(
672
+ graph=self.graph,
673
+ cmap=cmap,
674
+ color=color,
675
+ blend_colors=blend_colors,
676
+ blend_gamma=blend_gamma,
677
+ min_scale=min_scale,
678
+ max_scale=max_scale,
679
+ scale_factor=scale_factor,
680
+ ids_to_colors=ids_to_colors,
681
+ random_seed=random_seed,
682
+ )
683
+
684
+ def _combine_words(
685
+ self, words: List[str], max_chars_per_line: int, max_label_lines: int
686
+ ) -> str:
687
+ """Combine words to fit within the max_chars_per_line and max_label_lines constraints,
688
+ and separate the final output by TERM_DELIMITER for plotting.
689
+
690
+ Args:
691
+ words (List[str]): List of words to combine.
692
+ max_chars_per_line (int): Maximum number of characters in a line to display.
693
+ max_label_lines (int): Maximum number of lines in a label.
694
+
695
+ Returns:
696
+ str: String of combined words separated by ':' for line breaks.
697
+ """
698
+
699
+ def try_combinations(words_batch: List[str]) -> List[str]:
700
+ """Try to combine words within a batch and return them with combined words separated by ':'."""
701
+ combined_lines = []
702
+ i = 0
703
+ while i < len(words_batch):
704
+ current_word = words_batch[i]
705
+ combined_word = current_word # Start with the current word
706
+ # Try to combine more words if possible, and ensure the combination fits within max_length
707
+ for j in range(i + 1, len(words_batch)):
708
+ next_word = words_batch[j]
709
+ # Ensure that the combined word fits within the max_chars_per_line limit
710
+ if (
711
+ len(combined_word) + len(next_word) + 1 <= max_chars_per_line
712
+ ): # +1 for space
713
+ combined_word = f"{combined_word} {next_word}"
714
+ i += 1 # Move past the combined word
715
+ else:
716
+ break # Stop combining if the length is exceeded
717
+
718
+ # Add the combined word only if it fits within the max_chars_per_line limit
719
+ if len(combined_word) <= max_chars_per_line:
720
+ combined_lines.append(combined_word) # Add the combined word
721
+ # Move to the next word
722
+ i += 1
723
+
724
+ # Stop if we've reached the max_label_lines limit
725
+ if len(combined_lines) >= max_label_lines:
726
+ break
727
+
728
+ return combined_lines
729
+
730
+ # Main logic: start with max_label_lines number of words
731
+ combined_lines = try_combinations(words[:max_label_lines])
732
+ remaining_words = words[max_label_lines:] # Remaining words after the initial batch
733
+ # Track words that have already been added
734
+ existing_words = set(" ".join(combined_lines).split())
735
+
736
+ # Continue pulling more words until we fill the lines
737
+ while remaining_words and len(combined_lines) < max_label_lines:
738
+ available_slots = max_label_lines - len(combined_lines)
739
+ words_to_add = [
740
+ word for word in remaining_words[:available_slots] if word not in existing_words
741
+ ]
742
+ remaining_words = remaining_words[available_slots:]
743
+ # Update the existing words set
744
+ existing_words.update(words_to_add)
745
+ # Add to combined_lines only unique words
746
+ combined_lines += try_combinations(words_to_add)
747
+
748
+ # Join the final combined lines with TERM_DELIMITER, a special separator for line breaks
749
+ return TERM_DELIMITER.join(combined_lines[:max_label_lines])
750
+
751
+ def _calculate_best_label_positions(
752
+ self,
753
+ filtered_domain_centroids: Dict[str, Any],
754
+ center: np.ndarray,
755
+ radius: float,
756
+ offset: float,
757
+ ) -> Dict[str, Any]:
758
+ """Calculate and optimize label positions for clarity.
759
+
760
+ Args:
761
+ filtered_domain_centroids (Dict[str, Any]): Centroids of the filtered domains.
762
+ center (np.ndarray): The center coordinates for label positioning.
763
+ radius (float): The radius for positioning labels around the center.
764
+ offset (float): The offset distance from the radius for positioning labels.
765
+
766
+ Returns:
767
+ Dict[str, Any]: Optimized positions for labels.
768
+ """
769
+ num_domains = len(filtered_domain_centroids)
770
+ # Calculate equidistant positions around the center for initial label placement
771
+ equidistant_positions = self._calculate_equidistant_positions_around_center(
772
+ center, radius, offset, num_domains
773
+ )
774
+ # Create a mapping of domains to their initial label positions
775
+ label_positions = dict(zip(filtered_domain_centroids.keys(), equidistant_positions))
776
+ # Optimize the label positions to minimize distance to domain centroids
777
+ return self._optimize_label_positions(label_positions, filtered_domain_centroids)
778
+
779
+ def _calculate_equidistant_positions_around_center(
780
+ self, center: np.ndarray, radius: float, label_offset: float, num_domains: int
781
+ ) -> List[np.ndarray]:
782
+ """Calculate positions around a center at equidistant angles.
783
+
784
+ Args:
785
+ center (np.ndarray): The central point around which positions are calculated.
786
+ radius (float): The radius at which positions are calculated.
787
+ label_offset (float): The offset added to the radius for label positioning.
788
+ num_domains (int): The number of positions (or domains) to calculate.
789
+
790
+ Returns:
791
+ List[np.ndarray]: List of positions (as 2D numpy arrays) around the center.
792
+ """
793
+ # Calculate equidistant angles in radians around the center
794
+ angles = np.linspace(0, 2 * np.pi, num_domains, endpoint=False)
795
+ # Compute the positions around the center using the angles
796
+ return [
797
+ center + (radius + label_offset) * np.array([np.cos(angle), np.sin(angle)])
798
+ for angle in angles
799
+ ]
800
+
801
+ def _optimize_label_positions(
802
+ self, best_label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
803
+ ) -> Dict[str, Any]:
804
+ """Optimize label positions around the perimeter to minimize total distance to centroids.
805
+
806
+ Args:
807
+ best_label_positions (Dict[str, Any]): Initial positions of labels around the perimeter.
808
+ domain_centroids (Dict[str, Any]): Centroid positions of the domains.
809
+
810
+ Returns:
811
+ Dict[str, Any]: Optimized label positions.
812
+ """
813
+ while True:
814
+ improvement = False # Start each iteration assuming no improvement
815
+ # Iterate through each pair of labels to check for potential improvements
816
+ for i in range(len(domain_centroids)):
817
+ for j in range(i + 1, len(domain_centroids)):
818
+ # Calculate the current total distance
819
+ current_distance = self._calculate_total_distance(
820
+ best_label_positions, domain_centroids
821
+ )
822
+ # Evaluate the total distance after swapping two labels
823
+ swapped_distance = self._swap_and_evaluate(
824
+ best_label_positions, i, j, domain_centroids
825
+ )
826
+ # If the swap improves the total distance, perform the swap
827
+ if swapped_distance < current_distance:
828
+ labels = list(best_label_positions.keys())
829
+ best_label_positions[labels[i]], best_label_positions[labels[j]] = (
830
+ best_label_positions[labels[j]],
831
+ best_label_positions[labels[i]],
832
+ )
833
+ improvement = True # Found an improvement, so continue optimizing
834
+
835
+ if not improvement:
836
+ break # Exit the loop if no improvement was found in this iteration
837
+
838
+ return best_label_positions
839
+
840
+ def _calculate_total_distance(
841
+ self, label_positions: Dict[str, Any], domain_centroids: Dict[str, Any]
842
+ ) -> float:
843
+ """Calculate the total distance from label positions to their domain centroids.
844
+
845
+ Args:
846
+ label_positions (Dict[str, Any]): Positions of labels around the perimeter.
847
+ domain_centroids (Dict[str, Any]): Centroid positions of the domains.
848
+
849
+ Returns:
850
+ float: The total distance from labels to centroids.
851
+ """
852
+ total_distance = 0
853
+ # Iterate through each domain and calculate the distance to its centroid
854
+ for domain, pos in label_positions.items():
855
+ centroid = domain_centroids[domain]
856
+ total_distance += np.linalg.norm(centroid - pos)
857
+
858
+ return total_distance
859
+
860
+ def _swap_and_evaluate(
861
+ self,
862
+ label_positions: Dict[str, Any],
863
+ i: int,
864
+ j: int,
865
+ domain_centroids: Dict[str, Any],
866
+ ) -> float:
867
+ """Swap two labels and evaluate the total distance after the swap.
868
+
869
+ Args:
870
+ label_positions (Dict[str, Any]): Positions of labels around the perimeter.
871
+ i (int): Index of the first label to swap.
872
+ j (int): Index of the second label to swap.
873
+ domain_centroids (Dict[str, Any]): Centroid positions of the domains.
874
+
875
+ Returns:
876
+ float: The total distance after swapping the two labels.
877
+ """
878
+ # Get the list of labels from the dictionary keys
879
+ labels = list(label_positions.keys())
880
+ swapped_positions = copy.deepcopy(label_positions)
881
+ # Swap the positions of the two specified labels
882
+ swapped_positions[labels[i]], swapped_positions[labels[j]] = (
883
+ swapped_positions[labels[j]],
884
+ swapped_positions[labels[i]],
885
+ )
886
+ # Calculate and return the total distance after the swap
887
+ return self._calculate_total_distance(swapped_positions, domain_centroids)
888
+
889
+ def _apply_str_transformation(
890
+ self, words: List[str], transformation: Union[str, Dict[str, str]]
891
+ ) -> List[str]:
892
+ """Apply a user-specified case transformation to each word in the list without appending duplicates.
893
+
894
+ Args:
895
+ words (List[str]): A list of words to transform.
896
+ transformation (Union[str, Dict[str, str]]): A single transformation (e.g., 'lower', 'upper', 'title', 'capitalize')
897
+ or a dictionary mapping cases ('lower', 'upper', 'title') to transformations (e.g., 'lower'='upper').
898
+
899
+ Returns:
900
+ List[str]: A list of transformed words with no duplicates.
901
+ """
902
+ # Initialize a list to store transformed words
903
+ transformed_words = []
904
+ for word in words:
905
+ # Split word into subwords by space
906
+ subwords = word.split(" ")
907
+ transformed_subwords = []
908
+ # Apply transformation to each subword
909
+ for subword in subwords:
910
+ transformed_subword = subword # Start with the original subword
911
+ # If transformation is a string, apply it to all subwords
912
+ if isinstance(transformation, str):
913
+ if hasattr(subword, transformation):
914
+ transformed_subword = getattr(subword, transformation)()
915
+
916
+ # If transformation is a dictionary, apply case-specific transformations
917
+ elif isinstance(transformation, dict):
918
+ for case_type, transform in transformation.items():
919
+ if case_type == "lower" and subword.islower() and transform:
920
+ transformed_subword = getattr(subword, transform)()
921
+ elif case_type == "upper" and subword.isupper() and transform:
922
+ transformed_subword = getattr(subword, transform)()
923
+ elif case_type == "title" and subword.istitle() and transform:
924
+ transformed_subword = getattr(subword, transform)()
925
+
926
+ # Append the transformed subword to the list
927
+ transformed_subwords.append(transformed_subword)
928
+
929
+ # Rejoin the transformed subwords into a single string to preserve structure
930
+ transformed_word = " ".join(transformed_subwords)
931
+ # Only append if the transformed word is not already in the list
932
+ if transformed_word not in transformed_words:
933
+ transformed_words.append(transformed_word)
934
+
935
+ return transformed_words