workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -1,211 +0,0 @@
|
|
|
1
|
-
"""ExecutionEnvironment provides logic/functionality to figure out the current execution environment"""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import sys
|
|
5
|
-
import logging
|
|
6
|
-
import requests
|
|
7
|
-
from typing import Union
|
|
8
|
-
import boto3
|
|
9
|
-
from datetime import datetime, timezone
|
|
10
|
-
|
|
11
|
-
# Workbench imports
|
|
12
|
-
from workbench.utils.glue_utils import get_resolved_options
|
|
13
|
-
from workbench.utils.deprecated_utils import deprecated
|
|
14
|
-
|
|
15
|
-
# Set up the logger
|
|
16
|
-
log = logging.getLogger("workbench")
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def running_on_glue():
|
|
20
|
-
"""
|
|
21
|
-
Check if the current execution environment is an AWS Glue job.
|
|
22
|
-
|
|
23
|
-
Returns:
|
|
24
|
-
bool: True if running in AWS Glue environment, False otherwise.
|
|
25
|
-
"""
|
|
26
|
-
# Check if GLUE_VERSION or GLUE_PYTHON_VERSION is in the environment
|
|
27
|
-
if "GLUE_VERSION" in os.environ or "GLUE_PYTHON_VERSION" in os.environ:
|
|
28
|
-
log.info("Running in AWS Glue Environment...")
|
|
29
|
-
return True
|
|
30
|
-
else:
|
|
31
|
-
return False
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def running_on_lambda():
|
|
35
|
-
"""
|
|
36
|
-
Check if the current execution environment is an AWS Lambda function.
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
bool: True if running in AWS Lambda environment, False otherwise.
|
|
40
|
-
"""
|
|
41
|
-
if "AWS_LAMBDA_FUNCTION_NAME" in os.environ:
|
|
42
|
-
log.info("Running in AWS Lambda Environment...")
|
|
43
|
-
return True
|
|
44
|
-
else:
|
|
45
|
-
return False
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def running_on_docker() -> bool:
|
|
49
|
-
"""Check if the current environment is running on a Docker container.
|
|
50
|
-
|
|
51
|
-
Returns:
|
|
52
|
-
bool: True if running in a Docker container, False otherwise.
|
|
53
|
-
"""
|
|
54
|
-
try:
|
|
55
|
-
# Docker creates a .dockerenv file at the root of the directory tree inside the container.
|
|
56
|
-
# If this file exists, it is very likely that we are running inside a Docker container.
|
|
57
|
-
with open("/.dockerenv") as f:
|
|
58
|
-
return True
|
|
59
|
-
except FileNotFoundError:
|
|
60
|
-
pass
|
|
61
|
-
|
|
62
|
-
try:
|
|
63
|
-
# Another method is to check the contents of /proc/self/cgroup which should be different
|
|
64
|
-
# inside a Docker container.
|
|
65
|
-
with open("/proc/self/cgroup") as f:
|
|
66
|
-
if any("docker" in line for line in f):
|
|
67
|
-
return True
|
|
68
|
-
except FileNotFoundError:
|
|
69
|
-
pass
|
|
70
|
-
|
|
71
|
-
# Check if we are running on ECS
|
|
72
|
-
if running_on_ecs():
|
|
73
|
-
return True
|
|
74
|
-
|
|
75
|
-
# Probably not running in a Docker container
|
|
76
|
-
return False
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def running_on_ecs() -> bool:
|
|
80
|
-
"""
|
|
81
|
-
Check if the current environment is running on AWS ECS.
|
|
82
|
-
|
|
83
|
-
Returns:
|
|
84
|
-
bool: True if running on AWS ECS, False otherwise.
|
|
85
|
-
"""
|
|
86
|
-
indicators = [
|
|
87
|
-
"ECS_SERVICE_NAME",
|
|
88
|
-
"ECS_CONTAINER_METADATA_URI",
|
|
89
|
-
"ECS_CONTAINER_METADATA_URI_V4",
|
|
90
|
-
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
|
|
91
|
-
"AWS_EXECUTION_ENV",
|
|
92
|
-
]
|
|
93
|
-
return any(indicator in os.environ for indicator in indicators)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def running_as_service() -> bool:
|
|
97
|
-
"""
|
|
98
|
-
Check if the current environment is running as a service (e.g. Docker, ECS, Glue, Lambda).
|
|
99
|
-
|
|
100
|
-
Returns:
|
|
101
|
-
bool: True if running as a service, False otherwise.
|
|
102
|
-
"""
|
|
103
|
-
return running_on_docker() or running_on_glue() or running_on_lambda()
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def _glue_job_from_script_name(args):
|
|
107
|
-
"""Get the Glue Job Name from the script name"""
|
|
108
|
-
try:
|
|
109
|
-
script_name = args["scriptLocation"]
|
|
110
|
-
return os.path.splitext(os.path.basename(script_name))[0]
|
|
111
|
-
except Exception:
|
|
112
|
-
return "unknown"
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def glue_job_name():
|
|
116
|
-
"""Get the Glue Job Name from the environment or script name"""
|
|
117
|
-
# Define the required argument
|
|
118
|
-
args = get_resolved_options(sys.argv)
|
|
119
|
-
|
|
120
|
-
# Get the job name
|
|
121
|
-
job_name = args.get("JOB_NAME") or _glue_job_from_script_name(args)
|
|
122
|
-
return job_name
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
@deprecated(version=0.9)
|
|
126
|
-
def glue_job_run_id(job_name: str, session: boto3.Session) -> Union[str, None]:
|
|
127
|
-
"""Retrieve the Glue Job Run ID closest to the current time for the given job name.
|
|
128
|
-
Note: This mostly doesn't work, it will grab A glue job id but often not the correct one.
|
|
129
|
-
For now, I would just skip using this
|
|
130
|
-
"""
|
|
131
|
-
try:
|
|
132
|
-
# Set current time in UTC
|
|
133
|
-
current_time = datetime.now(timezone.utc)
|
|
134
|
-
|
|
135
|
-
job_runs = session.client("glue").get_job_runs(JobName=job_name)
|
|
136
|
-
if job_runs["JobRuns"]:
|
|
137
|
-
# Find the job run with the StartedOn time closest to the current time
|
|
138
|
-
closest_job_run = min(job_runs["JobRuns"], key=lambda run: abs(run["StartedOn"] - current_time))
|
|
139
|
-
job_id = closest_job_run["Id"]
|
|
140
|
-
return job_id[:9] # Shorten the Job Run ID to 9 characters
|
|
141
|
-
|
|
142
|
-
log.error(f"No runs found for Glue Job '{job_name}', returning None for Job Run ID.")
|
|
143
|
-
return None
|
|
144
|
-
|
|
145
|
-
except session.client("glue").exceptions.EntityNotFoundException:
|
|
146
|
-
log.error(f"Glue Job '{job_name}' not found, returning None for Job Run ID.")
|
|
147
|
-
return None
|
|
148
|
-
except Exception as e:
|
|
149
|
-
log.error(f"An error occurred while retrieving job run ID: {e}")
|
|
150
|
-
return None
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
def ecs_job_name():
|
|
154
|
-
"""Get the ECS Job Name from the metadata endpoint or environment variables."""
|
|
155
|
-
# Attempt to get the job name from ECS metadata
|
|
156
|
-
ecs_metadata_uri = os.environ.get("ECS_CONTAINER_METADATA_URI_V4")
|
|
157
|
-
|
|
158
|
-
if ecs_metadata_uri:
|
|
159
|
-
try:
|
|
160
|
-
response = requests.get(f"{ecs_metadata_uri}/task")
|
|
161
|
-
if response.status_code == 200:
|
|
162
|
-
metadata = response.json()
|
|
163
|
-
job_name = metadata.get("Family") # 'Family' represents the ECS task definition family name
|
|
164
|
-
if job_name:
|
|
165
|
-
return job_name
|
|
166
|
-
except requests.RequestException as e:
|
|
167
|
-
# Log the error or handle it as needed
|
|
168
|
-
log.error(f"Failed to fetch ECS metadata: {e}")
|
|
169
|
-
|
|
170
|
-
# Fallback to environment variables if metadata is not available
|
|
171
|
-
job_name = os.environ.get("ECS_SERVICE_NAME", "unknown")
|
|
172
|
-
return job_name
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
if __name__ == "__main__":
|
|
176
|
-
"""Test the Execution Environment utilities"""
|
|
177
|
-
|
|
178
|
-
# Test running_on_glue
|
|
179
|
-
assert running_on_glue() is False
|
|
180
|
-
os.environ["GLUE_VERSION"] = "1.0"
|
|
181
|
-
assert running_on_glue() is True
|
|
182
|
-
del os.environ["GLUE_VERSION"]
|
|
183
|
-
|
|
184
|
-
# Test running_on_lambda
|
|
185
|
-
assert running_on_lambda() is False
|
|
186
|
-
os.environ["AWS_LAMBDA_FUNCTION_NAME"] = "my_lambda_function"
|
|
187
|
-
assert running_on_lambda() is True
|
|
188
|
-
del os.environ["AWS_LAMBDA_FUNCTION_NAME"]
|
|
189
|
-
|
|
190
|
-
# Test running_on_docker
|
|
191
|
-
assert running_on_docker() is False
|
|
192
|
-
os.environ["ECS_CONTAINER_METADATA_URI"] = "http://localhost:8080"
|
|
193
|
-
assert running_on_docker() is True
|
|
194
|
-
del os.environ["ECS_CONTAINER_METADATA_URI"]
|
|
195
|
-
|
|
196
|
-
# Test running_on_ecs
|
|
197
|
-
assert running_on_ecs() is False
|
|
198
|
-
os.environ["ECS_CONTAINER_METADATA_URI"] = "http://localhost:8080"
|
|
199
|
-
assert running_on_ecs() is True
|
|
200
|
-
del os.environ["ECS_CONTAINER_METADATA_URI"]
|
|
201
|
-
|
|
202
|
-
# Test getting the Glue Job Name
|
|
203
|
-
print(glue_job_name())
|
|
204
|
-
|
|
205
|
-
# Test getting the Glue Job Run ID
|
|
206
|
-
from workbench.core.cloud_platform.aws.aws_session import AWSSession
|
|
207
|
-
|
|
208
|
-
session = AWSSession().boto3_session
|
|
209
|
-
print(glue_job_run_id("Test_Workbench_Shell", session))
|
|
210
|
-
|
|
211
|
-
print("All tests passed!")
|
|
@@ -1,167 +0,0 @@
|
|
|
1
|
-
"""Fast Inference on SageMaker Endpoints"""
|
|
2
|
-
|
|
3
|
-
import pandas as pd
|
|
4
|
-
from io import StringIO
|
|
5
|
-
import logging
|
|
6
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
-
|
|
8
|
-
# Sagemaker Imports
|
|
9
|
-
import sagemaker
|
|
10
|
-
from sagemaker.serializers import CSVSerializer
|
|
11
|
-
from sagemaker.deserializers import CSVDeserializer
|
|
12
|
-
from sagemaker import Predictor
|
|
13
|
-
|
|
14
|
-
log = logging.getLogger("workbench")
|
|
15
|
-
|
|
16
|
-
_CACHED_SM_SESSION = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def get_or_create_sm_session():
|
|
20
|
-
global _CACHED_SM_SESSION
|
|
21
|
-
if _CACHED_SM_SESSION is None:
|
|
22
|
-
_CACHED_SM_SESSION = sagemaker.Session()
|
|
23
|
-
return _CACHED_SM_SESSION
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def fast_inference(endpoint_name: str, eval_df: pd.DataFrame, sm_session=None, threads: int = 4) -> pd.DataFrame:
|
|
27
|
-
"""Run inference on the Endpoint using the provided DataFrame
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
endpoint_name (str): The name of the Endpoint
|
|
31
|
-
eval_df (pd.DataFrame): The DataFrame to run predictions on
|
|
32
|
-
sm_session (sagemaker.session.Session, optional): SageMaker Session. If None, a cached session is created.
|
|
33
|
-
threads (int): The number of threads to use (default: 4)
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
pd.DataFrame: The DataFrame with predictions
|
|
37
|
-
"""
|
|
38
|
-
# Use cached session if none is provided
|
|
39
|
-
if sm_session is None:
|
|
40
|
-
sm_session = get_or_create_sm_session()
|
|
41
|
-
|
|
42
|
-
predictor = Predictor(
|
|
43
|
-
endpoint_name,
|
|
44
|
-
sagemaker_session=sm_session,
|
|
45
|
-
serializer=CSVSerializer(),
|
|
46
|
-
deserializer=CSVDeserializer(),
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
total_rows = len(eval_df)
|
|
50
|
-
|
|
51
|
-
def process_chunk(chunk_df: pd.DataFrame, start_index: int) -> pd.DataFrame:
|
|
52
|
-
log.info(f"Processing {start_index}:{min(start_index + chunk_size, total_rows)} out of {total_rows} rows...")
|
|
53
|
-
csv_buffer = StringIO()
|
|
54
|
-
chunk_df.to_csv(csv_buffer, index=False)
|
|
55
|
-
response = predictor.predict(csv_buffer.getvalue())
|
|
56
|
-
# CSVDeserializer returns a nested list: first row is headers
|
|
57
|
-
return pd.DataFrame.from_records(response[1:], columns=response[0])
|
|
58
|
-
|
|
59
|
-
# Sagemaker has a connection pool limit of 10
|
|
60
|
-
if threads > 10:
|
|
61
|
-
log.warning("Sagemaker has a connection pool limit of 10. Reducing threads to 10.")
|
|
62
|
-
threads = 10
|
|
63
|
-
|
|
64
|
-
# Compute the chunk size (divide number of threads)
|
|
65
|
-
chunk_size = max(1, total_rows // threads)
|
|
66
|
-
|
|
67
|
-
# We also need to ensure that the chunk size is not too big
|
|
68
|
-
if chunk_size > 100:
|
|
69
|
-
chunk_size = 100
|
|
70
|
-
|
|
71
|
-
# Split DataFrame into chunks and process them concurrently
|
|
72
|
-
chunks = [(eval_df[i : i + chunk_size], i) for i in range(0, total_rows, chunk_size)]
|
|
73
|
-
with ThreadPoolExecutor(max_workers=threads) as executor:
|
|
74
|
-
df_list = list(executor.map(lambda p: process_chunk(*p), chunks))
|
|
75
|
-
|
|
76
|
-
combined_df = pd.concat(df_list, ignore_index=True)
|
|
77
|
-
|
|
78
|
-
# Convert the types of the dataframe
|
|
79
|
-
combined_df = df_type_conversions(combined_df)
|
|
80
|
-
return combined_df
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def df_type_conversions(df: pd.DataFrame) -> pd.DataFrame:
|
|
84
|
-
"""Convert the types of the dataframe that we get from an endpoint
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
df (pd.DataFrame): DataFrame to convert
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
pd.DataFrame: Converted DataFrame
|
|
91
|
-
"""
|
|
92
|
-
# Some endpoints will put in "N/A" values (for CSV serialization)
|
|
93
|
-
# We need to convert these to NaN and the run the conversions below
|
|
94
|
-
# Report on the number of N/A values in each column in the DataFrame
|
|
95
|
-
# For any count above 0 list the column name and the number of N/A values
|
|
96
|
-
na_counts = df.isin(["N/A"]).sum()
|
|
97
|
-
for column, count in na_counts.items():
|
|
98
|
-
if count > 0:
|
|
99
|
-
log.warning(f"{column} has {count} N/A values, converting to NaN")
|
|
100
|
-
pd.set_option("future.no_silent_downcasting", True)
|
|
101
|
-
df = df.replace("N/A", float("nan"))
|
|
102
|
-
|
|
103
|
-
# Convert data to numeric
|
|
104
|
-
# Note: Since we're using CSV serializers numeric columns often get changed to generic 'object' types
|
|
105
|
-
|
|
106
|
-
# Hard Conversion
|
|
107
|
-
# Note: We explicitly catch exceptions for columns that cannot be converted to numeric
|
|
108
|
-
for column in df.columns:
|
|
109
|
-
try:
|
|
110
|
-
df[column] = pd.to_numeric(df[column])
|
|
111
|
-
except ValueError:
|
|
112
|
-
# If a ValueError is raised, the column cannot be converted to numeric, so we keep it as is
|
|
113
|
-
pass
|
|
114
|
-
except TypeError:
|
|
115
|
-
# This typically means a duplicated column name, so confirm duplicate (more than 1) and log it
|
|
116
|
-
column_count = (df.columns == column).sum()
|
|
117
|
-
log.critical(f"{column} occurs {column_count} times in the DataFrame.")
|
|
118
|
-
pass
|
|
119
|
-
|
|
120
|
-
# Soft Conversion
|
|
121
|
-
# Convert columns to the best possible dtype that supports the pd.NA missing value.
|
|
122
|
-
df = df.convert_dtypes()
|
|
123
|
-
|
|
124
|
-
# Convert pd.NA placeholders to pd.NA
|
|
125
|
-
# Note: CSV serialization converts pd.NA to blank strings, so we have to put in placeholders
|
|
126
|
-
df.replace("__NA__", pd.NA, inplace=True)
|
|
127
|
-
|
|
128
|
-
# Check for True/False values in the string columns
|
|
129
|
-
for column in df.select_dtypes(include=["string"]).columns:
|
|
130
|
-
if df[column].str.lower().isin(["true", "false"]).all():
|
|
131
|
-
df[column] = df[column].str.lower().map({"true": True, "false": False})
|
|
132
|
-
|
|
133
|
-
# Return the Dataframe
|
|
134
|
-
return df
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if __name__ == "__main__":
|
|
138
|
-
"""Exercise the Endpoint Utilities"""
|
|
139
|
-
import time
|
|
140
|
-
from workbench.api.endpoint import Endpoint
|
|
141
|
-
from workbench.utils.endpoint_utils import fs_training_data, fs_evaluation_data
|
|
142
|
-
|
|
143
|
-
# Create an Endpoint
|
|
144
|
-
my_endpoint_name = "abalone-regression"
|
|
145
|
-
my_endpoint = Endpoint(my_endpoint_name)
|
|
146
|
-
if not my_endpoint.exists():
|
|
147
|
-
print(f"Endpoint {my_endpoint_name} does not exist.")
|
|
148
|
-
exit(1)
|
|
149
|
-
|
|
150
|
-
# Get the training data
|
|
151
|
-
my_train_df = fs_training_data(my_endpoint)
|
|
152
|
-
print(my_train_df)
|
|
153
|
-
|
|
154
|
-
# Run Fast Inference and time it
|
|
155
|
-
my_sm_session = my_endpoint.sm_session
|
|
156
|
-
my_eval_df = fs_evaluation_data(my_endpoint)
|
|
157
|
-
start_time = time.time()
|
|
158
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df, my_sm_session)
|
|
159
|
-
end_time = time.time()
|
|
160
|
-
print(f"Fast Inference took {end_time - start_time} seconds")
|
|
161
|
-
print(my_results_df)
|
|
162
|
-
print(my_results_df.info())
|
|
163
|
-
|
|
164
|
-
# Test with no session
|
|
165
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df)
|
|
166
|
-
print(my_results_df)
|
|
167
|
-
print(my_results_df.info())
|
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
"""Resource utilities for Workbench"""
|
|
2
|
-
|
|
3
|
-
import sys
|
|
4
|
-
import importlib.resources as resources
|
|
5
|
-
import pathlib
|
|
6
|
-
import pkg_resources
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def get_resource_path(package: str, resource: str) -> pathlib.Path:
|
|
10
|
-
"""Get the path to a resource file, compatible with Python 3.9 and higher.
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
package (str): The package where the resource is located.
|
|
14
|
-
resource (str): The name of the resource file.
|
|
15
|
-
|
|
16
|
-
Returns:
|
|
17
|
-
pathlib.Path: The path to the resource file.
|
|
18
|
-
"""
|
|
19
|
-
if sys.version_info >= (3, 10):
|
|
20
|
-
# Python 3.10 and higher: use importlib.resources.path
|
|
21
|
-
with resources.path(package, resource) as path:
|
|
22
|
-
return path
|
|
23
|
-
else:
|
|
24
|
-
# Python 3.9 and lower: manually construct the path based on package location
|
|
25
|
-
# Get the location of the installed package
|
|
26
|
-
package_location = pathlib.Path(pkg_resources.get_distribution(package.split(".")[0]).location)
|
|
27
|
-
resource_path = package_location / package.replace(".", "/") / resource
|
|
28
|
-
|
|
29
|
-
if resource_path.exists():
|
|
30
|
-
return resource_path
|
|
31
|
-
else:
|
|
32
|
-
raise FileNotFoundError(f"Resource '{resource}' not found in package '{package}'.")
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
if __name__ == "__main__":
|
|
36
|
-
# Test the resource utilities
|
|
37
|
-
with get_resource_path("workbench.resources", "open_source_api.key") as open_source_key_path:
|
|
38
|
-
with open(open_source_key_path, "r") as key_file:
|
|
39
|
-
print(key_file.read().strip())
|
|
File without changes
|
|
File without changes
|