genarena 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. genarena/__init__.py +49 -2
  2. genarena/__main__.py +10 -0
  3. genarena/arena.py +1685 -0
  4. genarena/battle.py +337 -0
  5. genarena/bt_elo.py +507 -0
  6. genarena/cli.py +1581 -0
  7. genarena/data.py +476 -0
  8. genarena/deploy/Dockerfile +25 -0
  9. genarena/deploy/README.md +55 -0
  10. genarena/deploy/__init__.py +5 -0
  11. genarena/deploy/app.py +84 -0
  12. genarena/experiments.py +121 -0
  13. genarena/leaderboard.py +270 -0
  14. genarena/logs.py +409 -0
  15. genarena/models.py +412 -0
  16. genarena/prompts/__init__.py +127 -0
  17. genarena/prompts/mmrb2.py +373 -0
  18. genarena/sampling.py +336 -0
  19. genarena/state.py +656 -0
  20. genarena/sync/__init__.py +105 -0
  21. genarena/sync/auto_commit.py +118 -0
  22. genarena/sync/deploy_ops.py +543 -0
  23. genarena/sync/git_ops.py +422 -0
  24. genarena/sync/hf_ops.py +891 -0
  25. genarena/sync/init_ops.py +431 -0
  26. genarena/sync/packer.py +587 -0
  27. genarena/sync/submit.py +837 -0
  28. genarena/utils.py +103 -0
  29. genarena/validation/__init__.py +19 -0
  30. genarena/validation/schema.py +327 -0
  31. genarena/validation/validator.py +329 -0
  32. genarena/visualize/README.md +148 -0
  33. genarena/visualize/__init__.py +14 -0
  34. genarena/visualize/app.py +938 -0
  35. genarena/visualize/data_loader.py +2335 -0
  36. genarena/visualize/static/app.js +3762 -0
  37. genarena/visualize/static/model_aliases.json +86 -0
  38. genarena/visualize/static/style.css +4104 -0
  39. genarena/visualize/templates/index.html +413 -0
  40. genarena/vlm.py +519 -0
  41. genarena-0.1.0.dist-info/METADATA +178 -0
  42. genarena-0.1.0.dist-info/RECORD +44 -0
  43. {genarena-0.0.1.dist-info → genarena-0.1.0.dist-info}/WHEEL +1 -2
  44. genarena-0.1.0.dist-info/entry_points.txt +2 -0
  45. genarena-0.0.1.dist-info/METADATA +0 -26
  46. genarena-0.0.1.dist-info/RECORD +0 -5
  47. genarena-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,2335 @@
1
+ # Copyright 2026 Ruihang Li.
2
+ # Licensed under the Apache License, Version 2.0.
3
+ # See LICENSE file in the project root for details.
4
+
5
+ """Data loader for arena visualization with preloading support."""
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Optional
13
+
14
+ from genarena.data import DataSample, ParquetDataset, discover_subsets
15
+ from genarena.models import GlobalModelOutputManager
16
+ from genarena.state import ArenaState, load_state
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class BattleRecord:
24
+ """A single battle record with all relevant information."""
25
+
26
+ # Battle identification
27
+ subset: str
28
+ exp_name: str
29
+ sample_index: int
30
+ model_a: str
31
+ model_b: str
32
+
33
+ # Battle result
34
+ final_winner: str # model name or "tie"
35
+ is_consistent: bool
36
+ timestamp: str = ""
37
+
38
+ # Raw VLM outputs (from audit logs, optional)
39
+ original_call: Optional[dict[str, Any]] = None
40
+ swapped_call: Optional[dict[str, Any]] = None
41
+
42
+ # Sample data (loaded on demand)
43
+ instruction: str = ""
44
+ task_type: str = ""
45
+ input_image_count: int = 1
46
+ prompt_source: Optional[str] = None
47
+ original_metadata: Optional[dict[str, Any]] = None
48
+
49
+ @property
50
+ def id(self) -> str:
51
+ """Unique identifier for this battle."""
52
+ return f"{self.subset}:{self.exp_name}:{self.model_a}_vs_{self.model_b}:{self.sample_index}"
53
+
54
+ @property
55
+ def winner_display(self) -> str:
56
+ """Display-friendly winner string."""
57
+ if self.final_winner == "tie":
58
+ return "Tie"
59
+ return self.final_winner
60
+
61
+ @property
62
+ def models(self) -> set[str]:
63
+ """Set of models involved in this battle."""
64
+ return {self.model_a, self.model_b}
65
+
66
+ def to_dict(self) -> dict[str, Any]:
67
+ """Convert to dictionary for JSON serialization."""
68
+ return {
69
+ "id": self.id,
70
+ "subset": self.subset,
71
+ "exp_name": self.exp_name,
72
+ "sample_index": self.sample_index,
73
+ "model_a": self.model_a,
74
+ "model_b": self.model_b,
75
+ "final_winner": self.final_winner,
76
+ "winner_display": self.winner_display,
77
+ "is_consistent": self.is_consistent,
78
+ "timestamp": self.timestamp,
79
+ "instruction": self.instruction,
80
+ "task_type": self.task_type,
81
+ "input_image_count": self.input_image_count,
82
+ "prompt_source": self.prompt_source,
83
+ "original_metadata": self.original_metadata,
84
+ "has_audit": self.original_call is not None,
85
+ }
86
+
87
+ def to_detail_dict(self) -> dict[str, Any]:
88
+ """Convert to detailed dictionary including VLM outputs."""
89
+ d = self.to_dict()
90
+ d["original_call"] = self.original_call
91
+ d["swapped_call"] = self.swapped_call
92
+ return d
93
+
94
+
95
+ @dataclass
96
+ class SubsetInfo:
97
+ """Information about a subset."""
98
+
99
+ name: str
100
+ models: list[str]
101
+ experiments: list[str]
102
+ total_battles: int
103
+ state: Optional[ArenaState] = None
104
+ min_input_images: int = 1
105
+ max_input_images: int = 1
106
+ prompt_sources: list[str] = field(default_factory=list)
107
+
108
+
109
+ class ArenaDataLoader:
110
+ """
111
+ Data loader for arena visualization.
112
+
113
+ Manages loading and querying battle records across multiple subsets.
114
+ Supports preloading for better performance with large datasets.
115
+ """
116
+
117
+ def __init__(self, arena_dir: str, data_dir: str, preload: bool = True):
118
+ """
119
+ Initialize the data loader.
120
+
121
+ Args:
122
+ arena_dir: Path to arena directory containing subset folders
123
+ data_dir: Path to data directory containing parquet files
124
+ preload: If True, preload all data at initialization
125
+ """
126
+ self.arena_dir = arena_dir
127
+ self.data_dir = data_dir
128
+
129
+ # Cached data
130
+ self._subsets: Optional[list[str]] = None
131
+ self._subset_info_cache: dict[str, SubsetInfo] = {}
132
+ self._dataset_cache: dict[str, ParquetDataset] = {}
133
+ self._model_manager_cache: dict[str, GlobalModelOutputManager] = {}
134
+
135
+ # Battle records cache: (subset, exp_name) -> List[BattleRecord]
136
+ self._battle_cache: dict[tuple[str, str], list[BattleRecord]] = {}
137
+
138
+ # Index for faster lookups: (subset, exp_name) -> {model -> [record_indices]}
139
+ self._model_index: dict[tuple[str, str], dict[str, list[int]]] = {}
140
+
141
+ # Sample data cache: (subset, sample_index) -> SampleMetadata dict
142
+ self._sample_cache: dict[tuple[str, int], dict[str, Any]] = {}
143
+
144
+ # Sample to parquet file mapping: (subset, sample_index) -> parquet_file_path
145
+ self._sample_file_map: dict[tuple[str, int], str] = {}
146
+
147
+ # Input image count range per subset: subset -> (min_count, max_count)
148
+ self._image_count_range: dict[str, tuple[int, int]] = {}
149
+
150
+ # Prompt sources per subset: subset -> list of unique prompt_source values
151
+ self._prompt_sources: dict[str, list[str]] = {}
152
+
153
+ # Audit logs cache: (subset, exp_name, model_a, model_b, sample_index) -> audit data
154
+ self._audit_cache: dict[tuple[str, str, str, str, int], dict[str, Any]] = {}
155
+
156
+ # Cross-subset ELO cache: (sorted_subsets_tuple, exp_name, model_scope) -> result dict
157
+ self._cross_subset_elo_cache: dict[tuple[tuple[str, ...], str, str], dict[str, Any]] = {}
158
+
159
+ if preload:
160
+ self._preload_all()
161
+
162
+ def _preload_all(self) -> None:
163
+ """Preload all data at initialization for better performance."""
164
+ logger.info("Preloading arena data...")
165
+
166
+ subsets = self.discover_subsets()
167
+ logger.info(f"Found {len(subsets)} subsets: {subsets}")
168
+
169
+ for subset in subsets:
170
+ logger.info(f"Loading subset: {subset}")
171
+
172
+ # Preload parquet dataset
173
+ self._preload_dataset(subset)
174
+
175
+ # Load subset info (models, experiments)
176
+ info = self.get_subset_info(subset)
177
+ if info:
178
+ logger.info(f" - {len(info.models)} models, {len(info.experiments)} experiments")
179
+
180
+ # Preload battle logs for each experiment
181
+ for exp_name in info.experiments:
182
+ records = self._load_battle_logs(subset, exp_name)
183
+ logger.info(f" - Experiment '{exp_name}': {len(records)} battles")
184
+
185
+ logger.info("Preloading complete!")
186
+
187
+ def _preload_dataset(self, subset: str) -> None:
188
+ """
189
+ Preload sample text data (instruction, task_type) using pyarrow directly.
190
+
191
+ This is much faster than using HuggingFace datasets because we skip
192
+ decoding image columns. Images are loaded on-demand when requested.
193
+ """
194
+ import pyarrow.parquet as pq
195
+
196
+ subset_path = os.path.join(self.data_dir, subset)
197
+ if not os.path.isdir(subset_path):
198
+ return
199
+
200
+ # Find parquet files
201
+ parquet_files = sorted([
202
+ os.path.join(subset_path, f)
203
+ for f in os.listdir(subset_path)
204
+ if f.startswith("data-") and f.endswith(".parquet")
205
+ ])
206
+
207
+ if not parquet_files:
208
+ return
209
+
210
+ logger.info(f" - Loading metadata from parquet (fast mode)...")
211
+
212
+ # Read all metadata columns + input_images (only to count, not decode)
213
+ columns_to_read = ["index", "instruction", "task_type", "input_images", "prompt_source", "original_metadata"]
214
+
215
+ total_rows = 0
216
+ min_img_count = float('inf')
217
+ max_img_count = 0
218
+ prompt_sources_set: set[str] = set()
219
+
220
+ for pf in parquet_files:
221
+ try:
222
+ # Get available columns in this file
223
+ import pyarrow.parquet as pq_schema
224
+ schema = pq.read_schema(pf)
225
+ available_columns = [c for c in columns_to_read if c in schema.names]
226
+
227
+ # Read the columns we need
228
+ table = pq.read_table(pf, columns=available_columns)
229
+
230
+ # Extract columns with defaults
231
+ def get_column(name, default=None):
232
+ if name in table.column_names:
233
+ return table.column(name).to_pylist()
234
+ return [default] * table.num_rows
235
+
236
+ indices = get_column("index", 0)
237
+ instructions = get_column("instruction", "")
238
+ task_types = get_column("task_type", "")
239
+ prompt_sources = get_column("prompt_source", None)
240
+ original_metadatas = get_column("original_metadata", None)
241
+
242
+ # Handle input_images separately for counting
243
+ has_input_images = "input_images" in table.column_names
244
+ input_images_col = table.column("input_images") if has_input_images else None
245
+
246
+ for i, idx in enumerate(indices):
247
+ idx = int(idx) if idx is not None else i
248
+
249
+ # Count input images without decoding
250
+ img_count = 0
251
+ if input_images_col is not None:
252
+ img_list = input_images_col[i].as_py()
253
+ img_count = len(img_list) if img_list else 0
254
+
255
+ min_img_count = min(min_img_count, img_count) if img_count > 0 else min_img_count
256
+ max_img_count = max(max_img_count, img_count)
257
+
258
+ # Track prompt sources
259
+ ps = prompt_sources[i] if prompt_sources[i] else None
260
+ if ps:
261
+ prompt_sources_set.add(str(ps))
262
+
263
+ # Build metadata dict
264
+ metadata = {
265
+ "instruction": str(instructions[i]) if instructions[i] else "",
266
+ "task_type": str(task_types[i]) if task_types[i] else "",
267
+ "input_image_count": img_count,
268
+ "prompt_source": ps,
269
+ "original_metadata": original_metadatas[i] if original_metadatas[i] else None,
270
+ }
271
+
272
+ self._sample_cache[(subset, idx)] = metadata
273
+ self._sample_file_map[(subset, idx)] = pf
274
+ total_rows += 1
275
+
276
+ except Exception as e:
277
+ logger.warning(f"Failed to read {pf}: {e}")
278
+ continue
279
+
280
+ # Store image count range for this subset
281
+ if total_rows > 0:
282
+ self._image_count_range[subset] = (
283
+ min_img_count if min_img_count != float('inf') else 1,
284
+ max_img_count if max_img_count > 0 else 1
285
+ )
286
+
287
+ # Store prompt sources for this subset
288
+ self._prompt_sources[subset] = sorted(prompt_sources_set)
289
+
290
+ logger.info(f" - Cached {total_rows} samples (input images: {self._image_count_range.get(subset, (1,1))}, sources: {len(prompt_sources_set)})")
291
+
292
+ def discover_subsets(self) -> list[str]:
293
+ """
294
+ Discover all available subsets.
295
+
296
+ A valid subset must exist in both arena_dir (with pk_logs) and data_dir.
297
+
298
+ Returns:
299
+ List of subset names
300
+ """
301
+ if self._subsets is not None:
302
+ return self._subsets
303
+
304
+ # Get subsets from data_dir (have parquet files)
305
+ data_subsets = set(discover_subsets(self.data_dir))
306
+
307
+ # Get subsets from arena_dir (have pk_logs)
308
+ arena_subsets = set()
309
+ if os.path.isdir(self.arena_dir):
310
+ for name in os.listdir(self.arena_dir):
311
+ subset_path = os.path.join(self.arena_dir, name)
312
+ pk_logs_path = os.path.join(subset_path, "pk_logs")
313
+ if os.path.isdir(pk_logs_path):
314
+ # Check if there are any experiment directories with battle logs
315
+ for exp_name in os.listdir(pk_logs_path):
316
+ exp_path = os.path.join(pk_logs_path, exp_name)
317
+ if os.path.isdir(exp_path):
318
+ # Check for .jsonl files
319
+ has_logs = any(
320
+ f.endswith(".jsonl")
321
+ for f in os.listdir(exp_path)
322
+ if os.path.isfile(os.path.join(exp_path, f))
323
+ )
324
+ if has_logs:
325
+ arena_subsets.add(name)
326
+ break
327
+
328
+ # Intersection: must have both data and battle logs
329
+ valid_subsets = sorted(data_subsets & arena_subsets)
330
+ self._subsets = valid_subsets
331
+ return valid_subsets
332
+
333
+ def get_subset_info(self, subset: str) -> Optional[SubsetInfo]:
334
+ """
335
+ Get information about a subset.
336
+
337
+ Args:
338
+ subset: Subset name
339
+
340
+ Returns:
341
+ SubsetInfo or None if subset doesn't exist
342
+ """
343
+ if subset in self._subset_info_cache:
344
+ return self._subset_info_cache[subset]
345
+
346
+ subset_path = os.path.join(self.arena_dir, subset)
347
+ if not os.path.isdir(subset_path):
348
+ return None
349
+
350
+ # Get models
351
+ model_manager = self._get_model_manager(subset)
352
+ models = model_manager.models if model_manager else []
353
+
354
+ # Get experiments
355
+ pk_logs_dir = os.path.join(subset_path, "pk_logs")
356
+ experiments = []
357
+ if os.path.isdir(pk_logs_dir):
358
+ for name in os.listdir(pk_logs_dir):
359
+ exp_path = os.path.join(pk_logs_dir, name)
360
+ if os.path.isdir(exp_path):
361
+ # Check for battle logs
362
+ has_logs = any(
363
+ f.endswith(".jsonl")
364
+ for f in os.listdir(exp_path)
365
+ if os.path.isfile(os.path.join(exp_path, f))
366
+ )
367
+ if has_logs:
368
+ experiments.append(name)
369
+ experiments.sort()
370
+
371
+ # Load state
372
+ state_path = os.path.join(subset_path, "arena", "state.json")
373
+ state = load_state(state_path)
374
+
375
+ # Get image count range
376
+ img_range = self._image_count_range.get(subset, (1, 1))
377
+
378
+ # Get prompt sources
379
+ prompt_sources = self._prompt_sources.get(subset, [])
380
+
381
+ info = SubsetInfo(
382
+ name=subset,
383
+ models=models,
384
+ experiments=experiments,
385
+ total_battles=state.total_battles,
386
+ state=state,
387
+ min_input_images=img_range[0],
388
+ max_input_images=img_range[1],
389
+ prompt_sources=prompt_sources,
390
+ )
391
+
392
+ self._subset_info_cache[subset] = info
393
+ return info
394
+
395
+ def _get_dataset(self, subset: str) -> Optional[ParquetDataset]:
396
+ """Get or create ParquetDataset for a subset."""
397
+ if subset not in self._dataset_cache:
398
+ try:
399
+ self._dataset_cache[subset] = ParquetDataset(self.data_dir, subset)
400
+ except Exception:
401
+ return None
402
+ return self._dataset_cache[subset]
403
+
404
+ def _get_model_manager(self, subset: str) -> Optional[GlobalModelOutputManager]:
405
+ """Get or create GlobalModelOutputManager for a subset."""
406
+ if subset not in self._model_manager_cache:
407
+ models_dir = os.path.join(self.arena_dir, subset, "models")
408
+ if os.path.isdir(models_dir):
409
+ self._model_manager_cache[subset] = GlobalModelOutputManager(models_dir)
410
+ else:
411
+ return None
412
+ return self._model_manager_cache[subset]
413
+
414
+ def _get_sample_data(self, subset: str, sample_index: int) -> dict[str, Any]:
415
+ """Get cached sample metadata."""
416
+ cache_key = (subset, sample_index)
417
+ if cache_key in self._sample_cache:
418
+ return self._sample_cache[cache_key]
419
+
420
+ # Fallback - return defaults
421
+ return {
422
+ "instruction": "",
423
+ "task_type": "",
424
+ "input_image_count": 1,
425
+ "prompt_source": None,
426
+ "original_metadata": None,
427
+ }
428
+
429
+ def _load_battle_logs(self, subset: str, exp_name: str) -> list[BattleRecord]:
430
+ """
431
+ Load battle records from log files.
432
+
433
+ Args:
434
+ subset: Subset name
435
+ exp_name: Experiment name
436
+
437
+ Returns:
438
+ List of BattleRecord objects
439
+ """
440
+ cache_key = (subset, exp_name)
441
+ if cache_key in self._battle_cache:
442
+ return self._battle_cache[cache_key]
443
+
444
+ records: list[BattleRecord] = []
445
+ exp_dir = os.path.join(self.arena_dir, subset, "pk_logs", exp_name)
446
+
447
+ if not os.path.isdir(exp_dir):
448
+ return records
449
+
450
+ # Load slim battle logs
451
+ for filename in os.listdir(exp_dir):
452
+ if not filename.endswith(".jsonl"):
453
+ continue
454
+
455
+ filepath = os.path.join(exp_dir, filename)
456
+ if not os.path.isfile(filepath):
457
+ continue
458
+
459
+ try:
460
+ with open(filepath, "r", encoding="utf-8") as f:
461
+ for line in f:
462
+ line = line.strip()
463
+ if not line:
464
+ continue
465
+ try:
466
+ data = json.loads(line)
467
+ sample_index = data.get("sample_index", -1)
468
+
469
+ # Get cached sample data
470
+ sample_meta = self._get_sample_data(subset, sample_index)
471
+
472
+ record = BattleRecord(
473
+ subset=subset,
474
+ exp_name=exp_name,
475
+ sample_index=sample_index,
476
+ model_a=data.get("model_a", ""),
477
+ model_b=data.get("model_b", ""),
478
+ final_winner=data.get("final_winner", "tie"),
479
+ is_consistent=data.get("is_consistent", False),
480
+ timestamp=data.get("timestamp", ""),
481
+ instruction=sample_meta.get("instruction", ""),
482
+ task_type=sample_meta.get("task_type", ""),
483
+ input_image_count=sample_meta.get("input_image_count", 1),
484
+ prompt_source=sample_meta.get("prompt_source"),
485
+ original_metadata=sample_meta.get("original_metadata"),
486
+ )
487
+ if record.model_a and record.model_b:
488
+ records.append(record)
489
+ except json.JSONDecodeError:
490
+ continue
491
+ except Exception:
492
+ continue
493
+
494
+ # Sort by sample_index
495
+ records.sort(key=lambda r: r.sample_index)
496
+
497
+ # Cache records
498
+ self._battle_cache[cache_key] = records
499
+
500
+ # Build model index for fast filtering
501
+ self._build_model_index(cache_key, records)
502
+
503
+ return records
504
+
505
+ def _build_model_index(
506
+ self, cache_key: tuple[str, str], records: list[BattleRecord]
507
+ ) -> None:
508
+ """Build index for fast model-based filtering."""
509
+ model_index: dict[str, list[int]] = {}
510
+
511
+ for i, record in enumerate(records):
512
+ for model in [record.model_a, record.model_b]:
513
+ if model not in model_index:
514
+ model_index[model] = []
515
+ model_index[model].append(i)
516
+
517
+ self._model_index[cache_key] = model_index
518
+
519
+ def _load_all_experiments_battles(self, subset: str) -> list[BattleRecord]:
520
+ """
521
+ Load battle records from all experiments for a subset.
522
+
523
+ Args:
524
+ subset: Subset name
525
+
526
+ Returns:
527
+ Combined list of BattleRecord objects from all experiments
528
+ """
529
+ info = self.get_subset_info(subset)
530
+ if not info:
531
+ return []
532
+
533
+ all_records: list[BattleRecord] = []
534
+ for exp_name in info.experiments:
535
+ records = self._load_battle_logs(subset, exp_name)
536
+ all_records.extend(records)
537
+
538
+ # Sort by sample_index for consistent ordering
539
+ all_records.sort(key=lambda r: (r.sample_index, r.exp_name, r.model_a, r.model_b))
540
+ return all_records
541
+
542
+ def _load_audit_log(
543
+ self, subset: str, exp_name: str, model_a: str, model_b: str, sample_index: int
544
+ ) -> Optional[dict[str, Any]]:
545
+ """
546
+ Load audit log for a specific battle.
547
+
548
+ Args:
549
+ subset: Subset name
550
+ exp_name: Experiment name
551
+ model_a: First model name
552
+ model_b: Second model name
553
+ sample_index: Sample index
554
+
555
+ Returns:
556
+ Audit data dict or None
557
+ """
558
+ cache_key = (subset, exp_name, model_a, model_b, sample_index)
559
+ if cache_key in self._audit_cache:
560
+ return self._audit_cache[cache_key]
561
+
562
+ # Determine filename (models are sorted alphabetically)
563
+ from genarena.utils import sanitize_name
564
+
565
+ first, second = sorted([model_a, model_b])
566
+ filename = f"{sanitize_name(first)}_vs_{sanitize_name(second)}.jsonl"
567
+ filepath = os.path.join(
568
+ self.arena_dir, subset, "pk_logs", exp_name, "raw_outputs", filename
569
+ )
570
+
571
+ if not os.path.isfile(filepath):
572
+ return None
573
+
574
+ try:
575
+ with open(filepath, "r", encoding="utf-8") as f:
576
+ for line in f:
577
+ line = line.strip()
578
+ if not line:
579
+ continue
580
+ try:
581
+ data = json.loads(line)
582
+ if data.get("sample_index") == sample_index:
583
+ self._audit_cache[cache_key] = data
584
+ return data
585
+ except json.JSONDecodeError:
586
+ continue
587
+ except Exception:
588
+ pass
589
+
590
+ return None
591
+
592
+ def get_battles(
593
+ self,
594
+ subset: str,
595
+ exp_name: str,
596
+ page: int = 1,
597
+ page_size: int = 20,
598
+ models: Optional[list[str]] = None,
599
+ result_filter: Optional[str] = None, # "wins", "losses", "ties"
600
+ consistency_filter: Optional[bool] = None,
601
+ min_images: Optional[int] = None,
602
+ max_images: Optional[int] = None,
603
+ prompt_source: Optional[str] = None,
604
+ ) -> tuple[list[BattleRecord], int]:
605
+ """
606
+ Get paginated battle records with filtering.
607
+
608
+ Args:
609
+ subset: Subset name
610
+ exp_name: Experiment name (use "__all__" for all experiments)
611
+ page: Page number (1-indexed)
612
+ page_size: Number of records per page
613
+ models: Filter by models (show battles involving ANY of these models)
614
+ result_filter: Filter by result relative to models ("wins", "losses", "ties")
615
+ consistency_filter: Filter by consistency (True/False/None for all)
616
+
617
+ Returns:
618
+ Tuple of (records, total_count)
619
+ """
620
+ # Handle "__all__" experiment - combine all experiments
621
+ if exp_name == "__all__":
622
+ all_records = self._load_all_experiments_battles(subset)
623
+ # For __all__, we don't use the model index optimization
624
+ cache_key = None
625
+ else:
626
+ all_records = self._load_battle_logs(subset, exp_name)
627
+ cache_key = (subset, exp_name)
628
+
629
+ # Apply filters using index for better performance
630
+ if models and cache_key and cache_key in self._model_index:
631
+ model_set = set(models)
632
+ model_index = self._model_index[cache_key]
633
+
634
+ if len(models) == 1:
635
+ # Single model: show battles involving this model
636
+ candidate_indices = set(model_index.get(models[0], []))
637
+ filtered = [all_records[i] for i in sorted(candidate_indices)]
638
+ else:
639
+ # 2+ models: show only battles BETWEEN these models (both participants must be in selected models)
640
+ # Find union of all records involving any selected model first
641
+ candidate_indices: set[int] = set()
642
+ for model in models:
643
+ if model in model_index:
644
+ candidate_indices.update(model_index[model])
645
+ # Then filter to keep only battles where BOTH models are in the selected set
646
+ filtered = [
647
+ all_records[i] for i in sorted(candidate_indices)
648
+ if all_records[i].model_a in model_set and all_records[i].model_b in model_set
649
+ ]
650
+
651
+ # Apply result filter
652
+ if result_filter:
653
+ if len(models) == 1:
654
+ # Single model: filter by that model's wins/losses/ties
655
+ model = models[0]
656
+ if result_filter == "wins":
657
+ filtered = [r for r in filtered if r.final_winner == model]
658
+ elif result_filter == "losses":
659
+ filtered = [
660
+ r
661
+ for r in filtered
662
+ if r.final_winner != "tie" and r.final_winner != model
663
+ ]
664
+ elif result_filter == "ties":
665
+ filtered = [r for r in filtered if r.final_winner == "tie"]
666
+ elif len(models) == 2:
667
+ # Two models: filter by winner (result_filter is the winning model name or "tie")
668
+ if result_filter == "ties":
669
+ filtered = [r for r in filtered if r.final_winner == "tie"]
670
+ elif result_filter in models:
671
+ # Filter by specific model winning
672
+ filtered = [r for r in filtered if r.final_winner == result_filter]
673
+ elif models:
674
+ # Fallback for __all__ mode or when index is not available
675
+ model_set = set(models)
676
+ if len(models) == 1:
677
+ model = models[0]
678
+ filtered = [r for r in all_records if model in r.models]
679
+ # Apply result filter
680
+ if result_filter:
681
+ if result_filter == "wins":
682
+ filtered = [r for r in filtered if r.final_winner == model]
683
+ elif result_filter == "losses":
684
+ filtered = [
685
+ r
686
+ for r in filtered
687
+ if r.final_winner != "tie" and r.final_winner != model
688
+ ]
689
+ elif result_filter == "ties":
690
+ filtered = [r for r in filtered if r.final_winner == "tie"]
691
+ else:
692
+ # 2+ models: show battles between these models
693
+ filtered = [
694
+ r for r in all_records
695
+ if r.model_a in model_set and r.model_b in model_set
696
+ ]
697
+ # Apply result filter
698
+ if result_filter:
699
+ if result_filter == "ties":
700
+ filtered = [r for r in filtered if r.final_winner == "tie"]
701
+ elif result_filter in models:
702
+ filtered = [r for r in filtered if r.final_winner == result_filter]
703
+ else:
704
+ filtered = all_records
705
+
706
+ # Apply consistency filter
707
+ if consistency_filter is not None:
708
+ filtered = [r for r in filtered if r.is_consistent == consistency_filter]
709
+
710
+ # Apply input image count filter
711
+ if min_images is not None or max_images is not None:
712
+ min_img = min_images if min_images is not None else 0
713
+ max_img = max_images if max_images is not None else float('inf')
714
+ filtered = [r for r in filtered if min_img <= r.input_image_count <= max_img]
715
+
716
+ # Apply prompt_source filter
717
+ if prompt_source:
718
+ filtered = [r for r in filtered if r.prompt_source == prompt_source]
719
+
720
+ total_count = len(filtered)
721
+
722
+ # Paginate
723
+ start = (page - 1) * page_size
724
+ end = start + page_size
725
+ page_records = filtered[start:end]
726
+
727
+ return page_records, total_count
728
+
729
+ def search_battles(
730
+ self,
731
+ subset: str,
732
+ exp_name: str,
733
+ query: str,
734
+ page: int = 1,
735
+ page_size: int = 20,
736
+ models: Optional[list[str]] = None,
737
+ consistency_filter: Optional[bool] = None,
738
+ search_fields: Optional[list[str]] = None,
739
+ ) -> tuple[list[BattleRecord], int]:
740
+ """
741
+ Search battle records by text query (full-text search).
742
+
743
+ Searches across instruction, task_type, prompt_source, and original_metadata.
744
+
745
+ Args:
746
+ subset: Subset name
747
+ exp_name: Experiment name (use "__all__" for all experiments)
748
+ query: Search query string (case-insensitive)
749
+ page: Page number (1-indexed)
750
+ page_size: Number of records per page
751
+ models: Optional filter by models
752
+ consistency_filter: Optional filter by consistency
753
+ search_fields: Fields to search in (default: all searchable fields)
754
+
755
+ Returns:
756
+ Tuple of (matching_records, total_count)
757
+ """
758
+ if not query or not query.strip():
759
+ # Empty query - return regular filtered results
760
+ return self.get_battles(
761
+ subset, exp_name, page, page_size,
762
+ models=models, consistency_filter=consistency_filter
763
+ )
764
+
765
+ # Normalize query for case-insensitive search
766
+ query_lower = query.lower().strip()
767
+ # Create regex pattern for more flexible matching
768
+ query_pattern = re.compile(re.escape(query_lower), re.IGNORECASE)
769
+
770
+ # Determine which fields to search
771
+ all_searchable_fields = ["instruction", "task_type", "prompt_source", "original_metadata"]
772
+ fields_to_search = search_fields if search_fields else all_searchable_fields
773
+
774
+ # Load all records
775
+ if exp_name == "__all__":
776
+ all_records = self._load_all_experiments_battles(subset)
777
+ else:
778
+ all_records = self._load_battle_logs(subset, exp_name)
779
+
780
+ # Apply model filter first (for efficiency)
781
+ if models:
782
+ model_set = set(models)
783
+ if len(models) == 1:
784
+ all_records = [r for r in all_records if models[0] in r.models]
785
+ else:
786
+ all_records = [
787
+ r for r in all_records
788
+ if r.model_a in model_set and r.model_b in model_set
789
+ ]
790
+
791
+ # Apply consistency filter
792
+ if consistency_filter is not None:
793
+ all_records = [r for r in all_records if r.is_consistent == consistency_filter]
794
+
795
+ # Search filter
796
+ def matches_query(record: BattleRecord) -> bool:
797
+ """Check if record matches the search query."""
798
+ for field_name in fields_to_search:
799
+ value = getattr(record, field_name, None)
800
+ if value is None:
801
+ continue
802
+
803
+ # Handle different field types
804
+ if field_name == "original_metadata" and isinstance(value, dict):
805
+ # Search in JSON string representation of metadata
806
+ metadata_str = json.dumps(value, ensure_ascii=False).lower()
807
+ if query_pattern.search(metadata_str):
808
+ return True
809
+ elif isinstance(value, str):
810
+ if query_pattern.search(value):
811
+ return True
812
+
813
+ return False
814
+
815
+ # Apply search filter
816
+ filtered = [r for r in all_records if matches_query(r)]
817
+
818
+ total_count = len(filtered)
819
+
820
+ # Paginate
821
+ start = (page - 1) * page_size
822
+ end = start + page_size
823
+ page_records = filtered[start:end]
824
+
825
+ return page_records, total_count
826
+
827
+ def search_prompts(
828
+ self,
829
+ subset: str,
830
+ exp_name: str,
831
+ query: str,
832
+ page: int = 1,
833
+ page_size: int = 10,
834
+ filter_models: Optional[list[str]] = None,
835
+ ) -> tuple[list[dict[str, Any]], int]:
836
+ """
837
+ Search prompts/samples by text query.
838
+
839
+ Args:
840
+ subset: Subset name
841
+ exp_name: Experiment name (use "__all__" for all experiments)
842
+ query: Search query string
843
+ page: Page number
844
+ page_size: Records per page
845
+ filter_models: Optional filter by models
846
+
847
+ Returns:
848
+ Tuple of (matching_prompts, total_count)
849
+ """
850
+ if not query or not query.strip():
851
+ # Empty query - return regular results
852
+ return self.get_prompts(subset, exp_name, page, page_size, filter_models=filter_models)
853
+
854
+ # Normalize query
855
+ query_lower = query.lower().strip()
856
+ query_pattern = re.compile(re.escape(query_lower), re.IGNORECASE)
857
+
858
+ # Load records and group by sample
859
+ if exp_name == "__all__":
860
+ all_records = self._load_all_experiments_battles(subset)
861
+ else:
862
+ all_records = self._load_battle_logs(subset, exp_name)
863
+
864
+ # Group by sample_index
865
+ sample_records: dict[int, list[BattleRecord]] = {}
866
+ for record in all_records:
867
+ if record.sample_index not in sample_records:
868
+ sample_records[record.sample_index] = []
869
+ sample_records[record.sample_index].append(record)
870
+
871
+ # Filter samples by query
872
+ matching_samples = []
873
+ for sample_index, records in sample_records.items():
874
+ if not records:
875
+ continue
876
+
877
+ first_record = records[0]
878
+
879
+ # Search in instruction, task_type, prompt_source, original_metadata
880
+ match_found = False
881
+
882
+ if first_record.instruction and query_pattern.search(first_record.instruction):
883
+ match_found = True
884
+ elif first_record.task_type and query_pattern.search(first_record.task_type):
885
+ match_found = True
886
+ elif first_record.prompt_source and query_pattern.search(first_record.prompt_source):
887
+ match_found = True
888
+ elif first_record.original_metadata:
889
+ metadata_str = json.dumps(first_record.original_metadata, ensure_ascii=False).lower()
890
+ if query_pattern.search(metadata_str):
891
+ match_found = True
892
+
893
+ if match_found:
894
+ matching_samples.append(sample_index)
895
+
896
+ # Sort and paginate
897
+ matching_samples.sort()
898
+ total_count = len(matching_samples)
899
+
900
+ start = (page - 1) * page_size
901
+ end = start + page_size
902
+ page_samples = matching_samples[start:end]
903
+
904
+ # Build result for each sample using get_sample_all_models
905
+ results = []
906
+ for sample_index in page_samples:
907
+ prompt_data = self.get_sample_all_models(subset, exp_name, sample_index, filter_models)
908
+ results.append(prompt_data)
909
+
910
+ return results, total_count
911
+
912
+ def get_battle_detail(
913
+ self, subset: str, exp_name: str, model_a: str, model_b: str, sample_index: int
914
+ ) -> Optional[BattleRecord]:
915
+ """
916
+ Get detailed battle record including VLM outputs.
917
+
918
+ Args:
919
+ subset: Subset name
920
+ exp_name: Experiment name (use "__all__" for all experiments)
921
+ model_a: First model name
922
+ model_b: Second model name
923
+ sample_index: Sample index
924
+
925
+ Returns:
926
+ BattleRecord with audit data, or None
927
+ """
928
+ # Find the battle record
929
+ if exp_name == "__all__":
930
+ all_records = self._load_all_experiments_battles(subset)
931
+ else:
932
+ all_records = self._load_battle_logs(subset, exp_name)
933
+
934
+ record = None
935
+ for r in all_records:
936
+ if (
937
+ r.sample_index == sample_index
938
+ and set([r.model_a, r.model_b]) == set([model_a, model_b])
939
+ ):
940
+ record = r
941
+ break
942
+
943
+ if not record:
944
+ return None
945
+
946
+ # Load audit data (use the record's actual exp_name for audit log lookup)
947
+ actual_exp_name = record.exp_name
948
+ audit = self._load_audit_log(
949
+ subset, actual_exp_name, record.model_a, record.model_b, sample_index
950
+ )
951
+ if audit:
952
+ record.original_call = audit.get("original_call")
953
+ record.swapped_call = audit.get("swapped_call")
954
+
955
+ return record
956
+
957
+ def get_image_path(
958
+ self, subset: str, model: str, sample_index: int
959
+ ) -> Optional[str]:
960
+ """
961
+ Get path to model output image.
962
+
963
+ Args:
964
+ subset: Subset name
965
+ model: Model name
966
+ sample_index: Sample index
967
+
968
+ Returns:
969
+ Image file path or None
970
+ """
971
+ model_manager = self._get_model_manager(subset)
972
+ if model_manager:
973
+ return model_manager.get_output_path(model, sample_index)
974
+ return None
975
+
976
+ def get_input_image(self, subset: str, sample_index: int) -> Optional[bytes]:
977
+ """
978
+ Get input image bytes for a sample.
979
+
980
+ Uses pyarrow to read directly from parquet for better performance.
981
+ Uses cached file mapping for fast lookup.
982
+
983
+ Args:
984
+ subset: Subset name
985
+ sample_index: Sample index
986
+
987
+ Returns:
988
+ Image bytes or None
989
+ """
990
+ import pyarrow.parquet as pq
991
+
992
+ # Use cached file mapping if available (fast path)
993
+ cache_key = (subset, sample_index)
994
+ if cache_key in self._sample_file_map:
995
+ pf = self._sample_file_map[cache_key]
996
+ result = self._read_image_from_parquet(pf, sample_index)
997
+ if result is not None:
998
+ return result
999
+
1000
+ # Fallback: search all parquet files (slow path)
1001
+ subset_path = os.path.join(self.data_dir, subset)
1002
+ if not os.path.isdir(subset_path):
1003
+ return None
1004
+
1005
+ parquet_files = sorted([
1006
+ os.path.join(subset_path, f)
1007
+ for f in os.listdir(subset_path)
1008
+ if f.startswith("data-") and f.endswith(".parquet")
1009
+ ])
1010
+
1011
+ for pf in parquet_files:
1012
+ result = self._read_image_from_parquet(pf, sample_index)
1013
+ if result is not None:
1014
+ return result
1015
+
1016
+ return None
1017
+
1018
+ def _read_image_from_parquet(self, parquet_file: str, sample_index: int) -> Optional[bytes]:
1019
+ """Read a single image from a parquet file."""
1020
+ import pyarrow.parquet as pq
1021
+
1022
+ try:
1023
+ table = pq.read_table(parquet_file, columns=["index", "input_images"])
1024
+ indices = table.column("index").to_pylist()
1025
+
1026
+ if sample_index not in indices:
1027
+ return None
1028
+
1029
+ row_idx = indices.index(sample_index)
1030
+ input_images = table.column("input_images")[row_idx].as_py()
1031
+
1032
+ if not input_images or len(input_images) == 0:
1033
+ return None
1034
+
1035
+ img_data = input_images[0]
1036
+
1037
+ # Handle different formats
1038
+ if isinstance(img_data, bytes):
1039
+ return img_data
1040
+ elif isinstance(img_data, dict):
1041
+ # HuggingFace Image format: {"bytes": ..., "path": ...}
1042
+ if "bytes" in img_data and img_data["bytes"]:
1043
+ return img_data["bytes"]
1044
+ elif "path" in img_data and img_data["path"]:
1045
+ path = img_data["path"]
1046
+ if os.path.isfile(path):
1047
+ with open(path, "rb") as f:
1048
+ return f.read()
1049
+
1050
+ except Exception as e:
1051
+ logger.debug(f"Error reading image from {parquet_file}: {e}")
1052
+
1053
+ return None
1054
+
1055
+ def get_input_image_count(self, subset: str, sample_index: int) -> int:
1056
+ """Get the number of input images for a sample."""
1057
+ import pyarrow.parquet as pq
1058
+
1059
+ cache_key = (subset, sample_index)
1060
+ if cache_key in self._sample_file_map:
1061
+ pf = self._sample_file_map[cache_key]
1062
+ try:
1063
+ table = pq.read_table(pf, columns=["index", "input_images"])
1064
+ indices = table.column("index").to_pylist()
1065
+ if sample_index in indices:
1066
+ row_idx = indices.index(sample_index)
1067
+ input_images = table.column("input_images")[row_idx].as_py()
1068
+ return len(input_images) if input_images else 0
1069
+ except Exception:
1070
+ pass
1071
+ return 1 # Default to 1
1072
+
1073
+ def get_input_image_by_idx(self, subset: str, sample_index: int, img_idx: int = 0) -> Optional[bytes]:
1074
+ """Get a specific input image by index."""
1075
+ import pyarrow.parquet as pq
1076
+
1077
+ cache_key = (subset, sample_index)
1078
+ if cache_key not in self._sample_file_map:
1079
+ return None
1080
+
1081
+ pf = self._sample_file_map[cache_key]
1082
+ try:
1083
+ table = pq.read_table(pf, columns=["index", "input_images"])
1084
+ indices = table.column("index").to_pylist()
1085
+
1086
+ if sample_index not in indices:
1087
+ return None
1088
+
1089
+ row_idx = indices.index(sample_index)
1090
+ input_images = table.column("input_images")[row_idx].as_py()
1091
+
1092
+ if not input_images or img_idx >= len(input_images):
1093
+ return None
1094
+
1095
+ img_data = input_images[img_idx]
1096
+
1097
+ if isinstance(img_data, bytes):
1098
+ return img_data
1099
+ elif isinstance(img_data, dict):
1100
+ if "bytes" in img_data and img_data["bytes"]:
1101
+ return img_data["bytes"]
1102
+ elif "path" in img_data and img_data["path"]:
1103
+ path = img_data["path"]
1104
+ if os.path.isfile(path):
1105
+ with open(path, "rb") as f:
1106
+ return f.read()
1107
+ except Exception as e:
1108
+ logger.debug(f"Error reading image: {e}")
1109
+
1110
+ return None
1111
+
1112
+ def get_head_to_head(
1113
+ self, subset: str, exp_name: str, model_a: str, model_b: str
1114
+ ) -> dict[str, Any]:
1115
+ """
1116
+ Get head-to-head statistics between two models.
1117
+
1118
+ Returns:
1119
+ Dict with wins_a, wins_b, ties, total, win_rate_a, win_rate_b
1120
+ """
1121
+ if exp_name == "__all__":
1122
+ all_records = self._load_all_experiments_battles(subset)
1123
+ # For __all__, we need to filter manually
1124
+ h2h_records = [
1125
+ r for r in all_records
1126
+ if set([r.model_a, r.model_b]) == set([model_a, model_b])
1127
+ ]
1128
+ else:
1129
+ all_records = self._load_battle_logs(subset, exp_name)
1130
+ cache_key = (subset, exp_name)
1131
+ model_index = self._model_index.get(cache_key, {})
1132
+
1133
+ # Find battles between these two models
1134
+ indices_a = set(model_index.get(model_a, []))
1135
+ indices_b = set(model_index.get(model_b, []))
1136
+ h2h_indices = indices_a & indices_b
1137
+ h2h_records = [all_records[idx] for idx in h2h_indices]
1138
+
1139
+ wins_a = 0
1140
+ wins_b = 0
1141
+ ties = 0
1142
+
1143
+ for record in h2h_records:
1144
+ if record.final_winner == model_a:
1145
+ wins_a += 1
1146
+ elif record.final_winner == model_b:
1147
+ wins_b += 1
1148
+ else:
1149
+ ties += 1
1150
+
1151
+ total = wins_a + wins_b + ties
1152
+
1153
+ return {
1154
+ "model_a": model_a,
1155
+ "model_b": model_b,
1156
+ "wins_a": wins_a,
1157
+ "wins_b": wins_b,
1158
+ "ties": ties,
1159
+ "total": total,
1160
+ "win_rate_a": wins_a / total if total > 0 else 0,
1161
+ "win_rate_b": wins_b / total if total > 0 else 0,
1162
+ "tie_rate": ties / total if total > 0 else 0,
1163
+ }
1164
+
1165
+ def get_win_rate_matrix(
1166
+ self,
1167
+ subset: str,
1168
+ exp_name: str = "__all__",
1169
+ filter_models: Optional[list[str]] = None,
1170
+ ) -> dict[str, Any]:
1171
+ """
1172
+ Compute win rate matrix for all model pairs.
1173
+
1174
+ Args:
1175
+ subset: Subset name
1176
+ exp_name: Experiment name (use "__all__" for all experiments)
1177
+ filter_models: Optional list of models to include
1178
+
1179
+ Returns:
1180
+ Dict with:
1181
+ - models: List of model names (sorted by ELO)
1182
+ - matrix: 2D array where matrix[i][j] = win rate of model i vs model j
1183
+ - counts: 2D array where counts[i][j] = number of battles between i and j
1184
+ - wins: 2D array where wins[i][j] = wins of model i vs model j
1185
+ """
1186
+ # Load all records
1187
+ if exp_name == "__all__":
1188
+ all_records = self._load_all_experiments_battles(subset)
1189
+ else:
1190
+ all_records = self._load_battle_logs(subset, exp_name)
1191
+
1192
+ # Determine models to include
1193
+ info = self.get_subset_info(subset)
1194
+ if filter_models:
1195
+ models = [m for m in filter_models if m in info.models]
1196
+ else:
1197
+ models = list(info.models)
1198
+
1199
+ # Get ELO leaderboard to sort models by ELO
1200
+ leaderboard = self.get_elo_leaderboard(subset, models)
1201
+ models = [entry["model"] for entry in leaderboard]
1202
+
1203
+ n = len(models)
1204
+ model_to_idx = {m: i for i, m in enumerate(models)}
1205
+
1206
+ # Initialize matrices
1207
+ wins_matrix = [[0] * n for _ in range(n)]
1208
+ counts_matrix = [[0] * n for _ in range(n)]
1209
+
1210
+ # Count wins for each pair
1211
+ model_set = set(models)
1212
+ for record in all_records:
1213
+ if record.model_a not in model_set or record.model_b not in model_set:
1214
+ continue
1215
+
1216
+ i = model_to_idx[record.model_a]
1217
+ j = model_to_idx[record.model_b]
1218
+
1219
+ # Count total battles (symmetric)
1220
+ counts_matrix[i][j] += 1
1221
+ counts_matrix[j][i] += 1
1222
+
1223
+ # Count wins
1224
+ if record.final_winner == record.model_a:
1225
+ wins_matrix[i][j] += 1
1226
+ elif record.final_winner == record.model_b:
1227
+ wins_matrix[j][i] += 1
1228
+ else:
1229
+ # Tie counts as 0.5 win for each
1230
+ wins_matrix[i][j] += 0.5
1231
+ wins_matrix[j][i] += 0.5
1232
+
1233
+ # Compute win rate matrix
1234
+ win_rate_matrix = [[0.0] * n for _ in range(n)]
1235
+ for i in range(n):
1236
+ for j in range(n):
1237
+ if counts_matrix[i][j] > 0:
1238
+ win_rate_matrix[i][j] = wins_matrix[i][j] / counts_matrix[i][j]
1239
+ elif i == j:
1240
+ win_rate_matrix[i][j] = 0.5 # Self vs self
1241
+
1242
+ return {
1243
+ "models": models,
1244
+ "matrix": win_rate_matrix,
1245
+ "counts": counts_matrix,
1246
+ "wins": wins_matrix,
1247
+ }
1248
+
1249
+ def get_elo_by_source(
1250
+ self,
1251
+ subset: str,
1252
+ exp_name: str = "__all__",
1253
+ ) -> dict[str, Any]:
1254
+ """
1255
+ Compute ELO rankings grouped by prompt_source.
1256
+
1257
+ Args:
1258
+ subset: Subset name
1259
+ exp_name: Experiment name
1260
+
1261
+ Returns:
1262
+ Dict with:
1263
+ - sources: List of source names
1264
+ - leaderboards: Dict mapping source -> list of model ELO entries
1265
+ - sample_counts: Dict mapping source -> number of samples
1266
+ - battle_counts: Dict mapping source -> number of battles
1267
+ """
1268
+ from genarena.bt_elo import compute_bt_elo_ratings
1269
+
1270
+ # Load all records
1271
+ if exp_name == "__all__":
1272
+ all_records = self._load_all_experiments_battles(subset)
1273
+ else:
1274
+ all_records = self._load_battle_logs(subset, exp_name)
1275
+
1276
+ # Group battles by prompt_source
1277
+ battles_by_source: dict[str, list[tuple[str, str, str]]] = {}
1278
+ sample_counts: dict[str, set[int]] = {}
1279
+
1280
+ for record in all_records:
1281
+ source = record.prompt_source or "unknown"
1282
+ if source not in battles_by_source:
1283
+ battles_by_source[source] = []
1284
+ sample_counts[source] = set()
1285
+
1286
+ # Convert winner to bt_elo format
1287
+ if record.final_winner == record.model_a:
1288
+ winner = "model_a"
1289
+ elif record.final_winner == record.model_b:
1290
+ winner = "model_b"
1291
+ else:
1292
+ winner = "tie"
1293
+
1294
+ battles_by_source[source].append((record.model_a, record.model_b, winner))
1295
+ sample_counts[source].add(record.sample_index)
1296
+
1297
+ # Compute ELO for each source
1298
+ leaderboards: dict[str, list[dict[str, Any]]] = {}
1299
+ battle_counts: dict[str, int] = {}
1300
+
1301
+ for source, battles in battles_by_source.items():
1302
+ if not battles:
1303
+ continue
1304
+
1305
+ battle_counts[source] = len(battles)
1306
+
1307
+ try:
1308
+ ratings = compute_bt_elo_ratings(battles)
1309
+
1310
+ # Build leaderboard
1311
+ entries = []
1312
+ for model, elo in ratings.items():
1313
+ # Count wins/losses/ties for this model in this source
1314
+ wins = losses = ties = 0
1315
+ for ma, mb, w in battles:
1316
+ if model == ma:
1317
+ if w == "model_a":
1318
+ wins += 1
1319
+ elif w == "model_b":
1320
+ losses += 1
1321
+ else:
1322
+ ties += 1
1323
+ elif model == mb:
1324
+ if w == "model_b":
1325
+ wins += 1
1326
+ elif w == "model_a":
1327
+ losses += 1
1328
+ else:
1329
+ ties += 1
1330
+
1331
+ total = wins + losses + ties
1332
+ entries.append({
1333
+ "model": model,
1334
+ "elo": round(elo, 1),
1335
+ "wins": wins,
1336
+ "losses": losses,
1337
+ "ties": ties,
1338
+ "total": total,
1339
+ "win_rate": (wins + 0.5 * ties) / total if total > 0 else 0,
1340
+ })
1341
+
1342
+ # Sort by ELO descending
1343
+ entries.sort(key=lambda x: -x["elo"])
1344
+ leaderboards[source] = entries
1345
+
1346
+ except Exception as e:
1347
+ logger.warning(f"Failed to compute ELO for source {source}: {e}")
1348
+ continue
1349
+
1350
+ # Sort sources by battle count
1351
+ sources = sorted(battle_counts.keys(), key=lambda s: -battle_counts[s])
1352
+
1353
+ return {
1354
+ "sources": sources,
1355
+ "leaderboards": leaderboards,
1356
+ "sample_counts": {s: len(sample_counts[s]) for s in sources},
1357
+ "battle_counts": battle_counts,
1358
+ }
1359
+
1360
+ def _load_elo_snapshot(self, snapshot_path: str) -> Optional[dict[str, Any]]:
1361
+ """
1362
+ Load ELO snapshot from a JSON file.
1363
+
1364
+ Args:
1365
+ snapshot_path: Path to elo_snapshot.json
1366
+
1367
+ Returns:
1368
+ Dict with elo ratings and metadata, or None if not found
1369
+ """
1370
+ if not os.path.isfile(snapshot_path):
1371
+ return None
1372
+
1373
+ try:
1374
+ with open(snapshot_path, "r", encoding="utf-8") as f:
1375
+ data = json.load(f)
1376
+
1377
+ if not isinstance(data, dict):
1378
+ return None
1379
+
1380
+ # Extract ELO ratings (support both {"elo": {...}} and direct {model: elo} format)
1381
+ elo_data = data.get("elo") if isinstance(data.get("elo"), dict) else data
1382
+ if not isinstance(elo_data, dict):
1383
+ return None
1384
+
1385
+ return {
1386
+ "elo": {str(k): float(v) for k, v in elo_data.items()},
1387
+ "battle_count": data.get("battle_count", 0),
1388
+ "model_count": data.get("model_count", len(elo_data)),
1389
+ "exp_name": data.get("exp_name", ""),
1390
+ }
1391
+ except Exception as e:
1392
+ logger.debug(f"Failed to load ELO snapshot from {snapshot_path}: {e}")
1393
+ return None
1394
+
1395
+ def get_elo_history(
1396
+ self,
1397
+ subset: str,
1398
+ exp_name: str = "__all__",
1399
+ granularity: str = "experiment",
1400
+ filter_models: Optional[list[str]] = None,
1401
+ max_points: int = 50,
1402
+ ) -> dict[str, Any]:
1403
+ """
1404
+ Get ELO history over experiments by reading pre-computed elo_snapshot.json files.
1405
+
1406
+ Args:
1407
+ subset: Subset name
1408
+ exp_name: Experiment name (only "__all__" or "experiment" granularity supported)
1409
+ granularity: Grouping method ("experiment" reads from snapshots; time-based not supported)
1410
+ filter_models: Optional models to track
1411
+ max_points: Maximum number of time points to return
1412
+
1413
+ Returns:
1414
+ Dict with:
1415
+ - timestamps: List of experiment names
1416
+ - models: Dict mapping model -> list of ELO values
1417
+ - battle_counts: List of cumulative battle counts
1418
+ """
1419
+ # Get subset info for experiment order
1420
+ info = self.get_subset_info(subset)
1421
+ if not info:
1422
+ return {"timestamps": [], "models": {}, "battle_counts": []}
1423
+
1424
+ # Only support experiment-level granularity (reading from snapshots)
1425
+ # Time-based granularity would require real-time computation which we want to avoid
1426
+ if granularity != "experiment":
1427
+ logger.warning(
1428
+ f"Time-based granularity '{granularity}' is not supported for ELO history. "
1429
+ f"Falling back to 'experiment' granularity."
1430
+ )
1431
+
1432
+ # Get ordered list of experiments
1433
+ experiments = info.experiments
1434
+ if not experiments:
1435
+ return {"timestamps": [], "models": {}, "battle_counts": []}
1436
+
1437
+ # If too many experiments, sample them
1438
+ if len(experiments) > max_points:
1439
+ step = len(experiments) // max_points
1440
+ sampled = [experiments[i] for i in range(0, len(experiments), step)]
1441
+ if sampled[-1] != experiments[-1]:
1442
+ sampled.append(experiments[-1])
1443
+ experiments = sampled
1444
+
1445
+ # Load ELO snapshots for each experiment
1446
+ timestamps: list[str] = []
1447
+ model_elos: dict[str, list[Optional[float]]] = {}
1448
+ battle_counts: list[int] = []
1449
+
1450
+ pk_logs_dir = os.path.join(self.arena_dir, subset, "pk_logs")
1451
+
1452
+ for exp in experiments:
1453
+ snapshot_path = os.path.join(pk_logs_dir, exp, "elo_snapshot.json")
1454
+ snapshot = self._load_elo_snapshot(snapshot_path)
1455
+
1456
+ if snapshot is None:
1457
+ # Skip experiments without snapshots
1458
+ continue
1459
+
1460
+ elo_ratings = snapshot["elo"]
1461
+ battle_count = snapshot["battle_count"]
1462
+
1463
+ timestamps.append(exp)
1464
+ battle_counts.append(battle_count)
1465
+
1466
+ # Update model ELOs
1467
+ all_models_so_far = set(model_elos.keys()) | set(elo_ratings.keys())
1468
+ for model in all_models_so_far:
1469
+ if model not in model_elos:
1470
+ # New model: fill with None for previous timestamps
1471
+ model_elos[model] = [None] * (len(timestamps) - 1)
1472
+ model_elos[model].append(elo_ratings.get(model))
1473
+
1474
+ # Ensure all models have the same length
1475
+ for model in model_elos:
1476
+ if len(model_elos[model]) < len(timestamps):
1477
+ model_elos[model].append(None)
1478
+
1479
+ # Filter to requested models if specified
1480
+ if filter_models:
1481
+ filter_set = set(filter_models)
1482
+ model_elos = {m: v for m, v in model_elos.items() if m in filter_set}
1483
+
1484
+ return {
1485
+ "timestamps": timestamps,
1486
+ "models": model_elos,
1487
+ "battle_counts": battle_counts,
1488
+ }
1489
+
1490
+ def get_cross_subset_info(
1491
+ self,
1492
+ subsets: list[str],
1493
+ ) -> dict[str, Any]:
1494
+ """
1495
+ Get information about models across multiple subsets.
1496
+
1497
+ Args:
1498
+ subsets: List of subset names
1499
+
1500
+ Returns:
1501
+ Dict with:
1502
+ - common_models: Models present in all subsets
1503
+ - all_models: Models present in any subset
1504
+ - per_subset_models: Dict mapping subset -> list of models
1505
+ - per_subset_battles: Dict mapping subset -> battle count
1506
+ """
1507
+ per_subset_models: dict[str, set[str]] = {}
1508
+ per_subset_battles: dict[str, int] = {}
1509
+
1510
+ for subset in subsets:
1511
+ info = self.get_subset_info(subset)
1512
+ if info:
1513
+ per_subset_models[subset] = set(info.models)
1514
+ per_subset_battles[subset] = info.total_battles
1515
+
1516
+ if not per_subset_models:
1517
+ return {
1518
+ "common_models": [],
1519
+ "all_models": [],
1520
+ "per_subset_models": {},
1521
+ "per_subset_battles": {},
1522
+ }
1523
+
1524
+ # Compute intersection and union
1525
+ all_model_sets = list(per_subset_models.values())
1526
+ common_models = set.intersection(*all_model_sets) if all_model_sets else set()
1527
+ all_models = set.union(*all_model_sets) if all_model_sets else set()
1528
+
1529
+ return {
1530
+ "common_models": sorted(common_models),
1531
+ "all_models": sorted(all_models),
1532
+ "per_subset_models": {s: sorted(m) for s, m in per_subset_models.items()},
1533
+ "per_subset_battles": per_subset_battles,
1534
+ "total_battles": sum(per_subset_battles.values()),
1535
+ }
1536
+
1537
+ def get_cross_subset_elo(
1538
+ self,
1539
+ subsets: list[str],
1540
+ exp_name: str = "__all__",
1541
+ model_scope: str = "all",
1542
+ ) -> dict[str, Any]:
1543
+ """
1544
+ Compute ELO rankings across multiple subsets.
1545
+
1546
+ Args:
1547
+ subsets: List of subset names
1548
+ exp_name: Experiment name (use "__all__" for all)
1549
+ model_scope: "common" = only models in all subsets, "all" = all models
1550
+
1551
+ Returns:
1552
+ Dict with merged leaderboard and per-subset comparison
1553
+ """
1554
+ # Check cache first
1555
+ cache_key = (tuple(sorted(subsets)), exp_name, model_scope)
1556
+ if cache_key in self._cross_subset_elo_cache:
1557
+ return self._cross_subset_elo_cache[cache_key]
1558
+
1559
+ from genarena.bt_elo import compute_bt_elo_ratings
1560
+
1561
+ # Get cross-subset info
1562
+ cross_info = self.get_cross_subset_info(subsets)
1563
+
1564
+ # Determine models to include
1565
+ if model_scope == "common":
1566
+ included_models = set(cross_info["common_models"])
1567
+ else:
1568
+ included_models = set(cross_info["all_models"])
1569
+
1570
+ if not included_models:
1571
+ return {
1572
+ "subsets": subsets,
1573
+ "model_scope": model_scope,
1574
+ "common_models": cross_info["common_models"],
1575
+ "all_models": cross_info["all_models"],
1576
+ "total_battles": 0,
1577
+ "leaderboard": [],
1578
+ "per_subset_elo": {},
1579
+ }
1580
+
1581
+ # Collect all battles
1582
+ all_battles = []
1583
+ model_presence: dict[str, set[str]] = {} # model -> set of subsets it's in
1584
+
1585
+ for subset in subsets:
1586
+ if exp_name == "__all__":
1587
+ records = self._load_all_experiments_battles(subset)
1588
+ else:
1589
+ records = self._load_battle_logs(subset, exp_name)
1590
+
1591
+ for record in records:
1592
+ # Skip if either model is not in included set
1593
+ if model_scope == "common":
1594
+ if record.model_a not in included_models or record.model_b not in included_models:
1595
+ continue
1596
+
1597
+ # Convert to bt_elo format
1598
+ if record.final_winner == record.model_a:
1599
+ winner = "model_a"
1600
+ elif record.final_winner == record.model_b:
1601
+ winner = "model_b"
1602
+ else:
1603
+ winner = "tie"
1604
+
1605
+ all_battles.append((record.model_a, record.model_b, winner))
1606
+
1607
+ # Track model presence
1608
+ for m in [record.model_a, record.model_b]:
1609
+ if m not in model_presence:
1610
+ model_presence[m] = set()
1611
+ model_presence[m].add(subset)
1612
+
1613
+ if not all_battles:
1614
+ return {
1615
+ "subsets": subsets,
1616
+ "model_scope": model_scope,
1617
+ "common_models": cross_info["common_models"],
1618
+ "all_models": cross_info["all_models"],
1619
+ "total_battles": 0,
1620
+ "leaderboard": [],
1621
+ "per_subset_elo": {},
1622
+ }
1623
+
1624
+ # Compute merged ELO
1625
+ try:
1626
+ ratings = compute_bt_elo_ratings(all_battles)
1627
+ except Exception as e:
1628
+ logger.error(f"Failed to compute cross-subset ELO: {e}")
1629
+ return {
1630
+ "subsets": subsets,
1631
+ "model_scope": model_scope,
1632
+ "error": str(e),
1633
+ "total_battles": len(all_battles),
1634
+ "leaderboard": [],
1635
+ }
1636
+
1637
+ # Count wins/losses/ties per model
1638
+ model_stats: dict[str, dict[str, int]] = {}
1639
+ for ma, mb, winner in all_battles:
1640
+ for m in [ma, mb]:
1641
+ if m not in model_stats:
1642
+ model_stats[m] = {"wins": 0, "losses": 0, "ties": 0}
1643
+
1644
+ if winner == "model_a":
1645
+ model_stats[ma]["wins"] += 1
1646
+ model_stats[mb]["losses"] += 1
1647
+ elif winner == "model_b":
1648
+ model_stats[mb]["wins"] += 1
1649
+ model_stats[ma]["losses"] += 1
1650
+ else:
1651
+ model_stats[ma]["ties"] += 1
1652
+ model_stats[mb]["ties"] += 1
1653
+
1654
+ # Build leaderboard
1655
+ leaderboard = []
1656
+ for model, elo in ratings.items():
1657
+ stats = model_stats.get(model, {"wins": 0, "losses": 0, "ties": 0})
1658
+ total = stats["wins"] + stats["losses"] + stats["ties"]
1659
+ leaderboard.append({
1660
+ "model": model,
1661
+ "elo": round(elo, 1),
1662
+ "wins": stats["wins"],
1663
+ "losses": stats["losses"],
1664
+ "ties": stats["ties"],
1665
+ "total": total,
1666
+ "win_rate": (stats["wins"] + 0.5 * stats["ties"]) / total if total > 0 else 0,
1667
+ "subset_presence": sorted(model_presence.get(model, set())),
1668
+ })
1669
+
1670
+ leaderboard.sort(key=lambda x: -x["elo"])
1671
+
1672
+ # Get per-subset ELO for comparison
1673
+ per_subset_elo: dict[str, dict[str, float]] = {}
1674
+ for subset in subsets:
1675
+ subset_lb = self.get_elo_leaderboard(subset)
1676
+ per_subset_elo[subset] = {entry["model"]: entry["elo"] for entry in subset_lb}
1677
+
1678
+ result = {
1679
+ "subsets": subsets,
1680
+ "model_scope": model_scope,
1681
+ "common_models": cross_info["common_models"],
1682
+ "all_models": cross_info["all_models"],
1683
+ "total_battles": len(all_battles),
1684
+ "leaderboard": leaderboard,
1685
+ "per_subset_elo": per_subset_elo,
1686
+ }
1687
+
1688
+ # Cache the result
1689
+ self._cross_subset_elo_cache[cache_key] = result
1690
+ return result
1691
+
1692
+ def get_stats(self, subset: str, exp_name: Optional[str] = None) -> dict[str, Any]:
1693
+ """
1694
+ Get statistics for a subset.
1695
+
1696
+ Args:
1697
+ subset: Subset name
1698
+ exp_name: Optional experiment name (if None, uses overall state; "__all__" for all experiments)
1699
+
1700
+ Returns:
1701
+ Statistics dictionary
1702
+ """
1703
+ info = self.get_subset_info(subset)
1704
+ if not info:
1705
+ return {}
1706
+
1707
+ if exp_name == "__all__":
1708
+ # Combine stats from all experiments
1709
+ records = self._load_all_experiments_battles(subset)
1710
+ total_battles = len(records)
1711
+ consistent = sum(1 for r in records if r.is_consistent)
1712
+ ties = sum(1 for r in records if r.final_winner == "tie")
1713
+ elif exp_name:
1714
+ records = self._load_battle_logs(subset, exp_name)
1715
+ total_battles = len(records)
1716
+ consistent = sum(1 for r in records if r.is_consistent)
1717
+ ties = sum(1 for r in records if r.final_winner == "tie")
1718
+ else:
1719
+ total_battles = info.total_battles
1720
+ consistent = 0
1721
+ ties = 0
1722
+
1723
+ return {
1724
+ "subset": subset,
1725
+ "models": info.models,
1726
+ "experiments": info.experiments,
1727
+ "total_battles": total_battles,
1728
+ "consistent_battles": consistent,
1729
+ "tie_battles": ties,
1730
+ "consistency_rate": consistent / total_battles if total_battles > 0 else 0,
1731
+ }
1732
+
1733
+ def get_model_win_stats(
1734
+ self, subset: str, exp_name: str, sample_index: int,
1735
+ filter_models: Optional[list[str]] = None
1736
+ ) -> dict[str, dict[str, Any]]:
1737
+ """
1738
+ Get win/loss statistics for all models on a specific sample.
1739
+
1740
+ Args:
1741
+ subset: Subset name
1742
+ exp_name: Experiment name (use "__all__" for all experiments)
1743
+ sample_index: Sample index
1744
+ filter_models: Optional list of models to filter (only count battles between these models)
1745
+
1746
+ Returns:
1747
+ Dict mapping model name to stats (wins, losses, ties, total, win_rate)
1748
+ """
1749
+ if exp_name == "__all__":
1750
+ all_records = self._load_all_experiments_battles(subset)
1751
+ else:
1752
+ all_records = self._load_battle_logs(subset, exp_name)
1753
+
1754
+ # Filter records for this sample
1755
+ sample_records = [r for r in all_records if r.sample_index == sample_index]
1756
+
1757
+ # If filter_models is specified, only count battles between those models
1758
+ if filter_models:
1759
+ filter_set = set(filter_models)
1760
+ sample_records = [
1761
+ r for r in sample_records
1762
+ if r.model_a in filter_set and r.model_b in filter_set
1763
+ ]
1764
+
1765
+ # Collect stats per model
1766
+ model_stats: dict[str, dict[str, int]] = {}
1767
+
1768
+ for record in sample_records:
1769
+ for model in [record.model_a, record.model_b]:
1770
+ if model not in model_stats:
1771
+ model_stats[model] = {"wins": 0, "losses": 0, "ties": 0}
1772
+
1773
+ if record.final_winner == "tie":
1774
+ model_stats[record.model_a]["ties"] += 1
1775
+ model_stats[record.model_b]["ties"] += 1
1776
+ elif record.final_winner == record.model_a:
1777
+ model_stats[record.model_a]["wins"] += 1
1778
+ model_stats[record.model_b]["losses"] += 1
1779
+ elif record.final_winner == record.model_b:
1780
+ model_stats[record.model_b]["wins"] += 1
1781
+ model_stats[record.model_a]["losses"] += 1
1782
+
1783
+ # Calculate win rate and total
1784
+ result: dict[str, dict[str, Any]] = {}
1785
+ for model, stats in model_stats.items():
1786
+ total = stats["wins"] + stats["losses"] + stats["ties"]
1787
+ win_rate = stats["wins"] / total if total > 0 else 0
1788
+ result[model] = {
1789
+ "wins": stats["wins"],
1790
+ "losses": stats["losses"],
1791
+ "ties": stats["ties"],
1792
+ "total": total,
1793
+ "win_rate": win_rate,
1794
+ }
1795
+
1796
+ return result
1797
+
1798
+ def get_sample_all_models(
1799
+ self, subset: str, exp_name: str, sample_index: int,
1800
+ filter_models: Optional[list[str]] = None,
1801
+ stats_scope: str = "filtered"
1802
+ ) -> dict[str, Any]:
1803
+ """
1804
+ Get all model outputs for a specific sample, sorted by win rate.
1805
+
1806
+ Args:
1807
+ subset: Subset name
1808
+ exp_name: Experiment name
1809
+ sample_index: Sample index
1810
+ filter_models: Optional list of models to filter (show only these models)
1811
+ stats_scope: 'filtered' = only count battles between filtered models,
1812
+ 'all' = count all battles (but show only filtered models)
1813
+
1814
+ Returns:
1815
+ Dict with sample info and all model outputs sorted by win rate
1816
+ """
1817
+ # Get sample metadata
1818
+ sample_meta = self._get_sample_data(subset, sample_index)
1819
+
1820
+ # Determine which models to use for stats calculation
1821
+ # If stats_scope is 'all', don't filter battles by models
1822
+ stats_filter = filter_models if stats_scope == "filtered" else None
1823
+ model_stats = self.get_model_win_stats(subset, exp_name, sample_index, stats_filter)
1824
+
1825
+ # Get all models that have outputs
1826
+ model_manager = self._get_model_manager(subset)
1827
+ available_models = []
1828
+
1829
+ if model_manager:
1830
+ # Determine which models to include
1831
+ models_to_check = model_manager.models
1832
+ if filter_models:
1833
+ filter_set = set(filter_models)
1834
+ models_to_check = [m for m in models_to_check if m in filter_set]
1835
+
1836
+ for model in models_to_check:
1837
+ output_path = model_manager.get_output_path(model, sample_index)
1838
+ if output_path and os.path.isfile(output_path):
1839
+ stats = model_stats.get(model, {
1840
+ "wins": 0, "losses": 0, "ties": 0, "total": 0, "win_rate": 0
1841
+ })
1842
+ available_models.append({
1843
+ "model": model,
1844
+ "wins": stats["wins"],
1845
+ "losses": stats["losses"],
1846
+ "ties": stats["ties"],
1847
+ "total": stats["total"],
1848
+ "win_rate": stats["win_rate"],
1849
+ })
1850
+
1851
+ # Sort by win rate (descending), then by wins (descending), then by model name
1852
+ available_models.sort(key=lambda x: (-x["win_rate"], -x["wins"], x["model"]))
1853
+
1854
+ return {
1855
+ "subset": subset,
1856
+ "exp_name": exp_name,
1857
+ "sample_index": sample_index,
1858
+ "instruction": sample_meta.get("instruction", ""),
1859
+ "task_type": sample_meta.get("task_type", ""),
1860
+ "input_image_count": sample_meta.get("input_image_count", 1),
1861
+ "prompt_source": sample_meta.get("prompt_source"),
1862
+ "original_metadata": sample_meta.get("original_metadata"),
1863
+ "models": available_models,
1864
+ }
1865
+
1866
+ def get_model_battles_for_sample(
1867
+ self,
1868
+ subset: str,
1869
+ exp_name: str,
1870
+ sample_index: int,
1871
+ model: str,
1872
+ opponent_models: Optional[list[str]] = None,
1873
+ ) -> dict[str, Any]:
1874
+ """
1875
+ Get all battle records for a specific model on a specific sample.
1876
+
1877
+ Args:
1878
+ subset: Subset name
1879
+ exp_name: Experiment name (use "__all__" for all experiments)
1880
+ sample_index: Sample index
1881
+ model: The model to get battles for
1882
+ opponent_models: Optional list of opponent models to filter by
1883
+
1884
+ Returns:
1885
+ Dict with model info and list of battle records
1886
+ """
1887
+ if exp_name == "__all__":
1888
+ all_records = self._load_all_experiments_battles(subset)
1889
+ else:
1890
+ all_records = self._load_battle_logs(subset, exp_name)
1891
+
1892
+ # Filter records for this sample and involving this model
1893
+ model_battles = []
1894
+ all_opponents = set()
1895
+
1896
+ for record in all_records:
1897
+ if record.sample_index != sample_index:
1898
+ continue
1899
+ if model not in [record.model_a, record.model_b]:
1900
+ continue
1901
+
1902
+ # Determine opponent
1903
+ opponent = record.model_b if record.model_a == model else record.model_a
1904
+ all_opponents.add(opponent)
1905
+
1906
+ # Apply opponent filter if specified
1907
+ if opponent_models and opponent not in opponent_models:
1908
+ continue
1909
+
1910
+ # Determine result for this model
1911
+ if record.final_winner == "tie":
1912
+ result = "tie"
1913
+ elif record.final_winner == model:
1914
+ result = "win"
1915
+ else:
1916
+ result = "loss"
1917
+
1918
+ # Build battle data with judge outputs
1919
+ battle_data = {
1920
+ "opponent": opponent,
1921
+ "result": result,
1922
+ "is_consistent": record.is_consistent,
1923
+ "model_a": record.model_a,
1924
+ "model_b": record.model_b,
1925
+ "final_winner": record.final_winner,
1926
+ "exp_name": record.exp_name,
1927
+ }
1928
+
1929
+ # Load audit logs if not already loaded on the record
1930
+ if not record.original_call and not record.swapped_call:
1931
+ actual_exp_name = record.exp_name
1932
+ audit = self._load_audit_log(
1933
+ subset, actual_exp_name, record.model_a, record.model_b, sample_index
1934
+ )
1935
+ if audit:
1936
+ battle_data["original_call"] = audit.get("original_call")
1937
+ battle_data["swapped_call"] = audit.get("swapped_call")
1938
+ else:
1939
+ # Use existing data if available
1940
+ if record.original_call:
1941
+ battle_data["original_call"] = record.original_call
1942
+ if record.swapped_call:
1943
+ battle_data["swapped_call"] = record.swapped_call
1944
+
1945
+ model_battles.append(battle_data)
1946
+
1947
+ # Sort battles by opponent name
1948
+ model_battles.sort(key=lambda x: x["opponent"])
1949
+
1950
+ # Get model stats
1951
+ model_stats = self.get_model_win_stats(subset, exp_name, sample_index)
1952
+ stats = model_stats.get(model, {
1953
+ "wins": 0, "losses": 0, "ties": 0, "total": 0, "win_rate": 0
1954
+ })
1955
+
1956
+ return {
1957
+ "model": model,
1958
+ "sample_index": sample_index,
1959
+ "wins": stats["wins"],
1960
+ "losses": stats["losses"],
1961
+ "ties": stats["ties"],
1962
+ "total": stats["total"],
1963
+ "win_rate": stats["win_rate"],
1964
+ "battles": model_battles,
1965
+ "all_opponents": sorted(list(all_opponents)),
1966
+ }
1967
+
1968
+ def get_elo_leaderboard(
1969
+ self,
1970
+ subset: str,
1971
+ filter_models: Optional[list[str]] = None,
1972
+ ) -> list[dict[str, Any]]:
1973
+ """
1974
+ Get ELO leaderboard for a subset from state.json.
1975
+
1976
+ Args:
1977
+ subset: Subset name
1978
+ filter_models: Optional list of models to filter (show only these models)
1979
+
1980
+ Returns:
1981
+ List of model stats sorted by ELO rating (descending)
1982
+ """
1983
+ info = self.get_subset_info(subset)
1984
+ if not info or not info.state:
1985
+ return []
1986
+
1987
+ state = info.state
1988
+ leaderboard = []
1989
+
1990
+ for model_name, model_stats in state.models.items():
1991
+ # Apply filter if specified
1992
+ if filter_models and model_name not in filter_models:
1993
+ continue
1994
+
1995
+ leaderboard.append({
1996
+ "model": model_name,
1997
+ "elo": model_stats.elo,
1998
+ "wins": model_stats.wins,
1999
+ "losses": model_stats.losses,
2000
+ "ties": model_stats.ties,
2001
+ "total_battles": model_stats.total_battles,
2002
+ "win_rate": model_stats.win_rate,
2003
+ })
2004
+
2005
+ # Sort by ELO rating (descending)
2006
+ leaderboard.sort(key=lambda x: -x["elo"])
2007
+
2008
+ # Add rank
2009
+ for i, entry in enumerate(leaderboard):
2010
+ entry["rank"] = i + 1
2011
+
2012
+ return leaderboard
2013
+
2014
+ def get_model_vs_stats(
2015
+ self,
2016
+ subset: str,
2017
+ model: str,
2018
+ exp_name: str = "__all__",
2019
+ ) -> dict[str, Any]:
2020
+ """
2021
+ Get win/loss/tie stats of a specific model against all other models.
2022
+
2023
+ Args:
2024
+ subset: Subset name
2025
+ model: Target model name
2026
+ exp_name: Experiment name (default "__all__" for all experiments)
2027
+
2028
+ Returns:
2029
+ Dict with model stats and versus stats against each opponent
2030
+ """
2031
+ # Get overall ELO stats
2032
+ info = self.get_subset_info(subset)
2033
+ if not info or not info.state:
2034
+ return {}
2035
+
2036
+ state = info.state
2037
+ if model not in state.models:
2038
+ return {}
2039
+
2040
+ model_stats = state.models[model]
2041
+
2042
+ # Load battle records
2043
+ if exp_name == "__all__":
2044
+ all_records = self._load_all_experiments_battles(subset)
2045
+ else:
2046
+ all_records = self._load_battle_logs(subset, exp_name)
2047
+
2048
+ # Calculate stats against each opponent
2049
+ vs_stats: dict[str, dict[str, int]] = {}
2050
+
2051
+ for record in all_records:
2052
+ if model not in [record.model_a, record.model_b]:
2053
+ continue
2054
+
2055
+ opponent = record.model_b if record.model_a == model else record.model_a
2056
+
2057
+ if opponent not in vs_stats:
2058
+ vs_stats[opponent] = {"wins": 0, "losses": 0, "ties": 0}
2059
+
2060
+ if record.final_winner == "tie":
2061
+ vs_stats[opponent]["ties"] += 1
2062
+ elif record.final_winner == model:
2063
+ vs_stats[opponent]["wins"] += 1
2064
+ else:
2065
+ vs_stats[opponent]["losses"] += 1
2066
+
2067
+ # Convert to list with win rates and opponent ELO
2068
+ vs_list = []
2069
+ for opponent, stats in vs_stats.items():
2070
+ total = stats["wins"] + stats["losses"] + stats["ties"]
2071
+ opponent_elo = state.models[opponent].elo if opponent in state.models else 1000.0
2072
+ vs_list.append({
2073
+ "opponent": opponent,
2074
+ "opponent_elo": opponent_elo,
2075
+ "wins": stats["wins"],
2076
+ "losses": stats["losses"],
2077
+ "ties": stats["ties"],
2078
+ "total": total,
2079
+ "win_rate": stats["wins"] / total if total > 0 else 0,
2080
+ })
2081
+
2082
+ # Sort by opponent ELO (descending)
2083
+ vs_list.sort(key=lambda x: -x["opponent_elo"])
2084
+
2085
+ return {
2086
+ "model": model,
2087
+ "elo": model_stats.elo,
2088
+ "wins": model_stats.wins,
2089
+ "losses": model_stats.losses,
2090
+ "ties": model_stats.ties,
2091
+ "total_battles": model_stats.total_battles,
2092
+ "win_rate": model_stats.win_rate,
2093
+ "vs_stats": vs_list,
2094
+ }
2095
+
2096
+ def get_all_subsets_leaderboards(self) -> dict[str, Any]:
2097
+ """
2098
+ Get leaderboard data for all subsets (for Overview page).
2099
+
2100
+ Returns:
2101
+ Dict with:
2102
+ - subsets: List of subset names
2103
+ - models: List of all unique model names across all subsets
2104
+ - data: Dict mapping subset -> {model -> {elo, rank, wins, losses, ties, ...}}
2105
+ - subset_info: Dict mapping subset -> {total_battles, model_count}
2106
+ """
2107
+ subsets = self.discover_subsets()
2108
+ all_models: set[str] = set()
2109
+ data: dict[str, dict[str, dict[str, Any]]] = {}
2110
+ subset_info: dict[str, dict[str, Any]] = {}
2111
+
2112
+ for subset in subsets:
2113
+ leaderboard = self.get_elo_leaderboard(subset)
2114
+ info = self.get_subset_info(subset)
2115
+
2116
+ if not leaderboard:
2117
+ continue
2118
+
2119
+ # Build subset data
2120
+ subset_data: dict[str, dict[str, Any]] = {}
2121
+ for entry in leaderboard:
2122
+ model = entry["model"]
2123
+ all_models.add(model)
2124
+ subset_data[model] = {
2125
+ "elo": entry["elo"],
2126
+ "rank": entry["rank"],
2127
+ "wins": entry["wins"],
2128
+ "losses": entry["losses"],
2129
+ "ties": entry["ties"],
2130
+ "total_battles": entry["total_battles"],
2131
+ "win_rate": entry["win_rate"],
2132
+ }
2133
+
2134
+ data[subset] = subset_data
2135
+ subset_info[subset] = {
2136
+ "total_battles": info.total_battles if info else 0,
2137
+ "model_count": len(leaderboard),
2138
+ }
2139
+
2140
+ # Sort models by average ELO across all subsets (descending)
2141
+ model_avg_elo: dict[str, tuple[float, int]] = {} # model -> (sum_elo, count)
2142
+ for model in all_models:
2143
+ total_elo = 0.0
2144
+ count = 0
2145
+ for subset in subsets:
2146
+ if subset in data and model in data[subset]:
2147
+ total_elo += data[subset][model]["elo"]
2148
+ count += 1
2149
+ if count > 0:
2150
+ model_avg_elo[model] = (total_elo / count, count)
2151
+ else:
2152
+ model_avg_elo[model] = (0.0, 0)
2153
+
2154
+ sorted_models = sorted(
2155
+ all_models,
2156
+ key=lambda m: (-model_avg_elo[m][0], -model_avg_elo[m][1], m)
2157
+ )
2158
+
2159
+ return {
2160
+ "subsets": subsets,
2161
+ "models": sorted_models,
2162
+ "data": data,
2163
+ "subset_info": subset_info,
2164
+ }
2165
+
2166
+ def get_prompts(
2167
+ self,
2168
+ subset: str,
2169
+ exp_name: str,
2170
+ page: int = 1,
2171
+ page_size: int = 10,
2172
+ min_images: Optional[int] = None,
2173
+ max_images: Optional[int] = None,
2174
+ prompt_source: Optional[str] = None,
2175
+ filter_models: Optional[list[str]] = None,
2176
+ ) -> tuple[list[dict[str, Any]], int]:
2177
+ """
2178
+ Get paginated list of prompts/samples with all model outputs.
2179
+
2180
+ Args:
2181
+ subset: Subset name
2182
+ exp_name: Experiment name (use "__all__" for all experiments)
2183
+ page: Page number (1-indexed)
2184
+ page_size: Number of records per page
2185
+ min_images: Minimum number of input images
2186
+ max_images: Maximum number of input images
2187
+ prompt_source: Filter by prompt source
2188
+ filter_models: Optional list of models to filter (show only these models)
2189
+
2190
+ Returns:
2191
+ Tuple of (prompts_list, total_count)
2192
+ """
2193
+ # Get all sample indices from battle logs
2194
+ if exp_name == "__all__":
2195
+ all_records = self._load_all_experiments_battles(subset)
2196
+ else:
2197
+ all_records = self._load_battle_logs(subset, exp_name)
2198
+
2199
+ # Collect unique sample indices
2200
+ sample_indices = set()
2201
+ for record in all_records:
2202
+ sample_indices.add(record.sample_index)
2203
+
2204
+ # Sort sample indices
2205
+ sorted_indices = sorted(sample_indices)
2206
+
2207
+ # Apply filters
2208
+ filtered_indices = []
2209
+ for idx in sorted_indices:
2210
+ sample_meta = self._get_sample_data(subset, idx)
2211
+ img_count = sample_meta.get("input_image_count", 1)
2212
+ source = sample_meta.get("prompt_source")
2213
+
2214
+ # Apply image count filter
2215
+ if min_images is not None and img_count < min_images:
2216
+ continue
2217
+ if max_images is not None and img_count > max_images:
2218
+ continue
2219
+
2220
+ # Apply prompt source filter
2221
+ if prompt_source and source != prompt_source:
2222
+ continue
2223
+
2224
+ filtered_indices.append(idx)
2225
+
2226
+ total_count = len(filtered_indices)
2227
+
2228
+ # Paginate
2229
+ start = (page - 1) * page_size
2230
+ end = start + page_size
2231
+ page_indices = filtered_indices[start:end]
2232
+
2233
+ # Build prompt data for each sample
2234
+ prompts = []
2235
+ for idx in page_indices:
2236
+ prompt_data = self.get_sample_all_models(subset, exp_name, idx, filter_models)
2237
+ prompts.append(prompt_data)
2238
+
2239
+ return prompts, total_count
2240
+
2241
+
2242
+ class HFArenaDataLoader(ArenaDataLoader):
2243
+ """
2244
+ Data loader for HuggingFace Spaces deployment.
2245
+
2246
+ Extends ArenaDataLoader to:
2247
+ - Build image URL index from HF file list
2248
+ - Return HF CDN URLs for model output images instead of local paths
2249
+ """
2250
+
2251
+ def __init__(
2252
+ self,
2253
+ arena_dir: str,
2254
+ data_dir: str,
2255
+ hf_repo: str,
2256
+ image_files: list[str],
2257
+ preload: bool = True,
2258
+ ):
2259
+ """
2260
+ Initialize the HF data loader.
2261
+
2262
+ Args:
2263
+ arena_dir: Path to arena directory (metadata only, no images)
2264
+ data_dir: Path to data directory containing parquet files
2265
+ hf_repo: HuggingFace repo ID for image CDN URLs
2266
+ image_files: List of image file paths in the HF repo
2267
+ preload: If True, preload all data at initialization
2268
+ """
2269
+ self.hf_repo = hf_repo
2270
+ self._image_url_index = self._build_image_index(image_files)
2271
+ super().__init__(arena_dir, data_dir, preload=preload)
2272
+
2273
+ def _build_image_index(
2274
+ self, image_files: list[str]
2275
+ ) -> dict[tuple[str, str, int], str]:
2276
+ """
2277
+ Build index: (subset, model, sample_index) -> hf_file_path
2278
+
2279
+ Expected path format: {subset}/models/{exp_name}/{model}/{index}.png
2280
+
2281
+ Args:
2282
+ image_files: List of image file paths from HF repo
2283
+
2284
+ Returns:
2285
+ Dict mapping (subset, model, sample_index) to HF file path
2286
+ """
2287
+ from genarena.models import parse_image_index
2288
+
2289
+ index: dict[tuple[str, str, int], str] = {}
2290
+
2291
+ for path in image_files:
2292
+ parts = path.split("/")
2293
+ # Expected: subset/models/exp_name/model/000000.png
2294
+ if len(parts) >= 5 and parts[1] == "models":
2295
+ subset = parts[0]
2296
+ # exp_name = parts[2] # Not needed for lookup
2297
+ model = parts[3]
2298
+ filename = parts[4]
2299
+ idx = parse_image_index(filename)
2300
+ if idx is not None:
2301
+ # If duplicate, later entries overwrite earlier ones
2302
+ index[(subset, model, idx)] = path
2303
+
2304
+ logger.info(f"Built image URL index with {len(index)} entries")
2305
+ return index
2306
+
2307
+ def get_model_image_url(
2308
+ self, subset: str, model: str, sample_index: int
2309
+ ) -> Optional[str]:
2310
+ """
2311
+ Get HF CDN URL for model output image.
2312
+
2313
+ Args:
2314
+ subset: Subset name
2315
+ model: Model name
2316
+ sample_index: Sample index
2317
+
2318
+ Returns:
2319
+ HF CDN URL or None if not found
2320
+ """
2321
+ path = self._image_url_index.get((subset, model, sample_index))
2322
+ if path:
2323
+ return f"https://huggingface.co/datasets/{self.hf_repo}/resolve/main/{path}"
2324
+ return None
2325
+
2326
+ def get_image_path(
2327
+ self, subset: str, model: str, sample_index: int
2328
+ ) -> Optional[str]:
2329
+ """
2330
+ Override to return None since images are served via CDN.
2331
+
2332
+ For HF deployment, use get_model_image_url() instead.
2333
+ """
2334
+ # Return None to indicate image should be fetched via CDN
2335
+ return None