workbench 0.8.172__py3-none-any.whl → 0.8.174__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/data_capture_core.py +348 -0
- workbench/core/artifacts/endpoint_core.py +9 -3
- workbench/core/artifacts/monitor_core.py +33 -249
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +471 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +428 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +95 -204
- workbench/model_scripts/xgb_model/generated_model_script.py +5 -5
- workbench/repl/workbench_shell.py +3 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +471 -0
- workbench/utils/chem_utils/mol_standardize.py +428 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- {workbench-0.8.172.dist-info → workbench-0.8.174.dist-info}/METADATA +1 -1
- {workbench-0.8.172.dist-info → workbench-0.8.174.dist-info}/RECORD +33 -22
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/utils/chem_utils.py +0 -1556
- {workbench-0.8.172.dist-info → workbench-0.8.174.dist-info}/WHEEL +0 -0
- {workbench-0.8.172.dist-info → workbench-0.8.174.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.172.dist-info → workbench-0.8.174.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.172.dist-info → workbench-0.8.174.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
# Model:
|
|
2
|
-
from
|
|
3
|
-
from
|
|
4
|
-
from xgboost import XGBRegressor
|
|
1
|
+
# Model: NGBoost Regressor with Distribution output
|
|
2
|
+
from ngboost import NGBRegressor
|
|
3
|
+
from ngboost.distns import Cauchy, T
|
|
4
|
+
from xgboost import XGBRegressor # Point Estimator
|
|
5
5
|
from sklearn.model_selection import train_test_split
|
|
6
6
|
|
|
7
7
|
# Model Performance Scores
|
|
@@ -20,12 +20,19 @@ import numpy as np
|
|
|
20
20
|
import pandas as pd
|
|
21
21
|
from typing import List, Tuple
|
|
22
22
|
|
|
23
|
+
# Local Imports
|
|
24
|
+
from proximity import Proximity
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
23
28
|
# Template Placeholders
|
|
24
29
|
TEMPLATE_PARAMS = {
|
|
30
|
+
"id_column": "udm_mol_id",
|
|
25
31
|
"target": "udm_asy_res_value",
|
|
26
32
|
"features": ['bcut2d_logplow', 'numradicalelectrons', 'smr_vsa5', 'fr_lactam', 'fr_morpholine', 'fr_aldehyde', 'slogp_vsa1', 'fr_amidine', 'bpol', 'fr_ester', 'fr_azo', 'kappa3', 'peoe_vsa5', 'fr_ketone_topliss', 'vsa_estate9', 'estate_vsa9', 'bcut2d_mrhi', 'fr_ndealkylation1', 'numrotatablebonds', 'minestateindex', 'fr_quatn', 'peoe_vsa3', 'fr_epoxide', 'fr_aniline', 'minpartialcharge', 'fr_nitroso', 'fpdensitymorgan2', 'fr_oxime', 'fr_sulfone', 'smr_vsa1', 'kappa1', 'fr_pyridine', 'numaromaticrings', 'vsa_estate6', 'molmr', 'estate_vsa1', 'fr_dihydropyridine', 'vsa_estate10', 'fr_alkyl_halide', 'chi2n', 'fr_thiocyan', 'fpdensitymorgan1', 'fr_unbrch_alkane', 'slogp_vsa9', 'chi4n', 'fr_nitro_arom', 'fr_al_oh', 'fr_furan', 'fr_c_s', 'peoe_vsa8', 'peoe_vsa14', 'numheteroatoms', 'fr_ndealkylation2', 'maxabspartialcharge', 'vsa_estate2', 'peoe_vsa7', 'apol', 'numhacceptors', 'fr_tetrazole', 'vsa_estate1', 'peoe_vsa9', 'naromatom', 'bcut2d_chghi', 'fr_sh', 'fr_halogen', 'slogp_vsa4', 'fr_benzodiazepine', 'molwt', 'fr_isocyan', 'fr_prisulfonamd', 'maxabsestateindex', 'minabsestateindex', 'peoe_vsa11', 'slogp_vsa12', 'estate_vsa5', 'numaliphaticcarbocycles', 'bcut2d_mwlow', 'slogp_vsa7', 'fr_allylic_oxid', 'fr_methoxy', 'fr_nh0', 'fr_coo2', 'fr_phenol', 'nacid', 'nbase', 'chi3v', 'fr_ar_nh', 'fr_nitrile', 'fr_imidazole', 'fr_urea', 'bcut2d_mrlow', 'chi1', 'smr_vsa6', 'fr_aryl_methyl', 'narombond', 'fr_alkyl_carbamate', 'fr_piperzine', 'exactmolwt', 'qed', 'chi0n', 'fr_sulfonamd', 'fr_thiazole', 'numvalenceelectrons', 'fr_phos_acid', 'peoe_vsa12', 'fr_nh1', 'fr_hdrzine', 'fr_c_o_nocoo', 'fr_lactone', 'estate_vsa6', 'bcut2d_logphi', 'vsa_estate7', 'peoe_vsa13', 'numsaturatedcarbocycles', 'fr_nitro', 'fr_phenol_noorthohbond', 'rotratio', 'fr_barbitur', 'fr_isothiocyan', 'balabanj', 'fr_arn', 'fr_imine', 'maxpartialcharge', 'fr_sulfide', 'slogp_vsa11', 'fr_hoccn', 'fr_n_o', 'peoe_vsa1', 'slogp_vsa6', 'heavyatommolwt', 'fractioncsp3', 'estate_vsa8', 'peoe_vsa10', 'numaliphaticrings', 'fr_thiophene', 'maxestateindex', 'smr_vsa10', 'labuteasa', 'smr_vsa2', 'fpdensitymorgan3', 'smr_vsa9', 'slogp_vsa10', 'numaromaticheterocycles', 'fr_nh2', 'fr_diazo', 'chi3n', 'fr_ar_coo', 'slogp_vsa5', 'fr_bicyclic', 'fr_amide', 'estate_vsa10', 'fr_guanido', 'chi1n', 'numsaturatedrings', 'fr_piperdine', 'fr_term_acetylene', 'estate_vsa4', 'slogp_vsa3', 'fr_coo', 'fr_ether', 'estate_vsa7', 'bcut2d_chglo', 'fr_oxazole', 'peoe_vsa6', 'hallkieralpha', 'peoe_vsa2', 'chi2v', 'nocount', 'vsa_estate5', 'fr_nhpyrrole', 'fr_al_coo', 'bertzct', 'estate_vsa11', 'minabspartialcharge', 'slogp_vsa8', 'fr_imide', 'kappa2', 'numaliphaticheterocycles', 'numsaturatedheterocycles', 'fr_hdrzone', 'smr_vsa4', 'fr_ar_n', 'nrot', 'smr_vsa8', 'slogp_vsa2', 'chi4v', 'fr_phos_ester', 'fr_para_hydroxylation', 'smr_vsa3', 'nhohcount', 'estate_vsa2', 'mollogp', 'tpsa', 'fr_azide', 'peoe_vsa4', 'numhdonors', 'fr_al_oh_notert', 'fr_c_o', 'chi0', 'fr_nitro_arom_nonortho', 'vsa_estate3', 'fr_benzene', 'fr_ketone', 'vsa_estate8', 'smr_vsa7', 'fr_ar_oh', 'fr_priamide', 'ringcount', 'estate_vsa3', 'numaromaticcarbocycles', 'bcut2d_mwhi', 'chi1v', 'heavyatomcount', 'vsa_estate4', 'chi0v', 'chiral_centers', 'r_cnt', 's_cnt', 'db_stereo', 'e_cnt', 'z_cnt', 'chiral_fp', 'db_fp'],
|
|
27
33
|
"compressed_features": [],
|
|
28
|
-
"train_all_data":
|
|
34
|
+
"train_all_data": False,
|
|
35
|
+
"track_columns": "udm_asy_res_value"
|
|
29
36
|
}
|
|
30
37
|
|
|
31
38
|
|
|
@@ -101,7 +108,7 @@ def convert_categorical_types(df: pd.DataFrame, features: list, category_mapping
|
|
|
101
108
|
|
|
102
109
|
|
|
103
110
|
def decompress_features(
|
|
104
|
-
|
|
111
|
+
df: pd.DataFrame, features: List[str], compressed_features: List[str]
|
|
105
112
|
) -> Tuple[pd.DataFrame, List[str]]:
|
|
106
113
|
"""Prepare features for the model by decompressing bitstring features
|
|
107
114
|
|
|
@@ -157,11 +164,13 @@ def decompress_features(
|
|
|
157
164
|
|
|
158
165
|
if __name__ == "__main__":
|
|
159
166
|
# Template Parameters
|
|
167
|
+
id_column = TEMPLATE_PARAMS["id_column"]
|
|
160
168
|
target = TEMPLATE_PARAMS["target"]
|
|
161
169
|
features = TEMPLATE_PARAMS["features"]
|
|
162
170
|
orig_features = features.copy()
|
|
163
171
|
compressed_features = TEMPLATE_PARAMS["compressed_features"]
|
|
164
172
|
train_all_data = TEMPLATE_PARAMS["train_all_data"]
|
|
173
|
+
track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
|
|
165
174
|
validation_split = 0.2
|
|
166
175
|
|
|
167
176
|
# Script arguments for input/output directories
|
|
@@ -219,175 +228,78 @@ if __name__ == "__main__":
|
|
|
219
228
|
print(f"FIT/TRAIN: {df_train.shape}")
|
|
220
229
|
print(f"VALIDATION: {df_val.shape}")
|
|
221
230
|
|
|
231
|
+
# We're using XGBoost for point predictions and NGBoost for uncertainty quantification
|
|
232
|
+
xgb_model = XGBRegressor()
|
|
233
|
+
ngb_model = NGBRegressor() # Dist=Cauchy) Seems to give HUGE prediction intervals
|
|
234
|
+
ngb_model = NGBRegressor(
|
|
235
|
+
Dist=T,
|
|
236
|
+
learning_rate=0.005,
|
|
237
|
+
minibatch_frac=0.1, # Very small batches
|
|
238
|
+
col_sample=0.8 # This parameter DOES exist
|
|
239
|
+
) # Testing this out
|
|
240
|
+
print("NGBoost using T distribution for uncertainty quantification")
|
|
241
|
+
|
|
222
242
|
# Prepare features and targets for training
|
|
223
243
|
X_train = df_train[features]
|
|
224
244
|
X_validate = df_val[features]
|
|
225
245
|
y_train = df_train[target]
|
|
226
246
|
y_validate = df_val[target]
|
|
227
247
|
|
|
228
|
-
# Train
|
|
229
|
-
print("\nTraining XGBoost for point predictions...")
|
|
230
|
-
xgb_model = XGBRegressor(
|
|
231
|
-
n_estimators=1000,
|
|
232
|
-
max_depth=6,
|
|
233
|
-
learning_rate=0.01,
|
|
234
|
-
subsample=0.8,
|
|
235
|
-
colsample_bytree=0.8,
|
|
236
|
-
random_state=42,
|
|
237
|
-
verbosity=0
|
|
238
|
-
)
|
|
248
|
+
# Train both models using the training data
|
|
239
249
|
xgb_model.fit(X_train, y_train)
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
print(f"
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
confidence_levels = [0.50, 0.80, 0.90, 0.95] # 50%, 80%, 90%, 95% confidence intervals
|
|
254
|
-
|
|
255
|
-
# Store MAPIE models for each confidence level
|
|
256
|
-
mapie_models = {}
|
|
257
|
-
|
|
258
|
-
# Train models for each confidence level
|
|
259
|
-
for confidence_level in confidence_levels:
|
|
260
|
-
alpha = 1 - confidence_level
|
|
261
|
-
lower_q = alpha / 2
|
|
262
|
-
upper_q = 1 - alpha / 2
|
|
263
|
-
|
|
264
|
-
print(f"\nTraining quantile models for {confidence_level * 100:.0f}% confidence interval...")
|
|
265
|
-
print(f" Quantiles: {lower_q:.3f}, {upper_q:.3f}, 0.500")
|
|
266
|
-
|
|
267
|
-
# Train three models for this confidence level
|
|
268
|
-
quantile_estimators = []
|
|
269
|
-
for q in [lower_q, upper_q, 0.5]:
|
|
270
|
-
print(f" Training model for quantile {q:.3f}...")
|
|
271
|
-
est = LGBMRegressor(
|
|
272
|
-
objective="quantile",
|
|
273
|
-
alpha=q,
|
|
274
|
-
n_estimators=1000,
|
|
275
|
-
max_depth=6,
|
|
276
|
-
learning_rate=0.01,
|
|
277
|
-
num_leaves=31,
|
|
278
|
-
min_child_samples=20,
|
|
279
|
-
subsample=0.8,
|
|
280
|
-
colsample_bytree=0.8,
|
|
281
|
-
random_state=42,
|
|
282
|
-
verbose=-1,
|
|
283
|
-
force_col_wise=True
|
|
284
|
-
)
|
|
285
|
-
est.fit(X_train, y_train)
|
|
286
|
-
quantile_estimators.append(est)
|
|
287
|
-
|
|
288
|
-
# Create MAPIE CQR model for this confidence level
|
|
289
|
-
print(f" Setting up MAPIE CQR for {confidence_level * 100:.0f}% confidence...")
|
|
290
|
-
mapie_model = ConformalizedQuantileRegressor(
|
|
291
|
-
quantile_estimators,
|
|
292
|
-
confidence_level=confidence_level,
|
|
293
|
-
prefit=True
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
# Conformalize the model
|
|
297
|
-
print(f" Conformalizing with validation data...")
|
|
298
|
-
mapie_model.conformalize(X_validate, y_validate)
|
|
299
|
-
|
|
300
|
-
# Store the model
|
|
301
|
-
mapie_models[f"mapie_{confidence_level:.2f}"] = mapie_model
|
|
302
|
-
|
|
303
|
-
# Validate coverage for this confidence level
|
|
304
|
-
y_pred, y_pis = mapie_model.predict_interval(X_validate)
|
|
305
|
-
coverage = np.mean((y_validate >= y_pis[:, 0, 0]) & (y_validate <= y_pis[:, 1, 0]))
|
|
306
|
-
print(f" Coverage: Target={confidence_level * 100:.0f}%, Empirical={coverage * 100:.1f}%")
|
|
307
|
-
|
|
308
|
-
print(f"\nOverall Model Performance Summary:")
|
|
309
|
-
print(f"XGBoost RMSE: {xgb_rmse:.3f}")
|
|
310
|
-
print(f"XGBoost MAE: {xgb_mae:.3f}")
|
|
311
|
-
print(f"XGBoost R2: {xgb_r2:.3f}")
|
|
250
|
+
ngb_model.fit(X_train, y_train, X_val=X_validate, Y_val=y_validate)
|
|
251
|
+
|
|
252
|
+
# Make Predictions on the Validation Set
|
|
253
|
+
print(f"Making Predictions on Validation Set...")
|
|
254
|
+
preds = xgb_model.predict(X_validate)
|
|
255
|
+
|
|
256
|
+
# Calculate various model performance metrics (regression)
|
|
257
|
+
rmse = root_mean_squared_error(y_validate, preds)
|
|
258
|
+
mae = mean_absolute_error(y_validate, preds)
|
|
259
|
+
r2 = r2_score(y_validate, preds)
|
|
260
|
+
print(f"RMSE: {rmse:.3f}")
|
|
261
|
+
print(f"MAE: {mae:.3f}")
|
|
262
|
+
print(f"R2: {r2:.3f}")
|
|
312
263
|
print(f"NumRows: {len(df_val)}")
|
|
313
264
|
|
|
314
|
-
# Analyze interval widths across confidence levels
|
|
315
|
-
print(f"\nInterval Width Analysis:")
|
|
316
|
-
for conf_level in confidence_levels:
|
|
317
|
-
model = mapie_models[f"mapie_{conf_level:.2f}"]
|
|
318
|
-
_, y_pis = model.predict_interval(X_validate)
|
|
319
|
-
widths = y_pis[:, 1, 0] - y_pis[:, 0, 0]
|
|
320
|
-
print(f" {conf_level * 100:.0f}% CI: Mean width={np.mean(widths):.3f}, Std={np.std(widths):.3f}")
|
|
321
|
-
|
|
322
265
|
# Save the trained XGBoost model
|
|
323
266
|
xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
|
|
324
267
|
|
|
325
|
-
# Save
|
|
326
|
-
|
|
327
|
-
joblib.dump(model, os.path.join(args.model_dir, f"{model_name}.joblib"))
|
|
268
|
+
# Save the trained NGBoost model
|
|
269
|
+
joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
|
|
328
270
|
|
|
329
|
-
# Save the
|
|
271
|
+
# Save the features (this will validate input during predictions)
|
|
330
272
|
with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
|
|
331
|
-
json.dump(
|
|
332
|
-
|
|
333
|
-
#
|
|
334
|
-
|
|
335
|
-
with open(os.path.join(args.model_dir, "category_mappings.json"), "w") as fp:
|
|
336
|
-
json.dump(category_mappings, fp)
|
|
337
|
-
|
|
338
|
-
# Save model configuration
|
|
339
|
-
model_config = {
|
|
340
|
-
"model_type": "XGBoost_MAPIE_CQR_LightGBM",
|
|
341
|
-
"confidence_levels": confidence_levels,
|
|
342
|
-
"n_features": len(features),
|
|
343
|
-
"target": target,
|
|
344
|
-
"validation_metrics": {
|
|
345
|
-
"xgb_rmse": float(xgb_rmse),
|
|
346
|
-
"xgb_mae": float(xgb_mae),
|
|
347
|
-
"xgb_r2": float(xgb_r2),
|
|
348
|
-
"n_validation": len(df_val)
|
|
349
|
-
}
|
|
350
|
-
}
|
|
351
|
-
with open(os.path.join(args.model_dir, "model_config.json"), "w") as fp:
|
|
352
|
-
json.dump(model_config, fp, indent=2)
|
|
273
|
+
json.dump(orig_features, fp) # We save the original features, not the decompressed ones
|
|
274
|
+
|
|
275
|
+
# Now the Proximity model
|
|
276
|
+
model = Proximity(df_train, id_column, features, target, track_columns=track_columns)
|
|
353
277
|
|
|
354
|
-
|
|
355
|
-
|
|
278
|
+
# Now serialize the model
|
|
279
|
+
model.serialize(args.model_dir)
|
|
356
280
|
|
|
357
281
|
|
|
358
282
|
#
|
|
359
283
|
# Inference Section
|
|
360
284
|
#
|
|
361
285
|
def model_fn(model_dir) -> dict:
|
|
362
|
-
"""Load XGBoost and
|
|
363
|
-
|
|
364
|
-
# Load model configuration to know which models to load
|
|
365
|
-
with open(os.path.join(model_dir, "model_config.json")) as fp:
|
|
366
|
-
config = json.load(fp)
|
|
286
|
+
"""Load and return XGBoost, NGBoost, and Prox Model from model directory."""
|
|
367
287
|
|
|
368
288
|
# Load XGBoost regressor
|
|
369
289
|
xgb_path = os.path.join(model_dir, "xgb_model.json")
|
|
370
290
|
xgb_model = XGBRegressor(enable_categorical=True)
|
|
371
291
|
xgb_model.load_model(xgb_path)
|
|
372
292
|
|
|
373
|
-
# Load
|
|
374
|
-
|
|
375
|
-
for conf_level in config["confidence_levels"]:
|
|
376
|
-
model_name = f"mapie_{conf_level:.2f}"
|
|
377
|
-
mapie_models[model_name] = joblib.load(os.path.join(model_dir, f"{model_name}.joblib"))
|
|
293
|
+
# Load NGBoost regressor
|
|
294
|
+
ngb_model = joblib.load(os.path.join(model_dir, "ngb_model.joblib"))
|
|
378
295
|
|
|
379
|
-
#
|
|
380
|
-
|
|
381
|
-
category_path = os.path.join(model_dir, "category_mappings.json")
|
|
382
|
-
if os.path.exists(category_path):
|
|
383
|
-
with open(category_path) as fp:
|
|
384
|
-
category_mappings = json.load(fp)
|
|
296
|
+
# Deserialize the proximity model
|
|
297
|
+
prox_model = Proximity.deserialize(model_dir)
|
|
385
298
|
|
|
386
299
|
return {
|
|
387
|
-
"
|
|
388
|
-
"
|
|
389
|
-
"
|
|
390
|
-
"category_mappings": category_mappings
|
|
300
|
+
"xgboost": xgb_model,
|
|
301
|
+
"ngboost": ngb_model,
|
|
302
|
+
"proximity": prox_model
|
|
391
303
|
}
|
|
392
304
|
|
|
393
305
|
|
|
@@ -403,7 +315,7 @@ def input_fn(input_data, content_type):
|
|
|
403
315
|
if "text/csv" in content_type:
|
|
404
316
|
return pd.read_csv(StringIO(input_data))
|
|
405
317
|
elif "application/json" in content_type:
|
|
406
|
-
return pd.DataFrame(json.loads(input_data))
|
|
318
|
+
return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
|
|
407
319
|
else:
|
|
408
320
|
raise ValueError(f"{content_type} not supported!")
|
|
409
321
|
|
|
@@ -411,26 +323,23 @@ def input_fn(input_data, content_type):
|
|
|
411
323
|
def output_fn(output_df, accept_type):
|
|
412
324
|
"""Supports both CSV and JSON output formats."""
|
|
413
325
|
if "text/csv" in accept_type:
|
|
414
|
-
|
|
415
|
-
for col in output_df.select_dtypes(include=['category']).columns:
|
|
416
|
-
output_df[col] = output_df[col].astype(str)
|
|
417
|
-
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
326
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
|
|
418
327
|
return csv_output, "text/csv"
|
|
419
328
|
elif "application/json" in accept_type:
|
|
420
|
-
return output_df.to_json(orient="records"), "application/json"
|
|
329
|
+
return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
|
|
421
330
|
else:
|
|
422
331
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
423
332
|
|
|
424
333
|
|
|
425
334
|
def predict_fn(df, models) -> pd.DataFrame:
|
|
426
|
-
"""Make
|
|
335
|
+
"""Make Predictions with our XGB Quantile Regression Model
|
|
427
336
|
|
|
428
337
|
Args:
|
|
429
338
|
df (pd.DataFrame): The input DataFrame
|
|
430
|
-
models (dict):
|
|
339
|
+
models (dict): The dictionary of models to use for predictions
|
|
431
340
|
|
|
432
341
|
Returns:
|
|
433
|
-
pd.DataFrame: DataFrame with
|
|
342
|
+
pd.DataFrame: The DataFrame with the predictions added
|
|
434
343
|
"""
|
|
435
344
|
|
|
436
345
|
# Grab our feature columns (from training)
|
|
@@ -441,62 +350,44 @@ def predict_fn(df, models) -> pd.DataFrame:
|
|
|
441
350
|
# Match features in a case-insensitive manner
|
|
442
351
|
matched_df = match_features_case_insensitive(df, model_features)
|
|
443
352
|
|
|
444
|
-
#
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
353
|
+
# Use XGBoost for point predictions
|
|
354
|
+
df["prediction"] = models["xgboost"].predict(matched_df[model_features])
|
|
355
|
+
|
|
356
|
+
# NGBoost predict returns distribution objects
|
|
357
|
+
y_dists = models["ngboost"].pred_dist(matched_df[model_features])
|
|
358
|
+
|
|
359
|
+
# Extract parameters from distribution
|
|
360
|
+
dist_params = y_dists.params
|
|
451
361
|
|
|
452
|
-
#
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
#
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
elif conf_level == 0.90: # 90% CI
|
|
474
|
-
df["q_05"] = y_pis[:, 0, 0]
|
|
475
|
-
df["q_95"] = y_pis[:, 1, 0]
|
|
476
|
-
elif conf_level == 0.95: # 95% CI
|
|
477
|
-
df["q_025"] = y_pis[:, 0, 0]
|
|
478
|
-
df["q_975"] = y_pis[:, 1, 0]
|
|
479
|
-
|
|
480
|
-
# Add median (q_50) from XGBoost prediction
|
|
481
|
-
df["q_50"] = df["prediction"]
|
|
482
|
-
|
|
483
|
-
# Calculate uncertainty metrics based on 95% interval
|
|
484
|
-
interval_width = df["q_975"] - df["q_025"]
|
|
485
|
-
df["prediction_std"] = interval_width / 3.92
|
|
362
|
+
# Extract mean and std from distribution parameters
|
|
363
|
+
df["prediction_uq"] = dist_params['loc'] # mean
|
|
364
|
+
df["prediction_std"] = dist_params['scale'] # standard deviation
|
|
365
|
+
|
|
366
|
+
# Add 95% prediction intervals using ppf (percent point function)
|
|
367
|
+
# Note: Our hybrid model uses XGB point prediction and NGBoost UQ
|
|
368
|
+
# so we need to adjust the bounds to include the point prediction
|
|
369
|
+
df["q_025"] = np.minimum(y_dists.ppf(0.025), df["prediction"])
|
|
370
|
+
df["q_975"] = np.maximum(y_dists.ppf(0.975), df["prediction"])
|
|
371
|
+
|
|
372
|
+
# Add 90% prediction intervals
|
|
373
|
+
df["q_05"] = y_dists.ppf(0.05) # 5th percentile
|
|
374
|
+
df["q_95"] = y_dists.ppf(0.95) # 95th percentile
|
|
375
|
+
|
|
376
|
+
# Add 80% prediction intervals
|
|
377
|
+
df["q_10"] = y_dists.ppf(0.10) # 10th percentile
|
|
378
|
+
df["q_90"] = y_dists.ppf(0.90) # 90th percentile
|
|
379
|
+
|
|
380
|
+
# Add 50% prediction intervals
|
|
381
|
+
df["q_25"] = y_dists.ppf(0.25) # 25th percentile
|
|
382
|
+
df["q_75"] = y_dists.ppf(0.75) # 75th percentile
|
|
486
383
|
|
|
487
384
|
# Reorder the quantile columns for easier reading
|
|
488
385
|
quantile_cols = ["q_025", "q_05", "q_10", "q_25", "q_75", "q_90", "q_95", "q_975"]
|
|
489
386
|
other_cols = [col for col in df.columns if col not in quantile_cols]
|
|
490
387
|
df = df[other_cols + quantile_cols]
|
|
491
388
|
|
|
492
|
-
#
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
# Confidence bands
|
|
496
|
-
df["confidence_band"] = pd.cut(
|
|
497
|
-
df["uncertainty_score"],
|
|
498
|
-
bins=[0, 0.5, 1.0, 2.0, np.inf],
|
|
499
|
-
labels=["high", "medium", "low", "very_low"]
|
|
500
|
-
)
|
|
389
|
+
# Compute Nearest neighbors with Proximity model
|
|
390
|
+
models["proximity"].neighbors(df)
|
|
501
391
|
|
|
392
|
+
# Return the modified DataFrame
|
|
502
393
|
return df
|
|
@@ -28,12 +28,12 @@ from typing import List, Tuple
|
|
|
28
28
|
|
|
29
29
|
# Template Parameters
|
|
30
30
|
TEMPLATE_PARAMS = {
|
|
31
|
-
"model_type": "
|
|
32
|
-
"target": "
|
|
33
|
-
"features": ['bcut2d_logplow', '
|
|
31
|
+
"model_type": "classifier",
|
|
32
|
+
"target": "class",
|
|
33
|
+
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
34
34
|
"compressed_features": [],
|
|
35
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/
|
|
36
|
-
"train_all_data":
|
|
35
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/sol-class-f1-100/training",
|
|
36
|
+
"train_all_data": True
|
|
37
37
|
}
|
|
38
38
|
|
|
39
39
|
# Function to check if dataframe is empty
|
|
@@ -41,7 +41,7 @@ from workbench.cached.cached_meta import CachedMeta
|
|
|
41
41
|
try:
|
|
42
42
|
import rdkit # noqa
|
|
43
43
|
import mordred # noqa
|
|
44
|
-
from workbench.utils import
|
|
44
|
+
from workbench.utils.chem_utils import vis
|
|
45
45
|
|
|
46
46
|
HAVE_CHEM_UTILS = True
|
|
47
47
|
except ImportError:
|
|
@@ -178,12 +178,12 @@ class WorkbenchShell:
|
|
|
178
178
|
|
|
179
179
|
# Add cheminformatics utils if available
|
|
180
180
|
if HAVE_CHEM_UTILS:
|
|
181
|
-
self.commands["show"] =
|
|
181
|
+
self.commands["show"] = vis.show
|
|
182
182
|
|
|
183
183
|
def start(self):
|
|
184
184
|
"""Start the Workbench IPython shell"""
|
|
185
185
|
cprint("magenta", "\nWelcome to Workbench!")
|
|
186
|
-
if self.aws_status
|
|
186
|
+
if not self.aws_status:
|
|
187
187
|
cprint("red", "AWS Account Connection Failed...Review/Fix the Workbench Config:")
|
|
188
188
|
cprint("red", f"Path: {self.cm.site_config_path}")
|
|
189
189
|
self.show_config()
|
|
File without changes
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""Molecular fingerprint computation utilities"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
# Molecular Descriptor Imports
|
|
7
|
+
from rdkit import Chem
|
|
8
|
+
from rdkit.Chem import rdFingerprintGenerator
|
|
9
|
+
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
10
|
+
|
|
11
|
+
# Set up the logger
|
|
12
|
+
log = logging.getLogger("workbench")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
|
|
16
|
+
"""Compute and add Morgan fingerprints to the DataFrame.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
df (pd.DataFrame): Input DataFrame containing SMILES strings.
|
|
20
|
+
radius (int): Radius for the Morgan fingerprint.
|
|
21
|
+
n_bits (int): Number of bits for the fingerprint.
|
|
22
|
+
counts (bool): Count simulation for the fingerprint.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
|
|
26
|
+
|
|
27
|
+
Note:
|
|
28
|
+
See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
|
|
29
|
+
"""
|
|
30
|
+
delete_mol_column = False
|
|
31
|
+
|
|
32
|
+
# Check for the SMILES column (case-insensitive)
|
|
33
|
+
smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
|
|
34
|
+
if smiles_column is None:
|
|
35
|
+
raise ValueError("Input DataFrame must have a 'smiles' column")
|
|
36
|
+
|
|
37
|
+
# Sanity check the molecule column (sometimes it gets serialized, which doesn't work)
|
|
38
|
+
if "molecule" in df.columns and df["molecule"].dtype == "string":
|
|
39
|
+
log.warning("Detected serialized molecules in 'molecule' column. Removing...")
|
|
40
|
+
del df["molecule"]
|
|
41
|
+
|
|
42
|
+
# Convert SMILES to RDKit molecule objects (vectorized)
|
|
43
|
+
if "molecule" not in df.columns:
|
|
44
|
+
log.info("Converting SMILES to RDKit Molecules...")
|
|
45
|
+
delete_mol_column = True
|
|
46
|
+
df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
|
|
47
|
+
# Make sure our molecules are not None
|
|
48
|
+
failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
|
|
49
|
+
if failed_smiles:
|
|
50
|
+
log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
|
|
51
|
+
df = df.dropna(subset=["molecule"])
|
|
52
|
+
|
|
53
|
+
# If we have fragments in our compounds, get the largest fragment before computing fingerprints
|
|
54
|
+
largest_frags = df["molecule"].apply(
|
|
55
|
+
lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Create a Morgan fingerprint generator
|
|
59
|
+
if counts:
|
|
60
|
+
n_bits *= 4 # Multiply by 4 to simulate counts
|
|
61
|
+
morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
|
|
62
|
+
|
|
63
|
+
# Compute Morgan fingerprints (vectorized)
|
|
64
|
+
fingerprints = largest_frags.apply(
|
|
65
|
+
lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Add the fingerprints to the DataFrame
|
|
69
|
+
df["fingerprint"] = fingerprints
|
|
70
|
+
|
|
71
|
+
# Drop the intermediate 'molecule' column if it was added
|
|
72
|
+
if delete_mol_column:
|
|
73
|
+
del df["molecule"]
|
|
74
|
+
return df
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
if __name__ == "__main__":
|
|
78
|
+
print("Running molecular fingerprint tests...")
|
|
79
|
+
print("Note: This requires molecular_screening module to be available")
|
|
80
|
+
|
|
81
|
+
# Test molecules
|
|
82
|
+
test_molecules = {
|
|
83
|
+
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
|
|
84
|
+
"caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
|
|
85
|
+
"glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
|
|
86
|
+
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
|
|
87
|
+
"benzene": "c1ccccc1",
|
|
88
|
+
"butene_e": "C/C=C/C", # E-butene
|
|
89
|
+
"butene_z": "C/C=C\\C", # Z-butene
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# Test 1: Morgan Fingerprints
|
|
93
|
+
print("\n1. Testing Morgan fingerprint generation...")
|
|
94
|
+
|
|
95
|
+
test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
|
|
96
|
+
|
|
97
|
+
fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
|
|
98
|
+
|
|
99
|
+
print(" Fingerprint generation results:")
|
|
100
|
+
for _, row in fp_df.iterrows():
|
|
101
|
+
fp = row.get("fingerprint", "N/A")
|
|
102
|
+
fp_len = len(fp) if fp != "N/A" else 0
|
|
103
|
+
print(f" {row['name']:15} → {fp_len} bits")
|
|
104
|
+
|
|
105
|
+
# Test 2: Different fingerprint parameters
|
|
106
|
+
print("\n2. Testing different fingerprint parameters...")
|
|
107
|
+
|
|
108
|
+
# Test with counts enabled
|
|
109
|
+
fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
|
|
110
|
+
|
|
111
|
+
print(" With count simulation (256 bits * 4):")
|
|
112
|
+
for _, row in fp_counts_df.iterrows():
|
|
113
|
+
fp = row.get("fingerprint", "N/A")
|
|
114
|
+
fp_len = len(fp) if fp != "N/A" else 0
|
|
115
|
+
print(f" {row['name']:15} → {fp_len} bits")
|
|
116
|
+
|
|
117
|
+
# Test 3: Edge cases
|
|
118
|
+
print("\n3. Testing edge cases...")
|
|
119
|
+
|
|
120
|
+
# Invalid SMILES
|
|
121
|
+
invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
|
|
122
|
+
try:
|
|
123
|
+
fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
|
|
124
|
+
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
|
|
125
|
+
except Exception as e:
|
|
126
|
+
print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
|
|
127
|
+
|
|
128
|
+
# Test with pre-existing molecule column
|
|
129
|
+
mol_df = test_df.copy()
|
|
130
|
+
mol_df["molecule"] = mol_df["SMILES"].apply(Chem.MolFromSmiles)
|
|
131
|
+
fp_with_mol = compute_morgan_fingerprints(mol_df)
|
|
132
|
+
print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
|
|
133
|
+
|
|
134
|
+
print("\n✅ All fingerprint tests completed!")
|