workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (84) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  3. workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
  4. workbench/algorithms/dataframe/projection_2d.py +44 -21
  5. workbench/algorithms/dataframe/proximity.py +78 -150
  6. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  7. workbench/algorithms/models/cleanlab_model.py +382 -0
  8. workbench/algorithms/models/noise_model.py +388 -0
  9. workbench/algorithms/sql/outliers.py +3 -3
  10. workbench/api/__init__.py +3 -0
  11. workbench/api/df_store.py +17 -108
  12. workbench/api/endpoint.py +13 -11
  13. workbench/api/feature_set.py +111 -8
  14. workbench/api/meta_model.py +289 -0
  15. workbench/api/model.py +45 -12
  16. workbench/api/parameter_store.py +3 -52
  17. workbench/cached/cached_model.py +4 -4
  18. workbench/core/artifacts/artifact.py +5 -5
  19. workbench/core/artifacts/df_store_core.py +114 -0
  20. workbench/core/artifacts/endpoint_core.py +228 -237
  21. workbench/core/artifacts/feature_set_core.py +185 -230
  22. workbench/core/artifacts/model_core.py +34 -26
  23. workbench/core/artifacts/parameter_store_core.py +98 -0
  24. workbench/core/pipelines/pipeline_executor.py +1 -1
  25. workbench/core/transforms/features_to_model/features_to_model.py +22 -10
  26. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
  27. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  28. workbench/model_script_utils/model_script_utils.py +339 -0
  29. workbench/model_script_utils/pytorch_utils.py +405 -0
  30. workbench/model_script_utils/uq_harness.py +278 -0
  31. workbench/model_scripts/chemprop/chemprop.template +428 -631
  32. workbench/model_scripts/chemprop/generated_model_script.py +432 -635
  33. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  34. workbench/model_scripts/chemprop/requirements.txt +2 -10
  35. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  36. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  37. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  38. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  39. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  40. workbench/model_scripts/meta_model/meta_model.template +209 -0
  41. workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
  42. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  43. workbench/model_scripts/pytorch_model/pytorch.template +370 -609
  44. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  45. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  46. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  47. workbench/model_scripts/script_generation.py +6 -5
  48. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  49. workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
  50. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  51. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  52. workbench/model_scripts/xgb_model/xgb_model.template +366 -396
  53. workbench/repl/workbench_shell.py +0 -5
  54. workbench/resources/open_source_api.key +1 -1
  55. workbench/scripts/endpoint_test.py +2 -2
  56. workbench/scripts/meta_model_sim.py +35 -0
  57. workbench/scripts/training_test.py +85 -0
  58. workbench/utils/chem_utils/fingerprints.py +87 -46
  59. workbench/utils/chem_utils/projections.py +16 -6
  60. workbench/utils/chemprop_utils.py +36 -655
  61. workbench/utils/meta_model_simulator.py +499 -0
  62. workbench/utils/metrics_utils.py +256 -0
  63. workbench/utils/model_utils.py +192 -54
  64. workbench/utils/pytorch_utils.py +33 -472
  65. workbench/utils/shap_utils.py +1 -55
  66. workbench/utils/xgboost_local_crossfold.py +267 -0
  67. workbench/utils/xgboost_model_utils.py +49 -356
  68. workbench/web_interface/components/model_plot.py +7 -1
  69. workbench/web_interface/components/plugins/model_details.py +30 -68
  70. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  71. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
  72. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
  73. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
  74. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  75. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
  76. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  77. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  78. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  79. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  80. workbench/model_scripts/uq_models/mapie.template +0 -605
  81. workbench/model_scripts/uq_models/requirements.txt +0 -1
  82. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  83. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
  84. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,321 @@
1
+ """Compound Dataset Overlap Analysis
2
+
3
+ This module provides utilities for comparing two molecular datasets based on
4
+ Tanimoto similarity in fingerprint space. It helps quantify the "overlap"
5
+ between datasets in chemical space.
6
+
7
+ Use cases:
8
+ - Train/test split validation: Ensure test set isn't too similar to training
9
+ - Dataset comparison: Compare proprietary vs public datasets
10
+ - Novelty assessment: Find compounds in query dataset that are novel vs reference
11
+ """
12
+
13
+ import logging
14
+ from typing import Optional, Tuple
15
+
16
+ import pandas as pd
17
+
18
+ from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity
19
+
20
+ # Set up logging
21
+ log = logging.getLogger("workbench")
22
+
23
+
24
+ class CompoundDatasetOverlap:
25
+ """Compare two molecular datasets using Tanimoto similarity.
26
+
27
+ Builds a FingerprintProximity model on the reference dataset, then queries
28
+ with SMILES from the query dataset to find the nearest neighbor in the
29
+ reference for each query compound. This guarantees cross-dataset matches.
30
+
31
+ Attributes:
32
+ prox: FingerprintProximity instance on reference dataset
33
+ overlap_df: Results DataFrame with similarity scores for each query compound
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ df_reference: pd.DataFrame,
39
+ df_query: pd.DataFrame,
40
+ id_column_reference: str = "id",
41
+ id_column_query: str = "id",
42
+ radius: int = 2,
43
+ n_bits: int = 2048,
44
+ ) -> None:
45
+ """
46
+ Initialize the CompoundDatasetOverlap analysis.
47
+
48
+ Args:
49
+ df_reference: Reference dataset (DataFrame with SMILES)
50
+ df_query: Query dataset (DataFrame with SMILES)
51
+ id_column_reference: ID column name in df_reference
52
+ id_column_query: ID column name in df_query
53
+ radius: Morgan fingerprint radius (default: 2 = ECFP4)
54
+ n_bits: Number of fingerprint bits (default: 2048)
55
+ """
56
+ self.id_column_reference = id_column_reference
57
+ self.id_column_query = id_column_query
58
+ self._radius = radius
59
+ self._n_bits = n_bits
60
+
61
+ # Store copies of the dataframes
62
+ self.df_reference = df_reference.copy()
63
+ self.df_query = df_query.copy()
64
+
65
+ # Find SMILES columns
66
+ self._smiles_col_reference = self._find_smiles_column(self.df_reference)
67
+ self._smiles_col_query = self._find_smiles_column(self.df_query)
68
+
69
+ if self._smiles_col_reference is None:
70
+ raise ValueError("Reference dataset must have a SMILES column")
71
+ if self._smiles_col_query is None:
72
+ raise ValueError("Query dataset must have a SMILES column")
73
+
74
+ log.info(f"Reference dataset: {len(self.df_reference)} compounds")
75
+ log.info(f"Query dataset: {len(self.df_query)} compounds")
76
+
77
+ # Build FingerprintProximity on reference dataset only
78
+ self.prox = FingerprintProximity(
79
+ self.df_reference,
80
+ id_column=id_column_reference,
81
+ radius=radius,
82
+ n_bits=n_bits,
83
+ )
84
+
85
+ # Compute cross-dataset overlap
86
+ self.overlap_df = self._compute_cross_dataset_overlap()
87
+
88
+ @staticmethod
89
+ def _find_smiles_column(df: pd.DataFrame) -> Optional[str]:
90
+ """Find the SMILES column in a DataFrame (case-insensitive)."""
91
+ for col in df.columns:
92
+ if col.lower() == "smiles":
93
+ return col
94
+ return None
95
+
96
+ def _compute_cross_dataset_overlap(self) -> pd.DataFrame:
97
+ """For each query compound, find nearest neighbor in reference using neighbors_from_smiles."""
98
+ log.info(f"Computing nearest neighbors in reference for {len(self.df_query)} query compounds")
99
+
100
+ # Get SMILES list from query dataset
101
+ query_smiles = self.df_query[self._smiles_col_query].tolist()
102
+ query_ids = self.df_query[self.id_column_query].tolist()
103
+
104
+ # Query all compounds against reference (get only nearest neighbor)
105
+ neighbors_df = self.prox.neighbors_from_smiles(query_smiles, n_neighbors=1)
106
+
107
+ # Build results with query IDs
108
+ results = []
109
+ for i, (q_id, q_smi) in enumerate(zip(query_ids, query_smiles)):
110
+ # Find the row for this query SMILES
111
+ match = neighbors_df[neighbors_df["query_id"] == q_smi]
112
+ if len(match) > 0:
113
+ row = match.iloc[0]
114
+ results.append(
115
+ {
116
+ "id": q_id,
117
+ "smiles": q_smi,
118
+ "nearest_neighbor_id": row["neighbor_id"],
119
+ "tanimoto_similarity": row["similarity"],
120
+ }
121
+ )
122
+ else:
123
+ # Should not happen, but handle gracefully
124
+ results.append(
125
+ {
126
+ "id": q_id,
127
+ "smiles": q_smi,
128
+ "nearest_neighbor_id": None,
129
+ "tanimoto_similarity": 0.0,
130
+ }
131
+ )
132
+
133
+ result_df = pd.DataFrame(results)
134
+
135
+ # Add nearest neighbor SMILES from reference
136
+ ref_smiles_map = self.df_reference.set_index(self.id_column_reference)[self._smiles_col_reference]
137
+ result_df["nearest_neighbor_smiles"] = result_df["nearest_neighbor_id"].map(ref_smiles_map)
138
+
139
+ return result_df.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
140
+
141
+ def summary_stats(self) -> pd.DataFrame:
142
+ """Return distribution statistics for nearest-neighbor Tanimoto similarities."""
143
+ return (
144
+ self.overlap_df["tanimoto_similarity"]
145
+ .describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
146
+ .to_frame()
147
+ )
148
+
149
+ def novel_compounds(self, threshold: float = 0.4) -> pd.DataFrame:
150
+ """Return query compounds that are novel (low similarity to reference).
151
+
152
+ Args:
153
+ threshold: Maximum Tanimoto similarity to consider "novel" (default: 0.4)
154
+
155
+ Returns:
156
+ DataFrame of query compounds with similarity below threshold
157
+ """
158
+ novel = self.overlap_df[self.overlap_df["tanimoto_similarity"] < threshold].copy()
159
+ return novel.sort_values("tanimoto_similarity", ascending=True).reset_index(drop=True)
160
+
161
+ def similar_compounds(self, threshold: float = 0.7) -> pd.DataFrame:
162
+ """Return query compounds that are similar to reference (high overlap).
163
+
164
+ Args:
165
+ threshold: Minimum Tanimoto similarity to consider "similar" (default: 0.7)
166
+
167
+ Returns:
168
+ DataFrame of query compounds with similarity above threshold
169
+ """
170
+ similar = self.overlap_df[self.overlap_df["tanimoto_similarity"] >= threshold].copy()
171
+ return similar.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
172
+
173
+ def overlap_fraction(self, threshold: float = 0.7) -> float:
174
+ """Return fraction of query compounds that overlap with reference above similarity threshold.
175
+
176
+ Args:
177
+ threshold: Minimum Tanimoto similarity to consider "overlapping"
178
+
179
+ Returns:
180
+ Fraction of query compounds with nearest neighbor similarity >= threshold
181
+ """
182
+ n_overlapping = (self.overlap_df["tanimoto_similarity"] >= threshold).sum()
183
+ return n_overlapping / len(self.overlap_df)
184
+
185
+ def plot_histogram(self, bins: int = 50, figsize: Tuple[int, int] = (10, 6)) -> None:
186
+ """Plot histogram of nearest-neighbor Tanimoto similarities.
187
+
188
+ Args:
189
+ bins: Number of histogram bins
190
+ figsize: Figure size (width, height)
191
+ """
192
+ import matplotlib.pyplot as plt
193
+
194
+ fig, ax = plt.subplots(figsize=figsize)
195
+ ax.hist(self.overlap_df["tanimoto_similarity"], bins=bins, edgecolor="black", alpha=0.7)
196
+ ax.set_xlabel("Tanimoto Similarity (query → nearest in reference)")
197
+ ax.set_ylabel("Count")
198
+ ax.set_title(f"Dataset Overlap: {len(self.overlap_df)} query compounds")
199
+ ax.axvline(x=0.4, color="red", linestyle="--", label="Novel threshold (0.4)")
200
+ ax.axvline(x=0.7, color="green", linestyle="--", label="Similar threshold (0.7)")
201
+ ax.legend()
202
+
203
+ # Add summary stats as text
204
+ stats = self.overlap_df["tanimoto_similarity"]
205
+ textstr = f"Mean: {stats.mean():.3f}\nMedian: {stats.median():.3f}\nStd: {stats.std():.3f}"
206
+ ax.text(
207
+ 0.02,
208
+ 0.98,
209
+ textstr,
210
+ transform=ax.transAxes,
211
+ verticalalignment="top",
212
+ bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
213
+ )
214
+
215
+ plt.tight_layout()
216
+ plt.show()
217
+
218
+
219
+ # =============================================================================
220
+ # Testing
221
+ # =============================================================================
222
+ if __name__ == "__main__":
223
+ print("=" * 80)
224
+ print("Testing CompoundDatasetOverlap")
225
+ print("=" * 80)
226
+
227
+ # Test 1: Basic functionality with SMILES data
228
+ print("\n1. Testing with SMILES data...")
229
+
230
+ # Reference dataset: Known drug-like compounds
231
+ reference_data = {
232
+ "id": ["aspirin", "caffeine", "glucose", "ibuprofen", "naproxen", "ethanol", "methanol", "propanol"],
233
+ "smiles": [
234
+ "CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
235
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
236
+ "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # glucose
237
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
238
+ "COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
239
+ "CCO", # ethanol
240
+ "CO", # methanol
241
+ "CCCO", # propanol
242
+ ],
243
+ }
244
+
245
+ # Query dataset: Compounds to compare against reference
246
+ query_data = {
247
+ "id": ["acetaminophen", "theophylline", "benzene", "toluene", "phenol", "aniline"],
248
+ "smiles": [
249
+ "CC(=O)NC1=CC=C(C=C1)O", # acetaminophen - similar to aspirin
250
+ "CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
251
+ "c1ccccc1", # benzene - simple aromatic
252
+ "Cc1ccccc1", # toluene - similar to benzene
253
+ "Oc1ccccc1", # phenol - hydroxyl benzene
254
+ "Nc1ccccc1", # aniline - amino benzene
255
+ ],
256
+ }
257
+
258
+ df_reference = pd.DataFrame(reference_data)
259
+ df_query = pd.DataFrame(query_data)
260
+
261
+ print(f" Reference: {len(df_reference)} compounds, Query: {len(df_query)} compounds")
262
+
263
+ overlap = CompoundDatasetOverlap(
264
+ df_reference, df_query, id_column_reference="id", id_column_query="id", radius=2, n_bits=1024
265
+ )
266
+
267
+ print("\n Overlap results:")
268
+ print(overlap.overlap_df[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
269
+
270
+ print("\n Summary statistics:")
271
+ print(overlap.summary_stats())
272
+
273
+ # Test 2: Novel and similar compound identification
274
+ print("\n2. Testing novel/similar compound identification...")
275
+
276
+ similar = overlap.similar_compounds(threshold=0.3)
277
+ print(f" Similar compounds (sim >= 0.3): {len(similar)}")
278
+ if len(similar) > 0:
279
+ print(similar[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
280
+
281
+ novel = overlap.novel_compounds(threshold=0.3)
282
+ print(f"\n Novel compounds (sim < 0.3): {len(novel)}")
283
+ if len(novel) > 0:
284
+ print(novel[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
285
+
286
+ # Test 3: With Workbench data (if available)
287
+ print("\n3. Testing with Workbench FeatureSet (if available)...")
288
+
289
+ try:
290
+ from workbench.api import FeatureSet
291
+
292
+ fs = FeatureSet("aqsol_features")
293
+ full_df = fs.pull_dataframe()[:1000] # Limit to first 1000 for testing
294
+
295
+ # Split into reference and query sets
296
+ df_reference = full_df.sample(frac=0.8, random_state=42)
297
+ df_query = full_df.drop(df_reference.index)
298
+
299
+ print(f" Reference set: {len(df_reference)} compounds")
300
+ print(f" Query set: {len(df_query)} compounds")
301
+
302
+ overlap = CompoundDatasetOverlap(
303
+ df_reference, df_query, id_column_reference=fs.id_column, id_column_query=fs.id_column
304
+ )
305
+
306
+ print("\n Summary statistics:")
307
+ print(overlap.summary_stats())
308
+
309
+ print(f"\n Overlap fraction (sim >= 0.7): {overlap.overlap_fraction(0.7):.2%}")
310
+ print(f" Overlap fraction (sim >= 0.5): {overlap.overlap_fraction(0.5):.2%}")
311
+ print(f" Novel compounds (sim < 0.4): {len(overlap.novel_compounds(0.4))}")
312
+
313
+ # Uncomment to show histogram
314
+ overlap.plot_histogram()
315
+
316
+ except Exception as e:
317
+ print(f" Skipping Workbench test: {e}")
318
+
319
+ print("\n" + "=" * 80)
320
+ print("✅ All CompoundDatasetOverlap tests completed!")
321
+ print("=" * 80)
@@ -1,101 +1,194 @@
1
1
  import pandas as pd
2
+ import numpy as np
3
+ from sklearn.preprocessing import StandardScaler
4
+ from sklearn.neighbors import NearestNeighbors
5
+ from typing import List, Optional
2
6
  import logging
3
7
 
4
8
  # Workbench Imports
5
9
  from workbench.algorithms.dataframe.proximity import Proximity
6
10
  from workbench.algorithms.dataframe.projection_2d import Projection2D
7
- from workbench.core.views.inference_view import InferenceView
8
- from workbench.api import FeatureSet, Model
9
11
 
10
12
  # Set up logging
11
13
  log = logging.getLogger("workbench")
12
14
 
13
15
 
14
16
  class FeatureSpaceProximity(Proximity):
15
- def __init__(self, model: Model, n_neighbors: int = 10) -> None:
17
+ """Proximity computations for numeric feature spaces using Euclidean distance."""
18
+
19
+ def __init__(
20
+ self,
21
+ df: pd.DataFrame,
22
+ id_column: str,
23
+ features: List[str],
24
+ target: Optional[str] = None,
25
+ include_all_columns: bool = False,
26
+ ):
16
27
  """
17
28
  Initialize the FeatureSpaceProximity class.
18
29
 
19
30
  Args:
20
- model (Model): A Workbench model object.
21
- n_neighbors (int): Number of neighbors to compute. Defaults to 10.
31
+ df: DataFrame containing data for neighbor computations.
32
+ id_column: Name of the column used as the identifier.
33
+ features: List of feature column names to be used for neighbor computations.
34
+ target: Name of the target column. Defaults to None.
35
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
22
36
  """
23
-
24
- # Grab the features and target from the model
25
- features = model.features()
26
- target = model.target()
27
-
28
- # Grab the feature set for the model
29
- fs = FeatureSet(model.get_input())
30
-
31
- # If we have a "inference" view, pull the data from that view
32
- view_name = f"inf_{model.name.replace('-', '_')}"
33
- if view_name in fs.views():
34
- self.df = fs.view(view_name).pull_dataframe()
35
-
36
- # Otherwise, pull the data from the feature set and run inference
37
- else:
38
- inf_view = InferenceView.create(model)
39
- self.df = inf_view.pull_dataframe()
40
-
41
- # Call the parent class constructor
42
- super().__init__(self.df, id_column=fs.id_column, features=features, target=target, n_neighbors=n_neighbors)
43
-
44
- # Project the data to 2D
45
- self.df = Projection2D().fit_transform(self.df, features=features)
46
-
47
-
37
+ # Validate and filter features before calling parent init
38
+ self._raw_features = features
39
+ super().__init__(
40
+ df, id_column=id_column, features=features, target=target, include_all_columns=include_all_columns
41
+ )
42
+
43
+ def _prepare_data(self) -> None:
44
+ """Filter out non-numeric features and drop NaN rows."""
45
+ # Validate features
46
+ self.features = self._validate_features(self.df, self._raw_features)
47
+
48
+ # Drop NaN rows for the features we're using
49
+ self.df = self.df.dropna(subset=self.features).copy()
50
+
51
+ def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
52
+ """Remove non-numeric features and log warnings."""
53
+ non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
54
+ if non_numeric:
55
+ log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
56
+ return [f for f in features if f not in non_numeric]
57
+
58
+ def _build_model(self) -> None:
59
+ """Standardize features and fit Nearest Neighbors model."""
60
+ self.scaler = StandardScaler()
61
+ X = self.scaler.fit_transform(self.df[self.features])
62
+ self.nn = NearestNeighbors().fit(X)
63
+
64
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
65
+ """Transform features using the fitted scaler."""
66
+ return self.scaler.transform(df[self.features])
67
+
68
+ def _project_2d(self) -> None:
69
+ """Project the numeric features to 2D for visualization."""
70
+ if len(self.features) >= 2:
71
+ self.df = Projection2D().fit_transform(self.df, features=self.features)
72
+
73
+
74
+ # Testing the FeatureSpaceProximity class
48
75
  if __name__ == "__main__":
76
+
49
77
  pd.set_option("display.max_columns", None)
50
78
  pd.set_option("display.width", 1000)
51
79
 
52
- # Test a Workbench classification Model
53
- m = Model("wine-classification")
54
- fsp = FeatureSpaceProximity(m)
55
-
56
- # Neighbors Test using a single row from FeatureSet
57
- fs = FeatureSet(m.get_input())
80
+ # Create a sample DataFrame
81
+ data = {
82
+ "ID": [1, 2, 3, 4, 5],
83
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
84
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
85
+ "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
86
+ }
87
+ df = pd.DataFrame(data)
88
+
89
+ # Test the FeatureSpaceProximity class
90
+ features = ["Feature1", "Feature2", "Feature3"]
91
+ prox = FeatureSpaceProximity(df, id_column="ID", features=features)
92
+ print(prox.neighbors(1, n_neighbors=2))
93
+
94
+ # Test the neighbors method with radius
95
+ print(prox.neighbors(1, radius=2.0))
96
+
97
+ # Test with Features list
98
+ prox = FeatureSpaceProximity(df, id_column="ID", features=["Feature1"])
99
+ print(prox.neighbors(1))
100
+
101
+ # Create a sample DataFrame
102
+ data = {
103
+ "id": ["a", "b", "c", "d", "e"], # Testing string IDs
104
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
105
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
106
+ "target": [1, 0, 1, 0, 5],
107
+ }
108
+ df = pd.DataFrame(data)
109
+
110
+ # Test with String Ids
111
+ prox = FeatureSpaceProximity(
112
+ df,
113
+ id_column="id",
114
+ features=["Feature1", "Feature2"],
115
+ target="target",
116
+ include_all_columns=True,
117
+ )
118
+ print(prox.neighbors(["a", "b"]))
119
+
120
+ # Test duplicate IDs
121
+ data = {
122
+ "id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
123
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
124
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
125
+ "target": [1, 0, 1, 0, 5],
126
+ }
127
+ df = pd.DataFrame(data)
128
+ prox = FeatureSpaceProximity(df, id_column="id", features=["Feature1", "Feature2"], target="target")
129
+ print(df.equals(prox.df))
130
+
131
+ # Test on real data from Workbench
132
+ from workbench.api import FeatureSet, Model
133
+
134
+ fs = FeatureSet("aqsol_features")
135
+ model = Model("aqsol-regression")
136
+ features = model.features()
58
137
  df = fs.pull_dataframe()
59
- single_query_neighbors = fsp.neighbors(df.iloc[[0]])
60
- print("\nNeighbors for Query ID:", df.iloc[0][fs.id_column])
61
- print(single_query_neighbors)
62
-
63
- # Test a Workbench regression model
64
- m = Model("abalone-regression")
65
- fsp = FeatureSpaceProximity(m)
66
-
67
- # Neighbors Test using a multiple rows from FeatureSet
68
- fs = FeatureSet(m.get_input())
69
- df = fs.pull_dataframe()
70
- query_neighbors = fsp.neighbors(df.iloc[0:2])
71
- print("\nNeighbors for Query ID:", df.iloc[0][fs.id_column])
72
- print(query_neighbors)
73
-
74
- # Test a Workbench regression model
75
- m = Model("aqsol-regression")
76
- fsp = FeatureSpaceProximity(m)
77
-
78
- # Neighbors Test using a multiple rows from FeatureSet
79
- fs = FeatureSet(m.get_input())
80
- df = fs.pull_dataframe()
81
- query_neighbors = fsp.neighbors(df.iloc[5:7])
82
- print("\nNeighbors for Query ID:", df.iloc[5][fs.id_column])
83
- print(query_neighbors)
84
-
85
- # Time the all_neighbors method
86
- import time
87
-
88
- start_time = time.time()
89
- all_neighbors_df = fsp.all_neighbors()
90
- end_time = time.time()
91
- print("\nTime taken for all_neighbors:", end_time - start_time)
92
- print("\nAll Neighbors DataFrame:")
93
- print(all_neighbors_df)
94
-
95
- # Show a scatter plot of the data
138
+ prox = FeatureSpaceProximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
139
+ print("\n" + "=" * 80)
140
+ print("Testing Neighbors...")
141
+ print("=" * 80)
142
+ test_id = df[fs.id_column].tolist()[0]
143
+ print(f"\nNeighbors for ID {test_id}:")
144
+ print(prox.neighbors(test_id))
145
+
146
+ print("\n" + "=" * 80)
147
+ print("Testing isolated_compounds...")
148
+ print("=" * 80)
149
+
150
+ # Test isolated data in the top 1%
151
+ isolated_1pct = prox.isolated(top_percent=1.0)
152
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
153
+ print(isolated_1pct)
154
+
155
+ # Test isolated data in the top 5%
156
+ isolated_5pct = prox.isolated(top_percent=5.0)
157
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
158
+ print(isolated_5pct)
159
+
160
+ print("\n" + "=" * 80)
161
+ print("Testing target_gradients...")
162
+ print("=" * 80)
163
+
164
+ # Test with different parameters
165
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
166
+ print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
167
+ print(gradients_1pct)
168
+
169
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
170
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
171
+ print(gradients_5pct)
172
+
173
+ # Test proximity_stats
174
+ print("\n" + "=" * 80)
175
+ print("Testing proximity_stats...")
176
+ print("=" * 80)
177
+ stats = prox.proximity_stats()
178
+ print(stats)
179
+
180
+ # Plot the distance distribution using pandas
181
+ print("\n" + "=" * 80)
182
+ print("Plotting distance distribution...")
183
+ print("=" * 80)
184
+ prox.df["nn_distance"].hist(bins=50, figsize=(10, 6), edgecolor="black")
185
+
186
+ # Visualize the 2D projection
187
+ print("\n" + "=" * 80)
188
+ print("Visualizing 2D Projection...")
189
+ print("=" * 80)
96
190
  from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
97
191
  from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
98
192
 
99
- # Run the Unit Test on the Plugin using the new DataFrame with 'x' and 'y'
100
- unit_test = PluginUnitTest(ScatterPlot, input_data=fsp.df, x="x", y="y")
193
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
101
194
  unit_test.run()