workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ import pandas as pd
12
12
  from workbench.core.artifacts.artifact import Artifact
13
13
  from workbench.core.artifacts.feature_set_core import FeatureSetCore
14
14
  from workbench.core.transforms.features_to_model.features_to_model import FeaturesToModel
15
- from workbench.api.model import Model, ModelType
15
+ from workbench.api.model import Model, ModelType, ModelFramework
16
16
 
17
17
 
18
18
  class FeatureSet(FeatureSetCore):
@@ -58,10 +58,7 @@ class FeatureSet(FeatureSetCore):
58
58
  include_aws_columns (bool): Include the AWS columns in the DataFrame (default: False)
59
59
 
60
60
  Returns:
61
- pd.DataFrame: A DataFrame of ALL the data from this FeatureSet
62
-
63
- Note:
64
- Obviously this is not recommended for large datasets :)
61
+ pd.DataFrame: A DataFrame of all the data from this FeatureSet up to the limit
65
62
  """
66
63
 
67
64
  # Get the table associated with the data
@@ -79,10 +76,11 @@ class FeatureSet(FeatureSetCore):
79
76
  self,
80
77
  name: str,
81
78
  model_type: ModelType,
79
+ model_framework: ModelFramework = ModelFramework.XGBOOST,
82
80
  tags: list = None,
83
81
  description: str = None,
84
82
  feature_list: list = None,
85
- target_column: str = None,
83
+ target_column: Union[str, list[str]] = None,
86
84
  model_class: str = None,
87
85
  model_import_str: str = None,
88
86
  custom_script: Union[str, Path] = None,
@@ -98,11 +96,12 @@ class FeatureSet(FeatureSetCore):
98
96
 
99
97
  name (str): The name of the Model to create
100
98
  model_type (ModelType): The type of model to create (See workbench.model.ModelType)
99
+ model_framework (ModelFramework, optional): The framework to use for the model (default: XGBOOST)
101
100
  tags (list, optional): Set the tags for the model. If not given tags will be generated.
102
101
  description (str, optional): Set the description for the model. If not give a description is generated.
103
102
  feature_list (list, optional): Set the feature list for the model. If not given a feature list is generated.
104
- target_column (str, optional): The target column for the model (use None for unsupervised model)
105
- model_class (str, optional): Model class to use (e.g. "KMeans", "PyTorch", default: None)
103
+ target_column (str or list[str], optional): Target column(s) for the model (use None for unsupervised model)
104
+ model_class (str, optional): Model class to use (e.g. "KMeans", default: None)
106
105
  model_import_str (str, optional): The import for the model (e.g. "from sklearn.cluster import KMeans")
107
106
  custom_script (str, optional): The custom script to use for the model (default: None)
108
107
  training_image (str, optional): The training image to use (default: "training")
@@ -128,8 +127,8 @@ class FeatureSet(FeatureSetCore):
128
127
  # Create the Model Tags
129
128
  tags = [name] if tags is None else tags
130
129
 
131
- # If the model_class is PyTorch, ensure we set the training and inference images
132
- if model_class and model_class.lower() == "pytorch":
130
+ # If the model framework is PyTorch or ChemProp, ensure we set the training and inference images
131
+ if model_framework in (ModelFramework.PYTORCH, ModelFramework.CHEMPROP):
133
132
  training_image = "pytorch_training"
134
133
  inference_image = "pytorch_inference"
135
134
 
@@ -138,6 +137,7 @@ class FeatureSet(FeatureSetCore):
138
137
  feature_name=self.name,
139
138
  model_name=name,
140
139
  model_type=model_type,
140
+ model_framework=model_framework,
141
141
  model_class=model_class,
142
142
  model_import_str=model_import_str,
143
143
  custom_script=custom_script,
@@ -154,6 +154,93 @@ class FeatureSet(FeatureSetCore):
154
154
  # Return the Model
155
155
  return Model(name)
156
156
 
157
+ def prox_model(
158
+ self, target: str, features: list, include_all_columns: bool = False
159
+ ) -> "FeatureSpaceProximity": # noqa: F821
160
+ """Create a local FeatureSpaceProximity Model for this FeatureSet
161
+
162
+ Args:
163
+ target (str): The target column name
164
+ features (list): The list of feature column names
165
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
166
+
167
+ Returns:
168
+ FeatureSpaceProximity: A local FeatureSpaceProximity Model
169
+ """
170
+ from workbench.algorithms.dataframe.feature_space_proximity import FeatureSpaceProximity # noqa: F401
171
+
172
+ # Create the Proximity Model from the full FeatureSet dataframe
173
+ full_df = self.pull_dataframe()
174
+
175
+ # Create and return the FeatureSpaceProximity Model
176
+ return FeatureSpaceProximity(
177
+ full_df, id_column=self.id_column, features=features, target=target, include_all_columns=include_all_columns
178
+ )
179
+
180
+ def fp_prox_model(
181
+ self,
182
+ target: str,
183
+ fingerprint_column: str = None,
184
+ include_all_columns: bool = False,
185
+ radius: int = 2,
186
+ n_bits: int = 1024,
187
+ counts: bool = False,
188
+ ) -> "FingerprintProximity": # noqa: F821
189
+ """Create a local FingerprintProximity Model for this FeatureSet
190
+
191
+ Args:
192
+ target (str): The target column name
193
+ fingerprint_column (str): Column containing fingerprints. If None, uses existing 'fingerprint'
194
+ column or computes from SMILES column.
195
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
196
+ radius (int): Radius for Morgan fingerprint computation (default: 2)
197
+ n_bits (int): Number of bits for fingerprint (default: 1024)
198
+ counts (bool): Whether to use count simulation (default: False)
199
+
200
+ Returns:
201
+ FingerprintProximity: A local FingerprintProximity Model
202
+ """
203
+ from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity # noqa: F401
204
+
205
+ # Create the Proximity Model from the full FeatureSet dataframe
206
+ full_df = self.pull_dataframe()
207
+
208
+ # Create and return the FingerprintProximity Model
209
+ return FingerprintProximity(
210
+ full_df,
211
+ id_column=self.id_column,
212
+ fingerprint_column=fingerprint_column,
213
+ target=target,
214
+ include_all_columns=include_all_columns,
215
+ radius=radius,
216
+ n_bits=n_bits,
217
+ )
218
+
219
+ def cleanlab_model(
220
+ self,
221
+ target: str,
222
+ features: list,
223
+ model_type: ModelType = ModelType.REGRESSOR,
224
+ ) -> "CleanLearning": # noqa: F821
225
+ """Create a CleanLearning model for detecting label issues in this FeatureSet
226
+
227
+ Args:
228
+ target (str): The target column name
229
+ features (list): The list of feature column names
230
+ model_type (ModelType): The model type (REGRESSOR or CLASSIFIER). Defaults to REGRESSOR.
231
+
232
+ Returns:
233
+ CleanLearning: A fitted cleanlab model. Use get_label_issues() to get
234
+ a DataFrame with id_column, label_quality, predicted_label, given_label, is_label_issue.
235
+ """
236
+ from workbench.algorithms.models.cleanlab_model import create_cleanlab_model # noqa: F401
237
+
238
+ # Get the full FeatureSet dataframe
239
+ full_df = self.pull_dataframe()
240
+
241
+ # Create and return the CleanLearning model
242
+ return create_cleanlab_model(full_df, self.id_column, features, target, model_type=model_type)
243
+
157
244
 
158
245
  if __name__ == "__main__":
159
246
  """Exercise the FeatureSet Class"""
@@ -164,5 +251,24 @@ if __name__ == "__main__":
164
251
  pprint(my_features.summary())
165
252
  pprint(my_features.details())
166
253
 
254
+ # Pull the full DataFrame
255
+ df = my_features.pull_dataframe()
256
+ print(df.head())
257
+
258
+ # Create a Proximity Model from the FeatureSet
259
+ features = ["height", "weight", "age", "iq_score", "likes_dogs", "food"]
260
+ my_prox = my_features.prox_model(target="salary", features=features)
261
+ neighbors = my_prox.neighbors(42)
262
+ print("Neighbors for ID 42:")
263
+ print(neighbors)
264
+
167
265
  # Create a Model from the FeatureSet
168
- my_model = my_features.to_model(name="test-model", model_type=ModelType.REGRESSOR, target_column="iq_score")
266
+ """
267
+ my_model = my_features.to_model(
268
+ name="test-model",
269
+ model_type=ModelType.REGRESSOR,
270
+ target_column="salary",
271
+ feature_list=features
272
+ )
273
+ pprint(my_model.summary())
274
+ """
workbench/api/meta.py CHANGED
@@ -6,7 +6,6 @@ such as Data Sources, Feature Sets, Models, and Endpoints.
6
6
  from typing import Union
7
7
  import pandas as pd
8
8
 
9
-
10
9
  # Workbench Imports
11
10
  from workbench.core.cloud_platform.cloud_meta import CloudMeta
12
11
 
@@ -0,0 +1,289 @@
1
+ """MetaModel: A Model that aggregates predictions from multiple child endpoints.
2
+
3
+ MetaModels don't train on feature data - they combine predictions from existing
4
+ endpoints using confidence-weighted voting. This provides ensemble benefits
5
+ across different model frameworks (XGBoost, PyTorch, ChemProp, etc.).
6
+ """
7
+
8
+ from pathlib import Path
9
+ import time
10
+ import logging
11
+
12
+ from sagemaker.estimator import Estimator
13
+
14
+ # Workbench Imports
15
+ from workbench.api.model import Model
16
+ from workbench.api.endpoint import Endpoint
17
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework, ModelImages
18
+ from workbench.core.artifacts.artifact import Artifact
19
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
20
+ from workbench.model_scripts.script_generation import generate_model_script
21
+ from workbench.utils.config_manager import ConfigManager
22
+ from workbench.utils.model_utils import supported_instance_types
23
+
24
+ # Set up logging
25
+ log = logging.getLogger("workbench")
26
+
27
+
28
+ class MetaModel(Model):
29
+ """MetaModel: A Model that aggregates predictions from child endpoints.
30
+
31
+ Common Usage:
32
+ ```python
33
+ # Create a meta model from existing endpoints
34
+ meta = MetaModel.create(
35
+ name="my-meta-model",
36
+ child_endpoints=["endpoint-1", "endpoint-2", "endpoint-3"],
37
+ target_column="target"
38
+ )
39
+
40
+ # Deploy like any other model
41
+ endpoint = meta.to_endpoint()
42
+ ```
43
+ """
44
+
45
+ @classmethod
46
+ def create(
47
+ cls,
48
+ name: str,
49
+ child_endpoints: list[str],
50
+ target_column: str,
51
+ description: str = None,
52
+ tags: list[str] = None,
53
+ ) -> "MetaModel":
54
+ """Create a new MetaModel from a list of child endpoints.
55
+
56
+ Args:
57
+ name: Name for the meta model
58
+ child_endpoints: List of endpoint names to aggregate
59
+ target_column: Name of the target column (for metadata)
60
+ description: Optional description for the model
61
+ tags: Optional list of tags
62
+
63
+ Returns:
64
+ MetaModel: The created meta model
65
+ """
66
+ Artifact.is_name_valid(name, delimiter="-", lower_case=False)
67
+
68
+ # Validate endpoints and get lineage info from primary endpoint
69
+ feature_list, feature_set_name, model_weights = cls._validate_and_get_lineage(child_endpoints)
70
+
71
+ # Delete existing model if it exists
72
+ log.important(f"Trying to delete existing model {name}...")
73
+ ModelCore.managed_delete(name)
74
+
75
+ # Run training and register model
76
+ aws_clamp = AWSAccountClamp()
77
+ estimator = cls._run_training(name, child_endpoints, target_column, model_weights, aws_clamp)
78
+ cls._register_model(name, child_endpoints, description, tags, estimator, aws_clamp)
79
+
80
+ # Set metadata and onboard
81
+ cls._set_metadata(name, target_column, feature_list, feature_set_name, child_endpoints)
82
+
83
+ log.important(f"MetaModel {name} created successfully!")
84
+ return cls(name)
85
+
86
+ @classmethod
87
+ def _validate_and_get_lineage(cls, child_endpoints: list[str]) -> tuple[list[str], str, dict[str, float]]:
88
+ """Validate child endpoints exist and get lineage info from primary endpoint.
89
+
90
+ Args:
91
+ child_endpoints: List of endpoint names
92
+
93
+ Returns:
94
+ tuple: (feature_list, feature_set_name, model_weights) from the primary endpoint's model
95
+ """
96
+ log.info("Verifying child endpoints and gathering model metrics...")
97
+ mae_scores = {}
98
+
99
+ for ep_name in child_endpoints:
100
+ ep = Endpoint(ep_name)
101
+ if not ep.exists():
102
+ raise ValueError(f"Child endpoint '{ep_name}' does not exist")
103
+
104
+ # Get model MAE from full_inference metrics
105
+ model = Model(ep.get_input())
106
+ metrics = model.get_inference_metrics("full_inference")
107
+ if metrics is not None and "mae" in metrics.columns:
108
+ mae = float(metrics["mae"].iloc[0])
109
+ mae_scores[ep_name] = mae
110
+ log.info(f" {ep_name} -> {model.name}: MAE={mae:.4f}")
111
+ else:
112
+ log.warning(f" {ep_name}: No full_inference metrics found, using default weight")
113
+ mae_scores[ep_name] = None
114
+
115
+ # Compute inverse-MAE weights (higher weight for lower MAE)
116
+ valid_mae = {k: v for k, v in mae_scores.items() if v is not None}
117
+ if valid_mae:
118
+ inv_mae = {k: 1.0 / v for k, v in valid_mae.items()}
119
+ total = sum(inv_mae.values())
120
+ model_weights = {k: v / total for k, v in inv_mae.items()}
121
+ # Fill in missing weights with equal share of remaining weight
122
+ missing = [k for k in mae_scores if mae_scores[k] is None]
123
+ if missing:
124
+ equal_weight = (1.0 - sum(model_weights.values())) / len(missing)
125
+ for k in missing:
126
+ model_weights[k] = equal_weight
127
+ else:
128
+ # No metrics available, use equal weights
129
+ model_weights = {k: 1.0 / len(child_endpoints) for k in child_endpoints}
130
+ log.warning("No MAE metrics found, using equal weights")
131
+
132
+ log.info(f"Model weights: {model_weights}")
133
+
134
+ # Use first endpoint as primary - backtrack to get model and feature set
135
+ primary_endpoint = Endpoint(child_endpoints[0])
136
+ primary_model = Model(primary_endpoint.get_input())
137
+ feature_list = primary_model.features()
138
+ feature_set_name = primary_model.get_input()
139
+
140
+ log.info(
141
+ f"Primary endpoint: {child_endpoints[0]} -> Model: {primary_model.name} -> FeatureSet: {feature_set_name}"
142
+ )
143
+ return feature_list, feature_set_name, model_weights
144
+
145
+ @classmethod
146
+ def _run_training(
147
+ cls,
148
+ name: str,
149
+ child_endpoints: list[str],
150
+ target_column: str,
151
+ model_weights: dict[str, float],
152
+ aws_clamp: AWSAccountClamp,
153
+ ) -> Estimator:
154
+ """Run the minimal training job that saves the meta model config.
155
+
156
+ Args:
157
+ name: Model name
158
+ child_endpoints: List of endpoint names
159
+ target_column: Target column name
160
+ model_weights: Dict mapping endpoint name to weight
161
+ aws_clamp: AWS account clamp
162
+
163
+ Returns:
164
+ Estimator: The fitted estimator
165
+ """
166
+ sm_session = aws_clamp.sagemaker_session()
167
+ cm = ConfigManager()
168
+ workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
169
+ models_s3_path = f"s3://{workbench_bucket}/models"
170
+
171
+ # Generate the model script from template
172
+ template_params = {
173
+ "model_type": ModelType.REGRESSOR,
174
+ "model_framework": ModelFramework.META,
175
+ "child_endpoints": child_endpoints,
176
+ "target_column": target_column,
177
+ "model_weights": model_weights,
178
+ "model_metrics_s3_path": f"{models_s3_path}/{name}/training",
179
+ "aws_region": sm_session.boto_region_name,
180
+ }
181
+ script_path = generate_model_script(template_params)
182
+
183
+ # Create estimator
184
+ training_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_training")
185
+ log.info(f"Using Meta Training Image: {training_image}")
186
+ estimator = Estimator(
187
+ entry_point=Path(script_path).name,
188
+ source_dir=str(Path(script_path).parent),
189
+ role=aws_clamp.aws_session.get_workbench_execution_role_arn(),
190
+ instance_count=1,
191
+ instance_type="ml.m5.large",
192
+ sagemaker_session=sm_session,
193
+ image_uri=training_image,
194
+ )
195
+
196
+ # Run training (no input data needed - just saves config)
197
+ log.important(f"Creating MetaModel {name}...")
198
+ estimator.fit()
199
+
200
+ return estimator
201
+
202
+ @classmethod
203
+ def _register_model(
204
+ cls,
205
+ name: str,
206
+ child_endpoints: list[str],
207
+ description: str,
208
+ tags: list[str],
209
+ estimator: Estimator,
210
+ aws_clamp: AWSAccountClamp,
211
+ ):
212
+ """Create model group and register the model.
213
+
214
+ Args:
215
+ name: Model name
216
+ child_endpoints: List of endpoint names
217
+ description: Model description
218
+ tags: Model tags
219
+ estimator: Fitted estimator
220
+ aws_clamp: AWS account clamp
221
+ """
222
+ sm_session = aws_clamp.sagemaker_session()
223
+ model_description = description or f"Meta model aggregating: {', '.join(child_endpoints)}"
224
+
225
+ # Create model group
226
+ aws_clamp.sagemaker_client().create_model_package_group(
227
+ ModelPackageGroupName=name,
228
+ ModelPackageGroupDescription=model_description,
229
+ Tags=[{"Key": "workbench_tags", "Value": "::".join(tags or [name])}],
230
+ )
231
+
232
+ # Register the model with meta inference image
233
+ inference_image = ModelImages.get_image_uri(sm_session.boto_region_name, "meta_inference")
234
+ log.important(f"Registering model {name} with Inference Image {inference_image}...")
235
+ estimator.create_model(role=aws_clamp.aws_session.get_workbench_execution_role_arn()).register(
236
+ model_package_group_name=name,
237
+ image_uri=inference_image,
238
+ content_types=["text/csv"],
239
+ response_types=["text/csv"],
240
+ inference_instances=supported_instance_types("x86_64"),
241
+ transform_instances=["ml.m5.large", "ml.m5.xlarge"],
242
+ approval_status="Approved",
243
+ description=model_description,
244
+ )
245
+
246
+ @classmethod
247
+ def _set_metadata(
248
+ cls, name: str, target_column: str, feature_list: list[str], feature_set_name: str, child_endpoints: list[str]
249
+ ):
250
+ """Set model metadata and onboard.
251
+
252
+ Args:
253
+ name: Model name
254
+ target_column: Target column name
255
+ feature_list: List of feature names
256
+ feature_set_name: Name of the input FeatureSet
257
+ child_endpoints: List of child endpoint names
258
+ """
259
+ time.sleep(3)
260
+ output_model = ModelCore(name)
261
+ output_model._set_model_type(ModelType.UQ_REGRESSOR)
262
+ output_model._set_model_framework(ModelFramework.META)
263
+ output_model.set_input(feature_set_name, force=True)
264
+ output_model.upsert_workbench_meta({"workbench_model_target": target_column})
265
+ output_model.upsert_workbench_meta({"workbench_model_features": feature_list})
266
+ output_model.upsert_workbench_meta({"child_endpoints": child_endpoints})
267
+ output_model.onboard_with_args(ModelType.UQ_REGRESSOR, target_column, feature_list=feature_list)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ """Exercise the MetaModel Class"""
272
+
273
+ meta = MetaModel.create(
274
+ name="logd-meta",
275
+ child_endpoints=["logd-xgb", "logd-pytorch", "logd-chemprop"],
276
+ target_column="logd",
277
+ description="Meta model for LogD prediction",
278
+ tags=["meta", "logd", "ensemble"],
279
+ )
280
+ print(meta.summary())
281
+
282
+ # Create an endpoint for the meta model
283
+ end = meta.to_endpoint(tags=["meta", "logd"])
284
+ end.set_owner("BW")
285
+ end.auto_inference()
286
+
287
+ # Test loading an existing meta model
288
+ meta = MetaModel("logd-meta")
289
+ print(meta.details())
workbench/api/model.py CHANGED
@@ -7,10 +7,15 @@ Dashboard UI, which provides additional model details and performance metrics
7
7
 
8
8
  # Workbench Imports
9
9
  from workbench.core.artifacts.artifact import Artifact
10
- from workbench.core.artifacts.model_core import ModelCore, ModelType # noqa: F401
10
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework # noqa: F401
11
11
  from workbench.core.transforms.model_to_endpoint.model_to_endpoint import ModelToEndpoint
12
12
  from workbench.api.endpoint import Endpoint
13
- from workbench.utils.model_utils import proximity_model, uq_model
13
+ from workbench.utils.model_utils import (
14
+ proximity_model_local,
15
+ fingerprint_prox_model_local,
16
+ noise_model_local,
17
+ cleanlab_model_local,
18
+ )
14
19
 
15
20
 
16
21
  class Model(ModelCore):
@@ -83,33 +88,55 @@ class Model(ModelCore):
83
88
  end.set_owner(self.get_owner())
84
89
  return end
85
90
 
86
- def prox_model(self, prox_model_name: str = None, track_columns: list = None) -> "Model":
87
- """Create a Proximity Model for this Model
91
+ def prox_model(self, include_all_columns: bool = False):
92
+ """Create a local Proximity Model for this Model
88
93
 
89
94
  Args:
90
- prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
91
- track_columns (list, optional): List of columns to track in the Proximity Model.
95
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
92
96
 
93
97
  Returns:
94
- Model: The Proximity Model
98
+ FeatureSpaceProximity: A local FeatureSpaceProximity Model
95
99
  """
96
- if prox_model_name is None:
97
- prox_model_name = self.model_name + "-prox"
98
- return proximity_model(self, prox_model_name, track_columns=track_columns)
100
+ return proximity_model_local(self, include_all_columns=include_all_columns)
99
101
 
100
- def uq_model(self, uq_model_name: str = None, train_all_data: bool = False) -> "Model":
101
- """Create a Uncertainty Quantification Model for this Model
102
+ def fp_prox_model(
103
+ self,
104
+ include_all_columns: bool = False,
105
+ radius: int = 2,
106
+ n_bits: int = 1024,
107
+ counts: bool = False,
108
+ ):
109
+ """Create a local Fingerprint Proximity Model for this Model
102
110
 
103
111
  Args:
104
- uq_model_name (str, optional): Name of the UQ Model (if not specified, a name will be generated)
105
- train_all_data (bool, optional): Whether to train the UQ Model on all data (default: False)
112
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
113
+ radius (int): Morgan fingerprint radius (default: 2)
114
+ n_bits (int): Number of bits for the fingerprint (default: 1024)
115
+ counts (bool): Use count fingerprints instead of binary (default: False)
116
+
117
+ Returns:
118
+ FingerprintProximity: A local FingerprintProximity Model
119
+ """
120
+ return fingerprint_prox_model_local(
121
+ self, include_all_columns=include_all_columns, radius=radius, n_bits=n_bits, counts=counts
122
+ )
123
+
124
+ def noise_model(self):
125
+ """Create a local Noise Model for this Model
106
126
 
107
127
  Returns:
108
- Model: The UQ Model
128
+ NoiseModel: A local Noise Model
109
129
  """
110
- if uq_model_name is None:
111
- uq_model_name = self.model_name + "-uq"
112
- return uq_model(self, uq_model_name, train_all_data=train_all_data)
130
+ return noise_model_local(self)
131
+
132
+ def cleanlab_model(self):
133
+ """Create a CleanLearning model for this Model's training data.
134
+
135
+ Returns:
136
+ CleanLearning: A fitted cleanlab model. Use get_label_issues() to get
137
+ a DataFrame with id_column, label_quality, predicted_label, given_label, is_label_issue.
138
+ """
139
+ return cleanlab_model_local(self)
113
140
 
114
141
 
115
142
  if __name__ == "__main__":
@@ -121,6 +148,10 @@ if __name__ == "__main__":
121
148
  pprint(my_model.summary())
122
149
  pprint(my_model.details())
123
150
 
124
- # Create an Endpoint from the Model
125
- my_endpoint = my_model.to_endpoint()
126
- pprint(my_endpoint.summary())
151
+ # Create an Endpoint from the Model (commented out for now)
152
+ # my_endpoint = my_model.to_endpoint()
153
+ # pprint(my_endpoint.summary())
154
+
155
+ # Create a local Proximity Model for this Model
156
+ prox_model = my_model.prox_model()
157
+ print(prox_model.neighbors(3398))
@@ -1,13 +1,10 @@
1
1
  """ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store."""
2
2
 
3
- from typing import Union
4
- import logging
5
-
6
3
  # Workbench Imports
7
- from workbench.core.cloud_platform.aws.aws_parameter_store import AWSParameterStore
4
+ from workbench.core.artifacts.parameter_store_core import ParameterStoreCore
8
5
 
9
6
 
10
- class ParameterStore(AWSParameterStore):
7
+ class ParameterStore(ParameterStoreCore):
11
8
  """ParameterStore: Manages Workbench parameters in a Cloud Based Parameter Store.
12
9
 
13
10
  Common Usage:
@@ -43,56 +40,10 @@ class ParameterStore(AWSParameterStore):
43
40
 
44
41
  def __init__(self):
45
42
  """ParameterStore Init Method"""
46
- self.log = logging.getLogger("workbench")
47
43
 
48
- # Initialize the SuperClass
44
+ # Initialize parent class
49
45
  super().__init__()
50
46
 
51
- def list(self, prefix: str = None) -> list:
52
- """List all parameters in the AWS Parameter Store, optionally filtering by a prefix.
53
-
54
- Args:
55
- prefix (str, optional): A prefix to filter the parameters by. Defaults to None.
56
-
57
- Returns:
58
- list: A list of parameter names and details.
59
- """
60
- return super().list(prefix=prefix)
61
-
62
- def get(self, name: str, warn: bool = True, decrypt: bool = True) -> Union[str, list, dict, None]:
63
- """Retrieve a parameter value from the AWS Parameter Store.
64
-
65
- Args:
66
- name (str): The name of the parameter to retrieve.
67
- warn (bool): Whether to log a warning if the parameter is not found.
68
- decrypt (bool): Whether to decrypt secure string parameters.
69
-
70
- Returns:
71
- Union[str, list, dict, None]: The value of the parameter or None if not found.
72
- """
73
- return super().get(name=name, warn=warn, decrypt=decrypt)
74
-
75
- def upsert(self, name: str, value):
76
- """Insert or update a parameter in the AWS Parameter Store.
77
-
78
- Args:
79
- name (str): The name of the parameter.
80
- value (str | list | dict): The value of the parameter.
81
- """
82
- super().upsert(name=name, value=value)
83
-
84
- def delete(self, name: str):
85
- """Delete a parameter from the AWS Parameter Store.
86
-
87
- Args:
88
- name (str): The name of the parameter to delete.
89
- """
90
- super().delete(name=name)
91
-
92
- def __repr__(self):
93
- """Return a string representation of the ParameterStore object."""
94
- return super().__repr__()
95
-
96
47
 
97
48
  if __name__ == "__main__":
98
49
  """Exercise the ParameterStore Class"""
@@ -6,7 +6,6 @@ import pandas as pd
6
6
  from functools import wraps
7
7
  from concurrent.futures import ThreadPoolExecutor
8
8
 
9
-
10
9
  # Workbench Imports
11
10
  from workbench.core.cloud_platform.cloud_meta import CloudMeta
12
11
  from workbench.utils.workbench_cache import WorkbenchCache