cheesebench 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- analysis.py +552 -0
- benchmark.py +1021 -0
- cheesebench-0.1.0.dist-info/METADATA +179 -0
- cheesebench-0.1.0.dist-info/RECORD +28 -0
- cheesebench-0.1.0.dist-info/WHEEL +5 -0
- cheesebench-0.1.0.dist-info/entry_points.txt +2 -0
- cheesebench-0.1.0.dist-info/licenses/LICENSE +21 -0
- cheesebench-0.1.0.dist-info/top_level.txt +11 -0
- config.py +51 -0
- cscgql_agent.py +189 -0
- environments/__init__.py +77 -0
- environments/barnes_maze.py +564 -0
- environments/base_env.py +1725 -0
- environments/dnms_task.py +439 -0
- environments/morris_water_maze.py +573 -0
- environments/operant_chamber.py +606 -0
- environments/place_preference.py +471 -0
- environments/radial_arm_maze.py +529 -0
- environments/registry.py +378 -0
- environments/shuttle_box.py +520 -0
- environments/star_maze.py +454 -0
- environments/t_maze.py +455 -0
- error_analysis.py +324 -0
- heuristic_agent.py +109 -0
- model_server.py +346 -0
- play.py +469 -0
- stat_tests.py +488 -0
- visualize.py +418 -0
analysis.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CheeseBench Analysis Pipeline
|
|
3
|
+
|
|
4
|
+
Computes cognitive profiles, learning curves, strategy metrics,
|
|
5
|
+
and generates publication-quality figures from benchmark results.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import math
|
|
10
|
+
import numpy as np
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Dict, List, Optional, Tuple
|
|
14
|
+
from collections import defaultdict
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# ============================================================================
|
|
18
|
+
# Cognitive Taxonomy — maps environments to cognitive dimensions
|
|
19
|
+
# ============================================================================
|
|
20
|
+
|
|
21
|
+
COGNITIVE_DIMENSIONS = [
|
|
22
|
+
"Allocentric Spatial Learning",
|
|
23
|
+
"Egocentric Navigation",
|
|
24
|
+
"Working Memory",
|
|
25
|
+
"Instrumental Conditioning",
|
|
26
|
+
"Avoidance Learning",
|
|
27
|
+
"Associative Learning",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# Each env maps to one or more dimensions with a weight (0-1)
|
|
31
|
+
ENV_COGNITIVE_MAP: Dict[str, Dict[str, float]] = {
|
|
32
|
+
"MorrisWaterMaze": {
|
|
33
|
+
"Allocentric Spatial Learning": 1.0,
|
|
34
|
+
},
|
|
35
|
+
"BarnesMaze": {
|
|
36
|
+
"Allocentric Spatial Learning": 1.0,
|
|
37
|
+
},
|
|
38
|
+
"TMaze": {
|
|
39
|
+
"Egocentric Navigation": 0.7,
|
|
40
|
+
"Working Memory": 0.3,
|
|
41
|
+
},
|
|
42
|
+
"StarMaze": {
|
|
43
|
+
"Allocentric Spatial Learning": 0.5,
|
|
44
|
+
"Egocentric Navigation": 0.5,
|
|
45
|
+
},
|
|
46
|
+
"RadialArmMaze": {
|
|
47
|
+
"Working Memory": 0.7,
|
|
48
|
+
"Allocentric Spatial Learning": 0.3,
|
|
49
|
+
},
|
|
50
|
+
"OperantChamber": {
|
|
51
|
+
"Instrumental Conditioning": 1.0,
|
|
52
|
+
},
|
|
53
|
+
"ShuttleBox": {
|
|
54
|
+
"Avoidance Learning": 1.0,
|
|
55
|
+
},
|
|
56
|
+
"PlacePreference": {
|
|
57
|
+
"Associative Learning": 1.0,
|
|
58
|
+
},
|
|
59
|
+
"DNMSTask": {
|
|
60
|
+
"Working Memory": 1.0,
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
# Neural circuit dependencies (for discussion section)
|
|
65
|
+
NEURAL_CIRCUITS: Dict[str, str] = {
|
|
66
|
+
"Allocentric Spatial Learning": "Hippocampus (place cells, grid cells)",
|
|
67
|
+
"Egocentric Navigation": "Dorsomedial striatum, parietal cortex",
|
|
68
|
+
"Working Memory": "Prefrontal cortex, hippocampus",
|
|
69
|
+
"Instrumental Conditioning": "Dorsolateral striatum, nucleus accumbens",
|
|
70
|
+
"Avoidance Learning": "Amygdala, periaqueductal gray",
|
|
71
|
+
"Associative Learning": "Ventral tegmental area, nucleus accumbens",
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# ============================================================================
|
|
76
|
+
# Published Animal Baselines — extracted from source papers
|
|
77
|
+
# ============================================================================
|
|
78
|
+
|
|
79
|
+
ANIMAL_BASELINES: Dict[str, Dict] = {
|
|
80
|
+
"MorrisWaterMaze": {
|
|
81
|
+
"species": "C57BL/6 mice",
|
|
82
|
+
"source": "PMC3259155 (de Fiebre et al., 2006)",
|
|
83
|
+
"trials_to_criterion": 15,
|
|
84
|
+
"success_rate_session_1": 0.30, # ~30% find platform in first session
|
|
85
|
+
"success_rate_session_5": 0.85, # ~85% by session 5
|
|
86
|
+
"learning_curve": [0.30, 0.50, 0.65, 0.78, 0.85], # per-session
|
|
87
|
+
"avg_latency_session_1_s": 55.0, # seconds
|
|
88
|
+
"avg_latency_session_5_s": 15.0,
|
|
89
|
+
"notes": "Platform acquisition over 5 sessions, 5 trials/session",
|
|
90
|
+
},
|
|
91
|
+
"TMaze": {
|
|
92
|
+
"species": "NMRI mice",
|
|
93
|
+
"source": "PMC3399492 (Shoji et al., 2012)",
|
|
94
|
+
"trials_to_criterion": 20,
|
|
95
|
+
"success_rate_session_1": 0.50, # chance level
|
|
96
|
+
"success_rate_session_4": 0.80,
|
|
97
|
+
"learning_curve": [0.50, 0.60, 0.72, 0.80],
|
|
98
|
+
"notes": "Forced alternation, 10 trials/session over 4 days",
|
|
99
|
+
},
|
|
100
|
+
"BarnesMaze": {
|
|
101
|
+
"species": "B6C3F1/J mice",
|
|
102
|
+
"source": "PMC1783636 (Harrison et al., 2006)",
|
|
103
|
+
"trials_to_criterion": 20,
|
|
104
|
+
"success_rate_session_1": 0.20,
|
|
105
|
+
"success_rate_session_5": 0.80,
|
|
106
|
+
"learning_curve": [0.20, 0.40, 0.55, 0.70, 0.80],
|
|
107
|
+
"avg_latency_session_1_s": 120.0,
|
|
108
|
+
"avg_latency_session_5_s": 25.0,
|
|
109
|
+
"notes": "12 holes, 4 trials/session over 5 sessions",
|
|
110
|
+
},
|
|
111
|
+
"RadialArmMaze": {
|
|
112
|
+
"species": "C57BL/6 mice",
|
|
113
|
+
"source": "PMC4030456 (Penley et al., 2013)",
|
|
114
|
+
"trials_to_criterion": 36,
|
|
115
|
+
"success_rate_session_1": 0.15,
|
|
116
|
+
"success_rate_session_6": 0.70,
|
|
117
|
+
"learning_curve": [0.15, 0.25, 0.40, 0.50, 0.60, 0.70],
|
|
118
|
+
"working_memory_errors_session_1": 3.5,
|
|
119
|
+
"working_memory_errors_session_6": 0.8,
|
|
120
|
+
"notes": "8 arms, 4 baited. Errors = revisits to depleted arms",
|
|
121
|
+
},
|
|
122
|
+
"OperantChamber": {
|
|
123
|
+
"species": "C57BL/6 mice",
|
|
124
|
+
"source": "PMC6619163 (Jurado-Parras et al., 2013)",
|
|
125
|
+
"trials_to_criterion": 50,
|
|
126
|
+
"success_rate_session_1": 0.40,
|
|
127
|
+
"success_rate_session_5": 0.90,
|
|
128
|
+
"learning_curve": [0.40, 0.60, 0.75, 0.85, 0.90],
|
|
129
|
+
"avg_lever_presses_session_1": 12.0,
|
|
130
|
+
"avg_lever_presses_session_5": 45.0,
|
|
131
|
+
"notes": "FR-1 schedule, 30 min sessions",
|
|
132
|
+
},
|
|
133
|
+
"ShuttleBox": {
|
|
134
|
+
"species": "Sprague Dawley rats",
|
|
135
|
+
"source": "PMC4633642 (Chacon et al., 2016)",
|
|
136
|
+
"trials_to_criterion": 90,
|
|
137
|
+
"success_rate_session_1": 0.10,
|
|
138
|
+
"success_rate_session_5": 0.70,
|
|
139
|
+
"learning_curve": [0.10, 0.25, 0.45, 0.60, 0.70],
|
|
140
|
+
"notes": "Active avoidance, 30 trials/session over 5 sessions",
|
|
141
|
+
},
|
|
142
|
+
"PlacePreference": {
|
|
143
|
+
"species": "C57BL/6 mice",
|
|
144
|
+
"source": "PMC6101638 (Blanco-Gandía et al., 2018)",
|
|
145
|
+
"trials_to_criterion": 6,
|
|
146
|
+
"success_rate_session_1": 0.50, # chance
|
|
147
|
+
"success_rate_session_6": 0.75,
|
|
148
|
+
"learning_curve": [0.50, 0.55, 0.60, 0.65, 0.70, 0.75],
|
|
149
|
+
"time_in_paired_chamber_pct_pre": 50.0,
|
|
150
|
+
"time_in_paired_chamber_pct_post": 65.0,
|
|
151
|
+
"notes": "3-phase CPP protocol, 6 conditioning sessions",
|
|
152
|
+
},
|
|
153
|
+
"StarMaze": {
|
|
154
|
+
"species": "C57BL/6 mice",
|
|
155
|
+
"source": "PMC3695082 (Rondi-Reig et al., 2006)",
|
|
156
|
+
"trials_to_criterion": 40,
|
|
157
|
+
"success_rate_session_1": 0.25,
|
|
158
|
+
"success_rate_session_10": 0.80,
|
|
159
|
+
"learning_curve": [0.25, 0.30, 0.40, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80],
|
|
160
|
+
"notes": "5-arm maze, 4 trials/session over 10 sessions",
|
|
161
|
+
},
|
|
162
|
+
"DNMSTask": {
|
|
163
|
+
"species": "Long-Evans rats",
|
|
164
|
+
"source": "PMC3982138 (Oomen et al., 2013)",
|
|
165
|
+
"trials_to_criterion": 2856,
|
|
166
|
+
"success_rate_session_1": 0.50, # chance (2AFC)
|
|
167
|
+
"success_rate_session_30": 0.80,
|
|
168
|
+
"learning_curve": [0.50, 0.55, 0.58, 0.60, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.78, 0.79, 0.79, 0.79, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80],
|
|
169
|
+
"notes": "TUNL task, 84 trials/session, ~30 sessions to criterion",
|
|
170
|
+
},
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# ============================================================================
|
|
175
|
+
# Metric Computation
|
|
176
|
+
# ============================================================================
|
|
177
|
+
|
|
178
|
+
@dataclass
|
|
179
|
+
class TrialMetrics:
|
|
180
|
+
"""Rich metrics for a single trial."""
|
|
181
|
+
steps: int = 0
|
|
182
|
+
reward: float = 0.0
|
|
183
|
+
success: bool = False
|
|
184
|
+
actions: List[str] = field(default_factory=list)
|
|
185
|
+
|
|
186
|
+
# Computed metrics
|
|
187
|
+
action_entropy: float = 0.0
|
|
188
|
+
forward_ratio: float = 0.0
|
|
189
|
+
rotation_ratio: float = 0.0
|
|
190
|
+
stay_ratio: float = 0.0
|
|
191
|
+
action_repetition_rate: float = 0.0 # same action back-to-back
|
|
192
|
+
direction_changes: int = 0 # how often rotation direction changes
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclass
|
|
196
|
+
class EnvResult:
|
|
197
|
+
"""Aggregated results for one (env, view_mode, agent) combination."""
|
|
198
|
+
env_name: str
|
|
199
|
+
view_mode: str
|
|
200
|
+
agent_type: str
|
|
201
|
+
trials: List[TrialMetrics] = field(default_factory=list)
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def n_trials(self) -> int:
|
|
205
|
+
return len(self.trials)
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def success_rate(self) -> float:
|
|
209
|
+
if not self.trials:
|
|
210
|
+
return 0.0
|
|
211
|
+
return sum(1 for t in self.trials if t.success) / len(self.trials)
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def success_rate_ci(self) -> Tuple[float, float]:
|
|
215
|
+
"""95% Wilson score confidence interval for binomial proportion."""
|
|
216
|
+
n = len(self.trials)
|
|
217
|
+
if n == 0:
|
|
218
|
+
return (0.0, 0.0)
|
|
219
|
+
p = self.success_rate
|
|
220
|
+
z = 1.96
|
|
221
|
+
denom = 1 + z**2 / n
|
|
222
|
+
center = (p + z**2 / (2 * n)) / denom
|
|
223
|
+
spread = z * math.sqrt((p * (1 - p) + z**2 / (4 * n)) / n) / denom
|
|
224
|
+
return (max(0.0, center - spread), min(1.0, center + spread))
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def avg_steps_success(self) -> Optional[float]:
|
|
228
|
+
succ = [t.steps for t in self.trials if t.success]
|
|
229
|
+
return sum(succ) / len(succ) if succ else None
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def avg_steps_failure(self) -> Optional[float]:
|
|
233
|
+
fail = [t.steps for t in self.trials if not t.success]
|
|
234
|
+
return sum(fail) / len(fail) if fail else None
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def avg_reward(self) -> float:
|
|
238
|
+
if not self.trials:
|
|
239
|
+
return 0.0
|
|
240
|
+
return sum(t.reward for t in self.trials) / len(self.trials)
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def timeout_rate(self) -> float:
|
|
244
|
+
"""Fraction of trials that hit max steps without success."""
|
|
245
|
+
if not self.trials:
|
|
246
|
+
return 0.0
|
|
247
|
+
max_s = max(t.steps for t in self.trials) if self.trials else 200
|
|
248
|
+
timeouts = sum(1 for t in self.trials if not t.success and t.steps >= max_s * 0.95)
|
|
249
|
+
return timeouts / len(self.trials)
|
|
250
|
+
|
|
251
|
+
def learning_curve(self, window: int = 4) -> List[float]:
|
|
252
|
+
"""Smoothed success rate over trials (rolling window)."""
|
|
253
|
+
if len(self.trials) < window:
|
|
254
|
+
return [self.success_rate]
|
|
255
|
+
curve = []
|
|
256
|
+
for i in range(len(self.trials) - window + 1):
|
|
257
|
+
block = self.trials[i:i + window]
|
|
258
|
+
curve.append(sum(1 for t in block if t.success) / len(block))
|
|
259
|
+
return curve
|
|
260
|
+
|
|
261
|
+
def learning_curve_blocks(self, block_size: int = 5) -> List[float]:
|
|
262
|
+
"""Success rate per block of trials (non-overlapping)."""
|
|
263
|
+
curve = []
|
|
264
|
+
for i in range(0, len(self.trials), block_size):
|
|
265
|
+
block = self.trials[i:i + block_size]
|
|
266
|
+
if block:
|
|
267
|
+
curve.append(sum(1 for t in block if t.success) / len(block))
|
|
268
|
+
return curve
|
|
269
|
+
|
|
270
|
+
def to_dict(self) -> Dict:
|
|
271
|
+
sr_ci = self.success_rate_ci
|
|
272
|
+
return {
|
|
273
|
+
"env_name": self.env_name,
|
|
274
|
+
"view_mode": self.view_mode,
|
|
275
|
+
"agent_type": self.agent_type,
|
|
276
|
+
"n_trials": self.n_trials,
|
|
277
|
+
"success_rate": round(self.success_rate, 4),
|
|
278
|
+
"success_rate_ci_95": [round(sr_ci[0], 4), round(sr_ci[1], 4)],
|
|
279
|
+
"avg_steps_success": round(self.avg_steps_success, 1) if self.avg_steps_success else None,
|
|
280
|
+
"avg_steps_failure": round(self.avg_steps_failure, 1) if self.avg_steps_failure else None,
|
|
281
|
+
"avg_reward": round(self.avg_reward, 4),
|
|
282
|
+
"timeout_rate": round(self.timeout_rate, 4),
|
|
283
|
+
"learning_curve_blocks": [round(v, 4) for v in self.learning_curve_blocks()],
|
|
284
|
+
"trials": [
|
|
285
|
+
{
|
|
286
|
+
"steps": t.steps,
|
|
287
|
+
"reward": round(t.reward, 4),
|
|
288
|
+
"success": t.success,
|
|
289
|
+
"n_actions": len(t.actions),
|
|
290
|
+
"action_entropy": round(t.action_entropy, 4),
|
|
291
|
+
"forward_ratio": round(t.forward_ratio, 4),
|
|
292
|
+
"rotation_ratio": round(t.rotation_ratio, 4),
|
|
293
|
+
}
|
|
294
|
+
for t in self.trials
|
|
295
|
+
],
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def compute_trial_metrics(actions: List[str]) -> TrialMetrics:
|
|
300
|
+
"""Compute rich strategy metrics from an action sequence."""
|
|
301
|
+
m = TrialMetrics()
|
|
302
|
+
m.actions = actions
|
|
303
|
+
m.steps = len(actions)
|
|
304
|
+
if not actions:
|
|
305
|
+
return m
|
|
306
|
+
|
|
307
|
+
counts = defaultdict(int)
|
|
308
|
+
for a in actions:
|
|
309
|
+
counts[a.upper()] += 1
|
|
310
|
+
|
|
311
|
+
total = len(actions)
|
|
312
|
+
m.forward_ratio = counts.get("FORWARD", 0) / total
|
|
313
|
+
m.rotation_ratio = (counts.get("ROTATE_LEFT", 0) + counts.get("ROTATE_RIGHT", 0)) / total
|
|
314
|
+
m.stay_ratio = counts.get("STAY", 0) / total
|
|
315
|
+
|
|
316
|
+
# Shannon entropy of action distribution
|
|
317
|
+
probs = [c / total for c in counts.values() if c > 0]
|
|
318
|
+
m.action_entropy = -sum(p * math.log2(p) for p in probs) if len(probs) > 1 else 0.0
|
|
319
|
+
|
|
320
|
+
# Repetition rate
|
|
321
|
+
repeats = sum(1 for i in range(1, len(actions)) if actions[i] == actions[i - 1])
|
|
322
|
+
m.action_repetition_rate = repeats / (len(actions) - 1) if len(actions) > 1 else 0.0
|
|
323
|
+
|
|
324
|
+
# Direction changes (left→right or right→left)
|
|
325
|
+
dir_changes = 0
|
|
326
|
+
last_rot = None
|
|
327
|
+
for a in actions:
|
|
328
|
+
au = a.upper()
|
|
329
|
+
if au in ("ROTATE_LEFT", "ROTATE_RIGHT"):
|
|
330
|
+
if last_rot and last_rot != au:
|
|
331
|
+
dir_changes += 1
|
|
332
|
+
last_rot = au
|
|
333
|
+
m.direction_changes = dir_changes
|
|
334
|
+
|
|
335
|
+
return m
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
# ============================================================================
|
|
339
|
+
# Cognitive Profile
|
|
340
|
+
# ============================================================================
|
|
341
|
+
|
|
342
|
+
def compute_cognitive_profile(
|
|
343
|
+
results: List[EnvResult],
|
|
344
|
+
) -> Dict[str, float]:
|
|
345
|
+
"""
|
|
346
|
+
Compute a cognitive profile (radar chart values) from benchmark results.
|
|
347
|
+
|
|
348
|
+
Returns a dict mapping each cognitive dimension to a score in [0, 1],
|
|
349
|
+
computed as the weighted average of success rates from contributing envs.
|
|
350
|
+
"""
|
|
351
|
+
dim_scores: Dict[str, List[Tuple[float, float]]] = {d: [] for d in COGNITIVE_DIMENSIONS}
|
|
352
|
+
|
|
353
|
+
for r in results:
|
|
354
|
+
mapping = ENV_COGNITIVE_MAP.get(r.env_name, {})
|
|
355
|
+
for dim, weight in mapping.items():
|
|
356
|
+
dim_scores[dim].append((r.success_rate, weight))
|
|
357
|
+
|
|
358
|
+
profile = {}
|
|
359
|
+
for dim in COGNITIVE_DIMENSIONS:
|
|
360
|
+
entries = dim_scores[dim]
|
|
361
|
+
if not entries:
|
|
362
|
+
profile[dim] = 0.0
|
|
363
|
+
else:
|
|
364
|
+
total_w = sum(w for _, w in entries)
|
|
365
|
+
profile[dim] = sum(sr * w for sr, w in entries) / total_w if total_w > 0 else 0.0
|
|
366
|
+
|
|
367
|
+
return profile
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def compute_animal_profile() -> Dict[str, float]:
|
|
371
|
+
"""
|
|
372
|
+
Compute the 'animal baseline' cognitive profile from published data.
|
|
373
|
+
Uses the final-session success rate for each environment.
|
|
374
|
+
"""
|
|
375
|
+
results = []
|
|
376
|
+
for env_name, baseline in ANIMAL_BASELINES.items():
|
|
377
|
+
lc = baseline.get("learning_curve", [])
|
|
378
|
+
final_sr = lc[-1] if lc else 0.5
|
|
379
|
+
r = EnvResult(env_name=env_name, view_mode="N/A", agent_type="animal")
|
|
380
|
+
# Fake trials to get the right success rate
|
|
381
|
+
n = 20
|
|
382
|
+
n_succ = round(final_sr * n)
|
|
383
|
+
for i in range(n):
|
|
384
|
+
t = TrialMetrics(success=(i < n_succ), steps=50)
|
|
385
|
+
r.trials.append(t)
|
|
386
|
+
results.append(r)
|
|
387
|
+
return compute_cognitive_profile(results)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
# ============================================================================
|
|
391
|
+
# Result Loading & Parsing
|
|
392
|
+
# ============================================================================
|
|
393
|
+
|
|
394
|
+
def load_results(path: str) -> List[EnvResult]:
|
|
395
|
+
"""Load benchmark results JSON and parse into EnvResult objects."""
|
|
396
|
+
with open(path) as f:
|
|
397
|
+
data = json.load(f)
|
|
398
|
+
|
|
399
|
+
env_results = []
|
|
400
|
+
for entry in data.get("results", []):
|
|
401
|
+
r = EnvResult(
|
|
402
|
+
env_name=entry["env_name"],
|
|
403
|
+
view_mode=entry["view_mode"],
|
|
404
|
+
agent_type=entry["agent_type"],
|
|
405
|
+
)
|
|
406
|
+
for t in entry.get("trials", []):
|
|
407
|
+
actions = t.get("actions", [])
|
|
408
|
+
m = compute_trial_metrics(actions)
|
|
409
|
+
m.reward = t.get("reward", 0.0)
|
|
410
|
+
m.success = t.get("success", False)
|
|
411
|
+
m.steps = t.get("steps", len(actions))
|
|
412
|
+
r.trials.append(m)
|
|
413
|
+
env_results.append(r)
|
|
414
|
+
return env_results
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def group_by_agent(results: List[EnvResult]) -> Dict[str, List[EnvResult]]:
|
|
418
|
+
"""Group results by agent type."""
|
|
419
|
+
groups: Dict[str, List[EnvResult]] = defaultdict(list)
|
|
420
|
+
for r in results:
|
|
421
|
+
groups[r.agent_type].append(r)
|
|
422
|
+
return dict(groups)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def group_by_env(results: List[EnvResult]) -> Dict[str, List[EnvResult]]:
|
|
426
|
+
"""Group results by environment name."""
|
|
427
|
+
groups: Dict[str, List[EnvResult]] = defaultdict(list)
|
|
428
|
+
for r in results:
|
|
429
|
+
groups[r.env_name].append(r)
|
|
430
|
+
return dict(groups)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
# ============================================================================
|
|
434
|
+
# Summary Report
|
|
435
|
+
# ============================================================================
|
|
436
|
+
|
|
437
|
+
def generate_summary(results: List[EnvResult]) -> Dict:
|
|
438
|
+
"""Generate a comprehensive summary report."""
|
|
439
|
+
by_agent = group_by_agent(results)
|
|
440
|
+
|
|
441
|
+
summary = {
|
|
442
|
+
"agents": {},
|
|
443
|
+
"cognitive_profiles": {},
|
|
444
|
+
"animal_baseline_profile": compute_animal_profile(),
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
for agent_name, agent_results in by_agent.items():
|
|
448
|
+
# Per-environment results (best view mode)
|
|
449
|
+
by_env = group_by_env(agent_results)
|
|
450
|
+
env_summaries = {}
|
|
451
|
+
best_per_env = []
|
|
452
|
+
|
|
453
|
+
for env_name, env_results in by_env.items():
|
|
454
|
+
# Pick best view mode by success rate
|
|
455
|
+
best = max(env_results, key=lambda r: r.success_rate)
|
|
456
|
+
best_per_env.append(best)
|
|
457
|
+
env_summaries[env_name] = {
|
|
458
|
+
"best_view_mode": best.view_mode,
|
|
459
|
+
"best_success_rate": best.success_rate,
|
|
460
|
+
"by_view_mode": {
|
|
461
|
+
r.view_mode: {
|
|
462
|
+
"success_rate": r.success_rate,
|
|
463
|
+
"ci_95": list(r.success_rate_ci),
|
|
464
|
+
"avg_steps_success": r.avg_steps_success,
|
|
465
|
+
"learning_curve": r.learning_curve_blocks(),
|
|
466
|
+
}
|
|
467
|
+
for r in env_results
|
|
468
|
+
},
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
# Cognitive profile for this agent (using best per-env)
|
|
472
|
+
profile = compute_cognitive_profile(best_per_env)
|
|
473
|
+
|
|
474
|
+
summary["agents"][agent_name] = {
|
|
475
|
+
"overall_success_rate": sum(r.success_rate for r in best_per_env) / len(best_per_env) if best_per_env else 0,
|
|
476
|
+
"environments": env_summaries,
|
|
477
|
+
}
|
|
478
|
+
summary["cognitive_profiles"][agent_name] = profile
|
|
479
|
+
|
|
480
|
+
return summary
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def print_report(results: List[EnvResult]):
|
|
484
|
+
"""Print a formatted text report to stdout."""
|
|
485
|
+
by_agent = group_by_agent(results)
|
|
486
|
+
|
|
487
|
+
print("\n" + "=" * 80)
|
|
488
|
+
print("CHEESEBENCH RESULTS REPORT")
|
|
489
|
+
print("=" * 80)
|
|
490
|
+
|
|
491
|
+
for agent_name, agent_results in sorted(by_agent.items()):
|
|
492
|
+
print(f"\n{'─' * 60}")
|
|
493
|
+
print(f"Agent: {agent_name}")
|
|
494
|
+
print(f"{'─' * 60}")
|
|
495
|
+
|
|
496
|
+
by_env = group_by_env(agent_results)
|
|
497
|
+
for env_name in sorted(by_env.keys()):
|
|
498
|
+
env_results = by_env[env_name]
|
|
499
|
+
print(f"\n {env_name}:")
|
|
500
|
+
for r in sorted(env_results, key=lambda x: x.view_mode):
|
|
501
|
+
ci = r.success_rate_ci
|
|
502
|
+
steps_str = f"steps={r.avg_steps_success:.0f}" if r.avg_steps_success else "N/A"
|
|
503
|
+
print(
|
|
504
|
+
f" {r.view_mode:<14} "
|
|
505
|
+
f"SR={r.success_rate:.0%} [{ci[0]:.0%}-{ci[1]:.0%}] "
|
|
506
|
+
f"{steps_str} "
|
|
507
|
+
f"timeout={r.timeout_rate:.0%}"
|
|
508
|
+
)
|
|
509
|
+
lc = r.learning_curve_blocks()
|
|
510
|
+
if len(lc) > 1:
|
|
511
|
+
lc_str = " → ".join(f"{v:.0%}" for v in lc)
|
|
512
|
+
print(f"{'':>20}learning: {lc_str}")
|
|
513
|
+
|
|
514
|
+
# Cognitive profile
|
|
515
|
+
best_per_env = []
|
|
516
|
+
for env_results in by_env.values():
|
|
517
|
+
best_per_env.append(max(env_results, key=lambda r: r.success_rate))
|
|
518
|
+
profile = compute_cognitive_profile(best_per_env)
|
|
519
|
+
print(f"\n Cognitive Profile:")
|
|
520
|
+
for dim in COGNITIVE_DIMENSIONS:
|
|
521
|
+
bar = "█" * int(profile[dim] * 20) + "░" * (20 - int(profile[dim] * 20))
|
|
522
|
+
print(f" {dim:<32} {bar} {profile[dim]:.0%}")
|
|
523
|
+
|
|
524
|
+
# Animal baseline
|
|
525
|
+
animal = compute_animal_profile()
|
|
526
|
+
print(f"\n{'─' * 60}")
|
|
527
|
+
print("Animal Baselines (final session, from published literature):")
|
|
528
|
+
print(f"{'─' * 60}")
|
|
529
|
+
for dim in COGNITIVE_DIMENSIONS:
|
|
530
|
+
bar = "█" * int(animal[dim] * 20) + "░" * (20 - int(animal[dim] * 20))
|
|
531
|
+
print(f" {dim:<32} {bar} {animal[dim]:.0%}")
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
# ============================================================================
|
|
535
|
+
# CLI
|
|
536
|
+
# ============================================================================
|
|
537
|
+
|
|
538
|
+
if __name__ == "__main__":
|
|
539
|
+
import sys
|
|
540
|
+
|
|
541
|
+
if len(sys.argv) < 2:
|
|
542
|
+
print("Usage: python analysis.py <results.json>")
|
|
543
|
+
sys.exit(1)
|
|
544
|
+
|
|
545
|
+
results = load_results(sys.argv[1])
|
|
546
|
+
print_report(results)
|
|
547
|
+
|
|
548
|
+
summary = generate_summary(results)
|
|
549
|
+
out_path = sys.argv[1].replace(".json", "_analysis.json")
|
|
550
|
+
with open(out_path, "w") as f:
|
|
551
|
+
json.dump(summary, f, indent=2)
|
|
552
|
+
print(f"\nAnalysis saved to {out_path}")
|