datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py CHANGED
@@ -1,9 +1,8 @@
1
- import sys
2
- import traceback
1
+ import hashlib
3
2
  from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
4
3
  from contextlib import closing, nullcontext
5
4
  from functools import partial
6
- from typing import TYPE_CHECKING, Any, Optional, TypeVar
5
+ from typing import TYPE_CHECKING, Any, TypeVar
7
6
 
8
7
  import attrs
9
8
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
@@ -12,9 +11,10 @@ from pydantic import BaseModel
12
11
  from datachain.asyn import AsyncMapper
13
12
  from datachain.cache import temporary_cache
14
13
  from datachain.dataset import RowDict
14
+ from datachain.hash_utils import hash_callable
15
15
  from datachain.lib.convert.flatten import flatten
16
16
  from datachain.lib.file import DataModel, File
17
- from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
17
+ from datachain.lib.utils import AbstractUDF, DataChainParamsError
18
18
  from datachain.query.batch import (
19
19
  Batch,
20
20
  BatchingStrategy,
@@ -40,8 +40,44 @@ T = TypeVar("T", bound=Sequence[Any])
40
40
 
41
41
 
42
42
  class UdfError(DataChainParamsError):
43
- def __init__(self, msg):
44
- super().__init__(f"UDF error: {msg}")
43
+ """Exception raised for UDF-related errors."""
44
+
45
+ def __init__(self, message: str) -> None:
46
+ self.message = message
47
+ super().__init__(message)
48
+
49
+ def __str__(self) -> str:
50
+ return f"{self.__class__.__name__!s}: {self.message!s}"
51
+
52
+ def __reduce__(self):
53
+ """Custom reduce method for pickling."""
54
+ return self.__class__, (self.message,)
55
+
56
+
57
+ class UdfRunError(Exception):
58
+ """Exception raised when UDF execution fails."""
59
+
60
+ def __init__(
61
+ self,
62
+ error: Exception | str,
63
+ stacktrace: str | None = None,
64
+ udf_name: str | None = None,
65
+ ) -> None:
66
+ self.error = error
67
+ self.stacktrace = stacktrace
68
+ self.udf_name = udf_name
69
+ super().__init__(str(error))
70
+
71
+ def __str__(self) -> str:
72
+ if isinstance(self.error, UdfRunError):
73
+ return str(self.error)
74
+ if isinstance(self.error, Exception):
75
+ return f"{self.error.__class__.__name__!s}: {self.error!s}"
76
+ return f"{self.__class__.__name__!s}: {self.error!s}"
77
+
78
+ def __reduce__(self):
79
+ """Custom reduce method for pickling."""
80
+ return self.__class__, (self.error, self.stacktrace, self.udf_name)
45
81
 
46
82
 
47
83
  ColumnType = Any
@@ -54,25 +90,16 @@ UDFOutputSpec = Mapping[str, ColumnType]
54
90
  UDFResult = dict[str, Any]
55
91
 
56
92
 
57
- @attrs.define
58
- class UDFProperties:
59
- udf: "UDFAdapter"
60
-
61
- def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
62
- return self.udf.get_batching(use_partitioning)
63
-
64
- @property
65
- def batch_rows(self):
66
- return self.udf.batch_rows
67
-
68
-
69
93
  @attrs.define(slots=False)
70
94
  class UDFAdapter:
71
95
  inner: "UDFBase"
72
96
  output: UDFOutputSpec
73
- batch_rows: Optional[int] = None
97
+ batch_size: int | None = None
74
98
  batch: int = 1
75
99
 
100
+ def hash(self) -> str:
101
+ return self.inner.hash()
102
+
76
103
  def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
77
104
  if use_partitioning:
78
105
  return Partition()
@@ -83,11 +110,6 @@ class UDFAdapter:
83
110
  return Batch(self.batch)
84
111
  raise ValueError(f"invalid batch size {self.batch}")
85
112
 
86
- @property
87
- def properties(self):
88
- # For backwards compatibility.
89
- return UDFProperties(self)
90
-
91
113
  def run(
92
114
  self,
93
115
  udf_fields: "Sequence[str]",
@@ -164,10 +186,31 @@ class UDFBase(AbstractUDF):
164
186
  prefetch: int = 0
165
187
 
166
188
  def __init__(self):
167
- self.params: Optional[SignalSchema] = None
189
+ self.params: SignalSchema | None = None
168
190
  self.output = None
169
191
  self._func = None
170
192
 
193
+ def hash(self) -> str:
194
+ """
195
+ Creates SHA hash of this UDF function. It takes into account function,
196
+ inputs and outputs.
197
+
198
+ For function-based UDFs, hashes self._func.
199
+ For class-based UDFs, hashes the process method.
200
+ """
201
+ # Hash user code: either _func (function-based) or process method (class-based)
202
+ func_to_hash = self._func if self._func else self.process
203
+
204
+ parts = [
205
+ hash_callable(func_to_hash),
206
+ self.params.hash() if self.params else "",
207
+ self.output.hash(),
208
+ ]
209
+
210
+ return hashlib.sha256(
211
+ b"".join([bytes.fromhex(part) for part in parts])
212
+ ).hexdigest()
213
+
171
214
  def process(self, *args, **kwargs):
172
215
  """Processing function that needs to be defined by user"""
173
216
  if not self._func:
@@ -188,7 +231,7 @@ class UDFBase(AbstractUDF):
188
231
  self,
189
232
  sign: "UdfSignature",
190
233
  params: "SignalSchema",
191
- func: Optional[Callable],
234
+ func: Callable | None,
192
235
  ):
193
236
  self.params = params
194
237
  self.output = sign.output_schema
@@ -237,13 +280,13 @@ class UDFBase(AbstractUDF):
237
280
 
238
281
  def to_udf_wrapper(
239
282
  self,
240
- batch_rows: Optional[int] = None,
283
+ batch_size: int | None = None,
241
284
  batch: int = 1,
242
285
  ) -> UDFAdapter:
243
286
  return UDFAdapter(
244
287
  self,
245
288
  self.output.to_udf_spec(),
246
- batch_rows,
289
+ batch_size,
247
290
  batch,
248
291
  )
249
292
 
@@ -295,28 +338,14 @@ class UDFBase(AbstractUDF):
295
338
  self._set_stream_recursive(field_value, catalog, cache, download_cb)
296
339
 
297
340
  def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
298
- row_dict = RowDict(zip(udf_fields, row))
341
+ row_dict = RowDict(zip(udf_fields, row, strict=False))
299
342
  return self._parse_row(row_dict, catalog, cache, download_cb)
300
343
 
301
344
  def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
302
- row_dict = RowDict(zip(udf_fields, row))
345
+ row_dict = RowDict(zip(udf_fields, row, strict=False))
303
346
  udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
304
347
  return row_dict["sys__id"], *udf_input
305
348
 
306
- def process_safe(self, obj_rows):
307
- try:
308
- result_objs = self.process(*obj_rows)
309
- except Exception as e: # noqa: BLE001
310
- msg = f"============== Error in user code: '{self.name}' =============="
311
- print(msg)
312
- exc_type, exc_value, exc_traceback = sys.exc_info()
313
- traceback.print_exception(exc_type, exc_value, exc_traceback.tb_next)
314
- print("=" * len(msg))
315
- raise DataChainError(
316
- f"Error in user code in class '{self.name}': {e!s}"
317
- ) from None
318
- return result_objs
319
-
320
349
 
321
350
  def noop(*args, **kwargs):
322
351
  pass
@@ -324,7 +353,7 @@ def noop(*args, **kwargs):
324
353
 
325
354
  async def _prefetch_input(
326
355
  row: T,
327
- download_cb: Optional["Callback"] = None,
356
+ download_cb: Callback | None = None,
328
357
  after_prefetch: "Callable[[], None]" = noop,
329
358
  ) -> T:
330
359
  for obj in row:
@@ -347,8 +376,8 @@ def _remove_prefetched(row: T) -> None:
347
376
  def _prefetch_inputs(
348
377
  prepared_inputs: "Iterable[T]",
349
378
  prefetch: int = 0,
350
- download_cb: Optional["Callback"] = None,
351
- after_prefetch: Optional[Callable[[], None]] = None,
379
+ download_cb: Callback | None = None,
380
+ after_prefetch: Callable[[], None] | None = None,
352
381
  remove_prefetched: bool = False,
353
382
  ) -> "abc.Generator[T, None, None]":
354
383
  if not prefetch:
@@ -415,9 +444,12 @@ class Mapper(UDFBase):
415
444
 
416
445
  with closing(prepared_inputs):
417
446
  for id_, *udf_args in prepared_inputs:
418
- result_objs = self.process_safe(udf_args)
447
+ result_objs = self.process(*udf_args)
419
448
  udf_output = self._flatten_row(result_objs)
420
- output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
449
+ output = [
450
+ {"sys__id": id_}
451
+ | dict(zip(self.signal_names, udf_output, strict=False))
452
+ ]
421
453
  processed_cb.relative_update(1)
422
454
  yield output
423
455
 
@@ -465,17 +497,19 @@ class BatchMapper(UDFBase):
465
497
  row, udf_fields, catalog, cache, download_cb
466
498
  )
467
499
  for row in batch
468
- ]
500
+ ],
501
+ strict=False,
469
502
  )
470
- result_objs = list(self.process_safe(udf_args))
503
+ result_objs = list(self.process(*udf_args))
471
504
  n_objs = len(result_objs)
472
505
  assert n_objs == n_rows, (
473
506
  f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
474
507
  )
475
508
  udf_outputs = (self._flatten_row(row) for row in result_objs)
476
509
  output = [
477
- {"sys__id": row_id} | dict(zip(self.signal_names, signals))
478
- for row_id, signals in zip(row_ids, udf_outputs)
510
+ {"sys__id": row_id}
511
+ | dict(zip(self.signal_names, signals, strict=False))
512
+ for row_id, signals in zip(row_ids, udf_outputs, strict=False)
479
513
  ]
480
514
  processed_cb.relative_update(n_rows)
481
515
  yield output
@@ -508,10 +542,10 @@ class Generator(UDFBase):
508
542
  )
509
543
 
510
544
  def _process_row(row):
511
- with safe_closing(self.process_safe(row)) as result_objs:
545
+ with safe_closing(self.process(*row)) as result_objs:
512
546
  for result_obj in result_objs:
513
547
  udf_output = self._flatten_row(result_obj)
514
- yield dict(zip(self.signal_names, udf_output))
548
+ yield dict(zip(self.signal_names, udf_output, strict=False))
515
549
 
516
550
  prepared_inputs = _prepare_rows(udf_inputs)
517
551
  prepared_inputs = _prefetch_inputs(
@@ -546,15 +580,21 @@ class Aggregator(UDFBase):
546
580
  self.setup()
547
581
 
548
582
  for batch in udf_inputs:
549
- udf_args = zip(
550
- *[
551
- self._prepare_row(row, udf_fields, catalog, cache, download_cb)
552
- for row in batch
553
- ]
554
- )
555
- result_objs = self.process_safe(udf_args)
583
+ prepared_rows = [
584
+ self._prepare_row(row, udf_fields, catalog, cache, download_cb)
585
+ for row in batch
586
+ ]
587
+ batched_args = zip(*prepared_rows, strict=False)
588
+ # Convert aggregated column values to lists. This keeps behavior
589
+ # consistent with the type hints promoted in the public API.
590
+ udf_args = [
591
+ list(arg) if isinstance(arg, tuple) else arg for arg in batched_args
592
+ ]
593
+ result_objs = self.process(*udf_args)
556
594
  udf_outputs = (self._flatten_row(row) for row in result_objs)
557
- output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
595
+ output = (
596
+ dict(zip(self.signal_names, row, strict=False)) for row in udf_outputs
597
+ )
558
598
  processed_cb.relative_update(len(batch))
559
599
  yield output
560
600
 
@@ -1,12 +1,12 @@
1
1
  import inspect
2
- from collections.abc import Generator, Iterator, Sequence
2
+ from collections.abc import Callable, Generator, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable, Union, get_args, get_origin
4
+ from typing import Any, get_args, get_origin
5
5
 
6
6
  from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
7
7
  from datachain.lib.signal_schema import SignalSchema
8
8
  from datachain.lib.udf import UDFBase
9
- from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
+ from datachain.lib.utils import AbstractUDF, DataChainParamsError, callable_name
10
10
 
11
11
 
12
12
  class UdfSignatureError(DataChainParamsError):
@@ -17,8 +17,8 @@ class UdfSignatureError(DataChainParamsError):
17
17
 
18
18
  @dataclass
19
19
  class UdfSignature: # noqa: PLW1641
20
- func: Union[Callable, UDFBase]
21
- params: dict[str, Union[DataType, Any]]
20
+ func: Callable | UDFBase
21
+ params: dict[str, DataType | Any]
22
22
  output_schema: SignalSchema
23
23
 
24
24
  DEFAULT_RETURN_TYPE = str
@@ -28,24 +28,29 @@ class UdfSignature: # noqa: PLW1641
28
28
  cls,
29
29
  chain: str,
30
30
  signal_map: dict[str, Callable],
31
- func: Union[None, UDFBase, Callable] = None,
32
- params: Union[None, str, Sequence[str]] = None,
33
- output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
31
+ func: UDFBase | Callable | None = None,
32
+ params: str | Sequence[str] | None = None,
33
+ output: DataType | Sequence[str] | dict[str, DataType] | None = None,
34
34
  is_generator: bool = True,
35
35
  ) -> "UdfSignature":
36
36
  keys = ", ".join(signal_map.keys())
37
37
  if len(signal_map) > 1:
38
38
  raise UdfSignatureError(
39
39
  chain,
40
- f"multiple signals '{keys}' are not supported in processors."
41
- " Chain multiple processors instead.",
40
+ (
41
+ f"multiple signals '{keys}' are not supported in processors."
42
+ " Chain multiple processors instead.",
43
+ ),
42
44
  )
43
- udf_func: Union[UDFBase, Callable]
45
+ udf_func: UDFBase | Callable
44
46
  if len(signal_map) == 1:
45
47
  if func is not None:
46
48
  raise UdfSignatureError(
47
49
  chain,
48
- f"processor can't have signal '{keys}' with function '{func}'",
50
+ (
51
+ "processor can't have signal "
52
+ f"'{keys}' with function '{callable_name(func)}'"
53
+ ),
49
54
  )
50
55
  signal_name, udf_func = next(iter(signal_map.items()))
51
56
  else:
@@ -56,13 +61,16 @@ class UdfSignature: # noqa: PLW1641
56
61
  signal_name = None
57
62
 
58
63
  if not isinstance(udf_func, UDFBase) and not callable(udf_func):
59
- raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
64
+ raise UdfSignatureError(
65
+ chain,
66
+ f"UDF '{callable_name(udf_func)}' is not callable",
67
+ )
60
68
 
61
69
  func_params_map_sign, func_outs_sign, is_iterator = cls._func_signature(
62
70
  chain, udf_func
63
71
  )
64
72
 
65
- udf_params: dict[str, Union[DataType, Any]] = {}
73
+ udf_params: dict[str, DataType | Any] = {}
66
74
  if params:
67
75
  udf_params = (
68
76
  {params: Any} if isinstance(params, str) else dict.fromkeys(params, Any)
@@ -76,14 +84,15 @@ class UdfSignature: # noqa: PLW1641
76
84
  }
77
85
 
78
86
  if output:
87
+ # Use the actual resolved function (udf_func) for clearer error messages
79
88
  udf_output_map = UdfSignature._validate_output(
80
- chain, signal_name, func, func_outs_sign, output
89
+ chain, signal_name, udf_func, func_outs_sign, output
81
90
  )
82
91
  else:
83
92
  if not func_outs_sign:
84
93
  raise UdfSignatureError(
85
94
  chain,
86
- f"outputs are not defined in function '{udf_func}'"
95
+ f"outputs are not defined in function '{callable_name(udf_func)}'"
87
96
  " hints or 'output'",
88
97
  )
89
98
 
@@ -97,9 +106,12 @@ class UdfSignature: # noqa: PLW1641
97
106
  if is_generator and not is_iterator:
98
107
  raise UdfSignatureError(
99
108
  chain,
100
- f"function '{func}' cannot be used in generator/aggregator"
101
- " because it returns a type that is not Iterator/Generator."
102
- f" Instead, it returns '{func_outs_sign}'",
109
+ (
110
+ f"function '{callable_name(udf_func)}' cannot be used in "
111
+ "generator/aggregator because it returns a type that is "
112
+ "not Iterator/Generator. "
113
+ f"Instead, it returns '{func_outs_sign}'"
114
+ ),
103
115
  )
104
116
 
105
117
  if isinstance(func_outs_sign, tuple):
@@ -124,11 +136,14 @@ class UdfSignature: # noqa: PLW1641
124
136
  if len(func_outs_sign) != len(output):
125
137
  raise UdfSignatureError(
126
138
  chain,
127
- f"length of outputs names ({len(output)}) and function '{func}'"
128
- f" return type length ({len(func_outs_sign)}) does not match",
139
+ (
140
+ f"length of outputs names ({len(output)}) and function "
141
+ f"'{callable_name(func)}' return type length "
142
+ f"({len(func_outs_sign)}) does not match"
143
+ ),
129
144
  )
130
145
 
131
- udf_output_map = dict(zip(output, func_outs_sign))
146
+ udf_output_map = dict(zip(output, func_outs_sign, strict=False))
132
147
  elif isinstance(output, dict):
133
148
  for key, value in output.items():
134
149
  if not isinstance(key, str):
@@ -164,7 +179,7 @@ class UdfSignature: # noqa: PLW1641
164
179
 
165
180
  @staticmethod
166
181
  def _func_signature(
167
- chain: str, udf_func: Union[Callable, UDFBase]
182
+ chain: str, udf_func: Callable | UDFBase
168
183
  ) -> tuple[dict[str, type], Sequence[type], bool]:
169
184
  if isinstance(udf_func, AbstractUDF):
170
185
  func = udf_func.process # type: ignore[unreachable]
@@ -183,17 +198,27 @@ class UdfSignature: # noqa: PLW1641
183
198
  orig = get_origin(anno)
184
199
  if inspect.isclass(orig) and issubclass(orig, Iterator):
185
200
  args = get_args(anno)
186
- if len(args) > 1 and not (
187
- issubclass(orig, Generator) and len(args) == 3
188
- ):
189
- raise UdfSignatureError(
190
- chain,
191
- f"function '{func}' should return iterator with a single"
192
- f" value while '{args}' are specified",
193
- )
194
- is_iterator = True
195
- anno = args[0]
196
- orig = get_origin(anno)
201
+ # For typing.Iterator without type args, default to DEFAULT_RETURN_TYPE
202
+ if len(args) == 0:
203
+ is_iterator = True
204
+ anno = UdfSignature.DEFAULT_RETURN_TYPE
205
+ orig = get_origin(anno)
206
+ else:
207
+ # typing.Generator[T, S, R] has 3 args; allow that shape
208
+ if len(args) > 1 and not (
209
+ issubclass(orig, Generator) and len(args) == 3
210
+ ):
211
+ raise UdfSignatureError(
212
+ chain,
213
+ (
214
+ f"function '{callable_name(func)}' should return "
215
+ "iterator with a single value while "
216
+ f"'{args}' are specified"
217
+ ),
218
+ )
219
+ is_iterator = True
220
+ anno = args[0]
221
+ orig = get_origin(anno)
197
222
 
198
223
  if orig and orig is tuple:
199
224
  output_types = tuple(get_args(anno)) # type: ignore[assignment]
datachain/lib/utils.py CHANGED
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import re
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Sequence
@@ -32,6 +33,25 @@ class DataChainColumnError(DataChainParamsError):
32
33
  super().__init__(f"Error for column {col_name}: {msg}")
33
34
 
34
35
 
36
+ def callable_name(obj: object) -> str:
37
+ """Return a friendly name for a callable or UDF-like instance."""
38
+ # UDF classes in DataChain inherit from AbstractUDF; prefer class name
39
+ if isinstance(obj, AbstractUDF):
40
+ return obj.__class__.__name__
41
+
42
+ # Plain functions and bound/unbound methods
43
+ if inspect.ismethod(obj) or inspect.isfunction(obj):
44
+ # __name__ exists for functions/methods; includes "<lambda>" for lambdas
45
+ return obj.__name__ # type: ignore[attr-defined]
46
+
47
+ # Generic callable object
48
+ if callable(obj):
49
+ return obj.__class__.__name__
50
+
51
+ # Fallback for non-callables
52
+ return str(obj)
53
+
54
+
35
55
  def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
36
56
  """Returns normalized_name -> original_name dict."""
37
57
  gen_col_counter = 0
datachain/lib/video.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import posixpath
2
2
  import shutil
3
3
  import tempfile
4
- from typing import Optional, Union
5
4
 
6
5
  from numpy import ndarray
7
6
 
@@ -18,7 +17,7 @@ except ImportError as exc:
18
17
  ) from exc
19
18
 
20
19
 
21
- def video_info(file: Union[File, VideoFile]) -> Video:
20
+ def video_info(file: File | VideoFile) -> Video:
22
21
  """
23
22
  Returns video file information.
24
23
 
@@ -108,7 +107,7 @@ def video_frame_np(video: VideoFile, frame: int) -> ndarray:
108
107
  def validate_frame_range(
109
108
  video: VideoFile,
110
109
  start: int = 0,
111
- end: Optional[int] = None,
110
+ end: int | None = None,
112
111
  step: int = 1,
113
112
  ) -> tuple[int, int, int]:
114
113
  """
@@ -186,7 +185,7 @@ def save_video_fragment(
186
185
  start: float,
187
186
  end: float,
188
187
  output: str,
189
- format: Optional[str] = None,
188
+ format: str | None = None,
190
189
  ) -> VideoFile:
191
190
  """
192
191
  Saves video interval as a new video file. If output is a remote path,
@@ -1,20 +1,13 @@
1
- import json
2
1
  import tarfile
2
+ import types
3
3
  import warnings
4
- from collections.abc import Iterator, Sequence
4
+ from collections.abc import Callable, Iterator, Sequence
5
5
  from pathlib import Path
6
- from typing import (
7
- Any,
8
- Callable,
9
- ClassVar,
10
- Optional,
11
- Union,
12
- get_args,
13
- get_origin,
14
- )
6
+ from typing import Any, ClassVar, Union, get_args, get_origin
15
7
 
16
8
  from pydantic import Field
17
9
 
10
+ from datachain import json
18
11
  from datachain.lib.data_model import DataModel
19
12
  from datachain.lib.file import File
20
13
  from datachain.lib.tar import build_tar_member
@@ -64,28 +57,28 @@ class WDSBasic(DataModel):
64
57
 
65
58
 
66
59
  class WDSAllFile(WDSBasic):
67
- txt: Optional[str] = Field(default=None)
68
- text: Optional[str] = Field(default=None)
69
- cap: Optional[str] = Field(default=None)
70
- transcript: Optional[str] = Field(default=None)
71
- cls: Optional[int] = Field(default=None)
72
- cls2: Optional[int] = Field(default=None)
73
- index: Optional[int] = Field(default=None)
74
- inx: Optional[int] = Field(default=None)
75
- id: Optional[int] = Field(default=None)
76
- json: Optional[dict] = Field(default=None) # type: ignore[assignment]
77
- jsn: Optional[dict] = Field(default=None)
78
-
79
- pyd: Optional[bytes] = Field(default=None)
80
- pickle: Optional[bytes] = Field(default=None)
81
- pth: Optional[bytes] = Field(default=None)
82
- ten: Optional[bytes] = Field(default=None)
83
- tb: Optional[bytes] = Field(default=None)
84
- mp: Optional[bytes] = Field(default=None)
85
- msg: Optional[bytes] = Field(default=None)
86
- npy: Optional[bytes] = Field(default=None)
87
- npz: Optional[bytes] = Field(default=None)
88
- cbor: Optional[bytes] = Field(default=None)
60
+ txt: str | None = Field(default=None)
61
+ text: str | None = Field(default=None)
62
+ cap: str | None = Field(default=None)
63
+ transcript: str | None = Field(default=None)
64
+ cls: int | None = Field(default=None)
65
+ cls2: int | None = Field(default=None)
66
+ index: int | None = Field(default=None)
67
+ inx: int | None = Field(default=None)
68
+ id: int | None = Field(default=None)
69
+ json: dict | None = Field(default=None) # type: ignore[assignment]
70
+ jsn: dict | None = Field(default=None)
71
+
72
+ pyd: bytes | None = Field(default=None)
73
+ pickle: bytes | None = Field(default=None)
74
+ pth: bytes | None = Field(default=None)
75
+ ten: bytes | None = Field(default=None)
76
+ tb: bytes | None = Field(default=None)
77
+ mp: bytes | None = Field(default=None)
78
+ msg: bytes | None = Field(default=None)
79
+ npy: bytes | None = Field(default=None)
80
+ npz: bytes | None = Field(default=None)
81
+ cbor: bytes | None = Field(default=None)
89
82
 
90
83
 
91
84
  class WDSReadableSubclass(DataModel):
@@ -189,9 +182,11 @@ class Builder:
189
182
  return
190
183
 
191
184
  anno = field.annotation
192
- if get_origin(anno) == Union:
193
- args = get_args(anno)
194
- anno = args[0]
185
+ anno_origin = get_origin(anno)
186
+ if anno_origin in (Union, types.UnionType):
187
+ anno_args = get_args(anno)
188
+ if len(anno_args) == 2 and type(None) in anno_args:
189
+ return anno_args[0] if anno_args[1] is type(None) else anno_args[1]
195
190
 
196
191
  return anno
197
192