workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of workbench might be problematic. Click here for more details.
- workbench/algorithms/dataframe/__init__.py +1 -2
- workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
- workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
- workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
- workbench/algorithms/dataframe/projection_2d.py +44 -21
- workbench/algorithms/dataframe/proximity.py +259 -305
- workbench/algorithms/graph/light/proximity_graph.py +14 -12
- workbench/algorithms/models/cleanlab_model.py +382 -0
- workbench/algorithms/models/noise_model.py +388 -0
- workbench/algorithms/sql/outliers.py +3 -3
- workbench/api/__init__.py +5 -1
- workbench/api/compound.py +1 -1
- workbench/api/df_store.py +17 -108
- workbench/api/endpoint.py +18 -5
- workbench/api/feature_set.py +121 -15
- workbench/api/meta.py +5 -2
- workbench/api/meta_model.py +289 -0
- workbench/api/model.py +55 -21
- workbench/api/monitor.py +1 -16
- workbench/api/parameter_store.py +3 -52
- workbench/cached/cached_model.py +4 -4
- workbench/core/artifacts/__init__.py +11 -2
- workbench/core/artifacts/artifact.py +16 -8
- workbench/core/artifacts/data_capture_core.py +355 -0
- workbench/core/artifacts/df_store_core.py +114 -0
- workbench/core/artifacts/endpoint_core.py +382 -253
- workbench/core/artifacts/feature_set_core.py +249 -45
- workbench/core/artifacts/model_core.py +135 -80
- workbench/core/artifacts/monitor_core.py +33 -248
- workbench/core/artifacts/parameter_store_core.py +98 -0
- workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
- workbench/core/cloud_platform/aws/aws_meta.py +12 -5
- workbench/core/cloud_platform/aws/aws_session.py +4 -4
- workbench/core/pipelines/pipeline_executor.py +1 -1
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/core/transforms/features_to_model/features_to_model.py +62 -40
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
- workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
- workbench/core/views/training_view.py +113 -42
- workbench/core/views/view.py +53 -3
- workbench/core/views/view_utils.py +4 -4
- workbench/model_script_utils/model_script_utils.py +339 -0
- workbench/model_script_utils/pytorch_utils.py +405 -0
- workbench/model_script_utils/uq_harness.py +278 -0
- workbench/model_scripts/chemprop/chemprop.template +649 -0
- workbench/model_scripts/chemprop/generated_model_script.py +649 -0
- workbench/model_scripts/chemprop/model_script_utils.py +339 -0
- workbench/model_scripts/chemprop/requirements.txt +3 -0
- workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
- workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
- workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
- workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
- workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
- workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
- workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
- workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
- workbench/model_scripts/meta_model/generated_model_script.py +209 -0
- workbench/model_scripts/meta_model/meta_model.template +209 -0
- workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
- workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
- workbench/model_scripts/pytorch_model/pytorch.template +440 -496
- workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
- workbench/model_scripts/pytorch_model/requirements.txt +1 -1
- workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
- workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
- workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
- workbench/model_scripts/script_generation.py +20 -11
- workbench/model_scripts/uq_models/generated_model_script.py +248 -0
- workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
- workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
- workbench/model_scripts/xgb_model/uq_harness.py +278 -0
- workbench/model_scripts/xgb_model/xgb_model.template +369 -401
- workbench/repl/workbench_shell.py +28 -19
- workbench/resources/open_source_api.key +1 -1
- workbench/scripts/endpoint_test.py +162 -0
- workbench/scripts/lambda_test.py +73 -0
- workbench/scripts/meta_model_sim.py +35 -0
- workbench/scripts/ml_pipeline_batch.py +137 -0
- workbench/scripts/ml_pipeline_sqs.py +186 -0
- workbench/scripts/monitor_cloud_watch.py +20 -100
- workbench/scripts/training_test.py +85 -0
- workbench/utils/aws_utils.py +4 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +175 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +483 -0
- workbench/utils/chem_utils/mol_standardize.py +450 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +219 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/chemprop_utils.py +141 -0
- workbench/utils/cloudwatch_handler.py +1 -1
- workbench/utils/cloudwatch_utils.py +137 -0
- workbench/utils/config_manager.py +3 -7
- workbench/utils/endpoint_utils.py +5 -7
- workbench/utils/license_manager.py +2 -6
- workbench/utils/meta_model_simulator.py +499 -0
- workbench/utils/metrics_utils.py +256 -0
- workbench/utils/model_utils.py +278 -79
- workbench/utils/monitor_utils.py +44 -62
- workbench/utils/pandas_utils.py +3 -3
- workbench/utils/pytorch_utils.py +87 -0
- workbench/utils/shap_utils.py +11 -57
- workbench/utils/workbench_logging.py +0 -3
- workbench/utils/workbench_sqs.py +1 -1
- workbench/utils/xgboost_local_crossfold.py +267 -0
- workbench/utils/xgboost_model_utils.py +127 -219
- workbench/web_interface/components/model_plot.py +14 -2
- workbench/web_interface/components/plugin_unit_test.py +5 -2
- workbench/web_interface/components/plugins/dashboard_status.py +3 -1
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- workbench/web_interface/components/plugins/model_details.py +38 -74
- workbench/web_interface/components/plugins/scatter_plot.py +6 -10
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
- workbench-0.8.220.dist-info/entry_points.txt +11 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
- workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
- workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
- workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
- workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
- workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
- workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
- workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
- workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
- workbench/model_scripts/quant_regression/quant_regression.template +0 -279
- workbench/model_scripts/quant_regression/requirements.txt +0 -1
- workbench/utils/chem_utils.py +0 -1556
- workbench/utils/execution_environment.py +0 -211
- workbench/utils/fast_inference.py +0 -167
- workbench/utils/resource_utils.py +0 -39
- workbench-0.8.162.dist-info/entry_points.txt +0 -5
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
- {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Salt feature extraction utilities for molecular analysis"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from typing import List, Dict, Tuple, Optional, Union
|
|
6
|
+
|
|
7
|
+
# Molecular Descriptor Imports
|
|
8
|
+
from rdkit import Chem
|
|
9
|
+
from rdkit.Chem import Descriptors
|
|
10
|
+
|
|
11
|
+
# Set up the logger
|
|
12
|
+
log = logging.getLogger("workbench")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_salt_feature_columns() -> List[str]:
|
|
16
|
+
"""Internal: Return list of all salt feature column names"""
|
|
17
|
+
return [
|
|
18
|
+
"has_salt",
|
|
19
|
+
"mw_ratio",
|
|
20
|
+
"salt_to_api_ratio",
|
|
21
|
+
"has_metal_salt",
|
|
22
|
+
"has_halide",
|
|
23
|
+
"ionic_strength_proxy",
|
|
24
|
+
"has_organic_salt",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _classify_salt_types(salt_frags: List[Chem.Mol]) -> Dict[str, int]:
|
|
29
|
+
"""Internal: Classify salt fragments into categories"""
|
|
30
|
+
features = {
|
|
31
|
+
"has_organic_salt": 0,
|
|
32
|
+
"has_metal_salt": 0,
|
|
33
|
+
"has_halide": 0,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
for frag in salt_frags:
|
|
37
|
+
# Get atoms
|
|
38
|
+
atoms = [atom.GetSymbol() for atom in frag.GetAtoms()]
|
|
39
|
+
|
|
40
|
+
# Metal detection
|
|
41
|
+
metals = ["Na", "K", "Ca", "Mg", "Li", "Zn", "Fe", "Al"]
|
|
42
|
+
if any(metal in atoms for metal in metals):
|
|
43
|
+
features["has_metal_salt"] = 1
|
|
44
|
+
|
|
45
|
+
# Halide detection
|
|
46
|
+
halides = ["Cl", "Br", "I", "F"]
|
|
47
|
+
if any(halide in atoms for halide in halides):
|
|
48
|
+
features["has_halide"] = 1
|
|
49
|
+
|
|
50
|
+
# Organic vs inorganic (simple heuristic: contains C)
|
|
51
|
+
if "C" in atoms:
|
|
52
|
+
features["has_organic_salt"] = 1
|
|
53
|
+
|
|
54
|
+
return features
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def extract_advanced_salt_features(
|
|
58
|
+
mol: Optional[Chem.Mol],
|
|
59
|
+
) -> Tuple[Optional[Dict[str, Union[int, float]]], Optional[Chem.Mol]]:
|
|
60
|
+
"""Extract comprehensive salt-related features from RDKit molecule"""
|
|
61
|
+
if mol is None:
|
|
62
|
+
return None, None
|
|
63
|
+
|
|
64
|
+
# Get fragments
|
|
65
|
+
fragments = Chem.GetMolFrags(mol, asMols=True)
|
|
66
|
+
|
|
67
|
+
# Identify API (largest organic fragment) vs salt fragments
|
|
68
|
+
fragment_weights = [(frag, Descriptors.MolWt(frag)) for frag in fragments]
|
|
69
|
+
fragment_weights.sort(key=lambda x: x[1], reverse=True)
|
|
70
|
+
|
|
71
|
+
# Find largest organic fragment as API
|
|
72
|
+
api_mol = None
|
|
73
|
+
salt_frags = []
|
|
74
|
+
|
|
75
|
+
for frag, mw in fragment_weights:
|
|
76
|
+
atoms = [atom.GetSymbol() for atom in frag.GetAtoms()]
|
|
77
|
+
if "C" in atoms and api_mol is None: # First organic fragment = API
|
|
78
|
+
api_mol = frag
|
|
79
|
+
else:
|
|
80
|
+
salt_frags.append(frag)
|
|
81
|
+
|
|
82
|
+
# Fallback: if no organic fragments, use largest
|
|
83
|
+
if api_mol is None:
|
|
84
|
+
api_mol = fragment_weights[0][0]
|
|
85
|
+
salt_frags = [frag for frag, _ in fragment_weights[1:]]
|
|
86
|
+
|
|
87
|
+
# Initialize all features with default values
|
|
88
|
+
features = {col: 0 for col in _get_salt_feature_columns()}
|
|
89
|
+
features["mw_ratio"] = 1.0 # default for no salt
|
|
90
|
+
|
|
91
|
+
# Basic features
|
|
92
|
+
features.update(
|
|
93
|
+
{
|
|
94
|
+
"has_salt": int(len(salt_frags) > 0),
|
|
95
|
+
"mw_ratio": Descriptors.MolWt(api_mol) / Descriptors.MolWt(mol),
|
|
96
|
+
}
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if salt_frags:
|
|
100
|
+
# Salt characterization
|
|
101
|
+
total_salt_mw = sum(Descriptors.MolWt(frag) for frag in salt_frags)
|
|
102
|
+
features.update(
|
|
103
|
+
{
|
|
104
|
+
"salt_to_api_ratio": total_salt_mw / Descriptors.MolWt(api_mol),
|
|
105
|
+
"ionic_strength_proxy": sum(abs(Chem.GetFormalCharge(frag)) for frag in salt_frags),
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Salt type classification
|
|
110
|
+
features.update(_classify_salt_types(salt_frags))
|
|
111
|
+
|
|
112
|
+
return features, api_mol
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def add_salt_features(df: pd.DataFrame) -> pd.DataFrame:
|
|
116
|
+
"""Add salt features to dataframe with 'molecule' column containing RDKit molecules"""
|
|
117
|
+
salt_features_list = []
|
|
118
|
+
|
|
119
|
+
for idx, row in df.iterrows():
|
|
120
|
+
mol = row["molecule"]
|
|
121
|
+
features, clean_mol = extract_advanced_salt_features(mol)
|
|
122
|
+
|
|
123
|
+
if features is None:
|
|
124
|
+
# Handle invalid molecules
|
|
125
|
+
features = {col: None for col in _get_salt_feature_columns()}
|
|
126
|
+
|
|
127
|
+
salt_features_list.append(features)
|
|
128
|
+
|
|
129
|
+
# Convert to DataFrame and concatenate
|
|
130
|
+
salt_df = pd.DataFrame(salt_features_list)
|
|
131
|
+
return pd.concat([df, salt_df], axis=1)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
if __name__ == "__main__":
|
|
135
|
+
print("Running salt feature extraction tests...")
|
|
136
|
+
|
|
137
|
+
# Test molecules with various salt forms
|
|
138
|
+
test_molecules = {
|
|
139
|
+
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O", # No salt
|
|
140
|
+
"sodium_acetate": "CC(=O)[O-].[Na+]", # Simple sodium salt
|
|
141
|
+
"calcium_acetate": "CC(=O)[O-].[Ca+2].CC(=O)[O-]", # Calcium salt (divalent)
|
|
142
|
+
"ammonium_chloride": "[NH4+].[Cl-]", # Inorganic salt
|
|
143
|
+
"methylamine_hcl": "CN.Cl", # Organic salt
|
|
144
|
+
"benzene": "c1ccccc1", # No salt
|
|
145
|
+
"aspirin_sodium": "CC(=O)OC1=CC=CC=C1C(=O)[O-].[Na+]", # API with sodium salt
|
|
146
|
+
"complex_salt": "CC(C)(C)c1ccc(O)cc1.Cl.Cl.[Zn+2]", # Complex with metal
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# Test 1: Basic salt feature extraction
|
|
150
|
+
print("\n1. Testing basic salt feature extraction...")
|
|
151
|
+
|
|
152
|
+
salt_test_df = pd.DataFrame({"smiles": list(test_molecules.values()), "name": list(test_molecules.keys())})
|
|
153
|
+
|
|
154
|
+
salt_test_df["molecule"] = salt_test_df["smiles"].apply(Chem.MolFromSmiles)
|
|
155
|
+
salt_result = add_salt_features(salt_test_df)
|
|
156
|
+
|
|
157
|
+
print(" Salt feature results:")
|
|
158
|
+
print(f" {'Molecule':20} has_salt metal halide organic mw_ratio")
|
|
159
|
+
print(" " + "-" * 65)
|
|
160
|
+
for _, row in salt_result.iterrows():
|
|
161
|
+
print(
|
|
162
|
+
f" {row['name']:20} {row['has_salt']:^8} {row['has_metal_salt']:^6} "
|
|
163
|
+
f"{row['has_halide']:^7} {row['has_organic_salt']:^8} {row['mw_ratio']:>8.3f}"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Test 2: Detailed feature extraction for specific cases
|
|
167
|
+
print("\n2. Testing detailed salt feature extraction...")
|
|
168
|
+
|
|
169
|
+
for name, smiles in [
|
|
170
|
+
("sodium_acetate", test_molecules["sodium_acetate"]),
|
|
171
|
+
("calcium_acetate", test_molecules["calcium_acetate"]),
|
|
172
|
+
("aspirin", test_molecules["aspirin"]),
|
|
173
|
+
]:
|
|
174
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
175
|
+
features, api_mol = extract_advanced_salt_features(mol)
|
|
176
|
+
|
|
177
|
+
print(f"\n {name}:")
|
|
178
|
+
print(f" Total MW: {Descriptors.MolWt(mol):.2f}")
|
|
179
|
+
if api_mol:
|
|
180
|
+
print(f" API MW: {Descriptors.MolWt(api_mol):.2f}")
|
|
181
|
+
if features:
|
|
182
|
+
print(f" Salt/API ratio: {features['salt_to_api_ratio']:.3f}")
|
|
183
|
+
print(f" Ionic strength proxy: {features['ionic_strength_proxy']}")
|
|
184
|
+
|
|
185
|
+
# Test 3: Edge cases
|
|
186
|
+
print("\n3. Testing edge cases...")
|
|
187
|
+
|
|
188
|
+
# Empty molecule (None)
|
|
189
|
+
empty_features, empty_api = extract_advanced_salt_features(None)
|
|
190
|
+
print(f" None molecule: features={empty_features}, api={empty_api}")
|
|
191
|
+
|
|
192
|
+
# Single fragment (no salt)
|
|
193
|
+
benzene_mol = Chem.MolFromSmiles("c1ccccc1")
|
|
194
|
+
benzene_features, benzene_api = extract_advanced_salt_features(benzene_mol)
|
|
195
|
+
print(
|
|
196
|
+
f" Single fragment (benzene): has_salt={benzene_features['has_salt']}, "
|
|
197
|
+
f"mw_ratio={benzene_features['mw_ratio']:.3f}"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Multiple organic fragments
|
|
201
|
+
multi_org = Chem.MolFromSmiles("c1ccccc1.CC(=O)O")
|
|
202
|
+
multi_features, multi_api = extract_advanced_salt_features(multi_org)
|
|
203
|
+
print(f" Multiple organic fragments: has_salt={multi_features['has_salt']}")
|
|
204
|
+
|
|
205
|
+
# Test 4: Salt type classification
|
|
206
|
+
print("\n4. Testing salt type classification...")
|
|
207
|
+
|
|
208
|
+
test_salts = [
|
|
209
|
+
("NaCl", "[Na+].[Cl-]", "Sodium chloride"),
|
|
210
|
+
("KBr", "[K+].[Br-]", "Potassium bromide"),
|
|
211
|
+
("CaCl2", "[Ca+2].[Cl-].[Cl-]", "Calcium chloride"),
|
|
212
|
+
("Organic salt", "CC(=O)[O-].CN", "Acetate with methylamine"),
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
for name, smiles, description in test_salts:
|
|
216
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
217
|
+
if mol:
|
|
218
|
+
features, _ = extract_advanced_salt_features(mol)
|
|
219
|
+
print(f" {name:15} ({description})")
|
|
220
|
+
print(
|
|
221
|
+
f" Metal: {features['has_metal_salt']}, "
|
|
222
|
+
f"Halide: {features['has_halide']}, "
|
|
223
|
+
f"Organic: {features['has_organic_salt']}"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Test 5: DataFrame integration
|
|
227
|
+
print("\n5. Testing DataFrame integration...")
|
|
228
|
+
|
|
229
|
+
# Create a mixed DataFrame
|
|
230
|
+
mixed_df = pd.DataFrame(
|
|
231
|
+
{
|
|
232
|
+
"smiles": [
|
|
233
|
+
test_molecules["aspirin"],
|
|
234
|
+
test_molecules["sodium_acetate"],
|
|
235
|
+
test_molecules["calcium_acetate"],
|
|
236
|
+
],
|
|
237
|
+
"name": ["aspirin", "sodium_acetate", "calcium_acetate"],
|
|
238
|
+
"existing_col": [1, 2, 3], # Test that existing columns are preserved
|
|
239
|
+
}
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
mixed_df["molecule"] = mixed_df["smiles"].apply(Chem.MolFromSmiles)
|
|
243
|
+
result_df = add_salt_features(mixed_df)
|
|
244
|
+
|
|
245
|
+
# Check that all columns are present
|
|
246
|
+
expected_cols = ["smiles", "name", "existing_col", "molecule"] + _get_salt_feature_columns()
|
|
247
|
+
missing_cols = [col for col in expected_cols if col not in result_df.columns]
|
|
248
|
+
|
|
249
|
+
if missing_cols:
|
|
250
|
+
print(f" ✗ Missing columns: {missing_cols}")
|
|
251
|
+
else:
|
|
252
|
+
print(" ✓ All expected columns present")
|
|
253
|
+
print(f" ✓ Original columns preserved: 'existing_col' in result = {('existing_col' in result_df.columns)}")
|
|
254
|
+
print(f" ✓ Salt features added: {len(_get_salt_feature_columns())} new columns")
|
|
255
|
+
|
|
256
|
+
print("\n✅ All salt feature tests completed!")
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""SDF File utilities for molecular data in Workbench"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
from rdkit import Chem
|
|
7
|
+
from rdkit.Chem import AllChem, SDWriter
|
|
8
|
+
|
|
9
|
+
# Set up the logger
|
|
10
|
+
log = logging.getLogger("workbench")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def df_to_sdf_file(
|
|
14
|
+
df: pd.DataFrame,
|
|
15
|
+
output_file: str,
|
|
16
|
+
smiles_col: str = "smiles",
|
|
17
|
+
id_col: Optional[str] = None,
|
|
18
|
+
include_cols: Optional[List[str]] = None,
|
|
19
|
+
skip_invalid: bool = True,
|
|
20
|
+
generate_3d: bool = True,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Convert DataFrame with SMILES to SDF file.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
df: DataFrame containing SMILES and other data
|
|
27
|
+
output_file: Path to output SDF file
|
|
28
|
+
smiles_col: Column name containing SMILES strings
|
|
29
|
+
id_col: Column to use as molecule ID/name
|
|
30
|
+
include_cols: Specific columns to include as properties (default: all except smiles and molecule columns)
|
|
31
|
+
skip_invalid: Skip invalid SMILES instead of raising error
|
|
32
|
+
generate_3d: Generate 3D coordinates and optimize geometry
|
|
33
|
+
"""
|
|
34
|
+
written_count = 0
|
|
35
|
+
|
|
36
|
+
with SDWriter(output_file) as writer:
|
|
37
|
+
writer.SetForceV3000(True)
|
|
38
|
+
for idx, row in df.iterrows():
|
|
39
|
+
mol = Chem.MolFromSmiles(row[smiles_col])
|
|
40
|
+
if mol is None:
|
|
41
|
+
if not skip_invalid:
|
|
42
|
+
raise ValueError(f"Invalid SMILES at row {idx}: {row[smiles_col]}")
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
# Generate 3D coordinates
|
|
46
|
+
if generate_3d:
|
|
47
|
+
mol = Chem.AddHs(mol)
|
|
48
|
+
|
|
49
|
+
# Try progressively more aggressive embedding strategies
|
|
50
|
+
embed_strategies = [
|
|
51
|
+
{"maxAttempts": 1000, "randomSeed": 42},
|
|
52
|
+
{"maxAttempts": 1000, "randomSeed": 42, "useRandomCoords": True},
|
|
53
|
+
{"maxAttempts": 1000, "randomSeed": 42, "boxSizeMult": 5.0},
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
embedded = False
|
|
57
|
+
for strategy in embed_strategies:
|
|
58
|
+
if AllChem.EmbedMolecule(mol, **strategy) != -1:
|
|
59
|
+
embedded = True
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
if not embedded:
|
|
63
|
+
if not skip_invalid:
|
|
64
|
+
raise ValueError(f"Could not generate 3D coords for row {idx}")
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
AllChem.MMFFOptimizeMolecule(mol)
|
|
68
|
+
|
|
69
|
+
# Set molecule name/ID
|
|
70
|
+
if id_col and id_col in df.columns:
|
|
71
|
+
mol.SetProp("_Name", str(row[id_col]))
|
|
72
|
+
|
|
73
|
+
# Determine which columns to include
|
|
74
|
+
if include_cols:
|
|
75
|
+
cols_to_add = [col for col in include_cols if col in df.columns and col != smiles_col]
|
|
76
|
+
else:
|
|
77
|
+
# Auto-exclude common molecule column names and SMILES column
|
|
78
|
+
mol_col_names = ["mol", "molecule", "rdkit_mol", "Mol"]
|
|
79
|
+
cols_to_add = [col for col in df.columns if col != smiles_col and col not in mol_col_names]
|
|
80
|
+
|
|
81
|
+
# Add properties
|
|
82
|
+
for col in cols_to_add:
|
|
83
|
+
mol.SetProp(col, str(row[col]))
|
|
84
|
+
|
|
85
|
+
writer.write(mol)
|
|
86
|
+
written_count += 1
|
|
87
|
+
|
|
88
|
+
log.info(f"Wrote {written_count} molecules to SDF: {output_file}")
|
|
89
|
+
return written_count
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def sdf_file_to_df(
|
|
93
|
+
sdf_file: str,
|
|
94
|
+
include_smiles: bool = True,
|
|
95
|
+
smiles_col: str = "smiles",
|
|
96
|
+
id_col: Optional[str] = None,
|
|
97
|
+
include_props: Optional[List[str]] = None,
|
|
98
|
+
exclude_props: Optional[List[str]] = None,
|
|
99
|
+
) -> pd.DataFrame:
|
|
100
|
+
"""
|
|
101
|
+
Convert SDF file to DataFrame.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
sdf_file: Path to input SDF file
|
|
105
|
+
include_smiles: Add SMILES column to output
|
|
106
|
+
smiles_col: Name for SMILES column
|
|
107
|
+
id_col: Column name for molecule ID/name (uses _Name property)
|
|
108
|
+
include_props: Specific properties to include (default: all)
|
|
109
|
+
exclude_props: Properties to exclude from output
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
DataFrame with molecules and their properties
|
|
113
|
+
"""
|
|
114
|
+
data = []
|
|
115
|
+
|
|
116
|
+
suppl = Chem.SDMolSupplier(sdf_file)
|
|
117
|
+
for idx, mol in enumerate(suppl):
|
|
118
|
+
if mol is None:
|
|
119
|
+
log.warning(f"Could not parse molecule at index {idx}")
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
row_data = {}
|
|
123
|
+
|
|
124
|
+
# Add SMILES if requested
|
|
125
|
+
if include_smiles:
|
|
126
|
+
row_data[smiles_col] = Chem.MolToSmiles(mol)
|
|
127
|
+
|
|
128
|
+
# Add molecule name/ID if requested
|
|
129
|
+
if id_col and mol.HasProp("_Name"):
|
|
130
|
+
row_data[id_col] = mol.GetProp("_Name")
|
|
131
|
+
|
|
132
|
+
# Get all properties
|
|
133
|
+
prop_names = mol.GetPropNames()
|
|
134
|
+
|
|
135
|
+
# Filter properties based on include/exclude lists
|
|
136
|
+
if include_props:
|
|
137
|
+
prop_names = [p for p in prop_names if p in include_props]
|
|
138
|
+
if exclude_props:
|
|
139
|
+
prop_names = [p for p in prop_names if p not in exclude_props]
|
|
140
|
+
|
|
141
|
+
# Add properties to row
|
|
142
|
+
for prop in prop_names:
|
|
143
|
+
if prop != "_Name": # Skip _Name if we already handled it
|
|
144
|
+
row_data[prop] = mol.GetProp(prop)
|
|
145
|
+
|
|
146
|
+
data.append(row_data)
|
|
147
|
+
|
|
148
|
+
df = pd.DataFrame(data)
|
|
149
|
+
log.info(f"Read {len(df)} molecules from SDF: {sdf_file}")
|
|
150
|
+
|
|
151
|
+
return df
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
import tempfile
|
|
156
|
+
import os
|
|
157
|
+
|
|
158
|
+
print("Running SDF utilities tests...")
|
|
159
|
+
|
|
160
|
+
# Create test data
|
|
161
|
+
test_data = pd.DataFrame(
|
|
162
|
+
{
|
|
163
|
+
"smiles": [
|
|
164
|
+
"CCO", # Ethanol
|
|
165
|
+
"c1ccccc1", # Benzene
|
|
166
|
+
"CC(=O)O", # Acetic acid
|
|
167
|
+
"INVALID_SMILES", # Invalid
|
|
168
|
+
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
|
|
169
|
+
],
|
|
170
|
+
"name": ["Ethanol", "Benzene", "Acetic Acid", "Invalid", "Caffeine"],
|
|
171
|
+
"mol_weight": [46.07, 78.11, 60.05, 0, 194.19],
|
|
172
|
+
"category": ["alcohol", "aromatic", "acid", "error", "alkaloid"],
|
|
173
|
+
"mol": ["should_exclude", "should_exclude", "should_exclude", "should_exclude", "should_exclude"],
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Test 1: Basic DataFrame to SDF conversion
|
|
178
|
+
print("\n1. Testing DataFrame to SDF conversion...")
|
|
179
|
+
with tempfile.NamedTemporaryFile(suffix=".sdf", delete=False) as tmp:
|
|
180
|
+
tmp_path = tmp.name
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
# Test with 3D generation
|
|
184
|
+
count = df_to_sdf_file(
|
|
185
|
+
test_data, tmp_path, smiles_col="smiles", id_col="name", skip_invalid=True, generate_3d=True
|
|
186
|
+
)
|
|
187
|
+
print(f" ✓ Wrote {count} molecules with 3D coords (expected 4, skipped 1 invalid)")
|
|
188
|
+
|
|
189
|
+
# Test without 3D generation
|
|
190
|
+
count = df_to_sdf_file(
|
|
191
|
+
test_data, tmp_path, smiles_col="smiles", id_col="name", skip_invalid=True, generate_3d=False
|
|
192
|
+
)
|
|
193
|
+
print(f" ✓ Wrote {count} molecules without 3D coords")
|
|
194
|
+
|
|
195
|
+
except Exception as e:
|
|
196
|
+
print(f" ✗ Error writing SDF: {e}")
|
|
197
|
+
|
|
198
|
+
# Test 2: SDF to DataFrame conversion
|
|
199
|
+
print("\n2. Testing SDF to DataFrame conversion...")
|
|
200
|
+
try:
|
|
201
|
+
# Read back the SDF
|
|
202
|
+
df_read = sdf_file_to_df(tmp_path, include_smiles=True, smiles_col="SMILES", id_col="mol_name")
|
|
203
|
+
print(f" ✓ Read {len(df_read)} molecules from SDF")
|
|
204
|
+
print(f" ✓ Columns: {list(df_read.columns)}")
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
print(f" ✗ Error reading SDF: {e}")
|
|
208
|
+
|
|
209
|
+
# Test 3: Column filtering
|
|
210
|
+
print("\n3. Testing column inclusion/exclusion...")
|
|
211
|
+
try:
|
|
212
|
+
# Test with specific columns only
|
|
213
|
+
count = df_to_sdf_file(
|
|
214
|
+
test_data,
|
|
215
|
+
tmp_path,
|
|
216
|
+
smiles_col="smiles",
|
|
217
|
+
include_cols=["name", "mol_weight"],
|
|
218
|
+
skip_invalid=True,
|
|
219
|
+
generate_3d=False,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
df_filtered = sdf_file_to_df(tmp_path)
|
|
223
|
+
excluded_mol = "mol" not in df_filtered.columns
|
|
224
|
+
included_weight = any("mol_weight" in str(col) for col in df_filtered.columns)
|
|
225
|
+
|
|
226
|
+
print(f" {'✓' if excluded_mol else '✗'} 'mol' column excluded")
|
|
227
|
+
print(f" {'✓' if included_weight else '✗'} 'mol_weight' included")
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
print(f" ✗ Error with column filtering: {e}")
|
|
231
|
+
|
|
232
|
+
# Test 4: Error handling
|
|
233
|
+
print("\n4. Testing error handling...")
|
|
234
|
+
|
|
235
|
+
# Test with skip_invalid=False
|
|
236
|
+
try:
|
|
237
|
+
count = df_to_sdf_file(test_data, tmp_path, smiles_col="smiles", skip_invalid=False, generate_3d=False)
|
|
238
|
+
print(" ✗ Should have raised error for invalid SMILES")
|
|
239
|
+
except ValueError:
|
|
240
|
+
print(" ✓ Correctly raised error for invalid SMILES")
|
|
241
|
+
|
|
242
|
+
# Test 5: Property filtering on read
|
|
243
|
+
print("\n5. Testing property filtering on read...")
|
|
244
|
+
try:
|
|
245
|
+
# Write full data
|
|
246
|
+
df_to_sdf_file(test_data, tmp_path, smiles_col="smiles", skip_invalid=True, generate_3d=False)
|
|
247
|
+
|
|
248
|
+
# Read with include filter
|
|
249
|
+
df_include = sdf_file_to_df(tmp_path, include_props=["mol_weight", "category"])
|
|
250
|
+
print(f" ✓ Include filter: {list(df_include.columns)}")
|
|
251
|
+
|
|
252
|
+
# Read with exclude filter
|
|
253
|
+
df_exclude = sdf_file_to_df(tmp_path, exclude_props=["category"])
|
|
254
|
+
has_category = "category" in df_exclude.columns
|
|
255
|
+
print(f" {'✗' if has_category else '✓'} Exclude filter: 'category' excluded")
|
|
256
|
+
|
|
257
|
+
except Exception as e:
|
|
258
|
+
print(f" ✗ Error with property filtering: {e}")
|
|
259
|
+
|
|
260
|
+
# Test 6: Edge cases
|
|
261
|
+
print("\n6. Testing edge cases...")
|
|
262
|
+
|
|
263
|
+
# Empty DataFrame
|
|
264
|
+
empty_df = pd.DataFrame(columns=["smiles", "name"])
|
|
265
|
+
try:
|
|
266
|
+
count = df_to_sdf_file(empty_df, tmp_path)
|
|
267
|
+
print(f" ✓ Empty DataFrame: wrote {count} molecules")
|
|
268
|
+
except Exception as e:
|
|
269
|
+
print(f" ✗ Empty DataFrame error: {e}")
|
|
270
|
+
|
|
271
|
+
# Missing columns
|
|
272
|
+
bad_df = pd.DataFrame({"not_smiles": ["CCO"]})
|
|
273
|
+
try:
|
|
274
|
+
count = df_to_sdf_file(bad_df, tmp_path, smiles_col="smiles")
|
|
275
|
+
print(" ✗ Should have raised error for missing column")
|
|
276
|
+
except KeyError:
|
|
277
|
+
print(" ✓ Correctly raised error for missing SMILES column")
|
|
278
|
+
|
|
279
|
+
# Large molecule test (3D generation stress test)
|
|
280
|
+
large_mol_df = pd.DataFrame({"smiles": ["C" * 50], "name": ["Long Chain"]}) # Very long carbon chain
|
|
281
|
+
try:
|
|
282
|
+
count = df_to_sdf_file(large_mol_df, tmp_path, generate_3d=True, skip_invalid=True)
|
|
283
|
+
print(f" ✓ Large molecule: wrote {count} molecule(s)")
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f" ✗ Large molecule error: {e}")
|
|
286
|
+
|
|
287
|
+
# Cleanup
|
|
288
|
+
if os.path.exists(tmp_path):
|
|
289
|
+
os.remove(tmp_path)
|
|
290
|
+
print(f"\n✓ Cleaned up temp file: {tmp_path}")
|
|
291
|
+
|
|
292
|
+
print("\n✅ All SDF utilities tests completed!")
|