workbench 0.8.176__py3-none-any.whl → 0.8.178__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

@@ -32,11 +32,11 @@ from sagemaker import Predictor
32
32
  from workbench.core.artifacts.artifact import Artifact
33
33
  from workbench.core.artifacts import FeatureSetCore, ModelCore, ModelType
34
34
  from workbench.utils.endpoint_metrics import EndpointMetrics
35
- from workbench.utils.fast_inference import fast_inference
36
35
  from workbench.utils.cache import Cache
37
36
  from workbench.utils.s3_utils import compute_s3_object_hash
38
37
  from workbench.utils.model_utils import uq_metrics
39
38
  from workbench.utils.xgboost_model_utils import cross_fold_inference
39
+ from workbench_bridges.endpoints.fast_inference import fast_inference
40
40
 
41
41
 
42
42
  class EndpointCore(Artifact):
@@ -1061,6 +1061,9 @@ if __name__ == "__main__":
1061
1061
  assert len(pred_results) == len(my_eval_df), "Predictions should match the number of sent rows"
1062
1062
 
1063
1063
  # Now we put in an invalid value
1064
+ print("*" * 80)
1065
+ print("NOW TESTING ERROR CONDITIONS...")
1066
+ print("*" * 80)
1064
1067
  my_eval_df.at[42, "length"] = "invalid_value"
1065
1068
  pred_results = my_endpoint.inference(my_eval_df, drop_error_rows=True)
1066
1069
  print(f"Sent rows: {len(my_eval_df)}")
@@ -17,7 +17,7 @@ from workbench.core.artifacts.artifact import Artifact
17
17
  from workbench.core.artifacts.data_source_factory import DataSourceFactory
18
18
  from workbench.core.artifacts.athena_source import AthenaSource
19
19
 
20
- from typing import TYPE_CHECKING
20
+ from typing import TYPE_CHECKING, Optional
21
21
 
22
22
  from workbench.utils.aws_utils import aws_throttle
23
23
 
@@ -509,6 +509,25 @@ class FeatureSetCore(Artifact):
509
509
  ].tolist()
510
510
  return hold_out_ids
511
511
 
512
+ def set_training_filter(self, filter_expression: Optional[str] = None):
513
+ """Set a filter expression for the training view for this FeatureSet
514
+
515
+ Args:
516
+ filter_expression (Optional[str]): A SQL filter expression (e.g., "age > 25 AND status = 'active'")
517
+ If None or empty string, will reset to default training view with no filter
518
+ (default: None)
519
+ """
520
+ from workbench.core.views import TrainingView
521
+
522
+ # Grab the existing holdout ids
523
+ holdout_ids = self.get_training_holdouts()
524
+
525
+ # Create a NEW training view
526
+ self.log.important(f"Setting Training Filter: {filter_expression}")
527
+ TrainingView.create(
528
+ self, id_column=self.id_column, holdout_ids=holdout_ids, filter_expression=filter_expression
529
+ )
530
+
512
531
  @classmethod
513
532
  def delete_views(cls, table: str, database: str):
514
533
  """Delete any views associated with this FeatureSet
@@ -707,7 +726,7 @@ if __name__ == "__main__":
707
726
 
708
727
  # Test getting the holdout ids
709
728
  print("Getting the hold out ids...")
710
- holdout_ids = my_features.get_training_holdouts("id")
729
+ holdout_ids = my_features.get_training_holdouts()
711
730
  print(f"Holdout IDs: {holdout_ids}")
712
731
 
713
732
  # Get a sample of the data
@@ -729,16 +748,26 @@ if __name__ == "__main__":
729
748
  table = my_features.view("training").table
730
749
  df = my_features.query(f'SELECT id, name FROM "{table}"')
731
750
  my_holdout_ids = [id for id in df["id"] if id < 20]
732
- my_features.set_training_holdouts("id", my_holdout_ids)
733
-
734
- # Test the hold out set functionality with strings
735
- print("Setting hold out ids (strings)...")
736
- my_holdout_ids = [name for name in df["name"] if int(name.split(" ")[1]) > 80]
737
- my_features.set_training_holdouts("name", my_holdout_ids)
751
+ my_features.set_training_holdouts(my_holdout_ids)
738
752
 
739
753
  # Get the training data
740
754
  print("Getting the training data...")
741
755
  training_data = my_features.get_training_data()
756
+ print(f"Training Data: {training_data.shape}")
757
+
758
+ # Test the filter expression functionality
759
+ print("Setting a filter expression...")
760
+ my_features.set_training_filter("id < 50 AND height > 65.0")
761
+ training_data = my_features.get_training_data()
762
+ print(f"Training Data: {training_data.shape}")
763
+ print(training_data)
764
+
765
+ # Remove training filter
766
+ print("Removing the filter expression...")
767
+ my_features.set_training_filter(None)
768
+ training_data = my_features.get_training_data()
769
+ print(f"Training Data: {training_data.shape}")
770
+ print(training_data)
742
771
 
743
772
  # Now delete the AWS artifacts associated with this Feature Set
744
773
  # print("Deleting Workbench Feature Set...")
@@ -37,35 +37,6 @@ class ModelType(Enum):
37
37
  UNKNOWN = "unknown"
38
38
 
39
39
 
40
- # Deprecated Images
41
- """
42
- # US East 1 images
43
- "py312-general-ml-training"
44
- ("us-east-1", "training", "0.1", "x86_64"): (
45
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
46
- ),
47
- ("us-east-1", "inference", "0.1", "x86_64"): (
48
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
49
- ),
50
-
51
- # US West 2 images
52
- ("us-west-2", "training", "0.1", "x86_64"): (
53
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
54
- ),
55
- ("us-west-2", "inference", "0.1", "x86_64"): (
56
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
57
- ),
58
-
59
- # ARM64 images
60
- ("us-east-1", "inference", "0.1", "arm64"): (
61
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
62
- ),
63
- ("us-west-2", "inference", "0.1", "arm64"): (
64
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
65
- ),
66
- """
67
-
68
-
69
40
  class ModelImages:
70
41
  """Class for retrieving workbench inference images"""
71
42
 
@@ -890,6 +861,14 @@ class ModelCore(Artifact):
890
861
  shap_data[key] = self.df_store.get(df_location)
891
862
  return shap_data or None
892
863
 
864
+ def cross_folds(self) -> dict:
865
+ """Retrieve the cross-fold inference results(only works for XGBoost models)
866
+
867
+ Returns:
868
+ dict: Dictionary with the cross-fold inference results
869
+ """
870
+ return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
871
+
893
872
  def supported_inference_instances(self) -> Optional[list]:
894
873
  """Retrieve the supported endpoint inference instance types
895
874
 
@@ -3,7 +3,7 @@
3
3
  from typing import Union
4
4
 
5
5
  # Workbench Imports
6
- from workbench.api import DataSource, FeatureSet
6
+ from workbench.api import FeatureSet
7
7
  from workbench.core.views.view import View
8
8
  from workbench.core.views.create_view import CreateView
9
9
  from workbench.core.views.view_utils import get_column_list
@@ -34,6 +34,7 @@ class TrainingView(CreateView):
34
34
  source_table: str = None,
35
35
  id_column: str = None,
36
36
  holdout_ids: Union[list[str], list[int], None] = None,
37
+ filter_expression: str = None,
37
38
  ) -> Union[View, None]:
38
39
  """Factory method to create and return a TrainingView instance.
39
40
 
@@ -42,6 +43,8 @@ class TrainingView(CreateView):
42
43
  source_table (str, optional): The table/view to create the view from. Defaults to None.
43
44
  id_column (str, optional): The name of the id column. Defaults to None.
44
45
  holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
46
+ filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
47
+ Defaults to None.
45
48
 
46
49
  Returns:
47
50
  Union[View, None]: The created View object (or None if failed to create the view)
@@ -69,28 +72,36 @@ class TrainingView(CreateView):
69
72
  else:
70
73
  id_column = instance.auto_id_column
71
74
 
72
- # If we don't have holdout ids, create a default training view
73
- if not holdout_ids:
74
- instance._default_training_view(instance.data_source, id_column)
75
- return View(instance.data_source, instance.view_name, auto_create_view=False)
75
+ # Enclose each column name in double quotes
76
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
77
+
78
+ # Build the training assignment logic
79
+ if holdout_ids:
80
+ # Format the list of holdout ids for SQL IN clause
81
+ if all(isinstance(id, str) for id in holdout_ids):
82
+ formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
83
+ else:
84
+ formatted_holdout_ids = ", ".join(map(str, holdout_ids))
76
85
 
77
- # Format the list of holdout ids for SQL IN clause
78
- if holdout_ids and all(isinstance(id, str) for id in holdout_ids):
79
- formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
86
+ training_logic = f"""CASE
87
+ WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
88
+ ELSE True
89
+ END AS training"""
80
90
  else:
81
- formatted_holdout_ids = ", ".join(map(str, holdout_ids))
91
+ # Default 80/20 split using modulo
92
+ training_logic = f"""CASE
93
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
94
+ ELSE False
95
+ END AS training"""
82
96
 
83
- # Enclose each column name in double quotes
84
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
97
+ # Build WHERE clause if filter_expression is provided
98
+ where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
85
99
 
86
100
  # Construct the CREATE VIEW query
87
101
  create_view_query = f"""
88
102
  CREATE OR REPLACE VIEW {instance.table} AS
89
- SELECT {sql_columns}, CASE
90
- WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
91
- ELSE True
92
- END AS training
93
- FROM {instance.source_table}
103
+ SELECT {sql_columns}, {training_logic}
104
+ FROM {instance.source_table}{where_clause}
94
105
  """
95
106
 
96
107
  # Execute the CREATE VIEW query
@@ -99,43 +110,13 @@ class TrainingView(CreateView):
99
110
  # Return the View
100
111
  return View(instance.data_source, instance.view_name, auto_create_view=False)
101
112
 
102
- # This is an internal method that's used to create a default training view
103
- def _default_training_view(self, data_source: DataSource, id_column: str):
104
- """Create a default view in Athena that assigns roughly 80% of the data to training
105
-
106
- Args:
107
- data_source (DataSource): The Workbench DataSource object
108
- id_column (str): The name of the id column
109
- """
110
- self.log.important(f"Creating default Training View {self.table}...")
111
-
112
- # Drop any columns generated from AWS
113
- aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
114
- column_list = [col for col in data_source.columns if col not in aws_cols]
115
-
116
- # Enclose each column name in double quotes
117
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
118
-
119
- # Construct the CREATE VIEW query with a simple modulo operation for the 80/20 split
120
- create_view_query = f"""
121
- CREATE OR REPLACE VIEW "{self.table}" AS
122
- SELECT {sql_columns}, CASE
123
- WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True -- Assign 80% to training
124
- ELSE False -- Assign roughly 20% to validation/test
125
- END AS training
126
- FROM {self.base_table_name}
127
- """
128
-
129
- # Execute the CREATE VIEW query
130
- data_source.execute_statement(create_view_query)
131
-
132
113
 
133
114
  if __name__ == "__main__":
134
115
  """Exercise the Training View functionality"""
135
116
  from workbench.api import FeatureSet
136
117
 
137
118
  # Get the FeatureSet
138
- fs = FeatureSet("test_features")
119
+ fs = FeatureSet("abalone_features")
139
120
 
140
121
  # Delete the existing training view
141
122
  training_view = TrainingView.create(fs)
@@ -152,9 +133,18 @@ if __name__ == "__main__":
152
133
 
153
134
  # Create a TrainingView with holdout ids
154
135
  my_holdout_ids = list(range(10))
155
- training_view = TrainingView.create(fs, id_column="id", holdout_ids=my_holdout_ids)
136
+ training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
156
137
 
157
138
  # Pull the training data
158
139
  df = training_view.pull_dataframe()
159
140
  print(df.head())
160
141
  print(df["training"].value_counts())
142
+ print(f"Shape: {df.shape}")
143
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
144
+
145
+ # Test the filter expression
146
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
147
+ df = training_view.pull_dataframe()
148
+ print(df.head())
149
+ print(f"Shape with filter: {df.shape}")
150
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
@@ -91,16 +91,27 @@ import logging
91
91
  import pandas as pd
92
92
  import numpy as np
93
93
  import re
94
+ import time
95
+ from contextlib import contextmanager
94
96
  from rdkit import Chem
95
97
  from rdkit.Chem import Descriptors, rdCIPLabeler
96
98
  from rdkit.ML.Descriptors import MoleculeDescriptors
97
99
  from mordred import Calculator as MordredCalculator
98
100
  from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
99
101
 
102
+
100
103
  logger = logging.getLogger("workbench")
101
104
  logger.setLevel(logging.DEBUG)
102
105
 
103
106
 
107
+ # Helper context manager for timing
108
+ @contextmanager
109
+ def timer(name):
110
+ start = time.time()
111
+ yield
112
+ print(f"{name}: {time.time() - start:.2f}s")
113
+
114
+
104
115
  def compute_stereochemistry_features(mol):
105
116
  """
106
117
  Compute stereochemistry descriptors using modern RDKit methods.
@@ -280,9 +291,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
280
291
  descriptor_values.append([np.nan] * len(all_descriptors))
281
292
 
282
293
  # Create RDKit features DataFrame
283
- rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames(), index=result.index)
294
+ rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
284
295
 
285
296
  # Add RDKit features to result
297
+ # Remove any columns from result that exist in rdkit_features_df
298
+ result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
286
299
  result = pd.concat([result, rdkit_features_df], axis=1)
287
300
 
288
301
  # Compute Mordred descriptors
@@ -299,7 +312,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
299
312
 
300
313
  # Compute Mordred descriptors
301
314
  valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
302
- mordred_df = calc.pandas(valid_mols, nproc=1) # For serverless, use nproc=1
315
+ mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
303
316
 
304
317
  # Replace values for invalid molecules with NaN
305
318
  for i, mol in enumerate(molecules):
@@ -310,10 +323,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
310
323
  for col in mordred_df.columns:
311
324
  mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
312
325
 
313
- # Set index to match result DataFrame
314
- mordred_df.index = result.index
315
-
316
326
  # Add Mordred features to result
327
+ # Remove any columns from result that exist in mordred
328
+ result = result.drop(columns=result.columns.intersection(mordred_df.columns))
317
329
  result = pd.concat([result, mordred_df], axis=1)
318
330
 
319
331
  # Compute stereochemistry features if requested
@@ -326,9 +338,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
326
338
  stereo_features.append(stereo_dict)
327
339
 
328
340
  # Create stereochemistry DataFrame
329
- stereo_df = pd.DataFrame(stereo_features, index=result.index)
341
+ stereo_df = pd.DataFrame(stereo_features)
330
342
 
331
343
  # Add stereochemistry features to result
344
+ result = result.drop(columns=result.columns.intersection(stereo_df.columns))
332
345
  result = pd.concat([result, stereo_df], axis=1)
333
346
 
334
347
  logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
@@ -357,7 +370,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
357
370
 
358
371
 
359
372
  if __name__ == "__main__":
360
- import time
361
373
  from mol_standardize import standardize
362
374
  from workbench.api import DataSource
363
375
 
@@ -81,6 +81,8 @@ Usage:
81
81
  import logging
82
82
  from typing import Optional, Tuple
83
83
  import pandas as pd
84
+ import time
85
+ from contextlib import contextmanager
84
86
  from rdkit import Chem
85
87
  from rdkit.Chem import Mol
86
88
  from rdkit.Chem.MolStandardize import rdMolStandardize
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
90
92
  RDLogger.DisableLog("rdApp.warning")
91
93
 
92
94
 
95
+ # Helper context manager for timing
96
+ @contextmanager
97
+ def timer(name):
98
+ start = time.time()
99
+ yield
100
+ print(f"{name}: {time.time() - start:.2f}s")
101
+
102
+
93
103
  class MolStandardizer:
94
104
  """
95
105
  Streamlined molecular standardizer for ADMET preprocessing
@@ -116,6 +126,7 @@ class MolStandardizer:
116
126
  Pipeline:
117
127
  1. Cleanup (remove Hs, disconnect metals, normalize)
118
128
  2. Get largest fragment (optional - only if remove_salts=True)
129
+ 2a. Extract salt information BEFORE further modifications
119
130
  3. Neutralize charges
120
131
  4. Canonicalize tautomer (optional)
121
132
 
@@ -130,18 +141,24 @@ class MolStandardizer:
130
141
 
131
142
  try:
132
143
  # Step 1: Cleanup
133
- mol = rdMolStandardize.Cleanup(mol, self.params)
134
- if mol is None:
144
+ cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
145
+ if cleaned_mol is None:
135
146
  return None, None
136
147
 
148
+ # If not doing any transformations, return early
149
+ if not self.remove_salts and not self.canonicalize_tautomer:
150
+ return cleaned_mol, None
151
+
137
152
  salt_smiles = None
153
+ mol = cleaned_mol
138
154
 
139
155
  # Step 2: Fragment handling (conditional based on remove_salts)
140
156
  if self.remove_salts:
141
- # Get parent molecule and extract salt information
142
- parent_mol = rdMolStandardize.FragmentParent(mol, self.params)
157
+ # Get parent molecule
158
+ parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
143
159
  if parent_mol:
144
- salt_smiles = self._extract_salt(mol, parent_mol)
160
+ # Extract salt BEFORE any modifications to parent
161
+ salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
145
162
  mol = parent_mol
146
163
  else:
147
164
  return None, None
@@ -153,7 +170,7 @@ class MolStandardizer:
153
170
  if mol is None:
154
171
  return None, salt_smiles
155
172
 
156
- # Step 4: Canonicalize tautomer
173
+ # Step 4: Canonicalize tautomer (LAST STEP)
157
174
  if self.canonicalize_tautomer:
158
175
  mol = self.tautomer_enumerator.Canonicalize(mol)
159
176
 
@@ -172,13 +189,22 @@ class MolStandardizer:
172
189
  - Mixtures: multiple large neutral organic fragments
173
190
 
174
191
  Args:
175
- orig_mol: Original molecule (before FragmentParent)
176
- parent_mol: Parent molecule (after FragmentParent)
192
+ orig_mol: Original molecule (after Cleanup, before FragmentParent)
193
+ parent_mol: Parent molecule (after FragmentParent, before tautomerization)
177
194
 
178
195
  Returns:
179
196
  SMILES string of salt components or None if no salts/mixture detected
180
197
  """
181
198
  try:
199
+ # Quick atom count check
200
+ if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
201
+ return None
202
+
203
+ # Quick heavy atom difference check
204
+ heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
205
+ if heavy_diff <= 0:
206
+ return None
207
+
182
208
  # Get all fragments from original molecule
183
209
  orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
184
210
 
@@ -268,7 +294,7 @@ def standardize(
268
294
  if "orig_smiles" not in result.columns:
269
295
  result["orig_smiles"] = result[smiles_column]
270
296
 
271
- # Initialize standardizer with salt removal control
297
+ # Initialize standardizer
272
298
  standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
273
299
 
274
300
  def process_smiles(smiles: str) -> pd.Series:
@@ -286,6 +312,11 @@ def standardize(
286
312
  log.error("Encountered missing or empty SMILES string")
287
313
  return pd.Series({"smiles": None, "salt": None})
288
314
 
315
+ # Early check for unreasonably long SMILES
316
+ if len(smiles) > 1000:
317
+ log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
318
+ return pd.Series({"smiles": None, "salt": None})
319
+
289
320
  # Parse molecule
290
321
  mol = Chem.MolFromSmiles(smiles)
291
322
  if mol is None:
@@ -299,7 +330,9 @@ def standardize(
299
330
  if std_mol is not None:
300
331
  # Check if molecule is reasonable
301
332
  if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
302
- log.error(f"Unusual molecule size: {std_mol.GetNumAtoms()} atoms")
333
+ log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
334
+ log.error(f"Original SMILES: {smiles}")
335
+ return pd.Series({"smiles": None, "salt": salt_smiles})
303
336
 
304
337
  if std_mol is None:
305
338
  return pd.Series(
@@ -325,8 +358,11 @@ def standardize(
325
358
 
326
359
 
327
360
  if __name__ == "__main__":
328
- import time
329
- from workbench.api import DataSource
361
+
362
+ # Pandas display options for better readability
363
+ pd.set_option("display.max_columns", None)
364
+ pd.set_option("display.width", 1000)
365
+ pd.set_option("display.max_colwidth", 100)
330
366
 
331
367
  # Test with DataFrame including various salt forms
332
368
  test_data = pd.DataFrame(
@@ -362,67 +398,53 @@ if __name__ == "__main__":
362
398
  )
363
399
 
364
400
  # General test
401
+ print("Testing standardization with full dataset...")
365
402
  standardize(test_data)
366
403
 
367
404
  # Remove the last two rows to avoid errors with None and INVALID
368
405
  test_data = test_data.iloc[:-2].reset_index(drop=True)
369
406
 
370
407
  # Test WITHOUT salt removal (keeps full molecule)
371
- print("\nStandardization KEEPING salts (extract_salts=False):")
372
- print("This preserves the full molecule including counterions")
408
+ print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
373
409
  result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
374
- display_cols = ["compound_id", "orig_smiles", "smiles", "salt"]
375
- print(result_keep[display_cols].to_string())
410
+ display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
411
+ print(result_keep[display_order])
376
412
 
377
413
  # Test WITH salt removal
378
414
  print("\n" + "=" * 70)
379
415
  print("Standardization REMOVING salts (extract_salts=True):")
380
- print("This extracts parent molecule and records salt information")
381
416
  result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
382
- print(result_remove[display_cols].to_string())
417
+ print(result_remove[display_order])
383
418
 
384
- # Test WITHOUT tautomerization (keeping salts)
419
+ # Test with problematic cases specifically
385
420
  print("\n" + "=" * 70)
386
- print("Standardization KEEPING salts, NO tautomerization:")
387
- result_no_taut = standardize(test_data, extract_salts=False, canonicalize_tautomer=False)
388
- print(result_no_taut[display_cols].to_string())
421
+ print("Testing specific problematic cases:")
422
+ problem_cases = pd.DataFrame(
423
+ {
424
+ "smiles": [
425
+ "CC(=O)O.CCN", # Should extract CC(=O)O as salt
426
+ "CCO.CC", # Should return CC as salt
427
+ ],
428
+ "compound_id": ["TEST_C002", "TEST_C005"],
429
+ }
430
+ )
431
+
432
+ problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
433
+ print(problem_result[display_order])
434
+
435
+ # Performance test with larger dataset
436
+ from workbench.api import DataSource
389
437
 
390
- # Show the difference for salt-containing molecules
391
- print("\n" + "=" * 70)
392
- print("Comparison showing differences:")
393
- for idx, row in result_keep.iterrows():
394
- keep_smiles = row["smiles"]
395
- remove_smiles = result_remove.loc[idx, "smiles"]
396
- no_taut_smiles = result_no_taut.loc[idx, "smiles"]
397
- salt = result_remove.loc[idx, "salt"]
398
-
399
- # Show differences when they exist
400
- if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
401
- print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
402
- if keep_smiles != no_taut_smiles:
403
- print(f" With salt + taut: {keep_smiles}")
404
- print(f" With salt, no taut: {no_taut_smiles}")
405
- if keep_smiles != remove_smiles:
406
- print(f" Parent only + taut: {remove_smiles}")
407
- if salt:
408
- print(f" Extracted salt: {salt}")
409
-
410
- # Summary statistics
411
438
  print("\n" + "=" * 70)
412
- print("Summary:")
413
- print(f"Total molecules: {len(result_remove)}")
414
- print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
415
- unique_salts = result_remove["salt"].dropna().unique()
416
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
417
439
 
418
- # Get a real dataset from Workbench and time the standardization
419
440
  ds = DataSource("aqsol_data")
420
- df = ds.pull_dataframe()[["id", "smiles"]]
421
- start_time = time.time()
422
- std_df = standardize(df, extract_salts=True, canonicalize_tautomer=True)
423
- end_time = time.time()
424
- print(f"\nStandardized {len(std_df)} molecules from Workbench in {end_time - start_time:.2f} seconds")
425
- print(std_df.head())
426
- print(f"Molecules with salts: {std_df['salt'].notna().sum()}")
427
- unique_salts = std_df["salt"].dropna().unique()
428
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
441
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000]
442
+
443
+ for tautomer in [True, False]:
444
+ for extract in [True, False]:
445
+ print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
446
+ start_time = time.time()
447
+ std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
448
+ elapsed = time.time() - start_time
449
+ mol_per_sec = len(df) / elapsed
450
+ print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")