workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,273 +0,0 @@
1
- # Model: NGBoost Regressor with Distribution output
2
- from ngboost import NGBRegressor
3
- from xgboost import XGBRegressor # Base Estimator
4
- from sklearn.model_selection import train_test_split
5
-
6
- # Model Performance Scores
7
- from sklearn.metrics import (
8
- mean_absolute_error,
9
- r2_score,
10
- root_mean_squared_error
11
- )
12
-
13
- from io import StringIO
14
- import json
15
- import argparse
16
- import joblib
17
- import os
18
- import pandas as pd
19
-
20
- # Local Imports
21
- from proximity import Proximity
22
-
23
-
24
-
25
- # Template Placeholders
26
- TEMPLATE_PARAMS = {
27
- "id_column": "{{id_column}}",
28
- "features": "{{feature_list}}",
29
- "target": "{{target_column}}",
30
- "train_all_data": "{{train_all_data}}",
31
- "track_columns": "{{track_columns}}"
32
- }
33
-
34
-
35
- # Function to check if dataframe is empty
36
- def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
37
- """
38
- Check if the provided dataframe is empty and raise an exception if it is.
39
-
40
- Args:
41
- df (pd.DataFrame): DataFrame to check
42
- df_name (str): Name of the DataFrame
43
- """
44
- if df.empty:
45
- msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
46
- print(msg)
47
- raise ValueError(msg)
48
-
49
-
50
- def match_features_case_insensitive(df: pd.DataFrame, model_features: list) -> pd.DataFrame:
51
- """
52
- Matches and renames DataFrame columns to match model feature names (case-insensitive).
53
- Prioritizes exact matches, then case-insensitive matches.
54
-
55
- Raises ValueError if any model features cannot be matched.
56
- """
57
- df_columns_lower = {col.lower(): col for col in df.columns}
58
- rename_dict = {}
59
- missing = []
60
- for feature in model_features:
61
- if feature in df.columns:
62
- continue # Exact match
63
- elif feature.lower() in df_columns_lower:
64
- rename_dict[df_columns_lower[feature.lower()]] = feature
65
- else:
66
- missing.append(feature)
67
-
68
- if missing:
69
- raise ValueError(f"Features not found: {missing}")
70
-
71
- # Rename the DataFrame columns to match the model features
72
- return df.rename(columns=rename_dict)
73
-
74
-
75
- # TRAINING SECTION
76
- #
77
- # This section (__main__) is where SageMaker will execute the training job
78
- # and save the model artifacts to the model directory.
79
- #
80
- if __name__ == "__main__":
81
- # Template Parameters
82
- id_column = TEMPLATE_PARAMS["id_column"]
83
- features = TEMPLATE_PARAMS["features"]
84
- target = TEMPLATE_PARAMS["target"]
85
- train_all_data = TEMPLATE_PARAMS["train_all_data"]
86
- track_columns = TEMPLATE_PARAMS["track_columns"] # Can be None
87
- validation_split = 0.2
88
-
89
- # Script arguments for input/output directories
90
- parser = argparse.ArgumentParser()
91
- parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
92
- parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
93
- parser.add_argument(
94
- "--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/opt/ml/output/data")
95
- )
96
- args = parser.parse_args()
97
-
98
- # Load training data from the specified directory
99
- training_files = [
100
- os.path.join(args.train, file)
101
- for file in os.listdir(args.train) if file.endswith(".csv")
102
- ]
103
- print(f"Training Files: {training_files}")
104
-
105
- # Combine files and read them all into a single pandas dataframe
106
- df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
107
-
108
- # Check if the DataFrame is empty
109
- check_dataframe(df, "training_df")
110
-
111
- # Training data split logic
112
- if train_all_data:
113
- # Use all data for both training and validation
114
- print("Training on all data...")
115
- df_train = df.copy()
116
- df_val = df.copy()
117
- elif "training" in df.columns:
118
- # Split data based on a 'training' column if it exists
119
- print("Splitting data based on 'training' column...")
120
- df_train = df[df["training"]].copy()
121
- df_val = df[~df["training"]].copy()
122
- else:
123
- # Perform a random split if no 'training' column is found
124
- print("Splitting data randomly...")
125
- df_train, df_val = train_test_split(df, test_size=validation_split, random_state=42)
126
-
127
- # We're using XGBoost for point predictions and NGBoost for uncertainty quantification
128
- xgb_model = XGBRegressor()
129
- ngb_model = NGBRegressor()
130
-
131
- # Prepare features and targets for training
132
- X_train = df_train[features]
133
- X_val = df_val[features]
134
- y_train = df_train[target]
135
- y_val = df_val[target]
136
-
137
- # Train both models using the training data
138
- xgb_model.fit(X_train, y_train)
139
- ngb_model.fit(X_train, y_train, X_val=X_val, Y_val=y_val)
140
-
141
- # Make Predictions on the Validation Set
142
- print(f"Making Predictions on Validation Set...")
143
- y_validate = df_val[target]
144
- X_validate = df_val[features]
145
- preds = xgb_model.predict(X_validate)
146
-
147
- # Calculate various model performance metrics (regression)
148
- rmse = root_mean_squared_error(y_validate, preds)
149
- mae = mean_absolute_error(y_validate, preds)
150
- r2 = r2_score(y_validate, preds)
151
- print(f"RMSE: {rmse:.3f}")
152
- print(f"MAE: {mae:.3f}")
153
- print(f"R2: {r2:.3f}")
154
- print(f"NumRows: {len(df_val)}")
155
-
156
- # Save the trained XGBoost model
157
- xgb_model.save_model(os.path.join(args.model_dir, "xgb_model.json"))
158
-
159
- # Save the trained NGBoost model
160
- joblib.dump(ngb_model, os.path.join(args.model_dir, "ngb_model.joblib"))
161
-
162
- # Save the feature list to validate input during predictions
163
- with open(os.path.join(args.model_dir, "feature_columns.json"), "w") as fp:
164
- json.dump(features, fp)
165
-
166
- # Now the Proximity model
167
- model = Proximity(df_train, id_column, features, target, track_columns=track_columns)
168
-
169
- # Now serialize the model
170
- model.serialize(args.model_dir)
171
-
172
-
173
- #
174
- # Inference Section
175
- #
176
- def model_fn(model_dir) -> dict:
177
- """Load and return XGBoost, NGBoost, and Prox Model from model directory."""
178
-
179
- # Load XGBoost regressor
180
- xgb_path = os.path.join(model_dir, "xgb_model.json")
181
- xgb_model = XGBRegressor(enable_categorical=True)
182
- xgb_model.load_model(xgb_path)
183
-
184
- # Load NGBoost regressor
185
- ngb_model = joblib.load(os.path.join(model_dir, "ngb_model.joblib"))
186
-
187
- # Deserialize the proximity model
188
- prox_model = Proximity.deserialize(model_dir)
189
-
190
- return {
191
- "xgboost": xgb_model,
192
- "ngboost": ngb_model,
193
- "proximity": prox_model
194
- }
195
-
196
-
197
- def input_fn(input_data, content_type):
198
- """Parse input data and return a DataFrame."""
199
- if not input_data:
200
- raise ValueError("Empty input data is not supported!")
201
-
202
- # Decode bytes to string if necessary
203
- if isinstance(input_data, bytes):
204
- input_data = input_data.decode("utf-8")
205
-
206
- if "text/csv" in content_type:
207
- return pd.read_csv(StringIO(input_data))
208
- elif "application/json" in content_type:
209
- return pd.DataFrame(json.loads(input_data)) # Assumes JSON array of records
210
- else:
211
- raise ValueError(f"{content_type} not supported!")
212
-
213
-
214
- def output_fn(output_df, accept_type):
215
- """Supports both CSV and JSON output formats."""
216
- if "text/csv" in accept_type:
217
- csv_output = output_df.fillna("N/A").to_csv(index=False) # CSV with N/A for missing values
218
- return csv_output, "text/csv"
219
- elif "application/json" in accept_type:
220
- return output_df.to_json(orient="records"), "application/json" # JSON array of records (NaNs -> null)
221
- else:
222
- raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
223
-
224
-
225
- def predict_fn(df, models) -> pd.DataFrame:
226
- """Make Predictions with our XGB Quantile Regression Model
227
-
228
- Args:
229
- df (pd.DataFrame): The input DataFrame
230
- models (dict): The dictionary of models to use for predictions
231
-
232
- Returns:
233
- pd.DataFrame: The DataFrame with the predictions added
234
- """
235
-
236
- # Grab our feature columns (from training)
237
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
238
- with open(os.path.join(model_dir, "feature_columns.json")) as fp:
239
- model_features = json.load(fp)
240
-
241
- # Match features in a case-insensitive manner
242
- matched_df = match_features_case_insensitive(df, model_features)
243
-
244
- # Use XGBoost for point predictions
245
- df["prediction"] = models["xgboost"].predict(matched_df[model_features])
246
-
247
- # NGBoost predict returns distribution objects
248
- y_dists = models["ngboost"].pred_dist(matched_df[model_features])
249
-
250
- # Extract parameters from distribution
251
- dist_params = y_dists.params
252
-
253
- # Extract mean and std from distribution parameters
254
- df["prediction_uq"] = dist_params['loc'] # mean
255
- df["prediction_std"] = dist_params['scale'] # standard deviation
256
-
257
- # Add 95% prediction intervals using ppf (percent point function)
258
- df["q_025"] = y_dists.ppf(0.025) # 2.5th percentile
259
- df["q_975"] = y_dists.ppf(0.975) # 97.5th percentile
260
-
261
- # Add 50% prediction intervals
262
- df["q_25"] = y_dists.ppf(0.25) # 25th percentile
263
- df["q_75"] = y_dists.ppf(0.75) # 75th percentile
264
-
265
- # Adjust prediction intervals to include point predictions
266
- df["q_025"] = df[["q_025", "prediction"]].min(axis=1)
267
- df["q_975"] = df[["q_975", "prediction"]].max(axis=1)
268
-
269
- # Compute Nearest neighbors with Proximity model
270
- models["proximity"].neighbors(df)
271
-
272
- # Return the modified DataFrame
273
- return df
@@ -1,384 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- from sklearn.preprocessing import StandardScaler
4
- from sklearn.neighbors import NearestNeighbors
5
- from typing import List, Dict
6
- import logging
7
- import pickle
8
- import os
9
- import json
10
- from pathlib import Path
11
- from enum import Enum
12
-
13
- # Set up logging
14
- log = logging.getLogger("workbench")
15
-
16
-
17
- # ^Enumerated^ Proximity Types (distance or similarity)
18
- class ProximityType(Enum):
19
- DISTANCE = "distance"
20
- SIMILARITY = "similarity"
21
-
22
-
23
- class Proximity:
24
- def __init__(
25
- self,
26
- df: pd.DataFrame,
27
- id_column: str,
28
- features: List[str],
29
- target: str = None,
30
- track_columns: List[str] = None,
31
- n_neighbors: int = 10,
32
- ):
33
- """
34
- Initialize the Proximity class.
35
-
36
- Args:
37
- df (pd.DataFrame): DataFrame containing data for neighbor computations.
38
- id_column (str): Name of the column used as the identifier.
39
- features (List[str]): List of feature column names to be used for neighbor computations.
40
- target (str, optional): Name of the target column. Defaults to None.
41
- track_columns (List[str], optional): Additional columns to track in results. Defaults to None.
42
- n_neighbors (int): Number of neighbors to compute. Defaults to 10.
43
- """
44
- self.df = df.dropna(subset=features).copy()
45
- self.id_column = id_column
46
- self.n_neighbors = min(n_neighbors, len(self.df) - 1)
47
- self.target = target
48
- self.features = features
49
- self.scaler = None
50
- self.X = None
51
- self.nn = None
52
- self.proximity_type = None
53
- self.track_columns = track_columns or []
54
-
55
- # Right now we only support numeric features, so remove any columns that are not numeric
56
- non_numeric_features = self.df[self.features].select_dtypes(exclude=["number"]).columns.tolist()
57
- if non_numeric_features:
58
- log.warning(f"Non-numeric features {non_numeric_features} aren't currently supported...")
59
- self.features = [f for f in self.features if f not in non_numeric_features]
60
-
61
- # Build the proximity model
62
- self.build_proximity_model()
63
-
64
- def build_proximity_model(self) -> None:
65
- """Standardize features and fit Nearest Neighbors model.
66
- Note: This method can be overridden in subclasses for custom behavior."""
67
- self.proximity_type = ProximityType.DISTANCE
68
- self.scaler = StandardScaler()
69
- self.X = self.scaler.fit_transform(self.df[self.features])
70
- self.nn = NearestNeighbors(n_neighbors=self.n_neighbors + 1).fit(self.X)
71
-
72
- def all_neighbors(self) -> pd.DataFrame:
73
- """
74
- Compute nearest neighbors for all rows in the dataset.
75
-
76
- Returns:
77
- pd.DataFrame: A DataFrame of neighbors and their distances.
78
- """
79
- distances, indices = self.nn.kneighbors(self.X)
80
- results = []
81
-
82
- for i, (dists, nbrs) in enumerate(zip(distances, indices)):
83
- query_id = self.df.iloc[i][self.id_column]
84
-
85
- # Process neighbors
86
- for neighbor_idx, dist in zip(nbrs, dists):
87
- # Skip self (neighbor index == current row index)
88
- if neighbor_idx == i:
89
- continue
90
- results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
91
-
92
- return pd.DataFrame(results)
93
-
94
- def neighbors(
95
- self,
96
- query_df: pd.DataFrame,
97
- radius: float = None,
98
- include_self: bool = True,
99
- ) -> pd.DataFrame:
100
- """
101
- Return neighbors for rows in a query DataFrame.
102
-
103
- Args:
104
- query_df: DataFrame containing query points
105
- radius: If provided, find all neighbors within this radius
106
- include_self: Whether to include self in results (if present)
107
-
108
- Returns:
109
- DataFrame containing neighbors and distances
110
-
111
- Note: The query DataFrame must include the feature columns. The id_column is optional.
112
- """
113
- # Check if all required features are present
114
- missing = set(self.features) - set(query_df.columns)
115
- if missing:
116
- raise ValueError(f"Query DataFrame is missing required feature columns: {missing}")
117
-
118
- # Check if id_column is present
119
- id_column_present = self.id_column in query_df.columns
120
-
121
- # None of the features can be NaNs, so report rows with NaNs and then drop them
122
- rows_with_nan = query_df[self.features].isna().any(axis=1)
123
-
124
- # Print the ID column for rows with NaNs
125
- if rows_with_nan.any():
126
- log.warning(f"Found {rows_with_nan.sum()} rows with NaNs in feature columns:")
127
- log.warning(query_df.loc[rows_with_nan, self.id_column])
128
-
129
- # Drop rows with NaNs in feature columns and reassign to query_df
130
- query_df = query_df.dropna(subset=self.features)
131
-
132
- # Transform the query features using the model's scaler
133
- X_query = self.scaler.transform(query_df[self.features])
134
-
135
- # Get neighbors using either radius or k-nearest neighbors
136
- if radius is not None:
137
- distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
138
- else:
139
- distances, indices = self.nn.kneighbors(X_query)
140
-
141
- # Build results
142
- all_results = []
143
- for i, (dists, nbrs) in enumerate(zip(distances, indices)):
144
- # Use the ID from the query DataFrame if available, otherwise use the row index
145
- query_id = query_df.iloc[i][self.id_column] if id_column_present else f"query_{i}"
146
-
147
- for neighbor_idx, dist in zip(nbrs, dists):
148
- # Skip if the neighbor is the query itself and include_self is False
149
- neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
150
- if not include_self and neighbor_id == query_id:
151
- continue
152
-
153
- all_results.append(
154
- self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist)
155
- )
156
-
157
- return pd.DataFrame(all_results)
158
-
159
- def _build_neighbor_result(self, query_id, neighbor_idx: int, distance: float) -> Dict:
160
- """
161
- Internal: Build a result dictionary for a single neighbor.
162
-
163
- Args:
164
- query_id: ID of the query point
165
- neighbor_idx: Index of the neighbor in the original DataFrame
166
- distance: Distance between query and neighbor
167
-
168
- Returns:
169
- Dictionary containing neighbor information
170
- """
171
- neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
172
-
173
- # Basic neighbor info
174
- neighbor_info = {
175
- self.id_column: query_id,
176
- "neighbor_id": neighbor_id,
177
- "distance": distance,
178
- }
179
-
180
- # Determine which additional columns to include
181
- relevant_cols = [self.target, "prediction"] if self.target else []
182
- relevant_cols += [c for c in self.df.columns if "_proba" in c or "residual" in c]
183
- relevant_cols += ["outlier"]
184
-
185
- # Add user-specified columns
186
- relevant_cols += self.track_columns
187
-
188
- # Add values for each relevant column that exists in the dataframe
189
- for col in filter(lambda c: c in self.df.columns, relevant_cols):
190
- neighbor_info[col] = self.df.iloc[neighbor_idx][col]
191
-
192
- return neighbor_info
193
-
194
- def serialize(self, directory: str) -> None:
195
- """
196
- Serialize the Proximity model to a directory.
197
-
198
- Args:
199
- directory: Directory path to save the model components
200
- """
201
- # Create directory if it doesn't exist
202
- os.makedirs(directory, exist_ok=True)
203
-
204
- # Save metadata
205
- metadata = {
206
- "id_column": self.id_column,
207
- "features": self.features,
208
- "target": self.target,
209
- "track_columns": self.track_columns,
210
- "n_neighbors": self.n_neighbors,
211
- }
212
-
213
- with open(os.path.join(directory, "metadata.json"), "w") as f:
214
- json.dump(metadata, f)
215
-
216
- # Save the DataFrame
217
- self.df.to_pickle(os.path.join(directory, "df.pkl"))
218
-
219
- # Save the scaler and nearest neighbors model
220
- with open(os.path.join(directory, "scaler.pkl"), "wb") as f:
221
- pickle.dump(self.scaler, f)
222
-
223
- with open(os.path.join(directory, "nn_model.pkl"), "wb") as f:
224
- pickle.dump(self.nn, f)
225
-
226
- log.info(f"Proximity model serialized to {directory}")
227
-
228
- @classmethod
229
- def deserialize(cls, directory: str) -> "Proximity":
230
- """
231
- Deserialize a Proximity model from a directory.
232
-
233
- Args:
234
- directory: Directory path containing the serialized model components
235
-
236
- Returns:
237
- Proximity: A new Proximity instance
238
- """
239
- directory_path = Path(directory)
240
- if not directory_path.exists() or not directory_path.is_dir():
241
- raise ValueError(f"Directory {directory} does not exist or is not a directory")
242
-
243
- # Load metadata
244
- with open(os.path.join(directory, "metadata.json"), "r") as f:
245
- metadata = json.load(f)
246
-
247
- # Load DataFrame
248
- df_path = os.path.join(directory, "df.pkl")
249
- if not os.path.exists(df_path):
250
- raise FileNotFoundError(f"DataFrame file not found at {df_path}")
251
- df = pd.read_pickle(df_path)
252
-
253
- # Create instance but skip _prepare_data
254
- instance = cls.__new__(cls)
255
- instance.df = df
256
- instance.id_column = metadata["id_column"]
257
- instance.features = metadata["features"]
258
- instance.target = metadata["target"]
259
- instance.track_columns = metadata["track_columns"]
260
- instance.n_neighbors = metadata["n_neighbors"]
261
-
262
- # Load scaler and nn model
263
- with open(os.path.join(directory, "scaler.pkl"), "rb") as f:
264
- instance.scaler = pickle.load(f)
265
-
266
- with open(os.path.join(directory, "nn_model.pkl"), "rb") as f:
267
- instance.nn = pickle.load(f)
268
-
269
- # Load X from scaler transform
270
- instance.X = instance.scaler.transform(instance.df[instance.features])
271
-
272
- log.info(f"Proximity model deserialized from {directory}")
273
- return instance
274
-
275
-
276
- # Testing the Proximity class
277
- if __name__ == "__main__":
278
-
279
- pd.set_option("display.max_columns", None)
280
- pd.set_option("display.width", 1000)
281
-
282
- # Create a sample DataFrame
283
- data = {
284
- "ID": [1, 2, 3, 4, 5],
285
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
286
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
287
- "Feature3": [2.5, 2.4, 2.3, 2.3, np.nan],
288
- }
289
- df = pd.DataFrame(data)
290
-
291
- # Test the Proximity class
292
- features = ["Feature1", "Feature2", "Feature3"]
293
- prox = Proximity(df, id_column="ID", features=features, n_neighbors=3)
294
- print(prox.all_neighbors())
295
-
296
- # Test the neighbors method
297
- print(prox.neighbors(query_df=df.iloc[[0]]))
298
-
299
- # Test the neighbors method with radius
300
- print(prox.neighbors(query_df=df.iloc[0:2], radius=2.0))
301
-
302
- # Test with data that isn't in the 'train' dataframe
303
- query_data = {
304
- "ID": [6],
305
- "Feature1": [0.31],
306
- "Feature2": [0.31],
307
- "Feature3": [2.31],
308
- }
309
- query_df = pd.DataFrame(query_data)
310
- print(prox.neighbors(query_df=query_df))
311
-
312
- # Test with Features list
313
- prox = Proximity(df, id_column="ID", features=["Feature1"], n_neighbors=2)
314
- print(prox.all_neighbors())
315
-
316
- # Create a sample DataFrame
317
- data = {
318
- "foo_id": ["a", "b", "c", "d", "e"], # Testing string IDs
319
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
320
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
321
- "target": [1, 0, 1, 0, 5],
322
- }
323
- df = pd.DataFrame(data)
324
-
325
- # Test with String Ids
326
- prox = Proximity(
327
- df,
328
- id_column="foo_id",
329
- features=["Feature1", "Feature2"],
330
- target="target",
331
- track_columns=["Feature1", "Feature2"],
332
- n_neighbors=3,
333
- )
334
- print(prox.all_neighbors())
335
-
336
- # Test the neighbors method
337
- print(prox.neighbors(query_df=df.iloc[0:2]))
338
-
339
- # Time neighbors with all IDs versus calling all_neighbors
340
- import time
341
-
342
- start_time = time.time()
343
- prox_df = prox.neighbors(query_df=df, include_self=False)
344
- end_time = time.time()
345
- print(f"Time taken for neighbors: {end_time - start_time:.4f} seconds")
346
- start_time = time.time()
347
- prox_df_all = prox.all_neighbors()
348
- end_time = time.time()
349
- print(f"Time taken for all_neighbors: {end_time - start_time:.4f} seconds")
350
-
351
- # Now compare the two dataframes
352
- print("Neighbors DataFrame:")
353
- print(prox_df)
354
- print("\nAll Neighbors DataFrame:")
355
- print(prox_df_all)
356
- # Check for any discrepancies
357
- if prox_df.equals(prox_df_all):
358
- print("The two DataFrames are equal :)")
359
- else:
360
- print("ERROR: The two DataFrames are not equal!")
361
-
362
- # Test querying without the id_column
363
- df_no_id = df.drop(columns=["foo_id"])
364
- print(prox.neighbors(query_df=df_no_id, include_self=False))
365
-
366
- # Test duplicate IDs
367
- data = {
368
- "foo_id": ["a", "b", "c", "d", "d"], # Duplicate ID (d)
369
- "Feature1": [0.1, 0.2, 0.3, 0.4, 0.5],
370
- "Feature2": [0.5, 0.4, 0.3, 0.2, 0.1],
371
- "target": [1, 0, 1, 0, 5],
372
- }
373
- df = pd.DataFrame(data)
374
- prox = Proximity(df, id_column="foo_id", features=["Feature1", "Feature2"], target="target", n_neighbors=3)
375
- print(df.equals(prox.df))
376
-
377
- # Test with a categorical feature
378
- from workbench.api import FeatureSet, Model
379
-
380
- fs = FeatureSet("abalone_features")
381
- model = Model("abalone-regression")
382
- df = fs.pull_dataframe()
383
- prox = Proximity(df, id_column=fs.id_column, features=model.features(), target=model.target())
384
- print(prox.neighbors(query_df=df[0:2]))