risk-network 0.0.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
risk/log/params.py ADDED
@@ -0,0 +1,198 @@
1
+ """
2
+ risk/log/params
3
+ ~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import csv
7
+ import json
8
+ import warnings
9
+ from datetime import datetime
10
+ from typing import Any, Dict
11
+
12
+ import numpy as np
13
+
14
+ from .console import print_header
15
+
16
+ # Suppress all warnings - this is to resolve warnings from multiprocessing
17
+ warnings.filterwarnings("ignore")
18
+
19
+
20
+ def _safe_param_export(func):
21
+ """A decorator to wrap parameter export functions in a try-except block for safe execution.
22
+
23
+ Args:
24
+ func (function): The function to be wrapped.
25
+
26
+ Returns:
27
+ function: The wrapped function with error handling.
28
+ """
29
+
30
+ def wrapper(*args, **kwargs):
31
+ try:
32
+ result = func(*args, **kwargs)
33
+ filepath = (
34
+ kwargs.get("filepath") or args[1]
35
+ ) # Assuming filepath is always the second argument
36
+ print(f"Parameters successfully exported to filepath: {filepath}")
37
+ return result
38
+ except Exception as e:
39
+ filepath = kwargs.get("filepath") or args[1]
40
+ print(f"An error occurred while exporting parameters to {filepath}: {e}")
41
+ return None
42
+
43
+ return wrapper
44
+
45
+
46
+ class Params:
47
+ """Handles the storage and logging of various parameters for network analysis.
48
+
49
+ The Params class provides methods to log parameters related to different components of the analysis,
50
+ such as the network, annotations, neighborhoods, graph, and plotter settings. It also stores
51
+ the current datetime when the parameters were initialized.
52
+ """
53
+
54
+ def __init__(self):
55
+ """Initialize the Params object with default settings and current datetime."""
56
+ self.initialize()
57
+ self.datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
58
+
59
+ def initialize(self) -> None:
60
+ """Initialize the parameter dictionaries for different components."""
61
+ self.network = {}
62
+ self.annotations = {}
63
+ self.neighborhoods = {}
64
+ self.graph = {}
65
+ self.plotter = {}
66
+
67
+ def log_network(self, **kwargs) -> None:
68
+ """Log network-related parameters.
69
+
70
+ Args:
71
+ **kwargs: Network parameters to log.
72
+ """
73
+ self.network = {**self.network, **kwargs}
74
+
75
+ def log_annotations(self, **kwargs) -> None:
76
+ """Log annotation-related parameters.
77
+
78
+ Args:
79
+ **kwargs: Annotation parameters to log.
80
+ """
81
+ self.annotations = {**self.annotations, **kwargs}
82
+
83
+ def log_neighborhoods(self, **kwargs) -> None:
84
+ """Log neighborhood-related parameters.
85
+
86
+ Args:
87
+ **kwargs: Neighborhood parameters to log.
88
+ """
89
+ self.neighborhoods = {**self.neighborhoods, **kwargs}
90
+
91
+ def log_graph(self, **kwargs) -> None:
92
+ """Log graph-related parameters.
93
+
94
+ Args:
95
+ **kwargs: Graph parameters to log.
96
+ """
97
+ self.graph = {**self.graph, **kwargs}
98
+
99
+ def log_plotter(self, **kwargs) -> None:
100
+ """Log plotter-related parameters.
101
+
102
+ Args:
103
+ **kwargs: Plotter parameters to log.
104
+ """
105
+ self.plotter = {**self.plotter, **kwargs}
106
+
107
+ @_safe_param_export
108
+ def to_csv(self, filepath: str) -> None:
109
+ """Export the parameters to a CSV file.
110
+
111
+ Args:
112
+ filepath (str): The path where the CSV file will be saved.
113
+ """
114
+ # Load the parameter dictionary
115
+ params = self.load()
116
+ # Open the file in write mode
117
+ with open(filepath, "w", newline="") as csv_file:
118
+ writer = csv.writer(csv_file)
119
+ # Write the header
120
+ writer.writerow(["parent_key", "child_key", "value"])
121
+ # Write the rows
122
+ for parent_key, parent_value in params.items():
123
+ if isinstance(parent_value, dict):
124
+ for child_key, child_value in parent_value.items():
125
+ writer.writerow([parent_key, child_key, child_value])
126
+ else:
127
+ writer.writerow([parent_key, "", parent_value])
128
+
129
+ @_safe_param_export
130
+ def to_json(self, filepath: str) -> None:
131
+ """Export the parameters to a JSON file.
132
+
133
+ Args:
134
+ filepath (str): The path where the JSON file will be saved.
135
+ """
136
+ with open(filepath, "w") as json_file:
137
+ json.dump(self.load(), json_file, indent=4)
138
+
139
+ @_safe_param_export
140
+ def to_txt(self, filepath: str) -> None:
141
+ """Export the parameters to a text file.
142
+
143
+ Args:
144
+ filepath (str): The path where the text file will be saved.
145
+ """
146
+ # Load the parameter dictionary
147
+ params = self.load()
148
+ # Open the file in write mode
149
+ with open(filepath, "w") as txt_file:
150
+ for key, nested_dict in params.items():
151
+ # Write the key
152
+ txt_file.write(f"{key}:\n")
153
+ # Write the nested dictionary values, one per line
154
+ for nested_key, nested_value in nested_dict.items():
155
+ txt_file.write(f" {nested_key}: {nested_value}\n")
156
+ # Add a blank line between different keys
157
+ txt_file.write("\n")
158
+
159
+ def load(self) -> Dict[str, Any]:
160
+ """Load and process various parameters, converting any np.ndarray values to lists.
161
+
162
+ Returns:
163
+ dict: A dictionary containing the processed parameters.
164
+ """
165
+ print_header("Loading parameters")
166
+ return _convert_ndarray_to_list(
167
+ {
168
+ "annotations": self.annotations,
169
+ "datetime": self.datetime,
170
+ "graph": self.graph,
171
+ "neighborhoods": self.neighborhoods,
172
+ "network": self.network,
173
+ "plotter": self.plotter,
174
+ }
175
+ )
176
+
177
+
178
+ def _convert_ndarray_to_list(d: Any) -> Any:
179
+ """Recursively convert all np.ndarray values in the dictionary to lists.
180
+
181
+ Args:
182
+ d (dict): The dictionary to process.
183
+
184
+ Returns:
185
+ dict: The processed dictionary with np.ndarray values converted to lists.
186
+ """
187
+ if isinstance(d, dict):
188
+ # Recursively process each value in the dictionary
189
+ return {k: _convert_ndarray_to_list(v) for k, v in d.items()}
190
+ elif isinstance(d, list):
191
+ # Recursively process each item in the list
192
+ return [_convert_ndarray_to_list(v) for v in d]
193
+ elif isinstance(d, np.ndarray):
194
+ # Convert numpy arrays to lists
195
+ return d.tolist()
196
+ else:
197
+ # Return the value unchanged if it's not a dict, list, or ndarray
198
+ return d
@@ -0,0 +1,10 @@
1
+ """
2
+ risk/neighborhoods
3
+ ~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from .domains import define_domains, trim_domains_and_top_annotations
7
+ from .neighborhoods import (
8
+ get_network_neighborhoods,
9
+ process_neighborhoods,
10
+ )
@@ -0,0 +1,189 @@
1
+ """
2
+ risk/neighborhoods/community
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ import community as community_louvain
7
+ import networkx as nx
8
+ import numpy as np
9
+ import markov_clustering as mc
10
+ from networkx.algorithms.community import asyn_lpa_communities
11
+
12
+
13
+ def calculate_dijkstra_neighborhoods(network: nx.Graph) -> np.ndarray:
14
+ """Calculate neighborhoods using Dijkstra's shortest path distances.
15
+
16
+ Args:
17
+ network (nx.Graph): The network graph.
18
+
19
+ Returns:
20
+ np.ndarray: Neighborhood matrix based on Dijkstra's distances.
21
+ """
22
+ # Compute Dijkstra's distance for all pairs of nodes in the network
23
+ all_dijkstra_paths = dict(nx.all_pairs_dijkstra_path_length(network, weight="length"))
24
+ neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
25
+
26
+ # Populate the neighborhoods matrix based on Dijkstra's distances
27
+ for source, targets in all_dijkstra_paths.items():
28
+ for target, length in targets.items():
29
+ neighborhoods[source, target] = (
30
+ 1 if np.isnan(length) or length == 0 else np.sqrt(1 / length)
31
+ )
32
+
33
+ return neighborhoods
34
+
35
+
36
+ def calculate_label_propagation_neighborhoods(network: nx.Graph) -> np.ndarray:
37
+ """Apply Label Propagation to the network to detect communities.
38
+
39
+ Args:
40
+ network (nx.Graph): The network graph.
41
+
42
+ Returns:
43
+ np.ndarray: Neighborhood matrix based on Label Propagation.
44
+ """
45
+ # Apply Label Propagation
46
+ communities = nx.algorithms.community.label_propagation.label_propagation_communities(network)
47
+
48
+ # Create a mapping from node to community
49
+ community_dict = {}
50
+ for community_id, community in enumerate(communities):
51
+ for node in community:
52
+ community_dict[node] = community_id
53
+
54
+ # Create a neighborhood matrix
55
+ num_nodes = network.number_of_nodes()
56
+ neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
57
+
58
+ # Assign neighborhoods based on community labels
59
+ for node_i, community_i in community_dict.items():
60
+ for node_j, community_j in community_dict.items():
61
+ if community_i == community_j:
62
+ neighborhoods[node_i, node_j] = 1
63
+
64
+ return neighborhoods
65
+
66
+
67
+ def calculate_louvain_neighborhoods(
68
+ network: nx.Graph, resolution: float, random_seed: int = 888
69
+ ) -> np.ndarray:
70
+ """Calculate neighborhoods using the Louvain method.
71
+
72
+ Args:
73
+ network (nx.Graph): The network graph.
74
+ resolution (float): Resolution parameter for the Louvain method.
75
+ random_seed (int, optional): Random seed for reproducibility. Defaults to 888.
76
+
77
+ Returns:
78
+ np.ndarray: Neighborhood matrix based on the Louvain method.
79
+ """
80
+ # Apply Louvain method to partition the network
81
+ partition = community_louvain.best_partition(
82
+ network, resolution=resolution, random_state=random_seed
83
+ )
84
+ neighborhoods = np.zeros((network.number_of_nodes(), network.number_of_nodes()), dtype=int)
85
+
86
+ # Assign neighborhoods based on community partitions
87
+ for node_i, community_i in partition.items():
88
+ for node_j, community_j in partition.items():
89
+ if community_i == community_j:
90
+ neighborhoods[node_i, node_j] = 1
91
+
92
+ return neighborhoods
93
+
94
+
95
+ def calculate_markov_clustering_neighborhoods(network: nx.Graph) -> np.ndarray:
96
+ """Apply Markov Clustering (MCL) to the network.
97
+
98
+ Args:
99
+ network (nx.Graph): The network graph.
100
+
101
+ Returns:
102
+ np.ndarray: Neighborhood matrix based on Markov Clustering.
103
+ """
104
+ # Convert the graph to an adjacency matrix
105
+ adjacency_matrix = nx.to_numpy_array(network)
106
+ # Run Markov Clustering
107
+ result = mc.run_mcl(adjacency_matrix) # Run MCL with default parameters
108
+ # Get clusters
109
+ clusters = mc.get_clusters(result)
110
+
111
+ # Create a community label for each node
112
+ community_dict = {}
113
+ for community_id, community in enumerate(clusters):
114
+ for node in community:
115
+ community_dict[node] = community_id
116
+
117
+ # Create a neighborhood matrix
118
+ num_nodes = network.number_of_nodes()
119
+ neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
120
+
121
+ # Assign neighborhoods based on community labels
122
+ for node_i, community_i in community_dict.items():
123
+ for node_j, community_j in community_dict.items():
124
+ if community_i == community_j:
125
+ neighborhoods[node_i, node_j] = 1
126
+
127
+ return neighborhoods
128
+
129
+
130
+ def calculate_spinglass_neighborhoods(network: nx.Graph) -> np.ndarray:
131
+ """Apply Spin Glass Community Detection to the network.
132
+
133
+ Args:
134
+ network (nx.Graph): The network graph.
135
+
136
+ Returns:
137
+ np.ndarray: Neighborhood matrix based on Spin Glass communities.
138
+ """
139
+ # Use the asynchronous label propagation algorithm as a proxy for Spin Glass
140
+ communities = asyn_lpa_communities(network)
141
+
142
+ # Create a community label for each node
143
+ community_dict = {}
144
+ for community_id, community in enumerate(communities):
145
+ for node in community:
146
+ community_dict[node] = community_id
147
+
148
+ # Create a neighborhood matrix
149
+ num_nodes = network.number_of_nodes()
150
+ neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
151
+
152
+ # Assign neighborhoods based on community labels
153
+ for node_i, community_i in community_dict.items():
154
+ for node_j, community_j in community_dict.items():
155
+ if community_i == community_j:
156
+ neighborhoods[node_i, node_j] = 1
157
+
158
+ return neighborhoods
159
+
160
+
161
+ def calculate_walktrap_neighborhoods(network: nx.Graph) -> np.ndarray:
162
+ """Apply Walktrap Community Detection to the network.
163
+
164
+ Args:
165
+ network (nx.Graph): The network graph.
166
+
167
+ Returns:
168
+ np.ndarray: Neighborhood matrix based on Walktrap communities.
169
+ """
170
+ # Use the asynchronous label propagation algorithm as a proxy for Walktrap
171
+ communities = asyn_lpa_communities(network)
172
+
173
+ # Create a community label for each node
174
+ community_dict = {}
175
+ for community_id, community in enumerate(communities):
176
+ for node in community:
177
+ community_dict[node] = community_id
178
+
179
+ # Create a neighborhood matrix
180
+ num_nodes = network.number_of_nodes()
181
+ neighborhoods = np.zeros((num_nodes, num_nodes), dtype=int)
182
+
183
+ # Assign neighborhoods based on community labels
184
+ for node_i, community_i in community_dict.items():
185
+ for node_j, community_j in community_dict.items():
186
+ if community_i == community_j:
187
+ neighborhoods[node_i, node_j] = 1
188
+
189
+ return neighborhoods
@@ -0,0 +1,257 @@
1
+ """
2
+ risk/neighborhoods/domains
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ """
5
+
6
+ from contextlib import suppress
7
+ from tqdm import tqdm
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from scipy.cluster.hierarchy import linkage, fcluster
13
+ from sklearn.metrics import silhouette_score
14
+
15
+ from risk.annotations import get_description
16
+ from risk.constants import GROUP_LINKAGE_METHODS, GROUP_DISTANCE_METRICS
17
+
18
+
19
+ def define_domains(
20
+ top_annotations: pd.DataFrame,
21
+ significant_neighborhoods_enrichment: np.ndarray,
22
+ linkage_criterion: str,
23
+ linkage_method: str,
24
+ linkage_metric: str,
25
+ ) -> pd.DataFrame:
26
+ """Define domains and assign nodes to these domains based on their enrichment scores and clustering.
27
+
28
+ Args:
29
+ top_annotations (pd.DataFrame): DataFrame of top annotations data for the network nodes.
30
+ significant_neighborhoods_enrichment (np.ndarray): The binary enrichment matrix below alpha.
31
+ linkage_criterion (str): The clustering criterion for defining groups.
32
+ linkage_method (str): The linkage method for clustering.
33
+ linkage_metric (str): The linkage metric for clustering.
34
+
35
+ Returns:
36
+ pd.DataFrame: DataFrame with the primary domain for each node.
37
+ """
38
+ # Perform hierarchical clustering on the binary enrichment matrix
39
+ m = significant_neighborhoods_enrichment[:, top_annotations["top attributes"]].T
40
+ best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
41
+ m, linkage_criterion, linkage_method, linkage_metric
42
+ )
43
+ try:
44
+ Z = linkage(m, method=best_linkage, metric=best_metric)
45
+ except ValueError as e:
46
+ raise ValueError("No significant annotations found.") from e
47
+
48
+ print(
49
+ f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'"
50
+ )
51
+ print(f"Optimal linkage threshold: {round(best_threshold, 3)}")
52
+
53
+ max_d_optimal = np.max(Z[:, 2]) * best_threshold
54
+ domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
55
+ # Assign domains to the annotations matrix
56
+ top_annotations["domain"] = 0
57
+ top_annotations.loc[top_annotations["top attributes"], "domain"] = domains
58
+
59
+ # Create DataFrames to store domain information
60
+ node_to_enrichment = pd.DataFrame(
61
+ data=significant_neighborhoods_enrichment,
62
+ columns=[top_annotations.index.values, top_annotations["domain"]],
63
+ )
64
+ node_to_domain = node_to_enrichment.groupby(level="domain", axis=1).sum()
65
+
66
+ t_max = node_to_domain.loc[:, 1:].max(axis=1)
67
+ t_idxmax = node_to_domain.loc[:, 1:].idxmax(axis=1)
68
+ t_idxmax[t_max == 0] = 0
69
+
70
+ # Assign primary domain
71
+ node_to_domain["primary domain"] = t_idxmax
72
+
73
+ return node_to_domain
74
+
75
+
76
+ def trim_domains_and_top_annotations(
77
+ domains: pd.DataFrame,
78
+ top_annotations: pd.DataFrame,
79
+ min_cluster_size: int = 5,
80
+ max_cluster_size: int = 1000,
81
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
82
+ """Trim domains and top annotations that do not meet size criteria and find outliers.
83
+
84
+ Args:
85
+ domains (pd.DataFrame): DataFrame of domain data for the network nodes.
86
+ top_annotations (pd.DataFrame): DataFrame of top annotations data for the network nodes.
87
+ min_cluster_size (int, optional): Minimum size of a cluster to be retained. Defaults to 5.
88
+ max_cluster_size (int, optional): Maximum size of a cluster to be retained. Defaults to 1000.
89
+
90
+ Returns:
91
+ tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing:
92
+ - Trimmed annotations (pd.DataFrame)
93
+ - Trimmed domains (pd.DataFrame)
94
+ - A DataFrame with domain labels (pd.DataFrame)
95
+ """
96
+ # Identify domains to remove based on size criteria
97
+ domain_counts = domains["primary domain"].value_counts()
98
+ to_remove = set(
99
+ domain_counts[(domain_counts < min_cluster_size) | (domain_counts > max_cluster_size)].index
100
+ )
101
+
102
+ # Add invalid domain IDs
103
+ invalid_domain_id = 888888
104
+ invalid_domain_ids = {0, invalid_domain_id}
105
+ # Mark domains to be removed
106
+ top_annotations["domain"].replace(to_remove, invalid_domain_id, inplace=True)
107
+ domains.loc[domains["primary domain"].isin(to_remove), ["primary domain"]] = invalid_domain_id
108
+
109
+ # Normalize "num enriched neighborhoods" by percentile for each domain and scale to 0-10
110
+ top_annotations["normalized_value"] = top_annotations.groupby("domain")[
111
+ "neighborhood enrichment sums"
112
+ ].transform(lambda x: (x.rank(pct=True) * 10).apply(np.ceil).astype(int))
113
+ # Multiply 'words' column by normalized values
114
+ top_annotations["words"] = top_annotations.apply(
115
+ lambda row: " ".join([row["words"]] * row["normalized_value"]), axis=1
116
+ )
117
+
118
+ # Generate domain labels
119
+ domain_labels = top_annotations.groupby("domain")["words"].apply(get_description).reset_index()
120
+ trimmed_domains_matrix = domain_labels.rename(
121
+ columns={"domain": "id", "words": "label"}
122
+ ).set_index("id")
123
+
124
+ # Remove invalid domains
125
+ valid_annotations = top_annotations[~top_annotations["domain"].isin(invalid_domain_ids)].drop(
126
+ columns=["normalized_value"]
127
+ )
128
+ valid_domains = domains[~domains["primary domain"].isin(invalid_domain_ids)]
129
+ valid_trimmed_domains_matrix = trimmed_domains_matrix[
130
+ ~trimmed_domains_matrix.index.isin(invalid_domain_ids)
131
+ ]
132
+
133
+ return valid_annotations, valid_domains, valid_trimmed_domains_matrix
134
+
135
+
136
+ def _optimize_silhouette_across_linkage_and_metrics(
137
+ m: np.ndarray, linkage_criterion: str, linkage_method: str, linkage_metric: str
138
+ ) -> Tuple[str, str, float]:
139
+ """Optimize silhouette score across different linkage methods and distance metrics.
140
+
141
+ Args:
142
+ m (np.ndarray): Data matrix.
143
+ linkage_criterion (str): Clustering criterion.
144
+ linkage_method (str): Linkage method for clustering.
145
+ linkage_metric (str): Linkage metric for clustering.
146
+
147
+ Returns:
148
+ tuple[str, str, float]: A tuple containing:
149
+ - Best linkage method (str)
150
+ - Best linkage metric (str)
151
+ - Best threshold (float)
152
+ """
153
+ best_overall_method = linkage_method
154
+ best_overall_metric = linkage_metric
155
+ best_overall_score = -np.inf
156
+ best_overall_threshold = 1
157
+
158
+ linkage_methods = GROUP_LINKAGE_METHODS if linkage_method == "auto" else [linkage_method]
159
+ linkage_metrics = GROUP_DISTANCE_METRICS if linkage_metric == "auto" else [linkage_metric]
160
+ total_combinations = len(linkage_methods) * len(linkage_metrics)
161
+
162
+ # Evaluating optimal linkage method and metric
163
+ for method in tqdm(
164
+ linkage_methods,
165
+ desc="Evaluating optimal linkage method and metric",
166
+ total=total_combinations,
167
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
168
+ ):
169
+ for metric in linkage_metrics:
170
+ with suppress(Exception):
171
+ Z = linkage(m, method=method, metric=metric)
172
+ threshold, score = _find_best_silhouette_score(Z, m, metric, linkage_criterion)
173
+ if score > best_overall_score:
174
+ best_overall_score = score
175
+ best_overall_threshold = threshold
176
+ best_overall_method = method
177
+ best_overall_metric = metric
178
+
179
+ return best_overall_method, best_overall_metric, best_overall_threshold
180
+
181
+
182
+ def _find_best_silhouette_score(
183
+ Z: np.ndarray,
184
+ m: np.ndarray,
185
+ linkage_metric: str,
186
+ linkage_criterion: str,
187
+ lower_bound: float = 0.001,
188
+ upper_bound: float = 1.0,
189
+ resolution: float = 0.001,
190
+ ) -> Tuple[float, float]:
191
+ """Find the best silhouette score using binary search.
192
+
193
+ Args:
194
+ Z (np.ndarray): Linkage matrix.
195
+ m (np.ndarray): Data matrix.
196
+ linkage_metric (str): Linkage metric for silhouette score calculation.
197
+ linkage_criterion (str): Clustering criterion.
198
+ lower_bound (float, optional): Lower bound for search. Defaults to 0.001.
199
+ upper_bound (float, optional): Upper bound for search. Defaults to 1.0.
200
+ resolution (float, optional): Desired resolution for the best threshold. Defaults to 0.001.
201
+
202
+ Returns:
203
+ tuple[float, float]: A tuple containing:
204
+ - Best threshold (float): The threshold that yields the best silhouette score.
205
+ - Best silhouette score (float): The highest silhouette score achieved.
206
+ """
207
+ best_score = -np.inf
208
+ best_threshold = None
209
+
210
+ # Test lower bound
211
+ max_d_lower = np.max(Z[:, 2]) * lower_bound
212
+ clusters_lower = fcluster(Z, max_d_lower, criterion=linkage_criterion)
213
+ try:
214
+ score_lower = silhouette_score(m, clusters_lower, metric=linkage_metric)
215
+ except ValueError:
216
+ score_lower = -np.inf
217
+
218
+ # Test upper bound
219
+ max_d_upper = np.max(Z[:, 2]) * upper_bound
220
+ clusters_upper = fcluster(Z, max_d_upper, criterion=linkage_criterion)
221
+ try:
222
+ score_upper = silhouette_score(m, clusters_upper, metric=linkage_metric)
223
+ except ValueError:
224
+ score_upper = -np.inf
225
+
226
+ # Determine initial bounds for binary search
227
+ if score_lower > score_upper:
228
+ best_score = score_lower
229
+ best_threshold = lower_bound
230
+ upper_bound = (lower_bound + upper_bound) / 2
231
+ else:
232
+ best_score = score_upper
233
+ best_threshold = upper_bound
234
+ lower_bound = (lower_bound + upper_bound) / 2
235
+
236
+ # Binary search loop
237
+ while upper_bound - lower_bound > resolution:
238
+ mid_threshold = (upper_bound + lower_bound) / 2
239
+ max_d_mid = np.max(Z[:, 2]) * mid_threshold
240
+ clusters_mid = fcluster(Z, max_d_mid, criterion=linkage_criterion)
241
+ try:
242
+ score_mid = silhouette_score(m, clusters_mid, metric=linkage_metric)
243
+ except ValueError:
244
+ score_mid = -np.inf
245
+
246
+ # Update best score and threshold if mid-point is better
247
+ if score_mid > best_score:
248
+ best_score = score_mid
249
+ best_threshold = mid_threshold
250
+
251
+ # Adjust bounds based on the scores
252
+ if score_lower > score_upper:
253
+ upper_bound = mid_threshold
254
+ else:
255
+ lower_bound = mid_threshold
256
+
257
+ return best_threshold, float(best_score)