genarena 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genarena/__init__.py +49 -2
- genarena/__main__.py +10 -0
- genarena/arena.py +1685 -0
- genarena/battle.py +337 -0
- genarena/bt_elo.py +507 -0
- genarena/cli.py +1581 -0
- genarena/data.py +476 -0
- genarena/deploy/Dockerfile +25 -0
- genarena/deploy/README.md +55 -0
- genarena/deploy/__init__.py +5 -0
- genarena/deploy/app.py +84 -0
- genarena/experiments.py +121 -0
- genarena/leaderboard.py +270 -0
- genarena/logs.py +409 -0
- genarena/models.py +412 -0
- genarena/prompts/__init__.py +127 -0
- genarena/prompts/mmrb2.py +373 -0
- genarena/sampling.py +336 -0
- genarena/state.py +656 -0
- genarena/sync/__init__.py +105 -0
- genarena/sync/auto_commit.py +118 -0
- genarena/sync/deploy_ops.py +543 -0
- genarena/sync/git_ops.py +422 -0
- genarena/sync/hf_ops.py +891 -0
- genarena/sync/init_ops.py +431 -0
- genarena/sync/packer.py +587 -0
- genarena/sync/submit.py +837 -0
- genarena/utils.py +103 -0
- genarena/validation/__init__.py +19 -0
- genarena/validation/schema.py +327 -0
- genarena/validation/validator.py +329 -0
- genarena/visualize/README.md +148 -0
- genarena/visualize/__init__.py +14 -0
- genarena/visualize/app.py +938 -0
- genarena/visualize/data_loader.py +2335 -0
- genarena/visualize/static/app.js +3762 -0
- genarena/visualize/static/model_aliases.json +86 -0
- genarena/visualize/static/style.css +4104 -0
- genarena/visualize/templates/index.html +413 -0
- genarena/vlm.py +519 -0
- genarena-0.1.0.dist-info/METADATA +178 -0
- genarena-0.1.0.dist-info/RECORD +44 -0
- {genarena-0.0.1.dist-info → genarena-0.1.0.dist-info}/WHEEL +1 -2
- genarena-0.1.0.dist-info/entry_points.txt +2 -0
- genarena-0.0.1.dist-info/METADATA +0 -26
- genarena-0.0.1.dist-info/RECORD +0 -5
- genarena-0.0.1.dist-info/top_level.txt +0 -1
genarena/models.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
# Copyright 2026 Ruihang Li.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0.
|
|
3
|
+
# See LICENSE file in the project root for details.
|
|
4
|
+
|
|
5
|
+
"""Model output management module."""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import warnings
|
|
12
|
+
from datetime import date
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
from genarena.experiments import parse_exp_date_suffix
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def parse_image_index(filename: str) -> Optional[int]:
|
|
19
|
+
"""
|
|
20
|
+
Extract numeric index from image filename.
|
|
21
|
+
|
|
22
|
+
Supports both zero-padded format (e.g., '000001.png') and simple numeric
|
|
23
|
+
format (e.g., '1.png', '42.png').
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
filename: Image filename (e.g., '000001.png' or '1.png')
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Integer index, or None if parsing fails
|
|
30
|
+
"""
|
|
31
|
+
# Remove extension and extract numeric part
|
|
32
|
+
name = os.path.splitext(filename)[0]
|
|
33
|
+
|
|
34
|
+
# Match pure numeric names (with or without leading zeros)
|
|
35
|
+
if name.isdigit():
|
|
36
|
+
return int(name)
|
|
37
|
+
|
|
38
|
+
# Try to extract leading number
|
|
39
|
+
match = re.match(r'^(\d+)', name)
|
|
40
|
+
if match:
|
|
41
|
+
return int(match.group(1))
|
|
42
|
+
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def discover_models(models_dir: str) -> list[str]:
|
|
47
|
+
"""
|
|
48
|
+
Discover all model subdirectories in the models directory.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
models_dir: Path to the models directory
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
List of model names (directory names)
|
|
55
|
+
"""
|
|
56
|
+
models = []
|
|
57
|
+
|
|
58
|
+
if not os.path.isdir(models_dir):
|
|
59
|
+
warnings.warn(f"Models directory does not exist: {models_dir}")
|
|
60
|
+
return models
|
|
61
|
+
|
|
62
|
+
for name in os.listdir(models_dir):
|
|
63
|
+
model_path = os.path.join(models_dir, name)
|
|
64
|
+
if os.path.isdir(model_path):
|
|
65
|
+
# Check if directory contains any image files
|
|
66
|
+
has_images = any(
|
|
67
|
+
f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))
|
|
68
|
+
for f in os.listdir(model_path)
|
|
69
|
+
)
|
|
70
|
+
if has_images:
|
|
71
|
+
models.append(name)
|
|
72
|
+
|
|
73
|
+
return sorted(models)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ModelOutputManager:
|
|
77
|
+
"""
|
|
78
|
+
Manager for model output images.
|
|
79
|
+
|
|
80
|
+
Handles discovery and retrieval of model output images,
|
|
81
|
+
supporting various naming formats.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, models_dir: str):
|
|
85
|
+
"""
|
|
86
|
+
Initialize the manager and scan for model outputs.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
models_dir: Path to the models directory
|
|
90
|
+
"""
|
|
91
|
+
self.models_dir = models_dir
|
|
92
|
+
|
|
93
|
+
# Model name -> {index -> filepath}
|
|
94
|
+
self._index_map: dict[str, dict[int, str]] = {}
|
|
95
|
+
|
|
96
|
+
# Discover models
|
|
97
|
+
self._models: list[str] = []
|
|
98
|
+
self._scan_models()
|
|
99
|
+
|
|
100
|
+
def _scan_models(self) -> None:
|
|
101
|
+
"""Scan models directory and build index mapping."""
|
|
102
|
+
if not os.path.isdir(self.models_dir):
|
|
103
|
+
warnings.warn(f"Models directory does not exist: {self.models_dir}")
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
for model_name in os.listdir(self.models_dir):
|
|
107
|
+
model_path = os.path.join(self.models_dir, model_name)
|
|
108
|
+
if not os.path.isdir(model_path):
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
# Build index map for this model
|
|
112
|
+
index_map: dict[int, str] = {}
|
|
113
|
+
|
|
114
|
+
for filename in os.listdir(model_path):
|
|
115
|
+
# Only process image files
|
|
116
|
+
if not filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
idx = parse_image_index(filename)
|
|
120
|
+
if idx is not None:
|
|
121
|
+
filepath = os.path.join(model_path, filename)
|
|
122
|
+
# If duplicate index, prefer .png
|
|
123
|
+
if idx in index_map:
|
|
124
|
+
if filename.lower().endswith('.png'):
|
|
125
|
+
index_map[idx] = filepath
|
|
126
|
+
else:
|
|
127
|
+
index_map[idx] = filepath
|
|
128
|
+
|
|
129
|
+
if index_map:
|
|
130
|
+
self._models.append(model_name)
|
|
131
|
+
self._index_map[model_name] = index_map
|
|
132
|
+
|
|
133
|
+
self._models.sort()
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def models(self) -> list[str]:
|
|
137
|
+
"""Get list of discovered model names."""
|
|
138
|
+
return self._models.copy()
|
|
139
|
+
|
|
140
|
+
def get_output_path(self, model: str, index: int) -> Optional[str]:
|
|
141
|
+
"""
|
|
142
|
+
Get the output image path for a model at a given index.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
model: Model name
|
|
146
|
+
index: Sample index
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Path to the output image, or None if not found
|
|
150
|
+
"""
|
|
151
|
+
if model not in self._index_map:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
return self._index_map[model].get(index)
|
|
155
|
+
|
|
156
|
+
def get_model_indices(self, model: str) -> set[int]:
|
|
157
|
+
"""
|
|
158
|
+
Get all available indices for a model.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
model: Model name
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Set of available indices
|
|
165
|
+
"""
|
|
166
|
+
if model not in self._index_map:
|
|
167
|
+
return set()
|
|
168
|
+
|
|
169
|
+
return set(self._index_map[model].keys())
|
|
170
|
+
|
|
171
|
+
def validate_coverage(
|
|
172
|
+
self,
|
|
173
|
+
model_a: str,
|
|
174
|
+
model_b: str,
|
|
175
|
+
indices: list[int]
|
|
176
|
+
) -> list[int]:
|
|
177
|
+
"""
|
|
178
|
+
Validate which indices have outputs from both models.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
model_a: First model name
|
|
182
|
+
model_b: Second model name
|
|
183
|
+
indices: List of indices to validate
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
List of indices where both models have outputs
|
|
187
|
+
"""
|
|
188
|
+
indices_a = self.get_model_indices(model_a)
|
|
189
|
+
indices_b = self.get_model_indices(model_b)
|
|
190
|
+
|
|
191
|
+
valid_indices = []
|
|
192
|
+
missing_a = []
|
|
193
|
+
missing_b = []
|
|
194
|
+
|
|
195
|
+
for idx in indices:
|
|
196
|
+
has_a = idx in indices_a
|
|
197
|
+
has_b = idx in indices_b
|
|
198
|
+
|
|
199
|
+
if has_a and has_b:
|
|
200
|
+
valid_indices.append(idx)
|
|
201
|
+
else:
|
|
202
|
+
if not has_a:
|
|
203
|
+
missing_a.append(idx)
|
|
204
|
+
if not has_b:
|
|
205
|
+
missing_b.append(idx)
|
|
206
|
+
|
|
207
|
+
# Log warnings for missing outputs
|
|
208
|
+
if missing_a:
|
|
209
|
+
warnings.warn(
|
|
210
|
+
f"Model '{model_a}' missing outputs for {len(missing_a)} indices: "
|
|
211
|
+
f"{missing_a[:5]}{'...' if len(missing_a) > 5 else ''}"
|
|
212
|
+
)
|
|
213
|
+
if missing_b:
|
|
214
|
+
warnings.warn(
|
|
215
|
+
f"Model '{model_b}' missing outputs for {len(missing_b)} indices: "
|
|
216
|
+
f"{missing_b[:5]}{'...' if len(missing_b) > 5 else ''}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return valid_indices
|
|
220
|
+
|
|
221
|
+
def has_model(self, model: str) -> bool:
|
|
222
|
+
"""Check if a model exists in the manager."""
|
|
223
|
+
return model in self._index_map
|
|
224
|
+
|
|
225
|
+
def refresh(self) -> None:
|
|
226
|
+
"""Re-scan the models directory."""
|
|
227
|
+
self._models = []
|
|
228
|
+
self._index_map = {}
|
|
229
|
+
self._scan_models()
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class GlobalModelOutputManager:
|
|
233
|
+
"""
|
|
234
|
+
Manager for model outputs stored under an experiment-scoped layout:
|
|
235
|
+
|
|
236
|
+
models/<exp_name>/<model_name>/<image files>
|
|
237
|
+
|
|
238
|
+
This manager enforces the GenArena v2 constraint:
|
|
239
|
+
- Within one subset, **model names must be globally unique across exp folders**.
|
|
240
|
+
If the same model directory name appears under two different exp directories,
|
|
241
|
+
this manager raises ValueError.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
def __init__(self, models_root_dir: str):
|
|
245
|
+
"""
|
|
246
|
+
Args:
|
|
247
|
+
models_root_dir: Path to `arena_dir/<subset>/models`
|
|
248
|
+
"""
|
|
249
|
+
self.models_root_dir = models_root_dir
|
|
250
|
+
|
|
251
|
+
# model_name -> exp_name
|
|
252
|
+
self._model_to_exp: dict[str, str] = {}
|
|
253
|
+
# model_name -> {index -> filepath}
|
|
254
|
+
self._index_map: dict[str, dict[int, str]] = {}
|
|
255
|
+
# exp_name -> list[model_name]
|
|
256
|
+
self._exp_to_models: dict[str, list[str]] = {}
|
|
257
|
+
# cached sorted model list
|
|
258
|
+
self._models: list[str] = []
|
|
259
|
+
|
|
260
|
+
self._scan()
|
|
261
|
+
|
|
262
|
+
def _scan(self) -> None:
|
|
263
|
+
if not os.path.isdir(self.models_root_dir):
|
|
264
|
+
warnings.warn(f"Models directory does not exist: {self.models_root_dir}")
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
for exp_name in os.listdir(self.models_root_dir):
|
|
268
|
+
if exp_name.startswith("."):
|
|
269
|
+
continue
|
|
270
|
+
exp_dir = os.path.join(self.models_root_dir, exp_name)
|
|
271
|
+
if not os.path.isdir(exp_dir):
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
for model_name in os.listdir(exp_dir):
|
|
275
|
+
if model_name.startswith("."):
|
|
276
|
+
continue
|
|
277
|
+
model_dir = os.path.join(exp_dir, model_name)
|
|
278
|
+
if not os.path.isdir(model_dir):
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
# Build index map for this model directory
|
|
282
|
+
index_map: dict[int, str] = {}
|
|
283
|
+
try:
|
|
284
|
+
filenames = os.listdir(model_dir)
|
|
285
|
+
except Exception:
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
for filename in filenames:
|
|
289
|
+
if not filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
|
|
290
|
+
continue
|
|
291
|
+
idx = parse_image_index(filename)
|
|
292
|
+
if idx is None:
|
|
293
|
+
continue
|
|
294
|
+
filepath = os.path.join(model_dir, filename)
|
|
295
|
+
# If duplicate index, prefer .png
|
|
296
|
+
if idx in index_map:
|
|
297
|
+
if filename.lower().endswith(".png"):
|
|
298
|
+
index_map[idx] = filepath
|
|
299
|
+
else:
|
|
300
|
+
index_map[idx] = filepath
|
|
301
|
+
|
|
302
|
+
# Ignore empty model directories (no images found)
|
|
303
|
+
if not index_map:
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
if model_name in self._model_to_exp:
|
|
307
|
+
prev_exp = self._model_to_exp[model_name]
|
|
308
|
+
raise ValueError(
|
|
309
|
+
f"Duplicate model name across experiments: '{model_name}' found in both "
|
|
310
|
+
f"'{prev_exp}' and '{exp_name}'. Model names must be unique across exp folders."
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
self._model_to_exp[model_name] = exp_name
|
|
314
|
+
self._index_map[model_name] = index_map
|
|
315
|
+
self._exp_to_models.setdefault(exp_name, []).append(model_name)
|
|
316
|
+
|
|
317
|
+
for exp in self._exp_to_models:
|
|
318
|
+
self._exp_to_models[exp].sort()
|
|
319
|
+
|
|
320
|
+
self._models = sorted(self._model_to_exp.keys())
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def models(self) -> list[str]:
|
|
324
|
+
"""Get list of discovered model names."""
|
|
325
|
+
return self._models.copy()
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def experiments(self) -> list[str]:
|
|
329
|
+
"""Get list of experiment names that contain at least one model."""
|
|
330
|
+
return sorted(self._exp_to_models.keys())
|
|
331
|
+
|
|
332
|
+
def get_model_exp(self, model: str) -> Optional[str]:
|
|
333
|
+
"""Return exp_name containing the model, or None if unknown."""
|
|
334
|
+
return self._model_to_exp.get(model)
|
|
335
|
+
|
|
336
|
+
def get_experiment_models(self, exp_name: str) -> list[str]:
|
|
337
|
+
"""Return models discovered under a specific exp directory."""
|
|
338
|
+
return self._exp_to_models.get(exp_name, []).copy()
|
|
339
|
+
|
|
340
|
+
def has_model(self, model: str) -> bool:
|
|
341
|
+
return model in self._index_map
|
|
342
|
+
|
|
343
|
+
def get_output_path(self, model: str, index: int) -> Optional[str]:
|
|
344
|
+
if model not in self._index_map:
|
|
345
|
+
return None
|
|
346
|
+
return self._index_map[model].get(index)
|
|
347
|
+
|
|
348
|
+
def get_model_indices(self, model: str) -> set[int]:
|
|
349
|
+
if model not in self._index_map:
|
|
350
|
+
return set()
|
|
351
|
+
return set(self._index_map[model].keys())
|
|
352
|
+
|
|
353
|
+
def validate_coverage(self, model_a: str, model_b: str, indices: list[int]) -> list[int]:
|
|
354
|
+
indices_a = self.get_model_indices(model_a)
|
|
355
|
+
indices_b = self.get_model_indices(model_b)
|
|
356
|
+
|
|
357
|
+
valid_indices = []
|
|
358
|
+
missing_a = []
|
|
359
|
+
missing_b = []
|
|
360
|
+
|
|
361
|
+
for idx in indices:
|
|
362
|
+
has_a = idx in indices_a
|
|
363
|
+
has_b = idx in indices_b
|
|
364
|
+
|
|
365
|
+
if has_a and has_b:
|
|
366
|
+
valid_indices.append(idx)
|
|
367
|
+
else:
|
|
368
|
+
if not has_a:
|
|
369
|
+
missing_a.append(idx)
|
|
370
|
+
if not has_b:
|
|
371
|
+
missing_b.append(idx)
|
|
372
|
+
|
|
373
|
+
if missing_a:
|
|
374
|
+
warnings.warn(
|
|
375
|
+
f"Model '{model_a}' missing outputs for {len(missing_a)} indices: "
|
|
376
|
+
f"{missing_a[:5]}{'...' if len(missing_a) > 5 else ''}"
|
|
377
|
+
)
|
|
378
|
+
if missing_b:
|
|
379
|
+
warnings.warn(
|
|
380
|
+
f"Model '{model_b}' missing outputs for {len(missing_b)} indices: "
|
|
381
|
+
f"{missing_b[:5]}{'...' if len(missing_b) > 5 else ''}"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return valid_indices
|
|
385
|
+
|
|
386
|
+
def refresh(self) -> None:
|
|
387
|
+
"""Re-scan the models root directory."""
|
|
388
|
+
self._model_to_exp = {}
|
|
389
|
+
self._index_map = {}
|
|
390
|
+
self._exp_to_models = {}
|
|
391
|
+
self._models = []
|
|
392
|
+
self._scan()
|
|
393
|
+
|
|
394
|
+
def get_models_up_to_date(self, exp_date: date) -> list[str]:
|
|
395
|
+
"""Return models from experiments with date <= exp_date.
|
|
396
|
+
|
|
397
|
+
This is used to ensure that when running battles for an old experiment,
|
|
398
|
+
we only consider models from experiments that existed at that time
|
|
399
|
+
(same date or earlier), not models from future experiments.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
exp_date: The cutoff date (inclusive).
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
List of model names from experiments with date <= exp_date.
|
|
406
|
+
"""
|
|
407
|
+
result: list[str] = []
|
|
408
|
+
for exp_name, models in self._exp_to_models.items():
|
|
409
|
+
d = parse_exp_date_suffix(exp_name)
|
|
410
|
+
if d is not None and d <= exp_date:
|
|
411
|
+
result.extend(models)
|
|
412
|
+
return sorted(result)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Prompt module loader and validator."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import importlib.util
|
|
5
|
+
import os
|
|
6
|
+
from types import ModuleType
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Required attributes for a valid prompt module
|
|
11
|
+
REQUIRED_ATTRIBUTES = ["PROMPT_TEXT", "ALLOW_TIE", "build_prompt", "parse_response"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_prompt(name: str) -> ModuleType:
|
|
15
|
+
"""
|
|
16
|
+
Load a prompt module by name.
|
|
17
|
+
|
|
18
|
+
First tries to load from the genarena.prompts package, then attempts
|
|
19
|
+
to load from a file path if the name looks like a path.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
name: Prompt module name (e.g., 'mmrb2') or path to a .py file
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Loaded module
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ImportError: If module cannot be found
|
|
29
|
+
ValueError: If module is invalid
|
|
30
|
+
"""
|
|
31
|
+
module = None
|
|
32
|
+
|
|
33
|
+
# Try loading from genarena.prompts package
|
|
34
|
+
try:
|
|
35
|
+
module = importlib.import_module(f"genarena.prompts.{name}")
|
|
36
|
+
except ImportError:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
# If not found and name looks like a path, try loading from file
|
|
40
|
+
if module is None and (name.endswith('.py') or os.path.sep in name):
|
|
41
|
+
if os.path.isfile(name):
|
|
42
|
+
spec = importlib.util.spec_from_file_location("custom_prompt", name)
|
|
43
|
+
if spec and spec.loader:
|
|
44
|
+
module = importlib.util.module_from_spec(spec)
|
|
45
|
+
spec.loader.exec_module(module)
|
|
46
|
+
|
|
47
|
+
if module is None:
|
|
48
|
+
raise ImportError(
|
|
49
|
+
f"Could not load prompt module '{name}'. "
|
|
50
|
+
f"Make sure it exists in genarena/prompts/ or provide a valid file path."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Validate the module
|
|
54
|
+
if not validate_prompt(module):
|
|
55
|
+
missing = get_missing_attributes(module)
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Invalid prompt module '{name}'. "
|
|
58
|
+
f"Missing required attributes: {missing}"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return module
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def validate_prompt(module: ModuleType) -> bool:
|
|
65
|
+
"""
|
|
66
|
+
Validate that a module contains all required prompt attributes.
|
|
67
|
+
|
|
68
|
+
Required attributes:
|
|
69
|
+
- PROMPT_TEXT: str - The evaluation prompt text
|
|
70
|
+
- ALLOW_TIE: bool - Whether single-round ties are allowed
|
|
71
|
+
- build_prompt: callable - Function to build VLM messages
|
|
72
|
+
- parse_response: callable - Function to parse VLM response
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
module: Module to validate
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
True if valid, False otherwise
|
|
79
|
+
"""
|
|
80
|
+
for attr in REQUIRED_ATTRIBUTES:
|
|
81
|
+
if not hasattr(module, attr):
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
# Check callable attributes
|
|
85
|
+
if attr in ("build_prompt", "parse_response"):
|
|
86
|
+
if not callable(getattr(module, attr)):
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
return True
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_missing_attributes(module: ModuleType) -> list[str]:
|
|
93
|
+
"""
|
|
94
|
+
Get list of missing required attributes from a module.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
module: Module to check
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
List of missing attribute names
|
|
101
|
+
"""
|
|
102
|
+
missing = []
|
|
103
|
+
for attr in REQUIRED_ATTRIBUTES:
|
|
104
|
+
if not hasattr(module, attr):
|
|
105
|
+
missing.append(attr)
|
|
106
|
+
elif attr in ("build_prompt", "parse_response"):
|
|
107
|
+
if not callable(getattr(module, attr)):
|
|
108
|
+
missing.append(f"{attr} (not callable)")
|
|
109
|
+
return missing
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def list_available_prompts() -> list[str]:
|
|
113
|
+
"""
|
|
114
|
+
List all available prompt modules in the prompts directory.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of prompt module names
|
|
118
|
+
"""
|
|
119
|
+
prompts_dir = os.path.dirname(__file__)
|
|
120
|
+
available = []
|
|
121
|
+
|
|
122
|
+
for filename in os.listdir(prompts_dir):
|
|
123
|
+
if filename.endswith('.py') and not filename.startswith('_'):
|
|
124
|
+
name = filename[:-3] # Remove .py extension
|
|
125
|
+
available.append(name)
|
|
126
|
+
|
|
127
|
+
return sorted(available)
|