mirrorneuron-membrane-python-sdk 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirrorneuron_membrane_python_sdk-1.0.1.dist-info/METADATA +11 -0
- mirrorneuron_membrane_python_sdk-1.0.1.dist-info/RECORD +8 -0
- mirrorneuron_membrane_python_sdk-1.0.1.dist-info/WHEEL +4 -0
- mirrorneuron_membrane_python_sdk-1.0.1.dist-info/entry_points.txt +2 -0
- mn_context_engine_sdk/__init__.py +19 -0
- mn_context_engine_sdk/benchmarks/__init__.py +1 -0
- mn_context_engine_sdk/benchmarks/context_compression_accuracy_benchmark.py +911 -0
- mn_context_engine_sdk/working_memory.py +258 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: mirrorneuron-membrane-python-sdk
|
|
3
|
+
Version: 1.0.1
|
|
4
|
+
Summary: Python SDK shell for consuming the Rust Membrane context engine.
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
License: Proprietary
|
|
7
|
+
Requires-Dist: pydantic>=2.0,<3.0
|
|
8
|
+
Provides-Extra: dev
|
|
9
|
+
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
10
|
+
Provides-Extra: compression
|
|
11
|
+
Requires-Dist: llmlingua>=0.2.2; extra == "compression"
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
mn_context_engine_sdk/__init__.py,sha256=b_cv-mjvkN12vXM6F4FqFy3w-jtJ5oxmpppL1o1WA0o,301
|
|
2
|
+
mn_context_engine_sdk/benchmarks/__init__.py,sha256=vnIljW7ZUJMday5EqJNyyn2LRnIWj4jbcufLEaCj7sQ,58
|
|
3
|
+
mn_context_engine_sdk/benchmarks/context_compression_accuracy_benchmark.py,sha256=3Gj6VGJHf3_zQIHWhLfJtR8ynJuvsnoOJcGrAqc5Avo,34882
|
|
4
|
+
mn_context_engine_sdk/working_memory.py,sha256=uPSIJ3OUrqe8j1YLIjqy3RE5IFgZl8_9r8pJw6y_Cyg,8934
|
|
5
|
+
mirrorneuron_membrane_python_sdk-1.0.1.dist-info/METADATA,sha256=Dgo_iYqfiGagYZgd7AXe7IFs52p7i2EgP77xlUl4WMw,376
|
|
6
|
+
mirrorneuron_membrane_python_sdk-1.0.1.dist-info/WHEEL,sha256=s7mhKzNlJuzZwFeKTmqJmaumsjk0UzyKYJqivS04jas,87
|
|
7
|
+
mirrorneuron_membrane_python_sdk-1.0.1.dist-info/entry_points.txt,sha256=yTQ3U_D2lUs3VnQOtw87yd8YmwFjQYk0SVGPeLGOsgg,130
|
|
8
|
+
mirrorneuron_membrane_python_sdk-1.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .working_memory import (
|
|
2
|
+
MemoryEdge,
|
|
3
|
+
MemoryItem,
|
|
4
|
+
MNMemoryEdge,
|
|
5
|
+
MNMemoryItem,
|
|
6
|
+
MNWorkingMemory,
|
|
7
|
+
WorkingMemory,
|
|
8
|
+
utcnow,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"MNMemoryEdge",
|
|
13
|
+
"MNMemoryItem",
|
|
14
|
+
"MNWorkingMemory",
|
|
15
|
+
"MemoryEdge",
|
|
16
|
+
"MemoryItem",
|
|
17
|
+
"WorkingMemory",
|
|
18
|
+
"utcnow",
|
|
19
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Benchmark entry points for the Membrane Python SDK."""
|
|
@@ -0,0 +1,911 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Compare raw vs compressed context on token use and answer accuracy.
|
|
3
|
+
|
|
4
|
+
The benchmark uses an Ollama-compatible /api/chat endpoint for both the
|
|
5
|
+
answering model and the judging model. It intentionally does not require the
|
|
6
|
+
Rust gRPC service; this keeps it useful for CI smoke checks, remote model
|
|
7
|
+
testing, and quick local experiments.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import json
|
|
14
|
+
import statistics
|
|
15
|
+
import time
|
|
16
|
+
import urllib.error
|
|
17
|
+
import urllib.request
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
DEFAULT_ENDPOINT = "http://192.168.4.173:11434/api/chat"
|
|
25
|
+
DEFAULT_MODEL = "nemotron3:33b"
|
|
26
|
+
CHARS_PER_TOKEN = 4
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class BenchmarkCase:
|
|
31
|
+
case_id: str
|
|
32
|
+
title: str
|
|
33
|
+
task: str
|
|
34
|
+
raw_context: str
|
|
35
|
+
compressed_context: str
|
|
36
|
+
expected_answer: str
|
|
37
|
+
required_terms: tuple[str, ...]
|
|
38
|
+
must_not_claim: tuple[str, ...] = ()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def estimate_tokens(text: str) -> int:
|
|
42
|
+
return max(1, (len(text) + CHARS_PER_TOKEN - 1) // CHARS_PER_TOKEN)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def filler(label: str, count: int) -> str:
|
|
46
|
+
blocks = []
|
|
47
|
+
for idx in range(count):
|
|
48
|
+
blocks.append(
|
|
49
|
+
f"[distractor-{label}-{idx:03d}] Prior agent note: this item is stale, "
|
|
50
|
+
f"superseded, or unrelated to the active goal. It mentions generic "
|
|
51
|
+
f"planning details, old hypotheses, verbose log text, and abandoned "
|
|
52
|
+
f"tool outputs that should not affect the final answer."
|
|
53
|
+
)
|
|
54
|
+
return "\n".join(blocks)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def build_cases() -> list[BenchmarkCase]:
|
|
58
|
+
audit_raw = f"""
|
|
59
|
+
RAW EVENT LOG FOR JOB finance-audit-77
|
|
60
|
+
{filler("audit-pre", 70)}
|
|
61
|
+
[task T-AUDIT-77] Determine whether the assistant disclosed the $49 activation fee before the user accepted the subscription.
|
|
62
|
+
[constraint C-FEE-001] The fee disclosure must happen before agreement. If disclosure happens after agreement, classify as FAIL.
|
|
63
|
+
[evidence EV-FEE-100] 2026-05-01T09:12:04Z assistant: "The plan includes a $49 activation fee before we continue."
|
|
64
|
+
[evidence EV-AGREE-200] 2026-05-01T09:13:11Z user: "I accept the subscription."
|
|
65
|
+
[evidence EV-FEE-OLD] 2026-04-28T11:00:00Z assistant disclosed a different legacy fee in an unrelated session.
|
|
66
|
+
[decision D-DRAFT] Earlier draft said unknown, but it was before EV-FEE-100 was extracted.
|
|
67
|
+
{filler("audit-post", 70)}
|
|
68
|
+
DO_NOT_LOSE: task=T-AUDIT-77, constraint=C-FEE-001, evidence=EV-FEE-100, evidence=EV-AGREE-200.
|
|
69
|
+
"""
|
|
70
|
+
audit_compressed = json.dumps(
|
|
71
|
+
{
|
|
72
|
+
"context_packet": {
|
|
73
|
+
"objective": {
|
|
74
|
+
"focus_id": "T-AUDIT-77",
|
|
75
|
+
"current_user_goal": "Decide if fee disclosure happened before agreement.",
|
|
76
|
+
},
|
|
77
|
+
"hard_constraints": [
|
|
78
|
+
{
|
|
79
|
+
"id": "C-FEE-001",
|
|
80
|
+
"text": "Fee disclosure must happen before agreement. After agreement means FAIL.",
|
|
81
|
+
}
|
|
82
|
+
],
|
|
83
|
+
"retrieved_evidence": [
|
|
84
|
+
{
|
|
85
|
+
"id": "EV-FEE-100",
|
|
86
|
+
"timestamp": "2026-05-01T09:12:04Z",
|
|
87
|
+
"quote": "The plan includes a $49 activation fee before we continue.",
|
|
88
|
+
},
|
|
89
|
+
{
|
|
90
|
+
"id": "EV-AGREE-200",
|
|
91
|
+
"timestamp": "2026-05-01T09:13:11Z",
|
|
92
|
+
"quote": "I accept the subscription.",
|
|
93
|
+
},
|
|
94
|
+
],
|
|
95
|
+
"do_not_lose": ["T-AUDIT-77", "C-FEE-001", "EV-FEE-100", "EV-AGREE-200"],
|
|
96
|
+
}
|
|
97
|
+
},
|
|
98
|
+
separators=(",", ":"),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
incident_raw = f"""
|
|
102
|
+
RAW MULTI-AGENT INCIDENT CONTEXT FOR JOB incident-414
|
|
103
|
+
{filler("incident-pre", 90)}
|
|
104
|
+
[task T-INC-414] Identify the root cause and cite the decisive logs.
|
|
105
|
+
[constraint C-INC-010] Do not blame the database unless DB-OK-9 is contradicted by newer evidence.
|
|
106
|
+
[log DB-OK-9] 2026-05-01T02:14:05Z database p95 latency remained below 11ms and no lock waits were observed.
|
|
107
|
+
[log API-ERR-22] 2026-05-01T02:16:39Z api-gateway began returning 502 for checkout requests.
|
|
108
|
+
[log CACHE-MISS-7] 2026-05-01T02:16:40Z feature flag ff_checkout_cache_v2 enabled in us-east.
|
|
109
|
+
[log TRACE-ROOT-5] 2026-05-01T02:16:41Z checkout service error: missing cache namespace checkout:v2:prices.
|
|
110
|
+
[fix FIX-88] Roll back ff_checkout_cache_v2 or create namespace checkout:v2:prices before enabling.
|
|
111
|
+
[stale hypothesis H-OLD-1] CPU saturation caused outage. Rejected after metrics review.
|
|
112
|
+
{filler("incident-post", 80)}
|
|
113
|
+
DO_NOT_LOSE: T-INC-414, C-INC-010, DB-OK-9, API-ERR-22, CACHE-MISS-7, TRACE-ROOT-5, FIX-88.
|
|
114
|
+
"""
|
|
115
|
+
incident_compressed = json.dumps(
|
|
116
|
+
{
|
|
117
|
+
"context_packet": {
|
|
118
|
+
"objective": {
|
|
119
|
+
"focus_id": "T-INC-414",
|
|
120
|
+
"current_user_goal": "Identify root cause and next fix.",
|
|
121
|
+
},
|
|
122
|
+
"hard_constraints": [
|
|
123
|
+
{
|
|
124
|
+
"id": "C-INC-010",
|
|
125
|
+
"text": "Do not blame the database unless DB-OK-9 is contradicted.",
|
|
126
|
+
}
|
|
127
|
+
],
|
|
128
|
+
"retrieved_evidence": [
|
|
129
|
+
{"id": "DB-OK-9", "summary": "Database p95 < 11ms; no lock waits."},
|
|
130
|
+
{"id": "API-ERR-22", "summary": "API gateway 502s began at 02:16:39Z."},
|
|
131
|
+
{
|
|
132
|
+
"id": "CACHE-MISS-7",
|
|
133
|
+
"summary": "ff_checkout_cache_v2 enabled in us-east at 02:16:40Z.",
|
|
134
|
+
},
|
|
135
|
+
{
|
|
136
|
+
"id": "TRACE-ROOT-5",
|
|
137
|
+
"summary": "Checkout failed because cache namespace checkout:v2:prices was missing.",
|
|
138
|
+
},
|
|
139
|
+
],
|
|
140
|
+
"shared_state": [
|
|
141
|
+
{
|
|
142
|
+
"id": "FIX-88",
|
|
143
|
+
"summary": "Rollback ff_checkout_cache_v2 or create checkout:v2:prices before enabling.",
|
|
144
|
+
}
|
|
145
|
+
],
|
|
146
|
+
"do_not_lose": [
|
|
147
|
+
"T-INC-414",
|
|
148
|
+
"C-INC-010",
|
|
149
|
+
"DB-OK-9",
|
|
150
|
+
"API-ERR-22",
|
|
151
|
+
"CACHE-MISS-7",
|
|
152
|
+
"TRACE-ROOT-5",
|
|
153
|
+
"FIX-88",
|
|
154
|
+
"checkout:v2:prices",
|
|
155
|
+
],
|
|
156
|
+
}
|
|
157
|
+
},
|
|
158
|
+
separators=(",", ":"),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
handoff_raw = f"""
|
|
162
|
+
RAW AGENT HANDOFF CONTEXT FOR JOB code-review-209
|
|
163
|
+
{filler("handoff-pre", 75)}
|
|
164
|
+
[task T-CODE-209] Prepare the next coding agent to finish the Membrane context compiler safely.
|
|
165
|
+
[user-constraint UC-1] Keep existing gRPC APIs backward compatible.
|
|
166
|
+
[user-constraint UC-2] LLMLingua/model compression must be optional; the system must work without GPU or Python model dependencies.
|
|
167
|
+
[finding F-1] CompileContext should preserve exact IDs, source_refs, do_not_lose, and hard constraints.
|
|
168
|
+
[finding F-2] Deterministic compression should run before model compression.
|
|
169
|
+
[finding F-3] If model compression drops pinned terms, reject it and return deterministic packet.
|
|
170
|
+
[handoff H-209] Next agent should add docs, tests, and an operator flag named MN_CONTEXT_MODEL_COMPRESSION_ENABLED.
|
|
171
|
+
{filler("handoff-post", 75)}
|
|
172
|
+
DO_NOT_LOSE: T-CODE-209, UC-1, UC-2, F-1, F-2, F-3, H-209, MN_CONTEXT_MODEL_COMPRESSION_ENABLED.
|
|
173
|
+
"""
|
|
174
|
+
handoff_compressed = json.dumps(
|
|
175
|
+
{
|
|
176
|
+
"context_packet": {
|
|
177
|
+
"objective": {
|
|
178
|
+
"focus_id": "T-CODE-209",
|
|
179
|
+
"current_user_goal": "Prepare next coding agent handoff.",
|
|
180
|
+
},
|
|
181
|
+
"hard_constraints": [
|
|
182
|
+
{"id": "UC-1", "text": "Keep existing gRPC APIs backward compatible."},
|
|
183
|
+
{
|
|
184
|
+
"id": "UC-2",
|
|
185
|
+
"text": "Model compression must be optional; system works without GPU/Python model deps.",
|
|
186
|
+
},
|
|
187
|
+
],
|
|
188
|
+
"shared_state": [
|
|
189
|
+
{
|
|
190
|
+
"id": "F-1",
|
|
191
|
+
"summary": "CompileContext preserves exact IDs, source_refs, do_not_lose, constraints.",
|
|
192
|
+
},
|
|
193
|
+
{"id": "F-2", "summary": "Run deterministic compression before model compression."},
|
|
194
|
+
{
|
|
195
|
+
"id": "F-3",
|
|
196
|
+
"summary": "Reject model output if pinned terms are lost.",
|
|
197
|
+
},
|
|
198
|
+
],
|
|
199
|
+
"handoff": {
|
|
200
|
+
"id": "H-209",
|
|
201
|
+
"next_expected_action": "Add docs, tests, and MN_CONTEXT_MODEL_COMPRESSION_ENABLED operator flag.",
|
|
202
|
+
},
|
|
203
|
+
"do_not_lose": [
|
|
204
|
+
"T-CODE-209",
|
|
205
|
+
"UC-1",
|
|
206
|
+
"UC-2",
|
|
207
|
+
"F-1",
|
|
208
|
+
"F-2",
|
|
209
|
+
"F-3",
|
|
210
|
+
"H-209",
|
|
211
|
+
"MN_CONTEXT_MODEL_COMPRESSION_ENABLED",
|
|
212
|
+
],
|
|
213
|
+
}
|
|
214
|
+
},
|
|
215
|
+
separators=(",", ":"),
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return [
|
|
219
|
+
BenchmarkCase(
|
|
220
|
+
case_id="finance_fee_ordering",
|
|
221
|
+
title="Fee disclosure ordering",
|
|
222
|
+
task=(
|
|
223
|
+
"Decide PASS or FAIL. Cite the decisive evidence IDs and give one sentence of reasoning."
|
|
224
|
+
),
|
|
225
|
+
raw_context=audit_raw,
|
|
226
|
+
compressed_context=audit_compressed,
|
|
227
|
+
expected_answer=(
|
|
228
|
+
"PASS. EV-FEE-100 at 09:12:04Z disclosed the $49 activation fee before "
|
|
229
|
+
"EV-AGREE-200 at 09:13:11Z accepted the subscription, satisfying C-FEE-001."
|
|
230
|
+
),
|
|
231
|
+
required_terms=("PASS", "EV-FEE-100", "EV-AGREE-200", "C-FEE-001"),
|
|
232
|
+
must_not_claim=("FAIL", "unknown", "after agreement"),
|
|
233
|
+
),
|
|
234
|
+
BenchmarkCase(
|
|
235
|
+
case_id="incident_root_cause",
|
|
236
|
+
title="Incident root cause",
|
|
237
|
+
task=(
|
|
238
|
+
"Identify the root cause, say what not to blame, and cite the fix ID."
|
|
239
|
+
),
|
|
240
|
+
raw_context=incident_raw,
|
|
241
|
+
compressed_context=incident_compressed,
|
|
242
|
+
expected_answer=(
|
|
243
|
+
"Root cause was enabling ff_checkout_cache_v2 without namespace "
|
|
244
|
+
"checkout:v2:prices, supported by CACHE-MISS-7 and TRACE-ROOT-5. "
|
|
245
|
+
"Do not blame the database because DB-OK-9 shows it was healthy. Use FIX-88."
|
|
246
|
+
),
|
|
247
|
+
required_terms=(
|
|
248
|
+
"ff_checkout_cache_v2",
|
|
249
|
+
"checkout:v2:prices",
|
|
250
|
+
"TRACE-ROOT-5",
|
|
251
|
+
"DB-OK-9",
|
|
252
|
+
"FIX-88",
|
|
253
|
+
),
|
|
254
|
+
must_not_claim=("database root cause", "CPU saturation", "H-OLD-1 caused"),
|
|
255
|
+
),
|
|
256
|
+
BenchmarkCase(
|
|
257
|
+
case_id="agent_handoff_constraints",
|
|
258
|
+
title="Agent handoff constraints",
|
|
259
|
+
task=(
|
|
260
|
+
"Write the next-agent handoff in 3 bullets. Include the required operator flag."
|
|
261
|
+
),
|
|
262
|
+
raw_context=handoff_raw,
|
|
263
|
+
compressed_context=handoff_compressed,
|
|
264
|
+
expected_answer=(
|
|
265
|
+
"The handoff must preserve backward-compatible gRPC APIs, keep LLMLingua optional "
|
|
266
|
+
"so no GPU/Python model dependency is required, preserve pinned terms, and add "
|
|
267
|
+
"MN_CONTEXT_MODEL_COMPRESSION_ENABLED with tests and docs."
|
|
268
|
+
),
|
|
269
|
+
required_terms=(
|
|
270
|
+
"backward compatible",
|
|
271
|
+
"optional",
|
|
272
|
+
"MN_CONTEXT_MODEL_COMPRESSION_ENABLED",
|
|
273
|
+
"pinned",
|
|
274
|
+
),
|
|
275
|
+
must_not_claim=("required GPU", "always use LLMLingua", "breaking API"),
|
|
276
|
+
),
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def answer_prompt(context: str, task: str) -> str:
|
|
281
|
+
return f"""You are a precise multi-agent runtime worker.
|
|
282
|
+
|
|
283
|
+
Use only the context below. Preserve exact IDs and constraints.
|
|
284
|
+
|
|
285
|
+
CONTEXT:
|
|
286
|
+
{context}
|
|
287
|
+
|
|
288
|
+
TASK:
|
|
289
|
+
{task}
|
|
290
|
+
|
|
291
|
+
Answer concisely. Cite exact IDs when available."""
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def judge_prompt(case: BenchmarkCase, answer: str) -> str:
|
|
295
|
+
return f"""You are an exacting benchmark judge for a context-compression experiment.
|
|
296
|
+
|
|
297
|
+
Evaluate whether the candidate answer correctly solves the task using the expected answer as ground truth.
|
|
298
|
+
Do not reward verbosity. Penalize missing exact IDs, inverted decisions, unsupported claims, or contradictions.
|
|
299
|
+
|
|
300
|
+
Task: {case.task}
|
|
301
|
+
|
|
302
|
+
Expected answer:
|
|
303
|
+
{case.expected_answer}
|
|
304
|
+
|
|
305
|
+
Required exact terms that should appear when relevant:
|
|
306
|
+
{json.dumps(case.required_terms)}
|
|
307
|
+
|
|
308
|
+
Claims that must NOT appear:
|
|
309
|
+
{json.dumps(case.must_not_claim)}
|
|
310
|
+
|
|
311
|
+
Candidate answer:
|
|
312
|
+
{answer}
|
|
313
|
+
|
|
314
|
+
Scoring rubric:
|
|
315
|
+
- factual_score: 0.0-1.0. Correct final conclusion and reasoning.
|
|
316
|
+
- required_terms_score: 0.0-1.0. Preserves required IDs/terms.
|
|
317
|
+
- constraint_score: 0.0-1.0. Obeys hard constraints and does not invert them.
|
|
318
|
+
- hallucination_score: 0.0-1.0. 1.0 means no unsupported or forbidden claims.
|
|
319
|
+
- score: weighted final score = 0.45*factual + 0.25*required_terms + 0.20*constraint + 0.10*hallucination.
|
|
320
|
+
- passed: true only if score >= 0.80 and no critical contradiction exists.
|
|
321
|
+
|
|
322
|
+
Return strict JSON only:
|
|
323
|
+
{{
|
|
324
|
+
"factual_score": <0.0 to 1.0>,
|
|
325
|
+
"required_terms_score": <0.0 to 1.0>,
|
|
326
|
+
"constraint_score": <0.0 to 1.0>,
|
|
327
|
+
"hallucination_score": <0.0 to 1.0>,
|
|
328
|
+
"score": <0.0 to 1.0>,
|
|
329
|
+
"passed": <true or false>,
|
|
330
|
+
"missing_terms": ["..."],
|
|
331
|
+
"forbidden_claims_found": ["..."],
|
|
332
|
+
"reason": "<short reason>"
|
|
333
|
+
}}"""
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def pairwise_judge_prompt(case: BenchmarkCase, raw_answer: str, compressed_answer: str) -> str:
|
|
337
|
+
return f"""You are judging whether context compression preserved answer quality.
|
|
338
|
+
|
|
339
|
+
Task: {case.task}
|
|
340
|
+
Expected answer: {case.expected_answer}
|
|
341
|
+
Required terms: {json.dumps(case.required_terms)}
|
|
342
|
+
Forbidden claims: {json.dumps(case.must_not_claim)}
|
|
343
|
+
|
|
344
|
+
Raw-context answer:
|
|
345
|
+
{raw_answer}
|
|
346
|
+
|
|
347
|
+
Compressed-context answer:
|
|
348
|
+
{compressed_answer}
|
|
349
|
+
|
|
350
|
+
Compare correctness, preservation of exact IDs/constraints, and unsupported claims.
|
|
351
|
+
Return strict JSON only:
|
|
352
|
+
{{
|
|
353
|
+
"winner": "raw" | "compressed" | "tie",
|
|
354
|
+
"raw_score": <0.0 to 1.0>,
|
|
355
|
+
"compressed_score": <0.0 to 1.0>,
|
|
356
|
+
"quality_delta": <compressed_score - raw_score>,
|
|
357
|
+
"compression_preserved_accuracy": <true or false>,
|
|
358
|
+
"reason": "<short reason>"
|
|
359
|
+
}}"""
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def ollama_chat(
|
|
363
|
+
endpoint: str,
|
|
364
|
+
model: str,
|
|
365
|
+
messages: list[dict[str, str]],
|
|
366
|
+
timeout: int,
|
|
367
|
+
) -> dict[str, Any]:
|
|
368
|
+
body = json.dumps(
|
|
369
|
+
{
|
|
370
|
+
"model": model,
|
|
371
|
+
"messages": messages,
|
|
372
|
+
"stream": False,
|
|
373
|
+
"options": {"temperature": 0},
|
|
374
|
+
}
|
|
375
|
+
).encode("utf-8")
|
|
376
|
+
request = urllib.request.Request(
|
|
377
|
+
endpoint,
|
|
378
|
+
data=body,
|
|
379
|
+
headers={"Content-Type": "application/json"},
|
|
380
|
+
method="POST",
|
|
381
|
+
)
|
|
382
|
+
started = time.perf_counter()
|
|
383
|
+
with urllib.request.urlopen(request, timeout=timeout) as response:
|
|
384
|
+
data = json.loads(response.read().decode("utf-8"))
|
|
385
|
+
data["_latency_s"] = time.perf_counter() - started
|
|
386
|
+
return data
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def response_text(response: dict[str, Any]) -> str:
|
|
390
|
+
return str(response.get("message", {}).get("content", "")).strip()
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def deterministic_score(answer: str, required_terms: tuple[str, ...]) -> float:
|
|
394
|
+
folded = answer.lower()
|
|
395
|
+
hits = sum(1 for term in required_terms if term.lower() in folded)
|
|
396
|
+
return hits / max(1, len(required_terms))
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def missing_required_terms(answer: str, required_terms: tuple[str, ...]) -> list[str]:
|
|
400
|
+
folded = answer.lower()
|
|
401
|
+
return [term for term in required_terms if term.lower() not in folded]
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def forbidden_claims_found(answer: str, forbidden_claims: tuple[str, ...]) -> list[str]:
|
|
405
|
+
folded = answer.lower()
|
|
406
|
+
return [claim for claim in forbidden_claims if claim.lower() in folded]
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def extract_json_object(text: str) -> dict[str, Any]:
|
|
410
|
+
stripped = text.strip()
|
|
411
|
+
if stripped.startswith("```"):
|
|
412
|
+
stripped = stripped.strip("`")
|
|
413
|
+
if stripped.lower().startswith("json"):
|
|
414
|
+
stripped = stripped[4:].strip()
|
|
415
|
+
start = stripped.find("{")
|
|
416
|
+
end = stripped.rfind("}")
|
|
417
|
+
if start >= 0 and end >= start:
|
|
418
|
+
stripped = stripped[start : end + 1]
|
|
419
|
+
return json.loads(stripped)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def parse_judge_response(text: str) -> dict[str, Any]:
|
|
423
|
+
data = extract_json_object(text)
|
|
424
|
+
component_keys = [
|
|
425
|
+
"factual_score",
|
|
426
|
+
"required_terms_score",
|
|
427
|
+
"constraint_score",
|
|
428
|
+
"hallucination_score",
|
|
429
|
+
]
|
|
430
|
+
for key in component_keys:
|
|
431
|
+
if key in data:
|
|
432
|
+
data[key] = max(0.0, min(1.0, float(data[key])))
|
|
433
|
+
if "score" not in data and all(key in data for key in component_keys):
|
|
434
|
+
data["score"] = (
|
|
435
|
+
0.45 * data["factual_score"]
|
|
436
|
+
+ 0.25 * data["required_terms_score"]
|
|
437
|
+
+ 0.20 * data["constraint_score"]
|
|
438
|
+
+ 0.10 * data["hallucination_score"]
|
|
439
|
+
)
|
|
440
|
+
score = float(data.get("score", 0.0))
|
|
441
|
+
data["score"] = max(0.0, min(1.0, score))
|
|
442
|
+
data["passed"] = bool(data.get("passed", data["score"] >= 0.80))
|
|
443
|
+
data["missing_terms"] = list(data.get("missing_terms", []))
|
|
444
|
+
data["forbidden_claims_found"] = list(data.get("forbidden_claims_found", []))
|
|
445
|
+
data["reason"] = str(data.get("reason", ""))
|
|
446
|
+
return data
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def parse_pairwise_response(text: str) -> dict[str, Any]:
|
|
450
|
+
data = extract_json_object(text)
|
|
451
|
+
winner = str(data.get("winner", "tie")).lower()
|
|
452
|
+
if winner not in {"raw", "compressed", "tie"}:
|
|
453
|
+
winner = "tie"
|
|
454
|
+
raw_score = max(0.0, min(1.0, float(data.get("raw_score", 0.0))))
|
|
455
|
+
compressed_score = max(0.0, min(1.0, float(data.get("compressed_score", 0.0))))
|
|
456
|
+
data["winner"] = winner
|
|
457
|
+
data["raw_score"] = raw_score
|
|
458
|
+
data["compressed_score"] = compressed_score
|
|
459
|
+
data["quality_delta"] = float(data.get("quality_delta", compressed_score - raw_score))
|
|
460
|
+
data["compression_preserved_accuracy"] = bool(
|
|
461
|
+
data.get("compression_preserved_accuracy", compressed_score >= raw_score - 0.05)
|
|
462
|
+
)
|
|
463
|
+
data["reason"] = str(data.get("reason", ""))
|
|
464
|
+
return data
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def call_json_judge(
|
|
468
|
+
endpoint: str,
|
|
469
|
+
model: str,
|
|
470
|
+
prompt: str,
|
|
471
|
+
timeout: int,
|
|
472
|
+
retries: int,
|
|
473
|
+
parser: Any,
|
|
474
|
+
) -> tuple[dict[str, Any] | None, str | None, float | None]:
|
|
475
|
+
last_error = None
|
|
476
|
+
repair_note = ""
|
|
477
|
+
for attempt in range(max(1, retries + 1)):
|
|
478
|
+
try:
|
|
479
|
+
response = ollama_chat(
|
|
480
|
+
endpoint,
|
|
481
|
+
model,
|
|
482
|
+
[
|
|
483
|
+
{
|
|
484
|
+
"role": "system",
|
|
485
|
+
"content": "You are a strict evaluator. Return valid JSON only.",
|
|
486
|
+
},
|
|
487
|
+
{"role": "user", "content": prompt + repair_note},
|
|
488
|
+
],
|
|
489
|
+
timeout,
|
|
490
|
+
)
|
|
491
|
+
return parser(response_text(response)), None, response.get("_latency_s")
|
|
492
|
+
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError, ValueError) as exc:
|
|
493
|
+
last_error = str(exc)
|
|
494
|
+
if attempt == 0:
|
|
495
|
+
repair_note = (
|
|
496
|
+
"\n\nYour previous answer was not valid JSON for this benchmark. "
|
|
497
|
+
"Return only the requested JSON object, with no prose or markdown."
|
|
498
|
+
)
|
|
499
|
+
return None, last_error, None
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def deterministic_accuracy(case: BenchmarkCase, answer: str) -> dict[str, Any]:
|
|
503
|
+
missing = missing_required_terms(answer, case.required_terms)
|
|
504
|
+
forbidden = forbidden_claims_found(answer, case.must_not_claim)
|
|
505
|
+
required_score = deterministic_score(answer, case.required_terms)
|
|
506
|
+
hallucination_score = 0.0 if forbidden else 1.0
|
|
507
|
+
score = 0.80 * required_score + 0.20 * hallucination_score
|
|
508
|
+
return {
|
|
509
|
+
"score": score,
|
|
510
|
+
"required_terms_score": required_score,
|
|
511
|
+
"hallucination_score": hallucination_score,
|
|
512
|
+
"missing_terms": missing,
|
|
513
|
+
"forbidden_claims_found": forbidden,
|
|
514
|
+
"passed": score >= 0.80 and not forbidden,
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def run_variant(
|
|
519
|
+
case: BenchmarkCase,
|
|
520
|
+
variant: str,
|
|
521
|
+
context: str,
|
|
522
|
+
args: argparse.Namespace,
|
|
523
|
+
) -> dict[str, Any]:
|
|
524
|
+
user_prompt = answer_prompt(context, case.task)
|
|
525
|
+
prompt_estimated_tokens = estimate_tokens(user_prompt)
|
|
526
|
+
result: dict[str, Any] = {
|
|
527
|
+
"case_id": case.case_id,
|
|
528
|
+
"title": case.title,
|
|
529
|
+
"variant": variant,
|
|
530
|
+
"prompt_estimated_tokens": prompt_estimated_tokens,
|
|
531
|
+
"context_estimated_tokens": estimate_tokens(context),
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
if args.dry_run:
|
|
535
|
+
result.update(
|
|
536
|
+
{
|
|
537
|
+
"answer": "",
|
|
538
|
+
"answer_latency_s": 0.0,
|
|
539
|
+
"ollama_prompt_tokens": None,
|
|
540
|
+
"ollama_completion_tokens": None,
|
|
541
|
+
"deterministic_score": None,
|
|
542
|
+
"judge_score": None,
|
|
543
|
+
"accuracy_score": None,
|
|
544
|
+
"passed": None,
|
|
545
|
+
}
|
|
546
|
+
)
|
|
547
|
+
return result
|
|
548
|
+
|
|
549
|
+
try:
|
|
550
|
+
answer_response = ollama_chat(
|
|
551
|
+
args.endpoint,
|
|
552
|
+
args.model,
|
|
553
|
+
[
|
|
554
|
+
{"role": "system", "content": "You are a helpful, precise assistant."},
|
|
555
|
+
{"role": "user", "content": user_prompt},
|
|
556
|
+
],
|
|
557
|
+
args.timeout,
|
|
558
|
+
)
|
|
559
|
+
except (urllib.error.URLError, TimeoutError, OSError) as exc:
|
|
560
|
+
result.update(
|
|
561
|
+
{
|
|
562
|
+
"answer": "",
|
|
563
|
+
"answer_error": str(exc),
|
|
564
|
+
"answer_latency_s": None,
|
|
565
|
+
"ollama_prompt_tokens": None,
|
|
566
|
+
"ollama_completion_tokens": None,
|
|
567
|
+
"deterministic_score": None,
|
|
568
|
+
"judge_score": None,
|
|
569
|
+
"judge_reason": None,
|
|
570
|
+
"judge_error": None,
|
|
571
|
+
"accuracy_score": None,
|
|
572
|
+
"passed": None,
|
|
573
|
+
}
|
|
574
|
+
)
|
|
575
|
+
return result
|
|
576
|
+
|
|
577
|
+
answer = response_text(answer_response)
|
|
578
|
+
det = deterministic_accuracy(case, answer)
|
|
579
|
+
|
|
580
|
+
judge_data: dict[str, Any] | None = None
|
|
581
|
+
judge_error = None
|
|
582
|
+
if not args.skip_judge:
|
|
583
|
+
judge_data, judge_error, judge_latency_s = call_json_judge(
|
|
584
|
+
args.endpoint,
|
|
585
|
+
args.judge_model,
|
|
586
|
+
judge_prompt(case, answer),
|
|
587
|
+
args.timeout,
|
|
588
|
+
args.judge_retries,
|
|
589
|
+
parse_judge_response,
|
|
590
|
+
)
|
|
591
|
+
result["judge_latency_s"] = judge_latency_s
|
|
592
|
+
|
|
593
|
+
judge_score = judge_data["score"] if judge_data else None
|
|
594
|
+
accuracy = judge_score if judge_score is not None else det["score"]
|
|
595
|
+
passed = judge_data["passed"] if judge_data else det["passed"]
|
|
596
|
+
missing_terms = (
|
|
597
|
+
judge_data.get("missing_terms", det["missing_terms"]) if judge_data else det["missing_terms"]
|
|
598
|
+
)
|
|
599
|
+
forbidden_claims = (
|
|
600
|
+
judge_data.get("forbidden_claims_found", det["forbidden_claims_found"])
|
|
601
|
+
if judge_data
|
|
602
|
+
else det["forbidden_claims_found"]
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
result.update(
|
|
606
|
+
{
|
|
607
|
+
"answer": answer,
|
|
608
|
+
"answer_latency_s": answer_response.get("_latency_s"),
|
|
609
|
+
"ollama_prompt_tokens": answer_response.get("prompt_eval_count"),
|
|
610
|
+
"ollama_completion_tokens": answer_response.get("eval_count"),
|
|
611
|
+
"deterministic_score": det["score"],
|
|
612
|
+
"deterministic_required_terms_score": det["required_terms_score"],
|
|
613
|
+
"deterministic_hallucination_score": det["hallucination_score"],
|
|
614
|
+
"judge_score": judge_score,
|
|
615
|
+
"judge_rubric": judge_data,
|
|
616
|
+
"judge_reason": judge_data.get("reason") if judge_data else None,
|
|
617
|
+
"judge_error": judge_error,
|
|
618
|
+
"missing_terms": missing_terms,
|
|
619
|
+
"forbidden_claims_found": forbidden_claims,
|
|
620
|
+
"accuracy_score": accuracy,
|
|
621
|
+
"passed": passed,
|
|
622
|
+
}
|
|
623
|
+
)
|
|
624
|
+
return result
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def run_pairwise_judgments(
|
|
628
|
+
cases: list[BenchmarkCase],
|
|
629
|
+
results: list[dict[str, Any]],
|
|
630
|
+
args: argparse.Namespace,
|
|
631
|
+
) -> list[dict[str, Any]]:
|
|
632
|
+
if args.dry_run or args.skip_judge or args.skip_pairwise_judge:
|
|
633
|
+
return []
|
|
634
|
+
|
|
635
|
+
by_case_variant = {
|
|
636
|
+
(row["case_id"], row["variant"]): row
|
|
637
|
+
for row in results
|
|
638
|
+
if row.get("answer") and not row.get("answer_error")
|
|
639
|
+
}
|
|
640
|
+
judgments: list[dict[str, Any]] = []
|
|
641
|
+
for case in cases:
|
|
642
|
+
raw = by_case_variant.get((case.case_id, "raw"))
|
|
643
|
+
compressed = by_case_variant.get((case.case_id, "compressed"))
|
|
644
|
+
if raw is None or compressed is None:
|
|
645
|
+
continue
|
|
646
|
+
data, error, latency_s = call_json_judge(
|
|
647
|
+
args.endpoint,
|
|
648
|
+
args.judge_model,
|
|
649
|
+
pairwise_judge_prompt(case, raw["answer"], compressed["answer"]),
|
|
650
|
+
args.timeout,
|
|
651
|
+
args.judge_retries,
|
|
652
|
+
parse_pairwise_response,
|
|
653
|
+
)
|
|
654
|
+
row: dict[str, Any] = {
|
|
655
|
+
"case_id": case.case_id,
|
|
656
|
+
"title": case.title,
|
|
657
|
+
"judge_latency_s": latency_s,
|
|
658
|
+
"judge_error": error,
|
|
659
|
+
}
|
|
660
|
+
if data:
|
|
661
|
+
row.update(data)
|
|
662
|
+
judgments.append(row)
|
|
663
|
+
return judgments
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def summarize(
|
|
667
|
+
results: list[dict[str, Any]],
|
|
668
|
+
pairwise_judgments: list[dict[str, Any]] | None = None,
|
|
669
|
+
) -> dict[str, Any]:
|
|
670
|
+
by_variant: dict[str, list[dict[str, Any]]] = {}
|
|
671
|
+
for row in results:
|
|
672
|
+
by_variant.setdefault(row["variant"], []).append(row)
|
|
673
|
+
|
|
674
|
+
summary: dict[str, Any] = {}
|
|
675
|
+
for variant, rows in by_variant.items():
|
|
676
|
+
token_key = (
|
|
677
|
+
"ollama_prompt_tokens"
|
|
678
|
+
if all(row.get("ollama_prompt_tokens") is not None for row in rows)
|
|
679
|
+
else "prompt_estimated_tokens"
|
|
680
|
+
)
|
|
681
|
+
tokens = [float(row[token_key]) for row in rows]
|
|
682
|
+
accuracy = [
|
|
683
|
+
float(row["accuracy_score"])
|
|
684
|
+
for row in rows
|
|
685
|
+
if row.get("accuracy_score") is not None
|
|
686
|
+
]
|
|
687
|
+
summary[variant] = {
|
|
688
|
+
"cases": len(rows),
|
|
689
|
+
"token_metric": token_key,
|
|
690
|
+
"avg_prompt_tokens": statistics.mean(tokens),
|
|
691
|
+
"total_prompt_tokens": sum(tokens),
|
|
692
|
+
"avg_accuracy": statistics.mean(accuracy) if accuracy else None,
|
|
693
|
+
"pass_rate": (
|
|
694
|
+
sum(1 for row in rows if row.get("passed")) / len(rows)
|
|
695
|
+
if rows and rows[0].get("passed") is not None
|
|
696
|
+
else None
|
|
697
|
+
),
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
if "raw" in summary and "compressed" in summary:
|
|
701
|
+
raw = summary["raw"]["avg_prompt_tokens"]
|
|
702
|
+
compressed = summary["compressed"]["avg_prompt_tokens"]
|
|
703
|
+
summary["comparison"] = {
|
|
704
|
+
"avg_token_reduction_pct": ((raw - compressed) / raw * 100.0) if raw else 0.0,
|
|
705
|
+
"accuracy_delta": (
|
|
706
|
+
summary["compressed"]["avg_accuracy"] - summary["raw"]["avg_accuracy"]
|
|
707
|
+
if summary["raw"]["avg_accuracy"] is not None
|
|
708
|
+
and summary["compressed"]["avg_accuracy"] is not None
|
|
709
|
+
else None
|
|
710
|
+
),
|
|
711
|
+
}
|
|
712
|
+
if pairwise_judgments:
|
|
713
|
+
scored = [row for row in pairwise_judgments if not row.get("judge_error")]
|
|
714
|
+
if scored:
|
|
715
|
+
preserved = [
|
|
716
|
+
row
|
|
717
|
+
for row in scored
|
|
718
|
+
if bool(row.get("compression_preserved_accuracy", False))
|
|
719
|
+
]
|
|
720
|
+
deltas = [float(row.get("quality_delta", 0.0)) for row in scored]
|
|
721
|
+
summary["pairwise_judge"] = {
|
|
722
|
+
"cases": len(scored),
|
|
723
|
+
"compressed_wins": sum(1 for row in scored if row.get("winner") == "compressed"),
|
|
724
|
+
"raw_wins": sum(1 for row in scored if row.get("winner") == "raw"),
|
|
725
|
+
"ties": sum(1 for row in scored if row.get("winner") == "tie"),
|
|
726
|
+
"preservation_rate": len(preserved) / len(scored),
|
|
727
|
+
"avg_quality_delta": statistics.mean(deltas),
|
|
728
|
+
}
|
|
729
|
+
return summary
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def markdown_report(payload: dict[str, Any]) -> str:
|
|
733
|
+
lines = [
|
|
734
|
+
"# Context Compression Benchmark",
|
|
735
|
+
"",
|
|
736
|
+
f"Generated: {payload['generated_at']}",
|
|
737
|
+
f"Endpoint: `{payload['endpoint']}`",
|
|
738
|
+
f"Worker model: `{payload['model']}`",
|
|
739
|
+
f"Judge model: `{payload['judge_model']}`",
|
|
740
|
+
f"Dry run: `{payload['dry_run']}`",
|
|
741
|
+
"",
|
|
742
|
+
"## Summary",
|
|
743
|
+
"",
|
|
744
|
+
"| Variant | Token metric | Avg prompt tokens | Total prompt tokens | Avg accuracy | Pass rate |",
|
|
745
|
+
"|---|---:|---:|---:|---:|---:|",
|
|
746
|
+
]
|
|
747
|
+
for variant in ["raw", "compressed"]:
|
|
748
|
+
data = payload["summary"].get(variant, {})
|
|
749
|
+
lines.append(
|
|
750
|
+
"| {variant} | {metric} | {avg_tokens:.1f} | {total_tokens:.0f} | {accuracy} | {pass_rate} |".format(
|
|
751
|
+
variant=variant,
|
|
752
|
+
metric=data.get("token_metric", "-"),
|
|
753
|
+
avg_tokens=float(data.get("avg_prompt_tokens", 0.0)),
|
|
754
|
+
total_tokens=float(data.get("total_prompt_tokens", 0.0)),
|
|
755
|
+
accuracy=(
|
|
756
|
+
"-"
|
|
757
|
+
if data.get("avg_accuracy") is None
|
|
758
|
+
else f"{float(data['avg_accuracy']):.3f}"
|
|
759
|
+
),
|
|
760
|
+
pass_rate=(
|
|
761
|
+
"-"
|
|
762
|
+
if data.get("pass_rate") is None
|
|
763
|
+
else f"{float(data['pass_rate']) * 100:.1f}%"
|
|
764
|
+
),
|
|
765
|
+
)
|
|
766
|
+
)
|
|
767
|
+
comparison = payload["summary"].get("comparison", {})
|
|
768
|
+
if comparison:
|
|
769
|
+
lines.extend(
|
|
770
|
+
[
|
|
771
|
+
"",
|
|
772
|
+
"## Comparison",
|
|
773
|
+
"",
|
|
774
|
+
f"- Average token reduction: `{comparison['avg_token_reduction_pct']:.1f}%`",
|
|
775
|
+
"- Accuracy delta: `{}`".format(
|
|
776
|
+
"-"
|
|
777
|
+
if comparison.get("accuracy_delta") is None
|
|
778
|
+
else f"{comparison['accuracy_delta']:.3f}"
|
|
779
|
+
),
|
|
780
|
+
]
|
|
781
|
+
)
|
|
782
|
+
pairwise_summary = payload["summary"].get("pairwise_judge", {})
|
|
783
|
+
if pairwise_summary:
|
|
784
|
+
lines.extend(
|
|
785
|
+
[
|
|
786
|
+
"",
|
|
787
|
+
"## Pairwise Judge",
|
|
788
|
+
"",
|
|
789
|
+
f"- Cases judged: `{pairwise_summary['cases']}`",
|
|
790
|
+
f"- Compressed wins: `{pairwise_summary['compressed_wins']}`",
|
|
791
|
+
f"- Raw wins: `{pairwise_summary['raw_wins']}`",
|
|
792
|
+
f"- Ties: `{pairwise_summary['ties']}`",
|
|
793
|
+
f"- Accuracy preservation rate: `{pairwise_summary['preservation_rate'] * 100:.1f}%`",
|
|
794
|
+
f"- Average quality delta: `{pairwise_summary['avg_quality_delta']:.3f}`",
|
|
795
|
+
]
|
|
796
|
+
)
|
|
797
|
+
lines.extend(
|
|
798
|
+
[
|
|
799
|
+
"",
|
|
800
|
+
"## Case Details",
|
|
801
|
+
"",
|
|
802
|
+
"| Case | Variant | Estimated prompt tokens | Ollama prompt tokens | Accuracy | Missing terms | Forbidden claims | Passed | Error |",
|
|
803
|
+
"|---|---|---:|---:|---:|---|---|---|---|",
|
|
804
|
+
]
|
|
805
|
+
)
|
|
806
|
+
for row in payload["results"]:
|
|
807
|
+
error = row.get("answer_error") or row.get("judge_error") or ""
|
|
808
|
+
missing = ", ".join(row.get("missing_terms") or [])
|
|
809
|
+
forbidden = ", ".join(row.get("forbidden_claims_found") or [])
|
|
810
|
+
lines.append(
|
|
811
|
+
"| {case} | {variant} | {estimate} | {ollama} | {accuracy} | {missing} | {forbidden} | {passed} | {error} |".format(
|
|
812
|
+
case=row["case_id"],
|
|
813
|
+
variant=row["variant"],
|
|
814
|
+
estimate=row["prompt_estimated_tokens"],
|
|
815
|
+
ollama=row.get("ollama_prompt_tokens") or "-",
|
|
816
|
+
accuracy=(
|
|
817
|
+
"-"
|
|
818
|
+
if row.get("accuracy_score") is None
|
|
819
|
+
else f"{float(row['accuracy_score']):.3f}"
|
|
820
|
+
),
|
|
821
|
+
missing=missing.replace("|", "\\|") or "-",
|
|
822
|
+
forbidden=forbidden.replace("|", "\\|") or "-",
|
|
823
|
+
passed=row.get("passed"),
|
|
824
|
+
error=error.replace("|", "\\|"),
|
|
825
|
+
)
|
|
826
|
+
)
|
|
827
|
+
if payload.get("pairwise_judgments"):
|
|
828
|
+
lines.extend(
|
|
829
|
+
[
|
|
830
|
+
"",
|
|
831
|
+
"## Pairwise Details",
|
|
832
|
+
"",
|
|
833
|
+
"| Case | Winner | Raw score | Compressed score | Delta | Preserved | Reason | Error |",
|
|
834
|
+
"|---|---|---:|---:|---:|---|---|---|",
|
|
835
|
+
]
|
|
836
|
+
)
|
|
837
|
+
for row in payload["pairwise_judgments"]:
|
|
838
|
+
lines.append(
|
|
839
|
+
"| {case} | {winner} | {raw_score} | {compressed_score} | {delta} | {preserved} | {reason} | {error} |".format(
|
|
840
|
+
case=row["case_id"],
|
|
841
|
+
winner=row.get("winner", "-"),
|
|
842
|
+
raw_score=(
|
|
843
|
+
"-" if row.get("raw_score") is None else f"{float(row['raw_score']):.3f}"
|
|
844
|
+
),
|
|
845
|
+
compressed_score=(
|
|
846
|
+
"-"
|
|
847
|
+
if row.get("compressed_score") is None
|
|
848
|
+
else f"{float(row['compressed_score']):.3f}"
|
|
849
|
+
),
|
|
850
|
+
delta=(
|
|
851
|
+
"-"
|
|
852
|
+
if row.get("quality_delta") is None
|
|
853
|
+
else f"{float(row['quality_delta']):.3f}"
|
|
854
|
+
),
|
|
855
|
+
preserved=row.get("compression_preserved_accuracy"),
|
|
856
|
+
reason=str(row.get("reason", "")).replace("|", "\\|") or "-",
|
|
857
|
+
error=str(row.get("judge_error", "")).replace("|", "\\|") or "-",
|
|
858
|
+
)
|
|
859
|
+
)
|
|
860
|
+
return "\n".join(lines) + "\n"
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def main() -> int:
|
|
864
|
+
parser = argparse.ArgumentParser(
|
|
865
|
+
description="Benchmark raw vs compressed context token use and answer accuracy."
|
|
866
|
+
)
|
|
867
|
+
parser.add_argument("--endpoint", default=DEFAULT_ENDPOINT)
|
|
868
|
+
parser.add_argument("--model", default=DEFAULT_MODEL)
|
|
869
|
+
parser.add_argument("--judge-model", default=DEFAULT_MODEL)
|
|
870
|
+
parser.add_argument("--timeout", type=int, default=180)
|
|
871
|
+
parser.add_argument("--output-dir", default="benchmark_results")
|
|
872
|
+
parser.add_argument("--dry-run", action="store_true")
|
|
873
|
+
parser.add_argument("--skip-judge", action="store_true")
|
|
874
|
+
parser.add_argument("--skip-pairwise-judge", action="store_true")
|
|
875
|
+
parser.add_argument("--judge-retries", type=int, default=1)
|
|
876
|
+
parser.add_argument("--pass-threshold", type=float, default=0.75)
|
|
877
|
+
args = parser.parse_args()
|
|
878
|
+
|
|
879
|
+
cases = build_cases()
|
|
880
|
+
results = []
|
|
881
|
+
for case in cases:
|
|
882
|
+
results.append(run_variant(case, "raw", case.raw_context, args))
|
|
883
|
+
results.append(run_variant(case, "compressed", case.compressed_context, args))
|
|
884
|
+
pairwise_judgments = run_pairwise_judgments(cases, results, args)
|
|
885
|
+
|
|
886
|
+
payload = {
|
|
887
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
888
|
+
"endpoint": args.endpoint,
|
|
889
|
+
"model": args.model,
|
|
890
|
+
"judge_model": args.judge_model,
|
|
891
|
+
"dry_run": args.dry_run,
|
|
892
|
+
"results": results,
|
|
893
|
+
"pairwise_judgments": pairwise_judgments,
|
|
894
|
+
"summary": summarize(results, pairwise_judgments),
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
output_dir = Path(args.output_dir)
|
|
898
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
899
|
+
json_path = output_dir / "context_compression_benchmark.json"
|
|
900
|
+
md_path = output_dir / "context_compression_benchmark.md"
|
|
901
|
+
json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
|
902
|
+
md_path.write_text(markdown_report(payload), encoding="utf-8")
|
|
903
|
+
|
|
904
|
+
print(f"Wrote {json_path}")
|
|
905
|
+
print(f"Wrote {md_path}")
|
|
906
|
+
print(json.dumps(payload["summary"], indent=2))
|
|
907
|
+
return 0
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
if __name__ == "__main__":
|
|
911
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
2
|
+
from typing import Any, Optional, Dict, List
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
import uuid
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
def utcnow():
|
|
8
|
+
return datetime.now(timezone.utc)
|
|
9
|
+
|
|
10
|
+
VALID_TYPES = {"Fact", "Hypothesis", "Task", "Decision", "Evidence", "Constraint"}
|
|
11
|
+
VALID_STATUSES = {
|
|
12
|
+
"draft",
|
|
13
|
+
"validated",
|
|
14
|
+
"used",
|
|
15
|
+
"archived",
|
|
16
|
+
"hypothesis",
|
|
17
|
+
"tested",
|
|
18
|
+
"confirmed",
|
|
19
|
+
"rejected",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
GENERIC_TRANSITIONS = {
|
|
23
|
+
"draft": {"validated", "archived"},
|
|
24
|
+
"validated": {"used", "archived"},
|
|
25
|
+
"used": {"archived"},
|
|
26
|
+
"archived": set(),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
HYPOTHESIS_TRANSITIONS = {
|
|
30
|
+
"draft": {"tested", "archived"},
|
|
31
|
+
"hypothesis": {"tested", "archived"},
|
|
32
|
+
"tested": {"confirmed", "rejected", "archived"},
|
|
33
|
+
"confirmed": {"archived"},
|
|
34
|
+
"rejected": {"archived"},
|
|
35
|
+
"archived": set(),
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
class MNMemoryItem(BaseModel):
|
|
39
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
40
|
+
|
|
41
|
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
42
|
+
type: str
|
|
43
|
+
content: Any
|
|
44
|
+
status: str
|
|
45
|
+
confidence: float = 1.0
|
|
46
|
+
source: str
|
|
47
|
+
created_at: datetime = Field(default_factory=utcnow)
|
|
48
|
+
updated_at: datetime = Field(default_factory=utcnow)
|
|
49
|
+
expires_at: Optional[datetime] = None
|
|
50
|
+
version: int = 1
|
|
51
|
+
|
|
52
|
+
@field_validator("type")
|
|
53
|
+
@classmethod
|
|
54
|
+
def validate_type(cls, value: str) -> str:
|
|
55
|
+
if value not in VALID_TYPES:
|
|
56
|
+
raise ValueError(f"invalid memory type: {value}")
|
|
57
|
+
return value
|
|
58
|
+
|
|
59
|
+
@field_validator("status")
|
|
60
|
+
@classmethod
|
|
61
|
+
def validate_status(cls, value: str) -> str:
|
|
62
|
+
if value not in VALID_STATUSES:
|
|
63
|
+
raise ValueError(f"invalid memory status: {value}")
|
|
64
|
+
return value
|
|
65
|
+
|
|
66
|
+
@field_validator("confidence")
|
|
67
|
+
@classmethod
|
|
68
|
+
def validate_confidence(cls, value: float) -> float:
|
|
69
|
+
if not 0.0 <= value <= 1.0:
|
|
70
|
+
raise ValueError("confidence must be between 0.0 and 1.0")
|
|
71
|
+
return value
|
|
72
|
+
|
|
73
|
+
@field_validator("source")
|
|
74
|
+
@classmethod
|
|
75
|
+
def validate_source(cls, value: str) -> str:
|
|
76
|
+
if not value.strip():
|
|
77
|
+
raise ValueError("source must not be empty")
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
@model_validator(mode="after")
|
|
81
|
+
def validate_content(self) -> "MNMemoryItem":
|
|
82
|
+
if not isinstance(self.content, dict):
|
|
83
|
+
raise ValueError("content must be a structured object")
|
|
84
|
+
return self
|
|
85
|
+
|
|
86
|
+
class MNMemoryEdge(BaseModel):
|
|
87
|
+
source_id: str
|
|
88
|
+
target_id: str
|
|
89
|
+
relation: str
|
|
90
|
+
|
|
91
|
+
@field_validator("relation")
|
|
92
|
+
@classmethod
|
|
93
|
+
def validate_relation(cls, value: str) -> str:
|
|
94
|
+
if not value.strip():
|
|
95
|
+
raise ValueError("relation must not be empty")
|
|
96
|
+
return value
|
|
97
|
+
|
|
98
|
+
class MNWorkingMemory:
|
|
99
|
+
def __init__(self):
|
|
100
|
+
self._items: Dict[str, MNMemoryItem] = {}
|
|
101
|
+
self._edges: List[MNMemoryEdge] = []
|
|
102
|
+
|
|
103
|
+
def add(self, item: MNMemoryItem) -> None:
|
|
104
|
+
self._items[item.id] = item
|
|
105
|
+
|
|
106
|
+
def update(self, item_id: str, updates: dict) -> None:
|
|
107
|
+
if item_id in self._items:
|
|
108
|
+
item = self._items[item_id]
|
|
109
|
+
updates = dict(updates)
|
|
110
|
+
expected_version = updates.pop("expected_version", None)
|
|
111
|
+
if expected_version is not None and expected_version != item.version:
|
|
112
|
+
raise ValueError(f"version conflict: item has {item.version}, expected {expected_version}")
|
|
113
|
+
|
|
114
|
+
if "status" in updates:
|
|
115
|
+
self._validate_transition(item.type, item.status, updates["status"])
|
|
116
|
+
|
|
117
|
+
for key, value in updates.items():
|
|
118
|
+
if hasattr(item, key):
|
|
119
|
+
setattr(item, key, value)
|
|
120
|
+
item.updated_at = utcnow()
|
|
121
|
+
item.version += 1
|
|
122
|
+
|
|
123
|
+
def get(self, item_id: str) -> Optional[MNMemoryItem]:
|
|
124
|
+
item = self._items.get(item_id)
|
|
125
|
+
if item is None or self._is_expired(item):
|
|
126
|
+
return None
|
|
127
|
+
return item
|
|
128
|
+
|
|
129
|
+
def query(self, filters: dict) -> List[MNMemoryItem]:
|
|
130
|
+
results = []
|
|
131
|
+
for item in self._items.values():
|
|
132
|
+
if self._is_expired(item):
|
|
133
|
+
continue
|
|
134
|
+
match = True
|
|
135
|
+
for key, value in filters.items():
|
|
136
|
+
if getattr(item, key, None) != value:
|
|
137
|
+
match = False
|
|
138
|
+
break
|
|
139
|
+
if match:
|
|
140
|
+
results.append(item)
|
|
141
|
+
return results
|
|
142
|
+
|
|
143
|
+
def link(self, source_id: str, target_id: str, relation: str) -> None:
|
|
144
|
+
if source_id not in self._items:
|
|
145
|
+
raise KeyError(f"source item not found: {source_id}")
|
|
146
|
+
if target_id not in self._items:
|
|
147
|
+
raise KeyError(f"target item not found: {target_id}")
|
|
148
|
+
edge = MNMemoryEdge(source_id=source_id, target_id=target_id, relation=relation)
|
|
149
|
+
self._edges.append(edge)
|
|
150
|
+
|
|
151
|
+
def invalidate(self, item_id: str) -> None:
|
|
152
|
+
if item_id in self._items:
|
|
153
|
+
self._validate_transition(self._items[item_id].type, self._items[item_id].status, "archived")
|
|
154
|
+
self._items[item_id].status = "archived"
|
|
155
|
+
self._items[item_id].updated_at = utcnow()
|
|
156
|
+
self._items[item_id].version += 1
|
|
157
|
+
|
|
158
|
+
def get_context(self, agent_role: str, goal_id: str) -> List[MNMemoryItem]:
|
|
159
|
+
"""
|
|
160
|
+
Retrieves selective context for an agent around a goal.
|
|
161
|
+
Items must be visible to the role and either match the goal directly
|
|
162
|
+
or sit within two graph hops of a matching item.
|
|
163
|
+
"""
|
|
164
|
+
seeds = {
|
|
165
|
+
item.id
|
|
166
|
+
for item in self._items.values()
|
|
167
|
+
if not self._is_expired(item)
|
|
168
|
+
and item.status != "archived"
|
|
169
|
+
and (item.id == goal_id or item.content.get("goal_id") == goal_id)
|
|
170
|
+
}
|
|
171
|
+
reachable = self._expand_graph(seeds, max_depth=2)
|
|
172
|
+
|
|
173
|
+
return [
|
|
174
|
+
item
|
|
175
|
+
for item in self._items.values()
|
|
176
|
+
if item.id in reachable
|
|
177
|
+
and item.status != "archived"
|
|
178
|
+
and not self._is_expired(item)
|
|
179
|
+
and self._is_visible_to_role(agent_role, item)
|
|
180
|
+
]
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def _is_expired(item: MNMemoryItem) -> bool:
|
|
184
|
+
return item.expires_at is not None and utcnow() > item.expires_at
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
def _validate_transition(item_type: str, current: str, new_status: str) -> None:
|
|
188
|
+
if current == new_status:
|
|
189
|
+
return
|
|
190
|
+
transitions = HYPOTHESIS_TRANSITIONS if item_type == "Hypothesis" else GENERIC_TRANSITIONS
|
|
191
|
+
if new_status not in transitions.get(current, set()):
|
|
192
|
+
raise ValueError(f"invalid status transition for {item_type}: {current} -> {new_status}")
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def _is_visible_to_role(agent_role: str, item: MNMemoryItem) -> bool:
|
|
196
|
+
acl = item.content.get("acl")
|
|
197
|
+
if isinstance(acl, dict):
|
|
198
|
+
deny_roles = acl.get("deny_roles", [])
|
|
199
|
+
if agent_role in deny_roles:
|
|
200
|
+
return False
|
|
201
|
+
allow_roles = acl.get("allow_roles")
|
|
202
|
+
if allow_roles is not None:
|
|
203
|
+
return agent_role in allow_roles
|
|
204
|
+
|
|
205
|
+
if agent_role == "planner":
|
|
206
|
+
return item.type != "Evidence"
|
|
207
|
+
if agent_role == "executor":
|
|
208
|
+
return item.type != "Hypothesis"
|
|
209
|
+
if agent_role == "reviewer":
|
|
210
|
+
return True
|
|
211
|
+
return item.source == agent_role
|
|
212
|
+
|
|
213
|
+
def _expand_graph(self, seed_ids: set[str], max_depth: int) -> set[str]:
|
|
214
|
+
reachable = set(seed_ids)
|
|
215
|
+
frontier = set(seed_ids)
|
|
216
|
+
for _ in range(max_depth):
|
|
217
|
+
next_frontier = set()
|
|
218
|
+
for edge in self._edges:
|
|
219
|
+
if edge.source_id in frontier and edge.target_id not in reachable:
|
|
220
|
+
next_frontier.add(edge.target_id)
|
|
221
|
+
if edge.target_id in frontier and edge.source_id not in reachable:
|
|
222
|
+
next_frontier.add(edge.source_id)
|
|
223
|
+
if not next_frontier:
|
|
224
|
+
break
|
|
225
|
+
reachable.update(next_frontier)
|
|
226
|
+
frontier = next_frontier
|
|
227
|
+
return reachable
|
|
228
|
+
|
|
229
|
+
# --- Methods for compatibility with MicroNeuron payloads (Serialization) ---
|
|
230
|
+
def to_dict(self) -> dict:
|
|
231
|
+
return {
|
|
232
|
+
"items": [item.model_dump(mode="json") for item in self._items.values()],
|
|
233
|
+
"edges": [edge.model_dump(mode="json") for edge in self._edges]
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
@classmethod
|
|
237
|
+
def from_dict(cls, data: dict) -> "MNWorkingMemory":
|
|
238
|
+
wm = cls()
|
|
239
|
+
for item_data in data.get("items", []):
|
|
240
|
+
item = MNMemoryItem.model_validate(item_data)
|
|
241
|
+
wm.add(item)
|
|
242
|
+
for edge_data in data.get("edges", []):
|
|
243
|
+
edge = MNMemoryEdge.model_validate(edge_data)
|
|
244
|
+
wm._edges.append(edge)
|
|
245
|
+
return wm
|
|
246
|
+
|
|
247
|
+
def to_json(self) -> str:
|
|
248
|
+
return json.dumps(self.to_dict())
|
|
249
|
+
|
|
250
|
+
@classmethod
|
|
251
|
+
def from_json(cls, json_str: str) -> "MNWorkingMemory":
|
|
252
|
+
return cls.from_dict(json.loads(json_str))
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# Backward-compatible aliases for older SDK users.
|
|
256
|
+
MemoryItem = MNMemoryItem
|
|
257
|
+
MemoryEdge = MNMemoryEdge
|
|
258
|
+
WorkingMemory = MNWorkingMemory
|