workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
- workbench/algorithms/dataframe/projection_2d.py +8 -2
- workbench/algorithms/dataframe/proximity.py +3 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/feature_set.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +238 -225
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/transforms/features_to_model/features_to_model.py +2 -8
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +196 -68
- workbench/model_scripts/chemprop/generated_model_script.py +197 -72
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/scripts/ml_pipeline_sqs.py +71 -2
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -64
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +82 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/model_utils.py +0 -1
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +52 -36
- workbench/utils/theme_manager.py +95 -30
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +184 -85
- workbench/web_interface/components/settings_menu.py +185 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
workbench/__init__.py
CHANGED
|
@@ -9,10 +9,12 @@ from .proximity import Proximity
|
|
|
9
9
|
from .feature_space_proximity import FeatureSpaceProximity
|
|
10
10
|
from .fingerprint_proximity import FingerprintProximity
|
|
11
11
|
from .projection_2d import Projection2D
|
|
12
|
+
from .smart_aggregator import smart_aggregator
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
15
|
"Proximity",
|
|
15
16
|
"FeatureSpaceProximity",
|
|
16
17
|
"FingerprintProximity",
|
|
17
18
|
"Projection2D",
|
|
19
|
+
"smart_aggregator",
|
|
18
20
|
]
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""Compound Dataset Overlap Analysis
|
|
2
|
+
|
|
3
|
+
This module provides utilities for comparing two molecular datasets based on
|
|
4
|
+
Tanimoto similarity in fingerprint space. It helps quantify the "overlap"
|
|
5
|
+
between datasets in chemical space.
|
|
6
|
+
|
|
7
|
+
Use cases:
|
|
8
|
+
- Train/test split validation: Ensure test set isn't too similar to training
|
|
9
|
+
- Dataset comparison: Compare proprietary vs public datasets
|
|
10
|
+
- Novelty assessment: Find compounds in query dataset that are novel vs reference
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from typing import Optional, Tuple
|
|
15
|
+
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity
|
|
19
|
+
|
|
20
|
+
# Set up logging
|
|
21
|
+
log = logging.getLogger("workbench")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CompoundDatasetOverlap:
|
|
25
|
+
"""Compare two molecular datasets using Tanimoto similarity.
|
|
26
|
+
|
|
27
|
+
Builds a FingerprintProximity model on the reference dataset, then queries
|
|
28
|
+
with SMILES from the query dataset to find the nearest neighbor in the
|
|
29
|
+
reference for each query compound. This guarantees cross-dataset matches.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
prox: FingerprintProximity instance on reference dataset
|
|
33
|
+
overlap_df: Results DataFrame with similarity scores for each query compound
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
df_reference: pd.DataFrame,
|
|
39
|
+
df_query: pd.DataFrame,
|
|
40
|
+
id_column_reference: str = "id",
|
|
41
|
+
id_column_query: str = "id",
|
|
42
|
+
radius: int = 2,
|
|
43
|
+
n_bits: int = 2048,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Initialize the CompoundDatasetOverlap analysis.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
df_reference: Reference dataset (DataFrame with SMILES)
|
|
50
|
+
df_query: Query dataset (DataFrame with SMILES)
|
|
51
|
+
id_column_reference: ID column name in df_reference
|
|
52
|
+
id_column_query: ID column name in df_query
|
|
53
|
+
radius: Morgan fingerprint radius (default: 2 = ECFP4)
|
|
54
|
+
n_bits: Number of fingerprint bits (default: 2048)
|
|
55
|
+
"""
|
|
56
|
+
self.id_column_reference = id_column_reference
|
|
57
|
+
self.id_column_query = id_column_query
|
|
58
|
+
self._radius = radius
|
|
59
|
+
self._n_bits = n_bits
|
|
60
|
+
|
|
61
|
+
# Store copies of the dataframes
|
|
62
|
+
self.df_reference = df_reference.copy()
|
|
63
|
+
self.df_query = df_query.copy()
|
|
64
|
+
|
|
65
|
+
# Find SMILES columns
|
|
66
|
+
self._smiles_col_reference = self._find_smiles_column(self.df_reference)
|
|
67
|
+
self._smiles_col_query = self._find_smiles_column(self.df_query)
|
|
68
|
+
|
|
69
|
+
if self._smiles_col_reference is None:
|
|
70
|
+
raise ValueError("Reference dataset must have a SMILES column")
|
|
71
|
+
if self._smiles_col_query is None:
|
|
72
|
+
raise ValueError("Query dataset must have a SMILES column")
|
|
73
|
+
|
|
74
|
+
log.info(f"Reference dataset: {len(self.df_reference)} compounds")
|
|
75
|
+
log.info(f"Query dataset: {len(self.df_query)} compounds")
|
|
76
|
+
|
|
77
|
+
# Build FingerprintProximity on reference dataset only
|
|
78
|
+
self.prox = FingerprintProximity(
|
|
79
|
+
self.df_reference,
|
|
80
|
+
id_column=id_column_reference,
|
|
81
|
+
radius=radius,
|
|
82
|
+
n_bits=n_bits,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Compute cross-dataset overlap
|
|
86
|
+
self.overlap_df = self._compute_cross_dataset_overlap()
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def _find_smiles_column(df: pd.DataFrame) -> Optional[str]:
|
|
90
|
+
"""Find the SMILES column in a DataFrame (case-insensitive)."""
|
|
91
|
+
for col in df.columns:
|
|
92
|
+
if col.lower() == "smiles":
|
|
93
|
+
return col
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
def _compute_cross_dataset_overlap(self) -> pd.DataFrame:
|
|
97
|
+
"""For each query compound, find nearest neighbor in reference using neighbors_from_smiles."""
|
|
98
|
+
log.info(f"Computing nearest neighbors in reference for {len(self.df_query)} query compounds")
|
|
99
|
+
|
|
100
|
+
# Get SMILES list from query dataset
|
|
101
|
+
query_smiles = self.df_query[self._smiles_col_query].tolist()
|
|
102
|
+
query_ids = self.df_query[self.id_column_query].tolist()
|
|
103
|
+
|
|
104
|
+
# Query all compounds against reference (get only nearest neighbor)
|
|
105
|
+
neighbors_df = self.prox.neighbors_from_smiles(query_smiles, n_neighbors=1)
|
|
106
|
+
|
|
107
|
+
# Build results with query IDs
|
|
108
|
+
results = []
|
|
109
|
+
for i, (q_id, q_smi) in enumerate(zip(query_ids, query_smiles)):
|
|
110
|
+
# Find the row for this query SMILES
|
|
111
|
+
match = neighbors_df[neighbors_df["query_id"] == q_smi]
|
|
112
|
+
if len(match) > 0:
|
|
113
|
+
row = match.iloc[0]
|
|
114
|
+
results.append(
|
|
115
|
+
{
|
|
116
|
+
"id": q_id,
|
|
117
|
+
"smiles": q_smi,
|
|
118
|
+
"nearest_neighbor_id": row["neighbor_id"],
|
|
119
|
+
"tanimoto_similarity": row["similarity"],
|
|
120
|
+
}
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
# Should not happen, but handle gracefully
|
|
124
|
+
results.append(
|
|
125
|
+
{
|
|
126
|
+
"id": q_id,
|
|
127
|
+
"smiles": q_smi,
|
|
128
|
+
"nearest_neighbor_id": None,
|
|
129
|
+
"tanimoto_similarity": 0.0,
|
|
130
|
+
}
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
result_df = pd.DataFrame(results)
|
|
134
|
+
|
|
135
|
+
# Add nearest neighbor SMILES from reference
|
|
136
|
+
ref_smiles_map = self.df_reference.set_index(self.id_column_reference)[self._smiles_col_reference]
|
|
137
|
+
result_df["nearest_neighbor_smiles"] = result_df["nearest_neighbor_id"].map(ref_smiles_map)
|
|
138
|
+
|
|
139
|
+
return result_df.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
|
|
140
|
+
|
|
141
|
+
def summary_stats(self) -> pd.DataFrame:
|
|
142
|
+
"""Return distribution statistics for nearest-neighbor Tanimoto similarities."""
|
|
143
|
+
return (
|
|
144
|
+
self.overlap_df["tanimoto_similarity"]
|
|
145
|
+
.describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
|
|
146
|
+
.to_frame()
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def novel_compounds(self, threshold: float = 0.4) -> pd.DataFrame:
|
|
150
|
+
"""Return query compounds that are novel (low similarity to reference).
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
threshold: Maximum Tanimoto similarity to consider "novel" (default: 0.4)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
DataFrame of query compounds with similarity below threshold
|
|
157
|
+
"""
|
|
158
|
+
novel = self.overlap_df[self.overlap_df["tanimoto_similarity"] < threshold].copy()
|
|
159
|
+
return novel.sort_values("tanimoto_similarity", ascending=True).reset_index(drop=True)
|
|
160
|
+
|
|
161
|
+
def similar_compounds(self, threshold: float = 0.7) -> pd.DataFrame:
|
|
162
|
+
"""Return query compounds that are similar to reference (high overlap).
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
threshold: Minimum Tanimoto similarity to consider "similar" (default: 0.7)
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
DataFrame of query compounds with similarity above threshold
|
|
169
|
+
"""
|
|
170
|
+
similar = self.overlap_df[self.overlap_df["tanimoto_similarity"] >= threshold].copy()
|
|
171
|
+
return similar.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
|
|
172
|
+
|
|
173
|
+
def overlap_fraction(self, threshold: float = 0.7) -> float:
|
|
174
|
+
"""Return fraction of query compounds that overlap with reference above similarity threshold.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
threshold: Minimum Tanimoto similarity to consider "overlapping"
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Fraction of query compounds with nearest neighbor similarity >= threshold
|
|
181
|
+
"""
|
|
182
|
+
n_overlapping = (self.overlap_df["tanimoto_similarity"] >= threshold).sum()
|
|
183
|
+
return n_overlapping / len(self.overlap_df)
|
|
184
|
+
|
|
185
|
+
def plot_histogram(self, bins: int = 50, figsize: Tuple[int, int] = (10, 6)) -> None:
|
|
186
|
+
"""Plot histogram of nearest-neighbor Tanimoto similarities.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
bins: Number of histogram bins
|
|
190
|
+
figsize: Figure size (width, height)
|
|
191
|
+
"""
|
|
192
|
+
import matplotlib.pyplot as plt
|
|
193
|
+
|
|
194
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
195
|
+
ax.hist(self.overlap_df["tanimoto_similarity"], bins=bins, edgecolor="black", alpha=0.7)
|
|
196
|
+
ax.set_xlabel("Tanimoto Similarity (query → nearest in reference)")
|
|
197
|
+
ax.set_ylabel("Count")
|
|
198
|
+
ax.set_title(f"Dataset Overlap: {len(self.overlap_df)} query compounds")
|
|
199
|
+
ax.axvline(x=0.4, color="red", linestyle="--", label="Novel threshold (0.4)")
|
|
200
|
+
ax.axvline(x=0.7, color="green", linestyle="--", label="Similar threshold (0.7)")
|
|
201
|
+
ax.legend()
|
|
202
|
+
|
|
203
|
+
# Add summary stats as text
|
|
204
|
+
stats = self.overlap_df["tanimoto_similarity"]
|
|
205
|
+
textstr = f"Mean: {stats.mean():.3f}\nMedian: {stats.median():.3f}\nStd: {stats.std():.3f}"
|
|
206
|
+
ax.text(
|
|
207
|
+
0.02,
|
|
208
|
+
0.98,
|
|
209
|
+
textstr,
|
|
210
|
+
transform=ax.transAxes,
|
|
211
|
+
verticalalignment="top",
|
|
212
|
+
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
plt.tight_layout()
|
|
216
|
+
plt.show()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# =============================================================================
|
|
220
|
+
# Testing
|
|
221
|
+
# =============================================================================
|
|
222
|
+
if __name__ == "__main__":
|
|
223
|
+
print("=" * 80)
|
|
224
|
+
print("Testing CompoundDatasetOverlap")
|
|
225
|
+
print("=" * 80)
|
|
226
|
+
|
|
227
|
+
# Test 1: Basic functionality with SMILES data
|
|
228
|
+
print("\n1. Testing with SMILES data...")
|
|
229
|
+
|
|
230
|
+
# Reference dataset: Known drug-like compounds
|
|
231
|
+
reference_data = {
|
|
232
|
+
"id": ["aspirin", "caffeine", "glucose", "ibuprofen", "naproxen", "ethanol", "methanol", "propanol"],
|
|
233
|
+
"smiles": [
|
|
234
|
+
"CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
|
|
235
|
+
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
|
|
236
|
+
"C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # glucose
|
|
237
|
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
|
|
238
|
+
"COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
|
|
239
|
+
"CCO", # ethanol
|
|
240
|
+
"CO", # methanol
|
|
241
|
+
"CCCO", # propanol
|
|
242
|
+
],
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
# Query dataset: Compounds to compare against reference
|
|
246
|
+
query_data = {
|
|
247
|
+
"id": ["acetaminophen", "theophylline", "benzene", "toluene", "phenol", "aniline"],
|
|
248
|
+
"smiles": [
|
|
249
|
+
"CC(=O)NC1=CC=C(C=C1)O", # acetaminophen - similar to aspirin
|
|
250
|
+
"CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
|
|
251
|
+
"c1ccccc1", # benzene - simple aromatic
|
|
252
|
+
"Cc1ccccc1", # toluene - similar to benzene
|
|
253
|
+
"Oc1ccccc1", # phenol - hydroxyl benzene
|
|
254
|
+
"Nc1ccccc1", # aniline - amino benzene
|
|
255
|
+
],
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
df_reference = pd.DataFrame(reference_data)
|
|
259
|
+
df_query = pd.DataFrame(query_data)
|
|
260
|
+
|
|
261
|
+
print(f" Reference: {len(df_reference)} compounds, Query: {len(df_query)} compounds")
|
|
262
|
+
|
|
263
|
+
overlap = CompoundDatasetOverlap(
|
|
264
|
+
df_reference, df_query, id_column_reference="id", id_column_query="id", radius=2, n_bits=1024
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
print("\n Overlap results:")
|
|
268
|
+
print(overlap.overlap_df[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
|
|
269
|
+
|
|
270
|
+
print("\n Summary statistics:")
|
|
271
|
+
print(overlap.summary_stats())
|
|
272
|
+
|
|
273
|
+
# Test 2: Novel and similar compound identification
|
|
274
|
+
print("\n2. Testing novel/similar compound identification...")
|
|
275
|
+
|
|
276
|
+
similar = overlap.similar_compounds(threshold=0.3)
|
|
277
|
+
print(f" Similar compounds (sim >= 0.3): {len(similar)}")
|
|
278
|
+
if len(similar) > 0:
|
|
279
|
+
print(similar[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
|
|
280
|
+
|
|
281
|
+
novel = overlap.novel_compounds(threshold=0.3)
|
|
282
|
+
print(f"\n Novel compounds (sim < 0.3): {len(novel)}")
|
|
283
|
+
if len(novel) > 0:
|
|
284
|
+
print(novel[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
|
|
285
|
+
|
|
286
|
+
# Test 3: With Workbench data (if available)
|
|
287
|
+
print("\n3. Testing with Workbench FeatureSet (if available)...")
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
from workbench.api import FeatureSet
|
|
291
|
+
|
|
292
|
+
fs = FeatureSet("aqsol_features")
|
|
293
|
+
full_df = fs.pull_dataframe()[:1000] # Limit to first 1000 for testing
|
|
294
|
+
|
|
295
|
+
# Split into reference and query sets
|
|
296
|
+
df_reference = full_df.sample(frac=0.8, random_state=42)
|
|
297
|
+
df_query = full_df.drop(df_reference.index)
|
|
298
|
+
|
|
299
|
+
print(f" Reference set: {len(df_reference)} compounds")
|
|
300
|
+
print(f" Query set: {len(df_query)} compounds")
|
|
301
|
+
|
|
302
|
+
overlap = CompoundDatasetOverlap(
|
|
303
|
+
df_reference, df_query, id_column_reference=fs.id_column, id_column_query=fs.id_column
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
print("\n Summary statistics:")
|
|
307
|
+
print(overlap.summary_stats())
|
|
308
|
+
|
|
309
|
+
print(f"\n Overlap fraction (sim >= 0.7): {overlap.overlap_fraction(0.7):.2%}")
|
|
310
|
+
print(f" Overlap fraction (sim >= 0.5): {overlap.overlap_fraction(0.5):.2%}")
|
|
311
|
+
print(f" Novel compounds (sim < 0.4): {len(overlap.novel_compounds(0.4))}")
|
|
312
|
+
|
|
313
|
+
# Uncomment to show histogram
|
|
314
|
+
overlap.plot_histogram()
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
print(f" Skipping Workbench test: {e}")
|
|
318
|
+
|
|
319
|
+
print("\n" + "=" * 80)
|
|
320
|
+
print("✅ All CompoundDatasetOverlap tests completed!")
|
|
321
|
+
print("=" * 80)
|
|
@@ -29,7 +29,6 @@ class FingerprintProximity(Proximity):
|
|
|
29
29
|
include_all_columns: bool = False,
|
|
30
30
|
radius: int = 2,
|
|
31
31
|
n_bits: int = 1024,
|
|
32
|
-
counts: bool = False,
|
|
33
32
|
) -> None:
|
|
34
33
|
"""
|
|
35
34
|
Initialize the FingerprintProximity class for binary fingerprint similarity.
|
|
@@ -43,12 +42,10 @@ class FingerprintProximity(Proximity):
|
|
|
43
42
|
include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
|
|
44
43
|
radius: Radius for Morgan fingerprint computation (default: 2).
|
|
45
44
|
n_bits: Number of bits for fingerprint (default: 1024).
|
|
46
|
-
counts: Whether to use count simulation (default: False).
|
|
47
45
|
"""
|
|
48
46
|
# Store fingerprint computation parameters
|
|
49
47
|
self._fp_radius = radius
|
|
50
48
|
self._fp_n_bits = n_bits
|
|
51
|
-
self._fp_counts = counts
|
|
52
49
|
|
|
53
50
|
# Store the requested fingerprint column (may be None)
|
|
54
51
|
self._fingerprint_column_arg = fingerprint_column
|
|
@@ -107,54 +104,77 @@ class FingerprintProximity(Proximity):
|
|
|
107
104
|
# If fingerprint column doesn't exist yet, compute it
|
|
108
105
|
if self.fingerprint_column not in self.df.columns:
|
|
109
106
|
log.info(f"Computing Morgan fingerprints (radius={self._fp_radius}, n_bits={self._fp_n_bits})...")
|
|
110
|
-
self.df = compute_morgan_fingerprints(
|
|
111
|
-
self.df, radius=self._fp_radius, n_bits=self._fp_n_bits, counts=self._fp_counts
|
|
112
|
-
)
|
|
107
|
+
self.df = compute_morgan_fingerprints(self.df, radius=self._fp_radius, n_bits=self._fp_n_bits)
|
|
113
108
|
|
|
114
109
|
def _build_model(self) -> None:
|
|
115
110
|
"""
|
|
116
111
|
Build the fingerprint proximity model for Tanimoto similarity.
|
|
117
|
-
Converts fingerprint strings to binary arrays and initializes NearestNeighbors.
|
|
118
112
|
|
|
119
|
-
|
|
120
|
-
|
|
113
|
+
For binary fingerprints: uses Jaccard distance (1 - Tanimoto)
|
|
114
|
+
For count fingerprints: uses weighted Tanimoto (Ruzicka) distance
|
|
121
115
|
"""
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
116
|
+
# Convert fingerprint strings to matrix and detect format
|
|
117
|
+
self.X, self._is_count_fp = self._fingerprints_to_matrix(self.df)
|
|
118
|
+
|
|
119
|
+
if self._is_count_fp:
|
|
120
|
+
# Weighted Tanimoto (Ruzicka) for count vectors: 1 - Σmin(A,B)/Σmax(A,B)
|
|
121
|
+
log.info("Building NearestNeighbors model (weighted Tanimoto for count fingerprints)...")
|
|
122
|
+
|
|
123
|
+
def ruzicka_distance(a, b):
|
|
124
|
+
"""Ruzicka distance = 1 - weighted Tanimoto similarity."""
|
|
125
|
+
min_sum = np.minimum(a, b).sum()
|
|
126
|
+
max_sum = np.maximum(a, b).sum()
|
|
127
|
+
if max_sum == 0:
|
|
128
|
+
return 0.0
|
|
129
|
+
return 1.0 - (min_sum / max_sum)
|
|
130
|
+
|
|
131
|
+
self.nn = NearestNeighbors(metric=ruzicka_distance, algorithm="ball_tree").fit(self.X)
|
|
132
|
+
else:
|
|
133
|
+
# Standard Jaccard for binary fingerprints
|
|
134
|
+
log.info("Building NearestNeighbors model (Jaccard/Tanimoto for binary fingerprints)...")
|
|
135
|
+
self.nn = NearestNeighbors(metric="jaccard", algorithm="ball_tree").fit(self.X)
|
|
131
136
|
|
|
132
137
|
def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
|
|
133
138
|
"""
|
|
134
|
-
Transform fingerprints to
|
|
139
|
+
Transform fingerprints to matrix for querying.
|
|
135
140
|
|
|
136
141
|
Args:
|
|
137
142
|
df: DataFrame containing fingerprints to transform.
|
|
138
143
|
|
|
139
144
|
Returns:
|
|
140
|
-
|
|
145
|
+
Feature matrix for the fingerprints (binary or count based on self._is_count_fp).
|
|
141
146
|
"""
|
|
142
|
-
|
|
147
|
+
matrix, _ = self._fingerprints_to_matrix(df)
|
|
148
|
+
return matrix
|
|
143
149
|
|
|
144
|
-
def _fingerprints_to_matrix(self, df: pd.DataFrame) -> np.ndarray:
|
|
150
|
+
def _fingerprints_to_matrix(self, df: pd.DataFrame) -> tuple[np.ndarray, bool]:
|
|
145
151
|
"""
|
|
146
|
-
Convert fingerprint strings to a
|
|
152
|
+
Convert fingerprint strings to a numpy matrix.
|
|
153
|
+
|
|
154
|
+
Supports two formats (auto-detected):
|
|
155
|
+
- Bitstrings: "10110010..." → binary matrix (bool), is_count=False
|
|
156
|
+
- Count vectors: "0,3,0,1,5,..." → count matrix (uint8), is_count=True
|
|
147
157
|
|
|
148
158
|
Args:
|
|
149
159
|
df: DataFrame containing fingerprint column.
|
|
150
160
|
|
|
151
161
|
Returns:
|
|
152
|
-
2D numpy array
|
|
162
|
+
Tuple of (2D numpy array, is_count_fingerprint boolean)
|
|
153
163
|
"""
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
164
|
+
# Auto-detect format based on first fingerprint
|
|
165
|
+
sample = str(df[self.fingerprint_column].iloc[0])
|
|
166
|
+
if "," in sample:
|
|
167
|
+
# Count vector format: preserve counts for weighted Tanimoto
|
|
168
|
+
fingerprint_values = df[self.fingerprint_column].apply(
|
|
169
|
+
lambda fp: np.array([int(x) for x in fp.split(",")], dtype=np.uint8)
|
|
170
|
+
)
|
|
171
|
+
return np.vstack(fingerprint_values), True
|
|
172
|
+
else:
|
|
173
|
+
# Bitstring format: binary values
|
|
174
|
+
fingerprint_bits = df[self.fingerprint_column].apply(
|
|
175
|
+
lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
|
|
176
|
+
)
|
|
177
|
+
return np.vstack(fingerprint_bits), False
|
|
158
178
|
|
|
159
179
|
def _precompute_metrics(self) -> None:
|
|
160
180
|
"""Precompute metrics, adding Tanimoto similarity alongside distance."""
|
|
@@ -171,8 +191,13 @@ class FingerprintProximity(Proximity):
|
|
|
171
191
|
self.core_columns.extend([self.target, "nn_target", "nn_target_diff"])
|
|
172
192
|
|
|
173
193
|
def _project_2d(self) -> None:
|
|
174
|
-
"""Project the fingerprint matrix to 2D for visualization using UMAP
|
|
175
|
-
|
|
194
|
+
"""Project the fingerprint matrix to 2D for visualization using UMAP."""
|
|
195
|
+
if self._is_count_fp:
|
|
196
|
+
# For count fingerprints, convert to binary for UMAP projection (Jaccard needs binary)
|
|
197
|
+
X_binary = (self.X > 0).astype(np.bool_)
|
|
198
|
+
self.df = Projection2D().fit_transform(self.df, feature_matrix=X_binary, metric="jaccard")
|
|
199
|
+
else:
|
|
200
|
+
self.df = Projection2D().fit_transform(self.df, feature_matrix=self.X, metric="jaccard")
|
|
176
201
|
|
|
177
202
|
def isolated(self, top_percent: float = 1.0) -> pd.DataFrame:
|
|
178
203
|
"""
|
|
@@ -240,6 +265,81 @@ class FingerprintProximity(Proximity):
|
|
|
240
265
|
|
|
241
266
|
return neighbors_df
|
|
242
267
|
|
|
268
|
+
def neighbors_from_smiles(
|
|
269
|
+
self,
|
|
270
|
+
smiles: Union[str, List[str]],
|
|
271
|
+
n_neighbors: int = 5,
|
|
272
|
+
min_similarity: Optional[float] = None,
|
|
273
|
+
) -> pd.DataFrame:
|
|
274
|
+
"""
|
|
275
|
+
Find neighbors for SMILES strings not in the reference dataset.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
smiles: Single SMILES string or list of SMILES to query
|
|
279
|
+
n_neighbors: Number of neighbors to return (default: 5, ignored if min_similarity is set)
|
|
280
|
+
min_similarity: If provided, find all neighbors with Tanimoto similarity >= this value (0-1)
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
DataFrame containing neighbors with Tanimoto similarity scores.
|
|
284
|
+
The 'query_id' column contains the SMILES string (or index if list).
|
|
285
|
+
"""
|
|
286
|
+
# Normalize to list
|
|
287
|
+
smiles_list = [smiles] if isinstance(smiles, str) else smiles
|
|
288
|
+
|
|
289
|
+
# Build a temporary DataFrame with the query SMILES
|
|
290
|
+
query_df = pd.DataFrame({"smiles": smiles_list})
|
|
291
|
+
|
|
292
|
+
# Compute fingerprints using same parameters as the reference dataset
|
|
293
|
+
query_df = compute_morgan_fingerprints(query_df, radius=self._fp_radius, n_bits=self._fp_n_bits)
|
|
294
|
+
|
|
295
|
+
# Transform to matrix (use same format detection as reference)
|
|
296
|
+
X_query, _ = self._fingerprints_to_matrix(query_df)
|
|
297
|
+
|
|
298
|
+
# Query the model
|
|
299
|
+
if min_similarity is not None:
|
|
300
|
+
radius = 1 - min_similarity
|
|
301
|
+
distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
|
|
302
|
+
else:
|
|
303
|
+
distances, indices = self.nn.kneighbors(X_query, n_neighbors=n_neighbors)
|
|
304
|
+
|
|
305
|
+
# Build results
|
|
306
|
+
results = []
|
|
307
|
+
for i, (dists, nbrs) in enumerate(zip(distances, indices)):
|
|
308
|
+
query_id = smiles_list[i]
|
|
309
|
+
|
|
310
|
+
for neighbor_idx, dist in zip(nbrs, dists):
|
|
311
|
+
neighbor_row = self.df.iloc[neighbor_idx]
|
|
312
|
+
neighbor_id = neighbor_row[self.id_column]
|
|
313
|
+
similarity = 1.0 - dist if dist > 1e-6 else 1.0
|
|
314
|
+
|
|
315
|
+
result = {
|
|
316
|
+
"query_id": query_id,
|
|
317
|
+
"neighbor_id": neighbor_id,
|
|
318
|
+
"similarity": similarity,
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
# Add target if present
|
|
322
|
+
if self.target and self.target in self.df.columns:
|
|
323
|
+
result[self.target] = neighbor_row[self.target]
|
|
324
|
+
|
|
325
|
+
# Include all columns if requested
|
|
326
|
+
if self.include_all_columns:
|
|
327
|
+
for col in self.df.columns:
|
|
328
|
+
if col not in [self.id_column, "query_id", "neighbor_id", "similarity"]:
|
|
329
|
+
result[f"neighbor_{col}"] = neighbor_row[col]
|
|
330
|
+
|
|
331
|
+
results.append(result)
|
|
332
|
+
|
|
333
|
+
df_results = pd.DataFrame(results)
|
|
334
|
+
|
|
335
|
+
# Sort by query_id then similarity descending
|
|
336
|
+
if len(df_results) > 0:
|
|
337
|
+
df_results = df_results.sort_values(["query_id", "similarity"], ascending=[True, False]).reset_index(
|
|
338
|
+
drop=True
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return df_results
|
|
342
|
+
|
|
243
343
|
|
|
244
344
|
# Testing the FingerprintProximity class
|
|
245
345
|
if __name__ == "__main__":
|
|
@@ -273,12 +373,71 @@ if __name__ == "__main__":
|
|
|
273
373
|
)
|
|
274
374
|
print(prox.neighbors(["a", "b"]))
|
|
275
375
|
|
|
376
|
+
# Regression test: include_all_columns should not break neighbor sorting
|
|
377
|
+
print("\n" + "=" * 80)
|
|
378
|
+
print("Regression test: include_all_columns neighbor sorting...")
|
|
379
|
+
print("=" * 80)
|
|
380
|
+
neighbors_all_cols = prox.neighbors("a", n_neighbors=4)
|
|
381
|
+
# Verify neighbors are sorted by similarity (descending), not alphabetically by neighbor_id
|
|
382
|
+
similarities = neighbors_all_cols["similarity"].tolist()
|
|
383
|
+
assert similarities == sorted(
|
|
384
|
+
similarities, reverse=True
|
|
385
|
+
), f"Neighbors not sorted by similarity! Got: {similarities}"
|
|
386
|
+
# Verify query_id column has correct value (the query, not the neighbor)
|
|
387
|
+
assert all(
|
|
388
|
+
neighbors_all_cols["id"] == "a"
|
|
389
|
+
), f"Query ID column corrupted! Expected all 'a', got: {neighbors_all_cols['id'].tolist()}"
|
|
390
|
+
print("PASSED: Neighbors correctly sorted by similarity with include_all_columns=True")
|
|
391
|
+
|
|
392
|
+
# Test neighbors_from_smiles with synthetic data
|
|
393
|
+
print("\n" + "=" * 80)
|
|
394
|
+
print("Testing neighbors_from_smiles...")
|
|
395
|
+
print("=" * 80)
|
|
396
|
+
|
|
397
|
+
# Create reference dataset with known SMILES
|
|
398
|
+
ref_data = {
|
|
399
|
+
"id": ["aspirin", "ibuprofen", "naproxen", "caffeine", "ethanol"],
|
|
400
|
+
"smiles": [
|
|
401
|
+
"CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
|
|
402
|
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
|
|
403
|
+
"COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
|
|
404
|
+
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
|
|
405
|
+
"CCO", # ethanol
|
|
406
|
+
],
|
|
407
|
+
"activity": [1.0, 2.0, 2.5, 3.0, 0.5],
|
|
408
|
+
}
|
|
409
|
+
ref_df = pd.DataFrame(ref_data)
|
|
410
|
+
|
|
411
|
+
prox_ref = FingerprintProximity(ref_df, id_column="id", target="activity", radius=2, n_bits=1024)
|
|
412
|
+
|
|
413
|
+
# Query with a single SMILES (acetaminophen - similar to aspirin)
|
|
414
|
+
query_smiles = "CC(=O)NC1=CC=C(C=C1)O" # acetaminophen
|
|
415
|
+
print(f"\nQuery: acetaminophen ({query_smiles})")
|
|
416
|
+
neighbors = prox_ref.neighbors_from_smiles(query_smiles, n_neighbors=3)
|
|
417
|
+
print(neighbors)
|
|
418
|
+
|
|
419
|
+
# Query with multiple SMILES
|
|
420
|
+
print("\nQuery: multiple SMILES (theophylline, methanol)")
|
|
421
|
+
multi_query = [
|
|
422
|
+
"CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
|
|
423
|
+
"CO", # methanol - similar to ethanol
|
|
424
|
+
]
|
|
425
|
+
neighbors_multi = prox_ref.neighbors_from_smiles(multi_query, n_neighbors=2)
|
|
426
|
+
print(neighbors_multi)
|
|
427
|
+
|
|
428
|
+
# Test with min_similarity threshold
|
|
429
|
+
print("\nQuery with min_similarity=0.3:")
|
|
430
|
+
neighbors_thresh = prox_ref.neighbors_from_smiles(query_smiles, min_similarity=0.3)
|
|
431
|
+
print(neighbors_thresh)
|
|
432
|
+
|
|
433
|
+
print("PASSED: neighbors_from_smiles working correctly")
|
|
434
|
+
|
|
276
435
|
# Test on real data from Workbench
|
|
277
436
|
from workbench.api import FeatureSet, Model
|
|
278
437
|
|
|
279
438
|
fs = FeatureSet("aqsol_features")
|
|
280
439
|
model = Model("aqsol-regression")
|
|
281
|
-
df = fs.pull_dataframe()
|
|
440
|
+
df = fs.pull_dataframe()[:1000] # Limit to 1000 for testing
|
|
282
441
|
prox = FingerprintProximity(df, id_column=fs.id_column, target=model.target())
|
|
283
442
|
|
|
284
443
|
print("\n" + "=" * 80)
|
|
@@ -106,8 +106,14 @@ class Projection2D:
|
|
|
106
106
|
return PCA(n_components=2)
|
|
107
107
|
|
|
108
108
|
if projection == "UMAP" and UMAP_AVAILABLE:
|
|
109
|
-
|
|
110
|
-
|
|
109
|
+
# UMAP default n_neighbors=15, adjust if dataset is smaller
|
|
110
|
+
n_neighbors = min(15, len(df) - 1)
|
|
111
|
+
if n_neighbors < 15:
|
|
112
|
+
self.log.warning(
|
|
113
|
+
f"Dataset size ({len(df)}) smaller than default n_neighbors, using n_neighbors={n_neighbors}"
|
|
114
|
+
)
|
|
115
|
+
self.log.info(f"Projection: UMAP with metric={metric}, n_neighbors={n_neighbors}")
|
|
116
|
+
return umap.UMAP(n_components=2, metric=metric, n_neighbors=n_neighbors)
|
|
111
117
|
|
|
112
118
|
self.log.warning(
|
|
113
119
|
f"Projection method '{projection}' not recognized or UMAP not available. Falling back to TSNE."
|
|
@@ -331,5 +331,8 @@ class Proximity(ABC):
|
|
|
331
331
|
# Include all columns if requested
|
|
332
332
|
if self.include_all_columns:
|
|
333
333
|
result.update(neighbor_row.to_dict())
|
|
334
|
+
# Restore query_id after update (neighbor_row may have overwritten id column)
|
|
335
|
+
result[self.id_column] = query_id
|
|
336
|
+
result["neighbor_id"] = neighbor_id
|
|
334
337
|
|
|
335
338
|
return result
|