workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
1
+ """Salt feature extraction utilities for molecular analysis"""
2
+
3
+ import logging
4
+ import pandas as pd
5
+ from typing import List, Dict, Tuple, Optional, Union
6
+
7
+ # Molecular Descriptor Imports
8
+ from rdkit import Chem
9
+ from rdkit.Chem import Descriptors
10
+
11
+ # Set up the logger
12
+ log = logging.getLogger("workbench")
13
+
14
+
15
+ def _get_salt_feature_columns() -> List[str]:
16
+ """Internal: Return list of all salt feature column names"""
17
+ return [
18
+ "has_salt",
19
+ "mw_ratio",
20
+ "salt_to_api_ratio",
21
+ "has_metal_salt",
22
+ "has_halide",
23
+ "ionic_strength_proxy",
24
+ "has_organic_salt",
25
+ ]
26
+
27
+
28
+ def _classify_salt_types(salt_frags: List[Chem.Mol]) -> Dict[str, int]:
29
+ """Internal: Classify salt fragments into categories"""
30
+ features = {
31
+ "has_organic_salt": 0,
32
+ "has_metal_salt": 0,
33
+ "has_halide": 0,
34
+ }
35
+
36
+ for frag in salt_frags:
37
+ # Get atoms
38
+ atoms = [atom.GetSymbol() for atom in frag.GetAtoms()]
39
+
40
+ # Metal detection
41
+ metals = ["Na", "K", "Ca", "Mg", "Li", "Zn", "Fe", "Al"]
42
+ if any(metal in atoms for metal in metals):
43
+ features["has_metal_salt"] = 1
44
+
45
+ # Halide detection
46
+ halides = ["Cl", "Br", "I", "F"]
47
+ if any(halide in atoms for halide in halides):
48
+ features["has_halide"] = 1
49
+
50
+ # Organic vs inorganic (simple heuristic: contains C)
51
+ if "C" in atoms:
52
+ features["has_organic_salt"] = 1
53
+
54
+ return features
55
+
56
+
57
+ def extract_advanced_salt_features(
58
+ mol: Optional[Chem.Mol],
59
+ ) -> Tuple[Optional[Dict[str, Union[int, float]]], Optional[Chem.Mol]]:
60
+ """Extract comprehensive salt-related features from RDKit molecule"""
61
+ if mol is None:
62
+ return None, None
63
+
64
+ # Get fragments
65
+ fragments = Chem.GetMolFrags(mol, asMols=True)
66
+
67
+ # Identify API (largest organic fragment) vs salt fragments
68
+ fragment_weights = [(frag, Descriptors.MolWt(frag)) for frag in fragments]
69
+ fragment_weights.sort(key=lambda x: x[1], reverse=True)
70
+
71
+ # Find largest organic fragment as API
72
+ api_mol = None
73
+ salt_frags = []
74
+
75
+ for frag, mw in fragment_weights:
76
+ atoms = [atom.GetSymbol() for atom in frag.GetAtoms()]
77
+ if "C" in atoms and api_mol is None: # First organic fragment = API
78
+ api_mol = frag
79
+ else:
80
+ salt_frags.append(frag)
81
+
82
+ # Fallback: if no organic fragments, use largest
83
+ if api_mol is None:
84
+ api_mol = fragment_weights[0][0]
85
+ salt_frags = [frag for frag, _ in fragment_weights[1:]]
86
+
87
+ # Initialize all features with default values
88
+ features = {col: 0 for col in _get_salt_feature_columns()}
89
+ features["mw_ratio"] = 1.0 # default for no salt
90
+
91
+ # Basic features
92
+ features.update(
93
+ {
94
+ "has_salt": int(len(salt_frags) > 0),
95
+ "mw_ratio": Descriptors.MolWt(api_mol) / Descriptors.MolWt(mol),
96
+ }
97
+ )
98
+
99
+ if salt_frags:
100
+ # Salt characterization
101
+ total_salt_mw = sum(Descriptors.MolWt(frag) for frag in salt_frags)
102
+ features.update(
103
+ {
104
+ "salt_to_api_ratio": total_salt_mw / Descriptors.MolWt(api_mol),
105
+ "ionic_strength_proxy": sum(abs(Chem.GetFormalCharge(frag)) for frag in salt_frags),
106
+ }
107
+ )
108
+
109
+ # Salt type classification
110
+ features.update(_classify_salt_types(salt_frags))
111
+
112
+ return features, api_mol
113
+
114
+
115
+ def add_salt_features(df: pd.DataFrame) -> pd.DataFrame:
116
+ """Add salt features to dataframe with 'molecule' column containing RDKit molecules"""
117
+ salt_features_list = []
118
+
119
+ for idx, row in df.iterrows():
120
+ mol = row["molecule"]
121
+ features, clean_mol = extract_advanced_salt_features(mol)
122
+
123
+ if features is None:
124
+ # Handle invalid molecules
125
+ features = {col: None for col in _get_salt_feature_columns()}
126
+
127
+ salt_features_list.append(features)
128
+
129
+ # Convert to DataFrame and concatenate
130
+ salt_df = pd.DataFrame(salt_features_list)
131
+ return pd.concat([df, salt_df], axis=1)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ print("Running salt feature extraction tests...")
136
+
137
+ # Test molecules with various salt forms
138
+ test_molecules = {
139
+ "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O", # No salt
140
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Simple sodium salt
141
+ "calcium_acetate": "CC(=O)[O-].[Ca+2].CC(=O)[O-]", # Calcium salt (divalent)
142
+ "ammonium_chloride": "[NH4+].[Cl-]", # Inorganic salt
143
+ "methylamine_hcl": "CN.Cl", # Organic salt
144
+ "benzene": "c1ccccc1", # No salt
145
+ "aspirin_sodium": "CC(=O)OC1=CC=CC=C1C(=O)[O-].[Na+]", # API with sodium salt
146
+ "complex_salt": "CC(C)(C)c1ccc(O)cc1.Cl.Cl.[Zn+2]", # Complex with metal
147
+ }
148
+
149
+ # Test 1: Basic salt feature extraction
150
+ print("\n1. Testing basic salt feature extraction...")
151
+
152
+ salt_test_df = pd.DataFrame({"smiles": list(test_molecules.values()), "name": list(test_molecules.keys())})
153
+
154
+ salt_test_df["molecule"] = salt_test_df["smiles"].apply(Chem.MolFromSmiles)
155
+ salt_result = add_salt_features(salt_test_df)
156
+
157
+ print(" Salt feature results:")
158
+ print(f" {'Molecule':20} has_salt metal halide organic mw_ratio")
159
+ print(" " + "-" * 65)
160
+ for _, row in salt_result.iterrows():
161
+ print(
162
+ f" {row['name']:20} {row['has_salt']:^8} {row['has_metal_salt']:^6} "
163
+ f"{row['has_halide']:^7} {row['has_organic_salt']:^8} {row['mw_ratio']:>8.3f}"
164
+ )
165
+
166
+ # Test 2: Detailed feature extraction for specific cases
167
+ print("\n2. Testing detailed salt feature extraction...")
168
+
169
+ for name, smiles in [
170
+ ("sodium_acetate", test_molecules["sodium_acetate"]),
171
+ ("calcium_acetate", test_molecules["calcium_acetate"]),
172
+ ("aspirin", test_molecules["aspirin"]),
173
+ ]:
174
+ mol = Chem.MolFromSmiles(smiles)
175
+ features, api_mol = extract_advanced_salt_features(mol)
176
+
177
+ print(f"\n {name}:")
178
+ print(f" Total MW: {Descriptors.MolWt(mol):.2f}")
179
+ if api_mol:
180
+ print(f" API MW: {Descriptors.MolWt(api_mol):.2f}")
181
+ if features:
182
+ print(f" Salt/API ratio: {features['salt_to_api_ratio']:.3f}")
183
+ print(f" Ionic strength proxy: {features['ionic_strength_proxy']}")
184
+
185
+ # Test 3: Edge cases
186
+ print("\n3. Testing edge cases...")
187
+
188
+ # Empty molecule (None)
189
+ empty_features, empty_api = extract_advanced_salt_features(None)
190
+ print(f" None molecule: features={empty_features}, api={empty_api}")
191
+
192
+ # Single fragment (no salt)
193
+ benzene_mol = Chem.MolFromSmiles("c1ccccc1")
194
+ benzene_features, benzene_api = extract_advanced_salt_features(benzene_mol)
195
+ print(
196
+ f" Single fragment (benzene): has_salt={benzene_features['has_salt']}, "
197
+ f"mw_ratio={benzene_features['mw_ratio']:.3f}"
198
+ )
199
+
200
+ # Multiple organic fragments
201
+ multi_org = Chem.MolFromSmiles("c1ccccc1.CC(=O)O")
202
+ multi_features, multi_api = extract_advanced_salt_features(multi_org)
203
+ print(f" Multiple organic fragments: has_salt={multi_features['has_salt']}")
204
+
205
+ # Test 4: Salt type classification
206
+ print("\n4. Testing salt type classification...")
207
+
208
+ test_salts = [
209
+ ("NaCl", "[Na+].[Cl-]", "Sodium chloride"),
210
+ ("KBr", "[K+].[Br-]", "Potassium bromide"),
211
+ ("CaCl2", "[Ca+2].[Cl-].[Cl-]", "Calcium chloride"),
212
+ ("Organic salt", "CC(=O)[O-].CN", "Acetate with methylamine"),
213
+ ]
214
+
215
+ for name, smiles, description in test_salts:
216
+ mol = Chem.MolFromSmiles(smiles)
217
+ if mol:
218
+ features, _ = extract_advanced_salt_features(mol)
219
+ print(f" {name:15} ({description})")
220
+ print(
221
+ f" Metal: {features['has_metal_salt']}, "
222
+ f"Halide: {features['has_halide']}, "
223
+ f"Organic: {features['has_organic_salt']}"
224
+ )
225
+
226
+ # Test 5: DataFrame integration
227
+ print("\n5. Testing DataFrame integration...")
228
+
229
+ # Create a mixed DataFrame
230
+ mixed_df = pd.DataFrame(
231
+ {
232
+ "smiles": [
233
+ test_molecules["aspirin"],
234
+ test_molecules["sodium_acetate"],
235
+ test_molecules["calcium_acetate"],
236
+ ],
237
+ "name": ["aspirin", "sodium_acetate", "calcium_acetate"],
238
+ "existing_col": [1, 2, 3], # Test that existing columns are preserved
239
+ }
240
+ )
241
+
242
+ mixed_df["molecule"] = mixed_df["smiles"].apply(Chem.MolFromSmiles)
243
+ result_df = add_salt_features(mixed_df)
244
+
245
+ # Check that all columns are present
246
+ expected_cols = ["smiles", "name", "existing_col", "molecule"] + _get_salt_feature_columns()
247
+ missing_cols = [col for col in expected_cols if col not in result_df.columns]
248
+
249
+ if missing_cols:
250
+ print(f" ✗ Missing columns: {missing_cols}")
251
+ else:
252
+ print(" ✓ All expected columns present")
253
+ print(f" ✓ Original columns preserved: 'existing_col' in result = {('existing_col' in result_df.columns)}")
254
+ print(f" ✓ Salt features added: {len(_get_salt_feature_columns())} new columns")
255
+
256
+ print("\n✅ All salt feature tests completed!")
@@ -0,0 +1,292 @@
1
+ """SDF File utilities for molecular data in Workbench"""
2
+
3
+ import logging
4
+ import pandas as pd
5
+ from typing import List, Optional
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem, SDWriter
8
+
9
+ # Set up the logger
10
+ log = logging.getLogger("workbench")
11
+
12
+
13
+ def df_to_sdf_file(
14
+ df: pd.DataFrame,
15
+ output_file: str,
16
+ smiles_col: str = "smiles",
17
+ id_col: Optional[str] = None,
18
+ include_cols: Optional[List[str]] = None,
19
+ skip_invalid: bool = True,
20
+ generate_3d: bool = True,
21
+ ):
22
+ """
23
+ Convert DataFrame with SMILES to SDF file.
24
+
25
+ Args:
26
+ df: DataFrame containing SMILES and other data
27
+ output_file: Path to output SDF file
28
+ smiles_col: Column name containing SMILES strings
29
+ id_col: Column to use as molecule ID/name
30
+ include_cols: Specific columns to include as properties (default: all except smiles and molecule columns)
31
+ skip_invalid: Skip invalid SMILES instead of raising error
32
+ generate_3d: Generate 3D coordinates and optimize geometry
33
+ """
34
+ written_count = 0
35
+
36
+ with SDWriter(output_file) as writer:
37
+ writer.SetForceV3000(True)
38
+ for idx, row in df.iterrows():
39
+ mol = Chem.MolFromSmiles(row[smiles_col])
40
+ if mol is None:
41
+ if not skip_invalid:
42
+ raise ValueError(f"Invalid SMILES at row {idx}: {row[smiles_col]}")
43
+ continue
44
+
45
+ # Generate 3D coordinates
46
+ if generate_3d:
47
+ mol = Chem.AddHs(mol)
48
+
49
+ # Try progressively more aggressive embedding strategies
50
+ embed_strategies = [
51
+ {"maxAttempts": 1000, "randomSeed": 42},
52
+ {"maxAttempts": 1000, "randomSeed": 42, "useRandomCoords": True},
53
+ {"maxAttempts": 1000, "randomSeed": 42, "boxSizeMult": 5.0},
54
+ ]
55
+
56
+ embedded = False
57
+ for strategy in embed_strategies:
58
+ if AllChem.EmbedMolecule(mol, **strategy) != -1:
59
+ embedded = True
60
+ break
61
+
62
+ if not embedded:
63
+ if not skip_invalid:
64
+ raise ValueError(f"Could not generate 3D coords for row {idx}")
65
+ continue
66
+
67
+ AllChem.MMFFOptimizeMolecule(mol)
68
+
69
+ # Set molecule name/ID
70
+ if id_col and id_col in df.columns:
71
+ mol.SetProp("_Name", str(row[id_col]))
72
+
73
+ # Determine which columns to include
74
+ if include_cols:
75
+ cols_to_add = [col for col in include_cols if col in df.columns and col != smiles_col]
76
+ else:
77
+ # Auto-exclude common molecule column names and SMILES column
78
+ mol_col_names = ["mol", "molecule", "rdkit_mol", "Mol"]
79
+ cols_to_add = [col for col in df.columns if col != smiles_col and col not in mol_col_names]
80
+
81
+ # Add properties
82
+ for col in cols_to_add:
83
+ mol.SetProp(col, str(row[col]))
84
+
85
+ writer.write(mol)
86
+ written_count += 1
87
+
88
+ log.info(f"Wrote {written_count} molecules to SDF: {output_file}")
89
+ return written_count
90
+
91
+
92
+ def sdf_file_to_df(
93
+ sdf_file: str,
94
+ include_smiles: bool = True,
95
+ smiles_col: str = "smiles",
96
+ id_col: Optional[str] = None,
97
+ include_props: Optional[List[str]] = None,
98
+ exclude_props: Optional[List[str]] = None,
99
+ ) -> pd.DataFrame:
100
+ """
101
+ Convert SDF file to DataFrame.
102
+
103
+ Args:
104
+ sdf_file: Path to input SDF file
105
+ include_smiles: Add SMILES column to output
106
+ smiles_col: Name for SMILES column
107
+ id_col: Column name for molecule ID/name (uses _Name property)
108
+ include_props: Specific properties to include (default: all)
109
+ exclude_props: Properties to exclude from output
110
+
111
+ Returns:
112
+ DataFrame with molecules and their properties
113
+ """
114
+ data = []
115
+
116
+ suppl = Chem.SDMolSupplier(sdf_file)
117
+ for idx, mol in enumerate(suppl):
118
+ if mol is None:
119
+ log.warning(f"Could not parse molecule at index {idx}")
120
+ continue
121
+
122
+ row_data = {}
123
+
124
+ # Add SMILES if requested
125
+ if include_smiles:
126
+ row_data[smiles_col] = Chem.MolToSmiles(mol)
127
+
128
+ # Add molecule name/ID if requested
129
+ if id_col and mol.HasProp("_Name"):
130
+ row_data[id_col] = mol.GetProp("_Name")
131
+
132
+ # Get all properties
133
+ prop_names = mol.GetPropNames()
134
+
135
+ # Filter properties based on include/exclude lists
136
+ if include_props:
137
+ prop_names = [p for p in prop_names if p in include_props]
138
+ if exclude_props:
139
+ prop_names = [p for p in prop_names if p not in exclude_props]
140
+
141
+ # Add properties to row
142
+ for prop in prop_names:
143
+ if prop != "_Name": # Skip _Name if we already handled it
144
+ row_data[prop] = mol.GetProp(prop)
145
+
146
+ data.append(row_data)
147
+
148
+ df = pd.DataFrame(data)
149
+ log.info(f"Read {len(df)} molecules from SDF: {sdf_file}")
150
+
151
+ return df
152
+
153
+
154
+ if __name__ == "__main__":
155
+ import tempfile
156
+ import os
157
+
158
+ print("Running SDF utilities tests...")
159
+
160
+ # Create test data
161
+ test_data = pd.DataFrame(
162
+ {
163
+ "smiles": [
164
+ "CCO", # Ethanol
165
+ "c1ccccc1", # Benzene
166
+ "CC(=O)O", # Acetic acid
167
+ "INVALID_SMILES", # Invalid
168
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
169
+ ],
170
+ "name": ["Ethanol", "Benzene", "Acetic Acid", "Invalid", "Caffeine"],
171
+ "mol_weight": [46.07, 78.11, 60.05, 0, 194.19],
172
+ "category": ["alcohol", "aromatic", "acid", "error", "alkaloid"],
173
+ "mol": ["should_exclude", "should_exclude", "should_exclude", "should_exclude", "should_exclude"],
174
+ }
175
+ )
176
+
177
+ # Test 1: Basic DataFrame to SDF conversion
178
+ print("\n1. Testing DataFrame to SDF conversion...")
179
+ with tempfile.NamedTemporaryFile(suffix=".sdf", delete=False) as tmp:
180
+ tmp_path = tmp.name
181
+
182
+ try:
183
+ # Test with 3D generation
184
+ count = df_to_sdf_file(
185
+ test_data, tmp_path, smiles_col="smiles", id_col="name", skip_invalid=True, generate_3d=True
186
+ )
187
+ print(f" ✓ Wrote {count} molecules with 3D coords (expected 4, skipped 1 invalid)")
188
+
189
+ # Test without 3D generation
190
+ count = df_to_sdf_file(
191
+ test_data, tmp_path, smiles_col="smiles", id_col="name", skip_invalid=True, generate_3d=False
192
+ )
193
+ print(f" ✓ Wrote {count} molecules without 3D coords")
194
+
195
+ except Exception as e:
196
+ print(f" ✗ Error writing SDF: {e}")
197
+
198
+ # Test 2: SDF to DataFrame conversion
199
+ print("\n2. Testing SDF to DataFrame conversion...")
200
+ try:
201
+ # Read back the SDF
202
+ df_read = sdf_file_to_df(tmp_path, include_smiles=True, smiles_col="SMILES", id_col="mol_name")
203
+ print(f" ✓ Read {len(df_read)} molecules from SDF")
204
+ print(f" ✓ Columns: {list(df_read.columns)}")
205
+
206
+ except Exception as e:
207
+ print(f" ✗ Error reading SDF: {e}")
208
+
209
+ # Test 3: Column filtering
210
+ print("\n3. Testing column inclusion/exclusion...")
211
+ try:
212
+ # Test with specific columns only
213
+ count = df_to_sdf_file(
214
+ test_data,
215
+ tmp_path,
216
+ smiles_col="smiles",
217
+ include_cols=["name", "mol_weight"],
218
+ skip_invalid=True,
219
+ generate_3d=False,
220
+ )
221
+
222
+ df_filtered = sdf_file_to_df(tmp_path)
223
+ excluded_mol = "mol" not in df_filtered.columns
224
+ included_weight = any("mol_weight" in str(col) for col in df_filtered.columns)
225
+
226
+ print(f" {'✓' if excluded_mol else '✗'} 'mol' column excluded")
227
+ print(f" {'✓' if included_weight else '✗'} 'mol_weight' included")
228
+
229
+ except Exception as e:
230
+ print(f" ✗ Error with column filtering: {e}")
231
+
232
+ # Test 4: Error handling
233
+ print("\n4. Testing error handling...")
234
+
235
+ # Test with skip_invalid=False
236
+ try:
237
+ count = df_to_sdf_file(test_data, tmp_path, smiles_col="smiles", skip_invalid=False, generate_3d=False)
238
+ print(" ✗ Should have raised error for invalid SMILES")
239
+ except ValueError:
240
+ print(" ✓ Correctly raised error for invalid SMILES")
241
+
242
+ # Test 5: Property filtering on read
243
+ print("\n5. Testing property filtering on read...")
244
+ try:
245
+ # Write full data
246
+ df_to_sdf_file(test_data, tmp_path, smiles_col="smiles", skip_invalid=True, generate_3d=False)
247
+
248
+ # Read with include filter
249
+ df_include = sdf_file_to_df(tmp_path, include_props=["mol_weight", "category"])
250
+ print(f" ✓ Include filter: {list(df_include.columns)}")
251
+
252
+ # Read with exclude filter
253
+ df_exclude = sdf_file_to_df(tmp_path, exclude_props=["category"])
254
+ has_category = "category" in df_exclude.columns
255
+ print(f" {'✗' if has_category else '✓'} Exclude filter: 'category' excluded")
256
+
257
+ except Exception as e:
258
+ print(f" ✗ Error with property filtering: {e}")
259
+
260
+ # Test 6: Edge cases
261
+ print("\n6. Testing edge cases...")
262
+
263
+ # Empty DataFrame
264
+ empty_df = pd.DataFrame(columns=["smiles", "name"])
265
+ try:
266
+ count = df_to_sdf_file(empty_df, tmp_path)
267
+ print(f" ✓ Empty DataFrame: wrote {count} molecules")
268
+ except Exception as e:
269
+ print(f" ✗ Empty DataFrame error: {e}")
270
+
271
+ # Missing columns
272
+ bad_df = pd.DataFrame({"not_smiles": ["CCO"]})
273
+ try:
274
+ count = df_to_sdf_file(bad_df, tmp_path, smiles_col="smiles")
275
+ print(" ✗ Should have raised error for missing column")
276
+ except KeyError:
277
+ print(" ✓ Correctly raised error for missing SMILES column")
278
+
279
+ # Large molecule test (3D generation stress test)
280
+ large_mol_df = pd.DataFrame({"smiles": ["C" * 50], "name": ["Long Chain"]}) # Very long carbon chain
281
+ try:
282
+ count = df_to_sdf_file(large_mol_df, tmp_path, generate_3d=True, skip_invalid=True)
283
+ print(f" ✓ Large molecule: wrote {count} molecule(s)")
284
+ except Exception as e:
285
+ print(f" ✗ Large molecule error: {e}")
286
+
287
+ # Cleanup
288
+ if os.path.exists(tmp_path):
289
+ os.remove(tmp_path)
290
+ print(f"\n✓ Cleaned up temp file: {tmp_path}")
291
+
292
+ print("\n✅ All SDF utilities tests completed!")