workbench 0.8.213__py3-none-any.whl → 0.8.217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
- workbench/algorithms/dataframe/projection_2d.py +38 -21
- workbench/algorithms/dataframe/proximity.py +75 -150
- workbench/algorithms/graph/light/proximity_graph.py +5 -5
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +2 -2
- workbench/api/__init__.py +3 -0
- workbench/api/endpoint.py +10 -5
- workbench/api/feature_set.py +76 -6
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +43 -4
- workbench/core/artifacts/endpoint_core.py +63 -115
- workbench/core/artifacts/feature_set_core.py +1 -1
- workbench/core/artifacts/model_core.py +6 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
- workbench/model_script_utils/pytorch_utils.py +11 -1
- workbench/model_scripts/chemprop/chemprop.template +145 -69
- workbench/model_scripts/chemprop/generated_model_script.py +147 -71
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +7 -3
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +6 -6
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +42 -24
- workbench/model_scripts/pytorch_model/pytorch.template +42 -24
- workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
- workbench/model_scripts/script_generation.py +4 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +169 -158
- workbench/model_scripts/xgb_model/xgb_model.template +163 -152
- workbench/repl/workbench_shell.py +0 -5
- workbench/scripts/endpoint_test.py +2 -2
- workbench/utils/chem_utils/fingerprints.py +7 -3
- workbench/utils/chemprop_utils.py +23 -5
- workbench/utils/meta_model_simulator.py +471 -0
- workbench/utils/metrics_utils.py +94 -10
- workbench/utils/model_utils.py +91 -9
- workbench/utils/pytorch_utils.py +1 -1
- workbench/web_interface/components/plugins/scatter_plot.py +4 -8
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/METADATA +2 -1
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/RECORD +48 -43
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/WHEEL +0 -0
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.213.dist-info → workbench-0.8.217.dist-info}/top_level.txt +0 -0
|
@@ -7,39 +7,30 @@
|
|
|
7
7
|
# - Sample weights support
|
|
8
8
|
# - Categorical feature handling
|
|
9
9
|
# - Compressed feature decompression
|
|
10
|
+
#
|
|
11
|
+
# NOTE: Imports are structured to minimize serverless endpoint startup time.
|
|
12
|
+
# Heavy imports (sklearn, awswrangler) are deferred to training time.
|
|
10
13
|
|
|
11
|
-
import argparse
|
|
12
14
|
import json
|
|
13
15
|
import os
|
|
14
16
|
|
|
15
|
-
import awswrangler as wr
|
|
16
17
|
import joblib
|
|
17
18
|
import numpy as np
|
|
18
19
|
import pandas as pd
|
|
19
20
|
import xgboost as xgb
|
|
20
|
-
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
21
|
-
from sklearn.preprocessing import LabelEncoder
|
|
22
21
|
|
|
23
22
|
from model_script_utils import (
|
|
24
|
-
check_dataframe,
|
|
25
|
-
compute_classification_metrics,
|
|
26
|
-
compute_regression_metrics,
|
|
27
23
|
convert_categorical_types,
|
|
28
24
|
decompress_features,
|
|
29
25
|
expand_proba_column,
|
|
30
26
|
input_fn,
|
|
31
27
|
match_features_case_insensitive,
|
|
32
28
|
output_fn,
|
|
33
|
-
print_classification_metrics,
|
|
34
|
-
print_confusion_matrix,
|
|
35
|
-
print_regression_metrics,
|
|
36
29
|
)
|
|
37
30
|
from uq_harness import (
|
|
38
31
|
compute_confidence,
|
|
39
32
|
load_uq_models,
|
|
40
33
|
predict_intervals,
|
|
41
|
-
save_uq_models,
|
|
42
|
-
train_uq_models,
|
|
43
34
|
)
|
|
44
35
|
|
|
45
36
|
# =============================================================================
|
|
@@ -49,25 +40,27 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
49
40
|
# Training parameters
|
|
50
41
|
"n_folds": 5, # Number of CV folds (1 = single train/val split)
|
|
51
42
|
# Core tree parameters
|
|
52
|
-
"n_estimators":
|
|
53
|
-
"max_depth":
|
|
43
|
+
"n_estimators": 300,
|
|
44
|
+
"max_depth": 7,
|
|
54
45
|
"learning_rate": 0.05,
|
|
55
|
-
# Sampling parameters
|
|
56
|
-
"subsample": 0.
|
|
57
|
-
"colsample_bytree": 0.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
"
|
|
61
|
-
"
|
|
62
|
-
"
|
|
63
|
-
"reg_lambda": 2.0,
|
|
46
|
+
# Sampling parameters (less aggressive - ensemble provides regularization)
|
|
47
|
+
"subsample": 0.8,
|
|
48
|
+
"colsample_bytree": 0.8,
|
|
49
|
+
# Regularization (lighter - ensemble averaging reduces overfitting)
|
|
50
|
+
"min_child_weight": 3,
|
|
51
|
+
"gamma": 0.1,
|
|
52
|
+
"reg_alpha": 0.1,
|
|
53
|
+
"reg_lambda": 1.0,
|
|
64
54
|
# Random seed
|
|
65
|
-
"
|
|
55
|
+
"seed": 42,
|
|
66
56
|
}
|
|
67
57
|
|
|
68
58
|
# Workbench-specific parameters (not passed to XGBoost)
|
|
69
59
|
WORKBENCH_PARAMS = {"n_folds"}
|
|
70
60
|
|
|
61
|
+
# Regression-only parameters (filtered out for classifiers)
|
|
62
|
+
REGRESSION_ONLY_PARAMS = {"objective"}
|
|
63
|
+
|
|
71
64
|
# Template parameters (filled in by Workbench)
|
|
72
65
|
TEMPLATE_PARAMS = {
|
|
73
66
|
"model_type": "{{model_type}}",
|
|
@@ -80,10 +73,140 @@ TEMPLATE_PARAMS = {
|
|
|
80
73
|
}
|
|
81
74
|
|
|
82
75
|
|
|
76
|
+
# =============================================================================
|
|
77
|
+
# Model Loading (for SageMaker inference)
|
|
78
|
+
# =============================================================================
|
|
79
|
+
def model_fn(model_dir: str) -> dict:
|
|
80
|
+
"""Load XGBoost ensemble from the specified directory."""
|
|
81
|
+
# Load ensemble metadata
|
|
82
|
+
metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
|
|
83
|
+
if os.path.exists(metadata_path):
|
|
84
|
+
with open(metadata_path) as f:
|
|
85
|
+
metadata = json.load(f)
|
|
86
|
+
n_ensemble = metadata["n_ensemble"]
|
|
87
|
+
else:
|
|
88
|
+
n_ensemble = 1 # Legacy single model
|
|
89
|
+
|
|
90
|
+
# Load ensemble models
|
|
91
|
+
ensemble_models = []
|
|
92
|
+
for i in range(n_ensemble):
|
|
93
|
+
model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
|
|
94
|
+
if not os.path.exists(model_path):
|
|
95
|
+
model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
|
|
96
|
+
ensemble_models.append(joblib.load(model_path))
|
|
97
|
+
|
|
98
|
+
print(f"Loaded {len(ensemble_models)} model(s)")
|
|
99
|
+
|
|
100
|
+
# Load label encoder (classifier only)
|
|
101
|
+
label_encoder = None
|
|
102
|
+
encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
103
|
+
if os.path.exists(encoder_path):
|
|
104
|
+
label_encoder = joblib.load(encoder_path)
|
|
105
|
+
|
|
106
|
+
# Load category mappings
|
|
107
|
+
category_mappings = {}
|
|
108
|
+
category_path = os.path.join(model_dir, "category_mappings.json")
|
|
109
|
+
if os.path.exists(category_path):
|
|
110
|
+
with open(category_path) as f:
|
|
111
|
+
category_mappings = json.load(f)
|
|
112
|
+
|
|
113
|
+
# Load UQ models (regression only)
|
|
114
|
+
uq_models, uq_metadata = None, None
|
|
115
|
+
uq_path = os.path.join(model_dir, "uq_metadata.json")
|
|
116
|
+
if os.path.exists(uq_path):
|
|
117
|
+
uq_models, uq_metadata = load_uq_models(model_dir)
|
|
118
|
+
|
|
119
|
+
return {
|
|
120
|
+
"ensemble_models": ensemble_models,
|
|
121
|
+
"n_ensemble": n_ensemble,
|
|
122
|
+
"label_encoder": label_encoder,
|
|
123
|
+
"category_mappings": category_mappings,
|
|
124
|
+
"uq_models": uq_models,
|
|
125
|
+
"uq_metadata": uq_metadata,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# =============================================================================
|
|
130
|
+
# Inference (for SageMaker inference)
|
|
131
|
+
# =============================================================================
|
|
132
|
+
def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
133
|
+
"""Make predictions with XGBoost ensemble."""
|
|
134
|
+
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
135
|
+
with open(os.path.join(model_dir, "feature_columns.json")) as f:
|
|
136
|
+
features = json.load(f)
|
|
137
|
+
print(f"Model Features: {features}")
|
|
138
|
+
|
|
139
|
+
# Extract model components
|
|
140
|
+
ensemble_models = model_dict["ensemble_models"]
|
|
141
|
+
label_encoder = model_dict.get("label_encoder")
|
|
142
|
+
category_mappings = model_dict.get("category_mappings", {})
|
|
143
|
+
uq_models = model_dict.get("uq_models")
|
|
144
|
+
uq_metadata = model_dict.get("uq_metadata")
|
|
145
|
+
compressed_features = TEMPLATE_PARAMS["compressed_features"]
|
|
146
|
+
|
|
147
|
+
# Prepare features
|
|
148
|
+
matched_df = match_features_case_insensitive(df, features)
|
|
149
|
+
matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
|
|
150
|
+
|
|
151
|
+
if compressed_features:
|
|
152
|
+
print("Decompressing features for prediction...")
|
|
153
|
+
matched_df, features = decompress_features(matched_df, features, compressed_features)
|
|
154
|
+
|
|
155
|
+
X = matched_df[features]
|
|
156
|
+
|
|
157
|
+
# Collect ensemble predictions
|
|
158
|
+
all_preds = [m.predict(X) for m in ensemble_models]
|
|
159
|
+
ensemble_preds = np.stack(all_preds, axis=0)
|
|
160
|
+
|
|
161
|
+
if label_encoder is not None:
|
|
162
|
+
# Classification: average probabilities, then argmax
|
|
163
|
+
all_probs = [m.predict_proba(X) for m in ensemble_models]
|
|
164
|
+
avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
|
|
165
|
+
class_preds = np.argmax(avg_probs, axis=1)
|
|
166
|
+
|
|
167
|
+
df["prediction"] = label_encoder.inverse_transform(class_preds)
|
|
168
|
+
df["pred_proba"] = [p.tolist() for p in avg_probs]
|
|
169
|
+
df = expand_proba_column(df, label_encoder.classes_)
|
|
170
|
+
else:
|
|
171
|
+
# Regression: average predictions
|
|
172
|
+
df["prediction"] = np.mean(ensemble_preds, axis=0)
|
|
173
|
+
df["prediction_std"] = np.std(ensemble_preds, axis=0)
|
|
174
|
+
|
|
175
|
+
# Add UQ intervals if available
|
|
176
|
+
if uq_models and uq_metadata:
|
|
177
|
+
df = predict_intervals(df, X, uq_models, uq_metadata)
|
|
178
|
+
df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
|
|
179
|
+
|
|
180
|
+
print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
|
|
181
|
+
return df
|
|
182
|
+
|
|
183
|
+
|
|
83
184
|
# =============================================================================
|
|
84
185
|
# Training
|
|
85
186
|
# =============================================================================
|
|
86
187
|
if __name__ == "__main__":
|
|
188
|
+
# -------------------------------------------------------------------------
|
|
189
|
+
# Training-only imports (deferred to reduce serverless startup time)
|
|
190
|
+
# -------------------------------------------------------------------------
|
|
191
|
+
import argparse
|
|
192
|
+
|
|
193
|
+
import awswrangler as wr
|
|
194
|
+
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
|
|
195
|
+
from sklearn.preprocessing import LabelEncoder
|
|
196
|
+
|
|
197
|
+
from model_script_utils import (
|
|
198
|
+
check_dataframe,
|
|
199
|
+
compute_classification_metrics,
|
|
200
|
+
compute_regression_metrics,
|
|
201
|
+
print_classification_metrics,
|
|
202
|
+
print_confusion_matrix,
|
|
203
|
+
print_regression_metrics,
|
|
204
|
+
)
|
|
205
|
+
from uq_harness import (
|
|
206
|
+
save_uq_models,
|
|
207
|
+
train_uq_models,
|
|
208
|
+
)
|
|
209
|
+
|
|
87
210
|
# -------------------------------------------------------------------------
|
|
88
211
|
# Setup: Parse arguments and load data
|
|
89
212
|
# -------------------------------------------------------------------------
|
|
@@ -123,7 +246,7 @@ if __name__ == "__main__":
|
|
|
123
246
|
all_df, features = decompress_features(all_df, features, compressed_features)
|
|
124
247
|
|
|
125
248
|
# -------------------------------------------------------------------------
|
|
126
|
-
# Classification setup
|
|
249
|
+
# Classification setup
|
|
127
250
|
# -------------------------------------------------------------------------
|
|
128
251
|
label_encoder = None
|
|
129
252
|
if model_type == "classifier":
|
|
@@ -136,6 +259,18 @@ if __name__ == "__main__":
|
|
|
136
259
|
# -------------------------------------------------------------------------
|
|
137
260
|
n_folds = hyperparameters["n_folds"]
|
|
138
261
|
xgb_params = {k: v for k, v in hyperparameters.items() if k not in WORKBENCH_PARAMS}
|
|
262
|
+
|
|
263
|
+
# Map 'seed' to 'random_state' for XGBoost
|
|
264
|
+
if "seed" in xgb_params:
|
|
265
|
+
xgb_params["random_state"] = xgb_params.pop("seed")
|
|
266
|
+
|
|
267
|
+
# Handle objective: filter regression-only params for classifiers, set default for regressors
|
|
268
|
+
if model_type == "classifier":
|
|
269
|
+
xgb_params = {k: v for k, v in xgb_params.items() if k not in REGRESSION_ONLY_PARAMS}
|
|
270
|
+
else:
|
|
271
|
+
# Default to MAE (reg:absoluteerror) for regression if not specified
|
|
272
|
+
xgb_params.setdefault("objective", "reg:absoluteerror")
|
|
273
|
+
|
|
139
274
|
print(f"XGBoost params: {xgb_params}")
|
|
140
275
|
|
|
141
276
|
if n_folds == 1:
|
|
@@ -285,12 +420,10 @@ if __name__ == "__main__":
|
|
|
285
420
|
# -------------------------------------------------------------------------
|
|
286
421
|
# Save model artifacts
|
|
287
422
|
# -------------------------------------------------------------------------
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
print(f"Saved {len(ensemble_models)} XGBoost model(s)")
|
|
423
|
+
for idx, m in enumerate(ensemble_models):
|
|
424
|
+
joblib.dump(m, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
|
|
425
|
+
print(f"Saved {len(ensemble_models)} model(s)")
|
|
292
426
|
|
|
293
|
-
# Metadata files
|
|
294
427
|
with open(os.path.join(args.model_dir, "ensemble_metadata.json"), "w") as f:
|
|
295
428
|
json.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds}, f)
|
|
296
429
|
|
|
@@ -310,125 +443,3 @@ if __name__ == "__main__":
|
|
|
310
443
|
save_uq_models(uq_models, uq_metadata, args.model_dir)
|
|
311
444
|
|
|
312
445
|
print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
# =============================================================================
|
|
316
|
-
# Model Loading (for SageMaker inference)
|
|
317
|
-
# =============================================================================
|
|
318
|
-
def model_fn(model_dir: str) -> dict:
|
|
319
|
-
"""Load XGBoost ensemble and associated artifacts.
|
|
320
|
-
|
|
321
|
-
Args:
|
|
322
|
-
model_dir: Directory containing model artifacts
|
|
323
|
-
|
|
324
|
-
Returns:
|
|
325
|
-
Dictionary with ensemble_models, label_encoder, category_mappings, uq_models, etc.
|
|
326
|
-
"""
|
|
327
|
-
# Load ensemble metadata
|
|
328
|
-
metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
|
|
329
|
-
if os.path.exists(metadata_path):
|
|
330
|
-
with open(metadata_path) as f:
|
|
331
|
-
metadata = json.load(f)
|
|
332
|
-
n_ensemble = metadata["n_ensemble"]
|
|
333
|
-
else:
|
|
334
|
-
n_ensemble = 1 # Legacy single model
|
|
335
|
-
|
|
336
|
-
# Load ensemble models
|
|
337
|
-
ensemble_models = []
|
|
338
|
-
for i in range(n_ensemble):
|
|
339
|
-
model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
|
|
340
|
-
if not os.path.exists(model_path):
|
|
341
|
-
model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
|
|
342
|
-
ensemble_models.append(joblib.load(model_path))
|
|
343
|
-
|
|
344
|
-
# Load label encoder (classifier only)
|
|
345
|
-
label_encoder = None
|
|
346
|
-
encoder_path = os.path.join(model_dir, "label_encoder.joblib")
|
|
347
|
-
if os.path.exists(encoder_path):
|
|
348
|
-
label_encoder = joblib.load(encoder_path)
|
|
349
|
-
|
|
350
|
-
# Load category mappings
|
|
351
|
-
category_mappings = {}
|
|
352
|
-
category_path = os.path.join(model_dir, "category_mappings.json")
|
|
353
|
-
if os.path.exists(category_path):
|
|
354
|
-
with open(category_path) as f:
|
|
355
|
-
category_mappings = json.load(f)
|
|
356
|
-
|
|
357
|
-
# Load UQ models (regression only)
|
|
358
|
-
uq_models, uq_metadata = None, None
|
|
359
|
-
uq_path = os.path.join(model_dir, "uq_metadata.json")
|
|
360
|
-
if os.path.exists(uq_path):
|
|
361
|
-
uq_models, uq_metadata = load_uq_models(model_dir)
|
|
362
|
-
|
|
363
|
-
return {
|
|
364
|
-
"ensemble_models": ensemble_models,
|
|
365
|
-
"n_ensemble": n_ensemble,
|
|
366
|
-
"label_encoder": label_encoder,
|
|
367
|
-
"category_mappings": category_mappings,
|
|
368
|
-
"uq_models": uq_models,
|
|
369
|
-
"uq_metadata": uq_metadata,
|
|
370
|
-
}
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
# =============================================================================
|
|
374
|
-
# Inference (for SageMaker inference)
|
|
375
|
-
# =============================================================================
|
|
376
|
-
def predict_fn(df: pd.DataFrame, models: dict) -> pd.DataFrame:
|
|
377
|
-
"""Make predictions with XGBoost ensemble.
|
|
378
|
-
|
|
379
|
-
Args:
|
|
380
|
-
df: Input DataFrame with features
|
|
381
|
-
models: Dictionary from model_fn containing ensemble and metadata
|
|
382
|
-
|
|
383
|
-
Returns:
|
|
384
|
-
DataFrame with predictions added
|
|
385
|
-
"""
|
|
386
|
-
# Load feature columns
|
|
387
|
-
model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
388
|
-
with open(os.path.join(model_dir, "feature_columns.json")) as f:
|
|
389
|
-
features = json.load(f)
|
|
390
|
-
print(f"Model Features: {features}")
|
|
391
|
-
|
|
392
|
-
# Extract model components
|
|
393
|
-
ensemble_models = models["ensemble_models"]
|
|
394
|
-
label_encoder = models.get("label_encoder")
|
|
395
|
-
category_mappings = models.get("category_mappings", {})
|
|
396
|
-
uq_models = models.get("uq_models")
|
|
397
|
-
uq_metadata = models.get("uq_metadata")
|
|
398
|
-
compressed_features = TEMPLATE_PARAMS["compressed_features"]
|
|
399
|
-
|
|
400
|
-
# Prepare features
|
|
401
|
-
matched_df = match_features_case_insensitive(df, features)
|
|
402
|
-
matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
|
|
403
|
-
|
|
404
|
-
if compressed_features:
|
|
405
|
-
print("Decompressing features for prediction...")
|
|
406
|
-
matched_df, features = decompress_features(matched_df, features, compressed_features)
|
|
407
|
-
|
|
408
|
-
X = matched_df[features]
|
|
409
|
-
|
|
410
|
-
# Collect ensemble predictions
|
|
411
|
-
all_preds = [m.predict(X) for m in ensemble_models]
|
|
412
|
-
ensemble_preds = np.stack(all_preds, axis=0)
|
|
413
|
-
|
|
414
|
-
if label_encoder is not None:
|
|
415
|
-
# Classification: average probabilities, then argmax
|
|
416
|
-
all_probs = [m.predict_proba(X) for m in ensemble_models]
|
|
417
|
-
avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
|
|
418
|
-
class_preds = np.argmax(avg_probs, axis=1)
|
|
419
|
-
|
|
420
|
-
df["prediction"] = label_encoder.inverse_transform(class_preds)
|
|
421
|
-
df["pred_proba"] = [p.tolist() for p in avg_probs]
|
|
422
|
-
df = expand_proba_column(df, label_encoder.classes_)
|
|
423
|
-
else:
|
|
424
|
-
# Regression: average predictions
|
|
425
|
-
df["prediction"] = np.mean(ensemble_preds, axis=0)
|
|
426
|
-
df["prediction_std"] = np.std(ensemble_preds, axis=0)
|
|
427
|
-
|
|
428
|
-
# Add UQ intervals if available
|
|
429
|
-
if uq_models and uq_metadata:
|
|
430
|
-
df = predict_intervals(df, X, uq_models, uq_metadata)
|
|
431
|
-
df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
|
|
432
|
-
|
|
433
|
-
print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
|
|
434
|
-
return df
|
|
@@ -302,11 +302,6 @@ class WorkbenchShell:
|
|
|
302
302
|
self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
|
|
303
303
|
self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
|
|
304
304
|
|
|
305
|
-
# Algorithms
|
|
306
|
-
self.commands["FSP"] = importlib.import_module(
|
|
307
|
-
"workbench.algorithms.dataframe.feature_space_proximity"
|
|
308
|
-
).FeatureSpaceProximity
|
|
309
|
-
|
|
310
305
|
# These are 'nice to have' imports
|
|
311
306
|
self.commands["pd"] = importlib.import_module("pandas")
|
|
312
307
|
self.commands["wr"] = importlib.import_module("awswrangler")
|
|
@@ -5,7 +5,7 @@ Usage:
|
|
|
5
5
|
python model_script_harness.py <local_script.py> <model_name>
|
|
6
6
|
|
|
7
7
|
Example:
|
|
8
|
-
python model_script_harness.py pytorch.py aqsol-pytorch
|
|
8
|
+
python model_script_harness.py pytorch.py aqsol-reg-pytorch
|
|
9
9
|
|
|
10
10
|
This allows you to test LOCAL changes to a model script against deployed model artifacts.
|
|
11
11
|
Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
|
|
@@ -72,7 +72,7 @@ def main():
|
|
|
72
72
|
print("Usage: python model_script_harness.py <local_script.py> <model_name>")
|
|
73
73
|
print("\nArguments:")
|
|
74
74
|
print(" local_script.py - Path to your LOCAL model script to test")
|
|
75
|
-
print(" model_name - Workbench model name (e.g., aqsol-pytorch
|
|
75
|
+
print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
|
|
76
76
|
print("\nOptional: testing/env.json with additional environment variables")
|
|
77
77
|
sys.exit(1)
|
|
78
78
|
|
|
@@ -4,10 +4,14 @@ import logging
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
6
|
# Molecular Descriptor Imports
|
|
7
|
-
from rdkit import Chem
|
|
7
|
+
from rdkit import Chem, RDLogger
|
|
8
8
|
from rdkit.Chem import rdFingerprintGenerator
|
|
9
9
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
10
10
|
|
|
11
|
+
# Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
|
|
12
|
+
# Keep errors enabled so we see actual problems
|
|
13
|
+
RDLogger.DisableLog("rdApp.warning")
|
|
14
|
+
|
|
11
15
|
# Set up the logger
|
|
12
16
|
log = logging.getLogger("workbench")
|
|
13
17
|
|
|
@@ -47,8 +51,8 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
47
51
|
# Make sure our molecules are not None
|
|
48
52
|
failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
|
|
49
53
|
if failed_smiles:
|
|
50
|
-
log.
|
|
51
|
-
df = df.dropna(subset=["molecule"])
|
|
54
|
+
log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
|
|
55
|
+
df = df.dropna(subset=["molecule"]).copy()
|
|
52
56
|
|
|
53
57
|
# If we have fragments in our compounds, get the largest fragment before computing fingerprints
|
|
54
58
|
largest_frags = df["molecule"].apply(
|
|
@@ -76,6 +76,10 @@ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
|
76
76
|
This retrieves the validation predictions saved during model training and
|
|
77
77
|
computes metrics directly from them.
|
|
78
78
|
|
|
79
|
+
Note:
|
|
80
|
+
- Regression: Supports both single-target and multi-target models
|
|
81
|
+
- Classification: Only single-target is supported (with any number of classes)
|
|
82
|
+
|
|
79
83
|
Args:
|
|
80
84
|
workbench_model: Workbench model object
|
|
81
85
|
|
|
@@ -84,6 +88,7 @@ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
|
84
88
|
- DataFrame with computed metrics
|
|
85
89
|
- DataFrame with validation predictions
|
|
86
90
|
"""
|
|
91
|
+
|
|
87
92
|
# Get the validation predictions from S3
|
|
88
93
|
s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
|
|
89
94
|
predictions_df = pull_s3_data(s3_path)
|
|
@@ -93,14 +98,27 @@ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
|
93
98
|
|
|
94
99
|
log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
|
|
95
100
|
|
|
96
|
-
#
|
|
101
|
+
# Get target and class labels
|
|
97
102
|
target = workbench_model.target()
|
|
98
103
|
class_labels = workbench_model.class_labels()
|
|
99
104
|
|
|
100
|
-
|
|
105
|
+
# If single target just use the "prediction" column
|
|
106
|
+
if isinstance(target, str):
|
|
101
107
|
metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
|
|
102
|
-
|
|
103
|
-
|
|
108
|
+
return metrics_df, predictions_df
|
|
109
|
+
|
|
110
|
+
# Multi-target regression
|
|
111
|
+
metrics_list = []
|
|
112
|
+
for t in target:
|
|
113
|
+
# Prediction will be {target}_pred in multi-target case
|
|
114
|
+
pred_col = f"{t}_pred"
|
|
115
|
+
|
|
116
|
+
# Drop NaNs for this target
|
|
117
|
+
target_preds_df = predictions_df.dropna(subset=[t, pred_col])
|
|
118
|
+
metrics_df = compute_metrics_from_predictions(target_preds_df, t, class_labels, prediction_col=pred_col)
|
|
119
|
+
metrics_df.insert(0, "target", t)
|
|
120
|
+
metrics_list.append(metrics_df)
|
|
121
|
+
metrics_df = pd.concat(metrics_list, ignore_index=True) if metrics_list else pd.DataFrame()
|
|
104
122
|
|
|
105
123
|
return metrics_df, predictions_df
|
|
106
124
|
|
|
@@ -111,7 +129,7 @@ if __name__ == "__main__":
|
|
|
111
129
|
from workbench.api import Model
|
|
112
130
|
|
|
113
131
|
# Initialize Workbench model
|
|
114
|
-
model_name = "
|
|
132
|
+
model_name = "open-admet-chemprop-mt"
|
|
115
133
|
print(f"Loading Workbench model: {model_name}")
|
|
116
134
|
model = Model(model_name)
|
|
117
135
|
print(f"Model Framework: {model.model_framework}")
|