datachain 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +3 -4
- datachain/cache.py +10 -4
- datachain/catalog/catalog.py +42 -16
- datachain/cli.py +48 -32
- datachain/data_storage/metastore.py +24 -0
- datachain/data_storage/warehouse.py +3 -1
- datachain/job.py +56 -0
- datachain/lib/arrow.py +19 -7
- datachain/lib/clip.py +89 -66
- datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
- datachain/lib/convert/sql_to_python.py +23 -0
- datachain/lib/convert/values_to_tuples.py +51 -33
- datachain/lib/data_model.py +6 -27
- datachain/lib/dataset_info.py +70 -0
- datachain/lib/dc.py +618 -156
- datachain/lib/file.py +130 -22
- datachain/lib/image.py +1 -1
- datachain/lib/meta_formats.py +14 -2
- datachain/lib/model_store.py +3 -2
- datachain/lib/pytorch.py +10 -7
- datachain/lib/signal_schema.py +19 -11
- datachain/lib/text.py +2 -1
- datachain/lib/udf.py +56 -5
- datachain/lib/udf_signature.py +1 -1
- datachain/node.py +11 -8
- datachain/query/dataset.py +62 -28
- datachain/query/schema.py +2 -0
- datachain/query/session.py +4 -4
- datachain/sql/functions/array.py +12 -0
- datachain/sql/functions/string.py +8 -0
- datachain/torch/__init__.py +1 -1
- datachain/utils.py +6 -0
- datachain-0.2.13.dist-info/METADATA +411 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/RECORD +38 -42
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/WHEEL +1 -1
- datachain/lib/gpt4_vision.py +0 -97
- datachain/lib/hf_image_to_text.py +0 -97
- datachain/lib/hf_pipeline.py +0 -90
- datachain/lib/image_transform.py +0 -103
- datachain/lib/iptc_exif_xmp.py +0 -76
- datachain/lib/unstructured.py +0 -41
- datachain/text/__init__.py +0 -3
- datachain-0.2.11.dist-info/METADATA +0 -431
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/LICENSE +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.11.dist-info → datachain-0.2.13.dist-info}/top_level.txt +0 -0
datachain/lib/file.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
|
+
import os
|
|
4
|
+
import posixpath
|
|
3
5
|
from abc import ABC, abstractmethod
|
|
4
6
|
from contextlib import contextmanager
|
|
5
7
|
from datetime import datetime
|
|
@@ -16,14 +18,17 @@ from pydantic import Field, field_validator
|
|
|
16
18
|
|
|
17
19
|
from datachain.cache import UniqueId
|
|
18
20
|
from datachain.client.fileslice import FileSlice
|
|
19
|
-
from datachain.lib.data_model import DataModel
|
|
21
|
+
from datachain.lib.data_model import DataModel
|
|
20
22
|
from datachain.lib.utils import DataChainError
|
|
21
|
-
from datachain.sql.types import JSON, Int, String
|
|
23
|
+
from datachain.sql.types import JSON, Boolean, DateTime, Int, String
|
|
22
24
|
from datachain.utils import TIME_ZERO
|
|
23
25
|
|
|
24
26
|
if TYPE_CHECKING:
|
|
25
27
|
from datachain.catalog import Catalog
|
|
26
28
|
|
|
29
|
+
# how to create file path when exporting
|
|
30
|
+
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
|
|
31
|
+
|
|
27
32
|
|
|
28
33
|
class VFileError(DataChainError):
|
|
29
34
|
def __init__(self, file: "File", message: str, vtype: str = ""):
|
|
@@ -49,12 +54,15 @@ class VFile(ABC):
|
|
|
49
54
|
|
|
50
55
|
|
|
51
56
|
class TarVFile(VFile):
|
|
57
|
+
"""Virtual file model for files extracted from tar archives."""
|
|
58
|
+
|
|
52
59
|
@classmethod
|
|
53
60
|
def get_vtype(cls) -> str:
|
|
54
61
|
return "tar"
|
|
55
62
|
|
|
56
63
|
@classmethod
|
|
57
64
|
def open(cls, file: "File", location: list[dict]):
|
|
65
|
+
"""Stream file from tar archive based on location in archive."""
|
|
58
66
|
if len(location) > 1:
|
|
59
67
|
VFileError(file, "multiple 'location's are not supported yet")
|
|
60
68
|
|
|
@@ -100,7 +108,9 @@ class VFileRegistry:
|
|
|
100
108
|
return reader.open(file, location)
|
|
101
109
|
|
|
102
110
|
|
|
103
|
-
class File(
|
|
111
|
+
class File(DataModel):
|
|
112
|
+
"""`DataModel` for reading binary files."""
|
|
113
|
+
|
|
104
114
|
source: str = Field(default="")
|
|
105
115
|
parent: str = Field(default="")
|
|
106
116
|
name: str
|
|
@@ -116,25 +126,30 @@ class File(FileBasic):
|
|
|
116
126
|
"source": String,
|
|
117
127
|
"parent": String,
|
|
118
128
|
"name": String,
|
|
129
|
+
"size": Int,
|
|
119
130
|
"version": String,
|
|
120
131
|
"etag": String,
|
|
121
|
-
"
|
|
122
|
-
"
|
|
132
|
+
"is_latest": Boolean,
|
|
133
|
+
"last_modified": DateTime,
|
|
123
134
|
"location": JSON,
|
|
135
|
+
"vtype": String,
|
|
124
136
|
}
|
|
125
137
|
|
|
126
138
|
_unique_id_keys: ClassVar[list[str]] = [
|
|
127
139
|
"source",
|
|
128
140
|
"parent",
|
|
129
141
|
"name",
|
|
130
|
-
"etag",
|
|
131
142
|
"size",
|
|
143
|
+
"etag",
|
|
144
|
+
"version",
|
|
145
|
+
"is_latest",
|
|
132
146
|
"vtype",
|
|
133
147
|
"location",
|
|
148
|
+
"last_modified",
|
|
134
149
|
]
|
|
135
150
|
|
|
136
151
|
@staticmethod
|
|
137
|
-
def
|
|
152
|
+
def _validate_dict(
|
|
138
153
|
v: Optional[Union[str, dict, list[dict]]],
|
|
139
154
|
) -> Optional[Union[str, dict, list[dict]]]:
|
|
140
155
|
if v is None or v == "":
|
|
@@ -152,7 +167,7 @@ class File(FileBasic):
|
|
|
152
167
|
@field_validator("location", mode="before")
|
|
153
168
|
@classmethod
|
|
154
169
|
def validate_location(cls, v):
|
|
155
|
-
return File.
|
|
170
|
+
return File._validate_dict(v)
|
|
156
171
|
|
|
157
172
|
@field_validator("parent", mode="before")
|
|
158
173
|
@classmethod
|
|
@@ -172,9 +187,10 @@ class File(FileBasic):
|
|
|
172
187
|
self._caching_enabled = False
|
|
173
188
|
|
|
174
189
|
@contextmanager
|
|
175
|
-
def open(self):
|
|
190
|
+
def open(self, mode: Literal["rb", "r"] = "rb"):
|
|
191
|
+
"""Open the file and return a file object."""
|
|
176
192
|
if self.location:
|
|
177
|
-
with VFileRegistry.resolve(self, self.location) as f:
|
|
193
|
+
with VFileRegistry.resolve(self, self.location) as f: # type: ignore[arg-type]
|
|
178
194
|
yield f
|
|
179
195
|
|
|
180
196
|
uid = self.get_uid()
|
|
@@ -184,7 +200,41 @@ class File(FileBasic):
|
|
|
184
200
|
with client.open_object(
|
|
185
201
|
uid, use_cache=self._caching_enabled, cb=self._download_cb
|
|
186
202
|
) as f:
|
|
187
|
-
yield f
|
|
203
|
+
yield io.TextIOWrapper(f) if mode == "r" else f
|
|
204
|
+
|
|
205
|
+
def read(self, length: int = -1):
|
|
206
|
+
"""Returns file contents."""
|
|
207
|
+
with self.open() as stream:
|
|
208
|
+
return stream.read(length)
|
|
209
|
+
|
|
210
|
+
def read_bytes(self):
|
|
211
|
+
"""Returns file contents as bytes."""
|
|
212
|
+
return self.read()
|
|
213
|
+
|
|
214
|
+
def read_text(self):
|
|
215
|
+
"""Returns file contents as text."""
|
|
216
|
+
with self.open(mode="r") as stream:
|
|
217
|
+
return stream.read()
|
|
218
|
+
|
|
219
|
+
def save(self, destination: str):
|
|
220
|
+
"""Writes it's content to destination"""
|
|
221
|
+
with open(destination, mode="wb") as f:
|
|
222
|
+
f.write(self.read())
|
|
223
|
+
|
|
224
|
+
def export(
|
|
225
|
+
self,
|
|
226
|
+
output: str,
|
|
227
|
+
placement: ExportPlacement = "fullpath",
|
|
228
|
+
use_cache: bool = True,
|
|
229
|
+
) -> None:
|
|
230
|
+
"""Export file to new location."""
|
|
231
|
+
if use_cache:
|
|
232
|
+
self._caching_enabled = use_cache
|
|
233
|
+
dst = self.get_destination_path(output, placement)
|
|
234
|
+
dst_dir = os.path.dirname(dst)
|
|
235
|
+
os.makedirs(dst_dir, exist_ok=True)
|
|
236
|
+
|
|
237
|
+
self.save(dst)
|
|
188
238
|
|
|
189
239
|
def _set_stream(
|
|
190
240
|
self,
|
|
@@ -197,11 +247,12 @@ class File(FileBasic):
|
|
|
197
247
|
self._download_cb = download_cb
|
|
198
248
|
|
|
199
249
|
def get_uid(self) -> UniqueId:
|
|
250
|
+
"""Returns unique ID for file."""
|
|
200
251
|
dump = self.model_dump()
|
|
201
252
|
return UniqueId(*(dump[k] for k in self._unique_id_keys))
|
|
202
253
|
|
|
203
254
|
def get_local_path(self) -> Optional[str]:
|
|
204
|
-
"""
|
|
255
|
+
"""Returns path to a file in a local cache.
|
|
205
256
|
Return None if file is not cached. Throws an exception if cache is not setup."""
|
|
206
257
|
if self._catalog is None:
|
|
207
258
|
raise RuntimeError(
|
|
@@ -210,21 +261,27 @@ class File(FileBasic):
|
|
|
210
261
|
return self._catalog.cache.get_path(self.get_uid())
|
|
211
262
|
|
|
212
263
|
def get_file_suffix(self):
|
|
264
|
+
"""Returns last part of file name with `.`."""
|
|
213
265
|
return Path(self.name).suffix
|
|
214
266
|
|
|
215
267
|
def get_file_ext(self):
|
|
268
|
+
"""Returns last part of file name without `.`."""
|
|
216
269
|
return Path(self.name).suffix.strip(".")
|
|
217
270
|
|
|
218
271
|
def get_file_stem(self):
|
|
272
|
+
"""Returns file name without extension."""
|
|
219
273
|
return Path(self.name).stem
|
|
220
274
|
|
|
221
275
|
def get_full_name(self):
|
|
276
|
+
"""Returns name with parent directories."""
|
|
222
277
|
return (Path(self.parent) / self.name).as_posix()
|
|
223
278
|
|
|
224
279
|
def get_uri(self):
|
|
280
|
+
"""Returns file URI."""
|
|
225
281
|
return f"{self.source}/{self.get_full_name()}"
|
|
226
282
|
|
|
227
283
|
def get_path(self) -> str:
|
|
284
|
+
"""Returns file path."""
|
|
228
285
|
path = unquote(self.get_uri())
|
|
229
286
|
fs = self.get_fs()
|
|
230
287
|
if isinstance(fs, LocalFileSystem):
|
|
@@ -233,21 +290,65 @@ class File(FileBasic):
|
|
|
233
290
|
path = url2pathname(path)
|
|
234
291
|
return path
|
|
235
292
|
|
|
293
|
+
def get_destination_path(self, output: str, placement: ExportPlacement) -> str:
|
|
294
|
+
"""
|
|
295
|
+
Returns full destination path of a file for exporting to some output
|
|
296
|
+
based on export placement
|
|
297
|
+
"""
|
|
298
|
+
if placement == "filename":
|
|
299
|
+
path = unquote(self.name)
|
|
300
|
+
elif placement == "etag":
|
|
301
|
+
path = f"{self.etag}{self.get_file_suffix()}"
|
|
302
|
+
elif placement == "fullpath":
|
|
303
|
+
fs = self.get_fs()
|
|
304
|
+
if isinstance(fs, LocalFileSystem):
|
|
305
|
+
path = unquote(self.get_full_name())
|
|
306
|
+
else:
|
|
307
|
+
path = (
|
|
308
|
+
Path(urlparse(self.source).netloc) / unquote(self.get_full_name())
|
|
309
|
+
).as_posix()
|
|
310
|
+
elif placement == "checksum":
|
|
311
|
+
raise NotImplementedError("Checksum placement not implemented yet")
|
|
312
|
+
else:
|
|
313
|
+
raise ValueError(f"Unsupported file export placement: {placement}")
|
|
314
|
+
return posixpath.join(output, path) # type: ignore[union-attr]
|
|
315
|
+
|
|
236
316
|
def get_fs(self):
|
|
317
|
+
"""Returns `fsspec` filesystem for the file."""
|
|
237
318
|
return self._catalog.get_client(self.source).fs
|
|
238
319
|
|
|
239
320
|
|
|
240
321
|
class TextFile(File):
|
|
322
|
+
"""`DataModel` for reading text files."""
|
|
323
|
+
|
|
241
324
|
@contextmanager
|
|
242
325
|
def open(self):
|
|
243
|
-
|
|
244
|
-
|
|
326
|
+
"""Open the file and return a file object in text mode."""
|
|
327
|
+
with super().open(mode="r") as stream:
|
|
328
|
+
yield stream
|
|
329
|
+
|
|
330
|
+
def read_text(self):
|
|
331
|
+
"""Returns file contents as text."""
|
|
332
|
+
with self.open() as stream:
|
|
333
|
+
return stream.read()
|
|
334
|
+
|
|
335
|
+
def save(self, destination: str):
|
|
336
|
+
"""Writes it's content to destination"""
|
|
337
|
+
with open(destination, mode="w") as f:
|
|
338
|
+
f.write(self.read_text())
|
|
245
339
|
|
|
246
340
|
|
|
247
341
|
class ImageFile(File):
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
342
|
+
"""`DataModel` for reading image files."""
|
|
343
|
+
|
|
344
|
+
def read(self):
|
|
345
|
+
"""Returns `PIL.Image.Image` object."""
|
|
346
|
+
fobj = super().read()
|
|
347
|
+
return Image.open(BytesIO(fobj))
|
|
348
|
+
|
|
349
|
+
def save(self, destination: str):
|
|
350
|
+
"""Writes it's content to destination"""
|
|
351
|
+
self.read().save(destination)
|
|
251
352
|
|
|
252
353
|
|
|
253
354
|
def get_file(type_: Literal["binary", "text", "image"] = "binary"):
|
|
@@ -261,28 +362,35 @@ def get_file(type_: Literal["binary", "text", "image"] = "binary"):
|
|
|
261
362
|
source: str,
|
|
262
363
|
parent: str,
|
|
263
364
|
name: str,
|
|
365
|
+
size: int,
|
|
264
366
|
version: str,
|
|
265
367
|
etag: str,
|
|
266
|
-
|
|
267
|
-
|
|
368
|
+
is_latest: bool,
|
|
369
|
+
last_modified: datetime,
|
|
268
370
|
location: Optional[Union[dict, list[dict]]],
|
|
371
|
+
vtype: str,
|
|
269
372
|
) -> file: # type: ignore[valid-type]
|
|
270
373
|
return file(
|
|
271
374
|
source=source,
|
|
272
375
|
parent=parent,
|
|
273
376
|
name=name,
|
|
377
|
+
size=size,
|
|
274
378
|
version=version,
|
|
275
379
|
etag=etag,
|
|
276
|
-
|
|
277
|
-
|
|
380
|
+
is_latest=is_latest,
|
|
381
|
+
last_modified=last_modified,
|
|
278
382
|
location=location,
|
|
383
|
+
vtype=vtype,
|
|
279
384
|
)
|
|
280
385
|
|
|
281
386
|
return get_file_type
|
|
282
387
|
|
|
283
388
|
|
|
284
389
|
class IndexedFile(DataModel):
|
|
285
|
-
"""
|
|
390
|
+
"""Metadata indexed from tabular files.
|
|
391
|
+
|
|
392
|
+
Includes `file` and `index` signals.
|
|
393
|
+
"""
|
|
286
394
|
|
|
287
395
|
file: File
|
|
288
396
|
index: int
|
datachain/lib/image.py
CHANGED
|
@@ -53,7 +53,7 @@ def convert_images(
|
|
|
53
53
|
Resize, transform, and otherwise convert one or more images.
|
|
54
54
|
|
|
55
55
|
Args:
|
|
56
|
-
|
|
56
|
+
images (Image, list[Image]): PIL.Image object or list of objects.
|
|
57
57
|
mode (str): PIL.Image mode.
|
|
58
58
|
size (tuple[int, int]): Size in (width, height) pixels for resizing.
|
|
59
59
|
transform (Callable): Torchvision transform or huggingface processor to apply.
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -13,6 +13,7 @@ from typing import Any, Callable
|
|
|
13
13
|
import jmespath as jsp
|
|
14
14
|
from pydantic import ValidationError
|
|
15
15
|
|
|
16
|
+
from datachain.lib.data_model import ModelStore # noqa: F401
|
|
16
17
|
from datachain.lib.file import File
|
|
17
18
|
|
|
18
19
|
|
|
@@ -86,6 +87,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
|
|
|
86
87
|
except subprocess.CalledProcessError as e:
|
|
87
88
|
model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
|
|
88
89
|
print(f"{model_output}")
|
|
90
|
+
print("\n" + f"ModelStore.register({model_name})" + "\n")
|
|
91
|
+
print("\n" + f"spec={model_name}" + "\n")
|
|
89
92
|
return model_output
|
|
90
93
|
|
|
91
94
|
|
|
@@ -99,6 +102,7 @@ def read_meta( # noqa: C901
|
|
|
99
102
|
jmespath=None,
|
|
100
103
|
show_schema=False,
|
|
101
104
|
model_name=None,
|
|
105
|
+
nrows=None,
|
|
102
106
|
) -> Callable:
|
|
103
107
|
from datachain.lib.dc import DataChain
|
|
104
108
|
|
|
@@ -118,8 +122,7 @@ def read_meta( # noqa: C901
|
|
|
118
122
|
output=str,
|
|
119
123
|
)
|
|
120
124
|
)
|
|
121
|
-
|
|
122
|
-
chain.save()
|
|
125
|
+
chain.exec()
|
|
123
126
|
finally:
|
|
124
127
|
sys.stdout = current_stdout
|
|
125
128
|
model_output = captured_output.getvalue()
|
|
@@ -147,6 +150,7 @@ def read_meta( # noqa: C901
|
|
|
147
150
|
DataModel=spec, # noqa: N803
|
|
148
151
|
meta_type=meta_type,
|
|
149
152
|
jmespath=jmespath,
|
|
153
|
+
nrows=nrows,
|
|
150
154
|
) -> Iterator[spec]:
|
|
151
155
|
def validator(json_object: dict) -> spec:
|
|
152
156
|
json_string = json.dumps(json_object)
|
|
@@ -175,14 +179,22 @@ def read_meta( # noqa: C901
|
|
|
175
179
|
yield from validator(json_object)
|
|
176
180
|
|
|
177
181
|
else:
|
|
182
|
+
nrow = 0
|
|
178
183
|
for json_dict in json_object:
|
|
184
|
+
nrow = nrow + 1
|
|
185
|
+
if nrows is not None and nrow > nrows:
|
|
186
|
+
return
|
|
179
187
|
yield from validator(json_dict)
|
|
180
188
|
|
|
181
189
|
if meta_type == "jsonl":
|
|
182
190
|
try:
|
|
191
|
+
nrow = 0
|
|
183
192
|
with file.open() as fd:
|
|
184
193
|
data_string = fd.readline().replace("\r", "")
|
|
185
194
|
while data_string:
|
|
195
|
+
nrow = nrow + 1
|
|
196
|
+
if nrows is not None and nrow > nrows:
|
|
197
|
+
return
|
|
186
198
|
json_object = process_json(data_string, jmespath)
|
|
187
199
|
data_string = fd.readline()
|
|
188
200
|
yield from validator(json_object)
|
datachain/lib/model_store.py
CHANGED
|
@@ -22,7 +22,8 @@ class ModelStore:
|
|
|
22
22
|
return model.__name__
|
|
23
23
|
|
|
24
24
|
@classmethod
|
|
25
|
-
def
|
|
25
|
+
def register(cls, fr: type):
|
|
26
|
+
"""Register a class as a data model for deserialization."""
|
|
26
27
|
if (model := ModelStore.to_pydantic(fr)) is None:
|
|
27
28
|
return
|
|
28
29
|
|
|
@@ -34,7 +35,7 @@ class ModelStore:
|
|
|
34
35
|
|
|
35
36
|
for f_info in model.model_fields.values():
|
|
36
37
|
if (anno := ModelStore.to_pydantic(f_info.annotation)) is not None:
|
|
37
|
-
cls.
|
|
38
|
+
cls.register(anno)
|
|
38
39
|
|
|
39
40
|
@classmethod
|
|
40
41
|
def get(cls, name: str, version: Optional[int] = None) -> Optional[type]:
|
datachain/lib/pytorch.py
CHANGED
|
@@ -3,7 +3,6 @@ from collections.abc import Iterator
|
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
4
4
|
|
|
5
5
|
from PIL import Image
|
|
6
|
-
from pydantic import BaseModel
|
|
7
6
|
from torch import float32
|
|
8
7
|
from torch.distributed import get_rank, get_world_size
|
|
9
8
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
@@ -11,6 +10,7 @@ from torchvision.transforms import v2
|
|
|
11
10
|
|
|
12
11
|
from datachain.catalog import Catalog, get_catalog
|
|
13
12
|
from datachain.lib.dc import DataChain
|
|
13
|
+
from datachain.lib.file import File
|
|
14
14
|
from datachain.lib.text import convert_text
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
@@ -24,6 +24,7 @@ DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def label_to_int(value: str, classes: list) -> int:
|
|
27
|
+
"""Given a value and list of classes, return the index of the value's class."""
|
|
27
28
|
return classes.index(value)
|
|
28
29
|
|
|
29
30
|
|
|
@@ -33,7 +34,7 @@ class PytorchDataset(IterableDataset):
|
|
|
33
34
|
name: str,
|
|
34
35
|
version: Optional[int] = None,
|
|
35
36
|
catalog: Optional["Catalog"] = None,
|
|
36
|
-
transform: Optional["Transform"] =
|
|
37
|
+
transform: Optional["Transform"] = None,
|
|
37
38
|
tokenizer: Optional[Callable] = None,
|
|
38
39
|
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
39
40
|
num_samples: int = 0,
|
|
@@ -41,6 +42,9 @@ class PytorchDataset(IterableDataset):
|
|
|
41
42
|
"""
|
|
42
43
|
Pytorch IterableDataset that streams DataChain datasets.
|
|
43
44
|
|
|
45
|
+
See Also:
|
|
46
|
+
`DataChain.to_pytorch()` - convert chain to PyTorch Dataset.
|
|
47
|
+
|
|
44
48
|
Args:
|
|
45
49
|
name (str): Name of DataChain dataset to stream.
|
|
46
50
|
version (int): Version of DataChain dataset to stream.
|
|
@@ -53,7 +57,7 @@ class PytorchDataset(IterableDataset):
|
|
|
53
57
|
"""
|
|
54
58
|
self.name = name
|
|
55
59
|
self.version = version
|
|
56
|
-
self.transform = transform
|
|
60
|
+
self.transform = transform or DEFAULT_TRANSFORM
|
|
57
61
|
self.tokenizer = tokenizer
|
|
58
62
|
self.tokenizer_kwargs = tokenizer_kwargs or {}
|
|
59
63
|
self.num_samples = num_samples
|
|
@@ -90,12 +94,11 @@ class PytorchDataset(IterableDataset):
|
|
|
90
94
|
if self.num_samples > 0:
|
|
91
95
|
ds = ds.sample(self.num_samples)
|
|
92
96
|
ds = ds.chunk(total_rank, total_workers)
|
|
93
|
-
|
|
94
|
-
for row_features in stream:
|
|
97
|
+
for row_features in ds.collect():
|
|
95
98
|
row = []
|
|
96
99
|
for fr in row_features:
|
|
97
|
-
if isinstance(fr,
|
|
98
|
-
row.append(fr.
|
|
100
|
+
if isinstance(fr, File):
|
|
101
|
+
row.append(fr.read()) # type: ignore[unreachable]
|
|
99
102
|
else:
|
|
100
103
|
row.append(fr)
|
|
101
104
|
# Apply transforms
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -18,7 +18,8 @@ from pydantic import BaseModel, create_model
|
|
|
18
18
|
from typing_extensions import Literal as LiteralEx
|
|
19
19
|
|
|
20
20
|
from datachain.lib.convert.flatten import DATACHAIN_TO_TYPE
|
|
21
|
-
from datachain.lib.convert.
|
|
21
|
+
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
22
|
+
from datachain.lib.convert.sql_to_python import sql_to_python
|
|
22
23
|
from datachain.lib.convert.unflatten import unflatten_to_json_pos
|
|
23
24
|
from datachain.lib.data_model import DataModel, DataType
|
|
24
25
|
from datachain.lib.file import File
|
|
@@ -102,21 +103,20 @@ class SignalSchema:
|
|
|
102
103
|
@staticmethod
|
|
103
104
|
def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
|
|
104
105
|
signals: dict[str, DataType] = {}
|
|
105
|
-
for field,
|
|
106
|
-
|
|
107
|
-
if type_ is None:
|
|
106
|
+
for field, col_type in col_types.items():
|
|
107
|
+
if (py_type := DATACHAIN_TO_TYPE.get(col_type, None)) is None:
|
|
108
108
|
raise SignalSchemaError(
|
|
109
109
|
f"signal schema cannot be obtained for column '{field}':"
|
|
110
|
-
f" unsupported type '{
|
|
110
|
+
f" unsupported type '{py_type}'"
|
|
111
111
|
)
|
|
112
|
-
signals[field] =
|
|
112
|
+
signals[field] = py_type
|
|
113
113
|
return SignalSchema(signals)
|
|
114
114
|
|
|
115
115
|
def serialize(self) -> dict[str, str]:
|
|
116
116
|
signals = {}
|
|
117
117
|
for name, fr_type in self.values.items():
|
|
118
118
|
if (fr := ModelStore.to_pydantic(fr_type)) is not None:
|
|
119
|
-
ModelStore.
|
|
119
|
+
ModelStore.register(fr)
|
|
120
120
|
signals[name] = ModelStore.get_name(fr)
|
|
121
121
|
else:
|
|
122
122
|
orig = get_origin(fr_type)
|
|
@@ -144,7 +144,7 @@ class SignalSchema:
|
|
|
144
144
|
raise SignalSchemaError(
|
|
145
145
|
f"cannot deserialize '{signal}': "
|
|
146
146
|
f"unknown type '{type_name}'."
|
|
147
|
-
f" Try to add it with `ModelStore.
|
|
147
|
+
f" Try to add it with `ModelStore.register({type_name})`."
|
|
148
148
|
)
|
|
149
149
|
except TypeError as err:
|
|
150
150
|
raise SignalSchemaError(
|
|
@@ -161,7 +161,7 @@ class SignalSchema:
|
|
|
161
161
|
continue
|
|
162
162
|
if not has_subtree:
|
|
163
163
|
db_name = DEFAULT_DELIMITER.join(path)
|
|
164
|
-
res[db_name] =
|
|
164
|
+
res[db_name] = python_to_sql(type_)
|
|
165
165
|
return res
|
|
166
166
|
|
|
167
167
|
def row_to_objs(self, row: Sequence[Any]) -> list[DataType]:
|
|
@@ -278,6 +278,14 @@ class SignalSchema:
|
|
|
278
278
|
del schema[signal]
|
|
279
279
|
return SignalSchema(schema)
|
|
280
280
|
|
|
281
|
+
def mutate(self, args_map: dict) -> "SignalSchema":
|
|
282
|
+
return SignalSchema(self.values | sql_to_python(args_map))
|
|
283
|
+
|
|
284
|
+
def clone_without_sys_signals(self) -> "SignalSchema":
|
|
285
|
+
schema = copy.deepcopy(self.values)
|
|
286
|
+
schema.pop("sys", None)
|
|
287
|
+
return SignalSchema(schema)
|
|
288
|
+
|
|
281
289
|
def merge(
|
|
282
290
|
self,
|
|
283
291
|
right_schema: "SignalSchema",
|
|
@@ -290,9 +298,9 @@ class SignalSchema:
|
|
|
290
298
|
|
|
291
299
|
return SignalSchema(self.values | schema_right)
|
|
292
300
|
|
|
293
|
-
def
|
|
301
|
+
def get_signals(self, target_type: type[DataModel]) -> Iterator[str]:
|
|
294
302
|
for path, type_, has_subtree, _ in self.get_flat_tree():
|
|
295
|
-
if has_subtree and issubclass(type_,
|
|
303
|
+
if has_subtree and issubclass(type_, target_type):
|
|
296
304
|
yield ".".join(path)
|
|
297
305
|
|
|
298
306
|
def create_model(self, name: str) -> type[DataModel]:
|
datachain/lib/text.py
CHANGED
|
@@ -31,8 +31,9 @@ def convert_text(
|
|
|
31
31
|
res = tokenizer(text)
|
|
32
32
|
|
|
33
33
|
tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
34
|
+
tokens = torch.tensor(tokens)
|
|
34
35
|
|
|
35
36
|
if not encoder:
|
|
36
37
|
return tokens
|
|
37
38
|
|
|
38
|
-
return encoder(
|
|
39
|
+
return encoder(tokens)
|
datachain/lib/udf.py
CHANGED
|
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
|
|
9
9
|
from datachain.dataset import RowDict
|
|
10
10
|
from datachain.lib.convert.flatten import flatten
|
|
11
11
|
from datachain.lib.convert.unflatten import unflatten_to_json
|
|
12
|
-
from datachain.lib.
|
|
12
|
+
from datachain.lib.file import File
|
|
13
13
|
from datachain.lib.model_store import ModelStore
|
|
14
14
|
from datachain.lib.signal_schema import SignalSchema
|
|
15
15
|
from datachain.lib.udf_signature import UdfSignature
|
|
@@ -88,6 +88,53 @@ class UDFAdapter(_UDFBase):
|
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
class UDFBase(AbstractUDF):
|
|
91
|
+
"""Base class for stateful user-defined functions.
|
|
92
|
+
|
|
93
|
+
Any class that inherits from it must have a `process()` method that takes input
|
|
94
|
+
params from one or more rows in the chain and produces the expected output.
|
|
95
|
+
|
|
96
|
+
Optionally, the class may include these methods:
|
|
97
|
+
- `setup()` to run code on each worker before `process()` is called.
|
|
98
|
+
- `teardown()` to run code on each worker after `process()` completes.
|
|
99
|
+
|
|
100
|
+
Example:
|
|
101
|
+
```py
|
|
102
|
+
from datachain import C, DataChain, Mapper
|
|
103
|
+
import open_clip
|
|
104
|
+
|
|
105
|
+
class ImageEncoder(Mapper):
|
|
106
|
+
def __init__(self, model_name: str, pretrained: str):
|
|
107
|
+
self.model_name = model_name
|
|
108
|
+
self.pretrained = pretrained
|
|
109
|
+
|
|
110
|
+
def setup(self):
|
|
111
|
+
self.model, _, self.preprocess = (
|
|
112
|
+
open_clip.create_model_and_transforms(
|
|
113
|
+
self.model_name, self.pretrained
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def process(self, file) -> list[float]:
|
|
118
|
+
img = file.get_value()
|
|
119
|
+
img = self.preprocess(img).unsqueeze(0)
|
|
120
|
+
emb = self.model.encode_image(img)
|
|
121
|
+
return emb[0].tolist()
|
|
122
|
+
|
|
123
|
+
(
|
|
124
|
+
DataChain.from_storage(
|
|
125
|
+
"gs://datachain-demo/fashion-product-images/images", type="image"
|
|
126
|
+
)
|
|
127
|
+
.limit(5)
|
|
128
|
+
.map(
|
|
129
|
+
ImageEncoder("ViT-B-32", "laion2b_s34b_b79k"),
|
|
130
|
+
params=["file"],
|
|
131
|
+
output={"emb": list[float]},
|
|
132
|
+
)
|
|
133
|
+
.show()
|
|
134
|
+
)
|
|
135
|
+
```
|
|
136
|
+
"""
|
|
137
|
+
|
|
91
138
|
is_input_batched = False
|
|
92
139
|
is_output_batched = False
|
|
93
140
|
is_input_grouped = False
|
|
@@ -198,7 +245,7 @@ class UDFBase(AbstractUDF):
|
|
|
198
245
|
flat.extend(flatten(obj))
|
|
199
246
|
else:
|
|
200
247
|
flat.append(obj)
|
|
201
|
-
res.append(flat)
|
|
248
|
+
res.append(tuple(flat))
|
|
202
249
|
else:
|
|
203
250
|
# Generator expression is required, otherwise the value will be materialized
|
|
204
251
|
res = (
|
|
@@ -227,7 +274,7 @@ class UDFBase(AbstractUDF):
|
|
|
227
274
|
for row in rows:
|
|
228
275
|
obj_row = self.params.row_to_objs(row)
|
|
229
276
|
for obj in obj_row:
|
|
230
|
-
if isinstance(obj,
|
|
277
|
+
if isinstance(obj, File):
|
|
231
278
|
obj._set_stream(
|
|
232
279
|
self._catalog, caching_enabled=cache, download_cb=download_cb
|
|
233
280
|
)
|
|
@@ -256,7 +303,7 @@ class UDFBase(AbstractUDF):
|
|
|
256
303
|
else:
|
|
257
304
|
obj = slice[0]
|
|
258
305
|
|
|
259
|
-
if isinstance(obj,
|
|
306
|
+
if isinstance(obj, File):
|
|
260
307
|
obj._set_stream(
|
|
261
308
|
self._catalog, caching_enabled=cache, download_cb=download_cb
|
|
262
309
|
)
|
|
@@ -280,7 +327,7 @@ class UDFBase(AbstractUDF):
|
|
|
280
327
|
|
|
281
328
|
|
|
282
329
|
class Mapper(UDFBase):
|
|
283
|
-
pass
|
|
330
|
+
"""Inherit from this class to pass to `DataChain.map()`."""
|
|
284
331
|
|
|
285
332
|
|
|
286
333
|
class BatchMapper(Mapper):
|
|
@@ -289,10 +336,14 @@ class BatchMapper(Mapper):
|
|
|
289
336
|
|
|
290
337
|
|
|
291
338
|
class Generator(UDFBase):
|
|
339
|
+
"""Inherit from this class to pass to `DataChain.gen()`."""
|
|
340
|
+
|
|
292
341
|
is_output_batched = True
|
|
293
342
|
|
|
294
343
|
|
|
295
344
|
class Aggregator(UDFBase):
|
|
345
|
+
"""Inherit from this class to pass to `DataChain.agg()`."""
|
|
346
|
+
|
|
296
347
|
is_input_batched = True
|
|
297
348
|
is_output_batched = True
|
|
298
349
|
is_input_grouped = True
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -131,7 +131,7 @@ class UdfSignature:
|
|
|
131
131
|
raise UdfSignatureError(
|
|
132
132
|
chain,
|
|
133
133
|
f"output type '{value.__name__}' of signal '{key}' is not"
|
|
134
|
-
f" supported. Please use
|
|
134
|
+
f" supported. Please use DataModel types: {DataTypeNames}",
|
|
135
135
|
)
|
|
136
136
|
|
|
137
137
|
udf_output_map = output
|