risk-network 0.0.3b4__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -4
- risk/annotations/annotations.py +4 -2
- risk/annotations/io.py +1 -1
- risk/neighborhoods/neighborhoods.py +15 -2
- risk/network/geometry.py +2 -2
- risk/network/graph.py +4 -4
- risk/network/io.py +234 -53
- risk/network/plot.py +179 -58
- risk/risk.py +187 -75
- risk/stats/__init__.py +4 -1
- risk/stats/fisher_exact.py +132 -0
- risk/stats/hypergeom.py +131 -0
- risk/stats/permutation/__init__.py +6 -0
- risk/stats/permutation/permutation.py +212 -0
- risk/stats/{permutation.py → permutation/test_functions.py} +12 -39
- risk/stats/stats.py +1 -212
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/METADATA +6 -6
- risk_network-0.0.4.dist-info/RECORD +30 -0
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/WHEEL +1 -1
- risk_network-0.0.3b4.dist-info/RECORD +0 -26
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/LICENSE +0 -0
- {risk_network-0.0.3b4.dist-info → risk_network-0.0.4.dist-info}/top_level.txt +0 -0
risk/risk.py
CHANGED
@@ -6,6 +6,7 @@ risk/risk
|
|
6
6
|
from typing import Any, Dict
|
7
7
|
|
8
8
|
import networkx as nx
|
9
|
+
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
|
11
12
|
from risk.annotations import AnnotationsIO, define_top_annotations
|
@@ -17,7 +18,12 @@ from risk.neighborhoods import (
|
|
17
18
|
trim_domains_and_top_annotations,
|
18
19
|
)
|
19
20
|
from risk.network import NetworkIO, NetworkGraph, NetworkPlotter
|
20
|
-
from risk.stats import
|
21
|
+
from risk.stats import (
|
22
|
+
calculate_significance_matrices,
|
23
|
+
compute_fisher_exact_test,
|
24
|
+
compute_hypergeom_test,
|
25
|
+
compute_permutation_test,
|
26
|
+
)
|
21
27
|
|
22
28
|
|
23
29
|
class RISK(NetworkIO, AnnotationsIO):
|
@@ -27,85 +33,39 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
27
33
|
and performing network-based statistical analysis, such as neighborhood significance testing.
|
28
34
|
"""
|
29
35
|
|
30
|
-
def __init__(
|
31
|
-
|
32
|
-
compute_sphere: bool = True,
|
33
|
-
surface_depth: float = 0.0,
|
34
|
-
distance_metric: str = "dijkstra",
|
35
|
-
louvain_resolution: float = 0.1,
|
36
|
-
min_edges_per_node: int = 0,
|
37
|
-
edge_length_threshold: float = 0.5,
|
38
|
-
include_edge_weight: bool = True,
|
39
|
-
weight_label: str = "weight",
|
40
|
-
):
|
41
|
-
"""Initialize the RISK class with configuration settings.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
|
45
|
-
surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
|
46
|
-
distance_metric (str, optional): Distance metric to use in network analysis. Defaults to "dijkstra".
|
47
|
-
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
48
|
-
min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
|
49
|
-
edge_length_threshold (float, optional): Edge length threshold for analysis. Defaults to 0.5.
|
50
|
-
include_edge_weight (bool, optional): Whether to include edge weights in calculations. Defaults to True.
|
51
|
-
weight_label (str, optional): Label for edge weights. Defaults to "weight".
|
52
|
-
"""
|
36
|
+
def __init__(self, *args, **kwargs):
|
37
|
+
"""Initialize the RISK class with configuration settings."""
|
53
38
|
# Initialize and log network parameters
|
54
39
|
params.initialize()
|
55
|
-
|
56
|
-
|
57
|
-
surface_depth=surface_depth,
|
58
|
-
distance_metric=distance_metric,
|
59
|
-
louvain_resolution=louvain_resolution,
|
60
|
-
min_edges_per_node=min_edges_per_node,
|
61
|
-
edge_length_threshold=edge_length_threshold,
|
62
|
-
include_edge_weight=include_edge_weight,
|
63
|
-
weight_label=weight_label,
|
64
|
-
)
|
65
|
-
# Initialize parent classes
|
66
|
-
NetworkIO.__init__(
|
67
|
-
self,
|
68
|
-
compute_sphere=compute_sphere,
|
69
|
-
surface_depth=surface_depth,
|
70
|
-
distance_metric=distance_metric,
|
71
|
-
louvain_resolution=louvain_resolution,
|
72
|
-
min_edges_per_node=min_edges_per_node,
|
73
|
-
edge_length_threshold=edge_length_threshold,
|
74
|
-
include_edge_weight=include_edge_weight,
|
75
|
-
weight_label=weight_label,
|
76
|
-
)
|
77
|
-
AnnotationsIO.__init__(self)
|
78
|
-
|
79
|
-
# Set class attributes
|
80
|
-
self.compute_sphere = compute_sphere
|
81
|
-
self.surface_depth = surface_depth
|
82
|
-
self.distance_metric = distance_metric
|
83
|
-
self.louvain_resolution = louvain_resolution
|
84
|
-
self.min_edges_per_node = min_edges_per_node
|
85
|
-
self.edge_length_threshold = edge_length_threshold
|
86
|
-
self.include_edge_weight = include_edge_weight
|
87
|
-
self.weight_label = weight_label
|
40
|
+
# Initialize the parent classes
|
41
|
+
super().__init__(*args, **kwargs)
|
88
42
|
|
89
43
|
@property
|
90
44
|
def params(self):
|
91
45
|
"""Access the logged parameters."""
|
92
46
|
return params
|
93
47
|
|
94
|
-
def
|
48
|
+
def load_neighborhoods_by_permutation(
|
95
49
|
self,
|
96
50
|
network: nx.Graph,
|
97
51
|
annotations: Dict[str, Any],
|
52
|
+
distance_metric: str = "dijkstra",
|
53
|
+
louvain_resolution: float = 0.1,
|
54
|
+
edge_length_threshold: float = 0.5,
|
98
55
|
score_metric: str = "sum",
|
99
56
|
null_distribution: str = "network",
|
100
57
|
num_permutations: int = 1000,
|
101
58
|
random_seed: int = 888,
|
102
59
|
max_workers: int = 1,
|
103
60
|
) -> Dict[str, Any]:
|
104
|
-
"""Load significant neighborhoods for the network.
|
61
|
+
"""Load significant neighborhoods for the network using the permutation test.
|
105
62
|
|
106
63
|
Args:
|
107
64
|
network (nx.Graph): The network graph.
|
108
65
|
annotations (pd.DataFrame): The matrix of annotations associated with the network.
|
66
|
+
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
67
|
+
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
68
|
+
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
109
69
|
score_metric (str, optional): Scoring metric for neighborhood significance. Defaults to "sum".
|
110
70
|
null_distribution (str, optional): Distribution used for permutation tests. Defaults to "network".
|
111
71
|
num_permutations (int, optional): Number of permutations for significance testing. Defaults to 1000.
|
@@ -118,6 +78,10 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
118
78
|
print_header("Running permutation test")
|
119
79
|
# Log neighborhood analysis parameters
|
120
80
|
params.log_neighborhoods(
|
81
|
+
distance_metric=distance_metric,
|
82
|
+
louvain_resolution=louvain_resolution,
|
83
|
+
edge_length_threshold=edge_length_threshold,
|
84
|
+
statistical_test_function="permutation",
|
121
85
|
score_metric=score_metric,
|
122
86
|
null_distribution=null_distribution,
|
123
87
|
num_permutations=num_permutations,
|
@@ -125,27 +89,22 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
125
89
|
max_workers=max_workers,
|
126
90
|
)
|
127
91
|
|
128
|
-
#
|
129
|
-
|
130
|
-
for_print_distance_metric = f"louvain (resolution={self.louvain_resolution})"
|
131
|
-
else:
|
132
|
-
for_print_distance_metric = self.distance_metric
|
133
|
-
print(f"Distance metric: '{for_print_distance_metric}'")
|
134
|
-
# Compute neighborhoods based on the network and distance metric
|
135
|
-
neighborhoods = get_network_neighborhoods(
|
92
|
+
# Load neighborhoods based on the network and distance metric
|
93
|
+
neighborhoods = self._load_neighborhoods(
|
136
94
|
network,
|
137
|
-
|
138
|
-
|
139
|
-
|
95
|
+
distance_metric,
|
96
|
+
louvain_resolution=louvain_resolution,
|
97
|
+
edge_length_threshold=edge_length_threshold,
|
140
98
|
random_seed=random_seed,
|
141
99
|
)
|
142
100
|
|
143
101
|
# Log and display permutation test settings
|
144
|
-
print(f"Null distribution: '{null_distribution}'")
|
145
102
|
print(f"Neighborhood scoring metric: '{score_metric}'")
|
103
|
+
print(f"Null distribution: '{null_distribution}'")
|
146
104
|
print(f"Number of permutations: {num_permutations}")
|
147
|
-
|
148
|
-
|
105
|
+
print(f"Maximum workers: {max_workers}")
|
106
|
+
# Run permutation test to compute neighborhood significance
|
107
|
+
neighborhood_significance = compute_permutation_test(
|
149
108
|
neighborhoods=neighborhoods,
|
150
109
|
annotations=annotations["matrix"],
|
151
110
|
score_metric=score_metric,
|
@@ -157,6 +116,116 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
157
116
|
|
158
117
|
return neighborhood_significance
|
159
118
|
|
119
|
+
def load_neighborhoods_by_fisher_exact(
|
120
|
+
self,
|
121
|
+
network: nx.Graph,
|
122
|
+
annotations: Dict[str, Any],
|
123
|
+
distance_metric: str = "dijkstra",
|
124
|
+
louvain_resolution: float = 0.1,
|
125
|
+
edge_length_threshold: float = 0.5,
|
126
|
+
random_seed: int = 888,
|
127
|
+
max_workers: int = 1,
|
128
|
+
) -> Dict[str, Any]:
|
129
|
+
"""Load significant neighborhoods for the network using the Fisher's exact test.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
network (nx.Graph): The network graph.
|
133
|
+
annotations (pd.DataFrame): The matrix of annotations associated with the network.
|
134
|
+
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
135
|
+
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
136
|
+
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
137
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
138
|
+
max_workers (int, optional): Maximum number of workers for parallel computation. Defaults to 1.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
dict: Computed significance of neighborhoods.
|
142
|
+
"""
|
143
|
+
print_header("Running Fisher's exact test")
|
144
|
+
# Log neighborhood analysis parameters
|
145
|
+
params.log_neighborhoods(
|
146
|
+
distance_metric=distance_metric,
|
147
|
+
louvain_resolution=louvain_resolution,
|
148
|
+
edge_length_threshold=edge_length_threshold,
|
149
|
+
statistical_test_function="fisher_exact",
|
150
|
+
random_seed=random_seed,
|
151
|
+
max_workers=max_workers,
|
152
|
+
)
|
153
|
+
|
154
|
+
# Load neighborhoods based on the network and distance metric
|
155
|
+
neighborhoods = self._load_neighborhoods(
|
156
|
+
network,
|
157
|
+
distance_metric,
|
158
|
+
louvain_resolution=louvain_resolution,
|
159
|
+
edge_length_threshold=edge_length_threshold,
|
160
|
+
random_seed=random_seed,
|
161
|
+
)
|
162
|
+
|
163
|
+
# Log and display Fisher's exact test settings
|
164
|
+
print(f"Maximum workers: {max_workers}")
|
165
|
+
# Run Fisher's exact test to compute neighborhood significance
|
166
|
+
neighborhood_significance = compute_fisher_exact_test(
|
167
|
+
neighborhoods=neighborhoods,
|
168
|
+
annotations=annotations["matrix"],
|
169
|
+
max_workers=max_workers,
|
170
|
+
)
|
171
|
+
|
172
|
+
return neighborhood_significance
|
173
|
+
|
174
|
+
def load_neighborhoods_by_hypergeom(
|
175
|
+
self,
|
176
|
+
network: nx.Graph,
|
177
|
+
annotations: Dict[str, Any],
|
178
|
+
distance_metric: str = "dijkstra",
|
179
|
+
louvain_resolution: float = 0.1,
|
180
|
+
edge_length_threshold: float = 0.5,
|
181
|
+
random_seed: int = 888,
|
182
|
+
max_workers: int = 1,
|
183
|
+
) -> Dict[str, Any]:
|
184
|
+
"""Load significant neighborhoods for the network using the hypergeometric test.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
network (nx.Graph): The network graph.
|
188
|
+
annotations (pd.DataFrame): The matrix of annotations associated with the network.
|
189
|
+
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
190
|
+
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
191
|
+
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
192
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
193
|
+
max_workers (int, optional): Maximum number of workers for parallel computation. Defaults to 1.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
dict: Computed significance of neighborhoods.
|
197
|
+
"""
|
198
|
+
print_header("Running hypergeometric test")
|
199
|
+
# Log neighborhood analysis parameters
|
200
|
+
params.log_neighborhoods(
|
201
|
+
distance_metric=distance_metric,
|
202
|
+
louvain_resolution=louvain_resolution,
|
203
|
+
edge_length_threshold=edge_length_threshold,
|
204
|
+
statistical_test_function="hypergeom",
|
205
|
+
random_seed=random_seed,
|
206
|
+
max_workers=max_workers,
|
207
|
+
)
|
208
|
+
|
209
|
+
# Load neighborhoods based on the network and distance metric
|
210
|
+
neighborhoods = self._load_neighborhoods(
|
211
|
+
network,
|
212
|
+
distance_metric,
|
213
|
+
louvain_resolution=louvain_resolution,
|
214
|
+
edge_length_threshold=edge_length_threshold,
|
215
|
+
random_seed=random_seed,
|
216
|
+
)
|
217
|
+
|
218
|
+
# Log and display hypergeometric test settings
|
219
|
+
print(f"Maximum workers: {max_workers}")
|
220
|
+
# Run hypergeometric test to compute neighborhood significance
|
221
|
+
neighborhood_significance = compute_hypergeom_test(
|
222
|
+
neighborhoods=neighborhoods,
|
223
|
+
annotations=annotations["matrix"],
|
224
|
+
max_workers=max_workers,
|
225
|
+
)
|
226
|
+
|
227
|
+
return neighborhood_significance
|
228
|
+
|
160
229
|
def load_graph(
|
161
230
|
self,
|
162
231
|
network: nx.Graph,
|
@@ -180,7 +249,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
180
249
|
annotations (pd.DataFrame): DataFrame containing annotation data for the network.
|
181
250
|
neighborhoods (dict): Neighborhood enrichment data.
|
182
251
|
tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
|
183
|
-
pval_cutoff (float, optional):
|
252
|
+
pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
|
184
253
|
fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
|
185
254
|
impute_depth (int, optional): Depth for imputing neighbors. Defaults to 1.
|
186
255
|
prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
|
@@ -208,7 +277,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
208
277
|
max_cluster_size=max_cluster_size,
|
209
278
|
)
|
210
279
|
|
211
|
-
print(f"
|
280
|
+
print(f"p-value cutoff: {pval_cutoff}")
|
212
281
|
print(f"FDR BH cutoff: {fdr_cutoff}")
|
213
282
|
print(
|
214
283
|
f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
|
@@ -306,6 +375,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
306
375
|
outline_color=outline_color,
|
307
376
|
outline_scale=outline_scale,
|
308
377
|
)
|
378
|
+
|
309
379
|
# Initialize and return a NetworkPlotter object
|
310
380
|
return NetworkPlotter(
|
311
381
|
graph,
|
@@ -316,6 +386,48 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
316
386
|
outline_scale=outline_scale,
|
317
387
|
)
|
318
388
|
|
389
|
+
def _load_neighborhoods(
|
390
|
+
self,
|
391
|
+
network: nx.Graph,
|
392
|
+
distance_metric: str = "dijkstra",
|
393
|
+
louvain_resolution: float = 0.1,
|
394
|
+
edge_length_threshold: float = 0.5,
|
395
|
+
random_seed: int = 888,
|
396
|
+
) -> np.ndarray:
|
397
|
+
"""Load significant neighborhoods for the network.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
network (nx.Graph): The network graph.
|
401
|
+
annotations (pd.DataFrame): The matrix of annotations associated with the network.
|
402
|
+
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
403
|
+
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
404
|
+
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
405
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
406
|
+
|
407
|
+
Returns:
|
408
|
+
np.ndarray: Neighborhood matrix calculated based on the selected distance metric.
|
409
|
+
"""
|
410
|
+
# Display the chosen distance metric
|
411
|
+
if distance_metric == "louvain":
|
412
|
+
for_print_distance_metric = f"louvain (resolution={louvain_resolution})"
|
413
|
+
else:
|
414
|
+
for_print_distance_metric = distance_metric
|
415
|
+
# Log and display neighborhood settings
|
416
|
+
print(f"Distance metric: '{for_print_distance_metric}'")
|
417
|
+
print(f"Edge length threshold: {edge_length_threshold}")
|
418
|
+
print(f"Random seed: {random_seed}")
|
419
|
+
|
420
|
+
# Compute neighborhoods based on the network and distance metric
|
421
|
+
neighborhoods = get_network_neighborhoods(
|
422
|
+
network,
|
423
|
+
distance_metric,
|
424
|
+
edge_length_threshold,
|
425
|
+
louvain_resolution=louvain_resolution,
|
426
|
+
random_seed=random_seed,
|
427
|
+
)
|
428
|
+
|
429
|
+
return neighborhoods
|
430
|
+
|
319
431
|
def _define_top_annotations(
|
320
432
|
self,
|
321
433
|
network: nx.Graph,
|
risk/stats/__init__.py
CHANGED
@@ -3,4 +3,7 @@ risk/stats
|
|
3
3
|
~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
from .stats import calculate_significance_matrices
|
6
|
+
from .stats import calculate_significance_matrices
|
7
|
+
from .fisher_exact import compute_fisher_exact_test
|
8
|
+
from .hypergeom import compute_hypergeom_test
|
9
|
+
from .permutation import compute_permutation_test
|
@@ -0,0 +1,132 @@
|
|
1
|
+
"""
|
2
|
+
risk/stats/fisher_exact
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from multiprocessing import get_context, Manager
|
7
|
+
from tqdm import tqdm
|
8
|
+
from typing import Any, Dict
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from scipy.stats import fisher_exact
|
12
|
+
|
13
|
+
|
14
|
+
def compute_fisher_exact_test(
|
15
|
+
neighborhoods: np.ndarray,
|
16
|
+
annotations: np.ndarray,
|
17
|
+
max_workers: int = 4,
|
18
|
+
) -> Dict[str, Any]:
|
19
|
+
"""Compute Fisher's exact test for enrichment and depletion in neighborhoods.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
|
23
|
+
annotations (np.ndarray): Binary matrix representing annotations.
|
24
|
+
max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
dict: Dictionary containing depletion and enrichment p-values.
|
28
|
+
"""
|
29
|
+
# Ensure that the matrices are binary (boolean) and free of NaN values
|
30
|
+
neighborhoods = neighborhoods.astype(bool) # Convert to boolean
|
31
|
+
annotations = annotations.astype(bool) # Convert to boolean
|
32
|
+
|
33
|
+
# Initialize the process of calculating p-values using multiprocessing
|
34
|
+
ctx = get_context("spawn")
|
35
|
+
manager = Manager()
|
36
|
+
progress_counter = manager.Value("i", 0)
|
37
|
+
total_tasks = neighborhoods.shape[1] * annotations.shape[1]
|
38
|
+
|
39
|
+
# Calculate the workload per worker
|
40
|
+
chunk_size = total_tasks // max_workers
|
41
|
+
remainder = total_tasks % max_workers
|
42
|
+
|
43
|
+
# Execute the Fisher's exact test using multiprocessing
|
44
|
+
with ctx.Pool(max_workers) as pool:
|
45
|
+
with tqdm(total=total_tasks, desc="Total progress", position=0) as progress:
|
46
|
+
params_list = []
|
47
|
+
start_idx = 0
|
48
|
+
for i in range(max_workers):
|
49
|
+
end_idx = start_idx + chunk_size + (1 if i < remainder else 0)
|
50
|
+
params_list.append(
|
51
|
+
(neighborhoods, annotations, start_idx, end_idx, progress_counter)
|
52
|
+
)
|
53
|
+
start_idx = end_idx
|
54
|
+
|
55
|
+
# Start the Fisher's exact test process in parallel
|
56
|
+
results = pool.starmap_async(_fisher_exact_process_subset, params_list, chunksize=1)
|
57
|
+
|
58
|
+
# Update progress bar based on progress_counter
|
59
|
+
while not results.ready():
|
60
|
+
progress.update(progress_counter.value - progress.n)
|
61
|
+
results.wait(0.05) # Wait for 50ms
|
62
|
+
# Ensure progress bar reaches 100%
|
63
|
+
progress.update(total_tasks - progress.n)
|
64
|
+
|
65
|
+
# Accumulate results from each worker
|
66
|
+
depletion_pvals, enrichment_pvals = [], []
|
67
|
+
for dp, ep in results.get():
|
68
|
+
depletion_pvals.extend(dp)
|
69
|
+
enrichment_pvals.extend(ep)
|
70
|
+
|
71
|
+
# Reshape the results back into arrays with the appropriate dimensions
|
72
|
+
depletion_pvals = np.array(depletion_pvals).reshape(
|
73
|
+
neighborhoods.shape[1], annotations.shape[1]
|
74
|
+
)
|
75
|
+
enrichment_pvals = np.array(enrichment_pvals).reshape(
|
76
|
+
neighborhoods.shape[1], annotations.shape[1]
|
77
|
+
)
|
78
|
+
|
79
|
+
return {
|
80
|
+
"depletion_pvals": depletion_pvals,
|
81
|
+
"enrichment_pvals": enrichment_pvals,
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
def _fisher_exact_process_subset(
|
86
|
+
neighborhoods: np.ndarray,
|
87
|
+
annotations: np.ndarray,
|
88
|
+
start_idx: int,
|
89
|
+
end_idx: int,
|
90
|
+
progress_counter,
|
91
|
+
) -> tuple:
|
92
|
+
"""Process a subset of neighborhoods using Fisher's exact test.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
neighborhoods (np.ndarray): The full neighborhood matrix.
|
96
|
+
annotations (np.ndarray): The annotation matrix.
|
97
|
+
start_idx (int): Starting index of the neighborhood-annotation pairs to process.
|
98
|
+
end_idx (int): Ending index of the neighborhood-annotation pairs to process.
|
99
|
+
progress_counter: Shared counter for tracking progress.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
tuple: Local p-values for depletion and enrichment.
|
103
|
+
"""
|
104
|
+
# Initialize lists to store p-values for depletion and enrichment
|
105
|
+
depletion_pvals = []
|
106
|
+
enrichment_pvals = []
|
107
|
+
# Process the subset of tasks assigned to this worker
|
108
|
+
for idx in range(start_idx, end_idx):
|
109
|
+
i = idx // annotations.shape[1] # Neighborhood index
|
110
|
+
j = idx % annotations.shape[1] # Annotation index
|
111
|
+
|
112
|
+
neighborhood = neighborhoods[:, i]
|
113
|
+
annotation = annotations[:, j]
|
114
|
+
|
115
|
+
# Calculate the contingency table values
|
116
|
+
TP = np.sum(neighborhood & annotation)
|
117
|
+
FP = np.sum(neighborhood & ~annotation)
|
118
|
+
FN = np.sum(~neighborhood & annotation)
|
119
|
+
TN = np.sum(~neighborhood & ~annotation)
|
120
|
+
table = np.array([[TP, FP], [FN, TN]])
|
121
|
+
|
122
|
+
# Perform Fisher's exact test for depletion (alternative='less')
|
123
|
+
_, p_value_depletion = fisher_exact(table, alternative="less")
|
124
|
+
depletion_pvals.append(p_value_depletion)
|
125
|
+
# Perform Fisher's exact test for enrichment (alternative='greater')
|
126
|
+
_, p_value_enrichment = fisher_exact(table, alternative="greater")
|
127
|
+
enrichment_pvals.append(p_value_enrichment)
|
128
|
+
|
129
|
+
# Update the shared progress counter
|
130
|
+
progress_counter.value += 1
|
131
|
+
|
132
|
+
return depletion_pvals, enrichment_pvals
|
risk/stats/hypergeom.py
ADDED
@@ -0,0 +1,131 @@
|
|
1
|
+
"""
|
2
|
+
risk/stats/hypergeom
|
3
|
+
~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from multiprocessing import get_context, Manager
|
7
|
+
from tqdm import tqdm
|
8
|
+
from typing import Any, Dict
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from scipy.stats import hypergeom
|
12
|
+
|
13
|
+
|
14
|
+
def compute_hypergeom_test(
|
15
|
+
neighborhoods: np.ndarray,
|
16
|
+
annotations: np.ndarray,
|
17
|
+
max_workers: int = 4,
|
18
|
+
) -> Dict[str, Any]:
|
19
|
+
"""Compute hypergeometric test for enrichment and depletion in neighborhoods.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
|
23
|
+
annotations (np.ndarray): Binary matrix representing annotations.
|
24
|
+
max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
dict: Dictionary containing depletion and enrichment p-values.
|
28
|
+
"""
|
29
|
+
# Ensure that the matrices are binary (boolean) and free of NaN values
|
30
|
+
neighborhoods = neighborhoods.astype(bool) # Convert to boolean
|
31
|
+
annotations = annotations.astype(bool) # Convert to boolean
|
32
|
+
|
33
|
+
# Initialize the process of calculating p-values using multiprocessing
|
34
|
+
ctx = get_context("spawn")
|
35
|
+
manager = Manager()
|
36
|
+
progress_counter = manager.Value("i", 0)
|
37
|
+
total_tasks = neighborhoods.shape[1] * annotations.shape[1]
|
38
|
+
|
39
|
+
# Calculate the workload per worker
|
40
|
+
chunk_size = total_tasks // max_workers
|
41
|
+
remainder = total_tasks % max_workers
|
42
|
+
|
43
|
+
# Execute the hypergeometric test using multiprocessing
|
44
|
+
with ctx.Pool(max_workers) as pool:
|
45
|
+
with tqdm(total=total_tasks, desc="Total progress", position=0) as progress:
|
46
|
+
params_list = []
|
47
|
+
start_idx = 0
|
48
|
+
for i in range(max_workers):
|
49
|
+
end_idx = start_idx + chunk_size + (1 if i < remainder else 0)
|
50
|
+
params_list.append(
|
51
|
+
(neighborhoods, annotations, start_idx, end_idx, progress_counter)
|
52
|
+
)
|
53
|
+
start_idx = end_idx
|
54
|
+
|
55
|
+
# Start the hypergeometric test process in parallel
|
56
|
+
results = pool.starmap_async(_hypergeom_process_subset, params_list, chunksize=1)
|
57
|
+
|
58
|
+
# Update progress bar based on progress_counter
|
59
|
+
while not results.ready():
|
60
|
+
progress.update(progress_counter.value - progress.n)
|
61
|
+
results.wait(0.05) # Wait for 50ms
|
62
|
+
# Ensure progress bar reaches 100%
|
63
|
+
progress.update(total_tasks - progress.n)
|
64
|
+
|
65
|
+
# Accumulate results from each worker
|
66
|
+
depletion_pvals, enrichment_pvals = [], []
|
67
|
+
for dp, ep in results.get():
|
68
|
+
depletion_pvals.extend(dp)
|
69
|
+
enrichment_pvals.extend(ep)
|
70
|
+
|
71
|
+
# Reshape the results back into arrays with the appropriate dimensions
|
72
|
+
depletion_pvals = np.array(depletion_pvals).reshape(
|
73
|
+
neighborhoods.shape[1], annotations.shape[1]
|
74
|
+
)
|
75
|
+
enrichment_pvals = np.array(enrichment_pvals).reshape(
|
76
|
+
neighborhoods.shape[1], annotations.shape[1]
|
77
|
+
)
|
78
|
+
|
79
|
+
return {
|
80
|
+
"depletion_pvals": depletion_pvals,
|
81
|
+
"enrichment_pvals": enrichment_pvals,
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
def _hypergeom_process_subset(
|
86
|
+
neighborhoods: np.ndarray,
|
87
|
+
annotations: np.ndarray,
|
88
|
+
start_idx: int,
|
89
|
+
end_idx: int,
|
90
|
+
progress_counter,
|
91
|
+
) -> tuple:
|
92
|
+
"""Process a subset of neighborhoods using the hypergeometric test.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
neighborhoods (np.ndarray): The full neighborhood matrix.
|
96
|
+
annotations (np.ndarray): The annotation matrix.
|
97
|
+
start_idx (int): Starting index of the neighborhood-annotation pairs to process.
|
98
|
+
end_idx (int): Ending index of the neighborhood-annotation pairs to process.
|
99
|
+
progress_counter: Shared counter for tracking progress.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
tuple: Local p-values for depletion and enrichment.
|
103
|
+
"""
|
104
|
+
# Initialize lists to store p-values for depletion and enrichment
|
105
|
+
depletion_pvals = []
|
106
|
+
enrichment_pvals = []
|
107
|
+
# Process the subset of tasks assigned to this worker
|
108
|
+
for idx in range(start_idx, end_idx):
|
109
|
+
i = idx // annotations.shape[1] # Neighborhood index
|
110
|
+
j = idx % annotations.shape[1] # Annotation index
|
111
|
+
|
112
|
+
neighborhood = neighborhoods[:, i]
|
113
|
+
annotation = annotations[:, j]
|
114
|
+
|
115
|
+
# Calculate the required values for the hypergeometric test
|
116
|
+
M = annotations.shape[0] # Total number of items (population size)
|
117
|
+
n = np.sum(annotation) # Total number of successes in population
|
118
|
+
N = np.sum(neighborhood) # Total number of draws (sample size)
|
119
|
+
k = np.sum(neighborhood & annotation) # Number of successes in sample
|
120
|
+
|
121
|
+
# Perform hypergeometric test for depletion
|
122
|
+
p_value_depletion = hypergeom.cdf(k, M, n, N)
|
123
|
+
depletion_pvals.append(p_value_depletion)
|
124
|
+
# Perform hypergeometric test for enrichment
|
125
|
+
p_value_enrichment = hypergeom.sf(k - 1, M, n, N)
|
126
|
+
enrichment_pvals.append(p_value_enrichment)
|
127
|
+
|
128
|
+
# Update the shared progress counter
|
129
|
+
progress_counter.value += 1
|
130
|
+
|
131
|
+
return depletion_pvals, enrichment_pvals
|