risk-network 0.0.8b18__py3-none-any.whl → 0.0.8b20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.8-beta.18"
10
+ __version__ = "0.0.8-beta.20"
@@ -36,10 +36,11 @@ def load_annotations(network: nx.Graph, annotations_input: Dict[str, Any]) -> Di
36
36
  """Convert annotations input to a DataFrame and reindex based on the network's node labels.
37
37
 
38
38
  Args:
39
- annotations_input (dict): A dictionary with annotations.
39
+ network (nx.Graph): The network graph.
40
+ annotations_input (Dict[str, Any]): A dictionary with annotations.
40
41
 
41
42
  Returns:
42
- dict: A dictionary containing ordered nodes, ordered annotations, and the binary annotations matrix.
43
+ Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the binary annotations matrix.
43
44
  """
44
45
  # Flatten the dictionary to a list of tuples for easier DataFrame creation
45
46
  flattened_annotations = [
@@ -255,7 +256,7 @@ def _generate_coherent_description(words: List[str]) -> str:
255
256
  If there is only one unique entry, return it directly.
256
257
 
257
258
  Args:
258
- words (list): A list of words or numerical string values.
259
+ words (List): A list of words or numerical string values.
259
260
 
260
261
  Returns:
261
262
  str: A coherent description formed by arranging the words in a logical sequence.
risk/annotations/io.py CHANGED
@@ -33,7 +33,7 @@ class AnnotationsIO:
33
33
  filepath (str): Path to the JSON annotations file.
34
34
 
35
35
  Returns:
36
- dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
36
+ Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
37
37
  """
38
38
  filetype = "JSON"
39
39
  # Log the loading of the JSON file
@@ -158,10 +158,10 @@ class AnnotationsIO:
158
158
 
159
159
  Args:
160
160
  network (NetworkX graph): The network to which the annotations are related.
161
- content (dict): The annotations dictionary to load.
161
+ content (Dict[str, Any]): The annotations dictionary to load.
162
162
 
163
163
  Returns:
164
- dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
164
+ Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
165
165
  """
166
166
  # Ensure the input content is a dictionary
167
167
  if not isinstance(content, dict):
risk/log/params.py CHANGED
@@ -159,7 +159,7 @@ class Params:
159
159
  """Load and process various parameters, converting any np.ndarray values to lists.
160
160
 
161
161
  Returns:
162
- dict: A dictionary containing the processed parameters.
162
+ Dict[str, Any]: A dictionary containing the processed parameters.
163
163
  """
164
164
  log_header("Loading parameters")
165
165
  return _convert_ndarray_to_list(
@@ -174,14 +174,14 @@ class Params:
174
174
  )
175
175
 
176
176
 
177
- def _convert_ndarray_to_list(d: Any) -> Any:
177
+ def _convert_ndarray_to_list(d: Dict[str, Any]) -> Dict[str, Any]:
178
178
  """Recursively convert all np.ndarray values in the dictionary to lists.
179
179
 
180
180
  Args:
181
- d (dict): The dictionary to process.
181
+ d (Dict[str, Any]): The dictionary to process.
182
182
 
183
183
  Returns:
184
- dict: The processed dictionary with np.ndarray values converted to lists.
184
+ Dict[str, Any]: The processed dictionary with np.ndarray values converted to lists.
185
185
  """
186
186
  if isinstance(d, dict):
187
187
  # Recursively process each value in the dictionary
@@ -193,5 +193,5 @@ def _convert_ndarray_to_list(d: Any) -> Any:
193
193
  # Convert numpy arrays to lists
194
194
  return d.tolist()
195
195
  else:
196
- # Return the value unchanged if it's not a dict, list, or ndarray
196
+ # Return the value unchanged if it's not a dict, List, or ndarray
197
197
  return d
@@ -21,15 +21,20 @@ def calculate_greedy_modularity_neighborhoods(network: nx.Graph) -> np.ndarray:
21
21
  """
22
22
  # Detect communities using the Greedy Modularity method
23
23
  communities = greedy_modularity_communities(network)
24
- # Create a mapping from node to community
25
- community_dict = {node: idx for idx, community in enumerate(communities) for node in community}
26
24
  # Create a binary neighborhood matrix
27
- neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
25
+ n_nodes = network.number_of_nodes()
26
+ neighborhoods = np.zeros((n_nodes, n_nodes), dtype=int)
27
+ # Create a mapping from node to index in the matrix
28
28
  node_index = {node: i for i, node in enumerate(network.nodes())}
29
- for node_i, community_i in community_dict.items():
30
- for node_j, community_j in community_dict.items():
31
- if community_i == community_j:
32
- neighborhoods[node_index[node_i], node_index[node_j]] = 1
29
+ # Fill in the neighborhood matrix for nodes in the same community
30
+ for community in communities:
31
+ # Iterate through all pairs of nodes in the same community
32
+ for node_i in community:
33
+ idx_i = node_index[node_i]
34
+ for node_j in community:
35
+ idx_j = node_index[node_j]
36
+ # Set them as neighbors (1) in the binary matrix
37
+ neighborhoods[idx_i, idx_j] = 1
33
38
 
34
39
  return neighborhoods
35
40
 
@@ -43,22 +48,20 @@ def calculate_label_propagation_neighborhoods(network: nx.Graph) -> np.ndarray:
43
48
  Returns:
44
49
  np.ndarray: Binary neighborhood matrix on Label Propagation.
45
50
  """
46
- # Apply Label Propagation
51
+ # Apply Label Propagation for community detection
47
52
  communities = nx.algorithms.community.label_propagation.label_propagation_communities(network)
48
- # Create a mapping from node to community
49
- community_dict = {}
50
- for community_id, community in enumerate(communities):
51
- for node in community:
52
- community_dict[node] = community_id
53
-
54
53
  # Create a binary neighborhood matrix
55
54
  num_nodes = network.number_of_nodes()
56
55
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
56
+ # Create a mapping from node to index in the matrix
57
+ node_index = {node: i for i, node in enumerate(network.nodes())}
57
58
  # Assign neighborhoods based on community labels
58
- for node_i, community_i in community_dict.items():
59
- for node_j, community_j in community_dict.items():
60
- if community_i == community_j:
61
- neighborhoods[node_i, node_j] = 1
59
+ for community in communities:
60
+ for node_i in community:
61
+ idx_i = node_index[node_i]
62
+ for node_j in community:
63
+ idx_j = node_index[node_j]
64
+ neighborhoods[idx_i, idx_j] = 1
62
65
 
63
66
  return neighborhoods
64
67
 
@@ -81,12 +84,22 @@ def calculate_louvain_neighborhoods(
81
84
  network, resolution=resolution, random_state=random_seed
82
85
  )
83
86
  # Create a binary neighborhood matrix
84
- neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
87
+ num_nodes = network.number_of_nodes()
88
+ neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
89
+ # Create a mapping from node to index in the matrix
90
+ node_index = {node: i for i, node in enumerate(network.nodes())}
91
+ # Group nodes by community
92
+ community_groups = {}
93
+ for node, community in partition.items():
94
+ community_groups.setdefault(community, []).append(node)
95
+
85
96
  # Assign neighborhoods based on community partitions
86
- for node_i, community_i in partition.items():
87
- for node_j, community_j in partition.items():
88
- if community_i == community_j:
89
- neighborhoods[node_i, node_j] = 1
97
+ for community, nodes in community_groups.items():
98
+ for node_i in nodes:
99
+ idx_i = node_index[node_i]
100
+ for node_j in nodes:
101
+ idx_j = node_index[node_j]
102
+ neighborhoods[idx_i, idx_j] = 1
90
103
 
91
104
  return neighborhoods
92
105
 
@@ -102,24 +115,22 @@ def calculate_markov_clustering_neighborhoods(network: nx.Graph) -> np.ndarray:
102
115
  """
103
116
  # Convert the graph to an adjacency matrix
104
117
  adjacency_matrix = nx.to_numpy_array(network)
105
- # Run Markov Clustering
106
- result = mc.run_mcl(adjacency_matrix) # Run MCL with default parameters
107
- # Get clusters
118
+ # Run Markov Clustering (MCL)
119
+ result = mc.run_mcl(adjacency_matrix) # MCL with default parameters
120
+ # Get clusters (communities) from MCL result
108
121
  clusters = mc.get_clusters(result)
109
- # Create a community label for each node
110
- community_dict = {}
111
- for community_id, community in enumerate(clusters):
112
- for node in community:
113
- community_dict[node] = community_id
114
-
115
122
  # Create a binary neighborhood matrix
116
123
  num_nodes = network.number_of_nodes()
117
124
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
118
- # Assign neighborhoods based on community labels
119
- for node_i, community_i in community_dict.items():
120
- for node_j, community_j in community_dict.items():
121
- if community_i == community_j:
122
- neighborhoods[node_i, node_j] = 1
125
+ # Create a mapping from node to index in the matrix
126
+ node_index = {node: i for i, node in enumerate(network.nodes())}
127
+ # Assign neighborhoods based on MCL clusters
128
+ for cluster in clusters:
129
+ for node_i in cluster:
130
+ idx_i = node_index[node_i]
131
+ for node_j in cluster:
132
+ idx_j = node_index[node_j]
133
+ neighborhoods[idx_i, idx_j] = 1
123
134
 
124
135
  return neighborhoods
125
136
 
@@ -133,22 +144,20 @@ def calculate_spinglass_neighborhoods(network: nx.Graph) -> np.ndarray:
133
144
  Returns:
134
145
  np.ndarray: Binary neighborhood matrix on Spin Glass communities.
135
146
  """
136
- # Use the asynchronous label propagation algorithm as a proxy for Spin Glass
147
+ # Apply Asynchronous Label Propagation (LPA)
137
148
  communities = asyn_lpa_communities(network)
138
- # Create a community label for each node
139
- community_dict = {}
140
- for community_id, community in enumerate(communities):
141
- for node in community:
142
- community_dict[node] = community_id
143
-
144
149
  # Create a binary neighborhood matrix
145
150
  num_nodes = network.number_of_nodes()
146
151
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
147
- # Assign neighborhoods based on community labels
148
- for node_i, community_i in community_dict.items():
149
- for node_j, community_j in community_dict.items():
150
- if community_i == community_j:
151
- neighborhoods[node_i, node_j] = 1
152
+ # Create a mapping from node to index in the matrix
153
+ node_index = {node: i for i, node in enumerate(network.nodes())}
154
+ # Assign neighborhoods based on community labels from LPA
155
+ for community in communities:
156
+ for node_i in community:
157
+ idx_i = node_index[node_i]
158
+ for node_j in community:
159
+ idx_j = node_index[node_j]
160
+ neighborhoods[idx_i, idx_j] = 1
152
161
 
153
162
  return neighborhoods
154
163
 
@@ -162,21 +171,19 @@ def calculate_walktrap_neighborhoods(network: nx.Graph) -> np.ndarray:
162
171
  Returns:
163
172
  np.ndarray: Binary neighborhood matrix on Walktrap communities.
164
173
  """
165
- # Use the asynchronous label propagation algorithm as a proxy for Walktrap
174
+ # Apply Asynchronous Label Propagation (LPA)
166
175
  communities = asyn_lpa_communities(network)
167
- # Create a community label for each node
168
- community_dict = {}
169
- for community_id, community in enumerate(communities):
170
- for node in community:
171
- community_dict[node] = community_id
172
-
173
176
  # Create a binary neighborhood matrix
174
177
  num_nodes = network.number_of_nodes()
175
178
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
176
- # Assign neighborhoods based on community labels
177
- for node_i, community_i in community_dict.items():
178
- for node_j, community_j in community_dict.items():
179
- if community_i == community_j:
180
- neighborhoods[node_i, node_j] = 1
179
+ # Create a mapping from node to index in the matrix
180
+ node_index = {node: i for i, node in enumerate(network.nodes())}
181
+ # Assign neighborhoods based on community labels from LPA
182
+ for community in communities:
183
+ for node_i in community:
184
+ idx_i = node_index[node_i]
185
+ for node_j in community:
186
+ idx_j = node_index[node_j]
187
+ neighborhoods[idx_i, idx_j] = 1
181
188
 
182
189
  return neighborhoods
@@ -76,6 +76,10 @@ def define_domains(
76
76
  t_idxmax = node_to_domain.loc[:, 1:].idxmax(axis=1)
77
77
  t_idxmax[t_max == 0] = 0
78
78
 
79
+ # Assign all domains where the score is greater than 0
80
+ node_to_domain["all domains"] = node_to_domain.loc[:, 1:].apply(
81
+ lambda row: list(row[row > 0].index), axis=1
82
+ )
79
83
  # Assign primary domain
80
84
  node_to_domain["primary domain"] = t_idxmax
81
85
 
@@ -97,7 +101,7 @@ def trim_domains_and_top_annotations(
97
101
  max_cluster_size (int, optional): Maximum size of a cluster to be retained. Defaults to 1000.
98
102
 
99
103
  Returns:
100
- tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing:
104
+ Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing:
101
105
  - Trimmed annotations (pd.DataFrame)
102
106
  - Trimmed domains (pd.DataFrame)
103
107
  - A DataFrame with domain labels (pd.DataFrame)
@@ -154,7 +158,7 @@ def _optimize_silhouette_across_linkage_and_metrics(
154
158
  linkage_metric (str): Linkage metric for clustering.
155
159
 
156
160
  Returns:
157
- tuple[str, str, float]: A tuple containing:
161
+ Tuple[str, str, float]: A tuple containing:
158
162
  - Best linkage method (str)
159
163
  - Best linkage metric (str)
160
164
  - Best threshold (float)
@@ -208,7 +212,7 @@ def _find_best_silhouette_score(
208
212
  resolution (float, optional): Desired resolution for the best threshold. Defaults to 0.001.
209
213
 
210
214
  Returns:
211
- tuple[float, float]: A tuple containing:
215
+ Tuple[float, float]: A tuple containing:
212
216
  - Best threshold (float): The threshold that yields the best silhouette score.
213
217
  - Best silhouette score (float): The highest silhouette score achieved.
214
218
  """
@@ -5,7 +5,7 @@ risk/neighborhoods/neighborhoods
5
5
 
6
6
  import random
7
7
  import warnings
8
- from typing import Any, Dict, List, Tuple
8
+ from typing import Any, Dict, List, Tuple, Union
9
9
 
10
10
  import networkx as nx
11
11
  import numpy as np
@@ -28,50 +28,82 @@ warnings.filterwarnings(action="ignore", category=DataConversionWarning)
28
28
 
29
29
  def get_network_neighborhoods(
30
30
  network: nx.Graph,
31
- distance_metric: str = "louvain",
32
- edge_length_threshold: float = 1.0,
31
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
32
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 1.0,
33
33
  louvain_resolution: float = 1.0,
34
34
  random_seed: int = 888,
35
35
  ) -> np.ndarray:
36
- """Calculate the neighborhoods for each node in the network based on the specified distance metric.
36
+ """Calculate the combined neighborhoods for each node based on the specified community detection algorithm(s).
37
37
 
38
38
  Args:
39
39
  network (nx.Graph): The network graph.
40
- distance_metric (str): The distance metric to use ('greedy_modularity', 'louvain', 'label_propagation',
41
- 'markov_clustering', 'walktrap', 'spinglass').
42
- edge_length_threshold (float): The edge length threshold for the neighborhoods.
40
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
41
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
42
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
43
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
44
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
45
+ Defaults to 1.0.
43
46
  louvain_resolution (float, optional): Resolution parameter for the Louvain method. Defaults to 1.0.
44
47
  random_seed (int, optional): Random seed for methods requiring random initialization. Defaults to 888.
45
48
 
46
49
  Returns:
47
- np.ndarray: Neighborhood matrix calculated based on the selected distance metric.
50
+ np.ndarray: Summed neighborhood matrix from all selected algorithms.
48
51
  """
49
- # Set random seed for reproducibility in all methods besides Louvain, which requires a separate seed
52
+ # Set random seed for reproducibility
50
53
  random.seed(random_seed)
51
54
  np.random.seed(random_seed)
52
55
 
53
- # Create a subgraph based on the edge length percentile threshold
54
- network = _create_percentile_limited_subgraph(
55
- network, edge_length_percentile=edge_length_threshold
56
- )
56
+ # Ensure distance_metric is a list/tuple for multi-algorithm handling
57
+ if isinstance(distance_metric, (str, np.ndarray)):
58
+ distance_metric = [distance_metric]
59
+ # Ensure edge_length_threshold is a list/tuple for multi-threshold handling
60
+ if isinstance(edge_length_threshold, (float, int)):
61
+ edge_length_threshold = [edge_length_threshold] * len(distance_metric)
62
+ # Check that the number of distance metrics matches the number of edge length thresholds
63
+ if len(distance_metric) != len(edge_length_threshold):
64
+ raise ValueError(
65
+ "The number of distance metrics must match the number of edge length thresholds."
66
+ )
57
67
 
58
- if distance_metric == "louvain":
59
- return calculate_louvain_neighborhoods(network, louvain_resolution, random_seed=random_seed)
60
- if distance_metric == "greedy_modularity":
61
- return calculate_greedy_modularity_neighborhoods(network)
62
- if distance_metric == "label_propagation":
63
- return calculate_label_propagation_neighborhoods(network)
64
- if distance_metric == "markov_clustering":
65
- return calculate_markov_clustering_neighborhoods(network)
66
- if distance_metric == "walktrap":
67
- return calculate_walktrap_neighborhoods(network)
68
- if distance_metric == "spinglass":
69
- return calculate_spinglass_neighborhoods(network)
70
-
71
- raise ValueError(
72
- "Incorrect distance metric specified. Please choose from 'greedy_modularity', 'louvain',"
73
- "'label_propagation', 'markov_clustering', 'walktrap', 'spinglass'."
74
- )
68
+ # Initialize combined neighborhood matrix
69
+ num_nodes = network.number_of_nodes()
70
+ combined_neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
71
+
72
+ # Loop through each distance metric and corresponding edge length threshold
73
+ for metric, threshold in zip(distance_metric, edge_length_threshold):
74
+ # Create a subgraph based on the specific edge length threshold for this algorithm
75
+ subgraph = _create_percentile_limited_subgraph(network, edge_length_percentile=threshold)
76
+ # Call the appropriate neighborhood function based on the metric
77
+ if metric == "louvain":
78
+ neighborhoods = calculate_louvain_neighborhoods(
79
+ subgraph, louvain_resolution, random_seed=random_seed
80
+ )
81
+ elif metric == "greedy_modularity":
82
+ neighborhoods = calculate_greedy_modularity_neighborhoods(subgraph)
83
+ elif metric == "label_propagation":
84
+ neighborhoods = calculate_label_propagation_neighborhoods(subgraph)
85
+ elif metric == "markov_clustering":
86
+ neighborhoods = calculate_markov_clustering_neighborhoods(subgraph)
87
+ elif metric == "walktrap":
88
+ neighborhoods = calculate_walktrap_neighborhoods(subgraph)
89
+ elif metric == "spinglass":
90
+ neighborhoods = calculate_spinglass_neighborhoods(subgraph)
91
+ else:
92
+ raise ValueError(
93
+ "Incorrect distance metric specified. Please choose from 'greedy_modularity', 'louvain',"
94
+ "'label_propagation', 'markov_clustering', 'walktrap', 'spinglass'."
95
+ )
96
+
97
+ # Sum the neighborhood matrices
98
+ combined_neighborhoods += neighborhoods
99
+
100
+ # Ensure that the maximum value in each row is set to 1
101
+ # This ensures that for each row, only the strongest relationship (the maximum value) is retained,
102
+ # while all other values are reset to 0. This transformation simplifies the neighborhood matrix by
103
+ # focusing on the most significant connection per row.
104
+ combined_neighborhoods = _set_max_to_one(combined_neighborhoods)
105
+
106
+ return combined_neighborhoods
75
107
 
76
108
 
77
109
  def _create_percentile_limited_subgraph(G: nx.Graph, edge_length_percentile: float) -> nx.Graph:
@@ -110,6 +142,25 @@ def _create_percentile_limited_subgraph(G: nx.Graph, edge_length_percentile: flo
110
142
  return subgraph
111
143
 
112
144
 
145
+ def _set_max_to_one(matrix: np.ndarray) -> np.ndarray:
146
+ """For each row in the input matrix, set the maximum value(s) to 1 and all other values to 0.
147
+
148
+ Args:
149
+ matrix (np.ndarray): A 2D numpy array representing the neighborhood matrix.
150
+
151
+ Returns:
152
+ np.ndarray: The modified matrix where only the maximum value(s) in each row is set to 1, and others are set to 0.
153
+ """
154
+ # Find the maximum value in each row (column-wise max operation)
155
+ max_values = np.max(matrix, axis=1, keepdims=True)
156
+ # Create a boolean mask where elements are True if they are the max value in their row
157
+ max_mask = matrix == max_values
158
+ # Set all elements to 0, and then set the maximum value positions to 1
159
+ matrix[:] = 0 # Set everything to 0
160
+ matrix[max_mask] = 1 # Set only the max values to 1
161
+ return matrix
162
+
163
+
113
164
  def process_neighborhoods(
114
165
  network: nx.Graph,
115
166
  neighborhoods: Dict[str, Any],
@@ -120,12 +171,12 @@ def process_neighborhoods(
120
171
 
121
172
  Args:
122
173
  network (nx.Graph): The network data structure used for imputing and pruning neighbors.
123
- neighborhoods (dict): Dictionary containing 'enrichment_matrix', 'binary_enrichment_matrix', and 'significant_enrichment_matrix'.
174
+ neighborhoods (Dict[str, Any]): Dictionary containing 'enrichment_matrix', 'binary_enrichment_matrix', and 'significant_enrichment_matrix'.
124
175
  impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
125
176
  prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
126
177
 
127
178
  Returns:
128
- dict: Processed neighborhoods data, including the updated matrices and enrichment counts.
179
+ Dict[str, Any]: Processed neighborhoods data, including the updated matrices and enrichment counts.
129
180
  """
130
181
  enrichment_matrix = neighborhoods["enrichment_matrix"]
131
182
  binary_enrichment_matrix = neighborhoods["binary_enrichment_matrix"]
@@ -408,7 +459,7 @@ def _calculate_threshold(median_distances: List, distance_threshold: float) -> f
408
459
  """Calculate the distance threshold based on the given median distances and a percentile threshold.
409
460
 
410
461
  Args:
411
- median_distances (list): An array of median distances.
462
+ median_distances (List): An array of median distances.
412
463
  distance_threshold (float): A percentile threshold (0 to 1) used to determine the distance cutoff.
413
464
 
414
465
  Returns:
risk/network/geometry.py CHANGED
@@ -68,6 +68,7 @@ def assign_edge_lengths(
68
68
  v_coords = np.append(v_coords, G_depth.nodes[v].get("z", 0))
69
69
 
70
70
  distance = compute_distance(u_coords, v_coords, is_sphere=compute_sphere)
71
+ # Assign edge lengths to the original graph
71
72
  if include_edge_weight:
72
73
  # Square root of the normalized weight is used to minimize the effect of large weights
73
74
  G.edges[u, v]["length"] = distance / np.sqrt(G.edges[u, v]["normalized_weight"] + 1e-6)
risk/network/graph.py CHANGED
@@ -36,7 +36,7 @@ class NetworkGraph:
36
36
  top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
37
37
  domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
38
38
  trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
39
- node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
39
+ node_label_to_node_id_map (Dict[str, Any]): A dictionary mapping node labels to their corresponding IDs.
40
40
  node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
41
41
  """
42
42
  self.top_annotations = top_annotations
@@ -46,6 +46,9 @@ class NetworkGraph:
46
46
  trimmed_domains
47
47
  )
48
48
  self.node_enrichment_sums = node_enrichment_sums
49
+ self.node_id_to_domain_ids_and_enrichments_map = (
50
+ self._create_node_id_to_domain_ids_and_enrichments(domains)
51
+ )
49
52
  self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
50
53
  self.node_label_to_enrichment_map = dict(
51
54
  zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
@@ -57,14 +60,14 @@ class NetworkGraph:
57
60
  self.network = _unfold_sphere_to_plane(network)
58
61
  self.node_coordinates = _extract_node_coordinates(self.network)
59
62
 
60
- def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
63
+ def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[int, Any]:
61
64
  """Create a mapping from domains to the list of node IDs belonging to each domain.
62
65
 
63
66
  Args:
64
67
  domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
65
68
 
66
69
  Returns:
67
- dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
70
+ Dict[int, Any]: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
68
71
  """
69
72
  cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
70
73
  node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
@@ -76,14 +79,14 @@ class NetworkGraph:
76
79
 
77
80
  def _create_domain_id_to_domain_terms_map(
78
81
  self, trimmed_domains: pd.DataFrame
79
- ) -> Dict[str, Any]:
82
+ ) -> Dict[int, Any]:
80
83
  """Create a mapping from domain IDs to their corresponding terms.
81
84
 
82
85
  Args:
83
86
  trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
84
87
 
85
88
  Returns:
86
- dict: A dictionary mapping domain IDs to their corresponding terms.
89
+ Dict[int, Any]: A dictionary mapping domain IDs to their corresponding terms.
87
90
  """
88
91
  return dict(
89
92
  zip(
@@ -92,11 +95,45 @@ class NetworkGraph:
92
95
  )
93
96
  )
94
97
 
98
+ def _create_node_id_to_domain_ids_and_enrichments(
99
+ self, domains: pd.DataFrame
100
+ ) -> Dict[int, Dict]:
101
+ """Creates a dictionary mapping each node ID to its corresponding domain IDs and enrichment values.
102
+
103
+ Args:
104
+ domains (pd.DataFrame): A DataFrame containing domain information for each node. Assumes the last
105
+ two columns are 'all domains' and 'primary domain', which are excluded from processing.
106
+
107
+ Returns:
108
+ Dict[int, Dict]: A dictionary where the key is the node ID (index of the DataFrame), and the value is another dictionary
109
+ with 'domain' (a list of domain IDs with non-zero enrichment) and 'enrichment'
110
+ (a dict of domain IDs and their corresponding enrichment values).
111
+ """
112
+ # Initialize an empty dictionary to store the result
113
+ node_id_to_domain_ids_and_enrichments = {}
114
+ # Get the list of domain columns (excluding 'all domains' and 'primary domain')
115
+ domain_columns = domains.columns[
116
+ :-2
117
+ ] # The last two columns are 'all domains' and 'primary domain'
118
+ # Iterate over each row in the dataframe
119
+ for idx, row in domains.iterrows():
120
+ # Get the domains (column names) where the enrichment score is greater than 0
121
+ all_domains = domain_columns[row[domain_columns] > 0].tolist()
122
+ # Get the enrichment values for those domains
123
+ enrichment_values = row[all_domains].to_dict()
124
+ # Store the result in the dictionary with index as the key
125
+ node_id_to_domain_ids_and_enrichments[idx] = {
126
+ "domains": all_domains, # The column names where enrichment > 0
127
+ "enrichments": enrichment_values, # The actual enrichment values for those columns
128
+ }
129
+
130
+ return node_id_to_domain_ids_and_enrichments
131
+
95
132
  def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
96
133
  """Create a map from domain IDs to node labels.
97
134
 
98
135
  Returns:
99
- dict: A dictionary mapping domain IDs to the corresponding node labels.
136
+ Dict[int, List[str]]: A dictionary mapping domain IDs to the corresponding node labels.
100
137
  """
101
138
  domain_id_to_label_map = {}
102
139
  for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
risk/network/io.py CHANGED
@@ -491,7 +491,7 @@ class NetworkIO:
491
491
  if "x" not in attrs or "y" not in attrs:
492
492
  if (
493
493
  "pos" in attrs
494
- and isinstance(attrs["pos"], (list, tuple, np.ndarray))
494
+ and isinstance(attrs["pos"], (List, Tuple, np.ndarray))
495
495
  and len(attrs["pos"]) >= 2
496
496
  ):
497
497
  attrs["x"], attrs["y"] = attrs["pos"][