datachain 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -4
- datachain/catalog/catalog.py +17 -2
- datachain/cli.py +8 -1
- datachain/data_storage/db_engine.py +0 -2
- datachain/data_storage/schema.py +15 -26
- datachain/data_storage/sqlite.py +3 -0
- datachain/data_storage/warehouse.py +1 -7
- datachain/lib/arrow.py +7 -13
- datachain/lib/cached_stream.py +3 -85
- datachain/lib/clip.py +151 -0
- datachain/lib/dc.py +41 -59
- datachain/lib/feature.py +5 -1
- datachain/lib/feature_registry.py +3 -2
- datachain/lib/feature_utils.py +1 -2
- datachain/lib/file.py +17 -24
- datachain/lib/image.py +37 -79
- datachain/lib/pytorch.py +4 -2
- datachain/lib/signal_schema.py +3 -4
- datachain/lib/text.py +18 -49
- datachain/lib/udf.py +64 -55
- datachain/lib/udf_signature.py +11 -10
- datachain/lib/utils.py +17 -0
- datachain/lib/webdataset.py +2 -2
- datachain/listing.py +0 -3
- datachain/query/dataset.py +66 -46
- datachain/query/dispatch.py +2 -2
- datachain/query/schema.py +1 -8
- datachain/query/udf.py +16 -18
- datachain/sql/sqlite/base.py +34 -2
- datachain/sql/sqlite/vector.py +13 -5
- datachain/utils.py +28 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/METADATA +3 -2
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/RECORD +37 -38
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
- datachain/_version.py +0 -16
- datachain/lib/reader.py +0 -49
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/lib/dc.py
CHANGED
|
@@ -14,7 +14,7 @@ import sqlalchemy
|
|
|
14
14
|
|
|
15
15
|
from datachain.lib.feature import Feature, FeatureType
|
|
16
16
|
from datachain.lib.feature_utils import features_to_tuples
|
|
17
|
-
from datachain.lib.file import File, get_file
|
|
17
|
+
from datachain.lib.file import File, IndexedFile, get_file
|
|
18
18
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
19
19
|
from datachain.lib.settings import Settings
|
|
20
20
|
from datachain.lib.signal_schema import SignalSchema
|
|
@@ -39,6 +39,8 @@ if TYPE_CHECKING:
|
|
|
39
39
|
import pandas as pd
|
|
40
40
|
from typing_extensions import Self
|
|
41
41
|
|
|
42
|
+
from datachain.catalog import Catalog
|
|
43
|
+
|
|
42
44
|
C = Column
|
|
43
45
|
|
|
44
46
|
|
|
@@ -200,10 +202,12 @@ class DataChain(DatasetQuery):
|
|
|
200
202
|
def from_storage(
|
|
201
203
|
cls,
|
|
202
204
|
path,
|
|
205
|
+
*,
|
|
203
206
|
type: Literal["binary", "text", "image"] = "binary",
|
|
207
|
+
catalog: Optional["Catalog"] = None,
|
|
204
208
|
recursive: Optional[bool] = True,
|
|
205
209
|
anon: bool = False,
|
|
206
|
-
) -> "
|
|
210
|
+
) -> "Self":
|
|
207
211
|
"""Get data from a storage as a list of file with all file attributes. It
|
|
208
212
|
returns the chain itself as usual.
|
|
209
213
|
|
|
@@ -220,7 +224,7 @@ class DataChain(DatasetQuery):
|
|
|
220
224
|
```
|
|
221
225
|
"""
|
|
222
226
|
func = get_file(type)
|
|
223
|
-
return
|
|
227
|
+
return cls(path, catalog=catalog, recursive=recursive, anon=anon).map(file=func)
|
|
224
228
|
|
|
225
229
|
@classmethod
|
|
226
230
|
def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
|
|
@@ -433,8 +437,7 @@ class DataChain(DatasetQuery):
|
|
|
433
437
|
|
|
434
438
|
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
|
|
435
439
|
|
|
436
|
-
chain =
|
|
437
|
-
self,
|
|
440
|
+
chain = self.add_signals(
|
|
438
441
|
udf_obj.to_udf_wrapper(self._settings.batch),
|
|
439
442
|
**self._settings.to_dict(),
|
|
440
443
|
)
|
|
@@ -530,23 +533,23 @@ class DataChain(DatasetQuery):
|
|
|
530
533
|
signal_map,
|
|
531
534
|
) -> UDFBase:
|
|
532
535
|
is_generator = target_class.is_output_batched
|
|
533
|
-
name = self.name or "
|
|
536
|
+
name = self.name or ""
|
|
537
|
+
|
|
534
538
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
539
|
+
params_schema = self.signals_schema.slice(sign.params)
|
|
535
540
|
|
|
536
|
-
|
|
537
|
-
udf = target_class(params_feature, sign.output_schema, func=sign.func)
|
|
538
|
-
udf.set_catalog(self.catalog)
|
|
539
|
-
return udf
|
|
541
|
+
return UDFBase._create(target_class, sign, params_schema, self.catalog)
|
|
540
542
|
|
|
541
543
|
def _extend_features(self, method_name, *args, **kwargs):
|
|
542
544
|
super_func = getattr(super(), method_name)
|
|
543
545
|
|
|
544
546
|
new_schema = self.signals_schema.resolve(*args)
|
|
545
|
-
columns = new_schema.db_signals()
|
|
546
|
-
|
|
547
|
-
|
|
547
|
+
columns = [C(col) for col in new_schema.db_signals()]
|
|
548
|
+
res = super_func(*columns, **kwargs)
|
|
549
|
+
if isinstance(res, DataChain):
|
|
550
|
+
res.signals_schema = new_schema
|
|
548
551
|
|
|
549
|
-
return
|
|
552
|
+
return res
|
|
550
553
|
|
|
551
554
|
@detach
|
|
552
555
|
def select(self, *args: str) -> "Self":
|
|
@@ -699,6 +702,9 @@ class DataChain(DatasetQuery):
|
|
|
699
702
|
right_on = on
|
|
700
703
|
right_on_columns = on_columns
|
|
701
704
|
|
|
705
|
+
if self == right_ds:
|
|
706
|
+
right_ds = right_ds.clone(new_table=True)
|
|
707
|
+
|
|
702
708
|
ops = [
|
|
703
709
|
self.c(left) == right_ds.c(right)
|
|
704
710
|
for left, right in zip(on_columns, right_on_columns)
|
|
@@ -774,11 +780,11 @@ class DataChain(DatasetQuery):
|
|
|
774
780
|
from pyarrow import unify_schemas
|
|
775
781
|
from pyarrow.dataset import dataset
|
|
776
782
|
|
|
777
|
-
from datachain.lib.arrow import ArrowGenerator,
|
|
783
|
+
from datachain.lib.arrow import ArrowGenerator, schema_to_output
|
|
778
784
|
|
|
779
785
|
schema = None
|
|
780
786
|
if output:
|
|
781
|
-
output = {"source":
|
|
787
|
+
output = {"source": IndexedFile} | output
|
|
782
788
|
else:
|
|
783
789
|
schemas = []
|
|
784
790
|
for row in self.select("file").iterate():
|
|
@@ -791,7 +797,6 @@ class DataChain(DatasetQuery):
|
|
|
791
797
|
schema = unify_schemas(schemas)
|
|
792
798
|
try:
|
|
793
799
|
output = schema_to_output(schema)
|
|
794
|
-
print(f"Inferred tabular data schema: {output}")
|
|
795
800
|
except ValueError as e:
|
|
796
801
|
raise DatasetPrepareError(self.name, e) from e
|
|
797
802
|
|
|
@@ -893,15 +898,26 @@ class DataChain(DatasetQuery):
|
|
|
893
898
|
>>> single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
|
|
894
899
|
"""
|
|
895
900
|
session = Session.get(session)
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
if to_insert is not None:
|
|
899
|
-
if not isinstance(to_insert, list):
|
|
900
|
-
to_insert = [to_insert]
|
|
901
|
-
|
|
902
|
-
for record in to_insert:
|
|
903
|
-
cls.insert_record(dsr, record, session=session)
|
|
901
|
+
catalog = session.catalog
|
|
904
902
|
|
|
903
|
+
name = session.generate_temp_dataset_name()
|
|
904
|
+
columns: tuple[sqlalchemy.Column[Any], ...] = tuple(
|
|
905
|
+
sqlalchemy.Column(name, typ)
|
|
906
|
+
for name, typ in File._datachain_column_types.items()
|
|
907
|
+
)
|
|
908
|
+
dsr = catalog.create_dataset(name, columns=columns)
|
|
909
|
+
|
|
910
|
+
if isinstance(to_insert, dict):
|
|
911
|
+
to_insert = [to_insert]
|
|
912
|
+
elif not to_insert:
|
|
913
|
+
to_insert = []
|
|
914
|
+
|
|
915
|
+
warehouse = catalog.warehouse
|
|
916
|
+
dr = warehouse.dataset_rows(dsr)
|
|
917
|
+
db = warehouse.db
|
|
918
|
+
insert_q = dr.get_table().insert()
|
|
919
|
+
for record in to_insert:
|
|
920
|
+
db.execute(insert_q.values(**record))
|
|
905
921
|
return DataChain(name=dsr.name)
|
|
906
922
|
|
|
907
923
|
def sum(self, fr: FeatureType): # type: ignore[override]
|
|
@@ -915,37 +931,3 @@ class DataChain(DatasetQuery):
|
|
|
915
931
|
|
|
916
932
|
def max(self, fr: FeatureType): # type: ignore[override]
|
|
917
933
|
return self._extend_features("max", fr)
|
|
918
|
-
|
|
919
|
-
@detach
|
|
920
|
-
def gen_random(self) -> "DataChain":
|
|
921
|
-
from random import getrandbits
|
|
922
|
-
|
|
923
|
-
from datachain.data_storage.warehouse import RANDOM_BITS
|
|
924
|
-
|
|
925
|
-
if "random" not in self.signals_schema.values:
|
|
926
|
-
chain = self.map(random=lambda: getrandbits(RANDOM_BITS), output=int).save()
|
|
927
|
-
return chain.select_except("random")
|
|
928
|
-
|
|
929
|
-
return self
|
|
930
|
-
|
|
931
|
-
@detach
|
|
932
|
-
def shuffle(self) -> "DataChain":
|
|
933
|
-
"""Return results in deterministic random order."""
|
|
934
|
-
chain = self.gen_random()
|
|
935
|
-
return DatasetQuery.shuffle(chain)
|
|
936
|
-
|
|
937
|
-
@detach
|
|
938
|
-
def chunk(self, index: int, total: int) -> "DataChain":
|
|
939
|
-
"""Split a query into smaller chunks for e.g. parallelization.
|
|
940
|
-
|
|
941
|
-
Examples:
|
|
942
|
-
>>> dc = DataChain(...)
|
|
943
|
-
>>> chunk_1 = dc._chunk(0, 2)
|
|
944
|
-
>>> chunk_2 = dc._chunk(1, 2)
|
|
945
|
-
|
|
946
|
-
Note:
|
|
947
|
-
Bear in mind that `index` is 0-indexed but `total` isn't.
|
|
948
|
-
Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
|
|
949
|
-
"""
|
|
950
|
-
chain = self.gen_random()
|
|
951
|
-
return DatasetQuery.chunk(chain, index, total)
|
datachain/lib/feature.py
CHANGED
|
@@ -7,6 +7,7 @@ from datetime import datetime
|
|
|
7
7
|
from functools import lru_cache
|
|
8
8
|
from types import GenericAlias
|
|
9
9
|
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
10
11
|
Any,
|
|
11
12
|
ClassVar,
|
|
12
13
|
Literal,
|
|
@@ -39,6 +40,9 @@ from datachain.sql.types import (
|
|
|
39
40
|
String,
|
|
40
41
|
)
|
|
41
42
|
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from datachain.catalog import Catalog
|
|
45
|
+
|
|
42
46
|
FeatureStandardType = Union[
|
|
43
47
|
type[int],
|
|
44
48
|
type[str],
|
|
@@ -158,7 +162,7 @@ class Feature(BaseModel):
|
|
|
158
162
|
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
159
163
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
160
164
|
|
|
161
|
-
def _set_stream(self, catalog
|
|
165
|
+
def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
|
|
162
166
|
pass
|
|
163
167
|
|
|
164
168
|
@classmethod
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from typing import Any, ClassVar, Optional
|
|
2
3
|
|
|
3
|
-
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class Registry:
|
|
@@ -16,7 +17,7 @@ class Registry:
|
|
|
16
17
|
version = fr._version # type: ignore[attr-defined]
|
|
17
18
|
if version in cls.reg[name]:
|
|
18
19
|
full_name = f"{name}@{version}"
|
|
19
|
-
logger.warning(
|
|
20
|
+
logger.warning("Feature %s is already registered", full_name)
|
|
20
21
|
cls.reg[name][version] = fr
|
|
21
22
|
|
|
22
23
|
@classmethod
|
datachain/lib/feature_utils.py
CHANGED
|
@@ -11,11 +11,10 @@ from datachain.lib.feature import (
|
|
|
11
11
|
FeatureTypeNames,
|
|
12
12
|
convert_type_to_datachain,
|
|
13
13
|
)
|
|
14
|
-
from datachain.lib.reader import FeatureReader
|
|
15
14
|
from datachain.lib.utils import DataChainParamsError
|
|
16
15
|
from datachain.query.schema import Column
|
|
17
16
|
|
|
18
|
-
FeatureLike = Union[type["Feature"],
|
|
17
|
+
FeatureLike = Union[type["Feature"], Column, str]
|
|
19
18
|
|
|
20
19
|
AUTO_FEATURE_PREFIX = "_auto_fr"
|
|
21
20
|
SUFFIX_SYMBOLS = string.digits + string.ascii_lowercase
|
datachain/lib/file.py
CHANGED
|
@@ -2,11 +2,10 @@ import json
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any, ClassVar, Literal, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
|
|
6
6
|
from urllib.parse import unquote, urlparse
|
|
7
7
|
from urllib.request import url2pathname
|
|
8
8
|
|
|
9
|
-
from fsspec import Callback
|
|
10
9
|
from fsspec.implementations.local import LocalFileSystem
|
|
11
10
|
from pydantic import Field, field_validator
|
|
12
11
|
|
|
@@ -18,6 +17,9 @@ from datachain.lib.utils import DataChainError
|
|
|
18
17
|
from datachain.sql.types import JSON, Int, String
|
|
19
18
|
from datachain.utils import TIME_ZERO
|
|
20
19
|
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from datachain.catalog import Catalog
|
|
22
|
+
|
|
21
23
|
|
|
22
24
|
class FileFeature(Feature):
|
|
23
25
|
_is_file = True
|
|
@@ -182,26 +184,17 @@ class File(FileFeature):
|
|
|
182
184
|
|
|
183
185
|
def open(self):
|
|
184
186
|
if self._stream is None:
|
|
185
|
-
|
|
186
|
-
raise FileError(self, "stream is not set")
|
|
187
|
-
self._stream = self._open_stream()
|
|
187
|
+
raise FileError(self, "stream is not set")
|
|
188
188
|
|
|
189
189
|
if self.location:
|
|
190
190
|
return VFileRegistry.resolve(self, self.location)
|
|
191
191
|
|
|
192
192
|
return self._stream
|
|
193
193
|
|
|
194
|
-
def _set_stream(
|
|
195
|
-
self
|
|
196
|
-
) -> None:
|
|
197
|
-
if self._catalog is None and catalog is None:
|
|
198
|
-
raise DataChainError(f"Cannot set file '{stream}' without catalog")
|
|
199
|
-
|
|
200
|
-
if catalog:
|
|
201
|
-
self._catalog = catalog
|
|
202
|
-
|
|
194
|
+
def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
|
|
195
|
+
self._catalog = catalog
|
|
203
196
|
stream_class = PreCachedStream if caching_enabled else PreDownloadStream
|
|
204
|
-
self._stream = stream_class(
|
|
197
|
+
self._stream = stream_class(self._catalog, self.get_uid())
|
|
205
198
|
self._caching_enabled = caching_enabled
|
|
206
199
|
|
|
207
200
|
def get_uid(self) -> UniqueId:
|
|
@@ -232,11 +225,6 @@ class File(FileFeature):
|
|
|
232
225
|
def get_uri(self):
|
|
233
226
|
return f"{self.source}/{self.get_full_name()}"
|
|
234
227
|
|
|
235
|
-
def _open_stream(self, cache: bool = False, cb: Optional[Callback] = None):
|
|
236
|
-
client = self._catalog.get_client(self.source)
|
|
237
|
-
uid = self.get_uid()
|
|
238
|
-
return client.open_object(uid, use_cache=cache, cb=cb)
|
|
239
|
-
|
|
240
228
|
def get_path(self) -> str:
|
|
241
229
|
path = unquote(self.get_uri())
|
|
242
230
|
fs = self.get_fs()
|
|
@@ -258,10 +246,8 @@ class TextFile(File):
|
|
|
258
246
|
super().__init__(**kwargs)
|
|
259
247
|
self._stream = None
|
|
260
248
|
|
|
261
|
-
def _set_stream(
|
|
262
|
-
|
|
263
|
-
) -> None:
|
|
264
|
-
super()._set_stream(catalog, stream, caching_enabled)
|
|
249
|
+
def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
|
|
250
|
+
super()._set_stream(catalog, caching_enabled)
|
|
265
251
|
self._stream.set_mode("r")
|
|
266
252
|
|
|
267
253
|
|
|
@@ -296,3 +282,10 @@ def get_file(type: Literal["binary", "text", "image"] = "binary"):
|
|
|
296
282
|
)
|
|
297
283
|
|
|
298
284
|
return get_file_type
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class IndexedFile(Feature):
|
|
288
|
+
"""File source info for tables."""
|
|
289
|
+
|
|
290
|
+
file: File
|
|
291
|
+
index: int
|
datachain/lib/image.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
from io import BytesIO
|
|
3
|
-
from typing import
|
|
2
|
+
from typing import Callable, Optional, Union
|
|
4
3
|
|
|
5
4
|
from datachain.lib.file import File
|
|
6
5
|
|
|
@@ -14,8 +13,6 @@ except ImportError as exc:
|
|
|
14
13
|
" pip install 'datachain[cv]'\n"
|
|
15
14
|
) from exc
|
|
16
15
|
|
|
17
|
-
from datachain.lib.reader import FeatureReader
|
|
18
|
-
|
|
19
16
|
|
|
20
17
|
class ImageFile(File):
|
|
21
18
|
def get_value(self):
|
|
@@ -28,8 +25,8 @@ def convert_image(
|
|
|
28
25
|
mode: str = "RGB",
|
|
29
26
|
size: Optional[tuple[int, int]] = None,
|
|
30
27
|
transform: Optional[Callable] = None,
|
|
31
|
-
|
|
32
|
-
):
|
|
28
|
+
encoder: Optional[Callable] = None,
|
|
29
|
+
) -> Union[Image.Image, torch.Tensor]:
|
|
33
30
|
"""
|
|
34
31
|
Resize, transform, and otherwise convert an image.
|
|
35
32
|
|
|
@@ -37,8 +34,8 @@ def convert_image(
|
|
|
37
34
|
img (Image): PIL.Image object.
|
|
38
35
|
mode (str): PIL.Image mode.
|
|
39
36
|
size (tuple[int, int]): Size in (width, height) pixels for resizing.
|
|
40
|
-
transform (Callable): Torchvision
|
|
41
|
-
|
|
37
|
+
transform (Callable): Torchvision transform or huggingface processor to apply.
|
|
38
|
+
encoder (Callable): Encode image using model.
|
|
42
39
|
"""
|
|
43
40
|
if mode:
|
|
44
41
|
img = img.convert(mode)
|
|
@@ -46,86 +43,47 @@ def convert_image(
|
|
|
46
43
|
img = img.resize(size)
|
|
47
44
|
if transform:
|
|
48
45
|
img = transform(img)
|
|
49
|
-
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
from transformers.image_processing_utils import BaseImageProcessor
|
|
49
|
+
|
|
50
|
+
if isinstance(transform, BaseImageProcessor):
|
|
51
|
+
img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
|
|
52
|
+
except ImportError:
|
|
53
|
+
pass
|
|
54
|
+
if encoder:
|
|
50
55
|
img = img.unsqueeze(0) # type: ignore[attr-defined]
|
|
51
|
-
if
|
|
52
|
-
|
|
53
|
-
if not (
|
|
54
|
-
hasattr(open_clip_model, method_name)
|
|
55
|
-
and inspect.ismethod(getattr(open_clip_model, method_name))
|
|
56
|
-
):
|
|
57
|
-
raise ValueError(
|
|
58
|
-
"Unable to render Image: 'open_clip_model' doesn't support"
|
|
59
|
-
f" '{method_name}()'"
|
|
60
|
-
)
|
|
61
|
-
img = open_clip_model.encode_image(img)
|
|
56
|
+
if encoder:
|
|
57
|
+
img = encoder(img)
|
|
62
58
|
return img
|
|
63
59
|
|
|
64
60
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
):
|
|
73
|
-
"""
|
|
74
|
-
Read and optionally transform an image.
|
|
75
|
-
|
|
76
|
-
All kwargs are passed to `convert_image()`.
|
|
77
|
-
"""
|
|
78
|
-
self.mode = mode
|
|
79
|
-
self.size = size
|
|
80
|
-
self.transform = transform
|
|
81
|
-
self.open_clip_model = open_clip_model
|
|
82
|
-
super().__init__(ImageFile)
|
|
83
|
-
|
|
84
|
-
def __call__(self, img: Image.Image):
|
|
85
|
-
return convert_image(
|
|
86
|
-
img,
|
|
87
|
-
mode=self.mode,
|
|
88
|
-
size=self.size,
|
|
89
|
-
transform=self.transform,
|
|
90
|
-
open_clip_model=self.open_clip_model,
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def similarity_scores(
|
|
95
|
-
model: Any,
|
|
96
|
-
preprocess: Callable,
|
|
97
|
-
tokenizer: Callable,
|
|
98
|
-
image: Image.Image,
|
|
99
|
-
text: str,
|
|
100
|
-
prob: bool = False,
|
|
101
|
-
) -> list[float]:
|
|
61
|
+
def convert_images(
|
|
62
|
+
images: Union[Image.Image, list[Image.Image]],
|
|
63
|
+
mode: str = "RGB",
|
|
64
|
+
size: Optional[tuple[int, int]] = None,
|
|
65
|
+
transform: Optional[Callable] = None,
|
|
66
|
+
encoder: Optional[Callable] = None,
|
|
67
|
+
) -> Union[list[Image.Image], torch.Tensor]:
|
|
102
68
|
"""
|
|
103
|
-
|
|
69
|
+
Resize, transform, and otherwise convert one or more images.
|
|
104
70
|
|
|
105
71
|
Args:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
prob: Compute softmax probabilities across texts.
|
|
72
|
+
img (Image, list[Image]): PIL.Image object or list of objects.
|
|
73
|
+
mode (str): PIL.Image mode.
|
|
74
|
+
size (tuple[int, int]): Size in (width, height) pixels for resizing.
|
|
75
|
+
transform (Callable): Torchvision transform or huggingface processor to apply.
|
|
76
|
+
encoder (Callable): Encode image using model.
|
|
112
77
|
"""
|
|
78
|
+
if isinstance(images, Image.Image):
|
|
79
|
+
images = [images]
|
|
113
80
|
|
|
114
|
-
|
|
115
|
-
image = preprocess(image).unsqueeze(0)
|
|
116
|
-
text = tokenizer(text)
|
|
117
|
-
|
|
118
|
-
image_features = model.encode_image(image)
|
|
119
|
-
text_features = model.encode_text(text)
|
|
120
|
-
|
|
121
|
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
122
|
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
81
|
+
converted = [convert_image(img, mode, size, transform) for img in images]
|
|
123
82
|
|
|
124
|
-
|
|
83
|
+
if isinstance(converted[0], torch.Tensor):
|
|
84
|
+
converted = torch.stack(converted) # type: ignore[assignment,arg-type]
|
|
125
85
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
scores = logits_per_text
|
|
86
|
+
if encoder:
|
|
87
|
+
converted = encoder(converted)
|
|
130
88
|
|
|
131
|
-
|
|
89
|
+
return converted # type: ignore[return-value]
|
datachain/lib/pytorch.py
CHANGED
|
@@ -116,10 +116,12 @@ class PytorchDataset(IterableDataset):
|
|
|
116
116
|
self.transform = None
|
|
117
117
|
if self.tokenizer:
|
|
118
118
|
for i, val in enumerate(row):
|
|
119
|
-
if isinstance(val, str)
|
|
119
|
+
if isinstance(val, str) or (
|
|
120
|
+
isinstance(val, list) and isinstance(val[0], str)
|
|
121
|
+
):
|
|
120
122
|
row[i] = convert_text(
|
|
121
123
|
val, self.tokenizer, self.tokenizer_kwargs
|
|
122
|
-
)
|
|
124
|
+
).squeeze(0) # type: ignore[union-attr]
|
|
123
125
|
yield row
|
|
124
126
|
|
|
125
127
|
@staticmethod
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin
|
|
|
5
5
|
|
|
6
6
|
from pydantic import create_model
|
|
7
7
|
|
|
8
|
-
from datachain.lib.arrow import Source
|
|
9
8
|
from datachain.lib.feature import (
|
|
10
9
|
DATACHAIN_TO_TYPE,
|
|
11
10
|
DEFAULT_DELIMITER,
|
|
@@ -14,7 +13,7 @@ from datachain.lib.feature import (
|
|
|
14
13
|
convert_type_to_datachain,
|
|
15
14
|
)
|
|
16
15
|
from datachain.lib.feature_registry import Registry
|
|
17
|
-
from datachain.lib.file import File, TextFile
|
|
16
|
+
from datachain.lib.file import File, IndexedFile, TextFile
|
|
18
17
|
from datachain.lib.image import ImageFile
|
|
19
18
|
from datachain.lib.utils import DataChainParamsError
|
|
20
19
|
from datachain.lib.webdataset import TarStream, WDSAllFile, WDSBasic
|
|
@@ -36,7 +35,7 @@ NAMES_TO_TYPES = {
|
|
|
36
35
|
"datetime": datetime,
|
|
37
36
|
"WDSLaion": WDSLaion,
|
|
38
37
|
"Laion": Laion,
|
|
39
|
-
"Source":
|
|
38
|
+
"Source": IndexedFile,
|
|
40
39
|
"File": File,
|
|
41
40
|
"ImageFile": ImageFile,
|
|
42
41
|
"TextFile": TextFile,
|
|
@@ -150,7 +149,7 @@ class SignalSchema:
|
|
|
150
149
|
)
|
|
151
150
|
|
|
152
151
|
def slice(self, keys: Sequence[str]) -> "SignalSchema":
|
|
153
|
-
return SignalSchema({k:
|
|
152
|
+
return SignalSchema({k: self.values[k] for k in keys if k in self.values})
|
|
154
153
|
|
|
155
154
|
def row_to_features(self, row: Sequence, catalog: "Catalog") -> list[FeatureType]:
|
|
156
155
|
res = []
|
datachain/lib/text.py
CHANGED
|
@@ -1,19 +1,15 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
3
2
|
|
|
4
|
-
from datachain.lib.file import TextFile
|
|
5
|
-
from datachain.lib.reader import FeatureReader
|
|
6
|
-
|
|
7
3
|
if TYPE_CHECKING:
|
|
8
|
-
|
|
4
|
+
import torch
|
|
9
5
|
|
|
10
6
|
|
|
11
7
|
def convert_text(
|
|
12
8
|
text: Union[str, list[str]],
|
|
13
9
|
tokenizer: Optional[Callable] = None,
|
|
14
10
|
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
15
|
-
|
|
16
|
-
):
|
|
11
|
+
encoder: Optional[Callable] = None,
|
|
12
|
+
) -> Union[str, list[str], "torch.Tensor"]:
|
|
17
13
|
"""
|
|
18
14
|
Tokenize and otherwise transform text.
|
|
19
15
|
|
|
@@ -21,18 +17,8 @@ def convert_text(
|
|
|
21
17
|
text (str): Text to convert.
|
|
22
18
|
tokenizer (Callable): Tokenizer to use to tokenize objects.
|
|
23
19
|
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
|
|
24
|
-
|
|
20
|
+
encoder (Callable): Encode text using model.
|
|
25
21
|
"""
|
|
26
|
-
if open_clip_model:
|
|
27
|
-
method_name = "encode_text"
|
|
28
|
-
if not (
|
|
29
|
-
hasattr(open_clip_model, method_name)
|
|
30
|
-
and inspect.ismethod(getattr(open_clip_model, method_name))
|
|
31
|
-
):
|
|
32
|
-
raise ValueError(
|
|
33
|
-
f"TextColumn error: 'model' doesn't support '{method_name}()'"
|
|
34
|
-
)
|
|
35
|
-
|
|
36
22
|
if not tokenizer:
|
|
37
23
|
return text
|
|
38
24
|
|
|
@@ -43,38 +29,21 @@ def convert_text(
|
|
|
43
29
|
res = tokenizer(text, **tokenizer_kwargs)
|
|
44
30
|
else:
|
|
45
31
|
res = tokenizer(text)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
49
|
-
|
|
50
|
-
if not open_clip_model:
|
|
51
|
-
return tokens.squeeze(0)
|
|
52
|
-
|
|
53
|
-
return open_clip_model.encode_text(tokens).squeeze(0)
|
|
32
|
+
try:
|
|
33
|
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
54
34
|
|
|
35
|
+
tokens = (
|
|
36
|
+
res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
|
|
37
|
+
)
|
|
38
|
+
except ImportError:
|
|
39
|
+
tokens = res
|
|
55
40
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
self,
|
|
59
|
-
fr_class: "FeatureLike" = TextFile,
|
|
60
|
-
tokenizer: Optional[Callable] = None,
|
|
61
|
-
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
62
|
-
open_clip_model: Optional[Any] = None,
|
|
63
|
-
):
|
|
64
|
-
"""
|
|
65
|
-
Read and optionally transform a text column.
|
|
41
|
+
if not encoder:
|
|
42
|
+
return tokens
|
|
66
43
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
self.open_clip_model = open_clip_model
|
|
72
|
-
super().__init__(fr_class)
|
|
44
|
+
try:
|
|
45
|
+
import torch
|
|
46
|
+
except ImportError:
|
|
47
|
+
"Missing dependency 'torch' needed to encode text."
|
|
73
48
|
|
|
74
|
-
|
|
75
|
-
return convert_text(
|
|
76
|
-
value,
|
|
77
|
-
tokenizer=self.tokenizer,
|
|
78
|
-
tokenizer_kwargs=self.tokenizer_kwargs,
|
|
79
|
-
open_clip_model=self.open_clip_model,
|
|
80
|
-
)
|
|
49
|
+
return encoder(torch.tensor(tokens))
|