risk-network 0.0.8b20__py3-none-any.whl → 0.0.8b22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.8-beta.20"
10
+ __version__ = "0.0.8-beta.22"
@@ -3,5 +3,5 @@ risk/annotations
3
3
  ~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from .annotations import define_top_annotations, get_description
6
+ from .annotations import define_top_annotations, get_weighted_description
7
7
  from .io import AnnotationsIO
@@ -30,6 +30,8 @@ def _setup_nltk():
30
30
 
31
31
  # Ensure you have the necessary NLTK data
32
32
  _setup_nltk()
33
+ # Initialize English stopwords
34
+ stop_words = set(stopwords.words("english"))
33
35
 
34
36
 
35
37
  def load_annotations(network: nx.Graph, annotations_input: Dict[str, Any]) -> Dict[str, Any]:
@@ -47,11 +49,11 @@ def load_annotations(network: nx.Graph, annotations_input: Dict[str, Any]) -> Di
47
49
  (node, annotation) for annotation, nodes in annotations_input.items() for node in nodes
48
50
  ]
49
51
  # Create a DataFrame from the flattened list
50
- annotations = pd.DataFrame(flattened_annotations, columns=["Node", "Annotations"])
51
- annotations["Is Member"] = 1
52
+ annotations = pd.DataFrame(flattened_annotations, columns=["node", "annotations"])
53
+ annotations["is_member"] = 1
52
54
  # Pivot to create a binary matrix with nodes as rows and annotations as columns
53
55
  annotations_pivot = annotations.pivot_table(
54
- index="Node", columns="Annotations", values="Is Member", fill_value=0, dropna=False
56
+ index="node", columns="annotations", values="is_member", fill_value=0, dropna=False
55
57
  )
56
58
  # Reindex the annotations matrix based on the node labels from the network
57
59
  node_label_order = list(nx.get_node_attributes(network, "label").values())
@@ -81,7 +83,8 @@ def define_top_annotations(
81
83
  network: nx.Graph,
82
84
  ordered_annotation_labels: List[str],
83
85
  neighborhood_enrichment_sums: List[int],
84
- binary_enrichment_matrix: np.ndarray,
86
+ significant_enrichment_matrix: np.ndarray,
87
+ significant_binary_enrichment_matrix: np.ndarray,
85
88
  min_cluster_size: int = 5,
86
89
  max_cluster_size: int = 1000,
87
90
  ) -> pd.DataFrame:
@@ -91,42 +94,52 @@ def define_top_annotations(
91
94
  network (NetworkX graph): The network graph.
92
95
  ordered_annotation_labels (list of str): List of ordered annotation labels.
93
96
  neighborhood_enrichment_sums (list of int): List of neighborhood enrichment sums.
94
- binary_enrichment_matrix (np.ndarray): Binary enrichment matrix below alpha threshold.
97
+ significant_enrichment_matrix (np.ndarray): Enrichment matrix below alpha threshold.
98
+ significant_binary_enrichment_matrix (np.ndarray): Binary enrichment matrix below alpha threshold.
95
99
  min_cluster_size (int, optional): Minimum cluster size. Defaults to 5.
96
100
  max_cluster_size (int, optional): Maximum cluster size. Defaults to 1000.
97
101
 
98
102
  Returns:
99
103
  pd.DataFrame: DataFrame with top annotations and their properties.
100
104
  """
101
- # Create DataFrame to store annotations and their neighborhood enrichment sums
105
+ # Sum the columns of the significant enrichment matrix (positive floating point values)
106
+ significant_enrichment_scores = significant_enrichment_matrix.sum(axis=0)
107
+ # Create DataFrame to store annotations, their neighborhood enrichment sums, and enrichment scores
102
108
  annotations_enrichment_matrix = pd.DataFrame(
103
109
  {
104
110
  "id": range(len(ordered_annotation_labels)),
105
- "words": ordered_annotation_labels,
106
- "neighborhood enrichment sums": neighborhood_enrichment_sums,
111
+ "full_terms": ordered_annotation_labels,
112
+ "significant_neighborhood_enrichment_sums": neighborhood_enrichment_sums,
113
+ "significant_enrichment_score": significant_enrichment_scores,
107
114
  }
108
115
  )
109
- annotations_enrichment_matrix["top attributes"] = False
110
- # Apply size constraints to identify potential top attributes
116
+ annotations_enrichment_matrix["significant_annotations"] = False
117
+ # Apply size constraints to identify potential significant annotations
111
118
  annotations_enrichment_matrix.loc[
112
- (annotations_enrichment_matrix["neighborhood enrichment sums"] >= min_cluster_size)
113
- & (annotations_enrichment_matrix["neighborhood enrichment sums"] <= max_cluster_size),
114
- "top attributes",
119
+ (
120
+ annotations_enrichment_matrix["significant_neighborhood_enrichment_sums"]
121
+ >= min_cluster_size
122
+ )
123
+ & (
124
+ annotations_enrichment_matrix["significant_neighborhood_enrichment_sums"]
125
+ <= max_cluster_size
126
+ ),
127
+ "significant_annotations",
115
128
  ] = True
116
129
  # Initialize columns for connected components analysis
117
- annotations_enrichment_matrix["num connected components"] = 0
118
- annotations_enrichment_matrix["size connected components"] = None
119
- annotations_enrichment_matrix["size connected components"] = annotations_enrichment_matrix[
120
- "size connected components"
130
+ annotations_enrichment_matrix["num_connected_components"] = 0
131
+ annotations_enrichment_matrix["size_connected_components"] = None
132
+ annotations_enrichment_matrix["size_connected_components"] = annotations_enrichment_matrix[
133
+ "size_connected_components"
121
134
  ].astype(object)
122
- annotations_enrichment_matrix["num large connected components"] = 0
135
+ annotations_enrichment_matrix["num_large_connected_components"] = 0
123
136
 
124
137
  for attribute in annotations_enrichment_matrix.index.values[
125
- annotations_enrichment_matrix["top attributes"]
138
+ annotations_enrichment_matrix["significant_annotations"]
126
139
  ]:
127
140
  # Identify enriched neighborhoods based on the binary enrichment matrix
128
141
  enriched_neighborhoods = list(
129
- compress(list(network), binary_enrichment_matrix[:, attribute])
142
+ compress(list(network), significant_binary_enrichment_matrix[:, attribute])
130
143
  )
131
144
  enriched_network = nx.subgraph(network, enriched_neighborhoods)
132
145
  # Analyze connected components within the enriched subnetwork
@@ -145,55 +158,67 @@ def define_top_annotations(
145
158
  num_large_connected_components = len(filtered_size_connected_components)
146
159
 
147
160
  # Assign the number of connected components
148
- annotations_enrichment_matrix.loc[attribute, "num connected components"] = (
161
+ annotations_enrichment_matrix.loc[attribute, "num_connected_components"] = (
149
162
  num_connected_components
150
163
  )
151
164
  # Filter out attributes with more than one connected component
152
165
  annotations_enrichment_matrix.loc[
153
- annotations_enrichment_matrix["num connected components"] > 1, "top attributes"
166
+ annotations_enrichment_matrix["num_connected_components"] > 1, "significant_annotations"
154
167
  ] = False
155
168
  # Assign the number of large connected components
156
- annotations_enrichment_matrix.loc[attribute, "num large connected components"] = (
169
+ annotations_enrichment_matrix.loc[attribute, "num_large_connected_components"] = (
157
170
  num_large_connected_components
158
171
  )
159
172
  # Assign the size of connected components, ensuring it is always a list
160
- annotations_enrichment_matrix.at[attribute, "size connected components"] = (
173
+ annotations_enrichment_matrix.at[attribute, "size_connected_components"] = (
161
174
  filtered_size_connected_components.tolist()
162
175
  )
163
176
 
164
177
  return annotations_enrichment_matrix
165
178
 
166
179
 
167
- def get_description(words_column: pd.Series) -> str:
168
- """Process input Series to identify and return the top frequent, significant words,
169
- filtering based on stopwords and gracefully handling numerical strings.
180
+ def get_weighted_description(words_column: pd.Series, scores_column: pd.Series) -> str:
181
+ """Generate a weighted description from words and their corresponding scores,
182
+ with support for stopwords filtering and improved weighting logic.
170
183
 
171
184
  Args:
172
185
  words_column (pd.Series): A pandas Series containing strings to process.
186
+ scores_column (pd.Series): A pandas Series containing enrichment scores to weigh the terms.
173
187
 
174
188
  Returns:
175
- str: A coherent description formed from the most frequent and significant words.
189
+ str: A coherent description formed from the most frequent and significant words, weighed by enrichment scores.
176
190
  """
177
- # Concatenate all rows into a single string and tokenize into words
178
- all_words = words_column.str.cat(sep=" ")
179
- tokens = word_tokenize(all_words)
191
+ # Handle case where all scores are the same
192
+ if scores_column.max() == scores_column.min():
193
+ normalized_scores = pd.Series([1] * len(scores_column))
194
+ else:
195
+ # Normalize the enrichment scores to be between 0 and 1
196
+ normalized_scores = (scores_column - scores_column.min()) / (
197
+ scores_column.max() - scores_column.min()
198
+ )
180
199
 
200
+ # Combine words and normalized scores to create weighted words
201
+ weighted_words = []
202
+ for word, score in zip(words_column, normalized_scores):
203
+ word = str(word)
204
+ if word not in stop_words: # Skip stopwords
205
+ weight = max(1, int((0 if pd.isna(score) else score) * 10))
206
+ weighted_words.extend([word] * weight)
207
+
208
+ # Tokenize the weighted words
209
+ tokens = word_tokenize(" ".join(weighted_words))
181
210
  # Separate numeric tokens
182
211
  numeric_tokens = [token for token in tokens if token.replace(".", "", 1).isdigit()]
183
- # If there's only one unique numeric value, return it directly as a string
184
212
  unique_numeric_values = set(numeric_tokens)
185
213
  if len(unique_numeric_values) == 1:
186
214
  return f"{list(unique_numeric_values)[0]}"
187
215
 
188
- # Ensure that all values in 'words' are strings and include both alphabetic and numeric tokens
189
- words = [
190
- str(word) # Convert to string to ensure consistent processing
191
- for word in tokens
192
- if word.isalpha()
193
- or word.replace(".", "", 1).isdigit() # Keep alphabetic words and numeric strings
194
- ]
216
+ # Filter alphabetic and numeric tokens
217
+ words = [word for word in tokens if word.isalpha() or word.replace(".", "", 1).isdigit()]
218
+ # Apply word similarity filtering to remove redundant terms
219
+ simplified_words = _simplify_word_list(words)
195
220
  # Generate a coherent description from the processed words
196
- description = _generate_coherent_description(words)
221
+ description = _generate_coherent_description(simplified_words)
197
222
 
198
223
  return description
199
224
 
@@ -13,7 +13,7 @@ import pandas as pd
13
13
  from scipy.cluster.hierarchy import linkage, fcluster
14
14
  from sklearn.metrics import silhouette_score
15
15
 
16
- from risk.annotations import get_description
16
+ from risk.annotations import get_weighted_description
17
17
  from risk.constants import GROUP_LINKAGE_METHODS, GROUP_DISTANCE_METRICS
18
18
  from risk.log import logger
19
19
 
@@ -40,7 +40,7 @@ def define_domains(
40
40
  """
41
41
  try:
42
42
  # Transpose the matrix to cluster annotations
43
- m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
43
+ m = significant_neighborhoods_enrichment[:, top_annotations["significant_annotations"]].T
44
44
  best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
45
45
  m, linkage_criterion, linkage_method, linkage_metric
46
46
  )
@@ -55,7 +55,7 @@ def define_domains(
55
55
  # Assign domains to the annotations matrix
56
56
  domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
57
57
  top_annotations["domain"] = 0
58
- top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
58
+ top_annotations.loc[top_annotations["significant_annotations"], "domain"] = domains
59
59
  except ValueError:
60
60
  # If a ValueError is encountered, handle it by assigning unique domains
61
61
  n_rows = len(top_annotations)
@@ -77,11 +77,11 @@ def define_domains(
77
77
  t_idxmax[t_max == 0] = 0
78
78
 
79
79
  # Assign all domains where the score is greater than 0
80
- node_to_domain["all domains"] = node_to_domain.loc[:, 1:].apply(
80
+ node_to_domain["all_domains"] = node_to_domain.loc[:, 1:].apply(
81
81
  lambda row: list(row[row > 0].index), axis=1
82
82
  )
83
83
  # Assign primary domain
84
- node_to_domain["primary domain"] = t_idxmax
84
+ node_to_domain["primary_domain"] = t_idxmax
85
85
 
86
86
  return node_to_domain
87
87
 
@@ -107,7 +107,7 @@ def trim_domains_and_top_annotations(
107
107
  - A DataFrame with domain labels (pd.DataFrame)
108
108
  """
109
109
  # Identify domains to remove based on size criteria
110
- domain_counts = domains["primary domain"].value_counts()
110
+ domain_counts = domains["primary_domain"].value_counts()
111
111
  to_remove = set(
112
112
  domain_counts[(domain_counts < min_cluster_size) | (domain_counts > max_cluster_size)].index
113
113
  )
@@ -117,32 +117,51 @@ def trim_domains_and_top_annotations(
117
117
  invalid_domain_ids = {0, invalid_domain_id}
118
118
  # Mark domains to be removed
119
119
  top_annotations["domain"].replace(to_remove, invalid_domain_id, inplace=True)
120
- domains.loc[domains["primary domain"].isin(to_remove), ["primary domain"]] = invalid_domain_id
120
+ domains.loc[domains["primary_domain"].isin(to_remove), ["primary_domain"]] = invalid_domain_id
121
121
 
122
122
  # Normalize "num enriched neighborhoods" by percentile for each domain and scale to 0-10
123
123
  top_annotations["normalized_value"] = top_annotations.groupby("domain")[
124
- "neighborhood enrichment sums"
124
+ "significant_neighborhood_enrichment_sums"
125
125
  ].transform(lambda x: (x.rank(pct=True) * 10).apply(np.ceil).astype(int))
126
- # Multiply 'words' column by normalized values
127
- top_annotations["words"] = top_annotations.apply(
128
- lambda row: " ".join([str(row["words"])] * row["normalized_value"]), axis=1
126
+ # Modify the lambda function to pass both full_terms and significant_enrichment_score
127
+ top_annotations["combined_terms"] = top_annotations.apply(
128
+ lambda row: " ".join([str(row["full_terms"])] * row["normalized_value"]), axis=1
129
129
  )
130
130
 
131
- # Generate domain labels
132
- domain_labels = top_annotations.groupby("domain")["words"].apply(get_description).reset_index()
131
+ # Perform the groupby operation while retaining the other columns and adding the weighting with enrichment scores
132
+ domain_labels = (
133
+ top_annotations.groupby("domain")
134
+ .agg(
135
+ full_terms=("full_terms", lambda x: list(x)),
136
+ enrichment_scores=("significant_enrichment_score", lambda x: list(x)),
137
+ )
138
+ .reset_index()
139
+ )
140
+ domain_labels["combined_terms"] = domain_labels.apply(
141
+ lambda row: get_weighted_description(
142
+ pd.Series(row["full_terms"]), pd.Series(row["enrichment_scores"])
143
+ ),
144
+ axis=1,
145
+ )
146
+
147
+ # Rename the columns as necessary
133
148
  trimmed_domains_matrix = domain_labels.rename(
134
- columns={"domain": "id", "words": "label"}
149
+ columns={
150
+ "domain": "id",
151
+ "combined_terms": "normalized_description",
152
+ "full_terms": "full_descriptions",
153
+ "enrichment_scores": "enrichment_scores",
154
+ }
135
155
  ).set_index("id")
136
156
 
137
157
  # Remove invalid domains
138
158
  valid_annotations = top_annotations[~top_annotations["domain"].isin(invalid_domain_ids)].drop(
139
159
  columns=["normalized_value"]
140
160
  )
141
- valid_domains = domains[~domains["primary domain"].isin(invalid_domain_ids)]
161
+ valid_domains = domains[~domains["primary_domain"].isin(invalid_domain_ids)]
142
162
  valid_trimmed_domains_matrix = trimmed_domains_matrix[
143
163
  ~trimmed_domains_matrix.index.isin(invalid_domain_ids)
144
164
  ]
145
-
146
165
  return valid_annotations, valid_domains, valid_trimmed_domains_matrix
147
166
 
148
167
 
@@ -171,7 +171,7 @@ def process_neighborhoods(
171
171
 
172
172
  Args:
173
173
  network (nx.Graph): The network data structure used for imputing and pruning neighbors.
174
- neighborhoods (Dict[str, Any]): Dictionary containing 'enrichment_matrix', 'binary_enrichment_matrix', and 'significant_enrichment_matrix'.
174
+ neighborhoods (Dict[str, Any]): Dictionary containing 'enrichment_matrix', 'significant_binary_enrichment_matrix', and 'significant_enrichment_matrix'.
175
175
  impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
176
176
  prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
177
177
 
@@ -179,18 +179,18 @@ def process_neighborhoods(
179
179
  Dict[str, Any]: Processed neighborhoods data, including the updated matrices and enrichment counts.
180
180
  """
181
181
  enrichment_matrix = neighborhoods["enrichment_matrix"]
182
- binary_enrichment_matrix = neighborhoods["binary_enrichment_matrix"]
182
+ significant_binary_enrichment_matrix = neighborhoods["significant_binary_enrichment_matrix"]
183
183
  significant_enrichment_matrix = neighborhoods["significant_enrichment_matrix"]
184
184
  logger.debug(f"Imputation depth: {impute_depth}")
185
185
  if impute_depth:
186
186
  (
187
187
  enrichment_matrix,
188
- binary_enrichment_matrix,
188
+ significant_binary_enrichment_matrix,
189
189
  significant_enrichment_matrix,
190
190
  ) = _impute_neighbors(
191
191
  network,
192
192
  enrichment_matrix,
193
- binary_enrichment_matrix,
193
+ significant_binary_enrichment_matrix,
194
194
  max_depth=impute_depth,
195
195
  )
196
196
 
@@ -198,20 +198,20 @@ def process_neighborhoods(
198
198
  if prune_threshold:
199
199
  (
200
200
  enrichment_matrix,
201
- binary_enrichment_matrix,
201
+ significant_binary_enrichment_matrix,
202
202
  significant_enrichment_matrix,
203
203
  ) = _prune_neighbors(
204
204
  network,
205
205
  enrichment_matrix,
206
- binary_enrichment_matrix,
206
+ significant_binary_enrichment_matrix,
207
207
  distance_threshold=prune_threshold,
208
208
  )
209
209
 
210
- neighborhood_enrichment_counts = np.sum(binary_enrichment_matrix, axis=0)
210
+ neighborhood_enrichment_counts = np.sum(significant_binary_enrichment_matrix, axis=0)
211
211
  node_enrichment_sums = np.sum(enrichment_matrix, axis=1)
212
212
  return {
213
213
  "enrichment_matrix": enrichment_matrix,
214
- "binary_enrichment_matrix": binary_enrichment_matrix,
214
+ "significant_binary_enrichment_matrix": significant_binary_enrichment_matrix,
215
215
  "significant_enrichment_matrix": significant_enrichment_matrix,
216
216
  "neighborhood_enrichment_counts": neighborhood_enrichment_counts,
217
217
  "node_enrichment_sums": node_enrichment_sums,
@@ -221,7 +221,7 @@ def process_neighborhoods(
221
221
  def _impute_neighbors(
222
222
  network: nx.Graph,
223
223
  enrichment_matrix: np.ndarray,
224
- binary_enrichment_matrix: np.ndarray,
224
+ significant_binary_enrichment_matrix: np.ndarray,
225
225
  max_depth: int = 3,
226
226
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
227
227
  """Impute rows with sums of zero in the enrichment matrix based on the closest non-zero neighbors in the network graph.
@@ -229,7 +229,7 @@ def _impute_neighbors(
229
229
  Args:
230
230
  network (nx.Graph): The network graph with nodes having IDs matching the matrix indices.
231
231
  enrichment_matrix (np.ndarray): The enrichment matrix with rows to be imputed.
232
- binary_enrichment_matrix (np.ndarray): The alpha threshold matrix to be imputed similarly.
232
+ significant_binary_enrichment_matrix (np.ndarray): The alpha threshold matrix to be imputed similarly.
233
233
  max_depth (int): Maximum depth of nodes to traverse for imputing values.
234
234
 
235
235
  Returns:
@@ -239,19 +239,21 @@ def _impute_neighbors(
239
239
  - np.ndarray: The significant enrichment matrix with non-significant entries set to zero.
240
240
  """
241
241
  # Calculate the distance threshold value based on the shortest distances
242
- enrichment_matrix, binary_enrichment_matrix = _impute_neighbors_with_similarity(
243
- network, enrichment_matrix, binary_enrichment_matrix, max_depth=max_depth
242
+ enrichment_matrix, significant_binary_enrichment_matrix = _impute_neighbors_with_similarity(
243
+ network, enrichment_matrix, significant_binary_enrichment_matrix, max_depth=max_depth
244
244
  )
245
245
  # Create a matrix where non-significant entries are set to zero
246
- significant_enrichment_matrix = np.where(binary_enrichment_matrix == 1, enrichment_matrix, 0)
246
+ significant_enrichment_matrix = np.where(
247
+ significant_binary_enrichment_matrix == 1, enrichment_matrix, 0
248
+ )
247
249
 
248
- return enrichment_matrix, binary_enrichment_matrix, significant_enrichment_matrix
250
+ return enrichment_matrix, significant_binary_enrichment_matrix, significant_enrichment_matrix
249
251
 
250
252
 
251
253
  def _impute_neighbors_with_similarity(
252
254
  network: nx.Graph,
253
255
  enrichment_matrix: np.ndarray,
254
- binary_enrichment_matrix: np.ndarray,
256
+ significant_binary_enrichment_matrix: np.ndarray,
255
257
  max_depth: int = 3,
256
258
  ) -> Tuple[np.ndarray, np.ndarray]:
257
259
  """Impute non-enriched nodes based on the closest enriched neighbors' profiles and their similarity.
@@ -259,7 +261,7 @@ def _impute_neighbors_with_similarity(
259
261
  Args:
260
262
  network (nx.Graph): The network graph with nodes having IDs matching the matrix indices.
261
263
  enrichment_matrix (np.ndarray): The enrichment matrix with rows to be imputed.
262
- binary_enrichment_matrix (np.ndarray): The alpha threshold matrix to be imputed similarly.
264
+ significant_binary_enrichment_matrix (np.ndarray): The alpha threshold matrix to be imputed similarly.
263
265
  max_depth (int): Maximum depth of nodes to traverse for imputing values.
264
266
 
265
267
  Returns:
@@ -268,27 +270,31 @@ def _impute_neighbors_with_similarity(
268
270
  - The imputed alpha threshold matrix.
269
271
  """
270
272
  depth = 1
271
- rows_to_impute = np.where(binary_enrichment_matrix.sum(axis=1) == 0)[0]
273
+ rows_to_impute = np.where(significant_binary_enrichment_matrix.sum(axis=1) == 0)[0]
272
274
  while len(rows_to_impute) and depth <= max_depth:
273
275
  # Iterate over all enriched nodes
274
- for row_index in range(binary_enrichment_matrix.shape[0]):
275
- if binary_enrichment_matrix[row_index].sum() != 0:
276
- enrichment_matrix, binary_enrichment_matrix = _process_node_imputation(
277
- row_index, network, enrichment_matrix, binary_enrichment_matrix, depth
276
+ for row_index in range(significant_binary_enrichment_matrix.shape[0]):
277
+ if significant_binary_enrichment_matrix[row_index].sum() != 0:
278
+ enrichment_matrix, significant_binary_enrichment_matrix = _process_node_imputation(
279
+ row_index,
280
+ network,
281
+ enrichment_matrix,
282
+ significant_binary_enrichment_matrix,
283
+ depth,
278
284
  )
279
285
 
280
286
  # Update rows to impute for the next iteration
281
- rows_to_impute = np.where(binary_enrichment_matrix.sum(axis=1) == 0)[0]
287
+ rows_to_impute = np.where(significant_binary_enrichment_matrix.sum(axis=1) == 0)[0]
282
288
  depth += 1
283
289
 
284
- return enrichment_matrix, binary_enrichment_matrix
290
+ return enrichment_matrix, significant_binary_enrichment_matrix
285
291
 
286
292
 
287
293
  def _process_node_imputation(
288
294
  row_index: int,
289
295
  network: nx.Graph,
290
296
  enrichment_matrix: np.ndarray,
291
- binary_enrichment_matrix: np.ndarray,
297
+ significant_binary_enrichment_matrix: np.ndarray,
292
298
  depth: int,
293
299
  ) -> Tuple[np.ndarray, np.ndarray]:
294
300
  """Process the imputation for a single node based on its enriched neighbors.
@@ -297,7 +303,7 @@ def _process_node_imputation(
297
303
  row_index (int): The index of the enriched node being processed.
298
304
  network (nx.Graph): The network graph with nodes having IDs matching the matrix indices.
299
305
  enrichment_matrix (np.ndarray): The enrichment matrix with rows to be imputed.
300
- binary_enrichment_matrix (np.ndarray): The alpha threshold matrix to be imputed similarly.
306
+ significant_binary_enrichment_matrix (np.ndarray): The alpha threshold matrix to be imputed similarly.
301
307
  depth (int): Current depth for traversal.
302
308
 
303
309
  Returns:
@@ -310,7 +316,7 @@ def _process_node_imputation(
310
316
  n
311
317
  for n in neighbors
312
318
  if n != row_index
313
- and binary_enrichment_matrix[n].sum() != 0
319
+ and significant_binary_enrichment_matrix[n].sum() != 0
314
320
  and enrichment_matrix[n].sum() != 0
315
321
  ]
316
322
  # Filter non-enriched neighbors
@@ -318,7 +324,7 @@ def _process_node_imputation(
318
324
  n
319
325
  for n in neighbors
320
326
  if n != row_index
321
- and binary_enrichment_matrix[n].sum() == 0
327
+ and significant_binary_enrichment_matrix[n].sum() == 0
322
328
  and enrichment_matrix[n].sum() == 0
323
329
  ]
324
330
  # If there are valid non-enriched neighbors
@@ -363,15 +369,17 @@ def _process_node_imputation(
363
369
  enrichment_matrix[most_similar_neighbor] = enrichment_matrix[row_index] / np.sqrt(
364
370
  depth + 1
365
371
  )
366
- binary_enrichment_matrix[most_similar_neighbor] = binary_enrichment_matrix[row_index]
372
+ significant_binary_enrichment_matrix[most_similar_neighbor] = (
373
+ significant_binary_enrichment_matrix[row_index]
374
+ )
367
375
 
368
- return enrichment_matrix, binary_enrichment_matrix
376
+ return enrichment_matrix, significant_binary_enrichment_matrix
369
377
 
370
378
 
371
379
  def _prune_neighbors(
372
380
  network: nx.Graph,
373
381
  enrichment_matrix: np.ndarray,
374
- binary_enrichment_matrix: np.ndarray,
382
+ significant_binary_enrichment_matrix: np.ndarray,
375
383
  distance_threshold: float = 0.9,
376
384
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
377
385
  """Remove outliers based on their rank for edge lengths.
@@ -379,7 +387,7 @@ def _prune_neighbors(
379
387
  Args:
380
388
  network (nx.Graph): The network graph with nodes having IDs matching the matrix indices.
381
389
  enrichment_matrix (np.ndarray): The enrichment matrix.
382
- binary_enrichment_matrix (np.ndarray): The alpha threshold matrix.
390
+ significant_binary_enrichment_matrix (np.ndarray): The alpha threshold matrix.
383
391
  distance_threshold (float): Rank threshold (0 to 1) to determine outliers.
384
392
 
385
393
  Returns:
@@ -389,10 +397,12 @@ def _prune_neighbors(
389
397
  - np.ndarray: The significant enrichment matrix, where non-significant entries are set to zero.
390
398
  """
391
399
  # Identify indices with non-zero rows in the binary enrichment matrix
392
- non_zero_indices = np.where(binary_enrichment_matrix.sum(axis=1) != 0)[0]
400
+ non_zero_indices = np.where(significant_binary_enrichment_matrix.sum(axis=1) != 0)[0]
393
401
  median_distances = []
394
402
  for node in non_zero_indices:
395
- neighbors = [n for n in network.neighbors(node) if binary_enrichment_matrix[n].sum() != 0]
403
+ neighbors = [
404
+ n for n in network.neighbors(node) if significant_binary_enrichment_matrix[n].sum() != 0
405
+ ]
396
406
  if neighbors:
397
407
  median_distance = np.median(
398
408
  [_get_euclidean_distance(node, n, network) for n in neighbors]
@@ -404,7 +414,9 @@ def _prune_neighbors(
404
414
  # Prune nodes that are outliers based on the distance threshold
405
415
  for row_index in non_zero_indices:
406
416
  neighbors = [
407
- n for n in network.neighbors(row_index) if binary_enrichment_matrix[n].sum() != 0
417
+ n
418
+ for n in network.neighbors(row_index)
419
+ if significant_binary_enrichment_matrix[n].sum() != 0
408
420
  ]
409
421
  if neighbors:
410
422
  median_distance = np.median(
@@ -412,12 +424,14 @@ def _prune_neighbors(
412
424
  )
413
425
  if median_distance >= distance_threshold_value:
414
426
  enrichment_matrix[row_index] = 0
415
- binary_enrichment_matrix[row_index] = 0
427
+ significant_binary_enrichment_matrix[row_index] = 0
416
428
 
417
429
  # Create a matrix where non-significant entries are set to zero
418
- significant_enrichment_matrix = np.where(binary_enrichment_matrix == 1, enrichment_matrix, 0)
430
+ significant_enrichment_matrix = np.where(
431
+ significant_binary_enrichment_matrix == 1, enrichment_matrix, 0
432
+ )
419
433
 
420
- return enrichment_matrix, binary_enrichment_matrix, significant_enrichment_matrix
434
+ return enrichment_matrix, significant_binary_enrichment_matrix, significant_enrichment_matrix
421
435
 
422
436
 
423
437
  def _get_euclidean_distance(node1: Any, node2: Any, network: nx.Graph) -> float:
risk/network/graph.py CHANGED
@@ -45,6 +45,10 @@ class NetworkGraph:
45
45
  self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
46
46
  trimmed_domains
47
47
  )
48
+ self.domain_id_to_domain_info_map = self._create_domain_id_to_domain_info_map(
49
+ trimmed_domains
50
+ )
51
+ self.trimmed_domains = trimmed_domains
48
52
  self.node_enrichment_sums = node_enrichment_sums
49
53
  self.node_id_to_domain_ids_and_enrichments_map = (
50
54
  self._create_node_id_to_domain_ids_and_enrichments(domains)
@@ -60,7 +64,8 @@ class NetworkGraph:
60
64
  self.network = _unfold_sphere_to_plane(network)
61
65
  self.node_coordinates = _extract_node_coordinates(self.network)
62
66
 
63
- def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[int, Any]:
67
+ @staticmethod
68
+ def _create_domain_id_to_node_ids_map(domains: pd.DataFrame) -> Dict[int, Any]:
64
69
  """Create a mapping from domains to the list of node IDs belonging to each domain.
65
70
 
66
71
  Args:
@@ -69,17 +74,16 @@ class NetworkGraph:
69
74
  Returns:
70
75
  Dict[int, Any]: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
71
76
  """
72
- cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
73
- node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
77
+ cleaned_domains_matrix = domains.reset_index()[["index", "primary_domain"]]
78
+ node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary_domain"].to_dict()
74
79
  domain_id_to_node_ids_map = defaultdict(list)
75
80
  for k, v in node_to_domains_map.items():
76
81
  domain_id_to_node_ids_map[v].append(k)
77
82
 
78
83
  return domain_id_to_node_ids_map
79
84
 
80
- def _create_domain_id_to_domain_terms_map(
81
- self, trimmed_domains: pd.DataFrame
82
- ) -> Dict[int, Any]:
85
+ @staticmethod
86
+ def _create_domain_id_to_domain_terms_map(trimmed_domains: pd.DataFrame) -> Dict[int, Any]:
83
87
  """Create a mapping from domain IDs to their corresponding terms.
84
88
 
85
89
  Args:
@@ -91,13 +95,32 @@ class NetworkGraph:
91
95
  return dict(
92
96
  zip(
93
97
  trimmed_domains.index,
94
- trimmed_domains["label"],
98
+ trimmed_domains["normalized_description"],
95
99
  )
96
100
  )
97
101
 
98
- def _create_node_id_to_domain_ids_and_enrichments(
99
- self, domains: pd.DataFrame
100
- ) -> Dict[int, Dict]:
102
+ @staticmethod
103
+ def _create_domain_id_to_domain_info_map(
104
+ trimmed_domains: pd.DataFrame,
105
+ ) -> Dict[int, Dict[str, Any]]:
106
+ """Create a mapping from domain IDs to their corresponding full description and enrichment score.
107
+
108
+ Args:
109
+ trimmed_domains (pd.DataFrame): DataFrame containing domain IDs, full descriptions, and enrichment scores.
110
+
111
+ Returns:
112
+ Dict[int, Dict[str, Any]]: A dictionary mapping domain IDs (int) to a dictionary with 'full_descriptions' and 'enrichment_scores'.
113
+ """
114
+ return {
115
+ int(id_): {
116
+ "full_descriptions": trimmed_domains.at[id_, "full_descriptions"],
117
+ "enrichment_scores": trimmed_domains.at[id_, "enrichment_scores"],
118
+ }
119
+ for id_ in trimmed_domains.index
120
+ }
121
+
122
+ @staticmethod
123
+ def _create_node_id_to_domain_ids_and_enrichments(domains: pd.DataFrame) -> Dict[int, Dict]:
101
124
  """Creates a dictionary mapping each node ID to its corresponding domain IDs and enrichment values.
102
125
 
103
126
  Args:
risk/network/io.py CHANGED
@@ -491,7 +491,7 @@ class NetworkIO:
491
491
  if "x" not in attrs or "y" not in attrs:
492
492
  if (
493
493
  "pos" in attrs
494
- and isinstance(attrs["pos"], (List, Tuple, np.ndarray))
494
+ and isinstance(attrs["pos"], (list, tuple, np.ndarray))
495
495
  and len(attrs["pos"]) >= 2
496
496
  ):
497
497
  attrs["x"], attrs["y"] = attrs["pos"][
@@ -137,18 +137,12 @@ class Canvas:
137
137
  perimeter_linestyle=linestyle,
138
138
  perimeter_linewidth=linewidth,
139
139
  perimeter_color=(
140
- "custom" if isinstance(color, (List, Tuple, np.ndarray)) else color
140
+ "custom" if isinstance(color, (list, tuple, np.ndarray)) else color
141
141
  ), # np.ndarray usually indicates custom colors
142
142
  perimeter_outline_alpha=outline_alpha,
143
143
  perimeter_fill_alpha=fill_alpha,
144
144
  )
145
145
 
146
- # Convert color to RGBA using the to_rgba helper function - use outline_alpha for the perimeter
147
- color = to_rgba(
148
- color=color, alpha=outline_alpha, num_repeats=1
149
- ) # num_repeats=1 for a single color
150
- # Set the fill_alpha to 0 if not provided
151
- fill_alpha = fill_alpha if fill_alpha is not None else 0.0
152
146
  # Extract node coordinates from the network graph
153
147
  node_coordinates = self.graph.node_coordinates
154
148
  # Calculate the center and radius of the bounding box around the network
@@ -156,20 +150,26 @@ class Canvas:
156
150
  # Scale the radius by the scale factor
157
151
  scaled_radius = radius * scale
158
152
 
153
+ # Convert color to RGBA using the to_rgba helper function - use outline_alpha for the perimeter
154
+ outline_color_rgba = to_rgba(
155
+ color=color, alpha=outline_alpha, num_repeats=1
156
+ ) # num_repeats=1 for a single color
157
+ fill_color_rgba = to_rgba(
158
+ color=color, alpha=fill_alpha, num_repeats=1
159
+ ) # num_repeats=1 for a single color
160
+
159
161
  # Draw a circle to represent the network perimeter
160
162
  circle = plt.Circle(
161
163
  center,
162
164
  scaled_radius,
163
165
  linestyle=linestyle,
164
166
  linewidth=linewidth,
165
- color=color,
166
- fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
167
+ color=outline_color_rgba,
167
168
  )
168
169
  # Set the transparency of the fill if applicable
169
- if fill_alpha > 0:
170
- circle.set_facecolor(
171
- to_rgba(color=color, alpha=fill_alpha, num_repeats=1)
172
- ) # num_repeats=1 for a single color
170
+ circle.set_facecolor(
171
+ to_rgba(color=fill_color_rgba, num_repeats=1)
172
+ ) # num_repeats=1 for a single color
173
173
 
174
174
  self.ax.add_artist(circle)
175
175
 
@@ -210,13 +210,13 @@ class Canvas:
210
210
  perimeter_grid_size=grid_size,
211
211
  perimeter_linestyle=linestyle,
212
212
  perimeter_linewidth=linewidth,
213
- perimeter_color=("custom" if isinstance(color, (List, Tuple, np.ndarray)) else color),
213
+ perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
214
214
  perimeter_outline_alpha=outline_alpha,
215
215
  perimeter_fill_alpha=fill_alpha,
216
216
  )
217
217
 
218
218
  # Convert color to RGBA using outline_alpha for the line (outline)
219
- outline_color = to_rgba(color=color, num_repeats=1) # num_repeats=1 for a single color
219
+ outline_color_rgba = to_rgba(color=color, num_repeats=1) # num_repeats=1 for a single color
220
220
  # Extract node coordinates from the network graph
221
221
  node_coordinates = self.graph.node_coordinates
222
222
  # Scale the node coordinates if needed
@@ -229,9 +229,8 @@ class Canvas:
229
229
  levels=levels,
230
230
  bandwidth=bandwidth,
231
231
  grid_size=grid_size,
232
- color=outline_color,
232
+ color=outline_color_rgba,
233
233
  linestyle=linestyle,
234
234
  linewidth=linewidth,
235
- alpha=outline_alpha,
236
235
  fill_alpha=fill_alpha,
237
236
  )
@@ -68,13 +68,15 @@ class Contour:
68
68
  )
69
69
 
70
70
  # Ensure color is converted to RGBA with repetition matching the number of domains
71
- color = to_rgba(
71
+ color_rgba = to_rgba(
72
72
  color=color, alpha=alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
73
73
  )
74
74
  # Extract node coordinates from the network graph
75
75
  node_coordinates = self.graph.node_coordinates
76
76
  # Draw contours for each domain in the network
77
77
  for idx, (_, node_ids) in enumerate(self.graph.domain_id_to_node_ids_map.items()):
78
+ # Use the provided alpha value if it's not None, otherwise use the color's alpha
79
+ current_fill_alpha = fill_alpha if fill_alpha is not None else color_rgba[idx][3]
78
80
  if len(node_ids) > 1:
79
81
  self._draw_kde_contour(
80
82
  self.ax,
@@ -86,8 +88,7 @@ class Contour:
86
88
  grid_size=grid_size,
87
89
  linestyle=linestyle,
88
90
  linewidth=linewidth,
89
- alpha=alpha,
90
- fill_alpha=fill_alpha,
91
+ fill_alpha=current_fill_alpha,
91
92
  )
92
93
 
93
94
  def plot_subcontour(
@@ -122,7 +123,7 @@ class Contour:
122
123
  ValueError: If no valid nodes are found in the network graph.
123
124
  """
124
125
  # Check if nodes is a list of lists or a flat list
125
- if any(isinstance(item, (List, Tuple, np.ndarray)) for item in nodes):
126
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
126
127
  # If it's a list of lists, iterate over sublists
127
128
  node_groups = nodes
128
129
  # Convert color to RGBA arrays to match the number of groups
@@ -148,6 +149,8 @@ class Contour:
148
149
 
149
150
  # Draw the KDE contour for the specified nodes
150
151
  node_coordinates = self.graph.node_coordinates
152
+ # Use the provided alpha value if it's not None, otherwise use the color's alpha
153
+ current_fill_alpha = fill_alpha if fill_alpha is not None else color_rgba[idx][3]
151
154
  self._draw_kde_contour(
152
155
  self.ax,
153
156
  node_coordinates,
@@ -158,8 +161,7 @@ class Contour:
158
161
  grid_size=grid_size,
159
162
  linestyle=linestyle,
160
163
  linewidth=linewidth,
161
- alpha=alpha,
162
- fill_alpha=fill_alpha,
164
+ fill_alpha=current_fill_alpha,
163
165
  )
164
166
 
165
167
  def _draw_kde_contour(
@@ -173,7 +175,6 @@ class Contour:
173
175
  color: Union[str, np.ndarray] = "white",
174
176
  linestyle: str = "solid",
175
177
  linewidth: float = 1.5,
176
- alpha: Union[float, None] = 1.0,
177
178
  fill_alpha: Union[float, None] = 0.2,
178
179
  ) -> None:
179
180
  """Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
@@ -188,8 +189,6 @@ class Contour:
188
189
  color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
189
190
  linestyle (str, optional): Line style for the contour. Defaults to "solid".
190
191
  linewidth (float, optional): Line width for the contour. Defaults to 1.5.
191
- alpha (float, None, optional): Transparency level for the contour lines. If provided, it overrides any existing alpha
192
- values found in color. Defaults to 1.0.
193
192
  fill_alpha (float, None, optional): Transparency level for the contour fill. If provided, it overrides any existing
194
193
  alpha values found in color. Defaults to 0.2.
195
194
  """
@@ -245,6 +244,8 @@ class Contour:
245
244
  contour_colors = [color for _ in range(levels - 1)]
246
245
  # Plot the filled contours using fill_alpha for transparency
247
246
  if fill_alpha and fill_alpha > 0:
247
+ # Fill alpha works differently than alpha for contour lines
248
+ # Contour fill cannot be specified by RGBA, while contour lines can
248
249
  ax.contourf(
249
250
  x,
250
251
  y,
@@ -255,7 +256,7 @@ class Contour:
255
256
  alpha=fill_alpha,
256
257
  )
257
258
 
258
- # Plot the contour lines with the specified alpha for transparency
259
+ # Plot the contour lines with the specified RGBA alpha for transparency
259
260
  c = ax.contour(
260
261
  x,
261
262
  y,
@@ -264,7 +265,6 @@ class Contour:
264
265
  colors=contour_colors,
265
266
  linestyles=linestyle,
266
267
  linewidths=linewidth,
267
- alpha=alpha,
268
268
  )
269
269
 
270
270
  # Set linewidth for the contour lines to 0 for levels other than the base level
@@ -191,10 +191,10 @@ class Labels:
191
191
  filtered_domain_centroids, center, radius, offset
192
192
  )
193
193
  # Convert all domain colors to RGBA using the to_rgba helper function
194
- fontcolor = to_rgba(
194
+ fontcolor_rgba = to_rgba(
195
195
  color=fontcolor, alpha=fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
196
196
  )
197
- arrow_color = to_rgba(
197
+ arrow_color_rgba = to_rgba(
198
198
  color=arrow_color,
199
199
  alpha=arrow_alpha,
200
200
  num_repeats=len(self.graph.domain_id_to_node_ids_map),
@@ -216,10 +216,10 @@ class Labels:
216
216
  va="center",
217
217
  fontsize=fontsize,
218
218
  fontname=font,
219
- color=fontcolor[idx],
219
+ color=fontcolor_rgba[idx],
220
220
  arrowprops=dict(
221
221
  arrowstyle=arrow_style,
222
- color=arrow_color[idx],
222
+ color=arrow_color_rgba[idx],
223
223
  linewidth=arrow_linewidth,
224
224
  shrinkA=arrow_base_shrink,
225
225
  shrinkB=arrow_tip_shrink,
@@ -238,8 +238,7 @@ class Labels:
238
238
  va="center",
239
239
  fontsize=fontsize,
240
240
  fontname=font,
241
- color=fontcolor[idx],
242
- alpha=fontalpha,
241
+ color=fontcolor_rgba[idx],
243
242
  )
244
243
 
245
244
  def plot_sublabel(
@@ -282,7 +281,7 @@ class Labels:
282
281
  arrow_tip_shrink (float, optional): Distance between the arrow tip and the centroid. Defaults to 0.0.
283
282
  """
284
283
  # Check if nodes is a list of lists or a flat list
285
- if any(isinstance(item, (List, Tuple, np.ndarray)) for item in nodes):
284
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
286
285
  # If it's a list of lists, iterate over sublists
287
286
  node_groups = nodes
288
287
  # Convert fontcolor and arrow_color to RGBA arrays to match the number of groups
@@ -75,13 +75,13 @@ class Network:
75
75
 
76
76
  # Convert colors to RGBA using the to_rgba helper function
77
77
  # If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
78
- node_color = to_rgba(
78
+ node_color_rgba = to_rgba(
79
79
  color=node_color, alpha=node_alpha, num_repeats=len(self.graph.network.nodes)
80
80
  )
81
- node_edgecolor = to_rgba(
81
+ node_edgecolor_rgba = to_rgba(
82
82
  color=node_edgecolor, alpha=1.0, num_repeats=len(self.graph.network.nodes)
83
83
  )
84
- edge_color = to_rgba(
84
+ edge_color_rgba = to_rgba(
85
85
  color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
86
86
  )
87
87
 
@@ -94,8 +94,8 @@ class Network:
94
94
  pos=node_coordinates,
95
95
  node_size=node_size,
96
96
  node_shape=node_shape,
97
- node_color=node_color,
98
- edgecolors=node_edgecolor,
97
+ node_color=node_color_rgba,
98
+ edgecolors=node_edgecolor_rgba,
99
99
  linewidths=node_edgewidth,
100
100
  ax=self.ax,
101
101
  )
@@ -104,7 +104,7 @@ class Network:
104
104
  self.graph.network,
105
105
  pos=node_coordinates,
106
106
  width=edge_width,
107
- edge_color=edge_color,
107
+ edge_color=edge_color_rgba,
108
108
  ax=self.ax,
109
109
  )
110
110
 
@@ -141,7 +141,7 @@ class Network:
141
141
  ValueError: If no valid nodes are found in the network graph.
142
142
  """
143
143
  # Flatten nested lists of nodes, if necessary
144
- if any(isinstance(item, (List, Tuple, np.ndarray)) for item in nodes):
144
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
145
145
  nodes = [node for sublist in nodes for node in sublist]
146
146
 
147
147
  # Filter to get node IDs and their coordinates
@@ -162,9 +162,9 @@ class Network:
162
162
  ]
163
163
 
164
164
  # Convert colors to RGBA using the to_rgba helper function
165
- node_color = to_rgba(color=node_color, alpha=node_alpha, num_repeats=len(node_ids))
166
- node_edgecolor = to_rgba(color=node_edgecolor, alpha=1.0, num_repeats=len(node_ids))
167
- edge_color = to_rgba(
165
+ node_color_rgba = to_rgba(color=node_color, alpha=node_alpha, num_repeats=len(node_ids))
166
+ node_edgecolor_rgba = to_rgba(color=node_edgecolor, alpha=1.0, num_repeats=len(node_ids))
167
+ edge_color_rgba = to_rgba(
168
168
  color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
169
169
  )
170
170
 
@@ -178,8 +178,8 @@ class Network:
178
178
  nodelist=node_ids,
179
179
  node_size=node_size,
180
180
  node_shape=node_shape,
181
- node_color=node_color,
182
- edgecolors=node_edgecolor,
181
+ node_color=node_color_rgba,
182
+ edgecolors=node_edgecolor_rgba,
183
183
  linewidths=node_edgewidth,
184
184
  ax=self.ax,
185
185
  )
@@ -189,7 +189,7 @@ class Network:
189
189
  subgraph,
190
190
  pos=node_coordinates,
191
191
  width=edge_width,
192
- edge_color=edge_color,
192
+ edge_color=edge_color_rgba,
193
193
  ax=self.ax,
194
194
  )
195
195
 
@@ -244,7 +244,7 @@ class Network:
244
244
  # Apply the alpha value for enriched nodes
245
245
  network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
246
246
  # Convert the non-enriched color to RGBA using the to_rgba helper function
247
- nonenriched_color = to_rgba(
247
+ nonenriched_color_rgba = to_rgba(
248
248
  color=nonenriched_color, alpha=nonenriched_alpha, num_repeats=1
249
249
  ) # num_repeats=1 for a single color
250
250
  # Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
@@ -255,7 +255,7 @@ class Network:
255
255
  & np.all(network_colors[:, :3] == network_colors[:, 0:1], axis=1)
256
256
  )[:, None],
257
257
  np.tile(
258
- np.array(nonenriched_color), (network_colors.shape[0], 1)
258
+ np.array(nonenriched_color_rgba), (network_colors.shape[0], 1)
259
259
  ), # Replace with the full RGBA non-enriched color
260
260
  network_colors, # Keep the original colors where no match is found
261
261
  )
@@ -377,7 +377,7 @@ def to_rgba(
377
377
  if isinstance(c, str):
378
378
  # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
379
379
  rgba = np.array(mcolors.to_rgba(c))
380
- elif isinstance(c, (List, Tuple, np.ndarray)) and len(c) in [3, 4]:
380
+ elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
381
381
  # Convert RGB (3) or RGBA (4) values to RGBA format
382
382
  rgba = np.array(mcolors.to_rgba(c))
383
383
  else:
@@ -396,8 +396,8 @@ def to_rgba(
396
396
  # Handle a single color (string or RGB/RGBA list/tuple)
397
397
  if (
398
398
  isinstance(color, str)
399
- or isinstance(color, (List, Tuple, np.ndarray))
400
- and not any(isinstance(c, (str, List, Tuple, np.ndarray)) for c in color)
399
+ or isinstance(color, (list, tuple, np.ndarray))
400
+ and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
401
401
  ):
402
402
  rgba_color = convert_to_rgba(color)
403
403
  if num_repeats:
@@ -407,7 +407,7 @@ def to_rgba(
407
407
  return np.array([rgba_color]) # Return a single color wrapped in a numpy array
408
408
 
409
409
  # Handle a list/array of colors
410
- elif isinstance(color, (List, Tuple, np.ndarray)):
410
+ elif isinstance(color, (list, tuple, np.ndarray)):
411
411
  rgba_colors = np.array(
412
412
  [convert_to_rgba(c) for c in color]
413
413
  ) # Convert each color in the list to RGBA
risk/risk.py CHANGED
@@ -3,7 +3,7 @@ risk/risk
3
3
  ~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Any, Dict, Tuple, Union
6
+ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import networkx as nx
9
9
  import numpy as np
@@ -58,9 +58,9 @@ class RISK(NetworkIO, AnnotationsIO):
58
58
  self,
59
59
  network: nx.Graph,
60
60
  annotations: Dict[str, Any],
61
- distance_metric: str = "louvain",
61
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
62
62
  louvain_resolution: float = 0.1,
63
- edge_length_threshold: float = 0.5,
63
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
64
64
  null_distribution: str = "network",
65
65
  random_seed: int = 888,
66
66
  ) -> Dict[str, Any]:
@@ -69,9 +69,13 @@ class RISK(NetworkIO, AnnotationsIO):
69
69
  Args:
70
70
  network (nx.Graph): The network graph.
71
71
  annotations (Dict[str, Any]): The annotations associated with the network.
72
- distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "louvain".
72
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
73
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
74
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
73
75
  louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
74
- edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
76
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
77
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
78
+ Defaults to 0.5.
75
79
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
76
80
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
77
81
 
@@ -111,9 +115,9 @@ class RISK(NetworkIO, AnnotationsIO):
111
115
  self,
112
116
  network: nx.Graph,
113
117
  annotations: Dict[str, Any],
114
- distance_metric: str = "louvain",
118
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
115
119
  louvain_resolution: float = 0.1,
116
- edge_length_threshold: float = 0.5,
120
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
117
121
  null_distribution: str = "network",
118
122
  random_seed: int = 888,
119
123
  ) -> Dict[str, Any]:
@@ -122,9 +126,13 @@ class RISK(NetworkIO, AnnotationsIO):
122
126
  Args:
123
127
  network (nx.Graph): The network graph.
124
128
  annotations (Dict[str, Any]): The annotations associated with the network.
125
- distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "louvain".
129
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
130
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
131
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
126
132
  louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
127
- edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
133
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
134
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
135
+ Defaults to 0.5.
128
136
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
129
137
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
130
138
 
@@ -164,9 +172,9 @@ class RISK(NetworkIO, AnnotationsIO):
164
172
  self,
165
173
  network: nx.Graph,
166
174
  annotations: Dict[str, Any],
167
- distance_metric: str = "louvain",
175
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
168
176
  louvain_resolution: float = 0.1,
169
- edge_length_threshold: float = 0.5,
177
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
170
178
  score_metric: str = "sum",
171
179
  null_distribution: str = "network",
172
180
  num_permutations: int = 1000,
@@ -178,9 +186,13 @@ class RISK(NetworkIO, AnnotationsIO):
178
186
  Args:
179
187
  network (nx.Graph): The network graph.
180
188
  annotations (Dict[str, Any]): The annotations associated with the network.
181
- distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "louvain".
189
+ distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use. Can be a string for one
190
+ metric or a list/tuple/ndarray of metrics ('greedy_modularity', 'louvain', 'label_propagation',
191
+ 'markov_clustering', 'walktrap', 'spinglass'). Defaults to 'louvain'.
182
192
  louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
183
- edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
193
+ edge_length_threshold (float, List, Tuple, or np.ndarray, optional): Edge length threshold(s) for creating subgraphs.
194
+ Can be a single float for one threshold or a list/tuple of floats corresponding to multiple thresholds.
195
+ Defaults to 0.5.
184
196
  score_metric (str, optional): Scoring metric for neighborhood significance. Defaults to "sum".
185
197
  null_distribution (str, optional): Type of null distribution ('network' or 'annotations'). Defaults to "network".
186
198
  num_permutations (int, optional): Number of permutations for significance testing. Defaults to 1000.
@@ -353,7 +365,7 @@ class RISK(NetworkIO, AnnotationsIO):
353
365
  def load_plotter(
354
366
  self,
355
367
  graph: NetworkGraph,
356
- figsize: Tuple = (10, 10),
368
+ figsize: Union[List, Tuple, np.ndarray] = (10, 10),
357
369
  background_color: str = "white",
358
370
  background_alpha: Union[float, None] = 1.0,
359
371
  pad: float = 0.3,
@@ -362,7 +374,7 @@ class RISK(NetworkIO, AnnotationsIO):
362
374
 
363
375
  Args:
364
376
  graph (NetworkGraph): The graph to plot.
365
- figsize (Tuple, optional): Size of the figure. Defaults to (10, 10).
377
+ figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
366
378
  background_color (str, optional): Background color of the plot. Defaults to "white".
367
379
  background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
368
380
  any existing alpha values found in background_color. Defaults to 1.0.
@@ -385,9 +397,9 @@ class RISK(NetworkIO, AnnotationsIO):
385
397
  def _load_neighborhoods(
386
398
  self,
387
399
  network: nx.Graph,
388
- distance_metric: str = "louvain",
400
+ distance_metric: Union[str, List, Tuple, np.ndarray] = "louvain",
389
401
  louvain_resolution: float = 0.1,
390
- edge_length_threshold: float = 0.5,
402
+ edge_length_threshold: Union[float, List, Tuple, np.ndarray] = 0.5,
391
403
  random_seed: int = 888,
392
404
  ) -> np.ndarray:
393
405
  """Load significant neighborhoods for the network.
@@ -452,13 +464,15 @@ class RISK(NetworkIO, AnnotationsIO):
452
464
  # Extract necessary data from annotations and neighborhoods
453
465
  ordered_annotations = annotations["ordered_annotations"]
454
466
  neighborhood_enrichment_sums = neighborhoods["neighborhood_enrichment_counts"]
455
- neighborhoods_binary_enrichment_matrix = neighborhoods["binary_enrichment_matrix"]
467
+ significant_enrichment_matrix = neighborhoods["significant_enrichment_matrix"]
468
+ significant_binary_enrichment_matrix = neighborhoods["significant_binary_enrichment_matrix"]
456
469
  # Call external function to define top annotations
457
470
  return define_top_annotations(
458
471
  network=network,
459
472
  ordered_annotation_labels=ordered_annotations,
460
473
  neighborhood_enrichment_sums=neighborhood_enrichment_sums,
461
- binary_enrichment_matrix=neighborhoods_binary_enrichment_matrix,
474
+ significant_enrichment_matrix=significant_enrichment_matrix,
475
+ significant_binary_enrichment_matrix=significant_binary_enrichment_matrix,
462
476
  min_cluster_size=min_cluster_size,
463
477
  max_cluster_size=max_cluster_size,
464
478
  )
risk/stats/stats.py CHANGED
@@ -62,7 +62,7 @@ def calculate_significance_matrices(
62
62
  log_enrichment_matrix = -np.log10(enrichment_matrix)
63
63
 
64
64
  # Select the appropriate significance matrices based on the specified tail
65
- enrichment_matrix, binary_enrichment_matrix = _select_significance_matrices(
65
+ enrichment_matrix, significant_binary_enrichment_matrix = _select_significance_matrices(
66
66
  tail,
67
67
  log_depletion_matrix,
68
68
  depletion_alpha_threshold_matrix,
@@ -71,11 +71,13 @@ def calculate_significance_matrices(
71
71
  )
72
72
 
73
73
  # Filter the enrichment matrix using the binary significance matrix
74
- significant_enrichment_matrix = np.where(binary_enrichment_matrix == 1, enrichment_matrix, 0)
74
+ significant_enrichment_matrix = np.where(
75
+ significant_binary_enrichment_matrix == 1, enrichment_matrix, 0
76
+ )
75
77
 
76
78
  return {
77
79
  "enrichment_matrix": enrichment_matrix,
78
- "binary_enrichment_matrix": binary_enrichment_matrix,
80
+ "significant_binary_enrichment_matrix": significant_binary_enrichment_matrix,
79
81
  "significant_enrichment_matrix": significant_enrichment_matrix,
80
82
  }
81
83
 
@@ -127,10 +129,10 @@ def _select_significance_matrices(
127
129
 
128
130
  # Create a binary significance matrix where valid indices meet the alpha threshold
129
131
  valid_idxs = ~np.isnan(alpha_threshold_matrix)
130
- binary_enrichment_matrix = np.zeros(alpha_threshold_matrix.shape)
131
- binary_enrichment_matrix[valid_idxs] = alpha_threshold_matrix[valid_idxs]
132
+ significant_binary_enrichment_matrix = np.zeros(alpha_threshold_matrix.shape)
133
+ significant_binary_enrichment_matrix[valid_idxs] = alpha_threshold_matrix[valid_idxs]
132
134
 
133
- return enrichment_matrix, binary_enrichment_matrix
135
+ return enrichment_matrix, significant_binary_enrichment_matrix
134
136
 
135
137
 
136
138
  def _compute_threshold_matrix(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.8b20
3
+ Version: 0.0.8b22
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -1,37 +1,37 @@
1
- risk/__init__.py,sha256=MOWrmv2B-I1GTr0sgWTb-CLgeMWceWa6q2E6oeGC2CA,113
1
+ risk/__init__.py,sha256=fz5ZBsLUlWdBQ5uJS0qBge1qwPNox0OYhi4OXkjQxwI,113
2
2
  risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
- risk/risk.py,sha256=_ufeTLOAAT4QwrRysvDJOQeE0qMvpp3BSSayfFLhGJE,21720
4
- risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
5
- risk/annotations/annotations.py,sha256=KHGeF5vBDmX711nA08DfhxI9z7Z1Oaeo91ueWhM6vs8,11370
3
+ risk/risk.py,sha256=rjV0hllegCX978QaUo175FworKxNXlhQEQaQAPjHqos,23397
4
+ risk/annotations/__init__.py,sha256=kXgadEXaCh0z8OyhOhTj7c3qXGmWgOhaSZ4gSzSb59U,147
5
+ risk/annotations/annotations.py,sha256=giLJht0tPtf4UdtH_d0kbCZQU5H5fZoupGDFKaNbC_Q,12700
6
6
  risk/annotations/io.py,sha256=powWzeimVdE0WCwlBCXyu5otMyZZHQujC0DS3m5DC0c,9505
7
7
  risk/log/__init__.py,sha256=aDUz5LMFQsz0UlsQI2EdXtiBKRLfml1UMeZKC7QQIGU,134
8
8
  risk/log/config.py,sha256=m8pzj-hN4vI_2JdJUfyOoSvzT8_lhoIfBt27sKbnOes,4535
9
9
  risk/log/params.py,sha256=rvyg86RnkHwotST7x42RgsiYfq2HB-9BZxp6KkT_04o,6415
10
10
  risk/neighborhoods/__init__.py,sha256=tKKEg4lsbqFukpgYlUGxU_v_9FOqK7V0uvM9T2QzoL0,206
11
11
  risk/neighborhoods/community.py,sha256=MAgIblbuisEPwVU6mFZd4Yd9NUKlaHK99suw51r1Is0,7065
12
- risk/neighborhoods/domains.py,sha256=DbhUFsvbr8wuvrNr7a0PaAJO-cdv6U3-T4CXB4-j5Qw,10930
13
- risk/neighborhoods/neighborhoods.py,sha256=OPGNfeGQR533vWjger7f34ZPSgw9250LQXcTEIAhQvg,21165
12
+ risk/neighborhoods/domains.py,sha256=3iV0-nRLF2sL9_7epHY5b9AtTU-QQ84hOWO76VwFcrs,11685
13
+ risk/neighborhoods/neighborhoods.py,sha256=cT9CCi1uQLn9Kv9Lxt8AN_4s63SKIlOZspvUZnx27nE,21832
14
14
  risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
15
15
  risk/network/geometry.py,sha256=Y3Brp0XYWoBL2VHJX7I-gW5x-q7lGiEMqr2kqtutgkQ,6811
16
- risk/network/graph.py,sha256=-91JL84LYbdWohzybKFQ3NdWnervxP-wwbpaUOdRVLE,8576
17
- risk/network/io.py,sha256=w_9fUcZUVXAPRKGhLBc7xhIJs8l83szHiBQTdaNN0gk,22942
16
+ risk/network/graph.py,sha256=-tslu8nSbuBaqNGf6TQ8ON7C27v-BLH_37J2aC6Ke14,9602
17
+ risk/network/io.py,sha256=u0PPcKjp6Xze--7eDOlvalYkjQ9S2sjiC-ac2476PUI,22942
18
18
  risk/network/plot/__init__.py,sha256=MfmaXJgAZJgXZ2wrhK8pXwzETlcMaLChhWXKAozniAo,98
19
- risk/network/plot/canvas.py,sha256=hdrmGd2TCuii8wn6jDQfyJTI5YXDNGYFLiU4TyqAYbE,10778
20
- risk/network/plot/contour.py,sha256=xxTf6iNSlpe2S8aalt2mzivmR0wuGUOh_F3-IL6UbEU,15027
21
- risk/network/plot/labels.py,sha256=bFsP9NA3Fp0GhX62ArRP9tSqPCgUthKE9aFe0imoPcI,45115
22
- risk/network/plot/network.py,sha256=nfTmQxx1YwS3taXwq8WSCfu6nfKFOyxj7T5605qLXVM,13615
19
+ risk/network/plot/canvas.py,sha256=ZO6bHw1chIsUqtE7IkPKdgX4tFLA-T5OwN5SojqGSNU,10672
20
+ risk/network/plot/contour.py,sha256=CwX4i3uE5HL0W4kfx34U7YyoTTqMxyb7xaXKRVoNLzY,15265
21
+ risk/network/plot/labels.py,sha256=ozkqwhBOTHKJLaAz4dJopXuykAvssSZUer2W5V0x2jM,45103
22
+ risk/network/plot/network.py,sha256=6RURL1OdBFyQ34qNcwM_uH3LSQGYZZ8tZT51dggH1a0,13685
23
23
  risk/network/plot/plotter.py,sha256=iTPMiTnTTatM_-q1Ox_bjt5Pvv-Lo8gceiYB6TVzDcw,5770
24
- risk/network/plot/utils/color.py,sha256=HtUaGnqJPVNbRyUhQMlBonfHc_2Ci8BtTI3y424p8Cs,19626
24
+ risk/network/plot/utils/color.py,sha256=WSs1ge2oZ8yXwyVk2QqBF-avRd0aYT-sYZr9cxxAn7M,19626
25
25
  risk/network/plot/utils/layout.py,sha256=5DpRLvabgnPWwVJ-J3W6oFBBvbjCrudvvW4HDOzzoTo,1960
26
26
  risk/stats/__init__.py,sha256=WcgoETQ-hS0LQqKRsAMIPtP15xZ-4eul6VUBuUx4Wzc,220
27
27
  risk/stats/hypergeom.py,sha256=oc39f02ViB1vQ-uaDrxG_tzAT6dxQBRjc88EK2EGn78,2282
28
28
  risk/stats/poisson.py,sha256=polLgwS08MTCNzupYdmMUoEUYrJOjAbcYtYwjlfeE5Y,1803
29
- risk/stats/stats.py,sha256=07yMULKlCurK62x674SHKJavZtz9ge2K2ZsHix_z_pw,7088
29
+ risk/stats/stats.py,sha256=6iGi0-oN05mTmupg6X_VEBxEQvi2rujNhfPk4aLjwNI,7186
30
30
  risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
31
31
  risk/stats/permutation/permutation.py,sha256=meBNSrbRa9P8WJ54n485l0H7VQJlMSfHqdN4aCKYCtQ,10105
32
32
  risk/stats/permutation/test_functions.py,sha256=lftOude6hee0pyR80HlBD32522JkDoN5hrKQ9VEbuoY,2345
33
- risk_network-0.0.8b20.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
34
- risk_network-0.0.8b20.dist-info/METADATA,sha256=UnAgNaBf77W4-Vo5YGPJktwy5WQaEwWU2ByhSbyfEVg,47498
35
- risk_network-0.0.8b20.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
36
- risk_network-0.0.8b20.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
37
- risk_network-0.0.8b20.dist-info/RECORD,,
33
+ risk_network-0.0.8b22.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
34
+ risk_network-0.0.8b22.dist-info/METADATA,sha256=9trSkrh2Od_B2qltA2n_uVcvX1kUlMy-QmLO4WThrds,47498
35
+ risk_network-0.0.8b22.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
36
+ risk_network-0.0.8b22.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
37
+ risk_network-0.0.8b22.dist-info/RECORD,,