risk-network 0.0.3b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +13 -0
- risk/annotations/__init__.py +7 -0
- risk/annotations/annotations.py +259 -0
- risk/annotations/io.py +183 -0
- risk/constants.py +31 -0
- risk/log/__init__.py +9 -0
- risk/log/console.py +16 -0
- risk/log/params.py +198 -0
- risk/neighborhoods/__init__.py +10 -0
- risk/neighborhoods/community.py +189 -0
- risk/neighborhoods/domains.py +257 -0
- risk/neighborhoods/neighborhoods.py +319 -0
- risk/network/__init__.py +8 -0
- risk/network/geometry.py +165 -0
- risk/network/graph.py +280 -0
- risk/network/io.py +326 -0
- risk/network/plot.py +795 -0
- risk/risk.py +382 -0
- risk/stats/__init__.py +6 -0
- risk/stats/permutation.py +88 -0
- risk/stats/stats.py +447 -0
- risk_network-0.0.3b1.dist-info/LICENSE +674 -0
- risk_network-0.0.3b1.dist-info/METADATA +751 -0
- risk_network-0.0.3b1.dist-info/RECORD +26 -0
- risk_network-0.0.3b1.dist-info/WHEEL +5 -0
- risk_network-0.0.3b1.dist-info/top_level.txt +1 -0
risk/log/params.py
ADDED
@@ -0,0 +1,198 @@
|
|
1
|
+
"""
|
2
|
+
risk/log/params
|
3
|
+
~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
import csv
|
7
|
+
import json
|
8
|
+
import warnings
|
9
|
+
from datetime import datetime
|
10
|
+
from typing import Any, Dict
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
from .console import print_header
|
15
|
+
|
16
|
+
# Suppress all warnings - this is to resolve warnings from multiprocessing
|
17
|
+
warnings.filterwarnings("ignore")
|
18
|
+
|
19
|
+
|
20
|
+
def _safe_param_export(func):
|
21
|
+
"""A decorator to wrap parameter export functions in a try-except block for safe execution.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
func (function): The function to be wrapped.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
function: The wrapped function with error handling.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def wrapper(*args, **kwargs):
|
31
|
+
try:
|
32
|
+
result = func(*args, **kwargs)
|
33
|
+
filepath = (
|
34
|
+
kwargs.get("filepath") or args[1]
|
35
|
+
) # Assuming filepath is always the second argument
|
36
|
+
print(f"Parameters successfully exported to filepath: {filepath}")
|
37
|
+
return result
|
38
|
+
except Exception as e:
|
39
|
+
filepath = kwargs.get("filepath") or args[1]
|
40
|
+
print(f"An error occurred while exporting parameters to {filepath}: {e}")
|
41
|
+
return None
|
42
|
+
|
43
|
+
return wrapper
|
44
|
+
|
45
|
+
|
46
|
+
class Params:
|
47
|
+
"""Handles the storage and logging of various parameters for network analysis.
|
48
|
+
|
49
|
+
The Params class provides methods to log parameters related to different components of the analysis,
|
50
|
+
such as the network, annotations, neighborhoods, graph, and plotter settings. It also stores
|
51
|
+
the current datetime when the parameters were initialized.
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(self):
|
55
|
+
"""Initialize the Params object with default settings and current datetime."""
|
56
|
+
self.initialize()
|
57
|
+
self.datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
58
|
+
|
59
|
+
def initialize(self) -> None:
|
60
|
+
"""Initialize the parameter dictionaries for different components."""
|
61
|
+
self.network = {}
|
62
|
+
self.annotations = {}
|
63
|
+
self.neighborhoods = {}
|
64
|
+
self.graph = {}
|
65
|
+
self.plotter = {}
|
66
|
+
|
67
|
+
def log_network(self, **kwargs) -> None:
|
68
|
+
"""Log network-related parameters.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
**kwargs: Network parameters to log.
|
72
|
+
"""
|
73
|
+
self.network = {**self.network, **kwargs}
|
74
|
+
|
75
|
+
def log_annotations(self, **kwargs) -> None:
|
76
|
+
"""Log annotation-related parameters.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
**kwargs: Annotation parameters to log.
|
80
|
+
"""
|
81
|
+
self.annotations = {**self.annotations, **kwargs}
|
82
|
+
|
83
|
+
def log_neighborhoods(self, **kwargs) -> None:
|
84
|
+
"""Log neighborhood-related parameters.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
**kwargs: Neighborhood parameters to log.
|
88
|
+
"""
|
89
|
+
self.neighborhoods = {**self.neighborhoods, **kwargs}
|
90
|
+
|
91
|
+
def log_graph(self, **kwargs) -> None:
|
92
|
+
"""Log graph-related parameters.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
**kwargs: Graph parameters to log.
|
96
|
+
"""
|
97
|
+
self.graph = {**self.graph, **kwargs}
|
98
|
+
|
99
|
+
def log_plotter(self, **kwargs) -> None:
|
100
|
+
"""Log plotter-related parameters.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
**kwargs: Plotter parameters to log.
|
104
|
+
"""
|
105
|
+
self.plotter = {**self.plotter, **kwargs}
|
106
|
+
|
107
|
+
@_safe_param_export
|
108
|
+
def to_csv(self, filepath: str) -> None:
|
109
|
+
"""Export the parameters to a CSV file.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
filepath (str): The path where the CSV file will be saved.
|
113
|
+
"""
|
114
|
+
# Load the parameter dictionary
|
115
|
+
params = self.load()
|
116
|
+
# Open the file in write mode
|
117
|
+
with open(filepath, "w", newline="") as csv_file:
|
118
|
+
writer = csv.writer(csv_file)
|
119
|
+
# Write the header
|
120
|
+
writer.writerow(["parent_key", "child_key", "value"])
|
121
|
+
# Write the rows
|
122
|
+
for parent_key, parent_value in params.items():
|
123
|
+
if isinstance(parent_value, dict):
|
124
|
+
for child_key, child_value in parent_value.items():
|
125
|
+
writer.writerow([parent_key, child_key, child_value])
|
126
|
+
else:
|
127
|
+
writer.writerow([parent_key, "", parent_value])
|
128
|
+
|
129
|
+
@_safe_param_export
|
130
|
+
def to_json(self, filepath: str) -> None:
|
131
|
+
"""Export the parameters to a JSON file.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
filepath (str): The path where the JSON file will be saved.
|
135
|
+
"""
|
136
|
+
with open(filepath, "w") as json_file:
|
137
|
+
json.dump(self.load(), json_file, indent=4)
|
138
|
+
|
139
|
+
@_safe_param_export
|
140
|
+
def to_txt(self, filepath: str) -> None:
|
141
|
+
"""Export the parameters to a text file.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
filepath (str): The path where the text file will be saved.
|
145
|
+
"""
|
146
|
+
# Load the parameter dictionary
|
147
|
+
params = self.load()
|
148
|
+
# Open the file in write mode
|
149
|
+
with open(filepath, "w") as txt_file:
|
150
|
+
for key, nested_dict in params.items():
|
151
|
+
# Write the key
|
152
|
+
txt_file.write(f"{key}:\n")
|
153
|
+
# Write the nested dictionary values, one per line
|
154
|
+
for nested_key, nested_value in nested_dict.items():
|
155
|
+
txt_file.write(f" {nested_key}: {nested_value}\n")
|
156
|
+
# Add a blank line between different keys
|
157
|
+
txt_file.write("\n")
|
158
|
+
|
159
|
+
def load(self) -> Dict[str, Any]:
|
160
|
+
"""Load and process various parameters, converting any np.ndarray values to lists.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
dict: A dictionary containing the processed parameters.
|
164
|
+
"""
|
165
|
+
print_header("Loading parameters")
|
166
|
+
return _convert_ndarray_to_list(
|
167
|
+
{
|
168
|
+
"annotations": self.annotations,
|
169
|
+
"datetime": self.datetime,
|
170
|
+
"graph": self.graph,
|
171
|
+
"neighborhoods": self.neighborhoods,
|
172
|
+
"network": self.network,
|
173
|
+
"plotter": self.plotter,
|
174
|
+
}
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
def _convert_ndarray_to_list(d: Any) -> Any:
|
179
|
+
"""Recursively convert all np.ndarray values in the dictionary to lists.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
d (dict): The dictionary to process.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
dict: The processed dictionary with np.ndarray values converted to lists.
|
186
|
+
"""
|
187
|
+
if isinstance(d, dict):
|
188
|
+
# Recursively process each value in the dictionary
|
189
|
+
return {k: _convert_ndarray_to_list(v) for k, v in d.items()}
|
190
|
+
elif isinstance(d, list):
|
191
|
+
# Recursively process each item in the list
|
192
|
+
return [_convert_ndarray_to_list(v) for v in d]
|
193
|
+
elif isinstance(d, np.ndarray):
|
194
|
+
# Convert numpy arrays to lists
|
195
|
+
return d.tolist()
|
196
|
+
else:
|
197
|
+
# Return the value unchanged if it's not a dict, list, or ndarray
|
198
|
+
return d
|
@@ -0,0 +1,189 @@
|
|
1
|
+
"""
|
2
|
+
risk/neighborhoods/community
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
import community as community_louvain
|
7
|
+
import networkx as nx
|
8
|
+
import numpy as np
|
9
|
+
import markov_clustering as mc
|
10
|
+
from networkx.algorithms.community import asyn_lpa_communities
|
11
|
+
|
12
|
+
|
13
|
+
def calculate_dijkstra_neighborhoods(network: nx.Graph) -> np.ndarray:
|
14
|
+
"""Calculate neighborhoods using Dijkstra's shortest path distances.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
network (nx.Graph): The network graph.
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
np.ndarray: Neighborhood matrix based on Dijkstra's distances.
|
21
|
+
"""
|
22
|
+
# Compute Dijkstra's distance for all pairs of nodes in the network
|
23
|
+
all_dijkstra_paths = dict(nx.all_pairs_dijkstra_path_length(network, weight="length"))
|
24
|
+
neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
|
25
|
+
|
26
|
+
# Populate the neighborhoods matrix based on Dijkstra's distances
|
27
|
+
for source, targets in all_dijkstra_paths.items():
|
28
|
+
for target, length in targets.items():
|
29
|
+
neighborhoods[source, target] = (
|
30
|
+
1 if np.isnan(length) or length == 0 else np.sqrt(1 / length)
|
31
|
+
)
|
32
|
+
|
33
|
+
return neighborhoods
|
34
|
+
|
35
|
+
|
36
|
+
def calculate_label_propagation_neighborhoods(network: nx.Graph) -> np.ndarray:
|
37
|
+
"""Apply Label Propagation to the network to detect communities.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
network (nx.Graph): The network graph.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
np.ndarray: Neighborhood matrix based on Label Propagation.
|
44
|
+
"""
|
45
|
+
# Apply Label Propagation
|
46
|
+
communities = nx.algorithms.community.label_propagation.label_propagation_communities(network)
|
47
|
+
|
48
|
+
# Create a mapping from node to community
|
49
|
+
community_dict = {}
|
50
|
+
for community_id, community in enumerate(communities):
|
51
|
+
for node in community:
|
52
|
+
community_dict[node] = community_id
|
53
|
+
|
54
|
+
# Create a neighborhood matrix
|
55
|
+
num_nodes = network.number_of_nodes()
|
56
|
+
neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
|
57
|
+
|
58
|
+
# Assign neighborhoods based on community labels
|
59
|
+
for node_i, community_i in community_dict.items():
|
60
|
+
for node_j, community_j in community_dict.items():
|
61
|
+
if community_i == community_j:
|
62
|
+
neighborhoods[node_i, node_j] = 1
|
63
|
+
|
64
|
+
return neighborhoods
|
65
|
+
|
66
|
+
|
67
|
+
def calculate_louvain_neighborhoods(
|
68
|
+
network: nx.Graph, resolution: float, random_seed: int = 888
|
69
|
+
) -> np.ndarray:
|
70
|
+
"""Calculate neighborhoods using the Louvain method.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
network (nx.Graph): The network graph.
|
74
|
+
resolution (float): Resolution parameter for the Louvain method.
|
75
|
+
random_seed (int, optional): Random seed for reproducibility. Defaults to 888.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
np.ndarray: Neighborhood matrix based on the Louvain method.
|
79
|
+
"""
|
80
|
+
# Apply Louvain method to partition the network
|
81
|
+
partition = community_louvain.best_partition(
|
82
|
+
network, resolution=resolution, random_state=random_seed
|
83
|
+
)
|
84
|
+
neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
|
85
|
+
|
86
|
+
# Assign neighborhoods based on community partitions
|
87
|
+
for node_i, community_i in partition.items():
|
88
|
+
for node_j, community_j in partition.items():
|
89
|
+
if community_i == community_j:
|
90
|
+
neighborhoods[node_i, node_j] = 1
|
91
|
+
|
92
|
+
return neighborhoods
|
93
|
+
|
94
|
+
|
95
|
+
def calculate_markov_clustering_neighborhoods(network: nx.Graph) -> np.ndarray:
|
96
|
+
"""Apply Markov Clustering (MCL) to the network.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
network (nx.Graph): The network graph.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
np.ndarray: Neighborhood matrix based on Markov Clustering.
|
103
|
+
"""
|
104
|
+
# Convert the graph to an adjacency matrix
|
105
|
+
adjacency_matrix = nx.to_numpy_array(network)
|
106
|
+
# Run Markov Clustering
|
107
|
+
result = mc.run_mcl(adjacency_matrix) # Run MCL with default parameters
|
108
|
+
# Get clusters
|
109
|
+
clusters = mc.get_clusters(result)
|
110
|
+
|
111
|
+
# Create a community label for each node
|
112
|
+
community_dict = {}
|
113
|
+
for community_id, community in enumerate(clusters):
|
114
|
+
for node in community:
|
115
|
+
community_dict[node] = community_id
|
116
|
+
|
117
|
+
# Create a neighborhood matrix
|
118
|
+
num_nodes = network.number_of_nodes()
|
119
|
+
neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
|
120
|
+
|
121
|
+
# Assign neighborhoods based on community labels
|
122
|
+
for node_i, community_i in community_dict.items():
|
123
|
+
for node_j, community_j in community_dict.items():
|
124
|
+
if community_i == community_j:
|
125
|
+
neighborhoods[node_i, node_j] = 1
|
126
|
+
|
127
|
+
return neighborhoods
|
128
|
+
|
129
|
+
|
130
|
+
def calculate_spinglass_neighborhoods(network: nx.Graph) -> np.ndarray:
|
131
|
+
"""Apply Spin Glass Community Detection to the network.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
network (nx.Graph): The network graph.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
np.ndarray: Neighborhood matrix based on Spin Glass communities.
|
138
|
+
"""
|
139
|
+
# Use the asynchronous label propagation algorithm as a proxy for Spin Glass
|
140
|
+
communities = asyn_lpa_communities(network)
|
141
|
+
|
142
|
+
# Create a community label for each node
|
143
|
+
community_dict = {}
|
144
|
+
for community_id, community in enumerate(communities):
|
145
|
+
for node in community:
|
146
|
+
community_dict[node] = community_id
|
147
|
+
|
148
|
+
# Create a neighborhood matrix
|
149
|
+
num_nodes = network.number_of_nodes()
|
150
|
+
neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
|
151
|
+
|
152
|
+
# Assign neighborhoods based on community labels
|
153
|
+
for node_i, community_i in community_dict.items():
|
154
|
+
for node_j, community_j in community_dict.items():
|
155
|
+
if community_i == community_j:
|
156
|
+
neighborhoods[node_i, node_j] = 1
|
157
|
+
|
158
|
+
return neighborhoods
|
159
|
+
|
160
|
+
|
161
|
+
def calculate_walktrap_neighborhoods(network: nx.Graph) -> np.ndarray:
|
162
|
+
"""Apply Walktrap Community Detection to the network.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
network (nx.Graph): The network graph.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
np.ndarray: Neighborhood matrix based on Walktrap communities.
|
169
|
+
"""
|
170
|
+
# Use the asynchronous label propagation algorithm as a proxy for Walktrap
|
171
|
+
communities = asyn_lpa_communities(network)
|
172
|
+
|
173
|
+
# Create a community label for each node
|
174
|
+
community_dict = {}
|
175
|
+
for community_id, community in enumerate(communities):
|
176
|
+
for node in community:
|
177
|
+
community_dict[node] = community_id
|
178
|
+
|
179
|
+
# Create a neighborhood matrix
|
180
|
+
num_nodes = network.number_of_nodes()
|
181
|
+
neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
|
182
|
+
|
183
|
+
# Assign neighborhoods based on community labels
|
184
|
+
for node_i, community_i in community_dict.items():
|
185
|
+
for node_j, community_j in community_dict.items():
|
186
|
+
if community_i == community_j:
|
187
|
+
neighborhoods[node_i, node_j] = 1
|
188
|
+
|
189
|
+
return neighborhoods
|
@@ -0,0 +1,257 @@
|
|
1
|
+
"""
|
2
|
+
risk/neighborhoods/domains
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
4
|
+
"""
|
5
|
+
|
6
|
+
from contextlib import suppress
|
7
|
+
from tqdm import tqdm
|
8
|
+
from typing import Tuple
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import pandas as pd
|
12
|
+
from scipy.cluster.hierarchy import linkage, fcluster
|
13
|
+
from sklearn.metrics import silhouette_score
|
14
|
+
|
15
|
+
from risk.annotations import get_description
|
16
|
+
from risk.constants import GROUP_LINKAGE_METHODS, GROUP_DISTANCE_METRICS
|
17
|
+
|
18
|
+
|
19
|
+
def define_domains(
|
20
|
+
top_annotations: pd.DataFrame,
|
21
|
+
significant_neighborhoods_enrichment: np.ndarray,
|
22
|
+
linkage_criterion: str,
|
23
|
+
linkage_method: str,
|
24
|
+
linkage_metric: str,
|
25
|
+
) -> pd.DataFrame:
|
26
|
+
"""Define domains and assign nodes to these domains based on their enrichment scores and clustering.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
top_annotations (pd.DataFrame): DataFrame of top annotations data for the network nodes.
|
30
|
+
significant_neighborhoods_enrichment (np.ndarray): The binary enrichment matrix below alpha.
|
31
|
+
linkage_criterion (str): The clustering criterion for defining groups.
|
32
|
+
linkage_method (str): The linkage method for clustering.
|
33
|
+
linkage_metric (str): The linkage metric for clustering.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
pd.DataFrame: DataFrame with the primary domain for each node.
|
37
|
+
"""
|
38
|
+
# Perform hierarchical clustering on the binary enrichment matrix
|
39
|
+
m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
|
40
|
+
best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
|
41
|
+
m, linkage_criterion, linkage_method, linkage_metric
|
42
|
+
)
|
43
|
+
try:
|
44
|
+
Z = linkage(m, method=best_linkage, metric=best_metric)
|
45
|
+
except ValueError as e:
|
46
|
+
raise ValueError("No significant annotations found.") from e
|
47
|
+
|
48
|
+
print(
|
49
|
+
f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'"
|
50
|
+
)
|
51
|
+
print(f"Optimal linkage threshold: {round(best_threshold, 3)}")
|
52
|
+
|
53
|
+
max_d_optimal = np.max(Z[:, 2]) * best_threshold
|
54
|
+
domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
|
55
|
+
# Assign domains to the annotations matrix
|
56
|
+
top_annotations["domain"] = 0
|
57
|
+
top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
|
58
|
+
|
59
|
+
# Create DataFrames to store domain information
|
60
|
+
node_to_enrichment = pd.DataFrame(
|
61
|
+
data=significant_neighborhoods_enrichment,
|
62
|
+
columns=[top_annotations.index.values, top_annotations["domain"]],
|
63
|
+
)
|
64
|
+
node_to_domain = node_to_enrichment.groupby(level="domain", axis=1).sum()
|
65
|
+
|
66
|
+
t_max = node_to_domain.loc[:, 1:].max(axis=1)
|
67
|
+
t_idxmax = node_to_domain.loc[:, 1:].idxmax(axis=1)
|
68
|
+
t_idxmax[t_max == 0] = 0
|
69
|
+
|
70
|
+
# Assign primary domain
|
71
|
+
node_to_domain["primary domain"] = t_idxmax
|
72
|
+
|
73
|
+
return node_to_domain
|
74
|
+
|
75
|
+
|
76
|
+
def trim_domains_and_top_annotations(
|
77
|
+
domains: pd.DataFrame,
|
78
|
+
top_annotations: pd.DataFrame,
|
79
|
+
min_cluster_size: int = 5,
|
80
|
+
max_cluster_size: int = 1000,
|
81
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
82
|
+
"""Trim domains and top annotations that do not meet size criteria and find outliers.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
domains (pd.DataFrame): DataFrame of domain data for the network nodes.
|
86
|
+
top_annotations (pd.DataFrame): DataFrame of top annotations data for the network nodes.
|
87
|
+
min_cluster_size (int, optional): Minimum size of a cluster to be retained. Defaults to 5.
|
88
|
+
max_cluster_size (int, optional): Maximum size of a cluster to be retained. Defaults to 1000.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing:
|
92
|
+
- Trimmed annotations (pd.DataFrame)
|
93
|
+
- Trimmed domains (pd.DataFrame)
|
94
|
+
- A DataFrame with domain labels (pd.DataFrame)
|
95
|
+
"""
|
96
|
+
# Identify domains to remove based on size criteria
|
97
|
+
domain_counts = domains["primary domain"].value_counts()
|
98
|
+
to_remove = set(
|
99
|
+
domain_counts[(domain_counts < min_cluster_size) | (domain_counts > max_cluster_size)].index
|
100
|
+
)
|
101
|
+
|
102
|
+
# Add invalid domain IDs
|
103
|
+
invalid_domain_id = 888888
|
104
|
+
invalid_domain_ids = {0, invalid_domain_id}
|
105
|
+
# Mark domains to be removed
|
106
|
+
top_annotations["domain"].replace(to_remove, invalid_domain_id, inplace=True)
|
107
|
+
domains.loc[domains["primary domain"].isin(to_remove), ["primary domain"]] = invalid_domain_id
|
108
|
+
|
109
|
+
# Normalize "num enriched neighborhoods" by percentile for each domain and scale to 0-10
|
110
|
+
top_annotations["normalized_value"] = top_annotations.groupby("domain")[
|
111
|
+
"neighborhood enrichment sums"
|
112
|
+
].transform(lambda x: (x.rank(pct=True) * 10).apply(np.ceil).astype(int))
|
113
|
+
# Multiply 'words' column by normalized values
|
114
|
+
top_annotations["words"] = top_annotations.apply(
|
115
|
+
lambda row: " ".join([row["words"]] * row["normalized_value"]), axis=1
|
116
|
+
)
|
117
|
+
|
118
|
+
# Generate domain labels
|
119
|
+
domain_labels = top_annotations.groupby("domain")["words"].apply(get_description).reset_index()
|
120
|
+
trimmed_domains_matrix = domain_labels.rename(
|
121
|
+
columns={"domain": "id", "words": "label"}
|
122
|
+
).set_index("id")
|
123
|
+
|
124
|
+
# Remove invalid domains
|
125
|
+
valid_annotations = top_annotations[~top_annotations["domain"].isin(invalid_domain_ids)].drop(
|
126
|
+
columns=["normalized_value"]
|
127
|
+
)
|
128
|
+
valid_domains = domains[~domains["primary domain"].isin(invalid_domain_ids)]
|
129
|
+
valid_trimmed_domains_matrix = trimmed_domains_matrix[
|
130
|
+
~trimmed_domains_matrix.index.isin(invalid_domain_ids)
|
131
|
+
]
|
132
|
+
|
133
|
+
return valid_annotations, valid_domains, valid_trimmed_domains_matrix
|
134
|
+
|
135
|
+
|
136
|
+
def _optimize_silhouette_across_linkage_and_metrics(
|
137
|
+
m: np.ndarray, linkage_criterion: str, linkage_method: str, linkage_metric: str
|
138
|
+
) -> Tuple[str, str, float]:
|
139
|
+
"""Optimize silhouette score across different linkage methods and distance metrics.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
m (np.ndarray): Data matrix.
|
143
|
+
linkage_criterion (str): Clustering criterion.
|
144
|
+
linkage_method (str): Linkage method for clustering.
|
145
|
+
linkage_metric (str): Linkage metric for clustering.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
tuple[str, str, float]: A tuple containing:
|
149
|
+
- Best linkage method (str)
|
150
|
+
- Best linkage metric (str)
|
151
|
+
- Best threshold (float)
|
152
|
+
"""
|
153
|
+
best_overall_method = linkage_method
|
154
|
+
best_overall_metric = linkage_metric
|
155
|
+
best_overall_score = -np.inf
|
156
|
+
best_overall_threshold = 1
|
157
|
+
|
158
|
+
linkage_methods = GROUP_LINKAGE_METHODS if linkage_method == "auto" else [linkage_method]
|
159
|
+
linkage_metrics = GROUP_DISTANCE_METRICS if linkage_metric == "auto" else [linkage_metric]
|
160
|
+
total_combinations = len(linkage_methods) * len(linkage_metrics)
|
161
|
+
|
162
|
+
# Evaluating optimal linkage method and metric
|
163
|
+
for method in tqdm(
|
164
|
+
linkage_methods,
|
165
|
+
desc="Evaluating optimal linkage method and metric",
|
166
|
+
total=total_combinations,
|
167
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
|
168
|
+
):
|
169
|
+
for metric in linkage_metrics:
|
170
|
+
with suppress(Exception):
|
171
|
+
Z = linkage(m, method=method, metric=metric)
|
172
|
+
threshold, score = _find_best_silhouette_score(Z, m, metric, linkage_criterion)
|
173
|
+
if score > best_overall_score:
|
174
|
+
best_overall_score = score
|
175
|
+
best_overall_threshold = threshold
|
176
|
+
best_overall_method = method
|
177
|
+
best_overall_metric = metric
|
178
|
+
|
179
|
+
return best_overall_method, best_overall_metric, best_overall_threshold
|
180
|
+
|
181
|
+
|
182
|
+
def _find_best_silhouette_score(
|
183
|
+
Z: np.ndarray,
|
184
|
+
m: np.ndarray,
|
185
|
+
linkage_metric: str,
|
186
|
+
linkage_criterion: str,
|
187
|
+
lower_bound: float = 0.001,
|
188
|
+
upper_bound: float = 1.0,
|
189
|
+
resolution: float = 0.001,
|
190
|
+
) -> Tuple[float, float]:
|
191
|
+
"""Find the best silhouette score using binary search.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
Z (np.ndarray): Linkage matrix.
|
195
|
+
m (np.ndarray): Data matrix.
|
196
|
+
linkage_metric (str): Linkage metric for silhouette score calculation.
|
197
|
+
linkage_criterion (str): Clustering criterion.
|
198
|
+
lower_bound (float, optional): Lower bound for search. Defaults to 0.001.
|
199
|
+
upper_bound (float, optional): Upper bound for search. Defaults to 1.0.
|
200
|
+
resolution (float, optional): Desired resolution for the best threshold. Defaults to 0.001.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
tuple[float, float]: A tuple containing:
|
204
|
+
- Best threshold (float): The threshold that yields the best silhouette score.
|
205
|
+
- Best silhouette score (float): The highest silhouette score achieved.
|
206
|
+
"""
|
207
|
+
best_score = -np.inf
|
208
|
+
best_threshold = None
|
209
|
+
|
210
|
+
# Test lower bound
|
211
|
+
max_d_lower = np.max(Z[:, 2]) * lower_bound
|
212
|
+
clusters_lower = fcluster(Z, max_d_lower, criterion=linkage_criterion)
|
213
|
+
try:
|
214
|
+
score_lower = silhouette_score(m, clusters_lower, metric=linkage_metric)
|
215
|
+
except ValueError:
|
216
|
+
score_lower = -np.inf
|
217
|
+
|
218
|
+
# Test upper bound
|
219
|
+
max_d_upper = np.max(Z[:, 2]) * upper_bound
|
220
|
+
clusters_upper = fcluster(Z, max_d_upper, criterion=linkage_criterion)
|
221
|
+
try:
|
222
|
+
score_upper = silhouette_score(m, clusters_upper, metric=linkage_metric)
|
223
|
+
except ValueError:
|
224
|
+
score_upper = -np.inf
|
225
|
+
|
226
|
+
# Determine initial bounds for binary search
|
227
|
+
if score_lower > score_upper:
|
228
|
+
best_score = score_lower
|
229
|
+
best_threshold = lower_bound
|
230
|
+
upper_bound = (lower_bound + upper_bound) / 2
|
231
|
+
else:
|
232
|
+
best_score = score_upper
|
233
|
+
best_threshold = upper_bound
|
234
|
+
lower_bound = (lower_bound + upper_bound) / 2
|
235
|
+
|
236
|
+
# Binary search loop
|
237
|
+
while upper_bound - lower_bound > resolution:
|
238
|
+
mid_threshold = (upper_bound + lower_bound) / 2
|
239
|
+
max_d_mid = np.max(Z[:, 2]) * mid_threshold
|
240
|
+
clusters_mid = fcluster(Z, max_d_mid, criterion=linkage_criterion)
|
241
|
+
try:
|
242
|
+
score_mid = silhouette_score(m, clusters_mid, metric=linkage_metric)
|
243
|
+
except ValueError:
|
244
|
+
score_mid = -np.inf
|
245
|
+
|
246
|
+
# Update best score and threshold if mid-point is better
|
247
|
+
if score_mid > best_score:
|
248
|
+
best_score = score_mid
|
249
|
+
best_threshold = mid_threshold
|
250
|
+
|
251
|
+
# Adjust bounds based on the scores
|
252
|
+
if score_lower > score_upper:
|
253
|
+
upper_bound = mid_threshold
|
254
|
+
else:
|
255
|
+
lower_bound = mid_threshold
|
256
|
+
|
257
|
+
return best_threshold, float(best_score)
|