risk-network 0.0.3b4__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/network/plot.py CHANGED
@@ -27,7 +27,7 @@ class NetworkPlotter:
27
27
 
28
28
  def __init__(
29
29
  self,
30
- network_graph: NetworkGraph,
30
+ graph: NetworkGraph,
31
31
  figsize: tuple = (10, 10),
32
32
  background_color: str = "white",
33
33
  plot_outline: bool = True,
@@ -37,22 +37,22 @@ class NetworkPlotter:
37
37
  """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
38
38
 
39
39
  Args:
40
- network_graph (NetworkGraph): The network data and attributes to be visualized.
40
+ graph (NetworkGraph): The network data and attributes to be visualized.
41
41
  figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
42
42
  background_color (str, optional): Background color of the plot. Defaults to "white".
43
43
  plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
44
44
  outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
45
45
  outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
46
46
  """
47
- self.network_graph = network_graph
47
+ self.graph = graph
48
48
  # Initialize the plot with the specified parameters
49
49
  self.ax = self._initialize_plot(
50
- network_graph, figsize, background_color, plot_outline, outline_color, outline_scale
50
+ graph, figsize, background_color, plot_outline, outline_color, outline_scale
51
51
  )
52
52
 
53
53
  def _initialize_plot(
54
54
  self,
55
- network_graph: NetworkGraph,
55
+ graph: NetworkGraph,
56
56
  figsize: tuple,
57
57
  background_color: str,
58
58
  plot_outline: bool,
@@ -62,7 +62,7 @@ class NetworkPlotter:
62
62
  """Set up the plot with figure size, optional circle perimeter, and background color.
63
63
 
64
64
  Args:
65
- network_graph (NetworkGraph): The network data and attributes to be visualized.
65
+ graph (NetworkGraph): The network data and attributes to be visualized.
66
66
  figsize (tuple): Size of the figure in inches (width, height).
67
67
  background_color (str): Background color of the plot.
68
68
  plot_outline (bool): Whether to plot the network perimeter circle.
@@ -73,7 +73,7 @@ class NetworkPlotter:
73
73
  plt.Axes: The axis object for the plot.
74
74
  """
75
75
  # Extract node coordinates from the network graph
76
- node_coordinates = network_graph.node_coordinates
76
+ node_coordinates = graph.node_coordinates
77
77
  # Calculate the center and radius of the bounding box around the network
78
78
  center, radius = _calculate_bounding_box(node_coordinates)
79
79
  # Scale the radius by the outline_scale factor
@@ -141,10 +141,10 @@ class NetworkPlotter:
141
141
  network_node_shape=node_shape,
142
142
  )
143
143
  # Extract node coordinates from the network graph
144
- node_coordinates = self.network_graph.node_coordinates
144
+ node_coordinates = self.graph.node_coordinates
145
145
  # Draw the nodes of the graph
146
146
  nx.draw_networkx_nodes(
147
- self.network_graph.G,
147
+ self.graph.network,
148
148
  pos=node_coordinates,
149
149
  node_size=node_size,
150
150
  node_color=node_color,
@@ -155,7 +155,7 @@ class NetworkPlotter:
155
155
  )
156
156
  # Draw the edges of the graph
157
157
  nx.draw_networkx_edges(
158
- self.network_graph.G,
158
+ self.graph.network,
159
159
  pos=node_coordinates,
160
160
  width=edge_width,
161
161
  edge_color=edge_color,
@@ -197,20 +197,18 @@ class NetworkPlotter:
197
197
  )
198
198
  # Filter to get node IDs and their coordinates
199
199
  node_ids = [
200
- self.network_graph.node_label_to_id_map.get(node)
200
+ self.graph.node_label_to_id_map.get(node)
201
201
  for node in nodes
202
- if node in self.network_graph.node_label_to_id_map
202
+ if node in self.graph.node_label_to_id_map
203
203
  ]
204
204
  if not node_ids:
205
205
  raise ValueError("No nodes found in the network graph.")
206
206
 
207
207
  # Get the coordinates of the filtered nodes
208
- node_coordinates = {
209
- node_id: self.network_graph.node_coordinates[node_id] for node_id in node_ids
210
- }
208
+ node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
211
209
  # Draw the nodes in the subnetwork
212
210
  nx.draw_networkx_nodes(
213
- self.network_graph.G,
211
+ self.graph.network,
214
212
  pos=node_coordinates,
215
213
  nodelist=node_ids,
216
214
  node_size=node_size,
@@ -221,7 +219,7 @@ class NetworkPlotter:
221
219
  ax=self.ax,
222
220
  )
223
221
  # Draw the edges between the specified nodes in the subnetwork
224
- subgraph = self.network_graph.G.subgraph(node_ids)
222
+ subgraph = self.graph.network.subgraph(node_ids)
225
223
  nx.draw_networkx_edges(
226
224
  subgraph,
227
225
  pos=node_coordinates,
@@ -260,9 +258,9 @@ class NetworkPlotter:
260
258
  color = self.get_annotated_contour_colors(color=color)
261
259
 
262
260
  # Extract node coordinates from the network graph
263
- node_coordinates = self.network_graph.node_coordinates
261
+ node_coordinates = self.graph.node_coordinates
264
262
  # Draw contours for each domain in the network
265
- for idx, (_, nodes) in enumerate(self.network_graph.domain_to_nodes.items()):
263
+ for idx, (_, nodes) in enumerate(self.graph.domain_to_nodes.items()):
266
264
  if len(nodes) > 1:
267
265
  self._draw_kde_contour(
268
266
  self.ax,
@@ -299,23 +297,23 @@ class NetworkPlotter:
299
297
  """
300
298
  # Log the plotting parameters
301
299
  params.log_plotter(
302
- contour_levels=levels,
303
- contour_bandwidth=bandwidth,
304
- contour_grid_size=grid_size,
305
- contour_alpha=alpha,
306
- contour_color="custom" if isinstance(color, np.ndarray) else color,
300
+ subcontour_levels=levels,
301
+ subcontour_bandwidth=bandwidth,
302
+ subcontour_grid_size=grid_size,
303
+ subcontour_alpha=alpha,
304
+ subcontour_color="custom" if isinstance(color, np.ndarray) else color,
307
305
  )
308
306
  # Filter to get node IDs and their coordinates
309
307
  node_ids = [
310
- self.network_graph.node_label_to_id_map.get(node)
308
+ self.graph.node_label_to_id_map.get(node)
311
309
  for node in nodes
312
- if node in self.network_graph.node_label_to_id_map
310
+ if node in self.graph.node_label_to_id_map
313
311
  ]
314
312
  if not node_ids or len(node_ids) == 1:
315
313
  raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
316
314
 
317
315
  # Draw the KDE contour for the specified nodes
318
- node_coordinates = self.network_graph.node_coordinates
316
+ node_coordinates = self.graph.node_coordinates
319
317
  self._draw_kde_contour(
320
318
  self.ax,
321
319
  node_coordinates,
@@ -402,8 +400,10 @@ class NetworkPlotter:
402
400
  fontcolor: Union[str, np.ndarray] = "black",
403
401
  arrow_linewidth: float = 1,
404
402
  arrow_color: Union[str, np.ndarray] = "black",
405
- num_words: int = 10,
403
+ max_labels: Union[int, None] = None,
404
+ max_words: int = 10,
406
405
  min_words: int = 1,
406
+ words_to_omit: Union[List[str], None] = None,
407
407
  ) -> None:
408
408
  """Annotate the network graph with labels for different domains, positioned around the network for clarity.
409
409
 
@@ -415,8 +415,10 @@ class NetworkPlotter:
415
415
  fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
416
416
  arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
417
417
  arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
418
- num_words (int, optional): Maximum number of words in a label. Defaults to 10.
418
+ max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
419
+ max_words (int, optional): Maximum number of words in a label. Defaults to 10.
419
420
  min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
421
+ words_to_omit (List[str], optional): List of words to omit from the labels. Defaults to None.
420
422
  """
421
423
  # Log the plotting parameters
422
424
  params.log_plotter(
@@ -427,39 +429,82 @@ class NetworkPlotter:
427
429
  label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
428
430
  label_arrow_linewidth=arrow_linewidth,
429
431
  label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
430
- label_num_words=num_words,
432
+ label_max_labels=max_labels,
433
+ label_max_words=max_words,
431
434
  label_min_words=min_words,
435
+ label_words_to_omit=words_to_omit, # Log words_to_omit parameter
432
436
  )
437
+
433
438
  # Convert color strings to RGBA arrays if necessary
434
439
  if isinstance(fontcolor, str):
435
- fontcolor = self.get_annotated_contour_colors(color=fontcolor)
440
+ fontcolor = self.get_annotated_label_colors(color=fontcolor)
436
441
  if isinstance(arrow_color, str):
437
- arrow_color = self.get_annotated_contour_colors(color=arrow_color)
442
+ arrow_color = self.get_annotated_label_colors(color=arrow_color)
443
+ # Normalize words_to_omit to lowercase
444
+ if words_to_omit:
445
+ words_to_omit = set(word.lower() for word in words_to_omit)
438
446
 
439
447
  # Calculate the center and radius of the network
440
448
  domain_centroids = {}
441
- for domain, nodes in self.network_graph.domain_to_nodes.items():
449
+ for domain, nodes in self.graph.domain_to_nodes.items():
442
450
  if nodes: # Skip if the domain has no nodes
443
451
  domain_centroids[domain] = self._calculate_domain_centroid(nodes)
444
452
 
453
+ # Initialize empty lists to collect valid indices
454
+ valid_indices = []
455
+ filtered_domain_centroids = {}
456
+ filtered_domain_terms = {}
457
+ # Loop through domain_centroids with index
458
+ for idx, (domain, centroid) in enumerate(domain_centroids.items()):
459
+ # Process the domain term
460
+ terms = self.graph.trimmed_domain_to_term[domain].split(" ")
461
+ # Remove words_to_omit
462
+ if words_to_omit:
463
+ terms = [term for term in terms if term.lower() not in words_to_omit]
464
+ # Trim to max_words
465
+ terms = terms[:max_words]
466
+ # Check if the domain passes the word count condition
467
+ if len(terms) >= min_words:
468
+ # Add to filtered_domain_centroids
469
+ filtered_domain_centroids[domain] = centroid
470
+ # Store the trimmed terms
471
+ filtered_domain_terms[domain] = " ".join(terms)
472
+ # Keep track of the valid index
473
+ valid_indices.append(idx)
474
+
475
+ # If max_labels is specified and less than the available labels
476
+ if max_labels is not None and max_labels < len(filtered_domain_centroids):
477
+ step = len(filtered_domain_centroids) / max_labels
478
+ selected_indices = [int(i * step) for i in range(max_labels)]
479
+ filtered_domain_centroids = {
480
+ k: v
481
+ for i, (k, v) in enumerate(filtered_domain_centroids.items())
482
+ if i in selected_indices
483
+ }
484
+ filtered_domain_terms = {
485
+ k: v
486
+ for i, (k, v) in enumerate(filtered_domain_terms.items())
487
+ if i in selected_indices
488
+ }
489
+ fontcolor = fontcolor[selected_indices]
490
+ arrow_color = arrow_color[selected_indices]
491
+
492
+ # Update the terms in the graph after omitting words and filtering
493
+ for domain, terms in filtered_domain_terms.items():
494
+ self.graph.trimmed_domain_to_term[domain] = terms
495
+
445
496
  # Calculate the bounding box around the network
446
497
  center, radius = _calculate_bounding_box(
447
- self.network_graph.node_coordinates, radius_margin=perimeter_scale
498
+ self.graph.node_coordinates, radius_margin=perimeter_scale
448
499
  )
449
-
450
- # Filter out domains with insufficient words for labeling
451
- filtered_domains = {
452
- domain: centroid
453
- for domain, centroid in domain_centroids.items()
454
- if len(self.network_graph.trimmed_domain_to_term[domain].split(" ")[:num_words])
455
- >= min_words
456
- }
457
500
  # Calculate the best positions for labels around the perimeter
458
- best_label_positions = _best_label_positions(filtered_domains, center, radius, offset)
501
+ best_label_positions = _best_label_positions(
502
+ filtered_domain_centroids, center, radius, offset
503
+ )
459
504
  # Annotate the network with labels
460
505
  for idx, (domain, pos) in enumerate(best_label_positions.items()):
461
- centroid = filtered_domains[domain]
462
- annotations = self.network_graph.trimmed_domain_to_term[domain].split(" ")[:num_words]
506
+ centroid = filtered_domain_centroids[domain]
507
+ annotations = self.graph.trimmed_domain_to_term[domain].split(" ")[:max_words]
463
508
  self.ax.annotate(
464
509
  "\n".join(annotations),
465
510
  xy=centroid,
@@ -473,6 +518,81 @@ class NetworkPlotter:
473
518
  arrowprops=dict(arrowstyle="->", color=arrow_color[idx], linewidth=arrow_linewidth),
474
519
  )
475
520
 
521
+ def plot_sublabel(
522
+ self,
523
+ nodes: list,
524
+ label: str,
525
+ radial_position: float = 0.0,
526
+ perimeter_scale: float = 1.05,
527
+ offset: float = 0.10,
528
+ font: str = "Arial",
529
+ fontsize: int = 10,
530
+ fontcolor: str = "black",
531
+ arrow_linewidth: float = 1,
532
+ arrow_color: str = "black",
533
+ ) -> None:
534
+ """Annotate the network graph with a single label for the given nodes, positioned at a specified radial angle.
535
+
536
+ Args:
537
+ nodes (List[str]): List of node labels to be used for calculating the centroid.
538
+ label (str): The label to be annotated on the network.
539
+ radial_position (float, optional): Radial angle for positioning the label, in degrees (0-360). Defaults to 0.0.
540
+ perimeter_scale (float, optional): Scale factor for positioning the label around the perimeter. Defaults to 1.05.
541
+ offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
542
+ font (str, optional): Font name for the label. Defaults to "Arial".
543
+ fontsize (int, optional): Font size for the label. Defaults to 10.
544
+ fontcolor (str, optional): Color of the label text. Defaults to "black".
545
+ arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
546
+ arrow_color (str, optional): Color of the arrow. Defaults to "black".
547
+ """
548
+ # Log the plotting parameters
549
+ params.log_plotter(
550
+ sublabel_perimeter_scale=perimeter_scale,
551
+ sublabel_offset=offset,
552
+ sublabel_font=font,
553
+ sublabel_fontsize=fontsize,
554
+ sublabel_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
555
+ sublabel_arrow_linewidth=arrow_linewidth,
556
+ sublabel_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
557
+ sublabel_radial_position=radial_position,
558
+ )
559
+
560
+ # Map node labels to IDs
561
+ node_ids = [
562
+ self.graph.node_label_to_id_map.get(node)
563
+ for node in nodes
564
+ if node in self.graph.node_label_to_id_map
565
+ ]
566
+ if not node_ids or len(node_ids) == 1:
567
+ raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
568
+
569
+ # Calculate the centroid of the provided nodes
570
+ centroid = self._calculate_domain_centroid(node_ids)
571
+ # Calculate the bounding box around the network
572
+ center, radius = _calculate_bounding_box(
573
+ self.graph.node_coordinates, radius_margin=perimeter_scale
574
+ )
575
+ # Convert radial position to radians, adjusting for a 90-degree rotation
576
+ radial_radians = np.deg2rad(radial_position - 90)
577
+ label_position = (
578
+ center[0] + (radius + offset) * np.cos(radial_radians),
579
+ center[1] + (radius + offset) * np.sin(radial_radians),
580
+ )
581
+
582
+ # Annotate the network with the label
583
+ self.ax.annotate(
584
+ label,
585
+ xy=centroid,
586
+ xytext=label_position,
587
+ textcoords="data",
588
+ ha="center",
589
+ va="center",
590
+ fontsize=fontsize,
591
+ fontname=font,
592
+ color=fontcolor,
593
+ arrowprops=dict(arrowstyle="->", color=arrow_color, linewidth=arrow_linewidth),
594
+ )
595
+
476
596
  def _calculate_domain_centroid(self, nodes: list) -> tuple:
477
597
  """Calculate the most centrally located node in .
478
598
 
@@ -483,7 +603,7 @@ class NetworkPlotter:
483
603
  tuple: A tuple containing the domain's central node coordinates.
484
604
  """
485
605
  # Extract positions of all nodes in the domain
486
- node_positions = self.network_graph.node_coordinates[nodes, :]
606
+ node_positions = self.graph.node_coordinates[nodes, :]
487
607
  # Calculate the pairwise distance matrix between all nodes in the domain
488
608
  distances_matrix = np.linalg.norm(node_positions[:, np.newaxis] - node_positions, axis=2)
489
609
  # Sum the distances for each node to all other nodes in the domain
@@ -508,7 +628,7 @@ class NetworkPlotter:
508
628
  np.ndarray: Array of RGBA colors adjusted for enrichment status.
509
629
  """
510
630
  # Get the initial domain colors for each node
511
- network_colors = self.network_graph.get_domain_colors(**kwargs, random_seed=random_seed)
631
+ network_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
512
632
  if isinstance(nonenriched_color, str):
513
633
  # Convert the non-enriched color from string to RGBA
514
634
  nonenriched_color = mcolors.to_rgba(nonenriched_color)
@@ -535,14 +655,14 @@ class NetworkPlotter:
535
655
  """
536
656
  # Merge all enriched nodes from the domain_to_nodes dictionary
537
657
  enriched_nodes = set()
538
- for _, nodes in self.network_graph.domain_to_nodes.items():
658
+ for _, nodes in self.graph.domain_to_nodes.items():
539
659
  enriched_nodes.update(nodes)
540
660
 
541
661
  # Initialize all node sizes to the non-enriched size
542
- node_sizes = np.full(len(self.network_graph.G.nodes), nonenriched_nodesize)
662
+ node_sizes = np.full(len(self.graph.network.nodes), nonenriched_nodesize)
543
663
  # Set the size for enriched nodes
544
664
  for node in enriched_nodes:
545
- if node in self.network_graph.G.nodes:
665
+ if node in self.graph.network.nodes:
546
666
  node_sizes[node] = enriched_nodesize
547
667
 
548
668
  return node_sizes
@@ -587,12 +707,12 @@ class NetworkPlotter:
587
707
  if isinstance(color, str):
588
708
  # If a single color string is provided, convert it to RGBA and apply to all domains
589
709
  rgba_color = np.array(matplotlib.colors.to_rgba(color))
590
- return np.array([rgba_color for _ in self.network_graph.domain_to_nodes])
710
+ return np.array([rgba_color for _ in self.graph.domain_to_nodes])
591
711
 
592
712
  # Generate colors for each domain using the provided arguments and random seed
593
- node_colors = self.network_graph.get_domain_colors(**kwargs, random_seed=random_seed)
713
+ node_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
594
714
  annotated_colors = []
595
- for _, nodes in self.network_graph.domain_to_nodes.items():
715
+ for _, nodes in self.graph.domain_to_nodes.items():
596
716
  if len(nodes) > 1:
597
717
  # For domains with multiple nodes, choose the brightest color (sum of RGB values)
598
718
  domain_colors = np.array([node_colors[node] for node in nodes])
@@ -662,12 +782,12 @@ def _calculate_bounding_box(
662
782
 
663
783
 
664
784
  def _best_label_positions(
665
- filtered_domains: Dict[str, Any], center: np.ndarray, radius: float, offset: float
785
+ filtered_domain_centroids: Dict[str, Any], center: np.ndarray, radius: float, offset: float
666
786
  ) -> Dict[str, Any]:
667
787
  """Calculate and optimize label positions for clarity.
668
788
 
669
789
  Args:
670
- filtered_domains (dict): Centroids of the filtered domains.
790
+ filtered_domain_centroids (dict): Centroids of the filtered domains.
671
791
  center (np.ndarray): The center coordinates for label positioning.
672
792
  radius (float): The radius for positioning labels around the center.
673
793
  offset (float): The offset distance from the radius for positioning labels.
@@ -675,15 +795,16 @@ def _best_label_positions(
675
795
  Returns:
676
796
  dict: Optimized positions for labels.
677
797
  """
678
- num_domains = len(filtered_domains)
798
+ num_domains = len(filtered_domain_centroids)
679
799
  # Calculate equidistant positions around the center for initial label placement
680
800
  equidistant_positions = _equidistant_angles_around_center(center, radius, offset, num_domains)
681
801
  # Create a mapping of domains to their initial label positions
682
802
  label_positions = {
683
- domain: position for domain, position in zip(filtered_domains.keys(), equidistant_positions)
803
+ domain: position
804
+ for domain, position in zip(filtered_domain_centroids.keys(), equidistant_positions)
684
805
  }
685
806
  # Optimize the label positions to minimize distance to domain centroids
686
- return _optimize_label_positions(label_positions, filtered_domains)
807
+ return _optimize_label_positions(label_positions, filtered_domain_centroids)
687
808
 
688
809
 
689
810
  def _equidistant_angles_around_center(