workbench 0.8.219__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/api/feature_set.py +0 -1
  6. workbench/core/artifacts/feature_set_core.py +183 -228
  7. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  8. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  9. workbench/model_scripts/chemprop/chemprop.template +193 -68
  10. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  11. workbench/model_scripts/pytorch_model/generated_model_script.py +3 -3
  12. workbench/model_scripts/xgb_model/generated_model_script.py +3 -3
  13. workbench/scripts/ml_pipeline_sqs.py +71 -2
  14. workbench/themes/light/custom.css +7 -1
  15. workbench/themes/midnight_blue/custom.css +34 -0
  16. workbench/utils/chem_utils/projections.py +16 -6
  17. workbench/utils/model_utils.py +0 -1
  18. workbench/utils/plot_utils.py +146 -28
  19. workbench/utils/theme_manager.py +95 -30
  20. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  21. workbench/web_interface/components/settings_menu.py +184 -0
  22. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  23. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/RECORD +27 -25
  24. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  25. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +0 -0
  26. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  27. {workbench-0.8.219.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,321 @@
1
+ """Compound Dataset Overlap Analysis
2
+
3
+ This module provides utilities for comparing two molecular datasets based on
4
+ Tanimoto similarity in fingerprint space. It helps quantify the "overlap"
5
+ between datasets in chemical space.
6
+
7
+ Use cases:
8
+ - Train/test split validation: Ensure test set isn't too similar to training
9
+ - Dataset comparison: Compare proprietary vs public datasets
10
+ - Novelty assessment: Find compounds in query dataset that are novel vs reference
11
+ """
12
+
13
+ import logging
14
+ from typing import Optional, Tuple
15
+
16
+ import pandas as pd
17
+
18
+ from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity
19
+
20
+ # Set up logging
21
+ log = logging.getLogger("workbench")
22
+
23
+
24
+ class CompoundDatasetOverlap:
25
+ """Compare two molecular datasets using Tanimoto similarity.
26
+
27
+ Builds a FingerprintProximity model on the reference dataset, then queries
28
+ with SMILES from the query dataset to find the nearest neighbor in the
29
+ reference for each query compound. This guarantees cross-dataset matches.
30
+
31
+ Attributes:
32
+ prox: FingerprintProximity instance on reference dataset
33
+ overlap_df: Results DataFrame with similarity scores for each query compound
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ df_reference: pd.DataFrame,
39
+ df_query: pd.DataFrame,
40
+ id_column_reference: str = "id",
41
+ id_column_query: str = "id",
42
+ radius: int = 2,
43
+ n_bits: int = 2048,
44
+ ) -> None:
45
+ """
46
+ Initialize the CompoundDatasetOverlap analysis.
47
+
48
+ Args:
49
+ df_reference: Reference dataset (DataFrame with SMILES)
50
+ df_query: Query dataset (DataFrame with SMILES)
51
+ id_column_reference: ID column name in df_reference
52
+ id_column_query: ID column name in df_query
53
+ radius: Morgan fingerprint radius (default: 2 = ECFP4)
54
+ n_bits: Number of fingerprint bits (default: 2048)
55
+ """
56
+ self.id_column_reference = id_column_reference
57
+ self.id_column_query = id_column_query
58
+ self._radius = radius
59
+ self._n_bits = n_bits
60
+
61
+ # Store copies of the dataframes
62
+ self.df_reference = df_reference.copy()
63
+ self.df_query = df_query.copy()
64
+
65
+ # Find SMILES columns
66
+ self._smiles_col_reference = self._find_smiles_column(self.df_reference)
67
+ self._smiles_col_query = self._find_smiles_column(self.df_query)
68
+
69
+ if self._smiles_col_reference is None:
70
+ raise ValueError("Reference dataset must have a SMILES column")
71
+ if self._smiles_col_query is None:
72
+ raise ValueError("Query dataset must have a SMILES column")
73
+
74
+ log.info(f"Reference dataset: {len(self.df_reference)} compounds")
75
+ log.info(f"Query dataset: {len(self.df_query)} compounds")
76
+
77
+ # Build FingerprintProximity on reference dataset only
78
+ self.prox = FingerprintProximity(
79
+ self.df_reference,
80
+ id_column=id_column_reference,
81
+ radius=radius,
82
+ n_bits=n_bits,
83
+ )
84
+
85
+ # Compute cross-dataset overlap
86
+ self.overlap_df = self._compute_cross_dataset_overlap()
87
+
88
+ @staticmethod
89
+ def _find_smiles_column(df: pd.DataFrame) -> Optional[str]:
90
+ """Find the SMILES column in a DataFrame (case-insensitive)."""
91
+ for col in df.columns:
92
+ if col.lower() == "smiles":
93
+ return col
94
+ return None
95
+
96
+ def _compute_cross_dataset_overlap(self) -> pd.DataFrame:
97
+ """For each query compound, find nearest neighbor in reference using neighbors_from_smiles."""
98
+ log.info(f"Computing nearest neighbors in reference for {len(self.df_query)} query compounds")
99
+
100
+ # Get SMILES list from query dataset
101
+ query_smiles = self.df_query[self._smiles_col_query].tolist()
102
+ query_ids = self.df_query[self.id_column_query].tolist()
103
+
104
+ # Query all compounds against reference (get only nearest neighbor)
105
+ neighbors_df = self.prox.neighbors_from_smiles(query_smiles, n_neighbors=1)
106
+
107
+ # Build results with query IDs
108
+ results = []
109
+ for i, (q_id, q_smi) in enumerate(zip(query_ids, query_smiles)):
110
+ # Find the row for this query SMILES
111
+ match = neighbors_df[neighbors_df["query_id"] == q_smi]
112
+ if len(match) > 0:
113
+ row = match.iloc[0]
114
+ results.append(
115
+ {
116
+ "id": q_id,
117
+ "smiles": q_smi,
118
+ "nearest_neighbor_id": row["neighbor_id"],
119
+ "tanimoto_similarity": row["similarity"],
120
+ }
121
+ )
122
+ else:
123
+ # Should not happen, but handle gracefully
124
+ results.append(
125
+ {
126
+ "id": q_id,
127
+ "smiles": q_smi,
128
+ "nearest_neighbor_id": None,
129
+ "tanimoto_similarity": 0.0,
130
+ }
131
+ )
132
+
133
+ result_df = pd.DataFrame(results)
134
+
135
+ # Add nearest neighbor SMILES from reference
136
+ ref_smiles_map = self.df_reference.set_index(self.id_column_reference)[self._smiles_col_reference]
137
+ result_df["nearest_neighbor_smiles"] = result_df["nearest_neighbor_id"].map(ref_smiles_map)
138
+
139
+ return result_df.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
140
+
141
+ def summary_stats(self) -> pd.DataFrame:
142
+ """Return distribution statistics for nearest-neighbor Tanimoto similarities."""
143
+ return (
144
+ self.overlap_df["tanimoto_similarity"]
145
+ .describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
146
+ .to_frame()
147
+ )
148
+
149
+ def novel_compounds(self, threshold: float = 0.4) -> pd.DataFrame:
150
+ """Return query compounds that are novel (low similarity to reference).
151
+
152
+ Args:
153
+ threshold: Maximum Tanimoto similarity to consider "novel" (default: 0.4)
154
+
155
+ Returns:
156
+ DataFrame of query compounds with similarity below threshold
157
+ """
158
+ novel = self.overlap_df[self.overlap_df["tanimoto_similarity"] < threshold].copy()
159
+ return novel.sort_values("tanimoto_similarity", ascending=True).reset_index(drop=True)
160
+
161
+ def similar_compounds(self, threshold: float = 0.7) -> pd.DataFrame:
162
+ """Return query compounds that are similar to reference (high overlap).
163
+
164
+ Args:
165
+ threshold: Minimum Tanimoto similarity to consider "similar" (default: 0.7)
166
+
167
+ Returns:
168
+ DataFrame of query compounds with similarity above threshold
169
+ """
170
+ similar = self.overlap_df[self.overlap_df["tanimoto_similarity"] >= threshold].copy()
171
+ return similar.sort_values("tanimoto_similarity", ascending=False).reset_index(drop=True)
172
+
173
+ def overlap_fraction(self, threshold: float = 0.7) -> float:
174
+ """Return fraction of query compounds that overlap with reference above similarity threshold.
175
+
176
+ Args:
177
+ threshold: Minimum Tanimoto similarity to consider "overlapping"
178
+
179
+ Returns:
180
+ Fraction of query compounds with nearest neighbor similarity >= threshold
181
+ """
182
+ n_overlapping = (self.overlap_df["tanimoto_similarity"] >= threshold).sum()
183
+ return n_overlapping / len(self.overlap_df)
184
+
185
+ def plot_histogram(self, bins: int = 50, figsize: Tuple[int, int] = (10, 6)) -> None:
186
+ """Plot histogram of nearest-neighbor Tanimoto similarities.
187
+
188
+ Args:
189
+ bins: Number of histogram bins
190
+ figsize: Figure size (width, height)
191
+ """
192
+ import matplotlib.pyplot as plt
193
+
194
+ fig, ax = plt.subplots(figsize=figsize)
195
+ ax.hist(self.overlap_df["tanimoto_similarity"], bins=bins, edgecolor="black", alpha=0.7)
196
+ ax.set_xlabel("Tanimoto Similarity (query → nearest in reference)")
197
+ ax.set_ylabel("Count")
198
+ ax.set_title(f"Dataset Overlap: {len(self.overlap_df)} query compounds")
199
+ ax.axvline(x=0.4, color="red", linestyle="--", label="Novel threshold (0.4)")
200
+ ax.axvline(x=0.7, color="green", linestyle="--", label="Similar threshold (0.7)")
201
+ ax.legend()
202
+
203
+ # Add summary stats as text
204
+ stats = self.overlap_df["tanimoto_similarity"]
205
+ textstr = f"Mean: {stats.mean():.3f}\nMedian: {stats.median():.3f}\nStd: {stats.std():.3f}"
206
+ ax.text(
207
+ 0.02,
208
+ 0.98,
209
+ textstr,
210
+ transform=ax.transAxes,
211
+ verticalalignment="top",
212
+ bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
213
+ )
214
+
215
+ plt.tight_layout()
216
+ plt.show()
217
+
218
+
219
+ # =============================================================================
220
+ # Testing
221
+ # =============================================================================
222
+ if __name__ == "__main__":
223
+ print("=" * 80)
224
+ print("Testing CompoundDatasetOverlap")
225
+ print("=" * 80)
226
+
227
+ # Test 1: Basic functionality with SMILES data
228
+ print("\n1. Testing with SMILES data...")
229
+
230
+ # Reference dataset: Known drug-like compounds
231
+ reference_data = {
232
+ "id": ["aspirin", "caffeine", "glucose", "ibuprofen", "naproxen", "ethanol", "methanol", "propanol"],
233
+ "smiles": [
234
+ "CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
235
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
236
+ "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # glucose
237
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
238
+ "COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
239
+ "CCO", # ethanol
240
+ "CO", # methanol
241
+ "CCCO", # propanol
242
+ ],
243
+ }
244
+
245
+ # Query dataset: Compounds to compare against reference
246
+ query_data = {
247
+ "id": ["acetaminophen", "theophylline", "benzene", "toluene", "phenol", "aniline"],
248
+ "smiles": [
249
+ "CC(=O)NC1=CC=C(C=C1)O", # acetaminophen - similar to aspirin
250
+ "CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
251
+ "c1ccccc1", # benzene - simple aromatic
252
+ "Cc1ccccc1", # toluene - similar to benzene
253
+ "Oc1ccccc1", # phenol - hydroxyl benzene
254
+ "Nc1ccccc1", # aniline - amino benzene
255
+ ],
256
+ }
257
+
258
+ df_reference = pd.DataFrame(reference_data)
259
+ df_query = pd.DataFrame(query_data)
260
+
261
+ print(f" Reference: {len(df_reference)} compounds, Query: {len(df_query)} compounds")
262
+
263
+ overlap = CompoundDatasetOverlap(
264
+ df_reference, df_query, id_column_reference="id", id_column_query="id", radius=2, n_bits=1024
265
+ )
266
+
267
+ print("\n Overlap results:")
268
+ print(overlap.overlap_df[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
269
+
270
+ print("\n Summary statistics:")
271
+ print(overlap.summary_stats())
272
+
273
+ # Test 2: Novel and similar compound identification
274
+ print("\n2. Testing novel/similar compound identification...")
275
+
276
+ similar = overlap.similar_compounds(threshold=0.3)
277
+ print(f" Similar compounds (sim >= 0.3): {len(similar)}")
278
+ if len(similar) > 0:
279
+ print(similar[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
280
+
281
+ novel = overlap.novel_compounds(threshold=0.3)
282
+ print(f"\n Novel compounds (sim < 0.3): {len(novel)}")
283
+ if len(novel) > 0:
284
+ print(novel[["id", "nearest_neighbor_id", "tanimoto_similarity"]].to_string(index=False))
285
+
286
+ # Test 3: With Workbench data (if available)
287
+ print("\n3. Testing with Workbench FeatureSet (if available)...")
288
+
289
+ try:
290
+ from workbench.api import FeatureSet
291
+
292
+ fs = FeatureSet("aqsol_features")
293
+ full_df = fs.pull_dataframe()[:1000] # Limit to first 1000 for testing
294
+
295
+ # Split into reference and query sets
296
+ df_reference = full_df.sample(frac=0.8, random_state=42)
297
+ df_query = full_df.drop(df_reference.index)
298
+
299
+ print(f" Reference set: {len(df_reference)} compounds")
300
+ print(f" Query set: {len(df_query)} compounds")
301
+
302
+ overlap = CompoundDatasetOverlap(
303
+ df_reference, df_query, id_column_reference=fs.id_column, id_column_query=fs.id_column
304
+ )
305
+
306
+ print("\n Summary statistics:")
307
+ print(overlap.summary_stats())
308
+
309
+ print(f"\n Overlap fraction (sim >= 0.7): {overlap.overlap_fraction(0.7):.2%}")
310
+ print(f" Overlap fraction (sim >= 0.5): {overlap.overlap_fraction(0.5):.2%}")
311
+ print(f" Novel compounds (sim < 0.4): {len(overlap.novel_compounds(0.4))}")
312
+
313
+ # Uncomment to show histogram
314
+ overlap.plot_histogram()
315
+
316
+ except Exception as e:
317
+ print(f" Skipping Workbench test: {e}")
318
+
319
+ print("\n" + "=" * 80)
320
+ print("✅ All CompoundDatasetOverlap tests completed!")
321
+ print("=" * 80)
@@ -29,7 +29,6 @@ class FingerprintProximity(Proximity):
29
29
  include_all_columns: bool = False,
30
30
  radius: int = 2,
31
31
  n_bits: int = 1024,
32
- counts: bool = False,
33
32
  ) -> None:
34
33
  """
35
34
  Initialize the FingerprintProximity class for binary fingerprint similarity.
@@ -43,12 +42,10 @@ class FingerprintProximity(Proximity):
43
42
  include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
44
43
  radius: Radius for Morgan fingerprint computation (default: 2).
45
44
  n_bits: Number of bits for fingerprint (default: 1024).
46
- counts: Whether to use count simulation (default: False).
47
45
  """
48
46
  # Store fingerprint computation parameters
49
47
  self._fp_radius = radius
50
48
  self._fp_n_bits = n_bits
51
- self._fp_counts = counts
52
49
 
53
50
  # Store the requested fingerprint column (may be None)
54
51
  self._fingerprint_column_arg = fingerprint_column
@@ -107,54 +104,77 @@ class FingerprintProximity(Proximity):
107
104
  # If fingerprint column doesn't exist yet, compute it
108
105
  if self.fingerprint_column not in self.df.columns:
109
106
  log.info(f"Computing Morgan fingerprints (radius={self._fp_radius}, n_bits={self._fp_n_bits})...")
110
- self.df = compute_morgan_fingerprints(
111
- self.df, radius=self._fp_radius, n_bits=self._fp_n_bits, counts=self._fp_counts
112
- )
107
+ self.df = compute_morgan_fingerprints(self.df, radius=self._fp_radius, n_bits=self._fp_n_bits)
113
108
 
114
109
  def _build_model(self) -> None:
115
110
  """
116
111
  Build the fingerprint proximity model for Tanimoto similarity.
117
- Converts fingerprint strings to binary arrays and initializes NearestNeighbors.
118
112
 
119
- Note: sklearn uses Jaccard distance internally (1 - Tanimoto similarity).
120
- We convert back to Tanimoto similarity in the output methods.
113
+ For binary fingerprints: uses Jaccard distance (1 - Tanimoto)
114
+ For count fingerprints: uses weighted Tanimoto (Ruzicka) distance
121
115
  """
122
- log.info("Converting fingerprints to binary feature matrix...")
123
-
124
- # Convert fingerprint strings to binary arrays and store for later use
125
- self.X = self._fingerprints_to_matrix(self.df)
126
-
127
- # sklearn uses Jaccard distance = 1 - Tanimoto similarity
128
- # We convert to Tanimoto similarity in neighbors() and _precompute_metrics()
129
- log.info("Building NearestNeighbors model (Jaccard/Tanimoto metric, BallTree)...")
130
- self.nn = NearestNeighbors(metric="jaccard", algorithm="ball_tree").fit(self.X)
116
+ # Convert fingerprint strings to matrix and detect format
117
+ self.X, self._is_count_fp = self._fingerprints_to_matrix(self.df)
118
+
119
+ if self._is_count_fp:
120
+ # Weighted Tanimoto (Ruzicka) for count vectors: 1 - Σmin(A,B)/Σmax(A,B)
121
+ log.info("Building NearestNeighbors model (weighted Tanimoto for count fingerprints)...")
122
+
123
+ def ruzicka_distance(a, b):
124
+ """Ruzicka distance = 1 - weighted Tanimoto similarity."""
125
+ min_sum = np.minimum(a, b).sum()
126
+ max_sum = np.maximum(a, b).sum()
127
+ if max_sum == 0:
128
+ return 0.0
129
+ return 1.0 - (min_sum / max_sum)
130
+
131
+ self.nn = NearestNeighbors(metric=ruzicka_distance, algorithm="ball_tree").fit(self.X)
132
+ else:
133
+ # Standard Jaccard for binary fingerprints
134
+ log.info("Building NearestNeighbors model (Jaccard/Tanimoto for binary fingerprints)...")
135
+ self.nn = NearestNeighbors(metric="jaccard", algorithm="ball_tree").fit(self.X)
131
136
 
132
137
  def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
133
138
  """
134
- Transform fingerprints to binary matrix for querying.
139
+ Transform fingerprints to matrix for querying.
135
140
 
136
141
  Args:
137
142
  df: DataFrame containing fingerprints to transform.
138
143
 
139
144
  Returns:
140
- Binary feature matrix for the fingerprints.
145
+ Feature matrix for the fingerprints (binary or count based on self._is_count_fp).
141
146
  """
142
- return self._fingerprints_to_matrix(df)
147
+ matrix, _ = self._fingerprints_to_matrix(df)
148
+ return matrix
143
149
 
144
- def _fingerprints_to_matrix(self, df: pd.DataFrame) -> np.ndarray:
150
+ def _fingerprints_to_matrix(self, df: pd.DataFrame) -> tuple[np.ndarray, bool]:
145
151
  """
146
- Convert fingerprint strings to a binary numpy matrix.
152
+ Convert fingerprint strings to a numpy matrix.
153
+
154
+ Supports two formats (auto-detected):
155
+ - Bitstrings: "10110010..." → binary matrix (bool), is_count=False
156
+ - Count vectors: "0,3,0,1,5,..." → count matrix (uint8), is_count=True
147
157
 
148
158
  Args:
149
159
  df: DataFrame containing fingerprint column.
150
160
 
151
161
  Returns:
152
- 2D numpy array of binary fingerprint bits.
162
+ Tuple of (2D numpy array, is_count_fingerprint boolean)
153
163
  """
154
- fingerprint_bits = df[self.fingerprint_column].apply(
155
- lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
156
- )
157
- return np.vstack(fingerprint_bits)
164
+ # Auto-detect format based on first fingerprint
165
+ sample = str(df[self.fingerprint_column].iloc[0])
166
+ if "," in sample:
167
+ # Count vector format: preserve counts for weighted Tanimoto
168
+ fingerprint_values = df[self.fingerprint_column].apply(
169
+ lambda fp: np.array([int(x) for x in fp.split(",")], dtype=np.uint8)
170
+ )
171
+ return np.vstack(fingerprint_values), True
172
+ else:
173
+ # Bitstring format: binary values
174
+ fingerprint_bits = df[self.fingerprint_column].apply(
175
+ lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
176
+ )
177
+ return np.vstack(fingerprint_bits), False
158
178
 
159
179
  def _precompute_metrics(self) -> None:
160
180
  """Precompute metrics, adding Tanimoto similarity alongside distance."""
@@ -171,8 +191,13 @@ class FingerprintProximity(Proximity):
171
191
  self.core_columns.extend([self.target, "nn_target", "nn_target_diff"])
172
192
 
173
193
  def _project_2d(self) -> None:
174
- """Project the fingerprint matrix to 2D for visualization using UMAP with Jaccard metric."""
175
- self.df = Projection2D().fit_transform(self.df, feature_matrix=self.X, metric="jaccard")
194
+ """Project the fingerprint matrix to 2D for visualization using UMAP."""
195
+ if self._is_count_fp:
196
+ # For count fingerprints, convert to binary for UMAP projection (Jaccard needs binary)
197
+ X_binary = (self.X > 0).astype(np.bool_)
198
+ self.df = Projection2D().fit_transform(self.df, feature_matrix=X_binary, metric="jaccard")
199
+ else:
200
+ self.df = Projection2D().fit_transform(self.df, feature_matrix=self.X, metric="jaccard")
176
201
 
177
202
  def isolated(self, top_percent: float = 1.0) -> pd.DataFrame:
178
203
  """
@@ -240,6 +265,81 @@ class FingerprintProximity(Proximity):
240
265
 
241
266
  return neighbors_df
242
267
 
268
+ def neighbors_from_smiles(
269
+ self,
270
+ smiles: Union[str, List[str]],
271
+ n_neighbors: int = 5,
272
+ min_similarity: Optional[float] = None,
273
+ ) -> pd.DataFrame:
274
+ """
275
+ Find neighbors for SMILES strings not in the reference dataset.
276
+
277
+ Args:
278
+ smiles: Single SMILES string or list of SMILES to query
279
+ n_neighbors: Number of neighbors to return (default: 5, ignored if min_similarity is set)
280
+ min_similarity: If provided, find all neighbors with Tanimoto similarity >= this value (0-1)
281
+
282
+ Returns:
283
+ DataFrame containing neighbors with Tanimoto similarity scores.
284
+ The 'query_id' column contains the SMILES string (or index if list).
285
+ """
286
+ # Normalize to list
287
+ smiles_list = [smiles] if isinstance(smiles, str) else smiles
288
+
289
+ # Build a temporary DataFrame with the query SMILES
290
+ query_df = pd.DataFrame({"smiles": smiles_list})
291
+
292
+ # Compute fingerprints using same parameters as the reference dataset
293
+ query_df = compute_morgan_fingerprints(query_df, radius=self._fp_radius, n_bits=self._fp_n_bits)
294
+
295
+ # Transform to matrix (use same format detection as reference)
296
+ X_query, _ = self._fingerprints_to_matrix(query_df)
297
+
298
+ # Query the model
299
+ if min_similarity is not None:
300
+ radius = 1 - min_similarity
301
+ distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
302
+ else:
303
+ distances, indices = self.nn.kneighbors(X_query, n_neighbors=n_neighbors)
304
+
305
+ # Build results
306
+ results = []
307
+ for i, (dists, nbrs) in enumerate(zip(distances, indices)):
308
+ query_id = smiles_list[i]
309
+
310
+ for neighbor_idx, dist in zip(nbrs, dists):
311
+ neighbor_row = self.df.iloc[neighbor_idx]
312
+ neighbor_id = neighbor_row[self.id_column]
313
+ similarity = 1.0 - dist if dist > 1e-6 else 1.0
314
+
315
+ result = {
316
+ "query_id": query_id,
317
+ "neighbor_id": neighbor_id,
318
+ "similarity": similarity,
319
+ }
320
+
321
+ # Add target if present
322
+ if self.target and self.target in self.df.columns:
323
+ result[self.target] = neighbor_row[self.target]
324
+
325
+ # Include all columns if requested
326
+ if self.include_all_columns:
327
+ for col in self.df.columns:
328
+ if col not in [self.id_column, "query_id", "neighbor_id", "similarity"]:
329
+ result[f"neighbor_{col}"] = neighbor_row[col]
330
+
331
+ results.append(result)
332
+
333
+ df_results = pd.DataFrame(results)
334
+
335
+ # Sort by query_id then similarity descending
336
+ if len(df_results) > 0:
337
+ df_results = df_results.sort_values(["query_id", "similarity"], ascending=[True, False]).reset_index(
338
+ drop=True
339
+ )
340
+
341
+ return df_results
342
+
243
343
 
244
344
  # Testing the FingerprintProximity class
245
345
  if __name__ == "__main__":
@@ -273,12 +373,71 @@ if __name__ == "__main__":
273
373
  )
274
374
  print(prox.neighbors(["a", "b"]))
275
375
 
376
+ # Regression test: include_all_columns should not break neighbor sorting
377
+ print("\n" + "=" * 80)
378
+ print("Regression test: include_all_columns neighbor sorting...")
379
+ print("=" * 80)
380
+ neighbors_all_cols = prox.neighbors("a", n_neighbors=4)
381
+ # Verify neighbors are sorted by similarity (descending), not alphabetically by neighbor_id
382
+ similarities = neighbors_all_cols["similarity"].tolist()
383
+ assert similarities == sorted(
384
+ similarities, reverse=True
385
+ ), f"Neighbors not sorted by similarity! Got: {similarities}"
386
+ # Verify query_id column has correct value (the query, not the neighbor)
387
+ assert all(
388
+ neighbors_all_cols["id"] == "a"
389
+ ), f"Query ID column corrupted! Expected all 'a', got: {neighbors_all_cols['id'].tolist()}"
390
+ print("PASSED: Neighbors correctly sorted by similarity with include_all_columns=True")
391
+
392
+ # Test neighbors_from_smiles with synthetic data
393
+ print("\n" + "=" * 80)
394
+ print("Testing neighbors_from_smiles...")
395
+ print("=" * 80)
396
+
397
+ # Create reference dataset with known SMILES
398
+ ref_data = {
399
+ "id": ["aspirin", "ibuprofen", "naproxen", "caffeine", "ethanol"],
400
+ "smiles": [
401
+ "CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
402
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
403
+ "COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
404
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
405
+ "CCO", # ethanol
406
+ ],
407
+ "activity": [1.0, 2.0, 2.5, 3.0, 0.5],
408
+ }
409
+ ref_df = pd.DataFrame(ref_data)
410
+
411
+ prox_ref = FingerprintProximity(ref_df, id_column="id", target="activity", radius=2, n_bits=1024)
412
+
413
+ # Query with a single SMILES (acetaminophen - similar to aspirin)
414
+ query_smiles = "CC(=O)NC1=CC=C(C=C1)O" # acetaminophen
415
+ print(f"\nQuery: acetaminophen ({query_smiles})")
416
+ neighbors = prox_ref.neighbors_from_smiles(query_smiles, n_neighbors=3)
417
+ print(neighbors)
418
+
419
+ # Query with multiple SMILES
420
+ print("\nQuery: multiple SMILES (theophylline, methanol)")
421
+ multi_query = [
422
+ "CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
423
+ "CO", # methanol - similar to ethanol
424
+ ]
425
+ neighbors_multi = prox_ref.neighbors_from_smiles(multi_query, n_neighbors=2)
426
+ print(neighbors_multi)
427
+
428
+ # Test with min_similarity threshold
429
+ print("\nQuery with min_similarity=0.3:")
430
+ neighbors_thresh = prox_ref.neighbors_from_smiles(query_smiles, min_similarity=0.3)
431
+ print(neighbors_thresh)
432
+
433
+ print("PASSED: neighbors_from_smiles working correctly")
434
+
276
435
  # Test on real data from Workbench
277
436
  from workbench.api import FeatureSet, Model
278
437
 
279
438
  fs = FeatureSet("aqsol_features")
280
439
  model = Model("aqsol-regression")
281
- df = fs.pull_dataframe()
440
+ df = fs.pull_dataframe()[:1000] # Limit to 1000 for testing
282
441
  prox = FingerprintProximity(df, id_column=fs.id_column, target=model.target())
283
442
 
284
443
  print("\n" + "=" * 80)
@@ -106,8 +106,14 @@ class Projection2D:
106
106
  return PCA(n_components=2)
107
107
 
108
108
  if projection == "UMAP" and UMAP_AVAILABLE:
109
- self.log.info(f"Projection: UMAP with metric={metric}")
110
- return umap.UMAP(n_components=2, metric=metric)
109
+ # UMAP default n_neighbors=15, adjust if dataset is smaller
110
+ n_neighbors = min(15, len(df) - 1)
111
+ if n_neighbors < 15:
112
+ self.log.warning(
113
+ f"Dataset size ({len(df)}) smaller than default n_neighbors, using n_neighbors={n_neighbors}"
114
+ )
115
+ self.log.info(f"Projection: UMAP with metric={metric}, n_neighbors={n_neighbors}")
116
+ return umap.UMAP(n_components=2, metric=metric, n_neighbors=n_neighbors)
111
117
 
112
118
  self.log.warning(
113
119
  f"Projection method '{projection}' not recognized or UMAP not available. Falling back to TSNE."
@@ -331,5 +331,8 @@ class Proximity(ABC):
331
331
  # Include all columns if requested
332
332
  if self.include_all_columns:
333
333
  result.update(neighbor_row.to_dict())
334
+ # Restore query_id after update (neighbor_row may have overwritten id column)
335
+ result[self.id_column] = query_id
336
+ result["neighbor_id"] = neighbor_id
334
337
 
335
338
  return result
@@ -214,7 +214,6 @@ class FeatureSet(FeatureSetCore):
214
214
  include_all_columns=include_all_columns,
215
215
  radius=radius,
216
216
  n_bits=n_bits,
217
- counts=counts,
218
217
  )
219
218
 
220
219
  def cleanlab_model(