cudag 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudag/__init__.py +334 -0
- cudag/annotation/__init__.py +77 -0
- cudag/annotation/codegen.py +648 -0
- cudag/annotation/config.py +545 -0
- cudag/annotation/loader.py +342 -0
- cudag/annotation/scaffold.py +121 -0
- cudag/annotation/transcription.py +296 -0
- cudag/cli/__init__.py +5 -0
- cudag/cli/main.py +315 -0
- cudag/cli/new.py +873 -0
- cudag/core/__init__.py +364 -0
- cudag/core/button.py +137 -0
- cudag/core/canvas.py +222 -0
- cudag/core/config.py +70 -0
- cudag/core/coords.py +233 -0
- cudag/core/data_grid.py +804 -0
- cudag/core/dataset.py +678 -0
- cudag/core/distribution.py +136 -0
- cudag/core/drawing.py +75 -0
- cudag/core/fonts.py +156 -0
- cudag/core/generator.py +163 -0
- cudag/core/grid.py +367 -0
- cudag/core/grounding_task.py +247 -0
- cudag/core/icon.py +207 -0
- cudag/core/iconlist_task.py +301 -0
- cudag/core/models.py +1251 -0
- cudag/core/random.py +130 -0
- cudag/core/renderer.py +190 -0
- cudag/core/screen.py +402 -0
- cudag/core/scroll_task.py +254 -0
- cudag/core/scrollable_grid.py +447 -0
- cudag/core/state.py +110 -0
- cudag/core/task.py +293 -0
- cudag/core/taskbar.py +350 -0
- cudag/core/text.py +212 -0
- cudag/core/utils.py +82 -0
- cudag/data/surnames.txt +5000 -0
- cudag/modal_apps/__init__.py +4 -0
- cudag/modal_apps/archive.py +103 -0
- cudag/modal_apps/extract.py +138 -0
- cudag/modal_apps/preprocess.py +529 -0
- cudag/modal_apps/upload.py +317 -0
- cudag/prompts/SYSTEM_PROMPT.txt +104 -0
- cudag/prompts/__init__.py +33 -0
- cudag/prompts/system.py +43 -0
- cudag/prompts/tools.py +382 -0
- cudag/py.typed +0 -0
- cudag/schemas/filesystem.json +90 -0
- cudag/schemas/test_record.schema.json +113 -0
- cudag/schemas/train_record.schema.json +90 -0
- cudag/server/__init__.py +21 -0
- cudag/server/app.py +232 -0
- cudag/server/services/__init__.py +9 -0
- cudag/server/services/generator.py +128 -0
- cudag/templates/scripts/archive.sh +35 -0
- cudag/templates/scripts/build.sh +13 -0
- cudag/templates/scripts/extract.sh +54 -0
- cudag/templates/scripts/generate.sh +116 -0
- cudag/templates/scripts/pre-commit.sh +44 -0
- cudag/templates/scripts/preprocess.sh +46 -0
- cudag/templates/scripts/upload.sh +63 -0
- cudag/templates/scripts/verify.py +428 -0
- cudag/validation/__init__.py +35 -0
- cudag/validation/validate.py +508 -0
- cudag-0.3.10.dist-info/METADATA +570 -0
- cudag-0.3.10.dist-info/RECORD +69 -0
- cudag-0.3.10.dist-info/WHEEL +4 -0
- cudag-0.3.10.dist-info/entry_points.txt +2 -0
- cudag-0.3.10.dist-info/licenses/LICENSE +66 -0
cudag/core/dataset.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tylt LLC. All rights reserved.
|
|
2
|
+
# CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
|
|
3
|
+
# is strictly prohibited. For licensing inquiries: hello@claimhawk.app
|
|
4
|
+
|
|
5
|
+
"""Dataset builder for orchestrating sample generation.
|
|
6
|
+
|
|
7
|
+
The DatasetBuilder coordinates Screen, State, Renderer, and Tasks
|
|
8
|
+
to produce JSONL training datasets.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import random
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Callable
|
|
19
|
+
|
|
20
|
+
from cudag.core.coords import normalize_coord
|
|
21
|
+
from cudag.core.task import BaseTask, TaskContext, TaskSample, TestCase
|
|
22
|
+
from cudag.prompts.tools import format_tool_call
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class DatasetConfig:
|
|
27
|
+
"""Configuration for dataset generation."""
|
|
28
|
+
|
|
29
|
+
name_prefix: str
|
|
30
|
+
"""Prefix for dataset name (e.g., "calendar-mike")."""
|
|
31
|
+
|
|
32
|
+
seed: int = 42
|
|
33
|
+
"""Random seed for reproducibility."""
|
|
34
|
+
|
|
35
|
+
task_counts: dict[str, int] = field(default_factory=dict)
|
|
36
|
+
"""Number of samples per task type."""
|
|
37
|
+
|
|
38
|
+
train_split: float = 0.8
|
|
39
|
+
"""Fraction of data for training (rest is test/val)."""
|
|
40
|
+
|
|
41
|
+
system_prompt: str = "computer-use"
|
|
42
|
+
"""System prompt style: "computer-use", "compact", or custom."""
|
|
43
|
+
|
|
44
|
+
output_dir: Path | None = None
|
|
45
|
+
"""Output directory (auto-generated if None)."""
|
|
46
|
+
|
|
47
|
+
image_format: str = "png"
|
|
48
|
+
"""Image format: "png" or "jpg"."""
|
|
49
|
+
|
|
50
|
+
image_quality: int = 95
|
|
51
|
+
"""JPEG quality (ignored for PNG)."""
|
|
52
|
+
|
|
53
|
+
held_out_enabled: bool = False
|
|
54
|
+
"""Whether to hold out samples for evaluation."""
|
|
55
|
+
|
|
56
|
+
held_out_ratio: float = 0.1
|
|
57
|
+
"""Fraction of samples to hold out."""
|
|
58
|
+
|
|
59
|
+
test_count: int = 100
|
|
60
|
+
"""Number of test cases to generate PER TASK TYPE."""
|
|
61
|
+
|
|
62
|
+
test_distribution: dict[str, int] = field(default_factory=dict)
|
|
63
|
+
"""Per-task test counts. Overrides auto-distribution when set."""
|
|
64
|
+
|
|
65
|
+
test_tolerance: tuple[int, int] = (10, 10)
|
|
66
|
+
"""Coordinate tolerance for test (x, y in RU units)."""
|
|
67
|
+
|
|
68
|
+
annotation_ratio: float = 0.1
|
|
69
|
+
"""Fraction of test cases to annotate (0.0-1.0)."""
|
|
70
|
+
|
|
71
|
+
annotation_enabled: bool = True
|
|
72
|
+
"""Whether to generate annotated test images."""
|
|
73
|
+
|
|
74
|
+
annotation_per_type: dict[str, int] = field(default_factory=dict)
|
|
75
|
+
"""Number of annotations per task type. Overrides annotation_ratio when set."""
|
|
76
|
+
|
|
77
|
+
task_distributions: dict[str, dict[str, float]] = field(default_factory=dict)
|
|
78
|
+
"""Distribution of sample types within each task type.
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
task_distributions:
|
|
82
|
+
click-appointment:
|
|
83
|
+
grey_grey: 0.80 # 80% grey background + grey status
|
|
84
|
+
other_colors: 0.15 # 15% other color combos
|
|
85
|
+
adversarial: 0.05 # 5% no match cases
|
|
86
|
+
hover-appointment:
|
|
87
|
+
grey_grey: 0.80
|
|
88
|
+
other_colors: 0.15
|
|
89
|
+
adversarial: 0.05
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __post_init__(self) -> None:
|
|
93
|
+
"""Set default output directory if not provided."""
|
|
94
|
+
if self.output_dir is None:
|
|
95
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
96
|
+
self.output_dir = Path("datasets") / f"{self.name_prefix}_{timestamp}"
|
|
97
|
+
|
|
98
|
+
def get_distribution(self, task_type: str) -> dict[str, float]:
|
|
99
|
+
"""Get distribution for a task type.
|
|
100
|
+
|
|
101
|
+
Returns the task-specific distribution if defined, otherwise
|
|
102
|
+
returns an empty dict (task should use its own defaults).
|
|
103
|
+
"""
|
|
104
|
+
return self.task_distributions.get(task_type, {})
|
|
105
|
+
|
|
106
|
+
def sample_distribution_type(self, task_type: str, rng: Any) -> str | None:
|
|
107
|
+
"""Sample a distribution type for a task.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
task_type: The task type to sample for.
|
|
111
|
+
rng: Random number generator.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
The sampled distribution type name, or None if no distribution defined.
|
|
115
|
+
"""
|
|
116
|
+
dist = self.get_distribution(task_type)
|
|
117
|
+
if not dist:
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
roll = rng.random()
|
|
121
|
+
cumulative = 0.0
|
|
122
|
+
for dist_type, prob in dist.items():
|
|
123
|
+
cumulative += prob
|
|
124
|
+
if roll < cumulative:
|
|
125
|
+
return dist_type
|
|
126
|
+
# Return last type if we somehow miss due to float precision
|
|
127
|
+
return list(dist.keys())[-1] if dist else None
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_yaml(cls, path: Path) -> DatasetConfig:
|
|
131
|
+
"""Load config from YAML file."""
|
|
132
|
+
import yaml
|
|
133
|
+
|
|
134
|
+
with open(path) as f:
|
|
135
|
+
data = yaml.safe_load(f)
|
|
136
|
+
|
|
137
|
+
return cls(
|
|
138
|
+
name_prefix=data.get("name_prefix", "dataset"),
|
|
139
|
+
seed=data.get("seed", 42),
|
|
140
|
+
task_counts=data.get("tasks", {}),
|
|
141
|
+
train_split=data.get("splits", {}).get("train", 0.8),
|
|
142
|
+
system_prompt=data.get("system_prompt", "computer-use"),
|
|
143
|
+
output_dir=Path(data["output_dir"]) if "output_dir" in data else None,
|
|
144
|
+
image_format=data.get("output", {}).get("image_format", "png"),
|
|
145
|
+
image_quality=data.get("output", {}).get("image_quality", 95),
|
|
146
|
+
held_out_enabled=data.get("held_out", {}).get("enabled", False),
|
|
147
|
+
held_out_ratio=data.get("held_out", {}).get("ratio", 0.1),
|
|
148
|
+
test_count=data.get("test", {}).get("count", 100),
|
|
149
|
+
test_distribution=data.get("test", {}).get("distribution", {}),
|
|
150
|
+
test_tolerance=_parse_tolerance(data.get("test", {}).get("tolerance", [10, 10])),
|
|
151
|
+
annotation_ratio=data.get("annotation", {}).get("ratio", 0.1),
|
|
152
|
+
annotation_enabled=data.get("annotation", {}).get("enabled", True),
|
|
153
|
+
annotation_per_type=data.get("annotation", {}).get("per_type", {}),
|
|
154
|
+
task_distributions=data.get("task_distributions", {}),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _parse_tolerance(value: int | list[int]) -> tuple[int, int]:
|
|
159
|
+
"""Parse tolerance from config - handles both int and [x, y] formats."""
|
|
160
|
+
if isinstance(value, int):
|
|
161
|
+
return (value, value)
|
|
162
|
+
return tuple(value) # type: ignore[return-value]
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _wrap_text(text: str, font: Any, max_width: int, draw: Any) -> list[str]:
|
|
166
|
+
"""Wrap text to fit within max_width pixels.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
text: The text to wrap.
|
|
170
|
+
font: PIL ImageFont to use for measuring.
|
|
171
|
+
max_width: Maximum width in pixels.
|
|
172
|
+
draw: PIL ImageDraw for measuring text.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
List of wrapped lines.
|
|
176
|
+
"""
|
|
177
|
+
words = text.split()
|
|
178
|
+
lines = []
|
|
179
|
+
current_line: list[str] = []
|
|
180
|
+
|
|
181
|
+
for word in words:
|
|
182
|
+
test_line = " ".join(current_line + [word])
|
|
183
|
+
bbox = draw.textbbox((0, 0), test_line, font=font)
|
|
184
|
+
width = bbox[2] - bbox[0]
|
|
185
|
+
|
|
186
|
+
if width <= max_width:
|
|
187
|
+
current_line.append(word)
|
|
188
|
+
else:
|
|
189
|
+
if current_line:
|
|
190
|
+
lines.append(" ".join(current_line))
|
|
191
|
+
current_line = [word]
|
|
192
|
+
|
|
193
|
+
if current_line:
|
|
194
|
+
lines.append(" ".join(current_line))
|
|
195
|
+
|
|
196
|
+
return lines if lines else [""]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def annotate_test_image(
|
|
200
|
+
image_path: Path,
|
|
201
|
+
tool_calls: list[dict[str, Any]],
|
|
202
|
+
pixel_coords: tuple[int, int],
|
|
203
|
+
prompt: str,
|
|
204
|
+
output_path: Path | None = None,
|
|
205
|
+
bbox_pixels: tuple[int, int, int, int] | None = None,
|
|
206
|
+
) -> Path:
|
|
207
|
+
"""Annotate a test image with tool call output and prompt.
|
|
208
|
+
|
|
209
|
+
Draws:
|
|
210
|
+
- Red crosshair at the click location (for click/hover tasks)
|
|
211
|
+
- Red bounding box rectangle (for grounding tasks with bbox_pixels)
|
|
212
|
+
- Extends canvas with white bar at bottom for prompt and <tool_call> output
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
image_path: Path to the original test image.
|
|
216
|
+
tool_calls: List of tool call dicts (from ToolCall.to_dict()).
|
|
217
|
+
pixel_coords: The (x, y) pixel coordinates for crosshair.
|
|
218
|
+
prompt: The user prompt text to display.
|
|
219
|
+
output_path: Where to save the annotated image. If None, saves
|
|
220
|
+
to same directory with "_annotated" suffix.
|
|
221
|
+
bbox_pixels: Optional bounding box as (x, y, width, height) in pixels.
|
|
222
|
+
If provided, draws a rectangle instead of crosshair.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Path to the annotated image.
|
|
226
|
+
"""
|
|
227
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
228
|
+
|
|
229
|
+
# Load original image
|
|
230
|
+
original = Image.open(image_path).convert("RGB")
|
|
231
|
+
orig_width, orig_height = original.size
|
|
232
|
+
|
|
233
|
+
# Try to load a monospace font for JSON, fall back to default
|
|
234
|
+
try:
|
|
235
|
+
font = ImageFont.truetype("/System/Library/Fonts/Menlo.ttc", 11)
|
|
236
|
+
except (OSError, IOError):
|
|
237
|
+
try:
|
|
238
|
+
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 11)
|
|
239
|
+
except (OSError, IOError):
|
|
240
|
+
font = ImageFont.load_default()
|
|
241
|
+
|
|
242
|
+
# Create temporary draw to measure text for wrapping
|
|
243
|
+
temp_img = Image.new("RGB", (1, 1))
|
|
244
|
+
temp_draw = ImageDraw.Draw(temp_img)
|
|
245
|
+
|
|
246
|
+
# Wrap the prompt text to fit image width (with margins)
|
|
247
|
+
margin = 10
|
|
248
|
+
max_text_width = orig_width - margin * 2
|
|
249
|
+
prompt_text = f"Prompt: {prompt}"
|
|
250
|
+
wrapped_prompt = _wrap_text(prompt_text, font, max_text_width, temp_draw)
|
|
251
|
+
|
|
252
|
+
# Format tool calls as prettified JSON with <tool_call> tags
|
|
253
|
+
tool_call_lines: list[str] = []
|
|
254
|
+
for tc in tool_calls:
|
|
255
|
+
tool_call_lines.append("<tool_call>")
|
|
256
|
+
# Pretty print JSON with 2-space indent
|
|
257
|
+
pretty_json = json.dumps(tc, indent=2)
|
|
258
|
+
for json_line in pretty_json.split("\n"):
|
|
259
|
+
tool_call_lines.append(json_line)
|
|
260
|
+
tool_call_lines.append("</tool_call>")
|
|
261
|
+
|
|
262
|
+
# Calculate bar height based on number of lines
|
|
263
|
+
line_height = 14
|
|
264
|
+
total_lines = len(wrapped_prompt) + len(tool_call_lines) + 1 # +1 for spacing
|
|
265
|
+
bar_height = (total_lines * line_height) + 12 # +12 for padding
|
|
266
|
+
|
|
267
|
+
# Create new canvas with extra height for prompt and tool call output
|
|
268
|
+
new_height = orig_height + bar_height
|
|
269
|
+
img = Image.new("RGB", (orig_width, new_height), (255, 255, 255))
|
|
270
|
+
|
|
271
|
+
# Paste original image at top
|
|
272
|
+
img.paste(original, (0, 0))
|
|
273
|
+
|
|
274
|
+
draw = ImageDraw.Draw(img)
|
|
275
|
+
annotation_color = (255, 0, 0) # Red
|
|
276
|
+
|
|
277
|
+
if bbox_pixels is not None:
|
|
278
|
+
# Draw bounding box rectangle for grounding tasks
|
|
279
|
+
bx, by, bw, bh = bbox_pixels
|
|
280
|
+
draw.rectangle(
|
|
281
|
+
[(bx, by), (bx + bw, by + bh)],
|
|
282
|
+
outline=annotation_color,
|
|
283
|
+
width=3,
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
# Draw crosshair at click location
|
|
287
|
+
x, y = pixel_coords
|
|
288
|
+
crosshair_size = 10
|
|
289
|
+
# Horizontal line
|
|
290
|
+
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=annotation_color, width=2)
|
|
291
|
+
# Vertical line
|
|
292
|
+
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=annotation_color, width=2)
|
|
293
|
+
# Circle around crosshair
|
|
294
|
+
draw.ellipse(
|
|
295
|
+
[(x - crosshair_size, y - crosshair_size), (x + crosshair_size, y + crosshair_size)],
|
|
296
|
+
outline=annotation_color,
|
|
297
|
+
width=2,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Draw wrapped prompt text in the extended area below the original image
|
|
301
|
+
current_y = orig_height + 4
|
|
302
|
+
for line in wrapped_prompt:
|
|
303
|
+
draw.text((5, current_y), line, fill=(0, 0, 0), font=font)
|
|
304
|
+
current_y += line_height
|
|
305
|
+
|
|
306
|
+
# Add spacing
|
|
307
|
+
current_y += 4
|
|
308
|
+
|
|
309
|
+
# Draw tool call output (model response)
|
|
310
|
+
for line in tool_call_lines:
|
|
311
|
+
# Color the XML tags differently
|
|
312
|
+
if line.startswith("<tool_call>") or line.startswith("</tool_call>"):
|
|
313
|
+
draw.text((5, current_y), line, fill=(128, 0, 128), font=font) # Purple for tags
|
|
314
|
+
else:
|
|
315
|
+
draw.text((5, current_y), line, fill=(0, 100, 0), font=font) # Dark green for JSON
|
|
316
|
+
current_y += line_height
|
|
317
|
+
|
|
318
|
+
# Determine output path
|
|
319
|
+
if output_path is None:
|
|
320
|
+
stem = image_path.stem
|
|
321
|
+
output_path = image_path.parent / f"{stem}_annotated{image_path.suffix}"
|
|
322
|
+
|
|
323
|
+
img.save(output_path)
|
|
324
|
+
return output_path
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class DatasetBuilder:
|
|
328
|
+
"""Orchestrates dataset generation from tasks.
|
|
329
|
+
|
|
330
|
+
Example:
|
|
331
|
+
builder = DatasetBuilder(
|
|
332
|
+
config=DatasetConfig(name_prefix="calendar", task_counts={"click-day": 1000}),
|
|
333
|
+
tasks=[ClickDayTask(config, renderer)],
|
|
334
|
+
)
|
|
335
|
+
builder.build()
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
config: DatasetConfig,
|
|
341
|
+
tasks: list[BaseTask],
|
|
342
|
+
) -> None:
|
|
343
|
+
"""Initialize the builder.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
config: Dataset configuration
|
|
347
|
+
tasks: List of task instances to generate from
|
|
348
|
+
"""
|
|
349
|
+
self.config = config
|
|
350
|
+
self.tasks = {t.task_type: t for t in tasks}
|
|
351
|
+
self.rng = random.Random(config.seed)
|
|
352
|
+
|
|
353
|
+
def build(
|
|
354
|
+
self,
|
|
355
|
+
start_index: int = 0,
|
|
356
|
+
checkpoint_callback: Callable[[int], None] | None = None,
|
|
357
|
+
checkpoint_interval: int = 1000,
|
|
358
|
+
) -> Path:
|
|
359
|
+
"""Generate the complete dataset.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
start_index: Skip samples up to this index (for resume after preemption)
|
|
363
|
+
checkpoint_callback: Called with sample count every checkpoint_interval samples
|
|
364
|
+
checkpoint_interval: How often to call checkpoint_callback (default 1000)
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Path to the output directory
|
|
368
|
+
"""
|
|
369
|
+
output_dir = self.config.output_dir
|
|
370
|
+
assert output_dir is not None
|
|
371
|
+
|
|
372
|
+
# Create directories
|
|
373
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
374
|
+
(output_dir / "images").mkdir(exist_ok=True)
|
|
375
|
+
|
|
376
|
+
# Generate samples
|
|
377
|
+
samples: list[dict[str, Any]] = []
|
|
378
|
+
held_out: list[dict[str, Any]] = []
|
|
379
|
+
index = 0
|
|
380
|
+
samples_generated = 0
|
|
381
|
+
last_checkpoint = 0
|
|
382
|
+
|
|
383
|
+
for task_type, count in self.config.task_counts.items():
|
|
384
|
+
if task_type not in self.tasks:
|
|
385
|
+
raise ValueError(f"Unknown task type: {task_type}")
|
|
386
|
+
|
|
387
|
+
task = self.tasks[task_type]
|
|
388
|
+
for _ in range(count):
|
|
389
|
+
# Skip samples until we reach start_index
|
|
390
|
+
if index < start_index:
|
|
391
|
+
index += 1
|
|
392
|
+
continue
|
|
393
|
+
|
|
394
|
+
ctx = TaskContext(
|
|
395
|
+
rng=self.rng,
|
|
396
|
+
index=index,
|
|
397
|
+
output_dir=output_dir,
|
|
398
|
+
config=task.config,
|
|
399
|
+
dataset_name=self.config.name_prefix,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Use generate_samples() for 1:N image-to-samples pattern
|
|
403
|
+
# A single render can produce multiple training samples
|
|
404
|
+
task_samples = task.generate_samples(ctx)
|
|
405
|
+
for sample in task_samples:
|
|
406
|
+
record = self._to_record(sample)
|
|
407
|
+
|
|
408
|
+
# Decide if this should be held out
|
|
409
|
+
if self.config.held_out_enabled and self.rng.random() < self.config.held_out_ratio:
|
|
410
|
+
held_out.append(record)
|
|
411
|
+
else:
|
|
412
|
+
samples.append(record)
|
|
413
|
+
|
|
414
|
+
index += 1
|
|
415
|
+
samples_generated += 1
|
|
416
|
+
|
|
417
|
+
# Checkpoint callback
|
|
418
|
+
if checkpoint_callback and samples_generated - last_checkpoint >= checkpoint_interval:
|
|
419
|
+
checkpoint_callback(samples_generated)
|
|
420
|
+
last_checkpoint = samples_generated
|
|
421
|
+
|
|
422
|
+
# Write data files
|
|
423
|
+
self._write_jsonl(output_dir / "data.jsonl", samples + held_out)
|
|
424
|
+
self._write_splits(output_dir, samples)
|
|
425
|
+
|
|
426
|
+
if held_out:
|
|
427
|
+
self._write_jsonl(output_dir / "held_out.jsonl", held_out)
|
|
428
|
+
|
|
429
|
+
# Write config for reference
|
|
430
|
+
self._write_config(output_dir)
|
|
431
|
+
|
|
432
|
+
print(f"Generated {len(samples)} training samples, {len(held_out)} held out")
|
|
433
|
+
print(f"Output: {output_dir}")
|
|
434
|
+
|
|
435
|
+
return output_dir
|
|
436
|
+
|
|
437
|
+
def _to_record(self, sample: TaskSample) -> dict[str, Any]:
|
|
438
|
+
"""Convert TaskSample to JSONL record."""
|
|
439
|
+
# Get normalized coordinates
|
|
440
|
+
norm_coord = normalize_coord(sample.pixel_coords, sample.image_size)
|
|
441
|
+
|
|
442
|
+
# Check if sample has multiple tool_calls in metadata
|
|
443
|
+
if "tool_calls" in sample.metadata and len(sample.metadata["tool_calls"]) > 1:
|
|
444
|
+
# Format all tool calls for multi-action samples
|
|
445
|
+
gpt_parts = []
|
|
446
|
+
for tc in sample.metadata["tool_calls"]:
|
|
447
|
+
gpt_parts.append(format_tool_call(tc))
|
|
448
|
+
gpt_value = "\n".join(gpt_parts)
|
|
449
|
+
else:
|
|
450
|
+
# Single tool call - update with normalized coordinates
|
|
451
|
+
tool_call = sample.tool_call.to_dict()
|
|
452
|
+
if "coordinate" in tool_call["arguments"]:
|
|
453
|
+
tool_call["arguments"]["coordinate"] = list(norm_coord)
|
|
454
|
+
gpt_value = format_tool_call(tool_call)
|
|
455
|
+
|
|
456
|
+
# Build relative image path
|
|
457
|
+
assert self.config.output_dir is not None
|
|
458
|
+
image_rel = str(sample.image_path.relative_to(self.config.output_dir))
|
|
459
|
+
|
|
460
|
+
return {
|
|
461
|
+
"id": sample.id,
|
|
462
|
+
"image": image_rel,
|
|
463
|
+
"conversations": [
|
|
464
|
+
{"from": "human", "value": f"<image>\n{sample.human_prompt}"},
|
|
465
|
+
{"from": "gpt", "value": gpt_value},
|
|
466
|
+
],
|
|
467
|
+
"metadata": {
|
|
468
|
+
"task_type": sample.metadata.get("task_type", "unknown"),
|
|
469
|
+
"real_coords": list(sample.pixel_coords),
|
|
470
|
+
**{k: v for k, v in sample.metadata.items() if k != "task_type"},
|
|
471
|
+
},
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
def _write_jsonl(self, path: Path, records: list[dict[str, Any]]) -> None:
|
|
475
|
+
"""Write records to JSONL file."""
|
|
476
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
477
|
+
for record in records:
|
|
478
|
+
f.write(json.dumps(record) + "\n")
|
|
479
|
+
|
|
480
|
+
def _write_splits(self, output_dir: Path, samples: list[dict[str, Any]]) -> None:
|
|
481
|
+
"""Split samples and write train/val files."""
|
|
482
|
+
# Shuffle for splitting
|
|
483
|
+
shuffled = samples.copy()
|
|
484
|
+
self.rng.shuffle(shuffled)
|
|
485
|
+
|
|
486
|
+
split_idx = int(len(shuffled) * self.config.train_split)
|
|
487
|
+
train = shuffled[:split_idx]
|
|
488
|
+
val = shuffled[split_idx:]
|
|
489
|
+
|
|
490
|
+
self._write_jsonl(output_dir / "train.jsonl", train)
|
|
491
|
+
self._write_jsonl(output_dir / "val.jsonl", val)
|
|
492
|
+
|
|
493
|
+
print(f"Split: {len(train)} train, {len(val)} val")
|
|
494
|
+
|
|
495
|
+
def _write_config(self, output_dir: Path) -> None:
|
|
496
|
+
"""Write generation config for reference."""
|
|
497
|
+
# Extract task_types from task_counts keys
|
|
498
|
+
task_types = list(self.config.task_counts.keys())
|
|
499
|
+
|
|
500
|
+
config_data = {
|
|
501
|
+
"name_prefix": self.config.name_prefix,
|
|
502
|
+
"seed": self.config.seed,
|
|
503
|
+
"task_types": task_types,
|
|
504
|
+
"task_counts": self.config.task_counts,
|
|
505
|
+
"train_split": self.config.train_split,
|
|
506
|
+
"system_prompt": self.config.system_prompt,
|
|
507
|
+
"task_distributions": self.config.task_distributions,
|
|
508
|
+
"generated_at": datetime.now().isoformat(),
|
|
509
|
+
}
|
|
510
|
+
with open(output_dir / "config.json", "w") as f:
|
|
511
|
+
json.dump(config_data, f, indent=2)
|
|
512
|
+
|
|
513
|
+
def build_tests(self) -> Path:
|
|
514
|
+
"""Generate test cases.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
Path to the test directory
|
|
518
|
+
"""
|
|
519
|
+
output_dir = self.config.output_dir
|
|
520
|
+
assert output_dir is not None
|
|
521
|
+
|
|
522
|
+
# Create test directory structure (test/images/)
|
|
523
|
+
test_dir = output_dir / "test"
|
|
524
|
+
test_dir.mkdir(parents=True, exist_ok=True)
|
|
525
|
+
(test_dir / "images").mkdir(exist_ok=True)
|
|
526
|
+
|
|
527
|
+
# Create annotated directory if annotations enabled
|
|
528
|
+
annotated_dir = test_dir / "annotated"
|
|
529
|
+
if self.config.annotation_enabled:
|
|
530
|
+
annotated_dir.mkdir(exist_ok=True)
|
|
531
|
+
|
|
532
|
+
# Generate test cases - stop when we reach test_count
|
|
533
|
+
test_cases: list[dict[str, Any]] = []
|
|
534
|
+
raw_test_cases: list[TestCase] = []
|
|
535
|
+
index = 0
|
|
536
|
+
|
|
537
|
+
# Get task types to iterate through
|
|
538
|
+
task_types = [t for t in self.config.task_counts.keys() if t in self.tasks]
|
|
539
|
+
if not task_types:
|
|
540
|
+
return test_dir
|
|
541
|
+
|
|
542
|
+
# Generate test_count tests PER task type
|
|
543
|
+
for task_type in task_types:
|
|
544
|
+
if task_type not in self.tasks:
|
|
545
|
+
continue
|
|
546
|
+
task = self.tasks[task_type]
|
|
547
|
+
generated = 0
|
|
548
|
+
|
|
549
|
+
while generated < self.config.test_count:
|
|
550
|
+
ctx = TaskContext(
|
|
551
|
+
rng=self.rng,
|
|
552
|
+
index=index,
|
|
553
|
+
output_dir=test_dir,
|
|
554
|
+
config=task.config,
|
|
555
|
+
dataset_name=self.config.name_prefix,
|
|
556
|
+
)
|
|
557
|
+
tests = task.generate_tests(ctx)
|
|
558
|
+
for test_case in tests:
|
|
559
|
+
if generated >= self.config.test_count:
|
|
560
|
+
break
|
|
561
|
+
record = self._test_to_record(test_case, test_dir)
|
|
562
|
+
test_cases.append(record)
|
|
563
|
+
raw_test_cases.append(test_case)
|
|
564
|
+
generated += 1
|
|
565
|
+
index += 1
|
|
566
|
+
|
|
567
|
+
# ALWAYS generate 1 annotation per task type for tests (regardless of config)
|
|
568
|
+
# For grounding tasks, generate 1 annotation per unique element_label
|
|
569
|
+
annotated_count = 0
|
|
570
|
+
annotated_dir.mkdir(exist_ok=True)
|
|
571
|
+
|
|
572
|
+
# Group indices by task type (and element_label for grounding)
|
|
573
|
+
indices_by_key: dict[str, list[int]] = {}
|
|
574
|
+
for idx, test_case in enumerate(raw_test_cases):
|
|
575
|
+
task_type = test_case.metadata.get("task_type", "unknown")
|
|
576
|
+
# For grounding tasks, use element_label as additional key
|
|
577
|
+
if task_type == "grounding":
|
|
578
|
+
element_label = test_case.metadata.get("element_label", "unknown")
|
|
579
|
+
key = f"grounding:{element_label}"
|
|
580
|
+
else:
|
|
581
|
+
key = task_type
|
|
582
|
+
if key not in indices_by_key:
|
|
583
|
+
indices_by_key[key] = []
|
|
584
|
+
indices_by_key[key].append(idx)
|
|
585
|
+
|
|
586
|
+
# Select 1 index per key to annotate
|
|
587
|
+
indices_to_annotate: set[int] = set()
|
|
588
|
+
for key, indices in indices_by_key.items():
|
|
589
|
+
if indices:
|
|
590
|
+
indices_to_annotate.add(indices[0])
|
|
591
|
+
|
|
592
|
+
for idx in sorted(indices_to_annotate):
|
|
593
|
+
if idx >= len(raw_test_cases):
|
|
594
|
+
continue
|
|
595
|
+
test_case = raw_test_cases[idx]
|
|
596
|
+
|
|
597
|
+
# Get pixel coordinates for crosshair
|
|
598
|
+
pixel_coords = test_case.pixel_coords or (0, 0)
|
|
599
|
+
|
|
600
|
+
# Build tool_calls list from expected_action and any additional actions
|
|
601
|
+
tool_calls = [test_case.expected_action]
|
|
602
|
+
|
|
603
|
+
# Check for additional tool calls in metadata (e.g., type action for textfields)
|
|
604
|
+
if "additional_tool_calls" in test_case.metadata:
|
|
605
|
+
tool_calls.extend(test_case.metadata["additional_tool_calls"])
|
|
606
|
+
|
|
607
|
+
# Generate annotated image - include task type in filename
|
|
608
|
+
task_type = test_case.metadata.get("task_type", "unknown")
|
|
609
|
+
# For grounding, include element_label in filename
|
|
610
|
+
if task_type == "grounding":
|
|
611
|
+
element_label = test_case.metadata.get("element_label", "unknown")
|
|
612
|
+
annotated_path = annotated_dir / f"{task_type}_{element_label}_{test_case.test_id}_annotated.png"
|
|
613
|
+
else:
|
|
614
|
+
annotated_path = annotated_dir / f"{task_type}_{test_case.test_id}_annotated.png"
|
|
615
|
+
|
|
616
|
+
# Extract bbox_pixels for grounding tasks (format: [x, y, width, height])
|
|
617
|
+
bbox_pixels = None
|
|
618
|
+
if "bbox_pixels" in test_case.metadata:
|
|
619
|
+
bp = test_case.metadata["bbox_pixels"]
|
|
620
|
+
bbox_pixels = (bp[0], bp[1], bp[2], bp[3])
|
|
621
|
+
|
|
622
|
+
annotate_test_image(
|
|
623
|
+
image_path=test_case.screenshot,
|
|
624
|
+
tool_calls=tool_calls,
|
|
625
|
+
pixel_coords=pixel_coords,
|
|
626
|
+
prompt=test_case.prompt,
|
|
627
|
+
output_path=annotated_path,
|
|
628
|
+
bbox_pixels=bbox_pixels,
|
|
629
|
+
)
|
|
630
|
+
annotated_count += 1
|
|
631
|
+
|
|
632
|
+
# Write test.json
|
|
633
|
+
with open(test_dir / "test.json", "w", encoding="utf-8") as f:
|
|
634
|
+
json.dump(test_cases, f, indent=2)
|
|
635
|
+
|
|
636
|
+
if annotated_count > 0:
|
|
637
|
+
print(f"Generated {len(test_cases)} test cases ({annotated_count} annotated)")
|
|
638
|
+
else:
|
|
639
|
+
print(f"Generated {len(test_cases)} test cases")
|
|
640
|
+
|
|
641
|
+
return test_dir
|
|
642
|
+
|
|
643
|
+
def _test_to_record(self, test_case: TestCase, test_dir: Path) -> dict[str, Any]:
|
|
644
|
+
"""Convert TestCase to record for test.json."""
|
|
645
|
+
# Get image size from metadata if available, default to 1920x1080
|
|
646
|
+
image_size = test_case.metadata.get("image_size", (1920, 1080))
|
|
647
|
+
|
|
648
|
+
# Normalize coordinates in expected_action
|
|
649
|
+
expected_action = test_case.expected_action.copy()
|
|
650
|
+
if "arguments" in expected_action and "coordinate" in expected_action["arguments"]:
|
|
651
|
+
pixel_coords = expected_action["arguments"]["coordinate"]
|
|
652
|
+
if test_case.pixel_coords:
|
|
653
|
+
pixel_coords = test_case.pixel_coords
|
|
654
|
+
norm_coord = normalize_coord(tuple(pixel_coords), image_size)
|
|
655
|
+
expected_action["arguments"]["coordinate"] = list(norm_coord)
|
|
656
|
+
|
|
657
|
+
# Build relative screenshot path (relative to test_dir)
|
|
658
|
+
screenshot_rel = str(test_case.screenshot.relative_to(test_dir))
|
|
659
|
+
|
|
660
|
+
# Tolerance can come from test_case directly or from metadata
|
|
661
|
+
# Convert to list for JSON serialization
|
|
662
|
+
tolerance = test_case.tolerance
|
|
663
|
+
if isinstance(tolerance, tuple):
|
|
664
|
+
tolerance = list(tolerance)
|
|
665
|
+
elif isinstance(tolerance, int):
|
|
666
|
+
tolerance = [tolerance, tolerance]
|
|
667
|
+
|
|
668
|
+
return {
|
|
669
|
+
"test_id": test_case.test_id,
|
|
670
|
+
"screenshot": screenshot_rel,
|
|
671
|
+
"prompt": test_case.prompt,
|
|
672
|
+
"expected_action": expected_action,
|
|
673
|
+
"tolerance": tolerance,
|
|
674
|
+
"metadata": {
|
|
675
|
+
"real_coords": list(test_case.pixel_coords) if test_case.pixel_coords else None,
|
|
676
|
+
**test_case.metadata,
|
|
677
|
+
},
|
|
678
|
+
}
|