datachain 0.5.1__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +2 -0
- datachain/catalog/catalog.py +1 -9
- datachain/data_storage/sqlite.py +8 -0
- datachain/data_storage/warehouse.py +0 -4
- datachain/lib/convert/sql_to_python.py +8 -12
- datachain/lib/convert/values_to_tuples.py +2 -2
- datachain/lib/data_model.py +1 -1
- datachain/lib/dc.py +82 -30
- datachain/lib/func/__init__.py +14 -0
- datachain/lib/func/aggregate.py +42 -0
- datachain/lib/func/func.py +64 -0
- datachain/lib/signal_schema.py +15 -9
- datachain/lib/udf.py +177 -151
- datachain/lib/utils.py +5 -0
- datachain/query/__init__.py +1 -2
- datachain/query/batch.py +0 -11
- datachain/query/dataset.py +23 -44
- datachain/query/dispatch.py +0 -12
- datachain/query/schema.py +1 -61
- datachain/query/session.py +33 -25
- datachain/sql/functions/__init__.py +1 -1
- datachain/sql/functions/aggregate.py +47 -0
- datachain/sql/functions/array.py +0 -8
- datachain/sql/functions/string.py +12 -0
- datachain/sql/sqlite/base.py +30 -7
- {datachain-0.5.1.dist-info → datachain-0.6.1.dist-info}/METADATA +2 -2
- {datachain-0.5.1.dist-info → datachain-0.6.1.dist-info}/RECORD +31 -27
- {datachain-0.5.1.dist-info → datachain-0.6.1.dist-info}/LICENSE +0 -0
- {datachain-0.5.1.dist-info → datachain-0.6.1.dist-info}/WHEEL +0 -0
- {datachain-0.5.1.dist-info → datachain-0.6.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.5.1.dist-info → datachain-0.6.1.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import traceback
|
|
3
3
|
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
4
|
-
from dataclasses import dataclass
|
|
5
4
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
6
5
|
|
|
6
|
+
import attrs
|
|
7
7
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
10
10
|
from datachain.dataset import RowDict
|
|
11
11
|
from datachain.lib.convert.flatten import flatten
|
|
12
|
+
from datachain.lib.data_model import DataValue
|
|
12
13
|
from datachain.lib.file import File
|
|
13
14
|
from datachain.lib.signal_schema import SignalSchema
|
|
14
15
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
@@ -18,16 +19,14 @@ from datachain.query.batch import (
|
|
|
18
19
|
NoBatching,
|
|
19
20
|
Partition,
|
|
20
21
|
RowsOutputBatch,
|
|
21
|
-
UDFInputBatch,
|
|
22
22
|
)
|
|
23
|
-
from datachain.query.schema import ColumnParameter, UDFParameter
|
|
24
23
|
|
|
25
24
|
if TYPE_CHECKING:
|
|
26
25
|
from typing_extensions import Self
|
|
27
26
|
|
|
28
27
|
from datachain.catalog import Catalog
|
|
29
28
|
from datachain.lib.udf_signature import UdfSignature
|
|
30
|
-
from datachain.query.batch import RowsOutput
|
|
29
|
+
from datachain.query.batch import RowsOutput
|
|
31
30
|
|
|
32
31
|
|
|
33
32
|
class UdfError(DataChainParamsError):
|
|
@@ -45,11 +44,21 @@ UDFOutputSpec = Mapping[str, ColumnType]
|
|
|
45
44
|
UDFResult = dict[str, Any]
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
@
|
|
47
|
+
@attrs.define
|
|
49
48
|
class UDFProperties:
|
|
50
|
-
|
|
49
|
+
udf: "UDFAdapter"
|
|
51
50
|
|
|
52
|
-
|
|
51
|
+
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
52
|
+
return self.udf.get_batching(use_partitioning)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def batch(self):
|
|
56
|
+
return self.udf.batch
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@attrs.define(slots=False)
|
|
60
|
+
class UDFAdapter:
|
|
61
|
+
inner: "UDFBase"
|
|
53
62
|
output: UDFOutputSpec
|
|
54
63
|
batch: int = 1
|
|
55
64
|
|
|
@@ -62,20 +71,10 @@ class UDFProperties:
|
|
|
62
71
|
return Batch(self.batch)
|
|
63
72
|
raise ValueError(f"invalid batch size {self.batch}")
|
|
64
73
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class UDFAdapter:
|
|
70
|
-
def __init__(
|
|
71
|
-
self,
|
|
72
|
-
inner: "UDFBase",
|
|
73
|
-
properties: UDFProperties,
|
|
74
|
-
):
|
|
75
|
-
self.inner = inner
|
|
76
|
-
self.properties = properties
|
|
77
|
-
self.signal_names = properties.signal_names()
|
|
78
|
-
self.output = properties.output
|
|
74
|
+
@property
|
|
75
|
+
def properties(self):
|
|
76
|
+
# For backwards compatibility.
|
|
77
|
+
return UDFProperties(self)
|
|
79
78
|
|
|
80
79
|
def run(
|
|
81
80
|
self,
|
|
@@ -87,72 +86,14 @@ class UDFAdapter:
|
|
|
87
86
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
88
87
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
89
88
|
) -> Iterator[Iterable[UDFResult]]:
|
|
90
|
-
self.inner.
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
[RowDict(zip(udf_fields, row)) for row in batch.rows]
|
|
99
|
-
)
|
|
100
|
-
else:
|
|
101
|
-
n_rows = 1
|
|
102
|
-
inputs = RowDict(zip(udf_fields, batch))
|
|
103
|
-
output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
|
|
104
|
-
processed_cb.relative_update(n_rows)
|
|
105
|
-
yield output
|
|
106
|
-
|
|
107
|
-
if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
|
|
108
|
-
self.inner.teardown()
|
|
109
|
-
|
|
110
|
-
def run_once(
|
|
111
|
-
self,
|
|
112
|
-
catalog: "Catalog",
|
|
113
|
-
arg: "UDFInput",
|
|
114
|
-
is_generator: bool = False,
|
|
115
|
-
cache: bool = False,
|
|
116
|
-
cb: Callback = DEFAULT_CALLBACK,
|
|
117
|
-
) -> Iterable[UDFResult]:
|
|
118
|
-
if isinstance(arg, UDFInputBatch):
|
|
119
|
-
udf_inputs = [
|
|
120
|
-
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
121
|
-
for row in arg.rows
|
|
122
|
-
]
|
|
123
|
-
udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
|
|
124
|
-
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
125
|
-
if isinstance(arg, RowDict):
|
|
126
|
-
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
127
|
-
udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
|
|
128
|
-
if not is_generator:
|
|
129
|
-
# udf_outputs is generator already if is_generator=True
|
|
130
|
-
udf_outputs = [udf_outputs]
|
|
131
|
-
return self._process_results([arg], udf_outputs, is_generator)
|
|
132
|
-
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
133
|
-
|
|
134
|
-
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
135
|
-
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
136
|
-
|
|
137
|
-
def _process_results(
|
|
138
|
-
self,
|
|
139
|
-
rows: Sequence["RowDict"],
|
|
140
|
-
results: Sequence[Sequence[Any]],
|
|
141
|
-
is_generator=False,
|
|
142
|
-
) -> Iterable[UDFResult]:
|
|
143
|
-
"""Create a list of dictionaries representing UDF results."""
|
|
144
|
-
|
|
145
|
-
# outputting rows
|
|
146
|
-
if is_generator:
|
|
147
|
-
# each row in results is a tuple of column values
|
|
148
|
-
return (dict(zip(self.signal_names, row)) for row in results)
|
|
149
|
-
|
|
150
|
-
# outputting signals
|
|
151
|
-
row_ids = [row["sys__id"] for row in rows]
|
|
152
|
-
return [
|
|
153
|
-
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
154
|
-
for row_id, signals in zip(row_ids, results)
|
|
155
|
-
]
|
|
89
|
+
yield from self.inner.run(
|
|
90
|
+
udf_fields,
|
|
91
|
+
udf_inputs,
|
|
92
|
+
catalog,
|
|
93
|
+
cache,
|
|
94
|
+
download_cb,
|
|
95
|
+
processed_cb,
|
|
96
|
+
)
|
|
156
97
|
|
|
157
98
|
|
|
158
99
|
class UDFBase(AbstractUDF):
|
|
@@ -203,17 +144,12 @@ class UDFBase(AbstractUDF):
|
|
|
203
144
|
```
|
|
204
145
|
"""
|
|
205
146
|
|
|
206
|
-
is_input_batched = False
|
|
207
147
|
is_output_batched = False
|
|
208
|
-
is_input_grouped = False
|
|
209
|
-
params_spec: Optional[list[str]]
|
|
210
148
|
catalog: "Optional[Catalog]"
|
|
211
149
|
|
|
212
150
|
def __init__(self):
|
|
213
|
-
self.params = None
|
|
151
|
+
self.params: Optional[SignalSchema] = None
|
|
214
152
|
self.output = None
|
|
215
|
-
self.params_spec = None
|
|
216
|
-
self.output_spec = None
|
|
217
153
|
self.catalog = None
|
|
218
154
|
self._func = None
|
|
219
155
|
|
|
@@ -241,11 +177,6 @@ class UDFBase(AbstractUDF):
|
|
|
241
177
|
):
|
|
242
178
|
self.params = params
|
|
243
179
|
self.output = sign.output_schema
|
|
244
|
-
|
|
245
|
-
params_spec = self.params.to_udf_spec()
|
|
246
|
-
self.params_spec = list(params_spec.keys())
|
|
247
|
-
self.output_spec = self.output.to_udf_spec()
|
|
248
|
-
|
|
249
180
|
self._func = func
|
|
250
181
|
|
|
251
182
|
@classmethod
|
|
@@ -273,48 +204,27 @@ class UDFBase(AbstractUDF):
|
|
|
273
204
|
def name(self):
|
|
274
205
|
return self.__class__.__name__
|
|
275
206
|
|
|
207
|
+
@property
|
|
208
|
+
def signal_names(self) -> Iterable[str]:
|
|
209
|
+
return self.output.to_udf_spec().keys()
|
|
210
|
+
|
|
276
211
|
def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
212
|
+
return UDFAdapter(
|
|
213
|
+
self,
|
|
214
|
+
self.output.to_udf_spec(),
|
|
215
|
+
batch,
|
|
280
216
|
)
|
|
281
|
-
return UDFAdapter(self, properties)
|
|
282
|
-
|
|
283
|
-
def validate_results(self, results, *args, **kwargs):
|
|
284
|
-
return results
|
|
285
217
|
|
|
286
|
-
def
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
# Generator expression is required, otherwise the value will be materialized
|
|
298
|
-
res = (self._flatten_row(row) for row in result_objs)
|
|
299
|
-
|
|
300
|
-
if not self.is_output_batched:
|
|
301
|
-
res = list(res)
|
|
302
|
-
assert (
|
|
303
|
-
len(res) == 1
|
|
304
|
-
), f"{self.name} returns {len(res)} rows while it's not batched"
|
|
305
|
-
if isinstance(res[0], tuple):
|
|
306
|
-
res = res[0]
|
|
307
|
-
elif (
|
|
308
|
-
self.is_input_batched
|
|
309
|
-
and self.is_output_batched
|
|
310
|
-
and not self.is_input_grouped
|
|
311
|
-
):
|
|
312
|
-
res = list(res)
|
|
313
|
-
assert len(res) == len(
|
|
314
|
-
rows
|
|
315
|
-
), f"{self.name} returns {len(res)} rows while {len(rows)} expected"
|
|
316
|
-
|
|
317
|
-
return res
|
|
218
|
+
def run(
|
|
219
|
+
self,
|
|
220
|
+
udf_fields: "Sequence[str]",
|
|
221
|
+
udf_inputs: "Iterable[Any]",
|
|
222
|
+
catalog: "Catalog",
|
|
223
|
+
cache: bool,
|
|
224
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
225
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
226
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
227
|
+
raise NotImplementedError
|
|
318
228
|
|
|
319
229
|
def _flatten_row(self, row):
|
|
320
230
|
if len(self.output.values) > 1 and not isinstance(row, BaseModel):
|
|
@@ -328,17 +238,28 @@ class UDFBase(AbstractUDF):
|
|
|
328
238
|
def _obj_to_list(obj):
|
|
329
239
|
return flatten(obj) if isinstance(obj, BaseModel) else [obj]
|
|
330
240
|
|
|
331
|
-
def
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
241
|
+
def _parse_row(
|
|
242
|
+
self, row_dict: RowDict, cache: bool, download_cb: Callback
|
|
243
|
+
) -> list[DataValue]:
|
|
244
|
+
assert self.params
|
|
245
|
+
row = [row_dict[p] for p in self.params.to_udf_spec()]
|
|
246
|
+
obj_row = self.params.row_to_objs(row)
|
|
247
|
+
for obj in obj_row:
|
|
248
|
+
if isinstance(obj, File):
|
|
249
|
+
assert self.catalog is not None
|
|
250
|
+
obj._set_stream(
|
|
251
|
+
self.catalog, caching_enabled=cache, download_cb=download_cb
|
|
252
|
+
)
|
|
253
|
+
return obj_row
|
|
254
|
+
|
|
255
|
+
def _prepare_row(self, row, udf_fields, cache, download_cb):
|
|
256
|
+
row_dict = RowDict(zip(udf_fields, row))
|
|
257
|
+
return self._parse_row(row_dict, cache, download_cb)
|
|
258
|
+
|
|
259
|
+
def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
|
|
260
|
+
row_dict = RowDict(zip(udf_fields, row))
|
|
261
|
+
udf_input = self._parse_row(row_dict, cache, download_cb)
|
|
262
|
+
return row_dict["sys__id"], *udf_input
|
|
342
263
|
|
|
343
264
|
def process_safe(self, obj_rows):
|
|
344
265
|
try:
|
|
@@ -358,23 +279,128 @@ class UDFBase(AbstractUDF):
|
|
|
358
279
|
class Mapper(UDFBase):
|
|
359
280
|
"""Inherit from this class to pass to `DataChain.map()`."""
|
|
360
281
|
|
|
282
|
+
def run(
|
|
283
|
+
self,
|
|
284
|
+
udf_fields: "Sequence[str]",
|
|
285
|
+
udf_inputs: "Iterable[Sequence[Any]]",
|
|
286
|
+
catalog: "Catalog",
|
|
287
|
+
cache: bool,
|
|
288
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
289
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
290
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
291
|
+
self.catalog = catalog
|
|
292
|
+
self.setup()
|
|
293
|
+
|
|
294
|
+
for row in udf_inputs:
|
|
295
|
+
id_, *udf_args = self._prepare_row_and_id(
|
|
296
|
+
row, udf_fields, cache, download_cb
|
|
297
|
+
)
|
|
298
|
+
result_objs = self.process_safe(udf_args)
|
|
299
|
+
udf_output = self._flatten_row(result_objs)
|
|
300
|
+
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
|
|
301
|
+
processed_cb.relative_update(1)
|
|
302
|
+
yield output
|
|
303
|
+
|
|
304
|
+
self.teardown()
|
|
305
|
+
|
|
361
306
|
|
|
362
307
|
class BatchMapper(UDFBase):
|
|
363
308
|
"""Inherit from this class to pass to `DataChain.batch_map()`."""
|
|
364
309
|
|
|
365
|
-
is_input_batched = True
|
|
366
310
|
is_output_batched = True
|
|
367
311
|
|
|
312
|
+
def run(
|
|
313
|
+
self,
|
|
314
|
+
udf_fields: Sequence[str],
|
|
315
|
+
udf_inputs: Iterable[RowsOutputBatch],
|
|
316
|
+
catalog: "Catalog",
|
|
317
|
+
cache: bool,
|
|
318
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
319
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
320
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
321
|
+
self.catalog = catalog
|
|
322
|
+
self.setup()
|
|
323
|
+
|
|
324
|
+
for batch in udf_inputs:
|
|
325
|
+
n_rows = len(batch.rows)
|
|
326
|
+
row_ids, *udf_args = zip(
|
|
327
|
+
*[
|
|
328
|
+
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
|
|
329
|
+
for row in batch.rows
|
|
330
|
+
]
|
|
331
|
+
)
|
|
332
|
+
result_objs = list(self.process_safe(udf_args))
|
|
333
|
+
n_objs = len(result_objs)
|
|
334
|
+
assert (
|
|
335
|
+
n_objs == n_rows
|
|
336
|
+
), f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
|
|
337
|
+
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
338
|
+
output = [
|
|
339
|
+
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
340
|
+
for row_id, signals in zip(row_ids, udf_outputs)
|
|
341
|
+
]
|
|
342
|
+
processed_cb.relative_update(n_rows)
|
|
343
|
+
yield output
|
|
344
|
+
|
|
345
|
+
self.teardown()
|
|
346
|
+
|
|
368
347
|
|
|
369
348
|
class Generator(UDFBase):
|
|
370
349
|
"""Inherit from this class to pass to `DataChain.gen()`."""
|
|
371
350
|
|
|
372
351
|
is_output_batched = True
|
|
373
352
|
|
|
353
|
+
def run(
|
|
354
|
+
self,
|
|
355
|
+
udf_fields: "Sequence[str]",
|
|
356
|
+
udf_inputs: "Iterable[Sequence[Any]]",
|
|
357
|
+
catalog: "Catalog",
|
|
358
|
+
cache: bool,
|
|
359
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
360
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
361
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
362
|
+
self.catalog = catalog
|
|
363
|
+
self.setup()
|
|
364
|
+
|
|
365
|
+
for row in udf_inputs:
|
|
366
|
+
udf_args = self._prepare_row(row, udf_fields, cache, download_cb)
|
|
367
|
+
result_objs = self.process_safe(udf_args)
|
|
368
|
+
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
369
|
+
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
370
|
+
processed_cb.relative_update(1)
|
|
371
|
+
yield output
|
|
372
|
+
|
|
373
|
+
self.teardown()
|
|
374
|
+
|
|
374
375
|
|
|
375
376
|
class Aggregator(UDFBase):
|
|
376
377
|
"""Inherit from this class to pass to `DataChain.agg()`."""
|
|
377
378
|
|
|
378
|
-
is_input_batched = True
|
|
379
379
|
is_output_batched = True
|
|
380
|
-
|
|
380
|
+
|
|
381
|
+
def run(
|
|
382
|
+
self,
|
|
383
|
+
udf_fields: "Sequence[str]",
|
|
384
|
+
udf_inputs: Iterable[RowsOutputBatch],
|
|
385
|
+
catalog: "Catalog",
|
|
386
|
+
cache: bool,
|
|
387
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
388
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
389
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
390
|
+
self.catalog = catalog
|
|
391
|
+
self.setup()
|
|
392
|
+
|
|
393
|
+
for batch in udf_inputs:
|
|
394
|
+
udf_args = zip(
|
|
395
|
+
*[
|
|
396
|
+
self._prepare_row(row, udf_fields, cache, download_cb)
|
|
397
|
+
for row in batch.rows
|
|
398
|
+
]
|
|
399
|
+
)
|
|
400
|
+
result_objs = self.process_safe(udf_args)
|
|
401
|
+
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
402
|
+
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
403
|
+
processed_cb.relative_update(len(batch.rows))
|
|
404
|
+
yield output
|
|
405
|
+
|
|
406
|
+
self.teardown()
|
datachain/lib/utils.py
CHANGED
|
@@ -23,3 +23,8 @@ class DataChainError(Exception):
|
|
|
23
23
|
class DataChainParamsError(DataChainError):
|
|
24
24
|
def __init__(self, message):
|
|
25
25
|
super().__init__(message)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DataChainColumnError(DataChainParamsError):
|
|
29
|
+
def __init__(self, col_name, msg):
|
|
30
|
+
super().__init__(f"Error for column {col_name}: {msg}")
|
datachain/query/__init__.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
from .dataset import DatasetQuery
|
|
2
2
|
from .params import param
|
|
3
|
-
from .schema import C,
|
|
3
|
+
from .schema import C, LocalFilename, Object, Stream
|
|
4
4
|
from .session import Session
|
|
5
5
|
|
|
6
6
|
__all__ = [
|
|
7
7
|
"C",
|
|
8
8
|
"DatasetQuery",
|
|
9
|
-
"DatasetRow",
|
|
10
9
|
"LocalFilename",
|
|
11
10
|
"Object",
|
|
12
11
|
"Session",
|
datachain/query/batch.py
CHANGED
|
@@ -11,8 +11,6 @@ from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from sqlalchemy import Select
|
|
13
13
|
|
|
14
|
-
from datachain.dataset import RowDict
|
|
15
|
-
|
|
16
14
|
|
|
17
15
|
@dataclass
|
|
18
16
|
class RowsOutputBatch:
|
|
@@ -22,14 +20,6 @@ class RowsOutputBatch:
|
|
|
22
20
|
RowsOutput = Union[Sequence, RowsOutputBatch]
|
|
23
21
|
|
|
24
22
|
|
|
25
|
-
@dataclass
|
|
26
|
-
class UDFInputBatch:
|
|
27
|
-
rows: Sequence["RowDict"]
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
UDFInput = Union["RowDict", UDFInputBatch]
|
|
31
|
-
|
|
32
|
-
|
|
33
23
|
class BatchingStrategy(ABC):
|
|
34
24
|
"""BatchingStrategy provides means of batching UDF executions."""
|
|
35
25
|
|
|
@@ -107,7 +97,6 @@ class Partition(BatchingStrategy):
|
|
|
107
97
|
|
|
108
98
|
ordered_query = query.order_by(None).order_by(
|
|
109
99
|
PARTITION_COLUMN_ID,
|
|
110
|
-
"sys__id",
|
|
111
100
|
*query._order_by_clauses,
|
|
112
101
|
)
|
|
113
102
|
|
datachain/query/dataset.py
CHANGED
|
@@ -392,7 +392,7 @@ class UDFStep(Step, ABC):
|
|
|
392
392
|
|
|
393
393
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
394
394
|
use_partitioning = self.partition_by is not None
|
|
395
|
-
batching = self.udf.
|
|
395
|
+
batching = self.udf.get_batching(use_partitioning)
|
|
396
396
|
workers = self.workers
|
|
397
397
|
if (
|
|
398
398
|
not workers
|
|
@@ -591,10 +591,6 @@ class UDFSignal(UDFStep):
|
|
|
591
591
|
return query, []
|
|
592
592
|
table = self.catalog.warehouse.create_pre_udf_table(query)
|
|
593
593
|
q: Select = sqlalchemy.select(*table.c)
|
|
594
|
-
if query._order_by_clauses:
|
|
595
|
-
# we are adding ordering only if it's explicitly added by user in
|
|
596
|
-
# query part before adding signals
|
|
597
|
-
q = q.order_by(table.c.sys__id)
|
|
598
594
|
return q, [table]
|
|
599
595
|
|
|
600
596
|
def create_result_query(
|
|
@@ -630,11 +626,6 @@ class UDFSignal(UDFStep):
|
|
|
630
626
|
else:
|
|
631
627
|
res = sqlalchemy.select(*cols1).select_from(subq)
|
|
632
628
|
|
|
633
|
-
if query._order_by_clauses:
|
|
634
|
-
# if ordering is used in query part before adding signals, we
|
|
635
|
-
# will have it as order by id from select from pre-created udf table
|
|
636
|
-
res = res.order_by(subq.c.sys__id)
|
|
637
|
-
|
|
638
629
|
if self.partition_by is not None:
|
|
639
630
|
subquery = res.subquery()
|
|
640
631
|
res = sqlalchemy.select(*subquery.c).select_from(subquery)
|
|
@@ -666,13 +657,6 @@ class RowGenerator(UDFStep):
|
|
|
666
657
|
def create_result_query(
|
|
667
658
|
self, udf_table, query: Select
|
|
668
659
|
) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
|
|
669
|
-
if not query._order_by_clauses:
|
|
670
|
-
# if we are not selecting all rows in UDF, we need to ensure that
|
|
671
|
-
# we get the same rows as we got as inputs of UDF since selecting
|
|
672
|
-
# without ordering can be non deterministic in some databases
|
|
673
|
-
c = query.selected_columns
|
|
674
|
-
query = query.order_by(c.sys__id)
|
|
675
|
-
|
|
676
660
|
udf_table_query = udf_table.select().subquery()
|
|
677
661
|
udf_table_cols: list[sqlalchemy.Label[Any]] = [
|
|
678
662
|
label(c.name, c) for c in udf_table_query.columns
|
|
@@ -957,24 +941,24 @@ class SQLJoin(Step):
|
|
|
957
941
|
|
|
958
942
|
|
|
959
943
|
@frozen
|
|
960
|
-
class
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
cols: PartitionByType
|
|
944
|
+
class SQLGroupBy(SQLClause):
|
|
945
|
+
cols: Sequence[Union[str, ColumnElement]]
|
|
946
|
+
group_by: Sequence[Union[str, ColumnElement]]
|
|
964
947
|
|
|
965
|
-
def
|
|
966
|
-
|
|
948
|
+
def apply_sql_clause(self, query) -> Select:
|
|
949
|
+
if not self.cols:
|
|
950
|
+
raise ValueError("No columns to select")
|
|
951
|
+
if not self.group_by:
|
|
952
|
+
raise ValueError("No columns to group by")
|
|
967
953
|
|
|
968
|
-
|
|
969
|
-
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
970
|
-
) -> StepResult:
|
|
971
|
-
query = query_generator.select()
|
|
972
|
-
grouped_query = query.group_by(*self.cols)
|
|
954
|
+
subquery = query.subquery()
|
|
973
955
|
|
|
974
|
-
|
|
975
|
-
|
|
956
|
+
cols = [
|
|
957
|
+
subquery.c[str(c)] if isinstance(c, (str, C)) else c
|
|
958
|
+
for c in [*self.group_by, *self.cols]
|
|
959
|
+
]
|
|
976
960
|
|
|
977
|
-
return
|
|
961
|
+
return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by)
|
|
978
962
|
|
|
979
963
|
|
|
980
964
|
def _validate_columns(
|
|
@@ -1130,25 +1114,14 @@ class DatasetQuery:
|
|
|
1130
1114
|
query.steps = query.steps[-1:] + query.steps[:-1]
|
|
1131
1115
|
|
|
1132
1116
|
result = query.starting_step.apply()
|
|
1133
|
-
group_by = None
|
|
1134
1117
|
self.dependencies.update(result.dependencies)
|
|
1135
1118
|
|
|
1136
1119
|
for step in query.steps:
|
|
1137
|
-
if isinstance(step, GroupBy):
|
|
1138
|
-
if group_by is not None:
|
|
1139
|
-
raise TypeError("only one group_by allowed")
|
|
1140
|
-
group_by = step
|
|
1141
|
-
continue
|
|
1142
|
-
|
|
1143
1120
|
result = step.apply(
|
|
1144
1121
|
result.query_generator, self.temp_table_names
|
|
1145
1122
|
) # a chain of steps linked by results
|
|
1146
1123
|
self.dependencies.update(result.dependencies)
|
|
1147
1124
|
|
|
1148
|
-
if group_by:
|
|
1149
|
-
result = group_by.apply(result.query_generator, self.temp_table_names)
|
|
1150
|
-
self.dependencies.update(result.dependencies)
|
|
1151
|
-
|
|
1152
1125
|
return result.query_generator
|
|
1153
1126
|
|
|
1154
1127
|
@staticmethod
|
|
@@ -1410,9 +1383,13 @@ class DatasetQuery:
|
|
|
1410
1383
|
return query.as_scalar()
|
|
1411
1384
|
|
|
1412
1385
|
@detach
|
|
1413
|
-
def group_by(
|
|
1386
|
+
def group_by(
|
|
1387
|
+
self,
|
|
1388
|
+
cols: Sequence[ColumnElement],
|
|
1389
|
+
group_by: Sequence[ColumnElement],
|
|
1390
|
+
) -> "Self":
|
|
1414
1391
|
query = self.clone()
|
|
1415
|
-
query.steps.append(
|
|
1392
|
+
query.steps.append(SQLGroupBy(cols, group_by))
|
|
1416
1393
|
return query
|
|
1417
1394
|
|
|
1418
1395
|
@detach
|
|
@@ -1591,6 +1568,8 @@ class DatasetQuery:
|
|
|
1591
1568
|
)
|
|
1592
1569
|
version = version or dataset.latest_version
|
|
1593
1570
|
|
|
1571
|
+
self.session.add_dataset_version(dataset=dataset, version=version)
|
|
1572
|
+
|
|
1594
1573
|
dr = self.catalog.warehouse.dataset_rows(dataset)
|
|
1595
1574
|
|
|
1596
1575
|
self.catalog.warehouse.copy_table(dr.get_table(), query.select())
|
datachain/query/dispatch.py
CHANGED
|
@@ -114,7 +114,6 @@ class UDFDispatcher:
|
|
|
114
114
|
catalog: Optional[Catalog] = None
|
|
115
115
|
task_queue: Optional[multiprocess.Queue] = None
|
|
116
116
|
done_queue: Optional[multiprocess.Queue] = None
|
|
117
|
-
_batch_size: Optional[int] = None
|
|
118
117
|
|
|
119
118
|
def __init__(
|
|
120
119
|
self,
|
|
@@ -154,17 +153,6 @@ class UDFDispatcher:
|
|
|
154
153
|
self.done_queue = None
|
|
155
154
|
self.ctx = get_context("spawn")
|
|
156
155
|
|
|
157
|
-
@property
|
|
158
|
-
def batch_size(self):
|
|
159
|
-
if self._batch_size is None:
|
|
160
|
-
if hasattr(self.udf, "properties") and hasattr(
|
|
161
|
-
self.udf.properties, "batch"
|
|
162
|
-
):
|
|
163
|
-
self._batch_size = self.udf.properties.batch
|
|
164
|
-
else:
|
|
165
|
-
self._batch_size = 1
|
|
166
|
-
return self._batch_size
|
|
167
|
-
|
|
168
156
|
def _create_worker(self) -> "UDFWorker":
|
|
169
157
|
if not self.catalog:
|
|
170
158
|
id_generator = self.id_generator_class(
|