risk-network 0.0.3b4__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/risk.py CHANGED
@@ -6,6 +6,7 @@ risk/risk
6
6
  from typing import Any, Dict
7
7
 
8
8
  import networkx as nx
9
+ import numpy as np
9
10
  import pandas as pd
10
11
 
11
12
  from risk.annotations import AnnotationsIO, define_top_annotations
@@ -17,7 +18,12 @@ from risk.neighborhoods import (
17
18
  trim_domains_and_top_annotations,
18
19
  )
19
20
  from risk.network import NetworkIO, NetworkGraph, NetworkPlotter
20
- from risk.stats import compute_permutation, calculate_significance_matrices
21
+ from risk.stats import (
22
+ calculate_significance_matrices,
23
+ compute_fisher_exact_test,
24
+ compute_hypergeom_test,
25
+ compute_permutation_test,
26
+ )
21
27
 
22
28
 
23
29
  class RISK(NetworkIO, AnnotationsIO):
@@ -27,85 +33,39 @@ class RISK(NetworkIO, AnnotationsIO):
27
33
  and performing network-based statistical analysis, such as neighborhood significance testing.
28
34
  """
29
35
 
30
- def __init__(
31
- self,
32
- compute_sphere: bool = True,
33
- surface_depth: float = 0.0,
34
- distance_metric: str = "dijkstra",
35
- louvain_resolution: float = 0.1,
36
- min_edges_per_node: int = 0,
37
- edge_length_threshold: float = 0.5,
38
- include_edge_weight: bool = True,
39
- weight_label: str = "weight",
40
- ):
41
- """Initialize the RISK class with configuration settings.
42
-
43
- Args:
44
- compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
45
- surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
46
- distance_metric (str, optional): Distance metric to use in network analysis. Defaults to "dijkstra".
47
- louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
48
- min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
49
- edge_length_threshold (float, optional): Edge length threshold for analysis. Defaults to 0.5.
50
- include_edge_weight (bool, optional): Whether to include edge weights in calculations. Defaults to True.
51
- weight_label (str, optional): Label for edge weights. Defaults to "weight".
52
- """
36
+ def __init__(self, *args, **kwargs):
37
+ """Initialize the RISK class with configuration settings."""
53
38
  # Initialize and log network parameters
54
39
  params.initialize()
55
- params.log_network(
56
- compute_sphere=compute_sphere,
57
- surface_depth=surface_depth,
58
- distance_metric=distance_metric,
59
- louvain_resolution=louvain_resolution,
60
- min_edges_per_node=min_edges_per_node,
61
- edge_length_threshold=edge_length_threshold,
62
- include_edge_weight=include_edge_weight,
63
- weight_label=weight_label,
64
- )
65
- # Initialize parent classes
66
- NetworkIO.__init__(
67
- self,
68
- compute_sphere=compute_sphere,
69
- surface_depth=surface_depth,
70
- distance_metric=distance_metric,
71
- louvain_resolution=louvain_resolution,
72
- min_edges_per_node=min_edges_per_node,
73
- edge_length_threshold=edge_length_threshold,
74
- include_edge_weight=include_edge_weight,
75
- weight_label=weight_label,
76
- )
77
- AnnotationsIO.__init__(self)
78
-
79
- # Set class attributes
80
- self.compute_sphere = compute_sphere
81
- self.surface_depth = surface_depth
82
- self.distance_metric = distance_metric
83
- self.louvain_resolution = louvain_resolution
84
- self.min_edges_per_node = min_edges_per_node
85
- self.edge_length_threshold = edge_length_threshold
86
- self.include_edge_weight = include_edge_weight
87
- self.weight_label = weight_label
40
+ # Initialize the parent classes
41
+ super().__init__(*args, **kwargs)
88
42
 
89
43
  @property
90
44
  def params(self):
91
45
  """Access the logged parameters."""
92
46
  return params
93
47
 
94
- def load_neighborhoods(
48
+ def load_neighborhoods_by_permutation(
95
49
  self,
96
50
  network: nx.Graph,
97
51
  annotations: Dict[str, Any],
52
+ distance_metric: str = "dijkstra",
53
+ louvain_resolution: float = 0.1,
54
+ edge_length_threshold: float = 0.5,
98
55
  score_metric: str = "sum",
99
56
  null_distribution: str = "network",
100
57
  num_permutations: int = 1000,
101
58
  random_seed: int = 888,
102
59
  max_workers: int = 1,
103
60
  ) -> Dict[str, Any]:
104
- """Load significant neighborhoods for the network.
61
+ """Load significant neighborhoods for the network using the permutation test.
105
62
 
106
63
  Args:
107
64
  network (nx.Graph): The network graph.
108
65
  annotations (pd.DataFrame): The matrix of annotations associated with the network.
66
+ distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
67
+ louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
68
+ edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
109
69
  score_metric (str, optional): Scoring metric for neighborhood significance. Defaults to "sum".
110
70
  null_distribution (str, optional): Distribution used for permutation tests. Defaults to "network".
111
71
  num_permutations (int, optional): Number of permutations for significance testing. Defaults to 1000.
@@ -118,6 +78,10 @@ class RISK(NetworkIO, AnnotationsIO):
118
78
  print_header("Running permutation test")
119
79
  # Log neighborhood analysis parameters
120
80
  params.log_neighborhoods(
81
+ distance_metric=distance_metric,
82
+ louvain_resolution=louvain_resolution,
83
+ edge_length_threshold=edge_length_threshold,
84
+ statistical_test_function="permutation",
121
85
  score_metric=score_metric,
122
86
  null_distribution=null_distribution,
123
87
  num_permutations=num_permutations,
@@ -125,27 +89,22 @@ class RISK(NetworkIO, AnnotationsIO):
125
89
  max_workers=max_workers,
126
90
  )
127
91
 
128
- # Display the chosen distance metric
129
- if self.distance_metric == "louvain":
130
- for_print_distance_metric = f"louvain (resolution={self.louvain_resolution})"
131
- else:
132
- for_print_distance_metric = self.distance_metric
133
- print(f"Distance metric: '{for_print_distance_metric}'")
134
- # Compute neighborhoods based on the network and distance metric
135
- neighborhoods = get_network_neighborhoods(
92
+ # Load neighborhoods based on the network and distance metric
93
+ neighborhoods = self._load_neighborhoods(
136
94
  network,
137
- self.distance_metric,
138
- self.edge_length_threshold,
139
- louvain_resolution=self.louvain_resolution,
95
+ distance_metric,
96
+ louvain_resolution=louvain_resolution,
97
+ edge_length_threshold=edge_length_threshold,
140
98
  random_seed=random_seed,
141
99
  )
142
100
 
143
101
  # Log and display permutation test settings
144
- print(f"Null distribution: '{null_distribution}'")
145
102
  print(f"Neighborhood scoring metric: '{score_metric}'")
103
+ print(f"Null distribution: '{null_distribution}'")
146
104
  print(f"Number of permutations: {num_permutations}")
147
- # Run the permutation test to compute neighborhood significance
148
- neighborhood_significance = compute_permutation(
105
+ print(f"Maximum workers: {max_workers}")
106
+ # Run permutation test to compute neighborhood significance
107
+ neighborhood_significance = compute_permutation_test(
149
108
  neighborhoods=neighborhoods,
150
109
  annotations=annotations["matrix"],
151
110
  score_metric=score_metric,
@@ -157,6 +116,116 @@ class RISK(NetworkIO, AnnotationsIO):
157
116
 
158
117
  return neighborhood_significance
159
118
 
119
+ def load_neighborhoods_by_fisher_exact(
120
+ self,
121
+ network: nx.Graph,
122
+ annotations: Dict[str, Any],
123
+ distance_metric: str = "dijkstra",
124
+ louvain_resolution: float = 0.1,
125
+ edge_length_threshold: float = 0.5,
126
+ random_seed: int = 888,
127
+ max_workers: int = 1,
128
+ ) -> Dict[str, Any]:
129
+ """Load significant neighborhoods for the network using the Fisher's exact test.
130
+
131
+ Args:
132
+ network (nx.Graph): The network graph.
133
+ annotations (pd.DataFrame): The matrix of annotations associated with the network.
134
+ distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
135
+ louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
136
+ edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
137
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
138
+ max_workers (int, optional): Maximum number of workers for parallel computation. Defaults to 1.
139
+
140
+ Returns:
141
+ dict: Computed significance of neighborhoods.
142
+ """
143
+ print_header("Running Fisher's exact test")
144
+ # Log neighborhood analysis parameters
145
+ params.log_neighborhoods(
146
+ distance_metric=distance_metric,
147
+ louvain_resolution=louvain_resolution,
148
+ edge_length_threshold=edge_length_threshold,
149
+ statistical_test_function="fisher_exact",
150
+ random_seed=random_seed,
151
+ max_workers=max_workers,
152
+ )
153
+
154
+ # Load neighborhoods based on the network and distance metric
155
+ neighborhoods = self._load_neighborhoods(
156
+ network,
157
+ distance_metric,
158
+ louvain_resolution=louvain_resolution,
159
+ edge_length_threshold=edge_length_threshold,
160
+ random_seed=random_seed,
161
+ )
162
+
163
+ # Log and display Fisher's exact test settings
164
+ print(f"Maximum workers: {max_workers}")
165
+ # Run Fisher's exact test to compute neighborhood significance
166
+ neighborhood_significance = compute_fisher_exact_test(
167
+ neighborhoods=neighborhoods,
168
+ annotations=annotations["matrix"],
169
+ max_workers=max_workers,
170
+ )
171
+
172
+ return neighborhood_significance
173
+
174
+ def load_neighborhoods_by_hypergeom(
175
+ self,
176
+ network: nx.Graph,
177
+ annotations: Dict[str, Any],
178
+ distance_metric: str = "dijkstra",
179
+ louvain_resolution: float = 0.1,
180
+ edge_length_threshold: float = 0.5,
181
+ random_seed: int = 888,
182
+ max_workers: int = 1,
183
+ ) -> Dict[str, Any]:
184
+ """Load significant neighborhoods for the network using the hypergeometric test.
185
+
186
+ Args:
187
+ network (nx.Graph): The network graph.
188
+ annotations (pd.DataFrame): The matrix of annotations associated with the network.
189
+ distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
190
+ louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
191
+ edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
192
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
193
+ max_workers (int, optional): Maximum number of workers for parallel computation. Defaults to 1.
194
+
195
+ Returns:
196
+ dict: Computed significance of neighborhoods.
197
+ """
198
+ print_header("Running hypergeometric test")
199
+ # Log neighborhood analysis parameters
200
+ params.log_neighborhoods(
201
+ distance_metric=distance_metric,
202
+ louvain_resolution=louvain_resolution,
203
+ edge_length_threshold=edge_length_threshold,
204
+ statistical_test_function="hypergeom",
205
+ random_seed=random_seed,
206
+ max_workers=max_workers,
207
+ )
208
+
209
+ # Load neighborhoods based on the network and distance metric
210
+ neighborhoods = self._load_neighborhoods(
211
+ network,
212
+ distance_metric,
213
+ louvain_resolution=louvain_resolution,
214
+ edge_length_threshold=edge_length_threshold,
215
+ random_seed=random_seed,
216
+ )
217
+
218
+ # Log and display hypergeometric test settings
219
+ print(f"Maximum workers: {max_workers}")
220
+ # Run hypergeometric test to compute neighborhood significance
221
+ neighborhood_significance = compute_hypergeom_test(
222
+ neighborhoods=neighborhoods,
223
+ annotations=annotations["matrix"],
224
+ max_workers=max_workers,
225
+ )
226
+
227
+ return neighborhood_significance
228
+
160
229
  def load_graph(
161
230
  self,
162
231
  network: nx.Graph,
@@ -180,7 +249,7 @@ class RISK(NetworkIO, AnnotationsIO):
180
249
  annotations (pd.DataFrame): DataFrame containing annotation data for the network.
181
250
  neighborhoods (dict): Neighborhood enrichment data.
182
251
  tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
183
- pval_cutoff (float, optional): P-value cutoff for significance. Defaults to 0.01.
252
+ pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
184
253
  fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
185
254
  impute_depth (int, optional): Depth for imputing neighbors. Defaults to 1.
186
255
  prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
@@ -208,7 +277,7 @@ class RISK(NetworkIO, AnnotationsIO):
208
277
  max_cluster_size=max_cluster_size,
209
278
  )
210
279
 
211
- print(f"P-value cutoff: {pval_cutoff}")
280
+ print(f"p-value cutoff: {pval_cutoff}")
212
281
  print(f"FDR BH cutoff: {fdr_cutoff}")
213
282
  print(
214
283
  f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
@@ -306,6 +375,7 @@ class RISK(NetworkIO, AnnotationsIO):
306
375
  outline_color=outline_color,
307
376
  outline_scale=outline_scale,
308
377
  )
378
+
309
379
  # Initialize and return a NetworkPlotter object
310
380
  return NetworkPlotter(
311
381
  graph,
@@ -316,6 +386,48 @@ class RISK(NetworkIO, AnnotationsIO):
316
386
  outline_scale=outline_scale,
317
387
  )
318
388
 
389
+ def _load_neighborhoods(
390
+ self,
391
+ network: nx.Graph,
392
+ distance_metric: str = "dijkstra",
393
+ louvain_resolution: float = 0.1,
394
+ edge_length_threshold: float = 0.5,
395
+ random_seed: int = 888,
396
+ ) -> np.ndarray:
397
+ """Load significant neighborhoods for the network.
398
+
399
+ Args:
400
+ network (nx.Graph): The network graph.
401
+ annotations (pd.DataFrame): The matrix of annotations associated with the network.
402
+ distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
403
+ louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
404
+ edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
405
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
406
+
407
+ Returns:
408
+ np.ndarray: Neighborhood matrix calculated based on the selected distance metric.
409
+ """
410
+ # Display the chosen distance metric
411
+ if distance_metric == "louvain":
412
+ for_print_distance_metric = f"louvain (resolution={louvain_resolution})"
413
+ else:
414
+ for_print_distance_metric = distance_metric
415
+ # Log and display neighborhood settings
416
+ print(f"Distance metric: '{for_print_distance_metric}'")
417
+ print(f"Edge length threshold: {edge_length_threshold}")
418
+ print(f"Random seed: {random_seed}")
419
+
420
+ # Compute neighborhoods based on the network and distance metric
421
+ neighborhoods = get_network_neighborhoods(
422
+ network,
423
+ distance_metric,
424
+ edge_length_threshold,
425
+ louvain_resolution=louvain_resolution,
426
+ random_seed=random_seed,
427
+ )
428
+
429
+ return neighborhoods
430
+
319
431
  def _define_top_annotations(
320
432
  self,
321
433
  network: nx.Graph,
risk/stats/__init__.py CHANGED
@@ -3,4 +3,7 @@ risk/stats
3
3
  ~~~~~~~~~~
4
4
  """
5
5
 
6
- from .stats import calculate_significance_matrices, compute_permutation
6
+ from .stats import calculate_significance_matrices
7
+ from .fisher_exact import compute_fisher_exact_test
8
+ from .hypergeom import compute_hypergeom_test
9
+ from .permutation import compute_permutation_test
@@ -0,0 +1,132 @@
1
+ """
2
+ risk/stats/fisher_exact
3
+ ~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from multiprocessing import get_context, Manager
7
+ from tqdm import tqdm
8
+ from typing import Any, Dict
9
+
10
+ import numpy as np
11
+ from scipy.stats import fisher_exact
12
+
13
+
14
+ def compute_fisher_exact_test(
15
+ neighborhoods: np.ndarray,
16
+ annotations: np.ndarray,
17
+ max_workers: int = 4,
18
+ ) -> Dict[str, Any]:
19
+ """Compute Fisher's exact test for enrichment and depletion in neighborhoods.
20
+
21
+ Args:
22
+ neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
23
+ annotations (np.ndarray): Binary matrix representing annotations.
24
+ max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
25
+
26
+ Returns:
27
+ dict: Dictionary containing depletion and enrichment p-values.
28
+ """
29
+ # Ensure that the matrices are binary (boolean) and free of NaN values
30
+ neighborhoods = neighborhoods.astype(bool) # Convert to boolean
31
+ annotations = annotations.astype(bool) # Convert to boolean
32
+
33
+ # Initialize the process of calculating p-values using multiprocessing
34
+ ctx = get_context("spawn")
35
+ manager = Manager()
36
+ progress_counter = manager.Value("i", 0)
37
+ total_tasks = neighborhoods.shape[1] * annotations.shape[1]
38
+
39
+ # Calculate the workload per worker
40
+ chunk_size = total_tasks // max_workers
41
+ remainder = total_tasks % max_workers
42
+
43
+ # Execute the Fisher's exact test using multiprocessing
44
+ with ctx.Pool(max_workers) as pool:
45
+ with tqdm(total=total_tasks, desc="Total progress", position=0) as progress:
46
+ params_list = []
47
+ start_idx = 0
48
+ for i in range(max_workers):
49
+ end_idx = start_idx + chunk_size + (1 if i < remainder else 0)
50
+ params_list.append(
51
+ (neighborhoods, annotations, start_idx, end_idx, progress_counter)
52
+ )
53
+ start_idx = end_idx
54
+
55
+ # Start the Fisher's exact test process in parallel
56
+ results = pool.starmap_async(_fisher_exact_process_subset, params_list, chunksize=1)
57
+
58
+ # Update progress bar based on progress_counter
59
+ while not results.ready():
60
+ progress.update(progress_counter.value - progress.n)
61
+ results.wait(0.05) # Wait for 50ms
62
+ # Ensure progress bar reaches 100%
63
+ progress.update(total_tasks - progress.n)
64
+
65
+ # Accumulate results from each worker
66
+ depletion_pvals, enrichment_pvals = [], []
67
+ for dp, ep in results.get():
68
+ depletion_pvals.extend(dp)
69
+ enrichment_pvals.extend(ep)
70
+
71
+ # Reshape the results back into arrays with the appropriate dimensions
72
+ depletion_pvals = np.array(depletion_pvals).reshape(
73
+ neighborhoods.shape[1], annotations.shape[1]
74
+ )
75
+ enrichment_pvals = np.array(enrichment_pvals).reshape(
76
+ neighborhoods.shape[1], annotations.shape[1]
77
+ )
78
+
79
+ return {
80
+ "depletion_pvals": depletion_pvals,
81
+ "enrichment_pvals": enrichment_pvals,
82
+ }
83
+
84
+
85
+ def _fisher_exact_process_subset(
86
+ neighborhoods: np.ndarray,
87
+ annotations: np.ndarray,
88
+ start_idx: int,
89
+ end_idx: int,
90
+ progress_counter,
91
+ ) -> tuple:
92
+ """Process a subset of neighborhoods using Fisher's exact test.
93
+
94
+ Args:
95
+ neighborhoods (np.ndarray): The full neighborhood matrix.
96
+ annotations (np.ndarray): The annotation matrix.
97
+ start_idx (int): Starting index of the neighborhood-annotation pairs to process.
98
+ end_idx (int): Ending index of the neighborhood-annotation pairs to process.
99
+ progress_counter: Shared counter for tracking progress.
100
+
101
+ Returns:
102
+ tuple: Local p-values for depletion and enrichment.
103
+ """
104
+ # Initialize lists to store p-values for depletion and enrichment
105
+ depletion_pvals = []
106
+ enrichment_pvals = []
107
+ # Process the subset of tasks assigned to this worker
108
+ for idx in range(start_idx, end_idx):
109
+ i = idx // annotations.shape[1] # Neighborhood index
110
+ j = idx % annotations.shape[1] # Annotation index
111
+
112
+ neighborhood = neighborhoods[:, i]
113
+ annotation = annotations[:, j]
114
+
115
+ # Calculate the contingency table values
116
+ TP = np.sum(neighborhood & annotation)
117
+ FP = np.sum(neighborhood & ~annotation)
118
+ FN = np.sum(~neighborhood & annotation)
119
+ TN = np.sum(~neighborhood & ~annotation)
120
+ table = np.array([[TP, FP], [FN, TN]])
121
+
122
+ # Perform Fisher's exact test for depletion (alternative='less')
123
+ _, p_value_depletion = fisher_exact(table, alternative="less")
124
+ depletion_pvals.append(p_value_depletion)
125
+ # Perform Fisher's exact test for enrichment (alternative='greater')
126
+ _, p_value_enrichment = fisher_exact(table, alternative="greater")
127
+ enrichment_pvals.append(p_value_enrichment)
128
+
129
+ # Update the shared progress counter
130
+ progress_counter.value += 1
131
+
132
+ return depletion_pvals, enrichment_pvals
@@ -0,0 +1,131 @@
1
+ """
2
+ risk/stats/hypergeom
3
+ ~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from multiprocessing import get_context, Manager
7
+ from tqdm import tqdm
8
+ from typing import Any, Dict
9
+
10
+ import numpy as np
11
+ from scipy.stats import hypergeom
12
+
13
+
14
+ def compute_hypergeom_test(
15
+ neighborhoods: np.ndarray,
16
+ annotations: np.ndarray,
17
+ max_workers: int = 4,
18
+ ) -> Dict[str, Any]:
19
+ """Compute hypergeometric test for enrichment and depletion in neighborhoods.
20
+
21
+ Args:
22
+ neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
23
+ annotations (np.ndarray): Binary matrix representing annotations.
24
+ max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
25
+
26
+ Returns:
27
+ dict: Dictionary containing depletion and enrichment p-values.
28
+ """
29
+ # Ensure that the matrices are binary (boolean) and free of NaN values
30
+ neighborhoods = neighborhoods.astype(bool) # Convert to boolean
31
+ annotations = annotations.astype(bool) # Convert to boolean
32
+
33
+ # Initialize the process of calculating p-values using multiprocessing
34
+ ctx = get_context("spawn")
35
+ manager = Manager()
36
+ progress_counter = manager.Value("i", 0)
37
+ total_tasks = neighborhoods.shape[1] * annotations.shape[1]
38
+
39
+ # Calculate the workload per worker
40
+ chunk_size = total_tasks // max_workers
41
+ remainder = total_tasks % max_workers
42
+
43
+ # Execute the hypergeometric test using multiprocessing
44
+ with ctx.Pool(max_workers) as pool:
45
+ with tqdm(total=total_tasks, desc="Total progress", position=0) as progress:
46
+ params_list = []
47
+ start_idx = 0
48
+ for i in range(max_workers):
49
+ end_idx = start_idx + chunk_size + (1 if i < remainder else 0)
50
+ params_list.append(
51
+ (neighborhoods, annotations, start_idx, end_idx, progress_counter)
52
+ )
53
+ start_idx = end_idx
54
+
55
+ # Start the hypergeometric test process in parallel
56
+ results = pool.starmap_async(_hypergeom_process_subset, params_list, chunksize=1)
57
+
58
+ # Update progress bar based on progress_counter
59
+ while not results.ready():
60
+ progress.update(progress_counter.value - progress.n)
61
+ results.wait(0.05) # Wait for 50ms
62
+ # Ensure progress bar reaches 100%
63
+ progress.update(total_tasks - progress.n)
64
+
65
+ # Accumulate results from each worker
66
+ depletion_pvals, enrichment_pvals = [], []
67
+ for dp, ep in results.get():
68
+ depletion_pvals.extend(dp)
69
+ enrichment_pvals.extend(ep)
70
+
71
+ # Reshape the results back into arrays with the appropriate dimensions
72
+ depletion_pvals = np.array(depletion_pvals).reshape(
73
+ neighborhoods.shape[1], annotations.shape[1]
74
+ )
75
+ enrichment_pvals = np.array(enrichment_pvals).reshape(
76
+ neighborhoods.shape[1], annotations.shape[1]
77
+ )
78
+
79
+ return {
80
+ "depletion_pvals": depletion_pvals,
81
+ "enrichment_pvals": enrichment_pvals,
82
+ }
83
+
84
+
85
+ def _hypergeom_process_subset(
86
+ neighborhoods: np.ndarray,
87
+ annotations: np.ndarray,
88
+ start_idx: int,
89
+ end_idx: int,
90
+ progress_counter,
91
+ ) -> tuple:
92
+ """Process a subset of neighborhoods using the hypergeometric test.
93
+
94
+ Args:
95
+ neighborhoods (np.ndarray): The full neighborhood matrix.
96
+ annotations (np.ndarray): The annotation matrix.
97
+ start_idx (int): Starting index of the neighborhood-annotation pairs to process.
98
+ end_idx (int): Ending index of the neighborhood-annotation pairs to process.
99
+ progress_counter: Shared counter for tracking progress.
100
+
101
+ Returns:
102
+ tuple: Local p-values for depletion and enrichment.
103
+ """
104
+ # Initialize lists to store p-values for depletion and enrichment
105
+ depletion_pvals = []
106
+ enrichment_pvals = []
107
+ # Process the subset of tasks assigned to this worker
108
+ for idx in range(start_idx, end_idx):
109
+ i = idx // annotations.shape[1] # Neighborhood index
110
+ j = idx % annotations.shape[1] # Annotation index
111
+
112
+ neighborhood = neighborhoods[:, i]
113
+ annotation = annotations[:, j]
114
+
115
+ # Calculate the required values for the hypergeometric test
116
+ M = annotations.shape[0] # Total number of items (population size)
117
+ n = np.sum(annotation) # Total number of successes in population
118
+ N = np.sum(neighborhood) # Total number of draws (sample size)
119
+ k = np.sum(neighborhood & annotation) # Number of successes in sample
120
+
121
+ # Perform hypergeometric test for depletion
122
+ p_value_depletion = hypergeom.cdf(k, M, n, N)
123
+ depletion_pvals.append(p_value_depletion)
124
+ # Perform hypergeometric test for enrichment
125
+ p_value_enrichment = hypergeom.sf(k - 1, M, n, N)
126
+ enrichment_pvals.append(p_value_enrichment)
127
+
128
+ # Update the shared progress counter
129
+ progress_counter.value += 1
130
+
131
+ return depletion_pvals, enrichment_pvals
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/stats/permutation
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .permutation import compute_permutation_test