genarena 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genarena/__init__.py +49 -2
- genarena/__main__.py +10 -0
- genarena/arena.py +1685 -0
- genarena/battle.py +337 -0
- genarena/bt_elo.py +507 -0
- genarena/cli.py +1581 -0
- genarena/data.py +476 -0
- genarena/deploy/Dockerfile +22 -0
- genarena/deploy/README.md +55 -0
- genarena/deploy/__init__.py +5 -0
- genarena/deploy/app.py +84 -0
- genarena/experiments.py +121 -0
- genarena/leaderboard.py +270 -0
- genarena/logs.py +409 -0
- genarena/models.py +412 -0
- genarena/prompts/__init__.py +127 -0
- genarena/prompts/mmrb2.py +373 -0
- genarena/sampling.py +336 -0
- genarena/state.py +656 -0
- genarena/sync/__init__.py +105 -0
- genarena/sync/auto_commit.py +118 -0
- genarena/sync/deploy_ops.py +543 -0
- genarena/sync/git_ops.py +422 -0
- genarena/sync/hf_ops.py +891 -0
- genarena/sync/init_ops.py +431 -0
- genarena/sync/packer.py +587 -0
- genarena/sync/submit.py +837 -0
- genarena/utils.py +103 -0
- genarena/validation/__init__.py +19 -0
- genarena/validation/schema.py +327 -0
- genarena/validation/validator.py +329 -0
- genarena/visualize/README.md +148 -0
- genarena/visualize/__init__.py +14 -0
- genarena/visualize/app.py +938 -0
- genarena/visualize/data_loader.py +2430 -0
- genarena/visualize/static/app.js +3762 -0
- genarena/visualize/static/model_aliases.json +86 -0
- genarena/visualize/static/style.css +4104 -0
- genarena/visualize/templates/index.html +413 -0
- genarena/vlm.py +519 -0
- genarena-0.1.1.dist-info/METADATA +178 -0
- genarena-0.1.1.dist-info/RECORD +44 -0
- {genarena-0.0.1.dist-info → genarena-0.1.1.dist-info}/WHEEL +1 -2
- genarena-0.1.1.dist-info/entry_points.txt +2 -0
- genarena-0.0.1.dist-info/METADATA +0 -26
- genarena-0.0.1.dist-info/RECORD +0 -5
- genarena-0.0.1.dist-info/top_level.txt +0 -1
genarena/sampling.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
# Copyright 2026 Ruihang Li.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0.
|
|
3
|
+
# See LICENSE file in the project root for details.
|
|
4
|
+
|
|
5
|
+
"""Adaptive sampling configuration and utilities for battle scheduling.
|
|
6
|
+
|
|
7
|
+
This module provides configuration and logic for adaptive sampling strategies
|
|
8
|
+
to reduce the number of battles needed while maintaining statistical precision.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import Any, Optional
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SamplingConfig:
|
|
19
|
+
"""Configuration for battle sampling strategy.
|
|
20
|
+
|
|
21
|
+
Supports two modes:
|
|
22
|
+
- "full": Traditional full pairwise comparison (打满所有 samples)
|
|
23
|
+
- "adaptive": Adaptive sampling based on CI convergence (自适应采样)
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
mode: Sampling mode, either "full" or "adaptive".
|
|
27
|
+
|
|
28
|
+
# Adaptive mode parameters
|
|
29
|
+
min_samples: Minimum samples per model pair before checking CI.
|
|
30
|
+
max_samples: Maximum samples per model pair (hard cap).
|
|
31
|
+
batch_size: Number of samples to add in each adaptive iteration.
|
|
32
|
+
target_ci_width: Target 95% CI width (full width, not ±).
|
|
33
|
+
Sampling stops when CI width <= target_ci_width.
|
|
34
|
+
num_bootstrap: Number of bootstrap iterations for CI computation.
|
|
35
|
+
|
|
36
|
+
# Full mode parameters
|
|
37
|
+
sample_size: Fixed number of samples per pair (None = all available).
|
|
38
|
+
|
|
39
|
+
# Milestone parameters
|
|
40
|
+
milestone_min_samples: Minimum samples per pair for milestone experiments.
|
|
41
|
+
Used to ensure milestone snapshots have sufficient precision.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
mode: str = "adaptive"
|
|
45
|
+
|
|
46
|
+
# Adaptive mode parameters
|
|
47
|
+
min_samples: int = 100
|
|
48
|
+
max_samples: int = 1500
|
|
49
|
+
batch_size: int = 100
|
|
50
|
+
target_ci_width: float = 15.0 # ±7.5 Elo
|
|
51
|
+
num_bootstrap: int = 100
|
|
52
|
+
|
|
53
|
+
# Full mode parameters
|
|
54
|
+
sample_size: Optional[int] = None
|
|
55
|
+
|
|
56
|
+
# Milestone parameters
|
|
57
|
+
milestone_min_samples: int = 1000
|
|
58
|
+
|
|
59
|
+
def __post_init__(self) -> None:
|
|
60
|
+
"""Validate configuration."""
|
|
61
|
+
if self.mode not in ("full", "adaptive"):
|
|
62
|
+
raise ValueError(f"Invalid mode: {self.mode}. Must be 'full' or 'adaptive'.")
|
|
63
|
+
|
|
64
|
+
if self.min_samples < 1:
|
|
65
|
+
raise ValueError(f"min_samples must be >= 1, got {self.min_samples}")
|
|
66
|
+
|
|
67
|
+
if self.max_samples < self.min_samples:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"max_samples ({self.max_samples}) must be >= min_samples ({self.min_samples})"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
if self.batch_size < 1:
|
|
73
|
+
raise ValueError(f"batch_size must be >= 1, got {self.batch_size}")
|
|
74
|
+
|
|
75
|
+
if self.target_ci_width <= 0:
|
|
76
|
+
raise ValueError(f"target_ci_width must be > 0, got {self.target_ci_width}")
|
|
77
|
+
|
|
78
|
+
def to_dict(self) -> dict[str, Any]:
|
|
79
|
+
"""Convert to dictionary for serialization."""
|
|
80
|
+
return {
|
|
81
|
+
"mode": self.mode,
|
|
82
|
+
"min_samples": self.min_samples,
|
|
83
|
+
"max_samples": self.max_samples,
|
|
84
|
+
"batch_size": self.batch_size,
|
|
85
|
+
"target_ci_width": self.target_ci_width,
|
|
86
|
+
"num_bootstrap": self.num_bootstrap,
|
|
87
|
+
"sample_size": self.sample_size,
|
|
88
|
+
"milestone_min_samples": self.milestone_min_samples,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_dict(cls, data: dict[str, Any]) -> "SamplingConfig":
|
|
93
|
+
"""Create from dictionary."""
|
|
94
|
+
return cls(
|
|
95
|
+
mode=data.get("mode", "adaptive"),
|
|
96
|
+
min_samples=data.get("min_samples", 100),
|
|
97
|
+
max_samples=data.get("max_samples", 1500),
|
|
98
|
+
batch_size=data.get("batch_size", 100),
|
|
99
|
+
target_ci_width=data.get("target_ci_width", 15.0),
|
|
100
|
+
num_bootstrap=data.get("num_bootstrap", 100),
|
|
101
|
+
sample_size=data.get("sample_size"),
|
|
102
|
+
milestone_min_samples=data.get("milestone_min_samples", 1000),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def full_mode(cls, sample_size: Optional[int] = None) -> "SamplingConfig":
|
|
107
|
+
"""Create a full-mode configuration.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
sample_size: Fixed sample size per pair (None = all available).
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
SamplingConfig in full mode.
|
|
114
|
+
"""
|
|
115
|
+
return cls(mode="full", sample_size=sample_size)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def adaptive_mode(
|
|
119
|
+
cls,
|
|
120
|
+
target_ci_width: float = 15.0,
|
|
121
|
+
min_samples: int = 100,
|
|
122
|
+
max_samples: int = 1500,
|
|
123
|
+
) -> "SamplingConfig":
|
|
124
|
+
"""Create an adaptive-mode configuration.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
target_ci_width: Target 95% CI width.
|
|
128
|
+
min_samples: Minimum samples before checking CI.
|
|
129
|
+
max_samples: Maximum samples per pair.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
SamplingConfig in adaptive mode.
|
|
133
|
+
"""
|
|
134
|
+
return cls(
|
|
135
|
+
mode="adaptive",
|
|
136
|
+
target_ci_width=target_ci_width,
|
|
137
|
+
min_samples=min_samples,
|
|
138
|
+
max_samples=max_samples,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class PairSamplingState:
|
|
144
|
+
"""Tracks sampling state for a single model pair.
|
|
145
|
+
|
|
146
|
+
This state is derived from existing battle logs, enabling resume.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
model_a: str
|
|
150
|
+
model_b: str
|
|
151
|
+
current_samples: int = 0
|
|
152
|
+
ci_width: Optional[float] = None
|
|
153
|
+
converged: bool = False
|
|
154
|
+
|
|
155
|
+
def needs_more_samples(self, config: SamplingConfig) -> bool:
|
|
156
|
+
"""Check if this pair needs more samples.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
config: Sampling configuration.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
True if more samples are needed.
|
|
163
|
+
"""
|
|
164
|
+
if config.mode == "full":
|
|
165
|
+
if config.sample_size is None:
|
|
166
|
+
return True # Will be bounded by available samples
|
|
167
|
+
return self.current_samples < config.sample_size
|
|
168
|
+
|
|
169
|
+
# Adaptive mode
|
|
170
|
+
if self.current_samples < config.min_samples:
|
|
171
|
+
return True
|
|
172
|
+
|
|
173
|
+
if self.current_samples >= config.max_samples:
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
if self.converged:
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
if self.ci_width is not None and self.ci_width <= config.target_ci_width:
|
|
180
|
+
self.converged = True
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
def get_samples_to_run(self, config: SamplingConfig, available: int) -> int:
|
|
186
|
+
"""Calculate how many samples to run in the next batch.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
config: Sampling configuration.
|
|
190
|
+
available: Number of available samples (dataset size).
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Number of samples to run (can be 0).
|
|
194
|
+
"""
|
|
195
|
+
if not self.needs_more_samples(config):
|
|
196
|
+
return 0
|
|
197
|
+
|
|
198
|
+
if config.mode == "full":
|
|
199
|
+
target = config.sample_size if config.sample_size else available
|
|
200
|
+
return max(0, min(target, available) - self.current_samples)
|
|
201
|
+
|
|
202
|
+
# Adaptive mode
|
|
203
|
+
if self.current_samples < config.min_samples:
|
|
204
|
+
# Initial batch: reach min_samples
|
|
205
|
+
target = min(config.min_samples, available)
|
|
206
|
+
return max(0, target - self.current_samples)
|
|
207
|
+
|
|
208
|
+
# Subsequent batches: add batch_size
|
|
209
|
+
target = min(self.current_samples + config.batch_size, config.max_samples, available)
|
|
210
|
+
return max(0, target - self.current_samples)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass
|
|
214
|
+
class AdaptiveSamplingScheduler:
|
|
215
|
+
"""Scheduler for adaptive sampling across all model pairs.
|
|
216
|
+
|
|
217
|
+
Manages the sampling state for all pairs and determines which
|
|
218
|
+
pairs need more battles based on CI convergence.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
config: SamplingConfig
|
|
222
|
+
pair_states: dict[tuple[str, str], PairSamplingState] = field(default_factory=dict)
|
|
223
|
+
|
|
224
|
+
def get_or_create_state(self, model_a: str, model_b: str) -> PairSamplingState:
|
|
225
|
+
"""Get or create sampling state for a model pair.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
model_a: First model name.
|
|
229
|
+
model_b: Second model name.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
PairSamplingState for this pair.
|
|
233
|
+
"""
|
|
234
|
+
# Normalize pair order
|
|
235
|
+
key = (min(model_a, model_b), max(model_a, model_b))
|
|
236
|
+
|
|
237
|
+
if key not in self.pair_states:
|
|
238
|
+
self.pair_states[key] = PairSamplingState(
|
|
239
|
+
model_a=key[0],
|
|
240
|
+
model_b=key[1],
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return self.pair_states[key]
|
|
244
|
+
|
|
245
|
+
def update_state(
|
|
246
|
+
self,
|
|
247
|
+
model_a: str,
|
|
248
|
+
model_b: str,
|
|
249
|
+
current_samples: int,
|
|
250
|
+
ci_width: Optional[float] = None,
|
|
251
|
+
) -> None:
|
|
252
|
+
"""Update sampling state for a model pair.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
model_a: First model name.
|
|
256
|
+
model_b: Second model name.
|
|
257
|
+
current_samples: Current number of samples.
|
|
258
|
+
ci_width: Current CI width (if computed).
|
|
259
|
+
"""
|
|
260
|
+
state = self.get_or_create_state(model_a, model_b)
|
|
261
|
+
state.current_samples = current_samples
|
|
262
|
+
state.ci_width = ci_width
|
|
263
|
+
|
|
264
|
+
# Check convergence
|
|
265
|
+
if ci_width is not None and ci_width <= self.config.target_ci_width:
|
|
266
|
+
state.converged = True
|
|
267
|
+
|
|
268
|
+
def get_pairs_needing_samples(self) -> list[tuple[str, str]]:
|
|
269
|
+
"""Get list of pairs that need more samples.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
List of (model_a, model_b) tuples needing more samples.
|
|
273
|
+
"""
|
|
274
|
+
result = []
|
|
275
|
+
for key, state in self.pair_states.items():
|
|
276
|
+
if state.needs_more_samples(self.config):
|
|
277
|
+
result.append(key)
|
|
278
|
+
return result
|
|
279
|
+
|
|
280
|
+
def get_pair_with_widest_ci(self) -> Optional[tuple[str, str]]:
|
|
281
|
+
"""Get the pair with the widest CI that hasn't converged.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
(model_a, model_b) tuple, or None if all converged.
|
|
285
|
+
"""
|
|
286
|
+
widest_pair = None
|
|
287
|
+
widest_ci = -1.0
|
|
288
|
+
|
|
289
|
+
for key, state in self.pair_states.items():
|
|
290
|
+
if state.converged:
|
|
291
|
+
continue
|
|
292
|
+
if state.current_samples >= self.config.max_samples:
|
|
293
|
+
continue
|
|
294
|
+
if state.ci_width is not None and state.ci_width > widest_ci:
|
|
295
|
+
widest_ci = state.ci_width
|
|
296
|
+
widest_pair = key
|
|
297
|
+
|
|
298
|
+
return widest_pair
|
|
299
|
+
|
|
300
|
+
def all_converged(self) -> bool:
|
|
301
|
+
"""Check if all pairs have converged.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
True if all pairs have converged or reached max_samples.
|
|
305
|
+
"""
|
|
306
|
+
for state in self.pair_states.values():
|
|
307
|
+
if state.needs_more_samples(self.config):
|
|
308
|
+
return False
|
|
309
|
+
return True
|
|
310
|
+
|
|
311
|
+
def get_summary(self) -> dict[str, Any]:
|
|
312
|
+
"""Get summary statistics of current sampling state.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Dictionary with summary statistics.
|
|
316
|
+
"""
|
|
317
|
+
total_pairs = len(self.pair_states)
|
|
318
|
+
converged_pairs = sum(1 for s in self.pair_states.values() if s.converged)
|
|
319
|
+
maxed_pairs = sum(
|
|
320
|
+
1 for s in self.pair_states.values()
|
|
321
|
+
if s.current_samples >= self.config.max_samples
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
ci_widths = [s.ci_width for s in self.pair_states.values() if s.ci_width is not None]
|
|
325
|
+
total_samples = sum(s.current_samples for s in self.pair_states.values())
|
|
326
|
+
|
|
327
|
+
return {
|
|
328
|
+
"total_pairs": total_pairs,
|
|
329
|
+
"converged_pairs": converged_pairs,
|
|
330
|
+
"maxed_pairs": maxed_pairs,
|
|
331
|
+
"pending_pairs": total_pairs - converged_pairs - maxed_pairs,
|
|
332
|
+
"total_samples": total_samples,
|
|
333
|
+
"mean_ci_width": sum(ci_widths) / len(ci_widths) if ci_widths else None,
|
|
334
|
+
"max_ci_width": max(ci_widths) if ci_widths else None,
|
|
335
|
+
"min_ci_width": min(ci_widths) if ci_widths else None,
|
|
336
|
+
}
|