risk-network 0.0.8b18__py3-none-any.whl → 0.0.9b26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -2
- risk/annotations/__init__.py +2 -2
- risk/annotations/annotations.py +133 -72
- risk/annotations/io.py +50 -34
- risk/log/__init__.py +4 -2
- risk/log/{config.py → console.py} +5 -3
- risk/log/{params.py → parameters.py} +21 -46
- risk/neighborhoods/__init__.py +3 -5
- risk/neighborhoods/api.py +446 -0
- risk/neighborhoods/community.py +281 -96
- risk/neighborhoods/domains.py +92 -38
- risk/neighborhoods/neighborhoods.py +210 -149
- risk/network/__init__.py +1 -3
- risk/network/geometry.py +69 -58
- risk/network/graph/__init__.py +6 -0
- risk/network/graph/api.py +194 -0
- risk/network/graph/network.py +269 -0
- risk/network/graph/summary.py +254 -0
- risk/network/io.py +58 -48
- risk/network/plotter/__init__.py +6 -0
- risk/network/plotter/api.py +54 -0
- risk/network/{plot → plotter}/canvas.py +80 -26
- risk/network/{plot → plotter}/contour.py +43 -34
- risk/network/{plot → plotter}/labels.py +123 -113
- risk/network/plotter/network.py +424 -0
- risk/network/plotter/utils/colors.py +416 -0
- risk/network/plotter/utils/layout.py +94 -0
- risk/risk.py +11 -469
- risk/stats/__init__.py +8 -4
- risk/stats/binom.py +51 -0
- risk/stats/chi2.py +69 -0
- risk/stats/hypergeom.py +28 -18
- risk/stats/permutation/__init__.py +1 -1
- risk/stats/permutation/permutation.py +45 -39
- risk/stats/permutation/test_functions.py +25 -17
- risk/stats/poisson.py +17 -11
- risk/stats/stats.py +20 -16
- risk/stats/zscore.py +68 -0
- {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
- risk_network-0.0.9b26.dist-info/RECORD +44 -0
- {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
- risk/network/graph.py +0 -159
- risk/network/plot/__init__.py +0 -6
- risk/network/plot/network.py +0 -282
- risk/network/plot/plotter.py +0 -137
- risk/network/plot/utils/color.py +0 -353
- risk/network/plot/utils/layout.py +0 -53
- risk_network-0.0.8b18.dist-info/RECORD +0 -37
- {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph/summary
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Dict, Tuple, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
from statsmodels.stats.multitest import fdrcorrection
|
11
|
+
|
12
|
+
from risk.log.console import logger, log_header
|
13
|
+
|
14
|
+
|
15
|
+
class AnalysisSummary:
|
16
|
+
"""Handles the processing, storage, and export of network analysis results.
|
17
|
+
|
18
|
+
The Results class provides methods to process significance and depletion data, compute
|
19
|
+
FDR-corrected q-values, and structure information on domains and annotations into a
|
20
|
+
DataFrame. It also offers functionality to export the processed data in CSV, JSON,
|
21
|
+
and text formats for analysis and reporting.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
annotations: Dict[str, Any],
|
27
|
+
neighborhoods: Dict[str, Any],
|
28
|
+
graph, # Avoid type hinting NetworkGraph to prevent circular imports
|
29
|
+
):
|
30
|
+
"""Initialize the Results object with analysis components.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
|
34
|
+
neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
|
35
|
+
graph (NetworkGraph): Graph object representing domain-to-node and node-to-label mappings.
|
36
|
+
"""
|
37
|
+
self.annotations = annotations
|
38
|
+
self.neighborhoods = neighborhoods
|
39
|
+
self.graph = graph
|
40
|
+
|
41
|
+
def to_csv(self, filepath: str) -> None:
|
42
|
+
"""Export significance results to a CSV file.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
filepath (str): The path where the CSV file will be saved.
|
46
|
+
"""
|
47
|
+
# Load results and export directly to CSV
|
48
|
+
results = self.load()
|
49
|
+
results.to_csv(filepath, index=False)
|
50
|
+
logger.info(f"Analysis summary exported to CSV file: {filepath}")
|
51
|
+
|
52
|
+
def to_json(self, filepath: str) -> None:
|
53
|
+
"""Export significance results to a JSON file.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
filepath (str): The path where the JSON file will be saved.
|
57
|
+
"""
|
58
|
+
# Load results and export directly to JSON
|
59
|
+
results = self.load()
|
60
|
+
results.to_json(filepath, orient="records", indent=4)
|
61
|
+
logger.info(f"Analysis summary exported to JSON file: {filepath}")
|
62
|
+
|
63
|
+
def to_txt(self, filepath: str) -> None:
|
64
|
+
"""Export significance results to a text file.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
filepath (str): The path where the text file will be saved.
|
68
|
+
"""
|
69
|
+
# Load results and export directly to text file
|
70
|
+
results = self.load()
|
71
|
+
with open(filepath, "w", encoding="utf-8") as txt_file:
|
72
|
+
txt_file.write(results.to_string(index=False))
|
73
|
+
|
74
|
+
logger.info(f"Analysis summary exported to text file: {filepath}")
|
75
|
+
|
76
|
+
def load(self) -> pd.DataFrame:
|
77
|
+
"""Load and process domain and annotation data into a DataFrame with significance metrics.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
pd.DataFrame: Processed DataFrame containing significance scores, p-values, q-values,
|
81
|
+
and annotation member information.
|
82
|
+
"""
|
83
|
+
log_header("Loading analysis summary")
|
84
|
+
# Calculate significance and depletion q-values from p-value matrices in `annotations`
|
85
|
+
enrichment_pvals = self.neighborhoods["enrichment_pvals"]
|
86
|
+
depletion_pvals = self.neighborhoods["depletion_pvals"]
|
87
|
+
enrichment_qvals = self._calculate_qvalues(enrichment_pvals)
|
88
|
+
depletion_qvals = self._calculate_qvalues(depletion_pvals)
|
89
|
+
|
90
|
+
# Initialize DataFrame with domain and annotation details
|
91
|
+
results = pd.DataFrame(
|
92
|
+
[
|
93
|
+
{"Domain ID": domain_id, "Annotation": desc, "Summed Significance Score": score}
|
94
|
+
for domain_id, info in self.graph.domain_id_to_domain_info_map.items()
|
95
|
+
for desc, score in zip(info["full_descriptions"], info["significance_scores"])
|
96
|
+
]
|
97
|
+
)
|
98
|
+
# Sort by Domain ID and Summed Significance Score
|
99
|
+
results = results.sort_values(
|
100
|
+
by=["Domain ID", "Summed Significance Score"], ascending=[True, False]
|
101
|
+
).reset_index(drop=True)
|
102
|
+
|
103
|
+
# Add minimum p-values and q-values to DataFrame
|
104
|
+
results[
|
105
|
+
[
|
106
|
+
"Enrichment P-value",
|
107
|
+
"Enrichment Q-value",
|
108
|
+
"Depletion P-value",
|
109
|
+
"Depletion Q-value",
|
110
|
+
]
|
111
|
+
] = results.apply(
|
112
|
+
lambda row: self._get_significance_values(
|
113
|
+
row["Domain ID"],
|
114
|
+
row["Annotation"],
|
115
|
+
enrichment_pvals,
|
116
|
+
depletion_pvals,
|
117
|
+
enrichment_qvals,
|
118
|
+
depletion_qvals,
|
119
|
+
),
|
120
|
+
axis=1,
|
121
|
+
result_type="expand",
|
122
|
+
)
|
123
|
+
# Add annotation members and their counts
|
124
|
+
results["Annotation Members in Network"] = results["Annotation"].apply(
|
125
|
+
lambda desc: self._get_annotation_members(desc)
|
126
|
+
)
|
127
|
+
results["Annotation Members in Network Count"] = results[
|
128
|
+
"Annotation Members in Network"
|
129
|
+
].apply(lambda x: len(x.split(";")) if x else 0)
|
130
|
+
|
131
|
+
# Reorder columns and drop rows with NaN values
|
132
|
+
results = (
|
133
|
+
results[
|
134
|
+
[
|
135
|
+
"Domain ID",
|
136
|
+
"Annotation",
|
137
|
+
"Annotation Members in Network",
|
138
|
+
"Annotation Members in Network Count",
|
139
|
+
"Summed Significance Score",
|
140
|
+
"Enrichment P-value",
|
141
|
+
"Enrichment Q-value",
|
142
|
+
"Depletion P-value",
|
143
|
+
"Depletion Q-value",
|
144
|
+
]
|
145
|
+
]
|
146
|
+
.dropna()
|
147
|
+
.reset_index(drop=True)
|
148
|
+
)
|
149
|
+
|
150
|
+
# Convert annotations list to a DataFrame for comparison then merge with results
|
151
|
+
ordered_annotations = pd.DataFrame({"Annotation": self.annotations["ordered_annotations"]})
|
152
|
+
# Merge to ensure all annotations are present, filling missing rows with defaults
|
153
|
+
results = pd.merge(ordered_annotations, results, on="Annotation", how="left").fillna(
|
154
|
+
{
|
155
|
+
"Domain ID": -1,
|
156
|
+
"Annotation Members in Network": "",
|
157
|
+
"Annotation Members in Network Count": 0,
|
158
|
+
"Summed Significance Score": 0.0,
|
159
|
+
"Enrichment P-value": 1.0,
|
160
|
+
"Enrichment Q-value": 1.0,
|
161
|
+
"Depletion P-value": 1.0,
|
162
|
+
"Depletion Q-value": 1.0,
|
163
|
+
}
|
164
|
+
)
|
165
|
+
# Convert "Domain ID" and "Annotation Members in Network Count" to integers
|
166
|
+
results["Domain ID"] = results["Domain ID"].astype(int)
|
167
|
+
results["Annotation Members in Network Count"] = results[
|
168
|
+
"Annotation Members in Network Count"
|
169
|
+
].astype(int)
|
170
|
+
|
171
|
+
return results
|
172
|
+
|
173
|
+
@staticmethod
|
174
|
+
def _calculate_qvalues(pvals: np.ndarray) -> np.ndarray:
|
175
|
+
"""Calculate q-values (FDR) for each row of a p-value matrix.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
pvals (np.ndarray): 2D array of p-values.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
np.ndarray: 2D array of q-values, with FDR correction applied row-wise.
|
182
|
+
"""
|
183
|
+
return np.apply_along_axis(lambda row: fdrcorrection(row)[1], 1, pvals)
|
184
|
+
|
185
|
+
def _get_significance_values(
|
186
|
+
self,
|
187
|
+
domain_id: int,
|
188
|
+
description: str,
|
189
|
+
enrichment_pvals: np.ndarray,
|
190
|
+
depletion_pvals: np.ndarray,
|
191
|
+
enrichment_qvals: np.ndarray,
|
192
|
+
depletion_qvals: np.ndarray,
|
193
|
+
) -> Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
|
194
|
+
"""Retrieve the most significant p-values and q-values (FDR) for a given annotation.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
domain_id (int): The domain ID associated with the annotation.
|
198
|
+
description (str): The annotation description.
|
199
|
+
enrichment_pvals (np.ndarray): Matrix of significance p-values.
|
200
|
+
depletion_pvals (np.ndarray): Matrix of depletion p-values.
|
201
|
+
enrichment_qvals (np.ndarray): Matrix of significance q-values.
|
202
|
+
depletion_qvals (np.ndarray): Matrix of depletion q-values.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
|
206
|
+
Minimum significance p-value, significance q-value, depletion p-value, depletion q-value.
|
207
|
+
"""
|
208
|
+
try:
|
209
|
+
annotation_idx = self.annotations["ordered_annotations"].index(description)
|
210
|
+
except ValueError:
|
211
|
+
return None, None, None, None # Description not found
|
212
|
+
|
213
|
+
node_indices = self.graph.domain_id_to_node_ids_map.get(domain_id, [])
|
214
|
+
if not node_indices:
|
215
|
+
return None, None, None, None # No associated nodes
|
216
|
+
|
217
|
+
sig_p = enrichment_pvals[node_indices, annotation_idx]
|
218
|
+
dep_p = depletion_pvals[node_indices, annotation_idx]
|
219
|
+
sig_q = enrichment_qvals[node_indices, annotation_idx]
|
220
|
+
dep_q = depletion_qvals[node_indices, annotation_idx]
|
221
|
+
|
222
|
+
return (
|
223
|
+
np.min(sig_p) if sig_p.size > 0 else None,
|
224
|
+
np.min(sig_q) if sig_q.size > 0 else None,
|
225
|
+
np.min(dep_p) if dep_p.size > 0 else None,
|
226
|
+
np.min(dep_q) if dep_q.size > 0 else None,
|
227
|
+
)
|
228
|
+
|
229
|
+
def _get_annotation_members(self, description: str) -> str:
|
230
|
+
"""Retrieve node labels associated with a given annotation description.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
description (str): The annotation description.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
str: ';'-separated string of node labels that are associated with the annotation.
|
237
|
+
"""
|
238
|
+
try:
|
239
|
+
annotation_idx = self.annotations["ordered_annotations"].index(description)
|
240
|
+
except ValueError:
|
241
|
+
return "" # Description not found
|
242
|
+
|
243
|
+
# Get the column (safely) from the sparse matrix
|
244
|
+
column = self.annotations["matrix"][:, annotation_idx]
|
245
|
+
# Convert the column to a dense array if needed
|
246
|
+
column = column.toarray().ravel() # Convert to a 1D dense array
|
247
|
+
# Get nodes present for the annotation and sort by node label - use np.where on the dense array
|
248
|
+
nodes_present = np.where(column == 1)[0]
|
249
|
+
node_labels = sorted(
|
250
|
+
self.graph.node_id_to_node_label_map[node_id]
|
251
|
+
for node_id in nodes_present
|
252
|
+
if node_id in self.graph.node_id_to_node_label_map
|
253
|
+
)
|
254
|
+
return ";".join(node_labels)
|
risk/network/io.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
1
|
"""
|
2
2
|
risk/network/io
|
3
3
|
~~~~~~~~~~~~~~~
|
4
|
-
|
5
|
-
This file contains the code for the RISK class and command-line access.
|
6
4
|
"""
|
7
5
|
|
6
|
+
import copy
|
8
7
|
import json
|
9
8
|
import os
|
10
9
|
import pickle
|
@@ -155,7 +154,7 @@ class NetworkIO:
|
|
155
154
|
self._log_loading(filetype)
|
156
155
|
|
157
156
|
# Important: Make a copy of the network to avoid modifying the original
|
158
|
-
network_copy =
|
157
|
+
network_copy = copy.deepcopy(network)
|
159
158
|
# Initialize the graph
|
160
159
|
return self._initialize_graph(network_copy)
|
161
160
|
|
@@ -164,12 +163,12 @@ class NetworkIO:
|
|
164
163
|
filepath: str,
|
165
164
|
source_label: str = "source",
|
166
165
|
target_label: str = "target",
|
166
|
+
view_name: str = "",
|
167
167
|
compute_sphere: bool = True,
|
168
168
|
surface_depth: float = 0.0,
|
169
169
|
min_edges_per_node: int = 0,
|
170
170
|
include_edge_weight: bool = True,
|
171
171
|
weight_label: str = "weight",
|
172
|
-
view_name: str = "",
|
173
172
|
) -> nx.Graph:
|
174
173
|
"""Load a network from a Cytoscape file.
|
175
174
|
|
@@ -177,7 +176,7 @@ class NetworkIO:
|
|
177
176
|
filepath (str): Path to the Cytoscape file.
|
178
177
|
source_label (str, optional): Source node label. Defaults to "source".
|
179
178
|
target_label (str, optional): Target node label. Defaults to "target".
|
180
|
-
view_name (str, optional): Specific view name to load. Defaults to
|
179
|
+
view_name (str, optional): Specific view name to load. Defaults to "".
|
181
180
|
compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
|
182
181
|
surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
|
183
182
|
min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
|
@@ -214,7 +213,7 @@ class NetworkIO:
|
|
214
213
|
filepath (str): Path to the Cytoscape file.
|
215
214
|
source_label (str, optional): Source node label. Defaults to "source".
|
216
215
|
target_label (str, optional): Target node label. Defaults to "target".
|
217
|
-
view_name (str, optional): Specific view name to load. Defaults to
|
216
|
+
view_name (str, optional): Specific view name to load. Defaults to "".
|
218
217
|
|
219
218
|
Returns:
|
220
219
|
nx.Graph: Loaded and processed network.
|
@@ -297,12 +296,8 @@ class NetworkIO:
|
|
297
296
|
# Add node attributes
|
298
297
|
for node in G.nodes():
|
299
298
|
G.nodes[node]["label"] = node
|
300
|
-
G.nodes[node]["x"] = node_x_positions[
|
301
|
-
|
302
|
-
] # Assuming you have a dict `node_x_positions` for x coordinates
|
303
|
-
G.nodes[node]["y"] = node_y_positions[
|
304
|
-
node
|
305
|
-
] # Assuming you have a dict `node_y_positions` for y coordinates
|
299
|
+
G.nodes[node]["x"] = node_x_positions[node]
|
300
|
+
G.nodes[node]["y"] = node_y_positions[node]
|
306
301
|
|
307
302
|
# Initialize the graph
|
308
303
|
return self._initialize_graph(G)
|
@@ -378,15 +373,17 @@ class NetworkIO:
|
|
378
373
|
node_y_positions = {}
|
379
374
|
for node in cyjs_data["elements"]["nodes"]:
|
380
375
|
node_data = node["data"]
|
381
|
-
|
376
|
+
# Use the original node ID if available, otherwise use the default ID
|
377
|
+
node_id = node_data.get("id_original", node_data.get("id"))
|
382
378
|
node_x_positions[node_id] = node["position"]["x"]
|
383
379
|
node_y_positions[node_id] = node["position"]["y"]
|
384
380
|
|
385
381
|
# Process edges and add them to the graph
|
386
382
|
for edge in cyjs_data["elements"]["edges"]:
|
387
383
|
edge_data = edge["data"]
|
388
|
-
source
|
389
|
-
|
384
|
+
# Use the original source and target labels if available, otherwise fall back to default labels
|
385
|
+
source = edge_data.get(f"{source_label}_original", edge_data.get(source_label))
|
386
|
+
target = edge_data.get(f"{target_label}_original", edge_data.get(target_label))
|
390
387
|
# Add the edge to the graph, optionally including weights
|
391
388
|
if self.weight_label is not None and self.weight_label in edge_data:
|
392
389
|
weight = float(edge_data[self.weight_label])
|
@@ -424,7 +421,7 @@ class NetworkIO:
|
|
424
421
|
self._remove_invalid_graph_properties(G)
|
425
422
|
# IMPORTANT: This is where the graph node labels are converted to integers
|
426
423
|
# Make sure to perform this step after all other processing
|
427
|
-
G = nx.
|
424
|
+
G = nx.convert_node_labels_to_integers(G)
|
428
425
|
return G
|
429
426
|
|
430
427
|
def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
|
@@ -434,22 +431,24 @@ class NetworkIO:
|
|
434
431
|
Args:
|
435
432
|
G (nx.Graph): A NetworkX graph object.
|
436
433
|
"""
|
437
|
-
# Count number of nodes and edges before cleaning
|
434
|
+
# Count the number of nodes and edges before cleaning
|
438
435
|
num_initial_nodes = G.number_of_nodes()
|
439
436
|
num_initial_edges = G.number_of_edges()
|
440
437
|
# Remove self-loops to ensure correct edge count
|
441
|
-
G.remove_edges_from(
|
438
|
+
G.remove_edges_from(nx.selfloop_edges(G))
|
442
439
|
# Iteratively remove nodes with fewer edges than the threshold
|
443
440
|
while True:
|
444
|
-
nodes_to_remove = [
|
441
|
+
nodes_to_remove = [
|
442
|
+
node
|
443
|
+
for node, degree in dict(G.degree()).items()
|
444
|
+
if degree < self.min_edges_per_node
|
445
|
+
]
|
445
446
|
if not nodes_to_remove:
|
446
|
-
break # Exit loop if no
|
447
|
+
break # Exit loop if no nodes meet the condition
|
447
448
|
G.remove_nodes_from(nodes_to_remove)
|
448
449
|
|
449
450
|
# Remove isolated nodes
|
450
|
-
|
451
|
-
if isolated_nodes:
|
452
|
-
G.remove_nodes_from(isolated_nodes)
|
451
|
+
G.remove_nodes_from(nx.isolates(G))
|
453
452
|
|
454
453
|
# Log the number of nodes and edges before and after cleaning
|
455
454
|
num_final_nodes = G.number_of_nodes()
|
@@ -467,12 +466,9 @@ class NetworkIO:
|
|
467
466
|
"""
|
468
467
|
missing_weights = 0
|
469
468
|
# Assign user-defined edge weights to the "weight" attribute
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
data["weight"] = data.get(
|
474
|
-
self.weight_label, 1.0
|
475
|
-
) # Default to 1.0 if 'weight' not present
|
469
|
+
nx.set_edge_attributes(G, 1.0, "weight") # Set default weight
|
470
|
+
if self.weight_label in nx.get_edge_attributes(G, self.weight_label):
|
471
|
+
nx.set_edge_attributes(G, nx.get_edge_attributes(G, self.weight_label), "weight")
|
476
472
|
|
477
473
|
if self.include_edge_weight and missing_weights:
|
478
474
|
logger.debug(f"Total edges missing weights: {missing_weights}")
|
@@ -482,41 +478,55 @@ class NetworkIO:
|
|
482
478
|
|
483
479
|
Args:
|
484
480
|
G (nx.Graph): A NetworkX graph object.
|
481
|
+
|
482
|
+
Raises:
|
483
|
+
ValueError: If a node is missing 'x', 'y', and a valid 'pos' attribute.
|
485
484
|
"""
|
486
|
-
#
|
485
|
+
# Retrieve all relevant attributes in bulk
|
486
|
+
pos_attrs = nx.get_node_attributes(G, "pos")
|
487
|
+
name_attrs = nx.get_node_attributes(G, "name")
|
488
|
+
id_attrs = nx.get_node_attributes(G, "id")
|
489
|
+
# Dictionaries to hold missing or fallback attributes
|
490
|
+
x_attrs = {}
|
491
|
+
y_attrs = {}
|
492
|
+
label_attrs = {}
|
487
493
|
nodes_with_missing_labels = []
|
488
494
|
|
489
|
-
|
490
|
-
|
495
|
+
# Iterate through nodes to validate and assign missing attributes
|
496
|
+
for node in G.nodes:
|
497
|
+
attrs = G.nodes[node]
|
498
|
+
# Validate and assign 'x' and 'y' attributes
|
491
499
|
if "x" not in attrs or "y" not in attrs:
|
492
500
|
if (
|
493
|
-
|
494
|
-
and isinstance(
|
495
|
-
and len(
|
501
|
+
node in pos_attrs
|
502
|
+
and isinstance(pos_attrs[node], (list, tuple, np.ndarray))
|
503
|
+
and len(pos_attrs[node]) >= 2
|
496
504
|
):
|
497
|
-
|
498
|
-
:2
|
499
|
-
] # Use only x and y, ignoring z if present
|
505
|
+
x_attrs[node], y_attrs[node] = pos_attrs[node][:2]
|
500
506
|
else:
|
501
507
|
raise ValueError(
|
502
508
|
f"Node {node} is missing 'x', 'y', and a valid 'pos' attribute."
|
503
509
|
)
|
504
510
|
|
505
|
-
#
|
511
|
+
# Validate and assign 'label' attribute
|
506
512
|
if "label" not in attrs:
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
attrs["label"] = attrs["id"]
|
513
|
+
if node in name_attrs:
|
514
|
+
label_attrs[node] = name_attrs[node]
|
515
|
+
elif node in id_attrs:
|
516
|
+
label_attrs[node] = id_attrs[node]
|
512
517
|
else:
|
513
|
-
#
|
518
|
+
# Assign node ID as label and log the missing label
|
519
|
+
label_attrs[node] = str(node)
|
514
520
|
nodes_with_missing_labels.append(node)
|
515
|
-
attrs["label"] = str(node) # Use node ID as the label
|
516
521
|
|
517
|
-
#
|
522
|
+
# Batch update attributes in the graph
|
523
|
+
nx.set_node_attributes(G, x_attrs, "x")
|
524
|
+
nx.set_node_attributes(G, y_attrs, "y")
|
525
|
+
nx.set_node_attributes(G, label_attrs, "label")
|
526
|
+
|
527
|
+
# Log a warning if any labels were missing
|
518
528
|
if nodes_with_missing_labels:
|
519
|
-
total_nodes =
|
529
|
+
total_nodes = G.number_of_nodes()
|
520
530
|
fraction_missing_labels = len(nodes_with_missing_labels) / total_nodes
|
521
531
|
logger.warning(
|
522
532
|
f"{len(nodes_with_missing_labels)} out of {total_nodes} nodes "
|
@@ -0,0 +1,54 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph/api
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import List, Tuple, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from risk.log import log_header
|
11
|
+
from risk.network.graph.network import NetworkGraph
|
12
|
+
from risk.network.plotter.network import NetworkPlotter
|
13
|
+
|
14
|
+
|
15
|
+
class PlotterAPI:
|
16
|
+
"""Handles the loading of network plotter objects.
|
17
|
+
|
18
|
+
The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__() -> None:
|
22
|
+
pass
|
23
|
+
|
24
|
+
def load_plotter(
|
25
|
+
self,
|
26
|
+
graph: NetworkGraph,
|
27
|
+
figsize: Union[List, Tuple, np.ndarray] = (10, 10),
|
28
|
+
background_color: str = "white",
|
29
|
+
background_alpha: Union[float, None] = 1.0,
|
30
|
+
pad: float = 0.3,
|
31
|
+
) -> NetworkPlotter:
|
32
|
+
"""Get a NetworkPlotter object for plotting.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
graph (NetworkGraph): The graph to plot.
|
36
|
+
figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
|
37
|
+
background_color (str, optional): Background color of the plot. Defaults to "white".
|
38
|
+
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
|
39
|
+
any existing alpha values found in background_color. Defaults to 1.0.
|
40
|
+
pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
NetworkPlotter: A NetworkPlotter object configured with the given parameters.
|
44
|
+
"""
|
45
|
+
log_header("Loading plotter")
|
46
|
+
|
47
|
+
# Initialize and return a NetworkPlotter object
|
48
|
+
return NetworkPlotter(
|
49
|
+
graph,
|
50
|
+
figsize=figsize,
|
51
|
+
background_color=background_color,
|
52
|
+
background_alpha=background_alpha,
|
53
|
+
pad=pad,
|
54
|
+
)
|