risk-network 0.0.8b27__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +195 -118
  4. risk/annotations/io.py +47 -31
  5. risk/log/__init__.py +4 -2
  6. risk/log/console.py +3 -1
  7. risk/log/{params.py → parameters.py} +17 -42
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +442 -0
  10. risk/neighborhoods/community.py +324 -101
  11. risk/neighborhoods/domains.py +125 -52
  12. risk/neighborhoods/neighborhoods.py +177 -165
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +71 -89
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +200 -0
  17. risk/network/{graph.py → graph/graph.py} +90 -40
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +103 -114
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +9 -8
  23. risk/network/{plot → plotter}/contour.py +27 -24
  24. risk/network/{plot → plotter}/labels.py +73 -78
  25. risk/network/{plot → plotter}/network.py +45 -39
  26. risk/network/{plot → plotter}/plotter.py +23 -17
  27. risk/network/{plot/utils/color.py → plotter/utils/colors.py} +114 -122
  28. risk/network/{plot → plotter}/utils/layout.py +10 -7
  29. risk/risk.py +11 -500
  30. risk/stats/__init__.py +10 -4
  31. risk/stats/permutation/__init__.py +1 -1
  32. risk/stats/permutation/permutation.py +44 -38
  33. risk/stats/permutation/test_functions.py +26 -18
  34. risk/stats/{stats.py → significance.py} +17 -15
  35. risk/stats/stat_tests.py +267 -0
  36. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/METADATA +31 -46
  37. risk_network-0.0.9.dist-info/RECORD +40 -0
  38. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/WHEEL +1 -1
  39. risk/constants.py +0 -31
  40. risk/network/plot/__init__.py +0 -6
  41. risk/stats/hypergeom.py +0 -54
  42. risk/stats/poisson.py +0 -44
  43. risk_network-0.0.8b27.dist-info/RECORD +0 -37
  44. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/LICENSE +0 -0
  45. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/top_level.txt +0 -0
@@ -5,18 +5,19 @@ risk/stats/permutation/permutation
5
5
 
6
6
  from multiprocessing import get_context, Manager
7
7
  from multiprocessing.managers import ValueProxy
8
- from tqdm import tqdm
9
- from typing import Any, Callable, Dict
8
+ from typing import Any, Callable, Dict, List, Tuple, Union
10
9
 
11
10
  import numpy as np
11
+ from scipy.sparse import csr_matrix
12
12
  from threadpoolctl import threadpool_limits
13
+ from tqdm import tqdm
13
14
 
14
15
  from risk.stats.permutation.test_functions import DISPATCH_TEST_FUNCTIONS
15
16
 
16
17
 
17
18
  def compute_permutation_test(
18
- neighborhoods: np.ndarray,
19
- annotations: np.ndarray,
19
+ neighborhoods: csr_matrix,
20
+ annotations: csr_matrix,
20
21
  score_metric: str = "sum",
21
22
  null_distribution: str = "network",
22
23
  num_permutations: int = 1000,
@@ -26,9 +27,9 @@ def compute_permutation_test(
26
27
  """Compute permutation test for enrichment and depletion in neighborhoods.
27
28
 
28
29
  Args:
29
- neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
30
- annotations (np.ndarray): Binary matrix representing annotations.
31
- score_metric (str, optional): Metric to use for scoring ('sum', 'mean', etc.). Defaults to "sum".
30
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
31
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
32
+ score_metric (str, optional): Metric to use for scoring ('sum' or 'stdev'). Defaults to "sum".
32
33
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
33
34
  num_permutations (int, optional): Number of permutations to run. Defaults to 1000.
34
35
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
@@ -38,6 +39,7 @@ def compute_permutation_test(
38
39
  Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
39
40
  """
40
41
  # Ensure that the matrices are in the correct format and free of NaN values
42
+ # NOTE: Keep the data type as float32 to avoid locking issues with dot product operations
41
43
  neighborhoods = neighborhoods.astype(np.float32)
42
44
  annotations = annotations.astype(np.float32)
43
45
  # Retrieve the appropriate neighborhood score function based on the metric
@@ -65,19 +67,19 @@ def compute_permutation_test(
65
67
 
66
68
 
67
69
  def _run_permutation_test(
68
- neighborhoods: np.ndarray,
69
- annotations: np.ndarray,
70
+ neighborhoods: csr_matrix,
71
+ annotations: csr_matrix,
70
72
  neighborhood_score_func: Callable,
71
73
  null_distribution: str = "network",
72
74
  num_permutations: int = 1000,
73
75
  random_seed: int = 888,
74
76
  max_workers: int = 4,
75
77
  ) -> tuple:
76
- """Run a permutation test to calculate enrichment and depletion counts.
78
+ """Run the permutation test to calculate depletion and enrichment counts.
77
79
 
78
80
  Args:
79
- neighborhoods (np.ndarray): The neighborhood matrix.
80
- annotations (np.ndarray): The annotation matrix.
81
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
82
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
81
83
  neighborhood_score_func (Callable): Function to calculate neighborhood scores.
82
84
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
83
85
  num_permutations (int, optional): Number of permutations. Defaults to 1000.
@@ -93,14 +95,14 @@ def _run_permutation_test(
93
95
  if null_distribution == "network":
94
96
  idxs = range(annotations.shape[0])
95
97
  elif null_distribution == "annotations":
96
- idxs = np.nonzero(np.sum(~np.isnan(annotations), axis=1))[0]
98
+ idxs = np.nonzero(annotations.getnnz(axis=1) > 0)[0]
97
99
  else:
98
100
  raise ValueError(
99
101
  "Invalid null_distribution value. Choose either 'network' or 'annotations'."
100
102
  )
101
103
 
102
- # Replace NaNs with zeros in the annotations matrix
103
- annotations[np.isnan(annotations)] = 0
104
+ # Replace NaNs with zeros in the sparse annotations matrix
105
+ annotations.data[np.isnan(annotations.data)] = 0
104
106
  annotation_matrix_obsv = annotations[idxs]
105
107
  neighborhoods_matrix_obsv = neighborhoods.T[idxs].T
106
108
  # Calculate observed neighborhood scores
@@ -121,28 +123,35 @@ def _run_permutation_test(
121
123
  manager = Manager()
122
124
  progress_counter = manager.Value("i", 0)
123
125
  total_progress = num_permutations
126
+
127
+ # Generate precomputed permutations
128
+ permutations = [rng.permutation(idxs) for _ in range(num_permutations)]
129
+ # Divide permutations into batches for workers
130
+ batch_size = subset_size + (1 if remainder > 0 else 0)
131
+ permutation_batches = [
132
+ permutations[i * batch_size : (i + 1) * batch_size] for i in range(max_workers)
133
+ ]
134
+
124
135
  # Execute the permutation test using multiprocessing
125
136
  with ctx.Pool(max_workers) as pool:
126
137
  with tqdm(total=total_progress, desc="Total progress", position=0) as progress:
127
138
  # Prepare parameters for multiprocessing
128
139
  params_list = [
129
140
  (
141
+ permutation_batches[i], # Pass the batch of precomputed permutations
130
142
  annotations,
131
- np.array(idxs),
132
143
  neighborhoods_matrix_obsv,
133
144
  observed_neighborhood_scores,
134
145
  neighborhood_score_func,
135
- subset_size + (1 if i < remainder else 0),
136
146
  num_permutations,
137
147
  progress_counter,
138
148
  max_workers,
139
- rng, # Pass the random number generator to each worker
140
149
  )
141
150
  for i in range(max_workers)
142
151
  ]
143
152
 
144
153
  # Start the permutation process in parallel
145
- results = pool.starmap_async(_permutation_process_subset, params_list, chunksize=1)
154
+ results = pool.starmap_async(_permutation_process_batch, params_list, chunksize=1)
146
155
 
147
156
  # Update progress bar based on progress_counter
148
157
  while not results.ready():
@@ -159,31 +168,27 @@ def _run_permutation_test(
159
168
  return counts_depletion, counts_enrichment
160
169
 
161
170
 
162
- def _permutation_process_subset(
163
- annotation_matrix: np.ndarray,
164
- idxs: np.ndarray,
165
- neighborhoods_matrix_obsv: np.ndarray,
171
+ def _permutation_process_batch(
172
+ permutations: Union[List, Tuple, np.ndarray],
173
+ annotation_matrix: csr_matrix,
174
+ neighborhoods_matrix_obsv: csr_matrix,
166
175
  observed_neighborhood_scores: np.ndarray,
167
176
  neighborhood_score_func: Callable,
168
- subset_size: int,
169
177
  num_permutations: int,
170
178
  progress_counter: ValueProxy,
171
179
  max_workers: int,
172
- rng: np.random.Generator,
173
180
  ) -> tuple:
174
- """Process a subset of permutations for the permutation test.
181
+ """Process a batch of permutations in a worker process.
175
182
 
176
183
  Args:
177
- annotation_matrix (np.ndarray): The annotation matrix.
178
- idxs (np.ndarray): Indices of valid rows in the matrix.
179
- neighborhoods_matrix_obsv (np.ndarray): Observed neighborhoods matrix.
184
+ permutations (Union[List, Tuple, np.ndarray]): Permutation batch to process.
185
+ annotation_matrix (csr_matrix): Sparse binary matrix representing annotations.
186
+ neighborhoods_matrix_obsv (csr_matrix): Sparse binary matrix representing observed neighborhoods.
180
187
  observed_neighborhood_scores (np.ndarray): Observed neighborhood scores.
181
188
  neighborhood_score_func (Callable): Function to calculate neighborhood scores.
182
- subset_size (int): Number of permutations to run in this subset.
183
189
  num_permutations (int): Number of total permutations across all subsets.
184
190
  progress_counter (multiprocessing.managers.ValueProxy): Shared counter for tracking progress.
185
191
  max_workers (int): Number of workers for multiprocessing.
186
- rng (np.random.Generator): Random number generator object.
187
192
 
188
193
  Returns:
189
194
  tuple: Local counts of depletion and enrichment.
@@ -192,7 +197,9 @@ def _permutation_process_subset(
192
197
  local_counts_depletion = np.zeros(observed_neighborhood_scores.shape)
193
198
  local_counts_enrichment = np.zeros(observed_neighborhood_scores.shape)
194
199
 
195
- # NOTE: Limit the number of threads used by NumPy's BLAS implementation to 1 when more than one worker is used.
200
+ # Limit the number of threads used by NumPy's BLAS implementation to 1 when more than one worker is used
201
+ # NOTE: This does not work for Mac M chips due to a bug in the threadpoolctl package
202
+ # This is currently a known issue and is being addressed by the maintainers [https://github.com/joblib/threadpoolctl/issues/135]
196
203
  limits = None if max_workers == 1 else 1
197
204
  with threadpool_limits(limits=limits, user_api="blas"):
198
205
  # Initialize a local counter for batched progress updates
@@ -200,16 +207,16 @@ def _permutation_process_subset(
200
207
  # Calculate the modulo value based on total permutations for 1/100th frequency updates
201
208
  modulo_value = max(1, num_permutations // 100)
202
209
 
203
- for _ in range(subset_size):
204
- # Permute the annotation matrix using the RNG
205
- annotation_matrix_permut = annotation_matrix[rng.permutation(idxs)]
210
+ for permuted_idxs in permutations:
211
+ # Apply precomputed permutation
212
+ annotation_matrix_permut = annotation_matrix[permuted_idxs]
206
213
  # Calculate permuted neighborhood scores
207
214
  with np.errstate(invalid="ignore", divide="ignore"):
208
215
  permuted_neighborhood_scores = neighborhood_score_func(
209
216
  neighborhoods_matrix_obsv, annotation_matrix_permut
210
217
  )
211
218
 
212
- # Update local depletion and enrichment counts based on permuted scores
219
+ # Update local depletion and enrichment counts
213
220
  local_counts_depletion = np.add(
214
221
  local_counts_depletion, permuted_neighborhood_scores <= observed_neighborhood_scores
215
222
  )
@@ -218,9 +225,8 @@ def _permutation_process_subset(
218
225
  permuted_neighborhood_scores >= observed_neighborhood_scores,
219
226
  )
220
227
 
221
- # Update local progress counter
228
+ # Update progress
222
229
  local_progress += 1
223
- # Update shared progress counter every 1/100th of total permutations
224
230
  if local_progress % modulo_value == 0:
225
231
  progress_counter.value += modulo_value
226
232
 
@@ -4,53 +4,61 @@ risk/stats/permutation/test_functions
4
4
  """
5
5
 
6
6
  import numpy as np
7
+ from scipy.sparse import csr_matrix
7
8
 
8
- # Note: Cython optimizations provided minimal performance benefits.
9
+ # NOTE: Cython optimizations provided minimal performance benefits.
9
10
  # The final version with Cython is archived in the `cython_permutation` branch.
10
11
  # DISPATCH_TEST_FUNCTIONS can be found at the end of the file.
11
12
 
12
13
 
13
14
  def compute_neighborhood_score_by_sum(
14
- neighborhoods_matrix: np.ndarray, annotation_matrix: np.ndarray
15
+ neighborhoods_matrix: csr_matrix, annotation_matrix: csr_matrix
15
16
  ) -> np.ndarray:
16
- """Compute the sum of attribute values for each neighborhood.
17
+ """Compute the sum of attribute values for each neighborhood using sparse matrices.
17
18
 
18
19
  Args:
19
- neighborhoods_matrix (np.ndarray): Binary matrix representing neighborhoods.
20
- annotation_matrix (np.ndarray): Matrix representing annotation values.
20
+ neighborhoods_matrix (csr_matrix): Sparse binary matrix representing neighborhoods.
21
+ annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
21
22
 
22
23
  Returns:
23
- np.ndarray: Sum of attribute values for each neighborhood.
24
+ np.ndarray: Dense array of summed attribute values for each neighborhood.
24
25
  """
25
26
  # Calculate the neighborhood score as the dot product of neighborhoods and annotations
26
- neighborhood_sum = np.dot(neighborhoods_matrix, annotation_matrix)
27
- return neighborhood_sum
27
+ neighborhood_score = neighborhoods_matrix @ annotation_matrix # Sparse matrix multiplication
28
+ # Convert the result to a dense array for downstream calculations
29
+ neighborhood_score_dense = neighborhood_score.toarray()
30
+ return neighborhood_score_dense
28
31
 
29
32
 
30
33
  def compute_neighborhood_score_by_stdev(
31
- neighborhoods_matrix: np.ndarray, annotation_matrix: np.ndarray
34
+ neighborhoods_matrix: csr_matrix, annotation_matrix: csr_matrix
32
35
  ) -> np.ndarray:
33
- """Compute the standard deviation of neighborhood scores.
36
+ """Compute the standard deviation of neighborhood scores for sparse matrices.
34
37
 
35
38
  Args:
36
- neighborhoods_matrix (np.ndarray): Binary matrix representing neighborhoods.
37
- annotation_matrix (np.ndarray): Matrix representing annotation values.
39
+ neighborhoods_matrix (csr_matrix): Sparse binary matrix representing neighborhoods.
40
+ annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
38
41
 
39
42
  Returns:
40
43
  np.ndarray: Standard deviation of the neighborhood scores.
41
44
  """
42
45
  # Calculate the neighborhood score as the dot product of neighborhoods and annotations
43
- neighborhood_score = np.dot(neighborhoods_matrix, annotation_matrix)
44
- # Calculate the number of elements in each neighborhood
45
- N = np.sum(neighborhoods_matrix, axis=1)
46
+ neighborhood_score = neighborhoods_matrix @ annotation_matrix # Sparse matrix multiplication
47
+ # Calculate the number of elements in each neighborhood (sum of rows)
48
+ N = neighborhoods_matrix.sum(axis=1).A.flatten() # Convert to 1D array
49
+ # Avoid division by zero by replacing zeros in N with np.nan temporarily
50
+ N[N == 0] = np.nan
46
51
  # Compute the mean of the neighborhood scores
47
- M = neighborhood_score / N[:, None]
52
+ M = neighborhood_score.multiply(1 / N[:, None]).toarray() # Sparse element-wise division
48
53
  # Compute the mean of squares (EXX) directly using squared annotation matrix
49
- EXX = np.dot(neighborhoods_matrix, annotation_matrix**2) / N[:, None]
54
+ annotation_squared = annotation_matrix.multiply(annotation_matrix) # Element-wise squaring
55
+ EXX = (neighborhoods_matrix @ annotation_squared).multiply(1 / N[:, None]).toarray()
50
56
  # Calculate variance as EXX - M^2
51
- variance = EXX - M**2
57
+ variance = EXX - np.power(M, 2)
52
58
  # Compute the standard deviation as the square root of the variance
53
59
  neighborhood_stdev = np.sqrt(variance)
60
+ # Replace np.nan back with zeros in case N was 0 (no elements in the neighborhood)
61
+ neighborhood_stdev[np.isnan(neighborhood_stdev)] = 0
54
62
  return neighborhood_stdev
55
63
 
56
64
 
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/stats/stats
3
- ~~~~~~~~~~~~~~~~
2
+ risk/stats/significance
3
+ ~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, Union
@@ -44,7 +44,7 @@ def calculate_significance_matrices(
44
44
  enrichment_pvals, enrichment_qvals, pval_cutoff=pval_cutoff, fdr_cutoff=fdr_cutoff
45
45
  )
46
46
  # Compute the enrichment matrix using both q-values and p-values
47
- enrichment_matrix = (enrichment_qvals**2) * (enrichment_pvals**0.5)
47
+ enrichment_matrix = (enrichment_pvals**0.5) * (enrichment_qvals**2)
48
48
  else:
49
49
  # Compute threshold matrices based on p-value cutoffs only
50
50
  depletion_alpha_threshold_matrix = _compute_threshold_matrix(
@@ -62,7 +62,7 @@ def calculate_significance_matrices(
62
62
  log_enrichment_matrix = -np.log10(enrichment_matrix)
63
63
 
64
64
  # Select the appropriate significance matrices based on the specified tail
65
- enrichment_matrix, significant_binary_enrichment_matrix = _select_significance_matrices(
65
+ significance_matrix, significant_binary_significance_matrix = _select_significance_matrices(
66
66
  tail,
67
67
  log_depletion_matrix,
68
68
  depletion_alpha_threshold_matrix,
@@ -71,14 +71,14 @@ def calculate_significance_matrices(
71
71
  )
72
72
 
73
73
  # Filter the enrichment matrix using the binary significance matrix
74
- significant_enrichment_matrix = np.where(
75
- significant_binary_enrichment_matrix == 1, enrichment_matrix, 0
74
+ significant_significance_matrix = np.where(
75
+ significant_binary_significance_matrix == 1, significance_matrix, 0
76
76
  )
77
77
 
78
78
  return {
79
- "enrichment_matrix": enrichment_matrix,
80
- "significant_binary_enrichment_matrix": significant_binary_enrichment_matrix,
81
- "significant_enrichment_matrix": significant_enrichment_matrix,
79
+ "significance_matrix": significance_matrix,
80
+ "significant_significance_matrix": significant_significance_matrix,
81
+ "significant_binary_significance_matrix": significant_binary_significance_matrix,
82
82
  }
83
83
 
84
84
 
@@ -109,15 +109,15 @@ def _select_significance_matrices(
109
109
 
110
110
  if tail == "left":
111
111
  # Select depletion matrix and corresponding alpha threshold for left-tail analysis
112
- enrichment_matrix = -log_depletion_matrix
112
+ significance_matrix = -log_depletion_matrix
113
113
  alpha_threshold_matrix = depletion_alpha_threshold_matrix
114
114
  elif tail == "right":
115
115
  # Select enrichment matrix and corresponding alpha threshold for right-tail analysis
116
- enrichment_matrix = log_enrichment_matrix
116
+ significance_matrix = log_enrichment_matrix
117
117
  alpha_threshold_matrix = enrichment_alpha_threshold_matrix
118
118
  elif tail == "both":
119
119
  # Select the matrix with the highest absolute values while preserving the sign
120
- enrichment_matrix = np.where(
120
+ significance_matrix = np.where(
121
121
  np.abs(log_depletion_matrix) >= np.abs(log_enrichment_matrix),
122
122
  -log_depletion_matrix,
123
123
  log_enrichment_matrix,
@@ -126,13 +126,15 @@ def _select_significance_matrices(
126
126
  alpha_threshold_matrix = np.logical_or(
127
127
  depletion_alpha_threshold_matrix, enrichment_alpha_threshold_matrix
128
128
  )
129
+ else:
130
+ raise ValueError("Invalid value for 'tail'. Must be 'left', 'right', or 'both'.")
129
131
 
130
132
  # Create a binary significance matrix where valid indices meet the alpha threshold
131
133
  valid_idxs = ~np.isnan(alpha_threshold_matrix)
132
- significant_binary_enrichment_matrix = np.zeros(alpha_threshold_matrix.shape)
133
- significant_binary_enrichment_matrix[valid_idxs] = alpha_threshold_matrix[valid_idxs]
134
+ significant_binary_significance_matrix = np.zeros(alpha_threshold_matrix.shape)
135
+ significant_binary_significance_matrix[valid_idxs] = alpha_threshold_matrix[valid_idxs]
134
136
 
135
- return enrichment_matrix, significant_binary_enrichment_matrix
137
+ return significance_matrix, significant_binary_significance_matrix
136
138
 
137
139
 
138
140
  def _compute_threshold_matrix(
@@ -0,0 +1,267 @@
1
+ """
2
+ risk/stats/stat_tests
3
+ ~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict
7
+
8
+ import numpy as np
9
+ from scipy.sparse import csr_matrix
10
+ from scipy.stats import binom
11
+ from scipy.stats import chi2
12
+ from scipy.stats import hypergeom
13
+ from scipy.stats import norm
14
+ from scipy.stats import poisson
15
+
16
+
17
+ def compute_binom_test(
18
+ neighborhoods: csr_matrix,
19
+ annotations: csr_matrix,
20
+ null_distribution: str = "network",
21
+ ) -> Dict[str, Any]:
22
+ """Compute Binomial test for enrichment and depletion in neighborhoods with selectable null distribution.
23
+
24
+ Args:
25
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
26
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
27
+ null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
28
+
29
+ Returns:
30
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
31
+ """
32
+ # Get the total number of nodes in the network
33
+ total_nodes = neighborhoods.shape[1]
34
+
35
+ # Compute sums (remain sparse here)
36
+ neighborhood_sizes = neighborhoods.sum(axis=1) # Row sums
37
+ annotation_totals = annotations.sum(axis=0) # Column sums
38
+ # Compute probabilities (convert to dense)
39
+ if null_distribution == "network":
40
+ p_values = (annotation_totals / total_nodes).A.flatten() # Dense 1D array
41
+ elif null_distribution == "annotations":
42
+ p_values = (annotation_totals / annotations.sum()).A.flatten() # Dense 1D array
43
+ else:
44
+ raise ValueError(
45
+ "Invalid null_distribution value. Choose either 'network' or 'annotations'."
46
+ )
47
+
48
+ # Observed counts (sparse matrix multiplication)
49
+ annotated_counts = neighborhoods @ annotations # Sparse result
50
+ annotated_counts_dense = annotated_counts.toarray() # Convert for dense operations
51
+
52
+ # Compute enrichment and depletion p-values
53
+ enrichment_pvals = 1 - binom.cdf(annotated_counts_dense - 1, neighborhood_sizes.A, p_values)
54
+ depletion_pvals = binom.cdf(annotated_counts_dense, neighborhood_sizes.A, p_values)
55
+
56
+ return {"enrichment_pvals": enrichment_pvals, "depletion_pvals": depletion_pvals}
57
+
58
+
59
+ def compute_chi2_test(
60
+ neighborhoods: csr_matrix,
61
+ annotations: csr_matrix,
62
+ null_distribution: str = "network",
63
+ ) -> Dict[str, Any]:
64
+ """Compute chi-squared test for enrichment and depletion in neighborhoods with selectable null distribution.
65
+
66
+ Args:
67
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
68
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
69
+ null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
70
+
71
+ Returns:
72
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
73
+ """
74
+ # Total number of nodes in the network
75
+ total_node_count = neighborhoods.shape[0]
76
+
77
+ if null_distribution == "network":
78
+ # Case 1: Use all nodes as the background
79
+ background_population = total_node_count
80
+ neighborhood_sums = neighborhoods.sum(axis=0) # Column sums of neighborhoods
81
+ annotation_sums = annotations.sum(axis=0) # Column sums of annotations
82
+ elif null_distribution == "annotations":
83
+ # Case 2: Only consider nodes with at least one annotation
84
+ annotated_nodes = (
85
+ np.ravel(annotations.sum(axis=1)) > 0
86
+ ) # Row-wise sum to filter nodes with annotations
87
+ background_population = annotated_nodes.sum() # Total number of annotated nodes
88
+ neighborhood_sums = neighborhoods[annotated_nodes].sum(
89
+ axis=0
90
+ ) # Neighborhood sums for annotated nodes
91
+ annotation_sums = annotations[annotated_nodes].sum(
92
+ axis=0
93
+ ) # Annotation sums for annotated nodes
94
+ else:
95
+ raise ValueError(
96
+ "Invalid null_distribution value. Choose either 'network' or 'annotations'."
97
+ )
98
+
99
+ # Convert to dense arrays for downstream computations
100
+ neighborhood_sums = np.asarray(neighborhood_sums).reshape(-1, 1) # Ensure column vector shape
101
+ annotation_sums = np.asarray(annotation_sums).reshape(1, -1) # Ensure row vector shape
102
+
103
+ # Observed values: number of annotated nodes in each neighborhood
104
+ observed = neighborhoods.T @ annotations # Shape: (neighborhoods, annotations)
105
+ # Expected values under the null
106
+ expected = (neighborhood_sums @ annotation_sums) / background_population
107
+ # Chi-squared statistic: sum((observed - expected)^2 / expected)
108
+ with np.errstate(divide="ignore", invalid="ignore"): # Handle divide-by-zero
109
+ chi2_stat = np.where(expected > 0, np.power(observed - expected, 2) / expected, 0)
110
+
111
+ # Compute p-values for enrichment (upper tail) and depletion (lower tail)
112
+ enrichment_pvals = chi2.sf(chi2_stat, df=1) # Survival function for upper tail
113
+ depletion_pvals = chi2.cdf(chi2_stat, df=1) # Cumulative distribution for lower tail
114
+
115
+ return {"depletion_pvals": depletion_pvals, "enrichment_pvals": enrichment_pvals}
116
+
117
+
118
+ def compute_hypergeom_test(
119
+ neighborhoods: csr_matrix,
120
+ annotations: csr_matrix,
121
+ null_distribution: str = "network",
122
+ ) -> Dict[str, Any]:
123
+ """
124
+ Compute hypergeometric test for enrichment and depletion in neighborhoods with selectable null distribution.
125
+
126
+ Args:
127
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
128
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
129
+ null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
130
+
131
+ Returns:
132
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
133
+ """
134
+ # Get the total number of nodes in the network
135
+ total_nodes = neighborhoods.shape[1]
136
+
137
+ # Compute sums
138
+ neighborhood_sums = neighborhoods.sum(axis=0).A.flatten() # Convert to dense array
139
+ annotation_sums = annotations.sum(axis=0).A.flatten() # Convert to dense array
140
+
141
+ if null_distribution == "network":
142
+ background_population = total_nodes
143
+ elif null_distribution == "annotations":
144
+ annotated_nodes = annotations.sum(axis=1).A.flatten() > 0 # Boolean mask
145
+ background_population = annotated_nodes.sum()
146
+ neighborhood_sums = neighborhoods[annotated_nodes].sum(axis=0).A.flatten()
147
+ annotation_sums = annotations[annotated_nodes].sum(axis=0).A.flatten()
148
+ else:
149
+ raise ValueError(
150
+ "Invalid null_distribution value. Choose either 'network' or 'annotations'."
151
+ )
152
+
153
+ # Observed counts
154
+ annotated_in_neighborhood = neighborhoods.T @ annotations # Sparse result
155
+ annotated_in_neighborhood = annotated_in_neighborhood.toarray() # Convert to dense
156
+ # Align shapes for broadcasting
157
+ neighborhood_sums = neighborhood_sums.reshape(-1, 1)
158
+ annotation_sums = annotation_sums.reshape(1, -1)
159
+ background_population = np.array(background_population).reshape(1, 1)
160
+
161
+ # Compute hypergeometric p-values
162
+ depletion_pvals = hypergeom.cdf(
163
+ annotated_in_neighborhood, background_population, annotation_sums, neighborhood_sums
164
+ )
165
+ enrichment_pvals = hypergeom.sf(
166
+ annotated_in_neighborhood - 1, background_population, annotation_sums, neighborhood_sums
167
+ )
168
+
169
+ return {"depletion_pvals": depletion_pvals, "enrichment_pvals": enrichment_pvals}
170
+
171
+
172
+ def compute_poisson_test(
173
+ neighborhoods: csr_matrix,
174
+ annotations: csr_matrix,
175
+ null_distribution: str = "network",
176
+ ) -> Dict[str, Any]:
177
+ """
178
+ Compute Poisson test for enrichment and depletion in neighborhoods with selectable null distribution.
179
+
180
+ Args:
181
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
182
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
183
+ null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
184
+
185
+ Returns:
186
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
187
+ """
188
+ # Matrix multiplication to get the number of annotated nodes in each neighborhood
189
+ annotated_in_neighborhood = neighborhoods @ annotations # Sparse result
190
+ # Convert annotated counts to dense for downstream calculations
191
+ annotated_in_neighborhood_dense = annotated_in_neighborhood.toarray()
192
+
193
+ # Compute lambda_expected based on the chosen null distribution
194
+ if null_distribution == "network":
195
+ # Use the mean across neighborhoods (axis=1)
196
+ lambda_expected = np.mean(annotated_in_neighborhood_dense, axis=1, keepdims=True)
197
+ elif null_distribution == "annotations":
198
+ # Use the mean across annotations (axis=0)
199
+ lambda_expected = np.mean(annotated_in_neighborhood_dense, axis=0, keepdims=True)
200
+ else:
201
+ raise ValueError(
202
+ "Invalid null_distribution value. Choose either 'network' or 'annotations'."
203
+ )
204
+
205
+ # Compute p-values for enrichment and depletion using Poisson distribution
206
+ enrichment_pvals = 1 - poisson.cdf(annotated_in_neighborhood_dense - 1, lambda_expected)
207
+ depletion_pvals = poisson.cdf(annotated_in_neighborhood_dense, lambda_expected)
208
+
209
+ return {"enrichment_pvals": enrichment_pvals, "depletion_pvals": depletion_pvals}
210
+
211
+
212
+ def compute_zscore_test(
213
+ neighborhoods: csr_matrix,
214
+ annotations: csr_matrix,
215
+ null_distribution: str = "network",
216
+ ) -> Dict[str, Any]:
217
+ """
218
+ Compute z-score test for enrichment and depletion in neighborhoods with selectable null distribution.
219
+
220
+ Args:
221
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
222
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
223
+ null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
224
+
225
+ Returns:
226
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
227
+ """
228
+ # Total number of nodes in the network
229
+ total_node_count = neighborhoods.shape[1]
230
+
231
+ # Compute sums
232
+ if null_distribution == "network":
233
+ background_population = total_node_count
234
+ neighborhood_sums = neighborhoods.sum(axis=0).A.flatten() # Dense column sums
235
+ annotation_sums = annotations.sum(axis=0).A.flatten() # Dense row sums
236
+ elif null_distribution == "annotations":
237
+ annotated_nodes = annotations.sum(axis=1).A.flatten() > 0 # Dense boolean mask
238
+ background_population = annotated_nodes.sum()
239
+ neighborhood_sums = neighborhoods[annotated_nodes].sum(axis=0).A.flatten()
240
+ annotation_sums = annotations[annotated_nodes].sum(axis=0).A.flatten()
241
+ else:
242
+ raise ValueError(
243
+ "Invalid null_distribution value. Choose either 'network' or 'annotations'."
244
+ )
245
+
246
+ # Observed values
247
+ observed = (neighborhoods.T @ annotations).toarray() # Convert sparse result to dense
248
+ # Expected values under the null
249
+ neighborhood_sums = neighborhood_sums.reshape(-1, 1) # Ensure correct shape
250
+ annotation_sums = annotation_sums.reshape(1, -1) # Ensure correct shape
251
+ expected = (neighborhood_sums @ annotation_sums) / background_population
252
+
253
+ # Standard deviation under the null
254
+ std_dev = np.sqrt(
255
+ expected
256
+ * (1 - annotation_sums / background_population)
257
+ * (1 - neighborhood_sums / background_population)
258
+ )
259
+ std_dev[std_dev == 0] = np.nan # Avoid division by zero
260
+ # Compute z-scores
261
+ z_scores = (observed - expected) / std_dev
262
+
263
+ # Convert z-scores to depletion and enrichment p-values
264
+ enrichment_pvals = norm.sf(z_scores) # Upper tail
265
+ depletion_pvals = norm.cdf(z_scores) # Lower tail
266
+
267
+ return {"depletion_pvals": depletion_pvals, "enrichment_pvals": enrichment_pvals}