datachain 0.28.0__py3-none-any.whl → 0.28.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -324,6 +324,7 @@ class DataChain:
324
324
  sys: Optional[bool] = None,
325
325
  namespace: Optional[str] = None,
326
326
  project: Optional[str] = None,
327
+ batch_rows: Optional[int] = None,
327
328
  ) -> "Self":
328
329
  """Change settings for chain.
329
330
 
@@ -331,22 +332,24 @@ class DataChain:
331
332
  It returns chain, so, it can be chained later with next operation.
332
333
 
333
334
  Parameters:
334
- cache : data caching (default=False)
335
+ cache : data caching. (default=False)
335
336
  parallel : number of thread for processors. True is a special value to
336
- enable all available CPUs (default=1)
337
+ enable all available CPUs. (default=1)
337
338
  workers : number of distributed workers. Only for Studio mode. (default=1)
338
- min_task_size : minimum number of tasks (default=1)
339
- prefetch: number of workers to use for downloading files in advance.
339
+ min_task_size : minimum number of tasks. (default=1)
340
+ prefetch : number of workers to use for downloading files in advance.
340
341
  This is enabled by default and uses 2 workers.
341
342
  To disable prefetching, set it to 0.
342
- namespace: namespace name.
343
- project: project name.
343
+ namespace : namespace name.
344
+ project : project name.
345
+ batch_rows : row limit per insert to balance speed and memory usage.
346
+ (default=2000)
344
347
 
345
348
  Example:
346
349
  ```py
347
350
  chain = (
348
351
  chain
349
- .settings(cache=True, parallel=8)
352
+ .settings(cache=True, parallel=8, batch_rows=300)
350
353
  .map(laion=process_webdataset(spec=WDSLaion), params="file")
351
354
  )
352
355
  ```
@@ -356,7 +359,14 @@ class DataChain:
356
359
  settings = copy.copy(self._settings)
357
360
  settings.add(
358
361
  Settings(
359
- cache, parallel, workers, min_task_size, prefetch, namespace, project
362
+ cache,
363
+ parallel,
364
+ workers,
365
+ min_task_size,
366
+ prefetch,
367
+ namespace,
368
+ project,
369
+ batch_rows,
360
370
  )
361
371
  )
362
372
  return self._evolve(settings=settings, _sys=sys)
@@ -711,7 +721,7 @@ class DataChain:
711
721
 
712
722
  return self._evolve(
713
723
  query=self._query.add_signals(
714
- udf_obj.to_udf_wrapper(),
724
+ udf_obj.to_udf_wrapper(self._settings.batch_rows),
715
725
  **self._settings.to_dict(),
716
726
  ),
717
727
  signal_schema=self.signals_schema | udf_obj.output,
@@ -749,7 +759,7 @@ class DataChain:
749
759
  udf_obj.prefetch = prefetch
750
760
  return self._evolve(
751
761
  query=self._query.generate(
752
- udf_obj.to_udf_wrapper(),
762
+ udf_obj.to_udf_wrapper(self._settings.batch_rows),
753
763
  **self._settings.to_dict(),
754
764
  ),
755
765
  signal_schema=udf_obj.output,
@@ -885,7 +895,7 @@ class DataChain:
885
895
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
886
896
  return self._evolve(
887
897
  query=self._query.generate(
888
- udf_obj.to_udf_wrapper(),
898
+ udf_obj.to_udf_wrapper(self._settings.batch_rows),
889
899
  partition_by=processed_partition_by,
890
900
  **self._settings.to_dict(),
891
901
  ),
@@ -917,11 +927,24 @@ class DataChain:
917
927
  )
918
928
  chain.save("new_dataset")
919
929
  ```
930
+
931
+ .. deprecated:: 0.29.0
932
+ This method is deprecated and will be removed in a future version.
933
+ Use `agg()` instead, which provides the similar functionality.
920
934
  """
935
+ import warnings
936
+
937
+ warnings.warn(
938
+ "batch_map() is deprecated and will be removed in a future version. "
939
+ "Use agg() instead, which provides the similar functionality.",
940
+ DeprecationWarning,
941
+ stacklevel=2,
942
+ )
921
943
  udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
944
+
922
945
  return self._evolve(
923
946
  query=self._query.add_signals(
924
- udf_obj.to_udf_wrapper(batch),
947
+ udf_obj.to_udf_wrapper(self._settings.batch_rows, batch=batch),
925
948
  **self._settings.to_dict(),
926
949
  ),
927
950
  signal_schema=self.signals_schema | udf_obj.output,
@@ -2340,7 +2363,7 @@ class DataChain:
2340
2363
  def setup(self, **kwargs) -> "Self":
2341
2364
  """Setup variables to pass to UDF functions.
2342
2365
 
2343
- Use before running map/gen/agg/batch_map to save an object and pass it as an
2366
+ Use before running map/gen/agg to save an object and pass it as an
2344
2367
  argument to the UDF.
2345
2368
 
2346
2369
  The value must be a callable (a `lambda: <value>` syntax can be used to quickly
@@ -2419,9 +2442,11 @@ class DataChain:
2419
2442
  ds.to_storage("gs://mybucket", placement="filename")
2420
2443
  ```
2421
2444
  """
2445
+ chain = self.persist()
2446
+ count = chain.count()
2447
+
2422
2448
  if placement == "filename" and (
2423
- self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
2424
- != self._query.count()
2449
+ chain._query.distinct(pathfunc.name(C(f"{signal}__path"))).count() != count
2425
2450
  ):
2426
2451
  raise ValueError("Files with the same name found")
2427
2452
 
@@ -2433,7 +2458,7 @@ class DataChain:
2433
2458
  unit=" files",
2434
2459
  unit_scale=True,
2435
2460
  unit_divisor=10,
2436
- total=self.count(),
2461
+ total=count,
2437
2462
  leave=False,
2438
2463
  )
2439
2464
  file_exporter = FileExporter(
@@ -2444,7 +2469,10 @@ class DataChain:
2444
2469
  max_threads=num_threads or 1,
2445
2470
  client_config=client_config,
2446
2471
  )
2447
- file_exporter.run(self.to_values(signal), progress_bar)
2472
+ file_exporter.run(
2473
+ (rows[0] for rows in chain.to_iter(signal)),
2474
+ progress_bar,
2475
+ )
2448
2476
 
2449
2477
  def shuffle(self) -> "Self":
2450
2478
  """Shuffle the rows of the chain deterministically."""
@@ -15,6 +15,8 @@ if TYPE_CHECKING:
15
15
 
16
16
  P = ParamSpec("P")
17
17
 
18
+ READ_RECORDS_BATCH_SIZE = 10000
19
+
18
20
 
19
21
  def read_records(
20
22
  to_insert: Optional[Union[dict, Iterable[dict]]],
@@ -41,7 +43,7 @@ def read_records(
41
43
  Notes:
42
44
  This call blocks until all records are inserted.
43
45
  """
44
- from datachain.query.dataset import INSERT_BATCH_SIZE, adjust_outputs, get_col_types
46
+ from datachain.query.dataset import adjust_outputs, get_col_types
45
47
  from datachain.sql.types import SQLType
46
48
  from datachain.utils import batched
47
49
 
@@ -94,7 +96,7 @@ def read_records(
94
96
  {c.name: c.type for c in columns if isinstance(c.type, SQLType)},
95
97
  )
96
98
  records = (adjust_outputs(warehouse, record, col_types) for record in to_insert)
97
- for chunk in batched(records, INSERT_BATCH_SIZE):
99
+ for chunk in batched(records, READ_RECORDS_BATCH_SIZE):
98
100
  warehouse.insert_rows(table, chunk)
99
101
  warehouse.insert_rows_done(table)
100
102
  return read_dataset(name=dsr.full_name, session=session, settings=settings)
datachain/lib/file.py CHANGED
@@ -23,7 +23,7 @@ from pydantic import Field, field_validator
23
23
 
24
24
  from datachain.client.fileslice import FileSlice
25
25
  from datachain.lib.data_model import DataModel
26
- from datachain.lib.utils import DataChainError
26
+ from datachain.lib.utils import DataChainError, rebase_path
27
27
  from datachain.nodes_thread_pool import NodesThreadPool
28
28
  from datachain.sql.types import JSON, Boolean, DateTime, Int, String
29
29
  from datachain.utils import TIME_ZERO
@@ -634,6 +634,40 @@ class File(DataModel):
634
634
  location=self.location,
635
635
  )
636
636
 
637
+ def rebase(
638
+ self,
639
+ old_base: str,
640
+ new_base: str,
641
+ suffix: str = "",
642
+ extension: str = "",
643
+ ) -> str:
644
+ """
645
+ Rebase the file's URI from one base directory to another.
646
+
647
+ Args:
648
+ old_base: Base directory to remove from the file's URI
649
+ new_base: New base directory to prepend
650
+ suffix: Optional suffix to add before file extension
651
+ extension: Optional new file extension (without dot)
652
+
653
+ Returns:
654
+ str: Rebased URI with new base directory
655
+
656
+ Raises:
657
+ ValueError: If old_base is not found in the file's URI
658
+
659
+ Examples:
660
+ >>> file = File(source="s3://bucket", path="data/2025-05-27/file.wav")
661
+ >>> file.rebase("s3://bucket/data", "s3://output-bucket/processed", \
662
+ extension="mp3")
663
+ 's3://output-bucket/processed/2025-05-27/file.mp3'
664
+
665
+ >>> file.rebase("data/audio", "/local/output", suffix="_ch1",
666
+ extension="npy")
667
+ '/local/output/file_ch1.npy'
668
+ """
669
+ return rebase_path(self.get_uri(), old_base, new_base, suffix, extension)
670
+
637
671
 
638
672
  def resolve(file: File) -> File:
639
673
  """
@@ -1219,6 +1253,24 @@ class Audio(DataModel):
1219
1253
  codec: str = Field(default="")
1220
1254
  bit_rate: int = Field(default=-1)
1221
1255
 
1256
+ @staticmethod
1257
+ def get_channel_name(num_channels: int, channel_idx: int) -> str:
1258
+ """Map channel index to meaningful name based on common audio formats"""
1259
+ channel_mappings = {
1260
+ 1: ["Mono"],
1261
+ 2: ["Left", "Right"],
1262
+ 4: ["W", "X", "Y", "Z"], # First-order Ambisonics
1263
+ 6: ["FL", "FR", "FC", "LFE", "BL", "BR"], # 5.1 surround
1264
+ 8: ["FL", "FR", "FC", "LFE", "BL", "BR", "SL", "SR"], # 7.1 surround
1265
+ }
1266
+
1267
+ if num_channels in channel_mappings:
1268
+ channels = channel_mappings[num_channels]
1269
+ if 0 <= channel_idx < len(channels):
1270
+ return channels[channel_idx]
1271
+
1272
+ return f"Ch{channel_idx + 1}"
1273
+
1222
1274
 
1223
1275
  class ArrowRow(DataModel):
1224
1276
  """`DataModel` for reading row from Arrow-supported file."""
datachain/lib/settings.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from datachain.lib.utils import DataChainParamsError
2
+ from datachain.utils import DEFAULT_CHUNK_ROWS
2
3
 
3
4
 
4
5
  class SettingsError(DataChainParamsError):
@@ -16,6 +17,7 @@ class Settings:
16
17
  prefetch=None,
17
18
  namespace=None,
18
19
  project=None,
20
+ batch_rows=None,
19
21
  ):
20
22
  self._cache = cache
21
23
  self.parallel = parallel
@@ -24,6 +26,7 @@ class Settings:
24
26
  self.prefetch = prefetch
25
27
  self.namespace = namespace
26
28
  self.project = project
29
+ self._chunk_rows = batch_rows
27
30
 
28
31
  if not isinstance(cache, bool) and cache is not None:
29
32
  raise SettingsError(
@@ -53,6 +56,18 @@ class Settings:
53
56
  f", {min_task_size.__class__.__name__} was given"
54
57
  )
55
58
 
59
+ if batch_rows is not None and not isinstance(batch_rows, int):
60
+ raise SettingsError(
61
+ "'batch_rows' argument must be int or None"
62
+ f", {batch_rows.__class__.__name__} was given"
63
+ )
64
+
65
+ if batch_rows is not None and batch_rows <= 0:
66
+ raise SettingsError(
67
+ "'batch_rows' argument must be positive integer"
68
+ f", {batch_rows} was given"
69
+ )
70
+
56
71
  @property
57
72
  def cache(self):
58
73
  return self._cache if self._cache is not None else False
@@ -61,6 +76,10 @@ class Settings:
61
76
  def workers(self):
62
77
  return self._workers if self._workers is not None else False
63
78
 
79
+ @property
80
+ def batch_rows(self):
81
+ return self._chunk_rows if self._chunk_rows is not None else DEFAULT_CHUNK_ROWS
82
+
64
83
  def to_dict(self):
65
84
  res = {}
66
85
  if self._cache is not None:
@@ -75,6 +94,8 @@ class Settings:
75
94
  res["namespace"] = self.namespace
76
95
  if self.project is not None:
77
96
  res["project"] = self.project
97
+ if self._chunk_rows is not None:
98
+ res["batch_rows"] = self._chunk_rows
78
99
  return res
79
100
 
80
101
  def add(self, settings: "Settings"):
@@ -86,3 +107,5 @@ class Settings:
86
107
  self.project = settings.project or self.project
87
108
  if settings.prefetch is not None:
88
109
  self.prefetch = settings.prefetch
110
+ if settings._chunk_rows is not None:
111
+ self._chunk_rows = settings._chunk_rows
datachain/lib/udf.py CHANGED
@@ -62,19 +62,21 @@ class UDFProperties:
62
62
  return self.udf.get_batching(use_partitioning)
63
63
 
64
64
  @property
65
- def batch(self):
66
- return self.udf.batch
65
+ def batch_rows(self):
66
+ return self.udf.batch_rows
67
67
 
68
68
 
69
69
  @attrs.define(slots=False)
70
70
  class UDFAdapter:
71
71
  inner: "UDFBase"
72
72
  output: UDFOutputSpec
73
+ batch_rows: Optional[int] = None
73
74
  batch: int = 1
74
75
 
75
76
  def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
76
77
  if use_partitioning:
77
78
  return Partition()
79
+
78
80
  if self.batch == 1:
79
81
  return NoBatching()
80
82
  if self.batch > 1:
@@ -233,10 +235,15 @@ class UDFBase(AbstractUDF):
233
235
  def signal_names(self) -> Iterable[str]:
234
236
  return self.output.to_udf_spec().keys()
235
237
 
236
- def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
238
+ def to_udf_wrapper(
239
+ self,
240
+ batch_rows: Optional[int] = None,
241
+ batch: int = 1,
242
+ ) -> UDFAdapter:
237
243
  return UDFAdapter(
238
244
  self,
239
245
  self.output.to_udf_spec(),
246
+ batch_rows,
240
247
  batch,
241
248
  )
242
249
 
@@ -418,11 +425,27 @@ class Mapper(UDFBase):
418
425
 
419
426
 
420
427
  class BatchMapper(UDFBase):
421
- """Inherit from this class to pass to `DataChain.batch_map()`."""
428
+ """Inherit from this class to pass to `DataChain.batch_map()`.
429
+
430
+ .. deprecated:: 0.29.0
431
+ This class is deprecated and will be removed in a future version.
432
+ Use `Aggregator` instead, which provides the similar functionality.
433
+ """
422
434
 
423
435
  is_input_batched = True
424
436
  is_output_batched = True
425
437
 
438
+ def __init__(self):
439
+ import warnings
440
+
441
+ warnings.warn(
442
+ "BatchMapper is deprecated and will be removed in a future version. "
443
+ "Use Aggregator instead, which provides the similar functionality.",
444
+ DeprecationWarning,
445
+ stacklevel=2,
446
+ )
447
+ super().__init__()
448
+
426
449
  def run(
427
450
  self,
428
451
  udf_fields: Sequence[str],
datachain/lib/utils.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import re
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Sequence
4
+ from pathlib import PurePosixPath
5
+ from urllib.parse import urlparse
4
6
 
5
7
 
6
8
  class AbstractUDF(ABC):
@@ -57,3 +59,97 @@ def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
57
59
  new_col_names[generated_column] = org_column
58
60
 
59
61
  return new_col_names
62
+
63
+
64
+ def rebase_path(
65
+ src_path: str,
66
+ old_base: str,
67
+ new_base: str,
68
+ suffix: str = "",
69
+ extension: str = "",
70
+ ) -> str:
71
+ """
72
+ Rebase a file path from one base directory to another.
73
+
74
+ Args:
75
+ src_path: Source file path (can include URI scheme like s3://)
76
+ old_base: Base directory to remove from src_path
77
+ new_base: New base directory to prepend
78
+ suffix: Optional suffix to add before file extension
79
+ extension: Optional new file extension (without dot)
80
+
81
+ Returns:
82
+ str: Rebased path with new base directory
83
+
84
+ Raises:
85
+ ValueError: If old_base is not found in src_path
86
+ """
87
+ # Parse URIs to handle schemes properly
88
+ src_parsed = urlparse(src_path)
89
+ old_base_parsed = urlparse(old_base)
90
+ new_base_parsed = urlparse(new_base)
91
+
92
+ # Get the path component (without scheme)
93
+ if src_parsed.scheme:
94
+ src_path_only = src_parsed.netloc + src_parsed.path
95
+ else:
96
+ src_path_only = src_path
97
+
98
+ if old_base_parsed.scheme:
99
+ old_base_only = old_base_parsed.netloc + old_base_parsed.path
100
+ else:
101
+ old_base_only = old_base
102
+
103
+ # Normalize paths
104
+ src_path_norm = PurePosixPath(src_path_only).as_posix()
105
+ old_base_norm = PurePosixPath(old_base_only).as_posix()
106
+
107
+ # Find where old_base appears in src_path
108
+ if old_base_norm in src_path_norm:
109
+ # Find the index where old_base appears
110
+ idx = src_path_norm.find(old_base_norm)
111
+ if idx == -1:
112
+ raise ValueError(f"old_base '{old_base}' not found in src_path")
113
+
114
+ # Extract the relative path after old_base
115
+ relative_start = idx + len(old_base_norm)
116
+ # Skip leading slash if present
117
+ if relative_start < len(src_path_norm) and src_path_norm[relative_start] == "/":
118
+ relative_start += 1
119
+ relative_path = src_path_norm[relative_start:]
120
+ else:
121
+ raise ValueError(f"old_base '{old_base}' not found in src_path")
122
+
123
+ # Parse the filename
124
+ path_obj = PurePosixPath(relative_path)
125
+ stem = path_obj.stem
126
+ current_ext = path_obj.suffix
127
+
128
+ # Apply suffix and extension changes
129
+ new_stem = stem + suffix if suffix else stem
130
+ if extension:
131
+ new_ext = f".{extension}"
132
+ elif current_ext:
133
+ new_ext = current_ext
134
+ else:
135
+ new_ext = ""
136
+
137
+ # Build new filename
138
+ new_name = new_stem + new_ext
139
+
140
+ # Reconstruct path with new base
141
+ parent = str(path_obj.parent)
142
+ if parent == ".":
143
+ new_relative_path = new_name
144
+ else:
145
+ new_relative_path = str(PurePosixPath(parent) / new_name)
146
+
147
+ # Handle new_base URI scheme
148
+ if new_base_parsed.scheme:
149
+ # Has schema like s3://
150
+ base_path = new_base_parsed.netloc + new_base_parsed.path
151
+ base_path = PurePosixPath(base_path).as_posix()
152
+ full_path = str(PurePosixPath(base_path) / new_relative_path)
153
+ return f"{new_base_parsed.scheme}://{full_path}"
154
+ # Regular path
155
+ return str(PurePosixPath(new_base) / new_relative_path)
@@ -333,32 +333,24 @@ def process_udf_outputs(
333
333
  udf_table: "Table",
334
334
  udf_results: Iterator[Iterable["UDFResult"]],
335
335
  udf: "UDFAdapter",
336
- batch_size: int = INSERT_BATCH_SIZE,
337
336
  cb: Callback = DEFAULT_CALLBACK,
338
337
  ) -> None:
339
- import psutil
340
-
341
- rows: list[UDFResult] = []
342
338
  # Optimization: Compute row types once, rather than for every row.
343
339
  udf_col_types = get_col_types(warehouse, udf.output)
340
+ batch_rows = udf.batch_rows or INSERT_BATCH_SIZE
344
341
 
345
- for udf_output in udf_results:
346
- if not udf_output:
347
- continue
348
- with safe_closing(udf_output):
349
- for row in udf_output:
350
- cb.relative_update()
351
- rows.append(adjust_outputs(warehouse, row, udf_col_types))
352
- if len(rows) >= batch_size or (
353
- len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
354
- ):
355
- for row_chunk in batched(rows, batch_size):
356
- warehouse.insert_rows(udf_table, row_chunk)
357
- rows.clear()
342
+ def _insert_rows():
343
+ for udf_output in udf_results:
344
+ if not udf_output:
345
+ continue
346
+
347
+ with safe_closing(udf_output):
348
+ for row in udf_output:
349
+ cb.relative_update()
350
+ yield adjust_outputs(warehouse, row, udf_col_types)
358
351
 
359
- if rows:
360
- for row_chunk in batched(rows, batch_size):
361
- warehouse.insert_rows(udf_table, row_chunk)
352
+ for row_chunk in batched(_insert_rows(), batch_rows):
353
+ warehouse.insert_rows(udf_table, row_chunk)
362
354
 
363
355
  warehouse.insert_rows_done(udf_table)
364
356
 
@@ -401,6 +393,7 @@ class UDFStep(Step, ABC):
401
393
  min_task_size: Optional[int] = None
402
394
  is_generator = False
403
395
  cache: bool = False
396
+ batch_rows: Optional[int] = None
404
397
 
405
398
  @abstractmethod
406
399
  def create_udf_table(self, query: Select) -> "Table":
@@ -602,6 +595,7 @@ class UDFStep(Step, ABC):
602
595
  parallel=self.parallel,
603
596
  workers=self.workers,
604
597
  min_task_size=self.min_task_size,
598
+ batch_rows=self.batch_rows,
605
599
  )
606
600
  return self.__class__(self.udf, self.catalog)
607
601
 
@@ -1633,6 +1627,7 @@ class DatasetQuery:
1633
1627
  min_task_size: Optional[int] = None,
1634
1628
  partition_by: Optional[PartitionByType] = None,
1635
1629
  cache: bool = False,
1630
+ batch_rows: Optional[int] = None,
1636
1631
  ) -> "Self":
1637
1632
  """
1638
1633
  Adds one or more signals based on the results from the provided UDF.
@@ -1658,6 +1653,7 @@ class DatasetQuery:
1658
1653
  workers=workers,
1659
1654
  min_task_size=min_task_size,
1660
1655
  cache=cache,
1656
+ batch_rows=batch_rows,
1661
1657
  )
1662
1658
  )
1663
1659
  return query
@@ -1679,6 +1675,7 @@ class DatasetQuery:
1679
1675
  namespace: Optional[str] = None,
1680
1676
  project: Optional[str] = None,
1681
1677
  cache: bool = False,
1678
+ batch_rows: Optional[int] = None,
1682
1679
  ) -> "Self":
1683
1680
  query = self.clone()
1684
1681
  steps = query.steps
@@ -1691,6 +1688,7 @@ class DatasetQuery:
1691
1688
  workers=workers,
1692
1689
  min_task_size=min_task_size,
1693
1690
  cache=cache,
1691
+ batch_rows=batch_rows,
1694
1692
  )
1695
1693
  )
1696
1694
  return query
datachain/utils.py CHANGED
@@ -11,7 +11,6 @@ import time
11
11
  from collections.abc import Iterable, Iterator, Sequence
12
12
  from contextlib import contextmanager
13
13
  from datetime import date, datetime, timezone
14
- from itertools import chain, islice
15
14
  from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
16
15
  from uuid import UUID
17
16
 
@@ -26,6 +25,8 @@ if TYPE_CHECKING:
26
25
  from typing_extensions import Self
27
26
 
28
27
 
28
+ DEFAULT_CHUNK_ROWS = 2000
29
+
29
30
  logger = logging.getLogger("datachain")
30
31
 
31
32
  NUL = b"\0"
@@ -225,30 +226,44 @@ def get_envs_by_prefix(prefix: str) -> dict[str, str]:
225
226
  _T_co = TypeVar("_T_co", covariant=True)
226
227
 
227
228
 
228
- def batched(iterable: Iterable[_T_co], n: int) -> Iterator[tuple[_T_co, ...]]:
229
- """Batch data into tuples of length n. The last batch may be shorter."""
230
- # Based on: https://docs.python.org/3/library/itertools.html#itertools-recipes
231
- # batched('ABCDEFG', 3) --> ABC DEF G
232
- if n < 1:
233
- raise ValueError("Batch size must be at least one")
234
- it = iter(iterable)
235
- while batch := tuple(islice(it, n)):
229
+ def _dynamic_batched_core(
230
+ iterable: Iterable[_T_co],
231
+ batch_rows: int,
232
+ ) -> Iterator[list[_T_co]]:
233
+ """Core batching logic that yields lists."""
234
+
235
+ batch: list[_T_co] = []
236
+
237
+ for item in iterable:
238
+ # Check if adding this item would exceed limits
239
+ if len(batch) >= batch_rows and batch: # Yield current batch if we have one
240
+ yield batch
241
+ batch = []
242
+
243
+ batch.append(item)
244
+
245
+ # Yield any remaining items
246
+ if batch:
236
247
  yield batch
237
248
 
238
249
 
239
- def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]:
240
- """Batch data into iterators of length n. The last batch may be shorter."""
241
- # batched('ABCDEFG', 3) --> ABC DEF G
242
- if n < 1:
243
- raise ValueError("Batch size must be at least one")
244
- it = iter(iterable)
245
- while True:
246
- chunk_it = islice(it, n)
247
- try:
248
- first_el = next(chunk_it)
249
- except StopIteration:
250
- return
251
- yield chain((first_el,), chunk_it)
250
+ def batched(iterable: Iterable[_T_co], batch_rows: int) -> Iterator[tuple[_T_co, ...]]:
251
+ """
252
+ Batch data into tuples of length batch_rows .
253
+ The last batch may be shorter.
254
+ """
255
+ yield from (tuple(batch) for batch in _dynamic_batched_core(iterable, batch_rows))
256
+
257
+
258
+ def batched_it(
259
+ iterable: Iterable[_T_co],
260
+ batch_rows: int = DEFAULT_CHUNK_ROWS,
261
+ ) -> Iterator[Iterator[_T_co]]:
262
+ """
263
+ Batch data into iterators with dynamic sizing
264
+ based on row count and memory usage.
265
+ """
266
+ yield from (iter(batch) for batch in _dynamic_batched_core(iterable, batch_rows))
252
267
 
253
268
 
254
269
  def flatten(items):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datachain
3
- Version: 0.28.0
3
+ Version: 0.28.2
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License-Expression: Apache-2.0
@@ -19,7 +19,7 @@ datachain/script_meta.py,sha256=V-LaFOZG84pD0Zc0NvejYdzwDgzITv6yHvAHggDCnuY,4978
19
19
  datachain/semver.py,sha256=UB8GHPBtAP3UJGeiuJoInD7SK-DnB93_Xd1qy_CQ9cU,2074
20
20
  datachain/studio.py,sha256=-BmKLVNBLPFveUgVVE2So3aaiGndO2jK2qbHZ0zBDd8,15239
21
21
  datachain/telemetry.py,sha256=0A4IOPPp9VlP5pyW9eBfaTK3YhHGzHl7dQudQjUAx9A,994
22
- datachain/utils.py,sha256=DNqOi-Ydb7InyWvD9m7_yailxz6-YGpZzh00biQaHNo,15305
22
+ datachain/utils.py,sha256=Gp5JVr_m7nVWQGDOjrGnZjRXF9-Ai-MBxiPJIcpPvWQ,15451
23
23
  datachain/catalog/__init__.py,sha256=cMZzSz3VoUi-6qXSVaHYN-agxQuAcz2XSqnEPZ55crE,353
24
24
  datachain/catalog/catalog.py,sha256=QTWCXy75iWo-0MCXyfV_WbsKeZ1fpLpvL8d60rxn1ws,65528
25
25
  datachain/catalog/datasource.py,sha256=IkGMh0Ttg6Q-9DWfU_H05WUnZepbGa28HYleECi6K7I,1353
@@ -75,7 +75,7 @@ datachain/lib/audio.py,sha256=fQmIBq-9hrUZtkgeJdPHYA_D8Wfe9D4cQZk4_ijxpNc,7580
75
75
  datachain/lib/clip.py,sha256=lm5CzVi4Cj1jVLEKvERKArb-egb9j1Ls-fwTItT6vlI,6150
76
76
  datachain/lib/data_model.py,sha256=Rjah76GHwIV6AZQk4rsdg6JLre5D8Kb9T4PS5SXzsPA,3740
77
77
  datachain/lib/dataset_info.py,sha256=7w-DoKOyIVoOtWGCgciMLcP5CiAWJB3rVI-vUDF80k0,3311
78
- datachain/lib/file.py,sha256=_ch7xYcpl0kzImgEwccbQ-a5qb9rbEvx1vcuWerOn9k,42608
78
+ datachain/lib/file.py,sha256=IGwpCwjsSOpZXlRsatcMKToMmuvYiX6_UtaTjUKAAdg,44511
79
79
  datachain/lib/hf.py,sha256=3xdvPQPilnJiGv3H4S4bTGqvrGGlZgZmqjE1n_SMJZg,7293
80
80
  datachain/lib/image.py,sha256=erWvZW5M3emnbl6_fGAOPyKm-1EKbt3vOdWPfe3Oo7U,3265
81
81
  datachain/lib/listing.py,sha256=U-2stsTEwEsq4Y80dqGfktGzkmB5-ZntnL1_rzXlH0k,7089
@@ -85,13 +85,13 @@ datachain/lib/model_store.py,sha256=dkL2rcT5ag-kbgkhQPL_byEs-TCYr29qvdltroL5NxM,
85
85
  datachain/lib/namespaces.py,sha256=it52UbbwB8dzhesO2pMs_nThXiPQ1Ph9sD9I3GQkg5s,2099
86
86
  datachain/lib/projects.py,sha256=8lN0qV8czX1LGtWURCUvRlSJk-RpO9w9Rra_pOZus6g,2595
87
87
  datachain/lib/pytorch.py,sha256=S-st2SAczYut13KMf6eSqP_OQ8otWI5TRmzhK5fN3k0,7828
88
- datachain/lib/settings.py,sha256=9wi0FoHxRxNiyn99pR28IYsMkoo47jQxeXuObQr2Ar0,2929
88
+ datachain/lib/settings.py,sha256=n0YYhCVdgCdMkCSLY7kscJF9mUhlQ0a4ENWBsJFynkw,3809
89
89
  datachain/lib/signal_schema.py,sha256=JMsL8c4iCRH9PoRumvjimsOLQQslTjm_aDR2jh1zT2Q,38558
90
90
  datachain/lib/tar.py,sha256=MLcVjzIgBqRuJacCNpZ6kwSZNq1i2tLyROc8PVprHsA,999
91
91
  datachain/lib/text.py,sha256=UNHm8fhidk7wdrWqacEWaA6I9ykfYqarQ2URby7jc7M,1261
92
- datachain/lib/udf.py,sha256=SUnJWRDC3TlLhvpi8iqqJbeZGn5DChot7DyH-0Q-z20,17305
92
+ datachain/lib/udf.py,sha256=IB1IKF5KyA-NiyfhVzmBPpF_aITPS3zSlrt24f_Ofjo,17956
93
93
  datachain/lib/udf_signature.py,sha256=Yz20iJ-WF1pijT3hvcDIKFzgWV9gFxZM73KZRx3NbPk,7560
94
- datachain/lib/utils.py,sha256=rG2y7NwTqZOuomZZRmrA-Q-ANM_j1cToQYqDJoOeGyU,1480
94
+ datachain/lib/utils.py,sha256=RLji1gHnfDXtJCnBo8BcNu1obndFpVsXJ_1Vb-FQ9Qo,4554
95
95
  datachain/lib/video.py,sha256=ddVstiMkfxyBPDsnjCKY0d_93bw-DcMqGqN60yzsZoo,6851
96
96
  datachain/lib/webdataset.py,sha256=CkW8FfGigNx6wo2EEK4KMjhEE8FamRHWGs2HZuH7jDY,7214
97
97
  datachain/lib/webdataset_laion.py,sha256=xvT6m_r5y0KbOx14BUe7UC5mOgrktJq53Mh-H0EVlUE,2525
@@ -104,14 +104,14 @@ datachain/lib/convert/values_to_tuples.py,sha256=j5yZMrVUH6W7b-7yUvdCTGI7JCUAYUO
104
104
  datachain/lib/dc/__init__.py,sha256=TFci5HTvYGjBesNUxDAnXaX36PnzPEUSn5a6JxB9o0U,872
105
105
  datachain/lib/dc/csv.py,sha256=q6a9BpapGwP6nwy6c5cklxQumep2fUp9l2LAjtTJr6s,4411
106
106
  datachain/lib/dc/database.py,sha256=g5M6NjYR1T0vKte-abV-3Ejnm-HqxTIMir5cRi_SziE,6051
107
- datachain/lib/dc/datachain.py,sha256=mLE5v4KhzEQm7HVWBTxY6EwJ2J-YeFVcLUY4I21216c,93212
107
+ datachain/lib/dc/datachain.py,sha256=T5-b2LLCF0zYhXQjOgtzzr6cm5NfrKVGxcJTWn7tfNU,94164
108
108
  datachain/lib/dc/datasets.py,sha256=P6CIJizD2IYFwOQG5D3VbQRjDmUiRH0ysdtb551Xdm8,15098
109
109
  datachain/lib/dc/hf.py,sha256=AP_MUHg6HJWae10PN9hD_beQVjrl0cleZ6Cvhtl1yoI,2901
110
110
  datachain/lib/dc/json.py,sha256=dNijfJ-H92vU3soyR7X1IiDrWhm6yZIGG3bSnZkPdAE,2733
111
111
  datachain/lib/dc/listings.py,sha256=V379Cb-7ZyquM0w7sWArQZkzInZy4GB7QQ1ZfowKzQY,4544
112
112
  datachain/lib/dc/pandas.py,sha256=ObueUXDUFKJGu380GmazdG02ARpKAHPhSaymfmOH13E,1489
113
113
  datachain/lib/dc/parquet.py,sha256=zYcSgrWwyEDW9UxGUSVdIVsCu15IGEf0xL8KfWQqK94,1782
114
- datachain/lib/dc/records.py,sha256=FpPbApWopUri1gIaSMsfXN4fevja4mjmfb6Q5eiaGxI,3116
114
+ datachain/lib/dc/records.py,sha256=4N1Fq-j5r4GK-PR5jIO-9B2u_zTNX9l-6SmcRhQDAsw,3136
115
115
  datachain/lib/dc/storage.py,sha256=FXroEdxOZfbuEBIWfWTkbGwrI0D4_mrLZSRsIQm0WFE,7693
116
116
  datachain/lib/dc/utils.py,sha256=VawOAlJSvAtZbsMg33s5tJe21TRx1Km3QggI1nN6tnw,3984
117
117
  datachain/lib/dc/values.py,sha256=7l1n352xWrEdql2NhBcZ3hj8xyPglWiY4qHjFPjn6iw,1428
@@ -126,7 +126,7 @@ datachain/model/ultralytics/pose.py,sha256=pBlmt63Qe68FKmexHimUGlNbNOoOlMHXG4fzX
126
126
  datachain/model/ultralytics/segment.py,sha256=63bDCj43E6iZ0hFI5J6uQfksdCmjEp6sEm1XzVaE8pw,2986
127
127
  datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
128
128
  datachain/query/batch.py,sha256=-goxLpE0EUvaDHu66rstj53UnfHpYfBUGux8GSpJ93k,4306
129
- datachain/query/dataset.py,sha256=cYNrg1QyrZpO-oup3mqmSYHUvgEYBKe8RgkVbyQa6p0,62777
129
+ datachain/query/dataset.py,sha256=OJZ_YwpS5i4B0wVmosMmMNW1qABr6zyOmqNHQdAWir4,62704
130
130
  datachain/query/dispatch.py,sha256=A0nPxn6mEN5d9dDo6S8m16Ji_9IvJLXrgF2kqXdi4fs,15546
131
131
  datachain/query/metrics.py,sha256=DOK5HdNVaRugYPjl8qnBONvTkwjMloLqAr7Mi3TjCO0,858
132
132
  datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
@@ -158,9 +158,9 @@ datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR
158
158
  datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
159
159
  datachain/toolkit/split.py,sha256=ktGWzY4kyzjWyR86dhvzw-Zhl0lVk_LOX3NciTac6qo,2914
160
160
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
161
- datachain-0.28.0.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
162
- datachain-0.28.0.dist-info/METADATA,sha256=lA3lv9RX2NeQPobrEjoEbAwg5K3zmnAnbDJ_hjR8KLw,13766
163
- datachain-0.28.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
164
- datachain-0.28.0.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
165
- datachain-0.28.0.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
166
- datachain-0.28.0.dist-info/RECORD,,
161
+ datachain-0.28.2.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
162
+ datachain-0.28.2.dist-info/METADATA,sha256=dYo2qW8RMNNCyy6KOXztfXOIldyS4_mADxeAlCI9cKw,13766
163
+ datachain-0.28.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
164
+ datachain-0.28.2.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
165
+ datachain-0.28.2.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
166
+ datachain-0.28.2.dist-info/RECORD,,