risk-network 0.0.9b2__py3-none-any.whl → 0.0.9b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.9-beta.2"
10
+ __version__ = "0.0.9-beta.4"
@@ -10,7 +10,7 @@ import networkx as nx
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
 
13
- from risk.network.graph.summary import Summary
13
+ from risk.network.graph.summary import AnalysisSummary
14
14
 
15
15
 
16
16
  class NetworkGraph:
@@ -73,7 +73,7 @@ class NetworkGraph:
73
73
  self.node_coordinates = _extract_node_coordinates(self.network)
74
74
 
75
75
  # NOTE: Only after the above attributes are initialized, we can create the summary
76
- self.summary = Summary(annotations, neighborhoods, self)
76
+ self.summary = AnalysisSummary(annotations, neighborhoods, self)
77
77
 
78
78
  @staticmethod
79
79
  def _create_domain_id_to_node_ids_map(domains: pd.DataFrame) -> Dict[int, Any]:
@@ -114,21 +114,38 @@ class NetworkGraph:
114
114
  def _create_domain_id_to_domain_info_map(
115
115
  trimmed_domains: pd.DataFrame,
116
116
  ) -> Dict[int, Dict[str, Any]]:
117
- """Create a mapping from domain IDs to their corresponding full description and significance score.
117
+ """Create a mapping from domain IDs to their corresponding full description and significance score,
118
+ with scores sorted in descending order.
118
119
 
119
120
  Args:
120
121
  trimmed_domains (pd.DataFrame): DataFrame containing domain IDs, full descriptions, and significance scores.
121
122
 
122
123
  Returns:
123
- Dict[int, Dict[str, Any]]: A dictionary mapping domain IDs (int) to a dictionary with 'full_descriptions' and 'significance_scores'.
124
+ Dict[int, Dict[str, Any]]: A dictionary mapping domain IDs (int) to a dictionary with 'full_descriptions' and
125
+ 'significance_scores', both sorted by significance score in descending order.
124
126
  """
125
- return {
126
- int(id_): {
127
- "full_descriptions": trimmed_domains.at[id_, "full_descriptions"],
128
- "significance_scores": trimmed_domains.at[id_, "significance_scores"],
127
+ # Initialize an empty dictionary to store full descriptions and significance scores of domains
128
+ domain_info_map = {}
129
+ # Domain IDs are the index of the DataFrame (it's common for some IDs to be missing)
130
+ for domain_id in trimmed_domains.index:
131
+ # Sort full_descriptions and significance_scores by significance_scores in descending order
132
+ descriptions_and_scores = sorted(
133
+ zip(
134
+ trimmed_domains.at[domain_id, "full_descriptions"],
135
+ trimmed_domains.at[domain_id, "significance_scores"],
136
+ ),
137
+ key=lambda x: x[1], # Sort by significance score
138
+ reverse=True, # Descending order
139
+ )
140
+ # Unzip the sorted tuples back into separate lists
141
+ sorted_descriptions, sorted_scores = zip(*descriptions_and_scores)
142
+ # Assign to the domain info map
143
+ domain_info_map[int(domain_id)] = {
144
+ "full_descriptions": list(sorted_descriptions),
145
+ "significance_scores": list(sorted_scores),
129
146
  }
130
- for id_ in trimmed_domains.index
131
- }
147
+
148
+ return domain_info_map
132
149
 
133
150
  @staticmethod
134
151
  def _create_node_id_to_domain_ids_and_significances(domains: pd.DataFrame) -> Dict[int, Dict]:
@@ -3,8 +3,6 @@ risk/network/graph/summary
3
3
  ~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- import warnings
7
- from functools import lru_cache
8
6
  from typing import Any, Dict, Tuple, Union
9
7
 
10
8
  import numpy as np
@@ -14,11 +12,7 @@ from statsmodels.stats.multitest import fdrcorrection
14
12
  from risk.log.console import logger, log_header
15
13
 
16
14
 
17
- # Suppress all warnings - this is to resolve warnings from multiprocessing
18
- warnings.filterwarnings("ignore")
19
-
20
-
21
- class Summary:
15
+ class AnalysisSummary:
22
16
  """Handles the processing, storage, and export of network analysis results.
23
17
 
24
18
  The Results class provides methods to process significance and depletion data, compute
@@ -31,7 +25,7 @@ class Summary:
31
25
  self,
32
26
  annotations: Dict[str, Any],
33
27
  neighborhoods: Dict[str, Any],
34
- graph, # Avoid type hinting NetworkGraph to avoid circular import
28
+ graph, # Avoid type hinting NetworkGraph to prevent circular imports
35
29
  ):
36
30
  """Initialize the Results object with analysis components.
37
31
 
@@ -53,7 +47,7 @@ class Summary:
53
47
  # Load results and export directly to CSV
54
48
  results = self.load()
55
49
  results.to_csv(filepath, index=False)
56
- logger.info(f"Results summary exported to CSV file: {filepath}")
50
+ logger.info(f"Analysis summary exported to CSV file: {filepath}")
57
51
 
58
52
  def to_json(self, filepath: str) -> None:
59
53
  """Export significance results to a JSON file.
@@ -64,7 +58,7 @@ class Summary:
64
58
  # Load results and export directly to JSON
65
59
  results = self.load()
66
60
  results.to_json(filepath, orient="records", indent=4)
67
- logger.info(f"Results summary exported to JSON file: {filepath}")
61
+ logger.info(f"Analysis summary exported to JSON file: {filepath}")
68
62
 
69
63
  def to_txt(self, filepath: str) -> None:
70
64
  """Export significance results to a text file.
@@ -77,21 +71,16 @@ class Summary:
77
71
  with open(filepath, "w") as txt_file:
78
72
  txt_file.write(results.to_string(index=False))
79
73
 
80
- logger.info(f"Results summary exported to text file: {filepath}")
74
+ logger.info(f"Analysis summary exported to text file: {filepath}")
81
75
 
82
- @lru_cache(maxsize=None)
83
76
  def load(self) -> pd.DataFrame:
84
77
  """Load and process domain and annotation data into a DataFrame with significance metrics.
85
78
 
86
- Args:
87
- graph (Any): Graph object containing domain-to-node and node-to-label mappings.
88
- annotations (Dict[str, Any]): Annotation details, including ordered annotations and matrix.
89
-
90
79
  Returns:
91
80
  pd.DataFrame: Processed DataFrame containing significance scores, p-values, q-values,
92
81
  and annotation member information.
93
82
  """
94
- log_header("Loading parameters")
83
+ log_header("Loading analysis summary")
95
84
  # Calculate significance and depletion q-values from p-value matrices in `annotations`
96
85
  enrichment_pvals = self.neighborhoods["enrichment_pvals"]
97
86
  depletion_pvals = self.neighborhoods["depletion_pvals"]
@@ -132,12 +121,12 @@ class Summary:
132
121
  result_type="expand",
133
122
  )
134
123
  # Add annotation members and their counts
135
- results["Annotation Members"] = results["Annotation"].apply(
124
+ results["Annotation Members in Network"] = results["Annotation"].apply(
136
125
  lambda desc: self._get_annotation_members(desc)
137
126
  )
138
- results["Annotation Member Count"] = results["Annotation Members"].apply(
139
- lambda x: len(x.split(";")) if x else 0
140
- )
127
+ results["Annotation Members in Network Count"] = results[
128
+ "Annotation Members in Network"
129
+ ].apply(lambda x: len(x.split(";")) if x else 0)
141
130
 
142
131
  # Reorder columns and drop rows with NaN values
143
132
  results = (
@@ -145,8 +134,8 @@ class Summary:
145
134
  [
146
135
  "Domain ID",
147
136
  "Annotation",
148
- "Annotation Members",
149
- "Annotation Member Count",
137
+ "Annotation Members in Network",
138
+ "Annotation Members in Network Count",
150
139
  "Summed Significance Score",
151
140
  "Enrichment P-value",
152
141
  "Enrichment Q-value",
@@ -230,6 +219,7 @@ class Summary:
230
219
  except ValueError:
231
220
  return "" # Description not found
232
221
 
222
+ # Get nodes present for the annotation and sort by node label
233
223
  nodes_present = np.where(self.annotations["matrix"][:, annotation_idx] == 1)[0]
234
224
  node_labels = sorted(
235
225
  self.graph.node_id_to_node_label_map[node_id]
@@ -3,7 +3,7 @@ risk/network/plot/contour
3
3
  ~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import List, Tuple, Union
6
+ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import matplotlib.pyplot as plt
9
9
  import numpy as np
@@ -63,6 +63,8 @@ class Contour:
63
63
  contour_color=(
64
64
  "custom" if isinstance(color, np.ndarray) else color
65
65
  ), # np.ndarray usually indicates custom colors
66
+ contour_linestyle=linestyle,
67
+ contour_linewidth=linewidth,
66
68
  contour_alpha=alpha,
67
69
  contour_fill_alpha=fill_alpha,
68
70
  )
@@ -280,6 +282,7 @@ class Contour:
280
282
  min_scale: float = 0.8,
281
283
  max_scale: float = 1.0,
282
284
  scale_factor: float = 1.0,
285
+ ids_to_colors: Union[Dict[int, Any], None] = None,
283
286
  random_seed: int = 888,
284
287
  ) -> np.ndarray:
285
288
  """Get colors for the contours based on node annotations or a specified colormap.
@@ -296,6 +299,7 @@ class Contour:
296
299
  Controls the brightest colors. Defaults to 1.0.
297
300
  scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
298
301
  A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
302
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
299
303
  random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
300
304
 
301
305
  Returns:
@@ -310,6 +314,7 @@ class Contour:
310
314
  min_scale=min_scale,
311
315
  max_scale=max_scale,
312
316
  scale_factor=scale_factor,
317
+ ids_to_colors=ids_to_colors,
313
318
  random_seed=random_seed,
314
319
  )
315
320
 
@@ -54,7 +54,7 @@ class Labels:
54
54
  words_to_omit: Union[List, None] = None,
55
55
  overlay_ids: bool = False,
56
56
  ids_to_keep: Union[List, Tuple, np.ndarray, None] = None,
57
- ids_to_replace: Union[Dict, None] = None,
57
+ ids_to_labels: Union[Dict[int, str], None] = None,
58
58
  ) -> None:
59
59
  """Annotate the network graph with labels for different domains, positioned around the network for clarity.
60
60
 
@@ -62,7 +62,7 @@ class Labels:
62
62
  scale (float, optional): Scale factor for positioning labels around the perimeter. Defaults to 1.05.
63
63
  offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
64
64
  font (str, optional): Font name for the labels. Defaults to "Arial".
65
- fontcase (Union[str, Dict[str, str], None]): Defines how to transform the case of words.
65
+ fontcase (str, Dict[str, str], or None, optional): Defines how to transform the case of words.
66
66
  - If a string (e.g., 'upper', 'lower', 'title'), applies the transformation to all words.
67
67
  - If a dictionary, maps specific cases ('lower', 'upper', 'title') to transformations (e.g., 'lower'='upper').
68
68
  - If None, no transformation is applied.
@@ -87,7 +87,7 @@ class Labels:
87
87
  overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
88
88
  ids_to_keep (List, Tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
89
89
  you can set `overlay_ids=True`. Defaults to None.
90
- ids_to_replace (Dict, optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
90
+ ids_to_labels (Dict[int, str], optional): A dictionary mapping domain IDs to custom labels (strings). The labels should be
91
91
  space-separated words. If provided, the custom labels will replace the default domain terms. To discover domain IDs, you
92
92
  can set `overlay_ids=True`. Defaults to None.
93
93
 
@@ -119,7 +119,7 @@ class Labels:
119
119
  label_words_to_omit=words_to_omit,
120
120
  label_overlay_ids=overlay_ids,
121
121
  label_ids_to_keep=ids_to_keep,
122
- label_ids_to_replace=ids_to_replace,
122
+ label_ids_to_labels=ids_to_labels,
123
123
  )
124
124
 
125
125
  # Convert ids_to_keep to a tuple if it is not None
@@ -152,7 +152,7 @@ class Labels:
152
152
  self._process_ids_to_keep(
153
153
  domain_id_to_centroid_map=domain_id_to_centroid_map,
154
154
  ids_to_keep=ids_to_keep,
155
- ids_to_replace=ids_to_replace,
155
+ ids_to_labels=ids_to_labels,
156
156
  words_to_omit=words_to_omit,
157
157
  max_labels=max_labels,
158
158
  min_label_lines=min_label_lines,
@@ -173,7 +173,7 @@ class Labels:
173
173
  self._process_remaining_domains(
174
174
  domain_id_to_centroid_map=domain_id_to_centroid_map,
175
175
  ids_to_keep=ids_to_keep,
176
- ids_to_replace=ids_to_replace,
176
+ ids_to_labels=ids_to_labels,
177
177
  words_to_omit=words_to_omit,
178
178
  remaining_labels=remaining_labels,
179
179
  min_chars_per_line=min_chars_per_line,
@@ -368,7 +368,7 @@ class Labels:
368
368
  self,
369
369
  domain_id_to_centroid_map: Dict[str, np.ndarray],
370
370
  ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
371
- ids_to_replace: Union[Dict[str, str], None],
371
+ ids_to_labels: Union[Dict[int, str], None],
372
372
  words_to_omit: Union[List[str], None],
373
373
  max_labels: Union[int, None],
374
374
  min_label_lines: int,
@@ -384,7 +384,7 @@ class Labels:
384
384
  Args:
385
385
  domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
386
386
  ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
387
- ids_to_replace (Dict[str, str], optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
387
+ ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
388
388
  words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
389
389
  max_labels (int, optional): Maximum number of labels allowed.
390
390
  min_label_lines (int): Minimum number of lines in a label.
@@ -419,7 +419,7 @@ class Labels:
419
419
  domain=domain,
420
420
  domain_centroid=domain_centroid,
421
421
  domain_id_to_centroid_map=domain_id_to_centroid_map,
422
- ids_to_replace=ids_to_replace,
422
+ ids_to_labels=ids_to_labels,
423
423
  words_to_omit=words_to_omit,
424
424
  min_label_lines=min_label_lines,
425
425
  max_label_lines=max_label_lines,
@@ -434,7 +434,7 @@ class Labels:
434
434
  self,
435
435
  domain_id_to_centroid_map: Dict[str, np.ndarray],
436
436
  ids_to_keep: Union[List[str], Tuple[str], np.ndarray],
437
- ids_to_replace: Union[Dict[str, str], None],
437
+ ids_to_labels: Union[Dict[int, str], None],
438
438
  words_to_omit: Union[List[str], None],
439
439
  remaining_labels: int,
440
440
  min_label_lines: int,
@@ -450,7 +450,7 @@ class Labels:
450
450
  Args:
451
451
  domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
452
452
  ids_to_keep (List, Tuple, or np.ndarray, optional): IDs of domains that must be labeled.
453
- ids_to_replace (Dict[str, str], optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
453
+ ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
454
454
  words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
455
455
  remaining_labels (int): The remaining number of labels that can be generated.
456
456
  min_label_lines (int): Minimum number of lines in a label.
@@ -515,7 +515,7 @@ class Labels:
515
515
  domain=domain,
516
516
  domain_centroid=domain_centroid,
517
517
  domain_id_to_centroid_map=domain_id_to_centroid_map,
518
- ids_to_replace=ids_to_replace,
518
+ ids_to_labels=ids_to_labels,
519
519
  words_to_omit=words_to_omit,
520
520
  min_label_lines=min_label_lines,
521
521
  max_label_lines=max_label_lines,
@@ -536,7 +536,7 @@ class Labels:
536
536
  domain: str,
537
537
  domain_centroid: np.ndarray,
538
538
  domain_id_to_centroid_map: Dict[str, np.ndarray],
539
- ids_to_replace: Union[Dict[str, str], None],
539
+ ids_to_labels: Union[Dict[int, str], None],
540
540
  words_to_omit: Union[List[str], None],
541
541
  min_label_lines: int,
542
542
  max_label_lines: int,
@@ -552,7 +552,7 @@ class Labels:
552
552
  domain (str): Domain ID to process.
553
553
  domain_centroid (np.ndarray): Centroid position of the domain.
554
554
  domain_id_to_centroid_map (Dict[str, np.ndarray]): Mapping of domain IDs to their centroids.
555
- ids_to_replace (Dict[str, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
555
+ ids_to_labels (Dict[int, str], None, optional): A dictionary mapping domain IDs to custom labels. Defaults to None.
556
556
  words_to_omit (List[str], None, optional): List of words to omit from the labels. Defaults to None.
557
557
  min_label_lines (int): Minimum number of lines required in a label.
558
558
  max_label_lines (int): Maximum number of lines allowed in a label.
@@ -571,7 +571,7 @@ class Labels:
571
571
  # Process the domain terms
572
572
  domain_terms = self._process_terms(
573
573
  domain=domain,
574
- ids_to_replace=ids_to_replace,
574
+ ids_to_labels=ids_to_labels,
575
575
  words_to_omit=words_to_omit,
576
576
  max_label_lines=max_label_lines,
577
577
  min_chars_per_line=min_chars_per_line,
@@ -596,7 +596,7 @@ class Labels:
596
596
  def _process_terms(
597
597
  self,
598
598
  domain: str,
599
- ids_to_replace: Union[Dict[str, str], None],
599
+ ids_to_labels: Union[Dict[int, str], None],
600
600
  words_to_omit: Union[List[str], None],
601
601
  max_label_lines: int,
602
602
  min_chars_per_line: int,
@@ -606,8 +606,8 @@ class Labels:
606
606
 
607
607
  Args:
608
608
  domain (str): The domain being processed.
609
- ids_to_replace (Dict[str, str], optional): Dictionary mapping domain IDs to custom labels.
610
- words_to_omit (List, optional): List of words to omit from the labels.
609
+ ids_to_labels (Dict[int, str], None): Dictionary mapping domain IDs to custom labels.
610
+ words_to_omit (List[str], None): List of words to omit from the labels.
611
611
  max_label_lines (int): Maximum number of lines in a label.
612
612
  min_chars_per_line (int): Minimum number of characters in a line to display.
613
613
  max_chars_per_line (int): Maximum number of characters in a line to display.
@@ -615,9 +615,9 @@ class Labels:
615
615
  Returns:
616
616
  str: Processed terms separated by TERM_DELIMITER, with words combined if necessary to fit within constraints.
617
617
  """
618
- # Return custom labels if domain is in ids_to_replace
619
- if ids_to_replace and domain in ids_to_replace:
620
- terms = ids_to_replace[domain].replace(" ", TERM_DELIMITER)
618
+ # Return custom labels if domain is in ids_to_labels
619
+ if ids_to_labels and domain in ids_to_labels:
620
+ terms = ids_to_labels[domain].replace(" ", TERM_DELIMITER)
621
621
  return terms
622
622
 
623
623
  else:
@@ -645,6 +645,7 @@ class Labels:
645
645
  min_scale: float = 0.8,
646
646
  max_scale: float = 1.0,
647
647
  scale_factor: float = 1.0,
648
+ ids_to_colors: Union[Dict[int, Any], None] = None,
648
649
  random_seed: int = 888,
649
650
  ) -> np.ndarray:
650
651
  """Get colors for the labels based on node annotations or a specified colormap.
@@ -661,6 +662,7 @@ class Labels:
661
662
  Controls the brightest colors. Defaults to 1.0.
662
663
  scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
663
664
  A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
665
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
664
666
  random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
665
667
 
666
668
  Returns:
@@ -675,6 +677,7 @@ class Labels:
675
677
  min_scale=min_scale,
676
678
  max_scale=max_scale,
677
679
  scale_factor=scale_factor,
680
+ ids_to_colors=ids_to_colors,
678
681
  random_seed=random_seed,
679
682
  )
680
683
 
@@ -3,7 +3,7 @@ risk/network/plot/network
3
3
  ~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Any, List, Tuple, Union
6
+ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import networkx as nx
9
9
  import numpy as np
@@ -205,6 +205,7 @@ class Network:
205
205
  alpha: Union[float, None] = 1.0,
206
206
  nonsignificant_color: Union[str, List, Tuple, np.ndarray] = "white",
207
207
  nonsignificant_alpha: Union[float, None] = 1.0,
208
+ ids_to_colors: Union[Dict[int, Any], None] = None,
208
209
  random_seed: int = 888,
209
210
  ) -> np.ndarray:
210
211
  """Adjust the colors of nodes in the network graph based on significance.
@@ -224,6 +225,7 @@ class Network:
224
225
  Defaults to "white".
225
226
  nonsignificant_alpha (float, None, optional): Alpha value for non-significant nodes. If provided, it overrides any existing alpha values found
226
227
  in `nonsignificant_color`. Defaults to 1.0.
228
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
227
229
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
228
230
 
229
231
  Returns:
@@ -239,6 +241,7 @@ class Network:
239
241
  min_scale=min_scale,
240
242
  max_scale=max_scale,
241
243
  scale_factor=scale_factor,
244
+ ids_to_colors=ids_to_colors,
242
245
  random_seed=random_seed,
243
246
  )
244
247
  # Apply the alpha value for significant nodes
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import matplotlib
9
9
  import matplotlib.colors as mcolors
10
+ import networkx as nx
10
11
  import numpy as np
11
12
 
12
13
  from risk.network.graph import NetworkGraph
@@ -22,6 +23,7 @@ def get_annotated_domain_colors(
22
23
  min_scale: float = 0.8,
23
24
  max_scale: float = 1.0,
24
25
  scale_factor: float = 1.0,
26
+ ids_to_colors: Union[Dict[int, Any], None] = None,
25
27
  random_seed: int = 888,
26
28
  ) -> np.ndarray:
27
29
  """Get colors for the domains based on node annotations, or use a specified color.
@@ -37,6 +39,7 @@ def get_annotated_domain_colors(
37
39
  max_scale (float, optional): Maximum scale for color intensity when generating domain colors. Defaults to 1.0.
38
40
  scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on significance. Higher values
39
41
  increase the contrast. Defaults to 1.0.
42
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
40
43
  random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
41
44
 
42
45
  Returns:
@@ -52,6 +55,7 @@ def get_annotated_domain_colors(
52
55
  min_scale=min_scale,
53
56
  max_scale=max_scale,
54
57
  scale_factor=scale_factor,
58
+ ids_to_colors=ids_to_colors,
55
59
  random_seed=random_seed,
56
60
  )
57
61
  annotated_colors = []
@@ -59,14 +63,12 @@ def get_annotated_domain_colors(
59
63
  if len(node_ids) > 1:
60
64
  # For multi-node domains, choose the brightest color based on RGB sum
61
65
  domain_colors = np.array([node_colors[node] for node in node_ids])
62
- brightest_color = domain_colors[
63
- np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
64
- ]
65
- annotated_colors.append(brightest_color)
66
+ color = domain_colors[np.argmax(domain_colors[:, :3].sum(axis=1))] # Sum the RGB values
66
67
  else:
67
68
  # Single-node domains default to white (RGBA)
68
- default_color = np.array([1.0, 1.0, 1.0, 1.0])
69
- annotated_colors.append(default_color)
69
+ color = np.array([1.0, 1.0, 1.0, 1.0])
70
+
71
+ annotated_colors.append(color)
70
72
 
71
73
  return np.array(annotated_colors)
72
74
 
@@ -80,6 +82,7 @@ def get_domain_colors(
80
82
  min_scale: float = 0.8,
81
83
  max_scale: float = 1.0,
82
84
  scale_factor: float = 1.0,
85
+ ids_to_colors: Union[Dict[int, Any], None] = None,
83
86
  random_seed: int = 888,
84
87
  ) -> np.ndarray:
85
88
  """Generate composite colors for domains based on significance or specified colors.
@@ -97,16 +100,22 @@ def get_domain_colors(
97
100
  Defaults to 1.0.
98
101
  scale_factor (float, optional): Exponent for adjusting the color scaling based on significance scores. Higher values increase
99
102
  contrast by dimming lower scores more. Defaults to 1.0.
103
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
100
104
  random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments. Defaults to 888.
101
105
 
102
106
  Returns:
103
107
  np.ndarray: Array of RGBA colors generated for each domain, based on significance or the specified color.
104
108
  """
105
109
  # Get colors for each domain
106
- domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
110
+ domain_ids_to_colors = _get_domain_ids_to_colors(
111
+ graph=graph, cmap=cmap, color=color, ids_to_colors=ids_to_colors, random_seed=random_seed
112
+ )
107
113
  # Generate composite colors for nodes
108
114
  node_colors = _get_composite_node_colors(
109
- graph=graph, domain_colors=domain_colors, blend_colors=blend_colors, blend_gamma=blend_gamma
115
+ graph=graph,
116
+ domain_ids_to_colors=domain_ids_to_colors,
117
+ blend_colors=blend_colors,
118
+ blend_gamma=blend_gamma,
110
119
  )
111
120
  # Transform colors to ensure proper alpha values and intensity
112
121
  transformed_colors = _transform_colors(
@@ -119,10 +128,11 @@ def get_domain_colors(
119
128
  return transformed_colors
120
129
 
121
130
 
122
- def _get_domain_colors(
131
+ def _get_domain_ids_to_colors(
123
132
  graph: NetworkGraph,
124
133
  cmap: str = "gist_rainbow",
125
134
  color: Union[str, List, Tuple, np.ndarray, None] = None,
135
+ ids_to_colors: Union[Dict[int, Any], None] = None,
126
136
  random_seed: int = 888,
127
137
  ) -> Dict[int, Any]:
128
138
  """Get colors for each domain.
@@ -132,6 +142,7 @@ def _get_domain_colors(
132
142
  cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
133
143
  color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
134
144
  If None, the colormap will be used. Defaults to None.
145
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
135
146
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
136
147
 
137
148
  Returns:
@@ -145,17 +156,30 @@ def _get_domain_colors(
145
156
  color=color,
146
157
  random_seed=random_seed,
147
158
  )
148
- return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))
159
+ # Assign colors to domains either based on the generated colormap or the user-specified colors
160
+ domain_ids_to_colors = {}
161
+ for domain_id, domain_color in zip(graph.domain_id_to_node_ids_map.keys(), domain_colors):
162
+ if ids_to_colors and domain_id in ids_to_colors:
163
+ # Convert user-specified colors to RGBA format
164
+ user_rgba = to_rgba(ids_to_colors[domain_id])
165
+ domain_ids_to_colors[domain_id] = user_rgba
166
+ else:
167
+ domain_ids_to_colors[domain_id] = domain_color
168
+
169
+ return domain_ids_to_colors
149
170
 
150
171
 
151
172
  def _get_composite_node_colors(
152
- graph, domain_colors: np.ndarray, blend_colors: bool = False, blend_gamma: float = 2.2
173
+ graph: NetworkGraph,
174
+ domain_ids_to_colors: Dict[int, Any],
175
+ blend_colors: bool = False,
176
+ blend_gamma: float = 2.2,
153
177
  ) -> np.ndarray:
154
178
  """Generate composite colors for nodes based on domain colors and significance values, with optional color blending.
155
179
 
156
180
  Args:
157
181
  graph (NetworkGraph): The network data and attributes to be visualized.
158
- domain_colors (np.ndarray): Array or list of RGBA colors corresponding to each domain.
182
+ domain_ids_to_colors (Dict[int, Any]): Mapping of domain IDs to RGBA colors.
159
183
  blend_colors (bool): Whether to blend colors for nodes with multiple domains. Defaults to False.
160
184
  blend_gamma (float, optional): Gamma correction factor to be used for perceptual color blending.
161
185
  This parameter is only relevant if blend_colors is True. Defaults to 2.2.
@@ -167,11 +191,10 @@ def _get_composite_node_colors(
167
191
  num_nodes = len(graph.node_coordinates)
168
192
  # Initialize composite colors array with shape (number of nodes, 4) for RGBA
169
193
  composite_colors = np.zeros((num_nodes, 4))
170
-
171
194
  # If blending is not required, directly assign domain colors to nodes
172
195
  if not blend_colors:
173
196
  for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
174
- color = domain_colors[domain_id]
197
+ color = domain_ids_to_colors[domain_id]
175
198
  for node in nodes:
176
199
  composite_colors[node] = color
177
200
 
@@ -180,11 +203,11 @@ def _get_composite_node_colors(
180
203
  for node, node_info in graph.node_id_to_domain_ids_and_significance_map.items():
181
204
  domains = node_info["domains"] # List of domain IDs
182
205
  significances = node_info["significances"] # List of significance values
183
- # Filter domains and significances to keep only those with corresponding colors in domain_colors
206
+ # Filter domains and significances to keep only those with corresponding colors in domain_ids_to_colors
184
207
  filtered_domains_significances = [
185
208
  (domain_id, significance)
186
209
  for domain_id, significance in zip(domains, significances)
187
- if domain_id in domain_colors
210
+ if domain_id in domain_ids_to_colors
188
211
  ]
189
212
  # If no valid domains exist, skip this node
190
213
  if not filtered_domains_significances:
@@ -193,7 +216,7 @@ def _get_composite_node_colors(
193
216
  # Unpack filtered domains and significances
194
217
  filtered_domains, filtered_significances = zip(*filtered_domains_significances)
195
218
  # Get the colors corresponding to the valid filtered domains
196
- colors = [domain_colors[domain_id] for domain_id in filtered_domains]
219
+ colors = [domain_ids_to_colors[domain_id] for domain_id in filtered_domains]
197
220
  # Blend the colors using the given gamma (default is 2.2 if None)
198
221
  gamma = blend_gamma if blend_gamma is not None else 2.2
199
222
  composite_color = _blend_colors_perceptually(colors, filtered_significances, gamma)
@@ -204,8 +227,8 @@ def _get_composite_node_colors(
204
227
 
205
228
 
206
229
  def _get_colors(
207
- network,
208
- domain_id_to_node_ids_map,
230
+ network: nx.Graph,
231
+ domain_id_to_node_ids_map: Dict[int, Any],
209
232
  cmap: str = "gist_rainbow",
210
233
  color: Union[str, List, Tuple, np.ndarray, None] = None,
211
234
  random_seed: int = 888,
@@ -214,7 +237,7 @@ def _get_colors(
214
237
  close in space get maximally separated colors, while keeping some randomness.
215
238
 
216
239
  Args:
217
- network (NetworkX graph): The graph representing the network.
240
+ network (nx.Graph): The graph representing the network.
218
241
  domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
219
242
  cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
220
243
  color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
@@ -252,7 +275,7 @@ def _get_colors(
252
275
  return [colormap(pos) for pos in color_positions]
253
276
 
254
277
 
255
- def _assign_distant_colors(dist_matrix, num_colors_to_generate):
278
+ def _assign_distant_colors(dist_matrix: np.ndarray, num_colors_to_generate: int) -> np.ndarray:
256
279
  """Assign colors to centroids that are close in space, ensuring stark color differences.
257
280
 
258
281
  Args:
@@ -404,7 +427,8 @@ def to_rgba(
404
427
  return np.tile(
405
428
  rgba_color, (num_repeats, 1)
406
429
  ) # Repeat the color if num_repeats is provided
407
- return np.array([rgba_color]) # Return a single color wrapped in a numpy array
430
+
431
+ return rgba_color
408
432
 
409
433
  # Handle a list/array of colors
410
434
  elif isinstance(color, (list, tuple, np.ndarray)):
@@ -421,4 +445,4 @@ def to_rgba(
421
445
  return rgba_colors
422
446
 
423
447
  else:
424
- raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
448
+ raise ValueError("Color must be a valid string, RGB/RGBA, or array of RGB/RGBA colors.")
@@ -3,8 +3,9 @@ risk/network/plot/utils/layout
3
3
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Tuple
6
+ from typing import Any, Dict, List, Tuple
7
7
 
8
+ import networkx as nx
8
9
  import numpy as np
9
10
 
10
11
 
@@ -68,11 +69,13 @@ def refine_center_iteratively(
68
69
  return center, new_radius
69
70
 
70
71
 
71
- def calculate_centroids(network, domain_id_to_node_ids_map):
72
+ def calculate_centroids(
73
+ network: nx.Graph, domain_id_to_node_ids_map: Dict[int, Any]
74
+ ) -> List[Tuple[float, float]]:
72
75
  """Calculate the centroid for each domain based on node x and y coordinates in the network.
73
76
 
74
77
  Args:
75
- network (NetworkX graph): The graph representing the network.
78
+ network (nx.Graph): The graph representing the network.
76
79
  domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
77
80
 
78
81
  Returns:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.9b2
3
+ Version: 0.0.9b4
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -1,4 +1,4 @@
1
- risk/__init__.py,sha256=ICwmPcHgPSGMULpN30hPbaL4pu5GlN6TVBB5aduUyVM,112
1
+ risk/__init__.py,sha256=dZ3VXcE-IJQdsTfZrd9B8BAcyb3WnO6J_uyXuiDu7AQ,112
2
2
  risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
3
  risk/risk.py,sha256=De1vn8Xc-TKz6aTL0bvJI-SVrIqU3k0IWAbKc7dde1c,23618
4
4
  risk/annotations/__init__.py,sha256=kXgadEXaCh0z8OyhOhTj7c3qXGmWgOhaSZ4gSzSb59U,147
@@ -15,16 +15,16 @@ risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
15
15
  risk/network/geometry.py,sha256=gFtYUj9j9aul4paKq_qSGJn39Nazxu_MXv8m-tYYtrk,6840
16
16
  risk/network/io.py,sha256=AWSbZGLZHtl72KSlafQlcYoG00YLSznG7UYDi_wDT7M,22958
17
17
  risk/network/graph/__init__.py,sha256=H0YEiwqZ02LBTkH4blPwUjQ-DOUnhaTTNHM0BcXii6U,81
18
- risk/network/graph/network.py,sha256=4tDtQExFo9U1smaaxf-CaoxHOY99aagM2G11ap_DKbY,10192
19
- risk/network/graph/summary.py,sha256=zxkI9VyrYN5y41jlqLOIcDX0fF4wt24khtr6to36_uc,9239
18
+ risk/network/graph/network.py,sha256=j75Lfwd5VGIDv0GlVoYIgN6RRua7i-PNg5D-ssgRhfo,11190
19
+ risk/network/graph/summary.py,sha256=h2bpUjfwI1NMflkKwplGQEGPswfAtunormdTIEQYbvs,8987
20
20
  risk/network/plot/__init__.py,sha256=MfmaXJgAZJgXZ2wrhK8pXwzETlcMaLChhWXKAozniAo,98
21
21
  risk/network/plot/canvas.py,sha256=TlCpNtvoceizAumNr9I02JcBrBO6FiAFAa2ZC0bx3SU,13356
22
- risk/network/plot/contour.py,sha256=2ZVOlduo4Y4yIpXDJMIKO-v7eULcJ2QacQyOc7pUAxE,15267
23
- risk/network/plot/labels.py,sha256=YqeOhE7nah16kK3L88JnOJVqE6WWj7lm23niVdEc8cU,45504
24
- risk/network/plot/network.py,sha256=_yyOUoxJ_jelZV3TMCCTcGnt014TBYMUfecSLOiUb7E,13788
22
+ risk/network/plot/contour.py,sha256=91-K9jlV3K82Ax13BoLdDeZkBR5_8AnZujAblHBK-3A,15580
23
+ risk/network/plot/labels.py,sha256=S1UOYUB1WbsaQEVkEr5cUKQbUuebKJPKILqhL8zmQWM,45733
24
+ risk/network/plot/network.py,sha256=vcm53MlaWd-wmaPDo8-Ap2gJPAp0SNEGZ2NiGamkJKo,14014
25
25
  risk/network/plot/plotter.py,sha256=iTPMiTnTTatM_-q1Ox_bjt5Pvv-Lo8gceiYB6TVzDcw,5770
26
- risk/network/plot/utils/color.py,sha256=rGOx4WAdjyaeWFALybkKzJabm9VSmsWb5_hsb__pcNg,19701
27
- risk/network/plot/utils/layout.py,sha256=RnJq0yODpoheZnDl7KKFPQeXrnrsS3FLIdxupoYVZq4,3553
26
+ risk/network/plot/utils/color.py,sha256=T5qOlPhHzgLwileaVqT8Vq9A1ZQvaSMWvsCTnauAgTs,20802
27
+ risk/network/plot/utils/layout.py,sha256=6o7idoWQnyzujSWOFXQykUvyPy8NuRtJV04TnlbXXBo,3647
28
28
  risk/stats/__init__.py,sha256=WcgoETQ-hS0LQqKRsAMIPtP15xZ-4eul6VUBuUx4Wzc,220
29
29
  risk/stats/hypergeom.py,sha256=oc39f02ViB1vQ-uaDrxG_tzAT6dxQBRjc88EK2EGn78,2282
30
30
  risk/stats/poisson.py,sha256=polLgwS08MTCNzupYdmMUoEUYrJOjAbcYtYwjlfeE5Y,1803
@@ -32,8 +32,8 @@ risk/stats/stats.py,sha256=z8NrhiVj4BzJ250bVLfytpmfC7RzYu7mBuIZD_l0aCA,7222
32
32
  risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
33
33
  risk/stats/permutation/permutation.py,sha256=meBNSrbRa9P8WJ54n485l0H7VQJlMSfHqdN4aCKYCtQ,10105
34
34
  risk/stats/permutation/test_functions.py,sha256=lftOude6hee0pyR80HlBD32522JkDoN5hrKQ9VEbuoY,2345
35
- risk_network-0.0.9b2.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
36
- risk_network-0.0.9b2.dist-info/METADATA,sha256=ccV7eOOwXIFCkiyfNIvCwLkTNQed1DuOmwXL84Nx2dY,47497
37
- risk_network-0.0.9b2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
38
- risk_network-0.0.9b2.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
39
- risk_network-0.0.9b2.dist-info/RECORD,,
35
+ risk_network-0.0.9b4.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
36
+ risk_network-0.0.9b4.dist-info/METADATA,sha256=8iJaySThTILJJ7-8Au42vD89VFD_UJ64xYrSCKUSPNo,47497
37
+ risk_network-0.0.9b4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
38
+ risk_network-0.0.9b4.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
39
+ risk_network-0.0.9b4.dist-info/RECORD,,