workbench 0.8.331__py3-none-any.whl → 0.8.332__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- workbench/algorithms/dataframe/__init__.py +2 -2
- workbench/algorithms/dataframe/multi_task_alignment.py +443 -0
- workbench/api/endpoint.py +17 -0
- workbench/api/inference_cache.py +8 -73
- workbench/api/meta_endpoint.py +35 -39
- workbench/api/model.py +5 -0
- workbench/core/artifacts/async_endpoint_core.py +22 -17
- workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +8 -17
- workbench/model_scripts/chemprop/chemprop.template +98 -8
- workbench/model_scripts/meta_endpoint/meta_endpoint_dag.py +0 -46
- workbench/utils/chem_utils/mol_descriptors_3d.py +30 -15
- workbench/utils/endpoint_autoscaling.py +5 -1
- workbench/utils/inference_cache_utils.py +15 -16
- workbench/utils/meta_endpoint_dag.py +0 -46
- workbench/utils/multi_task.py +3 -360
- workbench/utils/synthetic_data_generator.py +3 -2
- workbench/web_interface/components/plugins/multi_task_alignment_map.py +519 -0
- {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/METADATA +1 -1
- {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/RECORD +23 -24
- workbench/algorithms/dataframe/dataset_comparison.py +0 -401
- workbench/web_interface/components/plugins/concordance_explorer.py +0 -194
- workbench/web_interface/components/plugins/concordance_map.py +0 -392
- {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/WHEEL +0 -0
- {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/entry_points.txt +0 -0
- {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/licenses/LICENSE +0 -0
- {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,7 @@ from .feature_space_proximity import FeatureSpaceProximity
|
|
|
10
10
|
from .fingerprint_proximity import FingerprintProximity
|
|
11
11
|
from .projection_2d import Projection2D
|
|
12
12
|
from .smart_aggregator import smart_aggregator
|
|
13
|
-
from .
|
|
13
|
+
from .multi_task_alignment import MultiTaskAlignment
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
16
|
"Proximity",
|
|
@@ -18,5 +18,5 @@ __all__ = [
|
|
|
18
18
|
"FingerprintProximity",
|
|
19
19
|
"Projection2D",
|
|
20
20
|
"smart_aggregator",
|
|
21
|
-
"
|
|
21
|
+
"MultiTaskAlignment",
|
|
22
22
|
]
|
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""Multi-task alignment: chemical-space coverage and per-aux concordance against a primary target.
|
|
2
|
+
|
|
3
|
+
Driver: deciding whether a multi-task chemprop run (e.g. ``mdr1_er + caco2_er + caco2_pappab + logd``
|
|
4
|
+
auxiliaries) will lift over a single-task model on the primary. Two ingredients matter:
|
|
5
|
+
|
|
6
|
+
1. **Chemical-space coverage** — do auxiliary compounds occupy the same chemistry as the
|
|
7
|
+
primary? Strong coverage means the aux head's gradient on shared chemistry can refine
|
|
8
|
+
the encoder. Aux-only chemistry that's well-connected to primary extends coverage.
|
|
9
|
+
2. **Per-aux alignment** — where they overlap, do the targets agree (Pearson r) and do
|
|
10
|
+
the local SAR neighborhoods predict each other (z-scored residual)?
|
|
11
|
+
|
|
12
|
+
Build one shared ``FingerprintProximity`` (ECFP + KNN + UMAP) on the union of all rows that
|
|
13
|
+
have any target value, then per aux compute label-only stats (counts, Pearson r, recommendation)
|
|
14
|
+
and chemistry stats (Tanimoto coverage, z-scored residual).
|
|
15
|
+
|
|
16
|
+
For the matching UI / "map" view, see
|
|
17
|
+
``workbench.web_interface.components.plugins.MultiTaskAlignmentMap``.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
from typing import Optional, Union
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pandas as pd
|
|
25
|
+
|
|
26
|
+
from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity
|
|
27
|
+
|
|
28
|
+
log = logging.getLogger("workbench")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MultiTaskAlignment:
|
|
32
|
+
"""Per-aux alignment of a multi-task DataFrame against a primary target.
|
|
33
|
+
|
|
34
|
+
Input is a wide multi-task DataFrame: one row per compound, with ``id_column``,
|
|
35
|
+
``smiles``, the primary target column, and one or more auxiliary target columns.
|
|
36
|
+
Targets are NaN where not measured.
|
|
37
|
+
|
|
38
|
+
A single ``FingerprintProximity`` is built on the union of all rows that have any
|
|
39
|
+
target value, so every per-aux computation reuses the same fingerprints, KNN graph,
|
|
40
|
+
and UMAP coordinates.
|
|
41
|
+
|
|
42
|
+
Use ``summary()`` for the per-aux quantitative table, ``results()`` for the
|
|
43
|
+
per-compound DataFrame (with shared UMAP coords + per-aux ``tanimoto_to_primary_<aux>``
|
|
44
|
+
and ``residual_<aux>`` columns), and ``neighbors()`` for compound-level lookups.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
df: pd.DataFrame,
|
|
50
|
+
primary: str,
|
|
51
|
+
auxiliaries: Optional[list[str]] = None,
|
|
52
|
+
id_column: str = "id",
|
|
53
|
+
k_neighbors: int = 5,
|
|
54
|
+
radius: int = 2,
|
|
55
|
+
n_bits: int = 2048,
|
|
56
|
+
min_n_shared: int = 10,
|
|
57
|
+
extension_ratio_threshold: float = 0.5,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Initialize the alignment and run all per-aux computations.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
df: Wide multi-task DataFrame with ``id_column``, ``smiles``, ``primary``,
|
|
63
|
+
and aux target columns. NaN means the target wasn't measured.
|
|
64
|
+
primary: Name of the primary target column.
|
|
65
|
+
auxiliaries: Aux target column names. If None, defaults to every numeric
|
|
66
|
+
column that isn't ``id_column``, ``smiles``, or ``primary``.
|
|
67
|
+
id_column: Identifier column name (default: ``"id"``).
|
|
68
|
+
k_neighbors: Number of primary-having neighbors used for the residual
|
|
69
|
+
computation (default: 5).
|
|
70
|
+
radius: Morgan fingerprint radius (default: 2 = ECFP4).
|
|
71
|
+
n_bits: Number of fingerprint bits (default: 2048).
|
|
72
|
+
min_n_shared: Minimum rows-with-both-targets before Pearson r is trusted
|
|
73
|
+
(default: 10).
|
|
74
|
+
extension_ratio_threshold: ``n_aux_only / n_primary`` above which the
|
|
75
|
+
extension region counts as "substantial volume" (default: 0.5).
|
|
76
|
+
"""
|
|
77
|
+
for col in (id_column, "smiles", primary):
|
|
78
|
+
if col not in df.columns:
|
|
79
|
+
raise ValueError(f"DataFrame missing required column: {col!r}")
|
|
80
|
+
|
|
81
|
+
if auxiliaries is None:
|
|
82
|
+
reserved = {id_column, "smiles", primary}
|
|
83
|
+
auxiliaries = [c for c in df.columns if c not in reserved and pd.api.types.is_numeric_dtype(df[c])]
|
|
84
|
+
else:
|
|
85
|
+
for aux in auxiliaries:
|
|
86
|
+
if aux not in df.columns:
|
|
87
|
+
raise ValueError(f"Aux column {aux!r} not in DataFrame")
|
|
88
|
+
if aux == primary:
|
|
89
|
+
raise ValueError(f"Aux {aux!r} is the same as primary")
|
|
90
|
+
|
|
91
|
+
if not auxiliaries:
|
|
92
|
+
raise ValueError("No auxiliaries provided and none auto-detected from numeric columns")
|
|
93
|
+
|
|
94
|
+
self.id_column = id_column
|
|
95
|
+
self.primary = primary
|
|
96
|
+
self.auxiliaries = list(auxiliaries)
|
|
97
|
+
self.k_neighbors = k_neighbors
|
|
98
|
+
self.min_n_shared = min_n_shared
|
|
99
|
+
self.extension_ratio_threshold = extension_ratio_threshold
|
|
100
|
+
|
|
101
|
+
all_targets = [primary, *self.auxiliaries]
|
|
102
|
+
|
|
103
|
+
# Drop rows with no target values at all — they can't contribute
|
|
104
|
+
any_target = df[all_targets].notna().any(axis=1)
|
|
105
|
+
n_dropped = int((~any_target).sum())
|
|
106
|
+
if n_dropped:
|
|
107
|
+
log.info(f"Dropping {n_dropped} rows with no target values")
|
|
108
|
+
df = df.loc[any_target].copy()
|
|
109
|
+
|
|
110
|
+
dup_mask = df.duplicated(subset=id_column, keep="first")
|
|
111
|
+
if dup_mask.any():
|
|
112
|
+
log.warning(f"Dropping {dup_mask.sum()} duplicate {id_column!r} value(s) (keeping first)")
|
|
113
|
+
df = df.loc[~dup_mask].copy()
|
|
114
|
+
|
|
115
|
+
log.info(f"MultiTaskAlignment: {len(df)} compounds, primary={primary!r}, " f"auxiliaries={self.auxiliaries}")
|
|
116
|
+
|
|
117
|
+
self._prox = FingerprintProximity(
|
|
118
|
+
df,
|
|
119
|
+
id_column=id_column,
|
|
120
|
+
target=None,
|
|
121
|
+
include_all_columns=True,
|
|
122
|
+
radius=radius,
|
|
123
|
+
n_bits=n_bits,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
log.info("Computing per-compound alignment metrics...")
|
|
127
|
+
self._per_compound = self._compute_per_compound()
|
|
128
|
+
self._summary = self._compute_summary()
|
|
129
|
+
|
|
130
|
+
# ------------------------------------------------------------------
|
|
131
|
+
# Public API
|
|
132
|
+
# ------------------------------------------------------------------
|
|
133
|
+
|
|
134
|
+
def results(self) -> pd.DataFrame:
|
|
135
|
+
"""Per-compound DataFrame with shared UMAP coords and per-aux alignment columns.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
DataFrame with original columns (``id``, ``smiles``, primary, all auxes), plus:
|
|
139
|
+
- ``x``, ``y``: shared UMAP 2D coordinates
|
|
140
|
+
- ``tanimoto_to_primary``: best Tanimoto similarity from this row to any
|
|
141
|
+
primary-having row (1.0 for primary rows themselves)
|
|
142
|
+
- ``residual_<aux>``: z-scored residual for each aux, defined only on
|
|
143
|
+
rows where the aux is measured. Computed as
|
|
144
|
+
``z(aux) - median(z(primary)`` over top-k primary-having neighbors``)``.
|
|
145
|
+
Sign indicates direction of disagreement; magnitude is in std units.
|
|
146
|
+
"""
|
|
147
|
+
df = self._prox.df.copy()
|
|
148
|
+
internal_cols = [
|
|
149
|
+
"nn_distance",
|
|
150
|
+
"nn_id",
|
|
151
|
+
"nn_target",
|
|
152
|
+
"nn_target_diff",
|
|
153
|
+
"nn_similarity",
|
|
154
|
+
"fingerprint",
|
|
155
|
+
]
|
|
156
|
+
df = df.drop(columns=[c for c in internal_cols if c in df.columns])
|
|
157
|
+
df = df.merge(self._per_compound, on=self.id_column, how="left")
|
|
158
|
+
return df
|
|
159
|
+
|
|
160
|
+
def summary(self) -> pd.DataFrame:
|
|
161
|
+
"""Per-aux quantitative summary, one row per aux.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
DataFrame with columns:
|
|
165
|
+
- ``aux``: aux target name
|
|
166
|
+
- ``n_primary``, ``n_aux``, ``n_shared``, ``n_aux_only``: row counts
|
|
167
|
+
- ``pearson_r``: correlation on shared rows (NaN if ``n_shared < min_n_shared``)
|
|
168
|
+
- ``r_confidence``: ``high`` / ``moderate`` / ``low`` / ``unmeasured``
|
|
169
|
+
- ``tanimoto_coverage_mean``: mean Tanimoto from aux-having rows to nearest
|
|
170
|
+
primary-having row
|
|
171
|
+
- ``frac_coverage_ge_05``, ``frac_coverage_ge_03``: fraction of aux-having
|
|
172
|
+
rows with Tanimoto coverage above the threshold
|
|
173
|
+
- ``residual_abs_mean``, ``residual_abs_p95``: z-scored residual stats over
|
|
174
|
+
aux-having rows that have at least one primary neighbor
|
|
175
|
+
- ``overlap``: ``Beneficial`` / ``Neutral`` / ``Harmful`` / ``N/A``
|
|
176
|
+
- ``extension``: ``Strong`` / ``Modest`` / ``Minimal`` / ``None``
|
|
177
|
+
- ``recommendation``: ``Use`` / ``Marginal`` / ``Risky`` / ``Skip``
|
|
178
|
+
"""
|
|
179
|
+
return self._summary.copy()
|
|
180
|
+
|
|
181
|
+
def neighbors(
|
|
182
|
+
self,
|
|
183
|
+
compound_id: Union[str, int, list],
|
|
184
|
+
n_neighbors: int = 10,
|
|
185
|
+
) -> pd.DataFrame:
|
|
186
|
+
"""Nearest neighbors for a compound (or list) in the shared fingerprint space.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
compound_id: ID or list of IDs to look up.
|
|
190
|
+
n_neighbors: Number of neighbors to return (default: 10).
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
DataFrame sorted by similarity (descending) with the query id, ``neighbor_id``,
|
|
194
|
+
``similarity``, ``smiles``, and the primary + aux target columns.
|
|
195
|
+
"""
|
|
196
|
+
nbrs = self._prox.neighbors(compound_id, n_neighbors=n_neighbors)
|
|
197
|
+
keep = [self.id_column, "neighbor_id", "similarity", "smiles", self.primary, *self.auxiliaries]
|
|
198
|
+
return nbrs[[c for c in keep if c in nbrs.columns]].reset_index(drop=True)
|
|
199
|
+
|
|
200
|
+
# ------------------------------------------------------------------
|
|
201
|
+
# Internals
|
|
202
|
+
# ------------------------------------------------------------------
|
|
203
|
+
|
|
204
|
+
def _compute_per_compound(self) -> pd.DataFrame:
|
|
205
|
+
"""Per-row chemistry signals: Tanimoto coverage and z-scored residual per aux."""
|
|
206
|
+
df = self._prox.df
|
|
207
|
+
all_ids = df[self.id_column].tolist()
|
|
208
|
+
primary_mask = df[self.primary].notna()
|
|
209
|
+
primary_ids = set(df.loc[primary_mask, self.id_column])
|
|
210
|
+
|
|
211
|
+
# One bulk neighbor lookup; we'll filter to primary-having neighbors below
|
|
212
|
+
n_lookup = max(50, self.k_neighbors * 10)
|
|
213
|
+
nbrs = self._prox.neighbors(all_ids, n_neighbors=n_lookup)
|
|
214
|
+
|
|
215
|
+
# Drop self-neighbors so a primary row's own value doesn't satisfy its coverage
|
|
216
|
+
nbrs_no_self = nbrs[nbrs[self.id_column] != nbrs["neighbor_id"]].copy()
|
|
217
|
+
primary_nbrs = nbrs_no_self[nbrs_no_self["neighbor_id"].isin(primary_ids)].copy()
|
|
218
|
+
|
|
219
|
+
# Tanimoto-to-primary: best similarity to any primary-having compound
|
|
220
|
+
best_sim = primary_nbrs.groupby(self.id_column)["similarity"].max()
|
|
221
|
+
|
|
222
|
+
result = pd.DataFrame({self.id_column: all_ids})
|
|
223
|
+
result["tanimoto_to_primary"] = result[self.id_column].map(best_sim).fillna(0.0)
|
|
224
|
+
# Primary rows are by definition fully covered
|
|
225
|
+
result.loc[result[self.id_column].isin(primary_ids), "tanimoto_to_primary"] = 1.0
|
|
226
|
+
|
|
227
|
+
# Z-scored "predicted primary" from top-k primary-having neighbors
|
|
228
|
+
primary_nbrs = primary_nbrs.sort_values([self.id_column, "similarity"], ascending=[True, False])
|
|
229
|
+
primary_nbrs["_rank"] = primary_nbrs.groupby(self.id_column).cumcount() + 1
|
|
230
|
+
topk = primary_nbrs[primary_nbrs["_rank"] <= self.k_neighbors].copy()
|
|
231
|
+
|
|
232
|
+
primary_vals = df.set_index(self.id_column)[self.primary]
|
|
233
|
+
primary_z = self._zscore(primary_vals)
|
|
234
|
+
topk["primary_z"] = topk["neighbor_id"].map(primary_z)
|
|
235
|
+
primary_z_pred = topk.groupby(self.id_column)["primary_z"].median()
|
|
236
|
+
|
|
237
|
+
# Per-aux z-scored residual: defined only where the aux is measured
|
|
238
|
+
for aux in self.auxiliaries:
|
|
239
|
+
aux_vals = df.set_index(self.id_column)[aux]
|
|
240
|
+
aux_z = self._zscore(aux_vals)
|
|
241
|
+
aux_z_aligned = result[self.id_column].map(aux_z)
|
|
242
|
+
primary_pred_aligned = result[self.id_column].map(primary_z_pred)
|
|
243
|
+
residual = aux_z_aligned - primary_pred_aligned
|
|
244
|
+
# Mask rows that don't have the aux value
|
|
245
|
+
aux_present = result[self.id_column].map(aux_vals.notna()).fillna(False).values
|
|
246
|
+
result[f"residual_{aux}"] = np.where(aux_present, residual.values, np.nan)
|
|
247
|
+
|
|
248
|
+
return result
|
|
249
|
+
|
|
250
|
+
def _compute_summary(self) -> pd.DataFrame:
|
|
251
|
+
"""Per-aux quantitative summary (counts, pearson, coverage, residuals, verdicts)."""
|
|
252
|
+
df = self._prox.df
|
|
253
|
+
primary_mask = df[self.primary].notna()
|
|
254
|
+
n_primary = int(primary_mask.sum())
|
|
255
|
+
rows = []
|
|
256
|
+
|
|
257
|
+
for aux in self.auxiliaries:
|
|
258
|
+
aux_mask = df[aux].notna()
|
|
259
|
+
n_aux = int(aux_mask.sum())
|
|
260
|
+
n_shared = int((primary_mask & aux_mask).sum())
|
|
261
|
+
n_aux_only = int((~primary_mask & aux_mask).sum())
|
|
262
|
+
|
|
263
|
+
if n_shared >= self.min_n_shared:
|
|
264
|
+
pearson_r = float(df.loc[primary_mask & aux_mask, [self.primary, aux]].corr().iloc[0, 1])
|
|
265
|
+
else:
|
|
266
|
+
pearson_r = float("nan")
|
|
267
|
+
|
|
268
|
+
aux_ids = df.loc[aux_mask, self.id_column]
|
|
269
|
+
cov = self._per_compound.set_index(self.id_column).loc[aux_ids, "tanimoto_to_primary"]
|
|
270
|
+
cov_mean = float(cov.mean()) if len(cov) else 0.0
|
|
271
|
+
cov_05 = float((cov >= 0.5).mean()) if len(cov) else 0.0
|
|
272
|
+
cov_03 = float((cov >= 0.3).mean()) if len(cov) else 0.0
|
|
273
|
+
|
|
274
|
+
residuals = self._per_compound[f"residual_{aux}"].dropna()
|
|
275
|
+
res_abs_mean = float(residuals.abs().mean()) if len(residuals) else float("nan")
|
|
276
|
+
res_abs_p95 = float(residuals.abs().quantile(0.95)) if len(residuals) else float("nan")
|
|
277
|
+
|
|
278
|
+
overlap, _ = _assess_overlap(pearson_r, n_shared, self.min_n_shared)
|
|
279
|
+
extension, _ = _assess_extension(pearson_r, n_aux_only, n_primary, self.extension_ratio_threshold)
|
|
280
|
+
recommendation, _ = _combine_assessments(overlap, extension)
|
|
281
|
+
|
|
282
|
+
rows.append(
|
|
283
|
+
{
|
|
284
|
+
"aux": aux,
|
|
285
|
+
"n_primary": n_primary,
|
|
286
|
+
"n_aux": n_aux,
|
|
287
|
+
"n_shared": n_shared,
|
|
288
|
+
"n_aux_only": n_aux_only,
|
|
289
|
+
"pearson_r": pearson_r,
|
|
290
|
+
"r_confidence": _confidence_tier(n_shared, self.min_n_shared),
|
|
291
|
+
"tanimoto_coverage_mean": cov_mean,
|
|
292
|
+
"frac_coverage_ge_05": cov_05,
|
|
293
|
+
"frac_coverage_ge_03": cov_03,
|
|
294
|
+
"residual_abs_mean": res_abs_mean,
|
|
295
|
+
"residual_abs_p95": res_abs_p95,
|
|
296
|
+
"overlap": overlap,
|
|
297
|
+
"extension": extension,
|
|
298
|
+
"recommendation": recommendation,
|
|
299
|
+
}
|
|
300
|
+
)
|
|
301
|
+
r_str = f"r={pearson_r:.3f}" if not np.isnan(pearson_r) else "r=NA"
|
|
302
|
+
log.info(
|
|
303
|
+
f" {aux}: shared={n_shared:,} aux_only={n_aux_only:,} {r_str} "
|
|
304
|
+
f"cov_mean={cov_mean:.2f} -> overlap={overlap}, extension={extension} "
|
|
305
|
+
f"-> {recommendation}"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return pd.DataFrame(rows)
|
|
309
|
+
|
|
310
|
+
@staticmethod
|
|
311
|
+
def _zscore(s: pd.Series) -> pd.Series:
|
|
312
|
+
"""Z-score a Series; returns zeros if std is zero (constant column)."""
|
|
313
|
+
std = s.std(ddof=0)
|
|
314
|
+
if not np.isfinite(std) or std == 0:
|
|
315
|
+
return s - s.mean()
|
|
316
|
+
return (s - s.mean()) / std
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# ----------------------------------------------------------------------
|
|
320
|
+
# Verdict helpers (label-only scoring; same thresholds across multi-task code)
|
|
321
|
+
# ----------------------------------------------------------------------
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _confidence_tier(n_shared: int, min_n_shared: int) -> str:
|
|
325
|
+
"""Bucket how trustworthy a Pearson r is, given shared-compound count."""
|
|
326
|
+
if n_shared < min_n_shared:
|
|
327
|
+
return "unmeasured"
|
|
328
|
+
if n_shared < 30:
|
|
329
|
+
return "low"
|
|
330
|
+
if n_shared < 100:
|
|
331
|
+
return "moderate"
|
|
332
|
+
return "high"
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _assess_overlap(r: float, n_shared: int, min_n_shared: int) -> tuple[str, str]:
|
|
336
|
+
"""Score the overlap region (compounds with both primary and aux measured).
|
|
337
|
+
|
|
338
|
+
Thresholds (label-correlation on shared rows):
|
|
339
|
+
r in [0.4, 0.95] -> Beneficial : sweet spot — heads predict related but distinct targets
|
|
340
|
+
r > 0.95 -> Neutral : redundant; aux head just re-weights primary
|
|
341
|
+
r < 0.4 -> Harmful : discordant; gradient conflict / negative-transfer risk
|
|
342
|
+
n_shared too low -> N/A
|
|
343
|
+
"""
|
|
344
|
+
if n_shared < min_n_shared:
|
|
345
|
+
return ("N/A", f"only {n_shared} shared compounds (need >= {min_n_shared} to score)")
|
|
346
|
+
if 0.4 <= r <= 0.95:
|
|
347
|
+
return (
|
|
348
|
+
"Beneficial",
|
|
349
|
+
f"sweet-spot r={r:.2f} on {n_shared:,} shared compounds — encoder learns richer features",
|
|
350
|
+
)
|
|
351
|
+
if r > 0.95:
|
|
352
|
+
return (
|
|
353
|
+
"Neutral",
|
|
354
|
+
f"redundant r={r:.2f} on {n_shared:,} shared compounds — aux head just re-weights primary",
|
|
355
|
+
)
|
|
356
|
+
return (
|
|
357
|
+
"Harmful",
|
|
358
|
+
f"discordant r={r:.2f} on {n_shared:,} shared compounds — gradient conflict, negative-transfer risk",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _assess_extension(
|
|
363
|
+
r: float,
|
|
364
|
+
n_aux_only: int,
|
|
365
|
+
n_primary: int,
|
|
366
|
+
ratio_threshold: float,
|
|
367
|
+
) -> tuple[str, str]:
|
|
368
|
+
"""Score the extension region (aux-only compounds; primary head is masked there)."""
|
|
369
|
+
if n_aux_only == 0:
|
|
370
|
+
return ("None", "no aux-only compounds")
|
|
371
|
+
|
|
372
|
+
ratio = n_aux_only / n_primary if n_primary > 0 else 0.0
|
|
373
|
+
has_volume = ratio >= ratio_threshold
|
|
374
|
+
|
|
375
|
+
if np.isnan(r):
|
|
376
|
+
sim_str = "task similarity unknown (no overlap to measure)"
|
|
377
|
+
if has_volume:
|
|
378
|
+
return ("Strong", f"{ratio:.1f}x primary of novel chemistry; {sim_str}")
|
|
379
|
+
return ("Modest", f"{n_aux_only:,} aux-only compounds ({ratio:.1f}x primary); {sim_str}")
|
|
380
|
+
|
|
381
|
+
similar = r >= 0.4
|
|
382
|
+
if has_volume and similar:
|
|
383
|
+
return ("Strong", f"{ratio:.1f}x primary of novel chemistry x similar task (r={r:.2f})")
|
|
384
|
+
if has_volume:
|
|
385
|
+
return ("Modest", f"{ratio:.1f}x primary of novel chemistry but weak similarity (r={r:.2f})")
|
|
386
|
+
if similar:
|
|
387
|
+
return ("Modest", f"limited volume ({ratio:.1f}x primary) but similar task (r={r:.2f})")
|
|
388
|
+
return ("Minimal", f"low volume ({ratio:.1f}x primary) and weak similarity (r={r:.2f})")
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def _combine_assessments(overlap: str, extension: str) -> tuple[str, str]:
|
|
392
|
+
"""Combine overlap + extension scores into an actionable recommendation."""
|
|
393
|
+
if overlap == "Harmful":
|
|
394
|
+
if extension in ("Strong", "Modest"):
|
|
395
|
+
return ("Risky", "harmful overlap; extension might rescue but cross-task signal could hurt primary")
|
|
396
|
+
return ("Skip", "negative transfer from overlap with no extension to compensate")
|
|
397
|
+
|
|
398
|
+
if overlap == "Beneficial":
|
|
399
|
+
if extension in ("Strong", "Modest"):
|
|
400
|
+
return ("Use", "both mechanisms contribute lift")
|
|
401
|
+
return ("Use", "cross-task signal contributes lift")
|
|
402
|
+
|
|
403
|
+
# overlap is Neutral or N/A — extension is the only available mechanism
|
|
404
|
+
if extension == "Strong":
|
|
405
|
+
return ("Use", "extension is the primary lift mechanism")
|
|
406
|
+
if extension == "Modest":
|
|
407
|
+
return ("Marginal", "limited extension lift; consider domain knowledge")
|
|
408
|
+
return ("Skip", "no clear lift mechanism")
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# =============================================================================
|
|
412
|
+
# Testing
|
|
413
|
+
# =============================================================================
|
|
414
|
+
if __name__ == "__main__":
|
|
415
|
+
from workbench.utils.synthetic_data_generator import SyntheticDataGenerator
|
|
416
|
+
|
|
417
|
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
418
|
+
pd.set_option("display.max_columns", None)
|
|
419
|
+
pd.set_option("display.width", 1400)
|
|
420
|
+
|
|
421
|
+
# Build a synthetic multi-task DataFrame from two AQSol partitions
|
|
422
|
+
ref_df, query_df = SyntheticDataGenerator().aqsol_alignment_data(overlap="medium", alignment="medium")
|
|
423
|
+
ref_df = ref_df.assign(id=ref_df["id"].astype(str).radd("ref_"))
|
|
424
|
+
query_df = query_df.assign(id=query_df["id"].astype(str).radd("qry_"))
|
|
425
|
+
ref_df = ref_df.rename(columns={"solubility": "primary_sol"})
|
|
426
|
+
query_df = query_df.rename(columns={"solubility": "aux_sol"})
|
|
427
|
+
mt_df = pd.concat(
|
|
428
|
+
[
|
|
429
|
+
ref_df[["id", "smiles", "primary_sol"]],
|
|
430
|
+
query_df[["id", "smiles", "aux_sol"]],
|
|
431
|
+
],
|
|
432
|
+
ignore_index=True,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
mta = MultiTaskAlignment(mt_df, primary="primary_sol", auxiliaries=["aux_sol"], id_column="id")
|
|
436
|
+
|
|
437
|
+
print("\n=== summary() ===")
|
|
438
|
+
print(mta.summary().to_string(index=False))
|
|
439
|
+
|
|
440
|
+
print("\n=== results() (first 5 rows) ===")
|
|
441
|
+
print(mta.results().head().to_string(index=False))
|
|
442
|
+
|
|
443
|
+
print("\nMultiTaskAlignment tests completed!")
|
workbench/api/endpoint.py
CHANGED
|
@@ -82,6 +82,23 @@ class Endpoint(EndpointCore):
|
|
|
82
82
|
return self._async.auto_inference()
|
|
83
83
|
return super().auto_inference()
|
|
84
84
|
|
|
85
|
+
def purge_async_queue(self) -> int:
|
|
86
|
+
"""Cancel queued async invocations by deleting their staged S3 inputs.
|
|
87
|
+
|
|
88
|
+
Useful when a long-running client was killed and you want to abandon
|
|
89
|
+
the orphaned backlog instead of waiting for the fleet to drain it.
|
|
90
|
+
Only valid on async endpoints — raises on sync endpoints.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
int: Number of staged input objects deleted.
|
|
94
|
+
"""
|
|
95
|
+
if self._async is None:
|
|
96
|
+
raise RuntimeError(
|
|
97
|
+
f"Endpoint '{self.name}' is not async — purge_async_queue is only "
|
|
98
|
+
f"meaningful for endpoints with an async invocation queue."
|
|
99
|
+
)
|
|
100
|
+
return self._async.purge_async_queue()
|
|
101
|
+
|
|
85
102
|
def full_inference(self) -> pd.DataFrame:
|
|
86
103
|
"""Run inference on the Endpoint using the full data from the model training view
|
|
87
104
|
|
workbench/api/inference_cache.py
CHANGED
|
@@ -23,10 +23,7 @@ import pandas as pd
|
|
|
23
23
|
|
|
24
24
|
from workbench.api.df_store import DFStore
|
|
25
25
|
from workbench.api.endpoint import Endpoint
|
|
26
|
-
from workbench.utils.inference_cache_utils import
|
|
27
|
-
DEFAULT_CHUNK_SIZE,
|
|
28
|
-
chunked_with_cache_writes,
|
|
29
|
-
)
|
|
26
|
+
from workbench.utils.inference_cache_utils import chunked_with_cache_writes
|
|
30
27
|
|
|
31
28
|
|
|
32
29
|
class InferenceCache:
|
|
@@ -49,33 +46,13 @@ class InferenceCache:
|
|
|
49
46
|
```
|
|
50
47
|
"""
|
|
51
48
|
|
|
52
|
-
# Rows per cache write. The endpoint is called once per chunk and the
|
|
53
|
-
# cache is persisted between chunks, so this also bounds the blast radius
|
|
54
|
-
# of an interrupted/failed write to one chunk worth of work.
|
|
55
|
-
#
|
|
56
|
-
# The actual chunk_size on each instance is set in __init__: either the
|
|
57
|
-
# explicit ``chunk_size`` constructor kwarg, or — for async endpoints with
|
|
58
|
-
# max_instances in their workbench_meta — derived from fleet capacity
|
|
59
|
-
# (max_instances × batch_size × 2) so each chunk holds an integer number
|
|
60
|
-
# of full fleet-waves. This avoids the "10 batches / 8 instances → tail"
|
|
61
|
-
# utilization loss. Falls back to this class attribute (DEFAULT_CHUNK_SIZE)
|
|
62
|
-
# for sync endpoints or legacy endpoints without max_instances in meta.
|
|
63
|
-
chunk_size: int = DEFAULT_CHUNK_SIZE
|
|
64
|
-
|
|
65
|
-
# Number of fleet-waves per chunk when auto-deriving chunk_size. With k
|
|
66
|
-
# batches per worker, relative tail-variance scales as 1/√k, so bumping
|
|
67
|
-
# k from 2 to 4 cuts tail overhead ~30% without changing batch_size (so
|
|
68
|
-
# per-batch polling cost is unchanged). Crash-recovery loss is one chunk
|
|
69
|
-
# = 4 fleet-waves of work, modest for any reasonable batch pipeline.
|
|
70
|
-
_CHUNK_WAVES = 4
|
|
71
|
-
|
|
72
49
|
def __init__(
|
|
73
50
|
self,
|
|
74
51
|
endpoint: Endpoint,
|
|
75
52
|
cache_key_column: str = "smiles",
|
|
76
53
|
output_key_column: Optional[str] = None,
|
|
77
54
|
auto_invalidate_cache: bool = False,
|
|
78
|
-
|
|
55
|
+
snapshot: int = 500,
|
|
79
56
|
):
|
|
80
57
|
"""Initialize the InferenceCache.
|
|
81
58
|
|
|
@@ -98,15 +75,16 @@ class InferenceCache:
|
|
|
98
75
|
kept regardless of endpoint changes — the manifest is
|
|
99
76
|
reseeded on first load so subsequent calls have a consistent
|
|
100
77
|
baseline.
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
78
|
+
snapshot (int): Rows per cache write (default: 500). The
|
|
79
|
+
endpoint is called once per snapshot's worth of rows and the
|
|
80
|
+
cache is persisted between calls, so this also bounds the
|
|
81
|
+
blast radius of an interrupted run to one snapshot's worth
|
|
82
|
+
of work.
|
|
106
83
|
"""
|
|
107
84
|
self._endpoint = endpoint
|
|
108
85
|
self.cache_key_column = cache_key_column
|
|
109
86
|
self.output_key_column = output_key_column
|
|
87
|
+
self.snapshot = int(snapshot)
|
|
110
88
|
self.cache_path = f"/workbench/inference_cache/{endpoint.name}"
|
|
111
89
|
self.manifest_path = f"{self.cache_path}__meta"
|
|
112
90
|
self._df_store = DFStore()
|
|
@@ -122,49 +100,6 @@ class InferenceCache:
|
|
|
122
100
|
self._coerce_warned: set[tuple] = set()
|
|
123
101
|
self.log = logging.getLogger("workbench")
|
|
124
102
|
|
|
125
|
-
# Resolve chunk_size: explicit override wins; else try fleet-derivation;
|
|
126
|
-
# else fall through to the class-level DEFAULT_CHUNK_SIZE.
|
|
127
|
-
if chunk_size is not None:
|
|
128
|
-
self.chunk_size = int(chunk_size)
|
|
129
|
-
else:
|
|
130
|
-
derived = self._derive_chunk_size()
|
|
131
|
-
if derived is not None:
|
|
132
|
-
self.chunk_size = derived
|
|
133
|
-
|
|
134
|
-
def _derive_chunk_size(self) -> Optional[int]:
|
|
135
|
-
"""Derive chunk_size from the wrapped endpoint's fleet capacity.
|
|
136
|
-
|
|
137
|
-
Returns ``capacity × batch_size × _CHUNK_WAVES`` so each chunk
|
|
138
|
-
holds an integer number of full fleet-waves — preventing the
|
|
139
|
-
"10 batches / 8 instances → 2-batch tail" utilization loss on async
|
|
140
|
-
endpoints. ``capacity`` prefers ``effective_max_instances`` when
|
|
141
|
-
present (set by MetaEndpoint to reflect the largest child fleet,
|
|
142
|
-
since the meta itself deploys with ``max_instances=1`` regardless
|
|
143
|
-
of downstream capacity), otherwise falls back to ``max_instances``.
|
|
144
|
-
Returns ``None`` (→ caller should use DEFAULT_CHUNK_SIZE) when the
|
|
145
|
-
endpoint's meta has neither, which is the case for sync endpoints
|
|
146
|
-
and legacy async deploys.
|
|
147
|
-
"""
|
|
148
|
-
try:
|
|
149
|
-
meta = self._endpoint.workbench_meta() or {}
|
|
150
|
-
except Exception:
|
|
151
|
-
return None
|
|
152
|
-
capacity = meta.get("effective_max_instances", meta.get("max_instances"))
|
|
153
|
-
if capacity is None:
|
|
154
|
-
return None
|
|
155
|
-
# Mirror AsyncEndpointCore's own resolution: explicit meta override
|
|
156
|
-
# wins, otherwise the core default.
|
|
157
|
-
from workbench.core.artifacts.async_endpoint_core import _DEFAULT_BATCH_SIZE
|
|
158
|
-
|
|
159
|
-
batch_size = int(meta.get("inference_batch_size", _DEFAULT_BATCH_SIZE))
|
|
160
|
-
derived = int(capacity) * batch_size * self._CHUNK_WAVES
|
|
161
|
-
self.log.info(
|
|
162
|
-
f"InferenceCache[{self._endpoint.name}]: chunk_size={derived} "
|
|
163
|
-
f"(capacity={capacity} × batch_size={batch_size} × "
|
|
164
|
-
f"{self._CHUNK_WAVES} waves — full fleet utilization per chunk)"
|
|
165
|
-
)
|
|
166
|
-
return derived
|
|
167
|
-
|
|
168
103
|
def __getattr__(self, name):
|
|
169
104
|
"""Delegate any unrecognized attribute access to the wrapped Endpoint."""
|
|
170
105
|
# __getattr__ is only called when normal lookup fails, so this won't
|