risk-network 0.0.12b0__py3-none-any.whl → 0.0.12b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- risk/__init__.py +1 -1
 - risk/annotations/__init__.py +10 -0
 - risk/annotations/annotations.py +354 -0
 - risk/annotations/io.py +241 -0
 - risk/annotations/nltk_setup.py +86 -0
 - risk/log/__init__.py +11 -0
 - risk/log/console.py +141 -0
 - risk/log/parameters.py +171 -0
 - risk/neighborhoods/__init__.py +7 -0
 - risk/neighborhoods/api.py +442 -0
 - risk/neighborhoods/community.py +441 -0
 - risk/neighborhoods/domains.py +360 -0
 - risk/neighborhoods/neighborhoods.py +514 -0
 - risk/neighborhoods/stats/__init__.py +13 -0
 - risk/neighborhoods/stats/permutation/__init__.py +6 -0
 - risk/neighborhoods/stats/permutation/permutation.py +240 -0
 - risk/neighborhoods/stats/permutation/test_functions.py +70 -0
 - risk/neighborhoods/stats/tests.py +275 -0
 - risk/network/__init__.py +4 -0
 - risk/network/graph/__init__.py +4 -0
 - risk/network/graph/api.py +200 -0
 - risk/network/graph/graph.py +268 -0
 - risk/network/graph/stats.py +166 -0
 - risk/network/graph/summary.py +253 -0
 - risk/network/io.py +693 -0
 - risk/network/plotter/__init__.py +4 -0
 - risk/network/plotter/api.py +54 -0
 - risk/network/plotter/canvas.py +291 -0
 - risk/network/plotter/contour.py +329 -0
 - risk/network/plotter/labels.py +935 -0
 - risk/network/plotter/network.py +294 -0
 - risk/network/plotter/plotter.py +141 -0
 - risk/network/plotter/utils/colors.py +419 -0
 - risk/network/plotter/utils/layout.py +94 -0
 - risk_network-0.0.12b1.dist-info/METADATA +122 -0
 - risk_network-0.0.12b1.dist-info/RECORD +40 -0
 - {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b1.dist-info}/WHEEL +1 -1
 - risk_network-0.0.12b0.dist-info/METADATA +0 -796
 - risk_network-0.0.12b0.dist-info/RECORD +0 -7
 - {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b1.dist-info}/licenses/LICENSE +0 -0
 - {risk_network-0.0.12b0.dist-info → risk_network-0.0.12b1.dist-info}/top_level.txt +0 -0
 
| 
         @@ -0,0 +1,360 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """
         
     | 
| 
      
 2 
     | 
    
         
            +
            risk/neighborhoods/domains
         
     | 
| 
      
 3 
     | 
    
         
            +
            ~~~~~~~~~~~~~~~~~~~~~~~~~~
         
     | 
| 
      
 4 
     | 
    
         
            +
            """
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            from itertools import product
         
     | 
| 
      
 7 
     | 
    
         
            +
            from typing import Tuple, Union
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 10 
     | 
    
         
            +
            import pandas as pd
         
     | 
| 
      
 11 
     | 
    
         
            +
            from numpy.linalg import LinAlgError
         
     | 
| 
      
 12 
     | 
    
         
            +
            from scipy.cluster.hierarchy import fcluster, linkage
         
     | 
| 
      
 13 
     | 
    
         
            +
            from sklearn.metrics import silhouette_score
         
     | 
| 
      
 14 
     | 
    
         
            +
            from tqdm import tqdm
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            from risk.annotations import get_weighted_description
         
     | 
| 
      
 17 
     | 
    
         
            +
            from risk.log import logger
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            # Define constants for clustering
         
     | 
| 
      
 20 
     | 
    
         
            +
            # fmt: off
         
     | 
| 
      
 21 
     | 
    
         
            +
            LINKAGE_METHODS = {"single", "complete", "average", "weighted", "centroid", "median", "ward"}
         
     | 
| 
      
 22 
     | 
    
         
            +
            LINKAGE_METRICS = {
         
     | 
| 
      
 23 
     | 
    
         
            +
                "braycurtis", "canberra", "chebyshev", "cityblock", "correlation", "cosine", "dice", "euclidean",
         
     | 
| 
      
 24 
     | 
    
         
            +
                "hamming", "jaccard", "jensenshannon", "kulczynski1", "mahalanobis", "matching", "minkowski",
         
     | 
| 
      
 25 
     | 
    
         
            +
                "rogerstanimoto", "russellrao", "seuclidean", "sokalmichener", "sokalsneath", "sqeuclidean", "yule",
         
     | 
| 
      
 26 
     | 
    
         
            +
            }
         
     | 
| 
      
 27 
     | 
    
         
            +
            # fmt: on
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
            def define_domains(
         
     | 
| 
      
 31 
     | 
    
         
            +
                top_annotations: pd.DataFrame,
         
     | 
| 
      
 32 
     | 
    
         
            +
                significant_neighborhoods_significance: np.ndarray,
         
     | 
| 
      
 33 
     | 
    
         
            +
                linkage_criterion: str,
         
     | 
| 
      
 34 
     | 
    
         
            +
                linkage_method: str,
         
     | 
| 
      
 35 
     | 
    
         
            +
                linkage_metric: str,
         
     | 
| 
      
 36 
     | 
    
         
            +
                linkage_threshold: Union[float, str],
         
     | 
| 
      
 37 
     | 
    
         
            +
            ) -> pd.DataFrame:
         
     | 
| 
      
 38 
     | 
    
         
            +
                """Define domains and assign nodes to these domains based on their significance scores and clustering,
         
     | 
| 
      
 39 
     | 
    
         
            +
                handling errors by assigning unique domains when clustering fails.
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 42 
     | 
    
         
            +
                    top_annotations (pd.DataFrame): DataFrame of top annotations data for the network nodes.
         
     | 
| 
      
 43 
     | 
    
         
            +
                    significant_neighborhoods_significance (np.ndarray): The binary significance matrix below alpha.
         
     | 
| 
      
 44 
     | 
    
         
            +
                    linkage_criterion (str): The clustering criterion for defining groups. Choose "off" to disable clustering.
         
     | 
| 
      
 45 
     | 
    
         
            +
                    linkage_method (str): The linkage method for clustering. Choose "auto" to optimize.
         
     | 
| 
      
 46 
     | 
    
         
            +
                    linkage_metric (str): The linkage metric for clustering. Choose "auto" to optimize.
         
     | 
| 
      
 47 
     | 
    
         
            +
                    linkage_threshold (float, str): The threshold for clustering. Choose "auto" to optimize.
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 50 
     | 
    
         
            +
                    pd.DataFrame: DataFrame with the primary domain for each node.
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                Raises:
         
     | 
| 
      
 53 
     | 
    
         
            +
                    ValueError: If the clustering criterion is set to "off" or if an error occurs during clustering.
         
     | 
| 
      
 54 
     | 
    
         
            +
                """
         
     | 
| 
      
 55 
     | 
    
         
            +
                try:
         
     | 
| 
      
 56 
     | 
    
         
            +
                    if linkage_criterion == "off":
         
     | 
| 
      
 57 
     | 
    
         
            +
                        raise ValueError("Clustering is turned off.")
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                    # Transpose the matrix to cluster annotations
         
     | 
| 
      
 60 
     | 
    
         
            +
                    m = significant_neighborhoods_significance[:, top_annotations["significant_annotations"]].T
         
     | 
| 
      
 61 
     | 
    
         
            +
                    # Safeguard the matrix by replacing NaN, Inf, and -Inf values
         
     | 
| 
      
 62 
     | 
    
         
            +
                    m = _safeguard_matrix(m)
         
     | 
| 
      
 63 
     | 
    
         
            +
                    # Optimize silhouette score across different linkage methods and distance metrics
         
     | 
| 
      
 64 
     | 
    
         
            +
                    best_linkage, best_metric, best_threshold = _optimize_silhouette_across_linkage_and_metrics(
         
     | 
| 
      
 65 
     | 
    
         
            +
                        m, linkage_criterion, linkage_method, linkage_metric, linkage_threshold
         
     | 
| 
      
 66 
     | 
    
         
            +
                    )
         
     | 
| 
      
 67 
     | 
    
         
            +
                    # Perform hierarchical clustering
         
     | 
| 
      
 68 
     | 
    
         
            +
                    Z = linkage(m, method=best_linkage, metric=best_metric)
         
     | 
| 
      
 69 
     | 
    
         
            +
                    logger.warning(
         
     | 
| 
      
 70 
     | 
    
         
            +
                        f"Linkage criterion: '{linkage_criterion}'\nLinkage method: '{best_linkage}'\nLinkage metric: '{best_metric}'\nLinkage threshold: {round(best_threshold, 3)}"
         
     | 
| 
      
 71 
     | 
    
         
            +
                    )
         
     | 
| 
      
 72 
     | 
    
         
            +
                    # Calculate the optimal threshold for clustering
         
     | 
| 
      
 73 
     | 
    
         
            +
                    max_d_optimal = np.max(Z[:, 2]) * best_threshold
         
     | 
| 
      
 74 
     | 
    
         
            +
                    # Assign domains to the annotations matrix
         
     | 
| 
      
 75 
     | 
    
         
            +
                    domains = fcluster(Z, max_d_optimal, criterion=linkage_criterion)
         
     | 
| 
      
 76 
     | 
    
         
            +
                    top_annotations["domain"] = 0
         
     | 
| 
      
 77 
     | 
    
         
            +
                    top_annotations.loc[top_annotations["significant_annotations"], "domain"] = domains
         
     | 
| 
      
 78 
     | 
    
         
            +
                except (ValueError, LinAlgError):
         
     | 
| 
      
 79 
     | 
    
         
            +
                    # If a ValueError is encountered, handle it by assigning unique domains
         
     | 
| 
      
 80 
     | 
    
         
            +
                    n_rows = len(top_annotations)
         
     | 
| 
      
 81 
     | 
    
         
            +
                    if linkage_criterion == "off":
         
     | 
| 
      
 82 
     | 
    
         
            +
                        logger.warning(
         
     | 
| 
      
 83 
     | 
    
         
            +
                            f"Clustering is turned off. Skipping clustering and assigning {n_rows} unique domains."
         
     | 
| 
      
 84 
     | 
    
         
            +
                        )
         
     | 
| 
      
 85 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 86 
     | 
    
         
            +
                        logger.error(
         
     | 
| 
      
 87 
     | 
    
         
            +
                            f"Error encountered. Skipping clustering and assigning {n_rows} unique domains."
         
     | 
| 
      
 88 
     | 
    
         
            +
                        )
         
     | 
| 
      
 89 
     | 
    
         
            +
                    top_annotations["domain"] = range(1, n_rows + 1)  # Assign unique domains
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                # Create DataFrames to store domain information
         
     | 
| 
      
 92 
     | 
    
         
            +
                node_to_significance = pd.DataFrame(
         
     | 
| 
      
 93 
     | 
    
         
            +
                    data=significant_neighborhoods_significance,
         
     | 
| 
      
 94 
     | 
    
         
            +
                    columns=[top_annotations.index.values, top_annotations["domain"]],
         
     | 
| 
      
 95 
     | 
    
         
            +
                )
         
     | 
| 
      
 96 
     | 
    
         
            +
                node_to_domain = node_to_significance.T.groupby(level="domain").sum().T
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                # Find the maximum significance score for each node
         
     | 
| 
      
 99 
     | 
    
         
            +
                t_max = node_to_domain.loc[:, 1:].max(axis=1)
         
     | 
| 
      
 100 
     | 
    
         
            +
                t_idxmax = node_to_domain.loc[:, 1:].idxmax(axis=1)
         
     | 
| 
      
 101 
     | 
    
         
            +
                t_idxmax[t_max == 0] = 0
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
                # Assign all domains where the score is greater than 0
         
     | 
| 
      
 104 
     | 
    
         
            +
                node_to_domain["all_domains"] = node_to_domain.loc[:, 1:].apply(
         
     | 
| 
      
 105 
     | 
    
         
            +
                    lambda row: list(row[row > 0].index), axis=1
         
     | 
| 
      
 106 
     | 
    
         
            +
                )
         
     | 
| 
      
 107 
     | 
    
         
            +
                # Assign primary domain
         
     | 
| 
      
 108 
     | 
    
         
            +
                node_to_domain["primary_domain"] = t_idxmax
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                return node_to_domain
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
            def trim_domains(
         
     | 
| 
      
 114 
     | 
    
         
            +
                domains: pd.DataFrame,
         
     | 
| 
      
 115 
     | 
    
         
            +
                top_annotations: pd.DataFrame,
         
     | 
| 
      
 116 
     | 
    
         
            +
                min_cluster_size: int = 5,
         
     | 
| 
      
 117 
     | 
    
         
            +
                max_cluster_size: int = 1000,
         
     | 
| 
      
 118 
     | 
    
         
            +
            ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
         
     | 
| 
      
 119 
     | 
    
         
            +
                """Trim domains that do not meet size criteria and find outliers.
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 122 
     | 
    
         
            +
                    domains (pd.DataFrame): DataFrame of domain data for the network nodes.
         
     | 
| 
      
 123 
     | 
    
         
            +
                    top_annotations (pd.DataFrame): DataFrame of top annotations data for the network nodes.
         
     | 
| 
      
 124 
     | 
    
         
            +
                    min_cluster_size (int, optional): Minimum size of a cluster to be retained. Defaults to 5.
         
     | 
| 
      
 125 
     | 
    
         
            +
                    max_cluster_size (int, optional): Maximum size of a cluster to be retained. Defaults to 1000.
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 128 
     | 
    
         
            +
                    Tuple[pd.DataFrame, pd.DataFrame]:
         
     | 
| 
      
 129 
     | 
    
         
            +
                        - Trimmed domains (pd.DataFrame)
         
     | 
| 
      
 130 
     | 
    
         
            +
                        - A DataFrame with domain labels (pd.DataFrame)
         
     | 
| 
      
 131 
     | 
    
         
            +
                """
         
     | 
| 
      
 132 
     | 
    
         
            +
                # Identify domains to remove based on size criteria
         
     | 
| 
      
 133 
     | 
    
         
            +
                domain_counts = domains["primary_domain"].value_counts()
         
     | 
| 
      
 134 
     | 
    
         
            +
                to_remove = set(
         
     | 
| 
      
 135 
     | 
    
         
            +
                    domain_counts[(domain_counts < min_cluster_size) | (domain_counts > max_cluster_size)].index
         
     | 
| 
      
 136 
     | 
    
         
            +
                )
         
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
                # Add invalid domain IDs
         
     | 
| 
      
 139 
     | 
    
         
            +
                invalid_domain_id = 888888
         
     | 
| 
      
 140 
     | 
    
         
            +
                invalid_domain_ids = {0, invalid_domain_id}
         
     | 
| 
      
 141 
     | 
    
         
            +
                # Mark domains to be removed
         
     | 
| 
      
 142 
     | 
    
         
            +
                top_annotations["domain"] = top_annotations["domain"].replace(to_remove, invalid_domain_id)
         
     | 
| 
      
 143 
     | 
    
         
            +
                domains.loc[domains["primary_domain"].isin(to_remove), ["primary_domain"]] = invalid_domain_id
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                # Normalize "num significant neighborhoods" by percentile for each domain and scale to 0-10
         
     | 
| 
      
 146 
     | 
    
         
            +
                top_annotations["normalized_value"] = top_annotations.groupby("domain")[
         
     | 
| 
      
 147 
     | 
    
         
            +
                    "significant_neighborhood_significance_sums"
         
     | 
| 
      
 148 
     | 
    
         
            +
                ].transform(lambda x: (x.rank(pct=True) * 10).apply(np.ceil).astype(int))
         
     | 
| 
      
 149 
     | 
    
         
            +
                # Modify the lambda function to pass both full_terms and significant_significance_score
         
     | 
| 
      
 150 
     | 
    
         
            +
                top_annotations["combined_terms"] = top_annotations.apply(
         
     | 
| 
      
 151 
     | 
    
         
            +
                    lambda row: " ".join([str(row["full_terms"])] * row["normalized_value"]), axis=1
         
     | 
| 
      
 152 
     | 
    
         
            +
                )
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
                # Perform the groupby operation while retaining the other columns and adding the weighting with significance scores
         
     | 
| 
      
 155 
     | 
    
         
            +
                domain_labels = (
         
     | 
| 
      
 156 
     | 
    
         
            +
                    top_annotations.groupby("domain")
         
     | 
| 
      
 157 
     | 
    
         
            +
                    .agg(
         
     | 
| 
      
 158 
     | 
    
         
            +
                        full_terms=("full_terms", lambda x: list(x)),
         
     | 
| 
      
 159 
     | 
    
         
            +
                        significance_scores=("significant_significance_score", lambda x: list(x)),
         
     | 
| 
      
 160 
     | 
    
         
            +
                    )
         
     | 
| 
      
 161 
     | 
    
         
            +
                    .reset_index()
         
     | 
| 
      
 162 
     | 
    
         
            +
                )
         
     | 
| 
      
 163 
     | 
    
         
            +
                domain_labels["combined_terms"] = domain_labels.apply(
         
     | 
| 
      
 164 
     | 
    
         
            +
                    lambda row: get_weighted_description(
         
     | 
| 
      
 165 
     | 
    
         
            +
                        pd.Series(row["full_terms"]), pd.Series(row["significance_scores"])
         
     | 
| 
      
 166 
     | 
    
         
            +
                    ),
         
     | 
| 
      
 167 
     | 
    
         
            +
                    axis=1,
         
     | 
| 
      
 168 
     | 
    
         
            +
                )
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                # Rename the columns as necessary
         
     | 
| 
      
 171 
     | 
    
         
            +
                trimmed_domains_matrix = domain_labels.rename(
         
     | 
| 
      
 172 
     | 
    
         
            +
                    columns={
         
     | 
| 
      
 173 
     | 
    
         
            +
                        "domain": "id",
         
     | 
| 
      
 174 
     | 
    
         
            +
                        "combined_terms": "normalized_description",
         
     | 
| 
      
 175 
     | 
    
         
            +
                        "full_terms": "full_descriptions",
         
     | 
| 
      
 176 
     | 
    
         
            +
                        "significance_scores": "significance_scores",
         
     | 
| 
      
 177 
     | 
    
         
            +
                    }
         
     | 
| 
      
 178 
     | 
    
         
            +
                ).set_index("id")
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
                # Remove invalid domains
         
     | 
| 
      
 181 
     | 
    
         
            +
                valid_domains = domains[~domains["primary_domain"].isin(invalid_domain_ids)]
         
     | 
| 
      
 182 
     | 
    
         
            +
                valid_trimmed_domains_matrix = trimmed_domains_matrix[
         
     | 
| 
      
 183 
     | 
    
         
            +
                    ~trimmed_domains_matrix.index.isin(invalid_domain_ids)
         
     | 
| 
      
 184 
     | 
    
         
            +
                ]
         
     | 
| 
      
 185 
     | 
    
         
            +
                return valid_domains, valid_trimmed_domains_matrix
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
      
 188 
     | 
    
         
            +
            def _safeguard_matrix(matrix: np.ndarray) -> np.ndarray:
         
     | 
| 
      
 189 
     | 
    
         
            +
                """Safeguard the matrix by replacing NaN, Inf, and -Inf values.
         
     | 
| 
      
 190 
     | 
    
         
            +
             
     | 
| 
      
 191 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 192 
     | 
    
         
            +
                    matrix (np.ndarray): Data matrix.
         
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 195 
     | 
    
         
            +
                    np.ndarray: Safeguarded data matrix.
         
     | 
| 
      
 196 
     | 
    
         
            +
                """
         
     | 
| 
      
 197 
     | 
    
         
            +
                # Replace NaN with column mean
         
     | 
| 
      
 198 
     | 
    
         
            +
                nan_replacement = np.nanmean(matrix, axis=0)
         
     | 
| 
      
 199 
     | 
    
         
            +
                matrix = np.where(np.isnan(matrix), nan_replacement, matrix)
         
     | 
| 
      
 200 
     | 
    
         
            +
                # Replace Inf/-Inf with maximum/minimum finite values
         
     | 
| 
      
 201 
     | 
    
         
            +
                finite_max = np.nanmax(matrix[np.isfinite(matrix)])
         
     | 
| 
      
 202 
     | 
    
         
            +
                finite_min = np.nanmin(matrix[np.isfinite(matrix)])
         
     | 
| 
      
 203 
     | 
    
         
            +
                matrix = np.where(np.isposinf(matrix), finite_max, matrix)
         
     | 
| 
      
 204 
     | 
    
         
            +
                matrix = np.where(np.isneginf(matrix), finite_min, matrix)
         
     | 
| 
      
 205 
     | 
    
         
            +
                # Ensure rows have non-zero variance (optional step)
         
     | 
| 
      
 206 
     | 
    
         
            +
                row_variance = np.var(matrix, axis=1)
         
     | 
| 
      
 207 
     | 
    
         
            +
                matrix = matrix[row_variance > 0]
         
     | 
| 
      
 208 
     | 
    
         
            +
                return matrix
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
            def _optimize_silhouette_across_linkage_and_metrics(
         
     | 
| 
      
 212 
     | 
    
         
            +
                m: np.ndarray,
         
     | 
| 
      
 213 
     | 
    
         
            +
                linkage_criterion: str,
         
     | 
| 
      
 214 
     | 
    
         
            +
                linkage_method: str,
         
     | 
| 
      
 215 
     | 
    
         
            +
                linkage_metric: str,
         
     | 
| 
      
 216 
     | 
    
         
            +
                linkage_threshold: Union[str, float],
         
     | 
| 
      
 217 
     | 
    
         
            +
            ) -> Tuple[str, str, float]:
         
     | 
| 
      
 218 
     | 
    
         
            +
                """Optimize silhouette score across different linkage methods and distance metrics.
         
     | 
| 
      
 219 
     | 
    
         
            +
             
     | 
| 
      
 220 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 221 
     | 
    
         
            +
                    m (np.ndarray): Data matrix.
         
     | 
| 
      
 222 
     | 
    
         
            +
                    linkage_criterion (str): Clustering criterion.
         
     | 
| 
      
 223 
     | 
    
         
            +
                    linkage_method (str): Linkage method for clustering. Choose "auto" to optimize.
         
     | 
| 
      
 224 
     | 
    
         
            +
                    linkage_metric (str): Linkage metric for clustering. Choose "auto" to optimize.
         
     | 
| 
      
 225 
     | 
    
         
            +
                    linkage_threshold (Union[str, float]): Threshold for clustering. Choose "auto" to optimize.
         
     | 
| 
      
 226 
     | 
    
         
            +
             
     | 
| 
      
 227 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 228 
     | 
    
         
            +
                    Tuple[str, str, float]:
         
     | 
| 
      
 229 
     | 
    
         
            +
                        - Best linkage method (str)
         
     | 
| 
      
 230 
     | 
    
         
            +
                        - Best linkage metric (str)
         
     | 
| 
      
 231 
     | 
    
         
            +
                        - Best threshold (float)
         
     | 
| 
      
 232 
     | 
    
         
            +
                """
         
     | 
| 
      
 233 
     | 
    
         
            +
                # Initialize best overall values
         
     | 
| 
      
 234 
     | 
    
         
            +
                best_overall_method = linkage_method
         
     | 
| 
      
 235 
     | 
    
         
            +
                best_overall_metric = linkage_metric
         
     | 
| 
      
 236 
     | 
    
         
            +
                best_overall_threshold = linkage_threshold
         
     | 
| 
      
 237 
     | 
    
         
            +
                best_overall_score = -np.inf
         
     | 
| 
      
 238 
     | 
    
         
            +
             
     | 
| 
      
 239 
     | 
    
         
            +
                # Set linkage methods and metrics to all combinations if "auto" is selected
         
     | 
| 
      
 240 
     | 
    
         
            +
                linkage_methods = LINKAGE_METHODS if linkage_method == "auto" else [linkage_method]
         
     | 
| 
      
 241 
     | 
    
         
            +
                linkage_metrics = LINKAGE_METRICS if linkage_metric == "auto" else [linkage_metric]
         
     | 
| 
      
 242 
     | 
    
         
            +
                total_combinations = len(linkage_methods) * len(linkage_metrics)
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                # Evaluating optimal linkage method and metric
         
     | 
| 
      
 245 
     | 
    
         
            +
                for method, metric in tqdm(
         
     | 
| 
      
 246 
     | 
    
         
            +
                    product(linkage_methods, linkage_metrics),
         
     | 
| 
      
 247 
     | 
    
         
            +
                    desc="Evaluating linkage methods and metrics",
         
     | 
| 
      
 248 
     | 
    
         
            +
                    total=total_combinations,
         
     | 
| 
      
 249 
     | 
    
         
            +
                    bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
         
     | 
| 
      
 250 
     | 
    
         
            +
                ):
         
     | 
| 
      
 251 
     | 
    
         
            +
                    # Some linkage methods and metrics may not work with certain data
         
     | 
| 
      
 252 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 253 
     | 
    
         
            +
                        Z = linkage(m, method=method, metric=metric)
         
     | 
| 
      
 254 
     | 
    
         
            +
                        if linkage_threshold == "auto":
         
     | 
| 
      
 255 
     | 
    
         
            +
                            try:
         
     | 
| 
      
 256 
     | 
    
         
            +
                                threshold, score = _find_best_silhouette_score(Z, m, metric, linkage_criterion)
         
     | 
| 
      
 257 
     | 
    
         
            +
                            except (ValueError, LinAlgError):
         
     | 
| 
      
 258 
     | 
    
         
            +
                                continue  # Skip to the next combination
         
     | 
| 
      
 259 
     | 
    
         
            +
                            current_threshold = threshold
         
     | 
| 
      
 260 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 261 
     | 
    
         
            +
                            score = silhouette_score(
         
     | 
| 
      
 262 
     | 
    
         
            +
                                m,
         
     | 
| 
      
 263 
     | 
    
         
            +
                                fcluster(Z, linkage_threshold * np.max(Z[:, 2]), criterion=linkage_criterion),
         
     | 
| 
      
 264 
     | 
    
         
            +
                                metric=metric,
         
     | 
| 
      
 265 
     | 
    
         
            +
                            )
         
     | 
| 
      
 266 
     | 
    
         
            +
                            current_threshold = linkage_threshold
         
     | 
| 
      
 267 
     | 
    
         
            +
                    except (ValueError, LinAlgError):
         
     | 
| 
      
 268 
     | 
    
         
            +
                        continue  # Skip to the next combination
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
                    if score > best_overall_score:
         
     | 
| 
      
 271 
     | 
    
         
            +
                        best_overall_score = score
         
     | 
| 
      
 272 
     | 
    
         
            +
                        best_overall_threshold = float(current_threshold)  # Ensure it's a float
         
     | 
| 
      
 273 
     | 
    
         
            +
                        best_overall_method = method
         
     | 
| 
      
 274 
     | 
    
         
            +
                        best_overall_metric = metric
         
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
                # Ensure that we always return a valid tuple:
         
     | 
| 
      
 277 
     | 
    
         
            +
                if best_overall_score == -np.inf:
         
     | 
| 
      
 278 
     | 
    
         
            +
                    # No valid linkage was found; return default values.
         
     | 
| 
      
 279 
     | 
    
         
            +
                    best_overall_threshold = float(linkage_threshold) if linkage_threshold != "auto" else 0.0
         
     | 
| 
      
 280 
     | 
    
         
            +
                    best_overall_method = linkage_method
         
     | 
| 
      
 281 
     | 
    
         
            +
                    best_overall_metric = linkage_metric
         
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
                return best_overall_method, best_overall_metric, best_overall_threshold
         
     | 
| 
      
 284 
     | 
    
         
            +
             
     | 
| 
      
 285 
     | 
    
         
            +
             
     | 
| 
      
 286 
     | 
    
         
            +
            def _find_best_silhouette_score(
         
     | 
| 
      
 287 
     | 
    
         
            +
                Z: np.ndarray,
         
     | 
| 
      
 288 
     | 
    
         
            +
                m: np.ndarray,
         
     | 
| 
      
 289 
     | 
    
         
            +
                linkage_metric: str,
         
     | 
| 
      
 290 
     | 
    
         
            +
                linkage_criterion: str,
         
     | 
| 
      
 291 
     | 
    
         
            +
                lower_bound: float = 0.001,
         
     | 
| 
      
 292 
     | 
    
         
            +
                upper_bound: float = 1.0,
         
     | 
| 
      
 293 
     | 
    
         
            +
            ) -> Tuple[float, float]:
         
     | 
| 
      
 294 
     | 
    
         
            +
                """Find the best silhouette score using binary search.
         
     | 
| 
      
 295 
     | 
    
         
            +
             
     | 
| 
      
 296 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 297 
     | 
    
         
            +
                    Z (np.ndarray): Linkage matrix.
         
     | 
| 
      
 298 
     | 
    
         
            +
                    m (np.ndarray): Data matrix.
         
     | 
| 
      
 299 
     | 
    
         
            +
                    linkage_metric (str): Linkage metric for silhouette score calculation.
         
     | 
| 
      
 300 
     | 
    
         
            +
                    linkage_criterion (str): Clustering criterion.
         
     | 
| 
      
 301 
     | 
    
         
            +
                    lower_bound (float, optional): Lower bound for search. Defaults to 0.001.
         
     | 
| 
      
 302 
     | 
    
         
            +
                    upper_bound (float, optional): Upper bound for search. Defaults to 1.0.
         
     | 
| 
      
 303 
     | 
    
         
            +
             
     | 
| 
      
 304 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 305 
     | 
    
         
            +
                    Tuple[float, float]:
         
     | 
| 
      
 306 
     | 
    
         
            +
                        - Best threshold (float): The threshold that yields the best silhouette score.
         
     | 
| 
      
 307 
     | 
    
         
            +
                        - Best silhouette score (float): The highest silhouette score achieved.
         
     | 
| 
      
 308 
     | 
    
         
            +
                """
         
     | 
| 
      
 309 
     | 
    
         
            +
                best_score = -np.inf
         
     | 
| 
      
 310 
     | 
    
         
            +
                best_threshold = None
         
     | 
| 
      
 311 
     | 
    
         
            +
                minimum_linkage_threshold = 1e-6
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
                # Test lower bound
         
     | 
| 
      
 314 
     | 
    
         
            +
                max_d_lower = np.max(Z[:, 2]) * lower_bound
         
     | 
| 
      
 315 
     | 
    
         
            +
                clusters_lower = fcluster(Z, max_d_lower, criterion=linkage_criterion)
         
     | 
| 
      
 316 
     | 
    
         
            +
                try:
         
     | 
| 
      
 317 
     | 
    
         
            +
                    score_lower = silhouette_score(m, clusters_lower, metric=linkage_metric)
         
     | 
| 
      
 318 
     | 
    
         
            +
                except ValueError:
         
     | 
| 
      
 319 
     | 
    
         
            +
                    score_lower = -np.inf
         
     | 
| 
      
 320 
     | 
    
         
            +
             
     | 
| 
      
 321 
     | 
    
         
            +
                # Test upper bound
         
     | 
| 
      
 322 
     | 
    
         
            +
                max_d_upper = np.max(Z[:, 2]) * upper_bound
         
     | 
| 
      
 323 
     | 
    
         
            +
                clusters_upper = fcluster(Z, max_d_upper, criterion=linkage_criterion)
         
     | 
| 
      
 324 
     | 
    
         
            +
                try:
         
     | 
| 
      
 325 
     | 
    
         
            +
                    score_upper = silhouette_score(m, clusters_upper, metric=linkage_metric)
         
     | 
| 
      
 326 
     | 
    
         
            +
                except ValueError:
         
     | 
| 
      
 327 
     | 
    
         
            +
                    score_upper = -np.inf
         
     | 
| 
      
 328 
     | 
    
         
            +
             
     | 
| 
      
 329 
     | 
    
         
            +
                # Determine initial bounds for binary search
         
     | 
| 
      
 330 
     | 
    
         
            +
                if score_lower > score_upper:
         
     | 
| 
      
 331 
     | 
    
         
            +
                    best_score = score_lower
         
     | 
| 
      
 332 
     | 
    
         
            +
                    best_threshold = lower_bound
         
     | 
| 
      
 333 
     | 
    
         
            +
                    upper_bound = (lower_bound + upper_bound) / 2
         
     | 
| 
      
 334 
     | 
    
         
            +
                else:
         
     | 
| 
      
 335 
     | 
    
         
            +
                    best_score = score_upper
         
     | 
| 
      
 336 
     | 
    
         
            +
                    best_threshold = upper_bound
         
     | 
| 
      
 337 
     | 
    
         
            +
                    lower_bound = (lower_bound + upper_bound) / 2
         
     | 
| 
      
 338 
     | 
    
         
            +
             
     | 
| 
      
 339 
     | 
    
         
            +
                # Binary search loop
         
     | 
| 
      
 340 
     | 
    
         
            +
                while upper_bound - lower_bound > minimum_linkage_threshold:
         
     | 
| 
      
 341 
     | 
    
         
            +
                    mid_threshold = (upper_bound + lower_bound) / 2
         
     | 
| 
      
 342 
     | 
    
         
            +
                    max_d_mid = np.max(Z[:, 2]) * mid_threshold
         
     | 
| 
      
 343 
     | 
    
         
            +
                    clusters_mid = fcluster(Z, max_d_mid, criterion=linkage_criterion)
         
     | 
| 
      
 344 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 345 
     | 
    
         
            +
                        score_mid = silhouette_score(m, clusters_mid, metric=linkage_metric)
         
     | 
| 
      
 346 
     | 
    
         
            +
                    except ValueError:
         
     | 
| 
      
 347 
     | 
    
         
            +
                        score_mid = -np.inf
         
     | 
| 
      
 348 
     | 
    
         
            +
             
     | 
| 
      
 349 
     | 
    
         
            +
                    # Update best score and threshold if mid-point is better
         
     | 
| 
      
 350 
     | 
    
         
            +
                    if score_mid > best_score:
         
     | 
| 
      
 351 
     | 
    
         
            +
                        best_score = score_mid
         
     | 
| 
      
 352 
     | 
    
         
            +
                        best_threshold = mid_threshold
         
     | 
| 
      
 353 
     | 
    
         
            +
             
     | 
| 
      
 354 
     | 
    
         
            +
                    # Adjust bounds based on the scores
         
     | 
| 
      
 355 
     | 
    
         
            +
                    if score_lower > score_upper:
         
     | 
| 
      
 356 
     | 
    
         
            +
                        upper_bound = mid_threshold
         
     | 
| 
      
 357 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 358 
     | 
    
         
            +
                        lower_bound = mid_threshold
         
     | 
| 
      
 359 
     | 
    
         
            +
             
     | 
| 
      
 360 
     | 
    
         
            +
                return best_threshold, float(best_score)
         
     |