workbench 0.8.176__py3-none-any.whl → 0.8.178__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/core/artifacts/endpoint_core.py +4 -1
- workbench/core/artifacts/feature_set_core.py +37 -8
- workbench/core/artifacts/model_core.py +8 -29
- workbench/core/views/training_view.py +38 -48
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +19 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +11 -15
- workbench/model_scripts/custom_models/uq_models/mapie.template +10 -14
- workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
- workbench/scripts/ml_pipeline_sqs.py +14 -2
- workbench/utils/chem_utils/mol_descriptors.py +19 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/model_utils.py +37 -25
- workbench/utils/xgboost_model_utils.py +1 -1
- {workbench-0.8.176.dist-info → workbench-0.8.178.dist-info}/METADATA +1 -1
- {workbench-0.8.176.dist-info → workbench-0.8.178.dist-info}/RECORD +20 -21
- workbench/utils/fast_inference.py +0 -167
- {workbench-0.8.176.dist-info → workbench-0.8.178.dist-info}/WHEEL +0 -0
- {workbench-0.8.176.dist-info → workbench-0.8.178.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.176.dist-info → workbench-0.8.178.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.176.dist-info → workbench-0.8.178.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.178
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -54,9 +54,9 @@ workbench/core/artifacts/cached_artifact_mixin.py,sha256=ngqFLZ4cQx_TFouXZgXZQsv
|
|
|
54
54
|
workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcvZyNxYERwvo8o0OQc,14858
|
|
55
55
|
workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
|
|
56
56
|
workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
|
|
57
|
-
workbench/core/artifacts/endpoint_core.py,sha256=
|
|
58
|
-
workbench/core/artifacts/feature_set_core.py,sha256=
|
|
59
|
-
workbench/core/artifacts/model_core.py,sha256=
|
|
57
|
+
workbench/core/artifacts/endpoint_core.py,sha256=Q6wL0IpMgCkVssX-BvPwawgogQjq9klSaoBUZ6tEIuc,49146
|
|
58
|
+
workbench/core/artifacts/feature_set_core.py,sha256=0wvW4VyZii0GmO6tBudoGEqZktLtb6spDyIkn7MkDcw,30292
|
|
59
|
+
workbench/core/artifacts/model_core.py,sha256=ECDwQ0qM5qb1yGJ07U70BVdfkrW9m7p9e6YJWib3uR0,50855
|
|
60
60
|
workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
|
|
61
61
|
workbench/core/cloud_platform/cloud_meta.py,sha256=-g4-LTC3D0PXb3VfaXdLR1ERijKuHdffeMK_zhD-koQ,8809
|
|
62
62
|
workbench/core/cloud_platform/aws/README.md,sha256=QT5IQXoUHbIA0qQ2wO6_2P2lYjYQFVYuezc22mWY4i8,97
|
|
@@ -118,14 +118,14 @@ workbench/core/views/create_view.py,sha256=2Ykzb2NvJGoD4PP4k2Bka46GDog9iGG5SWnAc
|
|
|
118
118
|
workbench/core/views/display_view.py,sha256=9K4O77ZnKOh93aMRhxcQJQ1lqScLhuJnU_tHtYZ_U4E,2598
|
|
119
119
|
workbench/core/views/inference_view.py,sha256=9s70M0dFdGq0tWvzMZfgUK7EPKtuvcQhux0uyRZuuLM,3293
|
|
120
120
|
workbench/core/views/pandas_to_view.py,sha256=20uCsnG2iMh-U1VxqVUUtnrWAY98SeuHjmfJK_wcq1I,6422
|
|
121
|
-
workbench/core/views/training_view.py,sha256=
|
|
121
|
+
workbench/core/views/training_view.py,sha256=UWW8Asxtm_kV7Z8NooitMA4xC5vTc7lSWwTGbLdifqY,5900
|
|
122
122
|
workbench/core/views/view.py,sha256=Ujzw6zLROP9oKfKm3zJwaOyfpyjh5uM9fAu1i3kUOig,11764
|
|
123
123
|
workbench/core/views/view_utils.py,sha256=y0YuPW-90nAfgAD1UW_49-j7Mvncfm7-5rV8I_97CK8,12274
|
|
124
124
|
workbench/core/views/storage/mdq_view.py,sha256=qf_ep1KwaXOIfO930laEwNIiCYP7VNOqjE3VdHfopRE,5195
|
|
125
125
|
workbench/model_scripts/script_generation.py,sha256=dL23XYwEsHIStc7i53DtF_47FqOrI9gq0kQAT6sNpZ8,7923
|
|
126
126
|
workbench/model_scripts/custom_models/chem_info/Readme.md,sha256=mH1lxJ4Pb7F5nBnVXaiuxpi8zS_yjUw_LBJepVKXhlA,574
|
|
127
|
-
workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=
|
|
128
|
-
workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256
|
|
127
|
+
workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
|
|
128
|
+
workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
|
|
129
129
|
workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py,sha256=xljMjdfh4Idi4v1Afq1zZxvF1SDa7pDOLSAhvGBEj88,2891
|
|
130
130
|
workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py,sha256=tMyMmeN1xajVWkqkV5mobYB8CYkzW9FRH8Vi3t81uo8,3231
|
|
131
131
|
workbench/model_scripts/custom_models/chem_info/requirements.txt,sha256=7HBUzvNiM8lOir-UfQabXYlUp3gxdGJ42u18EuSMGjc,39
|
|
@@ -140,8 +140,8 @@ workbench/model_scripts/custom_models/uq_models/Readme.md,sha256=UVpL-lvtTrLqwBe
|
|
|
140
140
|
workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=U4LIlpp8Rbu3apyzPR7-55lvlutpTsCro_PUvQ5pklY,6457
|
|
141
141
|
workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=0IJnSBACQ556ldEiPqR7yPCOOLJs1hQhHmPBvB2d9tY,13491
|
|
142
142
|
workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=QbDUfkiPCwJ-c-4Twgu4utZuYZaAyeW_3T1IP-_tutw,6683
|
|
143
|
-
workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=
|
|
144
|
-
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=
|
|
143
|
+
workbench/model_scripts/custom_models/uq_models/generated_model_script.py,sha256=DUH80Y-We_-3OomUNjvBdRPrNQLQb3zlSsKZIPiglU4,22402
|
|
144
|
+
workbench/model_scripts/custom_models/uq_models/mapie.template,sha256=SHP1Sd-nWMVF5sgB9Ski6C4IkQlm4g0EqpnJT1GfHl4,18204
|
|
145
145
|
workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=eawh0Fp3DhbdCXzWN6KloczT5ZS_ou4ayW65yUTTE4o,14109
|
|
146
146
|
workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=9-O6P-SW50ul5Wl6es2DMWXSbrwOg7HWsdc8Qdln0MM,8278
|
|
147
147
|
workbench/model_scripts/custom_models/uq_models/proximity.py,sha256=zqmNlX70LnWXr5fdtFFQppSNTLjlOciQVrjGr-g9jRE,13716
|
|
@@ -159,7 +159,7 @@ workbench/model_scripts/quant_regression/requirements.txt,sha256=jWlGc7HH7vqyukT
|
|
|
159
159
|
workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=c73ZpJBlU5k13Nx-ZDkLXu7da40CYyhwjwwmuPq6uLg,12870
|
|
160
160
|
workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
|
|
161
161
|
workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=d4pgeZYFezUQsB-7iIsjsUgB1FM6d27651wpfDdXmI0,12640
|
|
162
|
-
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=
|
|
162
|
+
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=BPhr2gfJQC1C26knsyktfLGL7Jp0YBKCIQjplCuHUg0,22218
|
|
163
163
|
workbench/model_scripts/xgb_model/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
164
164
|
workbench/model_scripts/xgb_model/xgb_model.template,sha256=HViJRsMWn393hP8VJRS45UQBzUVBhwR5sKc8Ern-9f4,17963
|
|
165
165
|
workbench/repl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -169,7 +169,7 @@ workbench/resources/signature_verify_pub.pem,sha256=V3-u-3_z2PH-805ybkKvzDOBwAbv
|
|
|
169
169
|
workbench/scripts/check_double_bond_stereo.py,sha256=p5hnL54Weq77ES0HCELq9JeoM-PyUGkvVSeWYF2dKyo,7776
|
|
170
170
|
workbench/scripts/glue_launcher.py,sha256=bIKQvfGxpAhzbeNvTnHfRW_5kQhY-169_868ZnCejJk,10692
|
|
171
171
|
workbench/scripts/ml_pipeline_batch.py,sha256=1T5JnLlUJR7bwAGBLHmLPOuj1xFRqVIQX8PsuDhHy8o,4907
|
|
172
|
-
workbench/scripts/ml_pipeline_sqs.py,sha256=
|
|
172
|
+
workbench/scripts/ml_pipeline_sqs.py,sha256=COewJcFYuv5Pa_l0q0PA4ZZb-AQ_7opKJP4JTEKBQ2c,5847
|
|
173
173
|
workbench/scripts/monitor_cloud_watch.py,sha256=s7MY4bsHts0nup9G0lWESCvgJZ9Mw1Eo-c8aKRgLjMw,9235
|
|
174
174
|
workbench/scripts/redis_expire.py,sha256=DxI_RKSNlrW2BsJZXcsSbaWGBgPZdPhtzHjV9SUtElE,1120
|
|
175
175
|
workbench/scripts/redis_report.py,sha256=iaJSuGPyLCs6e0TMcZDoT0YyJ43xJ1u74YD8FLnnUg4,990
|
|
@@ -211,7 +211,6 @@ workbench/utils/ecs_info.py,sha256=Gs9jNb4vcj2pziufIOI4BVIH1J-3XBMtWm1phVh8oRY,2
|
|
|
211
211
|
workbench/utils/endpoint_metrics.py,sha256=_4WVU6cLLuV0t_i0PSvhi0EoA5ss5aDFe7ZDpumx2R8,7822
|
|
212
212
|
workbench/utils/endpoint_utils.py,sha256=3-njrhMSAIOaEEiH7qMA9vgD3I7J2S9iUAcqXKx3OBo,7104
|
|
213
213
|
workbench/utils/extract_model_artifact.py,sha256=sFwkJd5mfJ1PU37pIHVmUIQS-taIUJdqi3D9-qRmy8g,7870
|
|
214
|
-
workbench/utils/fast_inference.py,sha256=Sm0EV1oPsYYGqiDBVUu3Nj6Ti68JV-UR2S0ZliBDPTk,6148
|
|
215
214
|
workbench/utils/glue_utils.py,sha256=dslfXQcJ4C-mGmsD6LqeK8vsXBez570t3fZBVZLV7HA,2039
|
|
216
215
|
workbench/utils/graph_utils.py,sha256=T4aslYVbzPmFe0_qKCQP6PZnaw1KATNXQNVO-yDGBxY,10839
|
|
217
216
|
workbench/utils/ipython_utils.py,sha256=skbdbBwUT-iuY3FZwy3ACS7-FWSe9M2qVXfLlQWnikE,700
|
|
@@ -220,7 +219,7 @@ workbench/utils/lambda_utils.py,sha256=7GhGRPyXn9o-toWb9HBGSnI8-DhK9YRkwhCSk_mNK
|
|
|
220
219
|
workbench/utils/license_manager.py,sha256=sDuhk1mZZqUbFmnuFXehyGnui_ALxrmYBg7gYwoo7ho,6975
|
|
221
220
|
workbench/utils/log_utils.py,sha256=7n1NJXO_jUX82e6LWAQug6oPo3wiPDBYsqk9gsYab_A,3167
|
|
222
221
|
workbench/utils/markdown_utils.py,sha256=4lEqzgG4EVmLcvvKKNUwNxVCySLQKJTJmWDiaDroI1w,8306
|
|
223
|
-
workbench/utils/model_utils.py,sha256=
|
|
222
|
+
workbench/utils/model_utils.py,sha256=97yqEEeGLV8KSDt_RTGexcUEK1wU_UnmLj-cfuryPOs,12779
|
|
224
223
|
workbench/utils/monitor_utils.py,sha256=kVaJ7BgUXs3VPMFYfLC03wkIV4Dq-pEhoXS0wkJFxCc,7858
|
|
225
224
|
workbench/utils/pandas_utils.py,sha256=uTUx-d1KYfjbS9PMQp2_9FogCV7xVZR6XLzU5YAGmfs,39371
|
|
226
225
|
workbench/utils/performance_utils.py,sha256=WDNvz-bOdC99cDuXl0urAV4DJ7alk_V3yzKPwvqgST4,1329
|
|
@@ -243,12 +242,12 @@ workbench/utils/workbench_cache.py,sha256=IQchxB81iR4eVggHBxUJdXxUCRkqWz1jKe5gxN
|
|
|
243
242
|
workbench/utils/workbench_event_bridge.py,sha256=z1GmXOB-Qs7VOgC6Hjnp2DI9nSEWepaSXejACxTIR7o,4150
|
|
244
243
|
workbench/utils/workbench_logging.py,sha256=WCuMWhQwibrvcGAyj96h2wowh6dH7zNlDJ7sWUzdCeI,10263
|
|
245
244
|
workbench/utils/workbench_sqs.py,sha256=RwM80z7YWwdtMaCKh7KWF8v38f7eBRU7kyC7ZhTRuI0,2072
|
|
246
|
-
workbench/utils/xgboost_model_utils.py,sha256=
|
|
245
|
+
workbench/utils/xgboost_model_utils.py,sha256=NNcALcBNOveqkIJiG7Wh7DS0O95RlGE3GZJbdSB8XWY,15571
|
|
247
246
|
workbench/utils/chem_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
248
247
|
workbench/utils/chem_utils/fingerprints.py,sha256=Qvs8jaUwguWUq3Q3j695MY0t0Wk3BvroW-oWBwalMUo,5255
|
|
249
248
|
workbench/utils/chem_utils/misc.py,sha256=Nevf8_opu-uIPrv_1_0ubuFVVo2_fGUkMoLAHB3XAeo,7372
|
|
250
|
-
workbench/utils/chem_utils/mol_descriptors.py,sha256=
|
|
251
|
-
workbench/utils/chem_utils/mol_standardize.py,sha256
|
|
249
|
+
workbench/utils/chem_utils/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
|
|
250
|
+
workbench/utils/chem_utils/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
|
|
252
251
|
workbench/utils/chem_utils/mol_tagging.py,sha256=8Bt6gHvyN8B2jvVuz12JgYMHVLDkCLnEPAfqkyMEoMc,9995
|
|
253
252
|
workbench/utils/chem_utils/projections.py,sha256=smV-VTB-pqRrgn4DXyDIpuCYcopJdPZ54YoCQv60JY0,7480
|
|
254
253
|
workbench/utils/chem_utils/salts.py,sha256=ZzFb6Z71Z_kMjVF-PKwHx0fn9pN9rPMj-oEY8Nt5JWA,9095
|
|
@@ -288,9 +287,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
|
|
|
288
287
|
workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
|
|
289
288
|
workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
|
|
290
289
|
workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
|
|
291
|
-
workbench-0.8.
|
|
292
|
-
workbench-0.8.
|
|
293
|
-
workbench-0.8.
|
|
294
|
-
workbench-0.8.
|
|
295
|
-
workbench-0.8.
|
|
296
|
-
workbench-0.8.
|
|
290
|
+
workbench-0.8.178.dist-info/licenses/LICENSE,sha256=z4QMMPlLJkZjU8VOKqJkZiQZCEZ--saIU2Z8-p3aVc0,1080
|
|
291
|
+
workbench-0.8.178.dist-info/METADATA,sha256=kS1snm2EjzaXVrpsg3TX28OmXqYDdZD1K7kQ0lXhNg8,9210
|
|
292
|
+
workbench-0.8.178.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
293
|
+
workbench-0.8.178.dist-info/entry_points.txt,sha256=zPFPruY9uayk8-wsKrhfnIyIB6jvZOW_ibyllEIsLWo,356
|
|
294
|
+
workbench-0.8.178.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
|
|
295
|
+
workbench-0.8.178.dist-info/RECORD,,
|
|
@@ -1,167 +0,0 @@
|
|
|
1
|
-
"""Fast Inference on SageMaker Endpoints"""
|
|
2
|
-
|
|
3
|
-
import pandas as pd
|
|
4
|
-
from io import StringIO
|
|
5
|
-
import logging
|
|
6
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
-
|
|
8
|
-
# Sagemaker Imports
|
|
9
|
-
import sagemaker
|
|
10
|
-
from sagemaker.serializers import CSVSerializer
|
|
11
|
-
from sagemaker.deserializers import CSVDeserializer
|
|
12
|
-
from sagemaker import Predictor
|
|
13
|
-
|
|
14
|
-
log = logging.getLogger("workbench")
|
|
15
|
-
|
|
16
|
-
_CACHED_SM_SESSION = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def get_or_create_sm_session():
|
|
20
|
-
global _CACHED_SM_SESSION
|
|
21
|
-
if _CACHED_SM_SESSION is None:
|
|
22
|
-
_CACHED_SM_SESSION = sagemaker.Session()
|
|
23
|
-
return _CACHED_SM_SESSION
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def fast_inference(endpoint_name: str, eval_df: pd.DataFrame, sm_session=None, threads: int = 4) -> pd.DataFrame:
|
|
27
|
-
"""Run inference on the Endpoint using the provided DataFrame
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
endpoint_name (str): The name of the Endpoint
|
|
31
|
-
eval_df (pd.DataFrame): The DataFrame to run predictions on
|
|
32
|
-
sm_session (sagemaker.session.Session, optional): SageMaker Session. If None, a cached session is created.
|
|
33
|
-
threads (int): The number of threads to use (default: 4)
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
pd.DataFrame: The DataFrame with predictions
|
|
37
|
-
"""
|
|
38
|
-
# Use cached session if none is provided
|
|
39
|
-
if sm_session is None:
|
|
40
|
-
sm_session = get_or_create_sm_session()
|
|
41
|
-
|
|
42
|
-
predictor = Predictor(
|
|
43
|
-
endpoint_name,
|
|
44
|
-
sagemaker_session=sm_session,
|
|
45
|
-
serializer=CSVSerializer(),
|
|
46
|
-
deserializer=CSVDeserializer(),
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
total_rows = len(eval_df)
|
|
50
|
-
|
|
51
|
-
def process_chunk(chunk_df: pd.DataFrame, start_index: int) -> pd.DataFrame:
|
|
52
|
-
log.info(f"Processing {start_index}:{min(start_index + chunk_size, total_rows)} out of {total_rows} rows...")
|
|
53
|
-
csv_buffer = StringIO()
|
|
54
|
-
chunk_df.to_csv(csv_buffer, index=False)
|
|
55
|
-
response = predictor.predict(csv_buffer.getvalue())
|
|
56
|
-
# CSVDeserializer returns a nested list: first row is headers
|
|
57
|
-
return pd.DataFrame.from_records(response[1:], columns=response[0])
|
|
58
|
-
|
|
59
|
-
# Sagemaker has a connection pool limit of 10
|
|
60
|
-
if threads > 10:
|
|
61
|
-
log.warning("Sagemaker has a connection pool limit of 10. Reducing threads to 10.")
|
|
62
|
-
threads = 10
|
|
63
|
-
|
|
64
|
-
# Compute the chunk size (divide number of threads)
|
|
65
|
-
chunk_size = max(1, total_rows // threads)
|
|
66
|
-
|
|
67
|
-
# We also need to ensure that the chunk size is not too big
|
|
68
|
-
if chunk_size > 100:
|
|
69
|
-
chunk_size = 100
|
|
70
|
-
|
|
71
|
-
# Split DataFrame into chunks and process them concurrently
|
|
72
|
-
chunks = [(eval_df[i : i + chunk_size], i) for i in range(0, total_rows, chunk_size)]
|
|
73
|
-
with ThreadPoolExecutor(max_workers=threads) as executor:
|
|
74
|
-
df_list = list(executor.map(lambda p: process_chunk(*p), chunks))
|
|
75
|
-
|
|
76
|
-
combined_df = pd.concat(df_list, ignore_index=True)
|
|
77
|
-
|
|
78
|
-
# Convert the types of the dataframe
|
|
79
|
-
combined_df = df_type_conversions(combined_df)
|
|
80
|
-
return combined_df
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def df_type_conversions(df: pd.DataFrame) -> pd.DataFrame:
|
|
84
|
-
"""Convert the types of the dataframe that we get from an endpoint
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
df (pd.DataFrame): DataFrame to convert
|
|
88
|
-
|
|
89
|
-
Returns:
|
|
90
|
-
pd.DataFrame: Converted DataFrame
|
|
91
|
-
"""
|
|
92
|
-
# Some endpoints will put in "N/A" values (for CSV serialization)
|
|
93
|
-
# We need to convert these to NaN and the run the conversions below
|
|
94
|
-
# Report on the number of N/A values in each column in the DataFrame
|
|
95
|
-
# For any count above 0 list the column name and the number of N/A values
|
|
96
|
-
na_counts = df.isin(["N/A"]).sum()
|
|
97
|
-
for column, count in na_counts.items():
|
|
98
|
-
if count > 0:
|
|
99
|
-
log.warning(f"{column} has {count} N/A values, converting to NaN")
|
|
100
|
-
pd.set_option("future.no_silent_downcasting", True)
|
|
101
|
-
df = df.replace("N/A", float("nan"))
|
|
102
|
-
|
|
103
|
-
# Convert data to numeric
|
|
104
|
-
# Note: Since we're using CSV serializers numeric columns often get changed to generic 'object' types
|
|
105
|
-
|
|
106
|
-
# Hard Conversion
|
|
107
|
-
# Note: We explicitly catch exceptions for columns that cannot be converted to numeric
|
|
108
|
-
for column in df.columns:
|
|
109
|
-
try:
|
|
110
|
-
df[column] = pd.to_numeric(df[column])
|
|
111
|
-
except ValueError:
|
|
112
|
-
# If a ValueError is raised, the column cannot be converted to numeric, so we keep it as is
|
|
113
|
-
pass
|
|
114
|
-
except TypeError:
|
|
115
|
-
# This typically means a duplicated column name, so confirm duplicate (more than 1) and log it
|
|
116
|
-
column_count = (df.columns == column).sum()
|
|
117
|
-
log.critical(f"{column} occurs {column_count} times in the DataFrame.")
|
|
118
|
-
pass
|
|
119
|
-
|
|
120
|
-
# Soft Conversion
|
|
121
|
-
# Convert columns to the best possible dtype that supports the pd.NA missing value.
|
|
122
|
-
df = df.convert_dtypes()
|
|
123
|
-
|
|
124
|
-
# Convert pd.NA placeholders to pd.NA
|
|
125
|
-
# Note: CSV serialization converts pd.NA to blank strings, so we have to put in placeholders
|
|
126
|
-
df.replace("__NA__", pd.NA, inplace=True)
|
|
127
|
-
|
|
128
|
-
# Check for True/False values in the string columns
|
|
129
|
-
for column in df.select_dtypes(include=["string"]).columns:
|
|
130
|
-
if df[column].str.lower().isin(["true", "false"]).all():
|
|
131
|
-
df[column] = df[column].str.lower().map({"true": True, "false": False})
|
|
132
|
-
|
|
133
|
-
# Return the Dataframe
|
|
134
|
-
return df
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if __name__ == "__main__":
|
|
138
|
-
"""Exercise the Endpoint Utilities"""
|
|
139
|
-
import time
|
|
140
|
-
from workbench.api.endpoint import Endpoint
|
|
141
|
-
from workbench.utils.endpoint_utils import fs_training_data, fs_evaluation_data
|
|
142
|
-
|
|
143
|
-
# Create an Endpoint
|
|
144
|
-
my_endpoint_name = "abalone-regression"
|
|
145
|
-
my_endpoint = Endpoint(my_endpoint_name)
|
|
146
|
-
if not my_endpoint.exists():
|
|
147
|
-
print(f"Endpoint {my_endpoint_name} does not exist.")
|
|
148
|
-
exit(1)
|
|
149
|
-
|
|
150
|
-
# Get the training data
|
|
151
|
-
my_train_df = fs_training_data(my_endpoint)
|
|
152
|
-
print(my_train_df)
|
|
153
|
-
|
|
154
|
-
# Run Fast Inference and time it
|
|
155
|
-
my_sm_session = my_endpoint.sm_session
|
|
156
|
-
my_eval_df = fs_evaluation_data(my_endpoint)
|
|
157
|
-
start_time = time.time()
|
|
158
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df, my_sm_session)
|
|
159
|
-
end_time = time.time()
|
|
160
|
-
print(f"Fast Inference took {end_time - start_time} seconds")
|
|
161
|
-
print(my_results_df)
|
|
162
|
-
print(my_results_df.info())
|
|
163
|
-
|
|
164
|
-
# Test with no session
|
|
165
|
-
my_results_df = fast_inference(my_endpoint_name, my_eval_df)
|
|
166
|
-
print(my_results_df)
|
|
167
|
-
print(my_results_df.info())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|