workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +78 -150
- workbench/algorithms/graph/light/proximity_graph.py +5 -5
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +3 -0
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +13 -11
- workbench/api/feature_set.py +111 -8
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +45 -12
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/artifact.py +5 -5
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +228 -237
- workbench/core/artifacts/feature_set_core.py +185 -230
- workbench/core/artifacts/model_core.py +34 -26
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/features_to_model/features_to_model.py +22 -10
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +428 -631
- workbench/model_scripts/chemprop/generated_model_script.py +432 -635
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +2 -10
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +370 -609
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/script_generation.py +6 -5
- workbench/model_scripts/uq_models/generated_model_script.py +65 -422
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +366 -396
- workbench/repl/workbench_shell.py +0 -5
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +2 -2
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/training_test.py +85 -0
- workbench/utils/chem_utils/fingerprints.py +87 -46
- workbench/utils/chem_utils/projections.py +16 -6
- workbench/utils/chemprop_utils.py +36 -655
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +192 -54
- workbench/utils/pytorch_utils.py +33 -472
- workbench/utils/shap_utils.py +1 -55
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +49 -356
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugins/model_details.py +30 -68
- workbench/web_interface/components/plugins/scatter_plot.py +4 -8
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
- workbench/model_scripts/uq_models/mapie.template +0 -605
- workbench/model_scripts/uq_models/requirements.txt +0 -1
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""Shared utility functions for model training scripts (templates).
|
|
2
|
+
|
|
3
|
+
These functions are used across multiple model templates (XGBoost, PyTorch, ChemProp)
|
|
4
|
+
to reduce code duplication and ensure consistent behavior.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from io import StringIO
|
|
8
|
+
import json
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from sklearn.metrics import (
|
|
12
|
+
confusion_matrix,
|
|
13
|
+
mean_absolute_error,
|
|
14
|
+
median_absolute_error,
|
|
15
|
+
precision_recall_fscore_support,
|
|
16
|
+
r2_score,
|
|
17
|
+
root_mean_squared_error,
|
|
18
|
+
)
|
|
19
|
+
from scipy.stats import spearmanr
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
|
|
23
|
+
"""Check if the provided dataframe is empty and raise an exception if it is.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
df: DataFrame to check
|
|
27
|
+
df_name: Name of the DataFrame (for error message)
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
ValueError: If the DataFrame is empty
|
|
31
|
+
"""
|
|
32
|
+
if df.empty:
|
|
33
|
+
msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
|
|
34
|
+
print(msg)
|
|
35
|
+
raise ValueError(msg)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
|
|
39
|
+
"""Expands a column containing a list of probabilities into separate columns.
|
|
40
|
+
|
|
41
|
+
Handles None values for rows where predictions couldn't be made.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
df: DataFrame containing a "pred_proba" column
|
|
45
|
+
class_labels: List of class labels
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
DataFrame with the "pred_proba" expanded into separate columns (e.g., "class1_proba")
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If DataFrame does not contain a "pred_proba" column
|
|
52
|
+
"""
|
|
53
|
+
proba_column = "pred_proba"
|
|
54
|
+
if proba_column not in df.columns:
|
|
55
|
+
raise ValueError('DataFrame does not contain a "pred_proba" column')
|
|
56
|
+
|
|
57
|
+
proba_splits = [f"{label}_proba" for label in class_labels]
|
|
58
|
+
n_classes = len(class_labels)
|
|
59
|
+
|
|
60
|
+
# Handle None values by replacing with list of NaNs
|
|
61
|
+
proba_values = []
|
|
62
|
+
for val in df[proba_column]:
|
|
63
|
+
if val is None:
|
|
64
|
+
proba_values.append([np.nan] * n_classes)
|
|
65
|
+
else:
|
|
66
|
+
proba_values.append(val)
|
|
67
|
+
|
|
68
|
+
proba_df = pd.DataFrame(proba_values, columns=proba_splits)
|
|
69
|
+
|
|
70
|
+
# Drop any existing proba columns and reset index for concat
|
|
71
|
+
df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
|
|
72
|
+
df = df.reset_index(drop=True)
|
|
73
|
+
df = pd.concat([df, proba_df], axis=1)
|
|
74
|
+
return df
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def match_features_case_insensitive(df: pd.DataFrame, model_features: list[str]) -> pd.DataFrame:
|
|
78
|
+
"""Matches and renames DataFrame columns to match model feature names (case-insensitive).
|
|
79
|
+
|
|
80
|
+
Prioritizes exact matches, then case-insensitive matches.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
df: Input DataFrame
|
|
84
|
+
model_features: List of feature names expected by the model
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
DataFrame with columns renamed to match model features
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
ValueError: If any model features cannot be matched
|
|
91
|
+
"""
|
|
92
|
+
df_columns_lower = {col.lower(): col for col in df.columns}
|
|
93
|
+
rename_dict = {}
|
|
94
|
+
missing = []
|
|
95
|
+
for feature in model_features:
|
|
96
|
+
if feature in df.columns:
|
|
97
|
+
continue # Exact match
|
|
98
|
+
elif feature.lower() in df_columns_lower:
|
|
99
|
+
rename_dict[df_columns_lower[feature.lower()]] = feature
|
|
100
|
+
else:
|
|
101
|
+
missing.append(feature)
|
|
102
|
+
|
|
103
|
+
if missing:
|
|
104
|
+
raise ValueError(f"Features not found: {missing}")
|
|
105
|
+
|
|
106
|
+
return df.rename(columns=rename_dict)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def convert_categorical_types(
|
|
110
|
+
df: pd.DataFrame, features: list[str], category_mappings: dict[str, list[str]] | None = None
|
|
111
|
+
) -> tuple[pd.DataFrame, dict[str, list[str]]]:
|
|
112
|
+
"""Converts appropriate columns to categorical type with consistent mappings.
|
|
113
|
+
|
|
114
|
+
In training mode (category_mappings is None or empty), detects object/string columns
|
|
115
|
+
with <20 unique values and converts them to categorical.
|
|
116
|
+
In inference mode (category_mappings provided), applies the stored mappings.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
df: The DataFrame to process
|
|
120
|
+
features: List of feature names to consider for conversion
|
|
121
|
+
category_mappings: Existing category mappings. If None or empty, training mode.
|
|
122
|
+
If populated, inference mode.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Tuple of (processed DataFrame, category mappings dictionary)
|
|
126
|
+
"""
|
|
127
|
+
if category_mappings is None:
|
|
128
|
+
category_mappings = {}
|
|
129
|
+
|
|
130
|
+
# Training mode
|
|
131
|
+
if not category_mappings:
|
|
132
|
+
for col in df.select_dtypes(include=["object", "string"]):
|
|
133
|
+
if col in features and df[col].nunique() < 20:
|
|
134
|
+
print(f"Training mode: Converting {col} to category")
|
|
135
|
+
df[col] = df[col].astype("category")
|
|
136
|
+
category_mappings[col] = df[col].cat.categories.tolist()
|
|
137
|
+
|
|
138
|
+
# Inference mode
|
|
139
|
+
else:
|
|
140
|
+
for col, categories in category_mappings.items():
|
|
141
|
+
if col in df.columns:
|
|
142
|
+
print(f"Inference mode: Applying categorical mapping for {col}")
|
|
143
|
+
df[col] = pd.Categorical(df[col], categories=categories)
|
|
144
|
+
|
|
145
|
+
return df, category_mappings
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def decompress_features(
|
|
149
|
+
df: pd.DataFrame, features: list[str], compressed_features: list[str]
|
|
150
|
+
) -> tuple[pd.DataFrame, list[str]]:
|
|
151
|
+
"""Decompress compressed features (bitstrings or count vectors) into individual columns.
|
|
152
|
+
|
|
153
|
+
Supports two formats (auto-detected):
|
|
154
|
+
- Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
|
|
155
|
+
- Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
df: The features DataFrame
|
|
159
|
+
features: Full list of feature names
|
|
160
|
+
compressed_features: List of feature names to decompress
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Tuple of (DataFrame with decompressed features, updated feature list)
|
|
164
|
+
"""
|
|
165
|
+
# Check for any missing values in the required features
|
|
166
|
+
missing_counts = df[features].isna().sum()
|
|
167
|
+
if missing_counts.any():
|
|
168
|
+
missing_features = missing_counts[missing_counts > 0]
|
|
169
|
+
print(
|
|
170
|
+
f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
|
|
171
|
+
"WARNING: You might want to remove/replace all NaN values before processing."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Make a copy to avoid mutating the original list
|
|
175
|
+
decompressed_features = features.copy()
|
|
176
|
+
|
|
177
|
+
for feature in compressed_features:
|
|
178
|
+
if (feature not in df.columns) or (feature not in decompressed_features):
|
|
179
|
+
print(f"Feature '{feature}' not in the features list, skipping decompression.")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
# Remove the feature from the list to avoid duplication
|
|
183
|
+
decompressed_features.remove(feature)
|
|
184
|
+
|
|
185
|
+
# Auto-detect format and parse: comma-separated counts or bitstring
|
|
186
|
+
sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
|
|
187
|
+
parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
|
|
188
|
+
feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
|
|
189
|
+
|
|
190
|
+
# Create new columns with prefix from feature name
|
|
191
|
+
prefix = feature[:3]
|
|
192
|
+
new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
|
|
193
|
+
new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
|
|
194
|
+
|
|
195
|
+
# Update features list and dataframe
|
|
196
|
+
decompressed_features.extend(new_col_names)
|
|
197
|
+
df = df.drop(columns=[feature])
|
|
198
|
+
df = pd.concat([df, new_df], axis=1)
|
|
199
|
+
|
|
200
|
+
return df, decompressed_features
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def input_fn(input_data, content_type: str) -> pd.DataFrame:
|
|
204
|
+
"""Parse input data and return a DataFrame.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
input_data: Raw input data (bytes or string)
|
|
208
|
+
content_type: MIME type of the input data
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Parsed DataFrame
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
ValueError: If input is empty or content_type is not supported
|
|
215
|
+
"""
|
|
216
|
+
if not input_data:
|
|
217
|
+
raise ValueError("Empty input data is not supported!")
|
|
218
|
+
|
|
219
|
+
if isinstance(input_data, bytes):
|
|
220
|
+
input_data = input_data.decode("utf-8")
|
|
221
|
+
|
|
222
|
+
if "text/csv" in content_type:
|
|
223
|
+
return pd.read_csv(StringIO(input_data))
|
|
224
|
+
elif "application/json" in content_type:
|
|
225
|
+
return pd.DataFrame(json.loads(input_data))
|
|
226
|
+
else:
|
|
227
|
+
raise ValueError(f"{content_type} not supported!")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
231
|
+
"""Convert output DataFrame to requested format.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
output_df: DataFrame to convert
|
|
235
|
+
accept_type: Requested MIME type
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Tuple of (formatted output string, MIME type)
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
RuntimeError: If accept_type is not supported
|
|
242
|
+
"""
|
|
243
|
+
if "text/csv" in accept_type:
|
|
244
|
+
csv_output = output_df.fillna("N/A").to_csv(index=False)
|
|
245
|
+
return csv_output, "text/csv"
|
|
246
|
+
elif "application/json" in accept_type:
|
|
247
|
+
return output_df.to_json(orient="records"), "application/json"
|
|
248
|
+
else:
|
|
249
|
+
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
|
+
"""Compute standard regression metrics.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
y_true: Ground truth target values
|
|
257
|
+
y_pred: Predicted values
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Dictionary with keys: rmse, mae, medae, r2, spearmanr, support
|
|
261
|
+
"""
|
|
262
|
+
return {
|
|
263
|
+
"rmse": root_mean_squared_error(y_true, y_pred),
|
|
264
|
+
"mae": mean_absolute_error(y_true, y_pred),
|
|
265
|
+
"medae": median_absolute_error(y_true, y_pred),
|
|
266
|
+
"r2": r2_score(y_true, y_pred),
|
|
267
|
+
"spearmanr": spearmanr(y_true, y_pred).correlation,
|
|
268
|
+
"support": len(y_true),
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def print_regression_metrics(metrics: dict[str, float]) -> None:
|
|
273
|
+
"""Print regression metrics in the format expected by SageMaker metric definitions.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
metrics: Dictionary of metric name -> value
|
|
277
|
+
"""
|
|
278
|
+
print(f"rmse: {metrics['rmse']:.3f}")
|
|
279
|
+
print(f"mae: {metrics['mae']:.3f}")
|
|
280
|
+
print(f"medae: {metrics['medae']:.3f}")
|
|
281
|
+
print(f"r2: {metrics['r2']:.3f}")
|
|
282
|
+
print(f"spearmanr: {metrics['spearmanr']:.3f}")
|
|
283
|
+
print(f"support: {metrics['support']}")
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def compute_classification_metrics(
|
|
287
|
+
y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str], target_col: str
|
|
288
|
+
) -> pd.DataFrame:
|
|
289
|
+
"""Compute per-class classification metrics.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
y_true: Ground truth labels
|
|
293
|
+
y_pred: Predicted labels
|
|
294
|
+
label_names: List of class label names
|
|
295
|
+
target_col: Name of the target column (for DataFrame output)
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
DataFrame with columns: target_col, precision, recall, f1, support
|
|
299
|
+
"""
|
|
300
|
+
scores = precision_recall_fscore_support(y_true, y_pred, average=None, labels=label_names)
|
|
301
|
+
return pd.DataFrame(
|
|
302
|
+
{
|
|
303
|
+
target_col: label_names,
|
|
304
|
+
"precision": scores[0],
|
|
305
|
+
"recall": scores[1],
|
|
306
|
+
"f1": scores[2],
|
|
307
|
+
"support": scores[3],
|
|
308
|
+
}
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def print_classification_metrics(score_df: pd.DataFrame, target_col: str, label_names: list[str]) -> None:
|
|
313
|
+
"""Print per-class classification metrics in the format expected by SageMaker.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
score_df: DataFrame from compute_classification_metrics
|
|
317
|
+
target_col: Name of the target column
|
|
318
|
+
label_names: List of class label names
|
|
319
|
+
"""
|
|
320
|
+
metrics = ["precision", "recall", "f1", "support"]
|
|
321
|
+
for t in label_names:
|
|
322
|
+
for m in metrics:
|
|
323
|
+
value = score_df.loc[score_df[target_col] == t, m].iloc[0]
|
|
324
|
+
print(f"Metrics:{t}:{m} {value}")
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str]) -> None:
|
|
328
|
+
"""Print confusion matrix in the format expected by SageMaker.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
y_true: Ground truth labels
|
|
332
|
+
y_pred: Predicted labels
|
|
333
|
+
label_names: List of class label names
|
|
334
|
+
"""
|
|
335
|
+
conf_mtx = confusion_matrix(y_true, y_pred, labels=label_names)
|
|
336
|
+
for i, row_name in enumerate(label_names):
|
|
337
|
+
for j, col_name in enumerate(label_names):
|
|
338
|
+
value = conf_mtx[i, j]
|
|
339
|
+
print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
|
|
@@ -1,11 +1,3 @@
|
|
|
1
1
|
# Requirements for ChemProp model scripts
|
|
2
|
-
# Note:
|
|
3
|
-
|
|
4
|
-
rdkit==2025.9.1
|
|
5
|
-
torch>=2.0.0
|
|
6
|
-
lightning>=2.0.0
|
|
7
|
-
pandas>=2.0.0
|
|
8
|
-
numpy>=1.24.0
|
|
9
|
-
scikit-learn>=1.3.0
|
|
10
|
-
awswrangler>=3.0.0
|
|
11
|
-
joblib>=1.3.0
|
|
2
|
+
# Note: The training and inference images already have torch and chemprop installed.
|
|
3
|
+
# So we only need to install packages that are not already included in the images.
|
|
@@ -1,31 +1,48 @@
|
|
|
1
|
-
"""Molecular fingerprint computation utilities
|
|
1
|
+
"""Molecular fingerprint computation utilities for ADMET modeling.
|
|
2
|
+
|
|
3
|
+
This module provides Morgan count fingerprints, the standard for ADMET prediction.
|
|
4
|
+
Count fingerprints outperform binary fingerprints for molecular property prediction.
|
|
5
|
+
|
|
6
|
+
References:
|
|
7
|
+
- Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
8
|
+
- ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
|
|
9
|
+
"""
|
|
2
10
|
|
|
3
11
|
import logging
|
|
4
|
-
import pandas as pd
|
|
5
12
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from rdkit
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from rdkit import Chem, RDLogger
|
|
16
|
+
from rdkit.Chem import AllChem
|
|
9
17
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
10
18
|
|
|
19
|
+
# Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
|
|
20
|
+
# Keep errors enabled so we see actual problems
|
|
21
|
+
RDLogger.DisableLog("rdApp.warning")
|
|
22
|
+
|
|
11
23
|
# Set up the logger
|
|
12
24
|
log = logging.getLogger("workbench")
|
|
13
25
|
|
|
14
26
|
|
|
15
|
-
def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048
|
|
16
|
-
"""Compute
|
|
27
|
+
def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
|
|
28
|
+
"""Compute Morgan count fingerprints for ADMET modeling.
|
|
29
|
+
|
|
30
|
+
Generates true count fingerprints where each bit position contains the
|
|
31
|
+
number of times that substructure appears in the molecule (clamped to 0-255).
|
|
32
|
+
This is the recommended approach for ADMET prediction per 2025 research.
|
|
17
33
|
|
|
18
34
|
Args:
|
|
19
|
-
df
|
|
20
|
-
radius
|
|
21
|
-
n_bits
|
|
22
|
-
counts (bool): Count simulation for the fingerprint.
|
|
35
|
+
df: Input DataFrame containing SMILES strings.
|
|
36
|
+
radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
|
|
37
|
+
n_bits: Number of bits for the fingerprint (default 2048).
|
|
23
38
|
|
|
24
39
|
Returns:
|
|
25
|
-
pd.DataFrame:
|
|
40
|
+
pd.DataFrame: Input DataFrame with 'fingerprint' column added.
|
|
41
|
+
Values are comma-separated uint8 counts.
|
|
26
42
|
|
|
27
43
|
Note:
|
|
28
|
-
|
|
44
|
+
Count fingerprints outperform binary for ADMET prediction.
|
|
45
|
+
See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
|
|
29
46
|
"""
|
|
30
47
|
delete_mol_column = False
|
|
31
48
|
|
|
@@ -39,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
39
56
|
log.warning("Detected serialized molecules in 'molecule' column. Removing...")
|
|
40
57
|
del df["molecule"]
|
|
41
58
|
|
|
42
|
-
# Convert SMILES to RDKit molecule objects
|
|
59
|
+
# Convert SMILES to RDKit molecule objects
|
|
43
60
|
if "molecule" not in df.columns:
|
|
44
61
|
log.info("Converting SMILES to RDKit Molecules...")
|
|
45
62
|
delete_mol_column = True
|
|
@@ -47,23 +64,32 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
47
64
|
# Make sure our molecules are not None
|
|
48
65
|
failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
|
|
49
66
|
if failed_smiles:
|
|
50
|
-
log.
|
|
51
|
-
df = df.dropna(subset=["molecule"])
|
|
67
|
+
log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
|
|
68
|
+
df = df.dropna(subset=["molecule"]).copy()
|
|
52
69
|
|
|
53
70
|
# If we have fragments in our compounds, get the largest fragment before computing fingerprints
|
|
54
71
|
largest_frags = df["molecule"].apply(
|
|
55
72
|
lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
|
|
56
73
|
)
|
|
57
74
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
75
|
+
def mol_to_count_string(mol):
|
|
76
|
+
"""Convert molecule to comma-separated count fingerprint string."""
|
|
77
|
+
if mol is None:
|
|
78
|
+
return pd.NA
|
|
62
79
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
80
|
+
# Get hashed Morgan fingerprint with counts
|
|
81
|
+
fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
|
|
82
|
+
|
|
83
|
+
# Initialize array and populate with counts (clamped to uint8 range)
|
|
84
|
+
counts = np.zeros(n_bits, dtype=np.uint8)
|
|
85
|
+
for idx, count in fp.GetNonzeroElements().items():
|
|
86
|
+
counts[idx] = min(count, 255)
|
|
87
|
+
|
|
88
|
+
# Return as comma-separated string
|
|
89
|
+
return ",".join(map(str, counts))
|
|
90
|
+
|
|
91
|
+
# Compute Morgan count fingerprints
|
|
92
|
+
fingerprints = largest_frags.apply(mol_to_count_string)
|
|
67
93
|
|
|
68
94
|
# Add the fingerprints to the DataFrame
|
|
69
95
|
df["fingerprint"] = fingerprints
|
|
@@ -71,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
|
|
|
71
97
|
# Drop the intermediate 'molecule' column if it was added
|
|
72
98
|
if delete_mol_column:
|
|
73
99
|
del df["molecule"]
|
|
100
|
+
|
|
74
101
|
return df
|
|
75
102
|
|
|
76
103
|
|
|
77
104
|
if __name__ == "__main__":
|
|
78
|
-
print("Running
|
|
79
|
-
print("Note: This requires molecular_screening module to be available")
|
|
105
|
+
print("Running Morgan count fingerprint tests...")
|
|
80
106
|
|
|
81
107
|
# Test molecules
|
|
82
108
|
test_molecules = {
|
|
83
109
|
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
|
|
84
110
|
"caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
|
|
85
111
|
"glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
|
|
86
|
-
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
|
|
112
|
+
"sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
|
|
87
113
|
"benzene": "c1ccccc1",
|
|
88
114
|
"butene_e": "C/C=C/C", # E-butene
|
|
89
115
|
"butene_z": "C/C=C\\C", # Z-butene
|
|
90
116
|
}
|
|
91
117
|
|
|
92
|
-
# Test 1: Morgan Fingerprints
|
|
93
|
-
print("\n1. Testing Morgan fingerprint generation...")
|
|
118
|
+
# Test 1: Morgan Count Fingerprints (default parameters)
|
|
119
|
+
print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
|
|
94
120
|
|
|
95
121
|
test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
|
|
96
|
-
|
|
97
|
-
fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
|
|
122
|
+
fp_df = compute_morgan_fingerprints(test_df.copy())
|
|
98
123
|
|
|
99
124
|
print(" Fingerprint generation results:")
|
|
100
125
|
for _, row in fp_df.iterrows():
|
|
101
126
|
fp = row.get("fingerprint", "N/A")
|
|
102
|
-
|
|
103
|
-
|
|
127
|
+
if pd.notna(fp):
|
|
128
|
+
counts = [int(x) for x in fp.split(",")]
|
|
129
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
130
|
+
max_count = max(counts)
|
|
131
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
|
|
132
|
+
else:
|
|
133
|
+
print(f" {row['name']:15} → N/A")
|
|
104
134
|
|
|
105
|
-
# Test 2: Different
|
|
106
|
-
print("\n2. Testing different
|
|
135
|
+
# Test 2: Different parameters
|
|
136
|
+
print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
|
|
107
137
|
|
|
108
|
-
|
|
109
|
-
fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
|
|
138
|
+
fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
|
|
110
139
|
|
|
111
|
-
|
|
112
|
-
for _, row in fp_counts_df.iterrows():
|
|
140
|
+
for _, row in fp_df_custom.iterrows():
|
|
113
141
|
fp = row.get("fingerprint", "N/A")
|
|
114
|
-
|
|
115
|
-
|
|
142
|
+
if pd.notna(fp):
|
|
143
|
+
counts = [int(x) for x in fp.split(",")]
|
|
144
|
+
non_zero = sum(1 for c in counts if c > 0)
|
|
145
|
+
print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
|
|
146
|
+
else:
|
|
147
|
+
print(f" {row['name']:15} → N/A")
|
|
116
148
|
|
|
117
149
|
# Test 3: Edge cases
|
|
118
150
|
print("\n3. Testing edge cases...")
|
|
119
151
|
|
|
120
152
|
# Invalid SMILES
|
|
121
153
|
invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
|
|
125
|
-
except Exception as e:
|
|
126
|
-
print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
|
|
154
|
+
fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
|
|
155
|
+
print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
|
|
127
156
|
|
|
128
157
|
# Test with pre-existing molecule column
|
|
129
158
|
mol_df = test_df.copy()
|
|
@@ -131,4 +160,16 @@ if __name__ == "__main__":
|
|
|
131
160
|
fp_with_mol = compute_morgan_fingerprints(mol_df)
|
|
132
161
|
print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
|
|
133
162
|
|
|
163
|
+
# Test 4: Verify count values are reasonable
|
|
164
|
+
print("\n4. Verifying count distribution...")
|
|
165
|
+
all_counts = []
|
|
166
|
+
for _, row in fp_df.iterrows():
|
|
167
|
+
fp = row.get("fingerprint", "N/A")
|
|
168
|
+
if pd.notna(fp):
|
|
169
|
+
counts = [int(x) for x in fp.split(",")]
|
|
170
|
+
all_counts.extend([c for c in counts if c > 0])
|
|
171
|
+
|
|
172
|
+
if all_counts:
|
|
173
|
+
print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
|
|
174
|
+
|
|
134
175
|
print("\n✅ All fingerprint tests completed!")
|