risk-network 0.0.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/stats/stats.py ADDED
@@ -0,0 +1,447 @@
1
+ """
2
+ risk/stats/stats
3
+ ~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import sys
7
+ from contextlib import contextmanager
8
+ from multiprocessing import get_context, Lock
9
+ from typing import Any, Callable, Generator, Union
10
+
11
+ import numpy as np
12
+ from statsmodels.stats.multitest import fdrcorrection
13
+ from threadpoolctl import threadpool_limits
14
+
15
+
16
+ def _is_notebook() -> bool:
17
+ """Determine the type of interactive environment and return it as a dictionary.
18
+
19
+ Returns:
20
+ bool: True if the environment is a Jupyter notebook, False otherwise.
21
+ """
22
+ try:
23
+ shell = get_ipython().__class__.__name__
24
+ if shell == "ZMQInteractiveShell":
25
+ return True # Jupyter notebook or qtconsole
26
+ elif shell == "TerminalInteractiveShell":
27
+ return False # Terminal running IPython
28
+ else:
29
+ return False # Other types of shell
30
+ except NameError:
31
+ return False # Standard Python interpreter
32
+
33
+
34
+ if _is_notebook():
35
+ from tqdm.notebook import tqdm
36
+ else:
37
+ from tqdm import tqdm
38
+
39
+
40
+ from risk.stats.permutation import (
41
+ compute_neighborhood_score_by_sum,
42
+ compute_neighborhood_score_by_stdev,
43
+ compute_neighborhood_score_by_z_score,
44
+ )
45
+
46
+ DISPATCH_PERMUTATION_TABLE = {
47
+ "sum": compute_neighborhood_score_by_sum,
48
+ "stdev": compute_neighborhood_score_by_stdev,
49
+ "z_score": compute_neighborhood_score_by_z_score,
50
+ }
51
+
52
+
53
+ def compute_permutation(
54
+ neighborhoods: np.ndarray,
55
+ annotations: np.ndarray,
56
+ score_metric: str = "sum",
57
+ null_distribution: str = "network",
58
+ num_permutations: int = 1000,
59
+ random_seed: int = 888,
60
+ max_workers: int = 1,
61
+ ) -> dict:
62
+ """Compute permutation test for enrichment and depletion in neighborhoods.
63
+
64
+ Args:
65
+ neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
66
+ annotations (np.ndarray): Binary matrix representing annotations.
67
+ score_metric (str, optional): Metric to use for scoring ('sum', 'mean', etc.). Defaults to "sum".
68
+ null_distribution (str, optional): Type of null distribution ('network' or other). Defaults to "network".
69
+ num_permutations (int, optional): Number of permutations to run. Defaults to 1000.
70
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
71
+ max_workers (int, optional): Number of workers for multiprocessing. Defaults to 1.
72
+
73
+ Returns:
74
+ dict: Dictionary containing depletion and enrichment p-values.
75
+ """
76
+ # Ensure that the matrices are in the correct format and free of NaN values
77
+ neighborhoods = neighborhoods.astype(np.float32)
78
+ annotations = annotations.astype(np.float32)
79
+ # Retrieve the appropriate neighborhood score function based on the metric
80
+ neighborhood_score_func = DISPATCH_PERMUTATION_TABLE[score_metric]
81
+ # Run the permutation test to calculate depletion and enrichment counts
82
+ counts_depletion, counts_enrichment = _run_permutation_test(
83
+ neighborhoods=neighborhoods,
84
+ annotations=annotations,
85
+ neighborhood_score_func=neighborhood_score_func,
86
+ null_distribution=null_distribution,
87
+ num_permutations=num_permutations,
88
+ random_seed=random_seed,
89
+ max_workers=max_workers,
90
+ )
91
+
92
+ # Compute p-values for depletion and enrichment
93
+ # If counts are 0, set p-value to 1/num_permutations to avoid zero p-values
94
+ depletion_pvals = np.maximum(counts_depletion, 1) / num_permutations
95
+ enrichment_pvals = np.maximum(counts_enrichment, 1) / num_permutations
96
+
97
+ return {
98
+ "depletion_pvals": depletion_pvals,
99
+ "enrichment_pvals": enrichment_pvals,
100
+ }
101
+
102
+
103
+ def _run_permutation_test(
104
+ neighborhoods: np.ndarray,
105
+ annotations: np.ndarray,
106
+ neighborhood_score_func: Callable,
107
+ null_distribution: str = "network",
108
+ num_permutations: int = 1000,
109
+ random_seed: int = 888,
110
+ max_workers: int = 4,
111
+ ) -> tuple:
112
+ """Run a permutation test to calculate enrichment and depletion counts.
113
+
114
+ Args:
115
+ neighborhoods (np.ndarray): The neighborhood matrix.
116
+ annotations (np.ndarray): The annotation matrix.
117
+ neighborhood_score_func (Callable): Function to calculate neighborhood scores.
118
+ null_distribution (str, optional): Type of null distribution. Defaults to "network".
119
+ num_permutations (int, optional): Number of permutations. Defaults to 1000.
120
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
121
+ max_workers (int, optional): Number of workers for multiprocessing. Defaults to 4.
122
+
123
+ Returns:
124
+ tuple: Depletion and enrichment counts.
125
+ """
126
+ # Set the random seed for reproducibility
127
+ np.random.seed(random_seed)
128
+ # Determine the indices to use based on the null distribution type
129
+ if null_distribution == "network":
130
+ idxs = range(annotations.shape[0])
131
+ else:
132
+ idxs = np.nonzero(np.sum(~np.isnan(annotations), axis=1))[0]
133
+
134
+ # Replace NaNs with zeros in the annotations matrix
135
+ annotations[np.isnan(annotations)] = 0
136
+ annotation_matrix_obsv = annotations[idxs]
137
+ neighborhoods_matrix_obsv = neighborhoods.T[idxs].T
138
+ # Calculate observed neighborhood scores
139
+ with np.errstate(invalid="ignore", divide="ignore"):
140
+ observed_neighborhood_scores = neighborhood_score_func(
141
+ neighborhoods_matrix_obsv, annotation_matrix_obsv
142
+ )
143
+
144
+ # Initialize count matrices for depletion and enrichment
145
+ counts_depletion = np.zeros(observed_neighborhood_scores.shape)
146
+ counts_enrichment = np.zeros(observed_neighborhood_scores.shape)
147
+ # Determine the number of permutations to run in each worker process
148
+ subset_size = num_permutations // max_workers
149
+ remainder = num_permutations % max_workers
150
+
151
+ # Use the spawn context for creating a new multiprocessing pool
152
+ ctx = get_context("spawn")
153
+ with ctx.Pool(max_workers, initializer=_init, initargs=(Lock(),)) as pool:
154
+ with threadpool_limits(limits=1, user_api="blas"):
155
+ params_list = [
156
+ (
157
+ annotations,
158
+ np.array(idxs),
159
+ neighborhoods_matrix_obsv,
160
+ observed_neighborhood_scores,
161
+ neighborhood_score_func,
162
+ subset_size + (1 if i < remainder else 0),
163
+ i,
164
+ max_workers,
165
+ True,
166
+ )
167
+ for i in range(max_workers)
168
+ ]
169
+ results = pool.starmap(_permutation_process_subset, params_list)
170
+ for local_counts_depletion, local_counts_enrichment in results:
171
+ counts_depletion = np.add(counts_depletion, local_counts_depletion)
172
+ counts_enrichment = np.add(counts_enrichment, local_counts_enrichment)
173
+
174
+ return counts_depletion, counts_enrichment
175
+
176
+
177
+ def _permutation_process_subset(
178
+ annotation_matrix: np.ndarray,
179
+ idxs: np.ndarray,
180
+ neighborhoods_matrix_obsv: np.ndarray,
181
+ observed_neighborhood_scores: np.ndarray,
182
+ neighborhood_score_func: Callable,
183
+ subset_size: int,
184
+ worker_id: int,
185
+ max_workers: int,
186
+ use_lock: bool,
187
+ ) -> tuple:
188
+ """Process a subset of permutations for the permutation test.
189
+
190
+ Args:
191
+ annotation_matrix (np.ndarray): The annotation matrix.
192
+ idxs (np.ndarray): Indices of valid rows in the matrix.
193
+ neighborhoods_matrix_obsv (np.ndarray): Observed neighborhoods matrix.
194
+ observed_neighborhood_scores (np.ndarray): Observed neighborhood scores.
195
+ neighborhood_score_func (Callable): Function to calculate neighborhood scores.
196
+ subset_size (int): Number of permutations to run in this subset.
197
+ worker_id (int): ID of the worker process.
198
+ max_workers (int): Number of worker processes.
199
+ use_lock (bool): Whether to use a lock for multiprocessing synchronization.
200
+
201
+ Returns:
202
+ tuple: Local counts of depletion and enrichment.
203
+ """
204
+ # Initialize local count matrices for this worker
205
+ local_counts_depletion = np.zeros(observed_neighborhood_scores.shape)
206
+ local_counts_enrichment = np.zeros(observed_neighborhood_scores.shape)
207
+
208
+ # Initialize progress bar for tracking permutation progress
209
+ text = f"Worker {worker_id + 1} of {max_workers} progress"
210
+ leave = worker_id == max_workers - 1 # Only leave the progress bar for the last worker
211
+
212
+ with _tqdm_context(
213
+ total=subset_size, desc=text, position=0, leave=leave, use_lock=use_lock
214
+ ) as progress:
215
+ for _ in range(subset_size):
216
+ # Permute the annotation matrix
217
+ annotation_matrix_permut = annotation_matrix[np.random.permutation(idxs)]
218
+ # Calculate permuted neighborhood scores
219
+ with np.errstate(invalid="ignore", divide="ignore"):
220
+ permuted_neighborhood_scores = neighborhood_score_func(
221
+ neighborhoods_matrix_obsv, annotation_matrix_permut
222
+ )
223
+ # Update local depletion and enrichment counts based on permuted scores
224
+ local_counts_depletion = np.add(
225
+ local_counts_depletion, permuted_neighborhood_scores <= observed_neighborhood_scores
226
+ )
227
+ local_counts_enrichment = np.add(
228
+ local_counts_enrichment,
229
+ permuted_neighborhood_scores >= observed_neighborhood_scores,
230
+ )
231
+ # Update progress bar
232
+ progress.update(1)
233
+
234
+ return local_counts_depletion, local_counts_enrichment
235
+
236
+
237
+ def _init(lock_: Any) -> None:
238
+ """Initialize a global lock for multiprocessing.
239
+
240
+ Args:
241
+ lock_ (Any): A lock object to be used in multiprocessing.
242
+ """
243
+ global lock
244
+ lock = lock_ # Assign the provided lock to a global variable
245
+
246
+
247
+ @contextmanager
248
+ def _tqdm_context(
249
+ total: int, desc: str, position: int, leave: bool = False, use_lock: bool = False
250
+ ) -> Generator:
251
+ """A context manager for a `tqdm` progress bar.
252
+
253
+ Args:
254
+ total (int): The total number of iterations for the progress bar.
255
+ desc (str): Description for the progress bar.
256
+ position (int): The position of the progress bar (useful for multiple bars).
257
+ leave (bool): Whether to leave the progress bar after completion.
258
+ use_lock (bool): Whether to use a lock for multiprocessing synchronization.
259
+
260
+ Yields:
261
+ tqdm: A `tqdm` progress bar object.
262
+ """
263
+ # Set default parameters for the progress bar
264
+ min_interval = 0.1
265
+ # Use a lock for multiprocessing synchronization if specified
266
+ if use_lock:
267
+ with lock:
268
+ # Create a progress bar with specified parameters and direct output to stderr
269
+ progress = tqdm(
270
+ total=total,
271
+ desc=desc,
272
+ position=position,
273
+ leave=leave,
274
+ mininterval=min_interval,
275
+ file=sys.stderr,
276
+ )
277
+ try:
278
+ yield progress # Yield the progress bar to the calling context
279
+ finally:
280
+ progress.close() # Ensure the progress bar is closed properly
281
+ else:
282
+ # Create a progress bar without using a lock
283
+ progress = tqdm(
284
+ total=total,
285
+ desc=desc,
286
+ position=position,
287
+ leave=leave,
288
+ mininterval=min_interval,
289
+ file=sys.stderr,
290
+ )
291
+ try:
292
+ yield progress # Yield the progress bar to the calling context
293
+ finally:
294
+ progress.close() # Ensure the progress bar is closed properly
295
+
296
+
297
+ def calculate_significance_matrices(
298
+ depletion_pvals: np.ndarray,
299
+ enrichment_pvals: np.ndarray,
300
+ tail: str = "right",
301
+ pval_cutoff: float = 0.05,
302
+ fdr_cutoff: float = 0.05,
303
+ ) -> dict:
304
+ """Calculate significance matrices based on p-values and specified tail.
305
+
306
+ Args:
307
+ depletion_pvals (np.ndarray): Matrix of depletion p-values.
308
+ enrichment_pvals (np.ndarray): Matrix of enrichment p-values.
309
+ tail (str, optional): The tail type for significance selection ('left', 'right', 'both'). Defaults to 'right'.
310
+ pval_cutoff (float, optional): Cutoff for p-value significance. Defaults to 0.05.
311
+ fdr_cutoff (float, optional): Cutoff for FDR significance if applied. Defaults to 0.05.
312
+
313
+ Returns:
314
+ dict: Dictionary containing the enrichment matrix, binary significance matrix,
315
+ and the matrix of significant enrichment values.
316
+ """
317
+ if fdr_cutoff < 1.0:
318
+ # Apply FDR correction to depletion p-values
319
+ depletion_qvals = np.apply_along_axis(fdrcorrection, 1, depletion_pvals)[:, 1, :]
320
+ depletion_alpha_threshold_matrix = _compute_threshold_matrix(
321
+ depletion_pvals, depletion_qvals, pval_cutoff=pval_cutoff, fdr_cutoff=fdr_cutoff
322
+ )
323
+ # Compute the depletion matrix using both q-values and p-values
324
+ depletion_matrix = (depletion_qvals**2) * (depletion_pvals**0.5)
325
+
326
+ # Apply FDR correction to enrichment p-values
327
+ enrichment_qvals = np.apply_along_axis(fdrcorrection, 1, enrichment_pvals)[:, 1, :]
328
+ enrichment_alpha_threshold_matrix = _compute_threshold_matrix(
329
+ enrichment_pvals, enrichment_qvals, pval_cutoff=pval_cutoff, fdr_cutoff=fdr_cutoff
330
+ )
331
+ # Compute the enrichment matrix using both q-values and p-values
332
+ enrichment_matrix = (enrichment_qvals**2) * (enrichment_pvals**0.5)
333
+ else:
334
+ # Compute threshold matrices based on p-value cutoffs only
335
+ depletion_alpha_threshold_matrix = _compute_threshold_matrix(
336
+ depletion_pvals, pval_cutoff=pval_cutoff
337
+ )
338
+ depletion_matrix = depletion_pvals
339
+
340
+ enrichment_alpha_threshold_matrix = _compute_threshold_matrix(
341
+ enrichment_pvals, pval_cutoff=pval_cutoff
342
+ )
343
+ enrichment_matrix = enrichment_pvals
344
+
345
+ # Apply a negative log10 transformation for visualization purposes
346
+ log_depletion_matrix = -np.log10(depletion_matrix)
347
+ log_enrichment_matrix = -np.log10(enrichment_matrix)
348
+
349
+ # Select the appropriate significance matrices based on the specified tail
350
+ enrichment_matrix, binary_enrichment_matrix = _select_significance_matrices(
351
+ tail,
352
+ log_depletion_matrix,
353
+ depletion_alpha_threshold_matrix,
354
+ log_enrichment_matrix,
355
+ enrichment_alpha_threshold_matrix,
356
+ )
357
+
358
+ # Filter the enrichment matrix using the binary significance matrix
359
+ significant_enrichment_matrix = np.where(binary_enrichment_matrix == 1, enrichment_matrix, 0)
360
+
361
+ return {
362
+ "enrichment_matrix": enrichment_matrix,
363
+ "binary_enrichment_matrix": binary_enrichment_matrix,
364
+ "significant_enrichment_matrix": significant_enrichment_matrix,
365
+ }
366
+
367
+
368
+ def _select_significance_matrices(
369
+ tail: str,
370
+ log_depletion_matrix: np.ndarray,
371
+ depletion_alpha_threshold_matrix: np.ndarray,
372
+ log_enrichment_matrix: np.ndarray,
373
+ enrichment_alpha_threshold_matrix: np.ndarray,
374
+ ) -> tuple:
375
+ """Select significance matrices based on the specified tail type.
376
+
377
+ Args:
378
+ tail (str): The tail type for significance selection. Options are 'left', 'right', or 'both'.
379
+ log_depletion_matrix (np.ndarray): Matrix of log-transformed depletion values.
380
+ depletion_alpha_threshold_matrix (np.ndarray): Alpha threshold matrix for depletion significance.
381
+ log_enrichment_matrix (np.ndarray): Matrix of log-transformed enrichment values.
382
+ enrichment_alpha_threshold_matrix (np.ndarray): Alpha threshold matrix for enrichment significance.
383
+
384
+ Returns:
385
+ tuple: A tuple containing the selected enrichment matrix and binary significance matrix.
386
+
387
+ Raises:
388
+ ValueError: If the provided tail type is not 'left', 'right', or 'both'.
389
+ """
390
+ if tail not in {"left", "right", "both"}:
391
+ raise ValueError("Invalid value for 'tail'. Must be 'left', 'right', or 'both'.")
392
+
393
+ if tail == "left":
394
+ # Select depletion matrix and corresponding alpha threshold for left-tail analysis
395
+ enrichment_matrix = -log_depletion_matrix
396
+ alpha_threshold_matrix = depletion_alpha_threshold_matrix
397
+ elif tail == "right":
398
+ # Select enrichment matrix and corresponding alpha threshold for right-tail analysis
399
+ enrichment_matrix = log_enrichment_matrix
400
+ alpha_threshold_matrix = enrichment_alpha_threshold_matrix
401
+ elif tail == "both":
402
+ # Select the matrix with the highest absolute values while preserving the sign
403
+ enrichment_matrix = np.where(
404
+ np.abs(log_depletion_matrix) >= np.abs(log_enrichment_matrix),
405
+ -log_depletion_matrix,
406
+ log_enrichment_matrix,
407
+ )
408
+ # Combine alpha thresholds using a logical OR operation
409
+ alpha_threshold_matrix = np.logical_or(
410
+ depletion_alpha_threshold_matrix, enrichment_alpha_threshold_matrix
411
+ )
412
+
413
+ # Create a binary significance matrix where valid indices meet the alpha threshold
414
+ valid_idxs = ~np.isnan(alpha_threshold_matrix)
415
+ binary_enrichment_matrix = np.zeros(alpha_threshold_matrix.shape)
416
+ binary_enrichment_matrix[valid_idxs] = alpha_threshold_matrix[valid_idxs]
417
+
418
+ return enrichment_matrix, binary_enrichment_matrix
419
+
420
+
421
+ def _compute_threshold_matrix(
422
+ pvals: np.ndarray,
423
+ fdr_pvals: Union[np.ndarray, None] = None,
424
+ pval_cutoff: float = 0.05,
425
+ fdr_cutoff: float = 0.05,
426
+ ) -> np.ndarray:
427
+ """Compute a threshold matrix indicating significance based on p-value and FDR cutoffs.
428
+
429
+ Args:
430
+ pvals (np.ndarray): Array of p-values for statistical tests.
431
+ fdr_pvals (np.ndarray, optional): Array of FDR-corrected p-values corresponding to the p-values. Defaults to None.
432
+ pval_cutoff (float, optional): Cutoff for p-value significance. Defaults to 0.05.
433
+ fdr_cutoff (float, optional): Cutoff for FDR significance. Defaults to 0.05.
434
+
435
+ Returns:
436
+ np.ndarray: A threshold matrix where 1 indicates significance based on the provided cutoffs, 0 otherwise.
437
+ """
438
+ if fdr_pvals is not None:
439
+ # Compute the threshold matrix based on both p-value and FDR cutoffs
440
+ pval_below_cutoff = pvals <= pval_cutoff
441
+ fdr_below_cutoff = fdr_pvals <= fdr_cutoff
442
+ threshold_matrix = np.logical_and(pval_below_cutoff, fdr_below_cutoff).astype(int)
443
+ else:
444
+ # Compute the threshold matrix based only on p-value cutoff
445
+ threshold_matrix = (pvals <= pval_cutoff).astype(int)
446
+
447
+ return threshold_matrix