workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -2,275 +2,307 @@ import pandas as pd
2
2
  import numpy as np
3
3
  from sklearn.preprocessing import StandardScaler
4
4
  from sklearn.neighbors import NearestNeighbors
5
- from typing import List, Dict
5
+ from typing import List, Dict, Optional, Union
6
6
  import logging
7
- import pickle
8
- import os
9
- import json
10
- from pathlib import Path
11
- from enum import Enum
12
7
 
13
8
  # Set up logging
14
9
  log = logging.getLogger("workbench")
15
10
 
16
11
 
17
- # ^Enumerated^ Proximity Types (distance or similarity)
18
- class ProximityType(Enum):
19
- DISTANCE = "distance"
20
- SIMILARITY = "similarity"
21
-
22
-
23
12
  class Proximity:
24
13
  def __init__(
25
14
  self,
26
15
  df: pd.DataFrame,
27
16
  id_column: str,
28
17
  features: List[str],
29
- target: str = None,
30
- track_columns: List[str] = None,
31
- n_neighbors: int = 10,
18
+ target: Optional[str] = None,
19
+ track_columns: Optional[List[str]] = None,
32
20
  ):
33
21
  """
34
22
  Initialize the Proximity class.
35
23
 
36
24
  Args:
37
- df (pd.DataFrame): DataFrame containing data for neighbor computations.
38
- id_column (str): Name of the column used as the identifier.
39
- features (List[str]): List of feature column names to be used for neighbor computations.
40
- target (str, optional): Name of the target column. Defaults to None.
41
- track_columns (List[str], optional): Additional columns to track in results. Defaults to None.
42
- n_neighbors (int): Number of neighbors to compute. Defaults to 10.
25
+ df: DataFrame containing data for neighbor computations.
26
+ id_column: Name of the column used as the identifier.
27
+ features: List of feature column names to be used for neighbor computations.
28
+ target: Name of the target column. Defaults to None.
29
+ track_columns: Additional columns to track in results. Defaults to None.
43
30
  """
44
- self.df = df.dropna(subset=features).copy()
45
31
  self.id_column = id_column
46
- self.n_neighbors = min(n_neighbors, len(self.df) - 1)
47
32
  self.target = target
48
- self.features = features
49
- self.scaler = None
50
- self.X = None
51
- self.nn = None
52
- self.proximity_type = None
53
33
  self.track_columns = track_columns or []
54
34
 
55
- # Right now we only support numeric features, so remove any columns that are not numeric
56
- non_numeric_features = self.df[self.features].select_dtypes(exclude=["number"]).columns.tolist()
57
- if non_numeric_features:
58
- log.warning(f"Non-numeric features {non_numeric_features} aren't currently supported...")
59
- self.features = [f for f in self.features if f not in non_numeric_features]
35
+ # Filter out non-numeric features
36
+ self.features = self._validate_features(df, features)
37
+
38
+ # Drop NaN rows and set up DataFrame
39
+ self.df = df.dropna(subset=self.features).copy()
40
+
41
+ # Compute target range if target is provided
42
+ self.target_range = None
43
+ if self.target and self.target in self.df.columns:
44
+ self.target_range = self.df[self.target].max() - self.df[self.target].min()
60
45
 
61
46
  # Build the proximity model
62
- self.build_proximity_model()
47
+ self._build_model()
63
48
 
64
- def build_proximity_model(self) -> None:
65
- """Standardize features and fit Nearest Neighbors model.
66
- Note: This method can be overridden in subclasses for custom behavior."""
67
- self.proximity_type = ProximityType.DISTANCE
68
- self.scaler = StandardScaler()
69
- self.X = self.scaler.fit_transform(self.df[self.features])
70
- self.nn = NearestNeighbors(n_neighbors=self.n_neighbors + 1).fit(self.X)
49
+ # Precompute landscape metrics
50
+ self._precompute_metrics()
71
51
 
72
- def all_neighbors(self) -> pd.DataFrame:
52
+ def isolated(self, top_percent: float = 1.0) -> pd.DataFrame:
73
53
  """
74
- Compute nearest neighbors for all rows in the dataset.
54
+ Find isolated data points based on distance to nearest neighbor.
55
+
56
+ Args:
57
+ top_percent: Percentage of most isolated data points to return (e.g., 1.0 returns top 1%)
75
58
 
76
59
  Returns:
77
- pd.DataFrame: A DataFrame of neighbors and their distances.
60
+ DataFrame of observations above the percentile threshold, sorted by distance (descending)
78
61
  """
79
- distances, indices = self.nn.kneighbors(self.X)
80
- results = []
62
+ percentile = 100 - top_percent
63
+ threshold = np.percentile(self.df["nn_distance"], percentile)
64
+ isolated = self.df[self.df["nn_distance"] >= threshold].copy()
65
+ return isolated.sort_values("nn_distance", ascending=False).reset_index(drop=True)
81
66
 
82
- for i, (dists, nbrs) in enumerate(zip(distances, indices)):
83
- query_id = self.df.iloc[i][self.id_column]
67
+ def target_gradients(
68
+ self,
69
+ top_percent: float = 1.0,
70
+ min_delta: Optional[float] = None,
71
+ k_neighbors: int = 4,
72
+ only_coincident: bool = False,
73
+ ) -> pd.DataFrame:
74
+ """
75
+ Find compounds with steep target gradients (data quality issues and activity cliffs).
84
76
 
85
- # Process neighbors
86
- for neighbor_idx, dist in zip(nbrs, dists):
87
- # Skip self (neighbor index == current row index)
88
- if neighbor_idx == i:
89
- continue
90
- results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
77
+ Uses a two-phase approach:
78
+ 1. Quick filter using nearest neighbor gradient
79
+ 2. Verify using k-neighbor median to handle cases where the nearest neighbor is the outlier
80
+
81
+ Args:
82
+ top_percent: Percentage of compounds with steepest gradients to return (e.g., 1.0 = top 1%)
83
+ min_delta: Minimum absolute target difference to consider. If None, defaults to target_range/100
84
+ k_neighbors: Number of neighbors to use for median calculation (default: 4)
85
+ only_coincident: If True, only consider compounds that are coincident (default: False)
86
+
87
+ Returns:
88
+ DataFrame of compounds with steepest gradients, sorted by gradient (descending)
89
+ """
90
+ if self.target is None:
91
+ raise ValueError("Target column must be specified")
92
+
93
+ epsilon = 1e-5
94
+
95
+ # Phase 1: Quick filter using precomputed nearest neighbor
96
+ candidates = self.df.copy()
97
+ candidates["gradient"] = candidates["nn_target_diff"] / (candidates["nn_distance"] + epsilon)
98
+
99
+ # Apply min_delta
100
+ if min_delta is None:
101
+ min_delta = self.target_range / 100.0 if self.target_range > 0 else 0.0
102
+ candidates = candidates[candidates["nn_target_diff"] >= min_delta]
103
+
104
+ # Filter based on mode
105
+ if only_coincident:
106
+ # Only keep coincident points (nn_distance ~= 0)
107
+ candidates = candidates[candidates["nn_distance"] < epsilon].copy()
108
+ else:
109
+ # Get top X% by initial gradient
110
+ percentile = 100 - top_percent
111
+ threshold = np.percentile(candidates["gradient"], percentile)
112
+ candidates = candidates[candidates["gradient"] >= threshold].copy()
113
+
114
+ # Phase 2: Verify with k-neighbor median to filter out cases where nearest neighbor is the outlier
115
+ results = []
116
+ for _, row in candidates.iterrows():
117
+ cmpd_id = row[self.id_column]
118
+ cmpd_target = row[self.target]
119
+
120
+ # Get k nearest neighbors (excluding self)
121
+ nbrs = self.neighbors(cmpd_id, n_neighbors=k_neighbors, include_self=False)
122
+
123
+ # Calculate median target of k neighbors, excluding the nearest neighbor (index 0)
124
+ neighbor_median = nbrs.iloc[1:k_neighbors][self.target].median()
125
+ median_diff = abs(cmpd_target - neighbor_median)
126
+
127
+ # Only keep if compound differs from neighborhood median
128
+ # This filters out cases where the nearest neighbor is the outlier
129
+ if median_diff >= min_delta:
130
+ results.append(
131
+ {
132
+ self.id_column: cmpd_id,
133
+ self.target: cmpd_target,
134
+ "nn_target": row["nn_target"],
135
+ "nn_target_diff": row["nn_target_diff"],
136
+ "nn_distance": row["nn_distance"],
137
+ "gradient": row["gradient"], # Keep Phase 1 gradient
138
+ "neighbor_median": neighbor_median,
139
+ "neighbor_median_diff": median_diff,
140
+ }
141
+ )
91
142
 
92
- return pd.DataFrame(results)
143
+ # Handle empty results
144
+ if not results:
145
+ return pd.DataFrame(
146
+ columns=[
147
+ self.id_column,
148
+ self.target,
149
+ "neighbor_median",
150
+ "neighbor_median_diff",
151
+ "mean_distance",
152
+ "gradient",
153
+ ]
154
+ )
155
+
156
+ results_df = pd.DataFrame(results)
157
+ results_df = results_df.sort_values("gradient", ascending=False).reset_index(drop=True)
158
+ return results_df
93
159
 
94
160
  def neighbors(
95
161
  self,
96
- query_df: pd.DataFrame,
97
- radius: float = None,
162
+ id_or_ids: Union[str, int, List[Union[str, int]]],
163
+ n_neighbors: Optional[int] = 5,
164
+ radius: Optional[float] = None,
98
165
  include_self: bool = True,
99
166
  ) -> pd.DataFrame:
100
167
  """
101
- Return neighbors for rows in a query DataFrame.
168
+ Return neighbors for ID(s) from the existing dataset.
102
169
 
103
170
  Args:
104
- query_df: DataFrame containing query points
171
+ id_or_ids: Single ID or list of IDs to look up
172
+ n_neighbors: Number of neighbors to return (default: 5, ignored if radius is set)
105
173
  radius: If provided, find all neighbors within this radius
106
- include_self: Whether to include self in results (if present)
174
+ include_self: Whether to include self in results (default: True)
107
175
 
108
176
  Returns:
109
177
  DataFrame containing neighbors and distances
110
-
111
- Note: The query DataFrame must include the feature columns. The id_column is optional.
112
178
  """
113
- # Check if all required features are present
114
- missing = set(self.features) - set(query_df.columns)
115
- if missing:
116
- raise ValueError(f"Query DataFrame is missing required feature columns: {missing}")
117
-
118
- # Check if id_column is present
119
- id_column_present = self.id_column in query_df.columns
179
+ # Normalize to list
180
+ ids = [id_or_ids] if not isinstance(id_or_ids, list) else id_or_ids
120
181
 
121
- # None of the features can be NaNs, so report rows with NaNs and then drop them
122
- rows_with_nan = query_df[self.features].isna().any(axis=1)
182
+ # Validate IDs exist
183
+ missing_ids = set(ids) - set(self.df[self.id_column])
184
+ if missing_ids:
185
+ raise ValueError(f"IDs not found in dataset: {missing_ids}")
123
186
 
124
- # Print the ID column for rows with NaNs
125
- if rows_with_nan.any():
126
- log.warning(f"Found {rows_with_nan.sum()} rows with NaNs in feature columns:")
127
- log.warning(query_df.loc[rows_with_nan, self.id_column])
187
+ # Filter to requested IDs and preserve order
188
+ query_df = self.df[self.df[self.id_column].isin(ids)]
189
+ query_df = query_df.set_index(self.id_column).loc[ids].reset_index()
128
190
 
129
- # Drop rows with NaNs in feature columns and reassign to query_df
130
- query_df = query_df.dropna(subset=self.features)
131
-
132
- # Transform the query features using the model's scaler
191
+ # Transform query features
133
192
  X_query = self.scaler.transform(query_df[self.features])
134
193
 
135
- # Get neighbors using either radius or k-nearest neighbors
194
+ # Get neighbors
136
195
  if radius is not None:
137
196
  distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
138
197
  else:
139
- distances, indices = self.nn.kneighbors(X_query)
198
+ distances, indices = self.nn.kneighbors(X_query, n_neighbors=n_neighbors)
140
199
 
141
200
  # Build results
142
- all_results = []
201
+ results = []
143
202
  for i, (dists, nbrs) in enumerate(zip(distances, indices)):
144
- # Use the ID from the query DataFrame if available, otherwise use the row index
145
- query_id = query_df.iloc[i][self.id_column] if id_column_present else f"query_{i}"
203
+ query_id = query_df.iloc[i][self.id_column]
146
204
 
147
205
  for neighbor_idx, dist in zip(nbrs, dists):
148
- # Skip if the neighbor is the query itself and include_self is False
149
206
  neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
207
+
208
+ # Skip self if requested
150
209
  if not include_self and neighbor_id == query_id:
151
210
  continue
152
211
 
153
- all_results.append(
154
- self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist)
155
- )
212
+ results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
156
213
 
157
- return pd.DataFrame(all_results)
214
+ df_results = pd.DataFrame(results)
215
+ df_results["is_self"] = df_results["neighbor_id"] == df_results[self.id_column]
216
+ df_results = df_results.sort_values([self.id_column, "is_self", "distance"], ascending=[True, False, True])
217
+ return df_results.drop("is_self", axis=1).reset_index(drop=True)
158
218
 
159
- def _build_neighbor_result(self, query_id, neighbor_idx: int, distance: float) -> Dict:
160
- """
161
- Internal: Build a result dictionary for a single neighbor.
219
+ def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
220
+ """Remove non-numeric features and log warnings."""
221
+ non_numeric = [f for f in features if f not in df.select_dtypes(include=["number"]).columns]
222
+ if non_numeric:
223
+ log.warning(f"Non-numeric features {non_numeric} aren't currently supported, excluding them")
224
+ return [f for f in features if f not in non_numeric]
162
225
 
163
- Args:
164
- query_id: ID of the query point
165
- neighbor_idx: Index of the neighbor in the original DataFrame
166
- distance: Distance between query and neighbor
226
+ def _build_model(self) -> None:
227
+ """Standardize features and fit Nearest Neighbors model."""
228
+ self.scaler = StandardScaler()
229
+ X = self.scaler.fit_transform(self.df[self.features])
230
+ self.nn = NearestNeighbors().fit(X)
167
231
 
168
- Returns:
169
- Dictionary containing neighbor information
232
+ def _precompute_metrics(self, n_neighbors: int = 10) -> None:
170
233
  """
171
- neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
172
-
173
- # Basic neighbor info
174
- neighbor_info = {
175
- self.id_column: query_id,
176
- "neighbor_id": neighbor_id,
177
- "distance": distance,
178
- }
179
-
180
- # Determine which additional columns to include
181
- relevant_cols = [self.target, "prediction"] if self.target else []
182
- relevant_cols += [c for c in self.df.columns if "_proba" in c or "residual" in c]
183
- relevant_cols += ["outlier"]
184
-
185
- # Add user-specified columns
186
- relevant_cols += self.track_columns
234
+ Precompute landscape metrics for all compounds.
187
235
 
188
- # Add values for each relevant column that exists in the dataframe
189
- for col in filter(lambda c: c in self.df.columns, relevant_cols):
190
- neighbor_info[col] = self.df.iloc[neighbor_idx][col]
236
+ Adds columns to self.df:
237
+ - nn_distance: Distance to nearest neighbor
238
+ - nn_id: ID of nearest neighbor
191
239
 
192
- return neighbor_info
193
-
194
- def serialize(self, directory: str) -> None:
240
+ If target is specified, also adds:
241
+ - nn_target: Target value of nearest neighbor
242
+ - nn_target_diff: Absolute difference from nearest neighbor target
195
243
  """
196
- Serialize the Proximity model to a directory.
244
+ log.info("Precomputing proximity metrics...")
197
245
 
198
- Args:
199
- directory: Directory path to save the model components
200
- """
201
- # Create directory if it doesn't exist
202
- os.makedirs(directory, exist_ok=True)
203
-
204
- # Save metadata
205
- metadata = {
206
- "id_column": self.id_column,
207
- "features": self.features,
208
- "target": self.target,
209
- "track_columns": self.track_columns,
210
- "n_neighbors": self.n_neighbors,
211
- }
246
+ # Make sure n_neighbors isn't greater than dataset size
247
+ n_neighbors = min(n_neighbors, len(self.df) - 1)
212
248
 
213
- with open(os.path.join(directory, "metadata.json"), "w") as f:
214
- json.dump(metadata, f)
249
+ # Get nearest neighbors for all points (including self)
250
+ X = self.scaler.transform(self.df[self.features])
251
+ distances, indices = self.nn.kneighbors(X, n_neighbors=2) # Just need nearest neighbor
215
252
 
216
- # Save the DataFrame
217
- self.df.to_pickle(os.path.join(directory, "df.pkl"))
253
+ # Extract nearest neighbor (index 1, since index 0 is self)
254
+ self.df["nn_distance"] = distances[:, 1]
255
+ self.df["nn_id"] = self.df.iloc[indices[:, 1]][self.id_column].values
218
256
 
219
- # Save the scaler and nearest neighbors model
220
- with open(os.path.join(directory, "scaler.pkl"), "wb") as f:
221
- pickle.dump(self.scaler, f)
257
+ # If target exists, compute target-based metrics
258
+ if self.target and self.target in self.df.columns:
259
+ # Get target values for nearest neighbor
260
+ nn_target_values = self.df.iloc[indices[:, 1]][self.target].values
261
+ self.df["nn_target"] = nn_target_values
262
+ self.df["nn_target_diff"] = np.abs(self.df[self.target].values - nn_target_values)
222
263
 
223
- with open(os.path.join(directory, "nn_model.pkl"), "wb") as f:
224
- pickle.dump(self.nn, f)
264
+ # Precompute target range for min_delta default
265
+ self.target_range = self.df[self.target].max() - self.df[self.target].min()
225
266
 
226
- log.info(f"Proximity model serialized to {directory}")
267
+ log.info("Proximity metrics precomputed successfully")
227
268
 
228
- @classmethod
229
- def deserialize(cls, directory: str) -> "Proximity":
269
+ def _build_neighbor_result(self, query_id, neighbor_idx: int, distance: float) -> Dict:
230
270
  """
231
- Deserialize a Proximity model from a directory.
271
+ Build a result dictionary for a single neighbor.
232
272
 
233
273
  Args:
234
- directory: Directory path containing the serialized model components
274
+ query_id: ID of the query point
275
+ neighbor_idx: Index of the neighbor in the original DataFrame
276
+ distance: Distance between query and neighbor
235
277
 
236
278
  Returns:
237
- Proximity: A new Proximity instance
279
+ Dictionary containing neighbor information
238
280
  """
239
- directory_path = Path(directory)
240
- if not directory_path.exists() or not directory_path.is_dir():
241
- raise ValueError(f"Directory {directory} does not exist or is not a directory")
281
+ neighbor_row = self.df.iloc[neighbor_idx]
282
+ neighbor_id = neighbor_row[self.id_column]
242
283
 
243
- # Load metadata
244
- with open(os.path.join(directory, "metadata.json"), "r") as f:
245
- metadata = json.load(f)
246
-
247
- # Load DataFrame
248
- df_path = os.path.join(directory, "df.pkl")
249
- if not os.path.exists(df_path):
250
- raise FileNotFoundError(f"DataFrame file not found at {df_path}")
251
- df = pd.read_pickle(df_path)
252
-
253
- # Create instance but skip _prepare_data
254
- instance = cls.__new__(cls)
255
- instance.df = df
256
- instance.id_column = metadata["id_column"]
257
- instance.features = metadata["features"]
258
- instance.target = metadata["target"]
259
- instance.track_columns = metadata["track_columns"]
260
- instance.n_neighbors = metadata["n_neighbors"]
284
+ # Start with basic info
285
+ result = {
286
+ self.id_column: query_id,
287
+ "neighbor_id": neighbor_id,
288
+ "distance": 0.0 if distance < 1e-5 else distance,
289
+ }
261
290
 
262
- # Load scaler and nn model
263
- with open(os.path.join(directory, "scaler.pkl"), "rb") as f:
264
- instance.scaler = pickle.load(f)
291
+ # Add target if present
292
+ if self.target and self.target in self.df.columns:
293
+ result[self.target] = neighbor_row[self.target]
265
294
 
266
- with open(os.path.join(directory, "nn_model.pkl"), "rb") as f:
267
- instance.nn = pickle.load(f)
295
+ # Add tracked columns
296
+ for col in self.track_columns:
297
+ if col in self.df.columns:
298
+ result[col] = neighbor_row[col]
268
299
 
269
- # Load X from scaler transform
270
- instance.X = instance.scaler.transform(instance.df[instance.features])
300
+ # Add prediction/probability columns if they exist
301
+ for col in self.df.columns:
302
+ if col == "prediction" or "_proba" in col or "residual" in col or col == "in_model":
303
+ result[col] = neighbor_row[col]
271
304
 
272
- log.info(f"Proximity model deserialized from {directory}")
273
- return instance
305
+ return result
274
306
 
275
307
 
276
308
  # Testing the Proximity class
@@ -290,28 +322,15 @@ if __name__ == "__main__":
290
322
 
291
323
  # Test the Proximity class
292
324
  features = ["Feature1", "Feature2", "Feature3"]
293
- prox = Proximity(df, id_column="ID", features=features, n_neighbors=3)
294
- print(prox.all_neighbors())
295
-
296
- # Test the neighbors method
297
- print(prox.neighbors(query_df=df.iloc[[0]]))
325
+ prox = Proximity(df, id_column="ID", features=features)
326
+ print(prox.neighbors(1, n_neighbors=2))
298
327
 
299
328
  # Test the neighbors method with radius
300
- print(prox.neighbors(query_df=df.iloc[0:2], radius=2.0))
301
-
302
- # Test with data that isn't in the 'train' dataframe
303
- query_data = {
304
- "ID": [6],
305
- "Feature1": [0.31],
306
- "Feature2": [0.31],
307
- "Feature3": [2.31],
308
- }
309
- query_df = pd.DataFrame(query_data)
310
- print(prox.neighbors(query_df=query_df))
329
+ print(prox.neighbors(1, radius=2.0))
311
330
 
312
331
  # Test with Features list
313
- prox = Proximity(df, id_column="ID", features=["Feature1"], n_neighbors=2)
314
- print(prox.all_neighbors())
332
+ prox = Proximity(df, id_column="ID", features=["Feature1"])
333
+ print(prox.neighbors(1))
315
334
 
316
335
  # Create a sample DataFrame
317
336
  data = {
@@ -329,39 +348,8 @@ if __name__ == "__main__":
329
348
  features=["Feature1", "Feature2"],
330
349
  target="target",
331
350
  track_columns=["Feature1", "Feature2"],
332
- n_neighbors=3,
333
351
  )
334
- print(prox.all_neighbors())
335
-
336
- # Test the neighbors method
337
- print(prox.neighbors(query_df=df.iloc[0:2]))
338
-
339
- # Time neighbors with all IDs versus calling all_neighbors
340
- import time
341
-
342
- start_time = time.time()
343
- prox_df = prox.neighbors(query_df=df, include_self=False)
344
- end_time = time.time()
345
- print(f"Time taken for neighbors: {end_time - start_time:.4f} seconds")
346
- start_time = time.time()
347
- prox_df_all = prox.all_neighbors()
348
- end_time = time.time()
349
- print(f"Time taken for all_neighbors: {end_time - start_time:.4f} seconds")
350
-
351
- # Now compare the two dataframes
352
- print("Neighbors DataFrame:")
353
- print(prox_df)
354
- print("\nAll Neighbors DataFrame:")
355
- print(prox_df_all)
356
- # Check for any discrepancies
357
- if prox_df.equals(prox_df_all):
358
- print("The two DataFrames are equal :)")
359
- else:
360
- print("ERROR: The two DataFrames are not equal!")
361
-
362
- # Test querying without the id_column
363
- df_no_id = df.drop(columns=["foo_id"])
364
- print(prox.neighbors(query_df=df_no_id, include_self=False))
352
+ print(prox.neighbors(["a", "b"]))
365
353
 
366
354
  # Test duplicate IDs
367
355
  data = {
@@ -371,14 +359,52 @@ if __name__ == "__main__":
371
359
  "target": [1, 0, 1, 0, 5],
372
360
  }
373
361
  df = pd.DataFrame(data)
374
- prox = Proximity(df, id_column="foo_id", features=["Feature1", "Feature2"], target="target", n_neighbors=3)
362
+ prox = Proximity(df, id_column="foo_id", features=["Feature1", "Feature2"], target="target")
375
363
  print(df.equals(prox.df))
376
364
 
377
365
  # Test with a categorical feature
378
366
  from workbench.api import FeatureSet, Model
379
367
 
380
- fs = FeatureSet("abalone_features")
381
- model = Model("abalone-regression")
368
+ fs = FeatureSet("aqsol_features")
369
+ model = Model("aqsol-regression")
370
+ features = model.features()
382
371
  df = fs.pull_dataframe()
383
- prox = Proximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
384
- print(prox.neighbors(query_df=df[0:2]))
372
+ prox = Proximity(
373
+ df, id_column=fs.id_column, features=model.features(), target=model.target(), track_columns=features
374
+ )
375
+ print(prox.neighbors(df[fs.id_column].tolist()[:3]))
376
+
377
+ print("\n" + "=" * 80)
378
+ print("Testing isolated_compounds...")
379
+ print("=" * 80)
380
+
381
+ # Test isolated data in the top 1%
382
+ isolated_1pct = prox.isolated(top_percent=1.0)
383
+ print(f"\nTop 1% most isolated compounds (n={len(isolated_1pct)}):")
384
+ print(isolated_1pct[[fs.id_column, "nn_distance", "nn_id"]].head(10))
385
+
386
+ # Test isolated data in the top 5%
387
+ isolated_5pct = prox.isolated(top_percent=5.0)
388
+ print(f"\nTop 5% most isolated compounds (n={len(isolated_5pct)}):")
389
+ print(isolated_5pct[[fs.id_column, "nn_distance", "nn_id"]].head(10))
390
+
391
+ print("\n" + "=" * 80)
392
+ print("Testing target_gradients...")
393
+ print("=" * 80)
394
+
395
+ # Test with different parameters
396
+ gradients_1pct = prox.target_gradients(top_percent=1.0, min_delta=1.0)
397
+ print(f"\nTop 1% target gradients (min_delta=5.0) (n={len(gradients_1pct)}):")
398
+ print(
399
+ gradients_1pct[
400
+ [fs.id_column, model.target(), "neighbor_median", "neighbor_median_diff", "mean_distance", "gradient"]
401
+ ].head(10)
402
+ )
403
+
404
+ gradients_5pct = prox.target_gradients(top_percent=5.0, min_delta=5.0)
405
+ print(f"\nTop 5% target gradients (min_delta=5.0) (n={len(gradients_5pct)}):")
406
+ print(
407
+ gradients_5pct[
408
+ [fs.id_column, model.target(), "neighbor_median", "neighbor_median_diff", "mean_distance", "gradient"]
409
+ ].head(10)
410
+ )
@@ -14,7 +14,7 @@ import pandas as pd
14
14
  TEMPLATE_PARAMS = {
15
15
  "features": "{{feature_list}}",
16
16
  "target": "{{target_column}}",
17
- "train_all_data": "{{train_all_data}}"
17
+ "train_all_data": "{{train_all_data}}",
18
18
  }
19
19
 
20
20
 
@@ -37,7 +37,7 @@ def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> p
37
37
  """
38
38
  Matches and renames DataFrame columns to match model feature names (case-insensitive).
39
39
  Prioritizes exact matches, then case-insensitive matches.
40
-
40
+
41
41
  Raises ValueError if any model features cannot be matched.
42
42
  """
43
43
  df_columns_lower = {col.lower(): col for col in df.columns}
@@ -81,10 +81,7 @@ if __name__ == "__main__":
81
81
  args = parser.parse_args()
82
82
 
83
83
  # Load training data from the specified directory
84
- training_files = [
85
- os.path.join(args.train, file)
86
- for file in os.listdir(args.train) if file.endswith(".csv")
87
- ]
84
+ training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
88
85
  df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
89
86
 
90
87
  # Check if the DataFrame is empty
@@ -109,8 +106,10 @@ if __name__ == "__main__":
109
106
  # Create and train the Regression/Confidence model
110
107
  # model = BayesianRidge()
111
108
  model = BayesianRidge(
112
- alpha_1=1e-6, alpha_2=1e-6, # Noise precision
113
- lambda_1=1e-6, lambda_2=1e-6, # Weight precision
109
+ alpha_1=1e-6,
110
+ alpha_2=1e-6, # Noise precision
111
+ lambda_1=1e-6,
112
+ lambda_2=1e-6, # Weight precision
114
113
  fit_intercept=True,
115
114
  )
116
115