cheesebench 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
analysis.py ADDED
@@ -0,0 +1,552 @@
1
+ """
2
+ CheeseBench Analysis Pipeline
3
+
4
+ Computes cognitive profiles, learning curves, strategy metrics,
5
+ and generates publication-quality figures from benchmark results.
6
+ """
7
+
8
+ import json
9
+ import math
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from dataclasses import dataclass, field
13
+ from typing import Dict, List, Optional, Tuple
14
+ from collections import defaultdict
15
+
16
+
17
+ # ============================================================================
18
+ # Cognitive Taxonomy — maps environments to cognitive dimensions
19
+ # ============================================================================
20
+
21
+ COGNITIVE_DIMENSIONS = [
22
+ "Allocentric Spatial Learning",
23
+ "Egocentric Navigation",
24
+ "Working Memory",
25
+ "Instrumental Conditioning",
26
+ "Avoidance Learning",
27
+ "Associative Learning",
28
+ ]
29
+
30
+ # Each env maps to one or more dimensions with a weight (0-1)
31
+ ENV_COGNITIVE_MAP: Dict[str, Dict[str, float]] = {
32
+ "MorrisWaterMaze": {
33
+ "Allocentric Spatial Learning": 1.0,
34
+ },
35
+ "BarnesMaze": {
36
+ "Allocentric Spatial Learning": 1.0,
37
+ },
38
+ "TMaze": {
39
+ "Egocentric Navigation": 0.7,
40
+ "Working Memory": 0.3,
41
+ },
42
+ "StarMaze": {
43
+ "Allocentric Spatial Learning": 0.5,
44
+ "Egocentric Navigation": 0.5,
45
+ },
46
+ "RadialArmMaze": {
47
+ "Working Memory": 0.7,
48
+ "Allocentric Spatial Learning": 0.3,
49
+ },
50
+ "OperantChamber": {
51
+ "Instrumental Conditioning": 1.0,
52
+ },
53
+ "ShuttleBox": {
54
+ "Avoidance Learning": 1.0,
55
+ },
56
+ "PlacePreference": {
57
+ "Associative Learning": 1.0,
58
+ },
59
+ "DNMSTask": {
60
+ "Working Memory": 1.0,
61
+ },
62
+ }
63
+
64
+ # Neural circuit dependencies (for discussion section)
65
+ NEURAL_CIRCUITS: Dict[str, str] = {
66
+ "Allocentric Spatial Learning": "Hippocampus (place cells, grid cells)",
67
+ "Egocentric Navigation": "Dorsomedial striatum, parietal cortex",
68
+ "Working Memory": "Prefrontal cortex, hippocampus",
69
+ "Instrumental Conditioning": "Dorsolateral striatum, nucleus accumbens",
70
+ "Avoidance Learning": "Amygdala, periaqueductal gray",
71
+ "Associative Learning": "Ventral tegmental area, nucleus accumbens",
72
+ }
73
+
74
+
75
+ # ============================================================================
76
+ # Published Animal Baselines — extracted from source papers
77
+ # ============================================================================
78
+
79
+ ANIMAL_BASELINES: Dict[str, Dict] = {
80
+ "MorrisWaterMaze": {
81
+ "species": "C57BL/6 mice",
82
+ "source": "PMC3259155 (de Fiebre et al., 2006)",
83
+ "trials_to_criterion": 15,
84
+ "success_rate_session_1": 0.30, # ~30% find platform in first session
85
+ "success_rate_session_5": 0.85, # ~85% by session 5
86
+ "learning_curve": [0.30, 0.50, 0.65, 0.78, 0.85], # per-session
87
+ "avg_latency_session_1_s": 55.0, # seconds
88
+ "avg_latency_session_5_s": 15.0,
89
+ "notes": "Platform acquisition over 5 sessions, 5 trials/session",
90
+ },
91
+ "TMaze": {
92
+ "species": "NMRI mice",
93
+ "source": "PMC3399492 (Shoji et al., 2012)",
94
+ "trials_to_criterion": 20,
95
+ "success_rate_session_1": 0.50, # chance level
96
+ "success_rate_session_4": 0.80,
97
+ "learning_curve": [0.50, 0.60, 0.72, 0.80],
98
+ "notes": "Forced alternation, 10 trials/session over 4 days",
99
+ },
100
+ "BarnesMaze": {
101
+ "species": "B6C3F1/J mice",
102
+ "source": "PMC1783636 (Harrison et al., 2006)",
103
+ "trials_to_criterion": 20,
104
+ "success_rate_session_1": 0.20,
105
+ "success_rate_session_5": 0.80,
106
+ "learning_curve": [0.20, 0.40, 0.55, 0.70, 0.80],
107
+ "avg_latency_session_1_s": 120.0,
108
+ "avg_latency_session_5_s": 25.0,
109
+ "notes": "12 holes, 4 trials/session over 5 sessions",
110
+ },
111
+ "RadialArmMaze": {
112
+ "species": "C57BL/6 mice",
113
+ "source": "PMC4030456 (Penley et al., 2013)",
114
+ "trials_to_criterion": 36,
115
+ "success_rate_session_1": 0.15,
116
+ "success_rate_session_6": 0.70,
117
+ "learning_curve": [0.15, 0.25, 0.40, 0.50, 0.60, 0.70],
118
+ "working_memory_errors_session_1": 3.5,
119
+ "working_memory_errors_session_6": 0.8,
120
+ "notes": "8 arms, 4 baited. Errors = revisits to depleted arms",
121
+ },
122
+ "OperantChamber": {
123
+ "species": "C57BL/6 mice",
124
+ "source": "PMC6619163 (Jurado-Parras et al., 2013)",
125
+ "trials_to_criterion": 50,
126
+ "success_rate_session_1": 0.40,
127
+ "success_rate_session_5": 0.90,
128
+ "learning_curve": [0.40, 0.60, 0.75, 0.85, 0.90],
129
+ "avg_lever_presses_session_1": 12.0,
130
+ "avg_lever_presses_session_5": 45.0,
131
+ "notes": "FR-1 schedule, 30 min sessions",
132
+ },
133
+ "ShuttleBox": {
134
+ "species": "Sprague Dawley rats",
135
+ "source": "PMC4633642 (Chacon et al., 2016)",
136
+ "trials_to_criterion": 90,
137
+ "success_rate_session_1": 0.10,
138
+ "success_rate_session_5": 0.70,
139
+ "learning_curve": [0.10, 0.25, 0.45, 0.60, 0.70],
140
+ "notes": "Active avoidance, 30 trials/session over 5 sessions",
141
+ },
142
+ "PlacePreference": {
143
+ "species": "C57BL/6 mice",
144
+ "source": "PMC6101638 (Blanco-Gandía et al., 2018)",
145
+ "trials_to_criterion": 6,
146
+ "success_rate_session_1": 0.50, # chance
147
+ "success_rate_session_6": 0.75,
148
+ "learning_curve": [0.50, 0.55, 0.60, 0.65, 0.70, 0.75],
149
+ "time_in_paired_chamber_pct_pre": 50.0,
150
+ "time_in_paired_chamber_pct_post": 65.0,
151
+ "notes": "3-phase CPP protocol, 6 conditioning sessions",
152
+ },
153
+ "StarMaze": {
154
+ "species": "C57BL/6 mice",
155
+ "source": "PMC3695082 (Rondi-Reig et al., 2006)",
156
+ "trials_to_criterion": 40,
157
+ "success_rate_session_1": 0.25,
158
+ "success_rate_session_10": 0.80,
159
+ "learning_curve": [0.25, 0.30, 0.40, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80],
160
+ "notes": "5-arm maze, 4 trials/session over 10 sessions",
161
+ },
162
+ "DNMSTask": {
163
+ "species": "Long-Evans rats",
164
+ "source": "PMC3982138 (Oomen et al., 2013)",
165
+ "trials_to_criterion": 2856,
166
+ "success_rate_session_1": 0.50, # chance (2AFC)
167
+ "success_rate_session_30": 0.80,
168
+ "learning_curve": [0.50, 0.55, 0.58, 0.60, 0.62, 0.64, 0.66, 0.68, 0.70, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.78, 0.79, 0.79, 0.79, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80],
169
+ "notes": "TUNL task, 84 trials/session, ~30 sessions to criterion",
170
+ },
171
+ }
172
+
173
+
174
+ # ============================================================================
175
+ # Metric Computation
176
+ # ============================================================================
177
+
178
+ @dataclass
179
+ class TrialMetrics:
180
+ """Rich metrics for a single trial."""
181
+ steps: int = 0
182
+ reward: float = 0.0
183
+ success: bool = False
184
+ actions: List[str] = field(default_factory=list)
185
+
186
+ # Computed metrics
187
+ action_entropy: float = 0.0
188
+ forward_ratio: float = 0.0
189
+ rotation_ratio: float = 0.0
190
+ stay_ratio: float = 0.0
191
+ action_repetition_rate: float = 0.0 # same action back-to-back
192
+ direction_changes: int = 0 # how often rotation direction changes
193
+
194
+
195
+ @dataclass
196
+ class EnvResult:
197
+ """Aggregated results for one (env, view_mode, agent) combination."""
198
+ env_name: str
199
+ view_mode: str
200
+ agent_type: str
201
+ trials: List[TrialMetrics] = field(default_factory=list)
202
+
203
+ @property
204
+ def n_trials(self) -> int:
205
+ return len(self.trials)
206
+
207
+ @property
208
+ def success_rate(self) -> float:
209
+ if not self.trials:
210
+ return 0.0
211
+ return sum(1 for t in self.trials if t.success) / len(self.trials)
212
+
213
+ @property
214
+ def success_rate_ci(self) -> Tuple[float, float]:
215
+ """95% Wilson score confidence interval for binomial proportion."""
216
+ n = len(self.trials)
217
+ if n == 0:
218
+ return (0.0, 0.0)
219
+ p = self.success_rate
220
+ z = 1.96
221
+ denom = 1 + z**2 / n
222
+ center = (p + z**2 / (2 * n)) / denom
223
+ spread = z * math.sqrt((p * (1 - p) + z**2 / (4 * n)) / n) / denom
224
+ return (max(0.0, center - spread), min(1.0, center + spread))
225
+
226
+ @property
227
+ def avg_steps_success(self) -> Optional[float]:
228
+ succ = [t.steps for t in self.trials if t.success]
229
+ return sum(succ) / len(succ) if succ else None
230
+
231
+ @property
232
+ def avg_steps_failure(self) -> Optional[float]:
233
+ fail = [t.steps for t in self.trials if not t.success]
234
+ return sum(fail) / len(fail) if fail else None
235
+
236
+ @property
237
+ def avg_reward(self) -> float:
238
+ if not self.trials:
239
+ return 0.0
240
+ return sum(t.reward for t in self.trials) / len(self.trials)
241
+
242
+ @property
243
+ def timeout_rate(self) -> float:
244
+ """Fraction of trials that hit max steps without success."""
245
+ if not self.trials:
246
+ return 0.0
247
+ max_s = max(t.steps for t in self.trials) if self.trials else 200
248
+ timeouts = sum(1 for t in self.trials if not t.success and t.steps >= max_s * 0.95)
249
+ return timeouts / len(self.trials)
250
+
251
+ def learning_curve(self, window: int = 4) -> List[float]:
252
+ """Smoothed success rate over trials (rolling window)."""
253
+ if len(self.trials) < window:
254
+ return [self.success_rate]
255
+ curve = []
256
+ for i in range(len(self.trials) - window + 1):
257
+ block = self.trials[i:i + window]
258
+ curve.append(sum(1 for t in block if t.success) / len(block))
259
+ return curve
260
+
261
+ def learning_curve_blocks(self, block_size: int = 5) -> List[float]:
262
+ """Success rate per block of trials (non-overlapping)."""
263
+ curve = []
264
+ for i in range(0, len(self.trials), block_size):
265
+ block = self.trials[i:i + block_size]
266
+ if block:
267
+ curve.append(sum(1 for t in block if t.success) / len(block))
268
+ return curve
269
+
270
+ def to_dict(self) -> Dict:
271
+ sr_ci = self.success_rate_ci
272
+ return {
273
+ "env_name": self.env_name,
274
+ "view_mode": self.view_mode,
275
+ "agent_type": self.agent_type,
276
+ "n_trials": self.n_trials,
277
+ "success_rate": round(self.success_rate, 4),
278
+ "success_rate_ci_95": [round(sr_ci[0], 4), round(sr_ci[1], 4)],
279
+ "avg_steps_success": round(self.avg_steps_success, 1) if self.avg_steps_success else None,
280
+ "avg_steps_failure": round(self.avg_steps_failure, 1) if self.avg_steps_failure else None,
281
+ "avg_reward": round(self.avg_reward, 4),
282
+ "timeout_rate": round(self.timeout_rate, 4),
283
+ "learning_curve_blocks": [round(v, 4) for v in self.learning_curve_blocks()],
284
+ "trials": [
285
+ {
286
+ "steps": t.steps,
287
+ "reward": round(t.reward, 4),
288
+ "success": t.success,
289
+ "n_actions": len(t.actions),
290
+ "action_entropy": round(t.action_entropy, 4),
291
+ "forward_ratio": round(t.forward_ratio, 4),
292
+ "rotation_ratio": round(t.rotation_ratio, 4),
293
+ }
294
+ for t in self.trials
295
+ ],
296
+ }
297
+
298
+
299
+ def compute_trial_metrics(actions: List[str]) -> TrialMetrics:
300
+ """Compute rich strategy metrics from an action sequence."""
301
+ m = TrialMetrics()
302
+ m.actions = actions
303
+ m.steps = len(actions)
304
+ if not actions:
305
+ return m
306
+
307
+ counts = defaultdict(int)
308
+ for a in actions:
309
+ counts[a.upper()] += 1
310
+
311
+ total = len(actions)
312
+ m.forward_ratio = counts.get("FORWARD", 0) / total
313
+ m.rotation_ratio = (counts.get("ROTATE_LEFT", 0) + counts.get("ROTATE_RIGHT", 0)) / total
314
+ m.stay_ratio = counts.get("STAY", 0) / total
315
+
316
+ # Shannon entropy of action distribution
317
+ probs = [c / total for c in counts.values() if c > 0]
318
+ m.action_entropy = -sum(p * math.log2(p) for p in probs) if len(probs) > 1 else 0.0
319
+
320
+ # Repetition rate
321
+ repeats = sum(1 for i in range(1, len(actions)) if actions[i] == actions[i - 1])
322
+ m.action_repetition_rate = repeats / (len(actions) - 1) if len(actions) > 1 else 0.0
323
+
324
+ # Direction changes (left→right or right→left)
325
+ dir_changes = 0
326
+ last_rot = None
327
+ for a in actions:
328
+ au = a.upper()
329
+ if au in ("ROTATE_LEFT", "ROTATE_RIGHT"):
330
+ if last_rot and last_rot != au:
331
+ dir_changes += 1
332
+ last_rot = au
333
+ m.direction_changes = dir_changes
334
+
335
+ return m
336
+
337
+
338
+ # ============================================================================
339
+ # Cognitive Profile
340
+ # ============================================================================
341
+
342
+ def compute_cognitive_profile(
343
+ results: List[EnvResult],
344
+ ) -> Dict[str, float]:
345
+ """
346
+ Compute a cognitive profile (radar chart values) from benchmark results.
347
+
348
+ Returns a dict mapping each cognitive dimension to a score in [0, 1],
349
+ computed as the weighted average of success rates from contributing envs.
350
+ """
351
+ dim_scores: Dict[str, List[Tuple[float, float]]] = {d: [] for d in COGNITIVE_DIMENSIONS}
352
+
353
+ for r in results:
354
+ mapping = ENV_COGNITIVE_MAP.get(r.env_name, {})
355
+ for dim, weight in mapping.items():
356
+ dim_scores[dim].append((r.success_rate, weight))
357
+
358
+ profile = {}
359
+ for dim in COGNITIVE_DIMENSIONS:
360
+ entries = dim_scores[dim]
361
+ if not entries:
362
+ profile[dim] = 0.0
363
+ else:
364
+ total_w = sum(w for _, w in entries)
365
+ profile[dim] = sum(sr * w for sr, w in entries) / total_w if total_w > 0 else 0.0
366
+
367
+ return profile
368
+
369
+
370
+ def compute_animal_profile() -> Dict[str, float]:
371
+ """
372
+ Compute the 'animal baseline' cognitive profile from published data.
373
+ Uses the final-session success rate for each environment.
374
+ """
375
+ results = []
376
+ for env_name, baseline in ANIMAL_BASELINES.items():
377
+ lc = baseline.get("learning_curve", [])
378
+ final_sr = lc[-1] if lc else 0.5
379
+ r = EnvResult(env_name=env_name, view_mode="N/A", agent_type="animal")
380
+ # Fake trials to get the right success rate
381
+ n = 20
382
+ n_succ = round(final_sr * n)
383
+ for i in range(n):
384
+ t = TrialMetrics(success=(i < n_succ), steps=50)
385
+ r.trials.append(t)
386
+ results.append(r)
387
+ return compute_cognitive_profile(results)
388
+
389
+
390
+ # ============================================================================
391
+ # Result Loading & Parsing
392
+ # ============================================================================
393
+
394
+ def load_results(path: str) -> List[EnvResult]:
395
+ """Load benchmark results JSON and parse into EnvResult objects."""
396
+ with open(path) as f:
397
+ data = json.load(f)
398
+
399
+ env_results = []
400
+ for entry in data.get("results", []):
401
+ r = EnvResult(
402
+ env_name=entry["env_name"],
403
+ view_mode=entry["view_mode"],
404
+ agent_type=entry["agent_type"],
405
+ )
406
+ for t in entry.get("trials", []):
407
+ actions = t.get("actions", [])
408
+ m = compute_trial_metrics(actions)
409
+ m.reward = t.get("reward", 0.0)
410
+ m.success = t.get("success", False)
411
+ m.steps = t.get("steps", len(actions))
412
+ r.trials.append(m)
413
+ env_results.append(r)
414
+ return env_results
415
+
416
+
417
+ def group_by_agent(results: List[EnvResult]) -> Dict[str, List[EnvResult]]:
418
+ """Group results by agent type."""
419
+ groups: Dict[str, List[EnvResult]] = defaultdict(list)
420
+ for r in results:
421
+ groups[r.agent_type].append(r)
422
+ return dict(groups)
423
+
424
+
425
+ def group_by_env(results: List[EnvResult]) -> Dict[str, List[EnvResult]]:
426
+ """Group results by environment name."""
427
+ groups: Dict[str, List[EnvResult]] = defaultdict(list)
428
+ for r in results:
429
+ groups[r.env_name].append(r)
430
+ return dict(groups)
431
+
432
+
433
+ # ============================================================================
434
+ # Summary Report
435
+ # ============================================================================
436
+
437
+ def generate_summary(results: List[EnvResult]) -> Dict:
438
+ """Generate a comprehensive summary report."""
439
+ by_agent = group_by_agent(results)
440
+
441
+ summary = {
442
+ "agents": {},
443
+ "cognitive_profiles": {},
444
+ "animal_baseline_profile": compute_animal_profile(),
445
+ }
446
+
447
+ for agent_name, agent_results in by_agent.items():
448
+ # Per-environment results (best view mode)
449
+ by_env = group_by_env(agent_results)
450
+ env_summaries = {}
451
+ best_per_env = []
452
+
453
+ for env_name, env_results in by_env.items():
454
+ # Pick best view mode by success rate
455
+ best = max(env_results, key=lambda r: r.success_rate)
456
+ best_per_env.append(best)
457
+ env_summaries[env_name] = {
458
+ "best_view_mode": best.view_mode,
459
+ "best_success_rate": best.success_rate,
460
+ "by_view_mode": {
461
+ r.view_mode: {
462
+ "success_rate": r.success_rate,
463
+ "ci_95": list(r.success_rate_ci),
464
+ "avg_steps_success": r.avg_steps_success,
465
+ "learning_curve": r.learning_curve_blocks(),
466
+ }
467
+ for r in env_results
468
+ },
469
+ }
470
+
471
+ # Cognitive profile for this agent (using best per-env)
472
+ profile = compute_cognitive_profile(best_per_env)
473
+
474
+ summary["agents"][agent_name] = {
475
+ "overall_success_rate": sum(r.success_rate for r in best_per_env) / len(best_per_env) if best_per_env else 0,
476
+ "environments": env_summaries,
477
+ }
478
+ summary["cognitive_profiles"][agent_name] = profile
479
+
480
+ return summary
481
+
482
+
483
+ def print_report(results: List[EnvResult]):
484
+ """Print a formatted text report to stdout."""
485
+ by_agent = group_by_agent(results)
486
+
487
+ print("\n" + "=" * 80)
488
+ print("CHEESEBENCH RESULTS REPORT")
489
+ print("=" * 80)
490
+
491
+ for agent_name, agent_results in sorted(by_agent.items()):
492
+ print(f"\n{'─' * 60}")
493
+ print(f"Agent: {agent_name}")
494
+ print(f"{'─' * 60}")
495
+
496
+ by_env = group_by_env(agent_results)
497
+ for env_name in sorted(by_env.keys()):
498
+ env_results = by_env[env_name]
499
+ print(f"\n {env_name}:")
500
+ for r in sorted(env_results, key=lambda x: x.view_mode):
501
+ ci = r.success_rate_ci
502
+ steps_str = f"steps={r.avg_steps_success:.0f}" if r.avg_steps_success else "N/A"
503
+ print(
504
+ f" {r.view_mode:<14} "
505
+ f"SR={r.success_rate:.0%} [{ci[0]:.0%}-{ci[1]:.0%}] "
506
+ f"{steps_str} "
507
+ f"timeout={r.timeout_rate:.0%}"
508
+ )
509
+ lc = r.learning_curve_blocks()
510
+ if len(lc) > 1:
511
+ lc_str = " → ".join(f"{v:.0%}" for v in lc)
512
+ print(f"{'':>20}learning: {lc_str}")
513
+
514
+ # Cognitive profile
515
+ best_per_env = []
516
+ for env_results in by_env.values():
517
+ best_per_env.append(max(env_results, key=lambda r: r.success_rate))
518
+ profile = compute_cognitive_profile(best_per_env)
519
+ print(f"\n Cognitive Profile:")
520
+ for dim in COGNITIVE_DIMENSIONS:
521
+ bar = "█" * int(profile[dim] * 20) + "░" * (20 - int(profile[dim] * 20))
522
+ print(f" {dim:<32} {bar} {profile[dim]:.0%}")
523
+
524
+ # Animal baseline
525
+ animal = compute_animal_profile()
526
+ print(f"\n{'─' * 60}")
527
+ print("Animal Baselines (final session, from published literature):")
528
+ print(f"{'─' * 60}")
529
+ for dim in COGNITIVE_DIMENSIONS:
530
+ bar = "█" * int(animal[dim] * 20) + "░" * (20 - int(animal[dim] * 20))
531
+ print(f" {dim:<32} {bar} {animal[dim]:.0%}")
532
+
533
+
534
+ # ============================================================================
535
+ # CLI
536
+ # ============================================================================
537
+
538
+ if __name__ == "__main__":
539
+ import sys
540
+
541
+ if len(sys.argv) < 2:
542
+ print("Usage: python analysis.py <results.json>")
543
+ sys.exit(1)
544
+
545
+ results = load_results(sys.argv[1])
546
+ print_report(results)
547
+
548
+ summary = generate_summary(results)
549
+ out_path = sys.argv[1].replace(".json", "_analysis.json")
550
+ with open(out_path, "w") as f:
551
+ json.dump(summary, f, indent=2)
552
+ print(f"\nAnalysis saved to {out_path}")