weco 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {weco-0.3.0 → weco-0.3.2}/.gitignore +5 -0
  2. {weco-0.3.0 → weco-0.3.2}/PKG-INFO +2 -1
  3. {weco-0.3.0 → weco-0.3.2}/examples/cuda/README.md +9 -5
  4. {weco-0.3.0 → weco-0.3.2}/examples/cuda/evaluate.py +11 -8
  5. weco-0.3.2/examples/extract-line-plot/README.md +72 -0
  6. weco-0.3.2/examples/extract-line-plot/eval.py +370 -0
  7. weco-0.3.2/examples/extract-line-plot/guide.md +53 -0
  8. weco-0.3.2/examples/extract-line-plot/optimize.py +116 -0
  9. weco-0.3.2/examples/extract-line-plot/prepare_data.py +94 -0
  10. weco-0.3.2/examples/extract-line-plot/pyproject.toml +18 -0
  11. {weco-0.3.0 → weco-0.3.2}/examples/triton/README.md +8 -7
  12. weco-0.3.2/examples/triton/evaluate.py +107 -0
  13. weco-0.3.2/examples/triton/optimize.py +23 -0
  14. {weco-0.3.0 → weco-0.3.2}/pyproject.toml +2 -1
  15. {weco-0.3.0 → weco-0.3.2}/weco/api.py +84 -46
  16. {weco-0.3.0 → weco-0.3.2}/weco/constants.py +0 -4
  17. {weco-0.3.0 → weco-0.3.2}/weco/optimizer.py +0 -2
  18. {weco-0.3.0 → weco-0.3.2}/weco/utils.py +41 -11
  19. {weco-0.3.0 → weco-0.3.2}/weco.egg-info/PKG-INFO +2 -1
  20. {weco-0.3.0 → weco-0.3.2}/weco.egg-info/SOURCES.txt +6 -0
  21. {weco-0.3.0 → weco-0.3.2}/weco.egg-info/requires.txt +1 -0
  22. weco-0.3.0/examples/triton/evaluate.py +0 -143
  23. weco-0.3.0/examples/triton/optimize.py +0 -44
  24. {weco-0.3.0 → weco-0.3.2}/.github/workflows/lint.yml +0 -0
  25. {weco-0.3.0 → weco-0.3.2}/.github/workflows/release.yml +0 -0
  26. {weco-0.3.0 → weco-0.3.2}/LICENSE +0 -0
  27. {weco-0.3.0 → weco-0.3.2}/README.md +0 -0
  28. {weco-0.3.0 → weco-0.3.2}/assets/example-optimization.gif +0 -0
  29. {weco-0.3.0 → weco-0.3.2}/assets/weco.svg +0 -0
  30. {weco-0.3.0 → weco-0.3.2}/contributing.md +0 -0
  31. {weco-0.3.0 → weco-0.3.2}/examples/README.md +0 -0
  32. {weco-0.3.0 → weco-0.3.2}/examples/cuda/optimize.py +0 -0
  33. {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/README.md +0 -0
  34. {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/colab_notebook_walkthrough.ipynb +0 -0
  35. {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/evaluate.py +0 -0
  36. {weco-0.3.0 → weco-0.3.2}/examples/hello-kernel-world/optimize.py +0 -0
  37. {weco-0.3.0 → weco-0.3.2}/examples/prompt/README.md +0 -0
  38. {weco-0.3.0 → weco-0.3.2}/examples/prompt/eval.py +0 -0
  39. {weco-0.3.0 → weco-0.3.2}/examples/prompt/optimize.py +0 -0
  40. {weco-0.3.0 → weco-0.3.2}/examples/prompt/prompt_guide.md +0 -0
  41. {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/README.md +0 -0
  42. {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/competition_description.md +0 -0
  43. {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/data/sample_submission.csv +0 -0
  44. {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/data/test.csv +0 -0
  45. {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/data/train.csv +0 -0
  46. {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/evaluate.py +0 -0
  47. {weco-0.3.0 → weco-0.3.2}/examples/spaceship-titanic/train.py +0 -0
  48. {weco-0.3.0 → weco-0.3.2}/setup.cfg +0 -0
  49. {weco-0.3.0 → weco-0.3.2}/weco/__init__.py +0 -0
  50. {weco-0.3.0 → weco-0.3.2}/weco/auth.py +0 -0
  51. {weco-0.3.0 → weco-0.3.2}/weco/chatbot.py +0 -0
  52. {weco-0.3.0 → weco-0.3.2}/weco/cli.py +0 -0
  53. {weco-0.3.0 → weco-0.3.2}/weco/credits.py +0 -0
  54. {weco-0.3.0 → weco-0.3.2}/weco/panels.py +0 -0
  55. {weco-0.3.0 → weco-0.3.2}/weco.egg-info/dependency_links.txt +0 -0
  56. {weco-0.3.0 → weco-0.3.2}/weco.egg-info/entry_points.txt +0 -0
  57. {weco-0.3.0 → weco-0.3.2}/weco.egg-info/top_level.txt +0 -0
@@ -84,3 +84,8 @@ repomix-output.*
84
84
 
85
85
  # Claude config
86
86
  .claude/
87
+
88
+ # Example: extract-line-plot generated artifacts
89
+ examples/extract-line-plot/predictions/
90
+ examples/extract-line-plot/subset_line_*/
91
+ examples/extract-line-plot/subset_line_*.zip
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: weco
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Documentation for `weco`, a CLI for using Weco AI's code optimizer.
5
5
  Author-email: Weco AI Team <contact@weco.ai>
6
6
  License:
@@ -219,6 +219,7 @@ Requires-Dist: packaging
219
219
  Requires-Dist: gitingest
220
220
  Requires-Dist: fastapi
221
221
  Requires-Dist: slowapi
222
+ Requires-Dist: psutil
222
223
  Provides-Extra: dev
223
224
  Requires-Dist: ruff; extra == "dev"
224
225
  Requires-Dist: build; extra == "dev"
@@ -7,9 +7,11 @@ This approach aims for low-level optimization beyond standard PyTorch or even Tr
7
7
 
8
8
  Install the CLI and dependencies for the example:
9
9
  ```bash
10
- pip install weco torch ninja triton
10
+ pip install weco ninja numpy torch triton
11
11
  ```
12
- > **Note:** This example requires a compatible NVIDIA GPU and the CUDA Toolkit installed on your system for compiling and running the generated CUDA code.
12
+ > **Note:**
13
+ > 1. This example requires a compatible NVIDIA GPU and the CUDA Toolkit installed on your system for compiling and running the generated CUDA code.
14
+ > 2. If compatible, install [flash attention](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) (`pip install flash-attn --no-build-isolation`).
13
15
 
14
16
  ## Run Weco
15
17
 
@@ -20,8 +22,9 @@ weco run --source optimize.py \
20
22
  --metric speedup \
21
23
  --goal maximize \
22
24
  --steps 50 \
23
- --model o4-mini \
24
- --additional-instructions "Write in-line CUDA using pytorch's load_inline() to optimize the code while ensuring a small max float diff. Maintain the same code format. Do not use any fallbacks. Assume any required dependencies are installed and data is already on the gpu."
25
+ --model gpt-5 \
26
+ --additional-instructions "Write in-line CUDA using pytorch's load_inline() to optimize the code while ensuring a small max float diff. Maintain the same code interface. Do not use any fallbacks and never use the build_directory arg for load_inline(). Assume any required dependencies are installed and data is already on the gpu." \
27
+ --eval-timeout 600
25
28
  ```
26
29
 
27
30
  ### Explanation
@@ -31,8 +34,9 @@ weco run --source optimize.py \
31
34
  * `--metric speedup`: The optimization target metric.
32
35
  * `--goal maximize`: Weco aims to increase the speedup.
33
36
  * `--steps 50`: The number of optimization iterations.
34
- * `--model o4-mini`: The LLM used for code generation.
37
+ * `--model gpt-5`: The LLM used for code generation.
35
38
  * `--additional-instructions "..."`: Provides guidance to the LLM on the optimization approach.
39
+ * `--eval-timeout 600`: Stop runnning the evaluation script if it does not complete in 600 seconds.
36
40
 
37
41
  Weco will iteratively modify `optimize.py`, generating and integrating CUDA C++ code, guided by the evaluation results and the additional instructions provided.
38
42
 
@@ -1,5 +1,6 @@
1
1
  import sys
2
2
  import os
3
+ import shutil
3
4
  import pathlib
4
5
  import importlib
5
6
  import importlib.util
@@ -55,10 +56,9 @@ class Model(nn.Module):
55
56
 
56
57
 
57
58
  ########################################################
58
- # Weco Solution
59
+ # Benchmark
59
60
  ########################################################
60
61
  def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
61
- # Clean out all old compiled extensions to prevent namespace collisions during build
62
62
  module_path = pathlib.Path(module_path)
63
63
  name = module_path.stem
64
64
  spec = importlib.util.spec_from_file_location(name, module_path)
@@ -69,12 +69,6 @@ def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
69
69
  return mod
70
70
 
71
71
 
72
- ########################################################
73
- # Benchmark
74
- ########################################################
75
- os.environ["MAX_JOBS"] = "1" # number of workers for building with ninja
76
-
77
-
78
72
  def get_inputs(batch_size, seq_len, n_embd, device):
79
73
  return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
80
74
 
@@ -86,6 +80,12 @@ if __name__ == "__main__":
86
80
  parser.add_argument("--solution-path", type=str, required=True)
87
81
  args = parser.parse_args()
88
82
 
83
+ # setup local cache for PyTorch extensions
84
+ cache_dir = pathlib.Path.cwd() / ".weco-temp/torch_extensions"
85
+ shutil.rmtree(cache_dir.parent, ignore_errors=True)
86
+ cache_dir.mkdir(parents=True, exist_ok=True)
87
+ os.environ["TORCH_EXTENSIONS_DIR"] = str(cache_dir)
88
+
89
89
  # benchmarking parameters
90
90
  n_correctness_trials = 10
91
91
  correctness_tolerance = 1e-5
@@ -145,3 +145,6 @@ if __name__ == "__main__":
145
145
  t_avg_optimized = do_bench(lambda: solution_model(inputs), warmup=warmup_ms, rep=rep_ms)
146
146
  print(f"optimized time: {t_avg_optimized:.2f}ms")
147
147
  print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")
148
+
149
+ # clean up
150
+ shutil.rmtree(cache_dir.parent, ignore_errors=True)
@@ -0,0 +1,72 @@
1
+ ## Extract Line Plot (Chart → CSV) with a VLM
2
+
3
+ This example is about optimizing an AI feature that turns image of chart into a table in csv format.
4
+
5
+ ### Prerequisites
6
+
7
+ - Python 3.9+
8
+ - `uv` installed (see `https://docs.astral.sh/uv/`)
9
+ - An OpenAI API key in your environment:
10
+
11
+ ```bash
12
+ export OPENAI_API_KEY=your_key_here
13
+ ```
14
+
15
+ ### Files
16
+
17
+ - `prepare_data.py`: downloads ChartQA (full) and prepares a 100-sample subset of line charts.
18
+ - `optimize.py`: baseline VLM function (`VLMExtractor.image_to_csv`) to be optimized.
19
+ - `eval.py`: evaluation harness that runs the baseline on images and reports a similarity score as "accuracy".
20
+
21
+ Generated artifacts (gitignored):
22
+ - `subset_line_100/` and `subset_line_100.zip`
23
+ - `predictions/`
24
+
25
+ ### 1) Prepare the data
26
+
27
+ From the repo root or this directory:
28
+
29
+ ```bash
30
+ cd examples/extract-line-plot
31
+ uv run --with huggingface_hub python prepare_data.py
32
+ ```
33
+
34
+ Notes:
35
+ - Downloads the ChartQA dataset snapshot and auto-extracts `ChartQA Dataset.zip` if needed.
36
+ - Produces `subset_line_100/` with `index.csv`, `images/`, and `tables/`.
37
+
38
+ ### 2) Run a baseline evaluation once
39
+
40
+ ```bash
41
+ uv run --with openai python eval.py --max-samples 10 --num-workers 4
42
+ ```
43
+
44
+ This writes predicted CSVs to `predictions/` and prints a final line like `accuracy: 0.32`.
45
+
46
+ Metric definition (summarized):
47
+ - Per-sample score = 0.2 × header match + 0.8 × Jaccard(similarity of content rows).
48
+ - Reported `accuracy` is the mean score over all evaluated samples.
49
+
50
+ ### 3) Optimize the baseline with Weco
51
+
52
+ Run Weco to iteratively improve `optimize.py` using 100 examples and many workers:
53
+
54
+ ```bash
55
+ weco run --source optimize.py --eval-command 'uv run --with openai python eval.py --max-samples 100 --num-workers 50' --metric accuracy --goal maximize --steps 20 --model gpt-5
56
+ ```
57
+
58
+ Arguments:
59
+ - `--source optimize.py`: file that Weco will edit to improve results.
60
+ - `--eval-command '…'`: command Weco executes to measure the metric.
61
+ - `--metric accuracy`: Weco parses `accuracy: <value>` from `eval.py` output.
62
+ - `--goal maximize`: higher is better.
63
+ - `--steps 20`: number of optimization iterations.
64
+ - `--model gpt-5`: model used by Weco to propose edits (change as desired).
65
+
66
+ ### Tips
67
+
68
+ - Ensure your OpenAI key has access to a vision-capable model (default: `gpt-4o-mini` in the eval; change via `--model`).
69
+ - Adjust `--num-workers` to balance throughput and rate limits.
70
+ - You can tweak baseline behavior in `optimize.py` (prompt, temperature) — Weco will explore modifications automatically during optimization.
71
+
72
+
@@ -0,0 +1,370 @@
1
+ import argparse
2
+ import csv
3
+ import math
4
+ import os
5
+ import sys
6
+ import time
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ from optimize import VLMExtractor
12
+
13
+ try:
14
+ import matplotlib
15
+
16
+ matplotlib.use("Agg")
17
+ import matplotlib.pyplot as plt
18
+ except Exception: # pragma: no cover - optional dependency
19
+ plt = None
20
+
21
+
22
+ def read_index(index_csv_path: Path) -> List[Tuple[str, Path, Path]]:
23
+ rows: List[Tuple[str, Path, Path]] = []
24
+ with open(index_csv_path, "r", encoding="utf-8") as f:
25
+ reader = csv.DictReader(f)
26
+ for row in reader:
27
+ rows.append((row["id"].strip(), Path(row["image"].strip()), Path(row["table"].strip())))
28
+ return rows
29
+
30
+
31
+ def write_csv(output_dir: Path, example_id: str, csv_text: str) -> Path:
32
+ output_dir.mkdir(parents=True, exist_ok=True)
33
+ out_path = output_dir / f"{example_id}.csv"
34
+ out_path.write_text(csv_text, encoding="utf-8")
35
+ return out_path
36
+
37
+
38
+ def read_csv_table(path: Path) -> Tuple[List[str], List[List[str]]]:
39
+ header: List[str] = []
40
+ rows: List[List[str]] = []
41
+ with open(path, "r", encoding="utf-8") as f:
42
+ reader = csv.reader(f)
43
+ for row in reader:
44
+ cleaned = [cell.strip() for cell in row]
45
+ if not any(cleaned):
46
+ continue
47
+ if not header:
48
+ header = cleaned
49
+ else:
50
+ rows.append(cleaned)
51
+ return header, rows
52
+
53
+
54
+ def header_match_score(gt_header: List[str], pred_header: List[str]) -> float:
55
+ if not gt_header or not pred_header:
56
+ return 0.0
57
+ normalized_gt = [cell.strip().lower() for cell in gt_header]
58
+ normalized_pred = [cell.strip().lower() for cell in pred_header]
59
+ return 1.0 if normalized_gt == normalized_pred else 0.0
60
+
61
+
62
+ def _to_float(cell: str) -> Optional[float]:
63
+ cell = cell.strip()
64
+ if not cell:
65
+ return None
66
+ normalized = cell.replace(",", "")
67
+ if normalized.endswith("%"):
68
+ normalized = normalized[:-1]
69
+ if normalized.startswith("$"):
70
+ normalized = normalized[1:]
71
+ try:
72
+ return float(normalized)
73
+ except ValueError:
74
+ return None
75
+
76
+
77
+ def _build_row_map(rows: List[List[str]]) -> Dict[str, List[List[str]]]:
78
+ mapping: Dict[str, List[List[str]]] = {}
79
+ for row in rows:
80
+ if not row:
81
+ continue
82
+ key = row[0].strip().lower()
83
+ mapping.setdefault(key, []).append(row[1:])
84
+ return mapping
85
+
86
+
87
+ def _score_row(gt_values: List[str], pred_values: List[str]) -> float:
88
+ if not gt_values:
89
+ return 1.0
90
+
91
+ per_column_scores: List[float] = []
92
+ for idx, gt_cell in enumerate(gt_values):
93
+ pred_cell = pred_values[idx] if idx < len(pred_values) else ""
94
+ gt_num = _to_float(gt_cell)
95
+ pred_num = _to_float(pred_cell)
96
+ if gt_num is None or pred_num is None:
97
+ per_column_scores.append(0.0)
98
+ continue
99
+ denom = abs(gt_num) + abs(pred_num)
100
+ if denom == 0:
101
+ per_column_scores.append(1.0)
102
+ continue
103
+ smape = 2.0 * abs(pred_num - gt_num) / denom
104
+ per_column_scores.append(max(0.0, 1.0 - smape / 2.0))
105
+
106
+ if not per_column_scores:
107
+ return 0.0
108
+ return sum(per_column_scores) / len(per_column_scores)
109
+
110
+
111
+ def visualize_difference(
112
+ gt_csv_path: Path,
113
+ pred_csv_path: Path,
114
+ example_id: str,
115
+ output_dir: Path,
116
+ ignore_header_mismatch: bool = False,
117
+ verbose: bool = False,
118
+ ) -> Optional[Path]:
119
+ def _viz_skip(reason: str) -> None:
120
+ if verbose:
121
+ print(f"[viz] skip {example_id}: {reason}", file=sys.stderr)
122
+
123
+ if plt is None:
124
+ _viz_skip("matplotlib not available")
125
+ return None
126
+
127
+ try:
128
+ gt_header, gt_rows = read_csv_table(gt_csv_path)
129
+ pred_header, pred_rows = read_csv_table(pred_csv_path)
130
+ except Exception:
131
+ _viz_skip("failed to read CSVs")
132
+ return None
133
+
134
+ if not gt_header or not pred_header:
135
+ _viz_skip("missing headers")
136
+ return None
137
+
138
+ if header_match_score(gt_header, pred_header) < 1.0 and not ignore_header_mismatch:
139
+ _viz_skip("header mismatch (use --visualize-allow-header-mismatch to override)")
140
+ return None
141
+
142
+ if len(gt_header) <= 1:
143
+ _viz_skip("no data columns in GT header")
144
+ return None
145
+
146
+ columns = gt_header[1:]
147
+ gt_series: Dict[str, List[float]] = {col: [] for col in columns}
148
+ pred_series: Dict[str, List[float]] = {col: [] for col in columns}
149
+ diff_series: Dict[str, List[float]] = {col: [] for col in columns}
150
+ x_labels: List[str] = []
151
+
152
+ pred_map = _build_row_map(pred_rows)
153
+ pred_consumed: Dict[str, int] = {}
154
+
155
+ for gt_row in gt_rows:
156
+ if not gt_row:
157
+ continue
158
+ x_label = gt_row[0].strip()
159
+ key = x_label.lower()
160
+ pred_entries = pred_map.get(key, [])
161
+ pred_idx = pred_consumed.get(key, 0)
162
+ pred_values = pred_entries[pred_idx] if pred_idx < len(pred_entries) else []
163
+ pred_consumed[key] = pred_idx + 1
164
+
165
+ x_labels.append(x_label or f"row_{len(x_labels) + 1}")
166
+
167
+ for col_idx, col_name in enumerate(columns):
168
+ gt_val = _to_float(gt_row[col_idx + 1]) if col_idx + 1 < len(gt_row) else None
169
+ pred_val = _to_float(pred_values[col_idx]) if col_idx < len(pred_values) else None
170
+ gt_float = gt_val if gt_val is not None else math.nan
171
+ pred_float = pred_val if pred_val is not None else math.nan
172
+
173
+ if math.isnan(gt_float):
174
+ # If GT is missing, treat as zero difference but keep nan for plotting gaps
175
+ diff = math.nan
176
+ elif math.isnan(pred_float):
177
+ diff = math.nan
178
+ else:
179
+ diff = pred_float - gt_float
180
+
181
+ gt_series[col_name].append(gt_float)
182
+ pred_series[col_name].append(pred_float)
183
+ diff_series[col_name].append(diff)
184
+
185
+ if not x_labels:
186
+ _viz_skip("no x labels in GT rows")
187
+ return None
188
+
189
+ num_series = len(columns)
190
+ if num_series == 0:
191
+ _viz_skip("no numeric series to plot")
192
+ return None
193
+
194
+ x_positions = list(range(len(x_labels)))
195
+ fig_height = max(3.0, 3.0 * num_series)
196
+ fig, axes = plt.subplots(num_series, 1, sharex=True, figsize=(11, fig_height))
197
+ if num_series == 1:
198
+ axes = [axes]
199
+
200
+ for ax, col_name in zip(axes, columns):
201
+ gt_values = gt_series[col_name]
202
+ pred_values = pred_series[col_name]
203
+ diff_values = diff_series[col_name]
204
+
205
+ ax.plot(x_positions, gt_values, marker="o", linewidth=1.5, label="Ground Truth")
206
+ ax.plot(x_positions, pred_values, marker="o", linewidth=1.5, label="Prediction")
207
+ ax.set_ylabel(col_name)
208
+ ax.grid(True, axis="y", alpha=0.3)
209
+
210
+ has_diff = any(not math.isnan(v) for v in diff_values)
211
+ legend_handles, legend_labels = ax.get_legend_handles_labels()
212
+
213
+ if has_diff:
214
+ ax2 = ax.twinx()
215
+ ax2.plot(x_positions, diff_values, linestyle="--", color="tab:red", marker="x", linewidth=1.2, label="Pred - GT")
216
+ ax2.axhline(0.0, color="tab:red", linewidth=0.8, alpha=0.4)
217
+ ax2.set_ylabel("Pred - GT")
218
+ handles2, labels2 = ax2.get_legend_handles_labels()
219
+ legend_handles += handles2
220
+ legend_labels += labels2
221
+
222
+ if legend_handles:
223
+ ax.legend(legend_handles, legend_labels, loc="upper left")
224
+
225
+ axes[-1].set_xticks(x_positions)
226
+ axes[-1].set_xticklabels(x_labels, rotation=45, ha="right")
227
+ axes[-1].set_xlabel(gt_header[0])
228
+
229
+ fig.suptitle(f"{example_id}: Ground Truth vs Prediction", fontsize=14)
230
+ fig.tight_layout(rect=(0, 0, 1, 0.96))
231
+
232
+ output_dir.mkdir(parents=True, exist_ok=True)
233
+ out_path = output_dir / f"{example_id}.png"
234
+ fig.savefig(out_path, dpi=150)
235
+ plt.close(fig)
236
+ return out_path
237
+
238
+
239
+ def evaluate_predictions(gt_csv_path: Path, pred_csv_path: Path) -> float:
240
+ gt_header, gt_rows = read_csv_table(gt_csv_path)
241
+ pred_header, pred_rows = read_csv_table(pred_csv_path)
242
+ if not gt_header or not pred_header:
243
+ return 0.0
244
+
245
+ header_score = header_match_score(gt_header, pred_header)
246
+
247
+ gt_map = _build_row_map(gt_rows)
248
+ pred_map = _build_row_map(pred_rows)
249
+
250
+ row_scores: List[float] = []
251
+ for key, gt_entries in gt_map.items():
252
+ pred_entries = pred_map.get(key, [])
253
+ for idx, gt_values in enumerate(gt_entries):
254
+ pred_values = pred_entries[idx] if idx < len(pred_entries) else []
255
+ row_scores.append(_score_row(gt_values, pred_values))
256
+
257
+ content_score = sum(row_scores) / len(row_scores) if row_scores else 0.0
258
+ return 0.2 * header_score + 0.8 * content_score
259
+
260
+
261
+ def process_one(
262
+ extractor: VLMExtractor, base_dir: Path, example_id: str, image_rel: Path, gt_table_rel: Path, output_dir: Path
263
+ ) -> Tuple[str, float, Path, Path]:
264
+ image_path = base_dir / image_rel
265
+ gt_csv_path = base_dir / gt_table_rel
266
+ pred_csv_text = extractor.image_to_csv(image_path)
267
+ pred_path = write_csv(output_dir, example_id, pred_csv_text)
268
+ score = evaluate_predictions(gt_csv_path, pred_path)
269
+ return example_id, score, pred_path, gt_csv_path
270
+
271
+
272
+ def main() -> None:
273
+ parser = argparse.ArgumentParser(description="Evaluate VLM extraction: image -> CSV")
274
+ parser.add_argument("--data-dir", type=str, default="subset_line_100")
275
+ parser.add_argument("--index", type=str, default="index.csv")
276
+ parser.add_argument("--out-dir", type=str, default="predictions")
277
+ parser.add_argument("--max-samples", type=int, default=100)
278
+ parser.add_argument("--num-workers", type=int, default=4)
279
+ parser.add_argument(
280
+ "--visualize-dir",
281
+ type=str,
282
+ default=None,
283
+ help="Directory where GT vs prediction plots will be saved (requires matplotlib).",
284
+ )
285
+ parser.add_argument(
286
+ "--visualize-max",
287
+ type=int,
288
+ default=1,
289
+ help="Maximum number of plots to generate when --visualize-dir is set. Use 0 for no limit.",
290
+ )
291
+ parser.add_argument(
292
+ "--visualize-allow-header-mismatch",
293
+ action="store_true",
294
+ help="If set, still plot GT vs prediction even when headers differ.",
295
+ )
296
+ parser.add_argument("--visualize-verbose", action="store_true", help="Print reasons when a visualization is skipped.")
297
+ args = parser.parse_args()
298
+
299
+ if not os.getenv("OPENAI_API_KEY"):
300
+ print("[error] OPENAI_API_KEY not set in environment", file=sys.stderr)
301
+ sys.exit(1)
302
+
303
+ base_dir = Path(args.data_dir)
304
+ index_path = base_dir / args.index
305
+ if not index_path.exists():
306
+ print(f"[error] index.csv not found at {index_path}", file=sys.stderr)
307
+ sys.exit(1)
308
+
309
+ rows = read_index(index_path)[: args.max_samples]
310
+ extractor = VLMExtractor()
311
+
312
+ visualize_dir: Optional[Path] = Path(args.visualize_dir) if args.visualize_dir else None
313
+ visualize_max = max(0, args.visualize_max)
314
+ if visualize_dir and plt is None:
315
+ print("[warn] matplotlib not available; skipping visualization.", file=sys.stderr)
316
+ visualize_dir = None
317
+
318
+ print(f"[setup] evaluating {len(rows)} samples using {extractor.model} …", flush=True)
319
+ start = time.time()
320
+ scores: List[float] = []
321
+ saved_visualizations = 0
322
+
323
+ with ThreadPoolExecutor(max_workers=max(1, args.num_workers)) as pool:
324
+ futures = [
325
+ pool.submit(process_one, extractor, base_dir, example_id, image_rel, gt_table_rel, Path(args.out_dir))
326
+ for (example_id, image_rel, gt_table_rel) in rows
327
+ ]
328
+
329
+ try:
330
+ for idx, fut in enumerate(as_completed(futures), 1):
331
+ try:
332
+ example_id, score, pred_path, gt_csv_path = fut.result()
333
+ scores.append(score)
334
+ if visualize_dir and (visualize_max == 0 or saved_visualizations < visualize_max):
335
+ out_path = visualize_difference(
336
+ gt_csv_path,
337
+ pred_path,
338
+ example_id,
339
+ visualize_dir,
340
+ ignore_header_mismatch=args.visualize_allow_header_mismatch,
341
+ verbose=args.visualize_verbose,
342
+ )
343
+ if out_path is not None:
344
+ saved_visualizations += 1
345
+ print(f"[viz] saved {out_path}", flush=True)
346
+ if idx % 5 == 0 or idx == len(rows):
347
+ elapsed = time.time() - start
348
+ avg = sum(scores) / len(scores) if scores else 0.0
349
+ print(f"[progress] {idx}/{len(rows)} done, avg score: {avg:.4f}, elapsed {elapsed:.1f}s", flush=True)
350
+ except Exception as e:
351
+ print(f"[error] failed on sample {idx}: {e}", file=sys.stderr)
352
+ except KeyboardInterrupt:
353
+ print("\n[warn] interrupted by user", file=sys.stderr)
354
+ sys.exit(1)
355
+
356
+ final_score = sum(scores) / len(scores) if scores else 0.0
357
+
358
+ # Apply cost cap: accuracy is zeroed if average cost/query exceeds $0.02
359
+ avg_cost_per_query = (extractor.total_cost_usd / extractor.num_queries) if getattr(extractor, "num_queries", 0) else 0.0
360
+ if avg_cost_per_query > 0.02:
361
+ print(f"[cost] avg ${avg_cost_per_query:.4f}/query exceeds $0.02 cap; accuracy set to 0.0", flush=True)
362
+ final_score = 0.0
363
+ else:
364
+ print(f"[cost] avg ${avg_cost_per_query:.4f}/query within cap", flush=True)
365
+
366
+ print(f"accuracy: {final_score:.4f}")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()
@@ -0,0 +1,53 @@
1
+ # Constraints
2
+ - Make sure the cost tracking is correct
3
+ - The average cost per query should be lower than $0.02 per query, the baseline has around $0.007 per query
4
+
5
+ # Here's a list of ideas to be tried:
6
+ 1. Try more recent models like gpt-5, note gpt-5 don't have temperature argument
7
+ 2. Try response API as below:
8
+
9
+ <api>
10
+ Responses API (concise)
11
+
12
+ Endpoint:
13
+ - POST https://api.openai.com/v1/responses
14
+
15
+ Purpose:
16
+ - Generate text or JSON from text and/or image inputs. Supports optional tools.
17
+
18
+ Minimum request fields:
19
+ - model: string (e.g., gpt-5, gpt-4o)
20
+ - input: string | array (text and/or image items)
21
+
22
+ Useful options:
23
+ - max_output_tokens: integer
24
+ - temperature or top_p: sampling controls (non-reasoning models)
25
+ - reasoning: object (gpt-5, o-series) — controls for reasoning models
26
+ - tools: array (web_search, file_search, function calls)
27
+ - tool_choice: string | object
28
+ - stream: boolean (default false)
29
+ - store: boolean (default true)
30
+ - parallel_tool_calls: boolean (default true)
31
+
32
+ Conversation (optional):
33
+ - conversation: string | object, or previous_response_id: string
34
+
35
+ Notes:
36
+ - gpt-5 does not support temperature; prefer reasoning options instead.
37
+
38
+ Example request:
39
+
40
+ ```json
41
+ {
42
+ "model": "gpt-5",
43
+ "input": "Explain how to extract a line plot from an image.",
44
+ "max_output_tokens": 400,
45
+ "reasoning": { "effort": "medium" },
46
+ "stream": false
47
+ }
48
+ ```
49
+ </api>
50
+
51
+ 3. Try playing with paramters like reasoning efforts
52
+ 4. Try to build tools for response api to use, or allow python interpretor
53
+ 5. Try to add preprocessing or post processing pipelines