workbench 0.8.160__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (114) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +12 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/api/parameter_store.py +5 -0
  13. workbench/core/artifacts/__init__.py +11 -2
  14. workbench/core/artifacts/artifact.py +11 -3
  15. workbench/core/artifacts/data_capture_core.py +355 -0
  16. workbench/core/artifacts/endpoint_core.py +256 -118
  17. workbench/core/artifacts/feature_set_core.py +265 -16
  18. workbench/core/artifacts/model_core.py +110 -63
  19. workbench/core/artifacts/monitor_core.py +33 -248
  20. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  21. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  22. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  23. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  24. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  25. workbench/core/transforms/features_to_model/features_to_model.py +45 -33
  26. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  27. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  28. workbench/core/views/training_view.py +113 -42
  29. workbench/core/views/view.py +53 -3
  30. workbench/core/views/view_utils.py +4 -4
  31. workbench/model_scripts/chemprop/chemprop.template +852 -0
  32. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  33. workbench/model_scripts/chemprop/requirements.txt +11 -0
  34. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  36. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  37. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  38. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  39. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  40. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  41. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  42. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  43. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  44. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  45. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  46. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  47. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  48. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  49. workbench/model_scripts/pytorch_model/generated_model_script.py +390 -188
  50. workbench/model_scripts/pytorch_model/pytorch.template +387 -176
  51. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  52. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  53. workbench/model_scripts/script_generation.py +19 -10
  54. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  55. workbench/model_scripts/uq_models/mapie.template +605 -0
  56. workbench/model_scripts/uq_models/requirements.txt +1 -0
  57. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  58. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  59. workbench/repl/workbench_shell.py +28 -14
  60. workbench/scripts/endpoint_test.py +162 -0
  61. workbench/scripts/lambda_test.py +73 -0
  62. workbench/scripts/ml_pipeline_batch.py +137 -0
  63. workbench/scripts/ml_pipeline_sqs.py +186 -0
  64. workbench/scripts/monitor_cloud_watch.py +20 -100
  65. workbench/utils/aws_utils.py +4 -3
  66. workbench/utils/chem_utils/__init__.py +0 -0
  67. workbench/utils/chem_utils/fingerprints.py +134 -0
  68. workbench/utils/chem_utils/misc.py +194 -0
  69. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  70. workbench/utils/chem_utils/mol_standardize.py +450 -0
  71. workbench/utils/chem_utils/mol_tagging.py +348 -0
  72. workbench/utils/chem_utils/projections.py +209 -0
  73. workbench/utils/chem_utils/salts.py +256 -0
  74. workbench/utils/chem_utils/sdf.py +292 -0
  75. workbench/utils/chem_utils/toxicity.py +250 -0
  76. workbench/utils/chem_utils/vis.py +253 -0
  77. workbench/utils/chemprop_utils.py +760 -0
  78. workbench/utils/cloudwatch_handler.py +1 -1
  79. workbench/utils/cloudwatch_utils.py +137 -0
  80. workbench/utils/config_manager.py +3 -7
  81. workbench/utils/endpoint_utils.py +5 -7
  82. workbench/utils/license_manager.py +2 -6
  83. workbench/utils/model_utils.py +95 -34
  84. workbench/utils/monitor_utils.py +44 -62
  85. workbench/utils/pandas_utils.py +3 -3
  86. workbench/utils/pytorch_utils.py +526 -0
  87. workbench/utils/shap_utils.py +10 -2
  88. workbench/utils/workbench_logging.py +0 -3
  89. workbench/utils/workbench_sqs.py +1 -1
  90. workbench/utils/xgboost_model_utils.py +371 -156
  91. workbench/web_interface/components/model_plot.py +7 -1
  92. workbench/web_interface/components/plugin_unit_test.py +5 -2
  93. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  94. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  95. workbench/web_interface/components/plugins/model_details.py +9 -7
  96. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  97. {workbench-0.8.160.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  98. {workbench-0.8.160.dist-info → workbench-0.8.202.dist-info}/RECORD +102 -86
  99. {workbench-0.8.160.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  100. {workbench-0.8.160.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  101. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  102. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  103. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  104. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  105. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  106. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  107. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  108. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  109. workbench/utils/chem_utils.py +0 -1556
  110. workbench/utils/execution_environment.py +0 -211
  111. workbench/utils/fast_inference.py +0 -167
  112. workbench/utils/resource_utils.py +0 -39
  113. {workbench-0.8.160.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  114. {workbench-0.8.160.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import time
9
9
  # Local Imports
10
10
  from workbench.core.transforms.transform import Transform, TransformInput, TransformOutput
11
11
  from workbench.core.artifacts.feature_set_core import FeatureSetCore
12
- from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelImages
12
+ from workbench.core.artifacts.model_core import ModelCore, ModelType, ModelFramework, ModelImages
13
13
  from workbench.core.artifacts.artifact import Artifact
14
14
  from workbench.model_scripts.script_generation import generate_model_script, fill_template
15
15
  from workbench.utils.model_utils import supported_instance_types
@@ -33,12 +33,13 @@ class FeaturesToModel(Transform):
33
33
  feature_name: str,
34
34
  model_name: str,
35
35
  model_type: ModelType,
36
+ model_framework=ModelFramework.XGBOOST,
36
37
  model_class=None,
37
38
  model_import_str=None,
38
39
  custom_script=None,
39
40
  custom_args=None,
40
- training_image="xgb_training",
41
- inference_image="xgb_inference",
41
+ training_image="training",
42
+ inference_image="inference",
42
43
  inference_arch="x86_64",
43
44
  ):
44
45
  """FeaturesToModel Initialization
@@ -46,12 +47,13 @@ class FeaturesToModel(Transform):
46
47
  feature_name (str): Name of the FeatureSet to use as input
47
48
  model_name (str): Name of the Model to create as output
48
49
  model_type (ModelType): ModelType.REGRESSOR or ModelType.CLASSIFIER, etc.
50
+ model_framework (ModelFramework, optional): The model framework (default ModelFramework.XGBOOST)
49
51
  model_class (str, optional): The scikit model (e.g. KNeighborsRegressor) (default None)
50
52
  model_import_str (str, optional): The import string for the model (default None)
51
53
  custom_script (str, optional): Custom script to use for the model (default None)
52
54
  custom_args (dict, optional): Custom arguments to pass to custom model scripts (default None)
53
- training_image (str, optional): Training image (default "xgb_training")
54
- inference_image (str, optional): Inference image (default "xgb_inference")
55
+ training_image (str, optional): Training image (default "training")
56
+ inference_image (str, optional): Inference image (default "inference")
55
57
  inference_arch (str, optional): Inference architecture (default "x86_64")
56
58
  """
57
59
 
@@ -65,6 +67,7 @@ class FeaturesToModel(Transform):
65
67
  self.input_type = TransformInput.FEATURE_SET
66
68
  self.output_type = TransformOutput.MODEL
67
69
  self.model_type = model_type
70
+ self.model_framework = model_framework
68
71
  self.model_class = model_class
69
72
  self.model_import_str = model_import_str
70
73
  self.custom_script = str(custom_script) if custom_script else None
@@ -157,6 +160,7 @@ class FeaturesToModel(Transform):
157
160
  template_params = {
158
161
  "model_imports": self.model_import_str,
159
162
  "model_type": self.model_type,
163
+ "model_framework": self.model_framework,
160
164
  "model_class": self.model_class,
161
165
  "target_column": self.target_column,
162
166
  "feature_list": self.model_feature_list,
@@ -164,6 +168,7 @@ class FeaturesToModel(Transform):
164
168
  "model_metrics_s3_path": self.model_training_root,
165
169
  "train_all_data": train_all_data,
166
170
  "id_column": feature_set.id_column,
171
+ "hyperparameters": kwargs.get("hyperparameters", {}),
167
172
  }
168
173
 
169
174
  # Custom Script
@@ -183,13 +188,15 @@ class FeaturesToModel(Transform):
183
188
  # Generate our model script
184
189
  script_path = generate_model_script(template_params)
185
190
 
186
- # Metric Definitions for Regression
191
+ # Metric Definitions for Regression (matches model script output format)
187
192
  if self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]:
188
193
  metric_definitions = [
189
- {"Name": "RMSE", "Regex": "RMSE: ([0-9.]+)"},
190
- {"Name": "MAE", "Regex": "MAE: ([0-9.]+)"},
191
- {"Name": "R2", "Regex": "R2: ([0-9.]+)"},
192
- {"Name": "NumRows", "Regex": "NumRows: ([0-9]+)"},
194
+ {"Name": "rmse", "Regex": r"rmse: ([0-9.]+)"},
195
+ {"Name": "mae", "Regex": r"mae: ([0-9.]+)"},
196
+ {"Name": "medae", "Regex": r"medae: ([0-9.]+)"},
197
+ {"Name": "r2", "Regex": r"r2: ([0-9.-]+)"},
198
+ {"Name": "spearmanr", "Regex": r"spearmanr: ([0-9.-]+)"},
199
+ {"Name": "support", "Regex": r"support: ([0-9]+)"},
193
200
  ]
194
201
 
195
202
  # Metric Definitions for Classification
@@ -209,7 +216,7 @@ class FeaturesToModel(Transform):
209
216
  raise ValueError(msg)
210
217
 
211
218
  # Dynamically create the metric definitions
212
- metrics = ["precision", "recall", "fscore"]
219
+ metrics = ["precision", "recall", "f1"]
213
220
  metric_definitions = []
214
221
  for t in self.class_labels:
215
222
  for m in metrics:
@@ -232,13 +239,21 @@ class FeaturesToModel(Transform):
232
239
  source_dir = str(Path(script_path).parent)
233
240
 
234
241
  # Create a Sagemaker Model with our script
235
- image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image, "0.1")
242
+ image = ModelImages.get_image_uri(self.sm_session.boto_region_name, self.training_image)
243
+
244
+ # Use GPU instance for ChemProp/PyTorch, CPU for others
245
+ if self.model_framework in [ModelFramework.CHEMPROP, ModelFramework.PYTORCH_TABULAR]:
246
+ train_instance_type = "ml.g6.xlarge" # NVIDIA L4 GPU, ~$0.80/hr
247
+ self.log.important(f"Using GPU instance {train_instance_type} for {self.model_framework.value}")
248
+ else:
249
+ train_instance_type = "ml.m5.xlarge"
250
+
236
251
  self.estimator = Estimator(
237
252
  entry_point=entry_point,
238
253
  source_dir=source_dir,
239
254
  role=self.workbench_role_arn,
240
255
  instance_count=1,
241
- instance_type="ml.m5.xlarge",
256
+ instance_type=train_instance_type,
242
257
  sagemaker_session=self.sm_session,
243
258
  image_uri=image,
244
259
  metric_definitions=metric_definitions,
@@ -263,13 +278,20 @@ class FeaturesToModel(Transform):
263
278
  self.log.important(f"Creating new model {self.output_name}...")
264
279
  self.create_and_register_model(**kwargs)
265
280
 
281
+ # Make a copy of the training view, to lock-in the training data used for this model
282
+ model_training_view_name = f"{self.output_name.replace('-', '_')}_training"
283
+ self.log.important(f"Creating Model Training View: {model_training_view_name}...")
284
+ feature_set.view("training").copy(f"{model_training_view_name}")
285
+
266
286
  def post_transform(self, **kwargs):
267
287
  """Post-Transform: Calling onboard() on the Model"""
268
288
  self.log.info("Post-Transform: Calling onboard() on the Model...")
269
289
  time.sleep(3) # Give AWS time to complete Model register
270
290
 
271
- # Store the model feature_list and target_column in the workbench_meta
272
- output_model = ModelCore(self.output_name, model_type=self.model_type)
291
+ # Store the model metadata information
292
+ output_model = ModelCore(self.output_name)
293
+ output_model._set_model_type(self.model_type)
294
+ output_model._set_model_framework(self.model_framework)
273
295
  output_model.upsert_workbench_meta({"workbench_model_features": self.model_feature_list})
274
296
  output_model.upsert_workbench_meta({"workbench_model_target": self.target_column})
275
297
 
@@ -280,11 +302,12 @@ class FeaturesToModel(Transform):
280
302
  # Call the Model onboard method
281
303
  output_model.onboard_with_args(self.model_type, self.target_column, self.model_feature_list)
282
304
 
283
- def create_and_register_model(self, aws_region=None):
305
+ def create_and_register_model(self, aws_region=None, **kwargs):
284
306
  """Create and Register the Model
285
307
 
286
308
  Args:
287
309
  aws_region (str, optional): AWS Region to use (default None)
310
+ **kwargs: Additional keyword arguments to pass to the model registration
288
311
  """
289
312
 
290
313
  # Get the metadata/tags to push into AWS
@@ -299,7 +322,7 @@ class FeaturesToModel(Transform):
299
322
 
300
323
  # Register our model
301
324
  image = ModelImages.get_image_uri(
302
- self.sm_session.boto_region_name, self.inference_image, "0.1", self.inference_arch
325
+ self.sm_session.boto_region_name, self.inference_image, architecture=self.inference_arch
303
326
  )
304
327
  self.log.important(f"Registering model {self.output_name} with Inference Image {image}...")
305
328
  model = self.estimator.create_model(role=self.workbench_role_arn)
@@ -323,12 +346,11 @@ if __name__ == "__main__":
323
346
 
324
347
  # Regression Model
325
348
  input_name = "abalone_features"
326
- output_name = "test-abalone-regression"
349
+ output_name = "abalone-regression"
327
350
  to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.REGRESSOR)
328
351
  to_model.set_output_tags(["test"])
329
352
  to_model.transform(target_column="class_number_of_rings", description="Test Abalone Regression")
330
353
 
331
- """
332
354
  # Classification Model
333
355
  input_name = "wine_features"
334
356
  output_name = "wine-classification"
@@ -338,10 +360,10 @@ if __name__ == "__main__":
338
360
 
339
361
  # Quantile Regression Model (Abalone)
340
362
  input_name = "abalone_features"
341
- output_name = "abalone-quantile-reg"
363
+ output_name = "abalone-regression-uq"
342
364
  to_model = FeaturesToModel(input_name, output_name, ModelType.UQ_REGRESSOR)
343
- to_model.set_output_tags(["abalone", "quantiles"])
344
- to_model.transform(target_column="class_number_of_rings", description="Abalone Quantile Regression")
365
+ to_model.set_output_tags(["abalone", "uq"])
366
+ to_model.transform(target_column="class_number_of_rings", description="Abalone UQ Regression")
345
367
 
346
368
  # Scikit-Learn Kmeans Clustering Model
347
369
  input_name = "wine_features"
@@ -395,7 +417,7 @@ if __name__ == "__main__":
395
417
  scripts_root = Path(__file__).resolve().parents[3] / "model_scripts"
396
418
  my_script = scripts_root / "custom_models" / "chem_info" / "molecular_descriptors.py"
397
419
  input_name = "aqsol_features"
398
- output_name = "smiles-to-taut-md-stereo-v0"
420
+ output_name = "test-smiles-to-taut-md-stereo"
399
421
  to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.TRANSFORMER, custom_script=my_script)
400
422
  to_model.set_output_tags(["smiles", "molecular descriptors"])
401
423
  to_model.transform(target_column=None, feature_list=["smiles"], description="Smiles to Molecular Descriptors")
@@ -408,13 +430,3 @@ if __name__ == "__main__":
408
430
  to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.TRANSFORMER, custom_script=my_script)
409
431
  to_model.set_output_tags(["smiles", "morgan fingerprints"])
410
432
  to_model.transform(target_column=None, feature_list=["smiles"], description="Smiles to Morgan Fingerprints")
411
-
412
- # Tautomerization Model
413
- scripts_root = Path(__file__).resolve().parents[3] / "model_scripts"
414
- my_script = scripts_root / "custom_models" / "chem_info" / "tautomerize.py"
415
- input_name = "aqsol_features"
416
- output_name = "tautomerize-v0"
417
- to_model = FeaturesToModel(input_name, output_name, model_type=ModelType.TRANSFORMER, custom_script=my_script)
418
- to_model.set_output_tags(["smiles", "tautomerization"])
419
- to_model.transform(target_column=None, feature_list=["smiles"], description="Tautomerize Smiles")
420
- """
@@ -5,6 +5,7 @@ from sagemaker import ModelPackage
5
5
  from sagemaker.serializers import CSVSerializer
6
6
  from sagemaker.deserializers import CSVDeserializer
7
7
  from sagemaker.serverless import ServerlessInferenceConfig
8
+ from sagemaker.model_monitor import DataCaptureConfig
8
9
 
9
10
  # Local Imports
10
11
  from workbench.core.transforms.transform import Transform, TransformInput, TransformOutput
@@ -51,27 +52,38 @@ class ModelToEndpoint(Transform):
51
52
  EndpointCore.managed_delete(self.output_name)
52
53
 
53
54
  # Get the Model Package ARN for our input model
54
- input_model = ModelCore(self.input_name)
55
- model_package_arn = input_model.model_package_arn()
55
+ workbench_model = ModelCore(self.input_name)
56
56
 
57
57
  # Deploy the model
58
- self._deploy_model(model_package_arn, **kwargs)
58
+ self._deploy_model(workbench_model, **kwargs)
59
59
 
60
60
  # Add this endpoint to the set of registered endpoints for the model
61
- input_model.register_endpoint(self.output_name)
61
+ workbench_model.register_endpoint(self.output_name)
62
62
 
63
63
  # This ensures that the endpoint is ready for use
64
64
  time.sleep(5) # We wait for AWS Lag
65
65
  end = EndpointCore(self.output_name)
66
66
  self.log.important(f"Endpoint {end.name} is ready for use")
67
67
 
68
- def _deploy_model(self, model_package_arn: str, mem_size: int = 2048, max_concurrency: int = 5):
68
+ def _deploy_model(
69
+ self,
70
+ workbench_model: ModelCore,
71
+ mem_size: int = 2048,
72
+ max_concurrency: int = 5,
73
+ data_capture: bool = False,
74
+ capture_percentage: int = 100,
75
+ ):
69
76
  """Internal Method: Deploy the Model
70
77
 
71
78
  Args:
72
- model_package_arn(str): The Model Package ARN used to deploy the Endpoint
79
+ workbench_model(ModelCore): The Workbench ModelCore object to deploy
80
+ mem_size(int): Memory size for serverless deployment
81
+ max_concurrency(int): Max concurrency for serverless deployment
82
+ data_capture(bool): Enable data capture during deployment
83
+ capture_percentage(int): Percentage of data to capture. Defaults to 100.
73
84
  """
74
85
  # Grab the specified Model Package
86
+ model_package_arn = workbench_model.model_package_arn()
75
87
  model_package = ModelPackage(
76
88
  role=self.workbench_role_arn,
77
89
  model_package_arn=model_package_arn,
@@ -95,6 +107,23 @@ class ModelToEndpoint(Transform):
95
107
  max_concurrency=max_concurrency,
96
108
  )
97
109
 
110
+ # Configure data capture if requested (and not serverless)
111
+ data_capture_config = None
112
+ if data_capture and not self.serverless:
113
+ # Set up the S3 path for data capture
114
+ base_endpoint_path = f"{workbench_model.endpoints_s3_path}/{self.output_name}"
115
+ data_capture_path = f"{base_endpoint_path}/data_capture"
116
+ self.log.important(f"Configuring Data Capture --> {data_capture_path}")
117
+ data_capture_config = DataCaptureConfig(
118
+ enable_capture=True,
119
+ sampling_percentage=capture_percentage,
120
+ destination_s3_uri=data_capture_path,
121
+ )
122
+ elif data_capture and self.serverless:
123
+ self.log.warning(
124
+ "Data capture is not supported for serverless endpoints. Skipping data capture configuration."
125
+ )
126
+
98
127
  # Deploy the Endpoint
99
128
  self.log.important(f"Deploying the Endpoint {self.output_name}...")
100
129
  model_package.deploy(
@@ -104,6 +133,7 @@ class ModelToEndpoint(Transform):
104
133
  endpoint_name=self.output_name,
105
134
  serializer=CSVSerializer(),
106
135
  deserializer=CSVDeserializer(),
136
+ data_capture_config=data_capture_config,
107
137
  tags=aws_tags,
108
138
  )
109
139
 
@@ -327,9 +327,36 @@ class PandasToFeatures(Transform):
327
327
  self.delete_existing()
328
328
  self.output_feature_group = self.create_feature_group()
329
329
 
330
+ def mac_spawn_hack(self):
331
+ """Workaround for macOS Tahoe fork/spawn issue with SageMaker FeatureStore ingest.
332
+
333
+ See: https://github.com/aws/sagemaker-python-sdk/issues/5312
334
+ macOS Tahoe 26+ has issues with forked processes creating boto3 sessions.
335
+ This forces spawn mode on macOS to avoid the hang.
336
+ """
337
+ import platform
338
+
339
+ if platform.system() == "Darwin": # macOS
340
+ self.log.warning("macOS detected, forcing 'spawn' mode for multiprocessing (Tahoe hang workaround)")
341
+ import multiprocessing
342
+
343
+ try:
344
+ import multiprocess
345
+
346
+ multiprocess.set_start_method("spawn", force=True)
347
+ except (RuntimeError, ImportError):
348
+ pass # Already set or multiprocess not available
349
+ try:
350
+ multiprocessing.set_start_method("spawn", force=True)
351
+ except RuntimeError:
352
+ pass # Already set
353
+
330
354
  def transform_impl(self):
331
355
  """Transform Implementation: Ingest the data into the Feature Group"""
332
356
 
357
+ # Workaround for macOS Tahoe hang issue
358
+ self.mac_spawn_hack()
359
+
333
360
  # Now we actually push the data into the Feature Group (called ingestion)
334
361
  self.log.important(f"Ingesting rows into Feature Group {self.output_name}...")
335
362
  ingest_manager = self.output_feature_group.ingest(self.output_df, max_workers=8, max_processes=4, wait=False)
@@ -3,14 +3,18 @@
3
3
  from typing import Union
4
4
 
5
5
  # Workbench Imports
6
- from workbench.api import DataSource, FeatureSet
6
+ from workbench.api import FeatureSet
7
7
  from workbench.core.views.view import View
8
8
  from workbench.core.views.create_view import CreateView
9
9
  from workbench.core.views.view_utils import get_column_list
10
10
 
11
11
 
12
12
  class TrainingView(CreateView):
13
- """TrainingView Class: A View with an additional training column that marks holdout ids
13
+ """TrainingView Class: A View with an additional training column (80/20 or holdout ids).
14
+ The TrainingView class creates a SQL view that includes all columns from the source table
15
+ along with an additional boolean column named "training". This view can also include
16
+ a SQL filter expression to filter the rows included in the view.
17
+
14
18
 
15
19
  Common Usage:
16
20
  ```python
@@ -19,8 +23,9 @@ class TrainingView(CreateView):
19
23
  training_view = TrainingView.create(fs)
20
24
  df = training_view.pull_dataframe()
21
25
 
22
- # Create a TrainingView with a specific set of columns
23
- training_view = TrainingView.create(fs, column_list=["my_col1", "my_col2"])
26
+ # Create a TrainingView with a specific filter expression
27
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="age > 30")
28
+ df = training_view.pull_dataframe()
24
29
 
25
30
  # Query the view
26
31
  df = training_view.query(f"SELECT * FROM {training_view.table} where training = TRUE")
@@ -31,17 +36,21 @@ class TrainingView(CreateView):
31
36
  def create(
32
37
  cls,
33
38
  feature_set: FeatureSet,
34
- source_table: str = None,
39
+ *, # Enforce keyword arguments after feature_set
35
40
  id_column: str = None,
36
41
  holdout_ids: Union[list[str], list[int], None] = None,
42
+ filter_expression: str = None,
43
+ source_table: str = None,
37
44
  ) -> Union[View, None]:
38
45
  """Factory method to create and return a TrainingView instance.
39
46
 
40
47
  Args:
41
48
  feature_set (FeatureSet): A FeatureSet object
42
- source_table (str, optional): The table/view to create the view from. Defaults to None.
43
49
  id_column (str, optional): The name of the id column. Defaults to None.
44
50
  holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
51
+ filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
52
+ Defaults to None.
53
+ source_table (str, optional): The table/view to create the view from. Defaults to None.
45
54
 
46
55
  Returns:
47
56
  Union[View, None]: The created View object (or None if failed to create the view)
@@ -69,28 +78,36 @@ class TrainingView(CreateView):
69
78
  else:
70
79
  id_column = instance.auto_id_column
71
80
 
72
- # If we don't have holdout ids, create a default training view
73
- if not holdout_ids:
74
- instance._default_training_view(instance.data_source, id_column)
75
- return View(instance.data_source, instance.view_name, auto_create_view=False)
81
+ # Enclose each column name in double quotes
82
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
83
+
84
+ # Build the training assignment logic
85
+ if holdout_ids:
86
+ # Format the list of holdout ids for SQL IN clause
87
+ if all(isinstance(id, str) for id in holdout_ids):
88
+ formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
89
+ else:
90
+ formatted_holdout_ids = ", ".join(map(str, holdout_ids))
76
91
 
77
- # Format the list of holdout ids for SQL IN clause
78
- if holdout_ids and all(isinstance(id, str) for id in holdout_ids):
79
- formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
92
+ training_logic = f"""CASE
93
+ WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
94
+ ELSE True
95
+ END AS training"""
80
96
  else:
81
- formatted_holdout_ids = ", ".join(map(str, holdout_ids))
97
+ # Default 80/20 split using modulo
98
+ training_logic = f"""CASE
99
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
100
+ ELSE False
101
+ END AS training"""
82
102
 
83
- # Enclose each column name in double quotes
84
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
103
+ # Build WHERE clause if filter_expression is provided
104
+ where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
85
105
 
86
106
  # Construct the CREATE VIEW query
87
107
  create_view_query = f"""
88
108
  CREATE OR REPLACE VIEW {instance.table} AS
89
- SELECT {sql_columns}, CASE
90
- WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
91
- ELSE True
92
- END AS training
93
- FROM {instance.source_table}
109
+ SELECT {sql_columns}, {training_logic}
110
+ FROM {instance.source_table}{where_clause}
94
111
  """
95
112
 
96
113
  # Execute the CREATE VIEW query
@@ -99,35 +116,56 @@ class TrainingView(CreateView):
99
116
  # Return the View
100
117
  return View(instance.data_source, instance.view_name, auto_create_view=False)
101
118
 
102
- # This is an internal method that's used to create a default training view
103
- def _default_training_view(self, data_source: DataSource, id_column: str):
104
- """Create a default view in Athena that assigns roughly 80% of the data to training
119
+ @classmethod
120
+ def create_with_sql(
121
+ cls,
122
+ feature_set: FeatureSet,
123
+ *,
124
+ sql_query: str,
125
+ id_column: str = None,
126
+ ) -> Union[View, None]:
127
+ """Factory method to create a TrainingView from a custom SQL query.
128
+
129
+ This method takes a complete SQL query and adds the default 80/20 training split.
130
+ Use this when you need complex queries like UNION ALL for oversampling.
105
131
 
106
132
  Args:
107
- data_source (DataSource): The Workbench DataSource object
108
- id_column (str): The name of the id column
133
+ feature_set (FeatureSet): A FeatureSet object
134
+ sql_query (str): Complete SELECT query (without the final semicolon)
135
+ id_column (str, optional): The name of the id column for training split. Defaults to None.
136
+
137
+ Returns:
138
+ Union[View, None]: The created View object (or None if failed)
109
139
  """
110
- self.log.important(f"Creating default Training View {self.table}...")
140
+ # Instantiate the TrainingView
141
+ instance = cls("training", feature_set)
111
142
 
112
- # Drop any columns generated from AWS
113
- aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
114
- column_list = [col for col in data_source.columns if col not in aws_cols]
143
+ # Sanity check on the id column
144
+ if not id_column:
145
+ instance.log.important("No id column specified, using auto_id_column")
146
+ if not instance.auto_id_column:
147
+ instance.log.error("No id column specified and no auto_id_column found, aborting")
148
+ return None
149
+ id_column = instance.auto_id_column
115
150
 
116
- # Enclose each column name in double quotes
117
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
151
+ # Default 80/20 split using modulo
152
+ training_logic = f"""CASE
153
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
154
+ ELSE False
155
+ END AS training"""
118
156
 
119
- # Construct the CREATE VIEW query with a simple modulo operation for the 80/20 split
157
+ # Wrap the custom query and add training column
120
158
  create_view_query = f"""
121
- CREATE OR REPLACE VIEW "{self.table}" AS
122
- SELECT {sql_columns}, CASE
123
- WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True -- Assign 80% to training
124
- ELSE False -- Assign roughly 20% to validation/test
125
- END AS training
126
- FROM {self.base_table_name}
159
+ CREATE OR REPLACE VIEW {instance.table} AS
160
+ SELECT *, {training_logic}
161
+ FROM ({sql_query}) AS custom_source
127
162
  """
128
163
 
129
164
  # Execute the CREATE VIEW query
130
- data_source.execute_statement(create_view_query)
165
+ instance.data_source.execute_statement(create_view_query)
166
+
167
+ # Return the View
168
+ return View(instance.data_source, instance.view_name, auto_create_view=False)
131
169
 
132
170
 
133
171
  if __name__ == "__main__":
@@ -135,7 +173,7 @@ if __name__ == "__main__":
135
173
  from workbench.api import FeatureSet
136
174
 
137
175
  # Get the FeatureSet
138
- fs = FeatureSet("test_features")
176
+ fs = FeatureSet("abalone_features")
139
177
 
140
178
  # Delete the existing training view
141
179
  training_view = TrainingView.create(fs)
@@ -152,9 +190,42 @@ if __name__ == "__main__":
152
190
 
153
191
  # Create a TrainingView with holdout ids
154
192
  my_holdout_ids = list(range(10))
155
- training_view = TrainingView.create(fs, id_column="id", holdout_ids=my_holdout_ids)
193
+ training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
156
194
 
157
195
  # Pull the training data
158
196
  df = training_view.pull_dataframe()
159
197
  print(df.head())
160
198
  print(df["training"].value_counts())
199
+ print(f"Shape: {df.shape}")
200
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
201
+
202
+ # Test the filter expression
203
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
204
+ df = training_view.pull_dataframe()
205
+ print(df.head())
206
+ print(f"Shape with filter: {df.shape}")
207
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
208
+
209
+ # Test create_with_sql with a custom query (UNION ALL for oversampling)
210
+ print("\n--- Testing create_with_sql with oversampling ---")
211
+ base_table = fs.table
212
+ replicate_ids = [0, 1, 2] # Oversample these IDs
213
+
214
+ custom_sql = f"""
215
+ SELECT * FROM {base_table}
216
+
217
+ UNION ALL
218
+
219
+ SELECT * FROM {base_table}
220
+ WHERE auto_id IN ({', '.join(map(str, replicate_ids))})
221
+ """
222
+
223
+ training_view = TrainingView.create_with_sql(fs, sql_query=custom_sql, id_column="auto_id")
224
+ df = training_view.pull_dataframe()
225
+ print(f"Shape with custom SQL: {df.shape}")
226
+ print(df["training"].value_counts())
227
+
228
+ # Verify oversampling - check if replicated IDs appear twice
229
+ for rep_id in replicate_ids:
230
+ count = len(df[df["auto_id"] == rep_id])
231
+ print(f"ID {rep_id} appears {count} times")
@@ -91,11 +91,11 @@ class View:
91
91
  self.table, self.data_source.database, self.data_source.boto3_session
92
92
  )
93
93
 
94
- def pull_dataframe(self, limit: int = 50000) -> Union[pd.DataFrame, None]:
94
+ def pull_dataframe(self, limit: int = 100000) -> Union[pd.DataFrame, None]:
95
95
  """Pull a DataFrame based on the view type
96
96
 
97
97
  Args:
98
- limit (int): The maximum number of rows to pull (default: 50000)
98
+ limit (int): The maximum number of rows to pull (default: 100000)
99
99
 
100
100
  Returns:
101
101
  Union[pd.DataFrame, None]: The DataFrame for the view or None if it doesn't exist
@@ -196,12 +196,52 @@ class View:
196
196
 
197
197
  # The BaseView always exists
198
198
  if self.view_name == "base":
199
- return True
199
+ return
200
200
 
201
201
  # Check the database directly
202
202
  if not self._check_database():
203
203
  self._auto_create_view()
204
204
 
205
+ def copy(self, dest_view_name: str) -> "View":
206
+ """Copy this view to a new view with a different name
207
+
208
+ Args:
209
+ dest_view_name (str): The destination view name (e.g. "training_v1")
210
+
211
+ Returns:
212
+ View: A new View object for the destination view
213
+ """
214
+ # Can't copy the base view
215
+ if self.view_name == "base":
216
+ self.log.error("Cannot copy the base view")
217
+ return None
218
+
219
+ # Get the view definition
220
+ get_view_query = f"""
221
+ SELECT view_definition
222
+ FROM information_schema.views
223
+ WHERE table_schema = '{self.database}'
224
+ AND table_name = '{self.table}'
225
+ """
226
+ df = self.data_source.query(get_view_query)
227
+
228
+ if df.empty:
229
+ self.log.error(f"View {self.table} not found")
230
+ return None
231
+
232
+ view_definition = df.iloc[0]["view_definition"]
233
+
234
+ # Create the new view with the destination name
235
+ dest_table = f"{self.base_table_name}___{dest_view_name.lower()}"
236
+ create_view_query = f'CREATE OR REPLACE VIEW "{dest_table}" AS {view_definition}'
237
+
238
+ self.log.important(f"Copying view {self.table} to {dest_table}...")
239
+ self.data_source.execute_statement(create_view_query)
240
+
241
+ # Return a new View object for the destination
242
+ artifact = FeatureSet(self.artifact_name) if self.is_feature_set else DataSource(self.artifact_name)
243
+ return View(artifact, dest_view_name, auto_create_view=False)
244
+
205
245
  def _check_database(self) -> bool:
206
246
  """Internal: Check if the view exists in the database
207
247
 
@@ -324,3 +364,13 @@ if __name__ == "__main__":
324
364
  # Test supplemental data tables deletion
325
365
  view = View(fs, "test_view")
326
366
  view.delete()
367
+
368
+ # Test copying a view
369
+ fs = FeatureSet("test_features")
370
+ display_view = View(fs, "display")
371
+ copied_view = display_view.copy("display_copy")
372
+ print(copied_view)
373
+ print(copied_view.pull_dataframe().head())
374
+
375
+ # Clean up copied view
376
+ copied_view.delete()