workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,162 +1,498 @@
1
1
  import pandas as pd
2
2
  import numpy as np
3
3
  from sklearn.neighbors import NearestNeighbors
4
- from typing import Union, List
4
+ from typing import Union, List, Optional
5
5
  import logging
6
6
 
7
7
  # Workbench Imports
8
- from workbench.algorithms.dataframe.proximity import Proximity, ProximityType
8
+ from workbench.algorithms.dataframe.proximity import Proximity
9
+ from workbench.algorithms.dataframe.projection_2d import Projection2D
10
+ from workbench.utils.chem_utils.fingerprints import compute_morgan_fingerprints
9
11
 
10
12
  # Set up logging
11
13
  log = logging.getLogger("workbench")
12
14
 
13
15
 
14
16
  class FingerprintProximity(Proximity):
17
+ """Proximity computations for binary fingerprints using Tanimoto similarity.
18
+
19
+ Note: Tanimoto similarity is equivalent to Jaccard similarity for binary vectors.
20
+ Tanimoto(A, B) = |A ∩ B| / |A ∪ B|
21
+ """
22
+
15
23
  def __init__(
16
- self, df: pd.DataFrame, id_column: Union[int, str], fingerprint_column: str, n_neighbors: int = 5
24
+ self,
25
+ df: pd.DataFrame,
26
+ id_column: str,
27
+ fingerprint_column: Optional[str] = None,
28
+ target: Optional[str] = None,
29
+ include_all_columns: bool = False,
30
+ radius: int = 2,
31
+ n_bits: int = 1024,
17
32
  ) -> None:
18
33
  """
19
34
  Initialize the FingerprintProximity class for binary fingerprint similarity.
20
35
 
21
36
  Args:
22
- df (pd.DataFrame): DataFrame containing fingerprints.
23
- id_column (Union[int, str]): Name of the column used as an identifier.
24
- fingerprint_column (str): Name of the column containing fingerprints.
25
- n_neighbors (int): Default number of neighbors to compute.
37
+ df: DataFrame containing fingerprints or SMILES.
38
+ id_column: Name of the column used as an identifier.
39
+ fingerprint_column: Name of the column containing fingerprints (bit strings).
40
+ If None, looks for existing "fingerprint" column or computes from SMILES.
41
+ target: Name of the target column. Defaults to None.
42
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
43
+ radius: Radius for Morgan fingerprint computation (default: 2).
44
+ n_bits: Number of bits for fingerprint (default: 1024).
45
+ """
46
+ # Store fingerprint computation parameters
47
+ self._fp_radius = radius
48
+ self._fp_n_bits = n_bits
49
+
50
+ # Store the requested fingerprint column (may be None)
51
+ self._fingerprint_column_arg = fingerprint_column
52
+
53
+ # Determine fingerprint column name (but don't compute yet - that happens in _prepare_data)
54
+ self.fingerprint_column = self._resolve_fingerprint_column_name(df, fingerprint_column)
55
+
56
+ # Call parent constructor with fingerprint_column as the only "feature"
57
+ super().__init__(
58
+ df,
59
+ id_column=id_column,
60
+ features=[self.fingerprint_column],
61
+ target=target,
62
+ include_all_columns=include_all_columns,
63
+ )
64
+
65
+ @staticmethod
66
+ def _resolve_fingerprint_column_name(df: pd.DataFrame, fingerprint_column: Optional[str]) -> str:
67
+ """
68
+ Determine the fingerprint column name, validating it exists or can be computed.
69
+
70
+ Args:
71
+ df: Input DataFrame.
72
+ fingerprint_column: Explicitly specified fingerprint column, or None.
73
+
74
+ Returns:
75
+ Name of the fingerprint column to use.
76
+
77
+ Raises:
78
+ ValueError: If no fingerprint column exists and no SMILES column found.
26
79
  """
27
- self.fingerprint_column = fingerprint_column
80
+ # If explicitly provided, validate it exists
81
+ if fingerprint_column is not None:
82
+ if fingerprint_column not in df.columns:
83
+ raise ValueError(f"Fingerprint column '{fingerprint_column}' not found in DataFrame")
84
+ return fingerprint_column
85
+
86
+ # Check for existing "fingerprint" column
87
+ if "fingerprint" in df.columns:
88
+ log.info("Using existing 'fingerprint' column")
89
+ return "fingerprint"
90
+
91
+ # Will need to compute from SMILES - validate SMILES column exists
92
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
93
+ if smiles_column is None:
94
+ raise ValueError(
95
+ "No fingerprint column provided and no SMILES column found. "
96
+ "Either provide a fingerprint_column or include a 'smiles' column in the DataFrame."
97
+ )
28
98
 
29
- # Call the parent class constructor
30
- super().__init__(df, id_column=id_column, features=[fingerprint_column], n_neighbors=n_neighbors)
99
+ # Fingerprints will be computed in _prepare_data
100
+ return "fingerprint"
31
101
 
32
- # Override the build_proximity_model method
33
- def build_proximity_model(self) -> None:
102
+ def _prepare_data(self) -> None:
103
+ """Compute fingerprints from SMILES if needed."""
104
+ # If fingerprint column doesn't exist yet, compute it
105
+ if self.fingerprint_column not in self.df.columns:
106
+ log.info(f"Computing Morgan fingerprints (radius={self._fp_radius}, n_bits={self._fp_n_bits})...")
107
+ self.df = compute_morgan_fingerprints(self.df, radius=self._fp_radius, n_bits=self._fp_n_bits)
108
+
109
+ def _build_model(self) -> None:
34
110
  """
35
- Prepare the fingerprint data for nearest neighbor calculations.
36
- Converts fingerprint strings to binary arrays and initializes NearestNeighbors.
111
+ Build the fingerprint proximity model for Tanimoto similarity.
112
+
113
+ For binary fingerprints: uses Jaccard distance (1 - Tanimoto)
114
+ For count fingerprints: uses weighted Tanimoto (Ruzicka) distance
37
115
  """
38
- log.info("Converting fingerprints to binary feature matrix...")
39
- self.proximity_type = ProximityType.SIMILARITY
116
+ # Convert fingerprint strings to matrix and detect format
117
+ self.X, self._is_count_fp = self._fingerprints_to_matrix(self.df)
40
118
 
41
- # Convert fingerprint strings to binary arrays
119
+ if self._is_count_fp:
120
+ # Weighted Tanimoto (Ruzicka) for count vectors: 1 - Σmin(A,B)/Σmax(A,B)
121
+ log.info("Building NearestNeighbors model (weighted Tanimoto for count fingerprints)...")
42
122
 
43
- fingerprint_bits = self.df[self.fingerprint_column].apply(
44
- lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
45
- )
46
- self.X = np.vstack(fingerprint_bits)
123
+ def ruzicka_distance(a, b):
124
+ """Ruzicka distance = 1 - weighted Tanimoto similarity."""
125
+ min_sum = np.minimum(a, b).sum()
126
+ max_sum = np.maximum(a, b).sum()
127
+ if max_sum == 0:
128
+ return 0.0
129
+ return 1.0 - (min_sum / max_sum)
47
130
 
48
- # Use Jaccard similarity for binary fingerprints
49
- log.info("Computing NearestNeighbors with Jaccard metric...")
50
- self.nn = NearestNeighbors(metric="jaccard", n_neighbors=self.n_neighbors + 1).fit(self.X)
131
+ self.nn = NearestNeighbors(metric=ruzicka_distance, algorithm="ball_tree").fit(self.X)
132
+ else:
133
+ # Standard Jaccard for binary fingerprints
134
+ log.info("Building NearestNeighbors model (Jaccard/Tanimoto for binary fingerprints)...")
135
+ self.nn = NearestNeighbors(metric="jaccard", algorithm="ball_tree").fit(self.X)
51
136
 
52
- # Override the prep_features_for_query method
53
- def prep_features_for_query(self, query_df: pd.DataFrame) -> np.ndarray:
137
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
54
138
  """
55
- Prepare the query DataFrame by converting fingerprints to binary arrays.
139
+ Transform fingerprints to matrix for querying.
56
140
 
57
141
  Args:
58
- query_df (pd.DataFrame): DataFrame containing query fingerprints.
142
+ df: DataFrame containing fingerprints to transform.
59
143
 
60
144
  Returns:
61
- np.ndarray: Binary feature matrix for the query fingerprints.
145
+ Feature matrix for the fingerprints (binary or count based on self._is_count_fp).
62
146
  """
63
- fingerprint_bits = query_df[self.fingerprint_column].apply(
64
- lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
65
- )
66
- return np.vstack(fingerprint_bits)
147
+ matrix, _ = self._fingerprints_to_matrix(df)
148
+ return matrix
67
149
 
68
- def all_neighbors(
69
- self,
70
- min_similarity: float = None,
71
- include_self: bool = False,
72
- add_columns: List[str] = None,
73
- ) -> pd.DataFrame:
150
+ def _fingerprints_to_matrix(self, df: pd.DataFrame) -> tuple[np.ndarray, bool]:
74
151
  """
75
- Find neighbors for all fingerprints in the dataset.
152
+ Convert fingerprint strings to a numpy matrix.
153
+
154
+ Supports two formats (auto-detected):
155
+ - Bitstrings: "10110010..." → binary matrix (bool), is_count=False
156
+ - Count vectors: "0,3,0,1,5,..." → count matrix (uint8), is_count=True
76
157
 
77
158
  Args:
78
- min_similarity: Minimum similarity threshold (0-1)
79
- include_self: Whether to include self in results
80
- add_columns: Additional columns to include in results
159
+ df: DataFrame containing fingerprint column.
81
160
 
82
161
  Returns:
83
- DataFrame containing neighbors and similarities
162
+ Tuple of (2D numpy array, is_count_fingerprint boolean)
84
163
  """
164
+ # Auto-detect format based on first fingerprint
165
+ sample = str(df[self.fingerprint_column].iloc[0])
166
+ if "," in sample:
167
+ # Count vector format: preserve counts for weighted Tanimoto
168
+ fingerprint_values = df[self.fingerprint_column].apply(
169
+ lambda fp: np.array([int(x) for x in fp.split(",")], dtype=np.uint8)
170
+ )
171
+ return np.vstack(fingerprint_values), True
172
+ else:
173
+ # Bitstring format: binary values
174
+ fingerprint_bits = df[self.fingerprint_column].apply(
175
+ lambda fp: np.array([int(bit) for bit in fp], dtype=np.bool_)
176
+ )
177
+ return np.vstack(fingerprint_bits), False
85
178
 
86
- # Call the parent class method to find neighbors
87
- return self.neighbors(
88
- query_df=self.df,
89
- min_similarity=min_similarity,
90
- include_self=include_self,
91
- add_columns=add_columns,
179
+ def _precompute_metrics(self) -> None:
180
+ """Precompute metrics, adding Tanimoto similarity alongside distance."""
181
+ # Call parent to compute nn_distance (Jaccard), nn_id, nn_target, nn_target_diff
182
+ super()._precompute_metrics()
183
+
184
+ # Add Tanimoto similarity (keep nn_distance for internal use by target_gradients)
185
+ self.df["nn_similarity"] = 1 - self.df["nn_distance"]
186
+
187
+ def _set_core_columns(self) -> None:
188
+ """Set core columns using nn_similarity instead of nn_distance."""
189
+ self.core_columns = [self.id_column, "nn_similarity", "nn_id"]
190
+ if self.target:
191
+ self.core_columns.extend([self.target, "nn_target", "nn_target_diff"])
192
+
193
+ def _project_2d(self) -> None:
194
+ """Project the fingerprint matrix to 2D for visualization using UMAP."""
195
+ if self._is_count_fp:
196
+ # For count fingerprints, convert to binary for UMAP projection (Jaccard needs binary)
197
+ X_binary = (self.X > 0).astype(np.bool_)
198
+ self.df = Projection2D().fit_transform(self.df, feature_matrix=X_binary, metric="jaccard")
199
+ else:
200
+ self.df = Projection2D().fit_transform(self.df, feature_matrix=self.X, metric="jaccard")
201
+
202
+ def isolated(self, top_percent: float = 1.0) -> pd.DataFrame:
203
+ """
204
+ Find isolated data points based on Tanimoto similarity to nearest neighbor.
205
+
206
+ Args:
207
+ top_percent: Percentage of most isolated data points to return (e.g., 1.0 returns top 1%)
208
+
209
+ Returns:
210
+ DataFrame of observations with lowest Tanimoto similarity, sorted ascending
211
+ """
212
+ # For Tanimoto similarity, isolated means LOW similarity to nearest neighbor
213
+ percentile = top_percent
214
+ threshold = np.percentile(self.df["nn_similarity"], percentile)
215
+ isolated = self.df[self.df["nn_similarity"] <= threshold].copy()
216
+ isolated = isolated.sort_values("nn_similarity", ascending=True).reset_index(drop=True)
217
+ return isolated if self.include_all_columns else isolated[self.core_columns]
218
+
219
+ def proximity_stats(self) -> pd.DataFrame:
220
+ """
221
+ Return distribution statistics for nearest neighbor Tanimoto similarity.
222
+
223
+ Returns:
224
+ DataFrame with similarity distribution statistics (count, mean, std, percentiles)
225
+ """
226
+ return (
227
+ self.df["nn_similarity"]
228
+ .describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
229
+ .to_frame()
92
230
  )
93
231
 
94
232
  def neighbors(
95
233
  self,
96
- query_df: pd.DataFrame,
97
- min_similarity: float = None,
98
- include_self: bool = False,
99
- add_columns: List[str] = None,
234
+ id_or_ids: Union[str, int, List[Union[str, int]]],
235
+ n_neighbors: Optional[int] = 5,
236
+ min_similarity: Optional[float] = None,
237
+ include_self: bool = True,
100
238
  ) -> pd.DataFrame:
101
239
  """
102
- Find neighbors for each row in the query DataFrame.
240
+ Return neighbors for ID(s) from the existing dataset.
103
241
 
104
242
  Args:
105
- query_df: DataFrame containing query fingerprints
106
- min_similarity: Minimum similarity threshold (0-1)
107
- include_self: Whether to include self in results (if present)
108
- add_columns: Additional columns to include in results
243
+ id_or_ids: Single ID or list of IDs to look up
244
+ n_neighbors: Number of neighbors to return (default: 5, ignored if min_similarity is set)
245
+ min_similarity: If provided, find all neighbors with Tanimoto similarity >= this value (0-1)
246
+ include_self: Whether to include self in results (default: True)
109
247
 
110
248
  Returns:
111
- DataFrame containing neighbors and similarities
112
-
113
- Note: The query DataFrame must include the feature columns. The id_column is optional.
249
+ DataFrame containing neighbors with Tanimoto similarity scores
114
250
  """
115
-
116
- # Calculate radius from similarity if provided
251
+ # Convert min_similarity to radius (Jaccard distance = 1 - Tanimoto similarity)
117
252
  radius = 1 - min_similarity if min_similarity is not None else None
118
253
 
119
- # Call the parent class method to find neighbors
254
+ # Call parent method (returns Jaccard distance)
120
255
  neighbors_df = super().neighbors(
121
- query_df=query_df,
256
+ id_or_ids=id_or_ids,
257
+ n_neighbors=n_neighbors,
122
258
  radius=radius,
123
259
  include_self=include_self,
124
- add_columns=add_columns,
125
260
  )
126
261
 
127
- # Convert distances to similarity
262
+ # Convert Jaccard distance to Tanimoto similarity
128
263
  neighbors_df["similarity"] = 1 - neighbors_df["distance"]
129
264
  neighbors_df.drop(columns=["distance"], inplace=True)
265
+
130
266
  return neighbors_df
131
267
 
268
+ def neighbors_from_smiles(
269
+ self,
270
+ smiles: Union[str, List[str]],
271
+ n_neighbors: int = 5,
272
+ min_similarity: Optional[float] = None,
273
+ ) -> pd.DataFrame:
274
+ """
275
+ Find neighbors for SMILES strings not in the reference dataset.
276
+
277
+ Args:
278
+ smiles: Single SMILES string or list of SMILES to query
279
+ n_neighbors: Number of neighbors to return (default: 5, ignored if min_similarity is set)
280
+ min_similarity: If provided, find all neighbors with Tanimoto similarity >= this value (0-1)
281
+
282
+ Returns:
283
+ DataFrame containing neighbors with Tanimoto similarity scores.
284
+ The 'query_id' column contains the SMILES string (or index if list).
285
+ """
286
+ # Normalize to list
287
+ smiles_list = [smiles] if isinstance(smiles, str) else smiles
288
+
289
+ # Build a temporary DataFrame with the query SMILES
290
+ query_df = pd.DataFrame({"smiles": smiles_list})
291
+
292
+ # Compute fingerprints using same parameters as the reference dataset
293
+ query_df = compute_morgan_fingerprints(query_df, radius=self._fp_radius, n_bits=self._fp_n_bits)
294
+
295
+ # Transform to matrix (use same format detection as reference)
296
+ X_query, _ = self._fingerprints_to_matrix(query_df)
297
+
298
+ # Query the model
299
+ if min_similarity is not None:
300
+ radius = 1 - min_similarity
301
+ distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
302
+ else:
303
+ distances, indices = self.nn.kneighbors(X_query, n_neighbors=n_neighbors)
304
+
305
+ # Build results
306
+ results = []
307
+ for i, (dists, nbrs) in enumerate(zip(distances, indices)):
308
+ query_id = smiles_list[i]
309
+
310
+ for neighbor_idx, dist in zip(nbrs, dists):
311
+ neighbor_row = self.df.iloc[neighbor_idx]
312
+ neighbor_id = neighbor_row[self.id_column]
313
+ similarity = 1.0 - dist if dist > 1e-6 else 1.0
314
+
315
+ result = {
316
+ "query_id": query_id,
317
+ "neighbor_id": neighbor_id,
318
+ "similarity": similarity,
319
+ }
320
+
321
+ # Add target if present
322
+ if self.target and self.target in self.df.columns:
323
+ result[self.target] = neighbor_row[self.target]
324
+
325
+ # Include all columns if requested
326
+ if self.include_all_columns:
327
+ for col in self.df.columns:
328
+ if col not in [self.id_column, "query_id", "neighbor_id", "similarity"]:
329
+ result[f"neighbor_{col}"] = neighbor_row[col]
330
+
331
+ results.append(result)
332
+
333
+ df_results = pd.DataFrame(results)
334
+
335
+ # Sort by query_id then similarity descending
336
+ if len(df_results) > 0:
337
+ df_results = df_results.sort_values(["query_id", "similarity"], ascending=[True, False]).reset_index(
338
+ drop=True
339
+ )
340
+
341
+ return df_results
342
+
132
343
 
133
344
  # Testing the FingerprintProximity class
134
345
  if __name__ == "__main__":
135
346
  pd.set_option("display.max_columns", None)
136
347
  pd.set_option("display.width", 1000)
137
348
 
138
- # Example DataFrame
349
+ # Create an Example DataFrame with fingerprints
139
350
  data = {
140
- "id": ["a", "b", "c", "d"],
141
- "fingerprint": ["101010", "111010", "101110", "011100"],
351
+ "id": ["a", "b", "c", "d", "e"],
352
+ "fingerprint": ["101010", "111010", "101110", "011100", "000111"],
353
+ "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
354
+ "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
355
+ "target": [1, 0, 1, 0, 5],
142
356
  }
143
357
  df = pd.DataFrame(data)
144
358
 
145
- # Initialize the FingerprintProximity class
146
- proximity = FingerprintProximity(df, fingerprint_column="fingerprint", id_column="id", n_neighbors=3)
359
+ # Test basic FingerprintProximity with explicit fingerprint column
360
+ prox = FingerprintProximity(df, fingerprint_column="fingerprint", id_column="id", target="target")
361
+ print(prox.neighbors("a", n_neighbors=3))
362
+
363
+ # Test neighbors with similarity threshold
364
+ print(prox.neighbors("a", min_similarity=0.5))
365
+
366
+ # Test with include_all_columns=True
367
+ prox = FingerprintProximity(
368
+ df,
369
+ fingerprint_column="fingerprint",
370
+ id_column="id",
371
+ target="target",
372
+ include_all_columns=True,
373
+ )
374
+ print(prox.neighbors(["a", "b"]))
375
+
376
+ # Regression test: include_all_columns should not break neighbor sorting
377
+ print("\n" + "=" * 80)
378
+ print("Regression test: include_all_columns neighbor sorting...")
379
+ print("=" * 80)
380
+ neighbors_all_cols = prox.neighbors("a", n_neighbors=4)
381
+ # Verify neighbors are sorted by similarity (descending), not alphabetically by neighbor_id
382
+ similarities = neighbors_all_cols["similarity"].tolist()
383
+ assert similarities == sorted(
384
+ similarities, reverse=True
385
+ ), f"Neighbors not sorted by similarity! Got: {similarities}"
386
+ # Verify query_id column has correct value (the query, not the neighbor)
387
+ assert all(
388
+ neighbors_all_cols["id"] == "a"
389
+ ), f"Query ID column corrupted! Expected all 'a', got: {neighbors_all_cols['id'].tolist()}"
390
+ print("PASSED: Neighbors correctly sorted by similarity with include_all_columns=True")
391
+
392
+ # Test neighbors_from_smiles with synthetic data
393
+ print("\n" + "=" * 80)
394
+ print("Testing neighbors_from_smiles...")
395
+ print("=" * 80)
396
+
397
+ # Create reference dataset with known SMILES
398
+ ref_data = {
399
+ "id": ["aspirin", "ibuprofen", "naproxen", "caffeine", "ethanol"],
400
+ "smiles": [
401
+ "CC(=O)OC1=CC=CC=C1C(=O)O", # aspirin
402
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # ibuprofen
403
+ "COC1=CC2=CC(C(C)C(O)=O)=CC=C2C=C1", # naproxen
404
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # caffeine
405
+ "CCO", # ethanol
406
+ ],
407
+ "activity": [1.0, 2.0, 2.5, 3.0, 0.5],
408
+ }
409
+ ref_df = pd.DataFrame(ref_data)
410
+
411
+ prox_ref = FingerprintProximity(ref_df, id_column="id", target="activity", radius=2, n_bits=1024)
412
+
413
+ # Query with a single SMILES (acetaminophen - similar to aspirin)
414
+ query_smiles = "CC(=O)NC1=CC=C(C=C1)O" # acetaminophen
415
+ print(f"\nQuery: acetaminophen ({query_smiles})")
416
+ neighbors = prox_ref.neighbors_from_smiles(query_smiles, n_neighbors=3)
417
+ print(neighbors)
418
+
419
+ # Query with multiple SMILES
420
+ print("\nQuery: multiple SMILES (theophylline, methanol)")
421
+ multi_query = [
422
+ "CN1C=NC2=C1C(=O)NC(=O)N2", # theophylline - similar to caffeine
423
+ "CO", # methanol - similar to ethanol
424
+ ]
425
+ neighbors_multi = prox_ref.neighbors_from_smiles(multi_query, n_neighbors=2)
426
+ print(neighbors_multi)
427
+
428
+ # Test with min_similarity threshold
429
+ print("\nQuery with min_similarity=0.3:")
430
+ neighbors_thresh = prox_ref.neighbors_from_smiles(query_smiles, min_similarity=0.3)
431
+ print(neighbors_thresh)
432
+
433
+ print("PASSED: neighbors_from_smiles working correctly")
434
+
435
+ # Test on real data from Workbench
436
+ from workbench.api import FeatureSet, Model
437
+
438
+ fs = FeatureSet("aqsol_features")
439
+ model = Model("aqsol-regression")
440
+ df = fs.pull_dataframe()[:1000] # Limit to 1000 for testing
441
+ prox = FingerprintProximity(df, id_column=fs.id_column, target=model.target())
442
+
443
+ print("\n" + "=" * 80)
444
+ print("Testing Neighbors...")
445
+ print("=" * 80)
446
+ test_id = df[fs.id_column].tolist()[0]
447
+ print(f"\nNeighbors for ID {test_id}:")
448
+ print(prox.neighbors(test_id))
449
+
450
+ print("\n" + "=" * 80)
451
+ print("Testing isolated compounds...")
452
+ print("=" * 80)
453
+
454
+ # Test isolated data in the top 1%
455
+ isolated_1pct = prox.isolated(top_percent=1.0)
456
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
457
+ print(isolated_1pct)
458
+
459
+ # Test isolated data in the top 5%
460
+ isolated_5pct = prox.isolated(top_percent=5.0)
461
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
462
+ print(isolated_5pct)
463
+
464
+ print("\n" + "=" * 80)
465
+ print("Testing target_gradients...")
466
+ print("=" * 80)
467
+
468
+ # Test with different parameters
469
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
470
+ print(f"\nTop 1% target gradients (min_delta=1.0) (n={len(gradients_1pct)}):")
471
+ print(gradients_1pct)
472
+
473
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
474
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
475
+ print(gradients_5pct)
476
+
477
+ # Test proximity_stats
478
+ print("\n" + "=" * 80)
479
+ print("Testing proximity_stats...")
480
+ print("=" * 80)
481
+ stats = prox.proximity_stats()
482
+ print(stats)
147
483
 
148
- # Test 1: All neighbors
149
- print("\n--- Test 1: All Neighbors ---")
150
- all_neighbors_df = proximity.all_neighbors()
151
- print(all_neighbors_df)
484
+ # Plot the similarity distribution using pandas
485
+ print("\n" + "=" * 80)
486
+ print("Plotting similarity distribution...")
487
+ print("=" * 80)
488
+ prox.df["nn_similarity"].hist(bins=50, figsize=(10, 6), edgecolor="black")
152
489
 
153
- # Test 2: Neighbors for a specific query
154
- print("\n--- Test 2: Neighbors for Query ---")
155
- query_df = pd.DataFrame({"id": ["a"], "fingerprint": ["101010"]})
156
- query_neighbors_df = proximity.neighbors(query_df=query_df)
157
- print(query_neighbors_df)
490
+ # Visualize the 2D projection
491
+ print("\n" + "=" * 80)
492
+ print("Visualizing 2D Projection...")
493
+ print("=" * 80)
494
+ from workbench.web_interface.components.plugin_unit_test import PluginUnitTest
495
+ from workbench.web_interface.components.plugins.scatter_plot import ScatterPlot
158
496
 
159
- # Test 3: Neighbors with similarity threshold
160
- print("\n--- Test 3: Neighbors with Minimum Similarity 0.5 ---")
161
- query_neighbors_sim_df = proximity.neighbors(query_df=query_df, min_similarity=0.5)
162
- print(query_neighbors_sim_df)
497
+ unit_test = PluginUnitTest(ScatterPlot, input_data=prox.df[:1000], x="x", y="y", color=model.target())
498
+ unit_test.run()