workbench 0.8.205__py3-none-any.whl → 0.8.212__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/api/endpoint.py +3 -6
- workbench/api/feature_set.py +1 -1
- workbench/api/model.py +5 -11
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/endpoint_core.py +57 -145
- workbench/core/artifacts/model_core.py +21 -19
- workbench/core/transforms/features_to_model/features_to_model.py +2 -2
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
- workbench/model_script_utils/model_script_utils.py +335 -0
- workbench/model_script_utils/pytorch_utils.py +395 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +289 -666
- workbench/model_scripts/chemprop/generated_model_script.py +292 -669
- workbench/model_scripts/chemprop/model_script_utils.py +335 -0
- workbench/model_scripts/chemprop/requirements.txt +2 -10
- workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
- workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
- workbench/model_scripts/pytorch_model/pytorch.template +350 -607
- workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/script_generation.py +2 -5
- workbench/model_scripts/uq_models/generated_model_script.py +65 -422
- workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
- workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +344 -407
- workbench/scripts/training_test.py +85 -0
- workbench/utils/chemprop_utils.py +18 -656
- workbench/utils/metrics_utils.py +172 -0
- workbench/utils/model_utils.py +104 -47
- workbench/utils/pytorch_utils.py +32 -472
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +49 -356
- workbench/web_interface/components/plugins/model_details.py +30 -68
- {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/METADATA +5 -5
- {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/RECORD +42 -31
- {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/entry_points.txt +1 -0
- workbench/model_scripts/uq_models/mapie.template +0 -605
- workbench/model_scripts/uq_models/requirements.txt +0 -1
- {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/WHEEL +0 -0
- {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.205.dist-info → workbench-0.8.212.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local test harness for SageMaker training scripts.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python training_test.py <model_script.py> <featureset_name>
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
python training_test.py ../model_scripts/pytorch_model/generated_model_script.py caco2-class-features
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import shutil
|
|
13
|
+
import subprocess
|
|
14
|
+
import sys
|
|
15
|
+
import tempfile
|
|
16
|
+
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
from workbench.api import FeatureSet
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_training_data(featureset_name: str) -> pd.DataFrame:
|
|
23
|
+
"""Get training data from the FeatureSet."""
|
|
24
|
+
fs = FeatureSet(featureset_name)
|
|
25
|
+
return fs.pull_dataframe()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def main():
|
|
29
|
+
if len(sys.argv) < 3:
|
|
30
|
+
print("Usage: python training_test.py <model_script.py> <featureset_name>")
|
|
31
|
+
sys.exit(1)
|
|
32
|
+
|
|
33
|
+
script_path = sys.argv[1]
|
|
34
|
+
featureset_name = sys.argv[2]
|
|
35
|
+
|
|
36
|
+
if not os.path.exists(script_path):
|
|
37
|
+
print(f"Error: Script not found: {script_path}")
|
|
38
|
+
sys.exit(1)
|
|
39
|
+
|
|
40
|
+
# Create temp directories
|
|
41
|
+
model_dir = tempfile.mkdtemp(prefix="training_model_")
|
|
42
|
+
train_dir = tempfile.mkdtemp(prefix="training_data_")
|
|
43
|
+
output_dir = tempfile.mkdtemp(prefix="training_output_")
|
|
44
|
+
|
|
45
|
+
print(f"Model dir: {model_dir}")
|
|
46
|
+
print(f"Train dir: {train_dir}")
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
# Get training data and save to CSV
|
|
50
|
+
print(f"Loading FeatureSet: {featureset_name}")
|
|
51
|
+
df = get_training_data(featureset_name)
|
|
52
|
+
print(f"Data shape: {df.shape}")
|
|
53
|
+
|
|
54
|
+
train_file = os.path.join(train_dir, "training_data.csv")
|
|
55
|
+
df.to_csv(train_file, index=False)
|
|
56
|
+
|
|
57
|
+
# Set up environment
|
|
58
|
+
env = os.environ.copy()
|
|
59
|
+
env["SM_MODEL_DIR"] = model_dir
|
|
60
|
+
env["SM_CHANNEL_TRAIN"] = train_dir
|
|
61
|
+
env["SM_OUTPUT_DATA_DIR"] = output_dir
|
|
62
|
+
|
|
63
|
+
print("\n" + "=" * 60)
|
|
64
|
+
print("Starting training...")
|
|
65
|
+
print("=" * 60 + "\n")
|
|
66
|
+
|
|
67
|
+
# Run the script
|
|
68
|
+
cmd = [sys.executable, script_path, "--model-dir", model_dir, "--train", train_dir]
|
|
69
|
+
result = subprocess.run(cmd, env=env)
|
|
70
|
+
|
|
71
|
+
print("\n" + "=" * 60)
|
|
72
|
+
if result.returncode == 0:
|
|
73
|
+
print("Training completed successfully!")
|
|
74
|
+
else:
|
|
75
|
+
print(f"Training failed with return code: {result.returncode}")
|
|
76
|
+
print("=" * 60)
|
|
77
|
+
|
|
78
|
+
finally:
|
|
79
|
+
shutil.rmtree(model_dir, ignore_errors=True)
|
|
80
|
+
shutil.rmtree(train_dir, ignore_errors=True)
|
|
81
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == "__main__":
|
|
85
|
+
main()
|