memory-arena 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memory_arena/__init__.py +19 -0
- memory_arena/arena/__init__.py +0 -0
- memory_arena/arena/engine.py +252 -0
- memory_arena/benchmark/__init__.py +15 -0
- memory_arena/benchmark/evaluator.py +315 -0
- memory_arena/benchmark/questions.py +150 -0
- memory_arena/benchmark/recall_lab.py +120 -0
- memory_arena/benchmark/recall_metrics.py +119 -0
- memory_arena/benchmark/runner.py +694 -0
- memory_arena/chatbot/__init__.py +0 -0
- memory_arena/chatbot/api.py +268 -0
- memory_arena/chatbot/router.py +12 -0
- memory_arena/chatbot/session.py +73 -0
- memory_arena/cli.py +439 -0
- memory_arena/data/longmemeval-s/processed/questions.jsonl +16 -0
- memory_arena/data/longmemeval-s/processed/sessions.jsonl +82 -0
- memory_arena/data/results_snapshot/longmemeval-s_bm25_seed0.json +1559 -0
- memory_arena/data/results_snapshot/longmemeval-s_bm25_seed1.json +1559 -0
- memory_arena/data/results_snapshot/longmemeval-s_bm25_seed2.json +1559 -0
- memory_arena/data/results_snapshot/longmemeval-s_bm25_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_cognee_seed0.json +1386 -0
- memory_arena/data/results_snapshot/longmemeval-s_cognee_summary.json +86 -0
- memory_arena/data/results_snapshot/longmemeval-s_full_context_seed0.json +2098 -0
- memory_arena/data/results_snapshot/longmemeval-s_full_context_summary.json +81 -0
- memory_arena/data/results_snapshot/longmemeval-s_graphiti_seed0.json +1444 -0
- memory_arena/data/results_snapshot/longmemeval-s_graphiti_summary.json +86 -0
- memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_seed0.json +1546 -0
- memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_seed1.json +1546 -0
- memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_seed2.json +1546 -0
- memory_arena/data/results_snapshot/longmemeval-s_hybrid_rrf_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_hyde_seed0.json +1540 -0
- memory_arena/data/results_snapshot/longmemeval-s_hyde_seed1.json +1540 -0
- memory_arena/data/results_snapshot/longmemeval-s_hyde_seed2.json +1537 -0
- memory_arena/data/results_snapshot/longmemeval-s_hyde_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_karpathy_llm_wiki_seed0.json +1496 -0
- memory_arena/data/results_snapshot/longmemeval-s_karpathy_llm_wiki_summary.json +86 -0
- memory_arena/data/results_snapshot/longmemeval-s_langmem_seed0.json +1386 -0
- memory_arena/data/results_snapshot/longmemeval-s_langmem_seed1.json +1386 -0
- memory_arena/data/results_snapshot/longmemeval-s_langmem_seed2.json +1386 -0
- memory_arena/data/results_snapshot/longmemeval-s_langmem_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_mem0_seed0.json +1453 -0
- memory_arena/data/results_snapshot/longmemeval-s_mem0_seed1.json +1455 -0
- memory_arena/data/results_snapshot/longmemeval-s_mem0_seed2.json +1453 -0
- memory_arena/data/results_snapshot/longmemeval-s_mem0_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_mem0g_seed0.json +1457 -0
- memory_arena/data/results_snapshot/longmemeval-s_mem0g_summary.json +86 -0
- memory_arena/data/results_snapshot/longmemeval-s_memori_seed0.json +1386 -0
- memory_arena/data/results_snapshot/longmemeval-s_memori_summary.json +86 -0
- memory_arena/data/results_snapshot/longmemeval-s_naive_vector_seed0.json +1540 -0
- memory_arena/data/results_snapshot/longmemeval-s_naive_vector_seed1.json +1540 -0
- memory_arena/data/results_snapshot/longmemeval-s_naive_vector_seed2.json +1538 -0
- memory_arena/data/results_snapshot/longmemeval-s_naive_vector_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_persona_profile_seed0.json +1537 -0
- memory_arena/data/results_snapshot/longmemeval-s_persona_profile_seed1.json +1538 -0
- memory_arena/data/results_snapshot/longmemeval-s_persona_profile_seed2.json +1540 -0
- memory_arena/data/results_snapshot/longmemeval-s_persona_profile_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_raptor_seed0.json +1521 -0
- memory_arena/data/results_snapshot/longmemeval-s_raptor_seed1.json +1521 -0
- memory_arena/data/results_snapshot/longmemeval-s_raptor_seed2.json +1521 -0
- memory_arena/data/results_snapshot/longmemeval-s_raptor_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_recency_window_seed0.json +1786 -0
- memory_arena/data/results_snapshot/longmemeval-s_recency_window_seed1.json +1786 -0
- memory_arena/data/results_snapshot/longmemeval-s_recency_window_seed2.json +1786 -0
- memory_arena/data/results_snapshot/longmemeval-s_recency_window_summary.json +90 -0
- memory_arena/data/results_snapshot/longmemeval-s_reflection_seed0.json +1539 -0
- memory_arena/data/results_snapshot/longmemeval-s_reflection_seed1.json +1539 -0
- memory_arena/data/results_snapshot/longmemeval-s_reflection_seed2.json +1541 -0
- memory_arena/data/results_snapshot/longmemeval-s_reflection_summary.json +90 -0
- memory_arena/exceptions.py +33 -0
- memory_arena/graph/__init__.py +0 -0
- memory_arena/graph/analyzer.py +143 -0
- memory_arena/graph/cypher_generator.py +115 -0
- memory_arena/graph/cypher_templates.py +106 -0
- memory_arena/graph/extractor.py +338 -0
- memory_arena/graph/neo4j_store.py +152 -0
- memory_arena/graph/resolver.py +77 -0
- memory_arena/graph/schema.py +55 -0
- memory_arena/llm/__init__.py +0 -0
- memory_arena/llm/client.py +327 -0
- memory_arena/llm/openrouter.py +122 -0
- memory_arena/llm/providers.py +252 -0
- memory_arena/models/__init__.py +1 -0
- memory_arena/models/api.py +61 -0
- memory_arena/models/benchmark.py +172 -0
- memory_arena/models/document.py +66 -0
- memory_arena/models/graph.py +51 -0
- memory_arena/models/retrieval.py +33 -0
- memory_arena/paths.py +90 -0
- memory_arena/py.typed +0 -0
- memory_arena/sessions/__init__.py +23 -0
- memory_arena/sessions/loaders.py +210 -0
- memory_arena/sessions/schema.py +110 -0
- memory_arena/settings.py +246 -0
- memory_arena/static/404/index.html +1 -0
- memory_arena/static/404.html +1 -0
- memory_arena/static/_next/static/chunks/117-020b9cef866aefe5.js +2 -0
- memory_arena/static/_next/static/chunks/648-77f86c7bea515da8.js +1 -0
- memory_arena/static/_next/static/chunks/app/_not-found/page-c59625541a56b4fe.js +1 -0
- memory_arena/static/_next/static/chunks/app/arena/page-410e2e6b1ec69f52.js +1 -0
- memory_arena/static/_next/static/chunks/app/benchmark/page-76db067e901c05a6.js +1 -0
- memory_arena/static/_next/static/chunks/app/layout-82e84594976f899d.js +1 -0
- memory_arena/static/_next/static/chunks/app/page-e07ee36911d8bec3.js +1 -0
- memory_arena/static/_next/static/chunks/app/recall-lab/page-5e983cc6e76fcbab.js +1 -0
- memory_arena/static/_next/static/chunks/fd9d1056-40eb70ab657cb18b.js +1 -0
- memory_arena/static/_next/static/chunks/framework-f66176bb897dc684.js +1 -0
- memory_arena/static/_next/static/chunks/main-app-ce09194c2205bef4.js +1 -0
- memory_arena/static/_next/static/chunks/main-f93c663012e61193.js +1 -0
- memory_arena/static/_next/static/chunks/pages/_app-72b849fbd24ac258.js +1 -0
- memory_arena/static/_next/static/chunks/pages/_error-7ba65e1336b92748.js +1 -0
- memory_arena/static/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
- memory_arena/static/_next/static/chunks/webpack-03f7c6bc932ce1e3.js +1 -0
- memory_arena/static/_next/static/css/7f0f5d6971a0dc74.css +3 -0
- memory_arena/static/_next/static/wwMdP-USmLrBi4wuoZSsd/_buildManifest.js +1 -0
- memory_arena/static/_next/static/wwMdP-USmLrBi4wuoZSsd/_ssgManifest.js +1 -0
- memory_arena/static/arena/index.html +1 -0
- memory_arena/static/arena/index.txt +8 -0
- memory_arena/static/benchmark/index.html +1 -0
- memory_arena/static/benchmark/index.txt +9 -0
- memory_arena/static/favicon.ico +0 -0
- memory_arena/static/index.html +1 -0
- memory_arena/static/index.txt +9 -0
- memory_arena/static/recall-lab/index.html +1 -0
- memory_arena/static/recall-lab/index.txt +9 -0
- memory_arena/strategies/__init__.py +190 -0
- memory_arena/strategies/amem.py +511 -0
- memory_arena/strategies/base.py +151 -0
- memory_arena/strategies/bm25.py +125 -0
- memory_arena/strategies/cognee.py +242 -0
- memory_arena/strategies/embeddings.py +62 -0
- memory_arena/strategies/full_context.py +87 -0
- memory_arena/strategies/graphiti.py +317 -0
- memory_arena/strategies/graphiti_falkor.py +131 -0
- memory_arena/strategies/hipporag2.py +539 -0
- memory_arena/strategies/hybrid_rrf.py +132 -0
- memory_arena/strategies/hyde.py +85 -0
- memory_arena/strategies/karpathy_llm_wiki.py +423 -0
- memory_arena/strategies/langmem.py +176 -0
- memory_arena/strategies/mem0.py +255 -0
- memory_arena/strategies/mem0g.py +240 -0
- memory_arena/strategies/memori.py +237 -0
- memory_arena/strategies/naive_vector.py +164 -0
- memory_arena/strategies/persona_profile.py +139 -0
- memory_arena/strategies/quantum/__init__.py +24 -0
- memory_arena/strategies/quantum/circuits.py +111 -0
- memory_arena/strategies/quantum/qiss.py +299 -0
- memory_arena/strategies/quantum/sqr.py +188 -0
- memory_arena/strategies/quantum/utils.py +106 -0
- memory_arena/strategies/raptor.py +254 -0
- memory_arena/strategies/recency_window.py +80 -0
- memory_arena/strategies/reflection.py +152 -0
- memory_arena/tokenizer.py +22 -0
- memory_arena/viz/__init__.py +0 -0
- memory_arena-0.1.8.dist-info/METADATA +746 -0
- memory_arena-0.1.8.dist-info/RECORD +157 -0
- memory_arena-0.1.8.dist-info/WHEEL +4 -0
- memory_arena-0.1.8.dist-info/entry_points.txt +2 -0
- memory_arena-0.1.8.dist-info/licenses/LICENSE +21 -0
memory_arena/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Memory Arena — Knowledge Base Benchmark. Find which retrieval architecture fits your data."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.8"
|
|
4
|
+
|
|
5
|
+
from memory_arena.models.benchmark import BenchmarkResult, Question
|
|
6
|
+
from memory_arena.models.document import Document, Section
|
|
7
|
+
from memory_arena.models.graph import Entity, Relationship
|
|
8
|
+
from memory_arena.strategies.base import Strategy
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Document",
|
|
12
|
+
"Section",
|
|
13
|
+
"Entity",
|
|
14
|
+
"Relationship",
|
|
15
|
+
"Question",
|
|
16
|
+
"BenchmarkResult",
|
|
17
|
+
"Strategy",
|
|
18
|
+
"__version__",
|
|
19
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Arena engine - blind A/B strategy matchups with ELO rating."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import random
|
|
9
|
+
import time
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from uuid import uuid4
|
|
13
|
+
|
|
14
|
+
from memory_arena.settings import settings
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
INITIAL_ELO = 1200.0
|
|
19
|
+
K_FACTOR = 32.0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Match:
|
|
24
|
+
"""A single A/B matchup between two strategies."""
|
|
25
|
+
|
|
26
|
+
id: str
|
|
27
|
+
question: str
|
|
28
|
+
strategy_a: str
|
|
29
|
+
strategy_b: str
|
|
30
|
+
answer_a: str
|
|
31
|
+
answer_b: str
|
|
32
|
+
latency_a_ms: float = 0.0
|
|
33
|
+
latency_b_ms: float = 0.0
|
|
34
|
+
cost_a: float = 0.0
|
|
35
|
+
cost_b: float = 0.0
|
|
36
|
+
winner: str | None = None # "a", "b", "tie", or None (pending)
|
|
37
|
+
timestamp: float = 0.0
|
|
38
|
+
sources_a: list[str] = field(default_factory=list)
|
|
39
|
+
sources_b: list[str] = field(default_factory=list)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class ArenaState:
|
|
44
|
+
"""Persistent arena state with ELO ratings and match history."""
|
|
45
|
+
|
|
46
|
+
elo: dict[str, float] = field(default_factory=dict)
|
|
47
|
+
matches: list[Match] = field(default_factory=list)
|
|
48
|
+
total_votes: int = 0
|
|
49
|
+
|
|
50
|
+
def save(self, path: Path) -> None:
|
|
51
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
data = {
|
|
53
|
+
"elo": self.elo,
|
|
54
|
+
"total_votes": self.total_votes,
|
|
55
|
+
"matches": [
|
|
56
|
+
{
|
|
57
|
+
"id": m.id,
|
|
58
|
+
"question": m.question,
|
|
59
|
+
"strategy_a": m.strategy_a,
|
|
60
|
+
"strategy_b": m.strategy_b,
|
|
61
|
+
"answer_a": m.answer_a[:500], # truncate for storage
|
|
62
|
+
"answer_b": m.answer_b[:500],
|
|
63
|
+
"latency_a_ms": m.latency_a_ms,
|
|
64
|
+
"latency_b_ms": m.latency_b_ms,
|
|
65
|
+
"cost_a": m.cost_a,
|
|
66
|
+
"cost_b": m.cost_b,
|
|
67
|
+
"winner": m.winner,
|
|
68
|
+
"timestamp": m.timestamp,
|
|
69
|
+
}
|
|
70
|
+
for m in self.matches[-200:] # keep last 200
|
|
71
|
+
],
|
|
72
|
+
}
|
|
73
|
+
path.write_text(json.dumps(data, indent=2))
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def load(cls, path: Path) -> ArenaState:
|
|
77
|
+
if not path.exists():
|
|
78
|
+
return cls()
|
|
79
|
+
try:
|
|
80
|
+
data = json.loads(path.read_text())
|
|
81
|
+
state = cls(
|
|
82
|
+
elo=data.get("elo", {}),
|
|
83
|
+
total_votes=data.get("total_votes", 0),
|
|
84
|
+
)
|
|
85
|
+
for m in data.get("matches", []):
|
|
86
|
+
state.matches.append(
|
|
87
|
+
Match(
|
|
88
|
+
id=m["id"],
|
|
89
|
+
question=m["question"],
|
|
90
|
+
strategy_a=m["strategy_a"],
|
|
91
|
+
strategy_b=m["strategy_b"],
|
|
92
|
+
answer_a=m.get("answer_a", ""),
|
|
93
|
+
answer_b=m.get("answer_b", ""),
|
|
94
|
+
latency_a_ms=m.get("latency_a_ms", 0),
|
|
95
|
+
latency_b_ms=m.get("latency_b_ms", 0),
|
|
96
|
+
cost_a=m.get("cost_a", 0),
|
|
97
|
+
cost_b=m.get("cost_b", 0),
|
|
98
|
+
winner=m.get("winner"),
|
|
99
|
+
timestamp=m.get("timestamp", 0),
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
return state
|
|
103
|
+
except (json.JSONDecodeError, KeyError):
|
|
104
|
+
log.warning("Corrupt arena state, starting fresh")
|
|
105
|
+
return cls()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class ArenaEngine:
|
|
109
|
+
"""Manages blind A/B matches between retrieval strategies."""
|
|
110
|
+
|
|
111
|
+
def __init__(self, strategies: dict) -> None:
|
|
112
|
+
self.strategies = strategies
|
|
113
|
+
self._state_path = Path(settings.results_path) / "arena_state.json"
|
|
114
|
+
self.state = ArenaState.load(self._state_path)
|
|
115
|
+
# Initialize ELO for new strategies
|
|
116
|
+
for name in strategies:
|
|
117
|
+
if name not in self.state.elo:
|
|
118
|
+
self.state.elo[name] = INITIAL_ELO
|
|
119
|
+
|
|
120
|
+
async def create_match(self, question: str, corpus: str = "") -> Match:
|
|
121
|
+
"""Pick two random strategies, query both, return a blind match."""
|
|
122
|
+
names = list(self.strategies.keys())
|
|
123
|
+
if len(names) < 2:
|
|
124
|
+
raise ValueError("Need at least 2 strategies for arena mode")
|
|
125
|
+
a_name, b_name = random.sample(names, 2)
|
|
126
|
+
|
|
127
|
+
result_a, result_b = await asyncio.gather(
|
|
128
|
+
self.strategies[a_name].query(question),
|
|
129
|
+
self.strategies[b_name].query(question),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
match = Match(
|
|
133
|
+
id=uuid4().hex[:8],
|
|
134
|
+
question=question,
|
|
135
|
+
strategy_a=a_name,
|
|
136
|
+
strategy_b=b_name,
|
|
137
|
+
answer_a=result_a.answer,
|
|
138
|
+
answer_b=result_b.answer,
|
|
139
|
+
latency_a_ms=result_a.latency_ms,
|
|
140
|
+
latency_b_ms=result_b.latency_ms,
|
|
141
|
+
cost_a=result_a.cost_usd,
|
|
142
|
+
cost_b=result_b.cost_usd,
|
|
143
|
+
sources_a=result_a.sources,
|
|
144
|
+
sources_b=result_b.sources,
|
|
145
|
+
timestamp=time.time(),
|
|
146
|
+
)
|
|
147
|
+
self.state.matches.append(match)
|
|
148
|
+
return match
|
|
149
|
+
|
|
150
|
+
def vote(self, match_id: str, winner: str) -> dict:
|
|
151
|
+
"""Record a vote and update ELO. winner: 'a', 'b', or 'tie'."""
|
|
152
|
+
if winner not in ("a", "b", "tie"):
|
|
153
|
+
return {"error": f"Invalid winner: {winner}. Must be 'a', 'b', or 'tie'"}
|
|
154
|
+
|
|
155
|
+
match = next((m for m in self.state.matches if m.id == match_id), None)
|
|
156
|
+
if not match:
|
|
157
|
+
return {"error": "Match not found"}
|
|
158
|
+
if match.winner is not None:
|
|
159
|
+
return {"error": "Match already voted on"}
|
|
160
|
+
|
|
161
|
+
match.winner = winner
|
|
162
|
+
self.state.total_votes += 1
|
|
163
|
+
self._update_elo(match)
|
|
164
|
+
self.state.save(self._state_path)
|
|
165
|
+
self._append_vote_jsonl(match)
|
|
166
|
+
|
|
167
|
+
return {
|
|
168
|
+
"strategy_a": match.strategy_a,
|
|
169
|
+
"strategy_b": match.strategy_b,
|
|
170
|
+
"winner": winner,
|
|
171
|
+
"elo": dict(self.state.elo),
|
|
172
|
+
"total_votes": self.state.total_votes,
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
def _update_elo(self, match: Match) -> None:
|
|
176
|
+
"""Standard ELO rating update."""
|
|
177
|
+
ea = self.state.elo.get(match.strategy_a, INITIAL_ELO)
|
|
178
|
+
eb = self.state.elo.get(match.strategy_b, INITIAL_ELO)
|
|
179
|
+
expected_a = 1.0 / (1.0 + 10.0 ** ((eb - ea) / 400.0))
|
|
180
|
+
|
|
181
|
+
if match.winner == "a":
|
|
182
|
+
score_a = 1.0
|
|
183
|
+
elif match.winner == "b":
|
|
184
|
+
score_a = 0.0
|
|
185
|
+
else: # tie
|
|
186
|
+
score_a = 0.5
|
|
187
|
+
|
|
188
|
+
self.state.elo[match.strategy_a] = ea + K_FACTOR * (score_a - expected_a)
|
|
189
|
+
self.state.elo[match.strategy_b] = eb + K_FACTOR * ((1 - score_a) - (1 - expected_a))
|
|
190
|
+
|
|
191
|
+
def leaderboard(self) -> list[dict]:
|
|
192
|
+
"""Return strategies sorted by ELO rating."""
|
|
193
|
+
board = []
|
|
194
|
+
for name, elo in self.state.elo.items():
|
|
195
|
+
wins = sum(
|
|
196
|
+
1
|
|
197
|
+
for m in self.state.matches
|
|
198
|
+
if m.winner
|
|
199
|
+
and (
|
|
200
|
+
(m.strategy_a == name and m.winner == "a")
|
|
201
|
+
or (m.strategy_b == name and m.winner == "b")
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
losses = sum(
|
|
205
|
+
1
|
|
206
|
+
for m in self.state.matches
|
|
207
|
+
if m.winner
|
|
208
|
+
and (
|
|
209
|
+
(m.strategy_a == name and m.winner == "b")
|
|
210
|
+
or (m.strategy_b == name and m.winner == "a")
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
ties = sum(
|
|
214
|
+
1
|
|
215
|
+
for m in self.state.matches
|
|
216
|
+
if m.winner == "tie" and (m.strategy_a == name or m.strategy_b == name)
|
|
217
|
+
)
|
|
218
|
+
board.append(
|
|
219
|
+
{
|
|
220
|
+
"strategy": name,
|
|
221
|
+
"elo": round(elo, 1),
|
|
222
|
+
"wins": wins,
|
|
223
|
+
"losses": losses,
|
|
224
|
+
"ties": ties,
|
|
225
|
+
"matches": wins + losses + ties,
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
return sorted(board, key=lambda x: x["elo"], reverse=True)
|
|
229
|
+
|
|
230
|
+
def _append_vote_jsonl(self, match: Match) -> None:
|
|
231
|
+
"""Append-only JSONL log of all votes (survives state resets)."""
|
|
232
|
+
jsonl_path = self._state_path.parent / "arena_votes.jsonl"
|
|
233
|
+
jsonl_path.parent.mkdir(parents=True, exist_ok=True)
|
|
234
|
+
record = {
|
|
235
|
+
"match_id": match.id,
|
|
236
|
+
"question": match.question[:200],
|
|
237
|
+
"strategy_a": match.strategy_a,
|
|
238
|
+
"strategy_b": match.strategy_b,
|
|
239
|
+
"winner": match.winner,
|
|
240
|
+
"latency_a_ms": round(match.latency_a_ms, 1),
|
|
241
|
+
"latency_b_ms": round(match.latency_b_ms, 1),
|
|
242
|
+
"cost_a": match.cost_a,
|
|
243
|
+
"cost_b": match.cost_b,
|
|
244
|
+
"timestamp": match.timestamp,
|
|
245
|
+
"elo_snapshot": {k: round(v, 1) for k, v in self.state.elo.items()},
|
|
246
|
+
}
|
|
247
|
+
with open(jsonl_path, "a") as f:
|
|
248
|
+
f.write(json.dumps(record) + "\n")
|
|
249
|
+
|
|
250
|
+
def get_pending_match(self, match_id: str) -> Match | None:
|
|
251
|
+
"""Get a match by ID."""
|
|
252
|
+
return next((m for m in self.state.matches if m.id == match_id), None)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Memory Arena benchmark engine."""
|
|
2
|
+
|
|
3
|
+
from memory_arena.benchmark.evaluator import (
|
|
4
|
+
MemoryScore,
|
|
5
|
+
evaluate_memory_answer,
|
|
6
|
+
)
|
|
7
|
+
from memory_arena.benchmark.questions import load_memory_questions
|
|
8
|
+
from memory_arena.benchmark.recall_metrics import compute_memory_recall_metrics
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"MemoryScore",
|
|
12
|
+
"compute_memory_recall_metrics",
|
|
13
|
+
"evaluate_memory_answer",
|
|
14
|
+
"load_memory_questions",
|
|
15
|
+
]
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""7-axis evaluator for memory-arena.
|
|
2
|
+
|
|
3
|
+
Axes (4 lifted from kb-arena + 3 memory-specific):
|
|
4
|
+
1. structural — must_mention / must_not_claim / max_tokens
|
|
5
|
+
2. sources — was at least one supporting_session_id cited?
|
|
6
|
+
3. judge — Opus rates accuracy 0..100 against the reference
|
|
7
|
+
4. memo — identical (answer, reference) pairs cached in-process
|
|
8
|
+
5. temporal_correctness — answer's claimed timestamp is inside ground_truth.valid_as_of window
|
|
9
|
+
6. update_precision — answer reflects the latest fact version, not an earlier one
|
|
10
|
+
7. abstention_f1 — abstention questions get F1 over an abstain/no-abstain classifier
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
19
|
+
|
|
20
|
+
from memory_arena.llm.client import LLMClient
|
|
21
|
+
from memory_arena.sessions.schema import Constraints, GroundTruth, QuestionRecord
|
|
22
|
+
|
|
23
|
+
_JUDGE_SYSTEM = (
|
|
24
|
+
"You are a strict judge comparing a candidate answer to a reference answer for a question "
|
|
25
|
+
"about a chat history. Output a JSON object with this exact shape:\n"
|
|
26
|
+
'{"accuracy": <0-100>, "completeness": <0-100>, "rationale": "<one sentence>"}\n'
|
|
27
|
+
"Be strict: an answer that misses key facts should score below 60. An answer that "
|
|
28
|
+
"contradicts the reference should score below 30. Reply with only the JSON object."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
_ABSTAIN_SYSTEM = (
|
|
33
|
+
"Classify whether the following answer abstains from answering (says it does not know, "
|
|
34
|
+
"cannot find the information, or refuses to answer). Reply with only YES or NO."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_TEMPORAL_SYSTEM = (
|
|
39
|
+
"Extract any explicit time marker mentioned in the answer (date, week-of, month, year, "
|
|
40
|
+
"or relative phrase like 'last week'). Reply with only the time marker, or NONE."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MemoryScore(BaseModel):
|
|
45
|
+
"""7-axis evaluation score."""
|
|
46
|
+
|
|
47
|
+
model_config = ConfigDict(extra="forbid")
|
|
48
|
+
|
|
49
|
+
accuracy: float = 0.0
|
|
50
|
+
completeness: float = 0.0
|
|
51
|
+
structural_pass: bool = True
|
|
52
|
+
structural_fails: list[str] = Field(default_factory=list)
|
|
53
|
+
sources_pass: bool = False
|
|
54
|
+
judge_score: float = 0.0
|
|
55
|
+
judge_rationale: str = ""
|
|
56
|
+
temporal_correct: bool = False
|
|
57
|
+
# update_precision_correct can be None when the question does not
|
|
58
|
+
# carry fact_versions to check — runner aggregates only non-None.
|
|
59
|
+
update_precision_correct: bool | None = False
|
|
60
|
+
abstained: bool = False
|
|
61
|
+
abstention_correct: bool = False
|
|
62
|
+
cost_usd: float = 0.0
|
|
63
|
+
tokens_used: int = 0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# Module-level memo cache. Keyed by (question_id, answer-hash, reference-hash)
|
|
67
|
+
# so long answers that diverge after the first 500 characters do not collide.
|
|
68
|
+
_judge_cache: dict[tuple[str, str, str], dict] = {}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _evaluate_structural(
|
|
72
|
+
answer: str,
|
|
73
|
+
constraints: Constraints,
|
|
74
|
+
) -> tuple[bool, list[str]]:
|
|
75
|
+
fails: list[str] = []
|
|
76
|
+
answer_lower = answer.lower()
|
|
77
|
+
for required in constraints.must_mention:
|
|
78
|
+
if required.lower() not in answer_lower:
|
|
79
|
+
fails.append(f"missing must_mention: {required}")
|
|
80
|
+
for forbidden in constraints.must_not_claim:
|
|
81
|
+
if forbidden.lower() in answer_lower:
|
|
82
|
+
fails.append(f"contains must_not_claim: {forbidden}")
|
|
83
|
+
if constraints.max_tokens > 0:
|
|
84
|
+
approx_tokens = max(1, len(answer) // 4)
|
|
85
|
+
if approx_tokens > constraints.max_tokens * 1.5:
|
|
86
|
+
fails.append(f"exceeds max_tokens by 50%: {approx_tokens} > {constraints.max_tokens}")
|
|
87
|
+
return (len(fails) == 0), fails
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _evaluate_sources(
|
|
91
|
+
answer: str,
|
|
92
|
+
supporting_session_ids: list[str],
|
|
93
|
+
ground_truth: GroundTruth,
|
|
94
|
+
) -> bool:
|
|
95
|
+
expected = set(ground_truth.supporting_session_ids)
|
|
96
|
+
if not expected:
|
|
97
|
+
return True # no labeled sources -> can't fail
|
|
98
|
+
return bool(set(supporting_session_ids) & expected)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def _evaluate_judge(
|
|
102
|
+
answer: str,
|
|
103
|
+
reference: str,
|
|
104
|
+
llm: LLMClient,
|
|
105
|
+
question_id: str = "",
|
|
106
|
+
) -> tuple[float, float, str, float, int]:
|
|
107
|
+
import hashlib
|
|
108
|
+
|
|
109
|
+
a_hash = hashlib.sha256(answer.strip().encode("utf-8")).hexdigest()
|
|
110
|
+
r_hash = hashlib.sha256(reference.strip().encode("utf-8")).hexdigest()
|
|
111
|
+
key = (question_id, a_hash, r_hash)
|
|
112
|
+
if key in _judge_cache:
|
|
113
|
+
cached = _judge_cache[key]
|
|
114
|
+
return (
|
|
115
|
+
cached["accuracy"],
|
|
116
|
+
cached["completeness"],
|
|
117
|
+
cached["rationale"],
|
|
118
|
+
0.0,
|
|
119
|
+
0,
|
|
120
|
+
)
|
|
121
|
+
if not reference:
|
|
122
|
+
return 0.0, 0.0, "no reference provided", 0.0, 0
|
|
123
|
+
resp = await llm.judge(answer=answer, reference=reference, system_prompt=_JUDGE_SYSTEM)
|
|
124
|
+
text = resp.text.strip()
|
|
125
|
+
accuracy = 0.0
|
|
126
|
+
completeness = 0.0
|
|
127
|
+
rationale = ""
|
|
128
|
+
try:
|
|
129
|
+
import json
|
|
130
|
+
|
|
131
|
+
match = re.search(r"\{.*\}", text, re.DOTALL)
|
|
132
|
+
if match:
|
|
133
|
+
data = json.loads(match.group(0))
|
|
134
|
+
accuracy = float(data.get("accuracy", 0))
|
|
135
|
+
completeness = float(data.get("completeness", 0))
|
|
136
|
+
rationale = data.get("rationale", "")
|
|
137
|
+
except Exception:
|
|
138
|
+
rationale = f"judge parse failure: {text[:80]}"
|
|
139
|
+
_judge_cache[key] = {
|
|
140
|
+
"accuracy": accuracy,
|
|
141
|
+
"completeness": completeness,
|
|
142
|
+
"rationale": rationale,
|
|
143
|
+
}
|
|
144
|
+
return accuracy, completeness, rationale, resp.cost_usd, resp.total_tokens
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def _classify_abstention(answer: str, llm: LLMClient) -> tuple[bool, float, int]:
|
|
148
|
+
if _quick_abstain_match(answer):
|
|
149
|
+
return True, 0.0, 0
|
|
150
|
+
resp = await llm.classify(
|
|
151
|
+
query=answer[:1000], system_prompt=_ABSTAIN_SYSTEM, allowed_values=["YES", "NO"]
|
|
152
|
+
)
|
|
153
|
+
return resp.upper().startswith("Y"), 0.0, 0
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _quick_abstain_match(answer: str) -> bool:
|
|
157
|
+
a = answer.lower()
|
|
158
|
+
needles = (
|
|
159
|
+
"i do not have that information",
|
|
160
|
+
"i don't have that information",
|
|
161
|
+
"i don't know",
|
|
162
|
+
"i do not know",
|
|
163
|
+
"i'm not sure",
|
|
164
|
+
"i am not sure",
|
|
165
|
+
"no information",
|
|
166
|
+
"cannot answer",
|
|
167
|
+
"unable to answer",
|
|
168
|
+
)
|
|
169
|
+
return any(n in a for n in needles)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
async def _check_temporal(
|
|
173
|
+
answer: str,
|
|
174
|
+
valid_as_of: str | None,
|
|
175
|
+
llm: LLMClient,
|
|
176
|
+
) -> bool:
|
|
177
|
+
if not valid_as_of:
|
|
178
|
+
return True
|
|
179
|
+
resp = await llm.classify(query=answer[:1000], system_prompt=_TEMPORAL_SYSTEM)
|
|
180
|
+
if resp.upper() == "NONE" or not resp:
|
|
181
|
+
return False
|
|
182
|
+
return _temporal_overlap(resp, valid_as_of)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _temporal_overlap(claimed: str, valid_as_of: str) -> bool:
|
|
186
|
+
"""Naive substring overlap on year/month/day numbers. Tighten in v0.2."""
|
|
187
|
+
claimed_nums = set(re.findall(r"\d+", claimed))
|
|
188
|
+
expected_nums = set(re.findall(r"\d+", valid_as_of))
|
|
189
|
+
return bool(claimed_nums & expected_nums) if expected_nums else False
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
async def _check_update_precision(
|
|
193
|
+
answer: str,
|
|
194
|
+
ground_truth: GroundTruth,
|
|
195
|
+
) -> bool | None:
|
|
196
|
+
"""Answer must reflect the latest fact version, not an earlier one.
|
|
197
|
+
|
|
198
|
+
Returns None when fact_versions is empty/missing — there's nothing to
|
|
199
|
+
check, and reporting True would inflate the per-category metric to
|
|
200
|
+
1.0 across every strategy regardless of whether the question
|
|
201
|
+
actually exercised an update. The runner aggregates only non-None
|
|
202
|
+
values.
|
|
203
|
+
"""
|
|
204
|
+
if not ground_truth.fact_versions:
|
|
205
|
+
return None
|
|
206
|
+
a = answer.lower()
|
|
207
|
+
latest = ground_truth.fact_versions[-1].value.lower()
|
|
208
|
+
if latest and latest in a:
|
|
209
|
+
return True
|
|
210
|
+
earlier = [v.value.lower() for v in ground_truth.fact_versions[:-1]]
|
|
211
|
+
if any(v and v in a for v in earlier):
|
|
212
|
+
return False
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
async def evaluate_memory_answer(
|
|
217
|
+
answer: str,
|
|
218
|
+
ground_truth: GroundTruth,
|
|
219
|
+
constraints: Constraints,
|
|
220
|
+
question: QuestionRecord,
|
|
221
|
+
llm: LLMClient,
|
|
222
|
+
supporting_session_ids: list[str] | None = None,
|
|
223
|
+
) -> MemoryScore:
|
|
224
|
+
"""Run all 7 axes and return a MemoryScore."""
|
|
225
|
+
score = MemoryScore()
|
|
226
|
+
if supporting_session_ids is None:
|
|
227
|
+
supporting_session_ids = []
|
|
228
|
+
|
|
229
|
+
# 1. structural
|
|
230
|
+
structural_pass, fails = _evaluate_structural(answer, constraints)
|
|
231
|
+
score.structural_pass = structural_pass
|
|
232
|
+
score.structural_fails = fails
|
|
233
|
+
|
|
234
|
+
# 2. sources
|
|
235
|
+
score.sources_pass = _evaluate_sources(answer, supporting_session_ids, ground_truth)
|
|
236
|
+
|
|
237
|
+
# 3. judge
|
|
238
|
+
accuracy, completeness, rationale, cost, toks = await _evaluate_judge(
|
|
239
|
+
answer, ground_truth.answer, llm, question_id=question.id
|
|
240
|
+
)
|
|
241
|
+
score.judge_score = accuracy
|
|
242
|
+
score.judge_rationale = rationale
|
|
243
|
+
score.completeness = completeness / 100.0 if completeness > 1 else completeness
|
|
244
|
+
score.cost_usd += cost
|
|
245
|
+
score.tokens_used += toks
|
|
246
|
+
|
|
247
|
+
# 5. temporal
|
|
248
|
+
if question.category == "temporal":
|
|
249
|
+
score.temporal_correct = await _check_temporal(answer, ground_truth.valid_as_of, llm)
|
|
250
|
+
else:
|
|
251
|
+
score.temporal_correct = True
|
|
252
|
+
|
|
253
|
+
# 6. update precision. _check_update_precision returns None when the
|
|
254
|
+
# question has no fact_versions to verify; preserve that so the
|
|
255
|
+
# runner can aggregate only the questions that actually exercised an
|
|
256
|
+
# update instead of polluting the metric with structural-trues.
|
|
257
|
+
if question.category == "knowledge_update":
|
|
258
|
+
score.update_precision_correct = await _check_update_precision(answer, ground_truth)
|
|
259
|
+
else:
|
|
260
|
+
score.update_precision_correct = True
|
|
261
|
+
|
|
262
|
+
# 7. abstention
|
|
263
|
+
abstained, _, _ = await _classify_abstention(answer, llm)
|
|
264
|
+
score.abstained = abstained
|
|
265
|
+
score.abstention_correct = abstained == constraints.abstention_expected
|
|
266
|
+
|
|
267
|
+
# Composite accuracy: judge score normalized to 0-1, dampened by structural and sources
|
|
268
|
+
base = accuracy / 100.0 if accuracy > 1 else accuracy
|
|
269
|
+
if not structural_pass:
|
|
270
|
+
base *= 0.5
|
|
271
|
+
if not score.sources_pass and ground_truth.supporting_session_ids:
|
|
272
|
+
base *= 0.8
|
|
273
|
+
if question.category == "abstention":
|
|
274
|
+
base = 1.0 if score.abstention_correct else 0.0
|
|
275
|
+
if question.category == "temporal" and not score.temporal_correct:
|
|
276
|
+
base *= 0.5
|
|
277
|
+
if question.category == "knowledge_update" and score.update_precision_correct is False:
|
|
278
|
+
base *= 0.5
|
|
279
|
+
score.accuracy = max(0.0, min(1.0, base))
|
|
280
|
+
|
|
281
|
+
return score
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def clear_eval_cache() -> None:
|
|
285
|
+
_judge_cache.clear()
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@dataclass
|
|
289
|
+
class EvaluatorBundle:
|
|
290
|
+
"""Convenience wrapper for callers that batch many evaluations."""
|
|
291
|
+
|
|
292
|
+
llm: LLMClient = field(default_factory=LLMClient)
|
|
293
|
+
|
|
294
|
+
async def evaluate(
|
|
295
|
+
self,
|
|
296
|
+
answer: str,
|
|
297
|
+
question: QuestionRecord,
|
|
298
|
+
supporting_session_ids: list[str] | None = None,
|
|
299
|
+
) -> MemoryScore:
|
|
300
|
+
return await evaluate_memory_answer(
|
|
301
|
+
answer=answer,
|
|
302
|
+
ground_truth=question.ground_truth,
|
|
303
|
+
constraints=question.constraints,
|
|
304
|
+
question=question,
|
|
305
|
+
llm=self.llm,
|
|
306
|
+
supporting_session_ids=supporting_session_ids,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
__all__ = [
|
|
311
|
+
"EvaluatorBundle",
|
|
312
|
+
"MemoryScore",
|
|
313
|
+
"clear_eval_cache",
|
|
314
|
+
"evaluate_memory_answer",
|
|
315
|
+
]
|