datachain 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -8,6 +8,8 @@ from datachain.sql.functions import conditional
8
8
 
9
9
  from .func import ColT, Func
10
10
 
11
+ CaseT = Union[int, float, complex, bool, str]
12
+
11
13
 
12
14
  def greatest(*args: Union[ColT, float]) -> Func:
13
15
  """
@@ -85,9 +87,7 @@ def least(*args: Union[ColT, float]) -> Func:
85
87
  )
86
88
 
87
89
 
88
- def case(
89
- *args: tuple[BinaryExpression, Union[int, float, complex, bool, str]], else_=None
90
- ) -> Func:
90
+ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
91
91
  """
92
92
  Returns the case function that produces case expression which has a list of
93
93
  conditions and corresponding results. Results can only be python primitives
@@ -108,26 +108,48 @@ def case(
108
108
  res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
109
109
  )
110
110
  ```
111
-
112
- Note:
113
- - Result column will always be of the same type as the input columns.
114
111
  """
115
112
  supported_types = [int, float, complex, str, bool]
116
113
 
117
114
  type_ = type(else_) if else_ else None
118
115
 
119
116
  if not args:
120
- raise DataChainParamsError("Missing case statements")
117
+ raise DataChainParamsError("Missing statements")
121
118
 
122
119
  for arg in args:
123
120
  if type_ and not isinstance(arg[1], type_):
124
- raise DataChainParamsError("Case statement values must be of the same type")
121
+ raise DataChainParamsError("Statement values must be of the same type")
125
122
  type_ = type(arg[1])
126
123
 
127
124
  if type_ not in supported_types:
128
125
  raise DataChainParamsError(
129
- f"Case supports only python literals ({supported_types}) for values"
126
+ f"Only python literals ({supported_types}) are supported for values"
130
127
  )
131
128
 
132
129
  kwargs = {"else_": else_}
133
130
  return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
131
+
132
+
133
+ def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
134
+ """
135
+ Returns the ifelse function that produces if expression which has a condition
136
+ and values for true and false outcome. Results can only be python primitives
137
+ like string, numbes or booleans. Result type is inferred from the values.
138
+
139
+ Args:
140
+ condition: BinaryExpression - condition which is evaluated
141
+ if_val: (str | int | float | complex | bool): value for true condition outcome
142
+ else_val: (str | int | float | complex | bool): value for false condition
143
+ outcome
144
+
145
+ Returns:
146
+ Func: A Func object that represents the ifelse function.
147
+
148
+ Example:
149
+ ```py
150
+ dc.mutate(
151
+ res=func.ifelse(C("num") > 0, "P", "N"),
152
+ )
153
+ ```
154
+ """
155
+ return case((condition, if_val), else_=else_val)
datachain/lib/arrow.py CHANGED
@@ -91,7 +91,9 @@ class ArrowGenerator(Generator):
91
91
  yield from record_batch.to_pylist()
92
92
 
93
93
  it = islice(iter_records(), self.nrows)
94
- with tqdm(it, desc="Parsed by pyarrow", unit="rows", total=self.nrows) as pbar:
94
+ with tqdm(
95
+ it, desc="Parsed by pyarrow", unit="rows", total=self.nrows, leave=False
96
+ ) as pbar:
95
97
  for index, record in enumerate(pbar):
96
98
  yield self._process_record(
97
99
  record, file, index, hf_schema, use_datachain_schema
datachain/lib/dc.py CHANGED
@@ -451,6 +451,7 @@ class DataChain:
451
451
  return dc
452
452
 
453
453
  if update or not list_ds_exists:
454
+ # disable prefetch for listing, as it pre-downloads all files
454
455
  (
455
456
  cls.from_records(
456
457
  DataChain.DEFAULT_FILE_RECORD,
@@ -458,6 +459,7 @@ class DataChain:
458
459
  settings=settings,
459
460
  in_memory=in_memory,
460
461
  )
462
+ .settings(prefetch=0)
461
463
  .gen(
462
464
  list_bucket(list_uri, cache, client_config=client_config),
463
465
  output={f"{object_name}": File},
@@ -1534,7 +1536,7 @@ class DataChain:
1534
1536
 
1535
1537
  Example:
1536
1538
  ```py
1537
- diff = persons.diff(
1539
+ res = persons.compare(
1538
1540
  new_persons,
1539
1541
  on=["id"],
1540
1542
  right_on=["other_id"],
@@ -1547,9 +1549,9 @@ class DataChain:
1547
1549
  )
1548
1550
  ```
1549
1551
  """
1550
- from datachain.lib.diff import compare as chain_compare
1552
+ from datachain.diff import _compare
1551
1553
 
1552
- return chain_compare(
1554
+ return _compare(
1553
1555
  self,
1554
1556
  other,
1555
1557
  on,
datachain/lib/file.py CHANGED
@@ -269,10 +269,21 @@ class File(DataModel):
269
269
  client = self._catalog.get_client(self.source)
270
270
  client.download(self, callback=self._download_cb)
271
271
 
272
- async def _prefetch(self) -> None:
273
- if self._caching_enabled:
274
- client = self._catalog.get_client(self.source)
275
- await client._download(self, callback=self._download_cb)
272
+ async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
273
+ from datachain.client.hf import HfClient
274
+
275
+ if self._catalog is None:
276
+ raise RuntimeError("cannot prefetch file because catalog is not setup")
277
+
278
+ client = self._catalog.get_client(self.source)
279
+ if client.protocol == HfClient.protocol:
280
+ return False
281
+
282
+ await client._download(self, callback=download_cb or self._download_cb)
283
+ self._set_stream(
284
+ self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
285
+ )
286
+ return True
276
287
 
277
288
  def get_local_path(self) -> Optional[str]:
278
289
  """Return path to a file in a local cache.
datachain/lib/hf.py CHANGED
@@ -95,7 +95,7 @@ class HFGenerator(Generator):
95
95
  ds = self.ds_dict[split]
96
96
  if split:
97
97
  desc += f" split '{split}'"
98
- with tqdm(desc=desc, unit=" rows") as pbar:
98
+ with tqdm(desc=desc, unit=" rows", leave=False) as pbar:
99
99
  for row in ds:
100
100
  output_dict = {}
101
101
  if split and "split" in self.output_schema.model_fields:
datachain/lib/pytorch.py CHANGED
@@ -1,5 +1,8 @@
1
1
  import logging
2
- from collections.abc import Iterator
2
+ import os
3
+ import weakref
4
+ from collections.abc import Generator, Iterable, Iterator
5
+ from contextlib import closing
3
6
  from typing import TYPE_CHECKING, Any, Callable, Optional
4
7
 
5
8
  from PIL import Image
@@ -9,15 +12,19 @@ from torch.utils.data import IterableDataset, get_worker_info
9
12
  from torchvision.transforms import v2
10
13
 
11
14
  from datachain import Session
12
- from datachain.asyn import AsyncMapper
15
+ from datachain.cache import get_temp_cache
13
16
  from datachain.catalog import Catalog, get_catalog
14
17
  from datachain.lib.dc import DataChain
15
18
  from datachain.lib.settings import Settings
16
19
  from datachain.lib.text import convert_text
20
+ from datachain.progress import CombinedDownloadCallback
21
+ from datachain.query.dataset import get_download_callback
17
22
 
18
23
  if TYPE_CHECKING:
19
24
  from torchvision.transforms.v2 import Transform
20
25
 
26
+ from datachain.cache import DataChainCache as Cache
27
+
21
28
 
22
29
  logger = logging.getLogger("datachain")
23
30
 
@@ -75,6 +82,19 @@ class PytorchDataset(IterableDataset):
75
82
  if (prefetch := dc_settings.prefetch) is not None:
76
83
  self.prefetch = prefetch
77
84
 
85
+ self._cache = catalog.cache
86
+ self._prefetch_cache: Optional[Cache] = None
87
+ if prefetch and not self.cache:
88
+ tmp_dir = catalog.cache.tmp_dir
89
+ assert tmp_dir
90
+ self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-")
91
+ self._cache = self._prefetch_cache
92
+ weakref.finalize(self, self._prefetch_cache.destroy)
93
+
94
+ def close(self) -> None:
95
+ if self._prefetch_cache:
96
+ self._prefetch_cache.destroy()
97
+
78
98
  def _init_catalog(self, catalog: "Catalog"):
79
99
  # For compatibility with multiprocessing,
80
100
  # we can only store params in __init__(), as Catalog isn't picklable
@@ -89,9 +109,15 @@ class PytorchDataset(IterableDataset):
89
109
  ms = ms_cls(*ms_args, **ms_kwargs)
90
110
  wh_cls, wh_args, wh_kwargs = self._wh_params
91
111
  wh = wh_cls(*wh_args, **wh_kwargs)
92
- return Catalog(ms, wh, **self._catalog_params)
112
+ catalog = Catalog(ms, wh, **self._catalog_params)
113
+ catalog.cache = self._cache
114
+ return catalog
93
115
 
94
- def _rows_iter(self, total_rank: int, total_workers: int):
116
+ def _row_iter(
117
+ self,
118
+ total_rank: int,
119
+ total_workers: int,
120
+ ) -> Generator[tuple[Any, ...], None, None]:
95
121
  catalog = self._get_catalog()
96
122
  session = Session("PyTorch", catalog=catalog)
97
123
  ds = DataChain.from_dataset(
@@ -104,16 +130,34 @@ class PytorchDataset(IterableDataset):
104
130
  ds = ds.chunk(total_rank, total_workers)
105
131
  yield from ds.collect()
106
132
 
107
- def __iter__(self) -> Iterator[Any]:
108
- total_rank, total_workers = self.get_rank_and_workers()
109
- rows = self._rows_iter(total_rank, total_workers)
110
- if self.prefetch > 0:
111
- from datachain.lib.udf import _prefetch_input
112
-
113
- rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
114
- yield from map(self._process_row, rows)
133
+ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
134
+ from datachain.lib.udf import _prefetch_inputs
115
135
 
116
- def _process_row(self, row_features):
136
+ total_rank, total_workers = self.get_rank_and_workers()
137
+ download_cb = CombinedDownloadCallback()
138
+ if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
139
+ download_cb = get_download_callback(
140
+ f"{total_rank}/{total_workers}",
141
+ position=total_rank,
142
+ leave=True,
143
+ )
144
+
145
+ rows = self._row_iter(total_rank, total_workers)
146
+ rows = _prefetch_inputs(
147
+ rows,
148
+ self.prefetch,
149
+ download_cb=download_cb,
150
+ after_prefetch=download_cb.increment_file_count,
151
+ )
152
+
153
+ with download_cb, closing(rows):
154
+ yield from rows
155
+
156
+ def __iter__(self) -> Iterator[list[Any]]:
157
+ with closing(self._iter_with_prefetch()) as rows:
158
+ yield from map(self._process_row, rows)
159
+
160
+ def _process_row(self, row_features: Iterable[Any]) -> list[Any]:
117
161
  row = []
118
162
  for fr in row_features:
119
163
  if hasattr(fr, "read"):
datachain/lib/udf.py CHANGED
@@ -1,14 +1,16 @@
1
- import contextlib
2
1
  import sys
3
2
  import traceback
4
- from collections.abc import Iterable, Iterator, Mapping, Sequence
5
- from typing import TYPE_CHECKING, Any, Callable, Optional
3
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
4
+ from contextlib import closing, nullcontext
5
+ from functools import partial
6
+ from typing import TYPE_CHECKING, Any, Optional, TypeVar
6
7
 
7
8
  import attrs
8
9
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
9
10
  from pydantic import BaseModel
10
11
 
11
12
  from datachain.asyn import AsyncMapper
13
+ from datachain.cache import temporary_cache
12
14
  from datachain.dataset import RowDict
13
15
  from datachain.lib.convert.flatten import flatten
14
16
  from datachain.lib.data_model import DataValue
@@ -21,17 +23,22 @@ from datachain.query.batch import (
21
23
  Partition,
22
24
  RowsOutputBatch,
23
25
  )
26
+ from datachain.utils import safe_closing
24
27
 
25
28
  if TYPE_CHECKING:
26
29
  from collections import abc
30
+ from contextlib import AbstractContextManager
27
31
 
28
32
  from typing_extensions import Self
29
33
 
34
+ from datachain.cache import DataChainCache as Cache
30
35
  from datachain.catalog import Catalog
31
36
  from datachain.lib.signal_schema import SignalSchema
32
37
  from datachain.lib.udf_signature import UdfSignature
33
38
  from datachain.query.batch import RowsOutput
34
39
 
40
+ T = TypeVar("T", bound=Sequence[Any])
41
+
35
42
 
36
43
  class UdfError(DataChainParamsError):
37
44
  def __init__(self, msg):
@@ -98,6 +105,10 @@ class UDFAdapter:
98
105
  processed_cb,
99
106
  )
100
107
 
108
+ @property
109
+ def prefetch(self) -> int:
110
+ return self.inner.prefetch
111
+
101
112
 
102
113
  class UDFBase(AbstractUDF):
103
114
  """Base class for stateful user-defined functions.
@@ -148,12 +159,11 @@ class UDFBase(AbstractUDF):
148
159
  """
149
160
 
150
161
  is_output_batched = False
151
- catalog: "Optional[Catalog]"
162
+ prefetch: int = 0
152
163
 
153
164
  def __init__(self):
154
165
  self.params: Optional[SignalSchema] = None
155
166
  self.output = None
156
- self.catalog = None
157
167
  self._func = None
158
168
 
159
169
  def process(self, *args, **kwargs):
@@ -242,26 +252,23 @@ class UDFBase(AbstractUDF):
242
252
  return flatten(obj) if isinstance(obj, BaseModel) else [obj]
243
253
 
244
254
  def _parse_row(
245
- self, row_dict: RowDict, cache: bool, download_cb: Callback
255
+ self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
246
256
  ) -> list[DataValue]:
247
257
  assert self.params
248
258
  row = [row_dict[p] for p in self.params.to_udf_spec()]
249
259
  obj_row = self.params.row_to_objs(row)
250
260
  for obj in obj_row:
251
261
  if isinstance(obj, File):
252
- assert self.catalog is not None
253
- obj._set_stream(
254
- self.catalog, caching_enabled=cache, download_cb=download_cb
255
- )
262
+ obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
256
263
  return obj_row
257
264
 
258
- def _prepare_row(self, row, udf_fields, cache, download_cb):
265
+ def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
259
266
  row_dict = RowDict(zip(udf_fields, row))
260
- return self._parse_row(row_dict, cache, download_cb)
267
+ return self._parse_row(row_dict, catalog, cache, download_cb)
261
268
 
262
- def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
269
+ def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
263
270
  row_dict = RowDict(zip(udf_fields, row))
264
- udf_input = self._parse_row(row_dict, cache, download_cb)
271
+ udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
265
272
  return row_dict["sys__id"], *udf_input
266
273
 
267
274
  def process_safe(self, obj_rows):
@@ -279,13 +286,47 @@ class UDFBase(AbstractUDF):
279
286
  return result_objs
280
287
 
281
288
 
282
- async def _prefetch_input(row):
289
+ def noop(*args, **kwargs):
290
+ pass
291
+
292
+
293
+ async def _prefetch_input(
294
+ row: T,
295
+ download_cb: Optional["Callback"] = None,
296
+ after_prefetch: "Callable[[], None]" = noop,
297
+ ) -> T:
283
298
  for obj in row:
284
- if isinstance(obj, File):
285
- await obj._prefetch()
299
+ if isinstance(obj, File) and await obj._prefetch(download_cb):
300
+ after_prefetch()
286
301
  return row
287
302
 
288
303
 
304
+ def _prefetch_inputs(
305
+ prepared_inputs: "Iterable[T]",
306
+ prefetch: int = 0,
307
+ download_cb: Optional["Callback"] = None,
308
+ after_prefetch: "Callable[[], None]" = noop,
309
+ ) -> "abc.Generator[T, None, None]":
310
+ if prefetch > 0:
311
+ f = partial(
312
+ _prefetch_input,
313
+ download_cb=download_cb,
314
+ after_prefetch=after_prefetch,
315
+ )
316
+ prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
317
+ yield from prepared_inputs
318
+
319
+
320
+ def _get_cache(
321
+ cache: "Cache", prefetch: int = 0, use_cache: bool = False
322
+ ) -> "AbstractContextManager[Cache]":
323
+ tmp_dir = cache.tmp_dir
324
+ assert tmp_dir
325
+ if prefetch and not use_cache:
326
+ return temporary_cache(tmp_dir, prefix="prefetch-")
327
+ return nullcontext(cache)
328
+
329
+
289
330
  class Mapper(UDFBase):
290
331
  """Inherit from this class to pass to `DataChain.map()`."""
291
332
 
@@ -300,18 +341,18 @@ class Mapper(UDFBase):
300
341
  download_cb: Callback = DEFAULT_CALLBACK,
301
342
  processed_cb: Callback = DEFAULT_CALLBACK,
302
343
  ) -> Iterator[Iterable[UDFResult]]:
303
- self.catalog = catalog
304
344
  self.setup()
305
- prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
306
- self._prepare_row_and_id(row, udf_fields, cache, download_cb)
307
- for row in udf_inputs
308
- )
309
- if self.prefetch > 0:
310
- prepared_inputs = AsyncMapper(
311
- _prefetch_input, prepared_inputs, workers=self.prefetch
312
- ).iterate()
313
345
 
314
- with contextlib.closing(prepared_inputs):
346
+ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
347
+ with safe_closing(udf_inputs):
348
+ for row in udf_inputs:
349
+ yield self._prepare_row_and_id(
350
+ row, udf_fields, catalog, cache, download_cb
351
+ )
352
+
353
+ prepared_inputs = _prepare_rows(udf_inputs)
354
+ prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
355
+ with closing(prepared_inputs):
315
356
  for id_, *udf_args in prepared_inputs:
316
357
  result_objs = self.process_safe(udf_args)
317
358
  udf_output = self._flatten_row(result_objs)
@@ -336,14 +377,15 @@ class BatchMapper(UDFBase):
336
377
  download_cb: Callback = DEFAULT_CALLBACK,
337
378
  processed_cb: Callback = DEFAULT_CALLBACK,
338
379
  ) -> Iterator[Iterable[UDFResult]]:
339
- self.catalog = catalog
340
380
  self.setup()
341
381
 
342
382
  for batch in udf_inputs:
343
383
  n_rows = len(batch.rows)
344
384
  row_ids, *udf_args = zip(
345
385
  *[
346
- self._prepare_row_and_id(row, udf_fields, cache, download_cb)
386
+ self._prepare_row_and_id(
387
+ row, udf_fields, catalog, cache, download_cb
388
+ )
347
389
  for row in batch.rows
348
390
  ]
349
391
  )
@@ -378,17 +420,18 @@ class Generator(UDFBase):
378
420
  download_cb: Callback = DEFAULT_CALLBACK,
379
421
  processed_cb: Callback = DEFAULT_CALLBACK,
380
422
  ) -> Iterator[Iterable[UDFResult]]:
381
- self.catalog = catalog
382
423
  self.setup()
383
- prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
384
- self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
385
- )
386
- if self.prefetch > 0:
387
- prepared_inputs = AsyncMapper(
388
- _prefetch_input, prepared_inputs, workers=self.prefetch
389
- ).iterate()
390
424
 
391
- with contextlib.closing(prepared_inputs):
425
+ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
426
+ with safe_closing(udf_inputs):
427
+ for row in udf_inputs:
428
+ yield self._prepare_row(
429
+ row, udf_fields, catalog, cache, download_cb
430
+ )
431
+
432
+ prepared_inputs = _prepare_rows(udf_inputs)
433
+ prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
434
+ with closing(prepared_inputs):
392
435
  for row in prepared_inputs:
393
436
  result_objs = self.process_safe(row)
394
437
  udf_outputs = (self._flatten_row(row) for row in result_objs)
@@ -413,13 +456,12 @@ class Aggregator(UDFBase):
413
456
  download_cb: Callback = DEFAULT_CALLBACK,
414
457
  processed_cb: Callback = DEFAULT_CALLBACK,
415
458
  ) -> Iterator[Iterable[UDFResult]]:
416
- self.catalog = catalog
417
459
  self.setup()
418
460
 
419
461
  for batch in udf_inputs:
420
462
  udf_args = zip(
421
463
  *[
422
- self._prepare_row(row, udf_fields, cache, download_cb)
464
+ self._prepare_row(row, udf_fields, catalog, cache, download_cb)
423
465
  for row in batch.rows
424
466
  ]
425
467
  )
datachain/listing.py CHANGED
@@ -153,6 +153,7 @@ class Listing:
153
153
  unit_scale=True,
154
154
  unit_divisor=1000,
155
155
  total=total_files,
156
+ leave=False,
156
157
  )
157
158
 
158
159
  counter = 0
datachain/progress.py CHANGED
@@ -5,6 +5,7 @@ import sys
5
5
  from threading import RLock
6
6
  from typing import Any, ClassVar
7
7
 
8
+ from fsspec import Callback
8
9
  from fsspec.callbacks import TqdmCallback
9
10
  from tqdm import tqdm
10
11
 
@@ -132,8 +133,24 @@ class Tqdm(tqdm):
132
133
  return d
133
134
 
134
135
 
135
- class CombinedDownloadCallback(TqdmCallback):
136
+ class CombinedDownloadCallback(Callback):
136
137
  def set_size(self, size):
137
138
  # This is a no-op to prevent fsspec's .get_file() from setting the combined
138
139
  # download size to the size of the current file.
139
140
  pass
141
+
142
+ def increment_file_count(self, n: int = 1) -> None:
143
+ pass
144
+
145
+
146
+ class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
147
+ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
148
+ self.files_count = 0
149
+ tqdm_kwargs = tqdm_kwargs or {}
150
+ tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
151
+ super().__init__(tqdm_kwargs, *args, **kwargs)
152
+
153
+ def increment_file_count(self, n: int = 1) -> None:
154
+ self.files_count += n
155
+ if self.tqdm is not None:
156
+ self.tqdm.postfix = f"{self.files_count} files"