workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  """FeaturesToModel: Train/Create a Model from a Feature Set"""
2
2
 
3
3
  from pathlib import Path
4
+ from typing import Union
4
5
  from sagemaker.estimator import Estimator
5
6
  import awswrangler as wr
6
7
  from datetime import datetime, timezone
@@ -9,7 +10,7 @@ import time
9
10
  # Local Imports
10
11
  from workbench.core.transforms.transform import Transform, TransformInput, TransformOutput
11
12
  from workbench.core.artifacts.feature_set_core import FeatureSetCore
12
- from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelImages
13
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework, ModelImages
13
14
  from workbench.core.artifacts.artifact import Artifact
14
15
  from workbench.model_scripts.script_generation import generate_model_script, fill_template
15
16
  from workbench.utils.model_utils import supported_instance_types
@@ -33,6 +34,7 @@ class FeaturesToModel(Transform):
33
34
  feature_name: str,
34
35
  model_name: str,
35
36
  model_type: ModelType,
37
+ model_framework=ModelFramework.XGBOOST,
36
38
  model_class=None,
37
39
  model_import_str=None,
38
40
  custom_script=None,
@@ -46,6 +48,7 @@ class FeaturesToModel(Transform):
46
48
  feature_name (str): Name of the FeatureSet to use as input
47
49
  model_name (str): Name of the Model to create as output
48
50
  model_type (ModelType): ModelType.REGRESSOR or ModelType.CLASSIFIER, etc.
51
+ model_framework (ModelFramework, optional): The model framework (default ModelFramework.XGBOOST)
49
52
  model_class (str, optional): The scikit model (e.g. KNeighborsRegressor) (default None)
50
53
  model_import_str (str, optional): The import string for the model (default None)
51
54
  custom_script (str, optional): Custom script to use for the model (default None)
@@ -65,6 +68,7 @@ class FeaturesToModel(Transform):
65
68
  self.input_type = TransformInput.FEATURE_SET
66
69
  self.output_type = TransformOutput.MODEL
67
70
  self.model_type = model_type
71
+ self.model_framework = model_framework
68
72
  self.model_class = model_class
69
73
  self.model_import_str = model_import_str
70
74
  self.custom_script = str(custom_script) if custom_script else None
@@ -80,12 +84,17 @@ class FeaturesToModel(Transform):
80
84
  self.inference_arch = inference_arch
81
85
 
82
86
  def transform_impl(
83
- self, target_column: str, description: str = None, feature_list: list = None, train_all_data=False, **kwargs
87
+ self,
88
+ target_column: Union[str, list[str]],
89
+ description: str = None,
90
+ feature_list: list = None,
91
+ train_all_data=False,
92
+ **kwargs,
84
93
  ):
85
94
  """Generic Features to Model: Note you should create a new class and inherit from
86
95
  this one to include specific logic for your Feature Set/Model
87
96
  Args:
88
- target_column (str): Column name of the target variable
97
+ target_column (str or list[str]): Column name(s) of the target variable(s)
89
98
  description (str): Description of the model (optional)
90
99
  feature_list (list[str]): A list of columns for the features (default None, will try to guess)
91
100
  train_all_data (bool): Train on ALL (100%) of the data (default False)
@@ -102,9 +111,11 @@ class FeaturesToModel(Transform):
102
111
  s3_training_path = feature_set.create_s3_training_data()
103
112
  self.log.info(f"Created new training data {s3_training_path}...")
104
113
 
105
- # Report the target column
114
+ # Report the target column(s)
106
115
  self.target_column = target_column
107
- self.log.info(f"Target column: {self.target_column}")
116
+ # Normalize target_column to a list for internal use
117
+ target_list = [target_column] if isinstance(target_column, str) else (target_column or [])
118
+ self.log.info(f"Target column(s): {self.target_column}")
108
119
 
109
120
  # Did they specify a feature list?
110
121
  if feature_list:
@@ -131,7 +142,7 @@ class FeaturesToModel(Transform):
131
142
  "is_deleted",
132
143
  "event_time",
133
144
  "training",
134
- ] + [self.target_column]
145
+ ] + target_list
135
146
  feature_list = [c for c in all_columns if c not in filter_list]
136
147
 
137
148
  # AWS Feature Store has 3 user column types (String, Integral, Fractional)
@@ -154,11 +165,14 @@ class FeaturesToModel(Transform):
154
165
  self.log.important(f"Feature List for Modeling: {self.model_feature_list}")
155
166
 
156
167
  # Set up our parameters for the model script
168
+ # ChemProp expects target_column as a list; other templates expect a string
169
+ target_for_template = target_list if self.model_framework == ModelFramework.CHEMPROP else self.target_column
157
170
  template_params = {
158
171
  "model_imports": self.model_import_str,
159
172
  "model_type": self.model_type,
173
+ "model_framework": self.model_framework,
160
174
  "model_class": self.model_class,
161
- "target_column": self.target_column,
175
+ "target_column": target_for_template,
162
176
  "feature_list": self.model_feature_list,
163
177
  "compressed_features": feature_set.get_compressed_features(),
164
178
  "model_metrics_s3_path": self.model_training_root,
@@ -184,23 +198,27 @@ class FeaturesToModel(Transform):
184
198
  # Generate our model script
185
199
  script_path = generate_model_script(template_params)
186
200
 
187
- # Metric Definitions for Regression
201
+ # Metric Definitions for Regression (matches model script output format)
188
202
  if self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
189
203
  metric_definitions = [
190
- {"Name": "RMSE", "Regex": "RMSE: ([0-9.]+)"},
191
- {"Name": "MAE", "Regex": "MAE: ([0-9.]+)"},
192
- {"Name": "R2", "Regex": "R2: ([0-9.]+)"},
193
- {"Name": "NumRows", "Regex": "NumRows: ([0-9]+)"},
204
+ {"Name": "rmse", "Regex": r"rmse: ([0-9.]+)"},
205
+ {"Name": "mae", "Regex": r"mae: ([0-9.]+)"},
206
+ {"Name": "medae", "Regex": r"medae: ([0-9.]+)"},
207
+ {"Name": "r2", "Regex": r"r2: ([0-9.-]+)"},
208
+ {"Name": "spearmanr", "Regex": r"spearmanr: ([0-9.-]+)"},
209
+ {"Name": "support", "Regex": r"support: ([0-9]+)"},
194
210
  ]
195
211
 
196
212
  # Metric Definitions for Classification
197
213
  elif self.model_type == ModelType.CLASSIFIER:
198
214
  # We need to get creative with the Classification Metrics
215
+ # Note: Classification only supports single target
216
+ class_target = target_list[0] if target_list else self.target_column
199
217
 
200
218
  # Grab all the target column class values (class labels)
201
219
  table = feature_set.data_source.table
202
- self.class_labels = feature_set.query(f'select DISTINCT {self.target_column} FROM "{table}"')[
203
- self.target_column
220
+ self.class_labels = feature_set.query(f'select DISTINCT {class_target} FROM "{table}"')[
221
+ class_target
204
222
  ].to_list()
205
223
 
206
224
  # Sanity check on the targets
@@ -209,20 +227,14 @@ class FeaturesToModel(Transform):
209
227
  self.log.critical(msg)
210
228
  raise ValueError(msg)
211
229
 
212
- # Dynamically create the metric definitions
213
- metrics = ["precision", "recall", "fscore"]
230
+ # Dynamically create the metric definitions (per-class precision/recall/f1/support)
231
+ # Note: Confusion matrix metrics are skipped to stay under SageMaker's 40 metric limit
232
+ metrics = ["precision", "recall", "f1", "support"]
214
233
  metric_definitions = []
215
234
  for t in self.class_labels:
216
235
  for m in metrics:
217
236
  metric_definitions.append({"Name": f"Metrics:{t}:{m}", "Regex": f"Metrics:{t}:{m} ([0-9.]+)"})
218
237
 
219
- # Add the confusion matrix metrics
220
- for row in self.class_labels:
221
- for col in self.class_labels:
222
- metric_definitions.append(
223
- {"Name": f"ConfusionMatrix:{row}:{col}", "Regex": f"ConfusionMatrix:{row}:{col} ([0-9.]+)"}
224
- )
225
-
226
238
  # If the model type is UNKNOWN, our metric_definitions will be empty
227
239
  else:
228
240
  self.log.important(f"ModelType is {self.model_type}, skipping metric_definitions...")
@@ -233,13 +245,21 @@ class FeaturesToModel(Transform):
233
245
  source_dir = str(Path(script_path).parent)
234
246
 
235
247
  # Create a Sagemaker Model with our script
236
- image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image, "0.1")
248
+ image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image)
249
+
250
+ # Use GPU instance for ChemProp/PyTorch, CPU for others
251
+ if self.model_framework in [ModelFramework.CHEMPROP, ModelFramework.PYTORCH]:
252
+ train_instance_type = "ml.g6.xlarge" # NVIDIA L4 GPU, ~$0.80/hr
253
+ self.log.important(f"Using GPU instance {train_instance_type} for {self.model_framework.value}")
254
+ else:
255
+ train_instance_type = "ml.m5.xlarge"
256
+
237
257
  self.estimator = Estimator(
238
258
  entry_point=entry_point,
239
259
  source_dir=source_dir,
240
260
  role=self.workbench_role_arn,
241
261
  instance_count=1,
242
- instance_type="ml.m5.xlarge",
262
+ instance_type=train_instance_type,
243
263
  sagemaker_session=self.sm_session,
244
264
  image_uri=image,
245
265
  metric_definitions=metric_definitions,
@@ -264,13 +284,20 @@ class FeaturesToModel(Transform):
264
284
  self.log.important(f"Creating new model {self.output_name}...")
265
285
  self.create_and_register_model(**kwargs)
266
286
 
287
+ # Make a copy of the training view, to lock-in the training data used for this model
288
+ model_training_view_name = f"{self.output_name.replace('-', '_')}_training"
289
+ self.log.important(f"Creating Model Training View: {model_training_view_name}...")
290
+ feature_set.view("training").copy(f"{model_training_view_name}")
291
+
267
292
  def post_transform(self, **kwargs):
268
293
  """Post-Transform: Calling onboard() on the Model"""
269
294
  self.log.info("Post-Transform: Calling onboard() on the Model...")
270
295
  time.sleep(3) # Give AWS time to complete Model register
271
296
 
272
- # Store the model feature_list and target_column in the workbench_meta
273
- output_model = ModelCore(self.output_name, model_type=self.model_type)
297
+ # Store the model metadata information
298
+ output_model = ModelCore(self.output_name)
299
+ output_model._set_model_type(self.model_type)
300
+ output_model._set_model_framework(self.model_framework)
274
301
  output_model.upsert_workbench_meta({"workbench_model_features": self.model_feature_list})
275
302
  output_model.upsert_workbench_meta({"workbench_model_target": self.target_column})
276
303
 
@@ -301,7 +328,7 @@ class FeaturesToModel(Transform):
301
328
 
302
329
  # Register our model
303
330
  image = ModelImages.get_image_uri(
304
- self.sm_session.boto_region_name, self.inference_image, "0.1", self.inference_arch
331
+ self.sm_session.boto_region_name, self.inference_image, architecture=self.inference_arch
305
332
  )
306
333
  self.log.important(f"Registering model {self.output_name} with Inference Image {image}...")
307
334
  model = self.estimator.create_model(role=self.workbench_role_arn)
@@ -325,12 +352,11 @@ if __name__ == "__main__":
325
352
 
326
353
  # Regression Model
327
354
  input_name = "abalone_features"
328
- output_name = "test-abalone-regression"
355
+ output_name = "abalone-regression"
329
356
  to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.REGRESSOR)
330
357
  to_model.set_output_tags(["test"])
331
358
  to_model.transform(target_column="class_number_of_rings", description="Test Abalone Regression")
332
359
 
333
- """
334
360
  # Classification Model
335
361
  input_name = "wine_features"
336
362
  output_name = "wine-classification"
@@ -340,10 +366,10 @@ if __name__ == "__main__":
340
366
 
341
367
  # Quantile Regression Model (Abalone)
342
368
  input_name = "abalone_features"
343
- output_name = "abalone-quantile-reg"
369
+ output_name = "abalone-regression-uq"
344
370
  to_model = FeaturesToModel(input_name, output_name, ModelType.UQ_REGRESSOR)
345
- to_model.set_output_tags(["abalone", "quantiles"])
346
- to_model.transform(target_column="class_number_of_rings", description="Abalone Quantile Regression")
371
+ to_model.set_output_tags(["abalone", "uq"])
372
+ to_model.transform(target_column="class_number_of_rings", description="Abalone UQ Regression")
347
373
 
348
374
  # Scikit-Learn Kmeans Clustering Model
349
375
  input_name = "wine_features"
@@ -397,7 +423,7 @@ if __name__ == "__main__":
397
423
  scripts_root = Path(__file__).resolve().parents[3] / "model_scripts"
398
424
  my_script = scripts_root / "custom_models" / "chem_info" / "molecular_descriptors.py"
399
425
  input_name = "aqsol_features"
400
- output_name = "smiles-to-taut-md-stereo-v0"
426
+ output_name = "test-smiles-to-taut-md-stereo"
401
427
  to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.TRANSFORMER, custom_script=my_script)
402
428
  to_model.set_output_tags(["smiles", "molecular descriptors"])
403
429
  to_model.transform(target_column=None, feature_list=["smiles"], description="Smiles to Molecular Descriptors")
@@ -410,13 +436,3 @@ if __name__ == "__main__":
410
436
  to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.TRANSFORMER, custom_script=my_script)
411
437
  to_model.set_output_tags(["smiles", "morgan fingerprints"])
412
438
  to_model.transform(target_column=None, feature_list=["smiles"], description="Smiles to Morgan Fingerprints")
413
-
414
- # Tautomerization Model
415
- scripts_root = Path(__file__).resolve().parents[3] / "model_scripts"
416
- my_script = scripts_root / "custom_models" / "chem_info" / "tautomerize.py"
417
- input_name = "aqsol_features"
418
- output_name = "tautomerize-v0"
419
- to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.TRANSFORMER, custom_script=my_script)
420
- to_model.set_output_tags(["smiles", "tautomerization"])
421
- to_model.transform(target_column=None, feature_list=["smiles"], description="Tautomerize Smiles")
422
- """
@@ -1,6 +1,7 @@
1
1
  """ModelToEndpoint: Deploy an Endpoint for a Model"""
2
2
 
3
3
  import time
4
+ from botocore.exceptions import ClientError
4
5
  from sagemaker import ModelPackage
5
6
  from sagemaker.serializers import CSVSerializer
6
7
  from sagemaker.deserializers import CSVDeserializer
@@ -102,10 +103,21 @@ class ModelToEndpoint(Transform):
102
103
  # Is this a serverless deployment?
103
104
  serverless_config = None
104
105
  if self.serverless:
106
+ # For PyTorch or ChemProp we need at least 4GB of memory
107
+ from workbench.api import ModelFramework
108
+
109
+ self.log.info(f"Model Framework: {workbench_model.model_framework}")
110
+ if workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]:
111
+ if mem_size < 4096:
112
+ self.log.important(
113
+ f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)"
114
+ )
115
+ mem_size = 4096
105
116
  serverless_config = ServerlessInferenceConfig(
106
117
  memory_size_in_mb=mem_size,
107
118
  max_concurrency=max_concurrency,
108
119
  )
120
+ self.log.important(f"Serverless Config: Memory={mem_size}MB, MaxConcurrency={max_concurrency}")
109
121
 
110
122
  # Configure data capture if requested (and not serverless)
111
123
  data_capture_config = None
@@ -126,16 +138,37 @@ class ModelToEndpoint(Transform):
126
138
 
127
139
  # Deploy the Endpoint
128
140
  self.log.important(f"Deploying the Endpoint {self.output_name}...")
129
- model_package.deploy(
130
- initial_instance_count=1,
131
- instance_type=self.instance_type,
132
- serverless_inference_config=serverless_config,
133
- endpoint_name=self.output_name,
134
- serializer=CSVSerializer(),
135
- deserializer=CSVDeserializer(),
136
- data_capture_config=data_capture_config,
137
- tags=aws_tags,
138
- )
141
+ try:
142
+ model_package.deploy(
143
+ initial_instance_count=1,
144
+ instance_type=self.instance_type,
145
+ serverless_inference_config=serverless_config,
146
+ endpoint_name=self.output_name,
147
+ serializer=CSVSerializer(),
148
+ deserializer=CSVDeserializer(),
149
+ data_capture_config=data_capture_config,
150
+ tags=aws_tags,
151
+ container_startup_health_check_timeout=300,
152
+ )
153
+ except ClientError as e:
154
+ # Check if this is the "endpoint config already exists" error
155
+ if "Cannot create already existing endpoint configuration" in str(e):
156
+ self.log.warning("Endpoint config already exists, deleting and retrying...")
157
+ self.sm_client.delete_endpoint_config(EndpointConfigName=self.output_name)
158
+ # Retry the deploy
159
+ model_package.deploy(
160
+ initial_instance_count=1,
161
+ instance_type=self.instance_type,
162
+ serverless_inference_config=serverless_config,
163
+ endpoint_name=self.output_name,
164
+ serializer=CSVSerializer(),
165
+ deserializer=CSVDeserializer(),
166
+ data_capture_config=data_capture_config,
167
+ tags=aws_tags,
168
+ container_startup_health_check_timeout=300,
169
+ )
170
+ else:
171
+ raise
139
172
 
140
173
  def post_transform(self, **kwargs):
141
174
  """Post-Transform: Calling onboard() for the Endpoint"""
@@ -68,6 +68,15 @@ class PandasToFeatures(Transform):
68
68
  self.output_df = input_df.copy()
69
69
  self.one_hot_columns = one_hot_columns or []
70
70
 
71
+ # Warn about known AWS Iceberg bug with event_time_column
72
+ if event_time_column is not None:
73
+ self.log.warning(
74
+ f"event_time_column='{event_time_column}' specified. Note: AWS has a known bug with "
75
+ "Iceberg FeatureGroups where varying event times across multiple days can cause "
76
+ "duplicate rows in the offline store. Setting event_time_column=None."
77
+ )
78
+ self.event_time_column = None
79
+
71
80
  # Now Prepare the DataFrame for its journey into an AWS FeatureGroup
72
81
  self.prep_dataframe()
73
82
 
@@ -327,9 +336,36 @@ class PandasToFeatures(Transform):
327
336
  self.delete_existing()
328
337
  self.output_feature_group = self.create_feature_group()
329
338
 
339
+ def mac_spawn_hack(self):
340
+ """Workaround for macOS Tahoe fork/spawn issue with SageMaker FeatureStore ingest.
341
+
342
+ See: https://github.com/aws/sagemaker-python-sdk/issues/5312
343
+ macOS Tahoe 26+ has issues with forked processes creating boto3 sessions.
344
+ This forces spawn mode on macOS to avoid the hang.
345
+ """
346
+ import platform
347
+
348
+ if platform.system() == "Darwin": # macOS
349
+ self.log.warning("macOS detected, forcing 'spawn' mode for multiprocessing (Tahoe hang workaround)")
350
+ import multiprocessing
351
+
352
+ try:
353
+ import multiprocess
354
+
355
+ multiprocess.set_start_method("spawn", force=True)
356
+ except (RuntimeError, ImportError):
357
+ pass # Already set or multiprocess not available
358
+ try:
359
+ multiprocessing.set_start_method("spawn", force=True)
360
+ except RuntimeError:
361
+ pass # Already set
362
+
330
363
  def transform_impl(self):
331
364
  """Transform Implementation: Ingest the data into the Feature Group"""
332
365
 
366
+ # Workaround for macOS Tahoe hang issue
367
+ self.mac_spawn_hack()
368
+
333
369
  # Now we actually push the data into the Feature Group (called ingestion)
334
370
  self.log.important(f"Ingesting rows into Feature Group {self.output_name}...")
335
371
  ingest_manager = self.output_feature_group.ingest(self.output_df, max_workers=8, max_processes=4, wait=False)
@@ -373,7 +409,7 @@ class PandasToFeatures(Transform):
373
409
 
374
410
  # Set Hold Out Ids (if we got them during creation)
375
411
  if self.incoming_hold_out_ids:
376
- self.output_feature_set.set_training_holdouts(self.id_column, self.incoming_hold_out_ids)
412
+ self.output_feature_set.set_training_holdouts(self.incoming_hold_out_ids)
377
413
 
378
414
  def ensure_feature_group_created(self, feature_group):
379
415
  status = feature_group.describe().get("FeatureGroupStatus")
@@ -435,7 +471,7 @@ if __name__ == "__main__":
435
471
 
436
472
  # Create my DF to Feature Set Transform (with one-hot encoding)
437
473
  df_to_features = PandasToFeatures("test_features")
438
- df_to_features.set_input(data_df, id_column="id", one_hot_columns=["food"])
474
+ df_to_features.set_input(data_df, id_column="id", event_time_column="date", one_hot_columns=["food"])
439
475
  df_to_features.set_output_tags(["test", "small"])
440
476
  df_to_features.transform()
441
477
 
@@ -3,14 +3,18 @@
3
3
  from typing import Union
4
4
 
5
5
  # Workbench Imports
6
- from workbench.api import DataSource, FeatureSet
6
+ from workbench.api import FeatureSet
7
7
  from workbench.core.views.view import View
8
8
  from workbench.core.views.create_view import CreateView
9
9
  from workbench.core.views.view_utils import get_column_list
10
10
 
11
11
 
12
12
  class TrainingView(CreateView):
13
- """TrainingView Class: A View with an additional training column that marks holdout ids
13
+ """TrainingView Class: A View with an additional training column (80/20 or holdout ids).
14
+ The TrainingView class creates a SQL view that includes all columns from the source table
15
+ along with an additional boolean column named "training". This view can also include
16
+ a SQL filter expression to filter the rows included in the view.
17
+
14
18
 
15
19
  Common Usage:
16
20
  ```python
@@ -19,8 +23,9 @@ class TrainingView(CreateView):
19
23
  training_view = TrainingView.create(fs)
20
24
  df = training_view.pull_dataframe()
21
25
 
22
- # Create a TrainingView with a specific set of columns
23
- training_view = TrainingView.create(fs, column_list=["my_col1", "my_col2"])
26
+ # Create a TrainingView with a specific filter expression
27
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="age > 30")
28
+ df = training_view.pull_dataframe()
24
29
 
25
30
  # Query the view
26
31
  df = training_view.query(f"SELECT * FROM {training_view.table} where training = TRUE")
@@ -31,17 +36,21 @@ class TrainingView(CreateView):
31
36
  def create(
32
37
  cls,
33
38
  feature_set: FeatureSet,
34
- source_table: str = None,
39
+ *, # Enforce keyword arguments after feature_set
35
40
  id_column: str = None,
36
41
  holdout_ids: Union[list[str], list[int], None] = None,
42
+ filter_expression: str = None,
43
+ source_table: str = None,
37
44
  ) -> Union[View, None]:
38
45
  """Factory method to create and return a TrainingView instance.
39
46
 
40
47
  Args:
41
48
  feature_set (FeatureSet): A FeatureSet object
42
- source_table (str, optional): The table/view to create the view from. Defaults to None.
43
49
  id_column (str, optional): The name of the id column. Defaults to None.
44
50
  holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
51
+ filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
52
+ Defaults to None.
53
+ source_table (str, optional): The table/view to create the view from. Defaults to None.
45
54
 
46
55
  Returns:
47
56
  Union[View, None]: The created View object (or None if failed to create the view)
@@ -69,28 +78,36 @@ class TrainingView(CreateView):
69
78
  else:
70
79
  id_column = instance.auto_id_column
71
80
 
72
- # If we don't have holdout ids, create a default training view
73
- if not holdout_ids:
74
- instance._default_training_view(instance.data_source, id_column)
75
- return View(instance.data_source, instance.view_name, auto_create_view=False)
81
+ # Enclose each column name in double quotes
82
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
83
+
84
+ # Build the training assignment logic
85
+ if holdout_ids:
86
+ # Format the list of holdout ids for SQL IN clause
87
+ if all(isinstance(id, str) for id in holdout_ids):
88
+ formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
89
+ else:
90
+ formatted_holdout_ids = ", ".join(map(str, holdout_ids))
76
91
 
77
- # Format the list of holdout ids for SQL IN clause
78
- if holdout_ids and all(isinstance(id, str) for id in holdout_ids):
79
- formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
92
+ training_logic = f"""CASE
93
+ WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
94
+ ELSE True
95
+ END AS training"""
80
96
  else:
81
- formatted_holdout_ids = ", ".join(map(str, holdout_ids))
97
+ # Default 80/20 split using modulo
98
+ training_logic = f"""CASE
99
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
100
+ ELSE False
101
+ END AS training"""
82
102
 
83
- # Enclose each column name in double quotes
84
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
103
+ # Build WHERE clause if filter_expression is provided
104
+ where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
85
105
 
86
106
  # Construct the CREATE VIEW query
87
107
  create_view_query = f"""
88
108
  CREATE OR REPLACE VIEW {instance.table} AS
89
- SELECT {sql_columns}, CASE
90
- WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
91
- ELSE True
92
- END AS training
93
- FROM {instance.source_table}
109
+ SELECT {sql_columns}, {training_logic}
110
+ FROM {instance.source_table}{where_clause}
94
111
  """
95
112
 
96
113
  # Execute the CREATE VIEW query
@@ -99,35 +116,56 @@ class TrainingView(CreateView):
99
116
  # Return the View
100
117
  return View(instance.data_source, instance.view_name, auto_create_view=False)
101
118
 
102
- # This is an internal method that's used to create a default training view
103
- def _default_training_view(self, data_source: DataSource, id_column: str):
104
- """Create a default view in Athena that assigns roughly 80% of the data to training
119
+ @classmethod
120
+ def create_with_sql(
121
+ cls,
122
+ feature_set: FeatureSet,
123
+ *,
124
+ sql_query: str,
125
+ id_column: str = None,
126
+ ) -> Union[View, None]:
127
+ """Factory method to create a TrainingView from a custom SQL query.
128
+
129
+ This method takes a complete SQL query and adds the default 80/20 training split.
130
+ Use this when you need complex queries like UNION ALL for oversampling.
105
131
 
106
132
  Args:
107
- data_source (DataSource): The Workbench DataSource object
108
- id_column (str): The name of the id column
133
+ feature_set (FeatureSet): A FeatureSet object
134
+ sql_query (str): Complete SELECT query (without the final semicolon)
135
+ id_column (str, optional): The name of the id column for training split. Defaults to None.
136
+
137
+ Returns:
138
+ Union[View, None]: The created View object (or None if failed)
109
139
  """
110
- self.log.important(f"Creating default Training View {self.table}...")
140
+ # Instantiate the TrainingView
141
+ instance = cls("training", feature_set)
111
142
 
112
- # Drop any columns generated from AWS
113
- aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
114
- column_list = [col for col in data_source.columns if col not in aws_cols]
143
+ # Sanity check on the id column
144
+ if not id_column:
145
+ instance.log.important("No id column specified, using auto_id_column")
146
+ if not instance.auto_id_column:
147
+ instance.log.error("No id column specified and no auto_id_column found, aborting")
148
+ return None
149
+ id_column = instance.auto_id_column
115
150
 
116
- # Enclose each column name in double quotes
117
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
151
+ # Default 80/20 split using modulo
152
+ training_logic = f"""CASE
153
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
154
+ ELSE False
155
+ END AS training"""
118
156
 
119
- # Construct the CREATE VIEW query with a simple modulo operation for the 80/20 split
157
+ # Wrap the custom query and add training column
120
158
  create_view_query = f"""
121
- CREATE OR REPLACE VIEW "{self.table}" AS
122
- SELECT {sql_columns}, CASE
123
- WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True -- Assign 80% to training
124
- ELSE False -- Assign roughly 20% to validation/test
125
- END AS training
126
- FROM {self.base_table_name}
159
+ CREATE OR REPLACE VIEW {instance.table} AS
160
+ SELECT *, {training_logic}
161
+ FROM ({sql_query}) AS custom_source
127
162
  """
128
163
 
129
164
  # Execute the CREATE VIEW query
130
- data_source.execute_statement(create_view_query)
165
+ instance.data_source.execute_statement(create_view_query)
166
+
167
+ # Return the View
168
+ return View(instance.data_source, instance.view_name, auto_create_view=False)
131
169
 
132
170
 
133
171
  if __name__ == "__main__":
@@ -135,7 +173,7 @@ if __name__ == "__main__":
135
173
  from workbench.api import FeatureSet
136
174
 
137
175
  # Get the FeatureSet
138
- fs = FeatureSet("test_features")
176
+ fs = FeatureSet("abalone_features")
139
177
 
140
178
  # Delete the existing training view
141
179
  training_view = TrainingView.create(fs)
@@ -152,9 +190,42 @@ if __name__ == "__main__":
152
190
 
153
191
  # Create a TrainingView with holdout ids
154
192
  my_holdout_ids = list(range(10))
155
- training_view = TrainingView.create(fs, id_column="id", holdout_ids=my_holdout_ids)
193
+ training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
156
194
 
157
195
  # Pull the training data
158
196
  df = training_view.pull_dataframe()
159
197
  print(df.head())
160
198
  print(df["training"].value_counts())
199
+ print(f"Shape: {df.shape}")
200
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
201
+
202
+ # Test the filter expression
203
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
204
+ df = training_view.pull_dataframe()
205
+ print(df.head())
206
+ print(f"Shape with filter: {df.shape}")
207
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
208
+
209
+ # Test create_with_sql with a custom query (UNION ALL for oversampling)
210
+ print("\n--- Testing create_with_sql with oversampling ---")
211
+ base_table = fs.table
212
+ replicate_ids = [0, 1, 2] # Oversample these IDs
213
+
214
+ custom_sql = f"""
215
+ SELECT * FROM {base_table}
216
+
217
+ UNION ALL
218
+
219
+ SELECT * FROM {base_table}
220
+ WHERE auto_id IN ({', '.join(map(str, replicate_ids))})
221
+ """
222
+
223
+ training_view = TrainingView.create_with_sql(fs, sql_query=custom_sql, id_column="auto_id")
224
+ df = training_view.pull_dataframe()
225
+ print(f"Shape with custom SQL: {df.shape}")
226
+ print(df["training"].value_counts())
227
+
228
+ # Verify oversampling - check if replicated IDs appear twice
229
+ for rep_id in replicate_ids:
230
+ count = len(df[df["auto_id"] == rep_id])
231
+ print(f"ID {rep_id} appears {count} times")