risk-network 0.0.16b1__py3-none-any.whl → 0.0.16b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. risk/__init__.py +2 -2
  2. risk/{_annotation → annotation}/__init__.py +2 -2
  3. risk/{_annotation → annotation}/_nltk_setup.py +3 -3
  4. risk/{_annotation/_annotation.py → annotation/annotation.py} +13 -13
  5. risk/{_annotation/_io.py → annotation/io.py} +4 -4
  6. risk/cluster/__init__.py +8 -0
  7. risk/{_neighborhoods → cluster}/_community.py +37 -37
  8. risk/cluster/api.py +273 -0
  9. risk/{_neighborhoods/_neighborhoods.py → cluster/cluster.py} +127 -98
  10. risk/{_neighborhoods/_domains.py → cluster/label.py} +18 -12
  11. risk/{_log → log}/__init__.py +2 -2
  12. risk/{_log/_console.py → log/console.py} +2 -2
  13. risk/{_log/_parameters.py → log/parameters.py} +20 -10
  14. risk/network/__init__.py +8 -0
  15. risk/network/graph/__init__.py +7 -0
  16. risk/{_network/_graph → network/graph}/_stats.py +2 -2
  17. risk/{_network/_graph → network/graph}/_summary.py +13 -13
  18. risk/{_network/_graph/_api.py → network/graph/api.py} +37 -39
  19. risk/{_network/_graph/_graph.py → network/graph/graph.py} +5 -5
  20. risk/{_network/_io.py → network/io.py} +9 -4
  21. risk/network/plotter/__init__.py +6 -0
  22. risk/{_network/_plotter → network/plotter}/_canvas.py +6 -6
  23. risk/{_network/_plotter → network/plotter}/_contour.py +4 -4
  24. risk/{_network/_plotter → network/plotter}/_labels.py +6 -6
  25. risk/{_network/_plotter → network/plotter}/_network.py +7 -7
  26. risk/{_network/_plotter → network/plotter}/_plotter.py +5 -5
  27. risk/network/plotter/_utils/__init__.py +7 -0
  28. risk/{_network/_plotter/_utils/_colors.py → network/plotter/_utils/colors.py} +3 -3
  29. risk/{_network/_plotter/_utils/_layout.py → network/plotter/_utils/layout.py} +2 -2
  30. risk/{_network/_plotter/_api.py → network/plotter/api.py} +5 -5
  31. risk/{_risk.py → risk.py} +9 -8
  32. risk/stats/__init__.py +6 -0
  33. risk/stats/_stats/__init__.py +11 -0
  34. risk/stats/_stats/permutation/__init__.py +6 -0
  35. risk/stats/_stats/permutation/_test_functions.py +72 -0
  36. risk/{_neighborhoods/_stats/_permutation/_permutation.py → stats/_stats/permutation/permutation.py} +35 -37
  37. risk/{_neighborhoods/_stats/_tests.py → stats/_stats/tests.py} +32 -34
  38. risk/stats/api.py +202 -0
  39. {risk_network-0.0.16b1.dist-info → risk_network-0.0.16b2.dist-info}/METADATA +2 -2
  40. risk_network-0.0.16b2.dist-info/RECORD +43 -0
  41. risk/_neighborhoods/__init__.py +0 -8
  42. risk/_neighborhoods/_api.py +0 -354
  43. risk/_neighborhoods/_stats/__init__.py +0 -11
  44. risk/_neighborhoods/_stats/_permutation/__init__.py +0 -6
  45. risk/_neighborhoods/_stats/_permutation/_test_functions.py +0 -72
  46. risk/_network/__init__.py +0 -8
  47. risk/_network/_graph/__init__.py +0 -7
  48. risk/_network/_plotter/__init__.py +0 -6
  49. risk/_network/_plotter/_utils/__init__.py +0 -7
  50. risk_network-0.0.16b1.dist-info/RECORD +0 -41
  51. {risk_network-0.0.16b1.dist-info → risk_network-0.0.16b2.dist-info}/WHEEL +0 -0
  52. {risk_network-0.0.16b1.dist-info → risk_network-0.0.16b2.dist-info}/licenses/LICENSE +0 -0
  53. {risk_network-0.0.16b1.dist-info → risk_network-0.0.16b2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_graph/_api
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import copy
@@ -9,14 +9,14 @@ from typing import Any, Dict, Union
9
9
  import networkx as nx
10
10
  import pandas as pd
11
11
 
12
- from ..._annotation import define_top_annotation
13
- from ..._log import log_header, logger, params
14
- from ..._neighborhoods import (
12
+ from ...annotation import define_top_annotation
13
+ from ...log import log_header, logger, params
14
+ from ...cluster import (
15
15
  define_domains,
16
- process_neighborhoods,
16
+ process_significant_clusters,
17
17
  trim_domains,
18
18
  )
19
- from ._graph import Graph
19
+ from .graph import Graph
20
20
  from ._stats import calculate_significance_matrices
21
21
 
22
22
 
@@ -24,14 +24,14 @@ class GraphAPI:
24
24
  """
25
25
  Handles the loading of network graphs and associated data.
26
26
 
27
- The GraphAPI class provides methods to load and process network graphs, annotations, and neighborhoods.
27
+ The GraphAPI class provides methods to load and process network graphs, annotations, and cluster results.
28
28
  """
29
29
 
30
30
  def load_graph(
31
31
  self,
32
32
  network: nx.Graph,
33
33
  annotation: Dict[str, Any],
34
- neighborhoods: Dict[str, Any],
34
+ stats_results: Dict[str, Any],
35
35
  tail: str = "right",
36
36
  pval_cutoff: float = 0.01,
37
37
  fdr_cutoff: float = 0.9999,
@@ -50,7 +50,7 @@ class GraphAPI:
50
50
  Args:
51
51
  network (nx.Graph): The network graph.
52
52
  annotation (Dict[str, Any]): The annotation associated with the network.
53
- neighborhoods (Dict[str, Any]): Neighborhood significance data.
53
+ stats_results (Dict[str, Any]): Cluster significance data.
54
54
  tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
55
55
  pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
56
56
  fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
@@ -62,14 +62,14 @@ class GraphAPI:
62
62
  Defaults to "yule".
63
63
  linkage_threshold (float, str, optional): Threshold for clustering. Choose "auto" to optimize.
64
64
  Defaults to 0.2.
65
- min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
66
- max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
65
+ min_cluster_size (int, optional): Minimum size for significant clusters. Defaults to 5.
66
+ max_cluster_size (int, optional): Maximum size for significant clusters. Defaults to 1000.
67
67
 
68
68
  Returns:
69
69
  Graph: A fully initialized and processed Graph object.
70
70
  """
71
71
  # Log the parameters and display headers
72
- log_header("Finding significant neighborhoods")
72
+ log_header("Finding significant clusters")
73
73
  params.log_graph(
74
74
  tail=tail,
75
75
  pval_cutoff=pval_cutoff,
@@ -92,20 +92,20 @@ class GraphAPI:
92
92
  logger.debug(
93
93
  f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
94
94
  )
95
- # Calculate significant neighborhoods based on the provided parameters
96
- significant_neighborhoods = calculate_significance_matrices(
97
- neighborhoods["depletion_pvals"],
98
- neighborhoods["enrichment_pvals"],
95
+ # Calculate significant clusters based on the provided parameters
96
+ significant_clusters = calculate_significance_matrices(
97
+ stats_results["depletion_pvals"],
98
+ stats_results["enrichment_pvals"],
99
99
  tail=tail,
100
100
  pval_cutoff=pval_cutoff,
101
101
  fdr_cutoff=fdr_cutoff,
102
102
  )
103
103
 
104
- log_header("Processing neighborhoods")
105
- # Process neighborhoods by imputing and pruning based on the given settings
106
- processed_neighborhoods = process_neighborhoods(
104
+ log_header("Processing significant clusters")
105
+ # Process significant clusters by imputing and pruning based on the given settings
106
+ processed_clusters = process_significant_clusters(
107
107
  network=network,
108
- neighborhoods=significant_neighborhoods,
108
+ significant_clusters=significant_clusters,
109
109
  impute_depth=impute_depth,
110
110
  prune_threshold=prune_threshold,
111
111
  )
@@ -113,24 +113,22 @@ class GraphAPI:
113
113
  log_header("Finding top annotations")
114
114
  logger.debug(f"Min cluster size: {min_cluster_size}")
115
115
  logger.debug(f"Max cluster size: {max_cluster_size}")
116
- # Define top annotations based on processed neighborhoods
116
+ # Define top annotations based on processed significant clusters
117
117
  top_annotation = self._define_top_annotation(
118
118
  network=network,
119
119
  annotation=annotation,
120
- neighborhoods=processed_neighborhoods,
120
+ processed_clusters=processed_clusters,
121
121
  min_cluster_size=min_cluster_size,
122
122
  max_cluster_size=max_cluster_size,
123
123
  )
124
124
 
125
- log_header("Optimizing distance threshold for domains")
126
- # Extract the significant significance matrix from the neighborhoods data
127
- significant_neighborhoods_significance = processed_neighborhoods[
128
- "significant_significance_matrix"
129
- ]
125
+ log_header("Grouping clusters into domains")
126
+ # Extract the significant significance matrix from the processed_clusters data
127
+ significant_clusters_significance = processed_clusters["significant_significance_matrix"]
130
128
  # Define domains in the network using the specified clustering settings
131
129
  domains = define_domains(
132
130
  top_annotation=top_annotation,
133
- significant_neighborhoods_significance=significant_neighborhoods_significance,
131
+ significant_clusters_significance=significant_clusters_significance,
134
132
  linkage_criterion=linkage_criterion,
135
133
  linkage_method=linkage_method,
136
134
  linkage_metric=linkage_metric,
@@ -147,13 +145,13 @@ class GraphAPI:
147
145
  # Prepare node mapping and significance sums for the final Graph object
148
146
  ordered_nodes = annotation["ordered_nodes"]
149
147
  node_label_to_id = dict(zip(ordered_nodes, range(len(ordered_nodes))))
150
- node_significance_sums = processed_neighborhoods["node_significance_sums"]
148
+ node_significance_sums = processed_clusters["node_significance_sums"]
151
149
 
152
150
  # Return the fully initialized Graph object
153
151
  return Graph(
154
152
  network=network,
155
153
  annotation=annotation,
156
- neighborhoods=neighborhoods,
154
+ stats_results=stats_results,
157
155
  domains=domains,
158
156
  trimmed_domains=trimmed_domains,
159
157
  node_label_to_node_id_map=node_label_to_id,
@@ -164,7 +162,7 @@ class GraphAPI:
164
162
  self,
165
163
  network: nx.Graph,
166
164
  annotation: Dict[str, Any],
167
- neighborhoods: Dict[str, Any],
165
+ processed_clusters: Dict[str, Any],
168
166
  min_cluster_size: int = 5,
169
167
  max_cluster_size: int = 1000,
170
168
  ) -> pd.DataFrame:
@@ -174,25 +172,25 @@ class GraphAPI:
174
172
  Args:
175
173
  network (nx.Graph): The network graph.
176
174
  annotation (Dict[str, Any]): Annotation data for the network.
177
- neighborhoods (Dict[str, Any]): Neighborhood significance data.
175
+ processed_clusters (Dict[str, Any]): Processed cluster significance data.
178
176
  min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
179
177
  max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
180
178
 
181
179
  Returns:
182
- Dict[str, Any]: Top annotations identified within the network.
180
+ pd.DataFrame: Top annotations identified within the network.
183
181
  """
184
- # Extract necessary data from annotation and neighborhoods
182
+ # Extract necessary data from annotation and processed_clusters
185
183
  ordered_annotation = annotation["ordered_annotation"]
186
- neighborhood_significance_sums = neighborhoods["neighborhood_significance_counts"]
187
- significant_significance_matrix = neighborhoods["significant_significance_matrix"]
188
- significant_binary_significance_matrix = neighborhoods[
184
+ cluster_significance_sums = processed_clusters["cluster_significance_counts"]
185
+ significant_significance_matrix = processed_clusters["significant_significance_matrix"]
186
+ significant_binary_significance_matrix = processed_clusters[
189
187
  "significant_binary_significance_matrix"
190
188
  ]
191
189
  # Call external function to define top annotations
192
190
  return define_top_annotation(
193
191
  network=network,
194
192
  ordered_annotation_labels=ordered_annotation,
195
- neighborhood_significance_sums=neighborhood_significance_sums,
193
+ cluster_significance_sums=cluster_significance_sums,
196
194
  significant_significance_matrix=significant_significance_matrix,
197
195
  significant_binary_significance_matrix=significant_binary_significance_matrix,
198
196
  min_cluster_size=min_cluster_size,
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_graph/_graph
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/graph/graph
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from collections import defaultdict
@@ -27,7 +27,7 @@ class Graph:
27
27
  self,
28
28
  network: nx.Graph,
29
29
  annotation: Dict[str, Any],
30
- neighborhoods: Dict[str, Any],
30
+ stats_results: Dict[str, Any],
31
31
  domains: pd.DataFrame,
32
32
  trimmed_domains: pd.DataFrame,
33
33
  node_label_to_node_id_map: Dict[str, Any],
@@ -40,7 +40,7 @@ class Graph:
40
40
  Args:
41
41
  network (nx.Graph): The network graph.
42
42
  annotation (Dict[str, Any]): The annotation associated with the network.
43
- neighborhoods (Dict[str, Any]): Neighborhood significance data.
43
+ stats_results (Dict[str, Any]): Cluster significance data.
44
44
  domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
45
45
  trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
46
46
  node_label_to_node_id_map (Dict[str, Any]): A dictionary mapping node labels to their corresponding IDs.
@@ -72,7 +72,7 @@ class Graph:
72
72
  self.node_coordinates = self._extract_node_coordinates(self.network)
73
73
 
74
74
  # NOTE: Only after the above attributes are initialized, we can create the summary
75
- self.summary = Summary(annotation, neighborhoods, self)
75
+ self.summary = Summary(annotation, stats_results, self)
76
76
 
77
77
  def pop(self, domain_id: int) -> List[str]:
78
78
  """
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_io
3
- ~~~~~~~~~~~~~~~~~
2
+ risk/network/io
3
+ ~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import copy
@@ -15,7 +15,7 @@ import networkx as nx
15
15
  import numpy as np
16
16
  import pandas as pd
17
17
 
18
- from .._log import log_header, logger, params
18
+ from ..log import log_header, logger, params
19
19
 
20
20
 
21
21
  class NetworkAPI:
@@ -370,7 +370,7 @@ class NetworkIO:
370
370
  self._log_loading_network(filetype, filepath=filepath)
371
371
 
372
372
  # Load the Cytoscape JSON file
373
- with open(filepath, "r") as f:
373
+ with open(filepath, "r", encoding="utf-8") as f:
374
374
  cyjs_data = json.load(f)
375
375
 
376
376
  # Create a graph
@@ -603,6 +603,11 @@ class NetworkIO:
603
603
  distances = compute_distance_vectorized(edge_data, compute_sphere)
604
604
  # Assign Euclidean or spherical distances to edges
605
605
  for (u, v), distance in zip(G.edges, distances):
606
+ if not np.isfinite(distance) or distance <= 0:
607
+ logger.warning(
608
+ f"Edge ({u},{v}) has invalid or non-positive length ({distance}); replaced with minimal fallback 1e-12."
609
+ )
610
+ distance = 1e-12
606
611
  G.edges[u, v]["length"] = distance
607
612
 
608
613
  def _map_to_sphere(self, G: nx.Graph) -> None:
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/plotter
3
+ ~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .api import PlotterAPI
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_canvas
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_canvas
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
@@ -8,10 +8,10 @@ from typing import List, Tuple, Union
8
8
  import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
 
11
- from ..._log import params
12
- from .._graph import Graph
13
- from ._utils._colors import to_rgba
14
- from ._utils._layout import calculate_bounding_box
11
+ from ...log import params
12
+ from ..graph import Graph
13
+ from ._utils.colors import to_rgba
14
+ from ._utils.layout import calculate_bounding_box
15
15
 
16
16
 
17
17
  class Canvas:
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_contour
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_contour
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
@@ -11,8 +11,8 @@ from scipy import linalg
11
11
  from scipy.ndimage import label
12
12
  from scipy.stats import gaussian_kde
13
13
 
14
- from ..._log import logger, params
15
- from .._graph import Graph
14
+ from ...log import logger, params
15
+ from ..graph import Graph
16
16
  from ._utils import get_annotated_domain_colors, to_rgba
17
17
 
18
18
 
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_labels
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_labels
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import copy
@@ -10,8 +10,8 @@ import matplotlib.pyplot as plt
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
 
13
- from ..._log import params
14
- from .._graph import Graph
13
+ from ...log import params
14
+ from ..graph import Graph
15
15
  from ._utils import calculate_bounding_box, get_annotated_domain_colors, to_rgba
16
16
 
17
17
  TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
@@ -275,12 +275,12 @@ class Labels:
275
275
  fontsize (int, optional): Font size for the label. Defaults to 10.
276
276
  fontcolor (str, List, Tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
277
277
  fontalpha (float, None, optional): Transparency level for the font color. If provided, it overrides any existing alpha values found
278
- in fontalpha. Defaults to 1.0.
278
+ in fontcolor. Defaults to 1.0.
279
279
  arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
280
280
  arrow_style (str, optional): Style of the arrows pointing to the centroid. Defaults to "->".
281
281
  arrow_color (str, List, Tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
282
282
  arrow_alpha (float, None, optional): Transparency level for the arrow color. If provided, it overrides any existing alpha values
283
- found in arrow_alpha. Defaults to 1.0.
283
+ found in arrow_color. Defaults to 1.0.
284
284
  arrow_base_shrink (float, optional): Distance between the text and the base of the arrow. Defaults to 0.0.
285
285
  arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
286
286
 
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_network
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_network
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
@@ -8,8 +8,8 @@ from typing import Any, Dict, List, Tuple, Union
8
8
  import networkx as nx
9
9
  import numpy as np
10
10
 
11
- from ..._log import params
12
- from .._graph import Graph
11
+ from ...log import params
12
+ from ..graph import Graph
13
13
  from ._utils import get_domain_colors, to_rgba
14
14
 
15
15
 
@@ -273,14 +273,14 @@ class Network:
273
273
  return adjusted_network_colors
274
274
 
275
275
  def get_annotated_node_sizes(
276
- self, significant_size: int = 50, nonsignificant_size: int = 25
276
+ self, significant_size: Union[int, float] = 50, nonsignificant_size: Union[int, float] = 25
277
277
  ) -> np.ndarray:
278
278
  """
279
279
  Adjust the sizes of nodes in the network graph based on whether they are significant or not.
280
280
 
281
281
  Args:
282
- significant_size (int): Size for significant nodes. Defaults to 50.
283
- nonsignificant_size (int): Size for non-significant nodes. Defaults to 25.
282
+ significant_size (int or float): Size for significant nodes. Can be an integer or float value. Defaults to 50.
283
+ nonsignificant_size (int or float): Size for non-significant nodes. Can be an integer or float value. Defaults to 25.
284
284
 
285
285
  Returns:
286
286
  np.ndarray: Array of node sizes, with significant nodes larger than non-significant ones.
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_plotter
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_plotter
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
@@ -8,8 +8,8 @@ from typing import List, Tuple, Union
8
8
  import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
 
11
- from ..._log import params
12
- from .._graph._graph import Graph
11
+ from ...log import params
12
+ from ..graph.graph import Graph
13
13
  from ._canvas import Canvas
14
14
  from ._contour import Contour
15
15
  from ._labels import Labels
@@ -123,7 +123,7 @@ class Plotter(Canvas, Network, Contour, Labels):
123
123
  Args:
124
124
  *args: Positional arguments passed to `plt.savefig`.
125
125
  pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
126
- dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
126
+ dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 100.
127
127
  **kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
128
128
  """
129
129
  # Ensure user-provided kwargs take precedence
@@ -0,0 +1,7 @@
1
+ """
2
+ risk/network/plotter/_utils
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .colors import get_annotated_domain_colors, get_domain_colors, to_rgba
7
+ from .layout import calculate_bounding_box, calculate_centroids
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_utils/_colors
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_utils/colors
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
@@ -9,7 +9,7 @@ import matplotlib
9
9
  import matplotlib.colors as mcolors
10
10
  import numpy as np
11
11
 
12
- from ..._graph import Graph
12
+ from ...graph import Graph
13
13
 
14
14
 
15
15
  def get_annotated_domain_colors(
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/_network/_plotter/_utils/_layout
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/_utils/layout
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple
@@ -1,14 +1,14 @@
1
1
  """
2
- risk/_network/_plotter/_api
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
7
7
 
8
8
  import numpy as np
9
9
 
10
- from ..._log import log_header
11
- from .._graph import Graph
10
+ from ...log import log_header
11
+ from ..graph import Graph
12
12
  from ._plotter import Plotter
13
13
 
14
14
 
@@ -32,7 +32,7 @@ class PlotterAPI:
32
32
 
33
33
  Args:
34
34
  graph (Graph): The graph to plot.
35
- figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
35
+ figsize (List, Tuple, or np.ndarray, optional): Figure size in inches (width, height). Defaults to (10, 10).
36
36
  background_color (str, optional): Background color of the plot. Defaults to "white".
37
37
  background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
38
38
  any existing alpha values found in background_color. Defaults to 1.0.
@@ -1,20 +1,21 @@
1
1
  """
2
- risk/_risk
3
- ~~~~~~~~~~
2
+ risk/risk
3
+ ~~~~~~~~~
4
4
  """
5
5
 
6
- from ._annotation import AnnotationHandler
7
- from ._log import params, set_global_verbosity
8
- from ._neighborhoods import NeighborhoodsAPI
9
- from ._network import GraphAPI, NetworkAPI, PlotterAPI
6
+ from .annotation import AnnotationHandler
7
+ from .cluster import ClusterAPI
8
+ from .log import params, set_global_verbosity
9
+ from .network import GraphAPI, NetworkAPI, PlotterAPI
10
+ from .stats import StatsAPI
10
11
 
11
12
 
12
- class RISK(NetworkAPI, AnnotationHandler, NeighborhoodsAPI, GraphAPI, PlotterAPI):
13
+ class RISK(NetworkAPI, AnnotationHandler, ClusterAPI, StatsAPI, GraphAPI, PlotterAPI):
13
14
  """
14
15
  RISK: A class for network analysis and visualization.
15
16
 
16
17
  The RISK class integrates functionalities for loading networks, processing annotations,
17
- performing network-based statistical analysis to quantify neighborhood relationships,
18
+ performing network-based statistical analysis to quantify cluster relationships,
18
19
  and visualizing networks and their properties.
19
20
  """
20
21
 
risk/stats/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/stats
3
+ ~~~~~~~~~~
4
+ """
5
+
6
+ from .api import StatsAPI
@@ -0,0 +1,11 @@
1
+ """
2
+ risk/cluster/_stats
3
+ ~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .permutation import compute_permutation_test
7
+ from .tests import (
8
+ compute_binom_test,
9
+ compute_chi2_test,
10
+ compute_hypergeom_test,
11
+ )
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/_clusters/_stats/_permutation
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .permutation import compute_permutation_test
@@ -0,0 +1,72 @@
1
+ """
2
+ risk/stats/_stats/permutation/_test_functions
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import numpy as np
7
+ from scipy.sparse import csr_matrix
8
+
9
+ # NOTE: Cython optimizations provided minimal performance benefits.
10
+ # The final version with Cython is archived in the `cython_permutation` branch.
11
+
12
+ # DISPATCH_TEST_FUNCTIONS can be found at the end of the file.
13
+
14
+
15
+ def compute_cluster_score_by_sum(
16
+ clusters_matrix: csr_matrix, annotation_matrix: csr_matrix
17
+ ) -> np.ndarray:
18
+ """
19
+ Compute the sum of attribute values for each cluster using sparse matrices.
20
+
21
+ Args:
22
+ clusters_matrix (csr_matrix): Sparse binary matrix representing clusters.
23
+ annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
24
+
25
+ Returns:
26
+ np.ndarray: Dense array of summed attribute values for each cluster.
27
+ """
28
+ # Calculate the cluster score as the dot product of clusters and annotation
29
+ cluster_score = clusters_matrix @ annotation_matrix # Sparse matrix multiplication
30
+ # Convert the result to a dense array for downstream calculations
31
+ cluster_score_dense = cluster_score.toarray()
32
+ return cluster_score_dense
33
+
34
+
35
+ def compute_cluster_score_by_stdev(
36
+ clusters_matrix: csr_matrix, annotation_matrix: csr_matrix
37
+ ) -> np.ndarray:
38
+ """
39
+ Compute the standard deviation of cluster scores for sparse matrices.
40
+
41
+ Args:
42
+ clusters_matrix (csr_matrix): Sparse binary matrix representing clusters.
43
+ annotation_matrix (csr_matrix): Sparse matrix representing annotation values.
44
+
45
+ Returns:
46
+ np.ndarray: Standard deviation of the cluster scores.
47
+ """
48
+ # Calculate the cluster score as the dot product of clusters and annotation
49
+ cluster_score = clusters_matrix @ annotation_matrix # Sparse matrix multiplication
50
+ # Calculate the number of elements in each cluster (sum of rows)
51
+ N = clusters_matrix.sum(axis=1).A.flatten() # Convert to 1D array
52
+ # Avoid division by zero by replacing zeros in N with np.nan temporarily
53
+ N[N == 0] = np.nan
54
+ # Compute the mean of the cluster scores
55
+ M = cluster_score.multiply(1 / N[:, None]).toarray() # Sparse element-wise division
56
+ # Compute the mean of squares (EXX) directly using squared annotation matrix
57
+ annotation_squared = annotation_matrix.multiply(annotation_matrix) # Element-wise squaring
58
+ EXX = (clusters_matrix @ annotation_squared).multiply(1 / N[:, None]).toarray()
59
+ # Calculate variance as EXX - M^2
60
+ variance = EXX - np.power(M, 2)
61
+ # Compute the standard deviation as the square root of the variance
62
+ cluster_stdev = np.sqrt(variance)
63
+ # Replace np.nan back with zeros in case N was 0 (no elements in the cluster)
64
+ cluster_stdev[np.isnan(cluster_stdev)] = 0
65
+ return cluster_stdev
66
+
67
+
68
+ # Dictionary to dispatch statistical test functions based on the score metric
69
+ DISPATCH_TEST_FUNCTIONS = {
70
+ "sum": compute_cluster_score_by_sum,
71
+ "stdev": compute_cluster_score_by_stdev,
72
+ }