genarena 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genarena/__init__.py +49 -2
- genarena/__main__.py +10 -0
- genarena/arena.py +1685 -0
- genarena/battle.py +337 -0
- genarena/bt_elo.py +507 -0
- genarena/cli.py +1581 -0
- genarena/data.py +476 -0
- genarena/deploy/Dockerfile +22 -0
- genarena/deploy/README.md +55 -0
- genarena/deploy/__init__.py +5 -0
- genarena/deploy/app.py +84 -0
- genarena/experiments.py +121 -0
- genarena/leaderboard.py +270 -0
- genarena/logs.py +409 -0
- genarena/models.py +412 -0
- genarena/prompts/__init__.py +127 -0
- genarena/prompts/mmrb2.py +373 -0
- genarena/sampling.py +336 -0
- genarena/state.py +656 -0
- genarena/sync/__init__.py +105 -0
- genarena/sync/auto_commit.py +118 -0
- genarena/sync/deploy_ops.py +543 -0
- genarena/sync/git_ops.py +422 -0
- genarena/sync/hf_ops.py +891 -0
- genarena/sync/init_ops.py +431 -0
- genarena/sync/packer.py +587 -0
- genarena/sync/submit.py +837 -0
- genarena/utils.py +103 -0
- genarena/validation/__init__.py +19 -0
- genarena/validation/schema.py +327 -0
- genarena/validation/validator.py +329 -0
- genarena/visualize/README.md +148 -0
- genarena/visualize/__init__.py +14 -0
- genarena/visualize/app.py +938 -0
- genarena/visualize/data_loader.py +2430 -0
- genarena/visualize/static/app.js +3762 -0
- genarena/visualize/static/model_aliases.json +86 -0
- genarena/visualize/static/style.css +4104 -0
- genarena/visualize/templates/index.html +413 -0
- genarena/vlm.py +519 -0
- genarena-0.1.1.dist-info/METADATA +178 -0
- genarena-0.1.1.dist-info/RECORD +44 -0
- {genarena-0.0.1.dist-info → genarena-0.1.1.dist-info}/WHEEL +1 -2
- genarena-0.1.1.dist-info/entry_points.txt +2 -0
- genarena-0.0.1.dist-info/METADATA +0 -26
- genarena-0.0.1.dist-info/RECORD +0 -5
- genarena-0.0.1.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,2430 @@
|
|
|
1
|
+
# Copyright 2026 Ruihang Li.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0.
|
|
3
|
+
# See LICENSE file in the project root for details.
|
|
4
|
+
|
|
5
|
+
"""Data loader for arena visualization with preloading support."""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any, Optional
|
|
13
|
+
|
|
14
|
+
from genarena.data import DataSample, ParquetDataset, discover_subsets
|
|
15
|
+
from genarena.models import GlobalModelOutputManager
|
|
16
|
+
from genarena.state import ArenaState, load_state
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class BattleRecord:
|
|
24
|
+
"""A single battle record with all relevant information."""
|
|
25
|
+
|
|
26
|
+
# Battle identification
|
|
27
|
+
subset: str
|
|
28
|
+
exp_name: str
|
|
29
|
+
sample_index: int
|
|
30
|
+
model_a: str
|
|
31
|
+
model_b: str
|
|
32
|
+
|
|
33
|
+
# Battle result
|
|
34
|
+
final_winner: str # model name or "tie"
|
|
35
|
+
is_consistent: bool
|
|
36
|
+
timestamp: str = ""
|
|
37
|
+
|
|
38
|
+
# Raw VLM outputs (from audit logs, optional)
|
|
39
|
+
original_call: Optional[dict[str, Any]] = None
|
|
40
|
+
swapped_call: Optional[dict[str, Any]] = None
|
|
41
|
+
|
|
42
|
+
# Sample data (loaded on demand)
|
|
43
|
+
instruction: str = ""
|
|
44
|
+
task_type: str = ""
|
|
45
|
+
input_image_count: int = 1
|
|
46
|
+
prompt_source: Optional[str] = None
|
|
47
|
+
original_metadata: Optional[dict[str, Any]] = None
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def id(self) -> str:
|
|
51
|
+
"""Unique identifier for this battle."""
|
|
52
|
+
return f"{self.subset}:{self.exp_name}:{self.model_a}_vs_{self.model_b}:{self.sample_index}"
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def winner_display(self) -> str:
|
|
56
|
+
"""Display-friendly winner string."""
|
|
57
|
+
if self.final_winner == "tie":
|
|
58
|
+
return "Tie"
|
|
59
|
+
return self.final_winner
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def models(self) -> set[str]:
|
|
63
|
+
"""Set of models involved in this battle."""
|
|
64
|
+
return {self.model_a, self.model_b}
|
|
65
|
+
|
|
66
|
+
def to_dict(self) -> dict[str, Any]:
|
|
67
|
+
"""Convert to dictionary for JSON serialization."""
|
|
68
|
+
return {
|
|
69
|
+
"id": self.id,
|
|
70
|
+
"subset": self.subset,
|
|
71
|
+
"exp_name": self.exp_name,
|
|
72
|
+
"sample_index": self.sample_index,
|
|
73
|
+
"model_a": self.model_a,
|
|
74
|
+
"model_b": self.model_b,
|
|
75
|
+
"final_winner": self.final_winner,
|
|
76
|
+
"winner_display": self.winner_display,
|
|
77
|
+
"is_consistent": self.is_consistent,
|
|
78
|
+
"timestamp": self.timestamp,
|
|
79
|
+
"instruction": self.instruction,
|
|
80
|
+
"task_type": self.task_type,
|
|
81
|
+
"input_image_count": self.input_image_count,
|
|
82
|
+
"prompt_source": self.prompt_source,
|
|
83
|
+
"original_metadata": self.original_metadata,
|
|
84
|
+
"has_audit": self.original_call is not None,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
def to_detail_dict(self) -> dict[str, Any]:
|
|
88
|
+
"""Convert to detailed dictionary including VLM outputs."""
|
|
89
|
+
d = self.to_dict()
|
|
90
|
+
d["original_call"] = self.original_call
|
|
91
|
+
d["swapped_call"] = self.swapped_call
|
|
92
|
+
return d
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class SubsetInfo:
|
|
97
|
+
"""Information about a subset."""
|
|
98
|
+
|
|
99
|
+
name: str
|
|
100
|
+
models: list[str]
|
|
101
|
+
experiments: list[str]
|
|
102
|
+
total_battles: int
|
|
103
|
+
state: Optional[ArenaState] = None
|
|
104
|
+
min_input_images: int = 1
|
|
105
|
+
max_input_images: int = 1
|
|
106
|
+
prompt_sources: list[str] = field(default_factory=list)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ArenaDataLoader:
|
|
110
|
+
"""
|
|
111
|
+
Data loader for arena visualization.
|
|
112
|
+
|
|
113
|
+
Manages loading and querying battle records across multiple subsets.
|
|
114
|
+
Supports preloading for better performance with large datasets.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, arena_dir: str, data_dir: str, preload: bool = True):
|
|
118
|
+
"""
|
|
119
|
+
Initialize the data loader.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
arena_dir: Path to arena directory containing subset folders
|
|
123
|
+
data_dir: Path to data directory containing parquet files
|
|
124
|
+
preload: If True, preload all data at initialization
|
|
125
|
+
"""
|
|
126
|
+
self.arena_dir = arena_dir
|
|
127
|
+
self.data_dir = data_dir
|
|
128
|
+
|
|
129
|
+
# Cached data
|
|
130
|
+
self._subsets: Optional[list[str]] = None
|
|
131
|
+
self._subset_info_cache: dict[str, SubsetInfo] = {}
|
|
132
|
+
self._dataset_cache: dict[str, ParquetDataset] = {}
|
|
133
|
+
self._model_manager_cache: dict[str, GlobalModelOutputManager] = {}
|
|
134
|
+
|
|
135
|
+
# Battle records cache: (subset, exp_name) -> List[BattleRecord]
|
|
136
|
+
self._battle_cache: dict[tuple[str, str], list[BattleRecord]] = {}
|
|
137
|
+
|
|
138
|
+
# Index for faster lookups: (subset, exp_name) -> {model -> [record_indices]}
|
|
139
|
+
self._model_index: dict[tuple[str, str], dict[str, list[int]]] = {}
|
|
140
|
+
|
|
141
|
+
# Sample data cache: (subset, sample_index) -> SampleMetadata dict
|
|
142
|
+
self._sample_cache: dict[tuple[str, int], dict[str, Any]] = {}
|
|
143
|
+
|
|
144
|
+
# Sample to parquet file mapping: (subset, sample_index) -> parquet_file_path
|
|
145
|
+
self._sample_file_map: dict[tuple[str, int], str] = {}
|
|
146
|
+
|
|
147
|
+
# Input image count range per subset: subset -> (min_count, max_count)
|
|
148
|
+
self._image_count_range: dict[str, tuple[int, int]] = {}
|
|
149
|
+
|
|
150
|
+
# Prompt sources per subset: subset -> list of unique prompt_source values
|
|
151
|
+
self._prompt_sources: dict[str, list[str]] = {}
|
|
152
|
+
|
|
153
|
+
# Audit logs cache: (subset, exp_name, model_a, model_b, sample_index) -> audit data
|
|
154
|
+
self._audit_cache: dict[tuple[str, str, str, str, int], dict[str, Any]] = {}
|
|
155
|
+
|
|
156
|
+
# Cross-subset ELO cache: (sorted_subsets_tuple, exp_name, model_scope) -> result dict
|
|
157
|
+
self._cross_subset_elo_cache: dict[tuple[tuple[str, ...], str, str], dict[str, Any]] = {}
|
|
158
|
+
|
|
159
|
+
if preload:
|
|
160
|
+
self._preload_all()
|
|
161
|
+
|
|
162
|
+
def _preload_all(self) -> None:
|
|
163
|
+
"""Preload all data at initialization for better performance."""
|
|
164
|
+
logger.info("Preloading arena data...")
|
|
165
|
+
|
|
166
|
+
subsets = self.discover_subsets()
|
|
167
|
+
logger.info(f"Found {len(subsets)} subsets: {subsets}")
|
|
168
|
+
|
|
169
|
+
for subset in subsets:
|
|
170
|
+
logger.info(f"Loading subset: {subset}")
|
|
171
|
+
|
|
172
|
+
# Preload parquet dataset
|
|
173
|
+
self._preload_dataset(subset)
|
|
174
|
+
|
|
175
|
+
# Load subset info (models, experiments)
|
|
176
|
+
info = self.get_subset_info(subset)
|
|
177
|
+
if info:
|
|
178
|
+
logger.info(f" - {len(info.models)} models, {len(info.experiments)} experiments")
|
|
179
|
+
|
|
180
|
+
# Preload battle logs for each experiment
|
|
181
|
+
for exp_name in info.experiments:
|
|
182
|
+
records = self._load_battle_logs(subset, exp_name)
|
|
183
|
+
logger.info(f" - Experiment '{exp_name}': {len(records)} battles")
|
|
184
|
+
|
|
185
|
+
logger.info("Preloading complete!")
|
|
186
|
+
|
|
187
|
+
def _preload_dataset(self, subset: str) -> None:
|
|
188
|
+
"""
|
|
189
|
+
Preload sample text data (instruction, task_type) using pyarrow directly.
|
|
190
|
+
|
|
191
|
+
This is much faster than using HuggingFace datasets because we skip
|
|
192
|
+
decoding image columns. Images are loaded on-demand when requested.
|
|
193
|
+
"""
|
|
194
|
+
import pyarrow.parquet as pq
|
|
195
|
+
|
|
196
|
+
subset_path = os.path.join(self.data_dir, subset)
|
|
197
|
+
if not os.path.isdir(subset_path):
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
# Find parquet files
|
|
201
|
+
parquet_files = sorted([
|
|
202
|
+
os.path.join(subset_path, f)
|
|
203
|
+
for f in os.listdir(subset_path)
|
|
204
|
+
if f.startswith("data-") and f.endswith(".parquet")
|
|
205
|
+
])
|
|
206
|
+
|
|
207
|
+
if not parquet_files:
|
|
208
|
+
return
|
|
209
|
+
|
|
210
|
+
logger.info(f" - Loading metadata from parquet (fast mode)...")
|
|
211
|
+
|
|
212
|
+
# Read all metadata columns + input_images (only to count, not decode)
|
|
213
|
+
columns_to_read = ["index", "instruction", "task_type", "input_images", "prompt_source", "original_metadata"]
|
|
214
|
+
|
|
215
|
+
total_rows = 0
|
|
216
|
+
min_img_count = float('inf')
|
|
217
|
+
max_img_count = 0
|
|
218
|
+
prompt_sources_set: set[str] = set()
|
|
219
|
+
|
|
220
|
+
for pf in parquet_files:
|
|
221
|
+
try:
|
|
222
|
+
# Get available columns in this file
|
|
223
|
+
import pyarrow.parquet as pq_schema
|
|
224
|
+
schema = pq.read_schema(pf)
|
|
225
|
+
available_columns = [c for c in columns_to_read if c in schema.names]
|
|
226
|
+
|
|
227
|
+
# Read the columns we need
|
|
228
|
+
table = pq.read_table(pf, columns=available_columns)
|
|
229
|
+
|
|
230
|
+
# Extract columns with defaults
|
|
231
|
+
def get_column(name, default=None):
|
|
232
|
+
if name in table.column_names:
|
|
233
|
+
return table.column(name).to_pylist()
|
|
234
|
+
return [default] * table.num_rows
|
|
235
|
+
|
|
236
|
+
indices = get_column("index", 0)
|
|
237
|
+
instructions = get_column("instruction", "")
|
|
238
|
+
task_types = get_column("task_type", "")
|
|
239
|
+
prompt_sources = get_column("prompt_source", None)
|
|
240
|
+
original_metadatas = get_column("original_metadata", None)
|
|
241
|
+
|
|
242
|
+
# Handle input_images separately for counting
|
|
243
|
+
has_input_images = "input_images" in table.column_names
|
|
244
|
+
input_images_col = table.column("input_images") if has_input_images else None
|
|
245
|
+
|
|
246
|
+
for i, idx in enumerate(indices):
|
|
247
|
+
idx = int(idx) if idx is not None else i
|
|
248
|
+
|
|
249
|
+
# Count input images without decoding
|
|
250
|
+
img_count = 0
|
|
251
|
+
if input_images_col is not None:
|
|
252
|
+
img_list = input_images_col[i].as_py()
|
|
253
|
+
img_count = len(img_list) if img_list else 0
|
|
254
|
+
|
|
255
|
+
min_img_count = min(min_img_count, img_count) if img_count > 0 else min_img_count
|
|
256
|
+
max_img_count = max(max_img_count, img_count)
|
|
257
|
+
|
|
258
|
+
# Track prompt sources
|
|
259
|
+
ps = prompt_sources[i] if prompt_sources[i] else None
|
|
260
|
+
if ps:
|
|
261
|
+
prompt_sources_set.add(str(ps))
|
|
262
|
+
|
|
263
|
+
# Build metadata dict
|
|
264
|
+
metadata = {
|
|
265
|
+
"instruction": str(instructions[i]) if instructions[i] else "",
|
|
266
|
+
"task_type": str(task_types[i]) if task_types[i] else "",
|
|
267
|
+
"input_image_count": img_count,
|
|
268
|
+
"prompt_source": ps,
|
|
269
|
+
"original_metadata": original_metadatas[i] if original_metadatas[i] else None,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
self._sample_cache[(subset, idx)] = metadata
|
|
273
|
+
self._sample_file_map[(subset, idx)] = pf
|
|
274
|
+
total_rows += 1
|
|
275
|
+
|
|
276
|
+
except Exception as e:
|
|
277
|
+
logger.warning(f"Failed to read {pf}: {e}")
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
# Store image count range for this subset
|
|
281
|
+
if total_rows > 0:
|
|
282
|
+
self._image_count_range[subset] = (
|
|
283
|
+
min_img_count if min_img_count != float('inf') else 1,
|
|
284
|
+
max_img_count if max_img_count > 0 else 1
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Store prompt sources for this subset
|
|
288
|
+
self._prompt_sources[subset] = sorted(prompt_sources_set)
|
|
289
|
+
|
|
290
|
+
logger.info(f" - Cached {total_rows} samples (input images: {self._image_count_range.get(subset, (1,1))}, sources: {len(prompt_sources_set)})")
|
|
291
|
+
|
|
292
|
+
def discover_subsets(self) -> list[str]:
|
|
293
|
+
"""
|
|
294
|
+
Discover all available subsets.
|
|
295
|
+
|
|
296
|
+
A valid subset must exist in both arena_dir (with pk_logs) and data_dir.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
List of subset names
|
|
300
|
+
"""
|
|
301
|
+
if self._subsets is not None:
|
|
302
|
+
return self._subsets
|
|
303
|
+
|
|
304
|
+
# Get subsets from data_dir (have parquet files)
|
|
305
|
+
data_subsets = set(discover_subsets(self.data_dir))
|
|
306
|
+
|
|
307
|
+
# Get subsets from arena_dir (have pk_logs)
|
|
308
|
+
arena_subsets = set()
|
|
309
|
+
if os.path.isdir(self.arena_dir):
|
|
310
|
+
for name in os.listdir(self.arena_dir):
|
|
311
|
+
subset_path = os.path.join(self.arena_dir, name)
|
|
312
|
+
pk_logs_path = os.path.join(subset_path, "pk_logs")
|
|
313
|
+
if os.path.isdir(pk_logs_path):
|
|
314
|
+
# Check if there are any experiment directories with battle logs
|
|
315
|
+
for exp_name in os.listdir(pk_logs_path):
|
|
316
|
+
exp_path = os.path.join(pk_logs_path, exp_name)
|
|
317
|
+
if os.path.isdir(exp_path):
|
|
318
|
+
# Check for .jsonl files
|
|
319
|
+
has_logs = any(
|
|
320
|
+
f.endswith(".jsonl")
|
|
321
|
+
for f in os.listdir(exp_path)
|
|
322
|
+
if os.path.isfile(os.path.join(exp_path, f))
|
|
323
|
+
)
|
|
324
|
+
if has_logs:
|
|
325
|
+
arena_subsets.add(name)
|
|
326
|
+
break
|
|
327
|
+
|
|
328
|
+
# Intersection: must have both data and battle logs
|
|
329
|
+
valid_subsets = sorted(data_subsets & arena_subsets)
|
|
330
|
+
self._subsets = valid_subsets
|
|
331
|
+
return valid_subsets
|
|
332
|
+
|
|
333
|
+
def get_subset_info(self, subset: str) -> Optional[SubsetInfo]:
|
|
334
|
+
"""
|
|
335
|
+
Get information about a subset.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
subset: Subset name
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
SubsetInfo or None if subset doesn't exist
|
|
342
|
+
"""
|
|
343
|
+
if subset in self._subset_info_cache:
|
|
344
|
+
return self._subset_info_cache[subset]
|
|
345
|
+
|
|
346
|
+
subset_path = os.path.join(self.arena_dir, subset)
|
|
347
|
+
if not os.path.isdir(subset_path):
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
# Get models
|
|
351
|
+
model_manager = self._get_model_manager(subset)
|
|
352
|
+
models = model_manager.models if model_manager else []
|
|
353
|
+
|
|
354
|
+
# Get experiments
|
|
355
|
+
pk_logs_dir = os.path.join(subset_path, "pk_logs")
|
|
356
|
+
experiments = []
|
|
357
|
+
if os.path.isdir(pk_logs_dir):
|
|
358
|
+
for name in os.listdir(pk_logs_dir):
|
|
359
|
+
exp_path = os.path.join(pk_logs_dir, name)
|
|
360
|
+
if os.path.isdir(exp_path):
|
|
361
|
+
# Check for battle logs
|
|
362
|
+
has_logs = any(
|
|
363
|
+
f.endswith(".jsonl")
|
|
364
|
+
for f in os.listdir(exp_path)
|
|
365
|
+
if os.path.isfile(os.path.join(exp_path, f))
|
|
366
|
+
)
|
|
367
|
+
if has_logs:
|
|
368
|
+
experiments.append(name)
|
|
369
|
+
experiments.sort()
|
|
370
|
+
|
|
371
|
+
# Load state
|
|
372
|
+
state_path = os.path.join(subset_path, "arena", "state.json")
|
|
373
|
+
state = load_state(state_path)
|
|
374
|
+
|
|
375
|
+
# Get image count range
|
|
376
|
+
img_range = self._image_count_range.get(subset, (1, 1))
|
|
377
|
+
|
|
378
|
+
# Get prompt sources
|
|
379
|
+
prompt_sources = self._prompt_sources.get(subset, [])
|
|
380
|
+
|
|
381
|
+
info = SubsetInfo(
|
|
382
|
+
name=subset,
|
|
383
|
+
models=models,
|
|
384
|
+
experiments=experiments,
|
|
385
|
+
total_battles=state.total_battles,
|
|
386
|
+
state=state,
|
|
387
|
+
min_input_images=img_range[0],
|
|
388
|
+
max_input_images=img_range[1],
|
|
389
|
+
prompt_sources=prompt_sources,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
self._subset_info_cache[subset] = info
|
|
393
|
+
return info
|
|
394
|
+
|
|
395
|
+
def _get_dataset(self, subset: str) -> Optional[ParquetDataset]:
|
|
396
|
+
"""Get or create ParquetDataset for a subset."""
|
|
397
|
+
if subset not in self._dataset_cache:
|
|
398
|
+
try:
|
|
399
|
+
self._dataset_cache[subset] = ParquetDataset(self.data_dir, subset)
|
|
400
|
+
except Exception:
|
|
401
|
+
return None
|
|
402
|
+
return self._dataset_cache[subset]
|
|
403
|
+
|
|
404
|
+
def _get_model_manager(self, subset: str) -> Optional[GlobalModelOutputManager]:
|
|
405
|
+
"""Get or create GlobalModelOutputManager for a subset."""
|
|
406
|
+
if subset not in self._model_manager_cache:
|
|
407
|
+
models_dir = os.path.join(self.arena_dir, subset, "models")
|
|
408
|
+
if os.path.isdir(models_dir):
|
|
409
|
+
self._model_manager_cache[subset] = GlobalModelOutputManager(models_dir)
|
|
410
|
+
else:
|
|
411
|
+
return None
|
|
412
|
+
return self._model_manager_cache[subset]
|
|
413
|
+
|
|
414
|
+
def _get_sample_data(self, subset: str, sample_index: int) -> dict[str, Any]:
|
|
415
|
+
"""Get cached sample metadata."""
|
|
416
|
+
cache_key = (subset, sample_index)
|
|
417
|
+
if cache_key in self._sample_cache:
|
|
418
|
+
return self._sample_cache[cache_key]
|
|
419
|
+
|
|
420
|
+
# Fallback - return defaults
|
|
421
|
+
return {
|
|
422
|
+
"instruction": "",
|
|
423
|
+
"task_type": "",
|
|
424
|
+
"input_image_count": 1,
|
|
425
|
+
"prompt_source": None,
|
|
426
|
+
"original_metadata": None,
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
def _load_battle_logs(self, subset: str, exp_name: str) -> list[BattleRecord]:
|
|
430
|
+
"""
|
|
431
|
+
Load battle records from log files.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
subset: Subset name
|
|
435
|
+
exp_name: Experiment name
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
List of BattleRecord objects
|
|
439
|
+
"""
|
|
440
|
+
cache_key = (subset, exp_name)
|
|
441
|
+
if cache_key in self._battle_cache:
|
|
442
|
+
return self._battle_cache[cache_key]
|
|
443
|
+
|
|
444
|
+
records: list[BattleRecord] = []
|
|
445
|
+
exp_dir = os.path.join(self.arena_dir, subset, "pk_logs", exp_name)
|
|
446
|
+
|
|
447
|
+
if not os.path.isdir(exp_dir):
|
|
448
|
+
return records
|
|
449
|
+
|
|
450
|
+
# Load slim battle logs
|
|
451
|
+
for filename in os.listdir(exp_dir):
|
|
452
|
+
if not filename.endswith(".jsonl"):
|
|
453
|
+
continue
|
|
454
|
+
|
|
455
|
+
filepath = os.path.join(exp_dir, filename)
|
|
456
|
+
if not os.path.isfile(filepath):
|
|
457
|
+
continue
|
|
458
|
+
|
|
459
|
+
try:
|
|
460
|
+
with open(filepath, "r", encoding="utf-8") as f:
|
|
461
|
+
for line in f:
|
|
462
|
+
line = line.strip()
|
|
463
|
+
if not line:
|
|
464
|
+
continue
|
|
465
|
+
try:
|
|
466
|
+
data = json.loads(line)
|
|
467
|
+
sample_index = data.get("sample_index", -1)
|
|
468
|
+
|
|
469
|
+
# Get cached sample data
|
|
470
|
+
sample_meta = self._get_sample_data(subset, sample_index)
|
|
471
|
+
|
|
472
|
+
record = BattleRecord(
|
|
473
|
+
subset=subset,
|
|
474
|
+
exp_name=exp_name,
|
|
475
|
+
sample_index=sample_index,
|
|
476
|
+
model_a=data.get("model_a", ""),
|
|
477
|
+
model_b=data.get("model_b", ""),
|
|
478
|
+
final_winner=data.get("final_winner", "tie"),
|
|
479
|
+
is_consistent=data.get("is_consistent", False),
|
|
480
|
+
timestamp=data.get("timestamp", ""),
|
|
481
|
+
instruction=sample_meta.get("instruction", ""),
|
|
482
|
+
task_type=sample_meta.get("task_type", ""),
|
|
483
|
+
input_image_count=sample_meta.get("input_image_count", 1),
|
|
484
|
+
prompt_source=sample_meta.get("prompt_source"),
|
|
485
|
+
original_metadata=sample_meta.get("original_metadata"),
|
|
486
|
+
)
|
|
487
|
+
if record.model_a and record.model_b:
|
|
488
|
+
records.append(record)
|
|
489
|
+
except json.JSONDecodeError:
|
|
490
|
+
continue
|
|
491
|
+
except Exception:
|
|
492
|
+
continue
|
|
493
|
+
|
|
494
|
+
# Sort by sample_index
|
|
495
|
+
records.sort(key=lambda r: r.sample_index)
|
|
496
|
+
|
|
497
|
+
# Cache records
|
|
498
|
+
self._battle_cache[cache_key] = records
|
|
499
|
+
|
|
500
|
+
# Build model index for fast filtering
|
|
501
|
+
self._build_model_index(cache_key, records)
|
|
502
|
+
|
|
503
|
+
return records
|
|
504
|
+
|
|
505
|
+
def _build_model_index(
|
|
506
|
+
self, cache_key: tuple[str, str], records: list[BattleRecord]
|
|
507
|
+
) -> None:
|
|
508
|
+
"""Build index for fast model-based filtering."""
|
|
509
|
+
model_index: dict[str, list[int]] = {}
|
|
510
|
+
|
|
511
|
+
for i, record in enumerate(records):
|
|
512
|
+
for model in [record.model_a, record.model_b]:
|
|
513
|
+
if model not in model_index:
|
|
514
|
+
model_index[model] = []
|
|
515
|
+
model_index[model].append(i)
|
|
516
|
+
|
|
517
|
+
self._model_index[cache_key] = model_index
|
|
518
|
+
|
|
519
|
+
def _load_all_experiments_battles(self, subset: str) -> list[BattleRecord]:
|
|
520
|
+
"""
|
|
521
|
+
Load battle records from all experiments for a subset.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
subset: Subset name
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
Combined list of BattleRecord objects from all experiments
|
|
528
|
+
"""
|
|
529
|
+
info = self.get_subset_info(subset)
|
|
530
|
+
if not info:
|
|
531
|
+
return []
|
|
532
|
+
|
|
533
|
+
all_records: list[BattleRecord] = []
|
|
534
|
+
for exp_name in info.experiments:
|
|
535
|
+
records = self._load_battle_logs(subset, exp_name)
|
|
536
|
+
all_records.extend(records)
|
|
537
|
+
|
|
538
|
+
# Sort by sample_index for consistent ordering
|
|
539
|
+
all_records.sort(key=lambda r: (r.sample_index, r.exp_name, r.model_a, r.model_b))
|
|
540
|
+
return all_records
|
|
541
|
+
|
|
542
|
+
def _load_audit_log(
|
|
543
|
+
self, subset: str, exp_name: str, model_a: str, model_b: str, sample_index: int
|
|
544
|
+
) -> Optional[dict[str, Any]]:
|
|
545
|
+
"""
|
|
546
|
+
Load audit log for a specific battle.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
subset: Subset name
|
|
550
|
+
exp_name: Experiment name
|
|
551
|
+
model_a: First model name
|
|
552
|
+
model_b: Second model name
|
|
553
|
+
sample_index: Sample index
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
Audit data dict or None
|
|
557
|
+
"""
|
|
558
|
+
cache_key = (subset, exp_name, model_a, model_b, sample_index)
|
|
559
|
+
if cache_key in self._audit_cache:
|
|
560
|
+
return self._audit_cache[cache_key]
|
|
561
|
+
|
|
562
|
+
# Determine filename (models are sorted alphabetically)
|
|
563
|
+
from genarena.utils import sanitize_name
|
|
564
|
+
|
|
565
|
+
first, second = sorted([model_a, model_b])
|
|
566
|
+
filename = f"{sanitize_name(first)}_vs_{sanitize_name(second)}.jsonl"
|
|
567
|
+
filepath = os.path.join(
|
|
568
|
+
self.arena_dir, subset, "pk_logs", exp_name, "raw_outputs", filename
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
if not os.path.isfile(filepath):
|
|
572
|
+
return None
|
|
573
|
+
|
|
574
|
+
try:
|
|
575
|
+
with open(filepath, "r", encoding="utf-8") as f:
|
|
576
|
+
for line in f:
|
|
577
|
+
line = line.strip()
|
|
578
|
+
if not line:
|
|
579
|
+
continue
|
|
580
|
+
try:
|
|
581
|
+
data = json.loads(line)
|
|
582
|
+
if data.get("sample_index") == sample_index:
|
|
583
|
+
self._audit_cache[cache_key] = data
|
|
584
|
+
return data
|
|
585
|
+
except json.JSONDecodeError:
|
|
586
|
+
continue
|
|
587
|
+
except Exception:
|
|
588
|
+
pass
|
|
589
|
+
|
|
590
|
+
return None
|
|
591
|
+
|
|
592
|
+
def get_battles(
|
|
593
|
+
self,
|
|
594
|
+
subset: str,
|
|
595
|
+
exp_name: str,
|
|
596
|
+
page: int = 1,
|
|
597
|
+
page_size: int = 20,
|
|
598
|
+
models: Optional[list[str]] = None,
|
|
599
|
+
result_filter: Optional[str] = None, # "wins", "losses", "ties"
|
|
600
|
+
consistency_filter: Optional[bool] = None,
|
|
601
|
+
min_images: Optional[int] = None,
|
|
602
|
+
max_images: Optional[int] = None,
|
|
603
|
+
prompt_source: Optional[str] = None,
|
|
604
|
+
) -> tuple[list[BattleRecord], int]:
|
|
605
|
+
"""
|
|
606
|
+
Get paginated battle records with filtering.
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
subset: Subset name
|
|
610
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
611
|
+
page: Page number (1-indexed)
|
|
612
|
+
page_size: Number of records per page
|
|
613
|
+
models: Filter by models (show battles involving ANY of these models)
|
|
614
|
+
result_filter: Filter by result relative to models ("wins", "losses", "ties")
|
|
615
|
+
consistency_filter: Filter by consistency (True/False/None for all)
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
Tuple of (records, total_count)
|
|
619
|
+
"""
|
|
620
|
+
# Handle "__all__" experiment - combine all experiments
|
|
621
|
+
if exp_name == "__all__":
|
|
622
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
623
|
+
# For __all__, we don't use the model index optimization
|
|
624
|
+
cache_key = None
|
|
625
|
+
else:
|
|
626
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
627
|
+
cache_key = (subset, exp_name)
|
|
628
|
+
|
|
629
|
+
# Apply filters using index for better performance
|
|
630
|
+
if models and cache_key and cache_key in self._model_index:
|
|
631
|
+
model_set = set(models)
|
|
632
|
+
model_index = self._model_index[cache_key]
|
|
633
|
+
|
|
634
|
+
if len(models) == 1:
|
|
635
|
+
# Single model: show battles involving this model
|
|
636
|
+
candidate_indices = set(model_index.get(models[0], []))
|
|
637
|
+
filtered = [all_records[i] for i in sorted(candidate_indices)]
|
|
638
|
+
else:
|
|
639
|
+
# 2+ models: show only battles BETWEEN these models (both participants must be in selected models)
|
|
640
|
+
# Find union of all records involving any selected model first
|
|
641
|
+
candidate_indices: set[int] = set()
|
|
642
|
+
for model in models:
|
|
643
|
+
if model in model_index:
|
|
644
|
+
candidate_indices.update(model_index[model])
|
|
645
|
+
# Then filter to keep only battles where BOTH models are in the selected set
|
|
646
|
+
filtered = [
|
|
647
|
+
all_records[i] for i in sorted(candidate_indices)
|
|
648
|
+
if all_records[i].model_a in model_set and all_records[i].model_b in model_set
|
|
649
|
+
]
|
|
650
|
+
|
|
651
|
+
# Apply result filter
|
|
652
|
+
if result_filter:
|
|
653
|
+
if len(models) == 1:
|
|
654
|
+
# Single model: filter by that model's wins/losses/ties
|
|
655
|
+
model = models[0]
|
|
656
|
+
if result_filter == "wins":
|
|
657
|
+
filtered = [r for r in filtered if r.final_winner == model]
|
|
658
|
+
elif result_filter == "losses":
|
|
659
|
+
filtered = [
|
|
660
|
+
r
|
|
661
|
+
for r in filtered
|
|
662
|
+
if r.final_winner != "tie" and r.final_winner != model
|
|
663
|
+
]
|
|
664
|
+
elif result_filter == "ties":
|
|
665
|
+
filtered = [r for r in filtered if r.final_winner == "tie"]
|
|
666
|
+
elif len(models) == 2:
|
|
667
|
+
# Two models: filter by winner (result_filter is the winning model name or "tie")
|
|
668
|
+
if result_filter == "ties":
|
|
669
|
+
filtered = [r for r in filtered if r.final_winner == "tie"]
|
|
670
|
+
elif result_filter in models:
|
|
671
|
+
# Filter by specific model winning
|
|
672
|
+
filtered = [r for r in filtered if r.final_winner == result_filter]
|
|
673
|
+
elif models:
|
|
674
|
+
# Fallback for __all__ mode or when index is not available
|
|
675
|
+
model_set = set(models)
|
|
676
|
+
if len(models) == 1:
|
|
677
|
+
model = models[0]
|
|
678
|
+
filtered = [r for r in all_records if model in r.models]
|
|
679
|
+
# Apply result filter
|
|
680
|
+
if result_filter:
|
|
681
|
+
if result_filter == "wins":
|
|
682
|
+
filtered = [r for r in filtered if r.final_winner == model]
|
|
683
|
+
elif result_filter == "losses":
|
|
684
|
+
filtered = [
|
|
685
|
+
r
|
|
686
|
+
for r in filtered
|
|
687
|
+
if r.final_winner != "tie" and r.final_winner != model
|
|
688
|
+
]
|
|
689
|
+
elif result_filter == "ties":
|
|
690
|
+
filtered = [r for r in filtered if r.final_winner == "tie"]
|
|
691
|
+
else:
|
|
692
|
+
# 2+ models: show battles between these models
|
|
693
|
+
filtered = [
|
|
694
|
+
r for r in all_records
|
|
695
|
+
if r.model_a in model_set and r.model_b in model_set
|
|
696
|
+
]
|
|
697
|
+
# Apply result filter
|
|
698
|
+
if result_filter:
|
|
699
|
+
if result_filter == "ties":
|
|
700
|
+
filtered = [r for r in filtered if r.final_winner == "tie"]
|
|
701
|
+
elif result_filter in models:
|
|
702
|
+
filtered = [r for r in filtered if r.final_winner == result_filter]
|
|
703
|
+
else:
|
|
704
|
+
filtered = all_records
|
|
705
|
+
|
|
706
|
+
# Apply consistency filter
|
|
707
|
+
if consistency_filter is not None:
|
|
708
|
+
filtered = [r for r in filtered if r.is_consistent == consistency_filter]
|
|
709
|
+
|
|
710
|
+
# Apply input image count filter
|
|
711
|
+
if min_images is not None or max_images is not None:
|
|
712
|
+
min_img = min_images if min_images is not None else 0
|
|
713
|
+
max_img = max_images if max_images is not None else float('inf')
|
|
714
|
+
filtered = [r for r in filtered if min_img <= r.input_image_count <= max_img]
|
|
715
|
+
|
|
716
|
+
# Apply prompt_source filter
|
|
717
|
+
if prompt_source:
|
|
718
|
+
filtered = [r for r in filtered if r.prompt_source == prompt_source]
|
|
719
|
+
|
|
720
|
+
total_count = len(filtered)
|
|
721
|
+
|
|
722
|
+
# Paginate
|
|
723
|
+
start = (page - 1) * page_size
|
|
724
|
+
end = start + page_size
|
|
725
|
+
page_records = filtered[start:end]
|
|
726
|
+
|
|
727
|
+
return page_records, total_count
|
|
728
|
+
|
|
729
|
+
def search_battles(
|
|
730
|
+
self,
|
|
731
|
+
subset: str,
|
|
732
|
+
exp_name: str,
|
|
733
|
+
query: str,
|
|
734
|
+
page: int = 1,
|
|
735
|
+
page_size: int = 20,
|
|
736
|
+
models: Optional[list[str]] = None,
|
|
737
|
+
consistency_filter: Optional[bool] = None,
|
|
738
|
+
search_fields: Optional[list[str]] = None,
|
|
739
|
+
) -> tuple[list[BattleRecord], int]:
|
|
740
|
+
"""
|
|
741
|
+
Search battle records by text query (full-text search).
|
|
742
|
+
|
|
743
|
+
Searches across instruction, task_type, prompt_source, and original_metadata.
|
|
744
|
+
|
|
745
|
+
Args:
|
|
746
|
+
subset: Subset name
|
|
747
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
748
|
+
query: Search query string (case-insensitive)
|
|
749
|
+
page: Page number (1-indexed)
|
|
750
|
+
page_size: Number of records per page
|
|
751
|
+
models: Optional filter by models
|
|
752
|
+
consistency_filter: Optional filter by consistency
|
|
753
|
+
search_fields: Fields to search in (default: all searchable fields)
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
Tuple of (matching_records, total_count)
|
|
757
|
+
"""
|
|
758
|
+
if not query or not query.strip():
|
|
759
|
+
# Empty query - return regular filtered results
|
|
760
|
+
return self.get_battles(
|
|
761
|
+
subset, exp_name, page, page_size,
|
|
762
|
+
models=models, consistency_filter=consistency_filter
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
# Normalize query for case-insensitive search
|
|
766
|
+
query_lower = query.lower().strip()
|
|
767
|
+
# Create regex pattern for more flexible matching
|
|
768
|
+
query_pattern = re.compile(re.escape(query_lower), re.IGNORECASE)
|
|
769
|
+
|
|
770
|
+
# Determine which fields to search
|
|
771
|
+
all_searchable_fields = ["instruction", "task_type", "prompt_source", "original_metadata"]
|
|
772
|
+
fields_to_search = search_fields if search_fields else all_searchable_fields
|
|
773
|
+
|
|
774
|
+
# Load all records
|
|
775
|
+
if exp_name == "__all__":
|
|
776
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
777
|
+
else:
|
|
778
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
779
|
+
|
|
780
|
+
# Apply model filter first (for efficiency)
|
|
781
|
+
if models:
|
|
782
|
+
model_set = set(models)
|
|
783
|
+
if len(models) == 1:
|
|
784
|
+
all_records = [r for r in all_records if models[0] in r.models]
|
|
785
|
+
else:
|
|
786
|
+
all_records = [
|
|
787
|
+
r for r in all_records
|
|
788
|
+
if r.model_a in model_set and r.model_b in model_set
|
|
789
|
+
]
|
|
790
|
+
|
|
791
|
+
# Apply consistency filter
|
|
792
|
+
if consistency_filter is not None:
|
|
793
|
+
all_records = [r for r in all_records if r.is_consistent == consistency_filter]
|
|
794
|
+
|
|
795
|
+
# Search filter
|
|
796
|
+
def matches_query(record: BattleRecord) -> bool:
|
|
797
|
+
"""Check if record matches the search query."""
|
|
798
|
+
for field_name in fields_to_search:
|
|
799
|
+
value = getattr(record, field_name, None)
|
|
800
|
+
if value is None:
|
|
801
|
+
continue
|
|
802
|
+
|
|
803
|
+
# Handle different field types
|
|
804
|
+
if field_name == "original_metadata" and isinstance(value, dict):
|
|
805
|
+
# Search in JSON string representation of metadata
|
|
806
|
+
metadata_str = json.dumps(value, ensure_ascii=False).lower()
|
|
807
|
+
if query_pattern.search(metadata_str):
|
|
808
|
+
return True
|
|
809
|
+
elif isinstance(value, str):
|
|
810
|
+
if query_pattern.search(value):
|
|
811
|
+
return True
|
|
812
|
+
|
|
813
|
+
return False
|
|
814
|
+
|
|
815
|
+
# Apply search filter
|
|
816
|
+
filtered = [r for r in all_records if matches_query(r)]
|
|
817
|
+
|
|
818
|
+
total_count = len(filtered)
|
|
819
|
+
|
|
820
|
+
# Paginate
|
|
821
|
+
start = (page - 1) * page_size
|
|
822
|
+
end = start + page_size
|
|
823
|
+
page_records = filtered[start:end]
|
|
824
|
+
|
|
825
|
+
return page_records, total_count
|
|
826
|
+
|
|
827
|
+
def search_prompts(
|
|
828
|
+
self,
|
|
829
|
+
subset: str,
|
|
830
|
+
exp_name: str,
|
|
831
|
+
query: str,
|
|
832
|
+
page: int = 1,
|
|
833
|
+
page_size: int = 10,
|
|
834
|
+
filter_models: Optional[list[str]] = None,
|
|
835
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
836
|
+
"""
|
|
837
|
+
Search prompts/samples by text query.
|
|
838
|
+
|
|
839
|
+
Args:
|
|
840
|
+
subset: Subset name
|
|
841
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
842
|
+
query: Search query string
|
|
843
|
+
page: Page number
|
|
844
|
+
page_size: Records per page
|
|
845
|
+
filter_models: Optional filter by models
|
|
846
|
+
|
|
847
|
+
Returns:
|
|
848
|
+
Tuple of (matching_prompts, total_count)
|
|
849
|
+
"""
|
|
850
|
+
if not query or not query.strip():
|
|
851
|
+
# Empty query - return regular results
|
|
852
|
+
return self.get_prompts(subset, exp_name, page, page_size, filter_models=filter_models)
|
|
853
|
+
|
|
854
|
+
# Normalize query
|
|
855
|
+
query_lower = query.lower().strip()
|
|
856
|
+
query_pattern = re.compile(re.escape(query_lower), re.IGNORECASE)
|
|
857
|
+
|
|
858
|
+
# Load records and group by sample
|
|
859
|
+
if exp_name == "__all__":
|
|
860
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
861
|
+
else:
|
|
862
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
863
|
+
|
|
864
|
+
# Group by sample_index
|
|
865
|
+
sample_records: dict[int, list[BattleRecord]] = {}
|
|
866
|
+
for record in all_records:
|
|
867
|
+
if record.sample_index not in sample_records:
|
|
868
|
+
sample_records[record.sample_index] = []
|
|
869
|
+
sample_records[record.sample_index].append(record)
|
|
870
|
+
|
|
871
|
+
# Filter samples by query
|
|
872
|
+
matching_samples = []
|
|
873
|
+
for sample_index, records in sample_records.items():
|
|
874
|
+
if not records:
|
|
875
|
+
continue
|
|
876
|
+
|
|
877
|
+
first_record = records[0]
|
|
878
|
+
|
|
879
|
+
# Search in instruction, task_type, prompt_source, original_metadata
|
|
880
|
+
match_found = False
|
|
881
|
+
|
|
882
|
+
if first_record.instruction and query_pattern.search(first_record.instruction):
|
|
883
|
+
match_found = True
|
|
884
|
+
elif first_record.task_type and query_pattern.search(first_record.task_type):
|
|
885
|
+
match_found = True
|
|
886
|
+
elif first_record.prompt_source and query_pattern.search(first_record.prompt_source):
|
|
887
|
+
match_found = True
|
|
888
|
+
elif first_record.original_metadata:
|
|
889
|
+
metadata_str = json.dumps(first_record.original_metadata, ensure_ascii=False).lower()
|
|
890
|
+
if query_pattern.search(metadata_str):
|
|
891
|
+
match_found = True
|
|
892
|
+
|
|
893
|
+
if match_found:
|
|
894
|
+
matching_samples.append(sample_index)
|
|
895
|
+
|
|
896
|
+
# Sort and paginate
|
|
897
|
+
matching_samples.sort()
|
|
898
|
+
total_count = len(matching_samples)
|
|
899
|
+
|
|
900
|
+
start = (page - 1) * page_size
|
|
901
|
+
end = start + page_size
|
|
902
|
+
page_samples = matching_samples[start:end]
|
|
903
|
+
|
|
904
|
+
# Build result for each sample using get_sample_all_models
|
|
905
|
+
results = []
|
|
906
|
+
for sample_index in page_samples:
|
|
907
|
+
prompt_data = self.get_sample_all_models(subset, exp_name, sample_index, filter_models)
|
|
908
|
+
results.append(prompt_data)
|
|
909
|
+
|
|
910
|
+
return results, total_count
|
|
911
|
+
|
|
912
|
+
def get_battle_detail(
|
|
913
|
+
self, subset: str, exp_name: str, model_a: str, model_b: str, sample_index: int
|
|
914
|
+
) -> Optional[BattleRecord]:
|
|
915
|
+
"""
|
|
916
|
+
Get detailed battle record including VLM outputs.
|
|
917
|
+
|
|
918
|
+
Args:
|
|
919
|
+
subset: Subset name
|
|
920
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
921
|
+
model_a: First model name
|
|
922
|
+
model_b: Second model name
|
|
923
|
+
sample_index: Sample index
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
BattleRecord with audit data, or None
|
|
927
|
+
"""
|
|
928
|
+
# Find the battle record
|
|
929
|
+
if exp_name == "__all__":
|
|
930
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
931
|
+
else:
|
|
932
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
933
|
+
|
|
934
|
+
record = None
|
|
935
|
+
for r in all_records:
|
|
936
|
+
if (
|
|
937
|
+
r.sample_index == sample_index
|
|
938
|
+
and set([r.model_a, r.model_b]) == set([model_a, model_b])
|
|
939
|
+
):
|
|
940
|
+
record = r
|
|
941
|
+
break
|
|
942
|
+
|
|
943
|
+
if not record:
|
|
944
|
+
return None
|
|
945
|
+
|
|
946
|
+
# Load audit data (use the record's actual exp_name for audit log lookup)
|
|
947
|
+
actual_exp_name = record.exp_name
|
|
948
|
+
audit = self._load_audit_log(
|
|
949
|
+
subset, actual_exp_name, record.model_a, record.model_b, sample_index
|
|
950
|
+
)
|
|
951
|
+
if audit:
|
|
952
|
+
record.original_call = audit.get("original_call")
|
|
953
|
+
record.swapped_call = audit.get("swapped_call")
|
|
954
|
+
|
|
955
|
+
return record
|
|
956
|
+
|
|
957
|
+
def get_image_path(
|
|
958
|
+
self, subset: str, model: str, sample_index: int
|
|
959
|
+
) -> Optional[str]:
|
|
960
|
+
"""
|
|
961
|
+
Get path to model output image.
|
|
962
|
+
|
|
963
|
+
Args:
|
|
964
|
+
subset: Subset name
|
|
965
|
+
model: Model name
|
|
966
|
+
sample_index: Sample index
|
|
967
|
+
|
|
968
|
+
Returns:
|
|
969
|
+
Image file path or None
|
|
970
|
+
"""
|
|
971
|
+
model_manager = self._get_model_manager(subset)
|
|
972
|
+
if model_manager:
|
|
973
|
+
return model_manager.get_output_path(model, sample_index)
|
|
974
|
+
return None
|
|
975
|
+
|
|
976
|
+
def get_input_image(self, subset: str, sample_index: int) -> Optional[bytes]:
|
|
977
|
+
"""
|
|
978
|
+
Get input image bytes for a sample.
|
|
979
|
+
|
|
980
|
+
Uses pyarrow to read directly from parquet for better performance.
|
|
981
|
+
Uses cached file mapping for fast lookup.
|
|
982
|
+
|
|
983
|
+
Args:
|
|
984
|
+
subset: Subset name
|
|
985
|
+
sample_index: Sample index
|
|
986
|
+
|
|
987
|
+
Returns:
|
|
988
|
+
Image bytes or None
|
|
989
|
+
"""
|
|
990
|
+
import pyarrow.parquet as pq
|
|
991
|
+
|
|
992
|
+
# Use cached file mapping if available (fast path)
|
|
993
|
+
cache_key = (subset, sample_index)
|
|
994
|
+
if cache_key in self._sample_file_map:
|
|
995
|
+
pf = self._sample_file_map[cache_key]
|
|
996
|
+
result = self._read_image_from_parquet(pf, sample_index)
|
|
997
|
+
if result is not None:
|
|
998
|
+
return result
|
|
999
|
+
|
|
1000
|
+
# Fallback: search all parquet files (slow path)
|
|
1001
|
+
subset_path = os.path.join(self.data_dir, subset)
|
|
1002
|
+
if not os.path.isdir(subset_path):
|
|
1003
|
+
return None
|
|
1004
|
+
|
|
1005
|
+
parquet_files = sorted([
|
|
1006
|
+
os.path.join(subset_path, f)
|
|
1007
|
+
for f in os.listdir(subset_path)
|
|
1008
|
+
if f.startswith("data-") and f.endswith(".parquet")
|
|
1009
|
+
])
|
|
1010
|
+
|
|
1011
|
+
for pf in parquet_files:
|
|
1012
|
+
result = self._read_image_from_parquet(pf, sample_index)
|
|
1013
|
+
if result is not None:
|
|
1014
|
+
return result
|
|
1015
|
+
|
|
1016
|
+
return None
|
|
1017
|
+
|
|
1018
|
+
def _read_image_from_parquet(self, parquet_file: str, sample_index: int) -> Optional[bytes]:
|
|
1019
|
+
"""Read a single image from a parquet file."""
|
|
1020
|
+
import pyarrow.parquet as pq
|
|
1021
|
+
|
|
1022
|
+
try:
|
|
1023
|
+
table = pq.read_table(parquet_file, columns=["index", "input_images"])
|
|
1024
|
+
indices = table.column("index").to_pylist()
|
|
1025
|
+
|
|
1026
|
+
if sample_index not in indices:
|
|
1027
|
+
return None
|
|
1028
|
+
|
|
1029
|
+
row_idx = indices.index(sample_index)
|
|
1030
|
+
input_images = table.column("input_images")[row_idx].as_py()
|
|
1031
|
+
|
|
1032
|
+
if not input_images or len(input_images) == 0:
|
|
1033
|
+
return None
|
|
1034
|
+
|
|
1035
|
+
img_data = input_images[0]
|
|
1036
|
+
|
|
1037
|
+
# Handle different formats
|
|
1038
|
+
if isinstance(img_data, bytes):
|
|
1039
|
+
return img_data
|
|
1040
|
+
elif isinstance(img_data, dict):
|
|
1041
|
+
# HuggingFace Image format: {"bytes": ..., "path": ...}
|
|
1042
|
+
if "bytes" in img_data and img_data["bytes"]:
|
|
1043
|
+
return img_data["bytes"]
|
|
1044
|
+
elif "path" in img_data and img_data["path"]:
|
|
1045
|
+
path = img_data["path"]
|
|
1046
|
+
if os.path.isfile(path):
|
|
1047
|
+
with open(path, "rb") as f:
|
|
1048
|
+
return f.read()
|
|
1049
|
+
|
|
1050
|
+
except Exception as e:
|
|
1051
|
+
logger.debug(f"Error reading image from {parquet_file}: {e}")
|
|
1052
|
+
|
|
1053
|
+
return None
|
|
1054
|
+
|
|
1055
|
+
def get_input_image_count(self, subset: str, sample_index: int) -> int:
|
|
1056
|
+
"""Get the number of input images for a sample."""
|
|
1057
|
+
import pyarrow.parquet as pq
|
|
1058
|
+
|
|
1059
|
+
cache_key = (subset, sample_index)
|
|
1060
|
+
if cache_key in self._sample_file_map:
|
|
1061
|
+
pf = self._sample_file_map[cache_key]
|
|
1062
|
+
try:
|
|
1063
|
+
table = pq.read_table(pf, columns=["index", "input_images"])
|
|
1064
|
+
indices = table.column("index").to_pylist()
|
|
1065
|
+
if sample_index in indices:
|
|
1066
|
+
row_idx = indices.index(sample_index)
|
|
1067
|
+
input_images = table.column("input_images")[row_idx].as_py()
|
|
1068
|
+
return len(input_images) if input_images else 0
|
|
1069
|
+
except Exception:
|
|
1070
|
+
pass
|
|
1071
|
+
return 1 # Default to 1
|
|
1072
|
+
|
|
1073
|
+
def get_input_image_by_idx(self, subset: str, sample_index: int, img_idx: int = 0) -> Optional[bytes]:
|
|
1074
|
+
"""Get a specific input image by index."""
|
|
1075
|
+
import pyarrow.parquet as pq
|
|
1076
|
+
|
|
1077
|
+
cache_key = (subset, sample_index)
|
|
1078
|
+
if cache_key not in self._sample_file_map:
|
|
1079
|
+
return None
|
|
1080
|
+
|
|
1081
|
+
pf = self._sample_file_map[cache_key]
|
|
1082
|
+
try:
|
|
1083
|
+
table = pq.read_table(pf, columns=["index", "input_images"])
|
|
1084
|
+
indices = table.column("index").to_pylist()
|
|
1085
|
+
|
|
1086
|
+
if sample_index not in indices:
|
|
1087
|
+
return None
|
|
1088
|
+
|
|
1089
|
+
row_idx = indices.index(sample_index)
|
|
1090
|
+
input_images = table.column("input_images")[row_idx].as_py()
|
|
1091
|
+
|
|
1092
|
+
if not input_images or img_idx >= len(input_images):
|
|
1093
|
+
return None
|
|
1094
|
+
|
|
1095
|
+
img_data = input_images[img_idx]
|
|
1096
|
+
|
|
1097
|
+
if isinstance(img_data, bytes):
|
|
1098
|
+
return img_data
|
|
1099
|
+
elif isinstance(img_data, dict):
|
|
1100
|
+
if "bytes" in img_data and img_data["bytes"]:
|
|
1101
|
+
return img_data["bytes"]
|
|
1102
|
+
elif "path" in img_data and img_data["path"]:
|
|
1103
|
+
path = img_data["path"]
|
|
1104
|
+
if os.path.isfile(path):
|
|
1105
|
+
with open(path, "rb") as f:
|
|
1106
|
+
return f.read()
|
|
1107
|
+
except Exception as e:
|
|
1108
|
+
logger.debug(f"Error reading image: {e}")
|
|
1109
|
+
|
|
1110
|
+
return None
|
|
1111
|
+
|
|
1112
|
+
def get_head_to_head(
|
|
1113
|
+
self, subset: str, exp_name: str, model_a: str, model_b: str
|
|
1114
|
+
) -> dict[str, Any]:
|
|
1115
|
+
"""
|
|
1116
|
+
Get head-to-head statistics between two models.
|
|
1117
|
+
|
|
1118
|
+
Returns:
|
|
1119
|
+
Dict with wins_a, wins_b, ties, total, win_rate_a, win_rate_b
|
|
1120
|
+
"""
|
|
1121
|
+
if exp_name == "__all__":
|
|
1122
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
1123
|
+
# For __all__, we need to filter manually
|
|
1124
|
+
h2h_records = [
|
|
1125
|
+
r for r in all_records
|
|
1126
|
+
if set([r.model_a, r.model_b]) == set([model_a, model_b])
|
|
1127
|
+
]
|
|
1128
|
+
else:
|
|
1129
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
1130
|
+
cache_key = (subset, exp_name)
|
|
1131
|
+
model_index = self._model_index.get(cache_key, {})
|
|
1132
|
+
|
|
1133
|
+
# Find battles between these two models
|
|
1134
|
+
indices_a = set(model_index.get(model_a, []))
|
|
1135
|
+
indices_b = set(model_index.get(model_b, []))
|
|
1136
|
+
h2h_indices = indices_a & indices_b
|
|
1137
|
+
h2h_records = [all_records[idx] for idx in h2h_indices]
|
|
1138
|
+
|
|
1139
|
+
wins_a = 0
|
|
1140
|
+
wins_b = 0
|
|
1141
|
+
ties = 0
|
|
1142
|
+
|
|
1143
|
+
for record in h2h_records:
|
|
1144
|
+
if record.final_winner == model_a:
|
|
1145
|
+
wins_a += 1
|
|
1146
|
+
elif record.final_winner == model_b:
|
|
1147
|
+
wins_b += 1
|
|
1148
|
+
else:
|
|
1149
|
+
ties += 1
|
|
1150
|
+
|
|
1151
|
+
total = wins_a + wins_b + ties
|
|
1152
|
+
|
|
1153
|
+
return {
|
|
1154
|
+
"model_a": model_a,
|
|
1155
|
+
"model_b": model_b,
|
|
1156
|
+
"wins_a": wins_a,
|
|
1157
|
+
"wins_b": wins_b,
|
|
1158
|
+
"ties": ties,
|
|
1159
|
+
"total": total,
|
|
1160
|
+
"win_rate_a": wins_a / total if total > 0 else 0,
|
|
1161
|
+
"win_rate_b": wins_b / total if total > 0 else 0,
|
|
1162
|
+
"tie_rate": ties / total if total > 0 else 0,
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
def get_win_rate_matrix(
|
|
1166
|
+
self,
|
|
1167
|
+
subset: str,
|
|
1168
|
+
exp_name: str = "__all__",
|
|
1169
|
+
filter_models: Optional[list[str]] = None,
|
|
1170
|
+
) -> dict[str, Any]:
|
|
1171
|
+
"""
|
|
1172
|
+
Compute win rate matrix for all model pairs.
|
|
1173
|
+
|
|
1174
|
+
Args:
|
|
1175
|
+
subset: Subset name
|
|
1176
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
1177
|
+
filter_models: Optional list of models to include
|
|
1178
|
+
|
|
1179
|
+
Returns:
|
|
1180
|
+
Dict with:
|
|
1181
|
+
- models: List of model names (sorted by ELO)
|
|
1182
|
+
- matrix: 2D array where matrix[i][j] = win rate of model i vs model j
|
|
1183
|
+
- counts: 2D array where counts[i][j] = number of battles between i and j
|
|
1184
|
+
- wins: 2D array where wins[i][j] = wins of model i vs model j
|
|
1185
|
+
"""
|
|
1186
|
+
# Load all records
|
|
1187
|
+
if exp_name == "__all__":
|
|
1188
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
1189
|
+
else:
|
|
1190
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
1191
|
+
|
|
1192
|
+
# Determine models to include
|
|
1193
|
+
info = self.get_subset_info(subset)
|
|
1194
|
+
if filter_models:
|
|
1195
|
+
models = [m for m in filter_models if m in info.models]
|
|
1196
|
+
else:
|
|
1197
|
+
models = list(info.models)
|
|
1198
|
+
|
|
1199
|
+
# Get ELO leaderboard to sort models by ELO
|
|
1200
|
+
leaderboard = self.get_elo_leaderboard(subset, models)
|
|
1201
|
+
models = [entry["model"] for entry in leaderboard]
|
|
1202
|
+
|
|
1203
|
+
n = len(models)
|
|
1204
|
+
model_to_idx = {m: i for i, m in enumerate(models)}
|
|
1205
|
+
|
|
1206
|
+
# Initialize matrices
|
|
1207
|
+
wins_matrix = [[0] * n for _ in range(n)]
|
|
1208
|
+
counts_matrix = [[0] * n for _ in range(n)]
|
|
1209
|
+
|
|
1210
|
+
# Count wins for each pair
|
|
1211
|
+
model_set = set(models)
|
|
1212
|
+
for record in all_records:
|
|
1213
|
+
if record.model_a not in model_set or record.model_b not in model_set:
|
|
1214
|
+
continue
|
|
1215
|
+
|
|
1216
|
+
i = model_to_idx[record.model_a]
|
|
1217
|
+
j = model_to_idx[record.model_b]
|
|
1218
|
+
|
|
1219
|
+
# Count total battles (symmetric)
|
|
1220
|
+
counts_matrix[i][j] += 1
|
|
1221
|
+
counts_matrix[j][i] += 1
|
|
1222
|
+
|
|
1223
|
+
# Count wins
|
|
1224
|
+
if record.final_winner == record.model_a:
|
|
1225
|
+
wins_matrix[i][j] += 1
|
|
1226
|
+
elif record.final_winner == record.model_b:
|
|
1227
|
+
wins_matrix[j][i] += 1
|
|
1228
|
+
else:
|
|
1229
|
+
# Tie counts as 0.5 win for each
|
|
1230
|
+
wins_matrix[i][j] += 0.5
|
|
1231
|
+
wins_matrix[j][i] += 0.5
|
|
1232
|
+
|
|
1233
|
+
# Compute win rate matrix
|
|
1234
|
+
win_rate_matrix = [[0.0] * n for _ in range(n)]
|
|
1235
|
+
for i in range(n):
|
|
1236
|
+
for j in range(n):
|
|
1237
|
+
if counts_matrix[i][j] > 0:
|
|
1238
|
+
win_rate_matrix[i][j] = wins_matrix[i][j] / counts_matrix[i][j]
|
|
1239
|
+
elif i == j:
|
|
1240
|
+
win_rate_matrix[i][j] = 0.5 # Self vs self
|
|
1241
|
+
|
|
1242
|
+
return {
|
|
1243
|
+
"models": models,
|
|
1244
|
+
"matrix": win_rate_matrix,
|
|
1245
|
+
"counts": counts_matrix,
|
|
1246
|
+
"wins": wins_matrix,
|
|
1247
|
+
}
|
|
1248
|
+
|
|
1249
|
+
def get_elo_by_source(
|
|
1250
|
+
self,
|
|
1251
|
+
subset: str,
|
|
1252
|
+
exp_name: str = "__all__",
|
|
1253
|
+
) -> dict[str, Any]:
|
|
1254
|
+
"""
|
|
1255
|
+
Compute ELO rankings grouped by prompt_source.
|
|
1256
|
+
|
|
1257
|
+
Args:
|
|
1258
|
+
subset: Subset name
|
|
1259
|
+
exp_name: Experiment name
|
|
1260
|
+
|
|
1261
|
+
Returns:
|
|
1262
|
+
Dict with:
|
|
1263
|
+
- sources: List of source names
|
|
1264
|
+
- leaderboards: Dict mapping source -> list of model ELO entries
|
|
1265
|
+
- sample_counts: Dict mapping source -> number of samples
|
|
1266
|
+
- battle_counts: Dict mapping source -> number of battles
|
|
1267
|
+
"""
|
|
1268
|
+
from genarena.bt_elo import compute_bt_elo_ratings
|
|
1269
|
+
|
|
1270
|
+
# Load all records
|
|
1271
|
+
if exp_name == "__all__":
|
|
1272
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
1273
|
+
else:
|
|
1274
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
1275
|
+
|
|
1276
|
+
# Group battles by prompt_source
|
|
1277
|
+
battles_by_source: dict[str, list[tuple[str, str, str]]] = {}
|
|
1278
|
+
sample_counts: dict[str, set[int]] = {}
|
|
1279
|
+
|
|
1280
|
+
for record in all_records:
|
|
1281
|
+
source = record.prompt_source or "unknown"
|
|
1282
|
+
if source not in battles_by_source:
|
|
1283
|
+
battles_by_source[source] = []
|
|
1284
|
+
sample_counts[source] = set()
|
|
1285
|
+
|
|
1286
|
+
# Convert winner to bt_elo format
|
|
1287
|
+
if record.final_winner == record.model_a:
|
|
1288
|
+
winner = "model_a"
|
|
1289
|
+
elif record.final_winner == record.model_b:
|
|
1290
|
+
winner = "model_b"
|
|
1291
|
+
else:
|
|
1292
|
+
winner = "tie"
|
|
1293
|
+
|
|
1294
|
+
battles_by_source[source].append((record.model_a, record.model_b, winner))
|
|
1295
|
+
sample_counts[source].add(record.sample_index)
|
|
1296
|
+
|
|
1297
|
+
# Compute ELO for each source
|
|
1298
|
+
leaderboards: dict[str, list[dict[str, Any]]] = {}
|
|
1299
|
+
battle_counts: dict[str, int] = {}
|
|
1300
|
+
|
|
1301
|
+
for source, battles in battles_by_source.items():
|
|
1302
|
+
if not battles:
|
|
1303
|
+
continue
|
|
1304
|
+
|
|
1305
|
+
battle_counts[source] = len(battles)
|
|
1306
|
+
|
|
1307
|
+
try:
|
|
1308
|
+
ratings = compute_bt_elo_ratings(battles)
|
|
1309
|
+
|
|
1310
|
+
# Build leaderboard
|
|
1311
|
+
entries = []
|
|
1312
|
+
for model, elo in ratings.items():
|
|
1313
|
+
# Count wins/losses/ties for this model in this source
|
|
1314
|
+
wins = losses = ties = 0
|
|
1315
|
+
for ma, mb, w in battles:
|
|
1316
|
+
if model == ma:
|
|
1317
|
+
if w == "model_a":
|
|
1318
|
+
wins += 1
|
|
1319
|
+
elif w == "model_b":
|
|
1320
|
+
losses += 1
|
|
1321
|
+
else:
|
|
1322
|
+
ties += 1
|
|
1323
|
+
elif model == mb:
|
|
1324
|
+
if w == "model_b":
|
|
1325
|
+
wins += 1
|
|
1326
|
+
elif w == "model_a":
|
|
1327
|
+
losses += 1
|
|
1328
|
+
else:
|
|
1329
|
+
ties += 1
|
|
1330
|
+
|
|
1331
|
+
total = wins + losses + ties
|
|
1332
|
+
entries.append({
|
|
1333
|
+
"model": model,
|
|
1334
|
+
"elo": round(elo, 1),
|
|
1335
|
+
"wins": wins,
|
|
1336
|
+
"losses": losses,
|
|
1337
|
+
"ties": ties,
|
|
1338
|
+
"total": total,
|
|
1339
|
+
"win_rate": (wins + 0.5 * ties) / total if total > 0 else 0,
|
|
1340
|
+
})
|
|
1341
|
+
|
|
1342
|
+
# Sort by ELO descending
|
|
1343
|
+
entries.sort(key=lambda x: -x["elo"])
|
|
1344
|
+
leaderboards[source] = entries
|
|
1345
|
+
|
|
1346
|
+
except Exception as e:
|
|
1347
|
+
logger.warning(f"Failed to compute ELO for source {source}: {e}")
|
|
1348
|
+
continue
|
|
1349
|
+
|
|
1350
|
+
# Sort sources by battle count
|
|
1351
|
+
sources = sorted(battle_counts.keys(), key=lambda s: -battle_counts[s])
|
|
1352
|
+
|
|
1353
|
+
return {
|
|
1354
|
+
"sources": sources,
|
|
1355
|
+
"leaderboards": leaderboards,
|
|
1356
|
+
"sample_counts": {s: len(sample_counts[s]) for s in sources},
|
|
1357
|
+
"battle_counts": battle_counts,
|
|
1358
|
+
}
|
|
1359
|
+
|
|
1360
|
+
def _load_elo_snapshot(self, snapshot_path: str) -> Optional[dict[str, Any]]:
|
|
1361
|
+
"""
|
|
1362
|
+
Load ELO snapshot from a JSON file.
|
|
1363
|
+
|
|
1364
|
+
Args:
|
|
1365
|
+
snapshot_path: Path to elo_snapshot.json
|
|
1366
|
+
|
|
1367
|
+
Returns:
|
|
1368
|
+
Dict with elo ratings and metadata, or None if not found
|
|
1369
|
+
"""
|
|
1370
|
+
if not os.path.isfile(snapshot_path):
|
|
1371
|
+
return None
|
|
1372
|
+
|
|
1373
|
+
try:
|
|
1374
|
+
with open(snapshot_path, "r", encoding="utf-8") as f:
|
|
1375
|
+
data = json.load(f)
|
|
1376
|
+
|
|
1377
|
+
if not isinstance(data, dict):
|
|
1378
|
+
return None
|
|
1379
|
+
|
|
1380
|
+
# Extract ELO ratings (support both {"elo": {...}} and direct {model: elo} format)
|
|
1381
|
+
elo_data = data.get("elo") if isinstance(data.get("elo"), dict) else data
|
|
1382
|
+
if not isinstance(elo_data, dict):
|
|
1383
|
+
return None
|
|
1384
|
+
|
|
1385
|
+
return {
|
|
1386
|
+
"elo": {str(k): float(v) for k, v in elo_data.items()},
|
|
1387
|
+
"battle_count": data.get("battle_count", 0),
|
|
1388
|
+
"model_count": data.get("model_count", len(elo_data)),
|
|
1389
|
+
"exp_name": data.get("exp_name", ""),
|
|
1390
|
+
}
|
|
1391
|
+
except Exception as e:
|
|
1392
|
+
logger.debug(f"Failed to load ELO snapshot from {snapshot_path}: {e}")
|
|
1393
|
+
return None
|
|
1394
|
+
|
|
1395
|
+
def get_elo_history(
|
|
1396
|
+
self,
|
|
1397
|
+
subset: str,
|
|
1398
|
+
exp_name: str = "__all__",
|
|
1399
|
+
granularity: str = "experiment",
|
|
1400
|
+
filter_models: Optional[list[str]] = None,
|
|
1401
|
+
max_points: int = 50,
|
|
1402
|
+
) -> dict[str, Any]:
|
|
1403
|
+
"""
|
|
1404
|
+
Get ELO history over experiments by reading pre-computed elo_snapshot.json files.
|
|
1405
|
+
|
|
1406
|
+
Args:
|
|
1407
|
+
subset: Subset name
|
|
1408
|
+
exp_name: Experiment name (only "__all__" or "experiment" granularity supported)
|
|
1409
|
+
granularity: Grouping method ("experiment" reads from snapshots; time-based not supported)
|
|
1410
|
+
filter_models: Optional models to track
|
|
1411
|
+
max_points: Maximum number of time points to return
|
|
1412
|
+
|
|
1413
|
+
Returns:
|
|
1414
|
+
Dict with:
|
|
1415
|
+
- timestamps: List of experiment names
|
|
1416
|
+
- models: Dict mapping model -> list of ELO values
|
|
1417
|
+
- battle_counts: List of cumulative battle counts
|
|
1418
|
+
"""
|
|
1419
|
+
# Get subset info for experiment order
|
|
1420
|
+
info = self.get_subset_info(subset)
|
|
1421
|
+
if not info:
|
|
1422
|
+
return {"timestamps": [], "models": {}, "battle_counts": []}
|
|
1423
|
+
|
|
1424
|
+
# Only support experiment-level granularity (reading from snapshots)
|
|
1425
|
+
# Time-based granularity would require real-time computation which we want to avoid
|
|
1426
|
+
if granularity != "experiment":
|
|
1427
|
+
logger.warning(
|
|
1428
|
+
f"Time-based granularity '{granularity}' is not supported for ELO history. "
|
|
1429
|
+
f"Falling back to 'experiment' granularity."
|
|
1430
|
+
)
|
|
1431
|
+
|
|
1432
|
+
# Get ordered list of experiments
|
|
1433
|
+
experiments = info.experiments
|
|
1434
|
+
if not experiments:
|
|
1435
|
+
return {"timestamps": [], "models": {}, "battle_counts": []}
|
|
1436
|
+
|
|
1437
|
+
# If too many experiments, sample them
|
|
1438
|
+
if len(experiments) > max_points:
|
|
1439
|
+
step = len(experiments) // max_points
|
|
1440
|
+
sampled = [experiments[i] for i in range(0, len(experiments), step)]
|
|
1441
|
+
if sampled[-1] != experiments[-1]:
|
|
1442
|
+
sampled.append(experiments[-1])
|
|
1443
|
+
experiments = sampled
|
|
1444
|
+
|
|
1445
|
+
# Load ELO snapshots for each experiment
|
|
1446
|
+
timestamps: list[str] = []
|
|
1447
|
+
model_elos: dict[str, list[Optional[float]]] = {}
|
|
1448
|
+
battle_counts: list[int] = []
|
|
1449
|
+
|
|
1450
|
+
pk_logs_dir = os.path.join(self.arena_dir, subset, "pk_logs")
|
|
1451
|
+
|
|
1452
|
+
for exp in experiments:
|
|
1453
|
+
snapshot_path = os.path.join(pk_logs_dir, exp, "elo_snapshot.json")
|
|
1454
|
+
snapshot = self._load_elo_snapshot(snapshot_path)
|
|
1455
|
+
|
|
1456
|
+
if snapshot is None:
|
|
1457
|
+
# Skip experiments without snapshots
|
|
1458
|
+
continue
|
|
1459
|
+
|
|
1460
|
+
elo_ratings = snapshot["elo"]
|
|
1461
|
+
battle_count = snapshot["battle_count"]
|
|
1462
|
+
|
|
1463
|
+
timestamps.append(exp)
|
|
1464
|
+
battle_counts.append(battle_count)
|
|
1465
|
+
|
|
1466
|
+
# Update model ELOs
|
|
1467
|
+
all_models_so_far = set(model_elos.keys()) | set(elo_ratings.keys())
|
|
1468
|
+
for model in all_models_so_far:
|
|
1469
|
+
if model not in model_elos:
|
|
1470
|
+
# New model: fill with None for previous timestamps
|
|
1471
|
+
model_elos[model] = [None] * (len(timestamps) - 1)
|
|
1472
|
+
model_elos[model].append(elo_ratings.get(model))
|
|
1473
|
+
|
|
1474
|
+
# Ensure all models have the same length
|
|
1475
|
+
for model in model_elos:
|
|
1476
|
+
if len(model_elos[model]) < len(timestamps):
|
|
1477
|
+
model_elos[model].append(None)
|
|
1478
|
+
|
|
1479
|
+
# Filter to requested models if specified
|
|
1480
|
+
if filter_models:
|
|
1481
|
+
filter_set = set(filter_models)
|
|
1482
|
+
model_elos = {m: v for m, v in model_elos.items() if m in filter_set}
|
|
1483
|
+
|
|
1484
|
+
return {
|
|
1485
|
+
"timestamps": timestamps,
|
|
1486
|
+
"models": model_elos,
|
|
1487
|
+
"battle_counts": battle_counts,
|
|
1488
|
+
}
|
|
1489
|
+
|
|
1490
|
+
def get_cross_subset_info(
|
|
1491
|
+
self,
|
|
1492
|
+
subsets: list[str],
|
|
1493
|
+
) -> dict[str, Any]:
|
|
1494
|
+
"""
|
|
1495
|
+
Get information about models across multiple subsets.
|
|
1496
|
+
|
|
1497
|
+
Args:
|
|
1498
|
+
subsets: List of subset names
|
|
1499
|
+
|
|
1500
|
+
Returns:
|
|
1501
|
+
Dict with:
|
|
1502
|
+
- common_models: Models present in all subsets
|
|
1503
|
+
- all_models: Models present in any subset
|
|
1504
|
+
- per_subset_models: Dict mapping subset -> list of models
|
|
1505
|
+
- per_subset_battles: Dict mapping subset -> battle count
|
|
1506
|
+
"""
|
|
1507
|
+
per_subset_models: dict[str, set[str]] = {}
|
|
1508
|
+
per_subset_battles: dict[str, int] = {}
|
|
1509
|
+
|
|
1510
|
+
for subset in subsets:
|
|
1511
|
+
info = self.get_subset_info(subset)
|
|
1512
|
+
if info:
|
|
1513
|
+
per_subset_models[subset] = set(info.models)
|
|
1514
|
+
per_subset_battles[subset] = info.total_battles
|
|
1515
|
+
|
|
1516
|
+
if not per_subset_models:
|
|
1517
|
+
return {
|
|
1518
|
+
"common_models": [],
|
|
1519
|
+
"all_models": [],
|
|
1520
|
+
"per_subset_models": {},
|
|
1521
|
+
"per_subset_battles": {},
|
|
1522
|
+
}
|
|
1523
|
+
|
|
1524
|
+
# Compute intersection and union
|
|
1525
|
+
all_model_sets = list(per_subset_models.values())
|
|
1526
|
+
common_models = set.intersection(*all_model_sets) if all_model_sets else set()
|
|
1527
|
+
all_models = set.union(*all_model_sets) if all_model_sets else set()
|
|
1528
|
+
|
|
1529
|
+
return {
|
|
1530
|
+
"common_models": sorted(common_models),
|
|
1531
|
+
"all_models": sorted(all_models),
|
|
1532
|
+
"per_subset_models": {s: sorted(m) for s, m in per_subset_models.items()},
|
|
1533
|
+
"per_subset_battles": per_subset_battles,
|
|
1534
|
+
"total_battles": sum(per_subset_battles.values()),
|
|
1535
|
+
}
|
|
1536
|
+
|
|
1537
|
+
def get_cross_subset_elo(
|
|
1538
|
+
self,
|
|
1539
|
+
subsets: list[str],
|
|
1540
|
+
exp_name: str = "__all__",
|
|
1541
|
+
model_scope: str = "all",
|
|
1542
|
+
) -> dict[str, Any]:
|
|
1543
|
+
"""
|
|
1544
|
+
Compute ELO rankings across multiple subsets.
|
|
1545
|
+
|
|
1546
|
+
Args:
|
|
1547
|
+
subsets: List of subset names
|
|
1548
|
+
exp_name: Experiment name (use "__all__" for all)
|
|
1549
|
+
model_scope: "common" = only models in all subsets, "all" = all models
|
|
1550
|
+
|
|
1551
|
+
Returns:
|
|
1552
|
+
Dict with merged leaderboard and per-subset comparison
|
|
1553
|
+
"""
|
|
1554
|
+
# Check cache first
|
|
1555
|
+
cache_key = (tuple(sorted(subsets)), exp_name, model_scope)
|
|
1556
|
+
if cache_key in self._cross_subset_elo_cache:
|
|
1557
|
+
return self._cross_subset_elo_cache[cache_key]
|
|
1558
|
+
|
|
1559
|
+
from genarena.bt_elo import compute_bt_elo_ratings
|
|
1560
|
+
|
|
1561
|
+
# Get cross-subset info
|
|
1562
|
+
cross_info = self.get_cross_subset_info(subsets)
|
|
1563
|
+
|
|
1564
|
+
# Determine models to include
|
|
1565
|
+
if model_scope == "common":
|
|
1566
|
+
included_models = set(cross_info["common_models"])
|
|
1567
|
+
else:
|
|
1568
|
+
included_models = set(cross_info["all_models"])
|
|
1569
|
+
|
|
1570
|
+
if not included_models:
|
|
1571
|
+
return {
|
|
1572
|
+
"subsets": subsets,
|
|
1573
|
+
"model_scope": model_scope,
|
|
1574
|
+
"common_models": cross_info["common_models"],
|
|
1575
|
+
"all_models": cross_info["all_models"],
|
|
1576
|
+
"total_battles": 0,
|
|
1577
|
+
"leaderboard": [],
|
|
1578
|
+
"per_subset_elo": {},
|
|
1579
|
+
}
|
|
1580
|
+
|
|
1581
|
+
# Collect all battles
|
|
1582
|
+
all_battles = []
|
|
1583
|
+
model_presence: dict[str, set[str]] = {} # model -> set of subsets it's in
|
|
1584
|
+
|
|
1585
|
+
for subset in subsets:
|
|
1586
|
+
if exp_name == "__all__":
|
|
1587
|
+
records = self._load_all_experiments_battles(subset)
|
|
1588
|
+
else:
|
|
1589
|
+
records = self._load_battle_logs(subset, exp_name)
|
|
1590
|
+
|
|
1591
|
+
for record in records:
|
|
1592
|
+
# Skip if either model is not in included set
|
|
1593
|
+
if model_scope == "common":
|
|
1594
|
+
if record.model_a not in included_models or record.model_b not in included_models:
|
|
1595
|
+
continue
|
|
1596
|
+
|
|
1597
|
+
# Convert to bt_elo format
|
|
1598
|
+
if record.final_winner == record.model_a:
|
|
1599
|
+
winner = "model_a"
|
|
1600
|
+
elif record.final_winner == record.model_b:
|
|
1601
|
+
winner = "model_b"
|
|
1602
|
+
else:
|
|
1603
|
+
winner = "tie"
|
|
1604
|
+
|
|
1605
|
+
all_battles.append((record.model_a, record.model_b, winner))
|
|
1606
|
+
|
|
1607
|
+
# Track model presence
|
|
1608
|
+
for m in [record.model_a, record.model_b]:
|
|
1609
|
+
if m not in model_presence:
|
|
1610
|
+
model_presence[m] = set()
|
|
1611
|
+
model_presence[m].add(subset)
|
|
1612
|
+
|
|
1613
|
+
if not all_battles:
|
|
1614
|
+
return {
|
|
1615
|
+
"subsets": subsets,
|
|
1616
|
+
"model_scope": model_scope,
|
|
1617
|
+
"common_models": cross_info["common_models"],
|
|
1618
|
+
"all_models": cross_info["all_models"],
|
|
1619
|
+
"total_battles": 0,
|
|
1620
|
+
"leaderboard": [],
|
|
1621
|
+
"per_subset_elo": {},
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
# Compute merged ELO
|
|
1625
|
+
try:
|
|
1626
|
+
ratings = compute_bt_elo_ratings(all_battles)
|
|
1627
|
+
except Exception as e:
|
|
1628
|
+
logger.error(f"Failed to compute cross-subset ELO: {e}")
|
|
1629
|
+
return {
|
|
1630
|
+
"subsets": subsets,
|
|
1631
|
+
"model_scope": model_scope,
|
|
1632
|
+
"error": str(e),
|
|
1633
|
+
"total_battles": len(all_battles),
|
|
1634
|
+
"leaderboard": [],
|
|
1635
|
+
}
|
|
1636
|
+
|
|
1637
|
+
# Count wins/losses/ties per model
|
|
1638
|
+
model_stats: dict[str, dict[str, int]] = {}
|
|
1639
|
+
for ma, mb, winner in all_battles:
|
|
1640
|
+
for m in [ma, mb]:
|
|
1641
|
+
if m not in model_stats:
|
|
1642
|
+
model_stats[m] = {"wins": 0, "losses": 0, "ties": 0}
|
|
1643
|
+
|
|
1644
|
+
if winner == "model_a":
|
|
1645
|
+
model_stats[ma]["wins"] += 1
|
|
1646
|
+
model_stats[mb]["losses"] += 1
|
|
1647
|
+
elif winner == "model_b":
|
|
1648
|
+
model_stats[mb]["wins"] += 1
|
|
1649
|
+
model_stats[ma]["losses"] += 1
|
|
1650
|
+
else:
|
|
1651
|
+
model_stats[ma]["ties"] += 1
|
|
1652
|
+
model_stats[mb]["ties"] += 1
|
|
1653
|
+
|
|
1654
|
+
# Build leaderboard
|
|
1655
|
+
leaderboard = []
|
|
1656
|
+
for model, elo in ratings.items():
|
|
1657
|
+
stats = model_stats.get(model, {"wins": 0, "losses": 0, "ties": 0})
|
|
1658
|
+
total = stats["wins"] + stats["losses"] + stats["ties"]
|
|
1659
|
+
leaderboard.append({
|
|
1660
|
+
"model": model,
|
|
1661
|
+
"elo": round(elo, 1),
|
|
1662
|
+
"wins": stats["wins"],
|
|
1663
|
+
"losses": stats["losses"],
|
|
1664
|
+
"ties": stats["ties"],
|
|
1665
|
+
"total": total,
|
|
1666
|
+
"win_rate": (stats["wins"] + 0.5 * stats["ties"]) / total if total > 0 else 0,
|
|
1667
|
+
"subset_presence": sorted(model_presence.get(model, set())),
|
|
1668
|
+
})
|
|
1669
|
+
|
|
1670
|
+
leaderboard.sort(key=lambda x: -x["elo"])
|
|
1671
|
+
|
|
1672
|
+
# Get per-subset ELO for comparison
|
|
1673
|
+
per_subset_elo: dict[str, dict[str, float]] = {}
|
|
1674
|
+
for subset in subsets:
|
|
1675
|
+
subset_lb = self.get_elo_leaderboard(subset)
|
|
1676
|
+
per_subset_elo[subset] = {entry["model"]: entry["elo"] for entry in subset_lb}
|
|
1677
|
+
|
|
1678
|
+
result = {
|
|
1679
|
+
"subsets": subsets,
|
|
1680
|
+
"model_scope": model_scope,
|
|
1681
|
+
"common_models": cross_info["common_models"],
|
|
1682
|
+
"all_models": cross_info["all_models"],
|
|
1683
|
+
"total_battles": len(all_battles),
|
|
1684
|
+
"leaderboard": leaderboard,
|
|
1685
|
+
"per_subset_elo": per_subset_elo,
|
|
1686
|
+
}
|
|
1687
|
+
|
|
1688
|
+
# Cache the result
|
|
1689
|
+
self._cross_subset_elo_cache[cache_key] = result
|
|
1690
|
+
return result
|
|
1691
|
+
|
|
1692
|
+
def get_stats(self, subset: str, exp_name: Optional[str] = None) -> dict[str, Any]:
|
|
1693
|
+
"""
|
|
1694
|
+
Get statistics for a subset.
|
|
1695
|
+
|
|
1696
|
+
Args:
|
|
1697
|
+
subset: Subset name
|
|
1698
|
+
exp_name: Optional experiment name (if None, uses overall state; "__all__" for all experiments)
|
|
1699
|
+
|
|
1700
|
+
Returns:
|
|
1701
|
+
Statistics dictionary
|
|
1702
|
+
"""
|
|
1703
|
+
info = self.get_subset_info(subset)
|
|
1704
|
+
if not info:
|
|
1705
|
+
return {}
|
|
1706
|
+
|
|
1707
|
+
if exp_name == "__all__":
|
|
1708
|
+
# Combine stats from all experiments
|
|
1709
|
+
records = self._load_all_experiments_battles(subset)
|
|
1710
|
+
total_battles = len(records)
|
|
1711
|
+
consistent = sum(1 for r in records if r.is_consistent)
|
|
1712
|
+
ties = sum(1 for r in records if r.final_winner == "tie")
|
|
1713
|
+
elif exp_name:
|
|
1714
|
+
records = self._load_battle_logs(subset, exp_name)
|
|
1715
|
+
total_battles = len(records)
|
|
1716
|
+
consistent = sum(1 for r in records if r.is_consistent)
|
|
1717
|
+
ties = sum(1 for r in records if r.final_winner == "tie")
|
|
1718
|
+
else:
|
|
1719
|
+
total_battles = info.total_battles
|
|
1720
|
+
consistent = 0
|
|
1721
|
+
ties = 0
|
|
1722
|
+
|
|
1723
|
+
return {
|
|
1724
|
+
"subset": subset,
|
|
1725
|
+
"models": info.models,
|
|
1726
|
+
"experiments": info.experiments,
|
|
1727
|
+
"total_battles": total_battles,
|
|
1728
|
+
"consistent_battles": consistent,
|
|
1729
|
+
"tie_battles": ties,
|
|
1730
|
+
"consistency_rate": consistent / total_battles if total_battles > 0 else 0,
|
|
1731
|
+
}
|
|
1732
|
+
|
|
1733
|
+
def get_model_win_stats(
|
|
1734
|
+
self, subset: str, exp_name: str, sample_index: int,
|
|
1735
|
+
filter_models: Optional[list[str]] = None
|
|
1736
|
+
) -> dict[str, dict[str, Any]]:
|
|
1737
|
+
"""
|
|
1738
|
+
Get win/loss statistics for all models on a specific sample.
|
|
1739
|
+
|
|
1740
|
+
Args:
|
|
1741
|
+
subset: Subset name
|
|
1742
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
1743
|
+
sample_index: Sample index
|
|
1744
|
+
filter_models: Optional list of models to filter (only count battles between these models)
|
|
1745
|
+
|
|
1746
|
+
Returns:
|
|
1747
|
+
Dict mapping model name to stats (wins, losses, ties, total, win_rate)
|
|
1748
|
+
"""
|
|
1749
|
+
if exp_name == "__all__":
|
|
1750
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
1751
|
+
else:
|
|
1752
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
1753
|
+
|
|
1754
|
+
# Filter records for this sample
|
|
1755
|
+
sample_records = [r for r in all_records if r.sample_index == sample_index]
|
|
1756
|
+
|
|
1757
|
+
# If filter_models is specified, only count battles between those models
|
|
1758
|
+
if filter_models:
|
|
1759
|
+
filter_set = set(filter_models)
|
|
1760
|
+
sample_records = [
|
|
1761
|
+
r for r in sample_records
|
|
1762
|
+
if r.model_a in filter_set and r.model_b in filter_set
|
|
1763
|
+
]
|
|
1764
|
+
|
|
1765
|
+
# Collect stats per model
|
|
1766
|
+
model_stats: dict[str, dict[str, int]] = {}
|
|
1767
|
+
|
|
1768
|
+
for record in sample_records:
|
|
1769
|
+
for model in [record.model_a, record.model_b]:
|
|
1770
|
+
if model not in model_stats:
|
|
1771
|
+
model_stats[model] = {"wins": 0, "losses": 0, "ties": 0}
|
|
1772
|
+
|
|
1773
|
+
if record.final_winner == "tie":
|
|
1774
|
+
model_stats[record.model_a]["ties"] += 1
|
|
1775
|
+
model_stats[record.model_b]["ties"] += 1
|
|
1776
|
+
elif record.final_winner == record.model_a:
|
|
1777
|
+
model_stats[record.model_a]["wins"] += 1
|
|
1778
|
+
model_stats[record.model_b]["losses"] += 1
|
|
1779
|
+
elif record.final_winner == record.model_b:
|
|
1780
|
+
model_stats[record.model_b]["wins"] += 1
|
|
1781
|
+
model_stats[record.model_a]["losses"] += 1
|
|
1782
|
+
|
|
1783
|
+
# Calculate win rate and total
|
|
1784
|
+
result: dict[str, dict[str, Any]] = {}
|
|
1785
|
+
for model, stats in model_stats.items():
|
|
1786
|
+
total = stats["wins"] + stats["losses"] + stats["ties"]
|
|
1787
|
+
win_rate = stats["wins"] / total if total > 0 else 0
|
|
1788
|
+
result[model] = {
|
|
1789
|
+
"wins": stats["wins"],
|
|
1790
|
+
"losses": stats["losses"],
|
|
1791
|
+
"ties": stats["ties"],
|
|
1792
|
+
"total": total,
|
|
1793
|
+
"win_rate": win_rate,
|
|
1794
|
+
}
|
|
1795
|
+
|
|
1796
|
+
return result
|
|
1797
|
+
|
|
1798
|
+
def get_sample_all_models(
|
|
1799
|
+
self, subset: str, exp_name: str, sample_index: int,
|
|
1800
|
+
filter_models: Optional[list[str]] = None,
|
|
1801
|
+
stats_scope: str = "filtered"
|
|
1802
|
+
) -> dict[str, Any]:
|
|
1803
|
+
"""
|
|
1804
|
+
Get all model outputs for a specific sample, sorted by win rate.
|
|
1805
|
+
|
|
1806
|
+
Args:
|
|
1807
|
+
subset: Subset name
|
|
1808
|
+
exp_name: Experiment name
|
|
1809
|
+
sample_index: Sample index
|
|
1810
|
+
filter_models: Optional list of models to filter (show only these models)
|
|
1811
|
+
stats_scope: 'filtered' = only count battles between filtered models,
|
|
1812
|
+
'all' = count all battles (but show only filtered models)
|
|
1813
|
+
|
|
1814
|
+
Returns:
|
|
1815
|
+
Dict with sample info and all model outputs sorted by win rate
|
|
1816
|
+
"""
|
|
1817
|
+
# Get sample metadata
|
|
1818
|
+
sample_meta = self._get_sample_data(subset, sample_index)
|
|
1819
|
+
|
|
1820
|
+
# Determine which models to use for stats calculation
|
|
1821
|
+
# If stats_scope is 'all', don't filter battles by models
|
|
1822
|
+
stats_filter = filter_models if stats_scope == "filtered" else None
|
|
1823
|
+
model_stats = self.get_model_win_stats(subset, exp_name, sample_index, stats_filter)
|
|
1824
|
+
|
|
1825
|
+
# Get all models that have outputs
|
|
1826
|
+
model_manager = self._get_model_manager(subset)
|
|
1827
|
+
available_models = []
|
|
1828
|
+
|
|
1829
|
+
if model_manager:
|
|
1830
|
+
# Determine which models to include
|
|
1831
|
+
models_to_check = model_manager.models
|
|
1832
|
+
if filter_models:
|
|
1833
|
+
filter_set = set(filter_models)
|
|
1834
|
+
models_to_check = [m for m in models_to_check if m in filter_set]
|
|
1835
|
+
|
|
1836
|
+
for model in models_to_check:
|
|
1837
|
+
output_path = model_manager.get_output_path(model, sample_index)
|
|
1838
|
+
if output_path and os.path.isfile(output_path):
|
|
1839
|
+
stats = model_stats.get(model, {
|
|
1840
|
+
"wins": 0, "losses": 0, "ties": 0, "total": 0, "win_rate": 0
|
|
1841
|
+
})
|
|
1842
|
+
available_models.append({
|
|
1843
|
+
"model": model,
|
|
1844
|
+
"wins": stats["wins"],
|
|
1845
|
+
"losses": stats["losses"],
|
|
1846
|
+
"ties": stats["ties"],
|
|
1847
|
+
"total": stats["total"],
|
|
1848
|
+
"win_rate": stats["win_rate"],
|
|
1849
|
+
})
|
|
1850
|
+
|
|
1851
|
+
# Sort by win rate (descending), then by wins (descending), then by model name
|
|
1852
|
+
available_models.sort(key=lambda x: (-x["win_rate"], -x["wins"], x["model"]))
|
|
1853
|
+
|
|
1854
|
+
return {
|
|
1855
|
+
"subset": subset,
|
|
1856
|
+
"exp_name": exp_name,
|
|
1857
|
+
"sample_index": sample_index,
|
|
1858
|
+
"instruction": sample_meta.get("instruction", ""),
|
|
1859
|
+
"task_type": sample_meta.get("task_type", ""),
|
|
1860
|
+
"input_image_count": sample_meta.get("input_image_count", 1),
|
|
1861
|
+
"prompt_source": sample_meta.get("prompt_source"),
|
|
1862
|
+
"original_metadata": sample_meta.get("original_metadata"),
|
|
1863
|
+
"models": available_models,
|
|
1864
|
+
}
|
|
1865
|
+
|
|
1866
|
+
def get_model_battles_for_sample(
|
|
1867
|
+
self,
|
|
1868
|
+
subset: str,
|
|
1869
|
+
exp_name: str,
|
|
1870
|
+
sample_index: int,
|
|
1871
|
+
model: str,
|
|
1872
|
+
opponent_models: Optional[list[str]] = None,
|
|
1873
|
+
) -> dict[str, Any]:
|
|
1874
|
+
"""
|
|
1875
|
+
Get all battle records for a specific model on a specific sample.
|
|
1876
|
+
|
|
1877
|
+
Args:
|
|
1878
|
+
subset: Subset name
|
|
1879
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
1880
|
+
sample_index: Sample index
|
|
1881
|
+
model: The model to get battles for
|
|
1882
|
+
opponent_models: Optional list of opponent models to filter by
|
|
1883
|
+
|
|
1884
|
+
Returns:
|
|
1885
|
+
Dict with model info and list of battle records
|
|
1886
|
+
"""
|
|
1887
|
+
if exp_name == "__all__":
|
|
1888
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
1889
|
+
else:
|
|
1890
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
1891
|
+
|
|
1892
|
+
# Filter records for this sample and involving this model
|
|
1893
|
+
model_battles = []
|
|
1894
|
+
all_opponents = set()
|
|
1895
|
+
|
|
1896
|
+
for record in all_records:
|
|
1897
|
+
if record.sample_index != sample_index:
|
|
1898
|
+
continue
|
|
1899
|
+
if model not in [record.model_a, record.model_b]:
|
|
1900
|
+
continue
|
|
1901
|
+
|
|
1902
|
+
# Determine opponent
|
|
1903
|
+
opponent = record.model_b if record.model_a == model else record.model_a
|
|
1904
|
+
all_opponents.add(opponent)
|
|
1905
|
+
|
|
1906
|
+
# Apply opponent filter if specified
|
|
1907
|
+
if opponent_models and opponent not in opponent_models:
|
|
1908
|
+
continue
|
|
1909
|
+
|
|
1910
|
+
# Determine result for this model
|
|
1911
|
+
if record.final_winner == "tie":
|
|
1912
|
+
result = "tie"
|
|
1913
|
+
elif record.final_winner == model:
|
|
1914
|
+
result = "win"
|
|
1915
|
+
else:
|
|
1916
|
+
result = "loss"
|
|
1917
|
+
|
|
1918
|
+
# Build battle data with judge outputs
|
|
1919
|
+
battle_data = {
|
|
1920
|
+
"opponent": opponent,
|
|
1921
|
+
"result": result,
|
|
1922
|
+
"is_consistent": record.is_consistent,
|
|
1923
|
+
"model_a": record.model_a,
|
|
1924
|
+
"model_b": record.model_b,
|
|
1925
|
+
"final_winner": record.final_winner,
|
|
1926
|
+
"exp_name": record.exp_name,
|
|
1927
|
+
}
|
|
1928
|
+
|
|
1929
|
+
# Load audit logs if not already loaded on the record
|
|
1930
|
+
if not record.original_call and not record.swapped_call:
|
|
1931
|
+
actual_exp_name = record.exp_name
|
|
1932
|
+
audit = self._load_audit_log(
|
|
1933
|
+
subset, actual_exp_name, record.model_a, record.model_b, sample_index
|
|
1934
|
+
)
|
|
1935
|
+
if audit:
|
|
1936
|
+
battle_data["original_call"] = audit.get("original_call")
|
|
1937
|
+
battle_data["swapped_call"] = audit.get("swapped_call")
|
|
1938
|
+
else:
|
|
1939
|
+
# Use existing data if available
|
|
1940
|
+
if record.original_call:
|
|
1941
|
+
battle_data["original_call"] = record.original_call
|
|
1942
|
+
if record.swapped_call:
|
|
1943
|
+
battle_data["swapped_call"] = record.swapped_call
|
|
1944
|
+
|
|
1945
|
+
model_battles.append(battle_data)
|
|
1946
|
+
|
|
1947
|
+
# Sort battles by opponent name
|
|
1948
|
+
model_battles.sort(key=lambda x: x["opponent"])
|
|
1949
|
+
|
|
1950
|
+
# Get model stats
|
|
1951
|
+
model_stats = self.get_model_win_stats(subset, exp_name, sample_index)
|
|
1952
|
+
stats = model_stats.get(model, {
|
|
1953
|
+
"wins": 0, "losses": 0, "ties": 0, "total": 0, "win_rate": 0
|
|
1954
|
+
})
|
|
1955
|
+
|
|
1956
|
+
return {
|
|
1957
|
+
"model": model,
|
|
1958
|
+
"sample_index": sample_index,
|
|
1959
|
+
"wins": stats["wins"],
|
|
1960
|
+
"losses": stats["losses"],
|
|
1961
|
+
"ties": stats["ties"],
|
|
1962
|
+
"total": stats["total"],
|
|
1963
|
+
"win_rate": stats["win_rate"],
|
|
1964
|
+
"battles": model_battles,
|
|
1965
|
+
"all_opponents": sorted(list(all_opponents)),
|
|
1966
|
+
}
|
|
1967
|
+
|
|
1968
|
+
def get_elo_leaderboard(
|
|
1969
|
+
self,
|
|
1970
|
+
subset: str,
|
|
1971
|
+
filter_models: Optional[list[str]] = None,
|
|
1972
|
+
) -> list[dict[str, Any]]:
|
|
1973
|
+
"""
|
|
1974
|
+
Get ELO leaderboard for a subset from state.json.
|
|
1975
|
+
|
|
1976
|
+
Args:
|
|
1977
|
+
subset: Subset name
|
|
1978
|
+
filter_models: Optional list of models to filter (show only these models)
|
|
1979
|
+
|
|
1980
|
+
Returns:
|
|
1981
|
+
List of model stats sorted by ELO rating (descending)
|
|
1982
|
+
"""
|
|
1983
|
+
info = self.get_subset_info(subset)
|
|
1984
|
+
if not info or not info.state:
|
|
1985
|
+
return []
|
|
1986
|
+
|
|
1987
|
+
state = info.state
|
|
1988
|
+
leaderboard = []
|
|
1989
|
+
|
|
1990
|
+
for model_name, model_stats in state.models.items():
|
|
1991
|
+
# Apply filter if specified
|
|
1992
|
+
if filter_models and model_name not in filter_models:
|
|
1993
|
+
continue
|
|
1994
|
+
|
|
1995
|
+
leaderboard.append({
|
|
1996
|
+
"model": model_name,
|
|
1997
|
+
"elo": model_stats.elo,
|
|
1998
|
+
"wins": model_stats.wins,
|
|
1999
|
+
"losses": model_stats.losses,
|
|
2000
|
+
"ties": model_stats.ties,
|
|
2001
|
+
"total_battles": model_stats.total_battles,
|
|
2002
|
+
"win_rate": model_stats.win_rate,
|
|
2003
|
+
})
|
|
2004
|
+
|
|
2005
|
+
# Sort by ELO rating (descending)
|
|
2006
|
+
leaderboard.sort(key=lambda x: -x["elo"])
|
|
2007
|
+
|
|
2008
|
+
# Add rank
|
|
2009
|
+
for i, entry in enumerate(leaderboard):
|
|
2010
|
+
entry["rank"] = i + 1
|
|
2011
|
+
|
|
2012
|
+
return leaderboard
|
|
2013
|
+
|
|
2014
|
+
def get_model_vs_stats(
|
|
2015
|
+
self,
|
|
2016
|
+
subset: str,
|
|
2017
|
+
model: str,
|
|
2018
|
+
exp_name: str = "__all__",
|
|
2019
|
+
) -> dict[str, Any]:
|
|
2020
|
+
"""
|
|
2021
|
+
Get win/loss/tie stats of a specific model against all other models.
|
|
2022
|
+
|
|
2023
|
+
Args:
|
|
2024
|
+
subset: Subset name
|
|
2025
|
+
model: Target model name
|
|
2026
|
+
exp_name: Experiment name (default "__all__" for all experiments)
|
|
2027
|
+
|
|
2028
|
+
Returns:
|
|
2029
|
+
Dict with model stats and versus stats against each opponent
|
|
2030
|
+
"""
|
|
2031
|
+
# Get overall ELO stats
|
|
2032
|
+
info = self.get_subset_info(subset)
|
|
2033
|
+
if not info or not info.state:
|
|
2034
|
+
return {}
|
|
2035
|
+
|
|
2036
|
+
state = info.state
|
|
2037
|
+
if model not in state.models:
|
|
2038
|
+
return {}
|
|
2039
|
+
|
|
2040
|
+
model_stats = state.models[model]
|
|
2041
|
+
|
|
2042
|
+
# Load battle records
|
|
2043
|
+
if exp_name == "__all__":
|
|
2044
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
2045
|
+
else:
|
|
2046
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
2047
|
+
|
|
2048
|
+
# Calculate stats against each opponent
|
|
2049
|
+
vs_stats: dict[str, dict[str, int]] = {}
|
|
2050
|
+
|
|
2051
|
+
for record in all_records:
|
|
2052
|
+
if model not in [record.model_a, record.model_b]:
|
|
2053
|
+
continue
|
|
2054
|
+
|
|
2055
|
+
opponent = record.model_b if record.model_a == model else record.model_a
|
|
2056
|
+
|
|
2057
|
+
if opponent not in vs_stats:
|
|
2058
|
+
vs_stats[opponent] = {"wins": 0, "losses": 0, "ties": 0}
|
|
2059
|
+
|
|
2060
|
+
if record.final_winner == "tie":
|
|
2061
|
+
vs_stats[opponent]["ties"] += 1
|
|
2062
|
+
elif record.final_winner == model:
|
|
2063
|
+
vs_stats[opponent]["wins"] += 1
|
|
2064
|
+
else:
|
|
2065
|
+
vs_stats[opponent]["losses"] += 1
|
|
2066
|
+
|
|
2067
|
+
# Convert to list with win rates and opponent ELO
|
|
2068
|
+
vs_list = []
|
|
2069
|
+
for opponent, stats in vs_stats.items():
|
|
2070
|
+
total = stats["wins"] + stats["losses"] + stats["ties"]
|
|
2071
|
+
opponent_elo = state.models[opponent].elo if opponent in state.models else 1000.0
|
|
2072
|
+
vs_list.append({
|
|
2073
|
+
"opponent": opponent,
|
|
2074
|
+
"opponent_elo": opponent_elo,
|
|
2075
|
+
"wins": stats["wins"],
|
|
2076
|
+
"losses": stats["losses"],
|
|
2077
|
+
"ties": stats["ties"],
|
|
2078
|
+
"total": total,
|
|
2079
|
+
"win_rate": stats["wins"] / total if total > 0 else 0,
|
|
2080
|
+
})
|
|
2081
|
+
|
|
2082
|
+
# Sort by opponent ELO (descending)
|
|
2083
|
+
vs_list.sort(key=lambda x: -x["opponent_elo"])
|
|
2084
|
+
|
|
2085
|
+
return {
|
|
2086
|
+
"model": model,
|
|
2087
|
+
"elo": model_stats.elo,
|
|
2088
|
+
"wins": model_stats.wins,
|
|
2089
|
+
"losses": model_stats.losses,
|
|
2090
|
+
"ties": model_stats.ties,
|
|
2091
|
+
"total_battles": model_stats.total_battles,
|
|
2092
|
+
"win_rate": model_stats.win_rate,
|
|
2093
|
+
"vs_stats": vs_list,
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
def get_all_subsets_leaderboards(self) -> dict[str, Any]:
|
|
2097
|
+
"""
|
|
2098
|
+
Get leaderboard data for all subsets (for Overview page).
|
|
2099
|
+
|
|
2100
|
+
Returns:
|
|
2101
|
+
Dict with:
|
|
2102
|
+
- subsets: List of subset names
|
|
2103
|
+
- models: List of all unique model names across all subsets
|
|
2104
|
+
- data: Dict mapping subset -> {model -> {elo, rank, wins, losses, ties, ...}}
|
|
2105
|
+
- subset_info: Dict mapping subset -> {total_battles, model_count}
|
|
2106
|
+
"""
|
|
2107
|
+
subsets = self.discover_subsets()
|
|
2108
|
+
all_models: set[str] = set()
|
|
2109
|
+
data: dict[str, dict[str, dict[str, Any]]] = {}
|
|
2110
|
+
subset_info: dict[str, dict[str, Any]] = {}
|
|
2111
|
+
|
|
2112
|
+
for subset in subsets:
|
|
2113
|
+
leaderboard = self.get_elo_leaderboard(subset)
|
|
2114
|
+
info = self.get_subset_info(subset)
|
|
2115
|
+
|
|
2116
|
+
if not leaderboard:
|
|
2117
|
+
continue
|
|
2118
|
+
|
|
2119
|
+
# Build subset data
|
|
2120
|
+
subset_data: dict[str, dict[str, Any]] = {}
|
|
2121
|
+
for entry in leaderboard:
|
|
2122
|
+
model = entry["model"]
|
|
2123
|
+
all_models.add(model)
|
|
2124
|
+
subset_data[model] = {
|
|
2125
|
+
"elo": entry["elo"],
|
|
2126
|
+
"rank": entry["rank"],
|
|
2127
|
+
"wins": entry["wins"],
|
|
2128
|
+
"losses": entry["losses"],
|
|
2129
|
+
"ties": entry["ties"],
|
|
2130
|
+
"total_battles": entry["total_battles"],
|
|
2131
|
+
"win_rate": entry["win_rate"],
|
|
2132
|
+
}
|
|
2133
|
+
|
|
2134
|
+
data[subset] = subset_data
|
|
2135
|
+
subset_info[subset] = {
|
|
2136
|
+
"total_battles": info.total_battles if info else 0,
|
|
2137
|
+
"model_count": len(leaderboard),
|
|
2138
|
+
}
|
|
2139
|
+
|
|
2140
|
+
# Sort models by average ELO across all subsets (descending)
|
|
2141
|
+
model_avg_elo: dict[str, tuple[float, int]] = {} # model -> (sum_elo, count)
|
|
2142
|
+
for model in all_models:
|
|
2143
|
+
total_elo = 0.0
|
|
2144
|
+
count = 0
|
|
2145
|
+
for subset in subsets:
|
|
2146
|
+
if subset in data and model in data[subset]:
|
|
2147
|
+
total_elo += data[subset][model]["elo"]
|
|
2148
|
+
count += 1
|
|
2149
|
+
if count > 0:
|
|
2150
|
+
model_avg_elo[model] = (total_elo / count, count)
|
|
2151
|
+
else:
|
|
2152
|
+
model_avg_elo[model] = (0.0, 0)
|
|
2153
|
+
|
|
2154
|
+
sorted_models = sorted(
|
|
2155
|
+
all_models,
|
|
2156
|
+
key=lambda m: (-model_avg_elo[m][0], -model_avg_elo[m][1], m)
|
|
2157
|
+
)
|
|
2158
|
+
|
|
2159
|
+
return {
|
|
2160
|
+
"subsets": subsets,
|
|
2161
|
+
"models": sorted_models,
|
|
2162
|
+
"data": data,
|
|
2163
|
+
"subset_info": subset_info,
|
|
2164
|
+
}
|
|
2165
|
+
|
|
2166
|
+
def get_prompts(
|
|
2167
|
+
self,
|
|
2168
|
+
subset: str,
|
|
2169
|
+
exp_name: str,
|
|
2170
|
+
page: int = 1,
|
|
2171
|
+
page_size: int = 10,
|
|
2172
|
+
min_images: Optional[int] = None,
|
|
2173
|
+
max_images: Optional[int] = None,
|
|
2174
|
+
prompt_source: Optional[str] = None,
|
|
2175
|
+
filter_models: Optional[list[str]] = None,
|
|
2176
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
2177
|
+
"""
|
|
2178
|
+
Get paginated list of prompts/samples with all model outputs.
|
|
2179
|
+
|
|
2180
|
+
Args:
|
|
2181
|
+
subset: Subset name
|
|
2182
|
+
exp_name: Experiment name (use "__all__" for all experiments)
|
|
2183
|
+
page: Page number (1-indexed)
|
|
2184
|
+
page_size: Number of records per page
|
|
2185
|
+
min_images: Minimum number of input images
|
|
2186
|
+
max_images: Maximum number of input images
|
|
2187
|
+
prompt_source: Filter by prompt source
|
|
2188
|
+
filter_models: Optional list of models to filter (show only these models)
|
|
2189
|
+
|
|
2190
|
+
Returns:
|
|
2191
|
+
Tuple of (prompts_list, total_count)
|
|
2192
|
+
"""
|
|
2193
|
+
# Get all sample indices from battle logs
|
|
2194
|
+
if exp_name == "__all__":
|
|
2195
|
+
all_records = self._load_all_experiments_battles(subset)
|
|
2196
|
+
else:
|
|
2197
|
+
all_records = self._load_battle_logs(subset, exp_name)
|
|
2198
|
+
|
|
2199
|
+
# Collect unique sample indices
|
|
2200
|
+
sample_indices = set()
|
|
2201
|
+
for record in all_records:
|
|
2202
|
+
sample_indices.add(record.sample_index)
|
|
2203
|
+
|
|
2204
|
+
# Sort sample indices
|
|
2205
|
+
sorted_indices = sorted(sample_indices)
|
|
2206
|
+
|
|
2207
|
+
# Apply filters
|
|
2208
|
+
filtered_indices = []
|
|
2209
|
+
for idx in sorted_indices:
|
|
2210
|
+
sample_meta = self._get_sample_data(subset, idx)
|
|
2211
|
+
img_count = sample_meta.get("input_image_count", 1)
|
|
2212
|
+
source = sample_meta.get("prompt_source")
|
|
2213
|
+
|
|
2214
|
+
# Apply image count filter
|
|
2215
|
+
if min_images is not None and img_count < min_images:
|
|
2216
|
+
continue
|
|
2217
|
+
if max_images is not None and img_count > max_images:
|
|
2218
|
+
continue
|
|
2219
|
+
|
|
2220
|
+
# Apply prompt source filter
|
|
2221
|
+
if prompt_source and source != prompt_source:
|
|
2222
|
+
continue
|
|
2223
|
+
|
|
2224
|
+
filtered_indices.append(idx)
|
|
2225
|
+
|
|
2226
|
+
total_count = len(filtered_indices)
|
|
2227
|
+
|
|
2228
|
+
# Paginate
|
|
2229
|
+
start = (page - 1) * page_size
|
|
2230
|
+
end = start + page_size
|
|
2231
|
+
page_indices = filtered_indices[start:end]
|
|
2232
|
+
|
|
2233
|
+
# Build prompt data for each sample
|
|
2234
|
+
prompts = []
|
|
2235
|
+
for idx in page_indices:
|
|
2236
|
+
prompt_data = self.get_sample_all_models(subset, exp_name, idx, filter_models)
|
|
2237
|
+
prompts.append(prompt_data)
|
|
2238
|
+
|
|
2239
|
+
return prompts, total_count
|
|
2240
|
+
|
|
2241
|
+
|
|
2242
|
+
class HFArenaDataLoader(ArenaDataLoader):
|
|
2243
|
+
"""
|
|
2244
|
+
Data loader for HuggingFace Spaces deployment.
|
|
2245
|
+
|
|
2246
|
+
Extends ArenaDataLoader to:
|
|
2247
|
+
- Build image URL index from HF file list
|
|
2248
|
+
- Return HF CDN URLs for model output images instead of local paths
|
|
2249
|
+
"""
|
|
2250
|
+
|
|
2251
|
+
def __init__(
|
|
2252
|
+
self,
|
|
2253
|
+
arena_dir: str,
|
|
2254
|
+
data_dir: str,
|
|
2255
|
+
hf_repo: str,
|
|
2256
|
+
image_files: list[str],
|
|
2257
|
+
preload: bool = True,
|
|
2258
|
+
):
|
|
2259
|
+
"""
|
|
2260
|
+
Initialize the HF data loader.
|
|
2261
|
+
|
|
2262
|
+
Args:
|
|
2263
|
+
arena_dir: Path to arena directory (metadata only, no images)
|
|
2264
|
+
data_dir: Path to data directory containing parquet files
|
|
2265
|
+
hf_repo: HuggingFace repo ID for image CDN URLs
|
|
2266
|
+
image_files: List of image file paths in the HF repo
|
|
2267
|
+
preload: If True, preload all data at initialization
|
|
2268
|
+
"""
|
|
2269
|
+
self.hf_repo = hf_repo
|
|
2270
|
+
self._image_url_index = self._build_image_index(image_files)
|
|
2271
|
+
super().__init__(arena_dir, data_dir, preload=preload)
|
|
2272
|
+
|
|
2273
|
+
def _build_image_index(
|
|
2274
|
+
self, image_files: list[str]
|
|
2275
|
+
) -> dict[tuple[str, str, int], str]:
|
|
2276
|
+
"""
|
|
2277
|
+
Build index: (subset, model, sample_index) -> hf_file_path
|
|
2278
|
+
|
|
2279
|
+
Expected path format: {subset}/models/{exp_name}/{model}/{index}.png
|
|
2280
|
+
|
|
2281
|
+
Args:
|
|
2282
|
+
image_files: List of image file paths from HF repo
|
|
2283
|
+
|
|
2284
|
+
Returns:
|
|
2285
|
+
Dict mapping (subset, model, sample_index) to HF file path
|
|
2286
|
+
"""
|
|
2287
|
+
from genarena.models import parse_image_index
|
|
2288
|
+
|
|
2289
|
+
index: dict[tuple[str, str, int], str] = {}
|
|
2290
|
+
|
|
2291
|
+
for path in image_files:
|
|
2292
|
+
parts = path.split("/")
|
|
2293
|
+
# Expected: subset/models/exp_name/model/000000.png
|
|
2294
|
+
if len(parts) >= 5 and parts[1] == "models":
|
|
2295
|
+
subset = parts[0]
|
|
2296
|
+
# exp_name = parts[2] # Not needed for lookup
|
|
2297
|
+
model = parts[3]
|
|
2298
|
+
filename = parts[4]
|
|
2299
|
+
idx = parse_image_index(filename)
|
|
2300
|
+
if idx is not None:
|
|
2301
|
+
# If duplicate, later entries overwrite earlier ones
|
|
2302
|
+
index[(subset, model, idx)] = path
|
|
2303
|
+
|
|
2304
|
+
logger.info(f"Built image URL index with {len(index)} entries")
|
|
2305
|
+
return index
|
|
2306
|
+
|
|
2307
|
+
def get_model_image_url(
|
|
2308
|
+
self, subset: str, model: str, sample_index: int
|
|
2309
|
+
) -> Optional[str]:
|
|
2310
|
+
"""
|
|
2311
|
+
Get HF CDN URL for model output image.
|
|
2312
|
+
|
|
2313
|
+
Args:
|
|
2314
|
+
subset: Subset name
|
|
2315
|
+
model: Model name
|
|
2316
|
+
sample_index: Sample index
|
|
2317
|
+
|
|
2318
|
+
Returns:
|
|
2319
|
+
HF CDN URL or None if not found
|
|
2320
|
+
"""
|
|
2321
|
+
path = self._image_url_index.get((subset, model, sample_index))
|
|
2322
|
+
if path:
|
|
2323
|
+
return f"https://huggingface.co/datasets/{self.hf_repo}/resolve/main/{path}"
|
|
2324
|
+
return None
|
|
2325
|
+
|
|
2326
|
+
def get_image_path(
|
|
2327
|
+
self, subset: str, model: str, sample_index: int
|
|
2328
|
+
) -> Optional[str]:
|
|
2329
|
+
"""
|
|
2330
|
+
Override to return None since images are served via CDN.
|
|
2331
|
+
|
|
2332
|
+
For HF deployment, use get_model_image_url() instead.
|
|
2333
|
+
"""
|
|
2334
|
+
# Return None to indicate image should be fetched via CDN
|
|
2335
|
+
return None
|
|
2336
|
+
|
|
2337
|
+
def _get_available_models_for_subset(self, subset: str) -> list[str]:
|
|
2338
|
+
"""
|
|
2339
|
+
Get list of models that have images in the HF CDN for this subset.
|
|
2340
|
+
|
|
2341
|
+
Returns:
|
|
2342
|
+
List of model names
|
|
2343
|
+
"""
|
|
2344
|
+
models = set()
|
|
2345
|
+
for (s, model, _) in self._image_url_index.keys():
|
|
2346
|
+
if s == subset:
|
|
2347
|
+
models.add(model)
|
|
2348
|
+
return sorted(models)
|
|
2349
|
+
|
|
2350
|
+
def _has_model_image(self, subset: str, model: str, sample_index: int) -> bool:
|
|
2351
|
+
"""
|
|
2352
|
+
Check if a model has an image for a specific sample in the HF CDN.
|
|
2353
|
+
|
|
2354
|
+
Args:
|
|
2355
|
+
subset: Subset name
|
|
2356
|
+
model: Model name
|
|
2357
|
+
sample_index: Sample index
|
|
2358
|
+
|
|
2359
|
+
Returns:
|
|
2360
|
+
True if image exists in CDN index
|
|
2361
|
+
"""
|
|
2362
|
+
return (subset, model, sample_index) in self._image_url_index
|
|
2363
|
+
|
|
2364
|
+
def get_sample_all_models(
|
|
2365
|
+
self, subset: str, exp_name: str, sample_index: int,
|
|
2366
|
+
filter_models: Optional[list[str]] = None,
|
|
2367
|
+
stats_scope: str = "filtered"
|
|
2368
|
+
) -> dict[str, Any]:
|
|
2369
|
+
"""
|
|
2370
|
+
Get all model outputs for a specific sample, sorted by win rate.
|
|
2371
|
+
|
|
2372
|
+
Override for HF deployment to use CDN image index instead of local files.
|
|
2373
|
+
|
|
2374
|
+
Args:
|
|
2375
|
+
subset: Subset name
|
|
2376
|
+
exp_name: Experiment name
|
|
2377
|
+
sample_index: Sample index
|
|
2378
|
+
filter_models: Optional list of models to filter (show only these models)
|
|
2379
|
+
stats_scope: 'filtered' = only count battles between filtered models,
|
|
2380
|
+
'all' = count all battles (but show only filtered models)
|
|
2381
|
+
|
|
2382
|
+
Returns:
|
|
2383
|
+
Dict with sample info and all model outputs sorted by win rate
|
|
2384
|
+
"""
|
|
2385
|
+
# Get sample metadata
|
|
2386
|
+
sample_meta = self._get_sample_data(subset, sample_index)
|
|
2387
|
+
|
|
2388
|
+
# Determine which models to use for stats calculation
|
|
2389
|
+
stats_filter = filter_models if stats_scope == "filtered" else None
|
|
2390
|
+
model_stats = self.get_model_win_stats(subset, exp_name, sample_index, stats_filter)
|
|
2391
|
+
|
|
2392
|
+
# Get all models that have outputs in CDN
|
|
2393
|
+
available_models_list = self._get_available_models_for_subset(subset)
|
|
2394
|
+
|
|
2395
|
+
# Apply filter if specified
|
|
2396
|
+
if filter_models:
|
|
2397
|
+
filter_set = set(filter_models)
|
|
2398
|
+
available_models_list = [m for m in available_models_list if m in filter_set]
|
|
2399
|
+
|
|
2400
|
+
# Build model info for models that have images for this sample
|
|
2401
|
+
available_models = []
|
|
2402
|
+
for model in available_models_list:
|
|
2403
|
+
# Check if model has image for this sample in CDN index
|
|
2404
|
+
if self._has_model_image(subset, model, sample_index):
|
|
2405
|
+
stats = model_stats.get(model, {
|
|
2406
|
+
"wins": 0, "losses": 0, "ties": 0, "total": 0, "win_rate": 0
|
|
2407
|
+
})
|
|
2408
|
+
available_models.append({
|
|
2409
|
+
"model": model,
|
|
2410
|
+
"wins": stats["wins"],
|
|
2411
|
+
"losses": stats["losses"],
|
|
2412
|
+
"ties": stats["ties"],
|
|
2413
|
+
"total": stats["total"],
|
|
2414
|
+
"win_rate": stats["win_rate"],
|
|
2415
|
+
})
|
|
2416
|
+
|
|
2417
|
+
# Sort by win rate (descending), then by wins (descending), then by model name
|
|
2418
|
+
available_models.sort(key=lambda x: (-x["win_rate"], -x["wins"], x["model"]))
|
|
2419
|
+
|
|
2420
|
+
return {
|
|
2421
|
+
"subset": subset,
|
|
2422
|
+
"exp_name": exp_name,
|
|
2423
|
+
"sample_index": sample_index,
|
|
2424
|
+
"instruction": sample_meta.get("instruction", ""),
|
|
2425
|
+
"task_type": sample_meta.get("task_type", ""),
|
|
2426
|
+
"input_image_count": sample_meta.get("input_image_count", 1),
|
|
2427
|
+
"prompt_source": sample_meta.get("prompt_source"),
|
|
2428
|
+
"original_metadata": sample_meta.get("original_metadata"),
|
|
2429
|
+
"models": available_models,
|
|
2430
|
+
}
|