risk-network 0.0.8b26__py3-none-any.whl → 0.0.9b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +74 -47
  4. risk/annotations/io.py +47 -31
  5. risk/log/__init__.py +4 -2
  6. risk/log/{config.py → console.py} +5 -3
  7. risk/log/{params.py → parameters.py} +17 -42
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +446 -0
  10. risk/neighborhoods/community.py +255 -77
  11. risk/neighborhoods/domains.py +62 -31
  12. risk/neighborhoods/neighborhoods.py +156 -160
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +65 -57
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +194 -0
  17. risk/network/{graph.py → graph/network.py} +87 -37
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +56 -47
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +7 -4
  23. risk/network/{plot → plotter}/contour.py +22 -19
  24. risk/network/{plot → plotter}/labels.py +69 -74
  25. risk/network/{plot → plotter}/network.py +170 -34
  26. risk/network/{plot/utils/color.py → plotter/utils/colors.py} +104 -112
  27. risk/network/{plot → plotter}/utils/layout.py +8 -5
  28. risk/risk.py +11 -500
  29. risk/stats/__init__.py +8 -4
  30. risk/stats/binom.py +51 -0
  31. risk/stats/chi2.py +69 -0
  32. risk/stats/hypergeom.py +27 -17
  33. risk/stats/permutation/__init__.py +1 -1
  34. risk/stats/permutation/permutation.py +44 -38
  35. risk/stats/permutation/test_functions.py +25 -17
  36. risk/stats/poisson.py +15 -9
  37. risk/stats/stats.py +15 -13
  38. risk/stats/zscore.py +68 -0
  39. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
  40. risk_network-0.0.9b26.dist-info/RECORD +44 -0
  41. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
  42. risk/network/plot/__init__.py +0 -6
  43. risk/network/plot/plotter.py +0 -137
  44. risk_network-0.0.8b26.dist-info/RECORD +0 -37
  45. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
  46. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ """
2
+ risk/network/graph/summary
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import Any, Dict, Tuple, Union
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from statsmodels.stats.multitest import fdrcorrection
11
+
12
+ from risk.log.console import logger, log_header
13
+
14
+
15
+ class AnalysisSummary:
16
+ """Handles the processing, storage, and export of network analysis results.
17
+
18
+ The Results class provides methods to process significance and depletion data, compute
19
+ FDR-corrected q-values, and structure information on domains and annotations into a
20
+ DataFrame. It also offers functionality to export the processed data in CSV, JSON,
21
+ and text formats for analysis and reporting.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ annotations: Dict[str, Any],
27
+ neighborhoods: Dict[str, Any],
28
+ graph, # Avoid type hinting NetworkGraph to prevent circular imports
29
+ ):
30
+ """Initialize the Results object with analysis components.
31
+
32
+ Args:
33
+ annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
34
+ neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
35
+ graph (NetworkGraph): Graph object representing domain-to-node and node-to-label mappings.
36
+ """
37
+ self.annotations = annotations
38
+ self.neighborhoods = neighborhoods
39
+ self.graph = graph
40
+
41
+ def to_csv(self, filepath: str) -> None:
42
+ """Export significance results to a CSV file.
43
+
44
+ Args:
45
+ filepath (str): The path where the CSV file will be saved.
46
+ """
47
+ # Load results and export directly to CSV
48
+ results = self.load()
49
+ results.to_csv(filepath, index=False)
50
+ logger.info(f"Analysis summary exported to CSV file: {filepath}")
51
+
52
+ def to_json(self, filepath: str) -> None:
53
+ """Export significance results to a JSON file.
54
+
55
+ Args:
56
+ filepath (str): The path where the JSON file will be saved.
57
+ """
58
+ # Load results and export directly to JSON
59
+ results = self.load()
60
+ results.to_json(filepath, orient="records", indent=4)
61
+ logger.info(f"Analysis summary exported to JSON file: {filepath}")
62
+
63
+ def to_txt(self, filepath: str) -> None:
64
+ """Export significance results to a text file.
65
+
66
+ Args:
67
+ filepath (str): The path where the text file will be saved.
68
+ """
69
+ # Load results and export directly to text file
70
+ results = self.load()
71
+ with open(filepath, "w", encoding="utf-8") as txt_file:
72
+ txt_file.write(results.to_string(index=False))
73
+
74
+ logger.info(f"Analysis summary exported to text file: {filepath}")
75
+
76
+ def load(self) -> pd.DataFrame:
77
+ """Load and process domain and annotation data into a DataFrame with significance metrics.
78
+
79
+ Returns:
80
+ pd.DataFrame: Processed DataFrame containing significance scores, p-values, q-values,
81
+ and annotation member information.
82
+ """
83
+ log_header("Loading analysis summary")
84
+ # Calculate significance and depletion q-values from p-value matrices in `annotations`
85
+ enrichment_pvals = self.neighborhoods["enrichment_pvals"]
86
+ depletion_pvals = self.neighborhoods["depletion_pvals"]
87
+ enrichment_qvals = self._calculate_qvalues(enrichment_pvals)
88
+ depletion_qvals = self._calculate_qvalues(depletion_pvals)
89
+
90
+ # Initialize DataFrame with domain and annotation details
91
+ results = pd.DataFrame(
92
+ [
93
+ {"Domain ID": domain_id, "Annotation": desc, "Summed Significance Score": score}
94
+ for domain_id, info in self.graph.domain_id_to_domain_info_map.items()
95
+ for desc, score in zip(info["full_descriptions"], info["significance_scores"])
96
+ ]
97
+ )
98
+ # Sort by Domain ID and Summed Significance Score
99
+ results = results.sort_values(
100
+ by=["Domain ID", "Summed Significance Score"], ascending=[True, False]
101
+ ).reset_index(drop=True)
102
+
103
+ # Add minimum p-values and q-values to DataFrame
104
+ results[
105
+ [
106
+ "Enrichment P-value",
107
+ "Enrichment Q-value",
108
+ "Depletion P-value",
109
+ "Depletion Q-value",
110
+ ]
111
+ ] = results.apply(
112
+ lambda row: self._get_significance_values(
113
+ row["Domain ID"],
114
+ row["Annotation"],
115
+ enrichment_pvals,
116
+ depletion_pvals,
117
+ enrichment_qvals,
118
+ depletion_qvals,
119
+ ),
120
+ axis=1,
121
+ result_type="expand",
122
+ )
123
+ # Add annotation members and their counts
124
+ results["Annotation Members in Network"] = results["Annotation"].apply(
125
+ lambda desc: self._get_annotation_members(desc)
126
+ )
127
+ results["Annotation Members in Network Count"] = results[
128
+ "Annotation Members in Network"
129
+ ].apply(lambda x: len(x.split(";")) if x else 0)
130
+
131
+ # Reorder columns and drop rows with NaN values
132
+ results = (
133
+ results[
134
+ [
135
+ "Domain ID",
136
+ "Annotation",
137
+ "Annotation Members in Network",
138
+ "Annotation Members in Network Count",
139
+ "Summed Significance Score",
140
+ "Enrichment P-value",
141
+ "Enrichment Q-value",
142
+ "Depletion P-value",
143
+ "Depletion Q-value",
144
+ ]
145
+ ]
146
+ .dropna()
147
+ .reset_index(drop=True)
148
+ )
149
+
150
+ # Convert annotations list to a DataFrame for comparison then merge with results
151
+ ordered_annotations = pd.DataFrame({"Annotation": self.annotations["ordered_annotations"]})
152
+ # Merge to ensure all annotations are present, filling missing rows with defaults
153
+ results = pd.merge(ordered_annotations, results, on="Annotation", how="left").fillna(
154
+ {
155
+ "Domain ID": -1,
156
+ "Annotation Members in Network": "",
157
+ "Annotation Members in Network Count": 0,
158
+ "Summed Significance Score": 0.0,
159
+ "Enrichment P-value": 1.0,
160
+ "Enrichment Q-value": 1.0,
161
+ "Depletion P-value": 1.0,
162
+ "Depletion Q-value": 1.0,
163
+ }
164
+ )
165
+ # Convert "Domain ID" and "Annotation Members in Network Count" to integers
166
+ results["Domain ID"] = results["Domain ID"].astype(int)
167
+ results["Annotation Members in Network Count"] = results[
168
+ "Annotation Members in Network Count"
169
+ ].astype(int)
170
+
171
+ return results
172
+
173
+ @staticmethod
174
+ def _calculate_qvalues(pvals: np.ndarray) -> np.ndarray:
175
+ """Calculate q-values (FDR) for each row of a p-value matrix.
176
+
177
+ Args:
178
+ pvals (np.ndarray): 2D array of p-values.
179
+
180
+ Returns:
181
+ np.ndarray: 2D array of q-values, with FDR correction applied row-wise.
182
+ """
183
+ return np.apply_along_axis(lambda row: fdrcorrection(row)[1], 1, pvals)
184
+
185
+ def _get_significance_values(
186
+ self,
187
+ domain_id: int,
188
+ description: str,
189
+ enrichment_pvals: np.ndarray,
190
+ depletion_pvals: np.ndarray,
191
+ enrichment_qvals: np.ndarray,
192
+ depletion_qvals: np.ndarray,
193
+ ) -> Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
194
+ """Retrieve the most significant p-values and q-values (FDR) for a given annotation.
195
+
196
+ Args:
197
+ domain_id (int): The domain ID associated with the annotation.
198
+ description (str): The annotation description.
199
+ enrichment_pvals (np.ndarray): Matrix of significance p-values.
200
+ depletion_pvals (np.ndarray): Matrix of depletion p-values.
201
+ enrichment_qvals (np.ndarray): Matrix of significance q-values.
202
+ depletion_qvals (np.ndarray): Matrix of depletion q-values.
203
+
204
+ Returns:
205
+ Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
206
+ Minimum significance p-value, significance q-value, depletion p-value, depletion q-value.
207
+ """
208
+ try:
209
+ annotation_idx = self.annotations["ordered_annotations"].index(description)
210
+ except ValueError:
211
+ return None, None, None, None # Description not found
212
+
213
+ node_indices = self.graph.domain_id_to_node_ids_map.get(domain_id, [])
214
+ if not node_indices:
215
+ return None, None, None, None # No associated nodes
216
+
217
+ sig_p = enrichment_pvals[node_indices, annotation_idx]
218
+ dep_p = depletion_pvals[node_indices, annotation_idx]
219
+ sig_q = enrichment_qvals[node_indices, annotation_idx]
220
+ dep_q = depletion_qvals[node_indices, annotation_idx]
221
+
222
+ return (
223
+ np.min(sig_p) if sig_p.size > 0 else None,
224
+ np.min(sig_q) if sig_q.size > 0 else None,
225
+ np.min(dep_p) if dep_p.size > 0 else None,
226
+ np.min(dep_q) if dep_q.size > 0 else None,
227
+ )
228
+
229
+ def _get_annotation_members(self, description: str) -> str:
230
+ """Retrieve node labels associated with a given annotation description.
231
+
232
+ Args:
233
+ description (str): The annotation description.
234
+
235
+ Returns:
236
+ str: ';'-separated string of node labels that are associated with the annotation.
237
+ """
238
+ try:
239
+ annotation_idx = self.annotations["ordered_annotations"].index(description)
240
+ except ValueError:
241
+ return "" # Description not found
242
+
243
+ # Get the column (safely) from the sparse matrix
244
+ column = self.annotations["matrix"][:, annotation_idx]
245
+ # Convert the column to a dense array if needed
246
+ column = column.toarray().ravel() # Convert to a 1D dense array
247
+ # Get nodes present for the annotation and sort by node label - use np.where on the dense array
248
+ nodes_present = np.where(column == 1)[0]
249
+ node_labels = sorted(
250
+ self.graph.node_id_to_node_label_map[node_id]
251
+ for node_id in nodes_present
252
+ if node_id in self.graph.node_id_to_node_label_map
253
+ )
254
+ return ";".join(node_labels)
risk/network/io.py CHANGED
@@ -1,8 +1,6 @@
1
1
  """
2
2
  risk/network/io
3
3
  ~~~~~~~~~~~~~~~
4
-
5
- This file contains the code for the RISK class and command-line access.
6
4
  """
7
5
 
8
6
  import copy
@@ -165,12 +163,12 @@ class NetworkIO:
165
163
  filepath: str,
166
164
  source_label: str = "source",
167
165
  target_label: str = "target",
166
+ view_name: str = "",
168
167
  compute_sphere: bool = True,
169
168
  surface_depth: float = 0.0,
170
169
  min_edges_per_node: int = 0,
171
170
  include_edge_weight: bool = True,
172
171
  weight_label: str = "weight",
173
- view_name: str = "",
174
172
  ) -> nx.Graph:
175
173
  """Load a network from a Cytoscape file.
176
174
 
@@ -178,7 +176,7 @@ class NetworkIO:
178
176
  filepath (str): Path to the Cytoscape file.
179
177
  source_label (str, optional): Source node label. Defaults to "source".
180
178
  target_label (str, optional): Target node label. Defaults to "target".
181
- view_name (str, optional): Specific view name to load. Defaults to None.
179
+ view_name (str, optional): Specific view name to load. Defaults to "".
182
180
  compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
183
181
  surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
184
182
  min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
@@ -215,7 +213,7 @@ class NetworkIO:
215
213
  filepath (str): Path to the Cytoscape file.
216
214
  source_label (str, optional): Source node label. Defaults to "source".
217
215
  target_label (str, optional): Target node label. Defaults to "target".
218
- view_name (str, optional): Specific view name to load. Defaults to None.
216
+ view_name (str, optional): Specific view name to load. Defaults to "".
219
217
 
220
218
  Returns:
221
219
  nx.Graph: Loaded and processed network.
@@ -298,12 +296,8 @@ class NetworkIO:
298
296
  # Add node attributes
299
297
  for node in G.nodes():
300
298
  G.nodes[node]["label"] = node
301
- G.nodes[node]["x"] = node_x_positions[
302
- node
303
- ] # Assuming you have a dict `node_x_positions` for x coordinates
304
- G.nodes[node]["y"] = node_y_positions[
305
- node
306
- ] # Assuming you have a dict `node_y_positions` for y coordinates
299
+ G.nodes[node]["x"] = node_x_positions[node]
300
+ G.nodes[node]["y"] = node_y_positions[node]
307
301
 
308
302
  # Initialize the graph
309
303
  return self._initialize_graph(G)
@@ -379,15 +373,17 @@ class NetworkIO:
379
373
  node_y_positions = {}
380
374
  for node in cyjs_data["elements"]["nodes"]:
381
375
  node_data = node["data"]
382
- node_id = node_data["id_original"]
376
+ # Use the original node ID if available, otherwise use the default ID
377
+ node_id = node_data.get("id_original", node_data.get("id"))
383
378
  node_x_positions[node_id] = node["position"]["x"]
384
379
  node_y_positions[node_id] = node["position"]["y"]
385
380
 
386
381
  # Process edges and add them to the graph
387
382
  for edge in cyjs_data["elements"]["edges"]:
388
383
  edge_data = edge["data"]
389
- source = edge_data[f"{source_label}_original"]
390
- target = edge_data[f"{target_label}_original"]
384
+ # Use the original source and target labels if available, otherwise fall back to default labels
385
+ source = edge_data.get(f"{source_label}_original", edge_data.get(source_label))
386
+ target = edge_data.get(f"{target_label}_original", edge_data.get(target_label))
391
387
  # Add the edge to the graph, optionally including weights
392
388
  if self.weight_label is not None and self.weight_label in edge_data:
393
389
  weight = float(edge_data[self.weight_label])
@@ -425,7 +421,7 @@ class NetworkIO:
425
421
  self._remove_invalid_graph_properties(G)
426
422
  # IMPORTANT: This is where the graph node labels are converted to integers
427
423
  # Make sure to perform this step after all other processing
428
- G = nx.relabel_nodes(G, {node: idx for idx, node in enumerate(G.nodes)})
424
+ G = nx.convert_node_labels_to_integers(G)
429
425
  return G
430
426
 
431
427
  def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
@@ -435,22 +431,24 @@ class NetworkIO:
435
431
  Args:
436
432
  G (nx.Graph): A NetworkX graph object.
437
433
  """
438
- # Count number of nodes and edges before cleaning
434
+ # Count the number of nodes and edges before cleaning
439
435
  num_initial_nodes = G.number_of_nodes()
440
436
  num_initial_edges = G.number_of_edges()
441
437
  # Remove self-loops to ensure correct edge count
442
- G.remove_edges_from(list(nx.selfloop_edges(G)))
438
+ G.remove_edges_from(nx.selfloop_edges(G))
443
439
  # Iteratively remove nodes with fewer edges than the threshold
444
440
  while True:
445
- nodes_to_remove = [node for node in G.nodes if G.degree(node) < self.min_edges_per_node]
441
+ nodes_to_remove = [
442
+ node
443
+ for node, degree in dict(G.degree()).items()
444
+ if degree < self.min_edges_per_node
445
+ ]
446
446
  if not nodes_to_remove:
447
- break # Exit loop if no more nodes need removal
447
+ break # Exit loop if no nodes meet the condition
448
448
  G.remove_nodes_from(nodes_to_remove)
449
449
 
450
450
  # Remove isolated nodes
451
- isolated_nodes = list(nx.isolates(G))
452
- if isolated_nodes:
453
- G.remove_nodes_from(isolated_nodes)
451
+ G.remove_nodes_from(nx.isolates(G))
454
452
 
455
453
  # Log the number of nodes and edges before and after cleaning
456
454
  num_final_nodes = G.number_of_nodes()
@@ -468,12 +466,9 @@ class NetworkIO:
468
466
  """
469
467
  missing_weights = 0
470
468
  # Assign user-defined edge weights to the "weight" attribute
471
- for _, _, data in G.edges(data=True):
472
- if self.weight_label not in data:
473
- missing_weights += 1
474
- data["weight"] = data.get(
475
- self.weight_label, 1.0
476
- ) # Default to 1.0 if 'weight' not present
469
+ nx.set_edge_attributes(G, 1.0, "weight") # Set default weight
470
+ if self.weight_label in nx.get_edge_attributes(G, self.weight_label):
471
+ nx.set_edge_attributes(G, nx.get_edge_attributes(G, self.weight_label), "weight")
477
472
 
478
473
  if self.include_edge_weight and missing_weights:
479
474
  logger.debug(f"Total edges missing weights: {missing_weights}")
@@ -483,41 +478,55 @@ class NetworkIO:
483
478
 
484
479
  Args:
485
480
  G (nx.Graph): A NetworkX graph object.
481
+
482
+ Raises:
483
+ ValueError: If a node is missing 'x', 'y', and a valid 'pos' attribute.
486
484
  """
487
- # Keep track of nodes missing labels
485
+ # Retrieve all relevant attributes in bulk
486
+ pos_attrs = nx.get_node_attributes(G, "pos")
487
+ name_attrs = nx.get_node_attributes(G, "name")
488
+ id_attrs = nx.get_node_attributes(G, "id")
489
+ # Dictionaries to hold missing or fallback attributes
490
+ x_attrs = {}
491
+ y_attrs = {}
492
+ label_attrs = {}
488
493
  nodes_with_missing_labels = []
489
494
 
490
- for node, attrs in G.nodes(data=True):
491
- # Attribute fallback for 'x' and 'y' attributes
495
+ # Iterate through nodes to validate and assign missing attributes
496
+ for node in G.nodes:
497
+ attrs = G.nodes[node]
498
+ # Validate and assign 'x' and 'y' attributes
492
499
  if "x" not in attrs or "y" not in attrs:
493
500
  if (
494
- "pos" in attrs
495
- and isinstance(attrs["pos"], (list, tuple, np.ndarray))
496
- and len(attrs["pos"]) >= 2
501
+ node in pos_attrs
502
+ and isinstance(pos_attrs[node], (list, tuple, np.ndarray))
503
+ and len(pos_attrs[node]) >= 2
497
504
  ):
498
- attrs["x"], attrs["y"] = attrs["pos"][
499
- :2
500
- ] # Use only x and y, ignoring z if present
505
+ x_attrs[node], y_attrs[node] = pos_attrs[node][:2]
501
506
  else:
502
507
  raise ValueError(
503
508
  f"Node {node} is missing 'x', 'y', and a valid 'pos' attribute."
504
509
  )
505
510
 
506
- # Attribute fallback for 'label' attribute
511
+ # Validate and assign 'label' attribute
507
512
  if "label" not in attrs:
508
- # Try alternative attribute names for label
509
- if "name" in attrs:
510
- attrs["label"] = attrs["name"]
511
- elif "id" in attrs:
512
- attrs["label"] = attrs["id"]
513
+ if node in name_attrs:
514
+ label_attrs[node] = name_attrs[node]
515
+ elif node in id_attrs:
516
+ label_attrs[node] = id_attrs[node]
513
517
  else:
514
- # Collect nodes with missing labels
518
+ # Assign node ID as label and log the missing label
519
+ label_attrs[node] = str(node)
515
520
  nodes_with_missing_labels.append(node)
516
- attrs["label"] = str(node) # Use node ID as the label
517
521
 
518
- # Issue a single warning if any labels were missing
522
+ # Batch update attributes in the graph
523
+ nx.set_node_attributes(G, x_attrs, "x")
524
+ nx.set_node_attributes(G, y_attrs, "y")
525
+ nx.set_node_attributes(G, label_attrs, "label")
526
+
527
+ # Log a warning if any labels were missing
519
528
  if nodes_with_missing_labels:
520
- total_nodes = len(G.nodes)
529
+ total_nodes = G.number_of_nodes()
521
530
  fraction_missing_labels = len(nodes_with_missing_labels) / total_nodes
522
531
  logger.warning(
523
532
  f"{len(nodes_with_missing_labels)} out of {total_nodes} nodes "
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/plot
3
+ ~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.plotter.api import PlotterAPI
@@ -0,0 +1,54 @@
1
+ """
2
+ risk/network/graph/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import List, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from risk.log import log_header
11
+ from risk.network.graph.network import NetworkGraph
12
+ from risk.network.plotter.network import NetworkPlotter
13
+
14
+
15
+ class PlotterAPI:
16
+ """Handles the loading of network plotter objects.
17
+
18
+ The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
19
+ """
20
+
21
+ def __init__() -> None:
22
+ pass
23
+
24
+ def load_plotter(
25
+ self,
26
+ graph: NetworkGraph,
27
+ figsize: Union[List, Tuple, np.ndarray] = (10, 10),
28
+ background_color: str = "white",
29
+ background_alpha: Union[float, None] = 1.0,
30
+ pad: float = 0.3,
31
+ ) -> NetworkPlotter:
32
+ """Get a NetworkPlotter object for plotting.
33
+
34
+ Args:
35
+ graph (NetworkGraph): The graph to plot.
36
+ figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
37
+ background_color (str, optional): Background color of the plot. Defaults to "white".
38
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
39
+ any existing alpha values found in background_color. Defaults to 1.0.
40
+ pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
41
+
42
+ Returns:
43
+ NetworkPlotter: A NetworkPlotter object configured with the given parameters.
44
+ """
45
+ log_header("Loading plotter")
46
+
47
+ # Initialize and return a NetworkPlotter object
48
+ return NetworkPlotter(
49
+ graph,
50
+ figsize=figsize,
51
+ background_color=background_color,
52
+ background_alpha=background_alpha,
53
+ pad=pad,
54
+ )
@@ -9,9 +9,9 @@ import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
 
11
11
  from risk.log import params
12
- from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils.color import to_rgba
14
- from risk.network.plot.utils.layout import calculate_bounding_box
12
+ from risk.network.graph.network import NetworkGraph
13
+ from risk.network.plotter.utils.colors import to_rgba
14
+ from risk.network.plotter.utils.layout import calculate_bounding_box
15
15
 
16
16
 
17
17
  class Canvas:
@@ -36,6 +36,7 @@ class Canvas:
36
36
  font: str = "Arial",
37
37
  title_color: Union[str, List, Tuple, np.ndarray] = "black",
38
38
  subtitle_color: Union[str, List, Tuple, np.ndarray] = "gray",
39
+ title_x: float = 0.5,
39
40
  title_y: float = 0.975,
40
41
  title_space_offset: float = 0.075,
41
42
  subtitle_offset: float = 0.025,
@@ -52,6 +53,7 @@ class Canvas:
52
53
  Defaults to "black".
53
54
  subtitle_color (str, List, Tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
54
55
  Defaults to "gray".
56
+ title_x (float, optional): X-axis position of the title. Defaults to 0.5.
55
57
  title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
56
58
  title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
57
59
  subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
@@ -85,7 +87,7 @@ class Canvas:
85
87
  fontsize=title_fontsize,
86
88
  color=title_color,
87
89
  fontname=font,
88
- x=0.5, # Center the title horizontally
90
+ x=title_x,
89
91
  ha="center",
90
92
  va="top",
91
93
  y=title_y,
@@ -234,6 +236,7 @@ class Canvas:
234
236
  # Scale the node coordinates if needed
235
237
  scaled_coordinates = node_coordinates * scale
236
238
  # Use the existing _draw_kde_contour method
239
+ # NOTE: This is a technical debt that should be refactored in the future - only works when inherited by NetworkPlotter
237
240
  self._draw_kde_contour(
238
241
  ax=self.ax,
239
242
  pos=scaled_coordinates,