workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (84) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  3. workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
  4. workbench/algorithms/dataframe/projection_2d.py +44 -21
  5. workbench/algorithms/dataframe/proximity.py +78 -150
  6. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  7. workbench/algorithms/models/cleanlab_model.py +382 -0
  8. workbench/algorithms/models/noise_model.py +388 -0
  9. workbench/algorithms/sql/outliers.py +3 -3
  10. workbench/api/__init__.py +3 -0
  11. workbench/api/df_store.py +17 -108
  12. workbench/api/endpoint.py +13 -11
  13. workbench/api/feature_set.py +111 -8
  14. workbench/api/meta_model.py +289 -0
  15. workbench/api/model.py +45 -12
  16. workbench/api/parameter_store.py +3 -52
  17. workbench/cached/cached_model.py +4 -4
  18. workbench/core/artifacts/artifact.py +5 -5
  19. workbench/core/artifacts/df_store_core.py +114 -0
  20. workbench/core/artifacts/endpoint_core.py +228 -237
  21. workbench/core/artifacts/feature_set_core.py +185 -230
  22. workbench/core/artifacts/model_core.py +34 -26
  23. workbench/core/artifacts/parameter_store_core.py +98 -0
  24. workbench/core/pipelines/pipeline_executor.py +1 -1
  25. workbench/core/transforms/features_to_model/features_to_model.py +22 -10
  26. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
  27. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  28. workbench/model_script_utils/model_script_utils.py +339 -0
  29. workbench/model_script_utils/pytorch_utils.py +405 -0
  30. workbench/model_script_utils/uq_harness.py +278 -0
  31. workbench/model_scripts/chemprop/chemprop.template +428 -631
  32. workbench/model_scripts/chemprop/generated_model_script.py +432 -635
  33. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  34. workbench/model_scripts/chemprop/requirements.txt +2 -10
  35. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  36. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  37. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  38. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  39. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  40. workbench/model_scripts/meta_model/meta_model.template +209 -0
  41. workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
  42. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  43. workbench/model_scripts/pytorch_model/pytorch.template +370 -609
  44. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  45. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  46. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  47. workbench/model_scripts/script_generation.py +6 -5
  48. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  49. workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
  50. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  51. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  52. workbench/model_scripts/xgb_model/xgb_model.template +366 -396
  53. workbench/repl/workbench_shell.py +0 -5
  54. workbench/resources/open_source_api.key +1 -1
  55. workbench/scripts/endpoint_test.py +2 -2
  56. workbench/scripts/meta_model_sim.py +35 -0
  57. workbench/scripts/training_test.py +85 -0
  58. workbench/utils/chem_utils/fingerprints.py +87 -46
  59. workbench/utils/chem_utils/projections.py +16 -6
  60. workbench/utils/chemprop_utils.py +36 -655
  61. workbench/utils/meta_model_simulator.py +499 -0
  62. workbench/utils/metrics_utils.py +256 -0
  63. workbench/utils/model_utils.py +192 -54
  64. workbench/utils/pytorch_utils.py +33 -472
  65. workbench/utils/shap_utils.py +1 -55
  66. workbench/utils/xgboost_local_crossfold.py +267 -0
  67. workbench/utils/xgboost_model_utils.py +49 -356
  68. workbench/web_interface/components/model_plot.py +7 -1
  69. workbench/web_interface/components/plugins/model_details.py +30 -68
  70. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  71. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
  72. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
  73. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
  74. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  75. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
  76. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  77. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  78. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  79. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  80. workbench/model_scripts/uq_models/mapie.template +0 -605
  81. workbench/model_scripts/uq_models/requirements.txt +0 -1
  82. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  83. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
  84. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
1
+ """Shared utility functions for model training scripts (templates).
2
+
3
+ These functions are used across multiple model templates (XGBoost, PyTorch, ChemProp)
4
+ to reduce code duplication and ensure consistent behavior.
5
+ """
6
+
7
+ from io import StringIO
8
+ import json
9
+ import numpy as np
10
+ import pandas as pd
11
+ from sklearn.metrics import (
12
+ confusion_matrix,
13
+ mean_absolute_error,
14
+ median_absolute_error,
15
+ precision_recall_fscore_support,
16
+ r2_score,
17
+ root_mean_squared_error,
18
+ )
19
+ from scipy.stats import spearmanr
20
+
21
+
22
+ def check_dataframe(df: pd.DataFrame, df_name: str) -> None:
23
+ """Check if the provided dataframe is empty and raise an exception if it is.
24
+
25
+ Args:
26
+ df: DataFrame to check
27
+ df_name: Name of the DataFrame (for error message)
28
+
29
+ Raises:
30
+ ValueError: If the DataFrame is empty
31
+ """
32
+ if df.empty:
33
+ msg = f"*** The training data {df_name} has 0 rows! ***STOPPING***"
34
+ print(msg)
35
+ raise ValueError(msg)
36
+
37
+
38
+ def expand_proba_column(df: pd.DataFrame, class_labels: list[str]) -> pd.DataFrame:
39
+ """Expands a column containing a list of probabilities into separate columns.
40
+
41
+ Handles None values for rows where predictions couldn't be made.
42
+
43
+ Args:
44
+ df: DataFrame containing a "pred_proba" column
45
+ class_labels: List of class labels
46
+
47
+ Returns:
48
+ DataFrame with the "pred_proba" expanded into separate columns (e.g., "class1_proba")
49
+
50
+ Raises:
51
+ ValueError: If DataFrame does not contain a "pred_proba" column
52
+ """
53
+ proba_column = "pred_proba"
54
+ if proba_column not in df.columns:
55
+ raise ValueError('DataFrame does not contain a "pred_proba" column')
56
+
57
+ proba_splits = [f"{label}_proba" for label in class_labels]
58
+ n_classes = len(class_labels)
59
+
60
+ # Handle None values by replacing with list of NaNs
61
+ proba_values = []
62
+ for val in df[proba_column]:
63
+ if val is None:
64
+ proba_values.append([np.nan] * n_classes)
65
+ else:
66
+ proba_values.append(val)
67
+
68
+ proba_df = pd.DataFrame(proba_values, columns=proba_splits)
69
+
70
+ # Drop any existing proba columns and reset index for concat
71
+ df = df.drop(columns=[proba_column] + proba_splits, errors="ignore")
72
+ df = df.reset_index(drop=True)
73
+ df = pd.concat([df, proba_df], axis=1)
74
+ return df
75
+
76
+
77
+ def match_features_case_insensitive(df: pd.DataFrame, model_features: list[str]) -> pd.DataFrame:
78
+ """Matches and renames DataFrame columns to match model feature names (case-insensitive).
79
+
80
+ Prioritizes exact matches, then case-insensitive matches.
81
+
82
+ Args:
83
+ df: Input DataFrame
84
+ model_features: List of feature names expected by the model
85
+
86
+ Returns:
87
+ DataFrame with columns renamed to match model features
88
+
89
+ Raises:
90
+ ValueError: If any model features cannot be matched
91
+ """
92
+ df_columns_lower = {col.lower(): col for col in df.columns}
93
+ rename_dict = {}
94
+ missing = []
95
+ for feature in model_features:
96
+ if feature in df.columns:
97
+ continue # Exact match
98
+ elif feature.lower() in df_columns_lower:
99
+ rename_dict[df_columns_lower[feature.lower()]] = feature
100
+ else:
101
+ missing.append(feature)
102
+
103
+ if missing:
104
+ raise ValueError(f"Features not found: {missing}")
105
+
106
+ return df.rename(columns=rename_dict)
107
+
108
+
109
+ def convert_categorical_types(
110
+ df: pd.DataFrame, features: list[str], category_mappings: dict[str, list[str]] | None = None
111
+ ) -> tuple[pd.DataFrame, dict[str, list[str]]]:
112
+ """Converts appropriate columns to categorical type with consistent mappings.
113
+
114
+ In training mode (category_mappings is None or empty), detects object/string columns
115
+ with <20 unique values and converts them to categorical.
116
+ In inference mode (category_mappings provided), applies the stored mappings.
117
+
118
+ Args:
119
+ df: The DataFrame to process
120
+ features: List of feature names to consider for conversion
121
+ category_mappings: Existing category mappings. If None or empty, training mode.
122
+ If populated, inference mode.
123
+
124
+ Returns:
125
+ Tuple of (processed DataFrame, category mappings dictionary)
126
+ """
127
+ if category_mappings is None:
128
+ category_mappings = {}
129
+
130
+ # Training mode
131
+ if not category_mappings:
132
+ for col in df.select_dtypes(include=["object", "string"]):
133
+ if col in features and df[col].nunique() < 20:
134
+ print(f"Training mode: Converting {col} to category")
135
+ df[col] = df[col].astype("category")
136
+ category_mappings[col] = df[col].cat.categories.tolist()
137
+
138
+ # Inference mode
139
+ else:
140
+ for col, categories in category_mappings.items():
141
+ if col in df.columns:
142
+ print(f"Inference mode: Applying categorical mapping for {col}")
143
+ df[col] = pd.Categorical(df[col], categories=categories)
144
+
145
+ return df, category_mappings
146
+
147
+
148
+ def decompress_features(
149
+ df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
+ ) -> tuple[pd.DataFrame, list[str]]:
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
156
+
157
+ Args:
158
+ df: The features DataFrame
159
+ features: Full list of feature names
160
+ compressed_features: List of feature names to decompress
161
+
162
+ Returns:
163
+ Tuple of (DataFrame with decompressed features, updated feature list)
164
+ """
165
+ # Check for any missing values in the required features
166
+ missing_counts = df[features].isna().sum()
167
+ if missing_counts.any():
168
+ missing_features = missing_counts[missing_counts > 0]
169
+ print(
170
+ f"WARNING: Found missing values in features: {missing_features.to_dict()}. "
171
+ "WARNING: You might want to remove/replace all NaN values before processing."
172
+ )
173
+
174
+ # Make a copy to avoid mutating the original list
175
+ decompressed_features = features.copy()
176
+
177
+ for feature in compressed_features:
178
+ if (feature not in df.columns) or (feature not in decompressed_features):
179
+ print(f"Feature '{feature}' not in the features list, skipping decompression.")
180
+ continue
181
+
182
+ # Remove the feature from the list to avoid duplication
183
+ decompressed_features.remove(feature)
184
+
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
189
+
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
194
+
195
+ # Update features list and dataframe
196
+ decompressed_features.extend(new_col_names)
197
+ df = df.drop(columns=[feature])
198
+ df = pd.concat([df, new_df], axis=1)
199
+
200
+ return df, decompressed_features
201
+
202
+
203
+ def input_fn(input_data, content_type: str) -> pd.DataFrame:
204
+ """Parse input data and return a DataFrame.
205
+
206
+ Args:
207
+ input_data: Raw input data (bytes or string)
208
+ content_type: MIME type of the input data
209
+
210
+ Returns:
211
+ Parsed DataFrame
212
+
213
+ Raises:
214
+ ValueError: If input is empty or content_type is not supported
215
+ """
216
+ if not input_data:
217
+ raise ValueError("Empty input data is not supported!")
218
+
219
+ if isinstance(input_data, bytes):
220
+ input_data = input_data.decode("utf-8")
221
+
222
+ if "text/csv" in content_type:
223
+ return pd.read_csv(StringIO(input_data))
224
+ elif "application/json" in content_type:
225
+ return pd.DataFrame(json.loads(input_data))
226
+ else:
227
+ raise ValueError(f"{content_type} not supported!")
228
+
229
+
230
+ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
231
+ """Convert output DataFrame to requested format.
232
+
233
+ Args:
234
+ output_df: DataFrame to convert
235
+ accept_type: Requested MIME type
236
+
237
+ Returns:
238
+ Tuple of (formatted output string, MIME type)
239
+
240
+ Raises:
241
+ RuntimeError: If accept_type is not supported
242
+ """
243
+ if "text/csv" in accept_type:
244
+ csv_output = output_df.fillna("N/A").to_csv(index=False)
245
+ return csv_output, "text/csv"
246
+ elif "application/json" in accept_type:
247
+ return output_df.to_json(orient="records"), "application/json"
248
+ else:
249
+ raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
+
251
+
252
+ def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
+ """Compute standard regression metrics.
254
+
255
+ Args:
256
+ y_true: Ground truth target values
257
+ y_pred: Predicted values
258
+
259
+ Returns:
260
+ Dictionary with keys: rmse, mae, medae, r2, spearmanr, support
261
+ """
262
+ return {
263
+ "rmse": root_mean_squared_error(y_true, y_pred),
264
+ "mae": mean_absolute_error(y_true, y_pred),
265
+ "medae": median_absolute_error(y_true, y_pred),
266
+ "r2": r2_score(y_true, y_pred),
267
+ "spearmanr": spearmanr(y_true, y_pred).correlation,
268
+ "support": len(y_true),
269
+ }
270
+
271
+
272
+ def print_regression_metrics(metrics: dict[str, float]) -> None:
273
+ """Print regression metrics in the format expected by SageMaker metric definitions.
274
+
275
+ Args:
276
+ metrics: Dictionary of metric name -> value
277
+ """
278
+ print(f"rmse: {metrics['rmse']:.3f}")
279
+ print(f"mae: {metrics['mae']:.3f}")
280
+ print(f"medae: {metrics['medae']:.3f}")
281
+ print(f"r2: {metrics['r2']:.3f}")
282
+ print(f"spearmanr: {metrics['spearmanr']:.3f}")
283
+ print(f"support: {metrics['support']}")
284
+
285
+
286
+ def compute_classification_metrics(
287
+ y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str], target_col: str
288
+ ) -> pd.DataFrame:
289
+ """Compute per-class classification metrics.
290
+
291
+ Args:
292
+ y_true: Ground truth labels
293
+ y_pred: Predicted labels
294
+ label_names: List of class label names
295
+ target_col: Name of the target column (for DataFrame output)
296
+
297
+ Returns:
298
+ DataFrame with columns: target_col, precision, recall, f1, support
299
+ """
300
+ scores = precision_recall_fscore_support(y_true, y_pred, average=None, labels=label_names)
301
+ return pd.DataFrame(
302
+ {
303
+ target_col: label_names,
304
+ "precision": scores[0],
305
+ "recall": scores[1],
306
+ "f1": scores[2],
307
+ "support": scores[3],
308
+ }
309
+ )
310
+
311
+
312
+ def print_classification_metrics(score_df: pd.DataFrame, target_col: str, label_names: list[str]) -> None:
313
+ """Print per-class classification metrics in the format expected by SageMaker.
314
+
315
+ Args:
316
+ score_df: DataFrame from compute_classification_metrics
317
+ target_col: Name of the target column
318
+ label_names: List of class label names
319
+ """
320
+ metrics = ["precision", "recall", "f1", "support"]
321
+ for t in label_names:
322
+ for m in metrics:
323
+ value = score_df.loc[score_df[target_col] == t, m].iloc[0]
324
+ print(f"Metrics:{t}:{m} {value}")
325
+
326
+
327
+ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names: list[str]) -> None:
328
+ """Print confusion matrix in the format expected by SageMaker.
329
+
330
+ Args:
331
+ y_true: Ground truth labels
332
+ y_pred: Predicted labels
333
+ label_names: List of class label names
334
+ """
335
+ conf_mtx = confusion_matrix(y_true, y_pred, labels=label_names)
336
+ for i, row_name in enumerate(label_names):
337
+ for j, col_name in enumerate(label_names):
338
+ value = conf_mtx[i, j]
339
+ print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
@@ -1,11 +1,3 @@
1
1
  # Requirements for ChemProp model scripts
2
- # Note: These are the local dev requirements. The Docker images have their own requirements.txt
3
- chemprop==2.2.1
4
- rdkit==2025.9.1
5
- torch>=2.0.0
6
- lightning>=2.0.0
7
- pandas>=2.0.0
8
- numpy>=1.24.0
9
- scikit-learn>=1.3.0
10
- awswrangler>=3.0.0
11
- joblib>=1.3.0
2
+ # Note: The training and inference images already have torch and chemprop installed.
3
+ # So we only need to install packages that are not already included in the images.
@@ -1,31 +1,48 @@
1
- """Molecular fingerprint computation utilities"""
1
+ """Molecular fingerprint computation utilities for ADMET modeling.
2
+
3
+ This module provides Morgan count fingerprints, the standard for ADMET prediction.
4
+ Count fingerprints outperform binary fingerprints for molecular property prediction.
5
+
6
+ References:
7
+ - Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
8
+ - ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
9
+ """
2
10
 
3
11
  import logging
4
- import pandas as pd
5
12
 
6
- # Molecular Descriptor Imports
7
- from rdkit import Chem
8
- from rdkit.Chem import rdFingerprintGenerator
13
+ import numpy as np
14
+ import pandas as pd
15
+ from rdkit import Chem, RDLogger
16
+ from rdkit.Chem import AllChem
9
17
  from rdkit.Chem.MolStandardize import rdMolStandardize
10
18
 
19
+ # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
20
+ # Keep errors enabled so we see actual problems
21
+ RDLogger.DisableLog("rdApp.warning")
22
+
11
23
  # Set up the logger
12
24
  log = logging.getLogger("workbench")
13
25
 
14
26
 
15
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
16
- """Compute and add Morgan fingerprints to the DataFrame.
27
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
28
+ """Compute Morgan count fingerprints for ADMET modeling.
29
+
30
+ Generates true count fingerprints where each bit position contains the
31
+ number of times that substructure appears in the molecule (clamped to 0-255).
32
+ This is the recommended approach for ADMET prediction per 2025 research.
17
33
 
18
34
  Args:
19
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
20
- radius (int): Radius for the Morgan fingerprint.
21
- n_bits (int): Number of bits for the fingerprint.
22
- counts (bool): Count simulation for the fingerprint.
35
+ df: Input DataFrame containing SMILES strings.
36
+ radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
37
+ n_bits: Number of bits for the fingerprint (default 2048).
23
38
 
24
39
  Returns:
25
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
40
+ pd.DataFrame: Input DataFrame with 'fingerprint' column added.
41
+ Values are comma-separated uint8 counts.
26
42
 
27
43
  Note:
28
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
44
+ Count fingerprints outperform binary for ADMET prediction.
45
+ See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
29
46
  """
30
47
  delete_mol_column = False
31
48
 
@@ -39,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
39
56
  log.warning("Detected serialized molecules in 'molecule' column. Removing...")
40
57
  del df["molecule"]
41
58
 
42
- # Convert SMILES to RDKit molecule objects (vectorized)
59
+ # Convert SMILES to RDKit molecule objects
43
60
  if "molecule" not in df.columns:
44
61
  log.info("Converting SMILES to RDKit Molecules...")
45
62
  delete_mol_column = True
@@ -47,23 +64,32 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
47
64
  # Make sure our molecules are not None
48
65
  failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
49
66
  if failed_smiles:
50
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
51
- df = df.dropna(subset=["molecule"])
67
+ log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
68
+ df = df.dropna(subset=["molecule"]).copy()
52
69
 
53
70
  # If we have fragments in our compounds, get the largest fragment before computing fingerprints
54
71
  largest_frags = df["molecule"].apply(
55
72
  lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
56
73
  )
57
74
 
58
- # Create a Morgan fingerprint generator
59
- if counts:
60
- n_bits *= 4 # Multiply by 4 to simulate counts
61
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
75
+ def mol_to_count_string(mol):
76
+ """Convert molecule to comma-separated count fingerprint string."""
77
+ if mol is None:
78
+ return pd.NA
62
79
 
63
- # Compute Morgan fingerprints (vectorized)
64
- fingerprints = largest_frags.apply(
65
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
66
- )
80
+ # Get hashed Morgan fingerprint with counts
81
+ fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
82
+
83
+ # Initialize array and populate with counts (clamped to uint8 range)
84
+ counts = np.zeros(n_bits, dtype=np.uint8)
85
+ for idx, count in fp.GetNonzeroElements().items():
86
+ counts[idx] = min(count, 255)
87
+
88
+ # Return as comma-separated string
89
+ return ",".join(map(str, counts))
90
+
91
+ # Compute Morgan count fingerprints
92
+ fingerprints = largest_frags.apply(mol_to_count_string)
67
93
 
68
94
  # Add the fingerprints to the DataFrame
69
95
  df["fingerprint"] = fingerprints
@@ -71,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
71
97
  # Drop the intermediate 'molecule' column if it was added
72
98
  if delete_mol_column:
73
99
  del df["molecule"]
100
+
74
101
  return df
75
102
 
76
103
 
77
104
  if __name__ == "__main__":
78
- print("Running molecular fingerprint tests...")
79
- print("Note: This requires molecular_screening module to be available")
105
+ print("Running Morgan count fingerprint tests...")
80
106
 
81
107
  # Test molecules
82
108
  test_molecules = {
83
109
  "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
84
110
  "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
85
111
  "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
86
- "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
87
113
  "benzene": "c1ccccc1",
88
114
  "butene_e": "C/C=C/C", # E-butene
89
115
  "butene_z": "C/C=C\\C", # Z-butene
90
116
  }
91
117
 
92
- # Test 1: Morgan Fingerprints
93
- print("\n1. Testing Morgan fingerprint generation...")
118
+ # Test 1: Morgan Count Fingerprints (default parameters)
119
+ print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
94
120
 
95
121
  test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
96
-
97
- fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
122
+ fp_df = compute_morgan_fingerprints(test_df.copy())
98
123
 
99
124
  print(" Fingerprint generation results:")
100
125
  for _, row in fp_df.iterrows():
101
126
  fp = row.get("fingerprint", "N/A")
102
- fp_len = len(fp) if fp != "N/A" else 0
103
- print(f" {row['name']:15} {fp_len} bits")
127
+ if pd.notna(fp):
128
+ counts = [int(x) for x in fp.split(",")]
129
+ non_zero = sum(1 for c in counts if c > 0)
130
+ max_count = max(counts)
131
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
132
+ else:
133
+ print(f" {row['name']:15} → N/A")
104
134
 
105
- # Test 2: Different fingerprint parameters
106
- print("\n2. Testing different fingerprint parameters...")
135
+ # Test 2: Different parameters
136
+ print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
107
137
 
108
- # Test with counts enabled
109
- fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
138
+ fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
110
139
 
111
- print(" With count simulation (256 bits * 4):")
112
- for _, row in fp_counts_df.iterrows():
140
+ for _, row in fp_df_custom.iterrows():
113
141
  fp = row.get("fingerprint", "N/A")
114
- fp_len = len(fp) if fp != "N/A" else 0
115
- print(f" {row['name']:15} {fp_len} bits")
142
+ if pd.notna(fp):
143
+ counts = [int(x) for x in fp.split(",")]
144
+ non_zero = sum(1 for c in counts if c > 0)
145
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
146
+ else:
147
+ print(f" {row['name']:15} → N/A")
116
148
 
117
149
  # Test 3: Edge cases
118
150
  print("\n3. Testing edge cases...")
119
151
 
120
152
  # Invalid SMILES
121
153
  invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
122
- try:
123
- fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
124
- print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
125
- except Exception as e:
126
- print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
154
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
155
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
127
156
 
128
157
  # Test with pre-existing molecule column
129
158
  mol_df = test_df.copy()
@@ -131,4 +160,16 @@ if __name__ == "__main__":
131
160
  fp_with_mol = compute_morgan_fingerprints(mol_df)
132
161
  print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
133
162
 
163
+ # Test 4: Verify count values are reasonable
164
+ print("\n4. Verifying count distribution...")
165
+ all_counts = []
166
+ for _, row in fp_df.iterrows():
167
+ fp = row.get("fingerprint", "N/A")
168
+ if pd.notna(fp):
169
+ counts = [int(x) for x in fp.split(",")]
170
+ all_counts.extend([c for c in counts if c > 0])
171
+
172
+ if all_counts:
173
+ print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
174
+
134
175
  print("\n✅ All fingerprint tests completed!")