workbench 0.8.217__py3-none-any.whl → 0.8.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. workbench/algorithms/dataframe/compound_dataset_overlap.py +321 -0
  2. workbench/algorithms/dataframe/fingerprint_proximity.py +190 -31
  3. workbench/algorithms/dataframe/projection_2d.py +8 -2
  4. workbench/algorithms/dataframe/proximity.py +3 -0
  5. workbench/algorithms/sql/outliers.py +3 -3
  6. workbench/api/feature_set.py +0 -1
  7. workbench/core/artifacts/endpoint_core.py +2 -2
  8. workbench/core/artifacts/feature_set_core.py +185 -230
  9. workbench/core/transforms/features_to_model/features_to_model.py +2 -8
  10. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +2 -0
  11. workbench/model_script_utils/model_script_utils.py +15 -11
  12. workbench/model_scripts/chemprop/chemprop.template +195 -70
  13. workbench/model_scripts/chemprop/generated_model_script.py +198 -73
  14. workbench/model_scripts/chemprop/model_script_utils.py +15 -11
  15. workbench/model_scripts/custom_models/chem_info/fingerprints.py +80 -43
  16. workbench/model_scripts/pytorch_model/generated_model_script.py +2 -2
  17. workbench/model_scripts/pytorch_model/model_script_utils.py +15 -11
  18. workbench/model_scripts/xgb_model/generated_model_script.py +7 -7
  19. workbench/model_scripts/xgb_model/model_script_utils.py +15 -11
  20. workbench/scripts/meta_model_sim.py +35 -0
  21. workbench/scripts/ml_pipeline_sqs.py +71 -2
  22. workbench/themes/light/custom.css +7 -1
  23. workbench/themes/midnight_blue/custom.css +34 -0
  24. workbench/utils/chem_utils/fingerprints.py +80 -43
  25. workbench/utils/chem_utils/projections.py +16 -6
  26. workbench/utils/meta_model_simulator.py +41 -13
  27. workbench/utils/model_utils.py +0 -1
  28. workbench/utils/plot_utils.py +146 -28
  29. workbench/utils/shap_utils.py +1 -55
  30. workbench/utils/theme_manager.py +95 -30
  31. workbench/web_interface/components/plugins/scatter_plot.py +152 -66
  32. workbench/web_interface/components/settings_menu.py +184 -0
  33. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/METADATA +4 -13
  34. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/RECORD +38 -37
  35. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/entry_points.txt +1 -0
  36. workbench/model_scripts/custom_models/meta_endpoints/example.py +0 -53
  37. workbench/model_scripts/custom_models/uq_models/meta_uq.template +0 -377
  38. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/WHEEL +0 -0
  39. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/licenses/LICENSE +0 -0
  40. {workbench-0.8.217.dist-info → workbench-0.8.224.dist-info}/top_level.txt +0 -0
@@ -61,10 +61,10 @@ DEFAULT_HYPERPARAMETERS = {
61
61
  TEMPLATE_PARAMS = {
62
62
  "model_type": "uq_regressor",
63
63
  "target": "udm_asy_res_efflux_ratio",
64
- "features": ['smr_vsa4', 'tpsa', 'numhdonors', 'nhohcount', 'nbase', 'vsa_estate3', 'fr_guanido', 'mollogp', 'peoe_vsa8', 'peoe_vsa1', 'fr_imine', 'vsa_estate2', 'estate_vsa10', 'asphericity', 'xc_3dv', 'smr_vsa3', 'charge_centroid_distance', 'c3sp3', 'nitrogen_span', 'estate_vsa2', 'minpartialcharge', 'hba_hbd_ratio', 'slogp_vsa1', 'axp_7d', 'nocount', 'vsa_estate4', 'vsa_estate6', 'estate_vsa4', 'xc_4dv', 'xc_4d', 'num_s_centers', 'vsa_estate9', 'chi2v', 'axp_5d', 'mi', 'mse', 'bcut2d_mrhi', 'smr_vsa6', 'hallkieralpha', 'balabanj', 'amphiphilic_moment', 'type_ii_pattern_count', 'minabsestateindex', 'bcut2d_mwlow', 'axp_0dv', 'slogp_vsa5', 'axp_2d', 'axp_1dv', 'xch_5d', 'peoe_vsa10', 'molecular_asymmetry', 'kappa3', 'estate_vsa3', 'sse', 'bcut2d_logphi', 'fr_imidazole', 'molecular_volume_3d', 'bertzct', 'maxestateindex', 'aromatic_interaction_score', 'axp_3d', 'radius_of_gyration', 'vsa_estate7', 'si', 'axp_5dv', 'molecular_axis_length', 'estate_vsa6', 'fpdensitymorgan1', 'axp_6d', 'estate_vsa9', 'fpdensitymorgan2', 'xp_0dv', 'xp_6dv', 'molmr', 'qed', 'estate_vsa8', 'peoe_vsa9', 'xch_6dv', 'xp_7d', 'slogp_vsa2', 'xp_5dv', 'bcut2d_chghi', 'xch_6d', 'chi0n', 'slogp_vsa3', 'chi1v', 'chi3v', 'bcut2d_chglo', 'axp_1d', 'mp', 'num_defined_stereocenters', 'xp_3dv', 'bcut2d_mrlow', 'fr_al_oh', 'peoe_vsa7', 'chi2n', 'axp_6dv', 'axp_2dv', 'chi4n', 'xc_3d', 'axp_7dv', 'vsa_estate8', 'xch_7d', 'maxpartialcharge', 'chi1n', 'peoe_vsa2', 'axp_3dv', 'bcut2d_logplow', 'mv', 'xpc_5dv', 'kappa2', 'vsa_estate5', 'xp_5d', 'mm', 'maxabspartialcharge', 'axp_4dv', 'maxabsestateindex', 'axp_4d', 'xch_4dv', 'xp_2dv', 'heavyatommolwt', 'numatomstereocenters', 'xp_7dv', 'numsaturatedheterocycles', 'xp_3d', 'kappa1', 'mz', 'axp_0d', 'chi1', 'xch_4d', 'smr_vsa1', 'xp_2d', 'estate_vsa5', 'phi', 'fr_ether', 'xc_5d', 'c1sp3', 'estate_vsa7', 'estate_vsa1', 'vsa_estate1', 'slogp_vsa4', 'avgipc', 'smr_vsa10', 'numvalenceelectrons', 'xc_5dv', 'peoe_vsa12', 'peoe_vsa6', 'xpc_5d', 'xpc_6d', 'minestateindex', 'chi3n', 'smr_vsa5', 'xp_4d', 'numheteroatoms', 'fpdensitymorgan3', 'xpc_4d', 'sps', 'xp_1d', 'sv', 'fr_ar_n', 'slogp_vsa10', 'c2sp3', 'xpc_4dv', 'chi0v', 'xpc_6dv', 'xp_1dv', 'vsa_estate10', 'sare', 'c2sp2', 'mpe', 'xch_7dv', 'chi4v', 'type_i_pattern_count', 'sp', 'slogp_vsa8', 'amide_count', 'num_stereocenters', 'num_r_centers', 'tertiary_amine_count', 'spe', 'xp_4dv', 'numsaturatedrings', 'mare', 'numhacceptors', 'chi0', 'fractioncsp3', 'fr_nh0', 'xch_5dv', 'fr_aniline', 'smr_vsa7', 'labuteasa', 'c3sp2', 'xp_0d', 'xp_6d', 'peoe_vsa11', 'fr_ar_nh', 'molwt', 'intramolecular_hbond_potential', 'peoe_vsa3', 'fr_nhpyrrole', 'numaliphaticrings', 'hybratio', 'smr_vsa9', 'peoe_vsa13', 'bcut2d_mwhi', 'c1sp2', 'slogp_vsa11', 'numrotatablebonds', 'numaliphaticcarbocycles', 'slogp_vsa6', 'peoe_vsa4', 'numunspecifiedatomstereocenters', 'xc_6d', 'xc_6dv', 'num_unspecified_stereocenters', 'sz', 'minabspartialcharge', 'fcsp3', 'c1sp1', 'fr_piperzine', 'numaliphaticheterocycles', 'numamidebonds', 'fr_benzene', 'numaromaticheterocycles', 'sm', 'fr_priamide', 'fr_piperdine', 'fr_methoxy', 'c4sp3', 'fr_c_o_nocoo', 'exactmolwt', 'stereo_complexity', 'fr_hoccn', 'numaromaticcarbocycles', 'fr_nh2', 'numheterocycles', 'fr_morpholine', 'fr_ketone', 'fr_nh1', 'frac_defined_stereo', 'fr_aryl_methyl', 'fr_alkyl_halide', 'fr_phenol', 'fr_al_oh_notert', 'fr_ar_oh', 'fr_pyridine', 'fr_amide', 'slogp_vsa7', 'fr_halogen', 'numsaturatedcarbocycles', 'slogp_vsa12', 'fr_ndealkylation1', 'xch_3d', 'fr_bicyclic', 'naromatom', 'narombond'],
64
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
65
65
  "id_column": "udm_mol_bat_id",
66
66
  "compressed_features": [],
67
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-pytorch/training",
67
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-pytorch-260113/training",
68
68
  "hyperparameters": {},
69
69
  }
70
70
 
@@ -148,12 +148,16 @@ def convert_categorical_types(
148
148
  def decompress_features(
149
149
  df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
150
  ) -> tuple[pd.DataFrame, list[str]]:
151
- """Decompress bitstring features into individual bit columns.
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
152
156
 
153
157
  Args:
154
158
  df: The features DataFrame
155
159
  features: Full list of feature names
156
- compressed_features: List of feature names to decompress (bitstrings)
160
+ compressed_features: List of feature names to decompress
157
161
 
158
162
  Returns:
159
163
  Tuple of (DataFrame with decompressed features, updated feature list)
@@ -178,18 +182,18 @@ def decompress_features(
178
182
  # Remove the feature from the list to avoid duplication
179
183
  decompressed_features.remove(feature)
180
184
 
181
- # Handle all compressed features as bitstrings
182
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
183
- prefix = feature[:3]
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
184
189
 
185
- # Create all new columns at once - avoids fragmentation
186
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
187
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
188
194
 
189
- # Add to features list
195
+ # Update features list and dataframe
190
196
  decompressed_features.extend(new_col_names)
191
-
192
- # Drop original column and concatenate new ones
193
197
  df = df.drop(columns=[feature])
194
198
  df = pd.concat([df, new_df], axis=1)
195
199
 
@@ -63,13 +63,13 @@ REGRESSION_ONLY_PARAMS = {"objective"}
63
63
 
64
64
  # Template parameters (filled in by Workbench)
65
65
  TEMPLATE_PARAMS = {
66
- "model_type": "classifier",
67
- "target": "solubility_class",
68
- "features": ['molwt', 'mollogp', 'molmr', 'heavyatomcount', 'numhacceptors', 'numhdonors', 'numheteroatoms', 'numrotatablebonds', 'numvalenceelectrons', 'numaromaticrings', 'numsaturatedrings', 'numaliphaticrings', 'ringcount', 'tpsa', 'labuteasa', 'balabanj', 'bertzct'],
69
- "id_column": "id",
70
- "compressed_features": [],
71
- "model_metrics_s3_path": "s3://sandbox-sageworks-artifacts/models/aqsol-class/training",
72
- "hyperparameters": {},
66
+ "model_type": "uq_regressor",
67
+ "target": "udm_asy_res_efflux_ratio",
68
+ "features": ['chi2v', 'fr_sulfone', 'chi1v', 'bcut2d_logplow', 'fr_piperzine', 'kappa3', 'smr_vsa1', 'slogp_vsa5', 'fr_ketone_topliss', 'fr_sulfonamd', 'fr_imine', 'fr_benzene', 'fr_ester', 'chi2n', 'labuteasa', 'peoe_vsa2', 'smr_vsa6', 'bcut2d_chglo', 'fr_sh', 'peoe_vsa1', 'fr_allylic_oxid', 'chi4n', 'fr_ar_oh', 'fr_nh0', 'fr_term_acetylene', 'slogp_vsa7', 'slogp_vsa4', 'estate_vsa1', 'vsa_estate4', 'numbridgeheadatoms', 'numheterocycles', 'fr_ketone', 'fr_morpholine', 'fr_guanido', 'estate_vsa2', 'numheteroatoms', 'fr_nitro_arom_nonortho', 'fr_piperdine', 'nocount', 'numspiroatoms', 'fr_aniline', 'fr_thiophene', 'slogp_vsa10', 'fr_amide', 'slogp_vsa2', 'fr_epoxide', 'vsa_estate7', 'fr_ar_coo', 'fr_imidazole', 'fr_nitrile', 'fr_oxazole', 'numsaturatedrings', 'fr_pyridine', 'fr_hoccn', 'fr_ndealkylation1', 'numaliphaticheterocycles', 'fr_phenol', 'maxpartialcharge', 'vsa_estate5', 'peoe_vsa13', 'minpartialcharge', 'qed', 'fr_al_oh', 'slogp_vsa11', 'chi0n', 'fr_bicyclic', 'peoe_vsa12', 'fpdensitymorgan1', 'fr_oxime', 'molwt', 'fr_dihydropyridine', 'smr_vsa5', 'peoe_vsa5', 'fr_nitro', 'hallkieralpha', 'heavyatommolwt', 'fr_alkyl_halide', 'peoe_vsa8', 'fr_nhpyrrole', 'fr_isocyan', 'bcut2d_chghi', 'fr_lactam', 'peoe_vsa11', 'smr_vsa9', 'tpsa', 'chi4v', 'slogp_vsa1', 'phi', 'bcut2d_logphi', 'avgipc', 'estate_vsa11', 'fr_coo', 'bcut2d_mwhi', 'numunspecifiedatomstereocenters', 'vsa_estate10', 'estate_vsa8', 'numvalenceelectrons', 'fr_nh2', 'fr_lactone', 'vsa_estate1', 'estate_vsa4', 'numatomstereocenters', 'vsa_estate8', 'fr_para_hydroxylation', 'peoe_vsa3', 'fr_thiazole', 'peoe_vsa10', 'fr_ndealkylation2', 'slogp_vsa12', 'peoe_vsa9', 'maxestateindex', 'fr_quatn', 'smr_vsa7', 'minestateindex', 'numaromaticheterocycles', 'numrotatablebonds', 'fr_ar_nh', 'fr_ether', 'exactmolwt', 'fr_phenol_noorthohbond', 'slogp_vsa3', 'fr_ar_n', 'sps', 'fr_c_o_nocoo', 'bertzct', 'peoe_vsa7', 'slogp_vsa8', 'numradicalelectrons', 'molmr', 'fr_tetrazole', 'numsaturatedcarbocycles', 'bcut2d_mrhi', 'kappa1', 'numamidebonds', 'fpdensitymorgan2', 'smr_vsa8', 'chi1n', 'estate_vsa6', 'fr_barbitur', 'fr_diazo', 'kappa2', 'chi0', 'bcut2d_mrlow', 'balabanj', 'peoe_vsa4', 'numhacceptors', 'fr_sulfide', 'chi3n', 'smr_vsa2', 'fr_al_oh_notert', 'fr_benzodiazepine', 'fr_phos_ester', 'fr_aldehyde', 'fr_coo2', 'estate_vsa5', 'fr_prisulfonamd', 'numaromaticcarbocycles', 'fr_unbrch_alkane', 'fr_urea', 'fr_nitroso', 'smr_vsa10', 'fr_c_s', 'smr_vsa3', 'fr_methoxy', 'maxabspartialcharge', 'slogp_vsa9', 'heavyatomcount', 'fr_azide', 'chi3v', 'smr_vsa4', 'mollogp', 'chi0v', 'fr_aryl_methyl', 'fr_nh1', 'fpdensitymorgan3', 'fr_furan', 'fr_hdrzine', 'fr_arn', 'numaromaticrings', 'vsa_estate3', 'fr_azo', 'fr_halogen', 'estate_vsa9', 'fr_hdrzone', 'numhdonors', 'fr_alkyl_carbamate', 'fr_isothiocyan', 'minabspartialcharge', 'fr_al_coo', 'ringcount', 'chi1', 'estate_vsa7', 'fr_nitro_arom', 'vsa_estate9', 'minabsestateindex', 'maxabsestateindex', 'vsa_estate6', 'estate_vsa10', 'estate_vsa3', 'fr_n_o', 'fr_amidine', 'fr_thiocyan', 'fr_phos_acid', 'fr_c_o', 'fr_imide', 'numaliphaticrings', 'peoe_vsa6', 'vsa_estate2', 'nhohcount', 'numsaturatedheterocycles', 'slogp_vsa6', 'peoe_vsa14', 'fractioncsp3', 'bcut2d_mwlow', 'numaliphaticcarbocycles', 'fr_priamide', 'nacid', 'nbase', 'naromatom', 'narombond', 'sz', 'sm', 'sv', 'sse', 'spe', 'sare', 'sp', 'si', 'mz', 'mm', 'mv', 'mse', 'mpe', 'mare', 'mp', 'mi', 'xch_3d', 'xch_4d', 'xch_5d', 'xch_6d', 'xch_7d', 'xch_3dv', 'xch_4dv', 'xch_5dv', 'xch_6dv', 'xch_7dv', 'xc_3d', 'xc_4d', 'xc_5d', 'xc_6d', 'xc_3dv', 'xc_4dv', 'xc_5dv', 'xc_6dv', 'xpc_4d', 'xpc_5d', 'xpc_6d', 'xpc_4dv', 'xpc_5dv', 'xpc_6dv', 'xp_0d', 'xp_1d', 'xp_2d', 'xp_3d', 'xp_4d', 'xp_5d', 'xp_6d', 'xp_7d', 'axp_0d', 'axp_1d', 'axp_2d', 'axp_3d', 'axp_4d', 'axp_5d', 'axp_6d', 'axp_7d', 'xp_0dv', 'xp_1dv', 'xp_2dv', 'xp_3dv', 'xp_4dv', 'xp_5dv', 'xp_6dv', 'xp_7dv', 'axp_0dv', 'axp_1dv', 'axp_2dv', 'axp_3dv', 'axp_4dv', 'axp_5dv', 'axp_6dv', 'axp_7dv', 'c1sp1', 'c2sp1', 'c1sp2', 'c2sp2', 'c3sp2', 'c1sp3', 'c2sp3', 'c3sp3', 'c4sp3', 'hybratio', 'fcsp3', 'num_stereocenters', 'num_unspecified_stereocenters', 'num_defined_stereocenters', 'num_r_centers', 'num_s_centers', 'num_stereobonds', 'num_e_bonds', 'num_z_bonds', 'stereo_complexity', 'frac_defined_stereo'],
69
+ "id_column": "udm_mol_bat_id",
70
+ "compressed_features": ['fingerprint'],
71
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/caco2-er-reg-temporal/training",
72
+ "hyperparameters": {'n_folds': 1},
73
73
  }
74
74
 
75
75
 
@@ -148,12 +148,16 @@ def convert_categorical_types(
148
148
  def decompress_features(
149
149
  df: pd.DataFrame, features: list[str], compressed_features: list[str]
150
150
  ) -> tuple[pd.DataFrame, list[str]]:
151
- """Decompress bitstring features into individual bit columns.
151
+ """Decompress compressed features (bitstrings or count vectors) into individual columns.
152
+
153
+ Supports two formats (auto-detected):
154
+ - Bitstrings: "10110010..." → individual uint8 columns (0 or 1)
155
+ - Count vectors: "0,3,0,1,5,..." → individual uint8 columns (0-255)
152
156
 
153
157
  Args:
154
158
  df: The features DataFrame
155
159
  features: Full list of feature names
156
- compressed_features: List of feature names to decompress (bitstrings)
160
+ compressed_features: List of feature names to decompress
157
161
 
158
162
  Returns:
159
163
  Tuple of (DataFrame with decompressed features, updated feature list)
@@ -178,18 +182,18 @@ def decompress_features(
178
182
  # Remove the feature from the list to avoid duplication
179
183
  decompressed_features.remove(feature)
180
184
 
181
- # Handle all compressed features as bitstrings
182
- bit_matrix = np.array([list(bitstring) for bitstring in df[feature]], dtype=np.uint8)
183
- prefix = feature[:3]
185
+ # Auto-detect format and parse: comma-separated counts or bitstring
186
+ sample = str(df[feature].dropna().iloc[0]) if not df[feature].dropna().empty else ""
187
+ parse_fn = (lambda s: list(map(int, s.split(",")))) if "," in sample else list
188
+ feature_matrix = np.array([parse_fn(s) for s in df[feature]], dtype=np.uint8)
184
189
 
185
- # Create all new columns at once - avoids fragmentation
186
- new_col_names = [f"{prefix}_{i}" for i in range(bit_matrix.shape[1])]
187
- new_df = pd.DataFrame(bit_matrix, columns=new_col_names, index=df.index)
190
+ # Create new columns with prefix from feature name
191
+ prefix = feature[:3]
192
+ new_col_names = [f"{prefix}_{i}" for i in range(feature_matrix.shape[1])]
193
+ new_df = pd.DataFrame(feature_matrix, columns=new_col_names, index=df.index)
188
194
 
189
- # Add to features list
195
+ # Update features list and dataframe
190
196
  decompressed_features.extend(new_col_names)
191
-
192
- # Drop original column and concatenate new ones
193
197
  df = df.drop(columns=[feature])
194
198
  df = pd.concat([df, new_df], axis=1)
195
199
 
@@ -0,0 +1,35 @@
1
+ """MetaModelSimulator: Simulate and analyze ensemble model performance.
2
+
3
+ This class helps evaluate whether a meta model (ensemble) would outperform
4
+ individual child models by analyzing endpoint inference predictions.
5
+ """
6
+
7
+ import argparse
8
+ from workbench.utils.meta_model_simulator import MetaModelSimulator
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(
13
+ description="Simulate and analyze ensemble model performance using MetaModelSimulator."
14
+ )
15
+ parser.add_argument(
16
+ "models",
17
+ nargs="+",
18
+ help="List of model endpoint names to include in the ensemble simulation.",
19
+ )
20
+ parser.add_argument(
21
+ "--id-column",
22
+ default="molecule_name",
23
+ help="Name of the ID column (default: molecule_name)",
24
+ )
25
+ args = parser.parse_args()
26
+ models = args.models
27
+ id_column = args.id_column
28
+
29
+ # Create MetaModelSimulator instance and generate report
30
+ sim = MetaModelSimulator(models, id_column=id_column)
31
+ sim.report()
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()
@@ -1,6 +1,8 @@
1
1
  import argparse
2
+ import ast
2
3
  import logging
3
4
  import json
5
+ import re
4
6
  from pathlib import Path
5
7
 
6
8
  # Workbench Imports
@@ -13,6 +15,56 @@ cm = ConfigManager()
13
15
  workbench_bucket = cm.get_config("WORKBENCH_BUCKET")
14
16
 
15
17
 
18
+ def parse_workbench_batch(script_content: str) -> dict | None:
19
+ """Parse WORKBENCH_BATCH config from a script.
20
+
21
+ Looks for a dictionary assignment like:
22
+ WORKBENCH_BATCH = {
23
+ "outputs": ["feature_set_xyz"],
24
+ }
25
+ or:
26
+ WORKBENCH_BATCH = {
27
+ "inputs": ["feature_set_xyz"],
28
+ }
29
+
30
+ Args:
31
+ script_content: The Python script content as a string
32
+
33
+ Returns:
34
+ The parsed dictionary or None if not found
35
+ """
36
+ pattern = r"WORKBENCH_BATCH\s*=\s*(\{[^}]+\})"
37
+ match = re.search(pattern, script_content, re.DOTALL)
38
+ if match:
39
+ try:
40
+ return ast.literal_eval(match.group(1))
41
+ except (ValueError, SyntaxError) as e:
42
+ print(f"⚠️ Warning: Failed to parse WORKBENCH_BATCH: {e}")
43
+ return None
44
+ return None
45
+
46
+
47
+ def get_message_group_id(batch_config: dict | None) -> str:
48
+ """Derive MessageGroupId from outputs or inputs.
49
+
50
+ - Scripts with outputs use first output as group
51
+ - Scripts with inputs use first input as group
52
+ - Default to "ml-pipeline-jobs" if no config
53
+ """
54
+ if not batch_config:
55
+ return "ml-pipeline-jobs"
56
+
57
+ outputs = batch_config.get("outputs", [])
58
+ inputs = batch_config.get("inputs", [])
59
+
60
+ if outputs:
61
+ return outputs[0]
62
+ elif inputs:
63
+ return inputs[0]
64
+ else:
65
+ return "ml-pipeline-jobs"
66
+
67
+
16
68
  def submit_to_sqs(
17
69
  script_path: str,
18
70
  size: str = "small",
@@ -44,12 +96,24 @@ def submit_to_sqs(
44
96
  if not script_file.exists():
45
97
  raise FileNotFoundError(f"Script not found: {script_path}")
46
98
 
99
+ # Read script content and parse WORKBENCH_BATCH config
100
+ script_content = script_file.read_text()
101
+ batch_config = parse_workbench_batch(script_content)
102
+ group_id = get_message_group_id(batch_config)
103
+ outputs = (batch_config or {}).get("outputs", [])
104
+ inputs = (batch_config or {}).get("inputs", [])
105
+
47
106
  print(f"📄 Script: {script_file.name}")
48
107
  print(f"📏 Size tier: {size}")
49
108
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (serverless={'False' if realtime else 'True'})")
50
109
  print(f"🔄 DynamicTraining: {dt}")
51
110
  print(f"🆕 Promote: {promote}")
52
111
  print(f"🪣 Bucket: {workbench_bucket}")
112
+ if outputs:
113
+ print(f"📤 Outputs: {outputs}")
114
+ if inputs:
115
+ print(f"📥 Inputs: {inputs}")
116
+ print(f"📦 Batch Group: {group_id}")
53
117
  sqs = AWSAccountClamp().boto3_session.client("sqs")
54
118
  script_name = script_file.name
55
119
 
@@ -75,7 +139,7 @@ def submit_to_sqs(
75
139
  print(f" Destination: {s3_path}")
76
140
 
77
141
  try:
78
- upload_content_to_s3(script_file.read_text(), s3_path)
142
+ upload_content_to_s3(script_content, s3_path)
79
143
  print("✅ Script uploaded successfully")
80
144
  except Exception as e:
81
145
  print(f"❌ Upload failed: {e}")
@@ -118,7 +182,7 @@ def submit_to_sqs(
118
182
  response = sqs.send_message(
119
183
  QueueUrl=queue_url,
120
184
  MessageBody=json.dumps(message, indent=2),
121
- MessageGroupId="ml-pipeline-jobs", # Required for FIFO
185
+ MessageGroupId=group_id, # From WORKBENCH_BATCH or default
122
186
  )
123
187
  message_id = response["MessageId"]
124
188
  print("✅ Message sent successfully!")
@@ -136,6 +200,11 @@ def submit_to_sqs(
136
200
  print(f"⚡ Mode: {'Real-time' if realtime else 'Serverless'} (SERVERLESS={'False' if realtime else 'True'})")
137
201
  print(f"🔄 DynamicTraining: {dt}")
138
202
  print(f"🆕 Promote: {promote}")
203
+ if outputs:
204
+ print(f"📤 Outputs: {outputs}")
205
+ if inputs:
206
+ print(f"📥 Inputs: {inputs}")
207
+ print(f"📦 Batch Group: {group_id}")
139
208
  print(f"🆔 Message ID: {message_id}")
140
209
  print("\n🔍 MONITORING LOCATIONS:")
141
210
  print(f" • SQS Queue: AWS Console → SQS → {queue_name}")
@@ -30,9 +30,10 @@ ul, ol {
30
30
  --ag-header-background-color: rgba(150, 150, 195);
31
31
  }
32
32
 
33
- /* Adjust cell background */
33
+ /* Adjust cell background and text color */
34
34
  .ag-cell {
35
35
  background-color: rgb(240, 240, 240);
36
+ color: rgb(80, 80, 80);
36
37
  }
37
38
 
38
39
  /* Alternate row colors */
@@ -40,6 +41,11 @@ ul, ol {
40
41
  background-color: rgb(230, 230, 230);
41
42
  }
42
43
 
44
+ /* AgGrid header text color */
45
+ .ag-header-cell-text {
46
+ color: rgb(60, 60, 60);
47
+ }
48
+
43
49
  /* Selection color for the entire row */
44
50
  .ag-row.ag-row-selected .ag-cell {
45
51
  background-color: rgba(170, 170, 205, 1.0);
@@ -133,6 +133,40 @@ a:hover {
133
133
  color: rgb(100, 255, 100);
134
134
  }
135
135
 
136
+ /* Dropdown styling (dcc.Dropdown) - override Bootstrap's --bs-body-bg variable */
137
+ .dash-dropdown {
138
+ --bs-body-bg: rgb(55, 60, 90);
139
+ --bs-border-color: rgb(80, 85, 115);
140
+ }
141
+
142
+
143
+ /* Bootstrap form controls (dbc components) */
144
+ .form-select, .form-control {
145
+ background-color: rgb(55, 60, 90) !important;
146
+ border: 1px solid rgb(80, 85, 115) !important;
147
+ color: rgb(210, 210, 210) !important;
148
+ }
149
+
150
+ .form-select:focus, .form-control:focus {
151
+ background-color: rgb(60, 65, 95) !important;
152
+ border-color: rgb(100, 105, 140) !important;
153
+ box-shadow: 0 0 0 0.2rem rgba(100, 105, 140, 0.25) !important;
154
+ }
155
+
156
+ .dropdown-menu {
157
+ background-color: rgb(55, 60, 90) !important;
158
+ border: 1px solid rgb(80, 85, 115) !important;
159
+ }
160
+
161
+ .dropdown-item {
162
+ color: rgb(210, 210, 210) !important;
163
+ }
164
+
165
+ .dropdown-item:hover, .dropdown-item:focus {
166
+ background-color: rgb(70, 75, 110) !important;
167
+ color: rgb(230, 230, 230) !important;
168
+ }
169
+
136
170
  /* Table styling */
137
171
  table {
138
172
  width: 100%;
@@ -1,11 +1,19 @@
1
- """Molecular fingerprint computation utilities"""
1
+ """Molecular fingerprint computation utilities for ADMET modeling.
2
+
3
+ This module provides Morgan count fingerprints, the standard for ADMET prediction.
4
+ Count fingerprints outperform binary fingerprints for molecular property prediction.
5
+
6
+ References:
7
+ - Count vs Binary: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
8
+ - ECFP/Morgan: https://pubs.acs.org/doi/10.1021/ci100050t
9
+ """
2
10
 
3
11
  import logging
4
- import pandas as pd
5
12
 
6
- # Molecular Descriptor Imports
13
+ import numpy as np
14
+ import pandas as pd
7
15
  from rdkit import Chem, RDLogger
8
- from rdkit.Chem import rdFingerprintGenerator
16
+ from rdkit.Chem import AllChem
9
17
  from rdkit.Chem.MolStandardize import rdMolStandardize
10
18
 
11
19
  # Suppress RDKit warnings (e.g., "not removing hydrogen atom without neighbors")
@@ -16,20 +24,25 @@ RDLogger.DisableLog("rdApp.warning")
16
24
  log = logging.getLogger("workbench")
17
25
 
18
26
 
19
- def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=True) -> pd.DataFrame:
20
- """Compute and add Morgan fingerprints to the DataFrame.
27
+ def compute_morgan_fingerprints(df: pd.DataFrame, radius: int = 2, n_bits: int = 2048) -> pd.DataFrame:
28
+ """Compute Morgan count fingerprints for ADMET modeling.
29
+
30
+ Generates true count fingerprints where each bit position contains the
31
+ number of times that substructure appears in the molecule (clamped to 0-255).
32
+ This is the recommended approach for ADMET prediction per 2025 research.
21
33
 
22
34
  Args:
23
- df (pd.DataFrame): Input DataFrame containing SMILES strings.
24
- radius (int): Radius for the Morgan fingerprint.
25
- n_bits (int): Number of bits for the fingerprint.
26
- counts (bool): Count simulation for the fingerprint.
35
+ df: Input DataFrame containing SMILES strings.
36
+ radius: Radius for the Morgan fingerprint (default 2 = ECFP4 equivalent).
37
+ n_bits: Number of bits for the fingerprint (default 2048).
27
38
 
28
39
  Returns:
29
- pd.DataFrame: The input DataFrame with the Morgan fingerprints added as bit strings.
40
+ pd.DataFrame: Input DataFrame with 'fingerprint' column added.
41
+ Values are comma-separated uint8 counts.
30
42
 
31
43
  Note:
32
- See: https://greglandrum.github.io/rdkit-blog/posts/2021-07-06-simulating-counts.html
44
+ Count fingerprints outperform binary for ADMET prediction.
45
+ See: https://pubs.acs.org/doi/10.1021/acs.est.3c02198
33
46
  """
34
47
  delete_mol_column = False
35
48
 
@@ -43,7 +56,7 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
43
56
  log.warning("Detected serialized molecules in 'molecule' column. Removing...")
44
57
  del df["molecule"]
45
58
 
46
- # Convert SMILES to RDKit molecule objects (vectorized)
59
+ # Convert SMILES to RDKit molecule objects
47
60
  if "molecule" not in df.columns:
48
61
  log.info("Converting SMILES to RDKit Molecules...")
49
62
  delete_mol_column = True
@@ -59,15 +72,24 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
59
72
  lambda mol: rdMolStandardize.LargestFragmentChooser().choose(mol) if mol else None
60
73
  )
61
74
 
62
- # Create a Morgan fingerprint generator
63
- if counts:
64
- n_bits *= 4 # Multiply by 4 to simulate counts
65
- morgan_generator = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=n_bits, countSimulation=counts)
75
+ def mol_to_count_string(mol):
76
+ """Convert molecule to comma-separated count fingerprint string."""
77
+ if mol is None:
78
+ return pd.NA
66
79
 
67
- # Compute Morgan fingerprints (vectorized)
68
- fingerprints = largest_frags.apply(
69
- lambda mol: (morgan_generator.GetFingerprint(mol).ToBitString() if mol else pd.NA)
70
- )
80
+ # Get hashed Morgan fingerprint with counts
81
+ fp = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=n_bits)
82
+
83
+ # Initialize array and populate with counts (clamped to uint8 range)
84
+ counts = np.zeros(n_bits, dtype=np.uint8)
85
+ for idx, count in fp.GetNonzeroElements().items():
86
+ counts[idx] = min(count, 255)
87
+
88
+ # Return as comma-separated string
89
+ return ",".join(map(str, counts))
90
+
91
+ # Compute Morgan count fingerprints
92
+ fingerprints = largest_frags.apply(mol_to_count_string)
71
93
 
72
94
  # Add the fingerprints to the DataFrame
73
95
  df["fingerprint"] = fingerprints
@@ -75,59 +97,62 @@ def compute_morgan_fingerprints(df: pd.DataFrame, radius=2, n_bits=2048, counts=
75
97
  # Drop the intermediate 'molecule' column if it was added
76
98
  if delete_mol_column:
77
99
  del df["molecule"]
100
+
78
101
  return df
79
102
 
80
103
 
81
104
  if __name__ == "__main__":
82
- print("Running molecular fingerprint tests...")
83
- print("Note: This requires molecular_screening module to be available")
105
+ print("Running Morgan count fingerprint tests...")
84
106
 
85
107
  # Test molecules
86
108
  test_molecules = {
87
109
  "aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O",
88
110
  "caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
89
111
  "glucose": "C([C@@H]1[C@H]([C@@H]([C@H](C(O1)O)O)O)O)O", # With stereochemistry
90
- "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt
112
+ "sodium_acetate": "CC(=O)[O-].[Na+]", # Salt (largest fragment used)
91
113
  "benzene": "c1ccccc1",
92
114
  "butene_e": "C/C=C/C", # E-butene
93
115
  "butene_z": "C/C=C\\C", # Z-butene
94
116
  }
95
117
 
96
- # Test 1: Morgan Fingerprints
97
- print("\n1. Testing Morgan fingerprint generation...")
118
+ # Test 1: Morgan Count Fingerprints (default parameters)
119
+ print("\n1. Testing Morgan fingerprint generation (radius=2, n_bits=2048)...")
98
120
 
99
121
  test_df = pd.DataFrame({"SMILES": list(test_molecules.values()), "name": list(test_molecules.keys())})
100
-
101
- fp_df = compute_morgan_fingerprints(test_df.copy(), radius=2, n_bits=512, counts=False)
122
+ fp_df = compute_morgan_fingerprints(test_df.copy())
102
123
 
103
124
  print(" Fingerprint generation results:")
104
125
  for _, row in fp_df.iterrows():
105
126
  fp = row.get("fingerprint", "N/A")
106
- fp_len = len(fp) if fp != "N/A" else 0
107
- print(f" {row['name']:15} {fp_len} bits")
127
+ if pd.notna(fp):
128
+ counts = [int(x) for x in fp.split(",")]
129
+ non_zero = sum(1 for c in counts if c > 0)
130
+ max_count = max(counts)
131
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero, max={max_count}")
132
+ else:
133
+ print(f" {row['name']:15} → N/A")
108
134
 
109
- # Test 2: Different fingerprint parameters
110
- print("\n2. Testing different fingerprint parameters...")
135
+ # Test 2: Different parameters
136
+ print("\n2. Testing with different parameters (radius=3, n_bits=1024)...")
111
137
 
112
- # Test with counts enabled
113
- fp_counts_df = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=256, counts=True)
138
+ fp_df_custom = compute_morgan_fingerprints(test_df.copy(), radius=3, n_bits=1024)
114
139
 
115
- print(" With count simulation (256 bits * 4):")
116
- for _, row in fp_counts_df.iterrows():
140
+ for _, row in fp_df_custom.iterrows():
117
141
  fp = row.get("fingerprint", "N/A")
118
- fp_len = len(fp) if fp != "N/A" else 0
119
- print(f" {row['name']:15} {fp_len} bits")
142
+ if pd.notna(fp):
143
+ counts = [int(x) for x in fp.split(",")]
144
+ non_zero = sum(1 for c in counts if c > 0)
145
+ print(f" {row['name']:15} → {len(counts)} features, {non_zero} non-zero")
146
+ else:
147
+ print(f" {row['name']:15} → N/A")
120
148
 
121
149
  # Test 3: Edge cases
122
150
  print("\n3. Testing edge cases...")
123
151
 
124
152
  # Invalid SMILES
125
153
  invalid_df = pd.DataFrame({"SMILES": ["INVALID", ""]})
126
- try:
127
- fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
128
- print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} valid molecules")
129
- except Exception as e:
130
- print(f" ✓ Invalid SMILES properly raised error: {type(e).__name__}")
154
+ fp_invalid = compute_morgan_fingerprints(invalid_df.copy())
155
+ print(f" ✓ Invalid SMILES handled: {len(fp_invalid)} rows returned")
131
156
 
132
157
  # Test with pre-existing molecule column
133
158
  mol_df = test_df.copy()
@@ -135,4 +160,16 @@ if __name__ == "__main__":
135
160
  fp_with_mol = compute_morgan_fingerprints(mol_df)
136
161
  print(f" ✓ Pre-existing molecule column handled: {len(fp_with_mol)} fingerprints generated")
137
162
 
163
+ # Test 4: Verify count values are reasonable
164
+ print("\n4. Verifying count distribution...")
165
+ all_counts = []
166
+ for _, row in fp_df.iterrows():
167
+ fp = row.get("fingerprint", "N/A")
168
+ if pd.notna(fp):
169
+ counts = [int(x) for x in fp.split(",")]
170
+ all_counts.extend([c for c in counts if c > 0])
171
+
172
+ if all_counts:
173
+ print(f" Non-zero counts: min={min(all_counts)}, max={max(all_counts)}, mean={np.mean(all_counts):.2f}")
174
+
138
175
  print("\n✅ All fingerprint tests completed!")
@@ -17,18 +17,28 @@ log = logging.getLogger("workbench")
17
17
 
18
18
  def fingerprints_to_matrix(fingerprints, dtype=np.uint8):
19
19
  """
20
- Convert bitstring fingerprints to numpy matrix.
20
+ Convert fingerprints to numpy matrix.
21
+
22
+ Supports two formats (auto-detected):
23
+ - Bitstrings: "10110010..." → matrix of 0s and 1s
24
+ - Count vectors: "0,3,0,1,5,..." → matrix of counts (or binary if dtype=np.bool_)
21
25
 
22
26
  Args:
23
- fingerprints: pandas Series or list of bitstring fingerprints
24
- dtype: numpy data type (uint8 is default: np.bool_ is good for Jaccard computations
27
+ fingerprints: pandas Series or list of fingerprints
28
+ dtype: numpy data type (uint8 is default; np.bool_ for Jaccard computations)
25
29
 
26
30
  Returns:
27
31
  dense numpy array of shape (n_molecules, n_bits)
28
32
  """
29
-
30
- # Dense matrix representation (we might support sparse in the future)
31
- return np.array([list(fp) for fp in fingerprints], dtype=dtype)
33
+ # Auto-detect format based on first fingerprint
34
+ sample = str(fingerprints.iloc[0] if hasattr(fingerprints, "iloc") else fingerprints[0])
35
+ if "," in sample:
36
+ # Count vector format: comma-separated integers
37
+ matrix = np.array([list(map(int, fp.split(","))) for fp in fingerprints], dtype=dtype)
38
+ else:
39
+ # Bitstring format: each character is a bit
40
+ matrix = np.array([list(fp) for fp in fingerprints], dtype=dtype)
41
+ return matrix
32
42
 
33
43
 
34
44
  def project_fingerprints(df: pd.DataFrame, projection: str = "UMAP") -> pd.DataFrame: