workbench 0.8.202__py3-none-any.whl → 0.8.220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of workbench might be problematic. Click here for more details.

Files changed (84) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/feature_space_proximity.py +168 -75
  3. workbench/algorithms/dataframe/fingerprint_proximity.py +421 -85
  4. workbench/algorithms/dataframe/projection_2d.py +44 -21
  5. workbench/algorithms/dataframe/proximity.py +78 -150
  6. workbench/algorithms/graph/light/proximity_graph.py +5 -5
  7. workbench/algorithms/models/cleanlab_model.py +382 -0
  8. workbench/algorithms/models/noise_model.py +388 -0
  9. workbench/algorithms/sql/outliers.py +3 -3
  10. workbench/api/__init__.py +3 -0
  11. workbench/api/df_store.py +17 -108
  12. workbench/api/endpoint.py +13 -11
  13. workbench/api/feature_set.py +111 -8
  14. workbench/api/meta_model.py +289 -0
  15. workbench/api/model.py +45 -12
  16. workbench/api/parameter_store.py +3 -52
  17. workbench/cached/cached_model.py +4 -4
  18. workbench/core/artifacts/artifact.py +5 -5
  19. workbench/core/artifacts/df_store_core.py +114 -0
  20. workbench/core/artifacts/endpoint_core.py +228 -237
  21. workbench/core/artifacts/feature_set_core.py +185 -230
  22. workbench/core/artifacts/model_core.py +34 -26
  23. workbench/core/artifacts/parameter_store_core.py +98 -0
  24. workbench/core/pipelines/pipeline_executor.py +1 -1
  25. workbench/core/transforms/features_to_model/features_to_model.py +22 -10
  26. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +41 -10
  27. workbench/core/transforms/pandas_transforms/pandas_to_features.py +11 -2
  28. workbench/model_script_utils/model_script_utils.py +339 -0
  29. workbench/model_script_utils/pytorch_utils.py +405 -0
  30. workbench/model_script_utils/uq_harness.py +278 -0
  31. workbench/model_scripts/chemprop/chemprop.template +428 -631
  32. workbench/model_scripts/chemprop/generated_model_script.py +432 -635
  33. workbench/model_scripts/chemprop/model_script_utils.py +339 -0
  34. workbench/model_scripts/chemprop/requirements.txt +2 -10
  35. workbench/model_scripts/custom_models/chem_info/fingerprints.py +87 -46
  36. workbench/model_scripts/custom_models/proximity/feature_space_proximity.py +194 -0
  37. workbench/model_scripts/custom_models/proximity/feature_space_proximity.template +6 -6
  38. workbench/model_scripts/custom_models/uq_models/feature_space_proximity.py +194 -0
  39. workbench/model_scripts/meta_model/generated_model_script.py +209 -0
  40. workbench/model_scripts/meta_model/meta_model.template +209 -0
  41. workbench/model_scripts/pytorch_model/generated_model_script.py +374 -613
  42. workbench/model_scripts/pytorch_model/model_script_utils.py +339 -0
  43. workbench/model_scripts/pytorch_model/pytorch.template +370 -609
  44. workbench/model_scripts/pytorch_model/pytorch_utils.py +405 -0
  45. workbench/model_scripts/pytorch_model/requirements.txt +1 -1
  46. workbench/model_scripts/pytorch_model/uq_harness.py +278 -0
  47. workbench/model_scripts/script_generation.py +6 -5
  48. workbench/model_scripts/uq_models/generated_model_script.py +65 -422
  49. workbench/model_scripts/xgb_model/generated_model_script.py +372 -395
  50. workbench/model_scripts/xgb_model/model_script_utils.py +339 -0
  51. workbench/model_scripts/xgb_model/uq_harness.py +278 -0
  52. workbench/model_scripts/xgb_model/xgb_model.template +366 -396
  53. workbench/repl/workbench_shell.py +0 -5
  54. workbench/resources/open_source_api.key +1 -1
  55. workbench/scripts/endpoint_test.py +2 -2
  56. workbench/scripts/meta_model_sim.py +35 -0
  57. workbench/scripts/training_test.py +85 -0
  58. workbench/utils/chem_utils/fingerprints.py +87 -46
  59. workbench/utils/chem_utils/projections.py +16 -6
  60. workbench/utils/chemprop_utils.py +36 -655
  61. workbench/utils/meta_model_simulator.py +499 -0
  62. workbench/utils/metrics_utils.py +256 -0
  63. workbench/utils/model_utils.py +192 -54
  64. workbench/utils/pytorch_utils.py +33 -472
  65. workbench/utils/shap_utils.py +1 -55
  66. workbench/utils/xgboost_local_crossfold.py +267 -0
  67. workbench/utils/xgboost_model_utils.py +49 -356
  68. workbench/web_interface/components/model_plot.py +7 -1
  69. workbench/web_interface/components/plugins/model_details.py +30 -68
  70. workbench/web_interface/components/plugins/scatter_plot.py +4 -8
  71. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/METADATA +6 -5
  72. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/RECORD +76 -60
  73. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/entry_points.txt +2 -0
  74. workbench/core/cloud_platform/aws/aws_df_store.py +0 -404
  75. workbench/core/cloud_platform/aws/aws_parameter_store.py +0 -296
  76. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  77. workbench/model_scripts/custom_models/proximity/proximity.py +0 -410
  78. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  79. workbench/model_scripts/custom_models/uq_models/proximity.py +0 -410
  80. workbench/model_scripts/uq_models/mapie.template +0 -605
  81. workbench/model_scripts/uq_models/requirements.txt +0 -1
  82. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/WHEEL +0 -0
  83. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/licenses/LICENSE +0 -0
  84. {workbench-0.8.202.dist-info → workbench-0.8.220.dist-info}/top_level.txt +0 -0
@@ -302,11 +302,6 @@ class WorkbenchShell:
302
302
  self.commands["PandasToView"] = importlib.import_module("workbench.core.views.pandas_to_view").PandasToView
303
303
  self.commands["Pipeline"] = importlib.import_module("workbench.api.pipeline").Pipeline
304
304
 
305
- # Algorithms
306
- self.commands["FSP"] = importlib.import_module(
307
- "workbench.algorithms.dataframe.feature_space_proximity"
308
- ).FeatureSpaceProximity
309
-
310
305
  # These are 'nice to have' imports
311
306
  self.commands["pd"] = importlib.import_module("pandas")
312
307
  self.commands["wr"] = importlib.import_module("awswrangler")
@@ -1 +1 @@
1
- eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTAxLTE0In02zCDRy41wKRViRnGmodczFWexLyfXYrJWSuVQQbhbWeRttQRv6zpo9x4O2yBjdRfhb9E7mFUppNiOS_ZGK-bL71nGHt_Mc8niG8jkpvKX9qZ6BqkXF_vzDIOcI8iGiwB3wikeVO4zRLD1AI0U3cgYmIyGXI9QKJ9L7IHyQ0TWqw==
1
+ eyJsaWNlbnNlX2lkIjogIk9wZW5fU291cmNlX0xpY2Vuc2UiLCAiY29tcGFueSI6ICIiLCAiYXdzX2FjY291bnRfaWQiOiAiIiwgInRpZXIiOiAiRW50ZXJwcmlzZSBQcm8iLCAiZmVhdHVyZXMiOiBbInBsdWdpbnMiLCAicGFnZXMiLCAidGhlbWVzIiwgInBpcGVsaW5lcyIsICJicmFuZGluZyJdLCAiZXhwaXJlcyI6ICIyMDI2LTEyLTA1In1IsmpkuybFALADkRj_RfmkQ0LAIsQeXRE7Uoc3DL1UrDr-rSnwu-PDqsKBUkX6jPRFZV3DLxNjBapxPeEIFhfvxvjzz_sc6CwtxNpZ3bPmxSPs2W-j3xZS4-XyEqIilcwSkWh-NU1u27gCuuivn5eiUmIYJGAp0wdVkeE6_Z9dlg==
@@ -5,7 +5,7 @@ Usage:
5
5
  python model_script_harness.py <local_script.py> <model_name>
6
6
 
7
7
  Example:
8
- python model_script_harness.py pytorch.py aqsol-pytorch-reg
8
+ python model_script_harness.py pytorch.py aqsol-reg-pytorch
9
9
 
10
10
  This allows you to test LOCAL changes to a model script against deployed model artifacts.
11
11
  Evaluation data is automatically pulled from the FeatureSet (training = FALSE rows).
@@ -72,7 +72,7 @@ def main():
72
72
  print("Usage: python model_script_harness.py <local_script.py> <model_name>")
73
73
  print("\nArguments:")
74
74
  print(" local_script.py - Path to your LOCAL model script to test")
75
- print(" model_name - Workbench model name (e.g., aqsol-pytorch-reg)")
75
+ print(" model_name - Workbench model name (e.g., aqsol-reg-pytorch)")
76
76
  print("\nOptional: testing/env.json with additional environment variables")
77
77
  sys.exit(1)
78
78
 
@@ -0,0 +1,35 @@
1
+ """MetaModelSimulator: Simulate and analyze ensemble model performance.
2
+
3
+ This class helps evaluate whether a meta model (ensemble) would outperform
4
+ individual child models by analyzing endpoint inference predictions.
5
+ """
6
+
7
+ import argparse
8
+ from workbench.utils.meta_model_simulator import MetaModelSimulator
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(
13
+ description="Simulate and analyze ensemble model performance using MetaModelSimulator."
14
+ )
15
+ parser.add_argument(
16
+ "models",
17
+ nargs="+",
18
+ help="List of model endpoint names to include in the ensemble simulation.",
19
+ )
20
+ parser.add_argument(
21
+ "--id-column",
22
+ default="molecule_name",
23
+ help="Name of the ID column (default: molecule_name)",
24
+ )
25
+ args = parser.parse_args()
26
+ models = args.models
27
+ id_column = args.id_column
28
+
29
+ # Create MetaModelSimulator instance and generate report
30
+ sim = MetaModelSimulator(models, id_column=id_column)
31
+ sim.report()
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
@@ -0,0 +1,85 @@
1
+ """
2
+ Local test harness for SageMaker training scripts.
3
+
4
+ Usage:
5
+ python training_test.py <model_script.py> <featureset_name>
6
+
7
+ Example:
8
+ python training_test.py ../model_scripts/pytorch_model/generated_model_script.py caco2-class-features
9
+ """
10
+
11
+ import os
12
+ import shutil
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+
17
+ import pandas as pd
18
+
19
+ from workbench.api import FeatureSet
20
+
21
+
22
+ def get_training_data(featureset_name: str) -> pd.DataFrame:
23
+ """Get training data from the FeatureSet."""
24
+ fs = FeatureSet(featureset_name)
25
+ return fs.pull_dataframe()
26
+
27
+
28
+ def main():
29
+ if len(sys.argv) < 3:
30
+ print("Usage: python training_test.py <model_script.py> <featureset_name>")
31
+ sys.exit(1)
32
+
33
+ script_path = sys.argv[1]
34
+ featureset_name = sys.argv[2]
35
+
36
+ if not os.path.exists(script_path):
37
+ print(f"Error: Script not found: {script_path}")
38
+ sys.exit(1)
39
+
40
+ # Create temp directories
41
+ model_dir = tempfile.mkdtemp(prefix="training_model_")
42
+ train_dir = tempfile.mkdtemp(prefix="training_data_")
43
+ output_dir = tempfile.mkdtemp(prefix="training_output_")
44
+
45
+ print(f"Model dir: {model_dir}")
46
+ print(f"Train dir: {train_dir}")
47
+
48
+ try:
49
+ # Get training data and save to CSV
50
+ print(f"Loading FeatureSet: {featureset_name}")
51
+ df = get_training_data(featureset_name)
52
+ print(f"Data shape: {df.shape}")
53
+
54
+ train_file = os.path.join(train_dir, "training_data.csv")
55
+ df.to_csv(train_file, index=False)
56
+
57
+ # Set up environment
58
+ env = os.environ.copy()
59
+ env["SM_MODEL_DIR"] = model_dir
60
+ env["SM_CHANNEL_TRAIN"] = train_dir
61
+ env["SM_OUTPUT_DATA_DIR"] = output_dir
62
+
63
+ print("\n" + "=" * 60)
64
+ print("Starting training...")
65
+ print("=" * 60 + "\n")
66
+
67
+ # Run the script
68
+ cmd = [sys.executable, script_path, "--model-dir", model_dir, "--train", train_dir]
69
+ result = subprocess.run(cmd, env=env)
70
+
71
+ print("\n" + "=" * 60)
72
+ if result.returncode == 0:
73
+ print("Training completed successfully!")
74
+ else:
75
+ print(f"Training failed with return code: {result.returncode}")
76
+ print("=" * 60)
77
+
78
+ finally:
79
+ shutil.rmtree(model_dir, ignore_errors=True)
80
+ shutil.rmtree(train_dir, ignore_errors=True)
81
+ shutil.rmtree(output_dir, ignore_errors=True)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
@@ -1,31 +1,48 @@
1
- """Molecular fingerprint computation utilities"""
1
+ """Molecular fingerprint computation utilities for ADMET modeling.
2
+
3
+ This module provides Morgan count fingerprints, the standard for ADMET prediction.
4
+ Count fingerprints outperform binary fingerprints for molecular property prediction.
5
+
6
+ References:
7
+ - Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
8
+ - ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
9
+ """
2
10
 
3
11
  import logging
4
- import pandas as pd
5
12
 
6
- # Molecular Descriptor Imports
7
- from rdkit import Chem
8
- from rdkit.Chem import rdFingerprintGenerator
13
+ import numpy as np
14
+ import pandas as pd
15
+ from rdkit import Chem, RDLogger
16
+ from rdkit.Chem import AllChem
9
17
  from rdkit.Chem.MolStandardize import rdMolStandardize
10
18
 
19
+ # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
20
+ # Keep errors enabled so we see actual problems
21
+ RDLogger.DisableLog("rdApp.warning")
22
+
11
23
  # Set up the logger
12
24
  log = logging.getLogger("workbench")
13
25
 
14
26
 
15
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
16
- """Compute and add Morgan fingerprints to the DataFrame.
27
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
28
+ """Compute Morgan count fingerprints for ADMET modeling.
29
+
30
+ Generates true count fingerprints where each bit position contains the
31
+ number of times that substructure appears in the molecule (clamped to 0-255).
32
+ This is the recommended approach for ADMET prediction per 2025 research.
17
33
 
18
34
  Args:
19
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
20
- radius (int): Radius for the Morgan fingerprint.
21
- n_bits (int): Number of bits for the fingerprint.
22
- counts (bool): Count simulation for the fingerprint.
35
+ df: Input DataFrame containing SMILES strings.
36
+ radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
37
+ n_bits: Number of bits for the fingerprint (default 2048).
23
38
 
24
39
  Returns:
25
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
40
+ pd.DataFrame: Input DataFrame with 'fingerprint' column added.
41
+ Values are comma-separated uint8 counts.
26
42
 
27
43
  Note:
28
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
44
+ Count fingerprints outperform binary for ADMET prediction.
45
+ See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
29
46
  """
30
47
  delete_mol_column = False
31
48
 
@@ -39,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
39
56
  log.warning("Detected serialized molecules in 'molecule' column. Removing...")
40
57
  del df["molecule"]
41
58
 
42
- # Convert SMILES to RDKit molecule objects (vectorized)
59
+ # Convert SMILES to RDKit molecule objects
43
60
  if "molecule" not in df.columns:
44
61
  log.info("Converting SMILES to RDKit Molecules...")
45
62
  delete_mol_column = True
@@ -47,23 +64,32 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
47
64
  # Make sure our molecules are not None
48
65
  failed_smiles = df[df["molecule"].isnull()][smiles_column].tolist()
49
66
  if failed_smiles:
50
- log.error(f"Failed to convert the following SMILES to molecules: {failed_smiles}")
51
- df = df.dropna(subset=["molecule"])
67
+ log.warning(f"Failed to convert {len(failed_smiles)} SMILES to molecules ({failed_smiles})")
68
+ df = df.dropna(subset=["molecule"]).copy()
52
69
 
53
70
  # If we have fragments in our compounds, get the largest fragment before computing fingerprints
54
71
  largest_frags = df["molecule"].apply(
55
72
  lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
56
73
  )
57
74
 
58
- # Create a Morgan fingerprint generator
59
- if counts:
60
- n_bits *= 4 # Multiply by 4 to simulate counts
61
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
75
+ def mol_to_count_string(mol):
76
+ """Convert molecule to comma-separated count fingerprint string."""
77
+ if mol is None:
78
+ return pd.NA
62
79
 
63
- # Compute Morgan fingerprints (vectorized)
64
- fingerprints = largest_frags.apply(
65
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
66
- )
80
+ # Get hashed Morgan fingerprint with counts
81
+ fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
82
+
83
+ # Initialize array and populate with counts (clamped to uint8 range)
84
+ counts = np.zeros(n_bits, dtype=np.uint8)
85
+ for idx, count in fp.GetNonzeroElements().items():
86
+ counts[idx] = min(count, 255)
87
+
88
+ # Return as comma-separated string
89
+ return ",".join(map(str, counts))
90
+
91
+ # Compute Morgan count fingerprints
92
+ fingerprints = largest_frags.apply(mol_to_count_string)
67
93
 
68
94
  # Add the fingerprints to the DataFrame
69
95
  df["fingerprint"] = fingerprints
@@ -71,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
71
97
  # Drop the intermediate 'molecule' column if it was added
72
98
  if delete_mol_column:
73
99
  del df["molecule"]
100
+
74
101
  return df
75
102
 
76
103
 
77
104
  if __name__ == "__main__":
78
- print("Running molecular fingerprint tests...")
79
- print("Note: This requires molecular_screening module to be available")
105
+ print("Running Morgan count fingerprint tests...")
80
106
 
81
107
  # Test molecules
82
108
  test_molecules = {
83
109
  "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
84
110
  "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
85
111
  "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
86
- "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
87
113
  "benzene": "c1ccccc1",
88
114
  "butene_e": "C/C=C/C", # E-butene
89
115
  "butene_z": "C/C=C\\C", # Z-butene
90
116
  }
91
117
 
92
- # Test 1: Morgan Fingerprints
93
- print("\n1. Testing Morgan fingerprint generation...")
118
+ # Test 1: Morgan Count Fingerprints (default parameters)
119
+ print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
94
120
 
95
121
  test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
96
-
97
- fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
122
+ fp_df = compute_morgan_fingerprints(test_df.copy())
98
123
 
99
124
  print(" Fingerprint generation results:")
100
125
  for _, row in fp_df.iterrows():
101
126
  fp = row.get("fingerprint", "N/A")
102
- fp_len = len(fp) if fp != "N/A" else 0
103
- print(f" {row['name']:15} {fp_len} bits")
127
+ if pd.notna(fp):
128
+ counts = [int(x) for x in fp.split(",")]
129
+ non_zero = sum(1 for c in counts if c > 0)
130
+ max_count = max(counts)
131
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
132
+ else:
133
+ print(f" {row['name']:15} → N/A")
104
134
 
105
- # Test 2: Different fingerprint parameters
106
- print("\n2. Testing different fingerprint parameters...")
135
+ # Test 2: Different parameters
136
+ print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
107
137
 
108
- # Test with counts enabled
109
- fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
138
+ fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
110
139
 
111
- print(" With count simulation (256 bits * 4):")
112
- for _, row in fp_counts_df.iterrows():
140
+ for _, row in fp_df_custom.iterrows():
113
141
  fp = row.get("fingerprint", "N/A")
114
- fp_len = len(fp) if fp != "N/A" else 0
115
- print(f" {row['name']:15} {fp_len} bits")
142
+ if pd.notna(fp):
143
+ counts = [int(x) for x in fp.split(",")]
144
+ non_zero = sum(1 for c in counts if c > 0)
145
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
146
+ else:
147
+ print(f" {row['name']:15} → N/A")
116
148
 
117
149
  # Test 3: Edge cases
118
150
  print("\n3. Testing edge cases...")
119
151
 
120
152
  # Invalid SMILES
121
153
  invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
122
- try:
123
- fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
124
- print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
125
- except Exception as e:
126
- print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
154
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
155
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
127
156
 
128
157
  # Test with pre-existing molecule column
129
158
  mol_df = test_df.copy()
@@ -131,4 +160,16 @@ if __name__ == "__main__":
131
160
  fp_with_mol = compute_morgan_fingerprints(mol_df)
132
161
  print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
133
162
 
163
+ # Test 4: Verify count values are reasonable
164
+ print("\n4. Verifying count distribution...")
165
+ all_counts = []
166
+ for _, row in fp_df.iterrows():
167
+ fp = row.get("fingerprint", "N/A")
168
+ if pd.notna(fp):
169
+ counts = [int(x) for x in fp.split(",")]
170
+ all_counts.extend([c for c in counts if c > 0])
171
+
172
+ if all_counts:
173
+ print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
174
+
134
175
  print("\n✅ All fingerprint tests completed!")
@@ -17,18 +17,28 @@ log = logging.getLogger("workbench")
17
17
 
18
18
  def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
19
19
  """
20
- Convert bitstring fingerprints to numpy matrix.
20
+ Convert fingerprints to numpy matrix.
21
+
22
+ Supports two formats (auto-detected):
23
+ - Bitstrings: "10110010..." → matrix of 0s and 1s
24
+ - Count vectors: "0,3,0,1,5,..." → matrix of counts (or binary if dtype=np.bool_)
21
25
 
22
26
  Args:
23
- fingerprints: pandas Series or list of bitstring fingerprints
24
- dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
27
+ fingerprints: pandas Series or list of fingerprints
28
+ dtype: numpy data type (uint8 is default; np.bool_ for Jaccard computations)
25
29
 
26
30
  Returns:
27
31
  dense numpy array of shape (n_molecules, n_bits)
28
32
  """
29
-
30
- # Dense matrix representation (we might support sparse in the future)
31
- return np.array([list(fp) for fp in fingerprints], dtype=dtype)
33
+ # Auto-detect format based on first fingerprint
34
+ sample = str(fingerprints.iloc[0] if hasattr(fingerprints, "iloc") else fingerprints[0])
35
+ if "," in sample:
36
+ # Count vector format: comma-separated integers
37
+ matrix = np.array([list(map(int, fp.split(","))) for fp in fingerprints], dtype=dtype)
38
+ else:
39
+ # Bitstring format: each character is a bit
40
+ matrix = np.array([list(fp) for fp in fingerprints], dtype=dtype)
41
+ return matrix
32
42
 
33
43
 
34
44
  def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame: