workbench 0.8.224__py3-none-any.whl → 0.8.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  4. workbench/algorithms/sql/column_stats.py +0 -1
  5. workbench/algorithms/sql/correlations.py +0 -1
  6. workbench/algorithms/sql/descriptive_stats.py +0 -1
  7. workbench/api/meta.py +0 -1
  8. workbench/cached/cached_meta.py +0 -1
  9. workbench/cached/cached_model.py +37 -7
  10. workbench/core/artifacts/endpoint_core.py +12 -2
  11. workbench/core/artifacts/feature_set_core.py +66 -8
  12. workbench/core/cloud_platform/cloud_meta.py +0 -1
  13. workbench/model_script_utils/model_script_utils.py +30 -0
  14. workbench/model_script_utils/uq_harness.py +0 -1
  15. workbench/model_scripts/chemprop/chemprop.template +3 -0
  16. workbench/model_scripts/chemprop/generated_model_script.py +3 -3
  17. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  18. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  19. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  20. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  21. workbench/model_scripts/pytorch_model/generated_model_script.py +50 -32
  22. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  23. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  24. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  25. workbench/model_scripts/script_generation.py +0 -1
  26. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  28. workbench/themes/dark/custom.css +85 -8
  29. workbench/themes/dark/plotly.json +6 -6
  30. workbench/themes/light/custom.css +172 -70
  31. workbench/themes/light/plotly.json +9 -9
  32. workbench/themes/midnight_blue/custom.css +48 -29
  33. workbench/themes/midnight_blue/plotly.json +1 -1
  34. workbench/utils/aws_utils.py +0 -1
  35. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  36. workbench/utils/chem_utils/vis.py +137 -27
  37. workbench/utils/clientside_callbacks.py +41 -0
  38. workbench/utils/markdown_utils.py +57 -0
  39. workbench/utils/pipeline_utils.py +0 -1
  40. workbench/utils/plot_utils.py +8 -110
  41. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  42. workbench/web_interface/components/model_plot.py +2 -0
  43. workbench/web_interface/components/plugin_unit_test.py +0 -1
  44. workbench/web_interface/components/plugins/ag_table.py +2 -4
  45. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  46. workbench/web_interface/components/plugins/model_details.py +10 -6
  47. workbench/web_interface/components/plugins/scatter_plot.py +56 -43
  48. workbench/web_interface/components/settings_menu.py +2 -1
  49. workbench/web_interface/page_views/main_page.py +0 -1
  50. {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/METADATA +31 -29
  51. {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/RECORD +55 -59
  52. {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
  53. workbench/themes/quartz/base_css.url +0 -1
  54. workbench/themes/quartz/custom.css +0 -117
  55. workbench/themes/quartz/plotly.json +0 -642
  56. workbench/themes/quartz_dark/base_css.url +0 -1
  57. workbench/themes/quartz_dark/custom.css +0 -131
  58. workbench/themes/quartz_dark/plotly.json +0 -642
  59. {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
  60. {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
  61. {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
@@ -59,12 +59,12 @@ DEFAULT_HYPERPARAMETERS = {
59
59
 
60
60
  # Template parameters (filled in by Workbench)
61
61
  TEMPLATE_PARAMS = {
62
- "model_type": "uq_regressor",
63
- "target": "udm_asy_res_efflux_ratio",
62
+ "model_type": "classifier",
63
+ "target": "class",
64
64
  "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
65
65
  "id_column": "udm_mol_bat_id",
66
66
  "compressed_features": [],
67
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-pytorch-260113/training",
67
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-class-pytorch-1-fr/training",
68
68
  "hyperparameters": {},
69
69
  }
70
70
 
@@ -152,24 +152,30 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
152
152
  print("Decompressing features for prediction...")
153
153
  matched_df, features = decompress_features(matched_df, features, compressed_features)
154
154
 
155
- # Track missing features
156
- missing_mask = matched_df[features].isna().any(axis=1)
157
- if missing_mask.any():
158
- print(f"Warning: {missing_mask.sum()} rows have missing features")
155
+ # Impute missing values (categorical with mode, continuous handled by scaler)
156
+ missing_counts = matched_df[features].isna().sum()
157
+ if missing_counts.any():
158
+ missing_features = missing_counts[missing_counts > 0]
159
+ print(f"Imputing missing values: {missing_features.to_dict()}")
160
+
161
+ # Load categorical imputation values if available
162
+ impute_path = os.path.join(model_dir, "categorical_impute.json")
163
+ if os.path.exists(impute_path):
164
+ with open(impute_path) as f:
165
+ cat_impute_values = json.load(f)
166
+ for col in categorical_cols:
167
+ if col in cat_impute_values and matched_df[col].isna().any():
168
+ matched_df[col] = matched_df[col].fillna(cat_impute_values[col])
169
+ # Continuous features are imputed by FeatureScaler.transform() using column means
159
170
 
160
171
  # Initialize output columns
161
172
  df["prediction"] = np.nan
162
173
  if model_type in ["regressor", "uq_regressor"]:
163
174
  df["prediction_std"] = np.nan
164
175
 
165
- complete_df = matched_df[~missing_mask].copy()
166
- if len(complete_df) == 0:
167
- print("Warning: No complete rows to predict on")
168
- return df
169
-
170
- # Prepare data for inference (with standardization)
176
+ # Prepare data for inference (with standardization and continuous imputation)
171
177
  x_cont, x_cat, _, _, _ = prepare_data(
172
- complete_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
178
+ matched_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
173
179
  )
174
180
 
175
181
  # Collect ensemble predictions
@@ -191,28 +197,20 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
191
197
  class_preds = np.argmax(avg_probs, axis=1)
192
198
  predictions = label_encoder.inverse_transform(class_preds)
193
199
 
194
- all_proba = pd.Series([None] * len(df), index=df.index, dtype=object)
195
- all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
196
- df["pred_proba"] = all_proba
200
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
197
201
  df = expand_proba_column(df, label_encoder.classes_)
198
202
  else:
199
203
  # Regression
200
204
  predictions = preds.flatten()
201
- df.loc[~missing_mask, "prediction_std"] = preds_std.flatten()
205
+ df["prediction_std"] = preds_std.flatten()
202
206
 
203
207
  # Add UQ intervals if available
204
208
  if uq_models and uq_metadata:
205
- X_complete = complete_df[features]
206
- df_complete = df.loc[~missing_mask].copy()
207
- df_complete["prediction"] = predictions # Set prediction before compute_confidence
208
- df_complete = predict_intervals(df_complete, X_complete, uq_models, uq_metadata)
209
- df_complete = compute_confidence(df_complete, uq_metadata["median_interval_width"], "q_10", "q_90")
210
- # Copy UQ columns back to main dataframe
211
- for col in df_complete.columns:
212
- if col.startswith("q_") or col == "confidence":
213
- df.loc[~missing_mask, col] = df_complete[col].values
214
-
215
- df.loc[~missing_mask, "prediction"] = predictions
209
+ df["prediction"] = predictions # Set prediction before compute_confidence
210
+ df = predict_intervals(df, matched_df[features], uq_models, uq_metadata)
211
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
212
+
213
+ df["prediction"] = predictions
216
214
  return df
217
215
 
218
216
 
@@ -275,11 +273,11 @@ if __name__ == "__main__":
275
273
  all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
276
274
  check_dataframe(all_df, "training_df")
277
275
 
278
- # Drop rows with missing features
276
+ # Drop rows with missing target (required for training)
279
277
  initial_count = len(all_df)
280
- all_df = all_df.dropna(subset=features)
278
+ all_df = all_df.dropna(subset=[target])
281
279
  if len(all_df) < initial_count:
282
- print(f"Dropped {initial_count - len(all_df)} rows with missing features")
280
+ print(f"Dropped {initial_count - len(all_df)} rows with missing target")
283
281
 
284
282
  print(f"Target: {target}")
285
283
  print(f"Features: {features}")
@@ -301,6 +299,23 @@ if __name__ == "__main__":
301
299
  print(f"Categorical: {categorical_cols}")
302
300
  print(f"Continuous: {len(continuous_cols)} columns")
303
301
 
302
+ # Report and handle missing values in features
303
+ # Compute categorical imputation values (mode) for use at inference time
304
+ cat_impute_values = {}
305
+ for col in categorical_cols:
306
+ mode_val = all_df[col].mode().iloc[0] if not all_df[col].mode().empty else all_df[col].cat.categories[0]
307
+ cat_impute_values[col] = str(mode_val) # Convert to string for JSON serialization
308
+
309
+ missing_counts = all_df[features].isna().sum()
310
+ if missing_counts.any():
311
+ missing_features = missing_counts[missing_counts > 0]
312
+ print(f"Missing values in features (will be imputed): {missing_features.to_dict()}")
313
+ # Impute categorical features with mode (most frequent value)
314
+ for col in categorical_cols:
315
+ if all_df[col].isna().any():
316
+ all_df[col] = all_df[col].fillna(cat_impute_values[col])
317
+ # Continuous features are imputed by FeatureScaler.transform() using column means
318
+
304
319
  # -------------------------------------------------------------------------
305
320
  # Classification setup
306
321
  # -------------------------------------------------------------------------
@@ -506,6 +521,9 @@ if __name__ == "__main__":
506
521
  with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
507
522
  json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
508
523
 
524
+ with open(os.path.join(args.model_dir, "categorical_impute.json"), "w") as f:
525
+ json.dump(cat_impute_values, f)
526
+
509
527
  with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
510
528
  json.dump(hyperparameters, f, indent=2)
511
529
 
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284
 
@@ -152,24 +152,30 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
152
152
  print("Decompressing features for prediction...")
153
153
  matched_df, features = decompress_features(matched_df, features, compressed_features)
154
154
 
155
- # Track missing features
156
- missing_mask = matched_df[features].isna().any(axis=1)
157
- if missing_mask.any():
158
- print(f"Warning: {missing_mask.sum()} rows have missing features")
155
+ # Impute missing values (categorical with mode, continuous handled by scaler)
156
+ missing_counts = matched_df[features].isna().sum()
157
+ if missing_counts.any():
158
+ missing_features = missing_counts[missing_counts > 0]
159
+ print(f"Imputing missing values: {missing_features.to_dict()}")
160
+
161
+ # Load categorical imputation values if available
162
+ impute_path = os.path.join(model_dir, "categorical_impute.json")
163
+ if os.path.exists(impute_path):
164
+ with open(impute_path) as f:
165
+ cat_impute_values = json.load(f)
166
+ for col in categorical_cols:
167
+ if col in cat_impute_values and matched_df[col].isna().any():
168
+ matched_df[col] = matched_df[col].fillna(cat_impute_values[col])
169
+ # Continuous features are imputed by FeatureScaler.transform() using column means
159
170
 
160
171
  # Initialize output columns
161
172
  df["prediction"] = np.nan
162
173
  if model_type in ["regressor", "uq_regressor"]:
163
174
  df["prediction_std"] = np.nan
164
175
 
165
- complete_df = matched_df[~missing_mask].copy()
166
- if len(complete_df) == 0:
167
- print("Warning: No complete rows to predict on")
168
- return df
169
-
170
- # Prepare data for inference (with standardization)
176
+ # Prepare data for inference (with standardization and continuous imputation)
171
177
  x_cont, x_cat, _, _, _ = prepare_data(
172
- complete_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
178
+ matched_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
173
179
  )
174
180
 
175
181
  # Collect ensemble predictions
@@ -191,28 +197,20 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
191
197
  class_preds = np.argmax(avg_probs, axis=1)
192
198
  predictions = label_encoder.inverse_transform(class_preds)
193
199
 
194
- all_proba = pd.Series([None] * len(df), index=df.index, dtype=object)
195
- all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
196
- df["pred_proba"] = all_proba
200
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
197
201
  df = expand_proba_column(df, label_encoder.classes_)
198
202
  else:
199
203
  # Regression
200
204
  predictions = preds.flatten()
201
- df.loc[~missing_mask, "prediction_std"] = preds_std.flatten()
205
+ df["prediction_std"] = preds_std.flatten()
202
206
 
203
207
  # Add UQ intervals if available
204
208
  if uq_models and uq_metadata:
205
- X_complete = complete_df[features]
206
- df_complete = df.loc[~missing_mask].copy()
207
- df_complete["prediction"] = predictions # Set prediction before compute_confidence
208
- df_complete = predict_intervals(df_complete, X_complete, uq_models, uq_metadata)
209
- df_complete = compute_confidence(df_complete, uq_metadata["median_interval_width"], "q_10", "q_90")
210
- # Copy UQ columns back to main dataframe
211
- for col in df_complete.columns:
212
- if col.startswith("q_") or col == "confidence":
213
- df.loc[~missing_mask, col] = df_complete[col].values
214
-
215
- df.loc[~missing_mask, "prediction"] = predictions
209
+ df["prediction"] = predictions # Set prediction before compute_confidence
210
+ df = predict_intervals(df, matched_df[features], uq_models, uq_metadata)
211
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
212
+
213
+ df["prediction"] = predictions
216
214
  return df
217
215
 
218
216
 
@@ -275,11 +273,11 @@ if __name__ == "__main__":
275
273
  all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
276
274
  check_dataframe(all_df, "training_df")
277
275
 
278
- # Drop rows with missing features
276
+ # Drop rows with missing target (required for training)
279
277
  initial_count = len(all_df)
280
- all_df = all_df.dropna(subset=features)
278
+ all_df = all_df.dropna(subset=[target])
281
279
  if len(all_df) < initial_count:
282
- print(f"Dropped {initial_count - len(all_df)} rows with missing features")
280
+ print(f"Dropped {initial_count - len(all_df)} rows with missing target")
283
281
 
284
282
  print(f"Target: {target}")
285
283
  print(f"Features: {features}")
@@ -301,6 +299,23 @@ if __name__ == "__main__":
301
299
  print(f"Categorical: {categorical_cols}")
302
300
  print(f"Continuous: {len(continuous_cols)} columns")
303
301
 
302
+ # Report and handle missing values in features
303
+ # Compute categorical imputation values (mode) for use at inference time
304
+ cat_impute_values = {}
305
+ for col in categorical_cols:
306
+ mode_val = all_df[col].mode().iloc[0] if not all_df[col].mode().empty else all_df[col].cat.categories[0]
307
+ cat_impute_values[col] = str(mode_val) # Convert to string for JSON serialization
308
+
309
+ missing_counts = all_df[features].isna().sum()
310
+ if missing_counts.any():
311
+ missing_features = missing_counts[missing_counts > 0]
312
+ print(f"Missing values in features (will be imputed): {missing_features.to_dict()}")
313
+ # Impute categorical features with mode (most frequent value)
314
+ for col in categorical_cols:
315
+ if all_df[col].isna().any():
316
+ all_df[col] = all_df[col].fillna(cat_impute_values[col])
317
+ # Continuous features are imputed by FeatureScaler.transform() using column means
318
+
304
319
  # -------------------------------------------------------------------------
305
320
  # Classification setup
306
321
  # -------------------------------------------------------------------------
@@ -506,6 +521,9 @@ if __name__ == "__main__":
506
521
  with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
507
522
  json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
508
523
 
524
+ with open(os.path.join(args.model_dir, "categorical_impute.json"), "w") as f:
525
+ json.dump(cat_impute_values, f)
526
+
509
527
  with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
510
528
  json.dump(hyperparameters, f, indent=2)
511
529
 
@@ -22,7 +22,6 @@ import joblib
22
22
  from lightgbm import LGBMRegressor
23
23
  from mapie.regression import ConformalizedQuantileRegressor
24
24
 
25
-
26
25
  # Default confidence levels for prediction intervals
27
26
  DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
28
27
 
@@ -6,7 +6,6 @@ import logging
6
6
  from pathlib import Path
7
7
  import importlib.util
8
8
 
9
-
10
9
  # Setup the logger
11
10
  log = logging.getLogger("workbench")
12
11
 
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284
 
@@ -22,7 +22,6 @@ import joblib
22
22
  from lightgbm import LGBMRegressor
23
23
  from mapie.regression import ConformalizedQuantileRegressor
24
24
 
25
-
26
25
  # Default confidence levels for prediction intervals
27
26
  DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
28
27
 
@@ -3,6 +3,7 @@ h1, h2, h3, h4 {
3
3
  }
4
4
  body {
5
5
  color: rgb(180, 180, 180); /* We want the text dim white */
6
+ background: linear-gradient(90deg, rgb(45, 45, 45) 0%, rgb(35, 35, 35) 100%);
6
7
  }
7
8
 
8
9
  /* Custom CSS to style bold text */
@@ -36,21 +37,38 @@ a:hover {
36
37
 
37
38
  /* AgGrid custom CSS */
38
39
 
39
- /* There's a one pixel border around the grid that we want to remove */
40
- .ag-root-wrapper {
41
- border: none !important; /* Force removal with !important */
42
- }
43
-
44
-
45
- /* Box shadow and rounded corners for all AgGrid themes */
40
+ /* AG Grid 33+ uses CSS variables for theming - set them at the theme level */
46
41
  [class*="ag-theme-"] {
42
+ --ag-background-color: rgb(40, 40, 40);
43
+ --ag-odd-row-background-color: rgb(40, 40, 40);
44
+ --ag-row-background-color: rgb(50, 50, 50);
45
+ --ag-selected-row-background-color: rgb(60, 70, 90);
46
+ --ag-row-hover-color: rgb(55, 55, 65);
47
+ --ag-header-background-color: rgb(35, 35, 35);
48
+ --ag-border-color: rgba(80, 80, 80, 0.5);
49
+ --ag-foreground-color: rgb(180, 180, 180);
50
+ --ag-header-foreground-color: rgb(220, 220, 220);
51
+ --ag-wrapper-border-radius: 12px;
52
+
53
+ /* Box shadow and rounded corners */
47
54
  box-shadow: 2px 2px 6px 5px rgba(0, 0, 0, 0.25);
48
- border-radius: 12px; /* Rounded corners */
55
+ border-radius: 12px;
49
56
  border: 0.5px solid rgba(0, 0, 0, 0.5);
50
57
  margin: 0;
51
58
  padding: 0;
52
59
  }
53
60
 
61
+ /* Remove border from the grid wrapper */
62
+ .ag-root-wrapper {
63
+ border: none !important;
64
+ }
65
+
66
+ /* AG Grid container - remove padding but allow shadow overflow */
67
+ div:has(> [class*="ag-theme-"]) {
68
+ padding: 0 !important;
69
+ overflow: visible !important;
70
+ }
71
+
54
72
  /* Apply styling to Workbench containers */
55
73
  .workbench-container {
56
74
  box-shadow: 2px 2px 6px 5px rgba(0, 0, 0, 0.25);
@@ -110,6 +128,40 @@ a:hover {
110
128
  color: rgb(100, 255, 100);
111
129
  }
112
130
 
131
+ /* Dropdown styling (dcc.Dropdown) - override Bootstrap's variables */
132
+ .dash-dropdown {
133
+ --bs-body-bg: rgb(35, 35, 35);
134
+ --bs-body-color: rgb(210, 210, 210);
135
+ --bs-border-color: rgb(60, 60, 60);
136
+ }
137
+
138
+ /* Bootstrap form controls (dbc components) */
139
+ .form-select, .form-control {
140
+ background-color: rgb(35, 35, 35) !important;
141
+ border: 1px solid rgb(60, 60, 60) !important;
142
+ color: rgb(210, 210, 210) !important;
143
+ }
144
+
145
+ .form-select:focus, .form-control:focus {
146
+ background-color: rgb(45, 45, 45) !important;
147
+ border-color: rgb(80, 80, 80) !important;
148
+ box-shadow: 0 0 0 0.2rem rgba(80, 80, 80, 0.25) !important;
149
+ }
150
+
151
+ .dropdown-menu {
152
+ background-color: rgb(35, 35, 35) !important;
153
+ border: 1px solid rgb(60, 60, 60) !important;
154
+ }
155
+
156
+ .dropdown-item {
157
+ color: rgb(210, 210, 210) !important;
158
+ }
159
+
160
+ .dropdown-item:hover, .dropdown-item:focus {
161
+ background-color: rgb(50, 50, 50) !important;
162
+ color: rgb(230, 230, 230) !important;
163
+ }
164
+
113
165
  /* Table styling */
114
166
  table {
115
167
  width: 100%;
@@ -128,4 +180,29 @@ td {
128
180
  padding: 5px;
129
181
  border: 0.5px solid #444;
130
182
  text-align: center !important;
183
+ }
184
+
185
+ /* AG Grid table header colors - gradient theme */
186
+ /* Data Sources tables - red gradient */
187
+ #main_data_sources .ag-header,
188
+ #data_sources_table .ag-header {
189
+ background: linear-gradient(180deg, rgb(140, 60, 60) 0%, rgb(80, 35, 35) 100%) !important;
190
+ }
191
+
192
+ /* Feature Sets tables - yellow/olive gradient */
193
+ #main_feature_sets .ag-header,
194
+ #feature_sets_table .ag-header {
195
+ background: linear-gradient(180deg, rgb(120, 115, 55) 0%, rgb(70, 65, 30) 100%) !important;
196
+ }
197
+
198
+ /* Models tables - green gradient */
199
+ #main_models .ag-header,
200
+ #models_table .ag-header {
201
+ background: linear-gradient(180deg, rgb(55, 110, 55) 0%, rgb(30, 60, 30) 100%) !important;
202
+ }
203
+
204
+ /* Endpoints tables - purple gradient */
205
+ #main_endpoints .ag-header,
206
+ #endpoints_table .ag-header {
207
+ background: linear-gradient(180deg, rgb(100, 60, 120) 0%, rgb(55, 30, 70) 100%) !important;
131
208
  }
@@ -483,11 +483,11 @@
483
483
  [1.0, "rgb(200, 100, 100)"]
484
484
  ],
485
485
  "sequential": [
486
- [0.0, "rgb(100, 100, 200)"],
487
- [0.4, "rgb(100, 200, 100)"],
488
- [0.65, "rgb(180, 180, 50)"],
489
- [0.85, "rgb(200, 100, 100)"],
490
- [1.0, "rgb(200, 100, 100)"]
486
+ [0.0, "rgba(80, 100, 255, 1.0)"],
487
+ [0.25, "rgba(70, 145, 220, 1.0)"],
488
+ [0.5, "rgba(70, 220, 100, 1.0)"],
489
+ [0.75, "rgba(255, 181, 80, 1.0)"],
490
+ [1.0, "rgba(232, 50, 131, 1.0)"]
491
491
  ],
492
492
  "sequentialminus": [
493
493
  [0.0, "rgb(255, 100, 100)"],
@@ -527,7 +527,7 @@
527
527
  "style": "dark"
528
528
  },
529
529
  "paper_bgcolor": "rgba(0, 0, 0, 0.0)",
530
- "plot_bgcolor": "rgba(0, 0, 0, 0.0)",
530
+ "plot_bgcolor": "rgb(40, 40, 40)",
531
531
  "polar": {
532
532
  "angularaxis": {
533
533
  "gridcolor": "#506784",