workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +12 -0
- workbench/api/feature_set.py +4 -4
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -12
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +168 -78
- workbench/core/artifacts/feature_set_core.py +72 -13
- workbench/core/artifacts/model_core.py +50 -15
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +9 -4
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +49 -53
- workbench/core/views/view.py +51 -1
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
- workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
- workbench/model_scripts/pytorch_model/pytorch.template +19 -20
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +7 -2
- workbench/model_scripts/uq_models/mapie.template +492 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/xgb_model.template +31 -40
- workbench/repl/workbench_shell.py +11 -6
- workbench/scripts/lambda_launcher.py +63 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +76 -30
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +283 -145
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
- workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,7 @@ import logging
|
|
|
10
10
|
|
|
11
11
|
# Workbench Imports
|
|
12
12
|
from workbench.utils.config_manager import ConfigManager
|
|
13
|
-
from
|
|
13
|
+
from workbench_bridges.utils.execution_environment import running_as_service
|
|
14
14
|
|
|
15
15
|
# Attempt to import IPython-related utilities
|
|
16
16
|
try:
|
|
@@ -66,10 +66,10 @@ class AWSSession:
|
|
|
66
66
|
return self._cached_boto3_session
|
|
67
67
|
|
|
68
68
|
def _create_boto3_session(self):
|
|
69
|
-
"""Internal: Get the AWS Boto3 Session,
|
|
69
|
+
"""Internal: Get the AWS Boto3 Session, assuming the Workbench Role if necessary."""
|
|
70
70
|
|
|
71
|
-
# Check
|
|
72
|
-
if
|
|
71
|
+
# Check if we're running as a service or already using the Workbench Role
|
|
72
|
+
if running_as_service() or self.is_workbench_role():
|
|
73
73
|
self.log.important("Using the default Boto3 session...")
|
|
74
74
|
return boto3.Session(region_name=self.region)
|
|
75
75
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""MolecularDescriptors: Compute a Feature Set based on RDKit Descriptors
|
|
2
2
|
|
|
3
|
-
Note: An alternative to using this class is to use the `
|
|
4
|
-
df_features =
|
|
3
|
+
Note: An alternative to using this class is to use the `compute_descriptors` function directly.
|
|
4
|
+
df_features = compute_descriptors(df)
|
|
5
5
|
to_features = PandasToFeatures("my_feature_set")
|
|
6
6
|
to_features.set_input(df_features, id_column="id")
|
|
7
7
|
to_features.set_output_tags(["blah", "whatever"])
|
|
@@ -10,7 +10,7 @@ Note: An alternative to using this class is to use the `compute_molecular_descri
|
|
|
10
10
|
|
|
11
11
|
# Local Imports
|
|
12
12
|
from workbench.core.transforms.data_to_features.light.data_to_features_light import DataToFeaturesLight
|
|
13
|
-
from workbench.utils.chem_utils import
|
|
13
|
+
from workbench.utils.chem_utils.mol_descriptors import compute_descriptors
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class MolecularDescriptors(DataToFeaturesLight):
|
|
@@ -39,7 +39,7 @@ class MolecularDescriptors(DataToFeaturesLight):
|
|
|
39
39
|
"""Compute a Feature Set based on RDKit Descriptors"""
|
|
40
40
|
|
|
41
41
|
# Compute/add all the Molecular Descriptors
|
|
42
|
-
self.output_df =
|
|
42
|
+
self.output_df = compute_descriptors(self.input_df)
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
if __name__ == "__main__":
|
|
@@ -37,8 +37,8 @@ class FeaturesToModel(Transform):
|
|
|
37
37
|
model_import_str=None,
|
|
38
38
|
custom_script=None,
|
|
39
39
|
custom_args=None,
|
|
40
|
-
training_image="
|
|
41
|
-
inference_image="
|
|
40
|
+
training_image="training",
|
|
41
|
+
inference_image="inference",
|
|
42
42
|
inference_arch="x86_64",
|
|
43
43
|
):
|
|
44
44
|
"""FeaturesToModel Initialization
|
|
@@ -50,8 +50,8 @@ class FeaturesToModel(Transform):
|
|
|
50
50
|
model_import_str (str, optional): The import string for the model (default None)
|
|
51
51
|
custom_script (str, optional): Custom script to use for the model (default None)
|
|
52
52
|
custom_args (dict, optional): Custom arguments to pass to custom model scripts (default None)
|
|
53
|
-
training_image (str, optional): Training image (default "
|
|
54
|
-
inference_image (str, optional): Inference image (default "
|
|
53
|
+
training_image (str, optional): Training image (default "training")
|
|
54
|
+
inference_image (str, optional): Inference image (default "inference")
|
|
55
55
|
inference_arch (str, optional): Inference architecture (default "x86_64")
|
|
56
56
|
"""
|
|
57
57
|
|
|
@@ -264,6 +264,11 @@ class FeaturesToModel(Transform):
|
|
|
264
264
|
self.log.important(f"Creating new model {self.output_name}...")
|
|
265
265
|
self.create_and_register_model(**kwargs)
|
|
266
266
|
|
|
267
|
+
# Make a copy of the training view, to lock-in the training data used for this model
|
|
268
|
+
model_training_view_name = f"{self.output_name.replace('-', '_')}_training"
|
|
269
|
+
self.log.important(f"Creating Model Training View: {model_training_view_name}...")
|
|
270
|
+
feature_set.view("training").copy(f"{model_training_view_name}")
|
|
271
|
+
|
|
267
272
|
def post_transform(self, **kwargs):
|
|
268
273
|
"""Post-Transform: Calling onboard() on the Model"""
|
|
269
274
|
self.log.info("Post-Transform: Calling onboard() on the Model...")
|
|
@@ -5,6 +5,7 @@ from sagemaker import ModelPackage
|
|
|
5
5
|
from sagemaker.serializers import CSVSerializer
|
|
6
6
|
from sagemaker.deserializers import CSVDeserializer
|
|
7
7
|
from sagemaker.serverless import ServerlessInferenceConfig
|
|
8
|
+
from sagemaker.model_monitor import DataCaptureConfig
|
|
8
9
|
|
|
9
10
|
# Local Imports
|
|
10
11
|
from workbench.core.transforms.transform import Transform, TransformInput, TransformOutput
|
|
@@ -51,27 +52,38 @@ class ModelToEndpoint(Transform):
|
|
|
51
52
|
EndpointCore.managed_delete(self.output_name)
|
|
52
53
|
|
|
53
54
|
# Get the Model Package ARN for our input model
|
|
54
|
-
|
|
55
|
-
model_package_arn = input_model.model_package_arn()
|
|
55
|
+
workbench_model = ModelCore(self.input_name)
|
|
56
56
|
|
|
57
57
|
# Deploy the model
|
|
58
|
-
self._deploy_model(
|
|
58
|
+
self._deploy_model(workbench_model, **kwargs)
|
|
59
59
|
|
|
60
60
|
# Add this endpoint to the set of registered endpoints for the model
|
|
61
|
-
|
|
61
|
+
workbench_model.register_endpoint(self.output_name)
|
|
62
62
|
|
|
63
63
|
# This ensures that the endpoint is ready for use
|
|
64
64
|
time.sleep(5) # We wait for AWS Lag
|
|
65
65
|
end = EndpointCore(self.output_name)
|
|
66
66
|
self.log.important(f"Endpoint {end.name} is ready for use")
|
|
67
67
|
|
|
68
|
-
def _deploy_model(
|
|
68
|
+
def _deploy_model(
|
|
69
|
+
self,
|
|
70
|
+
workbench_model: ModelCore,
|
|
71
|
+
mem_size: int = 2048,
|
|
72
|
+
max_concurrency: int = 5,
|
|
73
|
+
data_capture: bool = False,
|
|
74
|
+
capture_percentage: int = 100,
|
|
75
|
+
):
|
|
69
76
|
"""Internal Method: Deploy the Model
|
|
70
77
|
|
|
71
78
|
Args:
|
|
72
|
-
|
|
79
|
+
workbench_model(ModelCore): The Workbench ModelCore object to deploy
|
|
80
|
+
mem_size(int): Memory size for serverless deployment
|
|
81
|
+
max_concurrency(int): Max concurrency for serverless deployment
|
|
82
|
+
data_capture(bool): Enable data capture during deployment
|
|
83
|
+
capture_percentage(int): Percentage of data to capture. Defaults to 100.
|
|
73
84
|
"""
|
|
74
85
|
# Grab the specified Model Package
|
|
86
|
+
model_package_arn = workbench_model.model_package_arn()
|
|
75
87
|
model_package = ModelPackage(
|
|
76
88
|
role=self.workbench_role_arn,
|
|
77
89
|
model_package_arn=model_package_arn,
|
|
@@ -95,6 +107,23 @@ class ModelToEndpoint(Transform):
|
|
|
95
107
|
max_concurrency=max_concurrency,
|
|
96
108
|
)
|
|
97
109
|
|
|
110
|
+
# Configure data capture if requested (and not serverless)
|
|
111
|
+
data_capture_config = None
|
|
112
|
+
if data_capture and not self.serverless:
|
|
113
|
+
# Set up the S3 path for data capture
|
|
114
|
+
base_endpoint_path = f"{workbench_model.endpoints_s3_path}/{self.output_name}"
|
|
115
|
+
data_capture_path = f"{base_endpoint_path}/data_capture"
|
|
116
|
+
self.log.important(f"Configuring Data Capture --> {data_capture_path}")
|
|
117
|
+
data_capture_config = DataCaptureConfig(
|
|
118
|
+
enable_capture=True,
|
|
119
|
+
sampling_percentage=capture_percentage,
|
|
120
|
+
destination_s3_uri=data_capture_path,
|
|
121
|
+
)
|
|
122
|
+
elif data_capture and self.serverless:
|
|
123
|
+
self.log.warning(
|
|
124
|
+
"Data capture is not supported for serverless endpoints. Skipping data capture configuration."
|
|
125
|
+
)
|
|
126
|
+
|
|
98
127
|
# Deploy the Endpoint
|
|
99
128
|
self.log.important(f"Deploying the Endpoint {self.output_name}...")
|
|
100
129
|
model_package.deploy(
|
|
@@ -104,6 +133,7 @@ class ModelToEndpoint(Transform):
|
|
|
104
133
|
endpoint_name=self.output_name,
|
|
105
134
|
serializer=CSVSerializer(),
|
|
106
135
|
deserializer=CSVDeserializer(),
|
|
136
|
+
data_capture_config=data_capture_config,
|
|
107
137
|
tags=aws_tags,
|
|
108
138
|
)
|
|
109
139
|
|
|
@@ -327,9 +327,36 @@ class PandasToFeatures(Transform):
|
|
|
327
327
|
self.delete_existing()
|
|
328
328
|
self.output_feature_group = self.create_feature_group()
|
|
329
329
|
|
|
330
|
+
def mac_spawn_hack(self):
|
|
331
|
+
"""Workaround for macOS Tahoe fork/spawn issue with SageMaker FeatureStore ingest.
|
|
332
|
+
|
|
333
|
+
See: https://github.com/aws/sagemaker-python-sdk/issues/5312
|
|
334
|
+
macOS Tahoe 26+ has issues with forked processes creating boto3 sessions.
|
|
335
|
+
This forces spawn mode on macOS to avoid the hang.
|
|
336
|
+
"""
|
|
337
|
+
import platform
|
|
338
|
+
|
|
339
|
+
if platform.system() == "Darwin": # macOS
|
|
340
|
+
self.log.warning("macOS detected, forcing 'spawn' mode for multiprocessing (Tahoe hang workaround)")
|
|
341
|
+
import multiprocessing
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
import multiprocess
|
|
345
|
+
|
|
346
|
+
multiprocess.set_start_method("spawn", force=True)
|
|
347
|
+
except (RuntimeError, ImportError):
|
|
348
|
+
pass # Already set or multiprocess not available
|
|
349
|
+
try:
|
|
350
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
351
|
+
except RuntimeError:
|
|
352
|
+
pass # Already set
|
|
353
|
+
|
|
330
354
|
def transform_impl(self):
|
|
331
355
|
"""Transform Implementation: Ingest the data into the Feature Group"""
|
|
332
356
|
|
|
357
|
+
# Workaround for macOS Tahoe hang issue
|
|
358
|
+
self.mac_spawn_hack()
|
|
359
|
+
|
|
333
360
|
# Now we actually push the data into the Feature Group (called ingestion)
|
|
334
361
|
self.log.important(f"Ingesting rows into Feature Group {self.output_name}...")
|
|
335
362
|
ingest_manager = self.output_feature_group.ingest(self.output_df, max_workers=8, max_processes=4, wait=False)
|
|
@@ -3,14 +3,18 @@
|
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
5
|
# Workbench Imports
|
|
6
|
-
from workbench.api import
|
|
6
|
+
from workbench.api import FeatureSet
|
|
7
7
|
from workbench.core.views.view import View
|
|
8
8
|
from workbench.core.views.create_view import CreateView
|
|
9
9
|
from workbench.core.views.view_utils import get_column_list
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class TrainingView(CreateView):
|
|
13
|
-
"""TrainingView Class: A View with an additional training column
|
|
13
|
+
"""TrainingView Class: A View with an additional training column (80/20 or holdout ids).
|
|
14
|
+
The TrainingView class creates a SQL view that includes all columns from the source table
|
|
15
|
+
along with an additional boolean column named "training". This view can also include
|
|
16
|
+
a SQL filter expression to filter the rows included in the view.
|
|
17
|
+
|
|
14
18
|
|
|
15
19
|
Common Usage:
|
|
16
20
|
```python
|
|
@@ -19,8 +23,9 @@ class TrainingView(CreateView):
|
|
|
19
23
|
training_view = TrainingView.create(fs)
|
|
20
24
|
df = training_view.pull_dataframe()
|
|
21
25
|
|
|
22
|
-
# Create a TrainingView with a specific
|
|
23
|
-
training_view = TrainingView.create(fs,
|
|
26
|
+
# Create a TrainingView with a specific filter expression
|
|
27
|
+
training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="age > 30")
|
|
28
|
+
df = training_view.pull_dataframe()
|
|
24
29
|
|
|
25
30
|
# Query the view
|
|
26
31
|
df = training_view.query(f"SELECT * FROM {training_view.table} where training = TRUE")
|
|
@@ -31,17 +36,21 @@ class TrainingView(CreateView):
|
|
|
31
36
|
def create(
|
|
32
37
|
cls,
|
|
33
38
|
feature_set: FeatureSet,
|
|
34
|
-
|
|
39
|
+
*, # Enforce keyword arguments after feature_set
|
|
35
40
|
id_column: str = None,
|
|
36
41
|
holdout_ids: Union[list[str], list[int], None] = None,
|
|
42
|
+
filter_expression: str = None,
|
|
43
|
+
source_table: str = None,
|
|
37
44
|
) -> Union[View, None]:
|
|
38
45
|
"""Factory method to create and return a TrainingView instance.
|
|
39
46
|
|
|
40
47
|
Args:
|
|
41
48
|
feature_set (FeatureSet): A FeatureSet object
|
|
42
|
-
source_table (str, optional): The table/view to create the view from. Defaults to None.
|
|
43
49
|
id_column (str, optional): The name of the id column. Defaults to None.
|
|
44
50
|
holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
|
|
51
|
+
filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
|
|
52
|
+
Defaults to None.
|
|
53
|
+
source_table (str, optional): The table/view to create the view from. Defaults to None.
|
|
45
54
|
|
|
46
55
|
Returns:
|
|
47
56
|
Union[View, None]: The created View object (or None if failed to create the view)
|
|
@@ -69,28 +78,36 @@ class TrainingView(CreateView):
|
|
|
69
78
|
else:
|
|
70
79
|
id_column = instance.auto_id_column
|
|
71
80
|
|
|
72
|
-
#
|
|
73
|
-
|
|
74
|
-
instance._default_training_view(instance.data_source, id_column)
|
|
75
|
-
return View(instance.data_source, instance.view_name, auto_create_view=False)
|
|
81
|
+
# Enclose each column name in double quotes
|
|
82
|
+
sql_columns = ", ".join([f'"{column}"' for column in column_list])
|
|
76
83
|
|
|
77
|
-
#
|
|
78
|
-
if holdout_ids
|
|
79
|
-
|
|
84
|
+
# Build the training assignment logic
|
|
85
|
+
if holdout_ids:
|
|
86
|
+
# Format the list of holdout ids for SQL IN clause
|
|
87
|
+
if all(isinstance(id, str) for id in holdout_ids):
|
|
88
|
+
formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
|
|
89
|
+
else:
|
|
90
|
+
formatted_holdout_ids = ", ".join(map(str, holdout_ids))
|
|
91
|
+
|
|
92
|
+
training_logic = f"""CASE
|
|
93
|
+
WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
|
|
94
|
+
ELSE True
|
|
95
|
+
END AS training"""
|
|
80
96
|
else:
|
|
81
|
-
|
|
97
|
+
# Default 80/20 split using modulo
|
|
98
|
+
training_logic = f"""CASE
|
|
99
|
+
WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
|
|
100
|
+
ELSE False
|
|
101
|
+
END AS training"""
|
|
82
102
|
|
|
83
|
-
#
|
|
84
|
-
|
|
103
|
+
# Build WHERE clause if filter_expression is provided
|
|
104
|
+
where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
|
|
85
105
|
|
|
86
106
|
# Construct the CREATE VIEW query
|
|
87
107
|
create_view_query = f"""
|
|
88
108
|
CREATE OR REPLACE VIEW {instance.table} AS
|
|
89
|
-
SELECT {sql_columns},
|
|
90
|
-
|
|
91
|
-
ELSE True
|
|
92
|
-
END AS training
|
|
93
|
-
FROM {instance.source_table}
|
|
109
|
+
SELECT {sql_columns}, {training_logic}
|
|
110
|
+
FROM {instance.source_table}{where_clause}
|
|
94
111
|
"""
|
|
95
112
|
|
|
96
113
|
# Execute the CREATE VIEW query
|
|
@@ -99,43 +116,13 @@ class TrainingView(CreateView):
|
|
|
99
116
|
# Return the View
|
|
100
117
|
return View(instance.data_source, instance.view_name, auto_create_view=False)
|
|
101
118
|
|
|
102
|
-
# This is an internal method that's used to create a default training view
|
|
103
|
-
def _default_training_view(self, data_source: DataSource, id_column: str):
|
|
104
|
-
"""Create a default view in Athena that assigns roughly 80% of the data to training
|
|
105
|
-
|
|
106
|
-
Args:
|
|
107
|
-
data_source (DataSource): The Workbench DataSource object
|
|
108
|
-
id_column (str): The name of the id column
|
|
109
|
-
"""
|
|
110
|
-
self.log.important(f"Creating default Training View {self.table}...")
|
|
111
|
-
|
|
112
|
-
# Drop any columns generated from AWS
|
|
113
|
-
aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
|
|
114
|
-
column_list = [col for col in data_source.columns if col not in aws_cols]
|
|
115
|
-
|
|
116
|
-
# Enclose each column name in double quotes
|
|
117
|
-
sql_columns = ", ".join([f'"{column}"' for column in column_list])
|
|
118
|
-
|
|
119
|
-
# Construct the CREATE VIEW query with a simple modulo operation for the 80/20 split
|
|
120
|
-
create_view_query = f"""
|
|
121
|
-
CREATE OR REPLACE VIEW "{self.table}" AS
|
|
122
|
-
SELECT {sql_columns}, CASE
|
|
123
|
-
WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True -- Assign 80% to training
|
|
124
|
-
ELSE False -- Assign roughly 20% to validation/test
|
|
125
|
-
END AS training
|
|
126
|
-
FROM {self.base_table_name}
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
# Execute the CREATE VIEW query
|
|
130
|
-
data_source.execute_statement(create_view_query)
|
|
131
|
-
|
|
132
119
|
|
|
133
120
|
if __name__ == "__main__":
|
|
134
121
|
"""Exercise the Training View functionality"""
|
|
135
122
|
from workbench.api import FeatureSet
|
|
136
123
|
|
|
137
124
|
# Get the FeatureSet
|
|
138
|
-
fs = FeatureSet("
|
|
125
|
+
fs = FeatureSet("abalone_features")
|
|
139
126
|
|
|
140
127
|
# Delete the existing training view
|
|
141
128
|
training_view = TrainingView.create(fs)
|
|
@@ -152,9 +139,18 @@ if __name__ == "__main__":
|
|
|
152
139
|
|
|
153
140
|
# Create a TrainingView with holdout ids
|
|
154
141
|
my_holdout_ids = list(range(10))
|
|
155
|
-
training_view = TrainingView.create(fs, id_column="
|
|
142
|
+
training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
|
|
156
143
|
|
|
157
144
|
# Pull the training data
|
|
158
145
|
df = training_view.pull_dataframe()
|
|
159
146
|
print(df.head())
|
|
160
147
|
print(df["training"].value_counts())
|
|
148
|
+
print(f"Shape: {df.shape}")
|
|
149
|
+
print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
|
|
150
|
+
|
|
151
|
+
# Test the filter expression
|
|
152
|
+
training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
|
|
153
|
+
df = training_view.pull_dataframe()
|
|
154
|
+
print(df.head())
|
|
155
|
+
print(f"Shape with filter: {df.shape}")
|
|
156
|
+
print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
|
workbench/core/views/view.py
CHANGED
|
@@ -196,12 +196,52 @@ class View:
|
|
|
196
196
|
|
|
197
197
|
# The BaseView always exists
|
|
198
198
|
if self.view_name == "base":
|
|
199
|
-
return
|
|
199
|
+
return
|
|
200
200
|
|
|
201
201
|
# Check the database directly
|
|
202
202
|
if not self._check_database():
|
|
203
203
|
self._auto_create_view()
|
|
204
204
|
|
|
205
|
+
def copy(self, dest_view_name: str) -> "View":
|
|
206
|
+
"""Copy this view to a new view with a different name
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
dest_view_name (str): The destination view name (e.g. "training_v1")
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
View: A new View object for the destination view
|
|
213
|
+
"""
|
|
214
|
+
# Can't copy the base view
|
|
215
|
+
if self.view_name == "base":
|
|
216
|
+
self.log.error("Cannot copy the base view")
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
# Get the view definition
|
|
220
|
+
get_view_query = f"""
|
|
221
|
+
SELECT view_definition
|
|
222
|
+
FROM information_schema.views
|
|
223
|
+
WHERE table_schema = '{self.database}'
|
|
224
|
+
AND table_name = '{self.table}'
|
|
225
|
+
"""
|
|
226
|
+
df = self.data_source.query(get_view_query)
|
|
227
|
+
|
|
228
|
+
if df.empty:
|
|
229
|
+
self.log.error(f"View {self.table} not found")
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
view_definition = df.iloc[0]["view_definition"]
|
|
233
|
+
|
|
234
|
+
# Create the new view with the destination name
|
|
235
|
+
dest_table = f"{self.base_table_name}___{dest_view_name}"
|
|
236
|
+
create_view_query = f'CREATE OR REPLACE VIEW "{dest_table}" AS {view_definition}'
|
|
237
|
+
|
|
238
|
+
self.log.important(f"Copying view {self.table} to {dest_table}...")
|
|
239
|
+
self.data_source.execute_statement(create_view_query)
|
|
240
|
+
|
|
241
|
+
# Return a new View object for the destination
|
|
242
|
+
artifact = FeatureSet(self.artifact_name) if self.is_feature_set else DataSource(self.artifact_name)
|
|
243
|
+
return View(artifact, dest_view_name, auto_create_view=False)
|
|
244
|
+
|
|
205
245
|
def _check_database(self) -> bool:
|
|
206
246
|
"""Internal: Check if the view exists in the database
|
|
207
247
|
|
|
@@ -324,3 +364,13 @@ if __name__ == "__main__":
|
|
|
324
364
|
# Test supplemental data tables deletion
|
|
325
365
|
view = View(fs, "test_view")
|
|
326
366
|
view.delete()
|
|
367
|
+
|
|
368
|
+
# Test copying a view
|
|
369
|
+
fs = FeatureSet("test_features")
|
|
370
|
+
display_view = View(fs, "display")
|
|
371
|
+
copied_view = display_view.copy("display_copy")
|
|
372
|
+
print(copied_view)
|
|
373
|
+
print(copied_view.pull_dataframe().head())
|
|
374
|
+
|
|
375
|
+
# Clean up copied view
|
|
376
|
+
copied_view.delete()
|
|
@@ -296,15 +296,15 @@ if __name__ == "__main__":
|
|
|
296
296
|
print("View Details on the FeatureSet Table...")
|
|
297
297
|
print(view_details(my_data_source.table, my_data_source.database, my_data_source.boto3_session))
|
|
298
298
|
|
|
299
|
-
print("View Details on the
|
|
300
|
-
training_view = fs.view("
|
|
299
|
+
print("View Details on the Display View...")
|
|
300
|
+
training_view = fs.view("display")
|
|
301
301
|
print(view_details(training_view.table, training_view.database, my_data_source.boto3_session))
|
|
302
302
|
|
|
303
303
|
# Test get_column_list
|
|
304
304
|
print(get_column_list(my_data_source))
|
|
305
305
|
|
|
306
|
-
# Test get_column_list (with
|
|
307
|
-
training_table = fs.view("
|
|
306
|
+
# Test get_column_list (with display view)
|
|
307
|
+
training_table = fs.view("display").table
|
|
308
308
|
print(get_column_list(my_data_source, training_table))
|
|
309
309
|
|
|
310
310
|
# Test list_views
|