memory-arena 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. memory_arena/__init__.py +19 -0
  2. memory_arena/arena/__init__.py +0 -0
  3. memory_arena/arena/engine.py +252 -0
  4. memory_arena/benchmark/__init__.py +15 -0
  5. memory_arena/benchmark/evaluator.py +315 -0
  6. memory_arena/benchmark/questions.py +150 -0
  7. memory_arena/benchmark/recall_lab.py +120 -0
  8. memory_arena/benchmark/recall_metrics.py +119 -0
  9. memory_arena/benchmark/runner.py +694 -0
  10. memory_arena/chatbot/__init__.py +0 -0
  11. memory_arena/chatbot/api.py +268 -0
  12. memory_arena/chatbot/router.py +12 -0
  13. memory_arena/chatbot/session.py +73 -0
  14. memory_arena/cli.py +439 -0
  15. memory_arena/data/longmemeval-s/processed/questions.jsonl +16 -0
  16. memory_arena/data/longmemeval-s/processed/sessions.jsonl +82 -0
  17. memory_arena/data/results_snapshot/longmemeval-s_bm25_seed0.json +1559 -0
  18. memory_arena/data/results_snapshot/longmemeval-s_bm25_seed1.json +1559 -0
  19. memory_arena/data/results_snapshot/longmemeval-s_bm25_seed2.json +1559 -0
  20. memory_arena/data/results_snapshot/longmemeval-s_bm25_summary.json +90 -0
  21. memory_arena/data/results_snapshot/longmemeval-s_cognee_seed0.json +1386 -0
  22. memory_arena/data/results_snapshot/longmemeval-s_cognee_summary.json +86 -0
  23. memory_arena/data/results_snapshot/longmemeval-s_full_context_seed0.json +2098 -0
  24. memory_arena/data/results_snapshot/longmemeval-s_full_context_summary.json +81 -0
  25. memory_arena/data/results_snapshot/longmemeval-s_graphiti_seed0.json +1444 -0
  26. memory_arena/data/results_snapshot/longmemeval-s_graphiti_summary.json +86 -0
  27. memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_seed0.json +1546 -0
  28. memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_seed1.json +1546 -0
  29. memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_seed2.json +1546 -0
  30. memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_summary.json +90 -0
  31. memory_arena/data/results_snapshot/longmemeval-s_hyde_seed0.json +1540 -0
  32. memory_arena/data/results_snapshot/longmemeval-s_hyde_seed1.json +1540 -0
  33. memory_arena/data/results_snapshot/longmemeval-s_hyde_seed2.json +1537 -0
  34. memory_arena/data/results_snapshot/longmemeval-s_hyde_summary.json +90 -0
  35. memory_arena/data/results_snapshot/longmemeval-s_karpathy_llm_wiki_seed0.json +1496 -0
  36. memory_arena/data/results_snapshot/longmemeval-s_karpathy_llm_wiki_summary.json +86 -0
  37. memory_arena/data/results_snapshot/longmemeval-s_langmem_seed0.json +1386 -0
  38. memory_arena/data/results_snapshot/longmemeval-s_langmem_seed1.json +1386 -0
  39. memory_arena/data/results_snapshot/longmemeval-s_langmem_seed2.json +1386 -0
  40. memory_arena/data/results_snapshot/longmemeval-s_langmem_summary.json +90 -0
  41. memory_arena/data/results_snapshot/longmemeval-s_mem0_seed0.json +1453 -0
  42. memory_arena/data/results_snapshot/longmemeval-s_mem0_seed1.json +1455 -0
  43. memory_arena/data/results_snapshot/longmemeval-s_mem0_seed2.json +1453 -0
  44. memory_arena/data/results_snapshot/longmemeval-s_mem0_summary.json +90 -0
  45. memory_arena/data/results_snapshot/longmemeval-s_mem0g_seed0.json +1457 -0
  46. memory_arena/data/results_snapshot/longmemeval-s_mem0g_summary.json +86 -0
  47. memory_arena/data/results_snapshot/longmemeval-s_memori_seed0.json +1386 -0
  48. memory_arena/data/results_snapshot/longmemeval-s_memori_summary.json +86 -0
  49. memory_arena/data/results_snapshot/longmemeval-s_naive_vector_seed0.json +1540 -0
  50. memory_arena/data/results_snapshot/longmemeval-s_naive_vector_seed1.json +1540 -0
  51. memory_arena/data/results_snapshot/longmemeval-s_naive_vector_seed2.json +1538 -0
  52. memory_arena/data/results_snapshot/longmemeval-s_naive_vector_summary.json +90 -0
  53. memory_arena/data/results_snapshot/longmemeval-s_persona_profile_seed0.json +1537 -0
  54. memory_arena/data/results_snapshot/longmemeval-s_persona_profile_seed1.json +1538 -0
  55. memory_arena/data/results_snapshot/longmemeval-s_persona_profile_seed2.json +1540 -0
  56. memory_arena/data/results_snapshot/longmemeval-s_persona_profile_summary.json +90 -0
  57. memory_arena/data/results_snapshot/longmemeval-s_raptor_seed0.json +1521 -0
  58. memory_arena/data/results_snapshot/longmemeval-s_raptor_seed1.json +1521 -0
  59. memory_arena/data/results_snapshot/longmemeval-s_raptor_seed2.json +1521 -0
  60. memory_arena/data/results_snapshot/longmemeval-s_raptor_summary.json +90 -0
  61. memory_arena/data/results_snapshot/longmemeval-s_recency_window_seed0.json +1786 -0
  62. memory_arena/data/results_snapshot/longmemeval-s_recency_window_seed1.json +1786 -0
  63. memory_arena/data/results_snapshot/longmemeval-s_recency_window_seed2.json +1786 -0
  64. memory_arena/data/results_snapshot/longmemeval-s_recency_window_summary.json +90 -0
  65. memory_arena/data/results_snapshot/longmemeval-s_reflection_seed0.json +1539 -0
  66. memory_arena/data/results_snapshot/longmemeval-s_reflection_seed1.json +1539 -0
  67. memory_arena/data/results_snapshot/longmemeval-s_reflection_seed2.json +1541 -0
  68. memory_arena/data/results_snapshot/longmemeval-s_reflection_summary.json +90 -0
  69. memory_arena/exceptions.py +33 -0
  70. memory_arena/graph/__init__.py +0 -0
  71. memory_arena/graph/analyzer.py +143 -0
  72. memory_arena/graph/cypher_generator.py +115 -0
  73. memory_arena/graph/cypher_templates.py +106 -0
  74. memory_arena/graph/extractor.py +338 -0
  75. memory_arena/graph/neo4j_store.py +152 -0
  76. memory_arena/graph/resolver.py +77 -0
  77. memory_arena/graph/schema.py +55 -0
  78. memory_arena/llm/__init__.py +0 -0
  79. memory_arena/llm/client.py +327 -0
  80. memory_arena/llm/openrouter.py +122 -0
  81. memory_arena/llm/providers.py +252 -0
  82. memory_arena/models/__init__.py +1 -0
  83. memory_arena/models/api.py +61 -0
  84. memory_arena/models/benchmark.py +172 -0
  85. memory_arena/models/document.py +66 -0
  86. memory_arena/models/graph.py +51 -0
  87. memory_arena/models/retrieval.py +33 -0
  88. memory_arena/paths.py +90 -0
  89. memory_arena/py.typed +0 -0
  90. memory_arena/sessions/__init__.py +23 -0
  91. memory_arena/sessions/loaders.py +210 -0
  92. memory_arena/sessions/schema.py +110 -0
  93. memory_arena/settings.py +246 -0
  94. memory_arena/static/404/index.html +1 -0
  95. memory_arena/static/404.html +1 -0
  96. memory_arena/static/_next/static/chunks/117-020b9cef866aefe5.js +2 -0
  97. memory_arena/static/_next/static/chunks/648-77f86c7bea515da8.js +1 -0
  98. memory_arena/static/_next/static/chunks/app/_not-found/page-c59625541a56b4fe.js +1 -0
  99. memory_arena/static/_next/static/chunks/app/arena/page-410e2e6b1ec69f52.js +1 -0
  100. memory_arena/static/_next/static/chunks/app/benchmark/page-76db067e901c05a6.js +1 -0
  101. memory_arena/static/_next/static/chunks/app/layout-82e84594976f899d.js +1 -0
  102. memory_arena/static/_next/static/chunks/app/page-e07ee36911d8bec3.js +1 -0
  103. memory_arena/static/_next/static/chunks/app/recall-lab/page-5e983cc6e76fcbab.js +1 -0
  104. memory_arena/static/_next/static/chunks/fd9d1056-40eb70ab657cb18b.js +1 -0
  105. memory_arena/static/_next/static/chunks/framework-f66176bb897dc684.js +1 -0
  106. memory_arena/static/_next/static/chunks/main-app-ce09194c2205bef4.js +1 -0
  107. memory_arena/static/_next/static/chunks/main-f93c663012e61193.js +1 -0
  108. memory_arena/static/_next/static/chunks/pages/_app-72b849fbd24ac258.js +1 -0
  109. memory_arena/static/_next/static/chunks/pages/_error-7ba65e1336b92748.js +1 -0
  110. memory_arena/static/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
  111. memory_arena/static/_next/static/chunks/webpack-03f7c6bc932ce1e3.js +1 -0
  112. memory_arena/static/_next/static/css/7f0f5d6971a0dc74.css +3 -0
  113. memory_arena/static/_next/static/wwMdP-USmLrBi4wuoZSsd/_buildManifest.js +1 -0
  114. memory_arena/static/_next/static/wwMdP-USmLrBi4wuoZSsd/_ssgManifest.js +1 -0
  115. memory_arena/static/arena/index.html +1 -0
  116. memory_arena/static/arena/index.txt +8 -0
  117. memory_arena/static/benchmark/index.html +1 -0
  118. memory_arena/static/benchmark/index.txt +9 -0
  119. memory_arena/static/favicon.ico +0 -0
  120. memory_arena/static/index.html +1 -0
  121. memory_arena/static/index.txt +9 -0
  122. memory_arena/static/recall-lab/index.html +1 -0
  123. memory_arena/static/recall-lab/index.txt +9 -0
  124. memory_arena/strategies/__init__.py +190 -0
  125. memory_arena/strategies/amem.py +511 -0
  126. memory_arena/strategies/base.py +151 -0
  127. memory_arena/strategies/bm25.py +125 -0
  128. memory_arena/strategies/cognee.py +242 -0
  129. memory_arena/strategies/embeddings.py +62 -0
  130. memory_arena/strategies/full_context.py +87 -0
  131. memory_arena/strategies/graphiti.py +317 -0
  132. memory_arena/strategies/graphiti_falkor.py +131 -0
  133. memory_arena/strategies/hipporag2.py +539 -0
  134. memory_arena/strategies/hybrid_rrf.py +132 -0
  135. memory_arena/strategies/hyde.py +85 -0
  136. memory_arena/strategies/karpathy_llm_wiki.py +423 -0
  137. memory_arena/strategies/langmem.py +176 -0
  138. memory_arena/strategies/mem0.py +255 -0
  139. memory_arena/strategies/mem0g.py +240 -0
  140. memory_arena/strategies/memori.py +237 -0
  141. memory_arena/strategies/naive_vector.py +164 -0
  142. memory_arena/strategies/persona_profile.py +139 -0
  143. memory_arena/strategies/quantum/__init__.py +24 -0
  144. memory_arena/strategies/quantum/circuits.py +111 -0
  145. memory_arena/strategies/quantum/qiss.py +299 -0
  146. memory_arena/strategies/quantum/sqr.py +188 -0
  147. memory_arena/strategies/quantum/utils.py +106 -0
  148. memory_arena/strategies/raptor.py +254 -0
  149. memory_arena/strategies/recency_window.py +80 -0
  150. memory_arena/strategies/reflection.py +152 -0
  151. memory_arena/tokenizer.py +22 -0
  152. memory_arena/viz/__init__.py +0 -0
  153. memory_arena-0.1.8.dist-info/METADATA +746 -0
  154. memory_arena-0.1.8.dist-info/RECORD +157 -0
  155. memory_arena-0.1.8.dist-info/WHEEL +4 -0
  156. memory_arena-0.1.8.dist-info/entry_points.txt +2 -0
  157. memory_arena-0.1.8.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,19 @@
1
+ """Memory Arena — Knowledge Base Benchmark. Find which retrieval architecture fits your data."""
2
+
3
+ __version__ = "0.1.8"
4
+
5
+ from memory_arena.models.benchmark import BenchmarkResult, Question
6
+ from memory_arena.models.document import Document, Section
7
+ from memory_arena.models.graph import Entity, Relationship
8
+ from memory_arena.strategies.base import Strategy
9
+
10
+ __all__ = [
11
+ "Document",
12
+ "Section",
13
+ "Entity",
14
+ "Relationship",
15
+ "Question",
16
+ "BenchmarkResult",
17
+ "Strategy",
18
+ "__version__",
19
+ ]
File without changes
@@ -0,0 +1,252 @@
1
+ """Arena engine - blind A/B strategy matchups with ELO rating."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import logging
8
+ import random
9
+ import time
10
+ from dataclasses import dataclass, field
11
+ from pathlib import Path
12
+ from uuid import uuid4
13
+
14
+ from memory_arena.settings import settings
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+ INITIAL_ELO = 1200.0
19
+ K_FACTOR = 32.0
20
+
21
+
22
+ @dataclass
23
+ class Match:
24
+ """A single A/B matchup between two strategies."""
25
+
26
+ id: str
27
+ question: str
28
+ strategy_a: str
29
+ strategy_b: str
30
+ answer_a: str
31
+ answer_b: str
32
+ latency_a_ms: float = 0.0
33
+ latency_b_ms: float = 0.0
34
+ cost_a: float = 0.0
35
+ cost_b: float = 0.0
36
+ winner: str | None = None # "a", "b", "tie", or None (pending)
37
+ timestamp: float = 0.0
38
+ sources_a: list[str] = field(default_factory=list)
39
+ sources_b: list[str] = field(default_factory=list)
40
+
41
+
42
+ @dataclass
43
+ class ArenaState:
44
+ """Persistent arena state with ELO ratings and match history."""
45
+
46
+ elo: dict[str, float] = field(default_factory=dict)
47
+ matches: list[Match] = field(default_factory=list)
48
+ total_votes: int = 0
49
+
50
+ def save(self, path: Path) -> None:
51
+ path.parent.mkdir(parents=True, exist_ok=True)
52
+ data = {
53
+ "elo": self.elo,
54
+ "total_votes": self.total_votes,
55
+ "matches": [
56
+ {
57
+ "id": m.id,
58
+ "question": m.question,
59
+ "strategy_a": m.strategy_a,
60
+ "strategy_b": m.strategy_b,
61
+ "answer_a": m.answer_a[:500], # truncate for storage
62
+ "answer_b": m.answer_b[:500],
63
+ "latency_a_ms": m.latency_a_ms,
64
+ "latency_b_ms": m.latency_b_ms,
65
+ "cost_a": m.cost_a,
66
+ "cost_b": m.cost_b,
67
+ "winner": m.winner,
68
+ "timestamp": m.timestamp,
69
+ }
70
+ for m in self.matches[-200:] # keep last 200
71
+ ],
72
+ }
73
+ path.write_text(json.dumps(data, indent=2))
74
+
75
+ @classmethod
76
+ def load(cls, path: Path) -> ArenaState:
77
+ if not path.exists():
78
+ return cls()
79
+ try:
80
+ data = json.loads(path.read_text())
81
+ state = cls(
82
+ elo=data.get("elo", {}),
83
+ total_votes=data.get("total_votes", 0),
84
+ )
85
+ for m in data.get("matches", []):
86
+ state.matches.append(
87
+ Match(
88
+ id=m["id"],
89
+ question=m["question"],
90
+ strategy_a=m["strategy_a"],
91
+ strategy_b=m["strategy_b"],
92
+ answer_a=m.get("answer_a", ""),
93
+ answer_b=m.get("answer_b", ""),
94
+ latency_a_ms=m.get("latency_a_ms", 0),
95
+ latency_b_ms=m.get("latency_b_ms", 0),
96
+ cost_a=m.get("cost_a", 0),
97
+ cost_b=m.get("cost_b", 0),
98
+ winner=m.get("winner"),
99
+ timestamp=m.get("timestamp", 0),
100
+ )
101
+ )
102
+ return state
103
+ except (json.JSONDecodeError, KeyError):
104
+ log.warning("Corrupt arena state, starting fresh")
105
+ return cls()
106
+
107
+
108
+ class ArenaEngine:
109
+ """Manages blind A/B matches between retrieval strategies."""
110
+
111
+ def __init__(self, strategies: dict) -> None:
112
+ self.strategies = strategies
113
+ self._state_path = Path(settings.results_path) / "arena_state.json"
114
+ self.state = ArenaState.load(self._state_path)
115
+ # Initialize ELO for new strategies
116
+ for name in strategies:
117
+ if name not in self.state.elo:
118
+ self.state.elo[name] = INITIAL_ELO
119
+
120
+ async def create_match(self, question: str, corpus: str = "") -> Match:
121
+ """Pick two random strategies, query both, return a blind match."""
122
+ names = list(self.strategies.keys())
123
+ if len(names) < 2:
124
+ raise ValueError("Need at least 2 strategies for arena mode")
125
+ a_name, b_name = random.sample(names, 2)
126
+
127
+ result_a, result_b = await asyncio.gather(
128
+ self.strategies[a_name].query(question),
129
+ self.strategies[b_name].query(question),
130
+ )
131
+
132
+ match = Match(
133
+ id=uuid4().hex[:8],
134
+ question=question,
135
+ strategy_a=a_name,
136
+ strategy_b=b_name,
137
+ answer_a=result_a.answer,
138
+ answer_b=result_b.answer,
139
+ latency_a_ms=result_a.latency_ms,
140
+ latency_b_ms=result_b.latency_ms,
141
+ cost_a=result_a.cost_usd,
142
+ cost_b=result_b.cost_usd,
143
+ sources_a=result_a.sources,
144
+ sources_b=result_b.sources,
145
+ timestamp=time.time(),
146
+ )
147
+ self.state.matches.append(match)
148
+ return match
149
+
150
+ def vote(self, match_id: str, winner: str) -> dict:
151
+ """Record a vote and update ELO. winner: 'a', 'b', or 'tie'."""
152
+ if winner not in ("a", "b", "tie"):
153
+ return {"error": f"Invalid winner: {winner}. Must be 'a', 'b', or 'tie'"}
154
+
155
+ match = next((m for m in self.state.matches if m.id == match_id), None)
156
+ if not match:
157
+ return {"error": "Match not found"}
158
+ if match.winner is not None:
159
+ return {"error": "Match already voted on"}
160
+
161
+ match.winner = winner
162
+ self.state.total_votes += 1
163
+ self._update_elo(match)
164
+ self.state.save(self._state_path)
165
+ self._append_vote_jsonl(match)
166
+
167
+ return {
168
+ "strategy_a": match.strategy_a,
169
+ "strategy_b": match.strategy_b,
170
+ "winner": winner,
171
+ "elo": dict(self.state.elo),
172
+ "total_votes": self.state.total_votes,
173
+ }
174
+
175
+ def _update_elo(self, match: Match) -> None:
176
+ """Standard ELO rating update."""
177
+ ea = self.state.elo.get(match.strategy_a, INITIAL_ELO)
178
+ eb = self.state.elo.get(match.strategy_b, INITIAL_ELO)
179
+ expected_a = 1.0 / (1.0 + 10.0 ** ((eb - ea) / 400.0))
180
+
181
+ if match.winner == "a":
182
+ score_a = 1.0
183
+ elif match.winner == "b":
184
+ score_a = 0.0
185
+ else: # tie
186
+ score_a = 0.5
187
+
188
+ self.state.elo[match.strategy_a] = ea + K_FACTOR * (score_a - expected_a)
189
+ self.state.elo[match.strategy_b] = eb + K_FACTOR * ((1 - score_a) - (1 - expected_a))
190
+
191
+ def leaderboard(self) -> list[dict]:
192
+ """Return strategies sorted by ELO rating."""
193
+ board = []
194
+ for name, elo in self.state.elo.items():
195
+ wins = sum(
196
+ 1
197
+ for m in self.state.matches
198
+ if m.winner
199
+ and (
200
+ (m.strategy_a == name and m.winner == "a")
201
+ or (m.strategy_b == name and m.winner == "b")
202
+ )
203
+ )
204
+ losses = sum(
205
+ 1
206
+ for m in self.state.matches
207
+ if m.winner
208
+ and (
209
+ (m.strategy_a == name and m.winner == "b")
210
+ or (m.strategy_b == name and m.winner == "a")
211
+ )
212
+ )
213
+ ties = sum(
214
+ 1
215
+ for m in self.state.matches
216
+ if m.winner == "tie" and (m.strategy_a == name or m.strategy_b == name)
217
+ )
218
+ board.append(
219
+ {
220
+ "strategy": name,
221
+ "elo": round(elo, 1),
222
+ "wins": wins,
223
+ "losses": losses,
224
+ "ties": ties,
225
+ "matches": wins + losses + ties,
226
+ }
227
+ )
228
+ return sorted(board, key=lambda x: x["elo"], reverse=True)
229
+
230
+ def _append_vote_jsonl(self, match: Match) -> None:
231
+ """Append-only JSONL log of all votes (survives state resets)."""
232
+ jsonl_path = self._state_path.parent / "arena_votes.jsonl"
233
+ jsonl_path.parent.mkdir(parents=True, exist_ok=True)
234
+ record = {
235
+ "match_id": match.id,
236
+ "question": match.question[:200],
237
+ "strategy_a": match.strategy_a,
238
+ "strategy_b": match.strategy_b,
239
+ "winner": match.winner,
240
+ "latency_a_ms": round(match.latency_a_ms, 1),
241
+ "latency_b_ms": round(match.latency_b_ms, 1),
242
+ "cost_a": match.cost_a,
243
+ "cost_b": match.cost_b,
244
+ "timestamp": match.timestamp,
245
+ "elo_snapshot": {k: round(v, 1) for k, v in self.state.elo.items()},
246
+ }
247
+ with open(jsonl_path, "a") as f:
248
+ f.write(json.dumps(record) + "\n")
249
+
250
+ def get_pending_match(self, match_id: str) -> Match | None:
251
+ """Get a match by ID."""
252
+ return next((m for m in self.state.matches if m.id == match_id), None)
@@ -0,0 +1,15 @@
1
+ """Memory Arena benchmark engine."""
2
+
3
+ from memory_arena.benchmark.evaluator import (
4
+ MemoryScore,
5
+ evaluate_memory_answer,
6
+ )
7
+ from memory_arena.benchmark.questions import load_memory_questions
8
+ from memory_arena.benchmark.recall_metrics import compute_memory_recall_metrics
9
+
10
+ __all__ = [
11
+ "MemoryScore",
12
+ "compute_memory_recall_metrics",
13
+ "evaluate_memory_answer",
14
+ "load_memory_questions",
15
+ ]
@@ -0,0 +1,315 @@
1
+ """7-axis evaluator for memory-arena.
2
+
3
+ Axes (4 lifted from kb-arena + 3 memory-specific):
4
+ 1. structural — must_mention / must_not_claim / max_tokens
5
+ 2. sources — was at least one supporting_session_id cited?
6
+ 3. judge — Opus rates accuracy 0..100 against the reference
7
+ 4. memo — identical (answer, reference) pairs cached in-process
8
+ 5. temporal_correctness — answer's claimed timestamp is inside ground_truth.valid_as_of window
9
+ 6. update_precision — answer reflects the latest fact version, not an earlier one
10
+ 7. abstention_f1 — abstention questions get F1 over an abstain/no-abstain classifier
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import re
16
+ from dataclasses import dataclass, field
17
+
18
+ from pydantic import BaseModel, ConfigDict, Field
19
+
20
+ from memory_arena.llm.client import LLMClient
21
+ from memory_arena.sessions.schema import Constraints, GroundTruth, QuestionRecord
22
+
23
+ _JUDGE_SYSTEM = (
24
+ "You are a strict judge comparing a candidate answer to a reference answer for a question "
25
+ "about a chat history. Output a JSON object with this exact shape:\n"
26
+ '{"accuracy": <0-100>, "completeness": <0-100>, "rationale": "<one sentence>"}\n'
27
+ "Be strict: an answer that misses key facts should score below 60. An answer that "
28
+ "contradicts the reference should score below 30. Reply with only the JSON object."
29
+ )
30
+
31
+
32
+ _ABSTAIN_SYSTEM = (
33
+ "Classify whether the following answer abstains from answering (says it does not know, "
34
+ "cannot find the information, or refuses to answer). Reply with only YES or NO."
35
+ )
36
+
37
+
38
+ _TEMPORAL_SYSTEM = (
39
+ "Extract any explicit time marker mentioned in the answer (date, week-of, month, year, "
40
+ "or relative phrase like 'last week'). Reply with only the time marker, or NONE."
41
+ )
42
+
43
+
44
+ class MemoryScore(BaseModel):
45
+ """7-axis evaluation score."""
46
+
47
+ model_config = ConfigDict(extra="forbid")
48
+
49
+ accuracy: float = 0.0
50
+ completeness: float = 0.0
51
+ structural_pass: bool = True
52
+ structural_fails: list[str] = Field(default_factory=list)
53
+ sources_pass: bool = False
54
+ judge_score: float = 0.0
55
+ judge_rationale: str = ""
56
+ temporal_correct: bool = False
57
+ # update_precision_correct can be None when the question does not
58
+ # carry fact_versions to check — runner aggregates only non-None.
59
+ update_precision_correct: bool | None = False
60
+ abstained: bool = False
61
+ abstention_correct: bool = False
62
+ cost_usd: float = 0.0
63
+ tokens_used: int = 0
64
+
65
+
66
+ # Module-level memo cache. Keyed by (question_id, answer-hash, reference-hash)
67
+ # so long answers that diverge after the first 500 characters do not collide.
68
+ _judge_cache: dict[tuple[str, str, str], dict] = {}
69
+
70
+
71
+ def _evaluate_structural(
72
+ answer: str,
73
+ constraints: Constraints,
74
+ ) -> tuple[bool, list[str]]:
75
+ fails: list[str] = []
76
+ answer_lower = answer.lower()
77
+ for required in constraints.must_mention:
78
+ if required.lower() not in answer_lower:
79
+ fails.append(f"missing must_mention: {required}")
80
+ for forbidden in constraints.must_not_claim:
81
+ if forbidden.lower() in answer_lower:
82
+ fails.append(f"contains must_not_claim: {forbidden}")
83
+ if constraints.max_tokens > 0:
84
+ approx_tokens = max(1, len(answer) // 4)
85
+ if approx_tokens > constraints.max_tokens * 1.5:
86
+ fails.append(f"exceeds max_tokens by 50%: {approx_tokens} > {constraints.max_tokens}")
87
+ return (len(fails) == 0), fails
88
+
89
+
90
+ def _evaluate_sources(
91
+ answer: str,
92
+ supporting_session_ids: list[str],
93
+ ground_truth: GroundTruth,
94
+ ) -> bool:
95
+ expected = set(ground_truth.supporting_session_ids)
96
+ if not expected:
97
+ return True # no labeled sources -> can't fail
98
+ return bool(set(supporting_session_ids) & expected)
99
+
100
+
101
+ async def _evaluate_judge(
102
+ answer: str,
103
+ reference: str,
104
+ llm: LLMClient,
105
+ question_id: str = "",
106
+ ) -> tuple[float, float, str, float, int]:
107
+ import hashlib
108
+
109
+ a_hash = hashlib.sha256(answer.strip().encode("utf-8")).hexdigest()
110
+ r_hash = hashlib.sha256(reference.strip().encode("utf-8")).hexdigest()
111
+ key = (question_id, a_hash, r_hash)
112
+ if key in _judge_cache:
113
+ cached = _judge_cache[key]
114
+ return (
115
+ cached["accuracy"],
116
+ cached["completeness"],
117
+ cached["rationale"],
118
+ 0.0,
119
+ 0,
120
+ )
121
+ if not reference:
122
+ return 0.0, 0.0, "no reference provided", 0.0, 0
123
+ resp = await llm.judge(answer=answer, reference=reference, system_prompt=_JUDGE_SYSTEM)
124
+ text = resp.text.strip()
125
+ accuracy = 0.0
126
+ completeness = 0.0
127
+ rationale = ""
128
+ try:
129
+ import json
130
+
131
+ match = re.search(r"\{.*\}", text, re.DOTALL)
132
+ if match:
133
+ data = json.loads(match.group(0))
134
+ accuracy = float(data.get("accuracy", 0))
135
+ completeness = float(data.get("completeness", 0))
136
+ rationale = data.get("rationale", "")
137
+ except Exception:
138
+ rationale = f"judge parse failure: {text[:80]}"
139
+ _judge_cache[key] = {
140
+ "accuracy": accuracy,
141
+ "completeness": completeness,
142
+ "rationale": rationale,
143
+ }
144
+ return accuracy, completeness, rationale, resp.cost_usd, resp.total_tokens
145
+
146
+
147
+ async def _classify_abstention(answer: str, llm: LLMClient) -> tuple[bool, float, int]:
148
+ if _quick_abstain_match(answer):
149
+ return True, 0.0, 0
150
+ resp = await llm.classify(
151
+ query=answer[:1000], system_prompt=_ABSTAIN_SYSTEM, allowed_values=["YES", "NO"]
152
+ )
153
+ return resp.upper().startswith("Y"), 0.0, 0
154
+
155
+
156
+ def _quick_abstain_match(answer: str) -> bool:
157
+ a = answer.lower()
158
+ needles = (
159
+ "i do not have that information",
160
+ "i don't have that information",
161
+ "i don't know",
162
+ "i do not know",
163
+ "i'm not sure",
164
+ "i am not sure",
165
+ "no information",
166
+ "cannot answer",
167
+ "unable to answer",
168
+ )
169
+ return any(n in a for n in needles)
170
+
171
+
172
+ async def _check_temporal(
173
+ answer: str,
174
+ valid_as_of: str | None,
175
+ llm: LLMClient,
176
+ ) -> bool:
177
+ if not valid_as_of:
178
+ return True
179
+ resp = await llm.classify(query=answer[:1000], system_prompt=_TEMPORAL_SYSTEM)
180
+ if resp.upper() == "NONE" or not resp:
181
+ return False
182
+ return _temporal_overlap(resp, valid_as_of)
183
+
184
+
185
+ def _temporal_overlap(claimed: str, valid_as_of: str) -> bool:
186
+ """Naive substring overlap on year/month/day numbers. Tighten in v0.2."""
187
+ claimed_nums = set(re.findall(r"\d+", claimed))
188
+ expected_nums = set(re.findall(r"\d+", valid_as_of))
189
+ return bool(claimed_nums & expected_nums) if expected_nums else False
190
+
191
+
192
+ async def _check_update_precision(
193
+ answer: str,
194
+ ground_truth: GroundTruth,
195
+ ) -> bool | None:
196
+ """Answer must reflect the latest fact version, not an earlier one.
197
+
198
+ Returns None when fact_versions is empty/missing — there's nothing to
199
+ check, and reporting True would inflate the per-category metric to
200
+ 1.0 across every strategy regardless of whether the question
201
+ actually exercised an update. The runner aggregates only non-None
202
+ values.
203
+ """
204
+ if not ground_truth.fact_versions:
205
+ return None
206
+ a = answer.lower()
207
+ latest = ground_truth.fact_versions[-1].value.lower()
208
+ if latest and latest in a:
209
+ return True
210
+ earlier = [v.value.lower() for v in ground_truth.fact_versions[:-1]]
211
+ if any(v and v in a for v in earlier):
212
+ return False
213
+ return False
214
+
215
+
216
+ async def evaluate_memory_answer(
217
+ answer: str,
218
+ ground_truth: GroundTruth,
219
+ constraints: Constraints,
220
+ question: QuestionRecord,
221
+ llm: LLMClient,
222
+ supporting_session_ids: list[str] | None = None,
223
+ ) -> MemoryScore:
224
+ """Run all 7 axes and return a MemoryScore."""
225
+ score = MemoryScore()
226
+ if supporting_session_ids is None:
227
+ supporting_session_ids = []
228
+
229
+ # 1. structural
230
+ structural_pass, fails = _evaluate_structural(answer, constraints)
231
+ score.structural_pass = structural_pass
232
+ score.structural_fails = fails
233
+
234
+ # 2. sources
235
+ score.sources_pass = _evaluate_sources(answer, supporting_session_ids, ground_truth)
236
+
237
+ # 3. judge
238
+ accuracy, completeness, rationale, cost, toks = await _evaluate_judge(
239
+ answer, ground_truth.answer, llm, question_id=question.id
240
+ )
241
+ score.judge_score = accuracy
242
+ score.judge_rationale = rationale
243
+ score.completeness = completeness / 100.0 if completeness > 1 else completeness
244
+ score.cost_usd += cost
245
+ score.tokens_used += toks
246
+
247
+ # 5. temporal
248
+ if question.category == "temporal":
249
+ score.temporal_correct = await _check_temporal(answer, ground_truth.valid_as_of, llm)
250
+ else:
251
+ score.temporal_correct = True
252
+
253
+ # 6. update precision. _check_update_precision returns None when the
254
+ # question has no fact_versions to verify; preserve that so the
255
+ # runner can aggregate only the questions that actually exercised an
256
+ # update instead of polluting the metric with structural-trues.
257
+ if question.category == "knowledge_update":
258
+ score.update_precision_correct = await _check_update_precision(answer, ground_truth)
259
+ else:
260
+ score.update_precision_correct = True
261
+
262
+ # 7. abstention
263
+ abstained, _, _ = await _classify_abstention(answer, llm)
264
+ score.abstained = abstained
265
+ score.abstention_correct = abstained == constraints.abstention_expected
266
+
267
+ # Composite accuracy: judge score normalized to 0-1, dampened by structural and sources
268
+ base = accuracy / 100.0 if accuracy > 1 else accuracy
269
+ if not structural_pass:
270
+ base *= 0.5
271
+ if not score.sources_pass and ground_truth.supporting_session_ids:
272
+ base *= 0.8
273
+ if question.category == "abstention":
274
+ base = 1.0 if score.abstention_correct else 0.0
275
+ if question.category == "temporal" and not score.temporal_correct:
276
+ base *= 0.5
277
+ if question.category == "knowledge_update" and score.update_precision_correct is False:
278
+ base *= 0.5
279
+ score.accuracy = max(0.0, min(1.0, base))
280
+
281
+ return score
282
+
283
+
284
+ def clear_eval_cache() -> None:
285
+ _judge_cache.clear()
286
+
287
+
288
+ @dataclass
289
+ class EvaluatorBundle:
290
+ """Convenience wrapper for callers that batch many evaluations."""
291
+
292
+ llm: LLMClient = field(default_factory=LLMClient)
293
+
294
+ async def evaluate(
295
+ self,
296
+ answer: str,
297
+ question: QuestionRecord,
298
+ supporting_session_ids: list[str] | None = None,
299
+ ) -> MemoryScore:
300
+ return await evaluate_memory_answer(
301
+ answer=answer,
302
+ ground_truth=question.ground_truth,
303
+ constraints=question.constraints,
304
+ question=question,
305
+ llm=self.llm,
306
+ supporting_session_ids=supporting_session_ids,
307
+ )
308
+
309
+
310
+ __all__ = [
311
+ "EvaluatorBundle",
312
+ "MemoryScore",
313
+ "clear_eval_cache",
314
+ "evaluate_memory_answer",
315
+ ]