risk-network 0.0.16b0__py3-none-any.whl → 0.0.16b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. risk/__init__.py +2 -2
  2. risk/{_annotation → annotation}/__init__.py +2 -2
  3. risk/{_annotation → annotation}/_nltk_setup.py +3 -3
  4. risk/{_annotation/_annotation.py → annotation/annotation.py} +22 -25
  5. risk/{_annotation/_io.py → annotation/io.py} +4 -4
  6. risk/cluster/__init__.py +8 -0
  7. risk/{_neighborhoods → cluster}/_community.py +37 -37
  8. risk/cluster/api.py +273 -0
  9. risk/{_neighborhoods/_neighborhoods.py → cluster/cluster.py} +127 -98
  10. risk/{_neighborhoods/_domains.py → cluster/label.py} +18 -12
  11. risk/{_log → log}/__init__.py +2 -2
  12. risk/{_log/_console.py → log/console.py} +2 -2
  13. risk/{_log/_parameters.py → log/parameters.py} +20 -10
  14. risk/network/__init__.py +8 -0
  15. risk/network/graph/__init__.py +7 -0
  16. risk/{_network/_graph → network/graph}/_stats.py +2 -2
  17. risk/{_network/_graph → network/graph}/_summary.py +13 -13
  18. risk/{_network/_graph/_api.py → network/graph/api.py} +37 -39
  19. risk/{_network/_graph/_graph.py → network/graph/graph.py} +5 -5
  20. risk/{_network/_io.py → network/io.py} +9 -4
  21. risk/network/plotter/__init__.py +6 -0
  22. risk/{_network/_plotter → network/plotter}/_canvas.py +6 -6
  23. risk/{_network/_plotter → network/plotter}/_contour.py +4 -4
  24. risk/{_network/_plotter → network/plotter}/_labels.py +6 -6
  25. risk/{_network/_plotter → network/plotter}/_network.py +7 -7
  26. risk/{_network/_plotter → network/plotter}/_plotter.py +5 -5
  27. risk/network/plotter/_utils/__init__.py +7 -0
  28. risk/{_network/_plotter/_utils/_colors.py → network/plotter/_utils/colors.py} +3 -3
  29. risk/{_network/_plotter/_utils/_layout.py → network/plotter/_utils/layout.py} +2 -2
  30. risk/{_network/_plotter/_api.py → network/plotter/api.py} +5 -5
  31. risk/{_risk.py → risk.py} +9 -8
  32. risk/stats/__init__.py +6 -0
  33. risk/stats/_stats/__init__.py +11 -0
  34. risk/stats/_stats/permutation/__init__.py +6 -0
  35. risk/stats/_stats/permutation/_test_functions.py +72 -0
  36. risk/{_neighborhoods/_stats/_permutation/_permutation.py → stats/_stats/permutation/permutation.py} +35 -37
  37. risk/{_neighborhoods/_stats/_tests.py → stats/_stats/tests.py} +32 -34
  38. risk/stats/api.py +202 -0
  39. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/METADATA +2 -2
  40. risk_network-0.0.16b2.dist-info/RECORD +43 -0
  41. risk/_neighborhoods/__init__.py +0 -8
  42. risk/_neighborhoods/_api.py +0 -354
  43. risk/_neighborhoods/_stats/__init__.py +0 -11
  44. risk/_neighborhoods/_stats/_permutation/__init__.py +0 -6
  45. risk/_neighborhoods/_stats/_permutation/_test_functions.py +0 -72
  46. risk/_network/__init__.py +0 -8
  47. risk/_network/_graph/__init__.py +0 -7
  48. risk/_network/_plotter/__init__.py +0 -6
  49. risk/_network/_plotter/_utils/__init__.py +0 -7
  50. risk_network-0.0.16b0.dist-info/RECORD +0 -41
  51. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/WHEEL +0 -0
  52. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/licenses/LICENSE +0 -0
  53. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/top_level.txt +0 -0
risk/cluster/api.py ADDED
@@ -0,0 +1,273 @@
1
+ """
2
+ risk/cluster/api
3
+ ~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import copy
7
+ import random
8
+
9
+ import networkx as nx
10
+ import numpy as np
11
+ from scipy.sparse import csr_matrix
12
+
13
+ from ..log import log_header, logger, params
14
+ from .cluster import cluster_method
15
+ from ._community import (
16
+ calculate_greedy_modularity_clusters,
17
+ calculate_label_propagation_clusters,
18
+ calculate_leiden_clusters,
19
+ calculate_louvain_clusters,
20
+ calculate_markov_clustering_clusters,
21
+ calculate_spinglass_clusters,
22
+ calculate_walktrap_clusters,
23
+ )
24
+
25
+
26
+ class ClusterAPI:
27
+ """
28
+ Handles the loading of statistical results and annotation significance for clusters.
29
+
30
+ The ClusterAPI class provides methods for explicit clustering algorithms.
31
+ """
32
+
33
+ @cluster_method
34
+ def cluster_greedy(
35
+ self,
36
+ network: nx.Graph,
37
+ fraction_shortest_edges: float = 0.5,
38
+ ) -> csr_matrix:
39
+ """
40
+ Compute greedy modularity clusters for the given network.
41
+
42
+ Args:
43
+ network (nx.Graph): The network graph to cluster.
44
+ fraction_shortest_edges (float, optional): Rank-based fraction (0, 1] of the shortest edges
45
+ retained when building the clustering subgraph. Defaults to 0.5.
46
+
47
+ Returns:
48
+ csr_matrix: Sparse matrix representing cluster assignments.
49
+ """
50
+ self._log_clustering_params(
51
+ clustering="greedy",
52
+ fraction_shortest_edges=fraction_shortest_edges,
53
+ )
54
+ network = copy.copy(network)
55
+ return calculate_greedy_modularity_clusters(
56
+ network,
57
+ fraction_shortest_edges=fraction_shortest_edges,
58
+ )
59
+
60
+ @cluster_method
61
+ def cluster_labelprop(
62
+ self,
63
+ network: nx.Graph,
64
+ fraction_shortest_edges: float = 0.5,
65
+ ) -> csr_matrix:
66
+ """
67
+ Compute label propagation clusters for the given network.
68
+
69
+ Args:
70
+ network (nx.Graph): The network graph to cluster.
71
+ fraction_shortest_edges (float, optional): Rank-based fraction (0, 1] of the shortest edges
72
+ retained when building the clustering subgraph. Defaults to 0.5.
73
+
74
+ Returns:
75
+ csr_matrix: Sparse matrix representing cluster assignments.
76
+ """
77
+ self._log_clustering_params(
78
+ clustering="labelprop",
79
+ fraction_shortest_edges=fraction_shortest_edges,
80
+ )
81
+ network = copy.copy(network)
82
+ return calculate_label_propagation_clusters(
83
+ network,
84
+ fraction_shortest_edges=fraction_shortest_edges,
85
+ )
86
+
87
+ @cluster_method
88
+ def cluster_leiden(
89
+ self,
90
+ network: nx.Graph,
91
+ fraction_shortest_edges: float = 0.5,
92
+ resolution: float = 1.0,
93
+ random_seed: int = 888,
94
+ ) -> csr_matrix:
95
+ """
96
+ Compute Leiden clusters for the given network.
97
+
98
+ Args:
99
+ network (nx.Graph): The network graph to cluster.
100
+ fraction_shortest_edges (float, optional): Rank-based fraction (0, 1] of the shortest edges
101
+ retained when building the clustering subgraph. Defaults to 0.5.
102
+ resolution (float, optional): Resolution parameter for Leiden algorithm. Defaults to 1.0.
103
+ random_seed (int, optional): Random seed for reproducibility. Defaults to 888.
104
+
105
+ Returns:
106
+ csr_matrix: Sparse matrix representing cluster assignments.
107
+ """
108
+ self._log_clustering_params(
109
+ clustering="leiden",
110
+ fraction_shortest_edges=fraction_shortest_edges,
111
+ resolution=resolution,
112
+ random_seed=random_seed,
113
+ )
114
+ # Additional logging for specific parameters
115
+ logger.debug(f"Resolution: {resolution}")
116
+ logger.debug(f"Random seed: {random_seed}")
117
+ # Set random seed for reproducibility
118
+ random.seed(random_seed)
119
+ np.random.seed(random_seed)
120
+ network = copy.copy(network)
121
+ return calculate_leiden_clusters(
122
+ network,
123
+ fraction_shortest_edges=fraction_shortest_edges,
124
+ resolution=resolution,
125
+ random_seed=random_seed,
126
+ )
127
+
128
+ @cluster_method
129
+ def cluster_louvain(
130
+ self,
131
+ network: nx.Graph,
132
+ fraction_shortest_edges: float = 0.5,
133
+ resolution: float = 0.1,
134
+ random_seed: int = 888,
135
+ ) -> csr_matrix:
136
+ """
137
+ Compute Louvain clusters for the given network.
138
+
139
+ Args:
140
+ network (nx.Graph): The network graph to cluster.
141
+ fraction_shortest_edges (float, optional): Rank-based fraction (0, 1] of the shortest edges
142
+ retained when building the clustering subgraph. Defaults to 0.5.
143
+ resolution (float, optional): Resolution parameter for Louvain algorithm. Defaults to 1.0.
144
+ random_seed (int, optional): Random seed for reproducibility. Defaults to 888.
145
+
146
+ Returns:
147
+ csr_matrix: Sparse matrix representing cluster assignments.
148
+ """
149
+ self._log_clustering_params(
150
+ clustering="louvain",
151
+ fraction_shortest_edges=fraction_shortest_edges,
152
+ resolution=resolution,
153
+ random_seed=random_seed,
154
+ )
155
+ # Additional logging for specific parameters
156
+ logger.debug(f"Resolution: {resolution}")
157
+ logger.debug(f"Random seed: {random_seed}")
158
+ # Set random seed for reproducibility
159
+ random.seed(random_seed)
160
+ np.random.seed(random_seed)
161
+ network = copy.copy(network)
162
+ return calculate_louvain_clusters(
163
+ network,
164
+ fraction_shortest_edges=fraction_shortest_edges,
165
+ resolution=resolution,
166
+ random_seed=random_seed,
167
+ )
168
+
169
+ @cluster_method
170
+ def cluster_markov(
171
+ self,
172
+ network: nx.Graph,
173
+ fraction_shortest_edges: float = 0.5,
174
+ ) -> csr_matrix:
175
+ """
176
+ Compute Markov clustering clusters for the given network.
177
+
178
+ Args:
179
+ network (nx.Graph): The network graph to cluster.
180
+ fraction_shortest_edges (float, optional): Rank-based fraction (0, 1] of the shortest edges
181
+ retained when building the clustering subgraph. Defaults to 0.5.
182
+
183
+ Returns:
184
+ csr_matrix: Sparse matrix representing cluster assignments.
185
+ """
186
+ self._log_clustering_params(
187
+ clustering="markov",
188
+ fraction_shortest_edges=fraction_shortest_edges,
189
+ )
190
+ network = copy.copy(network)
191
+ return calculate_markov_clustering_clusters(
192
+ network,
193
+ fraction_shortest_edges=fraction_shortest_edges,
194
+ )
195
+
196
+ @cluster_method
197
+ def cluster_spinglass(
198
+ self,
199
+ network: nx.Graph,
200
+ fraction_shortest_edges: float = 0.5,
201
+ ) -> csr_matrix:
202
+ """
203
+ Compute spinglass clusters for the given network.
204
+
205
+ Args:
206
+ network (nx.Graph): The network graph to cluster.
207
+ fraction_shortest_edges (float, optional): Rank-based fraction (0, 1] of the shortest edges
208
+ retained when building the clustering subgraph. Defaults to 0.5.
209
+
210
+ Returns:
211
+ csr_matrix: Sparse matrix representing cluster assignments.
212
+ """
213
+ self._log_clustering_params(
214
+ clustering="spinglass",
215
+ fraction_shortest_edges=fraction_shortest_edges,
216
+ )
217
+ network = copy.copy(network)
218
+ return calculate_spinglass_clusters(
219
+ network,
220
+ fraction_shortest_edges=fraction_shortest_edges,
221
+ )
222
+
223
+ @cluster_method
224
+ def cluster_walktrap(
225
+ self,
226
+ network: nx.Graph,
227
+ fraction_shortest_edges: float = 0.5,
228
+ ) -> csr_matrix:
229
+ """
230
+ Compute walktrap clusters for the given network.
231
+
232
+ Args:
233
+ network (nx.Graph): The network graph to cluster.
234
+ fraction_shortest_edges (float, optional): Rank-based fraction (0, 1] of the shortest edges
235
+ retained when building the clustering subgraph. Defaults to 0.5.
236
+
237
+ Returns:
238
+ csr_matrix: Sparse matrix representing cluster assignments.
239
+ """
240
+ self._log_clustering_params(
241
+ clustering="walktrap",
242
+ fraction_shortest_edges=fraction_shortest_edges,
243
+ )
244
+ network = copy.copy(network)
245
+ return calculate_walktrap_clusters(
246
+ network,
247
+ fraction_shortest_edges=fraction_shortest_edges,
248
+ )
249
+
250
+ def _log_clustering_params(
251
+ self,
252
+ clustering: str,
253
+ fraction_shortest_edges: float,
254
+ **kwargs,
255
+ ) -> None:
256
+ """
257
+ Log clustering parameters for debugging and reproducibility.
258
+
259
+ Args:
260
+ clustering (str): The display name of the clustering method.
261
+ fraction_shortest_edges (float): Rank-based fraction (0, 1] of the shortest edges used for clustering.
262
+ **kwargs: Additional clustering parameters to log.
263
+ """
264
+ log_header("Computing clusters")
265
+ # Log and display cluster settings
266
+ logger.debug(f"Clustering: '{clustering}'")
267
+ logger.debug(f"Edge length threshold: {fraction_shortest_edges}")
268
+ # Log clustering parameters
269
+ params.log_clusters(
270
+ clustering=clustering,
271
+ fraction_shortest_edges=fraction_shortest_edges,
272
+ **kwargs,
273
+ )
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_neighborhoods/_neighborhoods
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/cluster/cluster
3
+ ~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import random
@@ -13,116 +13,124 @@ from scipy.sparse import csr_matrix
13
13
  from sklearn.exceptions import DataConversionWarning
14
14
  from sklearn.metrics.pairwise import cosine_similarity
15
15
 
16
- from .._log import logger
16
+ from ..log import logger
17
17
  from ._community import (
18
- calculate_greedy_modularity_neighborhoods,
19
- calculate_label_propagation_neighborhoods,
20
- calculate_leiden_neighborhoods,
21
- calculate_louvain_neighborhoods,
22
- calculate_markov_clustering_neighborhoods,
23
- calculate_spinglass_neighborhoods,
24
- calculate_walktrap_neighborhoods,
18
+ calculate_greedy_modularity_clusters,
19
+ calculate_label_propagation_clusters,
20
+ calculate_leiden_clusters,
21
+ calculate_louvain_clusters,
22
+ calculate_markov_clustering_clusters,
23
+ calculate_spinglass_clusters,
24
+ calculate_walktrap_clusters,
25
25
  )
26
26
 
27
27
  # Suppress DataConversionWarning
28
28
  warnings.filterwarnings(action="ignore", category=DataConversionWarning)
29
29
 
30
30
 
31
- def get_network_neighborhoods(
31
+ def cluster_method(func):
32
+ """
33
+ Decorator for clustering functions to ensure deterministic, reproducible results.
34
+ Sets random seeds, copies the network, and ensures output is normalized.
35
+
36
+ Args:
37
+ func (callable): The clustering function to be decorated.
38
+
39
+ Returns:
40
+ callable: The wrapped clustering function with added functionality.
41
+ """
42
+
43
+ def wrapper(*args, **kwargs):
44
+ """
45
+ Wrapper function to set random seeds and normalize output.
46
+
47
+ Args:
48
+ *args: Positional arguments for the clustering function.
49
+ **kwargs: Keyword arguments for the clustering function.
50
+
51
+ Returns:
52
+ csr_matrix: Sparse matrix representing cluster assignments.
53
+ """
54
+ clusters = func(*args, **kwargs)
55
+ return _set_max_row_value_to_one_sparse(clusters)
56
+
57
+ return wrapper
58
+
59
+
60
+ def get_network_clusters(
32
61
  network: nx.Graph,
33
- distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
34
- fraction_shortest_edges: Union[float, List, Tuple, np.ndarray] = 1.0,
62
+ clustering: str = "louvain",
63
+ fraction_shortest_edges: float = 0.5,
35
64
  louvain_resolution: float = 0.1,
36
65
  leiden_resolution: float = 1.0,
37
66
  random_seed: int = 888,
38
67
  ) -> csr_matrix:
39
68
  """
40
- Calculate the combined neighborhoods for each node using sparse matrices.
69
+ Calculate clusters for the network using a single method.
41
70
 
42
71
  Args:
43
72
  network (nx.Graph): The network graph.
44
- distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use.
45
- fraction_shortest_edges (float, List, Tuple, or np.ndarray, optional): Shortest edge rank fraction thresholds.
46
- louvain_resolution (float, optional): Resolution parameter for the Louvain method.
47
- leiden_resolution (float, optional): Resolution parameter for the Leiden method.
48
- random_seed (int, optional): Random seed for methods requiring random initialization.
73
+ clustering (str, optional): The clustering method ('greedy', 'labelprop', 'leiden', 'louvain', 'markov', 'spinglass', 'walktrap').
74
+ fraction_shortest_edges (float, optional): Fraction of shortest edges to consider for creating subgraphs. Defaults to 0.5.
75
+ louvain_resolution (float, optional): Resolution for Louvain.
76
+ leiden_resolution (float, optional): Resolution for Leiden.
77
+ random_seed (int, optional): Random seed.
49
78
 
50
79
  Returns:
51
- csr_matrix: The combined neighborhood matrix.
80
+ csr_matrix: Sparse cluster matrix.
52
81
 
53
82
  Raises:
54
- ValueError: If the number of distance metrics does not match the number of edge length thresholds.
83
+ ValueError: If invalid clustering method is provided.
55
84
  """
56
- # Set random seed for reproducibility
85
+ # Set random seed for cluster reproducibility
57
86
  random.seed(random_seed)
58
87
  np.random.seed(random_seed)
59
88
 
60
- # Ensure distance_metric is a list for multi-algorithm handling
61
- if isinstance(distance_metric, (str, np.ndarray)):
62
- distance_metric = [distance_metric]
63
- # Ensure fraction_shortest_edges is a list for multi-threshold handling
64
- if isinstance(fraction_shortest_edges, (float, int)):
65
- fraction_shortest_edges = [fraction_shortest_edges] * len(distance_metric)
66
- # Validate matching lengths of distance metrics and thresholds
67
- if len(distance_metric) != len(fraction_shortest_edges):
89
+ clusters = None
90
+ # Determine clustering method and compute clusters
91
+ if clustering == "greedy":
92
+ clusters = calculate_greedy_modularity_clusters(
93
+ network, fraction_shortest_edges=fraction_shortest_edges
94
+ )
95
+ elif clustering == "labelprop":
96
+ clusters = calculate_label_propagation_clusters(
97
+ network, fraction_shortest_edges=fraction_shortest_edges
98
+ )
99
+ elif clustering == "leiden":
100
+ clusters = calculate_leiden_clusters(
101
+ network,
102
+ resolution=leiden_resolution,
103
+ fraction_shortest_edges=fraction_shortest_edges,
104
+ random_seed=random_seed,
105
+ )
106
+ elif clustering == "louvain":
107
+ clusters = calculate_louvain_clusters(
108
+ network,
109
+ resolution=louvain_resolution,
110
+ fraction_shortest_edges=fraction_shortest_edges,
111
+ random_seed=random_seed,
112
+ )
113
+ elif clustering == "markov":
114
+ clusters = calculate_markov_clustering_clusters(
115
+ network, fraction_shortest_edges=fraction_shortest_edges
116
+ )
117
+ elif clustering == "spinglass":
118
+ clusters = calculate_spinglass_clusters(
119
+ network, fraction_shortest_edges=fraction_shortest_edges
120
+ )
121
+ elif clustering == "walktrap":
122
+ clusters = calculate_walktrap_clusters(
123
+ network, fraction_shortest_edges=fraction_shortest_edges
124
+ )
125
+ else:
68
126
  raise ValueError(
69
- "The number of distance metrics must match the number of edge length thresholds."
127
+ "Invalid clustering method. Choose from: 'greedy', 'labelprop', 'leiden', 'louvain', 'markov', 'spinglass', 'walktrap'."
70
128
  )
71
129
 
72
- # Initialize a sparse LIL matrix for incremental updates
73
- num_nodes = network.number_of_nodes()
74
- # Initialize a sparse matrix with the same shape as the network
75
- combined_neighborhoods = csr_matrix((num_nodes, num_nodes), dtype=np.uint8)
76
- # Loop through each distance metric and corresponding edge rank fraction
77
- for metric, percentile in zip(distance_metric, fraction_shortest_edges):
78
- # Compute neighborhoods for the specified metric
79
- if metric == "greedy_modularity":
80
- neighborhoods = calculate_greedy_modularity_neighborhoods(
81
- network, fraction_shortest_edges=percentile
82
- )
83
- elif metric == "label_propagation":
84
- neighborhoods = calculate_label_propagation_neighborhoods(
85
- network, fraction_shortest_edges=percentile
86
- )
87
- elif metric == "leiden":
88
- neighborhoods = calculate_leiden_neighborhoods(
89
- network,
90
- resolution=leiden_resolution,
91
- fraction_shortest_edges=percentile,
92
- random_seed=random_seed,
93
- )
94
- elif metric == "louvain":
95
- neighborhoods = calculate_louvain_neighborhoods(
96
- network,
97
- resolution=louvain_resolution,
98
- fraction_shortest_edges=percentile,
99
- random_seed=random_seed,
100
- )
101
- elif metric == "markov_clustering":
102
- neighborhoods = calculate_markov_clustering_neighborhoods(
103
- network, fraction_shortest_edges=percentile
104
- )
105
- elif metric == "spinglass":
106
- neighborhoods = calculate_spinglass_neighborhoods(
107
- network, fraction_shortest_edges=percentile
108
- )
109
- elif metric == "walktrap":
110
- neighborhoods = calculate_walktrap_neighborhoods(
111
- network, fraction_shortest_edges=percentile
112
- )
113
- else:
114
- raise ValueError(
115
- "Invalid distance metric. Choose from: 'greedy_modularity', 'label_propagation',"
116
- "'leiden', 'louvain', 'markov_clustering', 'spinglass', 'walktrap'."
117
- )
118
-
119
- # Add the sparse neighborhood matrix
120
- combined_neighborhoods += neighborhoods
121
-
122
- # Ensure maximum value in each row is set to 1
123
- combined_neighborhoods = _set_max_row_value_to_one_sparse(combined_neighborhoods)
130
+ # Ensure maximum per row set to 1
131
+ clusters = _set_max_row_value_to_one_sparse(clusters)
124
132
 
125
- return combined_neighborhoods
133
+ return clusters
126
134
 
127
135
 
128
136
  def _set_max_row_value_to_one_sparse(matrix: csr_matrix) -> csr_matrix:
@@ -144,27 +152,29 @@ def _set_max_row_value_to_one_sparse(matrix: csr_matrix) -> csr_matrix:
144
152
  return matrix
145
153
 
146
154
 
147
- def process_neighborhoods(
155
+ def process_significant_clusters(
148
156
  network: nx.Graph,
149
- neighborhoods: Dict[str, Any],
157
+ significant_clusters: Dict[str, Any],
150
158
  impute_depth: int = 0,
151
159
  prune_threshold: float = 0.0,
152
160
  ) -> Dict[str, Any]:
153
161
  """
154
- Process neighborhoods based on the imputation and pruning settings.
162
+ Process clusters based on the imputation and pruning settings.
155
163
 
156
164
  Args:
157
165
  network (nx.Graph): The network data structure used for imputing and pruning neighbors.
158
- neighborhoods (Dict[str, Any]): Dictionary containing 'significance_matrix', 'significant_binary_significance_matrix', and 'significant_significance_matrix'.
166
+ significant_clusters (Dict[str, Any]): Dictionary containing 'significance_matrix', 'significant_binary_significance_matrix', and 'significant_significance_matrix'.
159
167
  impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
160
168
  prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
161
169
 
162
170
  Returns:
163
- Dict[str, Any]: Processed neighborhoods data, including the updated matrices and significance counts.
171
+ Dict[str, Any]: Processed clusters data, including the updated matrices and significance counts.
164
172
  """
165
- significance_matrix = neighborhoods["significance_matrix"]
166
- significant_binary_significance_matrix = neighborhoods["significant_binary_significance_matrix"]
167
- significant_significance_matrix = neighborhoods["significant_significance_matrix"]
173
+ significance_matrix = significant_clusters["significance_matrix"]
174
+ significant_binary_significance_matrix = significant_clusters[
175
+ "significant_binary_significance_matrix"
176
+ ]
177
+ significant_significance_matrix = significant_clusters["significant_significance_matrix"]
168
178
  logger.debug(f"Imputation depth: {impute_depth}")
169
179
  if impute_depth:
170
180
  (
@@ -191,13 +201,13 @@ def process_neighborhoods(
191
201
  distance_threshold=prune_threshold,
192
202
  )
193
203
 
194
- neighborhood_significance_counts = np.sum(significant_binary_significance_matrix, axis=0)
204
+ cluster_significance_counts = np.sum(significant_binary_significance_matrix, axis=0)
195
205
  node_significance_sums = np.sum(significance_matrix, axis=1)
196
206
  return {
197
207
  "significance_matrix": significance_matrix,
198
208
  "significant_binary_significance_matrix": significant_binary_significance_matrix,
199
209
  "significant_significance_matrix": significant_significance_matrix,
200
- "neighborhood_significance_counts": neighborhood_significance_counts,
210
+ "cluster_significance_counts": cluster_significance_counts,
201
211
  "node_significance_sums": node_significance_sums,
202
212
  }
203
213
 
@@ -395,6 +405,7 @@ def _prune_neighbors(
395
405
  non_zero_indices = np.where(significant_binary_significance_matrix.sum(axis=1) != 0)[0]
396
406
  median_distances = []
397
407
  distance_lookup = {}
408
+ isolated_nodes = [] # Track nodes with no significant neighbors
398
409
  for node in non_zero_indices:
399
410
  dist = _median_distance_to_significant_neighbors(
400
411
  node, network, significant_binary_significance_matrix
@@ -402,6 +413,8 @@ def _prune_neighbors(
402
413
  if dist is not None:
403
414
  median_distances.append(dist)
404
415
  distance_lookup[node] = dist
416
+ else:
417
+ isolated_nodes.append(node) # Node has no significant neighbors
405
418
 
406
419
  if not median_distances:
407
420
  logger.warning("No significant neighbors found for pruning.")
@@ -422,6 +435,11 @@ def _prune_neighbors(
422
435
  significance_matrix[node] = 0
423
436
  significant_binary_significance_matrix[node] = 0
424
437
 
438
+ # Prune isolated nodes (no significant neighbors)
439
+ for node in isolated_nodes:
440
+ significance_matrix[node] = 0
441
+ significant_binary_significance_matrix[node] = 0
442
+
425
443
  # Create a matrix where non-significant entries are set to zero
426
444
  significant_significance_matrix = np.where(
427
445
  significant_binary_significance_matrix == 1, significance_matrix, 0
@@ -436,7 +454,7 @@ def _prune_neighbors(
436
454
 
437
455
  def _median_distance_to_significant_neighbors(
438
456
  node, network, significance_mask
439
- ) -> Union[float, None]:
457
+ ) -> Union[float, Any, None]:
440
458
  """
441
459
  Calculate the median distance from a node to its significant neighbors.
442
460
 
@@ -448,11 +466,22 @@ def _median_distance_to_significant_neighbors(
448
466
  Returns:
449
467
  Union[float, None]: The median distance to significant neighbors, or None if no significant neighbors exist.
450
468
  """
451
- neighbors = [n for n in network.neighbors(node) if significance_mask[n].sum() != 0]
469
+ # Get all neighbors at once
470
+ neighbors = list(network.neighbors(node))
452
471
  if not neighbors:
453
472
  return None
454
- # Calculate distances to significant neighbors
455
- distances = [_get_euclidean_distance(node, n, network) for n in neighbors]
473
+
474
+ # Vectorized check for significant neighbors
475
+ neighbors = np.array(neighbors)
476
+ significant_mask = significance_mask[neighbors].sum(axis=1) != 0
477
+ significant_neighbors = neighbors[significant_mask]
478
+ if len(significant_neighbors) == 0:
479
+ return None
480
+
481
+ # Vectorized distance calculation
482
+ node_pos = _get_node_position(network, node)
483
+ neighbor_positions = np.array([_get_node_position(network, n) for n in significant_neighbors])
484
+ distances = np.linalg.norm(neighbor_positions - node_pos, axis=1)
456
485
 
457
486
  return np.median(distances)
458
487