workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -1,1556 +0,0 @@
1
- """Chem/RDKIT/Mordred utilities for Workbench"""
2
-
3
- import logging
4
- import numpy as np
5
- import pandas as pd
6
- from typing import List, Dict, Tuple, Optional, Union
7
- import base64
8
- from sklearn.manifold import TSNE
9
-
10
- # Try importing UMAP
11
- try:
12
- import umap
13
- except ImportError:
14
- umap = None
15
-
16
- # Workbench Imports
17
- from workbench.utils.color_utils import is_dark, rgba_to_tuple
18
-
19
- # Molecular Descriptor Imports
20
- from rdkit import Chem
21
- from rdkit.Chem import AllChem, Mol, Descriptors, rdFingerprintGenerator, Draw, SDWriter
22
- from rdkit.Chem.Draw import rdMolDraw2D
23
- from rdkit.ML.Descriptors import MoleculeDescriptors
24
- from rdkit.Chem.MolStandardize.rdMolStandardize import TautomerEnumerator
25
- from rdkit.Chem import rdCIPLabeler
26
- from rdkit.Chem.rdMolDescriptors import CalcNumHBD, CalcExactMolWt
27
- from rdkit import RDLogger
28
- from rdkit.Chem import FunctionalGroups as FG
29
- from mordred import Calculator as MordredCalculator
30
- from mordred import AcidBase, Aromatic, Polarizability, RotatableBond
31
-
32
-
33
- # Load functional group hierarchy once during initialization
34
- fgroup_hierarchy = FG.BuildFuncGroupHierarchy()
35
-
36
- # Set RDKit logger to only show errors or critical messages
37
- lg = RDLogger.logger()
38
- lg.setLevel(RDLogger.ERROR)
39
-
40
- # Set up the logger
41
- log = logging.getLogger("workbench")
42
-
43
- """FIXME:
44
- Let's figure out what we're not just using these RDKit methods
45
-
46
- from rdkit.Chem import PandasTools
47
- df = PandasTools.LoadSDF('file.sdf', molColName='ROMol', smilesName='SMILES')
48
- PandasTools.WriteSDF(df, 'file.sdf', molColName='ROMol', properties=list(df.columns))
49
-
50
- """
51
-
52
-
53
- def df_to_sdf_file(
54
- df: pd.DataFrame,
55
- output_file: str,
56
- smiles_col: str = "smiles",
57
- id_col: Optional[str] = None,
58
- include_cols: Optional[List[str]] = None,
59
- skip_invalid: bool = True,
60
- generate_3d: bool = True,
61
- ):
62
- """
63
- Convert DataFrame with SMILES to SDF file.
64
-
65
- Args:
66
- df: DataFrame containing SMILES and other data
67
- output_file: Path to output SDF file
68
- smiles_col: Column name containing SMILES strings
69
- id_col: Column to use as molecule ID/name
70
- include_cols: Specific columns to include as properties (default: all except smiles and molecule columns)
71
- skip_invalid: Skip invalid SMILES instead of raising error
72
- generate_3d: Generate 3D coordinates and optimize geometry
73
- """
74
- written_count = 0
75
-
76
- with SDWriter(output_file) as writer:
77
- writer.SetForceV3000(True)
78
- for idx, row in df.iterrows():
79
- mol = Chem.MolFromSmiles(row[smiles_col])
80
- if mol is None:
81
- if not skip_invalid:
82
- raise ValueError(f"Invalid SMILES at row {idx}: {row[smiles_col]}")
83
- continue
84
-
85
- # Generate 3D coordinates
86
- if generate_3d:
87
- mol = Chem.AddHs(mol)
88
-
89
- # Try progressively more aggressive embedding strategies
90
- embed_strategies = [
91
- {"maxAttempts": 1000, "randomSeed": 42},
92
- {"maxAttempts": 1000, "randomSeed": 42, "useRandomCoords": True},
93
- {"maxAttempts": 1000, "randomSeed": 42, "boxSizeMult": 5.0},
94
- ]
95
-
96
- embedded = False
97
- for strategy in embed_strategies:
98
- if AllChem.EmbedMolecule(mol, **strategy) != -1:
99
- embedded = True
100
- break
101
-
102
- if not embedded:
103
- if not skip_invalid:
104
- raise ValueError(f"Could not generate 3D coords for row {idx}")
105
- continue
106
-
107
- AllChem.MMFFOptimizeMolecule(mol)
108
-
109
- # Set molecule name/ID
110
- if id_col and id_col in df.columns:
111
- mol.SetProp("_Name", str(row[id_col]))
112
-
113
- # Determine which columns to include
114
- if include_cols:
115
- cols_to_add = [col for col in include_cols if col in df.columns and col != smiles_col]
116
- else:
117
- # Auto-exclude common molecule column names and SMILES column
118
- mol_col_names = ["mol", "molecule", "rdkit_mol", "Mol"]
119
- cols_to_add = [col for col in df.columns if col != smiles_col and col not in mol_col_names]
120
-
121
- # Add properties
122
- for col in cols_to_add:
123
- mol.SetProp(col, str(row[col]))
124
-
125
- writer.write(mol)
126
- written_count += 1
127
-
128
- log.important(f"Wrote {written_count} molecules to SDF: {output_file}")
129
-
130
-
131
- def sdf_file_to_df(
132
- sdf_file: str,
133
- include_smiles: bool = True,
134
- smiles_col: str = "smiles",
135
- id_col: Optional[str] = None,
136
- include_props: Optional[List[str]] = None,
137
- exclude_props: Optional[List[str]] = None,
138
- ) -> pd.DataFrame:
139
- """
140
- Convert SDF file to DataFrame.
141
-
142
- Args:
143
- sdf_file: Path to input SDF file
144
- include_smiles: Add SMILES column to output
145
- smiles_col: Name for SMILES column
146
- id_col: Column name for molecule ID/name (uses _Name property)
147
- include_props: Specific properties to include (default: all)
148
- exclude_props: Properties to exclude from output
149
-
150
- Returns:
151
- DataFrame with molecules and their properties
152
- """
153
- data = []
154
-
155
- suppl = Chem.SDMolSupplier(sdf_file)
156
- for idx, mol in enumerate(suppl):
157
- if mol is None:
158
- log.warning(f"Could not parse molecule at index {idx}")
159
- continue
160
-
161
- row_data = {}
162
-
163
- # Add SMILES if requested
164
- if include_smiles:
165
- row_data[smiles_col] = Chem.MolToSmiles(mol)
166
-
167
- # Add molecule name/ID if requested
168
- if id_col and mol.HasProp("_Name"):
169
- row_data[id_col] = mol.GetProp("_Name")
170
-
171
- # Get all properties
172
- prop_names = mol.GetPropNames()
173
-
174
- # Filter properties based on include/exclude lists
175
- if include_props:
176
- prop_names = [p for p in prop_names if p in include_props]
177
- if exclude_props:
178
- prop_names = [p for p in prop_names if p not in exclude_props]
179
-
180
- # Add properties to row
181
- for prop in prop_names:
182
- if prop != "_Name": # Skip _Name if we already handled it
183
- row_data[prop] = mol.GetProp(prop)
184
-
185
- data.append(row_data)
186
-
187
- df = pd.DataFrame(data)
188
- log.important(f"Read {len(df)} molecules from SDF: {sdf_file}")
189
-
190
- return df
191
-
192
-
193
- def img_from_smiles(
194
- smiles: str, width: int = 500, height: int = 500, background: str = "rgba(64, 64, 64, 1)"
195
- ) -> Optional[str]:
196
- """
197
- Generate an image of the molecule represented by the given SMILES string.
198
-
199
- Args:
200
- smiles (str): A SMILES string representing the molecule.
201
- width (int): Width of the image in pixels. Default is 500.
202
- height (int): Height of the image in pixels. Default is 500.
203
- background (str): Background color of the image. Default is dark grey
204
-
205
- Returns:
206
- str: PIL image of the molecule or None if the SMILES string is invalid.
207
- """
208
-
209
- # Set up the drawing options
210
- dos = Draw.MolDrawOptions()
211
- if is_dark(background):
212
- rdMolDraw2D.SetDarkMode(dos)
213
- dos.setBackgroundColour(rgba_to_tuple(background))
214
-
215
- # Convert the SMILES string to an RDKit molecule and generate the image
216
- mol = Chem.MolFromSmiles(smiles)
217
- if mol:
218
- img = Draw.MolToImage(mol, options=dos, size=(width, height))
219
- return img
220
- else:
221
- log.warning(f"Invalid SMILES: {smiles}")
222
- return None
223
-
224
-
225
- def svg_from_smiles(
226
- smiles: str, width: int = 500, height: int = 500, background: str = "rgba(64, 64, 64, 1)"
227
- ) -> Optional[str]:
228
- """
229
- Generate an SVG image of the molecule represented by the given SMILES string.
230
-
231
- Args:
232
- smiles (str): A SMILES string representing the molecule.
233
- width (int): Width of the image in pixels. Default is 500.
234
- height (int): Height of the image in pixels. Default is 500.
235
- background (str): Background color of the image. Default is dark grey.
236
-
237
- Returns:
238
- Optional[str]: Encoded SVG string of the molecule or None if the SMILES string is invalid.
239
- """
240
- # Convert the SMILES string to an RDKit molecule
241
- mol = Chem.MolFromSmiles(smiles)
242
- if not mol:
243
- return None
244
-
245
- # Compute 2D coordinates for the molecule
246
- AllChem.Compute2DCoords(mol)
247
-
248
- # Initialize the SVG drawer
249
- drawer = rdMolDraw2D.MolDraw2DSVG(width, height)
250
-
251
- # Configure drawing options
252
- options = drawer.drawOptions()
253
- if is_dark(background):
254
- rdMolDraw2D.SetDarkMode(options)
255
- options.setBackgroundColour(rgba_to_tuple(background))
256
-
257
- # Draw the molecule
258
- drawer.DrawMolecule(mol)
259
- drawer.FinishDrawing()
260
-
261
- # Clean and encode the SVG
262
- svg = drawer.GetDrawingText()
263
- encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
264
- return f"data:image/svg+xml;base64,{encoded_svg}"
265
-
266
-
267
- def show(smiles: str, width: int = 500, height: int = 500) -> None:
268
- """
269
- Displays an image of the molecule represented by the given SMILES string.
270
-
271
- Args:
272
- smiles (str): A SMILES string representing the molecule.
273
- width (int): Width of the image in pixels. Default is 500.
274
- height (int): Height of the image in pixels. Default is 500.
275
-
276
- Returns:
277
- None
278
- """
279
- img = img_from_smiles(smiles, width, height)
280
- if img:
281
- img.show()
282
-
283
-
284
- def geometric_mean(series: pd.Series) -> float:
285
- """Computes the geometric mean manually to avoid using scipy."""
286
- return np.exp(np.log(series).mean())
287
-
288
-
289
- def rollup_experimental_data(
290
- df: pd.DataFrame, id: str, time: str, target: str, use_gmean: bool = False
291
- ) -> pd.DataFrame:
292
- """
293
- Rolls up a dataset by selecting the largest time per unique ID and averaging the target value
294
- if multiple records exist at that time. Supports both arithmetic and geometric mean.
295
-
296
- Parameters:
297
- df (pd.DataFrame): Input dataframe.
298
- id (str): Column representing the unique molecule ID.
299
- time (str): Column representing the time.
300
- target (str): Column representing the target value.
301
- use_gmean (bool): Whether to use the geometric mean instead of the arithmetic mean.
302
-
303
- Returns:
304
- pd.DataFrame: Rolled-up dataframe with all original columns retained.
305
- """
306
- # Find the max time per unique ID
307
- max_time_df = df.groupby(id)[time].transform("max")
308
- filtered_df = df[df[time] == max_time_df]
309
-
310
- # Define aggregation function
311
- agg_func = geometric_mean if use_gmean else np.mean
312
-
313
- # Perform aggregation on all columns
314
- agg_dict = {col: "first" for col in df.columns if col not in [target, id, time]}
315
- agg_dict[target] = lambda x: agg_func(x) if len(x) > 1 else x.iloc[0] # Apply mean or gmean
316
-
317
- rolled_up_df = filtered_df.groupby([id, time]).agg(agg_dict).reset_index()
318
- return rolled_up_df
319
-
320
-
321
- def micromolar_to_log(series_µM: pd.Series) -> pd.Series:
322
- """
323
- Convert a pandas Series of concentrations in µM (micromolar) to their logarithmic values (log10).
324
-
325
- Parameters:
326
- series_uM (pd.Series): Series of concentrations in micromolar.
327
-
328
- Returns:
329
- pd.Series: Series of logarithmic values (log10).
330
- """
331
- # Replace 0 or negative values with a small number to avoid log errors
332
- adjusted_series = series_µM.clip(lower=1e-9) # Alignment with another project
333
-
334
- series_mol_per_l = adjusted_series * 1e-6 # Convert µM/L to mol/L
335
- log_series = np.log10(series_mol_per_l)
336
- return log_series
337
-
338
-
339
- def log_to_micromolar(log_series: pd.Series) -> pd.Series:
340
- """
341
- Convert a pandas Series of logarithmic values (log10) back to concentrations in µM (micromolar).
342
-
343
- Parameters:
344
- log_series (pd.Series): Series of logarithmic values (log10).
345
-
346
- Returns:
347
- pd.Series: Series of concentrations in micromolar.
348
- """
349
- series_mol_per_l = 10**log_series # Convert log10 back to mol/L
350
- series_µM = series_mol_per_l * 1e6 # Convert mol/L to µM
351
- return series_µM
352
-
353
-
354
- def log_to_category(log_series: pd.Series) -> pd.Series:
355
- """
356
- Convert a pandas Series of log values to concentration categories.
357
-
358
- Parameters:
359
- log_series (pd.Series): Series of logarithmic values (log10).
360
-
361
- Returns:
362
- pd.Series: Series of concentration categories.
363
- """
364
- # Create a solubility classification column
365
- bins = [-float("inf"), -5, -4, float("inf")]
366
- labels = ["low", "medium", "high"]
367
- return pd.cut(log_series, bins=bins, labels=labels)
368
-
369
-
370
- def remove_disconnected_fragments(mol: Chem.Mol) -> Chem.Mol:
371
- """
372
- Remove disconnected fragments from a molecule, keeping the fragment with the most heavy atoms.
373
-
374
- Args:
375
- mol (Mol): RDKit molecule object.
376
-
377
- Returns:
378
- Mol: The fragment with the most heavy atoms, or None if no such fragment exists.
379
- """
380
- if mol is None or mol.GetNumAtoms() == 0:
381
- return None
382
- fragments = Chem.GetMolFrags(mol, asMols=True)
383
- return max(fragments, key=lambda frag: frag.GetNumHeavyAtoms()) if fragments else None
384
-
385
-
386
- def contains_heavy_metals(mol: Mol) -> bool:
387
- """
388
- Check if a molecule contains any heavy metals (broad filter).
389
-
390
- Args:
391
- mol: RDKit molecule object.
392
-
393
- Returns:
394
- bool: True if any heavy metals are detected, False otherwise.
395
- """
396
- heavy_metals = {"Zn", "Cu", "Fe", "Mn", "Co", "Pb", "Hg", "Cd", "As"}
397
- return any(atom.GetSymbol() in heavy_metals for atom in mol.GetAtoms())
398
-
399
-
400
- def halogen_toxicity_score(mol: Mol) -> (int, int):
401
- """
402
- Calculate the halogen count and toxicity threshold for a molecule.
403
-
404
- Args:
405
- mol: RDKit molecule object.
406
-
407
- Returns:
408
- Tuple[int, int]: (halogen_count, halogen_threshold), where the threshold
409
- scales with molecule size (minimum of 2 or 20% of atom count).
410
- """
411
- # Define halogens and count their occurrences
412
- halogens = {"Cl", "Br", "I", "F"}
413
- halogen_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() in halogens)
414
-
415
- # Define threshold: small molecules tolerate fewer halogens
416
- # Threshold scales with molecule size to account for reasonable substitution
417
- molecule_size = mol.GetNumAtoms()
418
- halogen_threshold = max(2, int(molecule_size * 0.2)) # Minimum 2, scaled by 20% of molecule size
419
-
420
- return halogen_count, halogen_threshold
421
-
422
-
423
- def toxic_elements(mol: Mol) -> Optional[List[str]]:
424
- """
425
- Identifies toxic elements or specific forms of elements in a molecule.
426
-
427
- Args:
428
- mol: RDKit molecule object.
429
-
430
- Returns:
431
- Optional[List[str]]: List of toxic elements or specific forms if found, otherwise None.
432
-
433
- Notes:
434
- Halogen toxicity logic integrates with `halogen_toxicity_score` and scales thresholds
435
- based on molecule size.
436
- """
437
- # Always toxic elements (heavy metals and known toxic single elements)
438
- always_toxic = {"Pb", "Hg", "Cd", "As", "Be", "Tl", "Sb"}
439
- toxic_found = set()
440
-
441
- for atom in mol.GetAtoms():
442
- symbol = atom.GetSymbol()
443
- formal_charge = atom.GetFormalCharge()
444
-
445
- # Check for always toxic elements
446
- if symbol in always_toxic:
447
- toxic_found.add(symbol)
448
-
449
- # Conditionally toxic nitrogen (positively charged)
450
- if symbol == "N" and formal_charge > 0:
451
- # Exclude benign quaternary ammonium (e.g., choline-like structures)
452
- if mol.HasSubstructMatch(Chem.MolFromSmarts("[N+](C)(C)(C)C")): # Example benign structure
453
- continue
454
- toxic_found.add("N+")
455
-
456
- # Halogen toxicity: Uses halogen_toxicity_score to flag excessive halogenation
457
- if symbol in {"Cl", "Br", "I", "F"}:
458
- halogen_count, halogen_threshold = halogen_toxicity_score(mol)
459
- if halogen_count > halogen_threshold:
460
- toxic_found.add(symbol)
461
-
462
- return list(toxic_found) if toxic_found else None
463
-
464
-
465
- # Precompiled SMARTS patterns for custom toxic functional groups
466
- toxic_smarts_patterns = [
467
- ("C(=S)N"), # Dithiocarbamate
468
- ("P(=O)(O)(O)O"), # Phosphate Ester
469
- ("[As](=O)(=O)-[OH]"), # Arsenic Oxide
470
- ("[C](Cl)(Cl)(Cl)"), # Trichloromethyl
471
- ("[Cr](=O)(=O)=O"), # Chromium(VI)
472
- ("[N+](C)(C)(C)(C)"), # Quaternary Ammonium
473
- ("[Se][Se]"), # Diselenide
474
- ("c1c(Cl)c(Cl)c(Cl)c1"), # Trichlorinated Aromatic Ring
475
- ("[CX3](=O)[CX4][Cl,Br,F,I]"), # Halogenated Carbonyl
476
- ("[P+](C*)(C*)(C*)(C*)"), # Phosphonium Group
477
- ("NC(=S)c1c(Cl)cccc1Cl"), # Chlorobenzene Thiocarbamate
478
- ("NC(=S)Nc1ccccc1"), # Phenyl Thiocarbamate
479
- ("S=C1NCCN1"), # Thiourea Derivative
480
- ]
481
- compiled_toxic_smarts = [Chem.MolFromSmarts(smarts) for smarts in toxic_smarts_patterns]
482
-
483
- # Precompiled SMARTS patterns for exemptions
484
- exempt_smarts_patterns = [
485
- "c1ccc(O)c(O)c1", # Phenols
486
- ]
487
- compiled_exempt_smarts = [Chem.MolFromSmarts(smarts) for smarts in exempt_smarts_patterns]
488
-
489
-
490
- def toxic_groups(mol: Chem.Mol) -> Optional[List[str]]:
491
- """
492
- Check if a molecule contains known toxic functional groups using RDKit's functional groups and SMARTS patterns.
493
-
494
- Args:
495
- mol (rdkit.Chem.Mol): The molecule to evaluate.
496
-
497
- Returns:
498
- Optional[List[str]]: List of SMARTS patterns for toxic groups if found, otherwise None.
499
- """
500
- toxic_smarts_matches = []
501
-
502
- # Use RDKit's functional group definitions
503
- toxic_group_names = ["Nitro", "Azide", "Alcohol", "Aldehyde", "Halogen", "TerminalAlkyne"]
504
- for group_name in toxic_group_names:
505
- group_node = next(node for node in fgroup_hierarchy if node.label == group_name)
506
- if mol.HasSubstructMatch(Chem.MolFromSmarts(group_node.smarts)):
507
- toxic_smarts_matches.append(group_node.smarts) # Use group_node's SMARTS directly
508
-
509
- # Check for custom precompiled toxic SMARTS patterns
510
- for smarts, compiled in zip(toxic_smarts_patterns, compiled_toxic_smarts):
511
- if mol.HasSubstructMatch(compiled): # Use precompiled SMARTS
512
- toxic_smarts_matches.append(smarts)
513
-
514
- # Special handling for N+
515
- if mol.HasSubstructMatch(Chem.MolFromSmarts("[N+]")):
516
- if not mol.HasSubstructMatch(Chem.MolFromSmarts("C[N+](C)(C)C")): # Exclude benign
517
- toxic_smarts_matches.append("[N+]") # Append as SMARTS
518
-
519
- # Exempt stabilizing functional groups using precompiled patterns
520
- for compiled in compiled_exempt_smarts:
521
- if mol.HasSubstructMatch(compiled):
522
- return None
523
-
524
- return toxic_smarts_matches if toxic_smarts_matches else None
525
-
526
-
527
- def contains_metalloenzyme_relevant_metals(mol: Mol) -> bool:
528
- """
529
- Check if a molecule contains metals relevant to metalloenzymes.
530
-
531
- Args:
532
- mol: RDKit molecule object.
533
-
534
- Returns:
535
- bool: True if metalloenzyme-relevant metals are detected, False otherwise.
536
- """
537
- metalloenzyme_metals = {"Zn", "Cu", "Fe", "Mn", "Co"}
538
- return any(atom.GetSymbol() in metalloenzyme_metals for atom in mol.GetAtoms())
539
-
540
-
541
- def contains_salts(mol: Mol) -> bool:
542
- """
543
- Check if a molecule contains common salts or counterions.
544
-
545
- Args:
546
- mol: RDKit molecule object.
547
-
548
- Returns:
549
- bool: True if salts are detected, False otherwise.
550
- """
551
- # Define common inorganic salt fragments (SMARTS patterns)
552
- salt_patterns = ["[Na+]", "[K+]", "[Cl-]", "[Mg+2]", "[Ca+2]", "[NH4+]", "[SO4--]"]
553
- for pattern in salt_patterns:
554
- if mol.HasSubstructMatch(Chem.MolFromSmarts(pattern)):
555
- return True
556
- return False
557
-
558
-
559
- def is_druglike_compound(mol: Mol) -> bool:
560
- """
561
- Filter for drug-likeness and QSAR relevance based on Lipinski's Rule of Five.
562
- Returns False for molecules unlikely to be orally bioavailable.
563
-
564
- Args:
565
- mol: RDKit molecule object.
566
-
567
- Returns:
568
- bool: True if the molecule is drug-like, False otherwise.
569
- """
570
-
571
- # Lipinski's Rule of Five
572
- mw = Descriptors.MolWt(mol)
573
- logp = Descriptors.MolLogP(mol)
574
- hbd = Descriptors.NumHDonors(mol)
575
- hba = Descriptors.NumHAcceptors(mol)
576
- if mw > 500 or logp > 5 or hbd > 5 or hba > 10:
577
- return False
578
-
579
- # Allow exceptions for linear molecules that meet strict RO5 criteria
580
- if mol.GetRingInfo().NumRings() == 0:
581
- if mw <= 300 and logp <= 3 and hbd <= 3 and hba <= 3:
582
- pass # Allow small, non-cyclic druglike compounds
583
- else:
584
- return False
585
-
586
- return True
587
-
588
-
589
- def add_compound_tags(df, mol_column="molecule") -> pd.DataFrame:
590
- """
591
- Adds a 'tags' column to a DataFrame, tagging compounds based on their properties.
592
-
593
- Args:
594
- df (pd.DataFrame): Input DataFrame containing molecular data.
595
- mol_column (str): Column name containing RDKit molecule objects.
596
-
597
- Returns:
598
- pd.DataFrame: Updated DataFrame with a 'tags' column.
599
- """
600
- # Initialize the tags column
601
- df["tags"] = [[] for _ in range(len(df))]
602
- df["meta"] = [{} for _ in range(len(df))]
603
-
604
- # Process each molecule in the DataFrame
605
- for idx, row in df.iterrows():
606
- mol = row[mol_column]
607
- tags = []
608
-
609
- # Check for salts
610
- if contains_salts(mol):
611
- tags.append("salt")
612
-
613
- # Check for fragments (should be done after salt check)
614
- fragments = Chem.GetMolFrags(mol, asMols=True)
615
- if len(fragments) > 1:
616
- tags.append("frag")
617
-
618
- # Check for heavy metals
619
- if contains_heavy_metals(mol):
620
- tags.append("heavy_metals")
621
-
622
- # Check for toxic elements
623
- te = toxic_elements(mol)
624
- if te:
625
- tags.append("toxic_element")
626
- df.at[idx, "meta"]["toxic_elements"] = te
627
-
628
- # Check for toxic groups
629
- tg = toxic_groups(mol)
630
- if tg:
631
- tags.append("toxic_group")
632
- df.at[idx, "meta"]["toxic_groups"] = tg
633
-
634
- # Check for metalloenzyme-relevant metals
635
- if contains_metalloenzyme_relevant_metals(mol):
636
- tags.append("metalloenzyme")
637
-
638
- # Check for drug-likeness
639
- if is_druglike_compound(mol):
640
- tags.append("druglike")
641
-
642
- # Update tags
643
- df.at[idx, "tags"] = tags
644
-
645
- return df
646
-
647
-
648
- def compute_molecular_descriptors(df: pd.DataFrame, tautomerize=True) -> pd.DataFrame:
649
- """Compute and add all the Molecular Descriptors
650
-
651
- Args:
652
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
653
- tautomerize (bool): Whether to tautomerize the SMILES strings.
654
-
655
- Returns:
656
- pd.DataFrame: The input DataFrame with all the RDKit Descriptors added
657
- """
658
-
659
- # Check for the smiles column (any capitalization)
660
- smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
661
- if smiles_column is None:
662
- raise ValueError("Input DataFrame must have a 'smiles' column")
663
-
664
- # Compute/add all the Molecular Descriptors
665
- log.info("Computing Molecular Descriptors...")
666
-
667
- # Convert SMILES to RDKit molecule objects (vectorized)
668
- log.info("Converting SMILES to RDKit Molecules...")
669
- df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
670
-
671
- # Make sure our molecules are not None
672
- failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
673
- if failed_smiles:
674
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
675
- df = df.dropna(subset=["molecule"])
676
-
677
- # If we have fragments in our compounds, get the largest fragment before computing descriptors
678
- df["molecule"] = df["molecule"].apply(remove_disconnected_fragments)
679
-
680
- # Tautomerize the molecules if requested
681
- if tautomerize:
682
- log.info("Tautomerizing molecules...")
683
- tautomer_enumerator = TautomerEnumerator()
684
- df["molecule"] = df["molecule"].apply(tautomer_enumerator.Canonicalize)
685
-
686
- # Now get all the RDKIT Descriptors
687
- all_descriptors = [x[0] for x in Descriptors._descList]
688
-
689
- # There's an overflow issue that happens with the IPC descriptor, so we'll remove it
690
- # See: https://github.com/rdkit/rdkit/issues/1527
691
- if "Ipc" in all_descriptors:
692
- all_descriptors.remove("Ipc")
693
-
694
- # Make sure we don't have duplicates
695
- all_descriptors = list(set(all_descriptors))
696
-
697
- # RDKit Molecular Descriptor Calculator Class
698
- log.info("Computing RDKit Descriptors...")
699
- calc = MoleculeDescriptors.MolecularDescriptorCalculator(all_descriptors)
700
- descriptor_values = [calc.CalcDescriptors(m) for m in df["molecule"]]
701
-
702
- # Lowercase the column names
703
- column_names = [name.lower() for name in calc.GetDescriptorNames()]
704
- rdkit_features_df = pd.DataFrame(descriptor_values, columns=column_names)
705
-
706
- # Now compute Mordred Features
707
- log.info("Computing Mordred Descriptors...")
708
- descriptor_choice = [AcidBase, Aromatic, Polarizability, RotatableBond]
709
- calc = MordredCalculator()
710
- for des in descriptor_choice:
711
- calc.register(des)
712
- mordred_df = calc.pandas(df["molecule"], nproc=1)
713
-
714
- # Lowercase the column names
715
- mordred_df.columns = [col.lower() for col in mordred_df.columns]
716
-
717
- # Compute stereochemistry descriptors
718
- stereo_df = compute_stereochemistry_descriptors(df)
719
-
720
- # Combine the DataFrame with the RDKit and Mordred Descriptors added
721
- # Note: This will overwrite any existing columns with the same name. This is a good thing
722
- # since we want computed descriptors to overwrite anything in the input dataframe
723
- output_df = stereo_df.combine_first(mordred_df).combine_first(rdkit_features_df)
724
-
725
- # Ensure no duplicate column names
726
- output_df = output_df.loc[:, ~output_df.columns.duplicated()]
727
-
728
- # Reorder the columns to have all the ones in the input df first and then the descriptors
729
- input_columns = df.columns.tolist()
730
- output_df = output_df[input_columns + [col for col in output_df.columns if col not in input_columns]]
731
-
732
- # Drop the intermediate 'molecule' column
733
- del output_df["molecule"]
734
-
735
- # Return the DataFrame with the RDKit and Mordred Descriptors added
736
- return output_df
737
-
738
-
739
- def compute_stereochemistry_descriptors(df: pd.DataFrame) -> pd.DataFrame:
740
- """Compute stereochemistry descriptors for molecules in a DataFrame.
741
-
742
- This function analyzes the stereochemical properties of molecules, including:
743
- - Chiral centers (R/S configuration)
744
- - Double bond stereochemistry (E/Z configuration)
745
-
746
- Args:
747
- df (pd.DataFrame): Input DataFrame with RDKit molecule objects in 'molecule' column
748
-
749
- Returns:
750
- pd.DataFrame: DataFrame with added stereochemistry descriptors
751
- """
752
- if "molecule" not in df.columns:
753
- raise ValueError("Input DataFrame must have a 'molecule' column")
754
-
755
- log.info("Computing stereochemistry descriptors...")
756
- output_df = df.copy()
757
-
758
- # Create helper functions to process a single molecule
759
- def process_molecule(mol):
760
- if mol is None:
761
- log.warning("Found a None molecule, skipping...")
762
- return {
763
- "chiral_centers": 0,
764
- "r_cnt": 0,
765
- "s_cnt": 0,
766
- "db_stereo": 0,
767
- "e_cnt": 0,
768
- "z_cnt": 0,
769
- "chiral_fp": 0,
770
- "db_fp": 0,
771
- }
772
-
773
- try:
774
- # Use the more accurate CIP labeling algorithm (Cahn-Ingold-Prelog rules)
775
- # This assigns R/S to chiral centers and E/Z to double bonds based on
776
- # the priority of substituents (atomic number, mass, etc.)
777
- rdCIPLabeler.AssignCIPLabels(mol)
778
-
779
- # Find all potential stereochemistry sites in the molecule
780
- stereo_info = Chem.FindPotentialStereo(mol)
781
-
782
- # Initialize counters
783
- specified_centers = 0 # Number of chiral centers with defined stereochemistry
784
- r_cnt = 0 # Count of R configured centers
785
- s_cnt = 0 # Count of S configured centers
786
- stereo_atoms = [] # List to store atom indices and their R/S configuration
787
-
788
- specified_bonds = 0 # Number of double bonds with defined stereochemistry
789
- e_cnt = 0 # Count of E (trans) configured double bonds
790
- z_cnt = 0 # Count of Z (cis) configured double bonds
791
- stereo_bonds = [] # List to store bond indices and their E/Z configuration
792
-
793
- # Process all stereo information found in the molecule
794
- for element in stereo_info:
795
- # Handle tetrahedral chiral centers
796
- if element.type == Chem.StereoType.Atom_Tetrahedral:
797
- atom_idx = element.centeredOn
798
-
799
- # Only count centers where stereochemistry is explicitly defined
800
- if element.specified == Chem.StereoSpecified.Specified:
801
- specified_centers += 1
802
- if element.descriptor == Chem.StereoDescriptor.Tet_CCW:
803
- r_cnt += 1
804
- stereo_atoms.append((atom_idx, "R"))
805
- elif element.descriptor == Chem.StereoDescriptor.Tet_CW:
806
- s_cnt += 1
807
- stereo_atoms.append((atom_idx, "S"))
808
-
809
- # Handle double bond stereochemistry
810
- elif element.type == Chem.StereoType.Bond_Double:
811
- bond_idx = element.centeredOn
812
-
813
- # Only count bonds where stereochemistry is explicitly defined
814
- if element.specified == Chem.StereoSpecified.Specified:
815
- specified_bonds += 1
816
- if element.descriptor == Chem.StereoDescriptor.Bond_Trans:
817
- e_cnt += 1
818
- stereo_bonds.append((bond_idx, "E"))
819
- elif element.descriptor == Chem.StereoDescriptor.Bond_Cis:
820
- z_cnt += 1
821
- stereo_bonds.append((bond_idx, "Z"))
822
-
823
- # Calculate chiral center fingerprint - unique bit vector for stereochemical configuration
824
- chiral_fp = 0
825
- if stereo_atoms:
826
- for i, (idx, stereo) in enumerate(sorted(stereo_atoms, key=lambda x: x[0])):
827
- bit_val = 1 if stereo == "R" else 0
828
- chiral_fp += bit_val << i # Shift bits to create a unique fingerprint
829
-
830
- # Calculate double bond fingerprint - bit vector for E/Z configurations
831
- db_fp = 0
832
- if stereo_bonds:
833
- for i, (idx, stereo) in enumerate(sorted(stereo_bonds, key=lambda x: x[0])):
834
- bit_val = 1 if stereo == "E" else 0
835
- db_fp += bit_val << i # Shift bits to create a unique fingerprint
836
-
837
- return {
838
- "chiral_centers": specified_centers,
839
- "r_cnt": r_cnt,
840
- "s_cnt": s_cnt,
841
- "db_stereo": specified_bonds,
842
- "e_cnt": e_cnt,
843
- "z_cnt": z_cnt,
844
- "chiral_fp": chiral_fp,
845
- "db_fp": db_fp,
846
- }
847
-
848
- except Exception as e:
849
- log.warning(f"Error processing stereochemistry: {str(e)}")
850
- return {
851
- "chiral_centers": 0,
852
- "r_cnt": 0,
853
- "s_cnt": 0,
854
- "db_stereo": 0,
855
- "e_cnt": 0,
856
- "z_cnt": 0,
857
- "chiral_fp": 0,
858
- "db_fp": 0,
859
- }
860
-
861
- # Process all molecules and collect results
862
- results = []
863
- for mol in df["molecule"]:
864
- results.append(process_molecule(mol))
865
-
866
- # Add all descriptors to the output dataframe
867
- for key in results[0].keys():
868
- output_df[key] = [r[key] for r in results]
869
-
870
- # Boolean flag indicating if the molecule has any stereochemistry defined
871
- output_df["has_stereo"] = (output_df["chiral_centers"] > 0) | (output_df["db_stereo"] > 0)
872
-
873
- return output_df
874
-
875
-
876
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
877
- """Compute and add Morgan fingerprints to the DataFrame.
878
-
879
- Args:
880
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
881
- radius (int): Radius for the Morgan fingerprint.
882
- n_bits (int): Number of bits for the fingerprint.
883
- counts (bool): Count simulation for the fingerprint.
884
-
885
- Returns:
886
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
887
-
888
- Note:
889
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
890
- """
891
- delete_mol_column = False
892
-
893
- # Check for the SMILES column (case-insensitive)
894
- smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
895
- if smiles_column is None:
896
- raise ValueError("Input DataFrame must have a 'smiles' column")
897
-
898
- # Sanity check the molecule column (sometimes it gets serialized, which doesn't work)
899
- if "molecule" in df.columns and df["molecule"].dtype == "string":
900
- log.warning("Detected serialized molecules in 'molecule' column. Removing...")
901
- del df["molecule"]
902
-
903
- # Convert SMILES to RDKit molecule objects (vectorized)
904
- if "molecule" not in df.columns:
905
- log.info("Converting SMILES to RDKit Molecules...")
906
- delete_mol_column = True
907
- df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
908
- # Make sure our molecules are not None
909
- failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
910
- if failed_smiles:
911
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
912
- df = df.dropna(subset=["molecule"])
913
-
914
- # If we have fragments in our compounds, get the largest fragment before computing fingerprints
915
- largest_frags = df["molecule"].apply(remove_disconnected_fragments)
916
-
917
- # Create a Morgan fingerprint generator
918
- if counts:
919
- n_bits *= 4 # Multiply by 4 to simulate counts
920
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
921
-
922
- # Compute Morgan fingerprints (vectorized)
923
- fingerprints = largest_frags.apply(
924
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
925
- )
926
-
927
- # Add the fingerprints to the DataFrame
928
- df["fingerprint"] = fingerprints
929
-
930
- # Drop the intermediate 'molecule' column if it was added
931
- if delete_mol_column:
932
- del df["molecule"]
933
- return df
934
-
935
-
936
- def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
937
- """
938
- Convert bitstring fingerprints to numpy matrix.
939
-
940
- Args:
941
- fingerprints: pandas Series or list of bitstring fingerprints
942
- dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
943
-
944
- Returns:
945
- dense numpy array of shape (n_molecules, n_bits)
946
- """
947
-
948
- # Dense matrix representation (we might support sparse in the future)
949
- return np.array([list(fp) for fp in fingerprints], dtype=dtype)
950
-
951
-
952
- def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame:
953
- """Project fingerprints onto a 2D plane using dimensionality reduction techniques.
954
-
955
- Args:
956
- df (pd.DataFrame): Input DataFrame containing fingerprint data.
957
- projection (str): Dimensionality reduction technique to use (TSNE or UMAP).
958
-
959
- Returns:
960
- pd.DataFrame: The input DataFrame with the projected coordinates added as 'x' and 'y' columns.
961
- """
962
- # Check for the fingerprint column (case-insensitive)
963
- fingerprint_column = next((col for col in df.columns if "fingerprint" in col.lower()), None)
964
- if fingerprint_column is None:
965
- raise ValueError("Input DataFrame must have a fingerprint column")
966
-
967
- # Create a matrix of fingerprints
968
- X = fingerprints_to_matrix(df[fingerprint_column])
969
-
970
- # Check for UMAP availability
971
- if projection == "UMAP" and umap is None:
972
- log.warning("UMAP is not available. Using TSNE instead.")
973
- projection = "TSNE"
974
-
975
- # Run the projection
976
- if projection == "TSNE":
977
- # Run TSNE on the fingerprint matrix
978
- tsne = TSNE(n_components=2, perplexity=30, random_state=42)
979
- embedding = tsne.fit_transform(X)
980
- else:
981
- # Run UMAP
982
- # reducer = umap.UMAP(densmap=True)
983
- reducer = umap.UMAP(metric="jaccard")
984
- embedding = reducer.fit_transform(X)
985
-
986
- # Add coordinates to DataFrame
987
- df["x"] = embedding[:, 0]
988
- df["y"] = embedding[:, 1]
989
-
990
- # If vertices disconnect from the manifold, they are given NaN values (so replace with 0)
991
- df["x"] = df["x"].fillna(0)
992
- df["y"] = df["y"].fillna(0)
993
-
994
- # Jitter
995
- jitter_scale = 0.1
996
- df["x"] += np.random.uniform(0, jitter_scale, len(df))
997
- df["y"] += np.random.uniform(0, jitter_scale, len(df))
998
-
999
- return df
1000
-
1001
-
1002
- def canonicalize(df: pd.DataFrame, remove_mol_col: bool = True) -> pd.DataFrame:
1003
- """
1004
- Generate RDKit's canonical SMILES for each molecule in the input DataFrame.
1005
-
1006
- Args:
1007
- df (pd.DataFrame): Input DataFrame containing a column named 'SMILES' (case-insensitive).
1008
- remove_mol_col (bool): Whether to drop the intermediate 'molecule' column. Default is True.
1009
-
1010
- Returns:
1011
- pd.DataFrame: A DataFrame with an additional 'smiles_canonical' column and,
1012
- optionally, the 'molecule' column.
1013
- """
1014
- # Identify the SMILES column (case-insensitive)
1015
- smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
1016
- if smiles_column is None:
1017
- raise ValueError("Input DataFrame must have a 'SMILES' column")
1018
-
1019
- # Convert SMILES to RDKit molecules
1020
- df["molecule"] = df[smiles_column].apply(Chem.MolFromSmiles)
1021
-
1022
- # Log invalid SMILES
1023
- invalid_indices = df[df["molecule"].isna()].index
1024
- if not invalid_indices.empty:
1025
- log.critical(f"Invalid SMILES strings at indices: {invalid_indices.tolist()}")
1026
-
1027
- # Drop rows where SMILES failed to convert to molecule
1028
- df.dropna(subset=["molecule"], inplace=True)
1029
-
1030
- # Remove disconnected fragments (keep the largest fragment)
1031
- df["molecule"] = df["molecule"].apply(lambda mol: remove_disconnected_fragments(mol) if mol else None)
1032
-
1033
- # Convert molecules to canonical SMILES (preserving isomeric information)
1034
- df["smiles_canonical"] = df["molecule"].apply(
1035
- lambda mol: Chem.MolToSmiles(mol, isomericSmiles=True) if mol else None
1036
- )
1037
-
1038
- # Drop intermediate RDKit molecule column if requested
1039
- if remove_mol_col:
1040
- df.drop(columns=["molecule"], inplace=True)
1041
-
1042
- return df
1043
-
1044
-
1045
- def custom_tautomer_canonicalization(mol: Mol) -> str:
1046
- """Domain-specific processing of a molecule to select the canonical tautomer.
1047
-
1048
- This function enumerates all possible tautomers for a given molecule and applies
1049
- custom logic to select the canonical form.
1050
-
1051
- Args:
1052
- mol (Mol): The RDKit molecule for which the canonical tautomer is to be determined.
1053
-
1054
- Returns:
1055
- str: The SMILES string of the selected canonical tautomer.
1056
- """
1057
- tautomer_enumerator = TautomerEnumerator()
1058
- enumerated_tautomers = tautomer_enumerator.Enumerate(mol)
1059
-
1060
- # Example custom logic: prioritize based on use-case specific criteria
1061
- selected_tautomer = None
1062
- highest_score = float("-inf")
1063
-
1064
- for taut in enumerated_tautomers:
1065
- # Compute custom scoring logic:
1066
- # 1. Prefer forms with fewer hydrogen bond donors (HBD) if membrane permeability is important
1067
- # 2. Penalize forms with high molecular weight for better drug-likeness
1068
- # 3. Incorporate known functional group preferences (e.g., keto > enol for binding)
1069
-
1070
- hbd = CalcNumHBD(taut) # Hydrogen Bond Donors
1071
- mw = CalcExactMolWt(taut) # Molecular Weight
1072
- aromatic_rings = taut.GetRingInfo().NumAromaticRings() # Favor aromaticity
1073
-
1074
- # Example scoring: balance HBD, MW, and aromaticity
1075
- score = -hbd - 0.01 * mw + aromatic_rings * 2
1076
-
1077
- # Update selected tautomer
1078
- if score > highest_score:
1079
- highest_score = score
1080
- selected_tautomer = taut
1081
-
1082
- # Return the SMILES of the selected tautomer
1083
- return Chem.MolToSmiles(selected_tautomer)
1084
-
1085
-
1086
- def standard_tautomer_canonicalization(mol: Mol) -> str:
1087
- """Standard processing of a molecule to select the canonical tautomer.
1088
-
1089
- RDKit's `TautomerEnumerator` uses heuristics to select a canonical tautomer,
1090
- such as preferring keto over enol forms and minimizing formal charges.
1091
-
1092
- Args:
1093
- mol (Mol): The RDKit molecule for which the canonical tautomer is to be determined.
1094
-
1095
- Returns:
1096
- str: The SMILES string of the canonical tautomer.
1097
- """
1098
- tautomer_enumerator = TautomerEnumerator()
1099
- canonical_tautomer = tautomer_enumerator.Canonicalize(mol)
1100
- return Chem.MolToSmiles(canonical_tautomer)
1101
-
1102
-
1103
- def tautomerize_smiles(df: pd.DataFrame) -> pd.DataFrame:
1104
- """
1105
- Perform tautomer enumeration and canonicalization on a DataFrame.
1106
-
1107
- Args:
1108
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
1109
-
1110
- Returns:
1111
- pd.DataFrame: A new DataFrame with additional 'smiles_canonical' and 'smiles_tautomer' columns.
1112
- """
1113
- # Standardize SMILES strings and create 'molecule' column for further processing
1114
- df = canonicalize(df, remove_mol_col=False)
1115
-
1116
- # Helper function to safely canonicalize a molecule's tautomer
1117
- def safe_tautomerize(mol):
1118
- """Safely canonicalize a molecule's tautomer, handling errors gracefully."""
1119
- if not mol:
1120
- return pd.NA
1121
- try:
1122
- # Use RDKit's standard Tautomer enumeration and canonicalization
1123
- # For custom logic, replace with custom_tautomer_canonicalization(mol)
1124
- return standard_tautomer_canonicalization(mol)
1125
- except Exception as e:
1126
- log.warning(f"Tautomerization failed: {str(e)}")
1127
- return pd.NA
1128
-
1129
- # Apply tautomer canonicalization to each molecule
1130
- df["smiles_tautomer"] = df["molecule"].apply(safe_tautomerize)
1131
-
1132
- # Drop intermediate RDKit molecule column to clean up the DataFrame
1133
- df.drop(columns=["molecule"], inplace=True)
1134
-
1135
- # Now switch the smiles columns
1136
- df.rename(columns={"smiles": "smiles_orig", "smiles_tautomer": "smiles"}, inplace=True)
1137
-
1138
- return df
1139
-
1140
-
1141
- def _get_salt_feature_columns() -> List[str]:
1142
- """Internal: Return list of all salt feature column names"""
1143
- return [
1144
- "has_salt",
1145
- "mw_ratio",
1146
- "salt_to_api_ratio",
1147
- "has_metal_salt",
1148
- "has_halide",
1149
- "ionic_strength_proxy",
1150
- "has_organic_salt",
1151
- ]
1152
-
1153
-
1154
- def _classify_salt_types(salt_frags: List[Chem.Mol]) -> Dict[str, int]:
1155
- """Internal: Classify salt fragments into categories"""
1156
- features = {
1157
- "has_organic_salt": 0,
1158
- "has_metal_salt": 0,
1159
- "has_halide": 0,
1160
- }
1161
-
1162
- for frag in salt_frags:
1163
- # Get atoms
1164
- atoms = [atom.GetSymbol() for atom in frag.GetAtoms()]
1165
-
1166
- # Metal detection
1167
- metals = ["Na", "K", "Ca", "Mg", "Li", "Zn", "Fe", "Al"]
1168
- if any(metal in atoms for metal in metals):
1169
- features["has_metal_salt"] = 1
1170
-
1171
- # Halide detection
1172
- halides = ["Cl", "Br", "I", "F"]
1173
- if any(halide in atoms for halide in halides):
1174
- features["has_halide"] = 1
1175
-
1176
- # Organic vs inorganic (simple heuristic: contains C)
1177
- if "C" in atoms:
1178
- features["has_organic_salt"] = 1
1179
-
1180
- return features
1181
-
1182
-
1183
- def extract_advanced_salt_features(
1184
- mol: Optional[Chem.Mol],
1185
- ) -> Tuple[Optional[Dict[str, Union[int, float]]], Optional[Chem.Mol]]:
1186
- """Extract comprehensive salt-related features from RDKit molecule"""
1187
- if mol is None:
1188
- return None, None
1189
-
1190
- # Get fragments
1191
- fragments = Chem.GetMolFrags(mol, asMols=True)
1192
-
1193
- # Identify API (largest organic fragment) vs salt fragments
1194
- fragment_weights = [(frag, Descriptors.MolWt(frag)) for frag in fragments]
1195
- fragment_weights.sort(key=lambda x: x[1], reverse=True)
1196
-
1197
- # Find largest organic fragment as API
1198
- api_mol = None
1199
- salt_frags = []
1200
-
1201
- for frag, mw in fragment_weights:
1202
- atoms = [atom.GetSymbol() for atom in frag.GetAtoms()]
1203
- if "C" in atoms and api_mol is None: # First organic fragment = API
1204
- api_mol = frag
1205
- else:
1206
- salt_frags.append(frag)
1207
-
1208
- # Fallback: if no organic fragments, use largest
1209
- if api_mol is None:
1210
- api_mol = fragment_weights[0][0]
1211
- salt_frags = [frag for frag, _ in fragment_weights[1:]]
1212
-
1213
- # Initialize all features with default values
1214
- features = {col: 0 for col in _get_salt_feature_columns()}
1215
- features["mw_ratio"] = 1.0 # default for no salt
1216
-
1217
- # Basic features
1218
- features.update(
1219
- {
1220
- "has_salt": int(len(salt_frags) > 0),
1221
- "mw_ratio": Descriptors.MolWt(api_mol) / Descriptors.MolWt(mol),
1222
- }
1223
- )
1224
-
1225
- if salt_frags:
1226
- # Salt characterization
1227
- total_salt_mw = sum(Descriptors.MolWt(frag) for frag in salt_frags)
1228
- features.update(
1229
- {
1230
- "salt_to_api_ratio": total_salt_mw / Descriptors.MolWt(api_mol),
1231
- "ionic_strength_proxy": sum(abs(Chem.GetFormalCharge(frag)) for frag in salt_frags),
1232
- }
1233
- )
1234
-
1235
- # Salt type classification
1236
- features.update(_classify_salt_types(salt_frags))
1237
-
1238
- return features, api_mol
1239
-
1240
-
1241
- def add_salt_features(df: pd.DataFrame) -> pd.DataFrame:
1242
- """Add salt features to dataframe with 'molecule' column containing RDKit molecules"""
1243
- salt_features_list = []
1244
-
1245
- for idx, row in df.iterrows():
1246
- mol = row["molecule"]
1247
- features, clean_mol = extract_advanced_salt_features(mol)
1248
-
1249
- if features is None:
1250
- # Handle invalid molecules
1251
- features = {col: None for col in _get_salt_feature_columns()}
1252
-
1253
- salt_features_list.append(features)
1254
-
1255
- # Convert to DataFrame and concatenate
1256
- salt_df = pd.DataFrame(salt_features_list)
1257
- return pd.concat([df, salt_df], axis=1)
1258
-
1259
-
1260
- def feature_resolution_issues(df: pd.DataFrame, features: List[str], show_cols: Optional[List[str]] = None) -> None:
1261
- """
1262
- Identify and print groups in a DataFrame where the given features have more than one unique SMILES,
1263
- sorted by group size (largest number of unique SMILES first).
1264
-
1265
- Args:
1266
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
1267
- features (List[str]): List of features to check.
1268
- show_cols (Optional[List[str]]): Columns to display; defaults to all columns.
1269
- """
1270
- # Check for the 'smiles' column (case-insensitive)
1271
- smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
1272
- if smiles_column is None:
1273
- raise ValueError("Input DataFrame must have a 'smiles' column")
1274
-
1275
- show_cols = show_cols if show_cols is not None else df.columns.tolist()
1276
-
1277
- # Drop duplicates to keep only unique SMILES for each feature combination
1278
- unique_df = df.drop_duplicates(subset=[smiles_column] + features)
1279
-
1280
- # Find groups with more than one unique SMILES
1281
- group_counts = unique_df.groupby(features).size()
1282
- collision_groups = group_counts[group_counts > 1].sort_values(ascending=False)
1283
-
1284
- # Print each group in order of size (largest first)
1285
- for group, count in collision_groups.items():
1286
- # Get the rows for this group
1287
- if isinstance(group, tuple):
1288
- group_mask = (unique_df[features] == group).all(axis=1)
1289
- else:
1290
- group_mask = unique_df[features[0]] == group
1291
-
1292
- group_df = unique_df[group_mask]
1293
-
1294
- print(f"Feature Group (unique SMILES: {count}):")
1295
- print(group_df[show_cols])
1296
- print("\n")
1297
-
1298
-
1299
- if __name__ == "__main__":
1300
- from workbench.api import DataSource
1301
-
1302
- # Set pandas display options
1303
- pd.options.display.max_columns = 20
1304
- pd.options.display.max_colwidth = 200
1305
- pd.options.display.width = 1400
1306
-
1307
- # Test data
1308
- # Create test molecules with known E/Z stereochemistry
1309
- test_smiles = [
1310
- # E (trans) examples
1311
- "C/C=C/C", # trans-2-butene
1312
- "C/C=C/Cl", # trans-2-chloro-2-butene
1313
- "ClC=CCl", # non-stereo notation
1314
- "Cl/C=C/Cl", # trans-1,2-dichloroethene
1315
- # Z (cis) examples
1316
- "C/C=C\\C", # cis-2-butene
1317
- "C/C=C\\Cl", # cis-2-chloro-2-butene
1318
- "Cl/C=C\\Cl", # cis-1,2-dichloroethene
1319
- # More complex examples
1320
- "C/C=C/C=C", # trans-2,4-hexadiene
1321
- "C/C=C\\C=C", # mix of cis and trans
1322
- "C/C=C/C=C/C", # all-trans-2,4,6-octatriene
1323
- "C/C(Cl)=C\\C", # substituted example
1324
- # Non-stereochemical double bonds
1325
- "C=C", # ethene (no stereochemistry)
1326
- "C=CC=C", # 1,3-butadiene (no specified stereochemistry)
1327
- "C1=CCCCC1", # cyclohexene (no stereochemistry possible)
1328
- # Compare with chiral centers
1329
- "C[C@H](Cl)Br", # chiral molecule
1330
- "CC(Cl)Br" # non-chiral notation
1331
- "N[C@H]1CC[C@@H](CC1)[NH2+]CCF", # From RDKIT/Github discussion example
1332
- ]
1333
-
1334
- # AQSol Smiles
1335
- aqsol_smiles = [
1336
- r"CCCCCCCC\\C=C\\CCCCCCCCNCCCNCCCNCCCN",
1337
- r"COC1=CC=C(C=C1N\\N=C1/C(=O)C(=CC2=CC=CC=C12)C(=O)NC1=CC(Cl)=CC=C1C)C(=O)NC1=CC=CC=C1",
1338
- r"NC(=O)N\\N=C\\C(O)C(O)C(O)CO",
1339
- r"C1=CC=C(C=C1)\\N=N\\C1=CC=CC=C1",
1340
- r"CC(=O)N\\N=C\\C1=CC=C(O1)[N+]([O-])=O",
1341
- r"CC(=O)OCCN(CCC#N)C1=CC=C(C=C1)\\N=N\\C1=CC=C(C=C1)[N+]([O-])=O",
1342
- r"ClC1=CC=C(Cl)C(N\\N=C2/C(=O)C(=CC3=CC=CC=C23)C(=O)NC2=CC=C3NC(=O)NC3=C2)=C1",
1343
- r"NC1=CC=C(C=C1)\\N=N\\C1=CC=CC=C1",
1344
- r"OC(=O)\\C=C/C=C\\C(O)=O",
1345
- r"CCOC(=O)\\C=C\\C1=CC=CC=C1",
1346
- r"CC(=O)\\C=C\\C1=C(C)CCCC1(C)C",
1347
- r"C\\C(=C/C(O)=O)C(O)=O",
1348
- r"CCC\\C=C\\C",
1349
- r"CC1=NN(C(=O)\\C1=N\\NC1=CC=C(C=C1Cl)C1=CC=C(N\\N=C2/C(C)=NN(C2=O)C2=CC=CC=C2)C(Cl)=C1)C1=CC=CC=C1",
1350
- r"OC(C1=CC2C3C(C1\\C2=C(\\C1=CC=CC=C1)C1=CC=CC=N1)C(=O)NC3=O)(C1=CC=CC=C1)C1=CC=CC=N1",
1351
- r"COC1=CC=C(\\C=C\\C(=O)C2=C(O)C=CC=C2)C=C1",
1352
- r"CC\\C(=C(\\CC)C1=CC=C(O)C=C1)C1=CC=C(O)C=C1",
1353
- r"C\\C=C\\OC1CCC(CC1)O\\C=C\\C",
1354
- r"CC(C)=C[C@@H]1[C@@H](C(=O)O[C@H]2CC(=O)C(C\\C=C/C=C)=C2C)C1(C)C",
1355
- r"CC\\C=C\\C",
1356
- r"COC(=O)C(\\C)=C\\[C@@H]1[C@@H](C(=O)O[C@H]2CC(=O)C(C\\C=C/C=C)=C2C)C1(C)C",
1357
- r"CC1=C(F)C(F)=C(COC(=O)C2C(\\C=C(/Cl)C(F)(F)F)C2(C)C)C(F)=C1F",
1358
- r"CCC(=O)OC\\C=C(/C)\\C=C\\C=C(/C)\\C=C\\C1=C(C)CCCC1(C)C",
1359
- r"CC(=O)C(\\C)=C/C1C(C)=CCCC1(C)C",
1360
- r"CC(=O)C(\\N=N\\C1=CC=CC=C1C(O)=O)C(=O)NC1=CC=C2NC(=O)NC2=C1",
1361
- r"O\\N=C1\\CCCC=C1",
1362
- r"CCCCCCCCCCCCCCCC(=O)NCCCCCCCC\\C=C/CCCCCCCC",
1363
- r"ClC\\C=C/CCl",
1364
- r"CC(=O)C(\\N=N\\C1=CC=C(Cl)C=C1[N+]([O-])=O)C(=O)NC1=CC=C2NC(=O)NC2=C1",
1365
- r"OC(=O)\\C=C(/Cl)C1=CC=CC=C1",
1366
- r"CC(=O)C(\\N=N\\C1=CC=C(C=C1)[N+]([O-])=O)C(=O)NC1=CC=C2NC(=O)NC2=C1",
1367
- r"CC\\C=C/CCO",
1368
- ]
1369
- all_smiles = test_smiles + aqsol_smiles
1370
-
1371
- # Create molecules
1372
- mols = [Chem.MolFromSmiles(s) for s in all_smiles]
1373
-
1374
- # Create test dataframe
1375
- df = pd.DataFrame({"smiles": all_smiles, "molecule": mols})
1376
-
1377
- # Test Stereochemistry Descriptors
1378
- # See: https://github.com/rdkit/rdkit/discussions/6567
1379
- df = compute_stereochemistry_descriptors(df)
1380
- # Print all the columns except molecule
1381
- print(df.drop(columns=["molecule"]))
1382
-
1383
- # Toxicity tests
1384
- smiles = "O=C(CCl)c1ccc(Cl)cc1Cl"
1385
- mol = Chem.MolFromSmiles(smiles)
1386
- print(toxic_elements(mol))
1387
- print(toxic_groups(mol))
1388
-
1389
- # Pyridone molecule
1390
- smiles = "C1=CC=NC(=O)C=C1"
1391
- show(smiles)
1392
-
1393
- # SVG image of the molecule
1394
- svg = svg_from_smiles(smiles)
1395
-
1396
- # PIL image of the molecule
1397
- img = img_from_smiles(smiles)
1398
- print(type(img))
1399
-
1400
- # Test the concentration conversion functions
1401
- df = pd.DataFrame({"smiles": [smiles, smiles, smiles, smiles, smiles, smiles], "µM": [500, 50, 5, 1, 0.1, 0]})
1402
-
1403
- # Convert µM to log10
1404
- df["log10"] = micromolar_to_log(df["µM"])
1405
- print(df)
1406
-
1407
- # Convert log10 back to µM
1408
- df["µM_new"] = log_to_micromolar(df["log10"])
1409
- print(df)
1410
-
1411
- # Convert log10 to categories
1412
- df["category"] = log_to_category(df["log10"])
1413
- print(df)
1414
-
1415
- # Test drug-likeness filter and print results
1416
- druglike_smiles = ["CC(C)=CCC\\C(C)=C/CO", "CC(C)CCCCCOC(=O)CCS", "OC(=O)CCCCCCCCC=C", "CC(C)(C)CCCCCC(=O)OC=C"]
1417
- mols = [Chem.MolFromSmiles(smile) for smile in druglike_smiles]
1418
- druglike = [is_druglike_compound(mol) for mol in mols]
1419
-
1420
- for smile, is_druglike in zip(druglike_smiles, druglike):
1421
- print(f"SMILES: {smile} -> Drug-like: {is_druglike}")
1422
-
1423
- # Test mol/None issue
1424
- df = DataSource("aqsol_data").pull_dataframe()[:100]
1425
- mol_df = compute_molecular_descriptors(df)
1426
-
1427
- # Compute Molecular Descriptors
1428
- df = pd.DataFrame({"smiles": [smiles, smiles, smiles, smiles, smiles]})
1429
- df = compute_molecular_descriptors(df)
1430
- print(df)
1431
-
1432
- df = DataSource("aqsol_data").pull_dataframe()[:1000]
1433
-
1434
- # Test compound tags
1435
- df["molecule"] = df["smiles"].apply(Chem.MolFromSmiles)
1436
- df = add_compound_tags(df)
1437
-
1438
- # Compute Molecular Descriptors
1439
- df = compute_molecular_descriptors(df)
1440
- print(df)
1441
-
1442
- # Compute Morgan Fingerprints
1443
- df = compute_morgan_fingerprints(df)
1444
- print(df)
1445
-
1446
- # Debug a few compounds that have issues
1447
- debug_smiles = {
1448
- "id": ["B-1579", "B-1866"],
1449
- "smiles": ["CC1=CC=C[NH++]([O-])[CH-]1", "OC(=O)C1=C[NH++]([O-])[CH-]C=C1"],
1450
- }
1451
- debug_df = pd.DataFrame(debug_smiles)
1452
- debug_df = compute_morgan_fingerprints(debug_df)
1453
- print(debug_df)
1454
-
1455
- # Project Fingerprints
1456
- df = project_fingerprints(df, projection="UMAP")
1457
- print(df)
1458
-
1459
- # Perform Tautomerization
1460
- df = tautomerize_smiles(df)
1461
- print(df)
1462
-
1463
- # Test Rollup Experimental Data
1464
- test_data = {
1465
- "id": ["1", "1", "2", "2", "3", "4", "4", "5", "5", "6", "6"],
1466
- "time_hr": [1, 4, 3, 3, 2, np.nan, 5, 6, 6, np.nan, np.nan],
1467
- "target_value": [1.90, 4.03, 2.5, 3.5, 7.8, 6.2, 8.1, np.nan, 5.4, 6.7, 6.9],
1468
- "smiles": [
1469
- "CC(=O)O", # Acetic acid
1470
- "CC(=O)O",
1471
- "C1CCCCC1", # Cyclohexane
1472
- "C1CCCCC1",
1473
- "C1=CC=CC=C1", # Benzene
1474
- "CCO", # Ethanol
1475
- "CCO",
1476
- "CC(C)=O", # Acetone
1477
- "CC(C)=O",
1478
- "CC(C)=O",
1479
- "CC(C)=O",
1480
- ],
1481
- }
1482
-
1483
- # Create test DataFrame
1484
- test_df = pd.DataFrame(test_data)
1485
- print("Original Test DataFrame:")
1486
- print(test_df)
1487
- print("\n")
1488
-
1489
- # Test with arithmetic mean
1490
- result_df = rollup_experimental_data(test_df, id="id", time="time_hr", target="target_value", use_gmean=False)
1491
- print("Result with Arithmetic Mean:")
1492
- print(result_df)
1493
- print("\n")
1494
-
1495
- # Test with geometric mean
1496
- result_df_gmean = rollup_experimental_data(test_df, id="id", time="time_hr", target="target_value", use_gmean=True)
1497
- print("Result with Geometric Mean:")
1498
- print(result_df_gmean)
1499
-
1500
- # Test some salted compounds
1501
- test_data = {
1502
- "id": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
1503
- "target_value": [1.90, 4.03, 2.5, 3.5, 7.8, 6.2, 8.1, 6.9, 5.4, 3.2, 4.8, 7.1],
1504
- "smiles": [
1505
- "CC(=O)O", # Acetic acid (no salt)
1506
- "C1CCCCC1", # Cyclohexane (no salt)
1507
- "C1=CC=CC=C1", # Benzene (no salt)
1508
- "CC(=O)O.[K+]", # Potassium acetate (metal cation)
1509
- "CC(=O)O.[Ca+2]", # Calcium acetate (metal cation)
1510
- "CC(=O)O.[Na+]", # Sodium acetate (metal cation)
1511
- "CCO.Cl", # Ethanol hydrochloride (halide anion)
1512
- "C1=CC=CC=C1.O.O", # Benzene hydrate (inorganic)
1513
- "CC(=O)[O-].C[NH3+]", # Methylammonium acetate (organic anion + cation)
1514
- "c1ccc(cc1)[NH3+].[Cl-]", # Aniline HCl (organic cation, halide anion)
1515
- "CC(=O)[O-].CC[NH3+]", # Ethylammonium acetate (organic anion + cation)
1516
- "CCO.[Br-].[Na+]", # Multiple salt components
1517
- ],
1518
- }
1519
-
1520
- # Create test DataFrame
1521
- test_df = pd.DataFrame(test_data)
1522
-
1523
- # Convert SMILES to molecules
1524
- test_df["molecule"] = test_df["smiles"].apply(Chem.MolFromSmiles)
1525
-
1526
- # Test individual function
1527
- print("Testing individual salt feature extraction:")
1528
- for i, row in test_df.iterrows():
1529
- if i < 3: # Test first few
1530
- features, clean_mol = extract_advanced_salt_features(row["molecule"])
1531
- print(f"SMILES: {row['smiles']}")
1532
- print(f"Features: {features}")
1533
- print(f"Clean mol atoms: {clean_mol.GetNumAtoms() if clean_mol else 'None'}")
1534
- print("---")
1535
-
1536
- # Test full DataFrame processing
1537
- print("\nTesting DataFrame processing:")
1538
- result_df = add_salt_features(test_df)
1539
-
1540
- # Display results focusing on salt-related columns
1541
- salt_cols = _get_salt_feature_columns()
1542
- display_cols = ["smiles"] + salt_cols
1543
- print(result_df[display_cols].to_string())
1544
-
1545
- # Summary stats
1546
- print(f"\nDataFrame shape before: {test_df.shape}")
1547
- print(f"DataFrame shape after: {result_df.shape}")
1548
- print(f"Compounds with salts: {result_df['has_salt'].sum()}")
1549
-
1550
- # Test the SDF file writing
1551
- my_sdf_file = "test_compounds.sdf"
1552
- df_to_sdf_file(test_df, my_sdf_file, skip_invalid=False)
1553
-
1554
- # Test the SDF file reading
1555
- df = sdf_file_to_df(my_sdf_file)
1556
- print(df)