workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
+ from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
24
25
 
25
26
 
26
27
  class ModelType(Enum):
@@ -29,63 +30,64 @@ class ModelType(Enum):
29
30
  CLASSIFIER = "classifier"
30
31
  REGRESSOR = "regressor"
31
32
  CLUSTERER = "clusterer"
32
- TRANSFORMER = "transformer"
33
33
  PROXIMITY = "proximity"
34
34
  PROJECTION = "projection"
35
35
  UQ_REGRESSOR = "uq_regressor"
36
36
  ENSEMBLE_REGRESSOR = "ensemble_regressor"
37
+ TRANSFORMER = "transformer"
38
+ UNKNOWN = "unknown"
39
+
40
+
41
+ class ModelFramework(Enum):
42
+ """Enumerated Types for Workbench Model Frameworks"""
43
+
44
+ SKLEARN = "sklearn"
45
+ XGBOOST = "xgboost"
46
+ LIGHTGBM = "lightgbm"
47
+ PYTORCH = "pytorch"
48
+ CHEMPROP = "chemprop"
49
+ TRANSFORMER = "transformer"
50
+ META = "meta"
37
51
  UNKNOWN = "unknown"
38
52
 
39
53
 
40
54
  class ModelImages:
41
55
  """Class for retrieving workbench inference images"""
42
56
 
43
- image_uris = {
44
- # US East 1 images
45
- ("us-east-1", "training", "0.1", "x86_64"): (
46
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
47
- ),
48
- ("us-east-1", "inference", "0.1", "x86_64"): (
49
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
50
- ),
51
- ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
53
- ),
54
- ("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
55
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
- ),
57
- # US West 2 images
58
- ("us-west-2", "training", "0.1", "x86_64"): (
59
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
60
- ),
61
- ("us-west-2", "inference", "0.1", "x86_64"): (
62
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
63
- ),
64
- ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
66
- ),
67
- ("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
68
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
- ),
70
- # ARM64 images
71
- # Meta Endpoint inference images
72
- ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
73
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
74
- ),
75
- ("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
76
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
77
- ),
57
+ # Account ID
58
+ ACCOUNT_ID = "507740646243"
59
+
60
+ # Image name mappings
61
+ IMAGE_NAMES = {
62
+ "training": "py312-general-ml-training",
63
+ "inference": "py312-general-ml-inference",
64
+ "pytorch_training": "py312-pytorch-training",
65
+ "pytorch_inference": "py312-pytorch-inference",
66
+ "meta_training": "py312-meta-training",
67
+ "meta_inference": "py312-meta-inference",
78
68
  }
79
69
 
80
70
  @classmethod
81
- def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
82
- key = (region, image_type, version, architecture)
83
- if key in cls.image_uris:
84
- return cls.image_uris[key]
85
- else:
86
- raise ValueError(
87
- f"No matching image found for region: {region}, image_type: {image_type}, version: {version}"
88
- )
71
+ def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
72
+ """
73
+ Dynamically construct ECR image URI.
74
+
75
+ Args:
76
+ region: AWS region (e.g., 'us-east-1', 'us-west-2')
77
+ image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
78
+ version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
79
+ architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
80
+
81
+ Returns:
82
+ ECR image URI string
83
+ """
84
+ if image_type not in cls.IMAGE_NAMES:
85
+ raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
86
+
87
+ image_name = cls.IMAGE_NAMES[image_type]
88
+ uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
89
+
90
+ return uri
89
91
 
90
92
 
91
93
  class ModelCore(Artifact):
@@ -99,11 +101,10 @@ class ModelCore(Artifact):
99
101
  ```
100
102
  """
101
103
 
102
- def __init__(self, model_name: str, model_type: ModelType = None, **kwargs):
104
+ def __init__(self, model_name: str, **kwargs):
103
105
  """ModelCore Initialization
104
106
  Args:
105
107
  model_name (str): Name of Model in Workbench.
106
- model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
107
108
  **kwargs: Additional keyword arguments
108
109
  """
109
110
 
@@ -137,10 +138,8 @@ class ModelCore(Artifact):
137
138
  self.latest_model = self.model_meta["ModelPackageList"][0]
138
139
  self.description = self.latest_model.get("ModelPackageDescription", "-")
139
140
  self.training_job_name = self._extract_training_job_name()
140
- if model_type:
141
- self._set_model_type(model_type)
142
- else:
143
- self.model_type = self._get_model_type()
141
+ self.model_type = self._get_model_type()
142
+ self.model_framework = self._get_model_framework()
144
143
  except (IndexError, KeyError):
145
144
  self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
146
145
  return
@@ -266,21 +265,25 @@ class ModelCore(Artifact):
266
265
  else:
267
266
  self.log.important(f"No inference data found for {self.model_name}!")
268
267
 
269
- def get_inference_metrics(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
268
+ def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
270
269
  """Retrieve the inference performance metrics for this model
271
270
 
272
271
  Args:
273
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
272
+ capture_name (str, optional): Specific capture_name (default: "auto")
274
273
  Returns:
275
274
  pd.DataFrame: DataFrame of the Model Metrics
276
275
 
277
276
  Note:
278
- If a capture_name isn't specified this will try to return something reasonable
277
+ If a capture_name isn't specified this will try to the 'first' available metrics
279
278
  """
280
279
  # Try to get the auto_capture 'training_holdout' or the training
281
- if capture_name == "latest":
282
- metrics_df = self.get_inference_metrics("auto_inference")
283
- return metrics_df if metrics_df is not None else self.get_inference_metrics("model_training")
280
+ if capture_name == "auto":
281
+ metric_list = self.list_inference_runs()
282
+ if metric_list:
283
+ return self.get_inference_metrics(metric_list[0])
284
+ else:
285
+ self.log.warning(f"No performance metrics found for {self.model_name}!")
286
+ return None
284
287
 
285
288
  # Grab the metrics captured during model training (could return None)
286
289
  if capture_name == "model_training":
@@ -302,11 +305,11 @@ class ModelCore(Artifact):
302
305
  self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
303
306
  return None
304
307
 
305
- def confusion_matrix(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
308
+ def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
306
309
  """Retrieve the confusion_matrix for this model
307
310
 
308
311
  Args:
309
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
312
+ capture_name (str, optional): Specific capture_name or "training" (default: "auto")
310
313
  Returns:
311
314
  pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
312
315
  """
@@ -318,7 +321,7 @@ class ModelCore(Artifact):
318
321
  raise ValueError(error_msg)
319
322
 
320
323
  # Grab the metrics from the Workbench Metadata (try inference first, then training)
321
- if capture_name == "latest":
324
+ if capture_name == "auto":
322
325
  cm = self.confusion_matrix("auto_inference")
323
326
  return cm if cm is not None else self.confusion_matrix("model_training")
324
327
 
@@ -540,6 +543,17 @@ class ModelCore(Artifact):
540
543
  else:
541
544
  self.log.error(f"Model {self.model_name} is not a classifier!")
542
545
 
546
+ def summary(self) -> dict:
547
+ """Summary information about this Model
548
+
549
+ Returns:
550
+ dict: Dictionary of summary information about this Model
551
+ """
552
+ self.log.info("Computing Model Summary...")
553
+ summary = super().summary()
554
+ summary["hyperparameters"] = get_model_hyperparameters(self)
555
+ return summary
556
+
543
557
  def details(self) -> dict:
544
558
  """Additional Details about this Model
545
559
 
@@ -564,6 +578,7 @@ class ModelCore(Artifact):
564
578
  details["status"] = self.latest_model["ModelPackageStatus"]
565
579
  details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
566
580
  details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
581
+ details["hyperparameters"] = get_model_hyperparameters(self)
567
582
 
568
583
  # Grab the inference and container info
569
584
  inference_spec = self.latest_model["InferenceSpecification"]
@@ -574,16 +589,6 @@ class ModelCore(Artifact):
574
589
  details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
575
590
  details["content_types"] = inference_spec["SupportedContentTypes"]
576
591
  details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
577
- details["model_metrics"] = self.get_inference_metrics()
578
- if self.model_type == ModelType.CLASSIFIER:
579
- details["confusion_matrix"] = self.confusion_matrix()
580
- details["predictions"] = None
581
- elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
582
- details["confusion_matrix"] = None
583
- details["predictions"] = self.get_inference_predictions()
584
- else:
585
- details["confusion_matrix"] = None
586
- details["predictions"] = None
587
592
 
588
593
  # Grab the inference metadata
589
594
  details["inference_meta"] = self.get_inference_metadata()
@@ -591,6 +596,24 @@ class ModelCore(Artifact):
591
596
  # Return the details
592
597
  return details
593
598
 
599
+ # Training View for this model
600
+ def training_view(self):
601
+ """Get the training view for this model"""
602
+ from workbench.core.artifacts.feature_set_core import FeatureSetCore
603
+ from workbench.core.views import View
604
+
605
+ # Grab our FeatureSet
606
+ fs = FeatureSetCore(self.get_input())
607
+
608
+ # See if we have a training view for this model
609
+ my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
610
+ view = View(fs, my_model_training_view, auto_create_view=False)
611
+ if view.exists():
612
+ return view
613
+ else:
614
+ self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
615
+ return fs.view("training")
616
+
594
617
  # Pipeline for this model
595
618
  def get_pipeline(self) -> str:
596
619
  """Get the pipeline for this model"""
@@ -854,21 +877,13 @@ class ModelCore(Artifact):
854
877
  return self.df_store.get(f"/workbench/models/{self.name}/shap_data")
855
878
  else:
856
879
  # Loop over the SHAP data and return a dict of DataFrames
857
- shap_dfs = self.df_store.list_subfiles(f"/workbench/models/{self.name}/shap_data")
880
+ shap_dfs = self.df_store.list(f"/workbench/models/{self.name}/shap_data")
858
881
  shap_data = {}
859
882
  for df_location in shap_dfs:
860
883
  key = df_location.split("/")[-1]
861
884
  shap_data[key] = self.df_store.get(df_location)
862
885
  return shap_data or None
863
886
 
864
- def cross_folds(self) -> dict:
865
- """Retrieve the cross-fold inference results(only works for XGBoost models)
866
-
867
- Returns:
868
- dict: Dictionary with the cross-fold inference results
869
- """
870
- return self.param_store.get(f"/workbench/models/{self.name}/inference/cross_fold")
871
-
872
887
  def supported_inference_instances(self) -> Optional[list]:
873
888
  """Retrieve the supported endpoint inference instance types
874
889
 
@@ -881,10 +896,24 @@ class ModelCore(Artifact):
881
896
  except (KeyError, IndexError, TypeError):
882
897
  return None
883
898
 
899
+ def publish_prox_model(self, prox_model_name: str = None, include_all_columns: bool = False):
900
+ """Create and publish a Proximity Model for this Model
901
+
902
+ Args:
903
+ prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
904
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
905
+
906
+ Returns:
907
+ Model: The published Proximity Model
908
+ """
909
+ if prox_model_name is None:
910
+ prox_model_name = self.model_name + "-prox"
911
+ return published_proximity_model(self, prox_model_name, include_all_columns=include_all_columns)
912
+
884
913
  def delete(self):
885
914
  """Delete the Model Packages and the Model Group"""
886
915
  if not self.exists():
887
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
916
+ self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
888
917
 
889
918
  # Call the Class Method to delete the Model Group
890
919
  ModelCore.managed_delete(model_group_name=self.name)
@@ -960,6 +989,27 @@ class ModelCore(Artifact):
960
989
  self.log.warning(f"Could not determine model type for {self.model_name}!")
961
990
  return ModelType.UNKNOWN
962
991
 
992
+ def _set_model_framework(self, model_framework: ModelFramework):
993
+ """Internal: Set the Model Framework for this Model"""
994
+ self.model_framework = model_framework
995
+ self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
996
+ self.remove_health_tag("model_framework_unknown")
997
+
998
+ def _get_model_framework(self) -> ModelFramework:
999
+ """Internal: Query the Workbench Metadata to get the model framework
1000
+ Returns:
1001
+ ModelFramework: The ModelFramework of this Model
1002
+ Notes:
1003
+ This is an internal method that should not be called directly
1004
+ Use the model_framework attribute instead
1005
+ """
1006
+ model_framework = self.workbench_meta().get("workbench_model_framework")
1007
+ try:
1008
+ return ModelFramework(model_framework)
1009
+ except ValueError:
1010
+ self.log.warning(f"Could not determine model framework for {self.model_name}!")
1011
+ return ModelFramework.UNKNOWN
1012
+
963
1013
  def _load_training_metrics(self):
964
1014
  """Internal: Retrieve the training metrics and Confusion Matrix for this model
965
1015
  and load the data into the Workbench Metadata
@@ -1151,13 +1201,11 @@ if __name__ == "__main__":
1151
1201
  # Grab a ModelCore object and pull some information from it
1152
1202
  my_model = ModelCore("abalone-regression")
1153
1203
 
1154
- # Call the various methods
1155
-
1156
1204
  # Let's do a check/validation of the Model
1157
1205
  print(f"Model Check: {my_model.exists()}")
1158
1206
 
1159
1207
  # Make sure the model is 'ready'
1160
- # my_model.onboard()
1208
+ my_model.onboard()
1161
1209
 
1162
1210
  # Get the ARN of the Model Group
1163
1211
  print(f"Model Group ARN: {my_model.group_arn()}")
@@ -1223,5 +1271,10 @@ if __name__ == "__main__":
1223
1271
  # Delete the Model
1224
1272
  # ModelCore.managed_delete("wine-classification")
1225
1273
 
1274
+ # Check the training view logic
1275
+ model = ModelCore("wine-class-test-251112-BW")
1276
+ training_view = model.training_view()
1277
+ print(f"Training View Name: {training_view.name}")
1278
+
1226
1279
  # Check for a model that doesn't exist
1227
1280
  my_model = ModelCore("empty-model-group")
@@ -0,0 +1,98 @@
1
+ """ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store."""
2
+
3
+ import logging
4
+
5
+ # Workbench Imports
6
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
7
+
8
+ # Workbench Bridges Import
9
+ from workbench_bridges.api import ParameterStore as BridgesParameterStore
10
+
11
+
12
+ class ParameterStoreCore(BridgesParameterStore):
13
+ """ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store.
14
+
15
+ Common Usage:
16
+ ```python
17
+ params = ParameterStoreCore()
18
+
19
+ # List Parameters
20
+ params.list()
21
+
22
+ ['/workbench/abalone_info',
23
+ '/workbench/my_data',
24
+ '/workbench/test',
25
+ '/workbench/pipelines/my_pipeline']
26
+
27
+ # Add Key
28
+ params.upsert("key", "value")
29
+ value = params.get("key")
30
+
31
+ # Add any data (lists, dictionaries, etc..)
32
+ my_data = {"key": "value", "number": 4.2, "list": [1,2,3]}
33
+ params.upsert("my_data", my_data)
34
+
35
+ # Retrieve data
36
+ return_value = params.get("my_data")
37
+ pprint(return_value)
38
+
39
+ {'key': 'value', 'list': [1, 2, 3], 'number': 4.2}
40
+
41
+ # Delete parameters
42
+ param_store.delete("my_data")
43
+ ```
44
+ """
45
+
46
+ def __init__(self):
47
+ """ParameterStoreCore Init Method"""
48
+ session = AWSAccountClamp().boto3_session
49
+
50
+ # Initialize parent with workbench config
51
+ super().__init__(boto3_session=session)
52
+ self.log = logging.getLogger("workbench")
53
+
54
+
55
+ if __name__ == "__main__":
56
+ """Exercise the ParameterStoreCore Class"""
57
+
58
+ # Create a ParameterStoreCore manager
59
+ param_store = ParameterStoreCore()
60
+
61
+ # List the parameters
62
+ print("Listing Parameters...")
63
+ print(param_store.list())
64
+
65
+ # Add a new parameter
66
+ param_store.upsert("/workbench/test", "value")
67
+
68
+ # Get the parameter
69
+ print(f"Getting parameter 'test': {param_store.get('/workbench/test')}")
70
+
71
+ # Add a dictionary as a parameter
72
+ sample_dict = {"key": "str_value", "awesome_value": 4.2}
73
+ param_store.upsert("/workbench/my_data", sample_dict)
74
+
75
+ # Retrieve the parameter as a dictionary
76
+ retrieved_value = param_store.get("/workbench/my_data")
77
+ print("Retrieved value:", retrieved_value)
78
+
79
+ # List the parameters
80
+ print("Listing Parameters...")
81
+ print(param_store.list())
82
+
83
+ # List the parameters with a prefix
84
+ print("Listing Parameters with prefix '/workbench':")
85
+ print(param_store.list("/workbench"))
86
+
87
+ # Delete the parameters
88
+ param_store.delete("/workbench/test")
89
+ param_store.delete("/workbench/my_data")
90
+
91
+ # Out of scope tests
92
+ param_store.upsert("test", "value")
93
+ param_store.delete("test")
94
+
95
+ # Recursive delete test
96
+ param_store.upsert("/workbench/test/test1", "value1")
97
+ param_store.upsert("/workbench/test/test2", "value2")
98
+ param_store.delete_recursive("workbench/test/")
@@ -7,7 +7,6 @@ import logging
7
7
  from typing import Union
8
8
  import pandas as pd
9
9
 
10
-
11
10
  # Workbench Imports
12
11
  from workbench.core.cloud_platform.aws.aws_meta import AWSMeta
13
12
 
@@ -123,7 +123,7 @@ class PipelineExecutor:
123
123
  if "model" in workbench_objects and (not subset or "endpoint" in subset):
124
124
  workbench_objects["model"].to_endpoint(**kwargs)
125
125
  endpoint = Endpoint(kwargs["name"])
126
- endpoint.auto_inference(capture=True)
126
+ endpoint.auto_inference()
127
127
 
128
128
  # Found something weird
129
129
  else: