workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -1,30 +1,22 @@
|
|
|
1
1
|
"""XGBoost Model Utilities"""
|
|
2
2
|
|
|
3
|
+
import glob
|
|
4
|
+
import hashlib
|
|
3
5
|
import logging
|
|
4
6
|
import os
|
|
5
|
-
import tempfile
|
|
6
|
-
import tarfile
|
|
7
7
|
import pickle
|
|
8
|
-
import
|
|
8
|
+
import tempfile
|
|
9
|
+
from typing import Any, List, Optional, Tuple
|
|
10
|
+
|
|
9
11
|
import awswrangler as wr
|
|
10
|
-
|
|
11
|
-
import hashlib
|
|
12
|
+
import joblib
|
|
12
13
|
import pandas as pd
|
|
13
|
-
import numpy as np
|
|
14
14
|
import xgboost as xgb
|
|
15
|
-
from typing import Dict, Any
|
|
16
|
-
from sklearn.model_selection import KFold, StratifiedKFold
|
|
17
|
-
from sklearn.metrics import (
|
|
18
|
-
precision_recall_fscore_support,
|
|
19
|
-
confusion_matrix,
|
|
20
|
-
mean_squared_error,
|
|
21
|
-
mean_absolute_error,
|
|
22
|
-
r2_score,
|
|
23
|
-
)
|
|
24
|
-
from sklearn.preprocessing import LabelEncoder
|
|
25
15
|
|
|
26
16
|
# Workbench Imports
|
|
27
|
-
from workbench.utils.
|
|
17
|
+
from workbench.utils.aws_utils import pull_s3_data
|
|
18
|
+
from workbench.utils.metrics_utils import compute_metrics_from_predictions
|
|
19
|
+
from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
|
|
28
20
|
from workbench.utils.pandas_utils import convert_categorical_types
|
|
29
21
|
|
|
30
22
|
# Set up the log
|
|
@@ -34,14 +26,12 @@ log = logging.getLogger("workbench")
|
|
|
34
26
|
def xgboost_model_from_s3(model_artifact_uri: str):
|
|
35
27
|
"""
|
|
36
28
|
Download and extract XGBoost model artifact from S3, then load the model into memory.
|
|
37
|
-
Handles both direct XGBoost model files and pickled models.
|
|
38
|
-
Ensures categorical feature support is enabled.
|
|
39
29
|
|
|
40
30
|
Args:
|
|
41
31
|
model_artifact_uri (str): S3 URI of the model artifact.
|
|
42
32
|
|
|
43
33
|
Returns:
|
|
44
|
-
Loaded XGBoost model or None if unavailable.
|
|
34
|
+
Loaded XGBoost model (XGBClassifier, XGBRegressor, or Booster) or None if unavailable.
|
|
45
35
|
"""
|
|
46
36
|
|
|
47
37
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
@@ -50,68 +40,90 @@ def xgboost_model_from_s3(model_artifact_uri: str):
|
|
|
50
40
|
wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
|
|
51
41
|
|
|
52
42
|
# Extract tarball
|
|
53
|
-
|
|
54
|
-
tar.extractall(path=tmpdir, filter="data")
|
|
43
|
+
safe_extract_tarfile(local_tar_path, tmpdir)
|
|
55
44
|
|
|
56
45
|
# Define model file patterns to search for (in order of preference)
|
|
57
46
|
patterns = [
|
|
58
|
-
#
|
|
59
|
-
os.path.join(tmpdir, "
|
|
60
|
-
os.path.join(tmpdir, "
|
|
61
|
-
os.path.join(tmpdir, "*.
|
|
47
|
+
# Joblib models (preferred - preserves everything)
|
|
48
|
+
os.path.join(tmpdir, "*model*.joblib"),
|
|
49
|
+
os.path.join(tmpdir, "xgb*.joblib"),
|
|
50
|
+
os.path.join(tmpdir, "**", "*model*.joblib"),
|
|
51
|
+
os.path.join(tmpdir, "**", "xgb*.joblib"),
|
|
52
|
+
# Pickle models (also preserves everything)
|
|
53
|
+
os.path.join(tmpdir, "*model*.pkl"),
|
|
54
|
+
os.path.join(tmpdir, "xgb*.pkl"),
|
|
55
|
+
os.path.join(tmpdir, "**", "*model*.pkl"),
|
|
56
|
+
os.path.join(tmpdir, "**", "xgb*.pkl"),
|
|
57
|
+
# JSON models (fallback - requires reconstruction)
|
|
58
|
+
os.path.join(tmpdir, "*model*.json"),
|
|
59
|
+
os.path.join(tmpdir, "xgb*.json"),
|
|
62
60
|
os.path.join(tmpdir, "**", "*model*.json"),
|
|
63
|
-
os.path.join(tmpdir, "**", "
|
|
64
|
-
# Pickled models
|
|
65
|
-
os.path.join(tmpdir, "*.pkl"),
|
|
66
|
-
os.path.join(tmpdir, "**", "*.pkl"),
|
|
67
|
-
os.path.join(tmpdir, "*.pickle"),
|
|
68
|
-
os.path.join(tmpdir, "**", "*.pickle"),
|
|
61
|
+
os.path.join(tmpdir, "**", "xgb*.json"),
|
|
69
62
|
]
|
|
70
63
|
|
|
71
64
|
# Try each pattern
|
|
72
65
|
for pattern in patterns:
|
|
73
|
-
# Use glob to find all matching files
|
|
74
66
|
for model_path in glob.glob(pattern, recursive=True):
|
|
75
|
-
#
|
|
67
|
+
# Skip files that are clearly not XGBoost models
|
|
68
|
+
filename = os.path.basename(model_path).lower()
|
|
69
|
+
if any(skip in filename for skip in ["label_encoder", "scaler", "preprocessor", "transformer"]):
|
|
70
|
+
log.debug(f"Skipping non-model file: {model_path}")
|
|
71
|
+
continue
|
|
72
|
+
|
|
76
73
|
_, ext = os.path.splitext(model_path)
|
|
77
74
|
|
|
78
75
|
try:
|
|
79
|
-
if ext
|
|
80
|
-
|
|
76
|
+
if ext == ".joblib":
|
|
77
|
+
model = joblib.load(model_path)
|
|
78
|
+
# Verify it's actually an XGBoost model
|
|
79
|
+
if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
|
|
80
|
+
log.important(f"Loaded XGBoost model from joblib: {model_path}")
|
|
81
|
+
return model
|
|
82
|
+
else:
|
|
83
|
+
log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
|
|
84
|
+
|
|
85
|
+
elif ext in [".pkl", ".pickle"]:
|
|
81
86
|
with open(model_path, "rb") as f:
|
|
82
87
|
model = pickle.load(f)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
log.important(f"Loaded XGBoost Booster from pickle: {model_path}")
|
|
88
|
+
# Verify it's actually an XGBoost model
|
|
89
|
+
if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
|
|
90
|
+
log.important(f"Loaded XGBoost model from pickle: {model_path}")
|
|
87
91
|
return model
|
|
88
|
-
|
|
89
|
-
log.
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
# Handle direct XGBoost model files
|
|
92
|
+
else:
|
|
93
|
+
log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
|
|
94
|
+
|
|
95
|
+
elif ext == ".json":
|
|
96
|
+
# JSON files should be XGBoost models by definition
|
|
94
97
|
booster = xgb.Booster()
|
|
95
98
|
booster.load_model(model_path)
|
|
96
|
-
log.important(f"Loaded XGBoost
|
|
99
|
+
log.important(f"Loaded XGBoost booster from JSON: {model_path}")
|
|
97
100
|
return booster
|
|
101
|
+
|
|
98
102
|
except Exception as e:
|
|
99
|
-
log.
|
|
100
|
-
continue
|
|
103
|
+
log.debug(f"Failed to load {model_path}: {e}")
|
|
104
|
+
continue
|
|
101
105
|
|
|
102
|
-
# If no model found
|
|
103
106
|
log.error("No XGBoost model found in the artifact.")
|
|
104
107
|
return None
|
|
105
108
|
|
|
106
109
|
|
|
107
|
-
def feature_importance(workbench_model, importance_type: str = "
|
|
110
|
+
def feature_importance(workbench_model, importance_type: str = "gain") -> Optional[List[Tuple[str, float]]]:
|
|
108
111
|
"""
|
|
109
112
|
Get sorted feature importances from a Workbench Model object.
|
|
110
113
|
|
|
111
114
|
Args:
|
|
112
115
|
workbench_model: Workbench model object
|
|
113
|
-
importance_type: Type of feature importance.
|
|
114
|
-
|
|
116
|
+
importance_type: Type of feature importance. Options:
|
|
117
|
+
- 'gain' (default): Average improvement in loss/objective when feature is used.
|
|
118
|
+
Best for understanding predictive power of features.
|
|
119
|
+
- 'weight': Number of times a feature appears in trees (split count).
|
|
120
|
+
Useful for understanding model complexity and feature usage frequency.
|
|
121
|
+
- 'cover': Average number of samples affected when feature is used.
|
|
122
|
+
Shows the relative quantity of observations related to this feature.
|
|
123
|
+
- 'total_gain': Total improvement in loss/objective across all splits.
|
|
124
|
+
Similar to 'gain' but not averaged (can be biased toward frequent features).
|
|
125
|
+
- 'total_cover': Total number of samples affected across all splits.
|
|
126
|
+
Similar to 'cover' but not averaged.
|
|
115
127
|
|
|
116
128
|
Returns:
|
|
117
129
|
List of tuples (feature, importance) sorted by importance value (descending).
|
|
@@ -120,7 +132,8 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
|
|
|
120
132
|
|
|
121
133
|
Note:
|
|
122
134
|
XGBoost's get_score() only returns features with non-zero importance.
|
|
123
|
-
This function ensures all model features are included in the output
|
|
135
|
+
This function ensures all model features are included in the output,
|
|
136
|
+
adding zero values for features that weren't used in any tree splits.
|
|
124
137
|
"""
|
|
125
138
|
model_artifact_uri = workbench_model.model_data_url()
|
|
126
139
|
xgb_model = xgboost_model_from_s3(model_artifact_uri)
|
|
@@ -128,11 +141,18 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
|
|
|
128
141
|
log.error("No XGBoost model found in the artifact.")
|
|
129
142
|
return None
|
|
130
143
|
|
|
131
|
-
#
|
|
132
|
-
|
|
144
|
+
# Check if we got a full sklearn model or just a booster (for backwards compatibility)
|
|
145
|
+
if hasattr(xgb_model, "get_booster"):
|
|
146
|
+
# Full sklearn model - get the booster for feature importance
|
|
147
|
+
booster = xgb_model.get_booster()
|
|
148
|
+
all_features = booster.feature_names
|
|
149
|
+
else:
|
|
150
|
+
# Already a booster (legacy JSON load)
|
|
151
|
+
booster = xgb_model
|
|
152
|
+
all_features = xgb_model.feature_names
|
|
133
153
|
|
|
134
|
-
# Get
|
|
135
|
-
|
|
154
|
+
# Get feature importances (only non-zero features)
|
|
155
|
+
importances = booster.get_score(importance_type=importance_type)
|
|
136
156
|
|
|
137
157
|
# Create complete importance dict with zeros for missing features
|
|
138
158
|
complete_importances = {feat: importances.get(feat, 0.0) for feat in all_features}
|
|
@@ -229,149 +249,45 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
|
|
|
229
249
|
return result_df
|
|
230
250
|
|
|
231
251
|
|
|
232
|
-
def
|
|
233
|
-
"""
|
|
234
|
-
|
|
252
|
+
def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
253
|
+
"""Pull cross-validation results from AWS training artifacts.
|
|
254
|
+
|
|
255
|
+
This retrieves the validation predictions saved during model training and
|
|
256
|
+
computes metrics directly from them. For XGBoost models trained with
|
|
257
|
+
n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
|
|
258
|
+
|
|
235
259
|
Args:
|
|
236
260
|
workbench_model: Workbench model object
|
|
237
|
-
|
|
261
|
+
|
|
238
262
|
Returns:
|
|
239
|
-
|
|
240
|
-
-
|
|
241
|
-
-
|
|
242
|
-
- overall_metrics: Overall metrics for all folds
|
|
263
|
+
Tuple of:
|
|
264
|
+
- DataFrame with computed metrics
|
|
265
|
+
- DataFrame with validation predictions
|
|
243
266
|
"""
|
|
244
|
-
|
|
267
|
+
# Get the validation predictions from S3
|
|
268
|
+
s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
|
|
269
|
+
predictions_df = pull_s3_data(s3_path)
|
|
245
270
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
model_artifact_uri = workbench_model.model_data_url()
|
|
249
|
-
loaded_booster = xgboost_model_from_s3(model_artifact_uri)
|
|
250
|
-
if loaded_booster is None:
|
|
251
|
-
log.error("No XGBoost model found in the artifact.")
|
|
252
|
-
return {}
|
|
253
|
-
# Create the model wrapper
|
|
254
|
-
is_classifier = model_type == "classifier"
|
|
255
|
-
xgb_model = (
|
|
256
|
-
xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
|
|
257
|
-
)
|
|
258
|
-
xgb_model._Booster = loaded_booster
|
|
259
|
-
# Prepare data
|
|
260
|
-
fs = FeatureSet(workbench_model.get_input())
|
|
261
|
-
df = fs.pull_dataframe()
|
|
262
|
-
feature_cols = workbench_model.features()
|
|
263
|
-
# Convert string features to categorical
|
|
264
|
-
for col in feature_cols:
|
|
265
|
-
if df[col].dtype in ["object", "string"]:
|
|
266
|
-
df[col] = df[col].astype("category")
|
|
267
|
-
# Split X and y
|
|
268
|
-
X = df[workbench_model.features()]
|
|
269
|
-
y = df[workbench_model.target()]
|
|
270
|
-
|
|
271
|
-
# Encode target if it's a classification problem
|
|
272
|
-
label_encoder = LabelEncoder() if is_classifier else None
|
|
273
|
-
if label_encoder:
|
|
274
|
-
y = pd.Series(label_encoder.fit_transform(y), name=workbench_model.target())
|
|
275
|
-
# Prepare KFold
|
|
276
|
-
kfold = (
|
|
277
|
-
StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=42)
|
|
278
|
-
if is_classifier
|
|
279
|
-
else KFold(n_splits=nfolds, shuffle=True, random_state=42)
|
|
280
|
-
)
|
|
271
|
+
if predictions_df is None:
|
|
272
|
+
raise ValueError(f"No validation predictions found at {s3_path}")
|
|
281
273
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
xgb_model.fit(X_train, y_train)
|
|
291
|
-
preds = xgb_model.predict(X_val)
|
|
292
|
-
all_predictions.extend(preds)
|
|
293
|
-
all_actuals.extend(y_val)
|
|
294
|
-
|
|
295
|
-
# Calculate metrics for this fold
|
|
296
|
-
fold_metrics = {"fold": fold_idx + 1}
|
|
297
|
-
|
|
298
|
-
if is_classifier:
|
|
299
|
-
y_val_original = label_encoder.inverse_transform(y_val)
|
|
300
|
-
preds_original = label_encoder.inverse_transform(preds.astype(int))
|
|
301
|
-
scores = precision_recall_fscore_support(
|
|
302
|
-
y_val_original, preds_original, average="weighted", zero_division=0
|
|
303
|
-
)
|
|
304
|
-
fold_metrics.update({"precision": float(scores[0]), "recall": float(scores[1]), "fscore": float(scores[2])})
|
|
305
|
-
else:
|
|
306
|
-
fold_metrics.update(
|
|
307
|
-
{
|
|
308
|
-
"rmse": float(np.sqrt(mean_squared_error(y_val, preds))),
|
|
309
|
-
"mae": float(mean_absolute_error(y_val, preds)),
|
|
310
|
-
"r2": float(r2_score(y_val, preds)),
|
|
311
|
-
}
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
fold_results.append(fold_metrics)
|
|
315
|
-
# Calculate overall metrics
|
|
316
|
-
overall_metrics = {}
|
|
317
|
-
if is_classifier:
|
|
318
|
-
all_actuals_original = label_encoder.inverse_transform(all_actuals)
|
|
319
|
-
all_predictions_original = label_encoder.inverse_transform(all_predictions)
|
|
320
|
-
scores = precision_recall_fscore_support(
|
|
321
|
-
all_actuals_original, all_predictions_original, average="weighted", zero_division=0
|
|
322
|
-
)
|
|
323
|
-
overall_metrics.update(
|
|
324
|
-
{
|
|
325
|
-
"precision": float(scores[0]),
|
|
326
|
-
"recall": float(scores[1]),
|
|
327
|
-
"fscore": float(scores[2]),
|
|
328
|
-
"confusion_matrix": confusion_matrix(
|
|
329
|
-
all_actuals_original, all_predictions_original, labels=label_encoder.classes_
|
|
330
|
-
).tolist(),
|
|
331
|
-
"label_names": list(label_encoder.classes_),
|
|
332
|
-
}
|
|
333
|
-
)
|
|
274
|
+
log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
|
|
275
|
+
|
|
276
|
+
# Compute metrics from predictions
|
|
277
|
+
target = workbench_model.target()
|
|
278
|
+
class_labels = workbench_model.class_labels()
|
|
279
|
+
|
|
280
|
+
if target in predictions_df.columns and "prediction" in predictions_df.columns:
|
|
281
|
+
metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
|
|
334
282
|
else:
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
"mae": float(mean_absolute_error(all_actuals, all_predictions)),
|
|
339
|
-
"r2": float(r2_score(all_actuals, all_predictions)),
|
|
340
|
-
}
|
|
341
|
-
)
|
|
342
|
-
# Calculate summary metrics across folds
|
|
343
|
-
summary_metrics = {}
|
|
344
|
-
metrics_to_aggregate = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "r2"]
|
|
345
|
-
|
|
346
|
-
for metric in metrics_to_aggregate:
|
|
347
|
-
values = [fold[metric] for fold in fold_results]
|
|
348
|
-
summary_metrics[metric] = f"{float(np.mean(values)):.3f} ±{float(np.std(values)):.3f}"
|
|
349
|
-
# Format fold results as strings (TBD section)
|
|
350
|
-
formatted_folds = {}
|
|
351
|
-
for fold_data in fold_results:
|
|
352
|
-
fold_key = f"Fold {fold_data['fold']}"
|
|
353
|
-
if is_classifier:
|
|
354
|
-
formatted_folds[fold_key] = (
|
|
355
|
-
f"precision: {fold_data['precision']:.3f} "
|
|
356
|
-
f"recall: {fold_data['recall']:.3f} "
|
|
357
|
-
f"fscore: {fold_data['fscore']:.3f}"
|
|
358
|
-
)
|
|
359
|
-
else:
|
|
360
|
-
formatted_folds[fold_key] = (
|
|
361
|
-
f"rmse: {fold_data['rmse']:.3f} mae: {fold_data['mae']:.3f} r2: {fold_data['r2']:.3f}"
|
|
362
|
-
)
|
|
363
|
-
# Return the results
|
|
364
|
-
return {
|
|
365
|
-
"summary_metrics": summary_metrics,
|
|
366
|
-
# "overall_metrics": overall_metrics,
|
|
367
|
-
"folds": formatted_folds,
|
|
368
|
-
}
|
|
283
|
+
metrics_df = pd.DataFrame()
|
|
284
|
+
|
|
285
|
+
return metrics_df, predictions_df
|
|
369
286
|
|
|
370
287
|
|
|
371
288
|
if __name__ == "__main__":
|
|
372
289
|
"""Exercise the Model Utilities"""
|
|
373
|
-
from workbench.api import Model
|
|
374
|
-
from pprint import pprint
|
|
290
|
+
from workbench.api import Model
|
|
375
291
|
|
|
376
292
|
# Test the XGBoost model loading and feature importance
|
|
377
293
|
model = Model("abalone-regression")
|
|
@@ -383,34 +299,26 @@ if __name__ == "__main__":
|
|
|
383
299
|
model_artifact_uri = model.model_data_url()
|
|
384
300
|
xgb_model = xgboost_model_from_s3(model_artifact_uri)
|
|
385
301
|
|
|
386
|
-
#
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# Test XGBoost add_leaf_hash
|
|
391
|
-
input_df = FeatureSet(model.get_input()).pull_dataframe()
|
|
392
|
-
leaf_df = add_leaf_hash(model, input_df)
|
|
393
|
-
print("DataFrame with Leaf Hash:")
|
|
394
|
-
print(leaf_df)
|
|
395
|
-
|
|
396
|
-
# Okay, we're going to copy row 3 and insert it into row 7 to make sure the leaf_hash is the same
|
|
397
|
-
input_df.iloc[7] = input_df.iloc[3]
|
|
398
|
-
print("DataFrame with Leaf Hash (3 and 7 should match):")
|
|
399
|
-
leaf_df = add_leaf_hash(model, input_df)
|
|
400
|
-
print(leaf_df)
|
|
401
|
-
|
|
402
|
-
# Test leaf_stats
|
|
403
|
-
target_col = "class_number_of_rings"
|
|
404
|
-
stats_df = leaf_stats(leaf_df, target_col)
|
|
405
|
-
print("DataFrame with Leaf Statistics:")
|
|
406
|
-
print(stats_df)
|
|
407
|
-
|
|
408
|
-
print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
|
|
409
|
-
model = Model("abalone-regression")
|
|
410
|
-
results = cross_fold_inference(model)
|
|
411
|
-
pprint(results)
|
|
302
|
+
# Verify enable_categorical is preserved (for debugging/confidence)
|
|
303
|
+
print(f"Model parameters: {xgb_model.get_params()}")
|
|
304
|
+
print(f"enable_categorical: {xgb_model.enable_categorical}")
|
|
412
305
|
|
|
413
|
-
print("\n===
|
|
306
|
+
print("\n=== PULL CV RESULTS EXAMPLE ===")
|
|
307
|
+
model = Model("abalone-regression")
|
|
308
|
+
metrics_df, predictions_df = pull_cv_results(model)
|
|
309
|
+
print(f"\nMetrics:\n{metrics_df}")
|
|
310
|
+
print(f"\nPredictions shape: {predictions_df.shape}")
|
|
311
|
+
print(f"Predictions columns: {predictions_df.columns.tolist()}")
|
|
312
|
+
print(predictions_df.head())
|
|
313
|
+
|
|
314
|
+
# Test on a Classifier model
|
|
315
|
+
print("\n=== CLASSIFIER MODEL TEST ===")
|
|
414
316
|
model = Model("wine-classification")
|
|
415
|
-
|
|
416
|
-
|
|
317
|
+
features = feature_importance(model)
|
|
318
|
+
print("Feature Importance:")
|
|
319
|
+
print(features)
|
|
320
|
+
metrics_df, predictions_df = pull_cv_results(model)
|
|
321
|
+
print(f"\nMetrics:\n{metrics_df}")
|
|
322
|
+
print(f"\nPredictions shape: {predictions_df.shape}")
|
|
323
|
+
print(f"Predictions columns: {predictions_df.columns.tolist()}")
|
|
324
|
+
print(predictions_df.head())
|
|
@@ -36,10 +36,22 @@ class ModelPlot(ComponentInterface):
|
|
|
36
36
|
if df is None:
|
|
37
37
|
return self.display_text("No Data")
|
|
38
38
|
|
|
39
|
-
#
|
|
39
|
+
# Grab the target(s) for this model
|
|
40
40
|
target = model.target()
|
|
41
|
+
|
|
42
|
+
# For multi-task models, match target to inference_run name or default to first
|
|
43
|
+
if isinstance(target, list):
|
|
44
|
+
target = next((t for t in target if t in inference_run), target[0])
|
|
45
|
+
|
|
46
|
+
# Compute error for coloring
|
|
41
47
|
df["error"] = abs(df["prediction"] - df[target])
|
|
42
|
-
return ScatterPlot().update_properties(
|
|
48
|
+
return ScatterPlot().update_properties(
|
|
49
|
+
df,
|
|
50
|
+
color="error",
|
|
51
|
+
regression_line=True,
|
|
52
|
+
x=target,
|
|
53
|
+
y="prediction",
|
|
54
|
+
)[0]
|
|
43
55
|
else:
|
|
44
56
|
return self.display_text(f"Model Type: {model.model_type}\n\n Awesome Plot Coming Soon!")
|
|
45
57
|
|
|
@@ -156,10 +156,13 @@ class PluginUnitTest:
|
|
|
156
156
|
"""Run the Dash server for the plugin, handling common errors gracefully."""
|
|
157
157
|
while self.is_port_in_use(self.port):
|
|
158
158
|
log.info(f"Port {self.port} is in use. Trying the next one...")
|
|
159
|
-
self.port += 1
|
|
159
|
+
self.port += 1
|
|
160
160
|
|
|
161
161
|
log.info(f"Starting Dash server on port {self.port}...")
|
|
162
|
-
|
|
162
|
+
try:
|
|
163
|
+
self.app.run(debug=True, use_reloader=False, port=self.port)
|
|
164
|
+
except KeyboardInterrupt:
|
|
165
|
+
log.info("Shutting down Dash server...")
|
|
163
166
|
|
|
164
167
|
@staticmethod
|
|
165
168
|
def is_port_in_use(port):
|
|
@@ -72,7 +72,9 @@ class DashboardStatus(PluginInterface):
|
|
|
72
72
|
details = "**Redis:** 🔴 Failed to Connect<br>"
|
|
73
73
|
|
|
74
74
|
# Fill in the license details
|
|
75
|
-
|
|
75
|
+
redis_host = config_info.get("REDIS_HOST", "NOT SET")
|
|
76
|
+
redis_port = config_info.get("REDIS_PORT", "NOT SET")
|
|
77
|
+
details += f"**Redis Server:** {redis_host}:{redis_port}<br>"
|
|
76
78
|
details += f"**Workbench S3 Bucket:** {config_info['WORKBENCH_BUCKET']}<br>"
|
|
77
79
|
details += f"**Plugin Path:** {config_info.get('WORKBENCH_PLUGINS', 'unknown')}<br>"
|
|
78
80
|
details += f"**Themes Path:** {config_info.get('WORKBENCH_THEMES', 'unknown')}<br>"
|
|
@@ -5,7 +5,7 @@ import dash_bootstrap_components as dbc
|
|
|
5
5
|
|
|
6
6
|
# Workbench Imports
|
|
7
7
|
from workbench.api.compound import Compound
|
|
8
|
-
from workbench.utils.chem_utils import svg_from_smiles
|
|
8
|
+
from workbench.utils.chem_utils.vis import svg_from_smiles
|
|
9
9
|
from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
|
|
10
10
|
from workbench.utils.theme_manager import ThemeManager
|
|
11
11
|
from workbench.utils.ai_summary import AISummary
|