risk-network 0.0.6b8__py3-none-any.whl → 0.0.6b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/log/params.py +2 -0
- risk/network/graph.py +38 -19
- risk/network/plot.py +31 -26
- risk/risk.py +4 -4
- {risk_network-0.0.6b8.dist-info → risk_network-0.0.6b10.dist-info}/METADATA +1 -1
- {risk_network-0.0.6b8.dist-info → risk_network-0.0.6b10.dist-info}/RECORD +10 -10
- {risk_network-0.0.6b8.dist-info → risk_network-0.0.6b10.dist-info}/WHEEL +1 -1
- {risk_network-0.0.6b8.dist-info → risk_network-0.0.6b10.dist-info}/LICENSE +0 -0
- {risk_network-0.0.6b8.dist-info → risk_network-0.0.6b10.dist-info}/top_level.txt +0 -0
risk/__init__.py
CHANGED
risk/log/params.py
CHANGED
@@ -7,6 +7,7 @@ import csv
|
|
7
7
|
import json
|
8
8
|
import warnings
|
9
9
|
from datetime import datetime
|
10
|
+
from functools import wraps
|
10
11
|
from typing import Any, Dict
|
11
12
|
|
12
13
|
import numpy as np
|
@@ -27,6 +28,7 @@ def _safe_param_export(func):
|
|
27
28
|
function: The wrapped function with error handling.
|
28
29
|
"""
|
29
30
|
|
31
|
+
@wraps(func)
|
30
32
|
def wrapper(*args, **kwargs):
|
31
33
|
try:
|
32
34
|
result = func(*args, **kwargs)
|
risk/network/graph.py
CHANGED
@@ -28,7 +28,7 @@ class NetworkGraph:
|
|
28
28
|
top_annotations: pd.DataFrame,
|
29
29
|
domains: pd.DataFrame,
|
30
30
|
trimmed_domains: pd.DataFrame,
|
31
|
-
|
31
|
+
node_label_to_node_id_map: Dict[str, Any],
|
32
32
|
node_enrichment_sums: np.ndarray,
|
33
33
|
):
|
34
34
|
"""Initialize the NetworkGraph object.
|
@@ -38,43 +38,48 @@ class NetworkGraph:
|
|
38
38
|
top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
|
39
39
|
domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
|
40
40
|
trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
|
41
|
-
|
41
|
+
node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
|
42
42
|
node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
|
43
43
|
"""
|
44
44
|
self.top_annotations = top_annotations
|
45
|
-
self.
|
45
|
+
self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
|
46
46
|
self.domains = domains
|
47
|
-
self.
|
48
|
-
|
47
|
+
self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
|
48
|
+
trimmed_domains
|
49
|
+
)
|
49
50
|
self.node_enrichment_sums = node_enrichment_sums
|
50
|
-
self.
|
51
|
+
self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
|
51
52
|
self.node_label_to_enrichment_map = dict(
|
52
|
-
zip(
|
53
|
+
zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
|
53
54
|
)
|
54
|
-
self.
|
55
|
-
# NOTE:
|
55
|
+
self.node_label_to_node_id_map = node_label_to_node_id_map
|
56
|
+
# NOTE: Below this point, instance attributes (i.e., self) will be used!
|
57
|
+
self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
|
58
|
+
# self.network and self.node_coordinates are properly declared in _initialize_network
|
56
59
|
self.network = None
|
57
60
|
self.node_coordinates = None
|
58
61
|
self._initialize_network(network)
|
59
62
|
|
60
|
-
def
|
61
|
-
"""Create a mapping from domains to the list of
|
63
|
+
def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
|
64
|
+
"""Create a mapping from domains to the list of node IDs belonging to each domain.
|
62
65
|
|
63
66
|
Args:
|
64
67
|
domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
|
65
68
|
|
66
69
|
Returns:
|
67
|
-
dict: A dictionary where keys are domain IDs and values are lists of
|
70
|
+
dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
|
68
71
|
"""
|
69
72
|
cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
|
70
73
|
node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
|
71
|
-
|
74
|
+
domain_id_to_node_ids_map = defaultdict(list)
|
72
75
|
for k, v in node_to_domains_map.items():
|
73
|
-
|
76
|
+
domain_id_to_node_ids_map[v].append(k)
|
74
77
|
|
75
|
-
return
|
78
|
+
return domain_id_to_node_ids_map
|
76
79
|
|
77
|
-
def
|
80
|
+
def _create_domain_id_to_domain_terms_map(
|
81
|
+
self, trimmed_domains: pd.DataFrame
|
82
|
+
) -> Dict[str, Any]:
|
78
83
|
"""Create a mapping from domain IDs to their corresponding terms.
|
79
84
|
|
80
85
|
Args:
|
@@ -90,6 +95,20 @@ class NetworkGraph:
|
|
90
95
|
)
|
91
96
|
)
|
92
97
|
|
98
|
+
def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
|
99
|
+
"""Create a map from domain IDs to node labels.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
dict: A dictionary mapping domain IDs to the corresponding node labels.
|
103
|
+
"""
|
104
|
+
domain_id_to_label_map = {}
|
105
|
+
for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
|
106
|
+
domain_id_to_label_map[domain_id] = [
|
107
|
+
self.node_id_to_node_label_map[node_id] for node_id in node_ids
|
108
|
+
]
|
109
|
+
|
110
|
+
return domain_id_to_label_map
|
111
|
+
|
93
112
|
def _initialize_network(self, G: nx.Graph) -> None:
|
94
113
|
"""Initialize the network by unfolding it and extracting node coordinates.
|
95
114
|
|
@@ -158,8 +177,8 @@ class NetworkGraph:
|
|
158
177
|
# Initialize composite colors array with shape (number of nodes, 4) for RGBA
|
159
178
|
composite_colors = np.zeros((num_nodes, 4))
|
160
179
|
# Assign colors to nodes based on domain_colors
|
161
|
-
for
|
162
|
-
color = domain_colors[
|
180
|
+
for domain_id, nodes in self.domain_id_to_node_ids_map.items():
|
181
|
+
color = domain_colors[domain_id]
|
163
182
|
for node in nodes:
|
164
183
|
composite_colors[node] = color
|
165
184
|
|
@@ -189,7 +208,7 @@ class NetworkGraph:
|
|
189
208
|
domain_colors = _get_colors(
|
190
209
|
num_colors_to_generate=len(domains), cmap=cmap, color=color, random_seed=random_seed
|
191
210
|
)
|
192
|
-
return dict(zip(self.
|
211
|
+
return dict(zip(self.domain_id_to_node_ids_map.keys(), domain_colors))
|
193
212
|
|
194
213
|
|
195
214
|
def _transform_colors(
|
risk/network/plot.py
CHANGED
@@ -308,9 +308,9 @@ class NetworkPlotter:
|
|
308
308
|
|
309
309
|
# Filter to get node IDs and their coordinates
|
310
310
|
node_ids = [
|
311
|
-
self.graph.
|
311
|
+
self.graph.node_label_to_node_id_map.get(node)
|
312
312
|
for node in nodes
|
313
|
-
if node in self.graph.
|
313
|
+
if node in self.graph.node_label_to_node_id_map
|
314
314
|
]
|
315
315
|
if not node_ids:
|
316
316
|
raise ValueError("No nodes found in the network graph.")
|
@@ -320,7 +320,7 @@ class NetworkPlotter:
|
|
320
320
|
node_color = [
|
321
321
|
node_color[nodes.index(node)]
|
322
322
|
for node in nodes
|
323
|
-
if node in self.graph.
|
323
|
+
if node in self.graph.node_label_to_node_id_map
|
324
324
|
]
|
325
325
|
|
326
326
|
# Convert colors to RGBA using the _to_rgba helper function
|
@@ -389,16 +389,16 @@ class NetworkPlotter:
|
|
389
389
|
)
|
390
390
|
|
391
391
|
# Ensure color is converted to RGBA with repetition matching the number of domains
|
392
|
-
color = _to_rgba(color, alpha, num_repeats=len(self.graph.
|
392
|
+
color = _to_rgba(color, alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map))
|
393
393
|
# Extract node coordinates from the network graph
|
394
394
|
node_coordinates = self.graph.node_coordinates
|
395
395
|
# Draw contours for each domain in the network
|
396
|
-
for idx, (_,
|
397
|
-
if len(
|
396
|
+
for idx, (_, node_ids) in enumerate(self.graph.domain_id_to_node_ids_map.items()):
|
397
|
+
if len(node_ids) > 1:
|
398
398
|
self._draw_kde_contour(
|
399
399
|
self.ax,
|
400
400
|
node_coordinates,
|
401
|
-
|
401
|
+
node_ids,
|
402
402
|
color=color[idx],
|
403
403
|
levels=levels,
|
404
404
|
bandwidth=bandwidth,
|
@@ -452,9 +452,9 @@ class NetworkPlotter:
|
|
452
452
|
for sublist in node_groups:
|
453
453
|
# Filter to get node IDs and their coordinates for each sublist
|
454
454
|
node_ids = [
|
455
|
-
self.graph.
|
455
|
+
self.graph.node_label_to_node_id_map.get(node)
|
456
456
|
for node in sublist
|
457
|
-
if node in self.graph.
|
457
|
+
if node in self.graph.node_label_to_node_id_map
|
458
458
|
]
|
459
459
|
if not node_ids or len(node_ids) == 1:
|
460
460
|
raise ValueError(
|
@@ -641,12 +641,14 @@ class NetworkPlotter:
|
|
641
641
|
|
642
642
|
# Set max_labels to the total number of domains if not provided (None)
|
643
643
|
if max_labels is None:
|
644
|
-
max_labels = len(self.graph.
|
644
|
+
max_labels = len(self.graph.domain_id_to_node_ids_map)
|
645
645
|
|
646
646
|
# Convert colors to RGBA using the _to_rgba helper function
|
647
|
-
fontcolor = _to_rgba(
|
647
|
+
fontcolor = _to_rgba(
|
648
|
+
fontcolor, fontalpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
649
|
+
)
|
648
650
|
arrow_color = _to_rgba(
|
649
|
-
arrow_color, arrow_alpha, num_repeats=len(self.graph.
|
651
|
+
arrow_color, arrow_alpha, num_repeats=len(self.graph.domain_id_to_node_ids_map)
|
650
652
|
)
|
651
653
|
|
652
654
|
# Normalize words_to_omit to lowercase
|
@@ -655,9 +657,9 @@ class NetworkPlotter:
|
|
655
657
|
|
656
658
|
# Calculate the center and radius of the network
|
657
659
|
domain_centroids = {}
|
658
|
-
for
|
659
|
-
if
|
660
|
-
domain_centroids[
|
660
|
+
for domain_id, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
661
|
+
if node_ids: # Skip if the domain has no nodes
|
662
|
+
domain_centroids[domain_id] = self._calculate_domain_centroid(node_ids)
|
661
663
|
|
662
664
|
# Initialize dictionaries and lists for valid indices
|
663
665
|
valid_indices = []
|
@@ -675,12 +677,15 @@ class NetworkPlotter:
|
|
675
677
|
|
676
678
|
# Process the specified IDs first
|
677
679
|
for domain in ids_to_keep:
|
678
|
-
if
|
680
|
+
if (
|
681
|
+
domain in self.graph.domain_id_to_domain_terms_map
|
682
|
+
and domain in domain_centroids
|
683
|
+
):
|
679
684
|
# Handle ids_to_replace logic here for ids_to_keep
|
680
685
|
if ids_to_replace and domain in ids_to_replace:
|
681
686
|
terms = ids_to_replace[domain].split(" ")
|
682
687
|
else:
|
683
|
-
terms = self.graph.
|
688
|
+
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
684
689
|
|
685
690
|
# Apply words_to_omit, word length constraints, and max_words
|
686
691
|
if words_to_omit:
|
@@ -712,7 +717,7 @@ class NetworkPlotter:
|
|
712
717
|
if ids_to_replace and domain in ids_to_replace:
|
713
718
|
terms = ids_to_replace[domain].split(" ")
|
714
719
|
else:
|
715
|
-
terms = self.graph.
|
720
|
+
terms = self.graph.domain_id_to_domain_terms_map[domain].split(" ")
|
716
721
|
|
717
722
|
# Apply words_to_omit, word length constraints, and max_words
|
718
723
|
if words_to_omit:
|
@@ -835,9 +840,9 @@ class NetworkPlotter:
|
|
835
840
|
for sublist in node_groups:
|
836
841
|
# Map node labels to IDs
|
837
842
|
node_ids = [
|
838
|
-
self.graph.
|
843
|
+
self.graph.node_label_to_node_id_map.get(node)
|
839
844
|
for node in sublist
|
840
|
-
if node in self.graph.
|
845
|
+
if node in self.graph.node_label_to_node_id_map
|
841
846
|
]
|
842
847
|
if not node_ids or len(node_ids) == 1:
|
843
848
|
raise ValueError(
|
@@ -948,10 +953,10 @@ class NetworkPlotter:
|
|
948
953
|
Returns:
|
949
954
|
np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
|
950
955
|
"""
|
951
|
-
# Merge all enriched nodes from the
|
956
|
+
# Merge all enriched nodes from the domain_id_to_node_ids_map dictionary
|
952
957
|
enriched_nodes = set()
|
953
|
-
for _,
|
954
|
-
enriched_nodes.update(
|
958
|
+
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
959
|
+
enriched_nodes.update(node_ids)
|
955
960
|
|
956
961
|
# Initialize all node sizes to the non-enriched size
|
957
962
|
node_sizes = np.full(len(self.graph.network.nodes), nonenriched_size)
|
@@ -1065,10 +1070,10 @@ class NetworkPlotter:
|
|
1065
1070
|
random_seed=random_seed,
|
1066
1071
|
)
|
1067
1072
|
annotated_colors = []
|
1068
|
-
for _,
|
1069
|
-
if len(
|
1073
|
+
for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
|
1074
|
+
if len(node_ids) > 1:
|
1070
1075
|
# For multi-node domains, choose the brightest color based on RGB sum
|
1071
|
-
domain_colors = np.array([node_colors[node] for node in
|
1076
|
+
domain_colors = np.array([node_colors[node] for node in node_ids])
|
1072
1077
|
brightest_color = domain_colors[
|
1073
1078
|
np.argmax(domain_colors[:, :3].sum(axis=1)) # Sum the RGB values
|
1074
1079
|
]
|
risk/risk.py
CHANGED
@@ -62,7 +62,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
62
62
|
|
63
63
|
Args:
|
64
64
|
network (nx.Graph): The network graph.
|
65
|
-
annotations (
|
65
|
+
annotations (dict): The annotations associated with the network.
|
66
66
|
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
67
67
|
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
68
68
|
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
@@ -131,7 +131,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
131
131
|
|
132
132
|
Args:
|
133
133
|
network (nx.Graph): The network graph.
|
134
|
-
annotations (
|
134
|
+
annotations (dict): The annotations associated with the network.
|
135
135
|
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
136
136
|
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
137
137
|
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
@@ -187,7 +187,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
187
187
|
|
188
188
|
Args:
|
189
189
|
network (nx.Graph): The network graph.
|
190
|
-
annotations (
|
190
|
+
annotations (dict): The annotations associated with the network.
|
191
191
|
distance_metric (str, optional): Distance metric for neighborhood analysis. Defaults to "dijkstra".
|
192
192
|
louvain_resolution (float, optional): Resolution parameter for Louvain clustering. Defaults to 0.1.
|
193
193
|
edge_length_threshold (float, optional): Edge length threshold for neighborhood analysis. Defaults to 0.5.
|
@@ -343,7 +343,7 @@ class RISK(NetworkIO, AnnotationsIO):
|
|
343
343
|
top_annotations=top_annotations,
|
344
344
|
domains=domains,
|
345
345
|
trimmed_domains=trimmed_domains,
|
346
|
-
|
346
|
+
node_label_to_node_id_map=node_label_to_id,
|
347
347
|
node_enrichment_sums=node_enrichment_sums,
|
348
348
|
)
|
349
349
|
|
@@ -1,21 +1,21 @@
|
|
1
|
-
risk/__init__.py,sha256=
|
1
|
+
risk/__init__.py,sha256=nUkWz8VqnztwzCLGVbN6G8bFixZFkDyLsEOsb5mLAGc,113
|
2
2
|
risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
|
3
|
-
risk/risk.py,sha256=
|
3
|
+
risk/risk.py,sha256=CKDIzVo9Jvl-fgzIlk5ZtJL9pIBMma24WK6EYdVu5po,20648
|
4
4
|
risk/annotations/__init__.py,sha256=vUpVvMRE5if01Ic8QY6M2Ae3EFGJHdugEe9PdEkAW4Y,138
|
5
5
|
risk/annotations/annotations.py,sha256=DRUTdGzMdqo62NWSapBUksbvPr9CrzD76qtOcxeNKmo,10554
|
6
6
|
risk/annotations/io.py,sha256=lo7NKqOVkeeBp58JBxWJHtA0xjL5Yoxqe9Ox0daKlZk,9457
|
7
7
|
risk/log/__init__.py,sha256=xuLImfxFlKpnVhzi_gDYlr2_c9cLkrw2c_3iEsXb1as,107
|
8
8
|
risk/log/console.py,sha256=im9DRExwf6wHlcn9fewoDcKIpo3vPcorZIaNAl-0csY,355
|
9
|
-
risk/log/params.py,sha256=
|
9
|
+
risk/log/params.py,sha256=Rfdg5UcGCrG80m6V79FyORERWUqIzHFO7tGiY4zAImM,6347
|
10
10
|
risk/neighborhoods/__init__.py,sha256=tKKEg4lsbqFukpgYlUGxU_v_9FOqK7V0uvM9T2QzoL0,206
|
11
11
|
risk/neighborhoods/community.py,sha256=7ebo1Q5KokSQISnxZIh2SQxsKXdXm8aVkp-h_DiQ3K0,6818
|
12
12
|
risk/neighborhoods/domains.py,sha256=5V--Nj-TrSdubhD_2PI57ffcn_PMSEgpX_iY5OjT6R8,10626
|
13
13
|
risk/neighborhoods/neighborhoods.py,sha256=sHmjFFl2U5qV9YbQCRbpbI36j7dS7IFfFwwRb1_-AuM,13945
|
14
14
|
risk/network/__init__.py,sha256=iEPeJdZfqp0toxtbElryB8jbz9_t_k4QQ3iDvKE8C_0,126
|
15
15
|
risk/network/geometry.py,sha256=H1yGVVqgbfpzBzJwEheDLfvGLSA284jGQQTn612L4Vc,6759
|
16
|
-
risk/network/graph.py,sha256=
|
16
|
+
risk/network/graph.py,sha256=7haHu4M3fleqbrIzs6HC9jnKizSERzmmAYSmUwdoSXA,13953
|
17
17
|
risk/network/io.py,sha256=gG50kOknO-D3HkW1HsbHMkTMvjUtn3l4W4Jwd-rXNr8,21202
|
18
|
-
risk/network/plot.py,sha256=
|
18
|
+
risk/network/plot.py,sha256=_g5xHolMTAfZCBvYYEX1CYME4s4zA2hTHtN-utaMPik,61978
|
19
19
|
risk/stats/__init__.py,sha256=e-BE_Dr_jgiK6hKM-T-tlG4yvHnId8e5qjnM0pdwNVc,230
|
20
20
|
risk/stats/fisher_exact.py,sha256=-bPwzu76-ob0HzrTV20mXUTot7v-MLuqFaAoab-QxPg,4966
|
21
21
|
risk/stats/hypergeom.py,sha256=lrIFdhCWRjvM4apYw1MlOKqT_IY5OjtCwrjdtJdt6Tg,4954
|
@@ -23,8 +23,8 @@ risk/stats/stats.py,sha256=kvShov-94W6ffgDUTb522vB9hDJQSyTsYif_UIaFfSM,7059
|
|
23
23
|
risk/stats/permutation/__init__.py,sha256=neJp7FENC-zg_CGOXqv-iIvz1r5XUKI9Ruxhmq7kDOI,105
|
24
24
|
risk/stats/permutation/permutation.py,sha256=qLWdwxEY6nmkYPxpM8HLDcd2mbqYv9Qr7CKtJvhLqIM,9220
|
25
25
|
risk/stats/permutation/test_functions.py,sha256=HuDIM-V1jkkfE1rlaIqrWWBSKZt3dQ1f-YEDjWpnLSE,2343
|
26
|
-
risk_network-0.0.
|
27
|
-
risk_network-0.0.
|
28
|
-
risk_network-0.0.
|
29
|
-
risk_network-0.0.
|
30
|
-
risk_network-0.0.
|
26
|
+
risk_network-0.0.6b10.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
|
27
|
+
risk_network-0.0.6b10.dist-info/METADATA,sha256=aOz9JrsPIpByvzMCwsbBBcBHHMLu3JzLF9FZMp9-IuM,43143
|
28
|
+
risk_network-0.0.6b10.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
29
|
+
risk_network-0.0.6b10.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
|
30
|
+
risk_network-0.0.6b10.dist-info/RECORD,,
|
File without changes
|
File without changes
|