workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ import getpass
5
5
  import time # For managing send intervals
6
6
 
7
7
  # Workbench imports
8
- from workbench.utils.execution_environment import (
8
+ from workbench_bridges.utils.execution_environment import (
9
9
  running_on_lambda,
10
10
  running_on_glue,
11
11
  running_on_ecs,
@@ -0,0 +1,137 @@
1
+ """AWS CloudWatch utility functions for Workbench."""
2
+
3
+ import time
4
+ import logging
5
+ from datetime import datetime, timezone
6
+ from typing import List, Optional, Dict, Generator
7
+ from urllib.parse import quote
8
+ from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
9
+
10
+ log = logging.getLogger("workbench")
11
+
12
+
13
+ def get_cloudwatch_client():
14
+ """Get the CloudWatch Logs client using the Workbench assumed role session."""
15
+ session = AWSAccountClamp().boto3_session
16
+ return session.client("logs")
17
+
18
+
19
+ def get_cloudwatch_logs_url(log_group: str, log_stream: str) -> Optional[str]:
20
+ """
21
+ Generate CloudWatch logs URL for the specified log group and stream.
22
+
23
+ Args:
24
+ log_group: Log group name (e.g., '/aws/batch/job')
25
+ log_stream: Log stream name
26
+
27
+ Returns:
28
+ CloudWatch console URL or None if unable to generate
29
+ """
30
+ try:
31
+ region = AWSAccountClamp().region
32
+
33
+ # URL encode the log group and stream
34
+ encoded_group = quote(log_group, safe="")
35
+ encoded_stream = quote(log_stream, safe="")
36
+
37
+ return (
38
+ f"https://{region}.console.aws.amazon.com/cloudwatch/home?"
39
+ f"region={region}#logsV2:log-groups/log-group/{encoded_group}"
40
+ f"/log-events/{encoded_stream}"
41
+ )
42
+ except Exception as e: # noqa: BLE001
43
+ log.warning(f"Failed to generate CloudWatch logs URL: {e}")
44
+ return None
45
+
46
+
47
+ def get_active_log_streams(
48
+ log_group_name: str, start_time_ms: int, stream_filter: Optional[str] = None, client=None
49
+ ) -> List[str]:
50
+ """Retrieve log streams that have events after the specified start time."""
51
+ if not client:
52
+ client = get_cloudwatch_client()
53
+ active_streams = []
54
+ stream_params = {
55
+ "logGroupName": log_group_name,
56
+ "orderBy": "LastEventTime",
57
+ "descending": True,
58
+ }
59
+ while True:
60
+ response = client.describe_log_streams(**stream_params)
61
+ log_streams = response.get("logStreams", [])
62
+ for log_stream in log_streams:
63
+ log_stream_name = log_stream["logStreamName"]
64
+ last_event_timestamp = log_stream.get("lastEventTimestamp", 0)
65
+ if last_event_timestamp >= start_time_ms:
66
+ active_streams.append(log_stream_name)
67
+ else:
68
+ break
69
+ if "nextToken" in response:
70
+ stream_params["nextToken"] = response["nextToken"]
71
+ else:
72
+ break
73
+ # Sort and filter streams
74
+ active_streams.sort()
75
+ if stream_filter and active_streams:
76
+ active_streams = [stream for stream in active_streams if stream_filter in stream]
77
+ return active_streams
78
+
79
+
80
+ def stream_log_events(
81
+ log_group_name: str,
82
+ log_stream_name: str,
83
+ start_time: Optional[datetime] = None,
84
+ end_time: Optional[datetime] = None,
85
+ follow: bool = False,
86
+ client=None,
87
+ ) -> Generator[Dict, None, None]:
88
+ """
89
+ Stream log events from a specific log stream.
90
+ Yields:
91
+ Log events as dictionaries
92
+ """
93
+ if not client:
94
+ client = get_cloudwatch_client()
95
+ params = {"logGroupName": log_group_name, "logStreamName": log_stream_name, "startFromHead": True}
96
+ if start_time:
97
+ params["startTime"] = int(start_time.timestamp() * 1000)
98
+ if end_time:
99
+ params["endTime"] = int(end_time.timestamp() * 1000)
100
+ next_token = None
101
+ while True:
102
+ if next_token:
103
+ params["nextToken"] = next_token
104
+ params.pop("startTime", None)
105
+ try:
106
+ response = client.get_log_events(**params)
107
+ events = response.get("events", [])
108
+ for event in events:
109
+ event["logStreamName"] = log_stream_name
110
+ yield event
111
+ next_token = response.get("nextForwardToken")
112
+ # Break if no more events or same token
113
+ if not next_token or next_token == params.get("nextToken"):
114
+ if not follow:
115
+ break
116
+ time.sleep(2)
117
+ except client.exceptions.ResourceNotFoundException:
118
+ if not follow:
119
+ break
120
+ time.sleep(2)
121
+
122
+
123
+ def print_log_event(
124
+ event: dict, show_stream: bool = True, local_time: bool = True, custom_format: Optional[str] = None
125
+ ):
126
+ """Print a formatted log event."""
127
+ timestamp = datetime.fromtimestamp(event["timestamp"] / 1000, tz=timezone.utc)
128
+ if local_time:
129
+ timestamp = timestamp.astimezone()
130
+ message = event["message"].rstrip()
131
+ if custom_format:
132
+ # Allow custom formatting
133
+ print(custom_format.format(stream=event.get("logStreamName", ""), time=timestamp, message=message))
134
+ elif show_stream and "logStreamName" in event:
135
+ print(f"[{event['logStreamName']}] [{timestamp:%Y-%m-%d %I:%M%p}] {message}")
136
+ else:
137
+ print(f"[{timestamp:%H:%M:%S}] {message}")
@@ -4,15 +4,12 @@ import os
4
4
  import sys
5
5
  import platform
6
6
  import logging
7
- import importlib.resources as resources # noqa: F401 Python 3.9 compatibility
8
7
  from typing import Any, Dict
8
+ from importlib.resources import files, as_file
9
9
 
10
10
  # Workbench imports
11
11
  from workbench.utils.license_manager import LicenseManager
12
- from workbench.utils.execution_environment import running_as_service
13
-
14
- # Python 3.9 compatibility
15
- from workbench.utils.resource_utils import get_resource_path
12
+ from workbench_bridges.utils.execution_environment import running_as_service
16
13
 
17
14
 
18
15
  class FatalConfigError(Exception):
@@ -172,8 +169,7 @@ class ConfigManager:
172
169
  Returns:
173
170
  str: The open source API key.
174
171
  """
175
- # Python 3.9 compatibility
176
- with get_resource_path("workbench.resources", "open_source_api.key") as open_source_key_path:
172
+ with as_file(files("workbench.resources").joinpath("open_source_api.key")) as open_source_key_path:
177
173
  with open(open_source_key_path, "r") as key_file:
178
174
  return key_file.read().strip()
179
175
 
@@ -7,9 +7,7 @@ from typing import Union, Optional
7
7
  import pandas as pd
8
8
 
9
9
  # Workbench Imports
10
- from workbench.api.feature_set import FeatureSet
11
- from workbench.api.model import Model
12
- from workbench.api.endpoint import Endpoint
10
+ from workbench.api import FeatureSet, Model, Endpoint
13
11
 
14
12
  # Set up the log
15
13
  log = logging.getLogger("workbench")
@@ -77,7 +75,7 @@ def internal_model_data_url(endpoint_config_name: str, session: boto3.Session) -
77
75
  return None
78
76
 
79
77
 
80
- def fs_training_data(end: Endpoint) -> pd.DataFrame:
78
+ def get_training_data(end: Endpoint) -> pd.DataFrame:
81
79
  """Code to get the training data from the FeatureSet used to train the Model
82
80
 
83
81
  Args:
@@ -100,7 +98,7 @@ def fs_training_data(end: Endpoint) -> pd.DataFrame:
100
98
  return train_df
101
99
 
102
100
 
103
- def fs_evaluation_data(end: Endpoint) -> pd.DataFrame:
101
+ def get_evaluation_data(end: Endpoint) -> pd.DataFrame:
104
102
  """Code to get the evaluation data from the FeatureSet NOT used for training
105
103
 
106
104
  Args:
@@ -178,11 +176,11 @@ if __name__ == "__main__":
178
176
  print(model_data_url)
179
177
 
180
178
  # Get the training data
181
- my_train_df = fs_training_data(my_endpoint)
179
+ my_train_df = get_training_data(my_endpoint)
182
180
  print(my_train_df)
183
181
 
184
182
  # Get the evaluation data
185
- my_eval_df = fs_evaluation_data(my_endpoint)
183
+ my_eval_df = get_evaluation_data(my_endpoint)
186
184
  print(my_eval_df)
187
185
 
188
186
  # Backtrack to the FeatureSet
@@ -6,15 +6,12 @@ import json
6
6
  import logging
7
7
  import requests
8
8
  from typing import Union
9
- import importlib.resources as resources # noqa: F401 Python 3.9 compatibility
10
9
  from datetime import datetime
11
10
  from cryptography.hazmat.primitives import hashes
12
11
  from cryptography.hazmat.primitives.asymmetric import padding
13
12
  from cryptography.hazmat.primitives import serialization
14
13
  from cryptography.hazmat.backends import default_backend
15
-
16
- # Python 3.9 compatibility
17
- from workbench.utils.resource_utils import get_resource_path
14
+ from importlib.resources import files, as_file
18
15
 
19
16
 
20
17
  class FatalLicenseError(Exception):
@@ -140,8 +137,7 @@ class LicenseManager:
140
137
  Returns:
141
138
  The public key as an object.
142
139
  """
143
- # Python 3.9 compatibility
144
- with get_resource_path("workbench.resources", "signature_verify_pub.pem") as public_key_path:
140
+ with as_file(files("workbench.resources").joinpath("signature_verify_pub.pem")) as public_key_path:
145
141
  with open(public_key_path, "rb") as key_file:
146
142
  public_key_data = key_file.read()
147
143
 
@@ -3,6 +3,7 @@
3
3
  import logging
4
4
  import pandas as pd
5
5
  import numpy as np
6
+ from scipy.stats import spearmanr
6
7
  import importlib.resources
7
8
  from pathlib import Path
8
9
  import os
@@ -92,6 +93,38 @@ def get_custom_script_path(package: str, script_name: str) -> Path:
92
93
  return script_path
93
94
 
94
95
 
96
+ def proximity_model_local(model: "Model"):
97
+ """Create a Proximity Model for this Model
98
+
99
+ Args:
100
+ model (Model): The Model/FeatureSet used to create the proximity model
101
+
102
+ Returns:
103
+ Proximity: The proximity model
104
+ """
105
+ from workbench.algorithms.dataframe.proximity import Proximity # noqa: F401 (avoid circular import)
106
+ from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
107
+
108
+ # Get Feature and Target Columns from the existing given Model
109
+ features = model.features()
110
+ target = model.target()
111
+
112
+ # Backtrack our FeatureSet to get the ID column
113
+ fs = FeatureSet(model.get_input())
114
+ id_column = fs.id_column
115
+
116
+ # Create the Proximity Model from both the full FeatureSet and the Model training data
117
+ full_df = fs.pull_dataframe()
118
+ model_df = model.training_view().pull_dataframe()
119
+
120
+ # Mark rows that are in the model
121
+ model_ids = set(model_df[id_column])
122
+ full_df["in_model"] = full_df[id_column].isin(model_ids)
123
+
124
+ # Create and return the Proximity Model
125
+ return Proximity(full_df, id_column, features, target, track_columns=features)
126
+
127
+
95
128
  def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
96
129
  """Create a proximity model based on the given model
97
130
 
@@ -139,9 +172,6 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
139
172
  """
140
173
  from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
141
174
 
142
- # Get the custom script path for the UQ model
143
- script_path = get_custom_script_path("uq_models", "meta_uq.template")
144
-
145
175
  # Get Feature and Target Columns from the existing given Model
146
176
  features = model.features()
147
177
  target = model.target()
@@ -156,12 +186,25 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
156
186
  description=f"UQ Model for {model.name}",
157
187
  tags=["uq", model.name],
158
188
  train_all_data=train_all_data,
159
- custom_script=script_path,
160
189
  custom_args={"id_column": fs.id_column, "track_columns": [target]},
161
190
  )
162
191
  return uq_model
163
192
 
164
193
 
194
+ def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
195
+ """
196
+ Extract a tarball safely, using data filter if available.
197
+
198
+ The filter parameter was backported to Python 3.8+, 3.9+, 3.10.13+, 3.11+
199
+ as a security patch, but may not be present in older patch versions.
200
+ """
201
+ with tarfile.open(tar_path, "r:gz") as tar:
202
+ if hasattr(tarfile, "data_filter"):
203
+ tar.extractall(path=extract_path, filter="data")
204
+ else:
205
+ tar.extractall(path=extract_path)
206
+
207
+
165
208
  def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
166
209
  """
167
210
  Download and extract category mappings from a model artifact in S3.
@@ -180,8 +223,7 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
180
223
  wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
181
224
 
182
225
  # Extract tarball
183
- with tarfile.open(local_tar_path, "r:gz") as tar:
184
- tar.extractall(path=tmpdir, filter="data")
226
+ safe_extract_tarfile(local_tar_path, tmpdir)
185
227
 
186
228
  # Look for category mappings in base directory only
187
229
  mappings_path = os.path.join(tmpdir, "category_mappings.json")
@@ -220,28 +262,41 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
220
262
  # --- Coverage and Interval Width ---
221
263
  if "q_025" in df.columns and "q_975" in df.columns:
222
264
  lower_95, upper_95 = df["q_025"], df["q_975"]
265
+ lower_90, upper_90 = df["q_05"], df["q_95"]
266
+ lower_80, upper_80 = df["q_10"], df["q_90"]
267
+ lower_68 = df.get("q_16", df["q_10"]) # fallback to 80% interval
268
+ upper_68 = df.get("q_84", df["q_90"]) # fallback to 80% interval
223
269
  lower_50, upper_50 = df["q_25"], df["q_75"]
224
270
  elif "prediction_std" in df.columns:
225
271
  lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
226
272
  upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
273
+ lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
274
+ upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
275
+ lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
276
+ upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
277
+ lower_68 = df["prediction"] - 1.0 * df["prediction_std"]
278
+ upper_68 = df["prediction"] + 1.0 * df["prediction_std"]
227
279
  lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
228
280
  upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
229
281
  else:
230
282
  raise ValueError(
231
283
  "Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
232
284
  )
285
+ median_std = df["prediction_std"].median()
233
286
  coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
234
- coverage_50 = np.mean((df[target_col] >= lower_50) & (df[target_col] <= upper_50))
235
- avg_width_95 = np.mean(upper_95 - lower_95)
236
- avg_width_50 = np.mean(upper_50 - lower_50)
287
+ coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
288
+ coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
289
+ coverage_68 = np.mean((df[target_col] >= lower_68) & (df[target_col] <= upper_68))
290
+ median_width_95 = np.median(upper_95 - lower_95)
291
+ median_width_90 = np.median(upper_90 - lower_90)
292
+ median_width_80 = np.median(upper_80 - lower_80)
293
+ median_width_50 = np.median(upper_50 - lower_50)
294
+ median_width_68 = np.median(upper_68 - lower_68)
237
295
 
238
296
  # --- CRPS (measures calibration + sharpness) ---
239
- if "prediction_std" in df.columns:
240
- z = (df[target_col] - df["prediction"]) / df["prediction_std"]
241
- crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
242
- mean_crps = np.mean(crps)
243
- else:
244
- mean_crps = np.nan
297
+ z = (df[target_col] - df["prediction"]) / df["prediction_std"]
298
+ crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
299
+ mean_crps = np.mean(crps)
245
300
 
246
301
  # --- Interval Score @ 95% (penalizes miscoverage) ---
247
302
  alpha_95 = 0.05
@@ -252,31 +307,43 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
252
307
  )
253
308
  mean_is_95 = np.mean(is_95)
254
309
 
255
- # --- Adaptive Calibration (correlation between errors and uncertainty) ---
310
+ # --- Interval to Error Correlation ---
256
311
  abs_residuals = np.abs(df[target_col] - df["prediction"])
257
- width_95 = upper_95 - lower_95
258
- adaptive_calibration = np.corrcoef(abs_residuals, width_95)[0, 1]
312
+ width_68 = upper_68 - lower_68
313
+
314
+ # Spearman correlation for robustness
315
+ interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
259
316
 
260
317
  # Collect results
261
318
  results = {
319
+ "coverage_68": coverage_68,
320
+ "coverage_80": coverage_80,
321
+ "coverage_90": coverage_90,
262
322
  "coverage_95": coverage_95,
263
- "coverage_50": coverage_50,
264
- "avg_width_95": avg_width_95,
265
- "avg_width_50": avg_width_50,
266
- "crps": mean_crps,
267
- "interval_score_95": mean_is_95,
268
- "adaptive_calibration": adaptive_calibration,
323
+ "median_std": median_std,
324
+ "median_width_50": median_width_50,
325
+ "median_width_68": median_width_68,
326
+ "median_width_80": median_width_80,
327
+ "median_width_90": median_width_90,
328
+ "median_width_95": median_width_95,
329
+ "interval_to_error_corr": interval_to_error_corr,
269
330
  "n_samples": len(df),
270
331
  }
271
332
 
272
333
  print("\n=== UQ Metrics ===")
334
+ print(f"Coverage @ 68%: {coverage_68:.3f} (target: 0.68)")
335
+ print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
336
+ print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
273
337
  print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
274
- print(f"Coverage @ 50%: {coverage_50:.3f} (target: 0.50)")
275
- print(f"Average 95% Width: {avg_width_95:.3f}")
276
- print(f"Average 50% Width: {avg_width_50:.3f}")
338
+ print(f"Median Prediction StdDev: {median_std:.3f}")
339
+ print(f"Median 50% Width: {median_width_50:.3f}")
340
+ print(f"Median 68% Width: {median_width_68:.3f}")
341
+ print(f"Median 80% Width: {median_width_80:.3f}")
342
+ print(f"Median 90% Width: {median_width_90:.3f}")
343
+ print(f"Median 95% Width: {median_width_95:.3f}")
277
344
  print(f"CRPS: {mean_crps:.3f} (lower is better)")
278
345
  print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
279
- print(f"Adaptive Calibration: {adaptive_calibration:.3f} (higher is better, target: >0.5)")
346
+ print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
280
347
  print(f"Samples: {len(df)}")
281
348
  return results
282
349
 
@@ -313,9 +380,3 @@ if __name__ == "__main__":
313
380
  df = end.auto_inference(capture=True)
314
381
  results = uq_metrics(df, target_col="solubility")
315
382
  print(results)
316
-
317
- # Test the uq_metrics function
318
- end = Endpoint("aqsol-uq-100")
319
- df = end.auto_inference(capture=True)
320
- results = uq_metrics(df, target_col="solubility")
321
- print(results)
@@ -14,7 +14,7 @@ from workbench.utils.s3_utils import read_content_from_s3
14
14
  log = logging.getLogger("workbench")
15
15
 
16
16
 
17
- def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
17
+ def pull_data_capture_for_testing(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
18
18
  """
19
19
  Read and process captured data from S3.
20
20
 
@@ -26,7 +26,12 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
26
26
 
27
27
  Returns:
28
28
  Union[pd.DataFrame, None]: A dataframe of the captured data (or None if no data is found).
29
+
30
+ Notes:
31
+ This method is really only for testing and debugging.
29
32
  """
33
+ log.important("This method is for testing and debugging only.")
34
+
30
35
  # List files in the specified S3 path
31
36
  files = wr.s3.list_objects(data_capture_path)
32
37
  if not files:
@@ -64,59 +69,53 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
64
69
  def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
65
70
  """
66
71
  Process the captured data DataFrame to extract input and output data.
67
- Continues processing even if individual files are malformed.
72
+ Handles cases where input or output might not be captured.
73
+
68
74
  Args:
69
75
  df (DataFrame): DataFrame with captured data.
70
76
  Returns:
71
77
  tuple[DataFrame, DataFrame]: Input and output DataFrames.
72
78
  """
79
+
80
+ def parse_endpoint_data(data: dict) -> pd.DataFrame:
81
+ """Parse endpoint data based on encoding type."""
82
+ encoding = data["encoding"].upper()
83
+
84
+ if encoding == "CSV":
85
+ return pd.read_csv(StringIO(data["data"]))
86
+ elif encoding == "JSON":
87
+ json_data = json.loads(data["data"])
88
+ if isinstance(json_data, dict):
89
+ return pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
90
+ else:
91
+ return pd.DataFrame(json_data)
92
+ else:
93
+ return None # Unknown encoding
94
+
73
95
  input_dfs = []
74
96
  output_dfs = []
75
97
 
76
- for idx, row in df.iterrows():
98
+ # Use itertuples() instead of iterrows() for better performance
99
+ for row in df.itertuples(index=True):
77
100
  try:
78
- capture_data = row["captureData"]
79
-
80
- # Check if this capture has the required fields (all or nothing)
81
- if "endpointInput" not in capture_data:
82
- log.warning(f"Row {idx}: No endpointInput found in capture data.")
83
- continue
84
-
85
- if "endpointOutput" not in capture_data:
86
- log.critical(
87
- f"Row {idx}: No endpointOutput found in capture data. DataCapture needs to include Output capture!"
88
- )
89
- continue
90
-
91
- # Process input data
92
- input_data = capture_data["endpointInput"]
93
- if input_data["encoding"].upper() == "CSV":
94
- input_df = pd.read_csv(StringIO(input_data["data"]))
95
- elif input_data["encoding"].upper() == "JSON":
96
- json_data = json.loads(input_data["data"])
97
- if isinstance(json_data, dict):
98
- input_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
99
- else:
100
- input_df = pd.DataFrame(json_data)
101
-
102
- # Process output data
103
- output_data = capture_data["endpointOutput"]
104
- if output_data["encoding"].upper() == "CSV":
105
- output_df = pd.read_csv(StringIO(output_data["data"]))
106
- elif output_data["encoding"].upper() == "JSON":
107
- json_data = json.loads(output_data["data"])
108
- if isinstance(json_data, dict):
109
- output_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
110
- else:
111
- output_df = pd.DataFrame(json_data)
112
-
113
- # If we get here, both processed successfully
114
- input_dfs.append(input_df)
115
- output_dfs.append(output_df)
101
+ capture_data = row.captureData
102
+
103
+ # Process input data if present
104
+ if "endpointInput" in capture_data:
105
+ input_df = parse_endpoint_data(capture_data["endpointInput"])
106
+ if input_df is not None:
107
+ input_dfs.append(input_df)
108
+
109
+ # Process output data if present
110
+ if "endpointOutput" in capture_data:
111
+ output_df = parse_endpoint_data(capture_data["endpointOutput"])
112
+ if output_df is not None:
113
+ output_dfs.append(output_df)
116
114
 
117
115
  except Exception as e:
118
- log.error(f"Row {idx}: Failed to process row: {e}")
116
+ log.debug(f"Row {row.Index}: Failed to process row: {e}")
119
117
  continue
118
+
120
119
  # Combine and return results
121
120
  return (
122
121
  pd.concat(input_dfs, ignore_index=True) if input_dfs else pd.DataFrame(),
@@ -178,23 +177,6 @@ def parse_monitoring_results(results_json: str) -> Dict[str, Any]:
178
177
  return {"error": str(e)}
179
178
 
180
179
 
181
- """TEMP
182
- # If the status is "CompletedWithViolations", we grab the lastest
183
- # violation file and add it to the result
184
- if status == "CompletedWithViolations":
185
- violation_file = f"{self.monitoring_path}/
186
- {last_run['CreationTime'].strftime('%Y/%m/%d')}/constraint_violations.json"
187
- if wr.s3.does_object_exist(violation_file):
188
- violations_json = read_content_from_s3(violation_file)
189
- violations = parse_monitoring_results(violations_json)
190
- result["violations"] = violations.get("constraint_violations", [])
191
- result["violation_count"] = len(result["violations"])
192
- else:
193
- result["violations"] = []
194
- result["violation_count"] = 0
195
- """
196
-
197
-
198
180
  def preprocessing_script(feature_list: list[str]) -> str:
199
181
  """
200
182
  A preprocessing script for monitoring jobs.
@@ -245,8 +227,8 @@ if __name__ == "__main__":
245
227
  from workbench.api.monitor import Monitor
246
228
 
247
229
  # Test pulling data capture
248
- mon = Monitor("caco2-pappab-class-0")
249
- df = pull_data_capture(mon.data_capture_path)
230
+ mon = Monitor("abalone-regression-rt")
231
+ df = pull_data_capture_for_testing(mon.data_capture_path)
250
232
  print("Data Capture:")
251
233
  print(df.head())
252
234
 
@@ -262,4 +244,4 @@ if __name__ == "__main__":
262
244
  # Test preprocessing script
263
245
  script = preprocessing_script(["feature1", "feature2", "feature3"])
264
246
  print("\nPreprocessing Script:")
265
- print(script)
247
+ # print(script)
@@ -152,7 +152,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
152
152
 
153
153
  # Check for differences in common columns
154
154
  for column in common_columns:
155
- if pd.api.types.is_string_dtype(df1[column]) or pd.api.types.is_string_dtype(df2[column]):
155
+ if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
156
156
  # String comparison with NaNs treated as equal
157
157
  differences = ~(df1[column].fillna("") == df2[column].fillna(""))
158
158
  elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
@@ -161,8 +161,8 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
161
161
  pd.isna(df1[column]) & pd.isna(df2[column])
162
162
  )
163
163
  else:
164
- # Other types (e.g., int) with NaNs treated as equal
165
- differences = ~(df1[column].fillna(0) == df2[column].fillna(0))
164
+ # Other types (int, Int64, etc.) - compare with NaNs treated as equal
165
+ differences = (df1[column] != df2[column]) & ~(pd.isna(df1[column]) & pd.isna(df2[column]))
166
166
 
167
167
  # If differences exist, display them
168
168
  if differences.any():