risk-network 0.0.8b18__py3-none-any.whl → 0.0.9b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +133 -72
  4. risk/annotations/io.py +50 -34
  5. risk/log/__init__.py +4 -2
  6. risk/log/{config.py → console.py} +5 -3
  7. risk/log/{params.py → parameters.py} +21 -46
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +446 -0
  10. risk/neighborhoods/community.py +281 -96
  11. risk/neighborhoods/domains.py +92 -38
  12. risk/neighborhoods/neighborhoods.py +210 -149
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +69 -58
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +194 -0
  17. risk/network/graph/network.py +269 -0
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +58 -48
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +80 -26
  23. risk/network/{plot → plotter}/contour.py +43 -34
  24. risk/network/{plot → plotter}/labels.py +123 -113
  25. risk/network/plotter/network.py +424 -0
  26. risk/network/plotter/utils/colors.py +416 -0
  27. risk/network/plotter/utils/layout.py +94 -0
  28. risk/risk.py +11 -469
  29. risk/stats/__init__.py +8 -4
  30. risk/stats/binom.py +51 -0
  31. risk/stats/chi2.py +69 -0
  32. risk/stats/hypergeom.py +28 -18
  33. risk/stats/permutation/__init__.py +1 -1
  34. risk/stats/permutation/permutation.py +45 -39
  35. risk/stats/permutation/test_functions.py +25 -17
  36. risk/stats/poisson.py +17 -11
  37. risk/stats/stats.py +20 -16
  38. risk/stats/zscore.py +68 -0
  39. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
  40. risk_network-0.0.9b26.dist-info/RECORD +44 -0
  41. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
  42. risk/network/graph.py +0 -159
  43. risk/network/plot/__init__.py +0 -6
  44. risk/network/plot/network.py +0 -282
  45. risk/network/plot/plotter.py +0 -137
  46. risk/network/plot/utils/color.py +0 -353
  47. risk/network/plot/utils/layout.py +0 -53
  48. risk_network-0.0.8b18.dist-info/RECORD +0 -37
  49. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
  50. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
risk/network/geometry.py CHANGED
@@ -3,6 +3,8 @@ risk/network/geometry
3
3
  ~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
+ import copy
7
+
6
8
  import networkx as nx
7
9
  import numpy as np
8
10
 
@@ -25,54 +27,50 @@ def assign_edge_lengths(
25
27
  nx.Graph: The graph with applied edge lengths.
26
28
  """
27
29
 
28
- def compute_distance(
29
- u_coords: np.ndarray, v_coords: np.ndarray, is_sphere: bool = False
30
- ) -> float:
31
- """Compute the distance between two coordinate vectors.
32
-
33
- Args:
34
- u_coords (np.ndarray): Coordinates of the first point.
35
- v_coords (np.ndarray): Coordinates of the second point.
36
- is_sphere (bool, optional): If True, compute spherical distance. Defaults to False.
37
-
38
- Returns:
39
- float: The computed distance between the two points.
40
- """
30
+ def compute_distance_vectorized(coords, is_sphere):
31
+ """Compute distances between pairs of coordinates."""
32
+ u_coords, v_coords = coords[:, 0, :], coords[:, 1, :]
41
33
  if is_sphere:
42
- # Normalize vectors and compute spherical distance using the dot product
43
- u_coords /= np.linalg.norm(u_coords)
44
- v_coords /= np.linalg.norm(v_coords)
45
- return np.arccos(np.clip(np.dot(u_coords, v_coords), -1.0, 1.0))
46
- else:
47
- # Compute Euclidean distance
48
- return np.linalg.norm(u_coords - v_coords)
34
+ u_norm = np.linalg.norm(u_coords, axis=1, keepdims=True)
35
+ v_norm = np.linalg.norm(v_coords, axis=1, keepdims=True)
36
+ u_coords /= u_norm
37
+ v_coords /= v_norm
38
+ dot_products = np.einsum("ij,ij->i", u_coords, v_coords)
39
+ return np.arccos(np.clip(dot_products, -1.0, 1.0))
40
+
41
+ return np.linalg.norm(u_coords - v_coords, axis=1)
49
42
 
50
- # Normalize graph coordinates
43
+ # Normalize graph coordinates and weights
51
44
  _normalize_graph_coordinates(G)
52
- # Normalize weights
53
45
  _normalize_weights(G)
54
- # Use G_depth for edge length calculation
46
+ # Map nodes to sphere and adjust depth if required
55
47
  if compute_sphere:
56
- # Map to sphere and adjust depth
57
48
  _map_to_sphere(G)
58
- G_depth = _create_depth(G.copy(), surface_depth=surface_depth)
49
+ G_depth = _create_depth(copy.deepcopy(G), surface_depth=surface_depth)
59
50
  else:
60
- # Calculate edge lengths directly on the plane
61
- G_depth = G.copy()
51
+ G_depth = copy.deepcopy(G)
62
52
 
63
- for u, v, _ in G_depth.edges(data=True):
53
+ # Precompute edge coordinate arrays for vectorized computation
54
+ edge_data = []
55
+ for u, v in G_depth.edges:
64
56
  u_coords = np.array([G_depth.nodes[u]["x"], G_depth.nodes[u]["y"]])
65
57
  v_coords = np.array([G_depth.nodes[v]["x"], G_depth.nodes[v]["y"]])
66
58
  if compute_sphere:
67
59
  u_coords = np.append(u_coords, G_depth.nodes[u].get("z", 0))
68
60
  v_coords = np.append(v_coords, G_depth.nodes[v].get("z", 0))
69
-
70
- distance = compute_distance(u_coords, v_coords, is_sphere=compute_sphere)
61
+ edge_data.append([u_coords, v_coords, (u, v)])
62
+
63
+ # Convert to numpy for faster processing
64
+ edge_coords = np.array([(e[0], e[1]) for e in edge_data])
65
+ edge_indices = [e[2] for e in edge_data]
66
+ # Compute distances in bulk
67
+ distances = compute_distance_vectorized(edge_coords, compute_sphere)
68
+ # Assign distances back to the graph
69
+ for (u, v), distance in zip(edge_indices, distances):
71
70
  if include_edge_weight:
72
- # Square root of the normalized weight is used to minimize the effect of large weights
73
- G.edges[u, v]["length"] = distance / np.sqrt(G.edges[u, v]["normalized_weight"] + 1e-6)
71
+ weight = G.edges[u, v].get("normalized_weight", 0) + 1e-6
72
+ G.edges[u, v]["length"] = distance / np.sqrt(weight)
74
73
  else:
75
- # Use calculated distance directly
76
74
  G.edges[u, v]["length"] = distance
77
75
 
78
76
  return G
@@ -84,23 +82,23 @@ def _map_to_sphere(G: nx.Graph) -> None:
84
82
  Args:
85
83
  G (nx.Graph): The input graph with nodes having 'x' and 'y' coordinates.
86
84
  """
87
- # Extract x, y coordinates from the graph nodes
88
- xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in G.nodes()])
89
- # Normalize the coordinates between [0, 1]
90
- min_vals = np.min(xy_coords, axis=0)
91
- max_vals = np.max(xy_coords, axis=0)
85
+ # Extract x, y coordinates as a NumPy array
86
+ nodes = list(G.nodes)
87
+ xy_coords = np.array([[G.nodes[node]["x"], G.nodes[node]["y"]] for node in nodes])
88
+ # Normalize coordinates between [0, 1]
89
+ min_vals = xy_coords.min(axis=0)
90
+ max_vals = xy_coords.max(axis=0)
92
91
  normalized_xy = (xy_coords - min_vals) / (max_vals - min_vals)
93
- # Map normalized coordinates to theta and phi on a sphere
92
+ # Convert normalized coordinates to spherical coordinates
94
93
  theta = normalized_xy[:, 0] * np.pi * 2
95
94
  phi = normalized_xy[:, 1] * np.pi
96
- # Convert spherical coordinates to Cartesian coordinates for 3D sphere
97
- for i, node in enumerate(G.nodes()):
98
- x = np.sin(phi[i]) * np.cos(theta[i])
99
- y = np.sin(phi[i]) * np.sin(theta[i])
100
- z = np.cos(phi[i])
101
- G.nodes[node]["x"] = x
102
- G.nodes[node]["y"] = y
103
- G.nodes[node]["z"] = z
95
+ # Compute 3D Cartesian coordinates
96
+ x = np.sin(phi) * np.cos(theta)
97
+ y = np.sin(phi) * np.sin(theta)
98
+ z = np.cos(phi)
99
+ # Assign coordinates back to graph nodes in bulk
100
+ xyz_coords = {node: {"x": x[i], "y": y[i], "z": z[i]} for i, node in enumerate(nodes)}
101
+ nx.set_node_attributes(G, xyz_coords)
104
102
 
105
103
 
106
104
  def _normalize_graph_coordinates(G: nx.Graph) -> None:
@@ -148,18 +146,31 @@ def _create_depth(G: nx.Graph, surface_depth: float = 0.0) -> nx.Graph:
148
146
  nx.Graph: The graph with adjusted 'z' attribute for each node.
149
147
  """
150
148
  if surface_depth >= 1.0:
151
- surface_depth = surface_depth - 1e-6 # Cap the surface depth to prevent value of 1.0
152
-
153
- # Compute subclusters as connected components (subclusters can be any other method)
154
- subclusters = {node: set(nx.node_connected_component(G, node)) for node in G.nodes}
155
- # Create a strength metric for subclusters (here using size)
156
- subcluster_strengths = {node: len(neighbors) for node, neighbors in subclusters.items()}
157
- # Normalize the subcluster strengths and apply depths
158
- max_strength = max(subcluster_strengths.values())
159
- for node, strength in subcluster_strengths.items():
149
+ surface_depth -= 1e-6 # Cap the surface depth to prevent a value of 1.0
150
+
151
+ # Compute subclusters as connected components
152
+ connected_components = list(nx.connected_components(G))
153
+ subcluster_strengths = {}
154
+ max_strength = 0
155
+ # Precompute strengths and track the maximum strength
156
+ for component in connected_components:
157
+ size = len(component)
158
+ max_strength = max(max_strength, size)
159
+ for node in component:
160
+ subcluster_strengths[node] = size
161
+
162
+ # Avoid repeated lookups and computations by pre-fetching node data
163
+ nodes = list(G.nodes(data=True))
164
+ node_updates = {}
165
+ for node, attrs in nodes:
166
+ strength = subcluster_strengths[node]
160
167
  normalized_surface_depth = (strength / max_strength) * surface_depth
161
- x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
168
+ x, y, z = attrs["x"], attrs["y"], attrs["z"]
162
169
  norm = np.sqrt(x**2 + y**2 + z**2)
163
- G.nodes[node]["z"] -= (z / norm) * normalized_surface_depth # Adjust Z for a depth
170
+ adjusted_z = z - (z / norm) * normalized_surface_depth
171
+ node_updates[node] = {"z": adjusted_z}
172
+
173
+ # Batch update node attributes
174
+ nx.set_node_attributes(G, node_updates)
164
175
 
165
176
  return G
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/graph
3
+ ~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.graph.api import GraphAPI
@@ -0,0 +1,194 @@
1
+ """
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import copy
7
+ from typing import Any, Dict
8
+
9
+ import networkx as nx
10
+ import pandas as pd
11
+
12
+ from risk.annotations import define_top_annotations
13
+ from risk.log import logger, log_header, params
14
+ from risk.neighborhoods import (
15
+ define_domains,
16
+ process_neighborhoods,
17
+ trim_domains,
18
+ )
19
+ from risk.network.graph.network import NetworkGraph
20
+ from risk.stats import calculate_significance_matrices
21
+
22
+
23
+ class GraphAPI:
24
+ """Handles the loading of network graphs and associated data.
25
+
26
+ The GraphAPI class provides methods to load and process network graphs, annotations, and neighborhoods.
27
+ """
28
+
29
+ def __init__() -> None:
30
+ pass
31
+
32
+ def load_graph(
33
+ self,
34
+ network: nx.Graph,
35
+ annotations: Dict[str, Any],
36
+ neighborhoods: Dict[str, Any],
37
+ tail: str = "right",
38
+ pval_cutoff: float = 0.01,
39
+ fdr_cutoff: float = 0.9999,
40
+ impute_depth: int = 0,
41
+ prune_threshold: float = 0.0,
42
+ linkage_criterion: str = "distance",
43
+ linkage_method: str = "average",
44
+ linkage_metric: str = "yule",
45
+ min_cluster_size: int = 5,
46
+ max_cluster_size: int = 1000,
47
+ ) -> NetworkGraph:
48
+ """Load and process the network graph, defining top annotations and domains.
49
+
50
+ Args:
51
+ network (nx.Graph): The network graph.
52
+ annotations (Dict[str, Any]): The annotations associated with the network.
53
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
54
+ tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
55
+ pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
56
+ fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
57
+ impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
58
+ prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
59
+ linkage_criterion (str, optional): Clustering criterion for defining domains. Defaults to "distance".
60
+ linkage_method (str, optional): Clustering method to use. Defaults to "average".
61
+ linkage_metric (str, optional): Metric to use for calculating distances. Defaults to "yule".
62
+ min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
63
+ max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
64
+
65
+ Returns:
66
+ NetworkGraph: A fully initialized and processed NetworkGraph object.
67
+ """
68
+ # Log the parameters and display headers
69
+ log_header("Finding significant neighborhoods")
70
+ params.log_graph(
71
+ tail=tail,
72
+ pval_cutoff=pval_cutoff,
73
+ fdr_cutoff=fdr_cutoff,
74
+ impute_depth=impute_depth,
75
+ prune_threshold=prune_threshold,
76
+ linkage_criterion=linkage_criterion,
77
+ linkage_method=linkage_method,
78
+ linkage_metric=linkage_metric,
79
+ min_cluster_size=min_cluster_size,
80
+ max_cluster_size=max_cluster_size,
81
+ )
82
+
83
+ # Make a copy of the network to avoid modifying the original
84
+ network = copy.deepcopy(network)
85
+
86
+ logger.debug(f"p-value cutoff: {pval_cutoff}")
87
+ logger.debug(f"FDR BH cutoff: {fdr_cutoff}")
88
+ logger.debug(
89
+ f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
90
+ )
91
+ # Calculate significant neighborhoods based on the provided parameters
92
+ significant_neighborhoods = calculate_significance_matrices(
93
+ neighborhoods["depletion_pvals"],
94
+ neighborhoods["enrichment_pvals"],
95
+ tail=tail,
96
+ pval_cutoff=pval_cutoff,
97
+ fdr_cutoff=fdr_cutoff,
98
+ )
99
+
100
+ log_header("Processing neighborhoods")
101
+ # Process neighborhoods by imputing and pruning based on the given settings
102
+ processed_neighborhoods = process_neighborhoods(
103
+ network=network,
104
+ neighborhoods=significant_neighborhoods,
105
+ impute_depth=impute_depth,
106
+ prune_threshold=prune_threshold,
107
+ )
108
+
109
+ log_header("Finding top annotations")
110
+ logger.debug(f"Min cluster size: {min_cluster_size}")
111
+ logger.debug(f"Max cluster size: {max_cluster_size}")
112
+ # Define top annotations based on processed neighborhoods
113
+ top_annotations = self._define_top_annotations(
114
+ network=network,
115
+ annotations=annotations,
116
+ neighborhoods=processed_neighborhoods,
117
+ min_cluster_size=min_cluster_size,
118
+ max_cluster_size=max_cluster_size,
119
+ )
120
+
121
+ log_header("Optimizing distance threshold for domains")
122
+ # Extract the significant significance matrix from the neighborhoods data
123
+ significant_neighborhoods_significance = processed_neighborhoods[
124
+ "significant_significance_matrix"
125
+ ]
126
+ # Define domains in the network using the specified clustering settings
127
+ domains = define_domains(
128
+ top_annotations=top_annotations,
129
+ significant_neighborhoods_significance=significant_neighborhoods_significance,
130
+ linkage_criterion=linkage_criterion,
131
+ linkage_method=linkage_method,
132
+ linkage_metric=linkage_metric,
133
+ )
134
+ # Trim domains and top annotations based on cluster size constraints
135
+ domains, trimmed_domains = trim_domains(
136
+ domains=domains,
137
+ top_annotations=top_annotations,
138
+ min_cluster_size=min_cluster_size,
139
+ max_cluster_size=max_cluster_size,
140
+ )
141
+
142
+ # Prepare node mapping and significance sums for the final NetworkGraph object
143
+ ordered_nodes = annotations["ordered_nodes"]
144
+ node_label_to_id = dict(zip(ordered_nodes, range(len(ordered_nodes))))
145
+ node_significance_sums = processed_neighborhoods["node_significance_sums"]
146
+
147
+ # Return the fully initialized NetworkGraph object
148
+ return NetworkGraph(
149
+ network=network,
150
+ annotations=annotations,
151
+ neighborhoods=neighborhoods,
152
+ domains=domains,
153
+ trimmed_domains=trimmed_domains,
154
+ node_label_to_node_id_map=node_label_to_id,
155
+ node_significance_sums=node_significance_sums,
156
+ )
157
+
158
+ def _define_top_annotations(
159
+ self,
160
+ network: nx.Graph,
161
+ annotations: Dict[str, Any],
162
+ neighborhoods: Dict[str, Any],
163
+ min_cluster_size: int = 5,
164
+ max_cluster_size: int = 1000,
165
+ ) -> pd.DataFrame:
166
+ """Define top annotations for the network.
167
+
168
+ Args:
169
+ network (nx.Graph): The network graph.
170
+ annotations (Dict[str, Any]): Annotations data for the network.
171
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
172
+ min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
173
+ max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
174
+
175
+ Returns:
176
+ Dict[str, Any]: Top annotations identified within the network.
177
+ """
178
+ # Extract necessary data from annotations and neighborhoods
179
+ ordered_annotations = annotations["ordered_annotations"]
180
+ neighborhood_significance_sums = neighborhoods["neighborhood_significance_counts"]
181
+ significant_significance_matrix = neighborhoods["significant_significance_matrix"]
182
+ significant_binary_significance_matrix = neighborhoods[
183
+ "significant_binary_significance_matrix"
184
+ ]
185
+ # Call external function to define top annotations
186
+ return define_top_annotations(
187
+ network=network,
188
+ ordered_annotation_labels=ordered_annotations,
189
+ neighborhood_significance_sums=neighborhood_significance_sums,
190
+ significant_significance_matrix=significant_significance_matrix,
191
+ significant_binary_significance_matrix=significant_binary_significance_matrix,
192
+ min_cluster_size=min_cluster_size,
193
+ max_cluster_size=max_cluster_size,
194
+ )
@@ -0,0 +1,269 @@
1
+ """
2
+ risk/network/graph/network
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from collections import defaultdict
7
+ from typing import Any, Dict, List
8
+
9
+ import networkx as nx
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from risk.network.graph.summary import AnalysisSummary
14
+
15
+
16
+ class NetworkGraph:
17
+ """A class to represent a network graph and process its nodes and edges.
18
+
19
+ The NetworkGraph class provides functionality to handle and manipulate a network graph,
20
+ including managing domains, annotations, and node significance data. It also includes methods
21
+ for transforming and mapping graph coordinates, as well as generating colors based on node
22
+ significance.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ network: nx.Graph,
28
+ annotations: Dict[str, Any],
29
+ neighborhoods: Dict[str, Any],
30
+ domains: pd.DataFrame,
31
+ trimmed_domains: pd.DataFrame,
32
+ node_label_to_node_id_map: Dict[str, Any],
33
+ node_significance_sums: np.ndarray,
34
+ ):
35
+ """Initialize the NetworkGraph object.
36
+
37
+ Args:
38
+ network (nx.Graph): The network graph.
39
+ annotations (Dict[str, Any]): The annotations associated with the network.
40
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
41
+ domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
42
+ trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
43
+ node_label_to_node_id_map (Dict[str, Any]): A dictionary mapping node labels to their corresponding IDs.
44
+ node_significance_sums (np.ndarray): Array containing the significant sums for the nodes.
45
+ """
46
+ # Initialize self.network downstream of the other attributes
47
+ # All public attributes can be accessed after initialization
48
+ self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
49
+ self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
50
+ trimmed_domains
51
+ )
52
+ self.domain_id_to_domain_info_map = self._create_domain_id_to_domain_info_map(
53
+ trimmed_domains
54
+ )
55
+ self.node_id_to_domain_ids_and_significance_map = (
56
+ self._create_node_id_to_domain_ids_and_significances(domains)
57
+ )
58
+ self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
59
+ self.node_label_to_significance_map = dict(
60
+ zip(node_label_to_node_id_map.keys(), node_significance_sums)
61
+ )
62
+ self.node_significance_sums = node_significance_sums
63
+ self.node_label_to_node_id_map = node_label_to_node_id_map
64
+
65
+ # NOTE: Below this point, instance attributes (i.e., self) will be used!
66
+ self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
67
+ # Unfold the network's 3D coordinates to 2D and extract node coordinates
68
+ self.network = _unfold_sphere_to_plane(network)
69
+ self.node_coordinates = _extract_node_coordinates(self.network)
70
+
71
+ # NOTE: Only after the above attributes are initialized, we can create the summary
72
+ self.summary = AnalysisSummary(annotations, neighborhoods, self)
73
+
74
+ def pop(self, domain_id: str) -> None:
75
+ """Remove domain ID from instance domain ID mappings. This can be useful for cleaning up
76
+ domain-specific mappings based on a given criterion, as domain attributes are stored and
77
+ accessed only in dictionaries modified by this method.
78
+
79
+ Args:
80
+ key (str): The domain ID key to be removed from each mapping.
81
+ """
82
+ # Define the domain mappings to be updated
83
+ domain_mappings = [
84
+ self.domain_id_to_node_ids_map,
85
+ self.domain_id_to_domain_terms_map,
86
+ self.domain_id_to_domain_info_map,
87
+ self.domain_id_to_node_labels_map,
88
+ ]
89
+ # Remove the specified domain_id key from each mapping if it exists
90
+ for mapping in domain_mappings:
91
+ if domain_id in mapping:
92
+ mapping.pop(domain_id)
93
+
94
+ # Remove the domain_id from the node_id_to_domain_ids_and_significance_map
95
+ for _, domain_info in self.node_id_to_domain_ids_and_significance_map.items():
96
+ if domain_id in domain_info["domains"]:
97
+ domain_info["domains"].remove(domain_id)
98
+ domain_info["significances"].pop(domain_id)
99
+
100
+ @staticmethod
101
+ def _create_domain_id_to_node_ids_map(domains: pd.DataFrame) -> Dict[int, Any]:
102
+ """Create a mapping from domains to the list of node IDs belonging to each domain.
103
+
104
+ Args:
105
+ domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
106
+
107
+ Returns:
108
+ Dict[int, Any]: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
109
+ """
110
+ cleaned_domains_matrix = domains.reset_index()[["index", "primary_domain"]]
111
+ node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary_domain"].to_dict()
112
+ domain_id_to_node_ids_map = defaultdict(list)
113
+ for k, v in node_to_domains_map.items():
114
+ domain_id_to_node_ids_map[v].append(k)
115
+
116
+ return domain_id_to_node_ids_map
117
+
118
+ @staticmethod
119
+ def _create_domain_id_to_domain_terms_map(trimmed_domains: pd.DataFrame) -> Dict[int, Any]:
120
+ """Create a mapping from domain IDs to their corresponding terms.
121
+
122
+ Args:
123
+ trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
124
+
125
+ Returns:
126
+ Dict[int, Any]: A dictionary mapping domain IDs to their corresponding terms.
127
+ """
128
+ return dict(
129
+ zip(
130
+ trimmed_domains.index,
131
+ trimmed_domains["normalized_description"],
132
+ )
133
+ )
134
+
135
+ @staticmethod
136
+ def _create_domain_id_to_domain_info_map(
137
+ trimmed_domains: pd.DataFrame,
138
+ ) -> Dict[int, Dict[str, Any]]:
139
+ """Create a mapping from domain IDs to their corresponding full description and significance score,
140
+ with scores sorted in descending order.
141
+
142
+ Args:
143
+ trimmed_domains (pd.DataFrame): DataFrame containing domain IDs, full descriptions, and significance scores.
144
+
145
+ Returns:
146
+ Dict[int, Dict[str, Any]]: A dictionary mapping domain IDs (int) to a dictionary with 'full_descriptions' and
147
+ 'significance_scores', both sorted by significance score in descending order.
148
+ """
149
+ # Initialize an empty dictionary to store full descriptions and significance scores of domains
150
+ domain_info_map = {}
151
+ # Domain IDs are the index of the DataFrame (it's common for some IDs to be missing)
152
+ for domain_id in trimmed_domains.index:
153
+ # Sort full_descriptions and significance_scores by significance_scores in descending order
154
+ descriptions_and_scores = sorted(
155
+ zip(
156
+ trimmed_domains.at[domain_id, "full_descriptions"],
157
+ trimmed_domains.at[domain_id, "significance_scores"],
158
+ ),
159
+ key=lambda x: x[1], # Sort by significance score
160
+ reverse=True, # Descending order
161
+ )
162
+ # Unzip the sorted tuples back into separate lists
163
+ sorted_descriptions, sorted_scores = zip(*descriptions_and_scores)
164
+ # Assign to the domain info map
165
+ domain_info_map[int(domain_id)] = {
166
+ "full_descriptions": list(sorted_descriptions),
167
+ "significance_scores": list(sorted_scores),
168
+ }
169
+
170
+ return domain_info_map
171
+
172
+ @staticmethod
173
+ def _create_node_id_to_domain_ids_and_significances(domains: pd.DataFrame) -> Dict[int, Dict]:
174
+ """Creates a dictionary mapping each node ID to its corresponding domain IDs and significance values.
175
+
176
+ Args:
177
+ domains (pd.DataFrame): A DataFrame containing domain information for each node. Assumes the last
178
+ two columns are 'all domains' and 'primary domain', which are excluded from processing.
179
+
180
+ Returns:
181
+ Dict[int, Dict]: A dictionary where the key is the node ID (index of the DataFrame), and the value is another dictionary
182
+ with 'domain' (a list of domain IDs with non-zero significance) and 'significance'
183
+ (a dict of domain IDs and their corresponding significance values).
184
+ """
185
+ # Initialize an empty dictionary to store the result
186
+ node_id_to_domain_ids_and_significances = {}
187
+ # Get the list of domain columns (excluding 'all domains' and 'primary domain')
188
+ domain_columns = domains.columns[
189
+ :-2
190
+ ] # The last two columns are 'all domains' and 'primary domain'
191
+ # Iterate over each row in the dataframe
192
+ for idx, row in domains.iterrows():
193
+ # Get the domains (column names) where the significance score is greater than 0
194
+ all_domains = domain_columns[row[domain_columns] > 0].tolist()
195
+ # Get the significance values for those domains
196
+ significance_values = row[all_domains].to_dict()
197
+ # Store the result in the dictionary with index as the key
198
+ node_id_to_domain_ids_and_significances[idx] = {
199
+ "domains": all_domains, # The column names where significance > 0
200
+ "significances": significance_values, # The actual significance values for those columns
201
+ }
202
+
203
+ return node_id_to_domain_ids_and_significances
204
+
205
+ def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
206
+ """Create a map from domain IDs to node labels.
207
+
208
+ Returns:
209
+ Dict[int, List[str]]: A dictionary mapping domain IDs to the corresponding node labels.
210
+ """
211
+ domain_id_to_label_map = {}
212
+ for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
213
+ domain_id_to_label_map[domain_id] = [
214
+ self.node_id_to_node_label_map[node_id] for node_id in node_ids
215
+ ]
216
+
217
+ return domain_id_to_label_map
218
+
219
+
220
+ def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
221
+ """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
222
+
223
+ Args:
224
+ G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
225
+
226
+ Returns:
227
+ nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
228
+ """
229
+ for node in G.nodes():
230
+ if "z" in G.nodes[node]:
231
+ # Extract 3D coordinates
232
+ x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
233
+ # Calculate spherical coordinates theta and phi from Cartesian coordinates
234
+ r = np.sqrt(x**2 + y**2 + z**2)
235
+ theta = np.arctan2(y, x)
236
+ phi = np.arccos(z / r)
237
+
238
+ # Convert spherical coordinates to 2D plane coordinates
239
+ unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
240
+ unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
241
+ unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
242
+ # Update network node attributes
243
+ G.nodes[node]["x"] = unfolded_x
244
+ G.nodes[node]["y"] = -unfolded_y
245
+ # Remove the 'z' coordinate as it's no longer needed
246
+ del G.nodes[node]["z"]
247
+
248
+ return G
249
+
250
+
251
+ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
252
+ """Extract 2D coordinates of nodes from the graph.
253
+
254
+ Args:
255
+ G (nx.Graph): The network graph with node coordinates.
256
+
257
+ Returns:
258
+ np.ndarray: Array of node coordinates with shape (num_nodes, 2).
259
+ """
260
+ # Extract x and y coordinates from graph nodes
261
+ x_coords = dict(G.nodes.data("x"))
262
+ y_coords = dict(G.nodes.data("y"))
263
+ coordinates_dicts = [x_coords, y_coords]
264
+ # Combine x and y coordinates into a single array
265
+ node_positions = {
266
+ node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
267
+ }
268
+ node_coordinates = np.vstack(list(node_positions.values()))
269
+ return node_coordinates