datachain 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +16 -6
- datachain/cache.py +32 -10
- datachain/catalog/catalog.py +17 -1
- datachain/cli/__init__.py +311 -0
- datachain/cli/commands/__init__.py +29 -0
- datachain/cli/commands/datasets.py +129 -0
- datachain/cli/commands/du.py +14 -0
- datachain/cli/commands/index.py +12 -0
- datachain/cli/commands/ls.py +169 -0
- datachain/cli/commands/misc.py +28 -0
- datachain/cli/commands/query.py +53 -0
- datachain/cli/commands/show.py +38 -0
- datachain/cli/parser/__init__.py +547 -0
- datachain/cli/parser/job.py +120 -0
- datachain/cli/parser/studio.py +126 -0
- datachain/cli/parser/utils.py +63 -0
- datachain/{cli_utils.py → cli/utils.py} +27 -1
- datachain/client/azure.py +6 -2
- datachain/client/fsspec.py +9 -3
- datachain/client/gcs.py +6 -2
- datachain/client/s3.py +16 -1
- datachain/data_storage/db_engine.py +9 -0
- datachain/data_storage/schema.py +4 -10
- datachain/data_storage/sqlite.py +7 -1
- datachain/data_storage/warehouse.py +6 -4
- datachain/{lib/diff.py → diff/__init__.py} +116 -12
- datachain/func/__init__.py +3 -2
- datachain/func/conditional.py +74 -0
- datachain/func/func.py +5 -1
- datachain/lib/arrow.py +7 -1
- datachain/lib/dc.py +8 -3
- datachain/lib/file.py +16 -5
- datachain/lib/hf.py +1 -1
- datachain/lib/listing.py +19 -1
- datachain/lib/pytorch.py +57 -13
- datachain/lib/signal_schema.py +89 -27
- datachain/lib/udf.py +82 -40
- datachain/listing.py +1 -0
- datachain/progress.py +20 -3
- datachain/query/dataset.py +122 -93
- datachain/query/dispatch.py +22 -16
- datachain/studio.py +58 -38
- datachain/utils.py +14 -3
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/METADATA +9 -9
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/RECORD +49 -37
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/WHEEL +1 -1
- datachain/cli.py +0 -1475
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/LICENSE +0 -0
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.3.dist-info → datachain-0.8.5.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
|
-
import contextlib
|
|
2
1
|
import sys
|
|
3
2
|
import traceback
|
|
4
|
-
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
5
|
-
from
|
|
3
|
+
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
4
|
+
from contextlib import closing, nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
|
6
7
|
|
|
7
8
|
import attrs
|
|
8
9
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
9
10
|
from pydantic import BaseModel
|
|
10
11
|
|
|
11
12
|
from datachain.asyn import AsyncMapper
|
|
13
|
+
from datachain.cache import temporary_cache
|
|
12
14
|
from datachain.dataset import RowDict
|
|
13
15
|
from datachain.lib.convert.flatten import flatten
|
|
14
16
|
from datachain.lib.data_model import DataValue
|
|
@@ -21,17 +23,22 @@ from datachain.query.batch import (
|
|
|
21
23
|
Partition,
|
|
22
24
|
RowsOutputBatch,
|
|
23
25
|
)
|
|
26
|
+
from datachain.utils import safe_closing
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
26
29
|
from collections import abc
|
|
30
|
+
from contextlib import AbstractContextManager
|
|
27
31
|
|
|
28
32
|
from typing_extensions import Self
|
|
29
33
|
|
|
34
|
+
from datachain.cache import DataChainCache as Cache
|
|
30
35
|
from datachain.catalog import Catalog
|
|
31
36
|
from datachain.lib.signal_schema import SignalSchema
|
|
32
37
|
from datachain.lib.udf_signature import UdfSignature
|
|
33
38
|
from datachain.query.batch import RowsOutput
|
|
34
39
|
|
|
40
|
+
T = TypeVar("T", bound=Sequence[Any])
|
|
41
|
+
|
|
35
42
|
|
|
36
43
|
class UdfError(DataChainParamsError):
|
|
37
44
|
def __init__(self, msg):
|
|
@@ -98,6 +105,10 @@ class UDFAdapter:
|
|
|
98
105
|
processed_cb,
|
|
99
106
|
)
|
|
100
107
|
|
|
108
|
+
@property
|
|
109
|
+
def prefetch(self) -> int:
|
|
110
|
+
return self.inner.prefetch
|
|
111
|
+
|
|
101
112
|
|
|
102
113
|
class UDFBase(AbstractUDF):
|
|
103
114
|
"""Base class for stateful user-defined functions.
|
|
@@ -148,12 +159,11 @@ class UDFBase(AbstractUDF):
|
|
|
148
159
|
"""
|
|
149
160
|
|
|
150
161
|
is_output_batched = False
|
|
151
|
-
|
|
162
|
+
prefetch: int = 0
|
|
152
163
|
|
|
153
164
|
def __init__(self):
|
|
154
165
|
self.params: Optional[SignalSchema] = None
|
|
155
166
|
self.output = None
|
|
156
|
-
self.catalog = None
|
|
157
167
|
self._func = None
|
|
158
168
|
|
|
159
169
|
def process(self, *args, **kwargs):
|
|
@@ -242,26 +252,23 @@ class UDFBase(AbstractUDF):
|
|
|
242
252
|
return flatten(obj) if isinstance(obj, BaseModel) else [obj]
|
|
243
253
|
|
|
244
254
|
def _parse_row(
|
|
245
|
-
self, row_dict: RowDict, cache: bool, download_cb: Callback
|
|
255
|
+
self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
|
|
246
256
|
) -> list[DataValue]:
|
|
247
257
|
assert self.params
|
|
248
258
|
row = [row_dict[p] for p in self.params.to_udf_spec()]
|
|
249
259
|
obj_row = self.params.row_to_objs(row)
|
|
250
260
|
for obj in obj_row:
|
|
251
261
|
if isinstance(obj, File):
|
|
252
|
-
|
|
253
|
-
obj._set_stream(
|
|
254
|
-
self.catalog, caching_enabled=cache, download_cb=download_cb
|
|
255
|
-
)
|
|
262
|
+
obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
|
|
256
263
|
return obj_row
|
|
257
264
|
|
|
258
|
-
def _prepare_row(self, row, udf_fields, cache, download_cb):
|
|
265
|
+
def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
|
|
259
266
|
row_dict = RowDict(zip(udf_fields, row))
|
|
260
|
-
return self._parse_row(row_dict, cache, download_cb)
|
|
267
|
+
return self._parse_row(row_dict, catalog, cache, download_cb)
|
|
261
268
|
|
|
262
|
-
def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
|
|
269
|
+
def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
|
|
263
270
|
row_dict = RowDict(zip(udf_fields, row))
|
|
264
|
-
udf_input = self._parse_row(row_dict, cache, download_cb)
|
|
271
|
+
udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
|
|
265
272
|
return row_dict["sys__id"], *udf_input
|
|
266
273
|
|
|
267
274
|
def process_safe(self, obj_rows):
|
|
@@ -279,13 +286,47 @@ class UDFBase(AbstractUDF):
|
|
|
279
286
|
return result_objs
|
|
280
287
|
|
|
281
288
|
|
|
282
|
-
|
|
289
|
+
def noop(*args, **kwargs):
|
|
290
|
+
pass
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
async def _prefetch_input(
|
|
294
|
+
row: T,
|
|
295
|
+
download_cb: Optional["Callback"] = None,
|
|
296
|
+
after_prefetch: "Callable[[], None]" = noop,
|
|
297
|
+
) -> T:
|
|
283
298
|
for obj in row:
|
|
284
|
-
if isinstance(obj, File):
|
|
285
|
-
|
|
299
|
+
if isinstance(obj, File) and await obj._prefetch(download_cb):
|
|
300
|
+
after_prefetch()
|
|
286
301
|
return row
|
|
287
302
|
|
|
288
303
|
|
|
304
|
+
def _prefetch_inputs(
|
|
305
|
+
prepared_inputs: "Iterable[T]",
|
|
306
|
+
prefetch: int = 0,
|
|
307
|
+
download_cb: Optional["Callback"] = None,
|
|
308
|
+
after_prefetch: "Callable[[], None]" = noop,
|
|
309
|
+
) -> "abc.Generator[T, None, None]":
|
|
310
|
+
if prefetch > 0:
|
|
311
|
+
f = partial(
|
|
312
|
+
_prefetch_input,
|
|
313
|
+
download_cb=download_cb,
|
|
314
|
+
after_prefetch=after_prefetch,
|
|
315
|
+
)
|
|
316
|
+
prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
|
|
317
|
+
yield from prepared_inputs
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _get_cache(
|
|
321
|
+
cache: "Cache", prefetch: int = 0, use_cache: bool = False
|
|
322
|
+
) -> "AbstractContextManager[Cache]":
|
|
323
|
+
tmp_dir = cache.tmp_dir
|
|
324
|
+
assert tmp_dir
|
|
325
|
+
if prefetch and not use_cache:
|
|
326
|
+
return temporary_cache(tmp_dir, prefix="prefetch-")
|
|
327
|
+
return nullcontext(cache)
|
|
328
|
+
|
|
329
|
+
|
|
289
330
|
class Mapper(UDFBase):
|
|
290
331
|
"""Inherit from this class to pass to `DataChain.map()`."""
|
|
291
332
|
|
|
@@ -300,18 +341,18 @@ class Mapper(UDFBase):
|
|
|
300
341
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
301
342
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
302
343
|
) -> Iterator[Iterable[UDFResult]]:
|
|
303
|
-
self.catalog = catalog
|
|
304
344
|
self.setup()
|
|
305
|
-
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
|
|
306
|
-
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
|
|
307
|
-
for row in udf_inputs
|
|
308
|
-
)
|
|
309
|
-
if self.prefetch > 0:
|
|
310
|
-
prepared_inputs = AsyncMapper(
|
|
311
|
-
_prefetch_input, prepared_inputs, workers=self.prefetch
|
|
312
|
-
).iterate()
|
|
313
345
|
|
|
314
|
-
|
|
346
|
+
def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
|
|
347
|
+
with safe_closing(udf_inputs):
|
|
348
|
+
for row in udf_inputs:
|
|
349
|
+
yield self._prepare_row_and_id(
|
|
350
|
+
row, udf_fields, catalog, cache, download_cb
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
prepared_inputs = _prepare_rows(udf_inputs)
|
|
354
|
+
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
|
|
355
|
+
with closing(prepared_inputs):
|
|
315
356
|
for id_, *udf_args in prepared_inputs:
|
|
316
357
|
result_objs = self.process_safe(udf_args)
|
|
317
358
|
udf_output = self._flatten_row(result_objs)
|
|
@@ -336,14 +377,15 @@ class BatchMapper(UDFBase):
|
|
|
336
377
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
337
378
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
338
379
|
) -> Iterator[Iterable[UDFResult]]:
|
|
339
|
-
self.catalog = catalog
|
|
340
380
|
self.setup()
|
|
341
381
|
|
|
342
382
|
for batch in udf_inputs:
|
|
343
383
|
n_rows = len(batch.rows)
|
|
344
384
|
row_ids, *udf_args = zip(
|
|
345
385
|
*[
|
|
346
|
-
self._prepare_row_and_id(
|
|
386
|
+
self._prepare_row_and_id(
|
|
387
|
+
row, udf_fields, catalog, cache, download_cb
|
|
388
|
+
)
|
|
347
389
|
for row in batch.rows
|
|
348
390
|
]
|
|
349
391
|
)
|
|
@@ -378,17 +420,18 @@ class Generator(UDFBase):
|
|
|
378
420
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
379
421
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
380
422
|
) -> Iterator[Iterable[UDFResult]]:
|
|
381
|
-
self.catalog = catalog
|
|
382
423
|
self.setup()
|
|
383
|
-
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
|
|
384
|
-
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
|
|
385
|
-
)
|
|
386
|
-
if self.prefetch > 0:
|
|
387
|
-
prepared_inputs = AsyncMapper(
|
|
388
|
-
_prefetch_input, prepared_inputs, workers=self.prefetch
|
|
389
|
-
).iterate()
|
|
390
424
|
|
|
391
|
-
|
|
425
|
+
def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
|
|
426
|
+
with safe_closing(udf_inputs):
|
|
427
|
+
for row in udf_inputs:
|
|
428
|
+
yield self._prepare_row(
|
|
429
|
+
row, udf_fields, catalog, cache, download_cb
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
prepared_inputs = _prepare_rows(udf_inputs)
|
|
433
|
+
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
|
|
434
|
+
with closing(prepared_inputs):
|
|
392
435
|
for row in prepared_inputs:
|
|
393
436
|
result_objs = self.process_safe(row)
|
|
394
437
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
@@ -413,13 +456,12 @@ class Aggregator(UDFBase):
|
|
|
413
456
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
414
457
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
415
458
|
) -> Iterator[Iterable[UDFResult]]:
|
|
416
|
-
self.catalog = catalog
|
|
417
459
|
self.setup()
|
|
418
460
|
|
|
419
461
|
for batch in udf_inputs:
|
|
420
462
|
udf_args = zip(
|
|
421
463
|
*[
|
|
422
|
-
self._prepare_row(row, udf_fields, cache, download_cb)
|
|
464
|
+
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
|
|
423
465
|
for row in batch.rows
|
|
424
466
|
]
|
|
425
467
|
)
|
datachain/listing.py
CHANGED
datachain/progress.py
CHANGED
|
@@ -5,6 +5,7 @@ import sys
|
|
|
5
5
|
from threading import RLock
|
|
6
6
|
from typing import Any, ClassVar
|
|
7
7
|
|
|
8
|
+
from fsspec import Callback
|
|
8
9
|
from fsspec.callbacks import TqdmCallback
|
|
9
10
|
from tqdm import tqdm
|
|
10
11
|
|
|
@@ -61,7 +62,7 @@ class Tqdm(tqdm):
|
|
|
61
62
|
disable : If (default: None) or False,
|
|
62
63
|
will be determined by logging level.
|
|
63
64
|
May be overridden to `True` due to non-TTY status.
|
|
64
|
-
Skip override by specifying env var `
|
|
65
|
+
Skip override by specifying env var `DATACHAIN_IGNORE_ISATTY`.
|
|
65
66
|
kwargs : anything accepted by `tqdm.tqdm()`
|
|
66
67
|
"""
|
|
67
68
|
kwargs = kwargs.copy()
|
|
@@ -77,7 +78,7 @@ class Tqdm(tqdm):
|
|
|
77
78
|
# auto-disable based on TTY
|
|
78
79
|
if (
|
|
79
80
|
not disable
|
|
80
|
-
and not env2bool("
|
|
81
|
+
and not env2bool("DATACHAIN_IGNORE_ISATTY")
|
|
81
82
|
and hasattr(file, "isatty")
|
|
82
83
|
):
|
|
83
84
|
disable = not file.isatty()
|
|
@@ -132,8 +133,24 @@ class Tqdm(tqdm):
|
|
|
132
133
|
return d
|
|
133
134
|
|
|
134
135
|
|
|
135
|
-
class CombinedDownloadCallback(
|
|
136
|
+
class CombinedDownloadCallback(Callback):
|
|
136
137
|
def set_size(self, size):
|
|
137
138
|
# This is a no-op to prevent fsspec's .get_file() from setting the combined
|
|
138
139
|
# download size to the size of the current file.
|
|
139
140
|
pass
|
|
141
|
+
|
|
142
|
+
def increment_file_count(self, n: int = 1) -> None:
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
|
|
147
|
+
def __init__(self, tqdm_kwargs=None, *args, **kwargs):
|
|
148
|
+
self.files_count = 0
|
|
149
|
+
tqdm_kwargs = tqdm_kwargs or {}
|
|
150
|
+
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
|
|
151
|
+
super().__init__(tqdm_kwargs, *args, **kwargs)
|
|
152
|
+
|
|
153
|
+
def increment_file_count(self, n: int = 1) -> None:
|
|
154
|
+
self.files_count += n
|
|
155
|
+
if self.tqdm is not None:
|
|
156
|
+
self.tqdm.postfix = f"{self.files_count} files"
|
datachain/query/dataset.py
CHANGED
|
@@ -35,6 +35,7 @@ from sqlalchemy.sql.schema import TableClause
|
|
|
35
35
|
from sqlalchemy.sql.selectable import Select
|
|
36
36
|
|
|
37
37
|
from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
|
|
38
|
+
from datachain.catalog.catalog import clone_catalog_with_cache
|
|
38
39
|
from datachain.data_storage.schema import (
|
|
39
40
|
PARTITION_COLUMN_ID,
|
|
40
41
|
partition_col_names,
|
|
@@ -43,7 +44,8 @@ from datachain.data_storage.schema import (
|
|
|
43
44
|
from datachain.dataset import DatasetStatus, RowDict
|
|
44
45
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
45
46
|
from datachain.func.base import Function
|
|
46
|
-
from datachain.
|
|
47
|
+
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
48
|
+
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
47
49
|
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
48
50
|
from datachain.query.session import Session
|
|
49
51
|
from datachain.sql.functions.random import rand
|
|
@@ -52,6 +54,7 @@ from datachain.utils import (
|
|
|
52
54
|
determine_processes,
|
|
53
55
|
filtered_cloudpickle_dumps,
|
|
54
56
|
get_datachain_executable,
|
|
57
|
+
safe_closing,
|
|
55
58
|
)
|
|
56
59
|
|
|
57
60
|
if TYPE_CHECKING:
|
|
@@ -349,19 +352,26 @@ def process_udf_outputs(
|
|
|
349
352
|
warehouse.insert_rows_done(udf_table)
|
|
350
353
|
|
|
351
354
|
|
|
352
|
-
def get_download_callback() ->
|
|
353
|
-
return
|
|
354
|
-
{
|
|
355
|
+
def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback:
|
|
356
|
+
return TqdmCombinedDownloadCallback(
|
|
357
|
+
{
|
|
358
|
+
"desc": "Download" + suffix,
|
|
359
|
+
"unit": "B",
|
|
360
|
+
"unit_scale": True,
|
|
361
|
+
"unit_divisor": 1024,
|
|
362
|
+
"leave": False,
|
|
363
|
+
**kwargs,
|
|
364
|
+
},
|
|
355
365
|
)
|
|
356
366
|
|
|
357
367
|
|
|
358
368
|
def get_processed_callback() -> Callback:
|
|
359
|
-
return TqdmCallback({"desc": "Processed", "unit": " rows"})
|
|
369
|
+
return TqdmCallback({"desc": "Processed", "unit": " rows", "leave": False})
|
|
360
370
|
|
|
361
371
|
|
|
362
372
|
def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
363
373
|
if is_generator:
|
|
364
|
-
return TqdmCallback({"desc": "Generated", "unit": " rows"})
|
|
374
|
+
return TqdmCallback({"desc": "Generated", "unit": " rows", "leave": False})
|
|
365
375
|
return DEFAULT_CALLBACK
|
|
366
376
|
|
|
367
377
|
|
|
@@ -412,97 +422,109 @@ class UDFStep(Step, ABC):
|
|
|
412
422
|
|
|
413
423
|
udf_fields = [str(c.name) for c in query.selected_columns]
|
|
414
424
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
425
|
+
prefetch = self.udf.prefetch
|
|
426
|
+
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
427
|
+
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
428
|
+
try:
|
|
429
|
+
if workers:
|
|
430
|
+
if catalog.in_memory:
|
|
431
|
+
raise RuntimeError(
|
|
432
|
+
"In-memory databases cannot be used with "
|
|
433
|
+
"distributed processing."
|
|
434
|
+
)
|
|
422
435
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
self.udf,
|
|
428
|
-
self.catalog,
|
|
429
|
-
udf_table,
|
|
430
|
-
query,
|
|
431
|
-
workers,
|
|
432
|
-
processes,
|
|
433
|
-
udf_fields=udf_fields,
|
|
434
|
-
is_generator=self.is_generator,
|
|
435
|
-
use_partitioning=use_partitioning,
|
|
436
|
-
cache=self.cache,
|
|
437
|
-
)
|
|
438
|
-
elif processes:
|
|
439
|
-
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
440
|
-
if self.catalog.in_memory:
|
|
441
|
-
raise RuntimeError(
|
|
442
|
-
"In-memory databases cannot be used with parallel processing."
|
|
443
|
-
)
|
|
444
|
-
udf_info: UdfInfo = {
|
|
445
|
-
"udf_data": filtered_cloudpickle_dumps(self.udf),
|
|
446
|
-
"catalog_init": self.catalog.get_init_params(),
|
|
447
|
-
"metastore_clone_params": self.catalog.metastore.clone_params(),
|
|
448
|
-
"warehouse_clone_params": self.catalog.warehouse.clone_params(),
|
|
449
|
-
"table": udf_table,
|
|
450
|
-
"query": query,
|
|
451
|
-
"udf_fields": udf_fields,
|
|
452
|
-
"batching": batching,
|
|
453
|
-
"processes": processes,
|
|
454
|
-
"is_generator": self.is_generator,
|
|
455
|
-
"cache": self.cache,
|
|
456
|
-
}
|
|
457
|
-
|
|
458
|
-
# Run the UDFDispatcher in another process to avoid needing
|
|
459
|
-
# if __name__ == '__main__': in user scripts
|
|
460
|
-
exec_cmd = get_datachain_executable()
|
|
461
|
-
cmd = [*exec_cmd, "internal-run-udf"]
|
|
462
|
-
envs = dict(os.environ)
|
|
463
|
-
envs.update({"PYTHONPATH": os.getcwd()})
|
|
464
|
-
process_data = filtered_cloudpickle_dumps(udf_info)
|
|
465
|
-
|
|
466
|
-
with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
|
|
467
|
-
process.communicate(process_data)
|
|
468
|
-
if retval := process.poll():
|
|
469
|
-
raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}")
|
|
470
|
-
else:
|
|
471
|
-
# Otherwise process single-threaded (faster for smaller UDFs)
|
|
472
|
-
warehouse = self.catalog.warehouse
|
|
473
|
-
|
|
474
|
-
udf_inputs = batching(warehouse.dataset_select_paginated, query)
|
|
475
|
-
download_cb = get_download_callback()
|
|
476
|
-
processed_cb = get_processed_callback()
|
|
477
|
-
generated_cb = get_generated_callback(self.is_generator)
|
|
478
|
-
try:
|
|
479
|
-
udf_results = self.udf.run(
|
|
480
|
-
udf_fields,
|
|
481
|
-
udf_inputs,
|
|
482
|
-
self.catalog,
|
|
483
|
-
self.cache,
|
|
484
|
-
download_cb,
|
|
485
|
-
processed_cb,
|
|
436
|
+
from datachain.catalog.loader import get_distributed_class
|
|
437
|
+
|
|
438
|
+
distributor = get_distributed_class(
|
|
439
|
+
min_task_size=self.min_task_size
|
|
486
440
|
)
|
|
487
|
-
|
|
488
|
-
warehouse,
|
|
489
|
-
udf_table,
|
|
490
|
-
udf_results,
|
|
441
|
+
distributor(
|
|
491
442
|
self.udf,
|
|
492
|
-
|
|
443
|
+
catalog,
|
|
444
|
+
udf_table,
|
|
445
|
+
query,
|
|
446
|
+
workers,
|
|
447
|
+
processes,
|
|
448
|
+
udf_fields=udf_fields,
|
|
449
|
+
is_generator=self.is_generator,
|
|
450
|
+
use_partitioning=use_partitioning,
|
|
451
|
+
cache=self.cache,
|
|
493
452
|
)
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
453
|
+
elif processes:
|
|
454
|
+
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
455
|
+
if catalog.in_memory:
|
|
456
|
+
raise RuntimeError(
|
|
457
|
+
"In-memory databases cannot be used "
|
|
458
|
+
"with parallel processing."
|
|
459
|
+
)
|
|
460
|
+
udf_info: UdfInfo = {
|
|
461
|
+
"udf_data": filtered_cloudpickle_dumps(self.udf),
|
|
462
|
+
"catalog_init": catalog.get_init_params(),
|
|
463
|
+
"metastore_clone_params": catalog.metastore.clone_params(),
|
|
464
|
+
"warehouse_clone_params": catalog.warehouse.clone_params(),
|
|
465
|
+
"table": udf_table,
|
|
466
|
+
"query": query,
|
|
467
|
+
"udf_fields": udf_fields,
|
|
468
|
+
"batching": batching,
|
|
469
|
+
"processes": processes,
|
|
470
|
+
"is_generator": self.is_generator,
|
|
471
|
+
"cache": self.cache,
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
# Run the UDFDispatcher in another process to avoid needing
|
|
475
|
+
# if __name__ == '__main__': in user scripts
|
|
476
|
+
exec_cmd = get_datachain_executable()
|
|
477
|
+
cmd = [*exec_cmd, "internal-run-udf"]
|
|
478
|
+
envs = dict(os.environ)
|
|
479
|
+
envs.update({"PYTHONPATH": os.getcwd()})
|
|
480
|
+
process_data = filtered_cloudpickle_dumps(udf_info)
|
|
481
|
+
|
|
482
|
+
with subprocess.Popen( # noqa: S603
|
|
483
|
+
cmd, env=envs, stdin=subprocess.PIPE
|
|
484
|
+
) as process:
|
|
485
|
+
process.communicate(process_data)
|
|
486
|
+
if retval := process.poll():
|
|
487
|
+
raise RuntimeError(
|
|
488
|
+
f"UDF Execution Failed! Exit code: {retval}"
|
|
489
|
+
)
|
|
490
|
+
else:
|
|
491
|
+
# Otherwise process single-threaded (faster for smaller UDFs)
|
|
492
|
+
warehouse = catalog.warehouse
|
|
493
|
+
|
|
494
|
+
udf_inputs = batching(warehouse.dataset_select_paginated, query)
|
|
495
|
+
download_cb = get_download_callback()
|
|
496
|
+
processed_cb = get_processed_callback()
|
|
497
|
+
generated_cb = get_generated_callback(self.is_generator)
|
|
498
|
+
|
|
499
|
+
try:
|
|
500
|
+
udf_results = self.udf.run(
|
|
501
|
+
udf_fields,
|
|
502
|
+
udf_inputs,
|
|
503
|
+
catalog,
|
|
504
|
+
self.cache,
|
|
505
|
+
download_cb,
|
|
506
|
+
processed_cb,
|
|
507
|
+
)
|
|
508
|
+
with safe_closing(udf_results):
|
|
509
|
+
process_udf_outputs(
|
|
510
|
+
warehouse,
|
|
511
|
+
udf_table,
|
|
512
|
+
udf_results,
|
|
513
|
+
self.udf,
|
|
514
|
+
cb=generated_cb,
|
|
515
|
+
)
|
|
516
|
+
finally:
|
|
517
|
+
download_cb.close()
|
|
518
|
+
processed_cb.close()
|
|
519
|
+
generated_cb.close()
|
|
520
|
+
|
|
521
|
+
except QueryScriptCancelError:
|
|
522
|
+
self.catalog.warehouse.close()
|
|
523
|
+
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
|
|
524
|
+
except (Exception, KeyboardInterrupt):
|
|
525
|
+
# Close any open database connections if an error is encountered
|
|
526
|
+
self.catalog.warehouse.close()
|
|
527
|
+
raise
|
|
506
528
|
|
|
507
529
|
def create_partitions_table(self, query: Select) -> "Table":
|
|
508
530
|
"""
|
|
@@ -602,6 +624,13 @@ class UDFSignal(UDFStep):
|
|
|
602
624
|
signal_name_cols = {c.name: c for c in signal_cols}
|
|
603
625
|
cols = signal_cols
|
|
604
626
|
|
|
627
|
+
overlap = {c.name for c in original_cols} & {c.name for c in cols}
|
|
628
|
+
if overlap:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
"Column already exists or added in the previous steps: "
|
|
631
|
+
+ ", ".join(overlap)
|
|
632
|
+
)
|
|
633
|
+
|
|
605
634
|
def q(*columns):
|
|
606
635
|
cols1 = []
|
|
607
636
|
cols2 = []
|
datachain/query/dispatch.py
CHANGED
|
@@ -14,7 +14,9 @@ from multiprocess import get_context
|
|
|
14
14
|
from sqlalchemy.sql import func
|
|
15
15
|
|
|
16
16
|
from datachain.catalog import Catalog
|
|
17
|
+
from datachain.catalog.catalog import clone_catalog_with_cache
|
|
17
18
|
from datachain.catalog.loader import get_distributed_class
|
|
19
|
+
from datachain.lib.udf import _get_cache
|
|
18
20
|
from datachain.query.batch import RowsOutput, RowsOutputBatch
|
|
19
21
|
from datachain.query.dataset import (
|
|
20
22
|
get_download_callback,
|
|
@@ -25,7 +27,7 @@ from datachain.query.dataset import (
|
|
|
25
27
|
from datachain.query.queue import get_from_queue, put_into_queue
|
|
26
28
|
from datachain.query.udf import UdfInfo
|
|
27
29
|
from datachain.query.utils import get_query_id_column
|
|
28
|
-
from datachain.utils import batched, flatten
|
|
30
|
+
from datachain.utils import batched, flatten, safe_closing
|
|
29
31
|
|
|
30
32
|
if TYPE_CHECKING:
|
|
31
33
|
from sqlalchemy import Select, Table
|
|
@@ -304,21 +306,25 @@ class UDFWorker:
|
|
|
304
306
|
processed_cb = ProcessedCallback()
|
|
305
307
|
generated_cb = get_generated_callback(self.is_generator)
|
|
306
308
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
self.
|
|
310
|
-
self.
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
309
|
+
prefetch = self.udf.prefetch
|
|
310
|
+
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
311
|
+
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
312
|
+
udf_results = self.udf.run(
|
|
313
|
+
self.udf_fields,
|
|
314
|
+
self.get_inputs(),
|
|
315
|
+
catalog,
|
|
316
|
+
self.cache,
|
|
317
|
+
download_cb=self.cb,
|
|
318
|
+
processed_cb=processed_cb,
|
|
319
|
+
)
|
|
320
|
+
with safe_closing(udf_results):
|
|
321
|
+
process_udf_outputs(
|
|
322
|
+
catalog.warehouse,
|
|
323
|
+
self.table,
|
|
324
|
+
self.notify_and_process(udf_results, processed_cb),
|
|
325
|
+
self.udf,
|
|
326
|
+
cb=generated_cb,
|
|
327
|
+
)
|
|
322
328
|
|
|
323
329
|
put_into_queue(
|
|
324
330
|
self.done_queue,
|