risk-network 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. risk/__init__.py +1 -1
  2. risk/annotation/__init__.py +10 -0
  3. risk/{annotations/annotations.py → annotation/annotation.py} +62 -102
  4. risk/{annotations → annotation}/io.py +93 -92
  5. risk/annotation/nltk_setup.py +86 -0
  6. risk/log/__init__.py +1 -1
  7. risk/log/parameters.py +26 -27
  8. risk/neighborhoods/__init__.py +0 -1
  9. risk/neighborhoods/api.py +38 -38
  10. risk/neighborhoods/community.py +33 -4
  11. risk/neighborhoods/domains.py +26 -28
  12. risk/neighborhoods/neighborhoods.py +8 -2
  13. risk/neighborhoods/stats/__init__.py +13 -0
  14. risk/neighborhoods/stats/permutation/__init__.py +6 -0
  15. risk/{stats → neighborhoods/stats}/permutation/permutation.py +24 -21
  16. risk/{stats → neighborhoods/stats}/permutation/test_functions.py +5 -4
  17. risk/{stats/stat_tests.py → neighborhoods/stats/tests.py} +62 -54
  18. risk/network/__init__.py +0 -2
  19. risk/network/graph/__init__.py +0 -2
  20. risk/network/graph/api.py +19 -19
  21. risk/network/graph/graph.py +73 -68
  22. risk/{stats/significance.py → network/graph/stats.py} +2 -2
  23. risk/network/graph/summary.py +12 -13
  24. risk/network/io.py +163 -20
  25. risk/network/plotter/__init__.py +0 -2
  26. risk/network/plotter/api.py +1 -1
  27. risk/network/plotter/canvas.py +36 -36
  28. risk/network/plotter/contour.py +14 -15
  29. risk/network/plotter/labels.py +303 -294
  30. risk/network/plotter/network.py +6 -6
  31. risk/network/plotter/plotter.py +8 -10
  32. risk/network/plotter/utils/colors.py +15 -8
  33. risk/network/plotter/utils/layout.py +3 -3
  34. risk/risk.py +6 -7
  35. risk_network-0.0.12.dist-info/METADATA +122 -0
  36. risk_network-0.0.12.dist-info/RECORD +40 -0
  37. {risk_network-0.0.10.dist-info → risk_network-0.0.12.dist-info}/WHEEL +1 -1
  38. risk/annotations/__init__.py +0 -7
  39. risk/network/geometry.py +0 -150
  40. risk/stats/__init__.py +0 -15
  41. risk/stats/permutation/__init__.py +0 -6
  42. risk_network-0.0.10.dist-info/METADATA +0 -798
  43. risk_network-0.0.10.dist-info/RECORD +0 -40
  44. {risk_network-0.0.10.dist-info → risk_network-0.0.12.dist-info/licenses}/LICENSE +0 -0
  45. {risk_network-0.0.10.dist-info → risk_network-0.0.12.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import numpy as np
9
9
  import pandas as pd
10
10
  from statsmodels.stats.multitest import fdrcorrection
11
11
 
12
- from risk.log.console import logger, log_header
12
+ from risk.log.console import log_header, logger
13
13
 
14
14
 
15
15
  class Summary:
@@ -23,18 +23,18 @@ class Summary:
23
23
 
24
24
  def __init__(
25
25
  self,
26
- annotations: Dict[str, Any],
26
+ annotation: Dict[str, Any],
27
27
  neighborhoods: Dict[str, Any],
28
28
  graph, # Avoid type hinting Graph to prevent circular imports
29
29
  ):
30
30
  """Initialize the Results object with analysis components.
31
31
 
32
32
  Args:
33
- annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
33
+ annotation (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
34
34
  neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
35
35
  graph (Graph): Graph object representing domain-to-node and node-to-label mappings.
36
36
  """
37
- self.annotations = annotations
37
+ self.annotation = annotation
38
38
  self.neighborhoods = neighborhoods
39
39
  self.graph = graph
40
40
 
@@ -81,7 +81,7 @@ class Summary:
81
81
  and annotation member information.
82
82
  """
83
83
  log_header("Loading analysis summary")
84
- # Calculate significance and depletion q-values from p-value matrices in `annotations`
84
+ # Calculate significance and depletion q-values from p-value matrices in annotation
85
85
  enrichment_pvals = self.neighborhoods["enrichment_pvals"]
86
86
  depletion_pvals = self.neighborhoods["depletion_pvals"]
87
87
  enrichment_qvals = self._calculate_qvalues(enrichment_pvals)
@@ -147,10 +147,10 @@ class Summary:
147
147
  .reset_index(drop=True)
148
148
  )
149
149
 
150
- # Convert annotations list to a DataFrame for comparison then merge with results
151
- ordered_annotations = pd.DataFrame({"Annotation": self.annotations["ordered_annotations"]})
150
+ # Convert annotation list to a DataFrame for comparison then merge with results
151
+ ordered_annotation = pd.DataFrame({"Annotation": self.annotation["ordered_annotation"]})
152
152
  # Merge to ensure all annotations are present, filling missing rows with defaults
153
- results = pd.merge(ordered_annotations, results, on="Annotation", how="left").fillna(
153
+ results = pd.merge(ordered_annotation, results, on="Annotation", how="left").fillna(
154
154
  {
155
155
  "Domain ID": -1,
156
156
  "Annotation Members in Network": "",
@@ -170,8 +170,7 @@ class Summary:
170
170
 
171
171
  return results
172
172
 
173
- @staticmethod
174
- def _calculate_qvalues(pvals: np.ndarray) -> np.ndarray:
173
+ def _calculate_qvalues(self, pvals: np.ndarray) -> np.ndarray:
175
174
  """Calculate q-values (FDR) for each row of a p-value matrix.
176
175
 
177
176
  Args:
@@ -206,7 +205,7 @@ class Summary:
206
205
  Minimum significance p-value, significance q-value, depletion p-value, depletion q-value.
207
206
  """
208
207
  try:
209
- annotation_idx = self.annotations["ordered_annotations"].index(description)
208
+ annotation_idx = self.annotation["ordered_annotation"].index(description)
210
209
  except ValueError:
211
210
  return None, None, None, None # Description not found
212
211
 
@@ -236,12 +235,12 @@ class Summary:
236
235
  str: ';'-separated string of node labels that are associated with the annotation.
237
236
  """
238
237
  try:
239
- annotation_idx = self.annotations["ordered_annotations"].index(description)
238
+ annotation_idx = self.annotation["ordered_annotation"].index(description)
240
239
  except ValueError:
241
240
  return "" # Description not found
242
241
 
243
242
  # Get the column (safely) from the sparse matrix
244
- column = self.annotations["matrix"][:, annotation_idx]
243
+ column = self.annotation["matrix"][:, annotation_idx]
245
244
  # Convert the column to a dense array if needed
246
245
  column = column.toarray().ravel() # Convert to a 1D dense array
247
246
  # Get nodes present for the annotation and sort by node label - use np.where on the dense array
risk/network/io.py CHANGED
@@ -15,8 +15,7 @@ import networkx as nx
15
15
  import numpy as np
16
16
  import pandas as pd
17
17
 
18
- from risk.network.geometry import assign_edge_lengths
19
- from risk.log import params, logger, log_header
18
+ from risk.log import log_header, logger, params
20
19
 
21
20
 
22
21
  class NetworkIO:
@@ -49,8 +48,8 @@ class NetworkIO:
49
48
  min_edges_per_node=min_edges_per_node,
50
49
  )
51
50
 
52
- @staticmethod
53
- def load_gpickle_network(
51
+ def load_network_gpickle(
52
+ self,
54
53
  filepath: str,
55
54
  compute_sphere: bool = True,
56
55
  surface_depth: float = 0.0,
@@ -72,9 +71,9 @@ class NetworkIO:
72
71
  surface_depth=surface_depth,
73
72
  min_edges_per_node=min_edges_per_node,
74
73
  )
75
- return networkio._load_gpickle_network(filepath=filepath)
74
+ return networkio._load_network_gpickle(filepath=filepath)
76
75
 
77
- def _load_gpickle_network(self, filepath: str) -> nx.Graph:
76
+ def _load_network_gpickle(self, filepath: str) -> nx.Graph:
78
77
  """Private method to load a network from a GPickle file.
79
78
 
80
79
  Args:
@@ -94,8 +93,8 @@ class NetworkIO:
94
93
  # Initialize the graph
95
94
  return self._initialize_graph(G)
96
95
 
97
- @staticmethod
98
- def load_networkx_network(
96
+ def load_network_networkx(
97
+ self,
99
98
  network: nx.Graph,
100
99
  compute_sphere: bool = True,
101
100
  surface_depth: float = 0.0,
@@ -117,9 +116,9 @@ class NetworkIO:
117
116
  surface_depth=surface_depth,
118
117
  min_edges_per_node=min_edges_per_node,
119
118
  )
120
- return networkio._load_networkx_network(network=network)
119
+ return networkio._load_network_networkx(network=network)
121
120
 
122
- def _load_networkx_network(self, network: nx.Graph) -> nx.Graph:
121
+ def _load_network_networkx(self, network: nx.Graph) -> nx.Graph:
123
122
  """Private method to load a NetworkX graph.
124
123
 
125
124
  Args:
@@ -138,8 +137,8 @@ class NetworkIO:
138
137
  # Initialize the graph
139
138
  return self._initialize_graph(network_copy)
140
139
 
141
- @staticmethod
142
- def load_cytoscape_network(
140
+ def load_network_cytoscape(
141
+ self,
143
142
  filepath: str,
144
143
  source_label: str = "source",
145
144
  target_label: str = "target",
@@ -167,14 +166,14 @@ class NetworkIO:
167
166
  surface_depth=surface_depth,
168
167
  min_edges_per_node=min_edges_per_node,
169
168
  )
170
- return networkio._load_cytoscape_network(
169
+ return networkio._load_network_cytoscape(
171
170
  filepath=filepath,
172
171
  source_label=source_label,
173
172
  target_label=target_label,
174
173
  view_name=view_name,
175
174
  )
176
175
 
177
- def _load_cytoscape_network(
176
+ def _load_network_cytoscape(
178
177
  self,
179
178
  filepath: str,
180
179
  source_label: str = "source",
@@ -194,6 +193,7 @@ class NetworkIO:
194
193
 
195
194
  Raises:
196
195
  ValueError: If no matching attribute metadata file is found.
196
+ KeyError: If the source or target label is not found in the attribute table.
197
197
  """
198
198
  filetype = "Cytoscape"
199
199
  # Log the loading of the Cytoscape file
@@ -307,8 +307,8 @@ class NetworkIO:
307
307
  if os.path.exists(tmp_dir):
308
308
  shutil.rmtree(tmp_dir)
309
309
 
310
- @staticmethod
311
- def load_cytoscape_json_network(
310
+ def load_network_cyjs(
311
+ self,
312
312
  filepath: str,
313
313
  source_label: str = "source",
314
314
  target_label: str = "target",
@@ -334,13 +334,13 @@ class NetworkIO:
334
334
  surface_depth=surface_depth,
335
335
  min_edges_per_node=min_edges_per_node,
336
336
  )
337
- return networkio._load_cytoscape_json_network(
337
+ return networkio._load_network_cyjs(
338
338
  filepath=filepath,
339
339
  source_label=source_label,
340
340
  target_label=target_label,
341
341
  )
342
342
 
343
- def _load_cytoscape_json_network(self, filepath, source_label="source", target_label="target"):
343
+ def _load_network_cyjs(self, filepath, source_label="source", target_label="target"):
344
344
  """Private method to load a network from a Cytoscape JSON (.cyjs) file.
345
345
 
346
346
  Args:
@@ -437,7 +437,8 @@ class NetworkIO:
437
437
  G.remove_nodes_from(nodes_to_remove)
438
438
 
439
439
  # Remove isolated nodes
440
- G.remove_nodes_from(nx.isolates(G))
440
+ isolates = list(nx.isolates(G))
441
+ G.remove_nodes_from(isolates)
441
442
 
442
443
  # Log the number of nodes and edges before and after cleaning
443
444
  num_final_nodes = G.number_of_nodes()
@@ -523,11 +524,153 @@ class NetworkIO:
523
524
  Args:
524
525
  G (nx.Graph): The input network graph.
525
526
  """
526
- assign_edge_lengths(
527
+ G_transformed = self._prepare_graph_for_edge_length_assignment(
527
528
  G,
528
529
  compute_sphere=self.compute_sphere,
529
530
  surface_depth=self.surface_depth,
530
531
  )
532
+ self._calculate_and_set_edge_lengths(G_transformed, self.compute_sphere)
533
+
534
+ def _prepare_graph_for_edge_length_assignment(
535
+ self,
536
+ G: nx.Graph,
537
+ compute_sphere: bool = True,
538
+ surface_depth: float = 0.0,
539
+ ) -> nx.Graph:
540
+ """Prepare the graph by normalizing coordinates and optionally mapping nodes to a sphere.
541
+
542
+ Args:
543
+ G (nx.Graph): The input graph.
544
+ compute_sphere (bool): Whether to map nodes to a sphere. Defaults to True.
545
+ surface_depth (float): The surface depth for mapping to a sphere. Defaults to 0.0.
546
+
547
+ Returns:
548
+ nx.Graph: The graph with transformed coordinates.
549
+ """
550
+ self._normalize_graph_coordinates(G)
551
+
552
+ if compute_sphere:
553
+ self._map_to_sphere(G)
554
+ G_depth = self._create_depth(G, surface_depth=surface_depth)
555
+ else:
556
+ G_depth = G
557
+
558
+ return G_depth
559
+
560
+ def _calculate_and_set_edge_lengths(self, G: nx.Graph, compute_sphere: bool) -> None:
561
+ """Compute and assign edge lengths in the graph.
562
+
563
+ Args:
564
+ G (nx.Graph): The input graph.
565
+ compute_sphere (bool): Whether to compute spherical distances.
566
+ """
567
+
568
+ def compute_distance_vectorized(coords, is_sphere):
569
+ """Compute Euclidean or spherical distances between edges in bulk."""
570
+ u_coords, v_coords = coords[:, 0, :], coords[:, 1, :]
571
+ if is_sphere:
572
+ u_coords /= np.linalg.norm(u_coords, axis=1, keepdims=True)
573
+ v_coords /= np.linalg.norm(v_coords, axis=1, keepdims=True)
574
+ dot_products = np.einsum("ij,ij->i", u_coords, v_coords)
575
+ return np.arccos(np.clip(dot_products, -1.0, 1.0))
576
+ return np.linalg.norm(u_coords - v_coords, axis=1)
577
+
578
+ # Precompute edge coordinate arrays and compute distances in bulk
579
+ edge_data = np.array(
580
+ [
581
+ [
582
+ np.array([G.nodes[u]["x"], G.nodes[u]["y"], G.nodes[u].get("z", 0)]),
583
+ np.array([G.nodes[v]["x"], G.nodes[v]["y"], G.nodes[v].get("z", 0)]),
584
+ ]
585
+ for u, v in G.edges
586
+ ]
587
+ )
588
+ # Compute distances
589
+ distances = compute_distance_vectorized(edge_data, compute_sphere)
590
+ # Assign Euclidean or spherical distances to edges
591
+ for (u, v), distance in zip(G.edges, distances):
592
+ G.edges[u, v]["length"] = distance
593
+
594
+ def _map_to_sphere(self, G: nx.Graph) -> None:
595
+ """Map the x and y coordinates of graph nodes onto a 3D sphere.
596
+
597
+ Args:
598
+ G (nx.Graph): The input graph with nodes having 'x' and 'y' coordinates.
599
+ """
600
+ # Extract x, y coordinates as a NumPy array
601
+ nodes = list(G.nodes)
602
+ xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in nodes])
603
+ # Normalize coordinates between [0, 1]
604
+ min_vals = xy_coords.min(axis=0)
605
+ max_vals = xy_coords.max(axis=0)
606
+ normalized_xy = (xy_coords - min_vals) / (max_vals - min_vals)
607
+ # Convert normalized coordinates to spherical coordinates
608
+ theta = normalized_xy[:, 0] * np.pi * 2
609
+ phi = normalized_xy[:, 1] * np.pi
610
+ # Compute 3D Cartesian coordinates
611
+ x = np.sin(phi) * np.cos(theta)
612
+ y = np.sin(phi) * np.sin(theta)
613
+ z = np.cos(phi)
614
+ # Assign coordinates back to graph nodes in bulk
615
+ xyz_coords = {node: {"x": x[i], "y": y[i], "z": z[i]} for i, node in enumerate(nodes)}
616
+ nx.set_node_attributes(G, xyz_coords)
617
+
618
+ def _normalize_graph_coordinates(self, G: nx.Graph) -> None:
619
+ """Normalize the x and y coordinates of the nodes in the graph to the [0, 1] range.
620
+
621
+ Args:
622
+ G (nx.Graph): The input graph with nodes having 'x' and 'y' coordinates.
623
+ """
624
+ # Extract x, y coordinates from the graph nodes
625
+ xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in G.nodes()])
626
+ # Calculate min and max values for x and y
627
+ min_vals = np.min(xy_coords, axis=0)
628
+ max_vals = np.max(xy_coords, axis=0)
629
+ # Normalize the coordinates to [0, 1]
630
+ normalized_xy = (xy_coords - min_vals) / (max_vals - min_vals)
631
+ # Update the node coordinates with the normalized values
632
+ for i, node in enumerate(G.nodes()):
633
+ G.nodes[node]["x"], G.nodes[node]["y"] = normalized_xy[i]
634
+
635
+ def _create_depth(self, G: nx.Graph, surface_depth: float = 0.0) -> nx.Graph:
636
+ """Adjust the 'z' attribute of each node based on the subcluster strengths and normalized surface depth.
637
+
638
+ Args:
639
+ G (nx.Graph): The input graph.
640
+ surface_depth (float): The maximum surface depth to apply for the strongest subcluster.
641
+
642
+ Returns:
643
+ nx.Graph: The graph with adjusted 'z' attribute for each node.
644
+ """
645
+ if surface_depth >= 1.0:
646
+ surface_depth -= 1e-6 # Cap the surface depth to prevent a value of 1.0
647
+
648
+ # Compute subclusters as connected components
649
+ connected_components = list(nx.connected_components(G))
650
+ subcluster_strengths = {}
651
+ max_strength = 0
652
+ # Precompute strengths and track the maximum strength
653
+ for component in connected_components:
654
+ size = len(component)
655
+ max_strength = max(max_strength, size)
656
+ for node in component:
657
+ subcluster_strengths[node] = size
658
+
659
+ # Avoid repeated lookups and computations by pre-fetching node data
660
+ nodes = list(G.nodes(data=True))
661
+ node_updates = {}
662
+ for node, attrs in nodes:
663
+ strength = subcluster_strengths[node]
664
+ normalized_surface_depth = (strength / max_strength) * surface_depth
665
+ x, y, z = attrs["x"], attrs["y"], attrs["z"]
666
+ norm = np.sqrt(x**2 + y**2 + z**2)
667
+ adjusted_z = z - (z / norm) * normalized_surface_depth
668
+ node_updates[node] = {"z": adjusted_z}
669
+
670
+ # Batch update node attributes
671
+ nx.set_node_attributes(G, node_updates)
672
+
673
+ return G
531
674
 
532
675
  def _log_loading(
533
676
  self,
@@ -2,5 +2,3 @@
2
2
  risk/network/plotter
3
3
  ~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
-
6
- from risk.network.plotter.api import PlotterAPI
@@ -18,7 +18,7 @@ class PlotterAPI:
18
18
  The PlotterAPI class provides methods to load and configure Plotter objects for plotting network graphs.
19
19
  """
20
20
 
21
- def __init__() -> None:
21
+ def __init__(self) -> None:
22
22
  pass
23
23
 
24
24
  def load_plotter(
@@ -76,7 +76,7 @@ class Canvas:
76
76
  fig = self.ax.figure
77
77
  # Use a tight layout to ensure that title and subtitle do not overlap with the original plot
78
78
  fig.tight_layout(
79
- rect=[0, 0, 1, 1 - title_space_offset]
79
+ rect=(0, 0, 1, 1 - title_space_offset)
80
80
  ) # Leave space above the plot for title
81
81
 
82
82
  # Plot title if provided
@@ -158,7 +158,7 @@ class Canvas:
158
158
  # Calculate the center and radius of the bounding box around the network
159
159
  center, radius = calculate_bounding_box(node_coordinates)
160
160
  # Adjust the center based on user-defined offsets
161
- adjusted_center = _calculate_adjusted_center(
161
+ adjusted_center = self._calculate_adjusted_center(
162
162
  center, radius, center_offset_x, center_offset_y
163
163
  )
164
164
  # Scale the radius by the scale factor
@@ -250,42 +250,42 @@ class Canvas:
250
250
  fill_alpha=fill_alpha,
251
251
  )
252
252
 
253
+ def _calculate_adjusted_center(
254
+ self,
255
+ center: Tuple[float, float],
256
+ radius: float,
257
+ center_offset_x: float = 0.0,
258
+ center_offset_y: float = 0.0,
259
+ ) -> Tuple[float, float]:
260
+ """Calculate the adjusted center for the network perimeter circle based on user-defined offsets.
253
261
 
254
- def _calculate_adjusted_center(
255
- center: Tuple[float, float],
256
- radius: float,
257
- center_offset_x: float = 0.0,
258
- center_offset_y: float = 0.0,
259
- ) -> Tuple[float, float]:
260
- """Calculate the adjusted center for the network perimeter circle based on user-defined offsets.
261
-
262
- Args:
263
- center (Tuple[float, float]): Original center coordinates of the network graph.
264
- radius (float): Radius of the bounding box around the network.
265
- center_offset_x (float, optional): Horizontal offset as a fraction of the diameter.
266
- Negative values shift the center left, positive values shift it right. Allowed
267
- values are in the range [-1, 1]. Defaults to 0.0.
268
- center_offset_y (float, optional): Vertical offset as a fraction of the diameter.
269
- Negative values shift the center down, positive values shift it up. Allowed
270
- values are in the range [-1, 1]. Defaults to 0.0.
262
+ Args:
263
+ center (Tuple[float, float]): Original center coordinates of the network graph.
264
+ radius (float): Radius of the bounding box around the network.
265
+ center_offset_x (float, optional): Horizontal offset as a fraction of the diameter.
266
+ Negative values shift the center left, positive values shift it right. Allowed
267
+ values are in the range [-1, 1]. Defaults to 0.0.
268
+ center_offset_y (float, optional): Vertical offset as a fraction of the diameter.
269
+ Negative values shift the center down, positive values shift it up. Allowed
270
+ values are in the range [-1, 1]. Defaults to 0.0.
271
271
 
272
- Returns:
273
- Tuple[float, float]: Adjusted center coordinates after applying the offsets.
272
+ Returns:
273
+ Tuple[float, float]: Adjusted center coordinates after applying the offsets.
274
274
 
275
- Raises:
276
- ValueError: If the center offsets are outside the valid range [-1, 1].
277
- """
278
- # Flip the y-axis to match the plot orientation
279
- flipped_center_offset_y = -center_offset_y
280
- # Validate the center offsets
281
- if not -1 <= center_offset_x <= 1:
282
- raise ValueError("Horizontal center offset must be in the range [-1, 1].")
283
- if not -1 <= center_offset_y <= 1:
284
- raise ValueError("Vertical center offset must be in the range [-1, 1].")
275
+ Raises:
276
+ ValueError: If the center offsets are outside the valid range [-1, 1].
277
+ """
278
+ # Flip the y-axis to match the plot orientation
279
+ flipped_center_offset_y = -center_offset_y
280
+ # Validate the center offsets
281
+ if not -1 <= center_offset_x <= 1:
282
+ raise ValueError("Horizontal center offset must be in the range [-1, 1].")
283
+ if not -1 <= center_offset_y <= 1:
284
+ raise ValueError("Vertical center offset must be in the range [-1, 1].")
285
285
 
286
- # Calculate adjusted center by applying offset fractions of the diameter
287
- adjusted_center_x = center[0] + (center_offset_x * radius * 2)
288
- adjusted_center_y = center[1] + (flipped_center_offset_y * radius * 2)
286
+ # Calculate adjusted center by applying offset fractions of the diameter
287
+ adjusted_center_x = center[0] + (center_offset_x * radius * 2)
288
+ adjusted_center_y = center[1] + (flipped_center_offset_y * radius * 2)
289
289
 
290
- # Return the adjusted center coordinates
291
- return adjusted_center_x, adjusted_center_y
290
+ # Return the adjusted center coordinates
291
+ return adjusted_center_x, adjusted_center_y
@@ -11,7 +11,7 @@ from scipy import linalg
11
11
  from scipy.ndimage import label
12
12
  from scipy.stats import gaussian_kde
13
13
 
14
- from risk.log import params, logger
14
+ from risk.log import logger, params
15
15
  from risk.network.graph.graph import Graph
16
16
  from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
17
17
 
@@ -213,7 +213,7 @@ class Contour:
213
213
  ]
214
214
  z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
215
215
  # Check if the KDE forms a single connected component
216
- connected = _is_connected(z)
216
+ connected = self._is_connected(z)
217
217
  if not connected:
218
218
  bandwidth += 0.05 # Increase bandwidth slightly and retry
219
219
  except linalg.LinAlgError:
@@ -282,8 +282,8 @@ class Contour:
282
282
  scale_factor: float = 1.0,
283
283
  ids_to_colors: Union[Dict[int, Any], None] = None,
284
284
  random_seed: int = 888,
285
- ) -> np.ndarray:
286
- """Get colors for the contours based on node annotations or a specified colormap.
285
+ ) -> List[Tuple]:
286
+ """Get colors for the contours based on node annotation or a specified colormap.
287
287
 
288
288
  Args:
289
289
  cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
@@ -301,7 +301,7 @@ class Contour:
301
301
  random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
302
302
 
303
303
  Returns:
304
- np.ndarray: Array of RGBA colors for contour annotations.
304
+ List[Tuple]: List of RGBA colors for the contours, one for each domain in the network graph.
305
305
  """
306
306
  return get_annotated_domain_colors(
307
307
  graph=self.graph,
@@ -316,15 +316,14 @@ class Contour:
316
316
  random_seed=random_seed,
317
317
  )
318
318
 
319
+ def _is_connected(self, z: np.ndarray) -> bool:
320
+ """Determine if a thresholded grid represents a single, connected component.
319
321
 
320
- def _is_connected(z: np.ndarray) -> bool:
321
- """Determine if a thresholded grid represents a single, connected component.
322
-
323
- Args:
324
- z (np.ndarray): A binary grid where the component connectivity is evaluated.
322
+ Args:
323
+ z (np.ndarray): A binary grid where the component connectivity is evaluated.
325
324
 
326
- Returns:
327
- bool: True if the grid represents a single connected component, False otherwise.
328
- """
329
- _, num_features = label(z)
330
- return num_features == 1 # Return True if only one connected component is found
325
+ Returns:
326
+ bool: True if the grid represents a single connected component, False otherwise.
327
+ """
328
+ _, num_features = label(z)
329
+ return num_features == 1 # Return True if only one connected component is found