workbench 0.8.224__py3-none-any.whl → 0.8.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/__init__.py +1 -0
- workbench/algorithms/dataframe/__init__.py +2 -0
- workbench/algorithms/dataframe/smart_aggregator.py +161 -0
- workbench/algorithms/sql/column_stats.py +0 -1
- workbench/algorithms/sql/correlations.py +0 -1
- workbench/algorithms/sql/descriptive_stats.py +0 -1
- workbench/api/meta.py +0 -1
- workbench/cached/cached_meta.py +0 -1
- workbench/cached/cached_model.py +37 -7
- workbench/core/artifacts/endpoint_core.py +12 -2
- workbench/core/artifacts/feature_set_core.py +66 -8
- workbench/core/cloud_platform/cloud_meta.py +0 -1
- workbench/model_script_utils/model_script_utils.py +30 -0
- workbench/model_script_utils/uq_harness.py +0 -1
- workbench/model_scripts/chemprop/chemprop.template +3 -0
- workbench/model_scripts/chemprop/generated_model_script.py +3 -3
- workbench/model_scripts/chemprop/model_script_utils.py +30 -0
- workbench/model_scripts/custom_models/chem_info/mol_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py +0 -1
- workbench/model_scripts/custom_models/chem_info/morgan_fingerprints.py +0 -1
- workbench/model_scripts/pytorch_model/generated_model_script.py +50 -32
- workbench/model_scripts/pytorch_model/model_script_utils.py +30 -0
- workbench/model_scripts/pytorch_model/pytorch.template +47 -29
- workbench/model_scripts/pytorch_model/uq_harness.py +0 -1
- workbench/model_scripts/script_generation.py +0 -1
- workbench/model_scripts/xgb_model/model_script_utils.py +30 -0
- workbench/model_scripts/xgb_model/uq_harness.py +0 -1
- workbench/themes/dark/custom.css +85 -8
- workbench/themes/dark/plotly.json +6 -6
- workbench/themes/light/custom.css +172 -70
- workbench/themes/light/plotly.json +9 -9
- workbench/themes/midnight_blue/custom.css +48 -29
- workbench/themes/midnight_blue/plotly.json +1 -1
- workbench/utils/aws_utils.py +0 -1
- workbench/utils/chem_utils/mol_descriptors.py +0 -1
- workbench/utils/chem_utils/vis.py +137 -27
- workbench/utils/clientside_callbacks.py +41 -0
- workbench/utils/markdown_utils.py +57 -0
- workbench/utils/pipeline_utils.py +0 -1
- workbench/utils/plot_utils.py +8 -110
- workbench/web_interface/components/experiments/outlier_plot.py +0 -1
- workbench/web_interface/components/model_plot.py +2 -0
- workbench/web_interface/components/plugin_unit_test.py +0 -1
- workbench/web_interface/components/plugins/ag_table.py +2 -4
- workbench/web_interface/components/plugins/confusion_matrix.py +3 -6
- workbench/web_interface/components/plugins/model_details.py +10 -6
- workbench/web_interface/components/plugins/scatter_plot.py +56 -43
- workbench/web_interface/components/settings_menu.py +2 -1
- workbench/web_interface/page_views/main_page.py +0 -1
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/METADATA +31 -29
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/RECORD +55 -59
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/WHEEL +1 -1
- workbench/themes/quartz/base_css.url +0 -1
- workbench/themes/quartz/custom.css +0 -117
- workbench/themes/quartz/plotly.json +0 -642
- workbench/themes/quartz_dark/base_css.url +0 -1
- workbench/themes/quartz_dark/custom.css +0 -131
- workbench/themes/quartz_dark/plotly.json +0 -642
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.224.dist-info → workbench-0.8.231.dist-info}/top_level.txt +0 -0
workbench/__init__.py
CHANGED
|
@@ -9,10 +9,12 @@ from .proximity import Proximity
|
|
|
9
9
|
from .feature_space_proximity import FeatureSpaceProximity
|
|
10
10
|
from .fingerprint_proximity import FingerprintProximity
|
|
11
11
|
from .projection_2d import Projection2D
|
|
12
|
+
from .smart_aggregator import smart_aggregator
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
15
|
"Proximity",
|
|
15
16
|
"FeatureSpaceProximity",
|
|
16
17
|
"FingerprintProximity",
|
|
17
18
|
"Projection2D",
|
|
19
|
+
"smart_aggregator",
|
|
18
20
|
]
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""SmartSample: Intelligently reduce DataFrame rows by aggregating similar rows together."""
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import numpy as np
|
|
5
|
+
from sklearn.preprocessing import StandardScaler
|
|
6
|
+
from sklearn.cluster import MiniBatchKMeans
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
# Set up logging
|
|
10
|
+
log = logging.getLogger("workbench")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def smart_aggregator(df: pd.DataFrame, target_rows: int = 1000, outlier_column: str = "residual") -> pd.DataFrame:
|
|
14
|
+
"""
|
|
15
|
+
Reduce DataFrame rows by aggregating similar rows based on numeric column similarity.
|
|
16
|
+
|
|
17
|
+
This is a performant (2-pass) algorithm:
|
|
18
|
+
1. Pass 1: Normalize numeric columns and cluster similar rows using MiniBatchKMeans
|
|
19
|
+
2. Pass 2: Aggregate each cluster (mean for numeric, first for non-numeric)
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
df: Input DataFrame.
|
|
23
|
+
target_rows: Target number of rows in output (default: 1000).
|
|
24
|
+
outlier_column: Column where high values should resist aggregation (default: "residual").
|
|
25
|
+
Rows with high values in this column will be kept separate while rows
|
|
26
|
+
with low values cluster together. Set to None to disable.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Reduced DataFrame with 'aggregation_count' column showing how many rows were combined.
|
|
30
|
+
"""
|
|
31
|
+
if df is None or df.empty:
|
|
32
|
+
return df
|
|
33
|
+
|
|
34
|
+
n_rows = len(df)
|
|
35
|
+
|
|
36
|
+
# Preserve original column order
|
|
37
|
+
original_columns = df.columns.tolist()
|
|
38
|
+
|
|
39
|
+
# If already at or below target, just add the count column and return
|
|
40
|
+
if n_rows <= target_rows:
|
|
41
|
+
result = df.copy()
|
|
42
|
+
result["aggregation_count"] = 1
|
|
43
|
+
return result
|
|
44
|
+
|
|
45
|
+
log.info(f"smart_aggregator: Reducing {n_rows} rows to ~{target_rows} rows")
|
|
46
|
+
|
|
47
|
+
# Identify columns by type
|
|
48
|
+
df = df.copy()
|
|
49
|
+
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
|
|
50
|
+
non_numeric_cols = [c for c in df.columns if c not in numeric_cols]
|
|
51
|
+
|
|
52
|
+
if not numeric_cols:
|
|
53
|
+
log.warning("smart_aggregator: No numeric columns for clustering, falling back to random sample")
|
|
54
|
+
result = df.sample(n=target_rows)
|
|
55
|
+
result["aggregation_count"] = 1
|
|
56
|
+
return result.reset_index(drop=True)
|
|
57
|
+
|
|
58
|
+
# Handle NaN values - fill with column median
|
|
59
|
+
df_for_clustering = df[numeric_cols].fillna(df[numeric_cols].median())
|
|
60
|
+
|
|
61
|
+
# Normalize and cluster
|
|
62
|
+
X = StandardScaler().fit_transform(df_for_clustering)
|
|
63
|
+
df["_cluster"] = MiniBatchKMeans(
|
|
64
|
+
n_clusters=min(target_rows, n_rows), random_state=42, batch_size=min(1024, n_rows), n_init=3
|
|
65
|
+
).fit_predict(X)
|
|
66
|
+
|
|
67
|
+
# Post-process: give high-outlier rows their own unique clusters so they don't get aggregated
|
|
68
|
+
if outlier_column and outlier_column in df.columns:
|
|
69
|
+
# Top 10% of outlier values get their own clusters, capped at 200
|
|
70
|
+
n_to_isolate = min(int(n_rows * 0.1), 200)
|
|
71
|
+
threshold = df[outlier_column].nlargest(n_to_isolate).min()
|
|
72
|
+
high_outlier_mask = df[outlier_column] >= threshold
|
|
73
|
+
n_high_outliers = high_outlier_mask.sum()
|
|
74
|
+
# Assign unique cluster IDs starting after the max existing cluster
|
|
75
|
+
max_cluster = df["_cluster"].max()
|
|
76
|
+
df.loc[high_outlier_mask, "_cluster"] = range(max_cluster + 1, max_cluster + 1 + n_high_outliers)
|
|
77
|
+
log.info(f"smart_aggregator: Isolated {n_high_outliers} high-outlier rows (>= {threshold:.3f})")
|
|
78
|
+
elif outlier_column:
|
|
79
|
+
log.warning(f"smart_aggregator: outlier_column '{outlier_column}' not found in columns")
|
|
80
|
+
|
|
81
|
+
# Aggregate each cluster (mean for numeric, first for non-numeric)
|
|
82
|
+
agg_dict = {col: "mean" for col in numeric_cols} | {col: "first" for col in non_numeric_cols}
|
|
83
|
+
grouped = df.groupby("_cluster")
|
|
84
|
+
result = grouped.agg(agg_dict).reset_index(drop=True)
|
|
85
|
+
result["aggregation_count"] = grouped.size().values
|
|
86
|
+
|
|
87
|
+
# Restore original column order, with aggregation_count at the end
|
|
88
|
+
result = result[original_columns + ["aggregation_count"]]
|
|
89
|
+
|
|
90
|
+
log.info(f"smart_aggregator: Reduced to {len(result)} rows")
|
|
91
|
+
return result
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Testing
|
|
95
|
+
if __name__ == "__main__":
|
|
96
|
+
pd.set_option("display.max_columns", None)
|
|
97
|
+
pd.set_option("display.width", 1000)
|
|
98
|
+
|
|
99
|
+
# Create test data with clusters
|
|
100
|
+
np.random.seed(42)
|
|
101
|
+
n_samples = 10000
|
|
102
|
+
|
|
103
|
+
# Create 3 distinct clusters
|
|
104
|
+
cluster_1 = np.random.randn(n_samples // 3, 3) + np.array([0, 0, 0])
|
|
105
|
+
cluster_2 = np.random.randn(n_samples // 3, 3) + np.array([5, 5, 5])
|
|
106
|
+
cluster_3 = np.random.randn(n_samples // 3, 3) + np.array([10, 0, 5])
|
|
107
|
+
|
|
108
|
+
features = np.vstack([cluster_1, cluster_2, cluster_3])
|
|
109
|
+
|
|
110
|
+
# Create target and prediction columns, then compute residuals
|
|
111
|
+
target = features[:, 0] + features[:, 1] * 0.5 + np.random.randn(len(features)) * 0.1
|
|
112
|
+
prediction = target + np.random.randn(len(features)) * 0.5 # Add noise for residuals
|
|
113
|
+
residuals = np.abs(target - prediction)
|
|
114
|
+
|
|
115
|
+
data = {
|
|
116
|
+
"id": [f"id_{i}" for i in range(len(features))],
|
|
117
|
+
"A": features[:, 0],
|
|
118
|
+
"B": features[:, 1],
|
|
119
|
+
"C": features[:, 2],
|
|
120
|
+
"category": np.random.choice(["cat1", "cat2", "cat3"], len(features)),
|
|
121
|
+
"target": target,
|
|
122
|
+
"prediction": prediction,
|
|
123
|
+
"residual": residuals,
|
|
124
|
+
}
|
|
125
|
+
df = pd.DataFrame(data)
|
|
126
|
+
|
|
127
|
+
print(f"Original DataFrame: {len(df)} rows")
|
|
128
|
+
print(df.head())
|
|
129
|
+
print()
|
|
130
|
+
|
|
131
|
+
# Test smart_aggregator with residuals preservation
|
|
132
|
+
result = smart_aggregator(df, target_rows=500)
|
|
133
|
+
print(f"smart_aggregator result: {len(result)} rows")
|
|
134
|
+
print(result.head(20))
|
|
135
|
+
print()
|
|
136
|
+
print("Aggregation count stats:")
|
|
137
|
+
print(result["aggregation_count"].describe())
|
|
138
|
+
print()
|
|
139
|
+
# Show that high-residual points have lower aggregation counts
|
|
140
|
+
print("Aggregation count by residual quartile:")
|
|
141
|
+
result["residual_quartile"] = pd.qcut(result["residual"], 4, labels=["Q1 (low)", "Q2", "Q3", "Q4 (high)"])
|
|
142
|
+
print(result.groupby("residual_quartile")["aggregation_count"].mean())
|
|
143
|
+
|
|
144
|
+
# Test with real Workbench data
|
|
145
|
+
print("\n" + "=" * 80)
|
|
146
|
+
print("Testing with Workbench data...")
|
|
147
|
+
print("=" * 80)
|
|
148
|
+
|
|
149
|
+
from workbench.api import Model
|
|
150
|
+
|
|
151
|
+
model = Model("abalone-regression")
|
|
152
|
+
df = model.get_inference_predictions()
|
|
153
|
+
if df is not None:
|
|
154
|
+
print(f"\nOriginal DataFrame: {len(df)} rows")
|
|
155
|
+
print(df.head())
|
|
156
|
+
|
|
157
|
+
result = smart_aggregator(df, target_rows=500)
|
|
158
|
+
print(f"\nsmart_aggregator result: {len(result)} rows")
|
|
159
|
+
print(result.head())
|
|
160
|
+
print("\nAggregation count stats:")
|
|
161
|
+
print(result["aggregation_count"].describe())
|
workbench/api/meta.py
CHANGED
workbench/cached/cached_meta.py
CHANGED
workbench/cached/cached_model.py
CHANGED
|
@@ -4,8 +4,9 @@ from typing import Union
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
6
|
# Workbench Imports
|
|
7
|
-
from workbench.core.artifacts.model_core import ModelCore
|
|
7
|
+
from workbench.core.artifacts.model_core import ModelCore, ModelType
|
|
8
8
|
from workbench.core.artifacts.cached_artifact_mixin import CachedArtifactMixin
|
|
9
|
+
from workbench.algorithms.dataframe import smart_aggregator
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class CachedModel(CachedArtifactMixin, ModelCore):
|
|
@@ -84,20 +85,49 @@ class CachedModel(CachedArtifactMixin, ModelCore):
|
|
|
84
85
|
return super().get_inference_metrics(capture_name=capture_name)
|
|
85
86
|
|
|
86
87
|
@CachedArtifactMixin.cache_result
|
|
87
|
-
def get_inference_predictions(
|
|
88
|
+
def get_inference_predictions(
|
|
89
|
+
self, capture_name: str = "full_cross_fold", target_rows: int = 1000
|
|
90
|
+
) -> Union[pd.DataFrame, None]:
|
|
88
91
|
"""Retrieve the captured prediction results for this model
|
|
89
92
|
|
|
90
93
|
Args:
|
|
91
|
-
capture_name (str, optional): Specific capture_name (default:
|
|
94
|
+
capture_name (str, optional): Specific capture_name (default: full_cross_fold)
|
|
95
|
+
target_rows (int, optional): Target number of rows to return (default: 1000)
|
|
92
96
|
|
|
93
97
|
Returns:
|
|
94
98
|
pd.DataFrame: DataFrame of the Captured Predictions (might be None)
|
|
95
99
|
"""
|
|
96
|
-
# Note: This method can generate larger dataframes, so we'll sample if needed
|
|
97
100
|
df = super().get_inference_predictions(capture_name=capture_name)
|
|
98
|
-
if df is
|
|
99
|
-
|
|
100
|
-
|
|
101
|
+
if df is None:
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
# Compute residual based on model type
|
|
105
|
+
is_regressor = self.model_type in [ModelType.REGRESSOR, ModelType.UQ_REGRESSOR, ModelType.ENSEMBLE_REGRESSOR]
|
|
106
|
+
is_classifier = self.model_type == ModelType.CLASSIFIER
|
|
107
|
+
|
|
108
|
+
if is_regressor:
|
|
109
|
+
target = self.target()
|
|
110
|
+
if target and "prediction" in df.columns and target in df.columns:
|
|
111
|
+
df["residual"] = abs(df["prediction"] - df[target])
|
|
112
|
+
|
|
113
|
+
elif is_classifier:
|
|
114
|
+
target = self.target()
|
|
115
|
+
class_labels = self.class_labels()
|
|
116
|
+
if target and "prediction" in df.columns and target in df.columns and class_labels:
|
|
117
|
+
# Create a mapping from label to ordinal index
|
|
118
|
+
label_to_idx = {label: idx for idx, label in enumerate(class_labels)}
|
|
119
|
+
# Compute residual as distance between predicted and actual class
|
|
120
|
+
df["residual"] = abs(
|
|
121
|
+
df["prediction"].map(label_to_idx).fillna(-1) - df[target].map(label_to_idx).fillna(-1)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Use smart_aggregator to aggregate similar rows if we have too many
|
|
125
|
+
if len(df) > target_rows:
|
|
126
|
+
self.log.info(
|
|
127
|
+
f"{self.name}:{capture_name} Using smart_aggregator to reduce {len(df)} rows to ~{target_rows}"
|
|
128
|
+
)
|
|
129
|
+
df = smart_aggregator(df, target_rows=target_rows)
|
|
130
|
+
|
|
101
131
|
return df
|
|
102
132
|
|
|
103
133
|
@CachedArtifactMixin.cache_result
|
|
@@ -546,7 +546,14 @@ class EndpointCore(Artifact):
|
|
|
546
546
|
target_list = targets if isinstance(targets, list) else [targets]
|
|
547
547
|
primary_target = target_list[0]
|
|
548
548
|
|
|
549
|
-
#
|
|
549
|
+
# If we don't have a smiles column, try to merge it from the FeatureSet
|
|
550
|
+
if "smiles" not in out_of_fold_df.columns:
|
|
551
|
+
fs_df = fs.query(f'SELECT {fs.id_column}, "smiles" FROM "{fs.athena_table}"')
|
|
552
|
+
if "smiles" in fs_df.columns:
|
|
553
|
+
self.log.info("Merging 'smiles' column from FeatureSet into out-of-fold predictions.")
|
|
554
|
+
out_of_fold_df = out_of_fold_df.merge(fs_df, on=fs.id_column, how="left")
|
|
555
|
+
|
|
556
|
+
# Collect UQ columns (q_*, confidence) for additional tracking (used for hashing)
|
|
550
557
|
additional_columns = [col for col in out_of_fold_df.columns if col.startswith("q_") or col == "confidence"]
|
|
551
558
|
if additional_columns:
|
|
552
559
|
self.log.info(f"UQ columns from training: {', '.join(additional_columns)}")
|
|
@@ -559,7 +566,6 @@ class EndpointCore(Artifact):
|
|
|
559
566
|
# For single-target models (99% of cases), just save as "full_cross_fold"
|
|
560
567
|
# For multi-target models, save each as cv_{target} plus primary as "full_cross_fold"
|
|
561
568
|
is_multi_target = len(target_list) > 1
|
|
562
|
-
|
|
563
569
|
for target in target_list:
|
|
564
570
|
# Drop rows with NaN target values for metrics/plots
|
|
565
571
|
target_df = out_of_fold_df.dropna(subset=[target])
|
|
@@ -899,6 +905,10 @@ class EndpointCore(Artifact):
|
|
|
899
905
|
# Add UQ columns (q_*, confidence) and proba columns
|
|
900
906
|
output_columns += [c for c in cols if c.startswith("q_") or c == "confidence" or c.endswith("_proba")]
|
|
901
907
|
|
|
908
|
+
# Add smiles column if present
|
|
909
|
+
if "smiles" in cols:
|
|
910
|
+
output_columns.append("smiles")
|
|
911
|
+
|
|
902
912
|
# Write the predictions to S3
|
|
903
913
|
output_file = f"{inference_capture_path}/inference_predictions.csv"
|
|
904
914
|
self.log.info(f"Writing predictions to {output_file}")
|
|
@@ -554,12 +554,29 @@ class FeatureSetCore(Artifact):
|
|
|
554
554
|
aws_cols = ["write_time", "api_invocation_time", "is_deleted", "event_time"]
|
|
555
555
|
source_columns = get_column_list(self.data_source, self.table)
|
|
556
556
|
column_list = [col for col in source_columns if col not in aws_cols]
|
|
557
|
-
sql_columns = ", ".join([f'"{column}"' for column in column_list])
|
|
558
557
|
|
|
559
|
-
# Build
|
|
558
|
+
# Build the training column CASE statement
|
|
560
559
|
training_case = self._build_holdout_case(holdout_ids)
|
|
561
|
-
|
|
562
|
-
|
|
560
|
+
|
|
561
|
+
# For large weight_dict, use supplemental table + JOIN
|
|
562
|
+
if len(weight_dict) >= 100:
|
|
563
|
+
self.log.info("Using supplemental table approach for large weight_dict")
|
|
564
|
+
weights_table = self._create_weights_table(weight_dict)
|
|
565
|
+
|
|
566
|
+
# Build column selection with table alias
|
|
567
|
+
sql_columns = ", ".join([f't."{col}"' for col in column_list])
|
|
568
|
+
|
|
569
|
+
# Build JOIN query with training CASE and weight from joined table
|
|
570
|
+
training_case_aliased = training_case.replace(f"WHEN {self.id_column} IN", f"WHEN t.{self.id_column} IN")
|
|
571
|
+
inner_sql = f"""SELECT {sql_columns}, {training_case_aliased},
|
|
572
|
+
COALESCE(w.sample_weight, {default_weight}) AS sample_weight
|
|
573
|
+
FROM {self.table} t
|
|
574
|
+
LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
|
|
575
|
+
else:
|
|
576
|
+
# For small weight_dict, use CASE statement
|
|
577
|
+
sql_columns = ", ".join([f'"{column}"' for column in column_list])
|
|
578
|
+
weight_case = self._build_weight_case(weight_dict, default_weight)
|
|
579
|
+
inner_sql = f"SELECT {sql_columns}, {training_case}, {weight_case} FROM {self.table}"
|
|
563
580
|
|
|
564
581
|
# Optionally filter out zero weights
|
|
565
582
|
if exclude_zero_weights:
|
|
@@ -608,6 +625,10 @@ class FeatureSetCore(Artifact):
|
|
|
608
625
|
}
|
|
609
626
|
fs.set_sample_weights(weights) # zeros automatically excluded
|
|
610
627
|
fs.set_sample_weights(weights, exclude_zero_weights=False) # keep zeros
|
|
628
|
+
|
|
629
|
+
Note:
|
|
630
|
+
For large weight_dict (100+ entries), weights are stored as a supplemental
|
|
631
|
+
table and joined to avoid Athena query size limits.
|
|
611
632
|
"""
|
|
612
633
|
from workbench.core.views import TrainingView
|
|
613
634
|
|
|
@@ -618,14 +639,25 @@ class FeatureSetCore(Artifact):
|
|
|
618
639
|
|
|
619
640
|
self.log.important(f"Setting sample weights for {len(weight_dict)} IDs")
|
|
620
641
|
|
|
621
|
-
#
|
|
622
|
-
|
|
623
|
-
|
|
642
|
+
# For large weight_dict, use supplemental table + JOIN to avoid query size limits
|
|
643
|
+
if len(weight_dict) >= 100:
|
|
644
|
+
self.log.info("Using supplemental table approach for large weight_dict")
|
|
645
|
+
weights_table = self._create_weights_table(weight_dict)
|
|
646
|
+
|
|
647
|
+
# Build JOIN query with COALESCE for default weight
|
|
648
|
+
inner_sql = f"""SELECT t.*, COALESCE(w.sample_weight, {default_weight}) AS sample_weight
|
|
649
|
+
FROM {self.table} t
|
|
650
|
+
LEFT JOIN {weights_table} w ON t.{self.id_column} = w.{self.id_column}"""
|
|
651
|
+
else:
|
|
652
|
+
# For small weight_dict, use CASE statement (simpler, no extra table)
|
|
653
|
+
weight_case = self._build_weight_case(weight_dict, default_weight)
|
|
654
|
+
inner_sql = f"SELECT *, {weight_case} FROM {self.table}"
|
|
624
655
|
|
|
625
656
|
# Optionally filter out zero weights
|
|
626
657
|
if exclude_zero_weights:
|
|
627
658
|
zero_count = sum(1 for w in weight_dict.values() if w == 0.0)
|
|
628
|
-
|
|
659
|
+
if zero_count:
|
|
660
|
+
self.log.important(f"Filtering out {zero_count} rows with sample_weight = 0")
|
|
629
661
|
sql_query = f"SELECT * FROM ({inner_sql}) WHERE sample_weight > 0"
|
|
630
662
|
else:
|
|
631
663
|
sql_query = inner_sql
|
|
@@ -667,6 +699,32 @@ class FeatureSetCore(Artifact):
|
|
|
667
699
|
create_view_query = f"CREATE OR REPLACE VIEW {view_table} AS\n{sql_query}"
|
|
668
700
|
self.data_source.execute_statement(create_view_query)
|
|
669
701
|
|
|
702
|
+
def _create_weights_table(self, weight_dict: Dict[Union[str, int], float]) -> str:
|
|
703
|
+
"""Store sample weights as a supplemental data table.
|
|
704
|
+
|
|
705
|
+
Args:
|
|
706
|
+
weight_dict: Mapping of ID to sample weight
|
|
707
|
+
|
|
708
|
+
Returns:
|
|
709
|
+
str: The name of the created supplemental table
|
|
710
|
+
"""
|
|
711
|
+
from workbench.core.views.view_utils import dataframe_to_table
|
|
712
|
+
|
|
713
|
+
# Create DataFrame from weight_dict
|
|
714
|
+
df = pd.DataFrame(
|
|
715
|
+
[(id_val, weight) for id_val, weight in weight_dict.items()],
|
|
716
|
+
columns=[self.id_column, "sample_weight"],
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
# Supplemental table name follows convention: _{base_table}___sample_weights
|
|
720
|
+
weights_table = f"_{self.table}___sample_weights"
|
|
721
|
+
|
|
722
|
+
# Store as supplemental data table
|
|
723
|
+
self.log.info(f"Creating supplemental weights table: {weights_table}")
|
|
724
|
+
dataframe_to_table(self.data_source, df, weights_table)
|
|
725
|
+
|
|
726
|
+
return weights_table
|
|
727
|
+
|
|
670
728
|
@classmethod
|
|
671
729
|
def delete_views(cls, table: str, database: str):
|
|
672
730
|
"""Delete any views associated with this FeatureSet
|
|
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
|
249
249
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
250
|
|
|
251
251
|
|
|
252
|
+
def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
|
|
253
|
+
"""Cap extreme outliers in prediction_std using IQR method.
|
|
254
|
+
|
|
255
|
+
Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
|
|
256
|
+
This prevents unreasonably large std values while preserving the
|
|
257
|
+
relative ordering and keeping meaningful high-uncertainty signals.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Array with outliers capped at the upper fence
|
|
264
|
+
"""
|
|
265
|
+
if std_array.ndim == 1:
|
|
266
|
+
std_array = std_array.reshape(-1, 1)
|
|
267
|
+
squeeze = True
|
|
268
|
+
else:
|
|
269
|
+
squeeze = False
|
|
270
|
+
|
|
271
|
+
capped = std_array.copy()
|
|
272
|
+
for col in range(capped.shape[1]):
|
|
273
|
+
col_data = capped[:, col]
|
|
274
|
+
q1, q3 = np.percentile(col_data, [25, 75])
|
|
275
|
+
iqr = q3 - q1
|
|
276
|
+
upper_bound = q3 + 1.5 * iqr
|
|
277
|
+
capped[:, col] = np.minimum(col_data, upper_bound)
|
|
278
|
+
|
|
279
|
+
return capped.squeeze() if squeeze else capped
|
|
280
|
+
|
|
281
|
+
|
|
252
282
|
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
283
|
"""Compute standard regression metrics.
|
|
254
284
|
|
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
from chemprop import data, models
|
|
21
21
|
|
|
22
22
|
from model_script_utils import (
|
|
23
|
+
cap_std_outliers,
|
|
23
24
|
expand_proba_column,
|
|
24
25
|
input_fn,
|
|
25
26
|
output_fn,
|
|
@@ -245,6 +246,7 @@ def predict_fn(df: pd.DataFrame, model_dict: dict) -> pd.DataFrame:
|
|
|
245
246
|
preds_std = np.std(np.stack(all_preds), axis=0)
|
|
246
247
|
if preds.ndim == 1:
|
|
247
248
|
preds, preds_std = preds.reshape(-1, 1), preds_std.reshape(-1, 1)
|
|
249
|
+
preds_std = cap_std_outliers(preds_std)
|
|
248
250
|
|
|
249
251
|
print(f"Inference complete: {preds.shape[0]} predictions")
|
|
250
252
|
|
|
@@ -701,6 +703,7 @@ if __name__ == "__main__":
|
|
|
701
703
|
preds_std = np.std(np.stack(all_ens_preds), axis=0)
|
|
702
704
|
if preds_std.ndim == 1:
|
|
703
705
|
preds_std = preds_std.reshape(-1, 1)
|
|
706
|
+
preds_std = cap_std_outliers(preds_std)
|
|
704
707
|
|
|
705
708
|
print("\n--- Per-target metrics ---")
|
|
706
709
|
for t_idx, t_name in enumerate(target_columns):
|
|
@@ -58,11 +58,11 @@ DEFAULT_HYPERPARAMETERS = {
|
|
|
58
58
|
# Template parameters (filled in by Workbench)
|
|
59
59
|
TEMPLATE_PARAMS = {
|
|
60
60
|
"model_type": "uq_regressor",
|
|
61
|
-
"targets": ['
|
|
61
|
+
"targets": ['udm_asy_res_value'],
|
|
62
62
|
"feature_list": ['smiles'],
|
|
63
63
|
"id_column": "udm_mol_bat_id",
|
|
64
|
-
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/
|
|
65
|
-
"hyperparameters": {
|
|
64
|
+
"model_metrics_s3_path": "s3://ideaya-sageworks-bucket/models/logd-value-reg-chemprop-1-dt/training",
|
|
65
|
+
"hyperparameters": {},
|
|
66
66
|
}
|
|
67
67
|
|
|
68
68
|
|
|
@@ -249,6 +249,36 @@ def output_fn(output_df: pd.DataFrame, accept_type: str) -> tuple[str, str]:
|
|
|
249
249
|
raise RuntimeError(f"{accept_type} accept type is not supported by this script.")
|
|
250
250
|
|
|
251
251
|
|
|
252
|
+
def cap_std_outliers(std_array: np.ndarray) -> np.ndarray:
|
|
253
|
+
"""Cap extreme outliers in prediction_std using IQR method.
|
|
254
|
+
|
|
255
|
+
Uses the standard IQR fence (Q3 + 1.5*IQR) to cap extreme values.
|
|
256
|
+
This prevents unreasonably large std values while preserving the
|
|
257
|
+
relative ordering and keeping meaningful high-uncertainty signals.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
std_array: Array of standard deviations (n_samples,) or (n_samples, n_targets)
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Array with outliers capped at the upper fence
|
|
264
|
+
"""
|
|
265
|
+
if std_array.ndim == 1:
|
|
266
|
+
std_array = std_array.reshape(-1, 1)
|
|
267
|
+
squeeze = True
|
|
268
|
+
else:
|
|
269
|
+
squeeze = False
|
|
270
|
+
|
|
271
|
+
capped = std_array.copy()
|
|
272
|
+
for col in range(capped.shape[1]):
|
|
273
|
+
col_data = capped[:, col]
|
|
274
|
+
q1, q3 = np.percentile(col_data, [25, 75])
|
|
275
|
+
iqr = q3 - q1
|
|
276
|
+
upper_bound = q3 + 1.5 * iqr
|
|
277
|
+
capped[:, col] = np.minimum(col_data, upper_bound)
|
|
278
|
+
|
|
279
|
+
return capped.squeeze() if squeeze else capped
|
|
280
|
+
|
|
281
|
+
|
|
252
282
|
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
|
253
283
|
"""Compute standard regression metrics.
|
|
254
284
|
|
|
@@ -99,7 +99,6 @@ from rdkit.ML.Descriptors import MoleculeDescriptors
|
|
|
99
99
|
from mordred import Calculator as MordredCalculator
|
|
100
100
|
from mordred import AcidBase, Aromatic, Constitutional, Chi, CarbonTypes
|
|
101
101
|
|
|
102
|
-
|
|
103
102
|
logger = logging.getLogger("workbench")
|
|
104
103
|
logger.setLevel(logging.DEBUG)
|
|
105
104
|
|