workbench 0.8.175__py3-none-any.whl → 0.8.177__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -238,8 +238,8 @@ class Artifact(ABC):
238
238
  """
239
239
 
240
240
  # Check for ReadOnly Role
241
- if self.aws_account_clamp.read_only_role:
242
- self.log.info("Cannot add metadata with a ReadOnly Role...")
241
+ if self.aws_account_clamp.read_only:
242
+ self.log.info("Cannot add metadata with a ReadOnly Permissions...")
243
243
  return
244
244
 
245
245
  # Sanity check
@@ -32,11 +32,11 @@ from sagemaker import Predictor
32
32
  from workbench.core.artifacts.artifact import Artifact
33
33
  from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
34
34
  from workbench.utils.endpoint_metrics import EndpointMetrics
35
- from workbench.utils.fast_inference import fast_inference
36
35
  from workbench.utils.cache import Cache
37
36
  from workbench.utils.s3_utils import compute_s3_object_hash
38
37
  from workbench.utils.model_utils import uq_metrics
39
38
  from workbench.utils.xgboost_model_utils import cross_fold_inference
39
+ from workbench_bridges.endpoints.fast_inference import fast_inference
40
40
 
41
41
 
42
42
  class EndpointCore(Artifact):
@@ -1061,6 +1061,9 @@ if __name__ == "__main__":
1061
1061
  assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
1062
1062
 
1063
1063
  # Now we put in an invalid value
1064
+ print("*" * 80)
1065
+ print("NOW TESTING ERROR CONDITIONS...")
1066
+ print("*" * 80)
1064
1067
  my_eval_df.at[42, "length"] = "invalid_value"
1065
1068
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1066
1069
  print(f"Sent rows: {len(my_eval_df)}")
@@ -37,35 +37,6 @@ class ModelType(Enum):
37
37
  UNKNOWN = "unknown"
38
38
 
39
39
 
40
- # Deprecated Images
41
- """
42
- # US East 1 images
43
- "py312-general-ml-training"
44
- ("us-east-1", "training", "0.1", "x86_64"): (
45
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
46
- ),
47
- ("us-east-1", "inference", "0.1", "x86_64"): (
48
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
49
- ),
50
-
51
- # US West 2 images
52
- ("us-west-2", "training", "0.1", "x86_64"): (
53
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
54
- ),
55
- ("us-west-2", "inference", "0.1", "x86_64"): (
56
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
57
- ),
58
-
59
- # ARM64 images
60
- ("us-east-1", "inference", "0.1", "arm64"): (
61
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
62
- ),
63
- ("us-west-2", "inference", "0.1", "arm64"): (
64
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
65
- ),
66
- """
67
-
68
-
69
40
  class ModelImages:
70
41
  """Class for retrieving workbench inference images"""
71
42
 
@@ -890,6 +861,14 @@ class ModelCore(Artifact):
890
861
  shap_data[key] = self.df_store.get(df_location)
891
862
  return shap_data or None
892
863
 
864
+ def cross_folds(self) -> dict:
865
+ """Retrieve the cross-fold inference results(only works for XGBoost models)
866
+
867
+ Returns:
868
+ dict: Dictionary with the cross-fold inference results
869
+ """
870
+ return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
871
+
893
872
  def supported_inference_instances(self) -> Optional[list]:
894
873
  """Retrieve the supported endpoint inference instance types
895
874
 
@@ -55,9 +55,10 @@ class AWSAccountClamp:
55
55
  # Check our Assume Role
56
56
  self.log.info("Checking Workbench Assumed Role...")
57
57
  role_info = self.aws_session.assumed_role_info()
58
+ self.log.info(f"Assumed Role: {role_info}")
58
59
 
59
- # Check if the Role is a 'ReadOnly' role
60
- self.read_only_role = "readonly" in role_info["AssumedRoleArn"].lower()
60
+ # Check if we have tag write permissions (if we don't, we are read-only)
61
+ self.read_only = not self.check_tag_permissions()
61
62
 
62
63
  # Check our Workbench API Key and Load the License
63
64
  self.log.info("Checking Workbench API License...")
@@ -141,6 +142,45 @@ class AWSAccountClamp:
141
142
  """
142
143
  return self.boto3_session.client("sagemaker")
143
144
 
145
+ def check_tag_permissions(self):
146
+ """Check if current role has permission to add tags to SageMaker endpoints.
147
+
148
+ Returns:
149
+ bool: True if AddTags is allowed, False otherwise
150
+ """
151
+ try:
152
+ sagemaker = self.boto3_session.client("sagemaker")
153
+
154
+ # Use a non-existent endpoint name
155
+ fake_endpoint = "workbench-permission-check-dummy-endpoint"
156
+
157
+ # Try to add tags to the non-existent endpoint
158
+ sagemaker.add_tags(
159
+ ResourceArn=f"arn:aws:sagemaker:{self.region}:{self.account_id}:endpoint/{fake_endpoint}",
160
+ Tags=[{"Key": "PermissionCheck", "Value": "Test"}],
161
+ )
162
+
163
+ # If we get here, we have permission (but endpoint doesn't exist)
164
+ return True
165
+
166
+ except ClientError as e:
167
+ error_code = e.response["Error"]["Code"]
168
+
169
+ # AccessDeniedException = no permission
170
+ if error_code == "AccessDeniedException":
171
+ self.log.debug("No AddTags permission (AccessDeniedException)")
172
+ return False
173
+
174
+ # ResourceNotFound = we have permission, but endpoint doesn't exist
175
+ elif error_code in ["ResourceNotFound", "ValidationException"]:
176
+ self.log.debug("AddTags permission verified (resource not found)")
177
+ return True
178
+
179
+ # Unexpected error, assume no permission for safety
180
+ else:
181
+ self.log.debug(f"Unexpected error checking permissions: {error_code}")
182
+ return False
183
+
144
184
 
145
185
  if __name__ == "__main__":
146
186
  """Exercise the AWS Account Clamp Class"""
@@ -165,3 +205,9 @@ if __name__ == "__main__":
165
205
  print("\n\n*** AWS Sagemaker Session/Client Check ***")
166
206
  sm_client = aws_account_clamp.sagemaker_client()
167
207
  print(sm_client.list_feature_groups()["FeatureGroupSummaries"])
208
+
209
+ print("\n\n*** AWS Tag Permission Check ***")
210
+ if aws_account_clamp.check_tag_permissions():
211
+ print("Tag Permission Check Success...")
212
+ else:
213
+ print("Tag Permission Check Failed...")
@@ -91,16 +91,27 @@ import logging
91
91
  import pandas as pd
92
92
  import numpy as np
93
93
  import re
94
+ import time
95
+ from contextlib import contextmanager
94
96
  from rdkit import Chem
95
97
  from rdkit.Chem import Descriptors, rdCIPLabeler
96
98
  from rdkit.ML.Descriptors import MoleculeDescriptors
97
99
  from mordred import Calculator as MordredCalculator
98
100
  from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
99
101
 
102
+
100
103
  logger = logging.getLogger("workbench")
101
104
  logger.setLevel(logging.DEBUG)
102
105
 
103
106
 
107
+ # Helper context manager for timing
108
+ @contextmanager
109
+ def timer(name):
110
+ start = time.time()
111
+ yield
112
+ print(f"{name}: {time.time() - start:.2f}s")
113
+
114
+
104
115
  def compute_stereochemistry_features(mol):
105
116
  """
106
117
  Compute stereochemistry descriptors using modern RDKit methods.
@@ -280,9 +291,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
280
291
  descriptor_values.append([np.nan] * len(all_descriptors))
281
292
 
282
293
  # Create RDKit features DataFrame
283
- rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames(), index=result.index)
294
+ rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
284
295
 
285
296
  # Add RDKit features to result
297
+ # Remove any columns from result that exist in rdkit_features_df
298
+ result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
286
299
  result = pd.concat([result, rdkit_features_df], axis=1)
287
300
 
288
301
  # Compute Mordred descriptors
@@ -299,7 +312,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
299
312
 
300
313
  # Compute Mordred descriptors
301
314
  valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
302
- mordred_df = calc.pandas(valid_mols, nproc=1) # For serverless, use nproc=1
315
+ mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
303
316
 
304
317
  # Replace values for invalid molecules with NaN
305
318
  for i, mol in enumerate(molecules):
@@ -310,10 +323,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
310
323
  for col in mordred_df.columns:
311
324
  mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
312
325
 
313
- # Set index to match result DataFrame
314
- mordred_df.index = result.index
315
-
316
326
  # Add Mordred features to result
327
+ # Remove any columns from result that exist in mordred
328
+ result = result.drop(columns=result.columns.intersection(mordred_df.columns))
317
329
  result = pd.concat([result, mordred_df], axis=1)
318
330
 
319
331
  # Compute stereochemistry features if requested
@@ -326,9 +338,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
326
338
  stereo_features.append(stereo_dict)
327
339
 
328
340
  # Create stereochemistry DataFrame
329
- stereo_df = pd.DataFrame(stereo_features, index=result.index)
341
+ stereo_df = pd.DataFrame(stereo_features)
330
342
 
331
343
  # Add stereochemistry features to result
344
+ result = result.drop(columns=result.columns.intersection(stereo_df.columns))
332
345
  result = pd.concat([result, stereo_df], axis=1)
333
346
 
334
347
  logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
@@ -357,7 +370,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
357
370
 
358
371
 
359
372
  if __name__ == "__main__":
360
- import time
361
373
  from mol_standardize import standardize
362
374
  from workbench.api import DataSource
363
375
 
@@ -81,6 +81,8 @@ Usage:
81
81
  import logging
82
82
  from typing import Optional, Tuple
83
83
  import pandas as pd
84
+ import time
85
+ from contextlib import contextmanager
84
86
  from rdkit import Chem
85
87
  from rdkit.Chem import Mol
86
88
  from rdkit.Chem.MolStandardize import rdMolStandardize
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
90
92
  RDLogger.DisableLog("rdApp.warning")
91
93
 
92
94
 
95
+ # Helper context manager for timing
96
+ @contextmanager
97
+ def timer(name):
98
+ start = time.time()
99
+ yield
100
+ print(f"{name}: {time.time() - start:.2f}s")
101
+
102
+
93
103
  class MolStandardizer:
94
104
  """
95
105
  Streamlined molecular standardizer for ADMET preprocessing
@@ -116,6 +126,7 @@ class MolStandardizer:
116
126
  Pipeline:
117
127
  1. Cleanup (remove Hs, disconnect metals, normalize)
118
128
  2. Get largest fragment (optional - only if remove_salts=True)
129
+ 2a. Extract salt information BEFORE further modifications
119
130
  3. Neutralize charges
120
131
  4. Canonicalize tautomer (optional)
121
132
 
@@ -130,18 +141,24 @@ class MolStandardizer:
130
141
 
131
142
  try:
132
143
  # Step 1: Cleanup
133
- mol = rdMolStandardize.Cleanup(mol, self.params)
134
- if mol is None:
144
+ cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
145
+ if cleaned_mol is None:
135
146
  return None, None
136
147
 
148
+ # If not doing any transformations, return early
149
+ if not self.remove_salts and not self.canonicalize_tautomer:
150
+ return cleaned_mol, None
151
+
137
152
  salt_smiles = None
153
+ mol = cleaned_mol
138
154
 
139
155
  # Step 2: Fragment handling (conditional based on remove_salts)
140
156
  if self.remove_salts:
141
- # Get parent molecule and extract salt information
142
- parent_mol = rdMolStandardize.FragmentParent(mol, self.params)
157
+ # Get parent molecule
158
+ parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
143
159
  if parent_mol:
144
- salt_smiles = self._extract_salt(mol, parent_mol)
160
+ # Extract salt BEFORE any modifications to parent
161
+ salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
145
162
  mol = parent_mol
146
163
  else:
147
164
  return None, None
@@ -153,7 +170,7 @@ class MolStandardizer:
153
170
  if mol is None:
154
171
  return None, salt_smiles
155
172
 
156
- # Step 4: Canonicalize tautomer
173
+ # Step 4: Canonicalize tautomer (LAST STEP)
157
174
  if self.canonicalize_tautomer:
158
175
  mol = self.tautomer_enumerator.Canonicalize(mol)
159
176
 
@@ -172,13 +189,22 @@ class MolStandardizer:
172
189
  - Mixtures: multiple large neutral organic fragments
173
190
 
174
191
  Args:
175
- orig_mol: Original molecule (before FragmentParent)
176
- parent_mol: Parent molecule (after FragmentParent)
192
+ orig_mol: Original molecule (after Cleanup, before FragmentParent)
193
+ parent_mol: Parent molecule (after FragmentParent, before tautomerization)
177
194
 
178
195
  Returns:
179
196
  SMILES string of salt components or None if no salts/mixture detected
180
197
  """
181
198
  try:
199
+ # Quick atom count check
200
+ if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
201
+ return None
202
+
203
+ # Quick heavy atom difference check
204
+ heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
205
+ if heavy_diff <= 0:
206
+ return None
207
+
182
208
  # Get all fragments from original molecule
183
209
  orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
184
210
 
@@ -268,7 +294,7 @@ def standardize(
268
294
  if "orig_smiles" not in result.columns:
269
295
  result["orig_smiles"] = result[smiles_column]
270
296
 
271
- # Initialize standardizer with salt removal control
297
+ # Initialize standardizer
272
298
  standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
273
299
 
274
300
  def process_smiles(smiles: str) -> pd.Series:
@@ -286,6 +312,11 @@ def standardize(
286
312
  log.error("Encountered missing or empty SMILES string")
287
313
  return pd.Series({"smiles": None, "salt": None})
288
314
 
315
+ # Early check for unreasonably long SMILES
316
+ if len(smiles) > 1000:
317
+ log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
318
+ return pd.Series({"smiles": None, "salt": None})
319
+
289
320
  # Parse molecule
290
321
  mol = Chem.MolFromSmiles(smiles)
291
322
  if mol is None:
@@ -299,7 +330,9 @@ def standardize(
299
330
  if std_mol is not None:
300
331
  # Check if molecule is reasonable
301
332
  if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
302
- log.error(f"Unusual molecule size: {std_mol.GetNumAtoms()} atoms")
333
+ log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
334
+ log.error(f"Original SMILES: {smiles}")
335
+ return pd.Series({"smiles": None, "salt": salt_smiles})
303
336
 
304
337
  if std_mol is None:
305
338
  return pd.Series(
@@ -325,8 +358,11 @@ def standardize(
325
358
 
326
359
 
327
360
  if __name__ == "__main__":
328
- import time
329
- from workbench.api import DataSource
361
+
362
+ # Pandas display options for better readability
363
+ pd.set_option("display.max_columns", None)
364
+ pd.set_option("display.width", 1000)
365
+ pd.set_option("display.max_colwidth", 100)
330
366
 
331
367
  # Test with DataFrame including various salt forms
332
368
  test_data = pd.DataFrame(
@@ -362,67 +398,53 @@ if __name__ == "__main__":
362
398
  )
363
399
 
364
400
  # General test
401
+ print("Testing standardization with full dataset...")
365
402
  standardize(test_data)
366
403
 
367
404
  # Remove the last two rows to avoid errors with None and INVALID
368
405
  test_data = test_data.iloc[:-2].reset_index(drop=True)
369
406
 
370
407
  # Test WITHOUT salt removal (keeps full molecule)
371
- print("\nStandardization KEEPING salts (extract_salts=False):")
372
- print("This preserves the full molecule including counterions")
408
+ print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
373
409
  result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
374
- display_cols = ["compound_id", "orig_smiles", "smiles", "salt"]
375
- print(result_keep[display_cols].to_string())
410
+ display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
411
+ print(result_keep[display_order])
376
412
 
377
413
  # Test WITH salt removal
378
414
  print("\n" + "=" * 70)
379
415
  print("Standardization REMOVING salts (extract_salts=True):")
380
- print("This extracts parent molecule and records salt information")
381
416
  result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
382
- print(result_remove[display_cols].to_string())
417
+ print(result_remove[display_order])
383
418
 
384
- # Test WITHOUT tautomerization (keeping salts)
419
+ # Test with problematic cases specifically
385
420
  print("\n" + "=" * 70)
386
- print("Standardization KEEPING salts, NO tautomerization:")
387
- result_no_taut = standardize(test_data, extract_salts=False, canonicalize_tautomer=False)
388
- print(result_no_taut[display_cols].to_string())
421
+ print("Testing specific problematic cases:")
422
+ problem_cases = pd.DataFrame(
423
+ {
424
+ "smiles": [
425
+ "CC(=O)O.CCN", # Should extract CC(=O)O as salt
426
+ "CCO.CC", # Should return CC as salt
427
+ ],
428
+ "compound_id": ["TEST_C002", "TEST_C005"],
429
+ }
430
+ )
431
+
432
+ problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
433
+ print(problem_result[display_order])
434
+
435
+ # Performance test with larger dataset
436
+ from workbench.api import DataSource
389
437
 
390
- # Show the difference for salt-containing molecules
391
- print("\n" + "=" * 70)
392
- print("Comparison showing differences:")
393
- for idx, row in result_keep.iterrows():
394
- keep_smiles = row["smiles"]
395
- remove_smiles = result_remove.loc[idx, "smiles"]
396
- no_taut_smiles = result_no_taut.loc[idx, "smiles"]
397
- salt = result_remove.loc[idx, "salt"]
398
-
399
- # Show differences when they exist
400
- if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
401
- print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
402
- if keep_smiles != no_taut_smiles:
403
- print(f" With salt + taut: {keep_smiles}")
404
- print(f" With salt, no taut: {no_taut_smiles}")
405
- if keep_smiles != remove_smiles:
406
- print(f" Parent only + taut: {remove_smiles}")
407
- if salt:
408
- print(f" Extracted salt: {salt}")
409
-
410
- # Summary statistics
411
438
  print("\n" + "=" * 70)
412
- print("Summary:")
413
- print(f"Total molecules: {len(result_remove)}")
414
- print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
415
- unique_salts = result_remove["salt"].dropna().unique()
416
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
417
439
 
418
- # Get a real dataset from Workbench and time the standardization
419
440
  ds = DataSource("aqsol_data")
420
- df = ds.pull_dataframe()[["id", "smiles"]]
421
- start_time = time.time()
422
- std_df = standardize(df, extract_salts=True, canonicalize_tautomer=True)
423
- end_time = time.time()
424
- print(f"\nStandardized {len(std_df)} molecules from Workbench in {end_time - start_time:.2f} seconds")
425
- print(std_df.head())
426
- print(f"Molecules with salts: {std_df['salt'].notna().sum()}")
427
- unique_salts = std_df["salt"].dropna().unique()
428
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
441
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000]
442
+
443
+ for tautomer in [True, False]:
444
+ for extract in [True, False]:
445
+ print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
446
+ start_time = time.time()
447
+ std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
448
+ elapsed = time.time() - start_time
449
+ mol_per_sec = len(df) / elapsed
450
+ print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
@@ -472,9 +472,9 @@ def predict_fn(df, models) -> pd.DataFrame:
472
472
  # Add median (q_50) from XGBoost prediction
473
473
  df["q_50"] = df["prediction"]
474
474
 
475
- # Calculate uncertainty metrics based on 95% interval
476
- interval_width = df["q_975"] - df["q_025"]
477
- df["prediction_std"] = interval_width / 3.92
475
+ # Calculate uncertainty metrics based on 50% interval
476
+ interval_width = df["q_75"] - df["q_25"]
477
+ df["prediction_std"] = interval_width / 1.348
478
478
 
479
479
  # Reorder the quantile columns for easier reading
480
480
  quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
@@ -28,11 +28,11 @@ from typing import List, Tuple
28
28
 
29
29
  # Template Parameters
30
30
  TEMPLATE_PARAMS = {
31
- "model_type": "classifier",
32
- "target": "class",
31
+ "model_type": "regressor",
32
+ "target": "udm_asy_res_value",
33
33
  "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
34
34
  "compressed_features": [],
35
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/sol-class-f1-100/training",
35
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/pka-a1-reg-0-nightly-100-test/training",
36
36
  "train_all_data": True
37
37
  }
38
38
 
@@ -91,16 +91,27 @@ import logging
91
91
  import pandas as pd
92
92
  import numpy as np
93
93
  import re
94
+ import time
95
+ from contextlib import contextmanager
94
96
  from rdkit import Chem
95
97
  from rdkit.Chem import Descriptors, rdCIPLabeler
96
98
  from rdkit.ML.Descriptors import MoleculeDescriptors
97
99
  from mordred import Calculator as MordredCalculator
98
100
  from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
99
101
 
102
+
100
103
  logger = logging.getLogger("workbench")
101
104
  logger.setLevel(logging.DEBUG)
102
105
 
103
106
 
107
+ # Helper context manager for timing
108
+ @contextmanager
109
+ def timer(name):
110
+ start = time.time()
111
+ yield
112
+ print(f"{name}: {time.time() - start:.2f}s")
113
+
114
+
104
115
  def compute_stereochemistry_features(mol):
105
116
  """
106
117
  Compute stereochemistry descriptors using modern RDKit methods.
@@ -280,9 +291,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
280
291
  descriptor_values.append([np.nan] * len(all_descriptors))
281
292
 
282
293
  # Create RDKit features DataFrame
283
- rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames(), index=result.index)
294
+ rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
284
295
 
285
296
  # Add RDKit features to result
297
+ # Remove any columns from result that exist in rdkit_features_df
298
+ result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
286
299
  result = pd.concat([result, rdkit_features_df], axis=1)
287
300
 
288
301
  # Compute Mordred descriptors
@@ -299,7 +312,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
299
312
 
300
313
  # Compute Mordred descriptors
301
314
  valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
302
- mordred_df = calc.pandas(valid_mols, nproc=1) # For serverless, use nproc=1
315
+ mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
303
316
 
304
317
  # Replace values for invalid molecules with NaN
305
318
  for i, mol in enumerate(molecules):
@@ -310,10 +323,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
310
323
  for col in mordred_df.columns:
311
324
  mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
312
325
 
313
- # Set index to match result DataFrame
314
- mordred_df.index = result.index
315
-
316
326
  # Add Mordred features to result
327
+ # Remove any columns from result that exist in mordred
328
+ result = result.drop(columns=result.columns.intersection(mordred_df.columns))
317
329
  result = pd.concat([result, mordred_df], axis=1)
318
330
 
319
331
  # Compute stereochemistry features if requested
@@ -326,9 +338,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
326
338
  stereo_features.append(stereo_dict)
327
339
 
328
340
  # Create stereochemistry DataFrame
329
- stereo_df = pd.DataFrame(stereo_features, index=result.index)
341
+ stereo_df = pd.DataFrame(stereo_features)
330
342
 
331
343
  # Add stereochemistry features to result
344
+ result = result.drop(columns=result.columns.intersection(stereo_df.columns))
332
345
  result = pd.concat([result, stereo_df], axis=1)
333
346
 
334
347
  logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
@@ -357,7 +370,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
357
370
 
358
371
 
359
372
  if __name__ == "__main__":
360
- import time
361
373
  from mol_standardize import standardize
362
374
  from workbench.api import DataSource
363
375
 
@@ -81,6 +81,8 @@ Usage:
81
81
  import logging
82
82
  from typing import Optional, Tuple
83
83
  import pandas as pd
84
+ import time
85
+ from contextlib import contextmanager
84
86
  from rdkit import Chem
85
87
  from rdkit.Chem import Mol
86
88
  from rdkit.Chem.MolStandardize import rdMolStandardize
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
90
92
  RDLogger.DisableLog("rdApp.warning")
91
93
 
92
94
 
95
+ # Helper context manager for timing
96
+ @contextmanager
97
+ def timer(name):
98
+ start = time.time()
99
+ yield
100
+ print(f"{name}: {time.time() - start:.2f}s")
101
+
102
+
93
103
  class MolStandardizer:
94
104
  """
95
105
  Streamlined molecular standardizer for ADMET preprocessing
@@ -116,6 +126,7 @@ class MolStandardizer:
116
126
  Pipeline:
117
127
  1. Cleanup (remove Hs, disconnect metals, normalize)
118
128
  2. Get largest fragment (optional - only if remove_salts=True)
129
+ 2a. Extract salt information BEFORE further modifications
119
130
  3. Neutralize charges
120
131
  4. Canonicalize tautomer (optional)
121
132
 
@@ -130,18 +141,24 @@ class MolStandardizer:
130
141
 
131
142
  try:
132
143
  # Step 1: Cleanup
133
- mol = rdMolStandardize.Cleanup(mol, self.params)
134
- if mol is None:
144
+ cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
145
+ if cleaned_mol is None:
135
146
  return None, None
136
147
 
148
+ # If not doing any transformations, return early
149
+ if not self.remove_salts and not self.canonicalize_tautomer:
150
+ return cleaned_mol, None
151
+
137
152
  salt_smiles = None
153
+ mol = cleaned_mol
138
154
 
139
155
  # Step 2: Fragment handling (conditional based on remove_salts)
140
156
  if self.remove_salts:
141
- # Get parent molecule and extract salt information
142
- parent_mol = rdMolStandardize.FragmentParent(mol, self.params)
157
+ # Get parent molecule
158
+ parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
143
159
  if parent_mol:
144
- salt_smiles = self._extract_salt(mol, parent_mol)
160
+ # Extract salt BEFORE any modifications to parent
161
+ salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
145
162
  mol = parent_mol
146
163
  else:
147
164
  return None, None
@@ -153,7 +170,7 @@ class MolStandardizer:
153
170
  if mol is None:
154
171
  return None, salt_smiles
155
172
 
156
- # Step 4: Canonicalize tautomer
173
+ # Step 4: Canonicalize tautomer (LAST STEP)
157
174
  if self.canonicalize_tautomer:
158
175
  mol = self.tautomer_enumerator.Canonicalize(mol)
159
176
 
@@ -172,13 +189,22 @@ class MolStandardizer:
172
189
  - Mixtures: multiple large neutral organic fragments
173
190
 
174
191
  Args:
175
- orig_mol: Original molecule (before FragmentParent)
176
- parent_mol: Parent molecule (after FragmentParent)
192
+ orig_mol: Original molecule (after Cleanup, before FragmentParent)
193
+ parent_mol: Parent molecule (after FragmentParent, before tautomerization)
177
194
 
178
195
  Returns:
179
196
  SMILES string of salt components or None if no salts/mixture detected
180
197
  """
181
198
  try:
199
+ # Quick atom count check
200
+ if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
201
+ return None
202
+
203
+ # Quick heavy atom difference check
204
+ heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
205
+ if heavy_diff <= 0:
206
+ return None
207
+
182
208
  # Get all fragments from original molecule
183
209
  orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
184
210
 
@@ -268,7 +294,7 @@ def standardize(
268
294
  if "orig_smiles" not in result.columns:
269
295
  result["orig_smiles"] = result[smiles_column]
270
296
 
271
- # Initialize standardizer with salt removal control
297
+ # Initialize standardizer
272
298
  standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
273
299
 
274
300
  def process_smiles(smiles: str) -> pd.Series:
@@ -286,6 +312,11 @@ def standardize(
286
312
  log.error("Encountered missing or empty SMILES string")
287
313
  return pd.Series({"smiles": None, "salt": None})
288
314
 
315
+ # Early check for unreasonably long SMILES
316
+ if len(smiles) > 1000:
317
+ log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
318
+ return pd.Series({"smiles": None, "salt": None})
319
+
289
320
  # Parse molecule
290
321
  mol = Chem.MolFromSmiles(smiles)
291
322
  if mol is None:
@@ -299,7 +330,9 @@ def standardize(
299
330
  if std_mol is not None:
300
331
  # Check if molecule is reasonable
301
332
  if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
302
- log.error(f"Unusual molecule size: {std_mol.GetNumAtoms()} atoms")
333
+ log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
334
+ log.error(f"Original SMILES: {smiles}")
335
+ return pd.Series({"smiles": None, "salt": salt_smiles})
303
336
 
304
337
  if std_mol is None:
305
338
  return pd.Series(
@@ -325,8 +358,11 @@ def standardize(
325
358
 
326
359
 
327
360
  if __name__ == "__main__":
328
- import time
329
- from workbench.api import DataSource
361
+
362
+ # Pandas display options for better readability
363
+ pd.set_option("display.max_columns", None)
364
+ pd.set_option("display.width", 1000)
365
+ pd.set_option("display.max_colwidth", 100)
330
366
 
331
367
  # Test with DataFrame including various salt forms
332
368
  test_data = pd.DataFrame(
@@ -362,67 +398,53 @@ if __name__ == "__main__":
362
398
  )
363
399
 
364
400
  # General test
401
+ print("Testing standardization with full dataset...")
365
402
  standardize(test_data)
366
403
 
367
404
  # Remove the last two rows to avoid errors with None and INVALID
368
405
  test_data = test_data.iloc[:-2].reset_index(drop=True)
369
406
 
370
407
  # Test WITHOUT salt removal (keeps full molecule)
371
- print("\nStandardization KEEPING salts (extract_salts=False):")
372
- print("This preserves the full molecule including counterions")
408
+ print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
373
409
  result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
374
- display_cols = ["compound_id", "orig_smiles", "smiles", "salt"]
375
- print(result_keep[display_cols].to_string())
410
+ display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
411
+ print(result_keep[display_order])
376
412
 
377
413
  # Test WITH salt removal
378
414
  print("\n" + "=" * 70)
379
415
  print("Standardization REMOVING salts (extract_salts=True):")
380
- print("This extracts parent molecule and records salt information")
381
416
  result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
382
- print(result_remove[display_cols].to_string())
417
+ print(result_remove[display_order])
383
418
 
384
- # Test WITHOUT tautomerization (keeping salts)
419
+ # Test with problematic cases specifically
385
420
  print("\n" + "=" * 70)
386
- print("Standardization KEEPING salts, NO tautomerization:")
387
- result_no_taut = standardize(test_data, extract_salts=False, canonicalize_tautomer=False)
388
- print(result_no_taut[display_cols].to_string())
421
+ print("Testing specific problematic cases:")
422
+ problem_cases = pd.DataFrame(
423
+ {
424
+ "smiles": [
425
+ "CC(=O)O.CCN", # Should extract CC(=O)O as salt
426
+ "CCO.CC", # Should return CC as salt
427
+ ],
428
+ "compound_id": ["TEST_C002", "TEST_C005"],
429
+ }
430
+ )
431
+
432
+ problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
433
+ print(problem_result[display_order])
434
+
435
+ # Performance test with larger dataset
436
+ from workbench.api import DataSource
389
437
 
390
- # Show the difference for salt-containing molecules
391
- print("\n" + "=" * 70)
392
- print("Comparison showing differences:")
393
- for idx, row in result_keep.iterrows():
394
- keep_smiles = row["smiles"]
395
- remove_smiles = result_remove.loc[idx, "smiles"]
396
- no_taut_smiles = result_no_taut.loc[idx, "smiles"]
397
- salt = result_remove.loc[idx, "salt"]
398
-
399
- # Show differences when they exist
400
- if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
401
- print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
402
- if keep_smiles != no_taut_smiles:
403
- print(f" With salt + taut: {keep_smiles}")
404
- print(f" With salt, no taut: {no_taut_smiles}")
405
- if keep_smiles != remove_smiles:
406
- print(f" Parent only + taut: {remove_smiles}")
407
- if salt:
408
- print(f" Extracted salt: {salt}")
409
-
410
- # Summary statistics
411
438
  print("\n" + "=" * 70)
412
- print("Summary:")
413
- print(f"Total molecules: {len(result_remove)}")
414
- print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
415
- unique_salts = result_remove["salt"].dropna().unique()
416
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
417
439
 
418
- # Get a real dataset from Workbench and time the standardization
419
440
  ds = DataSource("aqsol_data")
420
- df = ds.pull_dataframe()[["id", "smiles"]]
421
- start_time = time.time()
422
- std_df = standardize(df, extract_salts=True, canonicalize_tautomer=True)
423
- end_time = time.time()
424
- print(f"\nStandardized {len(std_df)} molecules from Workbench in {end_time - start_time:.2f} seconds")
425
- print(std_df.head())
426
- print(f"Molecules with salts: {std_df['salt'].notna().sum()}")
427
- unique_salts = std_df["salt"].dropna().unique()
428
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
441
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000]
442
+
443
+ for tautomer in [True, False]:
444
+ for extract in [True, False]:
445
+ print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
446
+ start_time = time.time()
447
+ std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
448
+ elapsed = time.time() - start_time
449
+ mol_per_sec = len(df) / elapsed
450
+ print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
@@ -226,12 +226,18 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
226
226
  elif "prediction_std" in df.columns:
227
227
  lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
228
228
  upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
229
+ lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
230
+ upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
231
+ lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
232
+ upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
229
233
  lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
230
234
  upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
231
235
  else:
232
236
  raise ValueError(
233
237
  "Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
234
238
  )
239
+ avg_std = df["prediction_std"].mean()
240
+ median_std = df["prediction_std"].median()
235
241
  coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
236
242
  coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
237
243
  coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
@@ -242,12 +248,9 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
242
248
  avg_width_50 = np.mean(upper_50 - lower_50)
243
249
 
244
250
  # --- CRPS (measures calibration + sharpness) ---
245
- if "prediction_std" in df.columns:
246
- z = (df[target_col] - df["prediction"]) / df["prediction_std"]
247
- crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
248
- mean_crps = np.mean(crps)
249
- else:
250
- mean_crps = np.nan
251
+ z = (df[target_col] - df["prediction"]) / df["prediction_std"]
252
+ crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
253
+ mean_crps = np.mean(crps)
251
254
 
252
255
  # --- Interval Score @ 95% (penalizes miscoverage) ---
253
256
  alpha_95 = 0.05
@@ -265,27 +268,33 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
265
268
 
266
269
  # Collect results
267
270
  results = {
268
- "coverage_95": coverage_95,
269
- "coverage_90": coverage_90,
270
- "coverage_80": coverage_80,
271
271
  "coverage_50": coverage_50,
272
- "avg_width_95": avg_width_95,
272
+ "coverage_80": coverage_80,
273
+ "coverage_90": coverage_90,
274
+ "coverage_95": coverage_95,
275
+ "avg_std": avg_std,
276
+ "median_std": median_std,
273
277
  "avg_width_50": avg_width_50,
274
- "crps": mean_crps,
275
- "interval_score_95": mean_is_95,
276
- "adaptive_calibration": adaptive_calibration,
278
+ "avg_width_80": avg_width_80,
279
+ "avg_width_90": avg_width_90,
280
+ "avg_width_95": avg_width_95,
281
+ # "crps": mean_crps,
282
+ # "interval_score_95": mean_is_95,
283
+ # "adaptive_calibration": adaptive_calibration,
277
284
  "n_samples": len(df),
278
285
  }
279
286
 
280
287
  print("\n=== UQ Metrics ===")
281
- print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
282
- print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
283
- print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
284
288
  print(f"Coverage @ 50%: {coverage_50:.3f} (target: 0.50)")
285
- print(f"Average 95% Width: {avg_width_95:.3f}")
286
- print(f"Average 90% Width: {avg_width_90:.3f}")
287
- print(f"Average 80% Width: {avg_width_80:.3f}")
289
+ print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
290
+ print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
291
+ print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
292
+ print(f"Avg Prediction StdDev: {avg_std:.3f}")
293
+ print(f"Median Prediction StdDev: {median_std:.3f}")
288
294
  print(f"Average 50% Width: {avg_width_50:.3f}")
295
+ print(f"Average 80% Width: {avg_width_80:.3f}")
296
+ print(f"Average 90% Width: {avg_width_90:.3f}")
297
+ print(f"Average 95% Width: {avg_width_95:.3f}")
289
298
  print(f"CRPS: {mean_crps:.3f} (lower is better)")
290
299
  print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
291
300
  print(f"Adaptive Calibration: {adaptive_calibration:.3f} (higher is better, target: >0.5)")
@@ -325,9 +334,3 @@ if __name__ == "__main__":
325
334
  df = end.auto_inference(capture=True)
326
335
  results = uq_metrics(df, target_col="solubility")
327
336
  print(results)
328
-
329
- # Test the uq_metrics function
330
- end = Endpoint("aqsol-uq-100")
331
- df = end.auto_inference(capture=True)
332
- results = uq_metrics(df, target_col="solubility")
333
- print(results)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: workbench
3
- Version: 0.8.175
3
+ Version: 0.8.177
4
4
  Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
5
5
  Author-email: SuperCowPowers LLC <support@supercowpowers.com>
6
6
  License-Expression: MIT
@@ -48,19 +48,19 @@ workbench/cached/cached_model.py,sha256=iMc_fySUE5qau3feduVXMNb24JY0sBjt1g6WeLLc
48
48
  workbench/cached/cached_pipeline.py,sha256=QOVnEKu5RbIdlNpJUi-0Ebh0_-C68RigSPwKh4dvZTM,1948
49
49
  workbench/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
50
  workbench/core/artifacts/__init__.py,sha256=ps7rA_rbWnDbvWbg4kvu--IKMY8WmbPRyv4Si0xub1Q,965
51
- workbench/core/artifacts/artifact.py,sha256=AtTw8wfMd-fi7cHJHsBAXHUk53kRW_6lyBwwsIbHw54,17750
51
+ workbench/core/artifacts/artifact.py,sha256=WFGC1F61d7uFSRB7UTWYOF8O_wk8F9rn__THJL2veLM,17752
52
52
  workbench/core/artifacts/athena_source.py,sha256=RNmCe7s6uH4gVHpcdJcL84aSbF5Q1ahJBLLGwHYRXEU,26081
53
53
  workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv_7W6XCvxVGXXSfzzaft8,3775
54
54
  workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcvZyNxYERwvo8o0OQc,14858
55
55
  workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
56
56
  workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
57
- workbench/core/artifacts/endpoint_core.py,sha256=lwgiz0jttW8C4YqcKaA8nf231WI3kol-nLnKcAbFJko,49049
57
+ workbench/core/artifacts/endpoint_core.py,sha256=Q6wL0IpMgCkVssX-BvPwawgogQjq9klSaoBUZ6tEIuc,49146
58
58
  workbench/core/artifacts/feature_set_core.py,sha256=055VdSYR09HP4ygAuYvIYtHQ7Ec4XxsZygpgEl5H5jQ,29136
59
- workbench/core/artifacts/model_core.py,sha256=6d5dV4DGUBgD9E_Gpk0F5x7OEc4oiDKokvA8m42vnK4,51724
59
+ workbench/core/artifacts/model_core.py,sha256=ECDwQ0qM5qb1yGJ07U70BVdfkrW9m7p9e6YJWib3uR0,50855
60
60
  workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
61
61
  workbench/core/cloud_platform/cloud_meta.py,sha256=-g4-LTC3D0PXb3VfaXdLR1ERijKuHdffeMK_zhD-koQ,8809
62
62
  workbench/core/cloud_platform/aws/README.md,sha256=QT5IQXoUHbIA0qQ2wO6_2P2lYjYQFVYuezc22mWY4i8,97
63
- workbench/core/cloud_platform/aws/aws_account_clamp.py,sha256=vAVC_HEk1YGSlo5F2bhQlWUxPN2QgRe3ht73O42faWQ,6452
63
+ workbench/core/cloud_platform/aws/aws_account_clamp.py,sha256=V5iVsoGvSRilARtTdExnt27QptzAcJaW0s3nm2B8-ow,8286
64
64
  workbench/core/cloud_platform/aws/aws_df_store.py,sha256=utRIlTCPwFneHHZ8_Z3Hw3rOJSeryiFA4wBtucxULRQ,15055
65
65
  workbench/core/cloud_platform/aws/aws_graph_store.py,sha256=ytYxQTplUmeWbsPmxyZbf6mO9qyTl60ewlJG8MyfyEY,9414
66
66
  workbench/core/cloud_platform/aws/aws_meta.py,sha256=eY9Pn6pl2yAyseACFb2nitR-0vLwG4i8CSEXe8Iaswc,34778
@@ -124,8 +124,8 @@ workbench/core/views/view_utils.py,sha256=y0YuPW-90nAfgAD1UW_49-j7Mvncfm7-5rV8I_
124
124
  workbench/core/views/storage/mdq_view.py,sha256=qf_ep1KwaXOIfO930laEwNIiCYP7VNOqjE3VdHfopRE,5195
125
125
  workbench/model_scripts/script_generation.py,sha256=dL23XYwEsHIStc7i53DtF_47FqOrI9gq0kQAT6sNpZ8,7923
126
126
  workbench/model_scripts/custom_models/chem_info/Readme.md,sha256=mH1lxJ4Pb7F5nBnVXaiuxpi8zS_yjUw_LBJepVKXhlA,574
127
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=N07kGqyLd9DE9S23WfPqXGO5NMQzNxe0jtl1RgtC4yY,18315
128
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256=-BMtNzZSbXFnfoxFESHdfg7yjXO83JVecpIEsj39eDM,17145
127
+ workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
128
+ workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
129
129
  workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py,sha256=xljMjdfh4Idi4v1Afq1zZxvF1SDa7pDOLSAhvGBEj88,2891
130
130
  workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py,sha256=tMyMmeN1xajVWkqkV5mobYB8CYkzW9FRH8Vi3t81uo8,3231
131
131
  workbench/model_scripts/custom_models/chem_info/requirements.txt,sha256=7HBUzvNiM8lOir-UfQabXYlUp3gxdGJ42u18EuSMGjc,39
@@ -141,7 +141,7 @@ workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=U
141
141
  workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=0IJnSBACQ556ldEiPqR7yPCOOLJs1hQhHmPBvB2d9tY,13491
142
142
  workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=QbDUfkiPCwJ-c-4Twgu4utZuYZaAyeW_3T1IP-_tutw,6683
143
143
  workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=AcLf-vXOmn_vpTeiKpNKCW_dRhR8Co1sMFC84EPT4IE,22392
144
- workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=VkFM0eZM2d-hzDbngk9s08DD5vn2nQRD4coCUfj36Fk,18181
144
+ workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=Vou_g0ux-KOrs36S98g27Y8ckU9sdYrKWwypJjasQX4,18180
145
145
  workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=eawh0Fp3DhbdCXzWN6KloczT5ZS_ou4ayW65yUTTE4o,14109
146
146
  workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=9-O6P-SW50ul5Wl6es2DMWXSbrwOg7HWsdc8Qdln0MM,8278
147
147
  workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
@@ -159,7 +159,7 @@ workbench/model_scripts/quant_regression/requirements.txt,sha256=jWlGc7HH7vqyukT
159
159
  workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=c73ZpJBlU5k13Nx-ZDkLXu7da40CYyhwjwwmuPq6uLg,12870
160
160
  workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
161
161
  workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=d4pgeZYFezUQsB-7iIsjsUgB1FM6d27651wpfDdXmI0,12640
162
- workbench/model_scripts/xgb_model/generated_model_script.py,sha256=vxM9dxRwrAZoDwAkj-a7LNNcBNd3KpHdNrublpAIVQo,22194
162
+ workbench/model_scripts/xgb_model/generated_model_script.py,sha256=BPhr2gfJQC1C26knsyktfLGL7Jp0YBKCIQjplCuHUg0,22218
163
163
  workbench/model_scripts/xgb_model/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
164
164
  workbench/model_scripts/xgb_model/xgb_model.template,sha256=HViJRsMWn393hP8VJRS45UQBzUVBhwR5sKc8Ern-9f4,17963
165
165
  workbench/repl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -211,7 +211,6 @@ workbench/utils/ecs_info.py,sha256=Gs9jNb4vcj2pziufIOI4BVIH1J-3XBMtWm1phVh8oRY,2
211
211
  workbench/utils/endpoint_metrics.py,sha256=_4WVU6cLLuV0t_i0PSvhi0EoA5ss5aDFe7ZDpumx2R8,7822
212
212
  workbench/utils/endpoint_utils.py,sha256=3-njrhMSAIOaEEiH7qMA9vgD3I7J2S9iUAcqXKx3OBo,7104
213
213
  workbench/utils/extract_model_artifact.py,sha256=sFwkJd5mfJ1PU37pIHVmUIQS-taIUJdqi3D9-qRmy8g,7870
214
- workbench/utils/fast_inference.py,sha256=Sm0EV1oPsYYGqiDBVUu3Nj6Ti68JV-UR2S0ZliBDPTk,6148
215
214
  workbench/utils/glue_utils.py,sha256=dslfXQcJ4C-mGmsD6LqeK8vsXBez570t3fZBVZLV7HA,2039
216
215
  workbench/utils/graph_utils.py,sha256=T4aslYVbzPmFe0_qKCQP6PZnaw1KATNXQNVO-yDGBxY,10839
217
216
  workbench/utils/ipython_utils.py,sha256=skbdbBwUT-iuY3FZwy3ACS7-FWSe9M2qVXfLlQWnikE,700
@@ -220,7 +219,7 @@ workbench/utils/lambda_utils.py,sha256=7GhGRPyXn9o-toWb9HBGSnI8-DhK9YRkwhCSk_mNK
220
219
  workbench/utils/license_manager.py,sha256=sDuhk1mZZqUbFmnuFXehyGnui_ALxrmYBg7gYwoo7ho,6975
221
220
  workbench/utils/log_utils.py,sha256=7n1NJXO_jUX82e6LWAQug6oPo3wiPDBYsqk9gsYab_A,3167
222
221
  workbench/utils/markdown_utils.py,sha256=4lEqzgG4EVmLcvvKKNUwNxVCySLQKJTJmWDiaDroI1w,8306
223
- workbench/utils/model_utils.py,sha256=JeEztmFyDJ7yqRozDX0L6apuhLgKx1sgNlO5duB73qc,11938
222
+ workbench/utils/model_utils.py,sha256=7TYxTa2KCoLJfJ47QcnzmibMwKHX3bP37-sPvfqgdVM,12273
224
223
  workbench/utils/monitor_utils.py,sha256=kVaJ7BgUXs3VPMFYfLC03wkIV4Dq-pEhoXS0wkJFxCc,7858
225
224
  workbench/utils/pandas_utils.py,sha256=uTUx-d1KYfjbS9PMQp2_9FogCV7xVZR6XLzU5YAGmfs,39371
226
225
  workbench/utils/performance_utils.py,sha256=WDNvz-bOdC99cDuXl0urAV4DJ7alk_V3yzKPwvqgST4,1329
@@ -247,8 +246,8 @@ workbench/utils/xgboost_model_utils.py,sha256=iiDJH0O81aO6aOTwgssqQygvTgjE7lRDRz
247
246
  workbench/utils/chem_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
248
247
  workbench/utils/chem_utils/fingerprints.py,sha256=Qvs8jaUwguWUq3Q3j695MY0t0Wk3BvroW-oWBwalMUo,5255
249
248
  workbench/utils/chem_utils/misc.py,sha256=Nevf8_opu-uIPrv_1_0ubuFVVo2_fGUkMoLAHB3XAeo,7372
250
- workbench/utils/chem_utils/mol_descriptors.py,sha256=N07kGqyLd9DE9S23WfPqXGO5NMQzNxe0jtl1RgtC4yY,18315
251
- workbench/utils/chem_utils/mol_standardize.py,sha256=-BMtNzZSbXFnfoxFESHdfg7yjXO83JVecpIEsj39eDM,17145
249
+ workbench/utils/chem_utils/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
250
+ workbench/utils/chem_utils/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
252
251
  workbench/utils/chem_utils/mol_tagging.py,sha256=8Bt6gHvyN8B2jvVuz12JgYMHVLDkCLnEPAfqkyMEoMc,9995
253
252
  workbench/utils/chem_utils/projections.py,sha256=smV-VTB-pqRrgn4DXyDIpuCYcopJdPZ54YoCQv60JY0,7480
254
253
  workbench/utils/chem_utils/salts.py,sha256=ZzFb6Z71Z_kMjVF-PKwHx0fn9pN9rPMj-oEY8Nt5JWA,9095
@@ -288,9 +287,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
288
287
  workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
289
288
  workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
290
289
  workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
291
- workbench-0.8.175.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
292
- workbench-0.8.175.dist-info/METADATA,sha256=hAjhM-oXEqxffYyDwawIsSdTv3iKsRs5_OiZw1sv2RQ,9210
293
- workbench-0.8.175.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
294
- workbench-0.8.175.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
295
- workbench-0.8.175.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
296
- workbench-0.8.175.dist-info/RECORD,,
290
+ workbench-0.8.177.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
291
+ workbench-0.8.177.dist-info/METADATA,sha256=sjKEEHLha3-tDo9uYsRtpjPTHV_pj5PkucHuc2WWxBM,9210
292
+ workbench-0.8.177.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
293
+ workbench-0.8.177.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
294
+ workbench-0.8.177.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
295
+ workbench-0.8.177.dist-info/RECORD,,
@@ -1,167 +0,0 @@
1
- """Fast Inference on SageMaker Endpoints"""
2
-
3
- import pandas as pd
4
- from io import StringIO
5
- import logging
6
- from concurrent.futures import ThreadPoolExecutor
7
-
8
- # Sagemaker Imports
9
- import sagemaker
10
- from sagemaker.serializers import CSVSerializer
11
- from sagemaker.deserializers import CSVDeserializer
12
- from sagemaker import Predictor
13
-
14
- log = logging.getLogger("workbench")
15
-
16
- _CACHED_SM_SESSION = None
17
-
18
-
19
- def get_or_create_sm_session():
20
- global _CACHED_SM_SESSION
21
- if _CACHED_SM_SESSION is None:
22
- _CACHED_SM_SESSION = sagemaker.Session()
23
- return _CACHED_SM_SESSION
24
-
25
-
26
- def fast_inference(endpoint_name: str, eval_df: pd.DataFrame, sm_session=None, threads: int = 4) -> pd.DataFrame:
27
- """Run inference on the Endpoint using the provided DataFrame
28
-
29
- Args:
30
- endpoint_name (str): The name of the Endpoint
31
- eval_df (pd.DataFrame): The DataFrame to run predictions on
32
- sm_session (sagemaker.session.Session, optional): SageMaker Session. If None, a cached session is created.
33
- threads (int): The number of threads to use (default: 4)
34
-
35
- Returns:
36
- pd.DataFrame: The DataFrame with predictions
37
- """
38
- # Use cached session if none is provided
39
- if sm_session is None:
40
- sm_session = get_or_create_sm_session()
41
-
42
- predictor = Predictor(
43
- endpoint_name,
44
- sagemaker_session=sm_session,
45
- serializer=CSVSerializer(),
46
- deserializer=CSVDeserializer(),
47
- )
48
-
49
- total_rows = len(eval_df)
50
-
51
- def process_chunk(chunk_df: pd.DataFrame, start_index: int) -> pd.DataFrame:
52
- log.info(f"Processing {start_index}:{min(start_index + chunk_size, total_rows)} out of {total_rows} rows...")
53
- csv_buffer = StringIO()
54
- chunk_df.to_csv(csv_buffer, index=False)
55
- response = predictor.predict(csv_buffer.getvalue())
56
- # CSVDeserializer returns a nested list: first row is headers
57
- return pd.DataFrame.from_records(response[1:], columns=response[0])
58
-
59
- # Sagemaker has a connection pool limit of 10
60
- if threads > 10:
61
- log.warning("Sagemaker has a connection pool limit of 10. Reducing threads to 10.")
62
- threads = 10
63
-
64
- # Compute the chunk size (divide number of threads)
65
- chunk_size = max(1, total_rows // threads)
66
-
67
- # We also need to ensure that the chunk size is not too big
68
- if chunk_size > 100:
69
- chunk_size = 100
70
-
71
- # Split DataFrame into chunks and process them concurrently
72
- chunks = [(eval_df[i : i + chunk_size], i) for i in range(0, total_rows, chunk_size)]
73
- with ThreadPoolExecutor(max_workers=threads) as executor:
74
- df_list = list(executor.map(lambda p: process_chunk(*p), chunks))
75
-
76
- combined_df = pd.concat(df_list, ignore_index=True)
77
-
78
- # Convert the types of the dataframe
79
- combined_df = df_type_conversions(combined_df)
80
- return combined_df
81
-
82
-
83
- def df_type_conversions(df: pd.DataFrame) -> pd.DataFrame:
84
- """Convert the types of the dataframe that we get from an endpoint
85
-
86
- Args:
87
- df (pd.DataFrame): DataFrame to convert
88
-
89
- Returns:
90
- pd.DataFrame: Converted DataFrame
91
- """
92
- # Some endpoints will put in "N/A" values (for CSV serialization)
93
- # We need to convert these to NaN and the run the conversions below
94
- # Report on the number of N/A values in each column in the DataFrame
95
- # For any count above 0 list the column name and the number of N/A values
96
- na_counts = df.isin(["N/A"]).sum()
97
- for column, count in na_counts.items():
98
- if count > 0:
99
- log.warning(f"{column} has {count} N/A values, converting to NaN")
100
- pd.set_option("future.no_silent_downcasting", True)
101
- df = df.replace("N/A", float("nan"))
102
-
103
- # Convert data to numeric
104
- # Note: Since we're using CSV serializers numeric columns often get changed to generic 'object' types
105
-
106
- # Hard Conversion
107
- # Note: We explicitly catch exceptions for columns that cannot be converted to numeric
108
- for column in df.columns:
109
- try:
110
- df[column] = pd.to_numeric(df[column])
111
- except ValueError:
112
- # If a ValueError is raised, the column cannot be converted to numeric, so we keep it as is
113
- pass
114
- except TypeError:
115
- # This typically means a duplicated column name, so confirm duplicate (more than 1) and log it
116
- column_count = (df.columns == column).sum()
117
- log.critical(f"{column} occurs {column_count} times in the DataFrame.")
118
- pass
119
-
120
- # Soft Conversion
121
- # Convert columns to the best possible dtype that supports the pd.NA missing value.
122
- df = df.convert_dtypes()
123
-
124
- # Convert pd.NA placeholders to pd.NA
125
- # Note: CSV serialization converts pd.NA to blank strings, so we have to put in placeholders
126
- df.replace("__NA__", pd.NA, inplace=True)
127
-
128
- # Check for True/False values in the string columns
129
- for column in df.select_dtypes(include=["string"]).columns:
130
- if df[column].str.lower().isin(["true", "false"]).all():
131
- df[column] = df[column].str.lower().map({"true": True, "false": False})
132
-
133
- # Return the Dataframe
134
- return df
135
-
136
-
137
- if __name__ == "__main__":
138
- """Exercise the Endpoint Utilities"""
139
- import time
140
- from workbench.api.endpoint import Endpoint
141
- from workbench.utils.endpoint_utils import fs_training_data, fs_evaluation_data
142
-
143
- # Create an Endpoint
144
- my_endpoint_name = "abalone-regression"
145
- my_endpoint = Endpoint(my_endpoint_name)
146
- if not my_endpoint.exists():
147
- print(f"Endpoint {my_endpoint_name} does not exist.")
148
- exit(1)
149
-
150
- # Get the training data
151
- my_train_df = fs_training_data(my_endpoint)
152
- print(my_train_df)
153
-
154
- # Run Fast Inference and time it
155
- my_sm_session = my_endpoint.sm_session
156
- my_eval_df = fs_evaluation_data(my_endpoint)
157
- start_time = time.time()
158
- my_results_df = fast_inference(my_endpoint_name, my_eval_df, my_sm_session)
159
- end_time = time.time()
160
- print(f"Fast Inference took {end_time - start_time} seconds")
161
- print(my_results_df)
162
- print(my_results_df.info())
163
-
164
- # Test with no session
165
- my_results_df = fast_inference(my_endpoint_name, my_eval_df)
166
- print(my_results_df)
167
- print(my_results_df.info())