workbench 0.8.213__py3-none-any.whl → 0.8.217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/api/__init__.py +3 -0
  9. workbench/api/endpoint.py +10 -5
  10. workbench/api/feature_set.py +76 -6
  11. workbench/api/meta_model.py +289 -0
  12. workbench/api/model.py +43 -4
  13. workbench/core/artifacts/endpoint_core.py +63 -115
  14. workbench/core/artifacts/feature_set_core.py +1 -1
  15. workbench/core/artifacts/model_core.py +6 -4
  16. workbench/core/pipelines/pipeline_executor.py +1 -1
  17. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  18. workbench/model_script_utils/pytorch_utils.py +11 -1
  19. workbench/model_scripts/chemprop/chemprop.template +145 -69
  20. workbench/model_scripts/chemprop/generated_model_script.py +147 -71
  21. workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
  22. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  23. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  24. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  25. workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
  26. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  27. workbench/model_scripts/meta_model/meta_model.template +209 -0
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
  29. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  30. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  31. workbench/model_scripts/script_generation.py +4 -0
  32. workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
  33. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  34. workbench/repl/workbench_shell.py +0 -5
  35. workbench/scripts/endpoint_test.py +2 -2
  36. workbench/utils/chem_utils/fingerprints.py +7 -3
  37. workbench/utils/chemprop_utils.py +23 -5
  38. workbench/utils/meta_model_simulator.py +471 -0
  39. workbench/utils/metrics_utils.py +94 -10
  40. workbench/utils/model_utils.py +91 -9
  41. workbench/utils/pytorch_utils.py +1 -1
  42. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  43. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
  44. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
  45. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  46. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  47. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
  48. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
  49. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
  50. {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
@@ -7,39 +7,30 @@
7
7
  # - Sample weights support
8
8
  # - Categorical feature handling
9
9
  # - Compressed feature decompression
10
+ #
11
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
12
+ # Heavy imports (sklearn, awswrangler) are deferred to training time.
10
13
 
11
- import argparse
12
14
  import json
13
15
  import os
14
16
 
15
- import awswrangler as wr
16
17
  import joblib
17
18
  import numpy as np
18
19
  import pandas as pd
19
20
  import xgboost as xgb
20
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
21
- from sklearn.preprocessing import LabelEncoder
22
21
 
23
22
  from model_script_utils import (
24
- check_dataframe,
25
- compute_classification_metrics,
26
- compute_regression_metrics,
27
23
  convert_categorical_types,
28
24
  decompress_features,
29
25
  expand_proba_column,
30
26
  input_fn,
31
27
  match_features_case_insensitive,
32
28
  output_fn,
33
- print_classification_metrics,
34
- print_confusion_matrix,
35
- print_regression_metrics,
36
29
  )
37
30
  from uq_harness import (
38
31
  compute_confidence,
39
32
  load_uq_models,
40
33
  predict_intervals,
41
- save_uq_models,
42
- train_uq_models,
43
34
  )
44
35
 
45
36
  # =============================================================================
@@ -49,25 +40,27 @@ DEFAULT_HYPERPARAMETERS = {
49
40
  # Training parameters
50
41
  "n_folds": 5, # Number of CV folds (1 = single train/val split)
51
42
  # Core tree parameters
52
- "n_estimators": 200,
53
- "max_depth": 6,
43
+ "n_estimators": 300,
44
+ "max_depth": 7,
54
45
  "learning_rate": 0.05,
55
- # Sampling parameters
56
- "subsample": 0.7,
57
- "colsample_bytree": 0.6,
58
- "colsample_bylevel": 0.8,
59
- # Regularization
60
- "min_child_weight": 5,
61
- "gamma": 0.2,
62
- "reg_alpha": 0.5,
63
- "reg_lambda": 2.0,
46
+ # Sampling parameters (less aggressive - ensemble provides regularization)
47
+ "subsample": 0.8,
48
+ "colsample_bytree": 0.8,
49
+ # Regularization (lighter - ensemble averaging reduces overfitting)
50
+ "min_child_weight": 3,
51
+ "gamma": 0.1,
52
+ "reg_alpha": 0.1,
53
+ "reg_lambda": 1.0,
64
54
  # Random seed
65
- "random_state": 42,
55
+ "seed": 42,
66
56
  }
67
57
 
68
58
  # Workbench-specific parameters (not passed to XGBoost)
69
59
  WORKBENCH_PARAMS = {"n_folds"}
70
60
 
61
+ # Regression-only parameters (filtered out for classifiers)
62
+ REGRESSION_ONLY_PARAMS = {"objective"}
63
+
71
64
  # Template parameters (filled in by Workbench)
72
65
  TEMPLATE_PARAMS = {
73
66
  "model_type": "{{model_type}}",
@@ -80,10 +73,140 @@ TEMPLATE_PARAMS = {
80
73
  }
81
74
 
82
75
 
76
+ # =============================================================================
77
+ # Model Loading (for SageMaker inference)
78
+ # =============================================================================
79
+ def model_fn(model_dir: str) -> dict:
80
+ """Load XGBoost ensemble from the specified directory."""
81
+ # Load ensemble metadata
82
+ metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
83
+ if os.path.exists(metadata_path):
84
+ with open(metadata_path) as f:
85
+ metadata = json.load(f)
86
+ n_ensemble = metadata["n_ensemble"]
87
+ else:
88
+ n_ensemble = 1 # Legacy single model
89
+
90
+ # Load ensemble models
91
+ ensemble_models = []
92
+ for i in range(n_ensemble):
93
+ model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
94
+ if not os.path.exists(model_path):
95
+ model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
96
+ ensemble_models.append(joblib.load(model_path))
97
+
98
+ print(f"Loaded {len(ensemble_models)} model(s)")
99
+
100
+ # Load label encoder (classifier only)
101
+ label_encoder = None
102
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
103
+ if os.path.exists(encoder_path):
104
+ label_encoder = joblib.load(encoder_path)
105
+
106
+ # Load category mappings
107
+ category_mappings = {}
108
+ category_path = os.path.join(model_dir, "category_mappings.json")
109
+ if os.path.exists(category_path):
110
+ with open(category_path) as f:
111
+ category_mappings = json.load(f)
112
+
113
+ # Load UQ models (regression only)
114
+ uq_models, uq_metadata = None, None
115
+ uq_path = os.path.join(model_dir, "uq_metadata.json")
116
+ if os.path.exists(uq_path):
117
+ uq_models, uq_metadata = load_uq_models(model_dir)
118
+
119
+ return {
120
+ "ensemble_models": ensemble_models,
121
+ "n_ensemble": n_ensemble,
122
+ "label_encoder": label_encoder,
123
+ "category_mappings": category_mappings,
124
+ "uq_models": uq_models,
125
+ "uq_metadata": uq_metadata,
126
+ }
127
+
128
+
129
+ # =============================================================================
130
+ # Inference (for SageMaker inference)
131
+ # =============================================================================
132
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
133
+ """Make predictions with XGBoost ensemble."""
134
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
135
+ with open(os.path.join(model_dir, "feature_columns.json")) as f:
136
+ features = json.load(f)
137
+ print(f"Model Features: {features}")
138
+
139
+ # Extract model components
140
+ ensemble_models = model_dict["ensemble_models"]
141
+ label_encoder = model_dict.get("label_encoder")
142
+ category_mappings = model_dict.get("category_mappings", {})
143
+ uq_models = model_dict.get("uq_models")
144
+ uq_metadata = model_dict.get("uq_metadata")
145
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
146
+
147
+ # Prepare features
148
+ matched_df = match_features_case_insensitive(df, features)
149
+ matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
150
+
151
+ if compressed_features:
152
+ print("Decompressing features for prediction...")
153
+ matched_df, features = decompress_features(matched_df, features, compressed_features)
154
+
155
+ X = matched_df[features]
156
+
157
+ # Collect ensemble predictions
158
+ all_preds = [m.predict(X) for m in ensemble_models]
159
+ ensemble_preds = np.stack(all_preds, axis=0)
160
+
161
+ if label_encoder is not None:
162
+ # Classification: average probabilities, then argmax
163
+ all_probs = [m.predict_proba(X) for m in ensemble_models]
164
+ avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
165
+ class_preds = np.argmax(avg_probs, axis=1)
166
+
167
+ df["prediction"] = label_encoder.inverse_transform(class_preds)
168
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
169
+ df = expand_proba_column(df, label_encoder.classes_)
170
+ else:
171
+ # Regression: average predictions
172
+ df["prediction"] = np.mean(ensemble_preds, axis=0)
173
+ df["prediction_std"] = np.std(ensemble_preds, axis=0)
174
+
175
+ # Add UQ intervals if available
176
+ if uq_models and uq_metadata:
177
+ df = predict_intervals(df, X, uq_models, uq_metadata)
178
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
179
+
180
+ print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
181
+ return df
182
+
183
+
83
184
  # =============================================================================
84
185
  # Training
85
186
  # =============================================================================
86
187
  if __name__ == "__main__":
188
+ # -------------------------------------------------------------------------
189
+ # Training-only imports (deferred to reduce serverless startup time)
190
+ # -------------------------------------------------------------------------
191
+ import argparse
192
+
193
+ import awswrangler as wr
194
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
195
+ from sklearn.preprocessing import LabelEncoder
196
+
197
+ from model_script_utils import (
198
+ check_dataframe,
199
+ compute_classification_metrics,
200
+ compute_regression_metrics,
201
+ print_classification_metrics,
202
+ print_confusion_matrix,
203
+ print_regression_metrics,
204
+ )
205
+ from uq_harness import (
206
+ save_uq_models,
207
+ train_uq_models,
208
+ )
209
+
87
210
  # -------------------------------------------------------------------------
88
211
  # Setup: Parse arguments and load data
89
212
  # -------------------------------------------------------------------------
@@ -123,7 +246,7 @@ if __name__ == "__main__":
123
246
  all_df, features = decompress_features(all_df, features, compressed_features)
124
247
 
125
248
  # -------------------------------------------------------------------------
126
- # Classification setup: Encode target labels
249
+ # Classification setup
127
250
  # -------------------------------------------------------------------------
128
251
  label_encoder = None
129
252
  if model_type == "classifier":
@@ -136,6 +259,18 @@ if __name__ == "__main__":
136
259
  # -------------------------------------------------------------------------
137
260
  n_folds = hyperparameters["n_folds"]
138
261
  xgb_params = {k: v for k, v in hyperparameters.items() if k not in WORKBENCH_PARAMS}
262
+
263
+ # Map 'seed' to 'random_state' for XGBoost
264
+ if "seed" in xgb_params:
265
+ xgb_params["random_state"] = xgb_params.pop("seed")
266
+
267
+ # Handle objective: filter regression-only params for classifiers, set default for regressors
268
+ if model_type == "classifier":
269
+ xgb_params = {k: v for k, v in xgb_params.items() if k not in REGRESSION_ONLY_PARAMS}
270
+ else:
271
+ # Default to MAE (reg:absoluteerror) for regression if not specified
272
+ xgb_params.setdefault("objective", "reg:absoluteerror")
273
+
139
274
  print(f"XGBoost params: {xgb_params}")
140
275
 
141
276
  if n_folds == 1:
@@ -285,12 +420,10 @@ if __name__ == "__main__":
285
420
  # -------------------------------------------------------------------------
286
421
  # Save model artifacts
287
422
  # -------------------------------------------------------------------------
288
- # Ensemble models
289
- for idx, ens_model in enumerate(ensemble_models):
290
- joblib.dump(ens_model, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
291
- print(f"Saved {len(ensemble_models)} XGBoost model(s)")
423
+ for idx, m in enumerate(ensemble_models):
424
+ joblib.dump(m, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
425
+ print(f"Saved {len(ensemble_models)} model(s)")
292
426
 
293
- # Metadata files
294
427
  with open(os.path.join(args.model_dir, "ensemble_metadata.json"), "w") as f:
295
428
  json.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds}, f)
296
429
 
@@ -310,125 +443,3 @@ if __name__ == "__main__":
310
443
  save_uq_models(uq_models, uq_metadata, args.model_dir)
311
444
 
312
445
  print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
313
-
314
-
315
- # =============================================================================
316
- # Model Loading (for SageMaker inference)
317
- # =============================================================================
318
- def model_fn(model_dir: str) -> dict:
319
- """Load XGBoost ensemble and associated artifacts.
320
-
321
- Args:
322
- model_dir: Directory containing model artifacts
323
-
324
- Returns:
325
- Dictionary with ensemble_models, label_encoder, category_mappings, uq_models, etc.
326
- """
327
- # Load ensemble metadata
328
- metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
329
- if os.path.exists(metadata_path):
330
- with open(metadata_path) as f:
331
- metadata = json.load(f)
332
- n_ensemble = metadata["n_ensemble"]
333
- else:
334
- n_ensemble = 1 # Legacy single model
335
-
336
- # Load ensemble models
337
- ensemble_models = []
338
- for i in range(n_ensemble):
339
- model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
340
- if not os.path.exists(model_path):
341
- model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
342
- ensemble_models.append(joblib.load(model_path))
343
-
344
- # Load label encoder (classifier only)
345
- label_encoder = None
346
- encoder_path = os.path.join(model_dir, "label_encoder.joblib")
347
- if os.path.exists(encoder_path):
348
- label_encoder = joblib.load(encoder_path)
349
-
350
- # Load category mappings
351
- category_mappings = {}
352
- category_path = os.path.join(model_dir, "category_mappings.json")
353
- if os.path.exists(category_path):
354
- with open(category_path) as f:
355
- category_mappings = json.load(f)
356
-
357
- # Load UQ models (regression only)
358
- uq_models, uq_metadata = None, None
359
- uq_path = os.path.join(model_dir, "uq_metadata.json")
360
- if os.path.exists(uq_path):
361
- uq_models, uq_metadata = load_uq_models(model_dir)
362
-
363
- return {
364
- "ensemble_models": ensemble_models,
365
- "n_ensemble": n_ensemble,
366
- "label_encoder": label_encoder,
367
- "category_mappings": category_mappings,
368
- "uq_models": uq_models,
369
- "uq_metadata": uq_metadata,
370
- }
371
-
372
-
373
- # =============================================================================
374
- # Inference (for SageMaker inference)
375
- # =============================================================================
376
- def predict_fn(df: pd.DataFrame, models: dict) -> pd.DataFrame:
377
- """Make predictions with XGBoost ensemble.
378
-
379
- Args:
380
- df: Input DataFrame with features
381
- models: Dictionary from model_fn containing ensemble and metadata
382
-
383
- Returns:
384
- DataFrame with predictions added
385
- """
386
- # Load feature columns
387
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
388
- with open(os.path.join(model_dir, "feature_columns.json")) as f:
389
- features = json.load(f)
390
- print(f"Model Features: {features}")
391
-
392
- # Extract model components
393
- ensemble_models = models["ensemble_models"]
394
- label_encoder = models.get("label_encoder")
395
- category_mappings = models.get("category_mappings", {})
396
- uq_models = models.get("uq_models")
397
- uq_metadata = models.get("uq_metadata")
398
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
399
-
400
- # Prepare features
401
- matched_df = match_features_case_insensitive(df, features)
402
- matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
403
-
404
- if compressed_features:
405
- print("Decompressing features for prediction...")
406
- matched_df, features = decompress_features(matched_df, features, compressed_features)
407
-
408
- X = matched_df[features]
409
-
410
- # Collect ensemble predictions
411
- all_preds = [m.predict(X) for m in ensemble_models]
412
- ensemble_preds = np.stack(all_preds, axis=0)
413
-
414
- if label_encoder is not None:
415
- # Classification: average probabilities, then argmax
416
- all_probs = [m.predict_proba(X) for m in ensemble_models]
417
- avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
418
- class_preds = np.argmax(avg_probs, axis=1)
419
-
420
- df["prediction"] = label_encoder.inverse_transform(class_preds)
421
- df["pred_proba"] = [p.tolist() for p in avg_probs]
422
- df = expand_proba_column(df, label_encoder.classes_)
423
- else:
424
- # Regression: average predictions
425
- df["prediction"] = np.mean(ensemble_preds, axis=0)
426
- df["prediction_std"] = np.std(ensemble_preds, axis=0)
427
-
428
- # Add UQ intervals if available
429
- if uq_models and uq_metadata:
430
- df = predict_intervals(df, X, uq_models, uq_metadata)
431
- df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
432
-
433
- print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
434
- return df
@@ -302,11 +302,6 @@ class WorkbenchShell:
302
302
  self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
303
303
  self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
304
304
 
305
- # Algorithms
306
- self.commands["FSP"] = importlib.import_module(
307
- "workbench.algorithms.dataframe.feature_space_proximity"
308
- ).FeatureSpaceProximity
309
-
310
305
  # These are 'nice to have' imports
311
306
  self.commands["pd"] = importlib.import_module("pandas")
312
307
  self.commands["wr"] = importlib.import_module("awswrangler")
@@ -5,7 +5,7 @@ Usage:
5
5
  python model_script_harness.py <local_script.py> <model_name>
6
6
 
7
7
  Example:
8
- python model_script_harness.py pytorch.py aqsol-pytorch-reg
8
+ python model_script_harness.py pytorch.py aqsol-reg-pytorch
9
9
 
10
10
  This allows you to test LOCAL changes to a model script against deployed model artifacts.
11
11
  Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
@@ -72,7 +72,7 @@ def main():
72
72
  print("Usage: python model_script_harness.py <local_script.py> <model_name>")
73
73
  print("\nArguments:")
74
74
  print(" local_script.py - Path to your LOCAL model script to test")
75
- print(" model_name - Workbench model name (e.g., aqsol-pytorch-reg)")
75
+ print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
76
76
  print("\nOptional: testing/env.json with additional environment variables")
77
77
  sys.exit(1)
78
78
 
@@ -4,10 +4,14 @@ import logging
4
4
  import pandas as pd
5
5
 
6
6
  # Molecular Descriptor Imports
7
- from rdkit import Chem
7
+ from rdkit import Chem, RDLogger
8
8
  from rdkit.Chem import rdFingerprintGenerator
9
9
  from rdkit.Chem.MolStandardize import rdMolStandardize
10
10
 
11
+ # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
12
+ # Keep errors enabled so we see actual problems
13
+ RDLogger.DisableLog("rdApp.warning")
14
+
11
15
  # Set up the logger
12
16
  log = logging.getLogger("workbench")
13
17
 
@@ -47,8 +51,8 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
47
51
  # Make sure our molecules are not None
48
52
  failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
49
53
  if failed_smiles:
50
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
51
- df = df.dropna(subset=["molecule"])
54
+ log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
55
+ df = df.dropna(subset=["molecule"]).copy()
52
56
 
53
57
  # If we have fragments in our compounds, get the largest fragment before computing fingerprints
54
58
  largest_frags = df["molecule"].apply(
@@ -76,6 +76,10 @@ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
76
76
  This retrieves the validation predictions saved during model training and
77
77
  computes metrics directly from them.
78
78
 
79
+ Note:
80
+ - Regression: Supports both single-target and multi-target models
81
+ - Classification: Only single-target is supported (with any number of classes)
82
+
79
83
  Args:
80
84
  workbench_model: Workbench model object
81
85
 
@@ -84,6 +88,7 @@ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
84
88
  - DataFrame with computed metrics
85
89
  - DataFrame with validation predictions
86
90
  """
91
+
87
92
  # Get the validation predictions from S3
88
93
  s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
89
94
  predictions_df = pull_s3_data(s3_path)
@@ -93,14 +98,27 @@ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
93
98
 
94
99
  log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
95
100
 
96
- # Compute metrics from predictions
101
+ # Get target and class labels
97
102
  target = workbench_model.target()
98
103
  class_labels = workbench_model.class_labels()
99
104
 
100
- if target in predictions_df.columns and "prediction" in predictions_df.columns:
105
+ # If single target just use the "prediction" column
106
+ if isinstance(target, str):
101
107
  metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
102
- else:
103
- metrics_df = pd.DataFrame()
108
+ return metrics_df, predictions_df
109
+
110
+ # Multi-target regression
111
+ metrics_list = []
112
+ for t in target:
113
+ # Prediction will be {target}_pred in multi-target case
114
+ pred_col = f"{t}_pred"
115
+
116
+ # Drop NaNs for this target
117
+ target_preds_df = predictions_df.dropna(subset=[t, pred_col])
118
+ metrics_df = compute_metrics_from_predictions(target_preds_df, t, class_labels, prediction_col=pred_col)
119
+ metrics_df.insert(0, "target", t)
120
+ metrics_list.append(metrics_df)
121
+ metrics_df = pd.concat(metrics_list, ignore_index=True) if metrics_list else pd.DataFrame()
104
122
 
105
123
  return metrics_df, predictions_df
106
124
 
@@ -111,7 +129,7 @@ if __name__ == "__main__":
111
129
  from workbench.api import Model
112
130
 
113
131
  # Initialize Workbench model
114
- model_name = "logd-reg-chemprop"
132
+ model_name = "open-admet-chemprop-mt"
115
133
  print(f"Loading Workbench model: {model_name}")
116
134
  model = Model(model_name)
117
135
  print(f"Model Framework: {model.model_framework}")