risk-network 0.0.3b0__cp38-cp38-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,82 @@
1
+ # cython: language_level=3
2
+ import numpy as np
3
+ cimport numpy as np
4
+ cimport cython
5
+ from threadpoolctl import threadpool_limits
6
+
7
+
8
+ @cython.boundscheck(False) # Disable bounds checking for entire function
9
+ @cython.wraparound(False) # Disable negative index wrapping for entire function
10
+ def compute_neighborhood_score_by_sum_cython(
11
+ np.ndarray[np.float32_t, ndim=2] neighborhoods,
12
+ np.ndarray[np.float32_t, ndim=2] annotation_matrix,
13
+ ):
14
+ cdef np.float32_t[:, :] neighborhood_score
15
+ # Limit the number of threads used by np.dot
16
+ with threadpool_limits(limits=1, user_api='blas'):
17
+ neighborhood_score = np.dot(neighborhoods, annotation_matrix)
18
+
19
+ return np.asarray(neighborhood_score)
20
+
21
+
22
+ @cython.boundscheck(False)
23
+ @cython.wraparound(False)
24
+ def compute_neighborhood_score_by_stdev_cython(
25
+ np.ndarray[np.float32_t, ndim=2] neighborhoods,
26
+ np.ndarray[np.float32_t, ndim=2] annotation_matrix,
27
+ ):
28
+ cdef np.ndarray[np.float32_t, ndim=2] neighborhood_score
29
+ cdef np.ndarray[np.float32_t, ndim=2] EXX
30
+ # Perform dot product directly using the inputs with limited threads
31
+ with threadpool_limits(limits=1, user_api='blas'):
32
+ neighborhood_score = np.dot(neighborhoods, annotation_matrix)
33
+
34
+ # Sum across rows for neighborhoods to get N, reshape for broadcasting
35
+ cdef np.ndarray[np.float32_t, ndim=1] N = np.sum(neighborhoods, axis=1)
36
+ cdef np.ndarray[np.float32_t, ndim=2] N_reshaped = N[:, None]
37
+ # Mean of the dot product
38
+ cdef np.ndarray[np.float32_t, ndim=2] M = neighborhood_score / N_reshaped
39
+ # Compute the mean of squares (EXX) with limited threads
40
+ with threadpool_limits(limits=1, user_api='blas'):
41
+ EXX = np.dot(neighborhoods, np.power(annotation_matrix, 2)) / N_reshaped
42
+
43
+ # Variance computation
44
+ cdef np.ndarray[np.float32_t, ndim=2] variance = EXX - M**2
45
+ # Standard deviation computation
46
+ cdef np.ndarray[np.float32_t, ndim=2] stdev = np.sqrt(variance)
47
+
48
+ return stdev
49
+
50
+
51
+ @cython.boundscheck(False)
52
+ @cython.wraparound(False)
53
+ def compute_neighborhood_score_by_z_score_cython(
54
+ np.ndarray[np.float32_t, ndim=2] neighborhoods,
55
+ np.ndarray[np.float32_t, ndim=2] annotation_matrix,
56
+ ):
57
+ cdef np.ndarray[np.float32_t, ndim=2] neighborhood_score
58
+ cdef np.ndarray[np.float32_t, ndim=2] EXX
59
+ # Perform dot product directly using the inputs with limited threads
60
+ with threadpool_limits(limits=1, user_api='blas'):
61
+ neighborhood_score = np.dot(neighborhoods, annotation_matrix)
62
+
63
+ # Sum across rows for neighborhoods to get N, reshape for broadcasting
64
+ cdef np.ndarray[np.float32_t, ndim=1] N = np.sum(neighborhoods, axis=1)
65
+ cdef np.ndarray[np.float32_t, ndim=2] N_reshaped = N[:, None]
66
+ # Mean of the dot product
67
+ cdef np.ndarray[np.float32_t, ndim=2] M = neighborhood_score / N_reshaped
68
+ # Compute the mean of squares (EXX) with limited threads
69
+ with threadpool_limits(limits=1, user_api='blas'):
70
+ EXX = np.dot(neighborhoods, np.power(annotation_matrix, 2)) / N_reshaped
71
+
72
+ # Variance computation
73
+ cdef np.ndarray[np.float32_t, ndim=2] variance = EXX - M**2
74
+ # Standard deviation computation
75
+ cdef np.ndarray[np.float32_t, ndim=2] stdev = np.sqrt(variance)
76
+ # Z-score computation with error handling
77
+ with np.errstate(divide='ignore', invalid='ignore'):
78
+ neighborhood_score = np.divide(M, stdev)
79
+ # Handle divisions by zero or stdev == 0
80
+ neighborhood_score[np.isnan(neighborhood_score)] = 0 # Assuming requirement to reset NaN results to 0
81
+
82
+ return neighborhood_score
@@ -0,0 +1,11 @@
1
+ """
2
+ risk/stats/permutation/_cython/setup
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ # setup.py
7
+ from setuptools import setup
8
+ from Cython.Build import cythonize
9
+ import numpy as np
10
+
11
+ setup(ext_modules=cythonize("permutation.pyx"), include_dirs=[np.get_include()])
@@ -0,0 +1,83 @@
1
+ """
2
+ risk/stats/permutation/_python/permutation
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import numpy as np
7
+
8
+
9
+ def compute_neighborhood_score_by_sum_python(
10
+ neighborhoods_matrix: np.ndarray, annotation_matrix: np.ndarray
11
+ ) -> np.ndarray:
12
+ """Compute the sum of attribute values for each neighborhood.
13
+
14
+ Args:
15
+ neighborhoods_matrix (np.ndarray): Binary matrix representing neighborhoods.
16
+ annotation_matrix (np.ndarray): Matrix representing annotation values.
17
+
18
+ Returns:
19
+ np.ndarray: Sum of attribute values for each neighborhood.
20
+ """
21
+ # Directly compute the dot product to get the sum of attribute values in each neighborhood
22
+ neighborhood_score = np.dot(neighborhoods_matrix, annotation_matrix)
23
+ return neighborhood_score
24
+
25
+
26
+ def compute_neighborhood_score_by_stdev_python(
27
+ neighborhoods_matrix: np.ndarray, annotation_matrix: np.ndarray
28
+ ) -> np.ndarray:
29
+ """Compute the standard deviation of neighborhood scores.
30
+
31
+ Args:
32
+ neighborhoods_matrix (np.ndarray): Binary matrix representing neighborhoods.
33
+ annotation_matrix (np.ndarray): Matrix representing annotation values.
34
+
35
+ Returns:
36
+ np.ndarray: Standard deviation of the neighborhood scores.
37
+ """
38
+ # Calculate the neighborhood score as the dot product of neighborhoods and annotations
39
+ neighborhood_score = np.dot(neighborhoods_matrix, annotation_matrix)
40
+ # Calculate the number of elements in each neighborhood and reshape for broadcasting
41
+ N = np.sum(neighborhoods_matrix, axis=1)
42
+ N_reshaped = N[:, None]
43
+ # Compute the mean of the neighborhood scores
44
+ M = neighborhood_score / N_reshaped
45
+ # Compute the mean of squares (EXX) for annotation values
46
+ EXX = np.dot(neighborhoods_matrix, np.power(annotation_matrix, 2)) / N_reshaped
47
+ # Calculate variance as EXX - M^2
48
+ variance = EXX - np.power(M, 2)
49
+ # Compute the standard deviation as the square root of the variance
50
+ stdev = np.sqrt(variance)
51
+ return stdev
52
+
53
+
54
+ def compute_neighborhood_score_by_z_score_python(
55
+ neighborhoods_matrix: np.ndarray, annotation_matrix: np.ndarray
56
+ ) -> np.ndarray:
57
+ """Compute Z-scores for neighborhood scores.
58
+
59
+ Args:
60
+ neighborhoods_matrix (np.ndarray): Binary matrix representing neighborhoods.
61
+ annotation_matrix (np.ndarray): Matrix representing annotation values.
62
+
63
+ Returns:
64
+ np.ndarray: Z-scores for each neighborhood.
65
+ """
66
+ # Calculate the neighborhood score as the dot product of neighborhoods and annotations
67
+ neighborhood_score = np.dot(neighborhoods_matrix, annotation_matrix)
68
+ # Calculate the number of elements in each neighborhood
69
+ N = np.dot(neighborhoods_matrix, np.ones(annotation_matrix.shape))
70
+ # Compute the mean of the neighborhood scores
71
+ M = neighborhood_score / N
72
+ # Compute the mean of squares (EXX) and the squared mean (EEX)
73
+ EXX = np.dot(neighborhoods_matrix, np.power(annotation_matrix, 2)) / N
74
+ EEX = np.power(M, 2)
75
+ # Calculate the standard deviation for each neighborhood
76
+ std = np.sqrt(EXX - EEX)
77
+ # Calculate Z-scores, handling cases where std is 0 or N is less than 3
78
+ with np.errstate(divide="ignore", invalid="ignore"):
79
+ z_scores = np.divide(M, std)
80
+ z_scores[std == 0] = np.nan # Handle division by zero
81
+ z_scores[N < 3] = np.nan # Apply threshold for minimum number of elements
82
+
83
+ return z_scores
risk/stats/stats.py ADDED
@@ -0,0 +1,443 @@
1
+ """
2
+ risk/stats/stats
3
+ ~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from multiprocessing import Pool, Lock
7
+ from typing import Any, Callable, Union
8
+
9
+ import numpy as np
10
+ from statsmodels.stats.multitest import fdrcorrection
11
+
12
+
13
+ def _is_notebook() -> bool:
14
+ """Determine the type of interactive environment and return it as a dictionary.
15
+
16
+ Returns:
17
+ bool: True if the environment is a Jupyter notebook, False otherwise.
18
+ """
19
+ try:
20
+ shell = get_ipython().__class__.__name__
21
+ if shell == "ZMQInteractiveShell":
22
+ return True # Jupyter notebook or qtconsole
23
+ elif shell == "TerminalInteractiveShell":
24
+ return False # Terminal running IPython
25
+ else:
26
+ return False # Other types of shell
27
+ except NameError:
28
+ return False # Standard Python interpreter
29
+
30
+
31
+ if _is_notebook():
32
+ from tqdm.notebook import tqdm
33
+ else:
34
+ from tqdm import tqdm
35
+
36
+
37
+ from risk.stats.permutation import (
38
+ compute_neighborhood_score_by_sum_cython,
39
+ compute_neighborhood_score_by_stdev_cython,
40
+ compute_neighborhood_score_by_z_score_cython,
41
+ compute_neighborhood_score_by_sum_python,
42
+ compute_neighborhood_score_by_stdev_python,
43
+ compute_neighborhood_score_by_z_score_python,
44
+ )
45
+
46
+ CYTHON_DISPATCH_PERMUTATION_TABLE = {
47
+ "sum": compute_neighborhood_score_by_sum_cython,
48
+ "stdev": compute_neighborhood_score_by_stdev_cython,
49
+ "z_score": compute_neighborhood_score_by_z_score_cython,
50
+ }
51
+ PYTHON_DISPATCH_PERMUTATION_TABLE = {
52
+ "sum": compute_neighborhood_score_by_sum_python,
53
+ "stdev": compute_neighborhood_score_by_stdev_python,
54
+ "z_score": compute_neighborhood_score_by_z_score_python,
55
+ }
56
+
57
+
58
+ def compute_permutation(
59
+ neighborhoods: np.ndarray,
60
+ annotations: np.ndarray,
61
+ score_metric: str = "sum",
62
+ null_distribution: str = "network",
63
+ num_permutations: int = 1000,
64
+ use_cython: bool = True,
65
+ random_seed: int = 888,
66
+ max_workers: int = 1,
67
+ ) -> dict:
68
+ """Compute permutation test for enrichment and depletion in neighborhoods.
69
+
70
+ Args:
71
+ neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
72
+ annotations (np.ndarray): Binary matrix representing annotations.
73
+ score_metric (str, optional): Metric to use for scoring ('sum', 'mean', etc.). Defaults to "sum".
74
+ null_distribution (str, optional): Type of null distribution ('network' or other). Defaults to "network".
75
+ num_permutations (int, optional): Number of permutations to run. Defaults to 1000.
76
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
77
+ use_cython (bool, optional): Whether to use Cython for computation. Defaults to True.
78
+ max_workers (int, optional): Number of workers for multiprocessing. Defaults to 1.
79
+
80
+ Returns:
81
+ dict: Dictionary containing depletion and enrichment p-values.
82
+ """
83
+ # Ensure that the matrices are in the correct format and free of NaN values
84
+ neighborhoods = neighborhoods.astype(np.float32)
85
+ annotations = annotations.astype(np.float32)
86
+ # Retrieve the appropriate scoring function based on the metric and Cython usage
87
+ if use_cython:
88
+ neighborhood_score_func = CYTHON_DISPATCH_PERMUTATION_TABLE[score_metric]
89
+ else:
90
+ neighborhood_score_func = PYTHON_DISPATCH_PERMUTATION_TABLE[score_metric]
91
+ # Run the permutation test to calculate depletion and enrichment counts
92
+ counts_depletion, counts_enrichment = _run_permutation_test(
93
+ neighborhoods=neighborhoods,
94
+ annotations=annotations,
95
+ neighborhood_score_func=neighborhood_score_func,
96
+ null_distribution=null_distribution,
97
+ num_permutations=num_permutations,
98
+ random_seed=random_seed,
99
+ max_workers=max_workers,
100
+ )
101
+
102
+ # Compute p-values for depletion and enrichment
103
+ # If counts are 0, set p-value to 1/num_permutations to avoid zero p-values
104
+ depletion_pvals = np.maximum(counts_depletion, 1) / num_permutations
105
+ enrichment_pvals = np.maximum(counts_enrichment, 1) / num_permutations
106
+
107
+ return {
108
+ "depletion_pvals": depletion_pvals,
109
+ "enrichment_pvals": enrichment_pvals,
110
+ }
111
+
112
+
113
+ def _run_permutation_test(
114
+ neighborhoods: np.ndarray,
115
+ annotations: np.ndarray,
116
+ neighborhood_score_func: Callable,
117
+ null_distribution: str = "network",
118
+ num_permutations: int = 1000,
119
+ random_seed: int = 888,
120
+ max_workers: int = 4,
121
+ ) -> tuple:
122
+ """Run a permutation test to calculate enrichment and depletion counts.
123
+
124
+ Args:
125
+ neighborhoods (np.ndarray): The neighborhood matrix.
126
+ annotations (np.ndarray): The annotation matrix.
127
+ neighborhood_score_func (Callable): Function to calculate neighborhood scores.
128
+ null_distribution (str, optional): Type of null distribution. Defaults to "network".
129
+ num_permutations (int, optional): Number of permutations. Defaults to 1000.
130
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
131
+ max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
132
+
133
+ Returns:
134
+ tuple: Depletion and enrichment counts.
135
+ """
136
+ # Set the random seed for reproducibility
137
+ np.random.seed(random_seed)
138
+ # Determine indices based on null distribution type
139
+ if null_distribution == "network":
140
+ idxs = range(annotations.shape[0])
141
+ else:
142
+ idxs = np.nonzero(np.sum(~np.isnan(annotations), axis=1))[0]
143
+
144
+ # Replace NaNs with zeros in the annotations matrix
145
+ annotations[np.isnan(annotations)] = 0
146
+ annotation_matrix_obsv = annotations[idxs]
147
+ neighborhoods_matrix_obsv = neighborhoods.T[idxs].T
148
+ # Calculate observed neighborhood scores
149
+ with np.errstate(invalid="ignore", divide="ignore"):
150
+ observed_neighborhood_scores = neighborhood_score_func(
151
+ neighborhoods_matrix_obsv, annotation_matrix_obsv
152
+ )
153
+
154
+ # Initialize count matrices for depletion and enrichment
155
+ counts_depletion = np.zeros(observed_neighborhood_scores.shape)
156
+ counts_enrichment = np.zeros(observed_neighborhood_scores.shape)
157
+ # Determine subset size for each worker
158
+ subset_size = num_permutations // max_workers
159
+ remainder = num_permutations % max_workers
160
+
161
+ if max_workers == 1:
162
+ # If single-threaded, run the permutation process directly
163
+ local_counts_depletion, local_counts_enrichment = _permutation_process_subset(
164
+ annotations,
165
+ np.array(idxs),
166
+ neighborhoods_matrix_obsv,
167
+ observed_neighborhood_scores,
168
+ neighborhood_score_func,
169
+ num_permutations,
170
+ 0,
171
+ False,
172
+ )
173
+ counts_depletion = np.add(counts_depletion, local_counts_depletion)
174
+ counts_enrichment = np.add(counts_enrichment, local_counts_enrichment)
175
+ else:
176
+ # Prepare parameters for multiprocessing
177
+ params_list = [
178
+ (
179
+ annotations,
180
+ idxs,
181
+ neighborhoods_matrix_obsv,
182
+ observed_neighborhood_scores,
183
+ neighborhood_score_func,
184
+ subset_size + (1 if i < remainder else 0),
185
+ i,
186
+ True,
187
+ )
188
+ for i in range(max_workers)
189
+ ]
190
+
191
+ # Initialize a multiprocessing pool with a lock
192
+ lock = Lock()
193
+ with Pool(max_workers, initializer=_init, initargs=(lock,)) as pool:
194
+ results = pool.starmap(_permutation_process_subset, params_list)
195
+ # Accumulate results from each worker
196
+ for local_counts_depletion, local_counts_enrichment in results:
197
+ counts_depletion = np.add(counts_depletion, local_counts_depletion)
198
+ counts_enrichment = np.add(counts_enrichment, local_counts_enrichment)
199
+
200
+ return counts_depletion, counts_enrichment
201
+
202
+
203
+ def _permutation_process_subset(
204
+ annotation_matrix: np.ndarray,
205
+ idxs: np.ndarray,
206
+ neighborhoods_matrix_obsv: np.ndarray,
207
+ observed_neighborhood_scores: np.ndarray,
208
+ neighborhood_score_func: Callable,
209
+ subset_size: int,
210
+ worker_id: int,
211
+ use_lock: bool,
212
+ ) -> tuple:
213
+ """Process a subset of permutations for the permutation test.
214
+
215
+ Args:
216
+ annotation_matrix (np.ndarray): The annotation matrix.
217
+ idxs (np.ndarray): Indices of valid rows in the matrix.
218
+ neighborhoods_matrix_obsv (np.ndarray): Observed neighborhoods matrix.
219
+ observed_neighborhood_scores (np.ndarray): Observed neighborhood scores.
220
+ neighborhood_score_func (Callable): Function to calculate neighborhood scores.
221
+ subset_size (int): Number of permutations to run in this subset.
222
+ worker_id (int): ID of the worker process.
223
+ use_lock (bool): Whether to use a lock for multiprocessing synchronization.
224
+
225
+ Returns:
226
+ tuple: Local counts of depletion and enrichment.
227
+ """
228
+ # Initialize local count matrices for this worker
229
+ local_counts_depletion = np.zeros(observed_neighborhood_scores.shape)
230
+ local_counts_enrichment = np.zeros(observed_neighborhood_scores.shape)
231
+
232
+ if _is_notebook():
233
+ # Hack to ensure progress bar displays correctly in Jupyter notebooks
234
+ print(" ", end="", flush=True)
235
+
236
+ # Initialize progress bar for tracking permutation progress
237
+ text = f"Worker {worker_id + 1} Progress"
238
+ if use_lock:
239
+ with lock:
240
+ # Set mininterval to 0.1 to prevent rapid updates and improve performance
241
+ progress = tqdm(
242
+ total=subset_size, desc=text, position=worker_id, leave=False, mininterval=0.1
243
+ )
244
+ else:
245
+ progress = tqdm(
246
+ total=subset_size, desc=text, position=worker_id, leave=False, mininterval=0.1
247
+ )
248
+
249
+ for _ in range(subset_size):
250
+ # Permute the annotation matrix
251
+ annotation_matrix_permut = annotation_matrix[np.random.permutation(idxs)]
252
+ # Calculate permuted neighborhood scores
253
+ with np.errstate(invalid="ignore", divide="ignore"):
254
+ permuted_neighborhood_scores = neighborhood_score_func(
255
+ neighborhoods_matrix_obsv, annotation_matrix_permut
256
+ )
257
+ # Update local depletion and enrichment counts based on permuted scores
258
+ local_counts_depletion = np.add(
259
+ local_counts_depletion, permuted_neighborhood_scores <= observed_neighborhood_scores
260
+ )
261
+ local_counts_enrichment = np.add(
262
+ local_counts_enrichment, permuted_neighborhood_scores >= observed_neighborhood_scores
263
+ )
264
+ # Update progress bar
265
+ if use_lock:
266
+ with lock:
267
+ progress.update(1)
268
+ else:
269
+ progress.update(1)
270
+
271
+ # Close the progress bar once processing is complete
272
+ if use_lock:
273
+ with lock:
274
+ progress.close()
275
+ else:
276
+ progress.close()
277
+
278
+ return local_counts_depletion, local_counts_enrichment
279
+
280
+
281
+ def _init(lock_: Any) -> None:
282
+ """Initialize a global lock for multiprocessing.
283
+
284
+ Args:
285
+ lock_ (Any): A lock object to be used in multiprocessing.
286
+ """
287
+ global lock
288
+ lock = lock_ # Assign the provided lock to a global variable
289
+
290
+
291
+ def calculate_significance_matrices(
292
+ depletion_pvals: np.ndarray,
293
+ enrichment_pvals: np.ndarray,
294
+ tail: str = "right",
295
+ pval_cutoff: float = 0.05,
296
+ apply_fdr: bool = False,
297
+ fdr_cutoff: float = 0.05,
298
+ ) -> dict:
299
+ """Calculate significance matrices based on p-values and specified tail.
300
+
301
+ Args:
302
+ depletion_pvals (np.ndarray): Matrix of depletion p-values.
303
+ enrichment_pvals (np.ndarray): Matrix of enrichment p-values.
304
+ tail (str, optional): The tail type for significance selection ('left', 'right', 'both'). Defaults to 'right'.
305
+ pval_cutoff (float, optional): Cutoff for p-value significance. Defaults to 0.05.
306
+ apply_fdr (bool, optional): Whether to apply FDR correction. Defaults to False.
307
+ fdr_cutoff (float, optional): Cutoff for FDR significance if applied. Defaults to 0.05.
308
+
309
+ Returns:
310
+ dict: Dictionary containing the enrichment matrix, binary significance matrix,
311
+ and the matrix of significant enrichment values.
312
+ """
313
+ if apply_fdr:
314
+ # Apply FDR correction to depletion p-values
315
+ depletion_qvals = np.apply_along_axis(fdrcorrection, 1, depletion_pvals)[:, 1, :]
316
+ depletion_alpha_threshold_matrix = _compute_threshold_matrix(
317
+ depletion_pvals, depletion_qvals, pval_cutoff=pval_cutoff, fdr_cutoff=fdr_cutoff
318
+ )
319
+ # Compute the depletion matrix using both q-values and p-values
320
+ depletion_matrix = (depletion_qvals**2) * (depletion_pvals**0.5)
321
+
322
+ # Apply FDR correction to enrichment p-values
323
+ enrichment_qvals = np.apply_along_axis(fdrcorrection, 1, enrichment_pvals)[:, 1, :]
324
+ enrichment_alpha_threshold_matrix = _compute_threshold_matrix(
325
+ enrichment_pvals, enrichment_qvals, pval_cutoff=pval_cutoff, fdr_cutoff=fdr_cutoff
326
+ )
327
+ # Compute the enrichment matrix using both q-values and p-values
328
+ enrichment_matrix = (enrichment_qvals**2) * (enrichment_pvals**0.5)
329
+ else:
330
+ # Compute threshold matrices based on p-value cutoffs only
331
+ depletion_alpha_threshold_matrix = _compute_threshold_matrix(
332
+ depletion_pvals, pval_cutoff=pval_cutoff
333
+ )
334
+ depletion_matrix = depletion_pvals
335
+
336
+ enrichment_alpha_threshold_matrix = _compute_threshold_matrix(
337
+ enrichment_pvals, pval_cutoff=pval_cutoff
338
+ )
339
+ enrichment_matrix = enrichment_pvals
340
+
341
+ # Apply a negative log10 transformation for visualization purposes
342
+ log_depletion_matrix = -np.log10(depletion_matrix)
343
+ log_enrichment_matrix = -np.log10(enrichment_matrix)
344
+
345
+ # Select the appropriate significance matrices based on the specified tail
346
+ enrichment_matrix, binary_enrichment_matrix = _select_significance_matrices(
347
+ tail,
348
+ log_depletion_matrix,
349
+ depletion_alpha_threshold_matrix,
350
+ log_enrichment_matrix,
351
+ enrichment_alpha_threshold_matrix,
352
+ )
353
+
354
+ # Filter the enrichment matrix using the binary significance matrix
355
+ significant_enrichment_matrix = np.where(binary_enrichment_matrix == 1, enrichment_matrix, 0)
356
+
357
+ return {
358
+ "enrichment_matrix": enrichment_matrix,
359
+ "binary_enrichment_matrix": binary_enrichment_matrix,
360
+ "significant_enrichment_matrix": significant_enrichment_matrix,
361
+ }
362
+
363
+
364
+ def _select_significance_matrices(
365
+ tail: str,
366
+ log_depletion_matrix: np.ndarray,
367
+ depletion_alpha_threshold_matrix: np.ndarray,
368
+ log_enrichment_matrix: np.ndarray,
369
+ enrichment_alpha_threshold_matrix: np.ndarray,
370
+ ) -> tuple:
371
+ """Select significance matrices based on the specified tail type.
372
+
373
+ Args:
374
+ tail (str): The tail type for significance selection. Options are 'left', 'right', or 'both'.
375
+ log_depletion_matrix (np.ndarray): Matrix of log-transformed depletion values.
376
+ depletion_alpha_threshold_matrix (np.ndarray): Alpha threshold matrix for depletion significance.
377
+ log_enrichment_matrix (np.ndarray): Matrix of log-transformed enrichment values.
378
+ enrichment_alpha_threshold_matrix (np.ndarray): Alpha threshold matrix for enrichment significance.
379
+
380
+ Returns:
381
+ tuple: A tuple containing the selected enrichment matrix and binary significance matrix.
382
+
383
+ Raises:
384
+ ValueError: If the provided tail type is not 'left', 'right', or 'both'.
385
+ """
386
+ if tail not in {"left", "right", "both"}:
387
+ raise ValueError("Invalid value for 'tail'. Must be 'left', 'right', or 'both'.")
388
+
389
+ if tail == "left":
390
+ # Select depletion matrix and corresponding alpha threshold for left-tail analysis
391
+ enrichment_matrix = -log_depletion_matrix
392
+ alpha_threshold_matrix = depletion_alpha_threshold_matrix
393
+ elif tail == "right":
394
+ # Select enrichment matrix and corresponding alpha threshold for right-tail analysis
395
+ enrichment_matrix = log_enrichment_matrix
396
+ alpha_threshold_matrix = enrichment_alpha_threshold_matrix
397
+ elif tail == "both":
398
+ # Select the matrix with the highest absolute values while preserving the sign
399
+ enrichment_matrix = np.where(
400
+ np.abs(log_depletion_matrix) >= np.abs(log_enrichment_matrix),
401
+ -log_depletion_matrix,
402
+ log_enrichment_matrix,
403
+ )
404
+ # Combine alpha thresholds using a logical OR operation
405
+ alpha_threshold_matrix = np.logical_or(
406
+ depletion_alpha_threshold_matrix, enrichment_alpha_threshold_matrix
407
+ )
408
+
409
+ # Create a binary significance matrix where valid indices meet the alpha threshold
410
+ valid_idxs = ~np.isnan(alpha_threshold_matrix)
411
+ binary_enrichment_matrix = np.zeros(alpha_threshold_matrix.shape)
412
+ binary_enrichment_matrix[valid_idxs] = alpha_threshold_matrix[valid_idxs]
413
+
414
+ return enrichment_matrix, binary_enrichment_matrix
415
+
416
+
417
+ def _compute_threshold_matrix(
418
+ pvals: np.ndarray,
419
+ fdr_pvals: Union[np.ndarray, None] = None,
420
+ pval_cutoff: float = 0.05,
421
+ fdr_cutoff: float = 0.05,
422
+ ) -> np.ndarray:
423
+ """Compute a threshold matrix indicating significance based on p-value and FDR cutoffs.
424
+
425
+ Args:
426
+ pvals (np.ndarray): Array of p-values for statistical tests.
427
+ fdr_pvals (np.ndarray, optional): Array of FDR-corrected p-values corresponding to the p-values. Defaults to None.
428
+ pval_cutoff (float, optional): Cutoff for p-value significance. Defaults to 0.05.
429
+ fdr_cutoff (float, optional): Cutoff for FDR significance. Defaults to 0.05.
430
+
431
+ Returns:
432
+ np.ndarray: A threshold matrix where 1 indicates significance based on the provided cutoffs, 0 otherwise.
433
+ """
434
+ if fdr_pvals is not None:
435
+ # Compute the threshold matrix based on both p-value and FDR cutoffs
436
+ pval_below_cutoff = pvals <= pval_cutoff
437
+ fdr_below_cutoff = fdr_pvals <= fdr_cutoff
438
+ threshold_matrix = np.logical_and(pval_below_cutoff, fdr_below_cutoff).astype(int)
439
+ else:
440
+ # Compute the threshold matrix based only on p-value cutoff
441
+ threshold_matrix = (pvals <= pval_cutoff).astype(int)
442
+
443
+ return threshold_matrix