workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
+ from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
24
25
 
25
26
 
26
27
  class ModelType(Enum):
@@ -29,92 +30,64 @@ class ModelType(Enum):
29
30
  CLASSIFIER = "classifier"
30
31
  REGRESSOR = "regressor"
31
32
  CLUSTERER = "clusterer"
32
- TRANSFORMER = "transformer"
33
33
  PROXIMITY = "proximity"
34
34
  PROJECTION = "projection"
35
35
  UQ_REGRESSOR = "uq_regressor"
36
36
  ENSEMBLE_REGRESSOR = "ensemble_regressor"
37
+ TRANSFORMER = "transformer"
37
38
  UNKNOWN = "unknown"
38
39
 
39
40
 
40
- # Deprecated Images
41
- """
42
- # US East 1 images
43
- "py312-general-ml-training"
44
- ("us-east-1", "training", "0.1", "x86_64"): (
45
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
46
- ),
47
- ("us-east-1", "inference", "0.1", "x86_64"): (
48
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
49
- ),
50
-
51
- # US West 2 images
52
- ("us-west-2", "training", "0.1", "x86_64"): (
53
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
54
- ),
55
- ("us-west-2", "inference", "0.1", "x86_64"): (
56
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
57
- ),
58
-
59
- # ARM64 images
60
- ("us-east-1", "inference", "0.1", "arm64"): (
61
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
62
- ),
63
- ("us-west-2", "inference", "0.1", "arm64"): (
64
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
65
- ),
66
- """
41
+ class ModelFramework(Enum):
42
+ """Enumerated Types for Workbench Model Frameworks"""
43
+
44
+ SKLEARN = "sklearn"
45
+ XGBOOST = "xgboost"
46
+ LIGHTGBM = "lightgbm"
47
+ PYTORCH = "pytorch"
48
+ CHEMPROP = "chemprop"
49
+ TRANSFORMER = "transformer"
50
+ META = "meta"
51
+ UNKNOWN = "unknown"
67
52
 
68
53
 
69
54
  class ModelImages:
70
55
  """Class for retrieving workbench inference images"""
71
56
 
72
- image_uris = {
73
- # US East 1 images
74
- ("us-east-1", "training", "0.1", "x86_64"): (
75
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
76
- ),
77
- ("us-east-1", "inference", "0.1", "x86_64"): (
78
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
79
- ),
80
- ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
81
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
82
- ),
83
- ("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
84
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
85
- ),
86
- # US West 2 images
87
- ("us-west-2", "training", "0.1", "x86_64"): (
88
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-training:0.1"
89
- ),
90
- ("us-west-2", "inference", "0.1", "x86_64"): (
91
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-general-ml-inference:0.1"
92
- ),
93
- ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
94
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
95
- ),
96
- ("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
97
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
98
- ),
99
- # ARM64 images
100
- # Meta Endpoint inference images
101
- ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
102
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
103
- ),
104
- ("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
105
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
106
- ),
57
+ # Account ID
58
+ ACCOUNT_ID = "507740646243"
59
+
60
+ # Image name mappings
61
+ IMAGE_NAMES = {
62
+ "training": "py312-general-ml-training",
63
+ "inference": "py312-general-ml-inference",
64
+ "pytorch_training": "py312-pytorch-training",
65
+ "pytorch_inference": "py312-pytorch-inference",
66
+ "meta_training": "py312-meta-training",
67
+ "meta_inference": "py312-meta-inference",
107
68
  }
108
69
 
109
70
  @classmethod
110
- def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
111
- key = (region, image_type, version, architecture)
112
- if key in cls.image_uris:
113
- return cls.image_uris[key]
114
- else:
115
- raise ValueError(
116
- f"No matching image found for region: {region}, image_type: {image_type}, version: {version}"
117
- )
71
+ def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
72
+ """
73
+ Dynamically construct ECR image URI.
74
+
75
+ Args:
76
+ region: AWS region (e.g., 'us-east-1', 'us-west-2')
77
+ image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
78
+ version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
79
+ architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
80
+
81
+ Returns:
82
+ ECR image URI string
83
+ """
84
+ if image_type not in cls.IMAGE_NAMES:
85
+ raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
86
+
87
+ image_name = cls.IMAGE_NAMES[image_type]
88
+ uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
89
+
90
+ return uri
118
91
 
119
92
 
120
93
  class ModelCore(Artifact):
@@ -128,11 +101,10 @@ class ModelCore(Artifact):
128
101
  ```
129
102
  """
130
103
 
131
- def __init__(self, model_name: str, model_type: ModelType = None, **kwargs):
104
+ def __init__(self, model_name: str, **kwargs):
132
105
  """ModelCore Initialization
133
106
  Args:
134
107
  model_name (str): Name of Model in Workbench.
135
- model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
136
108
  **kwargs: Additional keyword arguments
137
109
  """
138
110
 
@@ -166,10 +138,8 @@ class ModelCore(Artifact):
166
138
  self.latest_model = self.model_meta["ModelPackageList"][0]
167
139
  self.description = self.latest_model.get("ModelPackageDescription", "-")
168
140
  self.training_job_name = self._extract_training_job_name()
169
- if model_type:
170
- self._set_model_type(model_type)
171
- else:
172
- self.model_type = self._get_model_type()
141
+ self.model_type = self._get_model_type()
142
+ self.model_framework = self._get_model_framework()
173
143
  except (IndexError, KeyError):
174
144
  self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
175
145
  return
@@ -295,21 +265,25 @@ class ModelCore(Artifact):
295
265
  else:
296
266
  self.log.important(f"No inference data found for {self.model_name}!")
297
267
 
298
- def get_inference_metrics(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
268
+ def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
299
269
  """Retrieve the inference performance metrics for this model
300
270
 
301
271
  Args:
302
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
272
+ capture_name (str, optional): Specific capture_name (default: "auto")
303
273
  Returns:
304
274
  pd.DataFrame: DataFrame of the Model Metrics
305
275
 
306
276
  Note:
307
- If a capture_name isn't specified this will try to return something reasonable
277
+ If a capture_name isn't specified this will try to the 'first' available metrics
308
278
  """
309
279
  # Try to get the auto_capture 'training_holdout' or the training
310
- if capture_name == "latest":
311
- metrics_df = self.get_inference_metrics("auto_inference")
312
- return metrics_df if metrics_df is not None else self.get_inference_metrics("model_training")
280
+ if capture_name == "auto":
281
+ metric_list = self.list_inference_runs()
282
+ if metric_list:
283
+ return self.get_inference_metrics(metric_list[0])
284
+ else:
285
+ self.log.warning(f"No performance metrics found for {self.model_name}!")
286
+ return None
313
287
 
314
288
  # Grab the metrics captured during model training (could return None)
315
289
  if capture_name == "model_training":
@@ -331,11 +305,11 @@ class ModelCore(Artifact):
331
305
  self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
332
306
  return None
333
307
 
334
- def confusion_matrix(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
308
+ def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
335
309
  """Retrieve the confusion_matrix for this model
336
310
 
337
311
  Args:
338
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
312
+ capture_name (str, optional): Specific capture_name or "training" (default: "auto")
339
313
  Returns:
340
314
  pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
341
315
  """
@@ -347,7 +321,7 @@ class ModelCore(Artifact):
347
321
  raise ValueError(error_msg)
348
322
 
349
323
  # Grab the metrics from the Workbench Metadata (try inference first, then training)
350
- if capture_name == "latest":
324
+ if capture_name == "auto":
351
325
  cm = self.confusion_matrix("auto_inference")
352
326
  return cm if cm is not None else self.confusion_matrix("model_training")
353
327
 
@@ -569,6 +543,17 @@ class ModelCore(Artifact):
569
543
  else:
570
544
  self.log.error(f"Model {self.model_name} is not a classifier!")
571
545
 
546
+ def summary(self) -> dict:
547
+ """Summary information about this Model
548
+
549
+ Returns:
550
+ dict: Dictionary of summary information about this Model
551
+ """
552
+ self.log.info("Computing Model Summary...")
553
+ summary = super().summary()
554
+ summary["hyperparameters"] = get_model_hyperparameters(self)
555
+ return summary
556
+
572
557
  def details(self) -> dict:
573
558
  """Additional Details about this Model
574
559
 
@@ -593,6 +578,7 @@ class ModelCore(Artifact):
593
578
  details["status"] = self.latest_model["ModelPackageStatus"]
594
579
  details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
595
580
  details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
581
+ details["hyperparameters"] = get_model_hyperparameters(self)
596
582
 
597
583
  # Grab the inference and container info
598
584
  inference_spec = self.latest_model["InferenceSpecification"]
@@ -603,16 +589,6 @@ class ModelCore(Artifact):
603
589
  details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
604
590
  details["content_types"] = inference_spec["SupportedContentTypes"]
605
591
  details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
606
- details["model_metrics"] = self.get_inference_metrics()
607
- if self.model_type == ModelType.CLASSIFIER:
608
- details["confusion_matrix"] = self.confusion_matrix()
609
- details["predictions"] = None
610
- elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
611
- details["confusion_matrix"] = None
612
- details["predictions"] = self.get_inference_predictions()
613
- else:
614
- details["confusion_matrix"] = None
615
- details["predictions"] = None
616
592
 
617
593
  # Grab the inference metadata
618
594
  details["inference_meta"] = self.get_inference_metadata()
@@ -620,6 +596,24 @@ class ModelCore(Artifact):
620
596
  # Return the details
621
597
  return details
622
598
 
599
+ # Training View for this model
600
+ def training_view(self):
601
+ """Get the training view for this model"""
602
+ from workbench.core.artifacts.feature_set_core import FeatureSetCore
603
+ from workbench.core.views import View
604
+
605
+ # Grab our FeatureSet
606
+ fs = FeatureSetCore(self.get_input())
607
+
608
+ # See if we have a training view for this model
609
+ my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
610
+ view = View(fs, my_model_training_view, auto_create_view=False)
611
+ if view.exists():
612
+ return view
613
+ else:
614
+ self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
615
+ return fs.view("training")
616
+
623
617
  # Pipeline for this model
624
618
  def get_pipeline(self) -> str:
625
619
  """Get the pipeline for this model"""
@@ -883,7 +877,7 @@ class ModelCore(Artifact):
883
877
  return self.df_store.get(f"/workbench/models/{self.name}/shap_data")
884
878
  else:
885
879
  # Loop over the SHAP data and return a dict of DataFrames
886
- shap_dfs = self.df_store.list_subfiles(f"/workbench/models/{self.name}/shap_data")
880
+ shap_dfs = self.df_store.list(f"/workbench/models/{self.name}/shap_data")
887
881
  shap_data = {}
888
882
  for df_location in shap_dfs:
889
883
  key = df_location.split("/")[-1]
@@ -902,10 +896,24 @@ class ModelCore(Artifact):
902
896
  except (KeyError, IndexError, TypeError):
903
897
  return None
904
898
 
899
+ def publish_prox_model(self, prox_model_name: str = None, include_all_columns: bool = False):
900
+ """Create and publish a Proximity Model for this Model
901
+
902
+ Args:
903
+ prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
904
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
905
+
906
+ Returns:
907
+ Model: The published Proximity Model
908
+ """
909
+ if prox_model_name is None:
910
+ prox_model_name = self.model_name + "-prox"
911
+ return published_proximity_model(self, prox_model_name, include_all_columns=include_all_columns)
912
+
905
913
  def delete(self):
906
914
  """Delete the Model Packages and the Model Group"""
907
915
  if not self.exists():
908
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
916
+ self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
909
917
 
910
918
  # Call the Class Method to delete the Model Group
911
919
  ModelCore.managed_delete(model_group_name=self.name)
@@ -981,6 +989,27 @@ class ModelCore(Artifact):
981
989
  self.log.warning(f"Could not determine model type for {self.model_name}!")
982
990
  return ModelType.UNKNOWN
983
991
 
992
+ def _set_model_framework(self, model_framework: ModelFramework):
993
+ """Internal: Set the Model Framework for this Model"""
994
+ self.model_framework = model_framework
995
+ self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
996
+ self.remove_health_tag("model_framework_unknown")
997
+
998
+ def _get_model_framework(self) -> ModelFramework:
999
+ """Internal: Query the Workbench Metadata to get the model framework
1000
+ Returns:
1001
+ ModelFramework: The ModelFramework of this Model
1002
+ Notes:
1003
+ This is an internal method that should not be called directly
1004
+ Use the model_framework attribute instead
1005
+ """
1006
+ model_framework = self.workbench_meta().get("workbench_model_framework")
1007
+ try:
1008
+ return ModelFramework(model_framework)
1009
+ except ValueError:
1010
+ self.log.warning(f"Could not determine model framework for {self.model_name}!")
1011
+ return ModelFramework.UNKNOWN
1012
+
984
1013
  def _load_training_metrics(self):
985
1014
  """Internal: Retrieve the training metrics and Confusion Matrix for this model
986
1015
  and load the data into the Workbench Metadata
@@ -1172,13 +1201,11 @@ if __name__ == "__main__":
1172
1201
  # Grab a ModelCore object and pull some information from it
1173
1202
  my_model = ModelCore("abalone-regression")
1174
1203
 
1175
- # Call the various methods
1176
-
1177
1204
  # Let's do a check/validation of the Model
1178
1205
  print(f"Model Check: {my_model.exists()}")
1179
1206
 
1180
1207
  # Make sure the model is 'ready'
1181
- # my_model.onboard()
1208
+ my_model.onboard()
1182
1209
 
1183
1210
  # Get the ARN of the Model Group
1184
1211
  print(f"Model Group ARN: {my_model.group_arn()}")
@@ -1244,5 +1271,10 @@ if __name__ == "__main__":
1244
1271
  # Delete the Model
1245
1272
  # ModelCore.managed_delete("wine-classification")
1246
1273
 
1274
+ # Check the training view logic
1275
+ model = ModelCore("wine-class-test-251112-BW")
1276
+ training_view = model.training_view()
1277
+ print(f"Training View Name: {training_view.name}")
1278
+
1247
1279
  # Check for a model that doesn't exist
1248
1280
  my_model = ModelCore("empty-model-group")
@@ -0,0 +1,98 @@
1
+ """ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store."""
2
+
3
+ import logging
4
+
5
+ # Workbench Imports
6
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
7
+
8
+ # Workbench Bridges Import
9
+ from workbench_bridges.api import ParameterStore as BridgesParameterStore
10
+
11
+
12
+ class ParameterStoreCore(BridgesParameterStore):
13
+ """ParameterStoreCore: Manages Workbench parameters in a Cloud Based Parameter Store.
14
+
15
+ Common Usage:
16
+ ```python
17
+ params = ParameterStoreCore()
18
+
19
+ # List Parameters
20
+ params.list()
21
+
22
+ ['/workbench/abalone_info',
23
+ '/workbench/my_data',
24
+ '/workbench/test',
25
+ '/workbench/pipelines/my_pipeline']
26
+
27
+ # Add Key
28
+ params.upsert("key", "value")
29
+ value = params.get("key")
30
+
31
+ # Add any data (lists, dictionaries, etc..)
32
+ my_data = {"key": "value", "number": 4.2, "list": [1,2,3]}
33
+ params.upsert("my_data", my_data)
34
+
35
+ # Retrieve data
36
+ return_value = params.get("my_data")
37
+ pprint(return_value)
38
+
39
+ {'key': 'value', 'list': [1, 2, 3], 'number': 4.2}
40
+
41
+ # Delete parameters
42
+ param_store.delete("my_data")
43
+ ```
44
+ """
45
+
46
+ def __init__(self):
47
+ """ParameterStoreCore Init Method"""
48
+ session = AWSAccountClamp().boto3_session
49
+
50
+ # Initialize parent with workbench config
51
+ super().__init__(boto3_session=session)
52
+ self.log = logging.getLogger("workbench")
53
+
54
+
55
+ if __name__ == "__main__":
56
+ """Exercise the ParameterStoreCore Class"""
57
+
58
+ # Create a ParameterStoreCore manager
59
+ param_store = ParameterStoreCore()
60
+
61
+ # List the parameters
62
+ print("Listing Parameters...")
63
+ print(param_store.list())
64
+
65
+ # Add a new parameter
66
+ param_store.upsert("/workbench/test", "value")
67
+
68
+ # Get the parameter
69
+ print(f"Getting parameter 'test': {param_store.get('/workbench/test')}")
70
+
71
+ # Add a dictionary as a parameter
72
+ sample_dict = {"key": "str_value", "awesome_value": 4.2}
73
+ param_store.upsert("/workbench/my_data", sample_dict)
74
+
75
+ # Retrieve the parameter as a dictionary
76
+ retrieved_value = param_store.get("/workbench/my_data")
77
+ print("Retrieved value:", retrieved_value)
78
+
79
+ # List the parameters
80
+ print("Listing Parameters...")
81
+ print(param_store.list())
82
+
83
+ # List the parameters with a prefix
84
+ print("Listing Parameters with prefix '/workbench':")
85
+ print(param_store.list("/workbench"))
86
+
87
+ # Delete the parameters
88
+ param_store.delete("/workbench/test")
89
+ param_store.delete("/workbench/my_data")
90
+
91
+ # Out of scope tests
92
+ param_store.upsert("test", "value")
93
+ param_store.delete("test")
94
+
95
+ # Recursive delete test
96
+ param_store.upsert("/workbench/test/test1", "value1")
97
+ param_store.upsert("/workbench/test/test2", "value2")
98
+ param_store.delete_recursive("workbench/test/")
@@ -55,9 +55,10 @@ class AWSAccountClamp:
55
55
  # Check our Assume Role
56
56
  self.log.info("Checking Workbench Assumed Role...")
57
57
  role_info = self.aws_session.assumed_role_info()
58
+ self.log.info(f"Assumed Role: {role_info}")
58
59
 
59
- # Check if the Role is a 'ReadOnly' role
60
- self.read_only_role = "readonly" in role_info["AssumedRoleArn"].lower()
60
+ # Check if we have tag write permissions (if we don't, we are read-only)
61
+ self.read_only = not self.check_tag_permissions()
61
62
 
62
63
  # Check our Workbench API Key and Load the License
63
64
  self.log.info("Checking Workbench API License...")
@@ -141,6 +142,45 @@ class AWSAccountClamp:
141
142
  """
142
143
  return self.boto3_session.client("sagemaker")
143
144
 
145
+ def check_tag_permissions(self):
146
+ """Check if current role has permission to add tags to SageMaker endpoints.
147
+
148
+ Returns:
149
+ bool: True if AddTags is allowed, False otherwise
150
+ """
151
+ try:
152
+ sagemaker = self.boto3_session.client("sagemaker")
153
+
154
+ # Use a non-existent endpoint name
155
+ fake_endpoint = "workbench-permission-check-dummy-endpoint"
156
+
157
+ # Try to add tags to the non-existent endpoint
158
+ sagemaker.add_tags(
159
+ ResourceArn=f"arn:aws:sagemaker:{self.region}:{self.account_id}:endpoint/{fake_endpoint}",
160
+ Tags=[{"Key": "PermissionCheck", "Value": "Test"}],
161
+ )
162
+
163
+ # If we get here, we have permission (but endpoint doesn't exist)
164
+ return True
165
+
166
+ except ClientError as e:
167
+ error_code = e.response["Error"]["Code"]
168
+
169
+ # AccessDeniedException = no permission
170
+ if error_code == "AccessDeniedException":
171
+ self.log.debug("No AddTags permission (AccessDeniedException)")
172
+ return False
173
+
174
+ # ResourceNotFound = we have permission, but endpoint doesn't exist
175
+ elif error_code in ["ResourceNotFound", "ValidationException"]:
176
+ self.log.debug("AddTags permission verified (resource not found)")
177
+ return True
178
+
179
+ # Unexpected error, assume no permission for safety
180
+ else:
181
+ self.log.debug(f"Unexpected error checking permissions: {error_code}")
182
+ return False
183
+
144
184
 
145
185
  if __name__ == "__main__":
146
186
  """Exercise the AWS Account Clamp Class"""
@@ -165,3 +205,9 @@ if __name__ == "__main__":
165
205
  print("\n\n*** AWS Sagemaker Session/Client Check ***")
166
206
  sm_client = aws_account_clamp.sagemaker_client()
167
207
  print(sm_client.list_feature_groups()["FeatureGroupSummaries"])
208
+
209
+ print("\n\n*** AWS Tag Permission Check ***")
210
+ if aws_account_clamp.check_tag_permissions():
211
+ print("Tag Permission Check Success...")
212
+ else:
213
+ print("Tag Permission Check Failed...")
@@ -7,7 +7,6 @@ import logging
7
7
  from typing import Union
8
8
  import pandas as pd
9
9
 
10
-
11
10
  # Workbench Imports
12
11
  from workbench.core.cloud_platform.aws.aws_meta import AWSMeta
13
12
 
@@ -123,7 +123,7 @@ class PipelineExecutor:
123
123
  if "model" in workbench_objects and (not subset or "endpoint" in subset):
124
124
  workbench_objects["model"].to_endpoint(**kwargs)
125
125
  endpoint = Endpoint(kwargs["name"])
126
- endpoint.auto_inference(capture=True)
126
+ endpoint.auto_inference()
127
127
 
128
128
  # Found something weird
129
129
  else: