workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sklearn.neighbors import NearestNeighbors
5
+ from typing import List, Optional
6
+ import logging
7
+
8
+ # Workbench Imports
9
+ from workbench.algorithms.dataframe.proximity import Proximity
10
+ from workbench.algorithms.dataframe.projection_2d import Projection2D
11
+
12
+ # Set up logging
13
+ log = logging.getLogger("workbench")
14
+
15
+
16
+ class FeatureSpaceProximity(Proximity):
17
+ """Proximity computations for numeric feature spaces using Euclidean distance."""
18
+
19
+ def __init__(
20
+ self,
21
+ df: pd.DataFrame,
22
+ id_column: str,
23
+ features: List[str],
24
+ target: Optional[str] = None,
25
+ include_all_columns: bool = False,
26
+ ):
27
+ """
28
+ Initialize the FeatureSpaceProximity class.
29
+
30
+ Args:
31
+ df: DataFrame containing data for neighbor computations.
32
+ id_column: Name of the column used as the identifier.
33
+ features: List of feature column names to be used for neighbor computations.
34
+ target: Name of the target column. Defaults to None.
35
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
36
+ """
37
+ # Validate and filter features before calling parent init
38
+ self._raw_features = features
39
+ super().__init__(
40
+ df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
41
+ )
42
+
43
+ def _prepare_data(self) -> None:
44
+ """Filter out non-numeric features and drop NaN rows."""
45
+ # Validate features
46
+ self.features = self._validate_features(self.df, self._raw_features)
47
+
48
+ # Drop NaN rows for the features we're using
49
+ self.df = self.df.dropna(subset=self.features).copy()
50
+
51
+ def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
52
+ """Remove non-numeric features and log warnings."""
53
+ non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
54
+ if non_numeric:
55
+ log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
56
+ return [f for f in features if f not in non_numeric]
57
+
58
+ def _build_model(self) -> None:
59
+ """Standardize features and fit Nearest Neighbors model."""
60
+ self.scaler = StandardScaler()
61
+ X = self.scaler.fit_transform(self.df[self.features])
62
+ self.nn = NearestNeighbors().fit(X)
63
+
64
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
65
+ """Transform features using the fitted scaler."""
66
+ return self.scaler.transform(df[self.features])
67
+
68
+ def _project_2d(self) -> None:
69
+ """Project the numeric features to 2D for visualization."""
70
+ if len(self.features) >= 2:
71
+ self.df = Projection2D().fit_transform(self.df, features=self.features)
72
+
73
+
74
+ # Testing the FeatureSpaceProximity class
75
+ if __name__ == "__main__":
76
+
77
+ pd.set_option("display.max_columns", None)
78
+ pd.set_option("display.width", 1000)
79
+
80
+ # Create a sample DataFrame
81
+ data = {
82
+ "ID": [1, 2, 3, 4, 5],
83
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
84
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
85
+ "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
86
+ }
87
+ df = pd.DataFrame(data)
88
+
89
+ # Test the FeatureSpaceProximity class
90
+ features = ["Feature1", "Feature2", "Feature3"]
91
+ prox = FeatureSpaceProximity(df, id_column="ID", features=features)
92
+ print(prox.neighbors(1, n_neighbors=2))
93
+
94
+ # Test the neighbors method with radius
95
+ print(prox.neighbors(1, radius=2.0))
96
+
97
+ # Test with Features list
98
+ prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
99
+ print(prox.neighbors(1))
100
+
101
+ # Create a sample DataFrame
102
+ data = {
103
+ "id": ["a", "b", "c", "d", "e"], # Testing string IDs
104
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
105
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
106
+ "target": [1, 0, 1, 0, 5],
107
+ }
108
+ df = pd.DataFrame(data)
109
+
110
+ # Test with String Ids
111
+ prox = FeatureSpaceProximity(
112
+ df,
113
+ id_column="id",
114
+ features=["Feature1", "Feature2"],
115
+ target="target",
116
+ include_all_columns=True,
117
+ )
118
+ print(prox.neighbors(["a", "b"]))
119
+
120
+ # Test duplicate IDs
121
+ data = {
122
+ "id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
123
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
124
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
125
+ "target": [1, 0, 1, 0, 5],
126
+ }
127
+ df = pd.DataFrame(data)
128
+ prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
129
+ print(df.equals(prox.df))
130
+
131
+ # Test on real data from Workbench
132
+ from workbench.api import FeatureSet, Model
133
+
134
+ fs = FeatureSet("aqsol_features")
135
+ model = Model("aqsol-regression")
136
+ features = model.features()
137
+ df = fs.pull_dataframe()
138
+ prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
139
+ print("\n" + "=" * 80)
140
+ print("Testing Neighbors...")
141
+ print("=" * 80)
142
+ test_id = df[fs.id_column].tolist()[0]
143
+ print(f"\nNeighbors for ID {test_id}:")
144
+ print(prox.neighbors(test_id))
145
+
146
+ print("\n" + "=" * 80)
147
+ print("Testing isolated_compounds...")
148
+ print("=" * 80)
149
+
150
+ # Test isolated data in the top 1%
151
+ isolated_1pct = prox.isolated(top_percent=1.0)
152
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
153
+ print(isolated_1pct)
154
+
155
+ # Test isolated data in the top 5%
156
+ isolated_5pct = prox.isolated(top_percent=5.0)
157
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
158
+ print(isolated_5pct)
159
+
160
+ print("\n" + "=" * 80)
161
+ print("Testing target_gradients...")
162
+ print("=" * 80)
163
+
164
+ # Test with different parameters
165
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
166
+ print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
167
+ print(gradients_1pct)
168
+
169
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
170
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
171
+ print(gradients_5pct)
172
+
173
+ # Test proximity_stats
174
+ print("\n" + "=" * 80)
175
+ print("Testing proximity_stats...")
176
+ print("=" * 80)
177
+ stats = prox.proximity_stats()
178
+ print(stats)
179
+
180
+ # Plot the distance distribution using pandas
181
+ print("\n" + "=" * 80)
182
+ print("Plotting distance distribution...")
183
+ print("=" * 80)
184
+ prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
185
+
186
+ # Visualize the 2D projection
187
+ print("\n" + "=" * 80)
188
+ print("Visualizing 2D Projection...")
189
+ print("=" * 80)
190
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
191
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
192
+
193
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
194
+ unit_test.run()
@@ -9,7 +9,7 @@ from sklearn.model_selection import train_test_split
9
9
  TEMPLATE_PARAMS = {
10
10
  "features": "{{feature_list}}",
11
11
  "target": "{{target_column}}",
12
- "train_all_data": "{{train_all_data}}"
12
+ "train_all_data": "{{train_all_data}}",
13
13
  }
14
14
 
15
15
  from io import StringIO
@@ -33,7 +33,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
33
33
  """
34
34
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
35
35
  Prioritizes exact matches, then case-insensitive matches.
36
-
36
+
37
37
  Raises ValueError if any model features cannot be matched.
38
38
  """
39
39
  df_columns_lower = {col.lower(): col for col in df.columns}
@@ -46,7 +46,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
46
46
  rename_dict[df_columns_lower[feature.lower()]] = feature
47
47
  else:
48
48
  missing.append(feature)
49
-
49
+
50
50
  if missing:
51
51
  raise ValueError(f"Features not found: {missing}")
52
52
 
@@ -76,10 +76,7 @@ if __name__ == "__main__":
76
76
  args = parser.parse_args()
77
77
 
78
78
  # Load training data from the specified directory
79
- training_files = [
80
- os.path.join(args.train, file)
81
- for file in os.listdir(args.train) if file.endswith(".csv")
82
- ]
79
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
83
80
  df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
84
81
 
85
82
  # Check if the DataFrame is empty
@@ -112,10 +109,7 @@ if __name__ == "__main__":
112
109
  )
113
110
 
114
111
  # Create a Pipeline with StandardScaler
115
- model = Pipeline([
116
- ("scaler", StandardScaler()),
117
- ("model", model)
118
- ])
112
+ model = Pipeline([("scaler", StandardScaler()), ("model", model)])
119
113
 
120
114
  # Prepare features and targets for training
121
115
  X_train = df_train[features]
@@ -3,11 +3,8 @@ from ngboost import NGBRegressor
3
3
  from sklearn.model_selection import train_test_split
4
4
 
5
5
  # Model Performance Scores
6
- from sklearn.metrics import (
7
- mean_absolute_error,
8
- r2_score,
9
- root_mean_squared_error
10
- )
6
+ from sklearn.metrics import mean_absolute_error, median_absolute_error, r2_score, root_mean_squared_error
7
+ from scipy.stats import spearmanr
11
8
 
12
9
  from io import StringIO
13
10
  import json
@@ -21,7 +18,7 @@ import pandas as pd
21
18
  TEMPLATE_PARAMS = {
22
19
  "features": "{{feature_list}}",
23
20
  "target": "{{target_column}}",
24
- "train_all_data": "{{train_all_data}}"
21
+ "train_all_data": "{{train_all_data}}",
25
22
  }
26
23
 
27
24
 
@@ -87,10 +84,7 @@ if __name__ == "__main__":
87
84
  args = parser.parse_args()
88
85
 
89
86
  # Load training data from the specified directory
90
- training_files = [
91
- os.path.join(args.train, file)
92
- for file in os.listdir(args.train) if file.endswith(".csv")
93
- ]
87
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
94
88
  print(f"Training Files: {training_files}")
95
89
 
96
90
  # Combine files and read them all into a single pandas dataframe
@@ -136,11 +130,16 @@ if __name__ == "__main__":
136
130
  # Calculate various model performance metrics (regression)
137
131
  rmse = root_mean_squared_error(y_validate, preds)
138
132
  mae = mean_absolute_error(y_validate, preds)
133
+ medae = median_absolute_error(y_validate, preds)
139
134
  r2 = r2_score(y_validate, preds)
140
- print(f"RMSE: {rmse:.3f}")
141
- print(f"MAE: {mae:.3f}")
142
- print(f"R2: {r2:.3f}")
143
- print(f"NumRows: {len(df_val)}")
135
+ spearman_corr = spearmanr(y_validate, preds).correlation
136
+ support = len(df_val)
137
+ print(f"rmse: {rmse:.3f}")
138
+ print(f"mae: {mae:.3f}")
139
+ print(f"medae: {medae:.3f}")
140
+ print(f"r2: {r2:.3f}")
141
+ print(f"spearmanr: {spearman_corr:.3f}")
142
+ print(f"support: {support}")
144
143
 
145
144
  # Save the trained NGBoost model
146
145
  joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
@@ -212,8 +211,8 @@ def predict_fn(df, model) -> pd.DataFrame:
212
211
  dist_params = y_dists.params
213
212
 
214
213
  # Extract mean and std from distribution parameters
215
- df["prediction"] = dist_params['loc'] # mean
216
- df["prediction_std"] = dist_params['scale'] # standard deviation
214
+ df["prediction"] = dist_params["loc"] # mean
215
+ df["prediction_std"] = dist_params["scale"] # standard deviation
217
216
 
218
217
  # Add 95% prediction intervals using ppf (percent point function)
219
218
  df["q_025"] = y_dists.ppf(0.025) # 2.5th percentile
@@ -3,7 +3,7 @@ TEMPLATE_PARAMS = {
3
3
  "model_type": "{{model_type}}",
4
4
  "target_column": "{{target_column}}",
5
5
  "feature_list": "{{feature_list}}",
6
- "model_metrics_s3_path": "{{model_metrics_s3_path}}"
6
+ "model_metrics_s3_path": "{{model_metrics_s3_path}}",
7
7
  }
8
8
 
9
9
  # Imports for XGB Model
@@ -12,11 +12,8 @@ import awswrangler as wr
12
12
  import numpy as np
13
13
 
14
14
  # Model Performance Scores
15
- from sklearn.metrics import (
16
- mean_absolute_error,
17
- r2_score,
18
- root_mean_squared_error
19
- )
15
+ from sklearn.metrics import mean_absolute_error, median_absolute_error, r2_score, root_mean_squared_error
16
+ from scipy.stats import spearmanr
20
17
 
21
18
  from io import StringIO
22
19
  import json
@@ -39,6 +36,7 @@ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
39
36
  print(msg)
40
37
  raise ValueError(msg)
41
38
 
39
+
42
40
  def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
43
41
  """
44
42
  Matches and renames the DataFrame's column names to match the model's feature names (case-insensitive).
@@ -95,11 +93,7 @@ if __name__ == "__main__":
95
93
  args = parser.parse_args()
96
94
 
97
95
  # Read the training data into DataFrames
98
- training_files = [
99
- os.path.join(args.train, file)
100
- for file in os.listdir(args.train)
101
- if file.endswith(".csv")
102
- ]
96
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
103
97
  print(f"Training Files: {training_files}")
104
98
 
105
99
  # Combine files and read them all into a single pandas dataframe
@@ -150,7 +144,6 @@ if __name__ == "__main__":
150
144
  result_df["residual"] = result_df[target] - result_df["prediction"]
151
145
  result_df["residual_abs"] = result_df["residual"].abs()
152
146
 
153
-
154
147
  # Save the results dataframe to S3
155
148
  wr.s3.to_csv(
156
149
  result_df,
@@ -161,11 +154,16 @@ if __name__ == "__main__":
161
154
  # Report Performance Metrics
162
155
  rmse = root_mean_squared_error(result_df[target], result_df["prediction"])
163
156
  mae = mean_absolute_error(result_df[target], result_df["prediction"])
157
+ medae = median_absolute_error(result_df[target], result_df["prediction"])
164
158
  r2 = r2_score(result_df[target], result_df["prediction"])
165
- print(f"RMSE: {rmse:.3f}")
166
- print(f"MAE: {mae:.3f}")
167
- print(f"R2: {r2:.3f}")
168
- print(f"NumRows: {len(result_df)}")
159
+ spearman_corr = spearmanr(result_df[target], result_df["prediction"]).correlation
160
+ support = len(result_df)
161
+ print(f"rmse: {rmse:.3f}")
162
+ print(f"mae: {mae:.3f}")
163
+ print(f"medae: {medae:.3f}")
164
+ print(f"r2: {r2:.3f}")
165
+ print(f"spearmanr: {spearman_corr:.3f}")
166
+ print(f"support: {support}")
169
167
 
170
168
  # Now save the models
171
169
  for name, model in models.items():
@@ -210,7 +208,7 @@ def input_fn(input_data, content_type):
210
208
  """Parse input data and return a DataFrame."""
211
209
  if not input_data:
212
210
  raise ValueError("Empty input data is not supported!")
213
-
211
+
214
212
  # Decode bytes to string if necessary
215
213
  if isinstance(input_data, bytes):
216
214
  input_data = input_data.decode("utf-8")
@@ -0,0 +1,209 @@
1
+ # Meta Model Template for Workbench
2
+ #
3
+ # NOTE: This is called a "meta model" but it's really a "meta endpoint" - it aggregates
4
+ # predictions from multiple child endpoints. We call it a "model" because Workbench
5
+ # creates Model artifacts that get deployed as Endpoints, so this follows that pattern.
6
+ #
7
+ # Assumptions/Shortcuts:
8
+ # - All child endpoints are regression models
9
+ # - All child endpoints output 'prediction' and 'confidence' columns
10
+ # - Aggregation uses model weights (provided at meta model creation time)
11
+ #
12
+ # This template:
13
+ # - Has no real training phase (just saves metadata including model weights)
14
+ # - At inference time, calls child endpoints and aggregates their predictions
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ from concurrent.futures import ThreadPoolExecutor, as_completed
20
+ from io import StringIO
21
+
22
+ import pandas as pd
23
+
24
+ from workbench_bridges.endpoints.fast_inference import fast_inference
25
+
26
+ # Template parameters (filled in by Workbench)
27
+ TEMPLATE_PARAMS = {
28
+ "child_endpoints": ['logd-reg-pytorch', 'logd-reg-chemprop'],
29
+ "target_column": "logd",
30
+ "model_weights": {'logd-reg-pytorch': 0.4228205813233993, 'logd-reg-chemprop': 0.5771794186766008},
31
+ "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/logd-meta/training",
32
+ "aws_region": "us-west-2",
33
+ }
34
+
35
+
36
+ def invoke_endpoints_parallel(endpoint_names: list[str], df: pd.DataFrame) -> dict[str, pd.DataFrame]:
37
+ """Call multiple child endpoints in parallel and collect their results.
38
+
39
+ Args:
40
+ endpoint_names: List of endpoint names to call
41
+ df: Input DataFrame to send to each endpoint
42
+
43
+ Returns:
44
+ Dict mapping endpoint_name -> result DataFrame (or None if failed)
45
+ """
46
+ results = {}
47
+
48
+ def call_endpoint(name: str) -> tuple[str, pd.DataFrame | None]:
49
+ try:
50
+ return name, fast_inference(name, df)
51
+ except Exception as e:
52
+ print(f"Error calling endpoint {name}: {e}")
53
+ return name, None
54
+
55
+ with ThreadPoolExecutor(max_workers=len(endpoint_names)) as executor:
56
+ futures = {executor.submit(call_endpoint, name): name for name in endpoint_names}
57
+ for future in as_completed(futures):
58
+ name, result = future.result()
59
+ results[name] = result
60
+
61
+ return results
62
+
63
+
64
+ def aggregate_predictions(results: dict[str, pd.DataFrame], model_weights: dict[str, float]) -> pd.DataFrame:
65
+ """Aggregate predictions from multiple endpoints using model weights.
66
+
67
+ Args:
68
+ results: Dict mapping endpoint_name -> predictions DataFrame
69
+ Each DataFrame must have 'prediction' and 'confidence' columns
70
+ model_weights: Dict mapping endpoint_name -> weight
71
+
72
+ Returns:
73
+ DataFrame with aggregated prediction, prediction_std, and confidence
74
+ """
75
+ # Filter out failed endpoints
76
+ valid_results = {k: v for k, v in results.items() if v is not None}
77
+ if not valid_results:
78
+ raise ValueError("All child endpoints failed")
79
+
80
+ # Use first result as base (for id columns, etc.)
81
+ first_df = list(valid_results.values())[0]
82
+ output_df = first_df.drop(columns=["prediction", "confidence", "prediction_std"], errors="ignore").copy()
83
+
84
+ # Build DataFrames of predictions and confidences from all endpoints
85
+ pred_df = pd.DataFrame({name: df["prediction"] for name, df in valid_results.items()})
86
+ conf_df = pd.DataFrame({name: df["confidence"] for name, df in valid_results.items()})
87
+
88
+ # Apply model weights (renormalize for valid endpoints only)
89
+ valid_weights = {k: model_weights.get(k, 1.0) for k in valid_results}
90
+ weight_sum = sum(valid_weights.values())
91
+ normalized_weights = {k: v / weight_sum for k, v in valid_weights.items()}
92
+
93
+ # Weighted average
94
+ output_df["prediction"] = sum(pred_df[name] * w for name, w in normalized_weights.items())
95
+
96
+ # Ensemble std across child endpoints
97
+ output_df["prediction_std"] = pred_df.std(axis=1)
98
+
99
+ # Aggregated confidence: weighted mean of child confidences
100
+ output_df["confidence"] = sum(conf_df[name] * w for name, w in normalized_weights.items())
101
+
102
+ return output_df
103
+
104
+
105
+ # =============================================================================
106
+ # Model Loading (for SageMaker inference)
107
+ # =============================================================================
108
+ def model_fn(model_dir: str) -> dict:
109
+ """Load meta model configuration."""
110
+ with open(os.path.join(model_dir, "meta_config.json")) as f:
111
+ config = json.load(f)
112
+
113
+ # Set AWS_REGION for fast_inference (baked in at training time)
114
+ if config.get("aws_region"):
115
+ os.environ["AWS_REGION"] = config["aws_region"]
116
+
117
+ print(f"Meta model loaded: {len(config['child_endpoints'])} child endpoints")
118
+ print(f"Model weights: {config.get('model_weights')}")
119
+ print(f"AWS region: {config.get('aws_region')}")
120
+ return config
121
+
122
+
123
+ def input_fn(input_data, content_type):
124
+ """Parse input data and return a DataFrame."""
125
+ if not input_data:
126
+ raise ValueError("Empty input data is not supported!")
127
+
128
+ # Decode bytes to string if necessary
129
+ if isinstance(input_data, bytes):
130
+ input_data = input_data.decode("utf-8")
131
+
132
+ if "text/csv" in content_type:
133
+ return pd.read_csv(StringIO(input_data))
134
+ elif "application/json" in content_type:
135
+ return pd.DataFrame(json.loads(input_data))
136
+ else:
137
+ raise ValueError(f"{content_type} not supported!")
138
+
139
+
140
+ def output_fn(output_df, accept_type):
141
+ """Supports both CSV and JSON output formats."""
142
+ if "text/csv" in accept_type:
143
+ return output_df.to_csv(index=False), "text/csv"
144
+ elif "application/json" in accept_type:
145
+ return output_df.to_json(orient="records"), "application/json"
146
+ else:
147
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
148
+
149
+
150
+ # =============================================================================
151
+ # Inference (for SageMaker inference)
152
+ # =============================================================================
153
+ def predict_fn(df: pd.DataFrame, config: dict) -> pd.DataFrame:
154
+ """Run inference by calling child endpoints and aggregating results."""
155
+ child_endpoints = config["child_endpoints"]
156
+ model_weights = config.get("model_weights", {})
157
+
158
+ print(f"Calling {len(child_endpoints)} child endpoints: {child_endpoints}")
159
+
160
+ # Call all child endpoints
161
+ results = invoke_endpoints_parallel(child_endpoints, df)
162
+
163
+ # Report status
164
+ for name, result in results.items():
165
+ status = f"{len(result)} rows" if result is not None else "FAILED"
166
+ print(f" {name}: {status}")
167
+
168
+ # Aggregate predictions using model weights
169
+ output_df = aggregate_predictions(results, model_weights)
170
+
171
+ print(f"Aggregated {len(output_df)} predictions from {len(results)} endpoints")
172
+ return output_df
173
+
174
+
175
+ # =============================================================================
176
+ # Training (just saves configuration - no actual training)
177
+ # =============================================================================
178
+ if __name__ == "__main__":
179
+ parser = argparse.ArgumentParser()
180
+ parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
181
+ parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data"))
182
+ parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
183
+ args = parser.parse_args()
184
+
185
+ child_endpoints = TEMPLATE_PARAMS["child_endpoints"]
186
+ target_column = TEMPLATE_PARAMS["target_column"]
187
+ model_weights = TEMPLATE_PARAMS["model_weights"]
188
+ aws_region = TEMPLATE_PARAMS["aws_region"]
189
+
190
+ print("=" * 60)
191
+ print("Meta Model Configuration")
192
+ print("=" * 60)
193
+ print(f"Child endpoints: {child_endpoints}")
194
+ print(f"Target column: {target_column}")
195
+ print(f"Model weights: {model_weights}")
196
+ print(f"AWS region: {aws_region}")
197
+
198
+ # Save configuration for inference
199
+ config = {
200
+ "child_endpoints": child_endpoints,
201
+ "target_column": target_column,
202
+ "model_weights": model_weights,
203
+ "aws_region": aws_region,
204
+ }
205
+
206
+ with open(os.path.join(args.model_dir, "meta_config.json"), "w") as f:
207
+ json.dump(config, f, indent=2)
208
+
209
+ print(f"\nMeta model configuration saved to {args.model_dir}")