workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,22 @@
1
1
  """XGBoost Model Utilities"""
2
2
 
3
+ import glob
4
+ import hashlib
3
5
  import logging
4
6
  import os
5
- import tempfile
6
- import tarfile
7
7
  import pickle
8
- import glob
8
+ import tempfile
9
+ from typing import Any, List, Optional, Tuple
10
+
9
11
  import awswrangler as wr
10
- from typing import Optional, List, Tuple
11
- import hashlib
12
+ import joblib
12
13
  import pandas as pd
13
- import numpy as np
14
14
  import xgboost as xgb
15
- from typing import Dict, Any
16
- from sklearn.model_selection import KFold, StratifiedKFold
17
- from sklearn.metrics import (
18
- precision_recall_fscore_support,
19
- confusion_matrix,
20
- mean_squared_error,
21
- mean_absolute_error,
22
- r2_score,
23
- )
24
- from sklearn.preprocessing import LabelEncoder
25
15
 
26
16
  # Workbench Imports
27
- from workbench.utils.model_utils import load_category_mappings_from_s3
17
+ from workbench.utils.aws_utils import pull_s3_data
18
+ from workbench.utils.metrics_utils import compute_metrics_from_predictions
19
+ from workbench.utils.model_utils import load_category_mappings_from_s3, safe_extract_tarfile
28
20
  from workbench.utils.pandas_utils import convert_categorical_types
29
21
 
30
22
  # Set up the log
@@ -34,14 +26,12 @@ log = logging.getLogger("workbench")
34
26
  def xgboost_model_from_s3(model_artifact_uri: str):
35
27
  """
36
28
  Download and extract XGBoost model artifact from S3, then load the model into memory.
37
- Handles both direct XGBoost model files and pickled models.
38
- Ensures categorical feature support is enabled.
39
29
 
40
30
  Args:
41
31
  model_artifact_uri (str): S3 URI of the model artifact.
42
32
 
43
33
  Returns:
44
- Loaded XGBoost model or None if unavailable.
34
+ Loaded XGBoost model (XGBClassifier, XGBRegressor, or Booster) or None if unavailable.
45
35
  """
46
36
 
47
37
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -50,68 +40,90 @@ def xgboost_model_from_s3(model_artifact_uri: str):
50
40
  wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
51
41
 
52
42
  # Extract tarball
53
- with tarfile.open(local_tar_path, "r:gz") as tar:
54
- tar.extractall(path=tmpdir, filter="data")
43
+ safe_extract_tarfile(local_tar_path, tmpdir)
55
44
 
56
45
  # Define model file patterns to search for (in order of preference)
57
46
  patterns = [
58
- # Direct XGBoost model files
59
- os.path.join(tmpdir, "xgboost-model"),
60
- os.path.join(tmpdir, "model"),
61
- os.path.join(tmpdir, "*.bin"),
47
+ # Joblib models (preferred - preserves everything)
48
+ os.path.join(tmpdir, "*model*.joblib"),
49
+ os.path.join(tmpdir, "xgb*.joblib"),
50
+ os.path.join(tmpdir, "**", "*model*.joblib"),
51
+ os.path.join(tmpdir, "**", "xgb*.joblib"),
52
+ # Pickle models (also preserves everything)
53
+ os.path.join(tmpdir, "*model*.pkl"),
54
+ os.path.join(tmpdir, "xgb*.pkl"),
55
+ os.path.join(tmpdir, "**", "*model*.pkl"),
56
+ os.path.join(tmpdir, "**", "xgb*.pkl"),
57
+ # JSON models (fallback - requires reconstruction)
58
+ os.path.join(tmpdir, "*model*.json"),
59
+ os.path.join(tmpdir, "xgb*.json"),
62
60
  os.path.join(tmpdir, "**", "*model*.json"),
63
- os.path.join(tmpdir, "**", "rmse.json"),
64
- # Pickled models
65
- os.path.join(tmpdir, "*.pkl"),
66
- os.path.join(tmpdir, "**", "*.pkl"),
67
- os.path.join(tmpdir, "*.pickle"),
68
- os.path.join(tmpdir, "**", "*.pickle"),
61
+ os.path.join(tmpdir, "**", "xgb*.json"),
69
62
  ]
70
63
 
71
64
  # Try each pattern
72
65
  for pattern in patterns:
73
- # Use glob to find all matching files
74
66
  for model_path in glob.glob(pattern, recursive=True):
75
- # Determine file type by extension
67
+ # Skip files that are clearly not XGBoost models
68
+ filename = os.path.basename(model_path).lower()
69
+ if any(skip in filename for skip in ["label_encoder", "scaler", "preprocessor", "transformer"]):
70
+ log.debug(f"Skipping non-model file: {model_path}")
71
+ continue
72
+
76
73
  _, ext = os.path.splitext(model_path)
77
74
 
78
75
  try:
79
- if ext.lower() in [".pkl", ".pickle"]:
80
- # Handle pickled models
76
+ if ext == ".joblib":
77
+ model = joblib.load(model_path)
78
+ # Verify it's actually an XGBoost model
79
+ if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
80
+ log.important(f"Loaded XGBoost model from joblib: {model_path}")
81
+ return model
82
+ else:
83
+ log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
84
+
85
+ elif ext in [".pkl", ".pickle"]:
81
86
  with open(model_path, "rb") as f:
82
87
  model = pickle.load(f)
83
-
84
- # Handle different model types
85
- if isinstance(model, xgb.Booster):
86
- log.important(f"Loaded XGBoost Booster from pickle: {model_path}")
88
+ # Verify it's actually an XGBoost model
89
+ if isinstance(model, (xgb.XGBClassifier, xgb.XGBRegressor, xgb.Booster)):
90
+ log.important(f"Loaded XGBoost model from pickle: {model_path}")
87
91
  return model
88
- elif hasattr(model, "get_booster"):
89
- log.important(f"Loaded XGBoost model from pipeline: {model_path}")
90
- booster = model.get_booster()
91
- return booster
92
- else:
93
- # Handle direct XGBoost model files
92
+ else:
93
+ log.debug(f"Skipping non-XGBoost object from {model_path}: {type(model)}")
94
+
95
+ elif ext == ".json":
96
+ # JSON files should be XGBoost models by definition
94
97
  booster = xgb.Booster()
95
98
  booster.load_model(model_path)
96
- log.important(f"Loaded XGBoost model directly: {model_path}")
99
+ log.important(f"Loaded XGBoost booster from JSON: {model_path}")
97
100
  return booster
101
+
98
102
  except Exception as e:
99
- log.info(f"Failed to load model from {model_path}: {e}")
100
- continue # Try the next file
103
+ log.debug(f"Failed to load {model_path}: {e}")
104
+ continue
101
105
 
102
- # If no model found
103
106
  log.error("No XGBoost model found in the artifact.")
104
107
  return None
105
108
 
106
109
 
107
- def feature_importance(workbench_model, importance_type: str = "weight") -> Optional[List[Tuple[str, float]]]:
110
+ def feature_importance(workbench_model, importance_type: str = "gain") -> Optional[List[Tuple[str, float]]]:
108
111
  """
109
112
  Get sorted feature importances from a Workbench Model object.
110
113
 
111
114
  Args:
112
115
  workbench_model: Workbench model object
113
- importance_type: Type of feature importance.
114
- Options: 'weight', 'gain', 'cover', 'total_gain', 'total_cover'
116
+ importance_type: Type of feature importance. Options:
117
+ - 'gain' (default): Average improvement in loss/objective when feature is used.
118
+ Best for understanding predictive power of features.
119
+ - 'weight': Number of times a feature appears in trees (split count).
120
+ Useful for understanding model complexity and feature usage frequency.
121
+ - 'cover': Average number of samples affected when feature is used.
122
+ Shows the relative quantity of observations related to this feature.
123
+ - 'total_gain': Total improvement in loss/objective across all splits.
124
+ Similar to 'gain' but not averaged (can be biased toward frequent features).
125
+ - 'total_cover': Total number of samples affected across all splits.
126
+ Similar to 'cover' but not averaged.
115
127
 
116
128
  Returns:
117
129
  List of tuples (feature, importance) sorted by importance value (descending).
@@ -120,7 +132,8 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
120
132
 
121
133
  Note:
122
134
  XGBoost's get_score() only returns features with non-zero importance.
123
- This function ensures all model features are included in the output.
135
+ This function ensures all model features are included in the output,
136
+ adding zero values for features that weren't used in any tree splits.
124
137
  """
125
138
  model_artifact_uri = workbench_model.model_data_url()
126
139
  xgb_model = xgboost_model_from_s3(model_artifact_uri)
@@ -128,11 +141,18 @@ def feature_importance(workbench_model, importance_type: str = "weight") -> Opti
128
141
  log.error("No XGBoost model found in the artifact.")
129
142
  return None
130
143
 
131
- # Get feature importances (only non-zero features)
132
- importances = xgb_model.get_score(importance_type=importance_type)
144
+ # Check if we got a full sklearn model or just a booster (for backwards compatibility)
145
+ if hasattr(xgb_model, "get_booster"):
146
+ # Full sklearn model - get the booster for feature importance
147
+ booster = xgb_model.get_booster()
148
+ all_features = booster.feature_names
149
+ else:
150
+ # Already a booster (legacy JSON load)
151
+ booster = xgb_model
152
+ all_features = xgb_model.feature_names
133
153
 
134
- # Get all feature names from the model
135
- all_features = xgb_model.feature_names
154
+ # Get feature importances (only non-zero features)
155
+ importances = booster.get_score(importance_type=importance_type)
136
156
 
137
157
  # Create complete importance dict with zeros for missing features
138
158
  complete_importances = {feat: importances.get(feat, 0.0) for feat in all_features}
@@ -229,149 +249,45 @@ def leaf_stats(df: pd.DataFrame, target_col: str) -> pd.DataFrame:
229
249
  return result_df
230
250
 
231
251
 
232
- def cross_fold_inference(workbench_model: Any, nfolds: int = 5) -> Dict[str, Any]:
233
- """
234
- Performs K-fold cross-validation with detailed metrics.
252
+ def pull_cv_results(workbench_model: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
253
+ """Pull cross-validation results from AWS training artifacts.
254
+
255
+ This retrieves the validation predictions saved during model training and
256
+ computes metrics directly from them. For XGBoost models trained with
257
+ n_folds > 1, these are out-of-fold predictions from k-fold cross-validation.
258
+
235
259
  Args:
236
260
  workbench_model: Workbench model object
237
- nfolds: Number of folds for cross-validation (default is 5)
261
+
238
262
  Returns:
239
- Dictionary containing:
240
- - folds: Dictionary of formatted strings for each fold
241
- - summary_metrics: Summary metrics across folds
242
- - overall_metrics: Overall metrics for all folds
263
+ Tuple of:
264
+ - DataFrame with computed metrics
265
+ - DataFrame with validation predictions
243
266
  """
244
- from workbench.api import FeatureSet
267
+ # Get the validation predictions from S3
268
+ s3_path = f"{workbench_model.model_training_path}/validation_predictions.csv"
269
+ predictions_df = pull_s3_data(s3_path)
245
270
 
246
- # Load model
247
- model_type = workbench_model.model_type.value
248
- model_artifact_uri = workbench_model.model_data_url()
249
- loaded_booster = xgboost_model_from_s3(model_artifact_uri)
250
- if loaded_booster is None:
251
- log.error("No XGBoost model found in the artifact.")
252
- return {}
253
- # Create the model wrapper
254
- is_classifier = model_type == "classifier"
255
- xgb_model = (
256
- xgb.XGBClassifier(enable_categorical=True) if is_classifier else xgb.XGBRegressor(enable_categorical=True)
257
- )
258
- xgb_model._Booster = loaded_booster
259
- # Prepare data
260
- fs = FeatureSet(workbench_model.get_input())
261
- df = fs.pull_dataframe()
262
- feature_cols = workbench_model.features()
263
- # Convert string features to categorical
264
- for col in feature_cols:
265
- if df[col].dtype in ["object", "string"]:
266
- df[col] = df[col].astype("category")
267
- # Split X and y
268
- X = df[workbench_model.features()]
269
- y = df[workbench_model.target()]
270
-
271
- # Encode target if it's a classification problem
272
- label_encoder = LabelEncoder() if is_classifier else None
273
- if label_encoder:
274
- y = pd.Series(label_encoder.fit_transform(y), name=workbench_model.target())
275
- # Prepare KFold
276
- kfold = (
277
- StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=42)
278
- if is_classifier
279
- else KFold(n_splits=nfolds, shuffle=True, random_state=42)
280
- )
271
+ if predictions_df is None:
272
+ raise ValueError(f"No validation predictions found at {s3_path}")
281
273
 
282
- fold_results = []
283
- all_predictions = []
284
- all_actuals = []
285
- for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
286
- X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
287
- y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
288
-
289
- # Train the model
290
- xgb_model.fit(X_train, y_train)
291
- preds = xgb_model.predict(X_val)
292
- all_predictions.extend(preds)
293
- all_actuals.extend(y_val)
294
-
295
- # Calculate metrics for this fold
296
- fold_metrics = {"fold": fold_idx + 1}
297
-
298
- if is_classifier:
299
- y_val_original = label_encoder.inverse_transform(y_val)
300
- preds_original = label_encoder.inverse_transform(preds.astype(int))
301
- scores = precision_recall_fscore_support(
302
- y_val_original, preds_original, average="weighted", zero_division=0
303
- )
304
- fold_metrics.update({"precision": float(scores[0]), "recall": float(scores[1]), "fscore": float(scores[2])})
305
- else:
306
- fold_metrics.update(
307
- {
308
- "rmse": float(np.sqrt(mean_squared_error(y_val, preds))),
309
- "mae": float(mean_absolute_error(y_val, preds)),
310
- "r2": float(r2_score(y_val, preds)),
311
- }
312
- )
313
-
314
- fold_results.append(fold_metrics)
315
- # Calculate overall metrics
316
- overall_metrics = {}
317
- if is_classifier:
318
- all_actuals_original = label_encoder.inverse_transform(all_actuals)
319
- all_predictions_original = label_encoder.inverse_transform(all_predictions)
320
- scores = precision_recall_fscore_support(
321
- all_actuals_original, all_predictions_original, average="weighted", zero_division=0
322
- )
323
- overall_metrics.update(
324
- {
325
- "precision": float(scores[0]),
326
- "recall": float(scores[1]),
327
- "fscore": float(scores[2]),
328
- "confusion_matrix": confusion_matrix(
329
- all_actuals_original, all_predictions_original, labels=label_encoder.classes_
330
- ).tolist(),
331
- "label_names": list(label_encoder.classes_),
332
- }
333
- )
274
+ log.info(f"Pulled {len(predictions_df)} validation predictions from {s3_path}")
275
+
276
+ # Compute metrics from predictions
277
+ target = workbench_model.target()
278
+ class_labels = workbench_model.class_labels()
279
+
280
+ if target in predictions_df.columns and "prediction" in predictions_df.columns:
281
+ metrics_df = compute_metrics_from_predictions(predictions_df, target, class_labels)
334
282
  else:
335
- overall_metrics.update(
336
- {
337
- "rmse": float(np.sqrt(mean_squared_error(all_actuals, all_predictions))),
338
- "mae": float(mean_absolute_error(all_actuals, all_predictions)),
339
- "r2": float(r2_score(all_actuals, all_predictions)),
340
- }
341
- )
342
- # Calculate summary metrics across folds
343
- summary_metrics = {}
344
- metrics_to_aggregate = ["precision", "recall", "fscore"] if is_classifier else ["rmse", "mae", "r2"]
345
-
346
- for metric in metrics_to_aggregate:
347
- values = [fold[metric] for fold in fold_results]
348
- summary_metrics[metric] = f"{float(np.mean(values)):.3f} ±{float(np.std(values)):.3f}"
349
- # Format fold results as strings (TBD section)
350
- formatted_folds = {}
351
- for fold_data in fold_results:
352
- fold_key = f"Fold {fold_data['fold']}"
353
- if is_classifier:
354
- formatted_folds[fold_key] = (
355
- f"precision: {fold_data['precision']:.3f} "
356
- f"recall: {fold_data['recall']:.3f} "
357
- f"fscore: {fold_data['fscore']:.3f}"
358
- )
359
- else:
360
- formatted_folds[fold_key] = (
361
- f"rmse: {fold_data['rmse']:.3f} mae: {fold_data['mae']:.3f} r2: {fold_data['r2']:.3f}"
362
- )
363
- # Return the results
364
- return {
365
- "summary_metrics": summary_metrics,
366
- # "overall_metrics": overall_metrics,
367
- "folds": formatted_folds,
368
- }
283
+ metrics_df = pd.DataFrame()
284
+
285
+ return metrics_df, predictions_df
369
286
 
370
287
 
371
288
  if __name__ == "__main__":
372
289
  """Exercise the Model Utilities"""
373
- from workbench.api import Model, FeatureSet
374
- from pprint import pprint
290
+ from workbench.api import Model
375
291
 
376
292
  # Test the XGBoost model loading and feature importance
377
293
  model = Model("abalone-regression")
@@ -383,34 +299,26 @@ if __name__ == "__main__":
383
299
  model_artifact_uri = model.model_data_url()
384
300
  xgb_model = xgboost_model_from_s3(model_artifact_uri)
385
301
 
386
- # Test with UQ Model
387
- uq_model = Model("aqsol-uq")
388
- _xgb_model = xgboost_model_from_s3(uq_model.model_data_url())
389
-
390
- # Test XGBoost add_leaf_hash
391
- input_df = FeatureSet(model.get_input()).pull_dataframe()
392
- leaf_df = add_leaf_hash(model, input_df)
393
- print("DataFrame with Leaf Hash:")
394
- print(leaf_df)
395
-
396
- # Okay, we're going to copy row 3 and insert it into row 7 to make sure the leaf_hash is the same
397
- input_df.iloc[7] = input_df.iloc[3]
398
- print("DataFrame with Leaf Hash (3 and 7 should match):")
399
- leaf_df = add_leaf_hash(model, input_df)
400
- print(leaf_df)
401
-
402
- # Test leaf_stats
403
- target_col = "class_number_of_rings"
404
- stats_df = leaf_stats(leaf_df, target_col)
405
- print("DataFrame with Leaf Statistics:")
406
- print(stats_df)
407
-
408
- print("\n=== CROSS FOLD REGRESSION EXAMPLE ===")
409
- model = Model("abalone-regression")
410
- results = cross_fold_inference(model)
411
- pprint(results)
302
+ # Verify enable_categorical is preserved (for debugging/confidence)
303
+ print(f"Model parameters: {xgb_model.get_params()}")
304
+ print(f"enable_categorical: {xgb_model.enable_categorical}")
412
305
 
413
- print("\n=== CROSS FOLD CLASSIFICATION EXAMPLE ===")
306
+ print("\n=== PULL CV RESULTS EXAMPLE ===")
307
+ model = Model("abalone-regression")
308
+ metrics_df, predictions_df = pull_cv_results(model)
309
+ print(f"\nMetrics:\n{metrics_df}")
310
+ print(f"\nPredictions shape: {predictions_df.shape}")
311
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
312
+ print(predictions_df.head())
313
+
314
+ # Test on a Classifier model
315
+ print("\n=== CLASSIFIER MODEL TEST ===")
414
316
  model = Model("wine-classification")
415
- results = cross_fold_inference(model)
416
- pprint(results)
317
+ features = feature_importance(model)
318
+ print("Feature Importance:")
319
+ print(features)
320
+ metrics_df, predictions_df = pull_cv_results(model)
321
+ print(f"\nMetrics:\n{metrics_df}")
322
+ print(f"\nPredictions shape: {predictions_df.shape}")
323
+ print(f"Predictions columns: {predictions_df.columns.tolist()}")
324
+ print(predictions_df.head())
@@ -36,10 +36,22 @@ class ModelPlot(ComponentInterface):
36
36
  if df is None:
37
37
  return self.display_text("No Data")
38
38
 
39
- # Calculate the distance from the diagonal for each point
39
+ # Grab the target(s) for this model
40
40
  target = model.target()
41
+
42
+ # For multi-task models, match target to inference_run name or default to first
43
+ if isinstance(target, list):
44
+ target = next((t for t in target if t in inference_run), target[0])
45
+
46
+ # Compute error for coloring
41
47
  df["error"] = abs(df["prediction"] - df[target])
42
- return ScatterPlot().update_properties(df, color="error", regression_line=True)[0]
48
+ return ScatterPlot().update_properties(
49
+ df,
50
+ color="error",
51
+ regression_line=True,
52
+ x=target,
53
+ y="prediction",
54
+ )[0]
43
55
  else:
44
56
  return self.display_text(f"Model Type: {model.model_type}\n\n Awesome Plot Coming Soon!")
45
57
 
@@ -156,10 +156,13 @@ class PluginUnitTest:
156
156
  """Run the Dash server for the plugin, handling common errors gracefully."""
157
157
  while self.is_port_in_use(self.port):
158
158
  log.info(f"Port {self.port} is in use. Trying the next one...")
159
- self.port += 1 # Increment the port number until an available one is found
159
+ self.port += 1
160
160
 
161
161
  log.info(f"Starting Dash server on port {self.port}...")
162
- self.app.run(debug=True, use_reloader=False, port=self.port)
162
+ try:
163
+ self.app.run(debug=True, use_reloader=False, port=self.port)
164
+ except KeyboardInterrupt:
165
+ log.info("Shutting down Dash server...")
163
166
 
164
167
  @staticmethod
165
168
  def is_port_in_use(port):
@@ -72,7 +72,9 @@ class DashboardStatus(PluginInterface):
72
72
  details = "**Redis:** 🔴 Failed to Connect<br>"
73
73
 
74
74
  # Fill in the license details
75
- details += f"**Redis Server:** {config_info['REDIS_HOST']}:{config_info.get('REDIS_PORT', 6379)}<br>"
75
+ redis_host = config_info.get("REDIS_HOST", "NOT SET")
76
+ redis_port = config_info.get("REDIS_PORT", "NOT SET")
77
+ details += f"**Redis Server:** {redis_host}:{redis_port}<br>"
76
78
  details += f"**Workbench S3 Bucket:** {config_info['WORKBENCH_BUCKET']}<br>"
77
79
  details += f"**Plugin Path:** {config_info.get('WORKBENCH_PLUGINS', 'unknown')}<br>"
78
80
  details += f"**Themes Path:** {config_info.get('WORKBENCH_THEMES', 'unknown')}<br>"
@@ -5,7 +5,7 @@ import dash_bootstrap_components as dbc
5
5
 
6
6
  # Workbench Imports
7
7
  from workbench.api.compound import Compound
8
- from workbench.utils.chem_utils import svg_from_smiles
8
+ from workbench.utils.chem_utils.vis import svg_from_smiles
9
9
  from workbench.web_interface.components.plugin_interface import PluginInterface, PluginPage, PluginInputType
10
10
  from workbench.utils.theme_manager import ThemeManager
11
11
  from workbench.utils.ai_summary import AISummary