workbench 0.8.172__py3-none-any.whl → 0.8.173__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/graph/light/proximity_graph.py +2 -1
- workbench/api/compound.py +1 -1
- workbench/api/monitor.py +1 -16
- workbench/core/artifacts/data_capture_core.py +315 -0
- workbench/core/artifacts/endpoint_core.py +9 -3
- workbench/core/artifacts/monitor_core.py +33 -249
- workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +471 -0
- workbench/model_scripts/custom_models/chem_info/mol_standardize.py +428 -0
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
- workbench/model_scripts/custom_models/uq_models/generated_model_script.py +95 -204
- workbench/model_scripts/xgb_model/generated_model_script.py +5 -5
- workbench/repl/workbench_shell.py +3 -3
- workbench/utils/chem_utils/__init__.py +0 -0
- workbench/utils/chem_utils/fingerprints.py +134 -0
- workbench/utils/chem_utils/misc.py +194 -0
- workbench/utils/chem_utils/mol_descriptors.py +471 -0
- workbench/utils/chem_utils/mol_standardize.py +428 -0
- workbench/utils/chem_utils/mol_tagging.py +348 -0
- workbench/utils/chem_utils/projections.py +209 -0
- workbench/utils/chem_utils/salts.py +256 -0
- workbench/utils/chem_utils/sdf.py +292 -0
- workbench/utils/chem_utils/toxicity.py +250 -0
- workbench/utils/chem_utils/vis.py +253 -0
- workbench/utils/monitor_utils.py +49 -56
- workbench/utils/pandas_utils.py +3 -3
- workbench/web_interface/components/plugins/generated_compounds.py +1 -1
- {workbench-0.8.172.dist-info → workbench-0.8.173.dist-info}/METADATA +1 -1
- {workbench-0.8.172.dist-info → workbench-0.8.173.dist-info}/RECORD +33 -22
- workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
- workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
- workbench/utils/chem_utils.py +0 -1556
- {workbench-0.8.172.dist-info → workbench-0.8.173.dist-info}/WHEEL +0 -0
- {workbench-0.8.172.dist-info → workbench-0.8.173.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.172.dist-info → workbench-0.8.173.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.172.dist-info → workbench-0.8.173.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Toxicity detection utilities for molecular compounds"""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple
|
|
4
|
+
from rdkit import Chem
|
|
5
|
+
from rdkit.Chem import Mol
|
|
6
|
+
from rdkit.Chem import FunctionalGroups as FG
|
|
7
|
+
|
|
8
|
+
# Precompiled SMARTS patterns for custom toxic functional groups
|
|
9
|
+
toxic_smarts_patterns = [
|
|
10
|
+
("C(=S)N"), # Dithiocarbamate
|
|
11
|
+
("P(=O)(O)(O)O"), # Phosphate Ester
|
|
12
|
+
("[As](=O)(=O)-[OH]"), # Arsenic Oxide
|
|
13
|
+
("[C](Cl)(Cl)(Cl)"), # Trichloromethyl
|
|
14
|
+
("[Cr](=O)(=O)=O"), # Chromium(VI)
|
|
15
|
+
("[N+](C)(C)(C)(C)"), # Quaternary Ammonium
|
|
16
|
+
("[Se][Se]"), # Diselenide
|
|
17
|
+
("c1c(Cl)c(Cl)c(Cl)c1"), # Trichlorinated Aromatic Ring
|
|
18
|
+
("[CX3](=O)[CX4][Cl,Br,F,I]"), # Halogenated Carbonyl
|
|
19
|
+
("[P+](C*)(C*)(C*)(C*)"), # Phosphonium Group
|
|
20
|
+
("NC(=S)c1c(Cl)cccc1Cl"), # Chlorobenzene Thiocarbamate
|
|
21
|
+
("NC(=S)Nc1ccccc1"), # Phenyl Thiocarbamate
|
|
22
|
+
("S=C1NCCN1"), # Thiourea Derivative
|
|
23
|
+
]
|
|
24
|
+
compiled_toxic_smarts = [Chem.MolFromSmarts(smarts) for smarts in toxic_smarts_patterns]
|
|
25
|
+
|
|
26
|
+
# Precompiled SMARTS patterns for exemptions
|
|
27
|
+
exempt_smarts_patterns = [
|
|
28
|
+
"c1ccc(O)c(O)c1", # Phenols
|
|
29
|
+
]
|
|
30
|
+
compiled_exempt_smarts = [Chem.MolFromSmarts(smarts) for smarts in exempt_smarts_patterns]
|
|
31
|
+
|
|
32
|
+
# Load functional group hierarchy once during initialization
|
|
33
|
+
fgroup_hierarchy = FG.BuildFuncGroupHierarchy()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def contains_heavy_metals(mol: Mol) -> bool:
|
|
37
|
+
"""
|
|
38
|
+
Check if a molecule contains any heavy metals (broad filter).
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
mol: RDKit molecule object.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
bool: True if any heavy metals are detected, False otherwise.
|
|
45
|
+
"""
|
|
46
|
+
heavy_metals = {"Zn", "Cu", "Fe", "Mn", "Co", "Pb", "Hg", "Cd", "As"}
|
|
47
|
+
return any(atom.GetSymbol() in heavy_metals for atom in mol.GetAtoms())
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def halogen_toxicity_score(mol: Mol) -> Tuple[int, int]:
|
|
51
|
+
"""
|
|
52
|
+
Calculate the halogen count and toxicity threshold for a molecule.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
mol: RDKit molecule object.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Tuple[int, int]: (halogen_count, halogen_threshold), where the threshold
|
|
59
|
+
scales with molecule size (minimum of 2 or 20% of atom count).
|
|
60
|
+
"""
|
|
61
|
+
# Define halogens and count their occurrences
|
|
62
|
+
halogens = {"Cl", "Br", "I", "F"}
|
|
63
|
+
halogen_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() in halogens)
|
|
64
|
+
|
|
65
|
+
# Define threshold: small molecules tolerate fewer halogens
|
|
66
|
+
# Threshold scales with molecule size to account for reasonable substitution
|
|
67
|
+
molecule_size = mol.GetNumAtoms()
|
|
68
|
+
halogen_threshold = max(2, int(molecule_size * 0.2)) # Minimum 2, scaled by 20% of molecule size
|
|
69
|
+
|
|
70
|
+
return halogen_count, halogen_threshold
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def toxic_elements(mol: Mol) -> Optional[List[str]]:
|
|
74
|
+
"""
|
|
75
|
+
Identifies toxic elements or specific forms of elements in a molecule.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
mol: RDKit molecule object.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Optional[List[str]]: List of toxic elements or specific forms if found, otherwise None.
|
|
82
|
+
|
|
83
|
+
Notes:
|
|
84
|
+
Halogen toxicity logic integrates with `halogen_toxicity_score` and scales thresholds
|
|
85
|
+
based on molecule size.
|
|
86
|
+
"""
|
|
87
|
+
# Always toxic elements (heavy metals and known toxic single elements)
|
|
88
|
+
always_toxic = {"Pb", "Hg", "Cd", "As", "Be", "Tl", "Sb"}
|
|
89
|
+
toxic_found = set()
|
|
90
|
+
|
|
91
|
+
for atom in mol.GetAtoms():
|
|
92
|
+
symbol = atom.GetSymbol()
|
|
93
|
+
formal_charge = atom.GetFormalCharge()
|
|
94
|
+
|
|
95
|
+
# Check for always toxic elements
|
|
96
|
+
if symbol in always_toxic:
|
|
97
|
+
toxic_found.add(symbol)
|
|
98
|
+
|
|
99
|
+
# Conditionally toxic nitrogen (positively charged)
|
|
100
|
+
if symbol == "N" and formal_charge > 0:
|
|
101
|
+
# Exclude benign quaternary ammonium (e.g., choline-like structures)
|
|
102
|
+
if mol.HasSubstructMatch(Chem.MolFromSmarts("[N+](C)(C)(C)C")): # Example benign structure
|
|
103
|
+
continue
|
|
104
|
+
toxic_found.add("N+")
|
|
105
|
+
|
|
106
|
+
# Halogen toxicity: Uses halogen_toxicity_score to flag excessive halogenation
|
|
107
|
+
if symbol in {"Cl", "Br", "I", "F"}:
|
|
108
|
+
halogen_count, halogen_threshold = halogen_toxicity_score(mol)
|
|
109
|
+
if halogen_count > halogen_threshold:
|
|
110
|
+
toxic_found.add(symbol)
|
|
111
|
+
|
|
112
|
+
return list(toxic_found) if toxic_found else None
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def toxic_groups(mol: Chem.Mol) -> Optional[List[str]]:
|
|
116
|
+
"""
|
|
117
|
+
Check if a molecule contains known toxic functional groups using RDKit's functional groups and SMARTS patterns.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
mol (rdkit.Chem.Mol): The molecule to evaluate.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Optional[List[str]]: List of SMARTS patterns for toxic groups if found, otherwise None.
|
|
124
|
+
"""
|
|
125
|
+
toxic_smarts_matches = []
|
|
126
|
+
|
|
127
|
+
# Use RDKit's functional group definitions
|
|
128
|
+
toxic_group_names = ["Nitro", "Azide", "Alcohol", "Aldehyde", "Halogen", "TerminalAlkyne"]
|
|
129
|
+
for group_name in toxic_group_names:
|
|
130
|
+
group_node = next(node for node in fgroup_hierarchy if node.label == group_name)
|
|
131
|
+
if mol.HasSubstructMatch(Chem.MolFromSmarts(group_node.smarts)):
|
|
132
|
+
toxic_smarts_matches.append(group_node.smarts) # Use group_node's SMARTS directly
|
|
133
|
+
|
|
134
|
+
# Check for custom precompiled toxic SMARTS patterns
|
|
135
|
+
for smarts, compiled in zip(toxic_smarts_patterns, compiled_toxic_smarts):
|
|
136
|
+
if mol.HasSubstructMatch(compiled): # Use precompiled SMARTS
|
|
137
|
+
toxic_smarts_matches.append(smarts)
|
|
138
|
+
|
|
139
|
+
# Special handling for N+
|
|
140
|
+
if mol.HasSubstructMatch(Chem.MolFromSmarts("[N+]")):
|
|
141
|
+
if not mol.HasSubstructMatch(Chem.MolFromSmarts("C[N+](C)(C)C")): # Exclude benign
|
|
142
|
+
toxic_smarts_matches.append("[N+]") # Append as SMARTS
|
|
143
|
+
|
|
144
|
+
# Exempt stabilizing functional groups using precompiled patterns
|
|
145
|
+
for compiled in compiled_exempt_smarts:
|
|
146
|
+
if mol.HasSubstructMatch(compiled):
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
return toxic_smarts_matches if toxic_smarts_matches else None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
if __name__ == "__main__":
|
|
153
|
+
print("Running toxicity detection tests...")
|
|
154
|
+
|
|
155
|
+
# Test molecules with descriptions
|
|
156
|
+
test_molecules = {
|
|
157
|
+
# Safe molecules
|
|
158
|
+
"water": ("O", "Water - should be safe"),
|
|
159
|
+
"benzene": ("c1ccccc1", "Benzene - simple aromatic"),
|
|
160
|
+
"glucose": ("C(C1C(C(C(C(O1)O)O)O)O)O", "Glucose - sugar"),
|
|
161
|
+
"ethanol": ("CCO", "Ethanol - simple alcohol"),
|
|
162
|
+
# Heavy metal containing
|
|
163
|
+
"lead_acetate": ("CC(=O)[O-].CC(=O)[O-].[Pb+2]", "Lead acetate - contains Pb"),
|
|
164
|
+
"mercury_chloride": ("Cl[Hg]Cl", "Mercury chloride - contains Hg"),
|
|
165
|
+
"arsenic_trioxide": ("O=[As]O[As]=O", "Arsenic trioxide - contains As"),
|
|
166
|
+
# Halogenated compounds
|
|
167
|
+
"chloroform": ("C(Cl)(Cl)Cl", "Chloroform - trichloromethyl"),
|
|
168
|
+
"ddt": ("c1ccc(cc1)C(c2ccc(cc2)Cl)C(Cl)(Cl)Cl", "DDT - heavily chlorinated"),
|
|
169
|
+
"fluorobenzene": ("Fc1ccccc1", "Fluorobenzene - single halogen"),
|
|
170
|
+
# Nitrogen compounds
|
|
171
|
+
"nitrobenzene": ("c1ccc(cc1)[N+](=O)[O-]", "Nitrobenzene - nitro group"),
|
|
172
|
+
"choline": ("C[N+](C)(C)CCO", "Choline - benign quaternary ammonium"),
|
|
173
|
+
"toxic_quat": ("[N+](C)(C)(C)(C)", "Toxic quaternary ammonium"),
|
|
174
|
+
# Phenol (exempt)
|
|
175
|
+
"catechol": ("c1ccc(O)c(O)c1", "Catechol - phenol, should be exempt"),
|
|
176
|
+
# Phosphate
|
|
177
|
+
"phosphate": ("P(=O)(O)(O)O", "Phosphate ester - toxic pattern"),
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Test 1: Heavy Metals Detection
|
|
181
|
+
print("\n1. Testing heavy metals detection...")
|
|
182
|
+
for name, (smiles, desc) in test_molecules.items():
|
|
183
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
184
|
+
if mol:
|
|
185
|
+
has_metals = contains_heavy_metals(mol)
|
|
186
|
+
expected = name in ["lead_acetate", "mercury_chloride", "arsenic_trioxide"]
|
|
187
|
+
status = "✓" if has_metals == expected else "✗"
|
|
188
|
+
print(f" {status} {name}: {has_metals} (expected: {expected})")
|
|
189
|
+
|
|
190
|
+
# Test 2: Halogen Toxicity Score
|
|
191
|
+
print("\n2. Testing halogen toxicity scoring...")
|
|
192
|
+
halogen_tests = ["chloroform", "ddt", "fluorobenzene", "benzene"]
|
|
193
|
+
for name in halogen_tests:
|
|
194
|
+
if name in test_molecules:
|
|
195
|
+
smiles, desc = test_molecules[name]
|
|
196
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
197
|
+
if mol:
|
|
198
|
+
count, threshold = halogen_toxicity_score(mol)
|
|
199
|
+
print(f" {name}: {count} halogens, threshold: {threshold}, toxic: {count > threshold}")
|
|
200
|
+
|
|
201
|
+
# Test 3: Toxic Elements
|
|
202
|
+
print("\n3. Testing toxic elements detection...")
|
|
203
|
+
for name, (smiles, desc) in test_molecules.items():
|
|
204
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
205
|
+
if mol:
|
|
206
|
+
toxics = toxic_elements(mol)
|
|
207
|
+
if toxics:
|
|
208
|
+
print(f" ⚠ {name}: {toxics}")
|
|
209
|
+
elif name in ["lead_acetate", "mercury_chloride", "arsenic_trioxide", "chloroform", "ddt"]:
|
|
210
|
+
print(f" ✗ {name}: Should have detected toxic elements")
|
|
211
|
+
else:
|
|
212
|
+
print(f" ✓ {name}: No toxic elements (as expected)")
|
|
213
|
+
|
|
214
|
+
# Test 4: Toxic Groups
|
|
215
|
+
print("\n4. Testing toxic functional groups...")
|
|
216
|
+
for name, (smiles, desc) in test_molecules.items():
|
|
217
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
218
|
+
if mol:
|
|
219
|
+
groups = toxic_groups(mol)
|
|
220
|
+
if groups:
|
|
221
|
+
print(f" ⚠ {name}: Found {len(groups)} toxic group(s)")
|
|
222
|
+
for g in groups[:3]: # Show first 3 patterns
|
|
223
|
+
print(f" - {g[:50]}...")
|
|
224
|
+
elif name == "catechol":
|
|
225
|
+
print(f" ✓ {name}: Exempt (phenol)")
|
|
226
|
+
elif name in ["nitrobenzene", "phosphate", "chloroform", "ethanol"]:
|
|
227
|
+
print(f" ✗ {name}: Should have detected toxic groups")
|
|
228
|
+
else:
|
|
229
|
+
print(f" ✓ {name}: No toxic groups")
|
|
230
|
+
|
|
231
|
+
# Test 5: Edge Cases
|
|
232
|
+
print("\n5. Testing edge cases...")
|
|
233
|
+
edge_cases = [
|
|
234
|
+
("", "Empty SMILES"),
|
|
235
|
+
("INVALID", "Invalid SMILES"),
|
|
236
|
+
("C" * 100, "Very long carbon chain"),
|
|
237
|
+
("[N+](C)(C)(C)C", "Benign quaternary ammonium"),
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
for smiles, desc in edge_cases:
|
|
241
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
242
|
+
if mol:
|
|
243
|
+
metals = contains_heavy_metals(mol)
|
|
244
|
+
elements = toxic_elements(mol)
|
|
245
|
+
groups = toxic_groups(mol)
|
|
246
|
+
print(f" {desc}: metals={metals}, elements={elements is not None}, groups={groups is not None}")
|
|
247
|
+
else:
|
|
248
|
+
print(f" {desc}: Invalid molecule (as expected)")
|
|
249
|
+
|
|
250
|
+
print("\n✅ All toxicity detection tests completed!")
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Molecular visualization utilities for Workbench"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import base64
|
|
5
|
+
import re
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
from rdkit import Chem
|
|
8
|
+
from rdkit.Chem import AllChem, Draw
|
|
9
|
+
from rdkit.Chem.Draw import rdMolDraw2D
|
|
10
|
+
|
|
11
|
+
# Set up the logger
|
|
12
|
+
log = logging.getLogger("workbench")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _is_dark(color: str) -> bool:
|
|
16
|
+
"""Determine if an rgba color is dark based on RGB average.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
color: Color in rgba(...) format
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
True if the color is dark, False otherwise
|
|
23
|
+
"""
|
|
24
|
+
match = re.match(r"rgba?\((\d+),\s*(\d+),\s*(\d+)", color)
|
|
25
|
+
if not match:
|
|
26
|
+
log.warning(f"Invalid color format: {color}, defaulting to dark")
|
|
27
|
+
return True # Default to dark mode on error
|
|
28
|
+
|
|
29
|
+
r, g, b = map(int, match.groups())
|
|
30
|
+
return (r + g + b) / 3 < 128
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _rgba_to_tuple(rgba: str) -> Tuple[float, float, float, float]:
|
|
34
|
+
"""Convert rgba string to normalized tuple (R, G, B, A).
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
rgba: RGBA color string (e.g., "rgba(255, 0, 0, 0.5)")
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Normalized tuple of (R, G, B, A) with RGB in [0, 1]
|
|
41
|
+
"""
|
|
42
|
+
try:
|
|
43
|
+
components = rgba.strip("rgba() ").split(",")
|
|
44
|
+
r, g, b = (int(components[i]) / 255 for i in range(3))
|
|
45
|
+
a = float(components[3]) if len(components) > 3 else 1.0
|
|
46
|
+
return r, g, b, a
|
|
47
|
+
except (IndexError, ValueError) as e:
|
|
48
|
+
log.warning(f"Error parsing color '{rgba}': {e}, using default")
|
|
49
|
+
return 0.25, 0.25, 0.25, 1.0 # Default dark grey
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _validate_molecule(smiles: str) -> Optional[Chem.Mol]:
|
|
53
|
+
"""Validate and return RDKit molecule from SMILES.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
smiles: SMILES string
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
RDKit molecule or None if invalid
|
|
60
|
+
"""
|
|
61
|
+
try:
|
|
62
|
+
mol = Chem.MolFromSmiles(smiles)
|
|
63
|
+
if mol is None:
|
|
64
|
+
log.warning(f"Invalid SMILES: {smiles}")
|
|
65
|
+
return mol
|
|
66
|
+
except Exception as e:
|
|
67
|
+
log.error(f"Error parsing SMILES '{smiles}': {e}")
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _configure_draw_options(options: Draw.MolDrawOptions, background: str) -> None:
|
|
72
|
+
"""Configure drawing options for molecule visualization.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
options: RDKit drawing options object
|
|
76
|
+
background: Background color string
|
|
77
|
+
"""
|
|
78
|
+
if _is_dark(background):
|
|
79
|
+
rdMolDraw2D.SetDarkMode(options)
|
|
80
|
+
options.setBackgroundColour(_rgba_to_tuple(background))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def img_from_smiles(
|
|
84
|
+
smiles: str, width: int = 500, height: int = 500, background: str = "rgba(64, 64, 64, 1)"
|
|
85
|
+
) -> Optional:
|
|
86
|
+
"""Generate an image of the molecule from SMILES.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
smiles: SMILES string representing the molecule
|
|
90
|
+
width: Width of the image in pixels (default: 500)
|
|
91
|
+
height: Height of the image in pixels (default: 500)
|
|
92
|
+
background: Background color (default: dark grey)
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
PIL Image object or None if SMILES is invalid
|
|
96
|
+
"""
|
|
97
|
+
mol = _validate_molecule(smiles)
|
|
98
|
+
if not mol:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
# Set up drawing options
|
|
102
|
+
dos = Draw.MolDrawOptions()
|
|
103
|
+
_configure_draw_options(dos, background)
|
|
104
|
+
|
|
105
|
+
# Generate and return image
|
|
106
|
+
return Draw.MolToImage(mol, options=dos, size=(width, height))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def svg_from_smiles(
|
|
110
|
+
smiles: str, width: int = 500, height: int = 500, background: str = "rgba(64, 64, 64, 1)"
|
|
111
|
+
) -> Optional[str]:
|
|
112
|
+
"""Generate an SVG image of the molecule from SMILES.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
smiles: SMILES string representing the molecule
|
|
116
|
+
width: Width of the image in pixels (default: 500)
|
|
117
|
+
height: Height of the image in pixels (default: 500)
|
|
118
|
+
background: Background color (default: dark grey)
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Base64-encoded SVG data URI or None if SMILES is invalid
|
|
122
|
+
"""
|
|
123
|
+
mol = _validate_molecule(smiles)
|
|
124
|
+
if not mol:
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
# Compute 2D coordinates
|
|
128
|
+
AllChem.Compute2DCoords(mol)
|
|
129
|
+
|
|
130
|
+
# Initialize SVG drawer
|
|
131
|
+
drawer = rdMolDraw2D.MolDraw2DSVG(width, height)
|
|
132
|
+
|
|
133
|
+
# Configure drawing options
|
|
134
|
+
_configure_draw_options(drawer.drawOptions(), background)
|
|
135
|
+
|
|
136
|
+
# Draw molecule
|
|
137
|
+
drawer.DrawMolecule(mol)
|
|
138
|
+
drawer.FinishDrawing()
|
|
139
|
+
|
|
140
|
+
# Encode SVG
|
|
141
|
+
svg = drawer.GetDrawingText()
|
|
142
|
+
encoded_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
|
|
143
|
+
return f"data:image/svg+xml;base64,{encoded_svg}"
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def show(smiles: str, width: int = 500, height: int = 500, background: str = "rgba(64, 64, 64, 1)") -> None:
|
|
147
|
+
"""Display an image of the molecule.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
smiles: SMILES string representing the molecule
|
|
151
|
+
width: Width of the image in pixels (default: 500)
|
|
152
|
+
height: Height of the image in pixels (default: 500)
|
|
153
|
+
background: Background color (default: dark grey)
|
|
154
|
+
"""
|
|
155
|
+
img = img_from_smiles(smiles, width, height, background)
|
|
156
|
+
if img:
|
|
157
|
+
img.show()
|
|
158
|
+
else:
|
|
159
|
+
log.error(f"Cannot display molecule for SMILES: {smiles}")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
if __name__ == "__main__":
|
|
163
|
+
# Test suite
|
|
164
|
+
print("Running molecular visualization tests...")
|
|
165
|
+
|
|
166
|
+
# Test molecules
|
|
167
|
+
test_molecules = {
|
|
168
|
+
"benzene": "c1ccccc1",
|
|
169
|
+
"caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
|
|
170
|
+
"aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
|
|
171
|
+
"invalid": "not_a_smiles",
|
|
172
|
+
"empty": "",
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
# Test 1: Valid SMILES image generation
|
|
176
|
+
print("\n1. Testing image generation from valid SMILES...")
|
|
177
|
+
for name, smiles in test_molecules.items():
|
|
178
|
+
if name not in ["invalid", "empty"]:
|
|
179
|
+
img = img_from_smiles(smiles, width=200, height=200)
|
|
180
|
+
status = "✓" if img else "✗"
|
|
181
|
+
print(f" {status} {name}: {'Success' if img else 'Failed'}")
|
|
182
|
+
|
|
183
|
+
# Test 2: Invalid SMILES handling
|
|
184
|
+
print("\n2. Testing invalid SMILES handling...")
|
|
185
|
+
img = img_from_smiles(test_molecules["invalid"])
|
|
186
|
+
print(f" {'✓' if img is None else '✗'} Invalid SMILES returns None: {img is None}")
|
|
187
|
+
|
|
188
|
+
img = img_from_smiles(test_molecules["empty"])
|
|
189
|
+
print(f" {'✓' if img is None else '✗'} Empty SMILES returns None: {img is None}")
|
|
190
|
+
|
|
191
|
+
# Test 3: SVG generation
|
|
192
|
+
print("\n3. Testing SVG generation...")
|
|
193
|
+
for name, smiles in test_molecules.items():
|
|
194
|
+
if name not in ["invalid", "empty"]:
|
|
195
|
+
svg = svg_from_smiles(smiles, width=200, height=200)
|
|
196
|
+
is_valid = svg and svg.startswith("data:image/svg+xml;base64,")
|
|
197
|
+
status = "✓" if is_valid else "✗"
|
|
198
|
+
print(f" {status} {name}: {'Valid SVG data URI' if is_valid else 'Failed'}")
|
|
199
|
+
|
|
200
|
+
# Test 4: Different backgrounds
|
|
201
|
+
print("\n4. Testing different background colors...")
|
|
202
|
+
backgrounds = [
|
|
203
|
+
("Light", "rgba(255, 255, 255, 1)"),
|
|
204
|
+
("Dark", "rgba(0, 0, 0, 1)"),
|
|
205
|
+
("Custom", "rgba(100, 150, 200, 0.8)"),
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
for bg_name, bg_color in backgrounds:
|
|
209
|
+
img = img_from_smiles(test_molecules["benzene"], background=bg_color)
|
|
210
|
+
status = "✓" if img else "✗"
|
|
211
|
+
print(f" {status} {bg_name} background: {'Success' if img else 'Failed'}")
|
|
212
|
+
|
|
213
|
+
# Test 5: Size variations
|
|
214
|
+
print("\n5. Testing different image sizes...")
|
|
215
|
+
sizes = [(100, 100), (500, 500), (1000, 800)]
|
|
216
|
+
|
|
217
|
+
for w, h in sizes:
|
|
218
|
+
img = img_from_smiles(test_molecules["caffeine"], width=w, height=h)
|
|
219
|
+
status = "✓" if img else "✗"
|
|
220
|
+
print(f" {status} Size {w}x{h}: {'Success' if img else 'Failed'}")
|
|
221
|
+
|
|
222
|
+
# Test 6: Color parsing functions
|
|
223
|
+
print("\n6. Testing color utility functions...")
|
|
224
|
+
test_colors = [
|
|
225
|
+
("invalid_color", True, (0.25, 0.25, 0.25, 1.0)), # Should use defaults
|
|
226
|
+
("rgba(255, 255, 255, 1)", False, (1.0, 1.0, 1.0, 1.0)),
|
|
227
|
+
("rgba(0, 0, 0, 1)", True, (0.0, 0.0, 0.0, 1.0)),
|
|
228
|
+
("rgba(64, 64, 64, 0.5)", True, (0.251, 0.251, 0.251, 0.5)),
|
|
229
|
+
("rgb(128, 128, 128)", False, (0.502, 0.502, 0.502, 1.0)),
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
for color, expected_dark, expected_tuple in test_colors:
|
|
233
|
+
is_dark_result = _is_dark(color)
|
|
234
|
+
tuple_result = _rgba_to_tuple(color)
|
|
235
|
+
|
|
236
|
+
dark_status = "✓" if is_dark_result == expected_dark else "✗"
|
|
237
|
+
print(f" {dark_status} is_dark('{color[:20]}...'): {is_dark_result} == {expected_dark}")
|
|
238
|
+
|
|
239
|
+
# Check tuple values with tolerance for floating point
|
|
240
|
+
tuple_match = all(abs(a - b) < 0.01 for a, b in zip(tuple_result, expected_tuple))
|
|
241
|
+
tuple_status = "✓" if tuple_match else "✗"
|
|
242
|
+
print(f" {tuple_status} rgba_to_tuple('{color[:20]}...'): matches expected")
|
|
243
|
+
|
|
244
|
+
# Test the show function (will open image windows)
|
|
245
|
+
print("\n7. Testing show function (will open image windows)...")
|
|
246
|
+
try:
|
|
247
|
+
show(test_molecules["aspirin"])
|
|
248
|
+
show(test_molecules["aspirin"], background="rgba(220, 220, 220, 1)")
|
|
249
|
+
print(" ✓ show() function executed (check for image window)")
|
|
250
|
+
except Exception as e:
|
|
251
|
+
print(f" ✗ show() function failed: {e}")
|
|
252
|
+
|
|
253
|
+
print("\n✅ All tests completed!")
|
workbench/utils/monitor_utils.py
CHANGED
|
@@ -14,7 +14,7 @@ from workbench.utils.s3_utils import read_content_from_s3
|
|
|
14
14
|
log = logging.getLogger("workbench")
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def
|
|
17
|
+
def pull_data_capture_for_testing(data_capture_path, max_files=1) -> Union[pd.DataFrame, None]:
|
|
18
18
|
"""
|
|
19
19
|
Read and process captured data from S3.
|
|
20
20
|
|
|
@@ -26,7 +26,12 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
26
26
|
|
|
27
27
|
Returns:
|
|
28
28
|
Union[pd.DataFrame, None]: A dataframe of the captured data (or None if no data is found).
|
|
29
|
+
|
|
30
|
+
Notes:
|
|
31
|
+
This method is really only for testing and debugging.
|
|
29
32
|
"""
|
|
33
|
+
log.important("This method is for testing and debugging only.")
|
|
34
|
+
|
|
30
35
|
# List files in the specified S3 path
|
|
31
36
|
files = wr.s3.list_objects(data_capture_path)
|
|
32
37
|
if not files:
|
|
@@ -64,7 +69,8 @@ def pull_data_capture(data_capture_path, max_files=1) -> Union[pd.DataFrame, Non
|
|
|
64
69
|
def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
65
70
|
"""
|
|
66
71
|
Process the captured data DataFrame to extract input and output data.
|
|
67
|
-
|
|
72
|
+
Handles cases where input or output might not be captured.
|
|
73
|
+
|
|
68
74
|
Args:
|
|
69
75
|
df (DataFrame): DataFrame with captured data.
|
|
70
76
|
Returns:
|
|
@@ -77,46 +83,50 @@ def process_data_capture(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
|
77
83
|
try:
|
|
78
84
|
capture_data = row["captureData"]
|
|
79
85
|
|
|
80
|
-
#
|
|
81
|
-
if "endpointInput"
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
elif input_data["encoding"].upper() == "JSON":
|
|
96
|
-
json_data = json.loads(input_data["data"])
|
|
97
|
-
if isinstance(json_data, dict):
|
|
98
|
-
input_df = pd.DataFrame({k: [v] if not isinstance(v, list) else v for k, v in json_data.items()})
|
|
86
|
+
# Process input data if present
|
|
87
|
+
if "endpointInput" in capture_data:
|
|
88
|
+
input_data = capture_data["endpointInput"]
|
|
89
|
+
encoding = input_data["encoding"].upper()
|
|
90
|
+
|
|
91
|
+
if encoding == "CSV":
|
|
92
|
+
input_df = pd.read_csv(StringIO(input_data["data"]))
|
|
93
|
+
elif encoding == "JSON":
|
|
94
|
+
json_data = json.loads(input_data["data"])
|
|
95
|
+
if isinstance(json_data, dict):
|
|
96
|
+
input_df = pd.DataFrame(
|
|
97
|
+
{k: [v] if not isinstance(v, list) else v for k, v in json_data.items()}
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
input_df = pd.DataFrame(json_data)
|
|
99
101
|
else:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
102
|
+
continue # Skip unknown encodings
|
|
103
|
+
|
|
104
|
+
input_dfs.append(input_df)
|
|
105
|
+
|
|
106
|
+
# Process output data if present
|
|
107
|
+
if "endpointOutput" in capture_data:
|
|
108
|
+
output_data = capture_data["endpointOutput"]
|
|
109
|
+
encoding = output_data["encoding"].upper()
|
|
110
|
+
|
|
111
|
+
if encoding == "CSV":
|
|
112
|
+
output_df = pd.read_csv(StringIO(output_data["data"]))
|
|
113
|
+
elif encoding == "JSON":
|
|
114
|
+
json_data = json.loads(output_data["data"])
|
|
115
|
+
if isinstance(json_data, dict):
|
|
116
|
+
output_df = pd.DataFrame(
|
|
117
|
+
{k: [v] if not isinstance(v, list) else v for k, v in json_data.items()}
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
output_df = pd.DataFrame(json_data)
|
|
110
121
|
else:
|
|
111
|
-
|
|
122
|
+
continue # Skip unknown encodings
|
|
112
123
|
|
|
113
|
-
|
|
114
|
-
input_dfs.append(input_df)
|
|
115
|
-
output_dfs.append(output_df)
|
|
124
|
+
output_dfs.append(output_df)
|
|
116
125
|
|
|
117
126
|
except Exception as e:
|
|
118
|
-
log.
|
|
127
|
+
log.debug(f"Row {idx}: Failed to process row: {e}")
|
|
119
128
|
continue
|
|
129
|
+
|
|
120
130
|
# Combine and return results
|
|
121
131
|
return (
|
|
122
132
|
pd.concat(input_dfs, ignore_index=True) if input_dfs else pd.DataFrame(),
|
|
@@ -178,23 +188,6 @@ def parse_monitoring_results(results_json: str) -> Dict[str, Any]:
|
|
|
178
188
|
return {"error": str(e)}
|
|
179
189
|
|
|
180
190
|
|
|
181
|
-
"""TEMP
|
|
182
|
-
# If the status is "CompletedWithViolations", we grab the lastest
|
|
183
|
-
# violation file and add it to the result
|
|
184
|
-
if status == "CompletedWithViolations":
|
|
185
|
-
violation_file = f"{self.monitoring_path}/
|
|
186
|
-
{last_run['CreationTime'].strftime('%Y/%m/%d')}/constraint_violations.json"
|
|
187
|
-
if wr.s3.does_object_exist(violation_file):
|
|
188
|
-
violations_json = read_content_from_s3(violation_file)
|
|
189
|
-
violations = parse_monitoring_results(violations_json)
|
|
190
|
-
result["violations"] = violations.get("constraint_violations", [])
|
|
191
|
-
result["violation_count"] = len(result["violations"])
|
|
192
|
-
else:
|
|
193
|
-
result["violations"] = []
|
|
194
|
-
result["violation_count"] = 0
|
|
195
|
-
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
191
|
def preprocessing_script(feature_list: list[str]) -> str:
|
|
199
192
|
"""
|
|
200
193
|
A preprocessing script for monitoring jobs.
|
|
@@ -245,8 +238,8 @@ if __name__ == "__main__":
|
|
|
245
238
|
from workbench.api.monitor import Monitor
|
|
246
239
|
|
|
247
240
|
# Test pulling data capture
|
|
248
|
-
mon = Monitor("
|
|
249
|
-
df =
|
|
241
|
+
mon = Monitor("abalone-regression-rt")
|
|
242
|
+
df = pull_data_capture_for_testing(mon.data_capture_path)
|
|
250
243
|
print("Data Capture:")
|
|
251
244
|
print(df.head())
|
|
252
245
|
|
|
@@ -262,4 +255,4 @@ if __name__ == "__main__":
|
|
|
262
255
|
# Test preprocessing script
|
|
263
256
|
script = preprocessing_script(["feature1", "feature2", "feature3"])
|
|
264
257
|
print("\nPreprocessing Script:")
|
|
265
|
-
print(script)
|
|
258
|
+
# print(script)
|
workbench/utils/pandas_utils.py
CHANGED
|
@@ -152,7 +152,7 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
152
152
|
|
|
153
153
|
# Check for differences in common columns
|
|
154
154
|
for column in common_columns:
|
|
155
|
-
if pd.api.types.is_string_dtype(df1[column])
|
|
155
|
+
if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
|
|
156
156
|
# String comparison with NaNs treated as equal
|
|
157
157
|
differences = ~(df1[column].fillna("") == df2[column].fillna(""))
|
|
158
158
|
elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
|
|
@@ -161,8 +161,8 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
|
|
|
161
161
|
pd.isna(df1[column]) & pd.isna(df2[column])
|
|
162
162
|
)
|
|
163
163
|
else:
|
|
164
|
-
# Other types (
|
|
165
|
-
differences =
|
|
164
|
+
# Other types (int, Int64, etc.) - compare with NaNs treated as equal
|
|
165
|
+
differences = (df1[column] != df2[column]) & ~(pd.isna(df1[column]) & pd.isna(df2[column]))
|
|
166
166
|
|
|
167
167
|
# If differences exist, display them
|
|
168
168
|
if differences.any():
|