risk-network 0.0.9b24__tar.gz → 0.0.9b26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/PKG-INFO +1 -1
  2. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/__init__.py +1 -1
  3. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/annotations/annotations.py +9 -9
  4. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/neighborhoods/__init__.py +1 -1
  5. risk_network-0.0.9b24/risk/neighborhoods/io.py → risk_network-0.0.9b26/risk/neighborhoods/api.py +15 -10
  6. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/neighborhoods/domains.py +11 -3
  7. risk_network-0.0.9b26/risk/network/__init__.py +6 -0
  8. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/network/graph/__init__.py +1 -1
  9. risk_network-0.0.9b24/risk/network/graph/io.py → risk_network-0.0.9b26/risk/network/graph/api.py +4 -4
  10. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/network/graph/summary.py +6 -2
  11. risk_network-0.0.9b26/risk/network/plotter/__init__.py +6 -0
  12. risk_network-0.0.9b24/risk/network/plot/io.py → risk_network-0.0.9b26/risk/network/plotter/api.py +5 -5
  13. {risk_network-0.0.9b24/risk/network/plot → risk_network-0.0.9b26/risk/network/plotter}/canvas.py +2 -2
  14. {risk_network-0.0.9b24/risk/network/plot → risk_network-0.0.9b26/risk/network/plotter}/contour.py +1 -1
  15. {risk_network-0.0.9b24/risk/network/plot → risk_network-0.0.9b26/risk/network/plotter}/labels.py +2 -2
  16. {risk_network-0.0.9b24/risk/network/plot → risk_network-0.0.9b26/risk/network/plotter}/network.py +5 -5
  17. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/risk.py +6 -3
  18. risk_network-0.0.9b26/risk/stats/binom.py +51 -0
  19. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/stats/chi2.py +25 -13
  20. risk_network-0.0.9b26/risk/stats/hypergeom.py +64 -0
  21. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/stats/permutation/permutation.py +45 -56
  22. risk_network-0.0.9b26/risk/stats/permutation/test_functions.py +69 -0
  23. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/stats/poisson.py +15 -9
  24. risk_network-0.0.9b26/risk/stats/zscore.py +68 -0
  25. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk_network.egg-info/PKG-INFO +1 -1
  26. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk_network.egg-info/SOURCES.txt +10 -10
  27. risk_network-0.0.9b24/risk/network/__init__.py +0 -8
  28. risk_network-0.0.9b24/risk/network/plot/__init__.py +0 -6
  29. risk_network-0.0.9b24/risk/stats/binom.py +0 -47
  30. risk_network-0.0.9b24/risk/stats/hypergeom.py +0 -54
  31. risk_network-0.0.9b24/risk/stats/permutation/test_functions.py +0 -61
  32. risk_network-0.0.9b24/risk/stats/zscore.py +0 -62
  33. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/LICENSE +0 -0
  34. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/MANIFEST.in +0 -0
  35. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/README.md +0 -0
  36. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/pyproject.toml +0 -0
  37. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/annotations/__init__.py +0 -0
  38. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/annotations/io.py +0 -0
  39. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/constants.py +0 -0
  40. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/log/__init__.py +0 -0
  41. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/log/console.py +0 -0
  42. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/log/parameters.py +0 -0
  43. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/neighborhoods/community.py +0 -0
  44. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/neighborhoods/neighborhoods.py +0 -0
  45. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/network/geometry.py +0 -0
  46. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/network/graph/network.py +0 -0
  47. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/network/io.py +0 -0
  48. {risk_network-0.0.9b24/risk/network/plot → risk_network-0.0.9b26/risk/network/plotter}/utils/colors.py +0 -0
  49. {risk_network-0.0.9b24/risk/network/plot → risk_network-0.0.9b26/risk/network/plotter}/utils/layout.py +0 -0
  50. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/stats/__init__.py +0 -0
  51. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/stats/permutation/__init__.py +0 -0
  52. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk/stats/stats.py +0 -0
  53. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk_network.egg-info/dependency_links.txt +0 -0
  54. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk_network.egg-info/requires.txt +0 -0
  55. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/risk_network.egg-info/top_level.txt +0 -0
  56. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/setup.cfg +0 -0
  57. {risk_network-0.0.9b24 → risk_network-0.0.9b26}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: risk-network
3
- Version: 0.0.9b24
3
+ Version: 0.0.9b26
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -7,4 +7,4 @@ RISK: Regional Inference of Significant Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.9-beta.24"
10
+ __version__ = "0.0.9-beta.26"
@@ -16,6 +16,7 @@ from nltk.tokenize import word_tokenize
16
16
  from nltk.corpus import stopwords
17
17
 
18
18
  from risk.log import logger
19
+ from scipy.sparse import csr_matrix
19
20
 
20
21
 
21
22
  def _setup_nltk():
@@ -47,17 +48,15 @@ def load_annotations(
47
48
  annotations_input (Dict[str, Any]): A dictionary with annotations.
48
49
  min_nodes_per_term (int, optional): The minimum number of network nodes required for each annotation
49
50
  term to be included. Defaults to 2.
51
+ use_sparse (bool, optional): Whether to return the annotations matrix as a sparse matrix. Defaults to True.
50
52
 
51
53
  Returns:
52
- Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the binary annotations matrix.
54
+ Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the sparse binary annotations
55
+ matrix.
53
56
 
54
57
  Raises:
55
58
  ValueError: If no annotations are found for the nodes in the network.
56
59
  ValueError: If no annotations have at least min_nodes_per_term nodes in the network.
57
-
58
- Comment:
59
- This function should be optimized to handle large networks and annotations efficiently. An attempt
60
- to use sparse matrices did not yield significant performance improvements, so it was not implemented.
61
60
  """
62
61
  # Flatten the dictionary to a list of tuples for easier DataFrame creation
63
62
  flattened_annotations = [
@@ -78,7 +77,6 @@ def load_annotations(
78
77
  raise ValueError("No terms found in the annotation file for the nodes in the network.")
79
78
 
80
79
  # Filter out annotations with fewer than min_nodes_per_term occurrences
81
- # This assists in reducing noise and focusing on more relevant annotations for statistical analysis
82
80
  num_terms_before_filtering = annotations_pivot.shape[1]
83
81
  annotations_pivot = annotations_pivot.loc[
84
82
  :, (annotations_pivot.sum(axis=0) >= min_nodes_per_term)
@@ -96,13 +94,15 @@ def load_annotations(
96
94
  # Extract ordered nodes and annotations
97
95
  ordered_nodes = tuple(annotations_pivot.index)
98
96
  ordered_annotations = tuple(annotations_pivot.columns)
99
- # Convert the annotations_pivot matrix to a numpy array and ensure it's binary
100
- annotations_pivot_numpy = (annotations_pivot.fillna(0).to_numpy() > 0).astype(int)
97
+ # Convert the annotations_pivot matrix to a numpy array or sparse matrix
98
+ annotations_pivot_binary = (annotations_pivot.fillna(0).to_numpy() > 0).astype(int)
99
+ # Convert the binary annotations matrix to a sparse matrix
100
+ annotations_pivot_binary = csr_matrix(annotations_pivot_binary)
101
101
 
102
102
  return {
103
103
  "ordered_nodes": ordered_nodes,
104
104
  "ordered_annotations": ordered_annotations,
105
- "matrix": annotations_pivot_numpy,
105
+ "matrix": annotations_pivot_binary,
106
106
  }
107
107
 
108
108
 
@@ -4,5 +4,5 @@ risk/neighborhoods
4
4
  """
5
5
 
6
6
  from risk.neighborhoods.domains import define_domains, trim_domains
7
- from risk.neighborhoods.io import NeighborhoodsIO
7
+ from risk.neighborhoods.api import NeighborhoodsAPI
8
8
  from risk.neighborhoods.neighborhoods import process_neighborhoods
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/io
3
- ~~~~~~~~~~~~~~~~~~~~~
2
+ risk/neighborhoods/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import copy
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Tuple, Union
8
8
 
9
9
  import networkx as nx
10
10
  import numpy as np
11
+ from scipy.sparse import csr_matrix
11
12
 
12
13
  from risk.log import logger, log_header, params
13
14
  from risk.neighborhoods.neighborhoods import get_network_neighborhoods
@@ -21,10 +22,10 @@ from risk.stats import (
21
22
  )
22
23
 
23
24
 
24
- class NeighborhoodsIO:
25
+ class NeighborhoodsAPI:
25
26
  """Handles the loading of statistical results and annotation significance for neighborhoods.
26
27
 
27
- The NeighborhoodsIO class provides methods to load neighborhood results from statistical tests.
28
+ The NeighborhoodsAPI class provides methods to load neighborhood results from statistical tests.
28
29
  """
29
30
 
30
31
  def __init__() -> None:
@@ -86,7 +87,7 @@ class NeighborhoodsIO:
86
87
  null_distribution: str = "network",
87
88
  random_seed: int = 888,
88
89
  ) -> Dict[str, Any]:
89
- """Load significant neighborhoods for the network using the Chi-squared test.
90
+ """Load significant neighborhoods for the network using the chi-squared test.
90
91
 
91
92
  Args:
92
93
  network (nx.Graph): The network graph.
@@ -396,12 +397,11 @@ class NeighborhoodsIO:
396
397
  leiden_resolution: float = 1.0,
397
398
  fraction_shortest_edges: Union[float, List, Tuple, np.ndarray] = 0.5,
398
399
  random_seed: int = 888,
399
- ) -> np.ndarray:
400
+ ) -> csr_matrix:
400
401
  """Load significant neighborhoods for the network.
401
402
 
402
403
  Args:
403
404
  network (nx.Graph): The network graph.
404
- annotations (pd.DataFrame): The matrix of annotations associated with the network.
405
405
  distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
406
406
  metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'leiden', 'label_propagation',
407
407
  'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
@@ -413,7 +413,7 @@ class NeighborhoodsIO:
413
413
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
414
414
 
415
415
  Returns:
416
- np.ndarray: Neighborhood matrix calculated based on the selected distance metric.
416
+ csr_matrix: Sparse neighborhood matrix calculated based on the selected distance metric.
417
417
  """
418
418
  # Display the chosen distance metric
419
419
  if distance_metric == "louvain":
@@ -422,12 +422,13 @@ class NeighborhoodsIO:
422
422
  for_print_distance_metric = f"leiden (resolution={leiden_resolution})"
423
423
  else:
424
424
  for_print_distance_metric = distance_metric
425
+
425
426
  # Log and display neighborhood settings
426
427
  logger.debug(f"Distance metric: '{for_print_distance_metric}'")
427
428
  logger.debug(f"Edge length threshold: {fraction_shortest_edges}")
428
429
  logger.debug(f"Random seed: {random_seed}")
429
430
 
430
- # Compute neighborhoods based on the network and distance metric
431
+ # Compute neighborhoods
431
432
  neighborhoods = get_network_neighborhoods(
432
433
  network,
433
434
  distance_metric,
@@ -437,5 +438,9 @@ class NeighborhoodsIO:
437
438
  random_seed=random_seed,
438
439
  )
439
440
 
440
- # Return the computed neighborhoods
441
+ # Ensure the neighborhood matrix is in sparse format
442
+ if not isinstance(neighborhoods, csr_matrix):
443
+ neighborhoods = csr_matrix(neighborhoods)
444
+
445
+ # Return the sparse neighborhood matrix
441
446
  return neighborhoods
@@ -39,6 +39,9 @@ def define_domains(
39
39
  pd.DataFrame: DataFrame with the primary domain for each node.
40
40
  """
41
41
  try:
42
+ if linkage_criterion == "off":
43
+ raise ValueError("Clustering is turned off.")
44
+
42
45
  # Transpose the matrix to cluster annotations
43
46
  m = significant_neighborhoods_significance[:, top_annotations["significant_annotations"]].T
44
47
  # Safeguard the matrix by replacing NaN, Inf, and -Inf values
@@ -62,9 +65,14 @@ def define_domains(
62
65
  except ValueError:
63
66
  # If a ValueError is encountered, handle it by assigning unique domains
64
67
  n_rows = len(top_annotations)
65
- logger.error(
66
- f"Error encountered. Skipping clustering and assigning {n_rows} unique domains."
67
- )
68
+ if linkage_criterion == "off":
69
+ logger.warning(
70
+ f"Clustering is turned off. Skipping clustering and assigning {n_rows} unique domains."
71
+ )
72
+ else:
73
+ logger.error(
74
+ f"Error encountered. Skipping clustering and assigning {n_rows} unique domains."
75
+ )
68
76
  top_annotations["domain"] = range(1, n_rows + 1) # Assign unique domains
69
77
 
70
78
  # Create DataFrames to store domain information
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network
3
+ ~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.io import NetworkIO
@@ -3,4 +3,4 @@ risk/network/graph
3
3
  ~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from risk.network.graph.io import GraphIO
6
+ from risk.network.graph.api import GraphAPI
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/graph/io
3
- ~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import copy
@@ -20,10 +20,10 @@ from risk.network.graph.network import NetworkGraph
20
20
  from risk.stats import calculate_significance_matrices
21
21
 
22
22
 
23
- class GraphIO:
23
+ class GraphAPI:
24
24
  """Handles the loading of network graphs and associated data.
25
25
 
26
- The GraphIO class provides methods to load and process network graphs, annotations, and neighborhoods.
26
+ The GraphAPI class provides methods to load and process network graphs, annotations, and neighborhoods.
27
27
  """
28
28
 
29
29
  def __init__() -> None:
@@ -240,8 +240,12 @@ class AnalysisSummary:
240
240
  except ValueError:
241
241
  return "" # Description not found
242
242
 
243
- # Get nodes present for the annotation and sort by node label
244
- nodes_present = np.where(self.annotations["matrix"][:, annotation_idx] == 1)[0]
243
+ # Get the column (safely) from the sparse matrix
244
+ column = self.annotations["matrix"][:, annotation_idx]
245
+ # Convert the column to a dense array if needed
246
+ column = column.toarray().ravel() # Convert to a 1D dense array
247
+ # Get nodes present for the annotation and sort by node label - use np.where on the dense array
248
+ nodes_present = np.where(column == 1)[0]
245
249
  node_labels = sorted(
246
250
  self.graph.node_id_to_node_label_map[node_id]
247
251
  for node_id in nodes_present
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/plot
3
+ ~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.plotter.api import PlotterAPI
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/graph/io
3
- ~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
@@ -9,13 +9,13 @@ import numpy as np
9
9
 
10
10
  from risk.log import log_header
11
11
  from risk.network.graph.network import NetworkGraph
12
- from risk.network.plot.network import NetworkPlotter
12
+ from risk.network.plotter.network import NetworkPlotter
13
13
 
14
14
 
15
- class PlotterIO:
15
+ class PlotterAPI:
16
16
  """Handles the loading of network plotter objects.
17
17
 
18
- The PlotterIO class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
18
+ The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
19
19
  """
20
20
 
21
21
  def __init__() -> None:
@@ -10,8 +10,8 @@ import numpy as np
10
10
 
11
11
  from risk.log import params
12
12
  from risk.network.graph.network import NetworkGraph
13
- from risk.network.plot.utils.colors import to_rgba
14
- from risk.network.plot.utils.layout import calculate_bounding_box
13
+ from risk.network.plotter.utils.colors import to_rgba
14
+ from risk.network.plotter.utils.layout import calculate_bounding_box
15
15
 
16
16
 
17
17
  class Canvas:
@@ -13,7 +13,7 @@ from scipy.stats import gaussian_kde
13
13
 
14
14
  from risk.log import params, logger
15
15
  from risk.network.graph.network import NetworkGraph
16
- from risk.network.plot.utils.colors import get_annotated_domain_colors, to_rgba
16
+ from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
17
17
 
18
18
 
19
19
  class Contour:
@@ -12,8 +12,8 @@ import pandas as pd
12
12
 
13
13
  from risk.log import params
14
14
  from risk.network.graph.network import NetworkGraph
15
- from risk.network.plot.utils.colors import get_annotated_domain_colors, to_rgba
16
- from risk.network.plot.utils.layout import calculate_bounding_box
15
+ from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
16
+ from risk.network.plotter.utils.layout import calculate_bounding_box
17
17
 
18
18
  TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
19
19
 
@@ -11,11 +11,11 @@ import numpy as np
11
11
 
12
12
  from risk.log import params
13
13
  from risk.network.graph.network import NetworkGraph
14
- from risk.network.plot.canvas import Canvas
15
- from risk.network.plot.contour import Contour
16
- from risk.network.plot.labels import Labels
17
- from risk.network.plot.utils.colors import get_domain_colors, to_rgba
18
- from risk.network.plot.utils.layout import calculate_bounding_box
14
+ from risk.network.plotter.canvas import Canvas
15
+ from risk.network.plotter.contour import Contour
16
+ from risk.network.plotter.labels import Labels
17
+ from risk.network.plotter.utils.colors import get_domain_colors, to_rgba
18
+ from risk.network.plotter.utils.layout import calculate_bounding_box
19
19
 
20
20
 
21
21
  class Network:
@@ -3,13 +3,16 @@ risk/risk
3
3
  ~~~~~~~~~
4
4
  """
5
5
 
6
+ from risk.network import NetworkIO
6
7
  from risk.annotations import AnnotationsIO
8
+ from risk.neighborhoods import NeighborhoodsAPI
9
+ from risk.network.graph import GraphAPI
10
+ from risk.network.plotter import PlotterAPI
11
+
7
12
  from risk.log import params, set_global_verbosity
8
- from risk.neighborhoods import NeighborhoodsIO
9
- from risk.network import GraphIO, NetworkIO, PlotterIO
10
13
 
11
14
 
12
- class RISK(NetworkIO, AnnotationsIO, NeighborhoodsIO, GraphIO, PlotterIO):
15
+ class RISK(NetworkIO, AnnotationsIO, NeighborhoodsAPI, GraphAPI, PlotterAPI):
13
16
  """RISK: A class for network analysis and visualization.
14
17
 
15
18
  The RISK class integrates functionalities for loading networks, processing annotations,
@@ -0,0 +1,51 @@
1
+ """
2
+ risk/stats/binomial
3
+ ~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict
7
+
8
+ from scipy.sparse import csr_matrix
9
+ from scipy.stats import binom
10
+
11
+
12
+ def compute_binom_test(
13
+ neighborhoods: csr_matrix,
14
+ annotations: csr_matrix,
15
+ null_distribution: str = "network",
16
+ ) -> Dict[str, Any]:
17
+ """Compute Binomial test for enrichment and depletion in neighborhoods with selectable null distribution.
18
+
19
+ Args:
20
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
21
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
22
+ null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
23
+
24
+ Returns:
25
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
26
+ """
27
+ # Get the total number of nodes in the network
28
+ total_nodes = neighborhoods.shape[1]
29
+
30
+ # Compute sums (remain sparse here)
31
+ neighborhood_sizes = neighborhoods.sum(axis=1) # Row sums
32
+ annotation_totals = annotations.sum(axis=0) # Column sums
33
+ # Compute probabilities (convert to dense)
34
+ if null_distribution == "network":
35
+ p_values = (annotation_totals / total_nodes).A.flatten() # Dense 1D array
36
+ elif null_distribution == "annotations":
37
+ p_values = (annotation_totals / annotations.sum()).A.flatten() # Dense 1D array
38
+ else:
39
+ raise ValueError(
40
+ "Invalid null_distribution value. Choose either 'network' or 'annotations'."
41
+ )
42
+
43
+ # Observed counts (sparse matrix multiplication)
44
+ annotated_counts = neighborhoods @ annotations # Sparse result
45
+ annotated_counts_dense = annotated_counts.toarray() # Convert for dense operations
46
+
47
+ # Compute enrichment and depletion p-values
48
+ enrichment_pvals = 1 - binom.cdf(annotated_counts_dense - 1, neighborhood_sizes.A, p_values)
49
+ depletion_pvals = binom.cdf(annotated_counts_dense, neighborhood_sizes.A, p_values)
50
+
51
+ return {"enrichment_pvals": enrichment_pvals, "depletion_pvals": depletion_pvals}
@@ -4,51 +4,63 @@ risk/stats/chi2
4
4
  """
5
5
 
6
6
  from typing import Any, Dict
7
+
7
8
  import numpy as np
9
+ from scipy.sparse import csr_matrix
8
10
  from scipy.stats import chi2
9
11
 
10
12
 
11
13
  def compute_chi2_test(
12
- neighborhoods: np.ndarray, annotations: np.ndarray, null_distribution: str = "network"
14
+ neighborhoods: csr_matrix,
15
+ annotations: csr_matrix,
16
+ null_distribution: str = "network",
13
17
  ) -> Dict[str, Any]:
14
18
  """Compute chi-squared test for enrichment and depletion in neighborhoods with selectable null distribution.
15
19
 
16
20
  Args:
17
- neighborhoods (np.ndarray): Binary matrix representing neighborhoods.
18
- annotations (np.ndarray): Binary matrix representing annotations.
21
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
22
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
19
23
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
20
24
 
21
25
  Returns:
22
26
  Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
23
27
  """
24
- # Get the total number of nodes in the network
28
+ # Total number of nodes in the network
25
29
  total_node_count = neighborhoods.shape[0]
26
30
 
27
31
  if null_distribution == "network":
28
32
  # Case 1: Use all nodes as the background
29
33
  background_population = total_node_count
30
- neighborhood_sums = np.sum(
31
- neighborhoods, axis=0, keepdims=True
32
- ).T # Column sums of neighborhoods
33
- annotation_sums = np.sum(annotations, axis=0, keepdims=True) # Column sums of annotations
34
+ neighborhood_sums = neighborhoods.sum(axis=0) # Column sums of neighborhoods
35
+ annotation_sums = annotations.sum(axis=0) # Column sums of annotations
34
36
  elif null_distribution == "annotations":
35
37
  # Case 2: Only consider nodes with at least one annotation
36
- annotated_nodes = np.sum(annotations, axis=1) > 0
37
- background_population = np.sum(annotated_nodes)
38
- neighborhood_sums = np.sum(neighborhoods[annotated_nodes], axis=0, keepdims=True).T
39
- annotation_sums = np.sum(annotations[annotated_nodes], axis=0, keepdims=True)
38
+ annotated_nodes = (
39
+ np.ravel(annotations.sum(axis=1)) > 0
40
+ ) # Row-wise sum to filter nodes with annotations
41
+ background_population = annotated_nodes.sum() # Total number of annotated nodes
42
+ neighborhood_sums = neighborhoods[annotated_nodes].sum(
43
+ axis=0
44
+ ) # Neighborhood sums for annotated nodes
45
+ annotation_sums = annotations[annotated_nodes].sum(
46
+ axis=0
47
+ ) # Annotation sums for annotated nodes
40
48
  else:
41
49
  raise ValueError(
42
50
  "Invalid null_distribution value. Choose either 'network' or 'annotations'."
43
51
  )
44
52
 
53
+ # Convert to dense arrays for downstream computations
54
+ neighborhood_sums = np.asarray(neighborhood_sums).reshape(-1, 1) # Ensure column vector shape
55
+ annotation_sums = np.asarray(annotation_sums).reshape(1, -1) # Ensure row vector shape
56
+
45
57
  # Observed values: number of annotated nodes in each neighborhood
46
58
  observed = neighborhoods.T @ annotations # Shape: (neighborhoods, annotations)
47
59
  # Expected values under the null
48
60
  expected = (neighborhood_sums @ annotation_sums) / background_population
49
61
  # Chi-squared statistic: sum((observed - expected)^2 / expected)
50
62
  with np.errstate(divide="ignore", invalid="ignore"): # Handle divide-by-zero
51
- chi2_stat = np.where(expected > 0, (observed - expected) ** 2 / expected, 0)
63
+ chi2_stat = np.where(expected > 0, np.power(observed - expected, 2) / expected, 0)
52
64
 
53
65
  # Compute p-values for enrichment (upper tail) and depletion (lower tail)
54
66
  enrichment_pvals = chi2.sf(chi2_stat, df=1) # Survival function for upper tail
@@ -0,0 +1,64 @@
1
+ """
2
+ risk/stats/hypergeom
3
+ ~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict
7
+
8
+ import numpy as np
9
+ from scipy.sparse import csr_matrix
10
+ from scipy.stats import hypergeom
11
+
12
+
13
+ def compute_hypergeom_test(
14
+ neighborhoods: csr_matrix,
15
+ annotations: csr_matrix,
16
+ null_distribution: str = "network",
17
+ ) -> Dict[str, Any]:
18
+ """
19
+ Compute hypergeometric test for enrichment and depletion in neighborhoods with selectable null distribution.
20
+
21
+ Args:
22
+ neighborhoods (csr_matrix): Sparse binary matrix representing neighborhoods.
23
+ annotations (csr_matrix): Sparse binary matrix representing annotations.
24
+ null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
25
+
26
+ Returns:
27
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
28
+ """
29
+ # Get the total number of nodes in the network
30
+ total_nodes = neighborhoods.shape[1]
31
+
32
+ # Compute sums
33
+ neighborhood_sums = neighborhoods.sum(axis=0).A.flatten() # Convert to dense array
34
+ annotation_sums = annotations.sum(axis=0).A.flatten() # Convert to dense array
35
+
36
+ if null_distribution == "network":
37
+ background_population = total_nodes
38
+ elif null_distribution == "annotations":
39
+ annotated_nodes = annotations.sum(axis=1).A.flatten() > 0 # Boolean mask
40
+ background_population = annotated_nodes.sum()
41
+ neighborhood_sums = neighborhoods[annotated_nodes].sum(axis=0).A.flatten()
42
+ annotation_sums = annotations[annotated_nodes].sum(axis=0).A.flatten()
43
+ else:
44
+ raise ValueError(
45
+ "Invalid null_distribution value. Choose either 'network' or 'annotations'."
46
+ )
47
+
48
+ # Observed counts
49
+ annotated_in_neighborhood = neighborhoods.T @ annotations # Sparse result
50
+ annotated_in_neighborhood = annotated_in_neighborhood.toarray() # Convert to dense
51
+ # Align shapes for broadcasting
52
+ neighborhood_sums = neighborhood_sums.reshape(-1, 1)
53
+ annotation_sums = annotation_sums.reshape(1, -1)
54
+ background_population = np.array(background_population).reshape(1, 1)
55
+
56
+ # Compute hypergeometric p-values
57
+ depletion_pvals = hypergeom.cdf(
58
+ annotated_in_neighborhood, background_population, annotation_sums, neighborhood_sums
59
+ )
60
+ enrichment_pvals = hypergeom.sf(
61
+ annotated_in_neighborhood - 1, background_population, annotation_sums, neighborhood_sums
62
+ )
63
+
64
+ return {"depletion_pvals": depletion_pvals, "enrichment_pvals": enrichment_pvals}