cudag 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. cudag/__init__.py +334 -0
  2. cudag/annotation/__init__.py +77 -0
  3. cudag/annotation/codegen.py +648 -0
  4. cudag/annotation/config.py +545 -0
  5. cudag/annotation/loader.py +342 -0
  6. cudag/annotation/scaffold.py +121 -0
  7. cudag/annotation/transcription.py +296 -0
  8. cudag/cli/__init__.py +5 -0
  9. cudag/cli/main.py +315 -0
  10. cudag/cli/new.py +873 -0
  11. cudag/core/__init__.py +364 -0
  12. cudag/core/button.py +137 -0
  13. cudag/core/canvas.py +222 -0
  14. cudag/core/config.py +70 -0
  15. cudag/core/coords.py +233 -0
  16. cudag/core/data_grid.py +804 -0
  17. cudag/core/dataset.py +678 -0
  18. cudag/core/distribution.py +136 -0
  19. cudag/core/drawing.py +75 -0
  20. cudag/core/fonts.py +156 -0
  21. cudag/core/generator.py +163 -0
  22. cudag/core/grid.py +367 -0
  23. cudag/core/grounding_task.py +247 -0
  24. cudag/core/icon.py +207 -0
  25. cudag/core/iconlist_task.py +301 -0
  26. cudag/core/models.py +1251 -0
  27. cudag/core/random.py +130 -0
  28. cudag/core/renderer.py +190 -0
  29. cudag/core/screen.py +402 -0
  30. cudag/core/scroll_task.py +254 -0
  31. cudag/core/scrollable_grid.py +447 -0
  32. cudag/core/state.py +110 -0
  33. cudag/core/task.py +293 -0
  34. cudag/core/taskbar.py +350 -0
  35. cudag/core/text.py +212 -0
  36. cudag/core/utils.py +82 -0
  37. cudag/data/surnames.txt +5000 -0
  38. cudag/modal_apps/__init__.py +4 -0
  39. cudag/modal_apps/archive.py +103 -0
  40. cudag/modal_apps/extract.py +138 -0
  41. cudag/modal_apps/preprocess.py +529 -0
  42. cudag/modal_apps/upload.py +317 -0
  43. cudag/prompts/SYSTEM_PROMPT.txt +104 -0
  44. cudag/prompts/__init__.py +33 -0
  45. cudag/prompts/system.py +43 -0
  46. cudag/prompts/tools.py +382 -0
  47. cudag/py.typed +0 -0
  48. cudag/schemas/filesystem.json +90 -0
  49. cudag/schemas/test_record.schema.json +113 -0
  50. cudag/schemas/train_record.schema.json +90 -0
  51. cudag/server/__init__.py +21 -0
  52. cudag/server/app.py +232 -0
  53. cudag/server/services/__init__.py +9 -0
  54. cudag/server/services/generator.py +128 -0
  55. cudag/templates/scripts/archive.sh +35 -0
  56. cudag/templates/scripts/build.sh +13 -0
  57. cudag/templates/scripts/extract.sh +54 -0
  58. cudag/templates/scripts/generate.sh +116 -0
  59. cudag/templates/scripts/pre-commit.sh +44 -0
  60. cudag/templates/scripts/preprocess.sh +46 -0
  61. cudag/templates/scripts/upload.sh +63 -0
  62. cudag/templates/scripts/verify.py +428 -0
  63. cudag/validation/__init__.py +35 -0
  64. cudag/validation/validate.py +508 -0
  65. cudag-0.3.10.dist-info/METADATA +570 -0
  66. cudag-0.3.10.dist-info/RECORD +69 -0
  67. cudag-0.3.10.dist-info/WHEEL +4 -0
  68. cudag-0.3.10.dist-info/entry_points.txt +2 -0
  69. cudag-0.3.10.dist-info/licenses/LICENSE +66 -0
cudag/core/dataset.py ADDED
@@ -0,0 +1,678 @@
1
+ # Copyright (c) 2025 Tylt LLC. All rights reserved.
2
+ # CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
3
+ # is strictly prohibited. For licensing inquiries: hello@claimhawk.app
4
+
5
+ """Dataset builder for orchestrating sample generation.
6
+
7
+ The DatasetBuilder coordinates Screen, State, Renderer, and Tasks
8
+ to produce JSONL training datasets.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import random
15
+ from dataclasses import dataclass, field
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+ from typing import Any, Callable
19
+
20
+ from cudag.core.coords import normalize_coord
21
+ from cudag.core.task import BaseTask, TaskContext, TaskSample, TestCase
22
+ from cudag.prompts.tools import format_tool_call
23
+
24
+
25
+ @dataclass
26
+ class DatasetConfig:
27
+ """Configuration for dataset generation."""
28
+
29
+ name_prefix: str
30
+ """Prefix for dataset name (e.g., "calendar-mike")."""
31
+
32
+ seed: int = 42
33
+ """Random seed for reproducibility."""
34
+
35
+ task_counts: dict[str, int] = field(default_factory=dict)
36
+ """Number of samples per task type."""
37
+
38
+ train_split: float = 0.8
39
+ """Fraction of data for training (rest is test/val)."""
40
+
41
+ system_prompt: str = "computer-use"
42
+ """System prompt style: "computer-use", "compact", or custom."""
43
+
44
+ output_dir: Path | None = None
45
+ """Output directory (auto-generated if None)."""
46
+
47
+ image_format: str = "png"
48
+ """Image format: "png" or "jpg"."""
49
+
50
+ image_quality: int = 95
51
+ """JPEG quality (ignored for PNG)."""
52
+
53
+ held_out_enabled: bool = False
54
+ """Whether to hold out samples for evaluation."""
55
+
56
+ held_out_ratio: float = 0.1
57
+ """Fraction of samples to hold out."""
58
+
59
+ test_count: int = 100
60
+ """Number of test cases to generate PER TASK TYPE."""
61
+
62
+ test_distribution: dict[str, int] = field(default_factory=dict)
63
+ """Per-task test counts. Overrides auto-distribution when set."""
64
+
65
+ test_tolerance: tuple[int, int] = (10, 10)
66
+ """Coordinate tolerance for test (x, y in RU units)."""
67
+
68
+ annotation_ratio: float = 0.1
69
+ """Fraction of test cases to annotate (0.0-1.0)."""
70
+
71
+ annotation_enabled: bool = True
72
+ """Whether to generate annotated test images."""
73
+
74
+ annotation_per_type: dict[str, int] = field(default_factory=dict)
75
+ """Number of annotations per task type. Overrides annotation_ratio when set."""
76
+
77
+ task_distributions: dict[str, dict[str, float]] = field(default_factory=dict)
78
+ """Distribution of sample types within each task type.
79
+
80
+ Example:
81
+ task_distributions:
82
+ click-appointment:
83
+ grey_grey: 0.80 # 80% grey background + grey status
84
+ other_colors: 0.15 # 15% other color combos
85
+ adversarial: 0.05 # 5% no match cases
86
+ hover-appointment:
87
+ grey_grey: 0.80
88
+ other_colors: 0.15
89
+ adversarial: 0.05
90
+ """
91
+
92
+ def __post_init__(self) -> None:
93
+ """Set default output directory if not provided."""
94
+ if self.output_dir is None:
95
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
96
+ self.output_dir = Path("datasets") / f"{self.name_prefix}_{timestamp}"
97
+
98
+ def get_distribution(self, task_type: str) -> dict[str, float]:
99
+ """Get distribution for a task type.
100
+
101
+ Returns the task-specific distribution if defined, otherwise
102
+ returns an empty dict (task should use its own defaults).
103
+ """
104
+ return self.task_distributions.get(task_type, {})
105
+
106
+ def sample_distribution_type(self, task_type: str, rng: Any) -> str | None:
107
+ """Sample a distribution type for a task.
108
+
109
+ Args:
110
+ task_type: The task type to sample for.
111
+ rng: Random number generator.
112
+
113
+ Returns:
114
+ The sampled distribution type name, or None if no distribution defined.
115
+ """
116
+ dist = self.get_distribution(task_type)
117
+ if not dist:
118
+ return None
119
+
120
+ roll = rng.random()
121
+ cumulative = 0.0
122
+ for dist_type, prob in dist.items():
123
+ cumulative += prob
124
+ if roll < cumulative:
125
+ return dist_type
126
+ # Return last type if we somehow miss due to float precision
127
+ return list(dist.keys())[-1] if dist else None
128
+
129
+ @classmethod
130
+ def from_yaml(cls, path: Path) -> DatasetConfig:
131
+ """Load config from YAML file."""
132
+ import yaml
133
+
134
+ with open(path) as f:
135
+ data = yaml.safe_load(f)
136
+
137
+ return cls(
138
+ name_prefix=data.get("name_prefix", "dataset"),
139
+ seed=data.get("seed", 42),
140
+ task_counts=data.get("tasks", {}),
141
+ train_split=data.get("splits", {}).get("train", 0.8),
142
+ system_prompt=data.get("system_prompt", "computer-use"),
143
+ output_dir=Path(data["output_dir"]) if "output_dir" in data else None,
144
+ image_format=data.get("output", {}).get("image_format", "png"),
145
+ image_quality=data.get("output", {}).get("image_quality", 95),
146
+ held_out_enabled=data.get("held_out", {}).get("enabled", False),
147
+ held_out_ratio=data.get("held_out", {}).get("ratio", 0.1),
148
+ test_count=data.get("test", {}).get("count", 100),
149
+ test_distribution=data.get("test", {}).get("distribution", {}),
150
+ test_tolerance=_parse_tolerance(data.get("test", {}).get("tolerance", [10, 10])),
151
+ annotation_ratio=data.get("annotation", {}).get("ratio", 0.1),
152
+ annotation_enabled=data.get("annotation", {}).get("enabled", True),
153
+ annotation_per_type=data.get("annotation", {}).get("per_type", {}),
154
+ task_distributions=data.get("task_distributions", {}),
155
+ )
156
+
157
+
158
+ def _parse_tolerance(value: int | list[int]) -> tuple[int, int]:
159
+ """Parse tolerance from config - handles both int and [x, y] formats."""
160
+ if isinstance(value, int):
161
+ return (value, value)
162
+ return tuple(value) # type: ignore[return-value]
163
+
164
+
165
+ def _wrap_text(text: str, font: Any, max_width: int, draw: Any) -> list[str]:
166
+ """Wrap text to fit within max_width pixels.
167
+
168
+ Args:
169
+ text: The text to wrap.
170
+ font: PIL ImageFont to use for measuring.
171
+ max_width: Maximum width in pixels.
172
+ draw: PIL ImageDraw for measuring text.
173
+
174
+ Returns:
175
+ List of wrapped lines.
176
+ """
177
+ words = text.split()
178
+ lines = []
179
+ current_line: list[str] = []
180
+
181
+ for word in words:
182
+ test_line = " ".join(current_line + [word])
183
+ bbox = draw.textbbox((0, 0), test_line, font=font)
184
+ width = bbox[2] - bbox[0]
185
+
186
+ if width <= max_width:
187
+ current_line.append(word)
188
+ else:
189
+ if current_line:
190
+ lines.append(" ".join(current_line))
191
+ current_line = [word]
192
+
193
+ if current_line:
194
+ lines.append(" ".join(current_line))
195
+
196
+ return lines if lines else [""]
197
+
198
+
199
+ def annotate_test_image(
200
+ image_path: Path,
201
+ tool_calls: list[dict[str, Any]],
202
+ pixel_coords: tuple[int, int],
203
+ prompt: str,
204
+ output_path: Path | None = None,
205
+ bbox_pixels: tuple[int, int, int, int] | None = None,
206
+ ) -> Path:
207
+ """Annotate a test image with tool call output and prompt.
208
+
209
+ Draws:
210
+ - Red crosshair at the click location (for click/hover tasks)
211
+ - Red bounding box rectangle (for grounding tasks with bbox_pixels)
212
+ - Extends canvas with white bar at bottom for prompt and <tool_call> output
213
+
214
+ Args:
215
+ image_path: Path to the original test image.
216
+ tool_calls: List of tool call dicts (from ToolCall.to_dict()).
217
+ pixel_coords: The (x, y) pixel coordinates for crosshair.
218
+ prompt: The user prompt text to display.
219
+ output_path: Where to save the annotated image. If None, saves
220
+ to same directory with "_annotated" suffix.
221
+ bbox_pixels: Optional bounding box as (x, y, width, height) in pixels.
222
+ If provided, draws a rectangle instead of crosshair.
223
+
224
+ Returns:
225
+ Path to the annotated image.
226
+ """
227
+ from PIL import Image, ImageDraw, ImageFont
228
+
229
+ # Load original image
230
+ original = Image.open(image_path).convert("RGB")
231
+ orig_width, orig_height = original.size
232
+
233
+ # Try to load a monospace font for JSON, fall back to default
234
+ try:
235
+ font = ImageFont.truetype("/System/Library/Fonts/Menlo.ttc", 11)
236
+ except (OSError, IOError):
237
+ try:
238
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 11)
239
+ except (OSError, IOError):
240
+ font = ImageFont.load_default()
241
+
242
+ # Create temporary draw to measure text for wrapping
243
+ temp_img = Image.new("RGB", (1, 1))
244
+ temp_draw = ImageDraw.Draw(temp_img)
245
+
246
+ # Wrap the prompt text to fit image width (with margins)
247
+ margin = 10
248
+ max_text_width = orig_width - margin * 2
249
+ prompt_text = f"Prompt: {prompt}"
250
+ wrapped_prompt = _wrap_text(prompt_text, font, max_text_width, temp_draw)
251
+
252
+ # Format tool calls as prettified JSON with <tool_call> tags
253
+ tool_call_lines: list[str] = []
254
+ for tc in tool_calls:
255
+ tool_call_lines.append("<tool_call>")
256
+ # Pretty print JSON with 2-space indent
257
+ pretty_json = json.dumps(tc, indent=2)
258
+ for json_line in pretty_json.split("\n"):
259
+ tool_call_lines.append(json_line)
260
+ tool_call_lines.append("</tool_call>")
261
+
262
+ # Calculate bar height based on number of lines
263
+ line_height = 14
264
+ total_lines = len(wrapped_prompt) + len(tool_call_lines) + 1 # +1 for spacing
265
+ bar_height = (total_lines * line_height) + 12 # +12 for padding
266
+
267
+ # Create new canvas with extra height for prompt and tool call output
268
+ new_height = orig_height + bar_height
269
+ img = Image.new("RGB", (orig_width, new_height), (255, 255, 255))
270
+
271
+ # Paste original image at top
272
+ img.paste(original, (0, 0))
273
+
274
+ draw = ImageDraw.Draw(img)
275
+ annotation_color = (255, 0, 0) # Red
276
+
277
+ if bbox_pixels is not None:
278
+ # Draw bounding box rectangle for grounding tasks
279
+ bx, by, bw, bh = bbox_pixels
280
+ draw.rectangle(
281
+ [(bx, by), (bx + bw, by + bh)],
282
+ outline=annotation_color,
283
+ width=3,
284
+ )
285
+ else:
286
+ # Draw crosshair at click location
287
+ x, y = pixel_coords
288
+ crosshair_size = 10
289
+ # Horizontal line
290
+ draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=annotation_color, width=2)
291
+ # Vertical line
292
+ draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=annotation_color, width=2)
293
+ # Circle around crosshair
294
+ draw.ellipse(
295
+ [(x - crosshair_size, y - crosshair_size), (x + crosshair_size, y + crosshair_size)],
296
+ outline=annotation_color,
297
+ width=2,
298
+ )
299
+
300
+ # Draw wrapped prompt text in the extended area below the original image
301
+ current_y = orig_height + 4
302
+ for line in wrapped_prompt:
303
+ draw.text((5, current_y), line, fill=(0, 0, 0), font=font)
304
+ current_y += line_height
305
+
306
+ # Add spacing
307
+ current_y += 4
308
+
309
+ # Draw tool call output (model response)
310
+ for line in tool_call_lines:
311
+ # Color the XML tags differently
312
+ if line.startswith("<tool_call>") or line.startswith("</tool_call>"):
313
+ draw.text((5, current_y), line, fill=(128, 0, 128), font=font) # Purple for tags
314
+ else:
315
+ draw.text((5, current_y), line, fill=(0, 100, 0), font=font) # Dark green for JSON
316
+ current_y += line_height
317
+
318
+ # Determine output path
319
+ if output_path is None:
320
+ stem = image_path.stem
321
+ output_path = image_path.parent / f"{stem}_annotated{image_path.suffix}"
322
+
323
+ img.save(output_path)
324
+ return output_path
325
+
326
+
327
+ class DatasetBuilder:
328
+ """Orchestrates dataset generation from tasks.
329
+
330
+ Example:
331
+ builder = DatasetBuilder(
332
+ config=DatasetConfig(name_prefix="calendar", task_counts={"click-day": 1000}),
333
+ tasks=[ClickDayTask(config, renderer)],
334
+ )
335
+ builder.build()
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ config: DatasetConfig,
341
+ tasks: list[BaseTask],
342
+ ) -> None:
343
+ """Initialize the builder.
344
+
345
+ Args:
346
+ config: Dataset configuration
347
+ tasks: List of task instances to generate from
348
+ """
349
+ self.config = config
350
+ self.tasks = {t.task_type: t for t in tasks}
351
+ self.rng = random.Random(config.seed)
352
+
353
+ def build(
354
+ self,
355
+ start_index: int = 0,
356
+ checkpoint_callback: Callable[[int], None] | None = None,
357
+ checkpoint_interval: int = 1000,
358
+ ) -> Path:
359
+ """Generate the complete dataset.
360
+
361
+ Args:
362
+ start_index: Skip samples up to this index (for resume after preemption)
363
+ checkpoint_callback: Called with sample count every checkpoint_interval samples
364
+ checkpoint_interval: How often to call checkpoint_callback (default 1000)
365
+
366
+ Returns:
367
+ Path to the output directory
368
+ """
369
+ output_dir = self.config.output_dir
370
+ assert output_dir is not None
371
+
372
+ # Create directories
373
+ output_dir.mkdir(parents=True, exist_ok=True)
374
+ (output_dir / "images").mkdir(exist_ok=True)
375
+
376
+ # Generate samples
377
+ samples: list[dict[str, Any]] = []
378
+ held_out: list[dict[str, Any]] = []
379
+ index = 0
380
+ samples_generated = 0
381
+ last_checkpoint = 0
382
+
383
+ for task_type, count in self.config.task_counts.items():
384
+ if task_type not in self.tasks:
385
+ raise ValueError(f"Unknown task type: {task_type}")
386
+
387
+ task = self.tasks[task_type]
388
+ for _ in range(count):
389
+ # Skip samples until we reach start_index
390
+ if index < start_index:
391
+ index += 1
392
+ continue
393
+
394
+ ctx = TaskContext(
395
+ rng=self.rng,
396
+ index=index,
397
+ output_dir=output_dir,
398
+ config=task.config,
399
+ dataset_name=self.config.name_prefix,
400
+ )
401
+
402
+ # Use generate_samples() for 1:N image-to-samples pattern
403
+ # A single render can produce multiple training samples
404
+ task_samples = task.generate_samples(ctx)
405
+ for sample in task_samples:
406
+ record = self._to_record(sample)
407
+
408
+ # Decide if this should be held out
409
+ if self.config.held_out_enabled and self.rng.random() < self.config.held_out_ratio:
410
+ held_out.append(record)
411
+ else:
412
+ samples.append(record)
413
+
414
+ index += 1
415
+ samples_generated += 1
416
+
417
+ # Checkpoint callback
418
+ if checkpoint_callback and samples_generated - last_checkpoint >= checkpoint_interval:
419
+ checkpoint_callback(samples_generated)
420
+ last_checkpoint = samples_generated
421
+
422
+ # Write data files
423
+ self._write_jsonl(output_dir / "data.jsonl", samples + held_out)
424
+ self._write_splits(output_dir, samples)
425
+
426
+ if held_out:
427
+ self._write_jsonl(output_dir / "held_out.jsonl", held_out)
428
+
429
+ # Write config for reference
430
+ self._write_config(output_dir)
431
+
432
+ print(f"Generated {len(samples)} training samples, {len(held_out)} held out")
433
+ print(f"Output: {output_dir}")
434
+
435
+ return output_dir
436
+
437
+ def _to_record(self, sample: TaskSample) -> dict[str, Any]:
438
+ """Convert TaskSample to JSONL record."""
439
+ # Get normalized coordinates
440
+ norm_coord = normalize_coord(sample.pixel_coords, sample.image_size)
441
+
442
+ # Check if sample has multiple tool_calls in metadata
443
+ if "tool_calls" in sample.metadata and len(sample.metadata["tool_calls"]) > 1:
444
+ # Format all tool calls for multi-action samples
445
+ gpt_parts = []
446
+ for tc in sample.metadata["tool_calls"]:
447
+ gpt_parts.append(format_tool_call(tc))
448
+ gpt_value = "\n".join(gpt_parts)
449
+ else:
450
+ # Single tool call - update with normalized coordinates
451
+ tool_call = sample.tool_call.to_dict()
452
+ if "coordinate" in tool_call["arguments"]:
453
+ tool_call["arguments"]["coordinate"] = list(norm_coord)
454
+ gpt_value = format_tool_call(tool_call)
455
+
456
+ # Build relative image path
457
+ assert self.config.output_dir is not None
458
+ image_rel = str(sample.image_path.relative_to(self.config.output_dir))
459
+
460
+ return {
461
+ "id": sample.id,
462
+ "image": image_rel,
463
+ "conversations": [
464
+ {"from": "human", "value": f"<image>\n{sample.human_prompt}"},
465
+ {"from": "gpt", "value": gpt_value},
466
+ ],
467
+ "metadata": {
468
+ "task_type": sample.metadata.get("task_type", "unknown"),
469
+ "real_coords": list(sample.pixel_coords),
470
+ **{k: v for k, v in sample.metadata.items() if k != "task_type"},
471
+ },
472
+ }
473
+
474
+ def _write_jsonl(self, path: Path, records: list[dict[str, Any]]) -> None:
475
+ """Write records to JSONL file."""
476
+ with open(path, "w", encoding="utf-8") as f:
477
+ for record in records:
478
+ f.write(json.dumps(record) + "\n")
479
+
480
+ def _write_splits(self, output_dir: Path, samples: list[dict[str, Any]]) -> None:
481
+ """Split samples and write train/val files."""
482
+ # Shuffle for splitting
483
+ shuffled = samples.copy()
484
+ self.rng.shuffle(shuffled)
485
+
486
+ split_idx = int(len(shuffled) * self.config.train_split)
487
+ train = shuffled[:split_idx]
488
+ val = shuffled[split_idx:]
489
+
490
+ self._write_jsonl(output_dir / "train.jsonl", train)
491
+ self._write_jsonl(output_dir / "val.jsonl", val)
492
+
493
+ print(f"Split: {len(train)} train, {len(val)} val")
494
+
495
+ def _write_config(self, output_dir: Path) -> None:
496
+ """Write generation config for reference."""
497
+ # Extract task_types from task_counts keys
498
+ task_types = list(self.config.task_counts.keys())
499
+
500
+ config_data = {
501
+ "name_prefix": self.config.name_prefix,
502
+ "seed": self.config.seed,
503
+ "task_types": task_types,
504
+ "task_counts": self.config.task_counts,
505
+ "train_split": self.config.train_split,
506
+ "system_prompt": self.config.system_prompt,
507
+ "task_distributions": self.config.task_distributions,
508
+ "generated_at": datetime.now().isoformat(),
509
+ }
510
+ with open(output_dir / "config.json", "w") as f:
511
+ json.dump(config_data, f, indent=2)
512
+
513
+ def build_tests(self) -> Path:
514
+ """Generate test cases.
515
+
516
+ Returns:
517
+ Path to the test directory
518
+ """
519
+ output_dir = self.config.output_dir
520
+ assert output_dir is not None
521
+
522
+ # Create test directory structure (test/images/)
523
+ test_dir = output_dir / "test"
524
+ test_dir.mkdir(parents=True, exist_ok=True)
525
+ (test_dir / "images").mkdir(exist_ok=True)
526
+
527
+ # Create annotated directory if annotations enabled
528
+ annotated_dir = test_dir / "annotated"
529
+ if self.config.annotation_enabled:
530
+ annotated_dir.mkdir(exist_ok=True)
531
+
532
+ # Generate test cases - stop when we reach test_count
533
+ test_cases: list[dict[str, Any]] = []
534
+ raw_test_cases: list[TestCase] = []
535
+ index = 0
536
+
537
+ # Get task types to iterate through
538
+ task_types = [t for t in self.config.task_counts.keys() if t in self.tasks]
539
+ if not task_types:
540
+ return test_dir
541
+
542
+ # Generate test_count tests PER task type
543
+ for task_type in task_types:
544
+ if task_type not in self.tasks:
545
+ continue
546
+ task = self.tasks[task_type]
547
+ generated = 0
548
+
549
+ while generated < self.config.test_count:
550
+ ctx = TaskContext(
551
+ rng=self.rng,
552
+ index=index,
553
+ output_dir=test_dir,
554
+ config=task.config,
555
+ dataset_name=self.config.name_prefix,
556
+ )
557
+ tests = task.generate_tests(ctx)
558
+ for test_case in tests:
559
+ if generated >= self.config.test_count:
560
+ break
561
+ record = self._test_to_record(test_case, test_dir)
562
+ test_cases.append(record)
563
+ raw_test_cases.append(test_case)
564
+ generated += 1
565
+ index += 1
566
+
567
+ # ALWAYS generate 1 annotation per task type for tests (regardless of config)
568
+ # For grounding tasks, generate 1 annotation per unique element_label
569
+ annotated_count = 0
570
+ annotated_dir.mkdir(exist_ok=True)
571
+
572
+ # Group indices by task type (and element_label for grounding)
573
+ indices_by_key: dict[str, list[int]] = {}
574
+ for idx, test_case in enumerate(raw_test_cases):
575
+ task_type = test_case.metadata.get("task_type", "unknown")
576
+ # For grounding tasks, use element_label as additional key
577
+ if task_type == "grounding":
578
+ element_label = test_case.metadata.get("element_label", "unknown")
579
+ key = f"grounding:{element_label}"
580
+ else:
581
+ key = task_type
582
+ if key not in indices_by_key:
583
+ indices_by_key[key] = []
584
+ indices_by_key[key].append(idx)
585
+
586
+ # Select 1 index per key to annotate
587
+ indices_to_annotate: set[int] = set()
588
+ for key, indices in indices_by_key.items():
589
+ if indices:
590
+ indices_to_annotate.add(indices[0])
591
+
592
+ for idx in sorted(indices_to_annotate):
593
+ if idx >= len(raw_test_cases):
594
+ continue
595
+ test_case = raw_test_cases[idx]
596
+
597
+ # Get pixel coordinates for crosshair
598
+ pixel_coords = test_case.pixel_coords or (0, 0)
599
+
600
+ # Build tool_calls list from expected_action and any additional actions
601
+ tool_calls = [test_case.expected_action]
602
+
603
+ # Check for additional tool calls in metadata (e.g., type action for textfields)
604
+ if "additional_tool_calls" in test_case.metadata:
605
+ tool_calls.extend(test_case.metadata["additional_tool_calls"])
606
+
607
+ # Generate annotated image - include task type in filename
608
+ task_type = test_case.metadata.get("task_type", "unknown")
609
+ # For grounding, include element_label in filename
610
+ if task_type == "grounding":
611
+ element_label = test_case.metadata.get("element_label", "unknown")
612
+ annotated_path = annotated_dir / f"{task_type}_{element_label}_{test_case.test_id}_annotated.png"
613
+ else:
614
+ annotated_path = annotated_dir / f"{task_type}_{test_case.test_id}_annotated.png"
615
+
616
+ # Extract bbox_pixels for grounding tasks (format: [x, y, width, height])
617
+ bbox_pixels = None
618
+ if "bbox_pixels" in test_case.metadata:
619
+ bp = test_case.metadata["bbox_pixels"]
620
+ bbox_pixels = (bp[0], bp[1], bp[2], bp[3])
621
+
622
+ annotate_test_image(
623
+ image_path=test_case.screenshot,
624
+ tool_calls=tool_calls,
625
+ pixel_coords=pixel_coords,
626
+ prompt=test_case.prompt,
627
+ output_path=annotated_path,
628
+ bbox_pixels=bbox_pixels,
629
+ )
630
+ annotated_count += 1
631
+
632
+ # Write test.json
633
+ with open(test_dir / "test.json", "w", encoding="utf-8") as f:
634
+ json.dump(test_cases, f, indent=2)
635
+
636
+ if annotated_count > 0:
637
+ print(f"Generated {len(test_cases)} test cases ({annotated_count} annotated)")
638
+ else:
639
+ print(f"Generated {len(test_cases)} test cases")
640
+
641
+ return test_dir
642
+
643
+ def _test_to_record(self, test_case: TestCase, test_dir: Path) -> dict[str, Any]:
644
+ """Convert TestCase to record for test.json."""
645
+ # Get image size from metadata if available, default to 1920x1080
646
+ image_size = test_case.metadata.get("image_size", (1920, 1080))
647
+
648
+ # Normalize coordinates in expected_action
649
+ expected_action = test_case.expected_action.copy()
650
+ if "arguments" in expected_action and "coordinate" in expected_action["arguments"]:
651
+ pixel_coords = expected_action["arguments"]["coordinate"]
652
+ if test_case.pixel_coords:
653
+ pixel_coords = test_case.pixel_coords
654
+ norm_coord = normalize_coord(tuple(pixel_coords), image_size)
655
+ expected_action["arguments"]["coordinate"] = list(norm_coord)
656
+
657
+ # Build relative screenshot path (relative to test_dir)
658
+ screenshot_rel = str(test_case.screenshot.relative_to(test_dir))
659
+
660
+ # Tolerance can come from test_case directly or from metadata
661
+ # Convert to list for JSON serialization
662
+ tolerance = test_case.tolerance
663
+ if isinstance(tolerance, tuple):
664
+ tolerance = list(tolerance)
665
+ elif isinstance(tolerance, int):
666
+ tolerance = [tolerance, tolerance]
667
+
668
+ return {
669
+ "test_id": test_case.test_id,
670
+ "screenshot": screenshot_rel,
671
+ "prompt": test_case.prompt,
672
+ "expected_action": expected_action,
673
+ "tolerance": tolerance,
674
+ "metadata": {
675
+ "real_coords": list(test_case.pixel_coords) if test_case.pixel_coords else None,
676
+ **test_case.metadata,
677
+ },
678
+ }