datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/tar.py CHANGED
@@ -6,12 +6,11 @@ from datachain.lib.file import File, TarVFile
6
6
 
7
7
 
8
8
  def build_tar_member(parent: File, info: tarfile.TarInfo) -> File:
9
- new_parent = parent.get_full_name()
10
9
  etag_string = "-".join([parent.etag, info.name, str(info.mtime)])
11
10
  etag = hashlib.md5(etag_string.encode(), usedforsecurity=False).hexdigest()
12
11
  return File(
13
12
  source=parent.source,
14
- path=f"{new_parent}/{info.name}",
13
+ path=f"{parent.path}/{info.name}",
15
14
  version=parent.version,
16
15
  size=info.size,
17
16
  etag=etag,
datachain/lib/text.py CHANGED
@@ -1,16 +1,17 @@
1
- from typing import Any, Callable, Optional, Union
1
+ from collections.abc import Callable
2
+ from typing import Any
2
3
 
3
4
  import torch
4
5
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase
5
6
 
6
7
 
7
8
  def convert_text(
8
- text: Union[str, list[str]],
9
- tokenizer: Optional[Callable] = None,
10
- tokenizer_kwargs: Optional[dict[str, Any]] = None,
11
- encoder: Optional[Callable] = None,
12
- device: Optional[Union[str, torch.device]] = None,
13
- ) -> Union[str, list[str], torch.Tensor]:
9
+ text: str | list[str],
10
+ tokenizer: Callable | None = None,
11
+ tokenizer_kwargs: dict[str, Any] | None = None,
12
+ encoder: Callable | None = None,
13
+ device: str | torch.device | None = None,
14
+ ) -> str | list[str] | torch.Tensor:
14
15
  """
15
16
  Tokenize and otherwise transform text.
16
17
 
datachain/lib/udf.py CHANGED
@@ -1,9 +1,8 @@
1
- import sys
2
- import traceback
1
+ import hashlib
3
2
  from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
4
3
  from contextlib import closing, nullcontext
5
4
  from functools import partial
6
- from typing import TYPE_CHECKING, Any, Optional, TypeVar
5
+ from typing import TYPE_CHECKING, Any, TypeVar
7
6
 
8
7
  import attrs
9
8
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
@@ -12,11 +11,10 @@ from pydantic import BaseModel
12
11
  from datachain.asyn import AsyncMapper
13
12
  from datachain.cache import temporary_cache
14
13
  from datachain.dataset import RowDict
14
+ from datachain.hash_utils import hash_callable
15
15
  from datachain.lib.convert.flatten import flatten
16
- from datachain.lib.data_model import DataValue
17
- from datachain.lib.file import File
18
- from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
19
- from datachain.progress import CombinedDownloadCallback
16
+ from datachain.lib.file import DataModel, File
17
+ from datachain.lib.utils import AbstractUDF, DataChainParamsError
20
18
  from datachain.query.batch import (
21
19
  Batch,
22
20
  BatchingStrategy,
@@ -42,8 +40,44 @@ T = TypeVar("T", bound=Sequence[Any])
42
40
 
43
41
 
44
42
  class UdfError(DataChainParamsError):
45
- def __init__(self, msg):
46
- super().__init__(f"UDF error: {msg}")
43
+ """Exception raised for UDF-related errors."""
44
+
45
+ def __init__(self, message: str) -> None:
46
+ self.message = message
47
+ super().__init__(message)
48
+
49
+ def __str__(self) -> str:
50
+ return f"{self.__class__.__name__!s}: {self.message!s}"
51
+
52
+ def __reduce__(self):
53
+ """Custom reduce method for pickling."""
54
+ return self.__class__, (self.message,)
55
+
56
+
57
+ class UdfRunError(Exception):
58
+ """Exception raised when UDF execution fails."""
59
+
60
+ def __init__(
61
+ self,
62
+ error: Exception | str,
63
+ stacktrace: str | None = None,
64
+ udf_name: str | None = None,
65
+ ) -> None:
66
+ self.error = error
67
+ self.stacktrace = stacktrace
68
+ self.udf_name = udf_name
69
+ super().__init__(str(error))
70
+
71
+ def __str__(self) -> str:
72
+ if isinstance(self.error, UdfRunError):
73
+ return str(self.error)
74
+ if isinstance(self.error, Exception):
75
+ return f"{self.error.__class__.__name__!s}: {self.error!s}"
76
+ return f"{self.__class__.__name__!s}: {self.error!s}"
77
+
78
+ def __reduce__(self):
79
+ """Custom reduce method for pickling."""
80
+ return self.__class__, (self.error, self.stacktrace, self.udf_name)
47
81
 
48
82
 
49
83
  ColumnType = Any
@@ -56,38 +90,26 @@ UDFOutputSpec = Mapping[str, ColumnType]
56
90
  UDFResult = dict[str, Any]
57
91
 
58
92
 
59
- @attrs.define
60
- class UDFProperties:
61
- udf: "UDFAdapter"
62
-
63
- def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
64
- return self.udf.get_batching(use_partitioning)
65
-
66
- @property
67
- def batch(self):
68
- return self.udf.batch
69
-
70
-
71
93
  @attrs.define(slots=False)
72
94
  class UDFAdapter:
73
95
  inner: "UDFBase"
74
96
  output: UDFOutputSpec
97
+ batch_size: int | None = None
75
98
  batch: int = 1
76
99
 
100
+ def hash(self) -> str:
101
+ return self.inner.hash()
102
+
77
103
  def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
78
104
  if use_partitioning:
79
105
  return Partition()
106
+
80
107
  if self.batch == 1:
81
108
  return NoBatching()
82
109
  if self.batch > 1:
83
110
  return Batch(self.batch)
84
111
  raise ValueError(f"invalid batch size {self.batch}")
85
112
 
86
- @property
87
- def properties(self):
88
- # For backwards compatibility.
89
- return UDFProperties(self)
90
-
91
113
  def run(
92
114
  self,
93
115
  udf_fields: "Sequence[str]",
@@ -164,10 +186,31 @@ class UDFBase(AbstractUDF):
164
186
  prefetch: int = 0
165
187
 
166
188
  def __init__(self):
167
- self.params: Optional[SignalSchema] = None
189
+ self.params: SignalSchema | None = None
168
190
  self.output = None
169
191
  self._func = None
170
192
 
193
+ def hash(self) -> str:
194
+ """
195
+ Creates SHA hash of this UDF function. It takes into account function,
196
+ inputs and outputs.
197
+
198
+ For function-based UDFs, hashes self._func.
199
+ For class-based UDFs, hashes the process method.
200
+ """
201
+ # Hash user code: either _func (function-based) or process method (class-based)
202
+ func_to_hash = self._func if self._func else self.process
203
+
204
+ parts = [
205
+ hash_callable(func_to_hash),
206
+ self.params.hash() if self.params else "",
207
+ self.output.hash(),
208
+ ]
209
+
210
+ return hashlib.sha256(
211
+ b"".join([bytes.fromhex(part) for part in parts])
212
+ ).hexdigest()
213
+
171
214
  def process(self, *args, **kwargs):
172
215
  """Processing function that needs to be defined by user"""
173
216
  if not self._func:
@@ -188,7 +231,7 @@ class UDFBase(AbstractUDF):
188
231
  self,
189
232
  sign: "UdfSignature",
190
233
  params: "SignalSchema",
191
- func: Optional[Callable],
234
+ func: Callable | None,
192
235
  ):
193
236
  self.params = params
194
237
  self.output = sign.output_schema
@@ -219,14 +262,31 @@ class UDFBase(AbstractUDF):
219
262
  def name(self):
220
263
  return self.__class__.__name__
221
264
 
265
+ @property
266
+ def verbose_name(self):
267
+ """Returns the name of the function or class that implements the UDF."""
268
+ if self._func and callable(self._func):
269
+ if hasattr(self._func, "__name__"):
270
+ return self._func.__name__
271
+ if hasattr(self._func, "__class__") and hasattr(
272
+ self._func.__class__, "__name__"
273
+ ):
274
+ return self._func.__class__.__name__
275
+ return "<unknown>"
276
+
222
277
  @property
223
278
  def signal_names(self) -> Iterable[str]:
224
279
  return self.output.to_udf_spec().keys()
225
280
 
226
- def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
281
+ def to_udf_wrapper(
282
+ self,
283
+ batch_size: int | None = None,
284
+ batch: int = 1,
285
+ ) -> UDFAdapter:
227
286
  return UDFAdapter(
228
287
  self,
229
288
  self.output.to_udf_spec(),
289
+ batch_size,
230
290
  batch,
231
291
  )
232
292
 
@@ -255,38 +315,37 @@ class UDFBase(AbstractUDF):
255
315
 
256
316
  def _parse_row(
257
317
  self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
258
- ) -> list[DataValue]:
318
+ ) -> list[Any]:
259
319
  assert self.params
260
320
  row = [row_dict[p] for p in self.params.to_udf_spec()]
261
321
  obj_row = self.params.row_to_objs(row)
262
322
  for obj in obj_row:
263
- if isinstance(obj, File):
264
- obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
323
+ self._set_stream_recursive(obj, catalog, cache, download_cb)
265
324
  return obj_row
266
325
 
326
+ def _set_stream_recursive(
327
+ self, obj: Any, catalog: "Catalog", cache: bool, download_cb: Callback
328
+ ) -> None:
329
+ """Recursively set the catalog stream on all File objects within an object."""
330
+ if isinstance(obj, File):
331
+ obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
332
+
333
+ # Check all fields for nested File objects, but only for DataModel objects
334
+ if isinstance(obj, DataModel):
335
+ for field_name in type(obj).model_fields:
336
+ field_value = getattr(obj, field_name, None)
337
+ if isinstance(field_value, DataModel):
338
+ self._set_stream_recursive(field_value, catalog, cache, download_cb)
339
+
267
340
  def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
268
- row_dict = RowDict(zip(udf_fields, row))
341
+ row_dict = RowDict(zip(udf_fields, row, strict=False))
269
342
  return self._parse_row(row_dict, catalog, cache, download_cb)
270
343
 
271
344
  def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
272
- row_dict = RowDict(zip(udf_fields, row))
345
+ row_dict = RowDict(zip(udf_fields, row, strict=False))
273
346
  udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
274
347
  return row_dict["sys__id"], *udf_input
275
348
 
276
- def process_safe(self, obj_rows):
277
- try:
278
- result_objs = self.process(*obj_rows)
279
- except Exception as e: # noqa: BLE001
280
- msg = f"============== Error in user code: '{self.name}' =============="
281
- print(msg)
282
- exc_type, exc_value, exc_traceback = sys.exc_info()
283
- traceback.print_exception(exc_type, exc_value, exc_traceback.tb_next)
284
- print("=" * len(msg))
285
- raise DataChainError(
286
- f"Error in user code in class '{self.name}': {e!s}"
287
- ) from None
288
- return result_objs
289
-
290
349
 
291
350
  def noop(*args, **kwargs):
292
351
  pass
@@ -294,11 +353,11 @@ def noop(*args, **kwargs):
294
353
 
295
354
  async def _prefetch_input(
296
355
  row: T,
297
- download_cb: Optional["Callback"] = None,
356
+ download_cb: Callback | None = None,
298
357
  after_prefetch: "Callable[[], None]" = noop,
299
358
  ) -> T:
300
359
  for obj in row:
301
- if isinstance(obj, File) and await obj._prefetch(download_cb):
360
+ if isinstance(obj, File) and obj.path and await obj._prefetch(download_cb):
302
361
  after_prefetch()
303
362
  return row
304
363
 
@@ -317,8 +376,8 @@ def _remove_prefetched(row: T) -> None:
317
376
  def _prefetch_inputs(
318
377
  prepared_inputs: "Iterable[T]",
319
378
  prefetch: int = 0,
320
- download_cb: Optional["Callback"] = None,
321
- after_prefetch: Optional[Callable[[], None]] = None,
379
+ download_cb: Callback | None = None,
380
+ after_prefetch: Callable[[], None] | None = None,
322
381
  remove_prefetched: bool = False,
323
382
  ) -> "abc.Generator[T, None, None]":
324
383
  if not prefetch:
@@ -327,8 +386,9 @@ def _prefetch_inputs(
327
386
 
328
387
  if after_prefetch is None:
329
388
  after_prefetch = noop
330
- if isinstance(download_cb, CombinedDownloadCallback):
331
- after_prefetch = download_cb.increment_file_count
389
+ if download_cb and hasattr(download_cb, "increment_file_count"):
390
+ increment_file_count: Callable[[], None] = download_cb.increment_file_count
391
+ after_prefetch = increment_file_count
332
392
 
333
393
  f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
334
394
  mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
@@ -384,9 +444,12 @@ class Mapper(UDFBase):
384
444
 
385
445
  with closing(prepared_inputs):
386
446
  for id_, *udf_args in prepared_inputs:
387
- result_objs = self.process_safe(udf_args)
447
+ result_objs = self.process(*udf_args)
388
448
  udf_output = self._flatten_row(result_objs)
389
- output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
449
+ output = [
450
+ {"sys__id": id_}
451
+ | dict(zip(self.signal_names, udf_output, strict=False))
452
+ ]
390
453
  processed_cb.relative_update(1)
391
454
  yield output
392
455
 
@@ -394,11 +457,27 @@ class Mapper(UDFBase):
394
457
 
395
458
 
396
459
  class BatchMapper(UDFBase):
397
- """Inherit from this class to pass to `DataChain.batch_map()`."""
460
+ """Inherit from this class to pass to `DataChain.batch_map()`.
461
+
462
+ .. deprecated:: 0.29.0
463
+ This class is deprecated and will be removed in a future version.
464
+ Use `Aggregator` instead, which provides the similar functionality.
465
+ """
398
466
 
399
467
  is_input_batched = True
400
468
  is_output_batched = True
401
469
 
470
+ def __init__(self):
471
+ import warnings
472
+
473
+ warnings.warn(
474
+ "BatchMapper is deprecated and will be removed in a future version. "
475
+ "Use Aggregator instead, which provides the similar functionality.",
476
+ DeprecationWarning,
477
+ stacklevel=2,
478
+ )
479
+ super().__init__()
480
+
402
481
  def run(
403
482
  self,
404
483
  udf_fields: Sequence[str],
@@ -411,24 +490,26 @@ class BatchMapper(UDFBase):
411
490
  self.setup()
412
491
 
413
492
  for batch in udf_inputs:
414
- n_rows = len(batch.rows)
493
+ n_rows = len(batch)
415
494
  row_ids, *udf_args = zip(
416
495
  *[
417
496
  self._prepare_row_and_id(
418
497
  row, udf_fields, catalog, cache, download_cb
419
498
  )
420
- for row in batch.rows
421
- ]
499
+ for row in batch
500
+ ],
501
+ strict=False,
422
502
  )
423
- result_objs = list(self.process_safe(udf_args))
503
+ result_objs = list(self.process(*udf_args))
424
504
  n_objs = len(result_objs)
425
505
  assert n_objs == n_rows, (
426
506
  f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
427
507
  )
428
508
  udf_outputs = (self._flatten_row(row) for row in result_objs)
429
509
  output = [
430
- {"sys__id": row_id} | dict(zip(self.signal_names, signals))
431
- for row_id, signals in zip(row_ids, udf_outputs)
510
+ {"sys__id": row_id}
511
+ | dict(zip(self.signal_names, signals, strict=False))
512
+ for row_id, signals in zip(row_ids, udf_outputs, strict=False)
432
513
  ]
433
514
  processed_cb.relative_update(n_rows)
434
515
  yield output
@@ -461,10 +542,10 @@ class Generator(UDFBase):
461
542
  )
462
543
 
463
544
  def _process_row(row):
464
- with safe_closing(self.process_safe(row)) as result_objs:
545
+ with safe_closing(self.process(*row)) as result_objs:
465
546
  for result_obj in result_objs:
466
547
  udf_output = self._flatten_row(result_obj)
467
- yield dict(zip(self.signal_names, udf_output))
548
+ yield dict(zip(self.signal_names, udf_output, strict=False))
468
549
 
469
550
  prepared_inputs = _prepare_rows(udf_inputs)
470
551
  prepared_inputs = _prefetch_inputs(
@@ -474,8 +555,9 @@ class Generator(UDFBase):
474
555
  remove_prefetched=bool(self.prefetch) and not cache,
475
556
  )
476
557
  with closing(prepared_inputs):
477
- for row in processed_cb.wrap(prepared_inputs):
558
+ for row in prepared_inputs:
478
559
  yield _process_row(row)
560
+ processed_cb.relative_update(1)
479
561
 
480
562
  self.teardown()
481
563
 
@@ -488,7 +570,7 @@ class Aggregator(UDFBase):
488
570
 
489
571
  def run(
490
572
  self,
491
- udf_fields: "Sequence[str]",
573
+ udf_fields: Sequence[str],
492
574
  udf_inputs: Iterable[RowsOutputBatch],
493
575
  catalog: "Catalog",
494
576
  cache: bool,
@@ -498,16 +580,22 @@ class Aggregator(UDFBase):
498
580
  self.setup()
499
581
 
500
582
  for batch in udf_inputs:
501
- udf_args = zip(
502
- *[
503
- self._prepare_row(row, udf_fields, catalog, cache, download_cb)
504
- for row in batch.rows
505
- ]
506
- )
507
- result_objs = self.process_safe(udf_args)
583
+ prepared_rows = [
584
+ self._prepare_row(row, udf_fields, catalog, cache, download_cb)
585
+ for row in batch
586
+ ]
587
+ batched_args = zip(*prepared_rows, strict=False)
588
+ # Convert aggregated column values to lists. This keeps behavior
589
+ # consistent with the type hints promoted in the public API.
590
+ udf_args = [
591
+ list(arg) if isinstance(arg, tuple) else arg for arg in batched_args
592
+ ]
593
+ result_objs = self.process(*udf_args)
508
594
  udf_outputs = (self._flatten_row(row) for row in result_objs)
509
- output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
510
- processed_cb.relative_update(len(batch.rows))
595
+ output = (
596
+ dict(zip(self.signal_names, row, strict=False)) for row in udf_outputs
597
+ )
598
+ processed_cb.relative_update(len(batch))
511
599
  yield output
512
600
 
513
601
  self.teardown()
@@ -1,12 +1,12 @@
1
1
  import inspect
2
- from collections.abc import Generator, Iterator, Sequence
2
+ from collections.abc import Callable, Generator, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable, Union, get_args, get_origin
4
+ from typing import Any, get_args, get_origin
5
5
 
6
6
  from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
7
7
  from datachain.lib.signal_schema import SignalSchema
8
8
  from datachain.lib.udf import UDFBase
9
- from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
+ from datachain.lib.utils import AbstractUDF, DataChainParamsError, callable_name
10
10
 
11
11
 
12
12
  class UdfSignatureError(DataChainParamsError):
@@ -16,9 +16,9 @@ class UdfSignatureError(DataChainParamsError):
16
16
 
17
17
 
18
18
  @dataclass
19
- class UdfSignature:
20
- func: Union[Callable, UDFBase]
21
- params: dict[str, Union[DataType, Any]]
19
+ class UdfSignature: # noqa: PLW1641
20
+ func: Callable | UDFBase
21
+ params: dict[str, DataType | Any]
22
22
  output_schema: SignalSchema
23
23
 
24
24
  DEFAULT_RETURN_TYPE = str
@@ -28,24 +28,29 @@ class UdfSignature:
28
28
  cls,
29
29
  chain: str,
30
30
  signal_map: dict[str, Callable],
31
- func: Union[None, UDFBase, Callable] = None,
32
- params: Union[None, str, Sequence[str]] = None,
33
- output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
31
+ func: UDFBase | Callable | None = None,
32
+ params: str | Sequence[str] | None = None,
33
+ output: DataType | Sequence[str] | dict[str, DataType] | None = None,
34
34
  is_generator: bool = True,
35
35
  ) -> "UdfSignature":
36
36
  keys = ", ".join(signal_map.keys())
37
37
  if len(signal_map) > 1:
38
38
  raise UdfSignatureError(
39
39
  chain,
40
- f"multiple signals '{keys}' are not supported in processors."
41
- " Chain multiple processors instead.",
40
+ (
41
+ f"multiple signals '{keys}' are not supported in processors."
42
+ " Chain multiple processors instead.",
43
+ ),
42
44
  )
43
- udf_func: Union[UDFBase, Callable]
45
+ udf_func: UDFBase | Callable
44
46
  if len(signal_map) == 1:
45
47
  if func is not None:
46
48
  raise UdfSignatureError(
47
49
  chain,
48
- f"processor can't have signal '{keys}' with function '{func}'",
50
+ (
51
+ "processor can't have signal "
52
+ f"'{keys}' with function '{callable_name(func)}'"
53
+ ),
49
54
  )
50
55
  signal_name, udf_func = next(iter(signal_map.items()))
51
56
  else:
@@ -56,13 +61,16 @@ class UdfSignature:
56
61
  signal_name = None
57
62
 
58
63
  if not isinstance(udf_func, UDFBase) and not callable(udf_func):
59
- raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
64
+ raise UdfSignatureError(
65
+ chain,
66
+ f"UDF '{callable_name(udf_func)}' is not callable",
67
+ )
60
68
 
61
69
  func_params_map_sign, func_outs_sign, is_iterator = cls._func_signature(
62
70
  chain, udf_func
63
71
  )
64
72
 
65
- udf_params: dict[str, Union[DataType, Any]] = {}
73
+ udf_params: dict[str, DataType | Any] = {}
66
74
  if params:
67
75
  udf_params = (
68
76
  {params: Any} if isinstance(params, str) else dict.fromkeys(params, Any)
@@ -76,14 +84,15 @@ class UdfSignature:
76
84
  }
77
85
 
78
86
  if output:
87
+ # Use the actual resolved function (udf_func) for clearer error messages
79
88
  udf_output_map = UdfSignature._validate_output(
80
- chain, signal_name, func, func_outs_sign, output
89
+ chain, signal_name, udf_func, func_outs_sign, output
81
90
  )
82
91
  else:
83
92
  if not func_outs_sign:
84
93
  raise UdfSignatureError(
85
94
  chain,
86
- f"outputs are not defined in function '{udf_func}'"
95
+ f"outputs are not defined in function '{callable_name(udf_func)}'"
87
96
  " hints or 'output'",
88
97
  )
89
98
 
@@ -97,9 +106,12 @@ class UdfSignature:
97
106
  if is_generator and not is_iterator:
98
107
  raise UdfSignatureError(
99
108
  chain,
100
- f"function '{func}' cannot be used in generator/aggregator"
101
- " because it returns a type that is not Iterator/Generator."
102
- f" Instead, it returns '{func_outs_sign}'",
109
+ (
110
+ f"function '{callable_name(udf_func)}' cannot be used in "
111
+ "generator/aggregator because it returns a type that is "
112
+ "not Iterator/Generator. "
113
+ f"Instead, it returns '{func_outs_sign}'"
114
+ ),
103
115
  )
104
116
 
105
117
  if isinstance(func_outs_sign, tuple):
@@ -124,11 +136,14 @@ class UdfSignature:
124
136
  if len(func_outs_sign) != len(output):
125
137
  raise UdfSignatureError(
126
138
  chain,
127
- f"length of outputs names ({len(output)}) and function '{func}'"
128
- f" return type length ({len(func_outs_sign)}) does not match",
139
+ (
140
+ f"length of outputs names ({len(output)}) and function "
141
+ f"'{callable_name(func)}' return type length "
142
+ f"({len(func_outs_sign)}) does not match"
143
+ ),
129
144
  )
130
145
 
131
- udf_output_map = dict(zip(output, func_outs_sign))
146
+ udf_output_map = dict(zip(output, func_outs_sign, strict=False))
132
147
  elif isinstance(output, dict):
133
148
  for key, value in output.items():
134
149
  if not isinstance(key, str):
@@ -164,7 +179,7 @@ class UdfSignature:
164
179
 
165
180
  @staticmethod
166
181
  def _func_signature(
167
- chain: str, udf_func: Union[Callable, UDFBase]
182
+ chain: str, udf_func: Callable | UDFBase
168
183
  ) -> tuple[dict[str, type], Sequence[type], bool]:
169
184
  if isinstance(udf_func, AbstractUDF):
170
185
  func = udf_func.process # type: ignore[unreachable]
@@ -183,17 +198,27 @@ class UdfSignature:
183
198
  orig = get_origin(anno)
184
199
  if inspect.isclass(orig) and issubclass(orig, Iterator):
185
200
  args = get_args(anno)
186
- if len(args) > 1 and not (
187
- issubclass(orig, Generator) and len(args) == 3
188
- ):
189
- raise UdfSignatureError(
190
- chain,
191
- f"function '{func}' should return iterator with a single"
192
- f" value while '{args}' are specified",
193
- )
194
- is_iterator = True
195
- anno = args[0]
196
- orig = get_origin(anno)
201
+ # For typing.Iterator without type args, default to DEFAULT_RETURN_TYPE
202
+ if len(args) == 0:
203
+ is_iterator = True
204
+ anno = UdfSignature.DEFAULT_RETURN_TYPE
205
+ orig = get_origin(anno)
206
+ else:
207
+ # typing.Generator[T, S, R] has 3 args; allow that shape
208
+ if len(args) > 1 and not (
209
+ issubclass(orig, Generator) and len(args) == 3
210
+ ):
211
+ raise UdfSignatureError(
212
+ chain,
213
+ (
214
+ f"function '{callable_name(func)}' should return "
215
+ "iterator with a single value while "
216
+ f"'{args}' are specified"
217
+ ),
218
+ )
219
+ is_iterator = True
220
+ anno = args[0]
221
+ orig = get_origin(anno)
197
222
 
198
223
  if orig and orig is tuple:
199
224
  output_types = tuple(get_args(anno)) # type: ignore[assignment]