workbench 0.8.234__py3-none-any.whl → 0.8.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. workbench/algorithms/dataframe/smart_aggregator.py +17 -12
  2. workbench/api/endpoint.py +13 -4
  3. workbench/api/model.py +2 -2
  4. workbench/cached/cached_model.py +2 -2
  5. workbench/core/artifacts/athena_source.py +5 -3
  6. workbench/core/artifacts/endpoint_core.py +30 -5
  7. workbench/core/cloud_platform/aws/aws_meta.py +2 -1
  8. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +27 -14
  9. workbench/model_script_utils/model_script_utils.py +225 -0
  10. workbench/model_script_utils/uq_harness.py +39 -21
  11. workbench/model_scripts/chemprop/chemprop.template +30 -15
  12. workbench/model_scripts/chemprop/generated_model_script.py +35 -18
  13. workbench/model_scripts/chemprop/model_script_utils.py +225 -0
  14. workbench/model_scripts/pytorch_model/generated_model_script.py +29 -15
  15. workbench/model_scripts/pytorch_model/model_script_utils.py +225 -0
  16. workbench/model_scripts/pytorch_model/pytorch.template +28 -14
  17. workbench/model_scripts/pytorch_model/uq_harness.py +39 -21
  18. workbench/model_scripts/xgb_model/generated_model_script.py +35 -22
  19. workbench/model_scripts/xgb_model/model_script_utils.py +225 -0
  20. workbench/model_scripts/xgb_model/uq_harness.py +39 -21
  21. workbench/model_scripts/xgb_model/xgb_model.template +29 -18
  22. workbench/scripts/ml_pipeline_batch.py +47 -2
  23. workbench/scripts/ml_pipeline_launcher.py +410 -0
  24. workbench/scripts/ml_pipeline_sqs.py +22 -2
  25. workbench/themes/dark/custom.css +29 -0
  26. workbench/themes/light/custom.css +29 -0
  27. workbench/themes/midnight_blue/custom.css +28 -0
  28. workbench/utils/model_utils.py +9 -0
  29. workbench/utils/theme_manager.py +95 -0
  30. workbench/web_interface/components/component_interface.py +3 -0
  31. workbench/web_interface/components/plugin_interface.py +26 -0
  32. workbench/web_interface/components/plugins/ag_table.py +4 -11
  33. workbench/web_interface/components/plugins/confusion_matrix.py +14 -8
  34. workbench/web_interface/components/plugins/model_plot.py +156 -0
  35. workbench/web_interface/components/plugins/scatter_plot.py +9 -2
  36. workbench/web_interface/components/plugins/shap_summary_plot.py +12 -4
  37. workbench/web_interface/components/settings_menu.py +10 -49
  38. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/METADATA +2 -2
  39. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/RECORD +43 -42
  40. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/WHEEL +1 -1
  41. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/entry_points.txt +1 -0
  42. workbench/web_interface/components/model_plot.py +0 -75
  43. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/licenses/LICENSE +0 -0
  44. {workbench-0.8.234.dist-info → workbench-0.8.239.dist-info}/top_level.txt +0 -0
@@ -55,25 +55,30 @@ def smart_aggregator(df: pd.DataFrame, target_rows: int = 1000, outlier_column:
55
55
  result["aggregation_count"] = 1
56
56
  return result.reset_index(drop=True)
57
57
 
58
- # Handle NaN values - fill with column median
59
- df_for_clustering = df[numeric_cols].fillna(df[numeric_cols].median())
60
-
61
- # Normalize and cluster
62
- X = StandardScaler().fit_transform(df_for_clustering)
58
+ # Handle NaN values - fill with column median (use numpy for speed)
59
+ clustering_data = df[numeric_cols].values
60
+ col_medians = np.nanmedian(clustering_data, axis=0)
61
+ nan_mask = np.isnan(clustering_data)
62
+ clustering_data = np.where(nan_mask, col_medians, clustering_data)
63
+
64
+ # Normalize and cluster (n_init=1 since MiniBatchKMeans is already approximate)
65
+ X = StandardScaler().fit_transform(clustering_data)
63
66
  df["_cluster"] = MiniBatchKMeans(
64
- n_clusters=min(target_rows, n_rows), random_state=42, batch_size=min(1024, n_rows), n_init=3
67
+ n_clusters=min(target_rows, n_rows), random_state=42, batch_size=min(1024, n_rows), n_init=1
65
68
  ).fit_predict(X)
66
69
 
67
70
  # Post-process: give high-outlier rows their own unique clusters so they don't get aggregated
68
71
  if outlier_column and outlier_column in df.columns:
69
- # Top 10% of outlier values get their own clusters, capped at 200
70
- n_to_isolate = min(int(n_rows * 0.1), 200)
71
- threshold = df[outlier_column].nlargest(n_to_isolate).min()
72
- high_outlier_mask = df[outlier_column] >= threshold
72
+ # Top 10% of outlier values get their own clusters, capped at 20% of target_rows
73
+ n_to_isolate = min(int(n_rows * 0.1), int(target_rows * 0.2))
74
+ outlier_values = df[outlier_column].values
75
+ threshold = np.partition(outlier_values, -n_to_isolate)[-n_to_isolate]
76
+ high_outlier_mask = outlier_values >= threshold
73
77
  n_high_outliers = high_outlier_mask.sum()
74
- # Assign unique cluster IDs starting after the max existing cluster
78
+ # Assign unique cluster IDs starting after the max existing cluster (match dtype to avoid warning)
75
79
  max_cluster = df["_cluster"].max()
76
- df.loc[high_outlier_mask, "_cluster"] = range(max_cluster + 1, max_cluster + 1 + n_high_outliers)
80
+ new_cluster_ids = np.arange(max_cluster + 1, max_cluster + 1 + n_high_outliers, dtype=df["_cluster"].dtype)
81
+ df.loc[high_outlier_mask, "_cluster"] = new_cluster_ids
77
82
  log.info(f"smart_aggregator: Isolated {n_high_outliers} high-outlier rows (>= {threshold:.3f})")
78
83
  elif outlier_column:
79
84
  log.warning(f"smart_aggregator: outlier_column '{outlier_column}' not found in columns")
workbench/api/endpoint.py CHANGED
@@ -29,7 +29,12 @@ class Endpoint(EndpointCore):
29
29
  return super().details(**kwargs)
30
30
 
31
31
  def inference(
32
- self, eval_df: pd.DataFrame, capture_name: str = None, id_column: str = None, drop_error_rows: bool = False
32
+ self,
33
+ eval_df: pd.DataFrame,
34
+ capture_name: str = None,
35
+ id_column: str = None,
36
+ drop_error_rows: bool = False,
37
+ include_quantiles: bool = False,
33
38
  ) -> pd.DataFrame:
34
39
  """Run inference on the Endpoint using the provided DataFrame
35
40
 
@@ -38,11 +43,12 @@ class Endpoint(EndpointCore):
38
43
  capture_name (str, optional): The Name of the capture to use (default: None)
39
44
  id_column (str, optional): The name of the column to use as the ID (default: None)
40
45
  drop_error_rows (bool): Whether to drop rows with errors (default: False)
46
+ include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
41
47
 
42
48
  Returns:
43
49
  pd.DataFrame: The DataFrame with predictions
44
50
  """
45
- return super().inference(eval_df, capture_name, id_column, drop_error_rows)
51
+ return super().inference(eval_df, capture_name, id_column, drop_error_rows, include_quantiles)
46
52
 
47
53
  def auto_inference(self) -> pd.DataFrame:
48
54
  """Run inference on the Endpoint using the test data from the model training view
@@ -75,13 +81,16 @@ class Endpoint(EndpointCore):
75
81
  """
76
82
  return super().fast_inference(eval_df, threads=threads)
77
83
 
78
- def cross_fold_inference(self) -> pd.DataFrame:
84
+ def cross_fold_inference(self, include_quantiles: bool = False) -> pd.DataFrame:
79
85
  """Pull cross-fold inference from model associated with this Endpoint
80
86
 
87
+ Args:
88
+ include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
89
+
81
90
  Returns:
82
91
  pd.DataFrame: A DataFrame with cross fold predictions
83
92
  """
84
- return super().cross_fold_inference()
93
+ return super().cross_fold_inference(include_quantiles)
85
94
 
86
95
 
87
96
  if __name__ == "__main__":
workbench/api/model.py CHANGED
@@ -44,7 +44,7 @@ class Model(ModelCore):
44
44
  serverless: bool = True,
45
45
  mem_size: int = 2048,
46
46
  max_concurrency: int = 5,
47
- instance: str = "ml.t2.medium",
47
+ instance: str = None,
48
48
  data_capture: bool = False,
49
49
  ) -> Endpoint:
50
50
  """Create an Endpoint from the Model.
@@ -55,7 +55,7 @@ class Model(ModelCore):
55
55
  serverless (bool): Set the endpoint to be serverless (default: True)
56
56
  mem_size (int): The memory size for the Endpoint in MB (default: 2048)
57
57
  max_concurrency (int): The maximum concurrency for the Endpoint (default: 5)
58
- instance (str): The instance type to use for Realtime(serverless=False) Endpoints (default: "ml.t2.medium")
58
+ instance (str): The instance type for Realtime Endpoints (default: None = auto-select based on model)
59
59
  data_capture (bool): Enable data capture for the Endpoint (default: False)
60
60
 
61
61
  Returns:
@@ -86,13 +86,13 @@ class CachedModel(CachedArtifactMixin, ModelCore):
86
86
 
87
87
  @CachedArtifactMixin.cache_result
88
88
  def get_inference_predictions(
89
- self, capture_name: str = "full_cross_fold", target_rows: int = 1000
89
+ self, capture_name: str = "full_cross_fold", target_rows: int = 2000
90
90
  ) -> Union[pd.DataFrame, None]:
91
91
  """Retrieve the captured prediction results for this model
92
92
 
93
93
  Args:
94
94
  capture_name (str, optional): Specific capture_name (default: full_cross_fold)
95
- target_rows (int, optional): Target number of rows to return (default: 1000)
95
+ target_rows (int, optional): Target number of rows to return (default: 2000)
96
96
 
97
97
  Returns:
98
98
  pd.DataFrame: DataFrame of the Captured Predictions (might be None)
@@ -258,7 +258,7 @@ class AthenaSource(DataSourceAbstract):
258
258
 
259
259
  # Wait for the query to complete
260
260
  wr.athena.wait_query(query_execution_id=query_execution_id, boto3_session=self.boto3_session)
261
- self.log.debug(f"Statement executed successfully: {query_execution_id}")
261
+ self.log.debug(f"Query executed successfully: {query_execution_id}")
262
262
  break # If successful, exit the retry loop
263
263
  except wr.exceptions.QueryFailed as e:
264
264
  if "AlreadyExistsException" in str(e):
@@ -271,11 +271,13 @@ class AthenaSource(DataSourceAbstract):
271
271
  time.sleep(retry_delay)
272
272
  else:
273
273
  if not silence_errors:
274
- self.log.critical(f"Failed to execute statement after {max_retries} attempts: {e}")
274
+ self.log.critical(f"Failed to execute query after {max_retries} attempts: {query}")
275
+ self.log.critical(f"Error: {e}")
275
276
  raise
276
277
  else:
277
278
  if not silence_errors:
278
- self.log.critical(f"Failed to execute statement: {e}")
279
+ self.log.critical(f"Failed to execute query: {query}")
280
+ self.log.critical(f"Error: {e}")
279
281
  raise
280
282
 
281
283
  def s3_storage_location(self) -> str:
@@ -370,7 +370,12 @@ class EndpointCore(Artifact):
370
370
  return self.inference(eval_df, "full_inference")
371
371
 
372
372
  def inference(
373
- self, eval_df: pd.DataFrame, capture_name: str = None, id_column: str = None, drop_error_rows: bool = False
373
+ self,
374
+ eval_df: pd.DataFrame,
375
+ capture_name: str = None,
376
+ id_column: str = None,
377
+ drop_error_rows: bool = False,
378
+ include_quantiles: bool = False,
374
379
  ) -> pd.DataFrame:
375
380
  """Run inference on the Endpoint using the provided DataFrame
376
381
 
@@ -379,6 +384,7 @@ class EndpointCore(Artifact):
379
384
  capture_name (str, optional): Name of the inference capture (default=None)
380
385
  id_column (str, optional): Name of the ID column (default=None)
381
386
  drop_error_rows (bool, optional): If True, drop rows that had endpoint errors/issues (default=False)
387
+ include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
382
388
 
383
389
  Returns:
384
390
  pd.DataFrame: DataFrame with the inference results
@@ -478,6 +484,7 @@ class EndpointCore(Artifact):
478
484
  description,
479
485
  features,
480
486
  id_column,
487
+ include_quantiles,
481
488
  )
482
489
 
483
490
  # Save primary target (or single target) with original capture_name
@@ -491,6 +498,7 @@ class EndpointCore(Artifact):
491
498
  capture_name.replace("_", " ").title(),
492
499
  features,
493
500
  id_column,
501
+ include_quantiles,
494
502
  )
495
503
 
496
504
  # Capture uncertainty metrics if prediction_std is available (UQ, ChemProp, etc.)
@@ -501,9 +509,12 @@ class EndpointCore(Artifact):
501
509
  # Return the prediction DataFrame
502
510
  return prediction_df
503
511
 
504
- def cross_fold_inference(self) -> pd.DataFrame:
512
+ def cross_fold_inference(self, include_quantiles: bool = False) -> pd.DataFrame:
505
513
  """Pull cross-fold inference training results for this Endpoint's model
506
514
 
515
+ Args:
516
+ include_quantiles (bool): Include q_* quantile columns in saved output (default: False)
517
+
507
518
  Returns:
508
519
  pd.DataFrame: A DataFrame with cross fold predictions
509
520
  """
@@ -594,6 +605,7 @@ class EndpointCore(Artifact):
594
605
  description,
595
606
  features=additional_columns,
596
607
  id_column=id_column,
608
+ include_quantiles=include_quantiles,
597
609
  )
598
610
 
599
611
  # Save primary target (or single target) as "full_cross_fold"
@@ -607,6 +619,7 @@ class EndpointCore(Artifact):
607
619
  "Full Cross Fold",
608
620
  features=additional_columns,
609
621
  id_column=id_column,
622
+ include_quantiles=include_quantiles,
610
623
  )
611
624
 
612
625
  return out_of_fold_df
@@ -824,6 +837,7 @@ class EndpointCore(Artifact):
824
837
  description: str,
825
838
  features: list,
826
839
  id_column: str = None,
840
+ include_quantiles: bool = False,
827
841
  ):
828
842
  """Internal: Capture the inference results and metrics to S3 for a single target
829
843
 
@@ -836,6 +850,7 @@ class EndpointCore(Artifact):
836
850
  description (str): Description of the inference results
837
851
  features (list): List of features to include in the inference results
838
852
  id_column (str, optional): Name of the ID column (default=None)
853
+ include_quantiles (bool): Include q_* quantile columns in output (default: False)
839
854
  """
840
855
 
841
856
  # Compute a dataframe hash (just use the last 8)
@@ -862,7 +877,7 @@ class EndpointCore(Artifact):
862
877
  wr.s3.to_csv(metrics, f"{inference_capture_path}/inference_metrics.csv", index=False)
863
878
 
864
879
  # Save the inference predictions for this target
865
- self._save_target_inference(inference_capture_path, pred_results_df, target, id_column)
880
+ self._save_target_inference(inference_capture_path, pred_results_df, target, id_column, include_quantiles)
866
881
 
867
882
  # CLASSIFIER: Write the confusion matrix to our S3 Model Inference Folder
868
883
  if model_type == ModelType.CLASSIFIER:
@@ -882,6 +897,7 @@ class EndpointCore(Artifact):
882
897
  pred_results_df: pd.DataFrame,
883
898
  target: str,
884
899
  id_column: str = None,
900
+ include_quantiles: bool = False,
885
901
  ):
886
902
  """Save inference results for a single target.
887
903
 
@@ -890,6 +906,7 @@ class EndpointCore(Artifact):
890
906
  pred_results_df (pd.DataFrame): DataFrame with prediction results
891
907
  target (str): Target column name
892
908
  id_column (str, optional): Name of the ID column
909
+ include_quantiles (bool): Include q_* quantile columns in output (default: False)
893
910
  """
894
911
  cols = pred_results_df.columns
895
912
 
@@ -902,8 +919,16 @@ class EndpointCore(Artifact):
902
919
 
903
920
  output_columns += [c for c in ["prediction", "prediction_std"] if c in cols]
904
921
 
905
- # Add UQ columns (q_*, confidence) and proba columns
906
- output_columns += [c for c in cols if c.startswith("q_") or c == "confidence" or c.endswith("_proba")]
922
+ # Add confidence column (always include if present)
923
+ if "confidence" in cols:
924
+ output_columns.append("confidence")
925
+
926
+ # Add quantile columns (q_*) only if requested
927
+ if include_quantiles:
928
+ output_columns += [c for c in cols if c.startswith("q_")]
929
+
930
+ # Add proba columns for classifiers
931
+ output_columns += [c for c in cols if c.endswith("_proba")]
907
932
 
908
933
  # Add smiles column if present
909
934
  if "smiles" in cols:
@@ -245,7 +245,8 @@ class AWSMeta:
245
245
  "Model Group": model_group_name,
246
246
  "Health": health_tags,
247
247
  "Owner": aws_tags.get("workbench_owner", "-"),
248
- "Model Type": aws_tags.get("workbench_model_type", "-"),
248
+ "Type": aws_tags.get("workbench_model_type", "-"),
249
+ "Framework": aws_tags.get("workbench_model_framework", "-"),
249
250
  "Created": created,
250
251
  "Ver": model_details.get("ModelPackageVersion", "-"),
251
252
  "Input": aws_tags.get("workbench_input", "-"),
@@ -26,13 +26,13 @@ class ModelToEndpoint(Transform):
26
26
  ```
27
27
  """
28
28
 
29
- def __init__(self, model_name: str, endpoint_name: str, serverless: bool = True, instance: str = "ml.t2.medium"):
29
+ def __init__(self, model_name: str, endpoint_name: str, serverless: bool = True, instance: str = None):
30
30
  """ModelToEndpoint Initialization
31
31
  Args:
32
32
  model_name(str): The Name of the input Model
33
33
  endpoint_name(str): The Name of the output Endpoint
34
34
  serverless(bool): Deploy the Endpoint in serverless mode (default: True)
35
- instance(str): The instance type to use for the Endpoint (default: "ml.t2.medium")
35
+ instance(str): The instance type for Realtime Endpoints (default: None = auto-select)
36
36
  """
37
37
  # Make sure the endpoint_name is a valid name
38
38
  Artifact.is_name_valid(endpoint_name, delimiter="-", lower_case=False)
@@ -42,7 +42,7 @@ class ModelToEndpoint(Transform):
42
42
 
43
43
  # Set up all my instance attributes
44
44
  self.serverless = serverless
45
- self.instance_type = "serverless" if serverless else instance
45
+ self.instance = instance
46
46
  self.input_type = TransformInput.MODEL
47
47
  self.output_type = TransformOutput.ENDPOINT
48
48
 
@@ -100,24 +100,37 @@ class ModelToEndpoint(Transform):
100
100
  # Get the metadata/tags to push into AWS
101
101
  aws_tags = self.get_aws_tags()
102
102
 
103
+ # Check the model framework for resource requirements
104
+ from workbench.api import ModelFramework
105
+
106
+ self.log.info(f"Model Framework: {workbench_model.model_framework}")
107
+ needs_more_resources = workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]
108
+
103
109
  # Is this a serverless deployment?
104
110
  serverless_config = None
111
+ instance_type = None
105
112
  if self.serverless:
106
113
  # For PyTorch or ChemProp we need at least 4GB of memory
107
- from workbench.api import ModelFramework
108
-
109
- self.log.info(f"Model Framework: {workbench_model.model_framework}")
110
- if workbench_model.model_framework in [ModelFramework.PYTORCH, ModelFramework.CHEMPROP]:
111
- if mem_size < 4096:
112
- self.log.important(
113
- f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)"
114
- )
115
- mem_size = 4096
114
+ if needs_more_resources and mem_size < 4096:
115
+ self.log.important(f"{workbench_model.model_framework} needs at least 4GB of memory (setting to 4GB)")
116
+ mem_size = 4096
116
117
  serverless_config = ServerlessInferenceConfig(
117
118
  memory_size_in_mb=mem_size,
118
119
  max_concurrency=max_concurrency,
119
120
  )
121
+ instance_type = "serverless"
120
122
  self.log.important(f"Serverless Config: Memory={mem_size}MB, MaxConcurrency={max_concurrency}")
123
+ else:
124
+ # For realtime endpoints, use explicit instance if provided, otherwise auto-select
125
+ if self.instance:
126
+ instance_type = self.instance
127
+ self.log.important(f"Realtime Endpoint: Using specified instance type: {instance_type}")
128
+ elif needs_more_resources:
129
+ instance_type = "ml.c7i.xlarge"
130
+ self.log.important(f"{workbench_model.model_framework} needs more resources (using {instance_type})")
131
+ else:
132
+ instance_type = "ml.t2.medium"
133
+ self.log.important(f"Realtime Endpoint: Instance Type={instance_type}")
121
134
 
122
135
  # Configure data capture if requested (and not serverless)
123
136
  data_capture_config = None
@@ -141,7 +154,7 @@ class ModelToEndpoint(Transform):
141
154
  try:
142
155
  model_package.deploy(
143
156
  initial_instance_count=1,
144
- instance_type=self.instance_type,
157
+ instance_type=instance_type,
145
158
  serverless_inference_config=serverless_config,
146
159
  endpoint_name=self.output_name,
147
160
  serializer=CSVSerializer(),
@@ -158,7 +171,7 @@ class ModelToEndpoint(Transform):
158
171
  # Retry the deploy
159
172
  model_package.deploy(
160
173
  initial_instance_count=1,
161
- instance_type=self.instance_type,
174
+ instance_type=instance_type,
162
175
  serverless_inference_config=serverless_config,
163
176
  endpoint_name=self.output_name,
164
177
  serializer=CSVSerializer(),
@@ -16,6 +16,7 @@ from sklearn.metrics import (
16
16
  r2_score,
17
17
  root_mean_squared_error,
18
18
  )
19
+ from sklearn.model_selection import GroupKFold, GroupShuffleSplit
19
20
  from scipy.stats import spearmanr
20
21
 
21
22
 
@@ -367,3 +368,227 @@ def print_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, label_names:
367
368
  for j, col_name in enumerate(label_names):
368
369
  value = conf_mtx[i, j]
369
370
  print(f"ConfusionMatrix:{row_name}:{col_name} {value}")
371
+
372
+
373
+ # =============================================================================
374
+ # Dataset Splitting Utilities for Molecular Data
375
+ # =============================================================================
376
+ def get_scaffold(smiles: str) -> str:
377
+ """Extract Bemis-Murcko scaffold from a SMILES string.
378
+
379
+ Args:
380
+ smiles: SMILES string of the molecule
381
+
382
+ Returns:
383
+ SMILES string of the scaffold, or empty string if molecule is invalid
384
+ """
385
+ from rdkit import Chem
386
+ from rdkit.Chem.Scaffolds import MurckoScaffold
387
+
388
+ mol = Chem.MolFromSmiles(smiles)
389
+ if mol is None:
390
+ return ""
391
+ try:
392
+ scaffold = MurckoScaffold.GetScaffoldForMol(mol)
393
+ return Chem.MolToSmiles(scaffold)
394
+ except Exception:
395
+ return ""
396
+
397
+
398
+ def get_scaffold_groups(smiles_list: list[str]) -> np.ndarray:
399
+ """Assign each molecule to a scaffold group.
400
+
401
+ Args:
402
+ smiles_list: List of SMILES strings
403
+
404
+ Returns:
405
+ Array of group indices (same scaffold = same group)
406
+ """
407
+ scaffold_to_group = {}
408
+ groups = []
409
+
410
+ for smi in smiles_list:
411
+ scaffold = get_scaffold(smi)
412
+ if scaffold not in scaffold_to_group:
413
+ scaffold_to_group[scaffold] = len(scaffold_to_group)
414
+ groups.append(scaffold_to_group[scaffold])
415
+
416
+ n_scaffolds = len(scaffold_to_group)
417
+ print(f"Found {n_scaffolds} unique scaffolds from {len(smiles_list)} molecules")
418
+ return np.array(groups)
419
+
420
+
421
+ def get_butina_clusters(smiles_list: list[str], cutoff: float = 0.4) -> np.ndarray:
422
+ """Cluster molecules using Butina algorithm on Morgan fingerprints.
423
+
424
+ Uses RDKit's Butina clustering with Tanimoto distance on Morgan fingerprints.
425
+ This is Pat Walters' recommended approach for creating diverse train/test splits.
426
+
427
+ Args:
428
+ smiles_list: List of SMILES strings
429
+ cutoff: Tanimoto distance cutoff for clustering (default 0.4)
430
+ Lower values = more clusters = more similar molecules per cluster
431
+
432
+ Returns:
433
+ Array of cluster indices
434
+ """
435
+ from rdkit import Chem, DataStructs
436
+ from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
437
+ from rdkit.ML.Cluster import Butina
438
+
439
+ # Create Morgan fingerprint generator
440
+ fp_gen = GetMorganGenerator(radius=2, fpSize=2048)
441
+
442
+ # Generate Morgan fingerprints
443
+ fps = []
444
+ valid_indices = []
445
+ for i, smi in enumerate(smiles_list):
446
+ mol = Chem.MolFromSmiles(smi)
447
+ if mol is not None:
448
+ fp = fp_gen.GetFingerprint(mol)
449
+ fps.append(fp)
450
+ valid_indices.append(i)
451
+
452
+ if len(fps) == 0:
453
+ raise ValueError("No valid molecules found for clustering")
454
+
455
+ # Compute distance matrix (upper triangle only for efficiency)
456
+ n = len(fps)
457
+ dists = []
458
+ for i in range(1, n):
459
+ sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])
460
+ dists.extend([1 - s for s in sims])
461
+
462
+ # Butina clustering
463
+ clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True)
464
+
465
+ # Map back to original indices
466
+ cluster_labels = np.zeros(len(smiles_list), dtype=int)
467
+ for cluster_idx, cluster in enumerate(clusters):
468
+ for mol_idx in cluster:
469
+ original_idx = valid_indices[mol_idx]
470
+ cluster_labels[original_idx] = cluster_idx
471
+
472
+ # Assign invalid molecules to their own clusters
473
+ next_cluster = len(clusters)
474
+ for i in range(len(smiles_list)):
475
+ if i not in valid_indices:
476
+ cluster_labels[i] = next_cluster
477
+ next_cluster += 1
478
+
479
+ n_clusters = len(set(cluster_labels))
480
+ print(f"Butina clustering: {n_clusters} clusters from {len(smiles_list)} molecules (cutoff={cutoff})")
481
+ return cluster_labels
482
+
483
+
484
+ def _find_smiles_column(columns: list[str]) -> str | None:
485
+ """Find SMILES column (case-insensitive match for 'smiles').
486
+
487
+ Args:
488
+ columns: List of column names
489
+
490
+ Returns:
491
+ The matching column name, or None if not found
492
+ """
493
+ return next((c for c in columns if c.lower() == "smiles"), None)
494
+
495
+
496
+ def get_split_indices(
497
+ df: pd.DataFrame,
498
+ n_splits: int = 5,
499
+ strategy: str = "random",
500
+ smiles_column: str | None = None,
501
+ target_column: str | None = None,
502
+ test_size: float = 0.2,
503
+ random_state: int = 42,
504
+ butina_cutoff: float = 0.4,
505
+ ) -> list[tuple[np.ndarray, np.ndarray]]:
506
+ """Get train/validation split indices using various strategies.
507
+
508
+ This is a unified interface for generating splits that can be used across
509
+ all model templates (XGBoost, PyTorch, ChemProp).
510
+
511
+ Args:
512
+ df: DataFrame containing the data
513
+ n_splits: Number of CV folds (1 = single train/val split)
514
+ strategy: Split strategy - one of:
515
+ - "random": Standard random split (default sklearn behavior)
516
+ - "scaffold": Bemis-Murcko scaffold-based grouping
517
+ - "butina": Morgan fingerprint clustering (recommended for ADMET)
518
+ smiles_column: Column containing SMILES. If None, auto-detects 'smiles' (case-insensitive)
519
+ target_column: Column containing target values (for stratification, optional)
520
+ test_size: Fraction for validation set when n_splits=1 (default 0.2)
521
+ random_state: Random seed for reproducibility
522
+ butina_cutoff: Tanimoto distance cutoff for Butina clustering (default 0.4)
523
+
524
+ Returns:
525
+ List of (train_indices, val_indices) tuples
526
+
527
+ Note:
528
+ If scaffold/butina strategy is requested but no SMILES column is found,
529
+ automatically falls back to random split with a warning message.
530
+
531
+ Example:
532
+ >>> folds = get_split_indices(df, n_splits=5, strategy="scaffold")
533
+ >>> for train_idx, val_idx in folds:
534
+ ... X_train, X_val = df.iloc[train_idx], df.iloc[val_idx]
535
+ """
536
+ from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
537
+
538
+ n_samples = len(df)
539
+
540
+ # Random split (original behavior)
541
+ if strategy == "random":
542
+ if n_splits == 1:
543
+ indices = np.arange(n_samples)
544
+ train_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
545
+ return [(train_idx, val_idx)]
546
+ else:
547
+ if target_column and df[target_column].dtype in ["object", "category", "bool"]:
548
+ kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
549
+ return list(kfold.split(df, df[target_column]))
550
+ else:
551
+ kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
552
+ return list(kfold.split(df))
553
+
554
+ # Scaffold or Butina split requires SMILES - auto-detect if not provided
555
+ if smiles_column is None:
556
+ smiles_column = _find_smiles_column(df.columns.tolist())
557
+
558
+ # Fall back to random split if no SMILES column available
559
+ if smiles_column is None or smiles_column not in df.columns:
560
+ print(f"No 'smiles' column found for strategy='{strategy}', falling back to random split")
561
+ return get_split_indices(
562
+ df,
563
+ n_splits=n_splits,
564
+ strategy="random",
565
+ target_column=target_column,
566
+ test_size=test_size,
567
+ random_state=random_state,
568
+ )
569
+
570
+ smiles_list = df[smiles_column].tolist()
571
+
572
+ # Get group assignments
573
+ if strategy == "scaffold":
574
+ groups = get_scaffold_groups(smiles_list)
575
+ elif strategy == "butina":
576
+ groups = get_butina_clusters(smiles_list, cutoff=butina_cutoff)
577
+ else:
578
+ raise ValueError(f"Unknown strategy: {strategy}. Use 'random', 'scaffold', or 'butina'")
579
+
580
+ # Generate splits using GroupKFold or GroupShuffleSplit
581
+ if n_splits == 1:
582
+ # Single split: use GroupShuffleSplit
583
+ splitter = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
584
+ return list(splitter.split(df, groups=groups))
585
+ else:
586
+ # K-fold: use GroupKFold (ensures no group appears in both train and val)
587
+ # Note: GroupKFold doesn't shuffle, so we shuffle group order first
588
+ unique_groups = np.unique(groups)
589
+ rng = np.random.default_rng(random_state)
590
+ shuffled_group_map = {g: i for i, g in enumerate(rng.permutation(unique_groups))}
591
+ shuffled_groups = np.array([shuffled_group_map[g] for g in groups])
592
+
593
+ gkf = GroupKFold(n_splits=n_splits)
594
+ return list(gkf.split(df, groups=shuffled_groups))