risk-network 0.0.8b26__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +195 -118
  4. risk/annotations/io.py +47 -31
  5. risk/log/__init__.py +4 -2
  6. risk/log/{config.py → console.py} +5 -3
  7. risk/log/{params.py → parameters.py} +17 -42
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +442 -0
  10. risk/neighborhoods/community.py +324 -101
  11. risk/neighborhoods/domains.py +125 -52
  12. risk/neighborhoods/neighborhoods.py +177 -165
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +71 -89
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +200 -0
  17. risk/network/{graph.py → graph/graph.py} +90 -40
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +103 -114
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +12 -9
  23. risk/network/{plot → plotter}/contour.py +27 -24
  24. risk/network/{plot → plotter}/labels.py +73 -78
  25. risk/network/{plot → plotter}/network.py +45 -39
  26. risk/network/{plot → plotter}/plotter.py +23 -17
  27. risk/network/{plot/utils/color.py → plotter/utils/colors.py} +114 -122
  28. risk/network/{plot → plotter}/utils/layout.py +10 -7
  29. risk/risk.py +11 -500
  30. risk/stats/__init__.py +10 -4
  31. risk/stats/permutation/__init__.py +1 -1
  32. risk/stats/permutation/permutation.py +44 -38
  33. risk/stats/permutation/test_functions.py +26 -18
  34. risk/stats/{stats.py → significance.py} +17 -15
  35. risk/stats/stat_tests.py +267 -0
  36. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9.dist-info}/METADATA +31 -46
  37. risk_network-0.0.9.dist-info/RECORD +40 -0
  38. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9.dist-info}/WHEEL +1 -1
  39. risk/constants.py +0 -31
  40. risk/network/plot/__init__.py +0 -6
  41. risk/stats/hypergeom.py +0 -54
  42. risk/stats/poisson.py +0 -44
  43. risk_network-0.0.8b26.dist-info/RECORD +0 -37
  44. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9.dist-info}/LICENSE +0 -0
  45. {risk_network-0.0.8b26.dist-info → risk_network-0.0.9.dist-info}/top_level.txt +0 -0
risk/network/io.py CHANGED
@@ -1,8 +1,6 @@
1
1
  """
2
2
  risk/network/io
3
3
  ~~~~~~~~~~~~~~~
4
-
5
- This file contains the code for the RISK class and command-line access.
6
4
  """
7
5
 
8
6
  import copy
@@ -33,8 +31,6 @@ class NetworkIO:
33
31
  compute_sphere: bool = True,
34
32
  surface_depth: float = 0.0,
35
33
  min_edges_per_node: int = 0,
36
- include_edge_weight: bool = True,
37
- weight_label: str = "weight",
38
34
  ):
39
35
  """Initialize the NetworkIO class.
40
36
 
@@ -42,21 +38,15 @@ class NetworkIO:
42
38
  compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
43
39
  surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
44
40
  min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
45
- include_edge_weight (bool, optional): Whether to include edge weights in calculations. Defaults to True.
46
- weight_label (str, optional): Label for edge weights. Defaults to "weight".
47
41
  """
48
42
  self.compute_sphere = compute_sphere
49
43
  self.surface_depth = surface_depth
50
44
  self.min_edges_per_node = min_edges_per_node
51
- self.include_edge_weight = include_edge_weight
52
- self.weight_label = weight_label
53
45
  # Log the initialization of the NetworkIO class
54
46
  params.log_network(
55
47
  compute_sphere=compute_sphere,
56
48
  surface_depth=surface_depth,
57
49
  min_edges_per_node=min_edges_per_node,
58
- include_edge_weight=include_edge_weight,
59
- weight_label=weight_label,
60
50
  )
61
51
 
62
52
  @staticmethod
@@ -65,8 +55,6 @@ class NetworkIO:
65
55
  compute_sphere: bool = True,
66
56
  surface_depth: float = 0.0,
67
57
  min_edges_per_node: int = 0,
68
- include_edge_weight: bool = True,
69
- weight_label: str = "weight",
70
58
  ) -> nx.Graph:
71
59
  """Load a network from a GPickle file.
72
60
 
@@ -75,8 +63,6 @@ class NetworkIO:
75
63
  compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
76
64
  surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
77
65
  min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
78
- include_edge_weight (bool, optional): Whether to include edge weights in calculations. Defaults to True.
79
- weight_label (str, optional): Label for edge weights. Defaults to "weight".
80
66
 
81
67
  Returns:
82
68
  nx.Graph: Loaded and processed network.
@@ -85,8 +71,6 @@ class NetworkIO:
85
71
  compute_sphere=compute_sphere,
86
72
  surface_depth=surface_depth,
87
73
  min_edges_per_node=min_edges_per_node,
88
- include_edge_weight=include_edge_weight,
89
- weight_label=weight_label,
90
74
  )
91
75
  return networkio._load_gpickle_network(filepath=filepath)
92
76
 
@@ -116,8 +100,6 @@ class NetworkIO:
116
100
  compute_sphere: bool = True,
117
101
  surface_depth: float = 0.0,
118
102
  min_edges_per_node: int = 0,
119
- include_edge_weight: bool = True,
120
- weight_label: str = "weight",
121
103
  ) -> nx.Graph:
122
104
  """Load a NetworkX graph.
123
105
 
@@ -126,8 +108,6 @@ class NetworkIO:
126
108
  compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
127
109
  surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
128
110
  min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
129
- include_edge_weight (bool, optional): Whether to include edge weights in calculations. Defaults to True.
130
- weight_label (str, optional): Label for edge weights. Defaults to "weight".
131
111
 
132
112
  Returns:
133
113
  nx.Graph: Loaded and processed network.
@@ -136,8 +116,6 @@ class NetworkIO:
136
116
  compute_sphere=compute_sphere,
137
117
  surface_depth=surface_depth,
138
118
  min_edges_per_node=min_edges_per_node,
139
- include_edge_weight=include_edge_weight,
140
- weight_label=weight_label,
141
119
  )
142
120
  return networkio._load_networkx_network(network=network)
143
121
 
@@ -165,12 +143,10 @@ class NetworkIO:
165
143
  filepath: str,
166
144
  source_label: str = "source",
167
145
  target_label: str = "target",
146
+ view_name: str = "",
168
147
  compute_sphere: bool = True,
169
148
  surface_depth: float = 0.0,
170
149
  min_edges_per_node: int = 0,
171
- include_edge_weight: bool = True,
172
- weight_label: str = "weight",
173
- view_name: str = "",
174
150
  ) -> nx.Graph:
175
151
  """Load a network from a Cytoscape file.
176
152
 
@@ -178,12 +154,10 @@ class NetworkIO:
178
154
  filepath (str): Path to the Cytoscape file.
179
155
  source_label (str, optional): Source node label. Defaults to "source".
180
156
  target_label (str, optional): Target node label. Defaults to "target".
181
- view_name (str, optional): Specific view name to load. Defaults to None.
157
+ view_name (str, optional): Specific view name to load. Defaults to "".
182
158
  compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
183
159
  surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
184
160
  min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
185
- include_edge_weight (bool, optional): Whether to include edge weights in calculations. Defaults to True.
186
- weight_label (str, optional): Label for edge weights. Defaults to "weight".
187
161
 
188
162
  Returns:
189
163
  nx.Graph: Loaded and processed network.
@@ -192,8 +166,6 @@ class NetworkIO:
192
166
  compute_sphere=compute_sphere,
193
167
  surface_depth=surface_depth,
194
168
  min_edges_per_node=min_edges_per_node,
195
- include_edge_weight=include_edge_weight,
196
- weight_label=weight_label,
197
169
  )
198
170
  return networkio._load_cytoscape_network(
199
171
  filepath=filepath,
@@ -215,10 +187,13 @@ class NetworkIO:
215
187
  filepath (str): Path to the Cytoscape file.
216
188
  source_label (str, optional): Source node label. Defaults to "source".
217
189
  target_label (str, optional): Target node label. Defaults to "target".
218
- view_name (str, optional): Specific view name to load. Defaults to None.
190
+ view_name (str, optional): Specific view name to load. Defaults to "".
219
191
 
220
192
  Returns:
221
193
  nx.Graph: Loaded and processed network.
194
+
195
+ Raises:
196
+ ValueError: If no matching attribute metadata file is found.
222
197
  """
223
198
  filetype = "Cytoscape"
224
199
  # Log the loading of the Cytoscape file
@@ -260,36 +235,59 @@ class NetworkIO:
260
235
 
261
236
  # Read the node attributes (from /tables/)
262
237
  attribute_metadata_keywords = ["/tables/", "SHARED_ATTRS", "edge.cytable"]
263
- attribute_metadata = [
264
- os.path.join(tmp_dir, cf)
265
- for cf in cys_files
266
- if all(keyword in cf for keyword in attribute_metadata_keywords)
267
- ][0]
268
- # Load attributes file from Cytoscape as pandas data frame
269
- attribute_table = pd.read_csv(attribute_metadata, sep=",", header=None, skiprows=1)
238
+ # Use a generator to find the first matching file
239
+ attribute_metadata = next(
240
+ (
241
+ os.path.join(tmp_dir, cf)
242
+ for cf in cys_files
243
+ if all(keyword in cf for keyword in attribute_metadata_keywords)
244
+ ),
245
+ None, # Default if no file matches
246
+ )
247
+ if attribute_metadata:
248
+ # Optimize `read_csv` by leveraging proper options
249
+ attribute_table = pd.read_csv(
250
+ attribute_metadata,
251
+ sep=",",
252
+ header=None,
253
+ skiprows=1,
254
+ dtype=str, # Use specific dtypes to reduce memory usage
255
+ engine="c", # Use the C engine for parsing if compatible
256
+ low_memory=False, # Optimize memory handling for large files
257
+ )
258
+ else:
259
+ raise ValueError("No matching attribute metadata file found.")
260
+
270
261
  # Set columns
271
262
  attribute_table.columns = attribute_table.iloc[0]
272
- # Skip first four rows
263
+ # Skip first four rows, select source and target columns, and reset index
273
264
  attribute_table = attribute_table.iloc[4:, :]
274
- # Conditionally select columns based on include_edge_weight
275
- if self.include_edge_weight:
276
- attribute_table = attribute_table[[source_label, target_label, self.weight_label]]
277
- else:
265
+ try:
266
+ # Attempt to filter the attribute_table with the given labels
278
267
  attribute_table = attribute_table[[source_label, target_label]]
268
+ except KeyError as e:
269
+ # Find which key(s) caused the issue
270
+ missing_keys = [
271
+ key
272
+ for key in [source_label, target_label]
273
+ if key not in attribute_table.columns
274
+ ]
275
+ # Raise the KeyError with details about the issue and available options
276
+ available_columns = ", ".join(attribute_table.columns)
277
+ raise KeyError(
278
+ f"The column(s) '{', '.join(missing_keys)}' do not exist in the table. "
279
+ f"Available columns are: {available_columns}."
280
+ ) from e
279
281
 
280
282
  attribute_table = attribute_table.dropna().reset_index(drop=True)
283
+
281
284
  # Create a graph
282
285
  G = nx.Graph()
283
- # Add edges and nodes, conditionally including weights
286
+ # Add edges and nodes
284
287
  for _, row in attribute_table.iterrows():
285
288
  source = row[source_label]
286
289
  target = row[target_label]
287
- if self.include_edge_weight:
288
- weight = float(row[self.weight_label])
289
- G.add_edge(source, target, weight=weight)
290
- else:
291
- G.add_edge(source, target)
292
-
290
+ G.add_edge(source, target)
293
291
  if source not in G:
294
292
  G.add_node(source) # Optionally add x, y coordinates here if available
295
293
  if target not in G:
@@ -298,12 +296,8 @@ class NetworkIO:
298
296
  # Add node attributes
299
297
  for node in G.nodes():
300
298
  G.nodes[node]["label"] = node
301
- G.nodes[node]["x"] = node_x_positions[
302
- node
303
- ] # Assuming you have a dict `node_x_positions` for x coordinates
304
- G.nodes[node]["y"] = node_y_positions[
305
- node
306
- ] # Assuming you have a dict `node_y_positions` for y coordinates
299
+ G.nodes[node]["x"] = node_x_positions[node]
300
+ G.nodes[node]["y"] = node_y_positions[node]
307
301
 
308
302
  # Initialize the graph
309
303
  return self._initialize_graph(G)
@@ -321,8 +315,6 @@ class NetworkIO:
321
315
  compute_sphere: bool = True,
322
316
  surface_depth: float = 0.0,
323
317
  min_edges_per_node: int = 0,
324
- include_edge_weight: bool = True,
325
- weight_label: str = "weight",
326
318
  ) -> nx.Graph:
327
319
  """Load a network from a Cytoscape JSON (.cyjs) file.
328
320
 
@@ -333,8 +325,6 @@ class NetworkIO:
333
325
  compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
334
326
  surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
335
327
  min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
336
- include_edge_weight (bool, optional): Whether to include edge weights in calculations. Defaults to True.
337
- weight_label (str, optional): Label for edge weights. Defaults to "weight".
338
328
 
339
329
  Returns:
340
330
  NetworkX graph: Loaded and processed network.
@@ -343,8 +333,6 @@ class NetworkIO:
343
333
  compute_sphere=compute_sphere,
344
334
  surface_depth=surface_depth,
345
335
  min_edges_per_node=min_edges_per_node,
346
- include_edge_weight=include_edge_weight,
347
- weight_label=weight_label,
348
336
  )
349
337
  return networkio._load_cytoscape_json_network(
350
338
  filepath=filepath,
@@ -379,21 +367,18 @@ class NetworkIO:
379
367
  node_y_positions = {}
380
368
  for node in cyjs_data["elements"]["nodes"]:
381
369
  node_data = node["data"]
382
- node_id = node_data["id_original"]
370
+ # Use the original node ID if available, otherwise use the default ID
371
+ node_id = node_data.get("id_original", node_data.get("id"))
383
372
  node_x_positions[node_id] = node["position"]["x"]
384
373
  node_y_positions[node_id] = node["position"]["y"]
385
374
 
386
375
  # Process edges and add them to the graph
387
376
  for edge in cyjs_data["elements"]["edges"]:
388
377
  edge_data = edge["data"]
389
- source = edge_data[f"{source_label}_original"]
390
- target = edge_data[f"{target_label}_original"]
391
- # Add the edge to the graph, optionally including weights
392
- if self.weight_label is not None and self.weight_label in edge_data:
393
- weight = float(edge_data[self.weight_label])
394
- G.add_edge(source, target, weight=weight)
395
- else:
396
- G.add_edge(source, target)
378
+ # Use the original source and target labels if available, otherwise fall back to default labels
379
+ source = edge_data.get(f"{source_label}_original", edge_data.get(source_label))
380
+ target = edge_data.get(f"{target_label}_original", edge_data.get(target_label))
381
+ G.add_edge(source, target)
397
382
 
398
383
  # Ensure nodes exist in the graph and add them if not present
399
384
  if source not in G:
@@ -425,7 +410,7 @@ class NetworkIO:
425
410
  self._remove_invalid_graph_properties(G)
426
411
  # IMPORTANT: This is where the graph node labels are converted to integers
427
412
  # Make sure to perform this step after all other processing
428
- G = nx.relabel_nodes(G, {node: idx for idx, node in enumerate(G.nodes)})
413
+ G = nx.convert_node_labels_to_integers(G)
429
414
  return G
430
415
 
431
416
  def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
@@ -435,22 +420,24 @@ class NetworkIO:
435
420
  Args:
436
421
  G (nx.Graph): A NetworkX graph object.
437
422
  """
438
- # Count number of nodes and edges before cleaning
423
+ # Count the number of nodes and edges before cleaning
439
424
  num_initial_nodes = G.number_of_nodes()
440
425
  num_initial_edges = G.number_of_edges()
441
426
  # Remove self-loops to ensure correct edge count
442
- G.remove_edges_from(list(nx.selfloop_edges(G)))
427
+ G.remove_edges_from(nx.selfloop_edges(G))
443
428
  # Iteratively remove nodes with fewer edges than the threshold
444
429
  while True:
445
- nodes_to_remove = [node for node in G.nodes if G.degree(node) < self.min_edges_per_node]
430
+ nodes_to_remove = [
431
+ node
432
+ for node, degree in dict(G.degree()).items()
433
+ if degree < self.min_edges_per_node
434
+ ]
446
435
  if not nodes_to_remove:
447
- break # Exit loop if no more nodes need removal
436
+ break # Exit loop if no nodes meet the condition
448
437
  G.remove_nodes_from(nodes_to_remove)
449
438
 
450
439
  # Remove isolated nodes
451
- isolated_nodes = list(nx.isolates(G))
452
- if isolated_nodes:
453
- G.remove_nodes_from(isolated_nodes)
440
+ G.remove_nodes_from(nx.isolates(G))
454
441
 
455
442
  # Log the number of nodes and edges before and after cleaning
456
443
  num_final_nodes = G.number_of_nodes()
@@ -461,63 +448,69 @@ class NetworkIO:
461
448
  logger.debug(f"Final edge count: {num_final_edges}")
462
449
 
463
450
  def _assign_edge_weights(self, G: nx.Graph) -> None:
464
- """Assign weights to the edges in the graph.
451
+ """Assign default edge weights to the graph.
465
452
 
466
453
  Args:
467
454
  G (nx.Graph): A NetworkX graph object.
468
455
  """
469
- missing_weights = 0
470
- # Assign user-defined edge weights to the "weight" attribute
471
- for _, _, data in G.edges(data=True):
472
- if self.weight_label not in data:
473
- missing_weights += 1
474
- data["weight"] = data.get(
475
- self.weight_label, 1.0
476
- ) # Default to 1.0 if 'weight' not present
477
-
478
- if self.include_edge_weight and missing_weights:
479
- logger.debug(f"Total edges missing weights: {missing_weights}")
456
+ # Set default weight for all edges in bulk
457
+ default_weight = 1
458
+ nx.set_edge_attributes(G, default_weight, "weight")
480
459
 
481
460
  def _validate_nodes(self, G: nx.Graph) -> None:
482
461
  """Validate the graph structure and attributes with attribute fallback for positions and labels.
483
462
 
484
463
  Args:
485
464
  G (nx.Graph): A NetworkX graph object.
465
+
466
+ Raises:
467
+ ValueError: If a node is missing 'x', 'y', and a valid 'pos' attribute.
486
468
  """
487
- # Keep track of nodes missing labels
469
+ # Retrieve all relevant attributes in bulk
470
+ pos_attrs = nx.get_node_attributes(G, "pos")
471
+ name_attrs = nx.get_node_attributes(G, "name")
472
+ id_attrs = nx.get_node_attributes(G, "id")
473
+ # Dictionaries to hold missing or fallback attributes
474
+ x_attrs = {}
475
+ y_attrs = {}
476
+ label_attrs = {}
488
477
  nodes_with_missing_labels = []
489
478
 
490
- for node, attrs in G.nodes(data=True):
491
- # Attribute fallback for 'x' and 'y' attributes
479
+ # Iterate through nodes to validate and assign missing attributes
480
+ for node in G.nodes:
481
+ attrs = G.nodes[node]
482
+ # Validate and assign 'x' and 'y' attributes
492
483
  if "x" not in attrs or "y" not in attrs:
493
484
  if (
494
- "pos" in attrs
495
- and isinstance(attrs["pos"], (list, tuple, np.ndarray))
496
- and len(attrs["pos"]) >= 2
485
+ node in pos_attrs
486
+ and isinstance(pos_attrs[node], (list, tuple, np.ndarray))
487
+ and len(pos_attrs[node]) >= 2
497
488
  ):
498
- attrs["x"], attrs["y"] = attrs["pos"][
499
- :2
500
- ] # Use only x and y, ignoring z if present
489
+ x_attrs[node], y_attrs[node] = pos_attrs[node][:2]
501
490
  else:
502
491
  raise ValueError(
503
492
  f"Node {node} is missing 'x', 'y', and a valid 'pos' attribute."
504
493
  )
505
494
 
506
- # Attribute fallback for 'label' attribute
495
+ # Validate and assign 'label' attribute
507
496
  if "label" not in attrs:
508
- # Try alternative attribute names for label
509
- if "name" in attrs:
510
- attrs["label"] = attrs["name"]
511
- elif "id" in attrs:
512
- attrs["label"] = attrs["id"]
497
+ if node in name_attrs:
498
+ label_attrs[node] = name_attrs[node]
499
+ elif node in id_attrs:
500
+ label_attrs[node] = id_attrs[node]
513
501
  else:
514
- # Collect nodes with missing labels
502
+ # Assign node ID as label and log the missing label
503
+ label_attrs[node] = str(node)
515
504
  nodes_with_missing_labels.append(node)
516
- attrs["label"] = str(node) # Use node ID as the label
517
505
 
518
- # Issue a single warning if any labels were missing
506
+ # Batch update attributes in the graph
507
+ nx.set_node_attributes(G, x_attrs, "x")
508
+ nx.set_node_attributes(G, y_attrs, "y")
509
+ nx.set_node_attributes(G, label_attrs, "label")
510
+
511
+ # Log a warning if any labels were missing
519
512
  if nodes_with_missing_labels:
520
- total_nodes = len(G.nodes)
513
+ total_nodes = G.number_of_nodes()
521
514
  fraction_missing_labels = len(nodes_with_missing_labels) / total_nodes
522
515
  logger.warning(
523
516
  f"{len(nodes_with_missing_labels)} out of {total_nodes} nodes "
@@ -534,7 +527,6 @@ class NetworkIO:
534
527
  G,
535
528
  compute_sphere=self.compute_sphere,
536
529
  surface_depth=self.surface_depth,
537
- include_edge_weight=self.include_edge_weight,
538
530
  )
539
531
 
540
532
  def _log_loading(
@@ -552,9 +544,6 @@ class NetworkIO:
552
544
  logger.debug(f"Filetype: {filetype}")
553
545
  if filepath:
554
546
  logger.debug(f"Filepath: {filepath}")
555
- logger.debug(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
556
- if self.include_edge_weight:
557
- logger.debug(f"Weight label: {self.weight_label}")
558
547
  logger.debug(f"Minimum edges per node: {self.min_edges_per_node}")
559
548
  logger.debug(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
560
549
  if self.compute_sphere:
@@ -0,0 +1,6 @@
1
+ """
2
+ risk/network/plotter
3
+ ~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from risk.network.plotter.api import PlotterAPI
@@ -0,0 +1,54 @@
1
+ """
2
+ risk/network/plotter/api
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from typing import List, Tuple, Union
7
+
8
+ import numpy as np
9
+
10
+ from risk.log import log_header
11
+ from risk.network.graph.graph import Graph
12
+ from risk.network.plotter.plotter import Plotter
13
+
14
+
15
+ class PlotterAPI:
16
+ """Handles the loading of network plotter objects.
17
+
18
+ The PlotterAPI class provides methods to load and configure Plotter objects for plotting network graphs.
19
+ """
20
+
21
+ def __init__() -> None:
22
+ pass
23
+
24
+ def load_plotter(
25
+ self,
26
+ graph: Graph,
27
+ figsize: Union[List, Tuple, np.ndarray] = (10, 10),
28
+ background_color: str = "white",
29
+ background_alpha: Union[float, None] = 1.0,
30
+ pad: float = 0.3,
31
+ ) -> Plotter:
32
+ """Get a Plotter object for plotting.
33
+
34
+ Args:
35
+ graph (Graph): The graph to plot.
36
+ figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
37
+ background_color (str, optional): Background color of the plot. Defaults to "white".
38
+ background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
39
+ any existing alpha values found in background_color. Defaults to 1.0.
40
+ pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
41
+
42
+ Returns:
43
+ Plotter: A Plotter object configured with the given parameters.
44
+ """
45
+ log_header("Loading plotter")
46
+
47
+ # Initialize and return a Plotter object
48
+ return Plotter(
49
+ graph,
50
+ figsize=figsize,
51
+ background_color=background_color,
52
+ background_alpha=background_alpha,
53
+ pad=pad,
54
+ )
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/network/plot/canvas
3
- ~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/network/plotter/canvas
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  from typing import List, Tuple, Union
@@ -9,19 +9,19 @@ import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
 
11
11
  from risk.log import params
12
- from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils.color import to_rgba
14
- from risk.network.plot.utils.layout import calculate_bounding_box
12
+ from risk.network.graph.graph import Graph
13
+ from risk.network.plotter.utils.colors import to_rgba
14
+ from risk.network.plotter.utils.layout import calculate_bounding_box
15
15
 
16
16
 
17
17
  class Canvas:
18
18
  """A class for laying out the canvas in a network graph."""
19
19
 
20
- def __init__(self, graph: NetworkGraph, ax: plt.Axes) -> None:
21
- """Initialize the Canvas with a NetworkGraph and axis for plotting.
20
+ def __init__(self, graph: Graph, ax: plt.Axes) -> None:
21
+ """Initialize the Canvas with a Graph and axis for plotting.
22
22
 
23
23
  Args:
24
- graph (NetworkGraph): The NetworkGraph object containing the network data.
24
+ graph (Graph): The Graph object containing the network data.
25
25
  ax (plt.Axes): The axis to plot the canvas on.
26
26
  """
27
27
  self.graph = graph
@@ -36,6 +36,7 @@ class Canvas:
36
36
  font: str = "Arial",
37
37
  title_color: Union[str, List, Tuple, np.ndarray] = "black",
38
38
  subtitle_color: Union[str, List, Tuple, np.ndarray] = "gray",
39
+ title_x: float = 0.5,
39
40
  title_y: float = 0.975,
40
41
  title_space_offset: float = 0.075,
41
42
  subtitle_offset: float = 0.025,
@@ -52,6 +53,7 @@ class Canvas:
52
53
  Defaults to "black".
53
54
  subtitle_color (str, List, Tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
54
55
  Defaults to "gray".
56
+ title_x (float, optional): X-axis position of the title. Defaults to 0.5.
55
57
  title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
56
58
  title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
57
59
  subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
@@ -85,7 +87,7 @@ class Canvas:
85
87
  fontsize=title_fontsize,
86
88
  color=title_color,
87
89
  fontname=font,
88
- x=0.5, # Center the title horizontally
90
+ x=title_x,
89
91
  ha="center",
90
92
  va="top",
91
93
  y=title_y,
@@ -234,6 +236,7 @@ class Canvas:
234
236
  # Scale the node coordinates if needed
235
237
  scaled_coordinates = node_coordinates * scale
236
238
  # Use the existing _draw_kde_contour method
239
+ # NOTE: This is a technical debt that should be refactored in the future - only works when inherited by Plotter
237
240
  self._draw_kde_contour(
238
241
  ax=self.ax,
239
242
  pos=scaled_coordinates,