workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
workbench/__init__.py
CHANGED
|
@@ -5,14 +5,13 @@ These classes provide functionality for Pandas Dataframes
|
|
|
5
5
|
- TBD: TBD
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from .proximity import Proximity
|
|
8
|
+
from .proximity import Proximity
|
|
9
9
|
from .feature_space_proximity import FeatureSpaceProximity
|
|
10
10
|
from .fingerprint_proximity import FingerprintProximity
|
|
11
11
|
from .projection_2d import Projection2D
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
14
|
"Proximity",
|
|
15
|
-
"ProximityType",
|
|
16
15
|
"FeatureSpaceProximity",
|
|
17
16
|
"FingerprintProximity",
|
|
18
17
|
"Projection2D",
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""Compound Dataset Overlap Analysis
|
|
2
|
+
|
|
3
|
+
This module provides utilities for comparing two molecular datasets based on
|
|
4
|
+
Tanimoto similarity in fingerprint space. It helps quantify the "overlap"
|
|
5
|
+
between datasets in chemical space.
|
|
6
|
+
|
|
7
|
+
Use cases:
|
|
8
|
+
- Train/test split validation: Ensure test set isn't too similar to training
|
|
9
|
+
- Dataset comparison: Compare proprietary vs public datasets
|
|
10
|
+
- Novelty assessment: Find compounds in query dataset that are novel vs reference
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from typing import Optional, Tuple
|
|
15
|
+
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity
|
|
19
|
+
|
|
20
|
+
# Set up logging
|
|
21
|
+
log = logging.getLogger("workbench")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CompoundDatasetOverlap:
|
|
25
|
+
"""Compare two molecular datasets using Tanimoto similarity.
|
|
26
|
+
|
|
27
|
+
Builds a FingerprintProximity model on the reference dataset, then queries
|
|
28
|
+
with SMILES from the query dataset to find the nearest neighbor in the
|
|
29
|
+
reference for each query compound. This guarantees cross-dataset matches.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
prox: FingerprintProximity instance on reference dataset
|
|
33
|
+
overlap_df: Results DataFrame with similarity scores for each query compound
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
df_reference: pd.DataFrame,
|
|
39
|
+
df_query: pd.DataFrame,
|
|
40
|
+
id_column_reference: str = "id",
|
|
41
|
+
id_column_query: str = "id",
|
|
42
|
+
radius: int = 2,
|
|
43
|
+
n_bits: int = 2048,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Initialize the CompoundDatasetOverlap analysis.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
df_reference: Reference dataset (DataFrame with SMILES)
|
|
50
|
+
df_query: Query dataset (DataFrame with SMILES)
|
|
51
|
+
id_column_reference: ID column name in df_reference
|
|
52
|
+
id_column_query: ID column name in df_query
|
|
53
|
+
radius: Morgan fingerprint radius (default: 2 = ECFP4)
|
|
54
|
+
n_bits: Number of fingerprint bits (default: 2048)
|
|
55
|
+
"""
|
|
56
|
+
self.id_column_reference = id_column_reference
|
|
57
|
+
self.id_column_query = id_column_query
|
|
58
|
+
self._radius = radius
|
|
59
|
+
self._n_bits = n_bits
|
|
60
|
+
|
|
61
|
+
# Store copies of the dataframes
|
|
62
|
+
self.df_reference = df_reference.copy()
|
|
63
|
+
self.df_query = df_query.copy()
|
|
64
|
+
|
|
65
|
+
# Find SMILES columns
|
|
66
|
+
self._smiles_col_reference = self._find_smiles_column(self.df_reference)
|
|
67
|
+
self._smiles_col_query = self._find_smiles_column(self.df_query)
|
|
68
|
+
|
|
69
|
+
if self._smiles_col_reference is None:
|
|
70
|
+
raise ValueError("Reference dataset must have a SMILES column")
|
|
71
|
+
if self._smiles_col_query is None:
|
|
72
|
+
raise ValueError("Query dataset must have a SMILES column")
|
|
73
|
+
|
|
74
|
+
log.info(f"Reference dataset: {len(self.df_reference)} compounds")
|
|
75
|
+
log.info(f"Query dataset: {len(self.df_query)} compounds")
|
|
76
|
+
|
|
77
|
+
# Build FingerprintProximity on reference dataset only
|
|
78
|
+
self.prox = FingerprintProximity(
|
|
79
|
+
self.df_reference,
|
|
80
|
+
id_column=id_column_reference,
|
|
81
|
+
radius=radius,
|
|
82
|
+
n_bits=n_bits,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Compute cross-dataset overlap
|
|
86
|
+
self.overlap_df = self._compute_cross_dataset_overlap()
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def _find_smiles_column(df: pd.DataFrame) -> Optional[str]:
|
|
90
|
+
"""Find the SMILES column in a DataFrame (case-insensitive)."""
|
|
91
|
+
for col in df.columns:
|
|
92
|
+
if col.lower() == "smiles":
|
|
93
|
+
return col
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
def _compute_cross_dataset_overlap(self) -> pd.DataFrame:
|
|
97
|
+
"""For each query compound, find nearest neighbor in reference using neighbors_from_smiles."""
|
|
98
|
+
log.info(f"Computing nearest neighbors in reference for {len(self.df_query)} query compounds")
|
|
99
|
+
|
|
100
|
+
# Get SMILES list from query dataset
|
|
101
|
+
query_smiles = self.df_query[self._smiles_col_query].tolist()
|
|
102
|
+
query_ids = self.df_query[self.id_column_query].tolist()
|
|
103
|
+
|
|
104
|
+
# Query all compounds against reference (get only nearest neighbor)
|
|
105
|
+
neighbors_df = self.prox.neighbors_from_smiles(query_smiles, n_neighbors=1)
|
|
106
|
+
|
|
107
|
+
# Build results with query IDs
|
|
108
|
+
results = []
|
|
109
|
+
for i, (q_id, q_smi) in enumerate(zip(query_ids, query_smiles)):
|
|
110
|
+
# Find the row for this query SMILES
|
|
111
|
+
match = neighbors_df[neighbors_df["query_id"] == q_smi]
|
|
112
|
+
if len(match) > 0:
|
|
113
|
+
row = match.iloc[0]
|
|
114
|
+
results.append(
|
|
115
|
+
{
|
|
116
|
+
"id": q_id,
|
|
117
|
+
"smiles": q_smi,
|
|
118
|
+
"nearest_neighbor_id": row["neighbor_id"],
|
|
119
|
+
"tanimoto_similarity": row["similarity"],
|
|
120
|
+
}
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
# Should not happen, but handle gracefully
|
|
124
|
+
results.append(
|
|
125
|
+
{
|
|
126
|
+
"id": q_id,
|
|
127
|
+
"smiles": q_smi,
|
|
128
|
+
"nearest_neighbor_id": None,
|
|
129
|
+
"tanimoto_similarity": 0.0,
|
|
130
|
+
}
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
result_df = pd.DataFrame(results)
|
|
134
|
+
|
|
135
|
+
# Add nearest neighbor SMILES from reference
|
|
136
|
+
ref_smiles_map = self.df_reference.set_index(self.id_column_reference)[self._smiles_col_reference]
|
|
137
|
+
result_df["nearest_neighbor_smiles"] = result_df["nearest_neighbor_id"].map(ref_smiles_map)
|
|
138
|
+
|
|
139
|
+
return result_df.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
|
|
140
|
+
|
|
141
|
+
def summary_stats(self) -> pd.DataFrame:
|
|
142
|
+
"""Return distribution statistics for nearest-neighbor Tanimoto similarities."""
|
|
143
|
+
return (
|
|
144
|
+
self.overlap_df["tanimoto_similarity"]
|
|
145
|
+
.describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
|
|
146
|
+
.to_frame()
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def novel_compounds(self, threshold: float = 0.4) -> pd.DataFrame:
|
|
150
|
+
"""Return query compounds that are novel (low similarity to reference).
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
threshold: Maximum Tanimoto similarity to consider "novel" (default: 0.4)
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
DataFrame of query compounds with similarity below threshold
|
|
157
|
+
"""
|
|
158
|
+
novel = self.overlap_df[self.overlap_df["tanimoto_similarity"] < threshold].copy()
|
|
159
|
+
return novel.sort_values("tanimoto_similarity", ascending=True).reset_index(drop=True)
|
|
160
|
+
|
|
161
|
+
def similar_compounds(self, threshold: float = 0.7) -> pd.DataFrame:
|
|
162
|
+
"""Return query compounds that are similar to reference (high overlap).
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
threshold: Minimum Tanimoto similarity to consider "similar" (default: 0.7)
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
DataFrame of query compounds with similarity above threshold
|
|
169
|
+
"""
|
|
170
|
+
similar = self.overlap_df[self.overlap_df["tanimoto_similarity"] >= threshold].copy()
|
|
171
|
+
return similar.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
|
|
172
|
+
|
|
173
|
+
def overlap_fraction(self, threshold: float = 0.7) -> float:
|
|
174
|
+
"""Return fraction of query compounds that overlap with reference above similarity threshold.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
threshold: Minimum Tanimoto similarity to consider "overlapping"
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Fraction of query compounds with nearest neighbor similarity >= threshold
|
|
181
|
+
"""
|
|
182
|
+
n_overlapping = (self.overlap_df["tanimoto_similarity"] >= threshold).sum()
|
|
183
|
+
return n_overlapping / len(self.overlap_df)
|
|
184
|
+
|
|
185
|
+
def plot_histogram(self, bins: int = 50, figsize: Tuple[int, int] = (10, 6)) -> None:
|
|
186
|
+
"""Plot histogram of nearest-neighbor Tanimoto similarities.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
bins: Number of histogram bins
|
|
190
|
+
figsize: Figure size (width, height)
|
|
191
|
+
"""
|
|
192
|
+
import matplotlib.pyplot as plt
|
|
193
|
+
|
|
194
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
195
|
+
ax.hist(self.overlap_df["tanimoto_similarity"], bins=bins, edgecolor="black", alpha=0.7)
|
|
196
|
+
ax.set_xlabel("Tanimoto Similarity (query → nearest in reference)")
|
|
197
|
+
ax.set_ylabel("Count")
|
|
198
|
+
ax.set_title(f"Dataset Overlap: {len(self.overlap_df)} query compounds")
|
|
199
|
+
ax.axvline(x=0.4, color="red", linestyle="--", label="Novel threshold (0.4)")
|
|
200
|
+
ax.axvline(x=0.7, color="green", linestyle="--", label="Similar threshold (0.7)")
|
|
201
|
+
ax.legend()
|
|
202
|
+
|
|
203
|
+
# Add summary stats as text
|
|
204
|
+
stats = self.overlap_df["tanimoto_similarity"]
|
|
205
|
+
textstr = f"Mean: {stats.mean():.3f}\nMedian: {stats.median():.3f}\nStd: {stats.std():.3f}"
|
|
206
|
+
ax.text(
|
|
207
|
+
0.02,
|
|
208
|
+
0.98,
|
|
209
|
+
textstr,
|
|
210
|
+
transform=ax.transAxes,
|
|
211
|
+
verticalalignment="top",
|
|
212
|
+
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
plt.tight_layout()
|
|
216
|
+
plt.show()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# =============================================================================
|
|
220
|
+
# Testing
|
|
221
|
+
# =============================================================================
|
|
222
|
+
if __name__ == "__main__":
|
|
223
|
+
print("=" * 80)
|
|
224
|
+
print("Testing CompoundDatasetOverlap")
|
|
225
|
+
print("=" * 80)
|
|
226
|
+
|
|
227
|
+
# Test 1: Basic functionality with SMILES data
|
|
228
|
+
print("\n1. Testing with SMILES data...")
|
|
229
|
+
|
|
230
|
+
# Reference dataset: Known drug-like compounds
|
|
231
|
+
reference_data = {
|
|
232
|
+
"id": ["aspirin", "caffeine", "glucose", "ibuprofen", "naproxen", "ethanol", "methanol", "propanol"],
|
|
233
|
+
"smiles": [
|
|
234
|
+
"CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
|
|
235
|
+
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
|
|
236
|
+
"C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # glucose
|
|
237
|
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
|
|
238
|
+
"COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
|
|
239
|
+
"CCO", # ethanol
|
|
240
|
+
"CO", # methanol
|
|
241
|
+
"CCCO", # propanol
|
|
242
|
+
],
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
# Query dataset: Compounds to compare against reference
|
|
246
|
+
query_data = {
|
|
247
|
+
"id": ["acetaminophen", "theophylline", "benzene", "toluene", "phenol", "aniline"],
|
|
248
|
+
"smiles": [
|
|
249
|
+
"CC(=O)NC1=CC=C(C=C1)O", # acetaminophen - similar to aspirin
|
|
250
|
+
"CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
|
|
251
|
+
"c1ccccc1", # benzene - simple aromatic
|
|
252
|
+
"Cc1ccccc1", # toluene - similar to benzene
|
|
253
|
+
"Oc1ccccc1", # phenol - hydroxyl benzene
|
|
254
|
+
"Nc1ccccc1", # aniline - amino benzene
|
|
255
|
+
],
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
df_reference = pd.DataFrame(reference_data)
|
|
259
|
+
df_query = pd.DataFrame(query_data)
|
|
260
|
+
|
|
261
|
+
print(f" Reference: {len(df_reference)} compounds, Query: {len(df_query)} compounds")
|
|
262
|
+
|
|
263
|
+
overlap = CompoundDatasetOverlap(
|
|
264
|
+
df_reference, df_query, id_column_reference="id", id_column_query="id", radius=2, n_bits=1024
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
print("\n Overlap results:")
|
|
268
|
+
print(overlap.overlap_df[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
|
|
269
|
+
|
|
270
|
+
print("\n Summary statistics:")
|
|
271
|
+
print(overlap.summary_stats())
|
|
272
|
+
|
|
273
|
+
# Test 2: Novel and similar compound identification
|
|
274
|
+
print("\n2. Testing novel/similar compound identification...")
|
|
275
|
+
|
|
276
|
+
similar = overlap.similar_compounds(threshold=0.3)
|
|
277
|
+
print(f" Similar compounds (sim >= 0.3): {len(similar)}")
|
|
278
|
+
if len(similar) > 0:
|
|
279
|
+
print(similar[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
|
|
280
|
+
|
|
281
|
+
novel = overlap.novel_compounds(threshold=0.3)
|
|
282
|
+
print(f"\n Novel compounds (sim < 0.3): {len(novel)}")
|
|
283
|
+
if len(novel) > 0:
|
|
284
|
+
print(novel[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
|
|
285
|
+
|
|
286
|
+
# Test 3: With Workbench data (if available)
|
|
287
|
+
print("\n3. Testing with Workbench FeatureSet (if available)...")
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
from workbench.api import FeatureSet
|
|
291
|
+
|
|
292
|
+
fs = FeatureSet("aqsol_features")
|
|
293
|
+
full_df = fs.pull_dataframe()[:1000] # Limit to first 1000 for testing
|
|
294
|
+
|
|
295
|
+
# Split into reference and query sets
|
|
296
|
+
df_reference = full_df.sample(frac=0.8, random_state=42)
|
|
297
|
+
df_query = full_df.drop(df_reference.index)
|
|
298
|
+
|
|
299
|
+
print(f" Reference set: {len(df_reference)} compounds")
|
|
300
|
+
print(f" Query set: {len(df_query)} compounds")
|
|
301
|
+
|
|
302
|
+
overlap = CompoundDatasetOverlap(
|
|
303
|
+
df_reference, df_query, id_column_reference=fs.id_column, id_column_query=fs.id_column
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
print("\n Summary statistics:")
|
|
307
|
+
print(overlap.summary_stats())
|
|
308
|
+
|
|
309
|
+
print(f"\n Overlap fraction (sim >= 0.7): {overlap.overlap_fraction(0.7):.2%}")
|
|
310
|
+
print(f" Overlap fraction (sim >= 0.5): {overlap.overlap_fraction(0.5):.2%}")
|
|
311
|
+
print(f" Novel compounds (sim < 0.4): {len(overlap.novel_compounds(0.4))}")
|
|
312
|
+
|
|
313
|
+
# Uncomment to show histogram
|
|
314
|
+
overlap.plot_histogram()
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
print(f" Skipping Workbench test: {e}")
|
|
318
|
+
|
|
319
|
+
print("\n" + "=" * 80)
|
|
320
|
+
print("✅ All CompoundDatasetOverlap tests completed!")
|
|
321
|
+
print("=" * 80)
|
|
@@ -1,101 +1,194 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from sklearn.preprocessing import StandardScaler
|
|
4
|
+
from sklearn.neighbors import NearestNeighbors
|
|
5
|
+
from typing import List, Optional
|
|
2
6
|
import logging
|
|
3
7
|
|
|
4
8
|
# Workbench Imports
|
|
5
9
|
from workbench.algorithms.dataframe.proximity import Proximity
|
|
6
10
|
from workbench.algorithms.dataframe.projection_2d import Projection2D
|
|
7
|
-
from workbench.core.views.inference_view import InferenceView
|
|
8
|
-
from workbench.api import FeatureSet, Model
|
|
9
11
|
|
|
10
12
|
# Set up logging
|
|
11
13
|
log = logging.getLogger("workbench")
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class FeatureSpaceProximity(Proximity):
|
|
15
|
-
|
|
17
|
+
"""Proximity computations for numeric feature spaces using Euclidean distance."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
df: pd.DataFrame,
|
|
22
|
+
id_column: str,
|
|
23
|
+
features: List[str],
|
|
24
|
+
target: Optional[str] = None,
|
|
25
|
+
include_all_columns: bool = False,
|
|
26
|
+
):
|
|
16
27
|
"""
|
|
17
28
|
Initialize the FeatureSpaceProximity class.
|
|
18
29
|
|
|
19
30
|
Args:
|
|
20
|
-
|
|
21
|
-
|
|
31
|
+
df: DataFrame containing data for neighbor computations.
|
|
32
|
+
id_column: Name of the column used as the identifier.
|
|
33
|
+
features: List of feature column names to be used for neighbor computations.
|
|
34
|
+
target: Name of the target column. Defaults to None.
|
|
35
|
+
include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
|
|
22
36
|
"""
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
#
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
37
|
+
# Validate and filter features before calling parent init
|
|
38
|
+
self._raw_features = features
|
|
39
|
+
super().__init__(
|
|
40
|
+
df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def _prepare_data(self) -> None:
|
|
44
|
+
"""Filter out non-numeric features and drop NaN rows."""
|
|
45
|
+
# Validate features
|
|
46
|
+
self.features = self._validate_features(self.df, self._raw_features)
|
|
47
|
+
|
|
48
|
+
# Drop NaN rows for the features we're using
|
|
49
|
+
self.df = self.df.dropna(subset=self.features).copy()
|
|
50
|
+
|
|
51
|
+
def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
|
|
52
|
+
"""Remove non-numeric features and log warnings."""
|
|
53
|
+
non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
|
|
54
|
+
if non_numeric:
|
|
55
|
+
log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
|
|
56
|
+
return [f for f in features if f not in non_numeric]
|
|
57
|
+
|
|
58
|
+
def _build_model(self) -> None:
|
|
59
|
+
"""Standardize features and fit Nearest Neighbors model."""
|
|
60
|
+
self.scaler = StandardScaler()
|
|
61
|
+
X = self.scaler.fit_transform(self.df[self.features])
|
|
62
|
+
self.nn = NearestNeighbors().fit(X)
|
|
63
|
+
|
|
64
|
+
def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
|
|
65
|
+
"""Transform features using the fitted scaler."""
|
|
66
|
+
return self.scaler.transform(df[self.features])
|
|
67
|
+
|
|
68
|
+
def _project_2d(self) -> None:
|
|
69
|
+
"""Project the numeric features to 2D for visualization."""
|
|
70
|
+
if len(self.features) >= 2:
|
|
71
|
+
self.df = Projection2D().fit_transform(self.df, features=self.features)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Testing the FeatureSpaceProximity class
|
|
48
75
|
if __name__ == "__main__":
|
|
76
|
+
|
|
49
77
|
pd.set_option("display.max_columns", None)
|
|
50
78
|
pd.set_option("display.width", 1000)
|
|
51
79
|
|
|
52
|
-
#
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
80
|
+
# Create a sample DataFrame
|
|
81
|
+
data = {
|
|
82
|
+
"ID": [1, 2, 3, 4, 5],
|
|
83
|
+
"Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
84
|
+
"Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
85
|
+
"Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
|
|
86
|
+
}
|
|
87
|
+
df = pd.DataFrame(data)
|
|
88
|
+
|
|
89
|
+
# Test the FeatureSpaceProximity class
|
|
90
|
+
features = ["Feature1", "Feature2", "Feature3"]
|
|
91
|
+
prox = FeatureSpaceProximity(df, id_column="ID", features=features)
|
|
92
|
+
print(prox.neighbors(1, n_neighbors=2))
|
|
93
|
+
|
|
94
|
+
# Test the neighbors method with radius
|
|
95
|
+
print(prox.neighbors(1, radius=2.0))
|
|
96
|
+
|
|
97
|
+
# Test with Features list
|
|
98
|
+
prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
|
|
99
|
+
print(prox.neighbors(1))
|
|
100
|
+
|
|
101
|
+
# Create a sample DataFrame
|
|
102
|
+
data = {
|
|
103
|
+
"id": ["a", "b", "c", "d", "e"], # Testing string IDs
|
|
104
|
+
"Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
105
|
+
"Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
106
|
+
"target": [1, 0, 1, 0, 5],
|
|
107
|
+
}
|
|
108
|
+
df = pd.DataFrame(data)
|
|
109
|
+
|
|
110
|
+
# Test with String Ids
|
|
111
|
+
prox = FeatureSpaceProximity(
|
|
112
|
+
df,
|
|
113
|
+
id_column="id",
|
|
114
|
+
features=["Feature1", "Feature2"],
|
|
115
|
+
target="target",
|
|
116
|
+
include_all_columns=True,
|
|
117
|
+
)
|
|
118
|
+
print(prox.neighbors(["a", "b"]))
|
|
119
|
+
|
|
120
|
+
# Test duplicate IDs
|
|
121
|
+
data = {
|
|
122
|
+
"id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
|
|
123
|
+
"Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
124
|
+
"Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
125
|
+
"target": [1, 0, 1, 0, 5],
|
|
126
|
+
}
|
|
127
|
+
df = pd.DataFrame(data)
|
|
128
|
+
prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
|
|
129
|
+
print(df.equals(prox.df))
|
|
130
|
+
|
|
131
|
+
# Test on real data from Workbench
|
|
132
|
+
from workbench.api import FeatureSet, Model
|
|
133
|
+
|
|
134
|
+
fs = FeatureSet("aqsol_features")
|
|
135
|
+
model = Model("aqsol-regression")
|
|
136
|
+
features = model.features()
|
|
58
137
|
df = fs.pull_dataframe()
|
|
59
|
-
|
|
60
|
-
print("\
|
|
61
|
-
print(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
print("
|
|
83
|
-
print(
|
|
84
|
-
|
|
85
|
-
#
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
print("\
|
|
92
|
-
print(
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
138
|
+
prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
|
|
139
|
+
print("\n" + "=" * 80)
|
|
140
|
+
print("Testing Neighbors...")
|
|
141
|
+
print("=" * 80)
|
|
142
|
+
test_id = df[fs.id_column].tolist()[0]
|
|
143
|
+
print(f"\nNeighbors for ID {test_id}:")
|
|
144
|
+
print(prox.neighbors(test_id))
|
|
145
|
+
|
|
146
|
+
print("\n" + "=" * 80)
|
|
147
|
+
print("Testing isolated_compounds...")
|
|
148
|
+
print("=" * 80)
|
|
149
|
+
|
|
150
|
+
# Test isolated data in the top 1%
|
|
151
|
+
isolated_1pct = prox.isolated(top_percent=1.0)
|
|
152
|
+
print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
|
|
153
|
+
print(isolated_1pct)
|
|
154
|
+
|
|
155
|
+
# Test isolated data in the top 5%
|
|
156
|
+
isolated_5pct = prox.isolated(top_percent=5.0)
|
|
157
|
+
print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
|
|
158
|
+
print(isolated_5pct)
|
|
159
|
+
|
|
160
|
+
print("\n" + "=" * 80)
|
|
161
|
+
print("Testing target_gradients...")
|
|
162
|
+
print("=" * 80)
|
|
163
|
+
|
|
164
|
+
# Test with different parameters
|
|
165
|
+
gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
|
|
166
|
+
print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
|
|
167
|
+
print(gradients_1pct)
|
|
168
|
+
|
|
169
|
+
gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
|
|
170
|
+
print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
|
|
171
|
+
print(gradients_5pct)
|
|
172
|
+
|
|
173
|
+
# Test proximity_stats
|
|
174
|
+
print("\n" + "=" * 80)
|
|
175
|
+
print("Testing proximity_stats...")
|
|
176
|
+
print("=" * 80)
|
|
177
|
+
stats = prox.proximity_stats()
|
|
178
|
+
print(stats)
|
|
179
|
+
|
|
180
|
+
# Plot the distance distribution using pandas
|
|
181
|
+
print("\n" + "=" * 80)
|
|
182
|
+
print("Plotting distance distribution...")
|
|
183
|
+
print("=" * 80)
|
|
184
|
+
prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
|
|
185
|
+
|
|
186
|
+
# Visualize the 2D projection
|
|
187
|
+
print("\n" + "=" * 80)
|
|
188
|
+
print("Visualizing 2D Projection...")
|
|
189
|
+
print("=" * 80)
|
|
96
190
|
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
97
191
|
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
98
192
|
|
|
99
|
-
|
|
100
|
-
unit_test = PluginUnitTest(ScatterPlot, input_data=fsp.df, x="x", y="y")
|
|
193
|
+
unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
|
|
101
194
|
unit_test.run()
|