invarlock 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invarlock/eval/data.py CHANGED
@@ -950,6 +950,9 @@ class WikiText2Provider:
950
950
  if override_size is not None:
951
951
  batch_size = max(1, min(override_size, len(candidates)))
952
952
 
953
+ config = getattr(model, "config", None)
954
+ scorer_vocab_size = getattr(config, "vocab_size", None)
955
+
953
956
  input_batch: list[list[int]] = []
954
957
  attention_batch: list[list[int]] = []
955
958
  candidate_batch: list[dict[str, Any]] = []
@@ -970,6 +973,14 @@ class WikiText2Provider:
970
973
  attention_batch, dtype=torch.long, device=device
971
974
  )
972
975
 
976
+ # Guard against out-of-range token IDs when scoring with GPT-2.
977
+ # Some model tokenizers emit IDs beyond GPT-2 vocab, which can
978
+ # trigger device-side asserts in embedding/gather kernels.
979
+ if scorer_vocab_size and scorer_vocab_size > 0:
980
+ input_tensor = input_tensor.clamp(
981
+ min=0, max=scorer_vocab_size - 1
982
+ )
983
+
973
984
  outputs = model(input_tensor, attention_mask=attention_tensor)
974
985
  shift_logits = outputs.logits[:, :-1, :].contiguous()
975
986
  shift_labels = input_tensor[:, 1:].contiguous()
@@ -214,9 +214,15 @@ class _PPLCausal(PrimaryMetric):
214
214
  ) -> dict[str, Any]:
215
215
  subj = self._coerce_contrib_array(subject)
216
216
  base = self._coerce_contrib_array(baseline)
217
- # Compute simple (unweighted) per-example arrays in log space; weights ignored for bootstrap here
217
+ # Compute per-example arrays in log space; use weights for paired bootstrap
218
218
  subj_vals = [v for (v, _w) in subj]
219
219
  base_vals = [v for (v, _w) in base]
220
+ pair_weights = []
221
+ for (_sv, sw), (_bv, bw) in zip(subj, base, strict=False):
222
+ weight = bw if math.isfinite(bw) and bw > 0 else sw
223
+ if not math.isfinite(weight) or weight <= 0:
224
+ weight = 1.0
225
+ pair_weights.append(float(weight))
220
226
 
221
227
  # Points in display space
222
228
  def _point(
@@ -249,15 +255,24 @@ class _PPLCausal(PrimaryMetric):
249
255
  dlog_lo, dlog_hi = compute_paired_delta_log_ci(
250
256
  subj_vals,
251
257
  base_vals,
258
+ weights=pair_weights,
252
259
  method="bca",
253
260
  replicates=reps_eff,
254
261
  alpha=alpha,
255
262
  seed=seed_eff,
256
263
  )
257
- delta_log = float(
258
- sum((s - b) for s, b in zip(subj_vals, base_vals, strict=False))
259
- / max(1, min(len(subj_vals), len(base_vals)))
260
- )
264
+ if pair_weights and len(pair_weights) >= min(len(subj_vals), len(base_vals)):
265
+ sw = 0.0
266
+ swx = 0.0
267
+ for s, b, w in zip(subj_vals, base_vals, pair_weights, strict=False):
268
+ sw += w
269
+ swx += w * (s - b)
270
+ delta_log = float(swx / sw) if sw > 0 else float("nan")
271
+ else:
272
+ delta_log = float(
273
+ sum((s - b) for s, b in zip(subj_vals, base_vals, strict=False))
274
+ / max(1, min(len(subj_vals), len(base_vals)))
275
+ )
261
276
  ratio = self.display_transform(delta_log)
262
277
  return {
263
278
  "kind": self.kind,