risk-network 0.0.6b0__tar.gz → 0.0.6b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/PKG-INFO +1 -1
  2. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/__init__.py +1 -1
  3. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/annotations/io.py +15 -3
  4. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/neighborhoods/community.py +7 -3
  5. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/neighborhoods/neighborhoods.py +2 -2
  6. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/network/io.py +11 -0
  7. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/network/plot.py +338 -173
  8. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/risk.py +6 -18
  9. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk_network.egg-info/PKG-INFO +1 -1
  10. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/LICENSE +0 -0
  11. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/MANIFEST.in +0 -0
  12. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/README.md +0 -0
  13. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/pyproject.toml +0 -0
  14. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/annotations/__init__.py +0 -0
  15. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/annotations/annotations.py +0 -0
  16. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/constants.py +0 -0
  17. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/log/__init__.py +0 -0
  18. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/log/console.py +0 -0
  19. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/log/params.py +0 -0
  20. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/neighborhoods/__init__.py +0 -0
  21. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/neighborhoods/domains.py +0 -0
  22. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/network/__init__.py +0 -0
  23. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/network/geometry.py +0 -0
  24. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/network/graph.py +0 -0
  25. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/stats/__init__.py +0 -0
  26. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/stats/fisher_exact.py +0 -0
  27. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/stats/hypergeom.py +0 -0
  28. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/stats/permutation/__init__.py +0 -0
  29. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/stats/permutation/permutation.py +0 -0
  30. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/stats/permutation/test_functions.py +0 -0
  31. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk/stats/stats.py +0 -0
  32. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk_network.egg-info/SOURCES.txt +0 -0
  33. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk_network.egg-info/dependency_links.txt +0 -0
  34. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk_network.egg-info/requires.txt +0 -0
  35. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/risk_network.egg-info/top_level.txt +0 -0
  36. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/setup.cfg +0 -0
  37. {risk_network-0.0.6b0 → risk_network-0.0.6b1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.6b0
3
+ Version: 0.0.6b1
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -7,4 +7,4 @@ RISK: RISK Infers Spatial Kinships
7
7
 
8
8
  from risk.risk import RISK
9
9
 
10
- __version__ = "0.0.6-beta.0"
10
+ __version__ = "0.0.6-beta.1"
@@ -36,13 +36,15 @@ class AnnotationsIO:
36
36
  dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
37
37
  """
38
38
  filetype = "JSON"
39
+ # Log the loading of the JSON file
39
40
  params.log_annotations(filepath=filepath, filetype=filetype)
40
41
  _log_loading(filetype, filepath=filepath)
42
+
41
43
  # Open and read the JSON file
42
44
  with open(filepath, "r") as file:
43
45
  annotations_input = json.load(file)
44
46
 
45
- # Process the JSON data and return it in the context of the network
47
+ # Load the annotations into the provided network
46
48
  return load_annotations(network, annotations_input)
47
49
 
48
50
  def load_excel_annotation(
@@ -69,14 +71,18 @@ class AnnotationsIO:
69
71
  linked to the provided network.
70
72
  """
71
73
  filetype = "Excel"
74
+ # Log the loading of the Excel file
72
75
  params.log_annotations(filepath=filepath, filetype=filetype)
73
76
  _log_loading(filetype, filepath=filepath)
77
+
74
78
  # Load the specified sheet from the Excel file
75
79
  df = pd.read_excel(filepath, sheet_name=sheet_name)
76
80
  # Split the nodes column by the specified nodes_delimiter
77
81
  df[nodes_colname] = df[nodes_colname].apply(lambda x: x.split(nodes_delimiter))
78
82
  # Convert the DataFrame to a dictionary pairing labels with their corresponding nodes
79
83
  label_node_dict = df.set_index(label_colname)[nodes_colname].to_dict()
84
+
85
+ # Load the annotations into the provided network
80
86
  return load_annotations(network, label_node_dict)
81
87
 
82
88
  def load_csv_annotation(
@@ -101,13 +107,16 @@ class AnnotationsIO:
101
107
  linked to the provided network.
102
108
  """
103
109
  filetype = "CSV"
110
+ # Log the loading of the CSV file
104
111
  params.log_annotations(filepath=filepath, filetype=filetype)
105
112
  _log_loading(filetype, filepath=filepath)
113
+
106
114
  # Load the CSV file into a dictionary
107
115
  annotations_input = _load_matrix_file(
108
116
  filepath, label_colname, nodes_colname, delimiter=",", nodes_delimiter=nodes_delimiter
109
117
  )
110
- # Process and return the annotations in the context of the network
118
+
119
+ # Load the annotations into the provided network
111
120
  return load_annotations(network, annotations_input)
112
121
 
113
122
  def load_tsv_annotation(
@@ -132,13 +141,16 @@ class AnnotationsIO:
132
141
  linked to the provided network.
133
142
  """
134
143
  filetype = "TSV"
144
+ # Log the loading of the TSV file
135
145
  params.log_annotations(filepath=filepath, filetype=filetype)
136
146
  _log_loading(filetype, filepath=filepath)
147
+
137
148
  # Load the TSV file into a dictionary
138
149
  annotations_input = _load_matrix_file(
139
150
  filepath, label_colname, nodes_colname, delimiter="\t", nodes_delimiter=nodes_delimiter
140
151
  )
141
- # Process and return the annotations in the context of the network
152
+
153
+ # Load the annotations into the provided network
142
154
  return load_annotations(network, annotations_input)
143
155
 
144
156
 
@@ -25,10 +25,14 @@ def calculate_dijkstra_neighborhoods(network: nx.Graph) -> np.ndarray:
25
25
 
26
26
  # Populate the neighborhoods matrix based on Dijkstra's distances
27
27
  for source, targets in all_dijkstra_paths.items():
28
+ max_length = max(targets.values()) if targets else 1 # Handle cases with no targets
28
29
  for target, length in targets.items():
29
- neighborhoods[source, target] = (
30
- 1 if np.isnan(length) or length == 0 else np.sqrt(1 / length)
31
- )
30
+ if np.isnan(length):
31
+ neighborhoods[source, target] = max_length # Use max distance for NaN
32
+ elif length == 0:
33
+ neighborhoods[source, target] = 1 # Assign 1 for zero-length paths (self-loops)
34
+ else:
35
+ neighborhoods[source, target] = 1 / length # Inverse of the distance
32
36
 
33
37
  return neighborhoods
34
38
 
@@ -4,7 +4,7 @@ risk/neighborhoods/neighborhoods
4
4
  """
5
5
 
6
6
  import warnings
7
- from typing import Any, Dict, Tuple
7
+ from typing import Any, Dict, List, Tuple
8
8
 
9
9
  import networkx as nx
10
10
  import numpy as np
@@ -305,7 +305,7 @@ def _get_node_position(network: nx.Graph, node: Any) -> np.ndarray:
305
305
  )
306
306
 
307
307
 
308
- def _calculate_threshold(average_distances: list, distance_threshold: float) -> float:
308
+ def _calculate_threshold(average_distances: List, distance_threshold: float) -> float:
309
309
  """Calculate the distance threshold based on the given average distances and a percentile threshold.
310
310
 
311
311
  Args:
@@ -48,6 +48,7 @@ class NetworkIO:
48
48
  self.min_edges_per_node = min_edges_per_node
49
49
  self.include_edge_weight = include_edge_weight
50
50
  self.weight_label = weight_label
51
+ # Log the initialization of the NetworkIO class
51
52
  params.log_network(
52
53
  compute_sphere=compute_sphere,
53
54
  surface_depth=surface_depth,
@@ -98,11 +99,14 @@ class NetworkIO:
98
99
  nx.Graph: Loaded and processed network.
99
100
  """
100
101
  filetype = "GPickle"
102
+ # Log the loading of the GPickle file
101
103
  params.log_network(filetype=filetype, filepath=filepath)
102
104
  self._log_loading(filetype, filepath=filepath)
105
+
103
106
  with open(filepath, "rb") as f:
104
107
  G = pickle.load(f)
105
108
 
109
+ # Initialize the graph
106
110
  return self._initialize_graph(G)
107
111
 
108
112
  @classmethod
@@ -147,8 +151,11 @@ class NetworkIO:
147
151
  nx.Graph: Processed network.
148
152
  """
149
153
  filetype = "NetworkX"
154
+ # Log the loading of the NetworkX graph
150
155
  params.log_network(filetype=filetype)
151
156
  self._log_loading(filetype)
157
+
158
+ # Initialize the graph
152
159
  return self._initialize_graph(network)
153
160
 
154
161
  @classmethod
@@ -213,8 +220,10 @@ class NetworkIO:
213
220
  nx.Graph: Loaded and processed network.
214
221
  """
215
222
  filetype = "Cytoscape"
223
+ # Log the loading of the Cytoscape file
216
224
  params.log_network(filetype=filetype, filepath=str(filepath))
217
225
  self._log_loading(filetype, filepath=filepath)
226
+
218
227
  cys_files = []
219
228
  tmp_dir = ".tmp_cytoscape"
220
229
  # Try / finally to remove unzipped files
@@ -295,6 +304,7 @@ class NetworkIO:
295
304
  node
296
305
  ] # Assuming you have a dict `node_y_positions` for y coordinates
297
306
 
307
+ # Initialize the graph
298
308
  return self._initialize_graph(G)
299
309
 
300
310
  finally:
@@ -354,6 +364,7 @@ class NetworkIO:
354
364
  NetworkX graph: Loaded and processed network.
355
365
  """
356
366
  filetype = "Cytoscape JSON"
367
+ # Log the loading of the Cytoscape JSON file
357
368
  params.log_network(filetype=filetype, filepath=str(filepath))
358
369
  self._log_loading(filetype, filepath=filepath)
359
370
 
@@ -5,7 +5,6 @@ risk/network/plot
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
7
7
 
8
- import matplotlib
9
8
  import matplotlib.colors as mcolors
10
9
  import matplotlib.pyplot as plt
11
10
  import networkx as nx
@@ -18,66 +17,42 @@ from risk.network.graph import NetworkGraph
18
17
 
19
18
 
20
19
  class NetworkPlotter:
21
- """A class responsible for visualizing network graphs with various customization options.
20
+ """A class for visualizing network graphs with customizable options.
22
21
 
23
- The NetworkPlotter class takes in a NetworkGraph object, which contains the network's data and attributes,
24
- and provides methods for plotting the network with customizable node and edge properties,
25
- as well as optional features like drawing the network's perimeter and setting background colors.
22
+ The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
23
+ flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
24
+ perimeter, and adjusting background colors.
26
25
  """
27
26
 
28
27
  def __init__(
29
28
  self,
30
29
  graph: NetworkGraph,
31
- figsize: tuple = (10, 10),
32
- background_color: str = "white",
33
- plot_outline: bool = True,
34
- outline_color: str = "black",
35
- outline_scale: float = 1.0,
36
- linestyle: str = "dashed",
30
+ figsize: Tuple = (10, 10),
31
+ background_color: Union[str, List, Tuple, np.ndarray] = "white",
37
32
  ) -> None:
38
33
  """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
39
34
 
40
35
  Args:
41
36
  graph (NetworkGraph): The network data and attributes to be visualized.
42
37
  figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
43
- background_color (str, optional): Background color of the plot. Defaults to "white".
44
- plot_outline (bool, optional): Whether to plot the network perimeter circle. Defaults to True.
45
- outline_color (str, optional): Color of the network perimeter circle. Defaults to "black".
46
- outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
47
- linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
38
+ background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
48
39
  """
49
40
  self.graph = graph
50
41
  # Initialize the plot with the specified parameters
51
- self.ax = self._initialize_plot(
52
- graph,
53
- figsize,
54
- background_color,
55
- plot_outline,
56
- outline_color,
57
- outline_scale,
58
- linestyle,
59
- )
42
+ self.ax = self._initialize_plot(graph, figsize, background_color)
60
43
 
61
44
  def _initialize_plot(
62
45
  self,
63
46
  graph: NetworkGraph,
64
- figsize: tuple,
65
- background_color: str,
66
- plot_outline: bool,
67
- outline_color: str,
68
- outline_scale: float,
69
- linestyle: str,
47
+ figsize: Tuple,
48
+ background_color: Union[str, List, Tuple, np.ndarray],
70
49
  ) -> plt.Axes:
71
- """Set up the plot with figure size, optional circle perimeter, and background color.
50
+ """Set up the plot with figure size and background color.
72
51
 
73
52
  Args:
74
53
  graph (NetworkGraph): The network data and attributes to be visualized.
75
54
  figsize (tuple): Size of the figure in inches (width, height).
76
55
  background_color (str): Background color of the plot.
77
- plot_outline (bool): Whether to plot the network perimeter circle.
78
- outline_color (str): Color of the network perimeter circle.
79
- outline_scale (float): Outline scaling factor for the perimeter diameter.
80
- linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid).
81
56
 
82
57
  Returns:
83
58
  plt.Axes: The axis object for the plot.
@@ -86,31 +61,19 @@ class NetworkPlotter:
86
61
  node_coordinates = graph.node_coordinates
87
62
  # Calculate the center and radius of the bounding box around the network
88
63
  center, radius = _calculate_bounding_box(node_coordinates)
89
- # Scale the radius by the outline_scale factor
90
- scaled_radius = radius * outline_scale
91
64
 
92
65
  # Create a new figure and axis for plotting
93
66
  fig, ax = plt.subplots(figsize=figsize)
94
67
  fig.tight_layout() # Adjust subplot parameters to give specified padding
95
- if plot_outline:
96
- # Draw a circle to represent the network perimeter
97
- circle = plt.Circle(
98
- center,
99
- scaled_radius,
100
- linestyle=linestyle, # Use the linestyle argument here
101
- color=outline_color,
102
- fill=False,
103
- linewidth=1.5,
104
- )
105
- ax.add_artist(circle) # Add the circle to the plot
106
-
107
- # Set axis limits based on the calculated bounding box and scaled radius
108
- ax.set_xlim([center[0] - scaled_radius - 0.3, center[0] + scaled_radius + 0.3])
109
- ax.set_ylim([center[1] - scaled_radius - 0.3, center[1] + scaled_radius + 0.3])
68
+ # Set axis limits based on the calculated bounding box and radius
69
+ ax.set_xlim([center[0] - radius - 0.3, center[0] + radius + 0.3])
70
+ ax.set_ylim([center[1] - radius - 0.3, center[1] + radius + 0.3])
110
71
  ax.set_aspect("equal") # Ensure the aspect ratio is equal
111
- fig.patch.set_facecolor(background_color) # Set the figure background color
112
- ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
113
72
 
73
+ # Set the background color of the plot
74
+ # Convert color to RGBA using the _to_rgba helper function
75
+ fig.patch.set_facecolor(_to_rgba(background_color, 1.0))
76
+ ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
114
77
  # Remove axis spines for a cleaner look
115
78
  for spine in ax.spines.values():
116
79
  spine.set_visible(False)
@@ -122,44 +85,182 @@ class NetworkPlotter:
122
85
 
123
86
  return ax
124
87
 
88
+ def plot_circle_perimeter(
89
+ self,
90
+ scale: float = 1.0,
91
+ linestyle: str = "dashed",
92
+ linewidth: float = 1.5,
93
+ color: Union[str, List, Tuple, np.ndarray] = "black",
94
+ outline_alpha: float = 1.0,
95
+ fill_alpha: float = 0.0,
96
+ ) -> None:
97
+ """Plot a circle around the network graph to represent the network perimeter.
98
+
99
+ Args:
100
+ scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
101
+ linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
102
+ linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
103
+ color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
104
+ outline_alpha (float, optional): Transparency level of the circle outline. Defaults to 1.0.
105
+ fill_alpha (float, optional): Transparency level of the circle fill. Defaults to 0.0.
106
+ """
107
+ # Log the circle perimeter plotting parameters
108
+ params.log_plotter(
109
+ perimeter_type="circle",
110
+ perimeter_scale=scale,
111
+ perimeter_linestyle=linestyle,
112
+ perimeter_linewidth=linewidth,
113
+ perimeter_color=(
114
+ "custom" if isinstance(color, (list, tuple, np.ndarray)) else color
115
+ ), # np.ndarray usually indicates custom colors
116
+ perimeter_outline_alpha=outline_alpha,
117
+ perimeter_fill_alpha=fill_alpha,
118
+ )
119
+
120
+ # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
121
+ color = _to_rgba(color, outline_alpha)
122
+ # Extract node coordinates from the network graph
123
+ node_coordinates = self.graph.node_coordinates
124
+ # Calculate the center and radius of the bounding box around the network
125
+ center, radius = _calculate_bounding_box(node_coordinates)
126
+ # Scale the radius by the scale factor
127
+ scaled_radius = radius * scale
128
+
129
+ # Draw a circle to represent the network perimeter
130
+ circle = plt.Circle(
131
+ center,
132
+ scaled_radius,
133
+ linestyle=linestyle,
134
+ linewidth=linewidth,
135
+ color=color,
136
+ fill=fill_alpha > 0, # Fill the circle if fill_alpha is greater than 0
137
+ )
138
+ # Set the transparency of the fill if applicable
139
+ if fill_alpha > 0:
140
+ circle.set_facecolor(
141
+ _to_rgba(color, fill_alpha)
142
+ ) # Use _to_rgba to set the fill color with transparency
143
+
144
+ self.ax.add_artist(circle)
145
+
146
+ def plot_contour_perimeter(
147
+ self,
148
+ scale: float = 1.0,
149
+ levels: int = 3,
150
+ bandwidth: float = 0.8,
151
+ grid_size: int = 250,
152
+ color: Union[str, List, Tuple, np.ndarray] = "black",
153
+ linestyle: str = "solid",
154
+ linewidth: float = 1.5,
155
+ outline_alpha: float = 1.0,
156
+ fill_alpha: float = 0.0,
157
+ ) -> None:
158
+ """
159
+ Plot a KDE-based contour around the network graph to represent the network perimeter.
160
+
161
+ Args:
162
+ scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
163
+ levels (int, optional): Number of contour levels. Defaults to 3.
164
+ bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
165
+ grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
166
+ color (str, list, tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
167
+ linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
168
+ linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
169
+ outline_alpha (float, optional): Transparency level of the contour outline. Defaults to 1.0.
170
+ fill_alpha (float, optional): Transparency level of the contour fill. Defaults to 0.0.
171
+ """
172
+ # Log the contour perimeter plotting parameters
173
+ params.log_plotter(
174
+ perimeter_type="contour",
175
+ perimeter_scale=scale,
176
+ perimeter_levels=levels,
177
+ perimeter_bandwidth=bandwidth,
178
+ perimeter_grid_size=grid_size,
179
+ perimeter_linestyle=linestyle,
180
+ perimeter_linewidth=linewidth,
181
+ perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
182
+ perimeter_outline_alpha=outline_alpha,
183
+ perimeter_fill_alpha=fill_alpha,
184
+ )
185
+
186
+ # Convert color to RGBA using the _to_rgba helper function - use outline_alpha for the perimeter
187
+ color = _to_rgba(color, outline_alpha)
188
+ # Extract node coordinates from the network graph
189
+ node_coordinates = self.graph.node_coordinates
190
+ # Scale the node coordinates if needed
191
+ scaled_coordinates = node_coordinates * scale
192
+ # Use the existing _draw_kde_contour method
193
+ self._draw_kde_contour(
194
+ ax=self.ax,
195
+ pos=scaled_coordinates,
196
+ nodes=list(range(len(node_coordinates))), # All nodes are included
197
+ levels=levels,
198
+ bandwidth=bandwidth,
199
+ grid_size=grid_size,
200
+ color=color,
201
+ linestyle=linestyle,
202
+ linewidth=linewidth,
203
+ alpha=fill_alpha, # Use fill_alpha for the fill
204
+ )
205
+
125
206
  def plot_network(
126
207
  self,
127
208
  node_size: Union[int, np.ndarray] = 50,
128
- edge_width: float = 1.0,
129
- node_color: Union[str, np.ndarray] = "white",
130
- node_edgecolor: str = "black",
131
- edge_color: str = "black",
132
209
  node_shape: str = "o",
210
+ edge_width: float = 1.0,
211
+ node_color: Union[str, List, Tuple, np.ndarray] = "white",
212
+ node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
213
+ edge_color: Union[str, List, Tuple, np.ndarray] = "black",
214
+ node_alpha: float = 1.0,
215
+ edge_alpha: float = 1.0,
133
216
  ) -> None:
134
217
  """Plot the network graph with customizable node colors, sizes, and edge widths.
135
218
 
136
219
  Args:
137
220
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
138
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
139
- node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
140
- node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
141
- edge_color (str, optional): Color of the edges. Defaults to "black".
142
221
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
222
+ edge_width (float, optional): Width of the edges. Defaults to 1.0.
223
+ node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
224
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
225
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
226
+ node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0.
227
+ edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
143
228
  """
144
229
  # Log the plotting parameters
145
230
  params.log_plotter(
146
- network_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
231
+ network_node_size=(
232
+ "custom" if isinstance(node_size, np.ndarray) else node_size
233
+ ), # np.ndarray usually indicates custom sizes
234
+ network_node_shape=node_shape,
147
235
  network_edge_width=edge_width,
148
- network_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
236
+ network_node_color=(
237
+ "custom" if isinstance(node_color, np.ndarray) else node_color
238
+ ), # np.ndarray usually indicates custom colors
149
239
  network_node_edgecolor=node_edgecolor,
150
240
  network_edge_color=edge_color,
151
- network_node_shape=node_shape,
241
+ network_node_alpha=node_alpha,
242
+ network_edge_alpha=edge_alpha,
152
243
  )
244
+
245
+ # Convert colors to RGBA using the _to_rgba helper function
246
+ # If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
247
+ node_color = _to_rgba(node_color, node_alpha, num_repeats=len(self.graph.network.nodes))
248
+ # Convert other colors to RGBA using the _to_rgba helper function
249
+ node_edgecolor = _to_rgba(
250
+ node_edgecolor, 1.0, num_repeats=len(self.graph.network.nodes)
251
+ ) # Node edges are usually fully opaque
252
+ edge_color = _to_rgba(edge_color, edge_alpha, num_repeats=len(self.graph.network.edges))
253
+
153
254
  # Extract node coordinates from the network graph
154
255
  node_coordinates = self.graph.node_coordinates
256
+
155
257
  # Draw the nodes of the graph
156
258
  nx.draw_networkx_nodes(
157
259
  self.graph.network,
158
260
  pos=node_coordinates,
159
261
  node_size=node_size,
160
- node_color=node_color,
161
262
  node_shape=node_shape,
162
- alpha=1.00,
263
+ node_color=node_color,
163
264
  edgecolors=node_edgecolor,
164
265
  ax=self.ax,
165
266
  )
@@ -174,37 +275,33 @@ class NetworkPlotter:
174
275
 
175
276
  def plot_subnetwork(
176
277
  self,
177
- nodes: list,
278
+ nodes: List,
178
279
  node_size: Union[int, np.ndarray] = 50,
179
- edge_width: float = 1.0,
180
- node_color: Union[str, np.ndarray] = "white",
181
- node_edgecolor: str = "black",
182
- edge_color: str = "black",
183
280
  node_shape: str = "o",
281
+ edge_width: float = 1.0,
282
+ node_color: Union[str, List, Tuple, np.ndarray] = "white",
283
+ node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
284
+ edge_color: Union[str, List, Tuple, np.ndarray] = "black",
285
+ node_alpha: float = 1.0,
286
+ edge_alpha: float = 1.0,
184
287
  ) -> None:
185
288
  """Plot a subnetwork of selected nodes with customizable node and edge attributes.
186
289
 
187
290
  Args:
188
291
  nodes (list): List of node labels to include in the subnetwork.
189
292
  node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
190
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
191
- node_color (str or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
192
- node_edgecolor (str, optional): Color of the node edges. Defaults to "black".
193
- edge_color (str, optional): Color of the edges. Defaults to "black".
194
293
  node_shape (str, optional): Shape of the nodes. Defaults to "o".
294
+ edge_width (float, optional): Width of the edges. Defaults to 1.0.
295
+ node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors. Defaults to "white".
296
+ node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
297
+ edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
298
+ node_alpha (float, optional): Alpha value (transparency) for the nodes. Defaults to 1.0.
299
+ edge_alpha (float, optional): Alpha value (transparency) for the edges. Defaults to 1.0.
195
300
 
196
301
  Raises:
197
302
  ValueError: If no valid nodes are found in the network graph.
198
303
  """
199
- # Log the plotting parameters for the subnetwork
200
- params.log_plotter(
201
- subnetwork_node_size="custom" if isinstance(node_size, np.ndarray) else node_size,
202
- subnetwork_edge_width=edge_width,
203
- subnetwork_node_color="custom" if isinstance(node_color, np.ndarray) else node_color,
204
- subnetwork_node_edgecolor=node_edgecolor,
205
- subnetwork_edge_color=edge_color,
206
- subnet_node_shape=node_shape,
207
- )
304
+ # Don't log subnetwork parameters as they are specific to individual annotations
208
305
  # Filter to get node IDs and their coordinates
209
306
  node_ids = [
210
307
  self.graph.node_label_to_id_map.get(node)
@@ -214,17 +311,21 @@ class NetworkPlotter:
214
311
  if not node_ids:
215
312
  raise ValueError("No nodes found in the network graph.")
216
313
 
314
+ # Convert colors to RGBA using the _to_rgba helper function
315
+ node_color = _to_rgba(node_color, node_alpha)
316
+ node_edgecolor = _to_rgba(node_edgecolor, 1.0) # Node edges usually fully opaque
317
+ edge_color = _to_rgba(edge_color, edge_alpha)
217
318
  # Get the coordinates of the filtered nodes
218
319
  node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
320
+
219
321
  # Draw the nodes in the subnetwork
220
322
  nx.draw_networkx_nodes(
221
323
  self.graph.network,
222
324
  pos=node_coordinates,
223
325
  nodelist=node_ids,
224
326
  node_size=node_size,
225
- node_color=node_color,
226
327
  node_shape=node_shape,
227
- alpha=1.00,
328
+ node_color=node_color,
228
329
  edgecolors=node_edgecolor,
229
330
  ax=self.ax,
230
331
  )
@@ -243,8 +344,10 @@ class NetworkPlotter:
243
344
  levels: int = 5,
244
345
  bandwidth: float = 0.8,
245
346
  grid_size: int = 250,
347
+ color: Union[str, List, Tuple, np.ndarray] = "white",
348
+ linestyle: str = "solid",
349
+ linewidth: float = 1.5,
246
350
  alpha: float = 0.2,
247
- color: Union[str, np.ndarray] = "white",
248
351
  ) -> None:
249
352
  """Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
250
353
 
@@ -252,21 +355,24 @@ class NetworkPlotter:
252
355
  levels (int, optional): Number of contour levels to plot. Defaults to 5.
253
356
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
254
357
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
358
+ color (str, list, tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors. Defaults to "white".
359
+ linestyle (str, optional): Line style for the contours. Defaults to "solid".
360
+ linewidth (float, optional): Line width for the contours. Defaults to 1.5.
255
361
  alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
256
- color (str or np.ndarray, optional): Color of the contours. Can be a string (e.g., 'white') or an array of colors. Defaults to "white".
257
362
  """
258
363
  # Log the contour plotting parameters
259
364
  params.log_plotter(
260
365
  contour_levels=levels,
261
366
  contour_bandwidth=bandwidth,
262
367
  contour_grid_size=grid_size,
368
+ contour_color=(
369
+ "custom" if isinstance(color, np.ndarray) else color
370
+ ), # np.ndarray usually indicates custom colors
263
371
  contour_alpha=alpha,
264
- contour_color="custom" if isinstance(color, np.ndarray) else color,
265
372
  )
266
- # Convert color string to RGBA array if necessary
267
- if isinstance(color, str):
268
- color = self.get_annotated_contour_colors(color=color)
269
373
 
374
+ # Ensure color is converted to RGBA with repetition matching the number of domains
375
+ color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_to_nodes))
270
376
  # Extract node coordinates from the network graph
271
377
  node_coordinates = self.graph.node_coordinates
272
378
  # Draw contours for each domain in the network
@@ -280,17 +386,21 @@ class NetworkPlotter:
280
386
  levels=levels,
281
387
  bandwidth=bandwidth,
282
388
  grid_size=grid_size,
389
+ linestyle=linestyle,
390
+ linewidth=linewidth,
283
391
  alpha=alpha,
284
392
  )
285
393
 
286
394
  def plot_subcontour(
287
395
  self,
288
- nodes: list,
396
+ nodes: List,
289
397
  levels: int = 5,
290
398
  bandwidth: float = 0.8,
291
399
  grid_size: int = 250,
400
+ color: Union[str, List, Tuple, np.ndarray] = "white",
401
+ linestyle: str = "solid",
402
+ linewidth: float = 1.5,
292
403
  alpha: float = 0.2,
293
- color: Union[str, np.ndarray] = "white",
294
404
  ) -> None:
295
405
  """Plot a subcontour for a given set of nodes using Kernel Density Estimation (KDE).
296
406
 
@@ -299,20 +409,15 @@ class NetworkPlotter:
299
409
  levels (int, optional): Number of contour levels to plot. Defaults to 5.
300
410
  bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
301
411
  grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
412
+ color (str, list, tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
413
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
414
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
302
415
  alpha (float, optional): Transparency level of the contour fill. Defaults to 0.2.
303
- color (str or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array. Defaults to "white".
304
416
 
305
417
  Raises:
306
418
  ValueError: If no valid nodes are found in the network graph.
307
419
  """
308
- # Log the plotting parameters
309
- params.log_plotter(
310
- subcontour_levels=levels,
311
- subcontour_bandwidth=bandwidth,
312
- subcontour_grid_size=grid_size,
313
- subcontour_alpha=alpha,
314
- subcontour_color="custom" if isinstance(color, np.ndarray) else color,
315
- )
420
+ # Don't log subcontour parameters as they are specific to individual annotations
316
421
  # Filter to get node IDs and their coordinates
317
422
  node_ids = [
318
423
  self.graph.node_label_to_id_map.get(node)
@@ -322,6 +427,8 @@ class NetworkPlotter:
322
427
  if not node_ids or len(node_ids) == 1:
323
428
  raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
324
429
 
430
+ # Convert color to RGBA using the _to_rgba helper function
431
+ color = _to_rgba(color, alpha)
325
432
  # Draw the KDE contour for the specified nodes
326
433
  node_coordinates = self.graph.node_coordinates
327
434
  self._draw_kde_contour(
@@ -332,6 +439,8 @@ class NetworkPlotter:
332
439
  levels=levels,
333
440
  bandwidth=bandwidth,
334
441
  grid_size=grid_size,
442
+ linestyle=linestyle,
443
+ linewidth=linewidth,
335
444
  alpha=alpha,
336
445
  )
337
446
 
@@ -339,11 +448,13 @@ class NetworkPlotter:
339
448
  self,
340
449
  ax: plt.Axes,
341
450
  pos: np.ndarray,
342
- nodes: list,
343
- color: Union[str, np.ndarray],
451
+ nodes: List,
344
452
  levels: int = 5,
345
453
  bandwidth: float = 0.8,
346
454
  grid_size: int = 250,
455
+ color: Union[str, np.ndarray] = "white",
456
+ linestyle: str = "solid",
457
+ linewidth: float = 1.5,
347
458
  alpha: float = 0.5,
348
459
  ) -> None:
349
460
  """Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
@@ -352,10 +463,12 @@ class NetworkPlotter:
352
463
  ax (plt.Axes): The axis to draw the contour on.
353
464
  pos (np.ndarray): Array of node positions (x, y).
354
465
  nodes (list): List of node indices to include in the contour.
355
- color (str or np.ndarray): Color for the contour.
356
466
  levels (int, optional): Number of contour levels. Defaults to 5.
357
467
  bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
358
468
  grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
469
+ color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
470
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
471
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
359
472
  alpha (float, optional): Transparency level for the contour fill. Defaults to 0.5.
360
473
  """
361
474
  # Extract the positions of the specified nodes
@@ -382,8 +495,7 @@ class NetworkPlotter:
382
495
  min_density, max_density = z.min(), z.max()
383
496
  contour_levels = np.linspace(min_density, max_density, levels)[1:]
384
497
  contour_colors = [color for _ in range(levels - 1)]
385
-
386
- # Plot the filled contours if alpha > 0
498
+ # Plot the filled contours only if alpha > 0
387
499
  if alpha > 0:
388
500
  ax.contourf(
389
501
  x,
@@ -391,13 +503,20 @@ class NetworkPlotter:
391
503
  z,
392
504
  levels=contour_levels,
393
505
  colors=contour_colors,
394
- alpha=alpha,
395
- extend="neither",
396
506
  antialiased=True,
507
+ alpha=alpha,
397
508
  )
398
509
 
399
- # Plot the contour lines without antialiasing for clarity
400
- c = ax.contour(x, y, z, levels=contour_levels, colors=contour_colors)
510
+ # Plot the contour lines without any change in behavior
511
+ c = ax.contour(
512
+ x,
513
+ y,
514
+ z,
515
+ levels=contour_levels,
516
+ colors=contour_colors,
517
+ linestyles=linestyle,
518
+ linewidths=linewidth,
519
+ )
401
520
  for i in range(1, len(contour_levels)):
402
521
  c.collections[i].set_linewidth(0)
403
522
 
@@ -407,9 +526,11 @@ class NetworkPlotter:
407
526
  offset: float = 0.10,
408
527
  font: str = "Arial",
409
528
  fontsize: int = 10,
410
- fontcolor: Union[str, np.ndarray] = "black",
529
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
530
+ fontalpha: float = 1.0,
411
531
  arrow_linewidth: float = 1,
412
- arrow_color: Union[str, np.ndarray] = "black",
532
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
533
+ arrow_alpha: float = 1.0,
413
534
  max_labels: Union[int, None] = None,
414
535
  max_words: int = 10,
415
536
  min_words: int = 1,
@@ -424,9 +545,11 @@ class NetworkPlotter:
424
545
  offset (float, optional): Offset distance for labels from the perimeter. Defaults to 0.10.
425
546
  font (str, optional): Font name for the labels. Defaults to "Arial".
426
547
  fontsize (int, optional): Font size for the labels. Defaults to 10.
427
- fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
548
+ fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
549
+ fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
428
550
  arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
429
- arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
551
+ arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrows. Defaults to "black".
552
+ arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
430
553
  max_labels (int, optional): Maximum number of labels to plot. Defaults to None (no limit).
431
554
  max_words (int, optional): Maximum number of words in a label. Defaults to 10.
432
555
  min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
@@ -440,9 +563,13 @@ class NetworkPlotter:
440
563
  label_offset=offset,
441
564
  label_font=font,
442
565
  label_fontsize=fontsize,
443
- label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
566
+ label_fontcolor=(
567
+ "custom" if isinstance(fontcolor, np.ndarray) else fontcolor
568
+ ), # np.ndarray usually indicates custom colors
569
+ label_fontalpha=fontalpha,
444
570
  label_arrow_linewidth=arrow_linewidth,
445
571
  label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
572
+ label_arrow_alpha=arrow_alpha,
446
573
  label_max_labels=max_labels,
447
574
  label_max_words=max_words,
448
575
  label_min_words=min_words,
@@ -451,11 +578,12 @@ class NetworkPlotter:
451
578
  label_words_to_omit=words_to_omit,
452
579
  )
453
580
 
454
- # Convert color strings to RGBA arrays if necessary
455
- if isinstance(fontcolor, str):
456
- fontcolor = self.get_annotated_label_colors(color=fontcolor)
457
- if isinstance(arrow_color, str):
458
- arrow_color = self.get_annotated_label_colors(color=arrow_color)
581
+ # Convert colors to RGBA using the _to_rgba helper function, applying alpha separately for font and arrow
582
+ fontcolor = _to_rgba(fontcolor, fontalpha, num_repeats=len(self.graph.domain_to_nodes))
583
+ arrow_color = _to_rgba(
584
+ arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_to_nodes)
585
+ )
586
+
459
587
  # Normalize words_to_omit to lowercase
460
588
  if words_to_omit:
461
589
  words_to_omit = set(word.lower() for word in words_to_omit)
@@ -535,16 +663,18 @@ class NetworkPlotter:
535
663
 
536
664
  def plot_sublabel(
537
665
  self,
538
- nodes: list,
666
+ nodes: List,
539
667
  label: str,
540
668
  radial_position: float = 0.0,
541
669
  perimeter_scale: float = 1.05,
542
670
  offset: float = 0.10,
543
671
  font: str = "Arial",
544
672
  fontsize: int = 10,
545
- fontcolor: str = "black",
673
+ fontcolor: Union[str, List, Tuple, np.ndarray] = "black",
674
+ fontalpha: float = 1.0,
546
675
  arrow_linewidth: float = 1,
547
- arrow_color: str = "black",
676
+ arrow_color: Union[str, List, Tuple, np.ndarray] = "black",
677
+ arrow_alpha: float = 1.0,
548
678
  ) -> None:
549
679
  """Annotate the network graph with a single label for the given nodes, positioned at a specified radial angle.
550
680
 
@@ -556,22 +686,13 @@ class NetworkPlotter:
556
686
  offset (float, optional): Offset distance for the label from the perimeter. Defaults to 0.10.
557
687
  font (str, optional): Font name for the label. Defaults to "Arial".
558
688
  fontsize (int, optional): Font size for the label. Defaults to 10.
559
- fontcolor (str, optional): Color of the label text. Defaults to "black".
689
+ fontcolor (str, list, tuple, or np.ndarray, optional): Color of the label text. Defaults to "black".
690
+ fontalpha (float, optional): Transparency level for the font color. Defaults to 1.0.
560
691
  arrow_linewidth (float, optional): Line width of the arrow pointing to the centroid. Defaults to 1.
561
- arrow_color (str, optional): Color of the arrow. Defaults to "black".
692
+ arrow_color (str, list, tuple, or np.ndarray, optional): Color of the arrow. Defaults to "black".
693
+ arrow_alpha (float, optional): Transparency level for the arrow color. Defaults to 1.0.
562
694
  """
563
- # Log the plotting parameters
564
- params.log_plotter(
565
- sublabel_perimeter_scale=perimeter_scale,
566
- sublabel_offset=offset,
567
- sublabel_font=font,
568
- sublabel_fontsize=fontsize,
569
- sublabel_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
570
- sublabel_arrow_linewidth=arrow_linewidth,
571
- sublabel_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
572
- sublabel_radial_position=radial_position,
573
- )
574
-
695
+ # Don't log sublabel parameters as they are specific to individual annotations
575
696
  # Map node labels to IDs
576
697
  node_ids = [
577
698
  self.graph.node_label_to_id_map.get(node)
@@ -581,6 +702,9 @@ class NetworkPlotter:
581
702
  if not node_ids or len(node_ids) == 1:
582
703
  raise ValueError("No nodes found in the network graph or insufficient nodes to plot.")
583
704
 
705
+ # Convert fontcolor and arrow_color to RGBA using the _to_rgba helper function
706
+ fontcolor = _to_rgba(fontcolor, fontalpha)
707
+ arrow_color = _to_rgba(arrow_color, arrow_alpha)
584
708
  # Calculate the centroid of the provided nodes
585
709
  centroid = self._calculate_domain_centroid(node_ids)
586
710
  # Calculate the bounding box around the network
@@ -608,7 +732,7 @@ class NetworkPlotter:
608
732
  arrowprops=dict(arrowstyle="->", color=arrow_color, linewidth=arrow_linewidth),
609
733
  )
610
734
 
611
- def _calculate_domain_centroid(self, nodes: list) -> tuple:
735
+ def _calculate_domain_centroid(self, nodes: List) -> tuple:
612
736
  """Calculate the most centrally located node in .
613
737
 
614
738
  Args:
@@ -631,44 +755,36 @@ class NetworkPlotter:
631
755
 
632
756
  def get_annotated_node_colors(
633
757
  self,
758
+ alpha: float = 1.0,
634
759
  nonenriched_color: Union[str, List, Tuple, np.ndarray] = "white",
760
+ nonenriched_alpha: float = 1.0,
635
761
  random_seed: int = 888,
636
762
  **kwargs,
637
763
  ) -> np.ndarray:
638
764
  """Adjust the colors of nodes in the network graph based on enrichment.
639
765
 
640
766
  Args:
641
- nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes.
642
- Defaults to "white". If an alpha channel is provided, it will be used to darken the RGB values.
767
+ alpha (float, optional): Alpha value for enriched nodes. Defaults to 1.0.
768
+ nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Defaults to "white".
769
+ nonenriched_alpha (float, optional): Alpha value for non-enriched nodes. Defaults to 1.0.
643
770
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
644
771
  **kwargs: Additional keyword arguments for `get_domain_colors`.
645
772
 
646
773
  Returns:
647
774
  np.ndarray: Array of RGBA colors adjusted for enrichment status.
648
775
  """
649
- # Get the initial domain colors for each node
776
+ # Get the initial domain colors for each node, which are returned as RGBA
650
777
  network_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
651
- if isinstance(nonenriched_color, str):
652
- # Convert the non-enriched color from string to RGBA
653
- nonenriched_color = mcolors.to_rgba(nonenriched_color)
654
-
655
- # Ensure nonenriched_color is a NumPy array
656
- nonenriched_color = np.array(nonenriched_color)
657
- # If alpha is provided (4th value), darken the RGB values based on alpha
658
- if len(nonenriched_color) == 4 and nonenriched_color[3] < 1.0:
659
- alpha = nonenriched_color[3]
660
- # Adjust RGB based on alpha (darken)
661
- nonenriched_color[:3] = nonenriched_color[:3] * alpha
662
- # Set alpha to 1.0 after darkening the color
663
- nonenriched_color[3] = 1.0
664
-
665
- # Adjust node colors: replace any fully transparent nodes (enriched) with the non-enriched color
778
+ # Apply the alpha value for enriched nodes
779
+ network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
780
+ # Convert the non-enriched color to RGBA using the _to_rgba helper function
781
+ nonenriched_color = _to_rgba(nonenriched_color, nonenriched_alpha)
782
+ # Adjust node colors: replace any fully black nodes (RGB == 0) with the non-enriched color and its alpha
666
783
  adjusted_network_colors = np.where(
667
784
  np.all(network_colors[:, :3] == 0, axis=1, keepdims=True), # Check RGB values only
668
- np.array([nonenriched_color]), # Apply the non-enriched color (with adjusted alpha)
669
- network_colors, # Keep the original colors
785
+ np.array([nonenriched_color]), # Apply the non-enriched color with alpha
786
+ network_colors, # Keep the original colors for enriched nodes
670
787
  )
671
-
672
788
  return adjusted_network_colors
673
789
 
674
790
  def get_annotated_node_sizes(
@@ -722,35 +838,43 @@ class NetworkPlotter:
722
838
  return self._get_annotated_domain_colors(**kwargs, random_seed=random_seed)
723
839
 
724
840
  def _get_annotated_domain_colors(
725
- self, color: Union[str, list, None] = None, random_seed: int = 888, **kwargs
841
+ self,
842
+ color: Union[str, List, Tuple, np.ndarray, None] = None,
843
+ random_seed: int = 888,
844
+ **kwargs,
726
845
  ) -> np.ndarray:
727
846
  """Get colors for the domains based on node annotations.
728
847
 
729
848
  Args:
730
- color (str, list, or None, optional): If provided, use this color or list of colors for domains. Defaults to None.
849
+ color (str, list, tuple, np.ndarray, or None, optional): If provided, use this color or list of colors for domains. Defaults to None.
731
850
  random_seed (int, optional): Seed for random number generation. Defaults to 888.
732
851
  **kwargs: Additional keyword arguments for `get_domain_colors`.
733
852
 
734
853
  Returns:
735
854
  np.ndarray: Array of RGBA colors for each domain.
736
855
  """
737
- if isinstance(color, str):
738
- # If a single color string is provided, convert it to RGBA and apply to all domains
739
- rgba_color = np.array(matplotlib.colors.to_rgba(color))
740
- return np.array([rgba_color for _ in self.graph.domain_to_nodes])
856
+ if color is not None:
857
+ # Convert color(s) to RGBA using _to_rgba helper
858
+ if isinstance(color, (list, tuple, np.ndarray)):
859
+ return np.array([_to_rgba(c) for c in color])
860
+ else:
861
+ rgba_color = _to_rgba(color)
862
+ return np.array([rgba_color for _ in self.graph.domain_to_nodes])
741
863
 
742
- # Generate colors for each domain using the provided arguments and random seed
864
+ # Generate colors for each domain using provided arguments and random seed
743
865
  node_colors = self.graph.get_domain_colors(**kwargs, random_seed=random_seed)
744
866
  annotated_colors = []
745
867
  for _, nodes in self.graph.domain_to_nodes.items():
746
868
  if len(nodes) > 1:
747
- # For domains with multiple nodes, choose the brightest color (sum of RGB values)
869
+ # For domains with multiple nodes, select the brightest color (sum of RGB values)
748
870
  domain_colors = np.array([node_colors[node] for node in nodes])
749
- brightest_color = domain_colors[np.argmax(domain_colors.sum(axis=1))]
871
+ brightest_color = domain_colors[
872
+ np.argmax(domain_colors[:, :3].sum(axis=1))
873
+ ] # Only consider RGB
750
874
  annotated_colors.append(brightest_color)
751
875
  else:
752
876
  # Assign a default color (white) for single-node domains
753
- default_color = np.array([1.0, 1.0, 1.0, 1.0])
877
+ default_color = np.array([1.0, 1.0, 1.0, 1.0]) # RGBA for white
754
878
  annotated_colors.append(default_color)
755
879
 
756
880
  return np.array(annotated_colors)
@@ -776,6 +900,47 @@ class NetworkPlotter:
776
900
  plt.show(*args, **kwargs)
777
901
 
778
902
 
903
+ def _to_rgba(
904
+ color: Union[str, List, Tuple, np.ndarray],
905
+ alpha: float = 1.0,
906
+ num_repeats: Union[int, None] = None,
907
+ ) -> np.ndarray:
908
+ """Convert a color or array of colors to RGBA format, applying or updating the alpha as needed.
909
+
910
+ Args:
911
+ color (Union[str, list, tuple, np.ndarray]): The color(s) to convert. Can be a string, list, tuple, or np.ndarray.
912
+ alpha (float, optional): Alpha value (transparency) to apply. Defaults to 1.0.
913
+ num_repeats (int or None, optional): If provided, the color will be repeated this many times. Defaults to None.
914
+
915
+ Returns:
916
+ np.ndarray: The RGBA color or array of RGBA colors.
917
+ """
918
+ # Handle single color case
919
+ if isinstance(color, str) or (
920
+ isinstance(color, (list, tuple, np.ndarray)) and len(color) in [3, 4]
921
+ ):
922
+ rgba_color = np.array(mcolors.to_rgba(color, alpha))
923
+ # Repeat the color if repeat argument is provided
924
+ if num_repeats is not None:
925
+ return np.array([rgba_color] * num_repeats)
926
+
927
+ return rgba_color
928
+
929
+ # Handle array of colors case
930
+ elif isinstance(color, (list, tuple, np.ndarray)) and isinstance(
931
+ color[0], (list, tuple, np.ndarray)
932
+ ):
933
+ rgba_colors = [mcolors.to_rgba(c, alpha) if len(c) == 3 else np.array(c) for c in color]
934
+ # Repeat the colors if repeat argument is provided
935
+ if num_repeats is not None and len(rgba_colors) == 1:
936
+ return np.array([rgba_colors[0]] * num_repeats)
937
+
938
+ return np.array(rgba_colors)
939
+
940
+ else:
941
+ raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
942
+
943
+
779
944
  def _is_connected(z: np.ndarray) -> bool:
780
945
  """Determine if a thresholded grid represents a single, connected component.
781
946
 
@@ -3,7 +3,7 @@ risk/risk
3
3
  ~~~~~~~~~
4
4
  """
5
5
 
6
- from typing import Any, Dict
6
+ from typing import Any, Dict, Tuple
7
7
 
8
8
  import networkx as nx
9
9
  import numpy as np
@@ -114,6 +114,7 @@ class RISK(NetworkIO, AnnotationsIO):
114
114
  max_workers=max_workers,
115
115
  )
116
116
 
117
+ # Return the computed neighborhood significance
117
118
  return neighborhood_significance
118
119
 
119
120
  def load_neighborhoods_by_fisher_exact(
@@ -169,6 +170,7 @@ class RISK(NetworkIO, AnnotationsIO):
169
170
  max_workers=max_workers,
170
171
  )
171
172
 
173
+ # Return the computed neighborhood significance
172
174
  return neighborhood_significance
173
175
 
174
176
  def load_neighborhoods_by_hypergeom(
@@ -224,6 +226,7 @@ class RISK(NetworkIO, AnnotationsIO):
224
226
  max_workers=max_workers,
225
227
  )
226
228
 
229
+ # Return the computed neighborhood significance
227
230
  return neighborhood_significance
228
231
 
229
232
  def load_graph(
@@ -347,12 +350,8 @@ class RISK(NetworkIO, AnnotationsIO):
347
350
  def load_plotter(
348
351
  self,
349
352
  graph: NetworkGraph,
350
- figsize: tuple = (10, 10),
353
+ figsize: Tuple = (10, 10),
351
354
  background_color: str = "white",
352
- plot_outline: bool = True,
353
- outline_color: str = "black",
354
- outline_scale: float = 1.00,
355
- linestyle: str = "dashed",
356
355
  ) -> NetworkPlotter:
357
356
  """Get a NetworkPlotter object for plotting.
358
357
 
@@ -360,10 +359,6 @@ class RISK(NetworkIO, AnnotationsIO):
360
359
  graph (NetworkGraph): The graph to plot.
361
360
  figsize (tuple, optional): Size of the figure. Defaults to (10, 10).
362
361
  background_color (str, optional): Background color of the plot. Defaults to "white".
363
- plot_outline (bool, optional): Whether to plot the network outline. Defaults to True.
364
- outline_color (str, optional): Color of the outline. Defaults to "black".
365
- outline_scale (float, optional): Scaling factor for the outline. Defaults to 1.00.
366
- linestyle (str): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
367
362
 
368
363
  Returns:
369
364
  NetworkPlotter: A NetworkPlotter object configured with the given parameters.
@@ -373,10 +368,6 @@ class RISK(NetworkIO, AnnotationsIO):
373
368
  params.log_plotter(
374
369
  figsize=figsize,
375
370
  background_color=background_color,
376
- plot_outline=plot_outline,
377
- outline_color=outline_color,
378
- outline_scale=outline_scale,
379
- linestyle=linestyle,
380
371
  )
381
372
 
382
373
  # Initialize and return a NetworkPlotter object
@@ -384,10 +375,6 @@ class RISK(NetworkIO, AnnotationsIO):
384
375
  graph,
385
376
  figsize=figsize,
386
377
  background_color=background_color,
387
- plot_outline=plot_outline,
388
- outline_color=outline_color,
389
- outline_scale=outline_scale,
390
- linestyle=linestyle,
391
378
  )
392
379
 
393
380
  def _load_neighborhoods(
@@ -430,6 +417,7 @@ class RISK(NetworkIO, AnnotationsIO):
430
417
  random_seed=random_seed,
431
418
  )
432
419
 
420
+ # Return the computed neighborhoods
433
421
  return neighborhoods
434
422
 
435
423
  def _define_top_annotations(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.6b0
3
+ Version: 0.0.6b1
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
File without changes
File without changes
File without changes
File without changes