risk-network 0.0.9b26__py3-none-any.whl → 0.0.9b28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
- risk/annotations/annotations.py +39 -38
- risk/neighborhoods/api.py +1 -5
- risk/neighborhoods/community.py +140 -95
- risk/neighborhoods/neighborhoods.py +34 -18
- risk/network/geometry.py +24 -27
- risk/network/graph/api.py +6 -6
- risk/network/graph/{network.py → graph.py} +7 -7
- risk/network/graph/summary.py +3 -3
- risk/network/io.py +39 -15
- risk/network/plotter/__init__.py +2 -2
- risk/network/plotter/api.py +12 -12
- risk/network/plotter/canvas.py +7 -7
- risk/network/plotter/contour.py +6 -6
- risk/network/plotter/labels.py +5 -5
- risk/network/plotter/network.py +6 -136
- risk/network/plotter/plotter.py +143 -0
- risk/network/plotter/utils/colors.py +11 -11
- risk/network/plotter/utils/layout.py +2 -2
- risk/stats/__init__.py +8 -6
- risk/stats/{stats.py → significance.py} +2 -2
- risk/stats/stat_tests.py +272 -0
- {risk_network-0.0.9b26.dist-info → risk_network-0.0.9b28.dist-info}/METADATA +1 -1
- risk_network-0.0.9b28.dist-info/RECORD +41 -0
- risk/stats/binom.py +0 -51
- risk/stats/chi2.py +0 -69
- risk/stats/hypergeom.py +0 -64
- risk/stats/poisson.py +0 -50
- risk/stats/zscore.py +0 -68
- risk_network-0.0.9b26.dist-info/RECORD +0 -44
- {risk_network-0.0.9b26.dist-info → risk_network-0.0.9b28.dist-info}/LICENSE +0 -0
- {risk_network-0.0.9b26.dist-info → risk_network-0.0.9b28.dist-info}/WHEEL +0 -0
- {risk_network-0.0.9b26.dist-info → risk_network-0.0.9b28.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple, Union
|
|
9
9
|
|
10
10
|
import networkx as nx
|
11
11
|
import numpy as np
|
12
|
+
from scipy.sparse import csr_matrix
|
12
13
|
from sklearn.exceptions import DataConversionWarning
|
13
14
|
from sklearn.metrics.pairwise import cosine_similarity
|
14
15
|
|
@@ -34,43 +35,43 @@ def get_network_neighborhoods(
|
|
34
35
|
louvain_resolution: float = 0.1,
|
35
36
|
leiden_resolution: float = 1.0,
|
36
37
|
random_seed: int = 888,
|
37
|
-
) ->
|
38
|
-
"""Calculate the combined neighborhoods for each node
|
38
|
+
) -> csr_matrix:
|
39
|
+
"""Calculate the combined neighborhoods for each node using sparse matrices.
|
39
40
|
|
40
41
|
Args:
|
41
42
|
network (nx.Graph): The network graph.
|
42
43
|
distance_metric (str, List, Tuple, or np.ndarray, optional): The distance metric(s) to use.
|
43
|
-
fraction_shortest_edges (float, List, Tuple, or np.ndarray, optional): Shortest edge rank fraction
|
44
|
+
fraction_shortest_edges (float, List, Tuple, or np.ndarray, optional): Shortest edge rank fraction thresholds.
|
44
45
|
louvain_resolution (float, optional): Resolution parameter for the Louvain method.
|
45
46
|
leiden_resolution (float, optional): Resolution parameter for the Leiden method.
|
46
47
|
random_seed (int, optional): Random seed for methods requiring random initialization.
|
47
48
|
|
48
49
|
Returns:
|
49
|
-
|
50
|
+
csr_matrix: The combined neighborhood matrix.
|
50
51
|
"""
|
51
52
|
# Set random seed for reproducibility
|
52
53
|
random.seed(random_seed)
|
53
54
|
np.random.seed(random_seed)
|
54
55
|
|
55
|
-
# Ensure distance_metric is a list
|
56
|
+
# Ensure distance_metric is a list for multi-algorithm handling
|
56
57
|
if isinstance(distance_metric, (str, np.ndarray)):
|
57
58
|
distance_metric = [distance_metric]
|
58
|
-
# Ensure fraction_shortest_edges is a list
|
59
|
+
# Ensure fraction_shortest_edges is a list for multi-threshold handling
|
59
60
|
if isinstance(fraction_shortest_edges, (float, int)):
|
60
61
|
fraction_shortest_edges = [fraction_shortest_edges] * len(distance_metric)
|
61
|
-
#
|
62
|
+
# Validate matching lengths of distance metrics and thresholds
|
62
63
|
if len(distance_metric) != len(fraction_shortest_edges):
|
63
64
|
raise ValueError(
|
64
65
|
"The number of distance metrics must match the number of edge length thresholds."
|
65
66
|
)
|
66
67
|
|
67
|
-
# Initialize
|
68
|
+
# Initialize a sparse LIL matrix for incremental updates
|
68
69
|
num_nodes = network.number_of_nodes()
|
69
|
-
|
70
|
-
|
70
|
+
# Initialize a sparse matrix with the same shape as the network
|
71
|
+
combined_neighborhoods = csr_matrix((num_nodes, num_nodes), dtype=np.uint8)
|
71
72
|
# Loop through each distance metric and corresponding edge rank fraction
|
72
73
|
for metric, percentile in zip(distance_metric, fraction_shortest_edges):
|
73
|
-
#
|
74
|
+
# Compute neighborhoods for the specified metric
|
74
75
|
if metric == "greedy_modularity":
|
75
76
|
neighborhoods = calculate_greedy_modularity_neighborhoods(
|
76
77
|
network, fraction_shortest_edges=percentile
|
@@ -107,22 +108,37 @@ def get_network_neighborhoods(
|
|
107
108
|
)
|
108
109
|
else:
|
109
110
|
raise ValueError(
|
110
|
-
"
|
111
|
+
"Invalid distance metric. Choose from: 'greedy_modularity', 'label_propagation',"
|
111
112
|
"'leiden', 'louvain', 'markov_clustering', 'spinglass', 'walktrap'."
|
112
113
|
)
|
113
114
|
|
114
|
-
#
|
115
|
+
# Add the sparse neighborhood matrix
|
115
116
|
combined_neighborhoods += neighborhoods
|
116
117
|
|
117
|
-
# Ensure
|
118
|
-
|
119
|
-
# while all other values are reset to 0. This transformation simplifies the neighborhood matrix by
|
120
|
-
# focusing on the most significant connection per row (or nodes).
|
121
|
-
combined_neighborhoods = _set_max_row_value_to_one(combined_neighborhoods)
|
118
|
+
# Ensure maximum value in each row is set to 1
|
119
|
+
combined_neighborhoods = _set_max_row_value_to_one_sparse(combined_neighborhoods)
|
122
120
|
|
123
121
|
return combined_neighborhoods
|
124
122
|
|
125
123
|
|
124
|
+
def _set_max_row_value_to_one_sparse(matrix: csr_matrix) -> csr_matrix:
|
125
|
+
"""Set the maximum value in each row of a sparse matrix to 1.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
matrix (csr_matrix): The input sparse matrix.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
csr_matrix: The modified sparse matrix where only the maximum value in each row is set to 1.
|
132
|
+
"""
|
133
|
+
# Iterate over each row and set the maximum value to 1
|
134
|
+
for i in range(matrix.shape[0]):
|
135
|
+
row_data = matrix[i].data
|
136
|
+
if len(row_data) > 0:
|
137
|
+
row_data[:] = (row_data == max(row_data)).astype(int)
|
138
|
+
|
139
|
+
return matrix
|
140
|
+
|
141
|
+
|
126
142
|
def _set_max_row_value_to_one(matrix: np.ndarray) -> np.ndarray:
|
127
143
|
"""For each row in the input matrix, set the maximum value(s) to 1 and all other values to 0. This is particularly
|
128
144
|
useful for neighborhood matrices that have undergone multiple neighborhood detection algorithms, where the
|
risk/network/geometry.py
CHANGED
@@ -3,8 +3,6 @@ risk/network/geometry
|
|
3
3
|
~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
|
-
import copy
|
7
|
-
|
8
6
|
import networkx as nx
|
9
7
|
import numpy as np
|
10
8
|
|
@@ -31,44 +29,43 @@ def assign_edge_lengths(
|
|
31
29
|
"""Compute distances between pairs of coordinates."""
|
32
30
|
u_coords, v_coords = coords[:, 0, :], coords[:, 1, :]
|
33
31
|
if is_sphere:
|
34
|
-
|
35
|
-
|
36
|
-
u_coords /= u_norm
|
37
|
-
v_coords /= v_norm
|
32
|
+
u_coords /= np.linalg.norm(u_coords, axis=1, keepdims=True)
|
33
|
+
v_coords /= np.linalg.norm(v_coords, axis=1, keepdims=True)
|
38
34
|
dot_products = np.einsum("ij,ij->i", u_coords, v_coords)
|
39
35
|
return np.arccos(np.clip(dot_products, -1.0, 1.0))
|
40
|
-
|
41
36
|
return np.linalg.norm(u_coords - v_coords, axis=1)
|
42
37
|
|
43
38
|
# Normalize graph coordinates and weights
|
44
39
|
_normalize_graph_coordinates(G)
|
45
40
|
_normalize_weights(G)
|
41
|
+
|
46
42
|
# Map nodes to sphere and adjust depth if required
|
47
43
|
if compute_sphere:
|
48
44
|
_map_to_sphere(G)
|
49
|
-
G_depth = _create_depth(
|
45
|
+
G_depth = _create_depth(G, surface_depth=surface_depth)
|
50
46
|
else:
|
51
|
-
G_depth =
|
52
|
-
|
53
|
-
# Precompute edge coordinate arrays
|
54
|
-
edge_data =
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
47
|
+
G_depth = G
|
48
|
+
|
49
|
+
# Precompute edge coordinate arrays and compute distances in bulk
|
50
|
+
edge_data = np.array(
|
51
|
+
[
|
52
|
+
[
|
53
|
+
np.array(
|
54
|
+
[G_depth.nodes[u]["x"], G_depth.nodes[u]["y"], G_depth.nodes[u].get("z", 0)]
|
55
|
+
),
|
56
|
+
np.array(
|
57
|
+
[G_depth.nodes[v]["x"], G_depth.nodes[v]["y"], G_depth.nodes[v].get("z", 0)]
|
58
|
+
),
|
59
|
+
]
|
60
|
+
for u, v in G_depth.edges
|
61
|
+
]
|
62
|
+
)
|
63
|
+
# Compute distances
|
64
|
+
distances = compute_distance_vectorized(edge_data, compute_sphere)
|
68
65
|
# Assign distances back to the graph
|
69
|
-
for (u, v), distance in zip(
|
66
|
+
for (u, v), distance in zip(G_depth.edges, distances):
|
70
67
|
if include_edge_weight:
|
71
|
-
weight = G.edges[u, v].get("normalized_weight",
|
68
|
+
weight = G.edges[u, v].get("normalized_weight", 1e-6) # Avoid divide-by-zero
|
72
69
|
G.edges[u, v]["length"] = distance / np.sqrt(weight)
|
73
70
|
else:
|
74
71
|
G.edges[u, v]["length"] = distance
|
risk/network/graph/api.py
CHANGED
@@ -16,7 +16,7 @@ from risk.neighborhoods import (
|
|
16
16
|
process_neighborhoods,
|
17
17
|
trim_domains,
|
18
18
|
)
|
19
|
-
from risk.network.graph.
|
19
|
+
from risk.network.graph.graph import Graph
|
20
20
|
from risk.stats import calculate_significance_matrices
|
21
21
|
|
22
22
|
|
@@ -44,7 +44,7 @@ class GraphAPI:
|
|
44
44
|
linkage_metric: str = "yule",
|
45
45
|
min_cluster_size: int = 5,
|
46
46
|
max_cluster_size: int = 1000,
|
47
|
-
) ->
|
47
|
+
) -> Graph:
|
48
48
|
"""Load and process the network graph, defining top annotations and domains.
|
49
49
|
|
50
50
|
Args:
|
@@ -63,7 +63,7 @@ class GraphAPI:
|
|
63
63
|
max_cluster_size (int, optional): Maximum size for clusters. Defaults to 1000.
|
64
64
|
|
65
65
|
Returns:
|
66
|
-
|
66
|
+
Graph: A fully initialized and processed Graph object.
|
67
67
|
"""
|
68
68
|
# Log the parameters and display headers
|
69
69
|
log_header("Finding significant neighborhoods")
|
@@ -139,13 +139,13 @@ class GraphAPI:
|
|
139
139
|
max_cluster_size=max_cluster_size,
|
140
140
|
)
|
141
141
|
|
142
|
-
# Prepare node mapping and significance sums for the final
|
142
|
+
# Prepare node mapping and significance sums for the final Graph object
|
143
143
|
ordered_nodes = annotations["ordered_nodes"]
|
144
144
|
node_label_to_id = dict(zip(ordered_nodes, range(len(ordered_nodes))))
|
145
145
|
node_significance_sums = processed_neighborhoods["node_significance_sums"]
|
146
146
|
|
147
|
-
# Return the fully initialized
|
148
|
-
return
|
147
|
+
# Return the fully initialized Graph object
|
148
|
+
return Graph(
|
149
149
|
network=network,
|
150
150
|
annotations=annotations,
|
151
151
|
neighborhoods=neighborhoods,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/graph/
|
3
|
-
|
2
|
+
risk/network/graph/graph
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
6
|
from collections import defaultdict
|
@@ -10,13 +10,13 @@ import networkx as nx
|
|
10
10
|
import numpy as np
|
11
11
|
import pandas as pd
|
12
12
|
|
13
|
-
from risk.network.graph.summary import
|
13
|
+
from risk.network.graph.summary import Summary
|
14
14
|
|
15
15
|
|
16
|
-
class
|
16
|
+
class Graph:
|
17
17
|
"""A class to represent a network graph and process its nodes and edges.
|
18
18
|
|
19
|
-
The
|
19
|
+
The Graph class provides functionality to handle and manipulate a network graph,
|
20
20
|
including managing domains, annotations, and node significance data. It also includes methods
|
21
21
|
for transforming and mapping graph coordinates, as well as generating colors based on node
|
22
22
|
significance.
|
@@ -32,7 +32,7 @@ class NetworkGraph:
|
|
32
32
|
node_label_to_node_id_map: Dict[str, Any],
|
33
33
|
node_significance_sums: np.ndarray,
|
34
34
|
):
|
35
|
-
"""Initialize the
|
35
|
+
"""Initialize the Graph object.
|
36
36
|
|
37
37
|
Args:
|
38
38
|
network (nx.Graph): The network graph.
|
@@ -69,7 +69,7 @@ class NetworkGraph:
|
|
69
69
|
self.node_coordinates = _extract_node_coordinates(self.network)
|
70
70
|
|
71
71
|
# NOTE: Only after the above attributes are initialized, we can create the summary
|
72
|
-
self.summary =
|
72
|
+
self.summary = Summary(annotations, neighborhoods, self)
|
73
73
|
|
74
74
|
def pop(self, domain_id: str) -> None:
|
75
75
|
"""Remove domain ID from instance domain ID mappings. This can be useful for cleaning up
|
risk/network/graph/summary.py
CHANGED
@@ -12,7 +12,7 @@ from statsmodels.stats.multitest import fdrcorrection
|
|
12
12
|
from risk.log.console import logger, log_header
|
13
13
|
|
14
14
|
|
15
|
-
class
|
15
|
+
class Summary:
|
16
16
|
"""Handles the processing, storage, and export of network analysis results.
|
17
17
|
|
18
18
|
The Results class provides methods to process significance and depletion data, compute
|
@@ -25,14 +25,14 @@ class AnalysisSummary:
|
|
25
25
|
self,
|
26
26
|
annotations: Dict[str, Any],
|
27
27
|
neighborhoods: Dict[str, Any],
|
28
|
-
graph, # Avoid type hinting
|
28
|
+
graph, # Avoid type hinting Graph to prevent circular imports
|
29
29
|
):
|
30
30
|
"""Initialize the Results object with analysis components.
|
31
31
|
|
32
32
|
Args:
|
33
33
|
annotations (Dict[str, Any]): Annotation data, including ordered annotations and matrix of associations.
|
34
34
|
neighborhoods (Dict[str, Any]): Neighborhood data containing p-values for significance and depletion analysis.
|
35
|
-
graph (
|
35
|
+
graph (Graph): Graph object representing domain-to-node and node-to-label mappings.
|
36
36
|
"""
|
37
37
|
self.annotations = annotations
|
38
38
|
self.neighborhoods = neighborhoods
|
risk/network/io.py
CHANGED
@@ -217,6 +217,9 @@ class NetworkIO:
|
|
217
217
|
|
218
218
|
Returns:
|
219
219
|
nx.Graph: Loaded and processed network.
|
220
|
+
|
221
|
+
Raises:
|
222
|
+
ValueError: If no matching attribute metadata file is found.
|
220
223
|
"""
|
221
224
|
filetype = "Cytoscape"
|
222
225
|
# Log the loading of the Cytoscape file
|
@@ -258,13 +261,29 @@ class NetworkIO:
|
|
258
261
|
|
259
262
|
# Read the node attributes (from /tables/)
|
260
263
|
attribute_metadata_keywords = ["/tables/", "SHARED_ATTRS", "edge.cytable"]
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
264
|
+
# Use a generator to find the first matching file
|
265
|
+
attribute_metadata = next(
|
266
|
+
(
|
267
|
+
os.path.join(tmp_dir, cf)
|
268
|
+
for cf in cys_files
|
269
|
+
if all(keyword in cf for keyword in attribute_metadata_keywords)
|
270
|
+
),
|
271
|
+
None, # Default if no file matches
|
272
|
+
)
|
273
|
+
if attribute_metadata:
|
274
|
+
# Optimize `read_csv` by leveraging proper options
|
275
|
+
attribute_table = pd.read_csv(
|
276
|
+
attribute_metadata,
|
277
|
+
sep=",",
|
278
|
+
header=None,
|
279
|
+
skiprows=1,
|
280
|
+
dtype=str, # Use specific dtypes to reduce memory usage
|
281
|
+
engine="c", # Use the C engine for parsing if compatible
|
282
|
+
low_memory=False, # Optimize memory handling for large files
|
283
|
+
)
|
284
|
+
else:
|
285
|
+
raise ValueError("No matching attribute metadata file found.")
|
286
|
+
|
268
287
|
# Set columns
|
269
288
|
attribute_table.columns = attribute_table.iloc[0]
|
270
289
|
# Skip first four rows
|
@@ -464,14 +483,19 @@ class NetworkIO:
|
|
464
483
|
Args:
|
465
484
|
G (nx.Graph): A NetworkX graph object.
|
466
485
|
"""
|
467
|
-
|
468
|
-
|
469
|
-
nx.set_edge_attributes(G,
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
486
|
+
# Set default weight for all edges in bulk
|
487
|
+
default_weight = 1.0
|
488
|
+
nx.set_edge_attributes(G, default_weight, "weight")
|
489
|
+
# Check and assign user-defined edge weights if available
|
490
|
+
weight_attributes = nx.get_edge_attributes(G, self.weight_label)
|
491
|
+
if weight_attributes:
|
492
|
+
nx.set_edge_attributes(G, weight_attributes, "weight")
|
493
|
+
|
494
|
+
# Log missing weights if include_edge_weight is enabled
|
495
|
+
if self.include_edge_weight:
|
496
|
+
missing_weights = len(G.edges) - len(weight_attributes)
|
497
|
+
if missing_weights > 0:
|
498
|
+
logger.debug(f"Total edges missing weights: {missing_weights}")
|
475
499
|
|
476
500
|
def _validate_nodes(self, G: nx.Graph) -> None:
|
477
501
|
"""Validate the graph structure and attributes with attribute fallback for positions and labels.
|
risk/network/plotter/__init__.py
CHANGED
risk/network/plotter/api.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/
|
3
|
-
|
2
|
+
risk/network/plotter/api
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
6
|
from typing import List, Tuple, Union
|
@@ -8,14 +8,14 @@ from typing import List, Tuple, Union
|
|
8
8
|
import numpy as np
|
9
9
|
|
10
10
|
from risk.log import log_header
|
11
|
-
from risk.network.graph.
|
12
|
-
from risk.network.plotter.
|
11
|
+
from risk.network.graph.graph import Graph
|
12
|
+
from risk.network.plotter.plotter import Plotter
|
13
13
|
|
14
14
|
|
15
15
|
class PlotterAPI:
|
16
16
|
"""Handles the loading of network plotter objects.
|
17
17
|
|
18
|
-
The PlotterAPI class provides methods to load and configure
|
18
|
+
The PlotterAPI class provides methods to load and configure Plotter objects for plotting network graphs.
|
19
19
|
"""
|
20
20
|
|
21
21
|
def __init__() -> None:
|
@@ -23,16 +23,16 @@ class PlotterAPI:
|
|
23
23
|
|
24
24
|
def load_plotter(
|
25
25
|
self,
|
26
|
-
graph:
|
26
|
+
graph: Graph,
|
27
27
|
figsize: Union[List, Tuple, np.ndarray] = (10, 10),
|
28
28
|
background_color: str = "white",
|
29
29
|
background_alpha: Union[float, None] = 1.0,
|
30
30
|
pad: float = 0.3,
|
31
|
-
) ->
|
32
|
-
"""Get a
|
31
|
+
) -> Plotter:
|
32
|
+
"""Get a Plotter object for plotting.
|
33
33
|
|
34
34
|
Args:
|
35
|
-
graph (
|
35
|
+
graph (Graph): The graph to plot.
|
36
36
|
figsize (List, Tuple, or np.ndarray, optional): Size of the plot. Defaults to (10, 10)., optional): Size of the figure. Defaults to (10, 10).
|
37
37
|
background_color (str, optional): Background color of the plot. Defaults to "white".
|
38
38
|
background_alpha (float, None, optional): Transparency level of the background color. If provided, it overrides
|
@@ -40,12 +40,12 @@ class PlotterAPI:
|
|
40
40
|
pad (float, optional): Padding value to adjust the axis limits. Defaults to 0.3.
|
41
41
|
|
42
42
|
Returns:
|
43
|
-
|
43
|
+
Plotter: A Plotter object configured with the given parameters.
|
44
44
|
"""
|
45
45
|
log_header("Loading plotter")
|
46
46
|
|
47
|
-
# Initialize and return a
|
48
|
-
return
|
47
|
+
# Initialize and return a Plotter object
|
48
|
+
return Plotter(
|
49
49
|
graph,
|
50
50
|
figsize=figsize,
|
51
51
|
background_color=background_color,
|
risk/network/plotter/canvas.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/
|
3
|
-
|
2
|
+
risk/network/plotter/canvas
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
6
|
from typing import List, Tuple, Union
|
@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
|
|
9
9
|
import numpy as np
|
10
10
|
|
11
11
|
from risk.log import params
|
12
|
-
from risk.network.graph.
|
12
|
+
from risk.network.graph.graph import Graph
|
13
13
|
from risk.network.plotter.utils.colors import to_rgba
|
14
14
|
from risk.network.plotter.utils.layout import calculate_bounding_box
|
15
15
|
|
@@ -17,11 +17,11 @@ from risk.network.plotter.utils.layout import calculate_bounding_box
|
|
17
17
|
class Canvas:
|
18
18
|
"""A class for laying out the canvas in a network graph."""
|
19
19
|
|
20
|
-
def __init__(self, graph:
|
21
|
-
"""Initialize the Canvas with a
|
20
|
+
def __init__(self, graph: Graph, ax: plt.Axes) -> None:
|
21
|
+
"""Initialize the Canvas with a Graph and axis for plotting.
|
22
22
|
|
23
23
|
Args:
|
24
|
-
graph (
|
24
|
+
graph (Graph): The Graph object containing the network data.
|
25
25
|
ax (plt.Axes): The axis to plot the canvas on.
|
26
26
|
"""
|
27
27
|
self.graph = graph
|
@@ -236,7 +236,7 @@ class Canvas:
|
|
236
236
|
# Scale the node coordinates if needed
|
237
237
|
scaled_coordinates = node_coordinates * scale
|
238
238
|
# Use the existing _draw_kde_contour method
|
239
|
-
# NOTE: This is a technical debt that should be refactored in the future - only works when inherited by
|
239
|
+
# NOTE: This is a technical debt that should be refactored in the future - only works when inherited by Plotter
|
240
240
|
self._draw_kde_contour(
|
241
241
|
ax=self.ax,
|
242
242
|
pos=scaled_coordinates,
|
risk/network/plotter/contour.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/
|
3
|
-
|
2
|
+
risk/network/plotter/contour
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
6
|
from typing import Any, Dict, List, Tuple, Union
|
@@ -12,18 +12,18 @@ from scipy.ndimage import label
|
|
12
12
|
from scipy.stats import gaussian_kde
|
13
13
|
|
14
14
|
from risk.log import params, logger
|
15
|
-
from risk.network.graph.
|
15
|
+
from risk.network.graph.graph import Graph
|
16
16
|
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
17
17
|
|
18
18
|
|
19
19
|
class Contour:
|
20
20
|
"""Class to generate Kernel Density Estimate (KDE) contours for nodes in a network graph."""
|
21
21
|
|
22
|
-
def __init__(self, graph:
|
23
|
-
"""Initialize the Contour with a
|
22
|
+
def __init__(self, graph: Graph, ax: plt.Axes) -> None:
|
23
|
+
"""Initialize the Contour with a Graph and axis for plotting.
|
24
24
|
|
25
25
|
Args:
|
26
|
-
graph (
|
26
|
+
graph (Graph): The Graph object containing the network data.
|
27
27
|
ax (plt.Axes): The axis to plot the contours on.
|
28
28
|
"""
|
29
29
|
self.graph = graph
|
risk/network/plotter/labels.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""
|
2
|
-
risk/network/
|
3
|
-
|
2
|
+
risk/network/plotter/labels
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
4
|
"""
|
5
5
|
|
6
6
|
import copy
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
11
11
|
import pandas as pd
|
12
12
|
|
13
13
|
from risk.log import params
|
14
|
-
from risk.network.graph.
|
14
|
+
from risk.network.graph.graph import Graph
|
15
15
|
from risk.network.plotter.utils.colors import get_annotated_domain_colors, to_rgba
|
16
16
|
from risk.network.plotter.utils.layout import calculate_bounding_box
|
17
17
|
|
@@ -21,11 +21,11 @@ TERM_DELIMITER = "::::" # String used to separate multiple domain terms when co
|
|
21
21
|
class Labels:
|
22
22
|
"""Class to handle the annotation of network graphs with labels for different domains."""
|
23
23
|
|
24
|
-
def __init__(self, graph:
|
24
|
+
def __init__(self, graph: Graph, ax: plt.Axes):
|
25
25
|
"""Initialize the Labeler object with a network graph and matplotlib axes.
|
26
26
|
|
27
27
|
Args:
|
28
|
-
graph (
|
28
|
+
graph (Graph): Graph object containing the network data.
|
29
29
|
ax (plt.Axes): Matplotlib axes object to plot the labels on.
|
30
30
|
"""
|
31
31
|
self.graph = graph
|