workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  import logging
4
4
  import pandas as pd
5
5
  import numpy as np
6
+ from scipy.stats import spearmanr
6
7
  import importlib.resources
7
8
  from pathlib import Path
8
9
  import os
@@ -92,13 +93,158 @@ def get_custom_script_path(package: str, script_name: str) -> Path:
92
93
  return script_path
93
94
 
94
95
 
95
- def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
96
- """Create a proximity model based on the given model
96
+ def proximity_model_local(model: "Model", include_all_columns: bool = False):
97
+ """Create a FeatureSpaceProximity Model for this Model
98
+
99
+ Args:
100
+ model (Model): The Model/FeatureSet used to create the proximity model
101
+ include_all_columns (bool): Include all DataFrame columns in neighbor results (default: False)
102
+
103
+ Returns:
104
+ FeatureSpaceProximity: The proximity model
105
+ """
106
+ from workbench.algorithms.dataframe.feature_space_proximity import FeatureSpaceProximity # noqa: F401
107
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
108
+
109
+ # Get Feature and Target Columns from the existing given Model
110
+ features = model.features()
111
+ target = model.target()
112
+
113
+ # Backtrack our FeatureSet to get the ID column
114
+ fs = FeatureSet(model.get_input())
115
+ id_column = fs.id_column
116
+
117
+ # Create the Proximity Model from both the full FeatureSet and the Model training data
118
+ full_df = fs.pull_dataframe()
119
+ model_df = model.training_view().pull_dataframe()
120
+
121
+ # Mark rows that are in the model
122
+ model_ids = set(model_df[id_column])
123
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
124
+
125
+ # Create and return the FeatureSpaceProximity Model
126
+ return FeatureSpaceProximity(
127
+ full_df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
128
+ )
129
+
130
+
131
+ def fingerprint_prox_model_local(
132
+ model: "Model",
133
+ include_all_columns: bool = False,
134
+ radius: int = 2,
135
+ n_bits: int = 1024,
136
+ counts: bool = False,
137
+ ):
138
+ """Create a FingerprintProximity Model for this Model
139
+
140
+ Args:
141
+ model (Model): The Model used to create the fingerprint proximity model
142
+ include_all_columns (bool): Include all DataFrame columns in neighbor results (default: False)
143
+ radius (int): Morgan fingerprint radius (default: 2)
144
+ n_bits (int): Number of bits for the fingerprint (default: 1024)
145
+ counts (bool): Use count fingerprints instead of binary (default: False)
146
+
147
+ Returns:
148
+ FingerprintProximity: The fingerprint proximity model
149
+ """
150
+ from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity # noqa: F401
151
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
152
+
153
+ # Get Target Column from the existing given Model
154
+ target = model.target()
155
+
156
+ # Backtrack our FeatureSet to get the ID column
157
+ fs = FeatureSet(model.get_input())
158
+ id_column = fs.id_column
159
+
160
+ # Create the Proximity Model from both the full FeatureSet and the Model training data
161
+ full_df = fs.pull_dataframe()
162
+ model_df = model.training_view().pull_dataframe()
163
+
164
+ # Mark rows that are in the model
165
+ model_ids = set(model_df[id_column])
166
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
167
+
168
+ # Create and return the FingerprintProximity Model
169
+ return FingerprintProximity(
170
+ full_df,
171
+ id_column=id_column,
172
+ target=target,
173
+ include_all_columns=include_all_columns,
174
+ radius=radius,
175
+ n_bits=n_bits,
176
+ )
177
+
178
+
179
+ def noise_model_local(model: "Model"):
180
+ """Create a NoiseModel for detecting noisy/problematic samples in a Model's training data.
181
+
182
+ Args:
183
+ model (Model): The Model used to create the noise model
184
+
185
+ Returns:
186
+ NoiseModel: The noise model with precomputed noise scores for all samples
187
+ """
188
+ from workbench.algorithms.models.noise_model import NoiseModel # noqa: F401 (avoid circular import)
189
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
190
+
191
+ # Get Feature and Target Columns from the existing given Model
192
+ features = model.features()
193
+ target = model.target()
194
+
195
+ # Backtrack our FeatureSet to get the ID column
196
+ fs = FeatureSet(model.get_input())
197
+ id_column = fs.id_column
198
+
199
+ # Create the NoiseModel from both the full FeatureSet and the Model training data
200
+ full_df = fs.pull_dataframe()
201
+ model_df = model.training_view().pull_dataframe()
202
+
203
+ # Mark rows that are in the model
204
+ model_ids = set(model_df[id_column])
205
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
206
+
207
+ # Create and return the NoiseModel
208
+ return NoiseModel(full_df, id_column, features, target)
209
+
210
+
211
+ def cleanlab_model_local(model: "Model"):
212
+ """Create a CleanlabModels instance for detecting data quality issues in a Model's training data.
213
+
214
+ Args:
215
+ model (Model): The Model used to create the cleanlab models
216
+
217
+ Returns:
218
+ CleanlabModels: Factory providing access to CleanLearning and Datalab models.
219
+ - clean_learning(): CleanLearning model with enhanced get_label_issues()
220
+ - datalab(): Datalab instance with report(), get_issues()
221
+ """
222
+ from workbench.algorithms.models.cleanlab_model import create_cleanlab_model # noqa: F401 (avoid circular import)
223
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
224
+
225
+ # Get Feature and Target Columns from the existing given Model
226
+ features = model.features()
227
+ target = model.target()
228
+ model_type = model.model_type
229
+
230
+ # Backtrack our FeatureSet to get the ID column
231
+ fs = FeatureSet(model.get_input())
232
+ id_column = fs.id_column
233
+
234
+ # Get the full FeatureSet data
235
+ full_df = fs.pull_dataframe()
236
+
237
+ # Create and return the CleanLearning model
238
+ return create_cleanlab_model(full_df, id_column, features, target, model_type=model_type)
239
+
240
+
241
+ def published_proximity_model(model: "Model", prox_model_name: str, include_all_columns: bool = False) -> "Model":
242
+ """Create a published proximity model based on the given model
97
243
 
98
244
  Args:
99
245
  model (Model): The model to create the proximity model from
100
246
  prox_model_name (str): The name of the proximity model to create
101
- track_columns (list, optional): List of columns to track in the proximity model
247
+ include_all_columns (bool): Include all DataFrame columns in results (default: False)
102
248
  Returns:
103
249
  Model: The proximity model
104
250
  """
@@ -121,45 +267,23 @@ def proximity_model(model: "Model", prox_model_name: str, track_columns: list =
121
267
  description=f"Proximity Model for {model.name}",
122
268
  tags=["proximity", model.name],
123
269
  custom_script=script_path,
124
- custom_args={"track_columns": track_columns},
270
+ custom_args={"include_all_columns": include_all_columns},
125
271
  )
126
272
  return prox_model
127
273
 
128
274
 
129
- def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -> "Model":
130
- """Create a Uncertainty Quantification (UQ) model based on the given model
131
-
132
- Args:
133
- model (Model): The model to create the UQ model from
134
- uq_model_name (str): The name of the UQ model to create
135
- train_all_data (bool, optional): Whether to train the UQ model on all data (default: False)
136
-
137
- Returns:
138
- Model: The UQ model
275
+ def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
139
276
  """
140
- from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
277
+ Extract a tarball safely, using data filter if available.
141
278
 
142
- # Get the custom script path for the UQ model
143
- script_path = get_custom_script_path("uq_models", "meta_uq.template")
144
-
145
- # Get Feature and Target Columns from the existing given Model
146
- features = model.features()
147
- target = model.target()
148
-
149
- # Create the Proximity Model from our FeatureSet
150
- fs = FeatureSet(model.get_input())
151
- uq_model = fs.to_model(
152
- name=uq_model_name,
153
- model_type=ModelType.UQ_REGRESSOR,
154
- feature_list=features,
155
- target_column=target,
156
- description=f"UQ Model for {model.name}",
157
- tags=["uq", model.name],
158
- train_all_data=train_all_data,
159
- custom_script=script_path,
160
- custom_args={"id_column": fs.id_column, "track_columns": [target]},
161
- )
162
- return uq_model
279
+ The filter parameter was backported to Python 3.8+, 3.9+, 3.10.13+, 3.11+
280
+ as a security patch, but may not be present in older patch versions.
281
+ """
282
+ with tarfile.open(tar_path, "r:gz") as tar:
283
+ if hasattr(tarfile, "data_filter"):
284
+ tar.extractall(path=extract_path, filter="data")
285
+ else:
286
+ tar.extractall(path=extract_path)
163
287
 
164
288
 
165
289
  def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
@@ -180,8 +304,7 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
180
304
  wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
181
305
 
182
306
  # Extract tarball
183
- with tarfile.open(local_tar_path, "r:gz") as tar:
184
- tar.extractall(path=tmpdir, filter="data")
307
+ safe_extract_tarfile(local_tar_path, tmpdir)
185
308
 
186
309
  # Look for category mappings in base directory only
187
310
  mappings_path = os.path.join(tmpdir, "category_mappings.json")
@@ -197,6 +320,63 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
197
320
  return category_mappings
198
321
 
199
322
 
323
+ def load_hyperparameters_from_s3(model_artifact_uri: str) -> Optional[dict]:
324
+ """
325
+ Download and extract hyperparameters from a model artifact in S3.
326
+
327
+ Args:
328
+ model_artifact_uri (str): S3 URI of the model artifact (model.tar.gz).
329
+
330
+ Returns:
331
+ dict: The loaded hyperparameters or None if not found.
332
+ """
333
+ hyperparameters = None
334
+
335
+ with tempfile.TemporaryDirectory() as tmpdir:
336
+ # Download model artifact
337
+ local_tar_path = os.path.join(tmpdir, "model.tar.gz")
338
+ wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
339
+
340
+ # Extract tarball
341
+ safe_extract_tarfile(local_tar_path, tmpdir)
342
+
343
+ # Look for hyperparameters in base directory only
344
+ hyperparameters_path = os.path.join(tmpdir, "hyperparameters.json")
345
+
346
+ if os.path.exists(hyperparameters_path):
347
+ try:
348
+ with open(hyperparameters_path, "r") as f:
349
+ hyperparameters = json.load(f)
350
+ log.info(f"Loaded hyperparameters from {hyperparameters_path}")
351
+ except Exception as e:
352
+ log.warning(f"Failed to load hyperparameters from {hyperparameters_path}: {e}")
353
+
354
+ return hyperparameters
355
+
356
+
357
+ def get_model_hyperparameters(workbench_model: Any) -> Optional[dict]:
358
+ """Get the hyperparameters used to train a Workbench model.
359
+
360
+ This retrieves the hyperparameters.json file from the model artifacts
361
+ that was saved during model training.
362
+
363
+ Args:
364
+ workbench_model: Workbench model object
365
+
366
+ Returns:
367
+ dict: The hyperparameters used during training, or None if not found
368
+ """
369
+ # Get the model artifact URI
370
+ model_artifact_uri = workbench_model.model_data_url()
371
+
372
+ if model_artifact_uri is None:
373
+ log.warning(f"No model artifact found for {workbench_model.uuid}")
374
+ return None
375
+
376
+ log.info(f"Loading hyperparameters from {model_artifact_uri}")
377
+ return load_hyperparameters_from_s3(model_artifact_uri)
378
+
379
+
200
380
  def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
201
381
  """
202
382
  Evaluate uncertainty quantification model with essential metrics.
@@ -217,31 +397,51 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
217
397
  if "prediction" not in df.columns:
218
398
  raise ValueError("Prediction column 'prediction' not found in DataFrame.")
219
399
 
400
+ # Drop rows with NaN predictions (e.g., from models that can't handle missing features)
401
+ n_total = len(df)
402
+ df = df.dropna(subset=["prediction", target_col])
403
+ n_valid = len(df)
404
+ if n_valid < n_total:
405
+ log.info(f"UQ metrics: dropped {n_total - n_valid} rows with NaN predictions")
406
+
220
407
  # --- Coverage and Interval Width ---
221
408
  if "q_025" in df.columns and "q_975" in df.columns:
222
409
  lower_95, upper_95 = df["q_025"], df["q_975"]
410
+ lower_90, upper_90 = df["q_05"], df["q_95"]
411
+ lower_80, upper_80 = df["q_10"], df["q_90"]
412
+ lower_68 = df.get("q_16", df["q_10"]) # fallback to 80% interval
413
+ upper_68 = df.get("q_84", df["q_90"]) # fallback to 80% interval
223
414
  lower_50, upper_50 = df["q_25"], df["q_75"]
224
415
  elif "prediction_std" in df.columns:
225
416
  lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
226
417
  upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
418
+ lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
419
+ upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
420
+ lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
421
+ upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
422
+ lower_68 = df["prediction"] - 1.0 * df["prediction_std"]
423
+ upper_68 = df["prediction"] + 1.0 * df["prediction_std"]
227
424
  lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
228
425
  upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
229
426
  else:
230
427
  raise ValueError(
231
428
  "Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
232
429
  )
430
+ median_std = df["prediction_std"].median()
233
431
  coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
234
- coverage_50 = np.mean((df[target_col] >= lower_50) & (df[target_col] <= upper_50))
235
- avg_width_95 = np.mean(upper_95 - lower_95)
236
- avg_width_50 = np.mean(upper_50 - lower_50)
432
+ coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
433
+ coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
434
+ coverage_68 = np.mean((df[target_col] >= lower_68) & (df[target_col] <= upper_68))
435
+ median_width_95 = np.median(upper_95 - lower_95)
436
+ median_width_90 = np.median(upper_90 - lower_90)
437
+ median_width_80 = np.median(upper_80 - lower_80)
438
+ median_width_50 = np.median(upper_50 - lower_50)
439
+ median_width_68 = np.median(upper_68 - lower_68)
237
440
 
238
441
  # --- CRPS (measures calibration + sharpness) ---
239
- if "prediction_std" in df.columns:
240
- z = (df[target_col] - df["prediction"]) / df["prediction_std"]
241
- crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
242
- mean_crps = np.mean(crps)
243
- else:
244
- mean_crps = np.nan
442
+ z = (df[target_col] - df["prediction"]) / df["prediction_std"]
443
+ crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
444
+ mean_crps = np.mean(crps)
245
445
 
246
446
  # --- Interval Score @ 95% (penalizes miscoverage) ---
247
447
  alpha_95 = 0.05
@@ -252,38 +452,50 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
252
452
  )
253
453
  mean_is_95 = np.mean(is_95)
254
454
 
255
- # --- Adaptive Calibration (correlation between errors and uncertainty) ---
455
+ # --- Interval to Error Correlation ---
256
456
  abs_residuals = np.abs(df[target_col] - df["prediction"])
257
- width_95 = upper_95 - lower_95
258
- adaptive_calibration = np.corrcoef(abs_residuals, width_95)[0, 1]
457
+ width_68 = upper_68 - lower_68
458
+
459
+ # Spearman correlation for robustness
460
+ interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
259
461
 
260
462
  # Collect results
261
463
  results = {
464
+ "coverage_68": coverage_68,
465
+ "coverage_80": coverage_80,
466
+ "coverage_90": coverage_90,
262
467
  "coverage_95": coverage_95,
263
- "coverage_50": coverage_50,
264
- "avg_width_95": avg_width_95,
265
- "avg_width_50": avg_width_50,
266
- "crps": mean_crps,
267
- "interval_score_95": mean_is_95,
268
- "adaptive_calibration": adaptive_calibration,
468
+ "median_std": median_std,
469
+ "median_width_50": median_width_50,
470
+ "median_width_68": median_width_68,
471
+ "median_width_80": median_width_80,
472
+ "median_width_90": median_width_90,
473
+ "median_width_95": median_width_95,
474
+ "interval_to_error_corr": interval_to_error_corr,
269
475
  "n_samples": len(df),
270
476
  }
271
477
 
272
478
  print("\n=== UQ Metrics ===")
479
+ print(f"Coverage @ 68%: {coverage_68:.3f} (target: 0.68)")
480
+ print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
481
+ print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
273
482
  print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
274
- print(f"Coverage @ 50%: {coverage_50:.3f} (target: 0.50)")
275
- print(f"Average 95% Width: {avg_width_95:.3f}")
276
- print(f"Average 50% Width: {avg_width_50:.3f}")
483
+ print(f"Median Prediction StdDev: {median_std:.3f}")
484
+ print(f"Median 50% Width: {median_width_50:.3f}")
485
+ print(f"Median 68% Width: {median_width_68:.3f}")
486
+ print(f"Median 80% Width: {median_width_80:.3f}")
487
+ print(f"Median 90% Width: {median_width_90:.3f}")
488
+ print(f"Median 95% Width: {median_width_95:.3f}")
277
489
  print(f"CRPS: {mean_crps:.3f} (lower is better)")
278
490
  print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
279
- print(f"Adaptive Calibration: {adaptive_calibration:.3f} (higher is better, target: >0.5)")
491
+ print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
280
492
  print(f"Samples: {len(df)}")
281
493
  return results
282
494
 
283
495
 
284
496
  if __name__ == "__main__":
285
497
  """Exercise the Model Utilities"""
286
- from workbench.api import Model, Endpoint
498
+ from workbench.api import Model
287
499
 
288
500
  # Get the instance information
289
501
  print(model_instance_info())
@@ -298,24 +510,11 @@ if __name__ == "__main__":
298
510
  # Get the custom script path
299
511
  print(get_custom_script_path("chem_info", "molecular_descriptors.py"))
300
512
 
301
- # Test the proximity model
513
+ # Test loading hyperparameters
302
514
  m = Model("aqsol-regression")
515
+ hyperparams = get_model_hyperparameters(m)
516
+ print(hyperparams)
517
+
518
+ # Test the proximity model
303
519
  # prox_model = proximity_model(m, "aqsol-prox")
304
520
  # print(prox_model)#
305
-
306
- # Test the UQ model
307
- # uq_model_instance = uq_model(m, "aqsol-uq")
308
- # print(uq_model_instance)
309
- # uq_model_instance.to_endpoint()
310
-
311
- # Test the uq_metrics function
312
- end = Endpoint("aqsol-uq")
313
- df = end.auto_inference(capture=True)
314
- results = uq_metrics(df, target_col="solubility")
315
- print(results)
316
-
317
- # Test the uq_metrics function
318
- end = Endpoint("aqsol-uq-100")
319
- df = end.auto_inference(capture=True)
320
- results = uq_metrics(df, target_col="solubility")
321
- print(results)
@@ -14,7 +14,7 @@ from workbench.utils.s3_utils import read_content_from_s3
14
14
  log = logging.getLogger("workbench")
15
15
 
16
16
 
17
- def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
17
+ def pull_data_capture_for_testing(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
18
18
  """
19
19
  Read and process captured data from S3.
20
20
 
@@ -26,7 +26,12 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
26
26
 
27
27
  Returns:
28
28
  Union[pd.DataFrame, None]: A dataframe of the captured data (or None if no data is found).
29
+
30
+ Notes:
31
+ This method is really only for testing and debugging.
29
32
  """
33
+ log.important("This method is for testing and debugging only.")
34
+
30
35
  # List files in the specified S3 path
31
36
  files = wr.s3.list_objects(data_capture_path)
32
37
  if not files:
@@ -64,59 +69,53 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
64
69
  def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
65
70
  """
66
71
  Process the captured data DataFrame to extract input and output data.
67
- Continues processing even if individual files are malformed.
72
+ Handles cases where input or output might not be captured.
73
+
68
74
  Args:
69
75
  df (DataFrame): DataFrame with captured data.
70
76
  Returns:
71
77
  tuple[DataFrame, DataFrame]: Input and output DataFrames.
72
78
  """
79
+
80
+ def parse_endpoint_data(data: dict) -> pd.DataFrame:
81
+ """Parse endpoint data based on encoding type."""
82
+ encoding = data["encoding"].upper()
83
+
84
+ if encoding == "CSV":
85
+ return pd.read_csv(StringIO(data["data"]))
86
+ elif encoding == "JSON":
87
+ json_data = json.loads(data["data"])
88
+ if isinstance(json_data, dict):
89
+ return pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
90
+ else:
91
+ return pd.DataFrame(json_data)
92
+ else:
93
+ return None # Unknown encoding
94
+
73
95
  input_dfs = []
74
96
  output_dfs = []
75
97
 
76
- for idx, row in df.iterrows():
98
+ # Use itertuples() instead of iterrows() for better performance
99
+ for row in df.itertuples(index=True):
77
100
  try:
78
- capture_data = row["captureData"]
79
-
80
- # Check if this capture has the required fields (all or nothing)
81
- if "endpointInput" not in capture_data:
82
- log.warning(f"Row {idx}: No endpointInput found in capture data.")
83
- continue
84
-
85
- if "endpointOutput" not in capture_data:
86
- log.critical(
87
- f"Row {idx}: No endpointOutput found in capture data. DataCapture needs to include Output capture!"
88
- )
89
- continue
90
-
91
- # Process input data
92
- input_data = capture_data["endpointInput"]
93
- if input_data["encoding"].upper() == "CSV":
94
- input_df = pd.read_csv(StringIO(input_data["data"]))
95
- elif input_data["encoding"].upper() == "JSON":
96
- json_data = json.loads(input_data["data"])
97
- if isinstance(json_data, dict):
98
- input_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
99
- else:
100
- input_df = pd.DataFrame(json_data)
101
-
102
- # Process output data
103
- output_data = capture_data["endpointOutput"]
104
- if output_data["encoding"].upper() == "CSV":
105
- output_df = pd.read_csv(StringIO(output_data["data"]))
106
- elif output_data["encoding"].upper() == "JSON":
107
- json_data = json.loads(output_data["data"])
108
- if isinstance(json_data, dict):
109
- output_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
110
- else:
111
- output_df = pd.DataFrame(json_data)
112
-
113
- # If we get here, both processed successfully
114
- input_dfs.append(input_df)
115
- output_dfs.append(output_df)
101
+ capture_data = row.captureData
102
+
103
+ # Process input data if present
104
+ if "endpointInput" in capture_data:
105
+ input_df = parse_endpoint_data(capture_data["endpointInput"])
106
+ if input_df is not None:
107
+ input_dfs.append(input_df)
108
+
109
+ # Process output data if present
110
+ if "endpointOutput" in capture_data:
111
+ output_df = parse_endpoint_data(capture_data["endpointOutput"])
112
+ if output_df is not None:
113
+ output_dfs.append(output_df)
116
114
 
117
115
  except Exception as e:
118
- log.error(f"Row {idx}: Failed to process row: {e}")
116
+ log.debug(f"Row {row.Index}: Failed to process row: {e}")
119
117
  continue
118
+
120
119
  # Combine and return results
121
120
  return (
122
121
  pd.concat(input_dfs, ignore_index=True) if input_dfs else pd.DataFrame(),
@@ -178,23 +177,6 @@ def parse_monitoring_results(results_json: str) -> Dict[str, Any]:
178
177
  return {"error": str(e)}
179
178
 
180
179
 
181
- """TEMP
182
- # If the status is "CompletedWithViolations", we grab the lastest
183
- # violation file and add it to the result
184
- if status == "CompletedWithViolations":
185
- violation_file = f"{self.monitoring_path}/
186
- {last_run['CreationTime'].strftime('%Y/%m/%d')}/constraint_violations.json"
187
- if wr.s3.does_object_exist(violation_file):
188
- violations_json = read_content_from_s3(violation_file)
189
- violations = parse_monitoring_results(violations_json)
190
- result["violations"] = violations.get("constraint_violations", [])
191
- result["violation_count"] = len(result["violations"])
192
- else:
193
- result["violations"] = []
194
- result["violation_count"] = 0
195
- """
196
-
197
-
198
180
  def preprocessing_script(feature_list: list[str]) -> str:
199
181
  """
200
182
  A preprocessing script for monitoring jobs.
@@ -245,8 +227,8 @@ if __name__ == "__main__":
245
227
  from workbench.api.monitor import Monitor
246
228
 
247
229
  # Test pulling data capture
248
- mon = Monitor("caco2-pappab-class-0")
249
- df = pull_data_capture(mon.data_capture_path)
230
+ mon = Monitor("abalone-regression-rt")
231
+ df = pull_data_capture_for_testing(mon.data_capture_path)
250
232
  print("Data Capture:")
251
233
  print(df.head())
252
234
 
@@ -262,4 +244,4 @@ if __name__ == "__main__":
262
244
  # Test preprocessing script
263
245
  script = preprocessing_script(["feature1", "feature2", "feature3"])
264
246
  print("\nPreprocessing Script:")
265
- print(script)
247
+ # print(script)
@@ -152,7 +152,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
152
152
 
153
153
  # Check for differences in common columns
154
154
  for column in common_columns:
155
- if pd.api.types.is_string_dtype(df1[column]) or pd.api.types.is_string_dtype(df2[column]):
155
+ if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
156
156
  # String comparison with NaNs treated as equal
157
157
  differences = ~(df1[column].fillna("") == df2[column].fillna(""))
158
158
  elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
@@ -161,8 +161,8 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
161
161
  pd.isna(df1[column]) & pd.isna(df2[column])
162
162
  )
163
163
  else:
164
- # Other types (e.g., int) with NaNs treated as equal
165
- differences = ~(df1[column].fillna(0) == df2[column].fillna(0))
164
+ # Other types (int, Int64, etc.) - compare with NaNs treated as equal
165
+ differences = (df1[column] != df2[column]) & ~(pd.isna(df1[column]) & pd.isna(df2[column]))
166
166
 
167
167
  # If differences exist, display them
168
168
  if differences.any():