datachain 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1,10 +1,8 @@
1
1
  def add_studio_parser(subparsers, parent_parser) -> None:
2
- studio_help = "Commands to authenticate DataChain with Iterative Studio"
2
+ studio_help = "Manage Studio authentication"
3
3
  studio_description = (
4
- "Authenticate DataChain with Studio and set the token. "
5
- "Once this token has been properly configured,\n"
6
- "DataChain will utilize it for seamlessly sharing datasets\n"
7
- "and using Studio features from CLI"
4
+ "Manage authentication and settings for Studio. "
5
+ "Configure tokens for sharing datasets and using Studio features."
8
6
  )
9
7
 
10
8
  studio_parser = subparsers.add_parser(
@@ -15,14 +13,13 @@ def add_studio_parser(subparsers, parent_parser) -> None:
15
13
  )
16
14
  studio_subparser = studio_parser.add_subparsers(
17
15
  dest="cmd",
18
- help="Use `DataChain studio CMD --help` to display command-specific help.",
19
- required=True,
16
+ help="Use `datachain studio CMD --help` to display command-specific help",
20
17
  )
21
18
 
22
- studio_login_help = "Authenticate DataChain with Studio host"
19
+ studio_login_help = "Authenticate with Studio"
23
20
  studio_login_description = (
24
- "By default, this command authenticates the DataChain with Studio\n"
25
- "using default scopes and assigns a random name as the token name."
21
+ "Authenticate with Studio using default scopes. "
22
+ "A random name will be assigned as the token name if not specified."
26
23
  )
27
24
  login_parser = studio_subparser.add_parser(
28
25
  "login",
@@ -36,14 +33,14 @@ def add_studio_parser(subparsers, parent_parser) -> None:
36
33
  "--hostname",
37
34
  action="store",
38
35
  default=None,
39
- help="The hostname of the Studio instance to authenticate with.",
36
+ help="Hostname of the Studio instance",
40
37
  )
41
38
  login_parser.add_argument(
42
39
  "-s",
43
40
  "--scopes",
44
41
  action="store",
45
42
  default=None,
46
- help="The scopes for the authentication token. ",
43
+ help="Authentication token scopes",
47
44
  )
48
45
 
49
46
  login_parser.add_argument(
@@ -51,21 +48,20 @@ def add_studio_parser(subparsers, parent_parser) -> None:
51
48
  "--name",
52
49
  action="store",
53
50
  default=None,
54
- help="The name of the authentication token. It will be used to\n"
55
- "identify token shown in Studio profile.",
51
+ help="Authentication token name (shown in Studio profile)",
56
52
  )
57
53
 
58
54
  login_parser.add_argument(
59
55
  "--no-open",
60
56
  action="store_true",
61
57
  default=False,
62
- help="Use authentication flow based on user code.\n"
63
- "You will be presented with user code to enter in browser.\n"
64
- "DataChain will also use this if it cannot launch browser on your behalf.",
58
+ help="Use code-based authentication without browser",
65
59
  )
66
60
 
67
- studio_logout_help = "Logout user from Studio"
68
- studio_logout_description = "This removes the studio token from your global config."
61
+ studio_logout_help = "Log out from Studio"
62
+ studio_logout_description = (
63
+ "Remove the Studio authentication token from global config."
64
+ )
69
65
 
70
66
  studio_subparser.add_parser(
71
67
  "logout",
@@ -74,10 +70,8 @@ def add_studio_parser(subparsers, parent_parser) -> None:
74
70
  help=studio_logout_help,
75
71
  )
76
72
 
77
- studio_team_help = "Set the default team for DataChain"
78
- studio_team_description = (
79
- "Set the default team for DataChain to use when interacting with Studio."
80
- )
73
+ studio_team_help = "Set default team for Studio operations"
74
+ studio_team_description = "Set the default team for Studio operations."
81
75
 
82
76
  team_parser = studio_subparser.add_parser(
83
77
  "team",
@@ -88,39 +82,21 @@ def add_studio_parser(subparsers, parent_parser) -> None:
88
82
  team_parser.add_argument(
89
83
  "team_name",
90
84
  action="store",
91
- help="The name of the team to set as the default.",
85
+ help="Name of the team to set as default",
92
86
  )
93
87
  team_parser.add_argument(
94
88
  "--global",
95
89
  action="store_true",
96
90
  default=False,
97
- help="Set the team globally for all DataChain projects.",
91
+ help="Set team globally for all projects",
98
92
  )
99
93
 
100
- studio_token_help = "View the token datachain uses to contact Studio" # noqa: S105 # nosec B105
94
+ studio_token_help = "View Studio authentication token" # noqa: S105
95
+ studio_token_description = "Display the current authentication token for Studio." # noqa: S105
101
96
 
102
97
  studio_subparser.add_parser(
103
98
  "token",
104
99
  parents=[parent_parser],
105
- description=studio_token_help,
100
+ description=studio_token_description,
106
101
  help=studio_token_help,
107
102
  )
108
-
109
- studio_ls_dataset_help = "List the available datasets from Studio"
110
- studio_ls_dataset_description = (
111
- "This command lists all the datasets available in Studio.\n"
112
- "It will show the dataset name and the number of versions available."
113
- )
114
-
115
- ls_dataset_parser = studio_subparser.add_parser(
116
- "dataset",
117
- parents=[parent_parser],
118
- description=studio_ls_dataset_description,
119
- help=studio_ls_dataset_help,
120
- )
121
- ls_dataset_parser.add_argument(
122
- "--team",
123
- action="store",
124
- default=None,
125
- help="The team to list datasets for. By default, it will use team from config.",
126
- )
@@ -30,7 +30,7 @@ def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Act
30
30
  "sources",
31
31
  type=str,
32
32
  nargs=nargs,
33
- help="Data sources - paths to cloud storage dirs",
33
+ help="Data sources - paths to cloud storage directories",
34
34
  )
35
35
 
36
36
 
datachain/client/azure.py CHANGED
@@ -2,7 +2,7 @@ from typing import Any, Optional
2
2
  from urllib.parse import parse_qs, urlsplit, urlunsplit
3
3
 
4
4
  from adlfs import AzureBlobFileSystem
5
- from tqdm import tqdm
5
+ from tqdm.auto import tqdm
6
6
 
7
7
  from datachain.lib.file import File
8
8
 
@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError
23
23
  from dvc_objects.fs.system import reflink
24
24
  from fsspec.asyn import get_loop, sync
25
25
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
26
- from tqdm import tqdm
26
+ from tqdm.auto import tqdm
27
27
 
28
28
  from datachain.cache import DataChainCache
29
29
  from datachain.client.fileslice import FileWrapper
datachain/client/gcs.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any, Optional, cast
7
7
 
8
8
  from dateutil.parser import isoparse
9
9
  from gcsfs import GCSFileSystem
10
- from tqdm import tqdm
10
+ from tqdm.auto import tqdm
11
11
 
12
12
  from datachain.lib.file import File
13
13
 
datachain/client/local.py CHANGED
@@ -38,7 +38,7 @@ class FileClient(Client):
38
38
  def get_uri(cls, name: str) -> "StorageURI":
39
39
  from datachain.dataset import StorageURI
40
40
 
41
- return StorageURI(f'{cls.PREFIX}/{name.removeprefix("/")}')
41
+ return StorageURI(f"{cls.PREFIX}/{name.removeprefix('/')}")
42
42
 
43
43
  @classmethod
44
44
  def ls_buckets(cls, **kwargs):
datachain/client/s3.py CHANGED
@@ -5,7 +5,7 @@ from urllib.parse import parse_qs, urlsplit, urlunsplit
5
5
 
6
6
  from botocore.exceptions import NoCredentialsError
7
7
  from s3fs import S3FileSystem
8
- from tqdm import tqdm
8
+ from tqdm.auto import tqdm
9
9
 
10
10
  from datachain.lib.file import File
11
11
 
@@ -21,7 +21,7 @@ from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
21
21
  from sqlalchemy.sql import func
22
22
  from sqlalchemy.sql.expression import bindparam, cast
23
23
  from sqlalchemy.sql.selectable import Select
24
- from tqdm import tqdm
24
+ from tqdm.auto import tqdm
25
25
 
26
26
  import datachain.sql.sqlite
27
27
  from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
@@ -14,7 +14,7 @@ import sqlalchemy as sa
14
14
  from sqlalchemy import Table, case, select
15
15
  from sqlalchemy.sql import func
16
16
  from sqlalchemy.sql.expression import true
17
- from tqdm import tqdm
17
+ from tqdm.auto import tqdm
18
18
 
19
19
  from datachain.client import Client
20
20
  from datachain.data_storage.schema import convert_rows_custom_column_types
datachain/lib/arrow.py CHANGED
@@ -7,7 +7,7 @@ import orjson
7
7
  import pyarrow as pa
8
8
  from fsspec.core import split_protocol
9
9
  from pyarrow.dataset import CsvFileFormat, dataset
10
- from tqdm import tqdm
10
+ from tqdm.auto import tqdm
11
11
 
12
12
  from datachain.lib.data_model import dict_to_data_model
13
13
  from datachain.lib.file import ArrowRow, File
@@ -33,7 +33,7 @@ class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
33
33
  # reads the whole file in-memory.
34
34
  (uri,) = self.references[path]
35
35
  protocol, _ = split_protocol(uri)
36
- return self.fss[protocol]._open(uri, mode, *args, **kwargs)
36
+ return self.fss[protocol].open(uri, mode, *args, **kwargs)
37
37
 
38
38
 
39
39
  class ArrowGenerator(Generator):
@@ -35,8 +35,7 @@ def unflatten_to_json_pos(
35
35
  def _normalize(name: str) -> str:
36
36
  if DEFAULT_DELIMITER in name:
37
37
  raise RuntimeError(
38
- f"variable '{name}' cannot be used "
39
- f"because it contains {DEFAULT_DELIMITER}"
38
+ f"variable '{name}' cannot be used because it contains {DEFAULT_DELIMITER}"
40
39
  )
41
40
  return _to_snake_case(name)
42
41
 
datachain/lib/dc.py CHANGED
@@ -11,6 +11,7 @@ from typing import (
11
11
  BinaryIO,
12
12
  Callable,
13
13
  ClassVar,
14
+ Literal,
14
15
  Optional,
15
16
  TypeVar,
16
17
  Union,
@@ -1276,7 +1277,12 @@ class DataChain:
1276
1277
  yield ret[0] if len(cols) == 1 else tuple(ret)
1277
1278
 
1278
1279
  def to_pytorch(
1279
- self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0
1280
+ self,
1281
+ transform=None,
1282
+ tokenizer=None,
1283
+ tokenizer_kwargs=None,
1284
+ num_samples=0,
1285
+ remove_prefetched: bool = False,
1280
1286
  ):
1281
1287
  """Convert to pytorch dataset format.
1282
1288
 
@@ -1286,6 +1292,7 @@ class DataChain:
1286
1292
  tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
1287
1293
  num_samples (int): Number of random samples to draw for each epoch.
1288
1294
  This argument is ignored if `num_samples=0` (the default).
1295
+ remove_prefetched (bool): Whether to remove prefetched files after reading.
1289
1296
 
1290
1297
  Example:
1291
1298
  ```py
@@ -1312,6 +1319,7 @@ class DataChain:
1312
1319
  tokenizer_kwargs=tokenizer_kwargs,
1313
1320
  num_samples=num_samples,
1314
1321
  dc_settings=chain._settings,
1322
+ remove_prefetched=remove_prefetched,
1315
1323
  )
1316
1324
 
1317
1325
  def remove_file_signals(self) -> "Self": # noqa: D102
@@ -1330,19 +1338,27 @@ class DataChain:
1330
1338
 
1331
1339
  Parameters:
1332
1340
  right_ds: Chain to join with.
1333
- on: Predicate or list of Predicates to join on. If both chains have the
1334
- same predicates then this predicate is enough for the join. Otherwise,
1335
- `right_on` parameter has to specify the predicates for the other chain.
1336
- right_on: Optional predicate or list of Predicates
1337
- for the `right_ds` to join.
1341
+ on: Predicate ("column.name", C("column.name"), or Func) or list of
1342
+ Predicates to join on. If both chains have the same predicates then
1343
+ this predicate is enough for the join. Otherwise, `right_on` parameter
1344
+ has to specify the predicates for the other chain.
1345
+ right_on: Optional predicate or list of Predicates for the `right_ds`
1346
+ to join.
1338
1347
  inner (bool): Whether to run inner join or outer join.
1339
- rname (str): name prefix for conflicting signal names.
1348
+ rname (str): Name prefix for conflicting signal names.
1340
1349
 
1341
- Example:
1350
+ Examples:
1342
1351
  ```py
1343
1352
  meta = meta_emd.merge(meta_pq, on=(C.name, C.emd__index),
1344
1353
  right_on=(C.name, C.pq__index))
1345
1354
  ```
1355
+
1356
+ ```py
1357
+ imgs.merge(captions,
1358
+ on=func.path.file_stem(imgs.c("file.path")),
1359
+ right_on=func.path.file_stem(captions.c("file.path"))
1360
+ ```
1361
+ )
1346
1362
  """
1347
1363
  if on is None:
1348
1364
  raise DatasetMergeError(["None"], None, "'on' must be specified")
@@ -2407,11 +2423,22 @@ class DataChain:
2407
2423
  def export_files(
2408
2424
  self,
2409
2425
  output: str,
2410
- signal="file",
2426
+ signal: str = "file",
2411
2427
  placement: FileExportPlacement = "fullpath",
2412
2428
  use_cache: bool = True,
2429
+ link_type: Literal["copy", "symlink"] = "copy",
2413
2430
  ) -> None:
2414
- """Method that exports all files from chain to some folder."""
2431
+ """Export files from a specified signal to a directory.
2432
+
2433
+ Args:
2434
+ output: Path to the target directory for exporting files.
2435
+ signal: Name of the signal to export files from.
2436
+ placement: The method to use for naming exported files.
2437
+ The possible values are: "filename", "etag", "fullpath", and "checksum".
2438
+ use_cache: If `True`, cache the files before exporting.
2439
+ link_type: Method to use for exporting files.
2440
+ Falls back to `'copy'` if symlinking fails.
2441
+ """
2415
2442
  if placement == "filename" and (
2416
2443
  self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
2417
2444
  != self._query.count()
@@ -2419,7 +2446,7 @@ class DataChain:
2419
2446
  raise ValueError("Files with the same name found")
2420
2447
 
2421
2448
  for file in self.collect(signal):
2422
- file.export(output, placement, use_cache) # type: ignore[union-attr]
2449
+ file.export(output, placement, use_cache, link_type=link_type) # type: ignore[union-attr]
2423
2450
 
2424
2451
  def shuffle(self) -> "Self":
2425
2452
  """Shuffle the rows of the chain deterministically."""
datachain/lib/file.py CHANGED
@@ -1,3 +1,4 @@
1
+ import errno
1
2
  import hashlib
2
3
  import io
3
4
  import json
@@ -76,18 +77,18 @@ class TarVFile(VFile):
76
77
  def open(cls, file: "File", location: list[dict]):
77
78
  """Stream file from tar archive based on location in archive."""
78
79
  if len(location) > 1:
79
- VFileError(file, "multiple 'location's are not supported yet")
80
+ raise VFileError(file, "multiple 'location's are not supported yet")
80
81
 
81
82
  loc = location[0]
82
83
 
83
84
  if (offset := loc.get("offset", None)) is None:
84
- VFileError(file, "'offset' is not specified")
85
+ raise VFileError(file, "'offset' is not specified")
85
86
 
86
87
  if (size := loc.get("size", None)) is None:
87
- VFileError(file, "'size' is not specified")
88
+ raise VFileError(file, "'size' is not specified")
88
89
 
89
90
  if (parent := loc.get("parent", None)) is None:
90
- VFileError(file, "'parent' is not specified")
91
+ raise VFileError(file, "'parent' is not specified")
91
92
 
92
93
  tar_file = File(**parent)
93
94
  tar_file._set_stream(file._catalog)
@@ -236,11 +237,26 @@ class File(DataModel):
236
237
  with open(destination, mode="wb") as f:
237
238
  f.write(self.read())
238
239
 
240
+ def _symlink_to(self, destination: str):
241
+ if self.location:
242
+ raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported")
243
+
244
+ if self._caching_enabled:
245
+ self.ensure_cached()
246
+ source = self.get_local_path()
247
+ assert source, "File was not cached"
248
+ elif self.source.startswith("file://"):
249
+ source = self.get_path()
250
+ else:
251
+ raise OSError(errno.EXDEV, "can't link across filesystems")
252
+ return os.symlink(source, destination)
253
+
239
254
  def export(
240
255
  self,
241
256
  output: str,
242
257
  placement: ExportPlacement = "fullpath",
243
258
  use_cache: bool = True,
259
+ link_type: Literal["copy", "symlink"] = "copy",
244
260
  ) -> None:
245
261
  """Export file to new location."""
246
262
  if use_cache:
@@ -249,6 +265,13 @@ class File(DataModel):
249
265
  dst_dir = os.path.dirname(dst)
250
266
  os.makedirs(dst_dir, exist_ok=True)
251
267
 
268
+ if link_type == "symlink":
269
+ try:
270
+ return self._symlink_to(dst)
271
+ except OSError as exc:
272
+ if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS):
273
+ raise
274
+
252
275
  self.save(dst)
253
276
 
254
277
  def _set_stream(
datachain/lib/hf.py CHANGED
@@ -29,7 +29,7 @@ from io import BytesIO
29
29
  from typing import TYPE_CHECKING, Any, Union
30
30
 
31
31
  import PIL
32
- from tqdm import tqdm
32
+ from tqdm.auto import tqdm
33
33
 
34
34
  from datachain.lib.arrow import arrow_type_mapper
35
35
  from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
datachain/lib/listing.py CHANGED
@@ -113,14 +113,14 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], st
113
113
  telemetry.log_param("client", client.PREFIX)
114
114
 
115
115
  if not uri.endswith("/") and _isfile(client, uri):
116
- return None, f'{storage_uri}/{path.lstrip("/")}', path
116
+ return None, f"{storage_uri}/{path.lstrip('/')}", path
117
117
  if uses_glob(path):
118
118
  lst_uri_path = posixpath.dirname(path)
119
119
  else:
120
- storage_uri, path = Client.parse_url(f'{uri.rstrip("/")}/')
120
+ storage_uri, path = Client.parse_url(f"{uri.rstrip('/')}/")
121
121
  lst_uri_path = path
122
122
 
123
- lst_uri = f'{storage_uri}/{lst_uri_path.lstrip("/")}'
123
+ lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
124
124
  ds_name = (
125
125
  f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
126
126
  )
@@ -180,7 +180,7 @@ def get_listing(
180
180
  # for local file system we need to fix listing path / prefix
181
181
  # if we are reusing existing listing
182
182
  if isinstance(client, FileClient) and listing and listing.name != ds_name:
183
- list_path = f'{ds_name.strip("/").removeprefix(listing.name)}/{list_path}'
183
+ list_path = f"{ds_name.strip('/').removeprefix(listing.name)}/{list_path}"
184
184
 
185
185
  ds_name = listing.name if listing else ds_name
186
186
 
datachain/lib/pytorch.py CHANGED
@@ -50,6 +50,7 @@ class PytorchDataset(IterableDataset):
50
50
  tokenizer_kwargs: Optional[dict[str, Any]] = None,
51
51
  num_samples: int = 0,
52
52
  dc_settings: Optional[Settings] = None,
53
+ remove_prefetched: bool = False,
53
54
  ):
54
55
  """
55
56
  Pytorch IterableDataset that streams DataChain datasets.
@@ -84,6 +85,7 @@ class PytorchDataset(IterableDataset):
84
85
 
85
86
  self._cache = catalog.cache
86
87
  self._prefetch_cache: Optional[Cache] = None
88
+ self._remove_prefetched = remove_prefetched
87
89
  if prefetch and not self.cache:
88
90
  tmp_dir = catalog.cache.tmp_dir
89
91
  assert tmp_dir
@@ -147,7 +149,7 @@ class PytorchDataset(IterableDataset):
147
149
  rows,
148
150
  self.prefetch,
149
151
  download_cb=download_cb,
150
- after_prefetch=download_cb.increment_file_count,
152
+ remove_prefetched=self._remove_prefetched,
151
153
  )
152
154
 
153
155
  with download_cb, closing(rows):
datachain/lib/udf.py CHANGED
@@ -16,6 +16,7 @@ from datachain.lib.convert.flatten import flatten
16
16
  from datachain.lib.data_model import DataValue
17
17
  from datachain.lib.file import File
18
18
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
19
+ from datachain.progress import CombinedDownloadCallback
19
20
  from datachain.query.batch import (
20
21
  Batch,
21
22
  BatchingStrategy,
@@ -301,20 +302,42 @@ async def _prefetch_input(
301
302
  return row
302
303
 
303
304
 
305
+ def _remove_prefetched(row: T) -> None:
306
+ for obj in row:
307
+ if isinstance(obj, File):
308
+ catalog = obj._catalog
309
+ assert catalog is not None
310
+ try:
311
+ catalog.cache.remove(obj)
312
+ except Exception as e: # noqa: BLE001
313
+ print(f"Failed to remove prefetched item {obj.name!r}: {e!s}")
314
+
315
+
304
316
  def _prefetch_inputs(
305
317
  prepared_inputs: "Iterable[T]",
306
318
  prefetch: int = 0,
307
319
  download_cb: Optional["Callback"] = None,
308
- after_prefetch: "Callable[[], None]" = noop,
320
+ after_prefetch: Optional[Callable[[], None]] = None,
321
+ remove_prefetched: bool = False,
309
322
  ) -> "abc.Generator[T, None, None]":
310
- if prefetch > 0:
311
- f = partial(
312
- _prefetch_input,
313
- download_cb=download_cb,
314
- after_prefetch=after_prefetch,
315
- )
316
- prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
317
- yield from prepared_inputs
323
+ if not prefetch:
324
+ yield from prepared_inputs
325
+ return
326
+
327
+ if after_prefetch is None:
328
+ after_prefetch = noop
329
+ if isinstance(download_cb, CombinedDownloadCallback):
330
+ after_prefetch = download_cb.increment_file_count
331
+
332
+ f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
333
+ mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
334
+ with closing(mapper.iterate()) as row_iter:
335
+ for row in row_iter:
336
+ try:
337
+ yield row # type: ignore[misc]
338
+ finally:
339
+ if remove_prefetched:
340
+ _remove_prefetched(row)
318
341
 
319
342
 
320
343
  def _get_cache(
@@ -351,7 +374,13 @@ class Mapper(UDFBase):
351
374
  )
352
375
 
353
376
  prepared_inputs = _prepare_rows(udf_inputs)
354
- prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
377
+ prepared_inputs = _prefetch_inputs(
378
+ prepared_inputs,
379
+ self.prefetch,
380
+ download_cb=download_cb,
381
+ remove_prefetched=bool(self.prefetch) and not cache,
382
+ )
383
+
355
384
  with closing(prepared_inputs):
356
385
  for id_, *udf_args in prepared_inputs:
357
386
  result_objs = self.process_safe(udf_args)
@@ -391,9 +420,9 @@ class BatchMapper(UDFBase):
391
420
  )
392
421
  result_objs = list(self.process_safe(udf_args))
393
422
  n_objs = len(result_objs)
394
- assert (
395
- n_objs == n_rows
396
- ), f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
423
+ assert n_objs == n_rows, (
424
+ f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
425
+ )
397
426
  udf_outputs = (self._flatten_row(row) for row in result_objs)
398
427
  output = [
399
428
  {"sys__id": row_id} | dict(zip(self.signal_names, signals))
@@ -429,15 +458,22 @@ class Generator(UDFBase):
429
458
  row, udf_fields, catalog, cache, download_cb
430
459
  )
431
460
 
461
+ def _process_row(row):
462
+ with safe_closing(self.process_safe(row)) as result_objs:
463
+ for result_obj in result_objs:
464
+ udf_output = self._flatten_row(result_obj)
465
+ yield dict(zip(self.signal_names, udf_output))
466
+
432
467
  prepared_inputs = _prepare_rows(udf_inputs)
433
- prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
468
+ prepared_inputs = _prefetch_inputs(
469
+ prepared_inputs,
470
+ self.prefetch,
471
+ download_cb=download_cb,
472
+ remove_prefetched=bool(self.prefetch) and not cache,
473
+ )
434
474
  with closing(prepared_inputs):
435
- for row in prepared_inputs:
436
- result_objs = self.process_safe(row)
437
- udf_outputs = (self._flatten_row(row) for row in result_objs)
438
- output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
439
- processed_cb.relative_update(1)
440
- yield output
475
+ for row in processed_cb.wrap(prepared_inputs):
476
+ yield _process_row(row)
441
477
 
442
478
  self.teardown()
443
479
 
datachain/listing.py CHANGED
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
7
7
 
8
8
  from sqlalchemy import Column
9
9
  from sqlalchemy.sql import func
10
- from tqdm import tqdm
10
+ from tqdm.auto import tqdm
11
11
 
12
12
  from datachain.node import DirType, Node, NodeWithPath
13
13
  from datachain.sql.functions import path as pathfunc
datachain/model/bbox.py CHANGED
@@ -22,9 +22,9 @@ class BBox(DataModel):
22
22
  @staticmethod
23
23
  def from_list(coords: list[float], title: str = "") -> "BBox":
24
24
  assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
25
- assert all(
26
- isinstance(value, (int, float)) for value in coords
27
- ), "Bounding box coordinates must be floats or integers."
25
+ assert all(isinstance(value, (int, float)) for value in coords), (
26
+ "Bounding box coordinates must be floats or integers."
27
+ )
28
28
  return BBox(
29
29
  title=title,
30
30
  coords=[round(c) for c in coords],
@@ -64,12 +64,12 @@ class OBBox(DataModel):
64
64
 
65
65
  @staticmethod
66
66
  def from_list(coords: list[float], title: str = "") -> "OBBox":
67
- assert (
68
- len(coords) == 8
69
- ), "Oriented bounding box must be a list of 8 coordinates."
70
- assert all(
71
- isinstance(value, (int, float)) for value in coords
72
- ), "Oriented bounding box coordinates must be floats or integers."
67
+ assert len(coords) == 8, (
68
+ "Oriented bounding box must be a list of 8 coordinates."
69
+ )
70
+ assert all(isinstance(value, (int, float)) for value in coords), (
71
+ "Oriented bounding box coordinates must be floats or integers."
72
+ )
73
73
  return OBBox(
74
74
  title=title,
75
75
  coords=[round(c) for c in coords],