datachain 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (49) hide show
  1. datachain/__init__.py +3 -4
  2. datachain/cache.py +10 -4
  3. datachain/catalog/catalog.py +35 -15
  4. datachain/cli.py +37 -32
  5. datachain/data_storage/metastore.py +24 -0
  6. datachain/data_storage/warehouse.py +3 -1
  7. datachain/job.py +56 -0
  8. datachain/lib/arrow.py +19 -7
  9. datachain/lib/clip.py +89 -66
  10. datachain/lib/convert/{type_converter.py → python_to_sql.py} +6 -6
  11. datachain/lib/convert/sql_to_python.py +23 -0
  12. datachain/lib/convert/values_to_tuples.py +51 -33
  13. datachain/lib/data_model.py +6 -27
  14. datachain/lib/dataset_info.py +70 -0
  15. datachain/lib/dc.py +646 -152
  16. datachain/lib/file.py +117 -15
  17. datachain/lib/image.py +1 -1
  18. datachain/lib/meta_formats.py +14 -2
  19. datachain/lib/model_store.py +3 -2
  20. datachain/lib/pytorch.py +10 -7
  21. datachain/lib/signal_schema.py +39 -14
  22. datachain/lib/text.py +2 -1
  23. datachain/lib/udf.py +56 -5
  24. datachain/lib/udf_signature.py +1 -1
  25. datachain/lib/webdataset.py +4 -3
  26. datachain/node.py +11 -8
  27. datachain/query/dataset.py +66 -147
  28. datachain/query/dispatch.py +15 -13
  29. datachain/query/schema.py +2 -0
  30. datachain/query/session.py +4 -4
  31. datachain/sql/functions/array.py +12 -0
  32. datachain/sql/functions/string.py +8 -0
  33. datachain/torch/__init__.py +1 -1
  34. datachain/utils.py +45 -0
  35. datachain-0.2.12.dist-info/METADATA +412 -0
  36. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/RECORD +40 -45
  37. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/WHEEL +1 -1
  38. datachain/lib/feature_registry.py +0 -77
  39. datachain/lib/gpt4_vision.py +0 -97
  40. datachain/lib/hf_image_to_text.py +0 -97
  41. datachain/lib/hf_pipeline.py +0 -90
  42. datachain/lib/image_transform.py +0 -103
  43. datachain/lib/iptc_exif_xmp.py +0 -76
  44. datachain/lib/unstructured.py +0 -41
  45. datachain/text/__init__.py +0 -3
  46. datachain-0.2.10.dist-info/METADATA +0 -430
  47. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/LICENSE +0 -0
  48. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/entry_points.txt +0 -0
  49. {datachain-0.2.10.dist-info → datachain-0.2.12.dist-info}/top_level.txt +0 -0
datachain/lib/file.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import io
2
2
  import json
3
+ import os
4
+ import posixpath
3
5
  from abc import ABC, abstractmethod
4
6
  from contextlib import contextmanager
5
7
  from datetime import datetime
@@ -16,7 +18,7 @@ from pydantic import Field, field_validator
16
18
 
17
19
  from datachain.cache import UniqueId
18
20
  from datachain.client.fileslice import FileSlice
19
- from datachain.lib.data_model import DataModel, FileBasic
21
+ from datachain.lib.data_model import DataModel
20
22
  from datachain.lib.utils import DataChainError
21
23
  from datachain.sql.types import JSON, Int, String
22
24
  from datachain.utils import TIME_ZERO
@@ -24,6 +26,9 @@ from datachain.utils import TIME_ZERO
24
26
  if TYPE_CHECKING:
25
27
  from datachain.catalog import Catalog
26
28
 
29
+ # how to create file path when exporting
30
+ ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]
31
+
27
32
 
28
33
  class VFileError(DataChainError):
29
34
  def __init__(self, file: "File", message: str, vtype: str = ""):
@@ -49,12 +54,15 @@ class VFile(ABC):
49
54
 
50
55
 
51
56
  class TarVFile(VFile):
57
+ """Virtual file model for files extracted from tar archives."""
58
+
52
59
  @classmethod
53
60
  def get_vtype(cls) -> str:
54
61
  return "tar"
55
62
 
56
63
  @classmethod
57
64
  def open(cls, file: "File", location: list[dict]):
65
+ """Stream file from tar archive based on location in archive."""
58
66
  if len(location) > 1:
59
67
  VFileError(file, "multiple 'location's are not supported yet")
60
68
 
@@ -100,7 +108,9 @@ class VFileRegistry:
100
108
  return reader.open(file, location)
101
109
 
102
110
 
103
- class File(FileBasic):
111
+ class File(DataModel):
112
+ """`DataModel` for reading binary files."""
113
+
104
114
  source: str = Field(default="")
105
115
  parent: str = Field(default="")
106
116
  name: str
@@ -127,14 +137,17 @@ class File(FileBasic):
127
137
  "source",
128
138
  "parent",
129
139
  "name",
130
- "etag",
131
140
  "size",
141
+ "etag",
142
+ "version",
143
+ "is_latest",
132
144
  "vtype",
133
145
  "location",
146
+ "last_modified",
134
147
  ]
135
148
 
136
149
  @staticmethod
137
- def to_dict(
150
+ def _validate_dict(
138
151
  v: Optional[Union[str, dict, list[dict]]],
139
152
  ) -> Optional[Union[str, dict, list[dict]]]:
140
153
  if v is None or v == "":
@@ -152,7 +165,7 @@ class File(FileBasic):
152
165
  @field_validator("location", mode="before")
153
166
  @classmethod
154
167
  def validate_location(cls, v):
155
- return File.to_dict(v)
168
+ return File._validate_dict(v)
156
169
 
157
170
  @field_validator("parent", mode="before")
158
171
  @classmethod
@@ -172,9 +185,10 @@ class File(FileBasic):
172
185
  self._caching_enabled = False
173
186
 
174
187
  @contextmanager
175
- def open(self):
188
+ def open(self, mode: Literal["rb", "r"] = "rb"):
189
+ """Open the file and return a file object."""
176
190
  if self.location:
177
- with VFileRegistry.resolve(self, self.location) as f:
191
+ with VFileRegistry.resolve(self, self.location) as f: # type: ignore[arg-type]
178
192
  yield f
179
193
 
180
194
  uid = self.get_uid()
@@ -184,7 +198,41 @@ class File(FileBasic):
184
198
  with client.open_object(
185
199
  uid, use_cache=self._caching_enabled, cb=self._download_cb
186
200
  ) as f:
187
- yield f
201
+ yield io.TextIOWrapper(f) if mode == "r" else f
202
+
203
+ def read(self, length: int = -1):
204
+ """Returns file contents."""
205
+ with self.open() as stream:
206
+ return stream.read(length)
207
+
208
+ def read_bytes(self):
209
+ """Returns file contents as bytes."""
210
+ return self.read()
211
+
212
+ def read_text(self):
213
+ """Returns file contents as text."""
214
+ with self.open(mode="r") as stream:
215
+ return stream.read()
216
+
217
+ def write(self, destination: str):
218
+ """Writes it's content to destination"""
219
+ with open(destination, mode="wb") as f:
220
+ f.write(self.read())
221
+
222
+ def export(
223
+ self,
224
+ output: str,
225
+ placement: ExportPlacement = "fullpath",
226
+ use_cache: bool = True,
227
+ ) -> None:
228
+ """Export file to new location."""
229
+ if use_cache:
230
+ self._caching_enabled = use_cache
231
+ dst = self.get_destination_path(output, placement)
232
+ dst_dir = os.path.dirname(dst)
233
+ os.makedirs(dst_dir, exist_ok=True)
234
+
235
+ self.write(dst)
188
236
 
189
237
  def _set_stream(
190
238
  self,
@@ -197,11 +245,12 @@ class File(FileBasic):
197
245
  self._download_cb = download_cb
198
246
 
199
247
  def get_uid(self) -> UniqueId:
248
+ """Returns unique ID for file."""
200
249
  dump = self.model_dump()
201
250
  return UniqueId(*(dump[k] for k in self._unique_id_keys))
202
251
 
203
252
  def get_local_path(self) -> Optional[str]:
204
- """Get path to a file in a local cache.
253
+ """Returns path to a file in a local cache.
205
254
  Return None if file is not cached. Throws an exception if cache is not setup."""
206
255
  if self._catalog is None:
207
256
  raise RuntimeError(
@@ -210,21 +259,27 @@ class File(FileBasic):
210
259
  return self._catalog.cache.get_path(self.get_uid())
211
260
 
212
261
  def get_file_suffix(self):
262
+ """Returns last part of file name with `.`."""
213
263
  return Path(self.name).suffix
214
264
 
215
265
  def get_file_ext(self):
266
+ """Returns last part of file name without `.`."""
216
267
  return Path(self.name).suffix.strip(".")
217
268
 
218
269
  def get_file_stem(self):
270
+ """Returns file name without extension."""
219
271
  return Path(self.name).stem
220
272
 
221
273
  def get_full_name(self):
274
+ """Returns name with parent directories."""
222
275
  return (Path(self.parent) / self.name).as_posix()
223
276
 
224
277
  def get_uri(self):
278
+ """Returns file URI."""
225
279
  return f"{self.source}/{self.get_full_name()}"
226
280
 
227
281
  def get_path(self) -> str:
282
+ """Returns file path."""
228
283
  path = unquote(self.get_uri())
229
284
  fs = self.get_fs()
230
285
  if isinstance(fs, LocalFileSystem):
@@ -233,21 +288,65 @@ class File(FileBasic):
233
288
  path = url2pathname(path)
234
289
  return path
235
290
 
291
+ def get_destination_path(self, output: str, placement: ExportPlacement) -> str:
292
+ """
293
+ Returns full destination path of a file for exporting to some output
294
+ based on export placement
295
+ """
296
+ if placement == "filename":
297
+ path = unquote(self.name)
298
+ elif placement == "etag":
299
+ path = f"{self.etag}{self.get_file_suffix()}"
300
+ elif placement == "fullpath":
301
+ fs = self.get_fs()
302
+ if isinstance(fs, LocalFileSystem):
303
+ path = unquote(self.get_full_name())
304
+ else:
305
+ path = (
306
+ Path(urlparse(self.source).netloc) / unquote(self.get_full_name())
307
+ ).as_posix()
308
+ elif placement == "checksum":
309
+ raise NotImplementedError("Checksum placement not implemented yet")
310
+ else:
311
+ raise ValueError(f"Unsupported file export placement: {placement}")
312
+ return posixpath.join(output, path) # type: ignore[union-attr]
313
+
236
314
  def get_fs(self):
315
+ """Returns `fsspec` filesystem for the file."""
237
316
  return self._catalog.get_client(self.source).fs
238
317
 
239
318
 
240
319
  class TextFile(File):
320
+ """`DataModel` for reading text files."""
321
+
241
322
  @contextmanager
242
323
  def open(self):
243
- with super().open() as binary:
244
- yield io.TextIOWrapper(binary)
324
+ """Open the file and return a file object in text mode."""
325
+ with super().open(mode="r") as stream:
326
+ yield stream
327
+
328
+ def read_text(self):
329
+ """Returns file contents as text."""
330
+ with self.open() as stream:
331
+ return stream.read()
332
+
333
+ def write(self, destination: str):
334
+ """Writes it's content to destination"""
335
+ with open(destination, mode="w") as f:
336
+ f.write(self.read_text())
245
337
 
246
338
 
247
339
  class ImageFile(File):
248
- def get_value(self):
249
- value = super().get_value()
250
- return Image.open(BytesIO(value))
340
+ """`DataModel` for reading image files."""
341
+
342
+ def read(self):
343
+ """Returns `PIL.Image.Image` object."""
344
+ fobj = super().read()
345
+ return Image.open(BytesIO(fobj))
346
+
347
+ def write(self, destination: str):
348
+ """Writes it's content to destination"""
349
+ self.read().save(destination)
251
350
 
252
351
 
253
352
  def get_file(type_: Literal["binary", "text", "image"] = "binary"):
@@ -282,7 +381,10 @@ def get_file(type_: Literal["binary", "text", "image"] = "binary"):
282
381
 
283
382
 
284
383
  class IndexedFile(DataModel):
285
- """File source info for tables."""
384
+ """Metadata indexed from tabular files.
385
+
386
+ Includes `file` and `index` signals.
387
+ """
286
388
 
287
389
  file: File
288
390
  index: int
datachain/lib/image.py CHANGED
@@ -53,7 +53,7 @@ def convert_images(
53
53
  Resize, transform, and otherwise convert one or more images.
54
54
 
55
55
  Args:
56
- img (Image, list[Image]): PIL.Image object or list of objects.
56
+ images (Image, list[Image]): PIL.Image object or list of objects.
57
57
  mode (str): PIL.Image mode.
58
58
  size (tuple[int, int]): Size in (width, height) pixels for resizing.
59
59
  transform (Callable): Torchvision transform or huggingface processor to apply.
@@ -13,6 +13,7 @@ from typing import Any, Callable
13
13
  import jmespath as jsp
14
14
  from pydantic import ValidationError
15
15
 
16
+ from datachain.lib.data_model import ModelStore # noqa: F401
16
17
  from datachain.lib.file import File
17
18
 
18
19
 
@@ -86,6 +87,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
86
87
  except subprocess.CalledProcessError as e:
87
88
  model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
88
89
  print(f"{model_output}")
90
+ print("\n" + f"ModelStore.register({model_name})" + "\n")
91
+ print("\n" + f"spec={model_name}" + "\n")
89
92
  return model_output
90
93
 
91
94
 
@@ -99,6 +102,7 @@ def read_meta( # noqa: C901
99
102
  jmespath=None,
100
103
  show_schema=False,
101
104
  model_name=None,
105
+ nrows=None,
102
106
  ) -> Callable:
103
107
  from datachain.lib.dc import DataChain
104
108
 
@@ -118,8 +122,7 @@ def read_meta( # noqa: C901
118
122
  output=str,
119
123
  )
120
124
  )
121
- # dummy executor (#1616)
122
- chain.save()
125
+ chain.exec()
123
126
  finally:
124
127
  sys.stdout = current_stdout
125
128
  model_output = captured_output.getvalue()
@@ -147,6 +150,7 @@ def read_meta( # noqa: C901
147
150
  DataModel=spec, # noqa: N803
148
151
  meta_type=meta_type,
149
152
  jmespath=jmespath,
153
+ nrows=nrows,
150
154
  ) -> Iterator[spec]:
151
155
  def validator(json_object: dict) -> spec:
152
156
  json_string = json.dumps(json_object)
@@ -175,14 +179,22 @@ def read_meta( # noqa: C901
175
179
  yield from validator(json_object)
176
180
 
177
181
  else:
182
+ nrow = 0
178
183
  for json_dict in json_object:
184
+ nrow = nrow + 1
185
+ if nrows is not None and nrow > nrows:
186
+ return
179
187
  yield from validator(json_dict)
180
188
 
181
189
  if meta_type == "jsonl":
182
190
  try:
191
+ nrow = 0
183
192
  with file.open() as fd:
184
193
  data_string = fd.readline().replace("\r", "")
185
194
  while data_string:
195
+ nrow = nrow + 1
196
+ if nrows is not None and nrow > nrows:
197
+ return
186
198
  json_object = process_json(data_string, jmespath)
187
199
  data_string = fd.readline()
188
200
  yield from validator(json_object)
@@ -22,7 +22,8 @@ class ModelStore:
22
22
  return model.__name__
23
23
 
24
24
  @classmethod
25
- def add(cls, fr: type):
25
+ def register(cls, fr: type):
26
+ """Register a class as a data model for deserialization."""
26
27
  if (model := ModelStore.to_pydantic(fr)) is None:
27
28
  return
28
29
 
@@ -34,7 +35,7 @@ class ModelStore:
34
35
 
35
36
  for f_info in model.model_fields.values():
36
37
  if (anno := ModelStore.to_pydantic(f_info.annotation)) is not None:
37
- cls.add(anno)
38
+ cls.register(anno)
38
39
 
39
40
  @classmethod
40
41
  def get(cls, name: str, version: Optional[int] = None) -> Optional[type]:
datachain/lib/pytorch.py CHANGED
@@ -3,7 +3,6 @@ from collections.abc import Iterator
3
3
  from typing import TYPE_CHECKING, Any, Callable, Optional
4
4
 
5
5
  from PIL import Image
6
- from pydantic import BaseModel
7
6
  from torch import float32
8
7
  from torch.distributed import get_rank, get_world_size
9
8
  from torch.utils.data import IterableDataset, get_worker_info
@@ -11,6 +10,7 @@ from torchvision.transforms import v2
11
10
 
12
11
  from datachain.catalog import Catalog, get_catalog
13
12
  from datachain.lib.dc import DataChain
13
+ from datachain.lib.file import File
14
14
  from datachain.lib.text import convert_text
15
15
 
16
16
  if TYPE_CHECKING:
@@ -24,6 +24,7 @@ DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])
24
24
 
25
25
 
26
26
  def label_to_int(value: str, classes: list) -> int:
27
+ """Given a value and list of classes, return the index of the value's class."""
27
28
  return classes.index(value)
28
29
 
29
30
 
@@ -33,7 +34,7 @@ class PytorchDataset(IterableDataset):
33
34
  name: str,
34
35
  version: Optional[int] = None,
35
36
  catalog: Optional["Catalog"] = None,
36
- transform: Optional["Transform"] = DEFAULT_TRANSFORM,
37
+ transform: Optional["Transform"] = None,
37
38
  tokenizer: Optional[Callable] = None,
38
39
  tokenizer_kwargs: Optional[dict[str, Any]] = None,
39
40
  num_samples: int = 0,
@@ -41,6 +42,9 @@ class PytorchDataset(IterableDataset):
41
42
  """
42
43
  Pytorch IterableDataset that streams DataChain datasets.
43
44
 
45
+ See Also:
46
+ `DataChain.to_pytorch()` - convert chain to PyTorch Dataset.
47
+
44
48
  Args:
45
49
  name (str): Name of DataChain dataset to stream.
46
50
  version (int): Version of DataChain dataset to stream.
@@ -53,7 +57,7 @@ class PytorchDataset(IterableDataset):
53
57
  """
54
58
  self.name = name
55
59
  self.version = version
56
- self.transform = transform
60
+ self.transform = transform or DEFAULT_TRANSFORM
57
61
  self.tokenizer = tokenizer
58
62
  self.tokenizer_kwargs = tokenizer_kwargs or {}
59
63
  self.num_samples = num_samples
@@ -90,12 +94,11 @@ class PytorchDataset(IterableDataset):
90
94
  if self.num_samples > 0:
91
95
  ds = ds.sample(self.num_samples)
92
96
  ds = ds.chunk(total_rank, total_workers)
93
- stream = ds.iterate()
94
- for row_features in stream:
97
+ for row_features in ds.collect():
95
98
  row = []
96
99
  for fr in row_features:
97
- if isinstance(fr, BaseModel):
98
- row.append(fr.get_value()) # type: ignore[unreachable]
100
+ if isinstance(fr, File):
101
+ row.append(fr.read()) # type: ignore[unreachable]
99
102
  else:
100
103
  row.append(fr)
101
104
  # Apply transforms
@@ -18,7 +18,8 @@ from pydantic import BaseModel, create_model
18
18
  from typing_extensions import Literal as LiteralEx
19
19
 
20
20
  from datachain.lib.convert.flatten import DATACHAIN_TO_TYPE
21
- from datachain.lib.convert.type_converter import convert_to_db_type
21
+ from datachain.lib.convert.python_to_sql import python_to_sql
22
+ from datachain.lib.convert.sql_to_python import sql_to_python
22
23
  from datachain.lib.convert.unflatten import unflatten_to_json_pos
23
24
  from datachain.lib.data_model import DataModel, DataType
24
25
  from datachain.lib.file import File
@@ -102,21 +103,20 @@ class SignalSchema:
102
103
  @staticmethod
103
104
  def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
104
105
  signals: dict[str, DataType] = {}
105
- for field, type_ in col_types.items():
106
- type_ = DATACHAIN_TO_TYPE.get(type_, None)
107
- if type_ is None:
106
+ for field, col_type in col_types.items():
107
+ if (py_type := DATACHAIN_TO_TYPE.get(col_type, None)) is None:
108
108
  raise SignalSchemaError(
109
109
  f"signal schema cannot be obtained for column '{field}':"
110
- f" unsupported type '{type_}'"
110
+ f" unsupported type '{py_type}'"
111
111
  )
112
- signals[field] = type_
112
+ signals[field] = py_type
113
113
  return SignalSchema(signals)
114
114
 
115
115
  def serialize(self) -> dict[str, str]:
116
116
  signals = {}
117
117
  for name, fr_type in self.values.items():
118
118
  if (fr := ModelStore.to_pydantic(fr_type)) is not None:
119
- ModelStore.add(fr)
119
+ ModelStore.register(fr)
120
120
  signals[name] = ModelStore.get_name(fr)
121
121
  else:
122
122
  orig = get_origin(fr_type)
@@ -143,8 +143,8 @@ class SignalSchema:
143
143
  if not fr:
144
144
  raise SignalSchemaError(
145
145
  f"cannot deserialize '{signal}': "
146
- f"unregistered type '{type_name}'."
147
- f" Try to register it with `Registry.add({type_name})`."
146
+ f"unknown type '{type_name}'."
147
+ f" Try to add it with `ModelStore.register({type_name})`."
148
148
  )
149
149
  except TypeError as err:
150
150
  raise SignalSchemaError(
@@ -161,7 +161,7 @@ class SignalSchema:
161
161
  continue
162
162
  if not has_subtree:
163
163
  db_name = DEFAULT_DELIMITER.join(path)
164
- res[db_name] = convert_to_db_type(type_)
164
+ res[db_name] = python_to_sql(type_)
165
165
  return res
166
166
 
167
167
  def row_to_objs(self, row: Sequence[Any]) -> list[DataType]:
@@ -192,10 +192,17 @@ class SignalSchema:
192
192
  def slice(
193
193
  self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None
194
194
  ) -> "SignalSchema":
195
+ # Make new schema that combines current schema and setup signals
195
196
  setup = setup or {}
196
197
  setup_no_types = dict.fromkeys(setup.keys(), str)
197
- union = self.values | setup_no_types
198
- schema = {k: union[k] for k in keys if k in union}
198
+ union = SignalSchema(self.values | setup_no_types)
199
+ # Slice combined schema by keys
200
+ schema = {}
201
+ for k in keys:
202
+ try:
203
+ schema[k] = union._find_in_tree(k.split("."))
204
+ except SignalResolvingError:
205
+ pass
199
206
  return SignalSchema(schema, setup)
200
207
 
201
208
  def row_to_features(
@@ -271,6 +278,14 @@ class SignalSchema:
271
278
  del schema[signal]
272
279
  return SignalSchema(schema)
273
280
 
281
+ def mutate(self, args_map: dict) -> "SignalSchema":
282
+ return SignalSchema(self.values | sql_to_python(args_map))
283
+
284
+ def clone_without_sys_signals(self) -> "SignalSchema":
285
+ schema = copy.deepcopy(self.values)
286
+ schema.pop("sys", None)
287
+ return SignalSchema(schema)
288
+
274
289
  def merge(
275
290
  self,
276
291
  right_schema: "SignalSchema",
@@ -283,9 +298,9 @@ class SignalSchema:
283
298
 
284
299
  return SignalSchema(self.values | schema_right)
285
300
 
286
- def get_file_signals(self) -> Iterator[str]:
301
+ def get_signals(self, target_type: type[DataModel]) -> Iterator[str]:
287
302
  for path, type_, has_subtree, _ in self.get_flat_tree():
288
- if has_subtree and issubclass(type_, File):
303
+ if has_subtree and issubclass(type_, target_type):
289
304
  yield ".".join(path)
290
305
 
291
306
  def create_model(self, name: str) -> type[DataModel]:
@@ -331,6 +346,16 @@ class SignalSchema:
331
346
  sub_schema = SignalSchema({"* list of": args[0]})
332
347
  sub_schema.print_tree(indent=indent, start_at=total_indent + indent)
333
348
 
349
+ def get_headers_with_length(self):
350
+ paths = [
351
+ path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree
352
+ ]
353
+ max_length = max([len(path) for path in paths], default=0)
354
+ return [
355
+ path + [""] * (max_length - len(path)) if len(path) < max_length else path
356
+ for path in paths
357
+ ], max_length
358
+
334
359
  def __or__(self, other):
335
360
  return self.__class__(self.values | other.values)
336
361
 
datachain/lib/text.py CHANGED
@@ -31,8 +31,9 @@ def convert_text(
31
31
  res = tokenizer(text)
32
32
 
33
33
  tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
34
+ tokens = torch.tensor(tokens)
34
35
 
35
36
  if not encoder:
36
37
  return tokens
37
38
 
38
- return encoder(torch.tensor(tokens))
39
+ return encoder(tokens)
datachain/lib/udf.py CHANGED
@@ -9,7 +9,7 @@ from pydantic import BaseModel
9
9
  from datachain.dataset import RowDict
10
10
  from datachain.lib.convert.flatten import flatten
11
11
  from datachain.lib.convert.unflatten import unflatten_to_json
12
- from datachain.lib.data_model import FileBasic
12
+ from datachain.lib.file import File
13
13
  from datachain.lib.model_store import ModelStore
14
14
  from datachain.lib.signal_schema import SignalSchema
15
15
  from datachain.lib.udf_signature import UdfSignature
@@ -88,6 +88,53 @@ class UDFAdapter(_UDFBase):
88
88
 
89
89
 
90
90
  class UDFBase(AbstractUDF):
91
+ """Base class for stateful user-defined functions.
92
+
93
+ Any class that inherits from it must have a `process()` method that takes input
94
+ params from one or more rows in the chain and produces the expected output.
95
+
96
+ Optionally, the class may include these methods:
97
+ - `setup()` to run code on each worker before `process()` is called.
98
+ - `teardown()` to run code on each worker after `process()` completes.
99
+
100
+ Example:
101
+ ```py
102
+ from datachain import C, DataChain, Mapper
103
+ import open_clip
104
+
105
+ class ImageEncoder(Mapper):
106
+ def __init__(self, model_name: str, pretrained: str):
107
+ self.model_name = model_name
108
+ self.pretrained = pretrained
109
+
110
+ def setup(self):
111
+ self.model, _, self.preprocess = (
112
+ open_clip.create_model_and_transforms(
113
+ self.model_name, self.pretrained
114
+ )
115
+ )
116
+
117
+ def process(self, file) -> list[float]:
118
+ img = file.get_value()
119
+ img = self.preprocess(img).unsqueeze(0)
120
+ emb = self.model.encode_image(img)
121
+ return emb[0].tolist()
122
+
123
+ (
124
+ DataChain.from_storage(
125
+ "gs://datachain-demo/fashion-product-images/images", type="image"
126
+ )
127
+ .limit(5)
128
+ .map(
129
+ ImageEncoder("ViT-B-32", "laion2b_s34b_b79k"),
130
+ params=["file"],
131
+ output={"emb": list[float]},
132
+ )
133
+ .show()
134
+ )
135
+ ```
136
+ """
137
+
91
138
  is_input_batched = False
92
139
  is_output_batched = False
93
140
  is_input_grouped = False
@@ -198,7 +245,7 @@ class UDFBase(AbstractUDF):
198
245
  flat.extend(flatten(obj))
199
246
  else:
200
247
  flat.append(obj)
201
- res.append(flat)
248
+ res.append(tuple(flat))
202
249
  else:
203
250
  # Generator expression is required, otherwise the value will be materialized
204
251
  res = (
@@ -227,7 +274,7 @@ class UDFBase(AbstractUDF):
227
274
  for row in rows:
228
275
  obj_row = self.params.row_to_objs(row)
229
276
  for obj in obj_row:
230
- if isinstance(obj, FileBasic):
277
+ if isinstance(obj, File):
231
278
  obj._set_stream(
232
279
  self._catalog, caching_enabled=cache, download_cb=download_cb
233
280
  )
@@ -256,7 +303,7 @@ class UDFBase(AbstractUDF):
256
303
  else:
257
304
  obj = slice[0]
258
305
 
259
- if isinstance(obj, FileBasic):
306
+ if isinstance(obj, File):
260
307
  obj._set_stream(
261
308
  self._catalog, caching_enabled=cache, download_cb=download_cb
262
309
  )
@@ -280,7 +327,7 @@ class UDFBase(AbstractUDF):
280
327
 
281
328
 
282
329
  class Mapper(UDFBase):
283
- pass
330
+ """Inherit from this class to pass to `DataChain.map()`."""
284
331
 
285
332
 
286
333
  class BatchMapper(Mapper):
@@ -289,10 +336,14 @@ class BatchMapper(Mapper):
289
336
 
290
337
 
291
338
  class Generator(UDFBase):
339
+ """Inherit from this class to pass to `DataChain.gen()`."""
340
+
292
341
  is_output_batched = True
293
342
 
294
343
 
295
344
  class Aggregator(UDFBase):
345
+ """Inherit from this class to pass to `DataChain.agg()`."""
346
+
296
347
  is_input_batched = True
297
348
  is_output_batched = True
298
349
  is_input_grouped = True
@@ -131,7 +131,7 @@ class UdfSignature:
131
131
  raise UdfSignatureError(
132
132
  chain,
133
133
  f"output type '{value.__name__}' of signal '{key}' is not"
134
- f" supported. Please use Feature types: {DataTypeNames}",
134
+ f" supported. Please use DataModel types: {DataTypeNames}",
135
135
  )
136
136
 
137
137
  udf_output_map = output