risk-network 0.0.6b9__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/network/graph.py CHANGED
@@ -3,7 +3,6 @@ risk/network/graph
3
3
  ~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- import random
7
6
  from collections import defaultdict
8
7
  from typing import Any, Dict, List, Tuple, Union
9
8
 
@@ -55,10 +54,9 @@ class NetworkGraph:
55
54
  self.node_label_to_node_id_map = node_label_to_node_id_map
56
55
  # NOTE: Below this point, instance attributes (i.e., self) will be used!
57
56
  self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
58
- # self.network and self.node_coordinates are properly declared in _initialize_network
59
- self.network = None
60
- self.node_coordinates = None
61
- self._initialize_network(network)
57
+ # Unfold the network's 3D coordinates to 2D and extract node coordinates
58
+ self.network = _unfold_sphere_to_plane(network)
59
+ self.node_coordinates = _extract_node_coordinates(self.network)
62
60
 
63
61
  def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
64
62
  """Create a mapping from domains to the list of node IDs belonging to each domain.
@@ -109,19 +107,6 @@ class NetworkGraph:
109
107
 
110
108
  return domain_id_to_label_map
111
109
 
112
- def _initialize_network(self, G: nx.Graph) -> None:
113
- """Initialize the network by unfolding it and extracting node coordinates.
114
-
115
- Args:
116
- G (nx.Graph): The input network graph with 3D node coordinates.
117
- """
118
- # Unfold the network's 3D coordinates to 2D
119
- G_2d = _unfold_sphere_to_plane(G)
120
- # Assign the unfolded graph to self.network
121
- self.network = G_2d
122
- # Extract 2D coordinates of nodes
123
- self.node_coordinates = _extract_node_coordinates(G_2d)
124
-
125
110
  def get_domain_colors(
126
111
  self,
127
112
  cmap: str = "gist_rainbow",
@@ -200,14 +185,15 @@ class NetworkGraph:
200
185
  Returns:
201
186
  dict: A dictionary mapping domain keys to their corresponding RGBA colors.
202
187
  """
203
- # Exclude non-numeric domain columns
204
- numeric_domains = [
205
- col for col in self.domains.columns if isinstance(col, (int, np.integer))
206
- ]
207
- domains = np.sort(numeric_domains)
188
+ # Get colors for each domain based on node positions
208
189
  domain_colors = _get_colors(
209
- num_colors_to_generate=len(domains), cmap=cmap, color=color, random_seed=random_seed
190
+ self.network,
191
+ self.domain_id_to_node_ids_map,
192
+ cmap=cmap,
193
+ color=color,
194
+ random_seed=random_seed,
210
195
  )
196
+ self.network, self.domain_id_to_node_ids_map
211
197
  return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
212
198
 
213
199
 
@@ -300,35 +286,100 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
300
286
 
301
287
 
302
288
  def _get_colors(
303
- num_colors_to_generate: int = 10,
289
+ network,
290
+ domain_id_to_node_ids_map,
304
291
  cmap: str = "gist_rainbow",
305
292
  color: Union[str, None] = None,
306
293
  random_seed: int = 888,
307
294
  ) -> List[Tuple]:
308
- """Generate a list of RGBA colors from a specified colormap or use a direct color string.
295
+ """Generate a list of RGBA colors based on domain centroids, ensuring that domains
296
+ close in space get maximally separated colors, while keeping some randomness.
309
297
 
310
298
  Args:
311
- num_colors_to_generate (int): The number of colors to generate. Defaults to 10.
299
+ network (NetworkX graph): The graph representing the network.
300
+ domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
312
301
  cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
313
302
  color (str or None, optional): A specific color to use for all generated colors.
314
303
  random_seed (int): Seed for random number generation. Defaults to 888.
315
- Defaults to None.
316
304
 
317
305
  Returns:
318
- list of tuple: List of RGBA colors.
306
+ List[Tuple]: List of RGBA colors.
319
307
  """
320
308
  # Set random seed for reproducibility
321
- random.seed(random_seed)
309
+ np.random.seed(random_seed)
310
+ # Determine the number of colors to generate based on the number of domains
311
+ num_colors_to_generate = len(domain_id_to_node_ids_map)
322
312
  if color:
323
- # If a direct color is provided, generate a list with that color
313
+ # Generate all colors as the same specified color
324
314
  rgba = matplotlib.colors.to_rgba(color)
325
- rgbas = [rgba] * num_colors_to_generate
326
- else:
327
- colormap = matplotlib.colormaps.get_cmap(cmap)
328
- # Generate evenly distributed color positions
329
- color_positions = np.linspace(0, 1, num_colors_to_generate)
330
- random.shuffle(color_positions) # Shuffle the positions to randomize colors
331
- # Generate colors based on shuffled positions
332
- rgbas = [colormap(pos) for pos in color_positions]
333
-
334
- return rgbas
315
+ return [rgba] * num_colors_to_generate
316
+
317
+ # Load colormap
318
+ colormap = matplotlib.colormaps.get_cmap(cmap)
319
+ # Step 1: Calculate centroids for each domain
320
+ centroids = _calculate_centroids(network, domain_id_to_node_ids_map)
321
+ # Step 2: Calculate pairwise distances between centroids
322
+ centroid_array = np.array(centroids)
323
+ dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
324
+ # Step 3: Assign distant colors to close centroids
325
+ color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
326
+ # Step 4: Randomly shift the entire color palette while maintaining relative distances
327
+ global_shift = np.random.uniform(-0.1, 0.1) # Small global shift to change the overall palette
328
+ color_positions = (color_positions + global_shift) % 1 # Wrap around to keep within [0, 1]
329
+ # Step 5: Ensure that all positions remain between 0 and 1
330
+ color_positions = np.clip(color_positions, 0, 1)
331
+
332
+ # Step 6: Generate RGBA colors based on positions
333
+ return [colormap(pos) for pos in color_positions]
334
+
335
+
336
+ def _calculate_centroids(network, domain_id_to_node_ids_map):
337
+ """Calculate the centroid for each domain based on node x and y coordinates in the network.
338
+
339
+ Args:
340
+ network (NetworkX graph): The graph representing the network.
341
+ domain_id_to_node_ids_map (dict): Mapping from domain IDs to lists of node IDs.
342
+
343
+ Returns:
344
+ List[Tuple[float, float]]: List of centroids (x, y) for each domain.
345
+ """
346
+ centroids = []
347
+ for domain_id, node_ids in domain_id_to_node_ids_map.items():
348
+ # Extract x and y coordinates from the network nodes
349
+ node_positions = np.array(
350
+ [[network.nodes[node_id]["x"], network.nodes[node_id]["y"]] for node_id in node_ids]
351
+ )
352
+ # Compute the centroid as the mean of the x and y coordinates
353
+ centroid = np.mean(node_positions, axis=0)
354
+ centroids.append(tuple(centroid))
355
+
356
+ return centroids
357
+
358
+
359
+ def _assign_distant_colors(dist_matrix, num_colors_to_generate):
360
+ """Assign colors to centroids that are close in space, ensuring stark color differences.
361
+
362
+ Args:
363
+ dist_matrix (ndarray): Matrix of pairwise centroid distances.
364
+ num_colors_to_generate (int): Number of colors to generate.
365
+
366
+ Returns:
367
+ np.array: Array of color positions in the range [0, 1].
368
+ """
369
+ color_positions = np.zeros(num_colors_to_generate)
370
+ # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
371
+ proximity_order = sorted(
372
+ range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
373
+ )
374
+ # Step 2: Assign colors starting with the most distant points in proximity order
375
+ for i, idx in enumerate(proximity_order):
376
+ color_positions[idx] = i / num_colors_to_generate
377
+
378
+ # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
379
+ half_spectrum = int(num_colors_to_generate / 2)
380
+ for i in range(half_spectrum):
381
+ # Split the spectrum so that close centroids are assigned distant colors
382
+ color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
383
+ color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate
384
+
385
+ return color_positions
risk/network/io.py CHANGED
@@ -16,7 +16,7 @@ import networkx as nx
16
16
  import pandas as pd
17
17
 
18
18
  from risk.network.geometry import assign_edge_lengths
19
- from risk.log import params, print_header
19
+ from risk.log import params, logger, log_header
20
20
 
21
21
 
22
22
  class NetworkIO:
@@ -57,9 +57,8 @@ class NetworkIO:
57
57
  weight_label=weight_label,
58
58
  )
59
59
 
60
- @classmethod
60
+ @staticmethod
61
61
  def load_gpickle_network(
62
- cls,
63
62
  filepath: str,
64
63
  compute_sphere: bool = True,
65
64
  surface_depth: float = 0.0,
@@ -80,7 +79,7 @@ class NetworkIO:
80
79
  Returns:
81
80
  nx.Graph: Loaded and processed network.
82
81
  """
83
- networkio = cls(
82
+ networkio = NetworkIO(
84
83
  compute_sphere=compute_sphere,
85
84
  surface_depth=surface_depth,
86
85
  min_edges_per_node=min_edges_per_node,
@@ -109,9 +108,8 @@ class NetworkIO:
109
108
  # Initialize the graph
110
109
  return self._initialize_graph(G)
111
110
 
112
- @classmethod
111
+ @staticmethod
113
112
  def load_networkx_network(
114
- cls,
115
113
  network: nx.Graph,
116
114
  compute_sphere: bool = True,
117
115
  surface_depth: float = 0.0,
@@ -132,7 +130,7 @@ class NetworkIO:
132
130
  Returns:
133
131
  nx.Graph: Loaded and processed network.
134
132
  """
135
- networkio = cls(
133
+ networkio = NetworkIO(
136
134
  compute_sphere=compute_sphere,
137
135
  surface_depth=surface_depth,
138
136
  min_edges_per_node=min_edges_per_node,
@@ -158,9 +156,8 @@ class NetworkIO:
158
156
  # Initialize the graph
159
157
  return self._initialize_graph(network)
160
158
 
161
- @classmethod
159
+ @staticmethod
162
160
  def load_cytoscape_network(
163
- cls,
164
161
  filepath: str,
165
162
  source_label: str = "source",
166
163
  target_label: str = "target",
@@ -187,7 +184,7 @@ class NetworkIO:
187
184
  Returns:
188
185
  nx.Graph: Loaded and processed network.
189
186
  """
190
- networkio = cls(
187
+ networkio = NetworkIO(
191
188
  compute_sphere=compute_sphere,
192
189
  surface_depth=surface_depth,
193
190
  min_edges_per_node=min_edges_per_node,
@@ -312,9 +309,8 @@ class NetworkIO:
312
309
  if os.path.exists(tmp_dir):
313
310
  shutil.rmtree(tmp_dir)
314
311
 
315
- @classmethod
312
+ @staticmethod
316
313
  def load_cytoscape_json_network(
317
- cls,
318
314
  filepath: str,
319
315
  source_label: str = "source",
320
316
  target_label: str = "target",
@@ -339,7 +335,7 @@ class NetworkIO:
339
335
  Returns:
340
336
  NetworkX graph: Loaded and processed network.
341
337
  """
342
- networkio = cls(
338
+ networkio = NetworkIO(
343
339
  compute_sphere=compute_sphere,
344
340
  surface_depth=surface_depth,
345
341
  min_edges_per_node=min_edges_per_node,
@@ -455,10 +451,10 @@ class NetworkIO:
455
451
  # Log the number of nodes and edges before and after cleaning
456
452
  num_final_nodes = G.number_of_nodes()
457
453
  num_final_edges = G.number_of_edges()
458
- print(f"Initial node count: {num_initial_nodes}")
459
- print(f"Final node count: {num_final_nodes}")
460
- print(f"Initial edge count: {num_initial_edges}")
461
- print(f"Final edge count: {num_final_edges}")
454
+ logger.debug(f"Initial node count: {num_initial_nodes}")
455
+ logger.debug(f"Final node count: {num_final_nodes}")
456
+ logger.debug(f"Initial edge count: {num_initial_edges}")
457
+ logger.debug(f"Final edge count: {num_final_edges}")
462
458
 
463
459
  def _assign_edge_weights(self, G: nx.Graph) -> None:
464
460
  """Assign weights to the edges in the graph.
@@ -476,7 +472,7 @@ class NetworkIO:
476
472
  ) # Default to 1.0 if 'weight' not present
477
473
 
478
474
  if self.include_edge_weight and missing_weights:
479
- print(f"Total edges missing weights: {missing_weights}")
475
+ logger.debug(f"Total edges missing weights: {missing_weights}")
480
476
 
481
477
  def _validate_nodes(self, G: nx.Graph) -> None:
482
478
  """Validate the graph structure and attributes.
@@ -514,14 +510,14 @@ class NetworkIO:
514
510
  filetype (str): The type of the file being loaded (e.g., 'CSV', 'JSON').
515
511
  filepath (str, optional): The path to the file being loaded. Defaults to "".
516
512
  """
517
- print_header("Loading network")
518
- print(f"Filetype: {filetype}")
513
+ log_header("Loading network")
514
+ logger.debug(f"Filetype: {filetype}")
519
515
  if filepath:
520
- print(f"Filepath: {filepath}")
521
- print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
516
+ logger.debug(f"Filepath: {filepath}")
517
+ logger.debug(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
522
518
  if self.include_edge_weight:
523
- print(f"Weight label: {self.weight_label}")
524
- print(f"Minimum edges per node: {self.min_edges_per_node}")
525
- print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
519
+ logger.debug(f"Weight label: {self.weight_label}")
520
+ logger.debug(f"Minimum edges per node: {self.min_edges_per_node}")
521
+ logger.debug(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
526
522
  if self.compute_sphere:
527
- print(f"Surface depth: {self.surface_depth}")
523
+ logger.debug(f"Surface depth: {self.surface_depth}")
risk/network/plot.py CHANGED
@@ -9,10 +9,12 @@ import matplotlib.colors as mcolors
9
9
  import matplotlib.pyplot as plt
10
10
  import networkx as nx
11
11
  import numpy as np
12
+ import pandas as pd
13
+ from scipy import linalg
12
14
  from scipy.ndimage import label
13
15
  from scipy.stats import gaussian_kde
14
16
 
15
- from risk.log import params
17
+ from risk.log import params, logger
16
18
  from risk.network.graph import NetworkGraph
17
19
 
18
20
 
@@ -85,6 +87,83 @@ class NetworkPlotter:
85
87
 
86
88
  return ax
87
89
 
90
+ def plot_title(
91
+ self,
92
+ title: Union[str, None] = None,
93
+ subtitle: Union[str, None] = None,
94
+ title_fontsize: int = 20,
95
+ subtitle_fontsize: int = 14,
96
+ font: str = "Arial",
97
+ title_color: str = "black",
98
+ subtitle_color: str = "gray",
99
+ title_y: float = 0.975,
100
+ title_space_offset: float = 0.075,
101
+ subtitle_offset: float = 0.025,
102
+ ) -> None:
103
+ """Plot title and subtitle on the network graph with customizable parameters.
104
+
105
+ Args:
106
+ title (str, optional): Title of the plot. Defaults to None.
107
+ subtitle (str, optional): Subtitle of the plot. Defaults to None.
108
+ title_fontsize (int, optional): Font size for the title. Defaults to 16.
109
+ subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 12.
110
+ font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
111
+ title_color (str, optional): Color of the title text. Defaults to "black".
112
+ subtitle_color (str, optional): Color of the subtitle text. Defaults to "gray".
113
+ title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
114
+ title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
115
+ subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
116
+ """
117
+ # Log the title and subtitle parameters
118
+ params.log_plotter(
119
+ title=title,
120
+ subtitle=subtitle,
121
+ title_fontsize=title_fontsize,
122
+ subtitle_fontsize=subtitle_fontsize,
123
+ title_subtitle_font=font,
124
+ title_color=title_color,
125
+ subtitle_color=subtitle_color,
126
+ subtitle_offset=subtitle_offset,
127
+ title_y=title_y,
128
+ title_space_offset=title_space_offset,
129
+ )
130
+
131
+ # Get the current figure and axis dimensions
132
+ fig = self.ax.figure
133
+ # Use a tight layout to ensure that title and subtitle do not overlap with the original plot
134
+ fig.tight_layout(
135
+ rect=[0, 0, 1, 1 - title_space_offset]
136
+ ) # Leave space above the plot for title
137
+
138
+ # Plot title if provided
139
+ if title:
140
+ # Set the title using figure's suptitle to ensure centering
141
+ self.ax.figure.suptitle(
142
+ title,
143
+ fontsize=title_fontsize,
144
+ color=title_color,
145
+ fontname=font,
146
+ x=0.5, # Center the title horizontally
147
+ ha="center",
148
+ va="top",
149
+ y=title_y,
150
+ )
151
+
152
+ # Plot subtitle if provided
153
+ if subtitle:
154
+ # Calculate the subtitle's y position based on title's position and subtitle_offset
155
+ subtitle_y_position = title_y - subtitle_offset
156
+ self.ax.figure.text(
157
+ 0.5, # Ensure horizontal centering for subtitle
158
+ subtitle_y_position,
159
+ subtitle,
160
+ ha="center",
161
+ va="top",
162
+ fontname=font,
163
+ fontsize=subtitle_fontsize,
164
+ color=subtitle_color,
165
+ )
166
+
88
167
  def plot_circle_perimeter(
89
168
  self,
90
169
  scale: float = 1.0,
@@ -509,26 +588,52 @@ class NetworkPlotter:
509
588
  # Extract the positions of the specified nodes
510
589
  points = np.array([pos[n] for n in nodes])
511
590
  if len(points) <= 1:
512
- return # Not enough points to form a contour
591
+ return None # Not enough points to form a contour
513
592
 
593
+ # Check if the KDE forms a single connected component
514
594
  connected = False
595
+ z = None # Initialize z to None to avoid UnboundLocalError
515
596
  while not connected and bandwidth <= 100.0:
516
- # Perform KDE on the points with the given bandwidth
517
- kde = gaussian_kde(points.T, bw_method=bandwidth)
518
- xmin, ymin = points.min(axis=0) - bandwidth
519
- xmax, ymax = points.max(axis=0) + bandwidth
520
- x, y = np.mgrid[
521
- xmin : xmax : complex(0, grid_size), ymin : ymax : complex(0, grid_size)
522
- ]
523
- z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
524
- # Check if the KDE forms a single connected component
525
- connected = _is_connected(z)
526
- if not connected:
527
- bandwidth += 0.05 # Increase bandwidth slightly and retry
597
+ try:
598
+ # Perform KDE on the points with the given bandwidth
599
+ kde = gaussian_kde(points.T, bw_method=bandwidth)
600
+ xmin, ymin = points.min(axis=0) - bandwidth
601
+ xmax, ymax = points.max(axis=0) + bandwidth
602
+ x, y = np.mgrid[
603
+ xmin : xmax : complex(0, grid_size), ymin : ymax : complex(0, grid_size)
604
+ ]
605
+ z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
606
+ # Check if the KDE forms a single connected component
607
+ connected = _is_connected(z)
608
+ if not connected:
609
+ bandwidth += 0.05 # Increase bandwidth slightly and retry
610
+ except linalg.LinAlgError:
611
+ bandwidth += 0.05 # Increase bandwidth and retry
612
+ except Exception as e:
613
+ # Catch any other exceptions and log them
614
+ logger.error(f"Unexpected error when drawing KDE contour: {e}")
615
+ return None
616
+
617
+ # If z is still None, the KDE computation failed
618
+ if z is None:
619
+ logger.error("Failed to compute KDE. Skipping contour plot for these nodes.")
620
+ return None
528
621
 
529
622
  # Define contour levels based on the density
530
623
  min_density, max_density = z.min(), z.max()
624
+ if min_density == max_density:
625
+ logger.warning(
626
+ "Contour levels could not be created due to lack of variation in density."
627
+ )
628
+ return None
629
+
630
+ # Create contour levels based on the density values
531
631
  contour_levels = np.linspace(min_density, max_density, levels)[1:]
632
+ if len(contour_levels) < 2 or not np.all(np.diff(contour_levels) > 0):
633
+ logger.error("Contour levels must be strictly increasing. Skipping contour plot.")
634
+ return None
635
+
636
+ # Set the contour color and linestyle
532
637
  contour_colors = [color for _ in range(levels - 1)]
533
638
  # Plot the filled contours using fill_alpha for transparency
534
639
  if fill_alpha > 0:
@@ -553,6 +658,7 @@ class NetworkPlotter:
553
658
  linewidths=linewidth,
554
659
  alpha=alpha,
555
660
  )
661
+
556
662
  # Set linewidth for the contour lines to 0 for levels other than the base level
557
663
  for i in range(1, len(contour_levels)):
558
664
  c.collections[i].set_linewidth(0)
@@ -601,7 +707,7 @@ class NetworkPlotter:
601
707
  min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
602
708
  max_word_length (int, optional): Maximum number of characters in a word to display. Defaults to 20.
603
709
  min_word_length (int, optional): Minimum number of characters in a word to display. Defaults to 1.
604
- words_to_omit (List, optional): List of words to omit from the labels. Defaults to None.
710
+ words_to_omit (list, optional): List of words to omit from the labels. Defaults to None.
605
711
  overlay_ids (bool, optional): Whether to overlay domain IDs in the center of the centroids. Defaults to False.
606
712
  ids_to_keep (list, tuple, np.ndarray, or None, optional): IDs of domains that must be labeled. To discover domain IDs,
607
713
  you can set `overlay_ids=True`. Defaults to None.
@@ -710,6 +816,9 @@ class NetworkPlotter:
710
816
  # Process remaining domains to fill in additional labels, if there are slots left
711
817
  if remaining_labels and remaining_labels > 0:
712
818
  for idx, (domain, centroid) in enumerate(domain_centroids.items()):
819
+ # Check if the domain is NaN and continue if true
820
+ if pd.isna(domain) or (isinstance(domain, float) and np.isnan(domain)):
821
+ continue # Skip NaN domains
713
822
  if ids_to_keep and domain in ids_to_keep:
714
823
  continue # Skip domains already handled by ids_to_keep
715
824
 
@@ -1086,14 +1195,16 @@ class NetworkPlotter:
1086
1195
  return np.array(annotated_colors)
1087
1196
 
1088
1197
  @staticmethod
1089
- def savefig(*args, **kwargs) -> None:
1090
- """Save the current plot to a file.
1198
+ def savefig(*args, pad_inches: float = 0.5, dpi: int = 100, **kwargs) -> None:
1199
+ """Save the current plot to a file with additional export options.
1091
1200
 
1092
1201
  Args:
1093
1202
  *args: Positional arguments passed to `plt.savefig`.
1203
+ pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
1204
+ dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
1094
1205
  **kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
1095
1206
  """
1096
- plt.savefig(*args, bbox_inches="tight", **kwargs)
1207
+ plt.savefig(*args, bbox_inches="tight", pad_inches=pad_inches, dpi=dpi, **kwargs)
1097
1208
 
1098
1209
  @staticmethod
1099
1210
  def show(*args, **kwargs) -> None:
@@ -1123,7 +1234,9 @@ def _to_rgba(
1123
1234
  """
1124
1235
  # Handle single color case (string, RGB, or RGBA)
1125
1236
  if isinstance(color, str) or (
1126
- isinstance(color, (list, tuple, np.ndarray)) and len(color) in [3, 4]
1237
+ isinstance(color, (list, tuple, np.ndarray))
1238
+ and len(color) in [3, 4]
1239
+ and not any(isinstance(c, (list, tuple, np.ndarray)) for c in color)
1127
1240
  ):
1128
1241
  rgba_color = np.array(mcolors.to_rgba(color))
1129
1242
  # Only set alpha if the input is an RGB color or a string (not RGBA)