workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +12 -11
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +14 -12
- workbench/api/feature_set.py +117 -11
- workbench/api/meta.py +0 -1
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +52 -21
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +49 -11
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +7 -7
- workbench/core/artifacts/data_capture_core.py +8 -1
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +323 -205
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +133 -101
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +60 -44
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +277 -0
- workbench/model_scripts/chemprop/chemprop.template +774 -0
- workbench/model_scripts/chemprop/generated_model_script.py +774 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +15 -12
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +277 -0
- workbench/model_scripts/xgb_model/xgb_model.template +367 -399
- workbench/repl/workbench_shell.py +18 -14
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_sqs.py +122 -6
- workbench/scripts/training_test.py +85 -0
- workbench/themes/dark/custom.css +59 -0
- workbench/themes/dark/plotly.json +5 -5
- workbench/themes/light/custom.css +153 -40
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +59 -0
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/mol_descriptors.py +18 -7
- workbench/utils/chem_utils/mol_standardize.py +80 -58
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chem_utils/vis.py +25 -27
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/config_manager.py +2 -6
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +274 -87
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +159 -34
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/theme_manager.py +95 -30
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -220
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +16 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -3
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +48 -80
- workbench/web_interface/components/plugins/scatter_plot.py +192 -92
- workbench/web_interface/components/settings_menu.py +184 -0
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
- {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
|
@@ -81,6 +81,8 @@ Usage:
|
|
|
81
81
|
import logging
|
|
82
82
|
from typing import Optional, Tuple
|
|
83
83
|
import pandas as pd
|
|
84
|
+
import time
|
|
85
|
+
from contextlib import contextmanager
|
|
84
86
|
from rdkit import Chem
|
|
85
87
|
from rdkit.Chem import Mol
|
|
86
88
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
|
|
|
90
92
|
RDLogger.DisableLog("rdApp.warning")
|
|
91
93
|
|
|
92
94
|
|
|
95
|
+
# Helper context manager for timing
|
|
96
|
+
@contextmanager
|
|
97
|
+
def timer(name):
|
|
98
|
+
start = time.time()
|
|
99
|
+
yield
|
|
100
|
+
print(f"{name}: {time.time() - start:.2f}s")
|
|
101
|
+
|
|
102
|
+
|
|
93
103
|
class MolStandardizer:
|
|
94
104
|
"""
|
|
95
105
|
Streamlined molecular standardizer for ADMET preprocessing
|
|
@@ -116,6 +126,7 @@ class MolStandardizer:
|
|
|
116
126
|
Pipeline:
|
|
117
127
|
1. Cleanup (remove Hs, disconnect metals, normalize)
|
|
118
128
|
2. Get largest fragment (optional - only if remove_salts=True)
|
|
129
|
+
2a. Extract salt information BEFORE further modifications
|
|
119
130
|
3. Neutralize charges
|
|
120
131
|
4. Canonicalize tautomer (optional)
|
|
121
132
|
|
|
@@ -130,18 +141,24 @@ class MolStandardizer:
|
|
|
130
141
|
|
|
131
142
|
try:
|
|
132
143
|
# Step 1: Cleanup
|
|
133
|
-
|
|
134
|
-
if
|
|
144
|
+
cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
|
|
145
|
+
if cleaned_mol is None:
|
|
135
146
|
return None, None
|
|
136
147
|
|
|
148
|
+
# If not doing any transformations, return early
|
|
149
|
+
if not self.remove_salts and not self.canonicalize_tautomer:
|
|
150
|
+
return cleaned_mol, None
|
|
151
|
+
|
|
137
152
|
salt_smiles = None
|
|
153
|
+
mol = cleaned_mol
|
|
138
154
|
|
|
139
155
|
# Step 2: Fragment handling (conditional based on remove_salts)
|
|
140
156
|
if self.remove_salts:
|
|
141
|
-
# Get parent molecule
|
|
142
|
-
parent_mol = rdMolStandardize.FragmentParent(
|
|
157
|
+
# Get parent molecule
|
|
158
|
+
parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
|
|
143
159
|
if parent_mol:
|
|
144
|
-
|
|
160
|
+
# Extract salt BEFORE any modifications to parent
|
|
161
|
+
salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
|
|
145
162
|
mol = parent_mol
|
|
146
163
|
else:
|
|
147
164
|
return None, None
|
|
@@ -153,7 +170,7 @@ class MolStandardizer:
|
|
|
153
170
|
if mol is None:
|
|
154
171
|
return None, salt_smiles
|
|
155
172
|
|
|
156
|
-
# Step 4: Canonicalize tautomer
|
|
173
|
+
# Step 4: Canonicalize tautomer (LAST STEP)
|
|
157
174
|
if self.canonicalize_tautomer:
|
|
158
175
|
mol = self.tautomer_enumerator.Canonicalize(mol)
|
|
159
176
|
|
|
@@ -172,13 +189,22 @@ class MolStandardizer:
|
|
|
172
189
|
- Mixtures: multiple large neutral organic fragments
|
|
173
190
|
|
|
174
191
|
Args:
|
|
175
|
-
orig_mol: Original molecule (before FragmentParent)
|
|
176
|
-
parent_mol: Parent molecule (after FragmentParent)
|
|
192
|
+
orig_mol: Original molecule (after Cleanup, before FragmentParent)
|
|
193
|
+
parent_mol: Parent molecule (after FragmentParent, before tautomerization)
|
|
177
194
|
|
|
178
195
|
Returns:
|
|
179
196
|
SMILES string of salt components or None if no salts/mixture detected
|
|
180
197
|
"""
|
|
181
198
|
try:
|
|
199
|
+
# Quick atom count check
|
|
200
|
+
if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# Quick heavy atom difference check
|
|
204
|
+
heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
|
|
205
|
+
if heavy_diff <= 0:
|
|
206
|
+
return None
|
|
207
|
+
|
|
182
208
|
# Get all fragments from original molecule
|
|
183
209
|
orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
|
|
184
210
|
|
|
@@ -268,7 +294,7 @@ def standardize(
|
|
|
268
294
|
if "orig_smiles" not in result.columns:
|
|
269
295
|
result["orig_smiles"] = result[smiles_column]
|
|
270
296
|
|
|
271
|
-
# Initialize standardizer
|
|
297
|
+
# Initialize standardizer
|
|
272
298
|
standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
|
|
273
299
|
|
|
274
300
|
def process_smiles(smiles: str) -> pd.Series:
|
|
@@ -286,6 +312,11 @@ def standardize(
|
|
|
286
312
|
log.error("Encountered missing or empty SMILES string")
|
|
287
313
|
return pd.Series({"smiles": None, "salt": None})
|
|
288
314
|
|
|
315
|
+
# Early check for unreasonably long SMILES
|
|
316
|
+
if len(smiles) > 1000:
|
|
317
|
+
log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
|
|
318
|
+
return pd.Series({"smiles": None, "salt": None})
|
|
319
|
+
|
|
289
320
|
# Parse molecule
|
|
290
321
|
mol = Chem.MolFromSmiles(smiles)
|
|
291
322
|
if mol is None:
|
|
@@ -299,7 +330,9 @@ def standardize(
|
|
|
299
330
|
if std_mol is not None:
|
|
300
331
|
# Check if molecule is reasonable
|
|
301
332
|
if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
|
|
302
|
-
log.error(f"
|
|
333
|
+
log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
|
|
334
|
+
log.error(f"Original SMILES: {smiles}")
|
|
335
|
+
return pd.Series({"smiles": None, "salt": salt_smiles})
|
|
303
336
|
|
|
304
337
|
if std_mol is None:
|
|
305
338
|
return pd.Series(
|
|
@@ -325,8 +358,11 @@ def standardize(
|
|
|
325
358
|
|
|
326
359
|
|
|
327
360
|
if __name__ == "__main__":
|
|
328
|
-
|
|
329
|
-
|
|
361
|
+
|
|
362
|
+
# Pandas display options for better readability
|
|
363
|
+
pd.set_option("display.max_columns", None)
|
|
364
|
+
pd.set_option("display.width", 1000)
|
|
365
|
+
pd.set_option("display.max_colwidth", 100)
|
|
330
366
|
|
|
331
367
|
# Test with DataFrame including various salt forms
|
|
332
368
|
test_data = pd.DataFrame(
|
|
@@ -362,67 +398,53 @@ if __name__ == "__main__":
|
|
|
362
398
|
)
|
|
363
399
|
|
|
364
400
|
# General test
|
|
401
|
+
print("Testing standardization with full dataset...")
|
|
365
402
|
standardize(test_data)
|
|
366
403
|
|
|
367
404
|
# Remove the last two rows to avoid errors with None and INVALID
|
|
368
405
|
test_data = test_data.iloc[:-2].reset_index(drop=True)
|
|
369
406
|
|
|
370
407
|
# Test WITHOUT salt removal (keeps full molecule)
|
|
371
|
-
print("\nStandardization KEEPING salts (extract_salts=False):")
|
|
372
|
-
print("This preserves the full molecule including counterions")
|
|
408
|
+
print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
|
|
373
409
|
result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
|
|
374
|
-
|
|
375
|
-
print(result_keep[
|
|
410
|
+
display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
|
|
411
|
+
print(result_keep[display_order])
|
|
376
412
|
|
|
377
413
|
# Test WITH salt removal
|
|
378
414
|
print("\n" + "=" * 70)
|
|
379
415
|
print("Standardization REMOVING salts (extract_salts=True):")
|
|
380
|
-
print("This extracts parent molecule and records salt information")
|
|
381
416
|
result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
|
|
382
|
-
print(result_remove[
|
|
417
|
+
print(result_remove[display_order])
|
|
383
418
|
|
|
384
|
-
# Test
|
|
419
|
+
# Test with problematic cases specifically
|
|
385
420
|
print("\n" + "=" * 70)
|
|
386
|
-
print("
|
|
387
|
-
|
|
388
|
-
|
|
421
|
+
print("Testing specific problematic cases:")
|
|
422
|
+
problem_cases = pd.DataFrame(
|
|
423
|
+
{
|
|
424
|
+
"smiles": [
|
|
425
|
+
"CC(=O)O.CCN", # Should extract CC(=O)O as salt
|
|
426
|
+
"CCO.CC", # Should return CC as salt
|
|
427
|
+
],
|
|
428
|
+
"compound_id": ["TEST_C002", "TEST_C005"],
|
|
429
|
+
}
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
|
|
433
|
+
print(problem_result[display_order])
|
|
434
|
+
|
|
435
|
+
# Performance test with larger dataset
|
|
436
|
+
from workbench.api import DataSource
|
|
389
437
|
|
|
390
|
-
# Show the difference for salt-containing molecules
|
|
391
|
-
print("\n" + "=" * 70)
|
|
392
|
-
print("Comparison showing differences:")
|
|
393
|
-
for idx, row in result_keep.iterrows():
|
|
394
|
-
keep_smiles = row["smiles"]
|
|
395
|
-
remove_smiles = result_remove.loc[idx, "smiles"]
|
|
396
|
-
no_taut_smiles = result_no_taut.loc[idx, "smiles"]
|
|
397
|
-
salt = result_remove.loc[idx, "salt"]
|
|
398
|
-
|
|
399
|
-
# Show differences when they exist
|
|
400
|
-
if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
|
|
401
|
-
print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
|
|
402
|
-
if keep_smiles != no_taut_smiles:
|
|
403
|
-
print(f" With salt + taut: {keep_smiles}")
|
|
404
|
-
print(f" With salt, no taut: {no_taut_smiles}")
|
|
405
|
-
if keep_smiles != remove_smiles:
|
|
406
|
-
print(f" Parent only + taut: {remove_smiles}")
|
|
407
|
-
if salt:
|
|
408
|
-
print(f" Extracted salt: {salt}")
|
|
409
|
-
|
|
410
|
-
# Summary statistics
|
|
411
438
|
print("\n" + "=" * 70)
|
|
412
|
-
print("Summary:")
|
|
413
|
-
print(f"Total molecules: {len(result_remove)}")
|
|
414
|
-
print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
|
|
415
|
-
unique_salts = result_remove["salt"].dropna().unique()
|
|
416
|
-
print(f"Unique salts found: {unique_salts[:5].tolist()}")
|
|
417
439
|
|
|
418
|
-
# Get a real dataset from Workbench and time the standardization
|
|
419
440
|
ds = DataSource("aqsol_data")
|
|
420
|
-
df = ds.pull_dataframe()[["id", "smiles"]]
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
441
|
+
df = ds.pull_dataframe()[["id", "smiles"]][:1000]
|
|
442
|
+
|
|
443
|
+
for tautomer in [True, False]:
|
|
444
|
+
for extract in [True, False]:
|
|
445
|
+
print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
|
|
446
|
+
start_time = time.time()
|
|
447
|
+
std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
|
|
448
|
+
elapsed = time.time() - start_time
|
|
449
|
+
mol_per_sec = len(df) / elapsed
|
|
450
|
+
print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from sklearn.preprocessing import StandardScaler
|
|
4
|
+
from sklearn.neighbors import NearestNeighbors
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
# Workbench Imports
|
|
9
|
+
from workbench.algorithms.dataframe.proximity import Proximity
|
|
10
|
+
from workbench.algorithms.dataframe.projection_2d import Projection2D
|
|
11
|
+
|
|
12
|
+
# Set up logging
|
|
13
|
+
log = logging.getLogger("workbench")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FeatureSpaceProximity(Proximity):
|
|
17
|
+
"""Proximity computations for numeric feature spaces using Euclidean distance."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
df: pd.DataFrame,
|
|
22
|
+
id_column: str,
|
|
23
|
+
features: List[str],
|
|
24
|
+
target: Optional[str] = None,
|
|
25
|
+
include_all_columns: bool = False,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize the FeatureSpaceProximity class.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
df: DataFrame containing data for neighbor computations.
|
|
32
|
+
id_column: Name of the column used as the identifier.
|
|
33
|
+
features: List of feature column names to be used for neighbor computations.
|
|
34
|
+
target: Name of the target column. Defaults to None.
|
|
35
|
+
include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
|
|
36
|
+
"""
|
|
37
|
+
# Validate and filter features before calling parent init
|
|
38
|
+
self._raw_features = features
|
|
39
|
+
super().__init__(
|
|
40
|
+
df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def _prepare_data(self) -> None:
|
|
44
|
+
"""Filter out non-numeric features and drop NaN rows."""
|
|
45
|
+
# Validate features
|
|
46
|
+
self.features = self._validate_features(self.df, self._raw_features)
|
|
47
|
+
|
|
48
|
+
# Drop NaN rows for the features we're using
|
|
49
|
+
self.df = self.df.dropna(subset=self.features).copy()
|
|
50
|
+
|
|
51
|
+
def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
|
|
52
|
+
"""Remove non-numeric features and log warnings."""
|
|
53
|
+
non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
|
|
54
|
+
if non_numeric:
|
|
55
|
+
log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
|
|
56
|
+
return [f for f in features if f not in non_numeric]
|
|
57
|
+
|
|
58
|
+
def _build_model(self) -> None:
|
|
59
|
+
"""Standardize features and fit Nearest Neighbors model."""
|
|
60
|
+
self.scaler = StandardScaler()
|
|
61
|
+
X = self.scaler.fit_transform(self.df[self.features])
|
|
62
|
+
self.nn = NearestNeighbors().fit(X)
|
|
63
|
+
|
|
64
|
+
def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
|
|
65
|
+
"""Transform features using the fitted scaler."""
|
|
66
|
+
return self.scaler.transform(df[self.features])
|
|
67
|
+
|
|
68
|
+
def _project_2d(self) -> None:
|
|
69
|
+
"""Project the numeric features to 2D for visualization."""
|
|
70
|
+
if len(self.features) >= 2:
|
|
71
|
+
self.df = Projection2D().fit_transform(self.df, features=self.features)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Testing the FeatureSpaceProximity class
|
|
75
|
+
if __name__ == "__main__":
|
|
76
|
+
|
|
77
|
+
pd.set_option("display.max_columns", None)
|
|
78
|
+
pd.set_option("display.width", 1000)
|
|
79
|
+
|
|
80
|
+
# Create a sample DataFrame
|
|
81
|
+
data = {
|
|
82
|
+
"ID": [1, 2, 3, 4, 5],
|
|
83
|
+
"Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
84
|
+
"Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
85
|
+
"Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
|
|
86
|
+
}
|
|
87
|
+
df = pd.DataFrame(data)
|
|
88
|
+
|
|
89
|
+
# Test the FeatureSpaceProximity class
|
|
90
|
+
features = ["Feature1", "Feature2", "Feature3"]
|
|
91
|
+
prox = FeatureSpaceProximity(df, id_column="ID", features=features)
|
|
92
|
+
print(prox.neighbors(1, n_neighbors=2))
|
|
93
|
+
|
|
94
|
+
# Test the neighbors method with radius
|
|
95
|
+
print(prox.neighbors(1, radius=2.0))
|
|
96
|
+
|
|
97
|
+
# Test with Features list
|
|
98
|
+
prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
|
|
99
|
+
print(prox.neighbors(1))
|
|
100
|
+
|
|
101
|
+
# Create a sample DataFrame
|
|
102
|
+
data = {
|
|
103
|
+
"id": ["a", "b", "c", "d", "e"], # Testing string IDs
|
|
104
|
+
"Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
105
|
+
"Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
106
|
+
"target": [1, 0, 1, 0, 5],
|
|
107
|
+
}
|
|
108
|
+
df = pd.DataFrame(data)
|
|
109
|
+
|
|
110
|
+
# Test with String Ids
|
|
111
|
+
prox = FeatureSpaceProximity(
|
|
112
|
+
df,
|
|
113
|
+
id_column="id",
|
|
114
|
+
features=["Feature1", "Feature2"],
|
|
115
|
+
target="target",
|
|
116
|
+
include_all_columns=True,
|
|
117
|
+
)
|
|
118
|
+
print(prox.neighbors(["a", "b"]))
|
|
119
|
+
|
|
120
|
+
# Test duplicate IDs
|
|
121
|
+
data = {
|
|
122
|
+
"id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
|
|
123
|
+
"Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
|
|
124
|
+
"Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
|
|
125
|
+
"target": [1, 0, 1, 0, 5],
|
|
126
|
+
}
|
|
127
|
+
df = pd.DataFrame(data)
|
|
128
|
+
prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
|
|
129
|
+
print(df.equals(prox.df))
|
|
130
|
+
|
|
131
|
+
# Test on real data from Workbench
|
|
132
|
+
from workbench.api import FeatureSet, Model
|
|
133
|
+
|
|
134
|
+
fs = FeatureSet("aqsol_features")
|
|
135
|
+
model = Model("aqsol-regression")
|
|
136
|
+
features = model.features()
|
|
137
|
+
df = fs.pull_dataframe()
|
|
138
|
+
prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
|
|
139
|
+
print("\n" + "=" * 80)
|
|
140
|
+
print("Testing Neighbors...")
|
|
141
|
+
print("=" * 80)
|
|
142
|
+
test_id = df[fs.id_column].tolist()[0]
|
|
143
|
+
print(f"\nNeighbors for ID {test_id}:")
|
|
144
|
+
print(prox.neighbors(test_id))
|
|
145
|
+
|
|
146
|
+
print("\n" + "=" * 80)
|
|
147
|
+
print("Testing isolated_compounds...")
|
|
148
|
+
print("=" * 80)
|
|
149
|
+
|
|
150
|
+
# Test isolated data in the top 1%
|
|
151
|
+
isolated_1pct = prox.isolated(top_percent=1.0)
|
|
152
|
+
print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
|
|
153
|
+
print(isolated_1pct)
|
|
154
|
+
|
|
155
|
+
# Test isolated data in the top 5%
|
|
156
|
+
isolated_5pct = prox.isolated(top_percent=5.0)
|
|
157
|
+
print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
|
|
158
|
+
print(isolated_5pct)
|
|
159
|
+
|
|
160
|
+
print("\n" + "=" * 80)
|
|
161
|
+
print("Testing target_gradients...")
|
|
162
|
+
print("=" * 80)
|
|
163
|
+
|
|
164
|
+
# Test with different parameters
|
|
165
|
+
gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
|
|
166
|
+
print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
|
|
167
|
+
print(gradients_1pct)
|
|
168
|
+
|
|
169
|
+
gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
|
|
170
|
+
print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
|
|
171
|
+
print(gradients_5pct)
|
|
172
|
+
|
|
173
|
+
# Test proximity_stats
|
|
174
|
+
print("\n" + "=" * 80)
|
|
175
|
+
print("Testing proximity_stats...")
|
|
176
|
+
print("=" * 80)
|
|
177
|
+
stats = prox.proximity_stats()
|
|
178
|
+
print(stats)
|
|
179
|
+
|
|
180
|
+
# Plot the distance distribution using pandas
|
|
181
|
+
print("\n" + "=" * 80)
|
|
182
|
+
print("Plotting distance distribution...")
|
|
183
|
+
print("=" * 80)
|
|
184
|
+
prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
|
|
185
|
+
|
|
186
|
+
# Visualize the 2D projection
|
|
187
|
+
print("\n" + "=" * 80)
|
|
188
|
+
print("Visualizing 2D Projection...")
|
|
189
|
+
print("=" * 80)
|
|
190
|
+
from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
|
|
191
|
+
from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
|
|
192
|
+
|
|
193
|
+
unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
|
|
194
|
+
unit_test.run()
|
|
@@ -8,7 +8,7 @@ TEMPLATE_PARAMS = {
|
|
|
8
8
|
"id_column": "{{id_column}}",
|
|
9
9
|
"features": "{{feature_list}}",
|
|
10
10
|
"target": "{{target_column}}",
|
|
11
|
-
"
|
|
11
|
+
"include_all_columns": "{{include_all_columns}}",
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
from io import StringIO
|
|
@@ -18,7 +18,7 @@ import os
|
|
|
18
18
|
import pandas as pd
|
|
19
19
|
|
|
20
20
|
# Local Imports
|
|
21
|
-
from
|
|
21
|
+
from feature_space_proximity import FeatureSpaceProximity
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
# Function to check if dataframe is empty
|
|
@@ -61,7 +61,7 @@ if __name__ == "__main__":
|
|
|
61
61
|
id_column = TEMPLATE_PARAMS["id_column"]
|
|
62
62
|
features = TEMPLATE_PARAMS["features"]
|
|
63
63
|
target = TEMPLATE_PARAMS["target"] # Can be None for unsupervised models
|
|
64
|
-
|
|
64
|
+
include_all_columns = TEMPLATE_PARAMS["include_all_columns"] # Defaults to False
|
|
65
65
|
|
|
66
66
|
# Script arguments for input/output directories
|
|
67
67
|
parser = argparse.ArgumentParser()
|
|
@@ -73,26 +73,24 @@ if __name__ == "__main__":
|
|
|
73
73
|
args = parser.parse_args()
|
|
74
74
|
|
|
75
75
|
# Load training data from the specified directory
|
|
76
|
-
training_files = [
|
|
77
|
-
os.path.join(args.train, file)
|
|
78
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
79
|
-
]
|
|
76
|
+
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
80
77
|
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
81
78
|
|
|
82
79
|
# Check if the DataFrame is empty
|
|
83
80
|
check_dataframe(all_df, "training_df")
|
|
84
81
|
|
|
85
|
-
# Create the
|
|
86
|
-
model =
|
|
82
|
+
# Create the FeatureSpaceProximity model
|
|
83
|
+
model = FeatureSpaceProximity(all_df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns)
|
|
87
84
|
|
|
88
85
|
# Now serialize the model
|
|
89
86
|
model.serialize(args.model_dir)
|
|
90
87
|
|
|
88
|
+
|
|
91
89
|
# Model loading and prediction functions
|
|
92
90
|
def model_fn(model_dir):
|
|
93
91
|
|
|
94
92
|
# Deserialize the model
|
|
95
|
-
model =
|
|
93
|
+
model = FeatureSpaceProximity.deserialize(model_dir)
|
|
96
94
|
return model
|
|
97
95
|
|
|
98
96
|
|
|
@@ -14,7 +14,7 @@ import pandas as pd
|
|
|
14
14
|
TEMPLATE_PARAMS = {
|
|
15
15
|
"features": "{{feature_list}}",
|
|
16
16
|
"target": "{{target_column}}",
|
|
17
|
-
"train_all_data": "{{train_all_data}}"
|
|
17
|
+
"train_all_data": "{{train_all_data}}",
|
|
18
18
|
}
|
|
19
19
|
|
|
20
20
|
|
|
@@ -37,7 +37,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
|
|
|
37
37
|
"""
|
|
38
38
|
Matches and renames DataFrame columns to match model feature names (case-insensitive).
|
|
39
39
|
Prioritizes exact matches, then case-insensitive matches.
|
|
40
|
-
|
|
40
|
+
|
|
41
41
|
Raises ValueError if any model features cannot be matched.
|
|
42
42
|
"""
|
|
43
43
|
df_columns_lower = {col.lower(): col for col in df.columns}
|
|
@@ -81,10 +81,7 @@ if __name__ == "__main__":
|
|
|
81
81
|
args = parser.parse_args()
|
|
82
82
|
|
|
83
83
|
# Load training data from the specified directory
|
|
84
|
-
training_files = [
|
|
85
|
-
os.path.join(args.train, file)
|
|
86
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
87
|
-
]
|
|
84
|
+
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
88
85
|
df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
89
86
|
|
|
90
87
|
# Check if the DataFrame is empty
|
|
@@ -109,8 +106,10 @@ if __name__ == "__main__":
|
|
|
109
106
|
# Create and train the Regression/Confidence model
|
|
110
107
|
# model = BayesianRidge()
|
|
111
108
|
model = BayesianRidge(
|
|
112
|
-
alpha_1=1e-6,
|
|
113
|
-
|
|
109
|
+
alpha_1=1e-6,
|
|
110
|
+
alpha_2=1e-6, # Noise precision
|
|
111
|
+
lambda_1=1e-6,
|
|
112
|
+
lambda_2=1e-6, # Weight precision
|
|
114
113
|
fit_intercept=True,
|
|
115
114
|
)
|
|
116
115
|
|
|
@@ -4,13 +4,10 @@ import awswrangler as wr
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
# Model Performance Scores
|
|
7
|
-
from sklearn.metrics import
|
|
8
|
-
mean_absolute_error,
|
|
9
|
-
r2_score,
|
|
10
|
-
root_mean_squared_error
|
|
11
|
-
)
|
|
7
|
+
from sklearn.metrics import mean_absolute_error, median_absolute_error, r2_score, root_mean_squared_error
|
|
12
8
|
from sklearn.model_selection import KFold
|
|
13
9
|
from scipy.optimize import minimize
|
|
10
|
+
from scipy.stats import spearmanr
|
|
14
11
|
|
|
15
12
|
from io import StringIO
|
|
16
13
|
import json
|
|
@@ -23,7 +20,7 @@ TEMPLATE_PARAMS = {
|
|
|
23
20
|
"features": "{{feature_list}}",
|
|
24
21
|
"target": "{{target_column}}",
|
|
25
22
|
"train_all_data": "{{train_all_data}}",
|
|
26
|
-
"model_metrics_s3_path": "{{model_metrics_s3_path}}"
|
|
23
|
+
"model_metrics_s3_path": "{{model_metrics_s3_path}}",
|
|
27
24
|
}
|
|
28
25
|
|
|
29
26
|
|
|
@@ -47,7 +44,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
|
|
|
47
44
|
"""
|
|
48
45
|
Matches and renames DataFrame columns to match model feature names (case-insensitive).
|
|
49
46
|
Prioritizes exact matches, then case-insensitive matches.
|
|
50
|
-
|
|
47
|
+
|
|
51
48
|
Raises ValueError if any model features cannot be matched.
|
|
52
49
|
"""
|
|
53
50
|
df_columns_lower = {col.lower(): col for col in df.columns}
|
|
@@ -90,10 +87,7 @@ if __name__ == "__main__":
|
|
|
90
87
|
args = parser.parse_args()
|
|
91
88
|
|
|
92
89
|
# Load training data from the specified directory
|
|
93
|
-
training_files = [
|
|
94
|
-
os.path.join(args.train, file)
|
|
95
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
96
|
-
]
|
|
90
|
+
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
97
91
|
df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
98
92
|
|
|
99
93
|
# Check if the DataFrame is empty
|
|
@@ -172,16 +166,14 @@ if __name__ == "__main__":
|
|
|
172
166
|
cv_residuals = np.array(cv_residuals)
|
|
173
167
|
cv_uncertainties = np.array(cv_uncertainties)
|
|
174
168
|
|
|
175
|
-
|
|
176
169
|
# Optimize calibration parameters: σ_cal = a * σ_uc + b
|
|
177
170
|
def neg_log_likelihood(params):
|
|
178
171
|
a, b = params
|
|
179
172
|
sigma_cal = a * cv_uncertainties + b
|
|
180
173
|
sigma_cal = np.maximum(sigma_cal, 1e-8) # Prevent division by zero
|
|
181
|
-
return np.sum(0.5 * np.log(2 * np.pi * sigma_cal
|
|
174
|
+
return np.sum(0.5 * np.log(2 * np.pi * sigma_cal**2) + 0.5 * (cv_residuals**2) / (sigma_cal**2))
|
|
182
175
|
|
|
183
|
-
|
|
184
|
-
result = minimize(neg_log_likelihood, x0=[1.0, 0.1], method='Nelder-Mead')
|
|
176
|
+
result = minimize(neg_log_likelihood, x0=[1.0, 0.1], method="Nelder-Mead")
|
|
185
177
|
cal_a, cal_b = result.x
|
|
186
178
|
|
|
187
179
|
print(f"Calibration parameters: a={cal_a:.4f}, b={cal_b:.4f}")
|
|
@@ -205,7 +197,9 @@ if __name__ == "__main__":
|
|
|
205
197
|
result_df["prediction"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].mean(axis=1)
|
|
206
198
|
|
|
207
199
|
# Compute uncalibrated uncertainty
|
|
208
|
-
result_df["prediction_std_uc"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].std(
|
|
200
|
+
result_df["prediction_std_uc"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].std(
|
|
201
|
+
axis=1
|
|
202
|
+
)
|
|
209
203
|
|
|
210
204
|
# Apply calibration to uncertainty
|
|
211
205
|
result_df["prediction_std"] = cal_a * result_df["prediction_std_uc"] + cal_b
|
|
@@ -224,11 +218,16 @@ if __name__ == "__main__":
|
|
|
224
218
|
# Report Performance Metrics
|
|
225
219
|
rmse = root_mean_squared_error(result_df[target], result_df["prediction"])
|
|
226
220
|
mae = mean_absolute_error(result_df[target], result_df["prediction"])
|
|
221
|
+
medae = median_absolute_error(result_df[target], result_df["prediction"])
|
|
227
222
|
r2 = r2_score(result_df[target], result_df["prediction"])
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
print(f"
|
|
231
|
-
print(f"
|
|
223
|
+
spearman_corr = spearmanr(result_df[target], result_df["prediction"]).correlation
|
|
224
|
+
support = len(result_df)
|
|
225
|
+
print(f"rmse: {rmse:.3f}")
|
|
226
|
+
print(f"mae: {mae:.3f}")
|
|
227
|
+
print(f"medae: {medae:.3f}")
|
|
228
|
+
print(f"r2: {r2:.3f}")
|
|
229
|
+
print(f"spearmanr: {spearman_corr:.3f}")
|
|
230
|
+
print(f"support: {support}")
|
|
232
231
|
|
|
233
232
|
# Now save the models
|
|
234
233
|
for name, model in models.items():
|
|
@@ -352,4 +351,4 @@ def predict_fn(df, models) -> pd.DataFrame:
|
|
|
352
351
|
df = df.reindex(sorted(df.columns), axis=1)
|
|
353
352
|
|
|
354
353
|
# All done, return the DataFrame
|
|
355
|
-
return df
|
|
354
|
+
return df
|