datachain 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +30 -6
- datachain/data_storage/db_engine.py +0 -2
- datachain/data_storage/schema.py +10 -27
- datachain/data_storage/warehouse.py +1 -7
- datachain/lib/arrow.py +7 -13
- datachain/lib/clip.py +151 -0
- datachain/lib/dc.py +35 -57
- datachain/lib/feature_utils.py +1 -2
- datachain/lib/file.py +7 -0
- datachain/lib/image.py +37 -79
- datachain/lib/pytorch.py +4 -2
- datachain/lib/signal_schema.py +2 -47
- datachain/lib/text.py +18 -49
- datachain/lib/udf.py +58 -30
- datachain/lib/udf_signature.py +11 -10
- datachain/lib/utils.py +17 -0
- datachain/lib/webdataset.py +2 -2
- datachain/listing.py +0 -3
- datachain/query/dataset.py +63 -37
- datachain/query/dispatch.py +2 -2
- datachain/query/schema.py +1 -8
- datachain/query/udf.py +16 -18
- datachain/utils.py +28 -0
- {datachain-0.2.1.dist-info → datachain-0.2.3.dist-info}/METADATA +2 -1
- {datachain-0.2.1.dist-info → datachain-0.2.3.dist-info}/RECORD +29 -29
- {datachain-0.2.1.dist-info → datachain-0.2.3.dist-info}/WHEEL +1 -1
- datachain/lib/reader.py +0 -49
- {datachain-0.2.1.dist-info → datachain-0.2.3.dist-info}/LICENSE +0 -0
- {datachain-0.2.1.dist-info → datachain-0.2.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.1.dist-info → datachain-0.2.3.dist-info}/top_level.txt +0 -0
datachain/lib/image.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
from io import BytesIO
|
|
3
|
-
from typing import
|
|
2
|
+
from typing import Callable, Optional, Union
|
|
4
3
|
|
|
5
4
|
from datachain.lib.file import File
|
|
6
5
|
|
|
@@ -14,8 +13,6 @@ except ImportError as exc:
|
|
|
14
13
|
" pip install 'datachain[cv]'\n"
|
|
15
14
|
) from exc
|
|
16
15
|
|
|
17
|
-
from datachain.lib.reader import FeatureReader
|
|
18
|
-
|
|
19
16
|
|
|
20
17
|
class ImageFile(File):
|
|
21
18
|
def get_value(self):
|
|
@@ -28,8 +25,8 @@ def convert_image(
|
|
|
28
25
|
mode: str = "RGB",
|
|
29
26
|
size: Optional[tuple[int, int]] = None,
|
|
30
27
|
transform: Optional[Callable] = None,
|
|
31
|
-
|
|
32
|
-
):
|
|
28
|
+
encoder: Optional[Callable] = None,
|
|
29
|
+
) -> Union[Image.Image, torch.Tensor]:
|
|
33
30
|
"""
|
|
34
31
|
Resize, transform, and otherwise convert an image.
|
|
35
32
|
|
|
@@ -37,8 +34,8 @@ def convert_image(
|
|
|
37
34
|
img (Image): PIL.Image object.
|
|
38
35
|
mode (str): PIL.Image mode.
|
|
39
36
|
size (tuple[int, int]): Size in (width, height) pixels for resizing.
|
|
40
|
-
transform (Callable): Torchvision
|
|
41
|
-
|
|
37
|
+
transform (Callable): Torchvision transform or huggingface processor to apply.
|
|
38
|
+
encoder (Callable): Encode image using model.
|
|
42
39
|
"""
|
|
43
40
|
if mode:
|
|
44
41
|
img = img.convert(mode)
|
|
@@ -46,86 +43,47 @@ def convert_image(
|
|
|
46
43
|
img = img.resize(size)
|
|
47
44
|
if transform:
|
|
48
45
|
img = transform(img)
|
|
49
|
-
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
from transformers.image_processing_utils import BaseImageProcessor
|
|
49
|
+
|
|
50
|
+
if isinstance(transform, BaseImageProcessor):
|
|
51
|
+
img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
|
|
52
|
+
except ImportError:
|
|
53
|
+
pass
|
|
54
|
+
if encoder:
|
|
50
55
|
img = img.unsqueeze(0) # type: ignore[attr-defined]
|
|
51
|
-
if
|
|
52
|
-
|
|
53
|
-
if not (
|
|
54
|
-
hasattr(open_clip_model, method_name)
|
|
55
|
-
and inspect.ismethod(getattr(open_clip_model, method_name))
|
|
56
|
-
):
|
|
57
|
-
raise ValueError(
|
|
58
|
-
"Unable to render Image: 'open_clip_model' doesn't support"
|
|
59
|
-
f" '{method_name}()'"
|
|
60
|
-
)
|
|
61
|
-
img = open_clip_model.encode_image(img)
|
|
56
|
+
if encoder:
|
|
57
|
+
img = encoder(img)
|
|
62
58
|
return img
|
|
63
59
|
|
|
64
60
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
):
|
|
73
|
-
"""
|
|
74
|
-
Read and optionally transform an image.
|
|
75
|
-
|
|
76
|
-
All kwargs are passed to `convert_image()`.
|
|
77
|
-
"""
|
|
78
|
-
self.mode = mode
|
|
79
|
-
self.size = size
|
|
80
|
-
self.transform = transform
|
|
81
|
-
self.open_clip_model = open_clip_model
|
|
82
|
-
super().__init__(ImageFile)
|
|
83
|
-
|
|
84
|
-
def __call__(self, img: Image.Image):
|
|
85
|
-
return convert_image(
|
|
86
|
-
img,
|
|
87
|
-
mode=self.mode,
|
|
88
|
-
size=self.size,
|
|
89
|
-
transform=self.transform,
|
|
90
|
-
open_clip_model=self.open_clip_model,
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def similarity_scores(
|
|
95
|
-
model: Any,
|
|
96
|
-
preprocess: Callable,
|
|
97
|
-
tokenizer: Callable,
|
|
98
|
-
image: Image.Image,
|
|
99
|
-
text: str,
|
|
100
|
-
prob: bool = False,
|
|
101
|
-
) -> list[float]:
|
|
61
|
+
def convert_images(
|
|
62
|
+
images: Union[Image.Image, list[Image.Image]],
|
|
63
|
+
mode: str = "RGB",
|
|
64
|
+
size: Optional[tuple[int, int]] = None,
|
|
65
|
+
transform: Optional[Callable] = None,
|
|
66
|
+
encoder: Optional[Callable] = None,
|
|
67
|
+
) -> Union[list[Image.Image], torch.Tensor]:
|
|
102
68
|
"""
|
|
103
|
-
|
|
69
|
+
Resize, transform, and otherwise convert one or more images.
|
|
104
70
|
|
|
105
71
|
Args:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
prob: Compute softmax probabilities across texts.
|
|
72
|
+
img (Image, list[Image]): PIL.Image object or list of objects.
|
|
73
|
+
mode (str): PIL.Image mode.
|
|
74
|
+
size (tuple[int, int]): Size in (width, height) pixels for resizing.
|
|
75
|
+
transform (Callable): Torchvision transform or huggingface processor to apply.
|
|
76
|
+
encoder (Callable): Encode image using model.
|
|
112
77
|
"""
|
|
78
|
+
if isinstance(images, Image.Image):
|
|
79
|
+
images = [images]
|
|
113
80
|
|
|
114
|
-
|
|
115
|
-
image = preprocess(image).unsqueeze(0)
|
|
116
|
-
text = tokenizer(text)
|
|
117
|
-
|
|
118
|
-
image_features = model.encode_image(image)
|
|
119
|
-
text_features = model.encode_text(text)
|
|
120
|
-
|
|
121
|
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
122
|
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
81
|
+
converted = [convert_image(img, mode, size, transform) for img in images]
|
|
123
82
|
|
|
124
|
-
|
|
83
|
+
if isinstance(converted[0], torch.Tensor):
|
|
84
|
+
converted = torch.stack(converted) # type: ignore[assignment,arg-type]
|
|
125
85
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
scores = logits_per_text
|
|
86
|
+
if encoder:
|
|
87
|
+
converted = encoder(converted)
|
|
130
88
|
|
|
131
|
-
|
|
89
|
+
return converted # type: ignore[return-value]
|
datachain/lib/pytorch.py
CHANGED
|
@@ -116,10 +116,12 @@ class PytorchDataset(IterableDataset):
|
|
|
116
116
|
self.transform = None
|
|
117
117
|
if self.tokenizer:
|
|
118
118
|
for i, val in enumerate(row):
|
|
119
|
-
if isinstance(val, str)
|
|
119
|
+
if isinstance(val, str) or (
|
|
120
|
+
isinstance(val, list) and isinstance(val[0], str)
|
|
121
|
+
):
|
|
120
122
|
row[i] = convert_text(
|
|
121
123
|
val, self.tokenizer, self.tokenizer_kwargs
|
|
122
|
-
)
|
|
124
|
+
).squeeze(0) # type: ignore[union-attr]
|
|
123
125
|
yield row
|
|
124
126
|
|
|
125
127
|
@staticmethod
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin
|
|
|
5
5
|
|
|
6
6
|
from pydantic import create_model
|
|
7
7
|
|
|
8
|
-
from datachain.lib.arrow import Source
|
|
9
8
|
from datachain.lib.feature import (
|
|
10
9
|
DATACHAIN_TO_TYPE,
|
|
11
10
|
DEFAULT_DELIMITER,
|
|
@@ -14,17 +13,13 @@ from datachain.lib.feature import (
|
|
|
14
13
|
convert_type_to_datachain,
|
|
15
14
|
)
|
|
16
15
|
from datachain.lib.feature_registry import Registry
|
|
17
|
-
from datachain.lib.file import File
|
|
18
|
-
from datachain.lib.image import ImageFile
|
|
16
|
+
from datachain.lib.file import File
|
|
19
17
|
from datachain.lib.utils import DataChainParamsError
|
|
20
|
-
from datachain.lib.webdataset import TarStream, WDSAllFile, WDSBasic
|
|
21
|
-
from datachain.lib.webdataset_laion import Laion, WDSLaion
|
|
22
18
|
|
|
23
19
|
if TYPE_CHECKING:
|
|
24
20
|
from datachain.catalog import Catalog
|
|
25
21
|
|
|
26
22
|
|
|
27
|
-
# TODO fix hardcoded Feature class names with://github.com/iterative/dvcx/issues/1625
|
|
28
23
|
NAMES_TO_TYPES = {
|
|
29
24
|
"int": int,
|
|
30
25
|
"str": str,
|
|
@@ -34,15 +29,6 @@ NAMES_TO_TYPES = {
|
|
|
34
29
|
"dict": dict,
|
|
35
30
|
"bytes": bytes,
|
|
36
31
|
"datetime": datetime,
|
|
37
|
-
"WDSLaion": WDSLaion,
|
|
38
|
-
"Laion": Laion,
|
|
39
|
-
"Source": Source,
|
|
40
|
-
"File": File,
|
|
41
|
-
"ImageFile": ImageFile,
|
|
42
|
-
"TextFile": TextFile,
|
|
43
|
-
"TarStream": TarStream,
|
|
44
|
-
"WDSBasic": WDSBasic,
|
|
45
|
-
"WDSAllFile": WDSAllFile,
|
|
46
32
|
}
|
|
47
33
|
|
|
48
34
|
|
|
@@ -150,7 +136,7 @@ class SignalSchema:
|
|
|
150
136
|
)
|
|
151
137
|
|
|
152
138
|
def slice(self, keys: Sequence[str]) -> "SignalSchema":
|
|
153
|
-
return SignalSchema({k:
|
|
139
|
+
return SignalSchema({k: self.values[k] for k in keys if k in self.values})
|
|
154
140
|
|
|
155
141
|
def row_to_features(self, row: Sequence, catalog: "Catalog") -> list[FeatureType]:
|
|
156
142
|
res = []
|
|
@@ -240,37 +226,6 @@ class SignalSchema:
|
|
|
240
226
|
if has_subtree and issubclass(type_, File):
|
|
241
227
|
yield ".".join(path)
|
|
242
228
|
|
|
243
|
-
def get_file_signals_values(self, row: dict[str, Any]) -> dict[str, Any]:
|
|
244
|
-
"""
|
|
245
|
-
Method that returns values with clean field names (without prefix) for
|
|
246
|
-
all file signals found in this schema for some row
|
|
247
|
-
Output example:
|
|
248
|
-
{
|
|
249
|
-
laion.file: {
|
|
250
|
-
"source": "s3://ldb-public",
|
|
251
|
-
"name": "dog.jpg",
|
|
252
|
-
...
|
|
253
|
-
},
|
|
254
|
-
meta.file: {
|
|
255
|
-
"source": "s3://datacomp",
|
|
256
|
-
"name": "cat.jpg",
|
|
257
|
-
...
|
|
258
|
-
}
|
|
259
|
-
}
|
|
260
|
-
"""
|
|
261
|
-
res = {}
|
|
262
|
-
|
|
263
|
-
for file_signals in self.get_file_signals():
|
|
264
|
-
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
|
|
265
|
-
res[file_signals] = {
|
|
266
|
-
c_name.removeprefix(prefix): c_value
|
|
267
|
-
for c_name, c_value in row.items()
|
|
268
|
-
if c_name.startswith(prefix)
|
|
269
|
-
and DEFAULT_DELIMITER not in c_name.removeprefix(prefix)
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
return res
|
|
273
|
-
|
|
274
229
|
def create_model(self, name: str) -> type[Feature]:
|
|
275
230
|
fields = {key: (value, None) for key, value in self.values.items()}
|
|
276
231
|
|
datachain/lib/text.py
CHANGED
|
@@ -1,19 +1,15 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
3
2
|
|
|
4
|
-
from datachain.lib.file import TextFile
|
|
5
|
-
from datachain.lib.reader import FeatureReader
|
|
6
|
-
|
|
7
3
|
if TYPE_CHECKING:
|
|
8
|
-
|
|
4
|
+
import torch
|
|
9
5
|
|
|
10
6
|
|
|
11
7
|
def convert_text(
|
|
12
8
|
text: Union[str, list[str]],
|
|
13
9
|
tokenizer: Optional[Callable] = None,
|
|
14
10
|
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
15
|
-
|
|
16
|
-
):
|
|
11
|
+
encoder: Optional[Callable] = None,
|
|
12
|
+
) -> Union[str, list[str], "torch.Tensor"]:
|
|
17
13
|
"""
|
|
18
14
|
Tokenize and otherwise transform text.
|
|
19
15
|
|
|
@@ -21,18 +17,8 @@ def convert_text(
|
|
|
21
17
|
text (str): Text to convert.
|
|
22
18
|
tokenizer (Callable): Tokenizer to use to tokenize objects.
|
|
23
19
|
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
|
|
24
|
-
|
|
20
|
+
encoder (Callable): Encode text using model.
|
|
25
21
|
"""
|
|
26
|
-
if open_clip_model:
|
|
27
|
-
method_name = "encode_text"
|
|
28
|
-
if not (
|
|
29
|
-
hasattr(open_clip_model, method_name)
|
|
30
|
-
and inspect.ismethod(getattr(open_clip_model, method_name))
|
|
31
|
-
):
|
|
32
|
-
raise ValueError(
|
|
33
|
-
f"TextColumn error: 'model' doesn't support '{method_name}()'"
|
|
34
|
-
)
|
|
35
|
-
|
|
36
22
|
if not tokenizer:
|
|
37
23
|
return text
|
|
38
24
|
|
|
@@ -43,38 +29,21 @@ def convert_text(
|
|
|
43
29
|
res = tokenizer(text, **tokenizer_kwargs)
|
|
44
30
|
else:
|
|
45
31
|
res = tokenizer(text)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
49
|
-
|
|
50
|
-
if not open_clip_model:
|
|
51
|
-
return tokens.squeeze(0)
|
|
52
|
-
|
|
53
|
-
return open_clip_model.encode_text(tokens).squeeze(0)
|
|
32
|
+
try:
|
|
33
|
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
54
34
|
|
|
35
|
+
tokens = (
|
|
36
|
+
res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
37
|
+
)
|
|
38
|
+
except ImportError:
|
|
39
|
+
tokens = res
|
|
55
40
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
self,
|
|
59
|
-
fr_class: "FeatureLike" = TextFile,
|
|
60
|
-
tokenizer: Optional[Callable] = None,
|
|
61
|
-
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
62
|
-
open_clip_model: Optional[Any] = None,
|
|
63
|
-
):
|
|
64
|
-
"""
|
|
65
|
-
Read and optionally transform a text column.
|
|
41
|
+
if not encoder:
|
|
42
|
+
return tokens
|
|
66
43
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
self.open_clip_model = open_clip_model
|
|
72
|
-
super().__init__(fr_class)
|
|
44
|
+
try:
|
|
45
|
+
import torch
|
|
46
|
+
except ImportError:
|
|
47
|
+
"Missing dependency 'torch' needed to encode text."
|
|
73
48
|
|
|
74
|
-
|
|
75
|
-
return convert_text(
|
|
76
|
-
value,
|
|
77
|
-
tokenizer=self.tokenizer,
|
|
78
|
-
tokenizer_kwargs=self.tokenizer_kwargs,
|
|
79
|
-
open_clip_model=self.open_clip_model,
|
|
80
|
-
)
|
|
49
|
+
return encoder(torch.tensor(tokens))
|
datachain/lib/udf.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import sys
|
|
3
3
|
import traceback
|
|
4
|
-
from typing import TYPE_CHECKING, Callable
|
|
4
|
+
from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
|
6
6
|
from datachain.lib.feature import Feature
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
-
from datachain.lib.
|
|
8
|
+
from datachain.lib.udf_signature import UdfSignature
|
|
9
|
+
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
9
10
|
from datachain.query import udf
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
@@ -17,26 +18,68 @@ class UdfError(DataChainParamsError):
|
|
|
17
18
|
super().__init__(f"UDF error: {msg}")
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
class UDFBase:
|
|
21
|
+
class UDFBase(AbstractUDF):
|
|
21
22
|
is_input_batched = False
|
|
22
23
|
is_output_batched = False
|
|
23
24
|
is_input_grouped = False
|
|
24
25
|
|
|
25
|
-
def __init__(
|
|
26
|
-
self
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.params = None
|
|
28
|
+
self.output = None
|
|
29
|
+
self.params_spec = None
|
|
30
|
+
self.output_spec = None
|
|
31
|
+
self._contains_stream = None
|
|
32
|
+
self._catalog = None
|
|
33
|
+
self._func = None
|
|
34
|
+
|
|
35
|
+
def process(self, *args, **kwargs):
|
|
36
|
+
"""Processing function that needs to be defined by user"""
|
|
37
|
+
if not self._func:
|
|
38
|
+
raise NotImplementedError("UDF processing is not implemented")
|
|
39
|
+
return self._func(*args, **kwargs)
|
|
40
|
+
|
|
41
|
+
def setup(self):
|
|
42
|
+
"""Initialization process executed on each worker before processing begins.
|
|
43
|
+
This is needed for tasks like pre-loading ML models prior to scoring.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def teardown(self):
|
|
47
|
+
"""Teardown process executed on each process/worker after processing ends.
|
|
48
|
+
This is needed for tasks like closing connections to end-points.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def _init(self, sign: UdfSignature, params: SignalSchema, func: Callable):
|
|
31
52
|
self.params = params
|
|
32
|
-
self.output =
|
|
33
|
-
self._func = func
|
|
53
|
+
self.output = sign.output_schema
|
|
34
54
|
|
|
35
|
-
params_spec = params.to_udf_spec()
|
|
55
|
+
params_spec = self.params.to_udf_spec()
|
|
36
56
|
self.params_spec = list(params_spec.keys())
|
|
37
|
-
self.output_spec = output.to_udf_spec()
|
|
57
|
+
self.output_spec = self.output.to_udf_spec()
|
|
38
58
|
|
|
39
|
-
self.
|
|
59
|
+
self._func = func
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def _create(
|
|
63
|
+
cls,
|
|
64
|
+
target_class: type["UDFBase"],
|
|
65
|
+
sign: UdfSignature,
|
|
66
|
+
params: SignalSchema,
|
|
67
|
+
catalog,
|
|
68
|
+
) -> "UDFBase":
|
|
69
|
+
if isinstance(sign.func, AbstractUDF):
|
|
70
|
+
if not isinstance(sign.func, target_class): # type: ignore[unreachable]
|
|
71
|
+
raise UdfError(
|
|
72
|
+
f"cannot create UDF: provided UDF '{sign.func.__name__}'"
|
|
73
|
+
f" must be a child of target class '{target_class.__name__}'",
|
|
74
|
+
)
|
|
75
|
+
result = sign.func
|
|
76
|
+
func = None
|
|
77
|
+
else:
|
|
78
|
+
result = target_class()
|
|
79
|
+
func = sign.func
|
|
80
|
+
|
|
81
|
+
result._init(sign, params, func)
|
|
82
|
+
return result
|
|
40
83
|
|
|
41
84
|
@property
|
|
42
85
|
def name(self):
|
|
@@ -53,25 +96,10 @@ class UDFBase:
|
|
|
53
96
|
udf_wrapper = udf(self.params_spec, self.output_spec, batch=batch)
|
|
54
97
|
return udf_wrapper(self)
|
|
55
98
|
|
|
56
|
-
def bootstrap(self):
|
|
57
|
-
"""Initialization process executed on each worker before processing begins.
|
|
58
|
-
This is needed for tasks like pre-loading ML models prior to scoring.
|
|
59
|
-
"""
|
|
60
|
-
|
|
61
|
-
def teardown(self):
|
|
62
|
-
"""Teardown process executed on each process/worker after processing ends.
|
|
63
|
-
This is needed for tasks like closing connections to end-points.
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
def process(self, *args, **kwargs):
|
|
67
|
-
if not self._func:
|
|
68
|
-
raise NotImplementedError("UDF processing is not implemented")
|
|
69
|
-
return self._func(*args, **kwargs)
|
|
70
|
-
|
|
71
99
|
def validate_results(self, results, *args, **kwargs):
|
|
72
100
|
return results
|
|
73
101
|
|
|
74
|
-
def __call__(self, *rows
|
|
102
|
+
def __call__(self, *rows):
|
|
75
103
|
if self.is_input_grouped:
|
|
76
104
|
objs = self._parse_grouped_rows(rows)
|
|
77
105
|
else:
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Callable, Optional, Union, get_args, get_origin
|
|
|
5
5
|
|
|
6
6
|
from datachain.lib.feature import Feature, FeatureType, FeatureTypeNames
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
-
from datachain.lib.utils import DataChainParamsError
|
|
8
|
+
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class UdfSignatureError(DataChainParamsError):
|
|
@@ -49,10 +49,13 @@ class UdfSignature:
|
|
|
49
49
|
else:
|
|
50
50
|
if func is None:
|
|
51
51
|
raise UdfSignatureError(chain, "user function is not defined")
|
|
52
|
+
|
|
52
53
|
udf_func = func
|
|
53
54
|
signal_name = None
|
|
55
|
+
|
|
54
56
|
if not callable(udf_func):
|
|
55
|
-
raise UdfSignatureError(chain, f"
|
|
57
|
+
raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
|
|
58
|
+
|
|
56
59
|
func_params_map_sign, func_outs_sign, is_iterator = (
|
|
57
60
|
UdfSignature._func_signature(chain, udf_func)
|
|
58
61
|
)
|
|
@@ -108,13 +111,6 @@ class UdfSignature:
|
|
|
108
111
|
if isinstance(output, str):
|
|
109
112
|
output = [output]
|
|
110
113
|
if isinstance(output, Sequence):
|
|
111
|
-
if not func_outs_sign:
|
|
112
|
-
raise UdfSignatureError(
|
|
113
|
-
chain,
|
|
114
|
-
"output types are not specified. Specify types in 'output' as"
|
|
115
|
-
" a dict or as function return value hint.",
|
|
116
|
-
)
|
|
117
|
-
|
|
118
114
|
if len(func_outs_sign) != len(output):
|
|
119
115
|
raise UdfSignatureError(
|
|
120
116
|
chain,
|
|
@@ -158,8 +154,13 @@ class UdfSignature:
|
|
|
158
154
|
|
|
159
155
|
@staticmethod
|
|
160
156
|
def _func_signature(
|
|
161
|
-
chain: str,
|
|
157
|
+
chain: str, udf_func: Callable
|
|
162
158
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
159
|
+
if isinstance(udf_func, AbstractUDF):
|
|
160
|
+
func = udf_func.process # type: ignore[unreachable]
|
|
161
|
+
else:
|
|
162
|
+
func = udf_func
|
|
163
|
+
|
|
163
164
|
sign = inspect.signature(func)
|
|
164
165
|
|
|
165
166
|
input_map = {prm.name: prm.annotation for prm in sign.parameters.values()}
|
datachain/lib/utils.py
CHANGED
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AbstractUDF(ABC):
|
|
5
|
+
@abstractmethod
|
|
6
|
+
def process(self, *args, **kwargs):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def setup(self):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def teardown(self):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
1
18
|
class DataChainError(Exception):
|
|
2
19
|
def __init__(self, message):
|
|
3
20
|
super().__init__(message)
|
datachain/lib/webdataset.py
CHANGED
|
@@ -2,6 +2,7 @@ import hashlib
|
|
|
2
2
|
import json
|
|
3
3
|
import tarfile
|
|
4
4
|
from collections.abc import Iterator, Sequence
|
|
5
|
+
from pathlib import Path
|
|
5
6
|
from typing import (
|
|
6
7
|
Any,
|
|
7
8
|
Callable,
|
|
@@ -240,10 +241,9 @@ class TarStream(File):
|
|
|
240
241
|
def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
|
|
241
242
|
builder = Builder(stream, core_extensions, spec, tar, encoding)
|
|
242
243
|
|
|
243
|
-
for item in tar.getmembers():
|
|
244
|
+
for item in sorted(tar.getmembers(), key=lambda m: Path(m.name).stem):
|
|
244
245
|
if not item.isfile():
|
|
245
246
|
continue
|
|
246
|
-
|
|
247
247
|
try:
|
|
248
248
|
builder.add(item)
|
|
249
249
|
except StopIteration:
|