risk-network 0.0.4b2__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -5
- risk/annotations/annotations.py +3 -3
- risk/constants.py +2 -2
- risk/neighborhoods/neighborhoods.py +5 -1
- risk/network/geometry.py +2 -2
- risk/network/graph.py +45 -19
- risk/network/io.py +45 -30
- risk/network/plot.py +70 -18
- risk/risk.py +175 -19
- risk/stats/__init__.py +4 -1
- risk/stats/fisher_exact.py +132 -0
- risk/stats/hypergeom.py +131 -0
- risk/stats/permutation/__init__.py +6 -0
- risk/stats/permutation/permutation.py +212 -0
- risk/stats/{permutation.py → permutation/test_functions.py} +12 -39
- risk/stats/stats.py +1 -212
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/METADATA +4 -5
- risk_network-0.0.5.dist-info/RECORD +30 -0
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/WHEEL +1 -1
- risk_network-0.0.4b2.dist-info/RECORD +0 -26
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/LICENSE +0 -0
- {risk_network-0.0.4b2.dist-info → risk_network-0.0.5.dist-info}/top_level.txt +0 -0
risk/risk.py
CHANGED
@@ -6,6 +6,7 @@ risk/risk
|
|
6
6
|
from typing import Any, Dict
|
7
7
|
|
8
8
|
import networkx as nx
|
9
|
+
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
|
11
12
|
from risk.annotations import AnnotationsIO, define_top_annotations
|
@@ -17,7 +18,12 @@ from risk.neighborhoods import (
|
|
17
18
|
trim_domains_and_top_annotations,
|
18
19
|
)
|
19
20
|
from risk.network import NetworkIO, NetworkGraph, NetworkPlotter
|
20
|
-
from risk.stats import
|
21
|
+
from risk.stats import (
|
22
|
+
calculate_significance_matrices,
|
23
|
+
compute_fisher_exact_test,
|
24
|
+
compute_hypergeom_test,
|
25
|
+
compute_permutation_test,
|
26
|
+
)
|
21
27
|
|
22
28
|
|
23
29
|
class RISK(NetworkIO, AnnotationsIO):
|
@@ -39,7 +45,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
39
45
|
"""Access the logged parameters."""
|
40
46
|
return params
|
41
47
|
|
42
|
-
def
|
48
|
+
def load_neighborhoods_by_permutation(
|
43
49
|
self,
|
44
50
|
network: nx.Graph,
|
45
51
|
annotations: Dict[str, Any],
|
@@ -52,7 +58,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
52
58
|
random_seed: int = 888,
|
53
59
|
max_workers: int = 1,
|
54
60
|
) -> Dict[str, Any]:
|
55
|
-
"""Load significant neighborhoods for the network.
|
61
|
+
"""Load significant neighborhoods for the network using the permutation test.
|
56
62
|
|
57
63
|
Args:
|
58
64
|
network (nx.Graph): The network graph.
|
@@ -75,6 +81,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
75
81
|
distance_metric=distance_metric,
|
76
82
|
louvain_resolution=louvain_resolution,
|
77
83
|
edge_length_threshold=edge_length_threshold,
|
84
|
+
statistical_test_function="permutation",
|
78
85
|
score_metric=score_metric,
|
79
86
|
null_distribution=null_distribution,
|
80
87
|
num_permutations=num_permutations,
|
@@ -82,30 +89,22 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
82
89
|
max_workers=max_workers,
|
83
90
|
)
|
84
91
|
|
85
|
-
#
|
86
|
-
|
87
|
-
for_print_distance_metric = f"louvain (resolution={louvain_resolution})"
|
88
|
-
else:
|
89
|
-
for_print_distance_metric = distance_metric
|
90
|
-
print(f"Distance metric: '{for_print_distance_metric}'")
|
91
|
-
print(f"Edge length threshold: {edge_length_threshold}")
|
92
|
-
# Compute neighborhoods based on the network and distance metric
|
93
|
-
neighborhoods = get_network_neighborhoods(
|
92
|
+
# Load neighborhoods based on the network and distance metric
|
93
|
+
neighborhoods = self._load_neighborhoods(
|
94
94
|
network,
|
95
95
|
distance_metric,
|
96
|
-
edge_length_threshold,
|
97
96
|
louvain_resolution=louvain_resolution,
|
97
|
+
edge_length_threshold=edge_length_threshold,
|
98
98
|
random_seed=random_seed,
|
99
99
|
)
|
100
100
|
|
101
101
|
# Log and display permutation test settings
|
102
|
-
print(f"Null distribution: '{null_distribution}'")
|
103
102
|
print(f"Neighborhood scoring metric: '{score_metric}'")
|
103
|
+
print(f"Null distribution: '{null_distribution}'")
|
104
104
|
print(f"Number of permutations: {num_permutations}")
|
105
|
-
print(f"Random seed: {random_seed}")
|
106
105
|
print(f"Maximum workers: {max_workers}")
|
107
|
-
# Run
|
108
|
-
neighborhood_significance =
|
106
|
+
# Run permutation test to compute neighborhood significance
|
107
|
+
neighborhood_significance = compute_permutation_test(
|
109
108
|
neighborhoods=neighborhoods,
|
110
109
|
annotations=annotations["matrix"],
|
111
110
|
score_metric=score_metric,
|
@@ -117,6 +116,116 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
117
116
|
|
118
117
|
return neighborhood_significance
|
119
118
|
|
119
|
+
def load_neighborhoods_by_fisher_exact(
|
120
|
+
self,
|
121
|
+
network: nx.Graph,
|
122
|
+
annotations: Dict[str, Any],
|
123
|
+
distance_metric: str = "dijkstra",
|
124
|
+
louvain_resolution: float = 0.1,
|
125
|
+
edge_length_threshold: float = 0.5,
|
126
|
+
random_seed: int = 888,
|
127
|
+
max_workers: int = 1,
|
128
|
+
) -> Dict[str, Any]:
|
129
|
+
"""Load significant neighborhoods for the network using the Fisher's exact test.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
network (nx.Graph): The network graph.
|
133
|
+
annotations (pd.DataFrame): The matrix of annotations associated with the network.
|
134
|
+
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
135
|
+
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
136
|
+
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
137
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
138
|
+
max_workers (int, optional): Maximum number of workers for parallel computation. Defaults to 1.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
dict: Computed significance of neighborhoods.
|
142
|
+
"""
|
143
|
+
print_header("Running Fisher's exact test")
|
144
|
+
# Log neighborhood analysis parameters
|
145
|
+
params.log_neighborhoods(
|
146
|
+
distance_metric=distance_metric,
|
147
|
+
louvain_resolution=louvain_resolution,
|
148
|
+
edge_length_threshold=edge_length_threshold,
|
149
|
+
statistical_test_function="fisher_exact",
|
150
|
+
random_seed=random_seed,
|
151
|
+
max_workers=max_workers,
|
152
|
+
)
|
153
|
+
|
154
|
+
# Load neighborhoods based on the network and distance metric
|
155
|
+
neighborhoods = self._load_neighborhoods(
|
156
|
+
network,
|
157
|
+
distance_metric,
|
158
|
+
louvain_resolution=louvain_resolution,
|
159
|
+
edge_length_threshold=edge_length_threshold,
|
160
|
+
random_seed=random_seed,
|
161
|
+
)
|
162
|
+
|
163
|
+
# Log and display Fisher's exact test settings
|
164
|
+
print(f"Maximum workers: {max_workers}")
|
165
|
+
# Run Fisher's exact test to compute neighborhood significance
|
166
|
+
neighborhood_significance = compute_fisher_exact_test(
|
167
|
+
neighborhoods=neighborhoods,
|
168
|
+
annotations=annotations["matrix"],
|
169
|
+
max_workers=max_workers,
|
170
|
+
)
|
171
|
+
|
172
|
+
return neighborhood_significance
|
173
|
+
|
174
|
+
def load_neighborhoods_by_hypergeom(
|
175
|
+
self,
|
176
|
+
network: nx.Graph,
|
177
|
+
annotations: Dict[str, Any],
|
178
|
+
distance_metric: str = "dijkstra",
|
179
|
+
louvain_resolution: float = 0.1,
|
180
|
+
edge_length_threshold: float = 0.5,
|
181
|
+
random_seed: int = 888,
|
182
|
+
max_workers: int = 1,
|
183
|
+
) -> Dict[str, Any]:
|
184
|
+
"""Load significant neighborhoods for the network using the hypergeometric test.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
network (nx.Graph): The network graph.
|
188
|
+
annotations (pd.DataFrame): The matrix of annotations associated with the network.
|
189
|
+
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
190
|
+
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
191
|
+
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
192
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
193
|
+
max_workers (int, optional): Maximum number of workers for parallel computation. Defaults to 1.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
dict: Computed significance of neighborhoods.
|
197
|
+
"""
|
198
|
+
print_header("Running hypergeometric test")
|
199
|
+
# Log neighborhood analysis parameters
|
200
|
+
params.log_neighborhoods(
|
201
|
+
distance_metric=distance_metric,
|
202
|
+
louvain_resolution=louvain_resolution,
|
203
|
+
edge_length_threshold=edge_length_threshold,
|
204
|
+
statistical_test_function="hypergeom",
|
205
|
+
random_seed=random_seed,
|
206
|
+
max_workers=max_workers,
|
207
|
+
)
|
208
|
+
|
209
|
+
# Load neighborhoods based on the network and distance metric
|
210
|
+
neighborhoods = self._load_neighborhoods(
|
211
|
+
network,
|
212
|
+
distance_metric,
|
213
|
+
louvain_resolution=louvain_resolution,
|
214
|
+
edge_length_threshold=edge_length_threshold,
|
215
|
+
random_seed=random_seed,
|
216
|
+
)
|
217
|
+
|
218
|
+
# Log and display hypergeometric test settings
|
219
|
+
print(f"Maximum workers: {max_workers}")
|
220
|
+
# Run hypergeometric test to compute neighborhood significance
|
221
|
+
neighborhood_significance = compute_hypergeom_test(
|
222
|
+
neighborhoods=neighborhoods,
|
223
|
+
annotations=annotations["matrix"],
|
224
|
+
max_workers=max_workers,
|
225
|
+
)
|
226
|
+
|
227
|
+
return neighborhood_significance
|
228
|
+
|
120
229
|
def load_graph(
|
121
230
|
self,
|
122
231
|
network: nx.Graph,
|
@@ -140,7 +249,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
140
249
|
annotations (pd.DataFrame): DataFrame containing annotation data for the network.
|
141
250
|
neighborhoods (dict): Neighborhood enrichment data.
|
142
251
|
tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
|
143
|
-
pval_cutoff (float, optional):
|
252
|
+
pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
|
144
253
|
fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
|
145
254
|
impute_depth (int, optional): Depth for imputing neighbors. Defaults to 1.
|
146
255
|
prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
|
@@ -168,7 +277,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
168
277
|
max_cluster_size=max_cluster_size,
|
169
278
|
)
|
170
279
|
|
171
|
-
print(f"
|
280
|
+
print(f"p-value cutoff: {pval_cutoff}")
|
172
281
|
print(f"FDR BH cutoff: {fdr_cutoff}")
|
173
282
|
print(
|
174
283
|
f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
|
@@ -243,6 +352,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
243
352
|
plot_outline: bool = True,
|
244
353
|
outline_color: str = "black",
|
245
354
|
outline_scale: float = 1.00,
|
355
|
+
linestyle: str = "dashed",
|
246
356
|
) -> NetworkPlotter:
|
247
357
|
"""Get a NetworkPlotter object for plotting.
|
248
358
|
|
@@ -253,6 +363,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
253
363
|
plot_outline (bool, optional): Whether to plot the network outline. Defaults to True.
|
254
364
|
outline_color (str, optional): Color of the outline. Defaults to "black".
|
255
365
|
outline_scale (float, optional): Scaling factor for the outline. Defaults to 1.00.
|
366
|
+
linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
|
256
367
|
|
257
368
|
Returns:
|
258
369
|
NetworkPlotter: A NetworkPlotter object configured with the given parameters.
|
@@ -265,7 +376,9 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
265
376
|
plot_outline=plot_outline,
|
266
377
|
outline_color=outline_color,
|
267
378
|
outline_scale=outline_scale,
|
379
|
+
linestyle=linestyle,
|
268
380
|
)
|
381
|
+
|
269
382
|
# Initialize and return a NetworkPlotter object
|
270
383
|
return NetworkPlotter(
|
271
384
|
graph,
|
@@ -274,8 +387,51 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
274
387
|
plot_outline=plot_outline,
|
275
388
|
outline_color=outline_color,
|
276
389
|
outline_scale=outline_scale,
|
390
|
+
linestyle=linestyle,
|
277
391
|
)
|
278
392
|
|
393
|
+
def _load_neighborhoods(
|
394
|
+
self,
|
395
|
+
network: nx.Graph,
|
396
|
+
distance_metric: str = "dijkstra",
|
397
|
+
louvain_resolution: float = 0.1,
|
398
|
+
edge_length_threshold: float = 0.5,
|
399
|
+
random_seed: int = 888,
|
400
|
+
) -> np.ndarray:
|
401
|
+
"""Load significant neighborhoods for the network.
|
402
|
+
|
403
|
+
Args:
|
404
|
+
network (nx.Graph): The network graph.
|
405
|
+
annotations (pd.DataFrame): The matrix of annotations associated with the network.
|
406
|
+
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
407
|
+
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
408
|
+
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
409
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
410
|
+
|
411
|
+
Returns:
|
412
|
+
np.ndarray: Neighborhood matrix calculated based on the selected distance metric.
|
413
|
+
"""
|
414
|
+
# Display the chosen distance metric
|
415
|
+
if distance_metric == "louvain":
|
416
|
+
for_print_distance_metric = f"louvain (resolution={louvain_resolution})"
|
417
|
+
else:
|
418
|
+
for_print_distance_metric = distance_metric
|
419
|
+
# Log and display neighborhood settings
|
420
|
+
print(f"Distance metric: '{for_print_distance_metric}'")
|
421
|
+
print(f"Edge length threshold: {edge_length_threshold}")
|
422
|
+
print(f"Random seed: {random_seed}")
|
423
|
+
|
424
|
+
# Compute neighborhoods based on the network and distance metric
|
425
|
+
neighborhoods = get_network_neighborhoods(
|
426
|
+
network,
|
427
|
+
distance_metric,
|
428
|
+
edge_length_threshold,
|
429
|
+
louvain_resolution=louvain_resolution,
|
430
|
+
random_seed=random_seed,
|
431
|
+
)
|
432
|
+
|
433
|
+
return neighborhoods
|
434
|
+
|
279
435
|
def _define_top_annotations(
|
280
436
|
self,
|
281
437
|
network: nx.Graph,
|
risk/stats/__init__.py
CHANGED
@@ -3,4 +3,7 @@ risk/stats
|
|
3
3
|
~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
from .stats import calculate_significance_matrices
|
6
|
+
from .stats import calculate_significance_matrices
|
7
|
+
from .fisher_exact import compute_fisher_exact_test
|
8
|
+
from .hypergeom import compute_hypergeom_test
|
9
|
+
from .permutation import compute_permutation_test
|
@@ -0,0 +1,132 @@
|
|
1
|
+
"""
|
2
|
+
risk/stats/fisher_exact
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from multiprocessing import get_context, Manager
|
7
|
+
from tqdm import tqdm
|
8
|
+
from typing import Any, Dict
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from scipy.stats import fisher_exact
|
12
|
+
|
13
|
+
|
14
|
+
def compute_fisher_exact_test(
|
15
|
+
neighborhoods: np.ndarray,
|
16
|
+
annotations: np.ndarray,
|
17
|
+
max_workers: int = 4,
|
18
|
+
) -> Dict[str, Any]:
|
19
|
+
"""Compute Fisher's exact test for enrichment and depletion in neighborhoods.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
|
23
|
+
annotations (np.ndarray): Binary matrix representing annotations.
|
24
|
+
max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
dict: Dictionary containing depletion and enrichment p-values.
|
28
|
+
"""
|
29
|
+
# Ensure that the matrices are binary (boolean) and free of NaN values
|
30
|
+
neighborhoods = neighborhoods.astype(bool) # Convert to boolean
|
31
|
+
annotations = annotations.astype(bool) # Convert to boolean
|
32
|
+
|
33
|
+
# Initialize the process of calculating p-values using multiprocessing
|
34
|
+
ctx = get_context("spawn")
|
35
|
+
manager = Manager()
|
36
|
+
progress_counter = manager.Value("i", 0)
|
37
|
+
total_tasks = neighborhoods.shape[1] * annotations.shape[1]
|
38
|
+
|
39
|
+
# Calculate the workload per worker
|
40
|
+
chunk_size = total_tasks // max_workers
|
41
|
+
remainder = total_tasks % max_workers
|
42
|
+
|
43
|
+
# Execute the Fisher's exact test using multiprocessing
|
44
|
+
with ctx.Pool(max_workers) as pool:
|
45
|
+
with tqdm(total=total_tasks, desc="Total progress", position=0) as progress:
|
46
|
+
params_list = []
|
47
|
+
start_idx = 0
|
48
|
+
for i in range(max_workers):
|
49
|
+
end_idx = start_idx + chunk_size + (1 if i < remainder else 0)
|
50
|
+
params_list.append(
|
51
|
+
(neighborhoods, annotations, start_idx, end_idx, progress_counter)
|
52
|
+
)
|
53
|
+
start_idx = end_idx
|
54
|
+
|
55
|
+
# Start the Fisher's exact test process in parallel
|
56
|
+
results = pool.starmap_async(_fisher_exact_process_subset, params_list, chunksize=1)
|
57
|
+
|
58
|
+
# Update progress bar based on progress_counter
|
59
|
+
while not results.ready():
|
60
|
+
progress.update(progress_counter.value - progress.n)
|
61
|
+
results.wait(0.05) # Wait for 50ms
|
62
|
+
# Ensure progress bar reaches 100%
|
63
|
+
progress.update(total_tasks - progress.n)
|
64
|
+
|
65
|
+
# Accumulate results from each worker
|
66
|
+
depletion_pvals, enrichment_pvals = [], []
|
67
|
+
for dp, ep in results.get():
|
68
|
+
depletion_pvals.extend(dp)
|
69
|
+
enrichment_pvals.extend(ep)
|
70
|
+
|
71
|
+
# Reshape the results back into arrays with the appropriate dimensions
|
72
|
+
depletion_pvals = np.array(depletion_pvals).reshape(
|
73
|
+
neighborhoods.shape[1], annotations.shape[1]
|
74
|
+
)
|
75
|
+
enrichment_pvals = np.array(enrichment_pvals).reshape(
|
76
|
+
neighborhoods.shape[1], annotations.shape[1]
|
77
|
+
)
|
78
|
+
|
79
|
+
return {
|
80
|
+
"depletion_pvals": depletion_pvals,
|
81
|
+
"enrichment_pvals": enrichment_pvals,
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
def _fisher_exact_process_subset(
|
86
|
+
neighborhoods: np.ndarray,
|
87
|
+
annotations: np.ndarray,
|
88
|
+
start_idx: int,
|
89
|
+
end_idx: int,
|
90
|
+
progress_counter,
|
91
|
+
) -> tuple:
|
92
|
+
"""Process a subset of neighborhoods using Fisher's exact test.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
neighborhoods (np.ndarray): The full neighborhood matrix.
|
96
|
+
annotations (np.ndarray): The annotation matrix.
|
97
|
+
start_idx (int): Starting index of the neighborhood-annotation pairs to process.
|
98
|
+
end_idx (int): Ending index of the neighborhood-annotation pairs to process.
|
99
|
+
progress_counter: Shared counter for tracking progress.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
tuple: Local p-values for depletion and enrichment.
|
103
|
+
"""
|
104
|
+
# Initialize lists to store p-values for depletion and enrichment
|
105
|
+
depletion_pvals = []
|
106
|
+
enrichment_pvals = []
|
107
|
+
# Process the subset of tasks assigned to this worker
|
108
|
+
for idx in range(start_idx, end_idx):
|
109
|
+
i = idx // annotations.shape[1] # Neighborhood index
|
110
|
+
j = idx % annotations.shape[1] # Annotation index
|
111
|
+
|
112
|
+
neighborhood = neighborhoods[:, i]
|
113
|
+
annotation = annotations[:, j]
|
114
|
+
|
115
|
+
# Calculate the contingency table values
|
116
|
+
TP = np.sum(neighborhood & annotation)
|
117
|
+
FP = np.sum(neighborhood & ~annotation)
|
118
|
+
FN = np.sum(~neighborhood & annotation)
|
119
|
+
TN = np.sum(~neighborhood & ~annotation)
|
120
|
+
table = np.array([[TP, FP], [FN, TN]])
|
121
|
+
|
122
|
+
# Perform Fisher's exact test for depletion (alternative='less')
|
123
|
+
_, p_value_depletion = fisher_exact(table, alternative="less")
|
124
|
+
depletion_pvals.append(p_value_depletion)
|
125
|
+
# Perform Fisher's exact test for enrichment (alternative='greater')
|
126
|
+
_, p_value_enrichment = fisher_exact(table, alternative="greater")
|
127
|
+
enrichment_pvals.append(p_value_enrichment)
|
128
|
+
|
129
|
+
# Update the shared progress counter
|
130
|
+
progress_counter.value += 1
|
131
|
+
|
132
|
+
return depletion_pvals, enrichment_pvals
|
risk/stats/hypergeom.py
ADDED
@@ -0,0 +1,131 @@
|
|
1
|
+
"""
|
2
|
+
risk/stats/hypergeom
|
3
|
+
~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from multiprocessing import get_context, Manager
|
7
|
+
from tqdm import tqdm
|
8
|
+
from typing import Any, Dict
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from scipy.stats import hypergeom
|
12
|
+
|
13
|
+
|
14
|
+
def compute_hypergeom_test(
|
15
|
+
neighborhoods: np.ndarray,
|
16
|
+
annotations: np.ndarray,
|
17
|
+
max_workers: int = 4,
|
18
|
+
) -> Dict[str, Any]:
|
19
|
+
"""Compute hypergeometric test for enrichment and depletion in neighborhoods.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
|
23
|
+
annotations (np.ndarray): Binary matrix representing annotations.
|
24
|
+
max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
dict: Dictionary containing depletion and enrichment p-values.
|
28
|
+
"""
|
29
|
+
# Ensure that the matrices are binary (boolean) and free of NaN values
|
30
|
+
neighborhoods = neighborhoods.astype(bool) # Convert to boolean
|
31
|
+
annotations = annotations.astype(bool) # Convert to boolean
|
32
|
+
|
33
|
+
# Initialize the process of calculating p-values using multiprocessing
|
34
|
+
ctx = get_context("spawn")
|
35
|
+
manager = Manager()
|
36
|
+
progress_counter = manager.Value("i", 0)
|
37
|
+
total_tasks = neighborhoods.shape[1] * annotations.shape[1]
|
38
|
+
|
39
|
+
# Calculate the workload per worker
|
40
|
+
chunk_size = total_tasks // max_workers
|
41
|
+
remainder = total_tasks % max_workers
|
42
|
+
|
43
|
+
# Execute the hypergeometric test using multiprocessing
|
44
|
+
with ctx.Pool(max_workers) as pool:
|
45
|
+
with tqdm(total=total_tasks, desc="Total progress", position=0) as progress:
|
46
|
+
params_list = []
|
47
|
+
start_idx = 0
|
48
|
+
for i in range(max_workers):
|
49
|
+
end_idx = start_idx + chunk_size + (1 if i < remainder else 0)
|
50
|
+
params_list.append(
|
51
|
+
(neighborhoods, annotations, start_idx, end_idx, progress_counter)
|
52
|
+
)
|
53
|
+
start_idx = end_idx
|
54
|
+
|
55
|
+
# Start the hypergeometric test process in parallel
|
56
|
+
results = pool.starmap_async(_hypergeom_process_subset, params_list, chunksize=1)
|
57
|
+
|
58
|
+
# Update progress bar based on progress_counter
|
59
|
+
while not results.ready():
|
60
|
+
progress.update(progress_counter.value - progress.n)
|
61
|
+
results.wait(0.05) # Wait for 50ms
|
62
|
+
# Ensure progress bar reaches 100%
|
63
|
+
progress.update(total_tasks - progress.n)
|
64
|
+
|
65
|
+
# Accumulate results from each worker
|
66
|
+
depletion_pvals, enrichment_pvals = [], []
|
67
|
+
for dp, ep in results.get():
|
68
|
+
depletion_pvals.extend(dp)
|
69
|
+
enrichment_pvals.extend(ep)
|
70
|
+
|
71
|
+
# Reshape the results back into arrays with the appropriate dimensions
|
72
|
+
depletion_pvals = np.array(depletion_pvals).reshape(
|
73
|
+
neighborhoods.shape[1], annotations.shape[1]
|
74
|
+
)
|
75
|
+
enrichment_pvals = np.array(enrichment_pvals).reshape(
|
76
|
+
neighborhoods.shape[1], annotations.shape[1]
|
77
|
+
)
|
78
|
+
|
79
|
+
return {
|
80
|
+
"depletion_pvals": depletion_pvals,
|
81
|
+
"enrichment_pvals": enrichment_pvals,
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
def _hypergeom_process_subset(
|
86
|
+
neighborhoods: np.ndarray,
|
87
|
+
annotations: np.ndarray,
|
88
|
+
start_idx: int,
|
89
|
+
end_idx: int,
|
90
|
+
progress_counter,
|
91
|
+
) -> tuple:
|
92
|
+
"""Process a subset of neighborhoods using the hypergeometric test.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
neighborhoods (np.ndarray): The full neighborhood matrix.
|
96
|
+
annotations (np.ndarray): The annotation matrix.
|
97
|
+
start_idx (int): Starting index of the neighborhood-annotation pairs to process.
|
98
|
+
end_idx (int): Ending index of the neighborhood-annotation pairs to process.
|
99
|
+
progress_counter: Shared counter for tracking progress.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
tuple: Local p-values for depletion and enrichment.
|
103
|
+
"""
|
104
|
+
# Initialize lists to store p-values for depletion and enrichment
|
105
|
+
depletion_pvals = []
|
106
|
+
enrichment_pvals = []
|
107
|
+
# Process the subset of tasks assigned to this worker
|
108
|
+
for idx in range(start_idx, end_idx):
|
109
|
+
i = idx // annotations.shape[1] # Neighborhood index
|
110
|
+
j = idx % annotations.shape[1] # Annotation index
|
111
|
+
|
112
|
+
neighborhood = neighborhoods[:, i]
|
113
|
+
annotation = annotations[:, j]
|
114
|
+
|
115
|
+
# Calculate the required values for the hypergeometric test
|
116
|
+
M = annotations.shape[0] # Total number of items (population size)
|
117
|
+
n = np.sum(annotation) # Total number of successes in population
|
118
|
+
N = np.sum(neighborhood) # Total number of draws (sample size)
|
119
|
+
k = np.sum(neighborhood & annotation) # Number of successes in sample
|
120
|
+
|
121
|
+
# Perform hypergeometric test for depletion
|
122
|
+
p_value_depletion = hypergeom.cdf(k, M, n, N)
|
123
|
+
depletion_pvals.append(p_value_depletion)
|
124
|
+
# Perform hypergeometric test for enrichment
|
125
|
+
p_value_enrichment = hypergeom.sf(k - 1, M, n, N)
|
126
|
+
enrichment_pvals.append(p_value_enrichment)
|
127
|
+
|
128
|
+
# Update the shared progress counter
|
129
|
+
progress_counter.value += 1
|
130
|
+
|
131
|
+
return depletion_pvals, enrichment_pvals
|