workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -1,167 +0,0 @@
|
|
|
1
|
-
"""Fast Inference on SageMaker Endpoints"""
|
|
2
|
-
|
|
3
|
-
import pandas as pd
|
|
4
|
-
from io import StringIO
|
|
5
|
-
import logging
|
|
6
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
-
|
|
8
|
-
# Sagemaker Imports
|
|
9
|
-
import sagemaker
|
|
10
|
-
from sagemaker.serializers import CSVSerializer
|
|
11
|
-
from sagemaker.deserializers import CSVDeserializer
|
|
12
|
-
from sagemaker import Predictor
|
|
13
|
-
|
|
14
|
-
log = logging.getLogger("workbench")
|
|
15
|
-
|
|
16
|
-
_CACHED_SM_SESSION = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def get_or_create_sm_session():
|
|
20
|
-
global _CACHED_SM_SESSION
|
|
21
|
-
if _CACHED_SM_SESSION is None:
|
|
22
|
-
_CACHED_SM_SESSION = sagemaker.Session()
|
|
23
|
-
return _CACHED_SM_SESSION
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def fast_inference(endpoint_name: str, eval_df: pd.DataFrame, sm_session=None, threads: int = 4) -> pd.DataFrame:
|
|
27
|
-
"""Run inference on the Endpoint using the provided DataFrame
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
endpoint_name (str): The name of the Endpoint
|
|
31
|
-
eval_df (pd.DataFrame): The DataFrame to run predictions on
|
|
32
|
-
sm_session (sagemaker.session.Session, optional): SageMaker Session. If None, a cached session is created.
|
|
33
|
-
threads (int): The number of threads to use (default: 4)
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
pd.DataFrame: The DataFrame with predictions
|
|
37
|
-
"""
|
|
38
|
-
# Use cached session if none is provided
|
|
39
|
-
if sm_session is None:
|
|
40
|
-
sm_session = get_or_create_sm_session()
|
|
41
|
-
|
|
42
|
-
predictor = Predictor(
|
|
43
|
-
endpoint_name,
|
|
44
|
-
sagemaker_session=sm_session,
|
|
45
|
-
serializer=CSVSerializer(),
|
|
46
|
-
deserializer=CSVDeserializer(),
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
total_rows = len(eval_df)
|
|
50
|
-
|
|
51
|
-
def process_chunk(chunk_df: pd.DataFrame, start_index: int) -> pd.DataFrame:
|
|
52
|
-
log.info(f"Processing {start_index}:{min(start_index + chunk_size, total_rows)} out of {total_rows} rows...")
|
|
53
|
-
csv_buffer = StringIO()
|
|
54
|
-
chunk_df.to_csv(csv_buffer, index=False)
|
|
55
|
-
response = predictor.predict(csv_buffer.getvalue())
|
|
56
|
-
# CSVDeserializer returns a nested list: first row is headers
|
|
57
|
-
return pd.DataFrame.from_records(response[1:], columns=response[0])
|
|
58
|
-
|
|
59
|
-
# Sagemaker has a connection pool limit of 10
|
|
60
|
-
if threads > 10:
|
|
61
|
-
log.warning("Sagemaker has a connection pool limit of 10. Reducing threads to 10.")
|
|
62
|
-
threads = 10
|
|
63
|
-
|
|
64
|
-
# Compute the chunk size (divide number of threads)
|
|
65
|
-
chunk_size = max(1, total_rows // threads)
|
|
66
|
-
|
|
67
|
-
# We also need to ensure that the chunk size is not too big
|
|
68
|
-
if chunk_size > 100:
|
|
69
|
-
chunk_size = 100
|
|
70
|
-
|
|
71
|
-
# Split DataFrame into chunks and process them concurrently
|
|
72
|
-
chunks = [(eval_df[i : i + chunk_size], i) for i in range(0, total_rows, chunk_size)]
|
|
73
|
-
with ThreadPoolExecutor(max_workers=threads) as executor:
|
|
74
|
-
df_list = list(executor.map(lambda p: process_chunk(*p), chunks))
|
|
75
|
-
|
|
76
|
-
combined_df = pd.concat(df_list, ignore_index=True)
|
|
77
|
-
|
|
78
|
-
# Convert the types of the dataframe
|
|
79
|
-
combined_df = df_type_conversions(combined_df)
|
|
80
|
-
return combined_df
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def df_type_conversions(df: pd.DataFrame) -> pd.DataFrame:
|
|
84
|
-
"""Convert the types of the dataframe that we get from an endpoint
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
df (pd.DataFrame): DataFrame to convert
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
pd.DataFrame: Converted DataFrame
|
|
91
|
-
"""
|
|
92
|
-
# Some endpoints will put in "N/A" values (for CSV serialization)
|
|
93
|
-
# We need to convert these to NaN and the run the conversions below
|
|
94
|
-
# Report on the number of N/A values in each column in the DataFrame
|
|
95
|
-
# For any count above 0 list the column name and the number of N/A values
|
|
96
|
-
na_counts = df.isin(["N/A"]).sum()
|
|
97
|
-
for column, count in na_counts.items():
|
|
98
|
-
if count > 0:
|
|
99
|
-
log.warning(f"{column} has {count} N/A values, converting to NaN")
|
|
100
|
-
pd.set_option("future.no_silent_downcasting", True)
|
|
101
|
-
df = df.replace("N/A", float("nan"))
|
|
102
|
-
|
|
103
|
-
# Convert data to numeric
|
|
104
|
-
# Note: Since we're using CSV serializers numeric columns often get changed to generic 'object' types
|
|
105
|
-
|
|
106
|
-
# Hard Conversion
|
|
107
|
-
# Note: We explicitly catch exceptions for columns that cannot be converted to numeric
|
|
108
|
-
for column in df.columns:
|
|
109
|
-
try:
|
|
110
|
-
df[column] = pd.to_numeric(df[column])
|
|
111
|
-
except ValueError:
|
|
112
|
-
# If a ValueError is raised, the column cannot be converted to numeric, so we keep it as is
|
|
113
|
-
pass
|
|
114
|
-
except TypeError:
|
|
115
|
-
# This typically means a duplicated column name, so confirm duplicate (more than 1) and log it
|
|
116
|
-
column_count = (df.columns == column).sum()
|
|
117
|
-
log.critical(f"{column} occurs {column_count} times in the DataFrame.")
|
|
118
|
-
pass
|
|
119
|
-
|
|
120
|
-
# Soft Conversion
|
|
121
|
-
# Convert columns to the best possible dtype that supports the pd.NA missing value.
|
|
122
|
-
df = df.convert_dtypes()
|
|
123
|
-
|
|
124
|
-
# Convert pd.NA placeholders to pd.NA
|
|
125
|
-
# Note: CSV serialization converts pd.NA to blank strings, so we have to put in placeholders
|
|
126
|
-
df.replace("__NA__", pd.NA, inplace=True)
|
|
127
|
-
|
|
128
|
-
# Check for True/False values in the string columns
|
|
129
|
-
for column in df.select_dtypes(include=["string"]).columns:
|
|
130
|
-
if df[column].str.lower().isin(["true", "false"]).all():
|
|
131
|
-
df[column] = df[column].str.lower().map({"true": True, "false": False})
|
|
132
|
-
|
|
133
|
-
# Return the Dataframe
|
|
134
|
-
return df
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if __name__ == "__main__":
|
|
138
|
-
"""Exercise the Endpoint Utilities"""
|
|
139
|
-
import time
|
|
140
|
-
from workbench.api.endpoint import Endpoint
|
|
141
|
-
from workbench.utils.endpoint_utils import fs_training_data, fs_evaluation_data
|
|
142
|
-
|
|
143
|
-
# Create an Endpoint
|
|
144
|
-
my_endpoint_name = "abalone-regression"
|
|
145
|
-
my_endpoint = Endpoint(my_endpoint_name)
|
|
146
|
-
if not my_endpoint.exists():
|
|
147
|
-
print(f"Endpoint {my_endpoint_name} does not exist.")
|
|
148
|
-
exit(1)
|
|
149
|
-
|
|
150
|
-
# Get the training data
|
|
151
|
-
my_train_df = fs_training_data(my_endpoint)
|
|
152
|
-
print(my_train_df)
|
|
153
|
-
|
|
154
|
-
# Run Fast Inference and time it
|
|
155
|
-
my_sm_session = my_endpoint.sm_session
|
|
156
|
-
my_eval_df = fs_evaluation_data(my_endpoint)
|
|
157
|
-
start_time = time.time()
|
|
158
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df, my_sm_session)
|
|
159
|
-
end_time = time.time()
|
|
160
|
-
print(f"Fast Inference took {end_time - start_time} seconds")
|
|
161
|
-
print(my_results_df)
|
|
162
|
-
print(my_results_df.info())
|
|
163
|
-
|
|
164
|
-
# Test with no session
|
|
165
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df)
|
|
166
|
-
print(my_results_df)
|
|
167
|
-
print(my_results_df.info())
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
"""Resource utilities for Workbench"""
|
|
2
|
-
|
|
3
|
-
import sys
|
|
4
|
-
import importlib.resources as resources
|
|
5
|
-
import pathlib
|
|
6
|
-
import pkg_resources
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def get_resource_path(package: str, resource: str) -> pathlib.Path:
|
|
10
|
-
"""Get the path to a resource file, compatible with Python 3.9 and higher.
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
package (str): The package where the resource is located.
|
|
14
|
-
resource (str): The name of the resource file.
|
|
15
|
-
|
|
16
|
-
Returns:
|
|
17
|
-
pathlib.Path: The path to the resource file.
|
|
18
|
-
"""
|
|
19
|
-
if sys.version_info >= (3, 10):
|
|
20
|
-
# Python 3.10 and higher: use importlib.resources.path
|
|
21
|
-
with resources.path(package, resource) as path:
|
|
22
|
-
return path
|
|
23
|
-
else:
|
|
24
|
-
# Python 3.9 and lower: manually construct the path based on package location
|
|
25
|
-
# Get the location of the installed package
|
|
26
|
-
package_location = pathlib.Path(pkg_resources.get_distribution(package.split(".")[0]).location)
|
|
27
|
-
resource_path = package_location / package.replace(".", "/") / resource
|
|
28
|
-
|
|
29
|
-
if resource_path.exists():
|
|
30
|
-
return resource_path
|
|
31
|
-
else:
|
|
32
|
-
raise FileNotFoundError(f"Resource '{resource}' not found in package '{package}'.")
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
if __name__ == "__main__":
|
|
36
|
-
# Test the resource utilities
|
|
37
|
-
with get_resource_path("workbench.resources", "open_source_api.key") as open_source_key_path:
|
|
38
|
-
with open(open_source_key_path, "r") as key_file:
|
|
39
|
-
print(key_file.read().strip())
|
|
File without changes
|
|
File without changes
|