risk-network 0.0.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/network/graph.py ADDED
@@ -0,0 +1,280 @@
1
+ """
2
+ risk/network/graph
3
+ ~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import random
7
+ from collections import defaultdict
8
+ from typing import Any, Dict, List, Tuple
9
+
10
+ import networkx as nx
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib
14
+ import matplotlib.cm as cm
15
+
16
+
17
+ class NetworkGraph:
18
+ """A class to represent a network graph and process its nodes and edges.
19
+
20
+ The NetworkGraph class provides functionality to handle and manipulate a network graph,
21
+ including managing domains, annotations, and node enrichment data. It also includes methods
22
+ for transforming and mapping graph coordinates, as well as generating colors based on node
23
+ enrichment.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ network: nx.Graph,
29
+ top_annotations: pd.DataFrame,
30
+ domains: pd.DataFrame,
31
+ trimmed_domains: pd.DataFrame,
32
+ node_label_to_id_map: Dict[str, Any],
33
+ node_enrichment_sums: np.ndarray,
34
+ ):
35
+ """Initialize the NetworkGraph object.
36
+
37
+ Args:
38
+ network (nx.Graph): The network graph.
39
+ top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
40
+ domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
41
+ trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
42
+ node_label_to_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
43
+ node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
44
+ """
45
+ self.top_annotations = top_annotations
46
+ self.domain_to_nodes = self._create_domain_to_nodes_map(domains)
47
+ self.domains = domains
48
+ self.trimmed_domain_to_term = self._create_domain_to_term_map(trimmed_domains)
49
+ self.trimmed_domains = trimmed_domains
50
+ self.node_label_to_id_map = node_label_to_id_map
51
+ self.node_enrichment_sums = node_enrichment_sums
52
+ # NOTE: self.G and self.node_coordinates are declared in _initialize_network
53
+ self.G = None
54
+ self.node_coordinates = None
55
+ self._initialize_network(network)
56
+
57
+ def _create_domain_to_nodes_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
58
+ """Create a mapping from domains to the list of nodes belonging to each domain.
59
+
60
+ Args:
61
+ domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
62
+
63
+ Returns:
64
+ dict: A dictionary where keys are domain IDs and values are lists of nodes belonging to each domain.
65
+ """
66
+ cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
67
+ node_to_domains = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
68
+ domain_to_nodes = defaultdict(list)
69
+ for k, v in node_to_domains.items():
70
+ domain_to_nodes[v].append(k)
71
+
72
+ return domain_to_nodes
73
+
74
+ def _create_domain_to_term_map(self, trimmed_domains: pd.DataFrame) -> Dict[str, Any]:
75
+ """Create a mapping from domain IDs to their corresponding terms.
76
+
77
+ Args:
78
+ trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
79
+
80
+ Returns:
81
+ dict: A dictionary mapping domain IDs to their corresponding terms.
82
+ """
83
+ return dict(
84
+ zip(
85
+ trimmed_domains.index,
86
+ trimmed_domains["label"],
87
+ )
88
+ )
89
+
90
+ def _initialize_network(self, G: nx.Graph) -> None:
91
+ """Initialize the network by unfolding it and extracting node coordinates.
92
+
93
+ Args:
94
+ G (nx.Graph): The input network graph with 3D node coordinates.
95
+ """
96
+ # Unfold the network's 3D coordinates to 2D
97
+ G_2d = _unfold_sphere_to_plane(G)
98
+ # Assign the unfolded graph to self.G
99
+ self.G = G_2d
100
+ # Extract 2D coordinates of nodes
101
+ self.node_coordinates = _extract_node_coordinates(G_2d)
102
+
103
+ def get_domain_colors(
104
+ self, min_scale: float = 0.8, max_scale: float = 1.0, random_seed: int = 888, **kwargs
105
+ ) -> np.ndarray:
106
+ """Generate composite colors for domains.
107
+
108
+ This method generates composite colors for nodes based on their enrichment scores and transforms
109
+ them to ensure proper alpha values and intensity. For nodes with alpha == 0, it assigns new colors
110
+ based on the closest valid neighbors within a specified distance.
111
+
112
+ Args:
113
+ min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
114
+ max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
115
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
116
+ **kwargs: Additional keyword arguments for color generation.
117
+
118
+ Returns:
119
+ np.ndarray: Array of transformed colors.
120
+ """
121
+ # Get colors for each domain
122
+ domain_colors = self._get_domain_colors(**kwargs, random_seed=random_seed)
123
+ # Generate composite colors for nodes
124
+ node_colors = self._get_composite_node_colors(domain_colors)
125
+ # Transform colors to ensure proper alpha values and intensity
126
+ transformed_colors = _transform_colors(
127
+ node_colors,
128
+ self.node_enrichment_sums,
129
+ min_scale=min_scale,
130
+ max_scale=max_scale,
131
+ )
132
+
133
+ return transformed_colors
134
+
135
+ def _get_composite_node_colors(self, domain_colors: np.ndarray) -> np.ndarray:
136
+ """Generate composite colors for nodes based on domain colors and counts.
137
+
138
+ Args:
139
+ domain_colors (np.ndarray): Array of colors corresponding to each domain.
140
+
141
+ Returns:
142
+ np.ndarray: Array of composite colors for each node.
143
+ """
144
+ # Determine the number of nodes
145
+ num_nodes = len(self.node_coordinates)
146
+ # Initialize composite colors array with shape (number of nodes, 4) for RGBA
147
+ composite_colors = np.zeros((num_nodes, 4))
148
+ # Assign colors to nodes based on domain_colors
149
+ for domain_idx, nodes in self.domain_to_nodes.items():
150
+ color = domain_colors[domain_idx]
151
+ for node in nodes:
152
+ composite_colors[node] = color
153
+
154
+ return composite_colors
155
+
156
+ def _get_domain_colors(self, **kwargs) -> Dict[str, Any]:
157
+ """Get colors for each domain.
158
+
159
+ Returns:
160
+ dict: A dictionary mapping domain keys to their corresponding RGBA colors.
161
+ """
162
+ # Exclude non-numeric domain columns
163
+ numeric_domains = [
164
+ col for col in self.domains.columns if isinstance(col, (int, np.integer))
165
+ ]
166
+ domains = np.sort(numeric_domains)
167
+ domain_colors = _get_colors(**kwargs, num_colors_to_generate=len(domains))
168
+ return dict(zip(self.domain_to_nodes.keys(), domain_colors))
169
+
170
+
171
+ def _transform_colors(
172
+ colors: np.ndarray, enrichment_sums: np.ndarray, min_scale: float = 0.8, max_scale: float = 1.0
173
+ ) -> np.ndarray:
174
+ """Transform colors to ensure proper alpha values and intensity based on enrichment sums.
175
+
176
+ Args:
177
+ colors (np.ndarray): An array of RGBA colors.
178
+ enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
179
+ min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
180
+ max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
181
+
182
+ Returns:
183
+ np.ndarray: The transformed array of RGBA colors with adjusted intensities.
184
+ """
185
+ if min_scale == max_scale:
186
+ min_scale = max_scale - 10e-6 # Avoid division by zero
187
+
188
+ log_enrichment_sums = np.log1p(enrichment_sums) # Use log1p to avoid log(0)
189
+ # Normalize the capped enrichment sums to the range [0, 1]
190
+ normalized_sums = log_enrichment_sums / np.max(log_enrichment_sums)
191
+ # Scale normalized sums to the specified color range [min_scale, max_scale]
192
+ scaled_sums = min_scale + (max_scale - min_scale) * normalized_sums
193
+ # Adjust RGB values based on scaled sums
194
+ for i in range(3): # Only adjust RGB values
195
+ colors[:, i] = scaled_sums * colors[:, i]
196
+
197
+ return colors
198
+
199
+
200
+ def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
201
+ """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
202
+
203
+ Args:
204
+ G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
205
+
206
+ Returns:
207
+ nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
208
+ """
209
+ for node in G.nodes():
210
+ if "z" in G.nodes[node]:
211
+ # Extract 3D coordinates
212
+ x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
213
+ # Calculate spherical coordinates theta and phi from Cartesian coordinates
214
+ r = np.sqrt(x**2 + y**2 + z**2)
215
+ theta = np.arctan2(y, x)
216
+ phi = np.arccos(z / r)
217
+
218
+ # Convert spherical coordinates to 2D plane coordinates
219
+ unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
220
+ unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
221
+ unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
222
+ # Update network node attributes
223
+ G.nodes[node]["x"] = unfolded_x
224
+ G.nodes[node]["y"] = -unfolded_y
225
+ # Remove the 'z' coordinate as it's no longer needed
226
+ del G.nodes[node]["z"]
227
+
228
+ return G
229
+
230
+
231
+ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
232
+ """Extract 2D coordinates of nodes from the graph.
233
+
234
+ Args:
235
+ G (nx.Graph): The network graph with node coordinates.
236
+
237
+ Returns:
238
+ np.ndarray: Array of node coordinates with shape (num_nodes, 2).
239
+ """
240
+ # Extract x and y coordinates from graph nodes
241
+ x_coords = dict(G.nodes.data("x"))
242
+ y_coords = dict(G.nodes.data("y"))
243
+ coordinates_dicts = [x_coords, y_coords]
244
+ # Combine x and y coordinates into a single array
245
+ node_positions = {
246
+ node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
247
+ }
248
+ node_coordinates = np.vstack(list(node_positions.values()))
249
+ return node_coordinates
250
+
251
+
252
+ def _get_colors(
253
+ num_colors_to_generate: int = 10, cmap: str = "hsv", random_seed: int = 888, **kwargs
254
+ ) -> List[Tuple]:
255
+ """Generate a list of RGBA colors from a specified colormap or use a direct color string.
256
+
257
+ Args:
258
+ num_colors_to_generate (int): The number of colors to generate. Defaults to 10.
259
+ cmap (str): The name of the colormap to use. Defaults to "hsv".
260
+ random_seed (int): Seed for random number generation. Defaults to 888.
261
+ **kwargs: Additional keyword arguments, such as 'color' for a specific color.
262
+
263
+ Returns:
264
+ list of tuple: List of RGBA colors.
265
+ """
266
+ # Set random seed for reproducibility
267
+ random.seed(random_seed)
268
+ if kwargs.get("color"):
269
+ # If a direct color string is provided, generate a list with that color
270
+ rgba = matplotlib.colors.to_rgba(kwargs["color"])
271
+ rgbas = [rgba] * num_colors_to_generate
272
+ else:
273
+ colormap = cm.get_cmap(cmap)
274
+ # Generate evenly distributed color positions
275
+ color_positions = np.linspace(0, 1, num_colors_to_generate)
276
+ random.shuffle(color_positions) # Shuffle the positions to randomize colors
277
+ # Generate colors based on shuffled positions
278
+ rgbas = [colormap(pos) for pos in color_positions]
279
+
280
+ return rgbas
risk/network/io.py ADDED
@@ -0,0 +1,326 @@
1
+ """
2
+ risk/network/io
3
+ ~~~~~~~~~~~~~~~
4
+
5
+ This file contains the code for the RISK class and command-line access.
6
+ """
7
+
8
+ import json
9
+ import pickle
10
+ import shutil
11
+ import zipfile
12
+ from xml.dom import minidom
13
+
14
+ import networkx as nx
15
+ import pandas as pd
16
+
17
+ from risk.network.geometry import apply_edge_lengths
18
+ from risk.log import params, print_header
19
+
20
+
21
+ class NetworkIO:
22
+ """A class for loading, processing, and managing network data.
23
+
24
+ The NetworkIO class provides methods to load network data from various formats (e.g., GPickle, NetworkX)
25
+ and process the network by adjusting node coordinates, calculating edge lengths, and validating graph structure.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ compute_sphere: bool = True,
31
+ surface_depth: float = 0.0,
32
+ distance_metric: str = "dijkstra",
33
+ edge_length_threshold: float = 0.5,
34
+ louvain_resolution: float = 0.1,
35
+ min_edges_per_node: int = 0,
36
+ include_edge_weight: bool = True,
37
+ weight_label: str = "weight",
38
+ ):
39
+ self.compute_sphere = compute_sphere
40
+ self.surface_depth = surface_depth
41
+ self.include_edge_weight = include_edge_weight
42
+ self.weight_label = weight_label
43
+ self.distance_metric = distance_metric
44
+ self.edge_length_threshold = edge_length_threshold
45
+ self.louvain_resolution = louvain_resolution
46
+ self.min_edges_per_node = min_edges_per_node
47
+
48
+ def load_gpickle_network(self, filepath: str) -> nx.Graph:
49
+ """Load a network from a GPickle file.
50
+
51
+ Args:
52
+ filepath (str): Path to the GPickle file.
53
+
54
+ Returns:
55
+ nx.Graph: Loaded and processed network.
56
+ """
57
+ filetype = "GPickle"
58
+ params.log_network(filetype=filetype, filepath=filepath)
59
+ self._log_loading(filetype, filepath=filepath)
60
+ with open(filepath, "rb") as f:
61
+ G = pickle.load(f)
62
+
63
+ return self._initialize_graph(G)
64
+
65
+ def load_networkx_network(self, G: nx.Graph) -> nx.Graph:
66
+ """Load a NetworkX graph.
67
+
68
+ Args:
69
+ G (nx.Graph): A NetworkX graph object.
70
+
71
+ Returns:
72
+ nx.Graph: Processed network.
73
+ """
74
+ filetype = "NetworkX"
75
+ params.log_network(filetype=filetype)
76
+ self._log_loading(filetype)
77
+ return self._initialize_graph(G)
78
+
79
+ def load_cytoscape_network(
80
+ self,
81
+ filepath: str,
82
+ source_label: str = "source",
83
+ target_label: str = "target",
84
+ view_name: str = "",
85
+ ) -> nx.Graph:
86
+ """Load a network from a Cytoscape file.
87
+
88
+ Args:
89
+ filepath (str): Path to the Cytoscape file.
90
+ source_label (str, optional): Source node label. Defaults to "source".
91
+ target_label (str, optional): Target node label. Defaults to "target".
92
+ view_name (str, optional): Specific view name to load. Defaults to None.
93
+
94
+ Returns:
95
+ nx.Graph: Loaded and processed network.
96
+ """
97
+ filetype = "Cytoscape"
98
+ params.log_network(filetype=filetype, filepath=str(filepath))
99
+ self._log_loading(filetype, filepath=filepath)
100
+ cys_files = []
101
+ # Try / finally to remove unzipped files
102
+ try:
103
+ # Unzip CYS file
104
+ with zipfile.ZipFile(filepath, "r") as zip_ref:
105
+ cys_files = zip_ref.namelist()
106
+ zip_ref.extractall("./")
107
+ # Get first view and network instances
108
+ cys_view_files = [cf for cf in cys_files if "/views/" in cf]
109
+ cys_view_file = (
110
+ cys_view_files[0]
111
+ if not view_name
112
+ else [cvf for cvf in cys_view_files if cvf.endswith(view_name + ".xgmml")][0]
113
+ )
114
+ # Parse nodes
115
+ cys_view_dom = minidom.parse(cys_view_file)
116
+ cys_nodes = cys_view_dom.getElementsByTagName("node")
117
+ node_x_positions = {}
118
+ node_y_positions = {}
119
+ for node in cys_nodes:
120
+ # Node ID is found in 'label'
121
+ node_id = str(node.attributes["label"].value)
122
+ for child in node.childNodes:
123
+ if child.nodeType == 1 and child.tagName == "graphics":
124
+ node_x_positions[node_id] = float(child.attributes["x"].value)
125
+ node_y_positions[node_id] = float(child.attributes["y"].value)
126
+
127
+ # Read the node attributes (from /tables/)
128
+ attribute_metadata_keywords = ["/tables/", "SHARED_ATTRS", "edge.cytable"]
129
+ attribute_metadata = [
130
+ cf
131
+ for cf in cys_files
132
+ if all(keyword in cf for keyword in attribute_metadata_keywords)
133
+ ][0]
134
+ # Load attributes file from Cytoscape as pandas data frame
135
+ attribute_table = pd.read_csv(attribute_metadata, sep=",", header=None, skiprows=1)
136
+ # Set columns
137
+ attribute_table.columns = attribute_table.iloc[0]
138
+ # Skip first four rows
139
+ attribute_table = attribute_table.iloc[4:, :]
140
+ # Conditionally select columns based on include_edge_weight
141
+ if self.include_edge_weight:
142
+ attribute_table = attribute_table[[source_label, target_label, self.weight_label]]
143
+ else:
144
+ attribute_table = attribute_table[[source_label, target_label]]
145
+
146
+ attribute_table = attribute_table.dropna().reset_index(drop=True)
147
+ # Create a graph
148
+ G = nx.Graph()
149
+ # Add edges and nodes, conditionally including weights
150
+ for _, row in attribute_table.iterrows():
151
+ source = row[source_label]
152
+ target = row[target_label]
153
+ if self.include_edge_weight:
154
+ weight = float(row[self.weight_label])
155
+ G.add_edge(source, target, weight=weight)
156
+ else:
157
+ G.add_edge(source, target)
158
+
159
+ if source not in G:
160
+ G.add_node(source) # Optionally add x, y coordinates here if available
161
+ if target not in G:
162
+ G.add_node(target) # Optionally add x, y coordinates here if available
163
+
164
+ # Add node attributes
165
+ for node in G.nodes():
166
+ G.nodes[node]["label"] = node
167
+ G.nodes[node]["x"] = node_x_positions[
168
+ node
169
+ ] # Assuming you have a dict `node_x_positions` for x coordinates
170
+ G.nodes[node]["y"] = node_y_positions[
171
+ node
172
+ ] # Assuming you have a dict `node_y_positions` for y coordinates
173
+
174
+ return self._initialize_graph(G)
175
+
176
+ finally:
177
+ # Remove unzipped files/directories
178
+ cys_dirnames = list(set([cf.split("/")[0] for cf in cys_files]))
179
+ for dirname in cys_dirnames:
180
+ shutil.rmtree(dirname)
181
+
182
+ def load_cytoscape_json_network(self, filepath, source_label="source", target_label="target"):
183
+ """Load a network from a Cytoscape JSON (.cyjs) file.
184
+
185
+ Args:
186
+ filepath (str): Path to the Cytoscape JSON file.
187
+ source_label (str, optional): Source node label. Default is "source".
188
+ target_label (str, optional): Target node label. Default is "target".
189
+
190
+ Returns:
191
+ NetworkX graph: Loaded and processed network.
192
+ """
193
+ filetype = "Cytoscape JSON"
194
+ params.log_network(filetype=filetype, filepath=str(filepath))
195
+ self._log_loading(filetype, filepath=filepath)
196
+ # Load the Cytoscape JSON file
197
+ with open(filepath, "r") as f:
198
+ cyjs_data = json.load(f)
199
+
200
+ # Create a graph
201
+ G = nx.Graph()
202
+ # Process nodes
203
+ node_x_positions = {}
204
+ node_y_positions = {}
205
+ for node in cyjs_data["elements"]["nodes"]:
206
+ node_data = node["data"]
207
+ node_id = node_data["id"]
208
+ node_x_positions[node_id] = node["position"]["x"]
209
+ node_y_positions[node_id] = node["position"]["y"]
210
+ G.add_node(node_id)
211
+ G.nodes[node_id]["label"] = node_data.get("name", node_id)
212
+ G.nodes[node_id]["x"] = node["position"]["x"]
213
+ G.nodes[node_id]["y"] = node["position"]["y"]
214
+
215
+ # Process edges
216
+ for edge in cyjs_data["elements"]["edges"]:
217
+ edge_data = edge["data"]
218
+ source = edge_data[source_label]
219
+ target = edge_data[target_label]
220
+ if self.weight_label is not None and self.weight_label in edge_data:
221
+ weight = float(edge_data[self.weight_label])
222
+ G.add_edge(source, target, weight=weight)
223
+ else:
224
+ G.add_edge(source, target)
225
+
226
+ # Initialize the graph
227
+ return self._initialize_graph(G)
228
+
229
+ def _initialize_graph(self, G: nx.Graph) -> nx.Graph:
230
+ """Initialize the graph by processing and validating its nodes and edges.
231
+
232
+ Args:
233
+ G (nx.Graph): The input NetworkX graph.
234
+
235
+ Returns:
236
+ nx.Graph: The processed and validated graph.
237
+ """
238
+ # IMPORTANT: This is where the graph node labels are converted to integers
239
+ G = nx.relabel_nodes(G, {node: idx for idx, node in enumerate(G.nodes)})
240
+ self._remove_invalid_graph_properties(G)
241
+ self._validate_edges(G)
242
+ self._validate_nodes(G)
243
+ self._process_graph(G)
244
+ return G
245
+
246
+ def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
247
+ """Remove invalid properties from the graph.
248
+
249
+ Args:
250
+ G (nx.Graph): A NetworkX graph object.
251
+ """
252
+ print(f"Minimum edges per node: {self.min_edges_per_node}")
253
+ # Remove nodes with fewer edges than the specified threshold
254
+ nodes_with_few_edges = [
255
+ node for node in G.nodes() if G.degree(node) <= self.min_edges_per_node
256
+ ]
257
+ G.remove_nodes_from(nodes_with_few_edges)
258
+ # Remove self-loop edges
259
+ self_loops = list(nx.selfloop_edges(G))
260
+ G.remove_edges_from(self_loops)
261
+
262
+ def _validate_edges(self, G: nx.Graph) -> None:
263
+ """Validate and assign weights to the edges in the graph.
264
+
265
+ Args:
266
+ G (nx.Graph): A NetworkX graph object.
267
+ """
268
+ missing_weights = 0
269
+ # Assign user-defined edge weights to the "weight" attribute
270
+ for _, _, data in G.edges(data=True):
271
+ if self.weight_label not in data:
272
+ missing_weights += 1
273
+ data["weight"] = data.get(
274
+ self.weight_label, 1.0
275
+ ) # Default to 1.0 if 'weight' not present
276
+
277
+ if self.include_edge_weight and missing_weights:
278
+ print(f"Total edges missing weights: {missing_weights}")
279
+
280
+ def _validate_nodes(self, G: nx.Graph) -> None:
281
+ """Validate the graph structure and attributes.
282
+
283
+ Args:
284
+ G (nx.Graph): A NetworkX graph object.
285
+ """
286
+ for node, attrs in G.nodes(data=True):
287
+ assert (
288
+ "x" in attrs and "y" in attrs
289
+ ), f"Node {node} is missing 'x' or 'y' position attributes."
290
+ assert "label" in attrs, f"Node {node} is missing a 'label' attribute."
291
+
292
+ def _process_graph(self, G: nx.Graph) -> None:
293
+ """Prepare the network by adjusting surface depth and calculating edge lengths.
294
+
295
+ Args:
296
+ G (nx.Graph): The input network graph.
297
+ """
298
+ apply_edge_lengths(
299
+ G,
300
+ compute_sphere=self.compute_sphere,
301
+ surface_depth=self.surface_depth,
302
+ include_edge_weight=self.include_edge_weight,
303
+ )
304
+
305
+ def _log_loading(
306
+ self,
307
+ filetype: str,
308
+ filepath: str = "",
309
+ ) -> None:
310
+ """Log the initialization details of the RISK class.
311
+
312
+ Args:
313
+ filetype (str): The type of the file being loaded (e.g., 'CSV', 'JSON').
314
+ filepath (str, optional): The path to the file being loaded. Defaults to "".
315
+ """
316
+ print_header("Loading network")
317
+ print(f"Filetype: {filetype}")
318
+ if filepath:
319
+ print(f"Filepath: {filepath}")
320
+ print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
321
+ if self.compute_sphere:
322
+ print(f"Surface depth: {self.surface_depth}")
323
+ print(f"Edge length threshold: {self.edge_length_threshold}")
324
+ print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
325
+ if self.include_edge_weight:
326
+ print(f"Weight label: {self.weight_label}")