workbench 0.8.162__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (147) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  3. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  4. workbench/algorithms/dataframe/fingerprint_proximity.py +422 -86
  5. workbench/algorithms/dataframe/projection_2d.py +44 -21
  6. workbench/algorithms/dataframe/proximity.py +259 -305
  7. workbench/algorithms/graph/light/proximity_graph.py +14 -12
  8. workbench/algorithms/models/cleanlab_model.py +382 -0
  9. workbench/algorithms/models/noise_model.py +388 -0
  10. workbench/algorithms/sql/outliers.py +3 -3
  11. workbench/api/__init__.py +5 -1
  12. workbench/api/compound.py +1 -1
  13. workbench/api/df_store.py +17 -108
  14. workbench/api/endpoint.py +18 -5
  15. workbench/api/feature_set.py +121 -15
  16. workbench/api/meta.py +5 -2
  17. workbench/api/meta_model.py +289 -0
  18. workbench/api/model.py +55 -21
  19. workbench/api/monitor.py +1 -16
  20. workbench/api/parameter_store.py +3 -52
  21. workbench/cached/cached_model.py +4 -4
  22. workbench/core/artifacts/__init__.py +11 -2
  23. workbench/core/artifacts/artifact.py +16 -8
  24. workbench/core/artifacts/data_capture_core.py +355 -0
  25. workbench/core/artifacts/df_store_core.py +114 -0
  26. workbench/core/artifacts/endpoint_core.py +382 -253
  27. workbench/core/artifacts/feature_set_core.py +249 -45
  28. workbench/core/artifacts/model_core.py +135 -80
  29. workbench/core/artifacts/monitor_core.py +33 -248
  30. workbench/core/artifacts/parameter_store_core.py +98 -0
  31. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  32. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  33. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  34. workbench/core/pipelines/pipeline_executor.py +1 -1
  35. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  36. workbench/core/transforms/features_to_model/features_to_model.py +62 -40
  37. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +76 -15
  38. workbench/core/transforms/pandas_transforms/pandas_to_features.py +38 -2
  39. workbench/core/views/training_view.py +113 -42
  40. workbench/core/views/view.py +53 -3
  41. workbench/core/views/view_utils.py +4 -4
  42. workbench/model_script_utils/model_script_utils.py +339 -0
  43. workbench/model_script_utils/pytorch_utils.py +405 -0
  44. workbench/model_script_utils/uq_harness.py +278 -0
  45. workbench/model_scripts/chemprop/chemprop.template +649 -0
  46. workbench/model_scripts/chemprop/generated_model_script.py +649 -0
  47. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  48. workbench/model_scripts/chemprop/requirements.txt +3 -0
  49. workbench/model_scripts/custom_models/chem_info/fingerprints.py +175 -0
  50. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  51. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  52. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  53. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  54. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  55. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +8 -10
  56. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  57. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  58. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  59. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  60. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  61. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  62. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  63. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  64. workbench/model_scripts/meta_model/meta_model.template +209 -0
  65. workbench/model_scripts/pytorch_model/generated_model_script.py +444 -500
  66. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  67. workbench/model_scripts/pytorch_model/pytorch.template +440 -496
  68. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  69. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  70. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  71. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  72. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  73. workbench/model_scripts/script_generation.py +20 -11
  74. workbench/model_scripts/uq_models/generated_model_script.py +248 -0
  75. workbench/model_scripts/xgb_model/generated_model_script.py +372 -404
  76. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  77. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  78. workbench/model_scripts/xgb_model/xgb_model.template +369 -401
  79. workbench/repl/workbench_shell.py +28 -19
  80. workbench/resources/open_source_api.key +1 -1
  81. workbench/scripts/endpoint_test.py +162 -0
  82. workbench/scripts/lambda_test.py +73 -0
  83. workbench/scripts/meta_model_sim.py +35 -0
  84. workbench/scripts/ml_pipeline_batch.py +137 -0
  85. workbench/scripts/ml_pipeline_sqs.py +186 -0
  86. workbench/scripts/monitor_cloud_watch.py +20 -100
  87. workbench/scripts/training_test.py +85 -0
  88. workbench/utils/aws_utils.py +4 -3
  89. workbench/utils/chem_utils/__init__.py +0 -0
  90. workbench/utils/chem_utils/fingerprints.py +175 -0
  91. workbench/utils/chem_utils/misc.py +194 -0
  92. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  93. workbench/utils/chem_utils/mol_standardize.py +450 -0
  94. workbench/utils/chem_utils/mol_tagging.py +348 -0
  95. workbench/utils/chem_utils/projections.py +219 -0
  96. workbench/utils/chem_utils/salts.py +256 -0
  97. workbench/utils/chem_utils/sdf.py +292 -0
  98. workbench/utils/chem_utils/toxicity.py +250 -0
  99. workbench/utils/chem_utils/vis.py +253 -0
  100. workbench/utils/chemprop_utils.py +141 -0
  101. workbench/utils/cloudwatch_handler.py +1 -1
  102. workbench/utils/cloudwatch_utils.py +137 -0
  103. workbench/utils/config_manager.py +3 -7
  104. workbench/utils/endpoint_utils.py +5 -7
  105. workbench/utils/license_manager.py +2 -6
  106. workbench/utils/meta_model_simulator.py +499 -0
  107. workbench/utils/metrics_utils.py +256 -0
  108. workbench/utils/model_utils.py +278 -79
  109. workbench/utils/monitor_utils.py +44 -62
  110. workbench/utils/pandas_utils.py +3 -3
  111. workbench/utils/pytorch_utils.py +87 -0
  112. workbench/utils/shap_utils.py +11 -57
  113. workbench/utils/workbench_logging.py +0 -3
  114. workbench/utils/workbench_sqs.py +1 -1
  115. workbench/utils/xgboost_local_crossfold.py +267 -0
  116. workbench/utils/xgboost_model_utils.py +127 -219
  117. workbench/web_interface/components/model_plot.py +14 -2
  118. workbench/web_interface/components/plugin_unit_test.py +5 -2
  119. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  120. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  121. workbench/web_interface/components/plugins/model_details.py +38 -74
  122. workbench/web_interface/components/plugins/scatter_plot.py +6 -10
  123. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/METADATA +31 -9
  124. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/RECORD +128 -96
  125. workbench-0.8.220.dist-info/entry_points.txt +11 -0
  126. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +1 -1
  127. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  128. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -280
  129. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  130. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  131. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  132. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  133. workbench/model_scripts/custom_models/proximity/proximity.py +0 -384
  134. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  135. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  136. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -273
  137. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -384
  138. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  139. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  140. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  141. workbench/utils/chem_utils.py +0 -1556
  142. workbench/utils/execution_environment.py +0 -211
  143. workbench/utils/fast_inference.py +0 -167
  144. workbench/utils/resource_utils.py +0 -39
  145. workbench-0.8.162.dist-info/entry_points.txt +0 -5
  146. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  147. {workbench-0.8.162.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ """Miscellaneous processing functions for molecular data."""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import pandas as pd
6
+ from typing import List, Optional
7
+
8
+ # Set up the logger
9
+ log = logging.getLogger("workbench")
10
+
11
+
12
+ def geometric_mean(series: pd.Series) -> float:
13
+ """Computes the geometric mean manually to avoid using scipy."""
14
+ return np.exp(np.log(series).mean())
15
+
16
+
17
+ def rollup_experimental_data(
18
+ df: pd.DataFrame, id: str, time: str, target: str, use_gmean: bool = False
19
+ ) -> pd.DataFrame:
20
+ """
21
+ Rolls up a dataset by selecting the largest time per unique ID and averaging the target value
22
+ if multiple records exist at that time. Supports both arithmetic and geometric mean.
23
+
24
+ Parameters:
25
+ df (pd.DataFrame): Input dataframe.
26
+ id (str): Column representing the unique molecule ID.
27
+ time (str): Column representing the time.
28
+ target (str): Column representing the target value.
29
+ use_gmean (bool): Whether to use the geometric mean instead of the arithmetic mean.
30
+
31
+ Returns:
32
+ pd.DataFrame: Rolled-up dataframe with all original columns retained.
33
+ """
34
+ # Find the max time per unique ID
35
+ max_time_df = df.groupby(id)[time].transform("max")
36
+ filtered_df = df[df[time] == max_time_df]
37
+
38
+ # Define aggregation function
39
+ agg_func = geometric_mean if use_gmean else np.mean
40
+
41
+ # Perform aggregation on all columns
42
+ agg_dict = {col: "first" for col in df.columns if col not in [target, id, time]}
43
+ agg_dict[target] = lambda x: agg_func(x) if len(x) > 1 else x.iloc[0] # Apply mean or gmean
44
+
45
+ rolled_up_df = filtered_df.groupby([id, time]).agg(agg_dict).reset_index()
46
+ return rolled_up_df
47
+
48
+
49
+ def micromolar_to_log(series_µM: pd.Series) -> pd.Series:
50
+ """
51
+ Convert a pandas Series of concentrations in µM (micromolar) to their logarithmic values (log10).
52
+
53
+ Parameters:
54
+ series_uM (pd.Series): Series of concentrations in micromolar.
55
+
56
+ Returns:
57
+ pd.Series: Series of logarithmic values (log10).
58
+ """
59
+ # Replace 0 or negative values with a small number to avoid log errors
60
+ adjusted_series = series_µM.clip(lower=1e-9) # Alignment with another project
61
+
62
+ series_mol_per_l = adjusted_series * 1e-6 # Convert µM/L to mol/L
63
+ log_series = np.log10(series_mol_per_l)
64
+ return log_series
65
+
66
+
67
+ def log_to_micromolar(log_series: pd.Series) -> pd.Series:
68
+ """
69
+ Convert a pandas Series of logarithmic values (log10) back to concentrations in µM (micromolar).
70
+
71
+ Parameters:
72
+ log_series (pd.Series): Series of logarithmic values (log10).
73
+
74
+ Returns:
75
+ pd.Series: Series of concentrations in micromolar.
76
+ """
77
+ series_mol_per_l = 10**log_series # Convert log10 back to mol/L
78
+ series_µM = series_mol_per_l * 1e6 # Convert mol/L to µM
79
+ return series_µM
80
+
81
+
82
+ def feature_resolution_issues(df: pd.DataFrame, features: List[str], show_cols: Optional[List[str]] = None) -> None:
83
+ """
84
+ Identify and print groups in a DataFrame where the given features have more than one unique SMILES,
85
+ sorted by group size (largest number of unique SMILES first).
86
+
87
+ Args:
88
+ df (pd.DataFrame): Input DataFrame containing SMILES strings.
89
+ features (List[str]): List of features to check.
90
+ show_cols (Optional[List[str]]): Columns to display; defaults to all columns.
91
+ """
92
+ # Check for the 'smiles' column (case-insensitive)
93
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
94
+ if smiles_column is None:
95
+ raise ValueError("Input DataFrame must have a 'smiles' column")
96
+
97
+ show_cols = show_cols if show_cols is not None else df.columns.tolist()
98
+
99
+ # Drop duplicates to keep only unique SMILES for each feature combination
100
+ unique_df = df.drop_duplicates(subset=[smiles_column] + features)
101
+
102
+ # Find groups with more than one unique SMILES
103
+ group_counts = unique_df.groupby(features).size()
104
+ collision_groups = group_counts[group_counts > 1].sort_values(ascending=False)
105
+
106
+ # Print each group in order of size (largest first)
107
+ for group, count in collision_groups.items():
108
+ # Get the rows for this group
109
+ if isinstance(group, tuple):
110
+ group_mask = (unique_df[features] == group).all(axis=1)
111
+ else:
112
+ group_mask = unique_df[features[0]] == group
113
+
114
+ group_df = unique_df[group_mask]
115
+
116
+ print(f"Feature Group (unique SMILES: {count}):")
117
+ print(group_df[show_cols])
118
+ print("\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ print("Running molecular processing and transformation tests...")
123
+ print("Note: This requires the molecular_filters module to be available")
124
+
125
+ # Test 1: Concentration conversions
126
+ print("\n1. Testing concentration conversions...")
127
+
128
+ # Test micromolar to log
129
+ test_conc = pd.Series([1.0, 10.0, 100.0, 1000.0, 0.001])
130
+ log_values = micromolar_to_log(test_conc)
131
+ back_to_uM = log_to_micromolar(log_values)
132
+
133
+ print(" µM → log10 → µM:")
134
+ for orig, log_val, back in zip(test_conc, log_values, back_to_uM):
135
+ print(f" {orig:8.3f} µM → {log_val:6.2f} → {back:8.3f} µM")
136
+
137
+ # Test 2: Geometric mean
138
+ print("\n2. Testing geometric mean...")
139
+ test_series = pd.Series([2, 4, 8, 16])
140
+ geo_mean = geometric_mean(test_series)
141
+ arith_mean = np.mean(test_series)
142
+ print(f" Series: {list(test_series)}")
143
+ print(f" Arithmetic mean: {arith_mean:.2f}")
144
+ print(f" Geometric mean: {geo_mean:.2f}")
145
+
146
+ # Test 3: Experimental data rollup
147
+ print("\n3. Testing experimental data rollup...")
148
+
149
+ # Create test data with multiple timepoints and replicates
150
+ test_data = pd.DataFrame(
151
+ {
152
+ "compound_id": ["A", "A", "A", "B", "B", "C", "C", "C"],
153
+ "time": [1, 2, 2, 1, 2, 1, 1, 2],
154
+ "activity": [10, 20, 22, 5, 8, 100, 110, 200],
155
+ "assay": ["kinase", "kinase", "kinase", "kinase", "kinase", "cell", "cell", "cell"],
156
+ }
157
+ )
158
+
159
+ # Rollup with arithmetic mean
160
+ rolled_arith = rollup_experimental_data(test_data, "compound_id", "time", "activity", use_gmean=False)
161
+ print(" Arithmetic mean rollup:")
162
+ print(rolled_arith[["compound_id", "time", "activity"]])
163
+
164
+ # Rollup with geometric mean
165
+ rolled_geo = rollup_experimental_data(test_data, "compound_id", "time", "activity", use_gmean=True)
166
+ print("\n Geometric mean rollup:")
167
+ print(rolled_geo[["compound_id", "time", "activity"]])
168
+
169
+ # Test 4: Feature resolution issues
170
+ print("\n4. Testing feature resolution identification...")
171
+
172
+ # Create data with some duplicate features but different SMILES
173
+ resolution_df = pd.DataFrame(
174
+ {
175
+ "smiles": ["CCO", "C(C)O", "CC(C)O", "CCC(C)O", "CCCO"],
176
+ "assay_id": ["A1", "A1", "A2", "A2", "A3"],
177
+ "value": [1.0, 1.5, 2.0, 2.2, 3.0],
178
+ }
179
+ )
180
+
181
+ print(" Checking for feature collisions in 'assay_id':")
182
+ feature_resolution_issues(resolution_df, ["assay_id"], show_cols=["smiles", "assay_id", "value"])
183
+
184
+ # Test 7: Edge cases
185
+ print("\n7. Testing edge cases...")
186
+
187
+ # Zero and negative concentrations
188
+ edge_conc = pd.Series([0, -1, 1e-10])
189
+ edge_log = micromolar_to_log(edge_conc)
190
+ print(" Edge concentration handling:")
191
+ for c, l in zip(edge_conc, edge_log):
192
+ print(f" {c:6.2e} µM → {l:6.2f}")
193
+
194
+ print("\n✅ All molecular processing tests completed!")
@@ -0,0 +1,483 @@
1
+ """
2
+ mol_descriptors.py - Molecular descriptor computation for ADMET modeling
3
+
4
+ Purpose:
5
+ Computes comprehensive molecular descriptors for ADMET (Absorption, Distribution,
6
+ Metabolism, Excretion, Toxicity) property prediction. Combines RDKit's full
7
+ descriptor set with selected Mordred descriptors and custom stereochemistry features.
8
+
9
+ Descriptor Categories:
10
+ 1. RDKit Descriptors (~220 descriptors)
11
+ - Constitutional (MW, heavy atom count, rotatable bonds)
12
+ - Topological (Balaban J, Kappa indices, Chi indices)
13
+ - Geometric (radius of gyration, spherocity)
14
+ - Electronic (HOMO/LUMO estimates, partial charges)
15
+ - Lipophilicity (LogP, MolLogP)
16
+ - Pharmacophore (H-bond donors/acceptors, aromatic rings)
17
+ - ADMET-specific (TPSA, QED, Lipinski descriptors)
18
+
19
+ 2. Mordred Descriptors (~80 descriptors from 5 ADMET-relevant modules)
20
+ - AcidBase module: pH-dependent properties (nAcid, nBase)
21
+ - Aromatic module: CYP metabolism features (nAromAtom, nAromBond)
22
+ - Constitutional module: Structural complexity (~40 descriptors including nSpiro, nBridgehead)
23
+ - Chi module: Molecular connectivity indices (~42 descriptors, Chi0-Chi4 variants)
24
+ - CarbonTypes module: Carbon hybridization states for metabolism (~20 descriptors)
25
+
26
+ 3. Stereochemistry Features (10 custom descriptors)
27
+ - Stereocenter counts (R/S, defined/undefined)
28
+ - Stereobond counts (E/Z, defined/undefined)
29
+ - Stereochemical complexity and coverage metrics
30
+ - Critical for distinguishing drug enantiomers/diastereomers
31
+
32
+ Pipeline Integration:
33
+ This module expects standardized SMILES from mol_standardize.py:
34
+
35
+ 1. Standardize structures (mol_standardize.py)
36
+
37
+ 2. Compute descriptors (this module)
38
+
39
+ 3. Feature selection/ML modeling
40
+
41
+ Output:
42
+ Returns input DataFrame with added descriptor columns:
43
+ - ~220 RDKit descriptors
44
+ - ~85 Mordred descriptors (from 5 modules)
45
+ - 10 stereochemistry descriptors
46
+ Total: ~310 descriptors
47
+
48
+ Invalid molecules receive NaN values for all descriptors.
49
+
50
+ Performance Notes:
51
+ - RDKit descriptors: Fast, vectorized computation
52
+ - Mordred descriptors: Moderate speed
53
+ - Stereochemistry: Moderate speed, requires CIP labeling
54
+ - Memory: <1GB per 10,000 molecules with all descriptors
55
+
56
+ Special Considerations:
57
+ - Ipc descriptor excluded due to potential overflow issues
58
+ - Molecules failing descriptor calculation get NaN (not dropped)
59
+ - Stereochemistry features optional for non-chiral datasets
60
+ - Salt information from standardization not included in descriptors
61
+ (use separately as categorical feature if needed)
62
+ - Feature selection recommended due to descriptor redundancy
63
+
64
+ Example Usage:
65
+ import pandas as pd
66
+ from mol_standardize import standardize_dataframe
67
+ from mol_descriptors import compute_descriptors
68
+
69
+ # Standard pipeline
70
+ df = pd.read_csv("molecules.csv")
71
+ df = standardize_dataframe(df) # Standardize first
72
+ df = compute_descriptors(df) # Then compute descriptors
73
+
74
+ # For achiral molecules (faster)
75
+ df = compute_descriptors(df, include_stereo=False)
76
+
77
+ # Custom SMILES column
78
+ df = compute_descriptors(df, smiles_column='canonical_smiles')
79
+
80
+ # The resulting DataFrame is ready for ML modeling
81
+ X = df.select_dtypes(include=[np.number]) # All numeric descriptors
82
+ y = df['activity'] # Your target variable
83
+
84
+ References:
85
+ - RDKit descriptors: https://www.rdkit.org/docs/GettingStartedInPython.html#descriptors
86
+ - Mordred: https://github.com/mordred-descriptor/mordred
87
+ - Stereochemistry in drug discovery: https://doi.org/10.1021/acs.jmedchem.0c00915
88
+ """
89
+
90
+ import logging
91
+ import pandas as pd
92
+ import numpy as np
93
+ import re
94
+ import time
95
+ from contextlib import contextmanager
96
+ from rdkit import Chem
97
+ from rdkit.Chem import Descriptors, rdCIPLabeler
98
+ from rdkit.ML.Descriptors import MoleculeDescriptors
99
+ from mordred import Calculator as MordredCalculator
100
+ from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
101
+
102
+
103
+ logger = logging.getLogger("workbench")
104
+ logger.setLevel(logging.DEBUG)
105
+
106
+
107
+ # Helper context manager for timing
108
+ @contextmanager
109
+ def timer(name):
110
+ start = time.time()
111
+ yield
112
+ print(f"{name}: {time.time() - start:.2f}s")
113
+
114
+
115
+ def compute_stereochemistry_features(mol):
116
+ """
117
+ Compute stereochemistry descriptors using modern RDKit methods.
118
+
119
+ Returns dict with 10 stereochemistry descriptors commonly used in ADMET.
120
+ """
121
+ if mol is None:
122
+ return {
123
+ "num_stereocenters": np.nan,
124
+ "num_unspecified_stereocenters": np.nan,
125
+ "num_defined_stereocenters": np.nan,
126
+ "num_r_centers": np.nan,
127
+ "num_s_centers": np.nan,
128
+ "num_stereobonds": np.nan,
129
+ "num_e_bonds": np.nan,
130
+ "num_z_bonds": np.nan,
131
+ "stereo_complexity": np.nan,
132
+ "frac_defined_stereo": np.nan,
133
+ }
134
+
135
+ try:
136
+ # Find all potential stereogenic elements
137
+ stereo_info = Chem.FindPotentialStereo(mol)
138
+
139
+ # Initialize counters
140
+ defined_centers = 0
141
+ undefined_centers = 0
142
+ r_centers = 0
143
+ s_centers = 0
144
+ defined_bonds = 0
145
+ undefined_bonds = 0
146
+ e_bonds = 0
147
+ z_bonds = 0
148
+
149
+ # Assign CIP labels for accurate R/S and E/Z determination
150
+ rdCIPLabeler.AssignCIPLabels(mol)
151
+
152
+ # Process stereogenic elements
153
+ for element in stereo_info:
154
+ if element.type == Chem.StereoType.Atom_Tetrahedral:
155
+ if element.specified == Chem.StereoSpecified.Specified:
156
+ defined_centers += 1
157
+ # Get the atom and check its CIP code
158
+ atom = mol.GetAtomWithIdx(element.centeredOn)
159
+ if atom.HasProp("_CIPCode"):
160
+ cip = atom.GetProp("_CIPCode")
161
+ if cip == "R":
162
+ r_centers += 1
163
+ elif cip == "S":
164
+ s_centers += 1
165
+ else:
166
+ undefined_centers += 1
167
+
168
+ elif element.type == Chem.StereoType.Bond_Double:
169
+ if element.specified == Chem.StereoSpecified.Specified:
170
+ defined_bonds += 1
171
+ # Get the bond and check its CIP code
172
+ bond = mol.GetBondWithIdx(element.centeredOn)
173
+ if bond.HasProp("_CIPCode"):
174
+ cip = bond.GetProp("_CIPCode")
175
+ if cip == "E":
176
+ e_bonds += 1
177
+ elif cip == "Z":
178
+ z_bonds += 1
179
+ else:
180
+ undefined_bonds += 1
181
+
182
+ # Calculate derived metrics
183
+ total_stereocenters = defined_centers + undefined_centers
184
+ total_stereobonds = defined_bonds + undefined_bonds
185
+ total_stereo = total_stereocenters + total_stereobonds
186
+
187
+ # Stereochemical complexity (total stereogenic elements)
188
+ stereo_complexity = total_stereo
189
+
190
+ # Fraction of defined stereochemistry
191
+ if total_stereo > 0:
192
+ frac_defined = (defined_centers + defined_bonds) / total_stereo
193
+ else:
194
+ frac_defined = 1.0 # No stereo elements = fully defined
195
+
196
+ return {
197
+ "num_stereocenters": total_stereocenters,
198
+ "num_unspecified_stereocenters": undefined_centers,
199
+ "num_defined_stereocenters": defined_centers,
200
+ "num_r_centers": r_centers,
201
+ "num_s_centers": s_centers,
202
+ "num_stereobonds": total_stereobonds,
203
+ "num_e_bonds": e_bonds,
204
+ "num_z_bonds": z_bonds,
205
+ "stereo_complexity": stereo_complexity,
206
+ "frac_defined_stereo": frac_defined,
207
+ }
208
+
209
+ except Exception as e:
210
+ logger.warning(f"Stereochemistry calculation failed: {e}")
211
+ return {
212
+ "num_stereocenters": np.nan,
213
+ "num_unspecified_stereocenters": np.nan,
214
+ "num_defined_stereocenters": np.nan,
215
+ "num_r_centers": np.nan,
216
+ "num_s_centers": np.nan,
217
+ "num_stereobonds": np.nan,
218
+ "num_e_bonds": np.nan,
219
+ "num_z_bonds": np.nan,
220
+ "stereo_complexity": np.nan,
221
+ "frac_defined_stereo": np.nan,
222
+ }
223
+
224
+
225
+ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_stereo: bool = True) -> pd.DataFrame:
226
+ """
227
+ Compute all molecular descriptors for ADMET modeling.
228
+
229
+ Args:
230
+ df: Input DataFrame with SMILES
231
+ include_mordred: Whether to compute Mordred descriptors (default True)
232
+ include_stereo: Whether to compute stereochemistry features (default True)
233
+
234
+ Returns:
235
+ DataFrame with all descriptor columns added
236
+
237
+ Example:
238
+ df = standardize(df) # First standardize
239
+ df = compute_descriptors(df) # Then compute descriptors with stereo
240
+ df = compute_descriptors(df, include_stereo=False) # Without stereo
241
+ df = compute_descriptors(df, include_mordred=False) # RDKit only
242
+ """
243
+
244
+ # Check for the smiles column (any capitalization)
245
+ smiles_column = next((col for col in df.columns if col.lower() == "smiles"), None)
246
+ if smiles_column is None:
247
+ raise ValueError("Input DataFrame must have a 'smiles' column")
248
+
249
+ result = df.copy()
250
+
251
+ # Create molecule objects
252
+ logger.info("Creating molecule objects...")
253
+ molecules = []
254
+ for idx, row in result.iterrows():
255
+ smiles = row[smiles_column]
256
+
257
+ if pd.isna(smiles) or smiles == "":
258
+ molecules.append(None)
259
+ else:
260
+ mol = Chem.MolFromSmiles(smiles)
261
+ molecules.append(mol)
262
+
263
+ # Compute RDKit descriptors
264
+ logger.info("Computing RDKit Descriptors...")
265
+
266
+ # Get all RDKit descriptors
267
+ all_descriptors = [x[0] for x in Descriptors._descList]
268
+
269
+ # Remove IPC descriptor due to overflow issue
270
+ # See: https://github.com/rdkit/rdkit/issues/1527
271
+ if "Ipc" in all_descriptors:
272
+ all_descriptors.remove("Ipc")
273
+
274
+ # Make sure we don't have duplicates
275
+ all_descriptors = list(set(all_descriptors))
276
+
277
+ # Initialize calculator
278
+ calc = MoleculeDescriptors.MolecularDescriptorCalculator(all_descriptors)
279
+
280
+ # Compute descriptors
281
+ descriptor_values = []
282
+ for mol in molecules:
283
+ if mol is None:
284
+ descriptor_values.append([np.nan] * len(all_descriptors))
285
+ else:
286
+ try:
287
+ values = calc.CalcDescriptors(mol)
288
+ descriptor_values.append(values)
289
+ except Exception as e:
290
+ logger.warning(f"RDKit descriptor calculation failed: {e}")
291
+ descriptor_values.append([np.nan] * len(all_descriptors))
292
+
293
+ # Create RDKit features DataFrame
294
+ rdkit_features_df = pd.DataFrame(descriptor_values, columns=calc.GetDescriptorNames())
295
+
296
+ # Add RDKit features to result
297
+ # Remove any columns from result that exist in rdkit_features_df
298
+ result = result.drop(columns=result.columns.intersection(rdkit_features_df.columns))
299
+ result = pd.concat([result, rdkit_features_df], axis=1)
300
+
301
+ # Compute Mordred descriptors
302
+ if include_mordred:
303
+ logger.info("Computing Mordred descriptors from relevant modules...")
304
+ calc = MordredCalculator()
305
+
306
+ # Register 5 ADMET-focused modules (avoiding overlap with RDKit)
307
+ calc.register(AcidBase) # ~2 descriptors: nAcid, nBase
308
+ calc.register(Aromatic) # ~2 descriptors: nAromAtom, nAromBond
309
+ calc.register(Constitutional) # ~30 descriptors: structural complexity
310
+ calc.register(Chi) # ~32 descriptors: connectivity indices
311
+ calc.register(CarbonTypes) # ~20 descriptors: carbon hybridization
312
+
313
+ # Compute Mordred descriptors
314
+ valid_mols = [mol if mol is not None else Chem.MolFromSmiles("C") for mol in molecules]
315
+ mordred_df = calc.pandas(valid_mols, nproc=1) # Endpoint multiprocessing will fail with nproc>1
316
+
317
+ # Replace values for invalid molecules with NaN
318
+ for i, mol in enumerate(molecules):
319
+ if mol is None:
320
+ mordred_df.iloc[i] = np.nan
321
+
322
+ # Handle Mordred's special error values
323
+ for col in mordred_df.columns:
324
+ mordred_df[col] = pd.to_numeric(mordred_df[col], errors="coerce")
325
+
326
+ # Add Mordred features to result
327
+ # Remove any columns from result that exist in mordred
328
+ result = result.drop(columns=result.columns.intersection(mordred_df.columns))
329
+ result = pd.concat([result, mordred_df], axis=1)
330
+
331
+ # Compute stereochemistry features if requested
332
+ if include_stereo:
333
+ logger.info("Computing Stereochemistry Descriptors...")
334
+
335
+ stereo_features = []
336
+ for mol in molecules:
337
+ stereo_dict = compute_stereochemistry_features(mol)
338
+ stereo_features.append(stereo_dict)
339
+
340
+ # Create stereochemistry DataFrame
341
+ stereo_df = pd.DataFrame(stereo_features)
342
+
343
+ # Add stereochemistry features to result
344
+ result = result.drop(columns=result.columns.intersection(stereo_df.columns))
345
+ result = pd.concat([result, stereo_df], axis=1)
346
+
347
+ logger.info(f"Added {len(stereo_df.columns)} stereochemistry descriptors")
348
+
349
+ # Log summary
350
+ valid_mols = sum(1 for m in molecules if m is not None)
351
+ total_descriptors = len(result.columns) - len(df.columns)
352
+ logger.info(f"Computed {total_descriptors} descriptors for {valid_mols}/{len(df)} valid molecules")
353
+
354
+ # Log descriptor breakdown
355
+ rdkit_count = len(rdkit_features_df.columns)
356
+ mordred_count = len(mordred_df.columns) if include_mordred else 0
357
+ stereo_count = len(stereo_df.columns) if include_stereo else 0
358
+ logger.info(f"Descriptor breakdown: RDKit={rdkit_count}, Mordred={mordred_count}, Stereo={stereo_count}")
359
+
360
+ # Sanitize column names for AWS Athena compatibility
361
+ # - Must be lowercase, no special characters except underscore, no spaces
362
+ result.columns = [re.sub(r"_+", "_", re.sub(r"[^a-z0-9_]", "_", col.lower())) for col in result.columns]
363
+
364
+ # Drop duplicate columns if any exist after sanitization
365
+ if result.columns.duplicated().any():
366
+ logger.warning("Duplicate column names after sanitization - dropping duplicates!")
367
+ result = result.loc[:, ~result.columns.duplicated()]
368
+
369
+ return result
370
+
371
+
372
+ if __name__ == "__main__":
373
+ from mol_standardize import standardize
374
+ from workbench.api import DataSource
375
+
376
+ # Configure pandas display
377
+ pd.set_option("display.max_columns", None)
378
+ pd.set_option("display.max_colwidth", 100)
379
+ pd.set_option("display.width", 1200)
380
+
381
+ # Test data - stereochemistry examples
382
+ stereo_test_data = pd.DataFrame(
383
+ {
384
+ "smiles": [
385
+ "CC(=O)Oc1ccccc1C(=O)O", # Aspirin
386
+ "C[C@H](N)C(=O)O", # L-Alanine
387
+ "C[C@@H](N)C(=O)O", # D-Alanine
388
+ "C/C=C/C=C/C", # E,E-hexadiene
389
+ "CC(F)(Cl)Br", # Unspecified chiral
390
+ "",
391
+ "INVALID", # Invalid cases
392
+ ],
393
+ "name": ["Aspirin", "L-Alanine", "D-Alanine", "E,E-hexadiene", "Unspecified", "Empty", "Invalid"],
394
+ }
395
+ )
396
+
397
+ # Test data - salt handling examples
398
+ salt_test_data = pd.DataFrame(
399
+ {
400
+ "smiles": [
401
+ "CC(=O)O", # Acetic acid
402
+ "[Na+].CC(=O)[O-]", # Sodium acetate
403
+ "CC(C)NCC(O)c1ccc(O)c(O)c1.Cl", # Drug HCl salt
404
+ "Oc1ccccn1", # Tautomer 1
405
+ "O=c1cccc[nH]1", # Tautomer 2
406
+ ],
407
+ "compound_id": [f"C{i:03d}" for i in range(1, 6)],
408
+ }
409
+ )
410
+
411
+ def run_basic_tests():
412
+ """Run basic functionality tests"""
413
+ print("=" * 80)
414
+ print("BASIC FUNCTIONALITY TESTS")
415
+ print("=" * 80)
416
+
417
+ # Test stereochemistry
418
+ result = compute_descriptors(stereo_test_data, include_stereo=True)
419
+
420
+ print("\nStereochemistry features (selected molecules):")
421
+ for idx, name in enumerate(stereo_test_data["name"][:4]):
422
+ print(
423
+ f"{name:15} - centers: {result.iloc[idx]['num_stereocenters']:.0f}, "
424
+ f"R/S: {result.iloc[idx]['num_r_centers']:.0f}/"
425
+ f"{result.iloc[idx]['num_s_centers']:.0f}"
426
+ )
427
+
428
+ # Test salt handling
429
+ print("\nSalt extraction test:")
430
+ std_result = standardize(salt_test_data, extract_salts=True)
431
+ for _, row in std_result.iterrows():
432
+ salt_info = f" → salt: {row['salt']}" if pd.notna(row["salt"]) else ""
433
+ print(f"{row['compound_id']}: {row['smiles'][:30]}{salt_info}")
434
+
435
+ def run_performance_tests():
436
+ """Run performance timing tests"""
437
+ print("\n" + "=" * 80)
438
+ print("PERFORMANCE TESTS on real world molecules")
439
+ print("=" * 80)
440
+
441
+ # Get a real dataset from Workbench
442
+ ds = DataSource("aqsol_data")
443
+ df = ds.pull_dataframe()[["id", "smiles"]][:1000] # Limit to 1000 for testing
444
+ n_mols = df.shape[0]
445
+ print(f"Pulled {n_mols} molecules from DataSource 'aqsol_data'")
446
+
447
+ # Test configurations
448
+ configs = [
449
+ ("Standardize (full)", standardize, {"extract_salts": True, "canonicalize_tautomer": True}),
450
+ ("Standardize (minimal)", standardize, {"extract_salts": False, "canonicalize_tautomer": False}),
451
+ ("Descriptors (all)", compute_descriptors, {"include_mordred": True, "include_stereo": True}),
452
+ ("Descriptors (RDKit only)", compute_descriptors, {"include_mordred": False, "include_stereo": False}),
453
+ ]
454
+
455
+ results = []
456
+ for name, func, params in configs:
457
+ start = time.time()
458
+ _ = func(df, **params)
459
+ elapsed = time.time() - start
460
+ throughput = n_mols / elapsed
461
+ results.append((name, elapsed, throughput))
462
+ print(f"{name:25} {elapsed:6.2f}s ({throughput:6.1f} mol/s)")
463
+
464
+ # Full pipeline test
465
+ print("\nFull pipeline (standardize + all descriptors):")
466
+ start = time.time()
467
+ std_data = standardize(df)
468
+ standardize_time = time.time() - start
469
+ print(f" Standardize: {standardize_time:.2f}s ({n_mols / standardize_time:.1f} mol/s)")
470
+ start = time.time()
471
+ _ = compute_descriptors(std_data)
472
+ descriptor_time = time.time() - start
473
+ print(f" Descriptors: {descriptor_time:.2f}s ({n_mols / descriptor_time:.1f} mol/s)")
474
+ pipeline_time = standardize_time + descriptor_time
475
+ print(f" Total: {pipeline_time:.2f}s ({n_mols / pipeline_time:.1f} mol/s)")
476
+
477
+ return results
478
+
479
+ # Run tests
480
+ run_basic_tests()
481
+ timing_results = run_performance_tests()
482
+
483
+ print("\n✅ All tests completed!")