risk-network 0.0.16b0__py3-none-any.whl → 0.0.16b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. risk/__init__.py +2 -2
  2. risk/{_annotation → annotation}/__init__.py +2 -2
  3. risk/{_annotation → annotation}/_nltk_setup.py +3 -3
  4. risk/{_annotation/_annotation.py → annotation/annotation.py} +22 -25
  5. risk/{_annotation/_io.py → annotation/io.py} +4 -4
  6. risk/cluster/__init__.py +8 -0
  7. risk/{_neighborhoods → cluster}/_community.py +37 -37
  8. risk/cluster/api.py +273 -0
  9. risk/{_neighborhoods/_neighborhoods.py → cluster/cluster.py} +127 -98
  10. risk/{_neighborhoods/_domains.py → cluster/label.py} +18 -12
  11. risk/{_log → log}/__init__.py +2 -2
  12. risk/{_log/_console.py → log/console.py} +2 -2
  13. risk/{_log/_parameters.py → log/parameters.py} +20 -10
  14. risk/network/__init__.py +8 -0
  15. risk/network/graph/__init__.py +7 -0
  16. risk/{_network/_graph → network/graph}/_stats.py +2 -2
  17. risk/{_network/_graph → network/graph}/_summary.py +13 -13
  18. risk/{_network/_graph/_api.py → network/graph/api.py} +37 -39
  19. risk/{_network/_graph/_graph.py → network/graph/graph.py} +5 -5
  20. risk/{_network/_io.py → network/io.py} +9 -4
  21. risk/network/plotter/__init__.py +6 -0
  22. risk/{_network/_plotter → network/plotter}/_canvas.py +6 -6
  23. risk/{_network/_plotter → network/plotter}/_contour.py +4 -4
  24. risk/{_network/_plotter → network/plotter}/_labels.py +6 -6
  25. risk/{_network/_plotter → network/plotter}/_network.py +7 -7
  26. risk/{_network/_plotter → network/plotter}/_plotter.py +5 -5
  27. risk/network/plotter/_utils/__init__.py +7 -0
  28. risk/{_network/_plotter/_utils/_colors.py → network/plotter/_utils/colors.py} +3 -3
  29. risk/{_network/_plotter/_utils/_layout.py → network/plotter/_utils/layout.py} +2 -2
  30. risk/{_network/_plotter/_api.py → network/plotter/api.py} +5 -5
  31. risk/{_risk.py → risk.py} +9 -8
  32. risk/stats/__init__.py +6 -0
  33. risk/stats/_stats/__init__.py +11 -0
  34. risk/stats/_stats/permutation/__init__.py +6 -0
  35. risk/stats/_stats/permutation/_test_functions.py +72 -0
  36. risk/{_neighborhoods/_stats/_permutation/_permutation.py → stats/_stats/permutation/permutation.py} +35 -37
  37. risk/{_neighborhoods/_stats/_tests.py → stats/_stats/tests.py} +32 -34
  38. risk/stats/api.py +202 -0
  39. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/METADATA +2 -2
  40. risk_network-0.0.16b2.dist-info/RECORD +43 -0
  41. risk/_neighborhoods/__init__.py +0 -8
  42. risk/_neighborhoods/_api.py +0 -354
  43. risk/_neighborhoods/_stats/__init__.py +0 -11
  44. risk/_neighborhoods/_stats/_permutation/__init__.py +0 -6
  45. risk/_neighborhoods/_stats/_permutation/_test_functions.py +0 -72
  46. risk/_network/__init__.py +0 -8
  47. risk/_network/_graph/__init__.py +0 -7
  48. risk/_network/_plotter/__init__.py +0 -6
  49. risk/_network/_plotter/_utils/__init__.py +0 -7
  50. risk_network-0.0.16b0.dist-info/RECORD +0 -41
  51. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/WHEEL +0 -0
  52. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/licenses/LICENSE +0 -0
  53. {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_network
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_network
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
@@ -8,8 +8,8 @@ from typing import Any, Dict, List, Tuple, Union
8
8
  import networkx as nx
9
9
  import numpy as np
10
10
 
11
- from ..._log import params
12
- from .._graph import Graph
11
+ from ...log import params
12
+ from ..graph import Graph
13
13
  from ._utils import get_domain_colors, to_rgba
14
14
 
15
15
 
@@ -273,14 +273,14 @@ class Network:
273
273
  return adjusted_network_colors
274
274
 
275
275
  def get_annotated_node_sizes(
276
- self, significant_size: int = 50, nonsignificant_size: int = 25
276
+ self, significant_size: Union[int, float] = 50, nonsignificant_size: Union[int, float] = 25
277
277
  ) -> np.ndarray:
278
278
  """
279
279
  Adjust the sizes of nodes in the network graph based on whether they are significant or not.
280
280
 
281
281
  Args:
282
- significant_size (int): Size for significant nodes. Defaults to 50.
283
- nonsignificant_size (int): Size for non-significant nodes. Defaults to 25.
282
+ significant_size (int or float): Size for significant nodes. Can be an integer or float value. Defaults to 50.
283
+ nonsignificant_size (int or float): Size for non-significant nodes. Can be an integer or float value. Defaults to 25.
284
284
 
285
285
  Returns:
286
286
  np.ndarray: Array of node sizes, with significant nodes larger than non-significant ones.
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_plotter
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_plotter
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
@@ -8,8 +8,8 @@ from typing import List, Tuple, Union
8
8
  import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
 
11
- from ..._log import params
12
- from .._graph._graph import Graph
11
+ from ...log import params
12
+ from ..graph.graph import Graph
13
13
  from ._canvas import Canvas
14
14
  from ._contour import Contour
15
15
  from ._labels import Labels
@@ -123,7 +123,7 @@ class Plotter(Canvas, Network, Contour, Labels):
123
123
  Args:
124
124
  *args: Positional arguments passed to `plt.savefig`.
125
125
  pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
126
- dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
126
+ dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 100.
127
127
  **kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
128
128
  """
129
129
  # Ensure user-provided kwargs take precedence
@@ -0,0 +1,7 @@
1
+ """
2
+ risk/network/plotter/_utils
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .colors import get_annotated_domain_colors, get_domain_colors, to_rgba
7
+ from .layout import calculate_bounding_box, calculate_centroids
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_utils/_colors
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_utils/colors
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
@@ -9,7 +9,7 @@ import matplotlib
9
9
  import matplotlib.colors as mcolors
10
10
  import numpy as np
11
11
 
12
- from ..._graph import Graph
12
+ from ...graph import Graph
13
13
 
14
14
 
15
15
  def get_annotated_domain_colors(
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_utils/_layout
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_utils/layout
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple
@@ -1,14 +1,14 @@
1
1
  """
2
- risk/_network/_plotter/_api
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
 
10
- from ..._log import log_header
11
- from .._graph import Graph
10
+ from ...log import log_header
11
+ from ..graph import Graph
12
12
  from ._plotter import Plotter
13
13
 
14
14
 
@@ -32,7 +32,7 @@ class PlotterAPI:
32
32
 
33
33
  Args:
34
34
  graph (Graph): The graph to plot.
35
- figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
35
+ figsize (List, Tuple, or np.ndarray, optional): Figure size in inches (width, height). Defaults to (10, 10).
36
36
  background_color (str, optional): Background color of the plot. Defaults to "white".
37
37
  background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
38
38
  any existing alpha values found in background_color. Defaults to 1.0.
@@ -1,20 +1,21 @@
1
1
  """
2
- risk/_risk
3
- ~~~~~~~~~~
2
+ risk/risk
3
+ ~~~~~~~~~
4
4
  """
5
5
 
6
- from ._annotation import AnnotationHandler
7
- from ._log import params, set_global_verbosity
8
- from ._neighborhoods import NeighborhoodsAPI
9
- from ._network import GraphAPI, NetworkAPI, PlotterAPI
6
+ from .annotation import AnnotationHandler
7
+ from .cluster import ClusterAPI
8
+ from .log import params, set_global_verbosity
9
+ from .network import GraphAPI, NetworkAPI, PlotterAPI
10
+ from .stats import StatsAPI
10
11
 
11
12
 
12
- class RISK(NetworkAPI, AnnotationHandler, NeighborhoodsAPI, GraphAPI, PlotterAPI):
13
+ class RISK(NetworkAPI, AnnotationHandler, ClusterAPI, StatsAPI, GraphAPI, PlotterAPI):
13
14
  """
14
15
  RISK: A class for network analysis and visualization.
15
16
 
16
17
  The RISK class integrates functionalities for loading networks, processing annotations,
17
- performing network-based statistical analysis to quantify neighborhood relationships,
18
+ performing network-based statistical analysis to quantify cluster relationships,
18
19
  and visualizing networks and their properties.
19
20
  """
20
21
 
risk/stats/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/stats
3
+ ~~~~~~~~~~
4
+ """
5
+
6
+ from .api import StatsAPI
@@ -0,0 +1,11 @@
1
+ """
2
+ risk/cluster/_stats
3
+ ~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .permutation import compute_permutation_test
7
+ from .tests import (
8
+ compute_binom_test,
9
+ compute_chi2_test,
10
+ compute_hypergeom_test,
11
+ )
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/_clusters/_stats/_permutation
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .permutation import compute_permutation_test
@@ -0,0 +1,72 @@
1
+ """
2
+ risk/stats/_stats/permutation/_test_functions
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import numpy as np
7
+ from scipy.sparse import csr_matrix
8
+
9
+ # NOTE: Cython optimizations provided minimal performance benefits.
10
+ # The final version with Cython is archived in the `cython_permutation` branch.
11
+
12
+ # DISPATCH_TEST_FUNCTIONS can be found at the end of the file.
13
+
14
+
15
+ def compute_cluster_score_by_sum(
16
+ clusters_matrix: csr_matrix, annotation_matrix: csr_matrix
17
+ ) -> np.ndarray:
18
+ """
19
+ Compute the sum of attribute values for each cluster using sparse matrices.
20
+
21
+ Args:
22
+ clusters_matrix (csr_matrix): Sparse binary matrix representing clusters.
23
+ annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
24
+
25
+ Returns:
26
+ np.ndarray: Dense array of summed attribute values for each cluster.
27
+ """
28
+ # Calculate the cluster score as the dot product of clusters and annotation
29
+ cluster_score = clusters_matrix @ annotation_matrix # Sparse matrix multiplication
30
+ # Convert the result to a dense array for downstream calculations
31
+ cluster_score_dense = cluster_score.toarray()
32
+ return cluster_score_dense
33
+
34
+
35
+ def compute_cluster_score_by_stdev(
36
+ clusters_matrix: csr_matrix, annotation_matrix: csr_matrix
37
+ ) -> np.ndarray:
38
+ """
39
+ Compute the standard deviation of cluster scores for sparse matrices.
40
+
41
+ Args:
42
+ clusters_matrix (csr_matrix): Sparse binary matrix representing clusters.
43
+ annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
44
+
45
+ Returns:
46
+ np.ndarray: Standard deviation of the cluster scores.
47
+ """
48
+ # Calculate the cluster score as the dot product of clusters and annotation
49
+ cluster_score = clusters_matrix @ annotation_matrix # Sparse matrix multiplication
50
+ # Calculate the number of elements in each cluster (sum of rows)
51
+ N = clusters_matrix.sum(axis=1).A.flatten() # Convert to 1D array
52
+ # Avoid division by zero by replacing zeros in N with np.nan temporarily
53
+ N[N == 0] = np.nan
54
+ # Compute the mean of the cluster scores
55
+ M = cluster_score.multiply(1 / N[:, None]).toarray() # Sparse element-wise division
56
+ # Compute the mean of squares (EXX) directly using squared annotation matrix
57
+ annotation_squared = annotation_matrix.multiply(annotation_matrix) # Element-wise squaring
58
+ EXX = (clusters_matrix @ annotation_squared).multiply(1 / N[:, None]).toarray()
59
+ # Calculate variance as EXX - M^2
60
+ variance = EXX - np.power(M, 2)
61
+ # Compute the standard deviation as the square root of the variance
62
+ cluster_stdev = np.sqrt(variance)
63
+ # Replace np.nan back with zeros in case N was 0 (no elements in the cluster)
64
+ cluster_stdev[np.isnan(cluster_stdev)] = 0
65
+ return cluster_stdev
66
+
67
+
68
+ # Dictionary to dispatch statistical test functions based on the score metric
69
+ DISPATCH_TEST_FUNCTIONS = {
70
+ "sum": compute_cluster_score_by_sum,
71
+ "stdev": compute_cluster_score_by_stdev,
72
+ }
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_neighborhoods/_stats/_permutation/_permutation
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/stats/_stats/permutation/permutation
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from multiprocessing import Manager, get_context
@@ -16,7 +16,7 @@ from ._test_functions import DISPATCH_TEST_FUNCTIONS
16
16
 
17
17
 
18
18
  def compute_permutation_test(
19
- neighborhoods: csr_matrix,
19
+ clusters: csr_matrix,
20
20
  annotation: csr_matrix,
21
21
  score_metric: str = "sum",
22
22
  null_distribution: str = "network",
@@ -25,10 +25,10 @@ def compute_permutation_test(
25
25
  max_workers: int = 1,
26
26
  ) -> Dict[str, Any]:
27
27
  """
28
- Compute permutation test for enrichment and depletion in neighborhoods.
28
+ Compute permutation test for enrichment and depletion in clusters.
29
29
 
30
30
  Args:
31
- neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
31
+ clusters (csr_matrix): Sparse binary matrix representing clusters.
32
32
  annotation (csr_matrix): Sparse binary matrix representing annotation.
33
33
  score_metric (str, optional): Metric to use for scoring ('sum' or 'stdev'). Defaults to "sum".
34
34
  null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
@@ -41,16 +41,16 @@ def compute_permutation_test(
41
41
  """
42
42
  # Ensure that the matrices are in the correct format and free of NaN values
43
43
  # NOTE: Keep the data type as float32 to avoid locking issues with dot product operations
44
- neighborhoods = neighborhoods.astype(np.float32)
44
+ clusters = clusters.astype(np.float32)
45
45
  annotation = annotation.astype(np.float32)
46
- # Retrieve the appropriate neighborhood score function based on the metric
47
- neighborhood_score_func = DISPATCH_TEST_FUNCTIONS[score_metric]
46
+ # Retrieve the appropriate cluster score function based on the metric
47
+ cluster_score_func = DISPATCH_TEST_FUNCTIONS[score_metric]
48
48
 
49
49
  # Run the permutation test to calculate depletion and enrichment counts
50
50
  counts_depletion, counts_enrichment = _run_permutation_test(
51
- neighborhoods=neighborhoods,
51
+ clusters=clusters,
52
52
  annotation=annotation,
53
- neighborhood_score_func=neighborhood_score_func,
53
+ cluster_score_func=cluster_score_func,
54
54
  null_distribution=null_distribution,
55
55
  num_permutations=num_permutations,
56
56
  random_seed=random_seed,
@@ -68,9 +68,9 @@ def compute_permutation_test(
68
68
 
69
69
 
70
70
  def _run_permutation_test(
71
- neighborhoods: csr_matrix,
71
+ clusters: csr_matrix,
72
72
  annotation: csr_matrix,
73
- neighborhood_score_func: Callable,
73
+ cluster_score_func: Callable,
74
74
  null_distribution: str = "network",
75
75
  num_permutations: int = 1000,
76
76
  random_seed: int = 888,
@@ -80,9 +80,9 @@ def _run_permutation_test(
80
80
  Run the permutation test to calculate depletion and enrichment counts.
81
81
 
82
82
  Args:
83
- neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
83
+ clusters (csr_matrix): Sparse binary matrix representing clusters.
84
84
  annotation (csr_matrix): Sparse binary matrix representing annotation.
85
- neighborhood_score_func (Callable): Function to calculate neighborhood scores.
85
+ cluster_score_func (Callable): Function to calculate cluster scores.
86
86
  null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
87
87
  num_permutations (int, optional): Number of permutations. Defaults to 1000.
88
88
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
@@ -109,16 +109,14 @@ def _run_permutation_test(
109
109
  # Replace NaNs with zeros in the sparse annotation matrix
110
110
  annotation.data[np.isnan(annotation.data)] = 0
111
111
  annotation_matrix_obsv = annotation[idxs]
112
- neighborhoods_matrix_obsv = neighborhoods.T[idxs].T
113
- # Calculate observed neighborhood scores
112
+ clusters_matrix_obsv = clusters.T[idxs].T
113
+ # Calculate observed cluster scores
114
114
  with np.errstate(invalid="ignore", divide="ignore"):
115
- observed_neighborhood_scores = neighborhood_score_func(
116
- neighborhoods_matrix_obsv, annotation_matrix_obsv
117
- )
115
+ observed_cluster_scores = cluster_score_func(clusters_matrix_obsv, annotation_matrix_obsv)
118
116
 
119
117
  # Initialize count matrices for depletion and enrichment
120
- counts_depletion = np.zeros(observed_neighborhood_scores.shape)
121
- counts_enrichment = np.zeros(observed_neighborhood_scores.shape)
118
+ counts_depletion = np.zeros(observed_cluster_scores.shape)
119
+ counts_enrichment = np.zeros(observed_cluster_scores.shape)
122
120
  # Determine the number of permutations to run in each worker process
123
121
  subset_size = num_permutations // max_workers
124
122
  remainder = num_permutations % max_workers
@@ -145,9 +143,9 @@ def _run_permutation_test(
145
143
  (
146
144
  permutation_batches[i], # Pass the batch of precomputed permutations
147
145
  annotation,
148
- neighborhoods_matrix_obsv,
149
- observed_neighborhood_scores,
150
- neighborhood_score_func,
146
+ clusters_matrix_obsv,
147
+ observed_cluster_scores,
148
+ cluster_score_func,
151
149
  num_permutations,
152
150
  progress_counter,
153
151
  max_workers,
@@ -176,9 +174,9 @@ def _run_permutation_test(
176
174
  def _permutation_process_batch(
177
175
  permutations: Union[List, Tuple, np.ndarray],
178
176
  annotation_matrix: csr_matrix,
179
- neighborhoods_matrix_obsv: csr_matrix,
180
- observed_neighborhood_scores: np.ndarray,
181
- neighborhood_score_func: Callable,
177
+ clusters_matrix_obsv: csr_matrix,
178
+ observed_cluster_scores: np.ndarray,
179
+ cluster_score_func: Callable,
182
180
  num_permutations: int,
183
181
  progress_counter: ValueProxy,
184
182
  max_workers: int,
@@ -189,9 +187,9 @@ def _permutation_process_batch(
189
187
  Args:
190
188
  permutations (Union[List, Tuple, np.ndarray]): Permutation batch to process.
191
189
  annotation_matrix (csr_matrix): Sparse binary matrix representing annotation.
192
- neighborhoods_matrix_obsv (csr_matrix): Sparse binary matrix representing observed neighborhoods.
193
- observed_neighborhood_scores (np.ndarray): Observed neighborhood scores.
194
- neighborhood_score_func (Callable): Function to calculate neighborhood scores.
190
+ clusters_matrix_obsv (csr_matrix): Sparse binary matrix representing observed clusters.
191
+ observed_cluster_scores (np.ndarray): Observed cluster scores.
192
+ cluster_score_func (Callable): Function to calculate cluster scores.
195
193
  num_permutations (int): Number of total permutations across all subsets.
196
194
  progress_counter (multiprocessing.managers.ValueProxy): Shared counter for tracking progress.
197
195
  max_workers (int): Number of workers for multiprocessing.
@@ -200,8 +198,8 @@ def _permutation_process_batch(
200
198
  tuple: Local counts of depletion and enrichment.
201
199
  """
202
200
  # Initialize local count matrices for this worker
203
- local_counts_depletion = np.zeros(observed_neighborhood_scores.shape)
204
- local_counts_enrichment = np.zeros(observed_neighborhood_scores.shape)
201
+ local_counts_depletion = np.zeros(observed_cluster_scores.shape)
202
+ local_counts_enrichment = np.zeros(observed_cluster_scores.shape)
205
203
 
206
204
  # Limit the number of threads used by NumPy's BLAS implementation to 1 when more than one worker is used
207
205
  # NOTE: This does not work for Mac M chips due to a bug in the threadpoolctl package
@@ -216,19 +214,19 @@ def _permutation_process_batch(
216
214
  for permuted_idxs in permutations:
217
215
  # Apply precomputed permutation
218
216
  annotation_matrix_permut = annotation_matrix[permuted_idxs]
219
- # Calculate permuted neighborhood scores
217
+ # Calculate permuted cluster scores
220
218
  with np.errstate(invalid="ignore", divide="ignore"):
221
- permuted_neighborhood_scores = neighborhood_score_func(
222
- neighborhoods_matrix_obsv, annotation_matrix_permut
219
+ permuted_cluster_scores = cluster_score_func(
220
+ clusters_matrix_obsv, annotation_matrix_permut
223
221
  )
224
222
 
225
223
  # Update local depletion and enrichment counts
226
224
  local_counts_depletion = np.add(
227
- local_counts_depletion, permuted_neighborhood_scores <= observed_neighborhood_scores
225
+ local_counts_depletion, permuted_cluster_scores <= observed_cluster_scores
228
226
  )
229
227
  local_counts_enrichment = np.add(
230
228
  local_counts_enrichment,
231
- permuted_neighborhood_scores >= observed_neighborhood_scores,
229
+ permuted_cluster_scores >= observed_cluster_scores,
232
230
  )
233
231
 
234
232
  # Update progress
@@ -1,25 +1,25 @@
1
1
  """
2
- risk/_neighborhoods/_stats/_tests
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/stats/_stats/tests
3
+ ~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict
7
7
 
8
8
  import numpy as np
9
9
  from scipy.sparse import csr_matrix
10
- from scipy.stats import binom, chi2, hypergeom, norm
10
+ from scipy.stats import binom, chi2, hypergeom
11
11
 
12
12
 
13
13
  def compute_binom_test(
14
- neighborhoods: csr_matrix,
14
+ clusters: csr_matrix,
15
15
  annotation: csr_matrix,
16
16
  null_distribution: str = "network",
17
17
  ) -> Dict[str, Any]:
18
18
  """
19
- Compute Binomial test for enrichment and depletion in neighborhoods with selectable null distribution.
19
+ Compute Binomial test for enrichment and depletion in clusters with selectable null distribution.
20
20
 
21
21
  Args:
22
- neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
22
+ clusters (csr_matrix): Sparse binary matrix representing clusters.
23
23
  annotation (csr_matrix): Sparse binary matrix representing annotation.
24
24
  null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
25
25
 
@@ -30,10 +30,10 @@ def compute_binom_test(
30
30
  ValueError: If an invalid null_distribution value is provided.
31
31
  """
32
32
  # Get the total number of nodes in the network
33
- total_nodes = neighborhoods.shape[1]
33
+ total_nodes = clusters.shape[1]
34
34
 
35
35
  # Compute sums (remain sparse here)
36
- neighborhood_sizes = neighborhoods.sum(axis=1) # Row sums
36
+ cluster_sizes = clusters.sum(axis=1) # Row sums
37
37
  annotation_totals = annotation.sum(axis=0) # Column sums
38
38
  # Compute probabilities (convert to dense)
39
39
  if null_distribution == "network":
@@ -46,26 +46,26 @@ def compute_binom_test(
46
46
  )
47
47
 
48
48
  # Observed counts (sparse matrix multiplication)
49
- annotated_counts = neighborhoods @ annotation # Sparse result
49
+ annotated_counts = clusters @ annotation # Sparse result
50
50
  annotated_counts_dense = annotated_counts.toarray() # Convert for dense operations
51
51
 
52
52
  # Compute enrichment and depletion p-values
53
- enrichment_pvals = 1 - binom.cdf(annotated_counts_dense - 1, neighborhood_sizes.A, p_values)
54
- depletion_pvals = binom.cdf(annotated_counts_dense, neighborhood_sizes.A, p_values)
53
+ enrichment_pvals = 1 - binom.cdf(annotated_counts_dense - 1, cluster_sizes.A, p_values)
54
+ depletion_pvals = binom.cdf(annotated_counts_dense, cluster_sizes.A, p_values)
55
55
 
56
56
  return {"enrichment_pvals": enrichment_pvals, "depletion_pvals": depletion_pvals}
57
57
 
58
58
 
59
59
  def compute_chi2_test(
60
- neighborhoods: csr_matrix,
60
+ clusters: csr_matrix,
61
61
  annotation: csr_matrix,
62
62
  null_distribution: str = "network",
63
63
  ) -> Dict[str, Any]:
64
64
  """
65
- Compute chi-squared test for enrichment and depletion in neighborhoods with selectable null distribution.
65
+ Compute chi-squared test for enrichment and depletion in clusters with selectable null distribution.
66
66
 
67
67
  Args:
68
- neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
68
+ clusters (csr_matrix): Sparse binary matrix representing clusters.
69
69
  annotation (csr_matrix): Sparse binary matrix representing annotation.
70
70
  null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
71
71
 
@@ -76,12 +76,12 @@ def compute_chi2_test(
76
76
  ValueError: If an invalid null_distribution value is provided.
77
77
  """
78
78
  # Total number of nodes in the network
79
- total_node_count = neighborhoods.shape[0]
79
+ total_node_count = clusters.shape[0]
80
80
 
81
81
  if null_distribution == "network":
82
82
  # Case 1: Use all nodes as the background
83
83
  background_population = total_node_count
84
- neighborhood_sums = neighborhoods.sum(axis=0) # Column sums of neighborhoods
84
+ cluster_sums = clusters.sum(axis=0) # Column sums of clusters
85
85
  annotation_sums = annotation.sum(axis=0) # Column sums of annotations
86
86
  elif null_distribution == "annotation":
87
87
  # Case 2: Only consider nodes with at least one annotation
@@ -89,9 +89,7 @@ def compute_chi2_test(
89
89
  np.ravel(annotation.sum(axis=1)) > 0
90
90
  ) # Row-wise sum to filter nodes with annotations
91
91
  background_population = annotated_nodes.sum() # Total number of annotated nodes
92
- neighborhood_sums = neighborhoods[annotated_nodes].sum(
93
- axis=0
94
- ) # Neighborhood sums for annotated nodes
92
+ cluster_sums = clusters[annotated_nodes].sum(axis=0) # Cluster sums for annotated nodes
95
93
  annotation_sums = annotation[annotated_nodes].sum(
96
94
  axis=0
97
95
  ) # Annotation sums for annotated nodes
@@ -101,13 +99,13 @@ def compute_chi2_test(
101
99
  )
102
100
 
103
101
  # Convert to dense arrays for downstream computations
104
- neighborhood_sums = np.asarray(neighborhood_sums).reshape(-1, 1) # Ensure column vector shape
102
+ cluster_sums = np.asarray(cluster_sums).reshape(-1, 1) # Ensure column vector shape
105
103
  annotation_sums = np.asarray(annotation_sums).reshape(1, -1) # Ensure row vector shape
106
104
 
107
- # Observed values: number of annotated nodes in each neighborhood
108
- observed = neighborhoods.T @ annotation # Shape: (neighborhoods, annotation)
105
+ # Observed values: number of annotated nodes in each cluster
106
+ observed = clusters.T @ annotation # Shape: (clusters, annotation)
109
107
  # Expected values under the null
110
- expected = (neighborhood_sums @ annotation_sums) / background_population
108
+ expected = (cluster_sums @ annotation_sums) / background_population
111
109
  # Chi-squared statistic: sum((observed - expected)^2 / expected)
112
110
  with np.errstate(divide="ignore", invalid="ignore"): # Handle divide-by-zero
113
111
  chi2_stat = np.where(expected > 0, np.power(observed - expected, 2) / expected, 0)
@@ -120,15 +118,15 @@ def compute_chi2_test(
120
118
 
121
119
 
122
120
  def compute_hypergeom_test(
123
- neighborhoods: csr_matrix,
121
+ clusters: csr_matrix,
124
122
  annotation: csr_matrix,
125
123
  null_distribution: str = "network",
126
124
  ) -> Dict[str, Any]:
127
125
  """
128
- Compute hypergeometric test for enrichment and depletion in neighborhoods with selectable null distribution.
126
+ Compute hypergeometric test for enrichment and depletion in clusters with selectable null distribution.
129
127
 
130
128
  Args:
131
- neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
129
+ clusters (csr_matrix): Sparse binary matrix representing clusters.
132
130
  annotation (csr_matrix): Sparse binary matrix representing annotation.
133
131
  null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
134
132
 
@@ -139,10 +137,10 @@ def compute_hypergeom_test(
139
137
  ValueError: If an invalid null_distribution value is provided.
140
138
  """
141
139
  # Get the total number of nodes in the network
142
- total_nodes = neighborhoods.shape[1]
140
+ total_nodes = clusters.shape[1]
143
141
 
144
142
  # Compute sums
145
- neighborhood_sums = neighborhoods.sum(axis=0).A.flatten() # Convert to dense array
143
+ cluster_sums = clusters.sum(axis=0).A.flatten() # Convert to dense array
146
144
  annotation_sums = annotation.sum(axis=0).A.flatten() # Convert to dense array
147
145
 
148
146
  if null_distribution == "network":
@@ -150,7 +148,7 @@ def compute_hypergeom_test(
150
148
  elif null_distribution == "annotation":
151
149
  annotated_nodes = annotation.sum(axis=1).A.flatten() > 0 # Boolean mask
152
150
  background_population = annotated_nodes.sum()
153
- neighborhood_sums = neighborhoods[annotated_nodes].sum(axis=0).A.flatten()
151
+ cluster_sums = clusters[annotated_nodes].sum(axis=0).A.flatten()
154
152
  annotation_sums = annotation[annotated_nodes].sum(axis=0).A.flatten()
155
153
  else:
156
154
  raise ValueError(
@@ -158,19 +156,19 @@ def compute_hypergeom_test(
158
156
  )
159
157
 
160
158
  # Observed counts
161
- annotated_in_neighborhood = neighborhoods.T @ annotation # Sparse result
162
- annotated_in_neighborhood = annotated_in_neighborhood.toarray() # Convert to dense
159
+ annotated_in_cluster = clusters.T @ annotation # Sparse result
160
+ annotated_in_cluster = annotated_in_cluster.toarray() # Convert to dense
163
161
  # Align shapes for broadcasting
164
- neighborhood_sums = neighborhood_sums.reshape(-1, 1)
162
+ cluster_sums = cluster_sums.reshape(-1, 1)
165
163
  annotation_sums = annotation_sums.reshape(1, -1)
166
164
  background_population = np.array(background_population).reshape(1, 1)
167
165
 
168
166
  # Compute hypergeometric p-values
169
167
  depletion_pvals = hypergeom.cdf(
170
- annotated_in_neighborhood, background_population, annotation_sums, neighborhood_sums
168
+ annotated_in_cluster, background_population, annotation_sums, cluster_sums
171
169
  )
172
170
  enrichment_pvals = hypergeom.sf(
173
- annotated_in_neighborhood - 1, background_population, annotation_sums, neighborhood_sums
171
+ annotated_in_cluster - 1, background_population, annotation_sums, cluster_sums
174
172
  )
175
173
 
176
174
  return {"depletion_pvals": depletion_pvals, "enrichment_pvals": enrichment_pvals}