workbench 0.8.176__py3-none-any.whl → 0.8.178__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

@@ -22,7 +22,7 @@ from typing import List, Tuple
22
22
 
23
23
  # Template Placeholders
24
24
  TEMPLATE_PARAMS = {
25
- "target": "udm_asy_res_value",
25
+ "target": "logs",
26
26
  "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
27
27
  "compressed_features": [],
28
28
  "train_all_data": True
@@ -242,7 +242,7 @@ if __name__ == "__main__":
242
242
  print(f"R2: {xgb_r2:.3f}")
243
243
 
244
244
  # Define confidence levels we want to model
245
- confidence_levels = [0.50, 0.80, 0.90, 0.95] # 50%, 80%, 90%, 95% confidence intervals
245
+ confidence_levels = [0.50, 0.68, 0.80, 0.90, 0.95] # 50%, 68%, 80%, 90%, 95% confidence intervals
246
246
 
247
247
  # Store MAPIE models for each confidence level
248
248
  mapie_models = {}
@@ -459,6 +459,9 @@ def predict_fn(df, models) -> pd.DataFrame:
459
459
  if conf_level == 0.50: # 50% CI
460
460
  df["q_25"] = y_pis[:, 0, 0]
461
461
  df["q_75"] = y_pis[:, 1, 0]
462
+ elif conf_level == 0.68: # 68% CI
463
+ df["q_16"] = y_pis[:, 0, 0]
464
+ df["q_84"] = y_pis[:, 1, 0]
462
465
  elif conf_level == 0.80: # 80% CI
463
466
  df["q_10"] = y_pis[:, 0, 0]
464
467
  df["q_90"] = y_pis[:, 1, 0]
@@ -472,23 +475,16 @@ def predict_fn(df, models) -> pd.DataFrame:
472
475
  # Add median (q_50) from XGBoost prediction
473
476
  df["q_50"] = df["prediction"]
474
477
 
475
- # Calculate uncertainty metrics based on 95% interval
476
- interval_width = df["q_975"] - df["q_025"]
477
- df["prediction_std"] = interval_width / 3.92
478
+ # Calculate a psueduo-standard deviation from the 68% interval width
479
+ df["prediction_std"] = (df["q_84"] - df["q_16"]) / 2.0
478
480
 
479
481
  # Reorder the quantile columns for easier reading
480
- quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
482
+ quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_75", "q_84", "q_90", "q_95", "q_975"]
481
483
  other_cols = [col for col in df.columns if col not in quantile_cols]
482
484
  df = df[other_cols + quantile_cols]
483
485
 
484
- # Uncertainty score
485
- df["uncertainty_score"] = interval_width / (np.abs(df["prediction"]) + 1e-6)
486
-
487
- # Confidence bands
488
- df["confidence_band"] = pd.cut(
489
- df["uncertainty_score"],
490
- bins=[0, 0.5, 1.0, 2.0, np.inf],
491
- labels=["high", "medium", "low", "very_low"]
492
- )
486
+ # Adjust the outer quantiles to ensure they encompass the prediction
487
+ df["q_025"] = np.minimum(df["q_025"], df["prediction"])
488
+ df["q_975"] = np.maximum(df["q_975"], df["prediction"])
493
489
 
494
490
  return df
@@ -242,7 +242,7 @@ if __name__ == "__main__":
242
242
  print(f"R2: {xgb_r2:.3f}")
243
243
 
244
244
  # Define confidence levels we want to model
245
- confidence_levels = [0.50, 0.80, 0.90, 0.95] # 50%, 80%, 90%, 95% confidence intervals
245
+ confidence_levels = [0.50, 0.68, 0.80, 0.90, 0.95] # 50%, 68%, 80%, 90%, 95% confidence intervals
246
246
 
247
247
  # Store MAPIE models for each confidence level
248
248
  mapie_models = {}
@@ -459,6 +459,9 @@ def predict_fn(df, models) -> pd.DataFrame:
459
459
  if conf_level == 0.50: # 50% CI
460
460
  df["q_25"] = y_pis[:, 0, 0]
461
461
  df["q_75"] = y_pis[:, 1, 0]
462
+ elif conf_level == 0.68: # 68% CI
463
+ df["q_16"] = y_pis[:, 0, 0]
464
+ df["q_84"] = y_pis[:, 1, 0]
462
465
  elif conf_level == 0.80: # 80% CI
463
466
  df["q_10"] = y_pis[:, 0, 0]
464
467
  df["q_90"] = y_pis[:, 1, 0]
@@ -472,23 +475,16 @@ def predict_fn(df, models) -> pd.DataFrame:
472
475
  # Add median (q_50) from XGBoost prediction
473
476
  df["q_50"] = df["prediction"]
474
477
 
475
- # Calculate uncertainty metrics based on 95% interval
476
- interval_width = df["q_975"] - df["q_025"]
477
- df["prediction_std"] = interval_width / 3.92
478
+ # Calculate a psueduo-standard deviation from the 68% interval width
479
+ df["prediction_std"] = (df["q_84"] - df["q_16"]) / 2.0
478
480
 
479
481
  # Reorder the quantile columns for easier reading
480
- quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
482
+ quantile_cols = ["q_025", "q_05", "q_10", "q_16", "q_25", "q_75", "q_84", "q_90", "q_95", "q_975"]
481
483
  other_cols = [col for col in df.columns if col not in quantile_cols]
482
484
  df = df[other_cols + quantile_cols]
483
485
 
484
- # Uncertainty score
485
- df["uncertainty_score"] = interval_width / (np.abs(df["prediction"]) + 1e-6)
486
-
487
- # Confidence bands
488
- df["confidence_band"] = pd.cut(
489
- df["uncertainty_score"],
490
- bins=[0, 0.5, 1.0, 2.0, np.inf],
491
- labels=["high", "medium", "low", "very_low"]
492
- )
486
+ # Adjust the outer quantiles to ensure they encompass the prediction
487
+ df["q_025"] = np.minimum(df["q_025"], df["prediction"])
488
+ df["q_975"] = np.maximum(df["q_975"], df["prediction"])
493
489
 
494
490
  return df
@@ -28,11 +28,11 @@ from typing import List, Tuple
28
28
 
29
29
  # Template Parameters
30
30
  TEMPLATE_PARAMS = {
31
- "model_type": "classifier",
32
- "target": "class",
31
+ "model_type": "regressor",
32
+ "target": "udm_asy_res_value",
33
33
  "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
34
34
  "compressed_features": [],
35
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/sol-class-f1-100/training",
35
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/pka-a1-reg-0-nightly-100-test/training",
36
36
  "train_all_data": True
37
37
  }
38
38
 
@@ -13,12 +13,13 @@ cm = ConfigManager()
13
13
  workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
14
 
15
15
 
16
- def submit_to_sqs(script_path: str, size: str = "small") -> None:
16
+ def submit_to_sqs(script_path: str, size: str = "small", realtime: bool = False) -> None:
17
17
  """
18
18
  Upload script to S3 and submit message to SQS queue for processing.
19
19
  Args:
20
20
  script_path: Local path to the ML pipeline script
21
21
  size: Job size tier - "small" (default), "medium", or "large"
22
+ realtime: If True, sets serverless=False for real-time processing (default: False, meaning serverless=True)
22
23
  """
23
24
  print(f"\n{'=' * 60}")
24
25
  print("🚀 SUBMITTING ML PIPELINE JOB")
@@ -33,6 +34,7 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
33
34
 
34
35
  print(f"📄 Script: {script_file.name}")
35
36
  print(f"📏 Size tier: {size}")
37
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
36
38
  print(f"🪣 Bucket: {workbench_bucket}")
37
39
  sqs = AWSAccountClamp().boto3_session.client("sqs")
38
40
  script_name = script_file.name
@@ -88,6 +90,10 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
88
90
 
89
91
  # Prepare message
90
92
  message = {"script_path": s3_path, "size": size}
93
+
94
+ # Set serverless environment variable (defaults to True, False if --realtime)
95
+ message["environment"] = {"SERVERLESS": "False" if realtime else "True"}
96
+
91
97
  print("\n📨 Sending message to SQS...")
92
98
 
93
99
  # Send the message to SQS
@@ -110,6 +116,7 @@ def submit_to_sqs(script_path: str, size: str = "small") -> None:
110
116
  print(f"{'=' * 60}")
111
117
  print(f"📄 Script: {script_name}")
112
118
  print(f"📏 Size: {size}")
119
+ print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
113
120
  print(f"🆔 Message ID: {message_id}")
114
121
  print("\n🔍 MONITORING LOCATIONS:")
115
122
  print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
@@ -126,9 +133,14 @@ def main():
126
133
  parser.add_argument(
127
134
  "--size", default="small", choices=["small", "medium", "large"], help="Job size tier (default: small)"
128
135
  )
136
+ parser.add_argument(
137
+ "--realtime",
138
+ action="store_true",
139
+ help="Run in real-time mode (sets serverless=False). Default is serverless mode (serverless=True)",
140
+ )
129
141
  args = parser.parse_args()
130
142
  try:
131
- submit_to_sqs(args.script_file, args.size)
143
+ submit_to_sqs(args.script_file, args.size, realtime=args.realtime)
132
144
  except Exception as e:
133
145
  print(f"\n❌ ERROR: {e}")
134
146
  log.error(f"Error: {e}")
@@ -91,16 +91,27 @@ import logging
91
91
  import pandas as pd
92
92
  import numpy as np
93
93
  import re
94
+ import time
95
+ from contextlib import contextmanager
94
96
  from rdkit import Chem
95
97
  from rdkit.Chem import Descriptors, rdCIPLabeler
96
98
  from rdkit.ML.Descriptors import MoleculeDescriptors
97
99
  from mordred import Calculator as MordredCalculator
98
100
  from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
99
101
 
102
+
100
103
  logger = logging.getLogger("workbench")
101
104
  logger.setLevel(logging.DEBUG)
102
105
 
103
106
 
107
+ # Helper context manager for timing
108
+ @contextmanager
109
+ def timer(name):
110
+ start = time.time()
111
+ yield
112
+ print(f"{name}: {time.time() - start:.2f}s")
113
+
114
+
104
115
  def compute_stereochemistry_features(mol):
105
116
  """
106
117
  Compute stereochemistry descriptors using modern RDKit methods.
@@ -280,9 +291,11 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
280
291
  descriptor_values.append([np.nan] * len(all_descriptors))
281
292
 
282
293
  # Create RDKit features DataFrame
283
- rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames(), index=result.index)
294
+ rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
284
295
 
285
296
  # Add RDKit features to result
297
+ # Remove any columns from result that exist in rdkit_features_df
298
+ result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
286
299
  result = pd.concat([result, rdkit_features_df], axis=1)
287
300
 
288
301
  # Compute Mordred descriptors
@@ -299,7 +312,7 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
299
312
 
300
313
  # Compute Mordred descriptors
301
314
  valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
302
- mordred_df = calc.pandas(valid_mols, nproc=1) # For serverless, use nproc=1
315
+ mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
303
316
 
304
317
  # Replace values for invalid molecules with NaN
305
318
  for i, mol in enumerate(molecules):
@@ -310,10 +323,9 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
310
323
  for col in mordred_df.columns:
311
324
  mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
312
325
 
313
- # Set index to match result DataFrame
314
- mordred_df.index = result.index
315
-
316
326
  # Add Mordred features to result
327
+ # Remove any columns from result that exist in mordred
328
+ result = result.drop(columns=result.columns.intersection(mordred_df.columns))
317
329
  result = pd.concat([result, mordred_df], axis=1)
318
330
 
319
331
  # Compute stereochemistry features if requested
@@ -326,9 +338,10 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
326
338
  stereo_features.append(stereo_dict)
327
339
 
328
340
  # Create stereochemistry DataFrame
329
- stereo_df = pd.DataFrame(stereo_features, index=result.index)
341
+ stereo_df = pd.DataFrame(stereo_features)
330
342
 
331
343
  # Add stereochemistry features to result
344
+ result = result.drop(columns=result.columns.intersection(stereo_df.columns))
332
345
  result = pd.concat([result, stereo_df], axis=1)
333
346
 
334
347
  logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
@@ -357,7 +370,6 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
357
370
 
358
371
 
359
372
  if __name__ == "__main__":
360
- import time
361
373
  from mol_standardize import standardize
362
374
  from workbench.api import DataSource
363
375
 
@@ -81,6 +81,8 @@ Usage:
81
81
  import logging
82
82
  from typing import Optional, Tuple
83
83
  import pandas as pd
84
+ import time
85
+ from contextlib import contextmanager
84
86
  from rdkit import Chem
85
87
  from rdkit.Chem import Mol
86
88
  from rdkit.Chem.MolStandardize import rdMolStandardize
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
90
92
  RDLogger.DisableLog("rdApp.warning")
91
93
 
92
94
 
95
+ # Helper context manager for timing
96
+ @contextmanager
97
+ def timer(name):
98
+ start = time.time()
99
+ yield
100
+ print(f"{name}: {time.time() - start:.2f}s")
101
+
102
+
93
103
  class MolStandardizer:
94
104
  """
95
105
  Streamlined molecular standardizer for ADMET preprocessing
@@ -116,6 +126,7 @@ class MolStandardizer:
116
126
  Pipeline:
117
127
  1. Cleanup (remove Hs, disconnect metals, normalize)
118
128
  2. Get largest fragment (optional - only if remove_salts=True)
129
+ 2a. Extract salt information BEFORE further modifications
119
130
  3. Neutralize charges
120
131
  4. Canonicalize tautomer (optional)
121
132
 
@@ -130,18 +141,24 @@ class MolStandardizer:
130
141
 
131
142
  try:
132
143
  # Step 1: Cleanup
133
- mol = rdMolStandardize.Cleanup(mol, self.params)
134
- if mol is None:
144
+ cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
145
+ if cleaned_mol is None:
135
146
  return None, None
136
147
 
148
+ # If not doing any transformations, return early
149
+ if not self.remove_salts and not self.canonicalize_tautomer:
150
+ return cleaned_mol, None
151
+
137
152
  salt_smiles = None
153
+ mol = cleaned_mol
138
154
 
139
155
  # Step 2: Fragment handling (conditional based on remove_salts)
140
156
  if self.remove_salts:
141
- # Get parent molecule and extract salt information
142
- parent_mol = rdMolStandardize.FragmentParent(mol, self.params)
157
+ # Get parent molecule
158
+ parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
143
159
  if parent_mol:
144
- salt_smiles = self._extract_salt(mol, parent_mol)
160
+ # Extract salt BEFORE any modifications to parent
161
+ salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
145
162
  mol = parent_mol
146
163
  else:
147
164
  return None, None
@@ -153,7 +170,7 @@ class MolStandardizer:
153
170
  if mol is None:
154
171
  return None, salt_smiles
155
172
 
156
- # Step 4: Canonicalize tautomer
173
+ # Step 4: Canonicalize tautomer (LAST STEP)
157
174
  if self.canonicalize_tautomer:
158
175
  mol = self.tautomer_enumerator.Canonicalize(mol)
159
176
 
@@ -172,13 +189,22 @@ class MolStandardizer:
172
189
  - Mixtures: multiple large neutral organic fragments
173
190
 
174
191
  Args:
175
- orig_mol: Original molecule (before FragmentParent)
176
- parent_mol: Parent molecule (after FragmentParent)
192
+ orig_mol: Original molecule (after Cleanup, before FragmentParent)
193
+ parent_mol: Parent molecule (after FragmentParent, before tautomerization)
177
194
 
178
195
  Returns:
179
196
  SMILES string of salt components or None if no salts/mixture detected
180
197
  """
181
198
  try:
199
+ # Quick atom count check
200
+ if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
201
+ return None
202
+
203
+ # Quick heavy atom difference check
204
+ heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
205
+ if heavy_diff <= 0:
206
+ return None
207
+
182
208
  # Get all fragments from original molecule
183
209
  orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
184
210
 
@@ -268,7 +294,7 @@ def standardize(
268
294
  if "orig_smiles" not in result.columns:
269
295
  result["orig_smiles"] = result[smiles_column]
270
296
 
271
- # Initialize standardizer with salt removal control
297
+ # Initialize standardizer
272
298
  standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
273
299
 
274
300
  def process_smiles(smiles: str) -> pd.Series:
@@ -286,6 +312,11 @@ def standardize(
286
312
  log.error("Encountered missing or empty SMILES string")
287
313
  return pd.Series({"smiles": None, "salt": None})
288
314
 
315
+ # Early check for unreasonably long SMILES
316
+ if len(smiles) > 1000:
317
+ log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
318
+ return pd.Series({"smiles": None, "salt": None})
319
+
289
320
  # Parse molecule
290
321
  mol = Chem.MolFromSmiles(smiles)
291
322
  if mol is None:
@@ -299,7 +330,9 @@ def standardize(
299
330
  if std_mol is not None:
300
331
  # Check if molecule is reasonable
301
332
  if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
302
- log.error(f"Unusual molecule size: {std_mol.GetNumAtoms()} atoms")
333
+ log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
334
+ log.error(f"Original SMILES: {smiles}")
335
+ return pd.Series({"smiles": None, "salt": salt_smiles})
303
336
 
304
337
  if std_mol is None:
305
338
  return pd.Series(
@@ -325,8 +358,11 @@ def standardize(
325
358
 
326
359
 
327
360
  if __name__ == "__main__":
328
- import time
329
- from workbench.api import DataSource
361
+
362
+ # Pandas display options for better readability
363
+ pd.set_option("display.max_columns", None)
364
+ pd.set_option("display.width", 1000)
365
+ pd.set_option("display.max_colwidth", 100)
330
366
 
331
367
  # Test with DataFrame including various salt forms
332
368
  test_data = pd.DataFrame(
@@ -362,67 +398,53 @@ if __name__ == "__main__":
362
398
  )
363
399
 
364
400
  # General test
401
+ print("Testing standardization with full dataset...")
365
402
  standardize(test_data)
366
403
 
367
404
  # Remove the last two rows to avoid errors with None and INVALID
368
405
  test_data = test_data.iloc[:-2].reset_index(drop=True)
369
406
 
370
407
  # Test WITHOUT salt removal (keeps full molecule)
371
- print("\nStandardization KEEPING salts (extract_salts=False):")
372
- print("This preserves the full molecule including counterions")
408
+ print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
373
409
  result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
374
- display_cols = ["compound_id", "orig_smiles", "smiles", "salt"]
375
- print(result_keep[display_cols].to_string())
410
+ display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
411
+ print(result_keep[display_order])
376
412
 
377
413
  # Test WITH salt removal
378
414
  print("\n" + "=" * 70)
379
415
  print("Standardization REMOVING salts (extract_salts=True):")
380
- print("This extracts parent molecule and records salt information")
381
416
  result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
382
- print(result_remove[display_cols].to_string())
417
+ print(result_remove[display_order])
383
418
 
384
- # Test WITHOUT tautomerization (keeping salts)
419
+ # Test with problematic cases specifically
385
420
  print("\n" + "=" * 70)
386
- print("Standardization KEEPING salts, NO tautomerization:")
387
- result_no_taut = standardize(test_data, extract_salts=False, canonicalize_tautomer=False)
388
- print(result_no_taut[display_cols].to_string())
421
+ print("Testing specific problematic cases:")
422
+ problem_cases = pd.DataFrame(
423
+ {
424
+ "smiles": [
425
+ "CC(=O)O.CCN", # Should extract CC(=O)O as salt
426
+ "CCO.CC", # Should return CC as salt
427
+ ],
428
+ "compound_id": ["TEST_C002", "TEST_C005"],
429
+ }
430
+ )
431
+
432
+ problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
433
+ print(problem_result[display_order])
434
+
435
+ # Performance test with larger dataset
436
+ from workbench.api import DataSource
389
437
 
390
- # Show the difference for salt-containing molecules
391
- print("\n" + "=" * 70)
392
- print("Comparison showing differences:")
393
- for idx, row in result_keep.iterrows():
394
- keep_smiles = row["smiles"]
395
- remove_smiles = result_remove.loc[idx, "smiles"]
396
- no_taut_smiles = result_no_taut.loc[idx, "smiles"]
397
- salt = result_remove.loc[idx, "salt"]
398
-
399
- # Show differences when they exist
400
- if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
401
- print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
402
- if keep_smiles != no_taut_smiles:
403
- print(f" With salt + taut: {keep_smiles}")
404
- print(f" With salt, no taut: {no_taut_smiles}")
405
- if keep_smiles != remove_smiles:
406
- print(f" Parent only + taut: {remove_smiles}")
407
- if salt:
408
- print(f" Extracted salt: {salt}")
409
-
410
- # Summary statistics
411
438
  print("\n" + "=" * 70)
412
- print("Summary:")
413
- print(f"Total molecules: {len(result_remove)}")
414
- print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
415
- unique_salts = result_remove["salt"].dropna().unique()
416
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
417
439
 
418
- # Get a real dataset from Workbench and time the standardization
419
440
  ds = DataSource("aqsol_data")
420
- df = ds.pull_dataframe()[["id", "smiles"]]
421
- start_time = time.time()
422
- std_df = standardize(df, extract_salts=True, canonicalize_tautomer=True)
423
- end_time = time.time()
424
- print(f"\nStandardized {len(std_df)} molecules from Workbench in {end_time - start_time:.2f} seconds")
425
- print(std_df.head())
426
- print(f"Molecules with salts: {std_df['salt'].notna().sum()}")
427
- unique_salts = std_df["salt"].dropna().unique()
428
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
441
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000]
442
+
443
+ for tautomer in [True, False]:
444
+ for extract in [True, False]:
445
+ print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
446
+ start_time = time.time()
447
+ std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
448
+ elapsed = time.time() - start_time
449
+ mol_per_sec = len(df) / elapsed
450
+ print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
@@ -222,32 +222,40 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
222
222
  lower_95, upper_95 = df["q_025"], df["q_975"]
223
223
  lower_90, upper_90 = df["q_05"], df["q_95"]
224
224
  lower_80, upper_80 = df["q_10"], df["q_90"]
225
+ lower_68, upper_68 = df["q_16"], df["q_84"]
225
226
  lower_50, upper_50 = df["q_25"], df["q_75"]
226
227
  elif "prediction_std" in df.columns:
227
228
  lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
228
229
  upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
230
+ lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
231
+ upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
232
+ lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
233
+ upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
234
+ lower_68 = df["prediction"] - 1.0 * df["prediction_std"]
235
+ upper_68 = df["prediction"] + 1.0 * df["prediction_std"]
229
236
  lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
230
237
  upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
231
238
  else:
232
239
  raise ValueError(
233
240
  "Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
234
241
  )
242
+ avg_std = df["prediction_std"].mean()
243
+ median_std = df["prediction_std"].median()
235
244
  coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
236
245
  coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
237
246
  coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
247
+ coverage_68 = np.mean((df[target_col] >= lower_68) & (df[target_col] <= upper_68))
238
248
  coverage_50 = np.mean((df[target_col] >= lower_50) & (df[target_col] <= upper_50))
239
249
  avg_width_95 = np.mean(upper_95 - lower_95)
240
250
  avg_width_90 = np.mean(upper_90 - lower_90)
241
251
  avg_width_80 = np.mean(upper_80 - lower_80)
242
252
  avg_width_50 = np.mean(upper_50 - lower_50)
253
+ avg_width_68 = np.mean(upper_68 - lower_68)
243
254
 
244
255
  # --- CRPS (measures calibration + sharpness) ---
245
- if "prediction_std" in df.columns:
246
- z = (df[target_col] - df["prediction"]) / df["prediction_std"]
247
- crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
248
- mean_crps = np.mean(crps)
249
- else:
250
- mean_crps = np.nan
256
+ z = (df[target_col] - df["prediction"]) / df["prediction_std"]
257
+ crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
258
+ mean_crps = np.mean(crps)
251
259
 
252
260
  # --- Interval Score @ 95% (penalizes miscoverage) ---
253
261
  alpha_95 = 0.05
@@ -265,27 +273,37 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
265
273
 
266
274
  # Collect results
267
275
  results = {
268
- "coverage_95": coverage_95,
269
- "coverage_90": coverage_90,
270
- "coverage_80": coverage_80,
271
276
  "coverage_50": coverage_50,
272
- "avg_width_95": avg_width_95,
277
+ "coverage_68": coverage_68,
278
+ "coverage_80": coverage_80,
279
+ "coverage_90": coverage_90,
280
+ "coverage_95": coverage_95,
281
+ "median_std": median_std,
282
+ "avg_std": avg_std,
273
283
  "avg_width_50": avg_width_50,
274
- "crps": mean_crps,
275
- "interval_score_95": mean_is_95,
276
- "adaptive_calibration": adaptive_calibration,
284
+ "avg_width_68": avg_width_68,
285
+ "avg_width_80": avg_width_80,
286
+ "avg_width_90": avg_width_90,
287
+ "avg_width_95": avg_width_95,
288
+ # "crps": mean_crps,
289
+ # "interval_score_95": mean_is_95,
290
+ # "adaptive_calibration": adaptive_calibration,
277
291
  "n_samples": len(df),
278
292
  }
279
293
 
280
294
  print("\n=== UQ Metrics ===")
281
- print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
282
- print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
283
- print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
284
295
  print(f"Coverage @ 50%: {coverage_50:.3f} (target: 0.50)")
285
- print(f"Average 95% Width: {avg_width_95:.3f}")
286
- print(f"Average 90% Width: {avg_width_90:.3f}")
287
- print(f"Average 80% Width: {avg_width_80:.3f}")
296
+ print(f"Coverage @ 68%: {coverage_68:.3f} (target: 0.68)")
297
+ print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
298
+ print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
299
+ print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
300
+ print(f"Median Prediction StdDev: {median_std:.3f}")
301
+ print(f"Avg Prediction StdDev: {avg_std:.3f}")
288
302
  print(f"Average 50% Width: {avg_width_50:.3f}")
303
+ print(f"Average 68% Width: {avg_width_68:.3f}")
304
+ print(f"Average 80% Width: {avg_width_80:.3f}")
305
+ print(f"Average 90% Width: {avg_width_90:.3f}")
306
+ print(f"Average 95% Width: {avg_width_95:.3f}")
289
307
  print(f"CRPS: {mean_crps:.3f} (lower is better)")
290
308
  print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
291
309
  print(f"Adaptive Calibration: {adaptive_calibration:.3f} (higher is better, target: >0.5)")
@@ -325,9 +343,3 @@ if __name__ == "__main__":
325
343
  df = end.auto_inference(capture=True)
326
344
  results = uq_metrics(df, target_col="solubility")
327
345
  print(results)
328
-
329
- # Test the uq_metrics function
330
- end = Endpoint("aqsol-uq-100")
331
- df = end.auto_inference(capture=True)
332
- results = uq_metrics(df, target_col="solubility")
333
- print(results)
@@ -259,7 +259,7 @@ def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Dict[str, Any
259
259
  xgb_model._Booster = loaded_booster
260
260
  # Prepare data
261
261
  fs = FeatureSet(workbench_model.get_input())
262
- df = fs.pull_dataframe()
262
+ df = fs.view("training").pull_dataframe()
263
263
  feature_cols = workbench_model.features()
264
264
  # Convert string features to categorical
265
265
  for col in feature_cols: