weco 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {weco-0.3.0 → weco-0.3.2}/.gitignore +5 -0
- {weco-0.3.0 → weco-0.3.2}/PKG-INFO +2 -1
- {weco-0.3.0 → weco-0.3.2}/examples/cuda/README.md +9 -5
- {weco-0.3.0 → weco-0.3.2}/examples/cuda/evaluate.py +11 -8
- weco-0.3.2/examples/extract-line-plot/README.md +72 -0
- weco-0.3.2/examples/extract-line-plot/eval.py +370 -0
- weco-0.3.2/examples/extract-line-plot/guide.md +53 -0
- weco-0.3.2/examples/extract-line-plot/optimize.py +116 -0
- weco-0.3.2/examples/extract-line-plot/prepare_data.py +94 -0
- weco-0.3.2/examples/extract-line-plot/pyproject.toml +18 -0
- {weco-0.3.0 → weco-0.3.2}/examples/triton/README.md +8 -7
- weco-0.3.2/examples/triton/evaluate.py +107 -0
- weco-0.3.2/examples/triton/optimize.py +23 -0
- {weco-0.3.0 → weco-0.3.2}/pyproject.toml +2 -1
- {weco-0.3.0 → weco-0.3.2}/weco/api.py +84 -46
- {weco-0.3.0 → weco-0.3.2}/weco/constants.py +0 -4
- {weco-0.3.0 → weco-0.3.2}/weco/optimizer.py +0 -2
- {weco-0.3.0 → weco-0.3.2}/weco/utils.py +41 -11
- {weco-0.3.0 → weco-0.3.2}/weco.egg-info/PKG-INFO +2 -1
- {weco-0.3.0 → weco-0.3.2}/weco.egg-info/SOURCES.txt +6 -0
- {weco-0.3.0 → weco-0.3.2}/weco.egg-info/requires.txt +1 -0
- weco-0.3.0/examples/triton/evaluate.py +0 -143
- weco-0.3.0/examples/triton/optimize.py +0 -44
- {weco-0.3.0 → weco-0.3.2}/.github/workflows/lint.yml +0 -0
- {weco-0.3.0 → weco-0.3.2}/.github/workflows/release.yml +0 -0
- {weco-0.3.0 → weco-0.3.2}/LICENSE +0 -0
- {weco-0.3.0 → weco-0.3.2}/README.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/assets/example-optimization.gif +0 -0
- {weco-0.3.0 → weco-0.3.2}/assets/weco.svg +0 -0
- {weco-0.3.0 → weco-0.3.2}/contributing.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/README.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/cuda/optimize.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/README.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/colab_notebook_walkthrough.ipynb +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/evaluate.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/optimize.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/prompt/README.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/prompt/eval.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/prompt/optimize.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/prompt/prompt_guide.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/README.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/competition_description.md +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/data/sample_submission.csv +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/data/test.csv +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/data/train.csv +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/evaluate.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/train.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/setup.cfg +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco/__init__.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco/auth.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco/chatbot.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco/cli.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco/credits.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco/panels.py +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco.egg-info/dependency_links.txt +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco.egg-info/entry_points.txt +0 -0
- {weco-0.3.0 → weco-0.3.2}/weco.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: weco
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Documentation for `weco`, a CLI for using Weco AI's code optimizer.
|
|
5
5
|
Author-email: Weco AI Team <contact@weco.ai>
|
|
6
6
|
License:
|
|
@@ -219,6 +219,7 @@ Requires-Dist: packaging
|
|
|
219
219
|
Requires-Dist: gitingest
|
|
220
220
|
Requires-Dist: fastapi
|
|
221
221
|
Requires-Dist: slowapi
|
|
222
|
+
Requires-Dist: psutil
|
|
222
223
|
Provides-Extra: dev
|
|
223
224
|
Requires-Dist: ruff; extra == "dev"
|
|
224
225
|
Requires-Dist: build; extra == "dev"
|
|
@@ -7,9 +7,11 @@ This approach aims for low-level optimization beyond standard PyTorch or even Tr
|
|
|
7
7
|
|
|
8
8
|
Install the CLI and dependencies for the example:
|
|
9
9
|
```bash
|
|
10
|
-
pip install weco torch
|
|
10
|
+
pip install weco ninja numpy torch triton
|
|
11
11
|
```
|
|
12
|
-
> **Note:**
|
|
12
|
+
> **Note:**
|
|
13
|
+
> 1. This example requires a compatible NVIDIA GPU and the CUDA Toolkit installed on your system for compiling and running the generated CUDA code.
|
|
14
|
+
> 2. If compatible, install [flash attention](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) (`pip install flash-attn --no-build-isolation`).
|
|
13
15
|
|
|
14
16
|
## Run Weco
|
|
15
17
|
|
|
@@ -20,8 +22,9 @@ weco run --source optimize.py \
|
|
|
20
22
|
--metric speedup \
|
|
21
23
|
--goal maximize \
|
|
22
24
|
--steps 50 \
|
|
23
|
-
--model
|
|
24
|
-
--additional-instructions "Write in-line CUDA using pytorch's load_inline() to optimize the code while ensuring a small max float diff. Maintain the same code
|
|
25
|
+
--model gpt-5 \
|
|
26
|
+
--additional-instructions "Write in-line CUDA using pytorch's load_inline() to optimize the code while ensuring a small max float diff. Maintain the same code interface. Do not use any fallbacks and never use the build_directory arg for load_inline(). Assume any required dependencies are installed and data is already on the gpu." \
|
|
27
|
+
--eval-timeout 600
|
|
25
28
|
```
|
|
26
29
|
|
|
27
30
|
### Explanation
|
|
@@ -31,8 +34,9 @@ weco run --source optimize.py \
|
|
|
31
34
|
* `--metric speedup`: The optimization target metric.
|
|
32
35
|
* `--goal maximize`: Weco aims to increase the speedup.
|
|
33
36
|
* `--steps 50`: The number of optimization iterations.
|
|
34
|
-
* `--model
|
|
37
|
+
* `--model gpt-5`: The LLM used for code generation.
|
|
35
38
|
* `--additional-instructions "..."`: Provides guidance to the LLM on the optimization approach.
|
|
39
|
+
* `--eval-timeout 600`: Stop runnning the evaluation script if it does not complete in 600 seconds.
|
|
36
40
|
|
|
37
41
|
Weco will iteratively modify `optimize.py`, generating and integrating CUDA C++ code, guided by the evaluation results and the additional instructions provided.
|
|
38
42
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import os
|
|
3
|
+
import shutil
|
|
3
4
|
import pathlib
|
|
4
5
|
import importlib
|
|
5
6
|
import importlib.util
|
|
@@ -55,10 +56,9 @@ class Model(nn.Module):
|
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
########################################################
|
|
58
|
-
#
|
|
59
|
+
# Benchmark
|
|
59
60
|
########################################################
|
|
60
61
|
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
|
|
61
|
-
# Clean out all old compiled extensions to prevent namespace collisions during build
|
|
62
62
|
module_path = pathlib.Path(module_path)
|
|
63
63
|
name = module_path.stem
|
|
64
64
|
spec = importlib.util.spec_from_file_location(name, module_path)
|
|
@@ -69,12 +69,6 @@ def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
|
|
|
69
69
|
return mod
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
########################################################
|
|
73
|
-
# Benchmark
|
|
74
|
-
########################################################
|
|
75
|
-
os.environ["MAX_JOBS"] = "1" # number of workers for building with ninja
|
|
76
|
-
|
|
77
|
-
|
|
78
72
|
def get_inputs(batch_size, seq_len, n_embd, device):
|
|
79
73
|
return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
|
|
80
74
|
|
|
@@ -86,6 +80,12 @@ if __name__ == "__main__":
|
|
|
86
80
|
parser.add_argument("--solution-path", type=str, required=True)
|
|
87
81
|
args = parser.parse_args()
|
|
88
82
|
|
|
83
|
+
# setup local cache for PyTorch extensions
|
|
84
|
+
cache_dir = pathlib.Path.cwd() / ".weco-temp/torch_extensions"
|
|
85
|
+
shutil.rmtree(cache_dir.parent, ignore_errors=True)
|
|
86
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
os.environ["TORCH_EXTENSIONS_DIR"] = str(cache_dir)
|
|
88
|
+
|
|
89
89
|
# benchmarking parameters
|
|
90
90
|
n_correctness_trials = 10
|
|
91
91
|
correctness_tolerance = 1e-5
|
|
@@ -145,3 +145,6 @@ if __name__ == "__main__":
|
|
|
145
145
|
t_avg_optimized = do_bench(lambda: solution_model(inputs), warmup=warmup_ms, rep=rep_ms)
|
|
146
146
|
print(f"optimized time: {t_avg_optimized:.2f}ms")
|
|
147
147
|
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
|
|
148
|
+
|
|
149
|
+
# clean up
|
|
150
|
+
shutil.rmtree(cache_dir.parent, ignore_errors=True)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
## Extract Line Plot (Chart → CSV) with a VLM
|
|
2
|
+
|
|
3
|
+
This example is about optimizing an AI feature that turns image of chart into a table in csv format.
|
|
4
|
+
|
|
5
|
+
### Prerequisites
|
|
6
|
+
|
|
7
|
+
- Python 3.9+
|
|
8
|
+
- `uv` installed (see `https://docs.astral.sh/uv/`)
|
|
9
|
+
- An OpenAI API key in your environment:
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
export OPENAI_API_KEY=your_key_here
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
### Files
|
|
16
|
+
|
|
17
|
+
- `prepare_data.py`: downloads ChartQA (full) and prepares a 100-sample subset of line charts.
|
|
18
|
+
- `optimize.py`: baseline VLM function (`VLMExtractor.image_to_csv`) to be optimized.
|
|
19
|
+
- `eval.py`: evaluation harness that runs the baseline on images and reports a similarity score as "accuracy".
|
|
20
|
+
|
|
21
|
+
Generated artifacts (gitignored):
|
|
22
|
+
- `subset_line_100/` and `subset_line_100.zip`
|
|
23
|
+
- `predictions/`
|
|
24
|
+
|
|
25
|
+
### 1) Prepare the data
|
|
26
|
+
|
|
27
|
+
From the repo root or this directory:
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
cd examples/extract-line-plot
|
|
31
|
+
uv run --with huggingface_hub python prepare_data.py
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Notes:
|
|
35
|
+
- Downloads the ChartQA dataset snapshot and auto-extracts `ChartQA Dataset.zip` if needed.
|
|
36
|
+
- Produces `subset_line_100/` with `index.csv`, `images/`, and `tables/`.
|
|
37
|
+
|
|
38
|
+
### 2) Run a baseline evaluation once
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
uv run --with openai python eval.py --max-samples 10 --num-workers 4
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
This writes predicted CSVs to `predictions/` and prints a final line like `accuracy: 0.32`.
|
|
45
|
+
|
|
46
|
+
Metric definition (summarized):
|
|
47
|
+
- Per-sample score = 0.2 × header match + 0.8 × Jaccard(similarity of content rows).
|
|
48
|
+
- Reported `accuracy` is the mean score over all evaluated samples.
|
|
49
|
+
|
|
50
|
+
### 3) Optimize the baseline with Weco
|
|
51
|
+
|
|
52
|
+
Run Weco to iteratively improve `optimize.py` using 100 examples and many workers:
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
weco run --source optimize.py --eval-command 'uv run --with openai python eval.py --max-samples 100 --num-workers 50' --metric accuracy --goal maximize --steps 20 --model gpt-5
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Arguments:
|
|
59
|
+
- `--source optimize.py`: file that Weco will edit to improve results.
|
|
60
|
+
- `--eval-command '…'`: command Weco executes to measure the metric.
|
|
61
|
+
- `--metric accuracy`: Weco parses `accuracy: <value>` from `eval.py` output.
|
|
62
|
+
- `--goal maximize`: higher is better.
|
|
63
|
+
- `--steps 20`: number of optimization iterations.
|
|
64
|
+
- `--model gpt-5`: model used by Weco to propose edits (change as desired).
|
|
65
|
+
|
|
66
|
+
### Tips
|
|
67
|
+
|
|
68
|
+
- Ensure your OpenAI key has access to a vision-capable model (default: `gpt-4o-mini` in the eval; change via `--model`).
|
|
69
|
+
- Adjust `--num-workers` to balance throughput and rate limits.
|
|
70
|
+
- You can tweak baseline behavior in `optimize.py` (prompt, temperature) — Weco will explore modifications automatically during optimization.
|
|
71
|
+
|
|
72
|
+
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import csv
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
from optimize import VLMExtractor
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import matplotlib
|
|
15
|
+
|
|
16
|
+
matplotlib.use("Agg")
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
except Exception: # pragma: no cover - optional dependency
|
|
19
|
+
plt = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def read_index(index_csv_path: Path) -> List[Tuple[str, Path, Path]]:
|
|
23
|
+
rows: List[Tuple[str, Path, Path]] = []
|
|
24
|
+
with open(index_csv_path, "r", encoding="utf-8") as f:
|
|
25
|
+
reader = csv.DictReader(f)
|
|
26
|
+
for row in reader:
|
|
27
|
+
rows.append((row["id"].strip(), Path(row["image"].strip()), Path(row["table"].strip())))
|
|
28
|
+
return rows
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def write_csv(output_dir: Path, example_id: str, csv_text: str) -> Path:
|
|
32
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
33
|
+
out_path = output_dir / f"{example_id}.csv"
|
|
34
|
+
out_path.write_text(csv_text, encoding="utf-8")
|
|
35
|
+
return out_path
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def read_csv_table(path: Path) -> Tuple[List[str], List[List[str]]]:
|
|
39
|
+
header: List[str] = []
|
|
40
|
+
rows: List[List[str]] = []
|
|
41
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
42
|
+
reader = csv.reader(f)
|
|
43
|
+
for row in reader:
|
|
44
|
+
cleaned = [cell.strip() for cell in row]
|
|
45
|
+
if not any(cleaned):
|
|
46
|
+
continue
|
|
47
|
+
if not header:
|
|
48
|
+
header = cleaned
|
|
49
|
+
else:
|
|
50
|
+
rows.append(cleaned)
|
|
51
|
+
return header, rows
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def header_match_score(gt_header: List[str], pred_header: List[str]) -> float:
|
|
55
|
+
if not gt_header or not pred_header:
|
|
56
|
+
return 0.0
|
|
57
|
+
normalized_gt = [cell.strip().lower() for cell in gt_header]
|
|
58
|
+
normalized_pred = [cell.strip().lower() for cell in pred_header]
|
|
59
|
+
return 1.0 if normalized_gt == normalized_pred else 0.0
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _to_float(cell: str) -> Optional[float]:
|
|
63
|
+
cell = cell.strip()
|
|
64
|
+
if not cell:
|
|
65
|
+
return None
|
|
66
|
+
normalized = cell.replace(",", "")
|
|
67
|
+
if normalized.endswith("%"):
|
|
68
|
+
normalized = normalized[:-1]
|
|
69
|
+
if normalized.startswith("$"):
|
|
70
|
+
normalized = normalized[1:]
|
|
71
|
+
try:
|
|
72
|
+
return float(normalized)
|
|
73
|
+
except ValueError:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _build_row_map(rows: List[List[str]]) -> Dict[str, List[List[str]]]:
|
|
78
|
+
mapping: Dict[str, List[List[str]]] = {}
|
|
79
|
+
for row in rows:
|
|
80
|
+
if not row:
|
|
81
|
+
continue
|
|
82
|
+
key = row[0].strip().lower()
|
|
83
|
+
mapping.setdefault(key, []).append(row[1:])
|
|
84
|
+
return mapping
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _score_row(gt_values: List[str], pred_values: List[str]) -> float:
|
|
88
|
+
if not gt_values:
|
|
89
|
+
return 1.0
|
|
90
|
+
|
|
91
|
+
per_column_scores: List[float] = []
|
|
92
|
+
for idx, gt_cell in enumerate(gt_values):
|
|
93
|
+
pred_cell = pred_values[idx] if idx < len(pred_values) else ""
|
|
94
|
+
gt_num = _to_float(gt_cell)
|
|
95
|
+
pred_num = _to_float(pred_cell)
|
|
96
|
+
if gt_num is None or pred_num is None:
|
|
97
|
+
per_column_scores.append(0.0)
|
|
98
|
+
continue
|
|
99
|
+
denom = abs(gt_num) + abs(pred_num)
|
|
100
|
+
if denom == 0:
|
|
101
|
+
per_column_scores.append(1.0)
|
|
102
|
+
continue
|
|
103
|
+
smape = 2.0 * abs(pred_num - gt_num) / denom
|
|
104
|
+
per_column_scores.append(max(0.0, 1.0 - smape / 2.0))
|
|
105
|
+
|
|
106
|
+
if not per_column_scores:
|
|
107
|
+
return 0.0
|
|
108
|
+
return sum(per_column_scores) / len(per_column_scores)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def visualize_difference(
|
|
112
|
+
gt_csv_path: Path,
|
|
113
|
+
pred_csv_path: Path,
|
|
114
|
+
example_id: str,
|
|
115
|
+
output_dir: Path,
|
|
116
|
+
ignore_header_mismatch: bool = False,
|
|
117
|
+
verbose: bool = False,
|
|
118
|
+
) -> Optional[Path]:
|
|
119
|
+
def _viz_skip(reason: str) -> None:
|
|
120
|
+
if verbose:
|
|
121
|
+
print(f"[viz] skip {example_id}: {reason}", file=sys.stderr)
|
|
122
|
+
|
|
123
|
+
if plt is None:
|
|
124
|
+
_viz_skip("matplotlib not available")
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
gt_header, gt_rows = read_csv_table(gt_csv_path)
|
|
129
|
+
pred_header, pred_rows = read_csv_table(pred_csv_path)
|
|
130
|
+
except Exception:
|
|
131
|
+
_viz_skip("failed to read CSVs")
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
if not gt_header or not pred_header:
|
|
135
|
+
_viz_skip("missing headers")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
if header_match_score(gt_header, pred_header) < 1.0 and not ignore_header_mismatch:
|
|
139
|
+
_viz_skip("header mismatch (use --visualize-allow-header-mismatch to override)")
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
if len(gt_header) <= 1:
|
|
143
|
+
_viz_skip("no data columns in GT header")
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
columns = gt_header[1:]
|
|
147
|
+
gt_series: Dict[str, List[float]] = {col: [] for col in columns}
|
|
148
|
+
pred_series: Dict[str, List[float]] = {col: [] for col in columns}
|
|
149
|
+
diff_series: Dict[str, List[float]] = {col: [] for col in columns}
|
|
150
|
+
x_labels: List[str] = []
|
|
151
|
+
|
|
152
|
+
pred_map = _build_row_map(pred_rows)
|
|
153
|
+
pred_consumed: Dict[str, int] = {}
|
|
154
|
+
|
|
155
|
+
for gt_row in gt_rows:
|
|
156
|
+
if not gt_row:
|
|
157
|
+
continue
|
|
158
|
+
x_label = gt_row[0].strip()
|
|
159
|
+
key = x_label.lower()
|
|
160
|
+
pred_entries = pred_map.get(key, [])
|
|
161
|
+
pred_idx = pred_consumed.get(key, 0)
|
|
162
|
+
pred_values = pred_entries[pred_idx] if pred_idx < len(pred_entries) else []
|
|
163
|
+
pred_consumed[key] = pred_idx + 1
|
|
164
|
+
|
|
165
|
+
x_labels.append(x_label or f"row_{len(x_labels) + 1}")
|
|
166
|
+
|
|
167
|
+
for col_idx, col_name in enumerate(columns):
|
|
168
|
+
gt_val = _to_float(gt_row[col_idx + 1]) if col_idx + 1 < len(gt_row) else None
|
|
169
|
+
pred_val = _to_float(pred_values[col_idx]) if col_idx < len(pred_values) else None
|
|
170
|
+
gt_float = gt_val if gt_val is not None else math.nan
|
|
171
|
+
pred_float = pred_val if pred_val is not None else math.nan
|
|
172
|
+
|
|
173
|
+
if math.isnan(gt_float):
|
|
174
|
+
# If GT is missing, treat as zero difference but keep nan for plotting gaps
|
|
175
|
+
diff = math.nan
|
|
176
|
+
elif math.isnan(pred_float):
|
|
177
|
+
diff = math.nan
|
|
178
|
+
else:
|
|
179
|
+
diff = pred_float - gt_float
|
|
180
|
+
|
|
181
|
+
gt_series[col_name].append(gt_float)
|
|
182
|
+
pred_series[col_name].append(pred_float)
|
|
183
|
+
diff_series[col_name].append(diff)
|
|
184
|
+
|
|
185
|
+
if not x_labels:
|
|
186
|
+
_viz_skip("no x labels in GT rows")
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
num_series = len(columns)
|
|
190
|
+
if num_series == 0:
|
|
191
|
+
_viz_skip("no numeric series to plot")
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
x_positions = list(range(len(x_labels)))
|
|
195
|
+
fig_height = max(3.0, 3.0 * num_series)
|
|
196
|
+
fig, axes = plt.subplots(num_series, 1, sharex=True, figsize=(11, fig_height))
|
|
197
|
+
if num_series == 1:
|
|
198
|
+
axes = [axes]
|
|
199
|
+
|
|
200
|
+
for ax, col_name in zip(axes, columns):
|
|
201
|
+
gt_values = gt_series[col_name]
|
|
202
|
+
pred_values = pred_series[col_name]
|
|
203
|
+
diff_values = diff_series[col_name]
|
|
204
|
+
|
|
205
|
+
ax.plot(x_positions, gt_values, marker="o", linewidth=1.5, label="Ground Truth")
|
|
206
|
+
ax.plot(x_positions, pred_values, marker="o", linewidth=1.5, label="Prediction")
|
|
207
|
+
ax.set_ylabel(col_name)
|
|
208
|
+
ax.grid(True, axis="y", alpha=0.3)
|
|
209
|
+
|
|
210
|
+
has_diff = any(not math.isnan(v) for v in diff_values)
|
|
211
|
+
legend_handles, legend_labels = ax.get_legend_handles_labels()
|
|
212
|
+
|
|
213
|
+
if has_diff:
|
|
214
|
+
ax2 = ax.twinx()
|
|
215
|
+
ax2.plot(x_positions, diff_values, linestyle="--", color="tab:red", marker="x", linewidth=1.2, label="Pred - GT")
|
|
216
|
+
ax2.axhline(0.0, color="tab:red", linewidth=0.8, alpha=0.4)
|
|
217
|
+
ax2.set_ylabel("Pred - GT")
|
|
218
|
+
handles2, labels2 = ax2.get_legend_handles_labels()
|
|
219
|
+
legend_handles += handles2
|
|
220
|
+
legend_labels += labels2
|
|
221
|
+
|
|
222
|
+
if legend_handles:
|
|
223
|
+
ax.legend(legend_handles, legend_labels, loc="upper left")
|
|
224
|
+
|
|
225
|
+
axes[-1].set_xticks(x_positions)
|
|
226
|
+
axes[-1].set_xticklabels(x_labels, rotation=45, ha="right")
|
|
227
|
+
axes[-1].set_xlabel(gt_header[0])
|
|
228
|
+
|
|
229
|
+
fig.suptitle(f"{example_id}: Ground Truth vs Prediction", fontsize=14)
|
|
230
|
+
fig.tight_layout(rect=(0, 0, 1, 0.96))
|
|
231
|
+
|
|
232
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
233
|
+
out_path = output_dir / f"{example_id}.png"
|
|
234
|
+
fig.savefig(out_path, dpi=150)
|
|
235
|
+
plt.close(fig)
|
|
236
|
+
return out_path
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def evaluate_predictions(gt_csv_path: Path, pred_csv_path: Path) -> float:
|
|
240
|
+
gt_header, gt_rows = read_csv_table(gt_csv_path)
|
|
241
|
+
pred_header, pred_rows = read_csv_table(pred_csv_path)
|
|
242
|
+
if not gt_header or not pred_header:
|
|
243
|
+
return 0.0
|
|
244
|
+
|
|
245
|
+
header_score = header_match_score(gt_header, pred_header)
|
|
246
|
+
|
|
247
|
+
gt_map = _build_row_map(gt_rows)
|
|
248
|
+
pred_map = _build_row_map(pred_rows)
|
|
249
|
+
|
|
250
|
+
row_scores: List[float] = []
|
|
251
|
+
for key, gt_entries in gt_map.items():
|
|
252
|
+
pred_entries = pred_map.get(key, [])
|
|
253
|
+
for idx, gt_values in enumerate(gt_entries):
|
|
254
|
+
pred_values = pred_entries[idx] if idx < len(pred_entries) else []
|
|
255
|
+
row_scores.append(_score_row(gt_values, pred_values))
|
|
256
|
+
|
|
257
|
+
content_score = sum(row_scores) / len(row_scores) if row_scores else 0.0
|
|
258
|
+
return 0.2 * header_score + 0.8 * content_score
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def process_one(
|
|
262
|
+
extractor: VLMExtractor, base_dir: Path, example_id: str, image_rel: Path, gt_table_rel: Path, output_dir: Path
|
|
263
|
+
) -> Tuple[str, float, Path, Path]:
|
|
264
|
+
image_path = base_dir / image_rel
|
|
265
|
+
gt_csv_path = base_dir / gt_table_rel
|
|
266
|
+
pred_csv_text = extractor.image_to_csv(image_path)
|
|
267
|
+
pred_path = write_csv(output_dir, example_id, pred_csv_text)
|
|
268
|
+
score = evaluate_predictions(gt_csv_path, pred_path)
|
|
269
|
+
return example_id, score, pred_path, gt_csv_path
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def main() -> None:
|
|
273
|
+
parser = argparse.ArgumentParser(description="Evaluate VLM extraction: image -> CSV")
|
|
274
|
+
parser.add_argument("--data-dir", type=str, default="subset_line_100")
|
|
275
|
+
parser.add_argument("--index", type=str, default="index.csv")
|
|
276
|
+
parser.add_argument("--out-dir", type=str, default="predictions")
|
|
277
|
+
parser.add_argument("--max-samples", type=int, default=100)
|
|
278
|
+
parser.add_argument("--num-workers", type=int, default=4)
|
|
279
|
+
parser.add_argument(
|
|
280
|
+
"--visualize-dir",
|
|
281
|
+
type=str,
|
|
282
|
+
default=None,
|
|
283
|
+
help="Directory where GT vs prediction plots will be saved (requires matplotlib).",
|
|
284
|
+
)
|
|
285
|
+
parser.add_argument(
|
|
286
|
+
"--visualize-max",
|
|
287
|
+
type=int,
|
|
288
|
+
default=1,
|
|
289
|
+
help="Maximum number of plots to generate when --visualize-dir is set. Use 0 for no limit.",
|
|
290
|
+
)
|
|
291
|
+
parser.add_argument(
|
|
292
|
+
"--visualize-allow-header-mismatch",
|
|
293
|
+
action="store_true",
|
|
294
|
+
help="If set, still plot GT vs prediction even when headers differ.",
|
|
295
|
+
)
|
|
296
|
+
parser.add_argument("--visualize-verbose", action="store_true", help="Print reasons when a visualization is skipped.")
|
|
297
|
+
args = parser.parse_args()
|
|
298
|
+
|
|
299
|
+
if not os.getenv("OPENAI_API_KEY"):
|
|
300
|
+
print("[error] OPENAI_API_KEY not set in environment", file=sys.stderr)
|
|
301
|
+
sys.exit(1)
|
|
302
|
+
|
|
303
|
+
base_dir = Path(args.data_dir)
|
|
304
|
+
index_path = base_dir / args.index
|
|
305
|
+
if not index_path.exists():
|
|
306
|
+
print(f"[error] index.csv not found at {index_path}", file=sys.stderr)
|
|
307
|
+
sys.exit(1)
|
|
308
|
+
|
|
309
|
+
rows = read_index(index_path)[: args.max_samples]
|
|
310
|
+
extractor = VLMExtractor()
|
|
311
|
+
|
|
312
|
+
visualize_dir: Optional[Path] = Path(args.visualize_dir) if args.visualize_dir else None
|
|
313
|
+
visualize_max = max(0, args.visualize_max)
|
|
314
|
+
if visualize_dir and plt is None:
|
|
315
|
+
print("[warn] matplotlib not available; skipping visualization.", file=sys.stderr)
|
|
316
|
+
visualize_dir = None
|
|
317
|
+
|
|
318
|
+
print(f"[setup] evaluating {len(rows)} samples using {extractor.model} …", flush=True)
|
|
319
|
+
start = time.time()
|
|
320
|
+
scores: List[float] = []
|
|
321
|
+
saved_visualizations = 0
|
|
322
|
+
|
|
323
|
+
with ThreadPoolExecutor(max_workers=max(1, args.num_workers)) as pool:
|
|
324
|
+
futures = [
|
|
325
|
+
pool.submit(process_one, extractor, base_dir, example_id, image_rel, gt_table_rel, Path(args.out_dir))
|
|
326
|
+
for (example_id, image_rel, gt_table_rel) in rows
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
try:
|
|
330
|
+
for idx, fut in enumerate(as_completed(futures), 1):
|
|
331
|
+
try:
|
|
332
|
+
example_id, score, pred_path, gt_csv_path = fut.result()
|
|
333
|
+
scores.append(score)
|
|
334
|
+
if visualize_dir and (visualize_max == 0 or saved_visualizations < visualize_max):
|
|
335
|
+
out_path = visualize_difference(
|
|
336
|
+
gt_csv_path,
|
|
337
|
+
pred_path,
|
|
338
|
+
example_id,
|
|
339
|
+
visualize_dir,
|
|
340
|
+
ignore_header_mismatch=args.visualize_allow_header_mismatch,
|
|
341
|
+
verbose=args.visualize_verbose,
|
|
342
|
+
)
|
|
343
|
+
if out_path is not None:
|
|
344
|
+
saved_visualizations += 1
|
|
345
|
+
print(f"[viz] saved {out_path}", flush=True)
|
|
346
|
+
if idx % 5 == 0 or idx == len(rows):
|
|
347
|
+
elapsed = time.time() - start
|
|
348
|
+
avg = sum(scores) / len(scores) if scores else 0.0
|
|
349
|
+
print(f"[progress] {idx}/{len(rows)} done, avg score: {avg:.4f}, elapsed {elapsed:.1f}s", flush=True)
|
|
350
|
+
except Exception as e:
|
|
351
|
+
print(f"[error] failed on sample {idx}: {e}", file=sys.stderr)
|
|
352
|
+
except KeyboardInterrupt:
|
|
353
|
+
print("\n[warn] interrupted by user", file=sys.stderr)
|
|
354
|
+
sys.exit(1)
|
|
355
|
+
|
|
356
|
+
final_score = sum(scores) / len(scores) if scores else 0.0
|
|
357
|
+
|
|
358
|
+
# Apply cost cap: accuracy is zeroed if average cost/query exceeds $0.02
|
|
359
|
+
avg_cost_per_query = (extractor.total_cost_usd / extractor.num_queries) if getattr(extractor, "num_queries", 0) else 0.0
|
|
360
|
+
if avg_cost_per_query > 0.02:
|
|
361
|
+
print(f"[cost] avg ${avg_cost_per_query:.4f}/query exceeds $0.02 cap; accuracy set to 0.0", flush=True)
|
|
362
|
+
final_score = 0.0
|
|
363
|
+
else:
|
|
364
|
+
print(f"[cost] avg ${avg_cost_per_query:.4f}/query within cap", flush=True)
|
|
365
|
+
|
|
366
|
+
print(f"accuracy: {final_score:.4f}")
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
if __name__ == "__main__":
|
|
370
|
+
main()
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Constraints
|
|
2
|
+
- Make sure the cost tracking is correct
|
|
3
|
+
- The average cost per query should be lower than $0.02 per query, the baseline has around $0.007 per query
|
|
4
|
+
|
|
5
|
+
# Here's a list of ideas to be tried:
|
|
6
|
+
1. Try more recent models like gpt-5, note gpt-5 don't have temperature argument
|
|
7
|
+
2. Try response API as below:
|
|
8
|
+
|
|
9
|
+
<api>
|
|
10
|
+
Responses API (concise)
|
|
11
|
+
|
|
12
|
+
Endpoint:
|
|
13
|
+
- POST https://api.openai.com/v1/responses
|
|
14
|
+
|
|
15
|
+
Purpose:
|
|
16
|
+
- Generate text or JSON from text and/or image inputs. Supports optional tools.
|
|
17
|
+
|
|
18
|
+
Minimum request fields:
|
|
19
|
+
- model: string (e.g., gpt-5, gpt-4o)
|
|
20
|
+
- input: string | array (text and/or image items)
|
|
21
|
+
|
|
22
|
+
Useful options:
|
|
23
|
+
- max_output_tokens: integer
|
|
24
|
+
- temperature or top_p: sampling controls (non-reasoning models)
|
|
25
|
+
- reasoning: object (gpt-5, o-series) — controls for reasoning models
|
|
26
|
+
- tools: array (web_search, file_search, function calls)
|
|
27
|
+
- tool_choice: string | object
|
|
28
|
+
- stream: boolean (default false)
|
|
29
|
+
- store: boolean (default true)
|
|
30
|
+
- parallel_tool_calls: boolean (default true)
|
|
31
|
+
|
|
32
|
+
Conversation (optional):
|
|
33
|
+
- conversation: string | object, or previous_response_id: string
|
|
34
|
+
|
|
35
|
+
Notes:
|
|
36
|
+
- gpt-5 does not support temperature; prefer reasoning options instead.
|
|
37
|
+
|
|
38
|
+
Example request:
|
|
39
|
+
|
|
40
|
+
```json
|
|
41
|
+
{
|
|
42
|
+
"model": "gpt-5",
|
|
43
|
+
"input": "Explain how to extract a line plot from an image.",
|
|
44
|
+
"max_output_tokens": 400,
|
|
45
|
+
"reasoning": { "effort": "medium" },
|
|
46
|
+
"stream": false
|
|
47
|
+
}
|
|
48
|
+
```
|
|
49
|
+
</api>
|
|
50
|
+
|
|
51
|
+
3. Try playing with paramters like reasoning efforts
|
|
52
|
+
4. Try to build tools for response api to use, or allow python interpretor
|
|
53
|
+
5. Try to add preprocessing or post processing pipelines
|