cudag 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudag/__init__.py +334 -0
- cudag/annotation/__init__.py +77 -0
- cudag/annotation/codegen.py +648 -0
- cudag/annotation/config.py +545 -0
- cudag/annotation/loader.py +342 -0
- cudag/annotation/scaffold.py +121 -0
- cudag/annotation/transcription.py +296 -0
- cudag/cli/__init__.py +5 -0
- cudag/cli/main.py +315 -0
- cudag/cli/new.py +873 -0
- cudag/core/__init__.py +364 -0
- cudag/core/button.py +137 -0
- cudag/core/canvas.py +222 -0
- cudag/core/config.py +70 -0
- cudag/core/coords.py +233 -0
- cudag/core/data_grid.py +804 -0
- cudag/core/dataset.py +678 -0
- cudag/core/distribution.py +136 -0
- cudag/core/drawing.py +75 -0
- cudag/core/fonts.py +156 -0
- cudag/core/generator.py +163 -0
- cudag/core/grid.py +367 -0
- cudag/core/grounding_task.py +247 -0
- cudag/core/icon.py +207 -0
- cudag/core/iconlist_task.py +301 -0
- cudag/core/models.py +1251 -0
- cudag/core/random.py +130 -0
- cudag/core/renderer.py +190 -0
- cudag/core/screen.py +402 -0
- cudag/core/scroll_task.py +254 -0
- cudag/core/scrollable_grid.py +447 -0
- cudag/core/state.py +110 -0
- cudag/core/task.py +293 -0
- cudag/core/taskbar.py +350 -0
- cudag/core/text.py +212 -0
- cudag/core/utils.py +82 -0
- cudag/data/surnames.txt +5000 -0
- cudag/modal_apps/__init__.py +4 -0
- cudag/modal_apps/archive.py +103 -0
- cudag/modal_apps/extract.py +138 -0
- cudag/modal_apps/preprocess.py +529 -0
- cudag/modal_apps/upload.py +317 -0
- cudag/prompts/SYSTEM_PROMPT.txt +104 -0
- cudag/prompts/__init__.py +33 -0
- cudag/prompts/system.py +43 -0
- cudag/prompts/tools.py +382 -0
- cudag/py.typed +0 -0
- cudag/schemas/filesystem.json +90 -0
- cudag/schemas/test_record.schema.json +113 -0
- cudag/schemas/train_record.schema.json +90 -0
- cudag/server/__init__.py +21 -0
- cudag/server/app.py +232 -0
- cudag/server/services/__init__.py +9 -0
- cudag/server/services/generator.py +128 -0
- cudag/templates/scripts/archive.sh +35 -0
- cudag/templates/scripts/build.sh +13 -0
- cudag/templates/scripts/extract.sh +54 -0
- cudag/templates/scripts/generate.sh +116 -0
- cudag/templates/scripts/pre-commit.sh +44 -0
- cudag/templates/scripts/preprocess.sh +46 -0
- cudag/templates/scripts/upload.sh +63 -0
- cudag/templates/scripts/verify.py +428 -0
- cudag/validation/__init__.py +35 -0
- cudag/validation/validate.py +508 -0
- cudag-0.3.10.dist-info/METADATA +570 -0
- cudag-0.3.10.dist-info/RECORD +69 -0
- cudag-0.3.10.dist-info/WHEEL +4 -0
- cudag-0.3.10.dist-info/entry_points.txt +2 -0
- cudag-0.3.10.dist-info/licenses/LICENSE +66 -0
cudag/core/grid.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tylt LLC. All rights reserved.
|
|
2
|
+
# CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
|
|
3
|
+
# is strictly prohibited. For licensing inquiries: hello@claimhawk.app
|
|
4
|
+
|
|
5
|
+
"""Grid abstraction for UI grids.
|
|
6
|
+
|
|
7
|
+
Provides a reusable Grid class for any grid-based UI component:
|
|
8
|
+
- Calendar day grids
|
|
9
|
+
- Data grids/tables
|
|
10
|
+
- Spreadsheets
|
|
11
|
+
- Game boards
|
|
12
|
+
|
|
13
|
+
The Grid class handles:
|
|
14
|
+
- Geometry (cell positions, sizes, gaps)
|
|
15
|
+
- Cell coordinate calculations
|
|
16
|
+
- Content/data management
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from typing import Any, Generic, TypeVar
|
|
23
|
+
|
|
24
|
+
# Type variable for cell content
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class GridGeometry:
|
|
30
|
+
"""Defines the physical layout of a grid.
|
|
31
|
+
|
|
32
|
+
All measurements are in pixels. Gaps can be float for sub-pixel accuracy.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
x: int
|
|
36
|
+
"""X position of grid top-left corner."""
|
|
37
|
+
|
|
38
|
+
y: int
|
|
39
|
+
"""Y position of grid top-left corner."""
|
|
40
|
+
|
|
41
|
+
rows: int
|
|
42
|
+
"""Number of rows."""
|
|
43
|
+
|
|
44
|
+
cols: int
|
|
45
|
+
"""Number of columns."""
|
|
46
|
+
|
|
47
|
+
cell_width: int
|
|
48
|
+
"""Width of each cell."""
|
|
49
|
+
|
|
50
|
+
cell_height: int
|
|
51
|
+
"""Height of each cell."""
|
|
52
|
+
|
|
53
|
+
row_gap: float = 0
|
|
54
|
+
"""Gap between rows in pixels (can be float)."""
|
|
55
|
+
|
|
56
|
+
col_gap: float = 0
|
|
57
|
+
"""Gap between columns in pixels (can be float)."""
|
|
58
|
+
|
|
59
|
+
first_row_header: bool = False
|
|
60
|
+
"""If True, first row is a fixed header (doesn't scroll)."""
|
|
61
|
+
|
|
62
|
+
last_col_scroll: bool = False
|
|
63
|
+
"""If True, last column is reserved for scrollbar."""
|
|
64
|
+
|
|
65
|
+
last_row_scroll: bool = False
|
|
66
|
+
"""If True, last row is reserved for horizontal scrollbar."""
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def data_rows(self) -> int:
|
|
70
|
+
"""Number of rows available for data (excluding header/scroll rows)."""
|
|
71
|
+
count = self.rows
|
|
72
|
+
if self.first_row_header:
|
|
73
|
+
count -= 1
|
|
74
|
+
if self.last_row_scroll:
|
|
75
|
+
count -= 1
|
|
76
|
+
return max(0, count)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def data_cols(self) -> int:
|
|
80
|
+
"""Number of columns available for data (excluding scroll column)."""
|
|
81
|
+
count = self.cols
|
|
82
|
+
if self.last_col_scroll:
|
|
83
|
+
count -= 1
|
|
84
|
+
return max(0, count)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def header_row(self) -> int | None:
|
|
88
|
+
"""Row index of header, or None if no header."""
|
|
89
|
+
return 0 if self.first_row_header else None
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def scroll_col(self) -> int | None:
|
|
93
|
+
"""Column index of scrollbar, or None if no scrollbar."""
|
|
94
|
+
return self.cols - 1 if self.last_col_scroll else None
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def scroll_row(self) -> int | None:
|
|
98
|
+
"""Row index of horizontal scrollbar, or None if no scrollbar."""
|
|
99
|
+
return self.rows - 1 if self.last_row_scroll else None
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def width(self) -> int:
|
|
103
|
+
"""Total grid width including gaps."""
|
|
104
|
+
gaps = self.col_gap * (self.cols - 1) if self.cols > 1 else 0
|
|
105
|
+
return self.cols * self.cell_width + gaps
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def height(self) -> int:
|
|
109
|
+
"""Total grid height including gaps."""
|
|
110
|
+
gaps = self.row_gap * (self.rows - 1) if self.rows > 1 else 0
|
|
111
|
+
return self.rows * self.cell_height + gaps
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def bounds(self) -> tuple[int, int, int, int]:
|
|
115
|
+
"""Grid bounds as (x, y, width, height)."""
|
|
116
|
+
return (self.x, self.y, self.width, self.height)
|
|
117
|
+
|
|
118
|
+
def tolerance_pixels(self, padding_ratio: float = 0.15) -> tuple[int, int]:
|
|
119
|
+
"""Natural tolerance in pixels based on cell size.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
padding_ratio: Padding on each side as ratio of cell size (default 15%)
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
(x_tolerance, y_tolerance) in pixels
|
|
126
|
+
"""
|
|
127
|
+
# Tolerance is cell size minus padding on each side
|
|
128
|
+
tol_ratio = 1.0 - (2 * padding_ratio) # 70% for 15% padding
|
|
129
|
+
return (int(self.cell_width * tol_ratio), int(self.cell_height * tol_ratio))
|
|
130
|
+
|
|
131
|
+
def tolerance_ru(
|
|
132
|
+
self,
|
|
133
|
+
image_size: tuple[int, int],
|
|
134
|
+
padding_ratio: float = 0.15,
|
|
135
|
+
) -> tuple[int, int]:
|
|
136
|
+
"""Natural tolerance in RU (normalized 0-1000) based on cell size.
|
|
137
|
+
|
|
138
|
+
Calculates tolerance as a percentage of the cell size in normalized coordinates.
|
|
139
|
+
A 15% padding on each side means 70% tolerance.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
image_size: (width, height) of the image in pixels
|
|
143
|
+
padding_ratio: Padding on each side as ratio of cell size (default 15%)
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
(x_tolerance, y_tolerance) in RU units (0-1000 scale)
|
|
147
|
+
|
|
148
|
+
Examples:
|
|
149
|
+
For a 24x15 cell on 224x208 image with 15% padding:
|
|
150
|
+
- x: (24/224 * 1000) * 0.7 = ~75 RU
|
|
151
|
+
- y: (15/208 * 1000) * 0.7 = ~50 RU
|
|
152
|
+
"""
|
|
153
|
+
tol_ratio = 1.0 - (2 * padding_ratio) # 70% for 15% padding
|
|
154
|
+
x_ru = (self.cell_width / image_size[0]) * 1000 * tol_ratio
|
|
155
|
+
y_ru = (self.cell_height / image_size[1]) * 1000 * tol_ratio
|
|
156
|
+
return (int(x_ru), int(y_ru))
|
|
157
|
+
|
|
158
|
+
def cell_position(self, row: int, col: int) -> tuple[int, int]:
|
|
159
|
+
"""Get top-left position of a cell.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
row: Row index (0-based)
|
|
163
|
+
col: Column index (0-based)
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
(x, y) position of cell top-left corner
|
|
167
|
+
"""
|
|
168
|
+
# Use float math then round to avoid drift with fractional gaps
|
|
169
|
+
x = round(self.x + col * (self.cell_width + self.col_gap))
|
|
170
|
+
y = round(self.y + row * (self.cell_height + self.row_gap))
|
|
171
|
+
return (x, y)
|
|
172
|
+
|
|
173
|
+
def cell_center(self, row: int, col: int) -> tuple[int, int]:
|
|
174
|
+
"""Get center position of a cell.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
row: Row index (0-based)
|
|
178
|
+
col: Column index (0-based)
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
(x, y) center position
|
|
182
|
+
"""
|
|
183
|
+
x, y = self.cell_position(row, col)
|
|
184
|
+
return (x + self.cell_width // 2, y + self.cell_height // 2)
|
|
185
|
+
|
|
186
|
+
def cell_bounds(self, row: int, col: int) -> tuple[int, int, int, int]:
|
|
187
|
+
"""Get bounds of a cell.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
row: Row index (0-based)
|
|
191
|
+
col: Column index (0-based)
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
(x, y, width, height) bounds
|
|
195
|
+
"""
|
|
196
|
+
x, y = self.cell_position(row, col)
|
|
197
|
+
return (x, y, self.cell_width, self.cell_height)
|
|
198
|
+
|
|
199
|
+
def index_to_rowcol(self, index: int) -> tuple[int, int]:
|
|
200
|
+
"""Convert linear index to (row, col).
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
index: Linear index (0 to rows*cols-1)
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
(row, col) tuple
|
|
207
|
+
"""
|
|
208
|
+
return divmod(index, self.cols)
|
|
209
|
+
|
|
210
|
+
def rowcol_to_index(self, row: int, col: int) -> int:
|
|
211
|
+
"""Convert (row, col) to linear index.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
row: Row index
|
|
215
|
+
col: Column index
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Linear index
|
|
219
|
+
"""
|
|
220
|
+
return row * self.cols + col
|
|
221
|
+
|
|
222
|
+
def point_to_cell(self, x: int, y: int) -> tuple[int, int] | None:
|
|
223
|
+
"""Find which cell contains a point.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
x: X coordinate
|
|
227
|
+
y: Y coordinate
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
(row, col) if point is in grid, None otherwise
|
|
231
|
+
"""
|
|
232
|
+
# Check if point is within grid bounds
|
|
233
|
+
if x < self.x or x >= self.x + self.width:
|
|
234
|
+
return None
|
|
235
|
+
if y < self.y or y >= self.y + self.height:
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
# Calculate column
|
|
239
|
+
rel_x = x - self.x
|
|
240
|
+
col_width_with_gap = self.cell_width + self.col_gap
|
|
241
|
+
col = rel_x // col_width_with_gap
|
|
242
|
+
col_offset = rel_x % col_width_with_gap
|
|
243
|
+
|
|
244
|
+
# Check if in gap
|
|
245
|
+
if col_offset >= self.cell_width:
|
|
246
|
+
return None
|
|
247
|
+
if col >= self.cols:
|
|
248
|
+
col = self.cols - 1
|
|
249
|
+
|
|
250
|
+
# Calculate row
|
|
251
|
+
rel_y = y - self.y
|
|
252
|
+
row_height_with_gap = self.cell_height + self.row_gap
|
|
253
|
+
row = rel_y // row_height_with_gap
|
|
254
|
+
row_offset = rel_y % row_height_with_gap
|
|
255
|
+
|
|
256
|
+
# Check if in gap
|
|
257
|
+
if row_offset >= self.cell_height:
|
|
258
|
+
return None
|
|
259
|
+
if row >= self.rows:
|
|
260
|
+
row = self.rows - 1
|
|
261
|
+
|
|
262
|
+
return (row, col)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@dataclass
|
|
266
|
+
class GridCell(Generic[T]):
|
|
267
|
+
"""A cell in a grid with position and content."""
|
|
268
|
+
|
|
269
|
+
row: int
|
|
270
|
+
"""Row index."""
|
|
271
|
+
|
|
272
|
+
col: int
|
|
273
|
+
"""Column index."""
|
|
274
|
+
|
|
275
|
+
content: T
|
|
276
|
+
"""Cell content (type depends on use case)."""
|
|
277
|
+
|
|
278
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
279
|
+
"""Optional metadata for the cell."""
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def index(self) -> int:
|
|
283
|
+
"""Linear index (assuming standard row-major order)."""
|
|
284
|
+
# This is a convenience property; actual cols should be passed from grid
|
|
285
|
+
raise NotImplementedError("Use grid.rowcol_to_index() instead")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@dataclass
|
|
289
|
+
class Grid(Generic[T]):
|
|
290
|
+
"""A grid of cells with geometry and content.
|
|
291
|
+
|
|
292
|
+
Generic type T is the content type for cells.
|
|
293
|
+
|
|
294
|
+
Example:
|
|
295
|
+
# Date grid (calendar)
|
|
296
|
+
geometry = GridGeometry(x=2, y=72, rows=6, cols=7, cell_width=24, cell_height=15)
|
|
297
|
+
grid = Grid(geometry)
|
|
298
|
+
for i, d in enumerate(dates):
|
|
299
|
+
row, col = geometry.index_to_rowcol(i)
|
|
300
|
+
grid.set_cell(row, col, d)
|
|
301
|
+
|
|
302
|
+
# Get click position for a date
|
|
303
|
+
cell = grid.find_cell(lambda c: c.content == target_date)
|
|
304
|
+
click_pos = geometry.cell_center(cell.row, cell.col)
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
geometry: GridGeometry
|
|
308
|
+
"""Grid layout/geometry."""
|
|
309
|
+
|
|
310
|
+
cells: list[GridCell[T]] = field(default_factory=list)
|
|
311
|
+
"""All cells in the grid."""
|
|
312
|
+
|
|
313
|
+
def __post_init__(self) -> None:
|
|
314
|
+
"""Initialize empty cells if not provided."""
|
|
315
|
+
if not self.cells:
|
|
316
|
+
for row in range(self.geometry.rows):
|
|
317
|
+
for col in range(self.geometry.cols):
|
|
318
|
+
self.cells.append(GridCell(row=row, col=col, content=None)) # type: ignore
|
|
319
|
+
|
|
320
|
+
def get_cell(self, row: int, col: int) -> GridCell[T] | None:
|
|
321
|
+
"""Get cell at position."""
|
|
322
|
+
index = self.geometry.rowcol_to_index(row, col)
|
|
323
|
+
if 0 <= index < len(self.cells):
|
|
324
|
+
return self.cells[index]
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
def set_cell(self, row: int, col: int, content: T, **metadata: Any) -> None:
|
|
328
|
+
"""Set cell content at position."""
|
|
329
|
+
index = self.geometry.rowcol_to_index(row, col)
|
|
330
|
+
if 0 <= index < len(self.cells):
|
|
331
|
+
self.cells[index].content = content
|
|
332
|
+
self.cells[index].metadata.update(metadata)
|
|
333
|
+
|
|
334
|
+
def find_cell(self, predicate: Any) -> GridCell[T] | None:
|
|
335
|
+
"""Find first cell matching predicate.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
predicate: Function taking GridCell and returning bool
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
First matching cell or None
|
|
342
|
+
"""
|
|
343
|
+
for cell in self.cells:
|
|
344
|
+
if predicate(cell):
|
|
345
|
+
return cell
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
def find_cells(self, predicate: Any) -> list[GridCell[T]]:
|
|
349
|
+
"""Find all cells matching predicate."""
|
|
350
|
+
return [cell for cell in self.cells if predicate(cell)]
|
|
351
|
+
|
|
352
|
+
def cell_center(self, row: int, col: int) -> tuple[int, int]:
|
|
353
|
+
"""Get center position of cell (convenience method)."""
|
|
354
|
+
return self.geometry.cell_center(row, col)
|
|
355
|
+
|
|
356
|
+
def cell_bounds(self, row: int, col: int) -> tuple[int, int, int, int]:
|
|
357
|
+
"""Get bounds of cell (convenience method)."""
|
|
358
|
+
return self.geometry.cell_bounds(row, col)
|
|
359
|
+
|
|
360
|
+
def iter_cells(self) -> Any:
|
|
361
|
+
"""Iterate over all cells."""
|
|
362
|
+
return iter(self.cells)
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def total_cells(self) -> int:
|
|
366
|
+
"""Total number of cells."""
|
|
367
|
+
return self.geometry.rows * self.geometry.cols
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tylt LLC. All rights reserved.
|
|
2
|
+
# CONFIDENTIAL AND PROPRIETARY. Unauthorized use, copying, or distribution
|
|
3
|
+
# is strictly prohibited. For licensing inquiries: hello@claimhawk.app
|
|
4
|
+
|
|
5
|
+
"""Base grounding task for element bounding box detection.
|
|
6
|
+
|
|
7
|
+
This module provides a base task class for generating "grounding" training data,
|
|
8
|
+
where the model must identify the bounding box of a specified element.
|
|
9
|
+
|
|
10
|
+
Example output:
|
|
11
|
+
<tool_call>
|
|
12
|
+
{"name": "get_bbox", "arguments": {"element": "search button", "bbox_2d": [123, 456, 789, 1011]}}
|
|
13
|
+
</tool_call>
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from abc import abstractmethod
|
|
19
|
+
from typing import TYPE_CHECKING, Any
|
|
20
|
+
|
|
21
|
+
from cudag.annotation import AnnotatedElement, AnnotationConfig
|
|
22
|
+
from cudag.core.task import BaseTask, TaskContext, TaskSample, TestCase
|
|
23
|
+
from cudag.prompts.tools import BboxCall, format_tool_call
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from random import Random
|
|
27
|
+
|
|
28
|
+
from PIL import Image
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def bbox_to_ru(
|
|
32
|
+
bbox: tuple[int, int, int, int],
|
|
33
|
+
image_size: tuple[int, int],
|
|
34
|
+
) -> tuple[int, int, int, int]:
|
|
35
|
+
"""Convert bbox (x, y, width, height) to RU coordinates [x1, y1, x2, y2].
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
bbox: Bounding box as (x, y, width, height) in pixels
|
|
39
|
+
image_size: Image size as (width, height) in pixels
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Bounding box as (x1, y1, x2, y2) in RU units (0-1000)
|
|
43
|
+
"""
|
|
44
|
+
x, y, w, h = bbox
|
|
45
|
+
img_w, img_h = image_size
|
|
46
|
+
|
|
47
|
+
x1 = int((x / img_w) * 1000)
|
|
48
|
+
y1 = int((y / img_h) * 1000)
|
|
49
|
+
x2 = int(((x + w) / img_w) * 1000)
|
|
50
|
+
y2 = int(((y + h) / img_h) * 1000)
|
|
51
|
+
|
|
52
|
+
return (x1, y1, x2, y2)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def scale_bbox(
|
|
56
|
+
bbox: tuple[int, int, int, int],
|
|
57
|
+
scale_x: float,
|
|
58
|
+
scale_y: float,
|
|
59
|
+
) -> tuple[int, int, int, int]:
|
|
60
|
+
"""Scale a bounding box by given factors.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
bbox: Bounding box as (x, y, width, height)
|
|
64
|
+
scale_x: X scale factor
|
|
65
|
+
scale_y: Y scale factor
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Scaled bounding box
|
|
69
|
+
"""
|
|
70
|
+
x, y, w, h = bbox
|
|
71
|
+
return (int(x * scale_x), int(y * scale_y), int(w * scale_x), int(h * scale_y))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class GroundingTaskBase(BaseTask):
|
|
75
|
+
"""Base class for grounding tasks that identify element bounding boxes.
|
|
76
|
+
|
|
77
|
+
Subclasses must implement:
|
|
78
|
+
- get_annotation_config(): Return the annotation config
|
|
79
|
+
- get_image_scale(): Return (scale_x, scale_y) for coordinate scaling
|
|
80
|
+
- render_image(ctx): Render an image and return (image, metadata)
|
|
81
|
+
|
|
82
|
+
The task will:
|
|
83
|
+
1. Pick a random element from the annotation
|
|
84
|
+
2. Generate a prompt like "Locate the {element_label}"
|
|
85
|
+
3. Return a BboxCall with the element's bounding box in RU coordinates
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
task_type: str = "grounding"
|
|
89
|
+
|
|
90
|
+
# Prompt templates that can be customized by subclasses
|
|
91
|
+
PROMPT_TEMPLATES = [
|
|
92
|
+
"Locate the {element}",
|
|
93
|
+
"Find the bounding box of the {element}",
|
|
94
|
+
"Where is the {element}?",
|
|
95
|
+
"Identify the {element} region",
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def get_annotation_config(self) -> AnnotationConfig:
|
|
100
|
+
"""Return the annotation config for this generator."""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def get_image_scale(self) -> tuple[float, float]:
|
|
105
|
+
"""Return (scale_x, scale_y) for coordinate scaling.
|
|
106
|
+
|
|
107
|
+
If annotation was made at different size than generator output,
|
|
108
|
+
return the scale factors to convert annotation coords to output coords.
|
|
109
|
+
"""
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def render_image(self, ctx: TaskContext) -> tuple[Any, dict[str, Any]]:
|
|
114
|
+
"""Render an image for this task.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
ctx: Task context with RNG, index, etc.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Tuple of (PIL.Image, metadata dict)
|
|
121
|
+
"""
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
def get_groundable_elements(self) -> list[AnnotatedElement]:
|
|
125
|
+
"""Get elements that can be used for grounding.
|
|
126
|
+
|
|
127
|
+
Override this to filter which elements are included in grounding tasks.
|
|
128
|
+
By default, returns all elements with non-empty labels.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
List of elements to use for grounding
|
|
132
|
+
"""
|
|
133
|
+
config = self.get_annotation_config()
|
|
134
|
+
return [el for el in config.elements if el.label]
|
|
135
|
+
|
|
136
|
+
def get_prompt(self, element: AnnotatedElement, rng: Random) -> str:
|
|
137
|
+
"""Generate a prompt for the given element.
|
|
138
|
+
|
|
139
|
+
Override this to customize prompt generation.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
element: The element to generate a prompt for
|
|
143
|
+
rng: Random number generator
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Prompt string
|
|
147
|
+
"""
|
|
148
|
+
template = rng.choice(self.PROMPT_TEMPLATES)
|
|
149
|
+
return template.format(element=element.label)
|
|
150
|
+
|
|
151
|
+
def generate_sample(self, ctx: TaskContext) -> TaskSample:
|
|
152
|
+
"""Generate a grounding training sample."""
|
|
153
|
+
# Render the image
|
|
154
|
+
image, metadata = self.render_image(ctx)
|
|
155
|
+
image_path = self.save_image(image, ctx)
|
|
156
|
+
|
|
157
|
+
# Pick a random element
|
|
158
|
+
elements = self.get_groundable_elements()
|
|
159
|
+
element = ctx.rng.choice(elements)
|
|
160
|
+
|
|
161
|
+
# Scale the bbox to generator output size
|
|
162
|
+
scale_x, scale_y = self.get_image_scale()
|
|
163
|
+
scaled_bbox = scale_bbox(element.bbox, scale_x, scale_y)
|
|
164
|
+
|
|
165
|
+
# Convert to RU coordinates [x1, y1, x2, y2]
|
|
166
|
+
bbox_ru = bbox_to_ru(scaled_bbox, image.size)
|
|
167
|
+
|
|
168
|
+
# Create the prompt
|
|
169
|
+
prompt = self.get_prompt(element, ctx.rng)
|
|
170
|
+
|
|
171
|
+
# Create BboxCall
|
|
172
|
+
bbox_call = BboxCall.create(element=element.label, bbox_2d=bbox_ru)
|
|
173
|
+
|
|
174
|
+
# Center point for metadata (midpoint of bbox)
|
|
175
|
+
center_x = scaled_bbox[0] + scaled_bbox[2] // 2
|
|
176
|
+
center_y = scaled_bbox[1] + scaled_bbox[3] // 2
|
|
177
|
+
|
|
178
|
+
return TaskSample(
|
|
179
|
+
id=self.build_id(ctx),
|
|
180
|
+
image_path=image_path,
|
|
181
|
+
human_prompt=prompt,
|
|
182
|
+
tool_call=bbox_call, # type: ignore[arg-type]
|
|
183
|
+
pixel_coords=(center_x, center_y),
|
|
184
|
+
metadata={
|
|
185
|
+
"task_type": self.task_type,
|
|
186
|
+
"element_label": element.label,
|
|
187
|
+
"element_type": element.element_type,
|
|
188
|
+
"bbox_pixels": list(scaled_bbox),
|
|
189
|
+
"bbox_ru": list(bbox_ru),
|
|
190
|
+
**metadata,
|
|
191
|
+
},
|
|
192
|
+
image_size=image.size,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def generate_test(self, ctx: TaskContext) -> TestCase:
|
|
196
|
+
"""Generate a grounding test case."""
|
|
197
|
+
# Render the image
|
|
198
|
+
image, metadata = self.render_image(ctx)
|
|
199
|
+
image_path = self.save_image(image, ctx, prefix="test")
|
|
200
|
+
|
|
201
|
+
# Pick a random element
|
|
202
|
+
elements = self.get_groundable_elements()
|
|
203
|
+
element = ctx.rng.choice(elements)
|
|
204
|
+
|
|
205
|
+
# Scale the bbox to generator output size
|
|
206
|
+
scale_x, scale_y = self.get_image_scale()
|
|
207
|
+
scaled_bbox = scale_bbox(element.bbox, scale_x, scale_y)
|
|
208
|
+
|
|
209
|
+
# Convert to RU coordinates [x1, y1, x2, y2]
|
|
210
|
+
bbox_ru = bbox_to_ru(scaled_bbox, image.size)
|
|
211
|
+
|
|
212
|
+
# Create the prompt
|
|
213
|
+
prompt = self.get_prompt(element, ctx.rng)
|
|
214
|
+
|
|
215
|
+
# Create expected action as dict
|
|
216
|
+
expected_action = {
|
|
217
|
+
"name": "get_bbox",
|
|
218
|
+
"arguments": {
|
|
219
|
+
"element": element.label,
|
|
220
|
+
"bbox_2d": list(bbox_ru),
|
|
221
|
+
},
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
# Center point for metadata
|
|
225
|
+
center_x = scaled_bbox[0] + scaled_bbox[2] // 2
|
|
226
|
+
center_y = scaled_bbox[1] + scaled_bbox[3] // 2
|
|
227
|
+
|
|
228
|
+
return TestCase(
|
|
229
|
+
test_id=f"test_{ctx.index:04d}",
|
|
230
|
+
screenshot=image_path,
|
|
231
|
+
prompt=prompt,
|
|
232
|
+
expected_action=expected_action,
|
|
233
|
+
tolerance=(50, 50), # Generous tolerance for bbox matching
|
|
234
|
+
metadata={
|
|
235
|
+
"task_type": self.task_type,
|
|
236
|
+
"element_label": element.label,
|
|
237
|
+
"element_type": element.element_type,
|
|
238
|
+
"bbox_pixels": list(scaled_bbox),
|
|
239
|
+
"image_size": image.size,
|
|
240
|
+
**metadata,
|
|
241
|
+
},
|
|
242
|
+
pixel_coords=(center_x, center_y),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def format_gpt_response(self, tool_call: BboxCall) -> str: # type: ignore[override]
|
|
246
|
+
"""Format the GPT response for this sample."""
|
|
247
|
+
return format_tool_call(tool_call)
|