workbench 0.8.213__py3-none-any.whl → 0.8.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/algorithms/sql/outliers.py +3 -3
  9. workbench/api/__init__.py +3 -0
  10. workbench/api/endpoint.py +10 -5
  11. workbench/api/feature_set.py +76 -6
  12. workbench/api/meta_model.py +289 -0
  13. workbench/api/model.py +43 -4
  14. workbench/core/artifacts/endpoint_core.py +65 -117
  15. workbench/core/artifacts/feature_set_core.py +3 -3
  16. workbench/core/artifacts/model_core.py +6 -4
  17. workbench/core/pipelines/pipeline_executor.py +1 -1
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  19. workbench/model_script_utils/model_script_utils.py +15 -11
  20. workbench/model_script_utils/pytorch_utils.py +11 -1
  21. workbench/model_scripts/chemprop/chemprop.template +147 -71
  22. workbench/model_scripts/chemprop/generated_model_script.py +151 -75
  23. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  24. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  25. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  27. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  28. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  29. workbench/model_scripts/meta_model/meta_model.template +209 -0
  30. workbench/model_scripts/pytorch_model/generated_model_script.py +45 -27
  31. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  32. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  33. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  34. workbench/model_scripts/script_generation.py +4 -0
  35. workbench/model_scripts/xgb_model/generated_model_script.py +167 -156
  36. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  37. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  38. workbench/repl/workbench_shell.py +0 -5
  39. workbench/scripts/endpoint_test.py +2 -2
  40. workbench/scripts/meta_model_sim.py +35 -0
  41. workbench/utils/chem_utils/fingerprints.py +87 -46
  42. workbench/utils/chemprop_utils.py +23 -5
  43. workbench/utils/meta_model_simulator.py +499 -0
  44. workbench/utils/metrics_utils.py +94 -10
  45. workbench/utils/model_utils.py +91 -9
  46. workbench/utils/pytorch_utils.py +1 -1
  47. workbench/utils/shap_utils.py +1 -55
  48. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  49. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/METADATA +2 -1
  50. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/RECORD +54 -50
  51. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/entry_points.txt +1 -0
  52. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  53. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  54. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  55. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  56. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/WHEEL +0 -0
  57. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/licenses/LICENSE +0 -0
  58. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
1
+ # Meta Model Template for Workbench
2
+ #
3
+ # NOTE: This is called a "meta model" but it's really a "meta endpoint" - it aggregates
4
+ # predictions from multiple child endpoints. We call it a "model" because Workbench
5
+ # creates Model artifacts that get deployed as Endpoints, so this follows that pattern.
6
+ #
7
+ # Assumptions/Shortcuts:
8
+ # - All child endpoints are regression models
9
+ # - All child endpoints output 'prediction' and 'confidence' columns
10
+ # - Aggregation uses model weights (provided at meta model creation time)
11
+ #
12
+ # This template:
13
+ # - Has no real training phase (just saves metadata including model weights)
14
+ # - At inference time, calls child endpoints and aggregates their predictions
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ from concurrent.futures import ThreadPoolExecutor, as_completed
20
+ from io import StringIO
21
+
22
+ import pandas as pd
23
+
24
+ from workbench_bridges.endpoints.fast_inference import fast_inference
25
+
26
+ # Template parameters (filled in by Workbench)
27
+ TEMPLATE_PARAMS = {
28
+ "child_endpoints": ['logd-reg-pytorch', 'logd-reg-chemprop'],
29
+ "target_column": "logd",
30
+ "model_weights": {'logd-reg-pytorch': 0.4228205813233993, 'logd-reg-chemprop': 0.5771794186766008},
31
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-meta/training",
32
+ "aws_region": "us-west-2",
33
+ }
34
+
35
+
36
+ def invoke_endpoints_parallel(endpoint_names: list[str], df: pd.DataFrame) -> dict[str, pd.DataFrame]:
37
+ """Call multiple child endpoints in parallel and collect their results.
38
+
39
+ Args:
40
+ endpoint_names: List of endpoint names to call
41
+ df: Input DataFrame to send to each endpoint
42
+
43
+ Returns:
44
+ Dict mapping endpoint_name -> result DataFrame (or None if failed)
45
+ """
46
+ results = {}
47
+
48
+ def call_endpoint(name: str) -> tuple[str, pd.DataFrame | None]:
49
+ try:
50
+ return name, fast_inference(name, df)
51
+ except Exception as e:
52
+ print(f"Error calling endpoint {name}: {e}")
53
+ return name, None
54
+
55
+ with ThreadPoolExecutor(max_workers=len(endpoint_names)) as executor:
56
+ futures = {executor.submit(call_endpoint, name): name for name in endpoint_names}
57
+ for future in as_completed(futures):
58
+ name, result = future.result()
59
+ results[name] = result
60
+
61
+ return results
62
+
63
+
64
+ def aggregate_predictions(results: dict[str, pd.DataFrame], model_weights: dict[str, float]) -> pd.DataFrame:
65
+ """Aggregate predictions from multiple endpoints using model weights.
66
+
67
+ Args:
68
+ results: Dict mapping endpoint_name -> predictions DataFrame
69
+ Each DataFrame must have 'prediction' and 'confidence' columns
70
+ model_weights: Dict mapping endpoint_name -> weight
71
+
72
+ Returns:
73
+ DataFrame with aggregated prediction, prediction_std, and confidence
74
+ """
75
+ # Filter out failed endpoints
76
+ valid_results = {k: v for k, v in results.items() if v is not None}
77
+ if not valid_results:
78
+ raise ValueError("All child endpoints failed")
79
+
80
+ # Use first result as base (for id columns, etc.)
81
+ first_df = list(valid_results.values())[0]
82
+ output_df = first_df.drop(columns=["prediction", "confidence", "prediction_std"], errors="ignore").copy()
83
+
84
+ # Build DataFrames of predictions and confidences from all endpoints
85
+ pred_df = pd.DataFrame({name: df["prediction"] for name, df in valid_results.items()})
86
+ conf_df = pd.DataFrame({name: df["confidence"] for name, df in valid_results.items()})
87
+
88
+ # Apply model weights (renormalize for valid endpoints only)
89
+ valid_weights = {k: model_weights.get(k, 1.0) for k in valid_results}
90
+ weight_sum = sum(valid_weights.values())
91
+ normalized_weights = {k: v / weight_sum for k, v in valid_weights.items()}
92
+
93
+ # Weighted average
94
+ output_df["prediction"] = sum(pred_df[name] * w for name, w in normalized_weights.items())
95
+
96
+ # Ensemble std across child endpoints
97
+ output_df["prediction_std"] = pred_df.std(axis=1)
98
+
99
+ # Aggregated confidence: weighted mean of child confidences
100
+ output_df["confidence"] = sum(conf_df[name] * w for name, w in normalized_weights.items())
101
+
102
+ return output_df
103
+
104
+
105
+ # =============================================================================
106
+ # Model Loading (for SageMaker inference)
107
+ # =============================================================================
108
+ def model_fn(model_dir: str) -> dict:
109
+ """Load meta model configuration."""
110
+ with open(os.path.join(model_dir, "meta_config.json")) as f:
111
+ config = json.load(f)
112
+
113
+ # Set AWS_REGION for fast_inference (baked in at training time)
114
+ if config.get("aws_region"):
115
+ os.environ["AWS_REGION"] = config["aws_region"]
116
+
117
+ print(f"Meta model loaded: {len(config['child_endpoints'])} child endpoints")
118
+ print(f"Model weights: {config.get('model_weights')}")
119
+ print(f"AWS region: {config.get('aws_region')}")
120
+ return config
121
+
122
+
123
+ def input_fn(input_data, content_type):
124
+ """Parse input data and return a DataFrame."""
125
+ if not input_data:
126
+ raise ValueError("Empty input data is not supported!")
127
+
128
+ # Decode bytes to string if necessary
129
+ if isinstance(input_data, bytes):
130
+ input_data = input_data.decode("utf-8")
131
+
132
+ if "text/csv" in content_type:
133
+ return pd.read_csv(StringIO(input_data))
134
+ elif "application/json" in content_type:
135
+ return pd.DataFrame(json.loads(input_data))
136
+ else:
137
+ raise ValueError(f"{content_type} not supported!")
138
+
139
+
140
+ def output_fn(output_df, accept_type):
141
+ """Supports both CSV and JSON output formats."""
142
+ if "text/csv" in accept_type:
143
+ return output_df.to_csv(index=False), "text/csv"
144
+ elif "application/json" in accept_type:
145
+ return output_df.to_json(orient="records"), "application/json"
146
+ else:
147
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
148
+
149
+
150
+ # =============================================================================
151
+ # Inference (for SageMaker inference)
152
+ # =============================================================================
153
+ def predict_fn(df: pd.DataFrame, config: dict) -> pd.DataFrame:
154
+ """Run inference by calling child endpoints and aggregating results."""
155
+ child_endpoints = config["child_endpoints"]
156
+ model_weights = config.get("model_weights", {})
157
+
158
+ print(f"Calling {len(child_endpoints)} child endpoints: {child_endpoints}")
159
+
160
+ # Call all child endpoints
161
+ results = invoke_endpoints_parallel(child_endpoints, df)
162
+
163
+ # Report status
164
+ for name, result in results.items():
165
+ status = f"{len(result)} rows" if result is not None else "FAILED"
166
+ print(f" {name}: {status}")
167
+
168
+ # Aggregate predictions using model weights
169
+ output_df = aggregate_predictions(results, model_weights)
170
+
171
+ print(f"Aggregated {len(output_df)} predictions from {len(results)} endpoints")
172
+ return output_df
173
+
174
+
175
+ # =============================================================================
176
+ # Training (just saves configuration - no actual training)
177
+ # =============================================================================
178
+ if __name__ == "__main__":
179
+ parser = argparse.ArgumentParser()
180
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
181
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
182
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
183
+ args = parser.parse_args()
184
+
185
+ child_endpoints = TEMPLATE_PARAMS["child_endpoints"]
186
+ target_column = TEMPLATE_PARAMS["target_column"]
187
+ model_weights = TEMPLATE_PARAMS["model_weights"]
188
+ aws_region = TEMPLATE_PARAMS["aws_region"]
189
+
190
+ print("=" * 60)
191
+ print("Meta Model Configuration")
192
+ print("=" * 60)
193
+ print(f"Child endpoints: {child_endpoints}")
194
+ print(f"Target column: {target_column}")
195
+ print(f"Model weights: {model_weights}")
196
+ print(f"AWS region: {aws_region}")
197
+
198
+ # Save configuration for inference
199
+ config = {
200
+ "child_endpoints": child_endpoints,
201
+ "target_column": target_column,
202
+ "model_weights": model_weights,
203
+ "aws_region": aws_region,
204
+ }
205
+
206
+ with open(os.path.join(args.model_dir, "meta_config.json"), "w") as f:
207
+ json.dump(config, f, indent=2)
208
+
209
+ print(f"\nMeta model configuration saved to {args.model_dir}")
@@ -0,0 +1,209 @@
1
+ # Meta Model Template for Workbench
2
+ #
3
+ # NOTE: This is called a "meta model" but it's really a "meta endpoint" - it aggregates
4
+ # predictions from multiple child endpoints. We call it a "model" because Workbench
5
+ # creates Model artifacts that get deployed as Endpoints, so this follows that pattern.
6
+ #
7
+ # Assumptions/Shortcuts:
8
+ # - All child endpoints are regression models
9
+ # - All child endpoints output 'prediction' and 'confidence' columns
10
+ # - Aggregation uses model weights (provided at meta model creation time)
11
+ #
12
+ # This template:
13
+ # - Has no real training phase (just saves metadata including model weights)
14
+ # - At inference time, calls child endpoints and aggregates their predictions
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ from concurrent.futures import ThreadPoolExecutor, as_completed
20
+ from io import StringIO
21
+
22
+ import pandas as pd
23
+
24
+ from workbench_bridges.endpoints.fast_inference import fast_inference
25
+
26
+ # Template parameters (filled in by Workbench)
27
+ TEMPLATE_PARAMS = {
28
+ "child_endpoints": "{{child_endpoints}}",
29
+ "target_column": "{{target_column}}",
30
+ "model_weights": "{{model_weights}}",
31
+ "model_metrics_s3_path": "{{model_metrics_s3_path}}",
32
+ "aws_region": "{{aws_region}}",
33
+ }
34
+
35
+
36
+ def invoke_endpoints_parallel(endpoint_names: list[str], df: pd.DataFrame) -> dict[str, pd.DataFrame]:
37
+ """Call multiple child endpoints in parallel and collect their results.
38
+
39
+ Args:
40
+ endpoint_names: List of endpoint names to call
41
+ df: Input DataFrame to send to each endpoint
42
+
43
+ Returns:
44
+ Dict mapping endpoint_name -> result DataFrame (or None if failed)
45
+ """
46
+ results = {}
47
+
48
+ def call_endpoint(name: str) -> tuple[str, pd.DataFrame | None]:
49
+ try:
50
+ return name, fast_inference(name, df)
51
+ except Exception as e:
52
+ print(f"Error calling endpoint {name}: {e}")
53
+ return name, None
54
+
55
+ with ThreadPoolExecutor(max_workers=len(endpoint_names)) as executor:
56
+ futures = {executor.submit(call_endpoint, name): name for name in endpoint_names}
57
+ for future in as_completed(futures):
58
+ name, result = future.result()
59
+ results[name] = result
60
+
61
+ return results
62
+
63
+
64
+ def aggregate_predictions(results: dict[str, pd.DataFrame], model_weights: dict[str, float]) -> pd.DataFrame:
65
+ """Aggregate predictions from multiple endpoints using model weights.
66
+
67
+ Args:
68
+ results: Dict mapping endpoint_name -> predictions DataFrame
69
+ Each DataFrame must have 'prediction' and 'confidence' columns
70
+ model_weights: Dict mapping endpoint_name -> weight
71
+
72
+ Returns:
73
+ DataFrame with aggregated prediction, prediction_std, and confidence
74
+ """
75
+ # Filter out failed endpoints
76
+ valid_results = {k: v for k, v in results.items() if v is not None}
77
+ if not valid_results:
78
+ raise ValueError("All child endpoints failed")
79
+
80
+ # Use first result as base (for id columns, etc.)
81
+ first_df = list(valid_results.values())[0]
82
+ output_df = first_df.drop(columns=["prediction", "confidence", "prediction_std"], errors="ignore").copy()
83
+
84
+ # Build DataFrames of predictions and confidences from all endpoints
85
+ pred_df = pd.DataFrame({name: df["prediction"] for name, df in valid_results.items()})
86
+ conf_df = pd.DataFrame({name: df["confidence"] for name, df in valid_results.items()})
87
+
88
+ # Apply model weights (renormalize for valid endpoints only)
89
+ valid_weights = {k: model_weights.get(k, 1.0) for k in valid_results}
90
+ weight_sum = sum(valid_weights.values())
91
+ normalized_weights = {k: v / weight_sum for k, v in valid_weights.items()}
92
+
93
+ # Weighted average
94
+ output_df["prediction"] = sum(pred_df[name] * w for name, w in normalized_weights.items())
95
+
96
+ # Ensemble std across child endpoints
97
+ output_df["prediction_std"] = pred_df.std(axis=1)
98
+
99
+ # Aggregated confidence: weighted mean of child confidences
100
+ output_df["confidence"] = sum(conf_df[name] * w for name, w in normalized_weights.items())
101
+
102
+ return output_df
103
+
104
+
105
+ # =============================================================================
106
+ # Model Loading (for SageMaker inference)
107
+ # =============================================================================
108
+ def model_fn(model_dir: str) -> dict:
109
+ """Load meta model configuration."""
110
+ with open(os.path.join(model_dir, "meta_config.json")) as f:
111
+ config = json.load(f)
112
+
113
+ # Set AWS_REGION for fast_inference (baked in at training time)
114
+ if config.get("aws_region"):
115
+ os.environ["AWS_REGION"] = config["aws_region"]
116
+
117
+ print(f"Meta model loaded: {len(config['child_endpoints'])} child endpoints")
118
+ print(f"Model weights: {config.get('model_weights')}")
119
+ print(f"AWS region: {config.get('aws_region')}")
120
+ return config
121
+
122
+
123
+ def input_fn(input_data, content_type):
124
+ """Parse input data and return a DataFrame."""
125
+ if not input_data:
126
+ raise ValueError("Empty input data is not supported!")
127
+
128
+ # Decode bytes to string if necessary
129
+ if isinstance(input_data, bytes):
130
+ input_data = input_data.decode("utf-8")
131
+
132
+ if "text/csv" in content_type:
133
+ return pd.read_csv(StringIO(input_data))
134
+ elif "application/json" in content_type:
135
+ return pd.DataFrame(json.loads(input_data))
136
+ else:
137
+ raise ValueError(f"{content_type} not supported!")
138
+
139
+
140
+ def output_fn(output_df, accept_type):
141
+ """Supports both CSV and JSON output formats."""
142
+ if "text/csv" in accept_type:
143
+ return output_df.to_csv(index=False), "text/csv"
144
+ elif "application/json" in accept_type:
145
+ return output_df.to_json(orient="records"), "application/json"
146
+ else:
147
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
148
+
149
+
150
+ # =============================================================================
151
+ # Inference (for SageMaker inference)
152
+ # =============================================================================
153
+ def predict_fn(df: pd.DataFrame, config: dict) -> pd.DataFrame:
154
+ """Run inference by calling child endpoints and aggregating results."""
155
+ child_endpoints = config["child_endpoints"]
156
+ model_weights = config.get("model_weights", {})
157
+
158
+ print(f"Calling {len(child_endpoints)} child endpoints: {child_endpoints}")
159
+
160
+ # Call all child endpoints
161
+ results = invoke_endpoints_parallel(child_endpoints, df)
162
+
163
+ # Report status
164
+ for name, result in results.items():
165
+ status = f"{len(result)} rows" if result is not None else "FAILED"
166
+ print(f" {name}: {status}")
167
+
168
+ # Aggregate predictions using model weights
169
+ output_df = aggregate_predictions(results, model_weights)
170
+
171
+ print(f"Aggregated {len(output_df)} predictions from {len(results)} endpoints")
172
+ return output_df
173
+
174
+
175
+ # =============================================================================
176
+ # Training (just saves configuration - no actual training)
177
+ # =============================================================================
178
+ if __name__ == "__main__":
179
+ parser = argparse.ArgumentParser()
180
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
181
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
182
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
183
+ args = parser.parse_args()
184
+
185
+ child_endpoints = TEMPLATE_PARAMS["child_endpoints"]
186
+ target_column = TEMPLATE_PARAMS["target_column"]
187
+ model_weights = TEMPLATE_PARAMS["model_weights"]
188
+ aws_region = TEMPLATE_PARAMS["aws_region"]
189
+
190
+ print("=" * 60)
191
+ print("Meta Model Configuration")
192
+ print("=" * 60)
193
+ print(f"Child endpoints: {child_endpoints}")
194
+ print(f"Target column: {target_column}")
195
+ print(f"Model weights: {model_weights}")
196
+ print(f"AWS region: {aws_region}")
197
+
198
+ # Save configuration for inference
199
+ config = {
200
+ "child_endpoints": child_endpoints,
201
+ "target_column": target_column,
202
+ "model_weights": model_weights,
203
+ "aws_region": aws_region,
204
+ }
205
+
206
+ with open(os.path.join(args.model_dir, "meta_config.json"), "w") as f:
207
+ json.dump(config, f, indent=2)
208
+
209
+ print(f"\nMeta model configuration saved to {args.model_dir}")
@@ -5,51 +5,36 @@
5
5
  # - Out-of-fold predictions for validation metrics
6
6
  # - Categorical feature embedding via TabularMLP
7
7
  # - Compressed feature decompression
8
+ #
9
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
10
+ # Heavy imports (sklearn, awswrangler) are deferred to training time.
8
11
 
9
- import argparse
10
12
  import json
11
13
  import os
12
14
 
13
- import awswrangler as wr
14
15
  import joblib
15
16
  import numpy as np
16
17
  import pandas as pd
17
18
  import torch
18
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
19
- from sklearn.preprocessing import LabelEncoder
20
-
21
- # Enable Tensor Core optimization for GPUs that support it
22
- torch.set_float32_matmul_precision("medium")
23
19
 
24
20
  from model_script_utils import (
25
- check_dataframe,
26
- compute_classification_metrics,
27
- compute_regression_metrics,
28
21
  convert_categorical_types,
29
22
  decompress_features,
30
23
  expand_proba_column,
31
24
  input_fn,
32
25
  match_features_case_insensitive,
33
26
  output_fn,
34
- print_classification_metrics,
35
- print_confusion_matrix,
36
- print_regression_metrics,
37
27
  )
38
28
  from pytorch_utils import (
39
29
  FeatureScaler,
40
- create_model,
41
30
  load_model,
42
31
  predict,
43
32
  prepare_data,
44
- save_model,
45
- train_model,
46
33
  )
47
34
  from uq_harness import (
48
35
  compute_confidence,
49
36
  load_uq_models,
50
37
  predict_intervals,
51
- save_uq_models,
52
- train_uq_models,
53
38
  )
54
39
 
55
40
  # =============================================================================
@@ -59,13 +44,15 @@ DEFAULT_HYPERPARAMETERS = {
59
44
  # Training parameters
60
45
  "n_folds": 5,
61
46
  "max_epochs": 200,
62
- "early_stopping_patience": 20,
47
+ "early_stopping_patience": 30,
63
48
  "batch_size": 128,
64
- # Model architecture
65
- "layers": "256-128-64",
49
+ # Model architecture (larger capacity - ensemble provides regularization)
50
+ "layers": "512-256-128",
66
51
  "learning_rate": 1e-3,
67
- "dropout": 0.1,
52
+ "dropout": 0.05,
68
53
  "use_batch_norm": True,
54
+ # Loss function for regression (L1Loss=MAE, MSELoss=MSE, HuberLoss, SmoothL1Loss)
55
+ "loss": "L1Loss",
69
56
  # Random seed
70
57
  "seed": 42,
71
58
  }
@@ -74,10 +61,10 @@ DEFAULT_HYPERPARAMETERS = {
74
61
  TEMPLATE_PARAMS = {
75
62
  "model_type": "uq_regressor",
76
63
  "target": "udm_asy_res_efflux_ratio",
77
- "features": ['smr_vsa4', 'tpsa', 'numhdonors', 'nhohcount', 'nbase', 'vsa_estate3', 'fr_guanido', 'mollogp', 'peoe_vsa8', 'peoe_vsa1', 'fr_imine', 'vsa_estate2', 'estate_vsa10', 'asphericity', 'xc_3dv', 'smr_vsa3', 'charge_centroid_distance', 'c3sp3', 'nitrogen_span', 'estate_vsa2', 'minpartialcharge', 'hba_hbd_ratio', 'slogp_vsa1', 'axp_7d', 'nocount', 'vsa_estate4', 'vsa_estate6', 'estate_vsa4', 'xc_4dv', 'xc_4d', 'num_s_centers', 'vsa_estate9', 'chi2v', 'axp_5d', 'mi', 'mse', 'bcut2d_mrhi', 'smr_vsa6', 'hallkieralpha', 'balabanj', 'amphiphilic_moment', 'type_ii_pattern_count', 'minabsestateindex', 'bcut2d_mwlow', 'axp_0dv', 'slogp_vsa5', 'axp_2d', 'axp_1dv', 'xch_5d', 'peoe_vsa10', 'molecular_asymmetry', 'kappa3', 'estate_vsa3', 'sse', 'bcut2d_logphi', 'fr_imidazole', 'molecular_volume_3d', 'bertzct', 'maxestateindex', 'aromatic_interaction_score', 'axp_3d', 'radius_of_gyration', 'vsa_estate7', 'si', 'axp_5dv', 'molecular_axis_length', 'estate_vsa6', 'fpdensitymorgan1', 'axp_6d', 'estate_vsa9', 'fpdensitymorgan2', 'xp_0dv', 'xp_6dv', 'molmr', 'qed', 'estate_vsa8', 'peoe_vsa9', 'xch_6dv', 'xp_7d', 'slogp_vsa2', 'xp_5dv', 'bcut2d_chghi', 'xch_6d', 'chi0n', 'slogp_vsa3', 'chi1v', 'chi3v', 'bcut2d_chglo', 'axp_1d', 'mp', 'num_defined_stereocenters', 'xp_3dv', 'bcut2d_mrlow', 'fr_al_oh', 'peoe_vsa7', 'chi2n', 'axp_6dv', 'axp_2dv', 'chi4n', 'xc_3d', 'axp_7dv', 'vsa_estate8', 'xch_7d', 'maxpartialcharge', 'chi1n', 'peoe_vsa2', 'axp_3dv', 'bcut2d_logplow', 'mv', 'xpc_5dv', 'kappa2', 'vsa_estate5', 'xp_5d', 'mm', 'maxabspartialcharge', 'axp_4dv', 'maxabsestateindex', 'axp_4d', 'xch_4dv', 'xp_2dv', 'heavyatommolwt', 'numatomstereocenters', 'xp_7dv', 'numsaturatedheterocycles', 'xp_3d', 'kappa1', 'mz', 'axp_0d', 'chi1', 'xch_4d', 'smr_vsa1', 'xp_2d', 'estate_vsa5', 'phi', 'fr_ether', 'xc_5d', 'c1sp3', 'estate_vsa7', 'estate_vsa1', 'vsa_estate1', 'slogp_vsa4', 'avgipc', 'smr_vsa10', 'numvalenceelectrons', 'xc_5dv', 'peoe_vsa12', 'peoe_vsa6', 'xpc_5d', 'xpc_6d', 'minestateindex', 'chi3n', 'smr_vsa5', 'xp_4d', 'numheteroatoms', 'fpdensitymorgan3', 'xpc_4d', 'sps', 'xp_1d', 'sv', 'fr_ar_n', 'slogp_vsa10', 'c2sp3', 'xpc_4dv', 'chi0v', 'xpc_6dv', 'xp_1dv', 'vsa_estate10', 'sare', 'c2sp2', 'mpe', 'xch_7dv', 'chi4v', 'type_i_pattern_count', 'sp', 'slogp_vsa8', 'amide_count', 'num_stereocenters', 'num_r_centers', 'tertiary_amine_count', 'spe', 'xp_4dv', 'numsaturatedrings', 'mare', 'numhacceptors', 'chi0', 'fractioncsp3', 'fr_nh0', 'xch_5dv', 'fr_aniline', 'smr_vsa7', 'labuteasa', 'c3sp2', 'xp_0d', 'xp_6d', 'peoe_vsa11', 'fr_ar_nh', 'molwt', 'intramolecular_hbond_potential', 'peoe_vsa3', 'fr_nhpyrrole', 'numaliphaticrings', 'hybratio', 'smr_vsa9', 'peoe_vsa13', 'bcut2d_mwhi', 'c1sp2', 'slogp_vsa11', 'numrotatablebonds', 'numaliphaticcarbocycles', 'slogp_vsa6', 'peoe_vsa4', 'numunspecifiedatomstereocenters', 'xc_6d', 'xc_6dv', 'num_unspecified_stereocenters', 'sz', 'minabspartialcharge', 'fcsp3', 'c1sp1', 'fr_piperzine', 'numaliphaticheterocycles', 'numamidebonds', 'fr_benzene', 'numaromaticheterocycles', 'sm', 'fr_priamide', 'fr_piperdine', 'fr_methoxy', 'c4sp3', 'fr_c_o_nocoo', 'exactmolwt', 'stereo_complexity', 'fr_hoccn', 'numaromaticcarbocycles', 'fr_nh2', 'numheterocycles', 'fr_morpholine', 'fr_ketone', 'fr_nh1', 'frac_defined_stereo', 'fr_aryl_methyl', 'fr_alkyl_halide', 'fr_phenol', 'fr_al_oh_notert', 'fr_ar_oh', 'fr_pyridine', 'fr_amide', 'slogp_vsa7', 'fr_halogen', 'numsaturatedcarbocycles', 'slogp_vsa12', 'fr_ndealkylation1', 'xch_3d', 'fr_bicyclic', 'naromatom', 'narombond'],
64
+ "features": ['fingerprint'],
78
65
  "id_column": "udm_mol_bat_id",
79
- "compressed_features": [],
80
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-pytorch/training",
66
+ "compressed_features": ['fingerprint'],
67
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-fp-pytorch/training",
81
68
  "hyperparameters": {},
82
69
  }
83
70
 
@@ -86,7 +73,7 @@ TEMPLATE_PARAMS = {
86
73
  # Model Loading (for SageMaker inference)
87
74
  # =============================================================================
88
75
  def model_fn(model_dir: str) -> dict:
89
- """Load TabularMLP ensemble from the specified directory."""
76
+ """Load PyTorch TabularMLP ensemble from the specified directory."""
90
77
  # Load ensemble metadata
91
78
  metadata_path = os.path.join(model_dir, "ensemble_metadata.joblib")
92
79
  if os.path.exists(metadata_path):
@@ -129,7 +116,7 @@ def model_fn(model_dir: str) -> dict:
129
116
  # Inference (for SageMaker inference)
130
117
  # =============================================================================
131
118
  def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
132
- """Make predictions with TabularMLP ensemble."""
119
+ """Make predictions with PyTorch TabularMLP ensemble."""
133
120
  model_type = TEMPLATE_PARAMS["model_type"]
134
121
  compressed_features = TEMPLATE_PARAMS["compressed_features"]
135
122
  model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
@@ -233,6 +220,36 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
233
220
  # Training
234
221
  # =============================================================================
235
222
  if __name__ == "__main__":
223
+ # -------------------------------------------------------------------------
224
+ # Training-only imports (deferred to reduce serverless startup time)
225
+ # -------------------------------------------------------------------------
226
+ import argparse
227
+
228
+ import awswrangler as wr
229
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
230
+ from sklearn.preprocessing import LabelEncoder
231
+
232
+ # Enable Tensor Core optimization for GPUs that support it
233
+ torch.set_float32_matmul_precision("medium")
234
+
235
+ from model_script_utils import (
236
+ check_dataframe,
237
+ compute_classification_metrics,
238
+ compute_regression_metrics,
239
+ print_classification_metrics,
240
+ print_confusion_matrix,
241
+ print_regression_metrics,
242
+ )
243
+ from pytorch_utils import (
244
+ create_model,
245
+ save_model,
246
+ train_model,
247
+ )
248
+ from uq_harness import (
249
+ save_uq_models,
250
+ train_uq_models,
251
+ )
252
+
236
253
  # -------------------------------------------------------------------------
237
254
  # Setup: Parse arguments and load data
238
255
  # -------------------------------------------------------------------------
@@ -377,6 +394,7 @@ if __name__ == "__main__":
377
394
  patience=hyperparameters["early_stopping_patience"],
378
395
  batch_size=hyperparameters["batch_size"],
379
396
  learning_rate=hyperparameters["learning_rate"],
397
+ loss=hyperparameters.get("loss", "L1Loss"),
380
398
  device=device,
381
399
  )
382
400
  ensemble_models.append(model)
@@ -148,12 +148,16 @@ def convert_categorical_types(
148
148
  def decompress_features(
149
149
  df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
150
  ) -> tuple[pd.DataFrame, list[str]]:
151
- """Decompress bitstring features into individual bit columns.
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
152
156
 
153
157
  Args:
154
158
  df: The features DataFrame
155
159
  features: Full list of feature names
156
- compressed_features: List of feature names to decompress (bitstrings)
160
+ compressed_features: List of feature names to decompress
157
161
 
158
162
  Returns:
159
163
  Tuple of (DataFrame with decompressed features, updated feature list)
@@ -178,18 +182,18 @@ def decompress_features(
178
182
  # Remove the feature from the list to avoid duplication
179
183
  decompressed_features.remove(feature)
180
184
 
181
- # Handle all compressed features as bitstrings
182
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
183
- prefix = feature[:3]
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
184
189
 
185
- # Create all new columns at once - avoids fragmentation
186
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
187
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
188
194
 
189
- # Add to features list
195
+ # Update features list and dataframe
190
196
  decompressed_features.extend(new_col_names)
191
-
192
- # Drop original column and concatenate new ones
193
197
  df = df.drop(columns=[feature])
194
198
  df = pd.concat([df, new_df], axis=1)
195
199