workbench 0.8.176__py3-none-any.whl → 0.8.177__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/core/artifacts/endpoint_core.py +4 -1
- workbench/core/artifacts/model_core.py +8 -29
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +19 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/uq_models/mapie.template +3 -3
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/utils/chem_utils/mol_descriptors.py +19 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/model_utils.py +28 -25
- {workbench-0.8.176.dist-info → workbench-0.8.177.dist-info}/METADATA +1 -1
- {workbench-0.8.176.dist-info → workbench-0.8.177.dist-info}/RECORD +15 -16
- workbench/utils/fast_inference.py +0 -167
- {workbench-0.8.176.dist-info → workbench-0.8.177.dist-info}/WHEEL +0 -0
- {workbench-0.8.176.dist-info → workbench-0.8.177.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.176.dist-info → workbench-0.8.177.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.176.dist-info → workbench-0.8.177.dist-info}/top_level.txt +0 -0
|
@@ -32,11 +32,11 @@ from sagemaker import Predictor
|
|
|
32
32
|
from workbench.core.artifacts.artifact import Artifact
|
|
33
33
|
from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
|
|
34
34
|
from workbench.utils.endpoint_metrics import EndpointMetrics
|
|
35
|
-
from workbench.utils.fast_inference import fast_inference
|
|
36
35
|
from workbench.utils.cache import Cache
|
|
37
36
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
38
37
|
from workbench.utils.model_utils import uq_metrics
|
|
39
38
|
from workbench.utils.xgboost_model_utils import cross_fold_inference
|
|
39
|
+
from workbench_bridges.endpoints.fast_inference import fast_inference
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class EndpointCore(Artifact):
|
|
@@ -1061,6 +1061,9 @@ if __name__ == "__main__":
|
|
|
1061
1061
|
assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
|
|
1062
1062
|
|
|
1063
1063
|
# Now we put in an invalid value
|
|
1064
|
+
print("*" * 80)
|
|
1065
|
+
print("NOW TESTING ERROR CONDITIONS...")
|
|
1066
|
+
print("*" * 80)
|
|
1064
1067
|
my_eval_df.at[42, "length"] = "invalid_value"
|
|
1065
1068
|
pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
|
|
1066
1069
|
print(f"Sent rows: {len(my_eval_df)}")
|
|
@@ -37,35 +37,6 @@ class ModelType(Enum):
|
|
|
37
37
|
UNKNOWN = "unknown"
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
# Deprecated Images
|
|
41
|
-
"""
|
|
42
|
-
# US East 1 images
|
|
43
|
-
"py312-general-ml-training"
|
|
44
|
-
("us-east-1", "training", "0.1", "x86_64"): (
|
|
45
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
46
|
-
),
|
|
47
|
-
("us-east-1", "inference", "0.1", "x86_64"): (
|
|
48
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
49
|
-
),
|
|
50
|
-
|
|
51
|
-
# US West 2 images
|
|
52
|
-
("us-west-2", "training", "0.1", "x86_64"): (
|
|
53
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
|
|
54
|
-
),
|
|
55
|
-
("us-west-2", "inference", "0.1", "x86_64"): (
|
|
56
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
|
|
57
|
-
),
|
|
58
|
-
|
|
59
|
-
# ARM64 images
|
|
60
|
-
("us-east-1", "inference", "0.1", "arm64"): (
|
|
61
|
-
"507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
62
|
-
),
|
|
63
|
-
("us-west-2", "inference", "0.1", "arm64"): (
|
|
64
|
-
"507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
|
|
65
|
-
),
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
|
|
69
40
|
class ModelImages:
|
|
70
41
|
"""Class for retrieving workbench inference images"""
|
|
71
42
|
|
|
@@ -890,6 +861,14 @@ class ModelCore(Artifact):
|
|
|
890
861
|
shap_data[key] = self.df_store.get(df_location)
|
|
891
862
|
return shap_data or None
|
|
892
863
|
|
|
864
|
+
def cross_folds(self) -> dict:
|
|
865
|
+
"""Retrieve the cross-fold inference results(only works for XGBoost models)
|
|
866
|
+
|
|
867
|
+
Returns:
|
|
868
|
+
dict: Dictionary with the cross-fold inference results
|
|
869
|
+
"""
|
|
870
|
+
return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
|
|
871
|
+
|
|
893
872
|
def supported_inference_instances(self) -> Optional[list]:
|
|
894
873
|
"""Retrieve the supported endpoint inference instance types
|
|
895
874
|
|
|
@@ -91,16 +91,27 @@ import logging
|
|
|
91
91
|
import pandas as pd
|
|
92
92
|
import numpy as np
|
|
93
93
|
import re
|
|
94
|
+
import time
|
|
95
|
+
from contextlib import contextmanager
|
|
94
96
|
from rdkit import Chem
|
|
95
97
|
from rdkit.Chem import Descriptors, rdCIPLabeler
|
|
96
98
|
from rdkit.ML.Descriptors import MoleculeDescriptors
|
|
97
99
|
from mordred import Calculator as MordredCalculator
|
|
98
100
|
from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
|
|
99
101
|
|
|
102
|
+
|
|
100
103
|
logger = logging.getLogger("workbench")
|
|
101
104
|
logger.setLevel(logging.DEBUG)
|
|
102
105
|
|
|
103
106
|
|
|
107
|
+
# Helper context manager for timing
|
|
108
|
+
@contextmanager
|
|
109
|
+
def timer(name):
|
|
110
|
+
start = time.time()
|
|
111
|
+
yield
|
|
112
|
+
print(f"{name}: {time.time() - start:.2f}s")
|
|
113
|
+
|
|
114
|
+
|
|
104
115
|
def compute_stereochemistry_features(mol):
|
|
105
116
|
"""
|
|
106
117
|
Compute stereochemistry descriptors using modern RDKit methods.
|
|
@@ -280,9 +291,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
280
291
|
descriptor_values.append([np.nan] * len(all_descriptors))
|
|
281
292
|
|
|
282
293
|
# Create RDKit features DataFrame
|
|
283
|
-
rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames()
|
|
294
|
+
rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
|
|
284
295
|
|
|
285
296
|
# Add RDKit features to result
|
|
297
|
+
# Remove any columns from result that exist in rdkit_features_df
|
|
298
|
+
result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
|
|
286
299
|
result = pd.concat([result, rdkit_features_df], axis=1)
|
|
287
300
|
|
|
288
301
|
# Compute Mordred descriptors
|
|
@@ -299,7 +312,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
299
312
|
|
|
300
313
|
# Compute Mordred descriptors
|
|
301
314
|
valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
|
|
302
|
-
mordred_df = calc.pandas(valid_mols, nproc=1) #
|
|
315
|
+
mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
|
|
303
316
|
|
|
304
317
|
# Replace values for invalid molecules with NaN
|
|
305
318
|
for i, mol in enumerate(molecules):
|
|
@@ -310,10 +323,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
310
323
|
for col in mordred_df.columns:
|
|
311
324
|
mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
|
|
312
325
|
|
|
313
|
-
# Set index to match result DataFrame
|
|
314
|
-
mordred_df.index = result.index
|
|
315
|
-
|
|
316
326
|
# Add Mordred features to result
|
|
327
|
+
# Remove any columns from result that exist in mordred
|
|
328
|
+
result = result.drop(columns=result.columns.intersection(mordred_df.columns))
|
|
317
329
|
result = pd.concat([result, mordred_df], axis=1)
|
|
318
330
|
|
|
319
331
|
# Compute stereochemistry features if requested
|
|
@@ -326,9 +338,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
326
338
|
stereo_features.append(stereo_dict)
|
|
327
339
|
|
|
328
340
|
# Create stereochemistry DataFrame
|
|
329
|
-
stereo_df = pd.DataFrame(stereo_features
|
|
341
|
+
stereo_df = pd.DataFrame(stereo_features)
|
|
330
342
|
|
|
331
343
|
# Add stereochemistry features to result
|
|
344
|
+
result = result.drop(columns=result.columns.intersection(stereo_df.columns))
|
|
332
345
|
result = pd.concat([result, stereo_df], axis=1)
|
|
333
346
|
|
|
334
347
|
logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
|
|
@@ -357,7 +370,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
357
370
|
|
|
358
371
|
|
|
359
372
|
if __name__ == "__main__":
|
|
360
|
-
import time
|
|
361
373
|
from mol_standardize import standardize
|
|
362
374
|
from workbench.api import DataSource
|
|
363
375
|
|
|
@@ -81,6 +81,8 @@ Usage:
|
|
|
81
81
|
import logging
|
|
82
82
|
from typing import Optional, Tuple
|
|
83
83
|
import pandas as pd
|
|
84
|
+
import time
|
|
85
|
+
from contextlib import contextmanager
|
|
84
86
|
from rdkit import Chem
|
|
85
87
|
from rdkit.Chem import Mol
|
|
86
88
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
|
|
|
90
92
|
RDLogger.DisableLog("rdApp.warning")
|
|
91
93
|
|
|
92
94
|
|
|
95
|
+
# Helper context manager for timing
|
|
96
|
+
@contextmanager
|
|
97
|
+
def timer(name):
|
|
98
|
+
start = time.time()
|
|
99
|
+
yield
|
|
100
|
+
print(f"{name}: {time.time() - start:.2f}s")
|
|
101
|
+
|
|
102
|
+
|
|
93
103
|
class MolStandardizer:
|
|
94
104
|
"""
|
|
95
105
|
Streamlined molecular standardizer for ADMET preprocessing
|
|
@@ -116,6 +126,7 @@ class MolStandardizer:
|
|
|
116
126
|
Pipeline:
|
|
117
127
|
1. Cleanup (remove Hs, disconnect metals, normalize)
|
|
118
128
|
2. Get largest fragment (optional - only if remove_salts=True)
|
|
129
|
+
2a. Extract salt information BEFORE further modifications
|
|
119
130
|
3. Neutralize charges
|
|
120
131
|
4. Canonicalize tautomer (optional)
|
|
121
132
|
|
|
@@ -130,18 +141,24 @@ class MolStandardizer:
|
|
|
130
141
|
|
|
131
142
|
try:
|
|
132
143
|
# Step 1: Cleanup
|
|
133
|
-
|
|
134
|
-
if
|
|
144
|
+
cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
|
|
145
|
+
if cleaned_mol is None:
|
|
135
146
|
return None, None
|
|
136
147
|
|
|
148
|
+
# If not doing any transformations, return early
|
|
149
|
+
if not self.remove_salts and not self.canonicalize_tautomer:
|
|
150
|
+
return cleaned_mol, None
|
|
151
|
+
|
|
137
152
|
salt_smiles = None
|
|
153
|
+
mol = cleaned_mol
|
|
138
154
|
|
|
139
155
|
# Step 2: Fragment handling (conditional based on remove_salts)
|
|
140
156
|
if self.remove_salts:
|
|
141
|
-
# Get parent molecule
|
|
142
|
-
parent_mol = rdMolStandardize.FragmentParent(
|
|
157
|
+
# Get parent molecule
|
|
158
|
+
parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
|
|
143
159
|
if parent_mol:
|
|
144
|
-
|
|
160
|
+
# Extract salt BEFORE any modifications to parent
|
|
161
|
+
salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
|
|
145
162
|
mol = parent_mol
|
|
146
163
|
else:
|
|
147
164
|
return None, None
|
|
@@ -153,7 +170,7 @@ class MolStandardizer:
|
|
|
153
170
|
if mol is None:
|
|
154
171
|
return None, salt_smiles
|
|
155
172
|
|
|
156
|
-
# Step 4: Canonicalize tautomer
|
|
173
|
+
# Step 4: Canonicalize tautomer (LAST STEP)
|
|
157
174
|
if self.canonicalize_tautomer:
|
|
158
175
|
mol = self.tautomer_enumerator.Canonicalize(mol)
|
|
159
176
|
|
|
@@ -172,13 +189,22 @@ class MolStandardizer:
|
|
|
172
189
|
- Mixtures: multiple large neutral organic fragments
|
|
173
190
|
|
|
174
191
|
Args:
|
|
175
|
-
orig_mol: Original molecule (before FragmentParent)
|
|
176
|
-
parent_mol: Parent molecule (after FragmentParent)
|
|
192
|
+
orig_mol: Original molecule (after Cleanup, before FragmentParent)
|
|
193
|
+
parent_mol: Parent molecule (after FragmentParent, before tautomerization)
|
|
177
194
|
|
|
178
195
|
Returns:
|
|
179
196
|
SMILES string of salt components or None if no salts/mixture detected
|
|
180
197
|
"""
|
|
181
198
|
try:
|
|
199
|
+
# Quick atom count check
|
|
200
|
+
if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# Quick heavy atom difference check
|
|
204
|
+
heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
|
|
205
|
+
if heavy_diff <= 0:
|
|
206
|
+
return None
|
|
207
|
+
|
|
182
208
|
# Get all fragments from original molecule
|
|
183
209
|
orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
|
|
184
210
|
|
|
@@ -268,7 +294,7 @@ def standardize(
|
|
|
268
294
|
if "orig_smiles" not in result.columns:
|
|
269
295
|
result["orig_smiles"] = result[smiles_column]
|
|
270
296
|
|
|
271
|
-
# Initialize standardizer
|
|
297
|
+
# Initialize standardizer
|
|
272
298
|
standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
|
|
273
299
|
|
|
274
300
|
def process_smiles(smiles: str) -> pd.Series:
|
|
@@ -286,6 +312,11 @@ def standardize(
|
|
|
286
312
|
log.error("Encountered missing or empty SMILES string")
|
|
287
313
|
return pd.Series({"smiles": None, "salt": None})
|
|
288
314
|
|
|
315
|
+
# Early check for unreasonably long SMILES
|
|
316
|
+
if len(smiles) > 1000:
|
|
317
|
+
log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
|
|
318
|
+
return pd.Series({"smiles": None, "salt": None})
|
|
319
|
+
|
|
289
320
|
# Parse molecule
|
|
290
321
|
mol = Chem.MolFromSmiles(smiles)
|
|
291
322
|
if mol is None:
|
|
@@ -299,7 +330,9 @@ def standardize(
|
|
|
299
330
|
if std_mol is not None:
|
|
300
331
|
# Check if molecule is reasonable
|
|
301
332
|
if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
|
|
302
|
-
log.error(f"
|
|
333
|
+
log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
|
|
334
|
+
log.error(f"Original SMILES: {smiles}")
|
|
335
|
+
return pd.Series({"smiles": None, "salt": salt_smiles})
|
|
303
336
|
|
|
304
337
|
if std_mol is None:
|
|
305
338
|
return pd.Series(
|
|
@@ -325,8 +358,11 @@ def standardize(
|
|
|
325
358
|
|
|
326
359
|
|
|
327
360
|
if __name__ == "__main__":
|
|
328
|
-
|
|
329
|
-
|
|
361
|
+
|
|
362
|
+
# Pandas display options for better readability
|
|
363
|
+
pd.set_option("display.max_columns", None)
|
|
364
|
+
pd.set_option("display.width", 1000)
|
|
365
|
+
pd.set_option("display.max_colwidth", 100)
|
|
330
366
|
|
|
331
367
|
# Test with DataFrame including various salt forms
|
|
332
368
|
test_data = pd.DataFrame(
|
|
@@ -362,67 +398,53 @@ if __name__ == "__main__":
|
|
|
362
398
|
)
|
|
363
399
|
|
|
364
400
|
# General test
|
|
401
|
+
print("Testing standardization with full dataset...")
|
|
365
402
|
standardize(test_data)
|
|
366
403
|
|
|
367
404
|
# Remove the last two rows to avoid errors with None and INVALID
|
|
368
405
|
test_data = test_data.iloc[:-2].reset_index(drop=True)
|
|
369
406
|
|
|
370
407
|
# Test WITHOUT salt removal (keeps full molecule)
|
|
371
|
-
print("\nStandardization KEEPING salts (extract_salts=False):")
|
|
372
|
-
print("This preserves the full molecule including counterions")
|
|
408
|
+
print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
|
|
373
409
|
result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
|
|
374
|
-
|
|
375
|
-
print(result_keep[
|
|
410
|
+
display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
|
|
411
|
+
print(result_keep[display_order])
|
|
376
412
|
|
|
377
413
|
# Test WITH salt removal
|
|
378
414
|
print("\n" + "=" * 70)
|
|
379
415
|
print("Standardization REMOVING salts (extract_salts=True):")
|
|
380
|
-
print("This extracts parent molecule and records salt information")
|
|
381
416
|
result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
|
|
382
|
-
print(result_remove[
|
|
417
|
+
print(result_remove[display_order])
|
|
383
418
|
|
|
384
|
-
# Test
|
|
419
|
+
# Test with problematic cases specifically
|
|
385
420
|
print("\n" + "=" * 70)
|
|
386
|
-
print("
|
|
387
|
-
|
|
388
|
-
|
|
421
|
+
print("Testing specific problematic cases:")
|
|
422
|
+
problem_cases = pd.DataFrame(
|
|
423
|
+
{
|
|
424
|
+
"smiles": [
|
|
425
|
+
"CC(=O)O.CCN", # Should extract CC(=O)O as salt
|
|
426
|
+
"CCO.CC", # Should return CC as salt
|
|
427
|
+
],
|
|
428
|
+
"compound_id": ["TEST_C002", "TEST_C005"],
|
|
429
|
+
}
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
|
|
433
|
+
print(problem_result[display_order])
|
|
434
|
+
|
|
435
|
+
# Performance test with larger dataset
|
|
436
|
+
from workbench.api import DataSource
|
|
389
437
|
|
|
390
|
-
# Show the difference for salt-containing molecules
|
|
391
|
-
print("\n" + "=" * 70)
|
|
392
|
-
print("Comparison showing differences:")
|
|
393
|
-
for idx, row in result_keep.iterrows():
|
|
394
|
-
keep_smiles = row["smiles"]
|
|
395
|
-
remove_smiles = result_remove.loc[idx, "smiles"]
|
|
396
|
-
no_taut_smiles = result_no_taut.loc[idx, "smiles"]
|
|
397
|
-
salt = result_remove.loc[idx, "salt"]
|
|
398
|
-
|
|
399
|
-
# Show differences when they exist
|
|
400
|
-
if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
|
|
401
|
-
print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
|
|
402
|
-
if keep_smiles != no_taut_smiles:
|
|
403
|
-
print(f" With salt + taut: {keep_smiles}")
|
|
404
|
-
print(f" With salt, no taut: {no_taut_smiles}")
|
|
405
|
-
if keep_smiles != remove_smiles:
|
|
406
|
-
print(f" Parent only + taut: {remove_smiles}")
|
|
407
|
-
if salt:
|
|
408
|
-
print(f" Extracted salt: {salt}")
|
|
409
|
-
|
|
410
|
-
# Summary statistics
|
|
411
438
|
print("\n" + "=" * 70)
|
|
412
|
-
print("Summary:")
|
|
413
|
-
print(f"Total molecules: {len(result_remove)}")
|
|
414
|
-
print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
|
|
415
|
-
unique_salts = result_remove["salt"].dropna().unique()
|
|
416
|
-
print(f"Unique salts found: {unique_salts[:5].tolist()}")
|
|
417
439
|
|
|
418
|
-
# Get a real dataset from Workbench and time the standardization
|
|
419
440
|
ds = DataSource("aqsol_data")
|
|
420
|
-
df = ds.pull_dataframe()[["id", "smiles"]]
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
441
|
+
df = ds.pull_dataframe()[["id", "smiles"]][:1000]
|
|
442
|
+
|
|
443
|
+
for tautomer in [True, False]:
|
|
444
|
+
for extract in [True, False]:
|
|
445
|
+
print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
|
|
446
|
+
start_time = time.time()
|
|
447
|
+
std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
|
|
448
|
+
elapsed = time.time() - start_time
|
|
449
|
+
mol_per_sec = len(df) / elapsed
|
|
450
|
+
print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
|
|
@@ -472,9 +472,9 @@ def predict_fn(df, models) -> pd.DataFrame:
|
|
|
472
472
|
# Add median (q_50) from XGBoost prediction
|
|
473
473
|
df["q_50"] = df["prediction"]
|
|
474
474
|
|
|
475
|
-
# Calculate uncertainty metrics based on
|
|
476
|
-
interval_width = df["
|
|
477
|
-
df["prediction_std"] = interval_width /
|
|
475
|
+
# Calculate uncertainty metrics based on 50% interval
|
|
476
|
+
interval_width = df["q_75"] - df["q_25"]
|
|
477
|
+
df["prediction_std"] = interval_width / 1.348
|
|
478
478
|
|
|
479
479
|
# Reorder the quantile columns for easier reading
|
|
480
480
|
quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
|
|
@@ -28,11 +28,11 @@ from typing import List, Tuple
|
|
|
28
28
|
|
|
29
29
|
# Template Parameters
|
|
30
30
|
TEMPLATE_PARAMS = {
|
|
31
|
-
"model_type": "
|
|
32
|
-
"target": "
|
|
31
|
+
"model_type": "regressor",
|
|
32
|
+
"target": "udm_asy_res_value",
|
|
33
33
|
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
34
34
|
"compressed_features": [],
|
|
35
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/
|
|
35
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/pka-a1-reg-0-nightly-100-test/training",
|
|
36
36
|
"train_all_data": True
|
|
37
37
|
}
|
|
38
38
|
|
|
@@ -91,16 +91,27 @@ import logging
|
|
|
91
91
|
import pandas as pd
|
|
92
92
|
import numpy as np
|
|
93
93
|
import re
|
|
94
|
+
import time
|
|
95
|
+
from contextlib import contextmanager
|
|
94
96
|
from rdkit import Chem
|
|
95
97
|
from rdkit.Chem import Descriptors, rdCIPLabeler
|
|
96
98
|
from rdkit.ML.Descriptors import MoleculeDescriptors
|
|
97
99
|
from mordred import Calculator as MordredCalculator
|
|
98
100
|
from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
|
|
99
101
|
|
|
102
|
+
|
|
100
103
|
logger = logging.getLogger("workbench")
|
|
101
104
|
logger.setLevel(logging.DEBUG)
|
|
102
105
|
|
|
103
106
|
|
|
107
|
+
# Helper context manager for timing
|
|
108
|
+
@contextmanager
|
|
109
|
+
def timer(name):
|
|
110
|
+
start = time.time()
|
|
111
|
+
yield
|
|
112
|
+
print(f"{name}: {time.time() - start:.2f}s")
|
|
113
|
+
|
|
114
|
+
|
|
104
115
|
def compute_stereochemistry_features(mol):
|
|
105
116
|
"""
|
|
106
117
|
Compute stereochemistry descriptors using modern RDKit methods.
|
|
@@ -280,9 +291,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
280
291
|
descriptor_values.append([np.nan] * len(all_descriptors))
|
|
281
292
|
|
|
282
293
|
# Create RDKit features DataFrame
|
|
283
|
-
rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames()
|
|
294
|
+
rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
|
|
284
295
|
|
|
285
296
|
# Add RDKit features to result
|
|
297
|
+
# Remove any columns from result that exist in rdkit_features_df
|
|
298
|
+
result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
|
|
286
299
|
result = pd.concat([result, rdkit_features_df], axis=1)
|
|
287
300
|
|
|
288
301
|
# Compute Mordred descriptors
|
|
@@ -299,7 +312,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
299
312
|
|
|
300
313
|
# Compute Mordred descriptors
|
|
301
314
|
valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
|
|
302
|
-
mordred_df = calc.pandas(valid_mols, nproc=1) #
|
|
315
|
+
mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
|
|
303
316
|
|
|
304
317
|
# Replace values for invalid molecules with NaN
|
|
305
318
|
for i, mol in enumerate(molecules):
|
|
@@ -310,10 +323,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
310
323
|
for col in mordred_df.columns:
|
|
311
324
|
mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
|
|
312
325
|
|
|
313
|
-
# Set index to match result DataFrame
|
|
314
|
-
mordred_df.index = result.index
|
|
315
|
-
|
|
316
326
|
# Add Mordred features to result
|
|
327
|
+
# Remove any columns from result that exist in mordred
|
|
328
|
+
result = result.drop(columns=result.columns.intersection(mordred_df.columns))
|
|
317
329
|
result = pd.concat([result, mordred_df], axis=1)
|
|
318
330
|
|
|
319
331
|
# Compute stereochemistry features if requested
|
|
@@ -326,9 +338,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
326
338
|
stereo_features.append(stereo_dict)
|
|
327
339
|
|
|
328
340
|
# Create stereochemistry DataFrame
|
|
329
|
-
stereo_df = pd.DataFrame(stereo_features
|
|
341
|
+
stereo_df = pd.DataFrame(stereo_features)
|
|
330
342
|
|
|
331
343
|
# Add stereochemistry features to result
|
|
344
|
+
result = result.drop(columns=result.columns.intersection(stereo_df.columns))
|
|
332
345
|
result = pd.concat([result, stereo_df], axis=1)
|
|
333
346
|
|
|
334
347
|
logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
|
|
@@ -357,7 +370,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
|
|
|
357
370
|
|
|
358
371
|
|
|
359
372
|
if __name__ == "__main__":
|
|
360
|
-
import time
|
|
361
373
|
from mol_standardize import standardize
|
|
362
374
|
from workbench.api import DataSource
|
|
363
375
|
|
|
@@ -81,6 +81,8 @@ Usage:
|
|
|
81
81
|
import logging
|
|
82
82
|
from typing import Optional, Tuple
|
|
83
83
|
import pandas as pd
|
|
84
|
+
import time
|
|
85
|
+
from contextlib import contextmanager
|
|
84
86
|
from rdkit import Chem
|
|
85
87
|
from rdkit.Chem import Mol
|
|
86
88
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
|
|
|
90
92
|
RDLogger.DisableLog("rdApp.warning")
|
|
91
93
|
|
|
92
94
|
|
|
95
|
+
# Helper context manager for timing
|
|
96
|
+
@contextmanager
|
|
97
|
+
def timer(name):
|
|
98
|
+
start = time.time()
|
|
99
|
+
yield
|
|
100
|
+
print(f"{name}: {time.time() - start:.2f}s")
|
|
101
|
+
|
|
102
|
+
|
|
93
103
|
class MolStandardizer:
|
|
94
104
|
"""
|
|
95
105
|
Streamlined molecular standardizer for ADMET preprocessing
|
|
@@ -116,6 +126,7 @@ class MolStandardizer:
|
|
|
116
126
|
Pipeline:
|
|
117
127
|
1. Cleanup (remove Hs, disconnect metals, normalize)
|
|
118
128
|
2. Get largest fragment (optional - only if remove_salts=True)
|
|
129
|
+
2a. Extract salt information BEFORE further modifications
|
|
119
130
|
3. Neutralize charges
|
|
120
131
|
4. Canonicalize tautomer (optional)
|
|
121
132
|
|
|
@@ -130,18 +141,24 @@ class MolStandardizer:
|
|
|
130
141
|
|
|
131
142
|
try:
|
|
132
143
|
# Step 1: Cleanup
|
|
133
|
-
|
|
134
|
-
if
|
|
144
|
+
cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
|
|
145
|
+
if cleaned_mol is None:
|
|
135
146
|
return None, None
|
|
136
147
|
|
|
148
|
+
# If not doing any transformations, return early
|
|
149
|
+
if not self.remove_salts and not self.canonicalize_tautomer:
|
|
150
|
+
return cleaned_mol, None
|
|
151
|
+
|
|
137
152
|
salt_smiles = None
|
|
153
|
+
mol = cleaned_mol
|
|
138
154
|
|
|
139
155
|
# Step 2: Fragment handling (conditional based on remove_salts)
|
|
140
156
|
if self.remove_salts:
|
|
141
|
-
# Get parent molecule
|
|
142
|
-
parent_mol = rdMolStandardize.FragmentParent(
|
|
157
|
+
# Get parent molecule
|
|
158
|
+
parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
|
|
143
159
|
if parent_mol:
|
|
144
|
-
|
|
160
|
+
# Extract salt BEFORE any modifications to parent
|
|
161
|
+
salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
|
|
145
162
|
mol = parent_mol
|
|
146
163
|
else:
|
|
147
164
|
return None, None
|
|
@@ -153,7 +170,7 @@ class MolStandardizer:
|
|
|
153
170
|
if mol is None:
|
|
154
171
|
return None, salt_smiles
|
|
155
172
|
|
|
156
|
-
# Step 4: Canonicalize tautomer
|
|
173
|
+
# Step 4: Canonicalize tautomer (LAST STEP)
|
|
157
174
|
if self.canonicalize_tautomer:
|
|
158
175
|
mol = self.tautomer_enumerator.Canonicalize(mol)
|
|
159
176
|
|
|
@@ -172,13 +189,22 @@ class MolStandardizer:
|
|
|
172
189
|
- Mixtures: multiple large neutral organic fragments
|
|
173
190
|
|
|
174
191
|
Args:
|
|
175
|
-
orig_mol: Original molecule (before FragmentParent)
|
|
176
|
-
parent_mol: Parent molecule (after FragmentParent)
|
|
192
|
+
orig_mol: Original molecule (after Cleanup, before FragmentParent)
|
|
193
|
+
parent_mol: Parent molecule (after FragmentParent, before tautomerization)
|
|
177
194
|
|
|
178
195
|
Returns:
|
|
179
196
|
SMILES string of salt components or None if no salts/mixture detected
|
|
180
197
|
"""
|
|
181
198
|
try:
|
|
199
|
+
# Quick atom count check
|
|
200
|
+
if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# Quick heavy atom difference check
|
|
204
|
+
heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
|
|
205
|
+
if heavy_diff <= 0:
|
|
206
|
+
return None
|
|
207
|
+
|
|
182
208
|
# Get all fragments from original molecule
|
|
183
209
|
orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
|
|
184
210
|
|
|
@@ -268,7 +294,7 @@ def standardize(
|
|
|
268
294
|
if "orig_smiles" not in result.columns:
|
|
269
295
|
result["orig_smiles"] = result[smiles_column]
|
|
270
296
|
|
|
271
|
-
# Initialize standardizer
|
|
297
|
+
# Initialize standardizer
|
|
272
298
|
standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
|
|
273
299
|
|
|
274
300
|
def process_smiles(smiles: str) -> pd.Series:
|
|
@@ -286,6 +312,11 @@ def standardize(
|
|
|
286
312
|
log.error("Encountered missing or empty SMILES string")
|
|
287
313
|
return pd.Series({"smiles": None, "salt": None})
|
|
288
314
|
|
|
315
|
+
# Early check for unreasonably long SMILES
|
|
316
|
+
if len(smiles) > 1000:
|
|
317
|
+
log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
|
|
318
|
+
return pd.Series({"smiles": None, "salt": None})
|
|
319
|
+
|
|
289
320
|
# Parse molecule
|
|
290
321
|
mol = Chem.MolFromSmiles(smiles)
|
|
291
322
|
if mol is None:
|
|
@@ -299,7 +330,9 @@ def standardize(
|
|
|
299
330
|
if std_mol is not None:
|
|
300
331
|
# Check if molecule is reasonable
|
|
301
332
|
if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
|
|
302
|
-
log.error(f"
|
|
333
|
+
log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
|
|
334
|
+
log.error(f"Original SMILES: {smiles}")
|
|
335
|
+
return pd.Series({"smiles": None, "salt": salt_smiles})
|
|
303
336
|
|
|
304
337
|
if std_mol is None:
|
|
305
338
|
return pd.Series(
|
|
@@ -325,8 +358,11 @@ def standardize(
|
|
|
325
358
|
|
|
326
359
|
|
|
327
360
|
if __name__ == "__main__":
|
|
328
|
-
|
|
329
|
-
|
|
361
|
+
|
|
362
|
+
# Pandas display options for better readability
|
|
363
|
+
pd.set_option("display.max_columns", None)
|
|
364
|
+
pd.set_option("display.width", 1000)
|
|
365
|
+
pd.set_option("display.max_colwidth", 100)
|
|
330
366
|
|
|
331
367
|
# Test with DataFrame including various salt forms
|
|
332
368
|
test_data = pd.DataFrame(
|
|
@@ -362,67 +398,53 @@ if __name__ == "__main__":
|
|
|
362
398
|
)
|
|
363
399
|
|
|
364
400
|
# General test
|
|
401
|
+
print("Testing standardization with full dataset...")
|
|
365
402
|
standardize(test_data)
|
|
366
403
|
|
|
367
404
|
# Remove the last two rows to avoid errors with None and INVALID
|
|
368
405
|
test_data = test_data.iloc[:-2].reset_index(drop=True)
|
|
369
406
|
|
|
370
407
|
# Test WITHOUT salt removal (keeps full molecule)
|
|
371
|
-
print("\nStandardization KEEPING salts (extract_salts=False):")
|
|
372
|
-
print("This preserves the full molecule including counterions")
|
|
408
|
+
print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
|
|
373
409
|
result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
|
|
374
|
-
|
|
375
|
-
print(result_keep[
|
|
410
|
+
display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
|
|
411
|
+
print(result_keep[display_order])
|
|
376
412
|
|
|
377
413
|
# Test WITH salt removal
|
|
378
414
|
print("\n" + "=" * 70)
|
|
379
415
|
print("Standardization REMOVING salts (extract_salts=True):")
|
|
380
|
-
print("This extracts parent molecule and records salt information")
|
|
381
416
|
result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
|
|
382
|
-
print(result_remove[
|
|
417
|
+
print(result_remove[display_order])
|
|
383
418
|
|
|
384
|
-
# Test
|
|
419
|
+
# Test with problematic cases specifically
|
|
385
420
|
print("\n" + "=" * 70)
|
|
386
|
-
print("
|
|
387
|
-
|
|
388
|
-
|
|
421
|
+
print("Testing specific problematic cases:")
|
|
422
|
+
problem_cases = pd.DataFrame(
|
|
423
|
+
{
|
|
424
|
+
"smiles": [
|
|
425
|
+
"CC(=O)O.CCN", # Should extract CC(=O)O as salt
|
|
426
|
+
"CCO.CC", # Should return CC as salt
|
|
427
|
+
],
|
|
428
|
+
"compound_id": ["TEST_C002", "TEST_C005"],
|
|
429
|
+
}
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
|
|
433
|
+
print(problem_result[display_order])
|
|
434
|
+
|
|
435
|
+
# Performance test with larger dataset
|
|
436
|
+
from workbench.api import DataSource
|
|
389
437
|
|
|
390
|
-
# Show the difference for salt-containing molecules
|
|
391
|
-
print("\n" + "=" * 70)
|
|
392
|
-
print("Comparison showing differences:")
|
|
393
|
-
for idx, row in result_keep.iterrows():
|
|
394
|
-
keep_smiles = row["smiles"]
|
|
395
|
-
remove_smiles = result_remove.loc[idx, "smiles"]
|
|
396
|
-
no_taut_smiles = result_no_taut.loc[idx, "smiles"]
|
|
397
|
-
salt = result_remove.loc[idx, "salt"]
|
|
398
|
-
|
|
399
|
-
# Show differences when they exist
|
|
400
|
-
if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
|
|
401
|
-
print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
|
|
402
|
-
if keep_smiles != no_taut_smiles:
|
|
403
|
-
print(f" With salt + taut: {keep_smiles}")
|
|
404
|
-
print(f" With salt, no taut: {no_taut_smiles}")
|
|
405
|
-
if keep_smiles != remove_smiles:
|
|
406
|
-
print(f" Parent only + taut: {remove_smiles}")
|
|
407
|
-
if salt:
|
|
408
|
-
print(f" Extracted salt: {salt}")
|
|
409
|
-
|
|
410
|
-
# Summary statistics
|
|
411
438
|
print("\n" + "=" * 70)
|
|
412
|
-
print("Summary:")
|
|
413
|
-
print(f"Total molecules: {len(result_remove)}")
|
|
414
|
-
print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
|
|
415
|
-
unique_salts = result_remove["salt"].dropna().unique()
|
|
416
|
-
print(f"Unique salts found: {unique_salts[:5].tolist()}")
|
|
417
439
|
|
|
418
|
-
# Get a real dataset from Workbench and time the standardization
|
|
419
440
|
ds = DataSource("aqsol_data")
|
|
420
|
-
df = ds.pull_dataframe()[["id", "smiles"]]
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
441
|
+
df = ds.pull_dataframe()[["id", "smiles"]][:1000]
|
|
442
|
+
|
|
443
|
+
for tautomer in [True, False]:
|
|
444
|
+
for extract in [True, False]:
|
|
445
|
+
print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
|
|
446
|
+
start_time = time.time()
|
|
447
|
+
std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
|
|
448
|
+
elapsed = time.time() - start_time
|
|
449
|
+
mol_per_sec = len(df) / elapsed
|
|
450
|
+
print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
|
workbench/utils/model_utils.py
CHANGED
|
@@ -226,12 +226,18 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
226
226
|
elif "prediction_std" in df.columns:
|
|
227
227
|
lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
|
|
228
228
|
upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
|
|
229
|
+
lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
|
|
230
|
+
upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
|
|
231
|
+
lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
|
|
232
|
+
upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
|
|
229
233
|
lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
|
|
230
234
|
upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
|
|
231
235
|
else:
|
|
232
236
|
raise ValueError(
|
|
233
237
|
"Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
|
|
234
238
|
)
|
|
239
|
+
avg_std = df["prediction_std"].mean()
|
|
240
|
+
median_std = df["prediction_std"].median()
|
|
235
241
|
coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
|
|
236
242
|
coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
|
|
237
243
|
coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
|
|
@@ -242,12 +248,9 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
242
248
|
avg_width_50 = np.mean(upper_50 - lower_50)
|
|
243
249
|
|
|
244
250
|
# --- CRPS (measures calibration + sharpness) ---
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
mean_crps = np.mean(crps)
|
|
249
|
-
else:
|
|
250
|
-
mean_crps = np.nan
|
|
251
|
+
z = (df[target_col] - df["prediction"]) / df["prediction_std"]
|
|
252
|
+
crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
|
|
253
|
+
mean_crps = np.mean(crps)
|
|
251
254
|
|
|
252
255
|
# --- Interval Score @ 95% (penalizes miscoverage) ---
|
|
253
256
|
alpha_95 = 0.05
|
|
@@ -265,27 +268,33 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
265
268
|
|
|
266
269
|
# Collect results
|
|
267
270
|
results = {
|
|
268
|
-
"coverage_95": coverage_95,
|
|
269
|
-
"coverage_90": coverage_90,
|
|
270
|
-
"coverage_80": coverage_80,
|
|
271
271
|
"coverage_50": coverage_50,
|
|
272
|
-
"
|
|
272
|
+
"coverage_80": coverage_80,
|
|
273
|
+
"coverage_90": coverage_90,
|
|
274
|
+
"coverage_95": coverage_95,
|
|
275
|
+
"avg_std": avg_std,
|
|
276
|
+
"median_std": median_std,
|
|
273
277
|
"avg_width_50": avg_width_50,
|
|
274
|
-
"
|
|
275
|
-
"
|
|
276
|
-
"
|
|
278
|
+
"avg_width_80": avg_width_80,
|
|
279
|
+
"avg_width_90": avg_width_90,
|
|
280
|
+
"avg_width_95": avg_width_95,
|
|
281
|
+
# "crps": mean_crps,
|
|
282
|
+
# "interval_score_95": mean_is_95,
|
|
283
|
+
# "adaptive_calibration": adaptive_calibration,
|
|
277
284
|
"n_samples": len(df),
|
|
278
285
|
}
|
|
279
286
|
|
|
280
287
|
print("\n=== UQ Metrics ===")
|
|
281
|
-
print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
|
|
282
|
-
print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
|
|
283
|
-
print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
|
|
284
288
|
print(f"Coverage @ 50%: {coverage_50:.3f} (target: 0.50)")
|
|
285
|
-
print(f"
|
|
286
|
-
print(f"
|
|
287
|
-
print(f"
|
|
289
|
+
print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
|
|
290
|
+
print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
|
|
291
|
+
print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
|
|
292
|
+
print(f"Avg Prediction StdDev: {avg_std:.3f}")
|
|
293
|
+
print(f"Median Prediction StdDev: {median_std:.3f}")
|
|
288
294
|
print(f"Average 50% Width: {avg_width_50:.3f}")
|
|
295
|
+
print(f"Average 80% Width: {avg_width_80:.3f}")
|
|
296
|
+
print(f"Average 90% Width: {avg_width_90:.3f}")
|
|
297
|
+
print(f"Average 95% Width: {avg_width_95:.3f}")
|
|
289
298
|
print(f"CRPS: {mean_crps:.3f} (lower is better)")
|
|
290
299
|
print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
|
|
291
300
|
print(f"Adaptive Calibration: {adaptive_calibration:.3f} (higher is better, target: >0.5)")
|
|
@@ -325,9 +334,3 @@ if __name__ == "__main__":
|
|
|
325
334
|
df = end.auto_inference(capture=True)
|
|
326
335
|
results = uq_metrics(df, target_col="solubility")
|
|
327
336
|
print(results)
|
|
328
|
-
|
|
329
|
-
# Test the uq_metrics function
|
|
330
|
-
end = Endpoint("aqsol-uq-100")
|
|
331
|
-
df = end.auto_inference(capture=True)
|
|
332
|
-
results = uq_metrics(df, target_col="solubility")
|
|
333
|
-
print(results)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.177
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -54,9 +54,9 @@ workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv
|
|
|
54
54
|
workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcvZyNxYERwvo8o0OQc,14858
|
|
55
55
|
workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
|
|
56
56
|
workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
|
|
57
|
-
workbench/core/artifacts/endpoint_core.py,sha256=
|
|
57
|
+
workbench/core/artifacts/endpoint_core.py,sha256=Q6wL0IpMgCkVssX-BvPwawgogQjq9klSaoBUZ6tEIuc,49146
|
|
58
58
|
workbench/core/artifacts/feature_set_core.py,sha256=055VdSYR09HP4ygAuYvIYtHQ7Ec4XxsZygpgEl5H5jQ,29136
|
|
59
|
-
workbench/core/artifacts/model_core.py,sha256=
|
|
59
|
+
workbench/core/artifacts/model_core.py,sha256=ECDwQ0qM5qb1yGJ07U70BVdfkrW9m7p9e6YJWib3uR0,50855
|
|
60
60
|
workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
|
|
61
61
|
workbench/core/cloud_platform/cloud_meta.py,sha256=-g4-LTC3D0PXb3VfaXdLR1ERijKuHdffeMK_zhD-koQ,8809
|
|
62
62
|
workbench/core/cloud_platform/aws/README.md,sha256=QT5IQXoUHbIA0qQ2wO6_2P2lYjYQFVYuezc22mWY4i8,97
|
|
@@ -124,8 +124,8 @@ workbench/core/views/view_utils.py,sha256=y0YuPW-90nAfgAD1UW_49-j7Mvncfm7-5rV8I_
|
|
|
124
124
|
workbench/core/views/storage/mdq_view.py,sha256=qf_ep1KwaXOIfO930laEwNIiCYP7VNOqjE3VdHfopRE,5195
|
|
125
125
|
workbench/model_scripts/script_generation.py,sha256=dL23XYwEsHIStc7i53DtF_47FqOrI9gq0kQAT6sNpZ8,7923
|
|
126
126
|
workbench/model_scripts/custom_models/chem_info/Readme.md,sha256=mH1lxJ4Pb7F5nBnVXaiuxpi8zS_yjUw_LBJepVKXhlA,574
|
|
127
|
-
workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=
|
|
128
|
-
workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256
|
|
127
|
+
workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
|
|
128
|
+
workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
|
|
129
129
|
workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py,sha256=xljMjdfh4Idi4v1Afq1zZxvF1SDa7pDOLSAhvGBEj88,2891
|
|
130
130
|
workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py,sha256=tMyMmeN1xajVWkqkV5mobYB8CYkzW9FRH8Vi3t81uo8,3231
|
|
131
131
|
workbench/model_scripts/custom_models/chem_info/requirements.txt,sha256=7HBUzvNiM8lOir-UfQabXYlUp3gxdGJ42u18EuSMGjc,39
|
|
@@ -141,7 +141,7 @@ workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=U
|
|
|
141
141
|
workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=0IJnSBACQ556ldEiPqR7yPCOOLJs1hQhHmPBvB2d9tY,13491
|
|
142
142
|
workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=QbDUfkiPCwJ-c-4Twgu4utZuYZaAyeW_3T1IP-_tutw,6683
|
|
143
143
|
workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=AcLf-vXOmn_vpTeiKpNKCW_dRhR8Co1sMFC84EPT4IE,22392
|
|
144
|
-
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=
|
|
144
|
+
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=Vou_g0ux-KOrs36S98g27Y8ckU9sdYrKWwypJjasQX4,18180
|
|
145
145
|
workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=eawh0Fp3DhbdCXzWN6KloczT5ZS_ou4ayW65yUTTE4o,14109
|
|
146
146
|
workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=9-O6P-SW50ul5Wl6es2DMWXSbrwOg7HWsdc8Qdln0MM,8278
|
|
147
147
|
workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
|
|
@@ -159,7 +159,7 @@ workbench/model_scripts/quant_regression/requirements.txt,sha256=jWlGc7HH7vqyukT
|
|
|
159
159
|
workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=c73ZpJBlU5k13Nx-ZDkLXu7da40CYyhwjwwmuPq6uLg,12870
|
|
160
160
|
workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
|
|
161
161
|
workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=d4pgeZYFezUQsB-7iIsjsUgB1FM6d27651wpfDdXmI0,12640
|
|
162
|
-
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=
|
|
162
|
+
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=BPhr2gfJQC1C26knsyktfLGL7Jp0YBKCIQjplCuHUg0,22218
|
|
163
163
|
workbench/model_scripts/xgb_model/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
164
164
|
workbench/model_scripts/xgb_model/xgb_model.template,sha256=HViJRsMWn393hP8VJRS45UQBzUVBhwR5sKc8Ern-9f4,17963
|
|
165
165
|
workbench/repl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -211,7 +211,6 @@ workbench/utils/ecs_info.py,sha256=Gs9jNb4vcj2pziufIOI4BVIH1J-3XBMtWm1phVh8oRY,2
|
|
|
211
211
|
workbench/utils/endpoint_metrics.py,sha256=_4WVU6cLLuV0t_i0PSvhi0EoA5ss5aDFe7ZDpumx2R8,7822
|
|
212
212
|
workbench/utils/endpoint_utils.py,sha256=3-njrhMSAIOaEEiH7qMA9vgD3I7J2S9iUAcqXKx3OBo,7104
|
|
213
213
|
workbench/utils/extract_model_artifact.py,sha256=sFwkJd5mfJ1PU37pIHVmUIQS-taIUJdqi3D9-qRmy8g,7870
|
|
214
|
-
workbench/utils/fast_inference.py,sha256=Sm0EV1oPsYYGqiDBVUu3Nj6Ti68JV-UR2S0ZliBDPTk,6148
|
|
215
214
|
workbench/utils/glue_utils.py,sha256=dslfXQcJ4C-mGmsD6LqeK8vsXBez570t3fZBVZLV7HA,2039
|
|
216
215
|
workbench/utils/graph_utils.py,sha256=T4aslYVbzPmFe0_qKCQP6PZnaw1KATNXQNVO-yDGBxY,10839
|
|
217
216
|
workbench/utils/ipython_utils.py,sha256=skbdbBwUT-iuY3FZwy3ACS7-FWSe9M2qVXfLlQWnikE,700
|
|
@@ -220,7 +219,7 @@ workbench/utils/lambda_utils.py,sha256=7GhGRPyXn9o-toWb9HBGSnI8-DhK9YRkwhCSk_mNK
|
|
|
220
219
|
workbench/utils/license_manager.py,sha256=sDuhk1mZZqUbFmnuFXehyGnui_ALxrmYBg7gYwoo7ho,6975
|
|
221
220
|
workbench/utils/log_utils.py,sha256=7n1NJXO_jUX82e6LWAQug6oPo3wiPDBYsqk9gsYab_A,3167
|
|
222
221
|
workbench/utils/markdown_utils.py,sha256=4lEqzgG4EVmLcvvKKNUwNxVCySLQKJTJmWDiaDroI1w,8306
|
|
223
|
-
workbench/utils/model_utils.py,sha256=
|
|
222
|
+
workbench/utils/model_utils.py,sha256=7TYxTa2KCoLJfJ47QcnzmibMwKHX3bP37-sPvfqgdVM,12273
|
|
224
223
|
workbench/utils/monitor_utils.py,sha256=kVaJ7BgUXs3VPMFYfLC03wkIV4Dq-pEhoXS0wkJFxCc,7858
|
|
225
224
|
workbench/utils/pandas_utils.py,sha256=uTUx-d1KYfjbS9PMQp2_9FogCV7xVZR6XLzU5YAGmfs,39371
|
|
226
225
|
workbench/utils/performance_utils.py,sha256=WDNvz-bOdC99cDuXl0urAV4DJ7alk_V3yzKPwvqgST4,1329
|
|
@@ -247,8 +246,8 @@ workbench/utils/xgboost_model_utils.py,sha256=iiDJH0O81aO6aOTwgssqQygvTgjE7lRDRz
|
|
|
247
246
|
workbench/utils/chem_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
248
247
|
workbench/utils/chem_utils/fingerprints.py,sha256=Qvs8jaUwguWUq3Q3j695MY0t0Wk3BvroW-oWBwalMUo,5255
|
|
249
248
|
workbench/utils/chem_utils/misc.py,sha256=Nevf8_opu-uIPrv_1_0ubuFVVo2_fGUkMoLAHB3XAeo,7372
|
|
250
|
-
workbench/utils/chem_utils/mol_descriptors.py,sha256=
|
|
251
|
-
workbench/utils/chem_utils/mol_standardize.py,sha256
|
|
249
|
+
workbench/utils/chem_utils/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
|
|
250
|
+
workbench/utils/chem_utils/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
|
|
252
251
|
workbench/utils/chem_utils/mol_tagging.py,sha256=8Bt6gHvyN8B2jvVuz12JgYMHVLDkCLnEPAfqkyMEoMc,9995
|
|
253
252
|
workbench/utils/chem_utils/projections.py,sha256=smV-VTB-pqRrgn4DXyDIpuCYcopJdPZ54YoCQv60JY0,7480
|
|
254
253
|
workbench/utils/chem_utils/salts.py,sha256=ZzFb6Z71Z_kMjVF-PKwHx0fn9pN9rPMj-oEY8Nt5JWA,9095
|
|
@@ -288,9 +287,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
|
|
|
288
287
|
workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
|
|
289
288
|
workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
|
|
290
289
|
workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
|
|
291
|
-
workbench-0.8.
|
|
292
|
-
workbench-0.8.
|
|
293
|
-
workbench-0.8.
|
|
294
|
-
workbench-0.8.
|
|
295
|
-
workbench-0.8.
|
|
296
|
-
workbench-0.8.
|
|
290
|
+
workbench-0.8.177.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
|
|
291
|
+
workbench-0.8.177.dist-info/METADATA,sha256=sjKEEHLha3-tDo9uYsRtpjPTHV_pj5PkucHuc2WWxBM,9210
|
|
292
|
+
workbench-0.8.177.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
293
|
+
workbench-0.8.177.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
|
|
294
|
+
workbench-0.8.177.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
|
|
295
|
+
workbench-0.8.177.dist-info/RECORD,,
|
|
@@ -1,167 +0,0 @@
|
|
|
1
|
-
"""Fast Inference on SageMaker Endpoints"""
|
|
2
|
-
|
|
3
|
-
import pandas as pd
|
|
4
|
-
from io import StringIO
|
|
5
|
-
import logging
|
|
6
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
-
|
|
8
|
-
# Sagemaker Imports
|
|
9
|
-
import sagemaker
|
|
10
|
-
from sagemaker.serializers import CSVSerializer
|
|
11
|
-
from sagemaker.deserializers import CSVDeserializer
|
|
12
|
-
from sagemaker import Predictor
|
|
13
|
-
|
|
14
|
-
log = logging.getLogger("workbench")
|
|
15
|
-
|
|
16
|
-
_CACHED_SM_SESSION = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def get_or_create_sm_session():
|
|
20
|
-
global _CACHED_SM_SESSION
|
|
21
|
-
if _CACHED_SM_SESSION is None:
|
|
22
|
-
_CACHED_SM_SESSION = sagemaker.Session()
|
|
23
|
-
return _CACHED_SM_SESSION
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def fast_inference(endpoint_name: str, eval_df: pd.DataFrame, sm_session=None, threads: int = 4) -> pd.DataFrame:
|
|
27
|
-
"""Run inference on the Endpoint using the provided DataFrame
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
endpoint_name (str): The name of the Endpoint
|
|
31
|
-
eval_df (pd.DataFrame): The DataFrame to run predictions on
|
|
32
|
-
sm_session (sagemaker.session.Session, optional): SageMaker Session. If None, a cached session is created.
|
|
33
|
-
threads (int): The number of threads to use (default: 4)
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
pd.DataFrame: The DataFrame with predictions
|
|
37
|
-
"""
|
|
38
|
-
# Use cached session if none is provided
|
|
39
|
-
if sm_session is None:
|
|
40
|
-
sm_session = get_or_create_sm_session()
|
|
41
|
-
|
|
42
|
-
predictor = Predictor(
|
|
43
|
-
endpoint_name,
|
|
44
|
-
sagemaker_session=sm_session,
|
|
45
|
-
serializer=CSVSerializer(),
|
|
46
|
-
deserializer=CSVDeserializer(),
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
total_rows = len(eval_df)
|
|
50
|
-
|
|
51
|
-
def process_chunk(chunk_df: pd.DataFrame, start_index: int) -> pd.DataFrame:
|
|
52
|
-
log.info(f"Processing {start_index}:{min(start_index + chunk_size, total_rows)} out of {total_rows} rows...")
|
|
53
|
-
csv_buffer = StringIO()
|
|
54
|
-
chunk_df.to_csv(csv_buffer, index=False)
|
|
55
|
-
response = predictor.predict(csv_buffer.getvalue())
|
|
56
|
-
# CSVDeserializer returns a nested list: first row is headers
|
|
57
|
-
return pd.DataFrame.from_records(response[1:], columns=response[0])
|
|
58
|
-
|
|
59
|
-
# Sagemaker has a connection pool limit of 10
|
|
60
|
-
if threads > 10:
|
|
61
|
-
log.warning("Sagemaker has a connection pool limit of 10. Reducing threads to 10.")
|
|
62
|
-
threads = 10
|
|
63
|
-
|
|
64
|
-
# Compute the chunk size (divide number of threads)
|
|
65
|
-
chunk_size = max(1, total_rows // threads)
|
|
66
|
-
|
|
67
|
-
# We also need to ensure that the chunk size is not too big
|
|
68
|
-
if chunk_size > 100:
|
|
69
|
-
chunk_size = 100
|
|
70
|
-
|
|
71
|
-
# Split DataFrame into chunks and process them concurrently
|
|
72
|
-
chunks = [(eval_df[i : i + chunk_size], i) for i in range(0, total_rows, chunk_size)]
|
|
73
|
-
with ThreadPoolExecutor(max_workers=threads) as executor:
|
|
74
|
-
df_list = list(executor.map(lambda p: process_chunk(*p), chunks))
|
|
75
|
-
|
|
76
|
-
combined_df = pd.concat(df_list, ignore_index=True)
|
|
77
|
-
|
|
78
|
-
# Convert the types of the dataframe
|
|
79
|
-
combined_df = df_type_conversions(combined_df)
|
|
80
|
-
return combined_df
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def df_type_conversions(df: pd.DataFrame) -> pd.DataFrame:
|
|
84
|
-
"""Convert the types of the dataframe that we get from an endpoint
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
df (pd.DataFrame): DataFrame to convert
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
pd.DataFrame: Converted DataFrame
|
|
91
|
-
"""
|
|
92
|
-
# Some endpoints will put in "N/A" values (for CSV serialization)
|
|
93
|
-
# We need to convert these to NaN and the run the conversions below
|
|
94
|
-
# Report on the number of N/A values in each column in the DataFrame
|
|
95
|
-
# For any count above 0 list the column name and the number of N/A values
|
|
96
|
-
na_counts = df.isin(["N/A"]).sum()
|
|
97
|
-
for column, count in na_counts.items():
|
|
98
|
-
if count > 0:
|
|
99
|
-
log.warning(f"{column} has {count} N/A values, converting to NaN")
|
|
100
|
-
pd.set_option("future.no_silent_downcasting", True)
|
|
101
|
-
df = df.replace("N/A", float("nan"))
|
|
102
|
-
|
|
103
|
-
# Convert data to numeric
|
|
104
|
-
# Note: Since we're using CSV serializers numeric columns often get changed to generic 'object' types
|
|
105
|
-
|
|
106
|
-
# Hard Conversion
|
|
107
|
-
# Note: We explicitly catch exceptions for columns that cannot be converted to numeric
|
|
108
|
-
for column in df.columns:
|
|
109
|
-
try:
|
|
110
|
-
df[column] = pd.to_numeric(df[column])
|
|
111
|
-
except ValueError:
|
|
112
|
-
# If a ValueError is raised, the column cannot be converted to numeric, so we keep it as is
|
|
113
|
-
pass
|
|
114
|
-
except TypeError:
|
|
115
|
-
# This typically means a duplicated column name, so confirm duplicate (more than 1) and log it
|
|
116
|
-
column_count = (df.columns == column).sum()
|
|
117
|
-
log.critical(f"{column} occurs {column_count} times in the DataFrame.")
|
|
118
|
-
pass
|
|
119
|
-
|
|
120
|
-
# Soft Conversion
|
|
121
|
-
# Convert columns to the best possible dtype that supports the pd.NA missing value.
|
|
122
|
-
df = df.convert_dtypes()
|
|
123
|
-
|
|
124
|
-
# Convert pd.NA placeholders to pd.NA
|
|
125
|
-
# Note: CSV serialization converts pd.NA to blank strings, so we have to put in placeholders
|
|
126
|
-
df.replace("__NA__", pd.NA, inplace=True)
|
|
127
|
-
|
|
128
|
-
# Check for True/False values in the string columns
|
|
129
|
-
for column in df.select_dtypes(include=["string"]).columns:
|
|
130
|
-
if df[column].str.lower().isin(["true", "false"]).all():
|
|
131
|
-
df[column] = df[column].str.lower().map({"true": True, "false": False})
|
|
132
|
-
|
|
133
|
-
# Return the Dataframe
|
|
134
|
-
return df
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if __name__ == "__main__":
|
|
138
|
-
"""Exercise the Endpoint Utilities"""
|
|
139
|
-
import time
|
|
140
|
-
from workbench.api.endpoint import Endpoint
|
|
141
|
-
from workbench.utils.endpoint_utils import fs_training_data, fs_evaluation_data
|
|
142
|
-
|
|
143
|
-
# Create an Endpoint
|
|
144
|
-
my_endpoint_name = "abalone-regression"
|
|
145
|
-
my_endpoint = Endpoint(my_endpoint_name)
|
|
146
|
-
if not my_endpoint.exists():
|
|
147
|
-
print(f"Endpoint {my_endpoint_name} does not exist.")
|
|
148
|
-
exit(1)
|
|
149
|
-
|
|
150
|
-
# Get the training data
|
|
151
|
-
my_train_df = fs_training_data(my_endpoint)
|
|
152
|
-
print(my_train_df)
|
|
153
|
-
|
|
154
|
-
# Run Fast Inference and time it
|
|
155
|
-
my_sm_session = my_endpoint.sm_session
|
|
156
|
-
my_eval_df = fs_evaluation_data(my_endpoint)
|
|
157
|
-
start_time = time.time()
|
|
158
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df, my_sm_session)
|
|
159
|
-
end_time = time.time()
|
|
160
|
-
print(f"Fast Inference took {end_time - start_time} seconds")
|
|
161
|
-
print(my_results_df)
|
|
162
|
-
print(my_results_df.info())
|
|
163
|
-
|
|
164
|
-
# Test with no session
|
|
165
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df)
|
|
166
|
-
print(my_results_df)
|
|
167
|
-
print(my_results_df.info())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|