risk-network 0.0.8b15__py3-none-any.whl → 0.0.8b17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/__init__.py CHANGED
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.8-beta.15"
10
+ __version__ = "0.0.8-beta.17"
risk/network/graph.py CHANGED
@@ -4,12 +4,11 @@ risk/network/graph
4
4
  """
5
5
 
6
6
  from collections import defaultdict
7
- from typing import Any, Dict, List, Tuple, Union
7
+ from typing import Any, Dict, List
8
8
 
9
9
  import networkx as nx
10
10
  import numpy as np
11
11
  import pandas as pd
12
- import matplotlib
13
12
 
14
13
 
15
14
  class NetworkGraph:
@@ -107,139 +106,6 @@ class NetworkGraph:
107
106
 
108
107
  return domain_id_to_label_map
109
108
 
110
- def get_domain_colors(
111
- self,
112
- cmap: str = "gist_rainbow",
113
- color: Union[str, None] = None,
114
- min_scale: float = 0.8,
115
- max_scale: float = 1.0,
116
- scale_factor: float = 1.0,
117
- random_seed: int = 888,
118
- ) -> np.ndarray:
119
- """Generate composite colors for domains based on enrichment or specified colors.
120
-
121
- Args:
122
- cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
123
- color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
124
- min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
125
- Controls the dimmest colors. Defaults to 0.8.
126
- max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
127
- Controls the brightest colors. Defaults to 1.0.
128
- scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores.
129
- A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
130
- random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments.
131
- Defaults to 888.
132
-
133
- Returns:
134
- np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
135
- """
136
- # Get colors for each domain
137
- domain_colors = self._get_domain_colors(cmap=cmap, color=color, random_seed=random_seed)
138
- # Generate composite colors for nodes
139
- node_colors = self._get_composite_node_colors(domain_colors)
140
- # Transform colors to ensure proper alpha values and intensity
141
- transformed_colors = _transform_colors(
142
- node_colors,
143
- self.node_enrichment_sums,
144
- min_scale=min_scale,
145
- max_scale=max_scale,
146
- scale_factor=scale_factor,
147
- )
148
-
149
- return transformed_colors
150
-
151
- def _get_domain_colors(
152
- self,
153
- cmap: str = "gist_rainbow",
154
- color: Union[str, None] = None,
155
- random_seed: int = 888,
156
- ) -> Dict[str, Any]:
157
- """Get colors for each domain.
158
-
159
- Args:
160
- cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
161
- color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
162
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
163
-
164
- Returns:
165
- dict: A dictionary mapping domain keys to their corresponding RGBA colors.
166
- """
167
- # Get colors for each domain based on node positions
168
- domain_colors = _get_colors(
169
- self.network,
170
- self.domain_id_to_node_ids_map,
171
- cmap=cmap,
172
- color=color,
173
- random_seed=random_seed,
174
- )
175
- return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
176
-
177
- def _get_composite_node_colors(self, domain_colors: np.ndarray) -> np.ndarray:
178
- """Generate composite colors for nodes based on domain colors and counts.
179
-
180
- Args:
181
- domain_colors (np.ndarray): Array of colors corresponding to each domain.
182
-
183
- Returns:
184
- np.ndarray: Array of composite colors for each node.
185
- """
186
- # Determine the number of nodes
187
- num_nodes = len(self.node_coordinates)
188
- # Initialize composite colors array with shape (number of nodes, 4) for RGBA
189
- composite_colors = np.zeros((num_nodes, 4))
190
- # Assign colors to nodes based on domain_colors
191
- for domain_id, nodes in self.domain_id_to_node_ids_map.items():
192
- color = domain_colors[domain_id]
193
- for node in nodes:
194
- composite_colors[node] = color
195
-
196
- return composite_colors
197
-
198
-
199
- def _transform_colors(
200
- colors: np.ndarray,
201
- enrichment_sums: np.ndarray,
202
- min_scale: float = 0.8,
203
- max_scale: float = 1.0,
204
- scale_factor: float = 1.0,
205
- ) -> np.ndarray:
206
- """Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
207
- very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
208
-
209
- Args:
210
- colors (np.ndarray): An array of RGBA colors.
211
- enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
212
- min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
213
- max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
214
- scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
215
- values more. Defaults to 1.0.
216
-
217
- Returns:
218
- np.ndarray: The transformed array of RGBA colors with adjusted intensities.
219
- """
220
- # Ensure that min_scale is less than max_scale
221
- if min_scale == max_scale:
222
- min_scale = max_scale - 10e-6 # Avoid division by zero
223
-
224
- # Replace black colors (#000000) with very dark grey (#1A1A1A)
225
- black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
226
- dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
227
- # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
228
- is_black = np.all(colors[:, :3] == black_color, axis=1)
229
- colors[is_black, :3] = dark_grey
230
-
231
- # Normalize the enrichment sums to the range [0, 1]
232
- normalized_sums = enrichment_sums / np.max(enrichment_sums)
233
- # Apply power scaling to dim lower values and emphasize higher values
234
- scaled_sums = normalized_sums**scale_factor
235
- # Linearly scale the normalized sums to the range [min_scale, max_scale]
236
- scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
237
- # Adjust RGB values based on scaled sums
238
- for i in range(3): # Only adjust RGB values
239
- colors[:, i] = scaled_sums * colors[:, i]
240
-
241
- return colors
242
-
243
109
 
244
110
  def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
245
111
  """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
@@ -291,103 +157,3 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
291
157
  }
292
158
  node_coordinates = np.vstack(list(node_positions.values()))
293
159
  return node_coordinates
294
-
295
-
296
- def _get_colors(
297
- network,
298
- domain_id_to_node_ids_map,
299
- cmap: str = "gist_rainbow",
300
- color: Union[str, None] = None,
301
- random_seed: int = 888,
302
- ) -> List[Tuple]:
303
- """Generate a list of RGBA colors based on domain centroids, ensuring that domains
304
- close in space get maximally separated colors, while keeping some randomness.
305
-
306
- Args:
307
- network (NetworkX graph): The graph representing the network.
308
- domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
309
- cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
310
- color (str or None, optional): A specific color to use for all generated colors.
311
- random_seed (int): Seed for random number generation. Defaults to 888.
312
-
313
- Returns:
314
- List[Tuple]: List of RGBA colors.
315
- """
316
- # Set random seed for reproducibility
317
- np.random.seed(random_seed)
318
- # Determine the number of colors to generate based on the number of domains
319
- num_colors_to_generate = len(domain_id_to_node_ids_map)
320
- if color:
321
- # Generate all colors as the same specified color
322
- rgba = matplotlib.colors.to_rgba(color)
323
- return [rgba] * num_colors_to_generate
324
-
325
- # Load colormap
326
- colormap = matplotlib.colormaps.get_cmap(cmap)
327
- # Step 1: Calculate centroids for each domain
328
- centroids = _calculate_centroids(network, domain_id_to_node_ids_map)
329
- # Step 2: Calculate pairwise distances between centroids
330
- centroid_array = np.array(centroids)
331
- dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
332
- # Step 3: Assign distant colors to close centroids
333
- color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
334
- # Step 4: Randomly shift the entire color palette while maintaining relative distances
335
- global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
336
- color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
337
- # Step 5: Ensure that all positions remain between 0 and 1
338
- color_positions = np.clip(color_positions, 0, 1)
339
-
340
- # Step 6: Generate RGBA colors based on positions
341
- return [colormap(pos) for pos in color_positions]
342
-
343
-
344
- def _calculate_centroids(network, domain_id_to_node_ids_map):
345
- """Calculate the centroid for each domain based on node x and y coordinates in the network.
346
-
347
- Args:
348
- network (NetworkX graph): The graph representing the network.
349
- domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
350
-
351
- Returns:
352
- List[Tuple[float, float]]: List of centroids (x, y) for each domain.
353
- """
354
- centroids = []
355
- for domain_id, node_ids in domain_id_to_node_ids_map.items():
356
- # Extract x and y coordinates from the network nodes
357
- node_positions = np.array(
358
- [[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
359
- )
360
- # Compute the centroid as the mean of the x and y coordinates
361
- centroid = np.mean(node_positions, axis=0)
362
- centroids.append(tuple(centroid))
363
-
364
- return centroids
365
-
366
-
367
- def _assign_distant_colors(dist_matrix, num_colors_to_generate):
368
- """Assign colors to centroids that are close in space, ensuring stark color differences.
369
-
370
- Args:
371
- dist_matrix (ndarray): Matrix of pairwise centroid distances.
372
- num_colors_to_generate (int): Number of colors to generate.
373
-
374
- Returns:
375
- np.array: Array of color positions in the range [0, 1].
376
- """
377
- color_positions = np.zeros(num_colors_to_generate)
378
- # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
379
- proximity_order = sorted(
380
- range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
381
- )
382
- # Step 2: Assign colors starting with the most distant points in proximity order
383
- for i, idx in enumerate(proximity_order):
384
- color_positions[idx] = i / num_colors_to_generate
385
-
386
- # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
387
- half_spectrum = int(num_colors_to_generate / 2)
388
- for i in range(half_spectrum):
389
- # Split the spectrum so that close centroids are assigned distant colors
390
- color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
391
- color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
392
-
393
- return color_positions
@@ -10,7 +10,8 @@ import numpy as np
10
10
 
11
11
  from risk.log import params
12
12
  from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils import calculate_bounding_box, to_rgba
13
+ from risk.network.plot.utils.color import to_rgba
14
+ from risk.network.plot.utils.layout import calculate_bounding_box
14
15
 
15
16
 
16
17
  class Canvas:
@@ -33,8 +34,8 @@ class Canvas:
33
34
  title_fontsize: int = 20,
34
35
  subtitle_fontsize: int = 14,
35
36
  font: str = "Arial",
36
- title_color: str = "black",
37
- subtitle_color: str = "gray",
37
+ title_color: Union[str, list, tuple, np.ndarray] = "black",
38
+ subtitle_color: Union[str, list, tuple, np.ndarray] = "gray",
38
39
  title_y: float = 0.975,
39
40
  title_space_offset: float = 0.075,
40
41
  subtitle_offset: float = 0.025,
@@ -47,8 +48,10 @@ class Canvas:
47
48
  title_fontsize (int, optional): Font size for the title. Defaults to 20.
48
49
  subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 14.
49
50
  font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
50
- title_color (str, optional): Color of the title text. Defaults to "black".
51
- subtitle_color (str, optional): Color of the subtitle text. Defaults to "gray".
51
+ title_color (str, list, tuple, or np.ndarray, optional): Color of the title text. Can be a string or an array of colors.
52
+ Defaults to "black".
53
+ subtitle_color (str, list, tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
54
+ Defaults to "gray".
52
55
  title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
53
56
  title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
54
57
  subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
@@ -141,7 +144,9 @@ class Canvas:
141
144
  )
142
145
 
143
146
  # Convert color to RGBA using the to_rgba helper function - use outline_alpha for the perimeter
144
- color = to_rgba(color=color, alpha=outline_alpha)
147
+ color = to_rgba(
148
+ color=color, alpha=outline_alpha, num_repeats=1
149
+ ) # num_repeats=1 for a single color
145
150
  # Set the fill_alpha to 0 if not provided
146
151
  fill_alpha = fill_alpha if fill_alpha is not None else 0.0
147
152
  # Extract node coordinates from the network graph
@@ -162,7 +167,9 @@ class Canvas:
162
167
  )
163
168
  # Set the transparency of the fill if applicable
164
169
  if fill_alpha > 0:
165
- circle.set_facecolor(to_rgba(color=color, alpha=fill_alpha))
170
+ circle.set_facecolor(
171
+ to_rgba(color=color, alpha=fill_alpha, num_repeats=1)
172
+ ) # num_repeats=1 for a single color
166
173
 
167
174
  self.ax.add_artist(circle)
168
175
 
@@ -209,7 +216,7 @@ class Canvas:
209
216
  )
210
217
 
211
218
  # Convert color to RGBA using outline_alpha for the line (outline)
212
- outline_color = to_rgba(color=color)
219
+ outline_color = to_rgba(color=color, num_repeats=1) # num_repeats=1 for a single color
213
220
  # Extract node coordinates from the network graph
214
221
  node_coordinates = self.graph.node_coordinates
215
222
  # Scale the node coordinates if needed
@@ -13,7 +13,7 @@ from scipy.stats import gaussian_kde
13
13
 
14
14
  from risk.log import params, logger
15
15
  from risk.network.graph import NetworkGraph
16
- from risk.network.plot.utils import get_annotated_domain_colors, to_rgba
16
+ from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
17
17
 
18
18
 
19
19
  class Contour:
@@ -110,7 +110,7 @@ class Contour:
110
110
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
111
111
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
112
112
  color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array.
113
- Defaults to "white".
113
+ Can be a single color or an array of colors. Defaults to "white".
114
114
  linestyle (str, optional): Line style for the contour. Defaults to "solid".
115
115
  linewidth (float, optional): Line width for the contour. Defaults to 1.5.
116
116
  alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
@@ -125,15 +125,16 @@ class Contour:
125
125
  if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
126
126
  # If it's a list of lists, iterate over sublists
127
127
  node_groups = nodes
128
+ # Convert color to RGBA arrays to match the number of groups
129
+ color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=len(node_groups))
128
130
  else:
129
131
  # If it's a flat list of nodes, treat it as a single group
130
132
  node_groups = [nodes]
131
-
132
- # Convert color to RGBA using the to_rgba helper function
133
- color_rgba = to_rgba(color=color, alpha=alpha)
133
+ # Wrap the RGBA color in an array to index the first element
134
+ color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=1)
134
135
 
135
136
  # Iterate over each group of nodes (either sublists or flat list)
136
- for sublist in node_groups:
137
+ for idx, sublist in enumerate(node_groups):
137
138
  # Filter to get node IDs and their coordinates for each sublist
138
139
  node_ids = [
139
140
  self.graph.node_label_to_node_id_map.get(node)
@@ -151,7 +152,7 @@ class Contour:
151
152
  self.ax,
152
153
  node_coordinates,
153
154
  node_ids,
154
- color=color_rgba,
155
+ color=color_rgba[idx],
155
156
  levels=levels,
156
157
  bandwidth=bandwidth,
157
158
  grid_size=grid_size,
@@ -273,7 +274,7 @@ class Contour:
273
274
  def get_annotated_contour_colors(
274
275
  self,
275
276
  cmap: str = "gist_rainbow",
276
- color: Union[str, None] = None,
277
+ color: Union[str, list, tuple, np.ndarray, None] = None,
277
278
  min_scale: float = 0.8,
278
279
  max_scale: float = 1.0,
279
280
  scale_factor: float = 1.0,
@@ -283,7 +284,8 @@ class Contour:
283
284
 
284
285
  Args:
285
286
  cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
286
- color (str or None, optional): Color to use for the contours. If None, the colormap will be used. Defaults to None.
287
+ color (str, list, tuple, np.ndarray, or None, optional): Color to use for the contours. Can be a single color or an array of colors.
288
+ If None, the colormap will be used. Defaults to None.
287
289
  min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
288
290
  Controls the dimmest colors. Defaults to 0.8.
289
291
  max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
@@ -11,7 +11,8 @@ import pandas as pd
11
11
 
12
12
  from risk.log import params
13
13
  from risk.network.graph import NetworkGraph
14
- from risk.network.plot.utils import calculate_bounding_box, get_annotated_domain_colors, to_rgba
14
+ from risk.network.plot.utils.color import get_annotated_domain_colors, to_rgba
15
+ from risk.network.plot.utils.layout import calculate_bounding_box
15
16
 
16
17
  TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
17
18
 
@@ -284,13 +285,19 @@ class Labels:
284
285
  if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
285
286
  # If it's a list of lists, iterate over sublists
286
287
  node_groups = nodes
288
+ # Convert fontcolor and arrow_color to RGBA arrays to match the number of groups
289
+ fontcolor_rgba = to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=len(node_groups))
290
+ arrow_color_rgba = to_rgba(
291
+ color=arrow_color, alpha=arrow_alpha, num_repeats=len(node_groups)
292
+ )
287
293
  else:
288
294
  # If it's a flat list of nodes, treat it as a single group
289
295
  node_groups = [nodes]
290
-
291
- # Convert fontcolor and arrow_color to RGBA
292
- fontcolor_rgba = to_rgba(color=fontcolor, alpha=fontalpha)
293
- arrow_color_rgba = to_rgba(color=arrow_color, alpha=arrow_alpha)
296
+ # Wrap the RGBA fontcolor and arrow_color in an array to index the first element
297
+ fontcolor_rgba = np.array(to_rgba(color=fontcolor, alpha=fontalpha, num_repeats=1))
298
+ arrow_color_rgba = np.array(
299
+ to_rgba(color=arrow_color, alpha=arrow_alpha, num_repeats=1)
300
+ )
294
301
 
295
302
  # Calculate the bounding box around the network
296
303
  center, radius = calculate_bounding_box(self.graph.node_coordinates, radius_margin=scale)
@@ -302,7 +309,7 @@ class Labels:
302
309
  )
303
310
 
304
311
  # Iterate over each group of nodes (either sublists or flat list)
305
- for sublist in node_groups:
312
+ for idx, sublist in enumerate(node_groups):
306
313
  # Map node labels to IDs
307
314
  node_ids = [
308
315
  self.graph.node_label_to_node_id_map.get(node)
@@ -326,10 +333,10 @@ class Labels:
326
333
  va="center",
327
334
  fontsize=fontsize,
328
335
  fontname=font,
329
- color=fontcolor_rgba,
336
+ color=fontcolor_rgba[idx],
330
337
  arrowprops=dict(
331
338
  arrowstyle=arrow_style,
332
- color=arrow_color_rgba,
339
+ color=arrow_color_rgba[idx],
333
340
  linewidth=arrow_linewidth,
334
341
  shrinkA=arrow_base_shrink,
335
342
  shrinkB=arrow_tip_shrink,
@@ -630,7 +637,7 @@ class Labels:
630
637
  def get_annotated_label_colors(
631
638
  self,
632
639
  cmap: str = "gist_rainbow",
633
- color: Union[str, None] = None,
640
+ color: Union[str, list, tuple, np.ndarray, None] = None,
634
641
  min_scale: float = 0.8,
635
642
  max_scale: float = 1.0,
636
643
  scale_factor: float = 1.0,
@@ -640,7 +647,8 @@ class Labels:
640
647
 
641
648
  Args:
642
649
  cmap (str, optional): Name of the colormap to use for generating label colors. Defaults to "gist_rainbow".
643
- color (str or None, optional): Color to use for the labels. If None, the colormap will be used. Defaults to None.
650
+ color (str, list, tuple, np.ndarray, or None, optional): Color to use for the labels. Can be a single color or an array
651
+ of colors. If None, the colormap will be used. Defaults to None.
644
652
  min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
645
653
  Controls the dimmest colors. Defaults to 0.8.
646
654
  max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
@@ -10,7 +10,7 @@ import numpy as np
10
10
 
11
11
  from risk.log import params
12
12
  from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils import to_rgba
13
+ from risk.network.plot.utils.color import get_domain_colors, to_rgba
14
14
 
15
15
 
16
16
  class Network:
@@ -47,8 +47,10 @@ class Network:
47
47
  edge_width (float, optional): Width of the edges. Defaults to 1.0.
48
48
  node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors.
49
49
  Defaults to "white".
50
- node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
51
- edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
50
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Can be a single color or an array of colors.
51
+ Defaults to "black".
52
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Can be a single color or an array of colors.
53
+ Defaults to "black".
52
54
  node_alpha (float, None, optional): Alpha value (transparency) for the nodes. If provided, it overrides any existing alpha
53
55
  values found in node_color. Defaults to 1.0. Annotated node_color alphas will override this value.
54
56
  edge_alpha (float, None, optional): Alpha value (transparency) for the edges. If provided, it overrides any existing alpha
@@ -194,12 +196,12 @@ class Network:
194
196
  def get_annotated_node_colors(
195
197
  self,
196
198
  cmap: str = "gist_rainbow",
197
- color: Union[str, None] = None,
199
+ color: Union[str, list, tuple, np.ndarray, None] = None,
198
200
  min_scale: float = 0.8,
199
201
  max_scale: float = 1.0,
200
202
  scale_factor: float = 1.0,
201
203
  alpha: Union[float, None] = 1.0,
202
- nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
204
+ nonenriched_color: Union[str, list, tuple, np.ndarray] = "white",
203
205
  nonenriched_alpha: Union[float, None] = 1.0,
204
206
  random_seed: int = 888,
205
207
  ) -> np.ndarray:
@@ -207,22 +209,25 @@ class Network:
207
209
 
208
210
  Args:
209
211
  cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
210
- color (str or None, optional): Color to use for the nodes. If None, the colormap will be used. Defaults to None.
212
+ color (str, list, tuple, np.ndarray, or None, optional): Color to use for the nodes. Can be a single color or an array of colors.
213
+ If None, the colormap will be used. Defaults to None.
211
214
  min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
212
215
  max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
213
216
  scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
214
- alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values
215
- found in color. Defaults to 1.0.
216
- nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
217
- nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing
218
- alpha values found in nonenriched_color. Defaults to 1.0.
217
+ alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values found in `color`.
218
+ Defaults to 1.0.
219
+ nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Can be a single color or an array of colors.
220
+ Defaults to "white".
221
+ nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing alpha values found
222
+ in `nonenriched_color`. Defaults to 1.0.
219
223
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
220
224
 
221
225
  Returns:
222
226
  np.ndarray: Array of RGBA colors adjusted for enrichment status.
223
227
  """
224
228
  # Get the initial domain colors for each node, which are returned as RGBA
225
- network_colors = self.graph.get_domain_colors(
229
+ network_colors = get_domain_colors(
230
+ graph=self.graph,
226
231
  cmap=cmap,
227
232
  color=color,
228
233
  min_scale=min_scale,
@@ -233,7 +238,9 @@ class Network:
233
238
  # Apply the alpha value for enriched nodes
234
239
  network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
235
240
  # Convert the non-enriched color to RGBA using the to_rgba helper function
236
- nonenriched_color = to_rgba(color=nonenriched_color, alpha=nonenriched_alpha)
241
+ nonenriched_color = to_rgba(
242
+ color=nonenriched_color, alpha=nonenriched_alpha, num_repeats=1
243
+ ) # num_repeats=1 for a single color
237
244
  # Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
238
245
  # 0.1 is a predefined threshold for the minimum color intensity
239
246
  adjusted_network_colors = np.where(
@@ -14,7 +14,8 @@ from risk.network.plot.canvas import Canvas
14
14
  from risk.network.plot.contour import Contour
15
15
  from risk.network.plot.labels import Labels
16
16
  from risk.network.plot.network import Network
17
- from risk.network.plot.utils import calculate_bounding_box, to_rgba
17
+ from risk.network.plot.utils.color import to_rgba
18
+ from risk.network.plot.utils.layout import calculate_bounding_box
18
19
 
19
20
 
20
21
  class NetworkPlotter(Canvas, Network, Contour, Labels):
@@ -58,7 +59,7 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
58
59
  self,
59
60
  graph: NetworkGraph,
60
61
  figsize: Tuple,
61
- background_color: Union[str, List, Tuple, np.ndarray],
62
+ background_color: Union[str, list, tuple, np.ndarray],
62
63
  background_alpha: Union[float, None],
63
64
  pad: float,
64
65
  ) -> plt.Axes:
@@ -67,9 +68,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
67
68
  Args:
68
69
  graph (NetworkGraph): The network data and attributes to be visualized.
69
70
  figsize (tuple): Size of the figure in inches (width, height).
70
- background_color (str): Background color of the plot.
71
- background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any
72
- existing alpha values found in background_color.
71
+ background_color (str, list, tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
72
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
73
+ alpha values found in `background_color`.
73
74
  pad (float, optional): Padding value to adjust the axis limits.
74
75
 
75
76
  Returns:
@@ -98,7 +99,9 @@ class NetworkPlotter(Canvas, Network, Contour, Labels):
98
99
 
99
100
  # Set the background color of the plot
100
101
  # Convert color to RGBA using the to_rgba helper function
101
- fig.patch.set_facecolor(to_rgba(color=background_color, alpha=background_alpha))
102
+ fig.patch.set_facecolor(
103
+ to_rgba(color=background_color, alpha=background_alpha, num_repeats=1)
104
+ ) # num_repeats=1 for single color
102
105
  ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
103
106
  # Remove axis spines for a cleaner look
104
107
  for spine in ax.spines.values():
@@ -0,0 +1,353 @@
1
+ """
2
+ risk/network/plot/utils/plot
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict, List, Tuple, Union
7
+
8
+ import matplotlib
9
+ import matplotlib.colors as mcolors
10
+ import numpy as np
11
+
12
+ from risk.network.graph import NetworkGraph
13
+ from risk.network.plot.utils.layout import calculate_centroids
14
+
15
+
16
+ def get_annotated_domain_colors(
17
+ graph: NetworkGraph,
18
+ cmap: str = "gist_rainbow",
19
+ color: Union[str, list, tuple, np.ndarray, None] = None,
20
+ min_scale: float = 0.8,
21
+ max_scale: float = 1.0,
22
+ scale_factor: float = 1.0,
23
+ random_seed: int = 888,
24
+ ) -> np.ndarray:
25
+ """Get colors for the domains based on node annotations, or use a specified color.
26
+
27
+ Args:
28
+ graph (NetworkGraph): The network data and attributes to be visualized.
29
+ cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
30
+ color (str, list, tuple, np.ndarray, or None, optional): Color to use for the domains. Can be a single color or an array of colors.
31
+ If None, the colormap will be used. Defaults to None.
32
+ min_scale (float, optional): Minimum scale for color intensity when generating domain colors. Defaults to 0.8.
33
+ max_scale (float, optional): Maximum scale for color intensity when generating domain colors. Defaults to 1.0.
34
+ scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on enrichment. Higher values
35
+ increase the contrast. Defaults to 1.0.
36
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
37
+
38
+ Returns:
39
+ np.ndarray: Array of RGBA colors for each domain.
40
+ """
41
+ # Generate domain colors based on the enrichment data
42
+ node_colors = get_domain_colors(
43
+ graph=graph,
44
+ cmap=cmap,
45
+ color=color,
46
+ min_scale=min_scale,
47
+ max_scale=max_scale,
48
+ scale_factor=scale_factor,
49
+ random_seed=random_seed,
50
+ )
51
+ annotated_colors = []
52
+ for _, node_ids in graph.domain_id_to_node_ids_map.items():
53
+ if len(node_ids) > 1:
54
+ # For multi-node domains, choose the brightest color based on RGB sum
55
+ domain_colors = np.array([node_colors[node] for node in node_ids])
56
+ brightest_color = domain_colors[
57
+ np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
58
+ ]
59
+ annotated_colors.append(brightest_color)
60
+ else:
61
+ # Single-node domains default to white (RGBA)
62
+ default_color = np.array([1.0, 1.0, 1.0, 1.0])
63
+ annotated_colors.append(default_color)
64
+
65
+ return np.array(annotated_colors)
66
+
67
+
68
+ def get_domain_colors(
69
+ graph: NetworkGraph,
70
+ cmap: str = "gist_rainbow",
71
+ color: Union[str, list, tuple, np.ndarray, None] = None,
72
+ min_scale: float = 0.8,
73
+ max_scale: float = 1.0,
74
+ scale_factor: float = 1.0,
75
+ random_seed: int = 888,
76
+ ) -> np.ndarray:
77
+ """Generate composite colors for domains based on enrichment or specified colors.
78
+
79
+ Args:
80
+ graph (NetworkGraph): The network data and attributes to be visualized.
81
+ cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
82
+ color (str, list, tuple, np.ndarray, or None, optional): A specific color or array of colors to use for all domains.
83
+ If None, the colormap will be used. Defaults to None.
84
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap. Controls the dimmest colors.
85
+ Defaults to 0.8.
86
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap. Controls the brightest colors.
87
+ Defaults to 1.0.
88
+ scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores. Higher values increase
89
+ contrast by dimming lower scores more. Defaults to 1.0.
90
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments. Defaults to 888.
91
+
92
+ Returns:
93
+ np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
94
+ """
95
+ # Get colors for each domain
96
+ domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
97
+ # Generate composite colors for nodes
98
+ node_colors = _get_composite_node_colors(graph=graph, domain_colors=domain_colors)
99
+ # Transform colors to ensure proper alpha values and intensity
100
+ transformed_colors = _transform_colors(
101
+ node_colors,
102
+ graph.node_enrichment_sums,
103
+ min_scale=min_scale,
104
+ max_scale=max_scale,
105
+ scale_factor=scale_factor,
106
+ )
107
+ return transformed_colors
108
+
109
+
110
+ def _get_domain_colors(
111
+ graph: NetworkGraph,
112
+ cmap: str = "gist_rainbow",
113
+ color: Union[str, list, tuple, np.ndarray, None] = None,
114
+ random_seed: int = 888,
115
+ ) -> Dict[str, Any]:
116
+ """Get colors for each domain.
117
+
118
+ Args:
119
+ graph (NetworkGraph): The network data and attributes to be visualized.
120
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
121
+ color (str, list, tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
122
+ If None, the colormap will be used. Defaults to None.
123
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
124
+
125
+ Returns:
126
+ dict: A dictionary mapping domain keys to their corresponding RGBA colors.
127
+ """
128
+ # Get colors for each domain based on node positions
129
+ domain_colors = _get_colors(
130
+ graph.network,
131
+ graph.domain_id_to_node_ids_map,
132
+ cmap=cmap,
133
+ color=color,
134
+ random_seed=random_seed,
135
+ )
136
+ return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))
137
+
138
+
139
+ def _get_composite_node_colors(graph: NetworkGraph, domain_colors: np.ndarray) -> np.ndarray:
140
+ """Generate composite colors for nodes based on domain colors and counts.
141
+
142
+ Args:
143
+ graph (NetworkGraph): The network data and attributes to be visualized.
144
+ domain_colors (np.ndarray): Array of colors corresponding to each domain.
145
+
146
+ Returns:
147
+ np.ndarray: Array of composite colors for each node.
148
+ """
149
+ # Determine the number of nodes
150
+ num_nodes = len(graph.node_coordinates)
151
+ # Initialize composite colors array with shape (number of nodes, 4) for RGBA
152
+ composite_colors = np.zeros((num_nodes, 4))
153
+ # Assign colors to nodes based on domain_colors
154
+ for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
155
+ color = domain_colors[domain_id]
156
+ for node in nodes:
157
+ composite_colors[node] = color
158
+
159
+ return composite_colors
160
+
161
+
162
+ def _get_colors(
163
+ network,
164
+ domain_id_to_node_ids_map,
165
+ cmap: str = "gist_rainbow",
166
+ color: Union[str, list, tuple, np.ndarray, None] = None,
167
+ random_seed: int = 888,
168
+ ) -> List[Tuple]:
169
+ """Generate a list of RGBA colors based on domain centroids, ensuring that domains
170
+ close in space get maximally separated colors, while keeping some randomness.
171
+
172
+ Args:
173
+ network (NetworkX graph): The graph representing the network.
174
+ domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
175
+ cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
176
+ color (str, list, tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
177
+ If None, the colormap will be used. Defaults to None.
178
+ random_seed (int, optional): Seed for random number generation. Defaults to 888.
179
+
180
+ Returns:
181
+ List[Tuple]: List of RGBA colors.
182
+ """
183
+ # Set random seed for reproducibility
184
+ np.random.seed(random_seed)
185
+ # Determine the number of colors to generate based on the number of domains
186
+ num_colors_to_generate = len(domain_id_to_node_ids_map)
187
+ if color:
188
+ # Generate all colors as the same specified color
189
+ rgba = to_rgba(color, num_repeats=num_colors_to_generate)
190
+ return rgba
191
+
192
+ # Load colormap
193
+ colormap = matplotlib.colormaps.get_cmap(cmap)
194
+ # Step 1: Calculate centroids for each domain
195
+ centroids = calculate_centroids(network, domain_id_to_node_ids_map)
196
+ # Step 2: Calculate pairwise distances between centroids
197
+ centroid_array = np.array(centroids)
198
+ dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
199
+ # Step 3: Assign distant colors to close centroids
200
+ color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
201
+ # Step 4: Randomly shift the entire color palette while maintaining relative distances
202
+ global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
203
+ color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
204
+ # Step 5: Ensure that all positions remain between 0 and 1
205
+ color_positions = np.clip(color_positions, 0, 1)
206
+
207
+ # Step 6: Generate RGBA colors based on positions
208
+ return [colormap(pos) for pos in color_positions]
209
+
210
+
211
+ def _assign_distant_colors(dist_matrix, num_colors_to_generate):
212
+ """Assign colors to centroids that are close in space, ensuring stark color differences.
213
+
214
+ Args:
215
+ dist_matrix (ndarray): Matrix of pairwise centroid distances.
216
+ num_colors_to_generate (int): Number of colors to generate.
217
+
218
+ Returns:
219
+ np.array: Array of color positions in the range [0, 1].
220
+ """
221
+ color_positions = np.zeros(num_colors_to_generate)
222
+ # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
223
+ proximity_order = sorted(
224
+ range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
225
+ )
226
+ # Step 2: Assign colors starting with the most distant points in proximity order
227
+ for i, idx in enumerate(proximity_order):
228
+ color_positions[idx] = i / num_colors_to_generate
229
+
230
+ # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
231
+ half_spectrum = int(num_colors_to_generate / 2)
232
+ for i in range(half_spectrum):
233
+ # Split the spectrum so that close centroids are assigned distant colors
234
+ color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
235
+ color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
236
+
237
+ return color_positions
238
+
239
+
240
+ def _transform_colors(
241
+ colors: np.ndarray,
242
+ enrichment_sums: np.ndarray,
243
+ min_scale: float = 0.8,
244
+ max_scale: float = 1.0,
245
+ scale_factor: float = 1.0,
246
+ ) -> np.ndarray:
247
+ """Transform colors using power scaling to emphasize high enrichment sums more. Black colors are replaced with
248
+ very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).
249
+
250
+ Args:
251
+ colors (np.ndarray): An array of RGBA colors.
252
+ enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
253
+ min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
254
+ max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
255
+ scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
256
+ values more. Defaults to 1.0.
257
+
258
+ Returns:
259
+ np.ndarray: The transformed array of RGBA colors with adjusted intensities.
260
+ """
261
+ # Ensure that min_scale is less than max_scale
262
+ if min_scale == max_scale:
263
+ min_scale = max_scale - 10e-6 # Avoid division by zero
264
+
265
+ # Replace black colors (#000000) with very dark grey (#1A1A1A)
266
+ black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
267
+ dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
268
+ # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
269
+ is_black = np.all(colors[:, :3] == black_color, axis=1)
270
+ colors[is_black, :3] = dark_grey
271
+
272
+ # Normalize the enrichment sums to the range [0, 1]
273
+ normalized_sums = enrichment_sums / np.max(enrichment_sums)
274
+ # Apply power scaling to dim lower values and emphasize higher values
275
+ scaled_sums = normalized_sums**scale_factor
276
+ # Linearly scale the normalized sums to the range [min_scale, max_scale]
277
+ scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
278
+ # Adjust RGB values based on scaled sums
279
+ for i in range(3): # Only adjust RGB values
280
+ colors[:, i] = scaled_sums * colors[:, i]
281
+
282
+ return colors
283
+
284
+
285
+ def to_rgba(
286
+ color: Union[str, list, tuple, np.ndarray],
287
+ alpha: Union[float, None] = None,
288
+ num_repeats: Union[int, None] = None,
289
+ ) -> np.ndarray:
290
+ """Convert color(s) to RGBA format, applying alpha and repeating as needed.
291
+
292
+ Args:
293
+ color (str, list, tuple, np.ndarray): The color(s) to convert. Can be a string (e.g., 'red'), a list or tuple of RGB/RGBA values,
294
+ or an `np.ndarray` of colors.
295
+ alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values found
296
+ in color.
297
+ num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
298
+
299
+ Returns:
300
+ np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
301
+ """
302
+
303
+ def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
304
+ """Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
305
+ # Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
306
+ if isinstance(c, str):
307
+ # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
308
+ rgba = np.array(mcolors.to_rgba(c))
309
+ elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
310
+ # Convert RGB (3) or RGBA (4) values to RGBA format
311
+ rgba = np.array(mcolors.to_rgba(c))
312
+ else:
313
+ raise ValueError(
314
+ f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
315
+ )
316
+
317
+ if alpha is not None: # Override alpha if provided
318
+ rgba[3] = alpha
319
+ return rgba
320
+
321
+ # If color is a 2D array of RGBA values, convert it to a list of lists
322
+ if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
323
+ color = [list(c) for c in color]
324
+
325
+ # Handle a single color (string or RGB/RGBA list/tuple)
326
+ if (
327
+ isinstance(color, str)
328
+ or isinstance(color, (list, tuple, np.ndarray))
329
+ and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
330
+ ):
331
+ rgba_color = convert_to_rgba(color)
332
+ if num_repeats:
333
+ return np.tile(
334
+ rgba_color, (num_repeats, 1)
335
+ ) # Repeat the color if num_repeats is provided
336
+ return np.array([rgba_color]) # Return a single color wrapped in a numpy array
337
+
338
+ # Handle a list/array of colors
339
+ elif isinstance(color, (list, tuple, np.ndarray)):
340
+ rgba_colors = np.array(
341
+ [convert_to_rgba(c) for c in color]
342
+ ) # Convert each color in the list to RGBA
343
+ # Handle repetition if num_repeats is provided
344
+ if num_repeats:
345
+ repeated_colors = np.array(
346
+ [rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
347
+ )
348
+ return repeated_colors
349
+
350
+ return rgba_colors
351
+
352
+ else:
353
+ raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
@@ -0,0 +1,53 @@
1
+ """
2
+ risk/network/plot/utils/layout
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import numpy as np
9
+
10
+
11
+ def calculate_bounding_box(
12
+ node_coordinates: np.ndarray, radius_margin: float = 1.05
13
+ ) -> Tuple[np.ndarray, float]:
14
+ """Calculate the bounding box of the network based on node coordinates.
15
+
16
+ Args:
17
+ node_coordinates (np.ndarray): Array of node coordinates (x, y).
18
+ radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
19
+
20
+ Returns:
21
+ tuple: Center of the bounding box and the radius (adjusted by the radius margin).
22
+ """
23
+ # Find minimum and maximum x, y coordinates
24
+ x_min, y_min = np.min(node_coordinates, axis=0)
25
+ x_max, y_max = np.max(node_coordinates, axis=0)
26
+ # Calculate the center of the bounding box
27
+ center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
28
+ # Calculate the radius of the bounding box, adjusted by the margin
29
+ radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
30
+ return center, radius
31
+
32
+
33
+ def calculate_centroids(network, domain_id_to_node_ids_map):
34
+ """Calculate the centroid for each domain based on node x and y coordinates in the network.
35
+
36
+ Args:
37
+ network (NetworkX graph): The graph representing the network.
38
+ domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
39
+
40
+ Returns:
41
+ List[Tuple[float, float]]: List of centroids (x, y) for each domain.
42
+ """
43
+ centroids = []
44
+ for domain_id, node_ids in domain_id_to_node_ids_map.items():
45
+ # Extract x and y coordinates from the network nodes
46
+ node_positions = np.array(
47
+ [[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
48
+ )
49
+ # Compute the centroid as the mean of the x and y coordinates
50
+ centroid = np.mean(node_positions, axis=0)
51
+ centroids.append(tuple(centroid))
52
+
53
+ return centroids
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.8b15
3
+ Version: 0.0.8b17
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -1,4 +1,4 @@
1
- risk/__init__.py,sha256=IvvDrNrYMIFlFEmfZ6J1gTBoHqc33MwADdxoLjnyZg4,113
1
+ risk/__init__.py,sha256=pF-sA0CxulV6zIC6axfgeXWkkQGyVZGuoGF_u4RgoQM,113
2
2
  risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
3
  risk/risk.py,sha256=slJXca_a726_D7oXwe765HaKTv3ZrOvhttyrWdCGPkA,21231
4
4
  risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
@@ -13,15 +13,16 @@ risk/neighborhoods/domains.py,sha256=D5MUIghbwyKKCAE8PN_HXvsO9NxLTGejQmyEqetD1Bk
13
13
  risk/neighborhoods/neighborhoods.py,sha256=M-wL4xB_BUTlSZg90swygO5NdrZ6hFUFqs6jsiZaqHk,18260
14
14
  risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
15
15
  risk/network/geometry.py,sha256=H1yGVVqgbfpzBzJwEheDLfvGLSA284jGQQTn612L4Vc,6759
16
- risk/network/graph.py,sha256=WxWh1BZvH9E_kgDb5HrwcQ-30hM8aNHPkBm2l3vzkAw,16973
16
+ risk/network/graph.py,sha256=x5cur1meitkR0YuE5vGxX0s_IFa5wkx8z44f_C1vK7U,6509
17
17
  risk/network/io.py,sha256=u0PPcKjp6Xze--7eDOlvalYkjQ9S2sjiC-ac2476PUI,22942
18
18
  risk/network/plot/__init__.py,sha256=MfmaXJgAZJgXZ2wrhK8pXwzETlcMaLChhWXKAozniAo,98
19
- risk/network/plot/canvas.py,sha256=LXHndwanWIBShChoPag8zgGHF2P9MFWYdEnLKc2eeb0,10295
20
- risk/network/plot/contour.py,sha256=YPG8Uz0VlJ4skLdGaTH_FmQN6A_ArK8XSTNo1LzkSws,14276
21
- risk/network/plot/labels.py,sha256=MQqM6iJJW91JI_dHzaFRbkrjK0yxJc5WBhTc_ea7Unk,43371
22
- risk/network/plot/network.py,sha256=t5eMh7mBJOh_Wa19aK8g_1zpL7maiXZAYw-TEFHxAVM,12819
23
- risk/network/plot/plotter.py,sha256=rQV4Db6Ud86FJm11uaBvgSuzpmGsrZxnsRnUKjg6w84,5572
24
- risk/network/plot/utils.py,sha256=jZgI8EysSjviQmdYAceZk2MwJXcdeFAkYp-odZNqV0k,6316
19
+ risk/network/plot/canvas.py,sha256=JnjPQaryRb_J6LP36BT2-rlsbJO3T4tTBornL8Oqqbs,10778
20
+ risk/network/plot/contour.py,sha256=8uwJ7K-Z6VMyr_uQ5VUyoQSqDHA7zDvR_nYAmLn60-I,14647
21
+ risk/network/plot/labels.py,sha256=Gt2HIbYsC5QZNZa_Mk29hijDVv0_V0pXEl1IILaNKkY,44016
22
+ risk/network/plot/network.py,sha256=9blVFeCp5x5XoGhPwOOdADegXC4gC72c2vrM2u4QPe0,13235
23
+ risk/network/plot/plotter.py,sha256=lN-_GDXRk9V3IFu8q7QmPjJGBZiP0QYwSvU6dVVDV2E,5770
24
+ risk/network/plot/utils/color.py,sha256=4W4EoQ_Fs4tmbngdczXnFkkAjvyYP5EV_P2Vu-TCCwY,15573
25
+ risk/network/plot/utils/layout.py,sha256=znssSqe2VZzzSz47hLZtTuXwMTpHR9b8lkQPL0BX7OA,1950
25
26
  risk/stats/__init__.py,sha256=WcgoETQ-hS0LQqKRsAMIPtP15xZ-4eul6VUBuUx4Wzc,220
26
27
  risk/stats/hypergeom.py,sha256=o6Qnj31gCAKxr2uQirXrbv7XvdDJGEq69MFW-ubx_hA,2272
27
28
  risk/stats/poisson.py,sha256=8x9hB4DCukq4gNIlIKO-c_jYG1-BTwTX53oLauFyfj8,1793
@@ -29,8 +30,8 @@ risk/stats/stats.py,sha256=kvShov-94W6ffgDUTb522vB9hDJQSyTsYif_UIaFfSM,7059
29
30
  risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
30
31
  risk/stats/permutation/permutation.py,sha256=D84Rcpt6iTQniK0PfQGcw9bLcHbMt9p-ARcurUnIXZQ,10095
31
32
  risk/stats/permutation/test_functions.py,sha256=lftOude6hee0pyR80HlBD32522JkDoN5hrKQ9VEbuoY,2345
32
- risk_network-0.0.8b15.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
33
- risk_network-0.0.8b15.dist-info/METADATA,sha256=4hN4nT8j1a52ghqyJQ0mhjaVuhw2mt3EgKMAjpYTagY,47498
34
- risk_network-0.0.8b15.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
35
- risk_network-0.0.8b15.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
36
- risk_network-0.0.8b15.dist-info/RECORD,,
33
+ risk_network-0.0.8b17.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
34
+ risk_network-0.0.8b17.dist-info/METADATA,sha256=HmkNnAzSs7IdEGzYirOsrzy0euWVd0Ezn4OztCtXhnE,47498
35
+ risk_network-0.0.8b17.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
36
+ risk_network-0.0.8b17.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
37
+ risk_network-0.0.8b17.dist-info/RECORD,,
@@ -1,153 +0,0 @@
1
- """
2
- risk/network/plot/utils
3
- ~~~~~~~~~~~~~~~~~~~~~~~
4
- """
5
-
6
- from typing import List, Tuple, Union
7
-
8
- import matplotlib.colors as mcolors
9
- import numpy as np
10
-
11
- from risk.network.graph import NetworkGraph
12
-
13
-
14
- def get_annotated_domain_colors(
15
- graph: NetworkGraph,
16
- cmap: str = "gist_rainbow",
17
- color: Union[str, None] = None,
18
- min_scale: float = 0.8,
19
- max_scale: float = 1.0,
20
- scale_factor: float = 1.0,
21
- random_seed: int = 888,
22
- ) -> np.ndarray:
23
- """Get colors for the domains based on node annotations, or use a specified color.
24
-
25
- Args:
26
- graph (NetworkGraph): The network data and attributes to be visualized.
27
- cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
28
- color (str or None, optional): Color to use for the domains. If None, the colormap will be used. Defaults to None.
29
- min_scale (float, optional): Minimum scale for color intensity when generating domain colors.
30
- Defaults to 0.8.
31
- max_scale (float, optional): Maximum scale for color intensity when generating domain colors.
32
- Defaults to 1.0.
33
- scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on
34
- enrichment. Higher values increase the contrast. Defaults to 1.0.
35
- random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
36
-
37
- Returns:
38
- np.ndarray: Array of RGBA colors for each domain.
39
- """
40
- # Generate domain colors based on the enrichment data
41
- node_colors = graph.get_domain_colors(
42
- cmap=cmap,
43
- color=color,
44
- min_scale=min_scale,
45
- max_scale=max_scale,
46
- scale_factor=scale_factor,
47
- random_seed=random_seed,
48
- )
49
- annotated_colors = []
50
- for _, node_ids in graph.domain_id_to_node_ids_map.items():
51
- if len(node_ids) > 1:
52
- # For multi-node domains, choose the brightest color based on RGB sum
53
- domain_colors = np.array([node_colors[node] for node in node_ids])
54
- brightest_color = domain_colors[
55
- np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
56
- ]
57
- annotated_colors.append(brightest_color)
58
- else:
59
- # Single-node domains default to white (RGBA)
60
- default_color = np.array([1.0, 1.0, 1.0, 1.0])
61
- annotated_colors.append(default_color)
62
-
63
- return np.array(annotated_colors)
64
-
65
-
66
- def calculate_bounding_box(
67
- node_coordinates: np.ndarray, radius_margin: float = 1.05
68
- ) -> Tuple[np.ndarray, float]:
69
- """Calculate the bounding box of the network based on node coordinates.
70
-
71
- Args:
72
- node_coordinates (np.ndarray): Array of node coordinates (x, y).
73
- radius_margin (float, optional): Margin factor to apply to the bounding box radius. Defaults to 1.05.
74
-
75
- Returns:
76
- tuple: Center of the bounding box and the radius (adjusted by the radius margin).
77
- """
78
- # Find minimum and maximum x, y coordinates
79
- x_min, y_min = np.min(node_coordinates, axis=0)
80
- x_max, y_max = np.max(node_coordinates, axis=0)
81
- # Calculate the center of the bounding box
82
- center = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2])
83
- # Calculate the radius of the bounding box, adjusted by the margin
84
- radius = max(x_max - x_min, y_max - y_min) / 2 * radius_margin
85
- return center, radius
86
-
87
-
88
- def to_rgba(
89
- color: Union[str, List, Tuple, np.ndarray],
90
- alpha: Union[float, None] = None,
91
- num_repeats: Union[int, None] = None,
92
- ) -> np.ndarray:
93
- """Convert color(s) to RGBA format, applying alpha and repeating as needed.
94
-
95
- Args:
96
- color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
97
- alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values
98
- found in color.
99
- num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.
100
-
101
- Returns:
102
- np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
103
- """
104
-
105
- def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
106
- """Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
107
- # Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
108
- if isinstance(c, str):
109
- # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
110
- rgba = np.array(mcolors.to_rgba(c))
111
- elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
112
- # Convert RGB (3) or RGBA (4) values to RGBA format
113
- rgba = np.array(mcolors.to_rgba(c))
114
- else:
115
- raise ValueError(
116
- f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
117
- )
118
-
119
- if alpha is not None: # Override alpha if provided
120
- rgba[3] = alpha
121
- return rgba
122
-
123
- # If color is a 2D array of RGBA values, convert it to a list of lists
124
- if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
125
- color = [list(c) for c in color]
126
-
127
- # Handle a single color (string or RGB/RGBA list/tuple)
128
- if isinstance(color, (str, list, tuple)) and not any(
129
- isinstance(c, (list, tuple, np.ndarray)) for c in color
130
- ):
131
- rgba_color = convert_to_rgba(color)
132
- if num_repeats:
133
- return np.tile(
134
- rgba_color, (num_repeats, 1)
135
- ) # Repeat the color if num_repeats is provided
136
- return np.array([rgba_color]) # Return a single color wrapped in a numpy array
137
-
138
- # Handle a list/array of colors
139
- elif isinstance(color, (list, tuple, np.ndarray)):
140
- rgba_colors = np.array(
141
- [convert_to_rgba(c) for c in color]
142
- ) # Convert each color in the list to RGBA
143
- # Handle repetition if num_repeats is provided
144
- if num_repeats:
145
- repeated_colors = np.array(
146
- [rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
147
- )
148
- return repeated_colors
149
-
150
- return rgba_colors
151
-
152
- else:
153
- raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")