workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -91,6 +91,8 @@ import logging
|
|
|
91
91
|
import pandas as pd
|
|
92
92
|
import numpy as np
|
|
93
93
|
import re
|
|
94
|
+
import time
|
|
95
|
+
from contextlib import contextmanager
|
|
94
96
|
from rdkit import Chem
|
|
95
97
|
from rdkit.Chem import Descriptors, rdCIPLabeler
|
|
96
98
|
from rdkit.ML.Descriptors import MoleculeDescriptors
|
|
@@ -101,6 +103,14 @@ logger = logging.getLogger("workbench")
|
|
|
101
103
|
logger.setLevel(logging.DEBUG)
|
|
102
104
|
|
|
103
105
|
|
|
106
|
+
# Helper context manager for timing
|
|
107
|
+
@contextmanager
|
|
108
|
+
def timer(name):
|
|
109
|
+
start = time.time()
|
|
110
|
+
yield
|
|
111
|
+
print(f"{name}: {time.time() - start:.2f}s")
|
|
112
|
+
|
|
113
|
+
|
|
104
114
|
def compute_stereochemistry_features(mol):
|
|
105
115
|
"""
|
|
106
116
|
Compute stereochemistry descriptors using modern RDKit methods.
|
|
@@ -280,9 +290,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
280
290
|
descriptor_values.append([np.nan] * len(all_descriptors))
|
|
281
291
|
|
|
282
292
|
# Create RDKit features DataFrame
|
|
283
|
-
rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames()
|
|
293
|
+
rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
|
|
284
294
|
|
|
285
295
|
# Add RDKit features to result
|
|
296
|
+
# Remove any columns from result that exist in rdkit_features_df
|
|
297
|
+
result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
|
|
286
298
|
result = pd.concat([result, rdkit_features_df], axis=1)
|
|
287
299
|
|
|
288
300
|
# Compute Mordred descriptors
|
|
@@ -299,7 +311,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
299
311
|
|
|
300
312
|
# Compute Mordred descriptors
|
|
301
313
|
valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
|
|
302
|
-
mordred_df = calc.pandas(valid_mols, nproc=1) #
|
|
314
|
+
mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
|
|
303
315
|
|
|
304
316
|
# Replace values for invalid molecules with NaN
|
|
305
317
|
for i, mol in enumerate(molecules):
|
|
@@ -310,10 +322,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
310
322
|
for col in mordred_df.columns:
|
|
311
323
|
mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
|
|
312
324
|
|
|
313
|
-
# Set index to match result DataFrame
|
|
314
|
-
mordred_df.index = result.index
|
|
315
|
-
|
|
316
325
|
# Add Mordred features to result
|
|
326
|
+
# Remove any columns from result that exist in mordred
|
|
327
|
+
result = result.drop(columns=result.columns.intersection(mordred_df.columns))
|
|
317
328
|
result = pd.concat([result, mordred_df], axis=1)
|
|
318
329
|
|
|
319
330
|
# Compute stereochemistry features if requested
|
|
@@ -326,9 +337,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
326
337
|
stereo_features.append(stereo_dict)
|
|
327
338
|
|
|
328
339
|
# Create stereochemistry DataFrame
|
|
329
|
-
stereo_df = pd.DataFrame(stereo_features
|
|
340
|
+
stereo_df = pd.DataFrame(stereo_features)
|
|
330
341
|
|
|
331
342
|
# Add stereochemistry features to result
|
|
343
|
+
result = result.drop(columns=result.columns.intersection(stereo_df.columns))
|
|
332
344
|
result = pd.concat([result, stereo_df], axis=1)
|
|
333
345
|
|
|
334
346
|
logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
|
|
@@ -357,7 +369,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
357
369
|
|
|
358
370
|
|
|
359
371
|
if __name__ == "__main__":
|
|
360
|
-
import time
|
|
361
372
|
from mol_standardize import standardize
|
|
362
373
|
from workbench.api import DataSource
|
|
363
374
|
|
|
@@ -81,6 +81,8 @@ Usage:
|
|
|
81
81
|
import logging
|
|
82
82
|
from typing import Optional, Tuple
|
|
83
83
|
import pandas as pd
|
|
84
|
+
import time
|
|
85
|
+
from contextlib import contextmanager
|
|
84
86
|
from rdkit import Chem
|
|
85
87
|
from rdkit.Chem import Mol
|
|
86
88
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
|
|
|
90
92
|
RDLogger.DisableLog("rdApp.warning")
|
|
91
93
|
|
|
92
94
|
|
|
95
|
+
# Helper context manager for timing
|
|
96
|
+
@contextmanager
|
|
97
|
+
def timer(name):
|
|
98
|
+
start = time.time()
|
|
99
|
+
yield
|
|
100
|
+
print(f"{name}: {time.time() - start:.2f}s")
|
|
101
|
+
|
|
102
|
+
|
|
93
103
|
class MolStandardizer:
|
|
94
104
|
"""
|
|
95
105
|
Streamlined molecular standardizer for ADMET preprocessing
|
|
@@ -116,6 +126,7 @@ class MolStandardizer:
|
|
|
116
126
|
Pipeline:
|
|
117
127
|
1. Cleanup (remove Hs, disconnect metals, normalize)
|
|
118
128
|
2. Get largest fragment (optional - only if remove_salts=True)
|
|
129
|
+
2a. Extract salt information BEFORE further modifications
|
|
119
130
|
3. Neutralize charges
|
|
120
131
|
4. Canonicalize tautomer (optional)
|
|
121
132
|
|
|
@@ -130,18 +141,24 @@ class MolStandardizer:
|
|
|
130
141
|
|
|
131
142
|
try:
|
|
132
143
|
# Step 1: Cleanup
|
|
133
|
-
|
|
134
|
-
if
|
|
144
|
+
cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
|
|
145
|
+
if cleaned_mol is None:
|
|
135
146
|
return None, None
|
|
136
147
|
|
|
148
|
+
# If not doing any transformations, return early
|
|
149
|
+
if not self.remove_salts and not self.canonicalize_tautomer:
|
|
150
|
+
return cleaned_mol, None
|
|
151
|
+
|
|
137
152
|
salt_smiles = None
|
|
153
|
+
mol = cleaned_mol
|
|
138
154
|
|
|
139
155
|
# Step 2: Fragment handling (conditional based on remove_salts)
|
|
140
156
|
if self.remove_salts:
|
|
141
|
-
# Get parent molecule
|
|
142
|
-
parent_mol = rdMolStandardize.FragmentParent(
|
|
157
|
+
# Get parent molecule
|
|
158
|
+
parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
|
|
143
159
|
if parent_mol:
|
|
144
|
-
|
|
160
|
+
# Extract salt BEFORE any modifications to parent
|
|
161
|
+
salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
|
|
145
162
|
mol = parent_mol
|
|
146
163
|
else:
|
|
147
164
|
return None, None
|
|
@@ -153,7 +170,7 @@ class MolStandardizer:
|
|
|
153
170
|
if mol is None:
|
|
154
171
|
return None, salt_smiles
|
|
155
172
|
|
|
156
|
-
# Step 4: Canonicalize tautomer
|
|
173
|
+
# Step 4: Canonicalize tautomer (LAST STEP)
|
|
157
174
|
if self.canonicalize_tautomer:
|
|
158
175
|
mol = self.tautomer_enumerator.Canonicalize(mol)
|
|
159
176
|
|
|
@@ -172,13 +189,22 @@ class MolStandardizer:
|
|
|
172
189
|
- Mixtures: multiple large neutral organic fragments
|
|
173
190
|
|
|
174
191
|
Args:
|
|
175
|
-
orig_mol: Original molecule (before FragmentParent)
|
|
176
|
-
parent_mol: Parent molecule (after FragmentParent)
|
|
192
|
+
orig_mol: Original molecule (after Cleanup, before FragmentParent)
|
|
193
|
+
parent_mol: Parent molecule (after FragmentParent, before tautomerization)
|
|
177
194
|
|
|
178
195
|
Returns:
|
|
179
196
|
SMILES string of salt components or None if no salts/mixture detected
|
|
180
197
|
"""
|
|
181
198
|
try:
|
|
199
|
+
# Quick atom count check
|
|
200
|
+
if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# Quick heavy atom difference check
|
|
204
|
+
heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
|
|
205
|
+
if heavy_diff <= 0:
|
|
206
|
+
return None
|
|
207
|
+
|
|
182
208
|
# Get all fragments from original molecule
|
|
183
209
|
orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
|
|
184
210
|
|
|
@@ -268,7 +294,7 @@ def standardize(
|
|
|
268
294
|
if "orig_smiles" not in result.columns:
|
|
269
295
|
result["orig_smiles"] = result[smiles_column]
|
|
270
296
|
|
|
271
|
-
# Initialize standardizer
|
|
297
|
+
# Initialize standardizer
|
|
272
298
|
standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
|
|
273
299
|
|
|
274
300
|
def process_smiles(smiles: str) -> pd.Series:
|
|
@@ -286,6 +312,11 @@ def standardize(
|
|
|
286
312
|
log.error("Encountered missing or empty SMILES string")
|
|
287
313
|
return pd.Series({"smiles": None, "salt": None})
|
|
288
314
|
|
|
315
|
+
# Early check for unreasonably long SMILES
|
|
316
|
+
if len(smiles) > 1000:
|
|
317
|
+
log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
|
|
318
|
+
return pd.Series({"smiles": None, "salt": None})
|
|
319
|
+
|
|
289
320
|
# Parse molecule
|
|
290
321
|
mol = Chem.MolFromSmiles(smiles)
|
|
291
322
|
if mol is None:
|
|
@@ -299,7 +330,9 @@ def standardize(
|
|
|
299
330
|
if std_mol is not None:
|
|
300
331
|
# Check if molecule is reasonable
|
|
301
332
|
if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
|
|
302
|
-
log.error(f"
|
|
333
|
+
log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
|
|
334
|
+
log.error(f"Original SMILES: {smiles}")
|
|
335
|
+
return pd.Series({"smiles": None, "salt": salt_smiles})
|
|
303
336
|
|
|
304
337
|
if std_mol is None:
|
|
305
338
|
return pd.Series(
|
|
@@ -325,8 +358,11 @@ def standardize(
|
|
|
325
358
|
|
|
326
359
|
|
|
327
360
|
if __name__ == "__main__":
|
|
328
|
-
|
|
329
|
-
|
|
361
|
+
|
|
362
|
+
# Pandas display options for better readability
|
|
363
|
+
pd.set_option("display.max_columns", None)
|
|
364
|
+
pd.set_option("display.width", 1000)
|
|
365
|
+
pd.set_option("display.max_colwidth", 100)
|
|
330
366
|
|
|
331
367
|
# Test with DataFrame including various salt forms
|
|
332
368
|
test_data = pd.DataFrame(
|
|
@@ -362,67 +398,53 @@ if __name__ == "__main__":
|
|
|
362
398
|
)
|
|
363
399
|
|
|
364
400
|
# General test
|
|
401
|
+
print("Testing standardization with full dataset...")
|
|
365
402
|
standardize(test_data)
|
|
366
403
|
|
|
367
404
|
# Remove the last two rows to avoid errors with None and INVALID
|
|
368
405
|
test_data = test_data.iloc[:-2].reset_index(drop=True)
|
|
369
406
|
|
|
370
407
|
# Test WITHOUT salt removal (keeps full molecule)
|
|
371
|
-
print("\nStandardization KEEPING salts (extract_salts=False):")
|
|
372
|
-
print("This preserves the full molecule including counterions")
|
|
408
|
+
print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
|
|
373
409
|
result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
|
|
374
|
-
|
|
375
|
-
print(result_keep[
|
|
410
|
+
display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
|
|
411
|
+
print(result_keep[display_order])
|
|
376
412
|
|
|
377
413
|
# Test WITH salt removal
|
|
378
414
|
print("\n" + "=" * 70)
|
|
379
415
|
print("Standardization REMOVING salts (extract_salts=True):")
|
|
380
|
-
print("This extracts parent molecule and records salt information")
|
|
381
416
|
result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
|
|
382
|
-
print(result_remove[
|
|
417
|
+
print(result_remove[display_order])
|
|
383
418
|
|
|
384
|
-
# Test
|
|
419
|
+
# Test with problematic cases specifically
|
|
385
420
|
print("\n" + "=" * 70)
|
|
386
|
-
print("
|
|
387
|
-
|
|
388
|
-
|
|
421
|
+
print("Testing specific problematic cases:")
|
|
422
|
+
problem_cases = pd.DataFrame(
|
|
423
|
+
{
|
|
424
|
+
"smiles": [
|
|
425
|
+
"CC(=O)O.CCN", # Should extract CC(=O)O as salt
|
|
426
|
+
"CCO.CC", # Should return CC as salt
|
|
427
|
+
],
|
|
428
|
+
"compound_id": ["TEST_C002", "TEST_C005"],
|
|
429
|
+
}
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
|
|
433
|
+
print(problem_result[display_order])
|
|
434
|
+
|
|
435
|
+
# Performance test with larger dataset
|
|
436
|
+
from workbench.api import DataSource
|
|
389
437
|
|
|
390
|
-
# Show the difference for salt-containing molecules
|
|
391
|
-
print("\n" + "=" * 70)
|
|
392
|
-
print("Comparison showing differences:")
|
|
393
|
-
for idx, row in result_keep.iterrows():
|
|
394
|
-
keep_smiles = row["smiles"]
|
|
395
|
-
remove_smiles = result_remove.loc[idx, "smiles"]
|
|
396
|
-
no_taut_smiles = result_no_taut.loc[idx, "smiles"]
|
|
397
|
-
salt = result_remove.loc[idx, "salt"]
|
|
398
|
-
|
|
399
|
-
# Show differences when they exist
|
|
400
|
-
if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
|
|
401
|
-
print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
|
|
402
|
-
if keep_smiles != no_taut_smiles:
|
|
403
|
-
print(f" With salt + taut: {keep_smiles}")
|
|
404
|
-
print(f" With salt, no taut: {no_taut_smiles}")
|
|
405
|
-
if keep_smiles != remove_smiles:
|
|
406
|
-
print(f" Parent only + taut: {remove_smiles}")
|
|
407
|
-
if salt:
|
|
408
|
-
print(f" Extracted salt: {salt}")
|
|
409
|
-
|
|
410
|
-
# Summary statistics
|
|
411
438
|
print("\n" + "=" * 70)
|
|
412
|
-
print("Summary:")
|
|
413
|
-
print(f"Total molecules: {len(result_remove)}")
|
|
414
|
-
print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
|
|
415
|
-
unique_salts = result_remove["salt"].dropna().unique()
|
|
416
|
-
print(f"Unique salts found: {unique_salts[:5].tolist()}")
|
|
417
439
|
|
|
418
|
-
# Get a real dataset from Workbench and time the standardization
|
|
419
440
|
ds = DataSource("aqsol_data")
|
|
420
|
-
df = ds.pull_dataframe()[["id", "smiles"]]
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
441
|
+
df = ds.pull_dataframe()[["id", "smiles"]][:1000]
|
|
442
|
+
|
|
443
|
+
for tautomer in [True, False]:
|
|
444
|
+
for extract in [True, False]:
|
|
445
|
+
print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
|
|
446
|
+
start_time = time.time()
|
|
447
|
+
std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
|
|
448
|
+
elapsed = time.time() - start_time
|
|
449
|
+
mol_per_sec = len(df) / elapsed
|
|
450
|
+
print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
|
|
@@ -17,18 +17,28 @@ log = logging.getLogger("workbench")
|
|
|
17
17
|
|
|
18
18
|
def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
|
|
19
19
|
"""
|
|
20
|
-
Convert
|
|
20
|
+
Convert fingerprints to numpy matrix.
|
|
21
|
+
|
|
22
|
+
Supports two formats (auto-detected):
|
|
23
|
+
- Bitstrings: "10110010..." → matrix of 0s and 1s
|
|
24
|
+
- Count vectors: "0,3,0,1,5,..." → matrix of counts (or binary if dtype=np.bool_)
|
|
21
25
|
|
|
22
26
|
Args:
|
|
23
|
-
fingerprints: pandas Series or list of
|
|
24
|
-
dtype: numpy data type (uint8 is default
|
|
27
|
+
fingerprints: pandas Series or list of fingerprints
|
|
28
|
+
dtype: numpy data type (uint8 is default; np.bool_ for Jaccard computations)
|
|
25
29
|
|
|
26
30
|
Returns:
|
|
27
31
|
dense numpy array of shape (n_molecules, n_bits)
|
|
28
32
|
"""
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
33
|
+
# Auto-detect format based on first fingerprint
|
|
34
|
+
sample = str(fingerprints.iloc[0] if hasattr(fingerprints, "iloc") else fingerprints[0])
|
|
35
|
+
if "," in sample:
|
|
36
|
+
# Count vector format: comma-separated integers
|
|
37
|
+
matrix = np.array([list(map(int, fp.split(","))) for fp in fingerprints], dtype=dtype)
|
|
38
|
+
else:
|
|
39
|
+
# Bitstring format: each character is a bit
|
|
40
|
+
matrix = np.array([list(fp) for fp in fingerprints], dtype=dtype)
|
|
41
|
+
return matrix
|
|
32
42
|
|
|
33
43
|
|
|
34
44
|
def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame:
|
|
@@ -2,34 +2,18 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import base64
|
|
5
|
-
import re
|
|
6
5
|
from typing import Optional, Tuple
|
|
7
6
|
from rdkit import Chem
|
|
8
7
|
from rdkit.Chem import AllChem, Draw
|
|
9
8
|
from rdkit.Chem.Draw import rdMolDraw2D
|
|
10
9
|
|
|
10
|
+
# Workbench Imports
|
|
11
|
+
from workbench.utils.color_utils import is_dark
|
|
12
|
+
|
|
11
13
|
# Set up the logger
|
|
12
14
|
log = logging.getLogger("workbench")
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
def _is_dark(color: str) -> bool:
|
|
16
|
-
"""Determine if an rgba color is dark based on RGB average.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
color: Color in rgba(...) format
|
|
20
|
-
|
|
21
|
-
Returns:
|
|
22
|
-
True if the color is dark, False otherwise
|
|
23
|
-
"""
|
|
24
|
-
match = re.match(r"rgba?\((\d+),\s*(\d+),\s*(\d+)", color)
|
|
25
|
-
if not match:
|
|
26
|
-
log.warning(f"Invalid color format: {color}, defaulting to dark")
|
|
27
|
-
return True # Default to dark mode on error
|
|
28
|
-
|
|
29
|
-
r, g, b = map(int, match.groups())
|
|
30
|
-
return (r + g + b) / 3 < 128
|
|
31
|
-
|
|
32
|
-
|
|
33
17
|
def _rgba_to_tuple(rgba: str) -> Tuple[float, float, float, float]:
|
|
34
18
|
"""Convert rgba string to normalized tuple (R, G, B, A).
|
|
35
19
|
|
|
@@ -75,7 +59,13 @@ def _configure_draw_options(options: Draw.MolDrawOptions, background: str) -> No
|
|
|
75
59
|
options: RDKit drawing options object
|
|
76
60
|
background: Background color string
|
|
77
61
|
"""
|
|
78
|
-
|
|
62
|
+
try:
|
|
63
|
+
if is_dark(background):
|
|
64
|
+
rdMolDraw2D.SetDarkMode(options)
|
|
65
|
+
# Light backgrounds use RDKit defaults (no action needed)
|
|
66
|
+
except ValueError:
|
|
67
|
+
# Default to dark mode if color format is invalid
|
|
68
|
+
log.warning(f"Invalid color format: {background}, defaulting to dark mode")
|
|
79
69
|
rdMolDraw2D.SetDarkMode(options)
|
|
80
70
|
options.setBackgroundColour(_rgba_to_tuple(background))
|
|
81
71
|
|
|
@@ -137,7 +127,7 @@ def svg_from_smiles(
|
|
|
137
127
|
drawer.DrawMolecule(mol)
|
|
138
128
|
drawer.FinishDrawing()
|
|
139
129
|
|
|
140
|
-
# Encode SVG
|
|
130
|
+
# Encode SVG as base64 data URI
|
|
141
131
|
svg = drawer.GetDrawingText()
|
|
142
132
|
encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
|
143
133
|
return f"data:image/svg+xml;base64,{encoded_svg}"
|
|
@@ -222,7 +212,7 @@ if __name__ == "__main__":
|
|
|
222
212
|
# Test 6: Color parsing functions
|
|
223
213
|
print("\n6. Testing color utility functions...")
|
|
224
214
|
test_colors = [
|
|
225
|
-
("invalid_color",
|
|
215
|
+
("invalid_color", None, (0.25, 0.25, 0.25, 1.0)), # Should raise ValueError
|
|
226
216
|
("rgba(255, 255, 255, 1)", False, (1.0, 1.0, 1.0, 1.0)),
|
|
227
217
|
("rgba(0, 0, 0, 1)", True, (0.0, 0.0, 0.0, 1.0)),
|
|
228
218
|
("rgba(64, 64, 64, 0.5)", True, (0.251, 0.251, 0.251, 0.5)),
|
|
@@ -230,12 +220,20 @@ if __name__ == "__main__":
|
|
|
230
220
|
]
|
|
231
221
|
|
|
232
222
|
for color, expected_dark, expected_tuple in test_colors:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
223
|
+
try:
|
|
224
|
+
is_dark_result = is_dark(color)
|
|
225
|
+
if expected_dark is None:
|
|
226
|
+
print(f" ✗ is_dark('{color[:20]}...'): Expected ValueError but got {is_dark_result}")
|
|
227
|
+
else:
|
|
228
|
+
dark_status = "✓" if is_dark_result == expected_dark else "✗"
|
|
229
|
+
print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
|
|
230
|
+
except ValueError:
|
|
231
|
+
if expected_dark is None:
|
|
232
|
+
print(f" ✓ is_dark('{color[:20]}...'): Correctly raised ValueError")
|
|
233
|
+
else:
|
|
234
|
+
print(f" ✗ is_dark('{color[:20]}...'): Unexpected ValueError")
|
|
238
235
|
|
|
236
|
+
tuple_result = _rgba_to_tuple(color)
|
|
239
237
|
# Check tuple values with tolerance for floating point
|
|
240
238
|
tuple_match = all(abs(a - b) < 0.01 for a, b in zip(tuple_result, expected_tuple))
|
|
241
239
|
tuple_status = "✓" if tuple_match else "✗"
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""ChemProp utilities for Workbench models."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Tuple
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from workbench.utils.aws_utils import pull_s3_data
|
|
10
|
+
from workbench.utils.metrics_utils import compute_metrics_from_predictions
|
|
11
|
+
from workbench.utils.model_utils import safe_extract_tarfile
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger("workbench")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
|
|
17
|
+
"""Download model artifact from S3 and extract it.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
s3_uri: S3 URI to the model artifact (model.tar.gz)
|
|
21
|
+
model_dir: Directory to extract model artifacts to
|
|
22
|
+
"""
|
|
23
|
+
import awswrangler as wr
|
|
24
|
+
|
|
25
|
+
log.info(f"Downloading model from {s3_uri}...")
|
|
26
|
+
|
|
27
|
+
# Download to temp file
|
|
28
|
+
local_tar_path = os.path.join(model_dir, "model.tar.gz")
|
|
29
|
+
wr.s3.download(path=s3_uri, local_file=local_tar_path)
|
|
30
|
+
|
|
31
|
+
# Extract using safe extraction
|
|
32
|
+
log.info(f"Extracting to {model_dir}...")
|
|
33
|
+
safe_extract_tarfile(local_tar_path, model_dir)
|
|
34
|
+
|
|
35
|
+
# Cleanup tar file
|
|
36
|
+
os.unlink(local_tar_path)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def load_chemprop_model_artifacts(model_dir: str) -> Tuple[Any, dict]:
|
|
40
|
+
"""Load ChemProp MPNN model and artifacts from an extracted model directory.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model_dir: Directory containing extracted model artifacts
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tuple of (MPNN model, artifacts_dict).
|
|
47
|
+
artifacts_dict contains 'label_encoder' and 'feature_metadata' if present.
|
|
48
|
+
"""
|
|
49
|
+
import joblib
|
|
50
|
+
from chemprop import models
|
|
51
|
+
|
|
52
|
+
model_path = os.path.join(model_dir, "chemprop_model.pt")
|
|
53
|
+
if not os.path.exists(model_path):
|
|
54
|
+
raise FileNotFoundError(f"No chemprop_model.pt found in {model_dir}")
|
|
55
|
+
|
|
56
|
+
model = models.MPNN.load_from_file(model_path)
|
|
57
|
+
model.eval()
|
|
58
|
+
|
|
59
|
+
# Load additional artifacts
|
|
60
|
+
artifacts = {}
|
|
61
|
+
|
|
62
|
+
label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
63
|
+
if os.path.exists(label_encoder_path):
|
|
64
|
+
artifacts["label_encoder"] = joblib.load(label_encoder_path)
|
|
65
|
+
|
|
66
|
+
feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
|
|
67
|
+
if os.path.exists(feature_metadata_path):
|
|
68
|
+
artifacts["feature_metadata"] = joblib.load(feature_metadata_path)
|
|
69
|
+
|
|
70
|
+
return model, artifacts
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
74
|
+
"""Pull cross-validation results from AWS training artifacts.
|
|
75
|
+
|
|
76
|
+
This retrieves the validation predictions saved during model training and
|
|
77
|
+
computes metrics directly from them.
|
|
78
|
+
|
|
79
|
+
Note:
|
|
80
|
+
- Regression: Supports both single-target and multi-target models
|
|
81
|
+
- Classification: Only single-target is supported (with any number of classes)
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
workbench_model: Workbench model object
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Tuple of:
|
|
88
|
+
- DataFrame with computed metrics
|
|
89
|
+
- DataFrame with validation predictions
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
# Get the validation predictions from S3
|
|
93
|
+
s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
|
|
94
|
+
predictions_df = pull_s3_data(s3_path)
|
|
95
|
+
|
|
96
|
+
if predictions_df is None:
|
|
97
|
+
raise ValueError(f"No validation predictions found at {s3_path}")
|
|
98
|
+
|
|
99
|
+
log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
|
|
100
|
+
|
|
101
|
+
# Get target and class labels
|
|
102
|
+
target = workbench_model.target()
|
|
103
|
+
class_labels = workbench_model.class_labels()
|
|
104
|
+
|
|
105
|
+
# If single target just use the "prediction" column
|
|
106
|
+
if isinstance(target, str):
|
|
107
|
+
metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
|
|
108
|
+
return metrics_df, predictions_df
|
|
109
|
+
|
|
110
|
+
# Multi-target regression
|
|
111
|
+
metrics_list = []
|
|
112
|
+
for t in target:
|
|
113
|
+
# Prediction will be {target}_pred in multi-target case
|
|
114
|
+
pred_col = f"{t}_pred"
|
|
115
|
+
|
|
116
|
+
# Drop NaNs for this target
|
|
117
|
+
target_preds_df = predictions_df.dropna(subset=[t, pred_col])
|
|
118
|
+
metrics_df = compute_metrics_from_predictions(target_preds_df, t, class_labels, prediction_col=pred_col)
|
|
119
|
+
metrics_df.insert(0, "target", t)
|
|
120
|
+
metrics_list.append(metrics_df)
|
|
121
|
+
metrics_df = pd.concat(metrics_list, ignore_index=True) if metrics_list else pd.DataFrame()
|
|
122
|
+
|
|
123
|
+
return metrics_df, predictions_df
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
if __name__ == "__main__":
|
|
127
|
+
|
|
128
|
+
# Tests for the ChemProp utilities
|
|
129
|
+
from workbench.api import Model
|
|
130
|
+
|
|
131
|
+
# Initialize Workbench model
|
|
132
|
+
model_name = "open-admet-chemprop-mt"
|
|
133
|
+
print(f"Loading Workbench model: {model_name}")
|
|
134
|
+
model = Model(model_name)
|
|
135
|
+
print(f"Model Framework: {model.model_framework}")
|
|
136
|
+
|
|
137
|
+
# Pull CV results
|
|
138
|
+
metrics_df, predictions_df = pull_cv_results(model)
|
|
139
|
+
print("\nTraining Metrics:")
|
|
140
|
+
print(metrics_df.to_string(index=False))
|
|
141
|
+
print(f"\nSample Predictions:\n{predictions_df.head().to_string(index=False)}")
|
|
@@ -4,16 +4,13 @@ import os
|
|
|
4
4
|
import sys
|
|
5
5
|
import platform
|
|
6
6
|
import logging
|
|
7
|
-
import importlib.resources as resources # noqa: F401 Python 3.9 compatibility
|
|
8
7
|
from typing import Any, Dict
|
|
8
|
+
from importlib.resources import files, as_file
|
|
9
9
|
|
|
10
10
|
# Workbench imports
|
|
11
11
|
from workbench.utils.license_manager import LicenseManager
|
|
12
12
|
from workbench_bridges.utils.execution_environment import running_as_service
|
|
13
13
|
|
|
14
|
-
# Python 3.9 compatibility
|
|
15
|
-
from workbench.utils.resource_utils import get_resource_path
|
|
16
|
-
|
|
17
14
|
|
|
18
15
|
class FatalConfigError(Exception):
|
|
19
16
|
"""Exception raised for errors in the configuration."""
|
|
@@ -172,8 +169,7 @@ class ConfigManager:
|
|
|
172
169
|
Returns:
|
|
173
170
|
str: The open source API key.
|
|
174
171
|
"""
|
|
175
|
-
|
|
176
|
-
with get_resource_path("workbench.resources", "open_source_api.key") as open_source_key_path:
|
|
172
|
+
with as_file(files("workbench.resources").joinpath("open_source_api.key")) as open_source_key_path:
|
|
177
173
|
with open(open_source_key_path, "r") as key_file:
|
|
178
174
|
return key_file.read().strip()
|
|
179
175
|
|
|
@@ -7,9 +7,7 @@ from typing import Union, Optional
|
|
|
7
7
|
import pandas as pd
|
|
8
8
|
|
|
9
9
|
# Workbench Imports
|
|
10
|
-
from workbench.api
|
|
11
|
-
from workbench.api.model import Model
|
|
12
|
-
from workbench.api.endpoint import Endpoint
|
|
10
|
+
from workbench.api import FeatureSet, Model, Endpoint
|
|
13
11
|
|
|
14
12
|
# Set up the log
|
|
15
13
|
log = logging.getLogger("workbench")
|
|
@@ -77,7 +75,7 @@ def internal_model_data_url(endpoint_config_name: str, session: boto3.Session) -
|
|
|
77
75
|
return None
|
|
78
76
|
|
|
79
77
|
|
|
80
|
-
def
|
|
78
|
+
def get_training_data(end: Endpoint) -> pd.DataFrame:
|
|
81
79
|
"""Code to get the training data from the FeatureSet used to train the Model
|
|
82
80
|
|
|
83
81
|
Args:
|
|
@@ -100,7 +98,7 @@ def fs_training_data(end: Endpoint) -> pd.DataFrame:
|
|
|
100
98
|
return train_df
|
|
101
99
|
|
|
102
100
|
|
|
103
|
-
def
|
|
101
|
+
def get_evaluation_data(end: Endpoint) -> pd.DataFrame:
|
|
104
102
|
"""Code to get the evaluation data from the FeatureSet NOT used for training
|
|
105
103
|
|
|
106
104
|
Args:
|
|
@@ -178,11 +176,11 @@ if __name__ == "__main__":
|
|
|
178
176
|
print(model_data_url)
|
|
179
177
|
|
|
180
178
|
# Get the training data
|
|
181
|
-
my_train_df =
|
|
179
|
+
my_train_df = get_training_data(my_endpoint)
|
|
182
180
|
print(my_train_df)
|
|
183
181
|
|
|
184
182
|
# Get the evaluation data
|
|
185
|
-
my_eval_df =
|
|
183
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
186
184
|
print(my_eval_df)
|
|
187
185
|
|
|
188
186
|
# Backtrack to the FeatureSet
|