datachain 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +16 -6
- datachain/cache.py +32 -10
- datachain/catalog/catalog.py +17 -1
- datachain/client/azure.py +6 -2
- datachain/client/fsspec.py +1 -1
- datachain/client/gcs.py +6 -2
- datachain/client/s3.py +22 -4
- datachain/data_storage/db_engine.py +9 -0
- datachain/data_storage/schema.py +4 -10
- datachain/data_storage/sqlite.py +7 -1
- datachain/data_storage/warehouse.py +6 -4
- datachain/{lib/diff.py → diff/__init__.py} +116 -12
- datachain/func/__init__.py +2 -1
- datachain/func/conditional.py +31 -9
- datachain/lib/arrow.py +3 -1
- datachain/lib/dc.py +5 -3
- datachain/lib/file.py +15 -4
- datachain/lib/hf.py +1 -1
- datachain/lib/pytorch.py +57 -13
- datachain/lib/udf.py +82 -40
- datachain/listing.py +1 -0
- datachain/progress.py +18 -1
- datachain/query/dataset.py +122 -93
- datachain/query/dispatch.py +22 -16
- datachain/utils.py +13 -2
- {datachain-0.8.4.dist-info → datachain-0.8.6.dist-info}/METADATA +6 -6
- {datachain-0.8.4.dist-info → datachain-0.8.6.dist-info}/RECORD +31 -31
- {datachain-0.8.4.dist-info → datachain-0.8.6.dist-info}/WHEEL +1 -1
- {datachain-0.8.4.dist-info → datachain-0.8.6.dist-info}/LICENSE +0 -0
- {datachain-0.8.4.dist-info → datachain-0.8.6.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.4.dist-info → datachain-0.8.6.dist-info}/top_level.txt +0 -0
datachain/func/conditional.py
CHANGED
|
@@ -8,6 +8,8 @@ from datachain.sql.functions import conditional
|
|
|
8
8
|
|
|
9
9
|
from .func import ColT, Func
|
|
10
10
|
|
|
11
|
+
CaseT = Union[int, float, complex, bool, str]
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
def greatest(*args: Union[ColT, float]) -> Func:
|
|
13
15
|
"""
|
|
@@ -85,9 +87,7 @@ def least(*args: Union[ColT, float]) -> Func:
|
|
|
85
87
|
)
|
|
86
88
|
|
|
87
89
|
|
|
88
|
-
def case(
|
|
89
|
-
*args: tuple[BinaryExpression, Union[int, float, complex, bool, str]], else_=None
|
|
90
|
-
) -> Func:
|
|
90
|
+
def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
|
|
91
91
|
"""
|
|
92
92
|
Returns the case function that produces case expression which has a list of
|
|
93
93
|
conditions and corresponding results. Results can only be python primitives
|
|
@@ -108,26 +108,48 @@ def case(
|
|
|
108
108
|
res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
|
|
109
109
|
)
|
|
110
110
|
```
|
|
111
|
-
|
|
112
|
-
Note:
|
|
113
|
-
- Result column will always be of the same type as the input columns.
|
|
114
111
|
"""
|
|
115
112
|
supported_types = [int, float, complex, str, bool]
|
|
116
113
|
|
|
117
114
|
type_ = type(else_) if else_ else None
|
|
118
115
|
|
|
119
116
|
if not args:
|
|
120
|
-
raise DataChainParamsError("Missing
|
|
117
|
+
raise DataChainParamsError("Missing statements")
|
|
121
118
|
|
|
122
119
|
for arg in args:
|
|
123
120
|
if type_ and not isinstance(arg[1], type_):
|
|
124
|
-
raise DataChainParamsError("
|
|
121
|
+
raise DataChainParamsError("Statement values must be of the same type")
|
|
125
122
|
type_ = type(arg[1])
|
|
126
123
|
|
|
127
124
|
if type_ not in supported_types:
|
|
128
125
|
raise DataChainParamsError(
|
|
129
|
-
f"
|
|
126
|
+
f"Only python literals ({supported_types}) are supported for values"
|
|
130
127
|
)
|
|
131
128
|
|
|
132
129
|
kwargs = {"else_": else_}
|
|
133
130
|
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
|
|
134
|
+
"""
|
|
135
|
+
Returns the ifelse function that produces if expression which has a condition
|
|
136
|
+
and values for true and false outcome. Results can only be python primitives
|
|
137
|
+
like string, numbes or booleans. Result type is inferred from the values.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
condition: BinaryExpression - condition which is evaluated
|
|
141
|
+
if_val: (str | int | float | complex | bool): value for true condition outcome
|
|
142
|
+
else_val: (str | int | float | complex | bool): value for false condition
|
|
143
|
+
outcome
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Func: A Func object that represents the ifelse function.
|
|
147
|
+
|
|
148
|
+
Example:
|
|
149
|
+
```py
|
|
150
|
+
dc.mutate(
|
|
151
|
+
res=func.ifelse(C("num") > 0, "P", "N"),
|
|
152
|
+
)
|
|
153
|
+
```
|
|
154
|
+
"""
|
|
155
|
+
return case((condition, if_val), else_=else_val)
|
datachain/lib/arrow.py
CHANGED
|
@@ -91,7 +91,9 @@ class ArrowGenerator(Generator):
|
|
|
91
91
|
yield from record_batch.to_pylist()
|
|
92
92
|
|
|
93
93
|
it = islice(iter_records(), self.nrows)
|
|
94
|
-
with tqdm(
|
|
94
|
+
with tqdm(
|
|
95
|
+
it, desc="Parsed by pyarrow", unit="rows", total=self.nrows, leave=False
|
|
96
|
+
) as pbar:
|
|
95
97
|
for index, record in enumerate(pbar):
|
|
96
98
|
yield self._process_record(
|
|
97
99
|
record, file, index, hf_schema, use_datachain_schema
|
datachain/lib/dc.py
CHANGED
|
@@ -451,6 +451,7 @@ class DataChain:
|
|
|
451
451
|
return dc
|
|
452
452
|
|
|
453
453
|
if update or not list_ds_exists:
|
|
454
|
+
# disable prefetch for listing, as it pre-downloads all files
|
|
454
455
|
(
|
|
455
456
|
cls.from_records(
|
|
456
457
|
DataChain.DEFAULT_FILE_RECORD,
|
|
@@ -458,6 +459,7 @@ class DataChain:
|
|
|
458
459
|
settings=settings,
|
|
459
460
|
in_memory=in_memory,
|
|
460
461
|
)
|
|
462
|
+
.settings(prefetch=0)
|
|
461
463
|
.gen(
|
|
462
464
|
list_bucket(list_uri, cache, client_config=client_config),
|
|
463
465
|
output={f"{object_name}": File},
|
|
@@ -1534,7 +1536,7 @@ class DataChain:
|
|
|
1534
1536
|
|
|
1535
1537
|
Example:
|
|
1536
1538
|
```py
|
|
1537
|
-
|
|
1539
|
+
res = persons.compare(
|
|
1538
1540
|
new_persons,
|
|
1539
1541
|
on=["id"],
|
|
1540
1542
|
right_on=["other_id"],
|
|
@@ -1547,9 +1549,9 @@ class DataChain:
|
|
|
1547
1549
|
)
|
|
1548
1550
|
```
|
|
1549
1551
|
"""
|
|
1550
|
-
from datachain.
|
|
1552
|
+
from datachain.diff import _compare
|
|
1551
1553
|
|
|
1552
|
-
return
|
|
1554
|
+
return _compare(
|
|
1553
1555
|
self,
|
|
1554
1556
|
other,
|
|
1555
1557
|
on,
|
datachain/lib/file.py
CHANGED
|
@@ -269,10 +269,21 @@ class File(DataModel):
|
|
|
269
269
|
client = self._catalog.get_client(self.source)
|
|
270
270
|
client.download(self, callback=self._download_cb)
|
|
271
271
|
|
|
272
|
-
async def _prefetch(self) ->
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
272
|
+
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
|
|
273
|
+
from datachain.client.hf import HfClient
|
|
274
|
+
|
|
275
|
+
if self._catalog is None:
|
|
276
|
+
raise RuntimeError("cannot prefetch file because catalog is not setup")
|
|
277
|
+
|
|
278
|
+
client = self._catalog.get_client(self.source)
|
|
279
|
+
if client.protocol == HfClient.protocol:
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
await client._download(self, callback=download_cb or self._download_cb)
|
|
283
|
+
self._set_stream(
|
|
284
|
+
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
|
|
285
|
+
)
|
|
286
|
+
return True
|
|
276
287
|
|
|
277
288
|
def get_local_path(self) -> Optional[str]:
|
|
278
289
|
"""Return path to a file in a local cache.
|
datachain/lib/hf.py
CHANGED
|
@@ -95,7 +95,7 @@ class HFGenerator(Generator):
|
|
|
95
95
|
ds = self.ds_dict[split]
|
|
96
96
|
if split:
|
|
97
97
|
desc += f" split '{split}'"
|
|
98
|
-
with tqdm(desc=desc, unit=" rows") as pbar:
|
|
98
|
+
with tqdm(desc=desc, unit=" rows", leave=False) as pbar:
|
|
99
99
|
for row in ds:
|
|
100
100
|
output_dict = {}
|
|
101
101
|
if split and "split" in self.output_schema.model_fields:
|
datachain/lib/pytorch.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
|
|
2
|
+
import os
|
|
3
|
+
import weakref
|
|
4
|
+
from collections.abc import Generator, Iterable, Iterator
|
|
5
|
+
from contextlib import closing
|
|
3
6
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
4
7
|
|
|
5
8
|
from PIL import Image
|
|
@@ -9,15 +12,19 @@ from torch.utils.data import IterableDataset, get_worker_info
|
|
|
9
12
|
from torchvision.transforms import v2
|
|
10
13
|
|
|
11
14
|
from datachain import Session
|
|
12
|
-
from datachain.
|
|
15
|
+
from datachain.cache import get_temp_cache
|
|
13
16
|
from datachain.catalog import Catalog, get_catalog
|
|
14
17
|
from datachain.lib.dc import DataChain
|
|
15
18
|
from datachain.lib.settings import Settings
|
|
16
19
|
from datachain.lib.text import convert_text
|
|
20
|
+
from datachain.progress import CombinedDownloadCallback
|
|
21
|
+
from datachain.query.dataset import get_download_callback
|
|
17
22
|
|
|
18
23
|
if TYPE_CHECKING:
|
|
19
24
|
from torchvision.transforms.v2 import Transform
|
|
20
25
|
|
|
26
|
+
from datachain.cache import DataChainCache as Cache
|
|
27
|
+
|
|
21
28
|
|
|
22
29
|
logger = logging.getLogger("datachain")
|
|
23
30
|
|
|
@@ -75,6 +82,19 @@ class PytorchDataset(IterableDataset):
|
|
|
75
82
|
if (prefetch := dc_settings.prefetch) is not None:
|
|
76
83
|
self.prefetch = prefetch
|
|
77
84
|
|
|
85
|
+
self._cache = catalog.cache
|
|
86
|
+
self._prefetch_cache: Optional[Cache] = None
|
|
87
|
+
if prefetch and not self.cache:
|
|
88
|
+
tmp_dir = catalog.cache.tmp_dir
|
|
89
|
+
assert tmp_dir
|
|
90
|
+
self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-")
|
|
91
|
+
self._cache = self._prefetch_cache
|
|
92
|
+
weakref.finalize(self, self._prefetch_cache.destroy)
|
|
93
|
+
|
|
94
|
+
def close(self) -> None:
|
|
95
|
+
if self._prefetch_cache:
|
|
96
|
+
self._prefetch_cache.destroy()
|
|
97
|
+
|
|
78
98
|
def _init_catalog(self, catalog: "Catalog"):
|
|
79
99
|
# For compatibility with multiprocessing,
|
|
80
100
|
# we can only store params in __init__(), as Catalog isn't picklable
|
|
@@ -89,9 +109,15 @@ class PytorchDataset(IterableDataset):
|
|
|
89
109
|
ms = ms_cls(*ms_args, **ms_kwargs)
|
|
90
110
|
wh_cls, wh_args, wh_kwargs = self._wh_params
|
|
91
111
|
wh = wh_cls(*wh_args, **wh_kwargs)
|
|
92
|
-
|
|
112
|
+
catalog = Catalog(ms, wh, **self._catalog_params)
|
|
113
|
+
catalog.cache = self._cache
|
|
114
|
+
return catalog
|
|
93
115
|
|
|
94
|
-
def
|
|
116
|
+
def _row_iter(
|
|
117
|
+
self,
|
|
118
|
+
total_rank: int,
|
|
119
|
+
total_workers: int,
|
|
120
|
+
) -> Generator[tuple[Any, ...], None, None]:
|
|
95
121
|
catalog = self._get_catalog()
|
|
96
122
|
session = Session("PyTorch", catalog=catalog)
|
|
97
123
|
ds = DataChain.from_dataset(
|
|
@@ -104,16 +130,34 @@ class PytorchDataset(IterableDataset):
|
|
|
104
130
|
ds = ds.chunk(total_rank, total_workers)
|
|
105
131
|
yield from ds.collect()
|
|
106
132
|
|
|
107
|
-
def
|
|
108
|
-
|
|
109
|
-
rows = self._rows_iter(total_rank, total_workers)
|
|
110
|
-
if self.prefetch > 0:
|
|
111
|
-
from datachain.lib.udf import _prefetch_input
|
|
112
|
-
|
|
113
|
-
rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
|
|
114
|
-
yield from map(self._process_row, rows)
|
|
133
|
+
def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
|
|
134
|
+
from datachain.lib.udf import _prefetch_inputs
|
|
115
135
|
|
|
116
|
-
|
|
136
|
+
total_rank, total_workers = self.get_rank_and_workers()
|
|
137
|
+
download_cb = CombinedDownloadCallback()
|
|
138
|
+
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
|
|
139
|
+
download_cb = get_download_callback(
|
|
140
|
+
f"{total_rank}/{total_workers}",
|
|
141
|
+
position=total_rank,
|
|
142
|
+
leave=True,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
rows = self._row_iter(total_rank, total_workers)
|
|
146
|
+
rows = _prefetch_inputs(
|
|
147
|
+
rows,
|
|
148
|
+
self.prefetch,
|
|
149
|
+
download_cb=download_cb,
|
|
150
|
+
after_prefetch=download_cb.increment_file_count,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
with download_cb, closing(rows):
|
|
154
|
+
yield from rows
|
|
155
|
+
|
|
156
|
+
def __iter__(self) -> Iterator[list[Any]]:
|
|
157
|
+
with closing(self._iter_with_prefetch()) as rows:
|
|
158
|
+
yield from map(self._process_row, rows)
|
|
159
|
+
|
|
160
|
+
def _process_row(self, row_features: Iterable[Any]) -> list[Any]:
|
|
117
161
|
row = []
|
|
118
162
|
for fr in row_features:
|
|
119
163
|
if hasattr(fr, "read"):
|
datachain/lib/udf.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
|
-
import contextlib
|
|
2
1
|
import sys
|
|
3
2
|
import traceback
|
|
4
|
-
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
5
|
-
from
|
|
3
|
+
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
4
|
+
from contextlib import closing, nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
|
6
7
|
|
|
7
8
|
import attrs
|
|
8
9
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
9
10
|
from pydantic import BaseModel
|
|
10
11
|
|
|
11
12
|
from datachain.asyn import AsyncMapper
|
|
13
|
+
from datachain.cache import temporary_cache
|
|
12
14
|
from datachain.dataset import RowDict
|
|
13
15
|
from datachain.lib.convert.flatten import flatten
|
|
14
16
|
from datachain.lib.data_model import DataValue
|
|
@@ -21,17 +23,22 @@ from datachain.query.batch import (
|
|
|
21
23
|
Partition,
|
|
22
24
|
RowsOutputBatch,
|
|
23
25
|
)
|
|
26
|
+
from datachain.utils import safe_closing
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
26
29
|
from collections import abc
|
|
30
|
+
from contextlib import AbstractContextManager
|
|
27
31
|
|
|
28
32
|
from typing_extensions import Self
|
|
29
33
|
|
|
34
|
+
from datachain.cache import DataChainCache as Cache
|
|
30
35
|
from datachain.catalog import Catalog
|
|
31
36
|
from datachain.lib.signal_schema import SignalSchema
|
|
32
37
|
from datachain.lib.udf_signature import UdfSignature
|
|
33
38
|
from datachain.query.batch import RowsOutput
|
|
34
39
|
|
|
40
|
+
T = TypeVar("T", bound=Sequence[Any])
|
|
41
|
+
|
|
35
42
|
|
|
36
43
|
class UdfError(DataChainParamsError):
|
|
37
44
|
def __init__(self, msg):
|
|
@@ -98,6 +105,10 @@ class UDFAdapter:
|
|
|
98
105
|
processed_cb,
|
|
99
106
|
)
|
|
100
107
|
|
|
108
|
+
@property
|
|
109
|
+
def prefetch(self) -> int:
|
|
110
|
+
return self.inner.prefetch
|
|
111
|
+
|
|
101
112
|
|
|
102
113
|
class UDFBase(AbstractUDF):
|
|
103
114
|
"""Base class for stateful user-defined functions.
|
|
@@ -148,12 +159,11 @@ class UDFBase(AbstractUDF):
|
|
|
148
159
|
"""
|
|
149
160
|
|
|
150
161
|
is_output_batched = False
|
|
151
|
-
|
|
162
|
+
prefetch: int = 0
|
|
152
163
|
|
|
153
164
|
def __init__(self):
|
|
154
165
|
self.params: Optional[SignalSchema] = None
|
|
155
166
|
self.output = None
|
|
156
|
-
self.catalog = None
|
|
157
167
|
self._func = None
|
|
158
168
|
|
|
159
169
|
def process(self, *args, **kwargs):
|
|
@@ -242,26 +252,23 @@ class UDFBase(AbstractUDF):
|
|
|
242
252
|
return flatten(obj) if isinstance(obj, BaseModel) else [obj]
|
|
243
253
|
|
|
244
254
|
def _parse_row(
|
|
245
|
-
self, row_dict: RowDict, cache: bool, download_cb: Callback
|
|
255
|
+
self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
|
|
246
256
|
) -> list[DataValue]:
|
|
247
257
|
assert self.params
|
|
248
258
|
row = [row_dict[p] for p in self.params.to_udf_spec()]
|
|
249
259
|
obj_row = self.params.row_to_objs(row)
|
|
250
260
|
for obj in obj_row:
|
|
251
261
|
if isinstance(obj, File):
|
|
252
|
-
|
|
253
|
-
obj._set_stream(
|
|
254
|
-
self.catalog, caching_enabled=cache, download_cb=download_cb
|
|
255
|
-
)
|
|
262
|
+
obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
|
|
256
263
|
return obj_row
|
|
257
264
|
|
|
258
|
-
def _prepare_row(self, row, udf_fields, cache, download_cb):
|
|
265
|
+
def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
|
|
259
266
|
row_dict = RowDict(zip(udf_fields, row))
|
|
260
|
-
return self._parse_row(row_dict, cache, download_cb)
|
|
267
|
+
return self._parse_row(row_dict, catalog, cache, download_cb)
|
|
261
268
|
|
|
262
|
-
def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
|
|
269
|
+
def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
|
|
263
270
|
row_dict = RowDict(zip(udf_fields, row))
|
|
264
|
-
udf_input = self._parse_row(row_dict, cache, download_cb)
|
|
271
|
+
udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
|
|
265
272
|
return row_dict["sys__id"], *udf_input
|
|
266
273
|
|
|
267
274
|
def process_safe(self, obj_rows):
|
|
@@ -279,13 +286,47 @@ class UDFBase(AbstractUDF):
|
|
|
279
286
|
return result_objs
|
|
280
287
|
|
|
281
288
|
|
|
282
|
-
|
|
289
|
+
def noop(*args, **kwargs):
|
|
290
|
+
pass
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
async def _prefetch_input(
|
|
294
|
+
row: T,
|
|
295
|
+
download_cb: Optional["Callback"] = None,
|
|
296
|
+
after_prefetch: "Callable[[], None]" = noop,
|
|
297
|
+
) -> T:
|
|
283
298
|
for obj in row:
|
|
284
|
-
if isinstance(obj, File):
|
|
285
|
-
|
|
299
|
+
if isinstance(obj, File) and await obj._prefetch(download_cb):
|
|
300
|
+
after_prefetch()
|
|
286
301
|
return row
|
|
287
302
|
|
|
288
303
|
|
|
304
|
+
def _prefetch_inputs(
|
|
305
|
+
prepared_inputs: "Iterable[T]",
|
|
306
|
+
prefetch: int = 0,
|
|
307
|
+
download_cb: Optional["Callback"] = None,
|
|
308
|
+
after_prefetch: "Callable[[], None]" = noop,
|
|
309
|
+
) -> "abc.Generator[T, None, None]":
|
|
310
|
+
if prefetch > 0:
|
|
311
|
+
f = partial(
|
|
312
|
+
_prefetch_input,
|
|
313
|
+
download_cb=download_cb,
|
|
314
|
+
after_prefetch=after_prefetch,
|
|
315
|
+
)
|
|
316
|
+
prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
|
|
317
|
+
yield from prepared_inputs
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _get_cache(
|
|
321
|
+
cache: "Cache", prefetch: int = 0, use_cache: bool = False
|
|
322
|
+
) -> "AbstractContextManager[Cache]":
|
|
323
|
+
tmp_dir = cache.tmp_dir
|
|
324
|
+
assert tmp_dir
|
|
325
|
+
if prefetch and not use_cache:
|
|
326
|
+
return temporary_cache(tmp_dir, prefix="prefetch-")
|
|
327
|
+
return nullcontext(cache)
|
|
328
|
+
|
|
329
|
+
|
|
289
330
|
class Mapper(UDFBase):
|
|
290
331
|
"""Inherit from this class to pass to `DataChain.map()`."""
|
|
291
332
|
|
|
@@ -300,18 +341,18 @@ class Mapper(UDFBase):
|
|
|
300
341
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
301
342
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
302
343
|
) -> Iterator[Iterable[UDFResult]]:
|
|
303
|
-
self.catalog = catalog
|
|
304
344
|
self.setup()
|
|
305
|
-
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
|
|
306
|
-
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
|
|
307
|
-
for row in udf_inputs
|
|
308
|
-
)
|
|
309
|
-
if self.prefetch > 0:
|
|
310
|
-
prepared_inputs = AsyncMapper(
|
|
311
|
-
_prefetch_input, prepared_inputs, workers=self.prefetch
|
|
312
|
-
).iterate()
|
|
313
345
|
|
|
314
|
-
|
|
346
|
+
def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
|
|
347
|
+
with safe_closing(udf_inputs):
|
|
348
|
+
for row in udf_inputs:
|
|
349
|
+
yield self._prepare_row_and_id(
|
|
350
|
+
row, udf_fields, catalog, cache, download_cb
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
prepared_inputs = _prepare_rows(udf_inputs)
|
|
354
|
+
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
|
|
355
|
+
with closing(prepared_inputs):
|
|
315
356
|
for id_, *udf_args in prepared_inputs:
|
|
316
357
|
result_objs = self.process_safe(udf_args)
|
|
317
358
|
udf_output = self._flatten_row(result_objs)
|
|
@@ -336,14 +377,15 @@ class BatchMapper(UDFBase):
|
|
|
336
377
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
337
378
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
338
379
|
) -> Iterator[Iterable[UDFResult]]:
|
|
339
|
-
self.catalog = catalog
|
|
340
380
|
self.setup()
|
|
341
381
|
|
|
342
382
|
for batch in udf_inputs:
|
|
343
383
|
n_rows = len(batch.rows)
|
|
344
384
|
row_ids, *udf_args = zip(
|
|
345
385
|
*[
|
|
346
|
-
self._prepare_row_and_id(
|
|
386
|
+
self._prepare_row_and_id(
|
|
387
|
+
row, udf_fields, catalog, cache, download_cb
|
|
388
|
+
)
|
|
347
389
|
for row in batch.rows
|
|
348
390
|
]
|
|
349
391
|
)
|
|
@@ -378,17 +420,18 @@ class Generator(UDFBase):
|
|
|
378
420
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
379
421
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
380
422
|
) -> Iterator[Iterable[UDFResult]]:
|
|
381
|
-
self.catalog = catalog
|
|
382
423
|
self.setup()
|
|
383
|
-
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
|
|
384
|
-
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
|
|
385
|
-
)
|
|
386
|
-
if self.prefetch > 0:
|
|
387
|
-
prepared_inputs = AsyncMapper(
|
|
388
|
-
_prefetch_input, prepared_inputs, workers=self.prefetch
|
|
389
|
-
).iterate()
|
|
390
424
|
|
|
391
|
-
|
|
425
|
+
def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
|
|
426
|
+
with safe_closing(udf_inputs):
|
|
427
|
+
for row in udf_inputs:
|
|
428
|
+
yield self._prepare_row(
|
|
429
|
+
row, udf_fields, catalog, cache, download_cb
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
prepared_inputs = _prepare_rows(udf_inputs)
|
|
433
|
+
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
|
|
434
|
+
with closing(prepared_inputs):
|
|
392
435
|
for row in prepared_inputs:
|
|
393
436
|
result_objs = self.process_safe(row)
|
|
394
437
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
@@ -413,13 +456,12 @@ class Aggregator(UDFBase):
|
|
|
413
456
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
414
457
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
415
458
|
) -> Iterator[Iterable[UDFResult]]:
|
|
416
|
-
self.catalog = catalog
|
|
417
459
|
self.setup()
|
|
418
460
|
|
|
419
461
|
for batch in udf_inputs:
|
|
420
462
|
udf_args = zip(
|
|
421
463
|
*[
|
|
422
|
-
self._prepare_row(row, udf_fields, cache, download_cb)
|
|
464
|
+
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
|
|
423
465
|
for row in batch.rows
|
|
424
466
|
]
|
|
425
467
|
)
|
datachain/listing.py
CHANGED
datachain/progress.py
CHANGED
|
@@ -5,6 +5,7 @@ import sys
|
|
|
5
5
|
from threading import RLock
|
|
6
6
|
from typing import Any, ClassVar
|
|
7
7
|
|
|
8
|
+
from fsspec import Callback
|
|
8
9
|
from fsspec.callbacks import TqdmCallback
|
|
9
10
|
from tqdm import tqdm
|
|
10
11
|
|
|
@@ -132,8 +133,24 @@ class Tqdm(tqdm):
|
|
|
132
133
|
return d
|
|
133
134
|
|
|
134
135
|
|
|
135
|
-
class CombinedDownloadCallback(
|
|
136
|
+
class CombinedDownloadCallback(Callback):
|
|
136
137
|
def set_size(self, size):
|
|
137
138
|
# This is a no-op to prevent fsspec's .get_file() from setting the combined
|
|
138
139
|
# download size to the size of the current file.
|
|
139
140
|
pass
|
|
141
|
+
|
|
142
|
+
def increment_file_count(self, n: int = 1) -> None:
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
|
|
147
|
+
def __init__(self, tqdm_kwargs=None, *args, **kwargs):
|
|
148
|
+
self.files_count = 0
|
|
149
|
+
tqdm_kwargs = tqdm_kwargs or {}
|
|
150
|
+
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
|
|
151
|
+
super().__init__(tqdm_kwargs, *args, **kwargs)
|
|
152
|
+
|
|
153
|
+
def increment_file_count(self, n: int = 1) -> None:
|
|
154
|
+
self.files_count += n
|
|
155
|
+
if self.tqdm is not None:
|
|
156
|
+
self.tqdm.postfix = f"{self.files_count} files"
|