workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +261 -235
- workbench/algorithms/graph/light/proximity_graph.py +10 -8
- workbench/api/__init__.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +11 -0
- workbench/api/feature_set.py +11 -8
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -15
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +256 -118
- workbench/core/artifacts/feature_set_core.py +265 -16
- workbench/core/artifacts/model_core.py +107 -60
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +42 -32
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/chemprop/chemprop.template +852 -0
- workbench/model_scripts/chemprop/generated_model_script.py +852 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
- workbench/model_scripts/pytorch_model/pytorch.template +370 -187
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +17 -9
- workbench/model_scripts/uq_models/generated_model_script.py +605 -0
- workbench/model_scripts/uq_models/mapie.template +605 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
- workbench/model_scripts/xgb_model/xgb_model.template +44 -46
- workbench/repl/workbench_shell.py +28 -14
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +760 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +95 -34
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +526 -0
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +371 -156
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +9 -7
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""AWS CloudWatch utility functions for Workbench."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import logging
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import List, Optional, Dict, Generator
|
|
7
|
+
from urllib.parse import quote
|
|
8
|
+
from workbench.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger("workbench")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_cloudwatch_client():
|
|
14
|
+
"""Get the CloudWatch Logs client using the Workbench assumed role session."""
|
|
15
|
+
session = AWSAccountClamp().boto3_session
|
|
16
|
+
return session.client("logs")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_cloudwatch_logs_url(log_group: str, log_stream: str) -> Optional[str]:
|
|
20
|
+
"""
|
|
21
|
+
Generate CloudWatch logs URL for the specified log group and stream.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
log_group: Log group name (e.g., '/aws/batch/job')
|
|
25
|
+
log_stream: Log stream name
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
CloudWatch console URL or None if unable to generate
|
|
29
|
+
"""
|
|
30
|
+
try:
|
|
31
|
+
region = AWSAccountClamp().region
|
|
32
|
+
|
|
33
|
+
# URL encode the log group and stream
|
|
34
|
+
encoded_group = quote(log_group, safe="")
|
|
35
|
+
encoded_stream = quote(log_stream, safe="")
|
|
36
|
+
|
|
37
|
+
return (
|
|
38
|
+
f"https://{region}.console.aws.amazon.com/cloudwatch/home?"
|
|
39
|
+
f"region={region}#logsV2:log-groups/log-group/{encoded_group}"
|
|
40
|
+
f"/log-events/{encoded_stream}"
|
|
41
|
+
)
|
|
42
|
+
except Exception as e: # noqa: BLE001
|
|
43
|
+
log.warning(f"Failed to generate CloudWatch logs URL: {e}")
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_active_log_streams(
|
|
48
|
+
log_group_name: str, start_time_ms: int, stream_filter: Optional[str] = None, client=None
|
|
49
|
+
) -> List[str]:
|
|
50
|
+
"""Retrieve log streams that have events after the specified start time."""
|
|
51
|
+
if not client:
|
|
52
|
+
client = get_cloudwatch_client()
|
|
53
|
+
active_streams = []
|
|
54
|
+
stream_params = {
|
|
55
|
+
"logGroupName": log_group_name,
|
|
56
|
+
"orderBy": "LastEventTime",
|
|
57
|
+
"descending": True,
|
|
58
|
+
}
|
|
59
|
+
while True:
|
|
60
|
+
response = client.describe_log_streams(**stream_params)
|
|
61
|
+
log_streams = response.get("logStreams", [])
|
|
62
|
+
for log_stream in log_streams:
|
|
63
|
+
log_stream_name = log_stream["logStreamName"]
|
|
64
|
+
last_event_timestamp = log_stream.get("lastEventTimestamp", 0)
|
|
65
|
+
if last_event_timestamp >= start_time_ms:
|
|
66
|
+
active_streams.append(log_stream_name)
|
|
67
|
+
else:
|
|
68
|
+
break
|
|
69
|
+
if "nextToken" in response:
|
|
70
|
+
stream_params["nextToken"] = response["nextToken"]
|
|
71
|
+
else:
|
|
72
|
+
break
|
|
73
|
+
# Sort and filter streams
|
|
74
|
+
active_streams.sort()
|
|
75
|
+
if stream_filter and active_streams:
|
|
76
|
+
active_streams = [stream for stream in active_streams if stream_filter in stream]
|
|
77
|
+
return active_streams
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def stream_log_events(
|
|
81
|
+
log_group_name: str,
|
|
82
|
+
log_stream_name: str,
|
|
83
|
+
start_time: Optional[datetime] = None,
|
|
84
|
+
end_time: Optional[datetime] = None,
|
|
85
|
+
follow: bool = False,
|
|
86
|
+
client=None,
|
|
87
|
+
) -> Generator[Dict, None, None]:
|
|
88
|
+
"""
|
|
89
|
+
Stream log events from a specific log stream.
|
|
90
|
+
Yields:
|
|
91
|
+
Log events as dictionaries
|
|
92
|
+
"""
|
|
93
|
+
if not client:
|
|
94
|
+
client = get_cloudwatch_client()
|
|
95
|
+
params = {"logGroupName": log_group_name, "logStreamName": log_stream_name, "startFromHead": True}
|
|
96
|
+
if start_time:
|
|
97
|
+
params["startTime"] = int(start_time.timestamp() * 1000)
|
|
98
|
+
if end_time:
|
|
99
|
+
params["endTime"] = int(end_time.timestamp() * 1000)
|
|
100
|
+
next_token = None
|
|
101
|
+
while True:
|
|
102
|
+
if next_token:
|
|
103
|
+
params["nextToken"] = next_token
|
|
104
|
+
params.pop("startTime", None)
|
|
105
|
+
try:
|
|
106
|
+
response = client.get_log_events(**params)
|
|
107
|
+
events = response.get("events", [])
|
|
108
|
+
for event in events:
|
|
109
|
+
event["logStreamName"] = log_stream_name
|
|
110
|
+
yield event
|
|
111
|
+
next_token = response.get("nextForwardToken")
|
|
112
|
+
# Break if no more events or same token
|
|
113
|
+
if not next_token or next_token == params.get("nextToken"):
|
|
114
|
+
if not follow:
|
|
115
|
+
break
|
|
116
|
+
time.sleep(2)
|
|
117
|
+
except client.exceptions.ResourceNotFoundException:
|
|
118
|
+
if not follow:
|
|
119
|
+
break
|
|
120
|
+
time.sleep(2)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def print_log_event(
|
|
124
|
+
event: dict, show_stream: bool = True, local_time: bool = True, custom_format: Optional[str] = None
|
|
125
|
+
):
|
|
126
|
+
"""Print a formatted log event."""
|
|
127
|
+
timestamp = datetime.fromtimestamp(event["timestamp"] / 1000, tz=timezone.utc)
|
|
128
|
+
if local_time:
|
|
129
|
+
timestamp = timestamp.astimezone()
|
|
130
|
+
message = event["message"].rstrip()
|
|
131
|
+
if custom_format:
|
|
132
|
+
# Allow custom formatting
|
|
133
|
+
print(custom_format.format(stream=event.get("logStreamName", ""), time=timestamp, message=message))
|
|
134
|
+
elif show_stream and "logStreamName" in event:
|
|
135
|
+
print(f"[{event['logStreamName']}] [{timestamp:%Y-%m-%d %I:%M%p}] {message}")
|
|
136
|
+
else:
|
|
137
|
+
print(f"[{timestamp:%H:%M:%S}] {message}")
|
|
@@ -4,15 +4,12 @@ import os
|
|
|
4
4
|
import sys
|
|
5
5
|
import platform
|
|
6
6
|
import logging
|
|
7
|
-
import importlib.resources as resources # noqa: F401 Python 3.9 compatibility
|
|
8
7
|
from typing import Any, Dict
|
|
8
|
+
from importlib.resources import files, as_file
|
|
9
9
|
|
|
10
10
|
# Workbench imports
|
|
11
11
|
from workbench.utils.license_manager import LicenseManager
|
|
12
|
-
from
|
|
13
|
-
|
|
14
|
-
# Python 3.9 compatibility
|
|
15
|
-
from workbench.utils.resource_utils import get_resource_path
|
|
12
|
+
from workbench_bridges.utils.execution_environment import running_as_service
|
|
16
13
|
|
|
17
14
|
|
|
18
15
|
class FatalConfigError(Exception):
|
|
@@ -172,8 +169,7 @@ class ConfigManager:
|
|
|
172
169
|
Returns:
|
|
173
170
|
str: The open source API key.
|
|
174
171
|
"""
|
|
175
|
-
|
|
176
|
-
with get_resource_path("workbench.resources", "open_source_api.key") as open_source_key_path:
|
|
172
|
+
with as_file(files("workbench.resources").joinpath("open_source_api.key")) as open_source_key_path:
|
|
177
173
|
with open(open_source_key_path, "r") as key_file:
|
|
178
174
|
return key_file.read().strip()
|
|
179
175
|
|
|
@@ -7,9 +7,7 @@ from typing import Union, Optional
|
|
|
7
7
|
import pandas as pd
|
|
8
8
|
|
|
9
9
|
# Workbench Imports
|
|
10
|
-
from workbench.api
|
|
11
|
-
from workbench.api.model import Model
|
|
12
|
-
from workbench.api.endpoint import Endpoint
|
|
10
|
+
from workbench.api import FeatureSet, Model, Endpoint
|
|
13
11
|
|
|
14
12
|
# Set up the log
|
|
15
13
|
log = logging.getLogger("workbench")
|
|
@@ -77,7 +75,7 @@ def internal_model_data_url(endpoint_config_name: str, session: boto3.Session) -
|
|
|
77
75
|
return None
|
|
78
76
|
|
|
79
77
|
|
|
80
|
-
def
|
|
78
|
+
def get_training_data(end: Endpoint) -> pd.DataFrame:
|
|
81
79
|
"""Code to get the training data from the FeatureSet used to train the Model
|
|
82
80
|
|
|
83
81
|
Args:
|
|
@@ -100,7 +98,7 @@ def fs_training_data(end: Endpoint) -> pd.DataFrame:
|
|
|
100
98
|
return train_df
|
|
101
99
|
|
|
102
100
|
|
|
103
|
-
def
|
|
101
|
+
def get_evaluation_data(end: Endpoint) -> pd.DataFrame:
|
|
104
102
|
"""Code to get the evaluation data from the FeatureSet NOT used for training
|
|
105
103
|
|
|
106
104
|
Args:
|
|
@@ -178,11 +176,11 @@ if __name__ == "__main__":
|
|
|
178
176
|
print(model_data_url)
|
|
179
177
|
|
|
180
178
|
# Get the training data
|
|
181
|
-
my_train_df =
|
|
179
|
+
my_train_df = get_training_data(my_endpoint)
|
|
182
180
|
print(my_train_df)
|
|
183
181
|
|
|
184
182
|
# Get the evaluation data
|
|
185
|
-
my_eval_df =
|
|
183
|
+
my_eval_df = get_evaluation_data(my_endpoint)
|
|
186
184
|
print(my_eval_df)
|
|
187
185
|
|
|
188
186
|
# Backtrack to the FeatureSet
|
|
@@ -6,15 +6,12 @@ import json
|
|
|
6
6
|
import logging
|
|
7
7
|
import requests
|
|
8
8
|
from typing import Union
|
|
9
|
-
import importlib.resources as resources # noqa: F401 Python 3.9 compatibility
|
|
10
9
|
from datetime import datetime
|
|
11
10
|
from cryptography.hazmat.primitives import hashes
|
|
12
11
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
13
12
|
from cryptography.hazmat.primitives import serialization
|
|
14
13
|
from cryptography.hazmat.backends import default_backend
|
|
15
|
-
|
|
16
|
-
# Python 3.9 compatibility
|
|
17
|
-
from workbench.utils.resource_utils import get_resource_path
|
|
14
|
+
from importlib.resources import files, as_file
|
|
18
15
|
|
|
19
16
|
|
|
20
17
|
class FatalLicenseError(Exception):
|
|
@@ -140,8 +137,7 @@ class LicenseManager:
|
|
|
140
137
|
Returns:
|
|
141
138
|
The public key as an object.
|
|
142
139
|
"""
|
|
143
|
-
|
|
144
|
-
with get_resource_path("workbench.resources", "signature_verify_pub.pem") as public_key_path:
|
|
140
|
+
with as_file(files("workbench.resources").joinpath("signature_verify_pub.pem")) as public_key_path:
|
|
145
141
|
with open(public_key_path, "rb") as key_file:
|
|
146
142
|
public_key_data = key_file.read()
|
|
147
143
|
|
workbench/utils/model_utils.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import numpy as np
|
|
6
|
+
from scipy.stats import spearmanr
|
|
6
7
|
import importlib.resources
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
import os
|
|
@@ -92,6 +93,38 @@ def get_custom_script_path(package: str, script_name: str) -> Path:
|
|
|
92
93
|
return script_path
|
|
93
94
|
|
|
94
95
|
|
|
96
|
+
def proximity_model_local(model: "Model"):
|
|
97
|
+
"""Create a Proximity Model for this Model
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
model (Model): The Model/FeatureSet used to create the proximity model
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Proximity: The proximity model
|
|
104
|
+
"""
|
|
105
|
+
from workbench.algorithms.dataframe.proximity import Proximity # noqa: F401 (avoid circular import)
|
|
106
|
+
from workbench.api import Model, FeatureSet # noqa: F401 (avoid circular import)
|
|
107
|
+
|
|
108
|
+
# Get Feature and Target Columns from the existing given Model
|
|
109
|
+
features = model.features()
|
|
110
|
+
target = model.target()
|
|
111
|
+
|
|
112
|
+
# Backtrack our FeatureSet to get the ID column
|
|
113
|
+
fs = FeatureSet(model.get_input())
|
|
114
|
+
id_column = fs.id_column
|
|
115
|
+
|
|
116
|
+
# Create the Proximity Model from both the full FeatureSet and the Model training data
|
|
117
|
+
full_df = fs.pull_dataframe()
|
|
118
|
+
model_df = model.training_view().pull_dataframe()
|
|
119
|
+
|
|
120
|
+
# Mark rows that are in the model
|
|
121
|
+
model_ids = set(model_df[id_column])
|
|
122
|
+
full_df["in_model"] = full_df[id_column].isin(model_ids)
|
|
123
|
+
|
|
124
|
+
# Create and return the Proximity Model
|
|
125
|
+
return Proximity(full_df, id_column, features, target, track_columns=features)
|
|
126
|
+
|
|
127
|
+
|
|
95
128
|
def proximity_model(model: "Model", prox_model_name: str, track_columns: list = None) -> "Model":
|
|
96
129
|
"""Create a proximity model based on the given model
|
|
97
130
|
|
|
@@ -139,9 +172,6 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
|
|
|
139
172
|
"""
|
|
140
173
|
from workbench.api import Model, ModelType, FeatureSet # noqa: F401 (avoid circular import)
|
|
141
174
|
|
|
142
|
-
# Get the custom script path for the UQ model
|
|
143
|
-
script_path = get_custom_script_path("uq_models", "meta_uq.template")
|
|
144
|
-
|
|
145
175
|
# Get Feature and Target Columns from the existing given Model
|
|
146
176
|
features = model.features()
|
|
147
177
|
target = model.target()
|
|
@@ -156,12 +186,25 @@ def uq_model(model: "Model", uq_model_name: str, train_all_data: bool = False) -
|
|
|
156
186
|
description=f"UQ Model for {model.name}",
|
|
157
187
|
tags=["uq", model.name],
|
|
158
188
|
train_all_data=train_all_data,
|
|
159
|
-
custom_script=script_path,
|
|
160
189
|
custom_args={"id_column": fs.id_column, "track_columns": [target]},
|
|
161
190
|
)
|
|
162
191
|
return uq_model
|
|
163
192
|
|
|
164
193
|
|
|
194
|
+
def safe_extract_tarfile(tar_path: str, extract_path: str) -> None:
|
|
195
|
+
"""
|
|
196
|
+
Extract a tarball safely, using data filter if available.
|
|
197
|
+
|
|
198
|
+
The filter parameter was backported to Python 3.8+, 3.9+, 3.10.13+, 3.11+
|
|
199
|
+
as a security patch, but may not be present in older patch versions.
|
|
200
|
+
"""
|
|
201
|
+
with tarfile.open(tar_path, "r:gz") as tar:
|
|
202
|
+
if hasattr(tarfile, "data_filter"):
|
|
203
|
+
tar.extractall(path=extract_path, filter="data")
|
|
204
|
+
else:
|
|
205
|
+
tar.extractall(path=extract_path)
|
|
206
|
+
|
|
207
|
+
|
|
165
208
|
def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
166
209
|
"""
|
|
167
210
|
Download and extract category mappings from a model artifact in S3.
|
|
@@ -180,8 +223,7 @@ def load_category_mappings_from_s3(model_artifact_uri: str) -> Optional[dict]:
|
|
|
180
223
|
wr.s3.download(path=model_artifact_uri, local_file=local_tar_path)
|
|
181
224
|
|
|
182
225
|
# Extract tarball
|
|
183
|
-
|
|
184
|
-
tar.extractall(path=tmpdir, filter="data")
|
|
226
|
+
safe_extract_tarfile(local_tar_path, tmpdir)
|
|
185
227
|
|
|
186
228
|
# Look for category mappings in base directory only
|
|
187
229
|
mappings_path = os.path.join(tmpdir, "category_mappings.json")
|
|
@@ -220,28 +262,41 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
220
262
|
# --- Coverage and Interval Width ---
|
|
221
263
|
if "q_025" in df.columns and "q_975" in df.columns:
|
|
222
264
|
lower_95, upper_95 = df["q_025"], df["q_975"]
|
|
265
|
+
lower_90, upper_90 = df["q_05"], df["q_95"]
|
|
266
|
+
lower_80, upper_80 = df["q_10"], df["q_90"]
|
|
267
|
+
lower_68 = df.get("q_16", df["q_10"]) # fallback to 80% interval
|
|
268
|
+
upper_68 = df.get("q_84", df["q_90"]) # fallback to 80% interval
|
|
223
269
|
lower_50, upper_50 = df["q_25"], df["q_75"]
|
|
224
270
|
elif "prediction_std" in df.columns:
|
|
225
271
|
lower_95 = df["prediction"] - 1.96 * df["prediction_std"]
|
|
226
272
|
upper_95 = df["prediction"] + 1.96 * df["prediction_std"]
|
|
273
|
+
lower_90 = df["prediction"] - 1.645 * df["prediction_std"]
|
|
274
|
+
upper_90 = df["prediction"] + 1.645 * df["prediction_std"]
|
|
275
|
+
lower_80 = df["prediction"] - 1.282 * df["prediction_std"]
|
|
276
|
+
upper_80 = df["prediction"] + 1.282 * df["prediction_std"]
|
|
277
|
+
lower_68 = df["prediction"] - 1.0 * df["prediction_std"]
|
|
278
|
+
upper_68 = df["prediction"] + 1.0 * df["prediction_std"]
|
|
227
279
|
lower_50 = df["prediction"] - 0.674 * df["prediction_std"]
|
|
228
280
|
upper_50 = df["prediction"] + 0.674 * df["prediction_std"]
|
|
229
281
|
else:
|
|
230
282
|
raise ValueError(
|
|
231
283
|
"Either quantile columns (q_025, q_975, q_25, q_75) or 'prediction_std' column must be present."
|
|
232
284
|
)
|
|
285
|
+
median_std = df["prediction_std"].median()
|
|
233
286
|
coverage_95 = np.mean((df[target_col] >= lower_95) & (df[target_col] <= upper_95))
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
287
|
+
coverage_90 = np.mean((df[target_col] >= lower_90) & (df[target_col] <= upper_90))
|
|
288
|
+
coverage_80 = np.mean((df[target_col] >= lower_80) & (df[target_col] <= upper_80))
|
|
289
|
+
coverage_68 = np.mean((df[target_col] >= lower_68) & (df[target_col] <= upper_68))
|
|
290
|
+
median_width_95 = np.median(upper_95 - lower_95)
|
|
291
|
+
median_width_90 = np.median(upper_90 - lower_90)
|
|
292
|
+
median_width_80 = np.median(upper_80 - lower_80)
|
|
293
|
+
median_width_50 = np.median(upper_50 - lower_50)
|
|
294
|
+
median_width_68 = np.median(upper_68 - lower_68)
|
|
237
295
|
|
|
238
296
|
# --- CRPS (measures calibration + sharpness) ---
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
mean_crps = np.mean(crps)
|
|
243
|
-
else:
|
|
244
|
-
mean_crps = np.nan
|
|
297
|
+
z = (df[target_col] - df["prediction"]) / df["prediction_std"]
|
|
298
|
+
crps = df["prediction_std"] * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))
|
|
299
|
+
mean_crps = np.mean(crps)
|
|
245
300
|
|
|
246
301
|
# --- Interval Score @ 95% (penalizes miscoverage) ---
|
|
247
302
|
alpha_95 = 0.05
|
|
@@ -252,31 +307,43 @@ def uq_metrics(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
|
|
|
252
307
|
)
|
|
253
308
|
mean_is_95 = np.mean(is_95)
|
|
254
309
|
|
|
255
|
-
# ---
|
|
310
|
+
# --- Interval to Error Correlation ---
|
|
256
311
|
abs_residuals = np.abs(df[target_col] - df["prediction"])
|
|
257
|
-
|
|
258
|
-
|
|
312
|
+
width_68 = upper_68 - lower_68
|
|
313
|
+
|
|
314
|
+
# Spearman correlation for robustness
|
|
315
|
+
interval_to_error_corr = spearmanr(width_68, abs_residuals)[0]
|
|
259
316
|
|
|
260
317
|
# Collect results
|
|
261
318
|
results = {
|
|
319
|
+
"coverage_68": coverage_68,
|
|
320
|
+
"coverage_80": coverage_80,
|
|
321
|
+
"coverage_90": coverage_90,
|
|
262
322
|
"coverage_95": coverage_95,
|
|
263
|
-
"
|
|
264
|
-
"
|
|
265
|
-
"
|
|
266
|
-
"
|
|
267
|
-
"
|
|
268
|
-
"
|
|
323
|
+
"median_std": median_std,
|
|
324
|
+
"median_width_50": median_width_50,
|
|
325
|
+
"median_width_68": median_width_68,
|
|
326
|
+
"median_width_80": median_width_80,
|
|
327
|
+
"median_width_90": median_width_90,
|
|
328
|
+
"median_width_95": median_width_95,
|
|
329
|
+
"interval_to_error_corr": interval_to_error_corr,
|
|
269
330
|
"n_samples": len(df),
|
|
270
331
|
}
|
|
271
332
|
|
|
272
333
|
print("\n=== UQ Metrics ===")
|
|
334
|
+
print(f"Coverage @ 68%: {coverage_68:.3f} (target: 0.68)")
|
|
335
|
+
print(f"Coverage @ 80%: {coverage_80:.3f} (target: 0.80)")
|
|
336
|
+
print(f"Coverage @ 90%: {coverage_90:.3f} (target: 0.90)")
|
|
273
337
|
print(f"Coverage @ 95%: {coverage_95:.3f} (target: 0.95)")
|
|
274
|
-
print(f"
|
|
275
|
-
print(f"
|
|
276
|
-
print(f"
|
|
338
|
+
print(f"Median Prediction StdDev: {median_std:.3f}")
|
|
339
|
+
print(f"Median 50% Width: {median_width_50:.3f}")
|
|
340
|
+
print(f"Median 68% Width: {median_width_68:.3f}")
|
|
341
|
+
print(f"Median 80% Width: {median_width_80:.3f}")
|
|
342
|
+
print(f"Median 90% Width: {median_width_90:.3f}")
|
|
343
|
+
print(f"Median 95% Width: {median_width_95:.3f}")
|
|
277
344
|
print(f"CRPS: {mean_crps:.3f} (lower is better)")
|
|
278
345
|
print(f"Interval Score 95%: {mean_is_95:.3f} (lower is better)")
|
|
279
|
-
print(f"
|
|
346
|
+
print(f"Interval/Error Corr: {interval_to_error_corr:.3f} (higher is better, target: >0.5)")
|
|
280
347
|
print(f"Samples: {len(df)}")
|
|
281
348
|
return results
|
|
282
349
|
|
|
@@ -313,9 +380,3 @@ if __name__ == "__main__":
|
|
|
313
380
|
df = end.auto_inference(capture=True)
|
|
314
381
|
results = uq_metrics(df, target_col="solubility")
|
|
315
382
|
print(results)
|
|
316
|
-
|
|
317
|
-
# Test the uq_metrics function
|
|
318
|
-
end = Endpoint("aqsol-uq-100")
|
|
319
|
-
df = end.auto_inference(capture=True)
|
|
320
|
-
results = uq_metrics(df, target_col="solubility")
|
|
321
|
-
print(results)
|
workbench/utils/monitor_utils.py
CHANGED
|
@@ -14,7 +14,7 @@ from workbench.utils.s3_utils import read_content_from_s3
|
|
|
14
14
|
log = logging.getLogger("workbench")
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def
|
|
17
|
+
def pull_data_capture_for_testing(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
|
|
18
18
|
"""
|
|
19
19
|
Read and process captured data from S3.
|
|
20
20
|
|
|
@@ -26,7 +26,12 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
26
26
|
|
|
27
27
|
Returns:
|
|
28
28
|
Union[pd.DataFrame, None]: A dataframe of the captured data (or None if no data is found).
|
|
29
|
+
|
|
30
|
+
Notes:
|
|
31
|
+
This method is really only for testing and debugging.
|
|
29
32
|
"""
|
|
33
|
+
log.important("This method is for testing and debugging only.")
|
|
34
|
+
|
|
30
35
|
# List files in the specified S3 path
|
|
31
36
|
files = wr.s3.list_objects(data_capture_path)
|
|
32
37
|
if not files:
|
|
@@ -64,59 +69,53 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
64
69
|
def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
65
70
|
"""
|
|
66
71
|
Process the captured data DataFrame to extract input and output data.
|
|
67
|
-
|
|
72
|
+
Handles cases where input or output might not be captured.
|
|
73
|
+
|
|
68
74
|
Args:
|
|
69
75
|
df (DataFrame): DataFrame with captured data.
|
|
70
76
|
Returns:
|
|
71
77
|
tuple[DataFrame, DataFrame]: Input and output DataFrames.
|
|
72
78
|
"""
|
|
79
|
+
|
|
80
|
+
def parse_endpoint_data(data: dict) -> pd.DataFrame:
|
|
81
|
+
"""Parse endpoint data based on encoding type."""
|
|
82
|
+
encoding = data["encoding"].upper()
|
|
83
|
+
|
|
84
|
+
if encoding == "CSV":
|
|
85
|
+
return pd.read_csv(StringIO(data["data"]))
|
|
86
|
+
elif encoding == "JSON":
|
|
87
|
+
json_data = json.loads(data["data"])
|
|
88
|
+
if isinstance(json_data, dict):
|
|
89
|
+
return pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
90
|
+
else:
|
|
91
|
+
return pd.DataFrame(json_data)
|
|
92
|
+
else:
|
|
93
|
+
return None # Unknown encoding
|
|
94
|
+
|
|
73
95
|
input_dfs = []
|
|
74
96
|
output_dfs = []
|
|
75
97
|
|
|
76
|
-
|
|
98
|
+
# Use itertuples() instead of iterrows() for better performance
|
|
99
|
+
for row in df.itertuples(index=True):
|
|
77
100
|
try:
|
|
78
|
-
capture_data = row
|
|
79
|
-
|
|
80
|
-
#
|
|
81
|
-
if "endpointInput"
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# Process input data
|
|
92
|
-
input_data = capture_data["endpointInput"]
|
|
93
|
-
if input_data["encoding"].upper() == "CSV":
|
|
94
|
-
input_df = pd.read_csv(StringIO(input_data["data"]))
|
|
95
|
-
elif input_data["encoding"].upper() == "JSON":
|
|
96
|
-
json_data = json.loads(input_data["data"])
|
|
97
|
-
if isinstance(json_data, dict):
|
|
98
|
-
input_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
99
|
-
else:
|
|
100
|
-
input_df = pd.DataFrame(json_data)
|
|
101
|
-
|
|
102
|
-
# Process output data
|
|
103
|
-
output_data = capture_data["endpointOutput"]
|
|
104
|
-
if output_data["encoding"].upper() == "CSV":
|
|
105
|
-
output_df = pd.read_csv(StringIO(output_data["data"]))
|
|
106
|
-
elif output_data["encoding"].upper() == "JSON":
|
|
107
|
-
json_data = json.loads(output_data["data"])
|
|
108
|
-
if isinstance(json_data, dict):
|
|
109
|
-
output_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
110
|
-
else:
|
|
111
|
-
output_df = pd.DataFrame(json_data)
|
|
112
|
-
|
|
113
|
-
# If we get here, both processed successfully
|
|
114
|
-
input_dfs.append(input_df)
|
|
115
|
-
output_dfs.append(output_df)
|
|
101
|
+
capture_data = row.captureData
|
|
102
|
+
|
|
103
|
+
# Process input data if present
|
|
104
|
+
if "endpointInput" in capture_data:
|
|
105
|
+
input_df = parse_endpoint_data(capture_data["endpointInput"])
|
|
106
|
+
if input_df is not None:
|
|
107
|
+
input_dfs.append(input_df)
|
|
108
|
+
|
|
109
|
+
# Process output data if present
|
|
110
|
+
if "endpointOutput" in capture_data:
|
|
111
|
+
output_df = parse_endpoint_data(capture_data["endpointOutput"])
|
|
112
|
+
if output_df is not None:
|
|
113
|
+
output_dfs.append(output_df)
|
|
116
114
|
|
|
117
115
|
except Exception as e:
|
|
118
|
-
log.
|
|
116
|
+
log.debug(f"Row {row.Index}: Failed to process row: {e}")
|
|
119
117
|
continue
|
|
118
|
+
|
|
120
119
|
# Combine and return results
|
|
121
120
|
return (
|
|
122
121
|
pd.concat(input_dfs, ignore_index=True) if input_dfs else pd.DataFrame(),
|
|
@@ -178,23 +177,6 @@ def parse_monitoring_results(results_json: str) -> Dict[str, Any]:
|
|
|
178
177
|
return {"error": str(e)}
|
|
179
178
|
|
|
180
179
|
|
|
181
|
-
"""TEMP
|
|
182
|
-
# If the status is "CompletedWithViolations", we grab the lastest
|
|
183
|
-
# violation file and add it to the result
|
|
184
|
-
if status == "CompletedWithViolations":
|
|
185
|
-
violation_file = f"{self.monitoring_path}/
|
|
186
|
-
{last_run['CreationTime'].strftime('%Y/%m/%d')}/constraint_violations.json"
|
|
187
|
-
if wr.s3.does_object_exist(violation_file):
|
|
188
|
-
violations_json = read_content_from_s3(violation_file)
|
|
189
|
-
violations = parse_monitoring_results(violations_json)
|
|
190
|
-
result["violations"] = violations.get("constraint_violations", [])
|
|
191
|
-
result["violation_count"] = len(result["violations"])
|
|
192
|
-
else:
|
|
193
|
-
result["violations"] = []
|
|
194
|
-
result["violation_count"] = 0
|
|
195
|
-
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
180
|
def preprocessing_script(feature_list: list[str]) -> str:
|
|
199
181
|
"""
|
|
200
182
|
A preprocessing script for monitoring jobs.
|
|
@@ -245,8 +227,8 @@ if __name__ == "__main__":
|
|
|
245
227
|
from workbench.api.monitor import Monitor
|
|
246
228
|
|
|
247
229
|
# Test pulling data capture
|
|
248
|
-
mon = Monitor("
|
|
249
|
-
df =
|
|
230
|
+
mon = Monitor("abalone-regression-rt")
|
|
231
|
+
df = pull_data_capture_for_testing(mon.data_capture_path)
|
|
250
232
|
print("Data Capture:")
|
|
251
233
|
print(df.head())
|
|
252
234
|
|
|
@@ -262,4 +244,4 @@ if __name__ == "__main__":
|
|
|
262
244
|
# Test preprocessing script
|
|
263
245
|
script = preprocessing_script(["feature1", "feature2", "feature3"])
|
|
264
246
|
print("\nPreprocessing Script:")
|
|
265
|
-
print(script)
|
|
247
|
+
# print(script)
|
workbench/utils/pandas_utils.py
CHANGED
|
@@ -152,7 +152,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
152
152
|
|
|
153
153
|
# Check for differences in common columns
|
|
154
154
|
for column in common_columns:
|
|
155
|
-
if pd.api.types.is_string_dtype(df1[column])
|
|
155
|
+
if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
|
|
156
156
|
# String comparison with NaNs treated as equal
|
|
157
157
|
differences = ~(df1[column].fillna("") == df2[column].fillna(""))
|
|
158
158
|
elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
|
|
@@ -161,8 +161,8 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
161
161
|
pd.isna(df1[column]) & pd.isna(df2[column])
|
|
162
162
|
)
|
|
163
163
|
else:
|
|
164
|
-
# Other types (
|
|
165
|
-
differences =
|
|
164
|
+
# Other types (int, Int64, etc.) - compare with NaNs treated as equal
|
|
165
|
+
differences = (df1[column] != df2[column]) & ~(pd.isna(df1[column]) & pd.isna(df2[column]))
|
|
166
166
|
|
|
167
167
|
# If differences exist, display them
|
|
168
168
|
if differences.any():
|