workbench 0.8.174__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (145) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +7 -7
  27. workbench/core/artifacts/data_capture_core.py +8 -1
  28. workbench/core/artifacts/df_store_core.py +114 -0
  29. workbench/core/artifacts/endpoint_core.py +323 -205
  30. workbench/core/artifacts/feature_set_core.py +249 -45
  31. workbench/core/artifacts/model_core.py +133 -101
  32. workbench/core/artifacts/parameter_store_core.py +98 -0
  33. workbench/core/cloud_platform/aws/aws_account_clamp.py +48 -2
  34. workbench/core/cloud_platform/cloud_meta.py +0 -1
  35. workbench/core/pipelines/pipeline_executor.py +1 -1
  36. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +277 -0
  45. workbench/model_scripts/chemprop/chemprop.template +774 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +18 -7
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +80 -58
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  61. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  62. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  63. workbench/model_scripts/meta_model/meta_model.template +209 -0
  64. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  65. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  66. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  67. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  68. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  69. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  70. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  71. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  72. workbench/model_scripts/script_generation.py +15 -12
  73. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  74. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  75. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  76. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  77. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  78. workbench/repl/workbench_shell.py +18 -14
  79. workbench/resources/open_source_api.key +1 -1
  80. workbench/scripts/endpoint_test.py +162 -0
  81. workbench/scripts/lambda_test.py +73 -0
  82. workbench/scripts/meta_model_sim.py +35 -0
  83. workbench/scripts/ml_pipeline_sqs.py +122 -6
  84. workbench/scripts/training_test.py +85 -0
  85. workbench/themes/dark/custom.css +59 -0
  86. workbench/themes/dark/plotly.json +5 -5
  87. workbench/themes/light/custom.css +153 -40
  88. workbench/themes/light/plotly.json +9 -9
  89. workbench/themes/midnight_blue/custom.css +59 -0
  90. workbench/utils/aws_utils.py +0 -1
  91. workbench/utils/chem_utils/fingerprints.py +87 -46
  92. workbench/utils/chem_utils/mol_descriptors.py +18 -7
  93. workbench/utils/chem_utils/mol_standardize.py +80 -58
  94. workbench/utils/chem_utils/projections.py +16 -6
  95. workbench/utils/chem_utils/vis.py +25 -27
  96. workbench/utils/chemprop_utils.py +141 -0
  97. workbench/utils/config_manager.py +2 -6
  98. workbench/utils/endpoint_utils.py +5 -7
  99. workbench/utils/license_manager.py +2 -6
  100. workbench/utils/markdown_utils.py +57 -0
  101. workbench/utils/meta_model_simulator.py +499 -0
  102. workbench/utils/metrics_utils.py +256 -0
  103. workbench/utils/model_utils.py +274 -87
  104. workbench/utils/pipeline_utils.py +0 -1
  105. workbench/utils/plot_utils.py +159 -34
  106. workbench/utils/pytorch_utils.py +87 -0
  107. workbench/utils/shap_utils.py +11 -57
  108. workbench/utils/theme_manager.py +95 -30
  109. workbench/utils/xgboost_local_crossfold.py +267 -0
  110. workbench/utils/xgboost_model_utils.py +127 -220
  111. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  112. workbench/web_interface/components/model_plot.py +16 -2
  113. workbench/web_interface/components/plugin_unit_test.py +5 -3
  114. workbench/web_interface/components/plugins/ag_table.py +2 -4
  115. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  116. workbench/web_interface/components/plugins/model_details.py +48 -80
  117. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  118. workbench/web_interface/components/settings_menu.py +184 -0
  119. workbench/web_interface/page_views/main_page.py +0 -1
  120. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  121. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/RECORD +125 -111
  122. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  123. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  124. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  125. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  126. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  127. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  128. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  129. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  130. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -502
  131. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  132. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  133. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  134. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  135. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  136. workbench/themes/quartz/base_css.url +0 -1
  137. workbench/themes/quartz/custom.css +0 -117
  138. workbench/themes/quartz/plotly.json +0 -642
  139. workbench/themes/quartz_dark/base_css.url +0 -1
  140. workbench/themes/quartz_dark/custom.css +0 -131
  141. workbench/themes/quartz_dark/plotly.json +0 -642
  142. workbench/utils/fast_inference.py +0 -167
  143. workbench/utils/resource_utils.py +0 -39
  144. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  145. {workbench-0.8.174.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -81,6 +81,8 @@ Usage:
81
81
  import logging
82
82
  from typing import Optional, Tuple
83
83
  import pandas as pd
84
+ import time
85
+ from contextlib import contextmanager
84
86
  from rdkit import Chem
85
87
  from rdkit.Chem import Mol
86
88
  from rdkit.Chem.MolStandardize import rdMolStandardize
@@ -90,6 +92,14 @@ log = logging.getLogger("workbench")
90
92
  RDLogger.DisableLog("rdApp.warning")
91
93
 
92
94
 
95
+ # Helper context manager for timing
96
+ @contextmanager
97
+ def timer(name):
98
+ start = time.time()
99
+ yield
100
+ print(f"{name}: {time.time() - start:.2f}s")
101
+
102
+
93
103
  class MolStandardizer:
94
104
  """
95
105
  Streamlined molecular standardizer for ADMET preprocessing
@@ -116,6 +126,7 @@ class MolStandardizer:
116
126
  Pipeline:
117
127
  1. Cleanup (remove Hs, disconnect metals, normalize)
118
128
  2. Get largest fragment (optional - only if remove_salts=True)
129
+ 2a. Extract salt information BEFORE further modifications
119
130
  3. Neutralize charges
120
131
  4. Canonicalize tautomer (optional)
121
132
 
@@ -130,18 +141,24 @@ class MolStandardizer:
130
141
 
131
142
  try:
132
143
  # Step 1: Cleanup
133
- mol = rdMolStandardize.Cleanup(mol, self.params)
134
- if mol is None:
144
+ cleaned_mol = rdMolStandardize.Cleanup(mol, self.params)
145
+ if cleaned_mol is None:
135
146
  return None, None
136
147
 
148
+ # If not doing any transformations, return early
149
+ if not self.remove_salts and not self.canonicalize_tautomer:
150
+ return cleaned_mol, None
151
+
137
152
  salt_smiles = None
153
+ mol = cleaned_mol
138
154
 
139
155
  # Step 2: Fragment handling (conditional based on remove_salts)
140
156
  if self.remove_salts:
141
- # Get parent molecule and extract salt information
142
- parent_mol = rdMolStandardize.FragmentParent(mol, self.params)
157
+ # Get parent molecule
158
+ parent_mol = rdMolStandardize.FragmentParent(cleaned_mol, self.params)
143
159
  if parent_mol:
144
- salt_smiles = self._extract_salt(mol, parent_mol)
160
+ # Extract salt BEFORE any modifications to parent
161
+ salt_smiles = self._extract_salt(cleaned_mol, parent_mol)
145
162
  mol = parent_mol
146
163
  else:
147
164
  return None, None
@@ -153,7 +170,7 @@ class MolStandardizer:
153
170
  if mol is None:
154
171
  return None, salt_smiles
155
172
 
156
- # Step 4: Canonicalize tautomer
173
+ # Step 4: Canonicalize tautomer (LAST STEP)
157
174
  if self.canonicalize_tautomer:
158
175
  mol = self.tautomer_enumerator.Canonicalize(mol)
159
176
 
@@ -172,13 +189,22 @@ class MolStandardizer:
172
189
  - Mixtures: multiple large neutral organic fragments
173
190
 
174
191
  Args:
175
- orig_mol: Original molecule (before FragmentParent)
176
- parent_mol: Parent molecule (after FragmentParent)
192
+ orig_mol: Original molecule (after Cleanup, before FragmentParent)
193
+ parent_mol: Parent molecule (after FragmentParent, before tautomerization)
177
194
 
178
195
  Returns:
179
196
  SMILES string of salt components or None if no salts/mixture detected
180
197
  """
181
198
  try:
199
+ # Quick atom count check
200
+ if orig_mol.GetNumAtoms() == parent_mol.GetNumAtoms():
201
+ return None
202
+
203
+ # Quick heavy atom difference check
204
+ heavy_diff = orig_mol.GetNumHeavyAtoms() - parent_mol.GetNumHeavyAtoms()
205
+ if heavy_diff <= 0:
206
+ return None
207
+
182
208
  # Get all fragments from original molecule
183
209
  orig_frags = Chem.GetMolFrags(orig_mol, asMols=True)
184
210
 
@@ -268,7 +294,7 @@ def standardize(
268
294
  if "orig_smiles" not in result.columns:
269
295
  result["orig_smiles"] = result[smiles_column]
270
296
 
271
- # Initialize standardizer with salt removal control
297
+ # Initialize standardizer
272
298
  standardizer = MolStandardizer(canonicalize_tautomer=canonicalize_tautomer, remove_salts=extract_salts)
273
299
 
274
300
  def process_smiles(smiles: str) -> pd.Series:
@@ -286,6 +312,11 @@ def standardize(
286
312
  log.error("Encountered missing or empty SMILES string")
287
313
  return pd.Series({"smiles": None, "salt": None})
288
314
 
315
+ # Early check for unreasonably long SMILES
316
+ if len(smiles) > 1000:
317
+ log.error(f"SMILES too long ({len(smiles)} chars): {smiles[:50]}...")
318
+ return pd.Series({"smiles": None, "salt": None})
319
+
289
320
  # Parse molecule
290
321
  mol = Chem.MolFromSmiles(smiles)
291
322
  if mol is None:
@@ -299,7 +330,9 @@ def standardize(
299
330
  if std_mol is not None:
300
331
  # Check if molecule is reasonable
301
332
  if std_mol.GetNumAtoms() == 0 or std_mol.GetNumAtoms() > 200: # Arbitrary limits
302
- log.error(f"Unusual molecule size: {std_mol.GetNumAtoms()} atoms")
333
+ log.error(f"Rejecting molecule size: {std_mol.GetNumAtoms()} atoms")
334
+ log.error(f"Original SMILES: {smiles}")
335
+ return pd.Series({"smiles": None, "salt": salt_smiles})
303
336
 
304
337
  if std_mol is None:
305
338
  return pd.Series(
@@ -325,8 +358,11 @@ def standardize(
325
358
 
326
359
 
327
360
  if __name__ == "__main__":
328
- import time
329
- from workbench.api import DataSource
361
+
362
+ # Pandas display options for better readability
363
+ pd.set_option("display.max_columns", None)
364
+ pd.set_option("display.width", 1000)
365
+ pd.set_option("display.max_colwidth", 100)
330
366
 
331
367
  # Test with DataFrame including various salt forms
332
368
  test_data = pd.DataFrame(
@@ -362,67 +398,53 @@ if __name__ == "__main__":
362
398
  )
363
399
 
364
400
  # General test
401
+ print("Testing standardization with full dataset...")
365
402
  standardize(test_data)
366
403
 
367
404
  # Remove the last two rows to avoid errors with None and INVALID
368
405
  test_data = test_data.iloc[:-2].reset_index(drop=True)
369
406
 
370
407
  # Test WITHOUT salt removal (keeps full molecule)
371
- print("\nStandardization KEEPING salts (extract_salts=False):")
372
- print("This preserves the full molecule including counterions")
408
+ print("\nStandardization KEEPING salts (extract_salts=False) Tautomerization: True")
373
409
  result_keep = standardize(test_data, extract_salts=False, canonicalize_tautomer=True)
374
- display_cols = ["compound_id", "orig_smiles", "smiles", "salt"]
375
- print(result_keep[display_cols].to_string())
410
+ display_order = ["compound_id", "orig_smiles", "smiles", "salt"]
411
+ print(result_keep[display_order])
376
412
 
377
413
  # Test WITH salt removal
378
414
  print("\n" + "=" * 70)
379
415
  print("Standardization REMOVING salts (extract_salts=True):")
380
- print("This extracts parent molecule and records salt information")
381
416
  result_remove = standardize(test_data, extract_salts=True, canonicalize_tautomer=True)
382
- print(result_remove[display_cols].to_string())
417
+ print(result_remove[display_order])
383
418
 
384
- # Test WITHOUT tautomerization (keeping salts)
419
+ # Test with problematic cases specifically
385
420
  print("\n" + "=" * 70)
386
- print("Standardization KEEPING salts, NO tautomerization:")
387
- result_no_taut = standardize(test_data, extract_salts=False, canonicalize_tautomer=False)
388
- print(result_no_taut[display_cols].to_string())
421
+ print("Testing specific problematic cases:")
422
+ problem_cases = pd.DataFrame(
423
+ {
424
+ "smiles": [
425
+ "CC(=O)O.CCN", # Should extract CC(=O)O as salt
426
+ "CCO.CC", # Should return CC as salt
427
+ ],
428
+ "compound_id": ["TEST_C002", "TEST_C005"],
429
+ }
430
+ )
431
+
432
+ problem_result = standardize(problem_cases, extract_salts=True, canonicalize_tautomer=True)
433
+ print(problem_result[display_order])
434
+
435
+ # Performance test with larger dataset
436
+ from workbench.api import DataSource
389
437
 
390
- # Show the difference for salt-containing molecules
391
- print("\n" + "=" * 70)
392
- print("Comparison showing differences:")
393
- for idx, row in result_keep.iterrows():
394
- keep_smiles = row["smiles"]
395
- remove_smiles = result_remove.loc[idx, "smiles"]
396
- no_taut_smiles = result_no_taut.loc[idx, "smiles"]
397
- salt = result_remove.loc[idx, "salt"]
398
-
399
- # Show differences when they exist
400
- if keep_smiles != remove_smiles or keep_smiles != no_taut_smiles:
401
- print(f"\n{row['compound_id']} ({row['orig_smiles']}):")
402
- if keep_smiles != no_taut_smiles:
403
- print(f" With salt + taut: {keep_smiles}")
404
- print(f" With salt, no taut: {no_taut_smiles}")
405
- if keep_smiles != remove_smiles:
406
- print(f" Parent only + taut: {remove_smiles}")
407
- if salt:
408
- print(f" Extracted salt: {salt}")
409
-
410
- # Summary statistics
411
438
  print("\n" + "=" * 70)
412
- print("Summary:")
413
- print(f"Total molecules: {len(result_remove)}")
414
- print(f"Molecules with salts: {result_remove['salt'].notna().sum()}")
415
- unique_salts = result_remove["salt"].dropna().unique()
416
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
417
439
 
418
- # Get a real dataset from Workbench and time the standardization
419
440
  ds = DataSource("aqsol_data")
420
- df = ds.pull_dataframe()[["id", "smiles"]]
421
- start_time = time.time()
422
- std_df = standardize(df, extract_salts=True, canonicalize_tautomer=True)
423
- end_time = time.time()
424
- print(f"\nStandardized {len(std_df)} molecules from Workbench in {end_time - start_time:.2f} seconds")
425
- print(std_df.head())
426
- print(f"Molecules with salts: {std_df['salt'].notna().sum()}")
427
- unique_salts = std_df["salt"].dropna().unique()
428
- print(f"Unique salts found: {unique_salts[:5].tolist()}")
441
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000]
442
+
443
+ for tautomer in [True, False]:
444
+ for extract in [True, False]:
445
+ print(f"Performance test with AQSol dataset: tautomer={tautomer} extract_salts={extract}:")
446
+ start_time = time.time()
447
+ std_df = standardize(df, canonicalize_tautomer=tautomer, extract_salts=extract)
448
+ elapsed = time.time() - start_time
449
+ mol_per_sec = len(df) / elapsed
450
+ print(f"{elapsed:.2f}s ({mol_per_sec:.0f} mol/s)")
@@ -15,7 +15,6 @@ import json
15
15
  from mol_standardize import standardize
16
16
  from mol_descriptors import compute_descriptors
17
17
 
18
-
19
18
  # TRAINING SECTION
20
19
  #
21
20
  # This section (__main__) is where SageMaker will execute the training job
@@ -15,8 +15,7 @@ import pandas as pd
15
15
  import json
16
16
 
17
17
  # Local imports
18
- from local_utils import compute_morgan_fingerprints
19
-
18
+ from fingerprints import compute_morgan_fingerprints
20
19
 
21
20
  # TRAINING SECTION
22
21
  #
@@ -0,0 +1,194 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sklearn.neighbors import NearestNeighbors
5
+ from typing import List, Optional
6
+ import logging
7
+
8
+ # Workbench Imports
9
+ from workbench.algorithms.dataframe.proximity import Proximity
10
+ from workbench.algorithms.dataframe.projection_2d import Projection2D
11
+
12
+ # Set up logging
13
+ log = logging.getLogger("workbench")
14
+
15
+
16
+ class FeatureSpaceProximity(Proximity):
17
+ """Proximity computations for numeric feature spaces using Euclidean distance."""
18
+
19
+ def __init__(
20
+ self,
21
+ df: pd.DataFrame,
22
+ id_column: str,
23
+ features: List[str],
24
+ target: Optional[str] = None,
25
+ include_all_columns: bool = False,
26
+ ):
27
+ """
28
+ Initialize the FeatureSpaceProximity class.
29
+
30
+ Args:
31
+ df: DataFrame containing data for neighbor computations.
32
+ id_column: Name of the column used as the identifier.
33
+ features: List of feature column names to be used for neighbor computations.
34
+ target: Name of the target column. Defaults to None.
35
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
36
+ """
37
+ # Validate and filter features before calling parent init
38
+ self._raw_features = features
39
+ super().__init__(
40
+ df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
41
+ )
42
+
43
+ def _prepare_data(self) -> None:
44
+ """Filter out non-numeric features and drop NaN rows."""
45
+ # Validate features
46
+ self.features = self._validate_features(self.df, self._raw_features)
47
+
48
+ # Drop NaN rows for the features we're using
49
+ self.df = self.df.dropna(subset=self.features).copy()
50
+
51
+ def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
52
+ """Remove non-numeric features and log warnings."""
53
+ non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
54
+ if non_numeric:
55
+ log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
56
+ return [f for f in features if f not in non_numeric]
57
+
58
+ def _build_model(self) -> None:
59
+ """Standardize features and fit Nearest Neighbors model."""
60
+ self.scaler = StandardScaler()
61
+ X = self.scaler.fit_transform(self.df[self.features])
62
+ self.nn = NearestNeighbors().fit(X)
63
+
64
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
65
+ """Transform features using the fitted scaler."""
66
+ return self.scaler.transform(df[self.features])
67
+
68
+ def _project_2d(self) -> None:
69
+ """Project the numeric features to 2D for visualization."""
70
+ if len(self.features) >= 2:
71
+ self.df = Projection2D().fit_transform(self.df, features=self.features)
72
+
73
+
74
+ # Testing the FeatureSpaceProximity class
75
+ if __name__ == "__main__":
76
+
77
+ pd.set_option("display.max_columns", None)
78
+ pd.set_option("display.width", 1000)
79
+
80
+ # Create a sample DataFrame
81
+ data = {
82
+ "ID": [1, 2, 3, 4, 5],
83
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
84
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
85
+ "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
86
+ }
87
+ df = pd.DataFrame(data)
88
+
89
+ # Test the FeatureSpaceProximity class
90
+ features = ["Feature1", "Feature2", "Feature3"]
91
+ prox = FeatureSpaceProximity(df, id_column="ID", features=features)
92
+ print(prox.neighbors(1, n_neighbors=2))
93
+
94
+ # Test the neighbors method with radius
95
+ print(prox.neighbors(1, radius=2.0))
96
+
97
+ # Test with Features list
98
+ prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
99
+ print(prox.neighbors(1))
100
+
101
+ # Create a sample DataFrame
102
+ data = {
103
+ "id": ["a", "b", "c", "d", "e"], # Testing string IDs
104
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
105
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
106
+ "target": [1, 0, 1, 0, 5],
107
+ }
108
+ df = pd.DataFrame(data)
109
+
110
+ # Test with String Ids
111
+ prox = FeatureSpaceProximity(
112
+ df,
113
+ id_column="id",
114
+ features=["Feature1", "Feature2"],
115
+ target="target",
116
+ include_all_columns=True,
117
+ )
118
+ print(prox.neighbors(["a", "b"]))
119
+
120
+ # Test duplicate IDs
121
+ data = {
122
+ "id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
123
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
124
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
125
+ "target": [1, 0, 1, 0, 5],
126
+ }
127
+ df = pd.DataFrame(data)
128
+ prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
129
+ print(df.equals(prox.df))
130
+
131
+ # Test on real data from Workbench
132
+ from workbench.api import FeatureSet, Model
133
+
134
+ fs = FeatureSet("aqsol_features")
135
+ model = Model("aqsol-regression")
136
+ features = model.features()
137
+ df = fs.pull_dataframe()
138
+ prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
139
+ print("\n" + "=" * 80)
140
+ print("Testing Neighbors...")
141
+ print("=" * 80)
142
+ test_id = df[fs.id_column].tolist()[0]
143
+ print(f"\nNeighbors for ID {test_id}:")
144
+ print(prox.neighbors(test_id))
145
+
146
+ print("\n" + "=" * 80)
147
+ print("Testing isolated_compounds...")
148
+ print("=" * 80)
149
+
150
+ # Test isolated data in the top 1%
151
+ isolated_1pct = prox.isolated(top_percent=1.0)
152
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
153
+ print(isolated_1pct)
154
+
155
+ # Test isolated data in the top 5%
156
+ isolated_5pct = prox.isolated(top_percent=5.0)
157
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
158
+ print(isolated_5pct)
159
+
160
+ print("\n" + "=" * 80)
161
+ print("Testing target_gradients...")
162
+ print("=" * 80)
163
+
164
+ # Test with different parameters
165
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
166
+ print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
167
+ print(gradients_1pct)
168
+
169
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
170
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
171
+ print(gradients_5pct)
172
+
173
+ # Test proximity_stats
174
+ print("\n" + "=" * 80)
175
+ print("Testing proximity_stats...")
176
+ print("=" * 80)
177
+ stats = prox.proximity_stats()
178
+ print(stats)
179
+
180
+ # Plot the distance distribution using pandas
181
+ print("\n" + "=" * 80)
182
+ print("Plotting distance distribution...")
183
+ print("=" * 80)
184
+ prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
185
+
186
+ # Visualize the 2D projection
187
+ print("\n" + "=" * 80)
188
+ print("Visualizing 2D Projection...")
189
+ print("=" * 80)
190
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
191
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
192
+
193
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
194
+ unit_test.run()
@@ -8,7 +8,7 @@ TEMPLATE_PARAMS = {
8
8
  "id_column": "{{id_column}}",
9
9
  "features": "{{feature_list}}",
10
10
  "target": "{{target_column}}",
11
- "track_columns": "{{track_columns}}"
11
+ "include_all_columns": "{{include_all_columns}}",
12
12
  }
13
13
 
14
14
  from io import StringIO
@@ -18,7 +18,7 @@ import os
18
18
  import pandas as pd
19
19
 
20
20
  # Local Imports
21
- from proximity import Proximity
21
+ from feature_space_proximity import FeatureSpaceProximity
22
22
 
23
23
 
24
24
  # Function to check if dataframe is empty
@@ -61,7 +61,7 @@ if __name__ == "__main__":
61
61
  id_column = TEMPLATE_PARAMS["id_column"]
62
62
  features = TEMPLATE_PARAMS["features"]
63
63
  target = TEMPLATE_PARAMS["target"] # Can be None for unsupervised models
64
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
64
+ include_all_columns = TEMPLATE_PARAMS["include_all_columns"] # Defaults to False
65
65
 
66
66
  # Script arguments for input/output directories
67
67
  parser = argparse.ArgumentParser()
@@ -73,26 +73,24 @@ if __name__ == "__main__":
73
73
  args = parser.parse_args()
74
74
 
75
75
  # Load training data from the specified directory
76
- training_files = [
77
- os.path.join(args.train, file)
78
- for file in os.listdir(args.train) if file.endswith(".csv")
79
- ]
76
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
80
77
  all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
81
78
 
82
79
  # Check if the DataFrame is empty
83
80
  check_dataframe(all_df, "training_df")
84
81
 
85
- # Create the Proximity model
86
- model = Proximity(all_df, id_column, features, target, track_columns=track_columns)
82
+ # Create the FeatureSpaceProximity model
83
+ model = FeatureSpaceProximity(all_df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns)
87
84
 
88
85
  # Now serialize the model
89
86
  model.serialize(args.model_dir)
90
87
 
88
+
91
89
  # Model loading and prediction functions
92
90
  def model_fn(model_dir):
93
91
 
94
92
  # Deserialize the model
95
- model = Proximity.deserialize(model_dir)
93
+ model = FeatureSpaceProximity.deserialize(model_dir)
96
94
  return model
97
95
 
98
96
 
@@ -14,7 +14,7 @@ import pandas as pd
14
14
  TEMPLATE_PARAMS = {
15
15
  "features": "{{feature_list}}",
16
16
  "target": "{{target_column}}",
17
- "train_all_data": "{{train_all_data}}"
17
+ "train_all_data": "{{train_all_data}}",
18
18
  }
19
19
 
20
20
 
@@ -37,7 +37,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
37
37
  """
38
38
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
39
39
  Prioritizes exact matches, then case-insensitive matches.
40
-
40
+
41
41
  Raises ValueError if any model features cannot be matched.
42
42
  """
43
43
  df_columns_lower = {col.lower(): col for col in df.columns}
@@ -81,10 +81,7 @@ if __name__ == "__main__":
81
81
  args = parser.parse_args()
82
82
 
83
83
  # Load training data from the specified directory
84
- training_files = [
85
- os.path.join(args.train, file)
86
- for file in os.listdir(args.train) if file.endswith(".csv")
87
- ]
84
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
88
85
  df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
89
86
 
90
87
  # Check if the DataFrame is empty
@@ -109,8 +106,10 @@ if __name__ == "__main__":
109
106
  # Create and train the Regression/Confidence model
110
107
  # model = BayesianRidge()
111
108
  model = BayesianRidge(
112
- alpha_1=1e-6, alpha_2=1e-6, # Noise precision
113
- lambda_1=1e-6, lambda_2=1e-6, # Weight precision
109
+ alpha_1=1e-6,
110
+ alpha_2=1e-6, # Noise precision
111
+ lambda_1=1e-6,
112
+ lambda_2=1e-6, # Weight precision
114
113
  fit_intercept=True,
115
114
  )
116
115
 
@@ -4,13 +4,10 @@ import awswrangler as wr
4
4
  import numpy as np
5
5
 
6
6
  # Model Performance Scores
7
- from sklearn.metrics import (
8
- mean_absolute_error,
9
- r2_score,
10
- root_mean_squared_error
11
- )
7
+ from sklearn.metrics import mean_absolute_error, median_absolute_error, r2_score, root_mean_squared_error
12
8
  from sklearn.model_selection import KFold
13
9
  from scipy.optimize import minimize
10
+ from scipy.stats import spearmanr
14
11
 
15
12
  from io import StringIO
16
13
  import json
@@ -23,7 +20,7 @@ TEMPLATE_PARAMS = {
23
20
  "features": "{{feature_list}}",
24
21
  "target": "{{target_column}}",
25
22
  "train_all_data": "{{train_all_data}}",
26
- "model_metrics_s3_path": "{{model_metrics_s3_path}}"
23
+ "model_metrics_s3_path": "{{model_metrics_s3_path}}",
27
24
  }
28
25
 
29
26
 
@@ -47,7 +44,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
47
44
  """
48
45
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
49
46
  Prioritizes exact matches, then case-insensitive matches.
50
-
47
+
51
48
  Raises ValueError if any model features cannot be matched.
52
49
  """
53
50
  df_columns_lower = {col.lower(): col for col in df.columns}
@@ -90,10 +87,7 @@ if __name__ == "__main__":
90
87
  args = parser.parse_args()
91
88
 
92
89
  # Load training data from the specified directory
93
- training_files = [
94
- os.path.join(args.train, file)
95
- for file in os.listdir(args.train) if file.endswith(".csv")
96
- ]
90
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
97
91
  df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
98
92
 
99
93
  # Check if the DataFrame is empty
@@ -172,16 +166,14 @@ if __name__ == "__main__":
172
166
  cv_residuals = np.array(cv_residuals)
173
167
  cv_uncertainties = np.array(cv_uncertainties)
174
168
 
175
-
176
169
  # Optimize calibration parameters: σ_cal = a * σ_uc + b
177
170
  def neg_log_likelihood(params):
178
171
  a, b = params
179
172
  sigma_cal = a * cv_uncertainties + b
180
173
  sigma_cal = np.maximum(sigma_cal, 1e-8) # Prevent division by zero
181
- return np.sum(0.5 * np.log(2 * np.pi * sigma_cal ** 2) + 0.5 * (cv_residuals ** 2) / (sigma_cal ** 2))
174
+ return np.sum(0.5 * np.log(2 * np.pi * sigma_cal**2) + 0.5 * (cv_residuals**2) / (sigma_cal**2))
182
175
 
183
-
184
- result = minimize(neg_log_likelihood, x0=[1.0, 0.1], method='Nelder-Mead')
176
+ result = minimize(neg_log_likelihood, x0=[1.0, 0.1], method="Nelder-Mead")
185
177
  cal_a, cal_b = result.x
186
178
 
187
179
  print(f"Calibration parameters: a={cal_a:.4f}, b={cal_b:.4f}")
@@ -205,7 +197,9 @@ if __name__ == "__main__":
205
197
  result_df["prediction"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].mean(axis=1)
206
198
 
207
199
  # Compute uncalibrated uncertainty
208
- result_df["prediction_std_uc"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].std(axis=1)
200
+ result_df["prediction_std_uc"] = result_df[[name for name in result_df.columns if name.startswith("m_")]].std(
201
+ axis=1
202
+ )
209
203
 
210
204
  # Apply calibration to uncertainty
211
205
  result_df["prediction_std"] = cal_a * result_df["prediction_std_uc"] + cal_b
@@ -224,11 +218,16 @@ if __name__ == "__main__":
224
218
  # Report Performance Metrics
225
219
  rmse = root_mean_squared_error(result_df[target], result_df["prediction"])
226
220
  mae = mean_absolute_error(result_df[target], result_df["prediction"])
221
+ medae = median_absolute_error(result_df[target], result_df["prediction"])
227
222
  r2 = r2_score(result_df[target], result_df["prediction"])
228
- print(f"RMSE: {rmse:.3f}")
229
- print(f"MAE: {mae:.3f}")
230
- print(f"R2: {r2:.3f}")
231
- print(f"NumRows: {len(result_df)}")
223
+ spearman_corr = spearmanr(result_df[target], result_df["prediction"]).correlation
224
+ support = len(result_df)
225
+ print(f"rmse: {rmse:.3f}")
226
+ print(f"mae: {mae:.3f}")
227
+ print(f"medae: {medae:.3f}")
228
+ print(f"r2: {r2:.3f}")
229
+ print(f"spearmanr: {spearman_corr:.3f}")
230
+ print(f"support: {support}")
232
231
 
233
232
  # Now save the models
234
233
  for name, model in models.items():
@@ -352,4 +351,4 @@ def predict_fn(df, models) -> pd.DataFrame:
352
351
  df = df.reindex(sorted(df.columns), axis=1)
353
352
 
354
353
  # All done, return the DataFrame
355
- return df
354
+ return df