workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -3,14 +3,18 @@
|
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
5
|
# Workbench Imports
|
|
6
|
-
from workbench.api import
|
|
6
|
+
from workbench.api import FeatureSet
|
|
7
7
|
from workbench.core.views.view import View
|
|
8
8
|
from workbench.core.views.create_view import CreateView
|
|
9
9
|
from workbench.core.views.view_utils import get_column_list
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class TrainingView(CreateView):
|
|
13
|
-
"""TrainingView Class: A View with an additional training column
|
|
13
|
+
"""TrainingView Class: A View with an additional training column (80/20 or holdout ids).
|
|
14
|
+
The TrainingView class creates a SQL view that includes all columns from the source table
|
|
15
|
+
along with an additional boolean column named "training". This view can also include
|
|
16
|
+
a SQL filter expression to filter the rows included in the view.
|
|
17
|
+
|
|
14
18
|
|
|
15
19
|
Common Usage:
|
|
16
20
|
```python
|
|
@@ -19,8 +23,9 @@ class TrainingView(CreateView):
|
|
|
19
23
|
training_view = TrainingView.create(fs)
|
|
20
24
|
df = training_view.pull_dataframe()
|
|
21
25
|
|
|
22
|
-
# Create a TrainingView with a specific
|
|
23
|
-
training_view = TrainingView.create(fs,
|
|
26
|
+
# Create a TrainingView with a specific filter expression
|
|
27
|
+
training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="age > 30")
|
|
28
|
+
df = training_view.pull_dataframe()
|
|
24
29
|
|
|
25
30
|
# Query the view
|
|
26
31
|
df = training_view.query(f"SELECT * FROM {training_view.table} where training = TRUE")
|
|
@@ -31,17 +36,21 @@ class TrainingView(CreateView):
|
|
|
31
36
|
def create(
|
|
32
37
|
cls,
|
|
33
38
|
feature_set: FeatureSet,
|
|
34
|
-
|
|
39
|
+
*, # Enforce keyword arguments after feature_set
|
|
35
40
|
id_column: str = None,
|
|
36
41
|
holdout_ids: Union[list[str], list[int], None] = None,
|
|
42
|
+
filter_expression: str = None,
|
|
43
|
+
source_table: str = None,
|
|
37
44
|
) -> Union[View, None]:
|
|
38
45
|
"""Factory method to create and return a TrainingView instance.
|
|
39
46
|
|
|
40
47
|
Args:
|
|
41
48
|
feature_set (FeatureSet): A FeatureSet object
|
|
42
|
-
source_table (str, optional): The table/view to create the view from. Defaults to None.
|
|
43
49
|
id_column (str, optional): The name of the id column. Defaults to None.
|
|
44
50
|
holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
|
|
51
|
+
filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
|
|
52
|
+
Defaults to None.
|
|
53
|
+
source_table (str, optional): The table/view to create the view from. Defaults to None.
|
|
45
54
|
|
|
46
55
|
Returns:
|
|
47
56
|
Union[View, None]: The created View object (or None if failed to create the view)
|
|
@@ -69,28 +78,36 @@ class TrainingView(CreateView):
|
|
|
69
78
|
else:
|
|
70
79
|
id_column = instance.auto_id_column
|
|
71
80
|
|
|
72
|
-
#
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
81
|
+
# Enclose each column name in double quotes
|
|
82
|
+
sql_columns = ", ".join([f'"{column}"' for column in column_list])
|
|
83
|
+
|
|
84
|
+
# Build the training assignment logic
|
|
85
|
+
if holdout_ids:
|
|
86
|
+
# Format the list of holdout ids for SQL IN clause
|
|
87
|
+
if all(isinstance(id, str) for id in holdout_ids):
|
|
88
|
+
formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
|
|
89
|
+
else:
|
|
90
|
+
formatted_holdout_ids = ", ".join(map(str, holdout_ids))
|
|
76
91
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
92
|
+
training_logic = f"""CASE
|
|
93
|
+
WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
|
|
94
|
+
ELSE True
|
|
95
|
+
END AS training"""
|
|
80
96
|
else:
|
|
81
|
-
|
|
97
|
+
# Default 80/20 split using modulo
|
|
98
|
+
training_logic = f"""CASE
|
|
99
|
+
WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
|
|
100
|
+
ELSE False
|
|
101
|
+
END AS training"""
|
|
82
102
|
|
|
83
|
-
#
|
|
84
|
-
|
|
103
|
+
# Build WHERE clause if filter_expression is provided
|
|
104
|
+
where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
|
|
85
105
|
|
|
86
106
|
# Construct the CREATE VIEW query
|
|
87
107
|
create_view_query = f"""
|
|
88
108
|
CREATE OR REPLACE VIEW {instance.table} AS
|
|
89
|
-
SELECT {sql_columns},
|
|
90
|
-
|
|
91
|
-
ELSE True
|
|
92
|
-
END AS training
|
|
93
|
-
FROM {instance.source_table}
|
|
109
|
+
SELECT {sql_columns}, {training_logic}
|
|
110
|
+
FROM {instance.source_table}{where_clause}
|
|
94
111
|
"""
|
|
95
112
|
|
|
96
113
|
# Execute the CREATE VIEW query
|
|
@@ -99,35 +116,56 @@ class TrainingView(CreateView):
|
|
|
99
116
|
# Return the View
|
|
100
117
|
return View(instance.data_source, instance.view_name, auto_create_view=False)
|
|
101
118
|
|
|
102
|
-
|
|
103
|
-
def
|
|
104
|
-
|
|
119
|
+
@classmethod
|
|
120
|
+
def create_with_sql(
|
|
121
|
+
cls,
|
|
122
|
+
feature_set: FeatureSet,
|
|
123
|
+
*,
|
|
124
|
+
sql_query: str,
|
|
125
|
+
id_column: str = None,
|
|
126
|
+
) -> Union[View, None]:
|
|
127
|
+
"""Factory method to create a TrainingView from a custom SQL query.
|
|
128
|
+
|
|
129
|
+
This method takes a complete SQL query and adds the default 80/20 training split.
|
|
130
|
+
Use this when you need complex queries like UNION ALL for oversampling.
|
|
105
131
|
|
|
106
132
|
Args:
|
|
107
|
-
|
|
108
|
-
|
|
133
|
+
feature_set (FeatureSet): A FeatureSet object
|
|
134
|
+
sql_query (str): Complete SELECT query (without the final semicolon)
|
|
135
|
+
id_column (str, optional): The name of the id column for training split. Defaults to None.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Union[View, None]: The created View object (or None if failed)
|
|
109
139
|
"""
|
|
110
|
-
|
|
140
|
+
# Instantiate the TrainingView
|
|
141
|
+
instance = cls("training", feature_set)
|
|
111
142
|
|
|
112
|
-
#
|
|
113
|
-
|
|
114
|
-
|
|
143
|
+
# Sanity check on the id column
|
|
144
|
+
if not id_column:
|
|
145
|
+
instance.log.important("No id column specified, using auto_id_column")
|
|
146
|
+
if not instance.auto_id_column:
|
|
147
|
+
instance.log.error("No id column specified and no auto_id_column found, aborting")
|
|
148
|
+
return None
|
|
149
|
+
id_column = instance.auto_id_column
|
|
115
150
|
|
|
116
|
-
#
|
|
117
|
-
|
|
151
|
+
# Default 80/20 split using modulo
|
|
152
|
+
training_logic = f"""CASE
|
|
153
|
+
WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
|
|
154
|
+
ELSE False
|
|
155
|
+
END AS training"""
|
|
118
156
|
|
|
119
|
-
#
|
|
157
|
+
# Wrap the custom query and add training column
|
|
120
158
|
create_view_query = f"""
|
|
121
|
-
CREATE OR REPLACE VIEW
|
|
122
|
-
SELECT {
|
|
123
|
-
|
|
124
|
-
ELSE False -- Assign roughly 20% to validation/test
|
|
125
|
-
END AS training
|
|
126
|
-
FROM {self.base_table_name}
|
|
159
|
+
CREATE OR REPLACE VIEW {instance.table} AS
|
|
160
|
+
SELECT *, {training_logic}
|
|
161
|
+
FROM ({sql_query}) AS custom_source
|
|
127
162
|
"""
|
|
128
163
|
|
|
129
164
|
# Execute the CREATE VIEW query
|
|
130
|
-
data_source.execute_statement(create_view_query)
|
|
165
|
+
instance.data_source.execute_statement(create_view_query)
|
|
166
|
+
|
|
167
|
+
# Return the View
|
|
168
|
+
return View(instance.data_source, instance.view_name, auto_create_view=False)
|
|
131
169
|
|
|
132
170
|
|
|
133
171
|
if __name__ == "__main__":
|
|
@@ -135,7 +173,7 @@ if __name__ == "__main__":
|
|
|
135
173
|
from workbench.api import FeatureSet
|
|
136
174
|
|
|
137
175
|
# Get the FeatureSet
|
|
138
|
-
fs = FeatureSet("
|
|
176
|
+
fs = FeatureSet("abalone_features")
|
|
139
177
|
|
|
140
178
|
# Delete the existing training view
|
|
141
179
|
training_view = TrainingView.create(fs)
|
|
@@ -152,9 +190,42 @@ if __name__ == "__main__":
|
|
|
152
190
|
|
|
153
191
|
# Create a TrainingView with holdout ids
|
|
154
192
|
my_holdout_ids = list(range(10))
|
|
155
|
-
training_view = TrainingView.create(fs, id_column="
|
|
193
|
+
training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
|
|
156
194
|
|
|
157
195
|
# Pull the training data
|
|
158
196
|
df = training_view.pull_dataframe()
|
|
159
197
|
print(df.head())
|
|
160
198
|
print(df["training"].value_counts())
|
|
199
|
+
print(f"Shape: {df.shape}")
|
|
200
|
+
print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
|
|
201
|
+
|
|
202
|
+
# Test the filter expression
|
|
203
|
+
training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
|
|
204
|
+
df = training_view.pull_dataframe()
|
|
205
|
+
print(df.head())
|
|
206
|
+
print(f"Shape with filter: {df.shape}")
|
|
207
|
+
print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
|
|
208
|
+
|
|
209
|
+
# Test create_with_sql with a custom query (UNION ALL for oversampling)
|
|
210
|
+
print("\n--- Testing create_with_sql with oversampling ---")
|
|
211
|
+
base_table = fs.table
|
|
212
|
+
replicate_ids = [0, 1, 2] # Oversample these IDs
|
|
213
|
+
|
|
214
|
+
custom_sql = f"""
|
|
215
|
+
SELECT * FROM {base_table}
|
|
216
|
+
|
|
217
|
+
UNION ALL
|
|
218
|
+
|
|
219
|
+
SELECT * FROM {base_table}
|
|
220
|
+
WHERE auto_id IN ({', '.join(map(str, replicate_ids))})
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
training_view = TrainingView.create_with_sql(fs, sql_query=custom_sql, id_column="auto_id")
|
|
224
|
+
df = training_view.pull_dataframe()
|
|
225
|
+
print(f"Shape with custom SQL: {df.shape}")
|
|
226
|
+
print(df["training"].value_counts())
|
|
227
|
+
|
|
228
|
+
# Verify oversampling - check if replicated IDs appear twice
|
|
229
|
+
for rep_id in replicate_ids:
|
|
230
|
+
count = len(df[df["auto_id"] == rep_id])
|
|
231
|
+
print(f"ID {rep_id} appears {count} times")
|
workbench/core/views/view.py
CHANGED
|
@@ -91,11 +91,11 @@ class View:
|
|
|
91
91
|
self.table, self.data_source.database, self.data_source.boto3_session
|
|
92
92
|
)
|
|
93
93
|
|
|
94
|
-
def pull_dataframe(self, limit: int =
|
|
94
|
+
def pull_dataframe(self, limit: int = 100000) -> Union[pd.DataFrame, None]:
|
|
95
95
|
"""Pull a DataFrame based on the view type
|
|
96
96
|
|
|
97
97
|
Args:
|
|
98
|
-
limit (int): The maximum number of rows to pull (default:
|
|
98
|
+
limit (int): The maximum number of rows to pull (default: 100000)
|
|
99
99
|
|
|
100
100
|
Returns:
|
|
101
101
|
Union[pd.DataFrame, None]: The DataFrame for the view or None if it doesn't exist
|
|
@@ -196,12 +196,52 @@ class View:
|
|
|
196
196
|
|
|
197
197
|
# The BaseView always exists
|
|
198
198
|
if self.view_name == "base":
|
|
199
|
-
return
|
|
199
|
+
return
|
|
200
200
|
|
|
201
201
|
# Check the database directly
|
|
202
202
|
if not self._check_database():
|
|
203
203
|
self._auto_create_view()
|
|
204
204
|
|
|
205
|
+
def copy(self, dest_view_name: str) -> "View":
|
|
206
|
+
"""Copy this view to a new view with a different name
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
dest_view_name (str): The destination view name (e.g. "training_v1")
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
View: A new View object for the destination view
|
|
213
|
+
"""
|
|
214
|
+
# Can't copy the base view
|
|
215
|
+
if self.view_name == "base":
|
|
216
|
+
self.log.error("Cannot copy the base view")
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
# Get the view definition
|
|
220
|
+
get_view_query = f"""
|
|
221
|
+
SELECT view_definition
|
|
222
|
+
FROM information_schema.views
|
|
223
|
+
WHERE table_schema = '{self.database}'
|
|
224
|
+
AND table_name = '{self.table}'
|
|
225
|
+
"""
|
|
226
|
+
df = self.data_source.query(get_view_query)
|
|
227
|
+
|
|
228
|
+
if df.empty:
|
|
229
|
+
self.log.error(f"View {self.table} not found")
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
view_definition = df.iloc[0]["view_definition"]
|
|
233
|
+
|
|
234
|
+
# Create the new view with the destination name
|
|
235
|
+
dest_table = f"{self.base_table_name}___{dest_view_name.lower()}"
|
|
236
|
+
create_view_query = f'CREATE OR REPLACE VIEW "{dest_table}" AS {view_definition}'
|
|
237
|
+
|
|
238
|
+
self.log.important(f"Copying view {self.table} to {dest_table}...")
|
|
239
|
+
self.data_source.execute_statement(create_view_query)
|
|
240
|
+
|
|
241
|
+
# Return a new View object for the destination
|
|
242
|
+
artifact = FeatureSet(self.artifact_name) if self.is_feature_set else DataSource(self.artifact_name)
|
|
243
|
+
return View(artifact, dest_view_name, auto_create_view=False)
|
|
244
|
+
|
|
205
245
|
def _check_database(self) -> bool:
|
|
206
246
|
"""Internal: Check if the view exists in the database
|
|
207
247
|
|
|
@@ -324,3 +364,13 @@ if __name__ == "__main__":
|
|
|
324
364
|
# Test supplemental data tables deletion
|
|
325
365
|
view = View(fs, "test_view")
|
|
326
366
|
view.delete()
|
|
367
|
+
|
|
368
|
+
# Test copying a view
|
|
369
|
+
fs = FeatureSet("test_features")
|
|
370
|
+
display_view = View(fs, "display")
|
|
371
|
+
copied_view = display_view.copy("display_copy")
|
|
372
|
+
print(copied_view)
|
|
373
|
+
print(copied_view.pull_dataframe().head())
|
|
374
|
+
|
|
375
|
+
# Clean up copied view
|
|
376
|
+
copied_view.delete()
|
|
@@ -296,15 +296,15 @@ if __name__ == "__main__":
|
|
|
296
296
|
print("View Details on the FeatureSet Table...")
|
|
297
297
|
print(view_details(my_data_source.table, my_data_source.database, my_data_source.boto3_session))
|
|
298
298
|
|
|
299
|
-
print("View Details on the
|
|
300
|
-
training_view = fs.view("
|
|
299
|
+
print("View Details on the Display View...")
|
|
300
|
+
training_view = fs.view("display")
|
|
301
301
|
print(view_details(training_view.table, training_view.database, my_data_source.boto3_session))
|
|
302
302
|
|
|
303
303
|
# Test get_column_list
|
|
304
304
|
print(get_column_list(my_data_source))
|
|
305
305
|
|
|
306
|
-
# Test get_column_list (with
|
|
307
|
-
training_table = fs.view("
|
|
306
|
+
# Test get_column_list (with display view)
|
|
307
|
+
training_table = fs.view("display").table
|
|
308
308
|
print(get_column_list(my_data_source, training_table))
|
|
309
309
|
|
|
310
310
|
# Test list_views
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""Shared utility functions for model training scripts (templates).
|
|
2
|
+
|
|
3
|
+
These functions are used across multiple model templates (XGBoost, PyTorch, ChemProp)
|
|
4
|
+
to reduce code duplication and ensure consistent behavior.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from io import StringIO
|
|
8
|
+
import json
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from sklearn.metrics import (
|
|
12
|
+
confusion_matrix,
|
|
13
|
+
mean_absolute_error,
|
|
14
|
+
median_absolute_error,
|
|
15
|
+
precision_recall_fscore_support,
|
|
16
|
+
r2_score,
|
|
17
|
+
root_mean_squared_error,
|
|
18
|
+
)
|
|
19
|
+
from scipy.stats import spearmanr
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
23
|
+
"""Check if the provided dataframe is empty and raise an exception if it is.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
df: DataFrame to check
|
|
27
|
+
df_name: Name of the DataFrame (for error message)
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
ValueError: If the DataFrame is empty
|
|
31
|
+
"""
|
|
32
|
+
if df.empty:
|
|
33
|
+
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
34
|
+
print(msg)
|
|
35
|
+
raise ValueError(msg)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
|
|
39
|
+
"""Expands a column containing a list of probabilities into separate columns.
|
|
40
|
+
|
|
41
|
+
Handles None values for rows where predictions couldn't be made.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
df: DataFrame containing a "pred_proba" column
|
|
45
|
+
class_labels: List of class labels
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
DataFrame with the "pred_proba" expanded into separate columns (e.g., "class1_proba")
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If DataFrame does not contain a "pred_proba" column
|
|
52
|
+
"""
|
|
53
|
+
proba_column = "pred_proba"
|
|
54
|
+
if proba_column not in df.columns:
|
|
55
|
+
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
56
|
+
|
|
57
|
+
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
58
|
+
n_classes = len(class_labels)
|
|
59
|
+
|
|
60
|
+
# Handle None values by replacing with list of NaNs
|
|
61
|
+
proba_values = []
|
|
62
|
+
for val in df[proba_column]:
|
|
63
|
+
if val is None:
|
|
64
|
+
proba_values.append([np.nan] * n_classes)
|
|
65
|
+
else:
|
|
66
|
+
proba_values.append(val)
|
|
67
|
+
|
|
68
|
+
proba_df = pd.DataFrame(proba_values, columns=proba_splits)
|
|
69
|
+
|
|
70
|
+
# Drop any existing proba columns and reset index for concat
|
|
71
|
+
df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
|
|
72
|
+
df = df.reset_index(drop=True)
|
|
73
|
+
df = pd.concat([df, proba_df], axis=1)
|
|
74
|
+
return df
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def match_features_case_insensitive(df: pd.DataFrame, model_features: list[str]) -> pd.DataFrame:
|
|
78
|
+
"""Matches and renames DataFrame columns to match model feature names (case-insensitive).
|
|
79
|
+
|
|
80
|
+
Prioritizes exact matches, then case-insensitive matches.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
df: Input DataFrame
|
|
84
|
+
model_features: List of feature names expected by the model
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
DataFrame with columns renamed to match model features
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
ValueError: If any model features cannot be matched
|
|
91
|
+
"""
|
|
92
|
+
df_columns_lower = {col.lower(): col for col in df.columns}
|
|
93
|
+
rename_dict = {}
|
|
94
|
+
missing = []
|
|
95
|
+
for feature in model_features:
|
|
96
|
+
if feature in df.columns:
|
|
97
|
+
continue # Exact match
|
|
98
|
+
elif feature.lower() in df_columns_lower:
|
|
99
|
+
rename_dict[df_columns_lower[feature.lower()]] = feature
|
|
100
|
+
else:
|
|
101
|
+
missing.append(feature)
|
|
102
|
+
|
|
103
|
+
if missing:
|
|
104
|
+
raise ValueError(f"Features not found: {missing}")
|
|
105
|
+
|
|
106
|
+
return df.rename(columns=rename_dict)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def convert_categorical_types(
|
|
110
|
+
df: pd.DataFrame, features: list[str], category_mappings: dict[str, list[str]] | None = None
|
|
111
|
+
) -> tuple[pd.DataFrame, dict[str, list[str]]]:
|
|
112
|
+
"""Converts appropriate columns to categorical type with consistent mappings.
|
|
113
|
+
|
|
114
|
+
In training mode (category_mappings is None or empty), detects object/string columns
|
|
115
|
+
with <20 unique values and converts them to categorical.
|
|
116
|
+
In inference mode (category_mappings provided), applies the stored mappings.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
df: The DataFrame to process
|
|
120
|
+
features: List of feature names to consider for conversion
|
|
121
|
+
category_mappings: Existing category mappings. If None or empty, training mode.
|
|
122
|
+
If populated, inference mode.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Tuple of (processed DataFrame, category mappings dictionary)
|
|
126
|
+
"""
|
|
127
|
+
if category_mappings is None:
|
|
128
|
+
category_mappings = {}
|
|
129
|
+
|
|
130
|
+
# Training mode
|
|
131
|
+
if not category_mappings:
|
|
132
|
+
for col in df.select_dtypes(include=["object", "string"]):
|
|
133
|
+
if col in features and df[col].nunique() < 20:
|
|
134
|
+
print(f"Training mode: Converting {col} to category")
|
|
135
|
+
df[col] = df[col].astype("category")
|
|
136
|
+
category_mappings[col] = df[col].cat.categories.tolist()
|
|
137
|
+
|
|
138
|
+
# Inference mode
|
|
139
|
+
else:
|
|
140
|
+
for col, categories in category_mappings.items():
|
|
141
|
+
if col in df.columns:
|
|
142
|
+
print(f"Inference mode: Applying categorical mapping for {col}")
|
|
143
|
+
df[col] = pd.Categorical(df[col], categories=categories)
|
|
144
|
+
|
|
145
|
+
return df, category_mappings
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def decompress_features(
|
|
149
|
+
df: pd.DataFrame, features: list[str], compressed_features: list[str]
|
|
150
|
+
) -> tuple[pd.DataFrame, list[str]]:
|
|
151
|
+
"""Decompress compressed features (bitstrings or count vectors) into individual columns.
|
|
152
|
+
|
|
153
|
+
Supports two formats (auto-detected):
|
|
154
|
+
- Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
|
|
155
|
+
- Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
df: The features DataFrame
|
|
159
|
+
features: Full list of feature names
|
|
160
|
+
compressed_features: List of feature names to decompress
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Tuple of (DataFrame with decompressed features, updated feature list)
|
|
164
|
+
"""
|
|
165
|
+
# Check for any missing values in the required features
|
|
166
|
+
missing_counts = df[features].isna().sum()
|
|
167
|
+
if missing_counts.any():
|
|
168
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
169
|
+
print(
|
|
170
|
+
f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
|
|
171
|
+
"WARNING: You might want to remove/replace all NaN values before processing."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Make a copy to avoid mutating the original list
|
|
175
|
+
decompressed_features = features.copy()
|
|
176
|
+
|
|
177
|
+
for feature in compressed_features:
|
|
178
|
+
if (feature not in df.columns) or (feature not in decompressed_features):
|
|
179
|
+
print(f"Feature '{feature}' not in the features list, skipping decompression.")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
# Remove the feature from the list to avoid duplication
|
|
183
|
+
decompressed_features.remove(feature)
|
|
184
|
+
|
|
185
|
+
# Auto-detect format and parse: comma-separated counts or bitstring
|
|
186
|
+
sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
|
|
187
|
+
parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
|
|
188
|
+
feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
|
|
189
|
+
|
|
190
|
+
# Create new columns with prefix from feature name
|
|
191
|
+
prefix = feature[:3]
|
|
192
|
+
new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
|
|
193
|
+
new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
|
|
194
|
+
|
|
195
|
+
# Update features list and dataframe
|
|
196
|
+
decompressed_features.extend(new_col_names)
|
|
197
|
+
df = df.drop(columns=[feature])
|
|
198
|
+
df = pd.concat([df, new_df], axis=1)
|
|
199
|
+
|
|
200
|
+
return df, decompressed_features
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def input_fn(input_data, content_type: str) -> pd.DataFrame:
|
|
204
|
+
"""Parse input data and return a DataFrame.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
input_data: Raw input data (bytes or string)
|
|
208
|
+
content_type: MIME type of the input data
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Parsed DataFrame
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
ValueError: If input is empty or content_type is not supported
|
|
215
|
+
"""
|
|
216
|
+
if not input_data:
|
|
217
|
+
raise ValueError("Empty input data is not supported!")
|
|
218
|
+
|
|
219
|
+
if isinstance(input_data, bytes):
|
|
220
|
+
input_data = input_data.decode("utf-8")
|
|
221
|
+
|
|
222
|
+
if "text/csv" in content_type:
|
|
223
|
+
return pd.read_csv(StringIO(input_data))
|
|
224
|
+
elif "application/json" in content_type:
|
|
225
|
+
return pd.DataFrame(json.loads(input_data))
|
|
226
|
+
else:
|
|
227
|
+
raise ValueError(f"{content_type} not supported!")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
231
|
+
"""Convert output DataFrame to requested format.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
output_df: DataFrame to convert
|
|
235
|
+
accept_type: Requested MIME type
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Tuple of (formatted output string, MIME type)
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
RuntimeError: If accept_type is not supported
|
|
242
|
+
"""
|
|
243
|
+
if "text/csv" in accept_type:
|
|
244
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
245
|
+
return csv_output, "text/csv"
|
|
246
|
+
elif "application/json" in accept_type:
|
|
247
|
+
return output_df.to_json(orient="records"), "application/json"
|
|
248
|
+
else:
|
|
249
|
+
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
|
+
"""Compute standard regression metrics.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
y_true: Ground truth target values
|
|
257
|
+
y_pred: Predicted values
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Dictionary with keys: rmse, mae, medae, r2, spearmanr, support
|
|
261
|
+
"""
|
|
262
|
+
return {
|
|
263
|
+
"rmse": root_mean_squared_error(y_true, y_pred),
|
|
264
|
+
"mae": mean_absolute_error(y_true, y_pred),
|
|
265
|
+
"medae": median_absolute_error(y_true, y_pred),
|
|
266
|
+
"r2": r2_score(y_true, y_pred),
|
|
267
|
+
"spearmanr": spearmanr(y_true, y_pred).correlation,
|
|
268
|
+
"support": len(y_true),
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def print_regression_metrics(metrics: dict[str, float]) -> None:
|
|
273
|
+
"""Print regression metrics in the format expected by SageMaker metric definitions.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
metrics: Dictionary of metric name -> value
|
|
277
|
+
"""
|
|
278
|
+
print(f"rmse: {metrics['rmse']:.3f}")
|
|
279
|
+
print(f"mae: {metrics['mae']:.3f}")
|
|
280
|
+
print(f"medae: {metrics['medae']:.3f}")
|
|
281
|
+
print(f"r2: {metrics['r2']:.3f}")
|
|
282
|
+
print(f"spearmanr: {metrics['spearmanr']:.3f}")
|
|
283
|
+
print(f"support: {metrics['support']}")
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def compute_classification_metrics(
|
|
287
|
+
y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str], target_col: str
|
|
288
|
+
) -> pd.DataFrame:
|
|
289
|
+
"""Compute per-class classification metrics.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
y_true: Ground truth labels
|
|
293
|
+
y_pred: Predicted labels
|
|
294
|
+
label_names: List of class label names
|
|
295
|
+
target_col: Name of the target column (for DataFrame output)
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
DataFrame with columns: target_col, precision, recall, f1, support
|
|
299
|
+
"""
|
|
300
|
+
scores = precision_recall_fscore_support(y_true, y_pred, average=None, labels=label_names)
|
|
301
|
+
return pd.DataFrame(
|
|
302
|
+
{
|
|
303
|
+
target_col: label_names,
|
|
304
|
+
"precision": scores[0],
|
|
305
|
+
"recall": scores[1],
|
|
306
|
+
"f1": scores[2],
|
|
307
|
+
"support": scores[3],
|
|
308
|
+
}
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def print_classification_metrics(score_df: pd.DataFrame, target_col: str, label_names: list[str]) -> None:
|
|
313
|
+
"""Print per-class classification metrics in the format expected by SageMaker.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
score_df: DataFrame from compute_classification_metrics
|
|
317
|
+
target_col: Name of the target column
|
|
318
|
+
label_names: List of class label names
|
|
319
|
+
"""
|
|
320
|
+
metrics = ["precision", "recall", "f1", "support"]
|
|
321
|
+
for t in label_names:
|
|
322
|
+
for m in metrics:
|
|
323
|
+
value = score_df.loc[score_df[target_col] == t, m].iloc[0]
|
|
324
|
+
print(f"Metrics:{t}:{m} {value}")
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str]) -> None:
|
|
328
|
+
"""Print confusion matrix in the format expected by SageMaker.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
y_true: Ground truth labels
|
|
332
|
+
y_pred: Predicted labels
|
|
333
|
+
label_names: List of class label names
|
|
334
|
+
"""
|
|
335
|
+
conf_mtx = confusion_matrix(y_true, y_pred, labels=label_names)
|
|
336
|
+
for i, row_name in enumerate(label_names):
|
|
337
|
+
for j, col_name in enumerate(label_names):
|
|
338
|
+
value = conf_mtx[i, j]
|
|
339
|
+
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|