risk-network 0.0.13b4__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. risk/__init__.py +3 -2
  2. risk/_annotation/__init__.py +10 -0
  3. risk/{annotation/annotation.py → _annotation/_annotation.py} +18 -11
  4. risk/{annotation/io.py → _annotation/_io.py} +22 -14
  5. risk/{annotation/nltk_setup.py → _annotation/_nltk_setup.py} +7 -5
  6. risk/_log/__init__.py +11 -0
  7. risk/{log/console.py → _log/_console.py} +22 -12
  8. risk/{log/parameters.py → _log/_parameters.py} +25 -14
  9. risk/_neighborhoods/__init__.py +8 -0
  10. risk/{neighborhoods/api.py → _neighborhoods/_api.py} +23 -17
  11. risk/{neighborhoods/community.py → _neighborhoods/_community.py} +19 -11
  12. risk/{neighborhoods/domains.py → _neighborhoods/_domains.py} +92 -35
  13. risk/{neighborhoods/neighborhoods.py → _neighborhoods/_neighborhoods.py} +69 -58
  14. risk/_neighborhoods/_stats/__init__.py +13 -0
  15. risk/_neighborhoods/_stats/_permutation/__init__.py +6 -0
  16. risk/{neighborhoods/stats/permutation/permutation.py → _neighborhoods/_stats/_permutation/_permutation.py} +9 -6
  17. risk/{neighborhoods/stats/permutation/test_functions.py → _neighborhoods/_stats/_permutation/_test_functions.py} +6 -4
  18. risk/{neighborhoods/stats/tests.py → _neighborhoods/_stats/_tests.py} +12 -7
  19. risk/_network/__init__.py +8 -0
  20. risk/_network/_graph/__init__.py +7 -0
  21. risk/{network/graph/api.py → _network/_graph/_api.py} +13 -13
  22. risk/{network/graph/graph.py → _network/_graph/_graph.py} +24 -13
  23. risk/{network/graph/stats.py → _network/_graph/_stats.py} +8 -5
  24. risk/{network/graph/summary.py → _network/_graph/_summary.py} +39 -32
  25. risk/{network/io.py → _network/_io.py} +166 -148
  26. risk/_network/_plotter/__init__.py +6 -0
  27. risk/{network/plotter/api.py → _network/_plotter/_api.py} +9 -10
  28. risk/{network/plotter/canvas.py → _network/_plotter/_canvas.py} +14 -10
  29. risk/{network/plotter/contour.py → _network/_plotter/_contour.py} +17 -11
  30. risk/{network/plotter/labels.py → _network/_plotter/_labels.py} +38 -23
  31. risk/{network/plotter/network.py → _network/_plotter/_network.py} +17 -11
  32. risk/{network/plotter/plotter.py → _network/_plotter/_plotter.py} +19 -15
  33. risk/_network/_plotter/_utils/__init__.py +7 -0
  34. risk/{network/plotter/utils/colors.py → _network/_plotter/_utils/_colors.py} +19 -11
  35. risk/{network/plotter/utils/layout.py → _network/_plotter/_utils/_layout.py} +8 -5
  36. risk/{risk.py → _risk.py} +11 -11
  37. risk_network-0.0.14.dist-info/METADATA +115 -0
  38. risk_network-0.0.14.dist-info/RECORD +41 -0
  39. {risk_network-0.0.13b4.dist-info → risk_network-0.0.14.dist-info}/WHEEL +1 -1
  40. risk/annotation/__init__.py +0 -10
  41. risk/log/__init__.py +0 -11
  42. risk/neighborhoods/__init__.py +0 -7
  43. risk/neighborhoods/stats/__init__.py +0 -13
  44. risk/neighborhoods/stats/permutation/__init__.py +0 -6
  45. risk/network/__init__.py +0 -4
  46. risk/network/graph/__init__.py +0 -4
  47. risk/network/plotter/__init__.py +0 -4
  48. risk_network-0.0.13b4.dist-info/METADATA +0 -125
  49. risk_network-0.0.13b4.dist-info/RECORD +0 -40
  50. {risk_network-0.0.13b4.dist-info → risk_network-0.0.14.dist-info}/licenses/LICENSE +0 -0
  51. {risk_network-0.0.13b4.dist-info → risk_network-0.0.14.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/community
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/_neighborhoods/_community
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import community as community_louvain
@@ -12,13 +12,14 @@ from leidenalg import RBConfigurationVertexPartition, find_partition
12
12
  from networkx.algorithms.community import greedy_modularity_communities
13
13
  from scipy.sparse import csr_matrix
14
14
 
15
- from risk.log import logger
15
+ from .._log import logger
16
16
 
17
17
 
18
18
  def calculate_greedy_modularity_neighborhoods(
19
19
  network: nx.Graph, fraction_shortest_edges: float = 1.0
20
20
  ) -> csr_matrix:
21
- """Calculate neighborhoods using the Greedy Modularity method with CSR matrix output.
21
+ """
22
+ Calculate neighborhoods using the Greedy Modularity method with CSR matrix output.
22
23
 
23
24
  Args:
24
25
  network (nx.Graph): The network graph.
@@ -62,7 +63,8 @@ def calculate_greedy_modularity_neighborhoods(
62
63
  def calculate_label_propagation_neighborhoods(
63
64
  network: nx.Graph, fraction_shortest_edges: float = 1.0
64
65
  ) -> csr_matrix:
65
- """Apply Label Propagation to the network to detect communities.
66
+ """
67
+ Apply Label Propagation to the network to detect communities.
66
68
 
67
69
  Args:
68
70
  network (nx.Graph): The network graph.
@@ -112,7 +114,8 @@ def calculate_leiden_neighborhoods(
112
114
  fraction_shortest_edges: float = 1.0,
113
115
  random_seed: int = 888,
114
116
  ) -> csr_matrix:
115
- """Calculate neighborhoods using the Leiden method with CSR matrix output.
117
+ """
118
+ Calculate neighborhoods using the Leiden method with CSR matrix output.
116
119
 
117
120
  Args:
118
121
  network (nx.Graph): The network graph.
@@ -168,7 +171,8 @@ def calculate_louvain_neighborhoods(
168
171
  fraction_shortest_edges: float = 1.0,
169
172
  random_seed: int = 888,
170
173
  ) -> csr_matrix:
171
- """Calculate neighborhoods using the Louvain method.
174
+ """
175
+ Calculate neighborhoods using the Louvain method.
172
176
 
173
177
  Args:
174
178
  network (nx.Graph): The network graph.
@@ -221,7 +225,8 @@ def calculate_louvain_neighborhoods(
221
225
  def calculate_markov_clustering_neighborhoods(
222
226
  network: nx.Graph, fraction_shortest_edges: float = 1.0
223
227
  ) -> csr_matrix:
224
- """Apply Markov Clustering (MCL) to the network and return a binary neighborhood matrix (CSR).
228
+ """
229
+ Apply Markov Clustering (MCL) to the network and return a binary neighborhood matrix (CSR).
225
230
 
226
231
  Args:
227
232
  network (nx.Graph): The network graph.
@@ -291,7 +296,8 @@ def calculate_markov_clustering_neighborhoods(
291
296
  def calculate_spinglass_neighborhoods(
292
297
  network: nx.Graph, fraction_shortest_edges: float = 1.0
293
298
  ) -> csr_matrix:
294
- """Apply Spinglass Community Detection to the network, handling disconnected components.
299
+ """
300
+ Apply Spinglass Community Detection to the network, handling disconnected components.
295
301
 
296
302
  Args:
297
303
  network (nx.Graph): The network graph.
@@ -355,7 +361,8 @@ def calculate_spinglass_neighborhoods(
355
361
  def calculate_walktrap_neighborhoods(
356
362
  network: nx.Graph, fraction_shortest_edges: float = 1.0
357
363
  ) -> csr_matrix:
358
- """Apply Walktrap Community Detection to the network with CSR matrix output.
364
+ """
365
+ Apply Walktrap Community Detection to the network with CSR matrix output.
359
366
 
360
367
  Args:
361
368
  network (nx.Graph): The network graph.
@@ -399,7 +406,8 @@ def calculate_walktrap_neighborhoods(
399
406
 
400
407
 
401
408
  def _create_percentile_limited_subgraph(G: nx.Graph, fraction_shortest_edges: float) -> nx.Graph:
402
- """Create a subgraph containing the shortest edges based on the specified rank fraction
409
+ """
410
+ Create a subgraph containing the shortest edges based on the specified rank fraction
403
411
  of all edge lengths in the input graph.
404
412
 
405
413
  Args:
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/domains
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/_neighborhoods/_domains
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from itertools import product
@@ -13,8 +13,9 @@ from scipy.cluster.hierarchy import fcluster, linkage
13
13
  from sklearn.metrics import silhouette_score
14
14
  from tqdm import tqdm
15
15
 
16
- from risk.annotation import get_weighted_description
17
- from risk.log import logger
16
+ from risk._annotation import get_weighted_description
17
+
18
+ from .._log import logger
18
19
 
19
20
  # Define constants for clustering
20
21
  # fmt: off
@@ -35,7 +36,8 @@ def define_domains(
35
36
  linkage_metric: str,
36
37
  linkage_threshold: Union[float, str],
37
38
  ) -> pd.DataFrame:
38
- """Define domains and assign nodes to these domains based on their significance scores and clustering,
39
+ """
40
+ Define domains and assign nodes to these domains based on their significance scores and clustering,
39
41
  handling errors by assigning unique domains when clustering fails.
40
42
 
41
43
  Args:
@@ -52,37 +54,48 @@ def define_domains(
52
54
  Raises:
53
55
  ValueError: If the clustering criterion is set to "off" or if an error occurs during clustering.
54
56
  """
55
- try:
56
- if linkage_criterion == "off":
57
- raise ValueError("Clustering is turned off.")
57
+ # Validate args first; let user mistakes raise immediately
58
+ clustering_off = _validate_clustering_args(
59
+ linkage_criterion, linkage_method, linkage_metric, linkage_threshold
60
+ )
58
61
 
62
+ # If clustering is turned off, assign unique domains and skip
63
+ if clustering_off:
64
+ n_rows = len(top_annotation)
65
+ logger.warning("Clustering is turned off. Skipping clustering.")
66
+ top_annotation["domain"] = range(1, n_rows + 1)
67
+ else:
59
68
  # Transpose the matrix to cluster annotations
60
69
  m = significant_neighborhoods_significance[:, top_annotation["significant_annotation"]].T
61
70
  # Safeguard the matrix by replacing NaN, Inf, and -Inf values
62
71
  m = _safeguard_matrix(m)
63
- # Optimize silhouette score across different linkage methods and distance metrics
64
- best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
65
- m, linkage_criterion, linkage_method, linkage_metric, linkage_threshold
66
- )
67
- # Perform hierarchical clustering
68
- Z = linkage(m, method=best_linkage, metric=best_metric)
69
- logger.warning(
70
- f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'\nLinkage threshold: {round(best_threshold, 3)}"
71
- )
72
- # Calculate the optimal threshold for clustering
73
- max_d_optimal = np.max(Z[:, 2]) * best_threshold
74
- # Assign domains to the annotation matrix
75
- domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
76
- top_annotation["domain"] = 0
77
- top_annotation.loc[top_annotation["significant_annotation"], "domain"] = domains
78
- except (ValueError, LinAlgError):
79
- # If a ValueError is encountered, handle it by assigning unique domains
80
- n_rows = len(top_annotation)
81
- if linkage_criterion == "off":
82
- logger.warning("Clustering is turned off. Skipping clustering.")
83
- else:
84
- logger.error("Error encountered. Skipping clustering.")
85
- top_annotation["domain"] = range(1, n_rows + 1) # Assign unique domains
72
+ try:
73
+ # Optimize silhouette score across different linkage methods and distance metrics
74
+ (
75
+ best_linkage,
76
+ best_metric,
77
+ best_threshold,
78
+ ) = _optimize_silhouette_across_linkage_and_metrics(
79
+ m, linkage_criterion, linkage_method, linkage_metric, linkage_threshold
80
+ )
81
+ # Perform hierarchical clustering
82
+ Z = linkage(m, method=best_linkage, metric=best_metric)
83
+ logger.warning(
84
+ f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'\nLinkage threshold: {round(best_threshold, 3)}"
85
+ )
86
+ # Calculate the optimal threshold for clustering
87
+ max_d_optimal = np.max(Z[:, 2]) * best_threshold
88
+ # Assign domains to the annotation matrix
89
+ domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
90
+ top_annotation["domain"] = 0
91
+ top_annotation.loc[top_annotation["significant_annotation"], "domain"] = domains
92
+ except (LinAlgError, ValueError):
93
+ # Numerical errors or degenerate input are handled gracefully (not user error)
94
+ n_rows = len(top_annotation)
95
+ logger.error(
96
+ "Clustering failed due to numerical or data degeneracy. Assigning unique domains."
97
+ )
98
+ top_annotation["domain"] = range(1, n_rows + 1)
86
99
 
87
100
  # Create DataFrames to store domain information
88
101
  node_to_significance = pd.DataFrame(
@@ -112,7 +125,8 @@ def trim_domains(
112
125
  min_cluster_size: int = 5,
113
126
  max_cluster_size: int = 1000,
114
127
  ) -> Tuple[pd.DataFrame, pd.DataFrame]:
115
- """Trim domains that do not meet size criteria and find outliers.
128
+ """
129
+ Trim domains that do not meet size criteria and find outliers.
116
130
 
117
131
  Args:
118
132
  domains (pd.DataFrame): DataFrame of domain data for the network nodes.
@@ -181,8 +195,49 @@ def trim_domains(
181
195
  return valid_domains, valid_trimmed_domains_matrix
182
196
 
183
197
 
198
+ def _validate_clustering_args(
199
+ linkage_criterion: str,
200
+ linkage_method: str,
201
+ linkage_metric: str,
202
+ linkage_threshold: Union[float, str],
203
+ ) -> bool:
204
+ """
205
+ Validate user-provided clustering arguments.
206
+
207
+ Returns:
208
+ bool: True if clustering is turned off (criterion == 'off'); False otherwise.
209
+
210
+ Raises:
211
+ ValueError: If any argument is invalid (user error).
212
+ """
213
+ # Allow opting out of clustering without raising
214
+ if linkage_criterion == "off":
215
+ return True
216
+ # Validate linkage method (allow "auto")
217
+ if linkage_method != "auto" and linkage_method not in LINKAGE_METHODS:
218
+ raise ValueError(
219
+ f"Invalid linkage_method '{linkage_method}'. Allowed values are 'auto' or one of: {sorted(LINKAGE_METHODS)}"
220
+ )
221
+ # Validate linkage metric (allow "auto")
222
+ if linkage_metric != "auto" and linkage_metric not in LINKAGE_METRICS:
223
+ raise ValueError(
224
+ f"Invalid linkage_metric '{linkage_metric}'. Allowed values are 'auto' or one of: {sorted(LINKAGE_METRICS)}"
225
+ )
226
+ # Validate linkage threshold (allow "auto"; otherwise must be float in (0, 1])
227
+ if linkage_threshold != "auto":
228
+ try:
229
+ lt = float(linkage_threshold)
230
+ except (TypeError, ValueError):
231
+ raise ValueError("linkage_threshold must be 'auto' or a float in the interval (0, 1].")
232
+ if not (0.0 < lt <= 1.0):
233
+ raise ValueError(f"linkage_threshold must be within (0, 1]. Received: {lt}")
234
+
235
+ return False
236
+
237
+
184
238
  def _safeguard_matrix(matrix: np.ndarray) -> np.ndarray:
185
- """Safeguard the matrix by replacing NaN, Inf, and -Inf values.
239
+ """
240
+ Safeguard the matrix by replacing NaN, Inf, and -Inf values.
186
241
 
187
242
  Args:
188
243
  matrix (np.ndarray): Data matrix.
@@ -211,7 +266,8 @@ def _optimize_silhouette_across_linkage_and_metrics(
211
266
  linkage_metric: str,
212
267
  linkage_threshold: Union[str, float],
213
268
  ) -> Tuple[str, str, float]:
214
- """Optimize silhouette score across different linkage methods and distance metrics.
269
+ """
270
+ Optimize silhouette score across different linkage methods and distance metrics.
215
271
 
216
272
  Args:
217
273
  m (np.ndarray): Data matrix.
@@ -287,7 +343,8 @@ def _find_best_silhouette_score(
287
343
  lower_bound: float = 0.001,
288
344
  upper_bound: float = 1.0,
289
345
  ) -> Tuple[float, float]:
290
- """Find the best silhouette score using binary search.
346
+ """
347
+ Find the best silhouette score using binary search.
291
348
 
292
349
  Args:
293
350
  Z (np.ndarray): Linkage matrix.
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/neighborhoods
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/_neighborhoods/_neighborhoods
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import random
@@ -13,8 +13,8 @@ from scipy.sparse import csr_matrix
13
13
  from sklearn.exceptions import DataConversionWarning
14
14
  from sklearn.metrics.pairwise import cosine_similarity
15
15
 
16
- from risk.log import logger
17
- from risk.neighborhoods.community import (
16
+ from .._log import logger
17
+ from ._community import (
18
18
  calculate_greedy_modularity_neighborhoods,
19
19
  calculate_label_propagation_neighborhoods,
20
20
  calculate_leiden_neighborhoods,
@@ -36,7 +36,8 @@ def get_network_neighborhoods(
36
36
  leiden_resolution: float = 1.0,
37
37
  random_seed: int = 888,
38
38
  ) -> csr_matrix:
39
- """Calculate the combined neighborhoods for each node using sparse matrices.
39
+ """
40
+ Calculate the combined neighborhoods for each node using sparse matrices.
40
41
 
41
42
  Args:
42
43
  network (nx.Graph): The network graph.
@@ -125,7 +126,8 @@ def get_network_neighborhoods(
125
126
 
126
127
 
127
128
  def _set_max_row_value_to_one_sparse(matrix: csr_matrix) -> csr_matrix:
128
- """Set the maximum value in each row of a sparse matrix to 1.
129
+ """
130
+ Set the maximum value in each row of a sparse matrix to 1.
129
131
 
130
132
  Args:
131
133
  matrix (csr_matrix): The input sparse matrix.
@@ -142,34 +144,14 @@ def _set_max_row_value_to_one_sparse(matrix: csr_matrix) -> csr_matrix:
142
144
  return matrix
143
145
 
144
146
 
145
- def _set_max_row_value_to_one(matrix: np.ndarray) -> np.ndarray:
146
- """For each row in the input matrix, set the maximum value(s) to 1 and all other values to 0. This is particularly
147
- useful for neighborhood matrices that have undergone multiple neighborhood detection algorithms, where the
148
- maximum value in each row represents the most significant relationship per node in the combined neighborhoods.
149
-
150
- Args:
151
- matrix (np.ndarray): A 2D numpy array representing the neighborhood matrix.
152
-
153
- Returns:
154
- np.ndarray: The modified matrix where only the maximum value(s) in each row is set to 1, and others are set to 0.
155
- """
156
- # Find the maximum value in each row (column-wise max operation)
157
- max_values = np.max(matrix, axis=1, keepdims=True)
158
- # Create a boolean mask where elements are True if they are the max value in their row
159
- max_mask = matrix == max_values
160
- # Set all elements to 0, and then set the maximum value positions to 1
161
- matrix[:] = 0 # Set everything to 0
162
- matrix[max_mask] = 1 # Set only the max values to 1
163
- return matrix
164
-
165
-
166
147
  def process_neighborhoods(
167
148
  network: nx.Graph,
168
149
  neighborhoods: Dict[str, Any],
169
150
  impute_depth: int = 0,
170
151
  prune_threshold: float = 0.0,
171
152
  ) -> Dict[str, Any]:
172
- """Process neighborhoods based on the imputation and pruning settings.
153
+ """
154
+ Process neighborhoods based on the imputation and pruning settings.
173
155
 
174
156
  Args:
175
157
  network (nx.Graph): The network data structure used for imputing and pruning neighbors.
@@ -226,7 +208,8 @@ def _impute_neighbors(
226
208
  significant_binary_significance_matrix: np.ndarray,
227
209
  max_depth: int = 3,
228
210
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
229
- """Impute rows with sums of zero in the significance matrix based on the closest non-zero neighbors in the network graph.
211
+ """
212
+ Impute rows with sums of zero in the significance matrix based on the closest non-zero neighbors in the network graph.
230
213
 
231
214
  Args:
232
215
  network (nx.Graph): The network graph with nodes having IDs matching the matrix indices.
@@ -262,7 +245,8 @@ def _impute_neighbors_with_similarity(
262
245
  significant_binary_significance_matrix: np.ndarray,
263
246
  max_depth: int = 3,
264
247
  ) -> Tuple[np.ndarray, np.ndarray]:
265
- """Impute non-significant nodes based on the closest significant neighbors' profiles and their similarity.
248
+ """
249
+ Impute non-significant nodes based on the closest significant neighbors' profiles and their similarity.
266
250
 
267
251
  Args:
268
252
  network (nx.Graph): The network graph with nodes having IDs matching the matrix indices.
@@ -306,7 +290,8 @@ def _process_node_imputation(
306
290
  significant_binary_significance_matrix: np.ndarray,
307
291
  depth: int,
308
292
  ) -> Tuple[np.ndarray, np.ndarray]:
309
- """Process the imputation for a single node based on its significant neighbors.
293
+ """
294
+ Process the imputation for a single node based on its significant neighbors.
310
295
 
311
296
  Args:
312
297
  row_index (int): The index of the significant node being processed.
@@ -391,7 +376,8 @@ def _prune_neighbors(
391
376
  significant_binary_significance_matrix: np.ndarray,
392
377
  distance_threshold: float = 0.9,
393
378
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
394
- """Remove outliers based on their rank for edge lengths.
379
+ """
380
+ Remove outliers based on their rank for edge lengths.
395
381
 
396
382
  Args:
397
383
  network (nx.Graph): The network graph with nodes having IDs matching the matrix indices.
@@ -408,34 +394,33 @@ def _prune_neighbors(
408
394
  # Identify indices with non-zero rows in the binary significance matrix
409
395
  non_zero_indices = np.where(significant_binary_significance_matrix.sum(axis=1) != 0)[0]
410
396
  median_distances = []
397
+ distance_lookup = {}
411
398
  for node in non_zero_indices:
412
- neighbors = [
413
- n
414
- for n in network.neighbors(node)
415
- if significant_binary_significance_matrix[n].sum() != 0
416
- ]
417
- if neighbors:
418
- median_distance = np.median(
419
- [_get_euclidean_distance(node, n, network) for n in neighbors]
420
- )
421
- median_distances.append(median_distance)
399
+ dist = _median_distance_to_significant_neighbors(
400
+ node, network, significant_binary_significance_matrix
401
+ )
402
+ if dist is not None:
403
+ median_distances.append(dist)
404
+ distance_lookup[node] = dist
405
+
406
+ if not median_distances:
407
+ logger.warning("No significant neighbors found for pruning.")
408
+ significant_significance_matrix = np.where(
409
+ significant_binary_significance_matrix == 1, significance_matrix, 0
410
+ )
411
+ return (
412
+ significance_matrix,
413
+ significant_binary_significance_matrix,
414
+ significant_significance_matrix,
415
+ )
422
416
 
423
417
  # Calculate the distance threshold value based on rank
424
418
  distance_threshold_value = _calculate_threshold(median_distances, 1 - distance_threshold)
425
419
  # Prune nodes that are outliers based on the distance threshold
426
- for row_index in non_zero_indices:
427
- neighbors = [
428
- n
429
- for n in network.neighbors(row_index)
430
- if significant_binary_significance_matrix[n].sum() != 0
431
- ]
432
- if neighbors:
433
- median_distance = np.median(
434
- [_get_euclidean_distance(row_index, n, network) for n in neighbors]
435
- )
436
- if median_distance >= distance_threshold_value:
437
- significance_matrix[row_index] = 0
438
- significant_binary_significance_matrix[row_index] = 0
420
+ for node, dist in distance_lookup.items():
421
+ if dist >= distance_threshold_value:
422
+ significance_matrix[node] = 0
423
+ significant_binary_significance_matrix[node] = 0
439
424
 
440
425
  # Create a matrix where non-significant entries are set to zero
441
426
  significant_significance_matrix = np.where(
@@ -449,8 +434,32 @@ def _prune_neighbors(
449
434
  )
450
435
 
451
436
 
437
+ def _median_distance_to_significant_neighbors(
438
+ node, network, significance_mask
439
+ ) -> Union[float, None]:
440
+ """
441
+ Calculate the median distance from a node to its significant neighbors.
442
+
443
+ Args:
444
+ node (Any): The node for which the median distance is being calculated.
445
+ network (nx.Graph): The network graph containing the nodes.
446
+ significance_mask (np.ndarray): Binary matrix indicating significant nodes.
447
+
448
+ Returns:
449
+ Union[float, None]: The median distance to significant neighbors, or None if no significant neighbors exist.
450
+ """
451
+ neighbors = [n for n in network.neighbors(node) if significance_mask[n].sum() != 0]
452
+ if not neighbors:
453
+ return None
454
+ # Calculate distances to significant neighbors
455
+ distances = [_get_euclidean_distance(node, n, network) for n in neighbors]
456
+
457
+ return np.median(distances)
458
+
459
+
452
460
  def _get_euclidean_distance(node1: Any, node2: Any, network: nx.Graph) -> float:
453
- """Calculate the Euclidean distance between two nodes in the network.
461
+ """
462
+ Calculate the Euclidean distance between two nodes in the network.
454
463
 
455
464
  Args:
456
465
  node1 (Any): The first node.
@@ -466,7 +475,8 @@ def _get_euclidean_distance(node1: Any, node2: Any, network: nx.Graph) -> float:
466
475
 
467
476
 
468
477
  def _get_node_position(network: nx.Graph, node: Any) -> np.ndarray:
469
- """Retrieve the position of a node in the network as a numpy array.
478
+ """
479
+ Retrieve the position of a node in the network as a numpy array.
470
480
 
471
481
  Args:
472
482
  network (nx.Graph): The network graph containing node positions.
@@ -485,7 +495,8 @@ def _get_node_position(network: nx.Graph, node: Any) -> np.ndarray:
485
495
 
486
496
 
487
497
  def _calculate_threshold(median_distances: List, distance_threshold: float) -> float:
488
- """Calculate the distance threshold based on the given median distances and a percentile threshold.
498
+ """
499
+ Calculate the distance threshold based on the given median distances and a percentile threshold.
489
500
 
490
501
  Args:
491
502
  median_distances (List): An array of median distances.
@@ -0,0 +1,13 @@
1
+ """
2
+ risk/_neighborhoods/_stats
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from ._permutation import compute_permutation_test
7
+ from ._tests import (
8
+ compute_binom_test,
9
+ compute_chi2_test,
10
+ compute_hypergeom_test,
11
+ compute_poisson_test,
12
+ compute_zscore_test,
13
+ )
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/_neighborhoods/_stats/_permutation
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from ._permutation import compute_permutation_test
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/stats/permutation/permutation
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/_neighborhoods/_stats/_permutation/_permutation
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from multiprocessing import Manager, get_context
@@ -12,7 +12,7 @@ from scipy.sparse import csr_matrix
12
12
  from threadpoolctl import threadpool_limits
13
13
  from tqdm import tqdm
14
14
 
15
- from risk.neighborhoods.stats.permutation.test_functions import DISPATCH_TEST_FUNCTIONS
15
+ from ._test_functions import DISPATCH_TEST_FUNCTIONS
16
16
 
17
17
 
18
18
  def compute_permutation_test(
@@ -24,7 +24,8 @@ def compute_permutation_test(
24
24
  random_seed: int = 888,
25
25
  max_workers: int = 1,
26
26
  ) -> Dict[str, Any]:
27
- """Compute permutation test for enrichment and depletion in neighborhoods.
27
+ """
28
+ Compute permutation test for enrichment and depletion in neighborhoods.
28
29
 
29
30
  Args:
30
31
  neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -75,7 +76,8 @@ def _run_permutation_test(
75
76
  random_seed: int = 888,
76
77
  max_workers: int = 4,
77
78
  ) -> tuple:
78
- """Run the permutation test to calculate depletion and enrichment counts.
79
+ """
80
+ Run the permutation test to calculate depletion and enrichment counts.
79
81
 
80
82
  Args:
81
83
  neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -181,7 +183,8 @@ def _permutation_process_batch(
181
183
  progress_counter: ValueProxy,
182
184
  max_workers: int,
183
185
  ) -> tuple:
184
- """Process a batch of permutations in a worker process.
186
+ """
187
+ Process a batch of permutations in a worker process.
185
188
 
186
189
  Args:
187
190
  permutations (Union[List, Tuple, np.ndarray]): Permutation batch to process.
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/stats/permutation/test_functions
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/_neighborhoods/_stats/_permutation/_test_functions
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import numpy as np
@@ -15,7 +15,8 @@ from scipy.sparse import csr_matrix
15
15
  def compute_neighborhood_score_by_sum(
16
16
  neighborhoods_matrix: csr_matrix, annotation_matrix: csr_matrix
17
17
  ) -> np.ndarray:
18
- """Compute the sum of attribute values for each neighborhood using sparse matrices.
18
+ """
19
+ Compute the sum of attribute values for each neighborhood using sparse matrices.
19
20
 
20
21
  Args:
21
22
  neighborhoods_matrix (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -34,7 +35,8 @@ def compute_neighborhood_score_by_sum(
34
35
  def compute_neighborhood_score_by_stdev(
35
36
  neighborhoods_matrix: csr_matrix, annotation_matrix: csr_matrix
36
37
  ) -> np.ndarray:
37
- """Compute the standard deviation of neighborhood scores for sparse matrices.
38
+ """
39
+ Compute the standard deviation of neighborhood scores for sparse matrices.
38
40
 
39
41
  Args:
40
42
  neighborhoods_matrix (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/stats/tests
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/_neighborhoods/_stats/_tests
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict
@@ -15,7 +15,8 @@ def compute_binom_test(
15
15
  annotation: csr_matrix,
16
16
  null_distribution: str = "network",
17
17
  ) -> Dict[str, Any]:
18
- """Compute Binomial test for enrichment and depletion in neighborhoods with selectable null distribution.
18
+ """
19
+ Compute Binomial test for enrichment and depletion in neighborhoods with selectable null distribution.
19
20
 
20
21
  Args:
21
22
  neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -60,7 +61,8 @@ def compute_chi2_test(
60
61
  annotation: csr_matrix,
61
62
  null_distribution: str = "network",
62
63
  ) -> Dict[str, Any]:
63
- """Compute chi-squared test for enrichment and depletion in neighborhoods with selectable null distribution.
64
+ """
65
+ Compute chi-squared test for enrichment and depletion in neighborhoods with selectable null distribution.
64
66
 
65
67
  Args:
66
68
  neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -122,7 +124,8 @@ def compute_hypergeom_test(
122
124
  annotation: csr_matrix,
123
125
  null_distribution: str = "network",
124
126
  ) -> Dict[str, Any]:
125
- """Compute hypergeometric test for enrichment and depletion in neighborhoods with selectable null distribution.
127
+ """
128
+ Compute hypergeometric test for enrichment and depletion in neighborhoods with selectable null distribution.
126
129
 
127
130
  Args:
128
131
  neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -178,7 +181,8 @@ def compute_poisson_test(
178
181
  annotation: csr_matrix,
179
182
  null_distribution: str = "network",
180
183
  ) -> Dict[str, Any]:
181
- """Compute Poisson test for enrichment and depletion in neighborhoods with selectable null distribution.
184
+ """
185
+ Compute Poisson test for enrichment and depletion in neighborhoods with selectable null distribution.
182
186
 
183
187
  Args:
184
188
  neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -220,7 +224,8 @@ def compute_zscore_test(
220
224
  annotation: csr_matrix,
221
225
  null_distribution: str = "network",
222
226
  ) -> Dict[str, Any]:
223
- """Compute z-score test for enrichment and depletion in neighborhoods with selectable null distribution.
227
+ """
228
+ Compute z-score test for enrichment and depletion in neighborhoods with selectable null distribution.
224
229
 
225
230
  Args:
226
231
  neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
@@ -0,0 +1,8 @@
1
+ """
2
+ risk/_network
3
+ ~~~~~~~~~~~~~
4
+ """
5
+
6
+ from ._graph import GraphAPI
7
+ from ._io import NetworkAPI
8
+ from ._plotter import PlotterAPI