risk-network 0.0.5b6__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.5-beta.6"
10
+ __version__ = "0.0.6"
risk/annotations/io.py CHANGED
@@ -36,13 +36,15 @@ class AnnotationsIO:
36
36
  dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
37
37
  """
38
38
  filetype = "JSON"
39
+ # Log the loading of the JSON file
39
40
  params.log_annotations(filepath=filepath, filetype=filetype)
40
41
  _log_loading(filetype, filepath=filepath)
42
+
41
43
  # Open and read the JSON file
42
44
  with open(filepath, "r") as file:
43
45
  annotations_input = json.load(file)
44
46
 
45
- # Process the JSON data and return it in the context of the network
47
+ # Load the annotations into the provided network
46
48
  return load_annotations(network, annotations_input)
47
49
 
48
50
  def load_excel_annotation(
@@ -69,14 +71,18 @@ class AnnotationsIO:
69
71
  linked to the provided network.
70
72
  """
71
73
  filetype = "Excel"
74
+ # Log the loading of the Excel file
72
75
  params.log_annotations(filepath=filepath, filetype=filetype)
73
76
  _log_loading(filetype, filepath=filepath)
77
+
74
78
  # Load the specified sheet from the Excel file
75
79
  df = pd.read_excel(filepath, sheet_name=sheet_name)
76
80
  # Split the nodes column by the specified nodes_delimiter
77
81
  df[nodes_colname] = df[nodes_colname].apply(lambda x: x.split(nodes_delimiter))
78
82
  # Convert the DataFrame to a dictionary pairing labels with their corresponding nodes
79
83
  label_node_dict = df.set_index(label_colname)[nodes_colname].to_dict()
84
+
85
+ # Load the annotations into the provided network
80
86
  return load_annotations(network, label_node_dict)
81
87
 
82
88
  def load_csv_annotation(
@@ -101,13 +107,16 @@ class AnnotationsIO:
101
107
  linked to the provided network.
102
108
  """
103
109
  filetype = "CSV"
110
+ # Log the loading of the CSV file
104
111
  params.log_annotations(filepath=filepath, filetype=filetype)
105
112
  _log_loading(filetype, filepath=filepath)
113
+
106
114
  # Load the CSV file into a dictionary
107
115
  annotations_input = _load_matrix_file(
108
116
  filepath, label_colname, nodes_colname, delimiter=",", nodes_delimiter=nodes_delimiter
109
117
  )
110
- # Process and return the annotations in the context of the network
118
+
119
+ # Load the annotations into the provided network
111
120
  return load_annotations(network, annotations_input)
112
121
 
113
122
  def load_tsv_annotation(
@@ -132,15 +141,47 @@ class AnnotationsIO:
132
141
  linked to the provided network.
133
142
  """
134
143
  filetype = "TSV"
144
+ # Log the loading of the TSV file
135
145
  params.log_annotations(filepath=filepath, filetype=filetype)
136
146
  _log_loading(filetype, filepath=filepath)
147
+
137
148
  # Load the TSV file into a dictionary
138
149
  annotations_input = _load_matrix_file(
139
150
  filepath, label_colname, nodes_colname, delimiter="\t", nodes_delimiter=nodes_delimiter
140
151
  )
141
- # Process and return the annotations in the context of the network
152
+
153
+ # Load the annotations into the provided network
142
154
  return load_annotations(network, annotations_input)
143
155
 
156
+ def load_dict_annotation(self, content: Dict[str, Any], network: nx.Graph) -> Dict[str, Any]:
157
+ """Load annotations from a provided dictionary and convert them to a dictionary annotation.
158
+
159
+ Args:
160
+ content (dict): The annotations dictionary to load.
161
+ network (NetworkX graph): The network to which the annotations are related.
162
+
163
+ Returns:
164
+ dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
165
+ """
166
+ # Ensure the input content is a dictionary
167
+ if not isinstance(content, dict):
168
+ raise TypeError(
169
+ f"Expected 'content' to be a dictionary, but got {type(content).__name__} instead."
170
+ )
171
+
172
+ filetype = "Dictionary"
173
+ # Log the loading of the annotations from the dictionary
174
+ params.log_annotations(filepath="In-memory dictionary", filetype=filetype)
175
+ _log_loading(filetype, "In-memory dictionary")
176
+
177
+ # Load the annotations into the provided network
178
+ annotations_dict = load_annotations(network, content)
179
+ # Ensure the output is a dictionary
180
+ if not isinstance(annotations_dict, dict):
181
+ raise ValueError("Expected output to be a dictionary")
182
+
183
+ return annotations_dict
184
+
144
185
 
145
186
  def _load_matrix_file(
146
187
  filepath: str,
risk/log/params.py CHANGED
@@ -7,6 +7,7 @@ import csv
7
7
  import json
8
8
  import warnings
9
9
  from datetime import datetime
10
+ from functools import wraps
10
11
  from typing import Any, Dict
11
12
 
12
13
  import numpy as np
@@ -27,6 +28,7 @@ def _safe_param_export(func):
27
28
  function: The wrapped function with error handling.
28
29
  """
29
30
 
31
+ @wraps(func)
30
32
  def wrapper(*args, **kwargs):
31
33
  try:
32
34
  result = func(*args, **kwargs)
@@ -25,10 +25,14 @@ def calculate_dijkstra_neighborhoods(network: nx.Graph) -> np.ndarray:
25
25
 
26
26
  # Populate the neighborhoods matrix based on Dijkstra's distances
27
27
  for source, targets in all_dijkstra_paths.items():
28
+ max_length = max(targets.values()) if targets else 1 # Handle cases with no targets
28
29
  for target, length in targets.items():
29
- neighborhoods[source, target] = (
30
- 1 if np.isnan(length) or length == 0 else np.sqrt(1 / length)
31
- )
30
+ if np.isnan(length):
31
+ neighborhoods[source, target] = max_length # Use max distance for NaN
32
+ elif length == 0:
33
+ neighborhoods[source, target] = 1 # Assign 1 for zero-length paths (self-loops)
34
+ else:
35
+ neighborhoods[source, target] = 1 / length # Inverse of the distance
32
36
 
33
37
  return neighborhoods
34
38
 
@@ -35,26 +35,31 @@ def define_domains(
35
35
  Returns:
36
36
  pd.DataFrame: DataFrame with the primary domain for each node.
37
37
  """
38
- # Perform hierarchical clustering on the binary enrichment matrix
39
- m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
40
- best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
41
- m, linkage_criterion, linkage_method, linkage_metric
42
- )
43
- try:
44
- Z = linkage(m, method=best_linkage, metric=best_metric)
45
- except ValueError as e:
46
- raise ValueError("No significant annotations found.") from e
38
+ # Check if there's more than one column in significant_neighborhoods_enrichment
39
+ if significant_neighborhoods_enrichment.shape[1] == 1:
40
+ print("Single annotation detected. Skipping clustering.")
41
+ top_annotations["domain"] = 1 # Assign a default domain or handle appropriately
42
+ else:
43
+ # Perform hierarchical clustering on the binary enrichment matrix
44
+ m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
45
+ best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
46
+ m, linkage_criterion, linkage_method, linkage_metric
47
+ )
48
+ try:
49
+ Z = linkage(m, method=best_linkage, metric=best_metric)
50
+ except ValueError as e:
51
+ raise ValueError("No significant annotations found.") from e
47
52
 
48
- print(
49
- f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'"
50
- )
51
- print(f"Optimal linkage threshold: {round(best_threshold, 3)}")
53
+ print(
54
+ f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'"
55
+ )
56
+ print(f"Optimal linkage threshold: {round(best_threshold, 3)}")
52
57
 
53
- max_d_optimal = np.max(Z[:, 2]) * best_threshold
54
- domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
55
- # Assign domains to the annotations matrix
56
- top_annotations["domain"] = 0
57
- top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
58
+ max_d_optimal = np.max(Z[:, 2]) * best_threshold
59
+ domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
60
+ # Assign domains to the annotations matrix
61
+ top_annotations["domain"] = 0
62
+ top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
58
63
 
59
64
  # Create DataFrames to store domain information
60
65
  node_to_enrichment = pd.DataFrame(
@@ -63,6 +68,7 @@ def define_domains(
63
68
  )
64
69
  node_to_domain = node_to_enrichment.groupby(level="domain", axis=1).sum()
65
70
 
71
+ # Find the maximum enrichment score for each node
66
72
  t_max = node_to_domain.loc[:, 1:].max(axis=1)
67
73
  t_idxmax = node_to_domain.loc[:, 1:].idxmax(axis=1)
68
74
  t_idxmax[t_max == 0] = 0
@@ -4,7 +4,7 @@ risk/neighborhoods/neighborhoods
4
4
  """
5
5
 
6
6
  import warnings
7
- from typing import Any, Dict, Tuple
7
+ from typing import Any, Dict, List, Tuple
8
8
 
9
9
  import networkx as nx
10
10
  import numpy as np
@@ -305,7 +305,7 @@ def _get_node_position(network: nx.Graph, node: Any) -> np.ndarray:
305
305
  )
306
306
 
307
307
 
308
- def _calculate_threshold(average_distances: list, distance_threshold: float) -> float:
308
+ def _calculate_threshold(average_distances: List, distance_threshold: float) -> float:
309
309
  """Calculate the distance threshold based on the given average distances and a percentile threshold.
310
310
 
311
311
  Args:
risk/network/graph.py CHANGED
@@ -28,7 +28,7 @@ class NetworkGraph:
28
28
  top_annotations: pd.DataFrame,
29
29
  domains: pd.DataFrame,
30
30
  trimmed_domains: pd.DataFrame,
31
- node_label_to_id_map: Dict[str, Any],
31
+ node_label_to_node_id_map: Dict[str, Any],
32
32
  node_enrichment_sums: np.ndarray,
33
33
  ):
34
34
  """Initialize the NetworkGraph object.
@@ -38,39 +38,48 @@ class NetworkGraph:
38
38
  top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
39
39
  domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
40
40
  trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
41
- node_label_to_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
41
+ node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
42
42
  node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
43
43
  """
44
44
  self.top_annotations = top_annotations
45
- self.domain_to_nodes = self._create_domain_to_nodes_map(domains)
45
+ self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
46
46
  self.domains = domains
47
- self.trimmed_domain_to_term = self._create_domain_to_term_map(trimmed_domains)
48
- self.trimmed_domains = trimmed_domains
49
- self.node_label_to_id_map = node_label_to_id_map
47
+ self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
48
+ trimmed_domains
49
+ )
50
50
  self.node_enrichment_sums = node_enrichment_sums
51
- # NOTE: self.network and self.node_coordinates are declared in _initialize_network
51
+ self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
52
+ self.node_label_to_enrichment_map = dict(
53
+ zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
54
+ )
55
+ self.node_label_to_node_id_map = node_label_to_node_id_map
56
+ # NOTE: Below this point, instance attributes (i.e., self) will be used!
57
+ self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
58
+ # self.network and self.node_coordinates are properly declared in _initialize_network
52
59
  self.network = None
53
60
  self.node_coordinates = None
54
61
  self._initialize_network(network)
55
62
 
56
- def _create_domain_to_nodes_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
57
- """Create a mapping from domains to the list of nodes belonging to each domain.
63
+ def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
64
+ """Create a mapping from domains to the list of node IDs belonging to each domain.
58
65
 
59
66
  Args:
60
67
  domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
61
68
 
62
69
  Returns:
63
- dict: A dictionary where keys are domain IDs and values are lists of nodes belonging to each domain.
70
+ dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
64
71
  """
65
72
  cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
66
- node_to_domains = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
67
- domain_to_nodes = defaultdict(list)
68
- for k, v in node_to_domains.items():
69
- domain_to_nodes[v].append(k)
73
+ node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
74
+ domain_id_to_node_ids_map = defaultdict(list)
75
+ for k, v in node_to_domains_map.items():
76
+ domain_id_to_node_ids_map[v].append(k)
70
77
 
71
- return domain_to_nodes
78
+ return domain_id_to_node_ids_map
72
79
 
73
- def _create_domain_to_term_map(self, trimmed_domains: pd.DataFrame) -> Dict[str, Any]:
80
+ def _create_domain_id_to_domain_terms_map(
81
+ self, trimmed_domains: pd.DataFrame
82
+ ) -> Dict[str, Any]:
74
83
  """Create a mapping from domain IDs to their corresponding terms.
75
84
 
76
85
  Args:
@@ -86,6 +95,20 @@ class NetworkGraph:
86
95
  )
87
96
  )
88
97
 
98
+ def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
99
+ """Create a map from domain IDs to node labels.
100
+
101
+ Returns:
102
+ dict: A dictionary mapping domain IDs to the corresponding node labels.
103
+ """
104
+ domain_id_to_label_map = {}
105
+ for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
106
+ domain_id_to_label_map[domain_id] = [
107
+ self.node_id_to_node_label_map[node_id] for node_id in node_ids
108
+ ]
109
+
110
+ return domain_id_to_label_map
111
+
89
112
  def _initialize_network(self, G: nx.Graph) -> None:
90
113
  """Initialize the network by unfolding it and extracting node coordinates.
91
114
 
@@ -101,31 +124,32 @@ class NetworkGraph:
101
124
 
102
125
  def get_domain_colors(
103
126
  self,
127
+ cmap: str = "gist_rainbow",
128
+ color: Union[str, None] = None,
104
129
  min_scale: float = 0.8,
105
130
  max_scale: float = 1.0,
106
131
  scale_factor: float = 1.0,
107
132
  random_seed: int = 888,
108
- **kwargs,
109
133
  ) -> np.ndarray:
110
- """Generate composite colors for domains.
111
-
112
- This method generates composite colors for nodes based on their enrichment scores and transforms
113
- them to ensure proper alpha values and intensity. For nodes with alpha == 0, it assigns new colors
114
- based on the closest valid neighbors within a specified distance.
134
+ """Generate composite colors for domains based on enrichment or specified colors.
115
135
 
116
136
  Args:
117
- min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
118
- max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
119
- scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
120
- values more. Defaults to 1.0.
121
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
122
- **kwargs: Additional keyword arguments for color generation.
137
+ cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
138
+ color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
139
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
140
+ Controls the dimmest colors. Defaults to 0.8.
141
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
142
+ Controls the brightest colors. Defaults to 1.0.
143
+ scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores.
144
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
145
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments.
146
+ Defaults to 888.
123
147
 
124
148
  Returns:
125
- np.ndarray: Array of transformed colors.
149
+ np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
126
150
  """
127
151
  # Get colors for each domain
128
- domain_colors = self._get_domain_colors(random_seed=random_seed)
152
+ domain_colors = self._get_domain_colors(cmap=cmap, color=color, random_seed=random_seed)
129
153
  # Generate composite colors for nodes
130
154
  node_colors = self._get_composite_node_colors(domain_colors)
131
155
  # Transform colors to ensure proper alpha values and intensity
@@ -153,20 +177,24 @@ class NetworkGraph:
153
177
  # Initialize composite colors array with shape (number of nodes, 4) for RGBA
154
178
  composite_colors = np.zeros((num_nodes, 4))
155
179
  # Assign colors to nodes based on domain_colors
156
- for domain_idx, nodes in self.domain_to_nodes.items():
157
- color = domain_colors[domain_idx]
180
+ for domain_id, nodes in self.domain_id_to_node_ids_map.items():
181
+ color = domain_colors[domain_id]
158
182
  for node in nodes:
159
183
  composite_colors[node] = color
160
184
 
161
185
  return composite_colors
162
186
 
163
187
  def _get_domain_colors(
164
- self, color: Union[str, None] = None, random_seed: int = 888
188
+ self,
189
+ cmap: str = "gist_rainbow",
190
+ color: Union[str, None] = None,
191
+ random_seed: int = 888,
165
192
  ) -> Dict[str, Any]:
166
193
  """Get colors for each domain.
167
194
 
168
195
  Args:
169
- color (Union[str, None], optional): Specific color to use for all domains. If specified, it will overwrite the colormap.
196
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
197
+ color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
170
198
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
171
199
 
172
200
  Returns:
@@ -178,9 +206,9 @@ class NetworkGraph:
178
206
  ]
179
207
  domains = np.sort(numeric_domains)
180
208
  domain_colors = _get_colors(
181
- num_colors_to_generate=len(domains), color=color, random_seed=random_seed
209
+ num_colors_to_generate=len(domains), cmap=cmap, color=color, random_seed=random_seed
182
210
  )
183
- return dict(zip(self.domain_to_nodes.keys(), domain_colors))
211
+ return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
184
212
 
185
213
 
186
214
  def _transform_colors(
@@ -273,17 +301,17 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
273
301
 
274
302
  def _get_colors(
275
303
  num_colors_to_generate: int = 10,
276
- cmap: str = "hsv",
277
- random_seed: int = 888,
304
+ cmap: str = "gist_rainbow",
278
305
  color: Union[str, None] = None,
306
+ random_seed: int = 888,
279
307
  ) -> List[Tuple]:
280
308
  """Generate a list of RGBA colors from a specified colormap or use a direct color string.
281
309
 
282
310
  Args:
283
311
  num_colors_to_generate (int): The number of colors to generate. Defaults to 10.
284
- cmap (str): The name of the colormap to use. Defaults to "hsv".
312
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
313
+ color (str or None, optional): A specific color to use for all generated colors.
285
314
  random_seed (int): Seed for random number generation. Defaults to 888.
286
- color (str, optional): Specific color to use for all nodes. If specified, it will overwrite the colormap.
287
315
  Defaults to None.
288
316
 
289
317
  Returns:
risk/network/io.py CHANGED
@@ -48,6 +48,7 @@ class NetworkIO:
48
48
  self.min_edges_per_node = min_edges_per_node
49
49
  self.include_edge_weight = include_edge_weight
50
50
  self.weight_label = weight_label
51
+ # Log the initialization of the NetworkIO class
51
52
  params.log_network(
52
53
  compute_sphere=compute_sphere,
53
54
  surface_depth=surface_depth,
@@ -98,11 +99,14 @@ class NetworkIO:
98
99
  nx.Graph: Loaded and processed network.
99
100
  """
100
101
  filetype = "GPickle"
102
+ # Log the loading of the GPickle file
101
103
  params.log_network(filetype=filetype, filepath=filepath)
102
104
  self._log_loading(filetype, filepath=filepath)
105
+
103
106
  with open(filepath, "rb") as f:
104
107
  G = pickle.load(f)
105
108
 
109
+ # Initialize the graph
106
110
  return self._initialize_graph(G)
107
111
 
108
112
  @classmethod
@@ -147,8 +151,11 @@ class NetworkIO:
147
151
  nx.Graph: Processed network.
148
152
  """
149
153
  filetype = "NetworkX"
154
+ # Log the loading of the NetworkX graph
150
155
  params.log_network(filetype=filetype)
151
156
  self._log_loading(filetype)
157
+
158
+ # Initialize the graph
152
159
  return self._initialize_graph(network)
153
160
 
154
161
  @classmethod
@@ -213,8 +220,10 @@ class NetworkIO:
213
220
  nx.Graph: Loaded and processed network.
214
221
  """
215
222
  filetype = "Cytoscape"
223
+ # Log the loading of the Cytoscape file
216
224
  params.log_network(filetype=filetype, filepath=str(filepath))
217
225
  self._log_loading(filetype, filepath=filepath)
226
+
218
227
  cys_files = []
219
228
  tmp_dir = ".tmp_cytoscape"
220
229
  # Try / finally to remove unzipped files
@@ -295,6 +304,7 @@ class NetworkIO:
295
304
  node
296
305
  ] # Assuming you have a dict `node_y_positions` for y coordinates
297
306
 
307
+ # Initialize the graph
298
308
  return self._initialize_graph(G)
299
309
 
300
310
  finally:
@@ -354,6 +364,7 @@ class NetworkIO:
354
364
  NetworkX graph: Loaded and processed network.
355
365
  """
356
366
  filetype = "Cytoscape JSON"
367
+ # Log the loading of the Cytoscape JSON file
357
368
  params.log_network(filetype=filetype, filepath=str(filepath))
358
369
  self._log_loading(filetype, filepath=filepath)
359
370
 
@@ -418,29 +429,37 @@ class NetworkIO:
418
429
  return G
419
430
 
420
431
  def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
421
- """Remove invalid properties from the graph.
432
+ """Remove invalid properties from the graph, including self-loops, nodes with fewer edges than
433
+ the threshold, and isolated nodes.
422
434
 
423
435
  Args:
424
436
  G (nx.Graph): A NetworkX graph object.
425
437
  """
426
- # First, Remove self-loop edges to ensure correct edge count
438
+ # Count number of nodes and edges before cleaning
439
+ num_initial_nodes = G.number_of_nodes()
440
+ num_initial_edges = G.number_of_edges()
441
+ # Remove self-loops to ensure correct edge count
427
442
  G.remove_edges_from(list(nx.selfloop_edges(G)))
428
- # Then, iteratively remove nodes with fewer edges than the specified threshold
443
+ # Iteratively remove nodes with fewer edges than the threshold
429
444
  while True:
430
- nodes_to_remove = [
431
- node for node in G.nodes() if G.degree(node) < self.min_edges_per_node
432
- ]
445
+ nodes_to_remove = [node for node in G.nodes if G.degree(node) < self.min_edges_per_node]
433
446
  if not nodes_to_remove:
434
- break # Exit loop if no more nodes to remove
435
-
436
- # Remove the nodes and their associated edges
447
+ break # Exit loop if no more nodes need removal
437
448
  G.remove_nodes_from(nodes_to_remove)
438
449
 
439
- # Optionally: Remove any isolated nodes if needed
450
+ # Remove isolated nodes
440
451
  isolated_nodes = list(nx.isolates(G))
441
452
  if isolated_nodes:
442
453
  G.remove_nodes_from(isolated_nodes)
443
454
 
455
+ # Log the number of nodes and edges before and after cleaning
456
+ num_final_nodes = G.number_of_nodes()
457
+ num_final_edges = G.number_of_edges()
458
+ print(f"Initial node count: {num_initial_nodes}")
459
+ print(f"Final node count: {num_final_nodes}")
460
+ print(f"Initial edge count: {num_initial_edges}")
461
+ print(f"Final edge count: {num_final_edges}")
462
+
444
463
  def _assign_edge_weights(self, G: nx.Graph) -> None:
445
464
  """Assign weights to the edges in the graph.
446
465
 
@@ -502,6 +521,7 @@ class NetworkIO:
502
521
  print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
503
522
  if self.include_edge_weight:
504
523
  print(f"Weight label: {self.weight_label}")
524
+ print(f"Minimum edges per node: {self.min_edges_per_node}")
505
525
  print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
506
526
  if self.compute_sphere:
507
527
  print(f"Surface depth: {self.surface_depth}")