workbench 0.8.171__py3-none-any.whl → 0.8.173__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  2. workbench/api/compound.py +1 -1
  3. workbench/api/feature_set.py +4 -4
  4. workbench/api/monitor.py +1 -16
  5. workbench/core/artifacts/artifact.py +11 -3
  6. workbench/core/artifacts/data_capture_core.py +315 -0
  7. workbench/core/artifacts/endpoint_core.py +9 -3
  8. workbench/core/artifacts/model_core.py +37 -14
  9. workbench/core/artifacts/monitor_core.py +33 -249
  10. workbench/core/cloud_platform/aws/aws_account_clamp.py +4 -1
  11. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  12. workbench/core/transforms/features_to_model/features_to_model.py +4 -4
  13. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +471 -0
  14. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +428 -0
  15. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  16. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +19 -9
  17. workbench/model_scripts/custom_models/uq_models/mapie.template +502 -0
  18. workbench/model_scripts/custom_models/uq_models/meta_uq.template +8 -5
  19. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  20. workbench/model_scripts/script_generation.py +5 -0
  21. workbench/model_scripts/xgb_model/generated_model_script.py +5 -5
  22. workbench/repl/workbench_shell.py +3 -3
  23. workbench/utils/chem_utils/__init__.py +0 -0
  24. workbench/utils/chem_utils/fingerprints.py +134 -0
  25. workbench/utils/chem_utils/misc.py +194 -0
  26. workbench/utils/chem_utils/mol_descriptors.py +471 -0
  27. workbench/utils/chem_utils/mol_standardize.py +428 -0
  28. workbench/utils/chem_utils/mol_tagging.py +348 -0
  29. workbench/utils/chem_utils/projections.py +209 -0
  30. workbench/utils/chem_utils/salts.py +256 -0
  31. workbench/utils/chem_utils/sdf.py +292 -0
  32. workbench/utils/chem_utils/toxicity.py +250 -0
  33. workbench/utils/chem_utils/vis.py +253 -0
  34. workbench/utils/model_utils.py +1 -1
  35. workbench/utils/monitor_utils.py +49 -56
  36. workbench/utils/pandas_utils.py +3 -3
  37. workbench/utils/workbench_sqs.py +1 -1
  38. workbench/utils/xgboost_model_utils.py +1 -0
  39. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  40. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/METADATA +1 -1
  41. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/RECORD +45 -34
  42. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  43. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  44. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  45. workbench/utils/chem_utils.py +0 -1556
  46. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/WHEEL +0 -0
  47. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/entry_points.txt +0 -0
  48. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/licenses/LICENSE +0 -0
  49. {workbench-0.8.171.dist-info → workbench-0.8.173.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,428 @@
1
+ """
2
+ mol_standardize.py - Molecular Standardization for ADMET Preprocessing
3
+ Following ChEMBL structure standardization pipeline
4
+
5
+ Purpose:
6
+ Standardizes chemical structures to ensure consistent molecular representations
7
+ for ADMET modeling. Handles tautomers, salts, charges, and structural variations
8
+ that can cause the same compound to be represented differently.
9
+
10
+ Standardization Pipeline:
11
+ 1. Cleanup
12
+ - Removes explicit hydrogens
13
+ - Disconnects metal atoms from organic fragments
14
+ - Normalizes functional groups (e.g., nitro, sulfoxide representations)
15
+
16
+ 2. Fragment Parent Selection (optional, controlled by extract_salts parameter)
17
+ - Identifies and keeps the largest organic fragment
18
+ - Removes salts, solvents, and counterions
19
+ - Example: [Na+].CC(=O)[O-] → CC(=O)O (keeps acetate, removes sodium)
20
+
21
+ 3. Charge Neutralization (optional, controlled by extract_salts parameter)
22
+ - Neutralizes charges where possible
23
+ - Only applied when extract_salts=True (following ChEMBL pipeline)
24
+ - Skipped when extract_salts=False to preserve ionic character
25
+ - Example: CC(=O)[O-] → CC(=O)O
26
+
27
+ 4. Tautomer Canonicalization (optional, default=True)
28
+ - Generates canonical tautomer form for consistency
29
+ - Example: Oc1ccccn1 → O=c1cccc[nH]1 (2-hydroxypyridine → 2-pyridone)
30
+
31
+ Output DataFrame Columns:
32
+ - orig_smiles: Original input SMILES (preserved for traceability)
33
+ - smiles: Standardized molecule (with or without salts based on extract_salts)
34
+ - salt: Removed salt/counterion as SMILES (only populated if extract_salts=True)
35
+
36
+ Salt Handling:
37
+ Salt forms can dramatically affect properties like solubility.
38
+ This module offers two modes for handling salts:
39
+
40
+ When extract_salts=True (default, ChEMBL standard):
41
+ - Removes salts/counterions to get parent molecule
42
+ - Neutralizes charges on the parent
43
+ - Records removed salts in 'salt' column
44
+ Input: [Na+].CC(=O)[O-] → Parent: CC(=O)O, Salt: [Na+]
45
+
46
+ When extract_salts=False (preserve full salt form):
47
+ - Keeps all fragments including salts/counterions
48
+ - Preserves ionic charges (no neutralization)
49
+
50
+ Mixture Detection:
51
+ The module detects and logs potential mixtures (vs true salt forms):
52
+ - Multiple large neutral organic fragments indicate a mixture
53
+ - Mixtures are logged but NOT recorded in the salt column
54
+ - True salts (small/charged fragments) are properly extracted
55
+
56
+ Downstream modeling options:
57
+ 1. Use parent only (standard approach for most ADMET properties)
58
+ 2. Include salt as a categorical or computed feature
59
+ 3. Model parent + salt effects hierarchically
60
+ 4. Use full salt form for properties like solubility/formulation
61
+
62
+ References:
63
+ - "ChEMBL Structure Pipeline" (Bento et al., 2020)
64
+ https://doi.org/10.1186/s13321-020-00456-1
65
+ - "Standardization and Validation with the RDKit" (Greg Landrum, RSC Open Science 2021)
66
+ https://github.com/greglandrum/RSC_OpenScience_Standardization_202104/blob/main/Standardization%20and%20Validation%20with%20the%20RDKit.ipynb
67
+
68
+ Usage:
69
+ from mol_standardize import standardize
70
+
71
+ # Basic usage (removes salts by default, ChEMBL standard)
72
+ df_std = standardize(df, smiles_column='smiles')
73
+
74
+ # Keep salts in the molecule (preserve ionic forms)
75
+ df_std = standardize(df, extract_salts=False)
76
+
77
+ # Without tautomer canonicalization (faster, less aggressive)
78
+ df_std = standardize(df, canonicalize_tautomer=False)
79
+ """
80
+
81
+ import logging
82
+ from typing import Optional, Tuple
83
+ import pandas as pd
84
+ from rdkit import Chem
85
+ from rdkit.Chem import Mol
86
+ from rdkit.Chem.MolStandardize import rdMolStandardize
87
+ from rdkit import RDLogger
88
+
89
+ log = logging.getLogger("workbench")
90
+ RDLogger.DisableLog("rdApp.warning")
91
+
92
+
93
+ class MolStandardizer:
94
+ """
95
+ Streamlined molecular standardizer for ADMET preprocessing
96
+ Uses ChEMBL standardization pipeline with RDKit
97
+ """
98
+
99
+ def __init__(self, canonicalize_tautomer: bool = True, remove_salts: bool = True):
100
+ """
101
+ Initialize standardizer with ChEMBL defaults
102
+
103
+ Args:
104
+ canonicalize_tautomer: Whether to canonicalize tautomers (default True)
105
+ remove_salts: Whether to remove salts/counterions (default True)
106
+ """
107
+ self.canonicalize_tautomer = canonicalize_tautomer
108
+ self.remove_salts = remove_salts
109
+ self.params = rdMolStandardize.CleanupParameters()
110
+ self.tautomer_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
111
+
112
+ def standardize(self, mol: Mol) -> Tuple[Optional[Mol], Optional[str]]:
113
+ """
114
+ Main standardization pipeline for ADMET
115
+
116
+ Pipeline:
117
+ 1. Cleanup (remove Hs, disconnect metals, normalize)
118
+ 2. Get largest fragment (optional - only if remove_salts=True)
119
+ 3. Neutralize charges
120
+ 4. Canonicalize tautomer (optional)
121
+
122
+ Args:
123
+ mol: RDKit molecule object
124
+
125
+ Returns:
126
+ Tuple of (standardized molecule or None if failed, salt SMILES or None)
127
+ """
128
+ if mol is None:
129
+ return None, None
130
+
131
+ try:
132
+ # Step 1: Cleanup
133
+ mol = rdMolStandardize.Cleanup(mol, self.params)
134
+ if mol is None:
135
+ return None, None
136
+
137
+ salt_smiles = None
138
+
139
+ # Step 2: Fragment handling (conditional based on remove_salts)
140
+ if self.remove_salts:
141
+ # Get parent molecule and extract salt information
142
+ parent_mol = rdMolStandardize.FragmentParent(mol, self.params)
143
+ if parent_mol:
144
+ salt_smiles = self._extract_salt(mol, parent_mol)
145
+ mol = parent_mol
146
+ else:
147
+ return None, None
148
+ # If not removing salts, keep the full molecule intact
149
+
150
+ # Step 3: Neutralize charges (skip if keeping salts to preserve ionic forms)
151
+ if self.remove_salts:
152
+ mol = rdMolStandardize.ChargeParent(mol, self.params, skipStandardize=True)
153
+ if mol is None:
154
+ return None, salt_smiles
155
+
156
+ # Step 4: Canonicalize tautomer
157
+ if self.canonicalize_tautomer:
158
+ mol = self.tautomer_enumerator.Canonicalize(mol)
159
+
160
+ return mol, salt_smiles
161
+
162
+ except Exception as e:
163
+ log.warning(f"Standardization failed: {e}")
164
+ return None, None
165
+
166
+ def _extract_salt(self, orig_mol: Mol, parent_mol: Mol) -> Optional[str]:
167
+ """
168
+ Extract salt/counterion by comparing original and parent molecules.
169
+
170
+ Detects and handles mixtures vs true salt forms:
171
+ - True salts: small (<= 6 heavy atoms) or charged fragments
172
+ - Mixtures: multiple large neutral organic fragments
173
+
174
+ Args:
175
+ orig_mol: Original molecule (before FragmentParent)
176
+ parent_mol: Parent molecule (after FragmentParent)
177
+
178
+ Returns:
179
+ SMILES string of salt components or None if no salts/mixture detected
180
+ """
181
+ try:
182
+ # Get all fragments from original molecule
183
+ orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
184
+
185
+ # If only one fragment, no salt
186
+ if len(orig_frags) <= 1:
187
+ return None
188
+
189
+ # Get canonical SMILES of parent for comparison
190
+ parent_smiles = Chem.MolToSmiles(parent_mol, canonical=True)
191
+
192
+ # Separate fragments into salts vs potential mixture components
193
+ salt_frags = []
194
+ mixture_frags = []
195
+
196
+ for frag in orig_frags:
197
+ frag_smiles = Chem.MolToSmiles(frag, canonical=True)
198
+
199
+ # Skip the parent fragment
200
+ if frag_smiles == parent_smiles:
201
+ continue
202
+
203
+ # Classify fragment as salt or mixture component
204
+ num_heavy = frag.GetNumHeavyAtoms()
205
+ has_charge = any(atom.GetFormalCharge() != 0 for atom in frag.GetAtoms())
206
+
207
+ # More nuanced classification
208
+ if has_charge and num_heavy <= 10: # Small charged fragment - likely a salt
209
+ salt_frags.append(frag_smiles)
210
+ elif not has_charge and num_heavy <= 6: # Small neutral - could be solvent/salt
211
+ salt_frags.append(frag_smiles)
212
+ else:
213
+ # Large neutral fragment - likely part of a mixture
214
+ mixture_frags.append(frag_smiles)
215
+
216
+ # Check if this looks like a mixture
217
+ if mixture_frags:
218
+ # Log mixture detection
219
+ total_frags = len(orig_frags)
220
+ log.warning(
221
+ f"Mixture detected: {total_frags} total fragments, "
222
+ f"{len(mixture_frags)} large neutral organics. "
223
+ f"Removing: {'.'.join(mixture_frags + salt_frags)}"
224
+ )
225
+ # Return None for mixtures - don't pollute the salt column
226
+ return None
227
+
228
+ # Return actual salts only
229
+ return ".".join(salt_frags) if salt_frags else None
230
+
231
+ except Exception as e:
232
+ log.info(f"Salt extraction failed: {e}")
233
+ return None
234
+
235
+
236
+ def standardize(
237
+ df: pd.DataFrame,
238
+ canonicalize_tautomer: bool = True,
239
+ extract_salts: bool = True,
240
+ ) -> pd.DataFrame:
241
+ """
242
+ Standardize molecules in a DataFrame for ADMET modeling
243
+
244
+ Args:
245
+ df: Input DataFrame with SMILES column
246
+ canonicalize_tautomer: Whether to canonicalize tautomers (default: True)
247
+ extract_salts: Whether to remove and extract salts (default: True)
248
+ If False, keeps full molecule with salts/counterions intact,
249
+ skipping charge neutralization to preserve ionic character
250
+
251
+ Returns:
252
+ DataFrame with:
253
+ - orig_smiles: Original SMILES (preserved)
254
+ - smiles: Standardized SMILES (working column for downstream)
255
+ - salt: Removed salt/counterion SMILES (only if extract_salts=True)
256
+ None for mixtures or when no true salts present
257
+ """
258
+
259
+ # Check for the smiles column (any capitalization)
260
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
261
+ if smiles_column is None:
262
+ raise ValueError("Input DataFrame must have a 'smiles' column")
263
+
264
+ # Copy input DataFrame to avoid modifying original
265
+ result = df.copy()
266
+
267
+ # Preserve original SMILES if not already saved
268
+ if "orig_smiles" not in result.columns:
269
+ result["orig_smiles"] = result[smiles_column]
270
+
271
+ # Initialize standardizer with salt removal control
272
+ standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
273
+
274
+ def process_smiles(smiles: str) -> pd.Series:
275
+ """
276
+ Process a single SMILES string through standardization pipeline
277
+
278
+ Args:
279
+ smiles: Input SMILES string
280
+
281
+ Returns:
282
+ Series with standardized SMILES and extracted salt (if applicable)
283
+ """
284
+ # Handle missing values
285
+ if pd.isna(smiles) or smiles == "":
286
+ log.error("Encountered missing or empty SMILES string")
287
+ return pd.Series({"smiles": None, "salt": None})
288
+
289
+ # Parse molecule
290
+ mol = Chem.MolFromSmiles(smiles)
291
+ if mol is None:
292
+ log.error(f"Invalid SMILES: {smiles}")
293
+ return pd.Series({"smiles": None, "salt": None})
294
+
295
+ # Full standardization with optional salt removal
296
+ std_mol, salt_smiles = standardizer.standardize(mol)
297
+
298
+ # After standardization, validate the result
299
+ if std_mol is not None:
300
+ # Check if molecule is reasonable
301
+ if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
302
+ log.error(f"Unusual molecule size: {std_mol.GetNumAtoms()} atoms")
303
+
304
+ if std_mol is None:
305
+ return pd.Series(
306
+ {
307
+ "smiles": None,
308
+ "salt": salt_smiles, # May have extracted salt even if full standardization failed
309
+ }
310
+ )
311
+
312
+ # Convert back to SMILES
313
+ return pd.Series(
314
+ {"smiles": Chem.MolToSmiles(std_mol, canonical=True), "salt": salt_smiles if extract_salts else None}
315
+ )
316
+
317
+ # Process molecules
318
+ processed = result[smiles_column].apply(process_smiles)
319
+
320
+ # Update the dataframe with processed results
321
+ for col in ["smiles", "salt"]:
322
+ result[col] = processed[col]
323
+
324
+ return result
325
+
326
+
327
+ if __name__ == "__main__":
328
+ import time
329
+ from workbench.api import DataSource
330
+
331
+ # Test with DataFrame including various salt forms
332
+ test_data = pd.DataFrame(
333
+ {
334
+ "smiles": [
335
+ # Organic salts
336
+ "[Na+].CC(=O)[O-]", # Sodium acetate
337
+ "CC(=O)O.CCN", # Acetic acid + ethylamine (acid-base pair)
338
+ # Tautomers
339
+ "CC(=O)CC(C)=O", # Acetylacetone - tautomer
340
+ "c1ccc(O)nc1", # 2-hydroxypyridine/2-pyridone - tautomer
341
+ # Multi-fragment
342
+ "CCO.CC", # Ethanol + methane mixture
343
+ # Simple organics
344
+ "CC(C)(C)c1ccccc1", # tert-butylbenzene
345
+ # Carbonate salts
346
+ "[Na+].[Na+].[O-]C([O-])=O", # Sodium carbonate
347
+ "[Li+].[Li+].[O-]C([O-])=O", # Lithium carbonate
348
+ "[K+].[K+].[O-]C([O-])=O", # Potassium carbonate
349
+ "[Mg++].[O-]C([O-])=O", # Magnesium carbonate
350
+ "[Ca++].[O-]C([O-])=O", # Calcium carbonate
351
+ # Drug salts
352
+ "CC(C)NCC(O)c1ccc(O)c(O)c1.Cl", # Isoproterenol HCl
353
+ "CN1CCC[C@H]1c2cccnc2.[Cl-]", # Nicotine HCl
354
+ # Tautomer with salt
355
+ "c1ccc(O)nc1.Cl", # 2-hydroxypyridine with HCl
356
+ # Edge cases
357
+ None, # Missing value
358
+ "INVALID", # Invalid SMILES
359
+ ],
360
+ "compound_id": [f"C{i:03d}" for i in range(1, 17)],
361
+ }
362
+ )
363
+
364
+ # General test
365
+ standardize(test_data)
366
+
367
+ # Remove the last two rows to avoid errors with None and INVALID
368
+ test_data = test_data.iloc[:-2].reset_index(drop=True)
369
+
370
+ # Test WITHOUT salt removal (keeps full molecule)
371
+ print("\nStandardization KEEPING salts (extract_salts=False):")
372
+ print("This preserves the full molecule including counterions")
373
+ result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
374
+ display_cols = ["compound_id", "orig_smiles", "smiles", "salt"]
375
+ print(result_keep[display_cols].to_string())
376
+
377
+ # Test WITH salt removal
378
+ print("\n" + "=" * 70)
379
+ print("Standardization REMOVING salts (extract_salts=True):")
380
+ print("This extracts parent molecule and records salt information")
381
+ result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
382
+ print(result_remove[display_cols].to_string())
383
+
384
+ # Test WITHOUT tautomerization (keeping salts)
385
+ print("\n" + "=" * 70)
386
+ print("Standardization KEEPING salts, NO tautomerization:")
387
+ result_no_taut = standardize(test_data, extract_salts=False, canonicalize_tautomer=False)
388
+ print(result_no_taut[display_cols].to_string())
389
+
390
+ # Show the difference for salt-containing molecules
391
+ print("\n" + "=" * 70)
392
+ print("Comparison showing differences:")
393
+ for idx, row in result_keep.iterrows():
394
+ keep_smiles = row["smiles"]
395
+ remove_smiles = result_remove.loc[idx, "smiles"]
396
+ no_taut_smiles = result_no_taut.loc[idx, "smiles"]
397
+ salt = result_remove.loc[idx, "salt"]
398
+
399
+ # Show differences when they exist
400
+ if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
401
+ print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
402
+ if keep_smiles != no_taut_smiles:
403
+ print(f" With salt + taut: {keep_smiles}")
404
+ print(f" With salt, no taut: {no_taut_smiles}")
405
+ if keep_smiles != remove_smiles:
406
+ print(f" Parent only + taut: {remove_smiles}")
407
+ if salt:
408
+ print(f" Extracted salt: {salt}")
409
+
410
+ # Summary statistics
411
+ print("\n" + "=" * 70)
412
+ print("Summary:")
413
+ print(f"Total molecules: {len(result_remove)}")
414
+ print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
415
+ unique_salts = result_remove["salt"].dropna().unique()
416
+ print(f"Unique salts found: {unique_salts[:5].tolist()}")
417
+
418
+ # Get a real dataset from Workbench and time the standardization
419
+ ds = DataSource("aqsol_data")
420
+ df = ds.pull_dataframe()[["id", "smiles"]]
421
+ start_time = time.time()
422
+ std_df = standardize(df, extract_salts=True, canonicalize_tautomer=True)
423
+ end_time = time.time()
424
+ print(f"\nStandardized {len(std_df)} molecules from Workbench in {end_time - start_time:.2f} seconds")
425
+ print(std_df.head())
426
+ print(f"Molecules with salts: {std_df['salt'].notna().sum()}")
427
+ unique_salts = std_df["salt"].dropna().unique()
428
+ print(f"Unique salts found: {unique_salts[:5].tolist()}")
@@ -7,13 +7,13 @@
7
7
  #
8
8
  import argparse
9
9
  import os
10
- import joblib
11
10
  from io import StringIO
12
11
  import pandas as pd
13
12
  import json
14
13
 
15
14
  # Local imports
16
- from local_utils import compute_molecular_descriptors
15
+ from mol_standardize import standardize
16
+ from mol_descriptors import compute_descriptors
17
17
 
18
18
 
19
19
  # TRAINING SECTION
@@ -32,15 +32,12 @@ if __name__ == "__main__":
32
32
  args = parser.parse_args()
33
33
 
34
34
  # This model doesn't get trained, it just a feature creation 'model'
35
-
36
- # Sagemaker seems to get upset if we don't save a model, so we'll create a placeholder model
37
- placeholder_model = {}
38
- joblib.dump(placeholder_model, os.path.join(args.model_dir, "model.joblib"))
35
+ # So we don't need to do anything here
39
36
 
40
37
 
41
38
  # Model loading and prediction functions
42
39
  def model_fn(model_dir):
43
- return joblib.load(os.path.join(model_dir, "model.joblib"))
40
+ return None
44
41
 
45
42
 
46
43
  def input_fn(input_data, content_type):
@@ -78,6 +75,7 @@ def output_fn(output_df, accept_type):
78
75
  # Prediction function
79
76
  def predict_fn(df, model):
80
77
 
81
- # Compute the Molecular Descriptors
82
- df = compute_molecular_descriptors(df)
78
+ # Standardize the molecule (extract salts) and then compute descriptors
79
+ df = standardize(df, extract_salts=True)
80
+ df = compute_descriptors(df)
83
81
  return df
@@ -1,6 +1,7 @@
1
1
  # Model: NGBoost Regressor with Distribution output
2
2
  from ngboost import NGBRegressor
3
- from xgboost import XGBRegressor # Base Estimator
3
+ from ngboost.distns import Cauchy, T
4
+ from xgboost import XGBRegressor # Point Estimator
4
5
  from sklearn.model_selection import train_test_split
5
6
 
6
7
  # Model Performance Scores
@@ -26,12 +27,12 @@ from proximity import Proximity
26
27
 
27
28
  # Template Placeholders
28
29
  TEMPLATE_PARAMS = {
29
- "id_column": "udm_mol_bat_id",
30
- "target": "udm_asy_res_intrinsic_clearance_ul_per_min_per_mg_protein",
31
- "features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v'],
30
+ "id_column": "udm_mol_id",
31
+ "target": "udm_asy_res_value",
32
+ "features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v', 'chiral_centers', 'r_cnt', 's_cnt', 'db_stereo', 'e_cnt', 'z_cnt', 'chiral_fp', 'db_fp'],
32
33
  "compressed_features": [],
33
34
  "train_all_data": False,
34
- "track_columns": ['udm_asy_res_intrinsic_clearance_ul_per_min_per_mg_protein']
35
+ "track_columns": "udm_asy_res_value"
35
36
  }
36
37
 
37
38
 
@@ -106,8 +107,10 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
106
107
  return df, category_mappings
107
108
 
108
109
 
109
- def decompress_features(df: pd.DataFrame, features: List[str], compressed_features: List[str]) -> Tuple[pd.DataFrame, List[str]]:
110
- """Prepare features for the XGBoost model
110
+ def decompress_features(
111
+ df: pd.DataFrame, features: List[str], compressed_features: List[str]
112
+ ) -> Tuple[pd.DataFrame, List[str]]:
113
+ """Prepare features for the model by decompressing bitstring features
111
114
 
112
115
  Args:
113
116
  df (pd.DataFrame): The features DataFrame
@@ -132,7 +135,7 @@ def decompress_features(df: pd.DataFrame, features: List[str], compressed_featur
132
135
  )
133
136
 
134
137
  # Decompress the specified compressed features
135
- decompressed_features = features
138
+ decompressed_features = features.copy()
136
139
  for feature in compressed_features:
137
140
  if (feature not in df.columns) or (feature not in features):
138
141
  print(f"Feature '{feature}' not in the features list, skipping decompression.")
@@ -227,7 +230,14 @@ if __name__ == "__main__":
227
230
 
228
231
  # We're using XGBoost for point predictions and NGBoost for uncertainty quantification
229
232
  xgb_model = XGBRegressor()
230
- ngb_model = NGBRegressor()
233
+ ngb_model = NGBRegressor() # Dist=Cauchy) Seems to give HUGE prediction intervals
234
+ ngb_model = NGBRegressor(
235
+ Dist=T,
236
+ learning_rate=0.005,
237
+ minibatch_frac=0.1, # Very small batches
238
+ col_sample=0.8 # This parameter DOES exist
239
+ ) # Testing this out
240
+ print("NGBoost using T distribution for uncertainty quantification")
231
241
 
232
242
  # Prepare features and targets for training
233
243
  X_train = df_train[features]