mirrorneuron-membrane-python-sdk 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.3
2
+ Name: mirrorneuron-membrane-python-sdk
3
+ Version: 1.0.1
4
+ Summary: Python SDK shell for consuming the Rust Membrane context engine.
5
+ Requires-Python: >=3.10
6
+ License: Proprietary
7
+ Requires-Dist: pydantic>=2.0,<3.0
8
+ Provides-Extra: dev
9
+ Requires-Dist: pytest>=8.0; extra == "dev"
10
+ Provides-Extra: compression
11
+ Requires-Dist: llmlingua>=0.2.2; extra == "compression"
@@ -0,0 +1,8 @@
1
+ mn_context_engine_sdk/__init__.py,sha256=b_cv-mjvkN12vXM6F4FqFy3w-jtJ5oxmpppL1o1WA0o,301
2
+ mn_context_engine_sdk/benchmarks/__init__.py,sha256=vnIljW7ZUJMday5EqJNyyn2LRnIWj4jbcufLEaCj7sQ,58
3
+ mn_context_engine_sdk/benchmarks/context_compression_accuracy_benchmark.py,sha256=3Gj6VGJHf3_zQIHWhLfJtR8ynJuvsnoOJcGrAqc5Avo,34882
4
+ mn_context_engine_sdk/working_memory.py,sha256=uPSIJ3OUrqe8j1YLIjqy3RE5IFgZl8_9r8pJw6y_Cyg,8934
5
+ mirrorneuron_membrane_python_sdk-1.0.1.dist-info/METADATA,sha256=Dgo_iYqfiGagYZgd7AXe7IFs52p7i2EgP77xlUl4WMw,376
6
+ mirrorneuron_membrane_python_sdk-1.0.1.dist-info/WHEEL,sha256=s7mhKzNlJuzZwFeKTmqJmaumsjk0UzyKYJqivS04jas,87
7
+ mirrorneuron_membrane_python_sdk-1.0.1.dist-info/entry_points.txt,sha256=yTQ3U_D2lUs3VnQOtw87yd8YmwFjQYk0SVGPeLGOsgg,130
8
+ mirrorneuron_membrane_python_sdk-1.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: mn-build-backend
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ mn-context-compression-benchmark = mn_context_engine_sdk.benchmarks.context_compression_accuracy_benchmark:main
@@ -0,0 +1,19 @@
1
+ from .working_memory import (
2
+ MemoryEdge,
3
+ MemoryItem,
4
+ MNMemoryEdge,
5
+ MNMemoryItem,
6
+ MNWorkingMemory,
7
+ WorkingMemory,
8
+ utcnow,
9
+ )
10
+
11
+ __all__ = [
12
+ "MNMemoryEdge",
13
+ "MNMemoryItem",
14
+ "MNWorkingMemory",
15
+ "MemoryEdge",
16
+ "MemoryItem",
17
+ "WorkingMemory",
18
+ "utcnow",
19
+ ]
@@ -0,0 +1 @@
1
+ """Benchmark entry points for the Membrane Python SDK."""
@@ -0,0 +1,911 @@
1
+ #!/usr/bin/env python3
2
+ """Compare raw vs compressed context on token use and answer accuracy.
3
+
4
+ The benchmark uses an Ollama-compatible /api/chat endpoint for both the
5
+ answering model and the judging model. It intentionally does not require the
6
+ Rust gRPC service; this keeps it useful for CI smoke checks, remote model
7
+ testing, and quick local experiments.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ import statistics
15
+ import time
16
+ import urllib.error
17
+ import urllib.request
18
+ from dataclasses import dataclass
19
+ from datetime import datetime, timezone
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+
24
+ DEFAULT_ENDPOINT = "http://192.168.4.173:11434/api/chat"
25
+ DEFAULT_MODEL = "nemotron3:33b"
26
+ CHARS_PER_TOKEN = 4
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class BenchmarkCase:
31
+ case_id: str
32
+ title: str
33
+ task: str
34
+ raw_context: str
35
+ compressed_context: str
36
+ expected_answer: str
37
+ required_terms: tuple[str, ...]
38
+ must_not_claim: tuple[str, ...] = ()
39
+
40
+
41
+ def estimate_tokens(text: str) -> int:
42
+ return max(1, (len(text) + CHARS_PER_TOKEN - 1) // CHARS_PER_TOKEN)
43
+
44
+
45
+ def filler(label: str, count: int) -> str:
46
+ blocks = []
47
+ for idx in range(count):
48
+ blocks.append(
49
+ f"[distractor-{label}-{idx:03d}] Prior agent note: this item is stale, "
50
+ f"superseded, or unrelated to the active goal. It mentions generic "
51
+ f"planning details, old hypotheses, verbose log text, and abandoned "
52
+ f"tool outputs that should not affect the final answer."
53
+ )
54
+ return "\n".join(blocks)
55
+
56
+
57
+ def build_cases() -> list[BenchmarkCase]:
58
+ audit_raw = f"""
59
+ RAW EVENT LOG FOR JOB finance-audit-77
60
+ {filler("audit-pre", 70)}
61
+ [task T-AUDIT-77] Determine whether the assistant disclosed the $49 activation fee before the user accepted the subscription.
62
+ [constraint C-FEE-001] The fee disclosure must happen before agreement. If disclosure happens after agreement, classify as FAIL.
63
+ [evidence EV-FEE-100] 2026-05-01T09:12:04Z assistant: "The plan includes a $49 activation fee before we continue."
64
+ [evidence EV-AGREE-200] 2026-05-01T09:13:11Z user: "I accept the subscription."
65
+ [evidence EV-FEE-OLD] 2026-04-28T11:00:00Z assistant disclosed a different legacy fee in an unrelated session.
66
+ [decision D-DRAFT] Earlier draft said unknown, but it was before EV-FEE-100 was extracted.
67
+ {filler("audit-post", 70)}
68
+ DO_NOT_LOSE: task=T-AUDIT-77, constraint=C-FEE-001, evidence=EV-FEE-100, evidence=EV-AGREE-200.
69
+ """
70
+ audit_compressed = json.dumps(
71
+ {
72
+ "context_packet": {
73
+ "objective": {
74
+ "focus_id": "T-AUDIT-77",
75
+ "current_user_goal": "Decide if fee disclosure happened before agreement.",
76
+ },
77
+ "hard_constraints": [
78
+ {
79
+ "id": "C-FEE-001",
80
+ "text": "Fee disclosure must happen before agreement. After agreement means FAIL.",
81
+ }
82
+ ],
83
+ "retrieved_evidence": [
84
+ {
85
+ "id": "EV-FEE-100",
86
+ "timestamp": "2026-05-01T09:12:04Z",
87
+ "quote": "The plan includes a $49 activation fee before we continue.",
88
+ },
89
+ {
90
+ "id": "EV-AGREE-200",
91
+ "timestamp": "2026-05-01T09:13:11Z",
92
+ "quote": "I accept the subscription.",
93
+ },
94
+ ],
95
+ "do_not_lose": ["T-AUDIT-77", "C-FEE-001", "EV-FEE-100", "EV-AGREE-200"],
96
+ }
97
+ },
98
+ separators=(",", ":"),
99
+ )
100
+
101
+ incident_raw = f"""
102
+ RAW MULTI-AGENT INCIDENT CONTEXT FOR JOB incident-414
103
+ {filler("incident-pre", 90)}
104
+ [task T-INC-414] Identify the root cause and cite the decisive logs.
105
+ [constraint C-INC-010] Do not blame the database unless DB-OK-9 is contradicted by newer evidence.
106
+ [log DB-OK-9] 2026-05-01T02:14:05Z database p95 latency remained below 11ms and no lock waits were observed.
107
+ [log API-ERR-22] 2026-05-01T02:16:39Z api-gateway began returning 502 for checkout requests.
108
+ [log CACHE-MISS-7] 2026-05-01T02:16:40Z feature flag ff_checkout_cache_v2 enabled in us-east.
109
+ [log TRACE-ROOT-5] 2026-05-01T02:16:41Z checkout service error: missing cache namespace checkout:v2:prices.
110
+ [fix FIX-88] Roll back ff_checkout_cache_v2 or create namespace checkout:v2:prices before enabling.
111
+ [stale hypothesis H-OLD-1] CPU saturation caused outage. Rejected after metrics review.
112
+ {filler("incident-post", 80)}
113
+ DO_NOT_LOSE: T-INC-414, C-INC-010, DB-OK-9, API-ERR-22, CACHE-MISS-7, TRACE-ROOT-5, FIX-88.
114
+ """
115
+ incident_compressed = json.dumps(
116
+ {
117
+ "context_packet": {
118
+ "objective": {
119
+ "focus_id": "T-INC-414",
120
+ "current_user_goal": "Identify root cause and next fix.",
121
+ },
122
+ "hard_constraints": [
123
+ {
124
+ "id": "C-INC-010",
125
+ "text": "Do not blame the database unless DB-OK-9 is contradicted.",
126
+ }
127
+ ],
128
+ "retrieved_evidence": [
129
+ {"id": "DB-OK-9", "summary": "Database p95 < 11ms; no lock waits."},
130
+ {"id": "API-ERR-22", "summary": "API gateway 502s began at 02:16:39Z."},
131
+ {
132
+ "id": "CACHE-MISS-7",
133
+ "summary": "ff_checkout_cache_v2 enabled in us-east at 02:16:40Z.",
134
+ },
135
+ {
136
+ "id": "TRACE-ROOT-5",
137
+ "summary": "Checkout failed because cache namespace checkout:v2:prices was missing.",
138
+ },
139
+ ],
140
+ "shared_state": [
141
+ {
142
+ "id": "FIX-88",
143
+ "summary": "Rollback ff_checkout_cache_v2 or create checkout:v2:prices before enabling.",
144
+ }
145
+ ],
146
+ "do_not_lose": [
147
+ "T-INC-414",
148
+ "C-INC-010",
149
+ "DB-OK-9",
150
+ "API-ERR-22",
151
+ "CACHE-MISS-7",
152
+ "TRACE-ROOT-5",
153
+ "FIX-88",
154
+ "checkout:v2:prices",
155
+ ],
156
+ }
157
+ },
158
+ separators=(",", ":"),
159
+ )
160
+
161
+ handoff_raw = f"""
162
+ RAW AGENT HANDOFF CONTEXT FOR JOB code-review-209
163
+ {filler("handoff-pre", 75)}
164
+ [task T-CODE-209] Prepare the next coding agent to finish the Membrane context compiler safely.
165
+ [user-constraint UC-1] Keep existing gRPC APIs backward compatible.
166
+ [user-constraint UC-2] LLMLingua/model compression must be optional; the system must work without GPU or Python model dependencies.
167
+ [finding F-1] CompileContext should preserve exact IDs, source_refs, do_not_lose, and hard constraints.
168
+ [finding F-2] Deterministic compression should run before model compression.
169
+ [finding F-3] If model compression drops pinned terms, reject it and return deterministic packet.
170
+ [handoff H-209] Next agent should add docs, tests, and an operator flag named MN_CONTEXT_MODEL_COMPRESSION_ENABLED.
171
+ {filler("handoff-post", 75)}
172
+ DO_NOT_LOSE: T-CODE-209, UC-1, UC-2, F-1, F-2, F-3, H-209, MN_CONTEXT_MODEL_COMPRESSION_ENABLED.
173
+ """
174
+ handoff_compressed = json.dumps(
175
+ {
176
+ "context_packet": {
177
+ "objective": {
178
+ "focus_id": "T-CODE-209",
179
+ "current_user_goal": "Prepare next coding agent handoff.",
180
+ },
181
+ "hard_constraints": [
182
+ {"id": "UC-1", "text": "Keep existing gRPC APIs backward compatible."},
183
+ {
184
+ "id": "UC-2",
185
+ "text": "Model compression must be optional; system works without GPU/Python model deps.",
186
+ },
187
+ ],
188
+ "shared_state": [
189
+ {
190
+ "id": "F-1",
191
+ "summary": "CompileContext preserves exact IDs, source_refs, do_not_lose, constraints.",
192
+ },
193
+ {"id": "F-2", "summary": "Run deterministic compression before model compression."},
194
+ {
195
+ "id": "F-3",
196
+ "summary": "Reject model output if pinned terms are lost.",
197
+ },
198
+ ],
199
+ "handoff": {
200
+ "id": "H-209",
201
+ "next_expected_action": "Add docs, tests, and MN_CONTEXT_MODEL_COMPRESSION_ENABLED operator flag.",
202
+ },
203
+ "do_not_lose": [
204
+ "T-CODE-209",
205
+ "UC-1",
206
+ "UC-2",
207
+ "F-1",
208
+ "F-2",
209
+ "F-3",
210
+ "H-209",
211
+ "MN_CONTEXT_MODEL_COMPRESSION_ENABLED",
212
+ ],
213
+ }
214
+ },
215
+ separators=(",", ":"),
216
+ )
217
+
218
+ return [
219
+ BenchmarkCase(
220
+ case_id="finance_fee_ordering",
221
+ title="Fee disclosure ordering",
222
+ task=(
223
+ "Decide PASS or FAIL. Cite the decisive evidence IDs and give one sentence of reasoning."
224
+ ),
225
+ raw_context=audit_raw,
226
+ compressed_context=audit_compressed,
227
+ expected_answer=(
228
+ "PASS. EV-FEE-100 at 09:12:04Z disclosed the $49 activation fee before "
229
+ "EV-AGREE-200 at 09:13:11Z accepted the subscription, satisfying C-FEE-001."
230
+ ),
231
+ required_terms=("PASS", "EV-FEE-100", "EV-AGREE-200", "C-FEE-001"),
232
+ must_not_claim=("FAIL", "unknown", "after agreement"),
233
+ ),
234
+ BenchmarkCase(
235
+ case_id="incident_root_cause",
236
+ title="Incident root cause",
237
+ task=(
238
+ "Identify the root cause, say what not to blame, and cite the fix ID."
239
+ ),
240
+ raw_context=incident_raw,
241
+ compressed_context=incident_compressed,
242
+ expected_answer=(
243
+ "Root cause was enabling ff_checkout_cache_v2 without namespace "
244
+ "checkout:v2:prices, supported by CACHE-MISS-7 and TRACE-ROOT-5. "
245
+ "Do not blame the database because DB-OK-9 shows it was healthy. Use FIX-88."
246
+ ),
247
+ required_terms=(
248
+ "ff_checkout_cache_v2",
249
+ "checkout:v2:prices",
250
+ "TRACE-ROOT-5",
251
+ "DB-OK-9",
252
+ "FIX-88",
253
+ ),
254
+ must_not_claim=("database root cause", "CPU saturation", "H-OLD-1 caused"),
255
+ ),
256
+ BenchmarkCase(
257
+ case_id="agent_handoff_constraints",
258
+ title="Agent handoff constraints",
259
+ task=(
260
+ "Write the next-agent handoff in 3 bullets. Include the required operator flag."
261
+ ),
262
+ raw_context=handoff_raw,
263
+ compressed_context=handoff_compressed,
264
+ expected_answer=(
265
+ "The handoff must preserve backward-compatible gRPC APIs, keep LLMLingua optional "
266
+ "so no GPU/Python model dependency is required, preserve pinned terms, and add "
267
+ "MN_CONTEXT_MODEL_COMPRESSION_ENABLED with tests and docs."
268
+ ),
269
+ required_terms=(
270
+ "backward compatible",
271
+ "optional",
272
+ "MN_CONTEXT_MODEL_COMPRESSION_ENABLED",
273
+ "pinned",
274
+ ),
275
+ must_not_claim=("required GPU", "always use LLMLingua", "breaking API"),
276
+ ),
277
+ ]
278
+
279
+
280
+ def answer_prompt(context: str, task: str) -> str:
281
+ return f"""You are a precise multi-agent runtime worker.
282
+
283
+ Use only the context below. Preserve exact IDs and constraints.
284
+
285
+ CONTEXT:
286
+ {context}
287
+
288
+ TASK:
289
+ {task}
290
+
291
+ Answer concisely. Cite exact IDs when available."""
292
+
293
+
294
+ def judge_prompt(case: BenchmarkCase, answer: str) -> str:
295
+ return f"""You are an exacting benchmark judge for a context-compression experiment.
296
+
297
+ Evaluate whether the candidate answer correctly solves the task using the expected answer as ground truth.
298
+ Do not reward verbosity. Penalize missing exact IDs, inverted decisions, unsupported claims, or contradictions.
299
+
300
+ Task: {case.task}
301
+
302
+ Expected answer:
303
+ {case.expected_answer}
304
+
305
+ Required exact terms that should appear when relevant:
306
+ {json.dumps(case.required_terms)}
307
+
308
+ Claims that must NOT appear:
309
+ {json.dumps(case.must_not_claim)}
310
+
311
+ Candidate answer:
312
+ {answer}
313
+
314
+ Scoring rubric:
315
+ - factual_score: 0.0-1.0. Correct final conclusion and reasoning.
316
+ - required_terms_score: 0.0-1.0. Preserves required IDs/terms.
317
+ - constraint_score: 0.0-1.0. Obeys hard constraints and does not invert them.
318
+ - hallucination_score: 0.0-1.0. 1.0 means no unsupported or forbidden claims.
319
+ - score: weighted final score = 0.45*factual + 0.25*required_terms + 0.20*constraint + 0.10*hallucination.
320
+ - passed: true only if score >= 0.80 and no critical contradiction exists.
321
+
322
+ Return strict JSON only:
323
+ {{
324
+ "factual_score": <0.0 to 1.0>,
325
+ "required_terms_score": <0.0 to 1.0>,
326
+ "constraint_score": <0.0 to 1.0>,
327
+ "hallucination_score": <0.0 to 1.0>,
328
+ "score": <0.0 to 1.0>,
329
+ "passed": <true or false>,
330
+ "missing_terms": ["..."],
331
+ "forbidden_claims_found": ["..."],
332
+ "reason": "<short reason>"
333
+ }}"""
334
+
335
+
336
+ def pairwise_judge_prompt(case: BenchmarkCase, raw_answer: str, compressed_answer: str) -> str:
337
+ return f"""You are judging whether context compression preserved answer quality.
338
+
339
+ Task: {case.task}
340
+ Expected answer: {case.expected_answer}
341
+ Required terms: {json.dumps(case.required_terms)}
342
+ Forbidden claims: {json.dumps(case.must_not_claim)}
343
+
344
+ Raw-context answer:
345
+ {raw_answer}
346
+
347
+ Compressed-context answer:
348
+ {compressed_answer}
349
+
350
+ Compare correctness, preservation of exact IDs/constraints, and unsupported claims.
351
+ Return strict JSON only:
352
+ {{
353
+ "winner": "raw" | "compressed" | "tie",
354
+ "raw_score": <0.0 to 1.0>,
355
+ "compressed_score": <0.0 to 1.0>,
356
+ "quality_delta": <compressed_score - raw_score>,
357
+ "compression_preserved_accuracy": <true or false>,
358
+ "reason": "<short reason>"
359
+ }}"""
360
+
361
+
362
+ def ollama_chat(
363
+ endpoint: str,
364
+ model: str,
365
+ messages: list[dict[str, str]],
366
+ timeout: int,
367
+ ) -> dict[str, Any]:
368
+ body = json.dumps(
369
+ {
370
+ "model": model,
371
+ "messages": messages,
372
+ "stream": False,
373
+ "options": {"temperature": 0},
374
+ }
375
+ ).encode("utf-8")
376
+ request = urllib.request.Request(
377
+ endpoint,
378
+ data=body,
379
+ headers={"Content-Type": "application/json"},
380
+ method="POST",
381
+ )
382
+ started = time.perf_counter()
383
+ with urllib.request.urlopen(request, timeout=timeout) as response:
384
+ data = json.loads(response.read().decode("utf-8"))
385
+ data["_latency_s"] = time.perf_counter() - started
386
+ return data
387
+
388
+
389
+ def response_text(response: dict[str, Any]) -> str:
390
+ return str(response.get("message", {}).get("content", "")).strip()
391
+
392
+
393
+ def deterministic_score(answer: str, required_terms: tuple[str, ...]) -> float:
394
+ folded = answer.lower()
395
+ hits = sum(1 for term in required_terms if term.lower() in folded)
396
+ return hits / max(1, len(required_terms))
397
+
398
+
399
+ def missing_required_terms(answer: str, required_terms: tuple[str, ...]) -> list[str]:
400
+ folded = answer.lower()
401
+ return [term for term in required_terms if term.lower() not in folded]
402
+
403
+
404
+ def forbidden_claims_found(answer: str, forbidden_claims: tuple[str, ...]) -> list[str]:
405
+ folded = answer.lower()
406
+ return [claim for claim in forbidden_claims if claim.lower() in folded]
407
+
408
+
409
+ def extract_json_object(text: str) -> dict[str, Any]:
410
+ stripped = text.strip()
411
+ if stripped.startswith("```"):
412
+ stripped = stripped.strip("`")
413
+ if stripped.lower().startswith("json"):
414
+ stripped = stripped[4:].strip()
415
+ start = stripped.find("{")
416
+ end = stripped.rfind("}")
417
+ if start >= 0 and end >= start:
418
+ stripped = stripped[start : end + 1]
419
+ return json.loads(stripped)
420
+
421
+
422
+ def parse_judge_response(text: str) -> dict[str, Any]:
423
+ data = extract_json_object(text)
424
+ component_keys = [
425
+ "factual_score",
426
+ "required_terms_score",
427
+ "constraint_score",
428
+ "hallucination_score",
429
+ ]
430
+ for key in component_keys:
431
+ if key in data:
432
+ data[key] = max(0.0, min(1.0, float(data[key])))
433
+ if "score" not in data and all(key in data for key in component_keys):
434
+ data["score"] = (
435
+ 0.45 * data["factual_score"]
436
+ + 0.25 * data["required_terms_score"]
437
+ + 0.20 * data["constraint_score"]
438
+ + 0.10 * data["hallucination_score"]
439
+ )
440
+ score = float(data.get("score", 0.0))
441
+ data["score"] = max(0.0, min(1.0, score))
442
+ data["passed"] = bool(data.get("passed", data["score"] >= 0.80))
443
+ data["missing_terms"] = list(data.get("missing_terms", []))
444
+ data["forbidden_claims_found"] = list(data.get("forbidden_claims_found", []))
445
+ data["reason"] = str(data.get("reason", ""))
446
+ return data
447
+
448
+
449
+ def parse_pairwise_response(text: str) -> dict[str, Any]:
450
+ data = extract_json_object(text)
451
+ winner = str(data.get("winner", "tie")).lower()
452
+ if winner not in {"raw", "compressed", "tie"}:
453
+ winner = "tie"
454
+ raw_score = max(0.0, min(1.0, float(data.get("raw_score", 0.0))))
455
+ compressed_score = max(0.0, min(1.0, float(data.get("compressed_score", 0.0))))
456
+ data["winner"] = winner
457
+ data["raw_score"] = raw_score
458
+ data["compressed_score"] = compressed_score
459
+ data["quality_delta"] = float(data.get("quality_delta", compressed_score - raw_score))
460
+ data["compression_preserved_accuracy"] = bool(
461
+ data.get("compression_preserved_accuracy", compressed_score >= raw_score - 0.05)
462
+ )
463
+ data["reason"] = str(data.get("reason", ""))
464
+ return data
465
+
466
+
467
+ def call_json_judge(
468
+ endpoint: str,
469
+ model: str,
470
+ prompt: str,
471
+ timeout: int,
472
+ retries: int,
473
+ parser: Any,
474
+ ) -> tuple[dict[str, Any] | None, str | None, float | None]:
475
+ last_error = None
476
+ repair_note = ""
477
+ for attempt in range(max(1, retries + 1)):
478
+ try:
479
+ response = ollama_chat(
480
+ endpoint,
481
+ model,
482
+ [
483
+ {
484
+ "role": "system",
485
+ "content": "You are a strict evaluator. Return valid JSON only.",
486
+ },
487
+ {"role": "user", "content": prompt + repair_note},
488
+ ],
489
+ timeout,
490
+ )
491
+ return parser(response_text(response)), None, response.get("_latency_s")
492
+ except (urllib.error.URLError, TimeoutError, json.JSONDecodeError, ValueError) as exc:
493
+ last_error = str(exc)
494
+ if attempt == 0:
495
+ repair_note = (
496
+ "\n\nYour previous answer was not valid JSON for this benchmark. "
497
+ "Return only the requested JSON object, with no prose or markdown."
498
+ )
499
+ return None, last_error, None
500
+
501
+
502
+ def deterministic_accuracy(case: BenchmarkCase, answer: str) -> dict[str, Any]:
503
+ missing = missing_required_terms(answer, case.required_terms)
504
+ forbidden = forbidden_claims_found(answer, case.must_not_claim)
505
+ required_score = deterministic_score(answer, case.required_terms)
506
+ hallucination_score = 0.0 if forbidden else 1.0
507
+ score = 0.80 * required_score + 0.20 * hallucination_score
508
+ return {
509
+ "score": score,
510
+ "required_terms_score": required_score,
511
+ "hallucination_score": hallucination_score,
512
+ "missing_terms": missing,
513
+ "forbidden_claims_found": forbidden,
514
+ "passed": score >= 0.80 and not forbidden,
515
+ }
516
+
517
+
518
+ def run_variant(
519
+ case: BenchmarkCase,
520
+ variant: str,
521
+ context: str,
522
+ args: argparse.Namespace,
523
+ ) -> dict[str, Any]:
524
+ user_prompt = answer_prompt(context, case.task)
525
+ prompt_estimated_tokens = estimate_tokens(user_prompt)
526
+ result: dict[str, Any] = {
527
+ "case_id": case.case_id,
528
+ "title": case.title,
529
+ "variant": variant,
530
+ "prompt_estimated_tokens": prompt_estimated_tokens,
531
+ "context_estimated_tokens": estimate_tokens(context),
532
+ }
533
+
534
+ if args.dry_run:
535
+ result.update(
536
+ {
537
+ "answer": "",
538
+ "answer_latency_s": 0.0,
539
+ "ollama_prompt_tokens": None,
540
+ "ollama_completion_tokens": None,
541
+ "deterministic_score": None,
542
+ "judge_score": None,
543
+ "accuracy_score": None,
544
+ "passed": None,
545
+ }
546
+ )
547
+ return result
548
+
549
+ try:
550
+ answer_response = ollama_chat(
551
+ args.endpoint,
552
+ args.model,
553
+ [
554
+ {"role": "system", "content": "You are a helpful, precise assistant."},
555
+ {"role": "user", "content": user_prompt},
556
+ ],
557
+ args.timeout,
558
+ )
559
+ except (urllib.error.URLError, TimeoutError, OSError) as exc:
560
+ result.update(
561
+ {
562
+ "answer": "",
563
+ "answer_error": str(exc),
564
+ "answer_latency_s": None,
565
+ "ollama_prompt_tokens": None,
566
+ "ollama_completion_tokens": None,
567
+ "deterministic_score": None,
568
+ "judge_score": None,
569
+ "judge_reason": None,
570
+ "judge_error": None,
571
+ "accuracy_score": None,
572
+ "passed": None,
573
+ }
574
+ )
575
+ return result
576
+
577
+ answer = response_text(answer_response)
578
+ det = deterministic_accuracy(case, answer)
579
+
580
+ judge_data: dict[str, Any] | None = None
581
+ judge_error = None
582
+ if not args.skip_judge:
583
+ judge_data, judge_error, judge_latency_s = call_json_judge(
584
+ args.endpoint,
585
+ args.judge_model,
586
+ judge_prompt(case, answer),
587
+ args.timeout,
588
+ args.judge_retries,
589
+ parse_judge_response,
590
+ )
591
+ result["judge_latency_s"] = judge_latency_s
592
+
593
+ judge_score = judge_data["score"] if judge_data else None
594
+ accuracy = judge_score if judge_score is not None else det["score"]
595
+ passed = judge_data["passed"] if judge_data else det["passed"]
596
+ missing_terms = (
597
+ judge_data.get("missing_terms", det["missing_terms"]) if judge_data else det["missing_terms"]
598
+ )
599
+ forbidden_claims = (
600
+ judge_data.get("forbidden_claims_found", det["forbidden_claims_found"])
601
+ if judge_data
602
+ else det["forbidden_claims_found"]
603
+ )
604
+
605
+ result.update(
606
+ {
607
+ "answer": answer,
608
+ "answer_latency_s": answer_response.get("_latency_s"),
609
+ "ollama_prompt_tokens": answer_response.get("prompt_eval_count"),
610
+ "ollama_completion_tokens": answer_response.get("eval_count"),
611
+ "deterministic_score": det["score"],
612
+ "deterministic_required_terms_score": det["required_terms_score"],
613
+ "deterministic_hallucination_score": det["hallucination_score"],
614
+ "judge_score": judge_score,
615
+ "judge_rubric": judge_data,
616
+ "judge_reason": judge_data.get("reason") if judge_data else None,
617
+ "judge_error": judge_error,
618
+ "missing_terms": missing_terms,
619
+ "forbidden_claims_found": forbidden_claims,
620
+ "accuracy_score": accuracy,
621
+ "passed": passed,
622
+ }
623
+ )
624
+ return result
625
+
626
+
627
+ def run_pairwise_judgments(
628
+ cases: list[BenchmarkCase],
629
+ results: list[dict[str, Any]],
630
+ args: argparse.Namespace,
631
+ ) -> list[dict[str, Any]]:
632
+ if args.dry_run or args.skip_judge or args.skip_pairwise_judge:
633
+ return []
634
+
635
+ by_case_variant = {
636
+ (row["case_id"], row["variant"]): row
637
+ for row in results
638
+ if row.get("answer") and not row.get("answer_error")
639
+ }
640
+ judgments: list[dict[str, Any]] = []
641
+ for case in cases:
642
+ raw = by_case_variant.get((case.case_id, "raw"))
643
+ compressed = by_case_variant.get((case.case_id, "compressed"))
644
+ if raw is None or compressed is None:
645
+ continue
646
+ data, error, latency_s = call_json_judge(
647
+ args.endpoint,
648
+ args.judge_model,
649
+ pairwise_judge_prompt(case, raw["answer"], compressed["answer"]),
650
+ args.timeout,
651
+ args.judge_retries,
652
+ parse_pairwise_response,
653
+ )
654
+ row: dict[str, Any] = {
655
+ "case_id": case.case_id,
656
+ "title": case.title,
657
+ "judge_latency_s": latency_s,
658
+ "judge_error": error,
659
+ }
660
+ if data:
661
+ row.update(data)
662
+ judgments.append(row)
663
+ return judgments
664
+
665
+
666
+ def summarize(
667
+ results: list[dict[str, Any]],
668
+ pairwise_judgments: list[dict[str, Any]] | None = None,
669
+ ) -> dict[str, Any]:
670
+ by_variant: dict[str, list[dict[str, Any]]] = {}
671
+ for row in results:
672
+ by_variant.setdefault(row["variant"], []).append(row)
673
+
674
+ summary: dict[str, Any] = {}
675
+ for variant, rows in by_variant.items():
676
+ token_key = (
677
+ "ollama_prompt_tokens"
678
+ if all(row.get("ollama_prompt_tokens") is not None for row in rows)
679
+ else "prompt_estimated_tokens"
680
+ )
681
+ tokens = [float(row[token_key]) for row in rows]
682
+ accuracy = [
683
+ float(row["accuracy_score"])
684
+ for row in rows
685
+ if row.get("accuracy_score") is not None
686
+ ]
687
+ summary[variant] = {
688
+ "cases": len(rows),
689
+ "token_metric": token_key,
690
+ "avg_prompt_tokens": statistics.mean(tokens),
691
+ "total_prompt_tokens": sum(tokens),
692
+ "avg_accuracy": statistics.mean(accuracy) if accuracy else None,
693
+ "pass_rate": (
694
+ sum(1 for row in rows if row.get("passed")) / len(rows)
695
+ if rows and rows[0].get("passed") is not None
696
+ else None
697
+ ),
698
+ }
699
+
700
+ if "raw" in summary and "compressed" in summary:
701
+ raw = summary["raw"]["avg_prompt_tokens"]
702
+ compressed = summary["compressed"]["avg_prompt_tokens"]
703
+ summary["comparison"] = {
704
+ "avg_token_reduction_pct": ((raw - compressed) / raw * 100.0) if raw else 0.0,
705
+ "accuracy_delta": (
706
+ summary["compressed"]["avg_accuracy"] - summary["raw"]["avg_accuracy"]
707
+ if summary["raw"]["avg_accuracy"] is not None
708
+ and summary["compressed"]["avg_accuracy"] is not None
709
+ else None
710
+ ),
711
+ }
712
+ if pairwise_judgments:
713
+ scored = [row for row in pairwise_judgments if not row.get("judge_error")]
714
+ if scored:
715
+ preserved = [
716
+ row
717
+ for row in scored
718
+ if bool(row.get("compression_preserved_accuracy", False))
719
+ ]
720
+ deltas = [float(row.get("quality_delta", 0.0)) for row in scored]
721
+ summary["pairwise_judge"] = {
722
+ "cases": len(scored),
723
+ "compressed_wins": sum(1 for row in scored if row.get("winner") == "compressed"),
724
+ "raw_wins": sum(1 for row in scored if row.get("winner") == "raw"),
725
+ "ties": sum(1 for row in scored if row.get("winner") == "tie"),
726
+ "preservation_rate": len(preserved) / len(scored),
727
+ "avg_quality_delta": statistics.mean(deltas),
728
+ }
729
+ return summary
730
+
731
+
732
+ def markdown_report(payload: dict[str, Any]) -> str:
733
+ lines = [
734
+ "# Context Compression Benchmark",
735
+ "",
736
+ f"Generated: {payload['generated_at']}",
737
+ f"Endpoint: `{payload['endpoint']}`",
738
+ f"Worker model: `{payload['model']}`",
739
+ f"Judge model: `{payload['judge_model']}`",
740
+ f"Dry run: `{payload['dry_run']}`",
741
+ "",
742
+ "## Summary",
743
+ "",
744
+ "| Variant | Token metric | Avg prompt tokens | Total prompt tokens | Avg accuracy | Pass rate |",
745
+ "|---|---:|---:|---:|---:|---:|",
746
+ ]
747
+ for variant in ["raw", "compressed"]:
748
+ data = payload["summary"].get(variant, {})
749
+ lines.append(
750
+ "| {variant} | {metric} | {avg_tokens:.1f} | {total_tokens:.0f} | {accuracy} | {pass_rate} |".format(
751
+ variant=variant,
752
+ metric=data.get("token_metric", "-"),
753
+ avg_tokens=float(data.get("avg_prompt_tokens", 0.0)),
754
+ total_tokens=float(data.get("total_prompt_tokens", 0.0)),
755
+ accuracy=(
756
+ "-"
757
+ if data.get("avg_accuracy") is None
758
+ else f"{float(data['avg_accuracy']):.3f}"
759
+ ),
760
+ pass_rate=(
761
+ "-"
762
+ if data.get("pass_rate") is None
763
+ else f"{float(data['pass_rate']) * 100:.1f}%"
764
+ ),
765
+ )
766
+ )
767
+ comparison = payload["summary"].get("comparison", {})
768
+ if comparison:
769
+ lines.extend(
770
+ [
771
+ "",
772
+ "## Comparison",
773
+ "",
774
+ f"- Average token reduction: `{comparison['avg_token_reduction_pct']:.1f}%`",
775
+ "- Accuracy delta: `{}`".format(
776
+ "-"
777
+ if comparison.get("accuracy_delta") is None
778
+ else f"{comparison['accuracy_delta']:.3f}"
779
+ ),
780
+ ]
781
+ )
782
+ pairwise_summary = payload["summary"].get("pairwise_judge", {})
783
+ if pairwise_summary:
784
+ lines.extend(
785
+ [
786
+ "",
787
+ "## Pairwise Judge",
788
+ "",
789
+ f"- Cases judged: `{pairwise_summary['cases']}`",
790
+ f"- Compressed wins: `{pairwise_summary['compressed_wins']}`",
791
+ f"- Raw wins: `{pairwise_summary['raw_wins']}`",
792
+ f"- Ties: `{pairwise_summary['ties']}`",
793
+ f"- Accuracy preservation rate: `{pairwise_summary['preservation_rate'] * 100:.1f}%`",
794
+ f"- Average quality delta: `{pairwise_summary['avg_quality_delta']:.3f}`",
795
+ ]
796
+ )
797
+ lines.extend(
798
+ [
799
+ "",
800
+ "## Case Details",
801
+ "",
802
+ "| Case | Variant | Estimated prompt tokens | Ollama prompt tokens | Accuracy | Missing terms | Forbidden claims | Passed | Error |",
803
+ "|---|---|---:|---:|---:|---|---|---|---|",
804
+ ]
805
+ )
806
+ for row in payload["results"]:
807
+ error = row.get("answer_error") or row.get("judge_error") or ""
808
+ missing = ", ".join(row.get("missing_terms") or [])
809
+ forbidden = ", ".join(row.get("forbidden_claims_found") or [])
810
+ lines.append(
811
+ "| {case} | {variant} | {estimate} | {ollama} | {accuracy} | {missing} | {forbidden} | {passed} | {error} |".format(
812
+ case=row["case_id"],
813
+ variant=row["variant"],
814
+ estimate=row["prompt_estimated_tokens"],
815
+ ollama=row.get("ollama_prompt_tokens") or "-",
816
+ accuracy=(
817
+ "-"
818
+ if row.get("accuracy_score") is None
819
+ else f"{float(row['accuracy_score']):.3f}"
820
+ ),
821
+ missing=missing.replace("|", "\\|") or "-",
822
+ forbidden=forbidden.replace("|", "\\|") or "-",
823
+ passed=row.get("passed"),
824
+ error=error.replace("|", "\\|"),
825
+ )
826
+ )
827
+ if payload.get("pairwise_judgments"):
828
+ lines.extend(
829
+ [
830
+ "",
831
+ "## Pairwise Details",
832
+ "",
833
+ "| Case | Winner | Raw score | Compressed score | Delta | Preserved | Reason | Error |",
834
+ "|---|---|---:|---:|---:|---|---|---|",
835
+ ]
836
+ )
837
+ for row in payload["pairwise_judgments"]:
838
+ lines.append(
839
+ "| {case} | {winner} | {raw_score} | {compressed_score} | {delta} | {preserved} | {reason} | {error} |".format(
840
+ case=row["case_id"],
841
+ winner=row.get("winner", "-"),
842
+ raw_score=(
843
+ "-" if row.get("raw_score") is None else f"{float(row['raw_score']):.3f}"
844
+ ),
845
+ compressed_score=(
846
+ "-"
847
+ if row.get("compressed_score") is None
848
+ else f"{float(row['compressed_score']):.3f}"
849
+ ),
850
+ delta=(
851
+ "-"
852
+ if row.get("quality_delta") is None
853
+ else f"{float(row['quality_delta']):.3f}"
854
+ ),
855
+ preserved=row.get("compression_preserved_accuracy"),
856
+ reason=str(row.get("reason", "")).replace("|", "\\|") or "-",
857
+ error=str(row.get("judge_error", "")).replace("|", "\\|") or "-",
858
+ )
859
+ )
860
+ return "\n".join(lines) + "\n"
861
+
862
+
863
+ def main() -> int:
864
+ parser = argparse.ArgumentParser(
865
+ description="Benchmark raw vs compressed context token use and answer accuracy."
866
+ )
867
+ parser.add_argument("--endpoint", default=DEFAULT_ENDPOINT)
868
+ parser.add_argument("--model", default=DEFAULT_MODEL)
869
+ parser.add_argument("--judge-model", default=DEFAULT_MODEL)
870
+ parser.add_argument("--timeout", type=int, default=180)
871
+ parser.add_argument("--output-dir", default="benchmark_results")
872
+ parser.add_argument("--dry-run", action="store_true")
873
+ parser.add_argument("--skip-judge", action="store_true")
874
+ parser.add_argument("--skip-pairwise-judge", action="store_true")
875
+ parser.add_argument("--judge-retries", type=int, default=1)
876
+ parser.add_argument("--pass-threshold", type=float, default=0.75)
877
+ args = parser.parse_args()
878
+
879
+ cases = build_cases()
880
+ results = []
881
+ for case in cases:
882
+ results.append(run_variant(case, "raw", case.raw_context, args))
883
+ results.append(run_variant(case, "compressed", case.compressed_context, args))
884
+ pairwise_judgments = run_pairwise_judgments(cases, results, args)
885
+
886
+ payload = {
887
+ "generated_at": datetime.now(timezone.utc).isoformat(),
888
+ "endpoint": args.endpoint,
889
+ "model": args.model,
890
+ "judge_model": args.judge_model,
891
+ "dry_run": args.dry_run,
892
+ "results": results,
893
+ "pairwise_judgments": pairwise_judgments,
894
+ "summary": summarize(results, pairwise_judgments),
895
+ }
896
+
897
+ output_dir = Path(args.output_dir)
898
+ output_dir.mkdir(parents=True, exist_ok=True)
899
+ json_path = output_dir / "context_compression_benchmark.json"
900
+ md_path = output_dir / "context_compression_benchmark.md"
901
+ json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
902
+ md_path.write_text(markdown_report(payload), encoding="utf-8")
903
+
904
+ print(f"Wrote {json_path}")
905
+ print(f"Wrote {md_path}")
906
+ print(json.dumps(payload["summary"], indent=2))
907
+ return 0
908
+
909
+
910
+ if __name__ == "__main__":
911
+ raise SystemExit(main())
@@ -0,0 +1,258 @@
1
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2
+ from typing import Any, Optional, Dict, List
3
+ from datetime import datetime, timezone
4
+ import uuid
5
+ import json
6
+
7
+ def utcnow():
8
+ return datetime.now(timezone.utc)
9
+
10
+ VALID_TYPES = {"Fact", "Hypothesis", "Task", "Decision", "Evidence", "Constraint"}
11
+ VALID_STATUSES = {
12
+ "draft",
13
+ "validated",
14
+ "used",
15
+ "archived",
16
+ "hypothesis",
17
+ "tested",
18
+ "confirmed",
19
+ "rejected",
20
+ }
21
+
22
+ GENERIC_TRANSITIONS = {
23
+ "draft": {"validated", "archived"},
24
+ "validated": {"used", "archived"},
25
+ "used": {"archived"},
26
+ "archived": set(),
27
+ }
28
+
29
+ HYPOTHESIS_TRANSITIONS = {
30
+ "draft": {"tested", "archived"},
31
+ "hypothesis": {"tested", "archived"},
32
+ "tested": {"confirmed", "rejected", "archived"},
33
+ "confirmed": {"archived"},
34
+ "rejected": {"archived"},
35
+ "archived": set(),
36
+ }
37
+
38
+ class MNMemoryItem(BaseModel):
39
+ model_config = ConfigDict(validate_assignment=True)
40
+
41
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
42
+ type: str
43
+ content: Any
44
+ status: str
45
+ confidence: float = 1.0
46
+ source: str
47
+ created_at: datetime = Field(default_factory=utcnow)
48
+ updated_at: datetime = Field(default_factory=utcnow)
49
+ expires_at: Optional[datetime] = None
50
+ version: int = 1
51
+
52
+ @field_validator("type")
53
+ @classmethod
54
+ def validate_type(cls, value: str) -> str:
55
+ if value not in VALID_TYPES:
56
+ raise ValueError(f"invalid memory type: {value}")
57
+ return value
58
+
59
+ @field_validator("status")
60
+ @classmethod
61
+ def validate_status(cls, value: str) -> str:
62
+ if value not in VALID_STATUSES:
63
+ raise ValueError(f"invalid memory status: {value}")
64
+ return value
65
+
66
+ @field_validator("confidence")
67
+ @classmethod
68
+ def validate_confidence(cls, value: float) -> float:
69
+ if not 0.0 <= value <= 1.0:
70
+ raise ValueError("confidence must be between 0.0 and 1.0")
71
+ return value
72
+
73
+ @field_validator("source")
74
+ @classmethod
75
+ def validate_source(cls, value: str) -> str:
76
+ if not value.strip():
77
+ raise ValueError("source must not be empty")
78
+ return value
79
+
80
+ @model_validator(mode="after")
81
+ def validate_content(self) -> "MNMemoryItem":
82
+ if not isinstance(self.content, dict):
83
+ raise ValueError("content must be a structured object")
84
+ return self
85
+
86
+ class MNMemoryEdge(BaseModel):
87
+ source_id: str
88
+ target_id: str
89
+ relation: str
90
+
91
+ @field_validator("relation")
92
+ @classmethod
93
+ def validate_relation(cls, value: str) -> str:
94
+ if not value.strip():
95
+ raise ValueError("relation must not be empty")
96
+ return value
97
+
98
+ class MNWorkingMemory:
99
+ def __init__(self):
100
+ self._items: Dict[str, MNMemoryItem] = {}
101
+ self._edges: List[MNMemoryEdge] = []
102
+
103
+ def add(self, item: MNMemoryItem) -> None:
104
+ self._items[item.id] = item
105
+
106
+ def update(self, item_id: str, updates: dict) -> None:
107
+ if item_id in self._items:
108
+ item = self._items[item_id]
109
+ updates = dict(updates)
110
+ expected_version = updates.pop("expected_version", None)
111
+ if expected_version is not None and expected_version != item.version:
112
+ raise ValueError(f"version conflict: item has {item.version}, expected {expected_version}")
113
+
114
+ if "status" in updates:
115
+ self._validate_transition(item.type, item.status, updates["status"])
116
+
117
+ for key, value in updates.items():
118
+ if hasattr(item, key):
119
+ setattr(item, key, value)
120
+ item.updated_at = utcnow()
121
+ item.version += 1
122
+
123
+ def get(self, item_id: str) -> Optional[MNMemoryItem]:
124
+ item = self._items.get(item_id)
125
+ if item is None or self._is_expired(item):
126
+ return None
127
+ return item
128
+
129
+ def query(self, filters: dict) -> List[MNMemoryItem]:
130
+ results = []
131
+ for item in self._items.values():
132
+ if self._is_expired(item):
133
+ continue
134
+ match = True
135
+ for key, value in filters.items():
136
+ if getattr(item, key, None) != value:
137
+ match = False
138
+ break
139
+ if match:
140
+ results.append(item)
141
+ return results
142
+
143
+ def link(self, source_id: str, target_id: str, relation: str) -> None:
144
+ if source_id not in self._items:
145
+ raise KeyError(f"source item not found: {source_id}")
146
+ if target_id not in self._items:
147
+ raise KeyError(f"target item not found: {target_id}")
148
+ edge = MNMemoryEdge(source_id=source_id, target_id=target_id, relation=relation)
149
+ self._edges.append(edge)
150
+
151
+ def invalidate(self, item_id: str) -> None:
152
+ if item_id in self._items:
153
+ self._validate_transition(self._items[item_id].type, self._items[item_id].status, "archived")
154
+ self._items[item_id].status = "archived"
155
+ self._items[item_id].updated_at = utcnow()
156
+ self._items[item_id].version += 1
157
+
158
+ def get_context(self, agent_role: str, goal_id: str) -> List[MNMemoryItem]:
159
+ """
160
+ Retrieves selective context for an agent around a goal.
161
+ Items must be visible to the role and either match the goal directly
162
+ or sit within two graph hops of a matching item.
163
+ """
164
+ seeds = {
165
+ item.id
166
+ for item in self._items.values()
167
+ if not self._is_expired(item)
168
+ and item.status != "archived"
169
+ and (item.id == goal_id or item.content.get("goal_id") == goal_id)
170
+ }
171
+ reachable = self._expand_graph(seeds, max_depth=2)
172
+
173
+ return [
174
+ item
175
+ for item in self._items.values()
176
+ if item.id in reachable
177
+ and item.status != "archived"
178
+ and not self._is_expired(item)
179
+ and self._is_visible_to_role(agent_role, item)
180
+ ]
181
+
182
+ @staticmethod
183
+ def _is_expired(item: MNMemoryItem) -> bool:
184
+ return item.expires_at is not None and utcnow() > item.expires_at
185
+
186
+ @staticmethod
187
+ def _validate_transition(item_type: str, current: str, new_status: str) -> None:
188
+ if current == new_status:
189
+ return
190
+ transitions = HYPOTHESIS_TRANSITIONS if item_type == "Hypothesis" else GENERIC_TRANSITIONS
191
+ if new_status not in transitions.get(current, set()):
192
+ raise ValueError(f"invalid status transition for {item_type}: {current} -> {new_status}")
193
+
194
+ @staticmethod
195
+ def _is_visible_to_role(agent_role: str, item: MNMemoryItem) -> bool:
196
+ acl = item.content.get("acl")
197
+ if isinstance(acl, dict):
198
+ deny_roles = acl.get("deny_roles", [])
199
+ if agent_role in deny_roles:
200
+ return False
201
+ allow_roles = acl.get("allow_roles")
202
+ if allow_roles is not None:
203
+ return agent_role in allow_roles
204
+
205
+ if agent_role == "planner":
206
+ return item.type != "Evidence"
207
+ if agent_role == "executor":
208
+ return item.type != "Hypothesis"
209
+ if agent_role == "reviewer":
210
+ return True
211
+ return item.source == agent_role
212
+
213
+ def _expand_graph(self, seed_ids: set[str], max_depth: int) -> set[str]:
214
+ reachable = set(seed_ids)
215
+ frontier = set(seed_ids)
216
+ for _ in range(max_depth):
217
+ next_frontier = set()
218
+ for edge in self._edges:
219
+ if edge.source_id in frontier and edge.target_id not in reachable:
220
+ next_frontier.add(edge.target_id)
221
+ if edge.target_id in frontier and edge.source_id not in reachable:
222
+ next_frontier.add(edge.source_id)
223
+ if not next_frontier:
224
+ break
225
+ reachable.update(next_frontier)
226
+ frontier = next_frontier
227
+ return reachable
228
+
229
+ # --- Methods for compatibility with MicroNeuron payloads (Serialization) ---
230
+ def to_dict(self) -> dict:
231
+ return {
232
+ "items": [item.model_dump(mode="json") for item in self._items.values()],
233
+ "edges": [edge.model_dump(mode="json") for edge in self._edges]
234
+ }
235
+
236
+ @classmethod
237
+ def from_dict(cls, data: dict) -> "MNWorkingMemory":
238
+ wm = cls()
239
+ for item_data in data.get("items", []):
240
+ item = MNMemoryItem.model_validate(item_data)
241
+ wm.add(item)
242
+ for edge_data in data.get("edges", []):
243
+ edge = MNMemoryEdge.model_validate(edge_data)
244
+ wm._edges.append(edge)
245
+ return wm
246
+
247
+ def to_json(self) -> str:
248
+ return json.dumps(self.to_dict())
249
+
250
+ @classmethod
251
+ def from_json(cls, json_str: str) -> "MNWorkingMemory":
252
+ return cls.from_dict(json.loads(json_str))
253
+
254
+
255
+ # Backward-compatible aliases for older SDK users.
256
+ MemoryItem = MNMemoryItem
257
+ MemoryEdge = MNMemoryEdge
258
+ WorkingMemory = MNWorkingMemory