genarena 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. genarena/__init__.py +49 -2
  2. genarena/__main__.py +10 -0
  3. genarena/arena.py +1685 -0
  4. genarena/battle.py +337 -0
  5. genarena/bt_elo.py +507 -0
  6. genarena/cli.py +1581 -0
  7. genarena/data.py +476 -0
  8. genarena/deploy/Dockerfile +25 -0
  9. genarena/deploy/README.md +55 -0
  10. genarena/deploy/__init__.py +5 -0
  11. genarena/deploy/app.py +84 -0
  12. genarena/experiments.py +121 -0
  13. genarena/leaderboard.py +270 -0
  14. genarena/logs.py +409 -0
  15. genarena/models.py +412 -0
  16. genarena/prompts/__init__.py +127 -0
  17. genarena/prompts/mmrb2.py +373 -0
  18. genarena/sampling.py +336 -0
  19. genarena/state.py +656 -0
  20. genarena/sync/__init__.py +105 -0
  21. genarena/sync/auto_commit.py +118 -0
  22. genarena/sync/deploy_ops.py +543 -0
  23. genarena/sync/git_ops.py +422 -0
  24. genarena/sync/hf_ops.py +891 -0
  25. genarena/sync/init_ops.py +431 -0
  26. genarena/sync/packer.py +587 -0
  27. genarena/sync/submit.py +837 -0
  28. genarena/utils.py +103 -0
  29. genarena/validation/__init__.py +19 -0
  30. genarena/validation/schema.py +327 -0
  31. genarena/validation/validator.py +329 -0
  32. genarena/visualize/README.md +148 -0
  33. genarena/visualize/__init__.py +14 -0
  34. genarena/visualize/app.py +938 -0
  35. genarena/visualize/data_loader.py +2335 -0
  36. genarena/visualize/static/app.js +3762 -0
  37. genarena/visualize/static/model_aliases.json +86 -0
  38. genarena/visualize/static/style.css +4104 -0
  39. genarena/visualize/templates/index.html +413 -0
  40. genarena/vlm.py +519 -0
  41. genarena-0.1.0.dist-info/METADATA +178 -0
  42. genarena-0.1.0.dist-info/RECORD +44 -0
  43. {genarena-0.0.1.dist-info → genarena-0.1.0.dist-info}/WHEEL +1 -2
  44. genarena-0.1.0.dist-info/entry_points.txt +2 -0
  45. genarena-0.0.1.dist-info/METADATA +0 -26
  46. genarena-0.0.1.dist-info/RECORD +0 -5
  47. genarena-0.0.1.dist-info/top_level.txt +0 -1
genarena/arena.py ADDED
@@ -0,0 +1,1685 @@
1
+ # Copyright 2026 Ruihang Li.
2
+ # Licensed under the Apache License, Version 2.0.
3
+ # See LICENSE file in the project root for details.
4
+
5
+ """Arena core coordinator module."""
6
+
7
+ import itertools
8
+ import json
9
+ import logging
10
+ import os
11
+ import random
12
+ import threading
13
+ import queue as thread_queue
14
+ from collections import deque
15
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
16
+ from dataclasses import dataclass, field
17
+ from typing import Any, Optional, Union
18
+
19
+ from genarena.battle import BattleResult, execute_battle
20
+ from genarena.bt_elo import compute_bootstrap_bt_elo, BattleTuple
21
+ from genarena.data import ParquetDataset, discover_subsets
22
+ from genarena.experiments import pick_latest_experiment_name, require_valid_exp_name, is_milestone_exp, parse_exp_date_suffix
23
+ from genarena.leaderboard import save_leaderboard
24
+ from genarena.logs import AuditLogger, BattleLogger, load_battle_history, count_battles_per_pair, load_battle_records
25
+ from genarena.models import GlobalModelOutputManager, ModelOutputManager
26
+ from genarena.prompts import load_prompt
27
+ from genarena.sampling import SamplingConfig, AdaptiveSamplingScheduler
28
+ from genarena.state import ArenaState, load_state, rebuild_state_from_logs, save_state, update_stats
29
+ from genarena.utils import ensure_dir, get_sorted_model_pair, iso_timestamp
30
+ from genarena.vlm import VLMJudge
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class BattlePair:
38
+ """A pair of models and sample for a battle."""
39
+
40
+ model_a: str
41
+ model_b: str
42
+ sample_index: int
43
+
44
+
45
+ @dataclass
46
+ class ArenaConfig:
47
+ """Configuration for an arena run."""
48
+
49
+ # Required paths
50
+ arena_dir: str
51
+ data_dir: str
52
+ subset: str
53
+
54
+ # Model configuration
55
+ models: Optional[list[str]] = None # None = all models
56
+
57
+ # Experiment configuration
58
+ exp_name: Optional[str] = None # None = timestamp
59
+ sample_size: Optional[int] = None # None = all samples (used in full mode)
60
+ num_threads: int = 8
61
+ num_processes: int = 1
62
+ parallel_swap_calls: bool = False
63
+ enable_progress_bar: bool = False
64
+
65
+ # Sampling configuration
66
+ sampling: SamplingConfig = field(default_factory=SamplingConfig)
67
+
68
+ # VLM configuration
69
+ judge_model: str = "Qwen/Qwen3-VL-32B-Instruct-FP8"
70
+ temperature: float = 0.0
71
+ prompt: str = "mmrb2"
72
+ timeout: int = 120
73
+ max_retries: int = 3
74
+
75
+ # Multi-endpoint configuration
76
+ base_urls: Optional[Union[str, list[str]]] = None # Comma-separated or list
77
+ api_keys: Optional[Union[str, list[str]]] = None # Comma-separated or list
78
+
79
+ # Logging configuration
80
+ enable_audit_log: bool = True
81
+ verbose: bool = False
82
+
83
+ # Model removal behavior
84
+ clean_orphaned_logs: bool = True # Delete battle logs involving removed models
85
+
86
+ def to_dict(self) -> dict[str, Any]:
87
+ """Convert to dictionary for serialization."""
88
+ # Parse base_urls for logging
89
+ base_urls_list = []
90
+ if self.base_urls:
91
+ if isinstance(self.base_urls, str):
92
+ base_urls_list = [u.strip() for u in self.base_urls.split(",") if u.strip()]
93
+ else:
94
+ base_urls_list = list(self.base_urls)
95
+
96
+ # Count api_keys for logging (don't expose actual keys)
97
+ num_api_keys = 0
98
+ if self.api_keys:
99
+ if isinstance(self.api_keys, str):
100
+ num_api_keys = len([k for k in self.api_keys.split(",") if k.strip()])
101
+ else:
102
+ num_api_keys = len(self.api_keys)
103
+
104
+ return {
105
+ "arena_dir": self.arena_dir,
106
+ "data_dir": self.data_dir,
107
+ "subset": self.subset,
108
+ "models": self.models,
109
+ "exp_name": self.exp_name,
110
+ "sample_size": self.sample_size,
111
+ "num_threads": self.num_threads,
112
+ "num_processes": self.num_processes,
113
+ "parallel_swap_calls": self.parallel_swap_calls,
114
+ "enable_progress_bar": self.enable_progress_bar,
115
+ "sampling": self.sampling.to_dict(),
116
+ "judge_model": self.judge_model,
117
+ "temperature": self.temperature,
118
+ "prompt": self.prompt,
119
+ "timeout": self.timeout,
120
+ "max_retries": self.max_retries,
121
+ "base_urls": base_urls_list,
122
+ "num_api_keys": num_api_keys,
123
+ "enable_audit_log": self.enable_audit_log,
124
+ "clean_orphaned_logs": self.clean_orphaned_logs,
125
+ "timestamp": iso_timestamp()
126
+ }
127
+
128
+
129
+ def _run_parquet_bucket_worker(
130
+ *,
131
+ arena_dir: str,
132
+ data_dir: str,
133
+ subset: str,
134
+ exp_name: str,
135
+ parquet_work: list[tuple[str, list[int]]],
136
+ models: list[str],
137
+ new_models: list[str],
138
+ num_threads: int,
139
+ judge_model: str,
140
+ temperature: float,
141
+ prompt: str,
142
+ timeout: int,
143
+ max_retries: int,
144
+ base_urls: Optional[Union[str, list[str]]],
145
+ api_keys: Optional[Union[str, list[str]]],
146
+ enable_audit_log: bool,
147
+ parallel_swap_calls: bool,
148
+ progress_queue: Any = None,
149
+ ) -> dict[str, int]:
150
+ """
151
+ Worker entry point for multiprocessing: execute battles for a bucket of parquet files.
152
+
153
+ Notes:
154
+ - Each process initializes its own VLM client/endpoint manager.
155
+ - Results are persisted via jsonl logs (with fcntl locks), so the parent process
156
+ only needs counts for progress reporting.
157
+ """
158
+ # Local imports are avoided here because the module is already imported in workers,
159
+ # but keep this function at module-level so it's picklable by ProcessPoolExecutor.
160
+ subset_dir = os.path.join(arena_dir, subset)
161
+ models_dir = os.path.join(subset_dir, "models")
162
+ pk_logs_dir = os.path.join(subset_dir, "pk_logs")
163
+ exp_dir = os.path.join(pk_logs_dir, exp_name)
164
+
165
+ ensure_dir(exp_dir)
166
+
167
+ prompt_module = load_prompt(prompt)
168
+ # In v2 layout, models are stored under models/<exp_name>/<model>/...
169
+ # and model names are globally unique across experiments.
170
+ model_manager = GlobalModelOutputManager(models_dir)
171
+
172
+ vlm = VLMJudge(
173
+ model=judge_model,
174
+ temperature=temperature,
175
+ timeout=timeout,
176
+ max_retries=max_retries,
177
+ base_urls=base_urls,
178
+ api_keys=api_keys,
179
+ progress=progress_queue,
180
+ )
181
+
182
+ battle_logger = BattleLogger(exp_dir)
183
+ audit_logger = AuditLogger(exp_dir) if enable_audit_log else None
184
+
185
+ completed_set = load_battle_history(pk_logs_dir)
186
+
187
+ class _ProgressBuffer:
188
+ """Batch progress updates to reduce cross-process queue overhead."""
189
+
190
+ def __init__(self, q: Any, flush_every: int = 20):
191
+ self._q = q
192
+ self._flush_every = flush_every
193
+ self._buf = 0
194
+
195
+ def put(self, n: int) -> None:
196
+ if self._q is None:
197
+ return
198
+ self._buf += int(n)
199
+ if self._buf >= self._flush_every:
200
+ try:
201
+ self._q.put(self._buf)
202
+ finally:
203
+ self._buf = 0
204
+
205
+ def total(self, n: int) -> None:
206
+ """Increase progress bar total by n (best-effort)."""
207
+ if self._q is None:
208
+ return
209
+ try:
210
+ n_int = int(n)
211
+ except Exception:
212
+ return
213
+ if n_int <= 0:
214
+ return
215
+ try:
216
+ self._q.put(("total", n_int))
217
+ except Exception:
218
+ pass
219
+
220
+ def flush(self) -> None:
221
+ if self._q is None:
222
+ return
223
+ if self._buf > 0:
224
+ try:
225
+ self._q.put(self._buf)
226
+ finally:
227
+ self._buf = 0
228
+
229
+ progress = _ProgressBuffer(progress_queue) if progress_queue is not None else None
230
+
231
+ def _execute_one(dataset: ParquetDataset, model_a: str, model_b: str, sample_index: int) -> bool:
232
+ # Skip if already completed (sorted key)
233
+ first, second, _ = get_sorted_model_pair(model_a, model_b)
234
+ if (first, second, sample_index) in completed_set:
235
+ return False
236
+
237
+ sample = dataset.get_by_index(sample_index)
238
+ if sample is None:
239
+ return False
240
+
241
+ output_a = model_manager.get_output_path(model_a, sample_index)
242
+ output_b = model_manager.get_output_path(model_b, sample_index)
243
+ if output_a is None or output_b is None:
244
+ return False
245
+
246
+ result = execute_battle(
247
+ vlm=vlm,
248
+ prompt_module=prompt_module,
249
+ sample=sample,
250
+ model_a_output=output_a,
251
+ model_b_output=output_b,
252
+ model_a=model_a,
253
+ model_b=model_b,
254
+ parallel_swap_calls=parallel_swap_calls,
255
+ progress=progress,
256
+ )
257
+
258
+ battle_logger.log_battle_result(result)
259
+ if audit_logger:
260
+ audit_logger.log_battle_result(result)
261
+
262
+ return True
263
+
264
+ # Build tasks lazily and keep inflight bounded to reduce overhead for large runs.
265
+ completed = 0
266
+ total_attempted = 0
267
+ total_indices = 0
268
+
269
+ selected_models = set(models)
270
+ new_models_filtered = [m for m in new_models if m in selected_models]
271
+ if not new_models_filtered:
272
+ return {"completed": 0, "attempted": 0, "indices": 0}
273
+
274
+ pair_set: set[tuple[str, str]] = set()
275
+ for m in new_models_filtered:
276
+ for other in selected_models:
277
+ if other == m:
278
+ continue
279
+ a, b, _ = get_sorted_model_pair(m, other)
280
+ pair_set.add((a, b))
281
+
282
+ model_pairs = sorted(pair_set)
283
+
284
+ if num_threads <= 1:
285
+ for pf, indices in parquet_work:
286
+ if not indices:
287
+ continue
288
+ total_indices += len(indices)
289
+ dataset = ParquetDataset(data_dir, subset, parquet_files=[pf])
290
+ for model_a, model_b in model_pairs:
291
+ valid_indices = model_manager.validate_coverage(model_a, model_b, indices)
292
+ first, second, _ = get_sorted_model_pair(model_a, model_b)
293
+ pending_indices = [idx for idx in valid_indices if (first, second, idx) not in completed_set]
294
+ if progress is not None:
295
+ # Each battle always makes 2 API calls (original + swapped).
296
+ progress.total(2 * len(pending_indices))
297
+ for idx in pending_indices:
298
+ total_attempted += 1
299
+ if _execute_one(dataset, model_a, model_b, idx):
300
+ completed += 1
301
+ else:
302
+ max_inflight = max(1, num_threads * 4)
303
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
304
+ inflight = set()
305
+
306
+ def _drain_one() -> None:
307
+ nonlocal completed
308
+ done_future = next(as_completed(inflight))
309
+ inflight.remove(done_future)
310
+ try:
311
+ ok = done_future.result()
312
+ if ok:
313
+ completed += 1
314
+ except Exception:
315
+ # Worker-level robustness: ignore individual battle failures.
316
+ pass
317
+
318
+ for pf, indices in parquet_work:
319
+ if not indices:
320
+ continue
321
+ total_indices += len(indices)
322
+ dataset = ParquetDataset(data_dir, subset, parquet_files=[pf])
323
+ for model_a, model_b in model_pairs:
324
+ valid_indices = model_manager.validate_coverage(model_a, model_b, indices)
325
+ first, second, _ = get_sorted_model_pair(model_a, model_b)
326
+ pending_indices = [idx for idx in valid_indices if (first, second, idx) not in completed_set]
327
+ if progress is not None:
328
+ progress.total(2 * len(pending_indices))
329
+ for idx in pending_indices:
330
+ total_attempted += 1
331
+ inflight.add(executor.submit(_execute_one, dataset, model_a, model_b, idx))
332
+ if len(inflight) >= max_inflight:
333
+ _drain_one()
334
+
335
+ while inflight:
336
+ _drain_one()
337
+
338
+ if progress is not None:
339
+ progress.flush()
340
+
341
+ return {
342
+ "completed": completed,
343
+ "attempted": total_attempted,
344
+ "indices": total_indices,
345
+ }
346
+
347
+
348
+ def _start_calls_progress_consumer(
349
+ *,
350
+ enabled: bool,
351
+ total: Optional[int] = None,
352
+ ) -> tuple[Any, Optional[threading.Thread], Any]:
353
+ """
354
+ Start a progress consumer thread that reads integer increments from a queue.
355
+
356
+ Returns:
357
+ (progress_queue, thread, close_fn)
358
+ """
359
+ if not enabled:
360
+ return None, None, lambda: None
361
+
362
+ try:
363
+ from tqdm import tqdm # type: ignore
364
+ except Exception:
365
+ logger.warning("tqdm is not available; progress bar disabled")
366
+ return None, None, lambda: None
367
+
368
+ q: thread_queue.Queue[Any] = thread_queue.Queue()
369
+ stop_sentinel = object()
370
+ bar = tqdm(total=total, unit="call", desc="API Calls", dynamic_ncols=True)
371
+ recent: deque[str] = deque(maxlen=10)
372
+
373
+ def _run() -> None:
374
+ while True:
375
+ item = q.get()
376
+ if item is stop_sentinel:
377
+ break
378
+ if isinstance(item, (int, float)):
379
+ try:
380
+ bar.update(int(item))
381
+ except Exception:
382
+ pass
383
+ elif isinstance(item, tuple) and len(item) == 2 and item[0] == "log":
384
+ try:
385
+ recent.append(str(item[1]))
386
+ bar.set_postfix_str(" | ".join(recent))
387
+ except Exception:
388
+ pass
389
+ elif isinstance(item, tuple) and len(item) == 2 and item[0] == "total":
390
+ try:
391
+ delta = int(item[1])
392
+ if delta > 0:
393
+ bar.total = (bar.total or 0) + delta
394
+ bar.refresh()
395
+ except Exception:
396
+ pass
397
+ else:
398
+ # Unknown item type, ignore.
399
+ pass
400
+
401
+ t = threading.Thread(target=_run, name="calls-progress-consumer", daemon=True)
402
+ t.start()
403
+
404
+ def _close() -> None:
405
+ try:
406
+ q.put(stop_sentinel)
407
+ except Exception:
408
+ pass
409
+ if t is not None:
410
+ t.join(timeout=5)
411
+ try:
412
+ bar.close()
413
+ except Exception:
414
+ pass
415
+
416
+ return q, t, _close
417
+
418
+
419
+ class Arena:
420
+ """
421
+ Arena coordinator for running pairwise model evaluations.
422
+
423
+ Manages:
424
+ - Subset directory structure
425
+ - Model discovery and output management
426
+ - Battle pair generation
427
+ - Checkpoint/resume functionality
428
+ - Parallel battle execution
429
+ - ELO state management
430
+ - Leaderboard generation
431
+ """
432
+
433
+ def __init__(self, config: ArenaConfig):
434
+ """
435
+ Initialize the arena.
436
+
437
+ Args:
438
+ config: ArenaConfig with all settings
439
+ """
440
+ self.config = config
441
+
442
+ # Set up paths
443
+ self.subset_dir = os.path.join(config.arena_dir, config.subset)
444
+ self.models_root_dir = os.path.join(self.subset_dir, "models")
445
+ self.pk_logs_dir = os.path.join(self.subset_dir, "pk_logs")
446
+ # Resolve experiment name (infer from models/ if not provided)
447
+ if config.exp_name is not None:
448
+ require_valid_exp_name(config.exp_name)
449
+ else:
450
+ config.exp_name = pick_latest_experiment_name(self.models_root_dir)
451
+
452
+ # In v2 layout, per-experiment model outputs live under: models/<exp_name>/<model>/...
453
+ self.models_dir = os.path.join(self.models_root_dir, config.exp_name)
454
+ if not os.path.isdir(self.models_dir):
455
+ raise ValueError(
456
+ f"Experiment models directory does not exist: {self.models_dir}. "
457
+ f"Expected `models/{config.exp_name}/<model_name>/...`."
458
+ )
459
+ self.exp_dir = os.path.join(self.pk_logs_dir, config.exp_name)
460
+ self.arena_state_dir = os.path.join(self.subset_dir, "arena")
461
+ self.state_path = os.path.join(self.arena_state_dir, "state.json")
462
+ self.leaderboard_path = os.path.join(self.subset_dir, "README.md")
463
+
464
+ # Initialize directories
465
+ self._init_directories()
466
+
467
+ # Load components
468
+ self.prompt_module = load_prompt(config.prompt)
469
+ # In multiprocessing mode, we only need fast index scanning in the parent
470
+ # process (full data is loaded per-parquet inside workers).
471
+ load_mode = "index_only" if config.num_processes > 1 else "full"
472
+ self.dataset = ParquetDataset(config.data_dir, config.subset, load_mode=load_mode)
473
+ # Global model registry (v2 layout): models/<exp_name>/<model>/...
474
+ self.model_manager = GlobalModelOutputManager(self.models_root_dir)
475
+
476
+ # Models that are newly introduced in this experiment (directory listing)
477
+ self.new_models = self.model_manager.get_experiment_models(config.exp_name)
478
+
479
+ # Parse experiment date for filtering eligible opponents
480
+ self.exp_date = parse_exp_date_suffix(config.exp_name)
481
+
482
+ # Resolve selected model universe for this run
483
+ # When running an old experiment, only consider models from experiments
484
+ # with date <= this experiment's date (to avoid battling "future" models).
485
+ if config.models:
486
+ self.models = [m for m in config.models if self.model_manager.has_model(m)]
487
+ elif self.exp_date is not None:
488
+ # Filter to models from experiments up to this experiment's date
489
+ self.models = self.model_manager.get_models_up_to_date(self.exp_date)
490
+ else:
491
+ self.models = self.model_manager.models
492
+
493
+ # Canonical "current models on disk" (used for state/log cleanup even when --models is used)
494
+ self.all_models = self.model_manager.models
495
+
496
+ # Initialize loggers
497
+ self.battle_logger = BattleLogger(self.exp_dir)
498
+ self.audit_logger = AuditLogger(self.exp_dir) if config.enable_audit_log else None
499
+
500
+ # Initialize VLM judge with multi-endpoint support
501
+ self.vlm = VLMJudge(
502
+ model=config.judge_model,
503
+ temperature=config.temperature,
504
+ timeout=config.timeout,
505
+ max_retries=config.max_retries,
506
+ base_urls=config.base_urls,
507
+ api_keys=config.api_keys,
508
+ )
509
+
510
+ # Save experiment config
511
+ self._save_config()
512
+ self._progress_queue = None
513
+
514
+ def _init_directories(self) -> None:
515
+ """Create necessary directory structure."""
516
+ ensure_dir(self.subset_dir)
517
+ ensure_dir(self.models_root_dir)
518
+ ensure_dir(self.pk_logs_dir)
519
+ ensure_dir(self.exp_dir)
520
+ ensure_dir(self.arena_state_dir)
521
+
522
+ if self.config.enable_audit_log:
523
+ ensure_dir(os.path.join(self.exp_dir, "raw_outputs"))
524
+
525
+ def _save_config(self) -> None:
526
+ """Save experiment configuration."""
527
+ config_path = os.path.join(self.exp_dir, "config.json")
528
+ history_path = os.path.join(self.exp_dir, "config_history.json")
529
+
530
+ config_dict = self.config.to_dict()
531
+ config_dict["models_actual"] = self.models
532
+
533
+ # If config exists, append to history
534
+ if os.path.isfile(config_path):
535
+ # Read existing config and append to history
536
+ try:
537
+ with open(config_path, "r", encoding="utf-8") as f:
538
+ existing = json.load(f)
539
+
540
+ # Load or create history
541
+ history = []
542
+ if os.path.isfile(history_path):
543
+ with open(history_path, "r", encoding="utf-8") as f:
544
+ history = json.load(f)
545
+
546
+ history.append(existing)
547
+
548
+ with open(history_path, "w", encoding="utf-8") as f:
549
+ json.dump(history, f, indent=2, ensure_ascii=False)
550
+ except Exception:
551
+ pass
552
+
553
+ # Write current config
554
+ with open(config_path, "w", encoding="utf-8") as f:
555
+ json.dump(config_dict, f, indent=2, ensure_ascii=False)
556
+
557
+ def _sync_state_with_models(self) -> bool:
558
+ """
559
+ Synchronize arena state with current available models.
560
+
561
+ If models have been removed from the models directory, this method will:
562
+ 1. Detect removed models (from both state and pk_logs)
563
+ 2. Move battle logs involving removed models to .pk_logs_rm/ (if clean_orphaned_logs=True)
564
+ 3. Rebuild ELO state from remaining battle logs
565
+ 4. Save the updated state
566
+
567
+ Returns:
568
+ True if state was rebuilt due to model changes, False otherwise
569
+ """
570
+ state = load_state(self.state_path)
571
+ # Use the canonical on-disk model set (do NOT treat --models filter as removals)
572
+ current_models = set(self.all_models)
573
+
574
+ # Get models that exist in state but not in current model list
575
+ state_models = set(state.models.keys())
576
+ removed_from_state = state_models - current_models
577
+
578
+ # Also scan pk_logs to find models that exist in logs but not in models/
579
+ logs_models = self._scan_models_from_logs()
580
+ removed_from_logs = logs_models - current_models
581
+
582
+ # Combine both sources of removed models
583
+ removed_models = removed_from_state | removed_from_logs
584
+
585
+ if not removed_models:
586
+ return False
587
+
588
+ logger.info(
589
+ f"Detected removed models: {removed_models}. "
590
+ f"Rebuilding ELO state from battle logs..."
591
+ )
592
+
593
+ # Clean up orphaned battle logs if enabled
594
+ if self.config.clean_orphaned_logs:
595
+ self._delete_orphaned_logs(removed_models)
596
+
597
+ # Rebuild state from logs, only including current models
598
+ new_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
599
+
600
+ # Save the rebuilt state
601
+ save_state(new_state, self.state_path)
602
+
603
+ logger.info(
604
+ f"State rebuilt: {new_state.total_battles} battles, "
605
+ f"{len(new_state.models)} models"
606
+ )
607
+
608
+ return True
609
+
610
+ def _scan_models_from_logs(self) -> set[str]:
611
+ """
612
+ Scan all battle log files to extract model names.
613
+
614
+ This method reads the actual content of jsonl files to get the original
615
+ model names, which is more reliable than parsing sanitized filenames.
616
+
617
+ Returns:
618
+ Set of all model names found in battle logs
619
+ """
620
+ models_found: set[str] = set()
621
+
622
+ if not os.path.isdir(self.pk_logs_dir):
623
+ return models_found
624
+
625
+ for exp_name in os.listdir(self.pk_logs_dir):
626
+ exp_dir = os.path.join(self.pk_logs_dir, exp_name)
627
+ if not os.path.isdir(exp_dir):
628
+ continue
629
+
630
+ for filename in os.listdir(exp_dir):
631
+ if not filename.endswith(".jsonl"):
632
+ continue
633
+
634
+ filepath = os.path.join(exp_dir, filename)
635
+ if not os.path.isfile(filepath):
636
+ continue
637
+
638
+ # Read first line to extract model names (all lines in a file
639
+ # should have the same model pair)
640
+ try:
641
+ with open(filepath, "r", encoding="utf-8") as f:
642
+ for line in f:
643
+ line = line.strip()
644
+ if not line:
645
+ continue
646
+ try:
647
+ record = json.loads(line)
648
+ model_a = record.get("model_a", "")
649
+ model_b = record.get("model_b", "")
650
+ if model_a:
651
+ models_found.add(model_a)
652
+ if model_b:
653
+ models_found.add(model_b)
654
+ # Only need first valid line per file
655
+ break
656
+ except json.JSONDecodeError:
657
+ continue
658
+ except Exception:
659
+ pass
660
+
661
+ return models_found
662
+
663
+ def _delete_orphaned_logs(self, removed_models: set[str]) -> None:
664
+ """
665
+ Move battle log files that involve removed models to .pk_logs_rm/ directory.
666
+
667
+ This method reads the actual content of each jsonl file to extract the
668
+ original model names, which is more reliable than parsing sanitized
669
+ filenames. Instead of deleting, files are moved to a backup directory
670
+ (.pk_logs_rm/) at the same level as pk_logs/.
671
+
672
+ Args:
673
+ removed_models: Set of model names that have been removed
674
+ """
675
+ import shutil
676
+
677
+ if not os.path.isdir(self.pk_logs_dir):
678
+ return
679
+
680
+ # Create backup directory at the same level as pk_logs
681
+ pk_logs_rm_dir = os.path.join(self.subset_dir, ".pk_logs_rm")
682
+
683
+ moved_count = 0
684
+
685
+ def _file_involves_removed_model(filepath: str) -> bool:
686
+ """
687
+ Check if a jsonl file involves any removed model by reading its content.
688
+
689
+ Returns True if any record in the file has model_a or model_b in removed_models.
690
+ """
691
+ try:
692
+ with open(filepath, "r", encoding="utf-8") as f:
693
+ for line in f:
694
+ line = line.strip()
695
+ if not line:
696
+ continue
697
+ try:
698
+ record = json.loads(line)
699
+ model_a = record.get("model_a", "")
700
+ model_b = record.get("model_b", "")
701
+ if model_a in removed_models or model_b in removed_models:
702
+ return True
703
+ except json.JSONDecodeError:
704
+ continue
705
+ except Exception:
706
+ pass
707
+ return False
708
+
709
+ def _move_to_backup(filepath: str, relative_path: str) -> bool:
710
+ """
711
+ Move a file to the backup directory, preserving relative path structure.
712
+
713
+ Args:
714
+ filepath: Absolute path to the source file
715
+ relative_path: Relative path from pk_logs_dir (e.g., "exp_name/file.jsonl")
716
+
717
+ Returns:
718
+ True if moved successfully, False otherwise
719
+ """
720
+ dest_path = os.path.join(pk_logs_rm_dir, relative_path)
721
+ dest_dir = os.path.dirname(dest_path)
722
+
723
+ try:
724
+ ensure_dir(dest_dir)
725
+ shutil.move(filepath, dest_path)
726
+ return True
727
+ except Exception as e:
728
+ logger.warning(f"Failed to move {filepath} to {dest_path}: {e}")
729
+ return False
730
+
731
+ # Iterate over all experiment directories
732
+ for exp_name in os.listdir(self.pk_logs_dir):
733
+ exp_dir = os.path.join(self.pk_logs_dir, exp_name)
734
+ if not os.path.isdir(exp_dir):
735
+ continue
736
+
737
+ # Check battle log files (format: model_a_vs_model_b.jsonl)
738
+ for filename in os.listdir(exp_dir):
739
+ if not filename.endswith(".jsonl"):
740
+ continue
741
+
742
+ filepath = os.path.join(exp_dir, filename)
743
+ if not os.path.isfile(filepath):
744
+ continue
745
+
746
+ # Check file content to determine if it involves removed models
747
+ if _file_involves_removed_model(filepath):
748
+ relative_path = os.path.join(exp_name, filename)
749
+ if _move_to_backup(filepath, relative_path):
750
+ moved_count += 1
751
+ logger.debug(f"Moved orphaned log to backup: {filepath}")
752
+
753
+ # Also check raw_outputs subdirectory
754
+ raw_outputs_dir = os.path.join(exp_dir, "raw_outputs")
755
+ if os.path.isdir(raw_outputs_dir):
756
+ for filename in os.listdir(raw_outputs_dir):
757
+ if not filename.endswith(".jsonl"):
758
+ continue
759
+
760
+ filepath = os.path.join(raw_outputs_dir, filename)
761
+ if not os.path.isfile(filepath):
762
+ continue
763
+
764
+ if _file_involves_removed_model(filepath):
765
+ relative_path = os.path.join(exp_name, "raw_outputs", filename)
766
+ if _move_to_backup(filepath, relative_path):
767
+ moved_count += 1
768
+ logger.debug(f"Moved orphaned audit log to backup: {filepath}")
769
+
770
+ if moved_count > 0:
771
+ logger.info(f"Moved {moved_count} orphaned battle log files to {pk_logs_rm_dir}")
772
+
773
+ def _generate_battle_pairs(self) -> list[BattlePair]:
774
+ """
775
+ Generate all battle pairs to execute.
776
+
777
+ In full mode: generates all possible pairs up to sample_size.
778
+ In adaptive mode: generates pairs based on sampling config, respecting
779
+ min_samples and max_samples per model pair.
780
+
781
+ Returns:
782
+ List of BattlePair objects
783
+ """
784
+ pairs = []
785
+
786
+ # Get all dataset indices
787
+ all_indices = self.dataset.get_all_indices()
788
+
789
+ # In full mode, apply global sample_size limit
790
+ # In adaptive mode, we apply per-pair limits later
791
+ if self.config.sampling.mode == "full":
792
+ if self.config.sample_size and self.config.sample_size < len(all_indices):
793
+ indices = random.sample(all_indices, self.config.sample_size)
794
+ else:
795
+ indices = all_indices
796
+ else:
797
+ # Adaptive mode: use all indices, will limit per-pair
798
+ indices = all_indices
799
+
800
+ # Generate model pairs to run for this exp:
801
+ # - only include pairs where at least one side is a "new model" in this exp
802
+ # - but respect the user-provided --models filter (self.models)
803
+ selected_models = set(self.models)
804
+ new_models = [m for m in self.new_models if m in selected_models]
805
+
806
+ if not new_models:
807
+ return []
808
+
809
+ # Build unique pair set (sorted) for: new-vs-all + new-vs-new
810
+ pair_set: set[tuple[str, str]] = set()
811
+ for m in new_models:
812
+ for other in selected_models:
813
+ if other == m:
814
+ continue
815
+ a, b, _ = get_sorted_model_pair(m, other)
816
+ pair_set.add((a, b))
817
+
818
+ model_pairs = sorted(pair_set)
819
+
820
+ # Load existing battle counts for adaptive mode
821
+ if self.config.sampling.mode == "adaptive":
822
+ existing_counts = count_battles_per_pair(self.pk_logs_dir)
823
+ # Determine target samples per pair based on experiment type
824
+ if is_milestone_exp(self.config.exp_name or ""):
825
+ target_samples = self.config.sampling.milestone_min_samples
826
+ else:
827
+ target_samples = self.config.sampling.min_samples
828
+ else:
829
+ existing_counts = {}
830
+ target_samples = None
831
+
832
+ # Generate battle pairs for each model pair and sample
833
+ for model_a, model_b in model_pairs:
834
+ # Validate coverage
835
+ valid_indices = self.model_manager.validate_coverage(
836
+ model_a, model_b, indices
837
+ )
838
+
839
+ # In adaptive mode, limit samples per pair
840
+ if self.config.sampling.mode == "adaptive" and target_samples is not None:
841
+ key = (min(model_a, model_b), max(model_a, model_b))
842
+ existing = existing_counts.get(key, 0)
843
+ needed = max(0, target_samples - existing)
844
+
845
+ if needed == 0:
846
+ continue # This pair already has enough samples
847
+
848
+ # Limit to needed samples (randomly select if more available)
849
+ if len(valid_indices) > needed:
850
+ valid_indices = random.sample(valid_indices, needed)
851
+
852
+ for idx in valid_indices:
853
+ pairs.append(BattlePair(
854
+ model_a=model_a,
855
+ model_b=model_b,
856
+ sample_index=idx
857
+ ))
858
+
859
+ return pairs
860
+
861
+ def _skip_completed(
862
+ self,
863
+ pairs: list[BattlePair]
864
+ ) -> list[BattlePair]:
865
+ """
866
+ Filter out already completed battles.
867
+
868
+ Only considers battles where both models still exist in the current
869
+ model list. Battles involving removed models are ignored.
870
+
871
+ Args:
872
+ pairs: List of battle pairs
873
+
874
+ Returns:
875
+ Filtered list excluding completed battles
876
+ """
877
+ all_completed = load_battle_history(self.pk_logs_dir)
878
+
879
+ # Filter completed battles to only include those with current on-disk models.
880
+ # This avoids treating --models filters as removals.
881
+ current_models = set(self.all_models)
882
+ completed = {
883
+ (m_a, m_b, idx)
884
+ for m_a, m_b, idx in all_completed
885
+ if m_a in current_models and m_b in current_models
886
+ }
887
+
888
+ remaining = []
889
+ for pair in pairs:
890
+ # Get sorted model names for lookup
891
+ first, second, _ = get_sorted_model_pair(pair.model_a, pair.model_b)
892
+ key = (first, second, pair.sample_index)
893
+
894
+ if key not in completed:
895
+ remaining.append(pair)
896
+
897
+ skipped = len(pairs) - len(remaining)
898
+ if skipped > 0:
899
+ logger.info(f"Skipping {skipped} already completed battles")
900
+
901
+ # Log if there are orphaned battles from removed models
902
+ orphaned = len(all_completed) - len(completed)
903
+ if orphaned > 0:
904
+ logger.info(
905
+ f"Ignoring {orphaned} battle records involving removed models"
906
+ )
907
+
908
+ return remaining
909
+
910
+ def _execute_single_battle(
911
+ self,
912
+ pair: BattlePair
913
+ ) -> Optional[BattleResult]:
914
+ """
915
+ Execute a single battle.
916
+
917
+ Args:
918
+ pair: BattlePair to execute
919
+
920
+ Returns:
921
+ BattleResult or None if failed
922
+ """
923
+ try:
924
+ # Get sample data
925
+ sample = self.dataset.get_by_index(pair.sample_index)
926
+ if sample is None:
927
+ logger.warning(
928
+ f"Sample {pair.sample_index} not found in dataset"
929
+ )
930
+ return None
931
+
932
+ # Get model outputs
933
+ output_a = self.model_manager.get_output_path(
934
+ pair.model_a, pair.sample_index
935
+ )
936
+ output_b = self.model_manager.get_output_path(
937
+ pair.model_b, pair.sample_index
938
+ )
939
+
940
+ if output_a is None or output_b is None:
941
+ logger.warning(
942
+ f"Missing output for battle {pair.model_a} vs {pair.model_b} "
943
+ f"at index {pair.sample_index}"
944
+ )
945
+ return None
946
+
947
+ # Execute battle
948
+ result = execute_battle(
949
+ vlm=self.vlm,
950
+ prompt_module=self.prompt_module,
951
+ sample=sample,
952
+ model_a_output=output_a,
953
+ model_b_output=output_b,
954
+ model_a=pair.model_a,
955
+ model_b=pair.model_b,
956
+ parallel_swap_calls=self.config.parallel_swap_calls,
957
+ progress=self._progress_queue,
958
+ )
959
+
960
+ return result
961
+
962
+ except Exception as e:
963
+ logger.error(
964
+ f"Error executing battle {pair.model_a} vs {pair.model_b} "
965
+ f"at index {pair.sample_index}: {e}"
966
+ )
967
+ return None
968
+
969
+ def _process_result(
970
+ self,
971
+ result: BattleResult,
972
+ state: ArenaState
973
+ ) -> ArenaState:
974
+ """
975
+ Process a battle result: log and update state.
976
+
977
+ Args:
978
+ result: BattleResult from battle execution
979
+ state: Current arena state
980
+
981
+ Returns:
982
+ Updated arena state
983
+ """
984
+ # Log battle result (slim)
985
+ self.battle_logger.log_battle_result(result)
986
+
987
+ # Log audit trail (detailed)
988
+ if self.audit_logger:
989
+ self.audit_logger.log_battle_result(result)
990
+
991
+ # Update W/L/T stats only. Elo is recomputed via Bradley-Terry fitting
992
+ # from accumulated battle logs (order-independent).
993
+ state = update_stats(state, result.model_a, result.model_b, result.final_winner)
994
+
995
+ return state
996
+
997
+ def _get_battles_from_logs(self) -> list[BattleTuple]:
998
+ """
999
+ Load battle records from logs and convert to BattleTuple format.
1000
+
1001
+ Returns:
1002
+ List of (model_a, model_b, winner) tuples for BT-Elo computation.
1003
+ """
1004
+ records = load_battle_records(self.pk_logs_dir)
1005
+ battles: list[BattleTuple] = []
1006
+
1007
+ current_models = set(self.all_models)
1008
+
1009
+ for record in records:
1010
+ model_a = record.get("model_a", "")
1011
+ model_b = record.get("model_b", "")
1012
+ final_winner = record.get("final_winner", "")
1013
+
1014
+ # Skip records involving removed models
1015
+ if model_a not in current_models or model_b not in current_models:
1016
+ continue
1017
+
1018
+ # Convert winner to standard format
1019
+ if final_winner == model_a:
1020
+ winner = "model_a"
1021
+ elif final_winner == model_b:
1022
+ winner = "model_b"
1023
+ elif final_winner == "tie":
1024
+ winner = "tie"
1025
+ else:
1026
+ continue # Skip invalid records
1027
+
1028
+ battles.append((model_a, model_b, winner))
1029
+
1030
+ return battles
1031
+
1032
+ def _load_anchor_elo(self) -> dict[str, float]:
1033
+ """
1034
+ Load anchor ELO ratings from the latest milestone snapshot.
1035
+
1036
+ Returns:
1037
+ Dict mapping model name to ELO rating for milestone models,
1038
+ or empty dict if no milestone exists.
1039
+ """
1040
+ # Discover milestone experiments
1041
+ exp_keys: list[tuple[tuple, str]] = []
1042
+ if not os.path.isdir(self.pk_logs_dir):
1043
+ return {}
1044
+
1045
+ for name in os.listdir(self.pk_logs_dir):
1046
+ if name.startswith("."):
1047
+ continue
1048
+ exp_dir = os.path.join(self.pk_logs_dir, name)
1049
+ if not os.path.isdir(exp_dir):
1050
+ continue
1051
+ d = parse_exp_date_suffix(name)
1052
+ if d is None:
1053
+ continue
1054
+ exp_keys.append(((d, name), name))
1055
+
1056
+ exp_keys.sort(key=lambda x: x[0])
1057
+
1058
+ # Find milestones
1059
+ milestones = [name for (key, name) in exp_keys if is_milestone_exp(name)]
1060
+ if not milestones:
1061
+ return {}
1062
+
1063
+ # Load from latest milestone snapshot
1064
+ latest_milestone = milestones[-1]
1065
+ snapshot_path = os.path.join(self.pk_logs_dir, latest_milestone, "elo_snapshot.json")
1066
+
1067
+ if not os.path.isfile(snapshot_path):
1068
+ return {}
1069
+
1070
+ try:
1071
+ with open(snapshot_path, "r", encoding="utf-8") as f:
1072
+ data = json.load(f)
1073
+ except Exception:
1074
+ return {}
1075
+
1076
+ if not isinstance(data, dict):
1077
+ return {}
1078
+
1079
+ # Accept either: {"elo": {...}} or a direct {model: elo} mapping
1080
+ raw = data.get("elo") if isinstance(data.get("elo"), dict) else data
1081
+ if not isinstance(raw, dict):
1082
+ return {}
1083
+
1084
+ # Filter to only include models that exist in current model set
1085
+ current_models = set(self.all_models)
1086
+ anchor_elo: dict[str, float] = {}
1087
+ for k, v in raw.items():
1088
+ if str(k) in current_models:
1089
+ try:
1090
+ anchor_elo[str(k)] = float(v)
1091
+ except Exception:
1092
+ continue
1093
+
1094
+ return anchor_elo
1095
+
1096
+ def _run_adaptive_with_ci_checking(self) -> ArenaState:
1097
+ """
1098
+ Run arena evaluation with adaptive CI-based sampling.
1099
+
1100
+ This method implements the iterative loop:
1101
+ 1. Run initial batch (min_samples per pair)
1102
+ 2. Compute bootstrap CI
1103
+ 3. If max CI width > target, add batch_size more samples to unconverged pairs
1104
+ 4. Repeat until all pairs converge or reach max_samples
1105
+
1106
+ Returns:
1107
+ Final ArenaState after all battles
1108
+ """
1109
+ sampling_config = self.config.sampling
1110
+ is_milestone = is_milestone_exp(self.config.exp_name or "")
1111
+
1112
+ # Determine target samples per pair for initial batch
1113
+ if is_milestone:
1114
+ target_samples = sampling_config.milestone_min_samples
1115
+ logger.info(f"Milestone experiment: targeting {target_samples} samples/pair initially")
1116
+ else:
1117
+ target_samples = sampling_config.min_samples
1118
+ logger.info(f"Incremental experiment: targeting {target_samples} samples/pair initially")
1119
+
1120
+ # Get all dataset indices
1121
+ all_indices = self.dataset.get_all_indices()
1122
+
1123
+ # Build model pairs (new models vs all selected models)
1124
+ selected_models = set(self.models)
1125
+ new_models = [m for m in self.new_models if m in selected_models]
1126
+
1127
+ if not new_models:
1128
+ logger.info("No new models to evaluate")
1129
+ return rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
1130
+
1131
+ # Build unique pair set
1132
+ pair_set: set[tuple[str, str]] = set()
1133
+ for m in new_models:
1134
+ for other in selected_models:
1135
+ if other == m:
1136
+ continue
1137
+ a, b, _ = get_sorted_model_pair(m, other)
1138
+ pair_set.add((a, b))
1139
+
1140
+ model_pairs = sorted(pair_set)
1141
+ logger.info(f"Evaluating {len(model_pairs)} model pairs with adaptive sampling")
1142
+
1143
+ # Initialize scheduler
1144
+ scheduler = AdaptiveSamplingScheduler(config=sampling_config)
1145
+
1146
+ # Load existing battle counts
1147
+ existing_counts = count_battles_per_pair(self.pk_logs_dir)
1148
+ for pair in model_pairs:
1149
+ count = existing_counts.get(pair, 0)
1150
+ scheduler.update_state(pair[0], pair[1], current_samples=count)
1151
+
1152
+ # Load existing state
1153
+ state = load_state(self.state_path)
1154
+
1155
+ # Progress tracking
1156
+ progress_queue, _progress_thread, progress_close = _start_calls_progress_consumer(
1157
+ enabled=self.config.enable_progress_bar,
1158
+ total=None, # Dynamic total
1159
+ )
1160
+ self._progress_queue = progress_queue
1161
+ if self._progress_queue is not None:
1162
+ try:
1163
+ self.vlm.set_progress(self._progress_queue)
1164
+ except Exception:
1165
+ pass
1166
+
1167
+ iteration = 0
1168
+ total_completed = 0
1169
+
1170
+ while True:
1171
+ iteration += 1
1172
+
1173
+ # Determine which pairs need more samples
1174
+ pairs_to_run: list[tuple[str, str]] = []
1175
+ samples_per_pair: dict[tuple[str, str], int] = {}
1176
+
1177
+ for pair in model_pairs:
1178
+ pair_state = scheduler.get_or_create_state(pair[0], pair[1])
1179
+ samples_to_run = pair_state.get_samples_to_run(sampling_config, len(all_indices))
1180
+
1181
+ if samples_to_run > 0:
1182
+ pairs_to_run.append(pair)
1183
+ samples_per_pair[pair] = samples_to_run
1184
+
1185
+ if not pairs_to_run:
1186
+ logger.info("All pairs have converged or reached max_samples")
1187
+ break
1188
+
1189
+ total_samples_this_iter = sum(samples_per_pair.values())
1190
+ logger.info(
1191
+ f"Iteration {iteration}: running {total_samples_this_iter} battles "
1192
+ f"across {len(pairs_to_run)} pairs"
1193
+ )
1194
+
1195
+ # Generate battle pairs for this iteration
1196
+ completed_set = load_battle_history(self.pk_logs_dir)
1197
+ battle_pairs: list[BattlePair] = []
1198
+
1199
+ for pair in pairs_to_run:
1200
+ model_a, model_b = pair
1201
+ needed = samples_per_pair[pair]
1202
+
1203
+ # Get valid indices for this pair
1204
+ valid_indices = self.model_manager.validate_coverage(model_a, model_b, all_indices)
1205
+
1206
+ # Filter out already completed
1207
+ pending_indices = [
1208
+ idx for idx in valid_indices
1209
+ if (model_a, model_b, idx) not in completed_set
1210
+ ]
1211
+
1212
+ # Select up to 'needed' samples
1213
+ if len(pending_indices) > needed:
1214
+ selected = random.sample(pending_indices, needed)
1215
+ else:
1216
+ selected = pending_indices
1217
+
1218
+ for idx in selected:
1219
+ battle_pairs.append(BattlePair(
1220
+ model_a=model_a,
1221
+ model_b=model_b,
1222
+ sample_index=idx
1223
+ ))
1224
+
1225
+ if not battle_pairs:
1226
+ logger.info("No more battles to execute")
1227
+ break
1228
+
1229
+ # Update progress bar total
1230
+ if self._progress_queue is not None:
1231
+ try:
1232
+ self._progress_queue.put(("total", 2 * len(battle_pairs)))
1233
+ except Exception:
1234
+ pass
1235
+
1236
+ # Execute battles
1237
+ iter_completed = 0
1238
+
1239
+ if self.config.num_threads <= 1:
1240
+ # Sequential execution
1241
+ for pair in battle_pairs:
1242
+ result = self._execute_single_battle(pair)
1243
+ if result:
1244
+ state = self._process_result(result, state)
1245
+ iter_completed += 1
1246
+ else:
1247
+ # Parallel execution
1248
+ with ThreadPoolExecutor(max_workers=self.config.num_threads) as executor:
1249
+ future_to_pair = {
1250
+ executor.submit(self._execute_single_battle, pair): pair
1251
+ for pair in battle_pairs
1252
+ }
1253
+
1254
+ for future in as_completed(future_to_pair):
1255
+ try:
1256
+ result = future.result()
1257
+ if result:
1258
+ state = self._process_result(result, state)
1259
+ iter_completed += 1
1260
+ except Exception as e:
1261
+ pair = future_to_pair[future]
1262
+ logger.error(f"Battle {pair.model_a} vs {pair.model_b} failed: {e}")
1263
+
1264
+ total_completed += iter_completed
1265
+ logger.info(f"Iteration {iteration} completed: {iter_completed} battles")
1266
+
1267
+ # Save intermediate state
1268
+ save_state(state, self.state_path)
1269
+
1270
+ # Update scheduler with new counts
1271
+ new_counts = count_battles_per_pair(self.pk_logs_dir)
1272
+ for pair in model_pairs:
1273
+ count = new_counts.get(pair, 0)
1274
+ scheduler.update_state(pair[0], pair[1], current_samples=count)
1275
+
1276
+ # Compute bootstrap CI to check convergence
1277
+ battles = self._get_battles_from_logs()
1278
+ if battles:
1279
+ # Load anchor ELO from latest milestone snapshot
1280
+ # Milestone models have fixed ELO, so we only check CI for new models
1281
+ anchor_elo = self._load_anchor_elo()
1282
+
1283
+ bootstrap_result = compute_bootstrap_bt_elo(
1284
+ battles,
1285
+ models=self.all_models,
1286
+ fixed_ratings=anchor_elo if anchor_elo else None,
1287
+ num_bootstrap=sampling_config.num_bootstrap,
1288
+ )
1289
+
1290
+ # Only check CI for new models (non-anchor models)
1291
+ # Anchor models have CI width = 0 since their ELO is fixed
1292
+ new_models_set = set(new_models)
1293
+ new_model_ci_widths = [
1294
+ bootstrap_result.ci_width.get(m, 0.0)
1295
+ for m in new_models_set
1296
+ if m in bootstrap_result.ci_width
1297
+ ]
1298
+
1299
+ if new_model_ci_widths:
1300
+ max_ci_width = max(new_model_ci_widths)
1301
+ mean_ci_width = sum(new_model_ci_widths) / len(new_model_ci_widths)
1302
+ else:
1303
+ max_ci_width = bootstrap_result.get_max_ci_width()
1304
+ mean_ci_width = bootstrap_result.get_mean_ci_width()
1305
+
1306
+ logger.info(
1307
+ f"CI check (new models only): max_width={max_ci_width:.2f}, "
1308
+ f"mean_width={mean_ci_width:.2f}, target={sampling_config.target_ci_width:.2f}"
1309
+ )
1310
+
1311
+ # Check if all new models have converged
1312
+ if max_ci_width <= sampling_config.target_ci_width:
1313
+ logger.info(f"CI target reached! Max CI width for new models: {max_ci_width:.2f}")
1314
+ # Mark all pairs as converged
1315
+ for pair in model_pairs:
1316
+ pair_state = scheduler.get_or_create_state(pair[0], pair[1])
1317
+ pair_state.converged = True
1318
+ break
1319
+
1320
+ # Check if all pairs have reached max_samples
1321
+ all_maxed = True
1322
+ for pair in model_pairs:
1323
+ pair_state = scheduler.get_or_create_state(pair[0], pair[1])
1324
+ if pair_state.current_samples < sampling_config.max_samples:
1325
+ all_maxed = False
1326
+ break
1327
+
1328
+ if all_maxed:
1329
+ logger.info("All pairs reached max_samples limit")
1330
+ break
1331
+
1332
+ progress_close()
1333
+
1334
+ # Final summary
1335
+ summary = scheduler.get_summary()
1336
+ logger.info(
1337
+ f"Adaptive sampling complete: "
1338
+ f"{summary['total_pairs']} pairs, "
1339
+ f"{summary['converged_pairs']} converged, "
1340
+ f"{summary['maxed_pairs']} reached max_samples, "
1341
+ f"{summary['total_samples']} total samples"
1342
+ )
1343
+
1344
+ # Final Elo recompute (Bradley-Terry) and state save
1345
+ final_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
1346
+ save_state(final_state, self.state_path)
1347
+
1348
+ logger.info(f"Arena completed: {total_completed} battles executed in {iteration} iterations")
1349
+
1350
+ return final_state
1351
+
1352
+ def run(self) -> ArenaState:
1353
+ """
1354
+ Run the arena evaluation.
1355
+
1356
+ If models have been removed from the arena directory, the ELO state
1357
+ will be automatically rebuilt from battle logs (excluding removed models).
1358
+
1359
+ Returns:
1360
+ Final ArenaState after all battles
1361
+ """
1362
+ # Sync state with current models (rebuild if models were removed).
1363
+ # This rebuild uses Bradley-Terry Elo scoring from logs.
1364
+ self._sync_state_with_models()
1365
+
1366
+ # Use adaptive CI-checking mode if enabled (and not multiprocessing)
1367
+ if (self.config.sampling.mode == "adaptive" and
1368
+ self.config.num_processes <= 1):
1369
+ return self._run_adaptive_with_ci_checking()
1370
+
1371
+ # Generate and filter battle pairs
1372
+ # If we can shard by parquet file, we can avoid constructing the full pair list
1373
+ # in the parent process (and avoid pickling huge lists).
1374
+ all_indices = self.dataset.get_all_indices()
1375
+
1376
+ # Apply sample size limit
1377
+ if self.config.sample_size and self.config.sample_size < len(all_indices):
1378
+ indices = random.sample(all_indices, self.config.sample_size)
1379
+ else:
1380
+ indices = all_indices
1381
+
1382
+ # If num_processes <= 1, fall back to the original thread-based implementation.
1383
+ if self.config.num_processes <= 1:
1384
+ all_pairs = self._generate_battle_pairs()
1385
+ pairs = self._skip_completed(all_pairs)
1386
+
1387
+ if not pairs:
1388
+ logger.info("No battles to execute")
1389
+ # Ensure state is up-to-date and order-independent
1390
+ state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
1391
+ save_state(state, self.state_path)
1392
+ return state
1393
+
1394
+ logger.info(f"Starting arena with {len(pairs)} battles to execute")
1395
+ logger.info(f"Models: {self.models}")
1396
+ logger.info(f"Experiment: {self.config.exp_name}")
1397
+ logger.info(f"Sampling mode: full")
1398
+
1399
+ # Load existing state
1400
+ state = load_state(self.state_path)
1401
+
1402
+ # Progress tracking
1403
+ completed = 0
1404
+ total = len(pairs)
1405
+
1406
+ progress_queue, _progress_thread, progress_close = _start_calls_progress_consumer(
1407
+ enabled=self.config.enable_progress_bar,
1408
+ total=(2 * len(pairs)) if self.config.enable_progress_bar else None,
1409
+ )
1410
+ self._progress_queue = progress_queue
1411
+ if self._progress_queue is not None:
1412
+ try:
1413
+ self.vlm.set_progress(self._progress_queue)
1414
+ except Exception:
1415
+ pass
1416
+
1417
+ if self.config.num_threads <= 1:
1418
+ # Sequential execution
1419
+ for pair in pairs:
1420
+ result = self._execute_single_battle(pair)
1421
+
1422
+ if result:
1423
+ state = self._process_result(result, state)
1424
+ completed += 1
1425
+
1426
+ # Progress logging every 10 battles
1427
+ if completed % 10 == 0:
1428
+ logger.info(f"Progress: {completed}/{total} battles")
1429
+ # Save intermediate state
1430
+ save_state(state, self.state_path)
1431
+ else:
1432
+ # Parallel execution
1433
+ with ThreadPoolExecutor(max_workers=self.config.num_threads) as executor:
1434
+ # Submit all battles
1435
+ future_to_pair = {
1436
+ executor.submit(self._execute_single_battle, pair): pair
1437
+ for pair in pairs
1438
+ }
1439
+
1440
+ # Process completed futures
1441
+ for future in as_completed(future_to_pair):
1442
+ pair = future_to_pair[future]
1443
+
1444
+ try:
1445
+ result = future.result()
1446
+
1447
+ if result:
1448
+ state = self._process_result(result, state)
1449
+ completed += 1
1450
+
1451
+ # Progress logging every 10 battles
1452
+ if completed % 10 == 0:
1453
+ logger.info(f"Progress: {completed}/{total} battles")
1454
+ # Save intermediate state
1455
+ save_state(state, self.state_path)
1456
+ except Exception as e:
1457
+ logger.error(
1458
+ f"Battle {pair.model_a} vs {pair.model_b} failed: {e}"
1459
+ )
1460
+
1461
+ progress_close()
1462
+
1463
+ # Final Elo recompute (Bradley-Terry) and state save
1464
+ final_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
1465
+ save_state(final_state, self.state_path)
1466
+
1467
+ logger.info(f"Arena completed: {completed}/{total} battles executed")
1468
+
1469
+ return final_state
1470
+
1471
+ # === Multiprocessing path (per-parquet sharding) ===
1472
+ grouped = self.dataset.group_indices_by_parquet(indices)
1473
+ if "" in grouped:
1474
+ logger.warning(
1475
+ "Parquet source mapping is incomplete (missing index->parquet mapping). "
1476
+ "Falling back to single-process execution."
1477
+ )
1478
+ self.config.num_processes = 1
1479
+ # Re-load dataset in full mode for single-process execution.
1480
+ self.dataset = ParquetDataset(self.config.data_dir, self.config.subset, load_mode="full")
1481
+ return self.run()
1482
+
1483
+ total_concurrency = max(1, int(self.config.num_processes)) * max(1, int(self.config.num_threads))
1484
+ logger.info(
1485
+ f"Starting arena with multiprocessing: num_processes={self.config.num_processes}, "
1486
+ f"num_threads={self.config.num_threads}, total_concurrency~{total_concurrency}"
1487
+ )
1488
+ logger.info(f"Models: {self.models}")
1489
+ logger.info(f"Experiment: {self.config.exp_name}")
1490
+
1491
+ completed = 0
1492
+ attempted = 0
1493
+ parquet_tasks = [(pf, idxs) for pf, idxs in grouped.items() if idxs]
1494
+ parquet_tasks.sort(key=lambda x: x[0])
1495
+
1496
+ # Assign parquet files to processes up-front (avoid per-parquet re-init overhead in a worker).
1497
+ # Simple greedy bin-packing by number of indices for load balancing.
1498
+ num_workers = max(1, int(self.config.num_processes))
1499
+ buckets: list[list[tuple[str, list[int]]]] = [[] for _ in range(num_workers)]
1500
+ bucket_sizes = [0 for _ in range(num_workers)]
1501
+ for pf, idxs in sorted(parquet_tasks, key=lambda x: len(x[1]), reverse=True):
1502
+ k = bucket_sizes.index(min(bucket_sizes))
1503
+ buckets[k].append((pf, idxs))
1504
+ bucket_sizes[k] += len(idxs)
1505
+
1506
+ # Progress consumer (optional). Use a process-safe Manager queue and batch updates in workers.
1507
+ manager = None
1508
+ mp_progress_queue = None
1509
+ progress_close = lambda: None
1510
+ if self.config.enable_progress_bar:
1511
+ try:
1512
+ import multiprocessing
1513
+ manager = multiprocessing.Manager()
1514
+ mp_progress_queue = manager.Queue()
1515
+ # Capture the queue reference for the closure (type narrowing)
1516
+ _queue = mp_progress_queue
1517
+ # Reuse same tqdm consumer code by wrapping manager queue into a local consumer thread.
1518
+ try:
1519
+ from tqdm import tqdm # type: ignore
1520
+ bar = tqdm(total=None, unit="call", desc="API Calls", dynamic_ncols=True)
1521
+ # Must be picklable across processes.
1522
+ stop_sentinel = ("stop", None)
1523
+ recent: deque[str] = deque(maxlen=10)
1524
+
1525
+ def _mp_consumer() -> None:
1526
+ while True:
1527
+ item = _queue.get()
1528
+ if item == stop_sentinel:
1529
+ break
1530
+ if isinstance(item, (int, float)):
1531
+ try:
1532
+ bar.update(int(item))
1533
+ except Exception:
1534
+ pass
1535
+ elif isinstance(item, tuple) and len(item) == 2 and item[0] == "log":
1536
+ try:
1537
+ recent.append(str(item[1]))
1538
+ bar.set_postfix_str(" | ".join(recent))
1539
+ except Exception:
1540
+ pass
1541
+ elif isinstance(item, tuple) and len(item) == 2 and item[0] == "total":
1542
+ try:
1543
+ delta = int(item[1])
1544
+ if delta > 0:
1545
+ bar.total = (bar.total or 0) + delta
1546
+ bar.refresh()
1547
+ except Exception:
1548
+ pass
1549
+ else:
1550
+ pass
1551
+
1552
+ t = threading.Thread(target=_mp_consumer, name="mp-calls-progress-consumer", daemon=True)
1553
+ t.start()
1554
+
1555
+ def progress_close() -> None:
1556
+ try:
1557
+ mp_progress_queue.put(stop_sentinel)
1558
+ except Exception:
1559
+ pass
1560
+ t.join(timeout=5)
1561
+ try:
1562
+ bar.close()
1563
+ except Exception:
1564
+ pass
1565
+
1566
+ except Exception:
1567
+ logger.warning("tqdm is not available; progress bar disabled")
1568
+ mp_progress_queue = None
1569
+ except Exception:
1570
+ logger.warning("Failed to initialize multiprocessing progress queue; progress bar disabled")
1571
+ mp_progress_queue = None
1572
+
1573
+ with ProcessPoolExecutor(max_workers=self.config.num_processes) as executor:
1574
+ futures = []
1575
+ for work in buckets:
1576
+ if not work:
1577
+ continue
1578
+ futures.append(executor.submit(
1579
+ _run_parquet_bucket_worker,
1580
+ arena_dir=self.config.arena_dir,
1581
+ data_dir=self.config.data_dir,
1582
+ subset=self.config.subset,
1583
+ exp_name=self.config.exp_name or "", # exp_name is guaranteed to be set in __init__
1584
+ parquet_work=work,
1585
+ models=self.models,
1586
+ new_models=self.new_models,
1587
+ num_threads=self.config.num_threads,
1588
+ judge_model=self.config.judge_model,
1589
+ temperature=self.config.temperature,
1590
+ prompt=self.config.prompt,
1591
+ timeout=self.config.timeout,
1592
+ max_retries=self.config.max_retries,
1593
+ base_urls=self.config.base_urls,
1594
+ api_keys=self.config.api_keys,
1595
+ enable_audit_log=self.config.enable_audit_log,
1596
+ parallel_swap_calls=self.config.parallel_swap_calls,
1597
+ progress_queue=mp_progress_queue,
1598
+ ))
1599
+
1600
+ for fut in as_completed(futures):
1601
+ try:
1602
+ res = fut.result()
1603
+ completed += int(res.get("completed", 0))
1604
+ attempted += int(res.get("attempted", 0))
1605
+ if completed > 0 and completed % 50 == 0:
1606
+ logger.info(f"Progress: completed={completed} attempted={attempted}")
1607
+ except Exception as e:
1608
+ logger.error(f"Worker failed: {e}")
1609
+
1610
+ progress_close()
1611
+ if manager is not None:
1612
+ try:
1613
+ manager.shutdown()
1614
+ except Exception:
1615
+ pass
1616
+
1617
+ # Final Elo recompute (Bradley-Terry) and state save
1618
+ final_state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
1619
+ save_state(final_state, self.state_path)
1620
+
1621
+ logger.info(f"Arena completed (multiprocessing): completed={completed} attempted={attempted}")
1622
+
1623
+ return final_state
1624
+
1625
+ def update_leaderboard(self) -> None:
1626
+ """Update the leaderboard README.md file."""
1627
+ # Always rebuild state from logs to ensure BT Elo is consistent and up-to-date.
1628
+ state = rebuild_state_from_logs(self.pk_logs_dir, models=self.all_models)
1629
+ save_state(state, self.state_path)
1630
+
1631
+ title = f"{self.config.subset.capitalize()} Leaderboard"
1632
+ save_leaderboard(state, self.leaderboard_path, title)
1633
+
1634
+ logger.info(f"Leaderboard saved to {self.leaderboard_path}")
1635
+
1636
+ def get_status(self) -> dict[str, Any]:
1637
+ """
1638
+ Get arena status summary.
1639
+
1640
+ Returns:
1641
+ Dict with status information
1642
+ """
1643
+ state = load_state(self.state_path)
1644
+
1645
+ return {
1646
+ "subset": self.config.subset,
1647
+ "models": self.models,
1648
+ "total_models": len(self.models),
1649
+ "total_battles": state.total_battles,
1650
+ "last_updated": state.last_updated,
1651
+ "dataset_size": len(self.dataset),
1652
+ "arena_dir": self.config.arena_dir
1653
+ }
1654
+
1655
+
1656
+ def get_all_subsets_status(arena_dir: str, data_dir: str) -> list[dict[str, Any]]:
1657
+ """
1658
+ Get status for all subsets in an arena directory.
1659
+
1660
+ Args:
1661
+ arena_dir: Arena directory path
1662
+ data_dir: Data directory path
1663
+
1664
+ Returns:
1665
+ List of status dicts for each subset
1666
+ """
1667
+ subsets = discover_subsets(data_dir)
1668
+ statuses = []
1669
+
1670
+ for subset in subsets:
1671
+ state_path = os.path.join(arena_dir, subset, "arena", "state.json")
1672
+ state = load_state(state_path)
1673
+
1674
+ models_dir = os.path.join(arena_dir, subset, "models")
1675
+ model_manager = GlobalModelOutputManager(models_dir)
1676
+
1677
+ statuses.append({
1678
+ "subset": subset,
1679
+ "models": model_manager.models,
1680
+ "total_models": len(model_manager.models),
1681
+ "total_battles": state.total_battles,
1682
+ "last_updated": state.last_updated
1683
+ })
1684
+
1685
+ return statuses