risk-network 0.0.8b18__py3-none-any.whl → 0.0.9b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. risk/__init__.py +2 -2
  2. risk/annotations/__init__.py +2 -2
  3. risk/annotations/annotations.py +133 -72
  4. risk/annotations/io.py +50 -34
  5. risk/log/__init__.py +4 -2
  6. risk/log/{config.py → console.py} +5 -3
  7. risk/log/{params.py → parameters.py} +21 -46
  8. risk/neighborhoods/__init__.py +3 -5
  9. risk/neighborhoods/api.py +446 -0
  10. risk/neighborhoods/community.py +281 -96
  11. risk/neighborhoods/domains.py +92 -38
  12. risk/neighborhoods/neighborhoods.py +210 -149
  13. risk/network/__init__.py +1 -3
  14. risk/network/geometry.py +69 -58
  15. risk/network/graph/__init__.py +6 -0
  16. risk/network/graph/api.py +194 -0
  17. risk/network/graph/network.py +269 -0
  18. risk/network/graph/summary.py +254 -0
  19. risk/network/io.py +58 -48
  20. risk/network/plotter/__init__.py +6 -0
  21. risk/network/plotter/api.py +54 -0
  22. risk/network/{plot → plotter}/canvas.py +80 -26
  23. risk/network/{plot → plotter}/contour.py +43 -34
  24. risk/network/{plot → plotter}/labels.py +123 -113
  25. risk/network/plotter/network.py +424 -0
  26. risk/network/plotter/utils/colors.py +416 -0
  27. risk/network/plotter/utils/layout.py +94 -0
  28. risk/risk.py +11 -469
  29. risk/stats/__init__.py +8 -4
  30. risk/stats/binom.py +51 -0
  31. risk/stats/chi2.py +69 -0
  32. risk/stats/hypergeom.py +28 -18
  33. risk/stats/permutation/__init__.py +1 -1
  34. risk/stats/permutation/permutation.py +45 -39
  35. risk/stats/permutation/test_functions.py +25 -17
  36. risk/stats/poisson.py +17 -11
  37. risk/stats/stats.py +20 -16
  38. risk/stats/zscore.py +68 -0
  39. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/METADATA +9 -5
  40. risk_network-0.0.9b26.dist-info/RECORD +44 -0
  41. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/WHEEL +1 -1
  42. risk/network/graph.py +0 -159
  43. risk/network/plot/__init__.py +0 -6
  44. risk/network/plot/network.py +0 -282
  45. risk/network/plot/plotter.py +0 -137
  46. risk/network/plot/utils/color.py +0 -353
  47. risk/network/plot/utils/layout.py +0 -53
  48. risk_network-0.0.8b18.dist-info/RECORD +0 -37
  49. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/LICENSE +0 -0
  50. {risk_network-0.0.8b18.dist-info → risk_network-0.0.9b26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,44 @@
1
+ risk/__init__.py,sha256=vcWTNpBnxRYMGh8X2IkLccT4MfT6c0J94_SFHMe66Rk,127
2
+ risk/constants.py,sha256=XInRaH78Slnw_sWgAsBFbUHkyA0h0jL0DKGuQNbOvjM,550
3
+ risk/risk.py,sha256=s827_lRknFseOP9O4zW8sP-IcCd2EzrpV_tnVY_tz5s,1104
4
+ risk/annotations/__init__.py,sha256=parsbcux1U4urpUqh9AdzbDWuLj9HlMidycMPkpSQFo,179
5
+ risk/annotations/annotations.py,sha256=XmVuLL5NFAj6F30fZY22N8nb4LK6sig7fE0NXL1iZp8,14497
6
+ risk/annotations/io.py,sha256=z1AJySsU-KL_IYuHa7j3nvuczmOHgK3WfaQ4TRunvrA,10499
7
+ risk/log/__init__.py,sha256=7LxDysQu7doi0LAvlY2YbjN6iJH0fNknqy8lSLgeljo,217
8
+ risk/log/console.py,sha256=PgjyEvyhYLUSHXPUKEqOmxsDsfrjPICIgqo_cAHq0N8,4575
9
+ risk/log/parameters.py,sha256=VtwfMzLU1xI4yji3-Ch5vHjH-KdwTfwaEMmi7hFQTs0,5716
10
+ risk/neighborhoods/__init__.py,sha256=Q74HwTH7okI-vaskJPy2bYwb5sNjGASTzJ6m8V8arCU,234
11
+ risk/neighborhoods/api.py,sha256=KdUouMHJPwvePJGdz7Ck1GWYhN96QDb_SuPyTt3KwAc,23515
12
+ risk/neighborhoods/community.py,sha256=VIDvB-SsMDDvWkUaYXf_E-gcg0HELMVv2MKshPwJAFQ,15480
13
+ risk/neighborhoods/domains.py,sha256=jMJ4-Qzwgmo6Hya8h0E2_IcMaLpbuH_FWlmSjJl2ikc,12832
14
+ risk/neighborhoods/neighborhoods.py,sha256=bBUY7hXqcsOoAEkPdRoRNuj36WsllXicmz_LxZfEuyw,21186
15
+ risk/network/__init__.py,sha256=oVi3FA1XXKD84014Cykq-9bpX4_s0F3aAUfNOU-07Qw,73
16
+ risk/network/geometry.py,sha256=omyb9afSKMUtQ-RKVHUoRyxJifOW0ASenHjyCjg43kg,6836
17
+ risk/network/io.py,sha256=JV5hqf1oIwWUVw07BjhD0qACQGbtIeA8NSMDcFql88k,23465
18
+ risk/network/graph/__init__.py,sha256=ziGJew3yhtqvrb9LUuneDu_LwW2Wa9vd4UuhoL5l1CA,91
19
+ risk/network/graph/api.py,sha256=Ag4PjFTX6BUvmW7ZdfIgwdsr8URigX9jD9yEFRXUxrU,8220
20
+ risk/network/graph/network.py,sha256=KdIBM_-flHMWcBK4RUjU_QRfOZIf_yv9fv4L7AOLkqU,12199
21
+ risk/network/graph/summary.py,sha256=8IenFZfhyzcg5aGNJp7Zjb0Umy0mFNmJlfwXcO7y8MU,10311
22
+ risk/network/plotter/__init__.py,sha256=ixXQxpBVpNIz1y9tUHZ7CiJmGfewvbvjuB1LQ-AIf1s,93
23
+ risk/network/plotter/api.py,sha256=cLZHq-rn_5FJwIWM5hYlQMobPmaxCE-P2iqgxTDIOTQ,1860
24
+ risk/network/plotter/canvas.py,sha256=l-Se86DMDJMHh8Yn-_hsl0_ipoazHLJGRCqXcc9HK4M,13498
25
+ risk/network/plotter/contour.py,sha256=svi76suYlVYq2VoDQxXmun8Hmo0lI2CQRjAyHg0qdhk,15490
26
+ risk/network/plotter/labels.py,sha256=QesD1ybseA6ldLmWMqVaAqSPR34yVEgEzXzg1AKQD6o,45513
27
+ risk/network/plotter/network.py,sha256=wcBf1GaM1wPzW-iXTrLzOmlG2_9wwfll_hJUzUO2u2Y,19917
28
+ risk/network/plotter/utils/colors.py,sha256=EFlIUZ3MGSKoHeZi9cgR6uLKK5GGJ4QzE6lmnrHViLw,18967
29
+ risk/network/plotter/utils/layout.py,sha256=2P4Bqi1dGiX9KsriLYqiq1KlHpsMdZemAUza4WcYoNA,3634
30
+ risk/stats/__init__.py,sha256=1CPRtT1LDwudrvFgkVtSom8cp4cM7b4X6b4fHPaNHw0,405
31
+ risk/stats/binom.py,sha256=8Qwcxnq1u-AycwQs_sQxwuxgkgDpES-A-kIcj4fRc3g,2032
32
+ risk/stats/chi2.py,sha256=MGFNrWP40i9TxnMsZYbDgqdMrN_Fe0xFsnWU8xNsVSs,3046
33
+ risk/stats/hypergeom.py,sha256=VfQBtpgSGG826uBP1WyBMavP3ylZnhponUZ2rHFdGAE,2502
34
+ risk/stats/poisson.py,sha256=_KHe9g8XNRD4-Q486zx2UgHCO2QyvBOiHuX3hRZLEqc,2050
35
+ risk/stats/stats.py,sha256=y2DMJF3uKRIWRyYiCd2Kwxa-EqOzX5HsMBms_Vw6wK8,7322
36
+ risk/stats/zscore.py,sha256=Jx9cLKAHiDnrgW_Su9KZYYQiTVsuyJMC7vXBusnEI-c,2648
37
+ risk/stats/permutation/__init__.py,sha256=OLmYLm2uj96hPsSaUs0vUqFYw6Thwch_aHtpL7L0ZFw,127
38
+ risk/stats/permutation/permutation.py,sha256=BWjgdBpLVcHvmwHy0bmD4aJFccxifNBSrrCBPppyKf4,10569
39
+ risk/stats/permutation/test_functions.py,sha256=D3XMPM8CasUNytWSRce22TI6KK6XulYn5uGG4lWxaHs,3120
40
+ risk_network-0.0.9b26.dist-info/LICENSE,sha256=jOtLnuWt7d5Hsx6XXB2QxzrSe2sWWh3NgMfFRetluQM,35147
41
+ risk_network-0.0.9b26.dist-info/METADATA,sha256=7nzWmesgQYNnymMbG9WejhChWxobBvbFgQQD0VKd3n0,47627
42
+ risk_network-0.0.9b26.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
43
+ risk_network-0.0.9b26.dist-info/top_level.txt,sha256=NX7C2PFKTvC1JhVKv14DFlFAIFnKc6Lpsu1ZfxvQwVw,5
44
+ risk_network-0.0.9b26.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
risk/network/graph.py DELETED
@@ -1,159 +0,0 @@
1
- """
2
- risk/network/graph
3
- ~~~~~~~~~~~~~~~~~~
4
- """
5
-
6
- from collections import defaultdict
7
- from typing import Any, Dict, List
8
-
9
- import networkx as nx
10
- import numpy as np
11
- import pandas as pd
12
-
13
-
14
- class NetworkGraph:
15
- """A class to represent a network graph and process its nodes and edges.
16
-
17
- The NetworkGraph class provides functionality to handle and manipulate a network graph,
18
- including managing domains, annotations, and node enrichment data. It also includes methods
19
- for transforming and mapping graph coordinates, as well as generating colors based on node
20
- enrichment.
21
- """
22
-
23
- def __init__(
24
- self,
25
- network: nx.Graph,
26
- top_annotations: pd.DataFrame,
27
- domains: pd.DataFrame,
28
- trimmed_domains: pd.DataFrame,
29
- node_label_to_node_id_map: Dict[str, Any],
30
- node_enrichment_sums: np.ndarray,
31
- ):
32
- """Initialize the NetworkGraph object.
33
-
34
- Args:
35
- network (nx.Graph): The network graph.
36
- top_annotations (pd.DataFrame): DataFrame containing annotations data for the network nodes.
37
- domains (pd.DataFrame): DataFrame containing domain data for the network nodes.
38
- trimmed_domains (pd.DataFrame): DataFrame containing trimmed domain data for the network nodes.
39
- node_label_to_node_id_map (dict): A dictionary mapping node labels to their corresponding IDs.
40
- node_enrichment_sums (np.ndarray): Array containing the enrichment sums for the nodes.
41
- """
42
- self.top_annotations = top_annotations
43
- self.domain_id_to_node_ids_map = self._create_domain_id_to_node_ids_map(domains)
44
- self.domains = domains
45
- self.domain_id_to_domain_terms_map = self._create_domain_id_to_domain_terms_map(
46
- trimmed_domains
47
- )
48
- self.node_enrichment_sums = node_enrichment_sums
49
- self.node_id_to_node_label_map = {v: k for k, v in node_label_to_node_id_map.items()}
50
- self.node_label_to_enrichment_map = dict(
51
- zip(node_label_to_node_id_map.keys(), node_enrichment_sums)
52
- )
53
- self.node_label_to_node_id_map = node_label_to_node_id_map
54
- # NOTE: Below this point, instance attributes (i.e., self) will be used!
55
- self.domain_id_to_node_labels_map = self._create_domain_id_to_node_labels_map()
56
- # Unfold the network's 3D coordinates to 2D and extract node coordinates
57
- self.network = _unfold_sphere_to_plane(network)
58
- self.node_coordinates = _extract_node_coordinates(self.network)
59
-
60
- def _create_domain_id_to_node_ids_map(self, domains: pd.DataFrame) -> Dict[str, Any]:
61
- """Create a mapping from domains to the list of node IDs belonging to each domain.
62
-
63
- Args:
64
- domains (pd.DataFrame): DataFrame containing domain information, including the 'primary domain' for each node.
65
-
66
- Returns:
67
- dict: A dictionary where keys are domain IDs and values are lists of node IDs belonging to each domain.
68
- """
69
- cleaned_domains_matrix = domains.reset_index()[["index", "primary domain"]]
70
- node_to_domains_map = cleaned_domains_matrix.set_index("index")["primary domain"].to_dict()
71
- domain_id_to_node_ids_map = defaultdict(list)
72
- for k, v in node_to_domains_map.items():
73
- domain_id_to_node_ids_map[v].append(k)
74
-
75
- return domain_id_to_node_ids_map
76
-
77
- def _create_domain_id_to_domain_terms_map(
78
- self, trimmed_domains: pd.DataFrame
79
- ) -> Dict[str, Any]:
80
- """Create a mapping from domain IDs to their corresponding terms.
81
-
82
- Args:
83
- trimmed_domains (pd.DataFrame): DataFrame containing domain IDs and their corresponding labels.
84
-
85
- Returns:
86
- dict: A dictionary mapping domain IDs to their corresponding terms.
87
- """
88
- return dict(
89
- zip(
90
- trimmed_domains.index,
91
- trimmed_domains["label"],
92
- )
93
- )
94
-
95
- def _create_domain_id_to_node_labels_map(self) -> Dict[int, List[str]]:
96
- """Create a map from domain IDs to node labels.
97
-
98
- Returns:
99
- dict: A dictionary mapping domain IDs to the corresponding node labels.
100
- """
101
- domain_id_to_label_map = {}
102
- for domain_id, node_ids in self.domain_id_to_node_ids_map.items():
103
- domain_id_to_label_map[domain_id] = [
104
- self.node_id_to_node_label_map[node_id] for node_id in node_ids
105
- ]
106
-
107
- return domain_id_to_label_map
108
-
109
-
110
- def _unfold_sphere_to_plane(G: nx.Graph) -> nx.Graph:
111
- """Convert 3D coordinates to 2D by unfolding a sphere to a plane.
112
-
113
- Args:
114
- G (nx.Graph): A network graph with 3D coordinates. Each node should have 'x', 'y', and 'z' attributes.
115
-
116
- Returns:
117
- nx.Graph: The network graph with updated 2D coordinates (only 'x' and 'y').
118
- """
119
- for node in G.nodes():
120
- if "z" in G.nodes[node]:
121
- # Extract 3D coordinates
122
- x, y, z = G.nodes[node]["x"], G.nodes[node]["y"], G.nodes[node]["z"]
123
- # Calculate spherical coordinates theta and phi from Cartesian coordinates
124
- r = np.sqrt(x**2 + y**2 + z**2)
125
- theta = np.arctan2(y, x)
126
- phi = np.arccos(z / r)
127
-
128
- # Convert spherical coordinates to 2D plane coordinates
129
- unfolded_x = (theta + np.pi) / (2 * np.pi) # Shift and normalize theta to [0, 1]
130
- unfolded_x = unfolded_x + 0.5 if unfolded_x < 0.5 else unfolded_x - 0.5
131
- unfolded_y = (np.pi - phi) / np.pi # Reflect phi and normalize to [0, 1]
132
- # Update network node attributes
133
- G.nodes[node]["x"] = unfolded_x
134
- G.nodes[node]["y"] = -unfolded_y
135
- # Remove the 'z' coordinate as it's no longer needed
136
- del G.nodes[node]["z"]
137
-
138
- return G
139
-
140
-
141
- def _extract_node_coordinates(G: nx.Graph) -> np.ndarray:
142
- """Extract 2D coordinates of nodes from the graph.
143
-
144
- Args:
145
- G (nx.Graph): The network graph with node coordinates.
146
-
147
- Returns:
148
- np.ndarray: Array of node coordinates with shape (num_nodes, 2).
149
- """
150
- # Extract x and y coordinates from graph nodes
151
- x_coords = dict(G.nodes.data("x"))
152
- y_coords = dict(G.nodes.data("y"))
153
- coordinates_dicts = [x_coords, y_coords]
154
- # Combine x and y coordinates into a single array
155
- node_positions = {
156
- node: np.array([coords[node] for coords in coordinates_dicts]) for node in x_coords
157
- }
158
- node_coordinates = np.vstack(list(node_positions.values()))
159
- return node_coordinates
@@ -1,6 +0,0 @@
1
- """
2
- risk/network/plot
3
- ~~~~~~~~~~~~~~~~~
4
- """
5
-
6
- from risk.network.plot.plotter import NetworkPlotter
@@ -1,282 +0,0 @@
1
- """
2
- risk/network/plot/network
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~
4
- """
5
-
6
- from typing import Any, List, Tuple, Union
7
-
8
- import networkx as nx
9
- import numpy as np
10
-
11
- from risk.log import params
12
- from risk.network.graph import NetworkGraph
13
- from risk.network.plot.utils.color import get_domain_colors, to_rgba
14
-
15
-
16
- class Network:
17
- """Class for plotting nodes and edges in a network graph."""
18
-
19
- def __init__(self, graph: NetworkGraph, ax: Any = None) -> None:
20
- """Initialize the NetworkPlotter class.
21
-
22
- Args:
23
- graph (NetworkGraph): The network data and attributes to be visualized.
24
- ax (Any, optional): Axes object to plot the network graph. Defaults to None.
25
- """
26
- self.graph = graph
27
- self.ax = ax
28
-
29
- def plot_network(
30
- self,
31
- node_size: Union[int, np.ndarray] = 50,
32
- node_shape: str = "o",
33
- node_edgewidth: float = 1.0,
34
- edge_width: float = 1.0,
35
- node_color: Union[str, List, Tuple, np.ndarray] = "white",
36
- node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
37
- edge_color: Union[str, List, Tuple, np.ndarray] = "black",
38
- node_alpha: Union[float, None] = 1.0,
39
- edge_alpha: Union[float, None] = 1.0,
40
- ) -> None:
41
- """Plot the network graph with customizable node colors, sizes, edge widths, and node edge widths.
42
-
43
- Args:
44
- node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
45
- node_shape (str, optional): Shape of the nodes. Defaults to "o".
46
- node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
47
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
48
- node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Can be a single color or an array of colors.
49
- Defaults to "white".
50
- node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Can be a single color or an array of colors.
51
- Defaults to "black".
52
- edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Can be a single color or an array of colors.
53
- Defaults to "black".
54
- node_alpha (float, None, optional): Alpha value (transparency) for the nodes. If provided, it overrides any existing alpha
55
- values found in node_color. Defaults to 1.0. Annotated node_color alphas will override this value.
56
- edge_alpha (float, None, optional): Alpha value (transparency) for the edges. If provided, it overrides any existing alpha
57
- values found in edge_color. Defaults to 1.0.
58
- """
59
- # Log the plotting parameters
60
- params.log_plotter(
61
- network_node_size=(
62
- "custom" if isinstance(node_size, np.ndarray) else node_size
63
- ), # np.ndarray usually indicates custom sizes
64
- network_node_shape=node_shape,
65
- network_node_edgewidth=node_edgewidth,
66
- network_edge_width=edge_width,
67
- network_node_color=(
68
- "custom" if isinstance(node_color, np.ndarray) else node_color
69
- ), # np.ndarray usually indicates custom colors
70
- network_node_edgecolor=node_edgecolor,
71
- network_edge_color=edge_color,
72
- network_node_alpha=node_alpha,
73
- network_edge_alpha=edge_alpha,
74
- )
75
-
76
- # Convert colors to RGBA using the to_rgba helper function
77
- # If node_colors was generated using get_annotated_node_colors, its alpha values will override node_alpha
78
- node_color = to_rgba(
79
- color=node_color, alpha=node_alpha, num_repeats=len(self.graph.network.nodes)
80
- )
81
- node_edgecolor = to_rgba(
82
- color=node_edgecolor, alpha=1.0, num_repeats=len(self.graph.network.nodes)
83
- )
84
- edge_color = to_rgba(
85
- color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
86
- )
87
-
88
- # Extract node coordinates from the network graph
89
- node_coordinates = self.graph.node_coordinates
90
-
91
- # Draw the nodes of the graph
92
- nx.draw_networkx_nodes(
93
- self.graph.network,
94
- pos=node_coordinates,
95
- node_size=node_size,
96
- node_shape=node_shape,
97
- node_color=node_color,
98
- edgecolors=node_edgecolor,
99
- linewidths=node_edgewidth,
100
- ax=self.ax,
101
- )
102
- # Draw the edges of the graph
103
- nx.draw_networkx_edges(
104
- self.graph.network,
105
- pos=node_coordinates,
106
- width=edge_width,
107
- edge_color=edge_color,
108
- ax=self.ax,
109
- )
110
-
111
- def plot_subnetwork(
112
- self,
113
- nodes: Union[List, Tuple, np.ndarray],
114
- node_size: Union[int, np.ndarray] = 50,
115
- node_shape: str = "o",
116
- node_edgewidth: float = 1.0,
117
- edge_width: float = 1.0,
118
- node_color: Union[str, List, Tuple, np.ndarray] = "white",
119
- node_edgecolor: Union[str, List, Tuple, np.ndarray] = "black",
120
- edge_color: Union[str, List, Tuple, np.ndarray] = "black",
121
- node_alpha: Union[float, None] = None,
122
- edge_alpha: Union[float, None] = None,
123
- ) -> None:
124
- """Plot a subnetwork of selected nodes with customizable node and edge attributes.
125
-
126
- Args:
127
- nodes (list, tuple, or np.ndarray): List of node labels to include in the subnetwork. Accepts nested lists.
128
- node_size (int or np.ndarray, optional): Size of the nodes. Can be a single integer or an array of sizes. Defaults to 50.
129
- node_shape (str, optional): Shape of the nodes. Defaults to "o".
130
- node_edgewidth (float, optional): Width of the node edges. Defaults to 1.0.
131
- edge_width (float, optional): Width of the edges. Defaults to 1.0.
132
- node_color (str, list, tuple, or np.ndarray, optional): Color of the nodes. Defaults to "white".
133
- node_edgecolor (str, list, tuple, or np.ndarray, optional): Color of the node edges. Defaults to "black".
134
- edge_color (str, list, tuple, or np.ndarray, optional): Color of the edges. Defaults to "black".
135
- node_alpha (float, None, optional): Transparency for the nodes. If provided, it overrides any existing alpha values
136
- found in node_color. Defaults to 1.0.
137
- edge_alpha (float, None, optional): Transparency for the edges. If provided, it overrides any existing alpha values
138
- found in node_color. Defaults to 1.0.
139
-
140
- Raises:
141
- ValueError: If no valid nodes are found in the network graph.
142
- """
143
- # Flatten nested lists of nodes, if necessary
144
- if any(isinstance(item, (list, tuple, np.ndarray)) for item in nodes):
145
- nodes = [node for sublist in nodes for node in sublist]
146
-
147
- # Filter to get node IDs and their coordinates
148
- node_ids = [
149
- self.graph.node_label_to_node_id_map.get(node)
150
- for node in nodes
151
- if node in self.graph.node_label_to_node_id_map
152
- ]
153
- if not node_ids:
154
- raise ValueError("No nodes found in the network graph.")
155
-
156
- # Check if node_color is a single color or a list of colors
157
- if not isinstance(node_color, (str, tuple, np.ndarray)):
158
- node_color = [
159
- node_color[nodes.index(node)]
160
- for node in nodes
161
- if node in self.graph.node_label_to_node_id_map
162
- ]
163
-
164
- # Convert colors to RGBA using the to_rgba helper function
165
- node_color = to_rgba(color=node_color, alpha=node_alpha, num_repeats=len(node_ids))
166
- node_edgecolor = to_rgba(color=node_edgecolor, alpha=1.0, num_repeats=len(node_ids))
167
- edge_color = to_rgba(
168
- color=edge_color, alpha=edge_alpha, num_repeats=len(self.graph.network.edges)
169
- )
170
-
171
- # Get the coordinates of the filtered nodes
172
- node_coordinates = {node_id: self.graph.node_coordinates[node_id] for node_id in node_ids}
173
-
174
- # Draw the nodes in the subnetwork
175
- nx.draw_networkx_nodes(
176
- self.graph.network,
177
- pos=node_coordinates,
178
- nodelist=node_ids,
179
- node_size=node_size,
180
- node_shape=node_shape,
181
- node_color=node_color,
182
- edgecolors=node_edgecolor,
183
- linewidths=node_edgewidth,
184
- ax=self.ax,
185
- )
186
- # Draw the edges between the specified nodes in the subnetwork
187
- subgraph = self.graph.network.subgraph(node_ids)
188
- nx.draw_networkx_edges(
189
- subgraph,
190
- pos=node_coordinates,
191
- width=edge_width,
192
- edge_color=edge_color,
193
- ax=self.ax,
194
- )
195
-
196
- def get_annotated_node_colors(
197
- self,
198
- cmap: str = "gist_rainbow",
199
- color: Union[str, list, tuple, np.ndarray, None] = None,
200
- min_scale: float = 0.8,
201
- max_scale: float = 1.0,
202
- scale_factor: float = 1.0,
203
- alpha: Union[float, None] = 1.0,
204
- nonenriched_color: Union[str, list, tuple, np.ndarray] = "white",
205
- nonenriched_alpha: Union[float, None] = 1.0,
206
- random_seed: int = 888,
207
- ) -> np.ndarray:
208
- """Adjust the colors of nodes in the network graph based on enrichment.
209
-
210
- Args:
211
- cmap (str, optional): Colormap to use for coloring the nodes. Defaults to "gist_rainbow".
212
- color (str, list, tuple, np.ndarray, or None, optional): Color to use for the nodes. Can be a single color or an array of colors.
213
- If None, the colormap will be used. Defaults to None.
214
- min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
215
- max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
216
- scale_factor (float, optional): Factor for adjusting the color scaling intensity. Defaults to 1.0.
217
- alpha (float, None, optional): Alpha value for enriched nodes. If provided, it overrides any existing alpha values found in `color`.
218
- Defaults to 1.0.
219
- nonenriched_color (str, list, tuple, or np.ndarray, optional): Color for non-enriched nodes. Can be a single color or an array of colors.
220
- Defaults to "white".
221
- nonenriched_alpha (float, None, optional): Alpha value for non-enriched nodes. If provided, it overrides any existing alpha values found
222
- in `nonenriched_color`. Defaults to 1.0.
223
- random_seed (int, optional): Seed for random number generation. Defaults to 888.
224
-
225
- Returns:
226
- np.ndarray: Array of RGBA colors adjusted for enrichment status.
227
- """
228
- # Get the initial domain colors for each node, which are returned as RGBA
229
- network_colors = get_domain_colors(
230
- graph=self.graph,
231
- cmap=cmap,
232
- color=color,
233
- min_scale=min_scale,
234
- max_scale=max_scale,
235
- scale_factor=scale_factor,
236
- random_seed=random_seed,
237
- )
238
- # Apply the alpha value for enriched nodes
239
- network_colors[:, 3] = alpha # Apply the alpha value to the enriched nodes' A channel
240
- # Convert the non-enriched color to RGBA using the to_rgba helper function
241
- nonenriched_color = to_rgba(
242
- color=nonenriched_color, alpha=nonenriched_alpha, num_repeats=1
243
- ) # num_repeats=1 for a single color
244
- # Adjust node colors: replace any nodes where all three RGB values are equal and less than 0.1
245
- # 0.1 is a predefined threshold for the minimum color intensity
246
- adjusted_network_colors = np.where(
247
- (
248
- np.all(network_colors[:, :3] < 0.1, axis=1)
249
- & np.all(network_colors[:, :3] == network_colors[:, 0:1], axis=1)
250
- )[:, None],
251
- np.tile(
252
- np.array(nonenriched_color), (network_colors.shape[0], 1)
253
- ), # Replace with the full RGBA non-enriched color
254
- network_colors, # Keep the original colors where no match is found
255
- )
256
- return adjusted_network_colors
257
-
258
- def get_annotated_node_sizes(
259
- self, enriched_size: int = 50, nonenriched_size: int = 25
260
- ) -> np.ndarray:
261
- """Adjust the sizes of nodes in the network graph based on whether they are enriched or not.
262
-
263
- Args:
264
- enriched_size (int): Size for enriched nodes. Defaults to 50.
265
- nonenriched_size (int): Size for non-enriched nodes. Defaults to 25.
266
-
267
- Returns:
268
- np.ndarray: Array of node sizes, with enriched nodes larger than non-enriched ones.
269
- """
270
- # Merge all enriched nodes from the domain_id_to_node_ids_map dictionary
271
- enriched_nodes = set()
272
- for _, node_ids in self.graph.domain_id_to_node_ids_map.items():
273
- enriched_nodes.update(node_ids)
274
-
275
- # Initialize all node sizes to the non-enriched size
276
- node_sizes = np.full(len(self.graph.network.nodes), nonenriched_size)
277
- # Set the size for enriched nodes
278
- for node in enriched_nodes:
279
- if node in self.graph.network.nodes:
280
- node_sizes[node] = enriched_size
281
-
282
- return node_sizes
@@ -1,137 +0,0 @@
1
- """
2
- risk/network/plot/plotter
3
- ~~~~~~~~~~~~~~~~~~~~~~~~~
4
- """
5
-
6
- from typing import List, Tuple, Union
7
-
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
-
11
- from risk.log import params
12
- from risk.network.graph import NetworkGraph
13
- from risk.network.plot.canvas import Canvas
14
- from risk.network.plot.contour import Contour
15
- from risk.network.plot.labels import Labels
16
- from risk.network.plot.network import Network
17
- from risk.network.plot.utils.color import to_rgba
18
- from risk.network.plot.utils.layout import calculate_bounding_box
19
-
20
-
21
- class NetworkPlotter(Canvas, Network, Contour, Labels):
22
- """A class for visualizing network graphs with customizable options.
23
-
24
- The NetworkPlotter class uses a NetworkGraph object and provides methods to plot the network with
25
- flexible node and edge properties. It also supports plotting labels, contours, drawing the network's
26
- perimeter, and adjusting background colors.
27
- """
28
-
29
- def __init__(
30
- self,
31
- graph: NetworkGraph,
32
- figsize: Tuple = (10, 10),
33
- background_color: Union[str, List, Tuple, np.ndarray] = "white",
34
- background_alpha: Union[float, None] = 1.0,
35
- pad: float = 0.3,
36
- ) -> None:
37
- """Initialize the NetworkPlotter with a NetworkGraph object and plotting parameters.
38
-
39
- Args:
40
- graph (NetworkGraph): The network data and attributes to be visualized.
41
- figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (10, 10).
42
- background_color (str, list, tuple, np.ndarray, optional): Background color of the plot. Defaults to "white".
43
- background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
44
- any existing alpha values found in background_color. Defaults to 1.0.
45
- pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
46
- """
47
- self.graph = graph
48
- # Initialize the plot with the specified parameters
49
- self.ax = self._initialize_plot(
50
- graph=graph,
51
- figsize=figsize,
52
- background_color=background_color,
53
- background_alpha=background_alpha,
54
- pad=pad,
55
- )
56
- super().__init__(graph=graph, ax=self.ax)
57
-
58
- def _initialize_plot(
59
- self,
60
- graph: NetworkGraph,
61
- figsize: Tuple,
62
- background_color: Union[str, list, tuple, np.ndarray],
63
- background_alpha: Union[float, None],
64
- pad: float,
65
- ) -> plt.Axes:
66
- """Set up the plot with figure size and background color.
67
-
68
- Args:
69
- graph (NetworkGraph): The network data and attributes to be visualized.
70
- figsize (tuple): Size of the figure in inches (width, height).
71
- background_color (str, list, tuple, or np.ndarray): Background color of the plot. Can be a single color or an array of colors.
72
- background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides any existing
73
- alpha values found in `background_color`.
74
- pad (float, optional): Padding value to adjust the axis limits.
75
-
76
- Returns:
77
- plt.Axes: The axis object for the plot.
78
- """
79
- # Log the plotter settings
80
- params.log_plotter(
81
- figsize=figsize,
82
- background_color=background_color,
83
- background_alpha=background_alpha,
84
- pad=pad,
85
- )
86
-
87
- # Extract node coordinates from the network graph
88
- node_coordinates = graph.node_coordinates
89
- # Calculate the center and radius of the bounding box around the network
90
- center, radius = calculate_bounding_box(node_coordinates)
91
-
92
- # Create a new figure and axis for plotting
93
- fig, ax = plt.subplots(figsize=figsize)
94
- fig.tight_layout() # Adjust subplot parameters to give specified padding
95
- # Set axis limits based on the calculated bounding box and radius
96
- ax.set_xlim([center[0] - radius - pad, center[0] + radius + pad])
97
- ax.set_ylim([center[1] - radius - pad, center[1] + radius + pad])
98
- ax.set_aspect("equal") # Ensure the aspect ratio is equal
99
-
100
- # Set the background color of the plot
101
- # Convert color to RGBA using the to_rgba helper function
102
- fig.patch.set_facecolor(
103
- to_rgba(color=background_color, alpha=background_alpha, num_repeats=1)
104
- ) # num_repeats=1 for single color
105
- ax.invert_yaxis() # Invert the y-axis to match typical image coordinates
106
- # Remove axis spines for a cleaner look
107
- for spine in ax.spines.values():
108
- spine.set_visible(False)
109
-
110
- # Hide axis ticks and labels
111
- ax.set_xticks([])
112
- ax.set_yticks([])
113
- ax.patch.set_visible(False) # Hide the axis background
114
-
115
- return ax
116
-
117
- @staticmethod
118
- def savefig(*args, pad_inches: float = 0.5, dpi: int = 100, **kwargs) -> None:
119
- """Save the current plot to a file with additional export options.
120
-
121
- Args:
122
- *args: Positional arguments passed to `plt.savefig`.
123
- pad_inches (float, optional): Padding around the figure when saving. Defaults to 0.5.
124
- dpi (int, optional): Dots per inch (DPI) for the exported image. Defaults to 300.
125
- **kwargs: Keyword arguments passed to `plt.savefig`, such as filename and format.
126
- """
127
- plt.savefig(*args, bbox_inches="tight", pad_inches=pad_inches, dpi=dpi, **kwargs)
128
-
129
- @staticmethod
130
- def show(*args, **kwargs) -> None:
131
- """Display the current plot.
132
-
133
- Args:
134
- *args: Positional arguments passed to `plt.show`.
135
- **kwargs: Keyword arguments passed to `plt.show`.
136
- """
137
- plt.show(*args, **kwargs)