risk-network 0.0.9b23__py3-none-any.whl → 0.0.9b25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. risk/__init__.py +1 -1
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +9 -9
  4. risk/annotations/io.py +0 -2
  5. risk/log/__init__.py +2 -2
  6. risk/neighborhoods/__init__.py +3 -5
  7. risk/neighborhoods/api.py +446 -0
  8. risk/neighborhoods/community.py +4 -2
  9. risk/neighborhoods/domains.py +28 -1
  10. risk/network/__init__.py +1 -3
  11. risk/network/graph/__init__.py +1 -1
  12. risk/network/graph/api.py +194 -0
  13. risk/network/graph/summary.py +6 -2
  14. risk/network/io.py +0 -2
  15. risk/network/plotter/__init__.py +6 -0
  16. risk/network/plotter/api.py +54 -0
  17. risk/network/{plot → plotter}/canvas.py +3 -3
  18. risk/network/{plot → plotter}/contour.py +2 -2
  19. risk/network/{plot → plotter}/labels.py +3 -3
  20. risk/network/{plot → plotter}/network.py +136 -3
  21. risk/network/{plot → plotter}/utils/colors.py +15 -6
  22. risk/risk.py +10 -483
  23. risk/stats/__init__.py +8 -4
  24. risk/stats/binom.py +51 -0
  25. risk/stats/chi2.py +69 -0
  26. risk/stats/hypergeom.py +27 -17
  27. risk/stats/permutation/__init__.py +1 -1
  28. risk/stats/permutation/permutation.py +44 -55
  29. risk/stats/permutation/test_functions.py +25 -17
  30. risk/stats/poisson.py +15 -9
  31. risk/stats/zscore.py +68 -0
  32. {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/METADATA +1 -1
  33. risk_network-0.0.9b25.dist-info/RECORD +44 -0
  34. risk/network/plot/__init__.py +0 -6
  35. risk/network/plot/plotter.py +0 -143
  36. risk_network-0.0.9b23.dist-info/RECORD +0 -39
  37. /risk/network/{plot → plotter}/utils/layout.py +0 -0
  38. {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/LICENSE +0 -0
  39. {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/WHEEL +0 -0
  40. {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ """
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import copy
7
+ from typing import Any, Dict
8
+
9
+ import networkx as nx
10
+ import pandas as pd
11
+
12
+ from risk.annotations import define_top_annotations
13
+ from risk.log import logger, log_header, params
14
+ from risk.neighborhoods import (
15
+ define_domains,
16
+ process_neighborhoods,
17
+ trim_domains,
18
+ )
19
+ from risk.network.graph.network import NetworkGraph
20
+ from risk.stats import calculate_significance_matrices
21
+
22
+
23
+ class GraphAPI:
24
+ """Handles the loading of network graphs and associated data.
25
+
26
+ The GraphAPI class provides methods to load and process network graphs, annotations, and neighborhoods.
27
+ """
28
+
29
+ def __init__() -> None:
30
+ pass
31
+
32
+ def load_graph(
33
+ self,
34
+ network: nx.Graph,
35
+ annotations: Dict[str, Any],
36
+ neighborhoods: Dict[str, Any],
37
+ tail: str = "right",
38
+ pval_cutoff: float = 0.01,
39
+ fdr_cutoff: float = 0.9999,
40
+ impute_depth: int = 0,
41
+ prune_threshold: float = 0.0,
42
+ linkage_criterion: str = "distance",
43
+ linkage_method: str = "average",
44
+ linkage_metric: str = "yule",
45
+ min_cluster_size: int = 5,
46
+ max_cluster_size: int = 1000,
47
+ ) -> NetworkGraph:
48
+ """Load and process the network graph, defining top annotations and domains.
49
+
50
+ Args:
51
+ network (nx.Graph): The network graph.
52
+ annotations (Dict[str, Any]): The annotations associated with the network.
53
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
54
+ tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
55
+ pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
56
+ fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
57
+ impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
58
+ prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
59
+ linkage_criterion (str, optional): Clustering criterion for defining domains. Defaults to "distance".
60
+ linkage_method (str, optional): Clustering method to use. Defaults to "average".
61
+ linkage_metric (str, optional): Metric to use for calculating distances. Defaults to "yule".
62
+ min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
63
+ max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
64
+
65
+ Returns:
66
+ NetworkGraph: A fully initialized and processed NetworkGraph object.
67
+ """
68
+ # Log the parameters and display headers
69
+ log_header("Finding significant neighborhoods")
70
+ params.log_graph(
71
+ tail=tail,
72
+ pval_cutoff=pval_cutoff,
73
+ fdr_cutoff=fdr_cutoff,
74
+ impute_depth=impute_depth,
75
+ prune_threshold=prune_threshold,
76
+ linkage_criterion=linkage_criterion,
77
+ linkage_method=linkage_method,
78
+ linkage_metric=linkage_metric,
79
+ min_cluster_size=min_cluster_size,
80
+ max_cluster_size=max_cluster_size,
81
+ )
82
+
83
+ # Make a copy of the network to avoid modifying the original
84
+ network = copy.deepcopy(network)
85
+
86
+ logger.debug(f"p-value cutoff: {pval_cutoff}")
87
+ logger.debug(f"FDR BH cutoff: {fdr_cutoff}")
88
+ logger.debug(
89
+ f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
90
+ )
91
+ # Calculate significant neighborhoods based on the provided parameters
92
+ significant_neighborhoods = calculate_significance_matrices(
93
+ neighborhoods["depletion_pvals"],
94
+ neighborhoods["enrichment_pvals"],
95
+ tail=tail,
96
+ pval_cutoff=pval_cutoff,
97
+ fdr_cutoff=fdr_cutoff,
98
+ )
99
+
100
+ log_header("Processing neighborhoods")
101
+ # Process neighborhoods by imputing and pruning based on the given settings
102
+ processed_neighborhoods = process_neighborhoods(
103
+ network=network,
104
+ neighborhoods=significant_neighborhoods,
105
+ impute_depth=impute_depth,
106
+ prune_threshold=prune_threshold,
107
+ )
108
+
109
+ log_header("Finding top annotations")
110
+ logger.debug(f"Min cluster size: {min_cluster_size}")
111
+ logger.debug(f"Max cluster size: {max_cluster_size}")
112
+ # Define top annotations based on processed neighborhoods
113
+ top_annotations = self._define_top_annotations(
114
+ network=network,
115
+ annotations=annotations,
116
+ neighborhoods=processed_neighborhoods,
117
+ min_cluster_size=min_cluster_size,
118
+ max_cluster_size=max_cluster_size,
119
+ )
120
+
121
+ log_header("Optimizing distance threshold for domains")
122
+ # Extract the significant significance matrix from the neighborhoods data
123
+ significant_neighborhoods_significance = processed_neighborhoods[
124
+ "significant_significance_matrix"
125
+ ]
126
+ # Define domains in the network using the specified clustering settings
127
+ domains = define_domains(
128
+ top_annotations=top_annotations,
129
+ significant_neighborhoods_significance=significant_neighborhoods_significance,
130
+ linkage_criterion=linkage_criterion,
131
+ linkage_method=linkage_method,
132
+ linkage_metric=linkage_metric,
133
+ )
134
+ # Trim domains and top annotations based on cluster size constraints
135
+ domains, trimmed_domains = trim_domains(
136
+ domains=domains,
137
+ top_annotations=top_annotations,
138
+ min_cluster_size=min_cluster_size,
139
+ max_cluster_size=max_cluster_size,
140
+ )
141
+
142
+ # Prepare node mapping and significance sums for the final NetworkGraph object
143
+ ordered_nodes = annotations["ordered_nodes"]
144
+ node_label_to_id = dict(zip(ordered_nodes, range(len(ordered_nodes))))
145
+ node_significance_sums = processed_neighborhoods["node_significance_sums"]
146
+
147
+ # Return the fully initialized NetworkGraph object
148
+ return NetworkGraph(
149
+ network=network,
150
+ annotations=annotations,
151
+ neighborhoods=neighborhoods,
152
+ domains=domains,
153
+ trimmed_domains=trimmed_domains,
154
+ node_label_to_node_id_map=node_label_to_id,
155
+ node_significance_sums=node_significance_sums,
156
+ )
157
+
158
+ def _define_top_annotations(
159
+ self,
160
+ network: nx.Graph,
161
+ annotations: Dict[str, Any],
162
+ neighborhoods: Dict[str, Any],
163
+ min_cluster_size: int = 5,
164
+ max_cluster_size: int = 1000,
165
+ ) -> pd.DataFrame:
166
+ """Define top annotations for the network.
167
+
168
+ Args:
169
+ network (nx.Graph): The network graph.
170
+ annotations (Dict[str, Any]): Annotations data for the network.
171
+ neighborhoods (Dict[str, Any]): Neighborhood significance data.
172
+ min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
173
+ max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
174
+
175
+ Returns:
176
+ Dict[str, Any]: Top annotations identified within the network.
177
+ """
178
+ # Extract necessary data from annotations and neighborhoods
179
+ ordered_annotations = annotations["ordered_annotations"]
180
+ neighborhood_significance_sums = neighborhoods["neighborhood_significance_counts"]
181
+ significant_significance_matrix = neighborhoods["significant_significance_matrix"]
182
+ significant_binary_significance_matrix = neighborhoods[
183
+ "significant_binary_significance_matrix"
184
+ ]
185
+ # Call external function to define top annotations
186
+ return define_top_annotations(
187
+ network=network,
188
+ ordered_annotation_labels=ordered_annotations,
189
+ neighborhood_significance_sums=neighborhood_significance_sums,
190
+ significant_significance_matrix=significant_significance_matrix,
191
+ significant_binary_significance_matrix=significant_binary_significance_matrix,
192
+ min_cluster_size=min_cluster_size,
193
+ max_cluster_size=max_cluster_size,
194
+ )
@@ -240,8 +240,12 @@ class AnalysisSummary:
240
240
  except ValueError:
241
241
  return "" # Description not found
242
242
 
243
- # Get nodes present for the annotation and sort by node label
244
- nodes_present = np.where(self.annotations["matrix"][:, annotation_idx] == 1)[0]
243
+ # Get the column (safely) from the sparse matrix
244
+ column = self.annotations["matrix"][:, annotation_idx]
245
+ # Convert the column to a dense array if needed
246
+ column = column.toarray().ravel() # Convert to a 1D dense array
247
+ # Get nodes present for the annotation and sort by node label - use np.where on the dense array
248
+ nodes_present = np.where(column == 1)[0]
245
249
  node_labels = sorted(
246
250
  self.graph.node_id_to_node_label_map[node_id]
247
251
  for node_id in nodes_present
risk/network/io.py CHANGED
@@ -1,8 +1,6 @@
1
1
  """
2
2
  risk/network/io
3
3
  ~~~~~~~~~~~~~~~
4
-
5
- This file contains the code for the RISK class and command-line access.
6
4
  """
7
5
 
8
6
  import copy
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/plot
3
+ ~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.plotter.api import PlotterAPI
@@ -0,0 +1,54 @@
1
+ """
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import List, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from risk.log import log_header
11
+ from risk.network.graph.network import NetworkGraph
12
+ from risk.network.plotter.network import NetworkPlotter
13
+
14
+
15
+ class PlotterAPI:
16
+ """Handles the loading of network plotter objects.
17
+
18
+ The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
19
+ """
20
+
21
+ def __init__() -> None:
22
+ pass
23
+
24
+ def load_plotter(
25
+ self,
26
+ graph: NetworkGraph,
27
+ figsize: Union[List, Tuple, np.ndarray] = (10, 10),
28
+ background_color: str = "white",
29
+ background_alpha: Union[float, None] = 1.0,
30
+ pad: float = 0.3,
31
+ ) -> NetworkPlotter:
32
+ """Get a NetworkPlotter object for plotting.
33
+
34
+ Args:
35
+ graph (NetworkGraph): The graph to plot.
36
+ figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
37
+ background_color (str, optional): Background color of the plot. Defaults to "white".
38
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
39
+ any existing alpha values found in background_color. Defaults to 1.0.
40
+ pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
41
+
42
+ Returns:
43
+ NetworkPlotter: A NetworkPlotter object configured with the given parameters.
44
+ """
45
+ log_header("Loading plotter")
46
+
47
+ # Initialize and return a NetworkPlotter object
48
+ return NetworkPlotter(
49
+ graph,
50
+ figsize=figsize,
51
+ background_color=background_color,
52
+ background_alpha=background_alpha,
53
+ pad=pad,
54
+ )
@@ -9,9 +9,9 @@ import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
 
11
11
  from risk.log import params
12
- from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils.colors import to_rgba
14
- from risk.network.plot.utils.layout import calculate_bounding_box
12
+ from risk.network.graph.network import NetworkGraph
13
+ from risk.network.plotter.utils.colors import to_rgba
14
+ from risk.network.plotter.utils.layout import calculate_bounding_box
15
15
 
16
16
 
17
17
  class Canvas:
@@ -12,8 +12,8 @@ from scipy.ndimage import label
12
12
  from scipy.stats import gaussian_kde
13
13
 
14
14
  from risk.log import params, logger
15
- from risk.network.graph import NetworkGraph
16
- from risk.network.plot.utils.colors import get_annotated_domain_colors, to_rgba
15
+ from risk.network.graph.network import NetworkGraph
16
+ from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
17
17
 
18
18
 
19
19
  class Contour:
@@ -11,9 +11,9 @@ import numpy as np
11
11
  import pandas as pd
12
12
 
13
13
  from risk.log import params
14
- from risk.network.graph import NetworkGraph
15
- from risk.network.plot.utils.colors import get_annotated_domain_colors, to_rgba
16
- from risk.network.plot.utils.layout import calculate_bounding_box
14
+ from risk.network.graph.network import NetworkGraph
15
+ from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
16
+ from risk.network.plotter.utils.layout import calculate_bounding_box
17
17
 
18
18
  TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
19
19
 
@@ -5,16 +5,24 @@ risk/network/plot/network
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
7
7
 
8
+ import matplotlib.pyplot as plt
8
9
  import networkx as nx
9
10
  import numpy as np
10
11
 
11
12
  from risk.log import params
12
- from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils.colors import get_domain_colors, to_rgba
13
+ from risk.network.graph.network import NetworkGraph
14
+ from risk.network.plotter.canvas import Canvas
15
+ from risk.network.plotter.contour import Contour
16
+ from risk.network.plotter.labels import Labels
17
+ from risk.network.plotter.utils.colors import get_domain_colors, to_rgba
18
+ from risk.network.plotter.utils.layout import calculate_bounding_box
14
19
 
15
20
 
16
21
  class Network:
17
- """Class for plotting nodes and edges in a network graph."""
22
+ """A class for plotting network graphs with customizable options.
23
+
24
+ The Network class provides methods to plot network graphs with flexible node and edge properties.
25
+ """
18
26
 
19
27
  def __init__(self, graph: NetworkGraph, ax: Any = None) -> None:
20
28
  """Initialize the NetworkPlotter class.
@@ -289,3 +297,128 @@ class Network:
289
297
  node_sizes[node] = significant_size
290
298
 
291
299
  return node_sizes
300
+
301
+
302
+ class NetworkPlotter(Canvas, Network, Contour, Labels):
303
+ """A class for visualizing network graphs with customizable options.
304
+
305
+ The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
306
+ flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
307
+ perimeter, and adjusting background colors.
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ graph: NetworkGraph,
313
+ figsize: Tuple = (10, 10),
314
+ background_color: Union[str, List, Tuple, np.ndarray] = "white",
315
+ background_alpha: Union[float, None] = 1.0,
316
+ pad: float = 0.3,
317
+ ) -> None:
318
+ """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
319
+
320
+ Args:
321
+ graph (NetworkGraph): The network data and attributes to be visualized.
322
+ figsize (Tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
323
+ background_color (str, List, Tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
324
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
325
+ any existing alpha values found in background_color. Defaults to 1.0.
326
+ pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
327
+ """
328
+ self.graph = graph
329
+ # Initialize the plot with the specified parameters
330
+ self.ax = self._initialize_plot(
331
+ graph=graph,
332
+ figsize=figsize,
333
+ background_color=background_color,
334
+ background_alpha=background_alpha,
335
+ pad=pad,
336
+ )
337
+ super().__init__(graph=graph, ax=self.ax)
338
+
339
+ def _initialize_plot(
340
+ self,
341
+ graph: NetworkGraph,
342
+ figsize: Tuple,
343
+ background_color: Union[str, List, Tuple, np.ndarray],
344
+ background_alpha: Union[float, None],
345
+ pad: float,
346
+ ) -> plt.Axes:
347
+ """Set up the plot with figure size and background color.
348
+
349
+ Args:
350
+ graph (NetworkGraph): The network data and attributes to be visualized.
351
+ figsize (Tuple): Size of the figure in inches (width, height).
352
+ background_color (str, List, Tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
353
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
354
+ alpha values found in `background_color`.
355
+ pad (float, optional): Padding value to adjust the axis limits.
356
+
357
+ Returns:
358
+ plt.Axes: The axis object for the plot.
359
+ """
360
+ # Log the plotter settings
361
+ params.log_plotter(
362
+ figsize=figsize,
363
+ background_color=background_color,
364
+ background_alpha=background_alpha,
365
+ pad=pad,
366
+ )
367
+
368
+ # Extract node coordinates from the network graph
369
+ node_coordinates = graph.node_coordinates
370
+ # Calculate the center and radius of the bounding box around the network
371
+ center, radius = calculate_bounding_box(node_coordinates)
372
+
373
+ # Create a new figure and axis for plotting
374
+ fig, ax = plt.subplots(figsize=figsize)
375
+ fig.tight_layout() # Adjust subplot parameters to give specified padding
376
+ # Set axis limits based on the calculated bounding box and radius
377
+ ax.set_xlim([center[0] - radius - pad, center[0] + radius + pad])
378
+ ax.set_ylim([center[1] - radius - pad, center[1] + radius + pad])
379
+ ax.set_aspect("equal") # Ensure the aspect ratio is equal
380
+
381
+ # Set the background color of the plot
382
+ # Convert color to RGBA using the to_rgba helper function
383
+ fig.patch.set_facecolor(
384
+ to_rgba(color=background_color, alpha=background_alpha, num_repeats=1)
385
+ ) # num_repeats=1 for single color
386
+ ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
387
+ # Remove axis spines for a cleaner look
388
+ for spine in ax.spines.values():
389
+ spine.set_visible(False)
390
+
391
+ # Hide axis ticks and labels
392
+ ax.set_xticks([])
393
+ ax.set_yticks([])
394
+ ax.patch.set_visible(False) # Hide the axis background
395
+
396
+ return ax
397
+
398
+ @staticmethod
399
+ def savefig(*args, pad_inches: float = 0.5, dpi: int = 100, **kwargs) -> None:
400
+ """Save the current plot to a file with additional export options.
401
+
402
+ Args:
403
+ *args: Positional arguments passed to `plt.savefig`.
404
+ pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
405
+ dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
406
+ **kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
407
+ """
408
+ # Ensure user-provided kwargs take precedence
409
+ kwargs.setdefault("dpi", dpi)
410
+ kwargs.setdefault("pad_inches", pad_inches)
411
+ # Ensure the plot is saved with tight bounding box if not specified
412
+ kwargs.setdefault("bbox_inches", "tight")
413
+ # Call plt.savefig with combined arguments
414
+ plt.savefig(*args, **kwargs)
415
+
416
+ @staticmethod
417
+ def show(*args, **kwargs) -> None:
418
+ """Display the current plot.
419
+
420
+ Args:
421
+ *args: Positional arguments passed to `plt.show`.
422
+ **kwargs: Keyword arguments passed to `plt.show`.
423
+ """
424
+ plt.show(*args, **kwargs)
@@ -9,7 +9,7 @@ import matplotlib
9
9
  import matplotlib.colors as mcolors
10
10
  import numpy as np
11
11
 
12
- from risk.network.graph import NetworkGraph
12
+ from risk.network.graph.network import NetworkGraph
13
13
 
14
14
 
15
15
  def get_annotated_domain_colors(
@@ -68,7 +68,7 @@ def get_annotated_domain_colors(
68
68
 
69
69
  annotated_colors.append(color)
70
70
 
71
- return np.array(annotated_colors)
71
+ return annotated_colors
72
72
 
73
73
 
74
74
  def get_domain_colors(
@@ -311,19 +311,28 @@ def _transform_colors(
311
311
  if min_scale == max_scale:
312
312
  min_scale = max_scale - 10e-6 # Avoid division by zero
313
313
 
314
+ # Replace invalid values in colors early
315
+ colors = np.nan_to_num(colors, nan=0.0) # Replace NaN with black
314
316
  # Replace black colors (#000000) with very dark grey (#1A1A1A)
315
317
  black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
316
318
  dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
317
- # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
318
319
  is_black = np.all(colors[:, :3] == black_color, axis=1)
319
320
  colors[is_black, :3] = dark_grey
320
321
 
321
- # Normalize the significance sums to the range [0, 1]
322
- normalized_sums = significance_sums / np.max(significance_sums)
323
- # Apply power scaling to dim lower values and emphasize higher values
322
+ # Handle invalid or zero significance sums
323
+ max_significance = np.max(significance_sums)
324
+ if max_significance == 0:
325
+ max_significance = 1 # Avoid division by zero
326
+ normalized_sums = significance_sums / max_significance
327
+ # Replace NaN values in normalized sums
328
+ normalized_sums = np.nan_to_num(normalized_sums, nan=0.0)
329
+
330
+ # Apply power scaling to emphasize higher significance values
324
331
  scaled_sums = normalized_sums**scale_factor
325
332
  # Linearly scale the normalized sums to the range [min_scale, max_scale]
326
333
  scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
334
+ # Replace NaN or invalid scaled sums
335
+ scaled_sums = np.nan_to_num(scaled_sums, nan=min_scale)
327
336
  # Adjust RGB values based on scaled sums
328
337
  for i in range(3): # Only adjust RGB values
329
338
  colors[:, i] = scaled_sums * colors[:, i]