langfun 0.1.2.dev202507260804__py3-none-any.whl → 0.1.2.dev202507280805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -0,0 +1,242 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Image drawing for facilitating UI understanding."""
15
+
16
+ import functools
17
+ from typing import Callable
18
+ import langfun as lf
19
+ from langfun.assistant.capabilities.gui import location
20
+ from PIL import Image as pil_image
21
+ from PIL import ImageDraw as pil_draw
22
+
23
+
24
+ def blank_image(
25
+ size: tuple[int, int],
26
+ background: tuple[int, int, int] = (0, 0, 0),
27
+ pil: bool = False
28
+ ) -> lf.Image | pil_image.Image:
29
+ """Creates a blank image of given size and background color.
30
+
31
+ Args:
32
+ size: The size of the image.
33
+ background: The background color of the image in RGB format.
34
+ pil: If True, the return value will be a `PIL.Image` object, otherwise it
35
+ will be a `lf.Image` object.
36
+
37
+ Returns:
38
+ A blank image with requested size and background color.
39
+ """
40
+ image = pil_image.new('RGB', size, background)
41
+ return image if pil else lf.Image.from_pil_image(image)
42
+
43
+
44
+ def draw(
45
+ image: lf.Image | pil_image.Image,
46
+ draw_fn: Callable[[pil_draw.ImageDraw], None],
47
+ ) -> lf.Image | pil_image.Image:
48
+ """Draws on an image with a draw_fn.
49
+
50
+ Args:
51
+ image: An `lf.Image` or a `PIL.Image` object.
52
+ draw_fn: A function that takes a `PIL.ImageDraw` object and draws on it.
53
+
54
+ Returns:
55
+ The image after drawing. Its type will be the same as the input image.
56
+ """
57
+ is_pil_input = isinstance(image, pil_image.Image)
58
+ if not is_pil_input:
59
+ image = image.to_pil_image()
60
+ draw_fn(pil_draw.Draw(image))
61
+ return image if is_pil_input else lf.Image.from_pil_image(image)
62
+
63
+
64
+ def draw_bboxes(
65
+ image: lf.Image | pil_image.Image,
66
+ bboxes: list[location.BBox],
67
+ line_color: str = 'red',
68
+ line_width: int = 3,
69
+ text: str | None = None,
70
+ ) -> lf.Image | pil_image.Image:
71
+ """Draws bounding boxes on an image.
72
+
73
+ Args:
74
+ image: An `lf.Image` or a `PIL.Image` object.
75
+ bboxes: A list of `location.BBox` objects to draw on the image.
76
+ line_color: The color of the bounding box lines.
77
+ line_width: The width of the bounding box lines.
78
+ text: The text to draw on each bounding box.
79
+
80
+ Returns:
81
+ The image after drawing. Its type will be the same as the input image.
82
+ """
83
+ if not bboxes:
84
+ return image
85
+
86
+ def _draw_fn(drawing: pil_draw.ImageDraw) -> None:
87
+ for bbox in bboxes:
88
+ drawing.rectangle(
89
+ (bbox.x, bbox.y, bbox.right, bbox.bottom),
90
+ outline=line_color, width=line_width
91
+ )
92
+ if text:
93
+ drawing.text((bbox.x + 5, bbox.y + 5), text, fill=line_color)
94
+ return draw(image, _draw_fn)
95
+
96
+
97
+ def draw_points(
98
+ image: lf.Image | pil_image.Image,
99
+ points: list[location.Coordinate],
100
+ color: str = 'red',
101
+ radius: int = 3,
102
+ ) -> lf.Image | pil_image.Image:
103
+ """Draws points on an image.
104
+
105
+ Args:
106
+ image: An `lf.Image` or a `PIL.Image` object.
107
+ points: A list of `location.Coordinate` objects to draw on the image.
108
+ color: The color of the points.
109
+ radius: The radius of the points.
110
+
111
+ Returns:
112
+ The image after drawing. Its type will be the same as the input image.
113
+ """
114
+ if not points:
115
+ return image
116
+
117
+ def _draw_fn(drawing: pil_draw.ImageDraw) -> None:
118
+ for point in points:
119
+ drawing.ellipse(
120
+ (
121
+ point.x - radius,
122
+ point.y - radius,
123
+ point.x + radius,
124
+ point.y + radius,
125
+ ),
126
+ fill=color,
127
+ )
128
+
129
+ return draw(image, _draw_fn)
130
+
131
+
132
+ def draw_calibration_lines(
133
+ image: lf.Image | pil_image.Image,
134
+ coordinate: location.Coordinate,
135
+ vline_color: str = 'green',
136
+ vline_width: int = 3,
137
+ hline_color: str = 'blue',
138
+ hline_width: int = 3
139
+ ) -> lf.Image | pil_image.Image:
140
+ """Draws calibration lines for a coordinate in an image.
141
+
142
+ Args:
143
+ image: An `lf.Image` or a `PIL.Image` object.
144
+ coordinate: The coordinate to draw calibration lines for.
145
+ vline_color: The color of the vertical line.
146
+ vline_width: The width of the vertical line.
147
+ hline_color: The color of the horizontal line.
148
+ hline_width: The width of the horizontal line.
149
+
150
+ Returns:
151
+ The image after drawing. Its type will be the same as the input image.
152
+ """
153
+ def _draw_fn(drawing: pil_draw.ImageDraw) -> None:
154
+ drawing.line(
155
+ (0, coordinate.y, image.size[0], coordinate.y),
156
+ fill=hline_color, width=hline_width
157
+ )
158
+ drawing.line(
159
+ (coordinate.x, 0, coordinate.x, image.size[1]),
160
+ fill=vline_color, width=vline_width
161
+ )
162
+ return draw(image, _draw_fn)
163
+
164
+
165
+ def draw_cursor(
166
+ image: lf.Image | pil_image.Image,
167
+ coordinate: location.Coordinate,
168
+ ) -> lf.Image | pil_image.Image:
169
+ """Draws a cursor on an image.
170
+
171
+ Args:
172
+ image: An `lf.Image` or a `PIL.Image` object.
173
+ coordinate: The coordinate to draw the cursor at.
174
+
175
+ Returns:
176
+ The image after drawing. Its type will be the same as the input image.
177
+ """
178
+ is_pil_input = isinstance(image, pil_image.Image)
179
+ if not is_pil_input:
180
+ image = image.to_pil_image()
181
+
182
+ image.paste(
183
+ _cursor_image(), (coordinate.x - 3, coordinate.y - 2), _cursor_image()
184
+ )
185
+ return image if is_pil_input else lf.Image.from_pil_image(image)
186
+
187
+
188
+ def draw_ref_points(
189
+ image: lf.Image | pil_image.Image,
190
+ width: int,
191
+ height: int,
192
+ rows: int,
193
+ cols: int,
194
+ ) -> lf.Image | pil_image.Image:
195
+ """Draw a grid of reference points on an image."""
196
+ delta_x = width // (cols + 1)
197
+ delta_y = height // (rows + 1)
198
+ points = []
199
+
200
+ for i in range(rows):
201
+ for j in range(cols):
202
+ x = int((i + 1) * delta_x)
203
+ y = int((j + 1) * delta_y)
204
+ points.append((x, y))
205
+
206
+ def _draw_fn(drawing: pil_draw.ImageDraw) -> None:
207
+ for point in points:
208
+ x, y = point
209
+ drawing.ellipse(
210
+ (x - 2, y - 2, x + 2, y + 2), fill='green', outline='green'
211
+ )
212
+ drawing.text((x + 5, y - 6), f'(x={x}, y={y})', fill='green')
213
+
214
+ return draw(image, _draw_fn)
215
+
216
+
217
+ def draw_text(
218
+ image: lf.Image | pil_image.Image,
219
+ text: str,
220
+ coordinate: location.Coordinate,
221
+ font_color: str = 'red',
222
+ ) -> lf.Image | pil_image.Image:
223
+ """Draws text on an image."""
224
+ is_pil_input = isinstance(image, pil_image.Image)
225
+ if not is_pil_input:
226
+ image = image.to_pil_image()
227
+
228
+ drawing = pil_draw.ImageDraw(image)
229
+ drawing.text(
230
+ (coordinate.x, coordinate.y),
231
+ text,
232
+ fill=font_color,
233
+ spacing=4,
234
+ )
235
+ return image if is_pil_input else lf.Image.from_pil_image(image)
236
+
237
+
238
+ @functools.cache
239
+ def _cursor_image() -> pil_image.Image:
240
+ return lf.Image.from_bytes(
241
+ b"""\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x19\x00\x00\x00\x1e\x08\x06\x00\x00\x00\xd9\xec\xb5\xdb\x00\x00\x04gIDATx\x9c\xc5\x95_H[g\x18\xc6\x9f\x9c$\'1\xccl\x8ch\xe7\xe2\xb4%\xd2\xea\xa8s\xdd&\xd8Q\x9b\xba\x8d\x8eZA\x10\x0cQ\xca\x84\xad^l\x17e-+\xb8\x1b\x99]\xbb\xe26Ea\xa0\xb2\x8bA\x07\xcat\xa3(Sj\xa7\xad4C+v\x17\xa27\xb6\xb3\x81\xd9\xe9b\xa41mNr\xce\xc9I\x9e]\xa8\x99v\xadZk\xb7\x07^\xbes\xf1\x9d\xe7\xf7\xbd\x7f\xbes\x80\xffR\xe1p\xd8\xa5(J\xf9\xd3\xf0\x16V\xd6\xd9\xd9\xd9lQ\x14\x7f\x88\xc5b_\xf5\xf4\xf4X\x9e\x06\xcc422R\x17\x89D8;;K\x92W\xee\xdc\xb9\xb3{\xbb\xccW2\x01\xc9X(\x14\x82\xd3\xe9DOOO\x91\xddn\xbf"IR\xe9v\x81\x00@\xf4x<5\xc1`\x90V\xab\x95:\x9d\x8e\xb5\xb5\xb5T\x14E%\xf9Y[[\x9bq[ \x83\x83\x83\x9f\x84B!\xee\xd8\xb1\x83\x00\x08\x80\xc5\xc5\xc5\x9c\x9b\x9b#\xc9\xee\x9b7o\xa6?1\xa4\xbf\xbf\xffD8\x1c\xa6\xddnO@\x00p\xd7\xae]\xbcv\xed\x1aIN\x05\x02\x81C[1_\xe9\tUU\x8d\n\x82\x00\x83\xc1\xb0f\x83\xd7\xeb\xc5\xe1\xc3\x87\xd1\xdc\xdc\xbc\xdbj\xb5\xfe\xac\xaa\xea\x89\xadfb\xbcx\xf1\xe2\x07\x8a\xa2\xd0\xe1p\xac\xc9du\x1c;v\x8c\x81@\x80$\xbf\x1b\x1e\x1e~\xfe\xb1!\xed\xed\xedU\xd1h\x94\xd9\xd9\xd9\x8f\x84\x00`^^\x1e\'&&Hr\xcc\xe7\xf3\xbd\xba\x19\xf3\xc4\x08\xab\xaa\xaa\x02\x80\xd1\xb8\xfe \x8d\x8f\x8f\xa3\xb0\xb0\x10\x1d\x1d\x1do\xa4\xa6\xa6\xfe\x12\x0e\x87+7\x0b\xa1\xaa\xaaQ\x00\x14Eq\xc3\x93-..\xa2\xb2\xb2\x12\xa7N\x9d\xb2\x01\xf8\x9e\xe4\x97\r\r\rI\x1b\xbdghnn.\x8d\xc5b\xb1\x82\x82\x82u\xcb\xf5`\x14\x15\x15qff\x86$/y\xbd\xde\x9d\xebe\x02I\x92\xa2\x00h2\x996\xccd\xb5\xae^\xbd\x8a}\xfb\xf6\xa1\xb7\xb7\xf7\xdd\xf4\xf4\xf4\x81\xb9\xb9\xb9\xb7\xfe\x95\xc1\xf2\xcaH$\x12\x15\x04!&\x8a\xa2\xfeQ\x86\x82 ++\x0b\xa2(\xc2b\xb1 ))\tf\xb3\x19\x06\x83\x01\x9d\x9d\x9d\xd8\xbbw\xaf#33\xb37\x14\n}\xd4\xd7\xd7w\xc1\xe5r\xc5VCt\xb2,k\x00\xe2\x1b\xf5\xa4\xa6\xa6\x06UUU\x08\x85B\x7f\x01P\x01h$\xa3\xb2,3\x18\x0cbjj*n6\x9b\xdfKII\x99\x00\xf0\x1b\x00&n^$\x12\xd1\x00\xc4VC,\x16\x0b\xce\x9d;\x87\xee\xeen\x0c\r\r!\x1e\x8f\xe3\xfc\xf9\xf3p\xbb\xdd\x18\x19\x19\xb9TVV\xf6\xa3\xa6iPU5\x8e\xa5\xd2\x1b\x00\xc4\x00\xf8\x00\xfc\t@\x0f@[\xf1\xd3\x1f?~\xbc\x80d\xd0\xedv\x13\x00\xf7\xec\xd9\xc3\xeb\xd7\xaf\x93dtxx\x98F\xa31\xd1\xec3g\xce\x90\xa4\xdf\xe5r\x1d\x04\xf0\xc2r\xa4\x02H\x01`\x03\xf0\x1c\x003\x00\xdd\xea*\xe8].W>\xc9\xbbn\xb7\x9b\xc5\xc5\xc5\xf4\xfb\xfd\x9c\x9e\x9e\xfe\xf5\xf4\xe9\xd3\xf5$#\xd5\xd5\xd5\t\x88\xd5j\xe5\xcc\xcc\x0c%I\xfa\xf6A\xa3\xf5\xa4/))yM\x96e\xff\x8d\x1b7(I\x92:00pA\xaf\xd7\x1f\x05\xf0\xf6\xe8\xe8\xe8O~\xbf\x9f6\x9b-\x01\xaa\xa8\xa8 \xc9\x88\xc7\xe39\xb4i\xc8\x81\x03\x07^\x91e\xf9\xae\xa2(\xf7\xeb\xeb\xeb\xbf\x00P\x00 \x03\x80}\xff\xfe\xfdG5M\xf3555% \x82 \xd0\xe3\xf1\x90\xe4\x95\xb4\xb4\xb4M\xfd\xae\x85#G\x8e\xe4\xde\xbe}{\xb4\xba\xba\xfaS\x00\xafc\xa9\xb6z\x00&\x00;\xbb\xba\xba\xbe\xd64\x8d\xb9\xb9\xb9\tP~~>\xe3\xf18\xa7\xa7\xa7\xdf\xdfT&\x0e\x87\xe3\xa5\x9c\x9c\x9c\x83\x00r\xb0\xd4\xb8\x95\x8b\xaa\x03\xf0lrr\xf2\x9b\x81@`\xb2\xbf\xbf\x9f:\x9d.\x01jmm%\xc9\xdf\xcf\x9e=\x9b\xb6\x11D\x87\xa5iH\xc6C\xa6\x02K\xa3i\xaf\xab\xab\xfb\x98\xa4VZZ\x9a\x80dddP\x92$\xde\xbbw\xef\x1b\xa7\xd3i\xc0\x13\xea\x19\x00\xb9\xb7n\xdd\xba\xec\xf5zy\xf2\xe4Ivvvrrr\x92\x81@@\x9a\x9f\x9f\x9fhiiy\xf9!\x07|,\xe9\x01\xa4\x96\x97\x97W,,,\xfc\xe1\xf3\xf9\xc6\xc7\xc6\xc6:\x1a\x1b\x1b?w:\x9d\x1f\x9aL\xa6w\x00\xbc\xb8\xbco\x8d\x1e\x97j\x06\x90a\xb3\xd9v/,,h\x00\x16\x01\x84\x01H\x00\xee/\x87\x8c\xa5Rn\x19\xb2\xd2;\xcb\xf2\xb3\xb2\x1c\x1a\xfe\xf9\xfc\xff?\xfa\x1b\x9d\xbfI\x0e\x1da[h\x00\x00\x00\x00IEND\xaeB`\x82"""
242
+ ).to_pil_image()
@@ -0,0 +1,103 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Structures for location objects in an image."""
15
+
16
+ import unittest
17
+ import langfun as lf
18
+ from langfun.assistant.capabilities.gui import drawing
19
+ from langfun.assistant.capabilities.gui import location
20
+ from PIL import Image as pil_image
21
+
22
+
23
+ class DrawingTest(unittest.TestCase):
24
+
25
+ def test_blank_image(self):
26
+ image = drawing.blank_image((3, 3), (0, 0, 0))
27
+ self.assertIsInstance(image, lf.Image)
28
+ self.assertEqual(image.size, (3, 3))
29
+ self.assertEqual(image.image_format, 'png')
30
+ self.assertIn(b'\x89PNG\r\n\x1a\n', image.to_bytes())
31
+ image2 = drawing.blank_image((3, 3), (0, 0, 0), pil=True)
32
+ self.assertIsInstance(image2, pil_image.Image)
33
+ image2 = lf.Image.from_pil_image(image2)
34
+ self.assertIsInstance(image2, lf.Image)
35
+ self.assertEqual(image2.to_bytes(), image.to_bytes())
36
+
37
+ def test_draw_bboxes(self):
38
+ image = drawing.blank_image((5, 5), (0, 0, 0))
39
+ image2 = drawing.draw_bboxes(image, [])
40
+ self.assertIs(image, image2)
41
+
42
+ image = drawing.draw_bboxes(
43
+ image,
44
+ [
45
+ location.BBox(0, 0, 3, 3),
46
+ location.BBox(1, 1, 4, 4),
47
+ ],
48
+ line_color='red',
49
+ line_width=1,
50
+ )
51
+ self.assertEqual(image.size, (5, 5))
52
+ self.assertIsInstance(image, lf.Image)
53
+
54
+ def test_draw_points(self):
55
+ image = drawing.blank_image((5, 5), (0, 0, 0))
56
+ image2 = drawing.draw_points(image, [])
57
+ self.assertIs(image, image2)
58
+
59
+ image = drawing.draw_points(
60
+ image,
61
+ [
62
+ location.Coordinate(1, 1),
63
+ location.Coordinate(2, 2),
64
+ ],
65
+ color='red',
66
+ radius=1,
67
+ )
68
+ self.assertEqual(image.size, (5, 5))
69
+ self.assertIsInstance(image, lf.Image)
70
+
71
+ def test_draw_calibration_lines(self):
72
+ image = drawing.blank_image((5, 5), (0, 0, 0))
73
+ image = drawing.draw_calibration_lines(
74
+ image,
75
+ location.Coordinate(2, 2),
76
+ vline_color='red',
77
+ vline_width=1,
78
+ hline_color='green',
79
+ hline_width=1,
80
+ )
81
+ self.assertEqual(image.size, (5, 5))
82
+ self.assertIsInstance(image, lf.Image)
83
+
84
+ def test_draw_cursor(self):
85
+ image = drawing.blank_image((50, 50), (255, 0, 0))
86
+ image = drawing.draw_cursor(image, location.Coordinate(10, 10))
87
+ self.assertEqual(image.size, (50, 50))
88
+ self.assertIsInstance(image, lf.Image)
89
+
90
+ def test_draw_ref_points(self):
91
+ image = drawing.blank_image((50, 50), (255, 0, 0))
92
+ image = drawing.draw_ref_points(image, 10, 10, 3, 3)
93
+ self.assertEqual(image.size, (50, 50))
94
+ self.assertIsInstance(image, lf.Image)
95
+
96
+ def test_draw_text(self):
97
+ image = drawing.blank_image((50, 50), (255, 0, 0))
98
+ image = drawing.draw_text(image, 'Hello World', location.Coordinate(10, 10))
99
+ self.assertEqual(image.size, (50, 50))
100
+ self.assertIsInstance(image, lf.Image)
101
+
102
+ if __name__ == '__main__':
103
+ unittest.main()
@@ -0,0 +1,288 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Structures for location objects in an image."""
15
+
16
+ import math
17
+ import random
18
+ from typing import Optional, Union
19
+ import pyglove as pg
20
+
21
+
22
+ class Coordinate(pg.Object):
23
+ """A coordinate in a 2D image."""
24
+ x: int
25
+ y: int
26
+
27
+ def as_tuple(self) -> tuple[int, int]:
28
+ """Returns the coordinate as a tuple."""
29
+ return self.x, self.y
30
+
31
+ def __add__(
32
+ self, other: Union['Coordinate', tuple[int, int]]
33
+ ) -> 'Coordinate':
34
+ """Returns the coordinate plus the other coordinate."""
35
+ other = Coordinate.from_value(other)
36
+ return Coordinate(self.x + other.x, self.y + other.y)
37
+
38
+ def __radd__(
39
+ self, other: Union['Coordinate', tuple[int, int]]
40
+ ) -> 'Coordinate':
41
+ """Returns the coordinate plus the other coordinate."""
42
+ return self + other
43
+
44
+ def __sub__(
45
+ self, other: Union['Coordinate', tuple[int, int]]
46
+ ) -> 'Coordinate':
47
+ """Returns the coordinate minus the other coordinate."""
48
+ other = Coordinate.from_value(other)
49
+ return Coordinate(self.x - other.x, self.y - other.y)
50
+
51
+ def __rsub__(
52
+ self, other: Union['Coordinate', tuple[int, int]]
53
+ ) -> 'Coordinate':
54
+ """Returns the coordinate minus the other coordinate."""
55
+ other = Coordinate.from_value(other)
56
+ return other - self
57
+
58
+ def __mul__(self, ratio: float) -> 'Coordinate':
59
+ """Returns the coordinate multiplied by a ratio."""
60
+ return Coordinate(self.x * ratio, self.y * ratio)
61
+
62
+ def __rmul__(self, ratio: float) -> 'Coordinate':
63
+ """Returns the coordinate multiplied by a ratio."""
64
+ return self * ratio
65
+
66
+ @classmethod
67
+ def random(cls,
68
+ bound: 'BBox',
69
+ rand: random.Random | None = None) -> 'Coordinate':
70
+ """Generates a random coordinate.
71
+
72
+ Args:
73
+ bound: The bounding box within which the coordinate will be generated.
74
+ rand: The random number generator to use. If None, the default random
75
+ number generator will be used.
76
+
77
+ Returns:
78
+ A random coordinate.
79
+ """
80
+ rand = rand or random
81
+ x = rand.randint(bound.left, bound.right)
82
+ y = rand.randint(bound.top, bound.bottom)
83
+ return cls(x, y)
84
+
85
+ def distance_to(self, point: 'Coordinate') -> float:
86
+ """Returns the distance to the point."""
87
+ return math.sqrt((self.x - point.x) ** 2 + (self.y - point.y) ** 2)
88
+
89
+ @classmethod
90
+ def from_value(
91
+ cls, value: Union[tuple[int, int], 'Coordinate']) -> 'Coordinate':
92
+ """Creates a coordinate from a tuple or coordinate."""
93
+ if isinstance(value, tuple):
94
+ if len(value) == 2:
95
+ return cls(*value)
96
+ else:
97
+ raise ValueError(f'Invalid tuple size: {len(value)}')
98
+ assert isinstance(value, Coordinate), value
99
+ return value
100
+
101
+
102
+ class BBox(pg.Object):
103
+ """A bounding box in a 2D image."""
104
+ x: int
105
+ y: int
106
+ right: int
107
+ bottom: int
108
+
109
+ def _on_bound(self):
110
+ super()._on_bound()
111
+ assert self.left < self.right, self
112
+ assert self.top < self.bottom, self
113
+
114
+ @property
115
+ def left(self) -> int:
116
+ """Returns the left coordinate of the bounding box."""
117
+ return self.x
118
+
119
+ @property
120
+ def top(self) -> int:
121
+ """Returns the top coordinate of the bounding box."""
122
+ return self.y
123
+
124
+ @property
125
+ def width(self) -> int:
126
+ """Returns the width of the bounding box."""
127
+ return self.right - self.left
128
+
129
+ @property
130
+ def height(self) -> int:
131
+ """Returns the height of the bounding box."""
132
+ return self.bottom - self.top
133
+
134
+ @property
135
+ def center(self) -> Coordinate:
136
+ """Returns the center of the bounding box."""
137
+ return Coordinate(
138
+ (self.left + self.right) // 2, (self.top + self.bottom) // 2)
139
+
140
+ @property
141
+ def area(self) -> int:
142
+ """Returns the area of the bounding box."""
143
+ return self.width * self.height
144
+
145
+ @property
146
+ def top_left(self) -> Coordinate:
147
+ """Returns the top left corner of the bounding box."""
148
+ return Coordinate(self.left, self.top)
149
+
150
+ @property
151
+ def bottom_right(self) -> Coordinate:
152
+ """Returns the bottom right corner of the bounding box."""
153
+ return Coordinate(self.right, self.bottom)
154
+
155
+ def clip(self, image_size: tuple[int, int]) -> Optional['BBox']:
156
+ """Clips the bounding box to the given image size.
157
+
158
+ Args:
159
+ image_size: The image size (width, height).
160
+
161
+ Returns:
162
+ The clipped bounding box.
163
+ """
164
+ width, height = image_size
165
+ x = min(max(0, self.x), width)
166
+ y = min(max(0, self.y), height)
167
+ right = min(max(0, self.right), width)
168
+ bottom = min(max(0, self.bottom), height)
169
+ try:
170
+ return BBox(x, y, right, bottom)
171
+ except AssertionError:
172
+ return None
173
+
174
+ def __contains__(
175
+ self,
176
+ v: Union[
177
+ tuple[int, int],
178
+ tuple[int, int, int, int],
179
+ Coordinate,
180
+ 'BBox'
181
+ ]) -> bool:
182
+ """Contains operator."""
183
+ if isinstance(v, tuple):
184
+ if len(v) == 2:
185
+ v = Coordinate(*v)
186
+ elif len(v) == 4:
187
+ v = BBox(*v)
188
+ else:
189
+ raise ValueError(f'Invalid tuple size: {len(v)}')
190
+ if isinstance(v, Coordinate):
191
+ return self.left <= v.x <= self.right and self.top <= v.y <= self.bottom
192
+ elif isinstance(v, BBox):
193
+ return (
194
+ v.left >= self.left
195
+ and v.right <= self.right
196
+ and v.top >= self.top
197
+ and v.bottom <= self.bottom
198
+ )
199
+ else:
200
+ raise ValueError(f'Invalid type: {type(v)}')
201
+
202
+ def intersects(self, other: 'BBox') -> bool:
203
+ """Returns the intersection of two bounding boxes."""
204
+ return not (self.right < other.left or
205
+ self.left > other.right or
206
+ self.bottom < other.top or
207
+ self.top > other.bottom)
208
+
209
+ def as_tuple(self) -> tuple[int, int, int, int]:
210
+ """Returns the bounding box as a tuple."""
211
+ return self.x, self.y, self.right, self.bottom
212
+
213
+ @classmethod
214
+ def random(
215
+ cls,
216
+ bound: 'BBox',
217
+ min_width: int = 30,
218
+ max_width: int = 200,
219
+ min_height: int = 30,
220
+ max_height: int = 200,
221
+ rand: random.Random | None = None
222
+ ) -> 'BBox':
223
+ """Generates a random bounding box.
224
+
225
+ Args:
226
+ bound: The bounding box within which the random bounding box will be
227
+ generated.
228
+ min_width: The minimum width of the random bounding box.
229
+ max_width: The maximum width of the random bounding box.
230
+ min_height: The minimum height of the random bounding box.
231
+ max_height: The maximum height of the random bounding box.
232
+ rand: The random number generator to use. If None, the default random
233
+ number generator will be used.
234
+
235
+ Returns:
236
+ A random bounding box.
237
+ """
238
+ assert max_width >= min_width, (min_width, max_width)
239
+ assert max_height >= min_height, (min_height, max_height)
240
+
241
+ if min_width > bound.width or min_height > bound.height:
242
+ raise ValueError('Minimum width or height is larger than the bound.')
243
+
244
+ rand = rand or random
245
+
246
+ max_width = min(max_width, bound.width)
247
+ max_height = min(max_height, bound.height)
248
+
249
+ width = rand.randint(min_width, max_width)
250
+ height = rand.randint(min_height, max_height)
251
+
252
+ max_left = bound.right - width
253
+ max_top = bound.bottom - height
254
+
255
+ left = rand.randint(bound.left, max_left)
256
+ top = rand.randint(bound.top, max_top)
257
+
258
+ right = left + width
259
+ bottom = top + height
260
+
261
+ return cls(left, top, right, bottom)
262
+
263
+ def matches(
264
+ self,
265
+ other: 'BBox',
266
+ area_diff_threshold: float = 5.0,
267
+ max_center_distance: float = 100
268
+ ) -> bool:
269
+ """Returns whether the bounding boxes match."""
270
+ if self.area == 0 or other.area == 0:
271
+ return False
272
+
273
+ return (
274
+ abs(self.area - other.area) / self.area < area_diff_threshold
275
+ and self.center.distance_to(other.center) < max_center_distance
276
+ )
277
+
278
+ def expand(
279
+ self,
280
+ width_scale: float = 1.0,
281
+ height_scale: float = 1.0,
282
+ ) -> 'BBox':
283
+ """Expands the bounding box from the center."""
284
+ new_width = int(self.width * width_scale)
285
+ new_height = int(self.height * height_scale)
286
+ new_x = self.x - (new_width - self.width) // 2
287
+ new_y = self.y - (new_height - self.height) // 2
288
+ return BBox(new_x, new_y, new_x + new_width, new_y + new_height)