workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -1,769 +0,0 @@
1
- """Local utilities for models that work with chemical information."""
2
-
3
- import logging
4
- import pandas as pd
5
- import numpy as np
6
- from typing import List, Optional
7
-
8
- # Molecular Descriptor Imports
9
- from rdkit import Chem
10
- from rdkit.Chem import Mol, Descriptors, rdFingerprintGenerator
11
- from rdkit.ML.Descriptors import MoleculeDescriptors
12
- from rdkit.Chem.MolStandardize.rdMolStandardize import TautomerEnumerator
13
- from rdkit.Chem import rdCIPLabeler
14
- from rdkit.Chem.rdMolDescriptors import CalcNumHBD, CalcExactMolWt
15
- from rdkit import RDLogger
16
- from rdkit.Chem import FunctionalGroups as FG
17
- from mordred import Calculator as MordredCalculator
18
- from mordred import AcidBase, Aromatic, Polarizability, RotatableBond
19
-
20
- # Load functional group hierarchy once during initialization
21
- fgroup_hierarchy = FG.BuildFuncGroupHierarchy()
22
-
23
- # Set RDKit logger to only show errors or critical messages
24
- lg = RDLogger.logger()
25
- lg.setLevel(RDLogger.ERROR)
26
-
27
- # Set up the logger
28
- log = logging.getLogger("workbench")
29
-
30
-
31
- def remove_disconnected_fragments(mol: Chem.Mol) -> Chem.Mol:
32
- """
33
- Remove disconnected fragments from a molecule, keeping the fragment with the most heavy atoms.
34
-
35
- Args:
36
- mol (Mol): RDKit molecule object.
37
-
38
- Returns:
39
- Mol: The fragment with the most heavy atoms, or None if no such fragment exists.
40
- """
41
- if mol is None or mol.GetNumAtoms() == 0:
42
- return None
43
- fragments = Chem.GetMolFrags(mol, asMols=True)
44
- return max(fragments, key=lambda frag: frag.GetNumHeavyAtoms()) if fragments else None
45
-
46
-
47
- def contains_heavy_metals(mol: Mol) -> bool:
48
- """
49
- Check if a molecule contains any heavy metals (broad filter).
50
-
51
- Args:
52
- mol: RDKit molecule object.
53
-
54
- Returns:
55
- bool: True if any heavy metals are detected, False otherwise.
56
- """
57
- heavy_metals = {"Zn", "Cu", "Fe", "Mn", "Co", "Pb", "Hg", "Cd", "As"}
58
- return any(atom.GetSymbol() in heavy_metals for atom in mol.GetAtoms())
59
-
60
-
61
- def halogen_toxicity_score(mol: Mol) -> (int, int):
62
- """
63
- Calculate the halogen count and toxicity threshold for a molecule.
64
-
65
- Args:
66
- mol: RDKit molecule object.
67
-
68
- Returns:
69
- Tuple[int, int]: (halogen_count, halogen_threshold), where the threshold
70
- scales with molecule size (minimum of 2 or 20% of atom count).
71
- """
72
- # Define halogens and count their occurrences
73
- halogens = {"Cl", "Br", "I", "F"}
74
- halogen_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() in halogens)
75
-
76
- # Define threshold: small molecules tolerate fewer halogens
77
- # Threshold scales with molecule size to account for reasonable substitution
78
- molecule_size = mol.GetNumAtoms()
79
- halogen_threshold = max(2, int(molecule_size * 0.2)) # Minimum 2, scaled by 20% of molecule size
80
-
81
- return halogen_count, halogen_threshold
82
-
83
-
84
- def toxic_elements(mol: Mol) -> Optional[List[str]]:
85
- """
86
- Identifies toxic elements or specific forms of elements in a molecule.
87
-
88
- Args:
89
- mol: RDKit molecule object.
90
-
91
- Returns:
92
- Optional[List[str]]: List of toxic elements or specific forms if found, otherwise None.
93
-
94
- Notes:
95
- Halogen toxicity logic integrates with `halogen_toxicity_score` and scales thresholds
96
- based on molecule size.
97
- """
98
- # Always toxic elements (heavy metals and known toxic single elements)
99
- always_toxic = {"Pb", "Hg", "Cd", "As", "Be", "Tl", "Sb"}
100
- toxic_found = set()
101
-
102
- for atom in mol.GetAtoms():
103
- symbol = atom.GetSymbol()
104
- formal_charge = atom.GetFormalCharge()
105
-
106
- # Check for always toxic elements
107
- if symbol in always_toxic:
108
- toxic_found.add(symbol)
109
-
110
- # Conditionally toxic nitrogen (positively charged)
111
- if symbol == "N" and formal_charge > 0:
112
- # Exclude benign quaternary ammonium (e.g., choline-like structures)
113
- if mol.HasSubstructMatch(Chem.MolFromSmarts("[N+](C)(C)(C)C")): # Example benign structure
114
- continue
115
- toxic_found.add("N+")
116
-
117
- # Halogen toxicity: Uses halogen_toxicity_score to flag excessive halogenation
118
- if symbol in {"Cl", "Br", "I", "F"}:
119
- halogen_count, halogen_threshold = halogen_toxicity_score(mol)
120
- if halogen_count > halogen_threshold:
121
- toxic_found.add(symbol)
122
-
123
- return list(toxic_found) if toxic_found else None
124
-
125
-
126
- # Precompiled SMARTS patterns for custom toxic functional groups
127
- toxic_smarts_patterns = [
128
- ("C(=S)N"), # Dithiocarbamate
129
- ("P(=O)(O)(O)O"), # Phosphate Ester
130
- ("[As](=O)(=O)-[OH]"), # Arsenic Oxide
131
- ("[C](Cl)(Cl)(Cl)"), # Trichloromethyl
132
- ("[Cr](=O)(=O)=O"), # Chromium(VI)
133
- ("[N+](C)(C)(C)(C)"), # Quaternary Ammonium
134
- ("[Se][Se]"), # Diselenide
135
- ("c1c(Cl)c(Cl)c(Cl)c1"), # Trichlorinated Aromatic Ring
136
- ("[CX3](=O)[CX4][Cl,Br,F,I]"), # Halogenated Carbonyl
137
- ("[P+](C*)(C*)(C*)(C*)"), # Phosphonium Group
138
- ("NC(=S)c1c(Cl)cccc1Cl"), # Chlorobenzene Thiocarbamate
139
- ("NC(=S)Nc1ccccc1"), # Phenyl Thiocarbamate
140
- ("S=C1NCCN1"), # Thiourea Derivative
141
- ]
142
- compiled_toxic_smarts = [Chem.MolFromSmarts(smarts) for smarts in toxic_smarts_patterns]
143
-
144
- # Precompiled SMARTS patterns for exemptions
145
- exempt_smarts_patterns = [
146
- "c1ccc(O)c(O)c1", # Phenols
147
- ]
148
- compiled_exempt_smarts = [Chem.MolFromSmarts(smarts) for smarts in exempt_smarts_patterns]
149
-
150
-
151
- def toxic_groups(mol: Chem.Mol) -> Optional[List[str]]:
152
- """
153
- Check if a molecule contains known toxic functional groups using RDKit's functional groups and SMARTS patterns.
154
-
155
- Args:
156
- mol (rdkit.Chem.Mol): The molecule to evaluate.
157
-
158
- Returns:
159
- Optional[List[str]]: List of SMARTS patterns for toxic groups if found, otherwise None.
160
- """
161
- toxic_smarts_matches = []
162
-
163
- # Use RDKit's functional group definitions
164
- toxic_group_names = ["Nitro", "Azide", "Alcohol", "Aldehyde", "Halogen", "TerminalAlkyne"]
165
- for group_name in toxic_group_names:
166
- group_node = next(node for node in fgroup_hierarchy if node.label == group_name)
167
- if mol.HasSubstructMatch(Chem.MolFromSmarts(group_node.smarts)):
168
- toxic_smarts_matches.append(group_node.smarts) # Use group_node's SMARTS directly
169
-
170
- # Check for custom precompiled toxic SMARTS patterns
171
- for smarts, compiled in zip(toxic_smarts_patterns, compiled_toxic_smarts):
172
- if mol.HasSubstructMatch(compiled): # Use precompiled SMARTS
173
- toxic_smarts_matches.append(smarts)
174
-
175
- # Special handling for N+
176
- if mol.HasSubstructMatch(Chem.MolFromSmarts("[N+]")):
177
- if not mol.HasSubstructMatch(Chem.MolFromSmarts("C[N+](C)(C)C")): # Exclude benign
178
- toxic_smarts_matches.append("[N+]") # Append as SMARTS
179
-
180
- # Exempt stabilizing functional groups using precompiled patterns
181
- for compiled in compiled_exempt_smarts:
182
- if mol.HasSubstructMatch(compiled):
183
- return None
184
-
185
- return toxic_smarts_matches if toxic_smarts_matches else None
186
-
187
-
188
- def contains_metalloenzyme_relevant_metals(mol: Mol) -> bool:
189
- """
190
- Check if a molecule contains metals relevant to metalloenzymes.
191
-
192
- Args:
193
- mol: RDKit molecule object.
194
-
195
- Returns:
196
- bool: True if metalloenzyme-relevant metals are detected, False otherwise.
197
- """
198
- metalloenzyme_metals = {"Zn", "Cu", "Fe", "Mn", "Co"}
199
- return any(atom.GetSymbol() in metalloenzyme_metals for atom in mol.GetAtoms())
200
-
201
-
202
- def contains_salts(mol: Mol) -> bool:
203
- """
204
- Check if a molecule contains common salts or counterions.
205
-
206
- Args:
207
- mol: RDKit molecule object.
208
-
209
- Returns:
210
- bool: True if salts are detected, False otherwise.
211
- """
212
- # Define common inorganic salt fragments (SMARTS patterns)
213
- salt_patterns = ["[Na+]", "[K+]", "[Cl-]", "[Mg+2]", "[Ca+2]", "[NH4+]", "[SO4--]"]
214
- for pattern in salt_patterns:
215
- if mol.HasSubstructMatch(Chem.MolFromSmarts(pattern)):
216
- return True
217
- return False
218
-
219
-
220
- def is_druglike_compound(mol: Mol) -> bool:
221
- """
222
- Filter for drug-likeness and QSAR relevance based on Lipinski's Rule of Five.
223
- Returns False for molecules unlikely to be orally bioavailable.
224
-
225
- Args:
226
- mol: RDKit molecule object.
227
-
228
- Returns:
229
- bool: True if the molecule is drug-like, False otherwise.
230
- """
231
-
232
- # Lipinski's Rule of Five
233
- mw = Descriptors.MolWt(mol)
234
- logp = Descriptors.MolLogP(mol)
235
- hbd = Descriptors.NumHDonors(mol)
236
- hba = Descriptors.NumHAcceptors(mol)
237
- if mw > 500 or logp > 5 or hbd > 5 or hba > 10:
238
- return False
239
-
240
- # Allow exceptions for linear molecules that meet strict RO5 criteria
241
- if mol.GetRingInfo().NumRings() == 0:
242
- if mw <= 300 and logp <= 3 and hbd <= 3 and hba <= 3:
243
- pass # Allow small, non-cyclic druglike compounds
244
- else:
245
- return False
246
-
247
- return True
248
-
249
-
250
- def add_compound_tags(df, mol_column="molecule") -> pd.DataFrame:
251
- """
252
- Adds a 'tags' column to a DataFrame, tagging compounds based on their properties.
253
-
254
- Args:
255
- df (pd.DataFrame): Input DataFrame containing molecular data.
256
- mol_column (str): Column name containing RDKit molecule objects.
257
-
258
- Returns:
259
- pd.DataFrame: Updated DataFrame with a 'tags' column.
260
- """
261
- # Initialize the tags column
262
- df["tags"] = [[] for _ in range(len(df))]
263
- df["meta"] = [{} for _ in range(len(df))]
264
-
265
- # Process each molecule in the DataFrame
266
- for idx, row in df.iterrows():
267
- mol = row[mol_column]
268
- tags = []
269
-
270
- # Check for salts
271
- if contains_salts(mol):
272
- tags.append("salt")
273
-
274
- # Check for fragments (should be done after salt check)
275
- fragments = Chem.GetMolFrags(mol, asMols=True)
276
- if len(fragments) > 1:
277
- tags.append("frag")
278
-
279
- # Check for heavy metals
280
- if contains_heavy_metals(mol):
281
- tags.append("heavy_metals")
282
-
283
- # Check for toxic elements
284
- te = toxic_elements(mol)
285
- if te:
286
- tags.append("toxic_element")
287
- df.at[idx, "meta"]["toxic_elements"] = te
288
-
289
- # Check for toxic groups
290
- tg = toxic_groups(mol)
291
- if tg:
292
- tags.append("toxic_group")
293
- df.at[idx, "meta"]["toxic_groups"] = tg
294
-
295
- # Check for metalloenzyme-relevant metals
296
- if contains_metalloenzyme_relevant_metals(mol):
297
- tags.append("metalloenzyme")
298
-
299
- # Check for drug-likeness
300
- if is_druglike_compound(mol):
301
- tags.append("druglike")
302
-
303
- # Update tags
304
- df.at[idx, "tags"] = tags
305
-
306
- return df
307
-
308
-
309
- def compute_molecular_descriptors(df: pd.DataFrame, tautomerize=True) -> pd.DataFrame:
310
- """Compute and add all the Molecular Descriptors
311
-
312
- Args:
313
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
314
- tautomerize (bool): Whether to tautomerize the SMILES strings.
315
-
316
- Returns:
317
- pd.DataFrame: The input DataFrame with all the RDKit Descriptors added
318
- """
319
-
320
- # Check for the smiles column (any capitalization)
321
- smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
322
- if smiles_column is None:
323
- raise ValueError("Input DataFrame must have a 'smiles' column")
324
-
325
- # Compute/add all the Molecular Descriptors
326
- log.info("Computing Molecular Descriptors...")
327
-
328
- # Convert SMILES to RDKit molecule objects (vectorized)
329
- log.info("Converting SMILES to RDKit Molecules...")
330
- df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
331
-
332
- # Make sure our molecules are not None
333
- failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
334
- if failed_smiles:
335
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
336
- df = df.dropna(subset=["molecule"])
337
-
338
- # If we have fragments in our compounds, get the largest fragment before computing descriptors
339
- df["molecule"] = df["molecule"].apply(remove_disconnected_fragments)
340
-
341
- # Tautomerize the molecules if requested
342
- if tautomerize:
343
- log.info("Tautomerizing molecules...")
344
- tautomer_enumerator = TautomerEnumerator()
345
- df["molecule"] = df["molecule"].apply(tautomer_enumerator.Canonicalize)
346
-
347
- # Now get all the RDKIT Descriptors
348
- all_descriptors = [x[0] for x in Descriptors._descList]
349
-
350
- # There's an overflow issue that happens with the IPC descriptor, so we'll remove it
351
- # See: https://github.com/rdkit/rdkit/issues/1527
352
- if "Ipc" in all_descriptors:
353
- all_descriptors.remove("Ipc")
354
-
355
- # Make sure we don't have duplicates
356
- all_descriptors = list(set(all_descriptors))
357
-
358
- # RDKit Molecular Descriptor Calculator Class
359
- log.info("Computing RDKit Descriptors...")
360
- calc = MoleculeDescriptors.MolecularDescriptorCalculator(all_descriptors)
361
- descriptor_values = [calc.CalcDescriptors(m) for m in df["molecule"]]
362
-
363
- # Lowercase the column names
364
- column_names = [name.lower() for name in calc.GetDescriptorNames()]
365
- rdkit_features_df = pd.DataFrame(descriptor_values, columns=column_names)
366
-
367
- # Now compute Mordred Features
368
- log.info("Computing Mordred Descriptors...")
369
- descriptor_choice = [AcidBase, Aromatic, Polarizability, RotatableBond]
370
- calc = MordredCalculator()
371
- for des in descriptor_choice:
372
- calc.register(des)
373
- mordred_df = calc.pandas(df["molecule"], nproc=1)
374
-
375
- # Lowercase the column names
376
- mordred_df.columns = [col.lower() for col in mordred_df.columns]
377
-
378
- # Compute stereochemistry descriptors
379
- stereo_df = compute_stereochemistry_descriptors(df)
380
-
381
- # Combine the DataFrame with the RDKit and Mordred Descriptors added
382
- # Note: This will overwrite any existing columns with the same name. This is a good thing
383
- # since we want computed descriptors to overwrite anything in the input dataframe
384
- output_df = stereo_df.combine_first(mordred_df).combine_first(rdkit_features_df)
385
-
386
- # Ensure no duplicate column names
387
- output_df = output_df.loc[:, ~output_df.columns.duplicated()]
388
-
389
- # Reorder the columns to have all the ones in the input df first and then the descriptors
390
- input_columns = df.columns.tolist()
391
- output_df = output_df[input_columns + [col for col in output_df.columns if col not in input_columns]]
392
-
393
- # Drop the intermediate 'molecule' column
394
- del output_df["molecule"]
395
-
396
- # Return the DataFrame with the RDKit and Mordred Descriptors added
397
- return output_df
398
-
399
-
400
- def compute_stereochemistry_descriptors(df: pd.DataFrame) -> pd.DataFrame:
401
- """Compute stereochemistry descriptors for molecules in a DataFrame.
402
-
403
- This function analyzes the stereochemical properties of molecules, including:
404
- - Chiral centers (R/S configuration)
405
- - Double bond stereochemistry (E/Z configuration)
406
-
407
- Args:
408
- df (pd.DataFrame): Input DataFrame with RDKit molecule objects in 'molecule' column
409
-
410
- Returns:
411
- pd.DataFrame: DataFrame with added stereochemistry descriptors
412
- """
413
- if "molecule" not in df.columns:
414
- raise ValueError("Input DataFrame must have a 'molecule' column")
415
-
416
- log.info("Computing stereochemistry descriptors...")
417
- output_df = df.copy()
418
-
419
- # Create helper functions to process a single molecule
420
- def process_molecule(mol):
421
- if mol is None:
422
- log.warning("Found a None molecule, skipping...")
423
- return {
424
- "chiral_centers": 0,
425
- "r_cnt": 0,
426
- "s_cnt": 0,
427
- "db_stereo": 0,
428
- "e_cnt": 0,
429
- "z_cnt": 0,
430
- "chiral_fp": 0,
431
- "db_fp": 0,
432
- }
433
-
434
- try:
435
- # Use the more accurate CIP labeling algorithm (Cahn-Ingold-Prelog rules)
436
- # This assigns R/S to chiral centers and E/Z to double bonds based on
437
- # the priority of substituents (atomic number, mass, etc.)
438
- rdCIPLabeler.AssignCIPLabels(mol)
439
-
440
- # Find all potential stereochemistry sites in the molecule
441
- stereo_info = Chem.FindPotentialStereo(mol)
442
-
443
- # Initialize counters
444
- specified_centers = 0 # Number of chiral centers with defined stereochemistry
445
- r_cnt = 0 # Count of R configured centers
446
- s_cnt = 0 # Count of S configured centers
447
- stereo_atoms = [] # List to store atom indices and their R/S configuration
448
-
449
- specified_bonds = 0 # Number of double bonds with defined stereochemistry
450
- e_cnt = 0 # Count of E (trans) configured double bonds
451
- z_cnt = 0 # Count of Z (cis) configured double bonds
452
- stereo_bonds = [] # List to store bond indices and their E/Z configuration
453
-
454
- # Process all stereo information found in the molecule
455
- for element in stereo_info:
456
- # Handle tetrahedral chiral centers
457
- if element.type == Chem.StereoType.Atom_Tetrahedral:
458
- atom_idx = element.centeredOn
459
-
460
- # Only count centers where stereochemistry is explicitly defined
461
- if element.specified == Chem.StereoSpecified.Specified:
462
- specified_centers += 1
463
- if element.descriptor == Chem.StereoDescriptor.Tet_CCW:
464
- r_cnt += 1
465
- stereo_atoms.append((atom_idx, "R"))
466
- elif element.descriptor == Chem.StereoDescriptor.Tet_CW:
467
- s_cnt += 1
468
- stereo_atoms.append((atom_idx, "S"))
469
-
470
- # Handle double bond stereochemistry
471
- elif element.type == Chem.StereoType.Bond_Double:
472
- bond_idx = element.centeredOn
473
-
474
- # Only count bonds where stereochemistry is explicitly defined
475
- if element.specified == Chem.StereoSpecified.Specified:
476
- specified_bonds += 1
477
- if element.descriptor == Chem.StereoDescriptor.Bond_Trans:
478
- e_cnt += 1
479
- stereo_bonds.append((bond_idx, "E"))
480
- elif element.descriptor == Chem.StereoDescriptor.Bond_Cis:
481
- z_cnt += 1
482
- stereo_bonds.append((bond_idx, "Z"))
483
-
484
- # Calculate chiral center fingerprint - unique bit vector for stereochemical configuration
485
- chiral_fp = 0
486
- if stereo_atoms:
487
- for i, (idx, stereo) in enumerate(sorted(stereo_atoms, key=lambda x: x[0])):
488
- bit_val = 1 if stereo == "R" else 0
489
- chiral_fp += bit_val << i # Shift bits to create a unique fingerprint
490
-
491
- # Calculate double bond fingerprint - bit vector for E/Z configurations
492
- db_fp = 0
493
- if stereo_bonds:
494
- for i, (idx, stereo) in enumerate(sorted(stereo_bonds, key=lambda x: x[0])):
495
- bit_val = 1 if stereo == "E" else 0
496
- db_fp += bit_val << i # Shift bits to create a unique fingerprint
497
-
498
- return {
499
- "chiral_centers": specified_centers,
500
- "r_cnt": r_cnt,
501
- "s_cnt": s_cnt,
502
- "db_stereo": specified_bonds,
503
- "e_cnt": e_cnt,
504
- "z_cnt": z_cnt,
505
- "chiral_fp": chiral_fp,
506
- "db_fp": db_fp,
507
- }
508
-
509
- except Exception as e:
510
- log.warning(f"Error processing stereochemistry: {str(e)}")
511
- return {
512
- "chiral_centers": 0,
513
- "r_cnt": 0,
514
- "s_cnt": 0,
515
- "db_stereo": 0,
516
- "e_cnt": 0,
517
- "z_cnt": 0,
518
- "chiral_fp": 0,
519
- "db_fp": 0,
520
- }
521
-
522
- # Process all molecules and collect results
523
- results = []
524
- for mol in df["molecule"]:
525
- results.append(process_molecule(mol))
526
-
527
- # Add all descriptors to the output dataframe
528
- for key in results[0].keys():
529
- output_df[key] = [r[key] for r in results]
530
-
531
- # Boolean flag indicating if the molecule has any stereochemistry defined
532
- output_df["has_stereo"] = (output_df["chiral_centers"] > 0) | (output_df["db_stereo"] > 0)
533
-
534
- return output_df
535
-
536
-
537
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
538
- """Compute and add Morgan fingerprints to the DataFrame.
539
-
540
- Args:
541
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
542
- radius (int): Radius for the Morgan fingerprint.
543
- n_bits (int): Number of bits for the fingerprint.
544
- counts (bool): Count simulation for the fingerprint.
545
-
546
- Returns:
547
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
548
-
549
- Note:
550
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
551
- """
552
- delete_mol_column = False
553
-
554
- # Check for the SMILES column (case-insensitive)
555
- smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
556
- if smiles_column is None:
557
- raise ValueError("Input DataFrame must have a 'smiles' column")
558
-
559
- # Sanity check the molecule column (sometimes it gets serialized, which doesn't work)
560
- if "molecule" in df.columns and df["molecule"].dtype == "string":
561
- log.warning("Detected serialized molecules in 'molecule' column. Removing...")
562
- del df["molecule"]
563
-
564
- # Convert SMILES to RDKit molecule objects (vectorized)
565
- if "molecule" not in df.columns:
566
- log.info("Converting SMILES to RDKit Molecules...")
567
- delete_mol_column = True
568
- df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
569
- # Make sure our molecules are not None
570
- failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
571
- if failed_smiles:
572
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
573
- df = df.dropna(subset=["molecule"])
574
-
575
- # If we have fragments in our compounds, get the largest fragment before computing fingerprints
576
- largest_frags = df["molecule"].apply(remove_disconnected_fragments)
577
-
578
- # Create a Morgan fingerprint generator
579
- if counts:
580
- n_bits *= 4 # Multiply by 4 to simulate counts
581
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
582
-
583
- # Compute Morgan fingerprints (vectorized)
584
- fingerprints = largest_frags.apply(
585
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
586
- )
587
-
588
- # Add the fingerprints to the DataFrame
589
- df["fingerprint"] = fingerprints
590
-
591
- # Drop the intermediate 'molecule' column if it was added
592
- if delete_mol_column:
593
- del df["molecule"]
594
- return df
595
-
596
-
597
- def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
598
- """
599
- Convert bitstring fingerprints to numpy matrix.
600
-
601
- Args:
602
- fingerprints: pandas Series or list of bitstring fingerprints
603
- dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
604
-
605
- Returns:
606
- dense numpy array of shape (n_molecules, n_bits)
607
- """
608
-
609
- # Dense matrix representation (we might support sparse in the future)
610
- return np.array([list(fp) for fp in fingerprints], dtype=dtype)
611
-
612
-
613
- def canonicalize(df: pd.DataFrame, remove_mol_col: bool = True) -> pd.DataFrame:
614
- """
615
- Generate RDKit's canonical SMILES for each molecule in the input DataFrame.
616
-
617
- Args:
618
- df (pd.DataFrame): Input DataFrame containing a column named 'SMILES' (case-insensitive).
619
- remove_mol_col (bool): Whether to drop the intermediate 'molecule' column. Default is True.
620
-
621
- Returns:
622
- pd.DataFrame: A DataFrame with an additional 'smiles_canonical' column and,
623
- optionally, the 'molecule' column.
624
- """
625
- # Identify the SMILES column (case-insensitive)
626
- smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
627
- if smiles_column is None:
628
- raise ValueError("Input DataFrame must have a 'SMILES' column")
629
-
630
- # Convert SMILES to RDKit molecules
631
- df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
632
-
633
- # Log invalid SMILES
634
- invalid_indices = df[df["molecule"].isna()].index
635
- if not invalid_indices.empty:
636
- log.critical(f"Invalid SMILES strings at indices: {invalid_indices.tolist()}")
637
-
638
- # Drop rows where SMILES failed to convert to molecule
639
- df.dropna(subset=["molecule"], inplace=True)
640
-
641
- # Remove disconnected fragments (keep the largest fragment)
642
- df["molecule"] = df["molecule"].apply(lambda mol: remove_disconnected_fragments(mol) if mol else None)
643
-
644
- # Convert molecules to canonical SMILES (preserving isomeric information)
645
- df["smiles_canonical"] = df["molecule"].apply(
646
- lambda mol: Chem.MolToSmiles(mol, isomericSmiles=True) if mol else None
647
- )
648
-
649
- # Drop intermediate RDKit molecule column if requested
650
- if remove_mol_col:
651
- df.drop(columns=["molecule"], inplace=True)
652
-
653
- return df
654
-
655
-
656
- def custom_tautomer_canonicalization(mol: Mol) -> str:
657
- """Domain-specific processing of a molecule to select the canonical tautomer.
658
-
659
- This function enumerates all possible tautomers for a given molecule and applies
660
- custom logic to select the canonical form.
661
-
662
- Args:
663
- mol (Mol): The RDKit molecule for which the canonical tautomer is to be determined.
664
-
665
- Returns:
666
- str: The SMILES string of the selected canonical tautomer.
667
- """
668
- tautomer_enumerator = TautomerEnumerator()
669
- enumerated_tautomers = tautomer_enumerator.Enumerate(mol)
670
-
671
- # Example custom logic: prioritize based on use-case specific criteria
672
- selected_tautomer = None
673
- highest_score = float("-inf")
674
-
675
- for taut in enumerated_tautomers:
676
- # Compute custom scoring logic:
677
- # 1. Prefer forms with fewer hydrogen bond donors (HBD) if membrane permeability is important
678
- # 2. Penalize forms with high molecular weight for better drug-likeness
679
- # 3. Incorporate known functional group preferences (e.g., keto > enol for binding)
680
-
681
- hbd = CalcNumHBD(taut) # Hydrogen Bond Donors
682
- mw = CalcExactMolWt(taut) # Molecular Weight
683
- aromatic_rings = taut.GetRingInfo().NumAromaticRings() # Favor aromaticity
684
-
685
- # Example scoring: balance HBD, MW, and aromaticity
686
- score = -hbd - 0.01 * mw + aromatic_rings * 2
687
-
688
- # Update selected tautomer
689
- if score > highest_score:
690
- highest_score = score
691
- selected_tautomer = taut
692
-
693
- # Return the SMILES of the selected tautomer
694
- return Chem.MolToSmiles(selected_tautomer)
695
-
696
-
697
- def standard_tautomer_canonicalization(mol: Mol) -> str:
698
- """Standard processing of a molecule to select the canonical tautomer.
699
-
700
- RDKit's `TautomerEnumerator` uses heuristics to select a canonical tautomer,
701
- such as preferring keto over enol forms and minimizing formal charges.
702
-
703
- Args:
704
- mol (Mol): The RDKit molecule for which the canonical tautomer is to be determined.
705
-
706
- Returns:
707
- str: The SMILES string of the canonical tautomer.
708
- """
709
- tautomer_enumerator = TautomerEnumerator()
710
- canonical_tautomer = tautomer_enumerator.Canonicalize(mol)
711
- return Chem.MolToSmiles(canonical_tautomer)
712
-
713
-
714
- def tautomerize_smiles(df: pd.DataFrame) -> pd.DataFrame:
715
- """
716
- Perform tautomer enumeration and canonicalization on a DataFrame.
717
-
718
- Args:
719
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
720
-
721
- Returns:
722
- pd.DataFrame: A new DataFrame with additional 'smiles_canonical' and 'smiles_tautomer' columns.
723
- """
724
- # Standardize SMILES strings and create 'molecule' column for further processing
725
- df = canonicalize(df, remove_mol_col=False)
726
-
727
- # Helper function to safely canonicalize a molecule's tautomer
728
- def safe_tautomerize(mol):
729
- """Safely canonicalize a molecule's tautomer, handling errors gracefully."""
730
- if not mol:
731
- return pd.NA
732
- try:
733
- # Use RDKit's standard Tautomer enumeration and canonicalization
734
- # For custom logic, replace with custom_tautomer_canonicalization(mol)
735
- return standard_tautomer_canonicalization(mol)
736
- except Exception as e:
737
- log.warning(f"Tautomerization failed: {str(e)}")
738
- return pd.NA
739
-
740
- # Apply tautomer canonicalization to each molecule
741
- df["smiles_tautomer"] = df["molecule"].apply(safe_tautomerize)
742
-
743
- # Drop intermediate RDKit molecule column to clean up the DataFrame
744
- df.drop(columns=["molecule"], inplace=True)
745
-
746
- # Now switch the smiles columns
747
- df.rename(columns={"smiles": "smiles_orig", "smiles_tautomer": "smiles"}, inplace=True)
748
-
749
- return df
750
-
751
-
752
- if __name__ == "__main__":
753
-
754
- # Small set of tests
755
- smiles = "O=C(CCl)c1ccc(Cl)cc1Cl"
756
- mol = Chem.MolFromSmiles(smiles)
757
-
758
- # Compute Molecular Descriptors
759
- df = pd.DataFrame({"smiles": [smiles, smiles, smiles, smiles, smiles]})
760
- md_df = compute_molecular_descriptors(df)
761
- print(md_df)
762
-
763
- # Compute Morgan Fingerprints
764
- fp_df = compute_morgan_fingerprints(df)
765
- print(fp_df)
766
-
767
- # Perform Tautomerization
768
- t_df = tautomerize_smiles(df)
769
- print(t_df)