risk-network 0.0.9b23__py3-none-any.whl → 0.0.9b25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/__init__.py +2 -2
- risk/annotations/annotations.py +9 -9
- risk/annotations/io.py +0 -2
- risk/log/__init__.py +2 -2
- risk/neighborhoods/__init__.py +3 -5
- risk/neighborhoods/api.py +446 -0
- risk/neighborhoods/community.py +4 -2
- risk/neighborhoods/domains.py +28 -1
- risk/network/__init__.py +1 -3
- risk/network/graph/__init__.py +1 -1
- risk/network/graph/api.py +194 -0
- risk/network/graph/summary.py +6 -2
- risk/network/io.py +0 -2
- risk/network/plotter/__init__.py +6 -0
- risk/network/plotter/api.py +54 -0
- risk/network/{plot → plotter}/canvas.py +3 -3
- risk/network/{plot → plotter}/contour.py +2 -2
- risk/network/{plot → plotter}/labels.py +3 -3
- risk/network/{plot → plotter}/network.py +136 -3
- risk/network/{plot → plotter}/utils/colors.py +15 -6
- risk/risk.py +10 -483
- risk/stats/__init__.py +8 -4
- risk/stats/binom.py +51 -0
- risk/stats/chi2.py +69 -0
- risk/stats/hypergeom.py +27 -17
- risk/stats/permutation/__init__.py +1 -1
- risk/stats/permutation/permutation.py +44 -55
- risk/stats/permutation/test_functions.py +25 -17
- risk/stats/poisson.py +15 -9
- risk/stats/zscore.py +68 -0
- {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/METADATA +1 -1
- risk_network-0.0.9b25.dist-info/RECORD +44 -0
- risk/network/plot/__init__.py +0 -6
- risk/network/plot/plotter.py +0 -143
- risk_network-0.0.9b23.dist-info/RECORD +0 -39
- /risk/network/{plot → plotter}/utils/layout.py +0 -0
- {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/LICENSE +0 -0
- {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/WHEEL +0 -0
- {risk_network-0.0.9b23.dist-info → risk_network-0.0.9b25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph/api
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
import copy
|
7
|
+
from typing import Any, Dict
|
8
|
+
|
9
|
+
import networkx as nx
|
10
|
+
import pandas as pd
|
11
|
+
|
12
|
+
from risk.annotations import define_top_annotations
|
13
|
+
from risk.log import logger, log_header, params
|
14
|
+
from risk.neighborhoods import (
|
15
|
+
define_domains,
|
16
|
+
process_neighborhoods,
|
17
|
+
trim_domains,
|
18
|
+
)
|
19
|
+
from risk.network.graph.network import NetworkGraph
|
20
|
+
from risk.stats import calculate_significance_matrices
|
21
|
+
|
22
|
+
|
23
|
+
class GraphAPI:
|
24
|
+
"""Handles the loading of network graphs and associated data.
|
25
|
+
|
26
|
+
The GraphAPI class provides methods to load and process network graphs, annotations, and neighborhoods.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__() -> None:
|
30
|
+
pass
|
31
|
+
|
32
|
+
def load_graph(
|
33
|
+
self,
|
34
|
+
network: nx.Graph,
|
35
|
+
annotations: Dict[str, Any],
|
36
|
+
neighborhoods: Dict[str, Any],
|
37
|
+
tail: str = "right",
|
38
|
+
pval_cutoff: float = 0.01,
|
39
|
+
fdr_cutoff: float = 0.9999,
|
40
|
+
impute_depth: int = 0,
|
41
|
+
prune_threshold: float = 0.0,
|
42
|
+
linkage_criterion: str = "distance",
|
43
|
+
linkage_method: str = "average",
|
44
|
+
linkage_metric: str = "yule",
|
45
|
+
min_cluster_size: int = 5,
|
46
|
+
max_cluster_size: int = 1000,
|
47
|
+
) -> NetworkGraph:
|
48
|
+
"""Load and process the network graph, defining top annotations and domains.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
network (nx.Graph): The network graph.
|
52
|
+
annotations (Dict[str, Any]): The annotations associated with the network.
|
53
|
+
neighborhoods (Dict[str, Any]): Neighborhood significance data.
|
54
|
+
tail (str, optional): Type of significance tail ("right", "left", "both"). Defaults to "right".
|
55
|
+
pval_cutoff (float, optional): p-value cutoff for significance. Defaults to 0.01.
|
56
|
+
fdr_cutoff (float, optional): FDR cutoff for significance. Defaults to 0.9999.
|
57
|
+
impute_depth (int, optional): Depth for imputing neighbors. Defaults to 0.
|
58
|
+
prune_threshold (float, optional): Distance threshold for pruning neighbors. Defaults to 0.0.
|
59
|
+
linkage_criterion (str, optional): Clustering criterion for defining domains. Defaults to "distance".
|
60
|
+
linkage_method (str, optional): Clustering method to use. Defaults to "average".
|
61
|
+
linkage_metric (str, optional): Metric to use for calculating distances. Defaults to "yule".
|
62
|
+
min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
|
63
|
+
max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
NetworkGraph: A fully initialized and processed NetworkGraph object.
|
67
|
+
"""
|
68
|
+
# Log the parameters and display headers
|
69
|
+
log_header("Finding significant neighborhoods")
|
70
|
+
params.log_graph(
|
71
|
+
tail=tail,
|
72
|
+
pval_cutoff=pval_cutoff,
|
73
|
+
fdr_cutoff=fdr_cutoff,
|
74
|
+
impute_depth=impute_depth,
|
75
|
+
prune_threshold=prune_threshold,
|
76
|
+
linkage_criterion=linkage_criterion,
|
77
|
+
linkage_method=linkage_method,
|
78
|
+
linkage_metric=linkage_metric,
|
79
|
+
min_cluster_size=min_cluster_size,
|
80
|
+
max_cluster_size=max_cluster_size,
|
81
|
+
)
|
82
|
+
|
83
|
+
# Make a copy of the network to avoid modifying the original
|
84
|
+
network = copy.deepcopy(network)
|
85
|
+
|
86
|
+
logger.debug(f"p-value cutoff: {pval_cutoff}")
|
87
|
+
logger.debug(f"FDR BH cutoff: {fdr_cutoff}")
|
88
|
+
logger.debug(
|
89
|
+
f"Significance tail: '{tail}' ({'enrichment' if tail == 'right' else 'depletion' if tail == 'left' else 'both'})"
|
90
|
+
)
|
91
|
+
# Calculate significant neighborhoods based on the provided parameters
|
92
|
+
significant_neighborhoods = calculate_significance_matrices(
|
93
|
+
neighborhoods["depletion_pvals"],
|
94
|
+
neighborhoods["enrichment_pvals"],
|
95
|
+
tail=tail,
|
96
|
+
pval_cutoff=pval_cutoff,
|
97
|
+
fdr_cutoff=fdr_cutoff,
|
98
|
+
)
|
99
|
+
|
100
|
+
log_header("Processing neighborhoods")
|
101
|
+
# Process neighborhoods by imputing and pruning based on the given settings
|
102
|
+
processed_neighborhoods = process_neighborhoods(
|
103
|
+
network=network,
|
104
|
+
neighborhoods=significant_neighborhoods,
|
105
|
+
impute_depth=impute_depth,
|
106
|
+
prune_threshold=prune_threshold,
|
107
|
+
)
|
108
|
+
|
109
|
+
log_header("Finding top annotations")
|
110
|
+
logger.debug(f"Min cluster size: {min_cluster_size}")
|
111
|
+
logger.debug(f"Max cluster size: {max_cluster_size}")
|
112
|
+
# Define top annotations based on processed neighborhoods
|
113
|
+
top_annotations = self._define_top_annotations(
|
114
|
+
network=network,
|
115
|
+
annotations=annotations,
|
116
|
+
neighborhoods=processed_neighborhoods,
|
117
|
+
min_cluster_size=min_cluster_size,
|
118
|
+
max_cluster_size=max_cluster_size,
|
119
|
+
)
|
120
|
+
|
121
|
+
log_header("Optimizing distance threshold for domains")
|
122
|
+
# Extract the significant significance matrix from the neighborhoods data
|
123
|
+
significant_neighborhoods_significance = processed_neighborhoods[
|
124
|
+
"significant_significance_matrix"
|
125
|
+
]
|
126
|
+
# Define domains in the network using the specified clustering settings
|
127
|
+
domains = define_domains(
|
128
|
+
top_annotations=top_annotations,
|
129
|
+
significant_neighborhoods_significance=significant_neighborhoods_significance,
|
130
|
+
linkage_criterion=linkage_criterion,
|
131
|
+
linkage_method=linkage_method,
|
132
|
+
linkage_metric=linkage_metric,
|
133
|
+
)
|
134
|
+
# Trim domains and top annotations based on cluster size constraints
|
135
|
+
domains, trimmed_domains = trim_domains(
|
136
|
+
domains=domains,
|
137
|
+
top_annotations=top_annotations,
|
138
|
+
min_cluster_size=min_cluster_size,
|
139
|
+
max_cluster_size=max_cluster_size,
|
140
|
+
)
|
141
|
+
|
142
|
+
# Prepare node mapping and significance sums for the final NetworkGraph object
|
143
|
+
ordered_nodes = annotations["ordered_nodes"]
|
144
|
+
node_label_to_id = dict(zip(ordered_nodes, range(len(ordered_nodes))))
|
145
|
+
node_significance_sums = processed_neighborhoods["node_significance_sums"]
|
146
|
+
|
147
|
+
# Return the fully initialized NetworkGraph object
|
148
|
+
return NetworkGraph(
|
149
|
+
network=network,
|
150
|
+
annotations=annotations,
|
151
|
+
neighborhoods=neighborhoods,
|
152
|
+
domains=domains,
|
153
|
+
trimmed_domains=trimmed_domains,
|
154
|
+
node_label_to_node_id_map=node_label_to_id,
|
155
|
+
node_significance_sums=node_significance_sums,
|
156
|
+
)
|
157
|
+
|
158
|
+
def _define_top_annotations(
|
159
|
+
self,
|
160
|
+
network: nx.Graph,
|
161
|
+
annotations: Dict[str, Any],
|
162
|
+
neighborhoods: Dict[str, Any],
|
163
|
+
min_cluster_size: int = 5,
|
164
|
+
max_cluster_size: int = 1000,
|
165
|
+
) -> pd.DataFrame:
|
166
|
+
"""Define top annotations for the network.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
network (nx.Graph): The network graph.
|
170
|
+
annotations (Dict[str, Any]): Annotations data for the network.
|
171
|
+
neighborhoods (Dict[str, Any]): Neighborhood significance data.
|
172
|
+
min_cluster_size (int, optional): Minimum size for clusters. Defaults to 5.
|
173
|
+
max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
Dict[str, Any]: Top annotations identified within the network.
|
177
|
+
"""
|
178
|
+
# Extract necessary data from annotations and neighborhoods
|
179
|
+
ordered_annotations = annotations["ordered_annotations"]
|
180
|
+
neighborhood_significance_sums = neighborhoods["neighborhood_significance_counts"]
|
181
|
+
significant_significance_matrix = neighborhoods["significant_significance_matrix"]
|
182
|
+
significant_binary_significance_matrix = neighborhoods[
|
183
|
+
"significant_binary_significance_matrix"
|
184
|
+
]
|
185
|
+
# Call external function to define top annotations
|
186
|
+
return define_top_annotations(
|
187
|
+
network=network,
|
188
|
+
ordered_annotation_labels=ordered_annotations,
|
189
|
+
neighborhood_significance_sums=neighborhood_significance_sums,
|
190
|
+
significant_significance_matrix=significant_significance_matrix,
|
191
|
+
significant_binary_significance_matrix=significant_binary_significance_matrix,
|
192
|
+
min_cluster_size=min_cluster_size,
|
193
|
+
max_cluster_size=max_cluster_size,
|
194
|
+
)
|
risk/network/graph/summary.py
CHANGED
@@ -240,8 +240,12 @@ class AnalysisSummary:
|
|
240
240
|
except ValueError:
|
241
241
|
return "" # Description not found
|
242
242
|
|
243
|
-
# Get
|
244
|
-
|
243
|
+
# Get the column (safely) from the sparse matrix
|
244
|
+
column = self.annotations["matrix"][:, annotation_idx]
|
245
|
+
# Convert the column to a dense array if needed
|
246
|
+
column = column.toarray().ravel() # Convert to a 1D dense array
|
247
|
+
# Get nodes present for the annotation and sort by node label - use np.where on the dense array
|
248
|
+
nodes_present = np.where(column == 1)[0]
|
245
249
|
node_labels = sorted(
|
246
250
|
self.graph.node_id_to_node_label_map[node_id]
|
247
251
|
for node_id in nodes_present
|
risk/network/io.py
CHANGED
@@ -0,0 +1,54 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph/api
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import List, Tuple, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from risk.log import log_header
|
11
|
+
from risk.network.graph.network import NetworkGraph
|
12
|
+
from risk.network.plotter.network import NetworkPlotter
|
13
|
+
|
14
|
+
|
15
|
+
class PlotterAPI:
|
16
|
+
"""Handles the loading of network plotter objects.
|
17
|
+
|
18
|
+
The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__() -> None:
|
22
|
+
pass
|
23
|
+
|
24
|
+
def load_plotter(
|
25
|
+
self,
|
26
|
+
graph: NetworkGraph,
|
27
|
+
figsize: Union[List, Tuple, np.ndarray] = (10, 10),
|
28
|
+
background_color: str = "white",
|
29
|
+
background_alpha: Union[float, None] = 1.0,
|
30
|
+
pad: float = 0.3,
|
31
|
+
) -> NetworkPlotter:
|
32
|
+
"""Get a NetworkPlotter object for plotting.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
graph (NetworkGraph): The graph to plot.
|
36
|
+
figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
|
37
|
+
background_color (str, optional): Background color of the plot. Defaults to "white".
|
38
|
+
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
|
39
|
+
any existing alpha values found in background_color. Defaults to 1.0.
|
40
|
+
pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
NetworkPlotter: A NetworkPlotter object configured with the given parameters.
|
44
|
+
"""
|
45
|
+
log_header("Loading plotter")
|
46
|
+
|
47
|
+
# Initialize and return a NetworkPlotter object
|
48
|
+
return NetworkPlotter(
|
49
|
+
graph,
|
50
|
+
figsize=figsize,
|
51
|
+
background_color=background_color,
|
52
|
+
background_alpha=background_alpha,
|
53
|
+
pad=pad,
|
54
|
+
)
|
@@ -9,9 +9,9 @@ import matplotlib.pyplot as plt
|
|
9
9
|
import numpy as np
|
10
10
|
|
11
11
|
from risk.log import params
|
12
|
-
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.
|
14
|
-
from risk.network.
|
12
|
+
from risk.network.graph.network import NetworkGraph
|
13
|
+
from risk.network.plotter.utils.colors import to_rgba
|
14
|
+
from risk.network.plotter.utils.layout import calculate_bounding_box
|
15
15
|
|
16
16
|
|
17
17
|
class Canvas:
|
@@ -12,8 +12,8 @@ from scipy.ndimage import label
|
|
12
12
|
from scipy.stats import gaussian_kde
|
13
13
|
|
14
14
|
from risk.log import params, logger
|
15
|
-
from risk.network.graph import NetworkGraph
|
16
|
-
from risk.network.
|
15
|
+
from risk.network.graph.network import NetworkGraph
|
16
|
+
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
17
17
|
|
18
18
|
|
19
19
|
class Contour:
|
@@ -11,9 +11,9 @@ import numpy as np
|
|
11
11
|
import pandas as pd
|
12
12
|
|
13
13
|
from risk.log import params
|
14
|
-
from risk.network.graph import NetworkGraph
|
15
|
-
from risk.network.
|
16
|
-
from risk.network.
|
14
|
+
from risk.network.graph.network import NetworkGraph
|
15
|
+
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
16
|
+
from risk.network.plotter.utils.layout import calculate_bounding_box
|
17
17
|
|
18
18
|
TERM_DELIMITER = "::::" # String used to separate multiple domain terms when constructing composite domain labels
|
19
19
|
|
@@ -5,16 +5,24 @@ risk/network/plot/network
|
|
5
5
|
|
6
6
|
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
|
+
import matplotlib.pyplot as plt
|
8
9
|
import networkx as nx
|
9
10
|
import numpy as np
|
10
11
|
|
11
12
|
from risk.log import params
|
12
|
-
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.
|
13
|
+
from risk.network.graph.network import NetworkGraph
|
14
|
+
from risk.network.plotter.canvas import Canvas
|
15
|
+
from risk.network.plotter.contour import Contour
|
16
|
+
from risk.network.plotter.labels import Labels
|
17
|
+
from risk.network.plotter.utils.colors import get_domain_colors, to_rgba
|
18
|
+
from risk.network.plotter.utils.layout import calculate_bounding_box
|
14
19
|
|
15
20
|
|
16
21
|
class Network:
|
17
|
-
"""
|
22
|
+
"""A class for plotting network graphs with customizable options.
|
23
|
+
|
24
|
+
The Network class provides methods to plot network graphs with flexible node and edge properties.
|
25
|
+
"""
|
18
26
|
|
19
27
|
def __init__(self, graph: NetworkGraph, ax: Any = None) -> None:
|
20
28
|
"""Initialize the NetworkPlotter class.
|
@@ -289,3 +297,128 @@ class Network:
|
|
289
297
|
node_sizes[node] = significant_size
|
290
298
|
|
291
299
|
return node_sizes
|
300
|
+
|
301
|
+
|
302
|
+
class NetworkPlotter(Canvas, Network, Contour, Labels):
|
303
|
+
"""A class for visualizing network graphs with customizable options.
|
304
|
+
|
305
|
+
The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
|
306
|
+
flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
|
307
|
+
perimeter, and adjusting background colors.
|
308
|
+
"""
|
309
|
+
|
310
|
+
def __init__(
|
311
|
+
self,
|
312
|
+
graph: NetworkGraph,
|
313
|
+
figsize: Tuple = (10, 10),
|
314
|
+
background_color: Union[str, List, Tuple, np.ndarray] = "white",
|
315
|
+
background_alpha: Union[float, None] = 1.0,
|
316
|
+
pad: float = 0.3,
|
317
|
+
) -> None:
|
318
|
+
"""Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
322
|
+
figsize (Tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
|
323
|
+
background_color (str, List, Tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
|
324
|
+
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
|
325
|
+
any existing alpha values found in background_color. Defaults to 1.0.
|
326
|
+
pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
|
327
|
+
"""
|
328
|
+
self.graph = graph
|
329
|
+
# Initialize the plot with the specified parameters
|
330
|
+
self.ax = self._initialize_plot(
|
331
|
+
graph=graph,
|
332
|
+
figsize=figsize,
|
333
|
+
background_color=background_color,
|
334
|
+
background_alpha=background_alpha,
|
335
|
+
pad=pad,
|
336
|
+
)
|
337
|
+
super().__init__(graph=graph, ax=self.ax)
|
338
|
+
|
339
|
+
def _initialize_plot(
|
340
|
+
self,
|
341
|
+
graph: NetworkGraph,
|
342
|
+
figsize: Tuple,
|
343
|
+
background_color: Union[str, List, Tuple, np.ndarray],
|
344
|
+
background_alpha: Union[float, None],
|
345
|
+
pad: float,
|
346
|
+
) -> plt.Axes:
|
347
|
+
"""Set up the plot with figure size and background color.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
graph (NetworkGraph): The network data and attributes to be visualized.
|
351
|
+
figsize (Tuple): Size of the figure in inches (width, height).
|
352
|
+
background_color (str, List, Tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
|
353
|
+
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
|
354
|
+
alpha values found in `background_color`.
|
355
|
+
pad (float, optional): Padding value to adjust the axis limits.
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
plt.Axes: The axis object for the plot.
|
359
|
+
"""
|
360
|
+
# Log the plotter settings
|
361
|
+
params.log_plotter(
|
362
|
+
figsize=figsize,
|
363
|
+
background_color=background_color,
|
364
|
+
background_alpha=background_alpha,
|
365
|
+
pad=pad,
|
366
|
+
)
|
367
|
+
|
368
|
+
# Extract node coordinates from the network graph
|
369
|
+
node_coordinates = graph.node_coordinates
|
370
|
+
# Calculate the center and radius of the bounding box around the network
|
371
|
+
center, radius = calculate_bounding_box(node_coordinates)
|
372
|
+
|
373
|
+
# Create a new figure and axis for plotting
|
374
|
+
fig, ax = plt.subplots(figsize=figsize)
|
375
|
+
fig.tight_layout() # Adjust subplot parameters to give specified padding
|
376
|
+
# Set axis limits based on the calculated bounding box and radius
|
377
|
+
ax.set_xlim([center[0] - radius - pad, center[0] + radius + pad])
|
378
|
+
ax.set_ylim([center[1] - radius - pad, center[1] + radius + pad])
|
379
|
+
ax.set_aspect("equal") # Ensure the aspect ratio is equal
|
380
|
+
|
381
|
+
# Set the background color of the plot
|
382
|
+
# Convert color to RGBA using the to_rgba helper function
|
383
|
+
fig.patch.set_facecolor(
|
384
|
+
to_rgba(color=background_color, alpha=background_alpha, num_repeats=1)
|
385
|
+
) # num_repeats=1 for single color
|
386
|
+
ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
|
387
|
+
# Remove axis spines for a cleaner look
|
388
|
+
for spine in ax.spines.values():
|
389
|
+
spine.set_visible(False)
|
390
|
+
|
391
|
+
# Hide axis ticks and labels
|
392
|
+
ax.set_xticks([])
|
393
|
+
ax.set_yticks([])
|
394
|
+
ax.patch.set_visible(False) # Hide the axis background
|
395
|
+
|
396
|
+
return ax
|
397
|
+
|
398
|
+
@staticmethod
|
399
|
+
def savefig(*args, pad_inches: float = 0.5, dpi: int = 100, **kwargs) -> None:
|
400
|
+
"""Save the current plot to a file with additional export options.
|
401
|
+
|
402
|
+
Args:
|
403
|
+
*args: Positional arguments passed to `plt.savefig`.
|
404
|
+
pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
|
405
|
+
dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
|
406
|
+
**kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
|
407
|
+
"""
|
408
|
+
# Ensure user-provided kwargs take precedence
|
409
|
+
kwargs.setdefault("dpi", dpi)
|
410
|
+
kwargs.setdefault("pad_inches", pad_inches)
|
411
|
+
# Ensure the plot is saved with tight bounding box if not specified
|
412
|
+
kwargs.setdefault("bbox_inches", "tight")
|
413
|
+
# Call plt.savefig with combined arguments
|
414
|
+
plt.savefig(*args, **kwargs)
|
415
|
+
|
416
|
+
@staticmethod
|
417
|
+
def show(*args, **kwargs) -> None:
|
418
|
+
"""Display the current plot.
|
419
|
+
|
420
|
+
Args:
|
421
|
+
*args: Positional arguments passed to `plt.show`.
|
422
|
+
**kwargs: Keyword arguments passed to `plt.show`.
|
423
|
+
"""
|
424
|
+
plt.show(*args, **kwargs)
|
@@ -9,7 +9,7 @@ import matplotlib
|
|
9
9
|
import matplotlib.colors as mcolors
|
10
10
|
import numpy as np
|
11
11
|
|
12
|
-
from risk.network.graph import NetworkGraph
|
12
|
+
from risk.network.graph.network import NetworkGraph
|
13
13
|
|
14
14
|
|
15
15
|
def get_annotated_domain_colors(
|
@@ -68,7 +68,7 @@ def get_annotated_domain_colors(
|
|
68
68
|
|
69
69
|
annotated_colors.append(color)
|
70
70
|
|
71
|
-
return
|
71
|
+
return annotated_colors
|
72
72
|
|
73
73
|
|
74
74
|
def get_domain_colors(
|
@@ -311,19 +311,28 @@ def _transform_colors(
|
|
311
311
|
if min_scale == max_scale:
|
312
312
|
min_scale = max_scale - 10e-6 # Avoid division by zero
|
313
313
|
|
314
|
+
# Replace invalid values in colors early
|
315
|
+
colors = np.nan_to_num(colors, nan=0.0) # Replace NaN with black
|
314
316
|
# Replace black colors (#000000) with very dark grey (#1A1A1A)
|
315
317
|
black_color = np.array([0.0, 0.0, 0.0]) # Pure black RGB
|
316
318
|
dark_grey = np.array([0.1, 0.1, 0.1]) # Very dark grey RGB (#1A1A1A)
|
317
|
-
# Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
|
318
319
|
is_black = np.all(colors[:, :3] == black_color, axis=1)
|
319
320
|
colors[is_black, :3] = dark_grey
|
320
321
|
|
321
|
-
#
|
322
|
-
|
323
|
-
|
322
|
+
# Handle invalid or zero significance sums
|
323
|
+
max_significance = np.max(significance_sums)
|
324
|
+
if max_significance == 0:
|
325
|
+
max_significance = 1 # Avoid division by zero
|
326
|
+
normalized_sums = significance_sums / max_significance
|
327
|
+
# Replace NaN values in normalized sums
|
328
|
+
normalized_sums = np.nan_to_num(normalized_sums, nan=0.0)
|
329
|
+
|
330
|
+
# Apply power scaling to emphasize higher significance values
|
324
331
|
scaled_sums = normalized_sums**scale_factor
|
325
332
|
# Linearly scale the normalized sums to the range [min_scale, max_scale]
|
326
333
|
scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
|
334
|
+
# Replace NaN or invalid scaled sums
|
335
|
+
scaled_sums = np.nan_to_num(scaled_sums, nan=min_scale)
|
327
336
|
# Adjust RGB values based on scaled sums
|
328
337
|
for i in range(3): # Only adjust RGB values
|
329
338
|
colors[:, i] = scaled_sums * colors[:, i]
|