workbench 0.8.331__py3-none-any.whl → 0.8.332__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. workbench/algorithms/dataframe/__init__.py +2 -2
  2. workbench/algorithms/dataframe/multi_task_alignment.py +443 -0
  3. workbench/api/endpoint.py +17 -0
  4. workbench/api/inference_cache.py +8 -73
  5. workbench/api/meta_endpoint.py +35 -39
  6. workbench/api/model.py +5 -0
  7. workbench/core/artifacts/async_endpoint_core.py +22 -17
  8. workbench/core/transforms/model_to_endpoint/model_to_endpoint.py +8 -17
  9. workbench/model_scripts/chemprop/chemprop.template +98 -8
  10. workbench/model_scripts/meta_endpoint/meta_endpoint_dag.py +0 -46
  11. workbench/utils/chem_utils/mol_descriptors_3d.py +30 -15
  12. workbench/utils/endpoint_autoscaling.py +5 -1
  13. workbench/utils/inference_cache_utils.py +15 -16
  14. workbench/utils/meta_endpoint_dag.py +0 -46
  15. workbench/utils/multi_task.py +3 -360
  16. workbench/utils/synthetic_data_generator.py +3 -2
  17. workbench/web_interface/components/plugins/multi_task_alignment_map.py +519 -0
  18. {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/METADATA +1 -1
  19. {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/RECORD +23 -24
  20. workbench/algorithms/dataframe/dataset_comparison.py +0 -401
  21. workbench/web_interface/components/plugins/concordance_explorer.py +0 -194
  22. workbench/web_interface/components/plugins/concordance_map.py +0 -392
  23. {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/WHEEL +0 -0
  24. {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/entry_points.txt +0 -0
  25. {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/licenses/LICENSE +0 -0
  26. {workbench-0.8.331.dist-info → workbench-0.8.332.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ from .feature_space_proximity import FeatureSpaceProximity
10
10
  from .fingerprint_proximity import FingerprintProximity
11
11
  from .projection_2d import Projection2D
12
12
  from .smart_aggregator import smart_aggregator
13
- from .dataset_comparison import DatasetComparison
13
+ from .multi_task_alignment import MultiTaskAlignment
14
14
 
15
15
  __all__ = [
16
16
  "Proximity",
@@ -18,5 +18,5 @@ __all__ = [
18
18
  "FingerprintProximity",
19
19
  "Projection2D",
20
20
  "smart_aggregator",
21
- "DatasetComparison",
21
+ "MultiTaskAlignment",
22
22
  ]
@@ -0,0 +1,443 @@
1
+ """Multi-task alignment: chemical-space coverage and per-aux concordance against a primary target.
2
+
3
+ Driver: deciding whether a multi-task chemprop run (e.g. ``mdr1_er + caco2_er + caco2_pappab + logd``
4
+ auxiliaries) will lift over a single-task model on the primary. Two ingredients matter:
5
+
6
+ 1. **Chemical-space coverage** — do auxiliary compounds occupy the same chemistry as the
7
+ primary? Strong coverage means the aux head's gradient on shared chemistry can refine
8
+ the encoder. Aux-only chemistry that's well-connected to primary extends coverage.
9
+ 2. **Per-aux alignment** — where they overlap, do the targets agree (Pearson r) and do
10
+ the local SAR neighborhoods predict each other (z-scored residual)?
11
+
12
+ Build one shared ``FingerprintProximity`` (ECFP + KNN + UMAP) on the union of all rows that
13
+ have any target value, then per aux compute label-only stats (counts, Pearson r, recommendation)
14
+ and chemistry stats (Tanimoto coverage, z-scored residual).
15
+
16
+ For the matching UI / "map" view, see
17
+ ``workbench.web_interface.components.plugins.MultiTaskAlignmentMap``.
18
+ """
19
+
20
+ import logging
21
+ from typing import Optional, Union
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+
26
+ from workbench.algorithms.dataframe.fingerprint_proximity import FingerprintProximity
27
+
28
+ log = logging.getLogger("workbench")
29
+
30
+
31
+ class MultiTaskAlignment:
32
+ """Per-aux alignment of a multi-task DataFrame against a primary target.
33
+
34
+ Input is a wide multi-task DataFrame: one row per compound, with ``id_column``,
35
+ ``smiles``, the primary target column, and one or more auxiliary target columns.
36
+ Targets are NaN where not measured.
37
+
38
+ A single ``FingerprintProximity`` is built on the union of all rows that have any
39
+ target value, so every per-aux computation reuses the same fingerprints, KNN graph,
40
+ and UMAP coordinates.
41
+
42
+ Use ``summary()`` for the per-aux quantitative table, ``results()`` for the
43
+ per-compound DataFrame (with shared UMAP coords + per-aux ``tanimoto_to_primary_<aux>``
44
+ and ``residual_<aux>`` columns), and ``neighbors()`` for compound-level lookups.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ df: pd.DataFrame,
50
+ primary: str,
51
+ auxiliaries: Optional[list[str]] = None,
52
+ id_column: str = "id",
53
+ k_neighbors: int = 5,
54
+ radius: int = 2,
55
+ n_bits: int = 2048,
56
+ min_n_shared: int = 10,
57
+ extension_ratio_threshold: float = 0.5,
58
+ ) -> None:
59
+ """Initialize the alignment and run all per-aux computations.
60
+
61
+ Args:
62
+ df: Wide multi-task DataFrame with ``id_column``, ``smiles``, ``primary``,
63
+ and aux target columns. NaN means the target wasn't measured.
64
+ primary: Name of the primary target column.
65
+ auxiliaries: Aux target column names. If None, defaults to every numeric
66
+ column that isn't ``id_column``, ``smiles``, or ``primary``.
67
+ id_column: Identifier column name (default: ``"id"``).
68
+ k_neighbors: Number of primary-having neighbors used for the residual
69
+ computation (default: 5).
70
+ radius: Morgan fingerprint radius (default: 2 = ECFP4).
71
+ n_bits: Number of fingerprint bits (default: 2048).
72
+ min_n_shared: Minimum rows-with-both-targets before Pearson r is trusted
73
+ (default: 10).
74
+ extension_ratio_threshold: ``n_aux_only / n_primary`` above which the
75
+ extension region counts as "substantial volume" (default: 0.5).
76
+ """
77
+ for col in (id_column, "smiles", primary):
78
+ if col not in df.columns:
79
+ raise ValueError(f"DataFrame missing required column: {col!r}")
80
+
81
+ if auxiliaries is None:
82
+ reserved = {id_column, "smiles", primary}
83
+ auxiliaries = [c for c in df.columns if c not in reserved and pd.api.types.is_numeric_dtype(df[c])]
84
+ else:
85
+ for aux in auxiliaries:
86
+ if aux not in df.columns:
87
+ raise ValueError(f"Aux column {aux!r} not in DataFrame")
88
+ if aux == primary:
89
+ raise ValueError(f"Aux {aux!r} is the same as primary")
90
+
91
+ if not auxiliaries:
92
+ raise ValueError("No auxiliaries provided and none auto-detected from numeric columns")
93
+
94
+ self.id_column = id_column
95
+ self.primary = primary
96
+ self.auxiliaries = list(auxiliaries)
97
+ self.k_neighbors = k_neighbors
98
+ self.min_n_shared = min_n_shared
99
+ self.extension_ratio_threshold = extension_ratio_threshold
100
+
101
+ all_targets = [primary, *self.auxiliaries]
102
+
103
+ # Drop rows with no target values at all — they can't contribute
104
+ any_target = df[all_targets].notna().any(axis=1)
105
+ n_dropped = int((~any_target).sum())
106
+ if n_dropped:
107
+ log.info(f"Dropping {n_dropped} rows with no target values")
108
+ df = df.loc[any_target].copy()
109
+
110
+ dup_mask = df.duplicated(subset=id_column, keep="first")
111
+ if dup_mask.any():
112
+ log.warning(f"Dropping {dup_mask.sum()} duplicate {id_column!r} value(s) (keeping first)")
113
+ df = df.loc[~dup_mask].copy()
114
+
115
+ log.info(f"MultiTaskAlignment: {len(df)} compounds, primary={primary!r}, " f"auxiliaries={self.auxiliaries}")
116
+
117
+ self._prox = FingerprintProximity(
118
+ df,
119
+ id_column=id_column,
120
+ target=None,
121
+ include_all_columns=True,
122
+ radius=radius,
123
+ n_bits=n_bits,
124
+ )
125
+
126
+ log.info("Computing per-compound alignment metrics...")
127
+ self._per_compound = self._compute_per_compound()
128
+ self._summary = self._compute_summary()
129
+
130
+ # ------------------------------------------------------------------
131
+ # Public API
132
+ # ------------------------------------------------------------------
133
+
134
+ def results(self) -> pd.DataFrame:
135
+ """Per-compound DataFrame with shared UMAP coords and per-aux alignment columns.
136
+
137
+ Returns:
138
+ DataFrame with original columns (``id``, ``smiles``, primary, all auxes), plus:
139
+ - ``x``, ``y``: shared UMAP 2D coordinates
140
+ - ``tanimoto_to_primary``: best Tanimoto similarity from this row to any
141
+ primary-having row (1.0 for primary rows themselves)
142
+ - ``residual_<aux>``: z-scored residual for each aux, defined only on
143
+ rows where the aux is measured. Computed as
144
+ ``z(aux) - median(z(primary)`` over top-k primary-having neighbors``)``.
145
+ Sign indicates direction of disagreement; magnitude is in std units.
146
+ """
147
+ df = self._prox.df.copy()
148
+ internal_cols = [
149
+ "nn_distance",
150
+ "nn_id",
151
+ "nn_target",
152
+ "nn_target_diff",
153
+ "nn_similarity",
154
+ "fingerprint",
155
+ ]
156
+ df = df.drop(columns=[c for c in internal_cols if c in df.columns])
157
+ df = df.merge(self._per_compound, on=self.id_column, how="left")
158
+ return df
159
+
160
+ def summary(self) -> pd.DataFrame:
161
+ """Per-aux quantitative summary, one row per aux.
162
+
163
+ Returns:
164
+ DataFrame with columns:
165
+ - ``aux``: aux target name
166
+ - ``n_primary``, ``n_aux``, ``n_shared``, ``n_aux_only``: row counts
167
+ - ``pearson_r``: correlation on shared rows (NaN if ``n_shared < min_n_shared``)
168
+ - ``r_confidence``: ``high`` / ``moderate`` / ``low`` / ``unmeasured``
169
+ - ``tanimoto_coverage_mean``: mean Tanimoto from aux-having rows to nearest
170
+ primary-having row
171
+ - ``frac_coverage_ge_05``, ``frac_coverage_ge_03``: fraction of aux-having
172
+ rows with Tanimoto coverage above the threshold
173
+ - ``residual_abs_mean``, ``residual_abs_p95``: z-scored residual stats over
174
+ aux-having rows that have at least one primary neighbor
175
+ - ``overlap``: ``Beneficial`` / ``Neutral`` / ``Harmful`` / ``N/A``
176
+ - ``extension``: ``Strong`` / ``Modest`` / ``Minimal`` / ``None``
177
+ - ``recommendation``: ``Use`` / ``Marginal`` / ``Risky`` / ``Skip``
178
+ """
179
+ return self._summary.copy()
180
+
181
+ def neighbors(
182
+ self,
183
+ compound_id: Union[str, int, list],
184
+ n_neighbors: int = 10,
185
+ ) -> pd.DataFrame:
186
+ """Nearest neighbors for a compound (or list) in the shared fingerprint space.
187
+
188
+ Args:
189
+ compound_id: ID or list of IDs to look up.
190
+ n_neighbors: Number of neighbors to return (default: 10).
191
+
192
+ Returns:
193
+ DataFrame sorted by similarity (descending) with the query id, ``neighbor_id``,
194
+ ``similarity``, ``smiles``, and the primary + aux target columns.
195
+ """
196
+ nbrs = self._prox.neighbors(compound_id, n_neighbors=n_neighbors)
197
+ keep = [self.id_column, "neighbor_id", "similarity", "smiles", self.primary, *self.auxiliaries]
198
+ return nbrs[[c for c in keep if c in nbrs.columns]].reset_index(drop=True)
199
+
200
+ # ------------------------------------------------------------------
201
+ # Internals
202
+ # ------------------------------------------------------------------
203
+
204
+ def _compute_per_compound(self) -> pd.DataFrame:
205
+ """Per-row chemistry signals: Tanimoto coverage and z-scored residual per aux."""
206
+ df = self._prox.df
207
+ all_ids = df[self.id_column].tolist()
208
+ primary_mask = df[self.primary].notna()
209
+ primary_ids = set(df.loc[primary_mask, self.id_column])
210
+
211
+ # One bulk neighbor lookup; we'll filter to primary-having neighbors below
212
+ n_lookup = max(50, self.k_neighbors * 10)
213
+ nbrs = self._prox.neighbors(all_ids, n_neighbors=n_lookup)
214
+
215
+ # Drop self-neighbors so a primary row's own value doesn't satisfy its coverage
216
+ nbrs_no_self = nbrs[nbrs[self.id_column] != nbrs["neighbor_id"]].copy()
217
+ primary_nbrs = nbrs_no_self[nbrs_no_self["neighbor_id"].isin(primary_ids)].copy()
218
+
219
+ # Tanimoto-to-primary: best similarity to any primary-having compound
220
+ best_sim = primary_nbrs.groupby(self.id_column)["similarity"].max()
221
+
222
+ result = pd.DataFrame({self.id_column: all_ids})
223
+ result["tanimoto_to_primary"] = result[self.id_column].map(best_sim).fillna(0.0)
224
+ # Primary rows are by definition fully covered
225
+ result.loc[result[self.id_column].isin(primary_ids), "tanimoto_to_primary"] = 1.0
226
+
227
+ # Z-scored "predicted primary" from top-k primary-having neighbors
228
+ primary_nbrs = primary_nbrs.sort_values([self.id_column, "similarity"], ascending=[True, False])
229
+ primary_nbrs["_rank"] = primary_nbrs.groupby(self.id_column).cumcount() + 1
230
+ topk = primary_nbrs[primary_nbrs["_rank"] <= self.k_neighbors].copy()
231
+
232
+ primary_vals = df.set_index(self.id_column)[self.primary]
233
+ primary_z = self._zscore(primary_vals)
234
+ topk["primary_z"] = topk["neighbor_id"].map(primary_z)
235
+ primary_z_pred = topk.groupby(self.id_column)["primary_z"].median()
236
+
237
+ # Per-aux z-scored residual: defined only where the aux is measured
238
+ for aux in self.auxiliaries:
239
+ aux_vals = df.set_index(self.id_column)[aux]
240
+ aux_z = self._zscore(aux_vals)
241
+ aux_z_aligned = result[self.id_column].map(aux_z)
242
+ primary_pred_aligned = result[self.id_column].map(primary_z_pred)
243
+ residual = aux_z_aligned - primary_pred_aligned
244
+ # Mask rows that don't have the aux value
245
+ aux_present = result[self.id_column].map(aux_vals.notna()).fillna(False).values
246
+ result[f"residual_{aux}"] = np.where(aux_present, residual.values, np.nan)
247
+
248
+ return result
249
+
250
+ def _compute_summary(self) -> pd.DataFrame:
251
+ """Per-aux quantitative summary (counts, pearson, coverage, residuals, verdicts)."""
252
+ df = self._prox.df
253
+ primary_mask = df[self.primary].notna()
254
+ n_primary = int(primary_mask.sum())
255
+ rows = []
256
+
257
+ for aux in self.auxiliaries:
258
+ aux_mask = df[aux].notna()
259
+ n_aux = int(aux_mask.sum())
260
+ n_shared = int((primary_mask & aux_mask).sum())
261
+ n_aux_only = int((~primary_mask & aux_mask).sum())
262
+
263
+ if n_shared >= self.min_n_shared:
264
+ pearson_r = float(df.loc[primary_mask & aux_mask, [self.primary, aux]].corr().iloc[0, 1])
265
+ else:
266
+ pearson_r = float("nan")
267
+
268
+ aux_ids = df.loc[aux_mask, self.id_column]
269
+ cov = self._per_compound.set_index(self.id_column).loc[aux_ids, "tanimoto_to_primary"]
270
+ cov_mean = float(cov.mean()) if len(cov) else 0.0
271
+ cov_05 = float((cov >= 0.5).mean()) if len(cov) else 0.0
272
+ cov_03 = float((cov >= 0.3).mean()) if len(cov) else 0.0
273
+
274
+ residuals = self._per_compound[f"residual_{aux}"].dropna()
275
+ res_abs_mean = float(residuals.abs().mean()) if len(residuals) else float("nan")
276
+ res_abs_p95 = float(residuals.abs().quantile(0.95)) if len(residuals) else float("nan")
277
+
278
+ overlap, _ = _assess_overlap(pearson_r, n_shared, self.min_n_shared)
279
+ extension, _ = _assess_extension(pearson_r, n_aux_only, n_primary, self.extension_ratio_threshold)
280
+ recommendation, _ = _combine_assessments(overlap, extension)
281
+
282
+ rows.append(
283
+ {
284
+ "aux": aux,
285
+ "n_primary": n_primary,
286
+ "n_aux": n_aux,
287
+ "n_shared": n_shared,
288
+ "n_aux_only": n_aux_only,
289
+ "pearson_r": pearson_r,
290
+ "r_confidence": _confidence_tier(n_shared, self.min_n_shared),
291
+ "tanimoto_coverage_mean": cov_mean,
292
+ "frac_coverage_ge_05": cov_05,
293
+ "frac_coverage_ge_03": cov_03,
294
+ "residual_abs_mean": res_abs_mean,
295
+ "residual_abs_p95": res_abs_p95,
296
+ "overlap": overlap,
297
+ "extension": extension,
298
+ "recommendation": recommendation,
299
+ }
300
+ )
301
+ r_str = f"r={pearson_r:.3f}" if not np.isnan(pearson_r) else "r=NA"
302
+ log.info(
303
+ f" {aux}: shared={n_shared:,} aux_only={n_aux_only:,} {r_str} "
304
+ f"cov_mean={cov_mean:.2f} -> overlap={overlap}, extension={extension} "
305
+ f"-> {recommendation}"
306
+ )
307
+
308
+ return pd.DataFrame(rows)
309
+
310
+ @staticmethod
311
+ def _zscore(s: pd.Series) -> pd.Series:
312
+ """Z-score a Series; returns zeros if std is zero (constant column)."""
313
+ std = s.std(ddof=0)
314
+ if not np.isfinite(std) or std == 0:
315
+ return s - s.mean()
316
+ return (s - s.mean()) / std
317
+
318
+
319
+ # ----------------------------------------------------------------------
320
+ # Verdict helpers (label-only scoring; same thresholds across multi-task code)
321
+ # ----------------------------------------------------------------------
322
+
323
+
324
+ def _confidence_tier(n_shared: int, min_n_shared: int) -> str:
325
+ """Bucket how trustworthy a Pearson r is, given shared-compound count."""
326
+ if n_shared < min_n_shared:
327
+ return "unmeasured"
328
+ if n_shared < 30:
329
+ return "low"
330
+ if n_shared < 100:
331
+ return "moderate"
332
+ return "high"
333
+
334
+
335
+ def _assess_overlap(r: float, n_shared: int, min_n_shared: int) -> tuple[str, str]:
336
+ """Score the overlap region (compounds with both primary and aux measured).
337
+
338
+ Thresholds (label-correlation on shared rows):
339
+ r in [0.4, 0.95] -> Beneficial : sweet spot — heads predict related but distinct targets
340
+ r > 0.95 -> Neutral : redundant; aux head just re-weights primary
341
+ r < 0.4 -> Harmful : discordant; gradient conflict / negative-transfer risk
342
+ n_shared too low -> N/A
343
+ """
344
+ if n_shared < min_n_shared:
345
+ return ("N/A", f"only {n_shared} shared compounds (need >= {min_n_shared} to score)")
346
+ if 0.4 <= r <= 0.95:
347
+ return (
348
+ "Beneficial",
349
+ f"sweet-spot r={r:.2f} on {n_shared:,} shared compounds — encoder learns richer features",
350
+ )
351
+ if r > 0.95:
352
+ return (
353
+ "Neutral",
354
+ f"redundant r={r:.2f} on {n_shared:,} shared compounds — aux head just re-weights primary",
355
+ )
356
+ return (
357
+ "Harmful",
358
+ f"discordant r={r:.2f} on {n_shared:,} shared compounds — gradient conflict, negative-transfer risk",
359
+ )
360
+
361
+
362
+ def _assess_extension(
363
+ r: float,
364
+ n_aux_only: int,
365
+ n_primary: int,
366
+ ratio_threshold: float,
367
+ ) -> tuple[str, str]:
368
+ """Score the extension region (aux-only compounds; primary head is masked there)."""
369
+ if n_aux_only == 0:
370
+ return ("None", "no aux-only compounds")
371
+
372
+ ratio = n_aux_only / n_primary if n_primary > 0 else 0.0
373
+ has_volume = ratio >= ratio_threshold
374
+
375
+ if np.isnan(r):
376
+ sim_str = "task similarity unknown (no overlap to measure)"
377
+ if has_volume:
378
+ return ("Strong", f"{ratio:.1f}x primary of novel chemistry; {sim_str}")
379
+ return ("Modest", f"{n_aux_only:,} aux-only compounds ({ratio:.1f}x primary); {sim_str}")
380
+
381
+ similar = r >= 0.4
382
+ if has_volume and similar:
383
+ return ("Strong", f"{ratio:.1f}x primary of novel chemistry x similar task (r={r:.2f})")
384
+ if has_volume:
385
+ return ("Modest", f"{ratio:.1f}x primary of novel chemistry but weak similarity (r={r:.2f})")
386
+ if similar:
387
+ return ("Modest", f"limited volume ({ratio:.1f}x primary) but similar task (r={r:.2f})")
388
+ return ("Minimal", f"low volume ({ratio:.1f}x primary) and weak similarity (r={r:.2f})")
389
+
390
+
391
+ def _combine_assessments(overlap: str, extension: str) -> tuple[str, str]:
392
+ """Combine overlap + extension scores into an actionable recommendation."""
393
+ if overlap == "Harmful":
394
+ if extension in ("Strong", "Modest"):
395
+ return ("Risky", "harmful overlap; extension might rescue but cross-task signal could hurt primary")
396
+ return ("Skip", "negative transfer from overlap with no extension to compensate")
397
+
398
+ if overlap == "Beneficial":
399
+ if extension in ("Strong", "Modest"):
400
+ return ("Use", "both mechanisms contribute lift")
401
+ return ("Use", "cross-task signal contributes lift")
402
+
403
+ # overlap is Neutral or N/A — extension is the only available mechanism
404
+ if extension == "Strong":
405
+ return ("Use", "extension is the primary lift mechanism")
406
+ if extension == "Modest":
407
+ return ("Marginal", "limited extension lift; consider domain knowledge")
408
+ return ("Skip", "no clear lift mechanism")
409
+
410
+
411
+ # =============================================================================
412
+ # Testing
413
+ # =============================================================================
414
+ if __name__ == "__main__":
415
+ from workbench.utils.synthetic_data_generator import SyntheticDataGenerator
416
+
417
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
418
+ pd.set_option("display.max_columns", None)
419
+ pd.set_option("display.width", 1400)
420
+
421
+ # Build a synthetic multi-task DataFrame from two AQSol partitions
422
+ ref_df, query_df = SyntheticDataGenerator().aqsol_alignment_data(overlap="medium", alignment="medium")
423
+ ref_df = ref_df.assign(id=ref_df["id"].astype(str).radd("ref_"))
424
+ query_df = query_df.assign(id=query_df["id"].astype(str).radd("qry_"))
425
+ ref_df = ref_df.rename(columns={"solubility": "primary_sol"})
426
+ query_df = query_df.rename(columns={"solubility": "aux_sol"})
427
+ mt_df = pd.concat(
428
+ [
429
+ ref_df[["id", "smiles", "primary_sol"]],
430
+ query_df[["id", "smiles", "aux_sol"]],
431
+ ],
432
+ ignore_index=True,
433
+ )
434
+
435
+ mta = MultiTaskAlignment(mt_df, primary="primary_sol", auxiliaries=["aux_sol"], id_column="id")
436
+
437
+ print("\n=== summary() ===")
438
+ print(mta.summary().to_string(index=False))
439
+
440
+ print("\n=== results() (first 5 rows) ===")
441
+ print(mta.results().head().to_string(index=False))
442
+
443
+ print("\nMultiTaskAlignment tests completed!")
workbench/api/endpoint.py CHANGED
@@ -82,6 +82,23 @@ class Endpoint(EndpointCore):
82
82
  return self._async.auto_inference()
83
83
  return super().auto_inference()
84
84
 
85
+ def purge_async_queue(self) -> int:
86
+ """Cancel queued async invocations by deleting their staged S3 inputs.
87
+
88
+ Useful when a long-running client was killed and you want to abandon
89
+ the orphaned backlog instead of waiting for the fleet to drain it.
90
+ Only valid on async endpoints — raises on sync endpoints.
91
+
92
+ Returns:
93
+ int: Number of staged input objects deleted.
94
+ """
95
+ if self._async is None:
96
+ raise RuntimeError(
97
+ f"Endpoint '{self.name}' is not async — purge_async_queue is only "
98
+ f"meaningful for endpoints with an async invocation queue."
99
+ )
100
+ return self._async.purge_async_queue()
101
+
85
102
  def full_inference(self) -> pd.DataFrame:
86
103
  """Run inference on the Endpoint using the full data from the model training view
87
104
 
@@ -23,10 +23,7 @@ import pandas as pd
23
23
 
24
24
  from workbench.api.df_store import DFStore
25
25
  from workbench.api.endpoint import Endpoint
26
- from workbench.utils.inference_cache_utils import (
27
- DEFAULT_CHUNK_SIZE,
28
- chunked_with_cache_writes,
29
- )
26
+ from workbench.utils.inference_cache_utils import chunked_with_cache_writes
30
27
 
31
28
 
32
29
  class InferenceCache:
@@ -49,33 +46,13 @@ class InferenceCache:
49
46
  ```
50
47
  """
51
48
 
52
- # Rows per cache write. The endpoint is called once per chunk and the
53
- # cache is persisted between chunks, so this also bounds the blast radius
54
- # of an interrupted/failed write to one chunk worth of work.
55
- #
56
- # The actual chunk_size on each instance is set in __init__: either the
57
- # explicit ``chunk_size`` constructor kwarg, or — for async endpoints with
58
- # max_instances in their workbench_meta — derived from fleet capacity
59
- # (max_instances × batch_size × 2) so each chunk holds an integer number
60
- # of full fleet-waves. This avoids the "10 batches / 8 instances → tail"
61
- # utilization loss. Falls back to this class attribute (DEFAULT_CHUNK_SIZE)
62
- # for sync endpoints or legacy endpoints without max_instances in meta.
63
- chunk_size: int = DEFAULT_CHUNK_SIZE
64
-
65
- # Number of fleet-waves per chunk when auto-deriving chunk_size. With k
66
- # batches per worker, relative tail-variance scales as 1/√k, so bumping
67
- # k from 2 to 4 cuts tail overhead ~30% without changing batch_size (so
68
- # per-batch polling cost is unchanged). Crash-recovery loss is one chunk
69
- # = 4 fleet-waves of work, modest for any reasonable batch pipeline.
70
- _CHUNK_WAVES = 4
71
-
72
49
  def __init__(
73
50
  self,
74
51
  endpoint: Endpoint,
75
52
  cache_key_column: str = "smiles",
76
53
  output_key_column: Optional[str] = None,
77
54
  auto_invalidate_cache: bool = False,
78
- chunk_size: Optional[int] = None,
55
+ snapshot: int = 500,
79
56
  ):
80
57
  """Initialize the InferenceCache.
81
58
 
@@ -98,15 +75,16 @@ class InferenceCache:
98
75
  kept regardless of endpoint changes — the manifest is
99
76
  reseeded on first load so subsequent calls have a consistent
100
77
  baseline.
101
- chunk_size (Optional[int]): Rows per cache write. If ``None``
102
- (default), derived from the endpoint's ``max_instances`` and
103
- ``inference_batch_size`` to produce full fleet-waves see
104
- :meth:`_derive_chunk_size`. Falls back to
105
- ``DEFAULT_CHUNK_SIZE`` when fleet info isn't available.
78
+ snapshot (int): Rows per cache write (default: 500). The
79
+ endpoint is called once per snapshot's worth of rows and the
80
+ cache is persisted between calls, so this also bounds the
81
+ blast radius of an interrupted run to one snapshot's worth
82
+ of work.
106
83
  """
107
84
  self._endpoint = endpoint
108
85
  self.cache_key_column = cache_key_column
109
86
  self.output_key_column = output_key_column
87
+ self.snapshot = int(snapshot)
110
88
  self.cache_path = f"/workbench/inference_cache/{endpoint.name}"
111
89
  self.manifest_path = f"{self.cache_path}__meta"
112
90
  self._df_store = DFStore()
@@ -122,49 +100,6 @@ class InferenceCache:
122
100
  self._coerce_warned: set[tuple] = set()
123
101
  self.log = logging.getLogger("workbench")
124
102
 
125
- # Resolve chunk_size: explicit override wins; else try fleet-derivation;
126
- # else fall through to the class-level DEFAULT_CHUNK_SIZE.
127
- if chunk_size is not None:
128
- self.chunk_size = int(chunk_size)
129
- else:
130
- derived = self._derive_chunk_size()
131
- if derived is not None:
132
- self.chunk_size = derived
133
-
134
- def _derive_chunk_size(self) -> Optional[int]:
135
- """Derive chunk_size from the wrapped endpoint's fleet capacity.
136
-
137
- Returns ``capacity × batch_size × _CHUNK_WAVES`` so each chunk
138
- holds an integer number of full fleet-waves — preventing the
139
- "10 batches / 8 instances → 2-batch tail" utilization loss on async
140
- endpoints. ``capacity`` prefers ``effective_max_instances`` when
141
- present (set by MetaEndpoint to reflect the largest child fleet,
142
- since the meta itself deploys with ``max_instances=1`` regardless
143
- of downstream capacity), otherwise falls back to ``max_instances``.
144
- Returns ``None`` (→ caller should use DEFAULT_CHUNK_SIZE) when the
145
- endpoint's meta has neither, which is the case for sync endpoints
146
- and legacy async deploys.
147
- """
148
- try:
149
- meta = self._endpoint.workbench_meta() or {}
150
- except Exception:
151
- return None
152
- capacity = meta.get("effective_max_instances", meta.get("max_instances"))
153
- if capacity is None:
154
- return None
155
- # Mirror AsyncEndpointCore's own resolution: explicit meta override
156
- # wins, otherwise the core default.
157
- from workbench.core.artifacts.async_endpoint_core import _DEFAULT_BATCH_SIZE
158
-
159
- batch_size = int(meta.get("inference_batch_size", _DEFAULT_BATCH_SIZE))
160
- derived = int(capacity) * batch_size * self._CHUNK_WAVES
161
- self.log.info(
162
- f"InferenceCache[{self._endpoint.name}]: chunk_size={derived} "
163
- f"(capacity={capacity} × batch_size={batch_size} × "
164
- f"{self._CHUNK_WAVES} waves — full fleet utilization per chunk)"
165
- )
166
- return derived
167
-
168
103
  def __getattr__(self, name):
169
104
  """Delegate any unrecognized attribute access to the wrapped Endpoint."""
170
105
  # __getattr__ is only called when normal lookup fails, so this won't