risk-network 0.0.4b2__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -2,12 +2,9 @@
2
2
  risk
3
3
  ~~~~
4
4
 
5
- risk
6
- ~~~~
7
-
8
- RISK: RISK Infers Spatial Kinship
5
+ RISK: RISK Infers Spatial Kinships
9
6
  """
10
7
 
11
8
  from risk.risk import RISK
12
9
 
13
- __version__ = "0.0.4-beta.2"
10
+ __version__ = "0.0.5"
@@ -197,7 +197,7 @@ def _simplify_word_list(words: List[str], threshold: float = 0.80) -> List[str]:
197
197
  word_counts = Counter(words)
198
198
  filtered_words = []
199
199
  used_words = set()
200
-
200
+ # Iterate through the words to find similar words
201
201
  for word in word_counts:
202
202
  if word in used_words:
203
203
  continue
@@ -207,7 +207,7 @@ def _simplify_word_list(words: List[str], threshold: float = 0.80) -> List[str]:
207
207
  similar_words = [
208
208
  other_word
209
209
  for other_word in word_counts
210
- if _jaccard_index(word_set, set(other_word)) >= threshold
210
+ if _calculate_jaccard_index(word_set, set(other_word)) >= threshold
211
211
  ]
212
212
  # Sort by frequency and choose the most frequent word
213
213
  similar_words.sort(key=lambda w: word_counts[w], reverse=True)
@@ -220,7 +220,7 @@ def _simplify_word_list(words: List[str], threshold: float = 0.80) -> List[str]:
220
220
  return final_words
221
221
 
222
222
 
223
- def _jaccard_index(set1: Set[Any], set2: Set[Any]) -> float:
223
+ def _calculate_jaccard_index(set1: Set[Any], set2: Set[Any]) -> float:
224
224
  """Calculate the Jaccard Index of two sets.
225
225
 
226
226
  Args:
risk/constants.py CHANGED
@@ -3,6 +3,8 @@ risk/constants
3
3
  ~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
+ GROUP_LINKAGE_METHODS = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
7
+
6
8
  GROUP_DISTANCE_METRICS = [
7
9
  "braycurtis",
8
10
  "canberra",
@@ -27,5 +29,3 @@ GROUP_DISTANCE_METRICS = [
27
29
  "sqeuclidean",
28
30
  "yule",
29
31
  ]
30
-
31
- GROUP_LINKAGE_METHODS = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
@@ -321,7 +321,11 @@ def _calculate_threshold(average_distances: list, distance_threshold: float) ->
321
321
  rank_percentiles = np.linspace(0, 1, len(sorted_distances))
322
322
  # Interpolating the ranks to 1000 evenly spaced percentiles
323
323
  interpolated_percentiles = np.linspace(0, 1, 1000)
324
- smoothed_distances = np.interp(interpolated_percentiles, rank_percentiles, sorted_distances)
324
+ try:
325
+ smoothed_distances = np.interp(interpolated_percentiles, rank_percentiles, sorted_distances)
326
+ except ValueError as e:
327
+ raise ValueError("No significant annotations found.") from e
328
+
325
329
  # Determine the index corresponding to the distance threshold
326
330
  threshold_index = int(np.ceil(distance_threshold * len(smoothed_distances))) - 1
327
331
  # Return the smoothed distance at the calculated index
risk/network/geometry.py CHANGED
@@ -7,13 +7,13 @@ import networkx as nx
7
7
  import numpy as np
8
8
 
9
9
 
10
- def apply_edge_lengths(
10
+ def assign_edge_lengths(
11
11
  G: nx.Graph,
12
12
  compute_sphere: bool = True,
13
13
  surface_depth: float = 0.0,
14
14
  include_edge_weight: bool = False,
15
15
  ) -> nx.Graph:
16
- """Apply edge lengths in the graph, optionally mapping nodes to a sphere and including edge weights.
16
+ """Assign edge lengths in the graph, optionally mapping nodes to a sphere and including edge weights.
17
17
 
18
18
  Args:
19
19
  G (nx.Graph): The input graph.
risk/network/graph.py CHANGED
@@ -5,13 +5,12 @@ risk/network/graph
5
5
 
6
6
  import random
7
7
  from collections import defaultdict
8
- from typing import Any, Dict, List, Tuple
8
+ from typing import Any, Dict, List, Tuple, Union
9
9
 
10
10
  import networkx as nx
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  import matplotlib
14
- import matplotlib.cm as cm
15
14
 
16
15
 
17
16
  class NetworkGraph:
@@ -101,7 +100,12 @@ class NetworkGraph:
101
100
  self.node_coordinates = _extract_node_coordinates(G_2d)
102
101
 
103
102
  def get_domain_colors(
104
- self, min_scale: float = 0.8, max_scale: float = 1.0, random_seed: int = 888, **kwargs
103
+ self,
104
+ min_scale: float = 0.8,
105
+ max_scale: float = 1.0,
106
+ scale_factor: float = 1.0,
107
+ random_seed: int = 888,
108
+ **kwargs,
105
109
  ) -> np.ndarray:
106
110
  """Generate composite colors for domains.
107
111
 
@@ -112,6 +116,8 @@ class NetworkGraph:
112
116
  Args:
113
117
  min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
114
118
  max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
119
+ scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
120
+ values more. Defaults to 1.0.
115
121
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
116
122
  **kwargs: Additional keyword arguments for color generation.
117
123
 
@@ -119,7 +125,7 @@ class NetworkGraph:
119
125
  np.ndarray: Array of transformed colors.
120
126
  """
121
127
  # Get colors for each domain
122
- domain_colors = self._get_domain_colors(**kwargs, random_seed=random_seed)
128
+ domain_colors = self._get_domain_colors(random_seed=random_seed)
123
129
  # Generate composite colors for nodes
124
130
  node_colors = self._get_composite_node_colors(domain_colors)
125
131
  # Transform colors to ensure proper alpha values and intensity
@@ -128,6 +134,7 @@ class NetworkGraph:
128
134
  self.node_enrichment_sums,
129
135
  min_scale=min_scale,
130
136
  max_scale=max_scale,
137
+ scale_factor=scale_factor,
131
138
  )
132
139
 
133
140
  return transformed_colors
@@ -153,9 +160,15 @@ class NetworkGraph:
153
160
 
154
161
  return composite_colors
155
162
 
156
- def _get_domain_colors(self, **kwargs) -> Dict[str, Any]:
163
+ def _get_domain_colors(
164
+ self, color: Union[str, None] = None, random_seed: int = 888
165
+ ) -> Dict[str, Any]:
157
166
  """Get colors for each domain.
158
167
 
168
+ Args:
169
+ color (Union[str, None], optional): Specific color to use for all domains. If specified, it will overwrite the colormap.
170
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
171
+
159
172
  Returns:
160
173
  dict: A dictionary mapping domain keys to their corresponding RGBA colors.
161
174
  """
@@ -164,20 +177,28 @@ class NetworkGraph:
164
177
  col for col in self.domains.columns if isinstance(col, (int, np.integer))
165
178
  ]
166
179
  domains = np.sort(numeric_domains)
167
- domain_colors = _get_colors(**kwargs, num_colors_to_generate=len(domains))
180
+ domain_colors = _get_colors(
181
+ num_colors_to_generate=len(domains), color=color, random_seed=random_seed
182
+ )
168
183
  return dict(zip(self.domain_to_nodes.keys(), domain_colors))
169
184
 
170
185
 
171
186
  def _transform_colors(
172
- colors: np.ndarray, enrichment_sums: np.ndarray, min_scale: float = 0.8, max_scale: float = 1.0
187
+ colors: np.ndarray,
188
+ enrichment_sums: np.ndarray,
189
+ min_scale: float = 0.8,
190
+ max_scale: float = 1.0,
191
+ scale_factor: float = 1.0,
173
192
  ) -> np.ndarray:
174
- """Transform colors to ensure proper alpha values and intensity based on enrichment sums.
193
+ """Transform colors using power scaling to emphasize high enrichment sums more.
175
194
 
176
195
  Args:
177
196
  colors (np.ndarray): An array of RGBA colors.
178
197
  enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
179
198
  min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
180
199
  max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
200
+ scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
201
+ values more. Defaults to 1.0.
181
202
 
182
203
  Returns:
183
204
  np.ndarray: The transformed array of RGBA colors with adjusted intensities.
@@ -185,11 +206,12 @@ def _transform_colors(
185
206
  if min_scale == max_scale:
186
207
  min_scale = max_scale - 10e-6 # Avoid division by zero
187
208
 
188
- log_enrichment_sums = np.log1p(enrichment_sums) # Use log1p to avoid log(0)
189
- # Normalize the capped enrichment sums to the range [0, 1]
190
- normalized_sums = log_enrichment_sums / np.max(log_enrichment_sums)
191
- # Scale normalized sums to the specified color range [min_scale, max_scale]
192
- scaled_sums = min_scale + (max_scale - min_scale) * normalized_sums
209
+ # Normalize the enrichment sums to the range [0, 1]
210
+ normalized_sums = enrichment_sums / np.max(enrichment_sums)
211
+ # Apply power scaling to dim lower values and emphasize higher values
212
+ scaled_sums = normalized_sums**scale_factor
213
+ # Linearly scale the normalized sums to the range [min_scale, max_scale]
214
+ scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
193
215
  # Adjust RGB values based on scaled sums
194
216
  for i in range(3): # Only adjust RGB values
195
217
  colors[:, i] = scaled_sums * colors[:, i]
@@ -250,7 +272,10 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
250
272
 
251
273
 
252
274
  def _get_colors(
253
- num_colors_to_generate: int = 10, cmap: str = "hsv", random_seed: int = 888, **kwargs
275
+ num_colors_to_generate: int = 10,
276
+ cmap: str = "hsv",
277
+ random_seed: int = 888,
278
+ color: Union[str, None] = None,
254
279
  ) -> List[Tuple]:
255
280
  """Generate a list of RGBA colors from a specified colormap or use a direct color string.
256
281
 
@@ -258,19 +283,20 @@ def _get_colors(
258
283
  num_colors_to_generate (int): The number of colors to generate. Defaults to 10.
259
284
  cmap (str): The name of the colormap to use. Defaults to "hsv".
260
285
  random_seed (int): Seed for random number generation. Defaults to 888.
261
- **kwargs: Additional keyword arguments, such as 'color' for a specific color.
286
+ color (str, optional): Specific color to use for all nodes. If specified, it will overwrite the colormap.
287
+ Defaults to None.
262
288
 
263
289
  Returns:
264
290
  list of tuple: List of RGBA colors.
265
291
  """
266
292
  # Set random seed for reproducibility
267
293
  random.seed(random_seed)
268
- if kwargs.get("color"):
269
- # If a direct color string is provided, generate a list with that color
270
- rgba = matplotlib.colors.to_rgba(kwargs["color"])
294
+ if color:
295
+ # If a direct color is provided, generate a list with that color
296
+ rgba = matplotlib.colors.to_rgba(color)
271
297
  rgbas = [rgba] * num_colors_to_generate
272
298
  else:
273
- colormap = cm.get_cmap(cmap)
299
+ colormap = matplotlib.colormaps.get_cmap(cmap)
274
300
  # Generate evenly distributed color positions
275
301
  color_positions = np.linspace(0, 1, num_colors_to_generate)
276
302
  random.shuffle(color_positions) # Shuffle the positions to randomize colors
risk/network/io.py CHANGED
@@ -6,6 +6,7 @@ This file contains the code for the RISK class and command-line access.
6
6
  """
7
7
 
8
8
  import json
9
+ import os
9
10
  import pickle
10
11
  import shutil
11
12
  import zipfile
@@ -14,7 +15,7 @@ from xml.dom import minidom
14
15
  import networkx as nx
15
16
  import pandas as pd
16
17
 
17
- from risk.network.geometry import apply_edge_lengths
18
+ from risk.network.geometry import assign_edge_lengths
18
19
  from risk.log import params, print_header
19
20
 
20
21
 
@@ -215,14 +216,20 @@ class NetworkIO:
215
216
  params.log_network(filetype=filetype, filepath=str(filepath))
216
217
  self._log_loading(filetype, filepath=filepath)
217
218
  cys_files = []
219
+ tmp_dir = ".tmp_cytoscape"
218
220
  # Try / finally to remove unzipped files
219
221
  try:
220
- # Unzip CYS file
222
+ # Create the temporary directory if it doesn't exist
223
+ if not os.path.exists(tmp_dir):
224
+ os.makedirs(tmp_dir)
225
+
226
+ # Unzip CYS file into the temporary directory
221
227
  with zipfile.ZipFile(filepath, "r") as zip_ref:
222
228
  cys_files = zip_ref.namelist()
223
- zip_ref.extractall("./")
229
+ zip_ref.extractall(tmp_dir)
230
+
224
231
  # Get first view and network instances
225
- cys_view_files = [cf for cf in cys_files if "/views/" in cf]
232
+ cys_view_files = [os.path.join(tmp_dir, cf) for cf in cys_files if "/views/" in cf]
226
233
  cys_view_file = (
227
234
  cys_view_files[0]
228
235
  if not view_name
@@ -244,7 +251,7 @@ class NetworkIO:
244
251
  # Read the node attributes (from /tables/)
245
252
  attribute_metadata_keywords = ["/tables/", "SHARED_ATTRS", "edge.cytable"]
246
253
  attribute_metadata = [
247
- cf
254
+ os.path.join(tmp_dir, cf)
248
255
  for cf in cys_files
249
256
  if all(keyword in cf for keyword in attribute_metadata_keywords)
250
257
  ][0]
@@ -291,10 +298,9 @@ class NetworkIO:
291
298
  return self._initialize_graph(G)
292
299
 
293
300
  finally:
294
- # Remove unzipped files/directories
295
- cys_dirnames = list(set([cf.split("/")[0] for cf in cys_files]))
296
- for dirname in cys_dirnames:
297
- shutil.rmtree(dirname)
301
+ # Remove the temporary directory and its contents
302
+ if os.path.exists(tmp_dir):
303
+ shutil.rmtree(tmp_dir)
298
304
 
299
305
  @classmethod
300
306
  def load_cytoscape_json_network(
@@ -402,12 +408,13 @@ class NetworkIO:
402
408
  Returns:
403
409
  nx.Graph: The processed and validated graph.
404
410
  """
411
+ self._validate_nodes(G)
412
+ self._assign_edge_weights(G)
413
+ self._assign_edge_lengths(G)
414
+ self._remove_invalid_graph_properties(G)
405
415
  # IMPORTANT: This is where the graph node labels are converted to integers
416
+ # Make sure to perform this step after all other processing
406
417
  G = nx.relabel_nodes(G, {node: idx for idx, node in enumerate(G.nodes)})
407
- self._remove_invalid_graph_properties(G)
408
- self._validate_edges(G)
409
- self._validate_nodes(G)
410
- self._process_graph(G)
411
418
  return G
412
419
 
413
420
  def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
@@ -416,18 +423,26 @@ class NetworkIO:
416
423
  Args:
417
424
  G (nx.Graph): A NetworkX graph object.
418
425
  """
419
- print(f"Minimum edges per node: {self.min_edges_per_node}")
420
- # Remove nodes with fewer edges than the specified threshold
421
- nodes_with_few_edges = [
422
- node for node in G.nodes() if G.degree(node) <= self.min_edges_per_node
423
- ]
424
- G.remove_nodes_from(nodes_with_few_edges)
425
- # Remove self-loop edges
426
- self_loops = list(nx.selfloop_edges(G))
427
- G.remove_edges_from(self_loops)
428
-
429
- def _validate_edges(self, G: nx.Graph) -> None:
430
- """Validate and assign weights to the edges in the graph.
426
+ # First, Remove self-loop edges to ensure correct edge count
427
+ G.remove_edges_from(list(nx.selfloop_edges(G)))
428
+ # Then, iteratively remove nodes with fewer edges than the specified threshold
429
+ while True:
430
+ nodes_to_remove = [
431
+ node for node in G.nodes() if G.degree(node) < self.min_edges_per_node
432
+ ]
433
+ if not nodes_to_remove:
434
+ break # Exit loop if no more nodes to remove
435
+
436
+ # Remove the nodes and their associated edges
437
+ G.remove_nodes_from(nodes_to_remove)
438
+
439
+ # Optionally: Remove any isolated nodes if needed
440
+ isolated_nodes = list(nx.isolates(G))
441
+ if isolated_nodes:
442
+ G.remove_nodes_from(isolated_nodes)
443
+
444
+ def _assign_edge_weights(self, G: nx.Graph) -> None:
445
+ """Assign weights to the edges in the graph.
431
446
 
432
447
  Args:
433
448
  G (nx.Graph): A NetworkX graph object.
@@ -456,13 +471,13 @@ class NetworkIO:
456
471
  ), f"Node {node} is missing 'x' or 'y' position attributes."
457
472
  assert "label" in attrs, f"Node {node} is missing a 'label' attribute."
458
473
 
459
- def _process_graph(self, G: nx.Graph) -> None:
474
+ def _assign_edge_lengths(self, G: nx.Graph) -> None:
460
475
  """Prepare the network by adjusting surface depth and calculating edge lengths.
461
476
 
462
477
  Args:
463
478
  G (nx.Graph): The input network graph.
464
479
  """
465
- apply_edge_lengths(
480
+ assign_edge_lengths(
466
481
  G,
467
482
  compute_sphere=self.compute_sphere,
468
483
  surface_depth=self.surface_depth,
@@ -484,9 +499,9 @@ class NetworkIO:
484
499
  print(f"Filetype: {filetype}")
485
500
  if filepath:
486
501
  print(f"Filepath: {filepath}")
487
- print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
488
- if self.compute_sphere:
489
- print(f"Surface depth: {self.surface_depth}")
490
502
  print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
491
503
  if self.include_edge_weight:
492
504
  print(f"Weight label: {self.weight_label}")
505
+ print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
506
+ if self.compute_sphere:
507
+ print(f"Surface depth: {self.surface_depth}")
risk/network/plot.py CHANGED
@@ -33,6 +33,7 @@ class NetworkPlotter:
33
33
  plot_outline: bool = True,
34
34
  outline_color: str = "black",
35
35
  outline_scale: float = 1.0,
36
+ linestyle: str = "dashed",
36
37
  ) -> None:
37
38
  """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
38
39
 
@@ -43,11 +44,18 @@ class NetworkPlotter:
43
44
  plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
44
45
  outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
45
46
  outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
47
+ linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
46
48
  """
47
49
  self.graph = graph
48
50
  # Initialize the plot with the specified parameters
49
51
  self.ax = self._initialize_plot(
50
- graph, figsize, background_color, plot_outline, outline_color, outline_scale
52
+ graph,
53
+ figsize,
54
+ background_color,
55
+ plot_outline,
56
+ outline_color,
57
+ outline_scale,
58
+ linestyle,
51
59
  )
52
60
 
53
61
  def _initialize_plot(
@@ -58,6 +66,7 @@ class NetworkPlotter:
58
66
  plot_outline: bool,
59
67
  outline_color: str,
60
68
  outline_scale: float,
69
+ linestyle: str,
61
70
  ) -> plt.Axes:
62
71
  """Set up the plot with figure size, optional circle perimeter, and background color.
63
72
 
@@ -68,6 +77,7 @@ class NetworkPlotter:
68
77
  plot_outline (bool): Whether to plot the network perimeter circle.
69
78
  outline_color (str): Color of the network perimeter circle.
70
79
  outline_scale (float): Outline scaling factor for the perimeter diameter.
80
+ linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid).
71
81
 
72
82
  Returns:
73
83
  plt.Axes: The axis object for the plot.
@@ -87,7 +97,7 @@ class NetworkPlotter:
87
97
  circle = plt.Circle(
88
98
  center,
89
99
  scaled_radius,
90
- linestyle="--",
100
+ linestyle=linestyle, # Use the linestyle argument here
91
101
  color=outline_color,
92
102
  fill=False,
93
103
  linewidth=1.5,
@@ -400,8 +410,12 @@ class NetworkPlotter:
400
410
  fontcolor: Union[str, np.ndarray] = "black",
401
411
  arrow_linewidth: float = 1,
402
412
  arrow_color: Union[str, np.ndarray] = "black",
413
+ max_labels: Union[int, None] = None,
403
414
  max_words: int = 10,
404
415
  min_words: int = 1,
416
+ max_word_length: int = 20,
417
+ min_word_length: int = 1,
418
+ words_to_omit: Union[List[str], None] = None,
405
419
  ) -> None:
406
420
  """Annotate the network graph with labels for different domains, positioned around the network for clarity.
407
421
 
@@ -413,8 +427,12 @@ class NetworkPlotter:
413
427
  fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
414
428
  arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
415
429
  arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
430
+ max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
416
431
  max_words (int, optional): Maximum number of words in a label. Defaults to 10.
417
432
  min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
433
+ max_word_length (int, optional): Maximum number of characters in a word to display. Defaults to 20.
434
+ min_word_length (int, optional): Minimum number of characters in a word to display. Defaults to 1.
435
+ words_to_omit (List[str], optional): List of words to omit from the labels. Defaults to None.
418
436
  """
419
437
  # Log the plotting parameters
420
438
  params.log_plotter(
@@ -425,14 +443,22 @@ class NetworkPlotter:
425
443
  label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
426
444
  label_arrow_linewidth=arrow_linewidth,
427
445
  label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
446
+ label_max_labels=max_labels,
428
447
  label_max_words=max_words,
429
448
  label_min_words=min_words,
449
+ label_max_word_length=max_word_length,
450
+ label_min_word_length=min_word_length,
451
+ label_words_to_omit=words_to_omit,
430
452
  )
453
+
431
454
  # Convert color strings to RGBA arrays if necessary
432
455
  if isinstance(fontcolor, str):
433
- fontcolor = self.get_annotated_contour_colors(color=fontcolor)
456
+ fontcolor = self.get_annotated_label_colors(color=fontcolor)
434
457
  if isinstance(arrow_color, str):
435
- arrow_color = self.get_annotated_contour_colors(color=arrow_color)
458
+ arrow_color = self.get_annotated_label_colors(color=arrow_color)
459
+ # Normalize words_to_omit to lowercase
460
+ if words_to_omit:
461
+ words_to_omit = set(word.lower() for word in words_to_omit)
436
462
 
437
463
  # Calculate the center and radius of the network
438
464
  domain_centroids = {}
@@ -443,31 +469,57 @@ class NetworkPlotter:
443
469
  # Initialize empty lists to collect valid indices
444
470
  valid_indices = []
445
471
  filtered_domain_centroids = {}
472
+ filtered_domain_terms = {}
446
473
  # Loop through domain_centroids with index
447
474
  for idx, (domain, centroid) in enumerate(domain_centroids.items()):
475
+ # Process the domain term
476
+ terms = self.graph.trimmed_domain_to_term[domain].split(" ")
477
+ # Remove words_to_omit
478
+ if words_to_omit:
479
+ terms = [term for term in terms if term.lower() not in words_to_omit]
480
+ # Filter words based on length
481
+ terms = [term for term in terms if min_word_length <= len(term) <= max_word_length]
482
+ # Trim to max_words
483
+ terms = terms[:max_words]
448
484
  # Check if the domain passes the word count condition
449
- if len(self.graph.trimmed_domain_to_term[domain].split(" ")[:max_words]) >= min_words:
485
+ if len(terms) >= min_words:
450
486
  # Add to filtered_domain_centroids
451
487
  filtered_domain_centroids[domain] = centroid
452
- # Keep track of the valid index
488
+ # Store the filtered and trimmed terms
489
+ filtered_domain_terms[domain] = " ".join(terms)
490
+ # Keep track of the valid index - used for fontcolor and arrow_color
453
491
  valid_indices.append(idx)
454
492
 
455
- # Now filter fontcolor and arrow_color to keep only the valid indices
456
- fontcolor = fontcolor[valid_indices]
457
- arrow_color = arrow_color[valid_indices]
493
+ # If max_labels is specified and less than the available labels
494
+ if max_labels is not None and max_labels < len(filtered_domain_centroids):
495
+ step = len(filtered_domain_centroids) / max_labels
496
+ selected_indices = [int(i * step) for i in range(max_labels)]
497
+ # Filter the centroids, terms, and valid_indices to only use the selected indices
498
+ filtered_domain_centroids = {
499
+ k: v
500
+ for i, (k, v) in enumerate(filtered_domain_centroids.items())
501
+ if i in selected_indices
502
+ }
503
+ filtered_domain_terms = {
504
+ k: v
505
+ for i, (k, v) in enumerate(filtered_domain_terms.items())
506
+ if i in selected_indices
507
+ }
508
+ # Update valid_indices to match selected indices
509
+ valid_indices = [valid_indices[i] for i in selected_indices]
458
510
 
459
511
  # Calculate the bounding box around the network
460
512
  center, radius = _calculate_bounding_box(
461
513
  self.graph.node_coordinates, radius_margin=perimeter_scale
462
514
  )
463
515
  # Calculate the best positions for labels around the perimeter
464
- best_label_positions = _best_label_positions(
516
+ best_label_positions = _calculate_best_label_positions(
465
517
  filtered_domain_centroids, center, radius, offset
466
518
  )
467
- # Annotate the network with labels
468
- for idx, (domain, pos) in enumerate(best_label_positions.items()):
519
+ # Annotate the network with labels - valid_indices is used for fontcolor and arrow_color
520
+ for idx, (domain, pos) in zip(valid_indices, best_label_positions.items()):
469
521
  centroid = filtered_domain_centroids[domain]
470
- annotations = self.graph.trimmed_domain_to_term[domain].split(" ")[:max_words]
522
+ annotations = filtered_domain_terms[domain].split(" ")[:max_words]
471
523
  self.ax.annotate(
472
524
  "\n".join(annotations),
473
525
  xy=centroid,
@@ -531,12 +583,10 @@ class NetworkPlotter:
531
583
 
532
584
  # Calculate the centroid of the provided nodes
533
585
  centroid = self._calculate_domain_centroid(node_ids)
534
-
535
586
  # Calculate the bounding box around the network
536
587
  center, radius = _calculate_bounding_box(
537
588
  self.graph.node_coordinates, radius_margin=perimeter_scale
538
589
  )
539
-
540
590
  # Convert radial position to radians, adjusting for a 90-degree rotation
541
591
  radial_radians = np.deg2rad(radial_position - 90)
542
592
  label_position = (
@@ -746,7 +796,7 @@ def _calculate_bounding_box(
746
796
  return center, radius
747
797
 
748
798
 
749
- def _best_label_positions(
799
+ def _calculate_best_label_positions(
750
800
  filtered_domain_centroids: Dict[str, Any], center: np.ndarray, radius: float, offset: float
751
801
  ) -> Dict[str, Any]:
752
802
  """Calculate and optimize label positions for clarity.
@@ -762,7 +812,9 @@ def _best_label_positions(
762
812
  """
763
813
  num_domains = len(filtered_domain_centroids)
764
814
  # Calculate equidistant positions around the center for initial label placement
765
- equidistant_positions = _equidistant_angles_around_center(center, radius, offset, num_domains)
815
+ equidistant_positions = _calculate_equidistant_positions_around_center(
816
+ center, radius, offset, num_domains
817
+ )
766
818
  # Create a mapping of domains to their initial label positions
767
819
  label_positions = {
768
820
  domain: position
@@ -772,7 +824,7 @@ def _best_label_positions(
772
824
  return _optimize_label_positions(label_positions, filtered_domain_centroids)
773
825
 
774
826
 
775
- def _equidistant_angles_around_center(
827
+ def _calculate_equidistant_positions_around_center(
776
828
  center: np.ndarray, radius: float, label_offset: float, num_domains: int
777
829
  ) -> List[np.ndarray]:
778
830
  """Calculate positions around a center at equidistant angles.