risk-network 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. risk/__init__.py +1 -1
  2. risk/annotation/__init__.py +10 -0
  3. risk/{annotations/annotations.py → annotation/annotation.py} +44 -44
  4. risk/{annotations → annotation}/io.py +93 -92
  5. risk/{annotations → annotation}/nltk_setup.py +6 -5
  6. risk/log/__init__.py +1 -1
  7. risk/log/parameters.py +26 -27
  8. risk/neighborhoods/__init__.py +0 -1
  9. risk/neighborhoods/api.py +38 -38
  10. risk/neighborhoods/community.py +33 -4
  11. risk/neighborhoods/domains.py +26 -28
  12. risk/neighborhoods/neighborhoods.py +8 -2
  13. risk/neighborhoods/stats/__init__.py +13 -0
  14. risk/neighborhoods/stats/permutation/__init__.py +6 -0
  15. risk/{stats → neighborhoods/stats}/permutation/permutation.py +24 -21
  16. risk/{stats → neighborhoods/stats}/permutation/test_functions.py +4 -4
  17. risk/{stats/stat_tests.py → neighborhoods/stats/tests.py} +62 -54
  18. risk/network/__init__.py +0 -2
  19. risk/network/graph/__init__.py +0 -2
  20. risk/network/graph/api.py +19 -19
  21. risk/network/graph/graph.py +73 -68
  22. risk/{stats/significance.py → network/graph/stats.py} +2 -2
  23. risk/network/graph/summary.py +12 -13
  24. risk/network/io.py +163 -20
  25. risk/network/plotter/__init__.py +0 -2
  26. risk/network/plotter/api.py +1 -1
  27. risk/network/plotter/canvas.py +36 -36
  28. risk/network/plotter/contour.py +14 -15
  29. risk/network/plotter/labels.py +303 -294
  30. risk/network/plotter/network.py +6 -6
  31. risk/network/plotter/plotter.py +8 -10
  32. risk/network/plotter/utils/colors.py +15 -8
  33. risk/network/plotter/utils/layout.py +3 -3
  34. risk/risk.py +6 -6
  35. risk_network-0.0.12.dist-info/METADATA +122 -0
  36. risk_network-0.0.12.dist-info/RECORD +40 -0
  37. {risk_network-0.0.11.dist-info → risk_network-0.0.12.dist-info}/WHEEL +1 -1
  38. risk/annotations/__init__.py +0 -7
  39. risk/network/geometry.py +0 -150
  40. risk/stats/__init__.py +0 -15
  41. risk/stats/permutation/__init__.py +0 -6
  42. risk_network-0.0.11.dist-info/METADATA +0 -798
  43. risk_network-0.0.11.dist-info/RECORD +0 -41
  44. {risk_network-0.0.11.dist-info → risk_network-0.0.12.dist-info/licenses}/LICENSE +0 -0
  45. {risk_network-0.0.11.dist-info → risk_network-0.0.12.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ class Graph:
17
17
  """A class to represent a network graph and process its nodes and edges.
18
18
 
19
19
  The Graph class provides functionality to handle and manipulate a network graph,
20
- including managing domains, annotations, and node significance data. It also includes methods
20
+ including managing domains, annotation, and node significance data. It also includes methods
21
21
  for transforming and mapping graph coordinates, as well as generating colors based on node
22
22
  significance.
23
23
  """
@@ -25,7 +25,7 @@ class Graph:
25
25
  def __init__(
26
26
  self,
27
27
  network: nx.Graph,
28
- annotations: Dict[str, Any],
28
+ annotation: Dict[str, Any],
29
29
  neighborhoods: Dict[str, Any],
30
30
  domains: pd.DataFrame,
31
31
  trimmed_domains: pd.DataFrame,
@@ -36,7 +36,7 @@ class Graph:
36
36
 
37
37
  Args:
38
38
  network (nx.Graph): The network graph.
39
- annotations (Dict[str, Any]): The annotations associated with the network.
39
+ annotation (Dict[str, Any]): The annotation associated with the network.
40
40
  neighborhoods (Dict[str, Any]): Neighborhood significance data.
41
41
  domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
42
42
  trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
@@ -65,20 +65,24 @@ class Graph:
65
65
  # NOTE: Below this point, instance attributes (i.e., self) will be used!
66
66
  self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
67
67
  # Unfold the network's 3D coordinates to 2D and extract node coordinates
68
- self.network = _unfold_sphere_to_plane(network)
69
- self.node_coordinates = _extract_node_coordinates(self.network)
68
+ self.network = self._unfold_sphere_to_plane(network)
69
+ self.node_coordinates = self._extract_node_coordinates(self.network)
70
70
 
71
71
  # NOTE: Only after the above attributes are initialized, we can create the summary
72
- self.summary = Summary(annotations, neighborhoods, self)
72
+ self.summary = Summary(annotation, neighborhoods, self)
73
73
 
74
- def pop(self, domain_id: str) -> None:
75
- """Remove domain ID from instance domain ID mappings. This can be useful for cleaning up
76
- domain-specific mappings based on a given criterion, as domain attributes are stored and
77
- accessed only in dictionaries modified by this method.
74
+ def pop(self, domain_id: int) -> List[str]:
75
+ """Remove a domain ID from the graph and return the corresponding node labels.
78
76
 
79
77
  Args:
80
- key (str): The domain ID key to be removed from each mapping.
78
+ key (int): The domain ID key to be removed from each mapping.
79
+
80
+ Returns:
81
+ List[str]: A list of node labels associated with the domain ID.
81
82
  """
83
+ # Get the node labels associated with the domain ID
84
+ node_labels = self.domain_id_to_node_labels_map.get(domain_id, [])
85
+
82
86
  # Define the domain mappings to be updated
83
87
  domain_mappings = [
84
88
  self.domain_id_to_node_ids_map,
@@ -97,8 +101,9 @@ class Graph:
97
101
  domain_info["domains"].remove(domain_id)
98
102
  domain_info["significances"].pop(domain_id)
99
103
 
100
- @staticmethod
101
- def _create_domain_id_to_node_ids_map(domains: pd.DataFrame) -> Dict[int, Any]:
104
+ return node_labels
105
+
106
+ def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[int, Any]:
102
107
  """Create a mapping from domains to the list of node IDs belonging to each domain.
103
108
 
104
109
  Args:
@@ -115,8 +120,9 @@ class Graph:
115
120
 
116
121
  return domain_id_to_node_ids_map
117
122
 
118
- @staticmethod
119
- def _create_domain_id_to_domain_terms_map(trimmed_domains: pd.DataFrame) -> Dict[int, Any]:
123
+ def _create_domain_id_to_domain_terms_map(
124
+ self, trimmed_domains: pd.DataFrame
125
+ ) -> Dict[int, Any]:
120
126
  """Create a mapping from domain IDs to their corresponding terms.
121
127
 
122
128
  Args:
@@ -132,8 +138,8 @@ class Graph:
132
138
  )
133
139
  )
134
140
 
135
- @staticmethod
136
141
  def _create_domain_id_to_domain_info_map(
142
+ self,
137
143
  trimmed_domains: pd.DataFrame,
138
144
  ) -> Dict[int, Dict[str, Any]]:
139
145
  """Create a mapping from domain IDs to their corresponding full description and significance score,
@@ -163,14 +169,15 @@ class Graph:
163
169
  sorted_descriptions, sorted_scores = zip(*descriptions_and_scores)
164
170
  # Assign to the domain info map
165
171
  domain_info_map[int(domain_id)] = {
166
- "full_descriptions": list(sorted_descriptions),
167
- "significance_scores": list(sorted_scores),
172
+ "full_descriptions": sorted_descriptions,
173
+ "significance_scores": sorted_scores,
168
174
  }
169
175
 
170
176
  return domain_info_map
171
177
 
172
- @staticmethod
173
- def _create_node_id_to_domain_ids_and_significances(domains: pd.DataFrame) -> Dict[int, Dict]:
178
+ def _create_node_id_to_domain_ids_and_significances(
179
+ self, domains: pd.DataFrame
180
+ ) -> Dict[int, Dict]:
174
181
  """Creates a dictionary mapping each node ID to its corresponding domain IDs and significance values.
175
182
 
176
183
  Args:
@@ -216,54 +223,52 @@ class Graph:
216
223
 
217
224
  return domain_id_to_label_map
218
225
 
226
+ def _unfold_sphere_to_plane(self, G: nx.Graph) -> nx.Graph:
227
+ """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
228
+
229
+ Args:
230
+ G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
219
231
 
220
- def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
221
- """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
232
+ Returns:
233
+ nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
234
+ """
235
+ for node in G.nodes():
236
+ if "z" in G.nodes[node]:
237
+ # Extract 3D coordinates
238
+ x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
239
+ # Calculate spherical coordinates theta and phi from Cartesian coordinates
240
+ r = np.sqrt(x**2 + y**2 + z**2)
241
+ theta = np.arctan2(y, x)
242
+ phi = np.arccos(z / r)
243
+
244
+ # Convert spherical coordinates to 2D plane coordinates
245
+ unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
246
+ unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
247
+ unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
248
+ # Update network node attributes
249
+ G.nodes[node]["x"] = unfolded_x
250
+ G.nodes[node]["y"] = -unfolded_y
251
+ # Remove the 'z' coordinate as it's no longer needed
252
+ del G.nodes[node]["z"]
253
+
254
+ return G
255
+
256
+ def _extract_node_coordinates(self, G: nx.Graph) -> np.ndarray:
257
+ """Extract 2D coordinates of nodes from the graph.
222
258
 
223
- Args:
224
- G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
259
+ Args:
260
+ G (nx.Graph): The network graph with node coordinates.
225
261
 
226
- Returns:
227
- nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
228
- """
229
- for node in G.nodes():
230
- if "z" in G.nodes[node]:
231
- # Extract 3D coordinates
232
- x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
233
- # Calculate spherical coordinates theta and phi from Cartesian coordinates
234
- r = np.sqrt(x**2 + y**2 + z**2)
235
- theta = np.arctan2(y, x)
236
- phi = np.arccos(z / r)
237
-
238
- # Convert spherical coordinates to 2D plane coordinates
239
- unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
240
- unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
241
- unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
242
- # Update network node attributes
243
- G.nodes[node]["x"] = unfolded_x
244
- G.nodes[node]["y"] = -unfolded_y
245
- # Remove the 'z' coordinate as it's no longer needed
246
- del G.nodes[node]["z"]
247
-
248
- return G
249
-
250
-
251
- def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
252
- """Extract 2D coordinates of nodes from the graph.
253
-
254
- Args:
255
- G (nx.Graph): The network graph with node coordinates.
256
-
257
- Returns:
258
- np.ndarray: Array of node coordinates with shape (num_nodes, 2).
259
- """
260
- # Extract x and y coordinates from graph nodes
261
- x_coords = dict(G.nodes.data("x"))
262
- y_coords = dict(G.nodes.data("y"))
263
- coordinates_dicts = [x_coords, y_coords]
264
- # Combine x and y coordinates into a single array
265
- node_positions = {
266
- node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
267
- }
268
- node_coordinates = np.vstack(list(node_positions.values()))
269
- return node_coordinates
262
+ Returns:
263
+ np.ndarray: Array of node coordinates with shape (num_nodes, 2).
264
+ """
265
+ # Extract x and y coordinates from graph nodes
266
+ x_coords = dict(G.nodes.data("x"))
267
+ y_coords = dict(G.nodes.data("y"))
268
+ coordinates_dicts = [x_coords, y_coords]
269
+ # Combine x and y coordinates into a single array
270
+ node_positions = {
271
+ node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
272
+ }
273
+ node_coordinates = np.vstack(list(node_positions.values()))
274
+ return node_coordinates
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/stats/significance
3
- ~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/graph/stats
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, Union
@@ -9,7 +9,7 @@ import numpy as np
9
9
  import pandas as pd
10
10
  from statsmodels.stats.multitest import fdrcorrection
11
11
 
12
- from risk.log.console import logger, log_header
12
+ from risk.log.console import log_header, logger
13
13
 
14
14
 
15
15
  class Summary:
@@ -23,18 +23,18 @@ class Summary:
23
23
 
24
24
  def __init__(
25
25
  self,
26
- annotations: Dict[str, Any],
26
+ annotation: Dict[str, Any],
27
27
  neighborhoods: Dict[str, Any],
28
28
  graph, # Avoid type hinting Graph to prevent circular imports
29
29
  ):
30
30
  """Initialize the Results object with analysis components.
31
31
 
32
32
  Args:
33
- annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
33
+ annotation (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
34
34
  neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
35
35
  graph (Graph): Graph object representing domain-to-node and node-to-label mappings.
36
36
  """
37
- self.annotations = annotations
37
+ self.annotation = annotation
38
38
  self.neighborhoods = neighborhoods
39
39
  self.graph = graph
40
40
 
@@ -81,7 +81,7 @@ class Summary:
81
81
  and annotation member information.
82
82
  """
83
83
  log_header("Loading analysis summary")
84
- # Calculate significance and depletion q-values from p-value matrices in `annotations`
84
+ # Calculate significance and depletion q-values from p-value matrices in annotation
85
85
  enrichment_pvals = self.neighborhoods["enrichment_pvals"]
86
86
  depletion_pvals = self.neighborhoods["depletion_pvals"]
87
87
  enrichment_qvals = self._calculate_qvalues(enrichment_pvals)
@@ -147,10 +147,10 @@ class Summary:
147
147
  .reset_index(drop=True)
148
148
  )
149
149
 
150
- # Convert annotations list to a DataFrame for comparison then merge with results
151
- ordered_annotations = pd.DataFrame({"Annotation": self.annotations["ordered_annotations"]})
150
+ # Convert annotation list to a DataFrame for comparison then merge with results
151
+ ordered_annotation = pd.DataFrame({"Annotation": self.annotation["ordered_annotation"]})
152
152
  # Merge to ensure all annotations are present, filling missing rows with defaults
153
- results = pd.merge(ordered_annotations, results, on="Annotation", how="left").fillna(
153
+ results = pd.merge(ordered_annotation, results, on="Annotation", how="left").fillna(
154
154
  {
155
155
  "Domain ID": -1,
156
156
  "Annotation Members in Network": "",
@@ -170,8 +170,7 @@ class Summary:
170
170
 
171
171
  return results
172
172
 
173
- @staticmethod
174
- def _calculate_qvalues(pvals: np.ndarray) -> np.ndarray:
173
+ def _calculate_qvalues(self, pvals: np.ndarray) -> np.ndarray:
175
174
  """Calculate q-values (FDR) for each row of a p-value matrix.
176
175
 
177
176
  Args:
@@ -206,7 +205,7 @@ class Summary:
206
205
  Minimum significance p-value, significance q-value, depletion p-value, depletion q-value.
207
206
  """
208
207
  try:
209
- annotation_idx = self.annotations["ordered_annotations"].index(description)
208
+ annotation_idx = self.annotation["ordered_annotation"].index(description)
210
209
  except ValueError:
211
210
  return None, None, None, None # Description not found
212
211
 
@@ -236,12 +235,12 @@ class Summary:
236
235
  str: ';'-separated string of node labels that are associated with the annotation.
237
236
  """
238
237
  try:
239
- annotation_idx = self.annotations["ordered_annotations"].index(description)
238
+ annotation_idx = self.annotation["ordered_annotation"].index(description)
240
239
  except ValueError:
241
240
  return "" # Description not found
242
241
 
243
242
  # Get the column (safely) from the sparse matrix
244
- column = self.annotations["matrix"][:, annotation_idx]
243
+ column = self.annotation["matrix"][:, annotation_idx]
245
244
  # Convert the column to a dense array if needed
246
245
  column = column.toarray().ravel() # Convert to a 1D dense array
247
246
  # Get nodes present for the annotation and sort by node label - use np.where on the dense array
risk/network/io.py CHANGED
@@ -15,8 +15,7 @@ import networkx as nx
15
15
  import numpy as np
16
16
  import pandas as pd
17
17
 
18
- from risk.network.geometry import assign_edge_lengths
19
- from risk.log import params, logger, log_header
18
+ from risk.log import log_header, logger, params
20
19
 
21
20
 
22
21
  class NetworkIO:
@@ -49,8 +48,8 @@ class NetworkIO:
49
48
  min_edges_per_node=min_edges_per_node,
50
49
  )
51
50
 
52
- @staticmethod
53
- def load_gpickle_network(
51
+ def load_network_gpickle(
52
+ self,
54
53
  filepath: str,
55
54
  compute_sphere: bool = True,
56
55
  surface_depth: float = 0.0,
@@ -72,9 +71,9 @@ class NetworkIO:
72
71
  surface_depth=surface_depth,
73
72
  min_edges_per_node=min_edges_per_node,
74
73
  )
75
- return networkio._load_gpickle_network(filepath=filepath)
74
+ return networkio._load_network_gpickle(filepath=filepath)
76
75
 
77
- def _load_gpickle_network(self, filepath: str) -> nx.Graph:
76
+ def _load_network_gpickle(self, filepath: str) -> nx.Graph:
78
77
  """Private method to load a network from a GPickle file.
79
78
 
80
79
  Args:
@@ -94,8 +93,8 @@ class NetworkIO:
94
93
  # Initialize the graph
95
94
  return self._initialize_graph(G)
96
95
 
97
- @staticmethod
98
- def load_networkx_network(
96
+ def load_network_networkx(
97
+ self,
99
98
  network: nx.Graph,
100
99
  compute_sphere: bool = True,
101
100
  surface_depth: float = 0.0,
@@ -117,9 +116,9 @@ class NetworkIO:
117
116
  surface_depth=surface_depth,
118
117
  min_edges_per_node=min_edges_per_node,
119
118
  )
120
- return networkio._load_networkx_network(network=network)
119
+ return networkio._load_network_networkx(network=network)
121
120
 
122
- def _load_networkx_network(self, network: nx.Graph) -> nx.Graph:
121
+ def _load_network_networkx(self, network: nx.Graph) -> nx.Graph:
123
122
  """Private method to load a NetworkX graph.
124
123
 
125
124
  Args:
@@ -138,8 +137,8 @@ class NetworkIO:
138
137
  # Initialize the graph
139
138
  return self._initialize_graph(network_copy)
140
139
 
141
- @staticmethod
142
- def load_cytoscape_network(
140
+ def load_network_cytoscape(
141
+ self,
143
142
  filepath: str,
144
143
  source_label: str = "source",
145
144
  target_label: str = "target",
@@ -167,14 +166,14 @@ class NetworkIO:
167
166
  surface_depth=surface_depth,
168
167
  min_edges_per_node=min_edges_per_node,
169
168
  )
170
- return networkio._load_cytoscape_network(
169
+ return networkio._load_network_cytoscape(
171
170
  filepath=filepath,
172
171
  source_label=source_label,
173
172
  target_label=target_label,
174
173
  view_name=view_name,
175
174
  )
176
175
 
177
- def _load_cytoscape_network(
176
+ def _load_network_cytoscape(
178
177
  self,
179
178
  filepath: str,
180
179
  source_label: str = "source",
@@ -194,6 +193,7 @@ class NetworkIO:
194
193
 
195
194
  Raises:
196
195
  ValueError: If no matching attribute metadata file is found.
196
+ KeyError: If the source or target label is not found in the attribute table.
197
197
  """
198
198
  filetype = "Cytoscape"
199
199
  # Log the loading of the Cytoscape file
@@ -307,8 +307,8 @@ class NetworkIO:
307
307
  if os.path.exists(tmp_dir):
308
308
  shutil.rmtree(tmp_dir)
309
309
 
310
- @staticmethod
311
- def load_cytoscape_json_network(
310
+ def load_network_cyjs(
311
+ self,
312
312
  filepath: str,
313
313
  source_label: str = "source",
314
314
  target_label: str = "target",
@@ -334,13 +334,13 @@ class NetworkIO:
334
334
  surface_depth=surface_depth,
335
335
  min_edges_per_node=min_edges_per_node,
336
336
  )
337
- return networkio._load_cytoscape_json_network(
337
+ return networkio._load_network_cyjs(
338
338
  filepath=filepath,
339
339
  source_label=source_label,
340
340
  target_label=target_label,
341
341
  )
342
342
 
343
- def _load_cytoscape_json_network(self, filepath, source_label="source", target_label="target"):
343
+ def _load_network_cyjs(self, filepath, source_label="source", target_label="target"):
344
344
  """Private method to load a network from a Cytoscape JSON (.cyjs) file.
345
345
 
346
346
  Args:
@@ -437,7 +437,8 @@ class NetworkIO:
437
437
  G.remove_nodes_from(nodes_to_remove)
438
438
 
439
439
  # Remove isolated nodes
440
- G.remove_nodes_from(nx.isolates(G))
440
+ isolates = list(nx.isolates(G))
441
+ G.remove_nodes_from(isolates)
441
442
 
442
443
  # Log the number of nodes and edges before and after cleaning
443
444
  num_final_nodes = G.number_of_nodes()
@@ -523,11 +524,153 @@ class NetworkIO:
523
524
  Args:
524
525
  G (nx.Graph): The input network graph.
525
526
  """
526
- assign_edge_lengths(
527
+ G_transformed = self._prepare_graph_for_edge_length_assignment(
527
528
  G,
528
529
  compute_sphere=self.compute_sphere,
529
530
  surface_depth=self.surface_depth,
530
531
  )
532
+ self._calculate_and_set_edge_lengths(G_transformed, self.compute_sphere)
533
+
534
+ def _prepare_graph_for_edge_length_assignment(
535
+ self,
536
+ G: nx.Graph,
537
+ compute_sphere: bool = True,
538
+ surface_depth: float = 0.0,
539
+ ) -> nx.Graph:
540
+ """Prepare the graph by normalizing coordinates and optionally mapping nodes to a sphere.
541
+
542
+ Args:
543
+ G (nx.Graph): The input graph.
544
+ compute_sphere (bool): Whether to map nodes to a sphere. Defaults to True.
545
+ surface_depth (float): The surface depth for mapping to a sphere. Defaults to 0.0.
546
+
547
+ Returns:
548
+ nx.Graph: The graph with transformed coordinates.
549
+ """
550
+ self._normalize_graph_coordinates(G)
551
+
552
+ if compute_sphere:
553
+ self._map_to_sphere(G)
554
+ G_depth = self._create_depth(G, surface_depth=surface_depth)
555
+ else:
556
+ G_depth = G
557
+
558
+ return G_depth
559
+
560
+ def _calculate_and_set_edge_lengths(self, G: nx.Graph, compute_sphere: bool) -> None:
561
+ """Compute and assign edge lengths in the graph.
562
+
563
+ Args:
564
+ G (nx.Graph): The input graph.
565
+ compute_sphere (bool): Whether to compute spherical distances.
566
+ """
567
+
568
+ def compute_distance_vectorized(coords, is_sphere):
569
+ """Compute Euclidean or spherical distances between edges in bulk."""
570
+ u_coords, v_coords = coords[:, 0, :], coords[:, 1, :]
571
+ if is_sphere:
572
+ u_coords /= np.linalg.norm(u_coords, axis=1, keepdims=True)
573
+ v_coords /= np.linalg.norm(v_coords, axis=1, keepdims=True)
574
+ dot_products = np.einsum("ij,ij->i", u_coords, v_coords)
575
+ return np.arccos(np.clip(dot_products, -1.0, 1.0))
576
+ return np.linalg.norm(u_coords - v_coords, axis=1)
577
+
578
+ # Precompute edge coordinate arrays and compute distances in bulk
579
+ edge_data = np.array(
580
+ [
581
+ [
582
+ np.array([G.nodes[u]["x"], G.nodes[u]["y"], G.nodes[u].get("z", 0)]),
583
+ np.array([G.nodes[v]["x"], G.nodes[v]["y"], G.nodes[v].get("z", 0)]),
584
+ ]
585
+ for u, v in G.edges
586
+ ]
587
+ )
588
+ # Compute distances
589
+ distances = compute_distance_vectorized(edge_data, compute_sphere)
590
+ # Assign Euclidean or spherical distances to edges
591
+ for (u, v), distance in zip(G.edges, distances):
592
+ G.edges[u, v]["length"] = distance
593
+
594
+ def _map_to_sphere(self, G: nx.Graph) -> None:
595
+ """Map the x and y coordinates of graph nodes onto a 3D sphere.
596
+
597
+ Args:
598
+ G (nx.Graph): The input graph with nodes having 'x' and 'y' coordinates.
599
+ """
600
+ # Extract x, y coordinates as a NumPy array
601
+ nodes = list(G.nodes)
602
+ xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in nodes])
603
+ # Normalize coordinates between [0, 1]
604
+ min_vals = xy_coords.min(axis=0)
605
+ max_vals = xy_coords.max(axis=0)
606
+ normalized_xy = (xy_coords - min_vals) / (max_vals - min_vals)
607
+ # Convert normalized coordinates to spherical coordinates
608
+ theta = normalized_xy[:, 0] * np.pi * 2
609
+ phi = normalized_xy[:, 1] * np.pi
610
+ # Compute 3D Cartesian coordinates
611
+ x = np.sin(phi) * np.cos(theta)
612
+ y = np.sin(phi) * np.sin(theta)
613
+ z = np.cos(phi)
614
+ # Assign coordinates back to graph nodes in bulk
615
+ xyz_coords = {node: {"x": x[i], "y": y[i], "z": z[i]} for i, node in enumerate(nodes)}
616
+ nx.set_node_attributes(G, xyz_coords)
617
+
618
+ def _normalize_graph_coordinates(self, G: nx.Graph) -> None:
619
+ """Normalize the x and y coordinates of the nodes in the graph to the [0, 1] range.
620
+
621
+ Args:
622
+ G (nx.Graph): The input graph with nodes having 'x' and 'y' coordinates.
623
+ """
624
+ # Extract x, y coordinates from the graph nodes
625
+ xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in G.nodes()])
626
+ # Calculate min and max values for x and y
627
+ min_vals = np.min(xy_coords, axis=0)
628
+ max_vals = np.max(xy_coords, axis=0)
629
+ # Normalize the coordinates to [0, 1]
630
+ normalized_xy = (xy_coords - min_vals) / (max_vals - min_vals)
631
+ # Update the node coordinates with the normalized values
632
+ for i, node in enumerate(G.nodes()):
633
+ G.nodes[node]["x"], G.nodes[node]["y"] = normalized_xy[i]
634
+
635
+ def _create_depth(self, G: nx.Graph, surface_depth: float = 0.0) -> nx.Graph:
636
+ """Adjust the 'z' attribute of each node based on the subcluster strengths and normalized surface depth.
637
+
638
+ Args:
639
+ G (nx.Graph): The input graph.
640
+ surface_depth (float): The maximum surface depth to apply for the strongest subcluster.
641
+
642
+ Returns:
643
+ nx.Graph: The graph with adjusted 'z' attribute for each node.
644
+ """
645
+ if surface_depth >= 1.0:
646
+ surface_depth -= 1e-6 # Cap the surface depth to prevent a value of 1.0
647
+
648
+ # Compute subclusters as connected components
649
+ connected_components = list(nx.connected_components(G))
650
+ subcluster_strengths = {}
651
+ max_strength = 0
652
+ # Precompute strengths and track the maximum strength
653
+ for component in connected_components:
654
+ size = len(component)
655
+ max_strength = max(max_strength, size)
656
+ for node in component:
657
+ subcluster_strengths[node] = size
658
+
659
+ # Avoid repeated lookups and computations by pre-fetching node data
660
+ nodes = list(G.nodes(data=True))
661
+ node_updates = {}
662
+ for node, attrs in nodes:
663
+ strength = subcluster_strengths[node]
664
+ normalized_surface_depth = (strength / max_strength) * surface_depth
665
+ x, y, z = attrs["x"], attrs["y"], attrs["z"]
666
+ norm = np.sqrt(x**2 + y**2 + z**2)
667
+ adjusted_z = z - (z / norm) * normalized_surface_depth
668
+ node_updates[node] = {"z": adjusted_z}
669
+
670
+ # Batch update node attributes
671
+ nx.set_node_attributes(G, node_updates)
672
+
673
+ return G
531
674
 
532
675
  def _log_loading(
533
676
  self,
@@ -2,5 +2,3 @@
2
2
  risk/network/plotter
3
3
  ~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
-
6
- from risk.network.plotter.api import PlotterAPI
@@ -18,7 +18,7 @@ class PlotterAPI:
18
18
  The PlotterAPI class provides methods to load and configure Plotter objects for plotting network graphs.
19
19
  """
20
20
 
21
- def __init__() -> None:
21
+ def __init__(self) -> None:
22
22
  pass
23
23
 
24
24
  def load_plotter(