workbench 0.8.205__py3-none-any.whl → 0.8.213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/models/noise_model.py +388 -0
  2. workbench/api/endpoint.py +3 -6
  3. workbench/api/feature_set.py +1 -1
  4. workbench/api/model.py +5 -11
  5. workbench/cached/cached_model.py +4 -4
  6. workbench/core/artifacts/endpoint_core.py +63 -153
  7. workbench/core/artifacts/model_core.py +21 -19
  8. workbench/core/transforms/features_to_model/features_to_model.py +2 -2
  9. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +1 -1
  10. workbench/model_script_utils/model_script_utils.py +335 -0
  11. workbench/model_script_utils/pytorch_utils.py +395 -0
  12. workbench/model_script_utils/uq_harness.py +278 -0
  13. workbench/model_scripts/chemprop/chemprop.template +289 -666
  14. workbench/model_scripts/chemprop/generated_model_script.py +292 -669
  15. workbench/model_scripts/chemprop/model_script_utils.py +335 -0
  16. workbench/model_scripts/chemprop/requirements.txt +2 -10
  17. workbench/model_scripts/pytorch_model/generated_model_script.py +355 -612
  18. workbench/model_scripts/pytorch_model/model_script_utils.py +335 -0
  19. workbench/model_scripts/pytorch_model/pytorch.template +350 -607
  20. workbench/model_scripts/pytorch_model/pytorch_utils.py +395 -0
  21. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  22. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  23. workbench/model_scripts/script_generation.py +2 -5
  24. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  25. workbench/model_scripts/xgb_model/generated_model_script.py +349 -412
  26. workbench/model_scripts/xgb_model/model_script_utils.py +335 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  28. workbench/model_scripts/xgb_model/xgb_model.template +344 -407
  29. workbench/scripts/training_test.py +85 -0
  30. workbench/utils/chemprop_utils.py +18 -656
  31. workbench/utils/metrics_utils.py +172 -0
  32. workbench/utils/model_utils.py +104 -47
  33. workbench/utils/pytorch_utils.py +32 -472
  34. workbench/utils/xgboost_local_crossfold.py +267 -0
  35. workbench/utils/xgboost_model_utils.py +49 -356
  36. workbench/web_interface/components/plugins/model_details.py +30 -68
  37. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/METADATA +5 -5
  38. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/RECORD +42 -31
  39. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/entry_points.txt +1 -0
  40. workbench/model_scripts/uq_models/mapie.template +0 -605
  41. workbench/model_scripts/uq_models/requirements.txt +0 -1
  42. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/WHEEL +0 -0
  43. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.205.dist-info → workbench-0.8.213.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,85 @@
1
+ """
2
+ Local test harness for SageMaker training scripts.
3
+
4
+ Usage:
5
+ python training_test.py <model_script.py> <featureset_name>
6
+
7
+ Example:
8
+ python training_test.py ../model_scripts/pytorch_model/generated_model_script.py caco2-class-features
9
+ """
10
+
11
+ import os
12
+ import shutil
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+
17
+ import pandas as pd
18
+
19
+ from workbench.api import FeatureSet
20
+
21
+
22
+ def get_training_data(featureset_name: str) -> pd.DataFrame:
23
+ """Get training data from the FeatureSet."""
24
+ fs = FeatureSet(featureset_name)
25
+ return fs.pull_dataframe()
26
+
27
+
28
+ def main():
29
+ if len(sys.argv) < 3:
30
+ print("Usage: python training_test.py <model_script.py> <featureset_name>")
31
+ sys.exit(1)
32
+
33
+ script_path = sys.argv[1]
34
+ featureset_name = sys.argv[2]
35
+
36
+ if not os.path.exists(script_path):
37
+ print(f"Error: Script not found: {script_path}")
38
+ sys.exit(1)
39
+
40
+ # Create temp directories
41
+ model_dir = tempfile.mkdtemp(prefix="training_model_")
42
+ train_dir = tempfile.mkdtemp(prefix="training_data_")
43
+ output_dir = tempfile.mkdtemp(prefix="training_output_")
44
+
45
+ print(f"Model dir: {model_dir}")
46
+ print(f"Train dir: {train_dir}")
47
+
48
+ try:
49
+ # Get training data and save to CSV
50
+ print(f"Loading FeatureSet: {featureset_name}")
51
+ df = get_training_data(featureset_name)
52
+ print(f"Data shape: {df.shape}")
53
+
54
+ train_file = os.path.join(train_dir, "training_data.csv")
55
+ df.to_csv(train_file, index=False)
56
+
57
+ # Set up environment
58
+ env = os.environ.copy()
59
+ env["SM_MODEL_DIR"] = model_dir
60
+ env["SM_CHANNEL_TRAIN"] = train_dir
61
+ env["SM_OUTPUT_DATA_DIR"] = output_dir
62
+
63
+ print("\n" + "=" * 60)
64
+ print("Starting training...")
65
+ print("=" * 60 + "\n")
66
+
67
+ # Run the script
68
+ cmd = [sys.executable, script_path, "--model-dir", model_dir, "--train", train_dir]
69
+ result = subprocess.run(cmd, env=env)
70
+
71
+ print("\n" + "=" * 60)
72
+ if result.returncode == 0:
73
+ print("Training completed successfully!")
74
+ else:
75
+ print(f"Training failed with return code: {result.returncode}")
76
+ print("=" * 60)
77
+
78
+ finally:
79
+ shutil.rmtree(model_dir, ignore_errors=True)
80
+ shutil.rmtree(train_dir, ignore_errors=True)
81
+ shutil.rmtree(output_dir, ignore_errors=True)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()