workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -3,14 +3,18 @@
3
3
  from typing import Union
4
4
 
5
5
  # Workbench Imports
6
- from workbench.api import DataSource, FeatureSet
6
+ from workbench.api import FeatureSet
7
7
  from workbench.core.views.view import View
8
8
  from workbench.core.views.create_view import CreateView
9
9
  from workbench.core.views.view_utils import get_column_list
10
10
 
11
11
 
12
12
  class TrainingView(CreateView):
13
- """TrainingView Class: A View with an additional training column that marks holdout ids
13
+ """TrainingView Class: A View with an additional training column (80/20 or holdout ids).
14
+ The TrainingView class creates a SQL view that includes all columns from the source table
15
+ along with an additional boolean column named "training". This view can also include
16
+ a SQL filter expression to filter the rows included in the view.
17
+
14
18
 
15
19
  Common Usage:
16
20
  ```python
@@ -19,8 +23,9 @@ class TrainingView(CreateView):
19
23
  training_view = TrainingView.create(fs)
20
24
  df = training_view.pull_dataframe()
21
25
 
22
- # Create a TrainingView with a specific set of columns
23
- training_view = TrainingView.create(fs, column_list=["my_col1", "my_col2"])
26
+ # Create a TrainingView with a specific filter expression
27
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="age > 30")
28
+ df = training_view.pull_dataframe()
24
29
 
25
30
  # Query the view
26
31
  df = training_view.query(f"SELECT * FROM {training_view.table} where training = TRUE")
@@ -31,17 +36,21 @@ class TrainingView(CreateView):
31
36
  def create(
32
37
  cls,
33
38
  feature_set: FeatureSet,
34
- source_table: str = None,
39
+ *, # Enforce keyword arguments after feature_set
35
40
  id_column: str = None,
36
41
  holdout_ids: Union[list[str], list[int], None] = None,
42
+ filter_expression: str = None,
43
+ source_table: str = None,
37
44
  ) -> Union[View, None]:
38
45
  """Factory method to create and return a TrainingView instance.
39
46
 
40
47
  Args:
41
48
  feature_set (FeatureSet): A FeatureSet object
42
- source_table (str, optional): The table/view to create the view from. Defaults to None.
43
49
  id_column (str, optional): The name of the id column. Defaults to None.
44
50
  holdout_ids (Union[list[str], list[int], None], optional): A list of holdout ids. Defaults to None.
51
+ filter_expression (str, optional): SQL filter expression (e.g., "age > 25 AND status = 'active'").
52
+ Defaults to None.
53
+ source_table (str, optional): The table/view to create the view from. Defaults to None.
45
54
 
46
55
  Returns:
47
56
  Union[View, None]: The created View object (or None if failed to create the view)
@@ -69,28 +78,36 @@ class TrainingView(CreateView):
69
78
  else:
70
79
  id_column = instance.auto_id_column
71
80
 
72
- # If we don't have holdout ids, create a default training view
73
- if not holdout_ids:
74
- instance._default_training_view(instance.data_source, id_column)
75
- return View(instance.data_source, instance.view_name, auto_create_view=False)
81
+ # Enclose each column name in double quotes
82
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
83
+
84
+ # Build the training assignment logic
85
+ if holdout_ids:
86
+ # Format the list of holdout ids for SQL IN clause
87
+ if all(isinstance(id, str) for id in holdout_ids):
88
+ formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
89
+ else:
90
+ formatted_holdout_ids = ", ".join(map(str, holdout_ids))
76
91
 
77
- # Format the list of holdout ids for SQL IN clause
78
- if holdout_ids and all(isinstance(id, str) for id in holdout_ids):
79
- formatted_holdout_ids = ", ".join(f"'{id}'" for id in holdout_ids)
92
+ training_logic = f"""CASE
93
+ WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
94
+ ELSE True
95
+ END AS training"""
80
96
  else:
81
- formatted_holdout_ids = ", ".join(map(str, holdout_ids))
97
+ # Default 80/20 split using modulo
98
+ training_logic = f"""CASE
99
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
100
+ ELSE False
101
+ END AS training"""
82
102
 
83
- # Enclose each column name in double quotes
84
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
103
+ # Build WHERE clause if filter_expression is provided
104
+ where_clause = f"\nWHERE {filter_expression}" if filter_expression else ""
85
105
 
86
106
  # Construct the CREATE VIEW query
87
107
  create_view_query = f"""
88
108
  CREATE OR REPLACE VIEW {instance.table} AS
89
- SELECT {sql_columns}, CASE
90
- WHEN {id_column} IN ({formatted_holdout_ids}) THEN False
91
- ELSE True
92
- END AS training
93
- FROM {instance.source_table}
109
+ SELECT {sql_columns}, {training_logic}
110
+ FROM {instance.source_table}{where_clause}
94
111
  """
95
112
 
96
113
  # Execute the CREATE VIEW query
@@ -99,35 +116,56 @@ class TrainingView(CreateView):
99
116
  # Return the View
100
117
  return View(instance.data_source, instance.view_name, auto_create_view=False)
101
118
 
102
- # This is an internal method that's used to create a default training view
103
- def _default_training_view(self, data_source: DataSource, id_column: str):
104
- """Create a default view in Athena that assigns roughly 80% of the data to training
119
+ @classmethod
120
+ def create_with_sql(
121
+ cls,
122
+ feature_set: FeatureSet,
123
+ *,
124
+ sql_query: str,
125
+ id_column: str = None,
126
+ ) -> Union[View, None]:
127
+ """Factory method to create a TrainingView from a custom SQL query.
128
+
129
+ This method takes a complete SQL query and adds the default 80/20 training split.
130
+ Use this when you need complex queries like UNION ALL for oversampling.
105
131
 
106
132
  Args:
107
- data_source (DataSource): The Workbench DataSource object
108
- id_column (str): The name of the id column
133
+ feature_set (FeatureSet): A FeatureSet object
134
+ sql_query (str): Complete SELECT query (without the final semicolon)
135
+ id_column (str, optional): The name of the id column for training split. Defaults to None.
136
+
137
+ Returns:
138
+ Union[View, None]: The created View object (or None if failed)
109
139
  """
110
- self.log.important(f"Creating default Training View {self.table}...")
140
+ # Instantiate the TrainingView
141
+ instance = cls("training", feature_set)
111
142
 
112
- # Drop any columns generated from AWS
113
- aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
114
- column_list = [col for col in data_source.columns if col not in aws_cols]
143
+ # Sanity check on the id column
144
+ if not id_column:
145
+ instance.log.important("No id column specified, using auto_id_column")
146
+ if not instance.auto_id_column:
147
+ instance.log.error("No id column specified and no auto_id_column found, aborting")
148
+ return None
149
+ id_column = instance.auto_id_column
115
150
 
116
- # Enclose each column name in double quotes
117
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
151
+ # Default 80/20 split using modulo
152
+ training_logic = f"""CASE
153
+ WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True
154
+ ELSE False
155
+ END AS training"""
118
156
 
119
- # Construct the CREATE VIEW query with a simple modulo operation for the 80/20 split
157
+ # Wrap the custom query and add training column
120
158
  create_view_query = f"""
121
- CREATE OR REPLACE VIEW "{self.table}" AS
122
- SELECT {sql_columns}, CASE
123
- WHEN MOD(ROW_NUMBER() OVER (ORDER BY {id_column}), 10) < 8 THEN True -- Assign 80% to training
124
- ELSE False -- Assign roughly 20% to validation/test
125
- END AS training
126
- FROM {self.base_table_name}
159
+ CREATE OR REPLACE VIEW {instance.table} AS
160
+ SELECT *, {training_logic}
161
+ FROM ({sql_query}) AS custom_source
127
162
  """
128
163
 
129
164
  # Execute the CREATE VIEW query
130
- data_source.execute_statement(create_view_query)
165
+ instance.data_source.execute_statement(create_view_query)
166
+
167
+ # Return the View
168
+ return View(instance.data_source, instance.view_name, auto_create_view=False)
131
169
 
132
170
 
133
171
  if __name__ == "__main__":
@@ -135,7 +173,7 @@ if __name__ == "__main__":
135
173
  from workbench.api import FeatureSet
136
174
 
137
175
  # Get the FeatureSet
138
- fs = FeatureSet("test_features")
176
+ fs = FeatureSet("abalone_features")
139
177
 
140
178
  # Delete the existing training view
141
179
  training_view = TrainingView.create(fs)
@@ -152,9 +190,42 @@ if __name__ == "__main__":
152
190
 
153
191
  # Create a TrainingView with holdout ids
154
192
  my_holdout_ids = list(range(10))
155
- training_view = TrainingView.create(fs, id_column="id", holdout_ids=my_holdout_ids)
193
+ training_view = TrainingView.create(fs, id_column="auto_id", holdout_ids=my_holdout_ids)
156
194
 
157
195
  # Pull the training data
158
196
  df = training_view.pull_dataframe()
159
197
  print(df.head())
160
198
  print(df["training"].value_counts())
199
+ print(f"Shape: {df.shape}")
200
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
201
+
202
+ # Test the filter expression
203
+ training_view = TrainingView.create(fs, id_column="auto_id", filter_expression="diameter > 0.5")
204
+ df = training_view.pull_dataframe()
205
+ print(df.head())
206
+ print(f"Shape with filter: {df.shape}")
207
+ print(f"Diameter min: {df['diameter'].min()}, max: {df['diameter'].max()}")
208
+
209
+ # Test create_with_sql with a custom query (UNION ALL for oversampling)
210
+ print("\n--- Testing create_with_sql with oversampling ---")
211
+ base_table = fs.table
212
+ replicate_ids = [0, 1, 2] # Oversample these IDs
213
+
214
+ custom_sql = f"""
215
+ SELECT * FROM {base_table}
216
+
217
+ UNION ALL
218
+
219
+ SELECT * FROM {base_table}
220
+ WHERE auto_id IN ({', '.join(map(str, replicate_ids))})
221
+ """
222
+
223
+ training_view = TrainingView.create_with_sql(fs, sql_query=custom_sql, id_column="auto_id")
224
+ df = training_view.pull_dataframe()
225
+ print(f"Shape with custom SQL: {df.shape}")
226
+ print(df["training"].value_counts())
227
+
228
+ # Verify oversampling - check if replicated IDs appear twice
229
+ for rep_id in replicate_ids:
230
+ count = len(df[df["auto_id"] == rep_id])
231
+ print(f"ID {rep_id} appears {count} times")
@@ -91,11 +91,11 @@ class View:
91
91
  self.table, self.data_source.database, self.data_source.boto3_session
92
92
  )
93
93
 
94
- def pull_dataframe(self, limit: int = 50000) -> Union[pd.DataFrame, None]:
94
+ def pull_dataframe(self, limit: int = 100000) -> Union[pd.DataFrame, None]:
95
95
  """Pull a DataFrame based on the view type
96
96
 
97
97
  Args:
98
- limit (int): The maximum number of rows to pull (default: 50000)
98
+ limit (int): The maximum number of rows to pull (default: 100000)
99
99
 
100
100
  Returns:
101
101
  Union[pd.DataFrame, None]: The DataFrame for the view or None if it doesn't exist
@@ -196,12 +196,52 @@ class View:
196
196
 
197
197
  # The BaseView always exists
198
198
  if self.view_name == "base":
199
- return True
199
+ return
200
200
 
201
201
  # Check the database directly
202
202
  if not self._check_database():
203
203
  self._auto_create_view()
204
204
 
205
+ def copy(self, dest_view_name: str) -> "View":
206
+ """Copy this view to a new view with a different name
207
+
208
+ Args:
209
+ dest_view_name (str): The destination view name (e.g. "training_v1")
210
+
211
+ Returns:
212
+ View: A new View object for the destination view
213
+ """
214
+ # Can't copy the base view
215
+ if self.view_name == "base":
216
+ self.log.error("Cannot copy the base view")
217
+ return None
218
+
219
+ # Get the view definition
220
+ get_view_query = f"""
221
+ SELECT view_definition
222
+ FROM information_schema.views
223
+ WHERE table_schema = '{self.database}'
224
+ AND table_name = '{self.table}'
225
+ """
226
+ df = self.data_source.query(get_view_query)
227
+
228
+ if df.empty:
229
+ self.log.error(f"View {self.table} not found")
230
+ return None
231
+
232
+ view_definition = df.iloc[0]["view_definition"]
233
+
234
+ # Create the new view with the destination name
235
+ dest_table = f"{self.base_table_name}___{dest_view_name.lower()}"
236
+ create_view_query = f'CREATE OR REPLACE VIEW "{dest_table}" AS {view_definition}'
237
+
238
+ self.log.important(f"Copying view {self.table} to {dest_table}...")
239
+ self.data_source.execute_statement(create_view_query)
240
+
241
+ # Return a new View object for the destination
242
+ artifact = FeatureSet(self.artifact_name) if self.is_feature_set else DataSource(self.artifact_name)
243
+ return View(artifact, dest_view_name, auto_create_view=False)
244
+
205
245
  def _check_database(self) -> bool:
206
246
  """Internal: Check if the view exists in the database
207
247
 
@@ -324,3 +364,13 @@ if __name__ == "__main__":
324
364
  # Test supplemental data tables deletion
325
365
  view = View(fs, "test_view")
326
366
  view.delete()
367
+
368
+ # Test copying a view
369
+ fs = FeatureSet("test_features")
370
+ display_view = View(fs, "display")
371
+ copied_view = display_view.copy("display_copy")
372
+ print(copied_view)
373
+ print(copied_view.pull_dataframe().head())
374
+
375
+ # Clean up copied view
376
+ copied_view.delete()
@@ -296,15 +296,15 @@ if __name__ == "__main__":
296
296
  print("View Details on the FeatureSet Table...")
297
297
  print(view_details(my_data_source.table, my_data_source.database, my_data_source.boto3_session))
298
298
 
299
- print("View Details on the Training View...")
300
- training_view = fs.view("training")
299
+ print("View Details on the Display View...")
300
+ training_view = fs.view("display")
301
301
  print(view_details(training_view.table, training_view.database, my_data_source.boto3_session))
302
302
 
303
303
  # Test get_column_list
304
304
  print(get_column_list(my_data_source))
305
305
 
306
- # Test get_column_list (with training view)
307
- training_table = fs.view("training").table
306
+ # Test get_column_list (with display view)
307
+ training_table = fs.view("display").table
308
308
  print(get_column_list(my_data_source, training_table))
309
309
 
310
310
  # Test list_views
@@ -0,0 +1,339 @@
1
+ """Shared utility functions for model training scripts (templates).
2
+
3
+ These functions are used across multiple model templates (XGBoost, PyTorch, ChemProp)
4
+ to reduce code duplication and ensure consistent behavior.
5
+ """
6
+
7
+ from io import StringIO
8
+ import json
9
+ import numpy as np
10
+ import pandas as pd
11
+ from sklearn.metrics import (
12
+ confusion_matrix,
13
+ mean_absolute_error,
14
+ median_absolute_error,
15
+ precision_recall_fscore_support,
16
+ r2_score,
17
+ root_mean_squared_error,
18
+ )
19
+ from scipy.stats import spearmanr
20
+
21
+
22
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
23
+ """Check if the provided dataframe is empty and raise an exception if it is.
24
+
25
+ Args:
26
+ df: DataFrame to check
27
+ df_name: Name of the DataFrame (for error message)
28
+
29
+ Raises:
30
+ ValueError: If the DataFrame is empty
31
+ """
32
+ if df.empty:
33
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
34
+ print(msg)
35
+ raise ValueError(msg)
36
+
37
+
38
+ def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
39
+ """Expands a column containing a list of probabilities into separate columns.
40
+
41
+ Handles None values for rows where predictions couldn't be made.
42
+
43
+ Args:
44
+ df: DataFrame containing a "pred_proba" column
45
+ class_labels: List of class labels
46
+
47
+ Returns:
48
+ DataFrame with the "pred_proba" expanded into separate columns (e.g., "class1_proba")
49
+
50
+ Raises:
51
+ ValueError: If DataFrame does not contain a "pred_proba" column
52
+ """
53
+ proba_column = "pred_proba"
54
+ if proba_column not in df.columns:
55
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
56
+
57
+ proba_splits = [f"{label}_proba" for label in class_labels]
58
+ n_classes = len(class_labels)
59
+
60
+ # Handle None values by replacing with list of NaNs
61
+ proba_values = []
62
+ for val in df[proba_column]:
63
+ if val is None:
64
+ proba_values.append([np.nan] * n_classes)
65
+ else:
66
+ proba_values.append(val)
67
+
68
+ proba_df = pd.DataFrame(proba_values, columns=proba_splits)
69
+
70
+ # Drop any existing proba columns and reset index for concat
71
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
72
+ df = df.reset_index(drop=True)
73
+ df = pd.concat([df, proba_df], axis=1)
74
+ return df
75
+
76
+
77
+ def match_features_case_insensitive(df: pd.DataFrame, model_features: list[str]) -> pd.DataFrame:
78
+ """Matches and renames DataFrame columns to match model feature names (case-insensitive).
79
+
80
+ Prioritizes exact matches, then case-insensitive matches.
81
+
82
+ Args:
83
+ df: Input DataFrame
84
+ model_features: List of feature names expected by the model
85
+
86
+ Returns:
87
+ DataFrame with columns renamed to match model features
88
+
89
+ Raises:
90
+ ValueError: If any model features cannot be matched
91
+ """
92
+ df_columns_lower = {col.lower(): col for col in df.columns}
93
+ rename_dict = {}
94
+ missing = []
95
+ for feature in model_features:
96
+ if feature in df.columns:
97
+ continue # Exact match
98
+ elif feature.lower() in df_columns_lower:
99
+ rename_dict[df_columns_lower[feature.lower()]] = feature
100
+ else:
101
+ missing.append(feature)
102
+
103
+ if missing:
104
+ raise ValueError(f"Features not found: {missing}")
105
+
106
+ return df.rename(columns=rename_dict)
107
+
108
+
109
+ def convert_categorical_types(
110
+ df: pd.DataFrame, features: list[str], category_mappings: dict[str, list[str]] | None = None
111
+ ) -> tuple[pd.DataFrame, dict[str, list[str]]]:
112
+ """Converts appropriate columns to categorical type with consistent mappings.
113
+
114
+ In training mode (category_mappings is None or empty), detects object/string columns
115
+ with <20 unique values and converts them to categorical.
116
+ In inference mode (category_mappings provided), applies the stored mappings.
117
+
118
+ Args:
119
+ df: The DataFrame to process
120
+ features: List of feature names to consider for conversion
121
+ category_mappings: Existing category mappings. If None or empty, training mode.
122
+ If populated, inference mode.
123
+
124
+ Returns:
125
+ Tuple of (processed DataFrame, category mappings dictionary)
126
+ """
127
+ if category_mappings is None:
128
+ category_mappings = {}
129
+
130
+ # Training mode
131
+ if not category_mappings:
132
+ for col in df.select_dtypes(include=["object", "string"]):
133
+ if col in features and df[col].nunique() < 20:
134
+ print(f"Training mode: Converting {col} to category")
135
+ df[col] = df[col].astype("category")
136
+ category_mappings[col] = df[col].cat.categories.tolist()
137
+
138
+ # Inference mode
139
+ else:
140
+ for col, categories in category_mappings.items():
141
+ if col in df.columns:
142
+ print(f"Inference mode: Applying categorical mapping for {col}")
143
+ df[col] = pd.Categorical(df[col], categories=categories)
144
+
145
+ return df, category_mappings
146
+
147
+
148
+ def decompress_features(
149
+ df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
+ ) -> tuple[pd.DataFrame, list[str]]:
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
156
+
157
+ Args:
158
+ df: The features DataFrame
159
+ features: Full list of feature names
160
+ compressed_features: List of feature names to decompress
161
+
162
+ Returns:
163
+ Tuple of (DataFrame with decompressed features, updated feature list)
164
+ """
165
+ # Check for any missing values in the required features
166
+ missing_counts = df[features].isna().sum()
167
+ if missing_counts.any():
168
+ missing_features = missing_counts[missing_counts > 0]
169
+ print(
170
+ f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
171
+ "WARNING: You might want to remove/replace all NaN values before processing."
172
+ )
173
+
174
+ # Make a copy to avoid mutating the original list
175
+ decompressed_features = features.copy()
176
+
177
+ for feature in compressed_features:
178
+ if (feature not in df.columns) or (feature not in decompressed_features):
179
+ print(f"Feature '{feature}' not in the features list, skipping decompression.")
180
+ continue
181
+
182
+ # Remove the feature from the list to avoid duplication
183
+ decompressed_features.remove(feature)
184
+
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
189
+
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
194
+
195
+ # Update features list and dataframe
196
+ decompressed_features.extend(new_col_names)
197
+ df = df.drop(columns=[feature])
198
+ df = pd.concat([df, new_df], axis=1)
199
+
200
+ return df, decompressed_features
201
+
202
+
203
+ def input_fn(input_data, content_type: str) -> pd.DataFrame:
204
+ """Parse input data and return a DataFrame.
205
+
206
+ Args:
207
+ input_data: Raw input data (bytes or string)
208
+ content_type: MIME type of the input data
209
+
210
+ Returns:
211
+ Parsed DataFrame
212
+
213
+ Raises:
214
+ ValueError: If input is empty or content_type is not supported
215
+ """
216
+ if not input_data:
217
+ raise ValueError("Empty input data is not supported!")
218
+
219
+ if isinstance(input_data, bytes):
220
+ input_data = input_data.decode("utf-8")
221
+
222
+ if "text/csv" in content_type:
223
+ return pd.read_csv(StringIO(input_data))
224
+ elif "application/json" in content_type:
225
+ return pd.DataFrame(json.loads(input_data))
226
+ else:
227
+ raise ValueError(f"{content_type} not supported!")
228
+
229
+
230
+ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
231
+ """Convert output DataFrame to requested format.
232
+
233
+ Args:
234
+ output_df: DataFrame to convert
235
+ accept_type: Requested MIME type
236
+
237
+ Returns:
238
+ Tuple of (formatted output string, MIME type)
239
+
240
+ Raises:
241
+ RuntimeError: If accept_type is not supported
242
+ """
243
+ if "text/csv" in accept_type:
244
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
245
+ return csv_output, "text/csv"
246
+ elif "application/json" in accept_type:
247
+ return output_df.to_json(orient="records"), "application/json"
248
+ else:
249
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
+
251
+
252
+ def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
+ """Compute standard regression metrics.
254
+
255
+ Args:
256
+ y_true: Ground truth target values
257
+ y_pred: Predicted values
258
+
259
+ Returns:
260
+ Dictionary with keys: rmse, mae, medae, r2, spearmanr, support
261
+ """
262
+ return {
263
+ "rmse": root_mean_squared_error(y_true, y_pred),
264
+ "mae": mean_absolute_error(y_true, y_pred),
265
+ "medae": median_absolute_error(y_true, y_pred),
266
+ "r2": r2_score(y_true, y_pred),
267
+ "spearmanr": spearmanr(y_true, y_pred).correlation,
268
+ "support": len(y_true),
269
+ }
270
+
271
+
272
+ def print_regression_metrics(metrics: dict[str, float]) -> None:
273
+ """Print regression metrics in the format expected by SageMaker metric definitions.
274
+
275
+ Args:
276
+ metrics: Dictionary of metric name -> value
277
+ """
278
+ print(f"rmse: {metrics['rmse']:.3f}")
279
+ print(f"mae: {metrics['mae']:.3f}")
280
+ print(f"medae: {metrics['medae']:.3f}")
281
+ print(f"r2: {metrics['r2']:.3f}")
282
+ print(f"spearmanr: {metrics['spearmanr']:.3f}")
283
+ print(f"support: {metrics['support']}")
284
+
285
+
286
+ def compute_classification_metrics(
287
+ y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str], target_col: str
288
+ ) -> pd.DataFrame:
289
+ """Compute per-class classification metrics.
290
+
291
+ Args:
292
+ y_true: Ground truth labels
293
+ y_pred: Predicted labels
294
+ label_names: List of class label names
295
+ target_col: Name of the target column (for DataFrame output)
296
+
297
+ Returns:
298
+ DataFrame with columns: target_col, precision, recall, f1, support
299
+ """
300
+ scores = precision_recall_fscore_support(y_true, y_pred, average=None, labels=label_names)
301
+ return pd.DataFrame(
302
+ {
303
+ target_col: label_names,
304
+ "precision": scores[0],
305
+ "recall": scores[1],
306
+ "f1": scores[2],
307
+ "support": scores[3],
308
+ }
309
+ )
310
+
311
+
312
+ def print_classification_metrics(score_df: pd.DataFrame, target_col: str, label_names: list[str]) -> None:
313
+ """Print per-class classification metrics in the format expected by SageMaker.
314
+
315
+ Args:
316
+ score_df: DataFrame from compute_classification_metrics
317
+ target_col: Name of the target column
318
+ label_names: List of class label names
319
+ """
320
+ metrics = ["precision", "recall", "f1", "support"]
321
+ for t in label_names:
322
+ for m in metrics:
323
+ value = score_df.loc[score_df[target_col] == t, m].iloc[0]
324
+ print(f"Metrics:{t}:{m} {value}")
325
+
326
+
327
+ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str]) -> None:
328
+ """Print confusion matrix in the format expected by SageMaker.
329
+
330
+ Args:
331
+ y_true: Ground truth labels
332
+ y_pred: Predicted labels
333
+ label_names: List of class label names
334
+ """
335
+ conf_mtx = confusion_matrix(y_true, y_pred, labels=label_names)
336
+ for i, row_name in enumerate(label_names):
337
+ for j, col_name in enumerate(label_names):
338
+ value = conf_mtx[i, j]
339
+ print(f"ConfusionMatrix:{row_name}:{col_name} {value}")