workbench 0.8.213__py3-none-any.whl → 0.8.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/algorithms/sql/outliers.py +3 -3
  9. workbench/api/__init__.py +3 -0
  10. workbench/api/endpoint.py +10 -5
  11. workbench/api/feature_set.py +76 -6
  12. workbench/api/meta_model.py +289 -0
  13. workbench/api/model.py +43 -4
  14. workbench/core/artifacts/endpoint_core.py +65 -117
  15. workbench/core/artifacts/feature_set_core.py +3 -3
  16. workbench/core/artifacts/model_core.py +6 -4
  17. workbench/core/pipelines/pipeline_executor.py +1 -1
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  19. workbench/model_script_utils/model_script_utils.py +15 -11
  20. workbench/model_script_utils/pytorch_utils.py +11 -1
  21. workbench/model_scripts/chemprop/chemprop.template +147 -71
  22. workbench/model_scripts/chemprop/generated_model_script.py +151 -75
  23. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  24. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  25. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  27. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  28. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  29. workbench/model_scripts/meta_model/meta_model.template +209 -0
  30. workbench/model_scripts/pytorch_model/generated_model_script.py +45 -27
  31. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  32. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  33. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  34. workbench/model_scripts/script_generation.py +4 -0
  35. workbench/model_scripts/xgb_model/generated_model_script.py +167 -156
  36. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  37. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  38. workbench/repl/workbench_shell.py +0 -5
  39. workbench/scripts/endpoint_test.py +2 -2
  40. workbench/scripts/meta_model_sim.py +35 -0
  41. workbench/utils/chem_utils/fingerprints.py +87 -46
  42. workbench/utils/chemprop_utils.py +23 -5
  43. workbench/utils/meta_model_simulator.py +499 -0
  44. workbench/utils/metrics_utils.py +94 -10
  45. workbench/utils/model_utils.py +91 -9
  46. workbench/utils/pytorch_utils.py +1 -1
  47. workbench/utils/shap_utils.py +1 -55
  48. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  49. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/METADATA +2 -1
  50. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/RECORD +54 -50
  51. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/entry_points.txt +1 -0
  52. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  53. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  54. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  55. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  56. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/WHEEL +0 -0
  57. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/licenses/LICENSE +0 -0
  58. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,48 @@
1
- """Molecular fingerprint computation utilities"""
1
+ """Molecular fingerprint computation utilities for ADMET modeling.
2
+
3
+ This module provides Morgan count fingerprints, the standard for ADMET prediction.
4
+ Count fingerprints outperform binary fingerprints for molecular property prediction.
5
+
6
+ References:
7
+ - Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
8
+ - ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
9
+ """
2
10
 
3
11
  import logging
4
- import pandas as pd
5
12
 
6
- # Molecular Descriptor Imports
7
- from rdkit import Chem
8
- from rdkit.Chem import rdFingerprintGenerator
13
+ import numpy as np
14
+ import pandas as pd
15
+ from rdkit import Chem, RDLogger
16
+ from rdkit.Chem import AllChem
9
17
  from rdkit.Chem.MolStandardize import rdMolStandardize
10
18
 
19
+ # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
20
+ # Keep errors enabled so we see actual problems
21
+ RDLogger.DisableLog("rdApp.warning")
22
+
11
23
  # Set up the logger
12
24
  log = logging.getLogger("workbench")
13
25
 
14
26
 
15
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
16
- """Compute and add Morgan fingerprints to the DataFrame.
27
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
28
+ """Compute Morgan count fingerprints for ADMET modeling.
29
+
30
+ Generates true count fingerprints where each bit position contains the
31
+ number of times that substructure appears in the molecule (clamped to 0-255).
32
+ This is the recommended approach for ADMET prediction per 2025 research.
17
33
 
18
34
  Args:
19
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
20
- radius (int): Radius for the Morgan fingerprint.
21
- n_bits (int): Number of bits for the fingerprint.
22
- counts (bool): Count simulation for the fingerprint.
35
+ df: Input DataFrame containing SMILES strings.
36
+ radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
37
+ n_bits: Number of bits for the fingerprint (default 2048).
23
38
 
24
39
  Returns:
25
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
40
+ pd.DataFrame: Input DataFrame with 'fingerprint' column added.
41
+ Values are comma-separated uint8 counts.
26
42
 
27
43
  Note:
28
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
44
+ Count fingerprints outperform binary for ADMET prediction.
45
+ See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
29
46
  """
30
47
  delete_mol_column = False
31
48
 
@@ -39,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
39
56
  log.warning("Detected serialized molecules in 'molecule' column. Removing...")
40
57
  del df["molecule"]
41
58
 
42
- # Convert SMILES to RDKit molecule objects (vectorized)
59
+ # Convert SMILES to RDKit molecule objects
43
60
  if "molecule" not in df.columns:
44
61
  log.info("Converting SMILES to RDKit Molecules...")
45
62
  delete_mol_column = True
@@ -47,23 +64,32 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
47
64
  # Make sure our molecules are not None
48
65
  failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
49
66
  if failed_smiles:
50
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
51
- df = df.dropna(subset=["molecule"])
67
+ log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
68
+ df = df.dropna(subset=["molecule"]).copy()
52
69
 
53
70
  # If we have fragments in our compounds, get the largest fragment before computing fingerprints
54
71
  largest_frags = df["molecule"].apply(
55
72
  lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
56
73
  )
57
74
 
58
- # Create a Morgan fingerprint generator
59
- if counts:
60
- n_bits *= 4 # Multiply by 4 to simulate counts
61
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
75
+ def mol_to_count_string(mol):
76
+ """Convert molecule to comma-separated count fingerprint string."""
77
+ if mol is None:
78
+ return pd.NA
62
79
 
63
- # Compute Morgan fingerprints (vectorized)
64
- fingerprints = largest_frags.apply(
65
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
66
- )
80
+ # Get hashed Morgan fingerprint with counts
81
+ fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
82
+
83
+ # Initialize array and populate with counts (clamped to uint8 range)
84
+ counts = np.zeros(n_bits, dtype=np.uint8)
85
+ for idx, count in fp.GetNonzeroElements().items():
86
+ counts[idx] = min(count, 255)
87
+
88
+ # Return as comma-separated string
89
+ return ",".join(map(str, counts))
90
+
91
+ # Compute Morgan count fingerprints
92
+ fingerprints = largest_frags.apply(mol_to_count_string)
67
93
 
68
94
  # Add the fingerprints to the DataFrame
69
95
  df["fingerprint"] = fingerprints
@@ -71,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
71
97
  # Drop the intermediate 'molecule' column if it was added
72
98
  if delete_mol_column:
73
99
  del df["molecule"]
100
+
74
101
  return df
75
102
 
76
103
 
77
104
  if __name__ == "__main__":
78
- print("Running molecular fingerprint tests...")
79
- print("Note: This requires molecular_screening module to be available")
105
+ print("Running Morgan count fingerprint tests...")
80
106
 
81
107
  # Test molecules
82
108
  test_molecules = {
83
109
  "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
84
110
  "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
85
111
  "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
86
- "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
87
113
  "benzene": "c1ccccc1",
88
114
  "butene_e": "C/C=C/C", # E-butene
89
115
  "butene_z": "C/C=C\\C", # Z-butene
90
116
  }
91
117
 
92
- # Test 1: Morgan Fingerprints
93
- print("\n1. Testing Morgan fingerprint generation...")
118
+ # Test 1: Morgan Count Fingerprints (default parameters)
119
+ print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
94
120
 
95
121
  test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
96
-
97
- fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
122
+ fp_df = compute_morgan_fingerprints(test_df.copy())
98
123
 
99
124
  print(" Fingerprint generation results:")
100
125
  for _, row in fp_df.iterrows():
101
126
  fp = row.get("fingerprint", "N/A")
102
- fp_len = len(fp) if fp != "N/A" else 0
103
- print(f" {row['name']:15} {fp_len} bits")
127
+ if pd.notna(fp):
128
+ counts = [int(x) for x in fp.split(",")]
129
+ non_zero = sum(1 for c in counts if c > 0)
130
+ max_count = max(counts)
131
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
132
+ else:
133
+ print(f" {row['name']:15} → N/A")
104
134
 
105
- # Test 2: Different fingerprint parameters
106
- print("\n2. Testing different fingerprint parameters...")
135
+ # Test 2: Different parameters
136
+ print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
107
137
 
108
- # Test with counts enabled
109
- fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
138
+ fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
110
139
 
111
- print(" With count simulation (256 bits * 4):")
112
- for _, row in fp_counts_df.iterrows():
140
+ for _, row in fp_df_custom.iterrows():
113
141
  fp = row.get("fingerprint", "N/A")
114
- fp_len = len(fp) if fp != "N/A" else 0
115
- print(f" {row['name']:15} {fp_len} bits")
142
+ if pd.notna(fp):
143
+ counts = [int(x) for x in fp.split(",")]
144
+ non_zero = sum(1 for c in counts if c > 0)
145
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
146
+ else:
147
+ print(f" {row['name']:15} → N/A")
116
148
 
117
149
  # Test 3: Edge cases
118
150
  print("\n3. Testing edge cases...")
119
151
 
120
152
  # Invalid SMILES
121
153
  invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
122
- try:
123
- fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
124
- print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
125
- except Exception as e:
126
- print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
154
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
155
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
127
156
 
128
157
  # Test with pre-existing molecule column
129
158
  mol_df = test_df.copy()
@@ -131,4 +160,16 @@ if __name__ == "__main__":
131
160
  fp_with_mol = compute_morgan_fingerprints(mol_df)
132
161
  print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
133
162
 
163
+ # Test 4: Verify count values are reasonable
164
+ print("\n4. Verifying count distribution...")
165
+ all_counts = []
166
+ for _, row in fp_df.iterrows():
167
+ fp = row.get("fingerprint", "N/A")
168
+ if pd.notna(fp):
169
+ counts = [int(x) for x in fp.split(",")]
170
+ all_counts.extend([c for c in counts if c > 0])
171
+
172
+ if all_counts:
173
+ print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
174
+
134
175
  print("\n✅ All fingerprint tests completed!")
@@ -0,0 +1,194 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sklearn.neighbors import NearestNeighbors
5
+ from typing import List, Optional
6
+ import logging
7
+
8
+ # Workbench Imports
9
+ from workbench.algorithms.dataframe.proximity import Proximity
10
+ from workbench.algorithms.dataframe.projection_2d import Projection2D
11
+
12
+ # Set up logging
13
+ log = logging.getLogger("workbench")
14
+
15
+
16
+ class FeatureSpaceProximity(Proximity):
17
+ """Proximity computations for numeric feature spaces using Euclidean distance."""
18
+
19
+ def __init__(
20
+ self,
21
+ df: pd.DataFrame,
22
+ id_column: str,
23
+ features: List[str],
24
+ target: Optional[str] = None,
25
+ include_all_columns: bool = False,
26
+ ):
27
+ """
28
+ Initialize the FeatureSpaceProximity class.
29
+
30
+ Args:
31
+ df: DataFrame containing data for neighbor computations.
32
+ id_column: Name of the column used as the identifier.
33
+ features: List of feature column names to be used for neighbor computations.
34
+ target: Name of the target column. Defaults to None.
35
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
36
+ """
37
+ # Validate and filter features before calling parent init
38
+ self._raw_features = features
39
+ super().__init__(
40
+ df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
41
+ )
42
+
43
+ def _prepare_data(self) -> None:
44
+ """Filter out non-numeric features and drop NaN rows."""
45
+ # Validate features
46
+ self.features = self._validate_features(self.df, self._raw_features)
47
+
48
+ # Drop NaN rows for the features we're using
49
+ self.df = self.df.dropna(subset=self.features).copy()
50
+
51
+ def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
52
+ """Remove non-numeric features and log warnings."""
53
+ non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
54
+ if non_numeric:
55
+ log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
56
+ return [f for f in features if f not in non_numeric]
57
+
58
+ def _build_model(self) -> None:
59
+ """Standardize features and fit Nearest Neighbors model."""
60
+ self.scaler = StandardScaler()
61
+ X = self.scaler.fit_transform(self.df[self.features])
62
+ self.nn = NearestNeighbors().fit(X)
63
+
64
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
65
+ """Transform features using the fitted scaler."""
66
+ return self.scaler.transform(df[self.features])
67
+
68
+ def _project_2d(self) -> None:
69
+ """Project the numeric features to 2D for visualization."""
70
+ if len(self.features) >= 2:
71
+ self.df = Projection2D().fit_transform(self.df, features=self.features)
72
+
73
+
74
+ # Testing the FeatureSpaceProximity class
75
+ if __name__ == "__main__":
76
+
77
+ pd.set_option("display.max_columns", None)
78
+ pd.set_option("display.width", 1000)
79
+
80
+ # Create a sample DataFrame
81
+ data = {
82
+ "ID": [1, 2, 3, 4, 5],
83
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
84
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
85
+ "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
86
+ }
87
+ df = pd.DataFrame(data)
88
+
89
+ # Test the FeatureSpaceProximity class
90
+ features = ["Feature1", "Feature2", "Feature3"]
91
+ prox = FeatureSpaceProximity(df, id_column="ID", features=features)
92
+ print(prox.neighbors(1, n_neighbors=2))
93
+
94
+ # Test the neighbors method with radius
95
+ print(prox.neighbors(1, radius=2.0))
96
+
97
+ # Test with Features list
98
+ prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
99
+ print(prox.neighbors(1))
100
+
101
+ # Create a sample DataFrame
102
+ data = {
103
+ "id": ["a", "b", "c", "d", "e"], # Testing string IDs
104
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
105
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
106
+ "target": [1, 0, 1, 0, 5],
107
+ }
108
+ df = pd.DataFrame(data)
109
+
110
+ # Test with String Ids
111
+ prox = FeatureSpaceProximity(
112
+ df,
113
+ id_column="id",
114
+ features=["Feature1", "Feature2"],
115
+ target="target",
116
+ include_all_columns=True,
117
+ )
118
+ print(prox.neighbors(["a", "b"]))
119
+
120
+ # Test duplicate IDs
121
+ data = {
122
+ "id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
123
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
124
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
125
+ "target": [1, 0, 1, 0, 5],
126
+ }
127
+ df = pd.DataFrame(data)
128
+ prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
129
+ print(df.equals(prox.df))
130
+
131
+ # Test on real data from Workbench
132
+ from workbench.api import FeatureSet, Model
133
+
134
+ fs = FeatureSet("aqsol_features")
135
+ model = Model("aqsol-regression")
136
+ features = model.features()
137
+ df = fs.pull_dataframe()
138
+ prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
139
+ print("\n" + "=" * 80)
140
+ print("Testing Neighbors...")
141
+ print("=" * 80)
142
+ test_id = df[fs.id_column].tolist()[0]
143
+ print(f"\nNeighbors for ID {test_id}:")
144
+ print(prox.neighbors(test_id))
145
+
146
+ print("\n" + "=" * 80)
147
+ print("Testing isolated_compounds...")
148
+ print("=" * 80)
149
+
150
+ # Test isolated data in the top 1%
151
+ isolated_1pct = prox.isolated(top_percent=1.0)
152
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
153
+ print(isolated_1pct)
154
+
155
+ # Test isolated data in the top 5%
156
+ isolated_5pct = prox.isolated(top_percent=5.0)
157
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
158
+ print(isolated_5pct)
159
+
160
+ print("\n" + "=" * 80)
161
+ print("Testing target_gradients...")
162
+ print("=" * 80)
163
+
164
+ # Test with different parameters
165
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
166
+ print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
167
+ print(gradients_1pct)
168
+
169
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
170
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
171
+ print(gradients_5pct)
172
+
173
+ # Test proximity_stats
174
+ print("\n" + "=" * 80)
175
+ print("Testing proximity_stats...")
176
+ print("=" * 80)
177
+ stats = prox.proximity_stats()
178
+ print(stats)
179
+
180
+ # Plot the distance distribution using pandas
181
+ print("\n" + "=" * 80)
182
+ print("Plotting distance distribution...")
183
+ print("=" * 80)
184
+ prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
185
+
186
+ # Visualize the 2D projection
187
+ print("\n" + "=" * 80)
188
+ print("Visualizing 2D Projection...")
189
+ print("=" * 80)
190
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
191
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
192
+
193
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
194
+ unit_test.run()
@@ -8,7 +8,7 @@ TEMPLATE_PARAMS = {
8
8
  "id_column": "{{id_column}}",
9
9
  "features": "{{feature_list}}",
10
10
  "target": "{{target_column}}",
11
- "track_columns": "{{track_columns}}",
11
+ "include_all_columns": "{{include_all_columns}}",
12
12
  }
13
13
 
14
14
  from io import StringIO
@@ -18,7 +18,7 @@ import os
18
18
  import pandas as pd
19
19
 
20
20
  # Local Imports
21
- from proximity import Proximity
21
+ from feature_space_proximity import FeatureSpaceProximity
22
22
 
23
23
 
24
24
  # Function to check if dataframe is empty
@@ -61,7 +61,7 @@ if __name__ == "__main__":
61
61
  id_column = TEMPLATE_PARAMS["id_column"]
62
62
  features = TEMPLATE_PARAMS["features"]
63
63
  target = TEMPLATE_PARAMS["target"] # Can be None for unsupervised models
64
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
64
+ include_all_columns = TEMPLATE_PARAMS["include_all_columns"] # Defaults to False
65
65
 
66
66
  # Script arguments for input/output directories
67
67
  parser = argparse.ArgumentParser()
@@ -79,8 +79,8 @@ if __name__ == "__main__":
79
79
  # Check if the DataFrame is empty
80
80
  check_dataframe(all_df, "training_df")
81
81
 
82
- # Create the Proximity model
83
- model = Proximity(all_df, id_column, features, target, track_columns=track_columns)
82
+ # Create the FeatureSpaceProximity model
83
+ model = FeatureSpaceProximity(all_df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns)
84
84
 
85
85
  # Now serialize the model
86
86
  model.serialize(args.model_dir)
@@ -90,7 +90,7 @@ if __name__ == "__main__":
90
90
  def model_fn(model_dir):
91
91
 
92
92
  # Deserialize the model
93
- model = Proximity.deserialize(model_dir)
93
+ model = FeatureSpaceProximity.deserialize(model_dir)
94
94
  return model
95
95
 
96
96
 
@@ -0,0 +1,194 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sklearn.neighbors import NearestNeighbors
5
+ from typing import List, Optional
6
+ import logging
7
+
8
+ # Workbench Imports
9
+ from workbench.algorithms.dataframe.proximity import Proximity
10
+ from workbench.algorithms.dataframe.projection_2d import Projection2D
11
+
12
+ # Set up logging
13
+ log = logging.getLogger("workbench")
14
+
15
+
16
+ class FeatureSpaceProximity(Proximity):
17
+ """Proximity computations for numeric feature spaces using Euclidean distance."""
18
+
19
+ def __init__(
20
+ self,
21
+ df: pd.DataFrame,
22
+ id_column: str,
23
+ features: List[str],
24
+ target: Optional[str] = None,
25
+ include_all_columns: bool = False,
26
+ ):
27
+ """
28
+ Initialize the FeatureSpaceProximity class.
29
+
30
+ Args:
31
+ df: DataFrame containing data for neighbor computations.
32
+ id_column: Name of the column used as the identifier.
33
+ features: List of feature column names to be used for neighbor computations.
34
+ target: Name of the target column. Defaults to None.
35
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
36
+ """
37
+ # Validate and filter features before calling parent init
38
+ self._raw_features = features
39
+ super().__init__(
40
+ df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
41
+ )
42
+
43
+ def _prepare_data(self) -> None:
44
+ """Filter out non-numeric features and drop NaN rows."""
45
+ # Validate features
46
+ self.features = self._validate_features(self.df, self._raw_features)
47
+
48
+ # Drop NaN rows for the features we're using
49
+ self.df = self.df.dropna(subset=self.features).copy()
50
+
51
+ def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
52
+ """Remove non-numeric features and log warnings."""
53
+ non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
54
+ if non_numeric:
55
+ log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
56
+ return [f for f in features if f not in non_numeric]
57
+
58
+ def _build_model(self) -> None:
59
+ """Standardize features and fit Nearest Neighbors model."""
60
+ self.scaler = StandardScaler()
61
+ X = self.scaler.fit_transform(self.df[self.features])
62
+ self.nn = NearestNeighbors().fit(X)
63
+
64
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
65
+ """Transform features using the fitted scaler."""
66
+ return self.scaler.transform(df[self.features])
67
+
68
+ def _project_2d(self) -> None:
69
+ """Project the numeric features to 2D for visualization."""
70
+ if len(self.features) >= 2:
71
+ self.df = Projection2D().fit_transform(self.df, features=self.features)
72
+
73
+
74
+ # Testing the FeatureSpaceProximity class
75
+ if __name__ == "__main__":
76
+
77
+ pd.set_option("display.max_columns", None)
78
+ pd.set_option("display.width", 1000)
79
+
80
+ # Create a sample DataFrame
81
+ data = {
82
+ "ID": [1, 2, 3, 4, 5],
83
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
84
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
85
+ "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
86
+ }
87
+ df = pd.DataFrame(data)
88
+
89
+ # Test the FeatureSpaceProximity class
90
+ features = ["Feature1", "Feature2", "Feature3"]
91
+ prox = FeatureSpaceProximity(df, id_column="ID", features=features)
92
+ print(prox.neighbors(1, n_neighbors=2))
93
+
94
+ # Test the neighbors method with radius
95
+ print(prox.neighbors(1, radius=2.0))
96
+
97
+ # Test with Features list
98
+ prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
99
+ print(prox.neighbors(1))
100
+
101
+ # Create a sample DataFrame
102
+ data = {
103
+ "id": ["a", "b", "c", "d", "e"], # Testing string IDs
104
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
105
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
106
+ "target": [1, 0, 1, 0, 5],
107
+ }
108
+ df = pd.DataFrame(data)
109
+
110
+ # Test with String Ids
111
+ prox = FeatureSpaceProximity(
112
+ df,
113
+ id_column="id",
114
+ features=["Feature1", "Feature2"],
115
+ target="target",
116
+ include_all_columns=True,
117
+ )
118
+ print(prox.neighbors(["a", "b"]))
119
+
120
+ # Test duplicate IDs
121
+ data = {
122
+ "id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
123
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
124
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
125
+ "target": [1, 0, 1, 0, 5],
126
+ }
127
+ df = pd.DataFrame(data)
128
+ prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
129
+ print(df.equals(prox.df))
130
+
131
+ # Test on real data from Workbench
132
+ from workbench.api import FeatureSet, Model
133
+
134
+ fs = FeatureSet("aqsol_features")
135
+ model = Model("aqsol-regression")
136
+ features = model.features()
137
+ df = fs.pull_dataframe()
138
+ prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
139
+ print("\n" + "=" * 80)
140
+ print("Testing Neighbors...")
141
+ print("=" * 80)
142
+ test_id = df[fs.id_column].tolist()[0]
143
+ print(f"\nNeighbors for ID {test_id}:")
144
+ print(prox.neighbors(test_id))
145
+
146
+ print("\n" + "=" * 80)
147
+ print("Testing isolated_compounds...")
148
+ print("=" * 80)
149
+
150
+ # Test isolated data in the top 1%
151
+ isolated_1pct = prox.isolated(top_percent=1.0)
152
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
153
+ print(isolated_1pct)
154
+
155
+ # Test isolated data in the top 5%
156
+ isolated_5pct = prox.isolated(top_percent=5.0)
157
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
158
+ print(isolated_5pct)
159
+
160
+ print("\n" + "=" * 80)
161
+ print("Testing target_gradients...")
162
+ print("=" * 80)
163
+
164
+ # Test with different parameters
165
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
166
+ print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
167
+ print(gradients_1pct)
168
+
169
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
170
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
171
+ print(gradients_5pct)
172
+
173
+ # Test proximity_stats
174
+ print("\n" + "=" * 80)
175
+ print("Testing proximity_stats...")
176
+ print("=" * 80)
177
+ stats = prox.proximity_stats()
178
+ print(stats)
179
+
180
+ # Plot the distance distribution using pandas
181
+ print("\n" + "=" * 80)
182
+ print("Plotting distance distribution...")
183
+ print("=" * 80)
184
+ prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
185
+
186
+ # Visualize the 2D projection
187
+ print("\n" + "=" * 80)
188
+ print("Visualizing 2D Projection...")
189
+ print("=" * 80)
190
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
191
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
192
+
193
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
194
+ unit_test.run()