risk-network 0.0.8b27__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +195 -118
  4. risk/annotations/io.py +47 -31
  5. risk/log/__init__.py +4 -2
  6. risk/log/console.py +3 -1
  7. risk/log/{params.py → parameters.py} +17 -42
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +442 -0
  10. risk/neighborhoods/community.py +324 -101
  11. risk/neighborhoods/domains.py +125 -52
  12. risk/neighborhoods/neighborhoods.py +177 -165
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +71 -89
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +200 -0
  17. risk/network/{graph.py → graph/graph.py} +90 -40
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +103 -114
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +9 -8
  23. risk/network/{plot → plotter}/contour.py +27 -24
  24. risk/network/{plot → plotter}/labels.py +73 -78
  25. risk/network/{plot → plotter}/network.py +45 -39
  26. risk/network/{plot → plotter}/plotter.py +23 -17
  27. risk/network/{plot/utils/color.py → plotter/utils/colors.py} +114 -122
  28. risk/network/{plot → plotter}/utils/layout.py +10 -7
  29. risk/risk.py +11 -500
  30. risk/stats/__init__.py +10 -4
  31. risk/stats/permutation/__init__.py +1 -1
  32. risk/stats/permutation/permutation.py +44 -38
  33. risk/stats/permutation/test_functions.py +26 -18
  34. risk/stats/{stats.py → significance.py} +17 -15
  35. risk/stats/stat_tests.py +267 -0
  36. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/METADATA +31 -46
  37. risk_network-0.0.9.dist-info/RECORD +40 -0
  38. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/WHEEL +1 -1
  39. risk/constants.py +0 -31
  40. risk/network/plot/__init__.py +0 -6
  41. risk/stats/hypergeom.py +0 -54
  42. risk/stats/poisson.py +0 -44
  43. risk_network-0.0.8b27.dist-info/RECORD +0 -37
  44. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/LICENSE +0 -0
  45. {risk_network-0.0.8b27.dist-info → risk_network-0.0.9.dist-info}/top_level.txt +0 -0
risk/network/geometry.py CHANGED
@@ -3,8 +3,6 @@ risk/network/geometry
3
3
  ~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- import copy
7
-
8
6
  import networkx as nx
9
7
  import numpy as np
10
8
 
@@ -13,70 +11,57 @@ def assign_edge_lengths(
13
11
  G: nx.Graph,
14
12
  compute_sphere: bool = True,
15
13
  surface_depth: float = 0.0,
16
- include_edge_weight: bool = False,
17
14
  ) -> nx.Graph:
18
- """Assign edge lengths in the graph, optionally mapping nodes to a sphere and including edge weights.
15
+ """Assign edge lengths in the graph, optionally mapping nodes to a sphere.
19
16
 
20
17
  Args:
21
18
  G (nx.Graph): The input graph.
22
19
  compute_sphere (bool): Whether to map nodes to a sphere. Defaults to True.
23
20
  surface_depth (float): The surface depth for mapping to a sphere. Defaults to 0.0.
24
- include_edge_weight (bool): Whether to include edge weights in the calculation. Defaults to False.
25
21
 
26
22
  Returns:
27
23
  nx.Graph: The graph with applied edge lengths.
28
24
  """
29
25
 
30
- def compute_distance(
31
- u_coords: np.ndarray, v_coords: np.ndarray, is_sphere: bool = False
32
- ) -> float:
33
- """Compute the distance between two coordinate vectors.
34
-
35
- Args:
36
- u_coords (np.ndarray): Coordinates of the first point.
37
- v_coords (np.ndarray): Coordinates of the second point.
38
- is_sphere (bool, optional): If True, compute spherical distance. Defaults to False.
39
-
40
- Returns:
41
- float: The computed distance between the two points.
42
- """
26
+ def compute_distance_vectorized(coords, is_sphere):
27
+ """Compute distances between pairs of coordinates."""
28
+ u_coords, v_coords = coords[:, 0, :], coords[:, 1, :]
43
29
  if is_sphere:
44
- # Normalize vectors and compute spherical distance using the dot product
45
- u_coords /= np.linalg.norm(u_coords)
46
- v_coords /= np.linalg.norm(v_coords)
47
- return np.arccos(np.clip(np.dot(u_coords, v_coords), -1.0, 1.0))
48
- else:
49
- # Compute Euclidean distance
50
- return np.linalg.norm(u_coords - v_coords)
30
+ u_coords /= np.linalg.norm(u_coords, axis=1, keepdims=True)
31
+ v_coords /= np.linalg.norm(v_coords, axis=1, keepdims=True)
32
+ dot_products = np.einsum("ij,ij->i", u_coords, v_coords)
33
+ return np.arccos(np.clip(dot_products, -1.0, 1.0))
34
+ return np.linalg.norm(u_coords - v_coords, axis=1)
51
35
 
52
36
  # Normalize graph coordinates
53
37
  _normalize_graph_coordinates(G)
54
- # Normalize weights
55
- _normalize_weights(G)
56
- # Use G_depth for edge length calculation
38
+
39
+ # Map nodes to sphere and adjust depth if required
57
40
  if compute_sphere:
58
- # Map to sphere and adjust depth
59
41
  _map_to_sphere(G)
60
- G_depth = _create_depth(copy.deepcopy(G), surface_depth=surface_depth)
42
+ G_depth = _create_depth(G, surface_depth=surface_depth)
61
43
  else:
62
- # Calculate edge lengths directly on the plane
63
- G_depth = copy.deepcopy(G)
64
-
65
- for u, v, _ in G_depth.edges(data=True):
66
- u_coords = np.array([G_depth.nodes[u]["x"], G_depth.nodes[u]["y"]])
67
- v_coords = np.array([G_depth.nodes[v]["x"], G_depth.nodes[v]["y"]])
68
- if compute_sphere:
69
- u_coords = np.append(u_coords, G_depth.nodes[u].get("z", 0))
70
- v_coords = np.append(v_coords, G_depth.nodes[v].get("z", 0))
71
-
72
- distance = compute_distance(u_coords, v_coords, is_sphere=compute_sphere)
73
- # Assign edge lengths to the original graph
74
- if include_edge_weight:
75
- # Square root of the normalized weight is used to minimize the effect of large weights
76
- G.edges[u, v]["length"] = distance / np.sqrt(G.edges[u, v]["normalized_weight"] + 1e-6)
77
- else:
78
- # Use calculated distance directly
79
- G.edges[u, v]["length"] = distance
44
+ G_depth = G
45
+
46
+ # Precompute edge coordinate arrays and compute distances in bulk
47
+ edge_data = np.array(
48
+ [
49
+ [
50
+ np.array(
51
+ [G_depth.nodes[u]["x"], G_depth.nodes[u]["y"], G_depth.nodes[u].get("z", 0)]
52
+ ),
53
+ np.array(
54
+ [G_depth.nodes[v]["x"], G_depth.nodes[v]["y"], G_depth.nodes[v].get("z", 0)]
55
+ ),
56
+ ]
57
+ for u, v in G_depth.edges
58
+ ]
59
+ )
60
+ # Compute distances
61
+ distances = compute_distance_vectorized(edge_data, compute_sphere)
62
+ # Assign distances back to the graph
63
+ for (u, v), distance in zip(G_depth.edges, distances):
64
+ G.edges[u, v]["length"] = distance
80
65
 
81
66
  return G
82
67
 
@@ -87,23 +72,23 @@ def _map_to_sphere(G: nx.Graph) -> None:
87
72
  Args:
88
73
  G (nx.Graph): The input graph with nodes having 'x' and 'y' coordinates.
89
74
  """
90
- # Extract x, y coordinates from the graph nodes
91
- xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in G.nodes()])
92
- # Normalize the coordinates between [0, 1]
93
- min_vals = np.min(xy_coords, axis=0)
94
- max_vals = np.max(xy_coords, axis=0)
75
+ # Extract x, y coordinates as a NumPy array
76
+ nodes = list(G.nodes)
77
+ xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in nodes])
78
+ # Normalize coordinates between [0, 1]
79
+ min_vals = xy_coords.min(axis=0)
80
+ max_vals = xy_coords.max(axis=0)
95
81
  normalized_xy = (xy_coords - min_vals) / (max_vals - min_vals)
96
- # Map normalized coordinates to theta and phi on a sphere
82
+ # Convert normalized coordinates to spherical coordinates
97
83
  theta = normalized_xy[:, 0] * np.pi * 2
98
84
  phi = normalized_xy[:, 1] * np.pi
99
- # Convert spherical coordinates to Cartesian coordinates for 3D sphere
100
- for i, node in enumerate(G.nodes()):
101
- x = np.sin(phi[i]) * np.cos(theta[i])
102
- y = np.sin(phi[i]) * np.sin(theta[i])
103
- z = np.cos(phi[i])
104
- G.nodes[node]["x"] = x
105
- G.nodes[node]["y"] = y
106
- G.nodes[node]["z"] = z
85
+ # Compute 3D Cartesian coordinates
86
+ x = np.sin(phi) * np.cos(theta)
87
+ y = np.sin(phi) * np.sin(theta)
88
+ z = np.cos(phi)
89
+ # Assign coordinates back to graph nodes in bulk
90
+ xyz_coords = {node: {"x": x[i], "y": y[i], "z": z[i]} for i, node in enumerate(nodes)}
91
+ nx.set_node_attributes(G, xyz_coords)
107
92
 
108
93
 
109
94
  def _normalize_graph_coordinates(G: nx.Graph) -> None:
@@ -124,22 +109,6 @@ def _normalize_graph_coordinates(G: nx.Graph) -> None:
124
109
  G.nodes[node]["x"], G.nodes[node]["y"] = normalized_xy[i]
125
110
 
126
111
 
127
- def _normalize_weights(G: nx.Graph) -> None:
128
- """Normalize the weights of the edges in the graph.
129
-
130
- Args:
131
- G (nx.Graph): The input graph with weighted edges.
132
- """
133
- # "weight" is present for all edges - weights are 1.0 if weight was not specified by the user
134
- weights = [data["weight"] for _, _, data in G.edges(data=True)]
135
- if weights: # Ensure there are weighted edges
136
- min_weight = min(weights)
137
- max_weight = max(weights)
138
- range_weight = max_weight - min_weight if max_weight > min_weight else 1
139
- for _, _, data in G.edges(data=True):
140
- data["normalized_weight"] = (data["weight"] - min_weight) / range_weight
141
-
142
-
143
112
  def _create_depth(G: nx.Graph, surface_depth: float = 0.0) -> nx.Graph:
144
113
  """Adjust the 'z' attribute of each node based on the subcluster strengths and normalized surface depth.
145
114
 
@@ -151,18 +120,31 @@ def _create_depth(G: nx.Graph, surface_depth: float = 0.0) -> nx.Graph:
151
120
  nx.Graph: The graph with adjusted 'z' attribute for each node.
152
121
  """
153
122
  if surface_depth >= 1.0:
154
- surface_depth = surface_depth - 1e-6 # Cap the surface depth to prevent value of 1.0
155
-
156
- # Compute subclusters as connected components (subclusters can be any other method)
157
- subclusters = {node: set(nx.node_connected_component(G, node)) for node in G.nodes}
158
- # Create a strength metric for subclusters (here using size)
159
- subcluster_strengths = {node: len(neighbors) for node, neighbors in subclusters.items()}
160
- # Normalize the subcluster strengths and apply depths
161
- max_strength = max(subcluster_strengths.values())
162
- for node, strength in subcluster_strengths.items():
123
+ surface_depth -= 1e-6 # Cap the surface depth to prevent a value of 1.0
124
+
125
+ # Compute subclusters as connected components
126
+ connected_components = list(nx.connected_components(G))
127
+ subcluster_strengths = {}
128
+ max_strength = 0
129
+ # Precompute strengths and track the maximum strength
130
+ for component in connected_components:
131
+ size = len(component)
132
+ max_strength = max(max_strength, size)
133
+ for node in component:
134
+ subcluster_strengths[node] = size
135
+
136
+ # Avoid repeated lookups and computations by pre-fetching node data
137
+ nodes = list(G.nodes(data=True))
138
+ node_updates = {}
139
+ for node, attrs in nodes:
140
+ strength = subcluster_strengths[node]
163
141
  normalized_surface_depth = (strength / max_strength) * surface_depth
164
- x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
142
+ x, y, z = attrs["x"], attrs["y"], attrs["z"]
165
143
  norm = np.sqrt(x**2 + y**2 + z**2)
166
- G.nodes[node]["z"] -= (z / norm) * normalized_surface_depth # Adjust Z for a depth
144
+ adjusted_z = z - (z / norm) * normalized_surface_depth
145
+ node_updates[node] = {"z": adjusted_z}
146
+
147
+ # Batch update node attributes
148
+ nx.set_node_attributes(G, node_updates)
167
149
 
168
150
  return G
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/graph
3
+ ~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.graph.api import GraphAPI
@@ -0,0 +1,200 @@
1
+ """
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import copy
7
+ from typing import Any, Dict, Union
8
+
9
+ import networkx as nx
10
+ import pandas as pd
11
+
12
+ from risk.annotations import define_top_annotations
13
+ from risk.log import logger, log_header, params
14
+ from risk.neighborhoods import (
15
+ define_domains,
16
+ process_neighborhoods,
17
+ trim_domains,
18
+ )
19
+ from risk.network.graph.graph import Graph
20
+ from risk.stats import calculate_significance_matrices
21
+
22
+
23
+ class GraphAPI:
24
+ """Handles the loading of network graphs and associated data.
25
+
26
+ The GraphAPI class provides methods to load and process network graphs, annotations, and neighborhoods.
27
+ """
28
+
29
+ def __init__() -> None:
30
+ pass
31
+
32
+ def load_graph(
33
+ self,
34
+ network: nx.Graph,
35
+ annotations: Dict[str, Any],
36
+ neighborhoods: Dict[str, Any],
37
+ tail: str = "right",
38
+ pval_cutoff: float = 0.01,
39
+ fdr_cutoff: float = 0.9999,
40
+ impute_depth: int = 0,
41
+ prune_threshold: float = 0.0,
42
+ linkage_criterion: str = "distance",
43
+ linkage_method: str = "average",
44
+ linkage_metric: str = "yule",
45
+ linkage_threshold: Union[float, str] = 0.2,
46
+ min_cluster_size: int = 5,
47
+ max_cluster_size: int = 1000,
48
+ ) -> Graph:
49
+ """Load and process the network graph, defining top annotations and domains.
50
+
51
+ Args:
52
+ network (nx.Graph): The network graph.
53
+ annotations (Dict[str, Any]): The annotations associated with the network.
54
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
55
+ tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
56
+ pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
57
+ fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
58
+ impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
59
+ prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
60
+ linkage_criterion (str, optional): Clustering criterion for defining domains. Defaults to "distance".
61
+ linkage_method (str, optional): Clustering method to use. Choose "auto" to optimize. Defaults to "average".
62
+ linkage_metric (str, optional): Metric to use for calculating distances. Choose "auto" to optimize.
63
+ Defaults to "yule".
64
+ linkage_threshold (float, str, optional): Threshold for clustering. Choose "auto" to optimize.
65
+ Defaults to 0.2.
66
+ min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
67
+ max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
68
+
69
+ Returns:
70
+ Graph: A fully initialized and processed Graph object.
71
+ """
72
+ # Log the parameters and display headers
73
+ log_header("Finding significant neighborhoods")
74
+ params.log_graph(
75
+ tail=tail,
76
+ pval_cutoff=pval_cutoff,
77
+ fdr_cutoff=fdr_cutoff,
78
+ impute_depth=impute_depth,
79
+ prune_threshold=prune_threshold,
80
+ linkage_criterion=linkage_criterion,
81
+ linkage_method=linkage_method,
82
+ linkage_metric=linkage_metric,
83
+ linkage_threshold=linkage_threshold,
84
+ min_cluster_size=min_cluster_size,
85
+ max_cluster_size=max_cluster_size,
86
+ )
87
+
88
+ # Make a copy of the network to avoid modifying the original
89
+ network = copy.deepcopy(network)
90
+
91
+ logger.debug(f"p-value cutoff: {pval_cutoff}")
92
+ logger.debug(f"FDR BH cutoff: {fdr_cutoff}")
93
+ logger.debug(
94
+ f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
95
+ )
96
+ # Calculate significant neighborhoods based on the provided parameters
97
+ significant_neighborhoods = calculate_significance_matrices(
98
+ neighborhoods["depletion_pvals"],
99
+ neighborhoods["enrichment_pvals"],
100
+ tail=tail,
101
+ pval_cutoff=pval_cutoff,
102
+ fdr_cutoff=fdr_cutoff,
103
+ )
104
+
105
+ log_header("Processing neighborhoods")
106
+ # Process neighborhoods by imputing and pruning based on the given settings
107
+ processed_neighborhoods = process_neighborhoods(
108
+ network=network,
109
+ neighborhoods=significant_neighborhoods,
110
+ impute_depth=impute_depth,
111
+ prune_threshold=prune_threshold,
112
+ )
113
+
114
+ log_header("Finding top annotations")
115
+ logger.debug(f"Min cluster size: {min_cluster_size}")
116
+ logger.debug(f"Max cluster size: {max_cluster_size}")
117
+ # Define top annotations based on processed neighborhoods
118
+ top_annotations = self._define_top_annotations(
119
+ network=network,
120
+ annotations=annotations,
121
+ neighborhoods=processed_neighborhoods,
122
+ min_cluster_size=min_cluster_size,
123
+ max_cluster_size=max_cluster_size,
124
+ )
125
+
126
+ log_header("Optimizing distance threshold for domains")
127
+ # Extract the significant significance matrix from the neighborhoods data
128
+ significant_neighborhoods_significance = processed_neighborhoods[
129
+ "significant_significance_matrix"
130
+ ]
131
+ # Define domains in the network using the specified clustering settings
132
+ domains = define_domains(
133
+ top_annotations=top_annotations,
134
+ significant_neighborhoods_significance=significant_neighborhoods_significance,
135
+ linkage_criterion=linkage_criterion,
136
+ linkage_method=linkage_method,
137
+ linkage_metric=linkage_metric,
138
+ linkage_threshold=linkage_threshold,
139
+ )
140
+ # Trim domains and top annotations based on cluster size constraints
141
+ domains, trimmed_domains = trim_domains(
142
+ domains=domains,
143
+ top_annotations=top_annotations,
144
+ min_cluster_size=min_cluster_size,
145
+ max_cluster_size=max_cluster_size,
146
+ )
147
+
148
+ # Prepare node mapping and significance sums for the final Graph object
149
+ ordered_nodes = annotations["ordered_nodes"]
150
+ node_label_to_id = dict(zip(ordered_nodes, range(len(ordered_nodes))))
151
+ node_significance_sums = processed_neighborhoods["node_significance_sums"]
152
+
153
+ # Return the fully initialized Graph object
154
+ return Graph(
155
+ network=network,
156
+ annotations=annotations,
157
+ neighborhoods=neighborhoods,
158
+ domains=domains,
159
+ trimmed_domains=trimmed_domains,
160
+ node_label_to_node_id_map=node_label_to_id,
161
+ node_significance_sums=node_significance_sums,
162
+ )
163
+
164
+ def _define_top_annotations(
165
+ self,
166
+ network: nx.Graph,
167
+ annotations: Dict[str, Any],
168
+ neighborhoods: Dict[str, Any],
169
+ min_cluster_size: int = 5,
170
+ max_cluster_size: int = 1000,
171
+ ) -> pd.DataFrame:
172
+ """Define top annotations for the network.
173
+
174
+ Args:
175
+ network (nx.Graph): The network graph.
176
+ annotations (Dict[str, Any]): Annotations data for the network.
177
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
178
+ min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
179
+ max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
180
+
181
+ Returns:
182
+ Dict[str, Any]: Top annotations identified within the network.
183
+ """
184
+ # Extract necessary data from annotations and neighborhoods
185
+ ordered_annotations = annotations["ordered_annotations"]
186
+ neighborhood_significance_sums = neighborhoods["neighborhood_significance_counts"]
187
+ significant_significance_matrix = neighborhoods["significant_significance_matrix"]
188
+ significant_binary_significance_matrix = neighborhoods[
189
+ "significant_binary_significance_matrix"
190
+ ]
191
+ # Call external function to define top annotations
192
+ return define_top_annotations(
193
+ network=network,
194
+ ordered_annotation_labels=ordered_annotations,
195
+ neighborhood_significance_sums=neighborhood_significance_sums,
196
+ significant_significance_matrix=significant_significance_matrix,
197
+ significant_binary_significance_matrix=significant_binary_significance_matrix,
198
+ min_cluster_size=min_cluster_size,
199
+ max_cluster_size=max_cluster_size,
200
+ )
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/graph
3
- ~~~~~~~~~~~~~~~~~~
2
+ risk/network/graph/graph
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from collections import defaultdict
@@ -10,60 +10,93 @@ import networkx as nx
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
 
13
+ from risk.network.graph.summary import Summary
13
14
 
14
- class NetworkGraph:
15
+
16
+ class Graph:
15
17
  """A class to represent a network graph and process its nodes and edges.
16
18
 
17
- The NetworkGraph class provides functionality to handle and manipulate a network graph,
18
- including managing domains, annotations, and node enrichment data. It also includes methods
19
+ The Graph class provides functionality to handle and manipulate a network graph,
20
+ including managing domains, annotations, and node significance data. It also includes methods
19
21
  for transforming and mapping graph coordinates, as well as generating colors based on node
20
- enrichment.
22
+ significance.
21
23
  """
22
24
 
23
25
  def __init__(
24
26
  self,
25
27
  network: nx.Graph,
26
- top_annotations: pd.DataFrame,
28
+ annotations: Dict[str, Any],
29
+ neighborhoods: Dict[str, Any],
27
30
  domains: pd.DataFrame,
28
31
  trimmed_domains: pd.DataFrame,
29
32
  node_label_to_node_id_map: Dict[str, Any],
30
- node_enrichment_sums: np.ndarray,
33
+ node_significance_sums: np.ndarray,
31
34
  ):
32
- """Initialize the NetworkGraph object.
35
+ """Initialize the Graph object.
33
36
 
34
37
  Args:
35
38
  network (nx.Graph): The network graph.
36
- top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
39
+ annotations (Dict[str, Any]): The annotations associated with the network.
40
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
37
41
  domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
38
42
  trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
39
43
  node_label_to_node_id_map (Dict[str, Any]): A dictionary mapping node labels to their corresponding IDs.
40
- node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
44
+ node_significance_sums (np.ndarray): Array containing the significant sums for the nodes.
41
45
  """
42
- self.top_annotations = top_annotations
46
+ # Initialize self.network downstream of the other attributes
47
+ # All public attributes can be accessed after initialization
43
48
  self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
44
- self.domains = domains
45
49
  self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
46
50
  trimmed_domains
47
51
  )
48
52
  self.domain_id_to_domain_info_map = self._create_domain_id_to_domain_info_map(
49
53
  trimmed_domains
50
54
  )
51
- self.trimmed_domains = trimmed_domains
52
- self.node_enrichment_sums = node_enrichment_sums
53
- self.node_id_to_domain_ids_and_enrichments_map = (
54
- self._create_node_id_to_domain_ids_and_enrichments(domains)
55
+ self.node_id_to_domain_ids_and_significance_map = (
56
+ self._create_node_id_to_domain_ids_and_significances(domains)
55
57
  )
56
58
  self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
57
- self.node_label_to_enrichment_map = dict(
58
- zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
59
+ self.node_label_to_significance_map = dict(
60
+ zip(node_label_to_node_id_map.keys(), node_significance_sums)
59
61
  )
62
+ self.node_significance_sums = node_significance_sums
60
63
  self.node_label_to_node_id_map = node_label_to_node_id_map
64
+
61
65
  # NOTE: Below this point, instance attributes (i.e., self) will be used!
62
66
  self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
63
67
  # Unfold the network's 3D coordinates to 2D and extract node coordinates
64
68
  self.network = _unfold_sphere_to_plane(network)
65
69
  self.node_coordinates = _extract_node_coordinates(self.network)
66
70
 
71
+ # NOTE: Only after the above attributes are initialized, we can create the summary
72
+ self.summary = Summary(annotations, neighborhoods, self)
73
+
74
+ def pop(self, domain_id: str) -> None:
75
+ """Remove domain ID from instance domain ID mappings. This can be useful for cleaning up
76
+ domain-specific mappings based on a given criterion, as domain attributes are stored and
77
+ accessed only in dictionaries modified by this method.
78
+
79
+ Args:
80
+ key (str): The domain ID key to be removed from each mapping.
81
+ """
82
+ # Define the domain mappings to be updated
83
+ domain_mappings = [
84
+ self.domain_id_to_node_ids_map,
85
+ self.domain_id_to_domain_terms_map,
86
+ self.domain_id_to_domain_info_map,
87
+ self.domain_id_to_node_labels_map,
88
+ ]
89
+ # Remove the specified domain_id key from each mapping if it exists
90
+ for mapping in domain_mappings:
91
+ if domain_id in mapping:
92
+ mapping.pop(domain_id)
93
+
94
+ # Remove the domain_id from the node_id_to_domain_ids_and_significance_map
95
+ for _, domain_info in self.node_id_to_domain_ids_and_significance_map.items():
96
+ if domain_id in domain_info["domains"]:
97
+ domain_info["domains"].remove(domain_id)
98
+ domain_info["significances"].pop(domain_id)
99
+
67
100
  @staticmethod
68
101
  def _create_domain_id_to_node_ids_map(domains: pd.DataFrame) -> Dict[int, Any]:
69
102
  """Create a mapping from domains to the list of node IDs belonging to each domain.
@@ -103,25 +136,42 @@ class NetworkGraph:
103
136
  def _create_domain_id_to_domain_info_map(
104
137
  trimmed_domains: pd.DataFrame,
105
138
  ) -> Dict[int, Dict[str, Any]]:
106
- """Create a mapping from domain IDs to their corresponding full description and enrichment score.
139
+ """Create a mapping from domain IDs to their corresponding full description and significance score,
140
+ with scores sorted in descending order.
107
141
 
108
142
  Args:
109
- trimmed_domains (pd.DataFrame): DataFrame containing domain IDs, full descriptions, and enrichment scores.
143
+ trimmed_domains (pd.DataFrame): DataFrame containing domain IDs, full descriptions, and significance scores.
110
144
 
111
145
  Returns:
112
- Dict[int, Dict[str, Any]]: A dictionary mapping domain IDs (int) to a dictionary with 'full_descriptions' and 'enrichment_scores'.
146
+ Dict[int, Dict[str, Any]]: A dictionary mapping domain IDs (int) to a dictionary with 'full_descriptions' and
147
+ 'significance_scores', both sorted by significance score in descending order.
113
148
  """
114
- return {
115
- int(id_): {
116
- "full_descriptions": trimmed_domains.at[id_, "full_descriptions"],
117
- "enrichment_scores": trimmed_domains.at[id_, "enrichment_scores"],
149
+ # Initialize an empty dictionary to store full descriptions and significance scores of domains
150
+ domain_info_map = {}
151
+ # Domain IDs are the index of the DataFrame (it's common for some IDs to be missing)
152
+ for domain_id in trimmed_domains.index:
153
+ # Sort full_descriptions and significance_scores by significance_scores in descending order
154
+ descriptions_and_scores = sorted(
155
+ zip(
156
+ trimmed_domains.at[domain_id, "full_descriptions"],
157
+ trimmed_domains.at[domain_id, "significance_scores"],
158
+ ),
159
+ key=lambda x: x[1], # Sort by significance score
160
+ reverse=True, # Descending order
161
+ )
162
+ # Unzip the sorted tuples back into separate lists
163
+ sorted_descriptions, sorted_scores = zip(*descriptions_and_scores)
164
+ # Assign to the domain info map
165
+ domain_info_map[int(domain_id)] = {
166
+ "full_descriptions": list(sorted_descriptions),
167
+ "significance_scores": list(sorted_scores),
118
168
  }
119
- for id_ in trimmed_domains.index
120
- }
169
+
170
+ return domain_info_map
121
171
 
122
172
  @staticmethod
123
- def _create_node_id_to_domain_ids_and_enrichments(domains: pd.DataFrame) -> Dict[int, Dict]:
124
- """Creates a dictionary mapping each node ID to its corresponding domain IDs and enrichment values.
173
+ def _create_node_id_to_domain_ids_and_significances(domains: pd.DataFrame) -> Dict[int, Dict]:
174
+ """Creates a dictionary mapping each node ID to its corresponding domain IDs and significance values.
125
175
 
126
176
  Args:
127
177
  domains (pd.DataFrame): A DataFrame containing domain information for each node. Assumes the last
@@ -129,28 +179,28 @@ class NetworkGraph:
129
179
 
130
180
  Returns:
131
181
  Dict[int, Dict]: A dictionary where the key is the node ID (index of the DataFrame), and the value is another dictionary
132
- with 'domain' (a list of domain IDs with non-zero enrichment) and 'enrichment'
133
- (a dict of domain IDs and their corresponding enrichment values).
182
+ with 'domain' (a list of domain IDs with non-zero significance) and 'significance'
183
+ (a dict of domain IDs and their corresponding significance values).
134
184
  """
135
185
  # Initialize an empty dictionary to store the result
136
- node_id_to_domain_ids_and_enrichments = {}
186
+ node_id_to_domain_ids_and_significances = {}
137
187
  # Get the list of domain columns (excluding 'all domains' and 'primary domain')
138
188
  domain_columns = domains.columns[
139
189
  :-2
140
190
  ] # The last two columns are 'all domains' and 'primary domain'
141
191
  # Iterate over each row in the dataframe
142
192
  for idx, row in domains.iterrows():
143
- # Get the domains (column names) where the enrichment score is greater than 0
193
+ # Get the domains (column names) where the significance score is greater than 0
144
194
  all_domains = domain_columns[row[domain_columns] > 0].tolist()
145
- # Get the enrichment values for those domains
146
- enrichment_values = row[all_domains].to_dict()
195
+ # Get the significance values for those domains
196
+ significance_values = row[all_domains].to_dict()
147
197
  # Store the result in the dictionary with index as the key
148
- node_id_to_domain_ids_and_enrichments[idx] = {
149
- "domains": all_domains, # The column names where enrichment > 0
150
- "enrichments": enrichment_values, # The actual enrichment values for those columns
198
+ node_id_to_domain_ids_and_significances[idx] = {
199
+ "domains": all_domains, # The column names where significance > 0
200
+ "significances": significance_values, # The actual significance values for those columns
151
201
  }
152
202
 
153
- return node_id_to_domain_ids_and_enrichments
203
+ return node_id_to_domain_ids_and_significances
154
204
 
155
205
  def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
156
206
  """Create a map from domain IDs to node labels.