datachain 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (39) hide show
  1. datachain/__init__.py +0 -4
  2. datachain/catalog/catalog.py +17 -2
  3. datachain/cli.py +8 -1
  4. datachain/data_storage/db_engine.py +0 -2
  5. datachain/data_storage/schema.py +15 -26
  6. datachain/data_storage/sqlite.py +3 -0
  7. datachain/data_storage/warehouse.py +1 -7
  8. datachain/lib/arrow.py +7 -13
  9. datachain/lib/cached_stream.py +3 -85
  10. datachain/lib/clip.py +151 -0
  11. datachain/lib/dc.py +41 -59
  12. datachain/lib/feature.py +5 -1
  13. datachain/lib/feature_registry.py +3 -2
  14. datachain/lib/feature_utils.py +1 -2
  15. datachain/lib/file.py +17 -24
  16. datachain/lib/image.py +37 -79
  17. datachain/lib/pytorch.py +4 -2
  18. datachain/lib/signal_schema.py +3 -4
  19. datachain/lib/text.py +18 -49
  20. datachain/lib/udf.py +64 -55
  21. datachain/lib/udf_signature.py +11 -10
  22. datachain/lib/utils.py +17 -0
  23. datachain/lib/webdataset.py +2 -2
  24. datachain/listing.py +0 -3
  25. datachain/query/dataset.py +66 -46
  26. datachain/query/dispatch.py +2 -2
  27. datachain/query/schema.py +1 -8
  28. datachain/query/udf.py +16 -18
  29. datachain/sql/sqlite/base.py +34 -2
  30. datachain/sql/sqlite/vector.py +13 -5
  31. datachain/utils.py +28 -0
  32. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/METADATA +3 -2
  33. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/RECORD +37 -38
  34. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
  35. datachain/_version.py +0 -16
  36. datachain/lib/reader.py +0 -49
  37. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
  38. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
  39. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py CHANGED
@@ -14,7 +14,7 @@ import sqlalchemy
14
14
 
15
15
  from datachain.lib.feature import Feature, FeatureType
16
16
  from datachain.lib.feature_utils import features_to_tuples
17
- from datachain.lib.file import File, get_file
17
+ from datachain.lib.file import File, IndexedFile, get_file
18
18
  from datachain.lib.meta_formats import read_meta, read_schema
19
19
  from datachain.lib.settings import Settings
20
20
  from datachain.lib.signal_schema import SignalSchema
@@ -39,6 +39,8 @@ if TYPE_CHECKING:
39
39
  import pandas as pd
40
40
  from typing_extensions import Self
41
41
 
42
+ from datachain.catalog import Catalog
43
+
42
44
  C = Column
43
45
 
44
46
 
@@ -200,10 +202,12 @@ class DataChain(DatasetQuery):
200
202
  def from_storage(
201
203
  cls,
202
204
  path,
205
+ *,
203
206
  type: Literal["binary", "text", "image"] = "binary",
207
+ catalog: Optional["Catalog"] = None,
204
208
  recursive: Optional[bool] = True,
205
209
  anon: bool = False,
206
- ) -> "DataChain":
210
+ ) -> "Self":
207
211
  """Get data from a storage as a list of file with all file attributes. It
208
212
  returns the chain itself as usual.
209
213
 
@@ -220,7 +224,7 @@ class DataChain(DatasetQuery):
220
224
  ```
221
225
  """
222
226
  func = get_file(type)
223
- return DataChain(path, recursive=recursive, anon=anon).map(file=func)
227
+ return cls(path, catalog=catalog, recursive=recursive, anon=anon).map(file=func)
224
228
 
225
229
  @classmethod
226
230
  def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
@@ -433,8 +437,7 @@ class DataChain(DatasetQuery):
433
437
 
434
438
  udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
435
439
 
436
- chain = DatasetQuery.add_signals(
437
- self,
440
+ chain = self.add_signals(
438
441
  udf_obj.to_udf_wrapper(self._settings.batch),
439
442
  **self._settings.to_dict(),
440
443
  )
@@ -530,23 +533,23 @@ class DataChain(DatasetQuery):
530
533
  signal_map,
531
534
  ) -> UDFBase:
532
535
  is_generator = target_class.is_output_batched
533
- name = self.name or "Unknown"
536
+ name = self.name or ""
537
+
534
538
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
539
+ params_schema = self.signals_schema.slice(sign.params)
535
540
 
536
- params_feature = self.signals_schema.slice(sign.params)
537
- udf = target_class(params_feature, sign.output_schema, func=sign.func)
538
- udf.set_catalog(self.catalog)
539
- return udf
541
+ return UDFBase._create(target_class, sign, params_schema, self.catalog)
540
542
 
541
543
  def _extend_features(self, method_name, *args, **kwargs):
542
544
  super_func = getattr(super(), method_name)
543
545
 
544
546
  new_schema = self.signals_schema.resolve(*args)
545
- columns = new_schema.db_signals()
546
- chain = super_func(*columns, **kwargs)
547
- chain.signals_schema = new_schema
547
+ columns = [C(col) for col in new_schema.db_signals()]
548
+ res = super_func(*columns, **kwargs)
549
+ if isinstance(res, DataChain):
550
+ res.signals_schema = new_schema
548
551
 
549
- return chain
552
+ return res
550
553
 
551
554
  @detach
552
555
  def select(self, *args: str) -> "Self":
@@ -699,6 +702,9 @@ class DataChain(DatasetQuery):
699
702
  right_on = on
700
703
  right_on_columns = on_columns
701
704
 
705
+ if self == right_ds:
706
+ right_ds = right_ds.clone(new_table=True)
707
+
702
708
  ops = [
703
709
  self.c(left) == right_ds.c(right)
704
710
  for left, right in zip(on_columns, right_on_columns)
@@ -774,11 +780,11 @@ class DataChain(DatasetQuery):
774
780
  from pyarrow import unify_schemas
775
781
  from pyarrow.dataset import dataset
776
782
 
777
- from datachain.lib.arrow import ArrowGenerator, Source, schema_to_output
783
+ from datachain.lib.arrow import ArrowGenerator, schema_to_output
778
784
 
779
785
  schema = None
780
786
  if output:
781
- output = {"source": Source} | output
787
+ output = {"source": IndexedFile} | output
782
788
  else:
783
789
  schemas = []
784
790
  for row in self.select("file").iterate():
@@ -791,7 +797,6 @@ class DataChain(DatasetQuery):
791
797
  schema = unify_schemas(schemas)
792
798
  try:
793
799
  output = schema_to_output(schema)
794
- print(f"Inferred tabular data schema: {output}")
795
800
  except ValueError as e:
796
801
  raise DatasetPrepareError(self.name, e) from e
797
802
 
@@ -893,15 +898,26 @@ class DataChain(DatasetQuery):
893
898
  >>> single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
894
899
  """
895
900
  session = Session.get(session)
896
- dsr = cls.create_empty_record(session=session)
897
-
898
- if to_insert is not None:
899
- if not isinstance(to_insert, list):
900
- to_insert = [to_insert]
901
-
902
- for record in to_insert:
903
- cls.insert_record(dsr, record, session=session)
901
+ catalog = session.catalog
904
902
 
903
+ name = session.generate_temp_dataset_name()
904
+ columns: tuple[sqlalchemy.Column[Any], ...] = tuple(
905
+ sqlalchemy.Column(name, typ)
906
+ for name, typ in File._datachain_column_types.items()
907
+ )
908
+ dsr = catalog.create_dataset(name, columns=columns)
909
+
910
+ if isinstance(to_insert, dict):
911
+ to_insert = [to_insert]
912
+ elif not to_insert:
913
+ to_insert = []
914
+
915
+ warehouse = catalog.warehouse
916
+ dr = warehouse.dataset_rows(dsr)
917
+ db = warehouse.db
918
+ insert_q = dr.get_table().insert()
919
+ for record in to_insert:
920
+ db.execute(insert_q.values(**record))
905
921
  return DataChain(name=dsr.name)
906
922
 
907
923
  def sum(self, fr: FeatureType): # type: ignore[override]
@@ -915,37 +931,3 @@ class DataChain(DatasetQuery):
915
931
 
916
932
  def max(self, fr: FeatureType): # type: ignore[override]
917
933
  return self._extend_features("max", fr)
918
-
919
- @detach
920
- def gen_random(self) -> "DataChain":
921
- from random import getrandbits
922
-
923
- from datachain.data_storage.warehouse import RANDOM_BITS
924
-
925
- if "random" not in self.signals_schema.values:
926
- chain = self.map(random=lambda: getrandbits(RANDOM_BITS), output=int).save()
927
- return chain.select_except("random")
928
-
929
- return self
930
-
931
- @detach
932
- def shuffle(self) -> "DataChain":
933
- """Return results in deterministic random order."""
934
- chain = self.gen_random()
935
- return DatasetQuery.shuffle(chain)
936
-
937
- @detach
938
- def chunk(self, index: int, total: int) -> "DataChain":
939
- """Split a query into smaller chunks for e.g. parallelization.
940
-
941
- Examples:
942
- >>> dc = DataChain(...)
943
- >>> chunk_1 = dc._chunk(0, 2)
944
- >>> chunk_2 = dc._chunk(1, 2)
945
-
946
- Note:
947
- Bear in mind that `index` is 0-indexed but `total` isn't.
948
- Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
949
- """
950
- chain = self.gen_random()
951
- return DatasetQuery.chunk(chain, index, total)
datachain/lib/feature.py CHANGED
@@ -7,6 +7,7 @@ from datetime import datetime
7
7
  from functools import lru_cache
8
8
  from types import GenericAlias
9
9
  from typing import (
10
+ TYPE_CHECKING,
10
11
  Any,
11
12
  ClassVar,
12
13
  Literal,
@@ -39,6 +40,9 @@ from datachain.sql.types import (
39
40
  String,
40
41
  )
41
42
 
43
+ if TYPE_CHECKING:
44
+ from datachain.catalog import Catalog
45
+
42
46
  FeatureStandardType = Union[
43
47
  type[int],
44
48
  type[str],
@@ -158,7 +162,7 @@ class Feature(BaseModel):
158
162
  s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
159
163
  return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
160
164
 
161
- def _set_stream(self, catalog, stream=None, caching_enabled: bool = False) -> None:
165
+ def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
162
166
  pass
163
167
 
164
168
  @classmethod
@@ -1,6 +1,7 @@
1
+ import logging
1
2
  from typing import Any, ClassVar, Optional
2
3
 
3
- from datachain.cli import logger
4
+ logger = logging.getLogger(__name__)
4
5
 
5
6
 
6
7
  class Registry:
@@ -16,7 +17,7 @@ class Registry:
16
17
  version = fr._version # type: ignore[attr-defined]
17
18
  if version in cls.reg[name]:
18
19
  full_name = f"{name}@{version}"
19
- logger.warning(f"Feature {full_name} is already registered")
20
+ logger.warning("Feature %s is already registered", full_name)
20
21
  cls.reg[name][version] = fr
21
22
 
22
23
  @classmethod
@@ -11,11 +11,10 @@ from datachain.lib.feature import (
11
11
  FeatureTypeNames,
12
12
  convert_type_to_datachain,
13
13
  )
14
- from datachain.lib.reader import FeatureReader
15
14
  from datachain.lib.utils import DataChainParamsError
16
15
  from datachain.query.schema import Column
17
16
 
18
- FeatureLike = Union[type["Feature"], FeatureReader, Column, str]
17
+ FeatureLike = Union[type["Feature"], Column, str]
19
18
 
20
19
  AUTO_FEATURE_PREFIX = "_auto_fr"
21
20
  SUFFIX_SYMBOLS = string.digits + string.ascii_lowercase
datachain/lib/file.py CHANGED
@@ -2,11 +2,10 @@ import json
2
2
  from abc import ABC, abstractmethod
3
3
  from datetime import datetime
4
4
  from pathlib import Path
5
- from typing import Any, ClassVar, Literal, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
6
6
  from urllib.parse import unquote, urlparse
7
7
  from urllib.request import url2pathname
8
8
 
9
- from fsspec import Callback
10
9
  from fsspec.implementations.local import LocalFileSystem
11
10
  from pydantic import Field, field_validator
12
11
 
@@ -18,6 +17,9 @@ from datachain.lib.utils import DataChainError
18
17
  from datachain.sql.types import JSON, Int, String
19
18
  from datachain.utils import TIME_ZERO
20
19
 
20
+ if TYPE_CHECKING:
21
+ from datachain.catalog import Catalog
22
+
21
23
 
22
24
  class FileFeature(Feature):
23
25
  _is_file = True
@@ -182,26 +184,17 @@ class File(FileFeature):
182
184
 
183
185
  def open(self):
184
186
  if self._stream is None:
185
- if self._catalog is None:
186
- raise FileError(self, "stream is not set")
187
- self._stream = self._open_stream()
187
+ raise FileError(self, "stream is not set")
188
188
 
189
189
  if self.location:
190
190
  return VFileRegistry.resolve(self, self.location)
191
191
 
192
192
  return self._stream
193
193
 
194
- def _set_stream(
195
- self, catalog=None, stream=None, caching_enabled: bool = False
196
- ) -> None:
197
- if self._catalog is None and catalog is None:
198
- raise DataChainError(f"Cannot set file '{stream}' without catalog")
199
-
200
- if catalog:
201
- self._catalog = catalog
202
-
194
+ def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
195
+ self._catalog = catalog
203
196
  stream_class = PreCachedStream if caching_enabled else PreDownloadStream
204
- self._stream = stream_class(stream, self.size, self._catalog, self.get_uid())
197
+ self._stream = stream_class(self._catalog, self.get_uid())
205
198
  self._caching_enabled = caching_enabled
206
199
 
207
200
  def get_uid(self) -> UniqueId:
@@ -232,11 +225,6 @@ class File(FileFeature):
232
225
  def get_uri(self):
233
226
  return f"{self.source}/{self.get_full_name()}"
234
227
 
235
- def _open_stream(self, cache: bool = False, cb: Optional[Callback] = None):
236
- client = self._catalog.get_client(self.source)
237
- uid = self.get_uid()
238
- return client.open_object(uid, use_cache=cache, cb=cb)
239
-
240
228
  def get_path(self) -> str:
241
229
  path = unquote(self.get_uri())
242
230
  fs = self.get_fs()
@@ -258,10 +246,8 @@ class TextFile(File):
258
246
  super().__init__(**kwargs)
259
247
  self._stream = None
260
248
 
261
- def _set_stream(
262
- self, catalog=None, stream=None, caching_enabled: bool = False
263
- ) -> None:
264
- super()._set_stream(catalog, stream, caching_enabled)
249
+ def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
250
+ super()._set_stream(catalog, caching_enabled)
265
251
  self._stream.set_mode("r")
266
252
 
267
253
 
@@ -296,3 +282,10 @@ def get_file(type: Literal["binary", "text", "image"] = "binary"):
296
282
  )
297
283
 
298
284
  return get_file_type
285
+
286
+
287
+ class IndexedFile(Feature):
288
+ """File source info for tables."""
289
+
290
+ file: File
291
+ index: int
datachain/lib/image.py CHANGED
@@ -1,6 +1,5 @@
1
- import inspect
2
1
  from io import BytesIO
3
- from typing import Any, Callable, Optional
2
+ from typing import Callable, Optional, Union
4
3
 
5
4
  from datachain.lib.file import File
6
5
 
@@ -14,8 +13,6 @@ except ImportError as exc:
14
13
  " pip install 'datachain[cv]'\n"
15
14
  ) from exc
16
15
 
17
- from datachain.lib.reader import FeatureReader
18
-
19
16
 
20
17
  class ImageFile(File):
21
18
  def get_value(self):
@@ -28,8 +25,8 @@ def convert_image(
28
25
  mode: str = "RGB",
29
26
  size: Optional[tuple[int, int]] = None,
30
27
  transform: Optional[Callable] = None,
31
- open_clip_model: Optional[Any] = None,
32
- ):
28
+ encoder: Optional[Callable] = None,
29
+ ) -> Union[Image.Image, torch.Tensor]:
33
30
  """
34
31
  Resize, transform, and otherwise convert an image.
35
32
 
@@ -37,8 +34,8 @@ def convert_image(
37
34
  img (Image): PIL.Image object.
38
35
  mode (str): PIL.Image mode.
39
36
  size (tuple[int, int]): Size in (width, height) pixels for resizing.
40
- transform (Callable): Torchvision v1 or other transform to apply.
41
- open_clip_model (Any): Encode image using model from open_clip library.
37
+ transform (Callable): Torchvision transform or huggingface processor to apply.
38
+ encoder (Callable): Encode image using model.
42
39
  """
43
40
  if mode:
44
41
  img = img.convert(mode)
@@ -46,86 +43,47 @@ def convert_image(
46
43
  img = img.resize(size)
47
44
  if transform:
48
45
  img = transform(img)
49
- if open_clip_model:
46
+
47
+ try:
48
+ from transformers.image_processing_utils import BaseImageProcessor
49
+
50
+ if isinstance(transform, BaseImageProcessor):
51
+ img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
52
+ except ImportError:
53
+ pass
54
+ if encoder:
50
55
  img = img.unsqueeze(0) # type: ignore[attr-defined]
51
- if open_clip_model:
52
- method_name = "encode_image"
53
- if not (
54
- hasattr(open_clip_model, method_name)
55
- and inspect.ismethod(getattr(open_clip_model, method_name))
56
- ):
57
- raise ValueError(
58
- "Unable to render Image: 'open_clip_model' doesn't support"
59
- f" '{method_name}()'"
60
- )
61
- img = open_clip_model.encode_image(img)
56
+ if encoder:
57
+ img = encoder(img)
62
58
  return img
63
59
 
64
60
 
65
- class ImageReader(FeatureReader):
66
- def __init__(
67
- self,
68
- mode: str = "RGB",
69
- size: Optional[tuple[int, int]] = None,
70
- transform: Optional[Callable] = None,
71
- open_clip_model: Any = None,
72
- ):
73
- """
74
- Read and optionally transform an image.
75
-
76
- All kwargs are passed to `convert_image()`.
77
- """
78
- self.mode = mode
79
- self.size = size
80
- self.transform = transform
81
- self.open_clip_model = open_clip_model
82
- super().__init__(ImageFile)
83
-
84
- def __call__(self, img: Image.Image):
85
- return convert_image(
86
- img,
87
- mode=self.mode,
88
- size=self.size,
89
- transform=self.transform,
90
- open_clip_model=self.open_clip_model,
91
- )
92
-
93
-
94
- def similarity_scores(
95
- model: Any,
96
- preprocess: Callable,
97
- tokenizer: Callable,
98
- image: Image.Image,
99
- text: str,
100
- prob: bool = False,
101
- ) -> list[float]:
61
+ def convert_images(
62
+ images: Union[Image.Image, list[Image.Image]],
63
+ mode: str = "RGB",
64
+ size: Optional[tuple[int, int]] = None,
65
+ transform: Optional[Callable] = None,
66
+ encoder: Optional[Callable] = None,
67
+ ) -> Union[list[Image.Image], torch.Tensor]:
102
68
  """
103
- Calculate CLIP similarity scores for one or more texts given an image.
69
+ Resize, transform, and otherwise convert one or more images.
104
70
 
105
71
  Args:
106
- model: Model from clip or open_clip packages.
107
- preprocess: Image preprocessing transforms.
108
- tokenizer: Text tokenizer.
109
- image: Image.
110
- text: Text.
111
- prob: Compute softmax probabilities across texts.
72
+ img (Image, list[Image]): PIL.Image object or list of objects.
73
+ mode (str): PIL.Image mode.
74
+ size (tuple[int, int]): Size in (width, height) pixels for resizing.
75
+ transform (Callable): Torchvision transform or huggingface processor to apply.
76
+ encoder (Callable): Encode image using model.
112
77
  """
78
+ if isinstance(images, Image.Image):
79
+ images = [images]
113
80
 
114
- with torch.no_grad():
115
- image = preprocess(image).unsqueeze(0)
116
- text = tokenizer(text)
117
-
118
- image_features = model.encode_image(image)
119
- text_features = model.encode_text(text)
120
-
121
- image_features /= image_features.norm(dim=-1, keepdim=True)
122
- text_features /= text_features.norm(dim=-1, keepdim=True)
81
+ converted = [convert_image(img, mode, size, transform) for img in images]
123
82
 
124
- logits_per_text = 100.0 * image_features @ text_features.T
83
+ if isinstance(converted[0], torch.Tensor):
84
+ converted = torch.stack(converted) # type: ignore[assignment,arg-type]
125
85
 
126
- if prob:
127
- scores = logits_per_text.softmax(dim=1)
128
- else:
129
- scores = logits_per_text
86
+ if encoder:
87
+ converted = encoder(converted)
130
88
 
131
- return scores[0].tolist()
89
+ return converted # type: ignore[return-value]
datachain/lib/pytorch.py CHANGED
@@ -116,10 +116,12 @@ class PytorchDataset(IterableDataset):
116
116
  self.transform = None
117
117
  if self.tokenizer:
118
118
  for i, val in enumerate(row):
119
- if isinstance(val, str):
119
+ if isinstance(val, str) or (
120
+ isinstance(val, list) and isinstance(val[0], str)
121
+ ):
120
122
  row[i] = convert_text(
121
123
  val, self.tokenizer, self.tokenizer_kwargs
122
- )
124
+ ).squeeze(0) # type: ignore[union-attr]
123
125
  yield row
124
126
 
125
127
  @staticmethod
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin
5
5
 
6
6
  from pydantic import create_model
7
7
 
8
- from datachain.lib.arrow import Source
9
8
  from datachain.lib.feature import (
10
9
  DATACHAIN_TO_TYPE,
11
10
  DEFAULT_DELIMITER,
@@ -14,7 +13,7 @@ from datachain.lib.feature import (
14
13
  convert_type_to_datachain,
15
14
  )
16
15
  from datachain.lib.feature_registry import Registry
17
- from datachain.lib.file import File, TextFile
16
+ from datachain.lib.file import File, IndexedFile, TextFile
18
17
  from datachain.lib.image import ImageFile
19
18
  from datachain.lib.utils import DataChainParamsError
20
19
  from datachain.lib.webdataset import TarStream, WDSAllFile, WDSBasic
@@ -36,7 +35,7 @@ NAMES_TO_TYPES = {
36
35
  "datetime": datetime,
37
36
  "WDSLaion": WDSLaion,
38
37
  "Laion": Laion,
39
- "Source": Source,
38
+ "Source": IndexedFile,
40
39
  "File": File,
41
40
  "ImageFile": ImageFile,
42
41
  "TextFile": TextFile,
@@ -150,7 +149,7 @@ class SignalSchema:
150
149
  )
151
150
 
152
151
  def slice(self, keys: Sequence[str]) -> "SignalSchema":
153
- return SignalSchema({k: v for k, v in self.values.items() if k in keys})
152
+ return SignalSchema({k: self.values[k] for k in keys if k in self.values})
154
153
 
155
154
  def row_to_features(self, row: Sequence, catalog: "Catalog") -> list[FeatureType]:
156
155
  res = []
datachain/lib/text.py CHANGED
@@ -1,19 +1,15 @@
1
- import inspect
2
1
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union
3
2
 
4
- from datachain.lib.file import TextFile
5
- from datachain.lib.reader import FeatureReader
6
-
7
3
  if TYPE_CHECKING:
8
- from datachain.lib.feature_utils import FeatureLike
4
+ import torch
9
5
 
10
6
 
11
7
  def convert_text(
12
8
  text: Union[str, list[str]],
13
9
  tokenizer: Optional[Callable] = None,
14
10
  tokenizer_kwargs: Optional[dict[str, Any]] = None,
15
- open_clip_model: Optional[Any] = None,
16
- ):
11
+ encoder: Optional[Callable] = None,
12
+ ) -> Union[str, list[str], "torch.Tensor"]:
17
13
  """
18
14
  Tokenize and otherwise transform text.
19
15
 
@@ -21,18 +17,8 @@ def convert_text(
21
17
  text (str): Text to convert.
22
18
  tokenizer (Callable): Tokenizer to use to tokenize objects.
23
19
  tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
24
- open_clip_model (Any): Encode text using model from open_clip library.
20
+ encoder (Callable): Encode text using model.
25
21
  """
26
- if open_clip_model:
27
- method_name = "encode_text"
28
- if not (
29
- hasattr(open_clip_model, method_name)
30
- and inspect.ismethod(getattr(open_clip_model, method_name))
31
- ):
32
- raise ValueError(
33
- f"TextColumn error: 'model' doesn't support '{method_name}()'"
34
- )
35
-
36
22
  if not tokenizer:
37
23
  return text
38
24
 
@@ -43,38 +29,21 @@ def convert_text(
43
29
  res = tokenizer(text, **tokenizer_kwargs)
44
30
  else:
45
31
  res = tokenizer(text)
46
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
47
-
48
- tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
49
-
50
- if not open_clip_model:
51
- return tokens.squeeze(0)
52
-
53
- return open_clip_model.encode_text(tokens).squeeze(0)
32
+ try:
33
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
54
34
 
35
+ tokens = (
36
+ res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
37
+ )
38
+ except ImportError:
39
+ tokens = res
55
40
 
56
- class TextReader(FeatureReader):
57
- def __init__(
58
- self,
59
- fr_class: "FeatureLike" = TextFile,
60
- tokenizer: Optional[Callable] = None,
61
- tokenizer_kwargs: Optional[dict[str, Any]] = None,
62
- open_clip_model: Optional[Any] = None,
63
- ):
64
- """
65
- Read and optionally transform a text column.
41
+ if not encoder:
42
+ return tokens
66
43
 
67
- All kwargs are passed to `convert_text()`.
68
- """
69
- self.tokenizer = tokenizer
70
- self.tokenizer_kwargs = tokenizer_kwargs
71
- self.open_clip_model = open_clip_model
72
- super().__init__(fr_class)
44
+ try:
45
+ import torch
46
+ except ImportError:
47
+ "Missing dependency 'torch' needed to encode text."
73
48
 
74
- def __call__(self, value: Union[str, list[str]]):
75
- return convert_text(
76
- value,
77
- tokenizer=self.tokenizer,
78
- tokenizer_kwargs=self.tokenizer_kwargs,
79
- open_clip_model=self.open_clip_model,
80
- )
49
+ return encoder(torch.tensor(tokens))