risk-network 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {risk_network-0.0.2 → risk_network-0.0.3}/MANIFEST.in +1 -2
  2. {risk_network-0.0.2 → risk_network-0.0.3}/PKG-INFO +15 -9
  3. risk_network-0.0.3/README.md +40 -0
  4. {risk_network-0.0.2 → risk_network-0.0.3}/pyproject.toml +2 -2
  5. {risk_network-0.0.2 → risk_network-0.0.3}/risk/__init__.py +1 -1
  6. {risk_network-0.0.2 → risk_network-0.0.3}/risk/annotations/annotations.py +9 -9
  7. {risk_network-0.0.2 → risk_network-0.0.3}/risk/annotations/io.py +62 -49
  8. {risk_network-0.0.2 → risk_network-0.0.3}/risk/log/params.py +5 -8
  9. risk_network-0.0.2/risk/neighborhoods/graph.py → risk_network-0.0.3/risk/neighborhoods/community.py +2 -2
  10. {risk_network-0.0.2 → risk_network-0.0.3}/risk/neighborhoods/neighborhoods.py +1 -1
  11. {risk_network-0.0.2 → risk_network-0.0.3}/risk/network/io.py +2 -9
  12. {risk_network-0.0.2 → risk_network-0.0.3}/risk/network/plot.py +38 -47
  13. {risk_network-0.0.2 → risk_network-0.0.3}/risk/risk.py +19 -29
  14. {risk_network-0.0.2/risk/stats/permutation/_python → risk_network-0.0.3/risk/stats}/permutation.py +25 -20
  15. {risk_network-0.0.2 → risk_network-0.0.3}/risk/stats/stats.py +76 -146
  16. {risk_network-0.0.2 → risk_network-0.0.3}/risk_network.egg-info/PKG-INFO +15 -9
  17. {risk_network-0.0.2 → risk_network-0.0.3}/risk_network.egg-info/SOURCES.txt +2 -5
  18. {risk_network-0.0.2 → risk_network-0.0.3}/risk_network.egg-info/requires.txt +1 -1
  19. {risk_network-0.0.2 → risk_network-0.0.3}/setup.py +3 -15
  20. risk_network-0.0.2/README.md +0 -34
  21. risk_network-0.0.2/risk/stats/permutation/__init__.py +0 -15
  22. risk_network-0.0.2/risk/stats/permutation/_cython/permutation.pyx +0 -82
  23. risk_network-0.0.2/risk/stats/permutation/_cython/setup.py +0 -11
  24. {risk_network-0.0.2 → risk_network-0.0.3}/LICENSE +0 -0
  25. {risk_network-0.0.2 → risk_network-0.0.3}/risk/annotations/__init__.py +0 -0
  26. {risk_network-0.0.2 → risk_network-0.0.3}/risk/constants.py +0 -0
  27. {risk_network-0.0.2 → risk_network-0.0.3}/risk/log/__init__.py +0 -0
  28. {risk_network-0.0.2 → risk_network-0.0.3}/risk/log/console.py +0 -0
  29. {risk_network-0.0.2 → risk_network-0.0.3}/risk/neighborhoods/__init__.py +0 -0
  30. {risk_network-0.0.2 → risk_network-0.0.3}/risk/neighborhoods/domains.py +0 -0
  31. {risk_network-0.0.2 → risk_network-0.0.3}/risk/network/__init__.py +0 -0
  32. {risk_network-0.0.2 → risk_network-0.0.3}/risk/network/geometry.py +0 -0
  33. {risk_network-0.0.2 → risk_network-0.0.3}/risk/network/graph.py +0 -0
  34. {risk_network-0.0.2 → risk_network-0.0.3}/risk/stats/__init__.py +0 -0
  35. {risk_network-0.0.2 → risk_network-0.0.3}/risk_network.egg-info/dependency_links.txt +0 -0
  36. {risk_network-0.0.2 → risk_network-0.0.3}/risk_network.egg-info/top_level.txt +0 -0
  37. {risk_network-0.0.2 → risk_network-0.0.3}/setup.cfg +0 -0
@@ -1,6 +1,5 @@
1
- # Include all Python and Cython source files
1
+ # Include all Python source files
2
2
  recursive-include risk *.py
3
- recursive-include risk *.pyx
4
3
 
5
4
  # Include important project files in the distribution
6
5
  include README.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: risk-network
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: A Python package for biological network analysis
5
5
  Author: Ira Horecka
6
6
  Author-email: Ira Horecka <ira89@icloud.com>
@@ -694,13 +694,13 @@ Classifier: Development Status :: 4 - Beta
694
694
  Requires-Python: >=3.7
695
695
  Description-Content-Type: text/markdown
696
696
  License-File: LICENSE
697
- Requires-Dist: cython
698
697
  Requires-Dist: ipywidgets
699
698
  Requires-Dist: markov_clustering
700
699
  Requires-Dist: matplotlib
701
700
  Requires-Dist: networkx
702
701
  Requires-Dist: nltk==3.8.1
703
702
  Requires-Dist: numpy
703
+ Requires-Dist: openpyxl
704
704
  Requires-Dist: pandas
705
705
  Requires-Dist: python-louvain
706
706
  Requires-Dist: scikit-learn
@@ -709,15 +709,21 @@ Requires-Dist: statsmodels
709
709
  Requires-Dist: threadpoolctl
710
710
  Requires-Dist: tqdm
711
711
 
712
- # RISK
713
-
714
- <ins>Regional Inference of Significant Kinships</ins>
712
+ <p align="center">
713
+ <img src="./docs/github/risk-logo-dark.png#gh-dark-mode-only" width="400" />
714
+ <img src="./docs/github/risk-logo-light.png#gh-light-mode-only" width="400" />
715
+ </p>
715
716
 
716
- <p align="left">
717
- <img src="./docs/github/risk-logo-dark.png#gh-dark-mode-only" width="40%" />
718
- <img src="./docs/github/risk-logo-light.png#gh-light-mode-only" width="40%" />
717
+ <p align="center">
718
+ <a href="https://pypi.python.org/pypi/risk-network"><img src="https://img.shields.io/pypi/v/risk-network.svg" alt="pypiv"></a>
719
+ <a href="https://www.python.org/downloads/"><img src="https://img.shields.io/badge/python-3.7+-blue.svg" alt="Python 3.7+"></a>
720
+ <a href="https://raw.githubusercontent.com/irahorecka/chrono24/main/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
719
721
  </p>
720
722
 
723
+ ## RISK
724
+
725
+ #### Regional Inference of Significant Kinships
726
+
721
727
  RISK is a software tool for visualizing spatial relationships in networks. It aims to enhance network analysis by integrating advanced network annotation algorithms, such as Louvain and Markov Clustering, to identify key functional modules and pathways.
722
728
 
723
729
  ## Features
@@ -730,7 +736,7 @@ RISK is a software tool for visualizing spatial relationships in networks. It ai
730
736
 
731
737
  *Saccharomyces cerevisiae* proteins oriented by physical interactions discovered through affinity enrichment and mass spectrometry (Michaelis et al., 2023).
732
738
 
733
- ![Metabolic Network Demo](./docs/github/network.png)
739
+ ![PPI Network Demo](./docs/github/network.png)
734
740
 
735
741
  ## Installation
736
742
 
@@ -0,0 +1,40 @@
1
+ <p align="center">
2
+ <img src="./docs/github/risk-logo-dark.png#gh-dark-mode-only" width="400" />
3
+ <img src="./docs/github/risk-logo-light.png#gh-light-mode-only" width="400" />
4
+ </p>
5
+
6
+ <p align="center">
7
+ <a href="https://pypi.python.org/pypi/risk-network"><img src="https://img.shields.io/pypi/v/risk-network.svg" alt="pypiv"></a>
8
+ <a href="https://www.python.org/downloads/"><img src="https://img.shields.io/badge/python-3.7+-blue.svg" alt="Python 3.7+"></a>
9
+ <a href="https://raw.githubusercontent.com/irahorecka/chrono24/main/LICENSE"><img src="https://img.shields.io/badge/License-GPLv3-blue.svg" alt="License: GPL v3"></a>
10
+ </p>
11
+
12
+ ## RISK
13
+
14
+ #### Regional Inference of Significant Kinships
15
+
16
+ RISK is a software tool for visualizing spatial relationships in networks. It aims to enhance network analysis by integrating advanced network annotation algorithms, such as Louvain and Markov Clustering, to identify key functional modules and pathways.
17
+
18
+ ## Features
19
+
20
+ - Spatial analysis of biological networks
21
+ - Functional enrichment detection
22
+ - Optimized performance
23
+
24
+ ## Example
25
+
26
+ *Saccharomyces cerevisiae* proteins oriented by physical interactions discovered through affinity enrichment and mass spectrometry (Michaelis et al., 2023).
27
+
28
+ ![PPI Network Demo](./docs/github/network.png)
29
+
30
+ ## Installation
31
+
32
+ Coming soon...
33
+
34
+ ## Usage
35
+
36
+ Coming soon...
37
+
38
+ ## License
39
+
40
+ This project is licensed under the GPL-3.0 license.
@@ -1,5 +1,5 @@
1
1
  [build-system]
2
- requires = ["setuptools", "wheel", "Cython", "numpy"]
2
+ requires = ["setuptools", "wheel", "numpy"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
@@ -26,13 +26,13 @@ classifiers = [
26
26
  "Development Status :: 4 - Beta",
27
27
  ]
28
28
  dependencies = [
29
- "cython",
30
29
  "ipywidgets",
31
30
  "markov_clustering",
32
31
  "matplotlib",
33
32
  "networkx",
34
33
  "nltk==3.8.1",
35
34
  "numpy",
35
+ "openpyxl",
36
36
  "pandas",
37
37
  "python-louvain",
38
38
  "scikit-learn",
@@ -10,4 +10,4 @@ RISK: RISK Infers Spatial Kinship
10
10
 
11
11
  from risk.risk import RISK
12
12
 
13
- __version__ = "0.0.2"
13
+ __version__ = "0.0.3"
@@ -139,15 +139,15 @@ def define_top_annotations(
139
139
  size_connected_components <= max_cluster_size,
140
140
  )
141
141
  )
142
- annotations_enrichment_matrix.loc[
143
- attribute, "num connected components"
144
- ] = num_connected_components
145
- annotations_enrichment_matrix.at[
146
- attribute, "size connected components"
147
- ] = size_connected_components
148
- annotations_enrichment_matrix.loc[
149
- attribute, "num large connected components"
150
- ] = num_large_connected_components
142
+ annotations_enrichment_matrix.loc[attribute, "num connected components"] = (
143
+ num_connected_components
144
+ )
145
+ annotations_enrichment_matrix.at[attribute, "size connected components"] = (
146
+ size_connected_components
147
+ )
148
+ annotations_enrichment_matrix.loc[attribute, "num large connected components"] = (
149
+ num_large_connected_components
150
+ )
151
151
 
152
152
  # Filter out attributes with more than one connected component
153
153
  annotations_enrichment_matrix.loc[
@@ -45,66 +45,70 @@ class AnnotationsIO:
45
45
  # Process the JSON data and return it in the context of the network
46
46
  return load_annotations(network, annotations_input)
47
47
 
48
- def load_csv_annotation(
48
+ def load_excel_annotation(
49
49
  self,
50
50
  filepath: str,
51
51
  network: nx.Graph,
52
52
  label_colname: str = "label",
53
53
  nodes_colname: str = "nodes",
54
- delimiter: str = ";",
54
+ sheet_name: str = "Sheet1",
55
+ nodes_delimiter: str = ";",
55
56
  ) -> Dict[str, Any]:
56
- """Load annotations from a CSV file and convert them to a DataFrame.
57
+ """Load annotations from an Excel file and associate them with the network.
57
58
 
58
59
  Args:
59
- filepath (str): Path to the CSV annotations file.
60
- network (NetworkX graph): The network to which the annotations are related.
61
- label_colname (str): Name of the column containing the labels.
62
- nodes_colname (str): Name of the column containing the nodes.
63
- delimiter (str): Delimiter used to parse the nodes column (default is ';').
60
+ filepath (str): Path to the Excel annotations file.
61
+ network (nx.Graph): The NetworkX graph to which the annotations are related.
62
+ label_colname (str): Name of the column containing the labels (e.g., GO terms).
63
+ nodes_colname (str): Name of the column containing the nodes associated with each label.
64
+ sheet_name (str, optional): The name of the Excel sheet to load (default is 'Sheet1').
65
+ nodes_delimiter (str, optional): Delimiter used to separate multiple nodes within the nodes column (default is ';').
64
66
 
65
67
  Returns:
66
- pd.DataFrame: DataFrame containing the labels and parsed nodes.
68
+ Dict[str, Any]: A dictionary where each label is paired with its respective list of nodes,
69
+ linked to the provided network.
67
70
  """
68
- filetype = "CSV"
71
+ filetype = "Excel"
69
72
  params.log_annotations(filepath=filepath, filetype=filetype)
70
73
  _log_loading(filetype, filepath=filepath)
71
- # Load the CSV file into a dictionary
72
- annotations_input = _load_matrix_file(filepath, label_colname, nodes_colname, delimiter)
73
- # Process and return the annotations in the context of the network
74
- return load_annotations(network, annotations_input)
74
+ # Load the specified sheet from the Excel file
75
+ df = pd.read_excel(filepath, sheet_name=sheet_name)
76
+ # Split the nodes column by the specified nodes_delimiter
77
+ df[nodes_colname] = df[nodes_colname].apply(lambda x: x.split(nodes_delimiter))
78
+ # Convert the DataFrame to a dictionary pairing labels with their corresponding nodes
79
+ label_node_dict = df.set_index(label_colname)[nodes_colname].to_dict()
80
+ return load_annotations(network, label_node_dict)
75
81
 
76
- def load_excel_annotation(
82
+ def load_csv_annotation(
77
83
  self,
78
84
  filepath: str,
79
85
  network: nx.Graph,
80
86
  label_colname: str = "label",
81
87
  nodes_colname: str = "nodes",
82
- sheet_name: str = "Sheet1",
83
- delimiter: str = ";",
88
+ nodes_delimiter: str = ";",
84
89
  ) -> Dict[str, Any]:
85
- """Load annotations from an Excel file and convert them to a dictionary.
90
+ """Load annotations from a CSV file and associate them with the network.
86
91
 
87
92
  Args:
88
- filepath (str): Path to the Excel annotations file.
89
- network (NetworkX graph): The network to which the annotations are related.
90
- label_colname (str): Name of the column containing the labels.
91
- nodes_colname (str): Name of the column containing the nodes.
92
- sheet_name (str): The name of the Excel sheet to load (default is 'Sheet1').
93
- delimiter (str): Delimiter used to parse the nodes column (default is ';').
93
+ filepath (str): Path to the CSV annotations file.
94
+ network (nx.Graph): The NetworkX graph to which the annotations are related.
95
+ label_colname (str): Name of the column containing the labels (e.g., GO terms).
96
+ nodes_colname (str): Name of the column containing the nodes associated with each label.
97
+ nodes_delimiter (str, optional): Delimiter used to separate multiple nodes within the nodes column (default is ';').
94
98
 
95
99
  Returns:
96
- dict: A dictionary where each label is paired with its respective list of nodes.
100
+ Dict[str, Any]: A dictionary where each label is paired with its respective list of nodes,
101
+ linked to the provided network.
97
102
  """
98
- filetype = "Excel"
103
+ filetype = "CSV"
99
104
  params.log_annotations(filepath=filepath, filetype=filetype)
100
105
  _log_loading(filetype, filepath=filepath)
101
- # Load the specified sheet from the Excel file
102
- df = pd.read_excel(filepath, sheet_name=sheet_name)
103
- # Split the nodes column by the specified delimiter
104
- df[nodes_colname] = df[nodes_colname].apply(lambda x: x.split(delimiter))
105
- # Convert the DataFrame to a dictionary pairing labels with their corresponding nodes
106
- label_node_dict = df.set_index(label_colname)[nodes_colname].to_dict()
107
- return load_annotations(network, label_node_dict)
106
+ # Load the CSV file into a dictionary
107
+ annotations_input = _load_matrix_file(
108
+ filepath, label_colname, nodes_colname, delimiter=",", nodes_delimiter=nodes_delimiter
109
+ )
110
+ # Process and return the annotations in the context of the network
111
+ return load_annotations(network, annotations_input)
108
112
 
109
113
  def load_tsv_annotation(
110
114
  self,
@@ -112,47 +116,56 @@ class AnnotationsIO:
112
116
  network: nx.Graph,
113
117
  label_colname: str = "label",
114
118
  nodes_colname: str = "nodes",
119
+ nodes_delimiter: str = ";",
115
120
  ) -> Dict[str, Any]:
116
- """Load annotations from a TSV file and convert them to a DataFrame.
121
+ """Load annotations from a TSV file and associate them with the network.
117
122
 
118
123
  Args:
119
124
  filepath (str): Path to the TSV annotations file.
120
- network (NetworkX graph): The network to which the annotations are related.
121
- label_colname (str): Name of the column containing the labels.
122
- nodes_colname (str): Name of the column containing the nodes.
125
+ network (nx.Graph): The NetworkX graph to which the annotations are related.
126
+ label_colname (str): Name of the column containing the labels (e.g., GO terms).
127
+ nodes_colname (str): Name of the column containing the nodes associated with each label.
128
+ nodes_delimiter (str, optional): Delimiter used to separate multiple nodes within the nodes column (default is ';').
123
129
 
124
130
  Returns:
125
- pd.DataFrame: DataFrame containing the labels and parsed nodes.
131
+ Dict[str, Any]: A dictionary where each label is paired with its respective list of nodes,
132
+ linked to the provided network.
126
133
  """
127
134
  filetype = "TSV"
128
135
  params.log_annotations(filepath=filepath, filetype=filetype)
129
136
  _log_loading(filetype, filepath=filepath)
130
- # Load the TSV file with tab delimiter and convert to dictionary
137
+ # Load the TSV file into a dictionary
131
138
  annotations_input = _load_matrix_file(
132
- filepath, label_colname, nodes_colname, delimiter="\t"
139
+ filepath, label_colname, nodes_colname, delimiter="\t", nodes_delimiter=nodes_delimiter
133
140
  )
134
141
  # Process and return the annotations in the context of the network
135
142
  return load_annotations(network, annotations_input)
136
143
 
137
144
 
138
145
  def _load_matrix_file(
139
- filepath: str, label_colname: str, nodes_colname: str, delimiter: str = ";"
146
+ filepath: str,
147
+ label_colname: str,
148
+ nodes_colname: str,
149
+ delimiter: str = ",",
150
+ nodes_delimiter: str = ";",
140
151
  ) -> Dict[str, Any]:
141
152
  """Load annotations from a CSV or TSV file and convert them to a dictionary.
142
153
 
143
154
  Args:
144
155
  filepath (str): Path to the annotation file.
145
- label_colname (str): Name of the column containing the labels.
146
- nodes_colname (str): Name of the column containing the nodes.
147
- delimiter (str): Delimiter used to parse the nodes column (default is ';').
156
+ label_colname (str): Name of the column containing the labels (e.g., GO terms).
157
+ nodes_colname (str): Name of the column containing the nodes associated with each label.
158
+ delimiter (str, optional): Delimiter used to separate columns in the file (default is ',').
159
+ nodes_delimiter (str, optional): Delimiter used to separate multiple nodes within the nodes column (default is ';').
148
160
 
149
161
  Returns:
150
- dict: A dictionary where each label is paired with its respective list of nodes.
162
+ Dict[str, Any]: A dictionary where each label is paired with its respective list of nodes.
151
163
  """
152
- df = pd.read_csv(filepath)
153
- # Split the nodes column by the delimiter
154
- df[nodes_colname] = df[nodes_colname].apply(lambda x: x.split(delimiter))
155
- # Create a dictionary pairing labels with their corresponding nodes
164
+ # Load the CSV or TSV file into a DataFrame
165
+ df = pd.read_csv(filepath, delimiter=delimiter)
166
+ # Split the nodes column by the nodes_delimiter to handle multiple nodes per label
167
+ df[nodes_colname] = df[nodes_colname].apply(lambda x: x.split(nodes_delimiter))
168
+ # Create a dictionary pairing labels with their corresponding list of nodes
156
169
  label_node_dict = df.set_index(label_colname)[nodes_colname].to_dict()
157
170
  return label_node_dict
158
171
 
@@ -147,14 +147,11 @@ class Params:
147
147
  params = self.load()
148
148
  # Open the file in write mode
149
149
  with open(filepath, "w") as txt_file:
150
- for key, nested_dict in params.items():
151
- # Write the key
152
- txt_file.write(f"{key}:\n")
153
- # Write the nested dictionary values, one per line
154
- for nested_key, nested_value in nested_dict.items():
155
- txt_file.write(f" {nested_key}: {nested_value}\n")
156
- # Add a blank line between different keys
157
- txt_file.write("\n")
150
+ for key, value in params.items():
151
+ # Write the key and its corresponding value
152
+ txt_file.write(f"{key}: {value}\n")
153
+ # Add a blank line after each entry
154
+ txt_file.write("\n")
158
155
 
159
156
  def load(self) -> Dict[str, Any]:
160
157
  """Load and process various parameters, converting any np.ndarray values to lists.
@@ -1,6 +1,6 @@
1
1
  """
2
- risk/neighborhoods/graph
3
- ~~~~~~~~~~~~~~~~~~~~~~~~
2
+ risk/neighborhoods/community
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
4
  """
5
5
 
6
6
  import community as community_louvain
@@ -10,7 +10,7 @@ import networkx as nx
10
10
  import numpy as np
11
11
  from sklearn.exceptions import DataConversionWarning
12
12
 
13
- from risk.neighborhoods.graph import (
13
+ from risk.neighborhoods.community import (
14
14
  calculate_dijkstra_neighborhoods,
15
15
  calculate_label_propagation_neighborhoods,
16
16
  calculate_louvain_neighborhoods,
@@ -29,9 +29,6 @@ class NetworkIO:
29
29
  self,
30
30
  compute_sphere: bool = True,
31
31
  surface_depth: float = 0.0,
32
- distance_metric: str = "dijkstra",
33
- edge_length_threshold: float = 0.5,
34
- louvain_resolution: float = 0.1,
35
32
  min_edges_per_node: int = 0,
36
33
  include_edge_weight: bool = True,
37
34
  weight_label: str = "weight",
@@ -40,9 +37,6 @@ class NetworkIO:
40
37
  self.surface_depth = surface_depth
41
38
  self.include_edge_weight = include_edge_weight
42
39
  self.weight_label = weight_label
43
- self.distance_metric = distance_metric
44
- self.edge_length_threshold = edge_length_threshold
45
- self.louvain_resolution = louvain_resolution
46
40
  self.min_edges_per_node = min_edges_per_node
47
41
 
48
42
  def load_gpickle_network(self, filepath: str) -> nx.Graph:
@@ -317,10 +311,9 @@ class NetworkIO:
317
311
  print(f"Filetype: {filetype}")
318
312
  if filepath:
319
313
  print(f"Filepath: {filepath}")
320
- print(f"Project to sphere: {self.compute_sphere}")
314
+ print(f"Projection: {'Sphere' if self.compute_sphere else 'Plane'}")
321
315
  if self.compute_sphere:
322
316
  print(f"Surface depth: {self.surface_depth}")
323
- print(f"Edge length threshold: {self.edge_length_threshold}")
324
- print(f"Include edge weights: {self.include_edge_weight}")
317
+ print(f"Edge weight: {'Included' if self.include_edge_weight else 'Excluded'}")
325
318
  if self.include_edge_weight:
326
319
  print(f"Weight label: {self.weight_label}")
@@ -45,21 +45,24 @@ class NetworkPlotter:
45
45
  outline_scale (float, optional): Outline scaling factor for the perimeter diameter. Defaults to 1.0.
46
46
  """
47
47
  self.network_graph = network_graph
48
- self.ax = None # Initialize the axis attribute
49
- # Initialize the plot with the given parameters
50
- self._initialize_plot(figsize, background_color, plot_outline, outline_color, outline_scale)
48
+ # Initialize the plot with the specified parameters
49
+ self.ax = self._initialize_plot(
50
+ network_graph, figsize, background_color, plot_outline, outline_color, outline_scale
51
+ )
51
52
 
52
53
  def _initialize_plot(
53
54
  self,
55
+ network_graph: NetworkGraph,
54
56
  figsize: tuple,
55
57
  background_color: str,
56
58
  plot_outline: bool,
57
59
  outline_color: str,
58
60
  outline_scale: float,
59
- ) -> tuple:
61
+ ) -> plt.Axes:
60
62
  """Set up the plot with figure size, optional circle perimeter, and background color.
61
63
 
62
64
  Args:
65
+ network_graph (NetworkGraph): The network data and attributes to be visualized.
63
66
  figsize (tuple): Size of the figure in inches (width, height).
64
67
  background_color (str): Background color of the plot.
65
68
  plot_outline (bool): Whether to plot the network perimeter circle.
@@ -67,10 +70,10 @@ class NetworkPlotter:
67
70
  outline_scale (float): Outline scaling factor for the perimeter diameter.
68
71
 
69
72
  Returns:
70
- tuple: The created matplotlib figure and axis.
73
+ plt.Axes: The axis object for the plot.
71
74
  """
72
75
  # Extract node coordinates from the network graph
73
- node_coordinates = self.network_graph.node_coordinates
76
+ node_coordinates = network_graph.node_coordinates
74
77
  # Calculate the center and radius of the bounding box around the network
75
78
  center, radius = _calculate_bounding_box(node_coordinates)
76
79
  # Scale the radius by the outline_scale factor
@@ -107,9 +110,7 @@ class NetworkPlotter:
107
110
  ax.set_yticks([])
108
111
  ax.patch.set_visible(False) # Hide the axis background
109
112
 
110
- # Store the axis for further use and return the figure and axis
111
- self.ax = ax
112
- return fig, ax
113
+ return ax
113
114
 
114
115
  def plot_network(
115
116
  self,
@@ -401,7 +402,7 @@ class NetworkPlotter:
401
402
  fontcolor: Union[str, np.ndarray] = "black",
402
403
  arrow_linewidth: float = 1,
403
404
  arrow_color: Union[str, np.ndarray] = "black",
404
- num_words: int = 10,
405
+ max_words: int = 10,
405
406
  min_words: int = 1,
406
407
  ) -> None:
407
408
  """Annotate the network graph with labels for different domains, positioned around the network for clarity.
@@ -414,7 +415,7 @@ class NetworkPlotter:
414
415
  fontcolor (str or np.ndarray, optional): Color of the label text. Can be a string or RGBA array. Defaults to "black".
415
416
  arrow_linewidth (float, optional): Line width of the arrows pointing to centroids. Defaults to 1.
416
417
  arrow_color (str or np.ndarray, optional): Color of the arrows. Can be a string or RGBA array. Defaults to "black".
417
- num_words (int, optional): Maximum number of words in a label. Defaults to 10.
418
+ max_words (int, optional): Maximum number of words in a label. Defaults to 10.
418
419
  min_words (int, optional): Minimum number of words required to display a label. Defaults to 1.
419
420
  """
420
421
  # Log the plotting parameters
@@ -426,7 +427,7 @@ class NetworkPlotter:
426
427
  label_fontcolor="custom" if isinstance(fontcolor, np.ndarray) else fontcolor,
427
428
  label_arrow_linewidth=arrow_linewidth,
428
429
  label_arrow_color="custom" if isinstance(arrow_color, np.ndarray) else arrow_color,
429
- label_num_words=num_words,
430
+ label_max_words=max_words,
430
431
  label_min_words=min_words,
431
432
  )
432
433
  # Convert color strings to RGBA arrays if necessary
@@ -436,7 +437,12 @@ class NetworkPlotter:
436
437
  arrow_color = self.get_annotated_contour_colors(color=arrow_color)
437
438
 
438
439
  # Calculate the center and radius of the network
439
- domain_centroids = self._calculate_domain_centroids()
440
+ domain_centroids = {}
441
+ for domain, nodes in self.network_graph.domain_to_nodes.items():
442
+ if nodes: # Skip if the domain has no nodes
443
+ domain_centroids[domain] = self._calculate_domain_centroid(nodes)
444
+
445
+ # Calculate the bounding box around the network
440
446
  center, radius = _calculate_bounding_box(
441
447
  self.network_graph.node_coordinates, radius_margin=perimeter_scale
442
448
  )
@@ -445,7 +451,7 @@ class NetworkPlotter:
445
451
  filtered_domains = {
446
452
  domain: centroid
447
453
  for domain, centroid in domain_centroids.items()
448
- if len(self.network_graph.trimmed_domain_to_term[domain].split(" ")[:num_words])
454
+ if len(self.network_graph.trimmed_domain_to_term[domain].split(" ")[:max_words])
449
455
  >= min_words
450
456
  }
451
457
  # Calculate the best positions for labels around the perimeter
@@ -453,7 +459,7 @@ class NetworkPlotter:
453
459
  # Annotate the network with labels
454
460
  for idx, (domain, pos) in enumerate(best_label_positions.items()):
455
461
  centroid = filtered_domains[domain]
456
- annotations = self.network_graph.trimmed_domain_to_term[domain].split(" ")[:num_words]
462
+ annotations = self.network_graph.trimmed_domain_to_term[domain].split(" ")[:max_words]
457
463
  self.ax.annotate(
458
464
  "\n".join(annotations),
459
465
  xy=centroid,
@@ -467,31 +473,26 @@ class NetworkPlotter:
467
473
  arrowprops=dict(arrowstyle="->", color=arrow_color[idx], linewidth=arrow_linewidth),
468
474
  )
469
475
 
470
- def _calculate_domain_centroids(self) -> Dict[Any, np.ndarray]:
471
- """Calculate the most centrally located node within each domain based on the node positions.
476
+ def _calculate_domain_centroid(self, nodes: list) -> tuple:
477
+ """Calculate the most centrally located node in .
478
+
479
+ Args:
480
+ nodes (list): List of node labels to include in the subnetwork.
472
481
 
473
482
  Returns:
474
- Dict[Any, np.ndarray]: A dictionary mapping each domain to its central node's coordinates.
483
+ tuple: A tuple containing the domain's central node coordinates.
475
484
  """
476
- domain_central_nodes = {}
477
- for domain, nodes in self.network_graph.domain_to_nodes.items():
478
- if not nodes: # Skip if the domain has no nodes
479
- continue
480
-
481
- # Extract positions of all nodes in the domain
482
- node_positions = self.network_graph.node_coordinates[nodes, :]
483
- # Calculate the pairwise distance matrix between all nodes in the domain
484
- distances_matrix = np.linalg.norm(
485
- node_positions[:, np.newaxis] - node_positions, axis=2
486
- )
487
- # Sum the distances for each node to all other nodes in the domain
488
- sum_distances = np.sum(distances_matrix, axis=1)
489
- # Identify the node with the smallest total distance to others (the centroid)
490
- central_node_idx = np.argmin(sum_distances)
491
- # Map the domain to the coordinates of its central node
492
- domain_central_nodes[domain] = node_positions[central_node_idx]
493
-
494
- return domain_central_nodes
485
+ # Extract positions of all nodes in the domain
486
+ node_positions = self.network_graph.node_coordinates[nodes, :]
487
+ # Calculate the pairwise distance matrix between all nodes in the domain
488
+ distances_matrix = np.linalg.norm(node_positions[:, np.newaxis] - node_positions, axis=2)
489
+ # Sum the distances for each node to all other nodes in the domain
490
+ sum_distances = np.sum(distances_matrix, axis=1)
491
+ # Identify the node with the smallest total distance to others (the centroid)
492
+ central_node_idx = np.argmin(sum_distances)
493
+ # Map the domain to the coordinates of its central node
494
+ domain_central_node = node_positions[central_node_idx]
495
+ return domain_central_node
495
496
 
496
497
  def get_annotated_node_colors(
497
498
  self, nonenriched_color: str = "white", random_seed: int = 888, **kwargs
@@ -604,16 +605,6 @@ class NetworkPlotter:
604
605
 
605
606
  return np.array(annotated_colors)
606
607
 
607
- @staticmethod
608
- def close(*args, **kwargs) -> None:
609
- """Close the current plot.
610
-
611
- Args:
612
- *args: Positional arguments passed to `plt.close`.
613
- **kwargs: Keyword arguments passed to `plt.close`.
614
- """
615
- plt.close(*args, **kwargs)
616
-
617
608
  @staticmethod
618
609
  def savefig(*args, **kwargs) -> None:
619
610
  """Save the current plot to a file.