workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -1,162 +1,498 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
2
|
import numpy as np
|
|
3
3
|
from sklearn.neighbors import NearestNeighbors
|
|
4
|
-
from typing import Union, List
|
|
4
|
+
from typing import Union, List, Optional
|
|
5
5
|
import logging
|
|
6
6
|
|
|
7
7
|
# Workbench Imports
|
|
8
|
-
from workbench.algorithms.dataframe.proximity import Proximity
|
|
8
|
+
from workbench.algorithms.dataframe.proximity import Proximity
|
|
9
|
+
from workbench.algorithms.dataframe.projection_2d import Projection2D
|
|
10
|
+
from workbench.utils.chem_utils.fingerprints import compute_morgan_fingerprints
|
|
9
11
|
|
|
10
12
|
# Set up logging
|
|
11
13
|
log = logging.getLogger("workbench")
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class FingerprintProximity(Proximity):
|
|
17
|
+
"""Proximity computations for binary fingerprints using Tanimoto similarity.
|
|
18
|
+
|
|
19
|
+
Note: Tanimoto similarity is equivalent to Jaccard similarity for binary vectors.
|
|
20
|
+
Tanimoto(A, B) = |A ∩ B| / |A ∪ B|
|
|
21
|
+
"""
|
|
22
|
+
|
|
15
23
|
def __init__(
|
|
16
|
-
self,
|
|
24
|
+
self,
|
|
25
|
+
df: pd.DataFrame,
|
|
26
|
+
id_column: str,
|
|
27
|
+
fingerprint_column: Optional[str] = None,
|
|
28
|
+
target: Optional[str] = None,
|
|
29
|
+
include_all_columns: bool = False,
|
|
30
|
+
radius: int = 2,
|
|
31
|
+
n_bits: int = 1024,
|
|
17
32
|
) -> None:
|
|
18
33
|
"""
|
|
19
34
|
Initialize the FingerprintProximity class for binary fingerprint similarity.
|
|
20
35
|
|
|
21
36
|
Args:
|
|
22
|
-
df
|
|
23
|
-
id_column
|
|
24
|
-
fingerprint_column
|
|
25
|
-
|
|
37
|
+
df: DataFrame containing fingerprints or SMILES.
|
|
38
|
+
id_column: Name of the column used as an identifier.
|
|
39
|
+
fingerprint_column: Name of the column containing fingerprints (bit strings).
|
|
40
|
+
If None, looks for existing "fingerprint" column or computes from SMILES.
|
|
41
|
+
target: Name of the target column. Defaults to None.
|
|
42
|
+
include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
|
|
43
|
+
radius: Radius for Morgan fingerprint computation (default: 2).
|
|
44
|
+
n_bits: Number of bits for fingerprint (default: 1024).
|
|
45
|
+
"""
|
|
46
|
+
# Store fingerprint computation parameters
|
|
47
|
+
self._fp_radius = radius
|
|
48
|
+
self._fp_n_bits = n_bits
|
|
49
|
+
|
|
50
|
+
# Store the requested fingerprint column (may be None)
|
|
51
|
+
self._fingerprint_column_arg = fingerprint_column
|
|
52
|
+
|
|
53
|
+
# Determine fingerprint column name (but don't compute yet - that happens in _prepare_data)
|
|
54
|
+
self.fingerprint_column = self._resolve_fingerprint_column_name(df, fingerprint_column)
|
|
55
|
+
|
|
56
|
+
# Call parent constructor with fingerprint_column as the only "feature"
|
|
57
|
+
super().__init__(
|
|
58
|
+
df,
|
|
59
|
+
id_column=id_column,
|
|
60
|
+
features=[self.fingerprint_column],
|
|
61
|
+
target=target,
|
|
62
|
+
include_all_columns=include_all_columns,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _resolve_fingerprint_column_name(df: pd.DataFrame, fingerprint_column: Optional[str]) -> str:
|
|
67
|
+
"""
|
|
68
|
+
Determine the fingerprint column name, validating it exists or can be computed.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
df: Input DataFrame.
|
|
72
|
+
fingerprint_column: Explicitly specified fingerprint column, or None.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Name of the fingerprint column to use.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If no fingerprint column exists and no SMILES column found.
|
|
26
79
|
"""
|
|
27
|
-
|
|
80
|
+
# If explicitly provided, validate it exists
|
|
81
|
+
if fingerprint_column is not None:
|
|
82
|
+
if fingerprint_column not in df.columns:
|
|
83
|
+
raise ValueError(f"Fingerprint column '{fingerprint_column}' not found in DataFrame")
|
|
84
|
+
return fingerprint_column
|
|
85
|
+
|
|
86
|
+
# Check for existing "fingerprint" column
|
|
87
|
+
if "fingerprint" in df.columns:
|
|
88
|
+
log.info("Using existing 'fingerprint' column")
|
|
89
|
+
return "fingerprint"
|
|
90
|
+
|
|
91
|
+
# Will need to compute from SMILES - validate SMILES column exists
|
|
92
|
+
smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
|
|
93
|
+
if smiles_column is None:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"No fingerprint column provided and no SMILES column found. "
|
|
96
|
+
"Either provide a fingerprint_column or include a 'smiles' column in the DataFrame."
|
|
97
|
+
)
|
|
28
98
|
|
|
29
|
-
#
|
|
30
|
-
|
|
99
|
+
# Fingerprints will be computed in _prepare_data
|
|
100
|
+
return "fingerprint"
|
|
31
101
|
|
|
32
|
-
|
|
33
|
-
|
|
102
|
+
def _prepare_data(self) -> None:
|
|
103
|
+
"""Compute fingerprints from SMILES if needed."""
|
|
104
|
+
# If fingerprint column doesn't exist yet, compute it
|
|
105
|
+
if self.fingerprint_column not in self.df.columns:
|
|
106
|
+
log.info(f"Computing Morgan fingerprints (radius={self._fp_radius}, n_bits={self._fp_n_bits})...")
|
|
107
|
+
self.df = compute_morgan_fingerprints(self.df, radius=self._fp_radius, n_bits=self._fp_n_bits)
|
|
108
|
+
|
|
109
|
+
def _build_model(self) -> None:
|
|
34
110
|
"""
|
|
35
|
-
|
|
36
|
-
|
|
111
|
+
Build the fingerprint proximity model for Tanimoto similarity.
|
|
112
|
+
|
|
113
|
+
For binary fingerprints: uses Jaccard distance (1 - Tanimoto)
|
|
114
|
+
For count fingerprints: uses weighted Tanimoto (Ruzicka) distance
|
|
37
115
|
"""
|
|
38
|
-
|
|
39
|
-
self.
|
|
116
|
+
# Convert fingerprint strings to matrix and detect format
|
|
117
|
+
self.X, self._is_count_fp = self._fingerprints_to_matrix(self.df)
|
|
40
118
|
|
|
41
|
-
|
|
119
|
+
if self._is_count_fp:
|
|
120
|
+
# Weighted Tanimoto (Ruzicka) for count vectors: 1 - Σmin(A,B)/Σmax(A,B)
|
|
121
|
+
log.info("Building NearestNeighbors model (weighted Tanimoto for count fingerprints)...")
|
|
42
122
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
123
|
+
def ruzicka_distance(a, b):
|
|
124
|
+
"""Ruzicka distance = 1 - weighted Tanimoto similarity."""
|
|
125
|
+
min_sum = np.minimum(a, b).sum()
|
|
126
|
+
max_sum = np.maximum(a, b).sum()
|
|
127
|
+
if max_sum == 0:
|
|
128
|
+
return 0.0
|
|
129
|
+
return 1.0 - (min_sum / max_sum)
|
|
47
130
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
131
|
+
self.nn = NearestNeighbors(metric=ruzicka_distance, algorithm="ball_tree").fit(self.X)
|
|
132
|
+
else:
|
|
133
|
+
# Standard Jaccard for binary fingerprints
|
|
134
|
+
log.info("Building NearestNeighbors model (Jaccard/Tanimoto for binary fingerprints)...")
|
|
135
|
+
self.nn = NearestNeighbors(metric="jaccard", algorithm="ball_tree").fit(self.X)
|
|
51
136
|
|
|
52
|
-
|
|
53
|
-
def prep_features_for_query(self, query_df: pd.DataFrame) -> np.ndarray:
|
|
137
|
+
def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
|
|
54
138
|
"""
|
|
55
|
-
|
|
139
|
+
Transform fingerprints to matrix for querying.
|
|
56
140
|
|
|
57
141
|
Args:
|
|
58
|
-
|
|
142
|
+
df: DataFrame containing fingerprints to transform.
|
|
59
143
|
|
|
60
144
|
Returns:
|
|
61
|
-
|
|
145
|
+
Feature matrix for the fingerprints (binary or count based on self._is_count_fp).
|
|
62
146
|
"""
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
)
|
|
66
|
-
return np.vstack(fingerprint_bits)
|
|
147
|
+
matrix, _ = self._fingerprints_to_matrix(df)
|
|
148
|
+
return matrix
|
|
67
149
|
|
|
68
|
-
def
|
|
69
|
-
self,
|
|
70
|
-
min_similarity: float = None,
|
|
71
|
-
include_self: bool = False,
|
|
72
|
-
add_columns: List[str] = None,
|
|
73
|
-
) -> pd.DataFrame:
|
|
150
|
+
def _fingerprints_to_matrix(self, df: pd.DataFrame) -> tuple[np.ndarray, bool]:
|
|
74
151
|
"""
|
|
75
|
-
|
|
152
|
+
Convert fingerprint strings to a numpy matrix.
|
|
153
|
+
|
|
154
|
+
Supports two formats (auto-detected):
|
|
155
|
+
- Bitstrings: "10110010..." → binary matrix (bool), is_count=False
|
|
156
|
+
- Count vectors: "0,3,0,1,5,..." → count matrix (uint8), is_count=True
|
|
76
157
|
|
|
77
158
|
Args:
|
|
78
|
-
|
|
79
|
-
include_self: Whether to include self in results
|
|
80
|
-
add_columns: Additional columns to include in results
|
|
159
|
+
df: DataFrame containing fingerprint column.
|
|
81
160
|
|
|
82
161
|
Returns:
|
|
83
|
-
|
|
162
|
+
Tuple of (2D numpy array, is_count_fingerprint boolean)
|
|
84
163
|
"""
|
|
164
|
+
# Auto-detect format based on first fingerprint
|
|
165
|
+
sample = str(df[self.fingerprint_column].iloc[0])
|
|
166
|
+
if "," in sample:
|
|
167
|
+
# Count vector format: preserve counts for weighted Tanimoto
|
|
168
|
+
fingerprint_values = df[self.fingerprint_column].apply(
|
|
169
|
+
lambda fp: np.array([int(x) for x in fp.split(",")], dtype=np.uint8)
|
|
170
|
+
)
|
|
171
|
+
return np.vstack(fingerprint_values), True
|
|
172
|
+
else:
|
|
173
|
+
# Bitstring format: binary values
|
|
174
|
+
fingerprint_bits = df[self.fingerprint_column].apply(
|
|
175
|
+
lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
|
|
176
|
+
)
|
|
177
|
+
return np.vstack(fingerprint_bits), False
|
|
85
178
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
179
|
+
def _precompute_metrics(self) -> None:
|
|
180
|
+
"""Precompute metrics, adding Tanimoto similarity alongside distance."""
|
|
181
|
+
# Call parent to compute nn_distance (Jaccard), nn_id, nn_target, nn_target_diff
|
|
182
|
+
super()._precompute_metrics()
|
|
183
|
+
|
|
184
|
+
# Add Tanimoto similarity (keep nn_distance for internal use by target_gradients)
|
|
185
|
+
self.df["nn_similarity"] = 1 - self.df["nn_distance"]
|
|
186
|
+
|
|
187
|
+
def _set_core_columns(self) -> None:
|
|
188
|
+
"""Set core columns using nn_similarity instead of nn_distance."""
|
|
189
|
+
self.core_columns = [self.id_column, "nn_similarity", "nn_id"]
|
|
190
|
+
if self.target:
|
|
191
|
+
self.core_columns.extend([self.target, "nn_target", "nn_target_diff"])
|
|
192
|
+
|
|
193
|
+
def _project_2d(self) -> None:
|
|
194
|
+
"""Project the fingerprint matrix to 2D for visualization using UMAP."""
|
|
195
|
+
if self._is_count_fp:
|
|
196
|
+
# For count fingerprints, convert to binary for UMAP projection (Jaccard needs binary)
|
|
197
|
+
X_binary = (self.X > 0).astype(np.bool_)
|
|
198
|
+
self.df = Projection2D().fit_transform(self.df, feature_matrix=X_binary, metric="jaccard")
|
|
199
|
+
else:
|
|
200
|
+
self.df = Projection2D().fit_transform(self.df, feature_matrix=self.X, metric="jaccard")
|
|
201
|
+
|
|
202
|
+
def isolated(self, top_percent: float = 1.0) -> pd.DataFrame:
|
|
203
|
+
"""
|
|
204
|
+
Find isolated data points based on Tanimoto similarity to nearest neighbor.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
top_percent: Percentage of most isolated data points to return (e.g., 1.0 returns top 1%)
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
DataFrame of observations with lowest Tanimoto similarity, sorted ascending
|
|
211
|
+
"""
|
|
212
|
+
# For Tanimoto similarity, isolated means LOW similarity to nearest neighbor
|
|
213
|
+
percentile = top_percent
|
|
214
|
+
threshold = np.percentile(self.df["nn_similarity"], percentile)
|
|
215
|
+
isolated = self.df[self.df["nn_similarity"] <= threshold].copy()
|
|
216
|
+
isolated = isolated.sort_values("nn_similarity", ascending=True).reset_index(drop=True)
|
|
217
|
+
return isolated if self.include_all_columns else isolated[self.core_columns]
|
|
218
|
+
|
|
219
|
+
def proximity_stats(self) -> pd.DataFrame:
|
|
220
|
+
"""
|
|
221
|
+
Return distribution statistics for nearest neighbor Tanimoto similarity.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
DataFrame with similarity distribution statistics (count, mean, std, percentiles)
|
|
225
|
+
"""
|
|
226
|
+
return (
|
|
227
|
+
self.df["nn_similarity"]
|
|
228
|
+
.describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
|
|
229
|
+
.to_frame()
|
|
92
230
|
)
|
|
93
231
|
|
|
94
232
|
def neighbors(
|
|
95
233
|
self,
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
234
|
+
id_or_ids: Union[str, int, List[Union[str, int]]],
|
|
235
|
+
n_neighbors: Optional[int] = 5,
|
|
236
|
+
min_similarity: Optional[float] = None,
|
|
237
|
+
include_self: bool = True,
|
|
100
238
|
) -> pd.DataFrame:
|
|
101
239
|
"""
|
|
102
|
-
|
|
240
|
+
Return neighbors for ID(s) from the existing dataset.
|
|
103
241
|
|
|
104
242
|
Args:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
243
|
+
id_or_ids: Single ID or list of IDs to look up
|
|
244
|
+
n_neighbors: Number of neighbors to return (default: 5, ignored if min_similarity is set)
|
|
245
|
+
min_similarity: If provided, find all neighbors with Tanimoto similarity >= this value (0-1)
|
|
246
|
+
include_self: Whether to include self in results (default: True)
|
|
109
247
|
|
|
110
248
|
Returns:
|
|
111
|
-
DataFrame containing neighbors
|
|
112
|
-
|
|
113
|
-
Note: The query DataFrame must include the feature columns. The id_column is optional.
|
|
249
|
+
DataFrame containing neighbors with Tanimoto similarity scores
|
|
114
250
|
"""
|
|
115
|
-
|
|
116
|
-
# Calculate radius from similarity if provided
|
|
251
|
+
# Convert min_similarity to radius (Jaccard distance = 1 - Tanimoto similarity)
|
|
117
252
|
radius = 1 - min_similarity if min_similarity is not None else None
|
|
118
253
|
|
|
119
|
-
# Call
|
|
254
|
+
# Call parent method (returns Jaccard distance)
|
|
120
255
|
neighbors_df = super().neighbors(
|
|
121
|
-
|
|
256
|
+
id_or_ids=id_or_ids,
|
|
257
|
+
n_neighbors=n_neighbors,
|
|
122
258
|
radius=radius,
|
|
123
259
|
include_self=include_self,
|
|
124
|
-
add_columns=add_columns,
|
|
125
260
|
)
|
|
126
261
|
|
|
127
|
-
# Convert
|
|
262
|
+
# Convert Jaccard distance to Tanimoto similarity
|
|
128
263
|
neighbors_df["similarity"] = 1 - neighbors_df["distance"]
|
|
129
264
|
neighbors_df.drop(columns=["distance"], inplace=True)
|
|
265
|
+
|
|
130
266
|
return neighbors_df
|
|
131
267
|
|
|
268
|
+
def neighbors_from_smiles(
|
|
269
|
+
self,
|
|
270
|
+
smiles: Union[str, List[str]],
|
|
271
|
+
n_neighbors: int = 5,
|
|
272
|
+
min_similarity: Optional[float] = None,
|
|
273
|
+
) -> pd.DataFrame:
|
|
274
|
+
"""
|
|
275
|
+
Find neighbors for SMILES strings not in the reference dataset.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
smiles: Single SMILES string or list of SMILES to query
|
|
279
|
+
n_neighbors: Number of neighbors to return (default: 5, ignored if min_similarity is set)
|
|
280
|
+
min_similarity: If provided, find all neighbors with Tanimoto similarity >= this value (0-1)
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
DataFrame containing neighbors with Tanimoto similarity scores.
|
|
284
|
+
The 'query_id' column contains the SMILES string (or index if list).
|
|
285
|
+
"""
|
|
286
|
+
# Normalize to list
|
|
287
|
+
smiles_list = [smiles] if isinstance(smiles, str) else smiles
|
|
288
|
+
|
|
289
|
+
# Build a temporary DataFrame with the query SMILES
|
|
290
|
+
query_df = pd.DataFrame({"smiles": smiles_list})
|
|
291
|
+
|
|
292
|
+
# Compute fingerprints using same parameters as the reference dataset
|
|
293
|
+
query_df = compute_morgan_fingerprints(query_df, radius=self._fp_radius, n_bits=self._fp_n_bits)
|
|
294
|
+
|
|
295
|
+
# Transform to matrix (use same format detection as reference)
|
|
296
|
+
X_query, _ = self._fingerprints_to_matrix(query_df)
|
|
297
|
+
|
|
298
|
+
# Query the model
|
|
299
|
+
if min_similarity is not None:
|
|
300
|
+
radius = 1 - min_similarity
|
|
301
|
+
distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
|
|
302
|
+
else:
|
|
303
|
+
distances, indices = self.nn.kneighbors(X_query, n_neighbors=n_neighbors)
|
|
304
|
+
|
|
305
|
+
# Build results
|
|
306
|
+
results = []
|
|
307
|
+
for i, (dists, nbrs) in enumerate(zip(distances, indices)):
|
|
308
|
+
query_id = smiles_list[i]
|
|
309
|
+
|
|
310
|
+
for neighbor_idx, dist in zip(nbrs, dists):
|
|
311
|
+
neighbor_row = self.df.iloc[neighbor_idx]
|
|
312
|
+
neighbor_id = neighbor_row[self.id_column]
|
|
313
|
+
similarity = 1.0 - dist if dist > 1e-6 else 1.0
|
|
314
|
+
|
|
315
|
+
result = {
|
|
316
|
+
"query_id": query_id,
|
|
317
|
+
"neighbor_id": neighbor_id,
|
|
318
|
+
"similarity": similarity,
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
# Add target if present
|
|
322
|
+
if self.target and self.target in self.df.columns:
|
|
323
|
+
result[self.target] = neighbor_row[self.target]
|
|
324
|
+
|
|
325
|
+
# Include all columns if requested
|
|
326
|
+
if self.include_all_columns:
|
|
327
|
+
for col in self.df.columns:
|
|
328
|
+
if col not in [self.id_column, "query_id", "neighbor_id", "similarity"]:
|
|
329
|
+
result[f"neighbor_{col}"] = neighbor_row[col]
|
|
330
|
+
|
|
331
|
+
results.append(result)
|
|
332
|
+
|
|
333
|
+
df_results = pd.DataFrame(results)
|
|
334
|
+
|
|
335
|
+
# Sort by query_id then similarity descending
|
|
336
|
+
if len(df_results) > 0:
|
|
337
|
+
df_results = df_results.sort_values(["query_id", "similarity"], ascending=[True, False]).reset_index(
|
|
338
|
+
drop=True
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return df_results
|
|
342
|
+
|
|
132
343
|
|
|
133
344
|
# Testing the FingerprintProximity class
|
|
134
345
|
if __name__ == "__main__":
|
|
135
346
|
pd.set_option("display.max_columns", None)
|
|
136
347
|
pd.set_option("display.width", 1000)
|
|
137
348
|
|
|
138
|
-
# Example DataFrame
|
|
349
|
+
# Create an Example DataFrame with fingerprints
|
|
139
350
|
data = {
|
|
140
|
-
"id": ["a", "b", "c", "d"],
|
|
141
|
-
"fingerprint": ["101010", "111010", "101110", "011100"],
|
|
351
|
+
"id": ["a", "b", "c", "d", "e"],
|
|
352
|
+
"fingerprint": ["101010", "111010", "101110", "011100", "000111"],
|
|
353
|
+
"Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
354
|
+
"Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
355
|
+
"target": [1, 0, 1, 0, 5],
|
|
142
356
|
}
|
|
143
357
|
df = pd.DataFrame(data)
|
|
144
358
|
|
|
145
|
-
#
|
|
146
|
-
|
|
359
|
+
# Test basic FingerprintProximity with explicit fingerprint column
|
|
360
|
+
prox = FingerprintProximity(df, fingerprint_column="fingerprint", id_column="id", target="target")
|
|
361
|
+
print(prox.neighbors("a", n_neighbors=3))
|
|
362
|
+
|
|
363
|
+
# Test neighbors with similarity threshold
|
|
364
|
+
print(prox.neighbors("a", min_similarity=0.5))
|
|
365
|
+
|
|
366
|
+
# Test with include_all_columns=True
|
|
367
|
+
prox = FingerprintProximity(
|
|
368
|
+
df,
|
|
369
|
+
fingerprint_column="fingerprint",
|
|
370
|
+
id_column="id",
|
|
371
|
+
target="target",
|
|
372
|
+
include_all_columns=True,
|
|
373
|
+
)
|
|
374
|
+
print(prox.neighbors(["a", "b"]))
|
|
375
|
+
|
|
376
|
+
# Regression test: include_all_columns should not break neighbor sorting
|
|
377
|
+
print("\n" + "=" * 80)
|
|
378
|
+
print("Regression test: include_all_columns neighbor sorting...")
|
|
379
|
+
print("=" * 80)
|
|
380
|
+
neighbors_all_cols = prox.neighbors("a", n_neighbors=4)
|
|
381
|
+
# Verify neighbors are sorted by similarity (descending), not alphabetically by neighbor_id
|
|
382
|
+
similarities = neighbors_all_cols["similarity"].tolist()
|
|
383
|
+
assert similarities == sorted(
|
|
384
|
+
similarities, reverse=True
|
|
385
|
+
), f"Neighbors not sorted by similarity! Got: {similarities}"
|
|
386
|
+
# Verify query_id column has correct value (the query, not the neighbor)
|
|
387
|
+
assert all(
|
|
388
|
+
neighbors_all_cols["id"] == "a"
|
|
389
|
+
), f"Query ID column corrupted! Expected all 'a', got: {neighbors_all_cols['id'].tolist()}"
|
|
390
|
+
print("PASSED: Neighbors correctly sorted by similarity with include_all_columns=True")
|
|
391
|
+
|
|
392
|
+
# Test neighbors_from_smiles with synthetic data
|
|
393
|
+
print("\n" + "=" * 80)
|
|
394
|
+
print("Testing neighbors_from_smiles...")
|
|
395
|
+
print("=" * 80)
|
|
396
|
+
|
|
397
|
+
# Create reference dataset with known SMILES
|
|
398
|
+
ref_data = {
|
|
399
|
+
"id": ["aspirin", "ibuprofen", "naproxen", "caffeine", "ethanol"],
|
|
400
|
+
"smiles": [
|
|
401
|
+
"CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
|
|
402
|
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
|
|
403
|
+
"COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
|
|
404
|
+
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
|
|
405
|
+
"CCO", # ethanol
|
|
406
|
+
],
|
|
407
|
+
"activity": [1.0, 2.0, 2.5, 3.0, 0.5],
|
|
408
|
+
}
|
|
409
|
+
ref_df = pd.DataFrame(ref_data)
|
|
410
|
+
|
|
411
|
+
prox_ref = FingerprintProximity(ref_df, id_column="id", target="activity", radius=2, n_bits=1024)
|
|
412
|
+
|
|
413
|
+
# Query with a single SMILES (acetaminophen - similar to aspirin)
|
|
414
|
+
query_smiles = "CC(=O)NC1=CC=C(C=C1)O" # acetaminophen
|
|
415
|
+
print(f"\nQuery: acetaminophen ({query_smiles})")
|
|
416
|
+
neighbors = prox_ref.neighbors_from_smiles(query_smiles, n_neighbors=3)
|
|
417
|
+
print(neighbors)
|
|
418
|
+
|
|
419
|
+
# Query with multiple SMILES
|
|
420
|
+
print("\nQuery: multiple SMILES (theophylline, methanol)")
|
|
421
|
+
multi_query = [
|
|
422
|
+
"CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
|
|
423
|
+
"CO", # methanol - similar to ethanol
|
|
424
|
+
]
|
|
425
|
+
neighbors_multi = prox_ref.neighbors_from_smiles(multi_query, n_neighbors=2)
|
|
426
|
+
print(neighbors_multi)
|
|
427
|
+
|
|
428
|
+
# Test with min_similarity threshold
|
|
429
|
+
print("\nQuery with min_similarity=0.3:")
|
|
430
|
+
neighbors_thresh = prox_ref.neighbors_from_smiles(query_smiles, min_similarity=0.3)
|
|
431
|
+
print(neighbors_thresh)
|
|
432
|
+
|
|
433
|
+
print("PASSED: neighbors_from_smiles working correctly")
|
|
434
|
+
|
|
435
|
+
# Test on real data from Workbench
|
|
436
|
+
from workbench.api import FeatureSet, Model
|
|
437
|
+
|
|
438
|
+
fs = FeatureSet("aqsol_features")
|
|
439
|
+
model = Model("aqsol-regression")
|
|
440
|
+
df = fs.pull_dataframe()[:1000] # Limit to 1000 for testing
|
|
441
|
+
prox = FingerprintProximity(df, id_column=fs.id_column, target=model.target())
|
|
442
|
+
|
|
443
|
+
print("\n" + "=" * 80)
|
|
444
|
+
print("Testing Neighbors...")
|
|
445
|
+
print("=" * 80)
|
|
446
|
+
test_id = df[fs.id_column].tolist()[0]
|
|
447
|
+
print(f"\nNeighbors for ID {test_id}:")
|
|
448
|
+
print(prox.neighbors(test_id))
|
|
449
|
+
|
|
450
|
+
print("\n" + "=" * 80)
|
|
451
|
+
print("Testing isolated compounds...")
|
|
452
|
+
print("=" * 80)
|
|
453
|
+
|
|
454
|
+
# Test isolated data in the top 1%
|
|
455
|
+
isolated_1pct = prox.isolated(top_percent=1.0)
|
|
456
|
+
print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
|
|
457
|
+
print(isolated_1pct)
|
|
458
|
+
|
|
459
|
+
# Test isolated data in the top 5%
|
|
460
|
+
isolated_5pct = prox.isolated(top_percent=5.0)
|
|
461
|
+
print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
|
|
462
|
+
print(isolated_5pct)
|
|
463
|
+
|
|
464
|
+
print("\n" + "=" * 80)
|
|
465
|
+
print("Testing target_gradients...")
|
|
466
|
+
print("=" * 80)
|
|
467
|
+
|
|
468
|
+
# Test with different parameters
|
|
469
|
+
gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
|
|
470
|
+
print(f"\nTop 1% target gradients (min_delta=1.0) (n={len(gradients_1pct)}):")
|
|
471
|
+
print(gradients_1pct)
|
|
472
|
+
|
|
473
|
+
gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
|
|
474
|
+
print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
|
|
475
|
+
print(gradients_5pct)
|
|
476
|
+
|
|
477
|
+
# Test proximity_stats
|
|
478
|
+
print("\n" + "=" * 80)
|
|
479
|
+
print("Testing proximity_stats...")
|
|
480
|
+
print("=" * 80)
|
|
481
|
+
stats = prox.proximity_stats()
|
|
482
|
+
print(stats)
|
|
147
483
|
|
|
148
|
-
#
|
|
149
|
-
print("\n
|
|
150
|
-
|
|
151
|
-
print(
|
|
484
|
+
# Plot the similarity distribution using pandas
|
|
485
|
+
print("\n" + "=" * 80)
|
|
486
|
+
print("Plotting similarity distribution...")
|
|
487
|
+
print("=" * 80)
|
|
488
|
+
prox.df["nn_similarity"].hist(bins=50, figsize=(10, 6), edgecolor="black")
|
|
152
489
|
|
|
153
|
-
#
|
|
154
|
-
print("\n
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
490
|
+
# Visualize the 2D projection
|
|
491
|
+
print("\n" + "=" * 80)
|
|
492
|
+
print("Visualizing 2D Projection...")
|
|
493
|
+
print("=" * 80)
|
|
494
|
+
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
495
|
+
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
158
496
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
query_neighbors_sim_df = proximity.neighbors(query_df=query_df, min_similarity=0.5)
|
|
162
|
-
print(query_neighbors_sim_df)
|
|
497
|
+
unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
|
|
498
|
+
unit_test.run()
|