datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (51) hide show
  1. datachain/__init__.py +17 -8
  2. datachain/catalog/catalog.py +5 -5
  3. datachain/cli.py +0 -2
  4. datachain/data_storage/schema.py +5 -5
  5. datachain/data_storage/sqlite.py +1 -1
  6. datachain/data_storage/warehouse.py +7 -7
  7. datachain/lib/arrow.py +25 -8
  8. datachain/lib/clip.py +6 -11
  9. datachain/lib/convert/__init__.py +0 -0
  10. datachain/lib/convert/flatten.py +67 -0
  11. datachain/lib/convert/type_converter.py +96 -0
  12. datachain/lib/convert/unflatten.py +69 -0
  13. datachain/lib/convert/values_to_tuples.py +85 -0
  14. datachain/lib/data_model.py +74 -0
  15. datachain/lib/dc.py +225 -168
  16. datachain/lib/file.py +41 -41
  17. datachain/lib/gpt4_vision.py +1 -9
  18. datachain/lib/hf_image_to_text.py +9 -17
  19. datachain/lib/hf_pipeline.py +4 -12
  20. datachain/lib/image.py +2 -18
  21. datachain/lib/image_transform.py +0 -1
  22. datachain/lib/iptc_exif_xmp.py +8 -15
  23. datachain/lib/meta_formats.py +1 -5
  24. datachain/lib/model_store.py +77 -0
  25. datachain/lib/pytorch.py +9 -21
  26. datachain/lib/signal_schema.py +139 -60
  27. datachain/lib/text.py +5 -16
  28. datachain/lib/udf.py +114 -30
  29. datachain/lib/udf_signature.py +5 -5
  30. datachain/lib/webdataset.py +3 -3
  31. datachain/lib/webdataset_laion.py +2 -3
  32. datachain/node.py +4 -4
  33. datachain/query/batch.py +1 -1
  34. datachain/query/dataset.py +51 -178
  35. datachain/query/dispatch.py +43 -30
  36. datachain/query/udf.py +46 -26
  37. datachain/remote/studio.py +1 -9
  38. datachain/torch/__init__.py +21 -0
  39. datachain/utils.py +39 -0
  40. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
  41. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
  42. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
  43. datachain/image/__init__.py +0 -3
  44. datachain/lib/cached_stream.py +0 -38
  45. datachain/lib/claude.py +0 -69
  46. datachain/lib/feature.py +0 -412
  47. datachain/lib/feature_registry.py +0 -51
  48. datachain/lib/feature_utils.py +0 -154
  49. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
  50. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
  51. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,30 @@
1
1
  import copy
2
2
  from collections.abc import Iterator, Sequence
3
+ from dataclasses import dataclass
3
4
  from datetime import datetime
4
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Annotated,
8
+ Any,
9
+ Callable,
10
+ Literal,
11
+ Optional,
12
+ Union,
13
+ get_args,
14
+ get_origin,
15
+ )
5
16
 
6
- from pydantic import create_model
17
+ from pydantic import BaseModel, create_model
18
+ from typing_extensions import Literal as LiteralEx
7
19
 
8
- from datachain.lib.feature import (
9
- DATACHAIN_TO_TYPE,
10
- DEFAULT_DELIMITER,
11
- Feature,
12
- FeatureType,
13
- convert_type_to_datachain,
14
- )
15
- from datachain.lib.feature_registry import Registry
20
+ from datachain.lib.convert.flatten import DATACHAIN_TO_TYPE
21
+ from datachain.lib.convert.type_converter import convert_to_db_type
22
+ from datachain.lib.convert.unflatten import unflatten_to_json_pos
23
+ from datachain.lib.data_model import DataModel, DataType
16
24
  from datachain.lib.file import File
25
+ from datachain.lib.model_store import ModelStore
17
26
  from datachain.lib.utils import DataChainParamsError
27
+ from datachain.query.schema import DEFAULT_DELIMITER
18
28
 
19
29
  if TYPE_CHECKING:
20
30
  from datachain.catalog import Catalog
@@ -56,10 +66,16 @@ class SignalResolvingTypeError(SignalResolvingError):
56
66
  )
57
67
 
58
68
 
69
+ @dataclass
59
70
  class SignalSchema:
71
+ values: dict[str, DataType]
72
+ tree: dict[str, Any]
73
+ setup_func: dict[str, Callable]
74
+ setup_values: Optional[dict[str, Callable]]
75
+
60
76
  def __init__(
61
77
  self,
62
- values: dict[str, FeatureType],
78
+ values: dict[str, DataType],
63
79
  setup: Optional[dict[str, Callable]] = None,
64
80
  ):
65
81
  self.values = values
@@ -85,7 +101,7 @@ class SignalSchema:
85
101
 
86
102
  @staticmethod
87
103
  def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
88
- signals: dict[str, FeatureType] = {}
104
+ signals: dict[str, DataType] = {}
89
105
  for field, type_ in col_types.items():
90
106
  type_ = DATACHAIN_TO_TYPE.get(type_, None)
91
107
  if type_ is None:
@@ -99,15 +115,16 @@ class SignalSchema:
99
115
  def serialize(self) -> dict[str, str]:
100
116
  signals = {}
101
117
  for name, fr_type in self.values.items():
102
- if Feature.is_feature(fr_type):
103
- signals[name] = fr_type._name() # type: ignore[union-attr]
118
+ if (fr := ModelStore.to_pydantic(fr_type)) is not None:
119
+ ModelStore.add(fr)
120
+ signals[name] = ModelStore.get_name(fr)
104
121
  else:
105
122
  orig = get_origin(fr_type)
106
123
  args = get_args(fr_type)
107
124
  # Check if fr_type is Optional
108
125
  if orig == Union and len(args) == 2 and (type(None) in args):
109
126
  fr_type = args[0]
110
- signals[name] = fr_type.__name__
127
+ signals[name] = str(fr_type.__name__) # type: ignore[union-attr]
111
128
  return signals
112
129
 
113
130
  @staticmethod
@@ -115,80 +132,93 @@ class SignalSchema:
115
132
  if not isinstance(schema, dict):
116
133
  raise SignalSchemaError(f"cannot deserialize signal schema: {schema}")
117
134
 
118
- signals: dict[str, FeatureType] = {}
135
+ signals: dict[str, DataType] = {}
119
136
  for signal, type_name in schema.items():
120
137
  try:
121
138
  fr = NAMES_TO_TYPES.get(type_name)
122
139
  if not fr:
123
- type_name, version = Registry.parse_name_version(type_name)
124
- fr = Registry.get(type_name, version)
140
+ type_name, version = ModelStore.parse_name_version(type_name)
141
+ fr = ModelStore.get(type_name, version)
142
+
143
+ if not fr:
144
+ raise SignalSchemaError(
145
+ f"cannot deserialize '{signal}': "
146
+ f"unknown type '{type_name}'."
147
+ f" Try to add it with `ModelStore.add({type_name})`."
148
+ )
125
149
  except TypeError as err:
126
150
  raise SignalSchemaError(
127
151
  f"cannot deserialize '{signal}': {err}"
128
152
  ) from err
129
-
130
- if not fr:
131
- raise SignalSchemaError(
132
- f"cannot deserialize '{signal}': unsupported type '{type_name}'"
133
- )
134
153
  signals[signal] = fr
135
154
 
136
155
  return SignalSchema(signals)
137
156
 
138
- def to_udf_spec(self) -> dict[str, Any]:
157
+ def to_udf_spec(self) -> dict[str, type]:
139
158
  res = {}
140
159
  for path, type_, has_subtree, _ in self.get_flat_tree():
141
160
  if path[0] in self.setup_func:
142
161
  continue
143
162
  if not has_subtree:
144
163
  db_name = DEFAULT_DELIMITER.join(path)
145
- res[db_name] = convert_type_to_datachain(type_)
164
+ res[db_name] = convert_to_db_type(type_)
146
165
  return res
147
166
 
148
- def row_to_objs(self, row: Sequence[Any]) -> list[FeatureType]:
167
+ def row_to_objs(self, row: Sequence[Any]) -> list[DataType]:
149
168
  self._init_setup_values()
150
169
 
151
170
  objs = []
152
171
  pos = 0
153
172
  for name, fr_type in self.values.items():
154
- if val := self.setup_values.get(name, None): # type: ignore[attr-defined]
173
+ if self.setup_values and (val := self.setup_values.get(name, None)):
155
174
  objs.append(val)
156
- elif Feature.is_feature(fr_type):
157
- j, pos = fr_type._unflatten_to_json_pos(row, pos) # type: ignore[union-attr]
158
- objs.append(fr_type(**j))
175
+ elif (fr := ModelStore.to_pydantic(fr_type)) is not None:
176
+ j, pos = unflatten_to_json_pos(fr, row, pos)
177
+ objs.append(fr(**j)) # type: ignore[arg-type]
159
178
  else:
160
179
  objs.append(row[pos])
161
180
  pos += 1
162
181
  return objs # type: ignore[return-value]
163
182
 
164
183
  def contains_file(self) -> bool:
165
- return any(
166
- fr._is_file # type: ignore[union-attr]
167
- for fr in self.values.values()
168
- if Feature.is_feature(fr)
169
- )
184
+ for type_ in self.values.values():
185
+ if (fr := ModelStore.to_pydantic(type_)) is not None and issubclass(
186
+ fr, File
187
+ ):
188
+ return True
189
+
190
+ return False
170
191
 
171
192
  def slice(
172
193
  self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None
173
194
  ) -> "SignalSchema":
195
+ # Make new schema that combines current schema and setup signals
174
196
  setup = setup or {}
175
197
  setup_no_types = dict.fromkeys(setup.keys(), str)
176
- union = self.values | setup_no_types
177
- schema = {k: union[k] for k in keys if k in union}
198
+ union = SignalSchema(self.values | setup_no_types)
199
+ # Slice combined schema by keys
200
+ schema = {}
201
+ for k in keys:
202
+ try:
203
+ schema[k] = union._find_in_tree(k.split("."))
204
+ except SignalResolvingError:
205
+ pass
178
206
  return SignalSchema(schema, setup)
179
207
 
180
- def row_to_features(self, row: Sequence, catalog: "Catalog") -> list[FeatureType]:
208
+ def row_to_features(
209
+ self, row: Sequence, catalog: "Catalog", cache: bool = False
210
+ ) -> list[DataType]:
181
211
  res = []
182
212
  pos = 0
183
213
  for fr_cls in self.values.values():
184
- if not Feature.is_feature(fr_cls):
214
+ if (fr := ModelStore.to_pydantic(fr_cls)) is None:
185
215
  res.append(row[pos])
186
216
  pos += 1
187
217
  else:
188
- json, pos = fr_cls._unflatten_to_json_pos(row, pos) # type: ignore[union-attr]
189
- obj = fr_cls(**json)
218
+ json, pos = unflatten_to_json_pos(fr, row, pos) # type: ignore[union-attr]
219
+ obj = fr(**json)
190
220
  if isinstance(obj, File):
191
- obj._set_stream(catalog)
221
+ obj._set_stream(catalog, caching_enabled=cache)
192
222
  res.append(obj)
193
223
  return res
194
224
 
@@ -208,7 +238,7 @@ class SignalSchema:
208
238
 
209
239
  return SignalSchema(schema)
210
240
 
211
- def _find_in_tree(self, path: list[str]) -> FeatureType:
241
+ def _find_in_tree(self, path: list[str]) -> DataType:
212
242
  curr_tree = self.tree
213
243
  curr_type = None
214
244
  i = 0
@@ -265,24 +295,23 @@ class SignalSchema:
265
295
  if has_subtree and issubclass(type_, File):
266
296
  yield ".".join(path)
267
297
 
268
- def create_model(self, name: str) -> type[Feature]:
298
+ def create_model(self, name: str) -> type[DataModel]:
269
299
  fields = {key: (value, None) for key, value in self.values.items()}
270
300
 
271
301
  return create_model(
272
302
  name,
273
- __base__=(Feature,), # type: ignore[call-overload]
303
+ __base__=(DataModel,), # type: ignore[call-overload]
274
304
  **fields,
275
305
  )
276
306
 
277
307
  @staticmethod
278
- def _build_tree(values: dict[str, FeatureType]) -> dict[str, Any]:
279
- res = {}
280
-
281
- for name, val in values.items():
282
- subtree = val.build_tree() if Feature.is_feature(val) else None # type: ignore[union-attr]
283
- res[name] = (val, subtree)
284
-
285
- return res
308
+ def _build_tree(
309
+ values: dict[str, DataType],
310
+ ) -> dict[str, tuple[DataType, Optional[dict]]]:
311
+ return {
312
+ name: (val, SignalSchema._build_tree_for_type(val))
313
+ for name, val in values.items()
314
+ }
286
315
 
287
316
  def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]:
288
317
  yield from self._get_flat_tree(self.tree, [], 0)
@@ -305,27 +334,77 @@ class SignalSchema:
305
334
 
306
335
  if get_origin(type_) is list:
307
336
  args = get_args(type_)
308
- if len(args) > 0 and Feature.is_feature(args[0]):
337
+ if len(args) > 0 and ModelStore.is_pydantic(args[0]):
309
338
  sub_schema = SignalSchema({"* list of": args[0]})
310
339
  sub_schema.print_tree(indent=indent, start_at=total_indent + indent)
311
340
 
341
+ def get_headers_with_length(self):
342
+ paths = [
343
+ path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree
344
+ ]
345
+ max_length = max([len(path) for path in paths], default=0)
346
+ return [
347
+ path + [""] * (max_length - len(path)) if len(path) < max_length else path
348
+ for path in paths
349
+ ], max_length
350
+
351
+ def __or__(self, other):
352
+ return self.__class__(self.values | other.values)
353
+
354
+ def __contains__(self, name: str):
355
+ return name in self.values
356
+
357
+ def remove(self, name: str):
358
+ return self.values.pop(name)
359
+
312
360
  @staticmethod
313
- def _type_to_str(type_):
314
- if get_origin(type_) == Union:
361
+ def _type_to_str(type_): # noqa: PLR0911
362
+ origin = get_origin(type_)
363
+
364
+ if origin == Union:
315
365
  args = get_args(type_)
316
366
  formatted_types = ", ".join(SignalSchema._type_to_str(arg) for arg in args)
317
367
  return f"Union[{formatted_types}]"
318
- if get_origin(type_) == Optional:
368
+ if origin == Optional:
319
369
  args = get_args(type_)
320
370
  type_str = SignalSchema._type_to_str(args[0])
321
371
  return f"Optional[{type_str}]"
322
- if get_origin(type_) is list:
372
+ if origin is list:
323
373
  args = get_args(type_)
324
374
  type_str = SignalSchema._type_to_str(args[0])
325
375
  return f"list[{type_str}]"
326
- if get_origin(type_) is dict:
376
+ if origin is dict:
327
377
  args = get_args(type_)
328
- type_str = SignalSchema._type_to_str(args[0])
378
+ type_str = SignalSchema._type_to_str(args[0]) if len(args) > 0 else ""
329
379
  vals = f", {SignalSchema._type_to_str(args[1])}" if len(args) > 1 else ""
330
380
  return f"dict[{type_str}{vals}]"
381
+ if origin == Annotated:
382
+ args = get_args(type_)
383
+ return SignalSchema._type_to_str(args[0])
384
+ if origin in (Literal, LiteralEx):
385
+ return "Literal"
331
386
  return type_.__name__
387
+
388
+ @staticmethod
389
+ def _build_tree_for_type(
390
+ model: DataType,
391
+ ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
392
+ if (fr := ModelStore.to_pydantic(model)) is not None:
393
+ return SignalSchema._build_tree_for_model(fr)
394
+ return None
395
+
396
+ @staticmethod
397
+ def _build_tree_for_model(
398
+ model: type[BaseModel],
399
+ ) -> Optional[dict[str, tuple[DataType, Optional[dict]]]]:
400
+ res: dict[str, tuple[DataType, Optional[dict]]] = {}
401
+
402
+ for name, f_info in model.model_fields.items():
403
+ anno = f_info.annotation
404
+ if (fr := ModelStore.to_pydantic(anno)) is not None:
405
+ subtree = SignalSchema._build_tree_for_model(fr)
406
+ else:
407
+ subtree = None
408
+ res[name] = (anno, subtree) # type: ignore[assignment]
409
+
410
+ return res
datachain/lib/text.py CHANGED
@@ -1,7 +1,7 @@
1
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1
+ from typing import Any, Callable, Optional, Union
2
2
 
3
- if TYPE_CHECKING:
4
- import torch
3
+ import torch
4
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
5
5
 
6
6
 
7
7
  def convert_text(
@@ -9,7 +9,7 @@ def convert_text(
9
9
  tokenizer: Optional[Callable] = None,
10
10
  tokenizer_kwargs: Optional[dict[str, Any]] = None,
11
11
  encoder: Optional[Callable] = None,
12
- ) -> Union[str, list[str], "torch.Tensor"]:
12
+ ) -> Union[str, list[str], torch.Tensor]:
13
13
  """
14
14
  Tokenize and otherwise transform text.
15
15
 
@@ -29,21 +29,10 @@ def convert_text(
29
29
  res = tokenizer(text, **tokenizer_kwargs)
30
30
  else:
31
31
  res = tokenizer(text)
32
- try:
33
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
34
32
 
35
- tokens = (
36
- res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
37
- )
38
- except ImportError:
39
- tokens = res
33
+ tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
40
34
 
41
35
  if not encoder:
42
36
  return tokens
43
37
 
44
- try:
45
- import torch
46
- except ImportError:
47
- "Missing dependency 'torch' needed to encode text."
48
-
49
38
  return encoder(torch.tensor(tokens))
datachain/lib/udf.py CHANGED
@@ -1,16 +1,29 @@
1
- import inspect
2
1
  import sys
3
2
  import traceback
4
- from typing import TYPE_CHECKING, Callable
3
+ from collections.abc import Iterable, Iterator
4
+ from typing import TYPE_CHECKING, Callable, Optional
5
5
 
6
- from datachain.lib.feature import Feature
6
+ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
7
+ from pydantic import BaseModel
8
+
9
+ from datachain.dataset import RowDict
10
+ from datachain.lib.convert.flatten import flatten
11
+ from datachain.lib.convert.unflatten import unflatten_to_json
12
+ from datachain.lib.data_model import FileBasic
13
+ from datachain.lib.model_store import ModelStore
7
14
  from datachain.lib.signal_schema import SignalSchema
8
15
  from datachain.lib.udf_signature import UdfSignature
9
16
  from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
10
- from datachain.query import udf
17
+ from datachain.query.batch import RowBatch
18
+ from datachain.query.schema import ColumnParameter
19
+ from datachain.query.udf import UDFBase as _UDFBase
20
+ from datachain.query.udf import UDFProperties, UDFResult
11
21
 
12
22
  if TYPE_CHECKING:
13
- from datachain.query.udf import UDFWrapper
23
+ from typing_extensions import Self
24
+
25
+ from datachain.catalog import Catalog
26
+ from datachain.query.batch import BatchingResult
14
27
 
15
28
 
16
29
  class UdfError(DataChainParamsError):
@@ -18,10 +31,67 @@ class UdfError(DataChainParamsError):
18
31
  super().__init__(f"UDF error: {msg}")
19
32
 
20
33
 
34
+ class UDFAdapter(_UDFBase):
35
+ def __init__(
36
+ self,
37
+ inner: "UDFBase",
38
+ properties: UDFProperties,
39
+ ):
40
+ self.inner = inner
41
+ super().__init__(properties)
42
+
43
+ def run(
44
+ self,
45
+ udf_inputs: "Iterable[BatchingResult]",
46
+ catalog: "Catalog",
47
+ is_generator: bool,
48
+ cache: bool,
49
+ download_cb: Callback = DEFAULT_CALLBACK,
50
+ processed_cb: Callback = DEFAULT_CALLBACK,
51
+ ) -> Iterator[Iterable["UDFResult"]]:
52
+ self.inner._catalog = catalog
53
+ if hasattr(self.inner, "setup") and callable(self.inner.setup):
54
+ self.inner.setup()
55
+
56
+ for batch in udf_inputs:
57
+ n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
58
+ output = self.run_once(catalog, batch, is_generator, cache, cb=download_cb)
59
+ processed_cb.relative_update(n_rows)
60
+ yield output
61
+
62
+ if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
63
+ self.inner.teardown()
64
+
65
+ def run_once(
66
+ self,
67
+ catalog: "Catalog",
68
+ arg: "BatchingResult",
69
+ is_generator: bool = False,
70
+ cache: bool = False,
71
+ cb: Callback = DEFAULT_CALLBACK,
72
+ ) -> Iterable[UDFResult]:
73
+ if isinstance(arg, RowBatch):
74
+ udf_inputs = [
75
+ self.bind_parameters(catalog, row, cache=cache, cb=cb)
76
+ for row in arg.rows
77
+ ]
78
+ udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
79
+ return self._process_results(arg.rows, udf_outputs, is_generator)
80
+ if isinstance(arg, RowDict):
81
+ udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
82
+ udf_outputs = self.inner(*udf_inputs, cache=cache, download_cb=cb)
83
+ if not is_generator:
84
+ # udf_outputs is generator already if is_generator=True
85
+ udf_outputs = [udf_outputs]
86
+ return self._process_results([arg], udf_outputs, is_generator)
87
+ raise ValueError(f"Unexpected UDF argument: {arg}")
88
+
89
+
21
90
  class UDFBase(AbstractUDF):
22
91
  is_input_batched = False
23
92
  is_output_batched = False
24
93
  is_input_grouped = False
94
+ params_spec: Optional[list[str]]
25
95
 
26
96
  def __init__(self):
27
97
  self.params = None
@@ -48,7 +118,12 @@ class UDFBase(AbstractUDF):
48
118
  This is needed for tasks like closing connections to end-points.
49
119
  """
50
120
 
51
- def _init(self, sign: UdfSignature, params: SignalSchema, func: Callable):
121
+ def _init(
122
+ self,
123
+ sign: UdfSignature,
124
+ params: SignalSchema,
125
+ func: Callable,
126
+ ):
52
127
  self.params = params
53
128
  self.output = sign.output_schema
54
129
 
@@ -61,20 +136,19 @@ class UDFBase(AbstractUDF):
61
136
  @classmethod
62
137
  def _create(
63
138
  cls,
64
- target_class: type["UDFBase"],
65
139
  sign: UdfSignature,
66
140
  params: SignalSchema,
67
- ) -> "UDFBase":
141
+ ) -> "Self":
68
142
  if isinstance(sign.func, AbstractUDF):
69
- if not isinstance(sign.func, target_class): # type: ignore[unreachable]
143
+ if not isinstance(sign.func, cls): # type: ignore[unreachable]
70
144
  raise UdfError(
71
145
  f"cannot create UDF: provided UDF '{sign.func.__name__}'"
72
- f" must be a child of target class '{target_class.__name__}'",
146
+ f" must be a child of target class '{cls.__name__}'",
73
147
  )
74
148
  result = sign.func
75
149
  func = None
76
150
  else:
77
- result = target_class()
151
+ result = cls()
78
152
  func = sign.func
79
153
 
80
154
  result._init(sign, params, func)
@@ -91,18 +165,21 @@ class UDFBase(AbstractUDF):
91
165
  def catalog(self):
92
166
  return self._catalog
93
167
 
94
- def to_udf_wrapper(self, batch=1) -> "UDFWrapper":
95
- udf_wrapper = udf(self.params_spec, self.output_spec, batch=batch)
96
- return udf_wrapper(self)
168
+ def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
169
+ assert self.params_spec is not None
170
+ properties = UDFProperties(
171
+ [ColumnParameter(p) for p in self.params_spec], self.output_spec, batch
172
+ )
173
+ return UDFAdapter(self, properties)
97
174
 
98
175
  def validate_results(self, results, *args, **kwargs):
99
176
  return results
100
177
 
101
- def __call__(self, *rows):
178
+ def __call__(self, *rows, cache, download_cb):
102
179
  if self.is_input_grouped:
103
- objs = self._parse_grouped_rows(rows)
180
+ objs = self._parse_grouped_rows(rows[0], cache, download_cb)
104
181
  else:
105
- objs = self._parse_rows(rows)
182
+ objs = self._parse_rows(rows, cache, download_cb)
106
183
 
107
184
  if not self.is_input_batched:
108
185
  objs = objs[0]
@@ -117,15 +194,19 @@ class UDFBase(AbstractUDF):
117
194
  for tuple_ in result_objs:
118
195
  flat = []
119
196
  for obj in tuple_:
120
- if isinstance(obj, Feature):
121
- flat.extend(Feature._flatten(obj))
197
+ if isinstance(obj, BaseModel):
198
+ flat.extend(flatten(obj))
122
199
  else:
123
200
  flat.append(obj)
124
201
  res.append(flat)
125
202
  else:
126
203
  # Generator expression is required, otherwise the value will be materialized
127
204
  res = (
128
- obj._flatten() if isinstance(obj, Feature) else (obj,)
205
+ flatten(obj)
206
+ if isinstance(obj, BaseModel)
207
+ else obj
208
+ if isinstance(obj, tuple)
209
+ else (obj,)
129
210
  for obj in result_objs
130
211
  )
131
212
 
@@ -139,24 +220,25 @@ class UDFBase(AbstractUDF):
139
220
 
140
221
  return res
141
222
 
142
- def _parse_rows(self, rows):
223
+ def _parse_rows(self, rows, cache, download_cb):
143
224
  if not self.is_input_batched:
144
225
  rows = [rows]
145
226
  objs = []
146
227
  for row in rows:
147
228
  obj_row = self.params.row_to_objs(row)
148
229
  for obj in obj_row:
149
- if isinstance(obj, Feature):
150
- obj._set_stream(self._catalog, caching_enabled=True)
230
+ if isinstance(obj, FileBasic):
231
+ obj._set_stream(
232
+ self._catalog, caching_enabled=cache, download_cb=download_cb
233
+ )
151
234
  objs.append(obj_row)
152
235
  return objs
153
236
 
154
- def _parse_grouped_rows(self, rows):
155
- group = rows[0]
237
+ def _parse_grouped_rows(self, group, cache, download_cb):
156
238
  spec_map = {}
157
239
  output_map = {}
158
240
  for name, (anno, subtree) in self.params.tree.items():
159
- if inspect.isclass(anno) and issubclass(anno, Feature):
241
+ if ModelStore.is_pydantic(anno):
160
242
  length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
161
243
  else:
162
244
  length = 1
@@ -169,13 +251,15 @@ class UDFBase(AbstractUDF):
169
251
  slice = flat_obj[position : position + length]
170
252
  position += length
171
253
 
172
- if Feature.is_feature(cls):
173
- obj = cls(**cls._unflatten_to_json(slice))
254
+ if ModelStore.is_pydantic(cls):
255
+ obj = cls(**unflatten_to_json(cls, slice))
174
256
  else:
175
257
  obj = slice[0]
176
258
 
177
- if isinstance(obj, Feature):
178
- obj._set_stream(self._catalog)
259
+ if isinstance(obj, FileBasic):
260
+ obj._set_stream(
261
+ self._catalog, caching_enabled=cache, download_cb=download_cb
262
+ )
179
263
  output_map[signal].append(obj)
180
264
 
181
265
  return list(output_map.values())
@@ -3,7 +3,7 @@ from collections.abc import Generator, Iterator, Sequence
3
3
  from dataclasses import dataclass
4
4
  from typing import Callable, Optional, Union, get_args, get_origin
5
5
 
6
- from datachain.lib.feature import Feature, FeatureType, FeatureTypeNames
6
+ from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
7
7
  from datachain.lib.signal_schema import SignalSchema
8
8
  from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
9
 
@@ -29,7 +29,7 @@ class UdfSignature:
29
29
  signal_map: dict[str, Callable],
30
30
  func: Optional[Callable] = None,
31
31
  params: Union[None, str, Sequence[str]] = None,
32
- output: Union[None, FeatureType, Sequence[str], dict[str, FeatureType]] = None,
32
+ output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
33
33
  is_generator: bool = True,
34
34
  ) -> "UdfSignature":
35
35
  keys = ", ".join(signal_map.keys())
@@ -127,15 +127,15 @@ class UdfSignature:
127
127
  f"output signal '{key}' has type '{type(key)}'"
128
128
  " while 'str' is expected",
129
129
  )
130
- if not Feature.is_feature_type(value):
130
+ if not is_chain_type(value):
131
131
  raise UdfSignatureError(
132
132
  chain,
133
133
  f"output type '{value.__name__}' of signal '{key}' is not"
134
- f" supported. Please use Feature types: {FeatureTypeNames}",
134
+ f" supported. Please use Feature types: {DataTypeNames}",
135
135
  )
136
136
 
137
137
  udf_output_map = output
138
- elif Feature.is_feature_type(output):
138
+ elif is_chain_type(output):
139
139
  udf_output_map = {signal_name: output}
140
140
  else:
141
141
  raise UdfSignatureError(
@@ -15,7 +15,7 @@ from typing import (
15
15
 
16
16
  from pydantic import Field
17
17
 
18
- from datachain.lib.feature import Feature
18
+ from datachain.lib.data_model import DataModel
19
19
  from datachain.lib.file import File, TarVFile
20
20
  from datachain.lib.utils import DataChainError
21
21
 
@@ -46,7 +46,7 @@ class UnknownFileExtensionError(WDSError):
46
46
  super().__init__(tar_stream, f"unknown extension '{ext}' for file '{name}'")
47
47
 
48
48
 
49
- class WDSBasic(Feature):
49
+ class WDSBasic(DataModel):
50
50
  file: File
51
51
 
52
52
 
@@ -75,7 +75,7 @@ class WDSAllFile(WDSBasic):
75
75
  cbor: Optional[bytes] = Field(default=None)
76
76
 
77
77
 
78
- class WDSReadableSubclass(Feature):
78
+ class WDSReadableSubclass(DataModel):
79
79
  @staticmethod
80
80
  def _reader(builder, item: tarfile.TarInfo) -> "WDSReadableSubclass":
81
81
  raise NotImplementedError