genarena 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genarena/__init__.py +49 -2
- genarena/__main__.py +10 -0
- genarena/arena.py +1685 -0
- genarena/battle.py +337 -0
- genarena/bt_elo.py +507 -0
- genarena/cli.py +1581 -0
- genarena/data.py +476 -0
- genarena/deploy/Dockerfile +22 -0
- genarena/deploy/README.md +55 -0
- genarena/deploy/__init__.py +5 -0
- genarena/deploy/app.py +84 -0
- genarena/experiments.py +121 -0
- genarena/leaderboard.py +270 -0
- genarena/logs.py +409 -0
- genarena/models.py +412 -0
- genarena/prompts/__init__.py +127 -0
- genarena/prompts/mmrb2.py +373 -0
- genarena/sampling.py +336 -0
- genarena/state.py +656 -0
- genarena/sync/__init__.py +105 -0
- genarena/sync/auto_commit.py +118 -0
- genarena/sync/deploy_ops.py +543 -0
- genarena/sync/git_ops.py +422 -0
- genarena/sync/hf_ops.py +891 -0
- genarena/sync/init_ops.py +431 -0
- genarena/sync/packer.py +587 -0
- genarena/sync/submit.py +837 -0
- genarena/utils.py +103 -0
- genarena/validation/__init__.py +19 -0
- genarena/validation/schema.py +327 -0
- genarena/validation/validator.py +329 -0
- genarena/visualize/README.md +148 -0
- genarena/visualize/__init__.py +14 -0
- genarena/visualize/app.py +938 -0
- genarena/visualize/data_loader.py +2430 -0
- genarena/visualize/static/app.js +3762 -0
- genarena/visualize/static/model_aliases.json +86 -0
- genarena/visualize/static/style.css +4104 -0
- genarena/visualize/templates/index.html +413 -0
- genarena/vlm.py +519 -0
- genarena-0.1.1.dist-info/METADATA +178 -0
- genarena-0.1.1.dist-info/RECORD +44 -0
- {genarena-0.0.1.dist-info → genarena-0.1.1.dist-info}/WHEEL +1 -2
- genarena-0.1.1.dist-info/entry_points.txt +2 -0
- genarena-0.0.1.dist-info/METADATA +0 -26
- genarena-0.0.1.dist-info/RECORD +0 -5
- genarena-0.0.1.dist-info/top_level.txt +0 -1
genarena/arena.py
ADDED
|
@@ -0,0 +1,1685 @@
|
|
|
1
|
+
# Copyright 2026 Ruihang Li.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0.
|
|
3
|
+
# See LICENSE file in the project root for details.
|
|
4
|
+
|
|
5
|
+
"""Arena core coordinator module."""
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import random
|
|
12
|
+
import threading
|
|
13
|
+
import queue as thread_queue
|
|
14
|
+
from collections import deque
|
|
15
|
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import Any, Optional, Union
|
|
18
|
+
|
|
19
|
+
from genarena.battle import BattleResult, execute_battle
|
|
20
|
+
from genarena.bt_elo import compute_bootstrap_bt_elo, BattleTuple
|
|
21
|
+
from genarena.data import ParquetDataset, discover_subsets
|
|
22
|
+
from genarena.experiments import pick_latest_experiment_name, require_valid_exp_name, is_milestone_exp, parse_exp_date_suffix
|
|
23
|
+
from genarena.leaderboard import save_leaderboard
|
|
24
|
+
from genarena.logs import AuditLogger, BattleLogger, load_battle_history, count_battles_per_pair, load_battle_records
|
|
25
|
+
from genarena.models import GlobalModelOutputManager, ModelOutputManager
|
|
26
|
+
from genarena.prompts import load_prompt
|
|
27
|
+
from genarena.sampling import SamplingConfig, AdaptiveSamplingScheduler
|
|
28
|
+
from genarena.state import ArenaState, load_state, rebuild_state_from_logs, save_state, update_stats
|
|
29
|
+
from genarena.utils import ensure_dir, get_sorted_model_pair, iso_timestamp
|
|
30
|
+
from genarena.vlm import VLMJudge
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class BattlePair:
|
|
38
|
+
"""A pair of models and sample for a battle."""
|
|
39
|
+
|
|
40
|
+
model_a: str
|
|
41
|
+
model_b: str
|
|
42
|
+
sample_index: int
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ArenaConfig:
|
|
47
|
+
"""Configuration for an arena run."""
|
|
48
|
+
|
|
49
|
+
# Required paths
|
|
50
|
+
arena_dir: str
|
|
51
|
+
data_dir: str
|
|
52
|
+
subset: str
|
|
53
|
+
|
|
54
|
+
# Model configuration
|
|
55
|
+
models: Optional[list[str]] = None # None = all models
|
|
56
|
+
|
|
57
|
+
# Experiment configuration
|
|
58
|
+
exp_name: Optional[str] = None # None = timestamp
|
|
59
|
+
sample_size: Optional[int] = None # None = all samples (used in full mode)
|
|
60
|
+
num_threads: int = 8
|
|
61
|
+
num_processes: int = 1
|
|
62
|
+
parallel_swap_calls: bool = False
|
|
63
|
+
enable_progress_bar: bool = False
|
|
64
|
+
|
|
65
|
+
# Sampling configuration
|
|
66
|
+
sampling: SamplingConfig = field(default_factory=SamplingConfig)
|
|
67
|
+
|
|
68
|
+
# VLM configuration
|
|
69
|
+
judge_model: str = "Qwen/Qwen3-VL-32B-Instruct-FP8"
|
|
70
|
+
temperature: float = 0.0
|
|
71
|
+
prompt: str = "mmrb2"
|
|
72
|
+
timeout: int = 120
|
|
73
|
+
max_retries: int = 3
|
|
74
|
+
|
|
75
|
+
# Multi-endpoint configuration
|
|
76
|
+
base_urls: Optional[Union[str, list[str]]] = None # Comma-separated or list
|
|
77
|
+
api_keys: Optional[Union[str, list[str]]] = None # Comma-separated or list
|
|
78
|
+
|
|
79
|
+
# Logging configuration
|
|
80
|
+
enable_audit_log: bool = True
|
|
81
|
+
verbose: bool = False
|
|
82
|
+
|
|
83
|
+
# Model removal behavior
|
|
84
|
+
clean_orphaned_logs: bool = True # Delete battle logs involving removed models
|
|
85
|
+
|
|
86
|
+
def to_dict(self) -> dict[str, Any]:
|
|
87
|
+
"""Convert to dictionary for serialization."""
|
|
88
|
+
# Parse base_urls for logging
|
|
89
|
+
base_urls_list = []
|
|
90
|
+
if self.base_urls:
|
|
91
|
+
if isinstance(self.base_urls, str):
|
|
92
|
+
base_urls_list = [u.strip() for u in self.base_urls.split(",") if u.strip()]
|
|
93
|
+
else:
|
|
94
|
+
base_urls_list = list(self.base_urls)
|
|
95
|
+
|
|
96
|
+
# Count api_keys for logging (don't expose actual keys)
|
|
97
|
+
num_api_keys = 0
|
|
98
|
+
if self.api_keys:
|
|
99
|
+
if isinstance(self.api_keys, str):
|
|
100
|
+
num_api_keys = len([k for k in self.api_keys.split(",") if k.strip()])
|
|
101
|
+
else:
|
|
102
|
+
num_api_keys = len(self.api_keys)
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
"arena_dir": self.arena_dir,
|
|
106
|
+
"data_dir": self.data_dir,
|
|
107
|
+
"subset": self.subset,
|
|
108
|
+
"models": self.models,
|
|
109
|
+
"exp_name": self.exp_name,
|
|
110
|
+
"sample_size": self.sample_size,
|
|
111
|
+
"num_threads": self.num_threads,
|
|
112
|
+
"num_processes": self.num_processes,
|
|
113
|
+
"parallel_swap_calls": self.parallel_swap_calls,
|
|
114
|
+
"enable_progress_bar": self.enable_progress_bar,
|
|
115
|
+
"sampling": self.sampling.to_dict(),
|
|
116
|
+
"judge_model": self.judge_model,
|
|
117
|
+
"temperature": self.temperature,
|
|
118
|
+
"prompt": self.prompt,
|
|
119
|
+
"timeout": self.timeout,
|
|
120
|
+
"max_retries": self.max_retries,
|
|
121
|
+
"base_urls": base_urls_list,
|
|
122
|
+
"num_api_keys": num_api_keys,
|
|
123
|
+
"enable_audit_log": self.enable_audit_log,
|
|
124
|
+
"clean_orphaned_logs": self.clean_orphaned_logs,
|
|
125
|
+
"timestamp": iso_timestamp()
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _run_parquet_bucket_worker(
|
|
130
|
+
*,
|
|
131
|
+
arena_dir: str,
|
|
132
|
+
data_dir: str,
|
|
133
|
+
subset: str,
|
|
134
|
+
exp_name: str,
|
|
135
|
+
parquet_work: list[tuple[str, list[int]]],
|
|
136
|
+
models: list[str],
|
|
137
|
+
new_models: list[str],
|
|
138
|
+
num_threads: int,
|
|
139
|
+
judge_model: str,
|
|
140
|
+
temperature: float,
|
|
141
|
+
prompt: str,
|
|
142
|
+
timeout: int,
|
|
143
|
+
max_retries: int,
|
|
144
|
+
base_urls: Optional[Union[str, list[str]]],
|
|
145
|
+
api_keys: Optional[Union[str, list[str]]],
|
|
146
|
+
enable_audit_log: bool,
|
|
147
|
+
parallel_swap_calls: bool,
|
|
148
|
+
progress_queue: Any = None,
|
|
149
|
+
) -> dict[str, int]:
|
|
150
|
+
"""
|
|
151
|
+
Worker entry point for multiprocessing: execute battles for a bucket of parquet files.
|
|
152
|
+
|
|
153
|
+
Notes:
|
|
154
|
+
- Each process initializes its own VLM client/endpoint manager.
|
|
155
|
+
- Results are persisted via jsonl logs (with fcntl locks), so the parent process
|
|
156
|
+
only needs counts for progress reporting.
|
|
157
|
+
"""
|
|
158
|
+
# Local imports are avoided here because the module is already imported in workers,
|
|
159
|
+
# but keep this function at module-level so it's picklable by ProcessPoolExecutor.
|
|
160
|
+
subset_dir = os.path.join(arena_dir, subset)
|
|
161
|
+
models_dir = os.path.join(subset_dir, "models")
|
|
162
|
+
pk_logs_dir = os.path.join(subset_dir, "pk_logs")
|
|
163
|
+
exp_dir = os.path.join(pk_logs_dir, exp_name)
|
|
164
|
+
|
|
165
|
+
ensure_dir(exp_dir)
|
|
166
|
+
|
|
167
|
+
prompt_module = load_prompt(prompt)
|
|
168
|
+
# In v2 layout, models are stored under models/<exp_name>/<model>/...
|
|
169
|
+
# and model names are globally unique across experiments.
|
|
170
|
+
model_manager = GlobalModelOutputManager(models_dir)
|
|
171
|
+
|
|
172
|
+
vlm = VLMJudge(
|
|
173
|
+
model=judge_model,
|
|
174
|
+
temperature=temperature,
|
|
175
|
+
timeout=timeout,
|
|
176
|
+
max_retries=max_retries,
|
|
177
|
+
base_urls=base_urls,
|
|
178
|
+
api_keys=api_keys,
|
|
179
|
+
progress=progress_queue,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
battle_logger = BattleLogger(exp_dir)
|
|
183
|
+
audit_logger = AuditLogger(exp_dir) if enable_audit_log else None
|
|
184
|
+
|
|
185
|
+
completed_set = load_battle_history(pk_logs_dir)
|
|
186
|
+
|
|
187
|
+
class _ProgressBuffer:
|
|
188
|
+
"""Batch progress updates to reduce cross-process queue overhead."""
|
|
189
|
+
|
|
190
|
+
def __init__(self, q: Any, flush_every: int = 20):
|
|
191
|
+
self._q = q
|
|
192
|
+
self._flush_every = flush_every
|
|
193
|
+
self._buf = 0
|
|
194
|
+
|
|
195
|
+
def put(self, n: int) -> None:
|
|
196
|
+
if self._q is None:
|
|
197
|
+
return
|
|
198
|
+
self._buf += int(n)
|
|
199
|
+
if self._buf >= self._flush_every:
|
|
200
|
+
try:
|
|
201
|
+
self._q.put(self._buf)
|
|
202
|
+
finally:
|
|
203
|
+
self._buf = 0
|
|
204
|
+
|
|
205
|
+
def total(self, n: int) -> None:
|
|
206
|
+
"""Increase progress bar total by n (best-effort)."""
|
|
207
|
+
if self._q is None:
|
|
208
|
+
return
|
|
209
|
+
try:
|
|
210
|
+
n_int = int(n)
|
|
211
|
+
except Exception:
|
|
212
|
+
return
|
|
213
|
+
if n_int <= 0:
|
|
214
|
+
return
|
|
215
|
+
try:
|
|
216
|
+
self._q.put(("total", n_int))
|
|
217
|
+
except Exception:
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
def flush(self) -> None:
|
|
221
|
+
if self._q is None:
|
|
222
|
+
return
|
|
223
|
+
if self._buf > 0:
|
|
224
|
+
try:
|
|
225
|
+
self._q.put(self._buf)
|
|
226
|
+
finally:
|
|
227
|
+
self._buf = 0
|
|
228
|
+
|
|
229
|
+
progress = _ProgressBuffer(progress_queue) if progress_queue is not None else None
|
|
230
|
+
|
|
231
|
+
def _execute_one(dataset: ParquetDataset, model_a: str, model_b: str, sample_index: int) -> bool:
|
|
232
|
+
# Skip if already completed (sorted key)
|
|
233
|
+
first, second, _ = get_sorted_model_pair(model_a, model_b)
|
|
234
|
+
if (first, second, sample_index) in completed_set:
|
|
235
|
+
return False
|
|
236
|
+
|
|
237
|
+
sample = dataset.get_by_index(sample_index)
|
|
238
|
+
if sample is None:
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
output_a = model_manager.get_output_path(model_a, sample_index)
|
|
242
|
+
output_b = model_manager.get_output_path(model_b, sample_index)
|
|
243
|
+
if output_a is None or output_b is None:
|
|
244
|
+
return False
|
|
245
|
+
|
|
246
|
+
result = execute_battle(
|
|
247
|
+
vlm=vlm,
|
|
248
|
+
prompt_module=prompt_module,
|
|
249
|
+
sample=sample,
|
|
250
|
+
model_a_output=output_a,
|
|
251
|
+
model_b_output=output_b,
|
|
252
|
+
model_a=model_a,
|
|
253
|
+
model_b=model_b,
|
|
254
|
+
parallel_swap_calls=parallel_swap_calls,
|
|
255
|
+
progress=progress,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
battle_logger.log_battle_result(result)
|
|
259
|
+
if audit_logger:
|
|
260
|
+
audit_logger.log_battle_result(result)
|
|
261
|
+
|
|
262
|
+
return True
|
|
263
|
+
|
|
264
|
+
# Build tasks lazily and keep inflight bounded to reduce overhead for large runs.
|
|
265
|
+
completed = 0
|
|
266
|
+
total_attempted = 0
|
|
267
|
+
total_indices = 0
|
|
268
|
+
|
|
269
|
+
selected_models = set(models)
|
|
270
|
+
new_models_filtered = [m for m in new_models if m in selected_models]
|
|
271
|
+
if not new_models_filtered:
|
|
272
|
+
return {"completed": 0, "attempted": 0, "indices": 0}
|
|
273
|
+
|
|
274
|
+
pair_set: set[tuple[str, str]] = set()
|
|
275
|
+
for m in new_models_filtered:
|
|
276
|
+
for other in selected_models:
|
|
277
|
+
if other == m:
|
|
278
|
+
continue
|
|
279
|
+
a, b, _ = get_sorted_model_pair(m, other)
|
|
280
|
+
pair_set.add((a, b))
|
|
281
|
+
|
|
282
|
+
model_pairs = sorted(pair_set)
|
|
283
|
+
|
|
284
|
+
if num_threads <= 1:
|
|
285
|
+
for pf, indices in parquet_work:
|
|
286
|
+
if not indices:
|
|
287
|
+
continue
|
|
288
|
+
total_indices += len(indices)
|
|
289
|
+
dataset = ParquetDataset(data_dir, subset, parquet_files=[pf])
|
|
290
|
+
for model_a, model_b in model_pairs:
|
|
291
|
+
valid_indices = model_manager.validate_coverage(model_a, model_b, indices)
|
|
292
|
+
first, second, _ = get_sorted_model_pair(model_a, model_b)
|
|
293
|
+
pending_indices = [idx for idx in valid_indices if (first, second, idx) not in completed_set]
|
|
294
|
+
if progress is not None:
|
|
295
|
+
# Each battle always makes 2 API calls (original + swapped).
|
|
296
|
+
progress.total(2 * len(pending_indices))
|
|
297
|
+
for idx in pending_indices:
|
|
298
|
+
total_attempted += 1
|
|
299
|
+
if _execute_one(dataset, model_a, model_b, idx):
|
|
300
|
+
completed += 1
|
|
301
|
+
else:
|
|
302
|
+
max_inflight = max(1, num_threads * 4)
|
|
303
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
304
|
+
inflight = set()
|
|
305
|
+
|
|
306
|
+
def _drain_one() -> None:
|
|
307
|
+
nonlocal completed
|
|
308
|
+
done_future = next(as_completed(inflight))
|
|
309
|
+
inflight.remove(done_future)
|
|
310
|
+
try:
|
|
311
|
+
ok = done_future.result()
|
|
312
|
+
if ok:
|
|
313
|
+
completed += 1
|
|
314
|
+
except Exception:
|
|
315
|
+
# Worker-level robustness: ignore individual battle failures.
|
|
316
|
+
pass
|
|
317
|
+
|
|
318
|
+
for pf, indices in parquet_work:
|
|
319
|
+
if not indices:
|
|
320
|
+
continue
|
|
321
|
+
total_indices += len(indices)
|
|
322
|
+
dataset = ParquetDataset(data_dir, subset, parquet_files=[pf])
|
|
323
|
+
for model_a, model_b in model_pairs:
|
|
324
|
+
valid_indices = model_manager.validate_coverage(model_a, model_b, indices)
|
|
325
|
+
first, second, _ = get_sorted_model_pair(model_a, model_b)
|
|
326
|
+
pending_indices = [idx for idx in valid_indices if (first, second, idx) not in completed_set]
|
|
327
|
+
if progress is not None:
|
|
328
|
+
progress.total(2 * len(pending_indices))
|
|
329
|
+
for idx in pending_indices:
|
|
330
|
+
total_attempted += 1
|
|
331
|
+
inflight.add(executor.submit(_execute_one, dataset, model_a, model_b, idx))
|
|
332
|
+
if len(inflight) >= max_inflight:
|
|
333
|
+
_drain_one()
|
|
334
|
+
|
|
335
|
+
while inflight:
|
|
336
|
+
_drain_one()
|
|
337
|
+
|
|
338
|
+
if progress is not None:
|
|
339
|
+
progress.flush()
|
|
340
|
+
|
|
341
|
+
return {
|
|
342
|
+
"completed": completed,
|
|
343
|
+
"attempted": total_attempted,
|
|
344
|
+
"indices": total_indices,
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _start_calls_progress_consumer(
|
|
349
|
+
*,
|
|
350
|
+
enabled: bool,
|
|
351
|
+
total: Optional[int] = None,
|
|
352
|
+
) -> tuple[Any, Optional[threading.Thread], Any]:
|
|
353
|
+
"""
|
|
354
|
+
Start a progress consumer thread that reads integer increments from a queue.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
(progress_queue, thread, close_fn)
|
|
358
|
+
"""
|
|
359
|
+
if not enabled:
|
|
360
|
+
return None, None, lambda: None
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
from tqdm import tqdm # type: ignore
|
|
364
|
+
except Exception:
|
|
365
|
+
logger.warning("tqdm is not available; progress bar disabled")
|
|
366
|
+
return None, None, lambda: None
|
|
367
|
+
|
|
368
|
+
q: thread_queue.Queue[Any] = thread_queue.Queue()
|
|
369
|
+
stop_sentinel = object()
|
|
370
|
+
bar = tqdm(total=total, unit="call", desc="API Calls", dynamic_ncols=True)
|
|
371
|
+
recent: deque[str] = deque(maxlen=10)
|
|
372
|
+
|
|
373
|
+
def _run() -> None:
|
|
374
|
+
while True:
|
|
375
|
+
item = q.get()
|
|
376
|
+
if item is stop_sentinel:
|
|
377
|
+
break
|
|
378
|
+
if isinstance(item, (int, float)):
|
|
379
|
+
try:
|
|
380
|
+
bar.update(int(item))
|
|
381
|
+
except Exception:
|
|
382
|
+
pass
|
|
383
|
+
elif isinstance(item, tuple) and len(item) == 2 and item[0] == "log":
|
|
384
|
+
try:
|
|
385
|
+
recent.append(str(item[1]))
|
|
386
|
+
bar.set_postfix_str(" | ".join(recent))
|
|
387
|
+
except Exception:
|
|
388
|
+
pass
|
|
389
|
+
elif isinstance(item, tuple) and len(item) == 2 and item[0] == "total":
|
|
390
|
+
try:
|
|
391
|
+
delta = int(item[1])
|
|
392
|
+
if delta > 0:
|
|
393
|
+
bar.total = (bar.total or 0) + delta
|
|
394
|
+
bar.refresh()
|
|
395
|
+
except Exception:
|
|
396
|
+
pass
|
|
397
|
+
else:
|
|
398
|
+
# Unknown item type, ignore.
|
|
399
|
+
pass
|
|
400
|
+
|
|
401
|
+
t = threading.Thread(target=_run, name="calls-progress-consumer", daemon=True)
|
|
402
|
+
t.start()
|
|
403
|
+
|
|
404
|
+
def _close() -> None:
|
|
405
|
+
try:
|
|
406
|
+
q.put(stop_sentinel)
|
|
407
|
+
except Exception:
|
|
408
|
+
pass
|
|
409
|
+
if t is not None:
|
|
410
|
+
t.join(timeout=5)
|
|
411
|
+
try:
|
|
412
|
+
bar.close()
|
|
413
|
+
except Exception:
|
|
414
|
+
pass
|
|
415
|
+
|
|
416
|
+
return q, t, _close
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class Arena:
|
|
420
|
+
"""
|
|
421
|
+
Arena coordinator for running pairwise model evaluations.
|
|
422
|
+
|
|
423
|
+
Manages:
|
|
424
|
+
- Subset directory structure
|
|
425
|
+
- Model discovery and output management
|
|
426
|
+
- Battle pair generation
|
|
427
|
+
- Checkpoint/resume functionality
|
|
428
|
+
- Parallel battle execution
|
|
429
|
+
- ELO state management
|
|
430
|
+
- Leaderboard generation
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def __init__(self, config: ArenaConfig):
|
|
434
|
+
"""
|
|
435
|
+
Initialize the arena.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
config: ArenaConfig with all settings
|
|
439
|
+
"""
|
|
440
|
+
self.config = config
|
|
441
|
+
|
|
442
|
+
# Set up paths
|
|
443
|
+
self.subset_dir = os.path.join(config.arena_dir, config.subset)
|
|
444
|
+
self.models_root_dir = os.path.join(self.subset_dir, "models")
|
|
445
|
+
self.pk_logs_dir = os.path.join(self.subset_dir, "pk_logs")
|
|
446
|
+
# Resolve experiment name (infer from models/ if not provided)
|
|
447
|
+
if config.exp_name is not None:
|
|
448
|
+
require_valid_exp_name(config.exp_name)
|
|
449
|
+
else:
|
|
450
|
+
config.exp_name = pick_latest_experiment_name(self.models_root_dir)
|
|
451
|
+
|
|
452
|
+
# In v2 layout, per-experiment model outputs live under: models/<exp_name>/<model>/...
|
|
453
|
+
self.models_dir = os.path.join(self.models_root_dir, config.exp_name)
|
|
454
|
+
if not os.path.isdir(self.models_dir):
|
|
455
|
+
raise ValueError(
|
|
456
|
+
f"Experiment models directory does not exist: {self.models_dir}. "
|
|
457
|
+
f"Expected `models/{config.exp_name}/<model_name>/...`."
|
|
458
|
+
)
|
|
459
|
+
self.exp_dir = os.path.join(self.pk_logs_dir, config.exp_name)
|
|
460
|
+
self.arena_state_dir = os.path.join(self.subset_dir, "arena")
|
|
461
|
+
self.state_path = os.path.join(self.arena_state_dir, "state.json")
|
|
462
|
+
self.leaderboard_path = os.path.join(self.subset_dir, "README.md")
|
|
463
|
+
|
|
464
|
+
# Initialize directories
|
|
465
|
+
self._init_directories()
|
|
466
|
+
|
|
467
|
+
# Load components
|
|
468
|
+
self.prompt_module = load_prompt(config.prompt)
|
|
469
|
+
# In multiprocessing mode, we only need fast index scanning in the parent
|
|
470
|
+
# process (full data is loaded per-parquet inside workers).
|
|
471
|
+
load_mode = "index_only" if config.num_processes > 1 else "full"
|
|
472
|
+
self.dataset = ParquetDataset(config.data_dir, config.subset, load_mode=load_mode)
|
|
473
|
+
# Global model registry (v2 layout): models/<exp_name>/<model>/...
|
|
474
|
+
self.model_manager = GlobalModelOutputManager(self.models_root_dir)
|
|
475
|
+
|
|
476
|
+
# Models that are newly introduced in this experiment (directory listing)
|
|
477
|
+
self.new_models = self.model_manager.get_experiment_models(config.exp_name)
|
|
478
|
+
|
|
479
|
+
# Parse experiment date for filtering eligible opponents
|
|
480
|
+
self.exp_date = parse_exp_date_suffix(config.exp_name)
|
|
481
|
+
|
|
482
|
+
# Resolve selected model universe for this run
|
|
483
|
+
# When running an old experiment, only consider models from experiments
|
|
484
|
+
# with date <= this experiment's date (to avoid battling "future" models).
|
|
485
|
+
if config.models:
|
|
486
|
+
self.models = [m for m in config.models if self.model_manager.has_model(m)]
|
|
487
|
+
elif self.exp_date is not None:
|
|
488
|
+
# Filter to models from experiments up to this experiment's date
|
|
489
|
+
self.models = self.model_manager.get_models_up_to_date(self.exp_date)
|
|
490
|
+
else:
|
|
491
|
+
self.models = self.model_manager.models
|
|
492
|
+
|
|
493
|
+
# Canonical "current models on disk" (used for state/log cleanup even when --models is used)
|
|
494
|
+
self.all_models = self.model_manager.models
|
|
495
|
+
|
|
496
|
+
# Initialize loggers
|
|
497
|
+
self.battle_logger = BattleLogger(self.exp_dir)
|
|
498
|
+
self.audit_logger = AuditLogger(self.exp_dir) if config.enable_audit_log else None
|
|
499
|
+
|
|
500
|
+
# Initialize VLM judge with multi-endpoint support
|
|
501
|
+
self.vlm = VLMJudge(
|
|
502
|
+
model=config.judge_model,
|
|
503
|
+
temperature=config.temperature,
|
|
504
|
+
timeout=config.timeout,
|
|
505
|
+
max_retries=config.max_retries,
|
|
506
|
+
base_urls=config.base_urls,
|
|
507
|
+
api_keys=config.api_keys,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Save experiment config
|
|
511
|
+
self._save_config()
|
|
512
|
+
self._progress_queue = None
|
|
513
|
+
|
|
514
|
+
def _init_directories(self) -> None:
|
|
515
|
+
"""Create necessary directory structure."""
|
|
516
|
+
ensure_dir(self.subset_dir)
|
|
517
|
+
ensure_dir(self.models_root_dir)
|
|
518
|
+
ensure_dir(self.pk_logs_dir)
|
|
519
|
+
ensure_dir(self.exp_dir)
|
|
520
|
+
ensure_dir(self.arena_state_dir)
|
|
521
|
+
|
|
522
|
+
if self.config.enable_audit_log:
|
|
523
|
+
ensure_dir(os.path.join(self.exp_dir, "raw_outputs"))
|
|
524
|
+
|
|
525
|
+
def _save_config(self) -> None:
|
|
526
|
+
"""Save experiment configuration."""
|
|
527
|
+
config_path = os.path.join(self.exp_dir, "config.json")
|
|
528
|
+
history_path = os.path.join(self.exp_dir, "config_history.json")
|
|
529
|
+
|
|
530
|
+
config_dict = self.config.to_dict()
|
|
531
|
+
config_dict["models_actual"] = self.models
|
|
532
|
+
|
|
533
|
+
# If config exists, append to history
|
|
534
|
+
if os.path.isfile(config_path):
|
|
535
|
+
# Read existing config and append to history
|
|
536
|
+
try:
|
|
537
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
538
|
+
existing = json.load(f)
|
|
539
|
+
|
|
540
|
+
# Load or create history
|
|
541
|
+
history = []
|
|
542
|
+
if os.path.isfile(history_path):
|
|
543
|
+
with open(history_path, "r", encoding="utf-8") as f:
|
|
544
|
+
history = json.load(f)
|
|
545
|
+
|
|
546
|
+
history.append(existing)
|
|
547
|
+
|
|
548
|
+
with open(history_path, "w", encoding="utf-8") as f:
|
|
549
|
+
json.dump(history, f, indent=2, ensure_ascii=False)
|
|
550
|
+
except Exception:
|
|
551
|
+
pass
|
|
552
|
+
|
|
553
|
+
# Write current config
|
|
554
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
555
|
+
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
|
556
|
+
|
|
557
|
+
def _sync_state_with_models(self) -> bool:
|
|
558
|
+
"""
|
|
559
|
+
Synchronize arena state with current available models.
|
|
560
|
+
|
|
561
|
+
If models have been removed from the models directory, this method will:
|
|
562
|
+
1. Detect removed models (from both state and pk_logs)
|
|
563
|
+
2. Move battle logs involving removed models to .pk_logs_rm/ (if clean_orphaned_logs=True)
|
|
564
|
+
3. Rebuild ELO state from remaining battle logs
|
|
565
|
+
4. Save the updated state
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
True if state was rebuilt due to model changes, False otherwise
|
|
569
|
+
"""
|
|
570
|
+
state = load_state(self.state_path)
|
|
571
|
+
# Use the canonical on-disk model set (do NOT treat --models filter as removals)
|
|
572
|
+
current_models = set(self.all_models)
|
|
573
|
+
|
|
574
|
+
# Get models that exist in state but not in current model list
|
|
575
|
+
state_models = set(state.models.keys())
|
|
576
|
+
removed_from_state = state_models - current_models
|
|
577
|
+
|
|
578
|
+
# Also scan pk_logs to find models that exist in logs but not in models/
|
|
579
|
+
logs_models = self._scan_models_from_logs()
|
|
580
|
+
removed_from_logs = logs_models - current_models
|
|
581
|
+
|
|
582
|
+
# Combine both sources of removed models
|
|
583
|
+
removed_models = removed_from_state | removed_from_logs
|
|
584
|
+
|
|
585
|
+
if not removed_models:
|
|
586
|
+
return False
|
|
587
|
+
|
|
588
|
+
logger.info(
|
|
589
|
+
f"Detected removed models: {removed_models}. "
|
|
590
|
+
f"Rebuilding ELO state from battle logs..."
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
# Clean up orphaned battle logs if enabled
|
|
594
|
+
if self.config.clean_orphaned_logs:
|
|
595
|
+
self._delete_orphaned_logs(removed_models)
|
|
596
|
+
|
|
597
|
+
# Rebuild state from logs, only including current models
|
|
598
|
+
new_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
|
|
599
|
+
|
|
600
|
+
# Save the rebuilt state
|
|
601
|
+
save_state(new_state, self.state_path)
|
|
602
|
+
|
|
603
|
+
logger.info(
|
|
604
|
+
f"State rebuilt: {new_state.total_battles} battles, "
|
|
605
|
+
f"{len(new_state.models)} models"
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
return True
|
|
609
|
+
|
|
610
|
+
def _scan_models_from_logs(self) -> set[str]:
|
|
611
|
+
"""
|
|
612
|
+
Scan all battle log files to extract model names.
|
|
613
|
+
|
|
614
|
+
This method reads the actual content of jsonl files to get the original
|
|
615
|
+
model names, which is more reliable than parsing sanitized filenames.
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
Set of all model names found in battle logs
|
|
619
|
+
"""
|
|
620
|
+
models_found: set[str] = set()
|
|
621
|
+
|
|
622
|
+
if not os.path.isdir(self.pk_logs_dir):
|
|
623
|
+
return models_found
|
|
624
|
+
|
|
625
|
+
for exp_name in os.listdir(self.pk_logs_dir):
|
|
626
|
+
exp_dir = os.path.join(self.pk_logs_dir, exp_name)
|
|
627
|
+
if not os.path.isdir(exp_dir):
|
|
628
|
+
continue
|
|
629
|
+
|
|
630
|
+
for filename in os.listdir(exp_dir):
|
|
631
|
+
if not filename.endswith(".jsonl"):
|
|
632
|
+
continue
|
|
633
|
+
|
|
634
|
+
filepath = os.path.join(exp_dir, filename)
|
|
635
|
+
if not os.path.isfile(filepath):
|
|
636
|
+
continue
|
|
637
|
+
|
|
638
|
+
# Read first line to extract model names (all lines in a file
|
|
639
|
+
# should have the same model pair)
|
|
640
|
+
try:
|
|
641
|
+
with open(filepath, "r", encoding="utf-8") as f:
|
|
642
|
+
for line in f:
|
|
643
|
+
line = line.strip()
|
|
644
|
+
if not line:
|
|
645
|
+
continue
|
|
646
|
+
try:
|
|
647
|
+
record = json.loads(line)
|
|
648
|
+
model_a = record.get("model_a", "")
|
|
649
|
+
model_b = record.get("model_b", "")
|
|
650
|
+
if model_a:
|
|
651
|
+
models_found.add(model_a)
|
|
652
|
+
if model_b:
|
|
653
|
+
models_found.add(model_b)
|
|
654
|
+
# Only need first valid line per file
|
|
655
|
+
break
|
|
656
|
+
except json.JSONDecodeError:
|
|
657
|
+
continue
|
|
658
|
+
except Exception:
|
|
659
|
+
pass
|
|
660
|
+
|
|
661
|
+
return models_found
|
|
662
|
+
|
|
663
|
+
def _delete_orphaned_logs(self, removed_models: set[str]) -> None:
|
|
664
|
+
"""
|
|
665
|
+
Move battle log files that involve removed models to .pk_logs_rm/ directory.
|
|
666
|
+
|
|
667
|
+
This method reads the actual content of each jsonl file to extract the
|
|
668
|
+
original model names, which is more reliable than parsing sanitized
|
|
669
|
+
filenames. Instead of deleting, files are moved to a backup directory
|
|
670
|
+
(.pk_logs_rm/) at the same level as pk_logs/.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
removed_models: Set of model names that have been removed
|
|
674
|
+
"""
|
|
675
|
+
import shutil
|
|
676
|
+
|
|
677
|
+
if not os.path.isdir(self.pk_logs_dir):
|
|
678
|
+
return
|
|
679
|
+
|
|
680
|
+
# Create backup directory at the same level as pk_logs
|
|
681
|
+
pk_logs_rm_dir = os.path.join(self.subset_dir, ".pk_logs_rm")
|
|
682
|
+
|
|
683
|
+
moved_count = 0
|
|
684
|
+
|
|
685
|
+
def _file_involves_removed_model(filepath: str) -> bool:
|
|
686
|
+
"""
|
|
687
|
+
Check if a jsonl file involves any removed model by reading its content.
|
|
688
|
+
|
|
689
|
+
Returns True if any record in the file has model_a or model_b in removed_models.
|
|
690
|
+
"""
|
|
691
|
+
try:
|
|
692
|
+
with open(filepath, "r", encoding="utf-8") as f:
|
|
693
|
+
for line in f:
|
|
694
|
+
line = line.strip()
|
|
695
|
+
if not line:
|
|
696
|
+
continue
|
|
697
|
+
try:
|
|
698
|
+
record = json.loads(line)
|
|
699
|
+
model_a = record.get("model_a", "")
|
|
700
|
+
model_b = record.get("model_b", "")
|
|
701
|
+
if model_a in removed_models or model_b in removed_models:
|
|
702
|
+
return True
|
|
703
|
+
except json.JSONDecodeError:
|
|
704
|
+
continue
|
|
705
|
+
except Exception:
|
|
706
|
+
pass
|
|
707
|
+
return False
|
|
708
|
+
|
|
709
|
+
def _move_to_backup(filepath: str, relative_path: str) -> bool:
|
|
710
|
+
"""
|
|
711
|
+
Move a file to the backup directory, preserving relative path structure.
|
|
712
|
+
|
|
713
|
+
Args:
|
|
714
|
+
filepath: Absolute path to the source file
|
|
715
|
+
relative_path: Relative path from pk_logs_dir (e.g., "exp_name/file.jsonl")
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
True if moved successfully, False otherwise
|
|
719
|
+
"""
|
|
720
|
+
dest_path = os.path.join(pk_logs_rm_dir, relative_path)
|
|
721
|
+
dest_dir = os.path.dirname(dest_path)
|
|
722
|
+
|
|
723
|
+
try:
|
|
724
|
+
ensure_dir(dest_dir)
|
|
725
|
+
shutil.move(filepath, dest_path)
|
|
726
|
+
return True
|
|
727
|
+
except Exception as e:
|
|
728
|
+
logger.warning(f"Failed to move {filepath} to {dest_path}: {e}")
|
|
729
|
+
return False
|
|
730
|
+
|
|
731
|
+
# Iterate over all experiment directories
|
|
732
|
+
for exp_name in os.listdir(self.pk_logs_dir):
|
|
733
|
+
exp_dir = os.path.join(self.pk_logs_dir, exp_name)
|
|
734
|
+
if not os.path.isdir(exp_dir):
|
|
735
|
+
continue
|
|
736
|
+
|
|
737
|
+
# Check battle log files (format: model_a_vs_model_b.jsonl)
|
|
738
|
+
for filename in os.listdir(exp_dir):
|
|
739
|
+
if not filename.endswith(".jsonl"):
|
|
740
|
+
continue
|
|
741
|
+
|
|
742
|
+
filepath = os.path.join(exp_dir, filename)
|
|
743
|
+
if not os.path.isfile(filepath):
|
|
744
|
+
continue
|
|
745
|
+
|
|
746
|
+
# Check file content to determine if it involves removed models
|
|
747
|
+
if _file_involves_removed_model(filepath):
|
|
748
|
+
relative_path = os.path.join(exp_name, filename)
|
|
749
|
+
if _move_to_backup(filepath, relative_path):
|
|
750
|
+
moved_count += 1
|
|
751
|
+
logger.debug(f"Moved orphaned log to backup: {filepath}")
|
|
752
|
+
|
|
753
|
+
# Also check raw_outputs subdirectory
|
|
754
|
+
raw_outputs_dir = os.path.join(exp_dir, "raw_outputs")
|
|
755
|
+
if os.path.isdir(raw_outputs_dir):
|
|
756
|
+
for filename in os.listdir(raw_outputs_dir):
|
|
757
|
+
if not filename.endswith(".jsonl"):
|
|
758
|
+
continue
|
|
759
|
+
|
|
760
|
+
filepath = os.path.join(raw_outputs_dir, filename)
|
|
761
|
+
if not os.path.isfile(filepath):
|
|
762
|
+
continue
|
|
763
|
+
|
|
764
|
+
if _file_involves_removed_model(filepath):
|
|
765
|
+
relative_path = os.path.join(exp_name, "raw_outputs", filename)
|
|
766
|
+
if _move_to_backup(filepath, relative_path):
|
|
767
|
+
moved_count += 1
|
|
768
|
+
logger.debug(f"Moved orphaned audit log to backup: {filepath}")
|
|
769
|
+
|
|
770
|
+
if moved_count > 0:
|
|
771
|
+
logger.info(f"Moved {moved_count} orphaned battle log files to {pk_logs_rm_dir}")
|
|
772
|
+
|
|
773
|
+
def _generate_battle_pairs(self) -> list[BattlePair]:
|
|
774
|
+
"""
|
|
775
|
+
Generate all battle pairs to execute.
|
|
776
|
+
|
|
777
|
+
In full mode: generates all possible pairs up to sample_size.
|
|
778
|
+
In adaptive mode: generates pairs based on sampling config, respecting
|
|
779
|
+
min_samples and max_samples per model pair.
|
|
780
|
+
|
|
781
|
+
Returns:
|
|
782
|
+
List of BattlePair objects
|
|
783
|
+
"""
|
|
784
|
+
pairs = []
|
|
785
|
+
|
|
786
|
+
# Get all dataset indices
|
|
787
|
+
all_indices = self.dataset.get_all_indices()
|
|
788
|
+
|
|
789
|
+
# In full mode, apply global sample_size limit
|
|
790
|
+
# In adaptive mode, we apply per-pair limits later
|
|
791
|
+
if self.config.sampling.mode == "full":
|
|
792
|
+
if self.config.sample_size and self.config.sample_size < len(all_indices):
|
|
793
|
+
indices = random.sample(all_indices, self.config.sample_size)
|
|
794
|
+
else:
|
|
795
|
+
indices = all_indices
|
|
796
|
+
else:
|
|
797
|
+
# Adaptive mode: use all indices, will limit per-pair
|
|
798
|
+
indices = all_indices
|
|
799
|
+
|
|
800
|
+
# Generate model pairs to run for this exp:
|
|
801
|
+
# - only include pairs where at least one side is a "new model" in this exp
|
|
802
|
+
# - but respect the user-provided --models filter (self.models)
|
|
803
|
+
selected_models = set(self.models)
|
|
804
|
+
new_models = [m for m in self.new_models if m in selected_models]
|
|
805
|
+
|
|
806
|
+
if not new_models:
|
|
807
|
+
return []
|
|
808
|
+
|
|
809
|
+
# Build unique pair set (sorted) for: new-vs-all + new-vs-new
|
|
810
|
+
pair_set: set[tuple[str, str]] = set()
|
|
811
|
+
for m in new_models:
|
|
812
|
+
for other in selected_models:
|
|
813
|
+
if other == m:
|
|
814
|
+
continue
|
|
815
|
+
a, b, _ = get_sorted_model_pair(m, other)
|
|
816
|
+
pair_set.add((a, b))
|
|
817
|
+
|
|
818
|
+
model_pairs = sorted(pair_set)
|
|
819
|
+
|
|
820
|
+
# Load existing battle counts for adaptive mode
|
|
821
|
+
if self.config.sampling.mode == "adaptive":
|
|
822
|
+
existing_counts = count_battles_per_pair(self.pk_logs_dir)
|
|
823
|
+
# Determine target samples per pair based on experiment type
|
|
824
|
+
if is_milestone_exp(self.config.exp_name or ""):
|
|
825
|
+
target_samples = self.config.sampling.milestone_min_samples
|
|
826
|
+
else:
|
|
827
|
+
target_samples = self.config.sampling.min_samples
|
|
828
|
+
else:
|
|
829
|
+
existing_counts = {}
|
|
830
|
+
target_samples = None
|
|
831
|
+
|
|
832
|
+
# Generate battle pairs for each model pair and sample
|
|
833
|
+
for model_a, model_b in model_pairs:
|
|
834
|
+
# Validate coverage
|
|
835
|
+
valid_indices = self.model_manager.validate_coverage(
|
|
836
|
+
model_a, model_b, indices
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
# In adaptive mode, limit samples per pair
|
|
840
|
+
if self.config.sampling.mode == "adaptive" and target_samples is not None:
|
|
841
|
+
key = (min(model_a, model_b), max(model_a, model_b))
|
|
842
|
+
existing = existing_counts.get(key, 0)
|
|
843
|
+
needed = max(0, target_samples - existing)
|
|
844
|
+
|
|
845
|
+
if needed == 0:
|
|
846
|
+
continue # This pair already has enough samples
|
|
847
|
+
|
|
848
|
+
# Limit to needed samples (randomly select if more available)
|
|
849
|
+
if len(valid_indices) > needed:
|
|
850
|
+
valid_indices = random.sample(valid_indices, needed)
|
|
851
|
+
|
|
852
|
+
for idx in valid_indices:
|
|
853
|
+
pairs.append(BattlePair(
|
|
854
|
+
model_a=model_a,
|
|
855
|
+
model_b=model_b,
|
|
856
|
+
sample_index=idx
|
|
857
|
+
))
|
|
858
|
+
|
|
859
|
+
return pairs
|
|
860
|
+
|
|
861
|
+
def _skip_completed(
|
|
862
|
+
self,
|
|
863
|
+
pairs: list[BattlePair]
|
|
864
|
+
) -> list[BattlePair]:
|
|
865
|
+
"""
|
|
866
|
+
Filter out already completed battles.
|
|
867
|
+
|
|
868
|
+
Only considers battles where both models still exist in the current
|
|
869
|
+
model list. Battles involving removed models are ignored.
|
|
870
|
+
|
|
871
|
+
Args:
|
|
872
|
+
pairs: List of battle pairs
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
Filtered list excluding completed battles
|
|
876
|
+
"""
|
|
877
|
+
all_completed = load_battle_history(self.pk_logs_dir)
|
|
878
|
+
|
|
879
|
+
# Filter completed battles to only include those with current on-disk models.
|
|
880
|
+
# This avoids treating --models filters as removals.
|
|
881
|
+
current_models = set(self.all_models)
|
|
882
|
+
completed = {
|
|
883
|
+
(m_a, m_b, idx)
|
|
884
|
+
for m_a, m_b, idx in all_completed
|
|
885
|
+
if m_a in current_models and m_b in current_models
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
remaining = []
|
|
889
|
+
for pair in pairs:
|
|
890
|
+
# Get sorted model names for lookup
|
|
891
|
+
first, second, _ = get_sorted_model_pair(pair.model_a, pair.model_b)
|
|
892
|
+
key = (first, second, pair.sample_index)
|
|
893
|
+
|
|
894
|
+
if key not in completed:
|
|
895
|
+
remaining.append(pair)
|
|
896
|
+
|
|
897
|
+
skipped = len(pairs) - len(remaining)
|
|
898
|
+
if skipped > 0:
|
|
899
|
+
logger.info(f"Skipping {skipped} already completed battles")
|
|
900
|
+
|
|
901
|
+
# Log if there are orphaned battles from removed models
|
|
902
|
+
orphaned = len(all_completed) - len(completed)
|
|
903
|
+
if orphaned > 0:
|
|
904
|
+
logger.info(
|
|
905
|
+
f"Ignoring {orphaned} battle records involving removed models"
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
return remaining
|
|
909
|
+
|
|
910
|
+
def _execute_single_battle(
|
|
911
|
+
self,
|
|
912
|
+
pair: BattlePair
|
|
913
|
+
) -> Optional[BattleResult]:
|
|
914
|
+
"""
|
|
915
|
+
Execute a single battle.
|
|
916
|
+
|
|
917
|
+
Args:
|
|
918
|
+
pair: BattlePair to execute
|
|
919
|
+
|
|
920
|
+
Returns:
|
|
921
|
+
BattleResult or None if failed
|
|
922
|
+
"""
|
|
923
|
+
try:
|
|
924
|
+
# Get sample data
|
|
925
|
+
sample = self.dataset.get_by_index(pair.sample_index)
|
|
926
|
+
if sample is None:
|
|
927
|
+
logger.warning(
|
|
928
|
+
f"Sample {pair.sample_index} not found in dataset"
|
|
929
|
+
)
|
|
930
|
+
return None
|
|
931
|
+
|
|
932
|
+
# Get model outputs
|
|
933
|
+
output_a = self.model_manager.get_output_path(
|
|
934
|
+
pair.model_a, pair.sample_index
|
|
935
|
+
)
|
|
936
|
+
output_b = self.model_manager.get_output_path(
|
|
937
|
+
pair.model_b, pair.sample_index
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
if output_a is None or output_b is None:
|
|
941
|
+
logger.warning(
|
|
942
|
+
f"Missing output for battle {pair.model_a} vs {pair.model_b} "
|
|
943
|
+
f"at index {pair.sample_index}"
|
|
944
|
+
)
|
|
945
|
+
return None
|
|
946
|
+
|
|
947
|
+
# Execute battle
|
|
948
|
+
result = execute_battle(
|
|
949
|
+
vlm=self.vlm,
|
|
950
|
+
prompt_module=self.prompt_module,
|
|
951
|
+
sample=sample,
|
|
952
|
+
model_a_output=output_a,
|
|
953
|
+
model_b_output=output_b,
|
|
954
|
+
model_a=pair.model_a,
|
|
955
|
+
model_b=pair.model_b,
|
|
956
|
+
parallel_swap_calls=self.config.parallel_swap_calls,
|
|
957
|
+
progress=self._progress_queue,
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
return result
|
|
961
|
+
|
|
962
|
+
except Exception as e:
|
|
963
|
+
logger.error(
|
|
964
|
+
f"Error executing battle {pair.model_a} vs {pair.model_b} "
|
|
965
|
+
f"at index {pair.sample_index}: {e}"
|
|
966
|
+
)
|
|
967
|
+
return None
|
|
968
|
+
|
|
969
|
+
def _process_result(
|
|
970
|
+
self,
|
|
971
|
+
result: BattleResult,
|
|
972
|
+
state: ArenaState
|
|
973
|
+
) -> ArenaState:
|
|
974
|
+
"""
|
|
975
|
+
Process a battle result: log and update state.
|
|
976
|
+
|
|
977
|
+
Args:
|
|
978
|
+
result: BattleResult from battle execution
|
|
979
|
+
state: Current arena state
|
|
980
|
+
|
|
981
|
+
Returns:
|
|
982
|
+
Updated arena state
|
|
983
|
+
"""
|
|
984
|
+
# Log battle result (slim)
|
|
985
|
+
self.battle_logger.log_battle_result(result)
|
|
986
|
+
|
|
987
|
+
# Log audit trail (detailed)
|
|
988
|
+
if self.audit_logger:
|
|
989
|
+
self.audit_logger.log_battle_result(result)
|
|
990
|
+
|
|
991
|
+
# Update W/L/T stats only. Elo is recomputed via Bradley-Terry fitting
|
|
992
|
+
# from accumulated battle logs (order-independent).
|
|
993
|
+
state = update_stats(state, result.model_a, result.model_b, result.final_winner)
|
|
994
|
+
|
|
995
|
+
return state
|
|
996
|
+
|
|
997
|
+
def _get_battles_from_logs(self) -> list[BattleTuple]:
|
|
998
|
+
"""
|
|
999
|
+
Load battle records from logs and convert to BattleTuple format.
|
|
1000
|
+
|
|
1001
|
+
Returns:
|
|
1002
|
+
List of (model_a, model_b, winner) tuples for BT-Elo computation.
|
|
1003
|
+
"""
|
|
1004
|
+
records = load_battle_records(self.pk_logs_dir)
|
|
1005
|
+
battles: list[BattleTuple] = []
|
|
1006
|
+
|
|
1007
|
+
current_models = set(self.all_models)
|
|
1008
|
+
|
|
1009
|
+
for record in records:
|
|
1010
|
+
model_a = record.get("model_a", "")
|
|
1011
|
+
model_b = record.get("model_b", "")
|
|
1012
|
+
final_winner = record.get("final_winner", "")
|
|
1013
|
+
|
|
1014
|
+
# Skip records involving removed models
|
|
1015
|
+
if model_a not in current_models or model_b not in current_models:
|
|
1016
|
+
continue
|
|
1017
|
+
|
|
1018
|
+
# Convert winner to standard format
|
|
1019
|
+
if final_winner == model_a:
|
|
1020
|
+
winner = "model_a"
|
|
1021
|
+
elif final_winner == model_b:
|
|
1022
|
+
winner = "model_b"
|
|
1023
|
+
elif final_winner == "tie":
|
|
1024
|
+
winner = "tie"
|
|
1025
|
+
else:
|
|
1026
|
+
continue # Skip invalid records
|
|
1027
|
+
|
|
1028
|
+
battles.append((model_a, model_b, winner))
|
|
1029
|
+
|
|
1030
|
+
return battles
|
|
1031
|
+
|
|
1032
|
+
def _load_anchor_elo(self) -> dict[str, float]:
|
|
1033
|
+
"""
|
|
1034
|
+
Load anchor ELO ratings from the latest milestone snapshot.
|
|
1035
|
+
|
|
1036
|
+
Returns:
|
|
1037
|
+
Dict mapping model name to ELO rating for milestone models,
|
|
1038
|
+
or empty dict if no milestone exists.
|
|
1039
|
+
"""
|
|
1040
|
+
# Discover milestone experiments
|
|
1041
|
+
exp_keys: list[tuple[tuple, str]] = []
|
|
1042
|
+
if not os.path.isdir(self.pk_logs_dir):
|
|
1043
|
+
return {}
|
|
1044
|
+
|
|
1045
|
+
for name in os.listdir(self.pk_logs_dir):
|
|
1046
|
+
if name.startswith("."):
|
|
1047
|
+
continue
|
|
1048
|
+
exp_dir = os.path.join(self.pk_logs_dir, name)
|
|
1049
|
+
if not os.path.isdir(exp_dir):
|
|
1050
|
+
continue
|
|
1051
|
+
d = parse_exp_date_suffix(name)
|
|
1052
|
+
if d is None:
|
|
1053
|
+
continue
|
|
1054
|
+
exp_keys.append(((d, name), name))
|
|
1055
|
+
|
|
1056
|
+
exp_keys.sort(key=lambda x: x[0])
|
|
1057
|
+
|
|
1058
|
+
# Find milestones
|
|
1059
|
+
milestones = [name for (key, name) in exp_keys if is_milestone_exp(name)]
|
|
1060
|
+
if not milestones:
|
|
1061
|
+
return {}
|
|
1062
|
+
|
|
1063
|
+
# Load from latest milestone snapshot
|
|
1064
|
+
latest_milestone = milestones[-1]
|
|
1065
|
+
snapshot_path = os.path.join(self.pk_logs_dir, latest_milestone, "elo_snapshot.json")
|
|
1066
|
+
|
|
1067
|
+
if not os.path.isfile(snapshot_path):
|
|
1068
|
+
return {}
|
|
1069
|
+
|
|
1070
|
+
try:
|
|
1071
|
+
with open(snapshot_path, "r", encoding="utf-8") as f:
|
|
1072
|
+
data = json.load(f)
|
|
1073
|
+
except Exception:
|
|
1074
|
+
return {}
|
|
1075
|
+
|
|
1076
|
+
if not isinstance(data, dict):
|
|
1077
|
+
return {}
|
|
1078
|
+
|
|
1079
|
+
# Accept either: {"elo": {...}} or a direct {model: elo} mapping
|
|
1080
|
+
raw = data.get("elo") if isinstance(data.get("elo"), dict) else data
|
|
1081
|
+
if not isinstance(raw, dict):
|
|
1082
|
+
return {}
|
|
1083
|
+
|
|
1084
|
+
# Filter to only include models that exist in current model set
|
|
1085
|
+
current_models = set(self.all_models)
|
|
1086
|
+
anchor_elo: dict[str, float] = {}
|
|
1087
|
+
for k, v in raw.items():
|
|
1088
|
+
if str(k) in current_models:
|
|
1089
|
+
try:
|
|
1090
|
+
anchor_elo[str(k)] = float(v)
|
|
1091
|
+
except Exception:
|
|
1092
|
+
continue
|
|
1093
|
+
|
|
1094
|
+
return anchor_elo
|
|
1095
|
+
|
|
1096
|
+
def _run_adaptive_with_ci_checking(self) -> ArenaState:
|
|
1097
|
+
"""
|
|
1098
|
+
Run arena evaluation with adaptive CI-based sampling.
|
|
1099
|
+
|
|
1100
|
+
This method implements the iterative loop:
|
|
1101
|
+
1. Run initial batch (min_samples per pair)
|
|
1102
|
+
2. Compute bootstrap CI
|
|
1103
|
+
3. If max CI width > target, add batch_size more samples to unconverged pairs
|
|
1104
|
+
4. Repeat until all pairs converge or reach max_samples
|
|
1105
|
+
|
|
1106
|
+
Returns:
|
|
1107
|
+
Final ArenaState after all battles
|
|
1108
|
+
"""
|
|
1109
|
+
sampling_config = self.config.sampling
|
|
1110
|
+
is_milestone = is_milestone_exp(self.config.exp_name or "")
|
|
1111
|
+
|
|
1112
|
+
# Determine target samples per pair for initial batch
|
|
1113
|
+
if is_milestone:
|
|
1114
|
+
target_samples = sampling_config.milestone_min_samples
|
|
1115
|
+
logger.info(f"Milestone experiment: targeting {target_samples} samples/pair initially")
|
|
1116
|
+
else:
|
|
1117
|
+
target_samples = sampling_config.min_samples
|
|
1118
|
+
logger.info(f"Incremental experiment: targeting {target_samples} samples/pair initially")
|
|
1119
|
+
|
|
1120
|
+
# Get all dataset indices
|
|
1121
|
+
all_indices = self.dataset.get_all_indices()
|
|
1122
|
+
|
|
1123
|
+
# Build model pairs (new models vs all selected models)
|
|
1124
|
+
selected_models = set(self.models)
|
|
1125
|
+
new_models = [m for m in self.new_models if m in selected_models]
|
|
1126
|
+
|
|
1127
|
+
if not new_models:
|
|
1128
|
+
logger.info("No new models to evaluate")
|
|
1129
|
+
return rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
|
|
1130
|
+
|
|
1131
|
+
# Build unique pair set
|
|
1132
|
+
pair_set: set[tuple[str, str]] = set()
|
|
1133
|
+
for m in new_models:
|
|
1134
|
+
for other in selected_models:
|
|
1135
|
+
if other == m:
|
|
1136
|
+
continue
|
|
1137
|
+
a, b, _ = get_sorted_model_pair(m, other)
|
|
1138
|
+
pair_set.add((a, b))
|
|
1139
|
+
|
|
1140
|
+
model_pairs = sorted(pair_set)
|
|
1141
|
+
logger.info(f"Evaluating {len(model_pairs)} model pairs with adaptive sampling")
|
|
1142
|
+
|
|
1143
|
+
# Initialize scheduler
|
|
1144
|
+
scheduler = AdaptiveSamplingScheduler(config=sampling_config)
|
|
1145
|
+
|
|
1146
|
+
# Load existing battle counts
|
|
1147
|
+
existing_counts = count_battles_per_pair(self.pk_logs_dir)
|
|
1148
|
+
for pair in model_pairs:
|
|
1149
|
+
count = existing_counts.get(pair, 0)
|
|
1150
|
+
scheduler.update_state(pair[0], pair[1], current_samples=count)
|
|
1151
|
+
|
|
1152
|
+
# Load existing state
|
|
1153
|
+
state = load_state(self.state_path)
|
|
1154
|
+
|
|
1155
|
+
# Progress tracking
|
|
1156
|
+
progress_queue, _progress_thread, progress_close = _start_calls_progress_consumer(
|
|
1157
|
+
enabled=self.config.enable_progress_bar,
|
|
1158
|
+
total=None, # Dynamic total
|
|
1159
|
+
)
|
|
1160
|
+
self._progress_queue = progress_queue
|
|
1161
|
+
if self._progress_queue is not None:
|
|
1162
|
+
try:
|
|
1163
|
+
self.vlm.set_progress(self._progress_queue)
|
|
1164
|
+
except Exception:
|
|
1165
|
+
pass
|
|
1166
|
+
|
|
1167
|
+
iteration = 0
|
|
1168
|
+
total_completed = 0
|
|
1169
|
+
|
|
1170
|
+
while True:
|
|
1171
|
+
iteration += 1
|
|
1172
|
+
|
|
1173
|
+
# Determine which pairs need more samples
|
|
1174
|
+
pairs_to_run: list[tuple[str, str]] = []
|
|
1175
|
+
samples_per_pair: dict[tuple[str, str], int] = {}
|
|
1176
|
+
|
|
1177
|
+
for pair in model_pairs:
|
|
1178
|
+
pair_state = scheduler.get_or_create_state(pair[0], pair[1])
|
|
1179
|
+
samples_to_run = pair_state.get_samples_to_run(sampling_config, len(all_indices))
|
|
1180
|
+
|
|
1181
|
+
if samples_to_run > 0:
|
|
1182
|
+
pairs_to_run.append(pair)
|
|
1183
|
+
samples_per_pair[pair] = samples_to_run
|
|
1184
|
+
|
|
1185
|
+
if not pairs_to_run:
|
|
1186
|
+
logger.info("All pairs have converged or reached max_samples")
|
|
1187
|
+
break
|
|
1188
|
+
|
|
1189
|
+
total_samples_this_iter = sum(samples_per_pair.values())
|
|
1190
|
+
logger.info(
|
|
1191
|
+
f"Iteration {iteration}: running {total_samples_this_iter} battles "
|
|
1192
|
+
f"across {len(pairs_to_run)} pairs"
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
# Generate battle pairs for this iteration
|
|
1196
|
+
completed_set = load_battle_history(self.pk_logs_dir)
|
|
1197
|
+
battle_pairs: list[BattlePair] = []
|
|
1198
|
+
|
|
1199
|
+
for pair in pairs_to_run:
|
|
1200
|
+
model_a, model_b = pair
|
|
1201
|
+
needed = samples_per_pair[pair]
|
|
1202
|
+
|
|
1203
|
+
# Get valid indices for this pair
|
|
1204
|
+
valid_indices = self.model_manager.validate_coverage(model_a, model_b, all_indices)
|
|
1205
|
+
|
|
1206
|
+
# Filter out already completed
|
|
1207
|
+
pending_indices = [
|
|
1208
|
+
idx for idx in valid_indices
|
|
1209
|
+
if (model_a, model_b, idx) not in completed_set
|
|
1210
|
+
]
|
|
1211
|
+
|
|
1212
|
+
# Select up to 'needed' samples
|
|
1213
|
+
if len(pending_indices) > needed:
|
|
1214
|
+
selected = random.sample(pending_indices, needed)
|
|
1215
|
+
else:
|
|
1216
|
+
selected = pending_indices
|
|
1217
|
+
|
|
1218
|
+
for idx in selected:
|
|
1219
|
+
battle_pairs.append(BattlePair(
|
|
1220
|
+
model_a=model_a,
|
|
1221
|
+
model_b=model_b,
|
|
1222
|
+
sample_index=idx
|
|
1223
|
+
))
|
|
1224
|
+
|
|
1225
|
+
if not battle_pairs:
|
|
1226
|
+
logger.info("No more battles to execute")
|
|
1227
|
+
break
|
|
1228
|
+
|
|
1229
|
+
# Update progress bar total
|
|
1230
|
+
if self._progress_queue is not None:
|
|
1231
|
+
try:
|
|
1232
|
+
self._progress_queue.put(("total", 2 * len(battle_pairs)))
|
|
1233
|
+
except Exception:
|
|
1234
|
+
pass
|
|
1235
|
+
|
|
1236
|
+
# Execute battles
|
|
1237
|
+
iter_completed = 0
|
|
1238
|
+
|
|
1239
|
+
if self.config.num_threads <= 1:
|
|
1240
|
+
# Sequential execution
|
|
1241
|
+
for pair in battle_pairs:
|
|
1242
|
+
result = self._execute_single_battle(pair)
|
|
1243
|
+
if result:
|
|
1244
|
+
state = self._process_result(result, state)
|
|
1245
|
+
iter_completed += 1
|
|
1246
|
+
else:
|
|
1247
|
+
# Parallel execution
|
|
1248
|
+
with ThreadPoolExecutor(max_workers=self.config.num_threads) as executor:
|
|
1249
|
+
future_to_pair = {
|
|
1250
|
+
executor.submit(self._execute_single_battle, pair): pair
|
|
1251
|
+
for pair in battle_pairs
|
|
1252
|
+
}
|
|
1253
|
+
|
|
1254
|
+
for future in as_completed(future_to_pair):
|
|
1255
|
+
try:
|
|
1256
|
+
result = future.result()
|
|
1257
|
+
if result:
|
|
1258
|
+
state = self._process_result(result, state)
|
|
1259
|
+
iter_completed += 1
|
|
1260
|
+
except Exception as e:
|
|
1261
|
+
pair = future_to_pair[future]
|
|
1262
|
+
logger.error(f"Battle {pair.model_a} vs {pair.model_b} failed: {e}")
|
|
1263
|
+
|
|
1264
|
+
total_completed += iter_completed
|
|
1265
|
+
logger.info(f"Iteration {iteration} completed: {iter_completed} battles")
|
|
1266
|
+
|
|
1267
|
+
# Save intermediate state
|
|
1268
|
+
save_state(state, self.state_path)
|
|
1269
|
+
|
|
1270
|
+
# Update scheduler with new counts
|
|
1271
|
+
new_counts = count_battles_per_pair(self.pk_logs_dir)
|
|
1272
|
+
for pair in model_pairs:
|
|
1273
|
+
count = new_counts.get(pair, 0)
|
|
1274
|
+
scheduler.update_state(pair[0], pair[1], current_samples=count)
|
|
1275
|
+
|
|
1276
|
+
# Compute bootstrap CI to check convergence
|
|
1277
|
+
battles = self._get_battles_from_logs()
|
|
1278
|
+
if battles:
|
|
1279
|
+
# Load anchor ELO from latest milestone snapshot
|
|
1280
|
+
# Milestone models have fixed ELO, so we only check CI for new models
|
|
1281
|
+
anchor_elo = self._load_anchor_elo()
|
|
1282
|
+
|
|
1283
|
+
bootstrap_result = compute_bootstrap_bt_elo(
|
|
1284
|
+
battles,
|
|
1285
|
+
models=self.all_models,
|
|
1286
|
+
fixed_ratings=anchor_elo if anchor_elo else None,
|
|
1287
|
+
num_bootstrap=sampling_config.num_bootstrap,
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
# Only check CI for new models (non-anchor models)
|
|
1291
|
+
# Anchor models have CI width = 0 since their ELO is fixed
|
|
1292
|
+
new_models_set = set(new_models)
|
|
1293
|
+
new_model_ci_widths = [
|
|
1294
|
+
bootstrap_result.ci_width.get(m, 0.0)
|
|
1295
|
+
for m in new_models_set
|
|
1296
|
+
if m in bootstrap_result.ci_width
|
|
1297
|
+
]
|
|
1298
|
+
|
|
1299
|
+
if new_model_ci_widths:
|
|
1300
|
+
max_ci_width = max(new_model_ci_widths)
|
|
1301
|
+
mean_ci_width = sum(new_model_ci_widths) / len(new_model_ci_widths)
|
|
1302
|
+
else:
|
|
1303
|
+
max_ci_width = bootstrap_result.get_max_ci_width()
|
|
1304
|
+
mean_ci_width = bootstrap_result.get_mean_ci_width()
|
|
1305
|
+
|
|
1306
|
+
logger.info(
|
|
1307
|
+
f"CI check (new models only): max_width={max_ci_width:.2f}, "
|
|
1308
|
+
f"mean_width={mean_ci_width:.2f}, target={sampling_config.target_ci_width:.2f}"
|
|
1309
|
+
)
|
|
1310
|
+
|
|
1311
|
+
# Check if all new models have converged
|
|
1312
|
+
if max_ci_width <= sampling_config.target_ci_width:
|
|
1313
|
+
logger.info(f"CI target reached! Max CI width for new models: {max_ci_width:.2f}")
|
|
1314
|
+
# Mark all pairs as converged
|
|
1315
|
+
for pair in model_pairs:
|
|
1316
|
+
pair_state = scheduler.get_or_create_state(pair[0], pair[1])
|
|
1317
|
+
pair_state.converged = True
|
|
1318
|
+
break
|
|
1319
|
+
|
|
1320
|
+
# Check if all pairs have reached max_samples
|
|
1321
|
+
all_maxed = True
|
|
1322
|
+
for pair in model_pairs:
|
|
1323
|
+
pair_state = scheduler.get_or_create_state(pair[0], pair[1])
|
|
1324
|
+
if pair_state.current_samples < sampling_config.max_samples:
|
|
1325
|
+
all_maxed = False
|
|
1326
|
+
break
|
|
1327
|
+
|
|
1328
|
+
if all_maxed:
|
|
1329
|
+
logger.info("All pairs reached max_samples limit")
|
|
1330
|
+
break
|
|
1331
|
+
|
|
1332
|
+
progress_close()
|
|
1333
|
+
|
|
1334
|
+
# Final summary
|
|
1335
|
+
summary = scheduler.get_summary()
|
|
1336
|
+
logger.info(
|
|
1337
|
+
f"Adaptive sampling complete: "
|
|
1338
|
+
f"{summary['total_pairs']} pairs, "
|
|
1339
|
+
f"{summary['converged_pairs']} converged, "
|
|
1340
|
+
f"{summary['maxed_pairs']} reached max_samples, "
|
|
1341
|
+
f"{summary['total_samples']} total samples"
|
|
1342
|
+
)
|
|
1343
|
+
|
|
1344
|
+
# Final Elo recompute (Bradley-Terry) and state save
|
|
1345
|
+
final_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
|
|
1346
|
+
save_state(final_state, self.state_path)
|
|
1347
|
+
|
|
1348
|
+
logger.info(f"Arena completed: {total_completed} battles executed in {iteration} iterations")
|
|
1349
|
+
|
|
1350
|
+
return final_state
|
|
1351
|
+
|
|
1352
|
+
def run(self) -> ArenaState:
|
|
1353
|
+
"""
|
|
1354
|
+
Run the arena evaluation.
|
|
1355
|
+
|
|
1356
|
+
If models have been removed from the arena directory, the ELO state
|
|
1357
|
+
will be automatically rebuilt from battle logs (excluding removed models).
|
|
1358
|
+
|
|
1359
|
+
Returns:
|
|
1360
|
+
Final ArenaState after all battles
|
|
1361
|
+
"""
|
|
1362
|
+
# Sync state with current models (rebuild if models were removed).
|
|
1363
|
+
# This rebuild uses Bradley-Terry Elo scoring from logs.
|
|
1364
|
+
self._sync_state_with_models()
|
|
1365
|
+
|
|
1366
|
+
# Use adaptive CI-checking mode if enabled (and not multiprocessing)
|
|
1367
|
+
if (self.config.sampling.mode == "adaptive" and
|
|
1368
|
+
self.config.num_processes <= 1):
|
|
1369
|
+
return self._run_adaptive_with_ci_checking()
|
|
1370
|
+
|
|
1371
|
+
# Generate and filter battle pairs
|
|
1372
|
+
# If we can shard by parquet file, we can avoid constructing the full pair list
|
|
1373
|
+
# in the parent process (and avoid pickling huge lists).
|
|
1374
|
+
all_indices = self.dataset.get_all_indices()
|
|
1375
|
+
|
|
1376
|
+
# Apply sample size limit
|
|
1377
|
+
if self.config.sample_size and self.config.sample_size < len(all_indices):
|
|
1378
|
+
indices = random.sample(all_indices, self.config.sample_size)
|
|
1379
|
+
else:
|
|
1380
|
+
indices = all_indices
|
|
1381
|
+
|
|
1382
|
+
# If num_processes <= 1, fall back to the original thread-based implementation.
|
|
1383
|
+
if self.config.num_processes <= 1:
|
|
1384
|
+
all_pairs = self._generate_battle_pairs()
|
|
1385
|
+
pairs = self._skip_completed(all_pairs)
|
|
1386
|
+
|
|
1387
|
+
if not pairs:
|
|
1388
|
+
logger.info("No battles to execute")
|
|
1389
|
+
# Ensure state is up-to-date and order-independent
|
|
1390
|
+
state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
|
|
1391
|
+
save_state(state, self.state_path)
|
|
1392
|
+
return state
|
|
1393
|
+
|
|
1394
|
+
logger.info(f"Starting arena with {len(pairs)} battles to execute")
|
|
1395
|
+
logger.info(f"Models: {self.models}")
|
|
1396
|
+
logger.info(f"Experiment: {self.config.exp_name}")
|
|
1397
|
+
logger.info(f"Sampling mode: full")
|
|
1398
|
+
|
|
1399
|
+
# Load existing state
|
|
1400
|
+
state = load_state(self.state_path)
|
|
1401
|
+
|
|
1402
|
+
# Progress tracking
|
|
1403
|
+
completed = 0
|
|
1404
|
+
total = len(pairs)
|
|
1405
|
+
|
|
1406
|
+
progress_queue, _progress_thread, progress_close = _start_calls_progress_consumer(
|
|
1407
|
+
enabled=self.config.enable_progress_bar,
|
|
1408
|
+
total=(2 * len(pairs)) if self.config.enable_progress_bar else None,
|
|
1409
|
+
)
|
|
1410
|
+
self._progress_queue = progress_queue
|
|
1411
|
+
if self._progress_queue is not None:
|
|
1412
|
+
try:
|
|
1413
|
+
self.vlm.set_progress(self._progress_queue)
|
|
1414
|
+
except Exception:
|
|
1415
|
+
pass
|
|
1416
|
+
|
|
1417
|
+
if self.config.num_threads <= 1:
|
|
1418
|
+
# Sequential execution
|
|
1419
|
+
for pair in pairs:
|
|
1420
|
+
result = self._execute_single_battle(pair)
|
|
1421
|
+
|
|
1422
|
+
if result:
|
|
1423
|
+
state = self._process_result(result, state)
|
|
1424
|
+
completed += 1
|
|
1425
|
+
|
|
1426
|
+
# Progress logging every 10 battles
|
|
1427
|
+
if completed % 10 == 0:
|
|
1428
|
+
logger.info(f"Progress: {completed}/{total} battles")
|
|
1429
|
+
# Save intermediate state
|
|
1430
|
+
save_state(state, self.state_path)
|
|
1431
|
+
else:
|
|
1432
|
+
# Parallel execution
|
|
1433
|
+
with ThreadPoolExecutor(max_workers=self.config.num_threads) as executor:
|
|
1434
|
+
# Submit all battles
|
|
1435
|
+
future_to_pair = {
|
|
1436
|
+
executor.submit(self._execute_single_battle, pair): pair
|
|
1437
|
+
for pair in pairs
|
|
1438
|
+
}
|
|
1439
|
+
|
|
1440
|
+
# Process completed futures
|
|
1441
|
+
for future in as_completed(future_to_pair):
|
|
1442
|
+
pair = future_to_pair[future]
|
|
1443
|
+
|
|
1444
|
+
try:
|
|
1445
|
+
result = future.result()
|
|
1446
|
+
|
|
1447
|
+
if result:
|
|
1448
|
+
state = self._process_result(result, state)
|
|
1449
|
+
completed += 1
|
|
1450
|
+
|
|
1451
|
+
# Progress logging every 10 battles
|
|
1452
|
+
if completed % 10 == 0:
|
|
1453
|
+
logger.info(f"Progress: {completed}/{total} battles")
|
|
1454
|
+
# Save intermediate state
|
|
1455
|
+
save_state(state, self.state_path)
|
|
1456
|
+
except Exception as e:
|
|
1457
|
+
logger.error(
|
|
1458
|
+
f"Battle {pair.model_a} vs {pair.model_b} failed: {e}"
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
progress_close()
|
|
1462
|
+
|
|
1463
|
+
# Final Elo recompute (Bradley-Terry) and state save
|
|
1464
|
+
final_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
|
|
1465
|
+
save_state(final_state, self.state_path)
|
|
1466
|
+
|
|
1467
|
+
logger.info(f"Arena completed: {completed}/{total} battles executed")
|
|
1468
|
+
|
|
1469
|
+
return final_state
|
|
1470
|
+
|
|
1471
|
+
# === Multiprocessing path (per-parquet sharding) ===
|
|
1472
|
+
grouped = self.dataset.group_indices_by_parquet(indices)
|
|
1473
|
+
if "" in grouped:
|
|
1474
|
+
logger.warning(
|
|
1475
|
+
"Parquet source mapping is incomplete (missing index->parquet mapping). "
|
|
1476
|
+
"Falling back to single-process execution."
|
|
1477
|
+
)
|
|
1478
|
+
self.config.num_processes = 1
|
|
1479
|
+
# Re-load dataset in full mode for single-process execution.
|
|
1480
|
+
self.dataset = ParquetDataset(self.config.data_dir, self.config.subset, load_mode="full")
|
|
1481
|
+
return self.run()
|
|
1482
|
+
|
|
1483
|
+
total_concurrency = max(1, int(self.config.num_processes)) * max(1, int(self.config.num_threads))
|
|
1484
|
+
logger.info(
|
|
1485
|
+
f"Starting arena with multiprocessing: num_processes={self.config.num_processes}, "
|
|
1486
|
+
f"num_threads={self.config.num_threads}, total_concurrency~{total_concurrency}"
|
|
1487
|
+
)
|
|
1488
|
+
logger.info(f"Models: {self.models}")
|
|
1489
|
+
logger.info(f"Experiment: {self.config.exp_name}")
|
|
1490
|
+
|
|
1491
|
+
completed = 0
|
|
1492
|
+
attempted = 0
|
|
1493
|
+
parquet_tasks = [(pf, idxs) for pf, idxs in grouped.items() if idxs]
|
|
1494
|
+
parquet_tasks.sort(key=lambda x: x[0])
|
|
1495
|
+
|
|
1496
|
+
# Assign parquet files to processes up-front (avoid per-parquet re-init overhead in a worker).
|
|
1497
|
+
# Simple greedy bin-packing by number of indices for load balancing.
|
|
1498
|
+
num_workers = max(1, int(self.config.num_processes))
|
|
1499
|
+
buckets: list[list[tuple[str, list[int]]]] = [[] for _ in range(num_workers)]
|
|
1500
|
+
bucket_sizes = [0 for _ in range(num_workers)]
|
|
1501
|
+
for pf, idxs in sorted(parquet_tasks, key=lambda x: len(x[1]), reverse=True):
|
|
1502
|
+
k = bucket_sizes.index(min(bucket_sizes))
|
|
1503
|
+
buckets[k].append((pf, idxs))
|
|
1504
|
+
bucket_sizes[k] += len(idxs)
|
|
1505
|
+
|
|
1506
|
+
# Progress consumer (optional). Use a process-safe Manager queue and batch updates in workers.
|
|
1507
|
+
manager = None
|
|
1508
|
+
mp_progress_queue = None
|
|
1509
|
+
progress_close = lambda: None
|
|
1510
|
+
if self.config.enable_progress_bar:
|
|
1511
|
+
try:
|
|
1512
|
+
import multiprocessing
|
|
1513
|
+
manager = multiprocessing.Manager()
|
|
1514
|
+
mp_progress_queue = manager.Queue()
|
|
1515
|
+
# Capture the queue reference for the closure (type narrowing)
|
|
1516
|
+
_queue = mp_progress_queue
|
|
1517
|
+
# Reuse same tqdm consumer code by wrapping manager queue into a local consumer thread.
|
|
1518
|
+
try:
|
|
1519
|
+
from tqdm import tqdm # type: ignore
|
|
1520
|
+
bar = tqdm(total=None, unit="call", desc="API Calls", dynamic_ncols=True)
|
|
1521
|
+
# Must be picklable across processes.
|
|
1522
|
+
stop_sentinel = ("stop", None)
|
|
1523
|
+
recent: deque[str] = deque(maxlen=10)
|
|
1524
|
+
|
|
1525
|
+
def _mp_consumer() -> None:
|
|
1526
|
+
while True:
|
|
1527
|
+
item = _queue.get()
|
|
1528
|
+
if item == stop_sentinel:
|
|
1529
|
+
break
|
|
1530
|
+
if isinstance(item, (int, float)):
|
|
1531
|
+
try:
|
|
1532
|
+
bar.update(int(item))
|
|
1533
|
+
except Exception:
|
|
1534
|
+
pass
|
|
1535
|
+
elif isinstance(item, tuple) and len(item) == 2 and item[0] == "log":
|
|
1536
|
+
try:
|
|
1537
|
+
recent.append(str(item[1]))
|
|
1538
|
+
bar.set_postfix_str(" | ".join(recent))
|
|
1539
|
+
except Exception:
|
|
1540
|
+
pass
|
|
1541
|
+
elif isinstance(item, tuple) and len(item) == 2 and item[0] == "total":
|
|
1542
|
+
try:
|
|
1543
|
+
delta = int(item[1])
|
|
1544
|
+
if delta > 0:
|
|
1545
|
+
bar.total = (bar.total or 0) + delta
|
|
1546
|
+
bar.refresh()
|
|
1547
|
+
except Exception:
|
|
1548
|
+
pass
|
|
1549
|
+
else:
|
|
1550
|
+
pass
|
|
1551
|
+
|
|
1552
|
+
t = threading.Thread(target=_mp_consumer, name="mp-calls-progress-consumer", daemon=True)
|
|
1553
|
+
t.start()
|
|
1554
|
+
|
|
1555
|
+
def progress_close() -> None:
|
|
1556
|
+
try:
|
|
1557
|
+
mp_progress_queue.put(stop_sentinel)
|
|
1558
|
+
except Exception:
|
|
1559
|
+
pass
|
|
1560
|
+
t.join(timeout=5)
|
|
1561
|
+
try:
|
|
1562
|
+
bar.close()
|
|
1563
|
+
except Exception:
|
|
1564
|
+
pass
|
|
1565
|
+
|
|
1566
|
+
except Exception:
|
|
1567
|
+
logger.warning("tqdm is not available; progress bar disabled")
|
|
1568
|
+
mp_progress_queue = None
|
|
1569
|
+
except Exception:
|
|
1570
|
+
logger.warning("Failed to initialize multiprocessing progress queue; progress bar disabled")
|
|
1571
|
+
mp_progress_queue = None
|
|
1572
|
+
|
|
1573
|
+
with ProcessPoolExecutor(max_workers=self.config.num_processes) as executor:
|
|
1574
|
+
futures = []
|
|
1575
|
+
for work in buckets:
|
|
1576
|
+
if not work:
|
|
1577
|
+
continue
|
|
1578
|
+
futures.append(executor.submit(
|
|
1579
|
+
_run_parquet_bucket_worker,
|
|
1580
|
+
arena_dir=self.config.arena_dir,
|
|
1581
|
+
data_dir=self.config.data_dir,
|
|
1582
|
+
subset=self.config.subset,
|
|
1583
|
+
exp_name=self.config.exp_name or "", # exp_name is guaranteed to be set in __init__
|
|
1584
|
+
parquet_work=work,
|
|
1585
|
+
models=self.models,
|
|
1586
|
+
new_models=self.new_models,
|
|
1587
|
+
num_threads=self.config.num_threads,
|
|
1588
|
+
judge_model=self.config.judge_model,
|
|
1589
|
+
temperature=self.config.temperature,
|
|
1590
|
+
prompt=self.config.prompt,
|
|
1591
|
+
timeout=self.config.timeout,
|
|
1592
|
+
max_retries=self.config.max_retries,
|
|
1593
|
+
base_urls=self.config.base_urls,
|
|
1594
|
+
api_keys=self.config.api_keys,
|
|
1595
|
+
enable_audit_log=self.config.enable_audit_log,
|
|
1596
|
+
parallel_swap_calls=self.config.parallel_swap_calls,
|
|
1597
|
+
progress_queue=mp_progress_queue,
|
|
1598
|
+
))
|
|
1599
|
+
|
|
1600
|
+
for fut in as_completed(futures):
|
|
1601
|
+
try:
|
|
1602
|
+
res = fut.result()
|
|
1603
|
+
completed += int(res.get("completed", 0))
|
|
1604
|
+
attempted += int(res.get("attempted", 0))
|
|
1605
|
+
if completed > 0 and completed % 50 == 0:
|
|
1606
|
+
logger.info(f"Progress: completed={completed} attempted={attempted}")
|
|
1607
|
+
except Exception as e:
|
|
1608
|
+
logger.error(f"Worker failed: {e}")
|
|
1609
|
+
|
|
1610
|
+
progress_close()
|
|
1611
|
+
if manager is not None:
|
|
1612
|
+
try:
|
|
1613
|
+
manager.shutdown()
|
|
1614
|
+
except Exception:
|
|
1615
|
+
pass
|
|
1616
|
+
|
|
1617
|
+
# Final Elo recompute (Bradley-Terry) and state save
|
|
1618
|
+
final_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
|
|
1619
|
+
save_state(final_state, self.state_path)
|
|
1620
|
+
|
|
1621
|
+
logger.info(f"Arena completed (multiprocessing): completed={completed} attempted={attempted}")
|
|
1622
|
+
|
|
1623
|
+
return final_state
|
|
1624
|
+
|
|
1625
|
+
def update_leaderboard(self) -> None:
|
|
1626
|
+
"""Update the leaderboard README.md file."""
|
|
1627
|
+
# Always rebuild state from logs to ensure BT Elo is consistent and up-to-date.
|
|
1628
|
+
state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
|
|
1629
|
+
save_state(state, self.state_path)
|
|
1630
|
+
|
|
1631
|
+
title = f"{self.config.subset.capitalize()} Leaderboard"
|
|
1632
|
+
save_leaderboard(state, self.leaderboard_path, title)
|
|
1633
|
+
|
|
1634
|
+
logger.info(f"Leaderboard saved to {self.leaderboard_path}")
|
|
1635
|
+
|
|
1636
|
+
def get_status(self) -> dict[str, Any]:
|
|
1637
|
+
"""
|
|
1638
|
+
Get arena status summary.
|
|
1639
|
+
|
|
1640
|
+
Returns:
|
|
1641
|
+
Dict with status information
|
|
1642
|
+
"""
|
|
1643
|
+
state = load_state(self.state_path)
|
|
1644
|
+
|
|
1645
|
+
return {
|
|
1646
|
+
"subset": self.config.subset,
|
|
1647
|
+
"models": self.models,
|
|
1648
|
+
"total_models": len(self.models),
|
|
1649
|
+
"total_battles": state.total_battles,
|
|
1650
|
+
"last_updated": state.last_updated,
|
|
1651
|
+
"dataset_size": len(self.dataset),
|
|
1652
|
+
"arena_dir": self.config.arena_dir
|
|
1653
|
+
}
|
|
1654
|
+
|
|
1655
|
+
|
|
1656
|
+
def get_all_subsets_status(arena_dir: str, data_dir: str) -> list[dict[str, Any]]:
|
|
1657
|
+
"""
|
|
1658
|
+
Get status for all subsets in an arena directory.
|
|
1659
|
+
|
|
1660
|
+
Args:
|
|
1661
|
+
arena_dir: Arena directory path
|
|
1662
|
+
data_dir: Data directory path
|
|
1663
|
+
|
|
1664
|
+
Returns:
|
|
1665
|
+
List of status dicts for each subset
|
|
1666
|
+
"""
|
|
1667
|
+
subsets = discover_subsets(data_dir)
|
|
1668
|
+
statuses = []
|
|
1669
|
+
|
|
1670
|
+
for subset in subsets:
|
|
1671
|
+
state_path = os.path.join(arena_dir, subset, "arena", "state.json")
|
|
1672
|
+
state = load_state(state_path)
|
|
1673
|
+
|
|
1674
|
+
models_dir = os.path.join(arena_dir, subset, "models")
|
|
1675
|
+
model_manager = GlobalModelOutputManager(models_dir)
|
|
1676
|
+
|
|
1677
|
+
statuses.append({
|
|
1678
|
+
"subset": subset,
|
|
1679
|
+
"models": model_manager.models,
|
|
1680
|
+
"total_models": len(model_manager.models),
|
|
1681
|
+
"total_battles": state.total_battles,
|
|
1682
|
+
"last_updated": state.last_updated
|
|
1683
|
+
})
|
|
1684
|
+
|
|
1685
|
+
return statuses
|