risk-network 0.0.16b0__py3-none-any.whl → 0.0.16b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -2
- risk/{_annotation → annotation}/__init__.py +2 -2
- risk/{_annotation → annotation}/_nltk_setup.py +3 -3
- risk/{_annotation/_annotation.py → annotation/annotation.py} +22 -25
- risk/{_annotation/_io.py → annotation/io.py} +4 -4
- risk/cluster/__init__.py +8 -0
- risk/{_neighborhoods → cluster}/_community.py +37 -37
- risk/cluster/api.py +273 -0
- risk/{_neighborhoods/_neighborhoods.py → cluster/cluster.py} +127 -98
- risk/{_neighborhoods/_domains.py → cluster/label.py} +18 -12
- risk/{_log → log}/__init__.py +2 -2
- risk/{_log/_console.py → log/console.py} +2 -2
- risk/{_log/_parameters.py → log/parameters.py} +20 -10
- risk/network/__init__.py +8 -0
- risk/network/graph/__init__.py +7 -0
- risk/{_network/_graph → network/graph}/_stats.py +2 -2
- risk/{_network/_graph → network/graph}/_summary.py +13 -13
- risk/{_network/_graph/_api.py → network/graph/api.py} +37 -39
- risk/{_network/_graph/_graph.py → network/graph/graph.py} +5 -5
- risk/{_network/_io.py → network/io.py} +9 -4
- risk/network/plotter/__init__.py +6 -0
- risk/{_network/_plotter → network/plotter}/_canvas.py +6 -6
- risk/{_network/_plotter → network/plotter}/_contour.py +4 -4
- risk/{_network/_plotter → network/plotter}/_labels.py +6 -6
- risk/{_network/_plotter → network/plotter}/_network.py +7 -7
- risk/{_network/_plotter → network/plotter}/_plotter.py +5 -5
- risk/network/plotter/_utils/__init__.py +7 -0
- risk/{_network/_plotter/_utils/_colors.py → network/plotter/_utils/colors.py} +3 -3
- risk/{_network/_plotter/_utils/_layout.py → network/plotter/_utils/layout.py} +2 -2
- risk/{_network/_plotter/_api.py → network/plotter/api.py} +5 -5
- risk/{_risk.py → risk.py} +9 -8
- risk/stats/__init__.py +6 -0
- risk/stats/_stats/__init__.py +11 -0
- risk/stats/_stats/permutation/__init__.py +6 -0
- risk/stats/_stats/permutation/_test_functions.py +72 -0
- risk/{_neighborhoods/_stats/_permutation/_permutation.py → stats/_stats/permutation/permutation.py} +35 -37
- risk/{_neighborhoods/_stats/_tests.py → stats/_stats/tests.py} +32 -34
- risk/stats/api.py +202 -0
- {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/METADATA +2 -2
- risk_network-0.0.16b2.dist-info/RECORD +43 -0
- risk/_neighborhoods/__init__.py +0 -8
- risk/_neighborhoods/_api.py +0 -354
- risk/_neighborhoods/_stats/__init__.py +0 -11
- risk/_neighborhoods/_stats/_permutation/__init__.py +0 -6
- risk/_neighborhoods/_stats/_permutation/_test_functions.py +0 -72
- risk/_network/__init__.py +0 -8
- risk/_network/_graph/__init__.py +0 -7
- risk/_network/_plotter/__init__.py +0 -6
- risk/_network/_plotter/_utils/__init__.py +0 -7
- risk_network-0.0.16b0.dist-info/RECORD +0 -41
- {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/WHEEL +0 -0
- {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/licenses/LICENSE +0 -0
- {risk_network-0.0.16b0.dist-info → risk_network-0.0.16b2.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
|
-
risk/
|
|
3
|
-
|
|
2
|
+
risk/network/plotter/_network
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from typing import Any, Dict, List, Tuple, Union
|
|
@@ -8,8 +8,8 @@ from typing import Any, Dict, List, Tuple, Union
|
|
|
8
8
|
import networkx as nx
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
11
|
-
from ...
|
|
12
|
-
from ..
|
|
11
|
+
from ...log import params
|
|
12
|
+
from ..graph import Graph
|
|
13
13
|
from ._utils import get_domain_colors, to_rgba
|
|
14
14
|
|
|
15
15
|
|
|
@@ -273,14 +273,14 @@ class Network:
|
|
|
273
273
|
return adjusted_network_colors
|
|
274
274
|
|
|
275
275
|
def get_annotated_node_sizes(
|
|
276
|
-
self, significant_size: int = 50, nonsignificant_size: int = 25
|
|
276
|
+
self, significant_size: Union[int, float] = 50, nonsignificant_size: Union[int, float] = 25
|
|
277
277
|
) -> np.ndarray:
|
|
278
278
|
"""
|
|
279
279
|
Adjust the sizes of nodes in the network graph based on whether they are significant or not.
|
|
280
280
|
|
|
281
281
|
Args:
|
|
282
|
-
significant_size (int): Size for significant nodes. Defaults to 50.
|
|
283
|
-
nonsignificant_size (int): Size for non-significant nodes. Defaults to 25.
|
|
282
|
+
significant_size (int or float): Size for significant nodes. Can be an integer or float value. Defaults to 50.
|
|
283
|
+
nonsignificant_size (int or float): Size for non-significant nodes. Can be an integer or float value. Defaults to 25.
|
|
284
284
|
|
|
285
285
|
Returns:
|
|
286
286
|
np.ndarray: Array of node sizes, with significant nodes larger than non-significant ones.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
|
-
risk/
|
|
3
|
-
|
|
2
|
+
risk/network/plotter/_plotter
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from typing import List, Tuple, Union
|
|
@@ -8,8 +8,8 @@ from typing import List, Tuple, Union
|
|
|
8
8
|
import matplotlib.pyplot as plt
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
11
|
-
from ...
|
|
12
|
-
from ..
|
|
11
|
+
from ...log import params
|
|
12
|
+
from ..graph.graph import Graph
|
|
13
13
|
from ._canvas import Canvas
|
|
14
14
|
from ._contour import Contour
|
|
15
15
|
from ._labels import Labels
|
|
@@ -123,7 +123,7 @@ class Plotter(Canvas, Network, Contour, Labels):
|
|
|
123
123
|
Args:
|
|
124
124
|
*args: Positional arguments passed to `plt.savefig`.
|
|
125
125
|
pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
|
|
126
|
-
dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to
|
|
126
|
+
dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 100.
|
|
127
127
|
**kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
|
|
128
128
|
"""
|
|
129
129
|
# Ensure user-provided kwargs take precedence
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
|
-
risk/
|
|
3
|
-
|
|
2
|
+
risk/network/plotter/_utils/colors
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from typing import Any, Dict, List, Tuple, Union
|
|
@@ -9,7 +9,7 @@ import matplotlib
|
|
|
9
9
|
import matplotlib.colors as mcolors
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
12
|
-
from ...
|
|
12
|
+
from ...graph import Graph
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def get_annotated_domain_colors(
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
"""
|
|
2
|
-
risk/
|
|
3
|
-
|
|
2
|
+
risk/network/plotter/api
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from typing import List, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
-
from ...
|
|
11
|
-
from ..
|
|
10
|
+
from ...log import log_header
|
|
11
|
+
from ..graph import Graph
|
|
12
12
|
from ._plotter import Plotter
|
|
13
13
|
|
|
14
14
|
|
|
@@ -32,7 +32,7 @@ class PlotterAPI:
|
|
|
32
32
|
|
|
33
33
|
Args:
|
|
34
34
|
graph (Graph): The graph to plot.
|
|
35
|
-
figsize (List, Tuple, or np.ndarray, optional):
|
|
35
|
+
figsize (List, Tuple, or np.ndarray, optional): Figure size in inches (width, height). Defaults to (10, 10).
|
|
36
36
|
background_color (str, optional): Background color of the plot. Defaults to "white".
|
|
37
37
|
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
|
|
38
38
|
any existing alpha values found in background_color. Defaults to 1.0.
|
risk/{_risk.py → risk.py}
RENAMED
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
"""
|
|
2
|
-
risk/
|
|
3
|
-
|
|
2
|
+
risk/risk
|
|
3
|
+
~~~~~~~~~
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from .
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
6
|
+
from .annotation import AnnotationHandler
|
|
7
|
+
from .cluster import ClusterAPI
|
|
8
|
+
from .log import params, set_global_verbosity
|
|
9
|
+
from .network import GraphAPI, NetworkAPI, PlotterAPI
|
|
10
|
+
from .stats import StatsAPI
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
class RISK(NetworkAPI, AnnotationHandler,
|
|
13
|
+
class RISK(NetworkAPI, AnnotationHandler, ClusterAPI, StatsAPI, GraphAPI, PlotterAPI):
|
|
13
14
|
"""
|
|
14
15
|
RISK: A class for network analysis and visualization.
|
|
15
16
|
|
|
16
17
|
The RISK class integrates functionalities for loading networks, processing annotations,
|
|
17
|
-
performing network-based statistical analysis to quantify
|
|
18
|
+
performing network-based statistical analysis to quantify cluster relationships,
|
|
18
19
|
and visualizing networks and their properties.
|
|
19
20
|
"""
|
|
20
21
|
|
risk/stats/__init__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""
|
|
2
|
+
risk/stats/_stats/permutation/_test_functions
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.sparse import csr_matrix
|
|
8
|
+
|
|
9
|
+
# NOTE: Cython optimizations provided minimal performance benefits.
|
|
10
|
+
# The final version with Cython is archived in the `cython_permutation` branch.
|
|
11
|
+
|
|
12
|
+
# DISPATCH_TEST_FUNCTIONS can be found at the end of the file.
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compute_cluster_score_by_sum(
|
|
16
|
+
clusters_matrix: csr_matrix, annotation_matrix: csr_matrix
|
|
17
|
+
) -> np.ndarray:
|
|
18
|
+
"""
|
|
19
|
+
Compute the sum of attribute values for each cluster using sparse matrices.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
clusters_matrix (csr_matrix): Sparse binary matrix representing clusters.
|
|
23
|
+
annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
np.ndarray: Dense array of summed attribute values for each cluster.
|
|
27
|
+
"""
|
|
28
|
+
# Calculate the cluster score as the dot product of clusters and annotation
|
|
29
|
+
cluster_score = clusters_matrix @ annotation_matrix # Sparse matrix multiplication
|
|
30
|
+
# Convert the result to a dense array for downstream calculations
|
|
31
|
+
cluster_score_dense = cluster_score.toarray()
|
|
32
|
+
return cluster_score_dense
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def compute_cluster_score_by_stdev(
|
|
36
|
+
clusters_matrix: csr_matrix, annotation_matrix: csr_matrix
|
|
37
|
+
) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Compute the standard deviation of cluster scores for sparse matrices.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
clusters_matrix (csr_matrix): Sparse binary matrix representing clusters.
|
|
43
|
+
annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
np.ndarray: Standard deviation of the cluster scores.
|
|
47
|
+
"""
|
|
48
|
+
# Calculate the cluster score as the dot product of clusters and annotation
|
|
49
|
+
cluster_score = clusters_matrix @ annotation_matrix # Sparse matrix multiplication
|
|
50
|
+
# Calculate the number of elements in each cluster (sum of rows)
|
|
51
|
+
N = clusters_matrix.sum(axis=1).A.flatten() # Convert to 1D array
|
|
52
|
+
# Avoid division by zero by replacing zeros in N with np.nan temporarily
|
|
53
|
+
N[N == 0] = np.nan
|
|
54
|
+
# Compute the mean of the cluster scores
|
|
55
|
+
M = cluster_score.multiply(1 / N[:, None]).toarray() # Sparse element-wise division
|
|
56
|
+
# Compute the mean of squares (EXX) directly using squared annotation matrix
|
|
57
|
+
annotation_squared = annotation_matrix.multiply(annotation_matrix) # Element-wise squaring
|
|
58
|
+
EXX = (clusters_matrix @ annotation_squared).multiply(1 / N[:, None]).toarray()
|
|
59
|
+
# Calculate variance as EXX - M^2
|
|
60
|
+
variance = EXX - np.power(M, 2)
|
|
61
|
+
# Compute the standard deviation as the square root of the variance
|
|
62
|
+
cluster_stdev = np.sqrt(variance)
|
|
63
|
+
# Replace np.nan back with zeros in case N was 0 (no elements in the cluster)
|
|
64
|
+
cluster_stdev[np.isnan(cluster_stdev)] = 0
|
|
65
|
+
return cluster_stdev
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# Dictionary to dispatch statistical test functions based on the score metric
|
|
69
|
+
DISPATCH_TEST_FUNCTIONS = {
|
|
70
|
+
"sum": compute_cluster_score_by_sum,
|
|
71
|
+
"stdev": compute_cluster_score_by_stdev,
|
|
72
|
+
}
|
risk/{_neighborhoods/_stats/_permutation/_permutation.py → stats/_stats/permutation/permutation.py}
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
|
-
risk/
|
|
3
|
-
|
|
2
|
+
risk/stats/_stats/permutation/permutation
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from multiprocessing import Manager, get_context
|
|
@@ -16,7 +16,7 @@ from ._test_functions import DISPATCH_TEST_FUNCTIONS
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def compute_permutation_test(
|
|
19
|
-
|
|
19
|
+
clusters: csr_matrix,
|
|
20
20
|
annotation: csr_matrix,
|
|
21
21
|
score_metric: str = "sum",
|
|
22
22
|
null_distribution: str = "network",
|
|
@@ -25,10 +25,10 @@ def compute_permutation_test(
|
|
|
25
25
|
max_workers: int = 1,
|
|
26
26
|
) -> Dict[str, Any]:
|
|
27
27
|
"""
|
|
28
|
-
Compute permutation test for enrichment and depletion in
|
|
28
|
+
Compute permutation test for enrichment and depletion in clusters.
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
|
-
|
|
31
|
+
clusters (csr_matrix): Sparse binary matrix representing clusters.
|
|
32
32
|
annotation (csr_matrix): Sparse binary matrix representing annotation.
|
|
33
33
|
score_metric (str, optional): Metric to use for scoring ('sum' or 'stdev'). Defaults to "sum".
|
|
34
34
|
null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
|
|
@@ -41,16 +41,16 @@ def compute_permutation_test(
|
|
|
41
41
|
"""
|
|
42
42
|
# Ensure that the matrices are in the correct format and free of NaN values
|
|
43
43
|
# NOTE: Keep the data type as float32 to avoid locking issues with dot product operations
|
|
44
|
-
|
|
44
|
+
clusters = clusters.astype(np.float32)
|
|
45
45
|
annotation = annotation.astype(np.float32)
|
|
46
|
-
# Retrieve the appropriate
|
|
47
|
-
|
|
46
|
+
# Retrieve the appropriate cluster score function based on the metric
|
|
47
|
+
cluster_score_func = DISPATCH_TEST_FUNCTIONS[score_metric]
|
|
48
48
|
|
|
49
49
|
# Run the permutation test to calculate depletion and enrichment counts
|
|
50
50
|
counts_depletion, counts_enrichment = _run_permutation_test(
|
|
51
|
-
|
|
51
|
+
clusters=clusters,
|
|
52
52
|
annotation=annotation,
|
|
53
|
-
|
|
53
|
+
cluster_score_func=cluster_score_func,
|
|
54
54
|
null_distribution=null_distribution,
|
|
55
55
|
num_permutations=num_permutations,
|
|
56
56
|
random_seed=random_seed,
|
|
@@ -68,9 +68,9 @@ def compute_permutation_test(
|
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
def _run_permutation_test(
|
|
71
|
-
|
|
71
|
+
clusters: csr_matrix,
|
|
72
72
|
annotation: csr_matrix,
|
|
73
|
-
|
|
73
|
+
cluster_score_func: Callable,
|
|
74
74
|
null_distribution: str = "network",
|
|
75
75
|
num_permutations: int = 1000,
|
|
76
76
|
random_seed: int = 888,
|
|
@@ -80,9 +80,9 @@ def _run_permutation_test(
|
|
|
80
80
|
Run the permutation test to calculate depletion and enrichment counts.
|
|
81
81
|
|
|
82
82
|
Args:
|
|
83
|
-
|
|
83
|
+
clusters (csr_matrix): Sparse binary matrix representing clusters.
|
|
84
84
|
annotation (csr_matrix): Sparse binary matrix representing annotation.
|
|
85
|
-
|
|
85
|
+
cluster_score_func (Callable): Function to calculate cluster scores.
|
|
86
86
|
null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
|
|
87
87
|
num_permutations (int, optional): Number of permutations. Defaults to 1000.
|
|
88
88
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
|
@@ -109,16 +109,14 @@ def _run_permutation_test(
|
|
|
109
109
|
# Replace NaNs with zeros in the sparse annotation matrix
|
|
110
110
|
annotation.data[np.isnan(annotation.data)] = 0
|
|
111
111
|
annotation_matrix_obsv = annotation[idxs]
|
|
112
|
-
|
|
113
|
-
# Calculate observed
|
|
112
|
+
clusters_matrix_obsv = clusters.T[idxs].T
|
|
113
|
+
# Calculate observed cluster scores
|
|
114
114
|
with np.errstate(invalid="ignore", divide="ignore"):
|
|
115
|
-
|
|
116
|
-
neighborhoods_matrix_obsv, annotation_matrix_obsv
|
|
117
|
-
)
|
|
115
|
+
observed_cluster_scores = cluster_score_func(clusters_matrix_obsv, annotation_matrix_obsv)
|
|
118
116
|
|
|
119
117
|
# Initialize count matrices for depletion and enrichment
|
|
120
|
-
counts_depletion = np.zeros(
|
|
121
|
-
counts_enrichment = np.zeros(
|
|
118
|
+
counts_depletion = np.zeros(observed_cluster_scores.shape)
|
|
119
|
+
counts_enrichment = np.zeros(observed_cluster_scores.shape)
|
|
122
120
|
# Determine the number of permutations to run in each worker process
|
|
123
121
|
subset_size = num_permutations // max_workers
|
|
124
122
|
remainder = num_permutations % max_workers
|
|
@@ -145,9 +143,9 @@ def _run_permutation_test(
|
|
|
145
143
|
(
|
|
146
144
|
permutation_batches[i], # Pass the batch of precomputed permutations
|
|
147
145
|
annotation,
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
146
|
+
clusters_matrix_obsv,
|
|
147
|
+
observed_cluster_scores,
|
|
148
|
+
cluster_score_func,
|
|
151
149
|
num_permutations,
|
|
152
150
|
progress_counter,
|
|
153
151
|
max_workers,
|
|
@@ -176,9 +174,9 @@ def _run_permutation_test(
|
|
|
176
174
|
def _permutation_process_batch(
|
|
177
175
|
permutations: Union[List, Tuple, np.ndarray],
|
|
178
176
|
annotation_matrix: csr_matrix,
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
177
|
+
clusters_matrix_obsv: csr_matrix,
|
|
178
|
+
observed_cluster_scores: np.ndarray,
|
|
179
|
+
cluster_score_func: Callable,
|
|
182
180
|
num_permutations: int,
|
|
183
181
|
progress_counter: ValueProxy,
|
|
184
182
|
max_workers: int,
|
|
@@ -189,9 +187,9 @@ def _permutation_process_batch(
|
|
|
189
187
|
Args:
|
|
190
188
|
permutations (Union[List, Tuple, np.ndarray]): Permutation batch to process.
|
|
191
189
|
annotation_matrix (csr_matrix): Sparse binary matrix representing annotation.
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
190
|
+
clusters_matrix_obsv (csr_matrix): Sparse binary matrix representing observed clusters.
|
|
191
|
+
observed_cluster_scores (np.ndarray): Observed cluster scores.
|
|
192
|
+
cluster_score_func (Callable): Function to calculate cluster scores.
|
|
195
193
|
num_permutations (int): Number of total permutations across all subsets.
|
|
196
194
|
progress_counter (multiprocessing.managers.ValueProxy): Shared counter for tracking progress.
|
|
197
195
|
max_workers (int): Number of workers for multiprocessing.
|
|
@@ -200,8 +198,8 @@ def _permutation_process_batch(
|
|
|
200
198
|
tuple: Local counts of depletion and enrichment.
|
|
201
199
|
"""
|
|
202
200
|
# Initialize local count matrices for this worker
|
|
203
|
-
local_counts_depletion = np.zeros(
|
|
204
|
-
local_counts_enrichment = np.zeros(
|
|
201
|
+
local_counts_depletion = np.zeros(observed_cluster_scores.shape)
|
|
202
|
+
local_counts_enrichment = np.zeros(observed_cluster_scores.shape)
|
|
205
203
|
|
|
206
204
|
# Limit the number of threads used by NumPy's BLAS implementation to 1 when more than one worker is used
|
|
207
205
|
# NOTE: This does not work for Mac M chips due to a bug in the threadpoolctl package
|
|
@@ -216,19 +214,19 @@ def _permutation_process_batch(
|
|
|
216
214
|
for permuted_idxs in permutations:
|
|
217
215
|
# Apply precomputed permutation
|
|
218
216
|
annotation_matrix_permut = annotation_matrix[permuted_idxs]
|
|
219
|
-
# Calculate permuted
|
|
217
|
+
# Calculate permuted cluster scores
|
|
220
218
|
with np.errstate(invalid="ignore", divide="ignore"):
|
|
221
|
-
|
|
222
|
-
|
|
219
|
+
permuted_cluster_scores = cluster_score_func(
|
|
220
|
+
clusters_matrix_obsv, annotation_matrix_permut
|
|
223
221
|
)
|
|
224
222
|
|
|
225
223
|
# Update local depletion and enrichment counts
|
|
226
224
|
local_counts_depletion = np.add(
|
|
227
|
-
local_counts_depletion,
|
|
225
|
+
local_counts_depletion, permuted_cluster_scores <= observed_cluster_scores
|
|
228
226
|
)
|
|
229
227
|
local_counts_enrichment = np.add(
|
|
230
228
|
local_counts_enrichment,
|
|
231
|
-
|
|
229
|
+
permuted_cluster_scores >= observed_cluster_scores,
|
|
232
230
|
)
|
|
233
231
|
|
|
234
232
|
# Update progress
|
|
@@ -1,25 +1,25 @@
|
|
|
1
1
|
"""
|
|
2
|
-
risk/
|
|
3
|
-
|
|
2
|
+
risk/stats/_stats/tests
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from typing import Any, Dict
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.sparse import csr_matrix
|
|
10
|
-
from scipy.stats import binom, chi2, hypergeom
|
|
10
|
+
from scipy.stats import binom, chi2, hypergeom
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def compute_binom_test(
|
|
14
|
-
|
|
14
|
+
clusters: csr_matrix,
|
|
15
15
|
annotation: csr_matrix,
|
|
16
16
|
null_distribution: str = "network",
|
|
17
17
|
) -> Dict[str, Any]:
|
|
18
18
|
"""
|
|
19
|
-
Compute Binomial test for enrichment and depletion in
|
|
19
|
+
Compute Binomial test for enrichment and depletion in clusters with selectable null distribution.
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
|
-
|
|
22
|
+
clusters (csr_matrix): Sparse binary matrix representing clusters.
|
|
23
23
|
annotation (csr_matrix): Sparse binary matrix representing annotation.
|
|
24
24
|
null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
|
|
25
25
|
|
|
@@ -30,10 +30,10 @@ def compute_binom_test(
|
|
|
30
30
|
ValueError: If an invalid null_distribution value is provided.
|
|
31
31
|
"""
|
|
32
32
|
# Get the total number of nodes in the network
|
|
33
|
-
total_nodes =
|
|
33
|
+
total_nodes = clusters.shape[1]
|
|
34
34
|
|
|
35
35
|
# Compute sums (remain sparse here)
|
|
36
|
-
|
|
36
|
+
cluster_sizes = clusters.sum(axis=1) # Row sums
|
|
37
37
|
annotation_totals = annotation.sum(axis=0) # Column sums
|
|
38
38
|
# Compute probabilities (convert to dense)
|
|
39
39
|
if null_distribution == "network":
|
|
@@ -46,26 +46,26 @@ def compute_binom_test(
|
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
# Observed counts (sparse matrix multiplication)
|
|
49
|
-
annotated_counts =
|
|
49
|
+
annotated_counts = clusters @ annotation # Sparse result
|
|
50
50
|
annotated_counts_dense = annotated_counts.toarray() # Convert for dense operations
|
|
51
51
|
|
|
52
52
|
# Compute enrichment and depletion p-values
|
|
53
|
-
enrichment_pvals = 1 - binom.cdf(annotated_counts_dense - 1,
|
|
54
|
-
depletion_pvals = binom.cdf(annotated_counts_dense,
|
|
53
|
+
enrichment_pvals = 1 - binom.cdf(annotated_counts_dense - 1, cluster_sizes.A, p_values)
|
|
54
|
+
depletion_pvals = binom.cdf(annotated_counts_dense, cluster_sizes.A, p_values)
|
|
55
55
|
|
|
56
56
|
return {"enrichment_pvals": enrichment_pvals, "depletion_pvals": depletion_pvals}
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
def compute_chi2_test(
|
|
60
|
-
|
|
60
|
+
clusters: csr_matrix,
|
|
61
61
|
annotation: csr_matrix,
|
|
62
62
|
null_distribution: str = "network",
|
|
63
63
|
) -> Dict[str, Any]:
|
|
64
64
|
"""
|
|
65
|
-
Compute chi-squared test for enrichment and depletion in
|
|
65
|
+
Compute chi-squared test for enrichment and depletion in clusters with selectable null distribution.
|
|
66
66
|
|
|
67
67
|
Args:
|
|
68
|
-
|
|
68
|
+
clusters (csr_matrix): Sparse binary matrix representing clusters.
|
|
69
69
|
annotation (csr_matrix): Sparse binary matrix representing annotation.
|
|
70
70
|
null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
|
|
71
71
|
|
|
@@ -76,12 +76,12 @@ def compute_chi2_test(
|
|
|
76
76
|
ValueError: If an invalid null_distribution value is provided.
|
|
77
77
|
"""
|
|
78
78
|
# Total number of nodes in the network
|
|
79
|
-
total_node_count =
|
|
79
|
+
total_node_count = clusters.shape[0]
|
|
80
80
|
|
|
81
81
|
if null_distribution == "network":
|
|
82
82
|
# Case 1: Use all nodes as the background
|
|
83
83
|
background_population = total_node_count
|
|
84
|
-
|
|
84
|
+
cluster_sums = clusters.sum(axis=0) # Column sums of clusters
|
|
85
85
|
annotation_sums = annotation.sum(axis=0) # Column sums of annotations
|
|
86
86
|
elif null_distribution == "annotation":
|
|
87
87
|
# Case 2: Only consider nodes with at least one annotation
|
|
@@ -89,9 +89,7 @@ def compute_chi2_test(
|
|
|
89
89
|
np.ravel(annotation.sum(axis=1)) > 0
|
|
90
90
|
) # Row-wise sum to filter nodes with annotations
|
|
91
91
|
background_population = annotated_nodes.sum() # Total number of annotated nodes
|
|
92
|
-
|
|
93
|
-
axis=0
|
|
94
|
-
) # Neighborhood sums for annotated nodes
|
|
92
|
+
cluster_sums = clusters[annotated_nodes].sum(axis=0) # Cluster sums for annotated nodes
|
|
95
93
|
annotation_sums = annotation[annotated_nodes].sum(
|
|
96
94
|
axis=0
|
|
97
95
|
) # Annotation sums for annotated nodes
|
|
@@ -101,13 +99,13 @@ def compute_chi2_test(
|
|
|
101
99
|
)
|
|
102
100
|
|
|
103
101
|
# Convert to dense arrays for downstream computations
|
|
104
|
-
|
|
102
|
+
cluster_sums = np.asarray(cluster_sums).reshape(-1, 1) # Ensure column vector shape
|
|
105
103
|
annotation_sums = np.asarray(annotation_sums).reshape(1, -1) # Ensure row vector shape
|
|
106
104
|
|
|
107
|
-
# Observed values: number of annotated nodes in each
|
|
108
|
-
observed =
|
|
105
|
+
# Observed values: number of annotated nodes in each cluster
|
|
106
|
+
observed = clusters.T @ annotation # Shape: (clusters, annotation)
|
|
109
107
|
# Expected values under the null
|
|
110
|
-
expected = (
|
|
108
|
+
expected = (cluster_sums @ annotation_sums) / background_population
|
|
111
109
|
# Chi-squared statistic: sum((observed - expected)^2 / expected)
|
|
112
110
|
with np.errstate(divide="ignore", invalid="ignore"): # Handle divide-by-zero
|
|
113
111
|
chi2_stat = np.where(expected > 0, np.power(observed - expected, 2) / expected, 0)
|
|
@@ -120,15 +118,15 @@ def compute_chi2_test(
|
|
|
120
118
|
|
|
121
119
|
|
|
122
120
|
def compute_hypergeom_test(
|
|
123
|
-
|
|
121
|
+
clusters: csr_matrix,
|
|
124
122
|
annotation: csr_matrix,
|
|
125
123
|
null_distribution: str = "network",
|
|
126
124
|
) -> Dict[str, Any]:
|
|
127
125
|
"""
|
|
128
|
-
Compute hypergeometric test for enrichment and depletion in
|
|
126
|
+
Compute hypergeometric test for enrichment and depletion in clusters with selectable null distribution.
|
|
129
127
|
|
|
130
128
|
Args:
|
|
131
|
-
|
|
129
|
+
clusters (csr_matrix): Sparse binary matrix representing clusters.
|
|
132
130
|
annotation (csr_matrix): Sparse binary matrix representing annotation.
|
|
133
131
|
null_distribution (str, optional): Type of null distribution ('network' or 'annotation'). Defaults to "network".
|
|
134
132
|
|
|
@@ -139,10 +137,10 @@ def compute_hypergeom_test(
|
|
|
139
137
|
ValueError: If an invalid null_distribution value is provided.
|
|
140
138
|
"""
|
|
141
139
|
# Get the total number of nodes in the network
|
|
142
|
-
total_nodes =
|
|
140
|
+
total_nodes = clusters.shape[1]
|
|
143
141
|
|
|
144
142
|
# Compute sums
|
|
145
|
-
|
|
143
|
+
cluster_sums = clusters.sum(axis=0).A.flatten() # Convert to dense array
|
|
146
144
|
annotation_sums = annotation.sum(axis=0).A.flatten() # Convert to dense array
|
|
147
145
|
|
|
148
146
|
if null_distribution == "network":
|
|
@@ -150,7 +148,7 @@ def compute_hypergeom_test(
|
|
|
150
148
|
elif null_distribution == "annotation":
|
|
151
149
|
annotated_nodes = annotation.sum(axis=1).A.flatten() > 0 # Boolean mask
|
|
152
150
|
background_population = annotated_nodes.sum()
|
|
153
|
-
|
|
151
|
+
cluster_sums = clusters[annotated_nodes].sum(axis=0).A.flatten()
|
|
154
152
|
annotation_sums = annotation[annotated_nodes].sum(axis=0).A.flatten()
|
|
155
153
|
else:
|
|
156
154
|
raise ValueError(
|
|
@@ -158,19 +156,19 @@ def compute_hypergeom_test(
|
|
|
158
156
|
)
|
|
159
157
|
|
|
160
158
|
# Observed counts
|
|
161
|
-
|
|
162
|
-
|
|
159
|
+
annotated_in_cluster = clusters.T @ annotation # Sparse result
|
|
160
|
+
annotated_in_cluster = annotated_in_cluster.toarray() # Convert to dense
|
|
163
161
|
# Align shapes for broadcasting
|
|
164
|
-
|
|
162
|
+
cluster_sums = cluster_sums.reshape(-1, 1)
|
|
165
163
|
annotation_sums = annotation_sums.reshape(1, -1)
|
|
166
164
|
background_population = np.array(background_population).reshape(1, 1)
|
|
167
165
|
|
|
168
166
|
# Compute hypergeometric p-values
|
|
169
167
|
depletion_pvals = hypergeom.cdf(
|
|
170
|
-
|
|
168
|
+
annotated_in_cluster, background_population, annotation_sums, cluster_sums
|
|
171
169
|
)
|
|
172
170
|
enrichment_pvals = hypergeom.sf(
|
|
173
|
-
|
|
171
|
+
annotated_in_cluster - 1, background_population, annotation_sums, cluster_sums
|
|
174
172
|
)
|
|
175
173
|
|
|
176
174
|
return {"depletion_pvals": depletion_pvals, "enrichment_pvals": enrichment_pvals}
|