datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +17 -8
- datachain/catalog/catalog.py +5 -5
- datachain/cli.py +0 -2
- datachain/data_storage/schema.py +5 -5
- datachain/data_storage/sqlite.py +1 -1
- datachain/data_storage/warehouse.py +7 -7
- datachain/lib/arrow.py +25 -8
- datachain/lib/clip.py +6 -11
- datachain/lib/convert/__init__.py +0 -0
- datachain/lib/convert/flatten.py +67 -0
- datachain/lib/convert/type_converter.py +96 -0
- datachain/lib/convert/unflatten.py +69 -0
- datachain/lib/convert/values_to_tuples.py +85 -0
- datachain/lib/data_model.py +74 -0
- datachain/lib/dc.py +225 -168
- datachain/lib/file.py +41 -41
- datachain/lib/gpt4_vision.py +1 -9
- datachain/lib/hf_image_to_text.py +9 -17
- datachain/lib/hf_pipeline.py +4 -12
- datachain/lib/image.py +2 -18
- datachain/lib/image_transform.py +0 -1
- datachain/lib/iptc_exif_xmp.py +8 -15
- datachain/lib/meta_formats.py +1 -5
- datachain/lib/model_store.py +77 -0
- datachain/lib/pytorch.py +9 -21
- datachain/lib/signal_schema.py +139 -60
- datachain/lib/text.py +5 -16
- datachain/lib/udf.py +114 -30
- datachain/lib/udf_signature.py +5 -5
- datachain/lib/webdataset.py +3 -3
- datachain/lib/webdataset_laion.py +2 -3
- datachain/node.py +4 -4
- datachain/query/batch.py +1 -1
- datachain/query/dataset.py +51 -178
- datachain/query/dispatch.py +43 -30
- datachain/query/udf.py +46 -26
- datachain/remote/studio.py +1 -9
- datachain/torch/__init__.py +21 -0
- datachain/utils.py +39 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
- datachain/image/__init__.py +0 -3
- datachain/lib/cached_stream.py +0 -38
- datachain/lib/claude.py +0 -69
- datachain/lib/feature.py +0 -412
- datachain/lib/feature_registry.py +0 -51
- datachain/lib/feature_utils.py +0 -154
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
datachain/lib/signal_schema.py
CHANGED
|
@@ -1,20 +1,30 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from collections.abc import Iterator, Sequence
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from datetime import datetime
|
|
4
|
-
from typing import
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Annotated,
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
Literal,
|
|
11
|
+
Optional,
|
|
12
|
+
Union,
|
|
13
|
+
get_args,
|
|
14
|
+
get_origin,
|
|
15
|
+
)
|
|
5
16
|
|
|
6
|
-
from pydantic import create_model
|
|
17
|
+
from pydantic import BaseModel, create_model
|
|
18
|
+
from typing_extensions import Literal as LiteralEx
|
|
7
19
|
|
|
8
|
-
from datachain.lib.
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
FeatureType,
|
|
13
|
-
convert_type_to_datachain,
|
|
14
|
-
)
|
|
15
|
-
from datachain.lib.feature_registry import Registry
|
|
20
|
+
from datachain.lib.convert.flatten import DATACHAIN_TO_TYPE
|
|
21
|
+
from datachain.lib.convert.type_converter import convert_to_db_type
|
|
22
|
+
from datachain.lib.convert.unflatten import unflatten_to_json_pos
|
|
23
|
+
from datachain.lib.data_model import DataModel, DataType
|
|
16
24
|
from datachain.lib.file import File
|
|
25
|
+
from datachain.lib.model_store import ModelStore
|
|
17
26
|
from datachain.lib.utils import DataChainParamsError
|
|
27
|
+
from datachain.query.schema import DEFAULT_DELIMITER
|
|
18
28
|
|
|
19
29
|
if TYPE_CHECKING:
|
|
20
30
|
from datachain.catalog import Catalog
|
|
@@ -56,10 +66,16 @@ class SignalResolvingTypeError(SignalResolvingError):
|
|
|
56
66
|
)
|
|
57
67
|
|
|
58
68
|
|
|
69
|
+
@dataclass
|
|
59
70
|
class SignalSchema:
|
|
71
|
+
values: dict[str, DataType]
|
|
72
|
+
tree: dict[str, Any]
|
|
73
|
+
setup_func: dict[str, Callable]
|
|
74
|
+
setup_values: Optional[dict[str, Callable]]
|
|
75
|
+
|
|
60
76
|
def __init__(
|
|
61
77
|
self,
|
|
62
|
-
values: dict[str,
|
|
78
|
+
values: dict[str, DataType],
|
|
63
79
|
setup: Optional[dict[str, Callable]] = None,
|
|
64
80
|
):
|
|
65
81
|
self.values = values
|
|
@@ -85,7 +101,7 @@ class SignalSchema:
|
|
|
85
101
|
|
|
86
102
|
@staticmethod
|
|
87
103
|
def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
|
|
88
|
-
signals: dict[str,
|
|
104
|
+
signals: dict[str, DataType] = {}
|
|
89
105
|
for field, type_ in col_types.items():
|
|
90
106
|
type_ = DATACHAIN_TO_TYPE.get(type_, None)
|
|
91
107
|
if type_ is None:
|
|
@@ -99,15 +115,16 @@ class SignalSchema:
|
|
|
99
115
|
def serialize(self) -> dict[str, str]:
|
|
100
116
|
signals = {}
|
|
101
117
|
for name, fr_type in self.values.items():
|
|
102
|
-
if
|
|
103
|
-
|
|
118
|
+
if (fr := ModelStore.to_pydantic(fr_type)) is not None:
|
|
119
|
+
ModelStore.add(fr)
|
|
120
|
+
signals[name] = ModelStore.get_name(fr)
|
|
104
121
|
else:
|
|
105
122
|
orig = get_origin(fr_type)
|
|
106
123
|
args = get_args(fr_type)
|
|
107
124
|
# Check if fr_type is Optional
|
|
108
125
|
if orig == Union and len(args) == 2 and (type(None) in args):
|
|
109
126
|
fr_type = args[0]
|
|
110
|
-
signals[name] = fr_type.__name__
|
|
127
|
+
signals[name] = str(fr_type.__name__) # type: ignore[union-attr]
|
|
111
128
|
return signals
|
|
112
129
|
|
|
113
130
|
@staticmethod
|
|
@@ -115,80 +132,93 @@ class SignalSchema:
|
|
|
115
132
|
if not isinstance(schema, dict):
|
|
116
133
|
raise SignalSchemaError(f"cannot deserialize signal schema: {schema}")
|
|
117
134
|
|
|
118
|
-
signals: dict[str,
|
|
135
|
+
signals: dict[str, DataType] = {}
|
|
119
136
|
for signal, type_name in schema.items():
|
|
120
137
|
try:
|
|
121
138
|
fr = NAMES_TO_TYPES.get(type_name)
|
|
122
139
|
if not fr:
|
|
123
|
-
type_name, version =
|
|
124
|
-
fr =
|
|
140
|
+
type_name, version = ModelStore.parse_name_version(type_name)
|
|
141
|
+
fr = ModelStore.get(type_name, version)
|
|
142
|
+
|
|
143
|
+
if not fr:
|
|
144
|
+
raise SignalSchemaError(
|
|
145
|
+
f"cannot deserialize '{signal}': "
|
|
146
|
+
f"unknown type '{type_name}'."
|
|
147
|
+
f" Try to add it with `ModelStore.add({type_name})`."
|
|
148
|
+
)
|
|
125
149
|
except TypeError as err:
|
|
126
150
|
raise SignalSchemaError(
|
|
127
151
|
f"cannot deserialize '{signal}': {err}"
|
|
128
152
|
) from err
|
|
129
|
-
|
|
130
|
-
if not fr:
|
|
131
|
-
raise SignalSchemaError(
|
|
132
|
-
f"cannot deserialize '{signal}': unsupported type '{type_name}'"
|
|
133
|
-
)
|
|
134
153
|
signals[signal] = fr
|
|
135
154
|
|
|
136
155
|
return SignalSchema(signals)
|
|
137
156
|
|
|
138
|
-
def to_udf_spec(self) -> dict[str,
|
|
157
|
+
def to_udf_spec(self) -> dict[str, type]:
|
|
139
158
|
res = {}
|
|
140
159
|
for path, type_, has_subtree, _ in self.get_flat_tree():
|
|
141
160
|
if path[0] in self.setup_func:
|
|
142
161
|
continue
|
|
143
162
|
if not has_subtree:
|
|
144
163
|
db_name = DEFAULT_DELIMITER.join(path)
|
|
145
|
-
res[db_name] =
|
|
164
|
+
res[db_name] = convert_to_db_type(type_)
|
|
146
165
|
return res
|
|
147
166
|
|
|
148
|
-
def row_to_objs(self, row: Sequence[Any]) -> list[
|
|
167
|
+
def row_to_objs(self, row: Sequence[Any]) -> list[DataType]:
|
|
149
168
|
self._init_setup_values()
|
|
150
169
|
|
|
151
170
|
objs = []
|
|
152
171
|
pos = 0
|
|
153
172
|
for name, fr_type in self.values.items():
|
|
154
|
-
if val := self.setup_values.get(name, None):
|
|
173
|
+
if self.setup_values and (val := self.setup_values.get(name, None)):
|
|
155
174
|
objs.append(val)
|
|
156
|
-
elif
|
|
157
|
-
j, pos =
|
|
158
|
-
objs.append(
|
|
175
|
+
elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
|
|
176
|
+
j, pos = unflatten_to_json_pos(fr, row, pos)
|
|
177
|
+
objs.append(fr(**j)) # type: ignore[arg-type]
|
|
159
178
|
else:
|
|
160
179
|
objs.append(row[pos])
|
|
161
180
|
pos += 1
|
|
162
181
|
return objs # type: ignore[return-value]
|
|
163
182
|
|
|
164
183
|
def contains_file(self) -> bool:
|
|
165
|
-
|
|
166
|
-
fr.
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
184
|
+
for type_ in self.values.values():
|
|
185
|
+
if (fr := ModelStore.to_pydantic(type_)) is not None and issubclass(
|
|
186
|
+
fr, File
|
|
187
|
+
):
|
|
188
|
+
return True
|
|
189
|
+
|
|
190
|
+
return False
|
|
170
191
|
|
|
171
192
|
def slice(
|
|
172
193
|
self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None
|
|
173
194
|
) -> "SignalSchema":
|
|
195
|
+
# Make new schema that combines current schema and setup signals
|
|
174
196
|
setup = setup or {}
|
|
175
197
|
setup_no_types = dict.fromkeys(setup.keys(), str)
|
|
176
|
-
union = self.values | setup_no_types
|
|
177
|
-
|
|
198
|
+
union = SignalSchema(self.values | setup_no_types)
|
|
199
|
+
# Slice combined schema by keys
|
|
200
|
+
schema = {}
|
|
201
|
+
for k in keys:
|
|
202
|
+
try:
|
|
203
|
+
schema[k] = union._find_in_tree(k.split("."))
|
|
204
|
+
except SignalResolvingError:
|
|
205
|
+
pass
|
|
178
206
|
return SignalSchema(schema, setup)
|
|
179
207
|
|
|
180
|
-
def row_to_features(
|
|
208
|
+
def row_to_features(
|
|
209
|
+
self, row: Sequence, catalog: "Catalog", cache: bool = False
|
|
210
|
+
) -> list[DataType]:
|
|
181
211
|
res = []
|
|
182
212
|
pos = 0
|
|
183
213
|
for fr_cls in self.values.values():
|
|
184
|
-
if
|
|
214
|
+
if (fr := ModelStore.to_pydantic(fr_cls)) is None:
|
|
185
215
|
res.append(row[pos])
|
|
186
216
|
pos += 1
|
|
187
217
|
else:
|
|
188
|
-
json, pos =
|
|
189
|
-
obj =
|
|
218
|
+
json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
|
|
219
|
+
obj = fr(**json)
|
|
190
220
|
if isinstance(obj, File):
|
|
191
|
-
obj._set_stream(catalog)
|
|
221
|
+
obj._set_stream(catalog, caching_enabled=cache)
|
|
192
222
|
res.append(obj)
|
|
193
223
|
return res
|
|
194
224
|
|
|
@@ -208,7 +238,7 @@ class SignalSchema:
|
|
|
208
238
|
|
|
209
239
|
return SignalSchema(schema)
|
|
210
240
|
|
|
211
|
-
def _find_in_tree(self, path: list[str]) ->
|
|
241
|
+
def _find_in_tree(self, path: list[str]) -> DataType:
|
|
212
242
|
curr_tree = self.tree
|
|
213
243
|
curr_type = None
|
|
214
244
|
i = 0
|
|
@@ -265,24 +295,23 @@ class SignalSchema:
|
|
|
265
295
|
if has_subtree and issubclass(type_, File):
|
|
266
296
|
yield ".".join(path)
|
|
267
297
|
|
|
268
|
-
def create_model(self, name: str) -> type[
|
|
298
|
+
def create_model(self, name: str) -> type[DataModel]:
|
|
269
299
|
fields = {key: (value, None) for key, value in self.values.items()}
|
|
270
300
|
|
|
271
301
|
return create_model(
|
|
272
302
|
name,
|
|
273
|
-
__base__=(
|
|
303
|
+
__base__=(DataModel,), # type: ignore[call-overload]
|
|
274
304
|
**fields,
|
|
275
305
|
)
|
|
276
306
|
|
|
277
307
|
@staticmethod
|
|
278
|
-
def _build_tree(
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
return res
|
|
308
|
+
def _build_tree(
|
|
309
|
+
values: dict[str, DataType],
|
|
310
|
+
) -> dict[str, tuple[DataType, Optional[dict]]]:
|
|
311
|
+
return {
|
|
312
|
+
name: (val, SignalSchema._build_tree_for_type(val))
|
|
313
|
+
for name, val in values.items()
|
|
314
|
+
}
|
|
286
315
|
|
|
287
316
|
def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]:
|
|
288
317
|
yield from self._get_flat_tree(self.tree, [], 0)
|
|
@@ -305,27 +334,77 @@ class SignalSchema:
|
|
|
305
334
|
|
|
306
335
|
if get_origin(type_) is list:
|
|
307
336
|
args = get_args(type_)
|
|
308
|
-
if len(args) > 0 and
|
|
337
|
+
if len(args) > 0 and ModelStore.is_pydantic(args[0]):
|
|
309
338
|
sub_schema = SignalSchema({"* list of": args[0]})
|
|
310
339
|
sub_schema.print_tree(indent=indent, start_at=total_indent + indent)
|
|
311
340
|
|
|
341
|
+
def get_headers_with_length(self):
|
|
342
|
+
paths = [
|
|
343
|
+
path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree
|
|
344
|
+
]
|
|
345
|
+
max_length = max([len(path) for path in paths], default=0)
|
|
346
|
+
return [
|
|
347
|
+
path + [""] * (max_length - len(path)) if len(path) < max_length else path
|
|
348
|
+
for path in paths
|
|
349
|
+
], max_length
|
|
350
|
+
|
|
351
|
+
def __or__(self, other):
|
|
352
|
+
return self.__class__(self.values | other.values)
|
|
353
|
+
|
|
354
|
+
def __contains__(self, name: str):
|
|
355
|
+
return name in self.values
|
|
356
|
+
|
|
357
|
+
def remove(self, name: str):
|
|
358
|
+
return self.values.pop(name)
|
|
359
|
+
|
|
312
360
|
@staticmethod
|
|
313
|
-
def _type_to_str(type_):
|
|
314
|
-
|
|
361
|
+
def _type_to_str(type_): # noqa: PLR0911
|
|
362
|
+
origin = get_origin(type_)
|
|
363
|
+
|
|
364
|
+
if origin == Union:
|
|
315
365
|
args = get_args(type_)
|
|
316
366
|
formatted_types = ", ".join(SignalSchema._type_to_str(arg) for arg in args)
|
|
317
367
|
return f"Union[{formatted_types}]"
|
|
318
|
-
if
|
|
368
|
+
if origin == Optional:
|
|
319
369
|
args = get_args(type_)
|
|
320
370
|
type_str = SignalSchema._type_to_str(args[0])
|
|
321
371
|
return f"Optional[{type_str}]"
|
|
322
|
-
if
|
|
372
|
+
if origin is list:
|
|
323
373
|
args = get_args(type_)
|
|
324
374
|
type_str = SignalSchema._type_to_str(args[0])
|
|
325
375
|
return f"list[{type_str}]"
|
|
326
|
-
if
|
|
376
|
+
if origin is dict:
|
|
327
377
|
args = get_args(type_)
|
|
328
|
-
type_str = SignalSchema._type_to_str(args[0])
|
|
378
|
+
type_str = SignalSchema._type_to_str(args[0]) if len(args) > 0 else ""
|
|
329
379
|
vals = f", {SignalSchema._type_to_str(args[1])}" if len(args) > 1 else ""
|
|
330
380
|
return f"dict[{type_str}{vals}]"
|
|
381
|
+
if origin == Annotated:
|
|
382
|
+
args = get_args(type_)
|
|
383
|
+
return SignalSchema._type_to_str(args[0])
|
|
384
|
+
if origin in (Literal, LiteralEx):
|
|
385
|
+
return "Literal"
|
|
331
386
|
return type_.__name__
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def _build_tree_for_type(
|
|
390
|
+
model: DataType,
|
|
391
|
+
) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
|
|
392
|
+
if (fr := ModelStore.to_pydantic(model)) is not None:
|
|
393
|
+
return SignalSchema._build_tree_for_model(fr)
|
|
394
|
+
return None
|
|
395
|
+
|
|
396
|
+
@staticmethod
|
|
397
|
+
def _build_tree_for_model(
|
|
398
|
+
model: type[BaseModel],
|
|
399
|
+
) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
|
|
400
|
+
res: dict[str, tuple[DataType, Optional[dict]]] = {}
|
|
401
|
+
|
|
402
|
+
for name, f_info in model.model_fields.items():
|
|
403
|
+
anno = f_info.annotation
|
|
404
|
+
if (fr := ModelStore.to_pydantic(anno)) is not None:
|
|
405
|
+
subtree = SignalSchema._build_tree_for_model(fr)
|
|
406
|
+
else:
|
|
407
|
+
subtree = None
|
|
408
|
+
res[name] = (anno, subtree) # type: ignore[assignment]
|
|
409
|
+
|
|
410
|
+
return res
|
datachain/lib/text.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Any, Callable, Optional, Union
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
import torch
|
|
4
|
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def convert_text(
|
|
@@ -9,7 +9,7 @@ def convert_text(
|
|
|
9
9
|
tokenizer: Optional[Callable] = None,
|
|
10
10
|
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
11
11
|
encoder: Optional[Callable] = None,
|
|
12
|
-
) -> Union[str, list[str],
|
|
12
|
+
) -> Union[str, list[str], torch.Tensor]:
|
|
13
13
|
"""
|
|
14
14
|
Tokenize and otherwise transform text.
|
|
15
15
|
|
|
@@ -29,21 +29,10 @@ def convert_text(
|
|
|
29
29
|
res = tokenizer(text, **tokenizer_kwargs)
|
|
30
30
|
else:
|
|
31
31
|
res = tokenizer(text)
|
|
32
|
-
try:
|
|
33
|
-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
34
32
|
|
|
35
|
-
|
|
36
|
-
res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
37
|
-
)
|
|
38
|
-
except ImportError:
|
|
39
|
-
tokens = res
|
|
33
|
+
tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
40
34
|
|
|
41
35
|
if not encoder:
|
|
42
36
|
return tokens
|
|
43
37
|
|
|
44
|
-
try:
|
|
45
|
-
import torch
|
|
46
|
-
except ImportError:
|
|
47
|
-
"Missing dependency 'torch' needed to encode text."
|
|
48
|
-
|
|
49
38
|
return encoder(torch.tensor(tokens))
|
datachain/lib/udf.py
CHANGED
|
@@ -1,16 +1,29 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
import sys
|
|
3
2
|
import traceback
|
|
4
|
-
from
|
|
3
|
+
from collections.abc import Iterable, Iterator
|
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from datachain.dataset import RowDict
|
|
10
|
+
from datachain.lib.convert.flatten import flatten
|
|
11
|
+
from datachain.lib.convert.unflatten import unflatten_to_json
|
|
12
|
+
from datachain.lib.data_model import FileBasic
|
|
13
|
+
from datachain.lib.model_store import ModelStore
|
|
7
14
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
15
|
from datachain.lib.udf_signature import UdfSignature
|
|
9
16
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
10
|
-
from datachain.query import
|
|
17
|
+
from datachain.query.batch import RowBatch
|
|
18
|
+
from datachain.query.schema import ColumnParameter
|
|
19
|
+
from datachain.query.udf import UDFBase as _UDFBase
|
|
20
|
+
from datachain.query.udf import UDFProperties, UDFResult
|
|
11
21
|
|
|
12
22
|
if TYPE_CHECKING:
|
|
13
|
-
from
|
|
23
|
+
from typing_extensions import Self
|
|
24
|
+
|
|
25
|
+
from datachain.catalog import Catalog
|
|
26
|
+
from datachain.query.batch import BatchingResult
|
|
14
27
|
|
|
15
28
|
|
|
16
29
|
class UdfError(DataChainParamsError):
|
|
@@ -18,10 +31,67 @@ class UdfError(DataChainParamsError):
|
|
|
18
31
|
super().__init__(f"UDF error: {msg}")
|
|
19
32
|
|
|
20
33
|
|
|
34
|
+
class UDFAdapter(_UDFBase):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
inner: "UDFBase",
|
|
38
|
+
properties: UDFProperties,
|
|
39
|
+
):
|
|
40
|
+
self.inner = inner
|
|
41
|
+
super().__init__(properties)
|
|
42
|
+
|
|
43
|
+
def run(
|
|
44
|
+
self,
|
|
45
|
+
udf_inputs: "Iterable[BatchingResult]",
|
|
46
|
+
catalog: "Catalog",
|
|
47
|
+
is_generator: bool,
|
|
48
|
+
cache: bool,
|
|
49
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
50
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
51
|
+
) -> Iterator[Iterable["UDFResult"]]:
|
|
52
|
+
self.inner._catalog = catalog
|
|
53
|
+
if hasattr(self.inner, "setup") and callable(self.inner.setup):
|
|
54
|
+
self.inner.setup()
|
|
55
|
+
|
|
56
|
+
for batch in udf_inputs:
|
|
57
|
+
n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
|
|
58
|
+
output = self.run_once(catalog, batch, is_generator, cache, cb=download_cb)
|
|
59
|
+
processed_cb.relative_update(n_rows)
|
|
60
|
+
yield output
|
|
61
|
+
|
|
62
|
+
if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
|
|
63
|
+
self.inner.teardown()
|
|
64
|
+
|
|
65
|
+
def run_once(
|
|
66
|
+
self,
|
|
67
|
+
catalog: "Catalog",
|
|
68
|
+
arg: "BatchingResult",
|
|
69
|
+
is_generator: bool = False,
|
|
70
|
+
cache: bool = False,
|
|
71
|
+
cb: Callback = DEFAULT_CALLBACK,
|
|
72
|
+
) -> Iterable[UDFResult]:
|
|
73
|
+
if isinstance(arg, RowBatch):
|
|
74
|
+
udf_inputs = [
|
|
75
|
+
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
76
|
+
for row in arg.rows
|
|
77
|
+
]
|
|
78
|
+
udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
|
|
79
|
+
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
80
|
+
if isinstance(arg, RowDict):
|
|
81
|
+
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
82
|
+
udf_outputs = self.inner(*udf_inputs, cache=cache, download_cb=cb)
|
|
83
|
+
if not is_generator:
|
|
84
|
+
# udf_outputs is generator already if is_generator=True
|
|
85
|
+
udf_outputs = [udf_outputs]
|
|
86
|
+
return self._process_results([arg], udf_outputs, is_generator)
|
|
87
|
+
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
88
|
+
|
|
89
|
+
|
|
21
90
|
class UDFBase(AbstractUDF):
|
|
22
91
|
is_input_batched = False
|
|
23
92
|
is_output_batched = False
|
|
24
93
|
is_input_grouped = False
|
|
94
|
+
params_spec: Optional[list[str]]
|
|
25
95
|
|
|
26
96
|
def __init__(self):
|
|
27
97
|
self.params = None
|
|
@@ -48,7 +118,12 @@ class UDFBase(AbstractUDF):
|
|
|
48
118
|
This is needed for tasks like closing connections to end-points.
|
|
49
119
|
"""
|
|
50
120
|
|
|
51
|
-
def _init(
|
|
121
|
+
def _init(
|
|
122
|
+
self,
|
|
123
|
+
sign: UdfSignature,
|
|
124
|
+
params: SignalSchema,
|
|
125
|
+
func: Callable,
|
|
126
|
+
):
|
|
52
127
|
self.params = params
|
|
53
128
|
self.output = sign.output_schema
|
|
54
129
|
|
|
@@ -61,20 +136,19 @@ class UDFBase(AbstractUDF):
|
|
|
61
136
|
@classmethod
|
|
62
137
|
def _create(
|
|
63
138
|
cls,
|
|
64
|
-
target_class: type["UDFBase"],
|
|
65
139
|
sign: UdfSignature,
|
|
66
140
|
params: SignalSchema,
|
|
67
|
-
) -> "
|
|
141
|
+
) -> "Self":
|
|
68
142
|
if isinstance(sign.func, AbstractUDF):
|
|
69
|
-
if not isinstance(sign.func,
|
|
143
|
+
if not isinstance(sign.func, cls): # type: ignore[unreachable]
|
|
70
144
|
raise UdfError(
|
|
71
145
|
f"cannot create UDF: provided UDF '{sign.func.__name__}'"
|
|
72
|
-
f" must be a child of target class '{
|
|
146
|
+
f" must be a child of target class '{cls.__name__}'",
|
|
73
147
|
)
|
|
74
148
|
result = sign.func
|
|
75
149
|
func = None
|
|
76
150
|
else:
|
|
77
|
-
result =
|
|
151
|
+
result = cls()
|
|
78
152
|
func = sign.func
|
|
79
153
|
|
|
80
154
|
result._init(sign, params, func)
|
|
@@ -91,18 +165,21 @@ class UDFBase(AbstractUDF):
|
|
|
91
165
|
def catalog(self):
|
|
92
166
|
return self._catalog
|
|
93
167
|
|
|
94
|
-
def to_udf_wrapper(self, batch=1) ->
|
|
95
|
-
|
|
96
|
-
|
|
168
|
+
def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
|
|
169
|
+
assert self.params_spec is not None
|
|
170
|
+
properties = UDFProperties(
|
|
171
|
+
[ColumnParameter(p) for p in self.params_spec], self.output_spec, batch
|
|
172
|
+
)
|
|
173
|
+
return UDFAdapter(self, properties)
|
|
97
174
|
|
|
98
175
|
def validate_results(self, results, *args, **kwargs):
|
|
99
176
|
return results
|
|
100
177
|
|
|
101
|
-
def __call__(self, *rows):
|
|
178
|
+
def __call__(self, *rows, cache, download_cb):
|
|
102
179
|
if self.is_input_grouped:
|
|
103
|
-
objs = self._parse_grouped_rows(rows)
|
|
180
|
+
objs = self._parse_grouped_rows(rows[0], cache, download_cb)
|
|
104
181
|
else:
|
|
105
|
-
objs = self._parse_rows(rows)
|
|
182
|
+
objs = self._parse_rows(rows, cache, download_cb)
|
|
106
183
|
|
|
107
184
|
if not self.is_input_batched:
|
|
108
185
|
objs = objs[0]
|
|
@@ -117,15 +194,19 @@ class UDFBase(AbstractUDF):
|
|
|
117
194
|
for tuple_ in result_objs:
|
|
118
195
|
flat = []
|
|
119
196
|
for obj in tuple_:
|
|
120
|
-
if isinstance(obj,
|
|
121
|
-
flat.extend(
|
|
197
|
+
if isinstance(obj, BaseModel):
|
|
198
|
+
flat.extend(flatten(obj))
|
|
122
199
|
else:
|
|
123
200
|
flat.append(obj)
|
|
124
201
|
res.append(flat)
|
|
125
202
|
else:
|
|
126
203
|
# Generator expression is required, otherwise the value will be materialized
|
|
127
204
|
res = (
|
|
128
|
-
|
|
205
|
+
flatten(obj)
|
|
206
|
+
if isinstance(obj, BaseModel)
|
|
207
|
+
else obj
|
|
208
|
+
if isinstance(obj, tuple)
|
|
209
|
+
else (obj,)
|
|
129
210
|
for obj in result_objs
|
|
130
211
|
)
|
|
131
212
|
|
|
@@ -139,24 +220,25 @@ class UDFBase(AbstractUDF):
|
|
|
139
220
|
|
|
140
221
|
return res
|
|
141
222
|
|
|
142
|
-
def _parse_rows(self, rows):
|
|
223
|
+
def _parse_rows(self, rows, cache, download_cb):
|
|
143
224
|
if not self.is_input_batched:
|
|
144
225
|
rows = [rows]
|
|
145
226
|
objs = []
|
|
146
227
|
for row in rows:
|
|
147
228
|
obj_row = self.params.row_to_objs(row)
|
|
148
229
|
for obj in obj_row:
|
|
149
|
-
if isinstance(obj,
|
|
150
|
-
obj._set_stream(
|
|
230
|
+
if isinstance(obj, FileBasic):
|
|
231
|
+
obj._set_stream(
|
|
232
|
+
self._catalog, caching_enabled=cache, download_cb=download_cb
|
|
233
|
+
)
|
|
151
234
|
objs.append(obj_row)
|
|
152
235
|
return objs
|
|
153
236
|
|
|
154
|
-
def _parse_grouped_rows(self,
|
|
155
|
-
group = rows[0]
|
|
237
|
+
def _parse_grouped_rows(self, group, cache, download_cb):
|
|
156
238
|
spec_map = {}
|
|
157
239
|
output_map = {}
|
|
158
240
|
for name, (anno, subtree) in self.params.tree.items():
|
|
159
|
-
if
|
|
241
|
+
if ModelStore.is_pydantic(anno):
|
|
160
242
|
length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
|
|
161
243
|
else:
|
|
162
244
|
length = 1
|
|
@@ -169,13 +251,15 @@ class UDFBase(AbstractUDF):
|
|
|
169
251
|
slice = flat_obj[position : position + length]
|
|
170
252
|
position += length
|
|
171
253
|
|
|
172
|
-
if
|
|
173
|
-
obj = cls(**cls
|
|
254
|
+
if ModelStore.is_pydantic(cls):
|
|
255
|
+
obj = cls(**unflatten_to_json(cls, slice))
|
|
174
256
|
else:
|
|
175
257
|
obj = slice[0]
|
|
176
258
|
|
|
177
|
-
if isinstance(obj,
|
|
178
|
-
obj._set_stream(
|
|
259
|
+
if isinstance(obj, FileBasic):
|
|
260
|
+
obj._set_stream(
|
|
261
|
+
self._catalog, caching_enabled=cache, download_cb=download_cb
|
|
262
|
+
)
|
|
179
263
|
output_map[signal].append(obj)
|
|
180
264
|
|
|
181
265
|
return list(output_map.values())
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -3,7 +3,7 @@ from collections.abc import Generator, Iterator, Sequence
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Callable, Optional, Union, get_args, get_origin
|
|
5
5
|
|
|
6
|
-
from datachain.lib.
|
|
6
|
+
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
8
|
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
9
|
|
|
@@ -29,7 +29,7 @@ class UdfSignature:
|
|
|
29
29
|
signal_map: dict[str, Callable],
|
|
30
30
|
func: Optional[Callable] = None,
|
|
31
31
|
params: Union[None, str, Sequence[str]] = None,
|
|
32
|
-
output: Union[None,
|
|
32
|
+
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
33
33
|
is_generator: bool = True,
|
|
34
34
|
) -> "UdfSignature":
|
|
35
35
|
keys = ", ".join(signal_map.keys())
|
|
@@ -127,15 +127,15 @@ class UdfSignature:
|
|
|
127
127
|
f"output signal '{key}' has type '{type(key)}'"
|
|
128
128
|
" while 'str' is expected",
|
|
129
129
|
)
|
|
130
|
-
if not
|
|
130
|
+
if not is_chain_type(value):
|
|
131
131
|
raise UdfSignatureError(
|
|
132
132
|
chain,
|
|
133
133
|
f"output type '{value.__name__}' of signal '{key}' is not"
|
|
134
|
-
f" supported. Please use Feature types: {
|
|
134
|
+
f" supported. Please use Feature types: {DataTypeNames}",
|
|
135
135
|
)
|
|
136
136
|
|
|
137
137
|
udf_output_map = output
|
|
138
|
-
elif
|
|
138
|
+
elif is_chain_type(output):
|
|
139
139
|
udf_output_map = {signal_name: output}
|
|
140
140
|
else:
|
|
141
141
|
raise UdfSignatureError(
|
datachain/lib/webdataset.py
CHANGED
|
@@ -15,7 +15,7 @@ from typing import (
|
|
|
15
15
|
|
|
16
16
|
from pydantic import Field
|
|
17
17
|
|
|
18
|
-
from datachain.lib.
|
|
18
|
+
from datachain.lib.data_model import DataModel
|
|
19
19
|
from datachain.lib.file import File, TarVFile
|
|
20
20
|
from datachain.lib.utils import DataChainError
|
|
21
21
|
|
|
@@ -46,7 +46,7 @@ class UnknownFileExtensionError(WDSError):
|
|
|
46
46
|
super().__init__(tar_stream, f"unknown extension '{ext}' for file '{name}'")
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
class WDSBasic(
|
|
49
|
+
class WDSBasic(DataModel):
|
|
50
50
|
file: File
|
|
51
51
|
|
|
52
52
|
|
|
@@ -75,7 +75,7 @@ class WDSAllFile(WDSBasic):
|
|
|
75
75
|
cbor: Optional[bytes] = Field(default=None)
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
class WDSReadableSubclass(
|
|
78
|
+
class WDSReadableSubclass(DataModel):
|
|
79
79
|
@staticmethod
|
|
80
80
|
def _reader(builder, item: tarfile.TarInfo) -> "WDSReadableSubclass":
|
|
81
81
|
raise NotImplementedError
|