workbench 0.8.224__py3-none-any.whl → 0.8.234__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +66 -8
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +3 -0
- workbench/model_scripts/chemprop/generated_model_script.py +3 -3
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +50 -32
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -70
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +48 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +61 -0
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +8 -110
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +28 -11
- workbench/web_interface/components/plugins/scatter_plot.py +56 -43
- workbench/web_interface/components/settings_menu.py +2 -1
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/METADATA +31 -29
- {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/RECORD +55 -59
- {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/top_level.txt +0 -0
|
@@ -59,12 +59,12 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
59
59
|
|
|
60
60
|
# Template parameters (filled in by Workbench)
|
|
61
61
|
TEMPLATE_PARAMS = {
|
|
62
|
-
"model_type": "
|
|
63
|
-
"target": "
|
|
62
|
+
"model_type": "classifier",
|
|
63
|
+
"target": "class",
|
|
64
64
|
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
65
65
|
"id_column": "udm_mol_bat_id",
|
|
66
66
|
"compressed_features": [],
|
|
67
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-pytorch-
|
|
67
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-class-pytorch-1-fr/training",
|
|
68
68
|
"hyperparameters": {},
|
|
69
69
|
}
|
|
70
70
|
|
|
@@ -152,24 +152,30 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
152
152
|
print("Decompressing features for prediction...")
|
|
153
153
|
matched_df, features = decompress_features(matched_df, features, compressed_features)
|
|
154
154
|
|
|
155
|
-
#
|
|
156
|
-
|
|
157
|
-
if
|
|
158
|
-
|
|
155
|
+
# Impute missing values (categorical with mode, continuous handled by scaler)
|
|
156
|
+
missing_counts = matched_df[features].isna().sum()
|
|
157
|
+
if missing_counts.any():
|
|
158
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
159
|
+
print(f"Imputing missing values: {missing_features.to_dict()}")
|
|
160
|
+
|
|
161
|
+
# Load categorical imputation values if available
|
|
162
|
+
impute_path = os.path.join(model_dir, "categorical_impute.json")
|
|
163
|
+
if os.path.exists(impute_path):
|
|
164
|
+
with open(impute_path) as f:
|
|
165
|
+
cat_impute_values = json.load(f)
|
|
166
|
+
for col in categorical_cols:
|
|
167
|
+
if col in cat_impute_values and matched_df[col].isna().any():
|
|
168
|
+
matched_df[col] = matched_df[col].fillna(cat_impute_values[col])
|
|
169
|
+
# Continuous features are imputed by FeatureScaler.transform() using column means
|
|
159
170
|
|
|
160
171
|
# Initialize output columns
|
|
161
172
|
df["prediction"] = np.nan
|
|
162
173
|
if model_type in ["regressor", "uq_regressor"]:
|
|
163
174
|
df["prediction_std"] = np.nan
|
|
164
175
|
|
|
165
|
-
|
|
166
|
-
if len(complete_df) == 0:
|
|
167
|
-
print("Warning: No complete rows to predict on")
|
|
168
|
-
return df
|
|
169
|
-
|
|
170
|
-
# Prepare data for inference (with standardization)
|
|
176
|
+
# Prepare data for inference (with standardization and continuous imputation)
|
|
171
177
|
x_cont, x_cat, _, _, _ = prepare_data(
|
|
172
|
-
|
|
178
|
+
matched_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
|
|
173
179
|
)
|
|
174
180
|
|
|
175
181
|
# Collect ensemble predictions
|
|
@@ -191,28 +197,20 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
191
197
|
class_preds = np.argmax(avg_probs, axis=1)
|
|
192
198
|
predictions = label_encoder.inverse_transform(class_preds)
|
|
193
199
|
|
|
194
|
-
|
|
195
|
-
all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
|
|
196
|
-
df["pred_proba"] = all_proba
|
|
200
|
+
df["pred_proba"] = [p.tolist() for p in avg_probs]
|
|
197
201
|
df = expand_proba_column(df, label_encoder.classes_)
|
|
198
202
|
else:
|
|
199
203
|
# Regression
|
|
200
204
|
predictions = preds.flatten()
|
|
201
|
-
df
|
|
205
|
+
df["prediction_std"] = preds_std.flatten()
|
|
202
206
|
|
|
203
207
|
# Add UQ intervals if available
|
|
204
208
|
if uq_models and uq_metadata:
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
# Copy UQ columns back to main dataframe
|
|
211
|
-
for col in df_complete.columns:
|
|
212
|
-
if col.startswith("q_") or col == "confidence":
|
|
213
|
-
df.loc[~missing_mask, col] = df_complete[col].values
|
|
214
|
-
|
|
215
|
-
df.loc[~missing_mask, "prediction"] = predictions
|
|
209
|
+
df["prediction"] = predictions # Set prediction before compute_confidence
|
|
210
|
+
df = predict_intervals(df, matched_df[features], uq_models, uq_metadata)
|
|
211
|
+
df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
|
|
212
|
+
|
|
213
|
+
df["prediction"] = predictions
|
|
216
214
|
return df
|
|
217
215
|
|
|
218
216
|
|
|
@@ -275,11 +273,11 @@ if __name__ == "__main__":
|
|
|
275
273
|
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
276
274
|
check_dataframe(all_df, "training_df")
|
|
277
275
|
|
|
278
|
-
# Drop rows with missing
|
|
276
|
+
# Drop rows with missing target (required for training)
|
|
279
277
|
initial_count = len(all_df)
|
|
280
|
-
all_df = all_df.dropna(subset=
|
|
278
|
+
all_df = all_df.dropna(subset=[target])
|
|
281
279
|
if len(all_df) < initial_count:
|
|
282
|
-
print(f"Dropped {initial_count - len(all_df)} rows with missing
|
|
280
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing target")
|
|
283
281
|
|
|
284
282
|
print(f"Target: {target}")
|
|
285
283
|
print(f"Features: {features}")
|
|
@@ -301,6 +299,23 @@ if __name__ == "__main__":
|
|
|
301
299
|
print(f"Categorical: {categorical_cols}")
|
|
302
300
|
print(f"Continuous: {len(continuous_cols)} columns")
|
|
303
301
|
|
|
302
|
+
# Report and handle missing values in features
|
|
303
|
+
# Compute categorical imputation values (mode) for use at inference time
|
|
304
|
+
cat_impute_values = {}
|
|
305
|
+
for col in categorical_cols:
|
|
306
|
+
mode_val = all_df[col].mode().iloc[0] if not all_df[col].mode().empty else all_df[col].cat.categories[0]
|
|
307
|
+
cat_impute_values[col] = str(mode_val) # Convert to string for JSON serialization
|
|
308
|
+
|
|
309
|
+
missing_counts = all_df[features].isna().sum()
|
|
310
|
+
if missing_counts.any():
|
|
311
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
312
|
+
print(f"Missing values in features (will be imputed): {missing_features.to_dict()}")
|
|
313
|
+
# Impute categorical features with mode (most frequent value)
|
|
314
|
+
for col in categorical_cols:
|
|
315
|
+
if all_df[col].isna().any():
|
|
316
|
+
all_df[col] = all_df[col].fillna(cat_impute_values[col])
|
|
317
|
+
# Continuous features are imputed by FeatureScaler.transform() using column means
|
|
318
|
+
|
|
304
319
|
# -------------------------------------------------------------------------
|
|
305
320
|
# Classification setup
|
|
306
321
|
# -------------------------------------------------------------------------
|
|
@@ -506,6 +521,9 @@ if __name__ == "__main__":
|
|
|
506
521
|
with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
|
|
507
522
|
json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
|
|
508
523
|
|
|
524
|
+
with open(os.path.join(args.model_dir, "categorical_impute.json"), "w") as f:
|
|
525
|
+
json.dump(cat_impute_values, f)
|
|
526
|
+
|
|
509
527
|
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
510
528
|
json.dump(hyperparameters, f, indent=2)
|
|
511
529
|
|
|
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
|
249
249
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
250
|
|
|
251
251
|
|
|
252
|
+
def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
|
|
253
|
+
"""Cap extreme outliers in prediction_std using IQR method.
|
|
254
|
+
|
|
255
|
+
Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
|
|
256
|
+
This prevents unreasonably large std values while preserving the
|
|
257
|
+
relative ordering and keeping meaningful high-uncertainty signals.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Array with outliers capped at the upper fence
|
|
264
|
+
"""
|
|
265
|
+
if std_array.ndim == 1:
|
|
266
|
+
std_array = std_array.reshape(-1, 1)
|
|
267
|
+
squeeze = True
|
|
268
|
+
else:
|
|
269
|
+
squeeze = False
|
|
270
|
+
|
|
271
|
+
capped = std_array.copy()
|
|
272
|
+
for col in range(capped.shape[1]):
|
|
273
|
+
col_data = capped[:, col]
|
|
274
|
+
q1, q3 = np.percentile(col_data, [25, 75])
|
|
275
|
+
iqr = q3 - q1
|
|
276
|
+
upper_bound = q3 + 1.5 * iqr
|
|
277
|
+
capped[:, col] = np.minimum(col_data, upper_bound)
|
|
278
|
+
|
|
279
|
+
return capped.squeeze() if squeeze else capped
|
|
280
|
+
|
|
281
|
+
|
|
252
282
|
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
283
|
"""Compute standard regression metrics.
|
|
254
284
|
|
|
@@ -152,24 +152,30 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
152
152
|
print("Decompressing features for prediction...")
|
|
153
153
|
matched_df, features = decompress_features(matched_df, features, compressed_features)
|
|
154
154
|
|
|
155
|
-
#
|
|
156
|
-
|
|
157
|
-
if
|
|
158
|
-
|
|
155
|
+
# Impute missing values (categorical with mode, continuous handled by scaler)
|
|
156
|
+
missing_counts = matched_df[features].isna().sum()
|
|
157
|
+
if missing_counts.any():
|
|
158
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
159
|
+
print(f"Imputing missing values: {missing_features.to_dict()}")
|
|
160
|
+
|
|
161
|
+
# Load categorical imputation values if available
|
|
162
|
+
impute_path = os.path.join(model_dir, "categorical_impute.json")
|
|
163
|
+
if os.path.exists(impute_path):
|
|
164
|
+
with open(impute_path) as f:
|
|
165
|
+
cat_impute_values = json.load(f)
|
|
166
|
+
for col in categorical_cols:
|
|
167
|
+
if col in cat_impute_values and matched_df[col].isna().any():
|
|
168
|
+
matched_df[col] = matched_df[col].fillna(cat_impute_values[col])
|
|
169
|
+
# Continuous features are imputed by FeatureScaler.transform() using column means
|
|
159
170
|
|
|
160
171
|
# Initialize output columns
|
|
161
172
|
df["prediction"] = np.nan
|
|
162
173
|
if model_type in ["regressor", "uq_regressor"]:
|
|
163
174
|
df["prediction_std"] = np.nan
|
|
164
175
|
|
|
165
|
-
|
|
166
|
-
if len(complete_df) == 0:
|
|
167
|
-
print("Warning: No complete rows to predict on")
|
|
168
|
-
return df
|
|
169
|
-
|
|
170
|
-
# Prepare data for inference (with standardization)
|
|
176
|
+
# Prepare data for inference (with standardization and continuous imputation)
|
|
171
177
|
x_cont, x_cat, _, _, _ = prepare_data(
|
|
172
|
-
|
|
178
|
+
matched_df, continuous_cols, categorical_cols, category_mappings=category_mappings, scaler=scaler
|
|
173
179
|
)
|
|
174
180
|
|
|
175
181
|
# Collect ensemble predictions
|
|
@@ -191,28 +197,20 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
191
197
|
class_preds = np.argmax(avg_probs, axis=1)
|
|
192
198
|
predictions = label_encoder.inverse_transform(class_preds)
|
|
193
199
|
|
|
194
|
-
|
|
195
|
-
all_proba.loc[~missing_mask] = [p.tolist() for p in avg_probs]
|
|
196
|
-
df["pred_proba"] = all_proba
|
|
200
|
+
df["pred_proba"] = [p.tolist() for p in avg_probs]
|
|
197
201
|
df = expand_proba_column(df, label_encoder.classes_)
|
|
198
202
|
else:
|
|
199
203
|
# Regression
|
|
200
204
|
predictions = preds.flatten()
|
|
201
|
-
df
|
|
205
|
+
df["prediction_std"] = preds_std.flatten()
|
|
202
206
|
|
|
203
207
|
# Add UQ intervals if available
|
|
204
208
|
if uq_models and uq_metadata:
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
# Copy UQ columns back to main dataframe
|
|
211
|
-
for col in df_complete.columns:
|
|
212
|
-
if col.startswith("q_") or col == "confidence":
|
|
213
|
-
df.loc[~missing_mask, col] = df_complete[col].values
|
|
214
|
-
|
|
215
|
-
df.loc[~missing_mask, "prediction"] = predictions
|
|
209
|
+
df["prediction"] = predictions # Set prediction before compute_confidence
|
|
210
|
+
df = predict_intervals(df, matched_df[features], uq_models, uq_metadata)
|
|
211
|
+
df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
|
|
212
|
+
|
|
213
|
+
df["prediction"] = predictions
|
|
216
214
|
return df
|
|
217
215
|
|
|
218
216
|
|
|
@@ -275,11 +273,11 @@ if __name__ == "__main__":
|
|
|
275
273
|
all_df = pd.concat([pd.read_csv(f, engine="python") for f in training_files])
|
|
276
274
|
check_dataframe(all_df, "training_df")
|
|
277
275
|
|
|
278
|
-
# Drop rows with missing
|
|
276
|
+
# Drop rows with missing target (required for training)
|
|
279
277
|
initial_count = len(all_df)
|
|
280
|
-
all_df = all_df.dropna(subset=
|
|
278
|
+
all_df = all_df.dropna(subset=[target])
|
|
281
279
|
if len(all_df) < initial_count:
|
|
282
|
-
print(f"Dropped {initial_count - len(all_df)} rows with missing
|
|
280
|
+
print(f"Dropped {initial_count - len(all_df)} rows with missing target")
|
|
283
281
|
|
|
284
282
|
print(f"Target: {target}")
|
|
285
283
|
print(f"Features: {features}")
|
|
@@ -301,6 +299,23 @@ if __name__ == "__main__":
|
|
|
301
299
|
print(f"Categorical: {categorical_cols}")
|
|
302
300
|
print(f"Continuous: {len(continuous_cols)} columns")
|
|
303
301
|
|
|
302
|
+
# Report and handle missing values in features
|
|
303
|
+
# Compute categorical imputation values (mode) for use at inference time
|
|
304
|
+
cat_impute_values = {}
|
|
305
|
+
for col in categorical_cols:
|
|
306
|
+
mode_val = all_df[col].mode().iloc[0] if not all_df[col].mode().empty else all_df[col].cat.categories[0]
|
|
307
|
+
cat_impute_values[col] = str(mode_val) # Convert to string for JSON serialization
|
|
308
|
+
|
|
309
|
+
missing_counts = all_df[features].isna().sum()
|
|
310
|
+
if missing_counts.any():
|
|
311
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
312
|
+
print(f"Missing values in features (will be imputed): {missing_features.to_dict()}")
|
|
313
|
+
# Impute categorical features with mode (most frequent value)
|
|
314
|
+
for col in categorical_cols:
|
|
315
|
+
if all_df[col].isna().any():
|
|
316
|
+
all_df[col] = all_df[col].fillna(cat_impute_values[col])
|
|
317
|
+
# Continuous features are imputed by FeatureScaler.transform() using column means
|
|
318
|
+
|
|
304
319
|
# -------------------------------------------------------------------------
|
|
305
320
|
# Classification setup
|
|
306
321
|
# -------------------------------------------------------------------------
|
|
@@ -506,6 +521,9 @@ if __name__ == "__main__":
|
|
|
506
521
|
with open(os.path.join(args.model_dir, "feature_metadata.json"), "w") as f:
|
|
507
522
|
json.dump({"continuous_cols": continuous_cols, "categorical_cols": categorical_cols}, f)
|
|
508
523
|
|
|
524
|
+
with open(os.path.join(args.model_dir, "categorical_impute.json"), "w") as f:
|
|
525
|
+
json.dump(cat_impute_values, f)
|
|
526
|
+
|
|
509
527
|
with open(os.path.join(args.model_dir, "hyperparameters.json"), "w") as f:
|
|
510
528
|
json.dump(hyperparameters, f, indent=2)
|
|
511
529
|
|
|
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
|
249
249
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
250
|
|
|
251
251
|
|
|
252
|
+
def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
|
|
253
|
+
"""Cap extreme outliers in prediction_std using IQR method.
|
|
254
|
+
|
|
255
|
+
Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
|
|
256
|
+
This prevents unreasonably large std values while preserving the
|
|
257
|
+
relative ordering and keeping meaningful high-uncertainty signals.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Array with outliers capped at the upper fence
|
|
264
|
+
"""
|
|
265
|
+
if std_array.ndim == 1:
|
|
266
|
+
std_array = std_array.reshape(-1, 1)
|
|
267
|
+
squeeze = True
|
|
268
|
+
else:
|
|
269
|
+
squeeze = False
|
|
270
|
+
|
|
271
|
+
capped = std_array.copy()
|
|
272
|
+
for col in range(capped.shape[1]):
|
|
273
|
+
col_data = capped[:, col]
|
|
274
|
+
q1, q3 = np.percentile(col_data, [25, 75])
|
|
275
|
+
iqr = q3 - q1
|
|
276
|
+
upper_bound = q3 + 1.5 * iqr
|
|
277
|
+
capped[:, col] = np.minimum(col_data, upper_bound)
|
|
278
|
+
|
|
279
|
+
return capped.squeeze() if squeeze else capped
|
|
280
|
+
|
|
281
|
+
|
|
252
282
|
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
283
|
"""Compute standard regression metrics.
|
|
254
284
|
|
workbench/themes/dark/custom.css
CHANGED
|
@@ -3,6 +3,7 @@ h1, h2, h3, h4 {
|
|
|
3
3
|
}
|
|
4
4
|
body {
|
|
5
5
|
color: rgb(180, 180, 180); /* We want the text dim white */
|
|
6
|
+
background: linear-gradient(90deg, rgb(45, 45, 45) 0%, rgb(35, 35, 35) 100%);
|
|
6
7
|
}
|
|
7
8
|
|
|
8
9
|
/* Custom CSS to style bold text */
|
|
@@ -36,21 +37,38 @@ a:hover {
|
|
|
36
37
|
|
|
37
38
|
/* AgGrid custom CSS */
|
|
38
39
|
|
|
39
|
-
/*
|
|
40
|
-
.ag-root-wrapper {
|
|
41
|
-
border: none !important; /* Force removal with !important */
|
|
42
|
-
}
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
/* Box shadow and rounded corners for all AgGrid themes */
|
|
40
|
+
/* AG Grid 33+ uses CSS variables for theming - set them at the theme level */
|
|
46
41
|
[class*="ag-theme-"] {
|
|
42
|
+
--ag-background-color: rgb(40, 40, 40);
|
|
43
|
+
--ag-odd-row-background-color: rgb(40, 40, 40);
|
|
44
|
+
--ag-row-background-color: rgb(50, 50, 50);
|
|
45
|
+
--ag-selected-row-background-color: rgb(60, 70, 90);
|
|
46
|
+
--ag-row-hover-color: rgb(55, 55, 65);
|
|
47
|
+
--ag-header-background-color: rgb(35, 35, 35);
|
|
48
|
+
--ag-border-color: rgba(80, 80, 80, 0.5);
|
|
49
|
+
--ag-foreground-color: rgb(180, 180, 180);
|
|
50
|
+
--ag-header-foreground-color: rgb(220, 220, 220);
|
|
51
|
+
--ag-wrapper-border-radius: 12px;
|
|
52
|
+
|
|
53
|
+
/* Box shadow and rounded corners */
|
|
47
54
|
box-shadow: 2px 2px 6px 5px rgba(0, 0, 0, 0.25);
|
|
48
|
-
border-radius: 12px;
|
|
55
|
+
border-radius: 12px;
|
|
49
56
|
border: 0.5px solid rgba(0, 0, 0, 0.5);
|
|
50
57
|
margin: 0;
|
|
51
58
|
padding: 0;
|
|
52
59
|
}
|
|
53
60
|
|
|
61
|
+
/* Remove border from the grid wrapper */
|
|
62
|
+
.ag-root-wrapper {
|
|
63
|
+
border: none !important;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
/* AG Grid container - remove padding but allow shadow overflow */
|
|
67
|
+
div:has(> [class*="ag-theme-"]) {
|
|
68
|
+
padding: 0 !important;
|
|
69
|
+
overflow: visible !important;
|
|
70
|
+
}
|
|
71
|
+
|
|
54
72
|
/* Apply styling to Workbench containers */
|
|
55
73
|
.workbench-container {
|
|
56
74
|
box-shadow: 2px 2px 6px 5px rgba(0, 0, 0, 0.25);
|
|
@@ -110,6 +128,40 @@ a:hover {
|
|
|
110
128
|
color: rgb(100, 255, 100);
|
|
111
129
|
}
|
|
112
130
|
|
|
131
|
+
/* Dropdown styling (dcc.Dropdown) - override Bootstrap's variables */
|
|
132
|
+
.dash-dropdown {
|
|
133
|
+
--bs-body-bg: rgb(35, 35, 35);
|
|
134
|
+
--bs-body-color: rgb(210, 210, 210);
|
|
135
|
+
--bs-border-color: rgb(60, 60, 60);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
/* Bootstrap form controls (dbc components) */
|
|
139
|
+
.form-select, .form-control {
|
|
140
|
+
background-color: rgb(35, 35, 35) !important;
|
|
141
|
+
border: 1px solid rgb(60, 60, 60) !important;
|
|
142
|
+
color: rgb(210, 210, 210) !important;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
.form-select:focus, .form-control:focus {
|
|
146
|
+
background-color: rgb(45, 45, 45) !important;
|
|
147
|
+
border-color: rgb(80, 80, 80) !important;
|
|
148
|
+
box-shadow: 0 0 0 0.2rem rgba(80, 80, 80, 0.25) !important;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
.dropdown-menu {
|
|
152
|
+
background-color: rgb(35, 35, 35) !important;
|
|
153
|
+
border: 1px solid rgb(60, 60, 60) !important;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
.dropdown-item {
|
|
157
|
+
color: rgb(210, 210, 210) !important;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
.dropdown-item:hover, .dropdown-item:focus {
|
|
161
|
+
background-color: rgb(50, 50, 50) !important;
|
|
162
|
+
color: rgb(230, 230, 230) !important;
|
|
163
|
+
}
|
|
164
|
+
|
|
113
165
|
/* Table styling */
|
|
114
166
|
table {
|
|
115
167
|
width: 100%;
|
|
@@ -128,4 +180,29 @@ td {
|
|
|
128
180
|
padding: 5px;
|
|
129
181
|
border: 0.5px solid #444;
|
|
130
182
|
text-align: center !important;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
/* AG Grid table header colors - gradient theme */
|
|
186
|
+
/* Data Sources tables - red gradient */
|
|
187
|
+
#main_data_sources .ag-header,
|
|
188
|
+
#data_sources_table .ag-header {
|
|
189
|
+
background: linear-gradient(180deg, rgb(140, 60, 60) 0%, rgb(80, 35, 35) 100%) !important;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
/* Feature Sets tables - yellow/olive gradient */
|
|
193
|
+
#main_feature_sets .ag-header,
|
|
194
|
+
#feature_sets_table .ag-header {
|
|
195
|
+
background: linear-gradient(180deg, rgb(120, 115, 55) 0%, rgb(70, 65, 30) 100%) !important;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
/* Models tables - green gradient */
|
|
199
|
+
#main_models .ag-header,
|
|
200
|
+
#models_table .ag-header {
|
|
201
|
+
background: linear-gradient(180deg, rgb(55, 110, 55) 0%, rgb(30, 60, 30) 100%) !important;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
/* Endpoints tables - purple gradient */
|
|
205
|
+
#main_endpoints .ag-header,
|
|
206
|
+
#endpoints_table .ag-header {
|
|
207
|
+
background: linear-gradient(180deg, rgb(100, 60, 120) 0%, rgb(55, 30, 70) 100%) !important;
|
|
131
208
|
}
|
|
@@ -483,11 +483,11 @@
|
|
|
483
483
|
[1.0, "rgb(200, 100, 100)"]
|
|
484
484
|
],
|
|
485
485
|
"sequential": [
|
|
486
|
-
[0.0, "
|
|
487
|
-
[0.
|
|
488
|
-
[0.
|
|
489
|
-
[0.
|
|
490
|
-
[1.0, "
|
|
486
|
+
[0.0, "rgba(80, 100, 255, 1.0)"],
|
|
487
|
+
[0.25, "rgba(70, 145, 220, 1.0)"],
|
|
488
|
+
[0.5, "rgba(70, 220, 100, 1.0)"],
|
|
489
|
+
[0.75, "rgba(255, 181, 80, 1.0)"],
|
|
490
|
+
[1.0, "rgba(232, 50, 131, 1.0)"]
|
|
491
491
|
],
|
|
492
492
|
"sequentialminus": [
|
|
493
493
|
[0.0, "rgb(255, 100, 100)"],
|
|
@@ -527,7 +527,7 @@
|
|
|
527
527
|
"style": "dark"
|
|
528
528
|
},
|
|
529
529
|
"paper_bgcolor": "rgba(0, 0, 0, 0.0)",
|
|
530
|
-
"plot_bgcolor": "
|
|
530
|
+
"plot_bgcolor": "rgb(40, 40, 40)",
|
|
531
531
|
"polar": {
|
|
532
532
|
"angularaxis": {
|
|
533
533
|
"gridcolor": "#506784",
|