datachain 0.11.11__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +33 -5
- datachain/catalog/loader.py +19 -13
- datachain/cli/__init__.py +2 -1
- datachain/cli/parser/studio.py +13 -1
- datachain/client/fsspec.py +12 -16
- datachain/client/hf.py +36 -14
- datachain/client/local.py +1 -4
- datachain/data_storage/warehouse.py +3 -8
- datachain/dataset.py +8 -0
- datachain/error.py +0 -12
- datachain/fs/utils.py +30 -0
- datachain/func/__init__.py +5 -0
- datachain/func/func.py +2 -1
- datachain/lib/dc.py +23 -8
- datachain/lib/file.py +55 -17
- datachain/lib/image.py +30 -6
- datachain/lib/listing.py +21 -39
- datachain/lib/video.py +7 -5
- datachain/model/bbox.py +209 -58
- datachain/model/pose.py +49 -37
- datachain/model/segment.py +22 -18
- datachain/model/ultralytics/bbox.py +9 -9
- datachain/model/ultralytics/pose.py +7 -7
- datachain/model/ultralytics/segment.py +7 -7
- datachain/model/utils.py +191 -0
- datachain/query/dataset.py +4 -2
- datachain/studio.py +8 -6
- datachain/utils.py +0 -16
- {datachain-0.11.11.dist-info → datachain-0.12.0.dist-info}/METADATA +4 -2
- {datachain-0.11.11.dist-info → datachain-0.12.0.dist-info}/RECORD +34 -32
- {datachain-0.11.11.dist-info → datachain-0.12.0.dist-info}/WHEEL +1 -1
- {datachain-0.11.11.dist-info → datachain-0.12.0.dist-info}/LICENSE +0 -0
- {datachain-0.11.11.dist-info → datachain-0.12.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.11.11.dist-info → datachain-0.12.0.dist-info}/top_level.txt +0 -0
datachain/lib/file.py
CHANGED
|
@@ -18,7 +18,6 @@ from urllib.request import url2pathname
|
|
|
18
18
|
|
|
19
19
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
20
20
|
from fsspec.utils import stringify_path
|
|
21
|
-
from PIL import Image as PilImage
|
|
22
21
|
from pydantic import Field, field_validator
|
|
23
22
|
|
|
24
23
|
from datachain.client.fileslice import FileSlice
|
|
@@ -52,7 +51,7 @@ class FileExporter(NodesThreadPool):
|
|
|
52
51
|
|
|
53
52
|
def __init__(
|
|
54
53
|
self,
|
|
55
|
-
output: str,
|
|
54
|
+
output: Union[str, os.PathLike[str]],
|
|
56
55
|
placement: ExportPlacement,
|
|
57
56
|
use_cache: bool,
|
|
58
57
|
link_type: Literal["copy", "symlink"],
|
|
@@ -243,6 +242,30 @@ class File(DataModel):
|
|
|
243
242
|
self._catalog = None
|
|
244
243
|
self._caching_enabled: bool = False
|
|
245
244
|
|
|
245
|
+
def as_text_file(self) -> "TextFile":
|
|
246
|
+
"""Convert the file to a `TextFile` object."""
|
|
247
|
+
if isinstance(self, TextFile):
|
|
248
|
+
return self
|
|
249
|
+
file = TextFile(**self.model_dump())
|
|
250
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
251
|
+
return file
|
|
252
|
+
|
|
253
|
+
def as_image_file(self) -> "ImageFile":
|
|
254
|
+
"""Convert the file to a `ImageFile` object."""
|
|
255
|
+
if isinstance(self, ImageFile):
|
|
256
|
+
return self
|
|
257
|
+
file = ImageFile(**self.model_dump())
|
|
258
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
259
|
+
return file
|
|
260
|
+
|
|
261
|
+
def as_video_file(self) -> "VideoFile":
|
|
262
|
+
"""Convert the file to a `VideoFile` object."""
|
|
263
|
+
if isinstance(self, VideoFile):
|
|
264
|
+
return self
|
|
265
|
+
file = VideoFile(**self.model_dump())
|
|
266
|
+
file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
|
|
267
|
+
return file
|
|
268
|
+
|
|
246
269
|
@classmethod
|
|
247
270
|
def upload(
|
|
248
271
|
cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
|
|
@@ -292,20 +315,20 @@ class File(DataModel):
|
|
|
292
315
|
) as f:
|
|
293
316
|
yield io.TextIOWrapper(f) if mode == "r" else f
|
|
294
317
|
|
|
295
|
-
def
|
|
296
|
-
"""Returns file contents."""
|
|
318
|
+
def read_bytes(self, length: int = -1):
|
|
319
|
+
"""Returns file contents as bytes."""
|
|
297
320
|
with self.open() as stream:
|
|
298
321
|
return stream.read(length)
|
|
299
322
|
|
|
300
|
-
def read_bytes(self):
|
|
301
|
-
"""Returns file contents as bytes."""
|
|
302
|
-
return self.read()
|
|
303
|
-
|
|
304
323
|
def read_text(self):
|
|
305
324
|
"""Returns file contents as text."""
|
|
306
325
|
with self.open(mode="r") as stream:
|
|
307
326
|
return stream.read()
|
|
308
327
|
|
|
328
|
+
def read(self, length: int = -1):
|
|
329
|
+
"""Returns file contents."""
|
|
330
|
+
return self.read_bytes(length)
|
|
331
|
+
|
|
309
332
|
def save(self, destination: str, client_config: Optional[dict] = None):
|
|
310
333
|
"""Writes it's content to destination"""
|
|
311
334
|
destination = stringify_path(destination)
|
|
@@ -333,7 +356,7 @@ class File(DataModel):
|
|
|
333
356
|
|
|
334
357
|
def export(
|
|
335
358
|
self,
|
|
336
|
-
output: str,
|
|
359
|
+
output: Union[str, os.PathLike[str]],
|
|
337
360
|
placement: ExportPlacement = "fullpath",
|
|
338
361
|
use_cache: bool = True,
|
|
339
362
|
link_type: Literal["copy", "symlink"] = "copy",
|
|
@@ -374,15 +397,10 @@ class File(DataModel):
|
|
|
374
397
|
client.download(self, callback=self._download_cb)
|
|
375
398
|
|
|
376
399
|
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
|
|
377
|
-
from datachain.client.hf import HfClient
|
|
378
|
-
|
|
379
400
|
if self._catalog is None:
|
|
380
401
|
raise RuntimeError("cannot prefetch file because catalog is not setup")
|
|
381
402
|
|
|
382
403
|
client = self._catalog.get_client(self.source)
|
|
383
|
-
if client.protocol == HfClient.protocol:
|
|
384
|
-
return False
|
|
385
|
-
|
|
386
404
|
await client._download(self, callback=download_cb or self._download_cb)
|
|
387
405
|
self._set_stream(
|
|
388
406
|
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
|
|
@@ -430,7 +448,9 @@ class File(DataModel):
|
|
|
430
448
|
path = url2pathname(path)
|
|
431
449
|
return path
|
|
432
450
|
|
|
433
|
-
def get_destination_path(
|
|
451
|
+
def get_destination_path(
|
|
452
|
+
self, output: Union[str, os.PathLike[str]], placement: ExportPlacement
|
|
453
|
+
) -> str:
|
|
434
454
|
"""
|
|
435
455
|
Returns full destination path of a file for exporting to some output
|
|
436
456
|
based on export placement
|
|
@@ -551,18 +571,36 @@ class TextFile(File):
|
|
|
551
571
|
class ImageFile(File):
|
|
552
572
|
"""`DataModel` for reading image files."""
|
|
553
573
|
|
|
574
|
+
def get_info(self) -> "Image":
|
|
575
|
+
"""
|
|
576
|
+
Retrieves metadata and information about the image file.
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
Image: A Model containing image metadata such as width, height and format.
|
|
580
|
+
"""
|
|
581
|
+
from .image import image_info
|
|
582
|
+
|
|
583
|
+
return image_info(self)
|
|
584
|
+
|
|
554
585
|
def read(self):
|
|
555
586
|
"""Returns `PIL.Image.Image` object."""
|
|
587
|
+
from PIL import Image as PilImage
|
|
588
|
+
|
|
556
589
|
fobj = super().read()
|
|
557
590
|
return PilImage.open(BytesIO(fobj))
|
|
558
591
|
|
|
559
|
-
def save(
|
|
592
|
+
def save( # type: ignore[override]
|
|
593
|
+
self,
|
|
594
|
+
destination: str,
|
|
595
|
+
format: Optional[str] = None,
|
|
596
|
+
client_config: Optional[dict] = None,
|
|
597
|
+
):
|
|
560
598
|
"""Writes it's content to destination"""
|
|
561
599
|
destination = stringify_path(destination)
|
|
562
600
|
|
|
563
601
|
client: Client = self._catalog.get_client(destination, **(client_config or {}))
|
|
564
602
|
with client.fs.open(destination, mode="wb") as f:
|
|
565
|
-
self.read().save(f)
|
|
603
|
+
self.read().save(f, format=format)
|
|
566
604
|
|
|
567
605
|
|
|
568
606
|
class Image(DataModel):
|
datachain/lib/image.py
CHANGED
|
@@ -1,17 +1,41 @@
|
|
|
1
1
|
from typing import Callable, Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from PIL import Image
|
|
4
|
+
from PIL import Image as PILImage
|
|
5
|
+
|
|
6
|
+
from datachain.lib.file import File, FileError, Image, ImageFile
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def image_info(file: Union[File, ImageFile]) -> Image:
|
|
10
|
+
"""
|
|
11
|
+
Returns image file information.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
file (ImageFile): Image file object.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Image: Image file information.
|
|
18
|
+
"""
|
|
19
|
+
try:
|
|
20
|
+
img = file.as_image_file().read()
|
|
21
|
+
except Exception as exc:
|
|
22
|
+
raise FileError(file, "unable to open image file") from exc
|
|
23
|
+
|
|
24
|
+
return Image(
|
|
25
|
+
width=img.width,
|
|
26
|
+
height=img.height,
|
|
27
|
+
format=img.format or "",
|
|
28
|
+
)
|
|
5
29
|
|
|
6
30
|
|
|
7
31
|
def convert_image(
|
|
8
|
-
img:
|
|
32
|
+
img: PILImage.Image,
|
|
9
33
|
mode: str = "RGB",
|
|
10
34
|
size: Optional[tuple[int, int]] = None,
|
|
11
35
|
transform: Optional[Callable] = None,
|
|
12
36
|
encoder: Optional[Callable] = None,
|
|
13
37
|
device: Optional[Union[str, torch.device]] = None,
|
|
14
|
-
) -> Union[
|
|
38
|
+
) -> Union[PILImage.Image, torch.Tensor]:
|
|
15
39
|
"""
|
|
16
40
|
Resize, transform, and otherwise convert an image.
|
|
17
41
|
|
|
@@ -47,13 +71,13 @@ def convert_image(
|
|
|
47
71
|
|
|
48
72
|
|
|
49
73
|
def convert_images(
|
|
50
|
-
images: Union[
|
|
74
|
+
images: Union[PILImage.Image, list[PILImage.Image]],
|
|
51
75
|
mode: str = "RGB",
|
|
52
76
|
size: Optional[tuple[int, int]] = None,
|
|
53
77
|
transform: Optional[Callable] = None,
|
|
54
78
|
encoder: Optional[Callable] = None,
|
|
55
79
|
device: Optional[Union[str, torch.device]] = None,
|
|
56
|
-
) -> Union[list[
|
|
80
|
+
) -> Union[list[PILImage.Image], torch.Tensor]:
|
|
57
81
|
"""
|
|
58
82
|
Resize, transform, and otherwise convert one or more images.
|
|
59
83
|
|
|
@@ -65,7 +89,7 @@ def convert_images(
|
|
|
65
89
|
encoder (Callable): Encode image using model.
|
|
66
90
|
device (str or torch.device): Device to use.
|
|
67
91
|
"""
|
|
68
|
-
if isinstance(images,
|
|
92
|
+
if isinstance(images, PILImage.Image):
|
|
69
93
|
images = [images]
|
|
70
94
|
|
|
71
95
|
converted = [
|
datachain/lib/listing.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
|
1
|
+
import glob
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
import posixpath
|
|
4
5
|
from collections.abc import Iterator
|
|
5
|
-
from
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
|
6
8
|
|
|
7
9
|
from fsspec.asyn import get_loop
|
|
8
10
|
from sqlalchemy.sql.expression import true
|
|
9
11
|
|
|
12
|
+
import datachain.fs.utils as fsutils
|
|
10
13
|
from datachain.asyn import iter_over_async
|
|
11
14
|
from datachain.client import Client
|
|
12
|
-
from datachain.error import
|
|
15
|
+
from datachain.error import ClientError
|
|
13
16
|
from datachain.lib.file import File
|
|
14
17
|
from datachain.query.schema import Column
|
|
15
18
|
from datachain.sql.functions import path as pathfunc
|
|
16
|
-
from datachain.telemetry import telemetry
|
|
17
19
|
from datachain.utils import uses_glob
|
|
18
20
|
|
|
19
21
|
if TYPE_CHECKING:
|
|
@@ -92,38 +94,6 @@ def ls(
|
|
|
92
94
|
return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
|
|
93
95
|
|
|
94
96
|
|
|
95
|
-
def _isfile(client: "Client", path: str) -> bool:
|
|
96
|
-
"""
|
|
97
|
-
Returns True if uri points to a file
|
|
98
|
-
"""
|
|
99
|
-
try:
|
|
100
|
-
if "://" in path:
|
|
101
|
-
# This makes sure that the uppercase scheme is converted to lowercase
|
|
102
|
-
scheme, path = path.split("://", 1)
|
|
103
|
-
path = f"{scheme.lower()}://{path}"
|
|
104
|
-
|
|
105
|
-
if os.name == "nt" and "*" in path:
|
|
106
|
-
# On Windows, the glob pattern "*" is not supported
|
|
107
|
-
return False
|
|
108
|
-
|
|
109
|
-
info = client.fs.info(path)
|
|
110
|
-
name = info.get("name")
|
|
111
|
-
# case for special simulated directories on some clouds
|
|
112
|
-
# e.g. Google creates a zero byte file with the same name as the
|
|
113
|
-
# directory with a trailing slash at the end
|
|
114
|
-
if not name or name.endswith("/"):
|
|
115
|
-
return False
|
|
116
|
-
|
|
117
|
-
return info["type"] == "file"
|
|
118
|
-
except FileNotFoundError:
|
|
119
|
-
return False
|
|
120
|
-
except REMOTE_ERRORS as e:
|
|
121
|
-
raise ClientError(
|
|
122
|
-
message=str(e),
|
|
123
|
-
error_code=getattr(e, "code", None),
|
|
124
|
-
) from e
|
|
125
|
-
|
|
126
|
-
|
|
127
97
|
def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
|
|
128
98
|
"""
|
|
129
99
|
Parsing uri and returns listing dataset name, listing uri and listing path
|
|
@@ -156,8 +126,16 @@ def listing_uri_from_name(dataset_name: str) -> str:
|
|
|
156
126
|
return dataset_name.removeprefix(LISTING_PREFIX)
|
|
157
127
|
|
|
158
128
|
|
|
129
|
+
@contextmanager
|
|
130
|
+
def _reraise_as_client_error() -> Iterator[None]:
|
|
131
|
+
try:
|
|
132
|
+
yield
|
|
133
|
+
except Exception as e:
|
|
134
|
+
raise ClientError(message=str(e), error_code=getattr(e, "code", None)) from e
|
|
135
|
+
|
|
136
|
+
|
|
159
137
|
def get_listing(
|
|
160
|
-
uri: str, session: "Session", update: bool = False
|
|
138
|
+
uri: Union[str, os.PathLike[str]], session: "Session", update: bool = False
|
|
161
139
|
) -> tuple[Optional[str], str, str, bool]:
|
|
162
140
|
"""Returns correct listing dataset name that must be used for saving listing
|
|
163
141
|
operation. It takes into account existing listings and reusability of those.
|
|
@@ -167,6 +145,7 @@ def get_listing(
|
|
|
167
145
|
be used to find rows based on uri.
|
|
168
146
|
"""
|
|
169
147
|
from datachain.client.local import FileClient
|
|
148
|
+
from datachain.telemetry import telemetry
|
|
170
149
|
|
|
171
150
|
catalog = session.catalog
|
|
172
151
|
cache = catalog.cache
|
|
@@ -174,11 +153,14 @@ def get_listing(
|
|
|
174
153
|
|
|
175
154
|
client = Client.get_client(uri, cache, **client_config)
|
|
176
155
|
telemetry.log_param("client", client.PREFIX)
|
|
156
|
+
if not isinstance(uri, str):
|
|
157
|
+
uri = os.fspath(uri)
|
|
177
158
|
|
|
178
159
|
# we don't want to use cached dataset (e.g. for a single file listing)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
160
|
+
isfile = _reraise_as_client_error()(fsutils.isfile)
|
|
161
|
+
if not glob.has_magic(uri) and not uri.endswith("/") and isfile(client.fs, uri):
|
|
162
|
+
_, path = Client.parse_url(uri)
|
|
163
|
+
return None, uri, path, False
|
|
182
164
|
|
|
183
165
|
ds_name, list_uri, list_path = parse_listing_uri(uri, client_config)
|
|
184
166
|
listing = None
|
datachain/lib/video.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
import shutil
|
|
3
3
|
import tempfile
|
|
4
|
-
from typing import Optional
|
|
4
|
+
from typing import Optional, Union
|
|
5
5
|
|
|
6
6
|
from numpy import ndarray
|
|
7
7
|
|
|
8
|
-
from datachain.lib.file import FileError, ImageFile, Video, VideoFile
|
|
8
|
+
from datachain.lib.file import File, FileError, ImageFile, Video, VideoFile
|
|
9
9
|
|
|
10
10
|
try:
|
|
11
11
|
import ffmpeg
|
|
@@ -18,7 +18,7 @@ except ImportError as exc:
|
|
|
18
18
|
) from exc
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def video_info(file: VideoFile) -> Video:
|
|
21
|
+
def video_info(file: Union[File, VideoFile]) -> Video:
|
|
22
22
|
"""
|
|
23
23
|
Returns video file information.
|
|
24
24
|
|
|
@@ -28,6 +28,8 @@ def video_info(file: VideoFile) -> Video:
|
|
|
28
28
|
Returns:
|
|
29
29
|
Video: Video file information.
|
|
30
30
|
"""
|
|
31
|
+
file = file.as_video_file()
|
|
32
|
+
|
|
31
33
|
if not (file_path := file.get_local_path()):
|
|
32
34
|
file.ensure_cached()
|
|
33
35
|
file_path = file.get_local_path()
|
|
@@ -170,7 +172,7 @@ def save_video_frame(
|
|
|
170
172
|
output_file = posixpath.join(
|
|
171
173
|
output, f"{video.get_file_stem()}_{frame:04d}.{format}"
|
|
172
174
|
)
|
|
173
|
-
return ImageFile.upload(img, output_file)
|
|
175
|
+
return ImageFile.upload(img, output_file, catalog=video._catalog)
|
|
174
176
|
|
|
175
177
|
|
|
176
178
|
def save_video_fragment(
|
|
@@ -218,6 +220,6 @@ def save_video_fragment(
|
|
|
218
220
|
).output(output_file_tmp).run(quiet=True)
|
|
219
221
|
|
|
220
222
|
with open(output_file_tmp, "rb") as f:
|
|
221
|
-
return VideoFile.upload(f.read(), output_file)
|
|
223
|
+
return VideoFile.upload(f.read(), output_file, catalog=video._catalog)
|
|
222
224
|
finally:
|
|
223
225
|
shutil.rmtree(temp_dir)
|
datachain/model/bbox.py
CHANGED
|
@@ -1,47 +1,216 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import TYPE_CHECKING, Union
|
|
3
|
+
|
|
1
4
|
from pydantic import Field
|
|
2
5
|
|
|
3
6
|
from datachain.lib.data_model import DataModel
|
|
4
7
|
|
|
8
|
+
from .utils import convert_bbox, validate_bbox
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from .pose import Pose, Pose3D
|
|
12
|
+
|
|
5
13
|
|
|
6
14
|
class BBox(DataModel):
|
|
7
15
|
"""
|
|
8
|
-
A data model
|
|
16
|
+
A data model representing a bounding box.
|
|
9
17
|
|
|
10
18
|
Attributes:
|
|
11
|
-
title (str): The title
|
|
12
|
-
coords (list[int]):
|
|
19
|
+
title (str): The title or label associated with the bounding box.
|
|
20
|
+
coords (list[int]): A list of four bounding box coordinates.
|
|
13
21
|
|
|
14
|
-
The bounding box
|
|
15
|
-
- (x1, y1)
|
|
16
|
-
- (x2, y2)
|
|
22
|
+
The bounding box follows the PASCAL VOC format, where:
|
|
23
|
+
- (x1, y1) represents the pixel coordinates of the top-left corner.
|
|
24
|
+
- (x2, y2) represents the pixel coordinates of the bottom-right corner.
|
|
17
25
|
"""
|
|
18
26
|
|
|
19
27
|
title: str = Field(default="")
|
|
20
28
|
coords: list[int] = Field(default=[])
|
|
21
29
|
|
|
22
30
|
@staticmethod
|
|
23
|
-
def
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
+
def from_albumentations(
|
|
32
|
+
coords: Sequence[float],
|
|
33
|
+
img_size: Sequence[int],
|
|
34
|
+
title: str = "",
|
|
35
|
+
) -> "BBox":
|
|
36
|
+
"""
|
|
37
|
+
Create a bounding box from Albumentations format.
|
|
38
|
+
|
|
39
|
+
Albumentations represents bounding boxes as `[x_min, y_min, x_max, y_max]`
|
|
40
|
+
with normalized coordinates (values between 0 and 1) relative to the image size.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
coords (Sequence[float]): The bounding box coordinates in
|
|
44
|
+
Albumentations format.
|
|
45
|
+
img_size (Sequence[int]): The reference image size as `[width, height]`.
|
|
46
|
+
title (str, optional): The title or label of the bounding box.
|
|
47
|
+
Defaults to an empty string.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
BBox: The bounding box data model.
|
|
51
|
+
"""
|
|
52
|
+
validate_bbox(coords, float)
|
|
53
|
+
bbox_coords = convert_bbox(coords, img_size, "albumentations", "voc")
|
|
54
|
+
return BBox(title=title, coords=list(map(round, bbox_coords)))
|
|
55
|
+
|
|
56
|
+
def to_albumentations(self, img_size: Sequence[int]) -> list[float]:
|
|
57
|
+
"""
|
|
58
|
+
Convert the bounding box coordinates to Albumentations format.
|
|
59
|
+
|
|
60
|
+
Albumentations represents bounding boxes as `[x_min, y_min, x_max, y_max]`
|
|
61
|
+
with normalized coordinates (values between 0 and 1) relative to the image size.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
img_size (Sequence[int]): The reference image size as `[width, height]`.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
list[float]: The bounding box coordinates in Albumentations format.
|
|
68
|
+
"""
|
|
69
|
+
return convert_bbox(self.coords, img_size, "voc", "albumentations")
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def from_coco(
|
|
73
|
+
coords: Sequence[float],
|
|
74
|
+
title: str = "",
|
|
75
|
+
) -> "BBox":
|
|
76
|
+
"""
|
|
77
|
+
Create a bounding box from COCO format.
|
|
78
|
+
|
|
79
|
+
COCO format represents bounding boxes as [x_min, y_min, width, height], where:
|
|
80
|
+
- (x_min, y_min) are the pixel coordinates of the top-left corner.
|
|
81
|
+
- width and height define the size of the bounding box in pixels.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
coords (Sequence[float]): The bounding box coordinates in COCO format.
|
|
85
|
+
title (str): The title of the bounding box.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
BBox: The bounding box data model.
|
|
89
|
+
"""
|
|
90
|
+
validate_bbox(coords, float, int)
|
|
91
|
+
bbox_coords = convert_bbox(coords, [], "coco", "voc")
|
|
92
|
+
return BBox(title=title, coords=list(map(round, bbox_coords)))
|
|
93
|
+
|
|
94
|
+
def to_coco(self) -> list[int]:
|
|
95
|
+
"""
|
|
96
|
+
Return the bounding box coordinates in COCO format.
|
|
97
|
+
|
|
98
|
+
COCO format represents bounding boxes as [x_min, y_min, width, height], where:
|
|
99
|
+
- (x_min, y_min) are the pixel coordinates of the top-left corner.
|
|
100
|
+
- width and height define the size of the bounding box in pixels.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
list[int]: The bounding box coordinates in COCO format.
|
|
104
|
+
"""
|
|
105
|
+
res = convert_bbox(self.coords, [], "voc", "coco")
|
|
106
|
+
return list(map(round, res))
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def from_voc(
|
|
110
|
+
coords: Sequence[float],
|
|
111
|
+
title: str = "",
|
|
112
|
+
) -> "BBox":
|
|
113
|
+
"""
|
|
114
|
+
Create a bounding box from PASCAL VOC format.
|
|
115
|
+
|
|
116
|
+
PASCAL VOC format represents bounding boxes as [x_min, y_min, x_max, y_max],
|
|
117
|
+
where:
|
|
118
|
+
- (x_min, y_min) are the pixel coordinates of the top-left corner.
|
|
119
|
+
- (x_max, y_max) are the pixel coordinates of the bottom-right corner.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
coords (Sequence[float]): The bounding box coordinates in VOC format.
|
|
123
|
+
title (str): The title of the bounding box.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
BBox: The bounding box data model.
|
|
127
|
+
"""
|
|
128
|
+
validate_bbox(coords, float, int)
|
|
129
|
+
return BBox(title=title, coords=list(map(round, coords)))
|
|
130
|
+
|
|
131
|
+
def to_voc(self) -> list[int]:
|
|
132
|
+
"""
|
|
133
|
+
Return the bounding box coordinates in PASCAL VOC format.
|
|
134
|
+
|
|
135
|
+
PASCAL VOC format represents bounding boxes as [x_min, y_min, x_max, y_max],
|
|
136
|
+
where:
|
|
137
|
+
- (x_min, y_min) are the pixel coordinates of the top-left corner.
|
|
138
|
+
- (x_max, y_max) are the pixel coordinates of the bottom-right corner.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
list[int]: The bounding box coordinates in VOC format.
|
|
142
|
+
"""
|
|
143
|
+
return self.coords
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def from_yolo(
|
|
147
|
+
coords: Sequence[float],
|
|
148
|
+
img_size: Sequence[int],
|
|
149
|
+
title: str = "",
|
|
150
|
+
) -> "BBox":
|
|
151
|
+
"""
|
|
152
|
+
Create a bounding box from YOLO format.
|
|
153
|
+
|
|
154
|
+
YOLO format represents bounding boxes as [x_center, y_center, width, height],
|
|
155
|
+
where:
|
|
156
|
+
- (x_center, y_center) are the normalized coordinates of the box center.
|
|
157
|
+
- width and height normalized values define the size of the bounding box.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
coords (Sequence[float]): The bounding box coordinates in YOLO format.
|
|
161
|
+
img_size (Sequence[int]): The reference image size as `[width, height]`.
|
|
162
|
+
title (str): The title of the bounding box.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
BBox: The bounding box data model.
|
|
166
|
+
"""
|
|
167
|
+
validate_bbox(coords, float)
|
|
168
|
+
bbox_coords = convert_bbox(coords, img_size, "yolo", "voc")
|
|
169
|
+
return BBox(title=title, coords=list(map(round, bbox_coords)))
|
|
170
|
+
|
|
171
|
+
def to_yolo(self, img_size: Sequence[int]) -> list[float]:
|
|
172
|
+
"""
|
|
173
|
+
Return the bounding box coordinates in YOLO format.
|
|
174
|
+
|
|
175
|
+
YOLO format represents bounding boxes as [x_center, y_center, width, height],
|
|
176
|
+
where:
|
|
177
|
+
- (x_center, y_center) are the normalized coordinates of the box center.
|
|
178
|
+
- width and height normalized values define the size of the bounding box.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
img_size (Sequence[int]): The reference image size as `[width, height]`.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
list[float]: The bounding box coordinates in YOLO format.
|
|
185
|
+
"""
|
|
186
|
+
return convert_bbox(self.coords, img_size, "voc", "yolo")
|
|
187
|
+
|
|
188
|
+
def point_inside(self, x: int, y: int) -> bool:
|
|
189
|
+
"""
|
|
190
|
+
Return True if the point is inside the bounding box.
|
|
191
|
+
|
|
192
|
+
Assumes that if the point is on the edge of the bounding box,
|
|
193
|
+
it is considered inside.
|
|
194
|
+
"""
|
|
195
|
+
x1, y1, x2, y2 = self.coords
|
|
196
|
+
return x1 <= x <= x2 and y1 <= y <= y2
|
|
197
|
+
|
|
198
|
+
def pose_inside(self, pose: Union["Pose", "Pose3D"]) -> bool:
|
|
199
|
+
"""Return True if the pose is inside the bounding box."""
|
|
200
|
+
return all(
|
|
201
|
+
self.point_inside(x, y) for x, y in zip(pose.x, pose.y) if x > 0 or y > 0
|
|
31
202
|
)
|
|
32
203
|
|
|
204
|
+
@staticmethod
|
|
205
|
+
def from_list(coords: Sequence[float], title: str = "") -> "BBox":
|
|
206
|
+
return BBox.from_voc(coords, title=title)
|
|
207
|
+
|
|
33
208
|
@staticmethod
|
|
34
209
|
def from_dict(coords: dict[str, float], title: str = "") -> "BBox":
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
"
|
|
38
|
-
|
|
39
|
-
"y2",
|
|
40
|
-
}, "Bounding box must be a dictionary with keys 'x1', 'y1', 'x2' and 'y2'."
|
|
41
|
-
return BBox.from_list(
|
|
42
|
-
[coords["x1"], coords["y1"], coords["x2"], coords["y2"]],
|
|
43
|
-
title=title,
|
|
44
|
-
)
|
|
210
|
+
keys = ("x1", "y1", "x2", "y2")
|
|
211
|
+
if not isinstance(coords, dict) or set(coords) != set(keys):
|
|
212
|
+
raise ValueError("Bounding box must be a dictionary with coordinates.")
|
|
213
|
+
return BBox.from_voc([coords[k] for k in keys], title=title)
|
|
45
214
|
|
|
46
215
|
|
|
47
216
|
class OBBox(DataModel):
|
|
@@ -63,40 +232,22 @@ class OBBox(DataModel):
|
|
|
63
232
|
coords: list[int] = Field(default=[])
|
|
64
233
|
|
|
65
234
|
@staticmethod
|
|
66
|
-
def from_list(coords:
|
|
67
|
-
|
|
68
|
-
"Oriented bounding box must be a list of
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
)
|
|
235
|
+
def from_list(coords: Sequence[float], title: str = "") -> "OBBox":
|
|
236
|
+
if not isinstance(coords, (list, tuple)):
|
|
237
|
+
raise TypeError("Oriented bounding box must be a list of coordinates.")
|
|
238
|
+
if len(coords) != 8:
|
|
239
|
+
raise ValueError("Oriented bounding box must have 8 coordinates.")
|
|
240
|
+
if not all(isinstance(value, (int, float)) for value in coords):
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"Oriented bounding box coordinates must be floats or integers."
|
|
243
|
+
)
|
|
244
|
+
return OBBox(title=title, coords=list(map(round, coords)))
|
|
77
245
|
|
|
78
246
|
@staticmethod
|
|
79
247
|
def from_dict(coords: dict[str, float], title: str = "") -> "OBBox":
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
"y3",
|
|
87
|
-
"x4",
|
|
88
|
-
"y4",
|
|
89
|
-
}, "Oriented bounding box must be a dictionary with coordinates."
|
|
90
|
-
return OBBox.from_list(
|
|
91
|
-
[
|
|
92
|
-
coords["x1"],
|
|
93
|
-
coords["y1"],
|
|
94
|
-
coords["x2"],
|
|
95
|
-
coords["y2"],
|
|
96
|
-
coords["x3"],
|
|
97
|
-
coords["y3"],
|
|
98
|
-
coords["x4"],
|
|
99
|
-
coords["y4"],
|
|
100
|
-
],
|
|
101
|
-
title=title,
|
|
102
|
-
)
|
|
248
|
+
keys = ("x1", "y1", "x2", "y2", "x3", "y3", "x4", "y4")
|
|
249
|
+
if not isinstance(coords, dict) or set(coords) != set(keys):
|
|
250
|
+
raise ValueError(
|
|
251
|
+
"Oriented bounding box must be a dictionary with coordinates."
|
|
252
|
+
)
|
|
253
|
+
return OBBox.from_list([coords[k] for k in keys], title=title)
|