genarena 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. genarena/__init__.py +49 -2
  2. genarena/__main__.py +10 -0
  3. genarena/arena.py +1685 -0
  4. genarena/battle.py +337 -0
  5. genarena/bt_elo.py +507 -0
  6. genarena/cli.py +1581 -0
  7. genarena/data.py +476 -0
  8. genarena/deploy/Dockerfile +22 -0
  9. genarena/deploy/README.md +55 -0
  10. genarena/deploy/__init__.py +5 -0
  11. genarena/deploy/app.py +84 -0
  12. genarena/experiments.py +121 -0
  13. genarena/leaderboard.py +270 -0
  14. genarena/logs.py +409 -0
  15. genarena/models.py +412 -0
  16. genarena/prompts/__init__.py +127 -0
  17. genarena/prompts/mmrb2.py +373 -0
  18. genarena/sampling.py +336 -0
  19. genarena/state.py +656 -0
  20. genarena/sync/__init__.py +105 -0
  21. genarena/sync/auto_commit.py +118 -0
  22. genarena/sync/deploy_ops.py +543 -0
  23. genarena/sync/git_ops.py +422 -0
  24. genarena/sync/hf_ops.py +891 -0
  25. genarena/sync/init_ops.py +431 -0
  26. genarena/sync/packer.py +587 -0
  27. genarena/sync/submit.py +837 -0
  28. genarena/utils.py +103 -0
  29. genarena/validation/__init__.py +19 -0
  30. genarena/validation/schema.py +327 -0
  31. genarena/validation/validator.py +329 -0
  32. genarena/visualize/README.md +148 -0
  33. genarena/visualize/__init__.py +14 -0
  34. genarena/visualize/app.py +938 -0
  35. genarena/visualize/data_loader.py +2430 -0
  36. genarena/visualize/static/app.js +3762 -0
  37. genarena/visualize/static/model_aliases.json +86 -0
  38. genarena/visualize/static/style.css +4104 -0
  39. genarena/visualize/templates/index.html +413 -0
  40. genarena/vlm.py +519 -0
  41. genarena-0.1.1.dist-info/METADATA +178 -0
  42. genarena-0.1.1.dist-info/RECORD +44 -0
  43. {genarena-0.0.1.dist-info → genarena-0.1.1.dist-info}/WHEEL +1 -2
  44. genarena-0.1.1.dist-info/entry_points.txt +2 -0
  45. genarena-0.0.1.dist-info/METADATA +0 -26
  46. genarena-0.0.1.dist-info/RECORD +0 -5
  47. genarena-0.0.1.dist-info/top_level.txt +0 -1
genarena/sampling.py ADDED
@@ -0,0 +1,336 @@
1
+ # Copyright 2026 Ruihang Li.
2
+ # Licensed under the Apache License, Version 2.0.
3
+ # See LICENSE file in the project root for details.
4
+
5
+ """Adaptive sampling configuration and utilities for battle scheduling.
6
+
7
+ This module provides configuration and logic for adaptive sampling strategies
8
+ to reduce the number of battles needed while maintaining statistical precision.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Optional
15
+
16
+
17
+ @dataclass
18
+ class SamplingConfig:
19
+ """Configuration for battle sampling strategy.
20
+
21
+ Supports two modes:
22
+ - "full": Traditional full pairwise comparison (打满所有 samples)
23
+ - "adaptive": Adaptive sampling based on CI convergence (自适应采样)
24
+
25
+ Attributes:
26
+ mode: Sampling mode, either "full" or "adaptive".
27
+
28
+ # Adaptive mode parameters
29
+ min_samples: Minimum samples per model pair before checking CI.
30
+ max_samples: Maximum samples per model pair (hard cap).
31
+ batch_size: Number of samples to add in each adaptive iteration.
32
+ target_ci_width: Target 95% CI width (full width, not ±).
33
+ Sampling stops when CI width <= target_ci_width.
34
+ num_bootstrap: Number of bootstrap iterations for CI computation.
35
+
36
+ # Full mode parameters
37
+ sample_size: Fixed number of samples per pair (None = all available).
38
+
39
+ # Milestone parameters
40
+ milestone_min_samples: Minimum samples per pair for milestone experiments.
41
+ Used to ensure milestone snapshots have sufficient precision.
42
+ """
43
+
44
+ mode: str = "adaptive"
45
+
46
+ # Adaptive mode parameters
47
+ min_samples: int = 100
48
+ max_samples: int = 1500
49
+ batch_size: int = 100
50
+ target_ci_width: float = 15.0 # ±7.5 Elo
51
+ num_bootstrap: int = 100
52
+
53
+ # Full mode parameters
54
+ sample_size: Optional[int] = None
55
+
56
+ # Milestone parameters
57
+ milestone_min_samples: int = 1000
58
+
59
+ def __post_init__(self) -> None:
60
+ """Validate configuration."""
61
+ if self.mode not in ("full", "adaptive"):
62
+ raise ValueError(f"Invalid mode: {self.mode}. Must be 'full' or 'adaptive'.")
63
+
64
+ if self.min_samples < 1:
65
+ raise ValueError(f"min_samples must be >= 1, got {self.min_samples}")
66
+
67
+ if self.max_samples < self.min_samples:
68
+ raise ValueError(
69
+ f"max_samples ({self.max_samples}) must be >= min_samples ({self.min_samples})"
70
+ )
71
+
72
+ if self.batch_size < 1:
73
+ raise ValueError(f"batch_size must be >= 1, got {self.batch_size}")
74
+
75
+ if self.target_ci_width <= 0:
76
+ raise ValueError(f"target_ci_width must be > 0, got {self.target_ci_width}")
77
+
78
+ def to_dict(self) -> dict[str, Any]:
79
+ """Convert to dictionary for serialization."""
80
+ return {
81
+ "mode": self.mode,
82
+ "min_samples": self.min_samples,
83
+ "max_samples": self.max_samples,
84
+ "batch_size": self.batch_size,
85
+ "target_ci_width": self.target_ci_width,
86
+ "num_bootstrap": self.num_bootstrap,
87
+ "sample_size": self.sample_size,
88
+ "milestone_min_samples": self.milestone_min_samples,
89
+ }
90
+
91
+ @classmethod
92
+ def from_dict(cls, data: dict[str, Any]) -> "SamplingConfig":
93
+ """Create from dictionary."""
94
+ return cls(
95
+ mode=data.get("mode", "adaptive"),
96
+ min_samples=data.get("min_samples", 100),
97
+ max_samples=data.get("max_samples", 1500),
98
+ batch_size=data.get("batch_size", 100),
99
+ target_ci_width=data.get("target_ci_width", 15.0),
100
+ num_bootstrap=data.get("num_bootstrap", 100),
101
+ sample_size=data.get("sample_size"),
102
+ milestone_min_samples=data.get("milestone_min_samples", 1000),
103
+ )
104
+
105
+ @classmethod
106
+ def full_mode(cls, sample_size: Optional[int] = None) -> "SamplingConfig":
107
+ """Create a full-mode configuration.
108
+
109
+ Args:
110
+ sample_size: Fixed sample size per pair (None = all available).
111
+
112
+ Returns:
113
+ SamplingConfig in full mode.
114
+ """
115
+ return cls(mode="full", sample_size=sample_size)
116
+
117
+ @classmethod
118
+ def adaptive_mode(
119
+ cls,
120
+ target_ci_width: float = 15.0,
121
+ min_samples: int = 100,
122
+ max_samples: int = 1500,
123
+ ) -> "SamplingConfig":
124
+ """Create an adaptive-mode configuration.
125
+
126
+ Args:
127
+ target_ci_width: Target 95% CI width.
128
+ min_samples: Minimum samples before checking CI.
129
+ max_samples: Maximum samples per pair.
130
+
131
+ Returns:
132
+ SamplingConfig in adaptive mode.
133
+ """
134
+ return cls(
135
+ mode="adaptive",
136
+ target_ci_width=target_ci_width,
137
+ min_samples=min_samples,
138
+ max_samples=max_samples,
139
+ )
140
+
141
+
142
+ @dataclass
143
+ class PairSamplingState:
144
+ """Tracks sampling state for a single model pair.
145
+
146
+ This state is derived from existing battle logs, enabling resume.
147
+ """
148
+
149
+ model_a: str
150
+ model_b: str
151
+ current_samples: int = 0
152
+ ci_width: Optional[float] = None
153
+ converged: bool = False
154
+
155
+ def needs_more_samples(self, config: SamplingConfig) -> bool:
156
+ """Check if this pair needs more samples.
157
+
158
+ Args:
159
+ config: Sampling configuration.
160
+
161
+ Returns:
162
+ True if more samples are needed.
163
+ """
164
+ if config.mode == "full":
165
+ if config.sample_size is None:
166
+ return True # Will be bounded by available samples
167
+ return self.current_samples < config.sample_size
168
+
169
+ # Adaptive mode
170
+ if self.current_samples < config.min_samples:
171
+ return True
172
+
173
+ if self.current_samples >= config.max_samples:
174
+ return False
175
+
176
+ if self.converged:
177
+ return False
178
+
179
+ if self.ci_width is not None and self.ci_width <= config.target_ci_width:
180
+ self.converged = True
181
+ return False
182
+
183
+ return True
184
+
185
+ def get_samples_to_run(self, config: SamplingConfig, available: int) -> int:
186
+ """Calculate how many samples to run in the next batch.
187
+
188
+ Args:
189
+ config: Sampling configuration.
190
+ available: Number of available samples (dataset size).
191
+
192
+ Returns:
193
+ Number of samples to run (can be 0).
194
+ """
195
+ if not self.needs_more_samples(config):
196
+ return 0
197
+
198
+ if config.mode == "full":
199
+ target = config.sample_size if config.sample_size else available
200
+ return max(0, min(target, available) - self.current_samples)
201
+
202
+ # Adaptive mode
203
+ if self.current_samples < config.min_samples:
204
+ # Initial batch: reach min_samples
205
+ target = min(config.min_samples, available)
206
+ return max(0, target - self.current_samples)
207
+
208
+ # Subsequent batches: add batch_size
209
+ target = min(self.current_samples + config.batch_size, config.max_samples, available)
210
+ return max(0, target - self.current_samples)
211
+
212
+
213
+ @dataclass
214
+ class AdaptiveSamplingScheduler:
215
+ """Scheduler for adaptive sampling across all model pairs.
216
+
217
+ Manages the sampling state for all pairs and determines which
218
+ pairs need more battles based on CI convergence.
219
+ """
220
+
221
+ config: SamplingConfig
222
+ pair_states: dict[tuple[str, str], PairSamplingState] = field(default_factory=dict)
223
+
224
+ def get_or_create_state(self, model_a: str, model_b: str) -> PairSamplingState:
225
+ """Get or create sampling state for a model pair.
226
+
227
+ Args:
228
+ model_a: First model name.
229
+ model_b: Second model name.
230
+
231
+ Returns:
232
+ PairSamplingState for this pair.
233
+ """
234
+ # Normalize pair order
235
+ key = (min(model_a, model_b), max(model_a, model_b))
236
+
237
+ if key not in self.pair_states:
238
+ self.pair_states[key] = PairSamplingState(
239
+ model_a=key[0],
240
+ model_b=key[1],
241
+ )
242
+
243
+ return self.pair_states[key]
244
+
245
+ def update_state(
246
+ self,
247
+ model_a: str,
248
+ model_b: str,
249
+ current_samples: int,
250
+ ci_width: Optional[float] = None,
251
+ ) -> None:
252
+ """Update sampling state for a model pair.
253
+
254
+ Args:
255
+ model_a: First model name.
256
+ model_b: Second model name.
257
+ current_samples: Current number of samples.
258
+ ci_width: Current CI width (if computed).
259
+ """
260
+ state = self.get_or_create_state(model_a, model_b)
261
+ state.current_samples = current_samples
262
+ state.ci_width = ci_width
263
+
264
+ # Check convergence
265
+ if ci_width is not None and ci_width <= self.config.target_ci_width:
266
+ state.converged = True
267
+
268
+ def get_pairs_needing_samples(self) -> list[tuple[str, str]]:
269
+ """Get list of pairs that need more samples.
270
+
271
+ Returns:
272
+ List of (model_a, model_b) tuples needing more samples.
273
+ """
274
+ result = []
275
+ for key, state in self.pair_states.items():
276
+ if state.needs_more_samples(self.config):
277
+ result.append(key)
278
+ return result
279
+
280
+ def get_pair_with_widest_ci(self) -> Optional[tuple[str, str]]:
281
+ """Get the pair with the widest CI that hasn't converged.
282
+
283
+ Returns:
284
+ (model_a, model_b) tuple, or None if all converged.
285
+ """
286
+ widest_pair = None
287
+ widest_ci = -1.0
288
+
289
+ for key, state in self.pair_states.items():
290
+ if state.converged:
291
+ continue
292
+ if state.current_samples >= self.config.max_samples:
293
+ continue
294
+ if state.ci_width is not None and state.ci_width > widest_ci:
295
+ widest_ci = state.ci_width
296
+ widest_pair = key
297
+
298
+ return widest_pair
299
+
300
+ def all_converged(self) -> bool:
301
+ """Check if all pairs have converged.
302
+
303
+ Returns:
304
+ True if all pairs have converged or reached max_samples.
305
+ """
306
+ for state in self.pair_states.values():
307
+ if state.needs_more_samples(self.config):
308
+ return False
309
+ return True
310
+
311
+ def get_summary(self) -> dict[str, Any]:
312
+ """Get summary statistics of current sampling state.
313
+
314
+ Returns:
315
+ Dictionary with summary statistics.
316
+ """
317
+ total_pairs = len(self.pair_states)
318
+ converged_pairs = sum(1 for s in self.pair_states.values() if s.converged)
319
+ maxed_pairs = sum(
320
+ 1 for s in self.pair_states.values()
321
+ if s.current_samples >= self.config.max_samples
322
+ )
323
+
324
+ ci_widths = [s.ci_width for s in self.pair_states.values() if s.ci_width is not None]
325
+ total_samples = sum(s.current_samples for s in self.pair_states.values())
326
+
327
+ return {
328
+ "total_pairs": total_pairs,
329
+ "converged_pairs": converged_pairs,
330
+ "maxed_pairs": maxed_pairs,
331
+ "pending_pairs": total_pairs - converged_pairs - maxed_pairs,
332
+ "total_samples": total_samples,
333
+ "mean_ci_width": sum(ci_widths) / len(ci_widths) if ci_widths else None,
334
+ "max_ci_width": max(ci_widths) if ci_widths else None,
335
+ "min_ci_width": min(ci_widths) if ci_widths else None,
336
+ }