workbench 0.8.217__py3-none-any.whl → 0.8.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/core/artifacts/endpoint_core.py +2 -2
- workbench/core/artifacts/feature_set_core.py +2 -2
- workbench/model_script_utils/model_script_utils.py +15 -11
- workbench/model_scripts/chemprop/chemprop.template +2 -2
- workbench/model_scripts/chemprop/generated_model_script.py +6 -6
- workbench/model_scripts/chemprop/model_script_utils.py +15 -11
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
- workbench/model_scripts/pytorch_model/generated_model_script.py +3 -3
- workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
- workbench/model_scripts/xgb_model/generated_model_script.py +6 -6
- workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/utils/chem_utils/fingerprints.py +80 -43
- workbench/utils/meta_model_simulator.py +41 -13
- workbench/utils/shap_utils.py +1 -55
- {workbench-0.8.217.dist-info → workbench-0.8.219.dist-info}/METADATA +1 -1
- {workbench-0.8.217.dist-info → workbench-0.8.219.dist-info}/RECORD +22 -23
- {workbench-0.8.217.dist-info → workbench-0.8.219.dist-info}/entry_points.txt +1 -0
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
- {workbench-0.8.217.dist-info → workbench-0.8.219.dist-info}/WHEEL +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.219.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.217.dist-info → workbench-0.8.219.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,19 @@
|
|
|
1
|
-
"""Molecular fingerprint computation utilities
|
|
1
|
+
"""Molecular fingerprint computation utilities for ADMET modeling.
|
|
2
|
+
|
|
3
|
+
This module provides Morgan count fingerprints, the standard for ADMET prediction.
|
|
4
|
+
Count fingerprints outperform binary fingerprints for molecular property prediction.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
- Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
8
|
+
- ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
|
|
9
|
+
"""
|
|
2
10
|
|
|
3
11
|
import logging
|
|
4
|
-
import pandas as pd
|
|
5
12
|
|
|
6
|
-
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
7
15
|
from rdkit import Chem, RDLogger
|
|
8
|
-
from rdkit.Chem import
|
|
16
|
+
from rdkit.Chem import AllChem
|
|
9
17
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
10
18
|
|
|
11
19
|
# Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
|
|
@@ -16,20 +24,25 @@ RDLogger.DisableLog("rdApp.warning")
|
|
|
16
24
|
log = logging.getLogger("workbench")
|
|
17
25
|
|
|
18
26
|
|
|
19
|
-
def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048
|
|
20
|
-
"""Compute
|
|
27
|
+
def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
|
|
28
|
+
"""Compute Morgan count fingerprints for ADMET modeling.
|
|
29
|
+
|
|
30
|
+
Generates true count fingerprints where each bit position contains the
|
|
31
|
+
number of times that substructure appears in the molecule (clamped to 0-255).
|
|
32
|
+
This is the recommended approach for ADMET prediction per 2025 research.
|
|
21
33
|
|
|
22
34
|
Args:
|
|
23
|
-
df
|
|
24
|
-
radius
|
|
25
|
-
n_bits
|
|
26
|
-
counts (bool): Count simulation for the fingerprint.
|
|
35
|
+
df: Input DataFrame containing SMILES strings.
|
|
36
|
+
radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
|
|
37
|
+
n_bits: Number of bits for the fingerprint (default 2048).
|
|
27
38
|
|
|
28
39
|
Returns:
|
|
29
|
-
pd.DataFrame:
|
|
40
|
+
pd.DataFrame: Input DataFrame with 'fingerprint' column added.
|
|
41
|
+
Values are comma-separated uint8 counts.
|
|
30
42
|
|
|
31
43
|
Note:
|
|
32
|
-
|
|
44
|
+
Count fingerprints outperform binary for ADMET prediction.
|
|
45
|
+
See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
33
46
|
"""
|
|
34
47
|
delete_mol_column = False
|
|
35
48
|
|
|
@@ -43,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
43
56
|
log.warning("Detected serialized molecules in 'molecule' column. Removing...")
|
|
44
57
|
del df["molecule"]
|
|
45
58
|
|
|
46
|
-
# Convert SMILES to RDKit molecule objects
|
|
59
|
+
# Convert SMILES to RDKit molecule objects
|
|
47
60
|
if "molecule" not in df.columns:
|
|
48
61
|
log.info("Converting SMILES to RDKit Molecules...")
|
|
49
62
|
delete_mol_column = True
|
|
@@ -59,15 +72,24 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
59
72
|
lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
|
|
60
73
|
)
|
|
61
74
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
75
|
+
def mol_to_count_string(mol):
|
|
76
|
+
"""Convert molecule to comma-separated count fingerprint string."""
|
|
77
|
+
if mol is None:
|
|
78
|
+
return pd.NA
|
|
66
79
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
80
|
+
# Get hashed Morgan fingerprint with counts
|
|
81
|
+
fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
|
|
82
|
+
|
|
83
|
+
# Initialize array and populate with counts (clamped to uint8 range)
|
|
84
|
+
counts = np.zeros(n_bits, dtype=np.uint8)
|
|
85
|
+
for idx, count in fp.GetNonzeroElements().items():
|
|
86
|
+
counts[idx] = min(count, 255)
|
|
87
|
+
|
|
88
|
+
# Return as comma-separated string
|
|
89
|
+
return ",".join(map(str, counts))
|
|
90
|
+
|
|
91
|
+
# Compute Morgan count fingerprints
|
|
92
|
+
fingerprints = largest_frags.apply(mol_to_count_string)
|
|
71
93
|
|
|
72
94
|
# Add the fingerprints to the DataFrame
|
|
73
95
|
df["fingerprint"] = fingerprints
|
|
@@ -75,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
75
97
|
# Drop the intermediate 'molecule' column if it was added
|
|
76
98
|
if delete_mol_column:
|
|
77
99
|
del df["molecule"]
|
|
100
|
+
|
|
78
101
|
return df
|
|
79
102
|
|
|
80
103
|
|
|
81
104
|
if __name__ == "__main__":
|
|
82
|
-
print("Running
|
|
83
|
-
print("Note: This requires molecular_screening module to be available")
|
|
105
|
+
print("Running Morgan count fingerprint tests...")
|
|
84
106
|
|
|
85
107
|
# Test molecules
|
|
86
108
|
test_molecules = {
|
|
87
109
|
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
|
|
88
110
|
"caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
|
|
89
111
|
"glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
|
|
90
|
-
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
|
|
112
|
+
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
|
|
91
113
|
"benzene": "c1ccccc1",
|
|
92
114
|
"butene_e": "C/C=C/C", # E-butene
|
|
93
115
|
"butene_z": "C/C=C\\C", # Z-butene
|
|
94
116
|
}
|
|
95
117
|
|
|
96
|
-
# Test 1: Morgan Fingerprints
|
|
97
|
-
print("\n1. Testing Morgan fingerprint generation...")
|
|
118
|
+
# Test 1: Morgan Count Fingerprints (default parameters)
|
|
119
|
+
print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
|
|
98
120
|
|
|
99
121
|
test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
|
|
100
|
-
|
|
101
|
-
fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
|
|
122
|
+
fp_df = compute_morgan_fingerprints(test_df.copy())
|
|
102
123
|
|
|
103
124
|
print(" Fingerprint generation results:")
|
|
104
125
|
for _, row in fp_df.iterrows():
|
|
105
126
|
fp = row.get("fingerprint", "N/A")
|
|
106
|
-
|
|
107
|
-
|
|
127
|
+
if pd.notna(fp):
|
|
128
|
+
counts = [int(x) for x in fp.split(",")]
|
|
129
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
130
|
+
max_count = max(counts)
|
|
131
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
|
|
132
|
+
else:
|
|
133
|
+
print(f" {row['name']:15} → N/A")
|
|
108
134
|
|
|
109
|
-
# Test 2: Different
|
|
110
|
-
print("\n2. Testing different
|
|
135
|
+
# Test 2: Different parameters
|
|
136
|
+
print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
|
|
111
137
|
|
|
112
|
-
|
|
113
|
-
fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
|
|
138
|
+
fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
|
|
114
139
|
|
|
115
|
-
|
|
116
|
-
for _, row in fp_counts_df.iterrows():
|
|
140
|
+
for _, row in fp_df_custom.iterrows():
|
|
117
141
|
fp = row.get("fingerprint", "N/A")
|
|
118
|
-
|
|
119
|
-
|
|
142
|
+
if pd.notna(fp):
|
|
143
|
+
counts = [int(x) for x in fp.split(",")]
|
|
144
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
145
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
|
|
146
|
+
else:
|
|
147
|
+
print(f" {row['name']:15} → N/A")
|
|
120
148
|
|
|
121
149
|
# Test 3: Edge cases
|
|
122
150
|
print("\n3. Testing edge cases...")
|
|
123
151
|
|
|
124
152
|
# Invalid SMILES
|
|
125
153
|
invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
|
|
129
|
-
except Exception as e:
|
|
130
|
-
print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
|
|
154
|
+
fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
|
|
155
|
+
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
|
|
131
156
|
|
|
132
157
|
# Test with pre-existing molecule column
|
|
133
158
|
mol_df = test_df.copy()
|
|
@@ -135,4 +160,16 @@ if __name__ == "__main__":
|
|
|
135
160
|
fp_with_mol = compute_morgan_fingerprints(mol_df)
|
|
136
161
|
print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
|
|
137
162
|
|
|
163
|
+
# Test 4: Verify count values are reasonable
|
|
164
|
+
print("\n4. Verifying count distribution...")
|
|
165
|
+
all_counts = []
|
|
166
|
+
for _, row in fp_df.iterrows():
|
|
167
|
+
fp = row.get("fingerprint", "N/A")
|
|
168
|
+
if pd.notna(fp):
|
|
169
|
+
counts = [int(x) for x in fp.split(",")]
|
|
170
|
+
all_counts.extend([c for c in counts if c > 0])
|
|
171
|
+
|
|
172
|
+
if all_counts:
|
|
173
|
+
print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
|
|
174
|
+
|
|
138
175
|
print("\n✅ All fingerprint tests completed!")
|
|
@@ -61,6 +61,13 @@ class MetaModelSimulator:
|
|
|
61
61
|
df["abs_residual"] = df["residual"].abs()
|
|
62
62
|
self._dfs[name] = df
|
|
63
63
|
|
|
64
|
+
# Find common rows across all models
|
|
65
|
+
id_sets = {name: set(df[self.id_column]) for name, df in self._dfs.items()}
|
|
66
|
+
common_ids = set.intersection(*id_sets.values())
|
|
67
|
+
sizes = ", ".join(f"{name}: {len(ids)}" for name, ids in id_sets.items())
|
|
68
|
+
log.info(f"Row counts before alignment: {sizes} -> common: {len(common_ids)}")
|
|
69
|
+
self._dfs = {name: df[df[self.id_column].isin(common_ids)] for name, df in self._dfs.items()}
|
|
70
|
+
|
|
64
71
|
# Align DataFrames by sorting on id column
|
|
65
72
|
self._dfs = {name: df.sort_values(self.id_column).reset_index(drop=True) for name, df in self._dfs.items()}
|
|
66
73
|
log.info(f"Loaded {len(self._dfs)} models, {len(list(self._dfs.values())[0])} samples each")
|
|
@@ -372,13 +379,13 @@ class MetaModelSimulator:
|
|
|
372
379
|
return weight_df
|
|
373
380
|
|
|
374
381
|
def ensemble_failure_analysis(self) -> dict:
|
|
375
|
-
"""Compare ensemble vs best
|
|
382
|
+
"""Compare best ensemble strategy vs best individual model.
|
|
376
383
|
|
|
377
384
|
Returns:
|
|
378
385
|
Dict with comparison statistics
|
|
379
386
|
"""
|
|
380
387
|
print("\n" + "=" * 60)
|
|
381
|
-
print("ENSEMBLE VS BEST MODEL COMPARISON")
|
|
388
|
+
print("BEST ENSEMBLE VS BEST MODEL COMPARISON")
|
|
382
389
|
print("=" * 60)
|
|
383
390
|
|
|
384
391
|
model_names = list(self._dfs.keys())
|
|
@@ -393,35 +400,55 @@ class MetaModelSimulator:
|
|
|
393
400
|
combined[f"{name}_abs_err"] = df["abs_residual"].values
|
|
394
401
|
|
|
395
402
|
pred_cols = [f"{name}_pred" for name in model_names]
|
|
403
|
+
conf_cols = [f"{name}_conf" for name in model_names]
|
|
404
|
+
pred_arr = combined[pred_cols].values
|
|
405
|
+
conf_arr = combined[conf_cols].values
|
|
396
406
|
|
|
397
|
-
# Calculate ensemble prediction (inverse-MAE weighted)
|
|
398
407
|
mae_scores = {name: self._dfs[name]["abs_residual"].mean() for name in model_names}
|
|
399
408
|
inv_mae_weights = np.array([1.0 / mae_scores[name] for name in model_names])
|
|
400
409
|
inv_mae_weights = inv_mae_weights / inv_mae_weights.sum()
|
|
401
|
-
pred_arr = combined[pred_cols].values
|
|
402
|
-
combined["ensemble_pred"] = (pred_arr * inv_mae_weights).sum(axis=1)
|
|
403
|
-
combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
|
|
404
410
|
|
|
405
|
-
#
|
|
411
|
+
# Compute all ensemble strategies (true ensembles that combine multiple models)
|
|
412
|
+
ensemble_strategies = {}
|
|
413
|
+
ensemble_strategies["Simple Mean"] = combined[pred_cols].mean(axis=1)
|
|
414
|
+
conf_sum = conf_arr.sum(axis=1, keepdims=True) + 1e-8
|
|
415
|
+
ensemble_strategies["Confidence-Weighted"] = (pred_arr * (conf_arr / conf_sum)).sum(axis=1)
|
|
416
|
+
ensemble_strategies["Inverse-MAE Weighted"] = (pred_arr * inv_mae_weights).sum(axis=1)
|
|
417
|
+
scaled_conf = conf_arr * inv_mae_weights
|
|
418
|
+
scaled_conf_sum = scaled_conf.sum(axis=1, keepdims=True) + 1e-8
|
|
419
|
+
ensemble_strategies["Scaled Conf-Weighted"] = (pred_arr * (scaled_conf / scaled_conf_sum)).sum(axis=1)
|
|
420
|
+
worst_model = max(mae_scores, key=mae_scores.get)
|
|
421
|
+
remaining = [n for n in model_names if n != worst_model]
|
|
422
|
+
remaining_cols = [f"{n}_pred" for n in remaining]
|
|
423
|
+
# Only add Drop Worst if it still combines multiple models
|
|
424
|
+
if len(remaining) > 1:
|
|
425
|
+
ensemble_strategies[f"Drop Worst ({worst_model})"] = combined[remaining_cols].mean(axis=1)
|
|
426
|
+
|
|
427
|
+
# Find best individual model
|
|
406
428
|
best_model = min(mae_scores, key=mae_scores.get)
|
|
407
429
|
combined["best_model_abs_err"] = combined[f"{best_model}_abs_err"]
|
|
430
|
+
best_model_mae = mae_scores[best_model]
|
|
408
431
|
|
|
409
|
-
#
|
|
432
|
+
# Find best true ensemble strategy
|
|
433
|
+
strategy_maes = {name: (preds - combined["target"]).abs().mean() for name, preds in ensemble_strategies.items()}
|
|
434
|
+
best_strategy = min(strategy_maes, key=strategy_maes.get)
|
|
435
|
+
combined["ensemble_pred"] = ensemble_strategies[best_strategy]
|
|
436
|
+
combined["ensemble_abs_err"] = (combined["ensemble_pred"] - combined["target"]).abs()
|
|
437
|
+
ensemble_mae = strategy_maes[best_strategy]
|
|
438
|
+
|
|
439
|
+
# Compare
|
|
410
440
|
combined["ensemble_better"] = combined["ensemble_abs_err"] < combined["best_model_abs_err"]
|
|
411
441
|
n_better = combined["ensemble_better"].sum()
|
|
412
442
|
n_total = len(combined)
|
|
413
443
|
|
|
414
|
-
ensemble_mae = combined["ensemble_abs_err"].mean()
|
|
415
|
-
best_model_mae = mae_scores[best_model]
|
|
416
|
-
|
|
417
444
|
print(f"\nBest individual model: {best_model} (MAE={best_model_mae:.4f})")
|
|
418
|
-
print(f"
|
|
445
|
+
print(f"Best ensemble strategy: {best_strategy} (MAE={ensemble_mae:.4f})")
|
|
419
446
|
if ensemble_mae < best_model_mae:
|
|
420
447
|
improvement = (best_model_mae - ensemble_mae) / best_model_mae * 100
|
|
421
448
|
print(f"Ensemble improves over best model by {improvement:.1f}%")
|
|
422
449
|
else:
|
|
423
450
|
degradation = (ensemble_mae - best_model_mae) / best_model_mae * 100
|
|
424
|
-
print(f"
|
|
451
|
+
print(f"No ensemble benefit: best single model outperforms all ensemble strategies by {degradation:.1f}%")
|
|
425
452
|
|
|
426
453
|
print("\nPer-row comparison:")
|
|
427
454
|
print(f" Ensemble wins: {n_better}/{n_total} ({100*n_better/n_total:.1f}%)")
|
|
@@ -443,6 +470,7 @@ class MetaModelSimulator:
|
|
|
443
470
|
|
|
444
471
|
return {
|
|
445
472
|
"ensemble_mae": ensemble_mae,
|
|
473
|
+
"best_strategy": best_strategy,
|
|
446
474
|
"best_model": best_model,
|
|
447
475
|
"best_model_mae": best_model_mae,
|
|
448
476
|
"ensemble_win_rate": n_better / n_total,
|
workbench/utils/shap_utils.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Optional, List, Tuple, Dict, Union
|
|
|
9
9
|
from workbench.utils.xgboost_model_utils import xgboost_model_from_s3
|
|
10
10
|
from workbench.utils.model_utils import load_category_mappings_from_s3
|
|
11
11
|
from workbench.utils.pandas_utils import convert_categorical_types
|
|
12
|
+
from workbench.model_script_utils.model_script_utils import decompress_features
|
|
12
13
|
|
|
13
14
|
# Set up the log
|
|
14
15
|
log = logging.getLogger("workbench")
|
|
@@ -111,61 +112,6 @@ def shap_values_data(
|
|
|
111
112
|
return result_df, feature_df
|
|
112
113
|
|
|
113
114
|
|
|
114
|
-
def decompress_features(
|
|
115
|
-
df: pd.DataFrame, features: List[str], compressed_features: List[str]
|
|
116
|
-
) -> Tuple[pd.DataFrame, List[str]]:
|
|
117
|
-
"""Prepare features for the XGBoost model
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
df (pd.DataFrame): The features DataFrame
|
|
121
|
-
features (List[str]): Full list of feature names
|
|
122
|
-
compressed_features (List[str]): List of feature names to decompress (bitstrings)
|
|
123
|
-
|
|
124
|
-
Returns:
|
|
125
|
-
pd.DataFrame: DataFrame with the decompressed features
|
|
126
|
-
List[str]: Updated list of feature names after decompression
|
|
127
|
-
|
|
128
|
-
Raises:
|
|
129
|
-
ValueError: If any missing values are found in the specified features
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
# Check for any missing values in the required features
|
|
133
|
-
missing_counts = df[features].isna().sum()
|
|
134
|
-
if missing_counts.any():
|
|
135
|
-
missing_features = missing_counts[missing_counts > 0]
|
|
136
|
-
print(
|
|
137
|
-
f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
|
|
138
|
-
"WARNING: You might want to remove/replace all NaN values before processing."
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
# Decompress the specified compressed features
|
|
142
|
-
decompressed_features = features
|
|
143
|
-
for feature in compressed_features:
|
|
144
|
-
if (feature not in df.columns) or (feature not in features):
|
|
145
|
-
print(f"Feature '{feature}' not in the features list, skipping decompression.")
|
|
146
|
-
continue
|
|
147
|
-
|
|
148
|
-
# Remove the feature from the list of features to avoid duplication
|
|
149
|
-
decompressed_features.remove(feature)
|
|
150
|
-
|
|
151
|
-
# Handle all compressed features as bitstrings
|
|
152
|
-
bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
|
|
153
|
-
prefix = feature[:3]
|
|
154
|
-
|
|
155
|
-
# Create all new columns at once - avoids fragmentation
|
|
156
|
-
new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
|
|
157
|
-
new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
|
|
158
|
-
|
|
159
|
-
# Add to features list
|
|
160
|
-
decompressed_features.extend(new_col_names)
|
|
161
|
-
|
|
162
|
-
# Drop original column and concatenate new ones
|
|
163
|
-
df = df.drop(columns=[feature])
|
|
164
|
-
df = pd.concat([df, new_df], axis=1)
|
|
165
|
-
|
|
166
|
-
return df, decompressed_features
|
|
167
|
-
|
|
168
|
-
|
|
169
115
|
def _calculate_shap_values(workbench_model, sample_df: pd.DataFrame = None):
|
|
170
116
|
"""
|
|
171
117
|
Internal function to calculate SHAP values for Workbench Models.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: workbench
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.219
|
|
4
4
|
Summary: Workbench: A Dashboard and Python API for creating and deploying AWS SageMaker Model Pipelines
|
|
5
5
|
Author-email: SuperCowPowers LLC <support@supercowpowers.com>
|
|
6
6
|
License: MIT License
|
|
@@ -26,7 +26,7 @@ workbench/algorithms/sql/__init__.py,sha256=TbOZQwCfx6Tjc3pCCLCiM31wpCX26j5MBNQ6
|
|
|
26
26
|
workbench/algorithms/sql/column_stats.py,sha256=IwgddvPVITdAvUgxaK_px2IVSEX-jA-8cPIVFoVkbN8,5943
|
|
27
27
|
workbench/algorithms/sql/correlations.py,sha256=0DMgAkzIdR0cApQ_5vs4CxPSRz1qItcAToz7GAOFqzI,3935
|
|
28
28
|
workbench/algorithms/sql/descriptive_stats.py,sha256=VxSR5zQi8NmAWrJvOCO3wrmgVHYrwhenSy5Gl0AOqoo,4075
|
|
29
|
-
workbench/algorithms/sql/outliers.py,sha256=
|
|
29
|
+
workbench/algorithms/sql/outliers.py,sha256=LbOYaE3bNR4x-aEIrA2KAX3Aq07ZowRgrW9buCeKisQ,10663
|
|
30
30
|
workbench/algorithms/sql/sample_rows.py,sha256=SRYoGb24QP_iPvOoW9bGZ95yZuseYDtyoNhilfoLu34,2688
|
|
31
31
|
workbench/algorithms/sql/value_counts.py,sha256=F-rZoLTTKv1cHYl2_tDlvWDjczy76uLTr3EMHa-WrEk,3340
|
|
32
32
|
workbench/api/__init__.py,sha256=1JAQKD82biia4h07BRA9ytjxuJUYQqgHvkf8FwpnlVQ,1195
|
|
@@ -58,8 +58,8 @@ workbench/core/artifacts/data_capture_core.py,sha256=q8f79rRTYiZ7T4IQRWXl8ZvPpcv
|
|
|
58
58
|
workbench/core/artifacts/data_source_abstract.py,sha256=5IRCzFVK-17cd4NXPMRfx99vQAmQ0WHE5jcm5RfsVTg,10619
|
|
59
59
|
workbench/core/artifacts/data_source_factory.py,sha256=YL_tA5fsgubbB3dPF6T4tO0rGgz-6oo3ge4i_YXVC-M,2380
|
|
60
60
|
workbench/core/artifacts/df_store_core.py,sha256=AueNr_JvuLLu_ByE7cb3u-isH9u0Q7cMP-UCgCX-Ctg,3536
|
|
61
|
-
workbench/core/artifacts/endpoint_core.py,sha256=
|
|
62
|
-
workbench/core/artifacts/feature_set_core.py,sha256=
|
|
61
|
+
workbench/core/artifacts/endpoint_core.py,sha256=fLOxgwNmbsrOpKafXN8zLCzazKdpJQZr2zanKJ14KRc,54057
|
|
62
|
+
workbench/core/artifacts/feature_set_core.py,sha256=zR6gia7V6JeUHaKYzQRGQwF1j0Z5DBcM8oqGPS1pox4,39344
|
|
63
63
|
workbench/core/artifacts/model_core.py,sha256=wPkpdRlxnAXMqsDtJGPotGFO146Hm7NCfYbImHwZo9c,52343
|
|
64
64
|
workbench/core/artifacts/monitor_core.py,sha256=M307yz7tEzOEHgv-LmtVy9jKjSbM98fHW3ckmNYrwlU,27897
|
|
65
65
|
workbench/core/artifacts/parameter_store_core.py,sha256=sHvjJMuybM4qdcKhH-Sx6Ur6Yn5ozA3QHwtidsnhyG8,2867
|
|
@@ -125,22 +125,21 @@ workbench/core/views/training_view.py,sha256=7HwhbQhDBhT3Zo_gssS-b4eueJ0h9nqqT8Y
|
|
|
125
125
|
workbench/core/views/view.py,sha256=DvmEA1xdvL980GET_cnbmHzqSy6IhlNaZcoQnVTtYis,13534
|
|
126
126
|
workbench/core/views/view_utils.py,sha256=CwOlpqXpumCr6REi-ey7Qjz5_tpg-s4oWHmlOVu8POQ,12270
|
|
127
127
|
workbench/core/views/storage/mdq_view.py,sha256=qf_ep1KwaXOIfO930laEwNIiCYP7VNOqjE3VdHfopRE,5195
|
|
128
|
-
workbench/model_script_utils/model_script_utils.py,sha256=
|
|
128
|
+
workbench/model_script_utils/model_script_utils.py,sha256=rGPdjxmQUPcZNXK_8nKYQWb7IPQ5ietne7UMYRQZpMo,11841
|
|
129
129
|
workbench/model_script_utils/pytorch_utils.py,sha256=vr8ybK45U0H8Jhjb5qx6xbJNozdcl7bVqubknDwh6U0,13704
|
|
130
130
|
workbench/model_script_utils/uq_harness.py,sha256=70b7dI9Wls03ff6zm2TpfKIsboVBKsj7P7fNzmMe6c0,10305
|
|
131
131
|
workbench/model_scripts/script_generation.py,sha256=w3L2VYGnGUvBtd01BWzH38DuHKULtYsc_Xz_3_Eavvo,8258
|
|
132
|
-
workbench/model_scripts/chemprop/chemprop.template,sha256=
|
|
133
|
-
workbench/model_scripts/chemprop/generated_model_script.py,sha256=
|
|
134
|
-
workbench/model_scripts/chemprop/model_script_utils.py,sha256=
|
|
132
|
+
workbench/model_scripts/chemprop/chemprop.template,sha256=EF1otxEJGPKm_iZibbWBUvjWhQY0G8jnPK8d_A7OnS8,29416
|
|
133
|
+
workbench/model_scripts/chemprop/generated_model_script.py,sha256=4WqqqkUlUSf1EEgzZk-OAFSwoif5drjwitEko0rlI38,30093
|
|
134
|
+
workbench/model_scripts/chemprop/model_script_utils.py,sha256=rGPdjxmQUPcZNXK_8nKYQWb7IPQ5ietne7UMYRQZpMo,11841
|
|
135
135
|
workbench/model_scripts/chemprop/requirements.txt,sha256=2IBHZZNYqhX9Ed7AmRVgN06tO3EHeBbN2EM8-tjWZhs,216
|
|
136
136
|
workbench/model_scripts/custom_models/chem_info/Readme.md,sha256=mH1lxJ4Pb7F5nBnVXaiuxpi8zS_yjUw_LBJepVKXhlA,574
|
|
137
|
-
workbench/model_scripts/custom_models/chem_info/fingerprints.py,sha256=
|
|
137
|
+
workbench/model_scripts/custom_models/chem_info/fingerprints.py,sha256=ECDzjZs4wSx3ZvAQipMl2NEqI2isCWHLYBv7mp0NVgk,6939
|
|
138
138
|
workbench/model_scripts/custom_models/chem_info/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
|
|
139
139
|
workbench/model_scripts/custom_models/chem_info/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
|
|
140
140
|
workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py,sha256=xljMjdfh4Idi4v1Afq1zZxvF1SDa7pDOLSAhvGBEj88,2891
|
|
141
141
|
workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py,sha256=LqVh_AHObo0uxHt_uNmeemScTLjM2j9C3I_QFJXdmUI,3232
|
|
142
142
|
workbench/model_scripts/custom_models/chem_info/requirements.txt,sha256=7HBUzvNiM8lOir-UfQabXYlUp3gxdGJ42u18EuSMGjc,39
|
|
143
|
-
workbench/model_scripts/custom_models/meta_endpoints/example.py,sha256=hzOAuLhIGB8vei-555ruNxpsE1GhuByHGjGB0zw8GSs,1726
|
|
144
143
|
workbench/model_scripts/custom_models/network_security/Readme.md,sha256=Z2gtiu0hLHvEJ1x-_oFq3qJZcsK81sceBAGAGltpqQ8,222
|
|
145
144
|
workbench/model_scripts/custom_models/proximity/Readme.md,sha256=RlMFAJZgAT2mCgDk-UwR_R0Y_NbCqeI5-8DUsxsbpWQ,289
|
|
146
145
|
workbench/model_scripts/custom_models/proximity/feature_space_proximity.py,sha256=FYsQd5Lf5CrSWi-1Dcs_NVFN86izifxkWk1-EOvEV54,6950
|
|
@@ -151,7 +150,6 @@ workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template,sha256=c
|
|
|
151
150
|
workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template,sha256=449Enh4-7RrMrxt1oS_SHJHGV8yYcFlWHsLrCVTFQGI,13778
|
|
152
151
|
workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py,sha256=FYsQd5Lf5CrSWi-1Dcs_NVFN86izifxkWk1-EOvEV54,6950
|
|
153
152
|
workbench/model_scripts/custom_models/uq_models/gaussian_process.template,sha256=3nMlCi8nEbc4N-MQTzjfIcljfDQkUmWeLBfmd18m5fg,6632
|
|
154
|
-
workbench/model_scripts/custom_models/uq_models/meta_uq.template,sha256=wLilHll9Hzwyo-y9Vsqx7PjzdMca4xkUt3Ed1zcgOBE,14412
|
|
155
153
|
workbench/model_scripts/custom_models/uq_models/ngboost.template,sha256=_ukYcsL4pnWvFV1oA89_wfVpxWbvoEx6MGwKxc38kSI,8512
|
|
156
154
|
workbench/model_scripts/custom_models/uq_models/requirements.txt,sha256=fw7T7t_YJAXK3T6Ysbesxh_Agx_tv0oYx72cEBTqRDY,98
|
|
157
155
|
workbench/model_scripts/custom_script_example/custom_model_script.py,sha256=T8aydawgRVAdSlDimoWpXxG2YuWWQkbcjBVjAeSG2_0,6408
|
|
@@ -160,8 +158,8 @@ workbench/model_scripts/ensemble_xgb/ensemble_xgb.template,sha256=lMEx0IkawcpTI5
|
|
|
160
158
|
workbench/model_scripts/ensemble_xgb/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
161
159
|
workbench/model_scripts/meta_model/generated_model_script.py,sha256=ncPrHd9-R8l_98vAiuTUJ92C9PKpEgAtpIrmd7TuqSQ,8341
|
|
162
160
|
workbench/model_scripts/meta_model/meta_model.template,sha256=viz-AKVq3YRwOUBt8-rUO1TwdEPFzyP7nnifqcIJurw,8244
|
|
163
|
-
workbench/model_scripts/pytorch_model/generated_model_script.py,sha256=
|
|
164
|
-
workbench/model_scripts/pytorch_model/model_script_utils.py,sha256=
|
|
161
|
+
workbench/model_scripts/pytorch_model/generated_model_script.py,sha256=ma2JOiCxCZfq94jvIsDoCa2VQBwKf-trj9QMpaa0VEQ,21108
|
|
162
|
+
workbench/model_scripts/pytorch_model/model_script_utils.py,sha256=rGPdjxmQUPcZNXK_8nKYQWb7IPQ5ietne7UMYRQZpMo,11841
|
|
165
163
|
workbench/model_scripts/pytorch_model/pytorch.template,sha256=KOH7nhq_3u0pHmjGymY5aycF0_ZlwLQ16qmDKUQcE9k,21091
|
|
166
164
|
workbench/model_scripts/pytorch_model/pytorch_utils.py,sha256=vr8ybK45U0H8Jhjb5qx6xbJNozdcl7bVqubknDwh6U0,13704
|
|
167
165
|
workbench/model_scripts/pytorch_model/requirements.txt,sha256=ES7YehHEL4E5oV8FScHm3oNQmkMI4ODgbC1fSbaY7T4,183
|
|
@@ -170,8 +168,8 @@ workbench/model_scripts/scikit_learn/generated_model_script.py,sha256=xhQIglpAgP
|
|
|
170
168
|
workbench/model_scripts/scikit_learn/requirements.txt,sha256=aVvwiJ3LgBUhM_PyFlb2gHXu_kpGPho3ANBzlOkfcvs,107
|
|
171
169
|
workbench/model_scripts/scikit_learn/scikit_learn.template,sha256=QQvqx-eX9ZTbYmyupq6R6vIQwosmsmY_MRBPaHyfjdk,12586
|
|
172
170
|
workbench/model_scripts/uq_models/generated_model_script.py,sha256=kgcIWghY6eazcBWS77MukhQUyYFmfJcS8SQ8RmjM82I,9006
|
|
173
|
-
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=
|
|
174
|
-
workbench/model_scripts/xgb_model/model_script_utils.py,sha256=
|
|
171
|
+
workbench/model_scripts/xgb_model/generated_model_script.py,sha256=VkgU9jXvWzTjPsq9JoIRJGKYJE-aj3-z7gTOc5f6hH4,18376
|
|
172
|
+
workbench/model_scripts/xgb_model/model_script_utils.py,sha256=rGPdjxmQUPcZNXK_8nKYQWb7IPQ5ietne7UMYRQZpMo,11841
|
|
175
173
|
workbench/model_scripts/xgb_model/requirements.txt,sha256=jWlGc7HH7vqyukTm38LN4EyDi8jDUPEay4n45z-30uc,104
|
|
176
174
|
workbench/model_scripts/xgb_model/uq_harness.py,sha256=70b7dI9Wls03ff6zm2TpfKIsboVBKsj7P7fNzmMe6c0,10305
|
|
177
175
|
workbench/model_scripts/xgb_model/xgb_model.template,sha256=w4-yx82yws-_esObZQIq13S8WKXXnZxqe86ZuyWoP5w,18367
|
|
@@ -183,6 +181,7 @@ workbench/scripts/check_double_bond_stereo.py,sha256=p5hnL54Weq77ES0HCELq9JeoM-P
|
|
|
183
181
|
workbench/scripts/endpoint_test.py,sha256=RV52DZZTOD_ou-ywZjaxQ2_wqnSJqvlnHQZbvf4iM6I,5339
|
|
184
182
|
workbench/scripts/glue_launcher.py,sha256=bIKQvfGxpAhzbeNvTnHfRW_5kQhY-169_868ZnCejJk,10692
|
|
185
183
|
workbench/scripts/lambda_test.py,sha256=SLAPIXeGQn82neQ6-Hif3VS3LWLwT0-dGw8yWw2aXRQ,2077
|
|
184
|
+
workbench/scripts/meta_model_sim.py,sha256=6iGpInA-nH6DSjk0z63fcoL8P7icqnZmKLE5Sqyrh7E,1026
|
|
186
185
|
workbench/scripts/ml_pipeline_batch.py,sha256=1T5JnLlUJR7bwAGBLHmLPOuj1xFRqVIQX8PsuDhHy8o,4907
|
|
187
186
|
workbench/scripts/ml_pipeline_sqs.py,sha256=5c8qX-SoV4htOUcSXk4OzD7BQskCnaA7cLMiF4Et24c,6666
|
|
188
187
|
workbench/scripts/monitor_cloud_watch.py,sha256=s7MY4bsHts0nup9G0lWESCvgJZ9Mw1Eo-c8aKRgLjMw,9235
|
|
@@ -236,7 +235,7 @@ workbench/utils/lambda_utils.py,sha256=7GhGRPyXn9o-toWb9HBGSnI8-DhK9YRkwhCSk_mNK
|
|
|
236
235
|
workbench/utils/license_manager.py,sha256=lNE9zZIglmX3zqqCKBdN1xqTgHCEZgJDxavF6pdG7fc,6825
|
|
237
236
|
workbench/utils/log_utils.py,sha256=7n1NJXO_jUX82e6LWAQug6oPo3wiPDBYsqk9gsYab_A,3167
|
|
238
237
|
workbench/utils/markdown_utils.py,sha256=4lEqzgG4EVmLcvvKKNUwNxVCySLQKJTJmWDiaDroI1w,8306
|
|
239
|
-
workbench/utils/meta_model_simulator.py,sha256=
|
|
238
|
+
workbench/utils/meta_model_simulator.py,sha256=fMKZoLi_VEJohNVvbZSMvZWNdUbIpGlB6Bg6mJQW33s,20630
|
|
240
239
|
workbench/utils/metrics_utils.py,sha256=iAoKrAM4iRX8wFSjSJhfNKbbW1BqB3eI_U3wvdhUdhE,9496
|
|
241
240
|
workbench/utils/model_utils.py,sha256=jiybuv6gGE-p2i2JEQcyAY-ffigtuzZFNvp_rHKCi3A,19284
|
|
242
241
|
workbench/utils/monitor_utils.py,sha256=kVaJ7BgUXs3VPMFYfLC03wkIV4Dq-pEhoXS0wkJFxCc,7858
|
|
@@ -250,7 +249,7 @@ workbench/utils/pytorch_utils.py,sha256=RoltE9-fOX2UixzaEmnxN6oJtBEKQ9Jklu0LRzYK
|
|
|
250
249
|
workbench/utils/redis_cache.py,sha256=39LFSWmOlNNcah02D3sBnmibc-DPeKC3SNq71K4HaB4,12893
|
|
251
250
|
workbench/utils/repl_utils.py,sha256=rWOMv2HiEIp8ZL6Ps6DlwiJlGr-pOhv9OZQhm3aR-1A,4668
|
|
252
251
|
workbench/utils/s3_utils.py,sha256=Xme_o_cftC_jWnw6R9YKS6-6C11zaCBAoQDlY3dZb5o,7337
|
|
253
|
-
workbench/utils/shap_utils.py,sha256=
|
|
252
|
+
workbench/utils/shap_utils.py,sha256=FeFNRH5mJTbuHlpHyFJgjHcU5BU7UthJL1Gb5Gl8_zw,10590
|
|
254
253
|
workbench/utils/shapley_values.py,sha256=3DvQz4HIPnxW42idgtuQ5vtzU-oF4_lToaWzLRjU-E4,3673
|
|
255
254
|
workbench/utils/symbols.py,sha256=PioF1yAQyOabw7kLg8nhvaZBPFe7ABkpfpPPE0qz_2k,1265
|
|
256
255
|
workbench/utils/test_data_generator.py,sha256=gqRXL7IUKG4wVfO1onflY3wg7vLkgx402_Zy3iqY7NU,11921
|
|
@@ -264,7 +263,7 @@ workbench/utils/workbench_sqs.py,sha256=RwM80z7YWwdtMaCKh7KWF8v38f7eBRU7kyC7ZhTR
|
|
|
264
263
|
workbench/utils/xgboost_local_crossfold.py,sha256=GY61F6-avQDiteIb1LAgvkHvAKvLg6H85xBDvfgCVDM,10718
|
|
265
264
|
workbench/utils/xgboost_model_utils.py,sha256=qEnB1viCIXMYLW0LJuyCioKMSilbmKTMuppaxBZqwhc,12967
|
|
266
265
|
workbench/utils/chem_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
267
|
-
workbench/utils/chem_utils/fingerprints.py,sha256=
|
|
266
|
+
workbench/utils/chem_utils/fingerprints.py,sha256=ECDzjZs4wSx3ZvAQipMl2NEqI2isCWHLYBv7mp0NVgk,6939
|
|
268
267
|
workbench/utils/chem_utils/misc.py,sha256=Nevf8_opu-uIPrv_1_0ubuFVVo2_fGUkMoLAHB3XAeo,7372
|
|
269
268
|
workbench/utils/chem_utils/mol_descriptors.py,sha256=c8gkHZ-8s3HJaW9zN9pnYGK7YVW8Y0xFqQ1G_ysrF2Y,18789
|
|
270
269
|
workbench/utils/chem_utils/mol_standardize.py,sha256=qPLCdVMSXMOWN-01O1isg2zq7eQyFAI0SNatHkRq1uw,17524
|
|
@@ -307,9 +306,9 @@ workbench/web_interface/page_views/main_page.py,sha256=X4-KyGTKLAdxR-Zk2niuLJB2Y
|
|
|
307
306
|
workbench/web_interface/page_views/models_page_view.py,sha256=M0bdC7bAzLyIaE2jviY12FF4abdMFZmg6sFuOY_LaGI,2650
|
|
308
307
|
workbench/web_interface/page_views/page_view.py,sha256=Gh6YnpOGlUejx-bHZAf5pzqoQ1H1R0OSwOpGhOBO06w,455
|
|
309
308
|
workbench/web_interface/page_views/pipelines_page_view.py,sha256=v2pxrIbsHBcYiblfius3JK766NZ7ciD2yPx0t3E5IJo,2656
|
|
310
|
-
workbench-0.8.
|
|
311
|
-
workbench-0.8.
|
|
312
|
-
workbench-0.8.
|
|
313
|
-
workbench-0.8.
|
|
314
|
-
workbench-0.8.
|
|
315
|
-
workbench-0.8.
|
|
309
|
+
workbench-0.8.219.dist-info/licenses/LICENSE,sha256=RTBoTMeEwTgEhS-n8vgQ-VUo5qig0PWVd8xFPKU6Lck,1080
|
|
310
|
+
workbench-0.8.219.dist-info/METADATA,sha256=1Sks6KYtjjg1QqIH6p4Q8d9Dazr3EfuQdcvv0wgsXgE,10525
|
|
311
|
+
workbench-0.8.219.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
312
|
+
workbench-0.8.219.dist-info/entry_points.txt,sha256=t_9tY7iYku9z96qFZZtUgbWDh_nHtehXxLPLBSpAzeM,566
|
|
313
|
+
workbench-0.8.219.dist-info/top_level.txt,sha256=Dhy72zTxaA_o_yRkPZx5zw-fwumnjGaeGf0hBN3jc_w,10
|
|
314
|
+
workbench-0.8.219.dist-info/RECORD,,
|
|
@@ -3,6 +3,7 @@ cloud_watch = workbench.scripts.monitor_cloud_watch:main
|
|
|
3
3
|
endpoint_test = workbench.scripts.endpoint_test:main
|
|
4
4
|
glue_launcher = workbench.scripts.glue_launcher:main
|
|
5
5
|
lambda_test = workbench.scripts.lambda_test:main
|
|
6
|
+
meta_model_sim = workbench.scripts.meta_model_sim:main
|
|
6
7
|
ml_pipeline_batch = workbench.scripts.ml_pipeline_batch:main
|
|
7
8
|
ml_pipeline_sqs = workbench.scripts.ml_pipeline_sqs:main
|
|
8
9
|
training_test = workbench.scripts.training_test:main
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
# Model: Meta Endpoint Example
|
|
2
|
-
# This script is a template for creating a custom meta endpoint in AWS Workbench.
|
|
3
|
-
from io import StringIO
|
|
4
|
-
import pandas as pd
|
|
5
|
-
import json
|
|
6
|
-
|
|
7
|
-
# Workbench Bridges imports
|
|
8
|
-
try:
|
|
9
|
-
from workbench_bridges.endpoints.fast_inference import fast_inference
|
|
10
|
-
except ImportError:
|
|
11
|
-
print("workbench_bridges not found, this is fine for training...")
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
# Not Used: We need to define this function for SageMaker
|
|
15
|
-
def model_fn(model_dir):
|
|
16
|
-
return None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def input_fn(input_data, content_type):
|
|
20
|
-
"""Parse input data and return a DataFrame."""
|
|
21
|
-
if not input_data:
|
|
22
|
-
raise ValueError("Empty input data is not supported!")
|
|
23
|
-
|
|
24
|
-
# Decode bytes to string if necessary
|
|
25
|
-
if isinstance(input_data, bytes):
|
|
26
|
-
input_data = input_data.decode("utf-8")
|
|
27
|
-
|
|
28
|
-
# Support CSV and JSON input formats
|
|
29
|
-
if "text/csv" in content_type:
|
|
30
|
-
return pd.read_csv(StringIO(input_data))
|
|
31
|
-
elif "application/json" in content_type:
|
|
32
|
-
return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
|
|
33
|
-
else:
|
|
34
|
-
raise ValueError(f"{content_type} not supported!")
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def output_fn(output_df, accept_type):
|
|
38
|
-
"""Supports both CSV and JSON output formats."""
|
|
39
|
-
if "text/csv" in accept_type:
|
|
40
|
-
csv_output = output_df.to_csv(index=False)
|
|
41
|
-
return csv_output, "text/csv"
|
|
42
|
-
elif "application/json" in accept_type:
|
|
43
|
-
return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
|
|
44
|
-
else:
|
|
45
|
-
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
# Prediction function
|
|
49
|
-
def predict_fn(df, model):
|
|
50
|
-
|
|
51
|
-
# Call inference on an endpoint
|
|
52
|
-
df = fast_inference("abalone-regression", df)
|
|
53
|
-
return df
|