datachain 0.8.8__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1,10 +1,8 @@
1
1
  def add_studio_parser(subparsers, parent_parser) -> None:
2
- studio_help = "Commands to authenticate DataChain with Iterative Studio"
2
+ studio_help = "Manage Studio authentication"
3
3
  studio_description = (
4
- "Authenticate DataChain with Studio and set the token. "
5
- "Once this token has been properly configured,\n"
6
- "DataChain will utilize it for seamlessly sharing datasets\n"
7
- "and using Studio features from CLI"
4
+ "Manage authentication and settings for Studio. "
5
+ "Configure tokens for sharing datasets and using Studio features."
8
6
  )
9
7
 
10
8
  studio_parser = subparsers.add_parser(
@@ -15,14 +13,13 @@ def add_studio_parser(subparsers, parent_parser) -> None:
15
13
  )
16
14
  studio_subparser = studio_parser.add_subparsers(
17
15
  dest="cmd",
18
- help="Use `DataChain studio CMD --help` to display command-specific help.",
19
- required=True,
16
+ help="Use `datachain studio CMD --help` to display command-specific help",
20
17
  )
21
18
 
22
- studio_login_help = "Authenticate DataChain with Studio host"
19
+ studio_login_help = "Authenticate with Studio"
23
20
  studio_login_description = (
24
- "By default, this command authenticates the DataChain with Studio\n"
25
- "using default scopes and assigns a random name as the token name."
21
+ "Authenticate with Studio using default scopes. "
22
+ "A random name will be assigned as the token name if not specified."
26
23
  )
27
24
  login_parser = studio_subparser.add_parser(
28
25
  "login",
@@ -36,14 +33,14 @@ def add_studio_parser(subparsers, parent_parser) -> None:
36
33
  "--hostname",
37
34
  action="store",
38
35
  default=None,
39
- help="The hostname of the Studio instance to authenticate with.",
36
+ help="Hostname of the Studio instance",
40
37
  )
41
38
  login_parser.add_argument(
42
39
  "-s",
43
40
  "--scopes",
44
41
  action="store",
45
42
  default=None,
46
- help="The scopes for the authentication token. ",
43
+ help="Authentication token scopes",
47
44
  )
48
45
 
49
46
  login_parser.add_argument(
@@ -51,21 +48,20 @@ def add_studio_parser(subparsers, parent_parser) -> None:
51
48
  "--name",
52
49
  action="store",
53
50
  default=None,
54
- help="The name of the authentication token. It will be used to\n"
55
- "identify token shown in Studio profile.",
51
+ help="Authentication token name (shown in Studio profile)",
56
52
  )
57
53
 
58
54
  login_parser.add_argument(
59
55
  "--no-open",
60
56
  action="store_true",
61
57
  default=False,
62
- help="Use authentication flow based on user code.\n"
63
- "You will be presented with user code to enter in browser.\n"
64
- "DataChain will also use this if it cannot launch browser on your behalf.",
58
+ help="Use code-based authentication without browser",
65
59
  )
66
60
 
67
- studio_logout_help = "Logout user from Studio"
68
- studio_logout_description = "This removes the studio token from your global config."
61
+ studio_logout_help = "Log out from Studio"
62
+ studio_logout_description = (
63
+ "Remove the Studio authentication token from global config."
64
+ )
69
65
 
70
66
  studio_subparser.add_parser(
71
67
  "logout",
@@ -74,10 +70,8 @@ def add_studio_parser(subparsers, parent_parser) -> None:
74
70
  help=studio_logout_help,
75
71
  )
76
72
 
77
- studio_team_help = "Set the default team for DataChain"
78
- studio_team_description = (
79
- "Set the default team for DataChain to use when interacting with Studio."
80
- )
73
+ studio_team_help = "Set default team for Studio operations"
74
+ studio_team_description = "Set the default team for Studio operations."
81
75
 
82
76
  team_parser = studio_subparser.add_parser(
83
77
  "team",
@@ -88,39 +82,21 @@ def add_studio_parser(subparsers, parent_parser) -> None:
88
82
  team_parser.add_argument(
89
83
  "team_name",
90
84
  action="store",
91
- help="The name of the team to set as the default.",
85
+ help="Name of the team to set as default",
92
86
  )
93
87
  team_parser.add_argument(
94
88
  "--global",
95
89
  action="store_true",
96
90
  default=False,
97
- help="Set the team globally for all DataChain projects.",
91
+ help="Set team globally for all projects",
98
92
  )
99
93
 
100
- studio_token_help = "View the token datachain uses to contact Studio" # noqa: S105 # nosec B105
94
+ studio_token_help = "View Studio authentication token" # noqa: S105
95
+ studio_token_description = "Display the current authentication token for Studio." # noqa: S105
101
96
 
102
97
  studio_subparser.add_parser(
103
98
  "token",
104
99
  parents=[parent_parser],
105
- description=studio_token_help,
100
+ description=studio_token_description,
106
101
  help=studio_token_help,
107
102
  )
108
-
109
- studio_ls_dataset_help = "List the available datasets from Studio"
110
- studio_ls_dataset_description = (
111
- "This command lists all the datasets available in Studio.\n"
112
- "It will show the dataset name and the number of versions available."
113
- )
114
-
115
- ls_dataset_parser = studio_subparser.add_parser(
116
- "dataset",
117
- parents=[parent_parser],
118
- description=studio_ls_dataset_description,
119
- help=studio_ls_dataset_help,
120
- )
121
- ls_dataset_parser.add_argument(
122
- "--team",
123
- action="store",
124
- default=None,
125
- help="The team to list datasets for. By default, it will use team from config.",
126
- )
@@ -30,7 +30,7 @@ def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Act
30
30
  "sources",
31
31
  type=str,
32
32
  nargs=nargs,
33
- help="Data sources - paths to cloud storage dirs",
33
+ help="Data sources - paths to cloud storage directories",
34
34
  )
35
35
 
36
36
 
datachain/client/local.py CHANGED
@@ -38,7 +38,7 @@ class FileClient(Client):
38
38
  def get_uri(cls, name: str) -> "StorageURI":
39
39
  from datachain.dataset import StorageURI
40
40
 
41
- return StorageURI(f'{cls.PREFIX}/{name.removeprefix("/")}')
41
+ return StorageURI(f"{cls.PREFIX}/{name.removeprefix('/')}")
42
42
 
43
43
  @classmethod
44
44
  def ls_buckets(cls, **kwargs):
datachain/lib/arrow.py CHANGED
@@ -33,7 +33,7 @@ class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
33
33
  # reads the whole file in-memory.
34
34
  (uri,) = self.references[path]
35
35
  protocol, _ = split_protocol(uri)
36
- return self.fss[protocol]._open(uri, mode, *args, **kwargs)
36
+ return self.fss[protocol].open(uri, mode, *args, **kwargs)
37
37
 
38
38
 
39
39
  class ArrowGenerator(Generator):
@@ -35,8 +35,7 @@ def unflatten_to_json_pos(
35
35
  def _normalize(name: str) -> str:
36
36
  if DEFAULT_DELIMITER in name:
37
37
  raise RuntimeError(
38
- f"variable '{name}' cannot be used "
39
- f"because it contains {DEFAULT_DELIMITER}"
38
+ f"variable '{name}' cannot be used because it contains {DEFAULT_DELIMITER}"
40
39
  )
41
40
  return _to_snake_case(name)
42
41
 
datachain/lib/dc.py CHANGED
@@ -11,6 +11,7 @@ from typing import (
11
11
  BinaryIO,
12
12
  Callable,
13
13
  ClassVar,
14
+ Literal,
14
15
  Optional,
15
16
  TypeVar,
16
17
  Union,
@@ -1276,7 +1277,12 @@ class DataChain:
1276
1277
  yield ret[0] if len(cols) == 1 else tuple(ret)
1277
1278
 
1278
1279
  def to_pytorch(
1279
- self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0
1280
+ self,
1281
+ transform=None,
1282
+ tokenizer=None,
1283
+ tokenizer_kwargs=None,
1284
+ num_samples=0,
1285
+ remove_prefetched: bool = False,
1280
1286
  ):
1281
1287
  """Convert to pytorch dataset format.
1282
1288
 
@@ -1286,6 +1292,7 @@ class DataChain:
1286
1292
  tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
1287
1293
  num_samples (int): Number of random samples to draw for each epoch.
1288
1294
  This argument is ignored if `num_samples=0` (the default).
1295
+ remove_prefetched (bool): Whether to remove prefetched files after reading.
1289
1296
 
1290
1297
  Example:
1291
1298
  ```py
@@ -1312,6 +1319,7 @@ class DataChain:
1312
1319
  tokenizer_kwargs=tokenizer_kwargs,
1313
1320
  num_samples=num_samples,
1314
1321
  dc_settings=chain._settings,
1322
+ remove_prefetched=remove_prefetched,
1315
1323
  )
1316
1324
 
1317
1325
  def remove_file_signals(self) -> "Self": # noqa: D102
@@ -2415,11 +2423,22 @@ class DataChain:
2415
2423
  def export_files(
2416
2424
  self,
2417
2425
  output: str,
2418
- signal="file",
2426
+ signal: str = "file",
2419
2427
  placement: FileExportPlacement = "fullpath",
2420
2428
  use_cache: bool = True,
2429
+ link_type: Literal["copy", "symlink"] = "copy",
2421
2430
  ) -> None:
2422
- """Method that exports all files from chain to some folder."""
2431
+ """Export files from a specified signal to a directory.
2432
+
2433
+ Args:
2434
+ output: Path to the target directory for exporting files.
2435
+ signal: Name of the signal to export files from.
2436
+ placement: The method to use for naming exported files.
2437
+ The possible values are: "filename", "etag", "fullpath", and "checksum".
2438
+ use_cache: If `True`, cache the files before exporting.
2439
+ link_type: Method to use for exporting files.
2440
+ Falls back to `'copy'` if symlinking fails.
2441
+ """
2423
2442
  if placement == "filename" and (
2424
2443
  self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
2425
2444
  != self._query.count()
@@ -2427,7 +2446,7 @@ class DataChain:
2427
2446
  raise ValueError("Files with the same name found")
2428
2447
 
2429
2448
  for file in self.collect(signal):
2430
- file.export(output, placement, use_cache) # type: ignore[union-attr]
2449
+ file.export(output, placement, use_cache, link_type=link_type) # type: ignore[union-attr]
2431
2450
 
2432
2451
  def shuffle(self) -> "Self":
2433
2452
  """Shuffle the rows of the chain deterministically."""
datachain/lib/file.py CHANGED
@@ -1,3 +1,4 @@
1
+ import errno
1
2
  import hashlib
2
3
  import io
3
4
  import json
@@ -76,18 +77,18 @@ class TarVFile(VFile):
76
77
  def open(cls, file: "File", location: list[dict]):
77
78
  """Stream file from tar archive based on location in archive."""
78
79
  if len(location) > 1:
79
- VFileError(file, "multiple 'location's are not supported yet")
80
+ raise VFileError(file, "multiple 'location's are not supported yet")
80
81
 
81
82
  loc = location[0]
82
83
 
83
84
  if (offset := loc.get("offset", None)) is None:
84
- VFileError(file, "'offset' is not specified")
85
+ raise VFileError(file, "'offset' is not specified")
85
86
 
86
87
  if (size := loc.get("size", None)) is None:
87
- VFileError(file, "'size' is not specified")
88
+ raise VFileError(file, "'size' is not specified")
88
89
 
89
90
  if (parent := loc.get("parent", None)) is None:
90
- VFileError(file, "'parent' is not specified")
91
+ raise VFileError(file, "'parent' is not specified")
91
92
 
92
93
  tar_file = File(**parent)
93
94
  tar_file._set_stream(file._catalog)
@@ -236,11 +237,26 @@ class File(DataModel):
236
237
  with open(destination, mode="wb") as f:
237
238
  f.write(self.read())
238
239
 
240
+ def _symlink_to(self, destination: str):
241
+ if self.location:
242
+ raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported")
243
+
244
+ if self._caching_enabled:
245
+ self.ensure_cached()
246
+ source = self.get_local_path()
247
+ assert source, "File was not cached"
248
+ elif self.source.startswith("file://"):
249
+ source = self.get_path()
250
+ else:
251
+ raise OSError(errno.EXDEV, "can't link across filesystems")
252
+ return os.symlink(source, destination)
253
+
239
254
  def export(
240
255
  self,
241
256
  output: str,
242
257
  placement: ExportPlacement = "fullpath",
243
258
  use_cache: bool = True,
259
+ link_type: Literal["copy", "symlink"] = "copy",
244
260
  ) -> None:
245
261
  """Export file to new location."""
246
262
  if use_cache:
@@ -249,6 +265,13 @@ class File(DataModel):
249
265
  dst_dir = os.path.dirname(dst)
250
266
  os.makedirs(dst_dir, exist_ok=True)
251
267
 
268
+ if link_type == "symlink":
269
+ try:
270
+ return self._symlink_to(dst)
271
+ except OSError as exc:
272
+ if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS):
273
+ raise
274
+
252
275
  self.save(dst)
253
276
 
254
277
  def _set_stream(
datachain/lib/listing.py CHANGED
@@ -113,14 +113,14 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], st
113
113
  telemetry.log_param("client", client.PREFIX)
114
114
 
115
115
  if not uri.endswith("/") and _isfile(client, uri):
116
- return None, f'{storage_uri}/{path.lstrip("/")}', path
116
+ return None, f"{storage_uri}/{path.lstrip('/')}", path
117
117
  if uses_glob(path):
118
118
  lst_uri_path = posixpath.dirname(path)
119
119
  else:
120
- storage_uri, path = Client.parse_url(f'{uri.rstrip("/")}/')
120
+ storage_uri, path = Client.parse_url(f"{uri.rstrip('/')}/")
121
121
  lst_uri_path = path
122
122
 
123
- lst_uri = f'{storage_uri}/{lst_uri_path.lstrip("/")}'
123
+ lst_uri = f"{storage_uri}/{lst_uri_path.lstrip('/')}"
124
124
  ds_name = (
125
125
  f"{LISTING_PREFIX}{storage_uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}"
126
126
  )
@@ -180,7 +180,7 @@ def get_listing(
180
180
  # for local file system we need to fix listing path / prefix
181
181
  # if we are reusing existing listing
182
182
  if isinstance(client, FileClient) and listing and listing.name != ds_name:
183
- list_path = f'{ds_name.strip("/").removeprefix(listing.name)}/{list_path}'
183
+ list_path = f"{ds_name.strip('/').removeprefix(listing.name)}/{list_path}"
184
184
 
185
185
  ds_name = listing.name if listing else ds_name
186
186
 
datachain/lib/pytorch.py CHANGED
@@ -50,6 +50,7 @@ class PytorchDataset(IterableDataset):
50
50
  tokenizer_kwargs: Optional[dict[str, Any]] = None,
51
51
  num_samples: int = 0,
52
52
  dc_settings: Optional[Settings] = None,
53
+ remove_prefetched: bool = False,
53
54
  ):
54
55
  """
55
56
  Pytorch IterableDataset that streams DataChain datasets.
@@ -84,6 +85,7 @@ class PytorchDataset(IterableDataset):
84
85
 
85
86
  self._cache = catalog.cache
86
87
  self._prefetch_cache: Optional[Cache] = None
88
+ self._remove_prefetched = remove_prefetched
87
89
  if prefetch and not self.cache:
88
90
  tmp_dir = catalog.cache.tmp_dir
89
91
  assert tmp_dir
@@ -147,7 +149,7 @@ class PytorchDataset(IterableDataset):
147
149
  rows,
148
150
  self.prefetch,
149
151
  download_cb=download_cb,
150
- after_prefetch=download_cb.increment_file_count,
152
+ remove_prefetched=self._remove_prefetched,
151
153
  )
152
154
 
153
155
  with download_cb, closing(rows):
datachain/lib/udf.py CHANGED
@@ -16,6 +16,7 @@ from datachain.lib.convert.flatten import flatten
16
16
  from datachain.lib.data_model import DataValue
17
17
  from datachain.lib.file import File
18
18
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
19
+ from datachain.progress import CombinedDownloadCallback
19
20
  from datachain.query.batch import (
20
21
  Batch,
21
22
  BatchingStrategy,
@@ -301,20 +302,42 @@ async def _prefetch_input(
301
302
  return row
302
303
 
303
304
 
305
+ def _remove_prefetched(row: T) -> None:
306
+ for obj in row:
307
+ if isinstance(obj, File):
308
+ catalog = obj._catalog
309
+ assert catalog is not None
310
+ try:
311
+ catalog.cache.remove(obj)
312
+ except Exception as e: # noqa: BLE001
313
+ print(f"Failed to remove prefetched item {obj.name!r}: {e!s}")
314
+
315
+
304
316
  def _prefetch_inputs(
305
317
  prepared_inputs: "Iterable[T]",
306
318
  prefetch: int = 0,
307
319
  download_cb: Optional["Callback"] = None,
308
- after_prefetch: "Callable[[], None]" = noop,
320
+ after_prefetch: Optional[Callable[[], None]] = None,
321
+ remove_prefetched: bool = False,
309
322
  ) -> "abc.Generator[T, None, None]":
310
- if prefetch > 0:
311
- f = partial(
312
- _prefetch_input,
313
- download_cb=download_cb,
314
- after_prefetch=after_prefetch,
315
- )
316
- prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
317
- yield from prepared_inputs
323
+ if not prefetch:
324
+ yield from prepared_inputs
325
+ return
326
+
327
+ if after_prefetch is None:
328
+ after_prefetch = noop
329
+ if isinstance(download_cb, CombinedDownloadCallback):
330
+ after_prefetch = download_cb.increment_file_count
331
+
332
+ f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
333
+ mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
334
+ with closing(mapper.iterate()) as row_iter:
335
+ for row in row_iter:
336
+ try:
337
+ yield row # type: ignore[misc]
338
+ finally:
339
+ if remove_prefetched:
340
+ _remove_prefetched(row)
318
341
 
319
342
 
320
343
  def _get_cache(
@@ -351,7 +374,13 @@ class Mapper(UDFBase):
351
374
  )
352
375
 
353
376
  prepared_inputs = _prepare_rows(udf_inputs)
354
- prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
377
+ prepared_inputs = _prefetch_inputs(
378
+ prepared_inputs,
379
+ self.prefetch,
380
+ download_cb=download_cb,
381
+ remove_prefetched=bool(self.prefetch) and not cache,
382
+ )
383
+
355
384
  with closing(prepared_inputs):
356
385
  for id_, *udf_args in prepared_inputs:
357
386
  result_objs = self.process_safe(udf_args)
@@ -391,9 +420,9 @@ class BatchMapper(UDFBase):
391
420
  )
392
421
  result_objs = list(self.process_safe(udf_args))
393
422
  n_objs = len(result_objs)
394
- assert (
395
- n_objs == n_rows
396
- ), f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
423
+ assert n_objs == n_rows, (
424
+ f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
425
+ )
397
426
  udf_outputs = (self._flatten_row(row) for row in result_objs)
398
427
  output = [
399
428
  {"sys__id": row_id} | dict(zip(self.signal_names, signals))
@@ -429,15 +458,22 @@ class Generator(UDFBase):
429
458
  row, udf_fields, catalog, cache, download_cb
430
459
  )
431
460
 
461
+ def _process_row(row):
462
+ with safe_closing(self.process_safe(row)) as result_objs:
463
+ for result_obj in result_objs:
464
+ udf_output = self._flatten_row(result_obj)
465
+ yield dict(zip(self.signal_names, udf_output))
466
+
432
467
  prepared_inputs = _prepare_rows(udf_inputs)
433
- prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
468
+ prepared_inputs = _prefetch_inputs(
469
+ prepared_inputs,
470
+ self.prefetch,
471
+ download_cb=download_cb,
472
+ remove_prefetched=bool(self.prefetch) and not cache,
473
+ )
434
474
  with closing(prepared_inputs):
435
- for row in prepared_inputs:
436
- result_objs = self.process_safe(row)
437
- udf_outputs = (self._flatten_row(row) for row in result_objs)
438
- output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
439
- processed_cb.relative_update(1)
440
- yield output
475
+ for row in processed_cb.wrap(prepared_inputs):
476
+ yield _process_row(row)
441
477
 
442
478
  self.teardown()
443
479
 
datachain/model/bbox.py CHANGED
@@ -22,9 +22,9 @@ class BBox(DataModel):
22
22
  @staticmethod
23
23
  def from_list(coords: list[float], title: str = "") -> "BBox":
24
24
  assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
25
- assert all(
26
- isinstance(value, (int, float)) for value in coords
27
- ), "Bounding box coordinates must be floats or integers."
25
+ assert all(isinstance(value, (int, float)) for value in coords), (
26
+ "Bounding box coordinates must be floats or integers."
27
+ )
28
28
  return BBox(
29
29
  title=title,
30
30
  coords=[round(c) for c in coords],
@@ -64,12 +64,12 @@ class OBBox(DataModel):
64
64
 
65
65
  @staticmethod
66
66
  def from_list(coords: list[float], title: str = "") -> "OBBox":
67
- assert (
68
- len(coords) == 8
69
- ), "Oriented bounding box must be a list of 8 coordinates."
70
- assert all(
71
- isinstance(value, (int, float)) for value in coords
72
- ), "Oriented bounding box coordinates must be floats or integers."
67
+ assert len(coords) == 8, (
68
+ "Oriented bounding box must be a list of 8 coordinates."
69
+ )
70
+ assert all(isinstance(value, (int, float)) for value in coords), (
71
+ "Oriented bounding box coordinates must be floats or integers."
72
+ )
73
73
  return OBBox(
74
74
  title=title,
75
75
  coords=[round(c) for c in coords],
datachain/model/pose.py CHANGED
@@ -22,9 +22,9 @@ class Pose(DataModel):
22
22
  def from_list(points: list[list[float]]) -> "Pose":
23
23
  assert len(points) == 2, "Pose must be a list of 2 lists: x and y coordinates."
24
24
  points_x, points_y = points
25
- assert (
26
- len(points_x) == len(points_y) == 17
27
- ), "Pose x and y coordinates must have the same length of 17."
25
+ assert len(points_x) == len(points_y) == 17, (
26
+ "Pose x and y coordinates must have the same length of 17."
27
+ )
28
28
  assert all(
29
29
  isinstance(value, (int, float)) for value in [*points_x, *points_y]
30
30
  ), "Pose coordinates must be floats or integers."
@@ -61,13 +61,13 @@ class Pose3D(DataModel):
61
61
 
62
62
  @staticmethod
63
63
  def from_list(points: list[list[float]]) -> "Pose3D":
64
- assert (
65
- len(points) == 3
66
- ), "Pose3D must be a list of 3 lists: x, y coordinates and visible."
64
+ assert len(points) == 3, (
65
+ "Pose3D must be a list of 3 lists: x, y coordinates and visible."
66
+ )
67
67
  points_x, points_y, points_v = points
68
- assert (
69
- len(points_x) == len(points_y) == len(points_v) == 17
70
- ), "Pose3D x, y coordinates and visible must have the same length of 17."
68
+ assert len(points_x) == len(points_y) == len(points_v) == 17, (
69
+ "Pose3D x, y coordinates and visible must have the same length of 17."
70
+ )
71
71
  assert all(
72
72
  isinstance(value, (int, float))
73
73
  for value in [*points_x, *points_y, *points_v]
@@ -22,13 +22,13 @@ class Segment(DataModel):
22
22
 
23
23
  @staticmethod
24
24
  def from_list(points: list[list[float]], title: str = "") -> "Segment":
25
- assert (
26
- len(points) == 2
27
- ), "Segment must be a list of 2 lists: x and y coordinates."
25
+ assert len(points) == 2, (
26
+ "Segment must be a list of 2 lists: x and y coordinates."
27
+ )
28
28
  points_x, points_y = points
29
- assert len(points_x) == len(
30
- points_y
31
- ), "Segment x and y coordinates must have the same length."
29
+ assert len(points_x) == len(points_y), (
30
+ "Segment x and y coordinates must have the same length."
31
+ )
32
32
  assert all(
33
33
  isinstance(value, (int, float)) for value in [*points_x, *points_y]
34
34
  ), "Segment coordinates must be floats or integers."
datachain/progress.py CHANGED
@@ -1,14 +1,5 @@
1
- """Manages progress bars."""
2
-
3
- import logging
4
- from threading import RLock
5
-
6
1
  from fsspec import Callback
7
2
  from fsspec.callbacks import TqdmCallback
8
- from tqdm.auto import tqdm
9
-
10
- logger = logging.getLogger(__name__)
11
- tqdm.set_lock(RLock())
12
3
 
13
4
 
14
5
  class CombinedDownloadCallback(Callback):
@@ -24,10 +15,6 @@ class CombinedDownloadCallback(Callback):
24
15
  class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
25
16
  def __init__(self, tqdm_kwargs=None, *args, **kwargs):
26
17
  self.files_count = 0
27
- tqdm_kwargs = tqdm_kwargs or {}
28
- tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
29
- kwargs = kwargs or {}
30
- kwargs["tqdm_cls"] = tqdm
31
18
  super().__init__(tqdm_kwargs, *args, **kwargs)
32
19
 
33
20
  def increment_file_count(self, n: int = 1) -> None:
@@ -336,15 +336,16 @@ def process_udf_outputs(
336
336
  for udf_output in udf_results:
337
337
  if not udf_output:
338
338
  continue
339
- for row in udf_output:
340
- cb.relative_update()
341
- rows.append(adjust_outputs(warehouse, row, udf_col_types))
342
- if len(rows) >= batch_size or (
343
- len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
344
- ):
345
- for row_chunk in batched(rows, batch_size):
346
- warehouse.insert_rows(udf_table, row_chunk)
347
- rows.clear()
339
+ with safe_closing(udf_output):
340
+ for row in udf_output:
341
+ cb.relative_update()
342
+ rows.append(adjust_outputs(warehouse, row, udf_col_types))
343
+ if len(rows) >= batch_size or (
344
+ len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
345
+ ):
346
+ for row_chunk in batched(rows, batch_size):
347
+ warehouse.insert_rows(udf_table, row_chunk)
348
+ rows.clear()
348
349
 
349
350
  if rows:
350
351
  for row_chunk in batched(rows, batch_size):
@@ -355,7 +356,7 @@ def process_udf_outputs(
355
356
 
356
357
  def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallback:
357
358
  return TqdmCombinedDownloadCallback(
358
- {
359
+ tqdm_kwargs={
359
360
  "desc": "Download" + suffix,
360
361
  "unit": "B",
361
362
  "unit_scale": True,
@@ -363,6 +364,7 @@ def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallbac
363
364
  "leave": False,
364
365
  **kwargs,
365
366
  },
367
+ tqdm_cls=tqdm,
366
368
  )
367
369
 
368
370