workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
- workbench/algorithms/dataframe/proximity.py +261 -235
- workbench/algorithms/graph/light/proximity_graph.py +10 -8
- workbench/api/__init__.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/endpoint.py +11 -0
- workbench/api/feature_set.py +11 -8
- workbench/api/meta.py +5 -2
- workbench/api/model.py +16 -15
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +11 -3
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/endpoint_core.py +256 -118
- workbench/core/artifacts/feature_set_core.py +265 -16
- workbench/core/artifacts/model_core.py +107 -60
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +42 -32
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_scripts/chemprop/chemprop.template +852 -0
- workbench/model_scripts/chemprop/generated_model_script.py +852 -0
- workbench/model_scripts/chemprop/requirements.txt +11 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
- workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
- workbench/model_scripts/pytorch_model/pytorch.template +370 -187
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +17 -9
- workbench/model_scripts/uq_models/generated_model_script.py +605 -0
- workbench/model_scripts/uq_models/mapie.template +605 -0
- workbench/model_scripts/uq_models/requirements.txt +1 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
- workbench/model_scripts/xgb_model/xgb_model.template +44 -46
- workbench/repl/workbench_shell.py +28 -14
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +760 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/model_utils.py +95 -34
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +526 -0
- workbench/utils/shap_utils.py +10 -2
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_model_utils.py +371 -156
- workbench/web_interface/components/model_plot.py +7 -1
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +9 -7
- workbench/web_interface/components/plugins/scatter_plot.py +3 -3
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
"""
|
|
2
|
+
mol_tagging.py - Molecular property tagging for ADMET modeling
|
|
3
|
+
Adds a 'tags' column to DataFrames for filtering and classification
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import List, Set, Optional
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from rdkit import Chem
|
|
10
|
+
from rdkit.Chem import Mol, Descriptors
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ============================================================================
|
|
16
|
+
# Property Detection Functions (Internal)
|
|
17
|
+
# ============================================================================
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _get_metal_tags(mol: Mol) -> Set[str]:
|
|
21
|
+
"""Detect metal-related tags."""
|
|
22
|
+
tags = set()
|
|
23
|
+
if mol is None:
|
|
24
|
+
return tags
|
|
25
|
+
|
|
26
|
+
# Metalloenzyme-relevant metals
|
|
27
|
+
metalloenzyme_metals = {"Zn", "Cu", "Fe", "Mn", "Co", "Ni", "Mo", "V"}
|
|
28
|
+
|
|
29
|
+
# Heavy/toxic metals
|
|
30
|
+
heavy_metals = {"Pb", "Hg", "Cd", "As", "Cr", "Tl", "Ba", "Be", "Al", "Sb", "Se", "Bi", "Ag"}
|
|
31
|
+
|
|
32
|
+
for atom in mol.GetAtoms():
|
|
33
|
+
symbol = atom.GetSymbol()
|
|
34
|
+
if symbol in metalloenzyme_metals:
|
|
35
|
+
tags.add("metalloenzyme_metal")
|
|
36
|
+
if symbol in heavy_metals:
|
|
37
|
+
tags.add("heavy_metal")
|
|
38
|
+
|
|
39
|
+
return tags
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_halogen_tags(mol: Mol) -> Set[str]:
|
|
43
|
+
"""Detect halogenation patterns."""
|
|
44
|
+
tags = set()
|
|
45
|
+
if mol is None:
|
|
46
|
+
return tags
|
|
47
|
+
|
|
48
|
+
# Count halogens
|
|
49
|
+
halogen_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() in ["F", "Cl", "Br", "I"])
|
|
50
|
+
|
|
51
|
+
if halogen_count > 0:
|
|
52
|
+
tags.add("halogenated")
|
|
53
|
+
|
|
54
|
+
# Flag heavily halogenated compounds
|
|
55
|
+
heavy_atom_count = mol.GetNumHeavyAtoms()
|
|
56
|
+
if heavy_atom_count > 0:
|
|
57
|
+
halogen_ratio = halogen_count / heavy_atom_count
|
|
58
|
+
if halogen_ratio > 0.5 or halogen_count > 4:
|
|
59
|
+
tags.add("highly_halogenated")
|
|
60
|
+
|
|
61
|
+
return tags
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_druglike_tags(mol: Mol) -> Set[str]:
|
|
65
|
+
"""Assess drug-likeness properties."""
|
|
66
|
+
tags = set()
|
|
67
|
+
if mol is None:
|
|
68
|
+
return tags
|
|
69
|
+
|
|
70
|
+
# Calculate descriptors once
|
|
71
|
+
mw = Descriptors.MolWt(mol)
|
|
72
|
+
logp = Descriptors.MolLogP(mol)
|
|
73
|
+
hbd = Descriptors.NumHDonors(mol)
|
|
74
|
+
hba = Descriptors.NumHAcceptors(mol)
|
|
75
|
+
rotatable = Descriptors.NumRotatableBonds(mol)
|
|
76
|
+
tpsa = Descriptors.TPSA(mol)
|
|
77
|
+
|
|
78
|
+
# Lipinski's Rule of Five
|
|
79
|
+
ro5_violations = 0
|
|
80
|
+
if mw > 500:
|
|
81
|
+
ro5_violations += 1
|
|
82
|
+
if logp > 5:
|
|
83
|
+
ro5_violations += 1
|
|
84
|
+
if hbd > 5:
|
|
85
|
+
ro5_violations += 1
|
|
86
|
+
if hba > 10:
|
|
87
|
+
ro5_violations += 1
|
|
88
|
+
|
|
89
|
+
if ro5_violations <= 1:
|
|
90
|
+
tags.add("ro5_pass")
|
|
91
|
+
if ro5_violations == 0:
|
|
92
|
+
tags.add("ro5_strict")
|
|
93
|
+
|
|
94
|
+
# Veber's rules
|
|
95
|
+
if rotatable <= 10 and tpsa <= 140:
|
|
96
|
+
tags.add("veber_pass")
|
|
97
|
+
|
|
98
|
+
# Lead-like
|
|
99
|
+
if 150 <= mw <= 350 and -3 <= logp <= 3.5:
|
|
100
|
+
tags.add("lead_like")
|
|
101
|
+
|
|
102
|
+
# Fragment-like (Rule of Three)
|
|
103
|
+
if mw <= 300 and logp <= 3 and hbd <= 3 and hba <= 3:
|
|
104
|
+
tags.add("fragment_like")
|
|
105
|
+
|
|
106
|
+
# Size categories
|
|
107
|
+
if mw < 200:
|
|
108
|
+
tags.add("small_molecule")
|
|
109
|
+
elif mw > 700:
|
|
110
|
+
tags.add("large_molecule")
|
|
111
|
+
|
|
112
|
+
return tags
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _get_structural_tags(mol: Mol) -> Set[str]:
|
|
116
|
+
"""Detect structural features."""
|
|
117
|
+
tags = set()
|
|
118
|
+
if mol is None:
|
|
119
|
+
return tags
|
|
120
|
+
|
|
121
|
+
# Check for multiple fragments
|
|
122
|
+
if len(Chem.GetMolFrags(mol)) > 1:
|
|
123
|
+
tags.add("multi_fragment")
|
|
124
|
+
|
|
125
|
+
# Check for rings
|
|
126
|
+
ring_info = mol.GetRingInfo()
|
|
127
|
+
if ring_info.NumRings() == 0:
|
|
128
|
+
tags.add("acyclic")
|
|
129
|
+
else:
|
|
130
|
+
tags.add("cyclic")
|
|
131
|
+
# Check for aromatic rings by checking if any ring atoms are aromatic
|
|
132
|
+
for ring in ring_info.AtomRings():
|
|
133
|
+
if any(mol.GetAtomWithIdx(idx).GetIsAromatic() for idx in ring):
|
|
134
|
+
tags.add("aromatic")
|
|
135
|
+
break
|
|
136
|
+
|
|
137
|
+
# Check for chirality
|
|
138
|
+
if any(atom.GetChiralTag() != Chem.ChiralType.CHI_UNSPECIFIED for atom in mol.GetAtoms()):
|
|
139
|
+
tags.add("chiral")
|
|
140
|
+
|
|
141
|
+
return tags
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# ============================================================================
|
|
145
|
+
# Main Tagging Function
|
|
146
|
+
# ============================================================================
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def tag_molecules(
|
|
150
|
+
df: pd.DataFrame,
|
|
151
|
+
smiles_column: str = "smiles",
|
|
152
|
+
tag_column: str = "tags",
|
|
153
|
+
tag_categories: Optional[List[str]] = None,
|
|
154
|
+
) -> pd.DataFrame:
|
|
155
|
+
"""
|
|
156
|
+
Add molecular property tags to a DataFrame.
|
|
157
|
+
|
|
158
|
+
Designed to work after mol_standardize.py processing.
|
|
159
|
+
Adds a single 'tags' column containing a list of string tags.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
df: Input DataFrame with SMILES
|
|
163
|
+
smiles_column: Column containing SMILES strings
|
|
164
|
+
tag_column: Name for output tags column (default: "tags")
|
|
165
|
+
tag_categories: Which tag categories to include. Options:
|
|
166
|
+
- "metals": Metal content tags
|
|
167
|
+
- "halogens": Halogenation tags
|
|
168
|
+
- "druglike": Drug-likeness assessments
|
|
169
|
+
- "structure": Structural features
|
|
170
|
+
- None (default): Include all categories
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
DataFrame with tags column added
|
|
174
|
+
|
|
175
|
+
Example:
|
|
176
|
+
df = tag_molecules(df) # Add all tags
|
|
177
|
+
df = tag_molecules(df, tag_categories=["druglike"]) # Only drug-likeness
|
|
178
|
+
"""
|
|
179
|
+
result = df.copy()
|
|
180
|
+
|
|
181
|
+
# Default to all categories
|
|
182
|
+
if tag_categories is None:
|
|
183
|
+
tag_categories = ["metals", "halogens", "druglike", "structure"]
|
|
184
|
+
|
|
185
|
+
# Initialize tags column
|
|
186
|
+
all_tags = []
|
|
187
|
+
|
|
188
|
+
# Process each molecule
|
|
189
|
+
for idx, row in result.iterrows():
|
|
190
|
+
# Parse SMILES to molecule
|
|
191
|
+
smiles = row[smiles_column]
|
|
192
|
+
if pd.isna(smiles) or smiles == "":
|
|
193
|
+
all_tags.append(["invalid_smiles"])
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
197
|
+
if mol is None:
|
|
198
|
+
all_tags.append(["invalid_smiles"])
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
# Collect tags based on categories
|
|
202
|
+
tags = set()
|
|
203
|
+
|
|
204
|
+
if "metals" in tag_categories:
|
|
205
|
+
tags.update(_get_metal_tags(mol))
|
|
206
|
+
|
|
207
|
+
if "halogens" in tag_categories:
|
|
208
|
+
tags.update(_get_halogen_tags(mol))
|
|
209
|
+
|
|
210
|
+
if "druglike" in tag_categories:
|
|
211
|
+
tags.update(_get_druglike_tags(mol))
|
|
212
|
+
|
|
213
|
+
if "structure" in tag_categories:
|
|
214
|
+
tags.update(_get_structural_tags(mol))
|
|
215
|
+
|
|
216
|
+
# Convert to sorted list for consistency
|
|
217
|
+
all_tags.append(sorted(list(tags)))
|
|
218
|
+
|
|
219
|
+
# Add tags column
|
|
220
|
+
result[tag_column] = all_tags
|
|
221
|
+
|
|
222
|
+
# Log summary
|
|
223
|
+
total = len(result)
|
|
224
|
+
valid = sum(1 for tags in all_tags if "invalid_smiles" not in tags)
|
|
225
|
+
ro5_pass = sum(1 for tags in all_tags if "ro5_pass" in tags)
|
|
226
|
+
|
|
227
|
+
logger.info(f"Tagged {total} molecules: {valid} valid, {ro5_pass} pass Ro5")
|
|
228
|
+
|
|
229
|
+
return result
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
# ============================================================================
|
|
233
|
+
# Utility Functions
|
|
234
|
+
# ============================================================================
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def filter_by_tags(
|
|
238
|
+
df: pd.DataFrame, require: Optional[List[str]] = None, exclude: Optional[List[str]] = None, tag_column: str = "tags"
|
|
239
|
+
) -> pd.DataFrame:
|
|
240
|
+
"""
|
|
241
|
+
Filter DataFrame rows based on tags.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
df: DataFrame with tags column
|
|
245
|
+
require: Tags that must be present (AND logic)
|
|
246
|
+
exclude: Tags that must not be present
|
|
247
|
+
tag_column: Name of tags column
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Filtered DataFrame
|
|
251
|
+
|
|
252
|
+
Example:
|
|
253
|
+
# Get drug-like molecules without heavy metals
|
|
254
|
+
filtered = filter_by_tags(df,
|
|
255
|
+
require=["ro5_pass"],
|
|
256
|
+
exclude=["heavy_metal"])
|
|
257
|
+
"""
|
|
258
|
+
result = df.copy()
|
|
259
|
+
|
|
260
|
+
if require:
|
|
261
|
+
for tag in require:
|
|
262
|
+
result = result[result[tag_column].apply(lambda x: tag in x)]
|
|
263
|
+
|
|
264
|
+
if exclude:
|
|
265
|
+
for tag in exclude:
|
|
266
|
+
result = result[result[tag_column].apply(lambda x: tag not in x)]
|
|
267
|
+
|
|
268
|
+
logger.info(f"Filtered {len(df)} → {len(result)} molecules")
|
|
269
|
+
|
|
270
|
+
return result
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def get_tag_summary(df: pd.DataFrame, tag_column: str = "tags") -> pd.Series:
|
|
274
|
+
"""
|
|
275
|
+
Get summary statistics of tags in DataFrame.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
df: DataFrame with tags column
|
|
279
|
+
tag_column: Name of tags column
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Series with tag counts
|
|
283
|
+
"""
|
|
284
|
+
# Flatten all tags and count
|
|
285
|
+
all_tags = []
|
|
286
|
+
for tags_list in df[tag_column]:
|
|
287
|
+
all_tags.extend(tags_list)
|
|
288
|
+
|
|
289
|
+
tag_counts = pd.Series(all_tags).value_counts()
|
|
290
|
+
return tag_counts
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
if __name__ == "__main__":
|
|
294
|
+
# Test the tagging functionality
|
|
295
|
+
print("Testing molecular tagging system")
|
|
296
|
+
print("=" * 60)
|
|
297
|
+
|
|
298
|
+
# Create test dataset
|
|
299
|
+
test_data = pd.DataFrame(
|
|
300
|
+
{
|
|
301
|
+
"smiles": [
|
|
302
|
+
"CC(=O)Oc1ccccc1C(=O)O", # Aspirin
|
|
303
|
+
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
|
|
304
|
+
"C" * 50, # Large alkane
|
|
305
|
+
"C(Cl)(Cl)(Cl)Cl", # Carbon tetrachloride
|
|
306
|
+
"[Zn+2].[Cl-].[Cl-]", # Zinc chloride
|
|
307
|
+
"CCC", # Propane
|
|
308
|
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen
|
|
309
|
+
"[Pb+2].[O-]C(=O)C", # Lead acetate
|
|
310
|
+
"", # Empty
|
|
311
|
+
"INVALID_SMILES", # Invalid
|
|
312
|
+
],
|
|
313
|
+
"compound_id": [f"C{i:03d}" for i in range(1, 11)],
|
|
314
|
+
}
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
print("Input data:")
|
|
318
|
+
print(test_data[["compound_id", "smiles"]])
|
|
319
|
+
|
|
320
|
+
# Apply tagging
|
|
321
|
+
print("\n" + "=" * 60)
|
|
322
|
+
print("Applying molecular tags...")
|
|
323
|
+
tagged_df = tag_molecules(test_data)
|
|
324
|
+
|
|
325
|
+
print("\nTagged results:")
|
|
326
|
+
for _, row in tagged_df.iterrows():
|
|
327
|
+
tags_str = ", ".join(row["tags"]) if row["tags"] else "none"
|
|
328
|
+
print(f"{row['compound_id']}: {tags_str}")
|
|
329
|
+
|
|
330
|
+
# Test filtering
|
|
331
|
+
print("\n" + "=" * 60)
|
|
332
|
+
print("Testing filters...")
|
|
333
|
+
|
|
334
|
+
# Get drug-like molecules
|
|
335
|
+
druglike = filter_by_tags(tagged_df, require=["ro5_pass"])
|
|
336
|
+
print(f"Drug-like molecules: {list(druglike['compound_id'])}")
|
|
337
|
+
|
|
338
|
+
# Exclude problematic molecules
|
|
339
|
+
clean = filter_by_tags(tagged_df, exclude=["heavy_metal", "highly_halogenated", "invalid_smiles"])
|
|
340
|
+
print(f"Clean molecules: {list(clean['compound_id'])}")
|
|
341
|
+
|
|
342
|
+
# Get tag summary
|
|
343
|
+
print("\n" + "=" * 60)
|
|
344
|
+
print("Tag summary:")
|
|
345
|
+
summary = get_tag_summary(tagged_df)
|
|
346
|
+
print(summary.head(10))
|
|
347
|
+
|
|
348
|
+
print("\n✅ All tests completed!")
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""Dimensionality reduction and projection utilities for molecular fingerprints"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from sklearn.manifold import TSNE
|
|
7
|
+
|
|
8
|
+
# Try importing UMAP
|
|
9
|
+
try:
|
|
10
|
+
import umap
|
|
11
|
+
except ImportError:
|
|
12
|
+
umap = None
|
|
13
|
+
|
|
14
|
+
# Set up the logger
|
|
15
|
+
log = logging.getLogger("workbench")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
|
|
19
|
+
"""
|
|
20
|
+
Convert bitstring fingerprints to numpy matrix.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
fingerprints: pandas Series or list of bitstring fingerprints
|
|
24
|
+
dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
dense numpy array of shape (n_molecules, n_bits)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# Dense matrix representation (we might support sparse in the future)
|
|
31
|
+
return np.array([list(fp) for fp in fingerprints], dtype=dtype)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame:
|
|
35
|
+
"""Project fingerprints onto a 2D plane using dimensionality reduction techniques.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
df (pd.DataFrame): Input DataFrame containing fingerprint data.
|
|
39
|
+
projection (str): Dimensionality reduction technique to use (TSNE or UMAP).
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
pd.DataFrame: The input DataFrame with the projected coordinates added as 'x' and 'y' columns.
|
|
43
|
+
"""
|
|
44
|
+
# Check for the fingerprint column (case-insensitive)
|
|
45
|
+
fingerprint_column = next((col for col in df.columns if "fingerprint" in col.lower()), None)
|
|
46
|
+
if fingerprint_column is None:
|
|
47
|
+
raise ValueError("Input DataFrame must have a fingerprint column")
|
|
48
|
+
|
|
49
|
+
# Create a matrix of fingerprints
|
|
50
|
+
X = fingerprints_to_matrix(df[fingerprint_column])
|
|
51
|
+
|
|
52
|
+
# Get number of samples
|
|
53
|
+
n_samples = X.shape[0]
|
|
54
|
+
|
|
55
|
+
# Check for UMAP availability
|
|
56
|
+
if projection == "UMAP" and umap is None:
|
|
57
|
+
log.warning("UMAP is not available. Using TSNE instead.")
|
|
58
|
+
projection = "TSNE"
|
|
59
|
+
|
|
60
|
+
# Run the projection
|
|
61
|
+
if projection == "TSNE":
|
|
62
|
+
# Adjust perplexity based on dataset size
|
|
63
|
+
# Perplexity must be less than n_samples and at least 1
|
|
64
|
+
perplexity = min(30, max(1, n_samples - 1))
|
|
65
|
+
|
|
66
|
+
# TSNE requires at least 4 samples
|
|
67
|
+
if n_samples < 4:
|
|
68
|
+
log.warning(f"Dataset too small for TSNE (n={n_samples}). Need at least 4 samples.")
|
|
69
|
+
# Return with random coordinates for very small datasets
|
|
70
|
+
df["x"] = np.random.uniform(-10, 10, n_samples)
|
|
71
|
+
df["y"] = np.random.uniform(-10, 10, n_samples)
|
|
72
|
+
return df
|
|
73
|
+
|
|
74
|
+
# Run TSNE on the fingerprint matrix
|
|
75
|
+
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
|
|
76
|
+
embedding = tsne.fit_transform(X)
|
|
77
|
+
else:
|
|
78
|
+
# Run UMAP
|
|
79
|
+
# Adjust n_neighbors based on dataset size
|
|
80
|
+
n_neighbors = min(15, n_samples - 1) if n_samples > 1 else 1
|
|
81
|
+
|
|
82
|
+
reducer = umap.UMAP(metric="jaccard", n_neighbors=n_neighbors)
|
|
83
|
+
embedding = reducer.fit_transform(X)
|
|
84
|
+
|
|
85
|
+
# Add coordinates to DataFrame
|
|
86
|
+
df["x"] = embedding[:, 0]
|
|
87
|
+
df["y"] = embedding[:, 1]
|
|
88
|
+
|
|
89
|
+
# If vertices disconnect from the manifold, they are given NaN values (so replace with 0)
|
|
90
|
+
df["x"] = df["x"].fillna(0)
|
|
91
|
+
df["y"] = df["y"].fillna(0)
|
|
92
|
+
|
|
93
|
+
# Jitter
|
|
94
|
+
jitter_scale = 0.1
|
|
95
|
+
df["x"] += np.random.uniform(0, jitter_scale, len(df))
|
|
96
|
+
df["y"] += np.random.uniform(0, jitter_scale, len(df))
|
|
97
|
+
|
|
98
|
+
return df
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
if __name__ == "__main__":
|
|
102
|
+
print("Running molecular projection tests...")
|
|
103
|
+
|
|
104
|
+
from rdkit import Chem
|
|
105
|
+
from rdkit.Chem import rdFingerprintGenerator
|
|
106
|
+
|
|
107
|
+
# Test molecules
|
|
108
|
+
test_molecules = {
|
|
109
|
+
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
|
|
110
|
+
"caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
|
|
111
|
+
"glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O",
|
|
112
|
+
"sodium_acetate": "CC(=O)[O-].[Na+]",
|
|
113
|
+
"benzene": "c1ccccc1",
|
|
114
|
+
"toluene": "Cc1ccccc1",
|
|
115
|
+
"phenol": "Oc1ccccc1",
|
|
116
|
+
"aniline": "Nc1ccccc1",
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
# Generate fingerprints for test
|
|
120
|
+
print("\n1. Generating test fingerprints...")
|
|
121
|
+
|
|
122
|
+
test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
|
|
123
|
+
|
|
124
|
+
# Generate Morgan fingerprints
|
|
125
|
+
mols = [Chem.MolFromSmiles(smi) for smi in test_df["SMILES"]]
|
|
126
|
+
morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=512)
|
|
127
|
+
fingerprints = [morgan_gen.GetFingerprint(mol).ToBitString() if mol else None for mol in mols]
|
|
128
|
+
test_df["fingerprint"] = fingerprints
|
|
129
|
+
|
|
130
|
+
# Remove any failed molecules
|
|
131
|
+
test_df = test_df.dropna(subset=["fingerprint"])
|
|
132
|
+
print(f" Generated {len(test_df)} fingerprints")
|
|
133
|
+
|
|
134
|
+
# Test 2: Fingerprint to matrix conversion
|
|
135
|
+
print("\n2. Testing fingerprint matrix conversion...")
|
|
136
|
+
|
|
137
|
+
matrix = fingerprints_to_matrix(test_df["fingerprint"])
|
|
138
|
+
print(f" Matrix shape: {matrix.shape}")
|
|
139
|
+
print(f" Matrix dtype: {matrix.dtype}")
|
|
140
|
+
print(f" Non-zero elements: {np.count_nonzero(matrix)}")
|
|
141
|
+
|
|
142
|
+
# Test 3: TSNE projection
|
|
143
|
+
print("\n3. Testing TSNE projection...")
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
proj_df = project_fingerprints(test_df.copy(), projection="TSNE")
|
|
147
|
+
|
|
148
|
+
print(" TSNE projection results:")
|
|
149
|
+
for _, row in proj_df.head(4).iterrows():
|
|
150
|
+
print(f" {row['name']:15} → x:{row['x']:7.2f} y:{row['y']:7.2f}")
|
|
151
|
+
|
|
152
|
+
# Check that coordinates were added
|
|
153
|
+
assert "x" in proj_df.columns and "y" in proj_df.columns
|
|
154
|
+
print(f" ✓ Successfully projected {len(proj_df)} molecules")
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
print(f" Note: TSNE projection test limited: {e}")
|
|
158
|
+
|
|
159
|
+
# Test 4: UMAP projection (if available)
|
|
160
|
+
print("\n4. Testing UMAP projection...")
|
|
161
|
+
|
|
162
|
+
if umap is not None:
|
|
163
|
+
try:
|
|
164
|
+
proj_umap_df = project_fingerprints(test_df.copy(), projection="UMAP")
|
|
165
|
+
|
|
166
|
+
print(" UMAP projection results:")
|
|
167
|
+
for _, row in proj_umap_df.head(4).iterrows():
|
|
168
|
+
print(f" {row['name']:15} → x:{row['x']:7.2f} y:{row['y']:7.2f}")
|
|
169
|
+
|
|
170
|
+
print(f" ✓ Successfully projected {len(proj_umap_df)} molecules with UMAP")
|
|
171
|
+
|
|
172
|
+
except Exception as e:
|
|
173
|
+
print(f" Note: UMAP projection failed: {e}")
|
|
174
|
+
else:
|
|
175
|
+
print(" UMAP not available - skipping test")
|
|
176
|
+
|
|
177
|
+
# Test 5: Edge cases
|
|
178
|
+
print("\n5. Testing edge cases...")
|
|
179
|
+
|
|
180
|
+
# Test with missing fingerprint column
|
|
181
|
+
no_fp_df = pd.DataFrame({"SMILES": ["CCO", "CC"]})
|
|
182
|
+
try:
|
|
183
|
+
project_fingerprints(no_fp_df)
|
|
184
|
+
print(" ✗ Should have raised error for missing fingerprint column")
|
|
185
|
+
except ValueError as e:
|
|
186
|
+
print(f" ✓ Correctly raised error for missing fingerprint: {str(e)}")
|
|
187
|
+
|
|
188
|
+
# Test with small dataset (less than perplexity)
|
|
189
|
+
small_df = test_df.head(2).copy()
|
|
190
|
+
if len(small_df) > 0:
|
|
191
|
+
try:
|
|
192
|
+
proj_small = project_fingerprints(small_df, projection="TSNE")
|
|
193
|
+
print(" Note: Small dataset projection handled")
|
|
194
|
+
except Exception as e:
|
|
195
|
+
print(f" Note: Small dataset appropriately failed: {type(e).__name__}")
|
|
196
|
+
|
|
197
|
+
# Test 6: Testing NaN value handling
|
|
198
|
+
print("\n6. Testing NaN value handling...")
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
# The projection should handle NaN values by replacing with 0
|
|
202
|
+
proj_test = project_fingerprints(test_df.copy(), projection="TSNE")
|
|
203
|
+
has_nan = proj_test[["x", "y"]].isnull().any().any()
|
|
204
|
+
print(f" NaN values in output: {has_nan}")
|
|
205
|
+
print(" ✓ NaN values properly handled")
|
|
206
|
+
except Exception as e:
|
|
207
|
+
print(f" Note: Could not test NaN handling due to: {e}")
|
|
208
|
+
|
|
209
|
+
print("\n✅ All projection tests completed!")
|