workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (84) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  3. workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
  4. workbench/algorithms/dataframe/projection_2d.py +44 -21
  5. workbench/algorithms/dataframe/proximity.py +78 -150
  6. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  7. workbench/algorithms/models/cleanlab_model.py +382 -0
  8. workbench/algorithms/models/noise_model.py +388 -0
  9. workbench/algorithms/sql/outliers.py +3 -3
  10. workbench/api/__init__.py +3 -0
  11. workbench/api/df_store.py +17 -108
  12. workbench/api/endpoint.py +13 -11
  13. workbench/api/feature_set.py +111 -8
  14. workbench/api/meta_model.py +289 -0
  15. workbench/api/model.py +45 -12
  16. workbench/api/parameter_store.py +3 -52
  17. workbench/cached/cached_model.py +4 -4
  18. workbench/core/artifacts/artifact.py +5 -5
  19. workbench/core/artifacts/df_store_core.py +114 -0
  20. workbench/core/artifacts/endpoint_core.py +228 -237
  21. workbench/core/artifacts/feature_set_core.py +185 -230
  22. workbench/core/artifacts/model_core.py +34 -26
  23. workbench/core/artifacts/parameter_store_core.py +98 -0
  24. workbench/core/pipelines/pipeline_executor.py +1 -1
  25. workbench/core/transforms/features_to_model/features_to_model.py +22 -10
  26. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
  27. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  28. workbench/model_script_utils/model_script_utils.py +339 -0
  29. workbench/model_script_utils/pytorch_utils.py +405 -0
  30. workbench/model_script_utils/uq_harness.py +278 -0
  31. workbench/model_scripts/chemprop/chemprop.template +428 -631
  32. workbench/model_scripts/chemprop/generated_model_script.py +432 -635
  33. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  34. workbench/model_scripts/chemprop/requirements.txt +2 -10
  35. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  36. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  37. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  38. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  39. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  40. workbench/model_scripts/meta_model/meta_model.template +209 -0
  41. workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
  42. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  43. workbench/model_scripts/pytorch_model/pytorch.template +370 -609
  44. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  45. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  46. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  47. workbench/model_scripts/script_generation.py +6 -5
  48. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  49. workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
  50. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  51. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  52. workbench/model_scripts/xgb_model/xgb_model.template +366 -396
  53. workbench/repl/workbench_shell.py +0 -5
  54. workbench/resources/open_source_api.key +1 -1
  55. workbench/scripts/endpoint_test.py +2 -2
  56. workbench/scripts/meta_model_sim.py +35 -0
  57. workbench/scripts/training_test.py +85 -0
  58. workbench/utils/chem_utils/fingerprints.py +87 -46
  59. workbench/utils/chem_utils/projections.py +16 -6
  60. workbench/utils/chemprop_utils.py +36 -655
  61. workbench/utils/meta_model_simulator.py +499 -0
  62. workbench/utils/metrics_utils.py +256 -0
  63. workbench/utils/model_utils.py +192 -54
  64. workbench/utils/pytorch_utils.py +33 -472
  65. workbench/utils/shap_utils.py +1 -55
  66. workbench/utils/xgboost_local_crossfold.py +267 -0
  67. workbench/utils/xgboost_model_utils.py +49 -356
  68. workbench/web_interface/components/model_plot.py +7 -1
  69. workbench/web_interface/components/plugins/model_details.py +30 -68
  70. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  71. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
  72. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
  73. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
  74. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  75. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
  76. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  77. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  78. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  79. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  80. workbench/model_scripts/uq_models/mapie.template +0 -605
  81. workbench/model_scripts/uq_models/requirements.txt +0 -1
  82. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  83. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
  84. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -58,10 +58,7 @@ class FeatureSet(FeatureSetCore):
58
58
  include_aws_columns (bool): Include the AWS columns in the DataFrame (default: False)
59
59
 
60
60
  Returns:
61
- pd.DataFrame: A DataFrame of ALL the data from this FeatureSet
62
-
63
- Note:
64
- Obviously this is not recommended for large datasets :)
61
+ pd.DataFrame: A DataFrame of all the data from this FeatureSet up to the limit
65
62
  """
66
63
 
67
64
  # Get the table associated with the data
@@ -83,7 +80,7 @@ class FeatureSet(FeatureSetCore):
83
80
  tags: list = None,
84
81
  description: str = None,
85
82
  feature_list: list = None,
86
- target_column: str = None,
83
+ target_column: Union[str, list[str]] = None,
87
84
  model_class: str = None,
88
85
  model_import_str: str = None,
89
86
  custom_script: Union[str, Path] = None,
@@ -103,7 +100,7 @@ class FeatureSet(FeatureSetCore):
103
100
  tags (list, optional): Set the tags for the model. If not given tags will be generated.
104
101
  description (str, optional): Set the description for the model. If not give a description is generated.
105
102
  feature_list (list, optional): Set the feature list for the model. If not given a feature list is generated.
106
- target_column (str, optional): The target column for the model (use None for unsupervised model)
103
+ target_column (str or list[str], optional): Target column(s) for the model (use None for unsupervised model)
107
104
  model_class (str, optional): Model class to use (e.g. "KMeans", default: None)
108
105
  model_import_str (str, optional): The import for the model (e.g. "from sklearn.cluster import KMeans")
109
106
  custom_script (str, optional): The custom script to use for the model (default: None)
@@ -131,7 +128,7 @@ class FeatureSet(FeatureSetCore):
131
128
  tags = [name] if tags is None else tags
132
129
 
133
130
  # If the model framework is PyTorch or ChemProp, ensure we set the training and inference images
134
- if model_framework in (ModelFramework.PYTORCH_TABULAR, ModelFramework.CHEMPROP):
131
+ if model_framework in (ModelFramework.PYTORCH, ModelFramework.CHEMPROP):
135
132
  training_image = "pytorch_training"
136
133
  inference_image = "pytorch_inference"
137
134
 
@@ -157,6 +154,93 @@ class FeatureSet(FeatureSetCore):
157
154
  # Return the Model
158
155
  return Model(name)
159
156
 
157
+ def prox_model(
158
+ self, target: str, features: list, include_all_columns: bool = False
159
+ ) -> "FeatureSpaceProximity": # noqa: F821
160
+ """Create a local FeatureSpaceProximity Model for this FeatureSet
161
+
162
+ Args:
163
+ target (str): The target column name
164
+ features (list): The list of feature column names
165
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
166
+
167
+ Returns:
168
+ FeatureSpaceProximity: A local FeatureSpaceProximity Model
169
+ """
170
+ from workbench.algorithms.dataframe.feature_space_proximity import FeatureSpaceProximity # noqa: F401
171
+
172
+ # Create the Proximity Model from the full FeatureSet dataframe
173
+ full_df = self.pull_dataframe()
174
+
175
+ # Create and return the FeatureSpaceProximity Model
176
+ return FeatureSpaceProximity(
177
+ full_df, id_column=self.id_column, features=features, target=target, include_all_columns=include_all_columns
178
+ )
179
+
180
+ def fp_prox_model(
181
+ self,
182
+ target: str,
183
+ fingerprint_column: str = None,
184
+ include_all_columns: bool = False,
185
+ radius: int = 2,
186
+ n_bits: int = 1024,
187
+ counts: bool = False,
188
+ ) -> "FingerprintProximity": # noqa: F821
189
+ """Create a local FingerprintProximity Model for this FeatureSet
190
+
191
+ Args:
192
+ target (str): The target column name
193
+ fingerprint_column (str): Column containing fingerprints. If None, uses existing 'fingerprint'
194
+ column or computes from SMILES column.
195
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
196
+ radius (int): Radius for Morgan fingerprint computation (default: 2)
197
+ n_bits (int): Number of bits for fingerprint (default: 1024)
198
+ counts (bool): Whether to use count simulation (default: False)
199
+
200
+ Returns:
201
+ FingerprintProximity: A local FingerprintProximity Model
202
+ """
203
+ from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity # noqa: F401
204
+
205
+ # Create the Proximity Model from the full FeatureSet dataframe
206
+ full_df = self.pull_dataframe()
207
+
208
+ # Create and return the FingerprintProximity Model
209
+ return FingerprintProximity(
210
+ full_df,
211
+ id_column=self.id_column,
212
+ fingerprint_column=fingerprint_column,
213
+ target=target,
214
+ include_all_columns=include_all_columns,
215
+ radius=radius,
216
+ n_bits=n_bits,
217
+ )
218
+
219
+ def cleanlab_model(
220
+ self,
221
+ target: str,
222
+ features: list,
223
+ model_type: ModelType = ModelType.REGRESSOR,
224
+ ) -> "CleanLearning": # noqa: F821
225
+ """Create a CleanLearning model for detecting label issues in this FeatureSet
226
+
227
+ Args:
228
+ target (str): The target column name
229
+ features (list): The list of feature column names
230
+ model_type (ModelType): The model type (REGRESSOR or CLASSIFIER). Defaults to REGRESSOR.
231
+
232
+ Returns:
233
+ CleanLearning: A fitted cleanlab model. Use get_label_issues() to get
234
+ a DataFrame with id_column, label_quality, predicted_label, given_label, is_label_issue.
235
+ """
236
+ from workbench.algorithms.models.cleanlab_model import create_cleanlab_model # noqa: F401
237
+
238
+ # Get the full FeatureSet dataframe
239
+ full_df = self.pull_dataframe()
240
+
241
+ # Create and return the CleanLearning model
242
+ return create_cleanlab_model(full_df, self.id_column, features, target, model_type=model_type)
243
+
160
244
 
161
245
  if __name__ == "__main__":
162
246
  """Exercise the FeatureSet Class"""
@@ -167,5 +251,24 @@ if __name__ == "__main__":
167
251
  pprint(my_features.summary())
168
252
  pprint(my_features.details())
169
253
 
254
+ # Pull the full DataFrame
255
+ df = my_features.pull_dataframe()
256
+ print(df.head())
257
+
258
+ # Create a Proximity Model from the FeatureSet
259
+ features = ["height", "weight", "age", "iq_score", "likes_dogs", "food"]
260
+ my_prox = my_features.prox_model(target="salary", features=features)
261
+ neighbors = my_prox.neighbors(42)
262
+ print("Neighbors for ID 42:")
263
+ print(neighbors)
264
+
170
265
  # Create a Model from the FeatureSet
171
- my_model = my_features.to_model(name="test-model", model_type=ModelType.REGRESSOR, target_column="iq_score")
266
+ """
267
+ my_model = my_features.to_model(
268
+ name="test-model",
269
+ model_type=ModelType.REGRESSOR,
270
+ target_column="salary",
271
+ feature_list=features
272
+ )
273
+ pprint(my_model.summary())
274
+ """
@@ -0,0 +1,289 @@
1
+ """MetaModel: A Model that aggregates predictions from multiple child endpoints.
2
+
3
+ MetaModels don't train on feature data - they combine predictions from existing
4
+ endpoints using confidence-weighted voting. This provides ensemble benefits
5
+ across different model frameworks (XGBoost, PyTorch, ChemProp, etc.).
6
+ """
7
+
8
+ from pathlib import Path
9
+ import time
10
+ import logging
11
+
12
+ from sagemaker.estimator import Estimator
13
+
14
+ # Workbench Imports
15
+ from workbench.api.model import Model
16
+ from workbench.api.endpoint import Endpoint
17
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework, ModelImages
18
+ from workbench.core.artifacts.artifact import Artifact
19
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
20
+ from workbench.model_scripts.script_generation import generate_model_script
21
+ from workbench.utils.config_manager import ConfigManager
22
+ from workbench.utils.model_utils import supported_instance_types
23
+
24
+ # Set up logging
25
+ log = logging.getLogger("workbench")
26
+
27
+
28
+ class MetaModel(Model):
29
+ """MetaModel: A Model that aggregates predictions from child endpoints.
30
+
31
+ Common Usage:
32
+ ```python
33
+ # Create a meta model from existing endpoints
34
+ meta = MetaModel.create(
35
+ name="my-meta-model",
36
+ child_endpoints=["endpoint-1", "endpoint-2", "endpoint-3"],
37
+ target_column="target"
38
+ )
39
+
40
+ # Deploy like any other model
41
+ endpoint = meta.to_endpoint()
42
+ ```
43
+ """
44
+
45
+ @classmethod
46
+ def create(
47
+ cls,
48
+ name: str,
49
+ child_endpoints: list[str],
50
+ target_column: str,
51
+ description: str = None,
52
+ tags: list[str] = None,
53
+ ) -> "MetaModel":
54
+ """Create a new MetaModel from a list of child endpoints.
55
+
56
+ Args:
57
+ name: Name for the meta model
58
+ child_endpoints: List of endpoint names to aggregate
59
+ target_column: Name of the target column (for metadata)
60
+ description: Optional description for the model
61
+ tags: Optional list of tags
62
+
63
+ Returns:
64
+ MetaModel: The created meta model
65
+ """
66
+ Artifact.is_name_valid(name, delimiter="-", lower_case=False)
67
+
68
+ # Validate endpoints and get lineage info from primary endpoint
69
+ feature_list, feature_set_name, model_weights = cls._validate_and_get_lineage(child_endpoints)
70
+
71
+ # Delete existing model if it exists
72
+ log.important(f"Trying to delete existing model {name}...")
73
+ ModelCore.managed_delete(name)
74
+
75
+ # Run training and register model
76
+ aws_clamp = AWSAccountClamp()
77
+ estimator = cls._run_training(name, child_endpoints, target_column, model_weights, aws_clamp)
78
+ cls._register_model(name, child_endpoints, description, tags, estimator, aws_clamp)
79
+
80
+ # Set metadata and onboard
81
+ cls._set_metadata(name, target_column, feature_list, feature_set_name, child_endpoints)
82
+
83
+ log.important(f"MetaModel {name} created successfully!")
84
+ return cls(name)
85
+
86
+ @classmethod
87
+ def _validate_and_get_lineage(cls, child_endpoints: list[str]) -> tuple[list[str], str, dict[str, float]]:
88
+ """Validate child endpoints exist and get lineage info from primary endpoint.
89
+
90
+ Args:
91
+ child_endpoints: List of endpoint names
92
+
93
+ Returns:
94
+ tuple: (feature_list, feature_set_name, model_weights) from the primary endpoint's model
95
+ """
96
+ log.info("Verifying child endpoints and gathering model metrics...")
97
+ mae_scores = {}
98
+
99
+ for ep_name in child_endpoints:
100
+ ep = Endpoint(ep_name)
101
+ if not ep.exists():
102
+ raise ValueError(f"Child endpoint '{ep_name}' does not exist")
103
+
104
+ # Get model MAE from full_inference metrics
105
+ model = Model(ep.get_input())
106
+ metrics = model.get_inference_metrics("full_inference")
107
+ if metrics is not None and "mae" in metrics.columns:
108
+ mae = float(metrics["mae"].iloc[0])
109
+ mae_scores[ep_name] = mae
110
+ log.info(f" {ep_name} -> {model.name}: MAE={mae:.4f}")
111
+ else:
112
+ log.warning(f" {ep_name}: No full_inference metrics found, using default weight")
113
+ mae_scores[ep_name] = None
114
+
115
+ # Compute inverse-MAE weights (higher weight for lower MAE)
116
+ valid_mae = {k: v for k, v in mae_scores.items() if v is not None}
117
+ if valid_mae:
118
+ inv_mae = {k: 1.0 / v for k, v in valid_mae.items()}
119
+ total = sum(inv_mae.values())
120
+ model_weights = {k: v / total for k, v in inv_mae.items()}
121
+ # Fill in missing weights with equal share of remaining weight
122
+ missing = [k for k in mae_scores if mae_scores[k] is None]
123
+ if missing:
124
+ equal_weight = (1.0 - sum(model_weights.values())) / len(missing)
125
+ for k in missing:
126
+ model_weights[k] = equal_weight
127
+ else:
128
+ # No metrics available, use equal weights
129
+ model_weights = {k: 1.0 / len(child_endpoints) for k in child_endpoints}
130
+ log.warning("No MAE metrics found, using equal weights")
131
+
132
+ log.info(f"Model weights: {model_weights}")
133
+
134
+ # Use first endpoint as primary - backtrack to get model and feature set
135
+ primary_endpoint = Endpoint(child_endpoints[0])
136
+ primary_model = Model(primary_endpoint.get_input())
137
+ feature_list = primary_model.features()
138
+ feature_set_name = primary_model.get_input()
139
+
140
+ log.info(
141
+ f"Primary endpoint: {child_endpoints[0]} -> Model: {primary_model.name} -> FeatureSet: {feature_set_name}"
142
+ )
143
+ return feature_list, feature_set_name, model_weights
144
+
145
+ @classmethod
146
+ def _run_training(
147
+ cls,
148
+ name: str,
149
+ child_endpoints: list[str],
150
+ target_column: str,
151
+ model_weights: dict[str, float],
152
+ aws_clamp: AWSAccountClamp,
153
+ ) -> Estimator:
154
+ """Run the minimal training job that saves the meta model config.
155
+
156
+ Args:
157
+ name: Model name
158
+ child_endpoints: List of endpoint names
159
+ target_column: Target column name
160
+ model_weights: Dict mapping endpoint name to weight
161
+ aws_clamp: AWS account clamp
162
+
163
+ Returns:
164
+ Estimator: The fitted estimator
165
+ """
166
+ sm_session = aws_clamp.sagemaker_session()
167
+ cm = ConfigManager()
168
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
169
+ models_s3_path = f"s3://{workbench_bucket}/models"
170
+
171
+ # Generate the model script from template
172
+ template_params = {
173
+ "model_type": ModelType.REGRESSOR,
174
+ "model_framework": ModelFramework.META,
175
+ "child_endpoints": child_endpoints,
176
+ "target_column": target_column,
177
+ "model_weights": model_weights,
178
+ "model_metrics_s3_path": f"{models_s3_path}/{name}/training",
179
+ "aws_region": sm_session.boto_region_name,
180
+ }
181
+ script_path = generate_model_script(template_params)
182
+
183
+ # Create estimator
184
+ training_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_training")
185
+ log.info(f"Using Meta Training Image: {training_image}")
186
+ estimator = Estimator(
187
+ entry_point=Path(script_path).name,
188
+ source_dir=str(Path(script_path).parent),
189
+ role=aws_clamp.aws_session.get_workbench_execution_role_arn(),
190
+ instance_count=1,
191
+ instance_type="ml.m5.large",
192
+ sagemaker_session=sm_session,
193
+ image_uri=training_image,
194
+ )
195
+
196
+ # Run training (no input data needed - just saves config)
197
+ log.important(f"Creating MetaModel {name}...")
198
+ estimator.fit()
199
+
200
+ return estimator
201
+
202
+ @classmethod
203
+ def _register_model(
204
+ cls,
205
+ name: str,
206
+ child_endpoints: list[str],
207
+ description: str,
208
+ tags: list[str],
209
+ estimator: Estimator,
210
+ aws_clamp: AWSAccountClamp,
211
+ ):
212
+ """Create model group and register the model.
213
+
214
+ Args:
215
+ name: Model name
216
+ child_endpoints: List of endpoint names
217
+ description: Model description
218
+ tags: Model tags
219
+ estimator: Fitted estimator
220
+ aws_clamp: AWS account clamp
221
+ """
222
+ sm_session = aws_clamp.sagemaker_session()
223
+ model_description = description or f"Meta model aggregating: {', '.join(child_endpoints)}"
224
+
225
+ # Create model group
226
+ aws_clamp.sagemaker_client().create_model_package_group(
227
+ ModelPackageGroupName=name,
228
+ ModelPackageGroupDescription=model_description,
229
+ Tags=[{"Key": "workbench_tags", "Value": "::".join(tags or [name])}],
230
+ )
231
+
232
+ # Register the model with meta inference image
233
+ inference_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_inference")
234
+ log.important(f"Registering model {name} with Inference Image {inference_image}...")
235
+ estimator.create_model(role=aws_clamp.aws_session.get_workbench_execution_role_arn()).register(
236
+ model_package_group_name=name,
237
+ image_uri=inference_image,
238
+ content_types=["text/csv"],
239
+ response_types=["text/csv"],
240
+ inference_instances=supported_instance_types("x86_64"),
241
+ transform_instances=["ml.m5.large", "ml.m5.xlarge"],
242
+ approval_status="Approved",
243
+ description=model_description,
244
+ )
245
+
246
+ @classmethod
247
+ def _set_metadata(
248
+ cls, name: str, target_column: str, feature_list: list[str], feature_set_name: str, child_endpoints: list[str]
249
+ ):
250
+ """Set model metadata and onboard.
251
+
252
+ Args:
253
+ name: Model name
254
+ target_column: Target column name
255
+ feature_list: List of feature names
256
+ feature_set_name: Name of the input FeatureSet
257
+ child_endpoints: List of child endpoint names
258
+ """
259
+ time.sleep(3)
260
+ output_model = ModelCore(name)
261
+ output_model._set_model_type(ModelType.UQ_REGRESSOR)
262
+ output_model._set_model_framework(ModelFramework.META)
263
+ output_model.set_input(feature_set_name, force=True)
264
+ output_model.upsert_workbench_meta({"workbench_model_target": target_column})
265
+ output_model.upsert_workbench_meta({"workbench_model_features": feature_list})
266
+ output_model.upsert_workbench_meta({"child_endpoints": child_endpoints})
267
+ output_model.onboard_with_args(ModelType.UQ_REGRESSOR, target_column, feature_list=feature_list)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ """Exercise the MetaModel Class"""
272
+
273
+ meta = MetaModel.create(
274
+ name="logd-meta",
275
+ child_endpoints=["logd-xgb", "logd-pytorch", "logd-chemprop"],
276
+ target_column="logd",
277
+ description="Meta model for LogD prediction",
278
+ tags=["meta", "logd", "ensemble"],
279
+ )
280
+ print(meta.summary())
281
+
282
+ # Create an endpoint for the meta model
283
+ end = meta.to_endpoint(tags=["meta", "logd"])
284
+ end.set_owner("BW")
285
+ end.auto_inference()
286
+
287
+ # Test loading an existing meta model
288
+ meta = MetaModel("logd-meta")
289
+ print(meta.details())
workbench/api/model.py CHANGED
@@ -10,7 +10,12 @@ from workbench.core.artifacts.artifact import Artifact
10
10
  from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework # noqa: F401
11
11
  from workbench.core.transforms.model_to_endpoint.model_to_endpoint import ModelToEndpoint
12
12
  from workbench.api.endpoint import Endpoint
13
- from workbench.utils.model_utils import proximity_model_local, uq_model
13
+ from workbench.utils.model_utils import (
14
+ proximity_model_local,
15
+ fingerprint_prox_model_local,
16
+ noise_model_local,
17
+ cleanlab_model_local,
18
+ )
14
19
 
15
20
 
16
21
  class Model(ModelCore):
@@ -83,27 +88,55 @@ class Model(ModelCore):
83
88
  end.set_owner(self.get_owner())
84
89
  return end
85
90
 
86
- def prox_model(self):
91
+ def prox_model(self, include_all_columns: bool = False):
87
92
  """Create a local Proximity Model for this Model
88
93
 
94
+ Args:
95
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
96
+
89
97
  Returns:
90
- Proximity: A local Proximity Model
98
+ FeatureSpaceProximity: A local FeatureSpaceProximity Model
91
99
  """
92
- return proximity_model_local(self)
100
+ return proximity_model_local(self, include_all_columns=include_all_columns)
93
101
 
94
- def uq_model(self, uq_model_name: str = None, train_all_data: bool = False) -> "Model":
95
- """Create a Uncertainty Quantification Model for this Model
102
+ def fp_prox_model(
103
+ self,
104
+ include_all_columns: bool = False,
105
+ radius: int = 2,
106
+ n_bits: int = 1024,
107
+ counts: bool = False,
108
+ ):
109
+ """Create a local Fingerprint Proximity Model for this Model
96
110
 
97
111
  Args:
98
- uq_model_name (str, optional): Name of the UQ Model (if not specified, a name will be generated)
99
- train_all_data (bool, optional): Whether to train the UQ Model on all data (default: False)
112
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
113
+ radius (int): Morgan fingerprint radius (default: 2)
114
+ n_bits (int): Number of bits for the fingerprint (default: 1024)
115
+ counts (bool): Use count fingerprints instead of binary (default: False)
116
+
117
+ Returns:
118
+ FingerprintProximity: A local FingerprintProximity Model
119
+ """
120
+ return fingerprint_prox_model_local(
121
+ self, include_all_columns=include_all_columns, radius=radius, n_bits=n_bits, counts=counts
122
+ )
123
+
124
+ def noise_model(self):
125
+ """Create a local Noise Model for this Model
126
+
127
+ Returns:
128
+ NoiseModel: A local Noise Model
129
+ """
130
+ return noise_model_local(self)
131
+
132
+ def cleanlab_model(self):
133
+ """Create a CleanLearning model for this Model's training data.
100
134
 
101
135
  Returns:
102
- Model: The UQ Model
136
+ CleanLearning: A fitted cleanlab model. Use get_label_issues() to get
137
+ a DataFrame with id_column, label_quality, predicted_label, given_label, is_label_issue.
103
138
  """
104
- if uq_model_name is None:
105
- uq_model_name = self.model_name + "-uq"
106
- return uq_model(self, uq_model_name, train_all_data=train_all_data)
139
+ return cleanlab_model_local(self)
107
140
 
108
141
 
109
142
  if __name__ == "__main__":
@@ -1,13 +1,10 @@
1
1
  """ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store."""
2
2
 
3
- from typing import Union
4
- import logging
5
-
6
3
  # Workbench Imports
7
- from workbench.core.cloud_platform.aws.aws_parameter_store import AWSParameterStore
4
+ from workbench.core.artifacts.parameter_store_core import ParameterStoreCore
8
5
 
9
6
 
10
- class ParameterStore(AWSParameterStore):
7
+ class ParameterStore(ParameterStoreCore):
11
8
  """ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store.
12
9
 
13
10
  Common Usage:
@@ -43,56 +40,10 @@ class ParameterStore(AWSParameterStore):
43
40
 
44
41
  def __init__(self):
45
42
  """ParameterStore Init Method"""
46
- self.log = logging.getLogger("workbench")
47
43
 
48
- # Initialize the SuperClass
44
+ # Initialize parent class
49
45
  super().__init__()
50
46
 
51
- def list(self, prefix: str = None) -> list:
52
- """List all parameters in the AWS Parameter Store, optionally filtering by a prefix.
53
-
54
- Args:
55
- prefix (str, optional): A prefix to filter the parameters by. Defaults to None.
56
-
57
- Returns:
58
- list: A list of parameter names and details.
59
- """
60
- return super().list(prefix=prefix)
61
-
62
- def get(self, name: str, warn: bool = True, decrypt: bool = True) -> Union[str, list, dict, None]:
63
- """Retrieve a parameter value from the AWS Parameter Store.
64
-
65
- Args:
66
- name (str): The name of the parameter to retrieve.
67
- warn (bool): Whether to log a warning if the parameter is not found.
68
- decrypt (bool): Whether to decrypt secure string parameters.
69
-
70
- Returns:
71
- Union[str, list, dict, None]: The value of the parameter or None if not found.
72
- """
73
- return super().get(name=name, warn=warn, decrypt=decrypt)
74
-
75
- def upsert(self, name: str, value):
76
- """Insert or update a parameter in the AWS Parameter Store.
77
-
78
- Args:
79
- name (str): The name of the parameter.
80
- value (str | list | dict): The value of the parameter.
81
- """
82
- super().upsert(name=name, value=value)
83
-
84
- def delete(self, name: str):
85
- """Delete a parameter from the AWS Parameter Store.
86
-
87
- Args:
88
- name (str): The name of the parameter to delete.
89
- """
90
- super().delete(name=name)
91
-
92
- def __repr__(self):
93
- """Return a string representation of the ParameterStore object."""
94
- return super().__repr__()
95
-
96
47
 
97
48
  if __name__ == "__main__":
98
49
  """Exercise the ParameterStore Class"""
@@ -72,11 +72,11 @@ class CachedModel(CachedArtifactMixin, ModelCore):
72
72
  return super().list_inference_runs()
73
73
 
74
74
  @CachedArtifactMixin.cache_result
75
- def get_inference_metrics(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
75
+ def get_inference_metrics(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
76
76
  """Retrieve the captured prediction results for this model
77
77
 
78
78
  Args:
79
- capture_name (str, optional): Specific capture_name (default: latest)
79
+ capture_name (str, optional): Specific capture_name (default: auto)
80
80
 
81
81
  Returns:
82
82
  pd.DataFrame: DataFrame of the Captured Metrics (might be None)
@@ -101,11 +101,11 @@ class CachedModel(CachedArtifactMixin, ModelCore):
101
101
  return df
102
102
 
103
103
  @CachedArtifactMixin.cache_result
104
- def confusion_matrix(self, capture_name: str = "latest") -> Union[pd.DataFrame, None]:
104
+ def confusion_matrix(self, capture_name: str = "auto") -> Union[pd.DataFrame, None]:
105
105
  """Retrieve the confusion matrix for the model
106
106
 
107
107
  Args:
108
- capture_name (str, optional): Specific capture_name (default: latest)
108
+ capture_name (str, optional): Specific capture_name (default: auto)
109
109
 
110
110
  Returns:
111
111
  pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
@@ -8,8 +8,8 @@ from typing import Union
8
8
 
9
9
  # Workbench Imports
10
10
  from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
11
- from workbench.core.cloud_platform.aws.aws_parameter_store import AWSParameterStore as ParameterStore
12
- from workbench.core.cloud_platform.aws.aws_df_store import AWSDFStore as DFStore
11
+ from workbench.core.artifacts.parameter_store_core import ParameterStoreCore
12
+ from workbench.core.artifacts.df_store_core import DFStoreCore
13
13
  from workbench.utils.aws_utils import dict_to_aws_tags
14
14
  from workbench.utils.config_manager import ConfigManager, FatalConfigError
15
15
  from workbench.core.cloud_platform.cloud_meta import CloudMeta
@@ -48,11 +48,11 @@ class Artifact(ABC):
48
48
  tag_delimiter = "::"
49
49
 
50
50
  # Grab our Dataframe Cache Storage
51
- df_cache = DFStore(path_prefix="/workbench/dataframe_cache")
51
+ df_cache = DFStoreCore(path_prefix="/workbench/dataframe_cache")
52
52
 
53
53
  # Artifact may want to use the Parameter Store or Dataframe Store
54
- param_store = ParameterStore()
55
- df_store = DFStore()
54
+ param_store = ParameterStoreCore()
55
+ df_store = DFStoreCore()
56
56
 
57
57
  def __init__(self, name: str, use_cached_meta: bool = False):
58
58
  """Initialize the Artifact Base Class