workbench 0.8.219__py3-none-any.whl → 0.8.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  5. workbench/algorithms/dataframe/projection_2d.py +8 -2
  6. workbench/algorithms/dataframe/proximity.py +3 -0
  7. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  8. workbench/algorithms/sql/column_stats.py +0 -1
  9. workbench/algorithms/sql/correlations.py +0 -1
  10. workbench/algorithms/sql/descriptive_stats.py +0 -1
  11. workbench/api/feature_set.py +0 -1
  12. workbench/api/meta.py +0 -1
  13. workbench/cached/cached_meta.py +0 -1
  14. workbench/cached/cached_model.py +37 -7
  15. workbench/core/artifacts/endpoint_core.py +12 -2
  16. workbench/core/artifacts/feature_set_core.py +238 -225
  17. workbench/core/cloud_platform/cloud_meta.py +0 -1
  18. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  19. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  20. workbench/model_script_utils/model_script_utils.py +30 -0
  21. workbench/model_script_utils/uq_harness.py +0 -1
  22. workbench/model_scripts/chemprop/chemprop.template +196 -68
  23. workbench/model_scripts/chemprop/generated_model_script.py +197 -72
  24. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  25. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  26. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  27. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  28. workbench/model_scripts/pytorch_model/generated_model_script.py +52 -34
  29. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  30. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  31. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  32. workbench/model_scripts/script_generation.py +0 -1
  33. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  34. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  35. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  36. workbench/scripts/ml_pipeline_sqs.py +71 -2
  37. workbench/themes/dark/custom.css +85 -8
  38. workbench/themes/dark/plotly.json +6 -6
  39. workbench/themes/light/custom.css +172 -64
  40. workbench/themes/light/plotly.json +9 -9
  41. workbench/themes/midnight_blue/custom.css +82 -29
  42. workbench/themes/midnight_blue/plotly.json +1 -1
  43. workbench/utils/aws_utils.py +0 -1
  44. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  45. workbench/utils/chem_utils/projections.py +16 -6
  46. workbench/utils/chem_utils/vis.py +137 -27
  47. workbench/utils/clientside_callbacks.py +41 -0
  48. workbench/utils/markdown_utils.py +57 -0
  49. workbench/utils/model_utils.py +0 -1
  50. workbench/utils/pipeline_utils.py +0 -1
  51. workbench/utils/plot_utils.py +52 -36
  52. workbench/utils/theme_manager.py +95 -30
  53. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  54. workbench/web_interface/components/model_plot.py +2 -0
  55. workbench/web_interface/components/plugin_unit_test.py +0 -1
  56. workbench/web_interface/components/plugins/ag_table.py +2 -4
  57. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  58. workbench/web_interface/components/plugins/model_details.py +10 -6
  59. workbench/web_interface/components/plugins/scatter_plot.py +184 -85
  60. workbench/web_interface/components/settings_menu.py +185 -0
  61. workbench/web_interface/page_views/main_page.py +0 -1
  62. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/METADATA +34 -41
  63. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/RECORD +67 -69
  64. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
  65. workbench/themes/quartz/base_css.url +0 -1
  66. workbench/themes/quartz/custom.css +0 -117
  67. workbench/themes/quartz/plotly.json +0 -642
  68. workbench/themes/quartz_dark/base_css.url +0 -1
  69. workbench/themes/quartz_dark/custom.css +0 -131
  70. workbench/themes/quartz_dark/plotly.json +0 -642
  71. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
  72. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
  73. {workbench-0.8.219.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284
 
@@ -152,24 +152,30 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
152
152
  print("Decompressing features for prediction...")
153
153
  matched_df, features = decompress_features(matched_df, features, compressed_features)
154
154
 
155
- # Track missing features
156
- missing_mask = matched_df[features].isna().any(axis=1)
157
- if missing_mask.any():
158
- print(f"Warning: {missing_mask.sum()} rows have missing features")
155
+ # Impute missing values (categorical with mode, continuous handled by scaler)
156
+ missing_counts = matched_df[features].isna().sum()
157
+ if missing_counts.any():
158
+ missing_features = missing_counts[missing_counts > 0]
159
+ print(f"Imputing missing values: {missing_features.to_dict()}")
160
+
161
+ # Load categorical imputation values if available
162
+ impute_path = os.path.join(model_dir, "categorical_impute.json")
163
+ if os.path.exists(impute_path):
164
+ with open(impute_path) as f:
165
+ cat_impute_values = json.load(f)
166
+ for col in categorical_cols:
167
+ if col in cat_impute_values and matched_df[col].isna().any():
168
+ matched_df[col] = matched_df[col].fillna(cat_impute_values[col])
169
+ # Continuous features are imputed by FeatureScaler.transform() using column means
159
170
 
160
171
  # Initialize output columns
161
172
  df["prediction"] = np.nan
162
173
  if model_type in ["regressor", "uq_regressor"]:
163
174
  df["prediction_std"] = np.nan
164
175
 
165
- complete_df = matched_df[~missing_mask].copy()
166
- if len(complete_df) == 0:
167
- print("Warning: No complete rows to predict on")
168
- return df
169
-
170
- # Prepare data for inference (with standardization)
176
+ # Prepare data for inference (with standardization and continuous imputation)
171
177
  x_cont, x_cat, _, _, _ = prepare_data(
172
- complete_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
178
+ matched_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
173
179
  )
174
180
 
175
181
  # Collect ensemble predictions
@@ -191,28 +197,20 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
191
197
  class_preds = np.argmax(avg_probs, axis=1)
192
198
  predictions = label_encoder.inverse_transform(class_preds)
193
199
 
194
- all_proba = pd.Series([None] * len(df), index=df.index, dtype=object)
195
- all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
196
- df["pred_proba"] = all_proba
200
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
197
201
  df = expand_proba_column(df, label_encoder.classes_)
198
202
  else:
199
203
  # Regression
200
204
  predictions = preds.flatten()
201
- df.loc[~missing_mask, "prediction_std"] = preds_std.flatten()
205
+ df["prediction_std"] = preds_std.flatten()
202
206
 
203
207
  # Add UQ intervals if available
204
208
  if uq_models and uq_metadata:
205
- X_complete = complete_df[features]
206
- df_complete = df.loc[~missing_mask].copy()
207
- df_complete["prediction"] = predictions # Set prediction before compute_confidence
208
- df_complete = predict_intervals(df_complete, X_complete, uq_models, uq_metadata)
209
- df_complete = compute_confidence(df_complete, uq_metadata["median_interval_width"], "q_10", "q_90")
210
- # Copy UQ columns back to main dataframe
211
- for col in df_complete.columns:
212
- if col.startswith("q_") or col == "confidence":
213
- df.loc[~missing_mask, col] = df_complete[col].values
214
-
215
- df.loc[~missing_mask, "prediction"] = predictions
209
+ df["prediction"] = predictions # Set prediction before compute_confidence
210
+ df = predict_intervals(df, matched_df[features], uq_models, uq_metadata)
211
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
212
+
213
+ df["prediction"] = predictions
216
214
  return df
217
215
 
218
216
 
@@ -275,11 +273,11 @@ if __name__ == "__main__":
275
273
  all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
276
274
  check_dataframe(all_df, "training_df")
277
275
 
278
- # Drop rows with missing features
276
+ # Drop rows with missing target (required for training)
279
277
  initial_count = len(all_df)
280
- all_df = all_df.dropna(subset=features)
278
+ all_df = all_df.dropna(subset=[target])
281
279
  if len(all_df) < initial_count:
282
- print(f"Dropped {initial_count - len(all_df)} rows with missing features")
280
+ print(f"Dropped {initial_count - len(all_df)} rows with missing target")
283
281
 
284
282
  print(f"Target: {target}")
285
283
  print(f"Features: {features}")
@@ -301,6 +299,23 @@ if __name__ == "__main__":
301
299
  print(f"Categorical: {categorical_cols}")
302
300
  print(f"Continuous: {len(continuous_cols)} columns")
303
301
 
302
+ # Report and handle missing values in features
303
+ # Compute categorical imputation values (mode) for use at inference time
304
+ cat_impute_values = {}
305
+ for col in categorical_cols:
306
+ mode_val = all_df[col].mode().iloc[0] if not all_df[col].mode().empty else all_df[col].cat.categories[0]
307
+ cat_impute_values[col] = str(mode_val) # Convert to string for JSON serialization
308
+
309
+ missing_counts = all_df[features].isna().sum()
310
+ if missing_counts.any():
311
+ missing_features = missing_counts[missing_counts > 0]
312
+ print(f"Missing values in features (will be imputed): {missing_features.to_dict()}")
313
+ # Impute categorical features with mode (most frequent value)
314
+ for col in categorical_cols:
315
+ if all_df[col].isna().any():
316
+ all_df[col] = all_df[col].fillna(cat_impute_values[col])
317
+ # Continuous features are imputed by FeatureScaler.transform() using column means
318
+
304
319
  # -------------------------------------------------------------------------
305
320
  # Classification setup
306
321
  # -------------------------------------------------------------------------
@@ -506,6 +521,9 @@ if __name__ == "__main__":
506
521
  with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
507
522
  json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
508
523
 
524
+ with open(os.path.join(args.model_dir, "categorical_impute.json"), "w") as f:
525
+ json.dump(cat_impute_values, f)
526
+
509
527
  with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
510
528
  json.dump(hyperparameters, f, indent=2)
511
529
 
@@ -22,7 +22,6 @@ import joblib
22
22
  from lightgbm import LGBMRegressor
23
23
  from mapie.regression import ConformalizedQuantileRegressor
24
24
 
25
-
26
25
  # Default confidence levels for prediction intervals
27
26
  DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
28
27
 
@@ -6,7 +6,6 @@ import logging
6
6
  from pathlib import Path
7
7
  import importlib.util
8
8
 
9
-
10
9
  # Setup the logger
11
10
  log = logging.getLogger("workbench")
12
11
 
@@ -65,11 +65,11 @@ REGRESSION_ONLY_PARAMS = {"objective"}
65
65
  TEMPLATE_PARAMS = {
66
66
  "model_type": "uq_regressor",
67
67
  "target": "udm_asy_res_efflux_ratio",
68
- "features": ['fingerprint'],
68
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
69
69
  "id_column": "udm_mol_bat_id",
70
70
  "compressed_features": ['fingerprint'],
71
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-fp/training",
72
- "hyperparameters": {},
71
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-temporal/training",
72
+ "hyperparameters": {'n_folds': 1},
73
73
  }
74
74
 
75
75
 
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284
 
@@ -22,7 +22,6 @@ import joblib
22
22
  from lightgbm import LGBMRegressor
23
23
  from mapie.regression import ConformalizedQuantileRegressor
24
24
 
25
-
26
25
  # Default confidence levels for prediction intervals
27
26
  DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
28
27
 
@@ -1,6 +1,8 @@
1
1
  import argparse
2
+ import ast
2
3
  import logging
3
4
  import json
5
+ import re
4
6
  from pathlib import Path
5
7
 
6
8
  # Workbench Imports
@@ -13,6 +15,56 @@ cm = ConfigManager()
13
15
  workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
16
 
15
17
 
18
+ def parse_workbench_batch(script_content: str) -> dict | None:
19
+ """Parse WORKBENCH_BATCH config from a script.
20
+
21
+ Looks for a dictionary assignment like:
22
+ WORKBENCH_BATCH = {
23
+ "outputs": ["feature_set_xyz"],
24
+ }
25
+ or:
26
+ WORKBENCH_BATCH = {
27
+ "inputs": ["feature_set_xyz"],
28
+ }
29
+
30
+ Args:
31
+ script_content: The Python script content as a string
32
+
33
+ Returns:
34
+ The parsed dictionary or None if not found
35
+ """
36
+ pattern = r"WORKBENCH_BATCH\s*=\s*(\{[^}]+\})"
37
+ match = re.search(pattern, script_content, re.DOTALL)
38
+ if match:
39
+ try:
40
+ return ast.literal_eval(match.group(1))
41
+ except (ValueError, SyntaxError) as e:
42
+ print(f"⚠️ Warning: Failed to parse WORKBENCH_BATCH: {e}")
43
+ return None
44
+ return None
45
+
46
+
47
+ def get_message_group_id(batch_config: dict | None) -> str:
48
+ """Derive MessageGroupId from outputs or inputs.
49
+
50
+ - Scripts with outputs use first output as group
51
+ - Scripts with inputs use first input as group
52
+ - Default to "ml-pipeline-jobs" if no config
53
+ """
54
+ if not batch_config:
55
+ return "ml-pipeline-jobs"
56
+
57
+ outputs = batch_config.get("outputs", [])
58
+ inputs = batch_config.get("inputs", [])
59
+
60
+ if outputs:
61
+ return outputs[0]
62
+ elif inputs:
63
+ return inputs[0]
64
+ else:
65
+ return "ml-pipeline-jobs"
66
+
67
+
16
68
  def submit_to_sqs(
17
69
  script_path: str,
18
70
  size: str = "small",
@@ -44,12 +96,24 @@ def submit_to_sqs(
44
96
  if not script_file.exists():
45
97
  raise FileNotFoundError(f"Script not found: {script_path}")
46
98
 
99
+ # Read script content and parse WORKBENCH_BATCH config
100
+ script_content = script_file.read_text()
101
+ batch_config = parse_workbench_batch(script_content)
102
+ group_id = get_message_group_id(batch_config)
103
+ outputs = (batch_config or {}).get("outputs", [])
104
+ inputs = (batch_config or {}).get("inputs", [])
105
+
47
106
  print(f"📄 Script: {script_file.name}")
48
107
  print(f"📏 Size tier: {size}")
49
108
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
50
109
  print(f"🔄 DynamicTraining: {dt}")
51
110
  print(f"🆕 Promote: {promote}")
52
111
  print(f"🪣 Bucket: {workbench_bucket}")
112
+ if outputs:
113
+ print(f"📤 Outputs: {outputs}")
114
+ if inputs:
115
+ print(f"📥 Inputs: {inputs}")
116
+ print(f"📦 Batch Group: {group_id}")
53
117
  sqs = AWSAccountClamp().boto3_session.client("sqs")
54
118
  script_name = script_file.name
55
119
 
@@ -75,7 +139,7 @@ def submit_to_sqs(
75
139
  print(f" Destination: {s3_path}")
76
140
 
77
141
  try:
78
- upload_content_to_s3(script_file.read_text(), s3_path)
142
+ upload_content_to_s3(script_content, s3_path)
79
143
  print("✅ Script uploaded successfully")
80
144
  except Exception as e:
81
145
  print(f"❌ Upload failed: {e}")
@@ -118,7 +182,7 @@ def submit_to_sqs(
118
182
  response = sqs.send_message(
119
183
  QueueUrl=queue_url,
120
184
  MessageBody=json.dumps(message, indent=2),
121
- MessageGroupId="ml-pipeline-jobs", # Required for FIFO
185
+ MessageGroupId=group_id, # From WORKBENCH_BATCH or default
122
186
  )
123
187
  message_id = response["MessageId"]
124
188
  print("✅ Message sent successfully!")
@@ -136,6 +200,11 @@ def submit_to_sqs(
136
200
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
137
201
  print(f"🔄 DynamicTraining: {dt}")
138
202
  print(f"🆕 Promote: {promote}")
203
+ if outputs:
204
+ print(f"📤 Outputs: {outputs}")
205
+ if inputs:
206
+ print(f"📥 Inputs: {inputs}")
207
+ print(f"📦 Batch Group: {group_id}")
139
208
  print(f"🆔 Message ID: {message_id}")
140
209
  print("\n🔍 MONITORING LOCATIONS:")
141
210
  print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
@@ -3,6 +3,7 @@ h1, h2, h3, h4 {
3
3
  }
4
4
  body {
5
5
  color: rgb(180, 180, 180); /* We want the text dim white */
6
+ background: linear-gradient(90deg, rgb(45, 45, 45) 0%, rgb(35, 35, 35) 100%);
6
7
  }
7
8
 
8
9
  /* Custom CSS to style bold text */
@@ -36,21 +37,38 @@ a:hover {
36
37
 
37
38
  /* AgGrid custom CSS */
38
39
 
39
- /* There's a one pixel border around the grid that we want to remove */
40
- .ag-root-wrapper {
41
- border: none !important; /* Force removal with !important */
42
- }
43
-
44
-
45
- /* Box shadow and rounded corners for all AgGrid themes */
40
+ /* AG Grid 33+ uses CSS variables for theming - set them at the theme level */
46
41
  [class*="ag-theme-"] {
42
+ --ag-background-color: rgb(40, 40, 40);
43
+ --ag-odd-row-background-color: rgb(40, 40, 40);
44
+ --ag-row-background-color: rgb(50, 50, 50);
45
+ --ag-selected-row-background-color: rgb(60, 70, 90);
46
+ --ag-row-hover-color: rgb(55, 55, 65);
47
+ --ag-header-background-color: rgb(35, 35, 35);
48
+ --ag-border-color: rgba(80, 80, 80, 0.5);
49
+ --ag-foreground-color: rgb(180, 180, 180);
50
+ --ag-header-foreground-color: rgb(220, 220, 220);
51
+ --ag-wrapper-border-radius: 12px;
52
+
53
+ /* Box shadow and rounded corners */
47
54
  box-shadow: 2px 2px 6px 5px rgba(0, 0, 0, 0.25);
48
- border-radius: 12px; /* Rounded corners */
55
+ border-radius: 12px;
49
56
  border: 0.5px solid rgba(0, 0, 0, 0.5);
50
57
  margin: 0;
51
58
  padding: 0;
52
59
  }
53
60
 
61
+ /* Remove border from the grid wrapper */
62
+ .ag-root-wrapper {
63
+ border: none !important;
64
+ }
65
+
66
+ /* AG Grid container - remove padding but allow shadow overflow */
67
+ div:has(> [class*="ag-theme-"]) {
68
+ padding: 0 !important;
69
+ overflow: visible !important;
70
+ }
71
+
54
72
  /* Apply styling to Workbench containers */
55
73
  .workbench-container {
56
74
  box-shadow: 2px 2px 6px 5px rgba(0, 0, 0, 0.25);
@@ -110,6 +128,40 @@ a:hover {
110
128
  color: rgb(100, 255, 100);
111
129
  }
112
130
 
131
+ /* Dropdown styling (dcc.Dropdown) - override Bootstrap's variables */
132
+ .dash-dropdown {
133
+ --bs-body-bg: rgb(35, 35, 35);
134
+ --bs-body-color: rgb(210, 210, 210);
135
+ --bs-border-color: rgb(60, 60, 60);
136
+ }
137
+
138
+ /* Bootstrap form controls (dbc components) */
139
+ .form-select, .form-control {
140
+ background-color: rgb(35, 35, 35) !important;
141
+ border: 1px solid rgb(60, 60, 60) !important;
142
+ color: rgb(210, 210, 210) !important;
143
+ }
144
+
145
+ .form-select:focus, .form-control:focus {
146
+ background-color: rgb(45, 45, 45) !important;
147
+ border-color: rgb(80, 80, 80) !important;
148
+ box-shadow: 0 0 0 0.2rem rgba(80, 80, 80, 0.25) !important;
149
+ }
150
+
151
+ .dropdown-menu {
152
+ background-color: rgb(35, 35, 35) !important;
153
+ border: 1px solid rgb(60, 60, 60) !important;
154
+ }
155
+
156
+ .dropdown-item {
157
+ color: rgb(210, 210, 210) !important;
158
+ }
159
+
160
+ .dropdown-item:hover, .dropdown-item:focus {
161
+ background-color: rgb(50, 50, 50) !important;
162
+ color: rgb(230, 230, 230) !important;
163
+ }
164
+
113
165
  /* Table styling */
114
166
  table {
115
167
  width: 100%;
@@ -128,4 +180,29 @@ td {
128
180
  padding: 5px;
129
181
  border: 0.5px solid #444;
130
182
  text-align: center !important;
183
+ }
184
+
185
+ /* AG Grid table header colors - gradient theme */
186
+ /* Data Sources tables - red gradient */
187
+ #main_data_sources .ag-header,
188
+ #data_sources_table .ag-header {
189
+ background: linear-gradient(180deg, rgb(140, 60, 60) 0%, rgb(80, 35, 35) 100%) !important;
190
+ }
191
+
192
+ /* Feature Sets tables - yellow/olive gradient */
193
+ #main_feature_sets .ag-header,
194
+ #feature_sets_table .ag-header {
195
+ background: linear-gradient(180deg, rgb(120, 115, 55) 0%, rgb(70, 65, 30) 100%) !important;
196
+ }
197
+
198
+ /* Models tables - green gradient */
199
+ #main_models .ag-header,
200
+ #models_table .ag-header {
201
+ background: linear-gradient(180deg, rgb(55, 110, 55) 0%, rgb(30, 60, 30) 100%) !important;
202
+ }
203
+
204
+ /* Endpoints tables - purple gradient */
205
+ #main_endpoints .ag-header,
206
+ #endpoints_table .ag-header {
207
+ background: linear-gradient(180deg, rgb(100, 60, 120) 0%, rgb(55, 30, 70) 100%) !important;
131
208
  }
@@ -483,11 +483,11 @@
483
483
  [1.0, "rgb(200, 100, 100)"]
484
484
  ],
485
485
  "sequential": [
486
- [0.0, "rgb(100, 100, 200)"],
487
- [0.4, "rgb(100, 200, 100)"],
488
- [0.65, "rgb(180, 180, 50)"],
489
- [0.85, "rgb(200, 100, 100)"],
490
- [1.0, "rgb(200, 100, 100)"]
486
+ [0.0, "rgba(80, 100, 255, 1.0)"],
487
+ [0.25, "rgba(70, 145, 220, 1.0)"],
488
+ [0.5, "rgba(70, 220, 100, 1.0)"],
489
+ [0.75, "rgba(255, 181, 80, 1.0)"],
490
+ [1.0, "rgba(232, 50, 131, 1.0)"]
491
491
  ],
492
492
  "sequentialminus": [
493
493
  [0.0, "rgb(255, 100, 100)"],
@@ -527,7 +527,7 @@
527
527
  "style": "dark"
528
528
  },
529
529
  "paper_bgcolor": "rgba(0, 0, 0, 0.0)",
530
- "plot_bgcolor": "rgba(0, 0, 0, 0.0)",
530
+ "plot_bgcolor": "rgb(40, 40, 40)",
531
531
  "polar": {
532
532
  "angularaxis": {
533
533
  "gridcolor": "#506784",