risk-network 0.0.8b19__py3-none-any.whl → 0.0.8b21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.8-beta.19"
10
+ __version__ = "0.0.8-beta.21"
@@ -36,10 +36,11 @@ def load_annotations(network: nx.Graph, annotations_input: Dict[str, Any]) -> Di
36
36
  """Convert annotations input to a DataFrame and reindex based on the network's node labels.
37
37
 
38
38
  Args:
39
- annotations_input (dict): A dictionary with annotations.
39
+ network (nx.Graph): The network graph.
40
+ annotations_input (Dict[str, Any]): A dictionary with annotations.
40
41
 
41
42
  Returns:
42
- dict: A dictionary containing ordered nodes, ordered annotations, and the binary annotations matrix.
43
+ Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the binary annotations matrix.
43
44
  """
44
45
  # Flatten the dictionary to a list of tuples for easier DataFrame creation
45
46
  flattened_annotations = [
@@ -255,7 +256,7 @@ def _generate_coherent_description(words: List[str]) -> str:
255
256
  If there is only one unique entry, return it directly.
256
257
 
257
258
  Args:
258
- words (list): A list of words or numerical string values.
259
+ words (List): A list of words or numerical string values.
259
260
 
260
261
  Returns:
261
262
  str: A coherent description formed by arranging the words in a logical sequence.
risk/annotations/io.py CHANGED
@@ -33,7 +33,7 @@ class AnnotationsIO:
33
33
  filepath (str): Path to the JSON annotations file.
34
34
 
35
35
  Returns:
36
- dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
36
+ Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
37
37
  """
38
38
  filetype = "JSON"
39
39
  # Log the loading of the JSON file
@@ -158,10 +158,10 @@ class AnnotationsIO:
158
158
 
159
159
  Args:
160
160
  network (NetworkX graph): The network to which the annotations are related.
161
- content (dict): The annotations dictionary to load.
161
+ content (Dict[str, Any]): The annotations dictionary to load.
162
162
 
163
163
  Returns:
164
- dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
164
+ Dict[str, Any]: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
165
165
  """
166
166
  # Ensure the input content is a dictionary
167
167
  if not isinstance(content, dict):
risk/log/params.py CHANGED
@@ -159,7 +159,7 @@ class Params:
159
159
  """Load and process various parameters, converting any np.ndarray values to lists.
160
160
 
161
161
  Returns:
162
- dict: A dictionary containing the processed parameters.
162
+ Dict[str, Any]: A dictionary containing the processed parameters.
163
163
  """
164
164
  log_header("Loading parameters")
165
165
  return _convert_ndarray_to_list(
@@ -174,14 +174,14 @@ class Params:
174
174
  )
175
175
 
176
176
 
177
- def _convert_ndarray_to_list(d: Any) -> Any:
177
+ def _convert_ndarray_to_list(d: Dict[str, Any]) -> Dict[str, Any]:
178
178
  """Recursively convert all np.ndarray values in the dictionary to lists.
179
179
 
180
180
  Args:
181
- d (dict): The dictionary to process.
181
+ d (Dict[str, Any]): The dictionary to process.
182
182
 
183
183
  Returns:
184
- dict: The processed dictionary with np.ndarray values converted to lists.
184
+ Dict[str, Any]: The processed dictionary with np.ndarray values converted to lists.
185
185
  """
186
186
  if isinstance(d, dict):
187
187
  # Recursively process each value in the dictionary
@@ -21,15 +21,20 @@ def calculate_greedy_modularity_neighborhoods(network: nx.Graph) -> np.ndarray:
21
21
  """
22
22
  # Detect communities using the Greedy Modularity method
23
23
  communities = greedy_modularity_communities(network)
24
- # Create a mapping from node to community
25
- community_dict = {node: idx for idx, community in enumerate(communities) for node in community}
26
24
  # Create a binary neighborhood matrix
27
- neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
25
+ n_nodes = network.number_of_nodes()
26
+ neighborhoods = np.zeros((n_nodes, n_nodes), dtype=int)
27
+ # Create a mapping from node to index in the matrix
28
28
  node_index = {node: i for i, node in enumerate(network.nodes())}
29
- for node_i, community_i in community_dict.items():
30
- for node_j, community_j in community_dict.items():
31
- if community_i == community_j:
32
- neighborhoods[node_index[node_i], node_index[node_j]] = 1
29
+ # Fill in the neighborhood matrix for nodes in the same community
30
+ for community in communities:
31
+ # Iterate through all pairs of nodes in the same community
32
+ for node_i in community:
33
+ idx_i = node_index[node_i]
34
+ for node_j in community:
35
+ idx_j = node_index[node_j]
36
+ # Set them as neighbors (1) in the binary matrix
37
+ neighborhoods[idx_i, idx_j] = 1
33
38
 
34
39
  return neighborhoods
35
40
 
@@ -43,22 +48,20 @@ def calculate_label_propagation_neighborhoods(network: nx.Graph) -> np.ndarray:
43
48
  Returns:
44
49
  np.ndarray: Binary neighborhood matrix on Label Propagation.
45
50
  """
46
- # Apply Label Propagation
51
+ # Apply Label Propagation for community detection
47
52
  communities = nx.algorithms.community.label_propagation.label_propagation_communities(network)
48
- # Create a mapping from node to community
49
- community_dict = {}
50
- for community_id, community in enumerate(communities):
51
- for node in community:
52
- community_dict[node] = community_id
53
-
54
53
  # Create a binary neighborhood matrix
55
54
  num_nodes = network.number_of_nodes()
56
55
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
56
+ # Create a mapping from node to index in the matrix
57
+ node_index = {node: i for i, node in enumerate(network.nodes())}
57
58
  # Assign neighborhoods based on community labels
58
- for node_i, community_i in community_dict.items():
59
- for node_j, community_j in community_dict.items():
60
- if community_i == community_j:
61
- neighborhoods[node_i, node_j] = 1
59
+ for community in communities:
60
+ for node_i in community:
61
+ idx_i = node_index[node_i]
62
+ for node_j in community:
63
+ idx_j = node_index[node_j]
64
+ neighborhoods[idx_i, idx_j] = 1
62
65
 
63
66
  return neighborhoods
64
67
 
@@ -81,12 +84,22 @@ def calculate_louvain_neighborhoods(
81
84
  network, resolution=resolution, random_state=random_seed
82
85
  )
83
86
  # Create a binary neighborhood matrix
84
- neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
87
+ num_nodes = network.number_of_nodes()
88
+ neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
89
+ # Create a mapping from node to index in the matrix
90
+ node_index = {node: i for i, node in enumerate(network.nodes())}
91
+ # Group nodes by community
92
+ community_groups = {}
93
+ for node, community in partition.items():
94
+ community_groups.setdefault(community, []).append(node)
95
+
85
96
  # Assign neighborhoods based on community partitions
86
- for node_i, community_i in partition.items():
87
- for node_j, community_j in partition.items():
88
- if community_i == community_j:
89
- neighborhoods[node_i, node_j] = 1
97
+ for community, nodes in community_groups.items():
98
+ for node_i in nodes:
99
+ idx_i = node_index[node_i]
100
+ for node_j in nodes:
101
+ idx_j = node_index[node_j]
102
+ neighborhoods[idx_i, idx_j] = 1
90
103
 
91
104
  return neighborhoods
92
105
 
@@ -102,24 +115,22 @@ def calculate_markov_clustering_neighborhoods(network: nx.Graph) -> np.ndarray:
102
115
  """
103
116
  # Convert the graph to an adjacency matrix
104
117
  adjacency_matrix = nx.to_numpy_array(network)
105
- # Run Markov Clustering
106
- result = mc.run_mcl(adjacency_matrix) # Run MCL with default parameters
107
- # Get clusters
118
+ # Run Markov Clustering (MCL)
119
+ result = mc.run_mcl(adjacency_matrix) # MCL with default parameters
120
+ # Get clusters (communities) from MCL result
108
121
  clusters = mc.get_clusters(result)
109
- # Create a community label for each node
110
- community_dict = {}
111
- for community_id, community in enumerate(clusters):
112
- for node in community:
113
- community_dict[node] = community_id
114
-
115
122
  # Create a binary neighborhood matrix
116
123
  num_nodes = network.number_of_nodes()
117
124
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
118
- # Assign neighborhoods based on community labels
119
- for node_i, community_i in community_dict.items():
120
- for node_j, community_j in community_dict.items():
121
- if community_i == community_j:
122
- neighborhoods[node_i, node_j] = 1
125
+ # Create a mapping from node to index in the matrix
126
+ node_index = {node: i for i, node in enumerate(network.nodes())}
127
+ # Assign neighborhoods based on MCL clusters
128
+ for cluster in clusters:
129
+ for node_i in cluster:
130
+ idx_i = node_index[node_i]
131
+ for node_j in cluster:
132
+ idx_j = node_index[node_j]
133
+ neighborhoods[idx_i, idx_j] = 1
123
134
 
124
135
  return neighborhoods
125
136
 
@@ -133,22 +144,20 @@ def calculate_spinglass_neighborhoods(network: nx.Graph) -> np.ndarray:
133
144
  Returns:
134
145
  np.ndarray: Binary neighborhood matrix on Spin Glass communities.
135
146
  """
136
- # Use the asynchronous label propagation algorithm as a proxy for Spin Glass
147
+ # Apply Asynchronous Label Propagation (LPA)
137
148
  communities = asyn_lpa_communities(network)
138
- # Create a community label for each node
139
- community_dict = {}
140
- for community_id, community in enumerate(communities):
141
- for node in community:
142
- community_dict[node] = community_id
143
-
144
149
  # Create a binary neighborhood matrix
145
150
  num_nodes = network.number_of_nodes()
146
151
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
147
- # Assign neighborhoods based on community labels
148
- for node_i, community_i in community_dict.items():
149
- for node_j, community_j in community_dict.items():
150
- if community_i == community_j:
151
- neighborhoods[node_i, node_j] = 1
152
+ # Create a mapping from node to index in the matrix
153
+ node_index = {node: i for i, node in enumerate(network.nodes())}
154
+ # Assign neighborhoods based on community labels from LPA
155
+ for community in communities:
156
+ for node_i in community:
157
+ idx_i = node_index[node_i]
158
+ for node_j in community:
159
+ idx_j = node_index[node_j]
160
+ neighborhoods[idx_i, idx_j] = 1
152
161
 
153
162
  return neighborhoods
154
163
 
@@ -162,21 +171,19 @@ def calculate_walktrap_neighborhoods(network: nx.Graph) -> np.ndarray:
162
171
  Returns:
163
172
  np.ndarray: Binary neighborhood matrix on Walktrap communities.
164
173
  """
165
- # Use the asynchronous label propagation algorithm as a proxy for Walktrap
174
+ # Apply Asynchronous Label Propagation (LPA)
166
175
  communities = asyn_lpa_communities(network)
167
- # Create a community label for each node
168
- community_dict = {}
169
- for community_id, community in enumerate(communities):
170
- for node in community:
171
- community_dict[node] = community_id
172
-
173
176
  # Create a binary neighborhood matrix
174
177
  num_nodes = network.number_of_nodes()
175
178
  neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
176
- # Assign neighborhoods based on community labels
177
- for node_i, community_i in community_dict.items():
178
- for node_j, community_j in community_dict.items():
179
- if community_i == community_j:
180
- neighborhoods[node_i, node_j] = 1
179
+ # Create a mapping from node to index in the matrix
180
+ node_index = {node: i for i, node in enumerate(network.nodes())}
181
+ # Assign neighborhoods based on community labels from LPA
182
+ for community in communities:
183
+ for node_i in community:
184
+ idx_i = node_index[node_i]
185
+ for node_j in community:
186
+ idx_j = node_index[node_j]
187
+ neighborhoods[idx_i, idx_j] = 1
181
188
 
182
189
  return neighborhoods
@@ -5,7 +5,7 @@ risk/neighborhoods/neighborhoods
5
5
 
6
6
  import random
7
7
  import warnings
8
- from typing import Any, Dict, List, Tuple
8
+ from typing import Any, Dict, List, Tuple, Union
9
9
 
10
10
  import networkx as nx
11
11
  import numpy as np
@@ -28,50 +28,82 @@ warnings.filterwarnings(action="ignore", category=DataConversionWarning)
28
28
 
29
29
  def get_network_neighborhoods(
30
30
  network: nx.Graph,
31
- distance_metric: str = "louvain",
32
- edge_length_threshold: float = 1.0,
31
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
32
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 1.0,
33
33
  louvain_resolution: float = 1.0,
34
34
  random_seed: int = 888,
35
35
  ) -> np.ndarray:
36
- """Calculate the neighborhoods for each node in the network based on the specified distance metric.
36
+ """Calculate the combined neighborhoods for each node based on the specified community detection algorithm(s).
37
37
 
38
38
  Args:
39
39
  network (nx.Graph): The network graph.
40
- distance_metric (str): The distance metric to use ('greedy_modularity', 'louvain', 'label_propagation',
41
- 'markov_clustering', 'walktrap', 'spinglass').
42
- edge_length_threshold (float): The edge length threshold for the neighborhoods.
40
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
41
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
42
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
43
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
44
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
45
+ Defaults to 1.0.
43
46
  louvain_resolution (float, optional): Resolution parameter for the Louvain method. Defaults to 1.0.
44
47
  random_seed (int, optional): Random seed for methods requiring random initialization. Defaults to 888.
45
48
 
46
49
  Returns:
47
- np.ndarray: Neighborhood matrix calculated based on the selected distance metric.
50
+ np.ndarray: Summed neighborhood matrix from all selected algorithms.
48
51
  """
49
- # Set random seed for reproducibility in all methods besides Louvain, which requires a separate seed
52
+ # Set random seed for reproducibility
50
53
  random.seed(random_seed)
51
54
  np.random.seed(random_seed)
52
55
 
53
- # Create a subgraph based on the edge length percentile threshold
54
- network = _create_percentile_limited_subgraph(
55
- network, edge_length_percentile=edge_length_threshold
56
- )
56
+ # Ensure distance_metric is a list/tuple for multi-algorithm handling
57
+ if isinstance(distance_metric, (str, np.ndarray)):
58
+ distance_metric = [distance_metric]
59
+ # Ensure edge_length_threshold is a list/tuple for multi-threshold handling
60
+ if isinstance(edge_length_threshold, (float, int)):
61
+ edge_length_threshold = [edge_length_threshold] * len(distance_metric)
62
+ # Check that the number of distance metrics matches the number of edge length thresholds
63
+ if len(distance_metric) != len(edge_length_threshold):
64
+ raise ValueError(
65
+ "The number of distance metrics must match the number of edge length thresholds."
66
+ )
57
67
 
58
- if distance_metric == "louvain":
59
- return calculate_louvain_neighborhoods(network, louvain_resolution, random_seed=random_seed)
60
- if distance_metric == "greedy_modularity":
61
- return calculate_greedy_modularity_neighborhoods(network)
62
- if distance_metric == "label_propagation":
63
- return calculate_label_propagation_neighborhoods(network)
64
- if distance_metric == "markov_clustering":
65
- return calculate_markov_clustering_neighborhoods(network)
66
- if distance_metric == "walktrap":
67
- return calculate_walktrap_neighborhoods(network)
68
- if distance_metric == "spinglass":
69
- return calculate_spinglass_neighborhoods(network)
70
-
71
- raise ValueError(
72
- "Incorrect distance metric specified. Please choose from 'greedy_modularity', 'louvain',"
73
- "'label_propagation', 'markov_clustering', 'walktrap', 'spinglass'."
74
- )
68
+ # Initialize combined neighborhood matrix
69
+ num_nodes = network.number_of_nodes()
70
+ combined_neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
71
+
72
+ # Loop through each distance metric and corresponding edge length threshold
73
+ for metric, threshold in zip(distance_metric, edge_length_threshold):
74
+ # Create a subgraph based on the specific edge length threshold for this algorithm
75
+ subgraph = _create_percentile_limited_subgraph(network, edge_length_percentile=threshold)
76
+ # Call the appropriate neighborhood function based on the metric
77
+ if metric == "louvain":
78
+ neighborhoods = calculate_louvain_neighborhoods(
79
+ subgraph, louvain_resolution, random_seed=random_seed
80
+ )
81
+ elif metric == "greedy_modularity":
82
+ neighborhoods = calculate_greedy_modularity_neighborhoods(subgraph)
83
+ elif metric == "label_propagation":
84
+ neighborhoods = calculate_label_propagation_neighborhoods(subgraph)
85
+ elif metric == "markov_clustering":
86
+ neighborhoods = calculate_markov_clustering_neighborhoods(subgraph)
87
+ elif metric == "walktrap":
88
+ neighborhoods = calculate_walktrap_neighborhoods(subgraph)
89
+ elif metric == "spinglass":
90
+ neighborhoods = calculate_spinglass_neighborhoods(subgraph)
91
+ else:
92
+ raise ValueError(
93
+ "Incorrect distance metric specified. Please choose from 'greedy_modularity', 'louvain',"
94
+ "'label_propagation', 'markov_clustering', 'walktrap', 'spinglass'."
95
+ )
96
+
97
+ # Sum the neighborhood matrices
98
+ combined_neighborhoods += neighborhoods
99
+
100
+ # Ensure that the maximum value in each row is set to 1
101
+ # This ensures that for each row, only the strongest relationship (the maximum value) is retained,
102
+ # while all other values are reset to 0. This transformation simplifies the neighborhood matrix by
103
+ # focusing on the most significant connection per row.
104
+ combined_neighborhoods = _set_max_to_one(combined_neighborhoods)
105
+
106
+ return combined_neighborhoods
75
107
 
76
108
 
77
109
  def _create_percentile_limited_subgraph(G: nx.Graph, edge_length_percentile: float) -> nx.Graph:
@@ -110,6 +142,25 @@ def _create_percentile_limited_subgraph(G: nx.Graph, edge_length_percentile: flo
110
142
  return subgraph
111
143
 
112
144
 
145
+ def _set_max_to_one(matrix: np.ndarray) -> np.ndarray:
146
+ """For each row in the input matrix, set the maximum value(s) to 1 and all other values to 0.
147
+
148
+ Args:
149
+ matrix (np.ndarray): A 2D numpy array representing the neighborhood matrix.
150
+
151
+ Returns:
152
+ np.ndarray: The modified matrix where only the maximum value(s) in each row is set to 1, and others are set to 0.
153
+ """
154
+ # Find the maximum value in each row (column-wise max operation)
155
+ max_values = np.max(matrix, axis=1, keepdims=True)
156
+ # Create a boolean mask where elements are True if they are the max value in their row
157
+ max_mask = matrix == max_values
158
+ # Set all elements to 0, and then set the maximum value positions to 1
159
+ matrix[:] = 0 # Set everything to 0
160
+ matrix[max_mask] = 1 # Set only the max values to 1
161
+ return matrix
162
+
163
+
113
164
  def process_neighborhoods(
114
165
  network: nx.Graph,
115
166
  neighborhoods: Dict[str, Any],
@@ -120,12 +171,12 @@ def process_neighborhoods(
120
171
 
121
172
  Args:
122
173
  network (nx.Graph): The network data structure used for imputing and pruning neighbors.
123
- neighborhoods (dict): Dictionary containing 'enrichment_matrix', 'binary_enrichment_matrix', and 'significant_enrichment_matrix'.
174
+ neighborhoods (Dict[str, Any]): Dictionary containing 'enrichment_matrix', 'binary_enrichment_matrix', and 'significant_enrichment_matrix'.
124
175
  impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
125
176
  prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
126
177
 
127
178
  Returns:
128
- dict: Processed neighborhoods data, including the updated matrices and enrichment counts.
179
+ Dict[str, Any]: Processed neighborhoods data, including the updated matrices and enrichment counts.
129
180
  """
130
181
  enrichment_matrix = neighborhoods["enrichment_matrix"]
131
182
  binary_enrichment_matrix = neighborhoods["binary_enrichment_matrix"]
@@ -408,7 +459,7 @@ def _calculate_threshold(median_distances: List, distance_threshold: float) -> f
408
459
  """Calculate the distance threshold based on the given median distances and a percentile threshold.
409
460
 
410
461
  Args:
411
- median_distances (list): An array of median distances.
462
+ median_distances (List): An array of median distances.
412
463
  distance_threshold (float): A percentile threshold (0 to 1) used to determine the distance cutoff.
413
464
 
414
465
  Returns:
risk/network/geometry.py CHANGED
@@ -68,6 +68,7 @@ def assign_edge_lengths(
68
68
  v_coords = np.append(v_coords, G_depth.nodes[v].get("z", 0))
69
69
 
70
70
  distance = compute_distance(u_coords, v_coords, is_sphere=compute_sphere)
71
+ # Assign edge lengths to the original graph
71
72
  if include_edge_weight:
72
73
  # Square root of the normalized weight is used to minimize the effect of large weights
73
74
  G.edges[u, v]["length"] = distance / np.sqrt(G.edges[u, v]["normalized_weight"] + 1e-6)
risk/network/graph.py CHANGED
@@ -36,7 +36,7 @@ class NetworkGraph:
36
36
  top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
37
37
  domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
38
38
  trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
39
- node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
39
+ node_label_to_node_id_map (Dict[str, Any]): A dictionary mapping node labels to their corresponding IDs.
40
40
  node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
41
41
  """
42
42
  self.top_annotations = top_annotations
@@ -60,14 +60,14 @@ class NetworkGraph:
60
60
  self.network = _unfold_sphere_to_plane(network)
61
61
  self.node_coordinates = _extract_node_coordinates(self.network)
62
62
 
63
- def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
63
+ def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[int, Any]:
64
64
  """Create a mapping from domains to the list of node IDs belonging to each domain.
65
65
 
66
66
  Args:
67
67
  domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
68
68
 
69
69
  Returns:
70
- dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
70
+ Dict[int, Any]: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
71
71
  """
72
72
  cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
73
73
  node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
@@ -79,14 +79,14 @@ class NetworkGraph:
79
79
 
80
80
  def _create_domain_id_to_domain_terms_map(
81
81
  self, trimmed_domains: pd.DataFrame
82
- ) -> Dict[str, Any]:
82
+ ) -> Dict[int, Any]:
83
83
  """Create a mapping from domain IDs to their corresponding terms.
84
84
 
85
85
  Args:
86
86
  trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
87
87
 
88
88
  Returns:
89
- dict: A dictionary mapping domain IDs to their corresponding terms.
89
+ Dict[int, Any]: A dictionary mapping domain IDs to their corresponding terms.
90
90
  """
91
91
  return dict(
92
92
  zip(
@@ -105,7 +105,7 @@ class NetworkGraph:
105
105
  two columns are 'all domains' and 'primary domain', which are excluded from processing.
106
106
 
107
107
  Returns:
108
- dict: A dictionary where the key is the node ID (index of the DataFrame), and the value is another dictionary
108
+ Dict[int, Dict]: A dictionary where the key is the node ID (index of the DataFrame), and the value is another dictionary
109
109
  with 'domain' (a list of domain IDs with non-zero enrichment) and 'enrichment'
110
110
  (a dict of domain IDs and their corresponding enrichment values).
111
111
  """
@@ -133,7 +133,7 @@ class NetworkGraph:
133
133
  """Create a map from domain IDs to node labels.
134
134
 
135
135
  Returns:
136
- dict: A dictionary mapping domain IDs to the corresponding node labels.
136
+ Dict[int, List[str]]: A dictionary mapping domain IDs to the corresponding node labels.
137
137
  """
138
138
  domain_id_to_label_map = {}
139
139
  for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
risk/network/io.py CHANGED
@@ -491,7 +491,7 @@ class NetworkIO:
491
491
  if "x" not in attrs or "y" not in attrs:
492
492
  if (
493
493
  "pos" in attrs
494
- and isinstance(attrs["pos"], (List, Tuple, np.ndarray))
494
+ and isinstance(attrs["pos"], (list, tuple, np.ndarray))
495
495
  and len(attrs["pos"]) >= 2
496
496
  ):
497
497
  attrs["x"], attrs["y"] = attrs["pos"][
@@ -137,7 +137,7 @@ class Canvas:
137
137
  perimeter_linestyle=linestyle,
138
138
  perimeter_linewidth=linewidth,
139
139
  perimeter_color=(
140
- "custom" if isinstance(color, (List, Tuple, np.ndarray)) else color
140
+ "custom" if isinstance(color, (list, tuple, np.ndarray)) else color
141
141
  ), # np.ndarray usually indicates custom colors
142
142
  perimeter_outline_alpha=outline_alpha,
143
143
  perimeter_fill_alpha=fill_alpha,
@@ -210,7 +210,7 @@ class Canvas:
210
210
  perimeter_grid_size=grid_size,
211
211
  perimeter_linestyle=linestyle,
212
212
  perimeter_linewidth=linewidth,
213
- perimeter_color=("custom" if isinstance(color, (List, Tuple, np.ndarray)) else color),
213
+ perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
214
214
  perimeter_outline_alpha=outline_alpha,
215
215
  perimeter_fill_alpha=fill_alpha,
216
216
  )
@@ -122,7 +122,7 @@ class Contour:
122
122
  ValueError: If no valid nodes are found in the network graph.
123
123
  """
124
124
  # Check if nodes is a list of lists or a flat list
125
- if any(isinstance(item, (List, Tuple, np.ndarray)) for item in nodes):
125
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
126
126
  # If it's a list of lists, iterate over sublists
127
127
  node_groups = nodes
128
128
  # Convert color to RGBA arrays to match the number of groups
@@ -181,7 +181,7 @@ class Contour:
181
181
  Args:
182
182
  ax (plt.Axes): The axis to draw the contour on.
183
183
  pos (np.ndarray): Array of node positions (x, y).
184
- nodes (list): List of node indices to include in the contour.
184
+ nodes (List): List of node indices to include in the contour.
185
185
  levels (int, optional): Number of contour levels. Defaults to 5.
186
186
  bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
187
187
  grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
@@ -86,7 +86,7 @@ class Labels:
86
86
  overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
87
87
  ids_to_keep (List, Tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
88
88
  you can set `overlay_ids=True`. Defaults to None.
89
- ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
89
+ ids_to_replace (Dict, optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
90
90
  space-separated words. If provided, the custom labels will replace the default domain terms. To discover domain IDs, you
91
91
  can set `overlay_ids=True`. Defaults to None.
92
92
 
@@ -282,7 +282,7 @@ class Labels:
282
282
  arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
283
283
  """
284
284
  # Check if nodes is a list of lists or a flat list
285
- if any(isinstance(item, (List, Tuple, np.ndarray)) for item in nodes):
285
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
286
286
  # If it's a list of lists, iterate over sublists
287
287
  node_groups = nodes
288
288
  # Convert fontcolor and arrow_color to RGBA arrays to match the number of groups
@@ -347,7 +347,7 @@ class Labels:
347
347
  """Calculate the most centrally located node in .
348
348
 
349
349
  Args:
350
- nodes (list): List of node labels to include in the subnetwork.
350
+ nodes (List): List of node labels to include in the subnetwork.
351
351
 
352
352
  Returns:
353
353
  tuple: A tuple containing the domain's central node coordinates.
@@ -382,18 +382,18 @@ class Labels:
382
382
  """Process the ids_to_keep, apply filtering, and store valid domain centroids and terms.
383
383
 
384
384
  Args:
385
- domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
385
+ domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
386
386
  ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
387
- ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
387
+ ids_to_replace (Dict[str, str], optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
388
388
  words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
389
389
  max_labels (int, optional): Maximum number of labels allowed.
390
390
  min_label_lines (int): Minimum number of lines in a label.
391
391
  max_label_lines (int): Maximum number of lines in a label.
392
392
  min_chars_per_line (int): Minimum number of characters in a line to display.
393
393
  max_chars_per_line (int): Maximum number of characters in a line to display.
394
- filtered_domain_centroids (dict): Dictionary to store filtered domain centroids (output).
395
- filtered_domain_terms (dict): Dictionary to store filtered domain terms (output).
396
- valid_indices (list): List to store valid indices (output).
394
+ filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store filtered domain centroids (output).
395
+ filtered_domain_terms (Dict[str, str]): Dictionary to store filtered domain terms (output).
396
+ valid_indices (List): List to store valid indices (output).
397
397
 
398
398
  Note:
399
399
  The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
@@ -448,18 +448,18 @@ class Labels:
448
448
  """Process remaining domains to fill in additional labels, respecting the remaining_labels limit.
449
449
 
450
450
  Args:
451
- domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
451
+ domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
452
452
  ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
453
- ids_to_replace (dict, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
453
+ ids_to_replace (Dict[str, str], optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
454
454
  words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
455
455
  remaining_labels (int): The remaining number of labels that can be generated.
456
456
  min_label_lines (int): Minimum number of lines in a label.
457
457
  max_label_lines (int): Maximum number of lines in a label.
458
458
  min_chars_per_line (int): Minimum number of characters in a line to display.
459
459
  max_chars_per_line (int): Maximum number of characters in a line to display.
460
- filtered_domain_centroids (dict): Dictionary to store filtered domain centroids (output).
461
- filtered_domain_terms (dict): Dictionary to store filtered domain terms (output).
462
- valid_indices (list): List to store valid indices (output).
460
+ filtered_domain_centroids (Dict[str, np.ndarray]): Dictionary to store filtered domain centroids (output).
461
+ filtered_domain_terms (Dict[str, str]): Dictionary to store filtered domain terms (output).
462
+ valid_indices (List): List to store valid indices (output).
463
463
 
464
464
  Note:
465
465
  The `filtered_domain_centroids`, `filtered_domain_terms`, and `valid_indices` are modified in-place.
@@ -551,9 +551,9 @@ class Labels:
551
551
  Args:
552
552
  domain (str): Domain ID to process.
553
553
  domain_centroid (np.ndarray): Centroid position of the domain.
554
- domain_id_to_centroid_map (dict): Mapping of domain IDs to their centroids.
555
- ids_to_replace (Union[Dict[str, str], None]): A dictionary mapping domain IDs to custom labels.
556
- words_to_omit (Union[List[str], None]): List of words to omit from the labels.
554
+ domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
555
+ ids_to_replace (Dict[str, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
556
+ words_to_omit (List[str], None, optional): List of words to omit from the labels. Defaults to None.
557
557
  min_label_lines (int): Minimum number of lines required in a label.
558
558
  max_label_lines (int): Maximum number of lines allowed in a label.
559
559
  min_chars_per_line (int): Minimum number of characters allowed per line.
@@ -606,7 +606,7 @@ class Labels:
606
606
 
607
607
  Args:
608
608
  domain (str): The domain being processed.
609
- ids_to_replace (dict, optional): Dictionary mapping domain IDs to custom labels.
609
+ ids_to_replace (Dict[str, str], optional): Dictionary mapping domain IDs to custom labels.
610
610
  words_to_omit (List, optional): List of words to omit from the labels.
611
611
  max_label_lines (int): Maximum number of lines in a label.
612
612
  min_chars_per_line (int): Minimum number of characters in a line to display.
@@ -740,13 +740,13 @@ def _calculate_best_label_positions(
740
740
  """Calculate and optimize label positions for clarity.
741
741
 
742
742
  Args:
743
- filtered_domain_centroids (dict): Centroids of the filtered domains.
743
+ filtered_domain_centroids (Dict[str, Any]): Centroids of the filtered domains.
744
744
  center (np.ndarray): The center coordinates for label positioning.
745
745
  radius (float): The radius for positioning labels around the center.
746
746
  offset (float): The offset distance from the radius for positioning labels.
747
747
 
748
748
  Returns:
749
- dict: Optimized positions for labels.
749
+ Dict[str, Any]: Optimized positions for labels.
750
750
  """
751
751
  num_domains = len(filtered_domain_centroids)
752
752
  # Calculate equidistant positions around the center for initial label placement
@@ -791,11 +791,11 @@ def _optimize_label_positions(
791
791
  """Optimize label positions around the perimeter to minimize total distance to centroids.
792
792
 
793
793
  Args:
794
- best_label_positions (dict): Initial positions of labels around the perimeter.
795
- domain_centroids (dict): Centroid positions of the domains.
794
+ best_label_positions (Dict[str, Any]): Initial positions of labels around the perimeter.
795
+ domain_centroids (Dict[str, Any]): Centroid positions of the domains.
796
796
 
797
797
  Returns:
798
- dict: Optimized label positions.
798
+ Dict[str, Any]: Optimized label positions.
799
799
  """
800
800
  while True:
801
801
  improvement = False # Start each iteration assuming no improvement
@@ -827,8 +827,8 @@ def _calculate_total_distance(
827
827
  """Calculate the total distance from label positions to their domain centroids.
828
828
 
829
829
  Args:
830
- label_positions (dict): Positions of labels around the perimeter.
831
- domain_centroids (dict): Centroid positions of the domains.
830
+ label_positions (Dict[str, Any]): Positions of labels around the perimeter.
831
+ domain_centroids (Dict[str, Any]): Centroid positions of the domains.
832
832
 
833
833
  Returns:
834
834
  float: The total distance from labels to centroids.
@@ -851,10 +851,10 @@ def _swap_and_evaluate(
851
851
  """Swap two labels and evaluate the total distance after the swap.
852
852
 
853
853
  Args:
854
- label_positions (dict): Positions of labels around the perimeter.
854
+ label_positions (Dict[str, Any]): Positions of labels around the perimeter.
855
855
  i (int): Index of the first label to swap.
856
856
  j (int): Index of the second label to swap.
857
- domain_centroids (dict): Centroid positions of the domains.
857
+ domain_centroids (Dict[str, Any]): Centroid positions of the domains.
858
858
 
859
859
  Returns:
860
860
  float: The total distance after swapping the two labels.
@@ -141,7 +141,7 @@ class Network:
141
141
  ValueError: If no valid nodes are found in the network graph.
142
142
  """
143
143
  # Flatten nested lists of nodes, if necessary
144
- if any(isinstance(item, (List, Tuple, np.ndarray)) for item in nodes):
144
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
145
145
  nodes = [node for sublist in nodes for node in sublist]
146
146
 
147
147
  # Filter to get node IDs and their coordinates
@@ -67,7 +67,7 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
67
67
 
68
68
  Args:
69
69
  graph (NetworkGraph): The network data and attributes to be visualized.
70
- figsize (tuple): Size of the figure in inches (width, height).
70
+ figsize (Tuple): Size of the figure in inches (width, height).
71
71
  background_color (str, List, Tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
72
72
  background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
73
73
  alpha values found in `background_color`.
@@ -124,7 +124,7 @@ def _get_domain_colors(
124
124
  cmap: str = "gist_rainbow",
125
125
  color: Union[str, List, Tuple, np.ndarray, None] = None,
126
126
  random_seed: int = 888,
127
- ) -> Dict[str, Any]:
127
+ ) -> Dict[int, Any]:
128
128
  """Get colors for each domain.
129
129
 
130
130
  Args:
@@ -135,7 +135,7 @@ def _get_domain_colors(
135
135
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
136
136
 
137
137
  Returns:
138
- dict: A dictionary mapping domain keys to their corresponding RGBA colors.
138
+ Dict[int, Any]: A dictionary mapping domain keys to their corresponding RGBA colors.
139
139
  """
140
140
  # Get colors for each domain based on node positions
141
141
  domain_colors = _get_colors(
@@ -215,7 +215,7 @@ def _get_colors(
215
215
 
216
216
  Args:
217
217
  network (NetworkX graph): The graph representing the network.
218
- domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
218
+ domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
219
219
  cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
220
220
  color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
221
221
  If None, the colormap will be used. Defaults to None.
@@ -377,7 +377,7 @@ def to_rgba(
377
377
  if isinstance(c, str):
378
378
  # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
379
379
  rgba = np.array(mcolors.to_rgba(c))
380
- elif isinstance(c, (List, Tuple, np.ndarray)) and len(c) in [3, 4]:
380
+ elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
381
381
  # Convert RGB (3) or RGBA (4) values to RGBA format
382
382
  rgba = np.array(mcolors.to_rgba(c))
383
383
  else:
@@ -396,8 +396,8 @@ def to_rgba(
396
396
  # Handle a single color (string or RGB/RGBA list/tuple)
397
397
  if (
398
398
  isinstance(color, str)
399
- or isinstance(color, (List, Tuple, np.ndarray))
400
- and not any(isinstance(c, (str, List, Tuple, np.ndarray)) for c in color)
399
+ or isinstance(color, (list, tuple, np.ndarray))
400
+ and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
401
401
  ):
402
402
  rgba_color = convert_to_rgba(color)
403
403
  if num_repeats:
@@ -407,7 +407,7 @@ def to_rgba(
407
407
  return np.array([rgba_color]) # Return a single color wrapped in a numpy array
408
408
 
409
409
  # Handle a list/array of colors
410
- elif isinstance(color, (List, Tuple, np.ndarray)):
410
+ elif isinstance(color, (list, tuple, np.ndarray)):
411
411
  rgba_colors = np.array(
412
412
  [convert_to_rgba(c) for c in color]
413
413
  ) # Convert each color in the list to RGBA
@@ -35,7 +35,7 @@ def calculate_centroids(network, domain_id_to_node_ids_map):
35
35
 
36
36
  Args:
37
37
  network (NetworkX graph): The graph representing the network.
38
- domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
38
+ domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
39
39
 
40
40
  Returns:
41
41
  List[Tuple[float, float]]: List of centroids (x, y) for each domain.
risk/risk.py CHANGED
@@ -3,7 +3,7 @@ risk/risk
3
3
  ~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Any, Dict, Tuple, Union
6
+ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import networkx as nx
9
9
  import numpy as np
@@ -58,9 +58,9 @@ class RISK(NetworkIO, AnnotationsIO):
58
58
  self,
59
59
  network: nx.Graph,
60
60
  annotations: Dict[str, Any],
61
- distance_metric: str = "louvain",
61
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
62
62
  louvain_resolution: float = 0.1,
63
- edge_length_threshold: float = 0.5,
63
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
64
64
  null_distribution: str = "network",
65
65
  random_seed: int = 888,
66
66
  ) -> Dict[str, Any]:
@@ -68,15 +68,19 @@ class RISK(NetworkIO, AnnotationsIO):
68
68
 
69
69
  Args:
70
70
  network (nx.Graph): The network graph.
71
- annotations (dict): The annotations associated with the network.
72
- distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "louvain".
71
+ annotations (Dict[str, Any]): The annotations associated with the network.
72
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
73
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
74
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
73
75
  louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
74
- edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
76
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
77
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
78
+ Defaults to 0.5.
75
79
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
76
80
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
77
81
 
78
82
  Returns:
79
- dict: Computed significance of neighborhoods.
83
+ Dict[str, Any]: Computed significance of neighborhoods.
80
84
  """
81
85
  log_header("Running hypergeometric test")
82
86
  # Log neighborhood analysis parameters
@@ -111,9 +115,9 @@ class RISK(NetworkIO, AnnotationsIO):
111
115
  self,
112
116
  network: nx.Graph,
113
117
  annotations: Dict[str, Any],
114
- distance_metric: str = "louvain",
118
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
115
119
  louvain_resolution: float = 0.1,
116
- edge_length_threshold: float = 0.5,
120
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
117
121
  null_distribution: str = "network",
118
122
  random_seed: int = 888,
119
123
  ) -> Dict[str, Any]:
@@ -121,15 +125,19 @@ class RISK(NetworkIO, AnnotationsIO):
121
125
 
122
126
  Args:
123
127
  network (nx.Graph): The network graph.
124
- annotations (dict): The annotations associated with the network.
125
- distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "louvain".
128
+ annotations (Dict[str, Any]): The annotations associated with the network.
129
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
130
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
131
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
126
132
  louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
127
- edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
133
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
134
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
135
+ Defaults to 0.5.
128
136
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
129
137
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
130
138
 
131
139
  Returns:
132
- dict: Computed significance of neighborhoods.
140
+ Dict[str, Any]: Computed significance of neighborhoods.
133
141
  """
134
142
  log_header("Running Poisson test")
135
143
  # Log neighborhood analysis parameters
@@ -164,9 +172,9 @@ class RISK(NetworkIO, AnnotationsIO):
164
172
  self,
165
173
  network: nx.Graph,
166
174
  annotations: Dict[str, Any],
167
- distance_metric: str = "louvain",
175
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
168
176
  louvain_resolution: float = 0.1,
169
- edge_length_threshold: float = 0.5,
177
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
170
178
  score_metric: str = "sum",
171
179
  null_distribution: str = "network",
172
180
  num_permutations: int = 1000,
@@ -177,10 +185,14 @@ class RISK(NetworkIO, AnnotationsIO):
177
185
 
178
186
  Args:
179
187
  network (nx.Graph): The network graph.
180
- annotations (dict): The annotations associated with the network.
181
- distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "louvain".
188
+ annotations (Dict[str, Any]): The annotations associated with the network.
189
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
190
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
191
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
182
192
  louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
183
- edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
193
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
194
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
195
+ Defaults to 0.5.
184
196
  score_metric (str, optional): Scoring metric for neighborhood significance. Defaults to "sum".
185
197
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
186
198
  num_permutations (int, optional): Number of permutations for significance testing. Defaults to 1000.
@@ -188,7 +200,7 @@ class RISK(NetworkIO, AnnotationsIO):
188
200
  max_workers (int, optional): Maximum number of workers for parallel computation. Defaults to 1.
189
201
 
190
202
  Returns:
191
- dict: Computed significance of neighborhoods.
203
+ Dict[str, Any]: Computed significance of neighborhoods.
192
204
  """
193
205
  log_header("Running permutation test")
194
206
  # Log neighborhood analysis parameters
@@ -253,7 +265,7 @@ class RISK(NetworkIO, AnnotationsIO):
253
265
  Args:
254
266
  network (nx.Graph): The network graph.
255
267
  annotations (pd.DataFrame): DataFrame containing annotation data for the network.
256
- neighborhoods (dict): Neighborhood enrichment data.
268
+ neighborhoods (Dict[str, Any]): Neighborhood enrichment data.
257
269
  tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
258
270
  pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
259
271
  fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
@@ -353,7 +365,7 @@ class RISK(NetworkIO, AnnotationsIO):
353
365
  def load_plotter(
354
366
  self,
355
367
  graph: NetworkGraph,
356
- figsize: Tuple = (10, 10),
368
+ figsize: Union[List, Tuple, np.ndarray] = (10, 10),
357
369
  background_color: str = "white",
358
370
  background_alpha: Union[float, None] = 1.0,
359
371
  pad: float = 0.3,
@@ -362,7 +374,7 @@ class RISK(NetworkIO, AnnotationsIO):
362
374
 
363
375
  Args:
364
376
  graph (NetworkGraph): The graph to plot.
365
- figsize (Tuple, optional): Size of the figure. Defaults to (10, 10).
377
+ figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
366
378
  background_color (str, optional): Background color of the plot. Defaults to "white".
367
379
  background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
368
380
  any existing alpha values found in background_color. Defaults to 1.0.
@@ -385,9 +397,9 @@ class RISK(NetworkIO, AnnotationsIO):
385
397
  def _load_neighborhoods(
386
398
  self,
387
399
  network: nx.Graph,
388
- distance_metric: str = "louvain",
400
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
389
401
  louvain_resolution: float = 0.1,
390
- edge_length_threshold: float = 0.5,
402
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
391
403
  random_seed: int = 888,
392
404
  ) -> np.ndarray:
393
405
  """Load significant neighborhoods for the network.
@@ -395,9 +407,13 @@ class RISK(NetworkIO, AnnotationsIO):
395
407
  Args:
396
408
  network (nx.Graph): The network graph.
397
409
  annotations (pd.DataFrame): The matrix of annotations associated with the network.
398
- distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "louvain".
410
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
411
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
412
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
399
413
  louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
400
- edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
414
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
415
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
416
+ Defaults to 0.5.
401
417
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
402
418
 
403
419
  Returns:
@@ -437,13 +453,13 @@ class RISK(NetworkIO, AnnotationsIO):
437
453
 
438
454
  Args:
439
455
  network (nx.Graph): The network graph.
440
- annotations (dict): Annotations data for the network.
441
- neighborhoods (dict): Neighborhood enrichment data.
456
+ annotations (Dict[str, Any]): Annotations data for the network.
457
+ neighborhoods (Dict[str, Any]): Neighborhood enrichment data.
442
458
  min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
443
459
  max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
444
460
 
445
461
  Returns:
446
- dict: Top annotations identified within the network.
462
+ Dict[str, Any]: Top annotations identified within the network.
447
463
  """
448
464
  # Extract necessary data from annotations and neighborhoods
449
465
  ordered_annotations = annotations["ordered_annotations"]
@@ -470,7 +486,7 @@ class RISK(NetworkIO, AnnotationsIO):
470
486
  """Define domains in the network based on enrichment data.
471
487
 
472
488
  Args:
473
- neighborhoods (dict): Enrichment data for neighborhoods.
489
+ neighborhoods (Dict[str, Any]): Enrichment data for neighborhoods.
474
490
  top_annotations (pd.DataFrame): Enrichment matrix for top annotations.
475
491
  linkage_criterion (str): Clustering criterion for defining domains.
476
492
  linkage_method (str): Clustering method to use.
risk/stats/hypergeom.py CHANGED
@@ -20,7 +20,7 @@ def compute_hypergeom_test(
20
20
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
21
21
 
22
22
  Returns:
23
- dict: Dictionary containing depletion and enrichment p-values.
23
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
24
24
  """
25
25
  # Get the total number of nodes in the network
26
26
  total_node_count = neighborhoods.shape[0]
@@ -35,7 +35,7 @@ def compute_permutation_test(
35
35
  max_workers (int, optional): Number of workers for multiprocessing. Defaults to 1.
36
36
 
37
37
  Returns:
38
- dict: Dictionary containing depletion and enrichment p-values.
38
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
39
39
  """
40
40
  # Ensure that the matrices are in the correct format and free of NaN values
41
41
  neighborhoods = neighborhoods.astype(np.float32)
risk/stats/poisson.py CHANGED
@@ -3,7 +3,7 @@ risk/stats/poisson
3
3
  ~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Dict, Any
6
+ from typing import Any, Dict
7
7
 
8
8
  import numpy as np
9
9
  from scipy.stats import poisson
@@ -20,7 +20,7 @@ def compute_poisson_test(
20
20
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
21
21
 
22
22
  Returns:
23
- dict: Dictionary containing depletion and enrichment p-values.
23
+ Dict[str, Any]: Dictionary containing depletion and enrichment p-values.
24
24
  """
25
25
  # Matrix multiplication to get the number of annotated nodes in each neighborhood
26
26
  annotated_in_neighborhood = neighborhoods @ annotations
risk/stats/stats.py CHANGED
@@ -3,7 +3,7 @@ risk/stats/stats
3
3
  ~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Union
6
+ from typing import Any, Dict, Union
7
7
 
8
8
  import numpy as np
9
9
  from statsmodels.stats.multitest import fdrcorrection
@@ -15,7 +15,7 @@ def calculate_significance_matrices(
15
15
  tail: str = "right",
16
16
  pval_cutoff: float = 0.05,
17
17
  fdr_cutoff: float = 0.05,
18
- ) -> dict:
18
+ ) -> Dict[str, Any]:
19
19
  """Calculate significance matrices based on p-values and specified tail.
20
20
 
21
21
  Args:
@@ -26,8 +26,8 @@ def calculate_significance_matrices(
26
26
  fdr_cutoff (float, optional): Cutoff for FDR significance if applied. Defaults to 0.05.
27
27
 
28
28
  Returns:
29
- dict: Dictionary containing the enrichment matrix, binary significance matrix,
30
- and the matrix of significant enrichment values.
29
+ Dict[str, Any]: Dictionary containing the enrichment matrix, binary significance matrix,
30
+ and the matrix of significant enrichment values.
31
31
  """
32
32
  if fdr_cutoff < 1.0:
33
33
  # Apply FDR correction to depletion p-values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.8b19
3
+ Version: 0.0.8b21
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -0,0 +1,37 @@
1
+ risk/__init__.py,sha256=rwU6NXlpU-Cu31fRh5iXEwcYq9jIr2prM7bv_iAaksI,113
2
+ risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
+ risk/risk.py,sha256=8J2cvtXYy99PCoOHEYNXt0vcYrcHqxi2O1VvYQNNPt4,23217
4
+ risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
5
+ risk/annotations/annotations.py,sha256=KHGeF5vBDmX711nA08DfhxI9z7Z1Oaeo91ueWhM6vs8,11370
6
+ risk/annotations/io.py,sha256=powWzeimVdE0WCwlBCXyu5otMyZZHQujC0DS3m5DC0c,9505
7
+ risk/log/__init__.py,sha256=aDUz5LMFQsz0UlsQI2EdXtiBKRLfml1UMeZKC7QQIGU,134
8
+ risk/log/config.py,sha256=m8pzj-hN4vI_2JdJUfyOoSvzT8_lhoIfBt27sKbnOes,4535
9
+ risk/log/params.py,sha256=rvyg86RnkHwotST7x42RgsiYfq2HB-9BZxp6KkT_04o,6415
10
+ risk/neighborhoods/__init__.py,sha256=tKKEg4lsbqFukpgYlUGxU_v_9FOqK7V0uvM9T2QzoL0,206
11
+ risk/neighborhoods/community.py,sha256=MAgIblbuisEPwVU6mFZd4Yd9NUKlaHK99suw51r1Is0,7065
12
+ risk/neighborhoods/domains.py,sha256=DbhUFsvbr8wuvrNr7a0PaAJO-cdv6U3-T4CXB4-j5Qw,10930
13
+ risk/neighborhoods/neighborhoods.py,sha256=OPGNfeGQR533vWjger7f34ZPSgw9250LQXcTEIAhQvg,21165
14
+ risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
15
+ risk/network/geometry.py,sha256=Y3Brp0XYWoBL2VHJX7I-gW5x-q7lGiEMqr2kqtutgkQ,6811
16
+ risk/network/graph.py,sha256=-91JL84LYbdWohzybKFQ3NdWnervxP-wwbpaUOdRVLE,8576
17
+ risk/network/io.py,sha256=u0PPcKjp6Xze--7eDOlvalYkjQ9S2sjiC-ac2476PUI,22942
18
+ risk/network/plot/__init__.py,sha256=MfmaXJgAZJgXZ2wrhK8pXwzETlcMaLChhWXKAozniAo,98
19
+ risk/network/plot/canvas.py,sha256=P8XzcesrbxjLcrT40hf15QgLTvearER2Yid3QefQF20,10778
20
+ risk/network/plot/contour.py,sha256=vo1BeXrMKW-EipLB-9pB5AMfzmiJJduo2H_xgWUoDYo,15027
21
+ risk/network/plot/labels.py,sha256=wdIi5UfAGVeZ1UM6AAuOQ0I4dBqQqzVe1euGCdiv91o,45115
22
+ risk/network/plot/network.py,sha256=U-3oYxq-QTZolc72khmesS85pNlep6L40kIT-qZVljE,13615
23
+ risk/network/plot/plotter.py,sha256=iTPMiTnTTatM_-q1Ox_bjt5Pvv-Lo8gceiYB6TVzDcw,5770
24
+ risk/network/plot/utils/color.py,sha256=WSs1ge2oZ8yXwyVk2QqBF-avRd0aYT-sYZr9cxxAn7M,19626
25
+ risk/network/plot/utils/layout.py,sha256=5DpRLvabgnPWwVJ-J3W6oFBBvbjCrudvvW4HDOzzoTo,1960
26
+ risk/stats/__init__.py,sha256=WcgoETQ-hS0LQqKRsAMIPtP15xZ-4eul6VUBuUx4Wzc,220
27
+ risk/stats/hypergeom.py,sha256=oc39f02ViB1vQ-uaDrxG_tzAT6dxQBRjc88EK2EGn78,2282
28
+ risk/stats/poisson.py,sha256=polLgwS08MTCNzupYdmMUoEUYrJOjAbcYtYwjlfeE5Y,1803
29
+ risk/stats/stats.py,sha256=07yMULKlCurK62x674SHKJavZtz9ge2K2ZsHix_z_pw,7088
30
+ risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
31
+ risk/stats/permutation/permutation.py,sha256=meBNSrbRa9P8WJ54n485l0H7VQJlMSfHqdN4aCKYCtQ,10105
32
+ risk/stats/permutation/test_functions.py,sha256=lftOude6hee0pyR80HlBD32522JkDoN5hrKQ9VEbuoY,2345
33
+ risk_network-0.0.8b21.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
34
+ risk_network-0.0.8b21.dist-info/METADATA,sha256=WAMLlJBw45mKR3apG1PxGAQfMOm9flaJxsz_h_rwnGs,47498
35
+ risk_network-0.0.8b21.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
36
+ risk_network-0.0.8b21.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
37
+ risk_network-0.0.8b21.dist-info/RECORD,,
@@ -1,37 +0,0 @@
1
- risk/__init__.py,sha256=jShwk2Z-jTDhGp_Y0GZIkJ1BoZFHOBTRGSKtBnS4Re0,113
2
- risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
- risk/risk.py,sha256=FQp5269IsCh-flmSWIpV7sBmvbGHjlrSy89SkImkxCE,21231
4
- risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
5
- risk/annotations/annotations.py,sha256=7ilzXxrlHqN75J3q8WeHz0n79D-jAtUQx5czvC9wfIM,11303
6
- risk/annotations/io.py,sha256=TTXVJQgUGAlKpnGBcx7Dow146IGyozA03nSbl3S7M5M,9475
7
- risk/log/__init__.py,sha256=aDUz5LMFQsz0UlsQI2EdXtiBKRLfml1UMeZKC7QQIGU,134
8
- risk/log/config.py,sha256=m8pzj-hN4vI_2JdJUfyOoSvzT8_lhoIfBt27sKbnOes,4535
9
- risk/log/params.py,sha256=lgwhtO_pQWLd2_Cpu0T7BMwH5NiA4GFW0aP6d1_rJTE,6363
10
- risk/neighborhoods/__init__.py,sha256=tKKEg4lsbqFukpgYlUGxU_v_9FOqK7V0uvM9T2QzoL0,206
11
- risk/neighborhoods/community.py,sha256=stYYBXeZlGLMV-k8ckQeIqThT6v9y-S3hETobAo9590,6817
12
- risk/neighborhoods/domains.py,sha256=DbhUFsvbr8wuvrNr7a0PaAJO-cdv6U3-T4CXB4-j5Qw,10930
13
- risk/neighborhoods/neighborhoods.py,sha256=M-wL4xB_BUTlSZg90swygO5NdrZ6hFUFqs6jsiZaqHk,18260
14
- risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
15
- risk/network/geometry.py,sha256=H1yGVVqgbfpzBzJwEheDLfvGLSA284jGQQTn612L4Vc,6759
16
- risk/network/graph.py,sha256=X63SNlWIov3oz0aMBMZfbHdmckbLqakZli5HP2Y5OdU,8519
17
- risk/network/io.py,sha256=w_9fUcZUVXAPRKGhLBc7xhIJs8l83szHiBQTdaNN0gk,22942
18
- risk/network/plot/__init__.py,sha256=MfmaXJgAZJgXZ2wrhK8pXwzETlcMaLChhWXKAozniAo,98
19
- risk/network/plot/canvas.py,sha256=hdrmGd2TCuii8wn6jDQfyJTI5YXDNGYFLiU4TyqAYbE,10778
20
- risk/network/plot/contour.py,sha256=ecmqNpyq512Koa14OGa58Z7EP_oUOxYfS4CuU1P9ras,15027
21
- risk/network/plot/labels.py,sha256=hoW6AL1dAUUt2WhWOEDu-Q28L_ojf13NkuxeqfHtqWc,44848
22
- risk/network/plot/network.py,sha256=nfTmQxx1YwS3taXwq8WSCfu6nfKFOyxj7T5605qLXVM,13615
23
- risk/network/plot/plotter.py,sha256=6534tKpOb2ZXn1imu_CDM_BkLaThi50eaRsJQyjsm84,5770
24
- risk/network/plot/utils/color.py,sha256=Y_AUIoj_zKzChz_aC-WoYqiZ5P4qil3k7lmm4MHDaPY,19606
25
- risk/network/plot/utils/layout.py,sha256=znssSqe2VZzzSz47hLZtTuXwMTpHR9b8lkQPL0BX7OA,1950
26
- risk/stats/__init__.py,sha256=WcgoETQ-hS0LQqKRsAMIPtP15xZ-4eul6VUBuUx4Wzc,220
27
- risk/stats/hypergeom.py,sha256=o6Qnj31gCAKxr2uQirXrbv7XvdDJGEq69MFW-ubx_hA,2272
28
- risk/stats/poisson.py,sha256=8x9hB4DCukq4gNIlIKO-c_jYG1-BTwTX53oLauFyfj8,1793
29
- risk/stats/stats.py,sha256=kvShov-94W6ffgDUTb522vB9hDJQSyTsYif_UIaFfSM,7059
30
- risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
31
- risk/stats/permutation/permutation.py,sha256=D84Rcpt6iTQniK0PfQGcw9bLcHbMt9p-ARcurUnIXZQ,10095
32
- risk/stats/permutation/test_functions.py,sha256=lftOude6hee0pyR80HlBD32522JkDoN5hrKQ9VEbuoY,2345
33
- risk_network-0.0.8b19.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
34
- risk_network-0.0.8b19.dist-info/METADATA,sha256=zRIBsg3MXtdrk10GQ2Byns8YWleTkv0taFT67hCrBl4,47498
35
- risk_network-0.0.8b19.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
36
- risk_network-0.0.8b19.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
37
- risk_network-0.0.8b19.dist-info/RECORD,,