risk-network 0.0.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +13 -0
- risk/annotations/__init__.py +7 -0
- risk/annotations/annotations.py +259 -0
- risk/annotations/io.py +183 -0
- risk/constants.py +31 -0
- risk/log/__init__.py +9 -0
- risk/log/console.py +16 -0
- risk/log/params.py +198 -0
- risk/neighborhoods/__init__.py +10 -0
- risk/neighborhoods/community.py +189 -0
- risk/neighborhoods/domains.py +257 -0
- risk/neighborhoods/neighborhoods.py +319 -0
- risk/network/__init__.py +8 -0
- risk/network/geometry.py +165 -0
- risk/network/graph.py +280 -0
- risk/network/io.py +326 -0
- risk/network/plot.py +795 -0
- risk/risk.py +382 -0
- risk/stats/__init__.py +6 -0
- risk/stats/permutation.py +88 -0
- risk/stats/stats.py +447 -0
- risk_network-0.0.3b1.dist-info/LICENSE +674 -0
- risk_network-0.0.3b1.dist-info/METADATA +751 -0
- risk_network-0.0.3b1.dist-info/RECORD +26 -0
- risk_network-0.0.3b1.dist-info/WHEEL +5 -0
- risk_network-0.0.3b1.dist-info/top_level.txt +1 -0
risk/network/graph.py
ADDED
@@ -0,0 +1,280 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/graph
|
3
|
+
~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
import random
|
7
|
+
from collections import defaultdict
|
8
|
+
from typing import Any, Dict, List, Tuple
|
9
|
+
|
10
|
+
import networkx as nx
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
import matplotlib
|
14
|
+
import matplotlib.cm as cm
|
15
|
+
|
16
|
+
|
17
|
+
class NetworkGraph:
|
18
|
+
"""A class to represent a network graph and process its nodes and edges.
|
19
|
+
|
20
|
+
The NetworkGraph class provides functionality to handle and manipulate a network graph,
|
21
|
+
including managing domains, annotations, and node enrichment data. It also includes methods
|
22
|
+
for transforming and mapping graph coordinates, as well as generating colors based on node
|
23
|
+
enrichment.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
network: nx.Graph,
|
29
|
+
top_annotations: pd.DataFrame,
|
30
|
+
domains: pd.DataFrame,
|
31
|
+
trimmed_domains: pd.DataFrame,
|
32
|
+
node_label_to_id_map: Dict[str, Any],
|
33
|
+
node_enrichment_sums: np.ndarray,
|
34
|
+
):
|
35
|
+
"""Initialize the NetworkGraph object.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
network (nx.Graph): The network graph.
|
39
|
+
top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
|
40
|
+
domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
|
41
|
+
trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
|
42
|
+
node_label_to_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
|
43
|
+
node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
|
44
|
+
"""
|
45
|
+
self.top_annotations = top_annotations
|
46
|
+
self.domain_to_nodes = self._create_domain_to_nodes_map(domains)
|
47
|
+
self.domains = domains
|
48
|
+
self.trimmed_domain_to_term = self._create_domain_to_term_map(trimmed_domains)
|
49
|
+
self.trimmed_domains = trimmed_domains
|
50
|
+
self.node_label_to_id_map = node_label_to_id_map
|
51
|
+
self.node_enrichment_sums = node_enrichment_sums
|
52
|
+
# NOTE: self.G and self.node_coordinates are declared in _initialize_network
|
53
|
+
self.G = None
|
54
|
+
self.node_coordinates = None
|
55
|
+
self._initialize_network(network)
|
56
|
+
|
57
|
+
def _create_domain_to_nodes_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
|
58
|
+
"""Create a mapping from domains to the list of nodes belonging to each domain.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
dict: A dictionary where keys are domain IDs and values are lists of nodes belonging to each domain.
|
65
|
+
"""
|
66
|
+
cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
|
67
|
+
node_to_domains = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
|
68
|
+
domain_to_nodes = defaultdict(list)
|
69
|
+
for k, v in node_to_domains.items():
|
70
|
+
domain_to_nodes[v].append(k)
|
71
|
+
|
72
|
+
return domain_to_nodes
|
73
|
+
|
74
|
+
def _create_domain_to_term_map(self, trimmed_domains: pd.DataFrame) -> Dict[str, Any]:
|
75
|
+
"""Create a mapping from domain IDs to their corresponding terms.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
dict: A dictionary mapping domain IDs to their corresponding terms.
|
82
|
+
"""
|
83
|
+
return dict(
|
84
|
+
zip(
|
85
|
+
trimmed_domains.index,
|
86
|
+
trimmed_domains["label"],
|
87
|
+
)
|
88
|
+
)
|
89
|
+
|
90
|
+
def _initialize_network(self, G: nx.Graph) -> None:
|
91
|
+
"""Initialize the network by unfolding it and extracting node coordinates.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
G (nx.Graph): The input network graph with 3D node coordinates.
|
95
|
+
"""
|
96
|
+
# Unfold the network's 3D coordinates to 2D
|
97
|
+
G_2d = _unfold_sphere_to_plane(G)
|
98
|
+
# Assign the unfolded graph to self.G
|
99
|
+
self.G = G_2d
|
100
|
+
# Extract 2D coordinates of nodes
|
101
|
+
self.node_coordinates = _extract_node_coordinates(G_2d)
|
102
|
+
|
103
|
+
def get_domain_colors(
|
104
|
+
self, min_scale: float = 0.8, max_scale: float = 1.0, random_seed: int = 888, **kwargs
|
105
|
+
) -> np.ndarray:
|
106
|
+
"""Generate composite colors for domains.
|
107
|
+
|
108
|
+
This method generates composite colors for nodes based on their enrichment scores and transforms
|
109
|
+
them to ensure proper alpha values and intensity. For nodes with alpha == 0, it assigns new colors
|
110
|
+
based on the closest valid neighbors within a specified distance.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
114
|
+
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
115
|
+
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
116
|
+
**kwargs: Additional keyword arguments for color generation.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
np.ndarray: Array of transformed colors.
|
120
|
+
"""
|
121
|
+
# Get colors for each domain
|
122
|
+
domain_colors = self._get_domain_colors(**kwargs, random_seed=random_seed)
|
123
|
+
# Generate composite colors for nodes
|
124
|
+
node_colors = self._get_composite_node_colors(domain_colors)
|
125
|
+
# Transform colors to ensure proper alpha values and intensity
|
126
|
+
transformed_colors = _transform_colors(
|
127
|
+
node_colors,
|
128
|
+
self.node_enrichment_sums,
|
129
|
+
min_scale=min_scale,
|
130
|
+
max_scale=max_scale,
|
131
|
+
)
|
132
|
+
|
133
|
+
return transformed_colors
|
134
|
+
|
135
|
+
def _get_composite_node_colors(self, domain_colors: np.ndarray) -> np.ndarray:
|
136
|
+
"""Generate composite colors for nodes based on domain colors and counts.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
domain_colors (np.ndarray): Array of colors corresponding to each domain.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
np.ndarray: Array of composite colors for each node.
|
143
|
+
"""
|
144
|
+
# Determine the number of nodes
|
145
|
+
num_nodes = len(self.node_coordinates)
|
146
|
+
# Initialize composite colors array with shape (number of nodes, 4) for RGBA
|
147
|
+
composite_colors = np.zeros((num_nodes, 4))
|
148
|
+
# Assign colors to nodes based on domain_colors
|
149
|
+
for domain_idx, nodes in self.domain_to_nodes.items():
|
150
|
+
color = domain_colors[domain_idx]
|
151
|
+
for node in nodes:
|
152
|
+
composite_colors[node] = color
|
153
|
+
|
154
|
+
return composite_colors
|
155
|
+
|
156
|
+
def _get_domain_colors(self, **kwargs) -> Dict[str, Any]:
|
157
|
+
"""Get colors for each domain.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
dict: A dictionary mapping domain keys to their corresponding RGBA colors.
|
161
|
+
"""
|
162
|
+
# Exclude non-numeric domain columns
|
163
|
+
numeric_domains = [
|
164
|
+
col for col in self.domains.columns if isinstance(col, (int, np.integer))
|
165
|
+
]
|
166
|
+
domains = np.sort(numeric_domains)
|
167
|
+
domain_colors = _get_colors(**kwargs, num_colors_to_generate=len(domains))
|
168
|
+
return dict(zip(self.domain_to_nodes.keys(), domain_colors))
|
169
|
+
|
170
|
+
|
171
|
+
def _transform_colors(
|
172
|
+
colors: np.ndarray, enrichment_sums: np.ndarray, min_scale: float = 0.8, max_scale: float = 1.0
|
173
|
+
) -> np.ndarray:
|
174
|
+
"""Transform colors to ensure proper alpha values and intensity based on enrichment sums.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
colors (np.ndarray): An array of RGBA colors.
|
178
|
+
enrichment_sums (np.ndarray): An array of enrichment sums corresponding to the colors.
|
179
|
+
min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
|
180
|
+
max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
|
181
|
+
|
182
|
+
Returns:
|
183
|
+
np.ndarray: The transformed array of RGBA colors with adjusted intensities.
|
184
|
+
"""
|
185
|
+
if min_scale == max_scale:
|
186
|
+
min_scale = max_scale - 10e-6 # Avoid division by zero
|
187
|
+
|
188
|
+
log_enrichment_sums = np.log1p(enrichment_sums) # Use log1p to avoid log(0)
|
189
|
+
# Normalize the capped enrichment sums to the range [0, 1]
|
190
|
+
normalized_sums = log_enrichment_sums / np.max(log_enrichment_sums)
|
191
|
+
# Scale normalized sums to the specified color range [min_scale, max_scale]
|
192
|
+
scaled_sums = min_scale + (max_scale - min_scale) * normalized_sums
|
193
|
+
# Adjust RGB values based on scaled sums
|
194
|
+
for i in range(3): # Only adjust RGB values
|
195
|
+
colors[:, i] = scaled_sums * colors[:, i]
|
196
|
+
|
197
|
+
return colors
|
198
|
+
|
199
|
+
|
200
|
+
def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
|
201
|
+
"""Convert 3D coordinates to 2D by unfolding a sphere to a plane.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
|
208
|
+
"""
|
209
|
+
for node in G.nodes():
|
210
|
+
if "z" in G.nodes[node]:
|
211
|
+
# Extract 3D coordinates
|
212
|
+
x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
|
213
|
+
# Calculate spherical coordinates theta and phi from Cartesian coordinates
|
214
|
+
r = np.sqrt(x**2 + y**2 + z**2)
|
215
|
+
theta = np.arctan2(y, x)
|
216
|
+
phi = np.arccos(z / r)
|
217
|
+
|
218
|
+
# Convert spherical coordinates to 2D plane coordinates
|
219
|
+
unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
|
220
|
+
unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
|
221
|
+
unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
|
222
|
+
# Update network node attributes
|
223
|
+
G.nodes[node]["x"] = unfolded_x
|
224
|
+
G.nodes[node]["y"] = -unfolded_y
|
225
|
+
# Remove the 'z' coordinate as it's no longer needed
|
226
|
+
del G.nodes[node]["z"]
|
227
|
+
|
228
|
+
return G
|
229
|
+
|
230
|
+
|
231
|
+
def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
|
232
|
+
"""Extract 2D coordinates of nodes from the graph.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
G (nx.Graph): The network graph with node coordinates.
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
np.ndarray: Array of node coordinates with shape (num_nodes, 2).
|
239
|
+
"""
|
240
|
+
# Extract x and y coordinates from graph nodes
|
241
|
+
x_coords = dict(G.nodes.data("x"))
|
242
|
+
y_coords = dict(G.nodes.data("y"))
|
243
|
+
coordinates_dicts = [x_coords, y_coords]
|
244
|
+
# Combine x and y coordinates into a single array
|
245
|
+
node_positions = {
|
246
|
+
node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
|
247
|
+
}
|
248
|
+
node_coordinates = np.vstack(list(node_positions.values()))
|
249
|
+
return node_coordinates
|
250
|
+
|
251
|
+
|
252
|
+
def _get_colors(
|
253
|
+
num_colors_to_generate: int = 10, cmap: str = "hsv", random_seed: int = 888, **kwargs
|
254
|
+
) -> List[Tuple]:
|
255
|
+
"""Generate a list of RGBA colors from a specified colormap or use a direct color string.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
num_colors_to_generate (int): The number of colors to generate. Defaults to 10.
|
259
|
+
cmap (str): The name of the colormap to use. Defaults to "hsv".
|
260
|
+
random_seed (int): Seed for random number generation. Defaults to 888.
|
261
|
+
**kwargs: Additional keyword arguments, such as 'color' for a specific color.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
list of tuple: List of RGBA colors.
|
265
|
+
"""
|
266
|
+
# Set random seed for reproducibility
|
267
|
+
random.seed(random_seed)
|
268
|
+
if kwargs.get("color"):
|
269
|
+
# If a direct color string is provided, generate a list with that color
|
270
|
+
rgba = matplotlib.colors.to_rgba(kwargs["color"])
|
271
|
+
rgbas = [rgba] * num_colors_to_generate
|
272
|
+
else:
|
273
|
+
colormap = cm.get_cmap(cmap)
|
274
|
+
# Generate evenly distributed color positions
|
275
|
+
color_positions = np.linspace(0, 1, num_colors_to_generate)
|
276
|
+
random.shuffle(color_positions) # Shuffle the positions to randomize colors
|
277
|
+
# Generate colors based on shuffled positions
|
278
|
+
rgbas = [colormap(pos) for pos in color_positions]
|
279
|
+
|
280
|
+
return rgbas
|
risk/network/io.py
ADDED
@@ -0,0 +1,326 @@
|
|
1
|
+
"""
|
2
|
+
risk/network/io
|
3
|
+
~~~~~~~~~~~~~~~
|
4
|
+
|
5
|
+
This file contains the code for the RISK class and command-line access.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import pickle
|
10
|
+
import shutil
|
11
|
+
import zipfile
|
12
|
+
from xml.dom import minidom
|
13
|
+
|
14
|
+
import networkx as nx
|
15
|
+
import pandas as pd
|
16
|
+
|
17
|
+
from risk.network.geometry import apply_edge_lengths
|
18
|
+
from risk.log import params, print_header
|
19
|
+
|
20
|
+
|
21
|
+
class NetworkIO:
|
22
|
+
"""A class for loading, processing, and managing network data.
|
23
|
+
|
24
|
+
The NetworkIO class provides methods to load network data from various formats (e.g., GPickle, NetworkX)
|
25
|
+
and process the network by adjusting node coordinates, calculating edge lengths, and validating graph structure.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
compute_sphere: bool = True,
|
31
|
+
surface_depth: float = 0.0,
|
32
|
+
distance_metric: str = "dijkstra",
|
33
|
+
edge_length_threshold: float = 0.5,
|
34
|
+
louvain_resolution: float = 0.1,
|
35
|
+
min_edges_per_node: int = 0,
|
36
|
+
include_edge_weight: bool = True,
|
37
|
+
weight_label: str = "weight",
|
38
|
+
):
|
39
|
+
self.compute_sphere = compute_sphere
|
40
|
+
self.surface_depth = surface_depth
|
41
|
+
self.include_edge_weight = include_edge_weight
|
42
|
+
self.weight_label = weight_label
|
43
|
+
self.distance_metric = distance_metric
|
44
|
+
self.edge_length_threshold = edge_length_threshold
|
45
|
+
self.louvain_resolution = louvain_resolution
|
46
|
+
self.min_edges_per_node = min_edges_per_node
|
47
|
+
|
48
|
+
def load_gpickle_network(self, filepath: str) -> nx.Graph:
|
49
|
+
"""Load a network from a GPickle file.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
filepath (str): Path to the GPickle file.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
nx.Graph: Loaded and processed network.
|
56
|
+
"""
|
57
|
+
filetype = "GPickle"
|
58
|
+
params.log_network(filetype=filetype, filepath=filepath)
|
59
|
+
self._log_loading(filetype, filepath=filepath)
|
60
|
+
with open(filepath, "rb") as f:
|
61
|
+
G = pickle.load(f)
|
62
|
+
|
63
|
+
return self._initialize_graph(G)
|
64
|
+
|
65
|
+
def load_networkx_network(self, G: nx.Graph) -> nx.Graph:
|
66
|
+
"""Load a NetworkX graph.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
G (nx.Graph): A NetworkX graph object.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
nx.Graph: Processed network.
|
73
|
+
"""
|
74
|
+
filetype = "NetworkX"
|
75
|
+
params.log_network(filetype=filetype)
|
76
|
+
self._log_loading(filetype)
|
77
|
+
return self._initialize_graph(G)
|
78
|
+
|
79
|
+
def load_cytoscape_network(
|
80
|
+
self,
|
81
|
+
filepath: str,
|
82
|
+
source_label: str = "source",
|
83
|
+
target_label: str = "target",
|
84
|
+
view_name: str = "",
|
85
|
+
) -> nx.Graph:
|
86
|
+
"""Load a network from a Cytoscape file.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
filepath (str): Path to the Cytoscape file.
|
90
|
+
source_label (str, optional): Source node label. Defaults to "source".
|
91
|
+
target_label (str, optional): Target node label. Defaults to "target".
|
92
|
+
view_name (str, optional): Specific view name to load. Defaults to None.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
nx.Graph: Loaded and processed network.
|
96
|
+
"""
|
97
|
+
filetype = "Cytoscape"
|
98
|
+
params.log_network(filetype=filetype, filepath=str(filepath))
|
99
|
+
self._log_loading(filetype, filepath=filepath)
|
100
|
+
cys_files = []
|
101
|
+
# Try / finally to remove unzipped files
|
102
|
+
try:
|
103
|
+
# Unzip CYS file
|
104
|
+
with zipfile.ZipFile(filepath, "r") as zip_ref:
|
105
|
+
cys_files = zip_ref.namelist()
|
106
|
+
zip_ref.extractall("./")
|
107
|
+
# Get first view and network instances
|
108
|
+
cys_view_files = [cf for cf in cys_files if "/views/" in cf]
|
109
|
+
cys_view_file = (
|
110
|
+
cys_view_files[0]
|
111
|
+
if not view_name
|
112
|
+
else [cvf for cvf in cys_view_files if cvf.endswith(view_name + ".xgmml")][0]
|
113
|
+
)
|
114
|
+
# Parse nodes
|
115
|
+
cys_view_dom = minidom.parse(cys_view_file)
|
116
|
+
cys_nodes = cys_view_dom.getElementsByTagName("node")
|
117
|
+
node_x_positions = {}
|
118
|
+
node_y_positions = {}
|
119
|
+
for node in cys_nodes:
|
120
|
+
# Node ID is found in 'label'
|
121
|
+
node_id = str(node.attributes["label"].value)
|
122
|
+
for child in node.childNodes:
|
123
|
+
if child.nodeType == 1 and child.tagName == "graphics":
|
124
|
+
node_x_positions[node_id] = float(child.attributes["x"].value)
|
125
|
+
node_y_positions[node_id] = float(child.attributes["y"].value)
|
126
|
+
|
127
|
+
# Read the node attributes (from /tables/)
|
128
|
+
attribute_metadata_keywords = ["/tables/", "SHARED_ATTRS", "edge.cytable"]
|
129
|
+
attribute_metadata = [
|
130
|
+
cf
|
131
|
+
for cf in cys_files
|
132
|
+
if all(keyword in cf for keyword in attribute_metadata_keywords)
|
133
|
+
][0]
|
134
|
+
# Load attributes file from Cytoscape as pandas data frame
|
135
|
+
attribute_table = pd.read_csv(attribute_metadata, sep=",", header=None, skiprows=1)
|
136
|
+
# Set columns
|
137
|
+
attribute_table.columns = attribute_table.iloc[0]
|
138
|
+
# Skip first four rows
|
139
|
+
attribute_table = attribute_table.iloc[4:, :]
|
140
|
+
# Conditionally select columns based on include_edge_weight
|
141
|
+
if self.include_edge_weight:
|
142
|
+
attribute_table = attribute_table[[source_label, target_label, self.weight_label]]
|
143
|
+
else:
|
144
|
+
attribute_table = attribute_table[[source_label, target_label]]
|
145
|
+
|
146
|
+
attribute_table = attribute_table.dropna().reset_index(drop=True)
|
147
|
+
# Create a graph
|
148
|
+
G = nx.Graph()
|
149
|
+
# Add edges and nodes, conditionally including weights
|
150
|
+
for _, row in attribute_table.iterrows():
|
151
|
+
source = row[source_label]
|
152
|
+
target = row[target_label]
|
153
|
+
if self.include_edge_weight:
|
154
|
+
weight = float(row[self.weight_label])
|
155
|
+
G.add_edge(source, target, weight=weight)
|
156
|
+
else:
|
157
|
+
G.add_edge(source, target)
|
158
|
+
|
159
|
+
if source not in G:
|
160
|
+
G.add_node(source) # Optionally add x, y coordinates here if available
|
161
|
+
if target not in G:
|
162
|
+
G.add_node(target) # Optionally add x, y coordinates here if available
|
163
|
+
|
164
|
+
# Add node attributes
|
165
|
+
for node in G.nodes():
|
166
|
+
G.nodes[node]["label"] = node
|
167
|
+
G.nodes[node]["x"] = node_x_positions[
|
168
|
+
node
|
169
|
+
] # Assuming you have a dict `node_x_positions` for x coordinates
|
170
|
+
G.nodes[node]["y"] = node_y_positions[
|
171
|
+
node
|
172
|
+
] # Assuming you have a dict `node_y_positions` for y coordinates
|
173
|
+
|
174
|
+
return self._initialize_graph(G)
|
175
|
+
|
176
|
+
finally:
|
177
|
+
# Remove unzipped files/directories
|
178
|
+
cys_dirnames = list(set([cf.split("/")[0] for cf in cys_files]))
|
179
|
+
for dirname in cys_dirnames:
|
180
|
+
shutil.rmtree(dirname)
|
181
|
+
|
182
|
+
def load_cytoscape_json_network(self, filepath, source_label="source", target_label="target"):
|
183
|
+
"""Load a network from a Cytoscape JSON (.cyjs) file.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
filepath (str): Path to the Cytoscape JSON file.
|
187
|
+
source_label (str, optional): Source node label. Default is "source".
|
188
|
+
target_label (str, optional): Target node label. Default is "target".
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
NetworkX graph: Loaded and processed network.
|
192
|
+
"""
|
193
|
+
filetype = "Cytoscape JSON"
|
194
|
+
params.log_network(filetype=filetype, filepath=str(filepath))
|
195
|
+
self._log_loading(filetype, filepath=filepath)
|
196
|
+
# Load the Cytoscape JSON file
|
197
|
+
with open(filepath, "r") as f:
|
198
|
+
cyjs_data = json.load(f)
|
199
|
+
|
200
|
+
# Create a graph
|
201
|
+
G = nx.Graph()
|
202
|
+
# Process nodes
|
203
|
+
node_x_positions = {}
|
204
|
+
node_y_positions = {}
|
205
|
+
for node in cyjs_data["elements"]["nodes"]:
|
206
|
+
node_data = node["data"]
|
207
|
+
node_id = node_data["id"]
|
208
|
+
node_x_positions[node_id] = node["position"]["x"]
|
209
|
+
node_y_positions[node_id] = node["position"]["y"]
|
210
|
+
G.add_node(node_id)
|
211
|
+
G.nodes[node_id]["label"] = node_data.get("name", node_id)
|
212
|
+
G.nodes[node_id]["x"] = node["position"]["x"]
|
213
|
+
G.nodes[node_id]["y"] = node["position"]["y"]
|
214
|
+
|
215
|
+
# Process edges
|
216
|
+
for edge in cyjs_data["elements"]["edges"]:
|
217
|
+
edge_data = edge["data"]
|
218
|
+
source = edge_data[source_label]
|
219
|
+
target = edge_data[target_label]
|
220
|
+
if self.weight_label is not None and self.weight_label in edge_data:
|
221
|
+
weight = float(edge_data[self.weight_label])
|
222
|
+
G.add_edge(source, target, weight=weight)
|
223
|
+
else:
|
224
|
+
G.add_edge(source, target)
|
225
|
+
|
226
|
+
# Initialize the graph
|
227
|
+
return self._initialize_graph(G)
|
228
|
+
|
229
|
+
def _initialize_graph(self, G: nx.Graph) -> nx.Graph:
|
230
|
+
"""Initialize the graph by processing and validating its nodes and edges.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
G (nx.Graph): The input NetworkX graph.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
nx.Graph: The processed and validated graph.
|
237
|
+
"""
|
238
|
+
# IMPORTANT: This is where the graph node labels are converted to integers
|
239
|
+
G = nx.relabel_nodes(G, {node: idx for idx, node in enumerate(G.nodes)})
|
240
|
+
self._remove_invalid_graph_properties(G)
|
241
|
+
self._validate_edges(G)
|
242
|
+
self._validate_nodes(G)
|
243
|
+
self._process_graph(G)
|
244
|
+
return G
|
245
|
+
|
246
|
+
def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
|
247
|
+
"""Remove invalid properties from the graph.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
G (nx.Graph): A NetworkX graph object.
|
251
|
+
"""
|
252
|
+
print(f"Minimum edges per node: {self.min_edges_per_node}")
|
253
|
+
# Remove nodes with fewer edges than the specified threshold
|
254
|
+
nodes_with_few_edges = [
|
255
|
+
node for node in G.nodes() if G.degree(node) <= self.min_edges_per_node
|
256
|
+
]
|
257
|
+
G.remove_nodes_from(nodes_with_few_edges)
|
258
|
+
# Remove self-loop edges
|
259
|
+
self_loops = list(nx.selfloop_edges(G))
|
260
|
+
G.remove_edges_from(self_loops)
|
261
|
+
|
262
|
+
def _validate_edges(self, G: nx.Graph) -> None:
|
263
|
+
"""Validate and assign weights to the edges in the graph.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
G (nx.Graph): A NetworkX graph object.
|
267
|
+
"""
|
268
|
+
missing_weights = 0
|
269
|
+
# Assign user-defined edge weights to the "weight" attribute
|
270
|
+
for _, _, data in G.edges(data=True):
|
271
|
+
if self.weight_label not in data:
|
272
|
+
missing_weights += 1
|
273
|
+
data["weight"] = data.get(
|
274
|
+
self.weight_label, 1.0
|
275
|
+
) # Default to 1.0 if 'weight' not present
|
276
|
+
|
277
|
+
if self.include_edge_weight and missing_weights:
|
278
|
+
print(f"Total edges missing weights: {missing_weights}")
|
279
|
+
|
280
|
+
def _validate_nodes(self, G: nx.Graph) -> None:
|
281
|
+
"""Validate the graph structure and attributes.
|
282
|
+
|
283
|
+
Args:
|
284
|
+
G (nx.Graph): A NetworkX graph object.
|
285
|
+
"""
|
286
|
+
for node, attrs in G.nodes(data=True):
|
287
|
+
assert (
|
288
|
+
"x" in attrs and "y" in attrs
|
289
|
+
), f"Node {node} is missing 'x' or 'y' position attributes."
|
290
|
+
assert "label" in attrs, f"Node {node} is missing a 'label' attribute."
|
291
|
+
|
292
|
+
def _process_graph(self, G: nx.Graph) -> None:
|
293
|
+
"""Prepare the network by adjusting surface depth and calculating edge lengths.
|
294
|
+
|
295
|
+
Args:
|
296
|
+
G (nx.Graph): The input network graph.
|
297
|
+
"""
|
298
|
+
apply_edge_lengths(
|
299
|
+
G,
|
300
|
+
compute_sphere=self.compute_sphere,
|
301
|
+
surface_depth=self.surface_depth,
|
302
|
+
include_edge_weight=self.include_edge_weight,
|
303
|
+
)
|
304
|
+
|
305
|
+
def _log_loading(
|
306
|
+
self,
|
307
|
+
filetype: str,
|
308
|
+
filepath: str = "",
|
309
|
+
) -> None:
|
310
|
+
"""Log the initialization details of the RISK class.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
filetype (str): The type of the file being loaded (e.g., 'CSV', 'JSON').
|
314
|
+
filepath (str, optional): The path to the file being loaded. Defaults to "".
|
315
|
+
"""
|
316
|
+
print_header("Loading network")
|
317
|
+
print(f"Filetype: {filetype}")
|
318
|
+
if filepath:
|
319
|
+
print(f"Filepath: {filepath}")
|
320
|
+
print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
|
321
|
+
if self.compute_sphere:
|
322
|
+
print(f"Surface depth: {self.surface_depth}")
|
323
|
+
print(f"Edge length threshold: {self.edge_length_threshold}")
|
324
|
+
print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
|
325
|
+
if self.include_edge_weight:
|
326
|
+
print(f"Weight label: {self.weight_label}")
|