datachain 0.6.8__py3-none-any.whl → 0.6.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +20 -3
- datachain/client/fsspec.py +1 -1
- datachain/data_storage/metastore.py +4 -0
- datachain/data_storage/sqlite.py +6 -2
- datachain/dataset.py +5 -0
- datachain/lib/dataset_info.py +3 -0
- datachain/lib/dc.py +79 -6
- datachain/lib/meta_formats.py +1 -0
- datachain/lib/models/__init__.py +4 -3
- datachain/lib/models/bbox.py +96 -25
- datachain/lib/models/pose.py +79 -8
- datachain/lib/models/segment.py +53 -0
- datachain/lib/models/ultralytics/__init__.py +14 -0
- datachain/lib/models/ultralytics/bbox.py +189 -0
- datachain/lib/models/ultralytics/pose.py +126 -0
- datachain/lib/models/ultralytics/segment.py +121 -0
- datachain/lib/signal_schema.py +1 -1
- datachain/listing.py +24 -7
- datachain/toolkit/__init__.py +3 -0
- datachain/toolkit/split.py +67 -0
- {datachain-0.6.8.dist-info → datachain-0.6.10.dist-info}/METADATA +42 -22
- {datachain-0.6.8.dist-info → datachain-0.6.10.dist-info}/RECORD +26 -20
- {datachain-0.6.8.dist-info → datachain-0.6.10.dist-info}/WHEEL +1 -1
- datachain/lib/models/yolo.py +0 -39
- {datachain-0.6.8.dist-info → datachain-0.6.10.dist-info}/LICENSE +0 -0
- {datachain-0.6.8.dist-info → datachain-0.6.10.dist-info}/entry_points.txt +0 -0
- {datachain-0.6.8.dist-info → datachain-0.6.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the YOLO models.
|
|
3
|
+
|
|
4
|
+
YOLO stands for "You Only Look Once", a family of object detection models that
|
|
5
|
+
are designed to be fast and accurate. The models are trained to detect objects
|
|
6
|
+
in images by dividing the image into a grid and predicting the bounding boxes
|
|
7
|
+
and class probabilities for each grid cell.
|
|
8
|
+
|
|
9
|
+
More information about YOLO can be found here:
|
|
10
|
+
- https://pjreddie.com/darknet/yolo/
|
|
11
|
+
- https://docs.ultralytics.com/
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from io import BytesIO
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
from PIL import Image
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from datachain.lib.data_model import DataModel
|
|
21
|
+
from datachain.lib.models.bbox import BBox, OBBox
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from ultralytics.engine.results import Results
|
|
25
|
+
from ultralytics.models import YOLO
|
|
26
|
+
|
|
27
|
+
from datachain.lib.file import File
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class YoloBBox(DataModel):
|
|
31
|
+
"""
|
|
32
|
+
A class representing a bounding box detected by a YOLO model.
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
cls: The class of the detected object.
|
|
36
|
+
name: The name of the detected object.
|
|
37
|
+
confidence: The confidence score of the detection.
|
|
38
|
+
box: The bounding box of the detected object
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
cls: int = Field(default=-1)
|
|
42
|
+
name: str = Field(default="")
|
|
43
|
+
confidence: float = Field(default=0)
|
|
44
|
+
box: BBox = Field(default=None)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def from_file(yolo: "YOLO", file: "File") -> "YoloBBox":
|
|
48
|
+
results = yolo(Image.open(BytesIO(file.read())))
|
|
49
|
+
if len(results) == 0:
|
|
50
|
+
return YoloBBox()
|
|
51
|
+
return YoloBBox.from_result(results[0])
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def from_result(result: "Results") -> "YoloBBox":
|
|
55
|
+
summary = result.summary()
|
|
56
|
+
if not summary:
|
|
57
|
+
return YoloBBox()
|
|
58
|
+
name = summary[0].get("name", "")
|
|
59
|
+
box = (
|
|
60
|
+
BBox.from_dict(summary[0]["box"], title=name)
|
|
61
|
+
if "box" in summary[0]
|
|
62
|
+
else BBox()
|
|
63
|
+
)
|
|
64
|
+
return YoloBBox(
|
|
65
|
+
cls=summary[0]["class"],
|
|
66
|
+
name=name,
|
|
67
|
+
confidence=summary[0]["confidence"],
|
|
68
|
+
box=box,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class YoloBBoxes(DataModel):
|
|
73
|
+
"""
|
|
74
|
+
A class representing a list of bounding boxes detected by a YOLO model.
|
|
75
|
+
|
|
76
|
+
Attributes:
|
|
77
|
+
cls: A list of classes of the detected objects.
|
|
78
|
+
name: A list of names of the detected objects.
|
|
79
|
+
confidence: A list of confidence scores of the detections.
|
|
80
|
+
box: A list of bounding boxes of the detected objects
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
cls: list[int]
|
|
84
|
+
name: list[str]
|
|
85
|
+
confidence: list[float]
|
|
86
|
+
box: list[BBox]
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def from_file(yolo: "YOLO", file: "File") -> "YoloBBoxes":
|
|
90
|
+
results = yolo(Image.open(BytesIO(file.read())))
|
|
91
|
+
return YoloBBoxes.from_results(results)
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def from_results(results: list["Results"]) -> "YoloBBoxes":
|
|
95
|
+
cls, names, confidence, box = [], [], [], []
|
|
96
|
+
for r in results:
|
|
97
|
+
for s in r.summary():
|
|
98
|
+
name = s.get("name", "")
|
|
99
|
+
cls.append(s["class"])
|
|
100
|
+
names.append(name)
|
|
101
|
+
confidence.append(s["confidence"])
|
|
102
|
+
box.append(BBox.from_dict(s.get("box", {}), title=name))
|
|
103
|
+
return YoloBBoxes(
|
|
104
|
+
cls=cls,
|
|
105
|
+
name=names,
|
|
106
|
+
confidence=confidence,
|
|
107
|
+
box=box,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class YoloOBBox(DataModel):
|
|
112
|
+
"""
|
|
113
|
+
A class representing an oriented bounding box detected by a YOLO model.
|
|
114
|
+
|
|
115
|
+
Attributes:
|
|
116
|
+
cls: The class of the detected object.
|
|
117
|
+
name: The name of the detected object.
|
|
118
|
+
confidence: The confidence score of the detection.
|
|
119
|
+
box: The oriented bounding box of the detected object.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
cls: int = Field(default=-1)
|
|
123
|
+
name: str = Field(default="")
|
|
124
|
+
confidence: float = Field(default=0)
|
|
125
|
+
box: OBBox = Field(default=None)
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def from_file(yolo: "YOLO", file: "File") -> "YoloOBBox":
|
|
129
|
+
results = yolo(Image.open(BytesIO(file.read())))
|
|
130
|
+
if len(results) == 0:
|
|
131
|
+
return YoloOBBox()
|
|
132
|
+
return YoloOBBox.from_result(results[0])
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def from_result(result: "Results") -> "YoloOBBox":
|
|
136
|
+
summary = result.summary()
|
|
137
|
+
if not summary:
|
|
138
|
+
return YoloOBBox()
|
|
139
|
+
name = summary[0].get("name", "")
|
|
140
|
+
box = (
|
|
141
|
+
OBBox.from_dict(summary[0]["box"], title=name)
|
|
142
|
+
if "box" in summary[0]
|
|
143
|
+
else OBBox()
|
|
144
|
+
)
|
|
145
|
+
return YoloOBBox(
|
|
146
|
+
cls=summary[0]["class"],
|
|
147
|
+
name=name,
|
|
148
|
+
confidence=summary[0]["confidence"],
|
|
149
|
+
box=box,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class YoloOBBoxes(DataModel):
|
|
154
|
+
"""
|
|
155
|
+
A class representing a list of oriented bounding boxes detected by a YOLO model.
|
|
156
|
+
|
|
157
|
+
Attributes:
|
|
158
|
+
cls: A list of classes of the detected objects.
|
|
159
|
+
name: A list of names of the detected objects.
|
|
160
|
+
confidence: A list of confidence scores of the detections.
|
|
161
|
+
box: A list of oriented bounding boxes of the detected objects.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
cls: list[int]
|
|
165
|
+
name: list[str]
|
|
166
|
+
confidence: list[float]
|
|
167
|
+
box: list[OBBox]
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def from_file(yolo: "YOLO", file: "File") -> "YoloOBBoxes":
|
|
171
|
+
results = yolo(Image.open(BytesIO(file.read())))
|
|
172
|
+
return YoloOBBoxes.from_results(results)
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def from_results(results: list["Results"]) -> "YoloOBBoxes":
|
|
176
|
+
cls, names, confidence, box = [], [], [], []
|
|
177
|
+
for r in results:
|
|
178
|
+
for s in r.summary():
|
|
179
|
+
name = s.get("name", "")
|
|
180
|
+
cls.append(s["class"])
|
|
181
|
+
names.append(name)
|
|
182
|
+
confidence.append(s["confidence"])
|
|
183
|
+
box.append(OBBox.from_dict(s.get("box", {}), title=name))
|
|
184
|
+
return YoloOBBoxes(
|
|
185
|
+
cls=cls,
|
|
186
|
+
name=names,
|
|
187
|
+
confidence=confidence,
|
|
188
|
+
box=box,
|
|
189
|
+
)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the YOLO models.
|
|
3
|
+
|
|
4
|
+
YOLO stands for "You Only Look Once", a family of object detection models that
|
|
5
|
+
are designed to be fast and accurate. The models are trained to detect objects
|
|
6
|
+
in images by dividing the image into a grid and predicting the bounding boxes
|
|
7
|
+
and class probabilities for each grid cell.
|
|
8
|
+
|
|
9
|
+
More information about YOLO can be found here:
|
|
10
|
+
- https://pjreddie.com/darknet/yolo/
|
|
11
|
+
- https://docs.ultralytics.com/
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from datachain.lib.data_model import DataModel
|
|
19
|
+
from datachain.lib.models.bbox import BBox
|
|
20
|
+
from datachain.lib.models.pose import Pose3D
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from ultralytics.engine.results import Results
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class YoloPoseBodyPart:
|
|
27
|
+
"""An enumeration of body parts for YOLO pose keypoints."""
|
|
28
|
+
|
|
29
|
+
nose = 0
|
|
30
|
+
left_eye = 1
|
|
31
|
+
right_eye = 2
|
|
32
|
+
left_ear = 3
|
|
33
|
+
right_ear = 4
|
|
34
|
+
left_shoulder = 5
|
|
35
|
+
right_shoulder = 6
|
|
36
|
+
left_elbow = 7
|
|
37
|
+
right_elbow = 8
|
|
38
|
+
left_wrist = 9
|
|
39
|
+
right_wrist = 10
|
|
40
|
+
left_hip = 11
|
|
41
|
+
right_hip = 12
|
|
42
|
+
left_knee = 13
|
|
43
|
+
right_knee = 14
|
|
44
|
+
left_ankle = 15
|
|
45
|
+
right_ankle = 16
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class YoloPose(DataModel):
|
|
49
|
+
"""
|
|
50
|
+
A data model for YOLO pose keypoints.
|
|
51
|
+
|
|
52
|
+
Attributes:
|
|
53
|
+
cls: The class of the pose.
|
|
54
|
+
name: The name of the pose.
|
|
55
|
+
confidence: The confidence score of the pose.
|
|
56
|
+
box: The bounding box of the pose.
|
|
57
|
+
keypoints: The 3D pose keypoints.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
cls: int = Field(default=-1)
|
|
61
|
+
name: str = Field(default="")
|
|
62
|
+
confidence: float = Field(default=0)
|
|
63
|
+
box: BBox = Field(default=None)
|
|
64
|
+
keypoints: Pose3D = Field(default=None)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def from_result(result: "Results") -> "YoloPose":
|
|
68
|
+
summary = result.summary()
|
|
69
|
+
if not summary:
|
|
70
|
+
return YoloPose()
|
|
71
|
+
name = summary[0].get("name", "")
|
|
72
|
+
box = (
|
|
73
|
+
BBox.from_dict(summary[0]["box"], title=name)
|
|
74
|
+
if "box" in summary[0]
|
|
75
|
+
else BBox()
|
|
76
|
+
)
|
|
77
|
+
keypoints = (
|
|
78
|
+
Pose3D.from_dict(summary[0]["keypoints"])
|
|
79
|
+
if "keypoints" in summary[0]
|
|
80
|
+
else Pose3D()
|
|
81
|
+
)
|
|
82
|
+
return YoloPose(
|
|
83
|
+
cls=summary[0]["class"],
|
|
84
|
+
name=name,
|
|
85
|
+
confidence=summary[0]["confidence"],
|
|
86
|
+
box=box,
|
|
87
|
+
keypoints=keypoints,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class YoloPoses(DataModel):
|
|
92
|
+
"""
|
|
93
|
+
A data model for a list of YOLO pose keypoints.
|
|
94
|
+
|
|
95
|
+
Attributes:
|
|
96
|
+
cls: The classes of the poses.
|
|
97
|
+
name: The names of the poses.
|
|
98
|
+
confidence: The confidence scores of the poses.
|
|
99
|
+
box: The bounding boxes of the poses.
|
|
100
|
+
keypoints: The 3D pose keypoints of the poses.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
cls: list[int]
|
|
104
|
+
name: list[str]
|
|
105
|
+
confidence: list[float]
|
|
106
|
+
box: list[BBox]
|
|
107
|
+
keypoints: list[Pose3D]
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def from_results(results: list["Results"]) -> "YoloPoses":
|
|
111
|
+
cls, names, confidence, box, keypoints = [], [], [], [], []
|
|
112
|
+
for r in results:
|
|
113
|
+
for s in r.summary():
|
|
114
|
+
name = s.get("name", "")
|
|
115
|
+
cls.append(s["class"])
|
|
116
|
+
names.append(name)
|
|
117
|
+
confidence.append(s["confidence"])
|
|
118
|
+
box.append(BBox.from_dict(s.get("box", {}), title=name))
|
|
119
|
+
keypoints.append(Pose3D.from_dict(s.get("keypoints", {})))
|
|
120
|
+
return YoloPoses(
|
|
121
|
+
cls=cls,
|
|
122
|
+
name=names,
|
|
123
|
+
confidence=confidence,
|
|
124
|
+
box=box,
|
|
125
|
+
keypoints=keypoints,
|
|
126
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the YOLO models.
|
|
3
|
+
|
|
4
|
+
YOLO stands for "You Only Look Once", a family of object detection models that
|
|
5
|
+
are designed to be fast and accurate. The models are trained to detect objects
|
|
6
|
+
in images by dividing the image into a grid and predicting the bounding boxes
|
|
7
|
+
and class probabilities for each grid cell.
|
|
8
|
+
|
|
9
|
+
More information about YOLO can be found here:
|
|
10
|
+
- https://pjreddie.com/darknet/yolo/
|
|
11
|
+
- https://docs.ultralytics.com/
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from io import BytesIO
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
from PIL import Image
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from datachain.lib.data_model import DataModel
|
|
21
|
+
from datachain.lib.models.bbox import BBox
|
|
22
|
+
from datachain.lib.models.segment import Segments
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ultralytics.engine.results import Results
|
|
26
|
+
from ultralytics.models import YOLO
|
|
27
|
+
|
|
28
|
+
from datachain.lib.file import File
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class YoloSegment(DataModel):
|
|
32
|
+
"""
|
|
33
|
+
A data model for a single YOLO segment.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
cls (int): The class of the segment.
|
|
37
|
+
name (str): The name of the segment.
|
|
38
|
+
confidence (float): The confidence of the segment.
|
|
39
|
+
box (BBox): The bounding box of the segment.
|
|
40
|
+
segments (Segments): The segments of the segment.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
cls: int = Field(default=-1)
|
|
44
|
+
name: str = Field(default="")
|
|
45
|
+
confidence: float = Field(default=0)
|
|
46
|
+
box: BBox = Field(default=None)
|
|
47
|
+
segments: Segments = Field(default=None)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def from_file(yolo: "YOLO", file: "File") -> "YoloSegment":
|
|
51
|
+
results = yolo(Image.open(BytesIO(file.read())))
|
|
52
|
+
if len(results) == 0:
|
|
53
|
+
return YoloSegment()
|
|
54
|
+
return YoloSegment.from_result(results[0])
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def from_result(result: "Results") -> "YoloSegment":
|
|
58
|
+
summary = result.summary()
|
|
59
|
+
if not summary:
|
|
60
|
+
return YoloSegment()
|
|
61
|
+
name = summary[0].get("name", "")
|
|
62
|
+
box = (
|
|
63
|
+
BBox.from_dict(summary[0]["box"], title=name)
|
|
64
|
+
if "box" in summary[0]
|
|
65
|
+
else BBox()
|
|
66
|
+
)
|
|
67
|
+
segments = (
|
|
68
|
+
Segments.from_dict(summary[0]["segments"], title=name)
|
|
69
|
+
if "segments" in summary[0]
|
|
70
|
+
else Segments()
|
|
71
|
+
)
|
|
72
|
+
return YoloSegment(
|
|
73
|
+
cls=summary[0]["class"],
|
|
74
|
+
name=summary[0]["name"],
|
|
75
|
+
confidence=summary[0]["confidence"],
|
|
76
|
+
box=box,
|
|
77
|
+
segments=segments,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class YoloSegments(DataModel):
|
|
82
|
+
"""
|
|
83
|
+
A data model for a list of YOLO segments.
|
|
84
|
+
|
|
85
|
+
Attributes:
|
|
86
|
+
cls (list[int]): The classes of the segments.
|
|
87
|
+
name (list[str]): The names of the segments.
|
|
88
|
+
confidence (list[float]): The confidences of the segments.
|
|
89
|
+
box (list[BBox]): The bounding boxes of the segments.
|
|
90
|
+
segments (list[Segments]): The segments of the segments.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
cls: list[int]
|
|
94
|
+
name: list[str]
|
|
95
|
+
confidence: list[float]
|
|
96
|
+
box: list[BBox]
|
|
97
|
+
segments: list[Segments]
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def from_file(yolo: "YOLO", file: "File") -> "YoloSegments":
|
|
101
|
+
results = yolo(Image.open(BytesIO(file.read())))
|
|
102
|
+
return YoloSegments.from_results(results)
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def from_results(results: list["Results"]) -> "YoloSegments":
|
|
106
|
+
cls, names, confidence, box, segments = [], [], [], [], []
|
|
107
|
+
for r in results:
|
|
108
|
+
for s in r.summary():
|
|
109
|
+
name = s.get("name", "")
|
|
110
|
+
cls.append(s["class"])
|
|
111
|
+
names.append(name)
|
|
112
|
+
confidence.append(s["confidence"])
|
|
113
|
+
box.append(BBox.from_dict(s.get("box", {}), title=name))
|
|
114
|
+
segments.append(Segments.from_dict(s.get("segments", {}), title=name))
|
|
115
|
+
return YoloSegments(
|
|
116
|
+
cls=cls,
|
|
117
|
+
name=names,
|
|
118
|
+
confidence=confidence,
|
|
119
|
+
box=box,
|
|
120
|
+
segments=segments,
|
|
121
|
+
)
|
datachain/lib/signal_schema.py
CHANGED
datachain/listing.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import glob
|
|
2
2
|
import os
|
|
3
3
|
from collections.abc import Iterable, Iterator
|
|
4
|
+
from functools import cached_property
|
|
4
5
|
from itertools import zip_longest
|
|
5
6
|
from typing import TYPE_CHECKING, Optional
|
|
6
7
|
|
|
@@ -15,28 +16,34 @@ from datachain.utils import suffix_to_number
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from datachain.catalog.datasource import DataSource
|
|
17
18
|
from datachain.client import Client
|
|
18
|
-
from datachain.data_storage import AbstractWarehouse
|
|
19
|
+
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
19
20
|
from datachain.dataset import DatasetRecord
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class Listing:
|
|
23
24
|
def __init__(
|
|
24
25
|
self,
|
|
26
|
+
metastore: "AbstractMetastore",
|
|
25
27
|
warehouse: "AbstractWarehouse",
|
|
26
28
|
client: "Client",
|
|
27
|
-
|
|
29
|
+
dataset_name: Optional["str"] = None,
|
|
30
|
+
dataset_version: Optional[int] = None,
|
|
28
31
|
object_name: str = "file",
|
|
29
32
|
):
|
|
33
|
+
self.metastore = metastore
|
|
30
34
|
self.warehouse = warehouse
|
|
31
35
|
self.client = client
|
|
32
|
-
self.
|
|
36
|
+
self.dataset_name = dataset_name # dataset representing bucket listing
|
|
37
|
+
self.dataset_version = dataset_version # dataset representing bucket listing
|
|
33
38
|
self.object_name = object_name
|
|
34
39
|
|
|
35
40
|
def clone(self) -> "Listing":
|
|
36
41
|
return self.__class__(
|
|
42
|
+
self.metastore.clone(),
|
|
37
43
|
self.warehouse.clone(),
|
|
38
44
|
self.client,
|
|
39
|
-
self.
|
|
45
|
+
self.dataset_name,
|
|
46
|
+
self.dataset_version,
|
|
40
47
|
self.object_name,
|
|
41
48
|
)
|
|
42
49
|
|
|
@@ -53,12 +60,22 @@ class Listing:
|
|
|
53
60
|
def uri(self):
|
|
54
61
|
from datachain.lib.listing import listing_uri_from_name
|
|
55
62
|
|
|
56
|
-
|
|
63
|
+
assert self.dataset_name
|
|
57
64
|
|
|
58
|
-
|
|
65
|
+
return listing_uri_from_name(self.dataset_name)
|
|
66
|
+
|
|
67
|
+
@cached_property
|
|
68
|
+
def dataset(self) -> "DatasetRecord":
|
|
69
|
+
assert self.dataset_name
|
|
70
|
+
return self.metastore.get_dataset(self.dataset_name)
|
|
71
|
+
|
|
72
|
+
@cached_property
|
|
59
73
|
def dataset_rows(self):
|
|
74
|
+
dataset = self.dataset
|
|
60
75
|
return self.warehouse.dataset_rows(
|
|
61
|
-
|
|
76
|
+
dataset,
|
|
77
|
+
self.dataset_version or dataset.latest_version,
|
|
78
|
+
object_name=self.object_name,
|
|
62
79
|
)
|
|
63
80
|
|
|
64
81
|
def expand_path(self, path, use_glob=True) -> list[Node]:
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from datachain import C, DataChain
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
|
|
5
|
+
"""
|
|
6
|
+
Splits a DataChain into multiple subsets based on the provided weights.
|
|
7
|
+
|
|
8
|
+
This function partitions the rows or items of a DataChain into disjoint subsets,
|
|
9
|
+
ensuring that the relative sizes of the subsets correspond to the given weights.
|
|
10
|
+
It is particularly useful for creating training, validation, and test datasets.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
dc (DataChain):
|
|
14
|
+
The DataChain instance to split.
|
|
15
|
+
weights (list[float]):
|
|
16
|
+
A list of weights indicating the relative proportions of the splits.
|
|
17
|
+
The weights do not need to sum to 1; they will be normalized internally.
|
|
18
|
+
For example:
|
|
19
|
+
- `[0.7, 0.3]` corresponds to a 70/30 split;
|
|
20
|
+
- `[2, 1, 1]` corresponds to a 50/25/25 split.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
list[DataChain]:
|
|
24
|
+
A list of DataChain instances, one for each weight in the weights list.
|
|
25
|
+
|
|
26
|
+
Examples:
|
|
27
|
+
Train-test split:
|
|
28
|
+
```python
|
|
29
|
+
from datachain import DataChain
|
|
30
|
+
from datachain.toolkit import train_test_split
|
|
31
|
+
|
|
32
|
+
# Load a DataChain from a storage source (e.g., S3 bucket)
|
|
33
|
+
dc = DataChain.from_storage("s3://bucket/dir/")
|
|
34
|
+
|
|
35
|
+
# Perform a 70/30 train-test split
|
|
36
|
+
train, test = train_test_split(dc, [0.7, 0.3])
|
|
37
|
+
|
|
38
|
+
# Save the resulting splits
|
|
39
|
+
train.save("dataset_train")
|
|
40
|
+
test.save("dataset_test")
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Train-test-validation split:
|
|
44
|
+
```python
|
|
45
|
+
train, test, val = train_test_split(dc, [0.7, 0.2, 0.1])
|
|
46
|
+
train.save("dataset_train")
|
|
47
|
+
test.save("dataset_test")
|
|
48
|
+
val.save("dataset_val")
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Note:
|
|
52
|
+
The splits are random but deterministic, based on Dataset `sys__rand` field.
|
|
53
|
+
"""
|
|
54
|
+
if len(weights) < 2:
|
|
55
|
+
raise ValueError("Weights should have at least two elements")
|
|
56
|
+
if any(weight < 0 for weight in weights):
|
|
57
|
+
raise ValueError("Weights should be non-negative")
|
|
58
|
+
|
|
59
|
+
weights_normalized = [weight / sum(weights) for weight in weights]
|
|
60
|
+
|
|
61
|
+
return [
|
|
62
|
+
dc.filter(
|
|
63
|
+
C("sys__rand") % 1000 >= round(sum(weights_normalized[:index]) * 1000),
|
|
64
|
+
C("sys__rand") % 1000 < round(sum(weights_normalized[: index + 1]) * 1000),
|
|
65
|
+
)
|
|
66
|
+
for index, _ in enumerate(weights_normalized)
|
|
67
|
+
]
|