@pentatonic-ai/ai-agent-sdk 0.10.4 → 0.10.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1 -1
- package/dist/index.js +1 -1
- package/package.json +1 -1
- package/packages/memory-engine-v2/compat/requirements.txt +6 -0
- package/packages/memory-engine-v2/compat/server.py +258 -18
- package/packages/memory-engine-v2/eval/recall_at_k.py +242 -0
- package/packages/memory-engine-v2/eval/retrieval_golden.seed.json +69 -0
- package/packages/memory-engine-v2/extractor-async/Dockerfile +1 -1
- package/packages/memory-engine-v2/extractor-async/extraction_schema.py +246 -0
- package/packages/memory-engine-v2/extractor-async/test_guided_json_parser.py +411 -0
- package/packages/memory-engine-v2/extractor-async/worker.py +417 -31
- package/packages/memory-engine-v2/resolution-queue-design.md +165 -0
- package/packages/memory-engine-v2/scripts/backfill_entity_reconciliation.py +11 -2
- package/packages/memory-engine-v2/scripts/backfill_sparse_vectors.py +369 -0
- package/packages/memory-engine-v2/scripts/bakeoff_guided_vs_kv.py +607 -0
- package/packages/memory-engine-v2/scripts/entity_resolution_v2.py +1041 -0
- package/packages/memory-engine-v2/tests/test_entity_resolution_v2.py +507 -0
- package/packages/memory-engine-v2/tests/test_hybrid_retrieval.py +810 -0
|
@@ -0,0 +1,607 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""bakeoff_guided_vs_kv — compare KV-text vs guided-JSON distiller output.
|
|
3
|
+
|
|
4
|
+
Replay/validation harness for DISTILL_OUTPUT_MODE. NO live LLM calls by
|
|
5
|
+
default: without --endpoint the script only re-parses existing outputs
|
|
6
|
+
(from distillation_traces or a JSONL file) and reports parse-quality
|
|
7
|
+
metrics. Live mode (both output formats against a vLLM box) requires an
|
|
8
|
+
explicit --endpoint.
|
|
9
|
+
|
|
10
|
+
Input sources (exactly one):
|
|
11
|
+
--input PATH JSONL of sample events. Each line is an event dict:
|
|
12
|
+
{"id", "source_kind", "content", "attributes", ...}.
|
|
13
|
+
Optional replay fields: "raw_response" (KV-format
|
|
14
|
+
model output) and/or "raw_response_guided"
|
|
15
|
+
(guided-JSON model output) — when present they are
|
|
16
|
+
parsed offline with the matching parser.
|
|
17
|
+
--from-traces Pull events + raw KV responses from
|
|
18
|
+
distillation_traces (joined to events). user_prompt
|
|
19
|
+
is exactly build_event_block() output, so it doubles
|
|
20
|
+
as the live-replay input; raw_response gives
|
|
21
|
+
KV-format parse stats for free. Connection from
|
|
22
|
+
--pg-dsn or PG_DSN env.
|
|
23
|
+
|
|
24
|
+
Eval-set hygiene (ALWAYS applied — mirrors worker.claim_next_batch's
|
|
25
|
+
claim-time pre-filters so the eval distribution matches what the
|
|
26
|
+
distiller actually processes):
|
|
27
|
+
- source_kind == 'agent' excluded (never distil the agent's own output)
|
|
28
|
+
- attributes.source in {seesa, claude-code-plugin, pip-code-ingest}
|
|
29
|
+
or prefixed openclaw-/triage-/briefing- excluded
|
|
30
|
+
- bytes-garbage excluded (> 10% U+FFFD replacement chars)
|
|
31
|
+
|
|
32
|
+
Per-mode metrics:
|
|
33
|
+
- parse yield: ENT / FCT / REL counts (total + per-event mean)
|
|
34
|
+
- dropped-record rate (kv: record-shaped lines the parser rejected;
|
|
35
|
+
guided: schema violations reported by validate_payload + events
|
|
36
|
+
lost to truncation salvage)
|
|
37
|
+
- FCT field-count violations (kv: FCT| lines without exactly 6 segments)
|
|
38
|
+
- pipe-debris instances (a literal `|` inside any parsed field value)
|
|
39
|
+
- subject-not-in-ENT rate (facts whose subject isn't a declared entity)
|
|
40
|
+
- junk-entity rate (noise_filter.is_noise_entity_name)
|
|
41
|
+
- groundedness (entity-name substring spot-check against the source)
|
|
42
|
+
- latency (live calls; or distillation_traces.llm_chunk_ms offline)
|
|
43
|
+
|
|
44
|
+
Outputs: <out-dir>/report.md + <out-dir>/records.jsonl
|
|
45
|
+
|
|
46
|
+
Examples:
|
|
47
|
+
# offline — replay 500 recent traces, KV parse stats only
|
|
48
|
+
python bakeoff_guided_vs_kv.py --from-traces --limit 500
|
|
49
|
+
|
|
50
|
+
# offline — replay a JSONL that already contains both raw outputs
|
|
51
|
+
python bakeoff_guided_vs_kv.py --input samples.jsonl
|
|
52
|
+
|
|
53
|
+
# LIVE — run both formats against the 7B box (explicit opt-in)
|
|
54
|
+
python bakeoff_guided_vs_kv.py --from-traces --limit 100 \
|
|
55
|
+
--endpoint http://172.31.91.6:8005/v1/chat/completions \
|
|
56
|
+
--model Qwen/Qwen2.5-7B-Instruct
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
from __future__ import annotations
|
|
60
|
+
|
|
61
|
+
import argparse
|
|
62
|
+
import importlib.util
|
|
63
|
+
import json
|
|
64
|
+
import os
|
|
65
|
+
import re
|
|
66
|
+
import statistics
|
|
67
|
+
import sys
|
|
68
|
+
import time
|
|
69
|
+
from pathlib import Path
|
|
70
|
+
from typing import Any
|
|
71
|
+
|
|
72
|
+
# ----------------------------------------------------------------------
|
|
73
|
+
# Load the worker + schema modules from extractor-async (flat scripts
|
|
74
|
+
# dir, no package) the same way the unit tests do.
|
|
75
|
+
# ----------------------------------------------------------------------
|
|
76
|
+
|
|
77
|
+
_EXTRACTOR_DIR = Path(__file__).resolve().parent.parent / "extractor-async"
|
|
78
|
+
sys.path.insert(0, str(_EXTRACTOR_DIR))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _load(name: str):
|
|
82
|
+
spec = importlib.util.spec_from_file_location(name, _EXTRACTOR_DIR / f"{name}.py")
|
|
83
|
+
assert spec and spec.loader
|
|
84
|
+
mod = importlib.util.module_from_spec(spec)
|
|
85
|
+
spec.loader.exec_module(mod)
|
|
86
|
+
return mod
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
worker = _load("worker")
|
|
90
|
+
extraction_schema = _load("extraction_schema")
|
|
91
|
+
noise_filter = _load("noise_filter")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# ----------------------------------------------------------------------
|
|
95
|
+
# Eval-set hygiene — mirrors worker.claim_next_batch's pre-filters
|
|
96
|
+
# (source_kind=agent skip, skip-sources, bytes-garbage ratio) plus the
|
|
97
|
+
# known agent-adjacent sources that pollute eval sets.
|
|
98
|
+
# ----------------------------------------------------------------------
|
|
99
|
+
|
|
100
|
+
EXCLUDE_SOURCES_EXACT = {"seesa", "claude-code-plugin", "pip-code-ingest"}
|
|
101
|
+
EXCLUDE_SOURCE_PREFIXES = ("openclaw-", "triage-", "briefing-")
|
|
102
|
+
GARBAGE_CHAR_RATIO = 0.10 # mirrors DISTILL_GARBAGE_CHAR_RATIO default
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def is_eval_eligible(event: dict[str, Any]) -> tuple[bool, str]:
|
|
106
|
+
"""(eligible, reason-if-not). Same spirit as claim_next_batch."""
|
|
107
|
+
if (event.get("source_kind") or "") == "agent":
|
|
108
|
+
return False, "source_kind=agent"
|
|
109
|
+
attrs = event.get("attributes") or {}
|
|
110
|
+
if isinstance(attrs, str):
|
|
111
|
+
try:
|
|
112
|
+
attrs = json.loads(attrs)
|
|
113
|
+
except json.JSONDecodeError:
|
|
114
|
+
attrs = {}
|
|
115
|
+
src = str(attrs.get("source") or "")
|
|
116
|
+
if src in EXCLUDE_SOURCES_EXACT or src.startswith(EXCLUDE_SOURCE_PREFIXES):
|
|
117
|
+
return False, f"source={src}"
|
|
118
|
+
content = event.get("content") or ""
|
|
119
|
+
if content and content.count("�") / len(content) > GARBAGE_CHAR_RATIO:
|
|
120
|
+
return False, "bytes_garbage"
|
|
121
|
+
return True, ""
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# ----------------------------------------------------------------------
|
|
125
|
+
# Per-event metric extraction
|
|
126
|
+
# ----------------------------------------------------------------------
|
|
127
|
+
|
|
128
|
+
_KV_RECORD_RE = re.compile(r"^(?:[-*]\s+)?(ENT|FCT|REL)\|")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _kv_raw_stats(raw: str) -> dict[str, int]:
|
|
132
|
+
"""Raw-text stats the parser can't see: candidate record lines and
|
|
133
|
+
FCT lines whose segment count isn't exactly 6."""
|
|
134
|
+
candidates = 0
|
|
135
|
+
fct_field_violations = 0
|
|
136
|
+
for line in raw.splitlines():
|
|
137
|
+
line = line.strip()
|
|
138
|
+
m = _KV_RECORD_RE.match(line)
|
|
139
|
+
if not m:
|
|
140
|
+
continue
|
|
141
|
+
candidates += 1
|
|
142
|
+
if m.group(1) == "FCT" and len(line.split("|")) != 6:
|
|
143
|
+
fct_field_violations += 1
|
|
144
|
+
return {"candidates": candidates, "fct_field_violations": fct_field_violations}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _iter_parsed_values(rec: dict[str, Any]):
|
|
148
|
+
for e in rec.get("entities", []):
|
|
149
|
+
yield e.get("name") or ""
|
|
150
|
+
for a in e.get("aliases") or []:
|
|
151
|
+
yield a
|
|
152
|
+
for f in rec.get("facts", []):
|
|
153
|
+
for k in ("subject", "predicate", "object", "statement"):
|
|
154
|
+
yield f.get(k) or ""
|
|
155
|
+
for r in rec.get("relationships", []):
|
|
156
|
+
for k in ("from", "to", "type"):
|
|
157
|
+
yield r.get(k) or ""
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def score_event(
|
|
161
|
+
rec: dict[str, Any], source_text: str
|
|
162
|
+
) -> dict[str, Any]:
|
|
163
|
+
"""Quality metrics for one parsed per-event extraction dict."""
|
|
164
|
+
ents = rec.get("entities", [])
|
|
165
|
+
facts = rec.get("facts", [])
|
|
166
|
+
rels = rec.get("relationships", [])
|
|
167
|
+
ent_names = {e.get("name", "") for e in ents}
|
|
168
|
+
src_lower = (source_text or "").lower()
|
|
169
|
+
|
|
170
|
+
junk = sum(
|
|
171
|
+
1 for e in ents
|
|
172
|
+
if noise_filter.is_noise_entity_name(e.get("type", ""), e.get("name", ""))
|
|
173
|
+
)
|
|
174
|
+
# Groundedness spot-check: the entity surface form should literally
|
|
175
|
+
# appear in the source (case-insensitive). Substring is deliberately
|
|
176
|
+
# crude — it overcounts grounding for short names, undercounts for
|
|
177
|
+
# inflected forms — but it's model-free and symmetric across modes.
|
|
178
|
+
ungrounded = sum(
|
|
179
|
+
1 for e in ents
|
|
180
|
+
if e.get("name") and e.get("name", "").lower() not in src_lower
|
|
181
|
+
)
|
|
182
|
+
subject_violations = sum(
|
|
183
|
+
1 for f in facts if f.get("subject") and f.get("subject") not in ent_names
|
|
184
|
+
)
|
|
185
|
+
pipe_debris = sum(1 for v in _iter_parsed_values(rec) if "|" in str(v))
|
|
186
|
+
return {
|
|
187
|
+
"ent": len(ents),
|
|
188
|
+
"fct": len(facts),
|
|
189
|
+
"rel": len(rels),
|
|
190
|
+
"junk_entities": junk,
|
|
191
|
+
"ungrounded_entities": ungrounded,
|
|
192
|
+
"subject_not_in_ent": subject_violations,
|
|
193
|
+
"pipe_debris": pipe_debris,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# ----------------------------------------------------------------------
|
|
198
|
+
# Aggregation
|
|
199
|
+
# ----------------------------------------------------------------------
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class ModeStats:
|
|
203
|
+
def __init__(self, mode: str) -> None:
|
|
204
|
+
self.mode = mode
|
|
205
|
+
self.events = 0
|
|
206
|
+
self.totals: dict[str, int] = {
|
|
207
|
+
"ent": 0, "fct": 0, "rel": 0,
|
|
208
|
+
"junk_entities": 0, "ungrounded_entities": 0,
|
|
209
|
+
"subject_not_in_ent": 0, "pipe_debris": 0,
|
|
210
|
+
"dropped_records": 0, "fct_field_violations": 0,
|
|
211
|
+
"schema_violations": 0, "events_lost_to_truncation": 0,
|
|
212
|
+
}
|
|
213
|
+
self.latencies_ms: list[float] = []
|
|
214
|
+
|
|
215
|
+
def add(self, scores: dict[str, Any]) -> None:
|
|
216
|
+
self.events += 1
|
|
217
|
+
for k, v in scores.items():
|
|
218
|
+
if k in self.totals:
|
|
219
|
+
self.totals[k] += v
|
|
220
|
+
|
|
221
|
+
def rate(self, key: str, denom_key: str = "ent") -> float:
|
|
222
|
+
denom = self.totals[denom_key]
|
|
223
|
+
return (self.totals[key] / denom) if denom else 0.0
|
|
224
|
+
|
|
225
|
+
def summary(self) -> dict[str, Any]:
|
|
226
|
+
t = self.totals
|
|
227
|
+
out: dict[str, Any] = {"mode": self.mode, "events": self.events, **t}
|
|
228
|
+
out["junk_entity_rate"] = round(self.rate("junk_entities"), 4)
|
|
229
|
+
out["ungrounded_entity_rate"] = round(self.rate("ungrounded_entities"), 4)
|
|
230
|
+
out["subject_not_in_ent_rate"] = round(
|
|
231
|
+
self.rate("subject_not_in_ent", "fct"), 4
|
|
232
|
+
)
|
|
233
|
+
if self.latencies_ms:
|
|
234
|
+
out["latency_ms_p50"] = round(statistics.median(self.latencies_ms), 1)
|
|
235
|
+
out["latency_ms_mean"] = round(statistics.fmean(self.latencies_ms), 1)
|
|
236
|
+
return out
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# ----------------------------------------------------------------------
|
|
240
|
+
# Offline replay
|
|
241
|
+
# ----------------------------------------------------------------------
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def replay_kv(raw: str, source_text: str, stats: ModeStats) -> dict[str, Any]:
|
|
245
|
+
parsed = worker._parse_kv_records(raw, expected_n=1)[0]
|
|
246
|
+
scores = score_event(parsed, source_text)
|
|
247
|
+
raw_stats = _kv_raw_stats(raw)
|
|
248
|
+
parsed_count = scores["ent"] + scores["fct"] + scores["rel"]
|
|
249
|
+
scores["dropped_records"] = max(0, raw_stats["candidates"] - parsed_count)
|
|
250
|
+
scores["fct_field_violations"] = raw_stats["fct_field_violations"]
|
|
251
|
+
stats.add(scores)
|
|
252
|
+
return scores
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def replay_guided(raw: str, source_text: str, stats: ModeStats) -> dict[str, Any]:
|
|
256
|
+
parsed = worker._parse_guided_json(raw, expected_n=1)[0]
|
|
257
|
+
scores = score_event(parsed, source_text)
|
|
258
|
+
payload = worker._load_guided_payload(raw)
|
|
259
|
+
if payload is None:
|
|
260
|
+
scores["events_lost_to_truncation"] = 1
|
|
261
|
+
scores["schema_violations"] = 0
|
|
262
|
+
else:
|
|
263
|
+
scores["events_lost_to_truncation"] = (
|
|
264
|
+
0 if isinstance(payload.get("events"), list) and payload["events"] else 1
|
|
265
|
+
)
|
|
266
|
+
scores["schema_violations"] = len(extraction_schema.validate_payload(payload))
|
|
267
|
+
stats.add(scores)
|
|
268
|
+
return scores
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
# ----------------------------------------------------------------------
|
|
272
|
+
# Live mode (explicit --endpoint only)
|
|
273
|
+
# ----------------------------------------------------------------------
|
|
274
|
+
|
|
275
|
+
_EVENT_HEADER_LINE_RE = re.compile(r"^\[event \d+\]", re.MULTILINE)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _build_live_body(
|
|
279
|
+
mode: str, user_prompt: str, n: int, model: str, param_style: str
|
|
280
|
+
) -> dict[str, Any]:
|
|
281
|
+
if mode == "kv":
|
|
282
|
+
return {
|
|
283
|
+
"model": model,
|
|
284
|
+
"messages": [
|
|
285
|
+
{"role": "system", "content": worker.BATCH_SYSTEM_PROMPT},
|
|
286
|
+
{"role": "user", "content": user_prompt},
|
|
287
|
+
],
|
|
288
|
+
"temperature": 0.0,
|
|
289
|
+
"max_tokens": worker.LLM_MAX_TOKENS_PER_EVENT * n,
|
|
290
|
+
}
|
|
291
|
+
body: dict[str, Any] = {
|
|
292
|
+
"model": model,
|
|
293
|
+
"messages": [
|
|
294
|
+
{"role": "system", "content": worker.GUIDED_JSON_SYSTEM_PROMPT},
|
|
295
|
+
{"role": "user", "content": user_prompt},
|
|
296
|
+
],
|
|
297
|
+
"temperature": 0.0,
|
|
298
|
+
"max_tokens": worker.LLM_MAX_TOKENS_PER_EVENT_JSON * n,
|
|
299
|
+
}
|
|
300
|
+
if param_style == "guided_json":
|
|
301
|
+
body["guided_json"] = extraction_schema.EXTRACTION_SCHEMA
|
|
302
|
+
else:
|
|
303
|
+
body["response_format"] = {
|
|
304
|
+
"type": "json_schema",
|
|
305
|
+
"json_schema": {
|
|
306
|
+
"name": "memory_extraction",
|
|
307
|
+
"strict": True,
|
|
308
|
+
"schema": extraction_schema.EXTRACTION_SCHEMA,
|
|
309
|
+
},
|
|
310
|
+
}
|
|
311
|
+
return body
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def run_live(
|
|
315
|
+
samples: list[dict[str, Any]],
|
|
316
|
+
endpoint: str,
|
|
317
|
+
model: str,
|
|
318
|
+
api_key: str,
|
|
319
|
+
param_style: str,
|
|
320
|
+
batch_size: int,
|
|
321
|
+
timeout: float,
|
|
322
|
+
kv_stats: ModeStats,
|
|
323
|
+
guided_stats: ModeStats,
|
|
324
|
+
records: list[dict[str, Any]],
|
|
325
|
+
) -> None:
|
|
326
|
+
import httpx # already a worker dep
|
|
327
|
+
|
|
328
|
+
headers = {"Content-Type": "application/json"}
|
|
329
|
+
if api_key:
|
|
330
|
+
headers["X-API-Key"] = api_key
|
|
331
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
332
|
+
|
|
333
|
+
with httpx.Client(timeout=timeout) as client:
|
|
334
|
+
for start in range(0, len(samples), batch_size):
|
|
335
|
+
chunk = samples[start : start + batch_size]
|
|
336
|
+
n = len(chunk)
|
|
337
|
+
blocks = []
|
|
338
|
+
for i, s in enumerate(chunk):
|
|
339
|
+
if s.get("user_prompt"):
|
|
340
|
+
# Trace replay: user_prompt IS build_event_block
|
|
341
|
+
# output; renumber the header to the chunk-local
|
|
342
|
+
# index so parsing reattaches correctly.
|
|
343
|
+
block = _EVENT_HEADER_LINE_RE.sub(
|
|
344
|
+
f"[event {i}]", s["user_prompt"], count=1
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
block = worker.build_event_block(i, s)
|
|
348
|
+
blocks.append(block)
|
|
349
|
+
user_prompt = "\n\n---\n\n".join(blocks)
|
|
350
|
+
|
|
351
|
+
for mode, stats in (("kv", kv_stats), ("guided_json", guided_stats)):
|
|
352
|
+
body = _build_live_body(mode, user_prompt, n, model, param_style)
|
|
353
|
+
t0 = time.perf_counter()
|
|
354
|
+
r = client.post(endpoint, json=body, headers=headers)
|
|
355
|
+
r.raise_for_status()
|
|
356
|
+
ms = (time.perf_counter() - t0) * 1000
|
|
357
|
+
stats.latencies_ms.append(ms)
|
|
358
|
+
data = r.json()
|
|
359
|
+
text = (
|
|
360
|
+
(data.get("choices") or [{}])[0]
|
|
361
|
+
.get("message", {})
|
|
362
|
+
.get("content", "")
|
|
363
|
+
) or ""
|
|
364
|
+
if mode == "kv":
|
|
365
|
+
parsed = worker._parse_kv_records(text, n)
|
|
366
|
+
raw_stats = _kv_raw_stats(text)
|
|
367
|
+
else:
|
|
368
|
+
parsed = worker._parse_guided_json(text, n)
|
|
369
|
+
payload = worker._load_guided_payload(text)
|
|
370
|
+
for i, (s, rec) in enumerate(zip(chunk, parsed)):
|
|
371
|
+
src = s.get("user_prompt") or s.get("content") or ""
|
|
372
|
+
scores = score_event(rec, src)
|
|
373
|
+
if mode == "kv":
|
|
374
|
+
# Chunk-level raw stats land on the first event of
|
|
375
|
+
# the chunk (they aren't attributable per-event).
|
|
376
|
+
if i == 0:
|
|
377
|
+
parsed_count = sum(
|
|
378
|
+
len(p.get("entities", []))
|
|
379
|
+
+ len(p.get("facts", []))
|
|
380
|
+
+ len(p.get("relationships", []))
|
|
381
|
+
for p in parsed
|
|
382
|
+
)
|
|
383
|
+
scores["dropped_records"] = max(
|
|
384
|
+
0, raw_stats["candidates"] - parsed_count
|
|
385
|
+
)
|
|
386
|
+
scores["fct_field_violations"] = raw_stats[
|
|
387
|
+
"fct_field_violations"
|
|
388
|
+
]
|
|
389
|
+
else:
|
|
390
|
+
if i == 0:
|
|
391
|
+
if payload is None:
|
|
392
|
+
scores["events_lost_to_truncation"] = n
|
|
393
|
+
else:
|
|
394
|
+
got = (
|
|
395
|
+
len(payload.get("events", []))
|
|
396
|
+
if isinstance(payload.get("events"), list)
|
|
397
|
+
else 0
|
|
398
|
+
)
|
|
399
|
+
scores["events_lost_to_truncation"] = max(0, n - got)
|
|
400
|
+
scores["schema_violations"] = len(
|
|
401
|
+
extraction_schema.validate_payload(payload)
|
|
402
|
+
)
|
|
403
|
+
stats.add(scores)
|
|
404
|
+
records.append(
|
|
405
|
+
{
|
|
406
|
+
"event_id": s.get("id") or s.get("event_id"),
|
|
407
|
+
"mode": mode,
|
|
408
|
+
"live": True,
|
|
409
|
+
"latency_ms_chunk": round(ms, 1),
|
|
410
|
+
**scores,
|
|
411
|
+
}
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
# ----------------------------------------------------------------------
|
|
416
|
+
# Input loading
|
|
417
|
+
# ----------------------------------------------------------------------
|
|
418
|
+
|
|
419
|
+
TRACES_SQL = """
|
|
420
|
+
SELECT t.event_id, t.user_prompt, t.raw_response, t.llm_chunk_ms,
|
|
421
|
+
t.system_prompt_hash,
|
|
422
|
+
e.source_kind, e.content, e.attributes
|
|
423
|
+
FROM distillation_traces t
|
|
424
|
+
JOIN events e ON e.id = t.event_id
|
|
425
|
+
ORDER BY t.created_at DESC
|
|
426
|
+
LIMIT %s
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def load_from_traces(pg_dsn: str, limit: int) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
|
431
|
+
import psycopg
|
|
432
|
+
import psycopg.rows
|
|
433
|
+
|
|
434
|
+
excluded: dict[str, int] = {}
|
|
435
|
+
samples: list[dict[str, Any]] = []
|
|
436
|
+
# Over-fetch so hygiene filtering still leaves ~limit rows.
|
|
437
|
+
with psycopg.connect(pg_dsn) as conn:
|
|
438
|
+
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
|
439
|
+
cur.execute(TRACES_SQL, (limit * 5,))
|
|
440
|
+
for row in cur:
|
|
441
|
+
ok, reason = is_eval_eligible(row)
|
|
442
|
+
if not ok:
|
|
443
|
+
excluded[reason] = excluded.get(reason, 0) + 1
|
|
444
|
+
continue
|
|
445
|
+
samples.append(dict(row))
|
|
446
|
+
if len(samples) >= limit:
|
|
447
|
+
break
|
|
448
|
+
return samples, excluded
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def load_from_jsonl(path: Path, limit: int) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
|
452
|
+
excluded: dict[str, int] = {}
|
|
453
|
+
samples: list[dict[str, Any]] = []
|
|
454
|
+
with path.open() as fh:
|
|
455
|
+
for line in fh:
|
|
456
|
+
line = line.strip()
|
|
457
|
+
if not line:
|
|
458
|
+
continue
|
|
459
|
+
ev = json.loads(line)
|
|
460
|
+
ok, reason = is_eval_eligible(ev)
|
|
461
|
+
if not ok:
|
|
462
|
+
excluded[reason] = excluded.get(reason, 0) + 1
|
|
463
|
+
continue
|
|
464
|
+
samples.append(ev)
|
|
465
|
+
if len(samples) >= limit:
|
|
466
|
+
break
|
|
467
|
+
return samples, excluded
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
# ----------------------------------------------------------------------
|
|
471
|
+
# Report
|
|
472
|
+
# ----------------------------------------------------------------------
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def write_report(
|
|
476
|
+
out_dir: Path,
|
|
477
|
+
kv_stats: ModeStats,
|
|
478
|
+
guided_stats: ModeStats,
|
|
479
|
+
excluded: dict[str, int],
|
|
480
|
+
records: list[dict[str, Any]],
|
|
481
|
+
live: bool,
|
|
482
|
+
) -> None:
|
|
483
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
484
|
+
with (out_dir / "records.jsonl").open("w") as fh:
|
|
485
|
+
for rec in records:
|
|
486
|
+
fh.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
|
487
|
+
|
|
488
|
+
rows = [s.summary() for s in (kv_stats, guided_stats) if s.events]
|
|
489
|
+
keys = [
|
|
490
|
+
"events", "ent", "fct", "rel",
|
|
491
|
+
"dropped_records", "fct_field_violations",
|
|
492
|
+
"schema_violations", "events_lost_to_truncation",
|
|
493
|
+
"pipe_debris", "junk_entity_rate", "ungrounded_entity_rate",
|
|
494
|
+
"subject_not_in_ent_rate", "latency_ms_p50", "latency_ms_mean",
|
|
495
|
+
]
|
|
496
|
+
lines = [
|
|
497
|
+
"# Guided-JSON vs KV distiller bake-off",
|
|
498
|
+
"",
|
|
499
|
+
f"- generated: {time.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
|
500
|
+
f"- mode: {'LIVE (called --endpoint)' if live else 'offline replay (no LLM calls)'}",
|
|
501
|
+
f"- eval-set hygiene exclusions: {json.dumps(excluded) if excluded else 'none'}",
|
|
502
|
+
"",
|
|
503
|
+
"| metric | " + " | ".join(r["mode"] for r in rows) + " |",
|
|
504
|
+
"|---|" + "---|" * len(rows),
|
|
505
|
+
]
|
|
506
|
+
for k in keys:
|
|
507
|
+
vals = [str(r.get(k, "—")) for r in rows]
|
|
508
|
+
lines.append(f"| {k} | " + " | ".join(vals) + " |")
|
|
509
|
+
lines += [
|
|
510
|
+
"",
|
|
511
|
+
"Notes:",
|
|
512
|
+
"- `dropped_records` / `fct_field_violations` / `pipe_debris` are",
|
|
513
|
+
" KV-format failure classes; guided decoding makes them",
|
|
514
|
+
" structurally impossible (`schema_violations` /",
|
|
515
|
+
" `events_lost_to_truncation` are the guided-mode analogues).",
|
|
516
|
+
"- `ungrounded_entity_rate` is a crude substring spot-check of",
|
|
517
|
+
" entity names against the source text — compare across modes,",
|
|
518
|
+
" don't read it as an absolute hallucination rate.",
|
|
519
|
+
]
|
|
520
|
+
(out_dir / "report.md").write_text("\n".join(lines) + "\n")
|
|
521
|
+
print("\n".join(lines))
|
|
522
|
+
print(f"\nwrote {out_dir / 'report.md'} and {out_dir / 'records.jsonl'}")
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
# ----------------------------------------------------------------------
|
|
526
|
+
# Main
|
|
527
|
+
# ----------------------------------------------------------------------
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def main() -> int:
|
|
531
|
+
ap = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
|
|
532
|
+
src = ap.add_mutually_exclusive_group(required=True)
|
|
533
|
+
src.add_argument("--input", type=Path, help="JSONL of sample events")
|
|
534
|
+
src.add_argument(
|
|
535
|
+
"--from-traces", action="store_true",
|
|
536
|
+
help="pull events + raw KV responses from distillation_traces",
|
|
537
|
+
)
|
|
538
|
+
ap.add_argument("--pg-dsn", default=os.environ.get("PG_DSN", ""),
|
|
539
|
+
help="Postgres DSN for --from-traces (default: PG_DSN env)")
|
|
540
|
+
ap.add_argument("--limit", type=int, default=200)
|
|
541
|
+
ap.add_argument("--endpoint", default="",
|
|
542
|
+
help="chat-completions URL. REQUIRED for live mode; "
|
|
543
|
+
"omit for offline replay (no LLM calls).")
|
|
544
|
+
ap.add_argument("--model", default=os.environ.get("LLM_MODEL", "Qwen/Qwen2.5-7B-Instruct"))
|
|
545
|
+
ap.add_argument("--api-key", default=os.environ.get("LLM_API_KEY", ""))
|
|
546
|
+
ap.add_argument("--param-style", choices=["response_format", "guided_json"],
|
|
547
|
+
default="response_format",
|
|
548
|
+
help="how the schema is attached in guided live calls")
|
|
549
|
+
ap.add_argument("--batch-size", type=int, default=15,
|
|
550
|
+
help="events per live LLM call (mirrors EVENTS_PER_LLM_CALL)")
|
|
551
|
+
ap.add_argument("--timeout", type=float, default=180.0)
|
|
552
|
+
ap.add_argument("--out-dir", type=Path, default=Path("bakeoff-out"))
|
|
553
|
+
args = ap.parse_args()
|
|
554
|
+
|
|
555
|
+
if args.from_traces:
|
|
556
|
+
if not args.pg_dsn:
|
|
557
|
+
ap.error("--from-traces needs --pg-dsn or PG_DSN env")
|
|
558
|
+
samples, excluded = load_from_traces(args.pg_dsn, args.limit)
|
|
559
|
+
else:
|
|
560
|
+
samples, excluded = load_from_jsonl(args.input, args.limit)
|
|
561
|
+
|
|
562
|
+
if not samples:
|
|
563
|
+
print("no eligible samples after hygiene filtering", file=sys.stderr)
|
|
564
|
+
return 1
|
|
565
|
+
print(f"{len(samples)} eligible samples "
|
|
566
|
+
f"({sum(excluded.values())} excluded by hygiene filters)")
|
|
567
|
+
|
|
568
|
+
kv_stats = ModeStats("kv")
|
|
569
|
+
guided_stats = ModeStats("guided_json")
|
|
570
|
+
records: list[dict[str, Any]] = []
|
|
571
|
+
|
|
572
|
+
# Offline replay over whatever raw outputs the input carries.
|
|
573
|
+
for s in samples:
|
|
574
|
+
source_text = s.get("user_prompt") or s.get("content") or ""
|
|
575
|
+
raw_kv = s.get("raw_response")
|
|
576
|
+
if raw_kv:
|
|
577
|
+
scores = replay_kv(raw_kv, source_text, kv_stats)
|
|
578
|
+
if s.get("llm_chunk_ms"):
|
|
579
|
+
kv_stats.latencies_ms.append(float(s["llm_chunk_ms"]))
|
|
580
|
+
records.append({"event_id": s.get("event_id") or s.get("id"),
|
|
581
|
+
"mode": "kv", "live": False, **scores})
|
|
582
|
+
raw_guided = s.get("raw_response_guided")
|
|
583
|
+
if raw_guided:
|
|
584
|
+
scores = replay_guided(raw_guided, source_text, guided_stats)
|
|
585
|
+
records.append({"event_id": s.get("event_id") or s.get("id"),
|
|
586
|
+
"mode": "guided_json", "live": False, **scores})
|
|
587
|
+
|
|
588
|
+
if args.endpoint:
|
|
589
|
+
run_live(
|
|
590
|
+
samples, args.endpoint, args.model, args.api_key,
|
|
591
|
+
args.param_style, args.batch_size, args.timeout,
|
|
592
|
+
kv_stats, guided_stats, records,
|
|
593
|
+
)
|
|
594
|
+
elif not guided_stats.events:
|
|
595
|
+
print(
|
|
596
|
+
"note: no guided-JSON outputs in the input and no --endpoint — "
|
|
597
|
+
"guided column will be empty. Pass --endpoint for a live "
|
|
598
|
+
"side-by-side run.",
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
write_report(args.out_dir, kv_stats, guided_stats, excluded, records,
|
|
602
|
+
live=bool(args.endpoint))
|
|
603
|
+
return 0
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
if __name__ == "__main__":
|
|
607
|
+
raise SystemExit(main())
|