workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,450 @@
1
+ """
2
+ mol_standardize.py - Molecular Standardization for ADMET Preprocessing
3
+ Following ChEMBL structure standardization pipeline
4
+
5
+ Purpose:
6
+ Standardizes chemical structures to ensure consistent molecular representations
7
+ for ADMET modeling. Handles tautomers, salts, charges, and structural variations
8
+ that can cause the same compound to be represented differently.
9
+
10
+ Standardization Pipeline:
11
+ 1. Cleanup
12
+ - Removes explicit hydrogens
13
+ - Disconnects metal atoms from organic fragments
14
+ - Normalizes functional groups (e.g., nitro, sulfoxide representations)
15
+
16
+ 2. Fragment Parent Selection (optional, controlled by extract_salts parameter)
17
+ - Identifies and keeps the largest organic fragment
18
+ - Removes salts, solvents, and counterions
19
+ - Example: [Na+].CC(=O)[O-] → CC(=O)O (keeps acetate, removes sodium)
20
+
21
+ 3. Charge Neutralization (optional, controlled by extract_salts parameter)
22
+ - Neutralizes charges where possible
23
+ - Only applied when extract_salts=True (following ChEMBL pipeline)
24
+ - Skipped when extract_salts=False to preserve ionic character
25
+ - Example: CC(=O)[O-] → CC(=O)O
26
+
27
+ 4. Tautomer Canonicalization (optional, default=True)
28
+ - Generates canonical tautomer form for consistency
29
+ - Example: Oc1ccccn1 → O=c1cccc[nH]1 (2-hydroxypyridine → 2-pyridone)
30
+
31
+ Output DataFrame Columns:
32
+ - orig_smiles: Original input SMILES (preserved for traceability)
33
+ - smiles: Standardized molecule (with or without salts based on extract_salts)
34
+ - salt: Removed salt/counterion as SMILES (only populated if extract_salts=True)
35
+
36
+ Salt Handling:
37
+ Salt forms can dramatically affect properties like solubility.
38
+ This module offers two modes for handling salts:
39
+
40
+ When extract_salts=True (default, ChEMBL standard):
41
+ - Removes salts/counterions to get parent molecule
42
+ - Neutralizes charges on the parent
43
+ - Records removed salts in 'salt' column
44
+ Input: [Na+].CC(=O)[O-] → Parent: CC(=O)O, Salt: [Na+]
45
+
46
+ When extract_salts=False (preserve full salt form):
47
+ - Keeps all fragments including salts/counterions
48
+ - Preserves ionic charges (no neutralization)
49
+
50
+ Mixture Detection:
51
+ The module detects and logs potential mixtures (vs true salt forms):
52
+ - Multiple large neutral organic fragments indicate a mixture
53
+ - Mixtures are logged but NOT recorded in the salt column
54
+ - True salts (small/charged fragments) are properly extracted
55
+
56
+ Downstream modeling options:
57
+ 1. Use parent only (standard approach for most ADMET properties)
58
+ 2. Include salt as a categorical or computed feature
59
+ 3. Model parent + salt effects hierarchically
60
+ 4. Use full salt form for properties like solubility/formulation
61
+
62
+ References:
63
+ - "ChEMBL Structure Pipeline" (Bento et al., 2020)
64
+ https://doi.org/10.1186/s13321-020-00456-1
65
+ - "Standardization and Validation with the RDKit" (Greg Landrum, RSC Open Science 2021)
66
+ https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/Standardization%20and%20Validation%20with%20the%20RDKit.ipynb
67
+
68
+ Usage:
69
+ from mol_standardize import standardize
70
+
71
+ # Basic usage (removes salts by default, ChEMBL standard)
72
+ df_std = standardize(df, smiles_column='smiles')
73
+
74
+ # Keep salts in the molecule (preserve ionic forms)
75
+ df_std = standardize(df, extract_salts=False)
76
+
77
+ # Without tautomer canonicalization (faster, less aggressive)
78
+ df_std = standardize(df, canonicalize_tautomer=False)
79
+ """
80
+
81
+ import logging
82
+ from typing import Optional, Tuple
83
+ import pandas as pd
84
+ import time
85
+ from contextlib import contextmanager
86
+ from rdkit import Chem
87
+ from rdkit.Chem import Mol
88
+ from rdkit.Chem.MolStandardize import rdMolStandardize
89
+ from rdkit import RDLogger
90
+
91
+ log = logging.getLogger("workbench")
92
+ RDLogger.DisableLog("rdApp.warning")
93
+
94
+
95
+ # Helper context manager for timing
96
+ @contextmanager
97
+ def timer(name):
98
+ start = time.time()
99
+ yield
100
+ print(f"{name}: {time.time() - start:.2f}s")
101
+
102
+
103
+ class MolStandardizer:
104
+ """
105
+ Streamlined molecular standardizer for ADMET preprocessing
106
+ Uses ChEMBL standardization pipeline with RDKit
107
+ """
108
+
109
+ def __init__(self, canonicalize_tautomer: bool = True, remove_salts: bool = True):
110
+ """
111
+ Initialize standardizer with ChEMBL defaults
112
+
113
+ Args:
114
+ canonicalize_tautomer: Whether to canonicalize tautomers (default True)
115
+ remove_salts: Whether to remove salts/counterions (default True)
116
+ """
117
+ self.canonicalize_tautomer = canonicalize_tautomer
118
+ self.remove_salts = remove_salts
119
+ self.params = rdMolStandardize.CleanupParameters()
120
+ self.tautomer_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
121
+
122
+ def standardize(self, mol: Mol) -> Tuple[Optional[Mol], Optional[str]]:
123
+ """
124
+ Main standardization pipeline for ADMET
125
+
126
+ Pipeline:
127
+ 1. Cleanup (remove Hs, disconnect metals, normalize)
128
+ 2. Get largest fragment (optional - only if remove_salts=True)
129
+ 2a. Extract salt information BEFORE further modifications
130
+ 3. Neutralize charges
131
+ 4. Canonicalize tautomer (optional)
132
+
133
+ Args:
134
+ mol: RDKit molecule object
135
+
136
+ Returns:
137
+ Tuple of (standardized molecule or None if failed, salt SMILES or None)
138
+ """
139
+ if mol is None:
140
+ return None, None
141
+
142
+ try:
143
+ # Step 1: Cleanup
144
+ cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
145
+ if cleaned_mol is None:
146
+ return None, None
147
+
148
+ # If not doing any transformations, return early
149
+ if not self.remove_salts and not self.canonicalize_tautomer:
150
+ return cleaned_mol, None
151
+
152
+ salt_smiles = None
153
+ mol = cleaned_mol
154
+
155
+ # Step 2: Fragment handling (conditional based on remove_salts)
156
+ if self.remove_salts:
157
+ # Get parent molecule
158
+ parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
159
+ if parent_mol:
160
+ # Extract salt BEFORE any modifications to parent
161
+ salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
162
+ mol = parent_mol
163
+ else:
164
+ return None, None
165
+ # If not removing salts, keep the full molecule intact
166
+
167
+ # Step 3: Neutralize charges (skip if keeping salts to preserve ionic forms)
168
+ if self.remove_salts:
169
+ mol = rdMolStandardize.ChargeParent(mol, self.params, skipStandardize=True)
170
+ if mol is None:
171
+ return None, salt_smiles
172
+
173
+ # Step 4: Canonicalize tautomer (LAST STEP)
174
+ if self.canonicalize_tautomer:
175
+ mol = self.tautomer_enumerator.Canonicalize(mol)
176
+
177
+ return mol, salt_smiles
178
+
179
+ except Exception as e:
180
+ log.warning(f"Standardization failed: {e}")
181
+ return None, None
182
+
183
+ def _extract_salt(self, orig_mol: Mol, parent_mol: Mol) -> Optional[str]:
184
+ """
185
+ Extract salt/counterion by comparing original and parent molecules.
186
+
187
+ Detects and handles mixtures vs true salt forms:
188
+ - True salts: small (<= 6 heavy atoms) or charged fragments
189
+ - Mixtures: multiple large neutral organic fragments
190
+
191
+ Args:
192
+ orig_mol: Original molecule (after Cleanup, before FragmentParent)
193
+ parent_mol: Parent molecule (after FragmentParent, before tautomerization)
194
+
195
+ Returns:
196
+ SMILES string of salt components or None if no salts/mixture detected
197
+ """
198
+ try:
199
+ # Quick atom count check
200
+ if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
201
+ return None
202
+
203
+ # Quick heavy atom difference check
204
+ heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
205
+ if heavy_diff <= 0:
206
+ return None
207
+
208
+ # Get all fragments from original molecule
209
+ orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
210
+
211
+ # If only one fragment, no salt
212
+ if len(orig_frags) <= 1:
213
+ return None
214
+
215
+ # Get canonical SMILES of parent for comparison
216
+ parent_smiles = Chem.MolToSmiles(parent_mol, canonical=True)
217
+
218
+ # Separate fragments into salts vs potential mixture components
219
+ salt_frags = []
220
+ mixture_frags = []
221
+
222
+ for frag in orig_frags:
223
+ frag_smiles = Chem.MolToSmiles(frag, canonical=True)
224
+
225
+ # Skip the parent fragment
226
+ if frag_smiles == parent_smiles:
227
+ continue
228
+
229
+ # Classify fragment as salt or mixture component
230
+ num_heavy = frag.GetNumHeavyAtoms()
231
+ has_charge = any(atom.GetFormalCharge() != 0 for atom in frag.GetAtoms())
232
+
233
+ # More nuanced classification
234
+ if has_charge and num_heavy <= 10: # Small charged fragment - likely a salt
235
+ salt_frags.append(frag_smiles)
236
+ elif not has_charge and num_heavy <= 6: # Small neutral - could be solvent/salt
237
+ salt_frags.append(frag_smiles)
238
+ else:
239
+ # Large neutral fragment - likely part of a mixture
240
+ mixture_frags.append(frag_smiles)
241
+
242
+ # Check if this looks like a mixture
243
+ if mixture_frags:
244
+ # Log mixture detection
245
+ total_frags = len(orig_frags)
246
+ log.warning(
247
+ f"Mixture detected: {total_frags} total fragments, "
248
+ f"{len(mixture_frags)} large neutral organics. "
249
+ f"Removing: {'.'.join(mixture_frags + salt_frags)}"
250
+ )
251
+ # Return None for mixtures - don't pollute the salt column
252
+ return None
253
+
254
+ # Return actual salts only
255
+ return ".".join(salt_frags) if salt_frags else None
256
+
257
+ except Exception as e:
258
+ log.info(f"Salt extraction failed: {e}")
259
+ return None
260
+
261
+
262
+ def standardize(
263
+ df: pd.DataFrame,
264
+ canonicalize_tautomer: bool = True,
265
+ extract_salts: bool = True,
266
+ ) -> pd.DataFrame:
267
+ """
268
+ Standardize molecules in a DataFrame for ADMET modeling
269
+
270
+ Args:
271
+ df: Input DataFrame with SMILES column
272
+ canonicalize_tautomer: Whether to canonicalize tautomers (default: True)
273
+ extract_salts: Whether to remove and extract salts (default: True)
274
+ If False, keeps full molecule with salts/counterions intact,
275
+ skipping charge neutralization to preserve ionic character
276
+
277
+ Returns:
278
+ DataFrame with:
279
+ - orig_smiles: Original SMILES (preserved)
280
+ - smiles: Standardized SMILES (working column for downstream)
281
+ - salt: Removed salt/counterion SMILES (only if extract_salts=True)
282
+ None for mixtures or when no true salts present
283
+ """
284
+
285
+ # Check for the smiles column (any capitalization)
286
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
287
+ if smiles_column is None:
288
+ raise ValueError("Input DataFrame must have a 'smiles' column")
289
+
290
+ # Copy input DataFrame to avoid modifying original
291
+ result = df.copy()
292
+
293
+ # Preserve original SMILES if not already saved
294
+ if "orig_smiles" not in result.columns:
295
+ result["orig_smiles"] = result[smiles_column]
296
+
297
+ # Initialize standardizer
298
+ standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
299
+
300
+ def process_smiles(smiles: str) -> pd.Series:
301
+ """
302
+ Process a single SMILES string through standardization pipeline
303
+
304
+ Args:
305
+ smiles: Input SMILES string
306
+
307
+ Returns:
308
+ Series with standardized SMILES and extracted salt (if applicable)
309
+ """
310
+ # Handle missing values
311
+ if pd.isna(smiles) or smiles == "":
312
+ log.error("Encountered missing or empty SMILES string")
313
+ return pd.Series({"smiles": None, "salt": None})
314
+
315
+ # Early check for unreasonably long SMILES
316
+ if len(smiles) > 1000:
317
+ log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
318
+ return pd.Series({"smiles": None, "salt": None})
319
+
320
+ # Parse molecule
321
+ mol = Chem.MolFromSmiles(smiles)
322
+ if mol is None:
323
+ log.error(f"Invalid SMILES: {smiles}")
324
+ return pd.Series({"smiles": None, "salt": None})
325
+
326
+ # Full standardization with optional salt removal
327
+ std_mol, salt_smiles = standardizer.standardize(mol)
328
+
329
+ # After standardization, validate the result
330
+ if std_mol is not None:
331
+ # Check if molecule is reasonable
332
+ if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
333
+ log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
334
+ log.error(f"Original SMILES: {smiles}")
335
+ return pd.Series({"smiles": None, "salt": salt_smiles})
336
+
337
+ if std_mol is None:
338
+ return pd.Series(
339
+ {
340
+ "smiles": None,
341
+ "salt": salt_smiles, # May have extracted salt even if full standardization failed
342
+ }
343
+ )
344
+
345
+ # Convert back to SMILES
346
+ return pd.Series(
347
+ {"smiles": Chem.MolToSmiles(std_mol, canonical=True), "salt": salt_smiles if extract_salts else None}
348
+ )
349
+
350
+ # Process molecules
351
+ processed = result[smiles_column].apply(process_smiles)
352
+
353
+ # Update the dataframe with processed results
354
+ for col in ["smiles", "salt"]:
355
+ result[col] = processed[col]
356
+
357
+ return result
358
+
359
+
360
+ if __name__ == "__main__":
361
+
362
+ # Pandas display options for better readability
363
+ pd.set_option("display.max_columns", None)
364
+ pd.set_option("display.width", 1000)
365
+ pd.set_option("display.max_colwidth", 100)
366
+
367
+ # Test with DataFrame including various salt forms
368
+ test_data = pd.DataFrame(
369
+ {
370
+ "smiles": [
371
+ # Organic salts
372
+ "[Na+].CC(=O)[O-]", # Sodium acetate
373
+ "CC(=O)O.CCN", # Acetic acid + ethylamine (acid-base pair)
374
+ # Tautomers
375
+ "CC(=O)CC(C)=O", # Acetylacetone - tautomer
376
+ "c1ccc(O)nc1", # 2-hydroxypyridine/2-pyridone - tautomer
377
+ # Multi-fragment
378
+ "CCO.CC", # Ethanol + methane mixture
379
+ # Simple organics
380
+ "CC(C)(C)c1ccccc1", # tert-butylbenzene
381
+ # Carbonate salts
382
+ "[Na+].[Na+].[O-]C([O-])=O", # Sodium carbonate
383
+ "[Li+].[Li+].[O-]C([O-])=O", # Lithium carbonate
384
+ "[K+].[K+].[O-]C([O-])=O", # Potassium carbonate
385
+ "[Mg++].[O-]C([O-])=O", # Magnesium carbonate
386
+ "[Ca++].[O-]C([O-])=O", # Calcium carbonate
387
+ # Drug salts
388
+ "CC(C)NCC(O)c1ccc(O)c(O)c1.Cl", # Isoproterenol HCl
389
+ "CN1CCC[C@H]1c2cccnc2.[Cl-]", # Nicotine HCl
390
+ # Tautomer with salt
391
+ "c1ccc(O)nc1.Cl", # 2-hydroxypyridine with HCl
392
+ # Edge cases
393
+ None, # Missing value
394
+ "INVALID", # Invalid SMILES
395
+ ],
396
+ "compound_id": [f"C{i:03d}" for i in range(1, 17)],
397
+ }
398
+ )
399
+
400
+ # General test
401
+ print("Testing standardization with full dataset...")
402
+ standardize(test_data)
403
+
404
+ # Remove the last two rows to avoid errors with None and INVALID
405
+ test_data = test_data.iloc[:-2].reset_index(drop=True)
406
+
407
+ # Test WITHOUT salt removal (keeps full molecule)
408
+ print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
409
+ result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
410
+ display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
411
+ print(result_keep[display_order])
412
+
413
+ # Test WITH salt removal
414
+ print("\n" + "=" * 70)
415
+ print("Standardization REMOVING salts (extract_salts=True):")
416
+ result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
417
+ print(result_remove[display_order])
418
+
419
+ # Test with problematic cases specifically
420
+ print("\n" + "=" * 70)
421
+ print("Testing specific problematic cases:")
422
+ problem_cases = pd.DataFrame(
423
+ {
424
+ "smiles": [
425
+ "CC(=O)O.CCN", # Should extract CC(=O)O as salt
426
+ "CCO.CC", # Should return CC as salt
427
+ ],
428
+ "compound_id": ["TEST_C002", "TEST_C005"],
429
+ }
430
+ )
431
+
432
+ problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
433
+ print(problem_result[display_order])
434
+
435
+ # Performance test with larger dataset
436
+ from workbench.api import DataSource
437
+
438
+ print("\n" + "=" * 70)
439
+
440
+ ds = DataSource("aqsol_data")
441
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000]
442
+
443
+ for tautomer in [True, False]:
444
+ for extract in [True, False]:
445
+ print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
446
+ start_time = time.time()
447
+ std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
448
+ elapsed = time.time() - start_time
449
+ mol_per_sec = len(df) / elapsed
450
+ print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
@@ -7,13 +7,13 @@
7
7
  #
8
8
  import argparse
9
9
  import os
10
- import joblib
11
10
  from io import StringIO
12
11
  import pandas as pd
13
12
  import json
14
13
 
15
14
  # Local imports
16
- from local_utils import compute_molecular_descriptors
15
+ from mol_standardize import standardize
16
+ from mol_descriptors import compute_descriptors
17
17
 
18
18
 
19
19
  # TRAINING SECTION
@@ -32,15 +32,12 @@ if __name__ == "__main__":
32
32
  args = parser.parse_args()
33
33
 
34
34
  # This model doesn't get trained, it just a feature creation 'model'
35
-
36
- # Sagemaker seems to get upset if we don't save a model, so we'll create a placeholder model
37
- placeholder_model = {}
38
- joblib.dump(placeholder_model, os.path.join(args.model_dir, "model.joblib"))
35
+ # So we don't need to do anything here
39
36
 
40
37
 
41
38
  # Model loading and prediction functions
42
39
  def model_fn(model_dir):
43
- return joblib.load(os.path.join(model_dir, "model.joblib"))
40
+ return None
44
41
 
45
42
 
46
43
  def input_fn(input_data, content_type):
@@ -78,6 +75,7 @@ def output_fn(output_df, accept_type):
78
75
  # Prediction function
79
76
  def predict_fn(df, model):
80
77
 
81
- # Compute the Molecular Descriptors
82
- df = compute_molecular_descriptors(df)
78
+ # Standardize the molecule (extract salts) and then compute descriptors
79
+ df = standardize(df, extract_salts=True)
80
+ df = compute_descriptors(df)
83
81
  return df
@@ -15,7 +15,7 @@ import pandas as pd
15
15
  import json
16
16
 
17
17
  # Local imports
18
- from local_utils import compute_morgan_fingerprints
18
+ from fingerprints import compute_morgan_fingerprints
19
19
 
20
20
 
21
21
  # TRAINING SECTION
@@ -0,0 +1,194 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sklearn.neighbors import NearestNeighbors
5
+ from typing import List, Optional
6
+ import logging
7
+
8
+ # Workbench Imports
9
+ from workbench.algorithms.dataframe.proximity import Proximity
10
+ from workbench.algorithms.dataframe.projection_2d import Projection2D
11
+
12
+ # Set up logging
13
+ log = logging.getLogger("workbench")
14
+
15
+
16
+ class FeatureSpaceProximity(Proximity):
17
+ """Proximity computations for numeric feature spaces using Euclidean distance."""
18
+
19
+ def __init__(
20
+ self,
21
+ df: pd.DataFrame,
22
+ id_column: str,
23
+ features: List[str],
24
+ target: Optional[str] = None,
25
+ include_all_columns: bool = False,
26
+ ):
27
+ """
28
+ Initialize the FeatureSpaceProximity class.
29
+
30
+ Args:
31
+ df: DataFrame containing data for neighbor computations.
32
+ id_column: Name of the column used as the identifier.
33
+ features: List of feature column names to be used for neighbor computations.
34
+ target: Name of the target column. Defaults to None.
35
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
36
+ """
37
+ # Validate and filter features before calling parent init
38
+ self._raw_features = features
39
+ super().__init__(
40
+ df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
41
+ )
42
+
43
+ def _prepare_data(self) -> None:
44
+ """Filter out non-numeric features and drop NaN rows."""
45
+ # Validate features
46
+ self.features = self._validate_features(self.df, self._raw_features)
47
+
48
+ # Drop NaN rows for the features we're using
49
+ self.df = self.df.dropna(subset=self.features).copy()
50
+
51
+ def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
52
+ """Remove non-numeric features and log warnings."""
53
+ non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
54
+ if non_numeric:
55
+ log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
56
+ return [f for f in features if f not in non_numeric]
57
+
58
+ def _build_model(self) -> None:
59
+ """Standardize features and fit Nearest Neighbors model."""
60
+ self.scaler = StandardScaler()
61
+ X = self.scaler.fit_transform(self.df[self.features])
62
+ self.nn = NearestNeighbors().fit(X)
63
+
64
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
65
+ """Transform features using the fitted scaler."""
66
+ return self.scaler.transform(df[self.features])
67
+
68
+ def _project_2d(self) -> None:
69
+ """Project the numeric features to 2D for visualization."""
70
+ if len(self.features) >= 2:
71
+ self.df = Projection2D().fit_transform(self.df, features=self.features)
72
+
73
+
74
+ # Testing the FeatureSpaceProximity class
75
+ if __name__ == "__main__":
76
+
77
+ pd.set_option("display.max_columns", None)
78
+ pd.set_option("display.width", 1000)
79
+
80
+ # Create a sample DataFrame
81
+ data = {
82
+ "ID": [1, 2, 3, 4, 5],
83
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
84
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
85
+ "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
86
+ }
87
+ df = pd.DataFrame(data)
88
+
89
+ # Test the FeatureSpaceProximity class
90
+ features = ["Feature1", "Feature2", "Feature3"]
91
+ prox = FeatureSpaceProximity(df, id_column="ID", features=features)
92
+ print(prox.neighbors(1, n_neighbors=2))
93
+
94
+ # Test the neighbors method with radius
95
+ print(prox.neighbors(1, radius=2.0))
96
+
97
+ # Test with Features list
98
+ prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
99
+ print(prox.neighbors(1))
100
+
101
+ # Create a sample DataFrame
102
+ data = {
103
+ "id": ["a", "b", "c", "d", "e"], # Testing string IDs
104
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
105
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
106
+ "target": [1, 0, 1, 0, 5],
107
+ }
108
+ df = pd.DataFrame(data)
109
+
110
+ # Test with String Ids
111
+ prox = FeatureSpaceProximity(
112
+ df,
113
+ id_column="id",
114
+ features=["Feature1", "Feature2"],
115
+ target="target",
116
+ include_all_columns=True,
117
+ )
118
+ print(prox.neighbors(["a", "b"]))
119
+
120
+ # Test duplicate IDs
121
+ data = {
122
+ "id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
123
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
124
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
125
+ "target": [1, 0, 1, 0, 5],
126
+ }
127
+ df = pd.DataFrame(data)
128
+ prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
129
+ print(df.equals(prox.df))
130
+
131
+ # Test on real data from Workbench
132
+ from workbench.api import FeatureSet, Model
133
+
134
+ fs = FeatureSet("aqsol_features")
135
+ model = Model("aqsol-regression")
136
+ features = model.features()
137
+ df = fs.pull_dataframe()
138
+ prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
139
+ print("\n" + "=" * 80)
140
+ print("Testing Neighbors...")
141
+ print("=" * 80)
142
+ test_id = df[fs.id_column].tolist()[0]
143
+ print(f"\nNeighbors for ID {test_id}:")
144
+ print(prox.neighbors(test_id))
145
+
146
+ print("\n" + "=" * 80)
147
+ print("Testing isolated_compounds...")
148
+ print("=" * 80)
149
+
150
+ # Test isolated data in the top 1%
151
+ isolated_1pct = prox.isolated(top_percent=1.0)
152
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
153
+ print(isolated_1pct)
154
+
155
+ # Test isolated data in the top 5%
156
+ isolated_5pct = prox.isolated(top_percent=5.0)
157
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
158
+ print(isolated_5pct)
159
+
160
+ print("\n" + "=" * 80)
161
+ print("Testing target_gradients...")
162
+ print("=" * 80)
163
+
164
+ # Test with different parameters
165
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
166
+ print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
167
+ print(gradients_1pct)
168
+
169
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
170
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
171
+ print(gradients_5pct)
172
+
173
+ # Test proximity_stats
174
+ print("\n" + "=" * 80)
175
+ print("Testing proximity_stats...")
176
+ print("=" * 80)
177
+ stats = prox.proximity_stats()
178
+ print(stats)
179
+
180
+ # Plot the distance distribution using pandas
181
+ print("\n" + "=" * 80)
182
+ print("Plotting distance distribution...")
183
+ print("=" * 80)
184
+ prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
185
+
186
+ # Visualize the 2D projection
187
+ print("\n" + "=" * 80)
188
+ print("Visualizing 2D Projection...")
189
+ print("=" * 80)
190
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
191
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
192
+
193
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
194
+ unit_test.run()