workbench 0.8.161__py3-none-any.whl → 0.8.192__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. workbench/algorithms/dataframe/proximity.py +143 -102
  2. workbench/algorithms/graph/light/proximity_graph.py +2 -1
  3. workbench/api/compound.py +1 -1
  4. workbench/api/endpoint.py +12 -0
  5. workbench/api/feature_set.py +4 -4
  6. workbench/api/meta.py +5 -2
  7. workbench/api/model.py +16 -12
  8. workbench/api/monitor.py +1 -16
  9. workbench/core/artifacts/artifact.py +11 -3
  10. workbench/core/artifacts/data_capture_core.py +355 -0
  11. workbench/core/artifacts/endpoint_core.py +168 -78
  12. workbench/core/artifacts/feature_set_core.py +72 -13
  13. workbench/core/artifacts/model_core.py +50 -15
  14. workbench/core/artifacts/monitor_core.py +33 -248
  15. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  16. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  17. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  18. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  19. workbench/core/transforms/features_to_model/features_to_model.py +9 -4
  20. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  21. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  22. workbench/core/views/training_view.py +49 -53
  23. workbench/core/views/view.py +51 -1
  24. workbench/core/views/view_utils.py +4 -4
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  26. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  27. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  28. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  29. workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
  30. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  31. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +10 -17
  32. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  33. workbench/model_scripts/custom_models/uq_models/meta_uq.template +156 -58
  34. workbench/model_scripts/custom_models/uq_models/ngboost.template +20 -14
  35. workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
  36. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  37. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +5 -13
  38. workbench/model_scripts/pytorch_model/pytorch.template +19 -20
  39. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  40. workbench/model_scripts/script_generation.py +7 -2
  41. workbench/model_scripts/uq_models/mapie.template +492 -0
  42. workbench/model_scripts/uq_models/requirements.txt +1 -0
  43. workbench/model_scripts/xgb_model/xgb_model.template +31 -40
  44. workbench/repl/workbench_shell.py +11 -6
  45. workbench/scripts/lambda_launcher.py +63 -0
  46. workbench/scripts/ml_pipeline_batch.py +137 -0
  47. workbench/scripts/ml_pipeline_sqs.py +186 -0
  48. workbench/scripts/monitor_cloud_watch.py +20 -100
  49. workbench/utils/aws_utils.py +4 -3
  50. workbench/utils/chem_utils/__init__.py +0 -0
  51. workbench/utils/chem_utils/fingerprints.py +134 -0
  52. workbench/utils/chem_utils/misc.py +194 -0
  53. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  54. workbench/utils/chem_utils/mol_standardize.py +450 -0
  55. workbench/utils/chem_utils/mol_tagging.py +348 -0
  56. workbench/utils/chem_utils/projections.py +209 -0
  57. workbench/utils/chem_utils/salts.py +256 -0
  58. workbench/utils/chem_utils/sdf.py +292 -0
  59. workbench/utils/chem_utils/toxicity.py +250 -0
  60. workbench/utils/chem_utils/vis.py +253 -0
  61. workbench/utils/cloudwatch_handler.py +1 -1
  62. workbench/utils/cloudwatch_utils.py +137 -0
  63. workbench/utils/config_manager.py +3 -7
  64. workbench/utils/endpoint_utils.py +5 -7
  65. workbench/utils/license_manager.py +2 -6
  66. workbench/utils/model_utils.py +76 -30
  67. workbench/utils/monitor_utils.py +44 -62
  68. workbench/utils/pandas_utils.py +3 -3
  69. workbench/utils/shap_utils.py +10 -2
  70. workbench/utils/workbench_logging.py +0 -3
  71. workbench/utils/workbench_sqs.py +1 -1
  72. workbench/utils/xgboost_model_utils.py +283 -145
  73. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  74. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  75. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  76. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/METADATA +4 -4
  77. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/RECORD +81 -76
  78. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/entry_points.txt +3 -0
  79. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  80. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  81. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  82. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  83. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  84. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  85. workbench/model_scripts/pytorch_model/generated_model_script.py +0 -565
  86. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  87. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  88. workbench/model_scripts/scikit_learn/generated_model_script.py +0 -307
  89. workbench/model_scripts/xgb_model/generated_model_script.py +0 -477
  90. workbench/utils/chem_utils.py +0 -1556
  91. workbench/utils/execution_environment.py +0 -211
  92. workbench/utils/fast_inference.py +0 -167
  93. workbench/utils/resource_utils.py +0 -39
  94. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/WHEEL +0 -0
  95. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/licenses/LICENSE +0 -0
  96. {workbench-0.8.161.dist-info → workbench-0.8.192.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ import logging
10
10
 
11
11
  # Workbench Imports
12
12
  from workbench.utils.config_manager import ConfigManager
13
- from workbench.utils.execution_environment import running_on_lambda, running_on_glue
13
+ from workbench_bridges.utils.execution_environment import running_as_service
14
14
 
15
15
  # Attempt to import IPython-related utilities
16
16
  try:
@@ -66,10 +66,10 @@ class AWSSession:
66
66
  return self._cached_boto3_session
67
67
 
68
68
  def _create_boto3_session(self):
69
- """Internal: Get the AWS Boto3 Session, defaulting to the Workbench Role if possible."""
69
+ """Internal: Get the AWS Boto3 Session, assuming the Workbench Role if necessary."""
70
70
 
71
- # Check the execution environment and determine if we need to assume the Workbench Role
72
- if running_on_lambda() or running_on_glue() or self.is_workbench_role():
71
+ # Check if we're running as a service or already using the Workbench Role
72
+ if running_as_service() or self.is_workbench_role():
73
73
  self.log.important("Using the default Boto3 session...")
74
74
  return boto3.Session(region_name=self.region)
75
75
 
@@ -1,7 +1,7 @@
1
1
  """MolecularDescriptors: Compute a Feature Set based on RDKit Descriptors
2
2
 
3
- Note: An alternative to using this class is to use the `compute_molecular_descriptors` function directly.
4
- df_features = compute_molecular_descriptors(df)
3
+ Note: An alternative to using this class is to use the `compute_descriptors` function directly.
4
+ df_features = compute_descriptors(df)
5
5
  to_features = PandasToFeatures("my_feature_set")
6
6
  to_features.set_input(df_features, id_column="id")
7
7
  to_features.set_output_tags(["blah", "whatever"])
@@ -10,7 +10,7 @@ Note: An alternative to using this class is to use the `compute_molecular_descri
10
10
 
11
11
  # Local Imports
12
12
  from workbench.core.transforms.data_to_features.light.data_to_features_light import DataToFeaturesLight
13
- from workbench.utils.chem_utils import compute_molecular_descriptors
13
+ from workbench.utils.chem_utils.mol_descriptors import compute_descriptors
14
14
 
15
15
 
16
16
  class MolecularDescriptors(DataToFeaturesLight):
@@ -39,7 +39,7 @@ class MolecularDescriptors(DataToFeaturesLight):
39
39
  """Compute a Feature Set based on RDKit Descriptors"""
40
40
 
41
41
  # Compute/add all the Molecular Descriptors
42
- self.output_df = compute_molecular_descriptors(self.input_df)
42
+ self.output_df = compute_descriptors(self.input_df)
43
43
 
44
44
 
45
45
  if __name__ == "__main__":
@@ -37,8 +37,8 @@ class FeaturesToModel(Transform):
37
37
  model_import_str=None,
38
38
  custom_script=None,
39
39
  custom_args=None,
40
- training_image="xgb_training",
41
- inference_image="xgb_inference",
40
+ training_image="training",
41
+ inference_image="inference",
42
42
  inference_arch="x86_64",
43
43
  ):
44
44
  """FeaturesToModel Initialization
@@ -50,8 +50,8 @@ class FeaturesToModel(Transform):
50
50
  model_import_str (str, optional): The import string for the model (default None)
51
51
  custom_script (str, optional): Custom script to use for the model (default None)
52
52
  custom_args (dict, optional): Custom arguments to pass to custom model scripts (default None)
53
- training_image (str, optional): Training image (default "xgb_training")
54
- inference_image (str, optional): Inference image (default "xgb_inference")
53
+ training_image (str, optional): Training image (default "training")
54
+ inference_image (str, optional): Inference image (default "inference")
55
55
  inference_arch (str, optional): Inference architecture (default "x86_64")
56
56
  """
57
57
 
@@ -264,6 +264,11 @@ class FeaturesToModel(Transform):
264
264
  self.log.important(f"Creating new model {self.output_name}...")
265
265
  self.create_and_register_model(**kwargs)
266
266
 
267
+ # Make a copy of the training view, to lock-in the training data used for this model
268
+ model_training_view_name = f"{self.output_name.replace('-', '_')}_training"
269
+ self.log.important(f"Creating Model Training View: {model_training_view_name}...")
270
+ feature_set.view("training").copy(f"{model_training_view_name}")
271
+
267
272
  def post_transform(self, **kwargs):
268
273
  """Post-Transform: Calling onboard() on the Model"""
269
274
  self.log.info("Post-Transform: Calling onboard() on the Model...")
@@ -5,6 +5,7 @@ from sagemaker import ModelPackage
5
5
  from sagemaker.serializers import CSVSerializer
6
6
  from sagemaker.deserializers import CSVDeserializer
7
7
  from sagemaker.serverless import ServerlessInferenceConfig
8
+ from sagemaker.model_monitor import DataCaptureConfig
8
9
 
9
10
  # Local Imports
10
11
  from workbench.core.transforms.transform import Transform, TransformInput, TransformOutput
@@ -51,27 +52,38 @@ class ModelToEndpoint(Transform):
51
52
  EndpointCore.managed_delete(self.output_name)
52
53
 
53
54
  # Get the Model Package ARN for our input model
54
- input_model = ModelCore(self.input_name)
55
- model_package_arn = input_model.model_package_arn()
55
+ workbench_model = ModelCore(self.input_name)
56
56
 
57
57
  # Deploy the model
58
- self._deploy_model(model_package_arn, **kwargs)
58
+ self._deploy_model(workbench_model, **kwargs)
59
59
 
60
60
  # Add this endpoint to the set of registered endpoints for the model
61
- input_model.register_endpoint(self.output_name)
61
+ workbench_model.register_endpoint(self.output_name)
62
62
 
63
63
  # This ensures that the endpoint is ready for use
64
64
  time.sleep(5) # We wait for AWS Lag
65
65
  end = EndpointCore(self.output_name)
66
66
  self.log.important(f"Endpoint {end.name} is ready for use")
67
67
 
68
- def _deploy_model(self, model_package_arn: str, mem_size: int = 2048, max_concurrency: int = 5):
68
+ def _deploy_model(
69
+ self,
70
+ workbench_model: ModelCore,
71
+ mem_size: int = 2048,
72
+ max_concurrency: int = 5,
73
+ data_capture: bool = False,
74
+ capture_percentage: int = 100,
75
+ ):
69
76
  """Internal Method: Deploy the Model
70
77
 
71
78
  Args:
72
- model_package_arn(str): The Model Package ARN used to deploy the Endpoint
79
+ workbench_model(ModelCore): The Workbench ModelCore object to deploy
80
+ mem_size(int): Memory size for serverless deployment
81
+ max_concurrency(int): Max concurrency for serverless deployment
82
+ data_capture(bool): Enable data capture during deployment
83
+ capture_percentage(int): Percentage of data to capture. Defaults to 100.
73
84
  """
74
85
  # Grab the specified Model Package
86
+ model_package_arn = workbench_model.model_package_arn()
75
87
  model_package = ModelPackage(
76
88
  role=self.workbench_role_arn,
77
89
  model_package_arn=model_package_arn,
@@ -95,6 +107,23 @@ class ModelToEndpoint(Transform):
95
107
  max_concurrency=max_concurrency,
96
108
  )
97
109
 
110
+ # Configure data capture if requested (and not serverless)
111
+ data_capture_config = None
112
+ if data_capture and not self.serverless:
113
+ # Set up the S3 path for data capture
114
+ base_endpoint_path = f"{workbench_model.endpoints_s3_path}/{self.output_name}"
115
+ data_capture_path = f"{base_endpoint_path}/data_capture"
116
+ self.log.important(f"Configuring Data Capture --> {data_capture_path}")
117
+ data_capture_config = DataCaptureConfig(
118
+ enable_capture=True,
119
+ sampling_percentage=capture_percentage,
120
+ destination_s3_uri=data_capture_path,
121
+ )
122
+ elif data_capture and self.serverless:
123
+ self.log.warning(
124
+ "Data capture is not supported for serverless endpoints. Skipping data capture configuration."
125
+ )
126
+
98
127
  # Deploy the Endpoint
99
128
  self.log.important(f"Deploying the Endpoint {self.output_name}...")
100
129
  model_package.deploy(
@@ -104,6 +133,7 @@ class ModelToEndpoint(Transform):
104
133
  endpoint_name=self.output_name,
105
134
  serializer=CSVSerializer(),
106
135
  deserializer=CSVDeserializer(),
136
+ data_capture_config=data_capture_config,
107
137
  tags=aws_tags,
108
138
  )
109
139
 
@@ -327,9 +327,36 @@ class PandasToFeatures(Transform):
327
327
  self.delete_existing()
328
328
  self.output_feature_group = self.create_feature_group()
329
329
 
330
+ def mac_spawn_hack(self):
331
+ """Workaround for macOS Tahoe fork/spawn issue with SageMaker FeatureStore ingest.
332
+
333
+ See: https://github.com/aws/sagemaker-python-sdk/issues/5312
334
+ macOS Tahoe 26+ has issues with forked processes creating boto3 sessions.
335
+ This forces spawn mode on macOS to avoid the hang.
336
+ """
337
+ import platform
338
+
339
+ if platform.system() == "Darwin": # macOS
340
+ self.log.warning("macOS detected, forcing 'spawn' mode for multiprocessing (Tahoe hang workaround)")
341
+ import multiprocessing
342
+
343
+ try:
344
+ import multiprocess
345
+
346
+ multiprocess.set_start_method("spawn", force=True)
347
+ except (RuntimeError, ImportError):
348
+ pass # Already set or multiprocess not available
349
+ try:
350
+ multiprocessing.set_start_method("spawn", force=True)
351
+ except RuntimeError:
352
+ pass # Already set
353
+
330
354
  def transform_impl(self):
331
355
  """Transform Implementation: Ingest the data into the Feature Group"""
332
356
 
357
+ # Workaround for macOS Tahoe hang issue
358
+ self.mac_spawn_hack()
359
+
333
360
  # Now we actually push the data into the Feature Group (called ingestion)
334
361
  self.log.important(f"Ingesting rows into Feature Group {self.output_name}...")
335
362
  ingest_manager = self.output_feature_group.ingest(self.output_df, max_workers=8, max_processes=4, wait=False)
@@ -3,14 +3,18 @@
3
3
  from typing import Union
4
4
 
5
5
  # Workbench Imports
6
- from workbench.api import DataSource, FeatureSet
6
+ from workbench.api import FeatureSet
7
7
  from workbench.core.views.view import View
8
8
  from workbench.core.views.create_view import CreateView
9
9
  from workbench.core.views.view_utils import get_column_list
10
10
 
11
11
 
12
12
  class TrainingView(CreateView):
13
- """TrainingView Class: A View with an additional training column that marks holdout ids
13
+ """TrainingView Class: A View with an additional training column (80/20 or holdout ids).
14
+ The TrainingView class creates a SQL view that includes all columns from the source table
15
+ along with an additional boolean column named "training". This view can also include
16
+ a SQL filter expression to filter the rows included in the view.
17
+
14
18
 
15
19
  Common Usage:
16
20
  ```python
@@ -19,8 +23,9 @@ class TrainingView(CreateView):
19
23
  training_view = TrainingView.create(fs)
20
24
  df = training_view.pull_dataframe()
21
25
 
22
- # Create a TrainingView with a specific set of columns
23
- training_view = TrainingView.create(fs, column_list=["my_col1", "my_col2"])
26
+ # Create a TrainingView with a specific filter expression
27
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="age > 30")
28
+ df = training_view.pull_dataframe()
24
29
 
25
30
  # Query the view
26
31
  df = training_view.query(f"SELECT * FROM {training_view.table} where training = TRUE")
@@ -31,17 +36,21 @@ class TrainingView(CreateView):
31
36
  def create(
32
37
  cls,
33
38
  feature_set: FeatureSet,
34
- source_table: str = None,
39
+ *, # Enforce keyword arguments after feature_set
35
40
  id_column: str = None,
36
41
  holdout_ids: Union[list[str], list[int], None] = None,
42
+ filter_expression: str = None,
43
+ source_table: str = None,
37
44
  ) -> Union[View, None]:
38
45
  """Factory method to create and return a TrainingView instance.
39
46
 
40
47
  Args:
41
48
  feature_set (FeatureSet): A FeatureSet object
42
- source_table (str, optional): The table/view to create the view from. Defaults to None.
43
49
  id_column (str, optional): The name of the id column. Defaults to None.
44
50
  holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
51
+ filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
52
+ Defaults to None.
53
+ source_table (str, optional): The table/view to create the view from. Defaults to None.
45
54
 
46
55
  Returns:
47
56
  Union[View, None]: The created View object (or None if failed to create the view)
@@ -69,28 +78,36 @@ class TrainingView(CreateView):
69
78
  else:
70
79
  id_column = instance.auto_id_column
71
80
 
72
- # If we don't have holdout ids, create a default training view
73
- if not holdout_ids:
74
- instance._default_training_view(instance.data_source, id_column)
75
- return View(instance.data_source, instance.view_name, auto_create_view=False)
81
+ # Enclose each column name in double quotes
82
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
76
83
 
77
- # Format the list of holdout ids for SQL IN clause
78
- if holdout_ids and all(isinstance(id, str) for id in holdout_ids):
79
- formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
84
+ # Build the training assignment logic
85
+ if holdout_ids:
86
+ # Format the list of holdout ids for SQL IN clause
87
+ if all(isinstance(id, str) for id in holdout_ids):
88
+ formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
89
+ else:
90
+ formatted_holdout_ids = ", ".join(map(str, holdout_ids))
91
+
92
+ training_logic = f"""CASE
93
+ WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
94
+ ELSE True
95
+ END AS training"""
80
96
  else:
81
- formatted_holdout_ids = ", ".join(map(str, holdout_ids))
97
+ # Default 80/20 split using modulo
98
+ training_logic = f"""CASE
99
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
100
+ ELSE False
101
+ END AS training"""
82
102
 
83
- # Enclose each column name in double quotes
84
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
103
+ # Build WHERE clause if filter_expression is provided
104
+ where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
85
105
 
86
106
  # Construct the CREATE VIEW query
87
107
  create_view_query = f"""
88
108
  CREATE OR REPLACE VIEW {instance.table} AS
89
- SELECT {sql_columns}, CASE
90
- WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
91
- ELSE True
92
- END AS training
93
- FROM {instance.source_table}
109
+ SELECT {sql_columns}, {training_logic}
110
+ FROM {instance.source_table}{where_clause}
94
111
  """
95
112
 
96
113
  # Execute the CREATE VIEW query
@@ -99,43 +116,13 @@ class TrainingView(CreateView):
99
116
  # Return the View
100
117
  return View(instance.data_source, instance.view_name, auto_create_view=False)
101
118
 
102
- # This is an internal method that's used to create a default training view
103
- def _default_training_view(self, data_source: DataSource, id_column: str):
104
- """Create a default view in Athena that assigns roughly 80% of the data to training
105
-
106
- Args:
107
- data_source (DataSource): The Workbench DataSource object
108
- id_column (str): The name of the id column
109
- """
110
- self.log.important(f"Creating default Training View {self.table}...")
111
-
112
- # Drop any columns generated from AWS
113
- aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
114
- column_list = [col for col in data_source.columns if col not in aws_cols]
115
-
116
- # Enclose each column name in double quotes
117
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
118
-
119
- # Construct the CREATE VIEW query with a simple modulo operation for the 80/20 split
120
- create_view_query = f"""
121
- CREATE OR REPLACE VIEW "{self.table}" AS
122
- SELECT {sql_columns}, CASE
123
- WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True -- Assign 80% to training
124
- ELSE False -- Assign roughly 20% to validation/test
125
- END AS training
126
- FROM {self.base_table_name}
127
- """
128
-
129
- # Execute the CREATE VIEW query
130
- data_source.execute_statement(create_view_query)
131
-
132
119
 
133
120
  if __name__ == "__main__":
134
121
  """Exercise the Training View functionality"""
135
122
  from workbench.api import FeatureSet
136
123
 
137
124
  # Get the FeatureSet
138
- fs = FeatureSet("test_features")
125
+ fs = FeatureSet("abalone_features")
139
126
 
140
127
  # Delete the existing training view
141
128
  training_view = TrainingView.create(fs)
@@ -152,9 +139,18 @@ if __name__ == "__main__":
152
139
 
153
140
  # Create a TrainingView with holdout ids
154
141
  my_holdout_ids = list(range(10))
155
- training_view = TrainingView.create(fs, id_column="id", holdout_ids=my_holdout_ids)
142
+ training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
156
143
 
157
144
  # Pull the training data
158
145
  df = training_view.pull_dataframe()
159
146
  print(df.head())
160
147
  print(df["training"].value_counts())
148
+ print(f"Shape: {df.shape}")
149
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
150
+
151
+ # Test the filter expression
152
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
153
+ df = training_view.pull_dataframe()
154
+ print(df.head())
155
+ print(f"Shape with filter: {df.shape}")
156
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
@@ -196,12 +196,52 @@ class View:
196
196
 
197
197
  # The BaseView always exists
198
198
  if self.view_name == "base":
199
- return True
199
+ return
200
200
 
201
201
  # Check the database directly
202
202
  if not self._check_database():
203
203
  self._auto_create_view()
204
204
 
205
+ def copy(self, dest_view_name: str) -> "View":
206
+ """Copy this view to a new view with a different name
207
+
208
+ Args:
209
+ dest_view_name (str): The destination view name (e.g. "training_v1")
210
+
211
+ Returns:
212
+ View: A new View object for the destination view
213
+ """
214
+ # Can't copy the base view
215
+ if self.view_name == "base":
216
+ self.log.error("Cannot copy the base view")
217
+ return None
218
+
219
+ # Get the view definition
220
+ get_view_query = f"""
221
+ SELECT view_definition
222
+ FROM information_schema.views
223
+ WHERE table_schema = '{self.database}'
224
+ AND table_name = '{self.table}'
225
+ """
226
+ df = self.data_source.query(get_view_query)
227
+
228
+ if df.empty:
229
+ self.log.error(f"View {self.table} not found")
230
+ return None
231
+
232
+ view_definition = df.iloc[0]["view_definition"]
233
+
234
+ # Create the new view with the destination name
235
+ dest_table = f"{self.base_table_name}___{dest_view_name}"
236
+ create_view_query = f'CREATE OR REPLACE VIEW "{dest_table}" AS {view_definition}'
237
+
238
+ self.log.important(f"Copying view {self.table} to {dest_table}...")
239
+ self.data_source.execute_statement(create_view_query)
240
+
241
+ # Return a new View object for the destination
242
+ artifact = FeatureSet(self.artifact_name) if self.is_feature_set else DataSource(self.artifact_name)
243
+ return View(artifact, dest_view_name, auto_create_view=False)
244
+
205
245
  def _check_database(self) -> bool:
206
246
  """Internal: Check if the view exists in the database
207
247
 
@@ -324,3 +364,13 @@ if __name__ == "__main__":
324
364
  # Test supplemental data tables deletion
325
365
  view = View(fs, "test_view")
326
366
  view.delete()
367
+
368
+ # Test copying a view
369
+ fs = FeatureSet("test_features")
370
+ display_view = View(fs, "display")
371
+ copied_view = display_view.copy("display_copy")
372
+ print(copied_view)
373
+ print(copied_view.pull_dataframe().head())
374
+
375
+ # Clean up copied view
376
+ copied_view.delete()
@@ -296,15 +296,15 @@ if __name__ == "__main__":
296
296
  print("View Details on the FeatureSet Table...")
297
297
  print(view_details(my_data_source.table, my_data_source.database, my_data_source.boto3_session))
298
298
 
299
- print("View Details on the Training View...")
300
- training_view = fs.view("training")
299
+ print("View Details on the Display View...")
300
+ training_view = fs.view("display")
301
301
  print(view_details(training_view.table, training_view.database, my_data_source.boto3_session))
302
302
 
303
303
  # Test get_column_list
304
304
  print(get_column_list(my_data_source))
305
305
 
306
- # Test get_column_list (with training view)
307
- training_table = fs.view("training").table
306
+ # Test get_column_list (with display view)
307
+ training_table = fs.view("display").table
308
308
  print(get_column_list(my_data_source, training_table))
309
309
 
310
310
  # Test list_views