risk-network 0.0.12b0__py3-none-any.whl → 0.0.12b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. risk/__init__.py +1 -1
  2. risk/annotations/__init__.py +10 -0
  3. risk/annotations/annotations.py +354 -0
  4. risk/annotations/io.py +241 -0
  5. risk/annotations/nltk_setup.py +86 -0
  6. risk/log/__init__.py +11 -0
  7. risk/log/console.py +141 -0
  8. risk/log/parameters.py +171 -0
  9. risk/neighborhoods/__init__.py +7 -0
  10. risk/neighborhoods/api.py +442 -0
  11. risk/neighborhoods/community.py +441 -0
  12. risk/neighborhoods/domains.py +360 -0
  13. risk/neighborhoods/neighborhoods.py +514 -0
  14. risk/neighborhoods/stats/__init__.py +13 -0
  15. risk/neighborhoods/stats/permutation/__init__.py +6 -0
  16. risk/neighborhoods/stats/permutation/permutation.py +240 -0
  17. risk/neighborhoods/stats/permutation/test_functions.py +70 -0
  18. risk/neighborhoods/stats/tests.py +275 -0
  19. risk/network/__init__.py +4 -0
  20. risk/network/graph/__init__.py +4 -0
  21. risk/network/graph/api.py +200 -0
  22. risk/network/graph/graph.py +268 -0
  23. risk/network/graph/stats.py +166 -0
  24. risk/network/graph/summary.py +253 -0
  25. risk/network/io.py +693 -0
  26. risk/network/plotter/__init__.py +4 -0
  27. risk/network/plotter/api.py +54 -0
  28. risk/network/plotter/canvas.py +291 -0
  29. risk/network/plotter/contour.py +329 -0
  30. risk/network/plotter/labels.py +935 -0
  31. risk/network/plotter/network.py +294 -0
  32. risk/network/plotter/plotter.py +141 -0
  33. risk/network/plotter/utils/colors.py +419 -0
  34. risk/network/plotter/utils/layout.py +94 -0
  35. risk_network-0.0.12b1.dist-info/METADATA +122 -0
  36. risk_network-0.0.12b1.dist-info/RECORD +40 -0
  37. {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b1.dist-info}/WHEEL +1 -1
  38. risk_network-0.0.12b0.dist-info/METADATA +0 -796
  39. risk_network-0.0.12b0.dist-info/RECORD +0 -7
  40. {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b1.dist-info}/licenses/LICENSE +0 -0
  41. {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,54 @@
1
+ """
2
+ risk/network/plotter/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import List, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from risk.log import log_header
11
+ from risk.network.graph.graph import Graph
12
+ from risk.network.plotter.plotter import Plotter
13
+
14
+
15
+ class PlotterAPI:
16
+ """Handles the loading of network plotter objects.
17
+
18
+ The PlotterAPI class provides methods to load and configure Plotter objects for plotting network graphs.
19
+ """
20
+
21
+ def __init__(self) -> None:
22
+ pass
23
+
24
+ def load_plotter(
25
+ self,
26
+ graph: Graph,
27
+ figsize: Union[List, Tuple, np.ndarray] = (10, 10),
28
+ background_color: str = "white",
29
+ background_alpha: Union[float, None] = 1.0,
30
+ pad: float = 0.3,
31
+ ) -> Plotter:
32
+ """Get a Plotter object for plotting.
33
+
34
+ Args:
35
+ graph (Graph): The graph to plot.
36
+ figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
37
+ background_color (str, optional): Background color of the plot. Defaults to "white".
38
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
39
+ any existing alpha values found in background_color. Defaults to 1.0.
40
+ pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
41
+
42
+ Returns:
43
+ Plotter: A Plotter object configured with the given parameters.
44
+ """
45
+ log_header("Loading plotter")
46
+
47
+ # Initialize and return a Plotter object
48
+ return Plotter(
49
+ graph,
50
+ figsize=figsize,
51
+ background_color=background_color,
52
+ background_alpha=background_alpha,
53
+ pad=pad,
54
+ )
@@ -0,0 +1,291 @@
1
+ """
2
+ risk/network/plotter/canvas
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import List, Tuple, Union
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+
11
+ from risk.log import params
12
+ from risk.network.graph.graph import Graph
13
+ from risk.network.plotter.utils.colors import to_rgba
14
+ from risk.network.plotter.utils.layout import calculate_bounding_box
15
+
16
+
17
+ class Canvas:
18
+ """A class for laying out the canvas in a network graph."""
19
+
20
+ def __init__(self, graph: Graph, ax: plt.Axes) -> None:
21
+ """Initialize the Canvas with a Graph and axis for plotting.
22
+
23
+ Args:
24
+ graph (Graph): The Graph object containing the network data.
25
+ ax (plt.Axes): The axis to plot the canvas on.
26
+ """
27
+ self.graph = graph
28
+ self.ax = ax
29
+
30
+ def plot_title(
31
+ self,
32
+ title: Union[str, None] = None,
33
+ subtitle: Union[str, None] = None,
34
+ title_fontsize: int = 20,
35
+ subtitle_fontsize: int = 14,
36
+ font: str = "Arial",
37
+ title_color: Union[str, List, Tuple, np.ndarray] = "black",
38
+ subtitle_color: Union[str, List, Tuple, np.ndarray] = "gray",
39
+ title_x: float = 0.5,
40
+ title_y: float = 0.975,
41
+ title_space_offset: float = 0.075,
42
+ subtitle_offset: float = 0.025,
43
+ ) -> None:
44
+ """Plot title and subtitle on the network graph with customizable parameters.
45
+
46
+ Args:
47
+ title (str, optional): Title of the plot. Defaults to None.
48
+ subtitle (str, optional): Subtitle of the plot. Defaults to None.
49
+ title_fontsize (int, optional): Font size for the title. Defaults to 20.
50
+ subtitle_fontsize (int, optional): Font size for the subtitle. Defaults to 14.
51
+ font (str, optional): Font family used for both title and subtitle. Defaults to "Arial".
52
+ title_color (str, List, Tuple, or np.ndarray, optional): Color of the title text. Can be a string or an array of colors.
53
+ Defaults to "black".
54
+ subtitle_color (str, List, Tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
55
+ Defaults to "gray".
56
+ title_x (float, optional): X-axis position of the title. Defaults to 0.5.
57
+ title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
58
+ title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
59
+ subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
60
+ """
61
+ # Log the title and subtitle parameters
62
+ params.log_plotter(
63
+ title=title,
64
+ subtitle=subtitle,
65
+ title_fontsize=title_fontsize,
66
+ subtitle_fontsize=subtitle_fontsize,
67
+ title_subtitle_font=font,
68
+ title_color=title_color,
69
+ subtitle_color=subtitle_color,
70
+ subtitle_offset=subtitle_offset,
71
+ title_y=title_y,
72
+ title_space_offset=title_space_offset,
73
+ )
74
+
75
+ # Get the current figure and axis dimensions
76
+ fig = self.ax.figure
77
+ # Use a tight layout to ensure that title and subtitle do not overlap with the original plot
78
+ fig.tight_layout(
79
+ rect=[0, 0, 1, 1 - title_space_offset]
80
+ ) # Leave space above the plot for title
81
+
82
+ # Plot title if provided
83
+ if title:
84
+ # Set the title using figure's suptitle to ensure centering
85
+ self.ax.figure.suptitle(
86
+ title,
87
+ fontsize=title_fontsize,
88
+ color=title_color,
89
+ fontname=font,
90
+ x=title_x,
91
+ ha="center",
92
+ va="top",
93
+ y=title_y,
94
+ )
95
+
96
+ # Plot subtitle if provided
97
+ if subtitle:
98
+ # Calculate the subtitle's y position based on the midpoint of the title and subtitle_offset
99
+ # Calculate the approximate height of the title in relative axis units
100
+ title_height = title_fontsize / fig.bbox.height
101
+ # Position the subtitle relative to the title's center (title_y - half the title height)
102
+ subtitle_y_position = title_y - (title_height / 2) - subtitle_offset
103
+ self.ax.figure.text(
104
+ 0.5, # Ensure horizontal centering for subtitle
105
+ subtitle_y_position, # Position subtitle based on the center of the title
106
+ subtitle,
107
+ ha="center",
108
+ va="top",
109
+ fontname=font,
110
+ fontsize=subtitle_fontsize,
111
+ color=subtitle_color,
112
+ )
113
+
114
+ def plot_circle_perimeter(
115
+ self,
116
+ scale: float = 1.0,
117
+ center_offset_x: float = 0.0,
118
+ center_offset_y: float = 0.0,
119
+ linestyle: str = "dashed",
120
+ linewidth: float = 1.5,
121
+ color: Union[str, List, Tuple, np.ndarray] = "black",
122
+ outline_alpha: Union[float, None] = 1.0,
123
+ fill_alpha: Union[float, None] = 0.0,
124
+ ) -> None:
125
+ """Plot a circle around the network graph to represent the network perimeter.
126
+
127
+ Args:
128
+ scale (float, optional): Scaling factor for the perimeter diameter. Defaults to 1.0.
129
+ center_offset_x (float, optional): Horizontal offset as a fraction of the diameter.
130
+ Negative values shift the center left, positive values shift it right. Defaults to 0.0.
131
+ center_offset_y (float, optional): Vertical offset as a fraction of the diameter.
132
+ Negative values shift the center down, positive values shift it up. Defaults to 0.0.
133
+ linestyle (str, optional): Line style for the network perimeter circle (e.g., dashed, solid). Defaults to "dashed".
134
+ linewidth (float, optional): Width of the circle's outline. Defaults to 1.5.
135
+ color (str, List, Tuple, or np.ndarray, optional): Color of the network perimeter circle. Defaults to "black".
136
+ outline_alpha (float, None, optional): Transparency level of the circle outline. If provided, it overrides any existing alpha
137
+ values found in color. Defaults to 1.0.
138
+ fill_alpha (float, None, optional): Transparency level of the circle fill. If provided, it overrides any existing alpha values
139
+ found in color. Defaults to 0.0.
140
+ """
141
+ # Log the circle perimeter plotting parameters
142
+ params.log_plotter(
143
+ perimeter_type="circle",
144
+ perimeter_scale=scale,
145
+ perimeter_center_offset_x=center_offset_x,
146
+ perimeter_center_offset_y=center_offset_y,
147
+ perimeter_linestyle=linestyle,
148
+ perimeter_linewidth=linewidth,
149
+ perimeter_color=(
150
+ "custom" if isinstance(color, (list, tuple, np.ndarray)) else color
151
+ ), # np.ndarray usually indicates custom colors
152
+ perimeter_outline_alpha=outline_alpha,
153
+ perimeter_fill_alpha=fill_alpha,
154
+ )
155
+
156
+ # Extract node coordinates from the network graph
157
+ node_coordinates = self.graph.node_coordinates
158
+ # Calculate the center and radius of the bounding box around the network
159
+ center, radius = calculate_bounding_box(node_coordinates)
160
+ # Adjust the center based on user-defined offsets
161
+ adjusted_center = self._calculate_adjusted_center(
162
+ center, radius, center_offset_x, center_offset_y
163
+ )
164
+ # Scale the radius by the scale factor
165
+ scaled_radius = radius * scale
166
+
167
+ # Convert color to RGBA using the to_rgba helper function - use outline_alpha for the perimeter
168
+ outline_color_rgba = to_rgba(
169
+ color=color, alpha=outline_alpha, num_repeats=1
170
+ ) # num_repeats=1 for a single color
171
+ fill_color_rgba = to_rgba(
172
+ color=color, alpha=fill_alpha, num_repeats=1
173
+ ) # num_repeats=1 for a single color
174
+
175
+ # Draw a circle to represent the network perimeter
176
+ circle = plt.Circle(
177
+ adjusted_center,
178
+ scaled_radius,
179
+ linestyle=linestyle,
180
+ linewidth=linewidth,
181
+ color=outline_color_rgba,
182
+ )
183
+ # Set the transparency of the fill if applicable
184
+ circle.set_facecolor(
185
+ to_rgba(color=fill_color_rgba, num_repeats=1)
186
+ ) # num_repeats=1 for a single color
187
+
188
+ self.ax.add_artist(circle)
189
+
190
+ def plot_contour_perimeter(
191
+ self,
192
+ scale: float = 1.0,
193
+ levels: int = 3,
194
+ bandwidth: float = 0.8,
195
+ grid_size: int = 250,
196
+ color: Union[str, List, Tuple, np.ndarray] = "black",
197
+ linestyle: str = "solid",
198
+ linewidth: float = 1.5,
199
+ outline_alpha: Union[float, None] = 1.0,
200
+ fill_alpha: Union[float, None] = 0.0,
201
+ ) -> None:
202
+ """
203
+ Plot a KDE-based contour around the network graph to represent the network perimeter.
204
+
205
+ Args:
206
+ scale (float, optional): Scaling factor for the perimeter size. Defaults to 1.0.
207
+ levels (int, optional): Number of contour levels. Defaults to 3.
208
+ bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
209
+ grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
210
+ color (str, List, Tuple, or np.ndarray, optional): Color of the network perimeter contour. Defaults to "black".
211
+ linestyle (str, optional): Line style for the network perimeter contour (e.g., dashed, solid). Defaults to "solid".
212
+ linewidth (float, optional): Width of the contour's outline. Defaults to 1.5.
213
+ outline_alpha (float, None, optional): Transparency level of the contour outline. If provided, it overrides any existing
214
+ alpha values found in color. Defaults to 1.0.
215
+ fill_alpha (float, None, optional): Transparency level of the contour fill. If provided, it overrides any existing alpha
216
+ values found in color. Defaults to 0.0.
217
+ """
218
+ # Log the contour perimeter plotting parameters
219
+ params.log_plotter(
220
+ perimeter_type="contour",
221
+ perimeter_scale=scale,
222
+ perimeter_levels=levels,
223
+ perimeter_bandwidth=bandwidth,
224
+ perimeter_grid_size=grid_size,
225
+ perimeter_linestyle=linestyle,
226
+ perimeter_linewidth=linewidth,
227
+ perimeter_color=("custom" if isinstance(color, (list, tuple, np.ndarray)) else color),
228
+ perimeter_outline_alpha=outline_alpha,
229
+ perimeter_fill_alpha=fill_alpha,
230
+ )
231
+
232
+ # Convert color to RGBA using outline_alpha for the line (outline)
233
+ outline_color_rgba = to_rgba(color=color, num_repeats=1) # num_repeats=1 for a single color
234
+ # Extract node coordinates from the network graph
235
+ node_coordinates = self.graph.node_coordinates
236
+ # Scale the node coordinates if needed
237
+ scaled_coordinates = node_coordinates * scale
238
+ # Use the existing _draw_kde_contour method
239
+ # NOTE: This is a technical debt that should be refactored in the future - only works when inherited by Plotter
240
+ self._draw_kde_contour(
241
+ ax=self.ax,
242
+ pos=scaled_coordinates,
243
+ nodes=list(range(len(node_coordinates))), # All nodes are included
244
+ levels=levels,
245
+ bandwidth=bandwidth,
246
+ grid_size=grid_size,
247
+ color=outline_color_rgba,
248
+ linestyle=linestyle,
249
+ linewidth=linewidth,
250
+ fill_alpha=fill_alpha,
251
+ )
252
+
253
+ def _calculate_adjusted_center(
254
+ self,
255
+ center: Tuple[float, float],
256
+ radius: float,
257
+ center_offset_x: float = 0.0,
258
+ center_offset_y: float = 0.0,
259
+ ) -> Tuple[float, float]:
260
+ """Calculate the adjusted center for the network perimeter circle based on user-defined offsets.
261
+
262
+ Args:
263
+ center (Tuple[float, float]): Original center coordinates of the network graph.
264
+ radius (float): Radius of the bounding box around the network.
265
+ center_offset_x (float, optional): Horizontal offset as a fraction of the diameter.
266
+ Negative values shift the center left, positive values shift it right. Allowed
267
+ values are in the range [-1, 1]. Defaults to 0.0.
268
+ center_offset_y (float, optional): Vertical offset as a fraction of the diameter.
269
+ Negative values shift the center down, positive values shift it up. Allowed
270
+ values are in the range [-1, 1]. Defaults to 0.0.
271
+
272
+ Returns:
273
+ Tuple[float, float]: Adjusted center coordinates after applying the offsets.
274
+
275
+ Raises:
276
+ ValueError: If the center offsets are outside the valid range [-1, 1].
277
+ """
278
+ # Flip the y-axis to match the plot orientation
279
+ flipped_center_offset_y = -center_offset_y
280
+ # Validate the center offsets
281
+ if not -1 <= center_offset_x <= 1:
282
+ raise ValueError("Horizontal center offset must be in the range [-1, 1].")
283
+ if not -1 <= center_offset_y <= 1:
284
+ raise ValueError("Vertical center offset must be in the range [-1, 1].")
285
+
286
+ # Calculate adjusted center by applying offset fractions of the diameter
287
+ adjusted_center_x = center[0] + (center_offset_x * radius * 2)
288
+ adjusted_center_y = center[1] + (flipped_center_offset_y * radius * 2)
289
+
290
+ # Return the adjusted center coordinates
291
+ return adjusted_center_x, adjusted_center_y
@@ -0,0 +1,329 @@
1
+ """
2
+ risk/network/plotter/contour
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict, List, Tuple, Union
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from scipy import linalg
11
+ from scipy.ndimage import label
12
+ from scipy.stats import gaussian_kde
13
+
14
+ from risk.log import logger, params
15
+ from risk.network.graph.graph import Graph
16
+ from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
17
+
18
+
19
+ class Contour:
20
+ """Class to generate Kernel Density Estimate (KDE) contours for nodes in a network graph."""
21
+
22
+ def __init__(self, graph: Graph, ax: plt.Axes) -> None:
23
+ """Initialize the Contour with a Graph and axis for plotting.
24
+
25
+ Args:
26
+ graph (Graph): The Graph object containing the network data.
27
+ ax (plt.Axes): The axis to plot the contours on.
28
+ """
29
+ self.graph = graph
30
+ self.ax = ax
31
+
32
+ def plot_contours(
33
+ self,
34
+ levels: int = 5,
35
+ bandwidth: float = 0.8,
36
+ grid_size: int = 250,
37
+ color: Union[str, List, Tuple, np.ndarray] = "white",
38
+ linestyle: str = "solid",
39
+ linewidth: float = 1.5,
40
+ alpha: Union[float, None] = 1.0,
41
+ fill_alpha: Union[float, None] = None,
42
+ ) -> None:
43
+ """Draw KDE contours for nodes in various domains of a network graph, highlighting areas of high density.
44
+
45
+ Args:
46
+ levels (int, optional): Number of contour levels to plot. Defaults to 5.
47
+ bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
48
+ grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
49
+ color (str, List, Tuple, or np.ndarray, optional): Color of the contours. Can be a single color or an array of colors.
50
+ Defaults to "white".
51
+ linestyle (str, optional): Line style for the contours. Defaults to "solid".
52
+ linewidth (float, optional): Line width for the contours. Defaults to 1.5.
53
+ alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
54
+ found in color. Defaults to 1.0.
55
+ fill_alpha (float, None, optional): Transparency level of the contour fill. If provided, it overrides any existing alpha
56
+ values found in color. Defaults to None.
57
+ """
58
+ # Log the contour plotting parameters
59
+ params.log_plotter(
60
+ contour_levels=levels,
61
+ contour_bandwidth=bandwidth,
62
+ contour_grid_size=grid_size,
63
+ contour_color=(
64
+ "custom" if isinstance(color, np.ndarray) else color
65
+ ), # np.ndarray usually indicates custom colors
66
+ contour_linestyle=linestyle,
67
+ contour_linewidth=linewidth,
68
+ contour_alpha=alpha,
69
+ contour_fill_alpha=fill_alpha,
70
+ )
71
+
72
+ # Ensure color is converted to RGBA with repetition matching the number of domains
73
+ color_rgba = to_rgba(
74
+ color=color, alpha=alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
75
+ )
76
+ # Extract node coordinates from the network graph
77
+ node_coordinates = self.graph.node_coordinates
78
+ # Draw contours for each domain in the network
79
+ for idx, (_, node_ids) in enumerate(self.graph.domain_id_to_node_ids_map.items()):
80
+ # Use the provided alpha value if it's not None, otherwise use the color's alpha
81
+ current_fill_alpha = fill_alpha if fill_alpha is not None else color_rgba[idx][3]
82
+ if len(node_ids) > 1:
83
+ self._draw_kde_contour(
84
+ self.ax,
85
+ node_coordinates,
86
+ node_ids,
87
+ color=color_rgba[idx],
88
+ levels=levels,
89
+ bandwidth=bandwidth,
90
+ grid_size=grid_size,
91
+ linestyle=linestyle,
92
+ linewidth=linewidth,
93
+ fill_alpha=current_fill_alpha,
94
+ )
95
+
96
+ def plot_subcontour(
97
+ self,
98
+ nodes: Union[List, Tuple, np.ndarray],
99
+ levels: int = 5,
100
+ bandwidth: float = 0.8,
101
+ grid_size: int = 250,
102
+ color: Union[str, List, Tuple, np.ndarray] = "white",
103
+ linestyle: str = "solid",
104
+ linewidth: float = 1.5,
105
+ alpha: Union[float, None] = 1.0,
106
+ fill_alpha: Union[float, None] = None,
107
+ ) -> None:
108
+ """Plot a subcontour for a given set of nodes or a list of node sets using Kernel Density Estimation (KDE).
109
+
110
+ Args:
111
+ nodes (List, Tuple, or np.ndarray): List of node labels or list of lists of node labels to plot the contour for.
112
+ levels (int, optional): Number of contour levels to plot. Defaults to 5.
113
+ bandwidth (float, optional): Bandwidth for KDE. Controls the smoothness of the contour. Defaults to 0.8.
114
+ grid_size (int, optional): Resolution of the grid for KDE. Higher values create finer contours. Defaults to 250.
115
+ color (str, List, Tuple, or np.ndarray, optional): Color of the contour. Can be a string (e.g., 'white') or RGBA array.
116
+ Can be a single color or an array of colors. Defaults to "white".
117
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
118
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
119
+ alpha (float, None, optional): Transparency level of the contour lines. If provided, it overrides any existing alpha values
120
+ found in color. Defaults to 1.0.
121
+ fill_alpha (float, None, optional): Transparency level of the contour fill. If provided, it overrides any existing alpha
122
+ values found in color. Defaults to None.
123
+
124
+ Raises:
125
+ ValueError: If no valid nodes are found in the network graph.
126
+ """
127
+ # Check if nodes is a list of lists or a flat list
128
+ if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
129
+ # If it's a list of lists, iterate over sublists
130
+ node_groups = nodes
131
+ # Convert color to RGBA arrays to match the number of groups
132
+ color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=len(node_groups))
133
+ else:
134
+ # If it's a flat list of nodes, treat it as a single group
135
+ node_groups = [nodes]
136
+ # Wrap the RGBA color in an array to index the first element
137
+ color_rgba = to_rgba(color=color, alpha=alpha, num_repeats=1)
138
+
139
+ # Iterate over each group of nodes (either sublists or flat list)
140
+ for idx, sublist in enumerate(node_groups):
141
+ # Filter to get node IDs and their coordinates for each sublist
142
+ node_ids = [
143
+ self.graph.node_label_to_node_id_map.get(node)
144
+ for node in sublist
145
+ if node in self.graph.node_label_to_node_id_map
146
+ ]
147
+ if not node_ids or len(node_ids) == 1:
148
+ raise ValueError(
149
+ "No nodes found in the network graph or insufficient nodes to plot."
150
+ )
151
+
152
+ # Draw the KDE contour for the specified nodes
153
+ node_coordinates = self.graph.node_coordinates
154
+ # Use the provided alpha value if it's not None, otherwise use the color's alpha
155
+ current_fill_alpha = fill_alpha if fill_alpha is not None else color_rgba[idx][3]
156
+ self._draw_kde_contour(
157
+ self.ax,
158
+ node_coordinates,
159
+ node_ids,
160
+ color=color_rgba[idx],
161
+ levels=levels,
162
+ bandwidth=bandwidth,
163
+ grid_size=grid_size,
164
+ linestyle=linestyle,
165
+ linewidth=linewidth,
166
+ fill_alpha=current_fill_alpha,
167
+ )
168
+
169
+ def _draw_kde_contour(
170
+ self,
171
+ ax: plt.Axes,
172
+ pos: np.ndarray,
173
+ nodes: List,
174
+ levels: int = 5,
175
+ bandwidth: float = 0.8,
176
+ grid_size: int = 250,
177
+ color: Union[str, np.ndarray] = "white",
178
+ linestyle: str = "solid",
179
+ linewidth: float = 1.5,
180
+ fill_alpha: Union[float, None] = 0.2,
181
+ ) -> None:
182
+ """Draw a Kernel Density Estimate (KDE) contour plot for a set of nodes on a given axis.
183
+
184
+ Args:
185
+ ax (plt.Axes): The axis to draw the contour on.
186
+ pos (np.ndarray): Array of node positions (x, y).
187
+ nodes (List): List of node indices to include in the contour.
188
+ levels (int, optional): Number of contour levels. Defaults to 5.
189
+ bandwidth (float, optional): Bandwidth for the KDE. Controls smoothness. Defaults to 0.8.
190
+ grid_size (int, optional): Grid resolution for the KDE. Higher values yield finer contours. Defaults to 250.
191
+ color (str or np.ndarray): Color for the contour. Can be a string or RGBA array. Defaults to "white".
192
+ linestyle (str, optional): Line style for the contour. Defaults to "solid".
193
+ linewidth (float, optional): Line width for the contour. Defaults to 1.5.
194
+ fill_alpha (float, None, optional): Transparency level for the contour fill. If provided, it overrides any existing
195
+ alpha values found in color. Defaults to 0.2.
196
+ """
197
+ # Extract the positions of the specified nodes
198
+ points = np.array([pos[n] for n in nodes])
199
+ if len(points) <= 1:
200
+ return # Not enough points to form a contour
201
+
202
+ # Check if the KDE forms a single connected component
203
+ connected = False
204
+ z = None # Initialize z to None to avoid UnboundLocalError
205
+ while not connected and bandwidth <= 100.0:
206
+ try:
207
+ # Perform KDE on the points with the given bandwidth
208
+ kde = gaussian_kde(points.T, bw_method=bandwidth)
209
+ xmin, ymin = points.min(axis=0) - bandwidth
210
+ xmax, ymax = points.max(axis=0) + bandwidth
211
+ x, y = np.mgrid[
212
+ xmin : xmax : complex(0, grid_size), ymin : ymax : complex(0, grid_size)
213
+ ]
214
+ z = kde(np.vstack([x.ravel(), y.ravel()])).reshape(x.shape)
215
+ # Check if the KDE forms a single connected component
216
+ connected = self._is_connected(z)
217
+ if not connected:
218
+ bandwidth += 0.05 # Increase bandwidth slightly and retry
219
+ except linalg.LinAlgError:
220
+ bandwidth += 0.05 # Increase bandwidth and retry
221
+ except Exception as e:
222
+ # Catch any other exceptions and log them
223
+ logger.error(f"Unexpected error when drawing KDE contour: {e}")
224
+ return
225
+
226
+ # If z is still None, the KDE computation failed
227
+ if z is None:
228
+ logger.error("Failed to compute KDE. Skipping contour plot for these nodes.")
229
+ return
230
+
231
+ # Define contour levels based on the density
232
+ min_density, max_density = z.min(), z.max()
233
+ if min_density == max_density:
234
+ logger.warning(
235
+ "Contour levels could not be created due to lack of variation in density."
236
+ )
237
+ return
238
+
239
+ # Create contour levels based on the density values
240
+ contour_levels = np.linspace(min_density, max_density, levels)[1:]
241
+ if len(contour_levels) < 2 or not np.all(np.diff(contour_levels) > 0):
242
+ logger.error("Contour levels must be strictly increasing. Skipping contour plot.")
243
+ return
244
+
245
+ # Set the contour color, fill, and linestyle
246
+ contour_colors = [color for _ in range(levels - 1)]
247
+ # Plot the filled contours using fill_alpha for transparency
248
+ if fill_alpha and fill_alpha > 0:
249
+ # Fill alpha works differently than alpha for contour lines
250
+ # Contour fill cannot be specified by RGBA, while contour lines can
251
+ ax.contourf(
252
+ x,
253
+ y,
254
+ z,
255
+ levels=contour_levels,
256
+ colors=contour_colors,
257
+ antialiased=True,
258
+ alpha=fill_alpha,
259
+ )
260
+
261
+ # Plot the base contour line with the specified RGBA alpha for transparency
262
+ base_contour_color = [color]
263
+ base_contour_level = [contour_levels[0]]
264
+ ax.contour(
265
+ x,
266
+ y,
267
+ z,
268
+ levels=base_contour_level,
269
+ colors=base_contour_color,
270
+ linestyles=linestyle,
271
+ linewidths=linewidth,
272
+ )
273
+
274
+ def get_annotated_contour_colors(
275
+ self,
276
+ cmap: str = "gist_rainbow",
277
+ color: Union[str, List, Tuple, np.ndarray, None] = None,
278
+ blend_colors: bool = False,
279
+ blend_gamma: float = 2.2,
280
+ min_scale: float = 0.8,
281
+ max_scale: float = 1.0,
282
+ scale_factor: float = 1.0,
283
+ ids_to_colors: Union[Dict[int, Any], None] = None,
284
+ random_seed: int = 888,
285
+ ) -> np.ndarray:
286
+ """Get colors for the contours based on node annotations or a specified colormap.
287
+
288
+ Args:
289
+ cmap (str, optional): Name of the colormap to use for generating contour colors. Defaults to "gist_rainbow".
290
+ color (str, List, Tuple, np.ndarray, or None, optional): Color to use for the contours. Can be a single color or an array of colors.
291
+ If None, the colormap will be used. Defaults to None.
292
+ blend_colors (bool, optional): Whether to blend colors for nodes with multiple domains. Defaults to False.
293
+ blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
294
+ min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
295
+ Controls the dimmest colors. Defaults to 0.8.
296
+ max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
297
+ Controls the brightest colors. Defaults to 1.0.
298
+ scale_factor (float, optional): Exponent for adjusting color scaling based on significance scores.
299
+ A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
300
+ ids_to_colors (Dict[int, Any], None, optional): Mapping of domain IDs to specific colors. Defaults to None.
301
+ random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.
302
+
303
+ Returns:
304
+ np.ndarray: Array of RGBA colors for contour annotations.
305
+ """
306
+ return get_annotated_domain_colors(
307
+ graph=self.graph,
308
+ cmap=cmap,
309
+ color=color,
310
+ blend_colors=blend_colors,
311
+ blend_gamma=blend_gamma,
312
+ min_scale=min_scale,
313
+ max_scale=max_scale,
314
+ scale_factor=scale_factor,
315
+ ids_to_colors=ids_to_colors,
316
+ random_seed=random_seed,
317
+ )
318
+
319
+ def _is_connected(self, z: np.ndarray) -> bool:
320
+ """Determine if a thresholded grid represents a single, connected component.
321
+
322
+ Args:
323
+ z (np.ndarray): A binary grid where the component connectivity is evaluated.
324
+
325
+ Returns:
326
+ bool: True if the grid represents a single connected component, False otherwise.
327
+ """
328
+ _, num_features = label(z)
329
+ return num_features == 1 # Return True if only one connected component is found