workbench 0.8.197__py3-none-any.whl → 0.8.201__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. workbench/algorithms/dataframe/proximity.py +19 -12
  2. workbench/api/__init__.py +2 -1
  3. workbench/api/feature_set.py +7 -4
  4. workbench/api/model.py +1 -1
  5. workbench/core/artifacts/__init__.py +11 -2
  6. workbench/core/artifacts/endpoint_core.py +84 -46
  7. workbench/core/artifacts/feature_set_core.py +69 -1
  8. workbench/core/artifacts/model_core.py +37 -7
  9. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  10. workbench/core/transforms/features_to_model/features_to_model.py +23 -20
  11. workbench/core/views/view.py +2 -2
  12. workbench/model_scripts/chemprop/chemprop.template +931 -0
  13. workbench/model_scripts/chemprop/generated_model_script.py +931 -0
  14. workbench/model_scripts/chemprop/requirements.txt +11 -0
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  16. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  17. workbench/model_scripts/custom_models/proximity/proximity.py +19 -12
  18. workbench/model_scripts/custom_models/uq_models/proximity.py +19 -12
  19. workbench/model_scripts/pytorch_model/generated_model_script.py +130 -88
  20. workbench/model_scripts/pytorch_model/pytorch.template +128 -86
  21. workbench/model_scripts/scikit_learn/generated_model_script.py +302 -0
  22. workbench/model_scripts/script_generation.py +10 -7
  23. workbench/model_scripts/uq_models/generated_model_script.py +25 -18
  24. workbench/model_scripts/uq_models/mapie.template +23 -16
  25. workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
  26. workbench/model_scripts/xgb_model/xgb_model.template +2 -2
  27. workbench/repl/workbench_shell.py +14 -5
  28. workbench/scripts/endpoint_test.py +162 -0
  29. workbench/scripts/{lambda_launcher.py → lambda_test.py} +10 -0
  30. workbench/utils/chemprop_utils.py +724 -0
  31. workbench/utils/pytorch_utils.py +497 -0
  32. workbench/utils/xgboost_model_utils.py +12 -5
  33. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/METADATA +2 -2
  34. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/RECORD +38 -30
  35. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/entry_points.txt +2 -1
  36. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/WHEEL +0 -0
  37. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/licenses/LICENSE +0 -0
  38. {workbench-0.8.197.dist-info → workbench-0.8.201.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,162 @@
1
+ """
2
+ Local test harness for SageMaker model scripts.
3
+
4
+ Usage:
5
+ python model_script_harness.py <local_script.py> <model_name>
6
+
7
+ Example:
8
+ python model_script_harness.py pytorch.py aqsol-pytorch-reg
9
+
10
+ This allows you to test LOCAL changes to a model script against deployed model artifacts.
11
+ Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
12
+
13
+ Optional: testing/env.json with additional environment variables
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+ import importlib.util
20
+ import tempfile
21
+ import shutil
22
+ import pandas as pd
23
+ import torch
24
+
25
+ # Workbench Imports
26
+ from workbench.api import Model, FeatureSet
27
+ from workbench.utils.pytorch_utils import download_and_extract_model
28
+
29
+ # Force CPU mode BEFORE any PyTorch imports to avoid MPS/CUDA issues on Mac
30
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
31
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
32
+ torch.set_default_device("cpu")
33
+ # Disable MPS entirely
34
+ if hasattr(torch.backends, "mps"):
35
+ torch.backends.mps.is_available = lambda: False
36
+
37
+
38
+ def get_eval_data(workbench_model: Model) -> pd.DataFrame:
39
+ """Get evaluation data from the FeatureSet associated with this model."""
40
+ # Get the FeatureSet
41
+ fs_name = workbench_model.get_input()
42
+ fs = FeatureSet(fs_name)
43
+ if not fs.exists():
44
+ raise ValueError(f"No FeatureSet found: {fs_name}")
45
+
46
+ # Get evaluation data (training = FALSE)
47
+ table = workbench_model.training_view().table
48
+ print(f"Querying evaluation data from {table}...")
49
+ eval_df = fs.query(f'SELECT * FROM "{table}" WHERE training = FALSE')
50
+ print(f"Retrieved {len(eval_df)} evaluation rows")
51
+
52
+ return eval_df
53
+
54
+
55
+ def load_model_script(script_path: str):
56
+ """Dynamically load the model script module."""
57
+ if not os.path.exists(script_path):
58
+ raise FileNotFoundError(f"Script not found: {script_path}")
59
+
60
+ spec = importlib.util.spec_from_file_location("model_script", script_path)
61
+ module = importlib.util.module_from_spec(spec)
62
+
63
+ # Add to sys.modules so imports within the script work
64
+ sys.modules["model_script"] = module
65
+
66
+ spec.loader.exec_module(module)
67
+ return module
68
+
69
+
70
+ def main():
71
+ if len(sys.argv) < 3:
72
+ print("Usage: python model_script_harness.py <local_script.py> <model_name>")
73
+ print("\nArguments:")
74
+ print(" local_script.py - Path to your LOCAL model script to test")
75
+ print(" model_name - Workbench model name (e.g., aqsol-pytorch-reg)")
76
+ print("\nOptional: testing/env.json with additional environment variables")
77
+ sys.exit(1)
78
+
79
+ script_path = sys.argv[1]
80
+ model_name = sys.argv[2]
81
+
82
+ # Validate local script exists
83
+ if not os.path.exists(script_path):
84
+ print(f"Error: Local script not found: {script_path}")
85
+ sys.exit(1)
86
+
87
+ # Initialize Workbench model
88
+ print(f"Loading Workbench model: {model_name}")
89
+ workbench_model = Model(model_name)
90
+ print(f"Model Framework: {workbench_model.model_framework}")
91
+ print()
92
+
93
+ # Create a temporary model directory
94
+ model_dir = tempfile.mkdtemp(prefix="model_harness_")
95
+ print(f"Using model directory: {model_dir}")
96
+
97
+ try:
98
+ # Load environment variables from env.json if it exists
99
+ if os.path.exists("testing/env.json"):
100
+ print("Loading environment variables from testing/env.json")
101
+ with open("testing/env.json") as f:
102
+ env_vars = json.load(f)
103
+ for key, value in env_vars.items():
104
+ os.environ[key] = value
105
+ print(f" Set {key} = {value}")
106
+ print()
107
+
108
+ # Set up SageMaker environment variables
109
+ os.environ["SM_MODEL_DIR"] = model_dir
110
+ print(f"Set SM_MODEL_DIR = {model_dir}")
111
+
112
+ # Download and extract model artifacts
113
+ s3_uri = workbench_model.model_data_url()
114
+ download_and_extract_model(s3_uri, model_dir)
115
+ print()
116
+
117
+ # Load the LOCAL model script
118
+ print(f"Loading LOCAL model script: {script_path}")
119
+ module = load_model_script(script_path)
120
+ print()
121
+
122
+ # Check for required functions
123
+ if not hasattr(module, "model_fn"):
124
+ raise AttributeError("Model script must have a model_fn function")
125
+ if not hasattr(module, "predict_fn"):
126
+ raise AttributeError("Model script must have a predict_fn function")
127
+
128
+ # Load the model
129
+ print("Calling model_fn...")
130
+ print("-" * 50)
131
+ model = module.model_fn(model_dir)
132
+ print("-" * 50)
133
+ print(f"Model loaded: {type(model)}")
134
+ print()
135
+
136
+ # Get evaluation data from FeatureSet
137
+ print("Pulling evaluation data from FeatureSet...")
138
+ df = get_eval_data(workbench_model)
139
+ print(f"Input shape: {df.shape}")
140
+ print(f"Columns: {df.columns.tolist()}")
141
+ print()
142
+
143
+ print("Calling predict_fn...")
144
+ print("-" * 50)
145
+ result = module.predict_fn(df, model)
146
+ print("-" * 50)
147
+ print()
148
+
149
+ print("Prediction result:")
150
+ print(f"Output shape: {result.shape}")
151
+ print(f"Output columns: {result.columns.tolist()}")
152
+ print()
153
+ print(result.head(10).to_string())
154
+
155
+ finally:
156
+ # Cleanup
157
+ print(f"\nCleaning up model directory: {model_dir}")
158
+ shutil.rmtree(model_dir, ignore_errors=True)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
@@ -1,3 +1,13 @@
1
+ """
2
+ Local test harness for AWS Lambda scripts.
3
+
4
+ Usage:
5
+ lambda_test <lambda_script.py>
6
+
7
+ Required: testing/event.json with the event definition
8
+ Options: testing/env.json with a set of ENV vars
9
+ """
10
+
1
11
  import sys
2
12
  import os
3
13
  import json