workbench 0.8.162__py3-none-any.whl → 0.8.202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (113) hide show
  1. workbench/algorithms/dataframe/__init__.py +1 -2
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +2 -2
  3. workbench/algorithms/dataframe/proximity.py +261 -235
  4. workbench/algorithms/graph/light/proximity_graph.py +10 -8
  5. workbench/api/__init__.py +2 -1
  6. workbench/api/compound.py +1 -1
  7. workbench/api/endpoint.py +11 -0
  8. workbench/api/feature_set.py +11 -8
  9. workbench/api/meta.py +5 -2
  10. workbench/api/model.py +16 -15
  11. workbench/api/monitor.py +1 -16
  12. workbench/core/artifacts/__init__.py +11 -2
  13. workbench/core/artifacts/artifact.py +11 -3
  14. workbench/core/artifacts/data_capture_core.py +355 -0
  15. workbench/core/artifacts/endpoint_core.py +256 -118
  16. workbench/core/artifacts/feature_set_core.py +265 -16
  17. workbench/core/artifacts/model_core.py +107 -60
  18. workbench/core/artifacts/monitor_core.py +33 -248
  19. workbench/core/cloud_platform/aws/aws_account_clamp.py +50 -1
  20. workbench/core/cloud_platform/aws/aws_meta.py +12 -5
  21. workbench/core/cloud_platform/aws/aws_parameter_store.py +18 -2
  22. workbench/core/cloud_platform/aws/aws_session.py +4 -4
  23. workbench/core/transforms/data_to_features/light/molecular_descriptors.py +4 -4
  24. workbench/core/transforms/features_to_model/features_to_model.py +42 -32
  25. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +36 -6
  26. workbench/core/transforms/pandas_transforms/pandas_to_features.py +27 -0
  27. workbench/core/views/training_view.py +113 -42
  28. workbench/core/views/view.py +53 -3
  29. workbench/core/views/view_utils.py +4 -4
  30. workbench/model_scripts/chemprop/chemprop.template +852 -0
  31. workbench/model_scripts/chemprop/generated_model_script.py +852 -0
  32. workbench/model_scripts/chemprop/requirements.txt +11 -0
  33. workbench/model_scripts/custom_models/chem_info/fingerprints.py +134 -0
  34. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +483 -0
  35. workbench/model_scripts/custom_models/chem_info/mol_standardize.py +450 -0
  36. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +7 -9
  37. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +1 -1
  38. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +3 -5
  39. workbench/model_scripts/custom_models/proximity/proximity.py +261 -235
  40. workbench/model_scripts/custom_models/uq_models/bayesian_ridge.template +7 -8
  41. workbench/model_scripts/custom_models/uq_models/ensemble_xgb.template +20 -21
  42. workbench/model_scripts/custom_models/uq_models/gaussian_process.template +5 -11
  43. workbench/model_scripts/custom_models/uq_models/meta_uq.template +166 -62
  44. workbench/model_scripts/custom_models/uq_models/ngboost.template +30 -18
  45. workbench/model_scripts/custom_models/uq_models/proximity.py +261 -235
  46. workbench/model_scripts/custom_models/uq_models/requirements.txt +1 -3
  47. workbench/model_scripts/ensemble_xgb/ensemble_xgb.template +15 -17
  48. workbench/model_scripts/pytorch_model/generated_model_script.py +373 -190
  49. workbench/model_scripts/pytorch_model/pytorch.template +370 -187
  50. workbench/model_scripts/scikit_learn/generated_model_script.py +7 -12
  51. workbench/model_scripts/scikit_learn/scikit_learn.template +4 -9
  52. workbench/model_scripts/script_generation.py +17 -9
  53. workbench/model_scripts/uq_models/generated_model_script.py +605 -0
  54. workbench/model_scripts/uq_models/mapie.template +605 -0
  55. workbench/model_scripts/uq_models/requirements.txt +1 -0
  56. workbench/model_scripts/xgb_model/generated_model_script.py +37 -46
  57. workbench/model_scripts/xgb_model/xgb_model.template +44 -46
  58. workbench/repl/workbench_shell.py +28 -14
  59. workbench/scripts/endpoint_test.py +162 -0
  60. workbench/scripts/lambda_test.py +73 -0
  61. workbench/scripts/ml_pipeline_batch.py +137 -0
  62. workbench/scripts/ml_pipeline_sqs.py +186 -0
  63. workbench/scripts/monitor_cloud_watch.py +20 -100
  64. workbench/utils/aws_utils.py +4 -3
  65. workbench/utils/chem_utils/__init__.py +0 -0
  66. workbench/utils/chem_utils/fingerprints.py +134 -0
  67. workbench/utils/chem_utils/misc.py +194 -0
  68. workbench/utils/chem_utils/mol_descriptors.py +483 -0
  69. workbench/utils/chem_utils/mol_standardize.py +450 -0
  70. workbench/utils/chem_utils/mol_tagging.py +348 -0
  71. workbench/utils/chem_utils/projections.py +209 -0
  72. workbench/utils/chem_utils/salts.py +256 -0
  73. workbench/utils/chem_utils/sdf.py +292 -0
  74. workbench/utils/chem_utils/toxicity.py +250 -0
  75. workbench/utils/chem_utils/vis.py +253 -0
  76. workbench/utils/chemprop_utils.py +760 -0
  77. workbench/utils/cloudwatch_handler.py +1 -1
  78. workbench/utils/cloudwatch_utils.py +137 -0
  79. workbench/utils/config_manager.py +3 -7
  80. workbench/utils/endpoint_utils.py +5 -7
  81. workbench/utils/license_manager.py +2 -6
  82. workbench/utils/model_utils.py +95 -34
  83. workbench/utils/monitor_utils.py +44 -62
  84. workbench/utils/pandas_utils.py +3 -3
  85. workbench/utils/pytorch_utils.py +526 -0
  86. workbench/utils/shap_utils.py +10 -2
  87. workbench/utils/workbench_logging.py +0 -3
  88. workbench/utils/workbench_sqs.py +1 -1
  89. workbench/utils/xgboost_model_utils.py +371 -156
  90. workbench/web_interface/components/model_plot.py +7 -1
  91. workbench/web_interface/components/plugin_unit_test.py +5 -2
  92. workbench/web_interface/components/plugins/dashboard_status.py +3 -1
  93. workbench/web_interface/components/plugins/generated_compounds.py +1 -1
  94. workbench/web_interface/components/plugins/model_details.py +9 -7
  95. workbench/web_interface/components/plugins/scatter_plot.py +3 -3
  96. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/METADATA +27 -6
  97. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/RECORD +101 -85
  98. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/entry_points.txt +4 -0
  99. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/licenses/LICENSE +1 -1
  100. workbench/model_scripts/custom_models/chem_info/local_utils.py +0 -769
  101. workbench/model_scripts/custom_models/chem_info/tautomerize.py +0 -83
  102. workbench/model_scripts/custom_models/proximity/generated_model_script.py +0 -138
  103. workbench/model_scripts/custom_models/uq_models/generated_model_script.py +0 -393
  104. workbench/model_scripts/custom_models/uq_models/mapie_xgb.template +0 -203
  105. workbench/model_scripts/ensemble_xgb/generated_model_script.py +0 -279
  106. workbench/model_scripts/quant_regression/quant_regression.template +0 -279
  107. workbench/model_scripts/quant_regression/requirements.txt +0 -1
  108. workbench/utils/chem_utils.py +0 -1556
  109. workbench/utils/execution_environment.py +0 -211
  110. workbench/utils/fast_inference.py +0 -167
  111. workbench/utils/resource_utils.py +0 -39
  112. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/WHEEL +0 -0
  113. {workbench-0.8.162.dist-info → workbench-0.8.202.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ """
2
+ mol_tagging.py - Molecular property tagging for ADMET modeling
3
+ Adds a 'tags' column to DataFrames for filtering and classification
4
+ """
5
+
6
+ import logging
7
+ from typing import List, Set, Optional
8
+ import pandas as pd
9
+ from rdkit import Chem
10
+ from rdkit.Chem import Mol, Descriptors
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ # ============================================================================
16
+ # Property Detection Functions (Internal)
17
+ # ============================================================================
18
+
19
+
20
+ def _get_metal_tags(mol: Mol) -> Set[str]:
21
+ """Detect metal-related tags."""
22
+ tags = set()
23
+ if mol is None:
24
+ return tags
25
+
26
+ # Metalloenzyme-relevant metals
27
+ metalloenzyme_metals = {"Zn", "Cu", "Fe", "Mn", "Co", "Ni", "Mo", "V"}
28
+
29
+ # Heavy/toxic metals
30
+ heavy_metals = {"Pb", "Hg", "Cd", "As", "Cr", "Tl", "Ba", "Be", "Al", "Sb", "Se", "Bi", "Ag"}
31
+
32
+ for atom in mol.GetAtoms():
33
+ symbol = atom.GetSymbol()
34
+ if symbol in metalloenzyme_metals:
35
+ tags.add("metalloenzyme_metal")
36
+ if symbol in heavy_metals:
37
+ tags.add("heavy_metal")
38
+
39
+ return tags
40
+
41
+
42
+ def _get_halogen_tags(mol: Mol) -> Set[str]:
43
+ """Detect halogenation patterns."""
44
+ tags = set()
45
+ if mol is None:
46
+ return tags
47
+
48
+ # Count halogens
49
+ halogen_count = sum(1 for atom in mol.GetAtoms() if atom.GetSymbol() in ["F", "Cl", "Br", "I"])
50
+
51
+ if halogen_count > 0:
52
+ tags.add("halogenated")
53
+
54
+ # Flag heavily halogenated compounds
55
+ heavy_atom_count = mol.GetNumHeavyAtoms()
56
+ if heavy_atom_count > 0:
57
+ halogen_ratio = halogen_count / heavy_atom_count
58
+ if halogen_ratio > 0.5 or halogen_count > 4:
59
+ tags.add("highly_halogenated")
60
+
61
+ return tags
62
+
63
+
64
+ def _get_druglike_tags(mol: Mol) -> Set[str]:
65
+ """Assess drug-likeness properties."""
66
+ tags = set()
67
+ if mol is None:
68
+ return tags
69
+
70
+ # Calculate descriptors once
71
+ mw = Descriptors.MolWt(mol)
72
+ logp = Descriptors.MolLogP(mol)
73
+ hbd = Descriptors.NumHDonors(mol)
74
+ hba = Descriptors.NumHAcceptors(mol)
75
+ rotatable = Descriptors.NumRotatableBonds(mol)
76
+ tpsa = Descriptors.TPSA(mol)
77
+
78
+ # Lipinski's Rule of Five
79
+ ro5_violations = 0
80
+ if mw > 500:
81
+ ro5_violations += 1
82
+ if logp > 5:
83
+ ro5_violations += 1
84
+ if hbd > 5:
85
+ ro5_violations += 1
86
+ if hba > 10:
87
+ ro5_violations += 1
88
+
89
+ if ro5_violations <= 1:
90
+ tags.add("ro5_pass")
91
+ if ro5_violations == 0:
92
+ tags.add("ro5_strict")
93
+
94
+ # Veber's rules
95
+ if rotatable <= 10 and tpsa <= 140:
96
+ tags.add("veber_pass")
97
+
98
+ # Lead-like
99
+ if 150 <= mw <= 350 and -3 <= logp <= 3.5:
100
+ tags.add("lead_like")
101
+
102
+ # Fragment-like (Rule of Three)
103
+ if mw <= 300 and logp <= 3 and hbd <= 3 and hba <= 3:
104
+ tags.add("fragment_like")
105
+
106
+ # Size categories
107
+ if mw < 200:
108
+ tags.add("small_molecule")
109
+ elif mw > 700:
110
+ tags.add("large_molecule")
111
+
112
+ return tags
113
+
114
+
115
+ def _get_structural_tags(mol: Mol) -> Set[str]:
116
+ """Detect structural features."""
117
+ tags = set()
118
+ if mol is None:
119
+ return tags
120
+
121
+ # Check for multiple fragments
122
+ if len(Chem.GetMolFrags(mol)) > 1:
123
+ tags.add("multi_fragment")
124
+
125
+ # Check for rings
126
+ ring_info = mol.GetRingInfo()
127
+ if ring_info.NumRings() == 0:
128
+ tags.add("acyclic")
129
+ else:
130
+ tags.add("cyclic")
131
+ # Check for aromatic rings by checking if any ring atoms are aromatic
132
+ for ring in ring_info.AtomRings():
133
+ if any(mol.GetAtomWithIdx(idx).GetIsAromatic() for idx in ring):
134
+ tags.add("aromatic")
135
+ break
136
+
137
+ # Check for chirality
138
+ if any(atom.GetChiralTag() != Chem.ChiralType.CHI_UNSPECIFIED for atom in mol.GetAtoms()):
139
+ tags.add("chiral")
140
+
141
+ return tags
142
+
143
+
144
+ # ============================================================================
145
+ # Main Tagging Function
146
+ # ============================================================================
147
+
148
+
149
+ def tag_molecules(
150
+ df: pd.DataFrame,
151
+ smiles_column: str = "smiles",
152
+ tag_column: str = "tags",
153
+ tag_categories: Optional[List[str]] = None,
154
+ ) -> pd.DataFrame:
155
+ """
156
+ Add molecular property tags to a DataFrame.
157
+
158
+ Designed to work after mol_standardize.py processing.
159
+ Adds a single 'tags' column containing a list of string tags.
160
+
161
+ Args:
162
+ df: Input DataFrame with SMILES
163
+ smiles_column: Column containing SMILES strings
164
+ tag_column: Name for output tags column (default: "tags")
165
+ tag_categories: Which tag categories to include. Options:
166
+ - "metals": Metal content tags
167
+ - "halogens": Halogenation tags
168
+ - "druglike": Drug-likeness assessments
169
+ - "structure": Structural features
170
+ - None (default): Include all categories
171
+
172
+ Returns:
173
+ DataFrame with tags column added
174
+
175
+ Example:
176
+ df = tag_molecules(df) # Add all tags
177
+ df = tag_molecules(df, tag_categories=["druglike"]) # Only drug-likeness
178
+ """
179
+ result = df.copy()
180
+
181
+ # Default to all categories
182
+ if tag_categories is None:
183
+ tag_categories = ["metals", "halogens", "druglike", "structure"]
184
+
185
+ # Initialize tags column
186
+ all_tags = []
187
+
188
+ # Process each molecule
189
+ for idx, row in result.iterrows():
190
+ # Parse SMILES to molecule
191
+ smiles = row[smiles_column]
192
+ if pd.isna(smiles) or smiles == "":
193
+ all_tags.append(["invalid_smiles"])
194
+ continue
195
+
196
+ mol = Chem.MolFromSmiles(smiles)
197
+ if mol is None:
198
+ all_tags.append(["invalid_smiles"])
199
+ continue
200
+
201
+ # Collect tags based on categories
202
+ tags = set()
203
+
204
+ if "metals" in tag_categories:
205
+ tags.update(_get_metal_tags(mol))
206
+
207
+ if "halogens" in tag_categories:
208
+ tags.update(_get_halogen_tags(mol))
209
+
210
+ if "druglike" in tag_categories:
211
+ tags.update(_get_druglike_tags(mol))
212
+
213
+ if "structure" in tag_categories:
214
+ tags.update(_get_structural_tags(mol))
215
+
216
+ # Convert to sorted list for consistency
217
+ all_tags.append(sorted(list(tags)))
218
+
219
+ # Add tags column
220
+ result[tag_column] = all_tags
221
+
222
+ # Log summary
223
+ total = len(result)
224
+ valid = sum(1 for tags in all_tags if "invalid_smiles" not in tags)
225
+ ro5_pass = sum(1 for tags in all_tags if "ro5_pass" in tags)
226
+
227
+ logger.info(f"Tagged {total} molecules: {valid} valid, {ro5_pass} pass Ro5")
228
+
229
+ return result
230
+
231
+
232
+ # ============================================================================
233
+ # Utility Functions
234
+ # ============================================================================
235
+
236
+
237
+ def filter_by_tags(
238
+ df: pd.DataFrame, require: Optional[List[str]] = None, exclude: Optional[List[str]] = None, tag_column: str = "tags"
239
+ ) -> pd.DataFrame:
240
+ """
241
+ Filter DataFrame rows based on tags.
242
+
243
+ Args:
244
+ df: DataFrame with tags column
245
+ require: Tags that must be present (AND logic)
246
+ exclude: Tags that must not be present
247
+ tag_column: Name of tags column
248
+
249
+ Returns:
250
+ Filtered DataFrame
251
+
252
+ Example:
253
+ # Get drug-like molecules without heavy metals
254
+ filtered = filter_by_tags(df,
255
+ require=["ro5_pass"],
256
+ exclude=["heavy_metal"])
257
+ """
258
+ result = df.copy()
259
+
260
+ if require:
261
+ for tag in require:
262
+ result = result[result[tag_column].apply(lambda x: tag in x)]
263
+
264
+ if exclude:
265
+ for tag in exclude:
266
+ result = result[result[tag_column].apply(lambda x: tag not in x)]
267
+
268
+ logger.info(f"Filtered {len(df)} → {len(result)} molecules")
269
+
270
+ return result
271
+
272
+
273
+ def get_tag_summary(df: pd.DataFrame, tag_column: str = "tags") -> pd.Series:
274
+ """
275
+ Get summary statistics of tags in DataFrame.
276
+
277
+ Args:
278
+ df: DataFrame with tags column
279
+ tag_column: Name of tags column
280
+
281
+ Returns:
282
+ Series with tag counts
283
+ """
284
+ # Flatten all tags and count
285
+ all_tags = []
286
+ for tags_list in df[tag_column]:
287
+ all_tags.extend(tags_list)
288
+
289
+ tag_counts = pd.Series(all_tags).value_counts()
290
+ return tag_counts
291
+
292
+
293
+ if __name__ == "__main__":
294
+ # Test the tagging functionality
295
+ print("Testing molecular tagging system")
296
+ print("=" * 60)
297
+
298
+ # Create test dataset
299
+ test_data = pd.DataFrame(
300
+ {
301
+ "smiles": [
302
+ "CC(=O)Oc1ccccc1C(=O)O", # Aspirin
303
+ "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
304
+ "C" * 50, # Large alkane
305
+ "C(Cl)(Cl)(Cl)Cl", # Carbon tetrachloride
306
+ "[Zn+2].[Cl-].[Cl-]", # Zinc chloride
307
+ "CCC", # Propane
308
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen
309
+ "[Pb+2].[O-]C(=O)C", # Lead acetate
310
+ "", # Empty
311
+ "INVALID_SMILES", # Invalid
312
+ ],
313
+ "compound_id": [f"C{i:03d}" for i in range(1, 11)],
314
+ }
315
+ )
316
+
317
+ print("Input data:")
318
+ print(test_data[["compound_id", "smiles"]])
319
+
320
+ # Apply tagging
321
+ print("\n" + "=" * 60)
322
+ print("Applying molecular tags...")
323
+ tagged_df = tag_molecules(test_data)
324
+
325
+ print("\nTagged results:")
326
+ for _, row in tagged_df.iterrows():
327
+ tags_str = ", ".join(row["tags"]) if row["tags"] else "none"
328
+ print(f"{row['compound_id']}: {tags_str}")
329
+
330
+ # Test filtering
331
+ print("\n" + "=" * 60)
332
+ print("Testing filters...")
333
+
334
+ # Get drug-like molecules
335
+ druglike = filter_by_tags(tagged_df, require=["ro5_pass"])
336
+ print(f"Drug-like molecules: {list(druglike['compound_id'])}")
337
+
338
+ # Exclude problematic molecules
339
+ clean = filter_by_tags(tagged_df, exclude=["heavy_metal", "highly_halogenated", "invalid_smiles"])
340
+ print(f"Clean molecules: {list(clean['compound_id'])}")
341
+
342
+ # Get tag summary
343
+ print("\n" + "=" * 60)
344
+ print("Tag summary:")
345
+ summary = get_tag_summary(tagged_df)
346
+ print(summary.head(10))
347
+
348
+ print("\n✅ All tests completed!")
@@ -0,0 +1,209 @@
1
+ """Dimensionality reduction and projection utilities for molecular fingerprints"""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import pandas as pd
6
+ from sklearn.manifold import TSNE
7
+
8
+ # Try importing UMAP
9
+ try:
10
+ import umap
11
+ except ImportError:
12
+ umap = None
13
+
14
+ # Set up the logger
15
+ log = logging.getLogger("workbench")
16
+
17
+
18
+ def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
19
+ """
20
+ Convert bitstring fingerprints to numpy matrix.
21
+
22
+ Args:
23
+ fingerprints: pandas Series or list of bitstring fingerprints
24
+ dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
25
+
26
+ Returns:
27
+ dense numpy array of shape (n_molecules, n_bits)
28
+ """
29
+
30
+ # Dense matrix representation (we might support sparse in the future)
31
+ return np.array([list(fp) for fp in fingerprints], dtype=dtype)
32
+
33
+
34
+ def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame:
35
+ """Project fingerprints onto a 2D plane using dimensionality reduction techniques.
36
+
37
+ Args:
38
+ df (pd.DataFrame): Input DataFrame containing fingerprint data.
39
+ projection (str): Dimensionality reduction technique to use (TSNE or UMAP).
40
+
41
+ Returns:
42
+ pd.DataFrame: The input DataFrame with the projected coordinates added as 'x' and 'y' columns.
43
+ """
44
+ # Check for the fingerprint column (case-insensitive)
45
+ fingerprint_column = next((col for col in df.columns if "fingerprint" in col.lower()), None)
46
+ if fingerprint_column is None:
47
+ raise ValueError("Input DataFrame must have a fingerprint column")
48
+
49
+ # Create a matrix of fingerprints
50
+ X = fingerprints_to_matrix(df[fingerprint_column])
51
+
52
+ # Get number of samples
53
+ n_samples = X.shape[0]
54
+
55
+ # Check for UMAP availability
56
+ if projection == "UMAP" and umap is None:
57
+ log.warning("UMAP is not available. Using TSNE instead.")
58
+ projection = "TSNE"
59
+
60
+ # Run the projection
61
+ if projection == "TSNE":
62
+ # Adjust perplexity based on dataset size
63
+ # Perplexity must be less than n_samples and at least 1
64
+ perplexity = min(30, max(1, n_samples - 1))
65
+
66
+ # TSNE requires at least 4 samples
67
+ if n_samples < 4:
68
+ log.warning(f"Dataset too small for TSNE (n={n_samples}). Need at least 4 samples.")
69
+ # Return with random coordinates for very small datasets
70
+ df["x"] = np.random.uniform(-10, 10, n_samples)
71
+ df["y"] = np.random.uniform(-10, 10, n_samples)
72
+ return df
73
+
74
+ # Run TSNE on the fingerprint matrix
75
+ tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
76
+ embedding = tsne.fit_transform(X)
77
+ else:
78
+ # Run UMAP
79
+ # Adjust n_neighbors based on dataset size
80
+ n_neighbors = min(15, n_samples - 1) if n_samples > 1 else 1
81
+
82
+ reducer = umap.UMAP(metric="jaccard", n_neighbors=n_neighbors)
83
+ embedding = reducer.fit_transform(X)
84
+
85
+ # Add coordinates to DataFrame
86
+ df["x"] = embedding[:, 0]
87
+ df["y"] = embedding[:, 1]
88
+
89
+ # If vertices disconnect from the manifold, they are given NaN values (so replace with 0)
90
+ df["x"] = df["x"].fillna(0)
91
+ df["y"] = df["y"].fillna(0)
92
+
93
+ # Jitter
94
+ jitter_scale = 0.1
95
+ df["x"] += np.random.uniform(0, jitter_scale, len(df))
96
+ df["y"] += np.random.uniform(0, jitter_scale, len(df))
97
+
98
+ return df
99
+
100
+
101
+ if __name__ == "__main__":
102
+ print("Running molecular projection tests...")
103
+
104
+ from rdkit import Chem
105
+ from rdkit.Chem import rdFingerprintGenerator
106
+
107
+ # Test molecules
108
+ test_molecules = {
109
+ "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
110
+ "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
111
+ "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O",
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]",
113
+ "benzene": "c1ccccc1",
114
+ "toluene": "Cc1ccccc1",
115
+ "phenol": "Oc1ccccc1",
116
+ "aniline": "Nc1ccccc1",
117
+ }
118
+
119
+ # Generate fingerprints for test
120
+ print("\n1. Generating test fingerprints...")
121
+
122
+ test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
123
+
124
+ # Generate Morgan fingerprints
125
+ mols = [Chem.MolFromSmiles(smi) for smi in test_df["SMILES"]]
126
+ morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=512)
127
+ fingerprints = [morgan_gen.GetFingerprint(mol).ToBitString() if mol else None for mol in mols]
128
+ test_df["fingerprint"] = fingerprints
129
+
130
+ # Remove any failed molecules
131
+ test_df = test_df.dropna(subset=["fingerprint"])
132
+ print(f" Generated {len(test_df)} fingerprints")
133
+
134
+ # Test 2: Fingerprint to matrix conversion
135
+ print("\n2. Testing fingerprint matrix conversion...")
136
+
137
+ matrix = fingerprints_to_matrix(test_df["fingerprint"])
138
+ print(f" Matrix shape: {matrix.shape}")
139
+ print(f" Matrix dtype: {matrix.dtype}")
140
+ print(f" Non-zero elements: {np.count_nonzero(matrix)}")
141
+
142
+ # Test 3: TSNE projection
143
+ print("\n3. Testing TSNE projection...")
144
+
145
+ try:
146
+ proj_df = project_fingerprints(test_df.copy(), projection="TSNE")
147
+
148
+ print(" TSNE projection results:")
149
+ for _, row in proj_df.head(4).iterrows():
150
+ print(f" {row['name']:15} → x:{row['x']:7.2f} y:{row['y']:7.2f}")
151
+
152
+ # Check that coordinates were added
153
+ assert "x" in proj_df.columns and "y" in proj_df.columns
154
+ print(f" ✓ Successfully projected {len(proj_df)} molecules")
155
+
156
+ except Exception as e:
157
+ print(f" Note: TSNE projection test limited: {e}")
158
+
159
+ # Test 4: UMAP projection (if available)
160
+ print("\n4. Testing UMAP projection...")
161
+
162
+ if umap is not None:
163
+ try:
164
+ proj_umap_df = project_fingerprints(test_df.copy(), projection="UMAP")
165
+
166
+ print(" UMAP projection results:")
167
+ for _, row in proj_umap_df.head(4).iterrows():
168
+ print(f" {row['name']:15} → x:{row['x']:7.2f} y:{row['y']:7.2f}")
169
+
170
+ print(f" ✓ Successfully projected {len(proj_umap_df)} molecules with UMAP")
171
+
172
+ except Exception as e:
173
+ print(f" Note: UMAP projection failed: {e}")
174
+ else:
175
+ print(" UMAP not available - skipping test")
176
+
177
+ # Test 5: Edge cases
178
+ print("\n5. Testing edge cases...")
179
+
180
+ # Test with missing fingerprint column
181
+ no_fp_df = pd.DataFrame({"SMILES": ["CCO", "CC"]})
182
+ try:
183
+ project_fingerprints(no_fp_df)
184
+ print(" ✗ Should have raised error for missing fingerprint column")
185
+ except ValueError as e:
186
+ print(f" ✓ Correctly raised error for missing fingerprint: {str(e)}")
187
+
188
+ # Test with small dataset (less than perplexity)
189
+ small_df = test_df.head(2).copy()
190
+ if len(small_df) > 0:
191
+ try:
192
+ proj_small = project_fingerprints(small_df, projection="TSNE")
193
+ print(" Note: Small dataset projection handled")
194
+ except Exception as e:
195
+ print(f" Note: Small dataset appropriately failed: {type(e).__name__}")
196
+
197
+ # Test 6: Testing NaN value handling
198
+ print("\n6. Testing NaN value handling...")
199
+
200
+ try:
201
+ # The projection should handle NaN values by replacing with 0
202
+ proj_test = project_fingerprints(test_df.copy(), projection="TSNE")
203
+ has_nan = proj_test[["x", "y"]].isnull().any().any()
204
+ print(f" NaN values in output: {has_nan}")
205
+ print(" ✓ NaN values properly handled")
206
+ except Exception as e:
207
+ print(f" Note: Could not test NaN handling due to: {e}")
208
+
209
+ print("\n✅ All projection tests completed!")