risk-network 0.0.8b18__py3-none-any.whl → 0.0.9b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +133 -72
  4. risk/annotations/io.py +50 -34
  5. risk/log/__init__.py +4 -2
  6. risk/log/{config.py → console.py} +5 -3
  7. risk/log/{params.py → parameters.py} +21 -46
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +446 -0
  10. risk/neighborhoods/community.py +281 -96
  11. risk/neighborhoods/domains.py +92 -38
  12. risk/neighborhoods/neighborhoods.py +210 -149
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +69 -58
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +194 -0
  17. risk/network/graph/network.py +269 -0
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +58 -48
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +80 -26
  23. risk/network/{plot → plotter}/contour.py +43 -34
  24. risk/network/{plot → plotter}/labels.py +123 -113
  25. risk/network/plotter/network.py +424 -0
  26. risk/network/plotter/utils/colors.py +416 -0
  27. risk/network/plotter/utils/layout.py +94 -0
  28. risk/risk.py +11 -469
  29. risk/stats/__init__.py +8 -4
  30. risk/stats/binom.py +51 -0
  31. risk/stats/chi2.py +69 -0
  32. risk/stats/hypergeom.py +28 -18
  33. risk/stats/permutation/__init__.py +1 -1
  34. risk/stats/permutation/permutation.py +45 -39
  35. risk/stats/permutation/test_functions.py +25 -17
  36. risk/stats/poisson.py +17 -11
  37. risk/stats/stats.py +20 -16
  38. risk/stats/zscore.py +68 -0
  39. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
  40. risk_network-0.0.9b26.dist-info/RECORD +44 -0
  41. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
  42. risk/network/graph.py +0 -159
  43. risk/network/plot/__init__.py +0 -6
  44. risk/network/plot/network.py +0 -282
  45. risk/network/plot/plotter.py +0 -137
  46. risk/network/plot/utils/color.py +0 -353
  47. risk/network/plot/utils/layout.py +0 -53
  48. risk_network-0.0.8b18.dist-info/RECORD +0 -37
  49. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
  50. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ """
2
+ risk/network/graph/summary
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict, Tuple, Union
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from statsmodels.stats.multitest import fdrcorrection
11
+
12
+ from risk.log.console import logger, log_header
13
+
14
+
15
+ class AnalysisSummary:
16
+ """Handles the processing, storage, and export of network analysis results.
17
+
18
+ The Results class provides methods to process significance and depletion data, compute
19
+ FDR-corrected q-values, and structure information on domains and annotations into a
20
+ DataFrame. It also offers functionality to export the processed data in CSV, JSON,
21
+ and text formats for analysis and reporting.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ annotations: Dict[str, Any],
27
+ neighborhoods: Dict[str, Any],
28
+ graph, # Avoid type hinting NetworkGraph to prevent circular imports
29
+ ):
30
+ """Initialize the Results object with analysis components.
31
+
32
+ Args:
33
+ annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
34
+ neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
35
+ graph (NetworkGraph): Graph object representing domain-to-node and node-to-label mappings.
36
+ """
37
+ self.annotations = annotations
38
+ self.neighborhoods = neighborhoods
39
+ self.graph = graph
40
+
41
+ def to_csv(self, filepath: str) -> None:
42
+ """Export significance results to a CSV file.
43
+
44
+ Args:
45
+ filepath (str): The path where the CSV file will be saved.
46
+ """
47
+ # Load results and export directly to CSV
48
+ results = self.load()
49
+ results.to_csv(filepath, index=False)
50
+ logger.info(f"Analysis summary exported to CSV file: {filepath}")
51
+
52
+ def to_json(self, filepath: str) -> None:
53
+ """Export significance results to a JSON file.
54
+
55
+ Args:
56
+ filepath (str): The path where the JSON file will be saved.
57
+ """
58
+ # Load results and export directly to JSON
59
+ results = self.load()
60
+ results.to_json(filepath, orient="records", indent=4)
61
+ logger.info(f"Analysis summary exported to JSON file: {filepath}")
62
+
63
+ def to_txt(self, filepath: str) -> None:
64
+ """Export significance results to a text file.
65
+
66
+ Args:
67
+ filepath (str): The path where the text file will be saved.
68
+ """
69
+ # Load results and export directly to text file
70
+ results = self.load()
71
+ with open(filepath, "w", encoding="utf-8") as txt_file:
72
+ txt_file.write(results.to_string(index=False))
73
+
74
+ logger.info(f"Analysis summary exported to text file: {filepath}")
75
+
76
+ def load(self) -> pd.DataFrame:
77
+ """Load and process domain and annotation data into a DataFrame with significance metrics.
78
+
79
+ Returns:
80
+ pd.DataFrame: Processed DataFrame containing significance scores, p-values, q-values,
81
+ and annotation member information.
82
+ """
83
+ log_header("Loading analysis summary")
84
+ # Calculate significance and depletion q-values from p-value matrices in `annotations`
85
+ enrichment_pvals = self.neighborhoods["enrichment_pvals"]
86
+ depletion_pvals = self.neighborhoods["depletion_pvals"]
87
+ enrichment_qvals = self._calculate_qvalues(enrichment_pvals)
88
+ depletion_qvals = self._calculate_qvalues(depletion_pvals)
89
+
90
+ # Initialize DataFrame with domain and annotation details
91
+ results = pd.DataFrame(
92
+ [
93
+ {"Domain ID": domain_id, "Annotation": desc, "Summed Significance Score": score}
94
+ for domain_id, info in self.graph.domain_id_to_domain_info_map.items()
95
+ for desc, score in zip(info["full_descriptions"], info["significance_scores"])
96
+ ]
97
+ )
98
+ # Sort by Domain ID and Summed Significance Score
99
+ results = results.sort_values(
100
+ by=["Domain ID", "Summed Significance Score"], ascending=[True, False]
101
+ ).reset_index(drop=True)
102
+
103
+ # Add minimum p-values and q-values to DataFrame
104
+ results[
105
+ [
106
+ "Enrichment P-value",
107
+ "Enrichment Q-value",
108
+ "Depletion P-value",
109
+ "Depletion Q-value",
110
+ ]
111
+ ] = results.apply(
112
+ lambda row: self._get_significance_values(
113
+ row["Domain ID"],
114
+ row["Annotation"],
115
+ enrichment_pvals,
116
+ depletion_pvals,
117
+ enrichment_qvals,
118
+ depletion_qvals,
119
+ ),
120
+ axis=1,
121
+ result_type="expand",
122
+ )
123
+ # Add annotation members and their counts
124
+ results["Annotation Members in Network"] = results["Annotation"].apply(
125
+ lambda desc: self._get_annotation_members(desc)
126
+ )
127
+ results["Annotation Members in Network Count"] = results[
128
+ "Annotation Members in Network"
129
+ ].apply(lambda x: len(x.split(";")) if x else 0)
130
+
131
+ # Reorder columns and drop rows with NaN values
132
+ results = (
133
+ results[
134
+ [
135
+ "Domain ID",
136
+ "Annotation",
137
+ "Annotation Members in Network",
138
+ "Annotation Members in Network Count",
139
+ "Summed Significance Score",
140
+ "Enrichment P-value",
141
+ "Enrichment Q-value",
142
+ "Depletion P-value",
143
+ "Depletion Q-value",
144
+ ]
145
+ ]
146
+ .dropna()
147
+ .reset_index(drop=True)
148
+ )
149
+
150
+ # Convert annotations list to a DataFrame for comparison then merge with results
151
+ ordered_annotations = pd.DataFrame({"Annotation": self.annotations["ordered_annotations"]})
152
+ # Merge to ensure all annotations are present, filling missing rows with defaults
153
+ results = pd.merge(ordered_annotations, results, on="Annotation", how="left").fillna(
154
+ {
155
+ "Domain ID": -1,
156
+ "Annotation Members in Network": "",
157
+ "Annotation Members in Network Count": 0,
158
+ "Summed Significance Score": 0.0,
159
+ "Enrichment P-value": 1.0,
160
+ "Enrichment Q-value": 1.0,
161
+ "Depletion P-value": 1.0,
162
+ "Depletion Q-value": 1.0,
163
+ }
164
+ )
165
+ # Convert "Domain ID" and "Annotation Members in Network Count" to integers
166
+ results["Domain ID"] = results["Domain ID"].astype(int)
167
+ results["Annotation Members in Network Count"] = results[
168
+ "Annotation Members in Network Count"
169
+ ].astype(int)
170
+
171
+ return results
172
+
173
+ @staticmethod
174
+ def _calculate_qvalues(pvals: np.ndarray) -> np.ndarray:
175
+ """Calculate q-values (FDR) for each row of a p-value matrix.
176
+
177
+ Args:
178
+ pvals (np.ndarray): 2D array of p-values.
179
+
180
+ Returns:
181
+ np.ndarray: 2D array of q-values, with FDR correction applied row-wise.
182
+ """
183
+ return np.apply_along_axis(lambda row: fdrcorrection(row)[1], 1, pvals)
184
+
185
+ def _get_significance_values(
186
+ self,
187
+ domain_id: int,
188
+ description: str,
189
+ enrichment_pvals: np.ndarray,
190
+ depletion_pvals: np.ndarray,
191
+ enrichment_qvals: np.ndarray,
192
+ depletion_qvals: np.ndarray,
193
+ ) -> Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
194
+ """Retrieve the most significant p-values and q-values (FDR) for a given annotation.
195
+
196
+ Args:
197
+ domain_id (int): The domain ID associated with the annotation.
198
+ description (str): The annotation description.
199
+ enrichment_pvals (np.ndarray): Matrix of significance p-values.
200
+ depletion_pvals (np.ndarray): Matrix of depletion p-values.
201
+ enrichment_qvals (np.ndarray): Matrix of significance q-values.
202
+ depletion_qvals (np.ndarray): Matrix of depletion q-values.
203
+
204
+ Returns:
205
+ Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
206
+ Minimum significance p-value, significance q-value, depletion p-value, depletion q-value.
207
+ """
208
+ try:
209
+ annotation_idx = self.annotations["ordered_annotations"].index(description)
210
+ except ValueError:
211
+ return None, None, None, None # Description not found
212
+
213
+ node_indices = self.graph.domain_id_to_node_ids_map.get(domain_id, [])
214
+ if not node_indices:
215
+ return None, None, None, None # No associated nodes
216
+
217
+ sig_p = enrichment_pvals[node_indices, annotation_idx]
218
+ dep_p = depletion_pvals[node_indices, annotation_idx]
219
+ sig_q = enrichment_qvals[node_indices, annotation_idx]
220
+ dep_q = depletion_qvals[node_indices, annotation_idx]
221
+
222
+ return (
223
+ np.min(sig_p) if sig_p.size > 0 else None,
224
+ np.min(sig_q) if sig_q.size > 0 else None,
225
+ np.min(dep_p) if dep_p.size > 0 else None,
226
+ np.min(dep_q) if dep_q.size > 0 else None,
227
+ )
228
+
229
+ def _get_annotation_members(self, description: str) -> str:
230
+ """Retrieve node labels associated with a given annotation description.
231
+
232
+ Args:
233
+ description (str): The annotation description.
234
+
235
+ Returns:
236
+ str: ';'-separated string of node labels that are associated with the annotation.
237
+ """
238
+ try:
239
+ annotation_idx = self.annotations["ordered_annotations"].index(description)
240
+ except ValueError:
241
+ return "" # Description not found
242
+
243
+ # Get the column (safely) from the sparse matrix
244
+ column = self.annotations["matrix"][:, annotation_idx]
245
+ # Convert the column to a dense array if needed
246
+ column = column.toarray().ravel() # Convert to a 1D dense array
247
+ # Get nodes present for the annotation and sort by node label - use np.where on the dense array
248
+ nodes_present = np.where(column == 1)[0]
249
+ node_labels = sorted(
250
+ self.graph.node_id_to_node_label_map[node_id]
251
+ for node_id in nodes_present
252
+ if node_id in self.graph.node_id_to_node_label_map
253
+ )
254
+ return ";".join(node_labels)
risk/network/io.py CHANGED
@@ -1,10 +1,9 @@
1
1
  """
2
2
  risk/network/io
3
3
  ~~~~~~~~~~~~~~~
4
-
5
- This file contains the code for the RISK class and command-line access.
6
4
  """
7
5
 
6
+ import copy
8
7
  import json
9
8
  import os
10
9
  import pickle
@@ -155,7 +154,7 @@ class NetworkIO:
155
154
  self._log_loading(filetype)
156
155
 
157
156
  # Important: Make a copy of the network to avoid modifying the original
158
- network_copy = network.copy()
157
+ network_copy = copy.deepcopy(network)
159
158
  # Initialize the graph
160
159
  return self._initialize_graph(network_copy)
161
160
 
@@ -164,12 +163,12 @@ class NetworkIO:
164
163
  filepath: str,
165
164
  source_label: str = "source",
166
165
  target_label: str = "target",
166
+ view_name: str = "",
167
167
  compute_sphere: bool = True,
168
168
  surface_depth: float = 0.0,
169
169
  min_edges_per_node: int = 0,
170
170
  include_edge_weight: bool = True,
171
171
  weight_label: str = "weight",
172
- view_name: str = "",
173
172
  ) -> nx.Graph:
174
173
  """Load a network from a Cytoscape file.
175
174
 
@@ -177,7 +176,7 @@ class NetworkIO:
177
176
  filepath (str): Path to the Cytoscape file.
178
177
  source_label (str, optional): Source node label. Defaults to "source".
179
178
  target_label (str, optional): Target node label. Defaults to "target".
180
- view_name (str, optional): Specific view name to load. Defaults to None.
179
+ view_name (str, optional): Specific view name to load. Defaults to "".
181
180
  compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
182
181
  surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
183
182
  min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
@@ -214,7 +213,7 @@ class NetworkIO:
214
213
  filepath (str): Path to the Cytoscape file.
215
214
  source_label (str, optional): Source node label. Defaults to "source".
216
215
  target_label (str, optional): Target node label. Defaults to "target".
217
- view_name (str, optional): Specific view name to load. Defaults to None.
216
+ view_name (str, optional): Specific view name to load. Defaults to "".
218
217
 
219
218
  Returns:
220
219
  nx.Graph: Loaded and processed network.
@@ -297,12 +296,8 @@ class NetworkIO:
297
296
  # Add node attributes
298
297
  for node in G.nodes():
299
298
  G.nodes[node]["label"] = node
300
- G.nodes[node]["x"] = node_x_positions[
301
- node
302
- ] # Assuming you have a dict `node_x_positions` for x coordinates
303
- G.nodes[node]["y"] = node_y_positions[
304
- node
305
- ] # Assuming you have a dict `node_y_positions` for y coordinates
299
+ G.nodes[node]["x"] = node_x_positions[node]
300
+ G.nodes[node]["y"] = node_y_positions[node]
306
301
 
307
302
  # Initialize the graph
308
303
  return self._initialize_graph(G)
@@ -378,15 +373,17 @@ class NetworkIO:
378
373
  node_y_positions = {}
379
374
  for node in cyjs_data["elements"]["nodes"]:
380
375
  node_data = node["data"]
381
- node_id = node_data["id_original"]
376
+ # Use the original node ID if available, otherwise use the default ID
377
+ node_id = node_data.get("id_original", node_data.get("id"))
382
378
  node_x_positions[node_id] = node["position"]["x"]
383
379
  node_y_positions[node_id] = node["position"]["y"]
384
380
 
385
381
  # Process edges and add them to the graph
386
382
  for edge in cyjs_data["elements"]["edges"]:
387
383
  edge_data = edge["data"]
388
- source = edge_data[f"{source_label}_original"]
389
- target = edge_data[f"{target_label}_original"]
384
+ # Use the original source and target labels if available, otherwise fall back to default labels
385
+ source = edge_data.get(f"{source_label}_original", edge_data.get(source_label))
386
+ target = edge_data.get(f"{target_label}_original", edge_data.get(target_label))
390
387
  # Add the edge to the graph, optionally including weights
391
388
  if self.weight_label is not None and self.weight_label in edge_data:
392
389
  weight = float(edge_data[self.weight_label])
@@ -424,7 +421,7 @@ class NetworkIO:
424
421
  self._remove_invalid_graph_properties(G)
425
422
  # IMPORTANT: This is where the graph node labels are converted to integers
426
423
  # Make sure to perform this step after all other processing
427
- G = nx.relabel_nodes(G, {node: idx for idx, node in enumerate(G.nodes)})
424
+ G = nx.convert_node_labels_to_integers(G)
428
425
  return G
429
426
 
430
427
  def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
@@ -434,22 +431,24 @@ class NetworkIO:
434
431
  Args:
435
432
  G (nx.Graph): A NetworkX graph object.
436
433
  """
437
- # Count number of nodes and edges before cleaning
434
+ # Count the number of nodes and edges before cleaning
438
435
  num_initial_nodes = G.number_of_nodes()
439
436
  num_initial_edges = G.number_of_edges()
440
437
  # Remove self-loops to ensure correct edge count
441
- G.remove_edges_from(list(nx.selfloop_edges(G)))
438
+ G.remove_edges_from(nx.selfloop_edges(G))
442
439
  # Iteratively remove nodes with fewer edges than the threshold
443
440
  while True:
444
- nodes_to_remove = [node for node in G.nodes if G.degree(node) < self.min_edges_per_node]
441
+ nodes_to_remove = [
442
+ node
443
+ for node, degree in dict(G.degree()).items()
444
+ if degree < self.min_edges_per_node
445
+ ]
445
446
  if not nodes_to_remove:
446
- break # Exit loop if no more nodes need removal
447
+ break # Exit loop if no nodes meet the condition
447
448
  G.remove_nodes_from(nodes_to_remove)
448
449
 
449
450
  # Remove isolated nodes
450
- isolated_nodes = list(nx.isolates(G))
451
- if isolated_nodes:
452
- G.remove_nodes_from(isolated_nodes)
451
+ G.remove_nodes_from(nx.isolates(G))
453
452
 
454
453
  # Log the number of nodes and edges before and after cleaning
455
454
  num_final_nodes = G.number_of_nodes()
@@ -467,12 +466,9 @@ class NetworkIO:
467
466
  """
468
467
  missing_weights = 0
469
468
  # Assign user-defined edge weights to the "weight" attribute
470
- for _, _, data in G.edges(data=True):
471
- if self.weight_label not in data:
472
- missing_weights += 1
473
- data["weight"] = data.get(
474
- self.weight_label, 1.0
475
- ) # Default to 1.0 if 'weight' not present
469
+ nx.set_edge_attributes(G, 1.0, "weight") # Set default weight
470
+ if self.weight_label in nx.get_edge_attributes(G, self.weight_label):
471
+ nx.set_edge_attributes(G, nx.get_edge_attributes(G, self.weight_label), "weight")
476
472
 
477
473
  if self.include_edge_weight and missing_weights:
478
474
  logger.debug(f"Total edges missing weights: {missing_weights}")
@@ -482,41 +478,55 @@ class NetworkIO:
482
478
 
483
479
  Args:
484
480
  G (nx.Graph): A NetworkX graph object.
481
+
482
+ Raises:
483
+ ValueError: If a node is missing 'x', 'y', and a valid 'pos' attribute.
485
484
  """
486
- # Keep track of nodes missing labels
485
+ # Retrieve all relevant attributes in bulk
486
+ pos_attrs = nx.get_node_attributes(G, "pos")
487
+ name_attrs = nx.get_node_attributes(G, "name")
488
+ id_attrs = nx.get_node_attributes(G, "id")
489
+ # Dictionaries to hold missing or fallback attributes
490
+ x_attrs = {}
491
+ y_attrs = {}
492
+ label_attrs = {}
487
493
  nodes_with_missing_labels = []
488
494
 
489
- for node, attrs in G.nodes(data=True):
490
- # Attribute fallback for 'x' and 'y' attributes
495
+ # Iterate through nodes to validate and assign missing attributes
496
+ for node in G.nodes:
497
+ attrs = G.nodes[node]
498
+ # Validate and assign 'x' and 'y' attributes
491
499
  if "x" not in attrs or "y" not in attrs:
492
500
  if (
493
- "pos" in attrs
494
- and isinstance(attrs["pos"], (list, tuple, np.ndarray))
495
- and len(attrs["pos"]) >= 2
501
+ node in pos_attrs
502
+ and isinstance(pos_attrs[node], (list, tuple, np.ndarray))
503
+ and len(pos_attrs[node]) >= 2
496
504
  ):
497
- attrs["x"], attrs["y"] = attrs["pos"][
498
- :2
499
- ] # Use only x and y, ignoring z if present
505
+ x_attrs[node], y_attrs[node] = pos_attrs[node][:2]
500
506
  else:
501
507
  raise ValueError(
502
508
  f"Node {node} is missing 'x', 'y', and a valid 'pos' attribute."
503
509
  )
504
510
 
505
- # Attribute fallback for 'label' attribute
511
+ # Validate and assign 'label' attribute
506
512
  if "label" not in attrs:
507
- # Try alternative attribute names for label
508
- if "name" in attrs:
509
- attrs["label"] = attrs["name"]
510
- elif "id" in attrs:
511
- attrs["label"] = attrs["id"]
513
+ if node in name_attrs:
514
+ label_attrs[node] = name_attrs[node]
515
+ elif node in id_attrs:
516
+ label_attrs[node] = id_attrs[node]
512
517
  else:
513
- # Collect nodes with missing labels
518
+ # Assign node ID as label and log the missing label
519
+ label_attrs[node] = str(node)
514
520
  nodes_with_missing_labels.append(node)
515
- attrs["label"] = str(node) # Use node ID as the label
516
521
 
517
- # Issue a single warning if any labels were missing
522
+ # Batch update attributes in the graph
523
+ nx.set_node_attributes(G, x_attrs, "x")
524
+ nx.set_node_attributes(G, y_attrs, "y")
525
+ nx.set_node_attributes(G, label_attrs, "label")
526
+
527
+ # Log a warning if any labels were missing
518
528
  if nodes_with_missing_labels:
519
- total_nodes = len(G.nodes)
529
+ total_nodes = G.number_of_nodes()
520
530
  fraction_missing_labels = len(nodes_with_missing_labels) / total_nodes
521
531
  logger.warning(
522
532
  f"{len(nodes_with_missing_labels)} out of {total_nodes} nodes "
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/plot
3
+ ~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.plotter.api import PlotterAPI
@@ -0,0 +1,54 @@
1
+ """
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import List, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from risk.log import log_header
11
+ from risk.network.graph.network import NetworkGraph
12
+ from risk.network.plotter.network import NetworkPlotter
13
+
14
+
15
+ class PlotterAPI:
16
+ """Handles the loading of network plotter objects.
17
+
18
+ The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
19
+ """
20
+
21
+ def __init__() -> None:
22
+ pass
23
+
24
+ def load_plotter(
25
+ self,
26
+ graph: NetworkGraph,
27
+ figsize: Union[List, Tuple, np.ndarray] = (10, 10),
28
+ background_color: str = "white",
29
+ background_alpha: Union[float, None] = 1.0,
30
+ pad: float = 0.3,
31
+ ) -> NetworkPlotter:
32
+ """Get a NetworkPlotter object for plotting.
33
+
34
+ Args:
35
+ graph (NetworkGraph): The graph to plot.
36
+ figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
37
+ background_color (str, optional): Background color of the plot. Defaults to "white".
38
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
39
+ any existing alpha values found in background_color. Defaults to 1.0.
40
+ pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
41
+
42
+ Returns:
43
+ NetworkPlotter: A NetworkPlotter object configured with the given parameters.
44
+ """
45
+ log_header("Loading plotter")
46
+
47
+ # Initialize and return a NetworkPlotter object
48
+ return NetworkPlotter(
49
+ graph,
50
+ figsize=figsize,
51
+ background_color=background_color,
52
+ background_alpha=background_alpha,
53
+ pad=pad,
54
+ )