nano-eval 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nano_eval-0.2.4/LICENSE +21 -0
- nano_eval-0.2.4/PKG-INFO +133 -0
- nano_eval-0.2.4/README.md +100 -0
- nano_eval-0.2.4/core.py +328 -0
- nano_eval-0.2.4/nano_eval.egg-info/PKG-INFO +133 -0
- nano_eval-0.2.4/nano_eval.egg-info/SOURCES.txt +15 -0
- nano_eval-0.2.4/nano_eval.egg-info/dependency_links.txt +1 -0
- nano_eval-0.2.4/nano_eval.egg-info/entry_points.txt +2 -0
- nano_eval-0.2.4/nano_eval.egg-info/requires.txt +15 -0
- nano_eval-0.2.4/nano_eval.egg-info/top_level.txt +3 -0
- nano_eval-0.2.4/nano_eval.py +273 -0
- nano_eval-0.2.4/pyproject.toml +51 -0
- nano_eval-0.2.4/setup.cfg +4 -0
- nano_eval-0.2.4/tasks/__init__.py +14 -0
- nano_eval-0.2.4/tasks/chartqa.py +121 -0
- nano_eval-0.2.4/tasks/gsm8k.py +155 -0
- nano_eval-0.2.4/tests/test_nano_eval.py +229 -0
nano_eval-0.2.4/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Thomas Børstad
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
nano_eval-0.2.4/PKG-INFO
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nano-eval
|
|
3
|
+
Version: 0.2.4
|
|
4
|
+
Summary: Nano Eval - A minimal tool for verifying VLMs/LLMs across frameworks
|
|
5
|
+
Author-email: Thomas Børstad <tboerstad@users.noreply.github.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/tboerstad/nano-eval
|
|
8
|
+
Project-URL: Repository, https://github.com/tboerstad/nano-eval
|
|
9
|
+
Project-URL: Issues, https://github.com/tboerstad/nano-eval/issues
|
|
10
|
+
Keywords: llm,evaluation,benchmark,openai,api,vllm,tgi
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: click>=8.0.0
|
|
19
|
+
Requires-Dist: datasets>=2.0.0
|
|
20
|
+
Requires-Dist: httpx>=0.23.0
|
|
21
|
+
Requires-Dist: pillow>=9.0.0
|
|
22
|
+
Requires-Dist: tqdm>=4.62.0
|
|
23
|
+
Requires-Dist: typing-extensions>=4.0.0
|
|
24
|
+
Requires-Dist: multiprocess<=0.70.17
|
|
25
|
+
Provides-Extra: dev
|
|
26
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
28
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
29
|
+
Requires-Dist: pre-commit>=3.0.0; extra == "dev"
|
|
30
|
+
Requires-Dist: respx>=0.20.0; extra == "dev"
|
|
31
|
+
Requires-Dist: ty>=0.0.1; extra == "dev"
|
|
32
|
+
Dynamic: license-file
|
|
33
|
+
|
|
34
|
+
**nano-eval** is a minimal tool for measuring the quality of a text or vision model.
|
|
35
|
+
|
|
36
|
+
## Quickstart
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
uvx nano-eval -t text -t vision --base-url http://localhost:8000/v1 --max-samples 100
|
|
40
|
+
|
|
41
|
+
# prints:
|
|
42
|
+
Task Accuracy Samples Duration
|
|
43
|
+
------ -------- ------- --------
|
|
44
|
+
text 84.3% 100 45s
|
|
45
|
+
vision 71.8% 100 38s
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
> **Note:** This tool is for eyeballing the accuracy of a model. One use case is comparing accuracy between inference frameworks (e.g., vLLM vs SGLang vs MAX running the same model).
|
|
49
|
+
|
|
50
|
+
## Supported Types
|
|
51
|
+
|
|
52
|
+
| Type | Dataset | Description |
|
|
53
|
+
|------|---------|-------------|
|
|
54
|
+
| `text` | gsm8k_cot_llama | Grade school math with chain-of-thought (8-shot) |
|
|
55
|
+
| `vision` | HuggingFaceM4/ChartQA | Chart question answering with images |
|
|
56
|
+
|
|
57
|
+
## Usage
|
|
58
|
+
|
|
59
|
+
```
|
|
60
|
+
$ nano-eval --help
|
|
61
|
+
Usage: nano-eval [OPTIONS]
|
|
62
|
+
|
|
63
|
+
Evaluate LLMs on standardized tasks via OpenAI-compatible APIs.
|
|
64
|
+
|
|
65
|
+
Example: nano-eval -t text --base-url http://localhost:8000/v1
|
|
66
|
+
|
|
67
|
+
Options:
|
|
68
|
+
-t, --type [text|vision] Type to evaluate (can be repeated)
|
|
69
|
+
[required]
|
|
70
|
+
--base-url TEXT OpenAI-compatible API endpoint [required]
|
|
71
|
+
--model TEXT Model name; auto-detected if endpoint serves
|
|
72
|
+
one model
|
|
73
|
+
--api-key TEXT Bearer token for API authentication
|
|
74
|
+
--max-concurrent INTEGER [default: 8]
|
|
75
|
+
--extra-request-params TEXT API params as key=value,... [default:
|
|
76
|
+
temperature=0,max_tokens=256,seed=42]
|
|
77
|
+
--max-samples INTEGER If provided, limit samples per task
|
|
78
|
+
--output-path PATH Write results.json and sample logs to this
|
|
79
|
+
directory
|
|
80
|
+
--log-samples Save per-sample results as JSONL (requires
|
|
81
|
+
--output-path)
|
|
82
|
+
--seed INTEGER Controls sample order [default: 42]
|
|
83
|
+
-v, --verbose Increase verbosity (up to -vv)
|
|
84
|
+
--version Show the version and exit.
|
|
85
|
+
--help Show this message and exit.
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### Python API
|
|
89
|
+
|
|
90
|
+
```python
|
|
91
|
+
import asyncio
|
|
92
|
+
from nano_eval import evaluate, EvalResult
|
|
93
|
+
|
|
94
|
+
result: EvalResult = asyncio.run(evaluate(
|
|
95
|
+
types=["text"],
|
|
96
|
+
base_url="http://localhost:8000/v1",
|
|
97
|
+
model="google/gemma-3-4b-it",
|
|
98
|
+
max_samples=100,
|
|
99
|
+
))
|
|
100
|
+
text_result = result["results"]["text"]
|
|
101
|
+
print(f"Accuracy: {text_result['metrics']['exact_match']:.1%}")
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
This tool is inspired and borrows from: [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). Please check it out
|
|
106
|
+
|
|
107
|
+
## Example Output
|
|
108
|
+
|
|
109
|
+
When using `--output-path`, a `results.json` file is generated:
|
|
110
|
+
|
|
111
|
+
```json
|
|
112
|
+
{
|
|
113
|
+
"config": {
|
|
114
|
+
"max_samples": 37,
|
|
115
|
+
"model": "google/gemma-3-4b-it"
|
|
116
|
+
},
|
|
117
|
+
"framework_version": "0.2.1",
|
|
118
|
+
"results": {
|
|
119
|
+
"text": {
|
|
120
|
+
"elapsed_seconds": 28.45,
|
|
121
|
+
"metrics": {
|
|
122
|
+
"exact_match": 0.7837837837837838,
|
|
123
|
+
"exact_match_stderr": 0.06861056852129647
|
|
124
|
+
},
|
|
125
|
+
"num_samples": 37,
|
|
126
|
+
"samples_hash": "12a1e9404db6afe810290a474d69cfebdaffefd0b56e48ac80e1fec0f286d659",
|
|
127
|
+
"task": "gsm8k_cot_llama",
|
|
128
|
+
"task_type": "text"
|
|
129
|
+
}
|
|
130
|
+
},
|
|
131
|
+
"total_seconds": 28.45
|
|
132
|
+
}
|
|
133
|
+
```
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
**nano-eval** is a minimal tool for measuring the quality of a text or vision model.
|
|
2
|
+
|
|
3
|
+
## Quickstart
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
uvx nano-eval -t text -t vision --base-url http://localhost:8000/v1 --max-samples 100
|
|
7
|
+
|
|
8
|
+
# prints:
|
|
9
|
+
Task Accuracy Samples Duration
|
|
10
|
+
------ -------- ------- --------
|
|
11
|
+
text 84.3% 100 45s
|
|
12
|
+
vision 71.8% 100 38s
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
> **Note:** This tool is for eyeballing the accuracy of a model. One use case is comparing accuracy between inference frameworks (e.g., vLLM vs SGLang vs MAX running the same model).
|
|
16
|
+
|
|
17
|
+
## Supported Types
|
|
18
|
+
|
|
19
|
+
| Type | Dataset | Description |
|
|
20
|
+
|------|---------|-------------|
|
|
21
|
+
| `text` | gsm8k_cot_llama | Grade school math with chain-of-thought (8-shot) |
|
|
22
|
+
| `vision` | HuggingFaceM4/ChartQA | Chart question answering with images |
|
|
23
|
+
|
|
24
|
+
## Usage
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
$ nano-eval --help
|
|
28
|
+
Usage: nano-eval [OPTIONS]
|
|
29
|
+
|
|
30
|
+
Evaluate LLMs on standardized tasks via OpenAI-compatible APIs.
|
|
31
|
+
|
|
32
|
+
Example: nano-eval -t text --base-url http://localhost:8000/v1
|
|
33
|
+
|
|
34
|
+
Options:
|
|
35
|
+
-t, --type [text|vision] Type to evaluate (can be repeated)
|
|
36
|
+
[required]
|
|
37
|
+
--base-url TEXT OpenAI-compatible API endpoint [required]
|
|
38
|
+
--model TEXT Model name; auto-detected if endpoint serves
|
|
39
|
+
one model
|
|
40
|
+
--api-key TEXT Bearer token for API authentication
|
|
41
|
+
--max-concurrent INTEGER [default: 8]
|
|
42
|
+
--extra-request-params TEXT API params as key=value,... [default:
|
|
43
|
+
temperature=0,max_tokens=256,seed=42]
|
|
44
|
+
--max-samples INTEGER If provided, limit samples per task
|
|
45
|
+
--output-path PATH Write results.json and sample logs to this
|
|
46
|
+
directory
|
|
47
|
+
--log-samples Save per-sample results as JSONL (requires
|
|
48
|
+
--output-path)
|
|
49
|
+
--seed INTEGER Controls sample order [default: 42]
|
|
50
|
+
-v, --verbose Increase verbosity (up to -vv)
|
|
51
|
+
--version Show the version and exit.
|
|
52
|
+
--help Show this message and exit.
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
### Python API
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
import asyncio
|
|
59
|
+
from nano_eval import evaluate, EvalResult
|
|
60
|
+
|
|
61
|
+
result: EvalResult = asyncio.run(evaluate(
|
|
62
|
+
types=["text"],
|
|
63
|
+
base_url="http://localhost:8000/v1",
|
|
64
|
+
model="google/gemma-3-4b-it",
|
|
65
|
+
max_samples=100,
|
|
66
|
+
))
|
|
67
|
+
text_result = result["results"]["text"]
|
|
68
|
+
print(f"Accuracy: {text_result['metrics']['exact_match']:.1%}")
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
This tool is inspired and borrows from: [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). Please check it out
|
|
73
|
+
|
|
74
|
+
## Example Output
|
|
75
|
+
|
|
76
|
+
When using `--output-path`, a `results.json` file is generated:
|
|
77
|
+
|
|
78
|
+
```json
|
|
79
|
+
{
|
|
80
|
+
"config": {
|
|
81
|
+
"max_samples": 37,
|
|
82
|
+
"model": "google/gemma-3-4b-it"
|
|
83
|
+
},
|
|
84
|
+
"framework_version": "0.2.1",
|
|
85
|
+
"results": {
|
|
86
|
+
"text": {
|
|
87
|
+
"elapsed_seconds": 28.45,
|
|
88
|
+
"metrics": {
|
|
89
|
+
"exact_match": 0.7837837837837838,
|
|
90
|
+
"exact_match_stderr": 0.06861056852129647
|
|
91
|
+
},
|
|
92
|
+
"num_samples": 37,
|
|
93
|
+
"samples_hash": "12a1e9404db6afe810290a474d69cfebdaffefd0b56e48ac80e1fec0f286d659",
|
|
94
|
+
"task": "gsm8k_cot_llama",
|
|
95
|
+
"task_type": "text"
|
|
96
|
+
}
|
|
97
|
+
},
|
|
98
|
+
"total_seconds": 28.45
|
|
99
|
+
}
|
|
100
|
+
```
|
nano_eval-0.2.4/core.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core utilities for nano-eval.
|
|
3
|
+
|
|
4
|
+
Responsibilities:
|
|
5
|
+
- APIConfig: endpoint, model, concurrency, timeout
|
|
6
|
+
- Sample/Task: minimal task abstraction (generator + scorer)
|
|
7
|
+
- complete(): async batch chat completions (OpenAI-compatible)
|
|
8
|
+
- run_task(): evaluate a Task, return TaskResult
|
|
9
|
+
- _normalize(): text normalization for comparison
|
|
10
|
+
- _encode_image(): PIL→base64; rejects remote URLs
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import base64
|
|
17
|
+
import hashlib
|
|
18
|
+
import logging
|
|
19
|
+
import math
|
|
20
|
+
import re
|
|
21
|
+
import time
|
|
22
|
+
from collections.abc import Callable
|
|
23
|
+
from contextlib import contextmanager
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from io import BytesIO
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
import datasets.config as ds_config
|
|
30
|
+
import httpx
|
|
31
|
+
from PIL import Image
|
|
32
|
+
from tqdm.asyncio import tqdm_asyncio
|
|
33
|
+
from typing_extensions import NotRequired, TypedDict
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Metrics(TypedDict):
|
|
39
|
+
exact_match: float
|
|
40
|
+
exact_match_stderr: float
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LoggedSample(TypedDict):
|
|
44
|
+
sample_id: int
|
|
45
|
+
target: str
|
|
46
|
+
prompt: str
|
|
47
|
+
response: str
|
|
48
|
+
exact_match: float
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TaskResult(TypedDict):
|
|
52
|
+
elapsed_seconds: float
|
|
53
|
+
metrics: Metrics
|
|
54
|
+
num_samples: int
|
|
55
|
+
samples: NotRequired[list[LoggedSample]]
|
|
56
|
+
samples_hash: str
|
|
57
|
+
task: str
|
|
58
|
+
task_type: str
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass(frozen=True)
|
|
62
|
+
class TextPrompt:
|
|
63
|
+
"""Text-only prompt (simple string or pre-formatted messages)."""
|
|
64
|
+
|
|
65
|
+
text: str | list[dict[str, str]]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass(frozen=True)
|
|
69
|
+
class VisionPrompt:
|
|
70
|
+
"""Multimodal prompt with text and images."""
|
|
71
|
+
|
|
72
|
+
text: str
|
|
73
|
+
images: list[Any]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
Input = TextPrompt | VisionPrompt
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class Sample:
|
|
81
|
+
"""A single evaluation sample: prompt + expected target."""
|
|
82
|
+
|
|
83
|
+
prompt: Input
|
|
84
|
+
target: str
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass(frozen=True)
|
|
88
|
+
class Task:
|
|
89
|
+
"""Minimal task definition: a loader of samples + a scoring function."""
|
|
90
|
+
|
|
91
|
+
name: str
|
|
92
|
+
task_type: str # "text" or "vision"
|
|
93
|
+
samples: Callable[[int | None, int | None], list[Sample]] # (max_samples, seed)
|
|
94
|
+
score: Callable[[str, str], float] # (response, target) -> score
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class APIConfig:
|
|
99
|
+
"""API configuration."""
|
|
100
|
+
|
|
101
|
+
url: str
|
|
102
|
+
model: str
|
|
103
|
+
api_key: str = ""
|
|
104
|
+
max_concurrent: int = 8
|
|
105
|
+
timeout: int = 300
|
|
106
|
+
gen_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
async def _request(
|
|
110
|
+
client: httpx.AsyncClient,
|
|
111
|
+
url: str,
|
|
112
|
+
payload: dict[str, Any],
|
|
113
|
+
) -> str:
|
|
114
|
+
"""Single request. Raises RuntimeError on failure."""
|
|
115
|
+
resp = await client.post(url, json=payload)
|
|
116
|
+
if resp.is_success:
|
|
117
|
+
return resp.json()["choices"][0]["message"]["content"]
|
|
118
|
+
raise RuntimeError(f"Request failed: {resp.text}")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
async def complete(
|
|
122
|
+
prompts: list[Input],
|
|
123
|
+
config: APIConfig,
|
|
124
|
+
progress_desc: str = "Running evals",
|
|
125
|
+
) -> list[str]:
|
|
126
|
+
"""
|
|
127
|
+
Run batch of chat completions.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
prompts: List of prompts (TextPrompt or VisionPrompt)
|
|
131
|
+
config: API configuration (includes gen_kwargs for temperature, max_tokens, etc.)
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
List of response strings
|
|
135
|
+
"""
|
|
136
|
+
headers = {"Content-Type": "application/json"}
|
|
137
|
+
if config.api_key:
|
|
138
|
+
headers["Authorization"] = f"Bearer {config.api_key}"
|
|
139
|
+
|
|
140
|
+
async with httpx.AsyncClient(
|
|
141
|
+
limits=httpx.Limits(max_connections=config.max_concurrent),
|
|
142
|
+
timeout=httpx.Timeout(config.timeout),
|
|
143
|
+
headers=headers,
|
|
144
|
+
trust_env=True,
|
|
145
|
+
) as client:
|
|
146
|
+
tasks: list[asyncio.Task[str]] = []
|
|
147
|
+
for prompt in prompts:
|
|
148
|
+
if isinstance(prompt, VisionPrompt):
|
|
149
|
+
messages = _build_vision_message(prompt.text, prompt.images)
|
|
150
|
+
elif isinstance(prompt.text, list):
|
|
151
|
+
messages = prompt.text
|
|
152
|
+
else:
|
|
153
|
+
messages = [{"role": "user", "content": prompt.text}]
|
|
154
|
+
|
|
155
|
+
payload: dict[str, Any] = {
|
|
156
|
+
"model": config.model,
|
|
157
|
+
"messages": messages,
|
|
158
|
+
**config.gen_kwargs,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
tasks.append(asyncio.create_task(_request(client, config.url, payload)))
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
return list(
|
|
165
|
+
await tqdm_asyncio.gather(*tasks, desc=progress_desc, leave=False)
|
|
166
|
+
)
|
|
167
|
+
except BaseException:
|
|
168
|
+
# On any failure (including Ctrl-C), cancel all pending tasks and await
|
|
169
|
+
# them to properly "retrieve" their exceptions. Without this, Python logs
|
|
170
|
+
# "Task exception was never retrieved" for each concurrent task that failed.
|
|
171
|
+
# Using BaseException (not Exception) ensures cleanup runs on KeyboardInterrupt.
|
|
172
|
+
for task in tasks:
|
|
173
|
+
task.cancel()
|
|
174
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
175
|
+
raise
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _build_vision_message(text: str, images: list[Any]) -> list[dict[str, Any]]:
|
|
179
|
+
"""Build OpenAI vision API message."""
|
|
180
|
+
content: list[dict[str, Any]] = []
|
|
181
|
+
for img in images:
|
|
182
|
+
if b64 := _encode_image(img):
|
|
183
|
+
content.append(
|
|
184
|
+
{
|
|
185
|
+
"type": "image_url",
|
|
186
|
+
"image_url": {"url": f"data:image/png;base64,{b64}"},
|
|
187
|
+
}
|
|
188
|
+
)
|
|
189
|
+
content.append({"type": "text", "text": text.replace("<image>", "").strip()})
|
|
190
|
+
return [{"role": "user", "content": content}]
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _encode_image(image: Any) -> str:
|
|
194
|
+
"""Encode PIL image to base64, or pass through string."""
|
|
195
|
+
if isinstance(image, str):
|
|
196
|
+
if image.startswith("http"):
|
|
197
|
+
raise ValueError("Remote image URLs are not supported.")
|
|
198
|
+
return image
|
|
199
|
+
|
|
200
|
+
if isinstance(image, Image.Image):
|
|
201
|
+
try:
|
|
202
|
+
# Convert to RGB if needed to avoid save errors with CMYK/palette modes
|
|
203
|
+
if image.mode not in ("RGB", "L"):
|
|
204
|
+
image = image.convert("RGB")
|
|
205
|
+
buf = BytesIO()
|
|
206
|
+
image.save(buf, format="PNG")
|
|
207
|
+
return base64.b64encode(buf.getvalue()).decode()
|
|
208
|
+
except Exception as e:
|
|
209
|
+
raise ValueError(f"Failed to encode image: {e}") from e
|
|
210
|
+
|
|
211
|
+
raise TypeError(f"Unsupported image type: {type(image).__name__}")
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _normalize(text: str) -> str:
|
|
215
|
+
"""Normalize text for comparison."""
|
|
216
|
+
text = re.sub(r"[$,]", "", text)
|
|
217
|
+
text = re.sub(r"(?s).*#### ", "", text)
|
|
218
|
+
text = re.sub(r"\.$", "", text)
|
|
219
|
+
return text.lower().strip()
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _prompt_to_str(prompt: Input) -> str:
|
|
223
|
+
"""Extract text from prompt (handles TextPrompt and VisionPrompt)."""
|
|
224
|
+
if isinstance(prompt, VisionPrompt):
|
|
225
|
+
return prompt.text
|
|
226
|
+
if isinstance(prompt.text, list):
|
|
227
|
+
return "\n".join(f"{m['role']}: {m['content']}" for m in prompt.text)
|
|
228
|
+
return prompt.text
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def compute_samples_hash(samples: list[Sample]) -> str:
|
|
232
|
+
"""Compute SHA256 hash for all samples in a task (includes image data)."""
|
|
233
|
+
hasher = hashlib.sha256()
|
|
234
|
+
for s in samples:
|
|
235
|
+
hasher.update(_prompt_to_str(s.prompt).encode())
|
|
236
|
+
if isinstance(s.prompt, VisionPrompt):
|
|
237
|
+
for img in s.prompt.images:
|
|
238
|
+
hasher.update(_encode_image(img).encode())
|
|
239
|
+
hasher.update(s.target.encode())
|
|
240
|
+
return hasher.hexdigest()
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
async def run_task(
|
|
244
|
+
task: Task,
|
|
245
|
+
config: APIConfig,
|
|
246
|
+
max_samples: int | None = None,
|
|
247
|
+
seed: int | None = None,
|
|
248
|
+
) -> TaskResult:
|
|
249
|
+
"""
|
|
250
|
+
Evaluate a task: collect samples, run inference, compute scores.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
task: Task definition with samples loader and scoring function
|
|
254
|
+
config: API configuration (includes gen_kwargs for temperature, max_tokens, etc.)
|
|
255
|
+
max_samples: Optional limit on number of samples
|
|
256
|
+
seed: Optional seed for shuffling sample order
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
TaskResult with metrics, sample count, elapsed time, and per-sample data
|
|
260
|
+
"""
|
|
261
|
+
samples = task.samples(max_samples, seed)
|
|
262
|
+
samples_hash = compute_samples_hash(samples)
|
|
263
|
+
prompts = [s.prompt for s in samples]
|
|
264
|
+
|
|
265
|
+
logger.info(
|
|
266
|
+
f"Starting {task.task_type} ({task.name}) eval: "
|
|
267
|
+
f"{len(samples)} samples, up to {config.max_concurrent} concurrent requests"
|
|
268
|
+
)
|
|
269
|
+
t0 = time.perf_counter()
|
|
270
|
+
desc = "Running vision eval" if task.task_type == "vision" else "Running text eval"
|
|
271
|
+
responses = await complete(prompts, config, desc)
|
|
272
|
+
elapsed = time.perf_counter() - t0
|
|
273
|
+
|
|
274
|
+
scores = [task.score(r, s.target) for r, s in zip(responses, samples)]
|
|
275
|
+
n = len(samples)
|
|
276
|
+
accuracy = sum(scores) / n if n else 0.0
|
|
277
|
+
stderr = math.sqrt(accuracy * (1 - accuracy) / (n - 1)) if n > 1 else 0.0
|
|
278
|
+
|
|
279
|
+
logger.debug(f"{task.name}: accuracy={accuracy:.4f}±{stderr:.4f} ({elapsed:.2f}s)")
|
|
280
|
+
|
|
281
|
+
# Always collect per-sample data for optional JSONL export (negligible overhead)
|
|
282
|
+
logged_samples: list[LoggedSample] = [
|
|
283
|
+
LoggedSample(
|
|
284
|
+
sample_id=i,
|
|
285
|
+
target=s.target,
|
|
286
|
+
prompt=_prompt_to_str(s.prompt),
|
|
287
|
+
response=r,
|
|
288
|
+
exact_match=score,
|
|
289
|
+
)
|
|
290
|
+
for i, (s, r, score) in enumerate(zip(samples, responses, scores))
|
|
291
|
+
]
|
|
292
|
+
return TaskResult(
|
|
293
|
+
elapsed_seconds=round(elapsed, 2),
|
|
294
|
+
metrics=Metrics(exact_match=accuracy, exact_match_stderr=stderr),
|
|
295
|
+
num_samples=n,
|
|
296
|
+
samples=logged_samples,
|
|
297
|
+
samples_hash=samples_hash,
|
|
298
|
+
task=task.name,
|
|
299
|
+
task_type=task.task_type,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@contextmanager
|
|
304
|
+
def offline_if_cached(dataset: str, revision: str):
|
|
305
|
+
"""Context manager: enable HF offline mode if dataset is cached (avoids HEAD requests).
|
|
306
|
+
|
|
307
|
+
Yields:
|
|
308
|
+
Tuple of (cached: bool, hf_home: Path) where cached indicates if dataset
|
|
309
|
+
is in cache and hf_home is the HuggingFace cache directory.
|
|
310
|
+
"""
|
|
311
|
+
from huggingface_hub.constants import HF_HOME, HF_HUB_CACHE
|
|
312
|
+
|
|
313
|
+
hub_path = (
|
|
314
|
+
Path(HF_HUB_CACHE)
|
|
315
|
+
/ f"datasets--{dataset.replace('/', '--')}"
|
|
316
|
+
/ "snapshots"
|
|
317
|
+
/ revision
|
|
318
|
+
)
|
|
319
|
+
cached = hub_path.is_dir()
|
|
320
|
+
|
|
321
|
+
if cached:
|
|
322
|
+
old = ds_config.HF_HUB_OFFLINE
|
|
323
|
+
ds_config.HF_HUB_OFFLINE = True
|
|
324
|
+
try:
|
|
325
|
+
yield cached, Path(HF_HOME)
|
|
326
|
+
finally:
|
|
327
|
+
if cached:
|
|
328
|
+
ds_config.HF_HUB_OFFLINE = old
|