workbench 0.8.219__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/api/feature_set.py +0 -1
  6. workbench/core/artifacts/feature_set_core.py +183 -228
  7. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  8. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  9. workbench/model_scripts/chemprop/chemprop.template +193 -68
  10. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  11. workbench/model_scripts/pytorch_model/generated_model_script.py +3 -3
  12. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  13. workbench/scripts/ml_pipeline_sqs.py +71 -2
  14. workbench/themes/light/custom.css +7 -1
  15. workbench/themes/midnight_blue/custom.css +34 -0
  16. workbench/utils/chem_utils/projections.py +16 -6
  17. workbench/utils/model_utils.py +0 -1
  18. workbench/utils/plot_utils.py +146 -28
  19. workbench/utils/theme_manager.py +95 -30
  20. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  21. workbench/web_interface/components/settings_menu.py +184 -0
  22. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  23. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/RECORD +27 -25
  24. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  25. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +0 -0
  26. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  27. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,7 @@ DEFAULT_HYPERPARAMETERS = {
34
34
  "max_epochs": 400,
35
35
  "patience": 50,
36
36
  "batch_size": 32,
37
- # Message Passing
37
+ # Message Passing (ignored when using foundation model)
38
38
  "hidden_dim": 700,
39
39
  "depth": 6,
40
40
  "dropout": 0.1, # Lower dropout - ensemble provides regularization
@@ -45,16 +45,24 @@ DEFAULT_HYPERPARAMETERS = {
45
45
  "criterion": "mae",
46
46
  # Random seed
47
47
  "seed": 42,
48
+ # Foundation model support
49
+ # - "CheMeleon": Load CheMeleon pretrained weights (auto-downloads on first use)
50
+ # - Path to .pt file: Load custom pretrained Chemprop model
51
+ # - None: Train from scratch (default)
52
+ "from_foundation": None,
53
+ # Freeze MPNN for N epochs, then unfreeze (0 = no freezing, train all params from start)
54
+ # Recommended: 5-20 epochs when using foundation models to stabilize FFN before fine-tuning MPNN
55
+ "freeze_mpnn_epochs": 0,
48
56
  }
49
57
 
50
58
  # Template parameters (filled in by Workbench)
51
59
  TEMPLATE_PARAMS = {
52
60
  "model_type": "uq_regressor",
53
- "targets": ['logd'],
54
- "feature_list": ['smiles', 'mollogp', 'fr_halogen', 'nbase', 'peoe_vsa6', 'bcut2d_mrlow', 'peoe_vsa7', 'peoe_vsa9', 'vsa_estate1', 'peoe_vsa1', 'numhdonors', 'vsa_estate5', 'smr_vsa3', 'slogp_vsa1', 'vsa_estate7', 'bcut2d_mwhi', 'axp_2dv', 'axp_3dv', 'mi', 'smr_vsa9', 'vsa_estate3', 'estate_vsa9', 'bcut2d_mwlow', 'tpsa', 'vsa_estate10', 'xch_5dv', 'slogp_vsa2', 'nhohcount', 'bcut2d_logplow', 'hallkieralpha', 'c2sp2', 'bcut2d_chglo', 'smr_vsa4', 'maxabspartialcharge', 'estate_vsa6', 'qed', 'slogp_vsa6', 'vsa_estate2', 'bcut2d_logphi', 'vsa_estate8', 'xch_7dv', 'fpdensitymorgan3', 'xpc_6d', 'smr_vsa10', 'axp_0d', 'fr_nh1', 'axp_4dv', 'peoe_vsa2', 'estate_vsa8', 'peoe_vsa5', 'vsa_estate6'],
55
- "id_column": "molecule_name",
56
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-reg-chemprop-hybrid/training",
57
- "hyperparameters": {},
61
+ "targets": ['udm_asy_res_free_percent'],
62
+ "feature_list": ['smiles'],
63
+ "id_column": "udm_mol_bat_id",
64
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/ppb-human-free-reg-chemprop-foundation-1-dt/training",
65
+ "hyperparameters": {'from_foundation': 'CheMeleon', 'freeze_mpnn_epochs': 10, 'n_folds': 5, 'max_epochs': 100, 'patience': 20, 'ffn_hidden_dim': 512, 'dropout': 0.15},
58
66
  }
59
67
 
60
68
 
@@ -114,26 +122,27 @@ def _create_molecule_datapoints(
114
122
  # Model Loading (for SageMaker inference)
115
123
  # =============================================================================
116
124
  def model_fn(model_dir: str) -> dict:
117
- """Load ChemProp MPNN ensemble from the specified directory."""
118
- from lightning import pytorch as pl
125
+ """Load ChemProp MPNN ensemble from the specified directory.
119
126
 
127
+ Optimized for serverless cold starts - uses direct PyTorch inference
128
+ instead of Lightning Trainer to minimize startup time.
129
+ """
120
130
  metadata = joblib.load(os.path.join(model_dir, "ensemble_metadata.joblib"))
131
+
132
+ # Load all ensemble models (keep on CPU for serverless compatibility)
133
+ # ChemProp handles device placement internally
121
134
  ensemble_models = []
122
135
  for i in range(metadata["n_ensemble"]):
123
136
  model = models.MPNN.load_from_file(os.path.join(model_dir, f"chemprop_model_{i}.pt"))
124
137
  model.eval()
125
138
  ensemble_models.append(model)
126
139
 
127
- # Pre-initialize trainer once during model loading (expensive operation)
128
- trainer = pl.Trainer(accelerator="auto", logger=False, enable_progress_bar=False)
129
-
130
140
  print(f"Loaded {len(ensemble_models)} model(s), targets={metadata['target_columns']}")
131
141
  return {
132
142
  "ensemble_models": ensemble_models,
133
143
  "n_ensemble": metadata["n_ensemble"],
134
144
  "target_columns": metadata["target_columns"],
135
145
  "median_std": metadata["median_std"],
136
- "trainer": trainer,
137
146
  }
138
147
 
139
148
 
@@ -141,13 +150,15 @@ def model_fn(model_dir: str) -> dict:
141
150
  # Inference (for SageMaker inference)
142
151
  # =============================================================================
143
152
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
144
- """Make predictions with ChemProp MPNN ensemble."""
153
+ """Make predictions with ChemProp MPNN ensemble.
154
+
155
+ Uses direct PyTorch inference (no Lightning Trainer) for fast serverless inference.
156
+ """
145
157
  model_type = TEMPLATE_PARAMS["model_type"]
146
158
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
147
159
 
148
160
  ensemble_models = model_dict["ensemble_models"]
149
161
  target_columns = model_dict["target_columns"]
150
- trainer = model_dict["trainer"] # Use pre-initialized trainer
151
162
 
152
163
  # Load artifacts
153
164
  label_encoder = None
@@ -202,18 +213,34 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
202
213
  return df
203
214
 
204
215
  dataset = data.MoleculeDataset(datapoints)
205
- dataloader = data.build_dataloader(dataset, shuffle=False)
216
+ dataloader = data.build_dataloader(dataset, shuffle=False, batch_size=64)
206
217
 
207
- # Ensemble predictions
218
+ # Ensemble predictions using direct PyTorch inference (no Lightning Trainer)
208
219
  all_preds = []
209
220
  for model in ensemble_models:
221
+ model_preds = []
222
+ model.eval()
210
223
  with torch.inference_mode():
211
- predictions = trainer.predict(model, dataloader)
212
- preds = np.concatenate([p.numpy() for p in predictions], axis=0)
224
+ for batch in dataloader:
225
+ # TrainingBatch contains (bmg, V_d, X_d, targets, weights, lt_mask, gt_mask)
226
+ # For inference we only need bmg, V_d, X_d
227
+ bmg, V_d, X_d, *_ = batch
228
+ output = model(bmg, V_d, X_d)
229
+ model_preds.append(output.detach().cpu().numpy())
230
+
231
+ if len(model_preds) == 0:
232
+ print(f"Warning: No predictions generated. Dataset size: {len(datapoints)}")
233
+ continue
234
+
235
+ preds = np.concatenate(model_preds, axis=0)
213
236
  if preds.ndim == 3 and preds.shape[1] == 1:
214
237
  preds = preds.squeeze(axis=1)
215
238
  all_preds.append(preds)
216
239
 
240
+ if len(all_preds) == 0:
241
+ print("Error: No ensemble predictions generated")
242
+ return df
243
+
217
244
  preds = np.mean(np.stack(all_preds), axis=0)
218
245
  preds_std = np.std(np.stack(all_preds), axis=0)
219
246
  if preds.ndim == 1:
@@ -243,8 +270,11 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
243
270
  df["prediction"] = df[f"{target_columns[0]}_pred"]
244
271
  df["prediction_std"] = df[f"{target_columns[0]}_pred_std"]
245
272
 
246
- # Compute confidence from ensemble std
247
- df = _compute_std_confidence(df, model_dict["median_std"])
273
+ # Compute confidence from ensemble std (or NaN if single model)
274
+ if model_dict["median_std"] is not None:
275
+ df = _compute_std_confidence(df, model_dict["median_std"])
276
+ else:
277
+ df["confidence"] = np.nan
248
278
 
249
279
  return df
250
280
 
@@ -279,54 +309,107 @@ if __name__ == "__main__":
279
309
  )
280
310
 
281
311
  # -------------------------------------------------------------------------
282
- # Training-only helper function
312
+ # Training-only helper functions
283
313
  # -------------------------------------------------------------------------
284
- def build_mpnn_model(
285
- hyperparameters: dict,
286
- task: str = "regression",
287
- num_classes: int | None = None,
288
- n_targets: int = 1,
289
- n_extra_descriptors: int = 0,
290
- x_d_transform: nn.ScaleTransform | None = None,
291
- output_transform: nn.UnscaleTransform | None = None,
292
- task_weights: np.ndarray | None = None,
293
- ) -> models.MPNN:
294
- """Build an MPNN model with specified hyperparameters."""
295
- hidden_dim = hyperparameters["hidden_dim"]
296
- depth = hyperparameters["depth"]
314
+ def _load_foundation_weights(from_foundation: str) -> tuple[nn.BondMessagePassing, nn.Aggregation]:
315
+ """Load pretrained MPNN weights from foundation model.
316
+
317
+ Args:
318
+ from_foundation: "CheMeleon" or path to .pt file
319
+
320
+ Returns:
321
+ Tuple of (message_passing, aggregation) modules
322
+ """
323
+ import urllib.request
324
+ from pathlib import Path
325
+
326
+ print(f"Loading foundation model: {from_foundation}")
327
+
328
+ if from_foundation.lower() == "chemeleon":
329
+ # Download from Zenodo if not cached
330
+ cache_dir = Path.home() / ".chemprop" / "foundation"
331
+ cache_dir.mkdir(parents=True, exist_ok=True)
332
+ chemeleon_path = cache_dir / "chemeleon_mp.pt"
333
+
334
+ if not chemeleon_path.exists():
335
+ print(" Downloading CheMeleon weights from Zenodo...")
336
+ urllib.request.urlretrieve(
337
+ "https://zenodo.org/records/15460715/files/chemeleon_mp.pt", chemeleon_path
338
+ )
339
+ print(f" Downloaded to {chemeleon_path}")
340
+
341
+ ckpt = torch.load(chemeleon_path, weights_only=True)
342
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
343
+ mp.load_state_dict(ckpt["state_dict"])
344
+ print(f" Loaded CheMeleon MPNN (hidden_dim={mp.output_dim})")
345
+ return mp, nn.MeanAggregation()
346
+
347
+ if not os.path.exists(from_foundation):
348
+ raise ValueError(f"Foundation model not found: {from_foundation}. Use 'CheMeleon' or a valid .pt path.")
349
+
350
+ ckpt = torch.load(from_foundation, weights_only=False)
351
+ if "hyper_parameters" in ckpt and "state_dict" in ckpt:
352
+ # CheMeleon-style checkpoint
353
+ mp = nn.BondMessagePassing(**ckpt["hyper_parameters"])
354
+ mp.load_state_dict(ckpt["state_dict"])
355
+ print(f" Loaded custom foundation weights (hidden_dim={mp.output_dim})")
356
+ return mp, nn.MeanAggregation()
357
+
358
+ # Full MPNN model file
359
+ pretrained = models.MPNN.load_from_file(from_foundation)
360
+ print(f" Loaded custom MPNN (hidden_dim={pretrained.message_passing.output_dim})")
361
+ return pretrained.message_passing, pretrained.agg
362
+
363
+ def _build_ffn(
364
+ task: str, input_dim: int, hyperparameters: dict,
365
+ num_classes: int | None, n_targets: int,
366
+ output_transform: nn.UnscaleTransform | None, task_weights: np.ndarray | None,
367
+ ) -> nn.Predictor:
368
+ """Build task-specific FFN head."""
297
369
  dropout = hyperparameters["dropout"]
298
370
  ffn_hidden_dim = hyperparameters["ffn_hidden_dim"]
299
371
  ffn_num_layers = hyperparameters["ffn_num_layers"]
300
372
 
301
- mp = nn.BondMessagePassing(d_h=hidden_dim, depth=depth, dropout=dropout)
302
- agg = nn.NormAggregation()
303
- ffn_input_dim = hidden_dim + n_extra_descriptors
304
-
305
373
  if task == "classification" and num_classes is not None:
306
- ffn = nn.MulticlassClassificationFFN(
307
- n_classes=num_classes, input_dim=ffn_input_dim,
374
+ return nn.MulticlassClassificationFFN(
375
+ n_classes=num_classes, input_dim=input_dim,
308
376
  hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers, dropout=dropout,
309
377
  )
378
+
379
+ from chemprop.nn.metrics import MAE, MSE
380
+ criterion_map = {"mae": MAE, "mse": MSE}
381
+ criterion_name = hyperparameters.get("criterion", "mae")
382
+ if criterion_name not in criterion_map:
383
+ raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
384
+
385
+ weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
386
+ return nn.RegressionFFN(
387
+ input_dim=input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
388
+ dropout=dropout, n_tasks=n_targets, output_transform=output_transform,
389
+ task_weights=weights_tensor, criterion=criterion_map[criterion_name](),
390
+ )
391
+
392
+ def build_mpnn_model(
393
+ hyperparameters: dict, task: str = "regression", num_classes: int | None = None,
394
+ n_targets: int = 1, n_extra_descriptors: int = 0,
395
+ x_d_transform: nn.ScaleTransform | None = None,
396
+ output_transform: nn.UnscaleTransform | None = None, task_weights: np.ndarray | None = None,
397
+ ) -> models.MPNN:
398
+ """Build MPNN model, optionally loading pretrained weights."""
399
+ from_foundation = hyperparameters.get("from_foundation")
400
+
401
+ if from_foundation:
402
+ mp, agg = _load_foundation_weights(from_foundation)
403
+ ffn_input_dim = mp.output_dim + n_extra_descriptors
310
404
  else:
311
- # Map criterion name to ChemProp metric class (must have .clone() method)
312
- from chemprop.nn.metrics import MAE, MSE
313
-
314
- criterion_map = {
315
- "mae": MAE,
316
- "mse": MSE,
317
- }
318
- criterion_name = hyperparameters.get("criterion", "mae")
319
- if criterion_name not in criterion_map:
320
- raise ValueError(f"Unknown criterion '{criterion_name}'. Supported: {list(criterion_map.keys())}")
321
- criterion = criterion_map[criterion_name]()
322
-
323
- weights_tensor = torch.tensor(task_weights, dtype=torch.float32) if task_weights is not None else None
324
- ffn = nn.RegressionFFN(
325
- input_dim=ffn_input_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
326
- dropout=dropout, n_tasks=n_targets, output_transform=output_transform, task_weights=weights_tensor,
327
- criterion=criterion,
405
+ mp = nn.BondMessagePassing(
406
+ d_h=hyperparameters["hidden_dim"], depth=hyperparameters["depth"],
407
+ dropout=hyperparameters["dropout"],
328
408
  )
409
+ agg = nn.NormAggregation()
410
+ ffn_input_dim = hyperparameters["hidden_dim"] + n_extra_descriptors
329
411
 
412
+ ffn = _build_ffn(task, ffn_input_dim, hyperparameters, num_classes, n_targets, output_transform, task_weights)
330
413
  return models.MPNN(message_passing=mp, agg=agg, predictor=ffn, batch_norm=True, metrics=None, X_d_transform=x_d_transform)
331
414
 
332
415
  # -------------------------------------------------------------------------
@@ -359,6 +442,14 @@ if __name__ == "__main__":
359
442
  print(f"Extra features: {extra_feature_cols if use_extra_features else 'None (SMILES only)'}")
360
443
  print(f"Hyperparameters: {hyperparameters}")
361
444
 
445
+ # Log foundation model configuration
446
+ if hyperparameters.get("from_foundation"):
447
+ freeze_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
448
+ freeze_msg = f"MPNN frozen for {freeze_epochs} epochs" if freeze_epochs > 0 else "no freezing"
449
+ print(f"Foundation model: {hyperparameters['from_foundation']} ({freeze_msg})")
450
+ else:
451
+ print("Foundation model: None (training from scratch)")
452
+
362
453
  # Load training data
363
454
  training_files = [os.path.join(args.train, f) for f in os.listdir(args.train) if f.endswith(".csv")]
364
455
  print(f"Training Files: {training_files}")
@@ -456,7 +547,7 @@ if __name__ == "__main__":
456
547
  print(f"Fold {fold_idx + 1}/{len(folds)} - Train: {len(train_idx)}, Val: {len(val_idx)}")
457
548
  print(f"{'='*50}")
458
549
 
459
- # Split data
550
+ # Split data (val_extra_raw preserves unscaled features for OOF predictions)
460
551
  df_train, df_val = all_df.iloc[train_idx].reset_index(drop=True), all_df.iloc[val_idx].reset_index(drop=True)
461
552
  train_targets, val_targets = all_targets[train_idx], all_targets[val_idx]
462
553
  train_extra = all_extra_features[train_idx] if all_extra_features is not None else None
@@ -484,7 +575,7 @@ if __name__ == "__main__":
484
575
  train_loader = data.build_dataloader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
485
576
  val_loader = data.build_dataloader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
486
577
 
487
- # Build and train model
578
+ # Build model
488
579
  pl.seed_everything(hyperparameters["seed"] + fold_idx)
489
580
  mpnn = build_mpnn_model(
490
581
  hyperparameters, task=task, num_classes=num_classes, n_targets=n_targets,
@@ -492,14 +583,39 @@ if __name__ == "__main__":
492
583
  output_transform=output_transform, task_weights=task_weights,
493
584
  )
494
585
 
495
- trainer = pl.Trainer(
496
- accelerator="auto", max_epochs=hyperparameters["max_epochs"], logger=False, enable_progress_bar=True,
497
- callbacks=[
498
- pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min"),
499
- pl.callbacks.ModelCheckpoint(dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1),
500
- ],
501
- )
502
- trainer.fit(mpnn, train_loader, val_loader)
586
+ # Train model (with optional two-phase foundation training)
587
+ freeze_mpnn_epochs = hyperparameters.get("freeze_mpnn_epochs", 0)
588
+ use_two_phase = hyperparameters.get("from_foundation") and freeze_mpnn_epochs > 0
589
+
590
+ def _set_mpnn_frozen(frozen: bool):
591
+ for param in mpnn.message_passing.parameters():
592
+ param.requires_grad = not frozen
593
+ for param in mpnn.agg.parameters():
594
+ param.requires_grad = not frozen
595
+
596
+ def _make_trainer(max_epochs: int, save_checkpoint: bool = False):
597
+ callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=hyperparameters["patience"], mode="min")]
598
+ if save_checkpoint:
599
+ callbacks.append(pl.callbacks.ModelCheckpoint(
600
+ dirpath=args.model_dir, filename=f"best_{fold_idx}", monitor="val_loss", mode="min", save_top_k=1
601
+ ))
602
+ return pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False, enable_progress_bar=True, callbacks=callbacks)
603
+
604
+ if use_two_phase:
605
+ # Phase 1: Freeze MPNN, train FFN only
606
+ print(f"Phase 1: Training with frozen MPNN for {freeze_mpnn_epochs} epochs...")
607
+ _set_mpnn_frozen(True)
608
+ _make_trainer(freeze_mpnn_epochs).fit(mpnn, train_loader, val_loader)
609
+
610
+ # Phase 2: Unfreeze and fine-tune all
611
+ print("Phase 2: Unfreezing MPNN, continuing training...")
612
+ _set_mpnn_frozen(False)
613
+ remaining_epochs = max(1, hyperparameters["max_epochs"] - freeze_mpnn_epochs)
614
+ trainer = _make_trainer(remaining_epochs, save_checkpoint=True)
615
+ trainer.fit(mpnn, train_loader, val_loader)
616
+ else:
617
+ trainer = _make_trainer(hyperparameters["max_epochs"], save_checkpoint=True)
618
+ trainer.fit(mpnn, train_loader, val_loader)
503
619
 
504
620
  # Load best checkpoint
505
621
  if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
@@ -509,7 +625,7 @@ if __name__ == "__main__":
509
625
  mpnn.eval()
510
626
  ensemble_models.append(mpnn)
511
627
 
512
- # Out-of-fold predictions (using raw features)
628
+ # Out-of-fold predictions (using unscaled features - model's x_d_transform handles scaling)
513
629
  val_dps_raw, _ = _create_molecule_datapoints(df_val[smiles_column].tolist(), val_targets, val_extra_raw)
514
630
  val_loader_pred = data.build_dataloader(data.MoleculeDataset(val_dps_raw), batch_size=batch_size, shuffle=False)
515
631
 
@@ -599,11 +715,17 @@ if __name__ == "__main__":
599
715
  df_val["prediction"] = df_val[f"{target_columns[0]}_pred"]
600
716
  df_val["prediction_std"] = df_val[f"{target_columns[0]}_pred_std"]
601
717
 
602
- # Compute confidence from ensemble std
603
- median_std = float(np.median(preds_std[:, 0]))
604
- print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
605
- df_val = _compute_std_confidence(df_val, median_std)
606
- print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
718
+ # Compute confidence from ensemble std (or NaN for single model)
719
+ if preds_std is not None:
720
+ median_std = float(np.median(preds_std[:, 0]))
721
+ print(f"\nComputing confidence scores (median_std={median_std:.6f})...")
722
+ df_val = _compute_std_confidence(df_val, median_std)
723
+ print(f" Confidence: mean={df_val['confidence'].mean():.3f}, min={df_val['confidence'].min():.3f}, max={df_val['confidence'].max():.3f}")
724
+ else:
725
+ # Single model - no ensemble std available, confidence is undefined
726
+ median_std = None
727
+ df_val["confidence"] = np.nan
728
+ print("\nSingle model (n_folds=1): No ensemble std, confidence set to NaN")
607
729
 
608
730
  # -------------------------------------------------------------------------
609
731
  # Save validation predictions to S3
@@ -633,6 +755,9 @@ if __name__ == "__main__":
633
755
  "n_folds": n_folds,
634
756
  "target_columns": target_columns,
635
757
  "median_std": median_std, # For confidence calculation during inference
758
+ # Foundation model provenance (for tracking/reproducibility)
759
+ "from_foundation": hyperparameters.get("from_foundation", None),
760
+ "freeze_mpnn_epochs": hyperparameters.get("freeze_mpnn_epochs", 0),
636
761
  }
637
762
  joblib.dump(ensemble_metadata, os.path.join(args.model_dir, "ensemble_metadata.joblib"))
638
763
 
@@ -61,10 +61,10 @@ DEFAULT_HYPERPARAMETERS = {
61
61
  TEMPLATE_PARAMS = {
62
62
  "model_type": "uq_regressor",
63
63
  "target": "udm_asy_res_efflux_ratio",
64
- "features": ['fingerprint'],
64
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
65
65
  "id_column": "udm_mol_bat_id",
66
- "compressed_features": ['fingerprint'],
67
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-fp-pytorch/training",
66
+ "compressed_features": [],
67
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-pytorch-260113/training",
68
68
  "hyperparameters": {},
69
69
  }
70
70
 
@@ -65,11 +65,11 @@ REGRESSION_ONLY_PARAMS = {"objective"}
65
65
  TEMPLATE_PARAMS = {
66
66
  "model_type": "uq_regressor",
67
67
  "target": "udm_asy_res_efflux_ratio",
68
- "features": ['fingerprint'],
68
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
69
69
  "id_column": "udm_mol_bat_id",
70
70
  "compressed_features": ['fingerprint'],
71
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-fp/training",
72
- "hyperparameters": {},
71
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-temporal/training",
72
+ "hyperparameters": {'n_folds': 1},
73
73
  }
74
74
 
75
75
 
@@ -1,6 +1,8 @@
1
1
  import argparse
2
+ import ast
2
3
  import logging
3
4
  import json
5
+ import re
4
6
  from pathlib import Path
5
7
 
6
8
  # Workbench Imports
@@ -13,6 +15,56 @@ cm = ConfigManager()
13
15
  workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
16
 
15
17
 
18
+ def parse_workbench_batch(script_content: str) -> dict | None:
19
+ """Parse WORKBENCH_BATCH config from a script.
20
+
21
+ Looks for a dictionary assignment like:
22
+ WORKBENCH_BATCH = {
23
+ "outputs": ["feature_set_xyz"],
24
+ }
25
+ or:
26
+ WORKBENCH_BATCH = {
27
+ "inputs": ["feature_set_xyz"],
28
+ }
29
+
30
+ Args:
31
+ script_content: The Python script content as a string
32
+
33
+ Returns:
34
+ The parsed dictionary or None if not found
35
+ """
36
+ pattern = r"WORKBENCH_BATCH\s*=\s*(\{[^}]+\})"
37
+ match = re.search(pattern, script_content, re.DOTALL)
38
+ if match:
39
+ try:
40
+ return ast.literal_eval(match.group(1))
41
+ except (ValueError, SyntaxError) as e:
42
+ print(f"⚠️ Warning: Failed to parse WORKBENCH_BATCH: {e}")
43
+ return None
44
+ return None
45
+
46
+
47
+ def get_message_group_id(batch_config: dict | None) -> str:
48
+ """Derive MessageGroupId from outputs or inputs.
49
+
50
+ - Scripts with outputs use first output as group
51
+ - Scripts with inputs use first input as group
52
+ - Default to "ml-pipeline-jobs" if no config
53
+ """
54
+ if not batch_config:
55
+ return "ml-pipeline-jobs"
56
+
57
+ outputs = batch_config.get("outputs", [])
58
+ inputs = batch_config.get("inputs", [])
59
+
60
+ if outputs:
61
+ return outputs[0]
62
+ elif inputs:
63
+ return inputs[0]
64
+ else:
65
+ return "ml-pipeline-jobs"
66
+
67
+
16
68
  def submit_to_sqs(
17
69
  script_path: str,
18
70
  size: str = "small",
@@ -44,12 +96,24 @@ def submit_to_sqs(
44
96
  if not script_file.exists():
45
97
  raise FileNotFoundError(f"Script not found: {script_path}")
46
98
 
99
+ # Read script content and parse WORKBENCH_BATCH config
100
+ script_content = script_file.read_text()
101
+ batch_config = parse_workbench_batch(script_content)
102
+ group_id = get_message_group_id(batch_config)
103
+ outputs = (batch_config or {}).get("outputs", [])
104
+ inputs = (batch_config or {}).get("inputs", [])
105
+
47
106
  print(f"📄 Script: {script_file.name}")
48
107
  print(f"📏 Size tier: {size}")
49
108
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
50
109
  print(f"🔄 DynamicTraining: {dt}")
51
110
  print(f"🆕 Promote: {promote}")
52
111
  print(f"🪣 Bucket: {workbench_bucket}")
112
+ if outputs:
113
+ print(f"📤 Outputs: {outputs}")
114
+ if inputs:
115
+ print(f"📥 Inputs: {inputs}")
116
+ print(f"📦 Batch Group: {group_id}")
53
117
  sqs = AWSAccountClamp().boto3_session.client("sqs")
54
118
  script_name = script_file.name
55
119
 
@@ -75,7 +139,7 @@ def submit_to_sqs(
75
139
  print(f" Destination: {s3_path}")
76
140
 
77
141
  try:
78
- upload_content_to_s3(script_file.read_text(), s3_path)
142
+ upload_content_to_s3(script_content, s3_path)
79
143
  print("✅ Script uploaded successfully")
80
144
  except Exception as e:
81
145
  print(f"❌ Upload failed: {e}")
@@ -118,7 +182,7 @@ def submit_to_sqs(
118
182
  response = sqs.send_message(
119
183
  QueueUrl=queue_url,
120
184
  MessageBody=json.dumps(message, indent=2),
121
- MessageGroupId="ml-pipeline-jobs", # Required for FIFO
185
+ MessageGroupId=group_id, # From WORKBENCH_BATCH or default
122
186
  )
123
187
  message_id = response["MessageId"]
124
188
  print("✅ Message sent successfully!")
@@ -136,6 +200,11 @@ def submit_to_sqs(
136
200
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
137
201
  print(f"🔄 DynamicTraining: {dt}")
138
202
  print(f"🆕 Promote: {promote}")
203
+ if outputs:
204
+ print(f"📤 Outputs: {outputs}")
205
+ if inputs:
206
+ print(f"📥 Inputs: {inputs}")
207
+ print(f"📦 Batch Group: {group_id}")
139
208
  print(f"🆔 Message ID: {message_id}")
140
209
  print("\n🔍 MONITORING LOCATIONS:")
141
210
  print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
@@ -30,9 +30,10 @@ ul, ol {
30
30
  --ag-header-background-color: rgba(150, 150, 195);
31
31
  }
32
32
 
33
- /* Adjust cell background */
33
+ /* Adjust cell background and text color */
34
34
  .ag-cell {
35
35
  background-color: rgb(240, 240, 240);
36
+ color: rgb(80, 80, 80);
36
37
  }
37
38
 
38
39
  /* Alternate row colors */
@@ -40,6 +41,11 @@ ul, ol {
40
41
  background-color: rgb(230, 230, 230);
41
42
  }
42
43
 
44
+ /* AgGrid header text color */
45
+ .ag-header-cell-text {
46
+ color: rgb(60, 60, 60);
47
+ }
48
+
43
49
  /* Selection color for the entire row */
44
50
  .ag-row.ag-row-selected .ag-cell {
45
51
  background-color: rgba(170, 170, 205, 1.0);
@@ -133,6 +133,40 @@ a:hover {
133
133
  color: rgb(100, 255, 100);
134
134
  }
135
135
 
136
+ /* Dropdown styling (dcc.Dropdown) - override Bootstrap's --bs-body-bg variable */
137
+ .dash-dropdown {
138
+ --bs-body-bg: rgb(55, 60, 90);
139
+ --bs-border-color: rgb(80, 85, 115);
140
+ }
141
+
142
+
143
+ /* Bootstrap form controls (dbc components) */
144
+ .form-select, .form-control {
145
+ background-color: rgb(55, 60, 90) !important;
146
+ border: 1px solid rgb(80, 85, 115) !important;
147
+ color: rgb(210, 210, 210) !important;
148
+ }
149
+
150
+ .form-select:focus, .form-control:focus {
151
+ background-color: rgb(60, 65, 95) !important;
152
+ border-color: rgb(100, 105, 140) !important;
153
+ box-shadow: 0 0 0 0.2rem rgba(100, 105, 140, 0.25) !important;
154
+ }
155
+
156
+ .dropdown-menu {
157
+ background-color: rgb(55, 60, 90) !important;
158
+ border: 1px solid rgb(80, 85, 115) !important;
159
+ }
160
+
161
+ .dropdown-item {
162
+ color: rgb(210, 210, 210) !important;
163
+ }
164
+
165
+ .dropdown-item:hover, .dropdown-item:focus {
166
+ background-color: rgb(70, 75, 110) !important;
167
+ color: rgb(230, 230, 230) !important;
168
+ }
169
+
136
170
  /* Table styling */
137
171
  table {
138
172
  width: 100%;