datachain 0.11.11__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/file.py CHANGED
@@ -18,7 +18,6 @@ from urllib.request import url2pathname
18
18
 
19
19
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
20
20
  from fsspec.utils import stringify_path
21
- from PIL import Image as PilImage
22
21
  from pydantic import Field, field_validator
23
22
 
24
23
  from datachain.client.fileslice import FileSlice
@@ -52,7 +51,7 @@ class FileExporter(NodesThreadPool):
52
51
 
53
52
  def __init__(
54
53
  self,
55
- output: str,
54
+ output: Union[str, os.PathLike[str]],
56
55
  placement: ExportPlacement,
57
56
  use_cache: bool,
58
57
  link_type: Literal["copy", "symlink"],
@@ -243,6 +242,30 @@ class File(DataModel):
243
242
  self._catalog = None
244
243
  self._caching_enabled: bool = False
245
244
 
245
+ def as_text_file(self) -> "TextFile":
246
+ """Convert the file to a `TextFile` object."""
247
+ if isinstance(self, TextFile):
248
+ return self
249
+ file = TextFile(**self.model_dump())
250
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
251
+ return file
252
+
253
+ def as_image_file(self) -> "ImageFile":
254
+ """Convert the file to a `ImageFile` object."""
255
+ if isinstance(self, ImageFile):
256
+ return self
257
+ file = ImageFile(**self.model_dump())
258
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
259
+ return file
260
+
261
+ def as_video_file(self) -> "VideoFile":
262
+ """Convert the file to a `VideoFile` object."""
263
+ if isinstance(self, VideoFile):
264
+ return self
265
+ file = VideoFile(**self.model_dump())
266
+ file._set_stream(self._catalog, caching_enabled=self._caching_enabled)
267
+ return file
268
+
246
269
  @classmethod
247
270
  def upload(
248
271
  cls, data: bytes, path: str, catalog: Optional["Catalog"] = None
@@ -292,20 +315,20 @@ class File(DataModel):
292
315
  ) as f:
293
316
  yield io.TextIOWrapper(f) if mode == "r" else f
294
317
 
295
- def read(self, length: int = -1):
296
- """Returns file contents."""
318
+ def read_bytes(self, length: int = -1):
319
+ """Returns file contents as bytes."""
297
320
  with self.open() as stream:
298
321
  return stream.read(length)
299
322
 
300
- def read_bytes(self):
301
- """Returns file contents as bytes."""
302
- return self.read()
303
-
304
323
  def read_text(self):
305
324
  """Returns file contents as text."""
306
325
  with self.open(mode="r") as stream:
307
326
  return stream.read()
308
327
 
328
+ def read(self, length: int = -1):
329
+ """Returns file contents."""
330
+ return self.read_bytes(length)
331
+
309
332
  def save(self, destination: str, client_config: Optional[dict] = None):
310
333
  """Writes it's content to destination"""
311
334
  destination = stringify_path(destination)
@@ -333,7 +356,7 @@ class File(DataModel):
333
356
 
334
357
  def export(
335
358
  self,
336
- output: str,
359
+ output: Union[str, os.PathLike[str]],
337
360
  placement: ExportPlacement = "fullpath",
338
361
  use_cache: bool = True,
339
362
  link_type: Literal["copy", "symlink"] = "copy",
@@ -374,15 +397,10 @@ class File(DataModel):
374
397
  client.download(self, callback=self._download_cb)
375
398
 
376
399
  async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
377
- from datachain.client.hf import HfClient
378
-
379
400
  if self._catalog is None:
380
401
  raise RuntimeError("cannot prefetch file because catalog is not setup")
381
402
 
382
403
  client = self._catalog.get_client(self.source)
383
- if client.protocol == HfClient.protocol:
384
- return False
385
-
386
404
  await client._download(self, callback=download_cb or self._download_cb)
387
405
  self._set_stream(
388
406
  self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
@@ -430,7 +448,9 @@ class File(DataModel):
430
448
  path = url2pathname(path)
431
449
  return path
432
450
 
433
- def get_destination_path(self, output: str, placement: ExportPlacement) -> str:
451
+ def get_destination_path(
452
+ self, output: Union[str, os.PathLike[str]], placement: ExportPlacement
453
+ ) -> str:
434
454
  """
435
455
  Returns full destination path of a file for exporting to some output
436
456
  based on export placement
@@ -551,18 +571,36 @@ class TextFile(File):
551
571
  class ImageFile(File):
552
572
  """`DataModel` for reading image files."""
553
573
 
574
+ def get_info(self) -> "Image":
575
+ """
576
+ Retrieves metadata and information about the image file.
577
+
578
+ Returns:
579
+ Image: A Model containing image metadata such as width, height and format.
580
+ """
581
+ from .image import image_info
582
+
583
+ return image_info(self)
584
+
554
585
  def read(self):
555
586
  """Returns `PIL.Image.Image` object."""
587
+ from PIL import Image as PilImage
588
+
556
589
  fobj = super().read()
557
590
  return PilImage.open(BytesIO(fobj))
558
591
 
559
- def save(self, destination: str, client_config: Optional[dict] = None):
592
+ def save( # type: ignore[override]
593
+ self,
594
+ destination: str,
595
+ format: Optional[str] = None,
596
+ client_config: Optional[dict] = None,
597
+ ):
560
598
  """Writes it's content to destination"""
561
599
  destination = stringify_path(destination)
562
600
 
563
601
  client: Client = self._catalog.get_client(destination, **(client_config or {}))
564
602
  with client.fs.open(destination, mode="wb") as f:
565
- self.read().save(f)
603
+ self.read().save(f, format=format)
566
604
 
567
605
 
568
606
  class Image(DataModel):
datachain/lib/image.py CHANGED
@@ -1,17 +1,41 @@
1
1
  from typing import Callable, Optional, Union
2
2
 
3
3
  import torch
4
- from PIL import Image
4
+ from PIL import Image as PILImage
5
+
6
+ from datachain.lib.file import File, FileError, Image, ImageFile
7
+
8
+
9
+ def image_info(file: Union[File, ImageFile]) -> Image:
10
+ """
11
+ Returns image file information.
12
+
13
+ Args:
14
+ file (ImageFile): Image file object.
15
+
16
+ Returns:
17
+ Image: Image file information.
18
+ """
19
+ try:
20
+ img = file.as_image_file().read()
21
+ except Exception as exc:
22
+ raise FileError(file, "unable to open image file") from exc
23
+
24
+ return Image(
25
+ width=img.width,
26
+ height=img.height,
27
+ format=img.format or "",
28
+ )
5
29
 
6
30
 
7
31
  def convert_image(
8
- img: Image.Image,
32
+ img: PILImage.Image,
9
33
  mode: str = "RGB",
10
34
  size: Optional[tuple[int, int]] = None,
11
35
  transform: Optional[Callable] = None,
12
36
  encoder: Optional[Callable] = None,
13
37
  device: Optional[Union[str, torch.device]] = None,
14
- ) -> Union[Image.Image, torch.Tensor]:
38
+ ) -> Union[PILImage.Image, torch.Tensor]:
15
39
  """
16
40
  Resize, transform, and otherwise convert an image.
17
41
 
@@ -47,13 +71,13 @@ def convert_image(
47
71
 
48
72
 
49
73
  def convert_images(
50
- images: Union[Image.Image, list[Image.Image]],
74
+ images: Union[PILImage.Image, list[PILImage.Image]],
51
75
  mode: str = "RGB",
52
76
  size: Optional[tuple[int, int]] = None,
53
77
  transform: Optional[Callable] = None,
54
78
  encoder: Optional[Callable] = None,
55
79
  device: Optional[Union[str, torch.device]] = None,
56
- ) -> Union[list[Image.Image], torch.Tensor]:
80
+ ) -> Union[list[PILImage.Image], torch.Tensor]:
57
81
  """
58
82
  Resize, transform, and otherwise convert one or more images.
59
83
 
@@ -65,7 +89,7 @@ def convert_images(
65
89
  encoder (Callable): Encode image using model.
66
90
  device (str or torch.device): Device to use.
67
91
  """
68
- if isinstance(images, Image.Image):
92
+ if isinstance(images, PILImage.Image):
69
93
  images = [images]
70
94
 
71
95
  converted = [
datachain/lib/listing.py CHANGED
@@ -1,19 +1,21 @@
1
+ import glob
1
2
  import logging
2
3
  import os
3
4
  import posixpath
4
5
  from collections.abc import Iterator
5
- from typing import TYPE_CHECKING, Callable, Optional, TypeVar
6
+ from contextlib import contextmanager
7
+ from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
6
8
 
7
9
  from fsspec.asyn import get_loop
8
10
  from sqlalchemy.sql.expression import true
9
11
 
12
+ import datachain.fs.utils as fsutils
10
13
  from datachain.asyn import iter_over_async
11
14
  from datachain.client import Client
12
- from datachain.error import REMOTE_ERRORS, ClientError
15
+ from datachain.error import ClientError
13
16
  from datachain.lib.file import File
14
17
  from datachain.query.schema import Column
15
18
  from datachain.sql.functions import path as pathfunc
16
- from datachain.telemetry import telemetry
17
19
  from datachain.utils import uses_glob
18
20
 
19
21
  if TYPE_CHECKING:
@@ -92,38 +94,6 @@ def ls(
92
94
  return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*"))
93
95
 
94
96
 
95
- def _isfile(client: "Client", path: str) -> bool:
96
- """
97
- Returns True if uri points to a file
98
- """
99
- try:
100
- if "://" in path:
101
- # This makes sure that the uppercase scheme is converted to lowercase
102
- scheme, path = path.split("://", 1)
103
- path = f"{scheme.lower()}://{path}"
104
-
105
- if os.name == "nt" and "*" in path:
106
- # On Windows, the glob pattern "*" is not supported
107
- return False
108
-
109
- info = client.fs.info(path)
110
- name = info.get("name")
111
- # case for special simulated directories on some clouds
112
- # e.g. Google creates a zero byte file with the same name as the
113
- # directory with a trailing slash at the end
114
- if not name or name.endswith("/"):
115
- return False
116
-
117
- return info["type"] == "file"
118
- except FileNotFoundError:
119
- return False
120
- except REMOTE_ERRORS as e:
121
- raise ClientError(
122
- message=str(e),
123
- error_code=getattr(e, "code", None),
124
- ) from e
125
-
126
-
127
97
  def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
128
98
  """
129
99
  Parsing uri and returns listing dataset name, listing uri and listing path
@@ -156,8 +126,16 @@ def listing_uri_from_name(dataset_name: str) -> str:
156
126
  return dataset_name.removeprefix(LISTING_PREFIX)
157
127
 
158
128
 
129
+ @contextmanager
130
+ def _reraise_as_client_error() -> Iterator[None]:
131
+ try:
132
+ yield
133
+ except Exception as e:
134
+ raise ClientError(message=str(e), error_code=getattr(e, "code", None)) from e
135
+
136
+
159
137
  def get_listing(
160
- uri: str, session: "Session", update: bool = False
138
+ uri: Union[str, os.PathLike[str]], session: "Session", update: bool = False
161
139
  ) -> tuple[Optional[str], str, str, bool]:
162
140
  """Returns correct listing dataset name that must be used for saving listing
163
141
  operation. It takes into account existing listings and reusability of those.
@@ -167,6 +145,7 @@ def get_listing(
167
145
  be used to find rows based on uri.
168
146
  """
169
147
  from datachain.client.local import FileClient
148
+ from datachain.telemetry import telemetry
170
149
 
171
150
  catalog = session.catalog
172
151
  cache = catalog.cache
@@ -174,11 +153,14 @@ def get_listing(
174
153
 
175
154
  client = Client.get_client(uri, cache, **client_config)
176
155
  telemetry.log_param("client", client.PREFIX)
156
+ if not isinstance(uri, str):
157
+ uri = os.fspath(uri)
177
158
 
178
159
  # we don't want to use cached dataset (e.g. for a single file listing)
179
- if not uri.endswith("/") and _isfile(client, uri):
180
- storage_uri, path = Client.parse_url(uri)
181
- return None, f"{storage_uri}/{path.lstrip('/')}", path, False
160
+ isfile = _reraise_as_client_error()(fsutils.isfile)
161
+ if not glob.has_magic(uri) and not uri.endswith("/") and isfile(client.fs, uri):
162
+ _, path = Client.parse_url(uri)
163
+ return None, uri, path, False
182
164
 
183
165
  ds_name, list_uri, list_path = parse_listing_uri(uri, client_config)
184
166
  listing = None
datachain/lib/video.py CHANGED
@@ -1,11 +1,11 @@
1
1
  import posixpath
2
2
  import shutil
3
3
  import tempfile
4
- from typing import Optional
4
+ from typing import Optional, Union
5
5
 
6
6
  from numpy import ndarray
7
7
 
8
- from datachain.lib.file import FileError, ImageFile, Video, VideoFile
8
+ from datachain.lib.file import File, FileError, ImageFile, Video, VideoFile
9
9
 
10
10
  try:
11
11
  import ffmpeg
@@ -18,7 +18,7 @@ except ImportError as exc:
18
18
  ) from exc
19
19
 
20
20
 
21
- def video_info(file: VideoFile) -> Video:
21
+ def video_info(file: Union[File, VideoFile]) -> Video:
22
22
  """
23
23
  Returns video file information.
24
24
 
@@ -28,6 +28,8 @@ def video_info(file: VideoFile) -> Video:
28
28
  Returns:
29
29
  Video: Video file information.
30
30
  """
31
+ file = file.as_video_file()
32
+
31
33
  if not (file_path := file.get_local_path()):
32
34
  file.ensure_cached()
33
35
  file_path = file.get_local_path()
@@ -170,7 +172,7 @@ def save_video_frame(
170
172
  output_file = posixpath.join(
171
173
  output, f"{video.get_file_stem()}_{frame:04d}.{format}"
172
174
  )
173
- return ImageFile.upload(img, output_file)
175
+ return ImageFile.upload(img, output_file, catalog=video._catalog)
174
176
 
175
177
 
176
178
  def save_video_fragment(
@@ -218,6 +220,6 @@ def save_video_fragment(
218
220
  ).output(output_file_tmp).run(quiet=True)
219
221
 
220
222
  with open(output_file_tmp, "rb") as f:
221
- return VideoFile.upload(f.read(), output_file)
223
+ return VideoFile.upload(f.read(), output_file, catalog=video._catalog)
222
224
  finally:
223
225
  shutil.rmtree(temp_dir)
datachain/model/bbox.py CHANGED
@@ -1,47 +1,216 @@
1
+ from collections.abc import Sequence
2
+ from typing import TYPE_CHECKING, Union
3
+
1
4
  from pydantic import Field
2
5
 
3
6
  from datachain.lib.data_model import DataModel
4
7
 
8
+ from .utils import convert_bbox, validate_bbox
9
+
10
+ if TYPE_CHECKING:
11
+ from .pose import Pose, Pose3D
12
+
5
13
 
6
14
  class BBox(DataModel):
7
15
  """
8
- A data model for representing bounding box.
16
+ A data model representing a bounding box.
9
17
 
10
18
  Attributes:
11
- title (str): The title of the bounding box.
12
- coords (list[int]): The coordinates of the bounding box.
19
+ title (str): The title or label associated with the bounding box.
20
+ coords (list[int]): A list of four bounding box coordinates.
13
21
 
14
- The bounding box is defined by two points:
15
- - (x1, y1): The top-left corner of the box.
16
- - (x2, y2): The bottom-right corner of the box.
22
+ The bounding box follows the PASCAL VOC format, where:
23
+ - (x1, y1) represents the pixel coordinates of the top-left corner.
24
+ - (x2, y2) represents the pixel coordinates of the bottom-right corner.
17
25
  """
18
26
 
19
27
  title: str = Field(default="")
20
28
  coords: list[int] = Field(default=[])
21
29
 
22
30
  @staticmethod
23
- def from_list(coords: list[float], title: str = "") -> "BBox":
24
- assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
25
- assert all(isinstance(value, (int, float)) for value in coords), (
26
- "Bounding box coordinates must be floats or integers."
27
- )
28
- return BBox(
29
- title=title,
30
- coords=[round(c) for c in coords],
31
+ def from_albumentations(
32
+ coords: Sequence[float],
33
+ img_size: Sequence[int],
34
+ title: str = "",
35
+ ) -> "BBox":
36
+ """
37
+ Create a bounding box from Albumentations format.
38
+
39
+ Albumentations represents bounding boxes as `[x_min, y_min, x_max, y_max]`
40
+ with normalized coordinates (values between 0 and 1) relative to the image size.
41
+
42
+ Args:
43
+ coords (Sequence[float]): The bounding box coordinates in
44
+ Albumentations format.
45
+ img_size (Sequence[int]): The reference image size as `[width, height]`.
46
+ title (str, optional): The title or label of the bounding box.
47
+ Defaults to an empty string.
48
+
49
+ Returns:
50
+ BBox: The bounding box data model.
51
+ """
52
+ validate_bbox(coords, float)
53
+ bbox_coords = convert_bbox(coords, img_size, "albumentations", "voc")
54
+ return BBox(title=title, coords=list(map(round, bbox_coords)))
55
+
56
+ def to_albumentations(self, img_size: Sequence[int]) -> list[float]:
57
+ """
58
+ Convert the bounding box coordinates to Albumentations format.
59
+
60
+ Albumentations represents bounding boxes as `[x_min, y_min, x_max, y_max]`
61
+ with normalized coordinates (values between 0 and 1) relative to the image size.
62
+
63
+ Args:
64
+ img_size (Sequence[int]): The reference image size as `[width, height]`.
65
+
66
+ Returns:
67
+ list[float]: The bounding box coordinates in Albumentations format.
68
+ """
69
+ return convert_bbox(self.coords, img_size, "voc", "albumentations")
70
+
71
+ @staticmethod
72
+ def from_coco(
73
+ coords: Sequence[float],
74
+ title: str = "",
75
+ ) -> "BBox":
76
+ """
77
+ Create a bounding box from COCO format.
78
+
79
+ COCO format represents bounding boxes as [x_min, y_min, width, height], where:
80
+ - (x_min, y_min) are the pixel coordinates of the top-left corner.
81
+ - width and height define the size of the bounding box in pixels.
82
+
83
+ Args:
84
+ coords (Sequence[float]): The bounding box coordinates in COCO format.
85
+ title (str): The title of the bounding box.
86
+
87
+ Returns:
88
+ BBox: The bounding box data model.
89
+ """
90
+ validate_bbox(coords, float, int)
91
+ bbox_coords = convert_bbox(coords, [], "coco", "voc")
92
+ return BBox(title=title, coords=list(map(round, bbox_coords)))
93
+
94
+ def to_coco(self) -> list[int]:
95
+ """
96
+ Return the bounding box coordinates in COCO format.
97
+
98
+ COCO format represents bounding boxes as [x_min, y_min, width, height], where:
99
+ - (x_min, y_min) are the pixel coordinates of the top-left corner.
100
+ - width and height define the size of the bounding box in pixels.
101
+
102
+ Returns:
103
+ list[int]: The bounding box coordinates in COCO format.
104
+ """
105
+ res = convert_bbox(self.coords, [], "voc", "coco")
106
+ return list(map(round, res))
107
+
108
+ @staticmethod
109
+ def from_voc(
110
+ coords: Sequence[float],
111
+ title: str = "",
112
+ ) -> "BBox":
113
+ """
114
+ Create a bounding box from PASCAL VOC format.
115
+
116
+ PASCAL VOC format represents bounding boxes as [x_min, y_min, x_max, y_max],
117
+ where:
118
+ - (x_min, y_min) are the pixel coordinates of the top-left corner.
119
+ - (x_max, y_max) are the pixel coordinates of the bottom-right corner.
120
+
121
+ Args:
122
+ coords (Sequence[float]): The bounding box coordinates in VOC format.
123
+ title (str): The title of the bounding box.
124
+
125
+ Returns:
126
+ BBox: The bounding box data model.
127
+ """
128
+ validate_bbox(coords, float, int)
129
+ return BBox(title=title, coords=list(map(round, coords)))
130
+
131
+ def to_voc(self) -> list[int]:
132
+ """
133
+ Return the bounding box coordinates in PASCAL VOC format.
134
+
135
+ PASCAL VOC format represents bounding boxes as [x_min, y_min, x_max, y_max],
136
+ where:
137
+ - (x_min, y_min) are the pixel coordinates of the top-left corner.
138
+ - (x_max, y_max) are the pixel coordinates of the bottom-right corner.
139
+
140
+ Returns:
141
+ list[int]: The bounding box coordinates in VOC format.
142
+ """
143
+ return self.coords
144
+
145
+ @staticmethod
146
+ def from_yolo(
147
+ coords: Sequence[float],
148
+ img_size: Sequence[int],
149
+ title: str = "",
150
+ ) -> "BBox":
151
+ """
152
+ Create a bounding box from YOLO format.
153
+
154
+ YOLO format represents bounding boxes as [x_center, y_center, width, height],
155
+ where:
156
+ - (x_center, y_center) are the normalized coordinates of the box center.
157
+ - width and height normalized values define the size of the bounding box.
158
+
159
+ Args:
160
+ coords (Sequence[float]): The bounding box coordinates in YOLO format.
161
+ img_size (Sequence[int]): The reference image size as `[width, height]`.
162
+ title (str): The title of the bounding box.
163
+
164
+ Returns:
165
+ BBox: The bounding box data model.
166
+ """
167
+ validate_bbox(coords, float)
168
+ bbox_coords = convert_bbox(coords, img_size, "yolo", "voc")
169
+ return BBox(title=title, coords=list(map(round, bbox_coords)))
170
+
171
+ def to_yolo(self, img_size: Sequence[int]) -> list[float]:
172
+ """
173
+ Return the bounding box coordinates in YOLO format.
174
+
175
+ YOLO format represents bounding boxes as [x_center, y_center, width, height],
176
+ where:
177
+ - (x_center, y_center) are the normalized coordinates of the box center.
178
+ - width and height normalized values define the size of the bounding box.
179
+
180
+ Args:
181
+ img_size (Sequence[int]): The reference image size as `[width, height]`.
182
+
183
+ Returns:
184
+ list[float]: The bounding box coordinates in YOLO format.
185
+ """
186
+ return convert_bbox(self.coords, img_size, "voc", "yolo")
187
+
188
+ def point_inside(self, x: int, y: int) -> bool:
189
+ """
190
+ Return True if the point is inside the bounding box.
191
+
192
+ Assumes that if the point is on the edge of the bounding box,
193
+ it is considered inside.
194
+ """
195
+ x1, y1, x2, y2 = self.coords
196
+ return x1 <= x <= x2 and y1 <= y <= y2
197
+
198
+ def pose_inside(self, pose: Union["Pose", "Pose3D"]) -> bool:
199
+ """Return True if the pose is inside the bounding box."""
200
+ return all(
201
+ self.point_inside(x, y) for x, y in zip(pose.x, pose.y) if x > 0 or y > 0
31
202
  )
32
203
 
204
+ @staticmethod
205
+ def from_list(coords: Sequence[float], title: str = "") -> "BBox":
206
+ return BBox.from_voc(coords, title=title)
207
+
33
208
  @staticmethod
34
209
  def from_dict(coords: dict[str, float], title: str = "") -> "BBox":
35
- assert isinstance(coords, dict) and set(coords) == {
36
- "x1",
37
- "y1",
38
- "x2",
39
- "y2",
40
- }, "Bounding box must be a dictionary with keys 'x1', 'y1', 'x2' and 'y2'."
41
- return BBox.from_list(
42
- [coords["x1"], coords["y1"], coords["x2"], coords["y2"]],
43
- title=title,
44
- )
210
+ keys = ("x1", "y1", "x2", "y2")
211
+ if not isinstance(coords, dict) or set(coords) != set(keys):
212
+ raise ValueError("Bounding box must be a dictionary with coordinates.")
213
+ return BBox.from_voc([coords[k] for k in keys], title=title)
45
214
 
46
215
 
47
216
  class OBBox(DataModel):
@@ -63,40 +232,22 @@ class OBBox(DataModel):
63
232
  coords: list[int] = Field(default=[])
64
233
 
65
234
  @staticmethod
66
- def from_list(coords: list[float], title: str = "") -> "OBBox":
67
- assert len(coords) == 8, (
68
- "Oriented bounding box must be a list of 8 coordinates."
69
- )
70
- assert all(isinstance(value, (int, float)) for value in coords), (
71
- "Oriented bounding box coordinates must be floats or integers."
72
- )
73
- return OBBox(
74
- title=title,
75
- coords=[round(c) for c in coords],
76
- )
235
+ def from_list(coords: Sequence[float], title: str = "") -> "OBBox":
236
+ if not isinstance(coords, (list, tuple)):
237
+ raise TypeError("Oriented bounding box must be a list of coordinates.")
238
+ if len(coords) != 8:
239
+ raise ValueError("Oriented bounding box must have 8 coordinates.")
240
+ if not all(isinstance(value, (int, float)) for value in coords):
241
+ raise ValueError(
242
+ "Oriented bounding box coordinates must be floats or integers."
243
+ )
244
+ return OBBox(title=title, coords=list(map(round, coords)))
77
245
 
78
246
  @staticmethod
79
247
  def from_dict(coords: dict[str, float], title: str = "") -> "OBBox":
80
- assert isinstance(coords, dict) and set(coords) == {
81
- "x1",
82
- "y1",
83
- "x2",
84
- "y2",
85
- "x3",
86
- "y3",
87
- "x4",
88
- "y4",
89
- }, "Oriented bounding box must be a dictionary with coordinates."
90
- return OBBox.from_list(
91
- [
92
- coords["x1"],
93
- coords["y1"],
94
- coords["x2"],
95
- coords["y2"],
96
- coords["x3"],
97
- coords["y3"],
98
- coords["x4"],
99
- coords["y4"],
100
- ],
101
- title=title,
102
- )
248
+ keys = ("x1", "y1", "x2", "y2", "x3", "y3", "x4", "y4")
249
+ if not isinstance(coords, dict) or set(coords) != set(keys):
250
+ raise ValueError(
251
+ "Oriented bounding box must be a dictionary with coordinates."
252
+ )
253
+ return OBBox.from_list([coords[k] for k in keys], title=title)