workbench 0.8.177__py3-none-any.whl → 0.8.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (140) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +1 -2
  3. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  4. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  5. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  6. workbench/algorithms/dataframe/projection_2d.py +44 -21
  7. workbench/algorithms/dataframe/proximity.py +259 -305
  8. workbench/algorithms/graph/light/proximity_graph.py +12 -11
  9. workbench/algorithms/models/cleanlab_model.py +382 -0
  10. workbench/algorithms/models/noise_model.py +388 -0
  11. workbench/algorithms/sql/column_stats.py +0 -1
  12. workbench/algorithms/sql/correlations.py +0 -1
  13. workbench/algorithms/sql/descriptive_stats.py +0 -1
  14. workbench/algorithms/sql/outliers.py +3 -3
  15. workbench/api/__init__.py +5 -1
  16. workbench/api/df_store.py +17 -108
  17. workbench/api/endpoint.py +14 -12
  18. workbench/api/feature_set.py +117 -11
  19. workbench/api/meta.py +0 -1
  20. workbench/api/meta_model.py +289 -0
  21. workbench/api/model.py +52 -21
  22. workbench/api/parameter_store.py +3 -52
  23. workbench/cached/cached_meta.py +0 -1
  24. workbench/cached/cached_model.py +49 -11
  25. workbench/core/artifacts/__init__.py +11 -2
  26. workbench/core/artifacts/artifact.py +5 -5
  27. workbench/core/artifacts/df_store_core.py +114 -0
  28. workbench/core/artifacts/endpoint_core.py +319 -204
  29. workbench/core/artifacts/feature_set_core.py +249 -45
  30. workbench/core/artifacts/model_core.py +135 -82
  31. workbench/core/artifacts/parameter_store_core.py +98 -0
  32. workbench/core/cloud_platform/cloud_meta.py +0 -1
  33. workbench/core/pipelines/pipeline_executor.py +1 -1
  34. workbench/core/transforms/features_to_model/features_to_model.py +60 -44
  35. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +43 -10
  36. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  37. workbench/core/views/training_view.py +113 -42
  38. workbench/core/views/view.py +53 -3
  39. workbench/core/views/view_utils.py +4 -4
  40. workbench/model_script_utils/model_script_utils.py +339 -0
  41. workbench/model_script_utils/pytorch_utils.py +405 -0
  42. workbench/model_script_utils/uq_harness.py +277 -0
  43. workbench/model_scripts/chemprop/chemprop.template +774 -0
  44. workbench/model_scripts/chemprop/generated_model_script.py +774 -0
  45. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  46. workbench/model_scripts/chemprop/requirements.txt +3 -0
  47. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  48. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  49. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  50. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -2
  51. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  52. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  53. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  54. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  55. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  56. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  57. workbench/model_scripts/custom_models/uq_models/ngboost.template +15 -16
  58. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  59. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  60. workbench/model_scripts/meta_model/meta_model.template +209 -0
  61. workbench/model_scripts/pytorch_model/generated_model_script.py +443 -499
  62. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  63. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  64. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  65. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  66. workbench/model_scripts/pytorch_model/uq_harness.py +277 -0
  67. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  68. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  69. workbench/model_scripts/script_generation.py +15 -12
  70. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  71. workbench/model_scripts/xgb_model/generated_model_script.py +371 -403
  72. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  73. workbench/model_scripts/xgb_model/uq_harness.py +277 -0
  74. workbench/model_scripts/xgb_model/xgb_model.template +367 -399
  75. workbench/repl/workbench_shell.py +18 -14
  76. workbench/resources/open_source_api.key +1 -1
  77. workbench/scripts/endpoint_test.py +162 -0
  78. workbench/scripts/lambda_test.py +73 -0
  79. workbench/scripts/meta_model_sim.py +35 -0
  80. workbench/scripts/ml_pipeline_sqs.py +122 -6
  81. workbench/scripts/training_test.py +85 -0
  82. workbench/themes/dark/custom.css +59 -0
  83. workbench/themes/dark/plotly.json +5 -5
  84. workbench/themes/light/custom.css +153 -40
  85. workbench/themes/light/plotly.json +9 -9
  86. workbench/themes/midnight_blue/custom.css +59 -0
  87. workbench/utils/aws_utils.py +0 -1
  88. workbench/utils/chem_utils/fingerprints.py +87 -46
  89. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  90. workbench/utils/chem_utils/projections.py +16 -6
  91. workbench/utils/chem_utils/vis.py +25 -27
  92. workbench/utils/chemprop_utils.py +141 -0
  93. workbench/utils/config_manager.py +2 -6
  94. workbench/utils/endpoint_utils.py +5 -7
  95. workbench/utils/license_manager.py +2 -6
  96. workbench/utils/markdown_utils.py +57 -0
  97. workbench/utils/meta_model_simulator.py +499 -0
  98. workbench/utils/metrics_utils.py +256 -0
  99. workbench/utils/model_utils.py +260 -76
  100. workbench/utils/pipeline_utils.py +0 -1
  101. workbench/utils/plot_utils.py +159 -34
  102. workbench/utils/pytorch_utils.py +87 -0
  103. workbench/utils/shap_utils.py +11 -57
  104. workbench/utils/theme_manager.py +95 -30
  105. workbench/utils/xgboost_local_crossfold.py +267 -0
  106. workbench/utils/xgboost_model_utils.py +127 -220
  107. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  108. workbench/web_interface/components/model_plot.py +16 -2
  109. workbench/web_interface/components/plugin_unit_test.py +5 -3
  110. workbench/web_interface/components/plugins/ag_table.py +2 -4
  111. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  112. workbench/web_interface/components/plugins/model_details.py +48 -80
  113. workbench/web_interface/components/plugins/scatter_plot.py +192 -92
  114. workbench/web_interface/components/settings_menu.py +184 -0
  115. workbench/web_interface/page_views/main_page.py +0 -1
  116. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/METADATA +31 -17
  117. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/RECORD +121 -106
  118. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/entry_points.txt +4 -0
  119. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/licenses/LICENSE +1 -1
  120. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  121. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  122. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  123. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  124. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  125. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -494
  126. workbench/model_scripts/custom_models/uq_models/mapie.template +0 -494
  127. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -386
  128. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  129. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  130. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  131. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  132. workbench/themes/quartz/base_css.url +0 -1
  133. workbench/themes/quartz/custom.css +0 -117
  134. workbench/themes/quartz/plotly.json +0 -642
  135. workbench/themes/quartz_dark/base_css.url +0 -1
  136. workbench/themes/quartz_dark/custom.css +0 -131
  137. workbench/themes/quartz_dark/plotly.json +0 -642
  138. workbench/utils/resource_utils.py +0 -39
  139. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/WHEEL +0 -0
  140. {workbench-0.8.177.dist-info → workbench-0.8.227.dist-info}/top_level.txt +0 -0
@@ -1,384 +1,338 @@
1
1
  import pandas as pd
2
2
  import numpy as np
3
- from sklearn.preprocessing import StandardScaler
4
- from sklearn.neighbors import NearestNeighbors
5
- from typing import List, Dict
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Dict, Optional, Union
6
5
  import logging
7
- import pickle
8
- import os
9
- import json
10
- from pathlib import Path
11
- from enum import Enum
12
6
 
13
7
  # Set up logging
14
8
  log = logging.getLogger("workbench")
15
9
 
16
10
 
17
- # ^Enumerated^ Proximity Types (distance or similarity)
18
- class ProximityType(Enum):
19
- DISTANCE = "distance"
20
- SIMILARITY = "similarity"
11
+ class Proximity(ABC):
12
+ """Abstract base class for proximity/neighbor computations."""
21
13
 
22
-
23
- class Proximity:
24
14
  def __init__(
25
15
  self,
26
16
  df: pd.DataFrame,
27
17
  id_column: str,
28
18
  features: List[str],
29
- target: str = None,
30
- track_columns: List[str] = None,
31
- n_neighbors: int = 10,
19
+ target: Optional[str] = None,
20
+ include_all_columns: bool = False,
32
21
  ):
33
22
  """
34
23
  Initialize the Proximity class.
35
24
 
36
25
  Args:
37
- df (pd.DataFrame): DataFrame containing data for neighbor computations.
38
- id_column (str): Name of the column used as the identifier.
39
- features (List[str]): List of feature column names to be used for neighbor computations.
40
- target (str, optional): Name of the target column. Defaults to None.
41
- track_columns (List[str], optional): Additional columns to track in results. Defaults to None.
42
- n_neighbors (int): Number of neighbors to compute. Defaults to 10.
26
+ df: DataFrame containing data for neighbor computations.
27
+ id_column: Name of the column used as the identifier.
28
+ features: List of feature column names to be used for neighbor computations.
29
+ target: Name of the target column. Defaults to None.
30
+ include_all_columns: Include all DataFrame columns in neighbor results. Defaults to False.
43
31
  """
44
- self.df = df.dropna(subset=features).copy()
45
32
  self.id_column = id_column
46
- self.n_neighbors = min(n_neighbors, len(self.df) - 1)
47
- self.target = target
48
33
  self.features = features
49
- self.scaler = None
50
- self.X = None
51
- self.nn = None
52
- self.proximity_type = None
53
- self.track_columns = track_columns or []
54
-
55
- # Right now we only support numeric features, so remove any columns that are not numeric
56
- non_numeric_features = self.df[self.features].select_dtypes(exclude=["number"]).columns.tolist()
57
- if non_numeric_features:
58
- log.warning(f"Non-numeric features {non_numeric_features} aren't currently supported...")
59
- self.features = [f for f in self.features if f not in non_numeric_features]
60
-
61
- # Build the proximity model
62
- self.build_proximity_model()
63
-
64
- def build_proximity_model(self) -> None:
65
- """Standardize features and fit Nearest Neighbors model.
66
- Note: This method can be overridden in subclasses for custom behavior."""
67
- self.proximity_type = ProximityType.DISTANCE
68
- self.scaler = StandardScaler()
69
- self.X = self.scaler.fit_transform(self.df[self.features])
70
- self.nn = NearestNeighbors(n_neighbors=self.n_neighbors + 1).fit(self.X)
71
-
72
- def all_neighbors(self) -> pd.DataFrame:
34
+ self.target = target
35
+ self.include_all_columns = include_all_columns
36
+
37
+ # Store the DataFrame (subclasses may filter/modify in _prepare_data)
38
+ self.df = df.copy()
39
+
40
+ # Prepare data (subclasses can override)
41
+ self._prepare_data()
42
+
43
+ # Compute target range if target is provided
44
+ self.target_range = None
45
+ if self.target and self.target in self.df.columns:
46
+ self.target_range = self.df[self.target].max() - self.df[self.target].min()
47
+
48
+ # Build the proximity model (subclass-specific)
49
+ self._build_model()
50
+
51
+ # Precompute landscape metrics
52
+ self._precompute_metrics()
53
+
54
+ # Define core columns for output (subclasses can override)
55
+ self._set_core_columns()
56
+
57
+ # Project the data to 2D (subclass-specific)
58
+ self._project_2d()
59
+
60
+ def _prepare_data(self) -> None:
61
+ """Prepare the data before building the model. Subclasses can override."""
62
+ pass
63
+
64
+ def _set_core_columns(self) -> None:
65
+ """Set the core columns for output. Subclasses can override."""
66
+ self.core_columns = [self.id_column, "nn_distance", "nn_id"]
67
+ if self.target:
68
+ self.core_columns.extend([self.target, "nn_target", "nn_target_diff"])
69
+
70
+ @abstractmethod
71
+ def _build_model(self) -> None:
72
+ """Build the proximity model. Must set self.nn (NearestNeighbors instance)."""
73
+ pass
74
+
75
+ @abstractmethod
76
+ def _transform_features(self, df: pd.DataFrame) -> np.ndarray:
77
+ """Transform features for querying. Returns feature matrix for nearest neighbor lookup."""
78
+ pass
79
+
80
+ @abstractmethod
81
+ def _project_2d(self) -> None:
82
+ """Project the data to 2D for visualization. Updates self.df with 'x' and 'y' columns."""
83
+ pass
84
+
85
+ def isolated(self, top_percent: float = 1.0) -> pd.DataFrame:
73
86
  """
74
- Compute nearest neighbors for all rows in the dataset.
87
+ Find isolated data points based on distance to nearest neighbor.
88
+
89
+ Args:
90
+ top_percent: Percentage of most isolated data points to return (e.g., 1.0 returns top 1%)
75
91
 
76
92
  Returns:
77
- pd.DataFrame: A DataFrame of neighbors and their distances.
93
+ DataFrame of observations above the percentile threshold, sorted by distance (descending)
78
94
  """
79
- distances, indices = self.nn.kneighbors(self.X)
80
- results = []
95
+ percentile = 100 - top_percent
96
+ threshold = np.percentile(self.df["nn_distance"], percentile)
97
+ isolated = self.df[self.df["nn_distance"] >= threshold].copy()
98
+ isolated = isolated.sort_values("nn_distance", ascending=False).reset_index(drop=True)
99
+ return isolated if self.include_all_columns else isolated[self.core_columns]
81
100
 
82
- for i, (dists, nbrs) in enumerate(zip(distances, indices)):
83
- query_id = self.df.iloc[i][self.id_column]
101
+ def proximity_stats(self) -> pd.DataFrame:
102
+ """
103
+ Return distribution statistics for nearest neighbor distances.
84
104
 
85
- # Process neighbors
86
- for neighbor_idx, dist in zip(nbrs, dists):
87
- # Skip self (neighbor index == current row index)
88
- if neighbor_idx == i:
89
- continue
90
- results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
105
+ Returns:
106
+ DataFrame with proximity distribution statistics (count, mean, std, percentiles)
107
+ """
108
+ return (
109
+ self.df["nn_distance"].describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).to_frame()
110
+ )
111
+
112
+ def target_gradients(
113
+ self,
114
+ top_percent: float = 1.0,
115
+ min_delta: Optional[float] = None,
116
+ k_neighbors: int = 4,
117
+ only_coincident: bool = False,
118
+ ) -> pd.DataFrame:
119
+ """
120
+ Find compounds with steep target gradients (data quality issues and activity cliffs).
91
121
 
92
- return pd.DataFrame(results)
122
+ Uses a two-phase approach:
123
+ 1. Quick filter using nearest neighbor gradient
124
+ 2. Verify using k-neighbor median to handle cases where the nearest neighbor is the outlier
125
+
126
+ Args:
127
+ top_percent: Percentage of compounds with steepest gradients to return (e.g., 1.0 = top 1%)
128
+ min_delta: Minimum absolute target difference to consider. If None, defaults to target_range/100
129
+ k_neighbors: Number of neighbors to use for median calculation (default: 4)
130
+ only_coincident: If True, only consider compounds that are coincident (default: False)
131
+
132
+ Returns:
133
+ DataFrame of compounds with steepest gradients, sorted by gradient (descending)
134
+ """
135
+ if self.target is None:
136
+ raise ValueError("Target column must be specified")
137
+
138
+ epsilon = 1e-6
139
+
140
+ # Phase 1: Quick filter using precomputed nearest neighbor
141
+ candidates = self.df.copy()
142
+ candidates["gradient"] = candidates["nn_target_diff"] / (candidates["nn_distance"] + epsilon)
143
+
144
+ # Apply min_delta
145
+ if min_delta is None:
146
+ min_delta = self.target_range / 100.0 if self.target_range > 0 else 0.0
147
+ candidates = candidates[candidates["nn_target_diff"] >= min_delta]
148
+
149
+ # Filter based on mode
150
+ if only_coincident:
151
+ # Only keep coincident points (nn_distance ~= 0)
152
+ candidates = candidates[candidates["nn_distance"] < epsilon].copy()
153
+ else:
154
+ # Get top X% by initial gradient
155
+ percentile = 100 - top_percent
156
+ threshold = np.percentile(candidates["gradient"], percentile)
157
+ candidates = candidates[candidates["gradient"] >= threshold].copy()
158
+
159
+ # Phase 2: Verify with K-neighbor median to filter out cases where nearest neighbor is the outlier
160
+ results = []
161
+ for _, row in candidates.iterrows():
162
+ cmpd_id = row[self.id_column]
163
+ cmpd_target = row[self.target]
164
+
165
+ # Get K nearest neighbors (excluding self)
166
+ nbrs = self.neighbors(cmpd_id, n_neighbors=k_neighbors, include_self=False)
167
+
168
+ # Calculate median target of k neighbors, excluding the nearest neighbor (index 0)
169
+ neighbor_median = nbrs.iloc[1:k_neighbors][self.target].median()
170
+ median_diff = abs(cmpd_target - neighbor_median)
171
+
172
+ # Only keep if compound differs from neighborhood median
173
+ # This filters out cases where the nearest neighbor is the outlier
174
+ if median_diff >= min_delta:
175
+ results.append(
176
+ {
177
+ self.id_column: cmpd_id,
178
+ self.target: cmpd_target,
179
+ "nn_target": row["nn_target"],
180
+ "nn_target_diff": row["nn_target_diff"],
181
+ "nn_distance": row["nn_distance"],
182
+ "gradient": row["gradient"], # Keep Phase 1 gradient
183
+ "neighbor_median": neighbor_median,
184
+ "neighbor_median_diff": median_diff,
185
+ }
186
+ )
187
+
188
+ # Handle empty results
189
+ if not results:
190
+ return pd.DataFrame(
191
+ columns=[
192
+ self.id_column,
193
+ self.target,
194
+ "nn_target",
195
+ "nn_target_diff",
196
+ "nn_distance",
197
+ "gradient",
198
+ "neighbor_median",
199
+ "neighbor_median_diff",
200
+ ]
201
+ )
202
+
203
+ results_df = pd.DataFrame(results)
204
+ results_df = results_df.sort_values("gradient", ascending=False).reset_index(drop=True)
205
+ return results_df
93
206
 
94
207
  def neighbors(
95
208
  self,
96
- query_df: pd.DataFrame,
97
- radius: float = None,
209
+ id_or_ids: Union[str, int, List[Union[str, int]]],
210
+ n_neighbors: Optional[int] = 5,
211
+ radius: Optional[float] = None,
98
212
  include_self: bool = True,
99
213
  ) -> pd.DataFrame:
100
214
  """
101
- Return neighbors for rows in a query DataFrame.
215
+ Return neighbors for ID(s) from the existing dataset.
102
216
 
103
217
  Args:
104
- query_df: DataFrame containing query points
218
+ id_or_ids: Single ID or list of IDs to look up
219
+ n_neighbors: Number of neighbors to return (default: 5, ignored if radius is set)
105
220
  radius: If provided, find all neighbors within this radius
106
- include_self: Whether to include self in results (if present)
221
+ include_self: Whether to include self in results (default: True)
107
222
 
108
223
  Returns:
109
224
  DataFrame containing neighbors and distances
110
-
111
- Note: The query DataFrame must include the feature columns. The id_column is optional.
112
225
  """
113
- # Check if all required features are present
114
- missing = set(self.features) - set(query_df.columns)
115
- if missing:
116
- raise ValueError(f"Query DataFrame is missing required feature columns: {missing}")
226
+ # Normalize to list
227
+ ids = [id_or_ids] if not isinstance(id_or_ids, list) else id_or_ids
117
228
 
118
- # Check if id_column is present
119
- id_column_present = self.id_column in query_df.columns
229
+ # Validate IDs exist
230
+ missing_ids = set(ids) - set(self.df[self.id_column])
231
+ if missing_ids:
232
+ raise ValueError(f"IDs not found in dataset: {missing_ids}")
120
233
 
121
- # None of the features can be NaNs, so report rows with NaNs and then drop them
122
- rows_with_nan = query_df[self.features].isna().any(axis=1)
234
+ # Filter to requested IDs and preserve order
235
+ query_df = self.df[self.df[self.id_column].isin(ids)]
236
+ query_df = query_df.set_index(self.id_column).loc[ids].reset_index()
123
237
 
124
- # Print the ID column for rows with NaNs
125
- if rows_with_nan.any():
126
- log.warning(f"Found {rows_with_nan.sum()} rows with NaNs in feature columns:")
127
- log.warning(query_df.loc[rows_with_nan, self.id_column])
238
+ # Transform query features (subclass-specific)
239
+ X_query = self._transform_features(query_df)
128
240
 
129
- # Drop rows with NaNs in feature columns and reassign to query_df
130
- query_df = query_df.dropna(subset=self.features)
131
-
132
- # Transform the query features using the model's scaler
133
- X_query = self.scaler.transform(query_df[self.features])
134
-
135
- # Get neighbors using either radius or k-nearest neighbors
241
+ # Get neighbors
136
242
  if radius is not None:
137
243
  distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
138
244
  else:
139
- distances, indices = self.nn.kneighbors(X_query)
245
+ distances, indices = self.nn.kneighbors(X_query, n_neighbors=n_neighbors)
140
246
 
141
247
  # Build results
142
- all_results = []
248
+ results = []
143
249
  for i, (dists, nbrs) in enumerate(zip(distances, indices)):
144
- # Use the ID from the query DataFrame if available, otherwise use the row index
145
- query_id = query_df.iloc[i][self.id_column] if id_column_present else f"query_{i}"
250
+ query_id = query_df.iloc[i][self.id_column]
146
251
 
147
252
  for neighbor_idx, dist in zip(nbrs, dists):
148
- # Skip if the neighbor is the query itself and include_self is False
149
253
  neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
254
+
255
+ # Skip self if requested
150
256
  if not include_self and neighbor_id == query_id:
151
257
  continue
152
258
 
153
- all_results.append(
154
- self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist)
155
- )
259
+ results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
156
260
 
157
- return pd.DataFrame(all_results)
261
+ df_results = pd.DataFrame(results)
262
+ df_results["is_self"] = df_results["neighbor_id"] == df_results[self.id_column]
263
+ df_results = df_results.sort_values([self.id_column, "is_self", "distance"], ascending=[True, False, True])
264
+ return df_results.drop("is_self", axis=1).reset_index(drop=True)
158
265
 
159
- def _build_neighbor_result(self, query_id, neighbor_idx: int, distance: float) -> Dict:
266
+ def _precompute_metrics(self) -> None:
160
267
  """
161
- Internal: Build a result dictionary for a single neighbor.
268
+ Precompute landscape metrics for all compounds.
162
269
 
163
- Args:
164
- query_id: ID of the query point
165
- neighbor_idx: Index of the neighbor in the original DataFrame
166
- distance: Distance between query and neighbor
270
+ Adds columns to self.df:
271
+ - nn_distance: Distance to nearest neighbor
272
+ - nn_id: ID of nearest neighbor
167
273
 
168
- Returns:
169
- Dictionary containing neighbor information
274
+ If target is specified, also adds:
275
+ - nn_target: Target value of nearest neighbor
276
+ - nn_target_diff: Absolute difference from nearest neighbor target
170
277
  """
171
- neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
278
+ log.info("Precomputing proximity metrics...")
172
279
 
173
- # Basic neighbor info
174
- neighbor_info = {
175
- self.id_column: query_id,
176
- "neighbor_id": neighbor_id,
177
- "distance": distance,
178
- }
280
+ # Get nearest neighbors for all points (n=2 because index 0 is self)
281
+ X = self._transform_features(self.df)
282
+ distances, indices = self.nn.kneighbors(X, n_neighbors=2)
179
283
 
180
- # Determine which additional columns to include
181
- relevant_cols = [self.target, "prediction"] if self.target else []
182
- relevant_cols += [c for c in self.df.columns if "_proba" in c or "residual" in c]
183
- relevant_cols += ["outlier"]
284
+ # Extract nearest neighbor (index 1, since index 0 is self)
285
+ self.df["nn_distance"] = distances[:, 1]
286
+ self.df["nn_id"] = self.df.iloc[indices[:, 1]][self.id_column].values
184
287
 
185
- # Add user-specified columns
186
- relevant_cols += self.track_columns
288
+ # If target exists, compute target-based metrics
289
+ if self.target and self.target in self.df.columns:
290
+ # Get target values for nearest neighbor
291
+ nn_target_values = self.df.iloc[indices[:, 1]][self.target].values
292
+ self.df["nn_target"] = nn_target_values
293
+ self.df["nn_target_diff"] = np.abs(self.df[self.target].values - nn_target_values)
187
294
 
188
- # Add values for each relevant column that exists in the dataframe
189
- for col in filter(lambda c: c in self.df.columns, relevant_cols):
190
- neighbor_info[col] = self.df.iloc[neighbor_idx][col]
295
+ # Precompute target range for min_delta default
296
+ self.target_range = self.df[self.target].max() - self.df[self.target].min()
191
297
 
192
- return neighbor_info
298
+ log.info("Proximity metrics precomputed successfully")
193
299
 
194
- def serialize(self, directory: str) -> None:
300
+ def _build_neighbor_result(self, query_id, neighbor_idx: int, distance: float) -> Dict:
195
301
  """
196
- Serialize the Proximity model to a directory.
302
+ Build a result dictionary for a single neighbor.
197
303
 
198
304
  Args:
199
- directory: Directory path to save the model components
200
- """
201
- # Create directory if it doesn't exist
202
- os.makedirs(directory, exist_ok=True)
203
-
204
- # Save metadata
205
- metadata = {
206
- "id_column": self.id_column,
207
- "features": self.features,
208
- "target": self.target,
209
- "track_columns": self.track_columns,
210
- "n_neighbors": self.n_neighbors,
211
- }
212
-
213
- with open(os.path.join(directory, "metadata.json"), "w") as f:
214
- json.dump(metadata, f)
215
-
216
- # Save the DataFrame
217
- self.df.to_pickle(os.path.join(directory, "df.pkl"))
305
+ query_id: ID of the query point
306
+ neighbor_idx: Index of the neighbor in the original DataFrame
307
+ distance: Distance between query and neighbor
218
308
 
219
- # Save the scaler and nearest neighbors model
220
- with open(os.path.join(directory, "scaler.pkl"), "wb") as f:
221
- pickle.dump(self.scaler, f)
309
+ Returns:
310
+ Dictionary containing neighbor information
311
+ """
312
+ neighbor_row = self.df.iloc[neighbor_idx]
313
+ neighbor_id = neighbor_row[self.id_column]
222
314
 
223
- with open(os.path.join(directory, "nn_model.pkl"), "wb") as f:
224
- pickle.dump(self.nn, f)
315
+ # Start with basic info
316
+ result = {
317
+ self.id_column: query_id,
318
+ "neighbor_id": neighbor_id,
319
+ "distance": 0.0 if distance < 1e-6 else distance,
320
+ }
225
321
 
226
- log.info(f"Proximity model serialized to {directory}")
322
+ # Add target if present
323
+ if self.target and self.target in self.df.columns:
324
+ result[self.target] = neighbor_row[self.target]
227
325
 
228
- @classmethod
229
- def deserialize(cls, directory: str) -> "Proximity":
230
- """
231
- Deserialize a Proximity model from a directory.
326
+ # Add prediction/probability columns if they exist
327
+ for col in self.df.columns:
328
+ if col == "prediction" or "_proba" in col or "residual" in col or col == "in_model":
329
+ result[col] = neighbor_row[col]
232
330
 
233
- Args:
234
- directory: Directory path containing the serialized model components
331
+ # Include all columns if requested
332
+ if self.include_all_columns:
333
+ result.update(neighbor_row.to_dict())
334
+ # Restore query_id after update (neighbor_row may have overwritten id column)
335
+ result[self.id_column] = query_id
336
+ result["neighbor_id"] = neighbor_id
235
337
 
236
- Returns:
237
- Proximity: A new Proximity instance
238
- """
239
- directory_path = Path(directory)
240
- if not directory_path.exists() or not directory_path.is_dir():
241
- raise ValueError(f"Directory {directory} does not exist or is not a directory")
242
-
243
- # Load metadata
244
- with open(os.path.join(directory, "metadata.json"), "r") as f:
245
- metadata = json.load(f)
246
-
247
- # Load DataFrame
248
- df_path = os.path.join(directory, "df.pkl")
249
- if not os.path.exists(df_path):
250
- raise FileNotFoundError(f"DataFrame file not found at {df_path}")
251
- df = pd.read_pickle(df_path)
252
-
253
- # Create instance but skip _prepare_data
254
- instance = cls.__new__(cls)
255
- instance.df = df
256
- instance.id_column = metadata["id_column"]
257
- instance.features = metadata["features"]
258
- instance.target = metadata["target"]
259
- instance.track_columns = metadata["track_columns"]
260
- instance.n_neighbors = metadata["n_neighbors"]
261
-
262
- # Load scaler and nn model
263
- with open(os.path.join(directory, "scaler.pkl"), "rb") as f:
264
- instance.scaler = pickle.load(f)
265
-
266
- with open(os.path.join(directory, "nn_model.pkl"), "rb") as f:
267
- instance.nn = pickle.load(f)
268
-
269
- # Load X from scaler transform
270
- instance.X = instance.scaler.transform(instance.df[instance.features])
271
-
272
- log.info(f"Proximity model deserialized from {directory}")
273
- return instance
274
-
275
-
276
- # Testing the Proximity class
277
- if __name__ == "__main__":
278
-
279
- pd.set_option("display.max_columns", None)
280
- pd.set_option("display.width", 1000)
281
-
282
- # Create a sample DataFrame
283
- data = {
284
- "ID": [1, 2, 3, 4, 5],
285
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
286
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
287
- "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
288
- }
289
- df = pd.DataFrame(data)
290
-
291
- # Test the Proximity class
292
- features = ["Feature1", "Feature2", "Feature3"]
293
- prox = Proximity(df, id_column="ID", features=features, n_neighbors=3)
294
- print(prox.all_neighbors())
295
-
296
- # Test the neighbors method
297
- print(prox.neighbors(query_df=df.iloc[[0]]))
298
-
299
- # Test the neighbors method with radius
300
- print(prox.neighbors(query_df=df.iloc[0:2], radius=2.0))
301
-
302
- # Test with data that isn't in the 'train' dataframe
303
- query_data = {
304
- "ID": [6],
305
- "Feature1": [0.31],
306
- "Feature2": [0.31],
307
- "Feature3": [2.31],
308
- }
309
- query_df = pd.DataFrame(query_data)
310
- print(prox.neighbors(query_df=query_df))
311
-
312
- # Test with Features list
313
- prox = Proximity(df, id_column="ID", features=["Feature1"], n_neighbors=2)
314
- print(prox.all_neighbors())
315
-
316
- # Create a sample DataFrame
317
- data = {
318
- "foo_id": ["a", "b", "c", "d", "e"], # Testing string IDs
319
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
320
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
321
- "target": [1, 0, 1, 0, 5],
322
- }
323
- df = pd.DataFrame(data)
324
-
325
- # Test with String Ids
326
- prox = Proximity(
327
- df,
328
- id_column="foo_id",
329
- features=["Feature1", "Feature2"],
330
- target="target",
331
- track_columns=["Feature1", "Feature2"],
332
- n_neighbors=3,
333
- )
334
- print(prox.all_neighbors())
335
-
336
- # Test the neighbors method
337
- print(prox.neighbors(query_df=df.iloc[0:2]))
338
-
339
- # Time neighbors with all IDs versus calling all_neighbors
340
- import time
341
-
342
- start_time = time.time()
343
- prox_df = prox.neighbors(query_df=df, include_self=False)
344
- end_time = time.time()
345
- print(f"Time taken for neighbors: {end_time - start_time:.4f} seconds")
346
- start_time = time.time()
347
- prox_df_all = prox.all_neighbors()
348
- end_time = time.time()
349
- print(f"Time taken for all_neighbors: {end_time - start_time:.4f} seconds")
350
-
351
- # Now compare the two dataframes
352
- print("Neighbors DataFrame:")
353
- print(prox_df)
354
- print("\nAll Neighbors DataFrame:")
355
- print(prox_df_all)
356
- # Check for any discrepancies
357
- if prox_df.equals(prox_df_all):
358
- print("The two DataFrames are equal :)")
359
- else:
360
- print("ERROR: The two DataFrames are not equal!")
361
-
362
- # Test querying without the id_column
363
- df_no_id = df.drop(columns=["foo_id"])
364
- print(prox.neighbors(query_df=df_no_id, include_self=False))
365
-
366
- # Test duplicate IDs
367
- data = {
368
- "foo_id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
369
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
370
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
371
- "target": [1, 0, 1, 0, 5],
372
- }
373
- df = pd.DataFrame(data)
374
- prox = Proximity(df, id_column="foo_id", features=["Feature1", "Feature2"], target="target", n_neighbors=3)
375
- print(df.equals(prox.df))
376
-
377
- # Test with a categorical feature
378
- from workbench.api import FeatureSet, Model
379
-
380
- fs = FeatureSet("abalone_features")
381
- model = Model("abalone-regression")
382
- df = fs.pull_dataframe()
383
- prox = Proximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
384
- print(prox.neighbors(query_df=df[0:2]))
338
+ return result