risk-network 0.0.8b26__py3-none-any.whl → 0.0.9b26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +2 -2
- risk/annotations/__init__.py +2 -2
- risk/annotations/annotations.py +74 -47
- risk/annotations/io.py +47 -31
- risk/log/__init__.py +4 -2
- risk/log/{config.py → console.py} +5 -3
- risk/log/{params.py → parameters.py} +17 -42
- risk/neighborhoods/__init__.py +3 -5
- risk/neighborhoods/api.py +446 -0
- risk/neighborhoods/community.py +255 -77
- risk/neighborhoods/domains.py +62 -31
- risk/neighborhoods/neighborhoods.py +156 -160
- risk/network/__init__.py +1 -3
- risk/network/geometry.py +65 -57
- risk/network/graph/__init__.py +6 -0
- risk/network/graph/api.py +194 -0
- risk/network/{graph.py → graph/network.py} +87 -37
- risk/network/graph/summary.py +254 -0
- risk/network/io.py +56 -47
- risk/network/plotter/__init__.py +6 -0
- risk/network/plotter/api.py +54 -0
- risk/network/{plot → plotter}/canvas.py +7 -4
- risk/network/{plot → plotter}/contour.py +22 -19
- risk/network/{plot → plotter}/labels.py +69 -74
- risk/network/{plot → plotter}/network.py +170 -34
- risk/network/{plot/utils/color.py → plotter/utils/colors.py} +104 -112
- risk/network/{plot → plotter}/utils/layout.py +8 -5
- risk/risk.py +11 -500
- risk/stats/__init__.py +8 -4
- risk/stats/binom.py +51 -0
- risk/stats/chi2.py +69 -0
- risk/stats/hypergeom.py +27 -17
- risk/stats/permutation/__init__.py +1 -1
- risk/stats/permutation/permutation.py +44 -38
- risk/stats/permutation/test_functions.py +25 -17
- risk/stats/poisson.py +15 -9
- risk/stats/stats.py +15 -13
- risk/stats/zscore.py +68 -0
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
- risk_network-0.0.9b26.dist-info/RECORD +44 -0
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
- risk/network/plot/__init__.py +0 -6
- risk/network/plot/plotter.py +0 -137
- risk_network-0.0.8b26.dist-info/RECORD +0 -37
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
- {risk_network-0.0.8b26.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph/summary
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Dict, Tuple, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
from statsmodels.stats.multitest import fdrcorrection
|
11
|
+
|
12
|
+
from risk.log.console import logger, log_header
|
13
|
+
|
14
|
+
|
15
|
+
class AnalysisSummary:
|
16
|
+
"""Handles the processing, storage, and export of network analysis results.
|
17
|
+
|
18
|
+
The Results class provides methods to process significance and depletion data, compute
|
19
|
+
FDR-corrected q-values, and structure information on domains and annotations into a
|
20
|
+
DataFrame. It also offers functionality to export the processed data in CSV, JSON,
|
21
|
+
and text formats for analysis and reporting.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
annotations: Dict[str, Any],
|
27
|
+
neighborhoods: Dict[str, Any],
|
28
|
+
graph, # Avoid type hinting NetworkGraph to prevent circular imports
|
29
|
+
):
|
30
|
+
"""Initialize the Results object with analysis components.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
|
34
|
+
neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
|
35
|
+
graph (NetworkGraph): Graph object representing domain-to-node and node-to-label mappings.
|
36
|
+
"""
|
37
|
+
self.annotations = annotations
|
38
|
+
self.neighborhoods = neighborhoods
|
39
|
+
self.graph = graph
|
40
|
+
|
41
|
+
def to_csv(self, filepath: str) -> None:
|
42
|
+
"""Export significance results to a CSV file.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
filepath (str): The path where the CSV file will be saved.
|
46
|
+
"""
|
47
|
+
# Load results and export directly to CSV
|
48
|
+
results = self.load()
|
49
|
+
results.to_csv(filepath, index=False)
|
50
|
+
logger.info(f"Analysis summary exported to CSV file: {filepath}")
|
51
|
+
|
52
|
+
def to_json(self, filepath: str) -> None:
|
53
|
+
"""Export significance results to a JSON file.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
filepath (str): The path where the JSON file will be saved.
|
57
|
+
"""
|
58
|
+
# Load results and export directly to JSON
|
59
|
+
results = self.load()
|
60
|
+
results.to_json(filepath, orient="records", indent=4)
|
61
|
+
logger.info(f"Analysis summary exported to JSON file: {filepath}")
|
62
|
+
|
63
|
+
def to_txt(self, filepath: str) -> None:
|
64
|
+
"""Export significance results to a text file.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
filepath (str): The path where the text file will be saved.
|
68
|
+
"""
|
69
|
+
# Load results and export directly to text file
|
70
|
+
results = self.load()
|
71
|
+
with open(filepath, "w", encoding="utf-8") as txt_file:
|
72
|
+
txt_file.write(results.to_string(index=False))
|
73
|
+
|
74
|
+
logger.info(f"Analysis summary exported to text file: {filepath}")
|
75
|
+
|
76
|
+
def load(self) -> pd.DataFrame:
|
77
|
+
"""Load and process domain and annotation data into a DataFrame with significance metrics.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
pd.DataFrame: Processed DataFrame containing significance scores, p-values, q-values,
|
81
|
+
and annotation member information.
|
82
|
+
"""
|
83
|
+
log_header("Loading analysis summary")
|
84
|
+
# Calculate significance and depletion q-values from p-value matrices in `annotations`
|
85
|
+
enrichment_pvals = self.neighborhoods["enrichment_pvals"]
|
86
|
+
depletion_pvals = self.neighborhoods["depletion_pvals"]
|
87
|
+
enrichment_qvals = self._calculate_qvalues(enrichment_pvals)
|
88
|
+
depletion_qvals = self._calculate_qvalues(depletion_pvals)
|
89
|
+
|
90
|
+
# Initialize DataFrame with domain and annotation details
|
91
|
+
results = pd.DataFrame(
|
92
|
+
[
|
93
|
+
{"Domain ID": domain_id, "Annotation": desc, "Summed Significance Score": score}
|
94
|
+
for domain_id, info in self.graph.domain_id_to_domain_info_map.items()
|
95
|
+
for desc, score in zip(info["full_descriptions"], info["significance_scores"])
|
96
|
+
]
|
97
|
+
)
|
98
|
+
# Sort by Domain ID and Summed Significance Score
|
99
|
+
results = results.sort_values(
|
100
|
+
by=["Domain ID", "Summed Significance Score"], ascending=[True, False]
|
101
|
+
).reset_index(drop=True)
|
102
|
+
|
103
|
+
# Add minimum p-values and q-values to DataFrame
|
104
|
+
results[
|
105
|
+
[
|
106
|
+
"Enrichment P-value",
|
107
|
+
"Enrichment Q-value",
|
108
|
+
"Depletion P-value",
|
109
|
+
"Depletion Q-value",
|
110
|
+
]
|
111
|
+
] = results.apply(
|
112
|
+
lambda row: self._get_significance_values(
|
113
|
+
row["Domain ID"],
|
114
|
+
row["Annotation"],
|
115
|
+
enrichment_pvals,
|
116
|
+
depletion_pvals,
|
117
|
+
enrichment_qvals,
|
118
|
+
depletion_qvals,
|
119
|
+
),
|
120
|
+
axis=1,
|
121
|
+
result_type="expand",
|
122
|
+
)
|
123
|
+
# Add annotation members and their counts
|
124
|
+
results["Annotation Members in Network"] = results["Annotation"].apply(
|
125
|
+
lambda desc: self._get_annotation_members(desc)
|
126
|
+
)
|
127
|
+
results["Annotation Members in Network Count"] = results[
|
128
|
+
"Annotation Members in Network"
|
129
|
+
].apply(lambda x: len(x.split(";")) if x else 0)
|
130
|
+
|
131
|
+
# Reorder columns and drop rows with NaN values
|
132
|
+
results = (
|
133
|
+
results[
|
134
|
+
[
|
135
|
+
"Domain ID",
|
136
|
+
"Annotation",
|
137
|
+
"Annotation Members in Network",
|
138
|
+
"Annotation Members in Network Count",
|
139
|
+
"Summed Significance Score",
|
140
|
+
"Enrichment P-value",
|
141
|
+
"Enrichment Q-value",
|
142
|
+
"Depletion P-value",
|
143
|
+
"Depletion Q-value",
|
144
|
+
]
|
145
|
+
]
|
146
|
+
.dropna()
|
147
|
+
.reset_index(drop=True)
|
148
|
+
)
|
149
|
+
|
150
|
+
# Convert annotations list to a DataFrame for comparison then merge with results
|
151
|
+
ordered_annotations = pd.DataFrame({"Annotation": self.annotations["ordered_annotations"]})
|
152
|
+
# Merge to ensure all annotations are present, filling missing rows with defaults
|
153
|
+
results = pd.merge(ordered_annotations, results, on="Annotation", how="left").fillna(
|
154
|
+
{
|
155
|
+
"Domain ID": -1,
|
156
|
+
"Annotation Members in Network": "",
|
157
|
+
"Annotation Members in Network Count": 0,
|
158
|
+
"Summed Significance Score": 0.0,
|
159
|
+
"Enrichment P-value": 1.0,
|
160
|
+
"Enrichment Q-value": 1.0,
|
161
|
+
"Depletion P-value": 1.0,
|
162
|
+
"Depletion Q-value": 1.0,
|
163
|
+
}
|
164
|
+
)
|
165
|
+
# Convert "Domain ID" and "Annotation Members in Network Count" to integers
|
166
|
+
results["Domain ID"] = results["Domain ID"].astype(int)
|
167
|
+
results["Annotation Members in Network Count"] = results[
|
168
|
+
"Annotation Members in Network Count"
|
169
|
+
].astype(int)
|
170
|
+
|
171
|
+
return results
|
172
|
+
|
173
|
+
@staticmethod
|
174
|
+
def _calculate_qvalues(pvals: np.ndarray) -> np.ndarray:
|
175
|
+
"""Calculate q-values (FDR) for each row of a p-value matrix.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
pvals (np.ndarray): 2D array of p-values.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
np.ndarray: 2D array of q-values, with FDR correction applied row-wise.
|
182
|
+
"""
|
183
|
+
return np.apply_along_axis(lambda row: fdrcorrection(row)[1], 1, pvals)
|
184
|
+
|
185
|
+
def _get_significance_values(
|
186
|
+
self,
|
187
|
+
domain_id: int,
|
188
|
+
description: str,
|
189
|
+
enrichment_pvals: np.ndarray,
|
190
|
+
depletion_pvals: np.ndarray,
|
191
|
+
enrichment_qvals: np.ndarray,
|
192
|
+
depletion_qvals: np.ndarray,
|
193
|
+
) -> Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
|
194
|
+
"""Retrieve the most significant p-values and q-values (FDR) for a given annotation.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
domain_id (int): The domain ID associated with the annotation.
|
198
|
+
description (str): The annotation description.
|
199
|
+
enrichment_pvals (np.ndarray): Matrix of significance p-values.
|
200
|
+
depletion_pvals (np.ndarray): Matrix of depletion p-values.
|
201
|
+
enrichment_qvals (np.ndarray): Matrix of significance q-values.
|
202
|
+
depletion_qvals (np.ndarray): Matrix of depletion q-values.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
Tuple[Union[float, None], Union[float, None], Union[float, None], Union[float, None]]:
|
206
|
+
Minimum significance p-value, significance q-value, depletion p-value, depletion q-value.
|
207
|
+
"""
|
208
|
+
try:
|
209
|
+
annotation_idx = self.annotations["ordered_annotations"].index(description)
|
210
|
+
except ValueError:
|
211
|
+
return None, None, None, None # Description not found
|
212
|
+
|
213
|
+
node_indices = self.graph.domain_id_to_node_ids_map.get(domain_id, [])
|
214
|
+
if not node_indices:
|
215
|
+
return None, None, None, None # No associated nodes
|
216
|
+
|
217
|
+
sig_p = enrichment_pvals[node_indices, annotation_idx]
|
218
|
+
dep_p = depletion_pvals[node_indices, annotation_idx]
|
219
|
+
sig_q = enrichment_qvals[node_indices, annotation_idx]
|
220
|
+
dep_q = depletion_qvals[node_indices, annotation_idx]
|
221
|
+
|
222
|
+
return (
|
223
|
+
np.min(sig_p) if sig_p.size > 0 else None,
|
224
|
+
np.min(sig_q) if sig_q.size > 0 else None,
|
225
|
+
np.min(dep_p) if dep_p.size > 0 else None,
|
226
|
+
np.min(dep_q) if dep_q.size > 0 else None,
|
227
|
+
)
|
228
|
+
|
229
|
+
def _get_annotation_members(self, description: str) -> str:
|
230
|
+
"""Retrieve node labels associated with a given annotation description.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
description (str): The annotation description.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
str: ';'-separated string of node labels that are associated with the annotation.
|
237
|
+
"""
|
238
|
+
try:
|
239
|
+
annotation_idx = self.annotations["ordered_annotations"].index(description)
|
240
|
+
except ValueError:
|
241
|
+
return "" # Description not found
|
242
|
+
|
243
|
+
# Get the column (safely) from the sparse matrix
|
244
|
+
column = self.annotations["matrix"][:, annotation_idx]
|
245
|
+
# Convert the column to a dense array if needed
|
246
|
+
column = column.toarray().ravel() # Convert to a 1D dense array
|
247
|
+
# Get nodes present for the annotation and sort by node label - use np.where on the dense array
|
248
|
+
nodes_present = np.where(column == 1)[0]
|
249
|
+
node_labels = sorted(
|
250
|
+
self.graph.node_id_to_node_label_map[node_id]
|
251
|
+
for node_id in nodes_present
|
252
|
+
if node_id in self.graph.node_id_to_node_label_map
|
253
|
+
)
|
254
|
+
return ";".join(node_labels)
|
risk/network/io.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
"""
|
2
2
|
risk/network/io
|
3
3
|
~~~~~~~~~~~~~~~
|
4
|
-
|
5
|
-
This file contains the code for the RISK class and command-line access.
|
6
4
|
"""
|
7
5
|
|
8
6
|
import copy
|
@@ -165,12 +163,12 @@ class NetworkIO:
|
|
165
163
|
filepath: str,
|
166
164
|
source_label: str = "source",
|
167
165
|
target_label: str = "target",
|
166
|
+
view_name: str = "",
|
168
167
|
compute_sphere: bool = True,
|
169
168
|
surface_depth: float = 0.0,
|
170
169
|
min_edges_per_node: int = 0,
|
171
170
|
include_edge_weight: bool = True,
|
172
171
|
weight_label: str = "weight",
|
173
|
-
view_name: str = "",
|
174
172
|
) -> nx.Graph:
|
175
173
|
"""Load a network from a Cytoscape file.
|
176
174
|
|
@@ -178,7 +176,7 @@ class NetworkIO:
|
|
178
176
|
filepath (str): Path to the Cytoscape file.
|
179
177
|
source_label (str, optional): Source node label. Defaults to "source".
|
180
178
|
target_label (str, optional): Target node label. Defaults to "target".
|
181
|
-
view_name (str, optional): Specific view name to load. Defaults to
|
179
|
+
view_name (str, optional): Specific view name to load. Defaults to "".
|
182
180
|
compute_sphere (bool, optional): Whether to map nodes to a sphere. Defaults to True.
|
183
181
|
surface_depth (float, optional): Surface depth for the sphere. Defaults to 0.0.
|
184
182
|
min_edges_per_node (int, optional): Minimum number of edges per node. Defaults to 0.
|
@@ -215,7 +213,7 @@ class NetworkIO:
|
|
215
213
|
filepath (str): Path to the Cytoscape file.
|
216
214
|
source_label (str, optional): Source node label. Defaults to "source".
|
217
215
|
target_label (str, optional): Target node label. Defaults to "target".
|
218
|
-
view_name (str, optional): Specific view name to load. Defaults to
|
216
|
+
view_name (str, optional): Specific view name to load. Defaults to "".
|
219
217
|
|
220
218
|
Returns:
|
221
219
|
nx.Graph: Loaded and processed network.
|
@@ -298,12 +296,8 @@ class NetworkIO:
|
|
298
296
|
# Add node attributes
|
299
297
|
for node in G.nodes():
|
300
298
|
G.nodes[node]["label"] = node
|
301
|
-
G.nodes[node]["x"] = node_x_positions[
|
302
|
-
|
303
|
-
] # Assuming you have a dict `node_x_positions` for x coordinates
|
304
|
-
G.nodes[node]["y"] = node_y_positions[
|
305
|
-
node
|
306
|
-
] # Assuming you have a dict `node_y_positions` for y coordinates
|
299
|
+
G.nodes[node]["x"] = node_x_positions[node]
|
300
|
+
G.nodes[node]["y"] = node_y_positions[node]
|
307
301
|
|
308
302
|
# Initialize the graph
|
309
303
|
return self._initialize_graph(G)
|
@@ -379,15 +373,17 @@ class NetworkIO:
|
|
379
373
|
node_y_positions = {}
|
380
374
|
for node in cyjs_data["elements"]["nodes"]:
|
381
375
|
node_data = node["data"]
|
382
|
-
|
376
|
+
# Use the original node ID if available, otherwise use the default ID
|
377
|
+
node_id = node_data.get("id_original", node_data.get("id"))
|
383
378
|
node_x_positions[node_id] = node["position"]["x"]
|
384
379
|
node_y_positions[node_id] = node["position"]["y"]
|
385
380
|
|
386
381
|
# Process edges and add them to the graph
|
387
382
|
for edge in cyjs_data["elements"]["edges"]:
|
388
383
|
edge_data = edge["data"]
|
389
|
-
source
|
390
|
-
|
384
|
+
# Use the original source and target labels if available, otherwise fall back to default labels
|
385
|
+
source = edge_data.get(f"{source_label}_original", edge_data.get(source_label))
|
386
|
+
target = edge_data.get(f"{target_label}_original", edge_data.get(target_label))
|
391
387
|
# Add the edge to the graph, optionally including weights
|
392
388
|
if self.weight_label is not None and self.weight_label in edge_data:
|
393
389
|
weight = float(edge_data[self.weight_label])
|
@@ -425,7 +421,7 @@ class NetworkIO:
|
|
425
421
|
self._remove_invalid_graph_properties(G)
|
426
422
|
# IMPORTANT: This is where the graph node labels are converted to integers
|
427
423
|
# Make sure to perform this step after all other processing
|
428
|
-
G = nx.
|
424
|
+
G = nx.convert_node_labels_to_integers(G)
|
429
425
|
return G
|
430
426
|
|
431
427
|
def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
|
@@ -435,22 +431,24 @@ class NetworkIO:
|
|
435
431
|
Args:
|
436
432
|
G (nx.Graph): A NetworkX graph object.
|
437
433
|
"""
|
438
|
-
# Count number of nodes and edges before cleaning
|
434
|
+
# Count the number of nodes and edges before cleaning
|
439
435
|
num_initial_nodes = G.number_of_nodes()
|
440
436
|
num_initial_edges = G.number_of_edges()
|
441
437
|
# Remove self-loops to ensure correct edge count
|
442
|
-
G.remove_edges_from(
|
438
|
+
G.remove_edges_from(nx.selfloop_edges(G))
|
443
439
|
# Iteratively remove nodes with fewer edges than the threshold
|
444
440
|
while True:
|
445
|
-
nodes_to_remove = [
|
441
|
+
nodes_to_remove = [
|
442
|
+
node
|
443
|
+
for node, degree in dict(G.degree()).items()
|
444
|
+
if degree < self.min_edges_per_node
|
445
|
+
]
|
446
446
|
if not nodes_to_remove:
|
447
|
-
break # Exit loop if no
|
447
|
+
break # Exit loop if no nodes meet the condition
|
448
448
|
G.remove_nodes_from(nodes_to_remove)
|
449
449
|
|
450
450
|
# Remove isolated nodes
|
451
|
-
|
452
|
-
if isolated_nodes:
|
453
|
-
G.remove_nodes_from(isolated_nodes)
|
451
|
+
G.remove_nodes_from(nx.isolates(G))
|
454
452
|
|
455
453
|
# Log the number of nodes and edges before and after cleaning
|
456
454
|
num_final_nodes = G.number_of_nodes()
|
@@ -468,12 +466,9 @@ class NetworkIO:
|
|
468
466
|
"""
|
469
467
|
missing_weights = 0
|
470
468
|
# Assign user-defined edge weights to the "weight" attribute
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
data["weight"] = data.get(
|
475
|
-
self.weight_label, 1.0
|
476
|
-
) # Default to 1.0 if 'weight' not present
|
469
|
+
nx.set_edge_attributes(G, 1.0, "weight") # Set default weight
|
470
|
+
if self.weight_label in nx.get_edge_attributes(G, self.weight_label):
|
471
|
+
nx.set_edge_attributes(G, nx.get_edge_attributes(G, self.weight_label), "weight")
|
477
472
|
|
478
473
|
if self.include_edge_weight and missing_weights:
|
479
474
|
logger.debug(f"Total edges missing weights: {missing_weights}")
|
@@ -483,41 +478,55 @@ class NetworkIO:
|
|
483
478
|
|
484
479
|
Args:
|
485
480
|
G (nx.Graph): A NetworkX graph object.
|
481
|
+
|
482
|
+
Raises:
|
483
|
+
ValueError: If a node is missing 'x', 'y', and a valid 'pos' attribute.
|
486
484
|
"""
|
487
|
-
#
|
485
|
+
# Retrieve all relevant attributes in bulk
|
486
|
+
pos_attrs = nx.get_node_attributes(G, "pos")
|
487
|
+
name_attrs = nx.get_node_attributes(G, "name")
|
488
|
+
id_attrs = nx.get_node_attributes(G, "id")
|
489
|
+
# Dictionaries to hold missing or fallback attributes
|
490
|
+
x_attrs = {}
|
491
|
+
y_attrs = {}
|
492
|
+
label_attrs = {}
|
488
493
|
nodes_with_missing_labels = []
|
489
494
|
|
490
|
-
|
491
|
-
|
495
|
+
# Iterate through nodes to validate and assign missing attributes
|
496
|
+
for node in G.nodes:
|
497
|
+
attrs = G.nodes[node]
|
498
|
+
# Validate and assign 'x' and 'y' attributes
|
492
499
|
if "x" not in attrs or "y" not in attrs:
|
493
500
|
if (
|
494
|
-
|
495
|
-
and isinstance(
|
496
|
-
and len(
|
501
|
+
node in pos_attrs
|
502
|
+
and isinstance(pos_attrs[node], (list, tuple, np.ndarray))
|
503
|
+
and len(pos_attrs[node]) >= 2
|
497
504
|
):
|
498
|
-
|
499
|
-
:2
|
500
|
-
] # Use only x and y, ignoring z if present
|
505
|
+
x_attrs[node], y_attrs[node] = pos_attrs[node][:2]
|
501
506
|
else:
|
502
507
|
raise ValueError(
|
503
508
|
f"Node {node} is missing 'x', 'y', and a valid 'pos' attribute."
|
504
509
|
)
|
505
510
|
|
506
|
-
#
|
511
|
+
# Validate and assign 'label' attribute
|
507
512
|
if "label" not in attrs:
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
attrs["label"] = attrs["id"]
|
513
|
+
if node in name_attrs:
|
514
|
+
label_attrs[node] = name_attrs[node]
|
515
|
+
elif node in id_attrs:
|
516
|
+
label_attrs[node] = id_attrs[node]
|
513
517
|
else:
|
514
|
-
#
|
518
|
+
# Assign node ID as label and log the missing label
|
519
|
+
label_attrs[node] = str(node)
|
515
520
|
nodes_with_missing_labels.append(node)
|
516
|
-
attrs["label"] = str(node) # Use node ID as the label
|
517
521
|
|
518
|
-
#
|
522
|
+
# Batch update attributes in the graph
|
523
|
+
nx.set_node_attributes(G, x_attrs, "x")
|
524
|
+
nx.set_node_attributes(G, y_attrs, "y")
|
525
|
+
nx.set_node_attributes(G, label_attrs, "label")
|
526
|
+
|
527
|
+
# Log a warning if any labels were missing
|
519
528
|
if nodes_with_missing_labels:
|
520
|
-
total_nodes =
|
529
|
+
total_nodes = G.number_of_nodes()
|
521
530
|
fraction_missing_labels = len(nodes_with_missing_labels) / total_nodes
|
522
531
|
logger.warning(
|
523
532
|
f"{len(nodes_with_missing_labels)} out of {total_nodes} nodes "
|
@@ -0,0 +1,54 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph/api
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import List, Tuple, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from risk.log import log_header
|
11
|
+
from risk.network.graph.network import NetworkGraph
|
12
|
+
from risk.network.plotter.network import NetworkPlotter
|
13
|
+
|
14
|
+
|
15
|
+
class PlotterAPI:
|
16
|
+
"""Handles the loading of network plotter objects.
|
17
|
+
|
18
|
+
The PlotterAPI class provides methods to load and configure NetworkPlotter objects for plotting network graphs.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__() -> None:
|
22
|
+
pass
|
23
|
+
|
24
|
+
def load_plotter(
|
25
|
+
self,
|
26
|
+
graph: NetworkGraph,
|
27
|
+
figsize: Union[List, Tuple, np.ndarray] = (10, 10),
|
28
|
+
background_color: str = "white",
|
29
|
+
background_alpha: Union[float, None] = 1.0,
|
30
|
+
pad: float = 0.3,
|
31
|
+
) -> NetworkPlotter:
|
32
|
+
"""Get a NetworkPlotter object for plotting.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
graph (NetworkGraph): The graph to plot.
|
36
|
+
figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
|
37
|
+
background_color (str, optional): Background color of the plot. Defaults to "white".
|
38
|
+
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
|
39
|
+
any existing alpha values found in background_color. Defaults to 1.0.
|
40
|
+
pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
NetworkPlotter: A NetworkPlotter object configured with the given parameters.
|
44
|
+
"""
|
45
|
+
log_header("Loading plotter")
|
46
|
+
|
47
|
+
# Initialize and return a NetworkPlotter object
|
48
|
+
return NetworkPlotter(
|
49
|
+
graph,
|
50
|
+
figsize=figsize,
|
51
|
+
background_color=background_color,
|
52
|
+
background_alpha=background_alpha,
|
53
|
+
pad=pad,
|
54
|
+
)
|
@@ -9,9 +9,9 @@ import matplotlib.pyplot as plt
|
|
9
9
|
import numpy as np
|
10
10
|
|
11
11
|
from risk.log import params
|
12
|
-
from risk.network.graph import NetworkGraph
|
13
|
-
from risk.network.
|
14
|
-
from risk.network.
|
12
|
+
from risk.network.graph.network import NetworkGraph
|
13
|
+
from risk.network.plotter.utils.colors import to_rgba
|
14
|
+
from risk.network.plotter.utils.layout import calculate_bounding_box
|
15
15
|
|
16
16
|
|
17
17
|
class Canvas:
|
@@ -36,6 +36,7 @@ class Canvas:
|
|
36
36
|
font: str = "Arial",
|
37
37
|
title_color: Union[str, List, Tuple, np.ndarray] = "black",
|
38
38
|
subtitle_color: Union[str, List, Tuple, np.ndarray] = "gray",
|
39
|
+
title_x: float = 0.5,
|
39
40
|
title_y: float = 0.975,
|
40
41
|
title_space_offset: float = 0.075,
|
41
42
|
subtitle_offset: float = 0.025,
|
@@ -52,6 +53,7 @@ class Canvas:
|
|
52
53
|
Defaults to "black".
|
53
54
|
subtitle_color (str, List, Tuple, or np.ndarray, optional): Color of the subtitle text. Can be a string or an array of colors.
|
54
55
|
Defaults to "gray".
|
56
|
+
title_x (float, optional): X-axis position of the title. Defaults to 0.5.
|
55
57
|
title_y (float, optional): Y-axis position of the title. Defaults to 0.975.
|
56
58
|
title_space_offset (float, optional): Fraction of figure height to leave for the space above the plot. Defaults to 0.075.
|
57
59
|
subtitle_offset (float, optional): Offset factor to position the subtitle below the title. Defaults to 0.025.
|
@@ -85,7 +87,7 @@ class Canvas:
|
|
85
87
|
fontsize=title_fontsize,
|
86
88
|
color=title_color,
|
87
89
|
fontname=font,
|
88
|
-
x=
|
90
|
+
x=title_x,
|
89
91
|
ha="center",
|
90
92
|
va="top",
|
91
93
|
y=title_y,
|
@@ -234,6 +236,7 @@ class Canvas:
|
|
234
236
|
# Scale the node coordinates if needed
|
235
237
|
scaled_coordinates = node_coordinates * scale
|
236
238
|
# Use the existing _draw_kde_contour method
|
239
|
+
# NOTE: This is a technical debt that should be refactored in the future - only works when inherited by NetworkPlotter
|
237
240
|
self._draw_kde_contour(
|
238
241
|
ax=self.ax,
|
239
242
|
pos=scaled_coordinates,
|