risk-network 0.0.6b10__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/annotations.py +61 -42
- risk/annotations/io.py +14 -14
- risk/log/__init__.py +1 -1
- risk/log/config.py +139 -0
- risk/log/params.py +4 -4
- risk/neighborhoods/community.py +25 -36
- risk/neighborhoods/domains.py +29 -27
- risk/neighborhoods/neighborhoods.py +171 -72
- risk/network/graph.py +92 -41
- risk/network/io.py +22 -26
- risk/network/plot.py +132 -19
- risk/risk.py +81 -78
- risk/stats/__init__.py +2 -2
- risk/stats/hypergeom.py +30 -107
- risk/stats/permutation/permutation.py +23 -17
- risk/stats/permutation/test_functions.py +2 -2
- risk/stats/poisson.py +44 -0
- {risk_network-0.0.6b10.dist-info → risk_network-0.0.7.dist-info}/METADATA +1 -1
- risk_network-0.0.7.dist-info/RECORD +30 -0
- risk/log/console.py +0 -16
- risk/stats/fisher_exact.py +0 -132
- risk_network-0.0.6b10.dist-info/RECORD +0 -30
- {risk_network-0.0.6b10.dist-info → risk_network-0.0.7.dist-info}/LICENSE +0 -0
- {risk_network-0.0.6b10.dist-info → risk_network-0.0.7.dist-info}/WHEEL +0 -0
- {risk_network-0.0.6b10.dist-info → risk_network-0.0.7.dist-info}/top_level.txt +0 -0
risk/network/graph.py
CHANGED
@@ -3,7 +3,6 @@ risk/network/graph
|
|
3
3
|
~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
import random
|
7
6
|
from collections import defaultdict
|
8
7
|
from typing import Any, Dict, List, Tuple, Union
|
9
8
|
|
@@ -55,10 +54,9 @@ class NetworkGraph:
|
|
55
54
|
self.node_label_to_node_id_map = node_label_to_node_id_map
|
56
55
|
# NOTE: Below this point, instance attributes (i.e., self) will be used!
|
57
56
|
self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
|
58
|
-
#
|
59
|
-
self.network =
|
60
|
-
self.node_coordinates =
|
61
|
-
self._initialize_network(network)
|
57
|
+
# Unfold the network's 3D coordinates to 2D and extract node coordinates
|
58
|
+
self.network = _unfold_sphere_to_plane(network)
|
59
|
+
self.node_coordinates = _extract_node_coordinates(self.network)
|
62
60
|
|
63
61
|
def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
|
64
62
|
"""Create a mapping from domains to the list of node IDs belonging to each domain.
|
@@ -109,19 +107,6 @@ class NetworkGraph:
|
|
109
107
|
|
110
108
|
return domain_id_to_label_map
|
111
109
|
|
112
|
-
def _initialize_network(self, G: nx.Graph) -> None:
|
113
|
-
"""Initialize the network by unfolding it and extracting node coordinates.
|
114
|
-
|
115
|
-
Args:
|
116
|
-
G (nx.Graph): The input network graph with 3D node coordinates.
|
117
|
-
"""
|
118
|
-
# Unfold the network's 3D coordinates to 2D
|
119
|
-
G_2d = _unfold_sphere_to_plane(G)
|
120
|
-
# Assign the unfolded graph to self.network
|
121
|
-
self.network = G_2d
|
122
|
-
# Extract 2D coordinates of nodes
|
123
|
-
self.node_coordinates = _extract_node_coordinates(G_2d)
|
124
|
-
|
125
110
|
def get_domain_colors(
|
126
111
|
self,
|
127
112
|
cmap: str = "gist_rainbow",
|
@@ -200,14 +185,15 @@ class NetworkGraph:
|
|
200
185
|
Returns:
|
201
186
|
dict: A dictionary mapping domain keys to their corresponding RGBA colors.
|
202
187
|
"""
|
203
|
-
#
|
204
|
-
numeric_domains = [
|
205
|
-
col for col in self.domains.columns if isinstance(col, (int, np.integer))
|
206
|
-
]
|
207
|
-
domains = np.sort(numeric_domains)
|
188
|
+
# Get colors for each domain based on node positions
|
208
189
|
domain_colors = _get_colors(
|
209
|
-
|
190
|
+
self.network,
|
191
|
+
self.domain_id_to_node_ids_map,
|
192
|
+
cmap=cmap,
|
193
|
+
color=color,
|
194
|
+
random_seed=random_seed,
|
210
195
|
)
|
196
|
+
self.network, self.domain_id_to_node_ids_map
|
211
197
|
return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
|
212
198
|
|
213
199
|
|
@@ -300,35 +286,100 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
|
|
300
286
|
|
301
287
|
|
302
288
|
def _get_colors(
|
303
|
-
|
289
|
+
network,
|
290
|
+
domain_id_to_node_ids_map,
|
304
291
|
cmap: str = "gist_rainbow",
|
305
292
|
color: Union[str, None] = None,
|
306
293
|
random_seed: int = 888,
|
307
294
|
) -> List[Tuple]:
|
308
|
-
"""Generate a list of RGBA colors
|
295
|
+
"""Generate a list of RGBA colors based on domain centroids, ensuring that domains
|
296
|
+
close in space get maximally separated colors, while keeping some randomness.
|
309
297
|
|
310
298
|
Args:
|
311
|
-
|
299
|
+
network (NetworkX graph): The graph representing the network.
|
300
|
+
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
312
301
|
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
313
302
|
color (str or None, optional): A specific color to use for all generated colors.
|
314
303
|
random_seed (int): Seed for random number generation. Defaults to 888.
|
315
|
-
Defaults to None.
|
316
304
|
|
317
305
|
Returns:
|
318
|
-
|
306
|
+
List[Tuple]: List of RGBA colors.
|
319
307
|
"""
|
320
308
|
# Set random seed for reproducibility
|
321
|
-
random.seed(random_seed)
|
309
|
+
np.random.seed(random_seed)
|
310
|
+
# Determine the number of colors to generate based on the number of domains
|
311
|
+
num_colors_to_generate = len(domain_id_to_node_ids_map)
|
322
312
|
if color:
|
323
|
-
#
|
313
|
+
# Generate all colors as the same specified color
|
324
314
|
rgba = matplotlib.colors.to_rgba(color)
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
315
|
+
return [rgba] * num_colors_to_generate
|
316
|
+
|
317
|
+
# Load colormap
|
318
|
+
colormap = matplotlib.colormaps.get_cmap(cmap)
|
319
|
+
# Step 1: Calculate centroids for each domain
|
320
|
+
centroids = _calculate_centroids(network, domain_id_to_node_ids_map)
|
321
|
+
# Step 2: Calculate pairwise distances between centroids
|
322
|
+
centroid_array = np.array(centroids)
|
323
|
+
dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
|
324
|
+
# Step 3: Assign distant colors to close centroids
|
325
|
+
color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
|
326
|
+
# Step 4: Randomly shift the entire color palette while maintaining relative distances
|
327
|
+
global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
|
328
|
+
color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
|
329
|
+
# Step 5: Ensure that all positions remain between 0 and 1
|
330
|
+
color_positions = np.clip(color_positions, 0, 1)
|
331
|
+
|
332
|
+
# Step 6: Generate RGBA colors based on positions
|
333
|
+
return [colormap(pos) for pos in color_positions]
|
334
|
+
|
335
|
+
|
336
|
+
def _calculate_centroids(network, domain_id_to_node_ids_map):
|
337
|
+
"""Calculate the centroid for each domain based on node x and y coordinates in the network.
|
338
|
+
|
339
|
+
Args:
|
340
|
+
network (NetworkX graph): The graph representing the network.
|
341
|
+
domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
|
342
|
+
|
343
|
+
Returns:
|
344
|
+
List[Tuple[float, float]]: List of centroids (x, y) for each domain.
|
345
|
+
"""
|
346
|
+
centroids = []
|
347
|
+
for domain_id, node_ids in domain_id_to_node_ids_map.items():
|
348
|
+
# Extract x and y coordinates from the network nodes
|
349
|
+
node_positions = np.array(
|
350
|
+
[[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
|
351
|
+
)
|
352
|
+
# Compute the centroid as the mean of the x and y coordinates
|
353
|
+
centroid = np.mean(node_positions, axis=0)
|
354
|
+
centroids.append(tuple(centroid))
|
355
|
+
|
356
|
+
return centroids
|
357
|
+
|
358
|
+
|
359
|
+
def _assign_distant_colors(dist_matrix, num_colors_to_generate):
|
360
|
+
"""Assign colors to centroids that are close in space, ensuring stark color differences.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
dist_matrix (ndarray): Matrix of pairwise centroid distances.
|
364
|
+
num_colors_to_generate (int): Number of colors to generate.
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
np.array: Array of color positions in the range [0, 1].
|
368
|
+
"""
|
369
|
+
color_positions = np.zeros(num_colors_to_generate)
|
370
|
+
# Step 1: Sort indices by centroid proximity (based on sum of distances to others)
|
371
|
+
proximity_order = sorted(
|
372
|
+
range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
|
373
|
+
)
|
374
|
+
# Step 2: Assign colors starting with the most distant points in proximity order
|
375
|
+
for i, idx in enumerate(proximity_order):
|
376
|
+
color_positions[idx] = i / num_colors_to_generate
|
377
|
+
|
378
|
+
# Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
|
379
|
+
half_spectrum = int(num_colors_to_generate / 2)
|
380
|
+
for i in range(half_spectrum):
|
381
|
+
# Split the spectrum so that close centroids are assigned distant colors
|
382
|
+
color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
|
383
|
+
color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
|
384
|
+
|
385
|
+
return color_positions
|
risk/network/io.py
CHANGED
@@ -16,7 +16,7 @@ import networkx as nx
|
|
16
16
|
import pandas as pd
|
17
17
|
|
18
18
|
from risk.network.geometry import assign_edge_lengths
|
19
|
-
from risk.log import params,
|
19
|
+
from risk.log import params, logger, log_header
|
20
20
|
|
21
21
|
|
22
22
|
class NetworkIO:
|
@@ -57,9 +57,8 @@ class NetworkIO:
|
|
57
57
|
weight_label=weight_label,
|
58
58
|
)
|
59
59
|
|
60
|
-
@
|
60
|
+
@staticmethod
|
61
61
|
def load_gpickle_network(
|
62
|
-
cls,
|
63
62
|
filepath: str,
|
64
63
|
compute_sphere: bool = True,
|
65
64
|
surface_depth: float = 0.0,
|
@@ -80,7 +79,7 @@ class NetworkIO:
|
|
80
79
|
Returns:
|
81
80
|
nx.Graph: Loaded and processed network.
|
82
81
|
"""
|
83
|
-
networkio =
|
82
|
+
networkio = NetworkIO(
|
84
83
|
compute_sphere=compute_sphere,
|
85
84
|
surface_depth=surface_depth,
|
86
85
|
min_edges_per_node=min_edges_per_node,
|
@@ -109,9 +108,8 @@ class NetworkIO:
|
|
109
108
|
# Initialize the graph
|
110
109
|
return self._initialize_graph(G)
|
111
110
|
|
112
|
-
@
|
111
|
+
@staticmethod
|
113
112
|
def load_networkx_network(
|
114
|
-
cls,
|
115
113
|
network: nx.Graph,
|
116
114
|
compute_sphere: bool = True,
|
117
115
|
surface_depth: float = 0.0,
|
@@ -132,7 +130,7 @@ class NetworkIO:
|
|
132
130
|
Returns:
|
133
131
|
nx.Graph: Loaded and processed network.
|
134
132
|
"""
|
135
|
-
networkio =
|
133
|
+
networkio = NetworkIO(
|
136
134
|
compute_sphere=compute_sphere,
|
137
135
|
surface_depth=surface_depth,
|
138
136
|
min_edges_per_node=min_edges_per_node,
|
@@ -158,9 +156,8 @@ class NetworkIO:
|
|
158
156
|
# Initialize the graph
|
159
157
|
return self._initialize_graph(network)
|
160
158
|
|
161
|
-
@
|
159
|
+
@staticmethod
|
162
160
|
def load_cytoscape_network(
|
163
|
-
cls,
|
164
161
|
filepath: str,
|
165
162
|
source_label: str = "source",
|
166
163
|
target_label: str = "target",
|
@@ -187,7 +184,7 @@ class NetworkIO:
|
|
187
184
|
Returns:
|
188
185
|
nx.Graph: Loaded and processed network.
|
189
186
|
"""
|
190
|
-
networkio =
|
187
|
+
networkio = NetworkIO(
|
191
188
|
compute_sphere=compute_sphere,
|
192
189
|
surface_depth=surface_depth,
|
193
190
|
min_edges_per_node=min_edges_per_node,
|
@@ -312,9 +309,8 @@ class NetworkIO:
|
|
312
309
|
if os.path.exists(tmp_dir):
|
313
310
|
shutil.rmtree(tmp_dir)
|
314
311
|
|
315
|
-
@
|
312
|
+
@staticmethod
|
316
313
|
def load_cytoscape_json_network(
|
317
|
-
cls,
|
318
314
|
filepath: str,
|
319
315
|
source_label: str = "source",
|
320
316
|
target_label: str = "target",
|
@@ -339,7 +335,7 @@ class NetworkIO:
|
|
339
335
|
Returns:
|
340
336
|
NetworkX graph: Loaded and processed network.
|
341
337
|
"""
|
342
|
-
networkio =
|
338
|
+
networkio = NetworkIO(
|
343
339
|
compute_sphere=compute_sphere,
|
344
340
|
surface_depth=surface_depth,
|
345
341
|
min_edges_per_node=min_edges_per_node,
|
@@ -455,10 +451,10 @@ class NetworkIO:
|
|
455
451
|
# Log the number of nodes and edges before and after cleaning
|
456
452
|
num_final_nodes = G.number_of_nodes()
|
457
453
|
num_final_edges = G.number_of_edges()
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
454
|
+
logger.debug(f"Initial node count: {num_initial_nodes}")
|
455
|
+
logger.debug(f"Final node count: {num_final_nodes}")
|
456
|
+
logger.debug(f"Initial edge count: {num_initial_edges}")
|
457
|
+
logger.debug(f"Final edge count: {num_final_edges}")
|
462
458
|
|
463
459
|
def _assign_edge_weights(self, G: nx.Graph) -> None:
|
464
460
|
"""Assign weights to the edges in the graph.
|
@@ -476,7 +472,7 @@ class NetworkIO:
|
|
476
472
|
) # Default to 1.0 if 'weight' not present
|
477
473
|
|
478
474
|
if self.include_edge_weight and missing_weights:
|
479
|
-
|
475
|
+
logger.debug(f"Total edges missing weights: {missing_weights}")
|
480
476
|
|
481
477
|
def _validate_nodes(self, G: nx.Graph) -> None:
|
482
478
|
"""Validate the graph structure and attributes.
|
@@ -514,14 +510,14 @@ class NetworkIO:
|
|
514
510
|
filetype (str): The type of the file being loaded (e.g., 'CSV', 'JSON').
|
515
511
|
filepath (str, optional): The path to the file being loaded. Defaults to "".
|
516
512
|
"""
|
517
|
-
|
518
|
-
|
513
|
+
log_header("Loading network")
|
514
|
+
logger.debug(f"Filetype: {filetype}")
|
519
515
|
if filepath:
|
520
|
-
|
521
|
-
|
516
|
+
logger.debug(f"Filepath: {filepath}")
|
517
|
+
logger.debug(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
|
522
518
|
if self.include_edge_weight:
|
523
|
-
|
524
|
-
|
525
|
-
|
519
|
+
logger.debug(f"Weight label: {self.weight_label}")
|
520
|
+
logger.debug(f"Minimum edges per node: {self.min_edges_per_node}")
|
521
|
+
logger.debug(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
|
526
522
|
if self.compute_sphere:
|
527
|
-
|
523
|
+
logger.debug(f"Surface depth: {self.surface_depth}")
|
risk/network/plot.py
CHANGED
@@ -9,10 +9,12 @@ import matplotlib.colors as mcolors
|
|
9
9
|
import matplotlib.pyplot as plt
|
10
10
|
import networkx as nx
|
11
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
from scipy import linalg
|
12
14
|
from scipy.ndimage import label
|
13
15
|
from scipy.stats import gaussian_kde
|
14
16
|
|
15
|
-
from risk.log import params
|
17
|
+
from risk.log import params, logger
|
16
18
|
from risk.network.graph import NetworkGraph
|
17
19
|
|
18
20
|
|
@@ -85,6 +87,83 @@ class NetworkPlotter:
|
|
85
87
|
|
86
88
|
return ax
|
87
89
|
|
90
|
+
def plot_title(
|
91
|
+
self,
|
92
|
+
title: Union[str, None] = None,
|
93
|
+
subtitle: Union[str, None] = None,
|
94
|
+
title_fontsize: int = 20,
|
95
|
+
subtitle_fontsize: int = 14,
|
96
|
+
font: str = "Arial",
|
97
|
+
title_color: str = "black",
|
98
|
+
subtitle_color: str = "gray",
|
99
|
+
title_y: float = 0.975,
|
100
|
+
title_space_offset: float = 0.075,
|
101
|
+
subtitle_offset: float = 0.025,
|
102
|
+
) -> None:
|
103
|
+
"""Plot title and subtitle on the network graph with customizable parameters.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
title (str, optional): Title of the plot. Defaults to None.
|
107
|
+
subtitle (str, optional): Subtitle of the plot. Defaults to None.
|
108
|
+
title_fontsize (int, optional): Font size for the title. Defaults to 16.
|
109
|
+
subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 12.
|
110
|
+
font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
|
111
|
+
title_color (str, optional): Color of the title text. Defaults to "black".
|
112
|
+
subtitle_color (str, optional): Color of the subtitle text. Defaults to "gray".
|
113
|
+
title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
|
114
|
+
title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
|
115
|
+
subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
|
116
|
+
"""
|
117
|
+
# Log the title and subtitle parameters
|
118
|
+
params.log_plotter(
|
119
|
+
title=title,
|
120
|
+
subtitle=subtitle,
|
121
|
+
title_fontsize=title_fontsize,
|
122
|
+
subtitle_fontsize=subtitle_fontsize,
|
123
|
+
title_subtitle_font=font,
|
124
|
+
title_color=title_color,
|
125
|
+
subtitle_color=subtitle_color,
|
126
|
+
subtitle_offset=subtitle_offset,
|
127
|
+
title_y=title_y,
|
128
|
+
title_space_offset=title_space_offset,
|
129
|
+
)
|
130
|
+
|
131
|
+
# Get the current figure and axis dimensions
|
132
|
+
fig = self.ax.figure
|
133
|
+
# Use a tight layout to ensure that title and subtitle do not overlap with the original plot
|
134
|
+
fig.tight_layout(
|
135
|
+
rect=[0, 0, 1, 1 - title_space_offset]
|
136
|
+
) # Leave space above the plot for title
|
137
|
+
|
138
|
+
# Plot title if provided
|
139
|
+
if title:
|
140
|
+
# Set the title using figure's suptitle to ensure centering
|
141
|
+
self.ax.figure.suptitle(
|
142
|
+
title,
|
143
|
+
fontsize=title_fontsize,
|
144
|
+
color=title_color,
|
145
|
+
fontname=font,
|
146
|
+
x=0.5, # Center the title horizontally
|
147
|
+
ha="center",
|
148
|
+
va="top",
|
149
|
+
y=title_y,
|
150
|
+
)
|
151
|
+
|
152
|
+
# Plot subtitle if provided
|
153
|
+
if subtitle:
|
154
|
+
# Calculate the subtitle's y position based on title's position and subtitle_offset
|
155
|
+
subtitle_y_position = title_y - subtitle_offset
|
156
|
+
self.ax.figure.text(
|
157
|
+
0.5, # Ensure horizontal centering for subtitle
|
158
|
+
subtitle_y_position,
|
159
|
+
subtitle,
|
160
|
+
ha="center",
|
161
|
+
va="top",
|
162
|
+
fontname=font,
|
163
|
+
fontsize=subtitle_fontsize,
|
164
|
+
color=subtitle_color,
|
165
|
+
)
|
166
|
+
|
88
167
|
def plot_circle_perimeter(
|
89
168
|
self,
|
90
169
|
scale: float = 1.0,
|
@@ -509,26 +588,52 @@ class NetworkPlotter:
|
|
509
588
|
# Extract the positions of the specified nodes
|
510
589
|
points = np.array([pos[n] for n in nodes])
|
511
590
|
if len(points) <= 1:
|
512
|
-
return # Not enough points to form a contour
|
591
|
+
return None # Not enough points to form a contour
|
513
592
|
|
593
|
+
# Check if the KDE forms a single connected component
|
514
594
|
connected = False
|
595
|
+
z = None # Initialize z to None to avoid UnboundLocalError
|
515
596
|
while not connected and bandwidth <= 100.0:
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
597
|
+
try:
|
598
|
+
# Perform KDE on the points with the given bandwidth
|
599
|
+
kde = gaussian_kde(points.T, bw_method=bandwidth)
|
600
|
+
xmin, ymin = points.min(axis=0) - bandwidth
|
601
|
+
xmax, ymax = points.max(axis=0) + bandwidth
|
602
|
+
x, y = np.mgrid[
|
603
|
+
xmin : xmax : complex(0, grid_size), ymin : ymax : complex(0, grid_size)
|
604
|
+
]
|
605
|
+
z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
|
606
|
+
# Check if the KDE forms a single connected component
|
607
|
+
connected = _is_connected(z)
|
608
|
+
if not connected:
|
609
|
+
bandwidth += 0.05 # Increase bandwidth slightly and retry
|
610
|
+
except linalg.LinAlgError:
|
611
|
+
bandwidth += 0.05 # Increase bandwidth and retry
|
612
|
+
except Exception as e:
|
613
|
+
# Catch any other exceptions and log them
|
614
|
+
logger.error(f"Unexpected error when drawing KDE contour: {e}")
|
615
|
+
return None
|
616
|
+
|
617
|
+
# If z is still None, the KDE computation failed
|
618
|
+
if z is None:
|
619
|
+
logger.error("Failed to compute KDE. Skipping contour plot for these nodes.")
|
620
|
+
return None
|
528
621
|
|
529
622
|
# Define contour levels based on the density
|
530
623
|
min_density, max_density = z.min(), z.max()
|
624
|
+
if min_density == max_density:
|
625
|
+
logger.warning(
|
626
|
+
"Contour levels could not be created due to lack of variation in density."
|
627
|
+
)
|
628
|
+
return None
|
629
|
+
|
630
|
+
# Create contour levels based on the density values
|
531
631
|
contour_levels = np.linspace(min_density, max_density, levels)[1:]
|
632
|
+
if len(contour_levels) < 2 or not np.all(np.diff(contour_levels) > 0):
|
633
|
+
logger.error("Contour levels must be strictly increasing. Skipping contour plot.")
|
634
|
+
return None
|
635
|
+
|
636
|
+
# Set the contour color and linestyle
|
532
637
|
contour_colors = [color for _ in range(levels - 1)]
|
533
638
|
# Plot the filled contours using fill_alpha for transparency
|
534
639
|
if fill_alpha > 0:
|
@@ -553,6 +658,7 @@ class NetworkPlotter:
|
|
553
658
|
linewidths=linewidth,
|
554
659
|
alpha=alpha,
|
555
660
|
)
|
661
|
+
|
556
662
|
# Set linewidth for the contour lines to 0 for levels other than the base level
|
557
663
|
for i in range(1, len(contour_levels)):
|
558
664
|
c.collections[i].set_linewidth(0)
|
@@ -601,7 +707,7 @@ class NetworkPlotter:
|
|
601
707
|
min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
|
602
708
|
max_word_length (int, optional): Maximum number of characters in a word to display. Defaults to 20.
|
603
709
|
min_word_length (int, optional): Minimum number of characters in a word to display. Defaults to 1.
|
604
|
-
words_to_omit (
|
710
|
+
words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
|
605
711
|
overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
|
606
712
|
ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
|
607
713
|
you can set `overlay_ids=True`. Defaults to None.
|
@@ -710,6 +816,9 @@ class NetworkPlotter:
|
|
710
816
|
# Process remaining domains to fill in additional labels, if there are slots left
|
711
817
|
if remaining_labels and remaining_labels > 0:
|
712
818
|
for idx, (domain, centroid) in enumerate(domain_centroids.items()):
|
819
|
+
# Check if the domain is NaN and continue if true
|
820
|
+
if pd.isna(domain) or (isinstance(domain, float) and np.isnan(domain)):
|
821
|
+
continue # Skip NaN domains
|
713
822
|
if ids_to_keep and domain in ids_to_keep:
|
714
823
|
continue # Skip domains already handled by ids_to_keep
|
715
824
|
|
@@ -1086,14 +1195,16 @@ class NetworkPlotter:
|
|
1086
1195
|
return np.array(annotated_colors)
|
1087
1196
|
|
1088
1197
|
@staticmethod
|
1089
|
-
def savefig(*args, **kwargs) -> None:
|
1090
|
-
"""Save the current plot to a file.
|
1198
|
+
def savefig(*args, pad_inches: float = 0.5, dpi: int = 100, **kwargs) -> None:
|
1199
|
+
"""Save the current plot to a file with additional export options.
|
1091
1200
|
|
1092
1201
|
Args:
|
1093
1202
|
*args: Positional arguments passed to `plt.savefig`.
|
1203
|
+
pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
|
1204
|
+
dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
|
1094
1205
|
**kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
|
1095
1206
|
"""
|
1096
|
-
plt.savefig(*args, bbox_inches="tight", **kwargs)
|
1207
|
+
plt.savefig(*args, bbox_inches="tight", pad_inches=pad_inches, dpi=dpi, **kwargs)
|
1097
1208
|
|
1098
1209
|
@staticmethod
|
1099
1210
|
def show(*args, **kwargs) -> None:
|
@@ -1123,7 +1234,9 @@ def _to_rgba(
|
|
1123
1234
|
"""
|
1124
1235
|
# Handle single color case (string, RGB, or RGBA)
|
1125
1236
|
if isinstance(color, str) or (
|
1126
|
-
isinstance(color, (list, tuple, np.ndarray))
|
1237
|
+
isinstance(color, (list, tuple, np.ndarray))
|
1238
|
+
and len(color) in [3, 4]
|
1239
|
+
and not any(isinstance(c, (list, tuple, np.ndarray)) for c in color)
|
1127
1240
|
):
|
1128
1241
|
rgba_color = np.array(mcolors.to_rgba(color))
|
1129
1242
|
# Only set alpha if the input is an RGB color or a string (not RGBA)
|