workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,8 @@ import logging
91
91
  import pandas as pd
92
92
  import numpy as np
93
93
  import re
94
+ import time
95
+ from contextlib import contextmanager
94
96
  from rdkit import Chem
95
97
  from rdkit.Chem import Descriptors, rdCIPLabeler
96
98
  from rdkit.ML.Descriptors import MoleculeDescriptors
@@ -101,6 +103,14 @@ logger = logging.getLogger("workbench")
101
103
  logger.setLevel(logging.DEBUG)
102
104
 
103
105
 
106
+ # Helper context manager for timing
107
+ @contextmanager
108
+ def timer(name):
109
+ start = time.time()
110
+ yield
111
+ print(f"{name}: {time.time() - start:.2f}s")
112
+
113
+
104
114
  def compute_stereochemistry_features(mol):
105
115
  """
106
116
  Compute stereochemistry descriptors using modern RDKit methods.
@@ -280,9 +290,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
280
290
  descriptor_values.append([np.nan] * len(all_descriptors))
281
291
 
282
292
  # Create RDKit features DataFrame
283
- rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames(), index=result.index)
293
+ rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
284
294
 
285
295
  # Add RDKit features to result
296
+ # Remove any columns from result that exist in rdkit_features_df
297
+ result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
286
298
  result = pd.concat([result, rdkit_features_df], axis=1)
287
299
 
288
300
  # Compute Mordred descriptors
@@ -299,7 +311,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
299
311
 
300
312
  # Compute Mordred descriptors
301
313
  valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
302
- mordred_df = calc.pandas(valid_mols, nproc=1) # For serverless, use nproc=1
314
+ mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
303
315
 
304
316
  # Replace values for invalid molecules with NaN
305
317
  for i, mol in enumerate(molecules):
@@ -310,10 +322,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
310
322
  for col in mordred_df.columns:
311
323
  mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
312
324
 
313
- # Set index to match result DataFrame
314
- mordred_df.index = result.index
315
-
316
325
  # Add Mordred features to result
326
+ # Remove any columns from result that exist in mordred
327
+ result = result.drop(columns=result.columns.intersection(mordred_df.columns))
317
328
  result = pd.concat([result, mordred_df], axis=1)
318
329
 
319
330
  # Compute stereochemistry features if requested
@@ -326,9 +337,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
326
337
  stereo_features.append(stereo_dict)
327
338
 
328
339
  # Create stereochemistry DataFrame
329
- stereo_df = pd.DataFrame(stereo_features, index=result.index)
340
+ stereo_df = pd.DataFrame(stereo_features)
330
341
 
331
342
  # Add stereochemistry features to result
343
+ result = result.drop(columns=result.columns.intersection(stereo_df.columns))
332
344
  result = pd.concat([result, stereo_df], axis=1)
333
345
 
334
346
  logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
@@ -357,7 +369,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
357
369
 
358
370
 
359
371
  if __name__ == "__main__":
360
- import time
361
372
  from mol_standardize import standardize
362
373
  from workbench.api import DataSource
363
374
 
@@ -81,6 +81,8 @@ Usage:
81
81
  import logging
82
82
  from typing import Optional, Tuple
83
83
  import pandas as pd
84
+ import time
85
+ from contextlib import contextmanager
84
86
  from rdkit import Chem
85
87
  from rdkit.Chem import Mol
86
88
  from rdkit.Chem.MolStandardize import rdMolStandardize
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
90
92
  RDLogger.DisableLog("rdApp.warning")
91
93
 
92
94
 
95
+ # Helper context manager for timing
96
+ @contextmanager
97
+ def timer(name):
98
+ start = time.time()
99
+ yield
100
+ print(f"{name}: {time.time() - start:.2f}s")
101
+
102
+
93
103
  class MolStandardizer:
94
104
  """
95
105
  Streamlined molecular standardizer for ADMET preprocessing
@@ -116,6 +126,7 @@ class MolStandardizer:
116
126
  Pipeline:
117
127
  1. Cleanup (remove Hs, disconnect metals, normalize)
118
128
  2. Get largest fragment (optional - only if remove_salts=True)
129
+ 2a. Extract salt information BEFORE further modifications
119
130
  3. Neutralize charges
120
131
  4. Canonicalize tautomer (optional)
121
132
 
@@ -130,18 +141,24 @@ class MolStandardizer:
130
141
 
131
142
  try:
132
143
  # Step 1: Cleanup
133
- mol = rdMolStandardize.Cleanup(mol, self.params)
134
- if mol is None:
144
+ cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
145
+ if cleaned_mol is None:
135
146
  return None, None
136
147
 
148
+ # If not doing any transformations, return early
149
+ if not self.remove_salts and not self.canonicalize_tautomer:
150
+ return cleaned_mol, None
151
+
137
152
  salt_smiles = None
153
+ mol = cleaned_mol
138
154
 
139
155
  # Step 2: Fragment handling (conditional based on remove_salts)
140
156
  if self.remove_salts:
141
- # Get parent molecule and extract salt information
142
- parent_mol = rdMolStandardize.FragmentParent(mol, self.params)
157
+ # Get parent molecule
158
+ parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
143
159
  if parent_mol:
144
- salt_smiles = self._extract_salt(mol, parent_mol)
160
+ # Extract salt BEFORE any modifications to parent
161
+ salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
145
162
  mol = parent_mol
146
163
  else:
147
164
  return None, None
@@ -153,7 +170,7 @@ class MolStandardizer:
153
170
  if mol is None:
154
171
  return None, salt_smiles
155
172
 
156
- # Step 4: Canonicalize tautomer
173
+ # Step 4: Canonicalize tautomer (LAST STEP)
157
174
  if self.canonicalize_tautomer:
158
175
  mol = self.tautomer_enumerator.Canonicalize(mol)
159
176
 
@@ -172,13 +189,22 @@ class MolStandardizer:
172
189
  - Mixtures: multiple large neutral organic fragments
173
190
 
174
191
  Args:
175
- orig_mol: Original molecule (before FragmentParent)
176
- parent_mol: Parent molecule (after FragmentParent)
192
+ orig_mol: Original molecule (after Cleanup, before FragmentParent)
193
+ parent_mol: Parent molecule (after FragmentParent, before tautomerization)
177
194
 
178
195
  Returns:
179
196
  SMILES string of salt components or None if no salts/mixture detected
180
197
  """
181
198
  try:
199
+ # Quick atom count check
200
+ if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
201
+ return None
202
+
203
+ # Quick heavy atom difference check
204
+ heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
205
+ if heavy_diff <= 0:
206
+ return None
207
+
182
208
  # Get all fragments from original molecule
183
209
  orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
184
210
 
@@ -268,7 +294,7 @@ def standardize(
268
294
  if "orig_smiles" not in result.columns:
269
295
  result["orig_smiles"] = result[smiles_column]
270
296
 
271
- # Initialize standardizer with salt removal control
297
+ # Initialize standardizer
272
298
  standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
273
299
 
274
300
  def process_smiles(smiles: str) -> pd.Series:
@@ -286,6 +312,11 @@ def standardize(
286
312
  log.error("Encountered missing or empty SMILES string")
287
313
  return pd.Series({"smiles": None, "salt": None})
288
314
 
315
+ # Early check for unreasonably long SMILES
316
+ if len(smiles) > 1000:
317
+ log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
318
+ return pd.Series({"smiles": None, "salt": None})
319
+
289
320
  # Parse molecule
290
321
  mol = Chem.MolFromSmiles(smiles)
291
322
  if mol is None:
@@ -299,7 +330,9 @@ def standardize(
299
330
  if std_mol is not None:
300
331
  # Check if molecule is reasonable
301
332
  if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
302
- log.error(f"Unusual molecule size: {std_mol.GetNumAtoms()} atoms")
333
+ log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
334
+ log.error(f"Original SMILES: {smiles}")
335
+ return pd.Series({"smiles": None, "salt": salt_smiles})
303
336
 
304
337
  if std_mol is None:
305
338
  return pd.Series(
@@ -325,8 +358,11 @@ def standardize(
325
358
 
326
359
 
327
360
  if __name__ == "__main__":
328
- import time
329
- from workbench.api import DataSource
361
+
362
+ # Pandas display options for better readability
363
+ pd.set_option("display.max_columns", None)
364
+ pd.set_option("display.width", 1000)
365
+ pd.set_option("display.max_colwidth", 100)
330
366
 
331
367
  # Test with DataFrame including various salt forms
332
368
  test_data = pd.DataFrame(
@@ -362,67 +398,53 @@ if __name__ == "__main__":
362
398
  )
363
399
 
364
400
  # General test
401
+ print("Testing standardization with full dataset...")
365
402
  standardize(test_data)
366
403
 
367
404
  # Remove the last two rows to avoid errors with None and INVALID
368
405
  test_data = test_data.iloc[:-2].reset_index(drop=True)
369
406
 
370
407
  # Test WITHOUT salt removal (keeps full molecule)
371
- print("\nStandardization KEEPING salts (extract_salts=False):")
372
- print("This preserves the full molecule including counterions")
408
+ print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
373
409
  result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
374
- display_cols = ["compound_id", "orig_smiles", "smiles", "salt"]
375
- print(result_keep[display_cols].to_string())
410
+ display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
411
+ print(result_keep[display_order])
376
412
 
377
413
  # Test WITH salt removal
378
414
  print("\n" + "=" * 70)
379
415
  print("Standardization REMOVING salts (extract_salts=True):")
380
- print("This extracts parent molecule and records salt information")
381
416
  result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
382
- print(result_remove[display_cols].to_string())
417
+ print(result_remove[display_order])
383
418
 
384
- # Test WITHOUT tautomerization (keeping salts)
419
+ # Test with problematic cases specifically
385
420
  print("\n" + "=" * 70)
386
- print("Standardization KEEPING salts, NO tautomerization:")
387
- result_no_taut = standardize(test_data, extract_salts=False, canonicalize_tautomer=False)
388
- print(result_no_taut[display_cols].to_string())
421
+ print("Testing specific problematic cases:")
422
+ problem_cases = pd.DataFrame(
423
+ {
424
+ "smiles": [
425
+ "CC(=O)O.CCN", # Should extract CC(=O)O as salt
426
+ "CCO.CC", # Should return CC as salt
427
+ ],
428
+ "compound_id": ["TEST_C002", "TEST_C005"],
429
+ }
430
+ )
431
+
432
+ problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
433
+ print(problem_result[display_order])
434
+
435
+ # Performance test with larger dataset
436
+ from workbench.api import DataSource
389
437
 
390
- # Show the difference for salt-containing molecules
391
- print("\n" + "=" * 70)
392
- print("Comparison showing differences:")
393
- for idx, row in result_keep.iterrows():
394
- keep_smiles = row["smiles"]
395
- remove_smiles = result_remove.loc[idx, "smiles"]
396
- no_taut_smiles = result_no_taut.loc[idx, "smiles"]
397
- salt = result_remove.loc[idx, "salt"]
398
-
399
- # Show differences when they exist
400
- if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
401
- print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
402
- if keep_smiles != no_taut_smiles:
403
- print(f" With salt + taut: {keep_smiles}")
404
- print(f" With salt, no taut: {no_taut_smiles}")
405
- if keep_smiles != remove_smiles:
406
- print(f" Parent only + taut: {remove_smiles}")
407
- if salt:
408
- print(f" Extracted salt: {salt}")
409
-
410
- # Summary statistics
411
438
  print("\n" + "=" * 70)
412
- print("Summary:")
413
- print(f"Total molecules: {len(result_remove)}")
414
- print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
415
- unique_salts = result_remove["salt"].dropna().unique()
416
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
417
439
 
418
- # Get a real dataset from Workbench and time the standardization
419
440
  ds = DataSource("aqsol_data")
420
- df = ds.pull_dataframe()[["id", "smiles"]]
421
- start_time = time.time()
422
- std_df = standardize(df, extract_salts=True, canonicalize_tautomer=True)
423
- end_time = time.time()
424
- print(f"\nStandardized {len(std_df)} molecules from Workbench in {end_time - start_time:.2f} seconds")
425
- print(std_df.head())
426
- print(f"Molecules with salts: {std_df['salt'].notna().sum()}")
427
- unique_salts = std_df["salt"].dropna().unique()
428
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
441
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000]
442
+
443
+ for tautomer in [True, False]:
444
+ for extract in [True, False]:
445
+ print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
446
+ start_time = time.time()
447
+ std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
448
+ elapsed = time.time() - start_time
449
+ mol_per_sec = len(df) / elapsed
450
+ print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
@@ -17,18 +17,28 @@ log = logging.getLogger("workbench")
17
17
 
18
18
  def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
19
19
  """
20
- Convert bitstring fingerprints to numpy matrix.
20
+ Convert fingerprints to numpy matrix.
21
+
22
+ Supports two formats (auto-detected):
23
+ - Bitstrings: "10110010..." → matrix of 0s and 1s
24
+ - Count vectors: "0,3,0,1,5,..." → matrix of counts (or binary if dtype=np.bool_)
21
25
 
22
26
  Args:
23
- fingerprints: pandas Series or list of bitstring fingerprints
24
- dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
27
+ fingerprints: pandas Series or list of fingerprints
28
+ dtype: numpy data type (uint8 is default; np.bool_ for Jaccard computations)
25
29
 
26
30
  Returns:
27
31
  dense numpy array of shape (n_molecules, n_bits)
28
32
  """
29
-
30
- # Dense matrix representation (we might support sparse in the future)
31
- return np.array([list(fp) for fp in fingerprints], dtype=dtype)
33
+ # Auto-detect format based on first fingerprint
34
+ sample = str(fingerprints.iloc[0] if hasattr(fingerprints, "iloc") else fingerprints[0])
35
+ if "," in sample:
36
+ # Count vector format: comma-separated integers
37
+ matrix = np.array([list(map(int, fp.split(","))) for fp in fingerprints], dtype=dtype)
38
+ else:
39
+ # Bitstring format: each character is a bit
40
+ matrix = np.array([list(fp) for fp in fingerprints], dtype=dtype)
41
+ return matrix
32
42
 
33
43
 
34
44
  def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame:
@@ -2,34 +2,18 @@
2
2
 
3
3
  import logging
4
4
  import base64
5
- import re
6
5
  from typing import Optional, Tuple
7
6
  from rdkit import Chem
8
7
  from rdkit.Chem import AllChem, Draw
9
8
  from rdkit.Chem.Draw import rdMolDraw2D
10
9
 
10
+ # Workbench Imports
11
+ from workbench.utils.color_utils import is_dark
12
+
11
13
  # Set up the logger
12
14
  log = logging.getLogger("workbench")
13
15
 
14
16
 
15
- def _is_dark(color: str) -> bool:
16
- """Determine if an rgba color is dark based on RGB average.
17
-
18
- Args:
19
- color: Color in rgba(...) format
20
-
21
- Returns:
22
- True if the color is dark, False otherwise
23
- """
24
- match = re.match(r"rgba?\((\d+),\s*(\d+),\s*(\d+)", color)
25
- if not match:
26
- log.warning(f"Invalid color format: {color}, defaulting to dark")
27
- return True # Default to dark mode on error
28
-
29
- r, g, b = map(int, match.groups())
30
- return (r + g + b) / 3 < 128
31
-
32
-
33
17
  def _rgba_to_tuple(rgba: str) -> Tuple[float, float, float, float]:
34
18
  """Convert rgba string to normalized tuple (R, G, B, A).
35
19
 
@@ -75,7 +59,13 @@ def _configure_draw_options(options: Draw.MolDrawOptions, background: str) -> No
75
59
  options: RDKit drawing options object
76
60
  background: Background color string
77
61
  """
78
- if _is_dark(background):
62
+ try:
63
+ if is_dark(background):
64
+ rdMolDraw2D.SetDarkMode(options)
65
+ # Light backgrounds use RDKit defaults (no action needed)
66
+ except ValueError:
67
+ # Default to dark mode if color format is invalid
68
+ log.warning(f"Invalid color format: {background}, defaulting to dark mode")
79
69
  rdMolDraw2D.SetDarkMode(options)
80
70
  options.setBackgroundColour(_rgba_to_tuple(background))
81
71
 
@@ -137,7 +127,7 @@ def svg_from_smiles(
137
127
  drawer.DrawMolecule(mol)
138
128
  drawer.FinishDrawing()
139
129
 
140
- # Encode SVG
130
+ # Encode SVG as base64 data URI
141
131
  svg = drawer.GetDrawingText()
142
132
  encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
143
133
  return f"data:image/svg+xml;base64,{encoded_svg}"
@@ -222,7 +212,7 @@ if __name__ == "__main__":
222
212
  # Test 6: Color parsing functions
223
213
  print("\n6. Testing color utility functions...")
224
214
  test_colors = [
225
- ("invalid_color", True, (0.25, 0.25, 0.25, 1.0)), # Should use defaults
215
+ ("invalid_color", None, (0.25, 0.25, 0.25, 1.0)), # Should raise ValueError
226
216
  ("rgba(255, 255, 255, 1)", False, (1.0, 1.0, 1.0, 1.0)),
227
217
  ("rgba(0, 0, 0, 1)", True, (0.0, 0.0, 0.0, 1.0)),
228
218
  ("rgba(64, 64, 64, 0.5)", True, (0.251, 0.251, 0.251, 0.5)),
@@ -230,12 +220,20 @@ if __name__ == "__main__":
230
220
  ]
231
221
 
232
222
  for color, expected_dark, expected_tuple in test_colors:
233
- is_dark_result = _is_dark(color)
234
- tuple_result = _rgba_to_tuple(color)
235
-
236
- dark_status = "✓" if is_dark_result == expected_dark else "✗"
237
- print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
223
+ try:
224
+ is_dark_result = is_dark(color)
225
+ if expected_dark is None:
226
+ print(f" is_dark('{color[:20]}...'): Expected ValueError but got {is_dark_result}")
227
+ else:
228
+ dark_status = "✓" if is_dark_result == expected_dark else "✗"
229
+ print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
230
+ except ValueError:
231
+ if expected_dark is None:
232
+ print(f" ✓ is_dark('{color[:20]}...'): Correctly raised ValueError")
233
+ else:
234
+ print(f" ✗ is_dark('{color[:20]}...'): Unexpected ValueError")
238
235
 
236
+ tuple_result = _rgba_to_tuple(color)
239
237
  # Check tuple values with tolerance for floating point
240
238
  tuple_match = all(abs(a - b) < 0.01 for a, b in zip(tuple_result, expected_tuple))
241
239
  tuple_status = "✓" if tuple_match else "✗"
@@ -0,0 +1,141 @@
1
+ """ChemProp utilities for Workbench models."""
2
+
3
+ import logging
4
+ import os
5
+ from typing import Any, Tuple
6
+
7
+ import pandas as pd
8
+
9
+ from workbench.utils.aws_utils import pull_s3_data
10
+ from workbench.utils.metrics_utils import compute_metrics_from_predictions
11
+ from workbench.utils.model_utils import safe_extract_tarfile
12
+
13
+ log = logging.getLogger("workbench")
14
+
15
+
16
+ def download_and_extract_model(s3_uri: str, model_dir: str) -> None:
17
+ """Download model artifact from S3 and extract it.
18
+
19
+ Args:
20
+ s3_uri: S3 URI to the model artifact (model.tar.gz)
21
+ model_dir: Directory to extract model artifacts to
22
+ """
23
+ import awswrangler as wr
24
+
25
+ log.info(f"Downloading model from {s3_uri}...")
26
+
27
+ # Download to temp file
28
+ local_tar_path = os.path.join(model_dir, "model.tar.gz")
29
+ wr.s3.download(path=s3_uri, local_file=local_tar_path)
30
+
31
+ # Extract using safe extraction
32
+ log.info(f"Extracting to {model_dir}...")
33
+ safe_extract_tarfile(local_tar_path, model_dir)
34
+
35
+ # Cleanup tar file
36
+ os.unlink(local_tar_path)
37
+
38
+
39
+ def load_chemprop_model_artifacts(model_dir: str) -> Tuple[Any, dict]:
40
+ """Load ChemProp MPNN model and artifacts from an extracted model directory.
41
+
42
+ Args:
43
+ model_dir: Directory containing extracted model artifacts
44
+
45
+ Returns:
46
+ Tuple of (MPNN model, artifacts_dict).
47
+ artifacts_dict contains 'label_encoder' and 'feature_metadata' if present.
48
+ """
49
+ import joblib
50
+ from chemprop import models
51
+
52
+ model_path = os.path.join(model_dir, "chemprop_model.pt")
53
+ if not os.path.exists(model_path):
54
+ raise FileNotFoundError(f"No chemprop_model.pt found in {model_dir}")
55
+
56
+ model = models.MPNN.load_from_file(model_path)
57
+ model.eval()
58
+
59
+ # Load additional artifacts
60
+ artifacts = {}
61
+
62
+ label_encoder_path = os.path.join(model_dir, "label_encoder.joblib")
63
+ if os.path.exists(label_encoder_path):
64
+ artifacts["label_encoder"] = joblib.load(label_encoder_path)
65
+
66
+ feature_metadata_path = os.path.join(model_dir, "feature_metadata.joblib")
67
+ if os.path.exists(feature_metadata_path):
68
+ artifacts["feature_metadata"] = joblib.load(feature_metadata_path)
69
+
70
+ return model, artifacts
71
+
72
+
73
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
74
+ """Pull cross-validation results from AWS training artifacts.
75
+
76
+ This retrieves the validation predictions saved during model training and
77
+ computes metrics directly from them.
78
+
79
+ Note:
80
+ - Regression: Supports both single-target and multi-target models
81
+ - Classification: Only single-target is supported (with any number of classes)
82
+
83
+ Args:
84
+ workbench_model: Workbench model object
85
+
86
+ Returns:
87
+ Tuple of:
88
+ - DataFrame with computed metrics
89
+ - DataFrame with validation predictions
90
+ """
91
+
92
+ # Get the validation predictions from S3
93
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
94
+ predictions_df = pull_s3_data(s3_path)
95
+
96
+ if predictions_df is None:
97
+ raise ValueError(f"No validation predictions found at {s3_path}")
98
+
99
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
100
+
101
+ # Get target and class labels
102
+ target = workbench_model.target()
103
+ class_labels = workbench_model.class_labels()
104
+
105
+ # If single target just use the "prediction" column
106
+ if isinstance(target, str):
107
+ metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
108
+ return metrics_df, predictions_df
109
+
110
+ # Multi-target regression
111
+ metrics_list = []
112
+ for t in target:
113
+ # Prediction will be {target}_pred in multi-target case
114
+ pred_col = f"{t}_pred"
115
+
116
+ # Drop NaNs for this target
117
+ target_preds_df = predictions_df.dropna(subset=[t, pred_col])
118
+ metrics_df = compute_metrics_from_predictions(target_preds_df, t, class_labels, prediction_col=pred_col)
119
+ metrics_df.insert(0, "target", t)
120
+ metrics_list.append(metrics_df)
121
+ metrics_df = pd.concat(metrics_list, ignore_index=True) if metrics_list else pd.DataFrame()
122
+
123
+ return metrics_df, predictions_df
124
+
125
+
126
+ if __name__ == "__main__":
127
+
128
+ # Tests for the ChemProp utilities
129
+ from workbench.api import Model
130
+
131
+ # Initialize Workbench model
132
+ model_name = "open-admet-chemprop-mt"
133
+ print(f"Loading Workbench model: {model_name}")
134
+ model = Model(model_name)
135
+ print(f"Model Framework: {model.model_framework}")
136
+
137
+ # Pull CV results
138
+ metrics_df, predictions_df = pull_cv_results(model)
139
+ print("\nTraining Metrics:")
140
+ print(metrics_df.to_string(index=False))
141
+ print(f"\nSample Predictions:\n{predictions_df.head().to_string(index=False)}")
@@ -4,16 +4,13 @@ import os
4
4
  import sys
5
5
  import platform
6
6
  import logging
7
- import importlib.resources as resources # noqa: F401 Python 3.9 compatibility
8
7
  from typing import Any, Dict
8
+ from importlib.resources import files, as_file
9
9
 
10
10
  # Workbench imports
11
11
  from workbench.utils.license_manager import LicenseManager
12
12
  from workbench_bridges.utils.execution_environment import running_as_service
13
13
 
14
- # Python 3.9 compatibility
15
- from workbench.utils.resource_utils import get_resource_path
16
-
17
14
 
18
15
  class FatalConfigError(Exception):
19
16
  """Exception raised for errors in the configuration."""
@@ -172,8 +169,7 @@ class ConfigManager:
172
169
  Returns:
173
170
  str: The open source API key.
174
171
  """
175
- # Python 3.9 compatibility
176
- with get_resource_path("workbench.resources", "open_source_api.key") as open_source_key_path:
172
+ with as_file(files("workbench.resources").joinpath("open_source_api.key")) as open_source_key_path:
177
173
  with open(open_source_key_path, "r") as key_file:
178
174
  return key_file.read().strip()
179
175
 
@@ -7,9 +7,7 @@ from typing import Union, Optional
7
7
  import pandas as pd
8
8
 
9
9
  # Workbench Imports
10
- from workbench.api.feature_set import FeatureSet
11
- from workbench.api.model import Model
12
- from workbench.api.endpoint import Endpoint
10
+ from workbench.api import FeatureSet, Model, Endpoint
13
11
 
14
12
  # Set up the log
15
13
  log = logging.getLogger("workbench")
@@ -77,7 +75,7 @@ def internal_model_data_url(endpoint_config_name: str, session: boto3.Session) -
77
75
  return None
78
76
 
79
77
 
80
- def fs_training_data(end: Endpoint) -> pd.DataFrame:
78
+ def get_training_data(end: Endpoint) -> pd.DataFrame:
81
79
  """Code to get the training data from the FeatureSet used to train the Model
82
80
 
83
81
  Args:
@@ -100,7 +98,7 @@ def fs_training_data(end: Endpoint) -> pd.DataFrame:
100
98
  return train_df
101
99
 
102
100
 
103
- def fs_evaluation_data(end: Endpoint) -> pd.DataFrame:
101
+ def get_evaluation_data(end: Endpoint) -> pd.DataFrame:
104
102
  """Code to get the evaluation data from the FeatureSet NOT used for training
105
103
 
106
104
  Args:
@@ -178,11 +176,11 @@ if __name__ == "__main__":
178
176
  print(model_data_url)
179
177
 
180
178
  # Get the training data
181
- my_train_df = fs_training_data(my_endpoint)
179
+ my_train_df = get_training_data(my_endpoint)
182
180
  print(my_train_df)
183
181
 
184
182
  # Get the evaluation data
185
- my_eval_df = fs_evaluation_data(my_endpoint)
183
+ my_eval_df = get_evaluation_data(my_endpoint)
186
184
  print(my_eval_df)
187
185
 
188
186
  # Backtrack to the FeatureSet