workbench 0.8.213__py3-none-any.whl → 0.8.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +257 -80
  3. workbench/algorithms/dataframe/projection_2d.py +38 -21
  4. workbench/algorithms/dataframe/proximity.py +75 -150
  5. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  6. workbench/algorithms/models/cleanlab_model.py +382 -0
  7. workbench/algorithms/models/noise_model.py +2 -2
  8. workbench/algorithms/sql/outliers.py +3 -3
  9. workbench/api/__init__.py +3 -0
  10. workbench/api/endpoint.py +10 -5
  11. workbench/api/feature_set.py +76 -6
  12. workbench/api/meta_model.py +289 -0
  13. workbench/api/model.py +43 -4
  14. workbench/core/artifacts/endpoint_core.py +65 -117
  15. workbench/core/artifacts/feature_set_core.py +3 -3
  16. workbench/core/artifacts/model_core.py +6 -4
  17. workbench/core/pipelines/pipeline_executor.py +1 -1
  18. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +30 -10
  19. workbench/model_script_utils/model_script_utils.py +15 -11
  20. workbench/model_script_utils/pytorch_utils.py +11 -1
  21. workbench/model_scripts/chemprop/chemprop.template +147 -71
  22. workbench/model_scripts/chemprop/generated_model_script.py +151 -75
  23. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  24. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  25. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  26. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  27. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  28. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  29. workbench/model_scripts/meta_model/meta_model.template +209 -0
  30. workbench/model_scripts/pytorch_model/generated_model_script.py +45 -27
  31. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  32. workbench/model_scripts/pytorch_model/pytorch.template +42 -24
  33. workbench/model_scripts/pytorch_model/pytorch_utils.py +11 -1
  34. workbench/model_scripts/script_generation.py +4 -0
  35. workbench/model_scripts/xgb_model/generated_model_script.py +167 -156
  36. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  37. workbench/model_scripts/xgb_model/xgb_model.template +163 -152
  38. workbench/repl/workbench_shell.py +0 -5
  39. workbench/scripts/endpoint_test.py +2 -2
  40. workbench/scripts/meta_model_sim.py +35 -0
  41. workbench/utils/chem_utils/fingerprints.py +87 -46
  42. workbench/utils/chemprop_utils.py +23 -5
  43. workbench/utils/meta_model_simulator.py +499 -0
  44. workbench/utils/metrics_utils.py +94 -10
  45. workbench/utils/model_utils.py +91 -9
  46. workbench/utils/pytorch_utils.py +1 -1
  47. workbench/utils/shap_utils.py +1 -55
  48. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  49. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/METADATA +2 -1
  50. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/RECORD +54 -50
  51. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/entry_points.txt +1 -0
  52. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  53. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  54. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  55. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  56. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/WHEEL +0 -0
  57. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/licenses/LICENSE +0 -0
  58. {workbench-0.8.213.dist-info → workbench-0.8.219.dist-info}/top_level.txt +0 -0
@@ -7,39 +7,30 @@
7
7
  # - Sample weights support
8
8
  # - Categorical feature handling
9
9
  # - Compressed feature decompression
10
+ #
11
+ # NOTE: Imports are structured to minimize serverless endpoint startup time.
12
+ # Heavy imports (sklearn, awswrangler) are deferred to training time.
10
13
 
11
- import argparse
12
14
  import json
13
15
  import os
14
16
 
15
- import awswrangler as wr
16
17
  import joblib
17
18
  import numpy as np
18
19
  import pandas as pd
19
20
  import xgboost as xgb
20
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
21
- from sklearn.preprocessing import LabelEncoder
22
21
 
23
22
  from model_script_utils import (
24
- check_dataframe,
25
- compute_classification_metrics,
26
- compute_regression_metrics,
27
23
  convert_categorical_types,
28
24
  decompress_features,
29
25
  expand_proba_column,
30
26
  input_fn,
31
27
  match_features_case_insensitive,
32
28
  output_fn,
33
- print_classification_metrics,
34
- print_confusion_matrix,
35
- print_regression_metrics,
36
29
  )
37
30
  from uq_harness import (
38
31
  compute_confidence,
39
32
  load_uq_models,
40
33
  predict_intervals,
41
- save_uq_models,
42
- train_uq_models,
43
34
  )
44
35
 
45
36
  # =============================================================================
@@ -49,25 +40,27 @@ DEFAULT_HYPERPARAMETERS = {
49
40
  # Training parameters
50
41
  "n_folds": 5, # Number of CV folds (1 = single train/val split)
51
42
  # Core tree parameters
52
- "n_estimators": 200,
53
- "max_depth": 6,
43
+ "n_estimators": 300,
44
+ "max_depth": 7,
54
45
  "learning_rate": 0.05,
55
- # Sampling parameters
56
- "subsample": 0.7,
57
- "colsample_bytree": 0.6,
58
- "colsample_bylevel": 0.8,
59
- # Regularization
60
- "min_child_weight": 5,
61
- "gamma": 0.2,
62
- "reg_alpha": 0.5,
63
- "reg_lambda": 2.0,
46
+ # Sampling parameters (less aggressive - ensemble provides regularization)
47
+ "subsample": 0.8,
48
+ "colsample_bytree": 0.8,
49
+ # Regularization (lighter - ensemble averaging reduces overfitting)
50
+ "min_child_weight": 3,
51
+ "gamma": 0.1,
52
+ "reg_alpha": 0.1,
53
+ "reg_lambda": 1.0,
64
54
  # Random seed
65
- "random_state": 42,
55
+ "seed": 42,
66
56
  }
67
57
 
68
58
  # Workbench-specific parameters (not passed to XGBoost)
69
59
  WORKBENCH_PARAMS = {"n_folds"}
70
60
 
61
+ # Regression-only parameters (filtered out for classifiers)
62
+ REGRESSION_ONLY_PARAMS = {"objective"}
63
+
71
64
  # Template parameters (filled in by Workbench)
72
65
  TEMPLATE_PARAMS = {
73
66
  "model_type": "{{model_type}}",
@@ -80,10 +73,140 @@ TEMPLATE_PARAMS = {
80
73
  }
81
74
 
82
75
 
76
+ # =============================================================================
77
+ # Model Loading (for SageMaker inference)
78
+ # =============================================================================
79
+ def model_fn(model_dir: str) -> dict:
80
+ """Load XGBoost ensemble from the specified directory."""
81
+ # Load ensemble metadata
82
+ metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
83
+ if os.path.exists(metadata_path):
84
+ with open(metadata_path) as f:
85
+ metadata = json.load(f)
86
+ n_ensemble = metadata["n_ensemble"]
87
+ else:
88
+ n_ensemble = 1 # Legacy single model
89
+
90
+ # Load ensemble models
91
+ ensemble_models = []
92
+ for i in range(n_ensemble):
93
+ model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
94
+ if not os.path.exists(model_path):
95
+ model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
96
+ ensemble_models.append(joblib.load(model_path))
97
+
98
+ print(f"Loaded {len(ensemble_models)} model(s)")
99
+
100
+ # Load label encoder (classifier only)
101
+ label_encoder = None
102
+ encoder_path = os.path.join(model_dir, "label_encoder.joblib")
103
+ if os.path.exists(encoder_path):
104
+ label_encoder = joblib.load(encoder_path)
105
+
106
+ # Load category mappings
107
+ category_mappings = {}
108
+ category_path = os.path.join(model_dir, "category_mappings.json")
109
+ if os.path.exists(category_path):
110
+ with open(category_path) as f:
111
+ category_mappings = json.load(f)
112
+
113
+ # Load UQ models (regression only)
114
+ uq_models, uq_metadata = None, None
115
+ uq_path = os.path.join(model_dir, "uq_metadata.json")
116
+ if os.path.exists(uq_path):
117
+ uq_models, uq_metadata = load_uq_models(model_dir)
118
+
119
+ return {
120
+ "ensemble_models": ensemble_models,
121
+ "n_ensemble": n_ensemble,
122
+ "label_encoder": label_encoder,
123
+ "category_mappings": category_mappings,
124
+ "uq_models": uq_models,
125
+ "uq_metadata": uq_metadata,
126
+ }
127
+
128
+
129
+ # =============================================================================
130
+ # Inference (for SageMaker inference)
131
+ # =============================================================================
132
+ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
133
+ """Make predictions with XGBoost ensemble."""
134
+ model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
135
+ with open(os.path.join(model_dir, "feature_columns.json")) as f:
136
+ features = json.load(f)
137
+ print(f"Model Features: {features}")
138
+
139
+ # Extract model components
140
+ ensemble_models = model_dict["ensemble_models"]
141
+ label_encoder = model_dict.get("label_encoder")
142
+ category_mappings = model_dict.get("category_mappings", {})
143
+ uq_models = model_dict.get("uq_models")
144
+ uq_metadata = model_dict.get("uq_metadata")
145
+ compressed_features = TEMPLATE_PARAMS["compressed_features"]
146
+
147
+ # Prepare features
148
+ matched_df = match_features_case_insensitive(df, features)
149
+ matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
150
+
151
+ if compressed_features:
152
+ print("Decompressing features for prediction...")
153
+ matched_df, features = decompress_features(matched_df, features, compressed_features)
154
+
155
+ X = matched_df[features]
156
+
157
+ # Collect ensemble predictions
158
+ all_preds = [m.predict(X) for m in ensemble_models]
159
+ ensemble_preds = np.stack(all_preds, axis=0)
160
+
161
+ if label_encoder is not None:
162
+ # Classification: average probabilities, then argmax
163
+ all_probs = [m.predict_proba(X) for m in ensemble_models]
164
+ avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
165
+ class_preds = np.argmax(avg_probs, axis=1)
166
+
167
+ df["prediction"] = label_encoder.inverse_transform(class_preds)
168
+ df["pred_proba"] = [p.tolist() for p in avg_probs]
169
+ df = expand_proba_column(df, label_encoder.classes_)
170
+ else:
171
+ # Regression: average predictions
172
+ df["prediction"] = np.mean(ensemble_preds, axis=0)
173
+ df["prediction_std"] = np.std(ensemble_preds, axis=0)
174
+
175
+ # Add UQ intervals if available
176
+ if uq_models and uq_metadata:
177
+ df = predict_intervals(df, X, uq_models, uq_metadata)
178
+ df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
179
+
180
+ print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
181
+ return df
182
+
183
+
83
184
  # =============================================================================
84
185
  # Training
85
186
  # =============================================================================
86
187
  if __name__ == "__main__":
188
+ # -------------------------------------------------------------------------
189
+ # Training-only imports (deferred to reduce serverless startup time)
190
+ # -------------------------------------------------------------------------
191
+ import argparse
192
+
193
+ import awswrangler as wr
194
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
195
+ from sklearn.preprocessing import LabelEncoder
196
+
197
+ from model_script_utils import (
198
+ check_dataframe,
199
+ compute_classification_metrics,
200
+ compute_regression_metrics,
201
+ print_classification_metrics,
202
+ print_confusion_matrix,
203
+ print_regression_metrics,
204
+ )
205
+ from uq_harness import (
206
+ save_uq_models,
207
+ train_uq_models,
208
+ )
209
+
87
210
  # -------------------------------------------------------------------------
88
211
  # Setup: Parse arguments and load data
89
212
  # -------------------------------------------------------------------------
@@ -123,7 +246,7 @@ if __name__ == "__main__":
123
246
  all_df, features = decompress_features(all_df, features, compressed_features)
124
247
 
125
248
  # -------------------------------------------------------------------------
126
- # Classification setup: Encode target labels
249
+ # Classification setup
127
250
  # -------------------------------------------------------------------------
128
251
  label_encoder = None
129
252
  if model_type == "classifier":
@@ -136,6 +259,18 @@ if __name__ == "__main__":
136
259
  # -------------------------------------------------------------------------
137
260
  n_folds = hyperparameters["n_folds"]
138
261
  xgb_params = {k: v for k, v in hyperparameters.items() if k not in WORKBENCH_PARAMS}
262
+
263
+ # Map 'seed' to 'random_state' for XGBoost
264
+ if "seed" in xgb_params:
265
+ xgb_params["random_state"] = xgb_params.pop("seed")
266
+
267
+ # Handle objective: filter regression-only params for classifiers, set default for regressors
268
+ if model_type == "classifier":
269
+ xgb_params = {k: v for k, v in xgb_params.items() if k not in REGRESSION_ONLY_PARAMS}
270
+ else:
271
+ # Default to MAE (reg:absoluteerror) for regression if not specified
272
+ xgb_params.setdefault("objective", "reg:absoluteerror")
273
+
139
274
  print(f"XGBoost params: {xgb_params}")
140
275
 
141
276
  if n_folds == 1:
@@ -285,12 +420,10 @@ if __name__ == "__main__":
285
420
  # -------------------------------------------------------------------------
286
421
  # Save model artifacts
287
422
  # -------------------------------------------------------------------------
288
- # Ensemble models
289
- for idx, ens_model in enumerate(ensemble_models):
290
- joblib.dump(ens_model, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
291
- print(f"Saved {len(ensemble_models)} XGBoost model(s)")
423
+ for idx, m in enumerate(ensemble_models):
424
+ joblib.dump(m, os.path.join(args.model_dir, f"xgb_model_{idx}.joblib"))
425
+ print(f"Saved {len(ensemble_models)} model(s)")
292
426
 
293
- # Metadata files
294
427
  with open(os.path.join(args.model_dir, "ensemble_metadata.json"), "w") as f:
295
428
  json.dump({"n_ensemble": len(ensemble_models), "n_folds": n_folds}, f)
296
429
 
@@ -310,125 +443,3 @@ if __name__ == "__main__":
310
443
  save_uq_models(uq_models, uq_metadata, args.model_dir)
311
444
 
312
445
  print(f"\nModel training complete! Artifacts saved to {args.model_dir}")
313
-
314
-
315
- # =============================================================================
316
- # Model Loading (for SageMaker inference)
317
- # =============================================================================
318
- def model_fn(model_dir: str) -> dict:
319
- """Load XGBoost ensemble and associated artifacts.
320
-
321
- Args:
322
- model_dir: Directory containing model artifacts
323
-
324
- Returns:
325
- Dictionary with ensemble_models, label_encoder, category_mappings, uq_models, etc.
326
- """
327
- # Load ensemble metadata
328
- metadata_path = os.path.join(model_dir, "ensemble_metadata.json")
329
- if os.path.exists(metadata_path):
330
- with open(metadata_path) as f:
331
- metadata = json.load(f)
332
- n_ensemble = metadata["n_ensemble"]
333
- else:
334
- n_ensemble = 1 # Legacy single model
335
-
336
- # Load ensemble models
337
- ensemble_models = []
338
- for i in range(n_ensemble):
339
- model_path = os.path.join(model_dir, f"xgb_model_{i}.joblib")
340
- if not os.path.exists(model_path):
341
- model_path = os.path.join(model_dir, "xgb_model.joblib") # Legacy fallback
342
- ensemble_models.append(joblib.load(model_path))
343
-
344
- # Load label encoder (classifier only)
345
- label_encoder = None
346
- encoder_path = os.path.join(model_dir, "label_encoder.joblib")
347
- if os.path.exists(encoder_path):
348
- label_encoder = joblib.load(encoder_path)
349
-
350
- # Load category mappings
351
- category_mappings = {}
352
- category_path = os.path.join(model_dir, "category_mappings.json")
353
- if os.path.exists(category_path):
354
- with open(category_path) as f:
355
- category_mappings = json.load(f)
356
-
357
- # Load UQ models (regression only)
358
- uq_models, uq_metadata = None, None
359
- uq_path = os.path.join(model_dir, "uq_metadata.json")
360
- if os.path.exists(uq_path):
361
- uq_models, uq_metadata = load_uq_models(model_dir)
362
-
363
- return {
364
- "ensemble_models": ensemble_models,
365
- "n_ensemble": n_ensemble,
366
- "label_encoder": label_encoder,
367
- "category_mappings": category_mappings,
368
- "uq_models": uq_models,
369
- "uq_metadata": uq_metadata,
370
- }
371
-
372
-
373
- # =============================================================================
374
- # Inference (for SageMaker inference)
375
- # =============================================================================
376
- def predict_fn(df: pd.DataFrame, models: dict) -> pd.DataFrame:
377
- """Make predictions with XGBoost ensemble.
378
-
379
- Args:
380
- df: Input DataFrame with features
381
- models: Dictionary from model_fn containing ensemble and metadata
382
-
383
- Returns:
384
- DataFrame with predictions added
385
- """
386
- # Load feature columns
387
- model_dir = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
388
- with open(os.path.join(model_dir, "feature_columns.json")) as f:
389
- features = json.load(f)
390
- print(f"Model Features: {features}")
391
-
392
- # Extract model components
393
- ensemble_models = models["ensemble_models"]
394
- label_encoder = models.get("label_encoder")
395
- category_mappings = models.get("category_mappings", {})
396
- uq_models = models.get("uq_models")
397
- uq_metadata = models.get("uq_metadata")
398
- compressed_features = TEMPLATE_PARAMS["compressed_features"]
399
-
400
- # Prepare features
401
- matched_df = match_features_case_insensitive(df, features)
402
- matched_df, _ = convert_categorical_types(matched_df, features, category_mappings)
403
-
404
- if compressed_features:
405
- print("Decompressing features for prediction...")
406
- matched_df, features = decompress_features(matched_df, features, compressed_features)
407
-
408
- X = matched_df[features]
409
-
410
- # Collect ensemble predictions
411
- all_preds = [m.predict(X) for m in ensemble_models]
412
- ensemble_preds = np.stack(all_preds, axis=0)
413
-
414
- if label_encoder is not None:
415
- # Classification: average probabilities, then argmax
416
- all_probs = [m.predict_proba(X) for m in ensemble_models]
417
- avg_probs = np.mean(np.stack(all_probs, axis=0), axis=0)
418
- class_preds = np.argmax(avg_probs, axis=1)
419
-
420
- df["prediction"] = label_encoder.inverse_transform(class_preds)
421
- df["pred_proba"] = [p.tolist() for p in avg_probs]
422
- df = expand_proba_column(df, label_encoder.classes_)
423
- else:
424
- # Regression: average predictions
425
- df["prediction"] = np.mean(ensemble_preds, axis=0)
426
- df["prediction_std"] = np.std(ensemble_preds, axis=0)
427
-
428
- # Add UQ intervals if available
429
- if uq_models and uq_metadata:
430
- df = predict_intervals(df, X, uq_models, uq_metadata)
431
- df = compute_confidence(df, uq_metadata["median_interval_width"], "q_10", "q_90")
432
-
433
- print(f"Inference complete: {len(df)} predictions, {len(ensemble_models)} ensemble members")
434
- return df
@@ -302,11 +302,6 @@ class WorkbenchShell:
302
302
  self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
303
303
  self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
304
304
 
305
- # Algorithms
306
- self.commands["FSP"] = importlib.import_module(
307
- "workbench.algorithms.dataframe.feature_space_proximity"
308
- ).FeatureSpaceProximity
309
-
310
305
  # These are 'nice to have' imports
311
306
  self.commands["pd"] = importlib.import_module("pandas")
312
307
  self.commands["wr"] = importlib.import_module("awswrangler")
@@ -5,7 +5,7 @@ Usage:
5
5
  python model_script_harness.py <local_script.py> <model_name>
6
6
 
7
7
  Example:
8
- python model_script_harness.py pytorch.py aqsol-pytorch-reg
8
+ python model_script_harness.py pytorch.py aqsol-reg-pytorch
9
9
 
10
10
  This allows you to test LOCAL changes to a model script against deployed model artifacts.
11
11
  Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
@@ -72,7 +72,7 @@ def main():
72
72
  print("Usage: python model_script_harness.py <local_script.py> <model_name>")
73
73
  print("\nArguments:")
74
74
  print(" local_script.py - Path to your LOCAL model script to test")
75
- print(" model_name - Workbench model name (e.g., aqsol-pytorch-reg)")
75
+ print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
76
76
  print("\nOptional: testing/env.json with additional environment variables")
77
77
  sys.exit(1)
78
78
 
@@ -0,0 +1,35 @@
1
+ """MetaModelSimulator: Simulate and analyze ensemble model performance.
2
+
3
+ This class helps evaluate whether a meta model (ensemble) would outperform
4
+ individual child models by analyzing endpoint inference predictions.
5
+ """
6
+
7
+ import argparse
8
+ from workbench.utils.meta_model_simulator import MetaModelSimulator
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(
13
+ description="Simulate and analyze ensemble model performance using MetaModelSimulator."
14
+ )
15
+ parser.add_argument(
16
+ "models",
17
+ nargs="+",
18
+ help="List of model endpoint names to include in the ensemble simulation.",
19
+ )
20
+ parser.add_argument(
21
+ "--id-column",
22
+ default="molecule_name",
23
+ help="Name of the ID column (default: molecule_name)",
24
+ )
25
+ args = parser.parse_args()
26
+ models = args.models
27
+ id_column = args.id_column
28
+
29
+ # Create MetaModelSimulator instance and generate report
30
+ sim = MetaModelSimulator(models, id_column=id_column)
31
+ sim.report()
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
@@ -1,31 +1,48 @@
1
- """Molecular fingerprint computation utilities"""
1
+ """Molecular fingerprint computation utilities for ADMET modeling.
2
+
3
+ This module provides Morgan count fingerprints, the standard for ADMET prediction.
4
+ Count fingerprints outperform binary fingerprints for molecular property prediction.
5
+
6
+ References:
7
+ - Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
8
+ - ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
9
+ """
2
10
 
3
11
  import logging
4
- import pandas as pd
5
12
 
6
- # Molecular Descriptor Imports
7
- from rdkit import Chem
8
- from rdkit.Chem import rdFingerprintGenerator
13
+ import numpy as np
14
+ import pandas as pd
15
+ from rdkit import Chem, RDLogger
16
+ from rdkit.Chem import AllChem
9
17
  from rdkit.Chem.MolStandardize import rdMolStandardize
10
18
 
19
+ # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
20
+ # Keep errors enabled so we see actual problems
21
+ RDLogger.DisableLog("rdApp.warning")
22
+
11
23
  # Set up the logger
12
24
  log = logging.getLogger("workbench")
13
25
 
14
26
 
15
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
16
- """Compute and add Morgan fingerprints to the DataFrame.
27
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
28
+ """Compute Morgan count fingerprints for ADMET modeling.
29
+
30
+ Generates true count fingerprints where each bit position contains the
31
+ number of times that substructure appears in the molecule (clamped to 0-255).
32
+ This is the recommended approach for ADMET prediction per 2025 research.
17
33
 
18
34
  Args:
19
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
20
- radius (int): Radius for the Morgan fingerprint.
21
- n_bits (int): Number of bits for the fingerprint.
22
- counts (bool): Count simulation for the fingerprint.
35
+ df: Input DataFrame containing SMILES strings.
36
+ radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
37
+ n_bits: Number of bits for the fingerprint (default 2048).
23
38
 
24
39
  Returns:
25
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
40
+ pd.DataFrame: Input DataFrame with 'fingerprint' column added.
41
+ Values are comma-separated uint8 counts.
26
42
 
27
43
  Note:
28
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
44
+ Count fingerprints outperform binary for ADMET prediction.
45
+ See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
29
46
  """
30
47
  delete_mol_column = False
31
48
 
@@ -39,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
39
56
  log.warning("Detected serialized molecules in 'molecule' column. Removing...")
40
57
  del df["molecule"]
41
58
 
42
- # Convert SMILES to RDKit molecule objects (vectorized)
59
+ # Convert SMILES to RDKit molecule objects
43
60
  if "molecule" not in df.columns:
44
61
  log.info("Converting SMILES to RDKit Molecules...")
45
62
  delete_mol_column = True
@@ -47,23 +64,32 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
47
64
  # Make sure our molecules are not None
48
65
  failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
49
66
  if failed_smiles:
50
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
51
- df = df.dropna(subset=["molecule"])
67
+ log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
68
+ df = df.dropna(subset=["molecule"]).copy()
52
69
 
53
70
  # If we have fragments in our compounds, get the largest fragment before computing fingerprints
54
71
  largest_frags = df["molecule"].apply(
55
72
  lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
56
73
  )
57
74
 
58
- # Create a Morgan fingerprint generator
59
- if counts:
60
- n_bits *= 4 # Multiply by 4 to simulate counts
61
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
75
+ def mol_to_count_string(mol):
76
+ """Convert molecule to comma-separated count fingerprint string."""
77
+ if mol is None:
78
+ return pd.NA
62
79
 
63
- # Compute Morgan fingerprints (vectorized)
64
- fingerprints = largest_frags.apply(
65
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
66
- )
80
+ # Get hashed Morgan fingerprint with counts
81
+ fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
82
+
83
+ # Initialize array and populate with counts (clamped to uint8 range)
84
+ counts = np.zeros(n_bits, dtype=np.uint8)
85
+ for idx, count in fp.GetNonzeroElements().items():
86
+ counts[idx] = min(count, 255)
87
+
88
+ # Return as comma-separated string
89
+ return ",".join(map(str, counts))
90
+
91
+ # Compute Morgan count fingerprints
92
+ fingerprints = largest_frags.apply(mol_to_count_string)
67
93
 
68
94
  # Add the fingerprints to the DataFrame
69
95
  df["fingerprint"] = fingerprints
@@ -71,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
71
97
  # Drop the intermediate 'molecule' column if it was added
72
98
  if delete_mol_column:
73
99
  del df["molecule"]
100
+
74
101
  return df
75
102
 
76
103
 
77
104
  if __name__ == "__main__":
78
- print("Running molecular fingerprint tests...")
79
- print("Note: This requires molecular_screening module to be available")
105
+ print("Running Morgan count fingerprint tests...")
80
106
 
81
107
  # Test molecules
82
108
  test_molecules = {
83
109
  "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
84
110
  "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
85
111
  "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
86
- "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
87
113
  "benzene": "c1ccccc1",
88
114
  "butene_e": "C/C=C/C", # E-butene
89
115
  "butene_z": "C/C=C\\C", # Z-butene
90
116
  }
91
117
 
92
- # Test 1: Morgan Fingerprints
93
- print("\n1. Testing Morgan fingerprint generation...")
118
+ # Test 1: Morgan Count Fingerprints (default parameters)
119
+ print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
94
120
 
95
121
  test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
96
-
97
- fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
122
+ fp_df = compute_morgan_fingerprints(test_df.copy())
98
123
 
99
124
  print(" Fingerprint generation results:")
100
125
  for _, row in fp_df.iterrows():
101
126
  fp = row.get("fingerprint", "N/A")
102
- fp_len = len(fp) if fp != "N/A" else 0
103
- print(f" {row['name']:15} {fp_len} bits")
127
+ if pd.notna(fp):
128
+ counts = [int(x) for x in fp.split(",")]
129
+ non_zero = sum(1 for c in counts if c > 0)
130
+ max_count = max(counts)
131
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
132
+ else:
133
+ print(f" {row['name']:15} → N/A")
104
134
 
105
- # Test 2: Different fingerprint parameters
106
- print("\n2. Testing different fingerprint parameters...")
135
+ # Test 2: Different parameters
136
+ print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
107
137
 
108
- # Test with counts enabled
109
- fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
138
+ fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
110
139
 
111
- print(" With count simulation (256 bits * 4):")
112
- for _, row in fp_counts_df.iterrows():
140
+ for _, row in fp_df_custom.iterrows():
113
141
  fp = row.get("fingerprint", "N/A")
114
- fp_len = len(fp) if fp != "N/A" else 0
115
- print(f" {row['name']:15} {fp_len} bits")
142
+ if pd.notna(fp):
143
+ counts = [int(x) for x in fp.split(",")]
144
+ non_zero = sum(1 for c in counts if c > 0)
145
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
146
+ else:
147
+ print(f" {row['name']:15} → N/A")
116
148
 
117
149
  # Test 3: Edge cases
118
150
  print("\n3. Testing edge cases...")
119
151
 
120
152
  # Invalid SMILES
121
153
  invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
122
- try:
123
- fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
124
- print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
125
- except Exception as e:
126
- print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
154
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
155
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
127
156
 
128
157
  # Test with pre-existing molecule column
129
158
  mol_df = test_df.copy()
@@ -131,4 +160,16 @@ if __name__ == "__main__":
131
160
  fp_with_mol = compute_morgan_fingerprints(mol_df)
132
161
  print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
133
162
 
163
+ # Test 4: Verify count values are reasonable
164
+ print("\n4. Verifying count distribution...")
165
+ all_counts = []
166
+ for _, row in fp_df.iterrows():
167
+ fp = row.get("fingerprint", "N/A")
168
+ if pd.notna(fp):
169
+ counts = [int(x) for x in fp.split(",")]
170
+ all_counts.extend([c for c in counts if c > 0])
171
+
172
+ if all_counts:
173
+ print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
174
+
134
175
  print("\n✅ All fingerprint tests completed!")