workbench 0.8.205__py3-none-any.whl → 0.8.213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/models/noise_model.py +388 -0
  2. workbench/api/endpoint.py +3 -6
  3. workbench/api/feature_set.py +1 -1
  4. workbench/api/model.py +5 -11
  5. workbench/cached/cached_model.py +4 -4
  6. workbench/core/artifacts/endpoint_core.py +63 -153
  7. workbench/core/artifacts/model_core.py +21 -19
  8. workbench/core/transforms/features_to_model/features_to_model.py +2 -2
  9. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
  10. workbench/model_script_utils/model_script_utils.py +335 -0
  11. workbench/model_script_utils/pytorch_utils.py +395 -0
  12. workbench/model_script_utils/uq_harness.py +278 -0
  13. workbench/model_scripts/chemprop/chemprop.template +289 -666
  14. workbench/model_scripts/chemprop/generated_model_script.py +292 -669
  15. workbench/model_scripts/chemprop/model_script_utils.py +335 -0
  16. workbench/model_scripts/chemprop/requirements.txt +2 -10
  17. workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
  18. workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
  19. workbench/model_scripts/pytorch_model/pytorch.template +350 -607
  20. workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
  21. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  22. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  23. workbench/model_scripts/script_generation.py +2 -5
  24. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  25. workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
  26. workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  28. workbench/model_scripts/xgb_model/xgb_model.template +344 -407
  29. workbench/scripts/training_test.py +85 -0
  30. workbench/utils/chemprop_utils.py +18 -656
  31. workbench/utils/metrics_utils.py +172 -0
  32. workbench/utils/model_utils.py +104 -47
  33. workbench/utils/pytorch_utils.py +32 -472
  34. workbench/utils/xgboost_local_crossfold.py +267 -0
  35. workbench/utils/xgboost_model_utils.py +49 -356
  36. workbench/web_interface/components/plugins/model_details.py +30 -68
  37. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/METADATA +5 -5
  38. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/RECORD +42 -31
  39. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/entry_points.txt +1 -0
  40. workbench/model_scripts/uq_models/mapie.template +0 -605
  41. workbench/model_scripts/uq_models/requirements.txt +0 -1
  42. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/WHEEL +0 -0
  43. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,8 @@ from typing import Union, Optional
12
12
  import hashlib
13
13
 
14
14
  # Model Performance Scores
15
- from sklearn.metrics import (
16
- mean_absolute_error,
17
- r2_score,
18
- median_absolute_error,
19
- roc_auc_score,
20
- confusion_matrix,
21
- precision_recall_fscore_support,
22
- mean_squared_error,
23
- )
24
- from sklearn.preprocessing import OneHotEncoder
15
+ from sklearn.metrics import confusion_matrix
16
+ from workbench.utils.metrics_utils import compute_regression_metrics, compute_classification_metrics
25
17
 
26
18
  # SageMaker Imports
27
19
  from sagemaker.serializers import CSVSerializer
@@ -35,7 +27,7 @@ from workbench.utils.endpoint_metrics import EndpointMetrics
35
27
  from workbench.utils.cache import Cache
36
28
  from workbench.utils.s3_utils import compute_s3_object_hash
37
29
  from workbench.utils.model_utils import uq_metrics
38
- from workbench.utils.xgboost_model_utils import cross_fold_inference as xgboost_cross_fold
30
+ from workbench.utils.xgboost_model_utils import pull_cv_results as xgboost_pull_cv
39
31
  from workbench.utils.pytorch_utils import pull_cv_results as pytorch_pull_cv
40
32
  from workbench.utils.chemprop_utils import pull_cv_results as chemprop_pull_cv
41
33
  from workbench_bridges.endpoints.fast_inference import fast_inference
@@ -397,7 +389,6 @@ class EndpointCore(Artifact):
397
389
  self.log.warning("No predictions were made. Returning empty DataFrame.")
398
390
  return prediction_df
399
391
 
400
- # FIXME: Multi-target support - currently uses first target for metrics
401
392
  # Normalize targets to handle both string and list formats
402
393
  if isinstance(targets, list):
403
394
  primary_target = targets[0] if targets else None
@@ -438,11 +429,13 @@ class EndpointCore(Artifact):
438
429
  target_list = targets if isinstance(targets, list) else [targets]
439
430
  primary_target = target_list[0]
440
431
 
441
- # For auto_inference, use shorter "auto_{target}" naming
442
- # Otherwise use "{capture_name}_{target}"
443
- prefix = "auto" if capture_name == "auto_inference" else capture_name
432
+ # For single-target models (99% of cases), just save with capture_name
433
+ # For multi-target models, save each as {prefix}_{target} plus primary as capture_name
434
+ is_multi_target = len(target_list) > 1
435
+
436
+ if is_multi_target:
437
+ prefix = "auto" if capture_name == "auto_inference" else capture_name
444
438
 
445
- # Save results for each target, plus primary target with original capture_name
446
439
  for target in target_list:
447
440
  # Drop rows with NaN target values for metrics/plots
448
441
  target_df = prediction_df.dropna(subset=[target])
@@ -455,21 +448,22 @@ class EndpointCore(Artifact):
455
448
  else:
456
449
  target_metrics = pd.DataFrame()
457
450
 
458
- # Save as {prefix}_{target}
459
- target_capture_name = f"{prefix}_{target}"
460
- description = target_capture_name.replace("_", " ").title()
461
- self._capture_inference_results(
462
- target_capture_name,
463
- target_df,
464
- target,
465
- model.model_type,
466
- target_metrics,
467
- description,
468
- features,
469
- id_column,
470
- )
451
+ if is_multi_target:
452
+ # Multi-target: save as {prefix}_{target}
453
+ target_capture_name = f"{prefix}_{target}"
454
+ description = target_capture_name.replace("_", " ").title()
455
+ self._capture_inference_results(
456
+ target_capture_name,
457
+ target_df,
458
+ target,
459
+ model.model_type,
460
+ target_metrics,
461
+ description,
462
+ features,
463
+ id_column,
464
+ )
471
465
 
472
- # Also save primary target with original capture_name for backward compatibility
466
+ # Save primary target (or single target) with original capture_name
473
467
  if target == primary_target:
474
468
  self._capture_inference_results(
475
469
  capture_name,
@@ -482,19 +476,16 @@ class EndpointCore(Artifact):
482
476
  id_column,
483
477
  )
484
478
 
485
- # For UQ Models we also capture the uncertainty metrics
486
- if model.model_type in [ModelType.UQ_REGRESSOR]:
479
+ # Capture uncertainty metrics if prediction_std is available (UQ, ChemProp, etc.)
480
+ if "prediction_std" in prediction_df.columns:
487
481
  metrics = uq_metrics(prediction_df, primary_target)
488
482
  self.param_store.upsert(f"/workbench/models/{model.name}/inference/{capture_name}", metrics)
489
483
 
490
484
  # Return the prediction DataFrame
491
485
  return prediction_df
492
486
 
493
- def cross_fold_inference(self, nfolds: int = 5) -> pd.DataFrame:
494
- """Run cross-fold inference (only works for XGBoost models)
495
-
496
- Args:
497
- nfolds (int): Number of folds to use for cross-fold (default: 5)
487
+ def cross_fold_inference(self) -> pd.DataFrame:
488
+ """Pull cross-fold inference training results for this Endpoint's model
498
489
 
499
490
  Returns:
500
491
  pd.DataFrame: A DataFrame with cross fold predictions
@@ -506,8 +497,8 @@ class EndpointCore(Artifact):
506
497
  # Compute CrossFold (Metrics and Prediction Dataframe)
507
498
  # For PyTorch and ChemProp, pull pre-computed CV results from training
508
499
  if model.model_framework in [ModelFramework.UNKNOWN, ModelFramework.XGBOOST]:
509
- cross_fold_metrics, out_of_fold_df = xgboost_cross_fold(model, nfolds=nfolds)
510
- elif model.model_framework == ModelFramework.PYTORCH_TABULAR:
500
+ cross_fold_metrics, out_of_fold_df = xgboost_pull_cv(model)
501
+ elif model.model_framework == ModelFramework.PYTORCH:
511
502
  cross_fold_metrics, out_of_fold_df = pytorch_pull_cv(model)
512
503
  elif model.model_framework == ModelFramework.CHEMPROP:
513
504
  cross_fold_metrics, out_of_fold_df = chemprop_pull_cv(model)
@@ -534,48 +525,24 @@ class EndpointCore(Artifact):
534
525
  fs = FeatureSetCore(model.get_input())
535
526
  id_column = fs.id_column
536
527
 
537
- # Is this a UQ Model? If so, run full inference and merge the results
538
- additional_columns = []
539
- if model.model_framework == ModelFramework.XGBOOST and model_type == ModelType.UQ_REGRESSOR:
540
- self.log.important("UQ Regressor detected, running full inference to get uncertainty estimates...")
541
-
542
- # Get the training view dataframe for inference
543
- training_df = model.training_view().pull_dataframe()
544
-
545
- # Run inference on the endpoint to get UQ outputs
546
- uq_df = self.inference(training_df)
547
-
548
- # Identify UQ-specific columns (quantiles, prediction_std, *_pred_std)
549
- uq_columns = [
550
- col
551
- for col in uq_df.columns
552
- if col.startswith("q_") or col == "prediction_std" or col.endswith("_pred_std") or col == "confidence"
553
- ]
554
-
555
- # Merge UQ columns with out-of-fold predictions
556
- if uq_columns:
557
- # Keep id_column and UQ columns, drop 'prediction' to avoid conflict when merging
558
- uq_df = uq_df[[id_column] + uq_columns]
559
-
560
- # Drop duplicates in uq_df based on id_column
561
- uq_df = uq_df.drop_duplicates(subset=[id_column])
562
-
563
- # Merge UQ columns into out_of_fold_df
564
- out_of_fold_df = pd.merge(out_of_fold_df, uq_df, on=id_column, how="left")
565
- additional_columns = uq_columns
566
- self.log.info(f"Added UQ columns: {', '.join(additional_columns)}")
567
-
568
- # Also compute UQ metrics (use first target for multi-target models)
569
- primary_target = targets[0] if isinstance(targets, list) else targets
570
- metrics = uq_metrics(out_of_fold_df, primary_target)
571
- self.param_store.upsert(f"/workbench/models/{model.name}/inference/full_cross_fold", metrics)
572
-
573
528
  # Normalize targets to a list for iteration
574
529
  target_list = targets if isinstance(targets, list) else [targets]
575
530
  primary_target = target_list[0]
576
531
 
577
- # Save results for each target as cv_{target}
578
- # Also save primary target as "full_cross_fold" for backward compatibility
532
+ # Collect UQ columns (q_*, confidence) for additional tracking
533
+ additional_columns = [col for col in out_of_fold_df.columns if col.startswith("q_") or col == "confidence"]
534
+ if additional_columns:
535
+ self.log.info(f"UQ columns from training: {', '.join(additional_columns)}")
536
+
537
+ # Capture uncertainty metrics if prediction_std is available (UQ, ChemProp, etc.)
538
+ if "prediction_std" in out_of_fold_df.columns:
539
+ metrics = uq_metrics(out_of_fold_df, primary_target)
540
+ self.param_store.upsert(f"/workbench/models/{model.name}/inference/full_cross_fold", metrics)
541
+
542
+ # For single-target models (99% of cases), just save as "full_cross_fold"
543
+ # For multi-target models, save each as cv_{target} plus primary as "full_cross_fold"
544
+ is_multi_target = len(target_list) > 1
545
+
579
546
  for target in target_list:
580
547
  # Drop rows with NaN target values for metrics/plots
581
548
  target_df = out_of_fold_df.dropna(subset=[target])
@@ -588,21 +555,22 @@ class EndpointCore(Artifact):
588
555
  else:
589
556
  target_metrics = pd.DataFrame()
590
557
 
591
- # Save as cv_{target}
592
- capture_name = f"cv_{target}"
593
- description = capture_name.replace("_", " ").title()
594
- self._capture_inference_results(
595
- capture_name,
596
- target_df,
597
- target,
598
- model_type,
599
- target_metrics,
600
- description,
601
- features=additional_columns,
602
- id_column=id_column,
603
- )
558
+ if is_multi_target:
559
+ # Multi-target: save as cv_{target}
560
+ capture_name = f"cv_{target}"
561
+ description = capture_name.replace("_", " ").title()
562
+ self._capture_inference_results(
563
+ capture_name,
564
+ target_df,
565
+ target,
566
+ model_type,
567
+ target_metrics,
568
+ description,
569
+ features=additional_columns,
570
+ id_column=id_column,
571
+ )
604
572
 
605
- # Also save primary target as "full_cross_fold" for backward compatibility
573
+ # Save primary target (or single target) as "full_cross_fold"
606
574
  if target == primary_target:
607
575
  self._capture_inference_results(
608
576
  "full_cross_fold",
@@ -960,29 +928,9 @@ class EndpointCore(Artifact):
960
928
  self.log.warning("Dropping NaN rows for metric computation.")
961
929
  prediction_df = prediction_df.dropna(subset=[target_column, "prediction"])
962
930
 
963
- # Compute the metrics
931
+ # Compute the metrics using shared utilities
964
932
  try:
965
- y_true = prediction_df[target_column]
966
- y_pred = prediction_df["prediction"]
967
-
968
- mae = mean_absolute_error(y_true, y_pred)
969
- rmse = np.sqrt(mean_squared_error(y_true, y_pred))
970
- r2 = r2_score(y_true, y_pred)
971
- # Mean Absolute Percentage Error
972
- mape = np.mean(np.where(y_true != 0, np.abs((y_true - y_pred) / y_true), np.abs(y_true - y_pred))) * 100
973
- # Median Absolute Error
974
- medae = median_absolute_error(y_true, y_pred)
975
-
976
- # Organize and return the metrics
977
- metrics = {
978
- "MAE": round(mae, 3),
979
- "RMSE": round(rmse, 3),
980
- "R2": round(r2, 3),
981
- "MAPE": round(mape, 3),
982
- "MedAE": round(medae, 3),
983
- "NumRows": len(prediction_df),
984
- }
985
- return pd.DataFrame.from_records([metrics])
933
+ return compute_regression_metrics(prediction_df, target_column)
986
934
  except Exception as e:
987
935
  self.log.warning(f"Error computing regression metrics: {str(e)}")
988
936
  return pd.DataFrame()
@@ -1065,46 +1013,8 @@ class EndpointCore(Artifact):
1065
1013
  else:
1066
1014
  self.validate_proba_columns(prediction_df, class_labels)
1067
1015
 
1068
- # Calculate precision, recall, f1, and support, handling zero division
1069
- scores = precision_recall_fscore_support(
1070
- prediction_df[target_column],
1071
- prediction_df["prediction"],
1072
- average=None,
1073
- labels=class_labels,
1074
- zero_division=0,
1075
- )
1076
-
1077
- # Identify the probability columns and keep them as a Pandas DataFrame
1078
- proba_columns = [f"{label}_proba" for label in class_labels]
1079
- y_score = prediction_df[proba_columns]
1080
-
1081
- # One-hot encode the true labels using all class labels (fit with class_labels)
1082
- encoder = OneHotEncoder(categories=[class_labels], sparse_output=False)
1083
- y_true = encoder.fit_transform(prediction_df[[target_column]])
1084
-
1085
- # Calculate ROC AUC per label and handle exceptions for missing classes
1086
- roc_auc_per_label = []
1087
- for i, label in enumerate(class_labels):
1088
- try:
1089
- roc_auc = roc_auc_score(y_true[:, i], y_score.iloc[:, i])
1090
- except ValueError as e:
1091
- self.log.warning(f"ROC AUC calculation failed for label {label}.")
1092
- self.log.warning(f"{str(e)}")
1093
- roc_auc = 0.0
1094
- roc_auc_per_label.append(roc_auc)
1095
-
1096
- # Put the scores into a DataFrame
1097
- score_df = pd.DataFrame(
1098
- {
1099
- target_column: class_labels,
1100
- "precision": scores[0],
1101
- "recall": scores[1],
1102
- "f1": scores[2],
1103
- "roc_auc": roc_auc_per_label,
1104
- "support": scores[3],
1105
- }
1106
- )
1107
- return score_df
1016
+ # Compute the metrics using shared utilities (returns per-class + 'all' row)
1017
+ return compute_classification_metrics(prediction_df, target_column, class_labels)
1108
1018
 
1109
1019
  def generate_confusion_matrix(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
1110
1020
  """Compute the confusion matrix for this Endpoint
@@ -21,7 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
- from workbench.utils.model_utils import proximity_model
24
+ from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
25
25
 
26
26
 
27
27
  class ModelType(Enum):
@@ -44,7 +44,7 @@ class ModelFramework(Enum):
44
44
  SKLEARN = "sklearn"
45
45
  XGBOOST = "xgboost"
46
46
  LIGHTGBM = "lightgbm"
47
- PYTORCH_TABULAR = "pytorch_tabular"
47
+ PYTORCH = "pytorch"
48
48
  CHEMPROP = "chemprop"
49
49
  TRANSFORMER = "transformer"
50
50
  UNKNOWN = "unknown"
@@ -263,11 +263,11 @@ class ModelCore(Artifact):
263
263
  else:
264
264
  self.log.important(f"No inference data found for {self.model_name}!")
265
265
 
266
- def get_inference_metrics(self, capture_name: str = "any") -> Union[pd.DataFrame, None]:
266
+ def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
267
267
  """Retrieve the inference performance metrics for this model
268
268
 
269
269
  Args:
270
- capture_name (str, optional): Specific capture_name (default: "any")
270
+ capture_name (str, optional): Specific capture_name (default: "auto")
271
271
  Returns:
272
272
  pd.DataFrame: DataFrame of the Model Metrics
273
273
 
@@ -275,7 +275,7 @@ class ModelCore(Artifact):
275
275
  If a capture_name isn't specified this will try to the 'first' available metrics
276
276
  """
277
277
  # Try to get the auto_capture 'training_holdout' or the training
278
- if capture_name == "any":
278
+ if capture_name == "auto":
279
279
  metric_list = self.list_inference_runs()
280
280
  if metric_list:
281
281
  return self.get_inference_metrics(metric_list[0])
@@ -303,11 +303,11 @@ class ModelCore(Artifact):
303
303
  self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
304
304
  return None
305
305
 
306
- def confusion_matrix(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
306
+ def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
307
307
  """Retrieve the confusion_matrix for this model
308
308
 
309
309
  Args:
310
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
310
+ capture_name (str, optional): Specific capture_name or "training" (default: "auto")
311
311
  Returns:
312
312
  pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
313
313
  """
@@ -319,7 +319,7 @@ class ModelCore(Artifact):
319
319
  raise ValueError(error_msg)
320
320
 
321
321
  # Grab the metrics from the Workbench Metadata (try inference first, then training)
322
- if capture_name == "latest":
322
+ if capture_name == "auto":
323
323
  cm = self.confusion_matrix("auto_inference")
324
324
  return cm if cm is not None else self.confusion_matrix("model_training")
325
325
 
@@ -541,6 +541,17 @@ class ModelCore(Artifact):
541
541
  else:
542
542
  self.log.error(f"Model {self.model_name} is not a classifier!")
543
543
 
544
+ def summary(self) -> dict:
545
+ """Summary information about this Model
546
+
547
+ Returns:
548
+ dict: Dictionary of summary information about this Model
549
+ """
550
+ self.log.info("Computing Model Summary...")
551
+ summary = super().summary()
552
+ summary["hyperparameters"] = get_model_hyperparameters(self)
553
+ return summary
554
+
544
555
  def details(self) -> dict:
545
556
  """Additional Details about this Model
546
557
 
@@ -565,6 +576,7 @@ class ModelCore(Artifact):
565
576
  details["status"] = self.latest_model["ModelPackageStatus"]
566
577
  details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
567
578
  details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
579
+ details["hyperparameters"] = get_model_hyperparameters(self)
568
580
 
569
581
  # Grab the inference and container info
570
582
  inference_spec = self.latest_model["InferenceSpecification"]
@@ -575,16 +587,6 @@ class ModelCore(Artifact):
575
587
  details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
576
588
  details["content_types"] = inference_spec["SupportedContentTypes"]
577
589
  details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
578
- details["model_metrics"] = self.get_inference_metrics()
579
- if self.model_type == ModelType.CLASSIFIER:
580
- details["confusion_matrix"] = self.confusion_matrix()
581
- details["predictions"] = None
582
- elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
583
- details["confusion_matrix"] = None
584
- details["predictions"] = self.get_inference_predictions()
585
- else:
586
- details["confusion_matrix"] = None
587
- details["predictions"] = None
588
590
 
589
591
  # Grab the inference metadata
590
592
  details["inference_meta"] = self.get_inference_metadata()
@@ -904,7 +906,7 @@ class ModelCore(Artifact):
904
906
  """
905
907
  if prox_model_name is None:
906
908
  prox_model_name = self.model_name + "-prox"
907
- return proximity_model(self, prox_model_name, track_columns=track_columns)
909
+ return published_proximity_model(self, prox_model_name, track_columns=track_columns)
908
910
 
909
911
  def delete(self):
910
912
  """Delete the Model Packages and the Model Group"""
@@ -228,7 +228,7 @@ class FeaturesToModel(Transform):
228
228
  raise ValueError(msg)
229
229
 
230
230
  # Dynamically create the metric definitions
231
- metrics = ["precision", "recall", "f1"]
231
+ metrics = ["precision", "recall", "f1", "support"]
232
232
  metric_definitions = []
233
233
  for t in self.class_labels:
234
234
  for m in metrics:
@@ -254,7 +254,7 @@ class FeaturesToModel(Transform):
254
254
  image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image)
255
255
 
256
256
  # Use GPU instance for ChemProp/PyTorch, CPU for others
257
- if self.model_framework in [ModelFramework.CHEMPROP, ModelFramework.PYTORCH_TABULAR]:
257
+ if self.model_framework in [ModelFramework.CHEMPROP, ModelFramework.PYTORCH]:
258
258
  train_instance_type = "ml.g6.xlarge" # NVIDIA L4 GPU, ~$0.80/hr
259
259
  self.log.important(f"Using GPU instance {train_instance_type} for {self.model_framework.value}")
260
260
  else:
@@ -106,7 +106,7 @@ class ModelToEndpoint(Transform):
106
106
  from workbench.api import ModelFramework
107
107
 
108
108
  self.log.info(f"Model Framework: {workbench_model.model_framework}")
109
- if workbench_model.model_framework in [ModelFramework.PYTORCH_TABULAR, ModelFramework.CHEMPROP]:
109
+ if workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]:
110
110
  if mem_size < 4096:
111
111
  self.log.important(
112
112
  f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)"