workbench 0.8.182__py3-none-any.whl → 0.8.184__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/proximity.py +143 -102
- workbench/api/model.py +13 -12
- workbench/core/artifacts/model_core.py +15 -0
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +5 -7
- workbench/model_scripts/custom_models/proximity/proximity.py +143 -102
- workbench/model_scripts/custom_models/uq_models/proximity.py +143 -102
- workbench/model_scripts/uq_models/generated_model_script.py +1 -1
- workbench/scripts/ml_pipeline_sqs.py +14 -4
- workbench/utils/model_utils.py +27 -0
- {workbench-0.8.182.dist-info → workbench-0.8.184.dist-info}/METADATA +1 -1
- {workbench-0.8.182.dist-info → workbench-0.8.184.dist-info}/RECORD +15 -15
- {workbench-0.8.182.dist-info → workbench-0.8.184.dist-info}/WHEEL +0 -0
- {workbench-0.8.182.dist-info → workbench-0.8.184.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.182.dist-info → workbench-0.8.184.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.182.dist-info → workbench-0.8.184.dist-info}/top_level.txt +0 -0
|
@@ -2,10 +2,9 @@ import pandas as pd
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
from sklearn.preprocessing import StandardScaler
|
|
4
4
|
from sklearn.neighbors import NearestNeighbors
|
|
5
|
-
from typing import List, Dict
|
|
5
|
+
from typing import List, Dict, Optional
|
|
6
6
|
import logging
|
|
7
7
|
import pickle
|
|
8
|
-
import os
|
|
9
8
|
import json
|
|
10
9
|
from pathlib import Path
|
|
11
10
|
from enum import Enum
|
|
@@ -14,7 +13,6 @@ from enum import Enum
|
|
|
14
13
|
log = logging.getLogger("workbench")
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
# ^Enumerated^ Proximity Types (distance or similarity)
|
|
18
16
|
class ProximityType(Enum):
|
|
19
17
|
DISTANCE = "distance"
|
|
20
18
|
SIMILARITY = "similarity"
|
|
@@ -26,44 +24,49 @@ class Proximity:
|
|
|
26
24
|
df: pd.DataFrame,
|
|
27
25
|
id_column: str,
|
|
28
26
|
features: List[str],
|
|
29
|
-
target: str = None,
|
|
30
|
-
track_columns: List[str] = None,
|
|
27
|
+
target: Optional[str] = None,
|
|
28
|
+
track_columns: Optional[List[str]] = None,
|
|
31
29
|
n_neighbors: int = 10,
|
|
32
30
|
):
|
|
33
31
|
"""
|
|
34
32
|
Initialize the Proximity class.
|
|
35
33
|
|
|
36
34
|
Args:
|
|
37
|
-
df
|
|
38
|
-
id_column
|
|
39
|
-
features
|
|
40
|
-
target
|
|
41
|
-
track_columns
|
|
42
|
-
n_neighbors
|
|
35
|
+
df: DataFrame containing data for neighbor computations.
|
|
36
|
+
id_column: Name of the column used as the identifier.
|
|
37
|
+
features: List of feature column names to be used for neighbor computations.
|
|
38
|
+
target: Name of the target column. Defaults to None.
|
|
39
|
+
track_columns: Additional columns to track in results. Defaults to None.
|
|
40
|
+
n_neighbors: Number of neighbors to compute. Defaults to 10.
|
|
43
41
|
"""
|
|
44
|
-
self.df = df.dropna(subset=features).copy()
|
|
45
42
|
self.id_column = id_column
|
|
46
|
-
self.n_neighbors = min(n_neighbors, len(self.df) - 1)
|
|
47
43
|
self.target = target
|
|
48
|
-
self.
|
|
44
|
+
self.track_columns = track_columns or []
|
|
45
|
+
self.proximity_type = None
|
|
49
46
|
self.scaler = None
|
|
50
47
|
self.X = None
|
|
51
48
|
self.nn = None
|
|
52
|
-
self.proximity_type = None
|
|
53
|
-
self.track_columns = track_columns or []
|
|
54
49
|
|
|
55
|
-
#
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
50
|
+
# Filter out non-numeric features
|
|
51
|
+
self.features = self._validate_features(df, features)
|
|
52
|
+
|
|
53
|
+
# Drop NaN rows and set up DataFrame
|
|
54
|
+
self.df = df.dropna(subset=self.features).copy()
|
|
55
|
+
self.n_neighbors = min(n_neighbors, len(self.df) - 1)
|
|
60
56
|
|
|
61
57
|
# Build the proximity model
|
|
62
58
|
self.build_proximity_model()
|
|
63
59
|
|
|
60
|
+
def _validate_features(self, df: pd.DataFrame, features: List[str]) -> List[str]:
|
|
61
|
+
"""Remove non-numeric features and log warnings."""
|
|
62
|
+
non_numeric = df[features].select_dtypes(exclude=["number"]).columns.tolist()
|
|
63
|
+
if non_numeric:
|
|
64
|
+
log.warning(f"Non-numeric features {non_numeric} aren't currently supported...")
|
|
65
|
+
return [f for f in features if f not in non_numeric]
|
|
66
|
+
return features
|
|
67
|
+
|
|
64
68
|
def build_proximity_model(self) -> None:
|
|
65
|
-
"""Standardize features and fit Nearest Neighbors model.
|
|
66
|
-
Note: This method can be overridden in subclasses for custom behavior."""
|
|
69
|
+
"""Standardize features and fit Nearest Neighbors model."""
|
|
67
70
|
self.proximity_type = ProximityType.DISTANCE
|
|
68
71
|
self.scaler = StandardScaler()
|
|
69
72
|
self.X = self.scaler.fit_transform(self.df[self.features])
|
|
@@ -74,27 +77,60 @@ class Proximity:
|
|
|
74
77
|
Compute nearest neighbors for all rows in the dataset.
|
|
75
78
|
|
|
76
79
|
Returns:
|
|
77
|
-
|
|
80
|
+
DataFrame of neighbors and their distances.
|
|
78
81
|
"""
|
|
79
82
|
distances, indices = self.nn.kneighbors(self.X)
|
|
80
|
-
results = []
|
|
81
83
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
for
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
|
|
84
|
+
results = [
|
|
85
|
+
self._build_neighbor_result(
|
|
86
|
+
query_id=self.df.iloc[i][self.id_column], neighbor_idx=neighbor_idx, distance=dist
|
|
87
|
+
)
|
|
88
|
+
for i, (dists, nbrs) in enumerate(zip(distances, indices))
|
|
89
|
+
for neighbor_idx, dist in zip(nbrs, dists)
|
|
90
|
+
if neighbor_idx != i # Skip self
|
|
91
|
+
]
|
|
91
92
|
|
|
92
93
|
return pd.DataFrame(results)
|
|
93
94
|
|
|
94
95
|
def neighbors(
|
|
96
|
+
self,
|
|
97
|
+
id_or_ids,
|
|
98
|
+
n_neighbors: Optional[int] = 5,
|
|
99
|
+
radius: Optional[float] = None,
|
|
100
|
+
include_self: bool = True,
|
|
101
|
+
) -> pd.DataFrame:
|
|
102
|
+
"""
|
|
103
|
+
Return neighbors for ID(s) from the existing dataset.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
id_or_ids: Single ID or list of IDs to look up
|
|
107
|
+
n_neighbors: Number of neighbors to return (default: 5)
|
|
108
|
+
radius: If provided, find all neighbors within this radius
|
|
109
|
+
include_self: Whether to include self in results (if present)
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
DataFrame containing neighbors and distances
|
|
113
|
+
"""
|
|
114
|
+
# Normalize to list
|
|
115
|
+
ids = [id_or_ids] if not isinstance(id_or_ids, list) else id_or_ids
|
|
116
|
+
|
|
117
|
+
# Validate IDs exist
|
|
118
|
+
missing_ids = set(ids) - set(self.df[self.id_column])
|
|
119
|
+
if missing_ids:
|
|
120
|
+
raise ValueError(f"IDs not found in dataset: {missing_ids}")
|
|
121
|
+
|
|
122
|
+
# Filter to requested IDs and preserve order
|
|
123
|
+
query_df = self.df[self.df[self.id_column].isin(ids)]
|
|
124
|
+
query_df = query_df.set_index(self.id_column).loc[ids].reset_index()
|
|
125
|
+
|
|
126
|
+
# Use the core implementation
|
|
127
|
+
return self.find_neighbors(query_df, n_neighbors=n_neighbors, radius=radius, include_self=include_self)
|
|
128
|
+
|
|
129
|
+
def find_neighbors(
|
|
95
130
|
self,
|
|
96
131
|
query_df: pd.DataFrame,
|
|
97
|
-
|
|
132
|
+
n_neighbors: Optional[int] = 5,
|
|
133
|
+
radius: Optional[float] = None,
|
|
98
134
|
include_self: bool = True,
|
|
99
135
|
) -> pd.DataFrame:
|
|
100
136
|
"""
|
|
@@ -102,63 +138,63 @@ class Proximity:
|
|
|
102
138
|
|
|
103
139
|
Args:
|
|
104
140
|
query_df: DataFrame containing query points
|
|
141
|
+
n_neighbors: Number of neighbors to return (default: 5)
|
|
105
142
|
radius: If provided, find all neighbors within this radius
|
|
106
143
|
include_self: Whether to include self in results (if present)
|
|
107
144
|
|
|
108
145
|
Returns:
|
|
109
146
|
DataFrame containing neighbors and distances
|
|
110
|
-
|
|
111
|
-
Note: The query DataFrame must include the feature columns. The id_column is optional.
|
|
112
147
|
"""
|
|
113
|
-
#
|
|
148
|
+
# Validate features
|
|
114
149
|
missing = set(self.features) - set(query_df.columns)
|
|
115
150
|
if missing:
|
|
116
151
|
raise ValueError(f"Query DataFrame is missing required feature columns: {missing}")
|
|
117
152
|
|
|
118
|
-
# Check if id_column is present
|
|
119
153
|
id_column_present = self.id_column in query_df.columns
|
|
120
154
|
|
|
121
|
-
#
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
# Print the ID column for rows with NaNs
|
|
125
|
-
if rows_with_nan.any():
|
|
126
|
-
log.warning(f"Found {rows_with_nan.sum()} rows with NaNs in feature columns:")
|
|
127
|
-
log.warning(query_df.loc[rows_with_nan, self.id_column])
|
|
128
|
-
|
|
129
|
-
# Drop rows with NaNs in feature columns and reassign to query_df
|
|
130
|
-
query_df = query_df.dropna(subset=self.features)
|
|
155
|
+
# Handle NaN rows
|
|
156
|
+
query_df = self._handle_nan_rows(query_df, id_column_present)
|
|
131
157
|
|
|
132
|
-
# Transform
|
|
158
|
+
# Transform query features
|
|
133
159
|
X_query = self.scaler.transform(query_df[self.features])
|
|
134
160
|
|
|
135
|
-
# Get neighbors
|
|
161
|
+
# Get neighbors
|
|
136
162
|
if radius is not None:
|
|
137
163
|
distances, indices = self.nn.radius_neighbors(X_query, radius=radius)
|
|
138
164
|
else:
|
|
139
|
-
distances, indices = self.nn.kneighbors(X_query)
|
|
165
|
+
distances, indices = self.nn.kneighbors(X_query, n_neighbors=n_neighbors)
|
|
140
166
|
|
|
141
167
|
# Build results
|
|
142
|
-
|
|
168
|
+
results = []
|
|
143
169
|
for i, (dists, nbrs) in enumerate(zip(distances, indices)):
|
|
144
|
-
# Use the ID from the query DataFrame if available, otherwise use the row index
|
|
145
170
|
query_id = query_df.iloc[i][self.id_column] if id_column_present else f"query_{i}"
|
|
146
171
|
|
|
147
172
|
for neighbor_idx, dist in zip(nbrs, dists):
|
|
148
|
-
# Skip if the neighbor is the query itself and include_self is False
|
|
149
173
|
neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
|
|
174
|
+
|
|
175
|
+
# Skip if neighbor is self and include_self is False
|
|
150
176
|
if not include_self and neighbor_id == query_id:
|
|
151
177
|
continue
|
|
152
178
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
179
|
+
results.append(self._build_neighbor_result(query_id=query_id, neighbor_idx=neighbor_idx, distance=dist))
|
|
180
|
+
|
|
181
|
+
results_df = pd.DataFrame(results).sort_values([self.id_column, "distance"]).reset_index(drop=True)
|
|
182
|
+
return results_df
|
|
183
|
+
|
|
184
|
+
def _handle_nan_rows(self, query_df: pd.DataFrame, id_column_present: bool) -> pd.DataFrame:
|
|
185
|
+
"""Drop rows with NaN values in feature columns and log warnings."""
|
|
186
|
+
rows_with_nan = query_df[self.features].isna().any(axis=1)
|
|
187
|
+
|
|
188
|
+
if rows_with_nan.any():
|
|
189
|
+
log.warning(f"Found {rows_with_nan.sum()} rows with NaNs in feature columns:")
|
|
190
|
+
if id_column_present:
|
|
191
|
+
log.warning(query_df.loc[rows_with_nan, self.id_column])
|
|
156
192
|
|
|
157
|
-
return
|
|
193
|
+
return query_df.dropna(subset=self.features)
|
|
158
194
|
|
|
159
195
|
def _build_neighbor_result(self, query_id, neighbor_idx: int, distance: float) -> Dict:
|
|
160
196
|
"""
|
|
161
|
-
|
|
197
|
+
Build a result dictionary for a single neighbor.
|
|
162
198
|
|
|
163
199
|
Args:
|
|
164
200
|
query_id: ID of the query point
|
|
@@ -169,27 +205,30 @@ class Proximity:
|
|
|
169
205
|
Dictionary containing neighbor information
|
|
170
206
|
"""
|
|
171
207
|
neighbor_id = self.df.iloc[neighbor_idx][self.id_column]
|
|
208
|
+
neighbor_row = self.df.iloc[neighbor_idx]
|
|
172
209
|
|
|
173
|
-
#
|
|
174
|
-
|
|
210
|
+
# Start with basic info
|
|
211
|
+
result = {
|
|
175
212
|
self.id_column: query_id,
|
|
176
213
|
"neighbor_id": neighbor_id,
|
|
177
214
|
"distance": distance,
|
|
178
215
|
}
|
|
179
216
|
|
|
180
|
-
#
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
217
|
+
# Columns to automatically include if they exist
|
|
218
|
+
auto_include = (
|
|
219
|
+
([self.target, "prediction"] if self.target else [])
|
|
220
|
+
+ self.track_columns
|
|
221
|
+
+ [col for col in self.df.columns if "_proba" in col or "residual" in col or col == "outlier"]
|
|
222
|
+
)
|
|
184
223
|
|
|
185
|
-
# Add
|
|
186
|
-
|
|
224
|
+
# Add values for existing columns
|
|
225
|
+
for col in auto_include:
|
|
226
|
+
if col in self.df.columns:
|
|
227
|
+
result[col] = neighbor_row[col]
|
|
187
228
|
|
|
188
|
-
#
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
return neighbor_info
|
|
229
|
+
# Truncate very small distances to zero
|
|
230
|
+
result["distance"] = 0.0 if distance < 1e-7 else distance
|
|
231
|
+
return result
|
|
193
232
|
|
|
194
233
|
def serialize(self, directory: str) -> None:
|
|
195
234
|
"""
|
|
@@ -198,8 +237,8 @@ class Proximity:
|
|
|
198
237
|
Args:
|
|
199
238
|
directory: Directory path to save the model components
|
|
200
239
|
"""
|
|
201
|
-
|
|
202
|
-
|
|
240
|
+
dir_path = Path(directory)
|
|
241
|
+
dir_path.mkdir(parents=True, exist_ok=True)
|
|
203
242
|
|
|
204
243
|
# Save metadata
|
|
205
244
|
metadata = {
|
|
@@ -210,17 +249,16 @@ class Proximity:
|
|
|
210
249
|
"n_neighbors": self.n_neighbors,
|
|
211
250
|
}
|
|
212
251
|
|
|
213
|
-
|
|
214
|
-
json.dump(metadata, f)
|
|
252
|
+
(dir_path / "metadata.json").write_text(json.dumps(metadata))
|
|
215
253
|
|
|
216
|
-
# Save
|
|
217
|
-
self.df.to_pickle(
|
|
254
|
+
# Save DataFrame
|
|
255
|
+
self.df.to_pickle(dir_path / "df.pkl")
|
|
218
256
|
|
|
219
|
-
# Save
|
|
220
|
-
with open(
|
|
257
|
+
# Save models
|
|
258
|
+
with open(dir_path / "scaler.pkl", "wb") as f:
|
|
221
259
|
pickle.dump(self.scaler, f)
|
|
222
260
|
|
|
223
|
-
with open(
|
|
261
|
+
with open(dir_path / "nn_model.pkl", "wb") as f:
|
|
224
262
|
pickle.dump(self.nn, f)
|
|
225
263
|
|
|
226
264
|
log.info(f"Proximity model serialized to {directory}")
|
|
@@ -234,23 +272,22 @@ class Proximity:
|
|
|
234
272
|
directory: Directory path containing the serialized model components
|
|
235
273
|
|
|
236
274
|
Returns:
|
|
237
|
-
|
|
275
|
+
A new Proximity instance
|
|
238
276
|
"""
|
|
239
|
-
|
|
240
|
-
if not
|
|
277
|
+
dir_path = Path(directory)
|
|
278
|
+
if not dir_path.is_dir():
|
|
241
279
|
raise ValueError(f"Directory {directory} does not exist or is not a directory")
|
|
242
280
|
|
|
243
281
|
# Load metadata
|
|
244
|
-
|
|
245
|
-
metadata = json.load(f)
|
|
282
|
+
metadata = json.loads((dir_path / "metadata.json").read_text())
|
|
246
283
|
|
|
247
284
|
# Load DataFrame
|
|
248
|
-
df_path =
|
|
249
|
-
if not
|
|
285
|
+
df_path = dir_path / "df.pkl"
|
|
286
|
+
if not df_path.exists():
|
|
250
287
|
raise FileNotFoundError(f"DataFrame file not found at {df_path}")
|
|
251
288
|
df = pd.read_pickle(df_path)
|
|
252
289
|
|
|
253
|
-
# Create instance
|
|
290
|
+
# Create instance without calling __init__
|
|
254
291
|
instance = cls.__new__(cls)
|
|
255
292
|
instance.df = df
|
|
256
293
|
instance.id_column = metadata["id_column"]
|
|
@@ -259,15 +296,16 @@ class Proximity:
|
|
|
259
296
|
instance.track_columns = metadata["track_columns"]
|
|
260
297
|
instance.n_neighbors = metadata["n_neighbors"]
|
|
261
298
|
|
|
262
|
-
# Load
|
|
263
|
-
with open(
|
|
299
|
+
# Load models
|
|
300
|
+
with open(dir_path / "scaler.pkl", "rb") as f:
|
|
264
301
|
instance.scaler = pickle.load(f)
|
|
265
302
|
|
|
266
|
-
with open(
|
|
303
|
+
with open(dir_path / "nn_model.pkl", "rb") as f:
|
|
267
304
|
instance.nn = pickle.load(f)
|
|
268
305
|
|
|
269
|
-
#
|
|
306
|
+
# Restore X
|
|
270
307
|
instance.X = instance.scaler.transform(instance.df[instance.features])
|
|
308
|
+
instance.proximity_type = ProximityType.DISTANCE
|
|
271
309
|
|
|
272
310
|
log.info(f"Proximity model deserialized from {directory}")
|
|
273
311
|
return instance
|
|
@@ -294,10 +332,10 @@ if __name__ == "__main__":
|
|
|
294
332
|
print(prox.all_neighbors())
|
|
295
333
|
|
|
296
334
|
# Test the neighbors method
|
|
297
|
-
print(prox.neighbors(
|
|
335
|
+
print(prox.neighbors(1))
|
|
298
336
|
|
|
299
337
|
# Test the neighbors method with radius
|
|
300
|
-
print(prox.neighbors(
|
|
338
|
+
print(prox.neighbors(1, radius=2.0))
|
|
301
339
|
|
|
302
340
|
# Test with data that isn't in the 'train' dataframe
|
|
303
341
|
query_data = {
|
|
@@ -307,7 +345,7 @@ if __name__ == "__main__":
|
|
|
307
345
|
"Feature3": [2.31],
|
|
308
346
|
}
|
|
309
347
|
query_df = pd.DataFrame(query_data)
|
|
310
|
-
print(prox.
|
|
348
|
+
print(prox.find_neighbors(query_df=query_df)) # For new data we use find_neighbors()
|
|
311
349
|
|
|
312
350
|
# Test with Features list
|
|
313
351
|
prox = Proximity(df, id_column="ID", features=["Feature1"], n_neighbors=2)
|
|
@@ -334,13 +372,13 @@ if __name__ == "__main__":
|
|
|
334
372
|
print(prox.all_neighbors())
|
|
335
373
|
|
|
336
374
|
# Test the neighbors method
|
|
337
|
-
print(prox.neighbors(
|
|
375
|
+
print(prox.neighbors(["a", "b"]))
|
|
338
376
|
|
|
339
377
|
# Time neighbors with all IDs versus calling all_neighbors
|
|
340
378
|
import time
|
|
341
379
|
|
|
342
380
|
start_time = time.time()
|
|
343
|
-
prox_df = prox.
|
|
381
|
+
prox_df = prox.find_neighbors(query_df=df, include_self=False)
|
|
344
382
|
end_time = time.time()
|
|
345
383
|
print(f"Time taken for neighbors: {end_time - start_time:.4f} seconds")
|
|
346
384
|
start_time = time.time()
|
|
@@ -361,7 +399,7 @@ if __name__ == "__main__":
|
|
|
361
399
|
|
|
362
400
|
# Test querying without the id_column
|
|
363
401
|
df_no_id = df.drop(columns=["foo_id"])
|
|
364
|
-
print(prox.
|
|
402
|
+
print(prox.find_neighbors(query_df=df_no_id, include_self=False))
|
|
365
403
|
|
|
366
404
|
# Test duplicate IDs
|
|
367
405
|
data = {
|
|
@@ -379,6 +417,9 @@ if __name__ == "__main__":
|
|
|
379
417
|
|
|
380
418
|
fs = FeatureSet("abalone_features")
|
|
381
419
|
model = Model("abalone-regression")
|
|
420
|
+
features = model.features()
|
|
382
421
|
df = fs.pull_dataframe()
|
|
383
|
-
prox = Proximity(
|
|
384
|
-
|
|
422
|
+
prox = Proximity(
|
|
423
|
+
df, id_column=fs.id_column, features=model.features(), target=model.target(), track_columns=features
|
|
424
|
+
)
|
|
425
|
+
print(prox.find_neighbors(query_df=df[0:2]))
|
workbench/api/model.py
CHANGED
|
@@ -10,7 +10,7 @@ from workbench.core.artifacts.artifact import Artifact
|
|
|
10
10
|
from workbench.core.artifacts.model_core import ModelCore, ModelType # noqa: F401
|
|
11
11
|
from workbench.core.transforms.model_to_endpoint.model_to_endpoint import ModelToEndpoint
|
|
12
12
|
from workbench.api.endpoint import Endpoint
|
|
13
|
-
from workbench.utils.model_utils import
|
|
13
|
+
from workbench.utils.model_utils import proximity_model_local, uq_model
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Model(ModelCore):
|
|
@@ -83,19 +83,16 @@ class Model(ModelCore):
|
|
|
83
83
|
end.set_owner(self.get_owner())
|
|
84
84
|
return end
|
|
85
85
|
|
|
86
|
-
def prox_model(self,
|
|
87
|
-
"""Create a Proximity Model for this Model
|
|
86
|
+
def prox_model(self, filtered: bool = True):
|
|
87
|
+
"""Create a local Proximity Model for this Model
|
|
88
88
|
|
|
89
89
|
Args:
|
|
90
|
-
|
|
91
|
-
track_columns (list, optional): List of columns to track in the Proximity Model.
|
|
90
|
+
filtered: bool, optional): Use filtered training data for the Proximity Model (default: True)
|
|
92
91
|
|
|
93
92
|
Returns:
|
|
94
|
-
|
|
93
|
+
Proximity: A local Proximity Model
|
|
95
94
|
"""
|
|
96
|
-
|
|
97
|
-
prox_model_name = self.model_name + "-prox"
|
|
98
|
-
return proximity_model(self, prox_model_name, track_columns=track_columns)
|
|
95
|
+
return proximity_model_local(self, filtered=filtered)
|
|
99
96
|
|
|
100
97
|
def uq_model(self, uq_model_name: str = None, train_all_data: bool = False) -> "Model":
|
|
101
98
|
"""Create a Uncertainty Quantification Model for this Model
|
|
@@ -121,6 +118,10 @@ if __name__ == "__main__":
|
|
|
121
118
|
pprint(my_model.summary())
|
|
122
119
|
pprint(my_model.details())
|
|
123
120
|
|
|
124
|
-
# Create an Endpoint from the Model
|
|
125
|
-
my_endpoint = my_model.to_endpoint()
|
|
126
|
-
pprint(my_endpoint.summary())
|
|
121
|
+
# Create an Endpoint from the Model (commented out for now)
|
|
122
|
+
# my_endpoint = my_model.to_endpoint()
|
|
123
|
+
# pprint(my_endpoint.summary())
|
|
124
|
+
|
|
125
|
+
# Create a local Proximity Model for this Model
|
|
126
|
+
prox_model = my_model.prox_model()
|
|
127
|
+
print(prox_model.neighbors(3398))
|
|
@@ -21,6 +21,7 @@ from workbench.utils.aws_utils import newest_path, pull_s3_data
|
|
|
21
21
|
from workbench.utils.s3_utils import compute_s3_object_hash
|
|
22
22
|
from workbench.utils.shap_utils import shap_values_data, shap_feature_importance
|
|
23
23
|
from workbench.utils.deprecated_utils import deprecated
|
|
24
|
+
from workbench.utils.model_utils import proximity_model
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class ModelType(Enum):
|
|
@@ -881,6 +882,20 @@ class ModelCore(Artifact):
|
|
|
881
882
|
except (KeyError, IndexError, TypeError):
|
|
882
883
|
return None
|
|
883
884
|
|
|
885
|
+
def publish_prox_model(self, prox_model_name: str = None, track_columns: list = None):
|
|
886
|
+
"""Create and publish a Proximity Model for this Model
|
|
887
|
+
|
|
888
|
+
Args:
|
|
889
|
+
prox_model_name (str, optional): Name of the Proximity Model (if not specified, a name will be generated)
|
|
890
|
+
track_columns (list, optional): List of columns to track in the Proximity Model.
|
|
891
|
+
|
|
892
|
+
Returns:
|
|
893
|
+
Model: The published Proximity Model
|
|
894
|
+
"""
|
|
895
|
+
if prox_model_name is None:
|
|
896
|
+
prox_model_name = self.model_name + "-prox"
|
|
897
|
+
return proximity_model(self, prox_model_name, track_columns=track_columns)
|
|
898
|
+
|
|
884
899
|
def delete(self):
|
|
885
900
|
"""Delete the Model Packages and the Model Group"""
|
|
886
901
|
if not self.exists():
|
|
@@ -6,9 +6,9 @@
|
|
|
6
6
|
# Template Placeholders
|
|
7
7
|
TEMPLATE_PARAMS = {
|
|
8
8
|
"id_column": "udm_mol_bat_id",
|
|
9
|
-
"features": ['bcut2d_logplow', '
|
|
10
|
-
"target": "
|
|
11
|
-
"track_columns": None
|
|
9
|
+
"features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
|
|
10
|
+
"target": "udm_asy_res_free_percent",
|
|
11
|
+
"track_columns": None,
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
from io import StringIO
|
|
@@ -73,10 +73,7 @@ if __name__ == "__main__":
|
|
|
73
73
|
args = parser.parse_args()
|
|
74
74
|
|
|
75
75
|
# Load training data from the specified directory
|
|
76
|
-
training_files = [
|
|
77
|
-
os.path.join(args.train, file)
|
|
78
|
-
for file in os.listdir(args.train) if file.endswith(".csv")
|
|
79
|
-
]
|
|
76
|
+
training_files = [os.path.join(args.train, file) for file in os.listdir(args.train) if file.endswith(".csv")]
|
|
80
77
|
all_df = pd.concat([pd.read_csv(file, engine="python") for file in training_files])
|
|
81
78
|
|
|
82
79
|
# Check if the DataFrame is empty
|
|
@@ -88,6 +85,7 @@ if __name__ == "__main__":
|
|
|
88
85
|
# Now serialize the model
|
|
89
86
|
model.serialize(args.model_dir)
|
|
90
87
|
|
|
88
|
+
|
|
91
89
|
# Model loading and prediction functions
|
|
92
90
|
def model_fn(model_dir):
|
|
93
91
|
|