risk-network 0.0.9b26__py3-none-any.whl → 0.0.9b28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple, Union
9
9
 
10
10
  import networkx as nx
11
11
  import numpy as np
12
+ from scipy.sparse import csr_matrix
12
13
  from sklearn.exceptions import DataConversionWarning
13
14
  from sklearn.metrics.pairwise import cosine_similarity
14
15
 
@@ -34,43 +35,43 @@ def get_network_neighborhoods(
34
35
  louvain_resolution: float = 0.1,
35
36
  leiden_resolution: float = 1.0,
36
37
  random_seed: int = 888,
37
- ) -> np.ndarray:
38
- """Calculate the combined neighborhoods for each node based on the specified community detection algorithm(s).
38
+ ) -> csr_matrix:
39
+ """Calculate the combined neighborhoods for each node using sparse matrices.
39
40
 
40
41
  Args:
41
42
  network (nx.Graph): The network graph.
42
43
  distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use.
43
- fraction_shortest_edges (float, List, Tuple, or np.ndarray, optional): Shortest edge rank fraction threshold(s) for creating subgraphs.
44
+ fraction_shortest_edges (float, List, Tuple, or np.ndarray, optional): Shortest edge rank fraction thresholds.
44
45
  louvain_resolution (float, optional): Resolution parameter for the Louvain method.
45
46
  leiden_resolution (float, optional): Resolution parameter for the Leiden method.
46
47
  random_seed (int, optional): Random seed for methods requiring random initialization.
47
48
 
48
49
  Returns:
49
- np.ndarray: Summed neighborhood matrix from all selected algorithms.
50
+ csr_matrix: The combined neighborhood matrix.
50
51
  """
51
52
  # Set random seed for reproducibility
52
53
  random.seed(random_seed)
53
54
  np.random.seed(random_seed)
54
55
 
55
- # Ensure distance_metric is a list/tuple for multi-algorithm handling
56
+ # Ensure distance_metric is a list for multi-algorithm handling
56
57
  if isinstance(distance_metric, (str, np.ndarray)):
57
58
  distance_metric = [distance_metric]
58
- # Ensure fraction_shortest_edges is a list/tuple for multi-threshold handling
59
+ # Ensure fraction_shortest_edges is a list for multi-threshold handling
59
60
  if isinstance(fraction_shortest_edges, (float, int)):
60
61
  fraction_shortest_edges = [fraction_shortest_edges] * len(distance_metric)
61
- # Check that the number of distance metrics matches the number of edge length thresholds
62
+ # Validate matching lengths of distance metrics and thresholds
62
63
  if len(distance_metric) != len(fraction_shortest_edges):
63
64
  raise ValueError(
64
65
  "The number of distance metrics must match the number of edge length thresholds."
65
66
  )
66
67
 
67
- # Initialize combined neighborhood matrix
68
+ # Initialize a sparse LIL matrix for incremental updates
68
69
  num_nodes = network.number_of_nodes()
69
- combined_neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
70
-
70
+ # Initialize a sparse matrix with the same shape as the network
71
+ combined_neighborhoods = csr_matrix((num_nodes, num_nodes), dtype=np.uint8)
71
72
  # Loop through each distance metric and corresponding edge rank fraction
72
73
  for metric, percentile in zip(distance_metric, fraction_shortest_edges):
73
- # Call the appropriate neighborhood function based on the metric
74
+ # Compute neighborhoods for the specified metric
74
75
  if metric == "greedy_modularity":
75
76
  neighborhoods = calculate_greedy_modularity_neighborhoods(
76
77
  network, fraction_shortest_edges=percentile
@@ -107,22 +108,37 @@ def get_network_neighborhoods(
107
108
  )
108
109
  else:
109
110
  raise ValueError(
110
- "Incorrect distance metric specified. Please choose from 'greedy_modularity', 'label_propagation',"
111
+ "Invalid distance metric. Choose from: 'greedy_modularity', 'label_propagation',"
111
112
  "'leiden', 'louvain', 'markov_clustering', 'spinglass', 'walktrap'."
112
113
  )
113
114
 
114
- # Sum the neighborhood matrices
115
+ # Add the sparse neighborhood matrix
115
116
  combined_neighborhoods += neighborhoods
116
117
 
117
- # Ensure that the maximum value in each row is set to 1
118
- # This ensures that for each row, only the strongest relationship (the maximum value) is retained,
119
- # while all other values are reset to 0. This transformation simplifies the neighborhood matrix by
120
- # focusing on the most significant connection per row (or nodes).
121
- combined_neighborhoods = _set_max_row_value_to_one(combined_neighborhoods)
118
+ # Ensure maximum value in each row is set to 1
119
+ combined_neighborhoods = _set_max_row_value_to_one_sparse(combined_neighborhoods)
122
120
 
123
121
  return combined_neighborhoods
124
122
 
125
123
 
124
+ def _set_max_row_value_to_one_sparse(matrix: csr_matrix) -> csr_matrix:
125
+ """Set the maximum value in each row of a sparse matrix to 1.
126
+
127
+ Args:
128
+ matrix (csr_matrix): The input sparse matrix.
129
+
130
+ Returns:
131
+ csr_matrix: The modified sparse matrix where only the maximum value in each row is set to 1.
132
+ """
133
+ # Iterate over each row and set the maximum value to 1
134
+ for i in range(matrix.shape[0]):
135
+ row_data = matrix[i].data
136
+ if len(row_data) > 0:
137
+ row_data[:] = (row_data == max(row_data)).astype(int)
138
+
139
+ return matrix
140
+
141
+
126
142
  def _set_max_row_value_to_one(matrix: np.ndarray) -> np.ndarray:
127
143
  """For each row in the input matrix, set the maximum value(s) to 1 and all other values to 0. This is particularly
128
144
  useful for neighborhood matrices that have undergone multiple neighborhood detection algorithms, where the
risk/network/geometry.py CHANGED
@@ -3,8 +3,6 @@ risk/network/geometry
3
3
  ~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
- import copy
7
-
8
6
  import networkx as nx
9
7
  import numpy as np
10
8
 
@@ -31,44 +29,43 @@ def assign_edge_lengths(
31
29
  """Compute distances between pairs of coordinates."""
32
30
  u_coords, v_coords = coords[:, 0, :], coords[:, 1, :]
33
31
  if is_sphere:
34
- u_norm = np.linalg.norm(u_coords, axis=1, keepdims=True)
35
- v_norm = np.linalg.norm(v_coords, axis=1, keepdims=True)
36
- u_coords /= u_norm
37
- v_coords /= v_norm
32
+ u_coords /= np.linalg.norm(u_coords, axis=1, keepdims=True)
33
+ v_coords /= np.linalg.norm(v_coords, axis=1, keepdims=True)
38
34
  dot_products = np.einsum("ij,ij->i", u_coords, v_coords)
39
35
  return np.arccos(np.clip(dot_products, -1.0, 1.0))
40
-
41
36
  return np.linalg.norm(u_coords - v_coords, axis=1)
42
37
 
43
38
  # Normalize graph coordinates and weights
44
39
  _normalize_graph_coordinates(G)
45
40
  _normalize_weights(G)
41
+
46
42
  # Map nodes to sphere and adjust depth if required
47
43
  if compute_sphere:
48
44
  _map_to_sphere(G)
49
- G_depth = _create_depth(copy.deepcopy(G), surface_depth=surface_depth)
45
+ G_depth = _create_depth(G, surface_depth=surface_depth)
50
46
  else:
51
- G_depth = copy.deepcopy(G)
52
-
53
- # Precompute edge coordinate arrays for vectorized computation
54
- edge_data = []
55
- for u, v in G_depth.edges:
56
- u_coords = np.array([G_depth.nodes[u]["x"], G_depth.nodes[u]["y"]])
57
- v_coords = np.array([G_depth.nodes[v]["x"], G_depth.nodes[v]["y"]])
58
- if compute_sphere:
59
- u_coords = np.append(u_coords, G_depth.nodes[u].get("z", 0))
60
- v_coords = np.append(v_coords, G_depth.nodes[v].get("z", 0))
61
- edge_data.append([u_coords, v_coords, (u, v)])
62
-
63
- # Convert to numpy for faster processing
64
- edge_coords = np.array([(e[0], e[1]) for e in edge_data])
65
- edge_indices = [e[2] for e in edge_data]
66
- # Compute distances in bulk
67
- distances = compute_distance_vectorized(edge_coords, compute_sphere)
47
+ G_depth = G
48
+
49
+ # Precompute edge coordinate arrays and compute distances in bulk
50
+ edge_data = np.array(
51
+ [
52
+ [
53
+ np.array(
54
+ [G_depth.nodes[u]["x"], G_depth.nodes[u]["y"], G_depth.nodes[u].get("z", 0)]
55
+ ),
56
+ np.array(
57
+ [G_depth.nodes[v]["x"], G_depth.nodes[v]["y"], G_depth.nodes[v].get("z", 0)]
58
+ ),
59
+ ]
60
+ for u, v in G_depth.edges
61
+ ]
62
+ )
63
+ # Compute distances
64
+ distances = compute_distance_vectorized(edge_data, compute_sphere)
68
65
  # Assign distances back to the graph
69
- for (u, v), distance in zip(edge_indices, distances):
66
+ for (u, v), distance in zip(G_depth.edges, distances):
70
67
  if include_edge_weight:
71
- weight = G.edges[u, v].get("normalized_weight", 0) + 1e-6
68
+ weight = G.edges[u, v].get("normalized_weight", 1e-6) # Avoid divide-by-zero
72
69
  G.edges[u, v]["length"] = distance / np.sqrt(weight)
73
70
  else:
74
71
  G.edges[u, v]["length"] = distance
risk/network/graph/api.py CHANGED
@@ -16,7 +16,7 @@ from risk.neighborhoods import (
16
16
  process_neighborhoods,
17
17
  trim_domains,
18
18
  )
19
- from risk.network.graph.network import NetworkGraph
19
+ from risk.network.graph.graph import Graph
20
20
  from risk.stats import calculate_significance_matrices
21
21
 
22
22
 
@@ -44,7 +44,7 @@ class GraphAPI:
44
44
  linkage_metric: str = "yule",
45
45
  min_cluster_size: int = 5,
46
46
  max_cluster_size: int = 1000,
47
- ) -> NetworkGraph:
47
+ ) -> Graph:
48
48
  """Load and process the network graph, defining top annotations and domains.
49
49
 
50
50
  Args:
@@ -63,7 +63,7 @@ class GraphAPI:
63
63
  max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
64
64
 
65
65
  Returns:
66
- NetworkGraph: A fully initialized and processed NetworkGraph object.
66
+ Graph: A fully initialized and processed Graph object.
67
67
  """
68
68
  # Log the parameters and display headers
69
69
  log_header("Finding significant neighborhoods")
@@ -139,13 +139,13 @@ class GraphAPI:
139
139
  max_cluster_size=max_cluster_size,
140
140
  )
141
141
 
142
- # Prepare node mapping and significance sums for the final NetworkGraph object
142
+ # Prepare node mapping and significance sums for the final Graph object
143
143
  ordered_nodes = annotations["ordered_nodes"]
144
144
  node_label_to_id = dict(zip(ordered_nodes, range(len(ordered_nodes))))
145
145
  node_significance_sums = processed_neighborhoods["node_significance_sums"]
146
146
 
147
- # Return the fully initialized NetworkGraph object
148
- return NetworkGraph(
147
+ # Return the fully initialized Graph object
148
+ return Graph(
149
149
  network=network,
150
150
  annotations=annotations,
151
151
  neighborhoods=neighborhoods,
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/graph/network
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/graph/graph
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from collections import defaultdict
@@ -10,13 +10,13 @@ import networkx as nx
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
 
13
- from risk.network.graph.summary import AnalysisSummary
13
+ from risk.network.graph.summary import Summary
14
14
 
15
15
 
16
- class NetworkGraph:
16
+ class Graph:
17
17
  """A class to represent a network graph and process its nodes and edges.
18
18
 
19
- The NetworkGraph class provides functionality to handle and manipulate a network graph,
19
+ The Graph class provides functionality to handle and manipulate a network graph,
20
20
  including managing domains, annotations, and node significance data. It also includes methods
21
21
  for transforming and mapping graph coordinates, as well as generating colors based on node
22
22
  significance.
@@ -32,7 +32,7 @@ class NetworkGraph:
32
32
  node_label_to_node_id_map: Dict[str, Any],
33
33
  node_significance_sums: np.ndarray,
34
34
  ):
35
- """Initialize the NetworkGraph object.
35
+ """Initialize the Graph object.
36
36
 
37
37
  Args:
38
38
  network (nx.Graph): The network graph.
@@ -69,7 +69,7 @@ class NetworkGraph:
69
69
  self.node_coordinates = _extract_node_coordinates(self.network)
70
70
 
71
71
  # NOTE: Only after the above attributes are initialized, we can create the summary
72
- self.summary = AnalysisSummary(annotations, neighborhoods, self)
72
+ self.summary = Summary(annotations, neighborhoods, self)
73
73
 
74
74
  def pop(self, domain_id: str) -> None:
75
75
  """Remove domain ID from instance domain ID mappings. This can be useful for cleaning up
@@ -12,7 +12,7 @@ from statsmodels.stats.multitest import fdrcorrection
12
12
  from risk.log.console import logger, log_header
13
13
 
14
14
 
15
- class AnalysisSummary:
15
+ class Summary:
16
16
  """Handles the processing, storage, and export of network analysis results.
17
17
 
18
18
  The Results class provides methods to process significance and depletion data, compute
@@ -25,14 +25,14 @@ class AnalysisSummary:
25
25
  self,
26
26
  annotations: Dict[str, Any],
27
27
  neighborhoods: Dict[str, Any],
28
- graph, # Avoid type hinting NetworkGraph to prevent circular imports
28
+ graph, # Avoid type hinting Graph to prevent circular imports
29
29
  ):
30
30
  """Initialize the Results object with analysis components.
31
31
 
32
32
  Args:
33
33
  annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
34
34
  neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
35
- graph (NetworkGraph): Graph object representing domain-to-node and node-to-label mappings.
35
+ graph (Graph): Graph object representing domain-to-node and node-to-label mappings.
36
36
  """
37
37
  self.annotations = annotations
38
38
  self.neighborhoods = neighborhoods
risk/network/io.py CHANGED
@@ -217,6 +217,9 @@ class NetworkIO:
217
217
 
218
218
  Returns:
219
219
  nx.Graph: Loaded and processed network.
220
+
221
+ Raises:
222
+ ValueError: If no matching attribute metadata file is found.
220
223
  """
221
224
  filetype = "Cytoscape"
222
225
  # Log the loading of the Cytoscape file
@@ -258,13 +261,29 @@ class NetworkIO:
258
261
 
259
262
  # Read the node attributes (from /tables/)
260
263
  attribute_metadata_keywords = ["/tables/", "SHARED_ATTRS", "edge.cytable"]
261
- attribute_metadata = [
262
- os.path.join(tmp_dir, cf)
263
- for cf in cys_files
264
- if all(keyword in cf for keyword in attribute_metadata_keywords)
265
- ][0]
266
- # Load attributes file from Cytoscape as pandas data frame
267
- attribute_table = pd.read_csv(attribute_metadata, sep=",", header=None, skiprows=1)
264
+ # Use a generator to find the first matching file
265
+ attribute_metadata = next(
266
+ (
267
+ os.path.join(tmp_dir, cf)
268
+ for cf in cys_files
269
+ if all(keyword in cf for keyword in attribute_metadata_keywords)
270
+ ),
271
+ None, # Default if no file matches
272
+ )
273
+ if attribute_metadata:
274
+ # Optimize `read_csv` by leveraging proper options
275
+ attribute_table = pd.read_csv(
276
+ attribute_metadata,
277
+ sep=",",
278
+ header=None,
279
+ skiprows=1,
280
+ dtype=str, # Use specific dtypes to reduce memory usage
281
+ engine="c", # Use the C engine for parsing if compatible
282
+ low_memory=False, # Optimize memory handling for large files
283
+ )
284
+ else:
285
+ raise ValueError("No matching attribute metadata file found.")
286
+
268
287
  # Set columns
269
288
  attribute_table.columns = attribute_table.iloc[0]
270
289
  # Skip first four rows
@@ -464,14 +483,19 @@ class NetworkIO:
464
483
  Args:
465
484
  G (nx.Graph): A NetworkX graph object.
466
485
  """
467
- missing_weights = 0
468
- # Assign user-defined edge weights to the "weight" attribute
469
- nx.set_edge_attributes(G, 1.0, "weight") # Set default weight
470
- if self.weight_label in nx.get_edge_attributes(G, self.weight_label):
471
- nx.set_edge_attributes(G, nx.get_edge_attributes(G, self.weight_label), "weight")
472
-
473
- if self.include_edge_weight and missing_weights:
474
- logger.debug(f"Total edges missing weights: {missing_weights}")
486
+ # Set default weight for all edges in bulk
487
+ default_weight = 1.0
488
+ nx.set_edge_attributes(G, default_weight, "weight")
489
+ # Check and assign user-defined edge weights if available
490
+ weight_attributes = nx.get_edge_attributes(G, self.weight_label)
491
+ if weight_attributes:
492
+ nx.set_edge_attributes(G, weight_attributes, "weight")
493
+
494
+ # Log missing weights if include_edge_weight is enabled
495
+ if self.include_edge_weight:
496
+ missing_weights = len(G.edges) - len(weight_attributes)
497
+ if missing_weights > 0:
498
+ logger.debug(f"Total edges missing weights: {missing_weights}")
475
499
 
476
500
  def _validate_nodes(self, G: nx.Graph) -> None:
477
501
  """Validate the graph structure and attributes with attribute fallback for positions and labels.
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/plot
3
- ~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter
3
+ ~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from risk.network.plotter.api import PlotterAPI
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/graph/api
3
- ~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
@@ -8,14 +8,14 @@ from typing import List, Tuple, Union
8
8
  import numpy as np
9
9
 
10
10
  from risk.log import log_header
11
- from risk.network.graph.network import NetworkGraph
12
- from risk.network.plotter.network import NetworkPlotter
11
+ from risk.network.graph.graph import Graph
12
+ from risk.network.plotter.plotter import Plotter
13
13
 
14
14
 
15
15
  class PlotterAPI:
16
16
  """Handles the loading of network plotter objects.
17
17
 
18
- The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
18
+ The PlotterAPI class provides methods to load and configure Plotter objects for plotting network graphs.
19
19
  """
20
20
 
21
21
  def __init__() -> None:
@@ -23,16 +23,16 @@ class PlotterAPI:
23
23
 
24
24
  def load_plotter(
25
25
  self,
26
- graph: NetworkGraph,
26
+ graph: Graph,
27
27
  figsize: Union[List, Tuple, np.ndarray] = (10, 10),
28
28
  background_color: str = "white",
29
29
  background_alpha: Union[float, None] = 1.0,
30
30
  pad: float = 0.3,
31
- ) -> NetworkPlotter:
32
- """Get a NetworkPlotter object for plotting.
31
+ ) -> Plotter:
32
+ """Get a Plotter object for plotting.
33
33
 
34
34
  Args:
35
- graph (NetworkGraph): The graph to plot.
35
+ graph (Graph): The graph to plot.
36
36
  figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
37
37
  background_color (str, optional): Background color of the plot. Defaults to "white".
38
38
  background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
@@ -40,12 +40,12 @@ class PlotterAPI:
40
40
  pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
41
41
 
42
42
  Returns:
43
- NetworkPlotter: A NetworkPlotter object configured with the given parameters.
43
+ Plotter: A Plotter object configured with the given parameters.
44
44
  """
45
45
  log_header("Loading plotter")
46
46
 
47
- # Initialize and return a NetworkPlotter object
48
- return NetworkPlotter(
47
+ # Initialize and return a Plotter object
48
+ return Plotter(
49
49
  graph,
50
50
  figsize=figsize,
51
51
  background_color=background_color,
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/plot/canvas
3
- ~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/canvas
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
 
11
11
  from risk.log import params
12
- from risk.network.graph.network import NetworkGraph
12
+ from risk.network.graph.graph import Graph
13
13
  from risk.network.plotter.utils.colors import to_rgba
14
14
  from risk.network.plotter.utils.layout import calculate_bounding_box
15
15
 
@@ -17,11 +17,11 @@ from risk.network.plotter.utils.layout import calculate_bounding_box
17
17
  class Canvas:
18
18
  """A class for laying out the canvas in a network graph."""
19
19
 
20
- def __init__(self, graph: NetworkGraph, ax: plt.Axes) -> None:
21
- """Initialize the Canvas with a NetworkGraph and axis for plotting.
20
+ def __init__(self, graph: Graph, ax: plt.Axes) -> None:
21
+ """Initialize the Canvas with a Graph and axis for plotting.
22
22
 
23
23
  Args:
24
- graph (NetworkGraph): The NetworkGraph object containing the network data.
24
+ graph (Graph): The Graph object containing the network data.
25
25
  ax (plt.Axes): The axis to plot the canvas on.
26
26
  """
27
27
  self.graph = graph
@@ -236,7 +236,7 @@ class Canvas:
236
236
  # Scale the node coordinates if needed
237
237
  scaled_coordinates = node_coordinates * scale
238
238
  # Use the existing _draw_kde_contour method
239
- # NOTE: This is a technical debt that should be refactored in the future - only works when inherited by NetworkPlotter
239
+ # NOTE: This is a technical debt that should be refactored in the future - only works when inherited by Plotter
240
240
  self._draw_kde_contour(
241
241
  ax=self.ax,
242
242
  pos=scaled_coordinates,
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/plot/contour
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/contour
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import Any, Dict, List, Tuple, Union
@@ -12,18 +12,18 @@ from scipy.ndimage import label
12
12
  from scipy.stats import gaussian_kde
13
13
 
14
14
  from risk.log import params, logger
15
- from risk.network.graph.network import NetworkGraph
15
+ from risk.network.graph.graph import Graph
16
16
  from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
17
17
 
18
18
 
19
19
  class Contour:
20
20
  """Class to generate Kernel Density Estimate (KDE) contours for nodes in a network graph."""
21
21
 
22
- def __init__(self, graph: NetworkGraph, ax: plt.Axes) -> None:
23
- """Initialize the Contour with a NetworkGraph and axis for plotting.
22
+ def __init__(self, graph: Graph, ax: plt.Axes) -> None:
23
+ """Initialize the Contour with a Graph and axis for plotting.
24
24
 
25
25
  Args:
26
- graph (NetworkGraph): The NetworkGraph object containing the network data.
26
+ graph (Graph): The Graph object containing the network data.
27
27
  ax (plt.Axes): The axis to plot the contours on.
28
28
  """
29
29
  self.graph = graph
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/plot/labels
3
- ~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/labels
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import copy
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import pandas as pd
12
12
 
13
13
  from risk.log import params
14
- from risk.network.graph.network import NetworkGraph
14
+ from risk.network.graph.graph import Graph
15
15
  from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
16
16
  from risk.network.plotter.utils.layout import calculate_bounding_box
17
17
 
@@ -21,11 +21,11 @@ TERM_DELIMITER = "::::" # String used to separate multiple domain terms when co
21
21
  class Labels:
22
22
  """Class to handle the annotation of network graphs with labels for different domains."""
23
23
 
24
- def __init__(self, graph: NetworkGraph, ax: plt.Axes):
24
+ def __init__(self, graph: Graph, ax: plt.Axes):
25
25
  """Initialize the Labeler object with a network graph and matplotlib axes.
26
26
 
27
27
  Args:
28
- graph (NetworkGraph): NetworkGraph object containing the network data.
28
+ graph (Graph): Graph object containing the network data.
29
29
  ax (plt.Axes): Matplotlib axes object to plot the labels on.
30
30
  """
31
31
  self.graph = graph