datachain 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/udf.py CHANGED
@@ -1,31 +1,32 @@
1
1
  import sys
2
2
  import traceback
3
- from typing import TYPE_CHECKING, Callable, Optional
3
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
4
+ from typing import TYPE_CHECKING, Any, Callable, Optional
4
5
 
6
+ import attrs
5
7
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
6
8
  from pydantic import BaseModel
7
9
 
8
10
  from datachain.dataset import RowDict
9
11
  from datachain.lib.convert.flatten import flatten
10
- from datachain.lib.convert.unflatten import unflatten_to_json
12
+ from datachain.lib.data_model import DataValue
11
13
  from datachain.lib.file import File
12
- from datachain.lib.model_store import ModelStore
13
14
  from datachain.lib.signal_schema import SignalSchema
14
- from datachain.lib.udf_signature import UdfSignature
15
15
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
16
- from datachain.query.batch import UDFInputBatch
17
- from datachain.query.schema import ColumnParameter
18
- from datachain.query.udf import UDFBase as _UDFBase
19
- from datachain.query.udf import UDFProperties
16
+ from datachain.query.batch import (
17
+ Batch,
18
+ BatchingStrategy,
19
+ NoBatching,
20
+ Partition,
21
+ RowsOutputBatch,
22
+ )
20
23
 
21
24
  if TYPE_CHECKING:
22
- from collections.abc import Iterable, Iterator, Sequence
23
-
24
25
  from typing_extensions import Self
25
26
 
26
27
  from datachain.catalog import Catalog
27
- from datachain.query.batch import RowsOutput, UDFInput
28
- from datachain.query.udf import UDFResult
28
+ from datachain.lib.udf_signature import UdfSignature
29
+ from datachain.query.batch import RowsOutput
29
30
 
30
31
 
31
32
  class UdfError(DataChainParamsError):
@@ -33,14 +34,47 @@ class UdfError(DataChainParamsError):
33
34
  super().__init__(f"UDF error: {msg}")
34
35
 
35
36
 
36
- class UDFAdapter(_UDFBase):
37
- def __init__(
38
- self,
39
- inner: "UDFBase",
40
- properties: UDFProperties,
41
- ):
42
- self.inner = inner
43
- super().__init__(properties)
37
+ ColumnType = Any
38
+
39
+ # Specification for the output of a UDF
40
+ UDFOutputSpec = Mapping[str, ColumnType]
41
+
42
+ # Result type when calling the UDF wrapper around the actual
43
+ # Python function / class implementing it.
44
+ UDFResult = dict[str, Any]
45
+
46
+
47
+ @attrs.define
48
+ class UDFProperties:
49
+ udf: "UDFAdapter"
50
+
51
+ def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
52
+ return self.udf.get_batching(use_partitioning)
53
+
54
+ @property
55
+ def batch(self):
56
+ return self.udf.batch
57
+
58
+
59
+ @attrs.define(slots=False)
60
+ class UDFAdapter:
61
+ inner: "UDFBase"
62
+ output: UDFOutputSpec
63
+ batch: int = 1
64
+
65
+ def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
66
+ if use_partitioning:
67
+ return Partition()
68
+ if self.batch == 1:
69
+ return NoBatching()
70
+ if self.batch > 1:
71
+ return Batch(self.batch)
72
+ raise ValueError(f"invalid batch size {self.batch}")
73
+
74
+ @property
75
+ def properties(self):
76
+ # For backwards compatibility.
77
+ return UDFProperties(self)
44
78
 
45
79
  def run(
46
80
  self,
@@ -51,48 +85,16 @@ class UDFAdapter(_UDFBase):
51
85
  cache: bool,
52
86
  download_cb: Callback = DEFAULT_CALLBACK,
53
87
  processed_cb: Callback = DEFAULT_CALLBACK,
54
- ) -> "Iterator[Iterable[UDFResult]]":
55
- self.inner._catalog = catalog
56
- if hasattr(self.inner, "setup") and callable(self.inner.setup):
57
- self.inner.setup()
58
-
59
- yield from super().run(
88
+ ) -> Iterator[Iterable[UDFResult]]:
89
+ yield from self.inner.run(
60
90
  udf_fields,
61
91
  udf_inputs,
62
92
  catalog,
63
- is_generator,
64
93
  cache,
65
94
  download_cb,
66
95
  processed_cb,
67
96
  )
68
97
 
69
- if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
70
- self.inner.teardown()
71
-
72
- def run_once(
73
- self,
74
- catalog: "Catalog",
75
- arg: "UDFInput",
76
- is_generator: bool = False,
77
- cache: bool = False,
78
- cb: Callback = DEFAULT_CALLBACK,
79
- ) -> "Iterable[UDFResult]":
80
- if isinstance(arg, UDFInputBatch):
81
- udf_inputs = [
82
- self.bind_parameters(catalog, row, cache=cache, cb=cb)
83
- for row in arg.rows
84
- ]
85
- udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
86
- return self._process_results(arg.rows, udf_outputs, is_generator)
87
- if isinstance(arg, RowDict):
88
- udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
89
- udf_outputs = self.inner(*udf_inputs, cache=cache, download_cb=cb)
90
- if not is_generator:
91
- # udf_outputs is generator already if is_generator=True
92
- udf_outputs = [udf_outputs]
93
- return self._process_results([arg], udf_outputs, is_generator)
94
- raise ValueError(f"Unexpected UDF argument: {arg}")
95
-
96
98
 
97
99
  class UDFBase(AbstractUDF):
98
100
  """Base class for stateful user-defined functions.
@@ -142,18 +144,13 @@ class UDFBase(AbstractUDF):
142
144
  ```
143
145
  """
144
146
 
145
- is_input_batched = False
146
147
  is_output_batched = False
147
- is_input_grouped = False
148
- params_spec: Optional[list[str]]
148
+ catalog: "Optional[Catalog]"
149
149
 
150
150
  def __init__(self):
151
- self.params = None
151
+ self.params: Optional[SignalSchema] = None
152
152
  self.output = None
153
- self.params_spec = None
154
- self.output_spec = None
155
- self._contains_stream = None
156
- self._catalog = None
153
+ self.catalog = None
157
154
  self._func = None
158
155
 
159
156
  def process(self, *args, **kwargs):
@@ -174,29 +171,24 @@ class UDFBase(AbstractUDF):
174
171
 
175
172
  def _init(
176
173
  self,
177
- sign: UdfSignature,
174
+ sign: "UdfSignature",
178
175
  params: SignalSchema,
179
- func: Callable,
176
+ func: Optional[Callable],
180
177
  ):
181
178
  self.params = params
182
179
  self.output = sign.output_schema
183
-
184
- params_spec = self.params.to_udf_spec()
185
- self.params_spec = list(params_spec.keys())
186
- self.output_spec = self.output.to_udf_spec()
187
-
188
180
  self._func = func
189
181
 
190
182
  @classmethod
191
183
  def _create(
192
184
  cls,
193
- sign: UdfSignature,
185
+ sign: "UdfSignature",
194
186
  params: SignalSchema,
195
187
  ) -> "Self":
196
188
  if isinstance(sign.func, AbstractUDF):
197
189
  if not isinstance(sign.func, cls): # type: ignore[unreachable]
198
190
  raise UdfError(
199
- f"cannot create UDF: provided UDF '{sign.func.__name__}'"
191
+ f"cannot create UDF: provided UDF '{type(sign.func).__name__}'"
200
192
  f" must be a child of target class '{cls.__name__}'",
201
193
  )
202
194
  result = sign.func
@@ -212,57 +204,27 @@ class UDFBase(AbstractUDF):
212
204
  def name(self):
213
205
  return self.__class__.__name__
214
206
 
215
- def set_catalog(self, catalog):
216
- self._catalog = catalog.copy(db=False)
217
-
218
207
  @property
219
- def catalog(self):
220
- return self._catalog
208
+ def signal_names(self) -> Iterable[str]:
209
+ return self.output.to_udf_spec().keys()
221
210
 
222
211
  def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
223
- assert self.params_spec is not None
224
- properties = UDFProperties(
225
- [ColumnParameter(p) for p in self.params_spec], self.output_spec, batch
212
+ return UDFAdapter(
213
+ self,
214
+ self.output.to_udf_spec(),
215
+ batch,
226
216
  )
227
- return UDFAdapter(self, properties)
228
-
229
- def validate_results(self, results, *args, **kwargs):
230
- return results
231
-
232
- def __call__(self, *rows, cache, download_cb):
233
- if self.is_input_grouped:
234
- objs = self._parse_grouped_rows(rows[0], cache, download_cb)
235
- elif self.is_input_batched:
236
- objs = zip(*self._parse_rows(rows[0], cache, download_cb))
237
- else:
238
- objs = self._parse_rows([rows], cache, download_cb)[0]
239
217
 
240
- result_objs = self.process_safe(objs)
241
-
242
- if not self.is_output_batched:
243
- result_objs = [result_objs]
244
-
245
- # Generator expression is required, otherwise the value will be materialized
246
- res = (self._flatten_row(row) for row in result_objs)
247
-
248
- if not self.is_output_batched:
249
- res = list(res)
250
- assert (
251
- len(res) == 1
252
- ), f"{self.name} returns {len(res)} rows while it's not batched"
253
- if isinstance(res[0], tuple):
254
- res = res[0]
255
- elif (
256
- self.is_input_batched
257
- and self.is_output_batched
258
- and not self.is_input_grouped
259
- ):
260
- res = list(res)
261
- assert len(res) == len(
262
- rows[0]
263
- ), f"{self.name} returns {len(res)} rows while len(rows[0]) expected"
264
-
265
- return res
218
+ def run(
219
+ self,
220
+ udf_fields: "Sequence[str]",
221
+ udf_inputs: "Iterable[Any]",
222
+ catalog: "Catalog",
223
+ cache: bool,
224
+ download_cb: Callback = DEFAULT_CALLBACK,
225
+ processed_cb: Callback = DEFAULT_CALLBACK,
226
+ ) -> Iterator[Iterable[UDFResult]]:
227
+ raise NotImplementedError
266
228
 
267
229
  def _flatten_row(self, row):
268
230
  if len(self.output.values) > 1 and not isinstance(row, BaseModel):
@@ -276,47 +238,28 @@ class UDFBase(AbstractUDF):
276
238
  def _obj_to_list(obj):
277
239
  return flatten(obj) if isinstance(obj, BaseModel) else [obj]
278
240
 
279
- def _parse_rows(self, rows, cache, download_cb):
280
- objs = []
281
- for row in rows:
282
- obj_row = self.params.row_to_objs(row)
283
- for obj in obj_row:
284
- if isinstance(obj, File):
285
- obj._set_stream(
286
- self._catalog, caching_enabled=cache, download_cb=download_cb
287
- )
288
- objs.append(obj_row)
289
- return objs
290
-
291
- def _parse_grouped_rows(self, group, cache, download_cb):
292
- spec_map = {}
293
- output_map = {}
294
- for name, (anno, subtree) in self.params.tree.items():
295
- if ModelStore.is_pydantic(anno):
296
- length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
297
- else:
298
- length = 1
299
- spec_map[name] = anno, length
300
- output_map[name] = []
301
-
302
- for flat_obj in group:
303
- position = 0
304
- for signal, (cls, length) in spec_map.items():
305
- slice = flat_obj[position : position + length]
306
- position += length
307
-
308
- if ModelStore.is_pydantic(cls):
309
- obj = cls(**unflatten_to_json(cls, slice))
310
- else:
311
- obj = slice[0]
312
-
313
- if isinstance(obj, File):
314
- obj._set_stream(
315
- self._catalog, caching_enabled=cache, download_cb=download_cb
316
- )
317
- output_map[signal].append(obj)
241
+ def _parse_row(
242
+ self, row_dict: RowDict, cache: bool, download_cb: Callback
243
+ ) -> list[DataValue]:
244
+ assert self.params
245
+ row = [row_dict[p] for p in self.params.to_udf_spec()]
246
+ obj_row = self.params.row_to_objs(row)
247
+ for obj in obj_row:
248
+ if isinstance(obj, File):
249
+ assert self.catalog is not None
250
+ obj._set_stream(
251
+ self.catalog, caching_enabled=cache, download_cb=download_cb
252
+ )
253
+ return obj_row
254
+
255
+ def _prepare_row(self, row, udf_fields, cache, download_cb):
256
+ row_dict = RowDict(zip(udf_fields, row))
257
+ return self._parse_row(row_dict, cache, download_cb)
318
258
 
319
- return list(output_map.values())
259
+ def _prepare_row_and_id(self, row, udf_fields, cache, download_cb):
260
+ row_dict = RowDict(zip(udf_fields, row))
261
+ udf_input = self._parse_row(row_dict, cache, download_cb)
262
+ return row_dict["sys__id"], *udf_input
320
263
 
321
264
  def process_safe(self, obj_rows):
322
265
  try:
@@ -336,23 +279,128 @@ class UDFBase(AbstractUDF):
336
279
  class Mapper(UDFBase):
337
280
  """Inherit from this class to pass to `DataChain.map()`."""
338
281
 
282
+ def run(
283
+ self,
284
+ udf_fields: "Sequence[str]",
285
+ udf_inputs: "Iterable[Sequence[Any]]",
286
+ catalog: "Catalog",
287
+ cache: bool,
288
+ download_cb: Callback = DEFAULT_CALLBACK,
289
+ processed_cb: Callback = DEFAULT_CALLBACK,
290
+ ) -> Iterator[Iterable[UDFResult]]:
291
+ self.catalog = catalog
292
+ self.setup()
293
+
294
+ for row in udf_inputs:
295
+ id_, *udf_args = self._prepare_row_and_id(
296
+ row, udf_fields, cache, download_cb
297
+ )
298
+ result_objs = self.process_safe(udf_args)
299
+ udf_output = self._flatten_row(result_objs)
300
+ output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
301
+ processed_cb.relative_update(1)
302
+ yield output
303
+
304
+ self.teardown()
305
+
339
306
 
340
307
  class BatchMapper(UDFBase):
341
308
  """Inherit from this class to pass to `DataChain.batch_map()`."""
342
309
 
343
- is_input_batched = True
344
310
  is_output_batched = True
345
311
 
312
+ def run(
313
+ self,
314
+ udf_fields: Sequence[str],
315
+ udf_inputs: Iterable[RowsOutputBatch],
316
+ catalog: "Catalog",
317
+ cache: bool,
318
+ download_cb: Callback = DEFAULT_CALLBACK,
319
+ processed_cb: Callback = DEFAULT_CALLBACK,
320
+ ) -> Iterator[Iterable[UDFResult]]:
321
+ self.catalog = catalog
322
+ self.setup()
323
+
324
+ for batch in udf_inputs:
325
+ n_rows = len(batch.rows)
326
+ row_ids, *udf_args = zip(
327
+ *[
328
+ self._prepare_row_and_id(row, udf_fields, cache, download_cb)
329
+ for row in batch.rows
330
+ ]
331
+ )
332
+ result_objs = list(self.process_safe(udf_args))
333
+ n_objs = len(result_objs)
334
+ assert (
335
+ n_objs == n_rows
336
+ ), f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
337
+ udf_outputs = (self._flatten_row(row) for row in result_objs)
338
+ output = [
339
+ {"sys__id": row_id} | dict(zip(self.signal_names, signals))
340
+ for row_id, signals in zip(row_ids, udf_outputs)
341
+ ]
342
+ processed_cb.relative_update(n_rows)
343
+ yield output
344
+
345
+ self.teardown()
346
+
346
347
 
347
348
  class Generator(UDFBase):
348
349
  """Inherit from this class to pass to `DataChain.gen()`."""
349
350
 
350
351
  is_output_batched = True
351
352
 
353
+ def run(
354
+ self,
355
+ udf_fields: "Sequence[str]",
356
+ udf_inputs: "Iterable[Sequence[Any]]",
357
+ catalog: "Catalog",
358
+ cache: bool,
359
+ download_cb: Callback = DEFAULT_CALLBACK,
360
+ processed_cb: Callback = DEFAULT_CALLBACK,
361
+ ) -> Iterator[Iterable[UDFResult]]:
362
+ self.catalog = catalog
363
+ self.setup()
364
+
365
+ for row in udf_inputs:
366
+ udf_args = self._prepare_row(row, udf_fields, cache, download_cb)
367
+ result_objs = self.process_safe(udf_args)
368
+ udf_outputs = (self._flatten_row(row) for row in result_objs)
369
+ output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
370
+ processed_cb.relative_update(1)
371
+ yield output
372
+
373
+ self.teardown()
374
+
352
375
 
353
376
  class Aggregator(UDFBase):
354
377
  """Inherit from this class to pass to `DataChain.agg()`."""
355
378
 
356
- is_input_batched = True
357
379
  is_output_batched = True
358
- is_input_grouped = True
380
+
381
+ def run(
382
+ self,
383
+ udf_fields: "Sequence[str]",
384
+ udf_inputs: Iterable[RowsOutputBatch],
385
+ catalog: "Catalog",
386
+ cache: bool,
387
+ download_cb: Callback = DEFAULT_CALLBACK,
388
+ processed_cb: Callback = DEFAULT_CALLBACK,
389
+ ) -> Iterator[Iterable[UDFResult]]:
390
+ self.catalog = catalog
391
+ self.setup()
392
+
393
+ for batch in udf_inputs:
394
+ udf_args = zip(
395
+ *[
396
+ self._prepare_row(row, udf_fields, cache, download_cb)
397
+ for row in batch.rows
398
+ ]
399
+ )
400
+ result_objs = self.process_safe(udf_args)
401
+ udf_outputs = (self._flatten_row(row) for row in result_objs)
402
+ output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
403
+ processed_cb.relative_update(len(batch.rows))
404
+ yield output
405
+
406
+ self.teardown()
@@ -1,10 +1,11 @@
1
1
  import inspect
2
2
  from collections.abc import Generator, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
- from typing import Callable, Optional, Union, get_args, get_origin
4
+ from typing import Callable, Union, get_args, get_origin
5
5
 
6
6
  from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
7
7
  from datachain.lib.signal_schema import SignalSchema
8
+ from datachain.lib.udf import UDFBase
8
9
  from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
10
 
10
11
 
@@ -16,7 +17,7 @@ class UdfSignatureError(DataChainParamsError):
16
17
 
17
18
  @dataclass
18
19
  class UdfSignature:
19
- func: Callable
20
+ func: Union[Callable, UDFBase]
20
21
  params: Sequence[str]
21
22
  output_schema: SignalSchema
22
23
 
@@ -27,7 +28,7 @@ class UdfSignature:
27
28
  cls,
28
29
  chain: str,
29
30
  signal_map: dict[str, Callable],
30
- func: Optional[Callable] = None,
31
+ func: Union[None, UDFBase, Callable] = None,
31
32
  params: Union[None, str, Sequence[str]] = None,
32
33
  output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
33
34
  is_generator: bool = True,
@@ -39,6 +40,7 @@ class UdfSignature:
39
40
  f"multiple signals '{keys}' are not supported in processors."
40
41
  " Chain multiple processors instead.",
41
42
  )
43
+ udf_func: Union[UDFBase, Callable]
42
44
  if len(signal_map) == 1:
43
45
  if func is not None:
44
46
  raise UdfSignatureError(
@@ -53,7 +55,7 @@ class UdfSignature:
53
55
  udf_func = func
54
56
  signal_name = None
55
57
 
56
- if not callable(udf_func):
58
+ if not isinstance(udf_func, UDFBase) and not callable(udf_func):
57
59
  raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
58
60
 
59
61
  func_params_map_sign, func_outs_sign, is_iterator = (
@@ -73,7 +75,7 @@ class UdfSignature:
73
75
  if not func_outs_sign:
74
76
  raise UdfSignatureError(
75
77
  chain,
76
- f"outputs are not defined in function '{udf_func.__name__}'"
78
+ f"outputs are not defined in function '{udf_func}'"
77
79
  " hints or 'output'",
78
80
  )
79
81
 
@@ -154,7 +156,7 @@ class UdfSignature:
154
156
 
155
157
  @staticmethod
156
158
  def _func_signature(
157
- chain: str, udf_func: Callable
159
+ chain: str, udf_func: Union[Callable, UDFBase]
158
160
  ) -> tuple[dict[str, type], Sequence[type], bool]:
159
161
  if isinstance(udf_func, AbstractUDF):
160
162
  func = udf_func.process # type: ignore[unreachable]
datachain/query/batch.py CHANGED
@@ -11,8 +11,6 @@ from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
11
11
  if TYPE_CHECKING:
12
12
  from sqlalchemy import Select
13
13
 
14
- from datachain.dataset import RowDict
15
-
16
14
 
17
15
  @dataclass
18
16
  class RowsOutputBatch:
@@ -22,14 +20,6 @@ class RowsOutputBatch:
22
20
  RowsOutput = Union[Sequence, RowsOutputBatch]
23
21
 
24
22
 
25
- @dataclass
26
- class UDFInputBatch:
27
- rows: Sequence["RowDict"]
28
-
29
-
30
- UDFInput = Union["RowDict", UDFInputBatch]
31
-
32
-
33
23
  class BatchingStrategy(ABC):
34
24
  """BatchingStrategy provides means of batching UDF executions."""
35
25
 
@@ -42,6 +42,7 @@ from datachain.data_storage.schema import (
42
42
  )
43
43
  from datachain.dataset import DatasetStatus, RowDict
44
44
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
45
+ from datachain.lib.udf import UDFAdapter
45
46
  from datachain.progress import CombinedDownloadCallback
46
47
  from datachain.sql.functions import rand
47
48
  from datachain.utils import (
@@ -53,7 +54,6 @@ from datachain.utils import (
53
54
 
54
55
  from .schema import C, UDFParamSpec, normalize_param
55
56
  from .session import Session
56
- from .udf import UDFBase
57
57
 
58
58
  if TYPE_CHECKING:
59
59
  from sqlalchemy.sql.elements import ClauseElement
@@ -299,7 +299,7 @@ def adjust_outputs(
299
299
  return row
300
300
 
301
301
 
302
- def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFBase) -> list[tuple]:
302
+ def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
303
303
  """Optimization: Precompute UDF column types so these don't have to be computed
304
304
  in the convert_type function for each row in a loop."""
305
305
  dialect = warehouse.db.dialect
@@ -320,7 +320,7 @@ def process_udf_outputs(
320
320
  warehouse: "AbstractWarehouse",
321
321
  udf_table: "Table",
322
322
  udf_results: Iterator[Iterable["UDFResult"]],
323
- udf: UDFBase,
323
+ udf: UDFAdapter,
324
324
  batch_size: int = INSERT_BATCH_SIZE,
325
325
  cb: Callback = DEFAULT_CALLBACK,
326
326
  ) -> None:
@@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
364
364
 
365
365
  @frozen
366
366
  class UDFStep(Step, ABC):
367
- udf: UDFBase
367
+ udf: UDFAdapter
368
368
  catalog: "Catalog"
369
369
  partition_by: Optional[PartitionByType] = None
370
370
  parallel: Optional[int] = None
@@ -392,7 +392,7 @@ class UDFStep(Step, ABC):
392
392
 
393
393
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
394
394
  use_partitioning = self.partition_by is not None
395
- batching = self.udf.properties.get_batching(use_partitioning)
395
+ batching = self.udf.get_batching(use_partitioning)
396
396
  workers = self.workers
397
397
  if (
398
398
  not workers
@@ -1465,7 +1465,7 @@ class DatasetQuery:
1465
1465
  @detach
1466
1466
  def add_signals(
1467
1467
  self,
1468
- udf: UDFBase,
1468
+ udf: UDFAdapter,
1469
1469
  parallel: Optional[int] = None,
1470
1470
  workers: Union[bool, int] = False,
1471
1471
  min_task_size: Optional[int] = None,
@@ -1509,7 +1509,7 @@ class DatasetQuery:
1509
1509
  @detach
1510
1510
  def generate(
1511
1511
  self,
1512
- udf: UDFBase,
1512
+ udf: UDFAdapter,
1513
1513
  parallel: Optional[int] = None,
1514
1514
  workers: Union[bool, int] = False,
1515
1515
  min_task_size: Optional[int] = None,
@@ -13,6 +13,7 @@ from multiprocess import get_context
13
13
 
14
14
  from datachain.catalog import Catalog
15
15
  from datachain.catalog.loader import get_distributed_class
16
+ from datachain.lib.udf import UDFAdapter, UDFResult
16
17
  from datachain.query.dataset import (
17
18
  get_download_callback,
18
19
  get_generated_callback,
@@ -27,7 +28,6 @@ from datachain.query.queue import (
27
28
  put_into_queue,
28
29
  unmarshal,
29
30
  )
30
- from datachain.query.udf import UDFBase, UDFResult
31
31
  from datachain.utils import batched_it
32
32
 
33
33
  DEFAULT_BATCH_SIZE = 10000
@@ -114,7 +114,6 @@ class UDFDispatcher:
114
114
  catalog: Optional[Catalog] = None
115
115
  task_queue: Optional[multiprocess.Queue] = None
116
116
  done_queue: Optional[multiprocess.Queue] = None
117
- _batch_size: Optional[int] = None
118
117
 
119
118
  def __init__(
120
119
  self,
@@ -154,17 +153,6 @@ class UDFDispatcher:
154
153
  self.done_queue = None
155
154
  self.ctx = get_context("spawn")
156
155
 
157
- @property
158
- def batch_size(self):
159
- if self._batch_size is None:
160
- if hasattr(self.udf, "properties") and hasattr(
161
- self.udf.properties, "batch"
162
- ):
163
- self._batch_size = self.udf.properties.batch
164
- else:
165
- self._batch_size = 1
166
- return self._batch_size
167
-
168
156
  def _create_worker(self) -> "UDFWorker":
169
157
  if not self.catalog:
170
158
  id_generator = self.id_generator_class(
@@ -336,7 +324,7 @@ class ProcessedCallback(Callback):
336
324
  @attrs.define
337
325
  class UDFWorker:
338
326
  catalog: Catalog
339
- udf: UDFBase
327
+ udf: UDFAdapter
340
328
  task_queue: "multiprocess.Queue"
341
329
  done_queue: "multiprocess.Queue"
342
330
  is_generator: bool