datachain 0.6.8__py3-none-any.whl → 0.6.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -603,9 +603,10 @@ class Catalog:
603
603
  )
604
604
 
605
605
  lst = Listing(
606
+ self.metastore.clone(),
606
607
  self.warehouse.clone(),
607
608
  Client.get_client(list_uri, self.cache, **self.client_config),
608
- self.get_dataset(list_ds_name),
609
+ dataset_name=list_ds_name,
609
610
  object_name=object_name,
610
611
  )
611
612
 
@@ -698,9 +699,13 @@ class Catalog:
698
699
 
699
700
  client = self.get_client(source, **client_config)
700
701
  uri = client.uri
701
- st = self.warehouse.clone()
702
702
  dataset_name, _, _, _ = DataChain.parse_uri(uri, self.session)
703
- listing = Listing(st, client, self.get_dataset(dataset_name))
703
+ listing = Listing(
704
+ self.metastore.clone(),
705
+ self.warehouse.clone(),
706
+ client,
707
+ dataset_name=dataset_name,
708
+ )
704
709
  rows = DatasetQuery(
705
710
  name=dataset.name, version=ds_version, catalog=self
706
711
  ).to_db_records()
@@ -769,6 +774,7 @@ class Catalog:
769
774
  create_rows: Optional[bool] = True,
770
775
  validate_version: Optional[bool] = True,
771
776
  listing: Optional[bool] = False,
777
+ uuid: Optional[str] = None,
772
778
  ) -> "DatasetRecord":
773
779
  """
774
780
  Creates new dataset of a specific version.
@@ -816,6 +822,7 @@ class Catalog:
816
822
  query_script=query_script,
817
823
  create_rows_table=create_rows,
818
824
  columns=columns,
825
+ uuid=uuid,
819
826
  )
820
827
 
821
828
  def create_new_dataset_version(
@@ -832,6 +839,7 @@ class Catalog:
832
839
  script_output="",
833
840
  create_rows_table=True,
834
841
  job_id: Optional[str] = None,
842
+ uuid: Optional[str] = None,
835
843
  ) -> DatasetRecord:
836
844
  """
837
845
  Creates dataset version if it doesn't exist.
@@ -855,6 +863,7 @@ class Catalog:
855
863
  schema=schema,
856
864
  job_id=job_id,
857
865
  ignore_if_exists=True,
866
+ uuid=uuid,
858
867
  )
859
868
 
860
869
  if create_rows_table:
@@ -1350,6 +1359,13 @@ class Catalog:
1350
1359
  # we will create new one if it doesn't exist
1351
1360
  pass
1352
1361
 
1362
+ if dataset and version and dataset.has_version(version):
1363
+ """No need to communicate with Studio at all"""
1364
+ dataset_uri = create_dataset_uri(remote_dataset_name, version)
1365
+ print(f"Local copy of dataset {dataset_uri} already present")
1366
+ _instantiate_dataset()
1367
+ return
1368
+
1353
1369
  remote_dataset = self.get_remote_dataset(remote_dataset_name)
1354
1370
  # if version is not specified in uri, take the latest one
1355
1371
  if not version:
@@ -1400,6 +1416,7 @@ class Catalog:
1400
1416
  columns=columns,
1401
1417
  feature_schema=remote_dataset_version.feature_schema,
1402
1418
  validate_version=False,
1419
+ uuid=remote_dataset_version.uuid,
1403
1420
  )
1404
1421
 
1405
1422
  # asking remote to export dataset rows table to s3 and to return signed
@@ -358,7 +358,7 @@ class Client(ABC):
358
358
  ) -> BinaryIO:
359
359
  """Open a file, including files in tar archives."""
360
360
  if use_cache and (cache_path := self.cache.get_path(file)):
361
- return open(cache_path, mode="rb") # noqa: SIM115
361
+ return open(cache_path, mode="rb")
362
362
  assert not file.location
363
363
  return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
364
364
 
@@ -138,6 +138,7 @@ class AbstractMetastore(ABC, Serializable):
138
138
  size: Optional[int] = None,
139
139
  preview: Optional[list[dict]] = None,
140
140
  job_id: Optional[str] = None,
141
+ uuid: Optional[str] = None,
141
142
  ) -> DatasetRecord:
142
143
  """Creates new dataset version."""
143
144
 
@@ -352,6 +353,7 @@ class AbstractDBMetastore(AbstractMetastore):
352
353
  """Datasets versions table columns."""
353
354
  return [
354
355
  Column("id", Integer, primary_key=True),
356
+ Column("uuid", Text, nullable=False, default=uuid4()),
355
357
  Column(
356
358
  "dataset_id",
357
359
  Integer,
@@ -545,6 +547,7 @@ class AbstractDBMetastore(AbstractMetastore):
545
547
  size: Optional[int] = None,
546
548
  preview: Optional[list[dict]] = None,
547
549
  job_id: Optional[str] = None,
550
+ uuid: Optional[str] = None,
548
551
  conn=None,
549
552
  ) -> DatasetRecord:
550
553
  """Creates new dataset version."""
@@ -555,6 +558,7 @@ class AbstractDBMetastore(AbstractMetastore):
555
558
 
556
559
  query = self._datasets_versions_insert().values(
557
560
  dataset_id=dataset.id,
561
+ uuid=uuid or str(uuid4()),
558
562
  version=version,
559
563
  status=status,
560
564
  feature_schema=json.dumps(feature_schema or {}),
@@ -747,8 +747,12 @@ class SQLiteWarehouse(AbstractWarehouse):
747
747
 
748
748
  ids = self.db.execute(select_ids).fetchall()
749
749
 
750
- select_q = query.with_only_columns(
751
- *[c for c in query.selected_columns if c.name != "sys__id"]
750
+ select_q = (
751
+ query.with_only_columns(
752
+ *[c for c in query.selected_columns if c.name != "sys__id"]
753
+ )
754
+ .offset(None)
755
+ .limit(None)
752
756
  )
753
757
 
754
758
  for batch in batched_it(ids, 10_000):
datachain/dataset.py CHANGED
@@ -163,6 +163,7 @@ class DatasetStatus:
163
163
  @dataclass
164
164
  class DatasetVersion:
165
165
  id: int
166
+ uuid: str
166
167
  dataset_id: int
167
168
  version: int
168
169
  status: int
@@ -184,6 +185,7 @@ class DatasetVersion:
184
185
  def parse( # noqa: PLR0913
185
186
  cls: type[V],
186
187
  id: int,
188
+ uuid: str,
187
189
  dataset_id: int,
188
190
  version: int,
189
191
  status: int,
@@ -203,6 +205,7 @@ class DatasetVersion:
203
205
  ):
204
206
  return cls(
205
207
  id,
208
+ uuid,
206
209
  dataset_id,
207
210
  version,
208
211
  status,
@@ -306,6 +309,7 @@ class DatasetRecord:
306
309
  query_script: str,
307
310
  schema: str,
308
311
  version_id: int,
312
+ version_uuid: str,
309
313
  version_dataset_id: int,
310
314
  version: int,
311
315
  version_status: int,
@@ -331,6 +335,7 @@ class DatasetRecord:
331
335
 
332
336
  dataset_version = DatasetVersion.parse(
333
337
  version_id,
338
+ version_uuid,
334
339
  version_dataset_id,
335
340
  version,
336
341
  version_status,
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  from datetime import datetime
3
3
  from typing import TYPE_CHECKING, Any, Optional, Union
4
+ from uuid import uuid4
4
5
 
5
6
  from pydantic import Field, field_validator
6
7
 
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
15
16
 
16
17
  class DatasetInfo(DataModel):
17
18
  name: str
19
+ uuid: str = Field(default=str(uuid4()))
18
20
  version: int = Field(default=1)
19
21
  status: int = Field(default=DatasetStatus.CREATED)
20
22
  created_at: datetime = Field(default=TIME_ZERO)
@@ -60,6 +62,7 @@ class DatasetInfo(DataModel):
60
62
  job: Optional[Job],
61
63
  ) -> "Self":
62
64
  return cls(
65
+ uuid=version.uuid,
63
66
  name=dataset.name,
64
67
  version=version.version,
65
68
  status=version.status,
datachain/lib/dc.py CHANGED
@@ -30,7 +30,7 @@ from datachain.client.local import FileClient
30
30
  from datachain.dataset import DatasetRecord
31
31
  from datachain.lib.convert.python_to_sql import python_to_sql
32
32
  from datachain.lib.convert.values_to_tuples import values_to_tuples
33
- from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
33
+ from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
34
34
  from datachain.lib.dataset_info import DatasetInfo
35
35
  from datachain.lib.file import ArrowRow, File, get_file_type
36
36
  from datachain.lib.file import ExportPlacement as FileExportPlacement
@@ -642,6 +642,59 @@ class DataChain:
642
642
  }
643
643
  return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
644
644
 
645
+ def explode(
646
+ self,
647
+ col: str,
648
+ model_name: Optional[str] = None,
649
+ object_name: Optional[str] = None,
650
+ ) -> "DataChain":
651
+ """Explodes a column containing JSON objects (dict or str DataChain type) into
652
+ individual columns based on the schema of the JSON. Schema is inferred from
653
+ the first row of the column.
654
+
655
+ Args:
656
+ col: the name of the column containing JSON to be exploded.
657
+ model_name: optional generated model name. By default generates the name
658
+ automatically.
659
+ object_name: optional generated object column name. By default generates the
660
+ name automatically.
661
+
662
+ Returns:
663
+ DataChain: A new DataChain instance with the new set of columns.
664
+ """
665
+ import json
666
+
667
+ import pyarrow as pa
668
+
669
+ from datachain.lib.arrow import schema_to_output
670
+
671
+ json_value = next(self.limit(1).collect(col))
672
+ json_dict = (
673
+ json.loads(json_value) if isinstance(json_value, str) else json_value
674
+ )
675
+
676
+ if not isinstance(json_dict, dict):
677
+ raise TypeError(f"Column {col} should be a string or dict type with JSON")
678
+
679
+ schema = pa.Table.from_pylist([json_dict]).schema
680
+ output = schema_to_output(schema, None)
681
+
682
+ if not model_name:
683
+ model_name = f"{col.title()}ExplodedModel"
684
+
685
+ model = dict_to_data_model(model_name, output)
686
+
687
+ def json_to_model(json_value: Union[str, dict]):
688
+ json_dict = (
689
+ json.loads(json_value) if isinstance(json_value, str) else json_value
690
+ )
691
+ return model.model_validate(json_dict)
692
+
693
+ if not object_name:
694
+ object_name = f"{col}_expl"
695
+
696
+ return self.map(json_to_model, params=col, output={object_name: model})
697
+
645
698
  @classmethod
646
699
  def datasets(
647
700
  cls,
@@ -895,7 +948,7 @@ class DataChain:
895
948
  2. Group-based UDF function input: Instead of individual rows, the function
896
949
  receives a list all rows within each group defined by `partition_by`.
897
950
 
898
- Example:
951
+ Examples:
899
952
  ```py
900
953
  chain = chain.agg(
901
954
  total=lambda category, amount: [sum(amount)],
@@ -904,6 +957,26 @@ class DataChain:
904
957
  )
905
958
  chain.save("new_dataset")
906
959
  ```
960
+
961
+ An alternative syntax, when you need to specify a more complex function:
962
+
963
+ ```py
964
+ # It automatically resolves which columns to pass to the function
965
+ # by looking at the function signature.
966
+ def agg_sum(
967
+ file: list[File], amount: list[float]
968
+ ) -> Iterator[tuple[File, float]]:
969
+ yield file[0], sum(amount)
970
+
971
+ chain = chain.agg(
972
+ agg_sum,
973
+ output={"file": File, "total": float},
974
+ # Alternative syntax is to use `C` (short for Column) to specify
975
+ # a column name or a nested column, e.g. C("file.path").
976
+ partition_by=C("category"),
977
+ )
978
+ chain.save("new_dataset")
979
+ ```
907
980
  """
908
981
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
909
982
  return self._evolve(
@@ -1242,15 +1315,15 @@ class DataChain:
1242
1315
  return self.results(row_factory=to_dict)
1243
1316
 
1244
1317
  @overload
1245
- def collect(self) -> Iterator[tuple[DataType, ...]]: ...
1318
+ def collect(self) -> Iterator[tuple[DataValue, ...]]: ...
1246
1319
 
1247
1320
  @overload
1248
- def collect(self, col: str) -> Iterator[DataType]: ... # type: ignore[overload-overlap]
1321
+ def collect(self, col: str) -> Iterator[DataValue]: ...
1249
1322
 
1250
1323
  @overload
1251
- def collect(self, *cols: str) -> Iterator[tuple[DataType, ...]]: ...
1324
+ def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...
1252
1325
 
1253
- def collect(self, *cols: str) -> Iterator[Union[DataType, tuple[DataType, ...]]]: # type: ignore[overload-overlap,misc]
1326
+ def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...]]]: # type: ignore[overload-overlap,misc]
1254
1327
  """Yields rows of values, optionally limited to the specified columns.
1255
1328
 
1256
1329
  Args:
@@ -114,6 +114,7 @@ def read_meta( # noqa: C901
114
114
  )
115
115
  )
116
116
  (model_output,) = chain.collect("meta_schema")
117
+ assert isinstance(model_output, str)
117
118
  if print_schema:
118
119
  print(f"{model_output}")
119
120
  # Below 'spec' should be a dynamically converted DataModel from Pydantic
@@ -1,5 +1,6 @@
1
- from . import yolo
2
- from .bbox import BBox
1
+ from . import ultralytics
2
+ from .bbox import BBox, OBBox
3
3
  from .pose import Pose, Pose3D
4
+ from .segment import Segments
4
5
 
5
- __all__ = ["BBox", "Pose", "Pose3D", "yolo"]
6
+ __all__ = ["BBox", "OBBox", "Pose", "Pose3D", "Segments", "ultralytics"]
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  from pydantic import Field
4
2
 
5
3
  from datachain.lib.data_model import DataModel
@@ -11,10 +9,7 @@ class BBox(DataModel):
11
9
 
12
10
  Attributes:
13
11
  title (str): The title of the bounding box.
14
- x1 (float): The x-coordinate of the top-left corner of the bounding box.
15
- y1 (float): The y-coordinate of the top-left corner of the bounding box.
16
- x2 (float): The x-coordinate of the bottom-right corner of the bounding box.
17
- y2 (float): The y-coordinate of the bottom-right corner of the bounding box.
12
+ coords (list[int]): The coordinates of the bounding box.
18
13
 
19
14
  The bounding box is defined by two points:
20
15
  - (x1, y1): The top-left corner of the box.
@@ -22,24 +17,100 @@ class BBox(DataModel):
22
17
  """
23
18
 
24
19
  title: str = Field(default="")
25
- x1: float = Field(default=0)
26
- y1: float = Field(default=0)
27
- x2: float = Field(default=0)
28
- y2: float = Field(default=0)
20
+ coords: list[int] = Field(default=None)
21
+
22
+ @staticmethod
23
+ def from_list(coords: list[float], title: str = "") -> "BBox":
24
+ assert len(coords) == 4, "Bounding box coordinates must be a list of 4 floats."
25
+ assert all(
26
+ isinstance(value, (int, float)) for value in coords
27
+ ), "Bounding box coordinates must be integers or floats."
28
+ return BBox(
29
+ title=title,
30
+ coords=[round(c) for c in coords],
31
+ )
32
+
33
+ @staticmethod
34
+ def from_dict(coords: dict[str, float], title: str = "") -> "BBox":
35
+ assert (
36
+ len(coords) == 4
37
+ ), "Bounding box coordinates must be a dictionary of 4 floats."
38
+ assert set(coords) == {
39
+ "x1",
40
+ "y1",
41
+ "x2",
42
+ "y2",
43
+ }, "Bounding box coordinates must contain keys with coordinates."
44
+ assert all(
45
+ isinstance(value, (int, float)) for value in coords.values()
46
+ ), "Bounding box coordinates must be integers or floats."
47
+ return BBox(
48
+ title=title,
49
+ coords=[
50
+ round(coords["x1"]),
51
+ round(coords["y1"]),
52
+ round(coords["x2"]),
53
+ round(coords["y2"]),
54
+ ],
55
+ )
56
+
57
+
58
+ class OBBox(DataModel):
59
+ """
60
+ A data model for representing oriented bounding boxes.
61
+
62
+ Attributes:
63
+ title (str): The title of the oriented bounding box.
64
+ coords (list[int]): The coordinates of the oriented bounding box.
65
+
66
+ The oriented bounding box is defined by four points:
67
+ - (x1, y1): The first corner of the box.
68
+ - (x2, y2): The second corner of the box.
69
+ - (x3, y3): The third corner of the box.
70
+ - (x4, y4): The fourth corner of the box.
71
+ """
72
+
73
+ title: str = Field(default="")
74
+ coords: list[int] = Field(default=None)
75
+
76
+ @staticmethod
77
+ def from_list(coords: list[float], title: str = "") -> "OBBox":
78
+ assert (
79
+ len(coords) == 8
80
+ ), "Oriented bounding box coordinates must be a list of 8 floats."
81
+ assert all(
82
+ isinstance(value, (int, float)) for value in coords
83
+ ), "Oriented bounding box coordinates must be integers or floats."
84
+ return OBBox(
85
+ title=title,
86
+ coords=[round(c) for c in coords],
87
+ )
29
88
 
30
89
  @staticmethod
31
- def from_xywh(bbox: list[float], title: Optional[str] = None) -> "BBox":
32
- """
33
- Converts a bounding box in (x, y, width, height) format
34
- to a BBox data model instance.
35
-
36
- Args:
37
- bbox (list[float]): A bounding box, represented as a list
38
- of four floats [x, y, width, height].
39
-
40
- Returns:
41
- BBox2D: An instance of the BBox data model.
42
- """
43
- assert len(bbox) == 4, f"Bounding box must have 4 elements, got f{len(bbox)}"
44
- x, y, w, h = bbox
45
- return BBox(title=title or "", x1=x, y1=y, x2=x + w, y2=y + h)
90
+ def from_dict(coords: dict[str, float], title: str = "") -> "OBBox":
91
+ assert set(coords) == {
92
+ "x1",
93
+ "y1",
94
+ "x2",
95
+ "y2",
96
+ "x3",
97
+ "y3",
98
+ "x4",
99
+ "y4",
100
+ }, "Oriented bounding box coordinates must contain keys with coordinates."
101
+ assert all(
102
+ isinstance(value, (int, float)) for value in coords.values()
103
+ ), "Oriented bounding box coordinates must be integers or floats."
104
+ return OBBox(
105
+ title=title,
106
+ coords=[
107
+ round(coords["x1"]),
108
+ round(coords["y1"]),
109
+ round(coords["x2"]),
110
+ round(coords["y2"]),
111
+ round(coords["x3"]),
112
+ round(coords["y3"]),
113
+ round(coords["x4"]),
114
+ round(coords["y4"]),
115
+ ],
116
+ )
@@ -8,15 +8,48 @@ class Pose(DataModel):
8
8
  A data model for representing pose keypoints.
9
9
 
10
10
  Attributes:
11
- x (list[float]): The x-coordinates of the keypoints.
12
- y (list[float]): The y-coordinates of the keypoints.
11
+ x (list[int]): The x-coordinates of the keypoints.
12
+ y (list[int]): The y-coordinates of the keypoints.
13
13
 
14
14
  The keypoints are represented as lists of x and y coordinates, where each index
15
15
  corresponds to a specific body part.
16
16
  """
17
17
 
18
- x: list[float] = Field(default=None)
19
- y: list[float] = Field(default=None)
18
+ x: list[int] = Field(default=None)
19
+ y: list[int] = Field(default=None)
20
+
21
+ @staticmethod
22
+ def from_list(points: list[list[float]]) -> "Pose":
23
+ assert len(points) == 2, "Pose coordinates must be a list of 2 lists."
24
+ points_x, points_y = points
25
+ assert (
26
+ len(points_x) == len(points_y) == 17
27
+ ), "Pose x and y coordinates must have the same length of 17."
28
+ assert all(
29
+ isinstance(value, (int, float)) for value in [*points_x, *points_y]
30
+ ), "Pose coordinates must be integers or floats."
31
+ return Pose(
32
+ x=[round(coord) for coord in points_x],
33
+ y=[round(coord) for coord in points_y],
34
+ )
35
+
36
+ @staticmethod
37
+ def from_dict(points: dict[str, list[float]]) -> "Pose":
38
+ assert set(points) == {
39
+ "x",
40
+ "y",
41
+ }, "Pose coordinates must contain keys 'x' and 'y'."
42
+ points_x, points_y = points["x"], points["y"]
43
+ assert (
44
+ len(points_x) == len(points_y) == 17
45
+ ), "Pose x and y coordinates must have the same length of 17."
46
+ assert all(
47
+ isinstance(value, (int, float)) for value in [*points_x, *points_y]
48
+ ), "Pose coordinates must be integers or floats."
49
+ return Pose(
50
+ x=[round(coord) for coord in points_x],
51
+ y=[round(coord) for coord in points_y],
52
+ )
20
53
 
21
54
 
22
55
  class Pose3D(DataModel):
@@ -24,14 +57,52 @@ class Pose3D(DataModel):
24
57
  A data model for representing 3D pose keypoints.
25
58
 
26
59
  Attributes:
27
- x (list[float]): The x-coordinates of the keypoints.
28
- y (list[float]): The y-coordinates of the keypoints.
60
+ x (list[int]): The x-coordinates of the keypoints.
61
+ y (list[int]): The y-coordinates of the keypoints.
29
62
  visible (list[float]): The visibility of the keypoints.
30
63
 
31
64
  The keypoints are represented as lists of x, y, and visibility values,
32
65
  where each index corresponds to a specific body part.
33
66
  """
34
67
 
35
- x: list[float] = Field(default=None)
36
- y: list[float] = Field(default=None)
68
+ x: list[int] = Field(default=None)
69
+ y: list[int] = Field(default=None)
37
70
  visible: list[float] = Field(default=None)
71
+
72
+ @staticmethod
73
+ def from_list(points: list[list[float]]) -> "Pose3D":
74
+ assert len(points) == 3, "Pose coordinates must be a list of 3 lists."
75
+ points_x, points_y, points_v = points
76
+ assert (
77
+ len(points_x) == len(points_y) == len(points_v) == 17
78
+ ), "Pose x, y, and visibility coordinates must have the same length of 17."
79
+ assert all(
80
+ isinstance(value, (int, float))
81
+ for value in [*points_x, *points_y, *points_v]
82
+ ), "Pose coordinates must be integers or floats."
83
+ return Pose3D(
84
+ x=[round(coord) for coord in points_x],
85
+ y=[round(coord) for coord in points_y],
86
+ visible=points_v,
87
+ )
88
+
89
+ @staticmethod
90
+ def from_dict(points: dict[str, list[float]]) -> "Pose3D":
91
+ assert set(points) == {
92
+ "x",
93
+ "y",
94
+ "visible",
95
+ }, "Pose coordinates must contain keys 'x', 'y', and 'visible'."
96
+ points_x, points_y, points_v = points["x"], points["y"], points["visible"]
97
+ assert (
98
+ len(points_x) == len(points_y) == len(points_v) == 17
99
+ ), "Pose x, y, and visibility coordinates must have the same length of 17."
100
+ assert all(
101
+ isinstance(value, (int, float))
102
+ for value in [*points_x, *points_y, *points_v]
103
+ ), "Pose coordinates must be integers or floats."
104
+ return Pose3D(
105
+ x=[round(coord) for coord in points_x],
106
+ y=[round(coord) for coord in points_y],
107
+ visible=points_v,
108
+ )
@@ -0,0 +1,53 @@
1
+ from pydantic import Field
2
+
3
+ from datachain.lib.data_model import DataModel
4
+
5
+
6
+ class Segments(DataModel):
7
+ """
8
+ A data model for representing segments.
9
+
10
+ Attributes:
11
+ title (str): The title of the segments.
12
+ x (list[int]): The x-coordinates of the segments.
13
+ y (list[int]): The y-coordinates of the segments.
14
+
15
+ The segments are represented as lists of x and y coordinates, where each index
16
+ corresponds to a specific segment.
17
+ """
18
+
19
+ title: str = Field(default="")
20
+ x: list[int] = Field(default=None)
21
+ y: list[int] = Field(default=None)
22
+
23
+ @staticmethod
24
+ def from_list(points: list[list[float]], title: str = "") -> "Segments":
25
+ assert len(points) == 2, "Segments coordinates must be a list of 2 lists."
26
+ points_x, points_y = points
27
+ assert len(points_x) == len(
28
+ points_y
29
+ ), "Segments x and y coordinates must have the same length."
30
+ assert all(
31
+ isinstance(value, (int, float)) for value in [*points_x, *points_y]
32
+ ), "Segments coordinates must be integers or floats."
33
+ return Segments(
34
+ title=title,
35
+ x=[round(coord) for coord in points_x],
36
+ y=[round(coord) for coord in points_y],
37
+ )
38
+
39
+ @staticmethod
40
+ def from_dict(points: dict[str, list[float]], title: str = "") -> "Segments":
41
+ assert set(points) == {
42
+ "x",
43
+ "y",
44
+ }, "Segments coordinates must contain keys 'x' and 'y'."
45
+ points_x, points_y = points["x"], points["y"]
46
+ assert all(
47
+ isinstance(value, (int, float)) for value in [*points_x, *points_y]
48
+ ), "Segments coordinates must be integers or floats."
49
+ return Segments(
50
+ title=title,
51
+ x=[round(coord) for coord in points_x],
52
+ y=[round(coord) for coord in points_y],
53
+ )
@@ -0,0 +1,14 @@
1
+ from .bbox import YoloBBox, YoloBBoxes, YoloOBBox, YoloOBBoxes
2
+ from .pose import YoloPose, YoloPoses
3
+ from .segment import YoloSegment, YoloSegments
4
+
5
+ __all__ = [
6
+ "YoloBBox",
7
+ "YoloBBoxes",
8
+ "YoloOBBox",
9
+ "YoloOBBoxes",
10
+ "YoloPose",
11
+ "YoloPoses",
12
+ "YoloSegment",
13
+ "YoloSegments",
14
+ ]