risk-network 0.0.5b5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/io.py +44 -3
- risk/constants.py +2 -2
- risk/log/params.py +2 -0
- risk/neighborhoods/community.py +7 -3
- risk/neighborhoods/domains.py +24 -18
- risk/neighborhoods/neighborhoods.py +2 -2
- risk/network/graph.py +68 -40
- risk/network/io.py +30 -10
- risk/network/plot.py +713 -308
- risk/risk.py +10 -22
- {risk_network-0.0.5b5.dist-info → risk_network-0.0.6.dist-info}/METADATA +3 -4
- risk_network-0.0.6.dist-info/RECORD +30 -0
- {risk_network-0.0.5b5.dist-info → risk_network-0.0.6.dist-info}/WHEEL +1 -1
- risk_network-0.0.5b5.dist-info/RECORD +0 -30
- {risk_network-0.0.5b5.dist-info → risk_network-0.0.6.dist-info}/LICENSE +0 -0
- {risk_network-0.0.5b5.dist-info → risk_network-0.0.6.dist-info}/top_level.txt +0 -0
risk/__init__.py
CHANGED
risk/annotations/io.py
CHANGED
@@ -36,13 +36,15 @@ class AnnotationsIO:
|
|
36
36
|
dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
|
37
37
|
"""
|
38
38
|
filetype = "JSON"
|
39
|
+
# Log the loading of the JSON file
|
39
40
|
params.log_annotations(filepath=filepath, filetype=filetype)
|
40
41
|
_log_loading(filetype, filepath=filepath)
|
42
|
+
|
41
43
|
# Open and read the JSON file
|
42
44
|
with open(filepath, "r") as file:
|
43
45
|
annotations_input = json.load(file)
|
44
46
|
|
45
|
-
#
|
47
|
+
# Load the annotations into the provided network
|
46
48
|
return load_annotations(network, annotations_input)
|
47
49
|
|
48
50
|
def load_excel_annotation(
|
@@ -69,14 +71,18 @@ class AnnotationsIO:
|
|
69
71
|
linked to the provided network.
|
70
72
|
"""
|
71
73
|
filetype = "Excel"
|
74
|
+
# Log the loading of the Excel file
|
72
75
|
params.log_annotations(filepath=filepath, filetype=filetype)
|
73
76
|
_log_loading(filetype, filepath=filepath)
|
77
|
+
|
74
78
|
# Load the specified sheet from the Excel file
|
75
79
|
df = pd.read_excel(filepath, sheet_name=sheet_name)
|
76
80
|
# Split the nodes column by the specified nodes_delimiter
|
77
81
|
df[nodes_colname] = df[nodes_colname].apply(lambda x: x.split(nodes_delimiter))
|
78
82
|
# Convert the DataFrame to a dictionary pairing labels with their corresponding nodes
|
79
83
|
label_node_dict = df.set_index(label_colname)[nodes_colname].to_dict()
|
84
|
+
|
85
|
+
# Load the annotations into the provided network
|
80
86
|
return load_annotations(network, label_node_dict)
|
81
87
|
|
82
88
|
def load_csv_annotation(
|
@@ -101,13 +107,16 @@ class AnnotationsIO:
|
|
101
107
|
linked to the provided network.
|
102
108
|
"""
|
103
109
|
filetype = "CSV"
|
110
|
+
# Log the loading of the CSV file
|
104
111
|
params.log_annotations(filepath=filepath, filetype=filetype)
|
105
112
|
_log_loading(filetype, filepath=filepath)
|
113
|
+
|
106
114
|
# Load the CSV file into a dictionary
|
107
115
|
annotations_input = _load_matrix_file(
|
108
116
|
filepath, label_colname, nodes_colname, delimiter=",", nodes_delimiter=nodes_delimiter
|
109
117
|
)
|
110
|
-
|
118
|
+
|
119
|
+
# Load the annotations into the provided network
|
111
120
|
return load_annotations(network, annotations_input)
|
112
121
|
|
113
122
|
def load_tsv_annotation(
|
@@ -132,15 +141,47 @@ class AnnotationsIO:
|
|
132
141
|
linked to the provided network.
|
133
142
|
"""
|
134
143
|
filetype = "TSV"
|
144
|
+
# Log the loading of the TSV file
|
135
145
|
params.log_annotations(filepath=filepath, filetype=filetype)
|
136
146
|
_log_loading(filetype, filepath=filepath)
|
147
|
+
|
137
148
|
# Load the TSV file into a dictionary
|
138
149
|
annotations_input = _load_matrix_file(
|
139
150
|
filepath, label_colname, nodes_colname, delimiter="\t", nodes_delimiter=nodes_delimiter
|
140
151
|
)
|
141
|
-
|
152
|
+
|
153
|
+
# Load the annotations into the provided network
|
142
154
|
return load_annotations(network, annotations_input)
|
143
155
|
|
156
|
+
def load_dict_annotation(self, content: Dict[str, Any], network: nx.Graph) -> Dict[str, Any]:
|
157
|
+
"""Load annotations from a provided dictionary and convert them to a dictionary annotation.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
content (dict): The annotations dictionary to load.
|
161
|
+
network (NetworkX graph): The network to which the annotations are related.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
dict: A dictionary containing ordered nodes, ordered annotations, and the annotations matrix.
|
165
|
+
"""
|
166
|
+
# Ensure the input content is a dictionary
|
167
|
+
if not isinstance(content, dict):
|
168
|
+
raise TypeError(
|
169
|
+
f"Expected 'content' to be a dictionary, but got {type(content).__name__} instead."
|
170
|
+
)
|
171
|
+
|
172
|
+
filetype = "Dictionary"
|
173
|
+
# Log the loading of the annotations from the dictionary
|
174
|
+
params.log_annotations(filepath="In-memory dictionary", filetype=filetype)
|
175
|
+
_log_loading(filetype, "In-memory dictionary")
|
176
|
+
|
177
|
+
# Load the annotations into the provided network
|
178
|
+
annotations_dict = load_annotations(network, content)
|
179
|
+
# Ensure the output is a dictionary
|
180
|
+
if not isinstance(annotations_dict, dict):
|
181
|
+
raise ValueError("Expected output to be a dictionary")
|
182
|
+
|
183
|
+
return annotations_dict
|
184
|
+
|
144
185
|
|
145
186
|
def _load_matrix_file(
|
146
187
|
filepath: str,
|
risk/constants.py
CHANGED
@@ -3,6 +3,8 @@ risk/constants
|
|
3
3
|
~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
+
GROUP_LINKAGE_METHODS = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
|
7
|
+
|
6
8
|
GROUP_DISTANCE_METRICS = [
|
7
9
|
"braycurtis",
|
8
10
|
"canberra",
|
@@ -27,5 +29,3 @@ GROUP_DISTANCE_METRICS = [
|
|
27
29
|
"sqeuclidean",
|
28
30
|
"yule",
|
29
31
|
]
|
30
|
-
|
31
|
-
GROUP_LINKAGE_METHODS = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
|
risk/log/params.py
CHANGED
@@ -7,6 +7,7 @@ import csv
|
|
7
7
|
import json
|
8
8
|
import warnings
|
9
9
|
from datetime import datetime
|
10
|
+
from functools import wraps
|
10
11
|
from typing import Any, Dict
|
11
12
|
|
12
13
|
import numpy as np
|
@@ -27,6 +28,7 @@ def _safe_param_export(func):
|
|
27
28
|
function: The wrapped function with error handling.
|
28
29
|
"""
|
29
30
|
|
31
|
+
@wraps(func)
|
30
32
|
def wrapper(*args, **kwargs):
|
31
33
|
try:
|
32
34
|
result = func(*args, **kwargs)
|
risk/neighborhoods/community.py
CHANGED
@@ -25,10 +25,14 @@ def calculate_dijkstra_neighborhoods(network: nx.Graph) -> np.ndarray:
|
|
25
25
|
|
26
26
|
# Populate the neighborhoods matrix based on Dijkstra's distances
|
27
27
|
for source, targets in all_dijkstra_paths.items():
|
28
|
+
max_length = max(targets.values()) if targets else 1 # Handle cases with no targets
|
28
29
|
for target, length in targets.items():
|
29
|
-
|
30
|
-
|
31
|
-
|
30
|
+
if np.isnan(length):
|
31
|
+
neighborhoods[source, target] = max_length # Use max distance for NaN
|
32
|
+
elif length == 0:
|
33
|
+
neighborhoods[source, target] = 1 # Assign 1 for zero-length paths (self-loops)
|
34
|
+
else:
|
35
|
+
neighborhoods[source, target] = 1 / length # Inverse of the distance
|
32
36
|
|
33
37
|
return neighborhoods
|
34
38
|
|
risk/neighborhoods/domains.py
CHANGED
@@ -35,26 +35,31 @@ def define_domains(
|
|
35
35
|
Returns:
|
36
36
|
pd.DataFrame: DataFrame with the primary domain for each node.
|
37
37
|
"""
|
38
|
-
#
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
38
|
+
# Check if there's more than one column in significant_neighborhoods_enrichment
|
39
|
+
if significant_neighborhoods_enrichment.shape[1] == 1:
|
40
|
+
print("Single annotation detected. Skipping clustering.")
|
41
|
+
top_annotations["domain"] = 1 # Assign a default domain or handle appropriately
|
42
|
+
else:
|
43
|
+
# Perform hierarchical clustering on the binary enrichment matrix
|
44
|
+
m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
|
45
|
+
best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
|
46
|
+
m, linkage_criterion, linkage_method, linkage_metric
|
47
|
+
)
|
48
|
+
try:
|
49
|
+
Z = linkage(m, method=best_linkage, metric=best_metric)
|
50
|
+
except ValueError as e:
|
51
|
+
raise ValueError("No significant annotations found.") from e
|
47
52
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
53
|
+
print(
|
54
|
+
f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'"
|
55
|
+
)
|
56
|
+
print(f"Optimal linkage threshold: {round(best_threshold, 3)}")
|
52
57
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
+
max_d_optimal = np.max(Z[:, 2]) * best_threshold
|
59
|
+
domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
|
60
|
+
# Assign domains to the annotations matrix
|
61
|
+
top_annotations["domain"] = 0
|
62
|
+
top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
|
58
63
|
|
59
64
|
# Create DataFrames to store domain information
|
60
65
|
node_to_enrichment = pd.DataFrame(
|
@@ -63,6 +68,7 @@ def define_domains(
|
|
63
68
|
)
|
64
69
|
node_to_domain = node_to_enrichment.groupby(level="domain", axis=1).sum()
|
65
70
|
|
71
|
+
# Find the maximum enrichment score for each node
|
66
72
|
t_max = node_to_domain.loc[:, 1:].max(axis=1)
|
67
73
|
t_idxmax = node_to_domain.loc[:, 1:].idxmax(axis=1)
|
68
74
|
t_idxmax[t_max == 0] = 0
|
@@ -4,7 +4,7 @@ risk/neighborhoods/neighborhoods
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import warnings
|
7
|
-
from typing import Any, Dict, Tuple
|
7
|
+
from typing import Any, Dict, List, Tuple
|
8
8
|
|
9
9
|
import networkx as nx
|
10
10
|
import numpy as np
|
@@ -305,7 +305,7 @@ def _get_node_position(network: nx.Graph, node: Any) -> np.ndarray:
|
|
305
305
|
)
|
306
306
|
|
307
307
|
|
308
|
-
def _calculate_threshold(average_distances:
|
308
|
+
def _calculate_threshold(average_distances: List, distance_threshold: float) -> float:
|
309
309
|
"""Calculate the distance threshold based on the given average distances and a percentile threshold.
|
310
310
|
|
311
311
|
Args:
|
risk/network/graph.py
CHANGED
@@ -28,7 +28,7 @@ class NetworkGraph:
|
|
28
28
|
top_annotations: pd.DataFrame,
|
29
29
|
domains: pd.DataFrame,
|
30
30
|
trimmed_domains: pd.DataFrame,
|
31
|
-
|
31
|
+
node_label_to_node_id_map: Dict[str, Any],
|
32
32
|
node_enrichment_sums: np.ndarray,
|
33
33
|
):
|
34
34
|
"""Initialize the NetworkGraph object.
|
@@ -38,39 +38,48 @@ class NetworkGraph:
|
|
38
38
|
top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
|
39
39
|
domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
|
40
40
|
trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
|
41
|
-
|
41
|
+
node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
|
42
42
|
node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
|
43
43
|
"""
|
44
44
|
self.top_annotations = top_annotations
|
45
|
-
self.
|
45
|
+
self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
|
46
46
|
self.domains = domains
|
47
|
-
self.
|
48
|
-
|
49
|
-
|
47
|
+
self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
|
48
|
+
trimmed_domains
|
49
|
+
)
|
50
50
|
self.node_enrichment_sums = node_enrichment_sums
|
51
|
-
|
51
|
+
self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
|
52
|
+
self.node_label_to_enrichment_map = dict(
|
53
|
+
zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
|
54
|
+
)
|
55
|
+
self.node_label_to_node_id_map = node_label_to_node_id_map
|
56
|
+
# NOTE: Below this point, instance attributes (i.e., self) will be used!
|
57
|
+
self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
|
58
|
+
# self.network and self.node_coordinates are properly declared in _initialize_network
|
52
59
|
self.network = None
|
53
60
|
self.node_coordinates = None
|
54
61
|
self._initialize_network(network)
|
55
62
|
|
56
|
-
def
|
57
|
-
"""Create a mapping from domains to the list of
|
63
|
+
def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
|
64
|
+
"""Create a mapping from domains to the list of node IDs belonging to each domain.
|
58
65
|
|
59
66
|
Args:
|
60
67
|
domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
|
61
68
|
|
62
69
|
Returns:
|
63
|
-
dict: A dictionary where keys are domain IDs and values are lists of
|
70
|
+
dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
|
64
71
|
"""
|
65
72
|
cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
|
66
|
-
|
67
|
-
|
68
|
-
for k, v in
|
69
|
-
|
73
|
+
node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
|
74
|
+
domain_id_to_node_ids_map = defaultdict(list)
|
75
|
+
for k, v in node_to_domains_map.items():
|
76
|
+
domain_id_to_node_ids_map[v].append(k)
|
70
77
|
|
71
|
-
return
|
78
|
+
return domain_id_to_node_ids_map
|
72
79
|
|
73
|
-
def
|
80
|
+
def _create_domain_id_to_domain_terms_map(
|
81
|
+
self, trimmed_domains: pd.DataFrame
|
82
|
+
) -> Dict[str, Any]:
|
74
83
|
"""Create a mapping from domain IDs to their corresponding terms.
|
75
84
|
|
76
85
|
Args:
|
@@ -86,6 +95,20 @@ class NetworkGraph:
|
|
86
95
|
)
|
87
96
|
)
|
88
97
|
|
98
|
+
def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
|
99
|
+
"""Create a map from domain IDs to node labels.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
dict: A dictionary mapping domain IDs to the corresponding node labels.
|
103
|
+
"""
|
104
|
+
domain_id_to_label_map = {}
|
105
|
+
for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
|
106
|
+
domain_id_to_label_map[domain_id] = [
|
107
|
+
self.node_id_to_node_label_map[node_id] for node_id in node_ids
|
108
|
+
]
|
109
|
+
|
110
|
+
return domain_id_to_label_map
|
111
|
+
|
89
112
|
def _initialize_network(self, G: nx.Graph) -> None:
|
90
113
|
"""Initialize the network by unfolding it and extracting node coordinates.
|
91
114
|
|
@@ -101,31 +124,32 @@ class NetworkGraph:
|
|
101
124
|
|
102
125
|
def get_domain_colors(
|
103
126
|
self,
|
127
|
+
cmap: str = "gist_rainbow",
|
128
|
+
color: Union[str, None] = None,
|
104
129
|
min_scale: float = 0.8,
|
105
130
|
max_scale: float = 1.0,
|
106
131
|
scale_factor: float = 1.0,
|
107
132
|
random_seed: int = 888,
|
108
|
-
**kwargs,
|
109
133
|
) -> np.ndarray:
|
110
|
-
"""Generate composite colors for domains.
|
111
|
-
|
112
|
-
This method generates composite colors for nodes based on their enrichment scores and transforms
|
113
|
-
them to ensure proper alpha values and intensity. For nodes with alpha == 0, it assigns new colors
|
114
|
-
based on the closest valid neighbors within a specified distance.
|
134
|
+
"""Generate composite colors for domains based on enrichment or specified colors.
|
115
135
|
|
116
136
|
Args:
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
137
|
+
cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
|
138
|
+
color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
|
139
|
+
min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap.
|
140
|
+
Controls the dimmest colors. Defaults to 0.8.
|
141
|
+
max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap.
|
142
|
+
Controls the brightest colors. Defaults to 1.0.
|
143
|
+
scale_factor (float, optional): Exponent for adjusting the color scaling based on enrichment scores.
|
144
|
+
A higher value increases contrast by dimming lower scores more. Defaults to 1.0.
|
145
|
+
random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments.
|
146
|
+
Defaults to 888.
|
123
147
|
|
124
148
|
Returns:
|
125
|
-
np.ndarray: Array of
|
149
|
+
np.ndarray: Array of RGBA colors generated for each domain, based on enrichment or the specified color.
|
126
150
|
"""
|
127
151
|
# Get colors for each domain
|
128
|
-
domain_colors = self._get_domain_colors(random_seed=random_seed)
|
152
|
+
domain_colors = self._get_domain_colors(cmap=cmap, color=color, random_seed=random_seed)
|
129
153
|
# Generate composite colors for nodes
|
130
154
|
node_colors = self._get_composite_node_colors(domain_colors)
|
131
155
|
# Transform colors to ensure proper alpha values and intensity
|
@@ -153,20 +177,24 @@ class NetworkGraph:
|
|
153
177
|
# Initialize composite colors array with shape (number of nodes, 4) for RGBA
|
154
178
|
composite_colors = np.zeros((num_nodes, 4))
|
155
179
|
# Assign colors to nodes based on domain_colors
|
156
|
-
for
|
157
|
-
color = domain_colors[
|
180
|
+
for domain_id, nodes in self.domain_id_to_node_ids_map.items():
|
181
|
+
color = domain_colors[domain_id]
|
158
182
|
for node in nodes:
|
159
183
|
composite_colors[node] = color
|
160
184
|
|
161
185
|
return composite_colors
|
162
186
|
|
163
187
|
def _get_domain_colors(
|
164
|
-
self,
|
188
|
+
self,
|
189
|
+
cmap: str = "gist_rainbow",
|
190
|
+
color: Union[str, None] = None,
|
191
|
+
random_seed: int = 888,
|
165
192
|
) -> Dict[str, Any]:
|
166
193
|
"""Get colors for each domain.
|
167
194
|
|
168
195
|
Args:
|
169
|
-
|
196
|
+
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
197
|
+
color (str or None, optional): A specific color to use for all generated colors. Defaults to None.
|
170
198
|
random_seed (int, optional): Seed for random number generation. Defaults to 888.
|
171
199
|
|
172
200
|
Returns:
|
@@ -178,9 +206,9 @@ class NetworkGraph:
|
|
178
206
|
]
|
179
207
|
domains = np.sort(numeric_domains)
|
180
208
|
domain_colors = _get_colors(
|
181
|
-
num_colors_to_generate=len(domains), color=color, random_seed=random_seed
|
209
|
+
num_colors_to_generate=len(domains), cmap=cmap, color=color, random_seed=random_seed
|
182
210
|
)
|
183
|
-
return dict(zip(self.
|
211
|
+
return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
|
184
212
|
|
185
213
|
|
186
214
|
def _transform_colors(
|
@@ -273,17 +301,17 @@ def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
|
|
273
301
|
|
274
302
|
def _get_colors(
|
275
303
|
num_colors_to_generate: int = 10,
|
276
|
-
cmap: str = "
|
277
|
-
random_seed: int = 888,
|
304
|
+
cmap: str = "gist_rainbow",
|
278
305
|
color: Union[str, None] = None,
|
306
|
+
random_seed: int = 888,
|
279
307
|
) -> List[Tuple]:
|
280
308
|
"""Generate a list of RGBA colors from a specified colormap or use a direct color string.
|
281
309
|
|
282
310
|
Args:
|
283
311
|
num_colors_to_generate (int): The number of colors to generate. Defaults to 10.
|
284
|
-
cmap (str): The name of the colormap to use. Defaults to "
|
312
|
+
cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
|
313
|
+
color (str or None, optional): A specific color to use for all generated colors.
|
285
314
|
random_seed (int): Seed for random number generation. Defaults to 888.
|
286
|
-
color (str, optional): Specific color to use for all nodes. If specified, it will overwrite the colormap.
|
287
315
|
Defaults to None.
|
288
316
|
|
289
317
|
Returns:
|
risk/network/io.py
CHANGED
@@ -48,6 +48,7 @@ class NetworkIO:
|
|
48
48
|
self.min_edges_per_node = min_edges_per_node
|
49
49
|
self.include_edge_weight = include_edge_weight
|
50
50
|
self.weight_label = weight_label
|
51
|
+
# Log the initialization of the NetworkIO class
|
51
52
|
params.log_network(
|
52
53
|
compute_sphere=compute_sphere,
|
53
54
|
surface_depth=surface_depth,
|
@@ -98,11 +99,14 @@ class NetworkIO:
|
|
98
99
|
nx.Graph: Loaded and processed network.
|
99
100
|
"""
|
100
101
|
filetype = "GPickle"
|
102
|
+
# Log the loading of the GPickle file
|
101
103
|
params.log_network(filetype=filetype, filepath=filepath)
|
102
104
|
self._log_loading(filetype, filepath=filepath)
|
105
|
+
|
103
106
|
with open(filepath, "rb") as f:
|
104
107
|
G = pickle.load(f)
|
105
108
|
|
109
|
+
# Initialize the graph
|
106
110
|
return self._initialize_graph(G)
|
107
111
|
|
108
112
|
@classmethod
|
@@ -147,8 +151,11 @@ class NetworkIO:
|
|
147
151
|
nx.Graph: Processed network.
|
148
152
|
"""
|
149
153
|
filetype = "NetworkX"
|
154
|
+
# Log the loading of the NetworkX graph
|
150
155
|
params.log_network(filetype=filetype)
|
151
156
|
self._log_loading(filetype)
|
157
|
+
|
158
|
+
# Initialize the graph
|
152
159
|
return self._initialize_graph(network)
|
153
160
|
|
154
161
|
@classmethod
|
@@ -213,8 +220,10 @@ class NetworkIO:
|
|
213
220
|
nx.Graph: Loaded and processed network.
|
214
221
|
"""
|
215
222
|
filetype = "Cytoscape"
|
223
|
+
# Log the loading of the Cytoscape file
|
216
224
|
params.log_network(filetype=filetype, filepath=str(filepath))
|
217
225
|
self._log_loading(filetype, filepath=filepath)
|
226
|
+
|
218
227
|
cys_files = []
|
219
228
|
tmp_dir = ".tmp_cytoscape"
|
220
229
|
# Try / finally to remove unzipped files
|
@@ -295,6 +304,7 @@ class NetworkIO:
|
|
295
304
|
node
|
296
305
|
] # Assuming you have a dict `node_y_positions` for y coordinates
|
297
306
|
|
307
|
+
# Initialize the graph
|
298
308
|
return self._initialize_graph(G)
|
299
309
|
|
300
310
|
finally:
|
@@ -354,6 +364,7 @@ class NetworkIO:
|
|
354
364
|
NetworkX graph: Loaded and processed network.
|
355
365
|
"""
|
356
366
|
filetype = "Cytoscape JSON"
|
367
|
+
# Log the loading of the Cytoscape JSON file
|
357
368
|
params.log_network(filetype=filetype, filepath=str(filepath))
|
358
369
|
self._log_loading(filetype, filepath=filepath)
|
359
370
|
|
@@ -418,29 +429,37 @@ class NetworkIO:
|
|
418
429
|
return G
|
419
430
|
|
420
431
|
def _remove_invalid_graph_properties(self, G: nx.Graph) -> None:
|
421
|
-
"""Remove invalid properties from the graph
|
432
|
+
"""Remove invalid properties from the graph, including self-loops, nodes with fewer edges than
|
433
|
+
the threshold, and isolated nodes.
|
422
434
|
|
423
435
|
Args:
|
424
436
|
G (nx.Graph): A NetworkX graph object.
|
425
437
|
"""
|
426
|
-
#
|
438
|
+
# Count number of nodes and edges before cleaning
|
439
|
+
num_initial_nodes = G.number_of_nodes()
|
440
|
+
num_initial_edges = G.number_of_edges()
|
441
|
+
# Remove self-loops to ensure correct edge count
|
427
442
|
G.remove_edges_from(list(nx.selfloop_edges(G)))
|
428
|
-
#
|
443
|
+
# Iteratively remove nodes with fewer edges than the threshold
|
429
444
|
while True:
|
430
|
-
nodes_to_remove = [
|
431
|
-
node for node in G.nodes() if G.degree(node) < self.min_edges_per_node
|
432
|
-
]
|
445
|
+
nodes_to_remove = [node for node in G.nodes if G.degree(node) < self.min_edges_per_node]
|
433
446
|
if not nodes_to_remove:
|
434
|
-
break # Exit loop if no more nodes
|
435
|
-
|
436
|
-
# Remove the nodes and their associated edges
|
447
|
+
break # Exit loop if no more nodes need removal
|
437
448
|
G.remove_nodes_from(nodes_to_remove)
|
438
449
|
|
439
|
-
#
|
450
|
+
# Remove isolated nodes
|
440
451
|
isolated_nodes = list(nx.isolates(G))
|
441
452
|
if isolated_nodes:
|
442
453
|
G.remove_nodes_from(isolated_nodes)
|
443
454
|
|
455
|
+
# Log the number of nodes and edges before and after cleaning
|
456
|
+
num_final_nodes = G.number_of_nodes()
|
457
|
+
num_final_edges = G.number_of_edges()
|
458
|
+
print(f"Initial node count: {num_initial_nodes}")
|
459
|
+
print(f"Final node count: {num_final_nodes}")
|
460
|
+
print(f"Initial edge count: {num_initial_edges}")
|
461
|
+
print(f"Final edge count: {num_final_edges}")
|
462
|
+
|
444
463
|
def _assign_edge_weights(self, G: nx.Graph) -> None:
|
445
464
|
"""Assign weights to the edges in the graph.
|
446
465
|
|
@@ -502,6 +521,7 @@ class NetworkIO:
|
|
502
521
|
print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
|
503
522
|
if self.include_edge_weight:
|
504
523
|
print(f"Weight label: {self.weight_label}")
|
524
|
+
print(f"Minimum edges per node: {self.min_edges_per_node}")
|
505
525
|
print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
|
506
526
|
if self.compute_sphere:
|
507
527
|
print(f"Surface depth: {self.surface_depth}")
|