workbench 0.8.204__py3-none-any.whl → 0.8.212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/models/noise_model.py +388 -0
  2. workbench/api/endpoint.py +3 -6
  3. workbench/api/feature_set.py +1 -1
  4. workbench/api/model.py +5 -11
  5. workbench/cached/cached_model.py +4 -4
  6. workbench/core/artifacts/endpoint_core.py +83 -146
  7. workbench/core/artifacts/model_core.py +21 -19
  8. workbench/core/transforms/features_to_model/features_to_model.py +2 -2
  9. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
  10. workbench/model_script_utils/model_script_utils.py +335 -0
  11. workbench/model_script_utils/pytorch_utils.py +395 -0
  12. workbench/model_script_utils/uq_harness.py +278 -0
  13. workbench/model_scripts/chemprop/chemprop.template +289 -666
  14. workbench/model_scripts/chemprop/generated_model_script.py +292 -669
  15. workbench/model_scripts/chemprop/model_script_utils.py +335 -0
  16. workbench/model_scripts/chemprop/requirements.txt +2 -10
  17. workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
  18. workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
  19. workbench/model_scripts/pytorch_model/pytorch.template +350 -607
  20. workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
  21. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  22. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  23. workbench/model_scripts/script_generation.py +2 -5
  24. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  25. workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
  26. workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  28. workbench/model_scripts/xgb_model/xgb_model.template +344 -407
  29. workbench/scripts/training_test.py +85 -0
  30. workbench/utils/chemprop_utils.py +18 -656
  31. workbench/utils/metrics_utils.py +172 -0
  32. workbench/utils/model_utils.py +104 -47
  33. workbench/utils/pytorch_utils.py +32 -472
  34. workbench/utils/xgboost_local_crossfold.py +267 -0
  35. workbench/utils/xgboost_model_utils.py +49 -356
  36. workbench/web_interface/components/plugins/model_details.py +30 -68
  37. {workbench-0.8.204.dist-info → workbench-0.8.212.dist-info}/METADATA +5 -5
  38. {workbench-0.8.204.dist-info → workbench-0.8.212.dist-info}/RECORD +42 -31
  39. {workbench-0.8.204.dist-info → workbench-0.8.212.dist-info}/entry_points.txt +1 -0
  40. workbench/model_scripts/uq_models/mapie.template +0 -605
  41. workbench/model_scripts/uq_models/requirements.txt +0 -1
  42. {workbench-0.8.204.dist-info → workbench-0.8.212.dist-info}/WHEEL +0 -0
  43. {workbench-0.8.204.dist-info → workbench-0.8.212.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.204.dist-info → workbench-0.8.212.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,8 @@ from typing import Union, Optional
12
12
  import hashlib
13
13
 
14
14
  # Model Performance Scores
15
- from sklearn.metrics import (
16
- mean_absolute_error,
17
- r2_score,
18
- median_absolute_error,
19
- roc_auc_score,
20
- confusion_matrix,
21
- precision_recall_fscore_support,
22
- mean_squared_error,
23
- )
24
- from sklearn.preprocessing import OneHotEncoder
15
+ from sklearn.metrics import confusion_matrix
16
+ from workbench.utils.metrics_utils import compute_regression_metrics, compute_classification_metrics
25
17
 
26
18
  # SageMaker Imports
27
19
  from sagemaker.serializers import CSVSerializer
@@ -35,7 +27,7 @@ from workbench.utils.endpoint_metrics import EndpointMetrics
35
27
  from workbench.utils.cache import Cache
36
28
  from workbench.utils.s3_utils import compute_s3_object_hash
37
29
  from workbench.utils.model_utils import uq_metrics
38
- from workbench.utils.xgboost_model_utils import cross_fold_inference as xgboost_cross_fold
30
+ from workbench.utils.xgboost_model_utils import pull_cv_results as xgboost_pull_cv
39
31
  from workbench.utils.pytorch_utils import pull_cv_results as pytorch_pull_cv
40
32
  from workbench.utils.chemprop_utils import pull_cv_results as chemprop_pull_cv
41
33
  from workbench_bridges.endpoints.fast_inference import fast_inference
@@ -397,7 +389,6 @@ class EndpointCore(Artifact):
397
389
  self.log.warning("No predictions were made. Returning empty DataFrame.")
398
390
  return prediction_df
399
391
 
400
- # FIXME: Multi-target support - currently uses first target for metrics
401
392
  # Normalize targets to handle both string and list formats
402
393
  if isinstance(targets, list):
403
394
  primary_target = targets[0] if targets else None
@@ -436,19 +427,16 @@ class EndpointCore(Artifact):
436
427
 
437
428
  # Normalize targets to a list for iteration
438
429
  target_list = targets if isinstance(targets, list) else [targets]
430
+ primary_target = target_list[0]
439
431
 
440
- # For multi-target models, use target-specific capture names (e.g., auto_target1, auto_target2)
441
- # For single-target models, use the original capture name for backward compatibility
442
- for target in target_list:
443
- # Determine capture name: use prefix for multi-target, original name for single-target
444
- if len(target_list) > 1:
445
- prefix = "auto" if "auto" in capture_name else capture_name
446
- target_capture_name = f"{prefix}_{target}"
447
- else:
448
- target_capture_name = capture_name
432
+ # For single-target models (99% of cases), just save with capture_name
433
+ # For multi-target models, save each as {prefix}_{target} plus primary as capture_name
434
+ is_multi_target = len(target_list) > 1
449
435
 
450
- description = target_capture_name.replace("_", " ").title()
436
+ if is_multi_target:
437
+ prefix = "auto" if capture_name == "auto_inference" else capture_name
451
438
 
439
+ for target in target_list:
452
440
  # Drop rows with NaN target values for metrics/plots
453
441
  target_df = prediction_df.dropna(subset=[target])
454
442
 
@@ -460,30 +448,44 @@ class EndpointCore(Artifact):
460
448
  else:
461
449
  target_metrics = pd.DataFrame()
462
450
 
463
- self._capture_inference_results(
464
- target_capture_name,
465
- target_df,
466
- target,
467
- model.model_type,
468
- target_metrics,
469
- description,
470
- features,
471
- id_column,
472
- )
451
+ if is_multi_target:
452
+ # Multi-target: save as {prefix}_{target}
453
+ target_capture_name = f"{prefix}_{target}"
454
+ description = target_capture_name.replace("_", " ").title()
455
+ self._capture_inference_results(
456
+ target_capture_name,
457
+ target_df,
458
+ target,
459
+ model.model_type,
460
+ target_metrics,
461
+ description,
462
+ features,
463
+ id_column,
464
+ )
465
+
466
+ # Save primary target (or single target) with original capture_name
467
+ if target == primary_target:
468
+ self._capture_inference_results(
469
+ capture_name,
470
+ target_df,
471
+ target,
472
+ model.model_type,
473
+ target_metrics,
474
+ capture_name.replace("_", " ").title(),
475
+ features,
476
+ id_column,
477
+ )
473
478
 
474
479
  # For UQ Models we also capture the uncertainty metrics
475
- if model.model_type in [ModelType.UQ_REGRESSOR]:
480
+ if model.model_type == ModelType.UQ_REGRESSOR:
476
481
  metrics = uq_metrics(prediction_df, primary_target)
477
482
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
478
483
 
479
484
  # Return the prediction DataFrame
480
485
  return prediction_df
481
486
 
482
- def cross_fold_inference(self, nfolds: int = 5) -> pd.DataFrame:
483
- """Run cross-fold inference (only works for XGBoost models)
484
-
485
- Args:
486
- nfolds (int): Number of folds to use for cross-fold (default: 5)
487
+ def cross_fold_inference(self) -> pd.DataFrame:
488
+ """Pull cross-fold inference training results for this Endpoint's model
487
489
 
488
490
  Returns:
489
491
  pd.DataFrame: A DataFrame with cross fold predictions
@@ -495,8 +497,8 @@ class EndpointCore(Artifact):
495
497
  # Compute CrossFold (Metrics and Prediction Dataframe)
496
498
  # For PyTorch and ChemProp, pull pre-computed CV results from training
497
499
  if model.model_framework in [ModelFramework.UNKNOWN, ModelFramework.XGBOOST]:
498
- cross_fold_metrics, out_of_fold_df = xgboost_cross_fold(model, nfolds=nfolds)
499
- elif model.model_framework == ModelFramework.PYTORCH_TABULAR:
500
+ cross_fold_metrics, out_of_fold_df = xgboost_pull_cv(model)
501
+ elif model.model_framework == ModelFramework.PYTORCH:
500
502
  cross_fold_metrics, out_of_fold_df = pytorch_pull_cv(model)
501
503
  elif model.model_framework == ModelFramework.CHEMPROP:
502
504
  cross_fold_metrics, out_of_fold_df = chemprop_pull_cv(model)
@@ -523,51 +525,27 @@ class EndpointCore(Artifact):
523
525
  fs = FeatureSetCore(model.get_input())
524
526
  id_column = fs.id_column
525
527
 
526
- # Is this a UQ Model? If so, run full inference and merge the results
528
+ # For UQ models, get UQ columns from training CV results and compute metrics
529
+ # Note: XGBoost training now saves all UQ columns (q_*, confidence, prediction_std)
527
530
  additional_columns = []
528
- if model.model_framework == ModelFramework.XGBOOST and model_type == ModelType.UQ_REGRESSOR:
529
- self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
530
-
531
- # Get the training view dataframe for inference
532
- training_df = model.training_view().pull_dataframe()
533
-
534
- # Run inference on the endpoint to get UQ outputs
535
- uq_df = self.inference(training_df)
536
-
537
- # Identify UQ-specific columns (quantiles, prediction_std, *_pred_std)
538
- uq_columns = [
539
- col
540
- for col in uq_df.columns
541
- if col.startswith("q_") or col == "prediction_std" or col.endswith("_pred_std") or col == "confidence"
542
- ]
543
-
544
- # Merge UQ columns with out-of-fold predictions
531
+ if model_type == ModelType.UQ_REGRESSOR:
532
+ uq_columns = [col for col in out_of_fold_df.columns if col.startswith("q_") or col == "confidence"]
545
533
  if uq_columns:
546
- # Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
547
- uq_df = uq_df[[id_column] + uq_columns]
548
-
549
- # Drop duplicates in uq_df based on id_column
550
- uq_df = uq_df.drop_duplicates(subset=[id_column])
551
-
552
- # Merge UQ columns into out_of_fold_df
553
- out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
554
534
  additional_columns = uq_columns
555
- self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
556
-
557
- # Also compute UQ metrics (use first target for multi-target models)
535
+ self.log.info(f"UQ columns from training: {', '.join(uq_columns)}")
558
536
  primary_target = targets[0] if isinstance(targets, list) else targets
559
537
  metrics = uq_metrics(out_of_fold_df, primary_target)
560
538
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/full_cross_fold", metrics)
561
539
 
562
540
  # Normalize targets to a list for iteration
563
541
  target_list = targets if isinstance(targets, list) else [targets]
542
+ primary_target = target_list[0]
564
543
 
565
- # For multi-target models, use target-specific capture names (e.g., cv_target1, cv_target2)
566
- # For single-target models, use "full_cross_fold" for backward compatibility
567
- for target in target_list:
568
- capture_name = f"cv_{target}"
569
- description = capture_name.replace("_", " ").title()
544
+ # For single-target models (99% of cases), just save as "full_cross_fold"
545
+ # For multi-target models, save each as cv_{target} plus primary as "full_cross_fold"
546
+ is_multi_target = len(target_list) > 1
570
547
 
548
+ for target in target_list:
571
549
  # Drop rows with NaN target values for metrics/plots
572
550
  target_df = out_of_fold_df.dropna(subset=[target])
573
551
 
@@ -579,16 +557,33 @@ class EndpointCore(Artifact):
579
557
  else:
580
558
  target_metrics = pd.DataFrame()
581
559
 
582
- self._capture_inference_results(
583
- capture_name,
584
- target_df,
585
- target,
586
- model_type,
587
- target_metrics,
588
- description,
589
- features=additional_columns,
590
- id_column=id_column,
591
- )
560
+ if is_multi_target:
561
+ # Multi-target: save as cv_{target}
562
+ capture_name = f"cv_{target}"
563
+ description = capture_name.replace("_", " ").title()
564
+ self._capture_inference_results(
565
+ capture_name,
566
+ target_df,
567
+ target,
568
+ model_type,
569
+ target_metrics,
570
+ description,
571
+ features=additional_columns,
572
+ id_column=id_column,
573
+ )
574
+
575
+ # Save primary target (or single target) as "full_cross_fold"
576
+ if target == primary_target:
577
+ self._capture_inference_results(
578
+ "full_cross_fold",
579
+ target_df,
580
+ target,
581
+ model_type,
582
+ target_metrics,
583
+ "Full Cross Fold",
584
+ features=additional_columns,
585
+ id_column=id_column,
586
+ )
592
587
 
593
588
  return out_of_fold_df
594
589
 
@@ -935,29 +930,9 @@ class EndpointCore(Artifact):
935
930
  self.log.warning("Dropping NaN rows for metric computation.")
936
931
  prediction_df = prediction_df.dropna(subset=[target_column, "prediction"])
937
932
 
938
- # Compute the metrics
933
+ # Compute the metrics using shared utilities
939
934
  try:
940
- y_true = prediction_df[target_column]
941
- y_pred = prediction_df["prediction"]
942
-
943
- mae = mean_absolute_error(y_true, y_pred)
944
- rmse = np.sqrt(mean_squared_error(y_true, y_pred))
945
- r2 = r2_score(y_true, y_pred)
946
- # Mean Absolute Percentage Error
947
- mape = np.mean(np.where(y_true != 0, np.abs((y_true - y_pred) / y_true), np.abs(y_true - y_pred))) * 100
948
- # Median Absolute Error
949
- medae = median_absolute_error(y_true, y_pred)
950
-
951
- # Organize and return the metrics
952
- metrics = {
953
- "MAE": round(mae, 3),
954
- "RMSE": round(rmse, 3),
955
- "R2": round(r2, 3),
956
- "MAPE": round(mape, 3),
957
- "MedAE": round(medae, 3),
958
- "NumRows": len(prediction_df),
959
- }
960
- return pd.DataFrame.from_records([metrics])
935
+ return compute_regression_metrics(prediction_df, target_column)
961
936
  except Exception as e:
962
937
  self.log.warning(f"Error computing regression metrics: {str(e)}")
963
938
  return pd.DataFrame()
@@ -1040,46 +1015,8 @@ class EndpointCore(Artifact):
1040
1015
  else:
1041
1016
  self.validate_proba_columns(prediction_df, class_labels)
1042
1017
 
1043
- # Calculate precision, recall, f1, and support, handling zero division
1044
- scores = precision_recall_fscore_support(
1045
- prediction_df[target_column],
1046
- prediction_df["prediction"],
1047
- average=None,
1048
- labels=class_labels,
1049
- zero_division=0,
1050
- )
1051
-
1052
- # Identify the probability columns and keep them as a Pandas DataFrame
1053
- proba_columns = [f"{label}_proba" for label in class_labels]
1054
- y_score = prediction_df[proba_columns]
1055
-
1056
- # One-hot encode the true labels using all class labels (fit with class_labels)
1057
- encoder = OneHotEncoder(categories=[class_labels], sparse_output=False)
1058
- y_true = encoder.fit_transform(prediction_df[[target_column]])
1059
-
1060
- # Calculate ROC AUC per label and handle exceptions for missing classes
1061
- roc_auc_per_label = []
1062
- for i, label in enumerate(class_labels):
1063
- try:
1064
- roc_auc = roc_auc_score(y_true[:, i], y_score.iloc[:, i])
1065
- except ValueError as e:
1066
- self.log.warning(f"ROC AUC calculation failed for label {label}.")
1067
- self.log.warning(f"{str(e)}")
1068
- roc_auc = 0.0
1069
- roc_auc_per_label.append(roc_auc)
1070
-
1071
- # Put the scores into a DataFrame
1072
- score_df = pd.DataFrame(
1073
- {
1074
- target_column: class_labels,
1075
- "precision": scores[0],
1076
- "recall": scores[1],
1077
- "f1": scores[2],
1078
- "roc_auc": roc_auc_per_label,
1079
- "support": scores[3],
1080
- }
1081
- )
1082
- return score_df
1018
+ # Compute the metrics using shared utilities (returns per-class + 'all' row)
1019
+ return compute_classification_metrics(prediction_df, target_column, class_labels)
1083
1020
 
1084
1021
  def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
1085
1022
  """Compute the confusion matrix for this Endpoint
@@ -21,7 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
- from workbench.utils.model_utils import proximity_model
24
+ from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
25
25
 
26
26
 
27
27
  class ModelType(Enum):
@@ -44,7 +44,7 @@ class ModelFramework(Enum):
44
44
  SKLEARN = "sklearn"
45
45
  XGBOOST = "xgboost"
46
46
  LIGHTGBM = "lightgbm"
47
- PYTORCH_TABULAR = "pytorch_tabular"
47
+ PYTORCH = "pytorch"
48
48
  CHEMPROP = "chemprop"
49
49
  TRANSFORMER = "transformer"
50
50
  UNKNOWN = "unknown"
@@ -263,11 +263,11 @@ class ModelCore(Artifact):
263
263
  else:
264
264
  self.log.important(f"No inference data found for {self.model_name}!")
265
265
 
266
- def get_inference_metrics(self, capture_name: str = "any") -> Union[pd.DataFrame, None]:
266
+ def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
267
267
  """Retrieve the inference performance metrics for this model
268
268
 
269
269
  Args:
270
- capture_name (str, optional): Specific capture_name (default: "any")
270
+ capture_name (str, optional): Specific capture_name (default: "auto")
271
271
  Returns:
272
272
  pd.DataFrame: DataFrame of the Model Metrics
273
273
 
@@ -275,7 +275,7 @@ class ModelCore(Artifact):
275
275
  If a capture_name isn't specified this will try to the 'first' available metrics
276
276
  """
277
277
  # Try to get the auto_capture 'training_holdout' or the training
278
- if capture_name == "any":
278
+ if capture_name == "auto":
279
279
  metric_list = self.list_inference_runs()
280
280
  if metric_list:
281
281
  return self.get_inference_metrics(metric_list[0])
@@ -303,11 +303,11 @@ class ModelCore(Artifact):
303
303
  self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
304
304
  return None
305
305
 
306
- def confusion_matrix(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
306
+ def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
307
307
  """Retrieve the confusion_matrix for this model
308
308
 
309
309
  Args:
310
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
310
+ capture_name (str, optional): Specific capture_name or "training" (default: "auto")
311
311
  Returns:
312
312
  pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
313
313
  """
@@ -319,7 +319,7 @@ class ModelCore(Artifact):
319
319
  raise ValueError(error_msg)
320
320
 
321
321
  # Grab the metrics from the Workbench Metadata (try inference first, then training)
322
- if capture_name == "latest":
322
+ if capture_name == "auto":
323
323
  cm = self.confusion_matrix("auto_inference")
324
324
  return cm if cm is not None else self.confusion_matrix("model_training")
325
325
 
@@ -541,6 +541,17 @@ class ModelCore(Artifact):
541
541
  else:
542
542
  self.log.error(f"Model {self.model_name} is not a classifier!")
543
543
 
544
+ def summary(self) -> dict:
545
+ """Summary information about this Model
546
+
547
+ Returns:
548
+ dict: Dictionary of summary information about this Model
549
+ """
550
+ self.log.info("Computing Model Summary...")
551
+ summary = super().summary()
552
+ summary["hyperparameters"] = get_model_hyperparameters(self)
553
+ return summary
554
+
544
555
  def details(self) -> dict:
545
556
  """Additional Details about this Model
546
557
 
@@ -565,6 +576,7 @@ class ModelCore(Artifact):
565
576
  details["status"] = self.latest_model["ModelPackageStatus"]
566
577
  details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
567
578
  details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
579
+ details["hyperparameters"] = get_model_hyperparameters(self)
568
580
 
569
581
  # Grab the inference and container info
570
582
  inference_spec = self.latest_model["InferenceSpecification"]
@@ -575,16 +587,6 @@ class ModelCore(Artifact):
575
587
  details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
576
588
  details["content_types"] = inference_spec["SupportedContentTypes"]
577
589
  details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
578
- details["model_metrics"] = self.get_inference_metrics()
579
- if self.model_type == ModelType.CLASSIFIER:
580
- details["confusion_matrix"] = self.confusion_matrix()
581
- details["predictions"] = None
582
- elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
583
- details["confusion_matrix"] = None
584
- details["predictions"] = self.get_inference_predictions()
585
- else:
586
- details["confusion_matrix"] = None
587
- details["predictions"] = None
588
590
 
589
591
  # Grab the inference metadata
590
592
  details["inference_meta"] = self.get_inference_metadata()
@@ -904,7 +906,7 @@ class ModelCore(Artifact):
904
906
  """
905
907
  if prox_model_name is None:
906
908
  prox_model_name = self.model_name + "-prox"
907
- return proximity_model(self, prox_model_name, track_columns=track_columns)
909
+ return published_proximity_model(self, prox_model_name, track_columns=track_columns)
908
910
 
909
911
  def delete(self):
910
912
  """Delete the Model Packages and the Model Group"""
@@ -228,7 +228,7 @@ class FeaturesToModel(Transform):
228
228
  raise ValueError(msg)
229
229
 
230
230
  # Dynamically create the metric definitions
231
- metrics = ["precision", "recall", "f1"]
231
+ metrics = ["precision", "recall", "f1", "support"]
232
232
  metric_definitions = []
233
233
  for t in self.class_labels:
234
234
  for m in metrics:
@@ -254,7 +254,7 @@ class FeaturesToModel(Transform):
254
254
  image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image)
255
255
 
256
256
  # Use GPU instance for ChemProp/PyTorch, CPU for others
257
- if self.model_framework in [ModelFramework.CHEMPROP, ModelFramework.PYTORCH_TABULAR]:
257
+ if self.model_framework in [ModelFramework.CHEMPROP, ModelFramework.PYTORCH]:
258
258
  train_instance_type = "ml.g6.xlarge" # NVIDIA L4 GPU, ~$0.80/hr
259
259
  self.log.important(f"Using GPU instance {train_instance_type} for {self.model_framework.value}")
260
260
  else:
@@ -106,7 +106,7 @@ class ModelToEndpoint(Transform):
106
106
  from workbench.api import ModelFramework
107
107
 
108
108
  self.log.info(f"Model Framework: {workbench_model.model_framework}")
109
- if workbench_model.model_framework in [ModelFramework.PYTORCH_TABULAR, ModelFramework.CHEMPROP]:
109
+ if workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]:
110
110
  if mem_size < 4096:
111
111
  self.log.important(
112
112
  f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)"