workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
21
21
  from workbench.utils.s3_utils import compute_s3_object_hash
22
22
  from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
23
23
  from workbench.utils.deprecated_utils import deprecated
24
+ from workbench.utils.model_utils import published_proximity_model, get_model_hyperparameters
24
25
 
25
26
 
26
27
  class ModelType(Enum):
@@ -29,69 +30,64 @@ class ModelType(Enum):
29
30
  CLASSIFIER = "classifier"
30
31
  REGRESSOR = "regressor"
31
32
  CLUSTERER = "clusterer"
32
- TRANSFORMER = "transformer"
33
33
  PROXIMITY = "proximity"
34
34
  PROJECTION = "projection"
35
35
  UQ_REGRESSOR = "uq_regressor"
36
36
  ENSEMBLE_REGRESSOR = "ensemble_regressor"
37
+ TRANSFORMER = "transformer"
38
+ UNKNOWN = "unknown"
39
+
40
+
41
+ class ModelFramework(Enum):
42
+ """Enumerated Types for Workbench Model Frameworks"""
43
+
44
+ SKLEARN = "sklearn"
45
+ XGBOOST = "xgboost"
46
+ LIGHTGBM = "lightgbm"
47
+ PYTORCH = "pytorch"
48
+ CHEMPROP = "chemprop"
49
+ TRANSFORMER = "transformer"
50
+ META = "meta"
37
51
  UNKNOWN = "unknown"
38
52
 
39
53
 
40
54
  class ModelImages:
41
55
  """Class for retrieving workbench inference images"""
42
56
 
43
- image_uris = {
44
- # US East 1 images
45
- ("us-east-1", "xgb_training", "0.1", "x86_64"): (
46
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
47
- ),
48
- ("us-east-1", "xgb_inference", "0.1", "x86_64"): (
49
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
50
- ),
51
- ("us-east-1", "pytorch_training", "0.1", "x86_64"): (
52
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
53
- ),
54
- ("us-east-1", "pytorch_inference", "0.1", "x86_64"): (
55
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
56
- ),
57
- # US West 2 images
58
- ("us-west-2", "xgb_training", "0.1", "x86_64"): (
59
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-training:0.1"
60
- ),
61
- ("us-west-2", "xgb_inference", "0.1", "x86_64"): (
62
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1"
63
- ),
64
- ("us-west-2", "pytorch_training", "0.1", "x86_64"): (
65
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-training:0.1"
66
- ),
67
- ("us-west-2", "pytorch_inference", "0.1", "x86_64"): (
68
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-pytorch-inference:0.1"
69
- ),
70
- # ARM64 images
71
- ("us-east-1", "xgb_inference", "0.1", "arm64"): (
72
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
73
- ),
74
- ("us-west-2", "xgb_inference", "0.1", "arm64"): (
75
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-sklearn-xgb-inference:0.1-arm64"
76
- ),
77
- # Meta Endpoint inference images
78
- ("us-east-1", "meta-endpoint", "0.1", "x86_64"): (
79
- "507740646243.dkr.ecr.us-east-1.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
80
- ),
81
- ("us-west-2", "meta-endpoint", "0.1", "x86_64"): (
82
- "507740646243.dkr.ecr.us-west-2.amazonaws.com/aws-ml-images/py312-meta-endpoint:0.1"
83
- ),
57
+ # Account ID
58
+ ACCOUNT_ID = "507740646243"
59
+
60
+ # Image name mappings
61
+ IMAGE_NAMES = {
62
+ "training": "py312-general-ml-training",
63
+ "inference": "py312-general-ml-inference",
64
+ "pytorch_training": "py312-pytorch-training",
65
+ "pytorch_inference": "py312-pytorch-inference",
66
+ "meta_training": "py312-meta-training",
67
+ "meta_inference": "py312-meta-inference",
84
68
  }
85
69
 
86
70
  @classmethod
87
- def get_image_uri(cls, region, image_type, version="0.1", architecture="x86_64"):
88
- key = (region, image_type, version, architecture)
89
- if key in cls.image_uris:
90
- return cls.image_uris[key]
91
- else:
92
- raise ValueError(
93
- f"No matching image found for region: {region}, image_type: {image_type}, version: {version}"
94
- )
71
+ def get_image_uri(cls, region, image_type, version="latest", architecture="x86_64"):
72
+ """
73
+ Dynamically construct ECR image URI.
74
+
75
+ Args:
76
+ region: AWS region (e.g., 'us-east-1', 'us-west-2')
77
+ image_type: Type of image (e.g., 'training', 'inference', 'pytorch_training')
78
+ version: Image version (e.g., '0.1', '0.2' defaults to 'latest')
79
+ architecture: CPU architecture (default: 'x86_64', currently unused but kept for compatibility)
80
+
81
+ Returns:
82
+ ECR image URI string
83
+ """
84
+ if image_type not in cls.IMAGE_NAMES:
85
+ raise ValueError(f"Unknown image_type: {image_type}. Valid types: {list(cls.IMAGE_NAMES.keys())}")
86
+
87
+ image_name = cls.IMAGE_NAMES[image_type]
88
+ uri = f"{cls.ACCOUNT_ID}.dkr.ecr.{region}.amazonaws.com/aws-ml-images/{image_name}:{version}"
89
+
90
+ return uri
95
91
 
96
92
 
97
93
  class ModelCore(Artifact):
@@ -105,11 +101,10 @@ class ModelCore(Artifact):
105
101
  ```
106
102
  """
107
103
 
108
- def __init__(self, model_name: str, model_type: ModelType = None, **kwargs):
104
+ def __init__(self, model_name: str, **kwargs):
109
105
  """ModelCore Initialization
110
106
  Args:
111
107
  model_name (str): Name of Model in Workbench.
112
- model_type (ModelType, optional): Set this for newly created Models. Defaults to None.
113
108
  **kwargs: Additional keyword arguments
114
109
  """
115
110
 
@@ -143,10 +138,8 @@ class ModelCore(Artifact):
143
138
  self.latest_model = self.model_meta["ModelPackageList"][0]
144
139
  self.description = self.latest_model.get("ModelPackageDescription", "-")
145
140
  self.training_job_name = self._extract_training_job_name()
146
- if model_type:
147
- self._set_model_type(model_type)
148
- else:
149
- self.model_type = self._get_model_type()
141
+ self.model_type = self._get_model_type()
142
+ self.model_framework = self._get_model_framework()
150
143
  except (IndexError, KeyError):
151
144
  self.log.critical(f"Model {self.model_name} appears to be malformed. Delete and recreate it!")
152
145
  return
@@ -272,21 +265,25 @@ class ModelCore(Artifact):
272
265
  else:
273
266
  self.log.important(f"No inference data found for {self.model_name}!")
274
267
 
275
- def get_inference_metrics(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
268
+ def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
276
269
  """Retrieve the inference performance metrics for this model
277
270
 
278
271
  Args:
279
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
272
+ capture_name (str, optional): Specific capture_name (default: "auto")
280
273
  Returns:
281
274
  pd.DataFrame: DataFrame of the Model Metrics
282
275
 
283
276
  Note:
284
- If a capture_name isn't specified this will try to return something reasonable
277
+ If a capture_name isn't specified this will try to the 'first' available metrics
285
278
  """
286
279
  # Try to get the auto_capture 'training_holdout' or the training
287
- if capture_name == "latest":
288
- metrics_df = self.get_inference_metrics("auto_inference")
289
- return metrics_df if metrics_df is not None else self.get_inference_metrics("model_training")
280
+ if capture_name == "auto":
281
+ metric_list = self.list_inference_runs()
282
+ if metric_list:
283
+ return self.get_inference_metrics(metric_list[0])
284
+ else:
285
+ self.log.warning(f"No performance metrics found for {self.model_name}!")
286
+ return None
290
287
 
291
288
  # Grab the metrics captured during model training (could return None)
292
289
  if capture_name == "model_training":
@@ -308,11 +305,11 @@ class ModelCore(Artifact):
308
305
  self.log.warning(f"Performance metrics {capture_name} not found for {self.model_name}!")
309
306
  return None
310
307
 
311
- def confusion_matrix(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
308
+ def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
312
309
  """Retrieve the confusion_matrix for this model
313
310
 
314
311
  Args:
315
- capture_name (str, optional): Specific capture_name or "training" (default: "latest")
312
+ capture_name (str, optional): Specific capture_name or "training" (default: "auto")
316
313
  Returns:
317
314
  pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
318
315
  """
@@ -324,7 +321,7 @@ class ModelCore(Artifact):
324
321
  raise ValueError(error_msg)
325
322
 
326
323
  # Grab the metrics from the Workbench Metadata (try inference first, then training)
327
- if capture_name == "latest":
324
+ if capture_name == "auto":
328
325
  cm = self.confusion_matrix("auto_inference")
329
326
  return cm if cm is not None else self.confusion_matrix("model_training")
330
327
 
@@ -546,6 +543,17 @@ class ModelCore(Artifact):
546
543
  else:
547
544
  self.log.error(f"Model {self.model_name} is not a classifier!")
548
545
 
546
+ def summary(self) -> dict:
547
+ """Summary information about this Model
548
+
549
+ Returns:
550
+ dict: Dictionary of summary information about this Model
551
+ """
552
+ self.log.info("Computing Model Summary...")
553
+ summary = super().summary()
554
+ summary["hyperparameters"] = get_model_hyperparameters(self)
555
+ return summary
556
+
549
557
  def details(self) -> dict:
550
558
  """Additional Details about this Model
551
559
 
@@ -570,6 +578,7 @@ class ModelCore(Artifact):
570
578
  details["status"] = self.latest_model["ModelPackageStatus"]
571
579
  details["approval_status"] = self.latest_model.get("ModelApprovalStatus", "unknown")
572
580
  details["image"] = self.container_image().split("/")[-1] # Shorten the image uri
581
+ details["hyperparameters"] = get_model_hyperparameters(self)
573
582
 
574
583
  # Grab the inference and container info
575
584
  inference_spec = self.latest_model["InferenceSpecification"]
@@ -580,16 +589,6 @@ class ModelCore(Artifact):
580
589
  details["transform_types"] = inference_spec["SupportedTransformInstanceTypes"]
581
590
  details["content_types"] = inference_spec["SupportedContentTypes"]
582
591
  details["response_types"] = inference_spec["SupportedResponseMIMETypes"]
583
- details["model_metrics"] = self.get_inference_metrics()
584
- if self.model_type == ModelType.CLASSIFIER:
585
- details["confusion_matrix"] = self.confusion_matrix()
586
- details["predictions"] = None
587
- elif self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
588
- details["confusion_matrix"] = None
589
- details["predictions"] = self.get_inference_predictions()
590
- else:
591
- details["confusion_matrix"] = None
592
- details["predictions"] = None
593
592
 
594
593
  # Grab the inference metadata
595
594
  details["inference_meta"] = self.get_inference_metadata()
@@ -597,6 +596,24 @@ class ModelCore(Artifact):
597
596
  # Return the details
598
597
  return details
599
598
 
599
+ # Training View for this model
600
+ def training_view(self):
601
+ """Get the training view for this model"""
602
+ from workbench.core.artifacts.feature_set_core import FeatureSetCore
603
+ from workbench.core.views import View
604
+
605
+ # Grab our FeatureSet
606
+ fs = FeatureSetCore(self.get_input())
607
+
608
+ # See if we have a training view for this model
609
+ my_model_training_view = f"{self.name.replace('-', '_')}_training".lower()
610
+ view = View(fs, my_model_training_view, auto_create_view=False)
611
+ if view.exists():
612
+ return view
613
+ else:
614
+ self.log.important(f"No specific training view {my_model_training_view}, returning default training view")
615
+ return fs.view("training")
616
+
600
617
  # Pipeline for this model
601
618
  def get_pipeline(self) -> str:
602
619
  """Get the pipeline for this model"""
@@ -860,7 +877,7 @@ class ModelCore(Artifact):
860
877
  return self.df_store.get(f"/workbench/models/{self.name}/shap_data")
861
878
  else:
862
879
  # Loop over the SHAP data and return a dict of DataFrames
863
- shap_dfs = self.df_store.list_subfiles(f"/workbench/models/{self.name}/shap_data")
880
+ shap_dfs = self.df_store.list(f"/workbench/models/{self.name}/shap_data")
864
881
  shap_data = {}
865
882
  for df_location in shap_dfs:
866
883
  key = df_location.split("/")[-1]
@@ -879,10 +896,24 @@ class ModelCore(Artifact):
879
896
  except (KeyError, IndexError, TypeError):
880
897
  return None
881
898
 
899
+ def publish_prox_model(self, prox_model_name: str = None, include_all_columns: bool = False):
900
+ """Create and publish a Proximity Model for this Model
901
+
902
+ Args:
903
+ prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
904
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
905
+
906
+ Returns:
907
+ Model: The published Proximity Model
908
+ """
909
+ if prox_model_name is None:
910
+ prox_model_name = self.model_name + "-prox"
911
+ return published_proximity_model(self, prox_model_name, include_all_columns=include_all_columns)
912
+
882
913
  def delete(self):
883
914
  """Delete the Model Packages and the Model Group"""
884
915
  if not self.exists():
885
- self.log.warning(f"Trying to delete an Model that doesn't exist: {self.name}")
916
+ self.log.warning(f"Trying to delete a Model that doesn't exist: {self.name}")
886
917
 
887
918
  # Call the Class Method to delete the Model Group
888
919
  ModelCore.managed_delete(model_group_name=self.name)
@@ -958,6 +989,27 @@ class ModelCore(Artifact):
958
989
  self.log.warning(f"Could not determine model type for {self.model_name}!")
959
990
  return ModelType.UNKNOWN
960
991
 
992
+ def _set_model_framework(self, model_framework: ModelFramework):
993
+ """Internal: Set the Model Framework for this Model"""
994
+ self.model_framework = model_framework
995
+ self.upsert_workbench_meta({"workbench_model_framework": self.model_framework.value})
996
+ self.remove_health_tag("model_framework_unknown")
997
+
998
+ def _get_model_framework(self) -> ModelFramework:
999
+ """Internal: Query the Workbench Metadata to get the model framework
1000
+ Returns:
1001
+ ModelFramework: The ModelFramework of this Model
1002
+ Notes:
1003
+ This is an internal method that should not be called directly
1004
+ Use the model_framework attribute instead
1005
+ """
1006
+ model_framework = self.workbench_meta().get("workbench_model_framework")
1007
+ try:
1008
+ return ModelFramework(model_framework)
1009
+ except ValueError:
1010
+ self.log.warning(f"Could not determine model framework for {self.model_name}!")
1011
+ return ModelFramework.UNKNOWN
1012
+
961
1013
  def _load_training_metrics(self):
962
1014
  """Internal: Retrieve the training metrics and Confusion Matrix for this model
963
1015
  and load the data into the Workbench Metadata
@@ -1149,13 +1201,11 @@ if __name__ == "__main__":
1149
1201
  # Grab a ModelCore object and pull some information from it
1150
1202
  my_model = ModelCore("abalone-regression")
1151
1203
 
1152
- # Call the various methods
1153
-
1154
1204
  # Let's do a check/validation of the Model
1155
1205
  print(f"Model Check: {my_model.exists()}")
1156
1206
 
1157
1207
  # Make sure the model is 'ready'
1158
- # my_model.onboard()
1208
+ my_model.onboard()
1159
1209
 
1160
1210
  # Get the ARN of the Model Group
1161
1211
  print(f"Model Group ARN: {my_model.group_arn()}")
@@ -1221,5 +1271,10 @@ if __name__ == "__main__":
1221
1271
  # Delete the Model
1222
1272
  # ModelCore.managed_delete("wine-classification")
1223
1273
 
1274
+ # Check the training view logic
1275
+ model = ModelCore("wine-class-test-251112-BW")
1276
+ training_view = model.training_view()
1277
+ print(f"Training View Name: {training_view.name}")
1278
+
1224
1279
  # Check for a model that doesn't exist
1225
1280
  my_model = ModelCore("empty-model-group")