datachain 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +8 -0
- datachain/data_storage/metastore.py +20 -1
- datachain/data_storage/sqlite.py +24 -32
- datachain/lib/arrow.py +64 -19
- datachain/lib/convert/values_to_tuples.py +2 -2
- datachain/lib/data_model.py +1 -1
- datachain/lib/dc.py +131 -12
- datachain/lib/signal_schema.py +6 -6
- datachain/lib/udf.py +208 -160
- datachain/lib/udf_signature.py +8 -6
- datachain/query/batch.py +0 -10
- datachain/query/dataset.py +7 -7
- datachain/query/dispatch.py +2 -14
- datachain/query/session.py +42 -0
- datachain/sql/functions/string.py +12 -0
- datachain/sql/sqlite/base.py +10 -5
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/METADATA +1 -1
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/RECORD +22 -23
- datachain/query/udf.py +0 -126
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/LICENSE +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/WHEEL +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.5.0.dist-info → datachain-0.6.0.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py
CHANGED
|
@@ -1,31 +1,32 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import traceback
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
4
5
|
|
|
6
|
+
import attrs
|
|
5
7
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
6
8
|
from pydantic import BaseModel
|
|
7
9
|
|
|
8
10
|
from datachain.dataset import RowDict
|
|
9
11
|
from datachain.lib.convert.flatten import flatten
|
|
10
|
-
from datachain.lib.
|
|
12
|
+
from datachain.lib.data_model import DataValue
|
|
11
13
|
from datachain.lib.file import File
|
|
12
|
-
from datachain.lib.model_store import ModelStore
|
|
13
14
|
from datachain.lib.signal_schema import SignalSchema
|
|
14
|
-
from datachain.lib.udf_signature import UdfSignature
|
|
15
15
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
16
|
-
from datachain.query.batch import
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
16
|
+
from datachain.query.batch import (
|
|
17
|
+
Batch,
|
|
18
|
+
BatchingStrategy,
|
|
19
|
+
NoBatching,
|
|
20
|
+
Partition,
|
|
21
|
+
RowsOutputBatch,
|
|
22
|
+
)
|
|
20
23
|
|
|
21
24
|
if TYPE_CHECKING:
|
|
22
|
-
from collections.abc import Iterable, Iterator, Sequence
|
|
23
|
-
|
|
24
25
|
from typing_extensions import Self
|
|
25
26
|
|
|
26
27
|
from datachain.catalog import Catalog
|
|
27
|
-
from datachain.
|
|
28
|
-
from datachain.query.
|
|
28
|
+
from datachain.lib.udf_signature import UdfSignature
|
|
29
|
+
from datachain.query.batch import RowsOutput
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class UdfError(DataChainParamsError):
|
|
@@ -33,14 +34,47 @@ class UdfError(DataChainParamsError):
|
|
|
33
34
|
super().__init__(f"UDF error: {msg}")
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
37
|
+
ColumnType = Any
|
|
38
|
+
|
|
39
|
+
# Specification for the output of a UDF
|
|
40
|
+
UDFOutputSpec = Mapping[str, ColumnType]
|
|
41
|
+
|
|
42
|
+
# Result type when calling the UDF wrapper around the actual
|
|
43
|
+
# Python function / class implementing it.
|
|
44
|
+
UDFResult = dict[str, Any]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@attrs.define
|
|
48
|
+
class UDFProperties:
|
|
49
|
+
udf: "UDFAdapter"
|
|
50
|
+
|
|
51
|
+
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
52
|
+
return self.udf.get_batching(use_partitioning)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def batch(self):
|
|
56
|
+
return self.udf.batch
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@attrs.define(slots=False)
|
|
60
|
+
class UDFAdapter:
|
|
61
|
+
inner: "UDFBase"
|
|
62
|
+
output: UDFOutputSpec
|
|
63
|
+
batch: int = 1
|
|
64
|
+
|
|
65
|
+
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
66
|
+
if use_partitioning:
|
|
67
|
+
return Partition()
|
|
68
|
+
if self.batch == 1:
|
|
69
|
+
return NoBatching()
|
|
70
|
+
if self.batch > 1:
|
|
71
|
+
return Batch(self.batch)
|
|
72
|
+
raise ValueError(f"invalid batch size {self.batch}")
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def properties(self):
|
|
76
|
+
# For backwards compatibility.
|
|
77
|
+
return UDFProperties(self)
|
|
44
78
|
|
|
45
79
|
def run(
|
|
46
80
|
self,
|
|
@@ -51,48 +85,16 @@ class UDFAdapter(_UDFBase):
|
|
|
51
85
|
cache: bool,
|
|
52
86
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
53
87
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
54
|
-
) ->
|
|
55
|
-
self.inner.
|
|
56
|
-
if hasattr(self.inner, "setup") and callable(self.inner.setup):
|
|
57
|
-
self.inner.setup()
|
|
58
|
-
|
|
59
|
-
yield from super().run(
|
|
88
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
89
|
+
yield from self.inner.run(
|
|
60
90
|
udf_fields,
|
|
61
91
|
udf_inputs,
|
|
62
92
|
catalog,
|
|
63
|
-
is_generator,
|
|
64
93
|
cache,
|
|
65
94
|
download_cb,
|
|
66
95
|
processed_cb,
|
|
67
96
|
)
|
|
68
97
|
|
|
69
|
-
if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
|
|
70
|
-
self.inner.teardown()
|
|
71
|
-
|
|
72
|
-
def run_once(
|
|
73
|
-
self,
|
|
74
|
-
catalog: "Catalog",
|
|
75
|
-
arg: "UDFInput",
|
|
76
|
-
is_generator: bool = False,
|
|
77
|
-
cache: bool = False,
|
|
78
|
-
cb: Callback = DEFAULT_CALLBACK,
|
|
79
|
-
) -> "Iterable[UDFResult]":
|
|
80
|
-
if isinstance(arg, UDFInputBatch):
|
|
81
|
-
udf_inputs = [
|
|
82
|
-
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
83
|
-
for row in arg.rows
|
|
84
|
-
]
|
|
85
|
-
udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
|
|
86
|
-
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
87
|
-
if isinstance(arg, RowDict):
|
|
88
|
-
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
89
|
-
udf_outputs = self.inner(*udf_inputs, cache=cache, download_cb=cb)
|
|
90
|
-
if not is_generator:
|
|
91
|
-
# udf_outputs is generator already if is_generator=True
|
|
92
|
-
udf_outputs = [udf_outputs]
|
|
93
|
-
return self._process_results([arg], udf_outputs, is_generator)
|
|
94
|
-
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
95
|
-
|
|
96
98
|
|
|
97
99
|
class UDFBase(AbstractUDF):
|
|
98
100
|
"""Base class for stateful user-defined functions.
|
|
@@ -142,18 +144,13 @@ class UDFBase(AbstractUDF):
|
|
|
142
144
|
```
|
|
143
145
|
"""
|
|
144
146
|
|
|
145
|
-
is_input_batched = False
|
|
146
147
|
is_output_batched = False
|
|
147
|
-
|
|
148
|
-
params_spec: Optional[list[str]]
|
|
148
|
+
catalog: "Optional[Catalog]"
|
|
149
149
|
|
|
150
150
|
def __init__(self):
|
|
151
|
-
self.params = None
|
|
151
|
+
self.params: Optional[SignalSchema] = None
|
|
152
152
|
self.output = None
|
|
153
|
-
self.
|
|
154
|
-
self.output_spec = None
|
|
155
|
-
self._contains_stream = None
|
|
156
|
-
self._catalog = None
|
|
153
|
+
self.catalog = None
|
|
157
154
|
self._func = None
|
|
158
155
|
|
|
159
156
|
def process(self, *args, **kwargs):
|
|
@@ -174,29 +171,24 @@ class UDFBase(AbstractUDF):
|
|
|
174
171
|
|
|
175
172
|
def _init(
|
|
176
173
|
self,
|
|
177
|
-
sign: UdfSignature,
|
|
174
|
+
sign: "UdfSignature",
|
|
178
175
|
params: SignalSchema,
|
|
179
|
-
func: Callable,
|
|
176
|
+
func: Optional[Callable],
|
|
180
177
|
):
|
|
181
178
|
self.params = params
|
|
182
179
|
self.output = sign.output_schema
|
|
183
|
-
|
|
184
|
-
params_spec = self.params.to_udf_spec()
|
|
185
|
-
self.params_spec = list(params_spec.keys())
|
|
186
|
-
self.output_spec = self.output.to_udf_spec()
|
|
187
|
-
|
|
188
180
|
self._func = func
|
|
189
181
|
|
|
190
182
|
@classmethod
|
|
191
183
|
def _create(
|
|
192
184
|
cls,
|
|
193
|
-
sign: UdfSignature,
|
|
185
|
+
sign: "UdfSignature",
|
|
194
186
|
params: SignalSchema,
|
|
195
187
|
) -> "Self":
|
|
196
188
|
if isinstance(sign.func, AbstractUDF):
|
|
197
189
|
if not isinstance(sign.func, cls): # type: ignore[unreachable]
|
|
198
190
|
raise UdfError(
|
|
199
|
-
f"cannot create UDF: provided UDF '{sign.func.__name__}'"
|
|
191
|
+
f"cannot create UDF: provided UDF '{type(sign.func).__name__}'"
|
|
200
192
|
f" must be a child of target class '{cls.__name__}'",
|
|
201
193
|
)
|
|
202
194
|
result = sign.func
|
|
@@ -212,57 +204,27 @@ class UDFBase(AbstractUDF):
|
|
|
212
204
|
def name(self):
|
|
213
205
|
return self.__class__.__name__
|
|
214
206
|
|
|
215
|
-
def set_catalog(self, catalog):
|
|
216
|
-
self._catalog = catalog.copy(db=False)
|
|
217
|
-
|
|
218
207
|
@property
|
|
219
|
-
def
|
|
220
|
-
return self.
|
|
208
|
+
def signal_names(self) -> Iterable[str]:
|
|
209
|
+
return self.output.to_udf_spec().keys()
|
|
221
210
|
|
|
222
211
|
def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
212
|
+
return UDFAdapter(
|
|
213
|
+
self,
|
|
214
|
+
self.output.to_udf_spec(),
|
|
215
|
+
batch,
|
|
226
216
|
)
|
|
227
|
-
return UDFAdapter(self, properties)
|
|
228
|
-
|
|
229
|
-
def validate_results(self, results, *args, **kwargs):
|
|
230
|
-
return results
|
|
231
|
-
|
|
232
|
-
def __call__(self, *rows, cache, download_cb):
|
|
233
|
-
if self.is_input_grouped:
|
|
234
|
-
objs = self._parse_grouped_rows(rows[0], cache, download_cb)
|
|
235
|
-
elif self.is_input_batched:
|
|
236
|
-
objs = zip(*self._parse_rows(rows[0], cache, download_cb))
|
|
237
|
-
else:
|
|
238
|
-
objs = self._parse_rows([rows], cache, download_cb)[0]
|
|
239
217
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
assert (
|
|
251
|
-
len(res) == 1
|
|
252
|
-
), f"{self.name} returns {len(res)} rows while it's not batched"
|
|
253
|
-
if isinstance(res[0], tuple):
|
|
254
|
-
res = res[0]
|
|
255
|
-
elif (
|
|
256
|
-
self.is_input_batched
|
|
257
|
-
and self.is_output_batched
|
|
258
|
-
and not self.is_input_grouped
|
|
259
|
-
):
|
|
260
|
-
res = list(res)
|
|
261
|
-
assert len(res) == len(
|
|
262
|
-
rows[0]
|
|
263
|
-
), f"{self.name} returns {len(res)} rows while len(rows[0]) expected"
|
|
264
|
-
|
|
265
|
-
return res
|
|
218
|
+
def run(
|
|
219
|
+
self,
|
|
220
|
+
udf_fields: "Sequence[str]",
|
|
221
|
+
udf_inputs: "Iterable[Any]",
|
|
222
|
+
catalog: "Catalog",
|
|
223
|
+
cache: bool,
|
|
224
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
225
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
226
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
227
|
+
raise NotImplementedError
|
|
266
228
|
|
|
267
229
|
def _flatten_row(self, row):
|
|
268
230
|
if len(self.output.values) > 1 and not isinstance(row, BaseModel):
|
|
@@ -276,47 +238,28 @@ class UDFBase(AbstractUDF):
|
|
|
276
238
|
def _obj_to_list(obj):
|
|
277
239
|
return flatten(obj) if isinstance(obj, BaseModel) else [obj]
|
|
278
240
|
|
|
279
|
-
def
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
|
|
297
|
-
else:
|
|
298
|
-
length = 1
|
|
299
|
-
spec_map[name] = anno, length
|
|
300
|
-
output_map[name] = []
|
|
301
|
-
|
|
302
|
-
for flat_obj in group:
|
|
303
|
-
position = 0
|
|
304
|
-
for signal, (cls, length) in spec_map.items():
|
|
305
|
-
slice = flat_obj[position : position + length]
|
|
306
|
-
position += length
|
|
307
|
-
|
|
308
|
-
if ModelStore.is_pydantic(cls):
|
|
309
|
-
obj = cls(**unflatten_to_json(cls, slice))
|
|
310
|
-
else:
|
|
311
|
-
obj = slice[0]
|
|
312
|
-
|
|
313
|
-
if isinstance(obj, File):
|
|
314
|
-
obj._set_stream(
|
|
315
|
-
self._catalog, caching_enabled=cache, download_cb=download_cb
|
|
316
|
-
)
|
|
317
|
-
output_map[signal].append(obj)
|
|
241
|
+
def _parse_row(
|
|
242
|
+
self, row_dict: RowDict, cache: bool, download_cb: Callback
|
|
243
|
+
) -> list[DataValue]:
|
|
244
|
+
assert self.params
|
|
245
|
+
row = [row_dict[p] for p in self.params.to_udf_spec()]
|
|
246
|
+
obj_row = self.params.row_to_objs(row)
|
|
247
|
+
for obj in obj_row:
|
|
248
|
+
if isinstance(obj, File):
|
|
249
|
+
assert self.catalog is not None
|
|
250
|
+
obj._set_stream(
|
|
251
|
+
self.catalog, caching_enabled=cache, download_cb=download_cb
|
|
252
|
+
)
|
|
253
|
+
return obj_row
|
|
254
|
+
|
|
255
|
+
def _prepare_row(self, row, udf_fields, cache, download_cb):
|
|
256
|
+
row_dict = RowDict(zip(udf_fields, row))
|
|
257
|
+
return self._parse_row(row_dict, cache, download_cb)
|
|
318
258
|
|
|
319
|
-
|
|
259
|
+
def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
|
|
260
|
+
row_dict = RowDict(zip(udf_fields, row))
|
|
261
|
+
udf_input = self._parse_row(row_dict, cache, download_cb)
|
|
262
|
+
return row_dict["sys__id"], *udf_input
|
|
320
263
|
|
|
321
264
|
def process_safe(self, obj_rows):
|
|
322
265
|
try:
|
|
@@ -336,23 +279,128 @@ class UDFBase(AbstractUDF):
|
|
|
336
279
|
class Mapper(UDFBase):
|
|
337
280
|
"""Inherit from this class to pass to `DataChain.map()`."""
|
|
338
281
|
|
|
282
|
+
def run(
|
|
283
|
+
self,
|
|
284
|
+
udf_fields: "Sequence[str]",
|
|
285
|
+
udf_inputs: "Iterable[Sequence[Any]]",
|
|
286
|
+
catalog: "Catalog",
|
|
287
|
+
cache: bool,
|
|
288
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
289
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
290
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
291
|
+
self.catalog = catalog
|
|
292
|
+
self.setup()
|
|
293
|
+
|
|
294
|
+
for row in udf_inputs:
|
|
295
|
+
id_, *udf_args = self._prepare_row_and_id(
|
|
296
|
+
row, udf_fields, cache, download_cb
|
|
297
|
+
)
|
|
298
|
+
result_objs = self.process_safe(udf_args)
|
|
299
|
+
udf_output = self._flatten_row(result_objs)
|
|
300
|
+
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
|
|
301
|
+
processed_cb.relative_update(1)
|
|
302
|
+
yield output
|
|
303
|
+
|
|
304
|
+
self.teardown()
|
|
305
|
+
|
|
339
306
|
|
|
340
307
|
class BatchMapper(UDFBase):
|
|
341
308
|
"""Inherit from this class to pass to `DataChain.batch_map()`."""
|
|
342
309
|
|
|
343
|
-
is_input_batched = True
|
|
344
310
|
is_output_batched = True
|
|
345
311
|
|
|
312
|
+
def run(
|
|
313
|
+
self,
|
|
314
|
+
udf_fields: Sequence[str],
|
|
315
|
+
udf_inputs: Iterable[RowsOutputBatch],
|
|
316
|
+
catalog: "Catalog",
|
|
317
|
+
cache: bool,
|
|
318
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
319
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
320
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
321
|
+
self.catalog = catalog
|
|
322
|
+
self.setup()
|
|
323
|
+
|
|
324
|
+
for batch in udf_inputs:
|
|
325
|
+
n_rows = len(batch.rows)
|
|
326
|
+
row_ids, *udf_args = zip(
|
|
327
|
+
*[
|
|
328
|
+
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
|
|
329
|
+
for row in batch.rows
|
|
330
|
+
]
|
|
331
|
+
)
|
|
332
|
+
result_objs = list(self.process_safe(udf_args))
|
|
333
|
+
n_objs = len(result_objs)
|
|
334
|
+
assert (
|
|
335
|
+
n_objs == n_rows
|
|
336
|
+
), f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
|
|
337
|
+
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
338
|
+
output = [
|
|
339
|
+
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
340
|
+
for row_id, signals in zip(row_ids, udf_outputs)
|
|
341
|
+
]
|
|
342
|
+
processed_cb.relative_update(n_rows)
|
|
343
|
+
yield output
|
|
344
|
+
|
|
345
|
+
self.teardown()
|
|
346
|
+
|
|
346
347
|
|
|
347
348
|
class Generator(UDFBase):
|
|
348
349
|
"""Inherit from this class to pass to `DataChain.gen()`."""
|
|
349
350
|
|
|
350
351
|
is_output_batched = True
|
|
351
352
|
|
|
353
|
+
def run(
|
|
354
|
+
self,
|
|
355
|
+
udf_fields: "Sequence[str]",
|
|
356
|
+
udf_inputs: "Iterable[Sequence[Any]]",
|
|
357
|
+
catalog: "Catalog",
|
|
358
|
+
cache: bool,
|
|
359
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
360
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
361
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
362
|
+
self.catalog = catalog
|
|
363
|
+
self.setup()
|
|
364
|
+
|
|
365
|
+
for row in udf_inputs:
|
|
366
|
+
udf_args = self._prepare_row(row, udf_fields, cache, download_cb)
|
|
367
|
+
result_objs = self.process_safe(udf_args)
|
|
368
|
+
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
369
|
+
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
370
|
+
processed_cb.relative_update(1)
|
|
371
|
+
yield output
|
|
372
|
+
|
|
373
|
+
self.teardown()
|
|
374
|
+
|
|
352
375
|
|
|
353
376
|
class Aggregator(UDFBase):
|
|
354
377
|
"""Inherit from this class to pass to `DataChain.agg()`."""
|
|
355
378
|
|
|
356
|
-
is_input_batched = True
|
|
357
379
|
is_output_batched = True
|
|
358
|
-
|
|
380
|
+
|
|
381
|
+
def run(
|
|
382
|
+
self,
|
|
383
|
+
udf_fields: "Sequence[str]",
|
|
384
|
+
udf_inputs: Iterable[RowsOutputBatch],
|
|
385
|
+
catalog: "Catalog",
|
|
386
|
+
cache: bool,
|
|
387
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
388
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
389
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
390
|
+
self.catalog = catalog
|
|
391
|
+
self.setup()
|
|
392
|
+
|
|
393
|
+
for batch in udf_inputs:
|
|
394
|
+
udf_args = zip(
|
|
395
|
+
*[
|
|
396
|
+
self._prepare_row(row, udf_fields, cache, download_cb)
|
|
397
|
+
for row in batch.rows
|
|
398
|
+
]
|
|
399
|
+
)
|
|
400
|
+
result_objs = self.process_safe(udf_args)
|
|
401
|
+
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
402
|
+
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
403
|
+
processed_cb.relative_update(len(batch.rows))
|
|
404
|
+
yield output
|
|
405
|
+
|
|
406
|
+
self.teardown()
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from collections.abc import Generator, Iterator, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Callable,
|
|
4
|
+
from typing import Callable, Union, get_args, get_origin
|
|
5
5
|
|
|
6
6
|
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
+
from datachain.lib.udf import UDFBase
|
|
8
9
|
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
10
|
|
|
10
11
|
|
|
@@ -16,7 +17,7 @@ class UdfSignatureError(DataChainParamsError):
|
|
|
16
17
|
|
|
17
18
|
@dataclass
|
|
18
19
|
class UdfSignature:
|
|
19
|
-
func: Callable
|
|
20
|
+
func: Union[Callable, UDFBase]
|
|
20
21
|
params: Sequence[str]
|
|
21
22
|
output_schema: SignalSchema
|
|
22
23
|
|
|
@@ -27,7 +28,7 @@ class UdfSignature:
|
|
|
27
28
|
cls,
|
|
28
29
|
chain: str,
|
|
29
30
|
signal_map: dict[str, Callable],
|
|
30
|
-
func:
|
|
31
|
+
func: Union[None, UDFBase, Callable] = None,
|
|
31
32
|
params: Union[None, str, Sequence[str]] = None,
|
|
32
33
|
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
33
34
|
is_generator: bool = True,
|
|
@@ -39,6 +40,7 @@ class UdfSignature:
|
|
|
39
40
|
f"multiple signals '{keys}' are not supported in processors."
|
|
40
41
|
" Chain multiple processors instead.",
|
|
41
42
|
)
|
|
43
|
+
udf_func: Union[UDFBase, Callable]
|
|
42
44
|
if len(signal_map) == 1:
|
|
43
45
|
if func is not None:
|
|
44
46
|
raise UdfSignatureError(
|
|
@@ -53,7 +55,7 @@ class UdfSignature:
|
|
|
53
55
|
udf_func = func
|
|
54
56
|
signal_name = None
|
|
55
57
|
|
|
56
|
-
if not callable(udf_func):
|
|
58
|
+
if not isinstance(udf_func, UDFBase) and not callable(udf_func):
|
|
57
59
|
raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
|
|
58
60
|
|
|
59
61
|
func_params_map_sign, func_outs_sign, is_iterator = (
|
|
@@ -73,7 +75,7 @@ class UdfSignature:
|
|
|
73
75
|
if not func_outs_sign:
|
|
74
76
|
raise UdfSignatureError(
|
|
75
77
|
chain,
|
|
76
|
-
f"outputs are not defined in function '{udf_func
|
|
78
|
+
f"outputs are not defined in function '{udf_func}'"
|
|
77
79
|
" hints or 'output'",
|
|
78
80
|
)
|
|
79
81
|
|
|
@@ -154,7 +156,7 @@ class UdfSignature:
|
|
|
154
156
|
|
|
155
157
|
@staticmethod
|
|
156
158
|
def _func_signature(
|
|
157
|
-
chain: str, udf_func: Callable
|
|
159
|
+
chain: str, udf_func: Union[Callable, UDFBase]
|
|
158
160
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
159
161
|
if isinstance(udf_func, AbstractUDF):
|
|
160
162
|
func = udf_func.process # type: ignore[unreachable]
|
datachain/query/batch.py
CHANGED
|
@@ -11,8 +11,6 @@ from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from sqlalchemy import Select
|
|
13
13
|
|
|
14
|
-
from datachain.dataset import RowDict
|
|
15
|
-
|
|
16
14
|
|
|
17
15
|
@dataclass
|
|
18
16
|
class RowsOutputBatch:
|
|
@@ -22,14 +20,6 @@ class RowsOutputBatch:
|
|
|
22
20
|
RowsOutput = Union[Sequence, RowsOutputBatch]
|
|
23
21
|
|
|
24
22
|
|
|
25
|
-
@dataclass
|
|
26
|
-
class UDFInputBatch:
|
|
27
|
-
rows: Sequence["RowDict"]
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
UDFInput = Union["RowDict", UDFInputBatch]
|
|
31
|
-
|
|
32
|
-
|
|
33
23
|
class BatchingStrategy(ABC):
|
|
34
24
|
"""BatchingStrategy provides means of batching UDF executions."""
|
|
35
25
|
|
datachain/query/dataset.py
CHANGED
|
@@ -42,6 +42,7 @@ from datachain.data_storage.schema import (
|
|
|
42
42
|
)
|
|
43
43
|
from datachain.dataset import DatasetStatus, RowDict
|
|
44
44
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
45
|
+
from datachain.lib.udf import UDFAdapter
|
|
45
46
|
from datachain.progress import CombinedDownloadCallback
|
|
46
47
|
from datachain.sql.functions import rand
|
|
47
48
|
from datachain.utils import (
|
|
@@ -53,7 +54,6 @@ from datachain.utils import (
|
|
|
53
54
|
|
|
54
55
|
from .schema import C, UDFParamSpec, normalize_param
|
|
55
56
|
from .session import Session
|
|
56
|
-
from .udf import UDFBase
|
|
57
57
|
|
|
58
58
|
if TYPE_CHECKING:
|
|
59
59
|
from sqlalchemy.sql.elements import ClauseElement
|
|
@@ -299,7 +299,7 @@ def adjust_outputs(
|
|
|
299
299
|
return row
|
|
300
300
|
|
|
301
301
|
|
|
302
|
-
def get_udf_col_types(warehouse: "AbstractWarehouse", udf:
|
|
302
|
+
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
|
|
303
303
|
"""Optimization: Precompute UDF column types so these don't have to be computed
|
|
304
304
|
in the convert_type function for each row in a loop."""
|
|
305
305
|
dialect = warehouse.db.dialect
|
|
@@ -320,7 +320,7 @@ def process_udf_outputs(
|
|
|
320
320
|
warehouse: "AbstractWarehouse",
|
|
321
321
|
udf_table: "Table",
|
|
322
322
|
udf_results: Iterator[Iterable["UDFResult"]],
|
|
323
|
-
udf:
|
|
323
|
+
udf: UDFAdapter,
|
|
324
324
|
batch_size: int = INSERT_BATCH_SIZE,
|
|
325
325
|
cb: Callback = DEFAULT_CALLBACK,
|
|
326
326
|
) -> None:
|
|
@@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
364
364
|
|
|
365
365
|
@frozen
|
|
366
366
|
class UDFStep(Step, ABC):
|
|
367
|
-
udf:
|
|
367
|
+
udf: UDFAdapter
|
|
368
368
|
catalog: "Catalog"
|
|
369
369
|
partition_by: Optional[PartitionByType] = None
|
|
370
370
|
parallel: Optional[int] = None
|
|
@@ -392,7 +392,7 @@ class UDFStep(Step, ABC):
|
|
|
392
392
|
|
|
393
393
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
394
394
|
use_partitioning = self.partition_by is not None
|
|
395
|
-
batching = self.udf.
|
|
395
|
+
batching = self.udf.get_batching(use_partitioning)
|
|
396
396
|
workers = self.workers
|
|
397
397
|
if (
|
|
398
398
|
not workers
|
|
@@ -1465,7 +1465,7 @@ class DatasetQuery:
|
|
|
1465
1465
|
@detach
|
|
1466
1466
|
def add_signals(
|
|
1467
1467
|
self,
|
|
1468
|
-
udf:
|
|
1468
|
+
udf: UDFAdapter,
|
|
1469
1469
|
parallel: Optional[int] = None,
|
|
1470
1470
|
workers: Union[bool, int] = False,
|
|
1471
1471
|
min_task_size: Optional[int] = None,
|
|
@@ -1509,7 +1509,7 @@ class DatasetQuery:
|
|
|
1509
1509
|
@detach
|
|
1510
1510
|
def generate(
|
|
1511
1511
|
self,
|
|
1512
|
-
udf:
|
|
1512
|
+
udf: UDFAdapter,
|
|
1513
1513
|
parallel: Optional[int] = None,
|
|
1514
1514
|
workers: Union[bool, int] = False,
|
|
1515
1515
|
min_task_size: Optional[int] = None,
|
datachain/query/dispatch.py
CHANGED
|
@@ -13,6 +13,7 @@ from multiprocess import get_context
|
|
|
13
13
|
|
|
14
14
|
from datachain.catalog import Catalog
|
|
15
15
|
from datachain.catalog.loader import get_distributed_class
|
|
16
|
+
from datachain.lib.udf import UDFAdapter, UDFResult
|
|
16
17
|
from datachain.query.dataset import (
|
|
17
18
|
get_download_callback,
|
|
18
19
|
get_generated_callback,
|
|
@@ -27,7 +28,6 @@ from datachain.query.queue import (
|
|
|
27
28
|
put_into_queue,
|
|
28
29
|
unmarshal,
|
|
29
30
|
)
|
|
30
|
-
from datachain.query.udf import UDFBase, UDFResult
|
|
31
31
|
from datachain.utils import batched_it
|
|
32
32
|
|
|
33
33
|
DEFAULT_BATCH_SIZE = 10000
|
|
@@ -114,7 +114,6 @@ class UDFDispatcher:
|
|
|
114
114
|
catalog: Optional[Catalog] = None
|
|
115
115
|
task_queue: Optional[multiprocess.Queue] = None
|
|
116
116
|
done_queue: Optional[multiprocess.Queue] = None
|
|
117
|
-
_batch_size: Optional[int] = None
|
|
118
117
|
|
|
119
118
|
def __init__(
|
|
120
119
|
self,
|
|
@@ -154,17 +153,6 @@ class UDFDispatcher:
|
|
|
154
153
|
self.done_queue = None
|
|
155
154
|
self.ctx = get_context("spawn")
|
|
156
155
|
|
|
157
|
-
@property
|
|
158
|
-
def batch_size(self):
|
|
159
|
-
if self._batch_size is None:
|
|
160
|
-
if hasattr(self.udf, "properties") and hasattr(
|
|
161
|
-
self.udf.properties, "batch"
|
|
162
|
-
):
|
|
163
|
-
self._batch_size = self.udf.properties.batch
|
|
164
|
-
else:
|
|
165
|
-
self._batch_size = 1
|
|
166
|
-
return self._batch_size
|
|
167
|
-
|
|
168
156
|
def _create_worker(self) -> "UDFWorker":
|
|
169
157
|
if not self.catalog:
|
|
170
158
|
id_generator = self.id_generator_class(
|
|
@@ -336,7 +324,7 @@ class ProcessedCallback(Callback):
|
|
|
336
324
|
@attrs.define
|
|
337
325
|
class UDFWorker:
|
|
338
326
|
catalog: Catalog
|
|
339
|
-
udf:
|
|
327
|
+
udf: UDFAdapter
|
|
340
328
|
task_queue: "multiprocess.Queue"
|
|
341
329
|
done_queue: "multiprocess.Queue"
|
|
342
330
|
is_generator: bool
|