workbench 0.8.224__py3-none-any.whl → 0.8.234__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. workbench/__init__.py +1 -0
  2. workbench/algorithms/dataframe/__init__.py +2 -0
  3. workbench/algorithms/dataframe/smart_aggregator.py +161 -0
  4. workbench/algorithms/sql/column_stats.py +0 -1
  5. workbench/algorithms/sql/correlations.py +0 -1
  6. workbench/algorithms/sql/descriptive_stats.py +0 -1
  7. workbench/api/meta.py +0 -1
  8. workbench/cached/cached_meta.py +0 -1
  9. workbench/cached/cached_model.py +37 -7
  10. workbench/core/artifacts/endpoint_core.py +12 -2
  11. workbench/core/artifacts/feature_set_core.py +66 -8
  12. workbench/core/cloud_platform/cloud_meta.py +0 -1
  13. workbench/model_script_utils/model_script_utils.py +30 -0
  14. workbench/model_script_utils/uq_harness.py +0 -1
  15. workbench/model_scripts/chemprop/chemprop.template +3 -0
  16. workbench/model_scripts/chemprop/generated_model_script.py +3 -3
  17. workbench/model_scripts/chemprop/model_script_utils.py +30 -0
  18. workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
  19. workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
  20. workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
  21. workbench/model_scripts/pytorch_model/generated_model_script.py +50 -32
  22. workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
  23. workbench/model_scripts/pytorch_model/pytorch.template +47 -29
  24. workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
  25. workbench/model_scripts/script_generation.py +0 -1
  26. workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
  27. workbench/model_scripts/xgb_model/uq_harness.py +0 -1
  28. workbench/themes/dark/custom.css +85 -8
  29. workbench/themes/dark/plotly.json +6 -6
  30. workbench/themes/light/custom.css +172 -70
  31. workbench/themes/light/plotly.json +9 -9
  32. workbench/themes/midnight_blue/custom.css +48 -29
  33. workbench/themes/midnight_blue/plotly.json +1 -1
  34. workbench/utils/aws_utils.py +0 -1
  35. workbench/utils/chem_utils/mol_descriptors.py +0 -1
  36. workbench/utils/chem_utils/vis.py +137 -27
  37. workbench/utils/clientside_callbacks.py +41 -0
  38. workbench/utils/markdown_utils.py +61 -0
  39. workbench/utils/pipeline_utils.py +0 -1
  40. workbench/utils/plot_utils.py +8 -110
  41. workbench/web_interface/components/experiments/outlier_plot.py +0 -1
  42. workbench/web_interface/components/model_plot.py +2 -0
  43. workbench/web_interface/components/plugin_unit_test.py +0 -1
  44. workbench/web_interface/components/plugins/ag_table.py +2 -4
  45. workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
  46. workbench/web_interface/components/plugins/model_details.py +28 -11
  47. workbench/web_interface/components/plugins/scatter_plot.py +56 -43
  48. workbench/web_interface/components/settings_menu.py +2 -1
  49. workbench/web_interface/page_views/main_page.py +0 -1
  50. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/METADATA +31 -29
  51. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/RECORD +55 -59
  52. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/WHEEL +1 -1
  53. workbench/themes/quartz/base_css.url +0 -1
  54. workbench/themes/quartz/custom.css +0 -117
  55. workbench/themes/quartz/plotly.json +0 -642
  56. workbench/themes/quartz_dark/base_css.url +0 -1
  57. workbench/themes/quartz_dark/custom.css +0 -131
  58. workbench/themes/quartz_dark/plotly.json +0 -642
  59. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/entry_points.txt +0 -0
  60. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/licenses/LICENSE +0 -0
  61. {workbench-0.8.224.dist-info → workbench-0.8.234.dist-info}/top_level.txt +0 -0
workbench/__init__.py CHANGED
@@ -29,6 +29,7 @@ Workbench Main Classes
29
29
  | json_to_data.set_output_tags(["abalone", "json", "whatever"])
30
30
  | json_to_data.transform()
31
31
  """
32
+
32
33
  import os
33
34
  from importlib.metadata import version
34
35
 
@@ -9,10 +9,12 @@ from .proximity import Proximity
9
9
  from .feature_space_proximity import FeatureSpaceProximity
10
10
  from .fingerprint_proximity import FingerprintProximity
11
11
  from .projection_2d import Projection2D
12
+ from .smart_aggregator import smart_aggregator
12
13
 
13
14
  __all__ = [
14
15
  "Proximity",
15
16
  "FeatureSpaceProximity",
16
17
  "FingerprintProximity",
17
18
  "Projection2D",
19
+ "smart_aggregator",
18
20
  ]
@@ -0,0 +1,161 @@
1
+ """SmartSample: Intelligently reduce DataFrame rows by aggregating similar rows together."""
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.cluster import MiniBatchKMeans
7
+ import logging
8
+
9
+ # Set up logging
10
+ log = logging.getLogger("workbench")
11
+
12
+
13
+ def smart_aggregator(df: pd.DataFrame, target_rows: int = 1000, outlier_column: str = "residual") -> pd.DataFrame:
14
+ """
15
+ Reduce DataFrame rows by aggregating similar rows based on numeric column similarity.
16
+
17
+ This is a performant (2-pass) algorithm:
18
+ 1. Pass 1: Normalize numeric columns and cluster similar rows using MiniBatchKMeans
19
+ 2. Pass 2: Aggregate each cluster (mean for numeric, first for non-numeric)
20
+
21
+ Args:
22
+ df: Input DataFrame.
23
+ target_rows: Target number of rows in output (default: 1000).
24
+ outlier_column: Column where high values should resist aggregation (default: "residual").
25
+ Rows with high values in this column will be kept separate while rows
26
+ with low values cluster together. Set to None to disable.
27
+
28
+ Returns:
29
+ Reduced DataFrame with 'aggregation_count' column showing how many rows were combined.
30
+ """
31
+ if df is None or df.empty:
32
+ return df
33
+
34
+ n_rows = len(df)
35
+
36
+ # Preserve original column order
37
+ original_columns = df.columns.tolist()
38
+
39
+ # If already at or below target, just add the count column and return
40
+ if n_rows <= target_rows:
41
+ result = df.copy()
42
+ result["aggregation_count"] = 1
43
+ return result
44
+
45
+ log.info(f"smart_aggregator: Reducing {n_rows} rows to ~{target_rows} rows")
46
+
47
+ # Identify columns by type
48
+ df = df.copy()
49
+ numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
50
+ non_numeric_cols = [c for c in df.columns if c not in numeric_cols]
51
+
52
+ if not numeric_cols:
53
+ log.warning("smart_aggregator: No numeric columns for clustering, falling back to random sample")
54
+ result = df.sample(n=target_rows)
55
+ result["aggregation_count"] = 1
56
+ return result.reset_index(drop=True)
57
+
58
+ # Handle NaN values - fill with column median
59
+ df_for_clustering = df[numeric_cols].fillna(df[numeric_cols].median())
60
+
61
+ # Normalize and cluster
62
+ X = StandardScaler().fit_transform(df_for_clustering)
63
+ df["_cluster"] = MiniBatchKMeans(
64
+ n_clusters=min(target_rows, n_rows), random_state=42, batch_size=min(1024, n_rows), n_init=3
65
+ ).fit_predict(X)
66
+
67
+ # Post-process: give high-outlier rows their own unique clusters so they don't get aggregated
68
+ if outlier_column and outlier_column in df.columns:
69
+ # Top 10% of outlier values get their own clusters, capped at 200
70
+ n_to_isolate = min(int(n_rows * 0.1), 200)
71
+ threshold = df[outlier_column].nlargest(n_to_isolate).min()
72
+ high_outlier_mask = df[outlier_column] >= threshold
73
+ n_high_outliers = high_outlier_mask.sum()
74
+ # Assign unique cluster IDs starting after the max existing cluster
75
+ max_cluster = df["_cluster"].max()
76
+ df.loc[high_outlier_mask, "_cluster"] = range(max_cluster + 1, max_cluster + 1 + n_high_outliers)
77
+ log.info(f"smart_aggregator: Isolated {n_high_outliers} high-outlier rows (>= {threshold:.3f})")
78
+ elif outlier_column:
79
+ log.warning(f"smart_aggregator: outlier_column '{outlier_column}' not found in columns")
80
+
81
+ # Aggregate each cluster (mean for numeric, first for non-numeric)
82
+ agg_dict = {col: "mean" for col in numeric_cols} | {col: "first" for col in non_numeric_cols}
83
+ grouped = df.groupby("_cluster")
84
+ result = grouped.agg(agg_dict).reset_index(drop=True)
85
+ result["aggregation_count"] = grouped.size().values
86
+
87
+ # Restore original column order, with aggregation_count at the end
88
+ result = result[original_columns + ["aggregation_count"]]
89
+
90
+ log.info(f"smart_aggregator: Reduced to {len(result)} rows")
91
+ return result
92
+
93
+
94
+ # Testing
95
+ if __name__ == "__main__":
96
+ pd.set_option("display.max_columns", None)
97
+ pd.set_option("display.width", 1000)
98
+
99
+ # Create test data with clusters
100
+ np.random.seed(42)
101
+ n_samples = 10000
102
+
103
+ # Create 3 distinct clusters
104
+ cluster_1 = np.random.randn(n_samples // 3, 3) + np.array([0, 0, 0])
105
+ cluster_2 = np.random.randn(n_samples // 3, 3) + np.array([5, 5, 5])
106
+ cluster_3 = np.random.randn(n_samples // 3, 3) + np.array([10, 0, 5])
107
+
108
+ features = np.vstack([cluster_1, cluster_2, cluster_3])
109
+
110
+ # Create target and prediction columns, then compute residuals
111
+ target = features[:, 0] + features[:, 1] * 0.5 + np.random.randn(len(features)) * 0.1
112
+ prediction = target + np.random.randn(len(features)) * 0.5 # Add noise for residuals
113
+ residuals = np.abs(target - prediction)
114
+
115
+ data = {
116
+ "id": [f"id_{i}" for i in range(len(features))],
117
+ "A": features[:, 0],
118
+ "B": features[:, 1],
119
+ "C": features[:, 2],
120
+ "category": np.random.choice(["cat1", "cat2", "cat3"], len(features)),
121
+ "target": target,
122
+ "prediction": prediction,
123
+ "residual": residuals,
124
+ }
125
+ df = pd.DataFrame(data)
126
+
127
+ print(f"Original DataFrame: {len(df)} rows")
128
+ print(df.head())
129
+ print()
130
+
131
+ # Test smart_aggregator with residuals preservation
132
+ result = smart_aggregator(df, target_rows=500)
133
+ print(f"smart_aggregator result: {len(result)} rows")
134
+ print(result.head(20))
135
+ print()
136
+ print("Aggregation count stats:")
137
+ print(result["aggregation_count"].describe())
138
+ print()
139
+ # Show that high-residual points have lower aggregation counts
140
+ print("Aggregation count by residual quartile:")
141
+ result["residual_quartile"] = pd.qcut(result["residual"], 4, labels=["Q1 (low)", "Q2", "Q3", "Q4 (high)"])
142
+ print(result.groupby("residual_quartile")["aggregation_count"].mean())
143
+
144
+ # Test with real Workbench data
145
+ print("\n" + "=" * 80)
146
+ print("Testing with Workbench data...")
147
+ print("=" * 80)
148
+
149
+ from workbench.api import Model
150
+
151
+ model = Model("abalone-regression")
152
+ df = model.get_inference_predictions()
153
+ if df is not None:
154
+ print(f"\nOriginal DataFrame: {len(df)} rows")
155
+ print(df.head())
156
+
157
+ result = smart_aggregator(df, target_rows=500)
158
+ print(f"\nsmart_aggregator result: {len(result)} rows")
159
+ print(result.head())
160
+ print("\nAggregation count stats:")
161
+ print(result["aggregation_count"].describe())
@@ -6,7 +6,6 @@ import pandas as pd
6
6
  # Workbench Imports
7
7
  from workbench.core.artifacts.data_source_abstract import DataSourceAbstract
8
8
 
9
-
10
9
  # Workbench Logger
11
10
  log = logging.getLogger("workbench")
12
11
 
@@ -7,7 +7,6 @@ from collections import defaultdict
7
7
  # Workbench Imports
8
8
  from workbench.core.artifacts.data_source_abstract import DataSourceAbstract
9
9
 
10
-
11
10
  # Workbench Logger
12
11
  log = logging.getLogger("workbench")
13
12
 
@@ -7,7 +7,6 @@ from collections import defaultdict
7
7
  # Workbench Imports
8
8
  from workbench.core.artifacts.data_source_abstract import DataSourceAbstract
9
9
 
10
-
11
10
  # Workbench Logger
12
11
  log = logging.getLogger("workbench")
13
12
 
workbench/api/meta.py CHANGED
@@ -6,7 +6,6 @@ such as Data Sources, Feature Sets, Models, and Endpoints.
6
6
  from typing import Union
7
7
  import pandas as pd
8
8
 
9
-
10
9
  # Workbench Imports
11
10
  from workbench.core.cloud_platform.cloud_meta import CloudMeta
12
11
 
@@ -6,7 +6,6 @@ import pandas as pd
6
6
  from functools import wraps
7
7
  from concurrent.futures import ThreadPoolExecutor
8
8
 
9
-
10
9
  # Workbench Imports
11
10
  from workbench.core.cloud_platform.cloud_meta import CloudMeta
12
11
  from workbench.utils.workbench_cache import WorkbenchCache
@@ -4,8 +4,9 @@ from typing import Union
4
4
  import pandas as pd
5
5
 
6
6
  # Workbench Imports
7
- from workbench.core.artifacts.model_core import ModelCore
7
+ from workbench.core.artifacts.model_core import ModelCore, ModelType
8
8
  from workbench.core.artifacts.cached_artifact_mixin import CachedArtifactMixin
9
+ from workbench.algorithms.dataframe import smart_aggregator
9
10
 
10
11
 
11
12
  class CachedModel(CachedArtifactMixin, ModelCore):
@@ -84,20 +85,49 @@ class CachedModel(CachedArtifactMixin, ModelCore):
84
85
  return super().get_inference_metrics(capture_name=capture_name)
85
86
 
86
87
  @CachedArtifactMixin.cache_result
87
- def get_inference_predictions(self, capture_name: str = "auto_inference") -> Union[pd.DataFrame, None]:
88
+ def get_inference_predictions(
89
+ self, capture_name: str = "full_cross_fold", target_rows: int = 1000
90
+ ) -> Union[pd.DataFrame, None]:
88
91
  """Retrieve the captured prediction results for this model
89
92
 
90
93
  Args:
91
- capture_name (str, optional): Specific capture_name (default: training_holdout)
94
+ capture_name (str, optional): Specific capture_name (default: full_cross_fold)
95
+ target_rows (int, optional): Target number of rows to return (default: 1000)
92
96
 
93
97
  Returns:
94
98
  pd.DataFrame: DataFrame of the Captured Predictions (might be None)
95
99
  """
96
- # Note: This method can generate larger dataframes, so we'll sample if needed
97
100
  df = super().get_inference_predictions(capture_name=capture_name)
98
- if df is not None and len(df) > 5000:
99
- self.log.warning(f"{self.name}:{capture_name} Sampling Inference Predictions to 5000 rows")
100
- return df.sample(5000)
101
+ if df is None:
102
+ return None
103
+
104
+ # Compute residual based on model type
105
+ is_regressor = self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]
106
+ is_classifier = self.model_type == ModelType.CLASSIFIER
107
+
108
+ if is_regressor:
109
+ target = self.target()
110
+ if target and "prediction" in df.columns and target in df.columns:
111
+ df["residual"] = abs(df["prediction"] - df[target])
112
+
113
+ elif is_classifier:
114
+ target = self.target()
115
+ class_labels = self.class_labels()
116
+ if target and "prediction" in df.columns and target in df.columns and class_labels:
117
+ # Create a mapping from label to ordinal index
118
+ label_to_idx = {label: idx for idx, label in enumerate(class_labels)}
119
+ # Compute residual as distance between predicted and actual class
120
+ df["residual"] = abs(
121
+ df["prediction"].map(label_to_idx).fillna(-1) - df[target].map(label_to_idx).fillna(-1)
122
+ )
123
+
124
+ # Use smart_aggregator to aggregate similar rows if we have too many
125
+ if len(df) > target_rows:
126
+ self.log.info(
127
+ f"{self.name}:{capture_name} Using smart_aggregator to reduce {len(df)} rows to ~{target_rows}"
128
+ )
129
+ df = smart_aggregator(df, target_rows=target_rows)
130
+
101
131
  return df
102
132
 
103
133
  @CachedArtifactMixin.cache_result
@@ -546,7 +546,14 @@ class EndpointCore(Artifact):
546
546
  target_list = targets if isinstance(targets, list) else [targets]
547
547
  primary_target = target_list[0]
548
548
 
549
- # Collect UQ columns (q_*, confidence) for additional tracking
549
+ # If we don't have a smiles column, try to merge it from the FeatureSet
550
+ if "smiles" not in out_of_fold_df.columns:
551
+ fs_df = fs.query(f'SELECT {fs.id_column}, "smiles" FROM "{fs.athena_table}"')
552
+ if "smiles" in fs_df.columns:
553
+ self.log.info("Merging 'smiles' column from FeatureSet into out-of-fold predictions.")
554
+ out_of_fold_df = out_of_fold_df.merge(fs_df, on=fs.id_column, how="left")
555
+
556
+ # Collect UQ columns (q_*, confidence) for additional tracking (used for hashing)
550
557
  additional_columns = [col for col in out_of_fold_df.columns if col.startswith("q_") or col == "confidence"]
551
558
  if additional_columns:
552
559
  self.log.info(f"UQ columns from training: {', '.join(additional_columns)}")
@@ -559,7 +566,6 @@ class EndpointCore(Artifact):
559
566
  # For single-target models (99% of cases), just save as "full_cross_fold"
560
567
  # For multi-target models, save each as cv_{target} plus primary as "full_cross_fold"
561
568
  is_multi_target = len(target_list) > 1
562
-
563
569
  for target in target_list:
564
570
  # Drop rows with NaN target values for metrics/plots
565
571
  target_df = out_of_fold_df.dropna(subset=[target])
@@ -899,6 +905,10 @@ class EndpointCore(Artifact):
899
905
  # Add UQ columns (q_*, confidence) and proba columns
900
906
  output_columns += [c for c in cols if c.startswith("q_") or c == "confidence" or c.endswith("_proba")]
901
907
 
908
+ # Add smiles column if present
909
+ if "smiles" in cols:
910
+ output_columns.append("smiles")
911
+
902
912
  # Write the predictions to S3
903
913
  output_file = f"{inference_capture_path}/inference_predictions.csv"
904
914
  self.log.info(f"Writing predictions to {output_file}")
@@ -554,12 +554,29 @@ class FeatureSetCore(Artifact):
554
554
  aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
555
555
  source_columns = get_column_list(self.data_source, self.table)
556
556
  column_list = [col for col in source_columns if col not in aws_cols]
557
- sql_columns = ", ".join([f'"{column}"' for column in column_list])
558
557
 
559
- # Build inner query with both columns
558
+ # Build the training column CASE statement
560
559
  training_case = self._build_holdout_case(holdout_ids)
561
- weight_case = self._build_weight_case(weight_dict, default_weight)
562
- inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
560
+
561
+ # For large weight_dict, use supplemental table + JOIN
562
+ if len(weight_dict) >= 100:
563
+ self.log.info("Using supplemental table approach for large weight_dict")
564
+ weights_table = self._create_weights_table(weight_dict)
565
+
566
+ # Build column selection with table alias
567
+ sql_columns = ", ".join([f't."{col}"' for col in column_list])
568
+
569
+ # Build JOIN query with training CASE and weight from joined table
570
+ training_case_aliased = training_case.replace(f"WHEN {self.id_column} IN", f"WHEN t.{self.id_column} IN")
571
+ inner_sql = f"""SELECT {sql_columns}, {training_case_aliased},
572
+ COALESCE(w.sample_weight, {default_weight}) AS sample_weight
573
+ FROM {self.table} t
574
+ LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
575
+ else:
576
+ # For small weight_dict, use CASE statement
577
+ sql_columns = ", ".join([f'"{column}"' for column in column_list])
578
+ weight_case = self._build_weight_case(weight_dict, default_weight)
579
+ inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
563
580
 
564
581
  # Optionally filter out zero weights
565
582
  if exclude_zero_weights:
@@ -608,6 +625,10 @@ class FeatureSetCore(Artifact):
608
625
  }
609
626
  fs.set_sample_weights(weights) # zeros automatically excluded
610
627
  fs.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
628
+
629
+ Note:
630
+ For large weight_dict (100+ entries), weights are stored as a supplemental
631
+ table and joined to avoid Athena query size limits.
611
632
  """
612
633
  from workbench.core.views import TrainingView
613
634
 
@@ -618,14 +639,25 @@ class FeatureSetCore(Artifact):
618
639
 
619
640
  self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
620
641
 
621
- # Build inner query with sample weights
622
- weight_case = self._build_weight_case(weight_dict, default_weight)
623
- inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
642
+ # For large weight_dict, use supplemental table + JOIN to avoid query size limits
643
+ if len(weight_dict) >= 100:
644
+ self.log.info("Using supplemental table approach for large weight_dict")
645
+ weights_table = self._create_weights_table(weight_dict)
646
+
647
+ # Build JOIN query with COALESCE for default weight
648
+ inner_sql = f"""SELECT t.*, COALESCE(w.sample_weight, {default_weight}) AS sample_weight
649
+ FROM {self.table} t
650
+ LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
651
+ else:
652
+ # For small weight_dict, use CASE statement (simpler, no extra table)
653
+ weight_case = self._build_weight_case(weight_dict, default_weight)
654
+ inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
624
655
 
625
656
  # Optionally filter out zero weights
626
657
  if exclude_zero_weights:
627
658
  zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
628
- self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
659
+ if zero_count:
660
+ self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
629
661
  sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
630
662
  else:
631
663
  sql_query = inner_sql
@@ -667,6 +699,32 @@ class FeatureSetCore(Artifact):
667
699
  create_view_query = f"CREATE OR REPLACE VIEW {view_table} AS\n{sql_query}"
668
700
  self.data_source.execute_statement(create_view_query)
669
701
 
702
+ def _create_weights_table(self, weight_dict: Dict[Union[str, int], float]) -> str:
703
+ """Store sample weights as a supplemental data table.
704
+
705
+ Args:
706
+ weight_dict: Mapping of ID to sample weight
707
+
708
+ Returns:
709
+ str: The name of the created supplemental table
710
+ """
711
+ from workbench.core.views.view_utils import dataframe_to_table
712
+
713
+ # Create DataFrame from weight_dict
714
+ df = pd.DataFrame(
715
+ [(id_val, weight) for id_val, weight in weight_dict.items()],
716
+ columns=[self.id_column, "sample_weight"],
717
+ )
718
+
719
+ # Supplemental table name follows convention: _{base_table}___sample_weights
720
+ weights_table = f"_{self.table}___sample_weights"
721
+
722
+ # Store as supplemental data table
723
+ self.log.info(f"Creating supplemental weights table: {weights_table}")
724
+ dataframe_to_table(self.data_source, df, weights_table)
725
+
726
+ return weights_table
727
+
670
728
  @classmethod
671
729
  def delete_views(cls, table: str, database: str):
672
730
  """Delete any views associated with this FeatureSet
@@ -7,7 +7,6 @@ import logging
7
7
  from typing import Union
8
8
  import pandas as pd
9
9
 
10
-
11
10
  # Workbench Imports
12
11
  from workbench.core.cloud_platform.aws.aws_meta import AWSMeta
13
12
 
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284
 
@@ -22,7 +22,6 @@ import joblib
22
22
  from lightgbm import LGBMRegressor
23
23
  from mapie.regression import ConformalizedQuantileRegressor
24
24
 
25
-
26
25
  # Default confidence levels for prediction intervals
27
26
  DEFAULT_CONFIDENCE_LEVELS = [0.50, 0.68, 0.80, 0.90, 0.95]
28
27
 
@@ -20,6 +20,7 @@ import torch
20
20
  from chemprop import data, models
21
21
 
22
22
  from model_script_utils import (
23
+ cap_std_outliers,
23
24
  expand_proba_column,
24
25
  input_fn,
25
26
  output_fn,
@@ -245,6 +246,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
245
246
  preds_std = np.std(np.stack(all_preds), axis=0)
246
247
  if preds.ndim == 1:
247
248
  preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
249
+ preds_std = cap_std_outliers(preds_std)
248
250
 
249
251
  print(f"Inference complete: {preds.shape[0]} predictions")
250
252
 
@@ -701,6 +703,7 @@ if __name__ == "__main__":
701
703
  preds_std = np.std(np.stack(all_ens_preds), axis=0)
702
704
  if preds_std.ndim == 1:
703
705
  preds_std = preds_std.reshape(-1, 1)
706
+ preds_std = cap_std_outliers(preds_std)
704
707
 
705
708
  print("\n--- Per-target metrics ---")
706
709
  for t_idx, t_name in enumerate(target_columns):
@@ -58,11 +58,11 @@ DEFAULT_HYPERPARAMETERS = {
58
58
  # Template parameters (filled in by Workbench)
59
59
  TEMPLATE_PARAMS = {
60
60
  "model_type": "uq_regressor",
61
- "targets": ['udm_asy_res_free_percent'],
61
+ "targets": ['udm_asy_res_value'],
62
62
  "feature_list": ['smiles'],
63
63
  "id_column": "udm_mol_bat_id",
64
- "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/ppb-human-free-reg-chemprop-foundation-1-dt/training",
65
- "hyperparameters": {'from_foundation': 'CheMeleon', 'freeze_mpnn_epochs': 10, 'n_folds': 5, 'max_epochs': 100, 'patience': 20, 'ffn_hidden_dim': 512, 'dropout': 0.15},
64
+ "model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/logd-value-reg-chemprop-1-dt/training",
65
+ "hyperparameters": {},
66
66
  }
67
67
 
68
68
 
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
249
249
  raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
250
250
 
251
251
 
252
+ def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
253
+ """Cap extreme outliers in prediction_std using IQR method.
254
+
255
+ Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
256
+ This prevents unreasonably large std values while preserving the
257
+ relative ordering and keeping meaningful high-uncertainty signals.
258
+
259
+ Args:
260
+ std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
261
+
262
+ Returns:
263
+ Array with outliers capped at the upper fence
264
+ """
265
+ if std_array.ndim == 1:
266
+ std_array = std_array.reshape(-1, 1)
267
+ squeeze = True
268
+ else:
269
+ squeeze = False
270
+
271
+ capped = std_array.copy()
272
+ for col in range(capped.shape[1]):
273
+ col_data = capped[:, col]
274
+ q1, q3 = np.percentile(col_data, [25, 75])
275
+ iqr = q3 - q1
276
+ upper_bound = q3 + 1.5 * iqr
277
+ capped[:, col] = np.minimum(col_data, upper_bound)
278
+
279
+ return capped.squeeze() if squeeze else capped
280
+
281
+
252
282
  def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
253
283
  """Compute standard regression metrics.
254
284
 
@@ -99,7 +99,6 @@ from rdkit.ML.Descriptors import MoleculeDescriptors
99
99
  from mordred import Calculator as MordredCalculator
100
100
  from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
101
101
 
102
-
103
102
  logger = logging.getLogger("workbench")
104
103
  logger.setLevel(logging.DEBUG)
105
104
 
@@ -15,7 +15,6 @@ import json
15
15
  from mol_standardize import standardize
16
16
  from mol_descriptors import compute_descriptors
17
17
 
18
-
19
18
  # TRAINING SECTION
20
19
  #
21
20
  # This section (__main__) is where SageMaker will execute the training job
@@ -17,7 +17,6 @@ import json
17
17
  # Local imports
18
18
  from fingerprints import compute_morgan_fingerprints
19
19
 
20
-
21
20
  # TRAINING SECTION
22
21
  #
23
22
  # This section (__main__) is where SageMaker will execute the training job