datachain 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -4
- datachain/catalog/catalog.py +17 -2
- datachain/cli.py +8 -1
- datachain/data_storage/db_engine.py +0 -2
- datachain/data_storage/schema.py +15 -26
- datachain/data_storage/sqlite.py +3 -0
- datachain/data_storage/warehouse.py +1 -7
- datachain/lib/arrow.py +7 -13
- datachain/lib/cached_stream.py +3 -85
- datachain/lib/clip.py +151 -0
- datachain/lib/dc.py +41 -59
- datachain/lib/feature.py +5 -1
- datachain/lib/feature_registry.py +3 -2
- datachain/lib/feature_utils.py +1 -2
- datachain/lib/file.py +17 -24
- datachain/lib/image.py +37 -79
- datachain/lib/pytorch.py +4 -2
- datachain/lib/signal_schema.py +3 -4
- datachain/lib/text.py +18 -49
- datachain/lib/udf.py +64 -55
- datachain/lib/udf_signature.py +11 -10
- datachain/lib/utils.py +17 -0
- datachain/lib/webdataset.py +2 -2
- datachain/listing.py +0 -3
- datachain/query/dataset.py +66 -46
- datachain/query/dispatch.py +2 -2
- datachain/query/schema.py +1 -8
- datachain/query/udf.py +16 -18
- datachain/sql/sqlite/base.py +34 -2
- datachain/sql/sqlite/vector.py +13 -5
- datachain/utils.py +28 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/METADATA +3 -2
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/RECORD +37 -38
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
- datachain/_version.py +0 -16
- datachain/lib/reader.py +0 -49
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
datachain/catalog/catalog.py
CHANGED
|
@@ -65,7 +65,7 @@ from datachain.listing import Listing
|
|
|
65
65
|
from datachain.node import DirType, Node, NodeWithPath
|
|
66
66
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
67
67
|
from datachain.remote.studio import StudioClient
|
|
68
|
-
from datachain.sql.types import DateTime, SQLType, String
|
|
68
|
+
from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
|
|
69
69
|
from datachain.storage import Storage, StorageStatus, StorageURI
|
|
70
70
|
from datachain.utils import (
|
|
71
71
|
DataChainDir,
|
|
@@ -714,7 +714,22 @@ class Catalog:
|
|
|
714
714
|
source_metastore = self.metastore.clone(client.uri)
|
|
715
715
|
source_warehouse = self.warehouse.clone()
|
|
716
716
|
|
|
717
|
-
columns =
|
|
717
|
+
columns = [
|
|
718
|
+
Column("vtype", String),
|
|
719
|
+
Column("dir_type", Int),
|
|
720
|
+
Column("parent", String),
|
|
721
|
+
Column("name", String),
|
|
722
|
+
Column("etag", String),
|
|
723
|
+
Column("version", String),
|
|
724
|
+
Column("is_latest", Boolean),
|
|
725
|
+
Column("last_modified", DateTime(timezone=True)),
|
|
726
|
+
Column("size", Int64),
|
|
727
|
+
Column("owner_name", String),
|
|
728
|
+
Column("owner_id", String),
|
|
729
|
+
Column("location", JSON),
|
|
730
|
+
Column("source", String),
|
|
731
|
+
]
|
|
732
|
+
|
|
718
733
|
if skip_indexing:
|
|
719
734
|
source_metastore.create_storage_if_not_registered(client.uri)
|
|
720
735
|
storage = source_metastore.get_storage(client.uri)
|
datachain/cli.py
CHANGED
|
@@ -5,13 +5,14 @@ import sys
|
|
|
5
5
|
import traceback
|
|
6
6
|
from argparse import SUPPRESS, Action, ArgumentParser, ArgumentTypeError, Namespace
|
|
7
7
|
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
8
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
8
9
|
from itertools import chain
|
|
9
10
|
from multiprocessing import freeze_support
|
|
10
11
|
from typing import TYPE_CHECKING, Optional, Union
|
|
11
12
|
|
|
12
13
|
import shtab
|
|
13
14
|
|
|
14
|
-
from datachain import
|
|
15
|
+
from datachain import utils
|
|
15
16
|
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
|
|
16
17
|
from datachain.utils import DataChainDir
|
|
17
18
|
|
|
@@ -96,6 +97,12 @@ def add_show_args(parser: ArgumentParser) -> None:
|
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
100
|
+
try:
|
|
101
|
+
__version__ = version("datachain")
|
|
102
|
+
except PackageNotFoundError:
|
|
103
|
+
# package is not installed
|
|
104
|
+
__version__ = "unknown"
|
|
105
|
+
|
|
99
106
|
parser = ArgumentParser(
|
|
100
107
|
description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
|
|
101
108
|
)
|
datachain/data_storage/schema.py
CHANGED
|
@@ -14,7 +14,7 @@ from sqlalchemy.sql.expression import null, true
|
|
|
14
14
|
|
|
15
15
|
from datachain.node import DirType
|
|
16
16
|
from datachain.sql.functions import path
|
|
17
|
-
from datachain.sql.types import
|
|
17
|
+
from datachain.sql.types import Int, SQLType, UInt64
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from sqlalchemy import Engine
|
|
@@ -31,7 +31,7 @@ def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
|
|
|
31
31
|
"""
|
|
32
32
|
c_set: dict[str, sa.Column] = {}
|
|
33
33
|
for c in columns:
|
|
34
|
-
if ec := c_set.get(c.name, None):
|
|
34
|
+
if (ec := c_set.get(c.name, None)) is not None:
|
|
35
35
|
if str(ec.type) != str(c.type):
|
|
36
36
|
raise ValueError(
|
|
37
37
|
f"conflicting types for column {c.name}:{c.type!s} and {ec.type!s}"
|
|
@@ -137,7 +137,7 @@ class DataTable:
|
|
|
137
137
|
self.name: str = name
|
|
138
138
|
self.engine = engine
|
|
139
139
|
self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
|
|
140
|
-
self.column_types = column_types
|
|
140
|
+
self.column_types: dict[str, SQLType] = column_types or {}
|
|
141
141
|
|
|
142
142
|
@staticmethod
|
|
143
143
|
def copy_column(column: sa.Column):
|
|
@@ -171,8 +171,8 @@ class DataTable:
|
|
|
171
171
|
):
|
|
172
172
|
# copy columns, since re-using the same objects from another table
|
|
173
173
|
# may raise an error
|
|
174
|
-
columns = [cls.copy_column(c) for c in columns
|
|
175
|
-
columns =
|
|
174
|
+
columns = cls.sys_columns() + [cls.copy_column(c) for c in columns]
|
|
175
|
+
columns = dedup_columns(columns)
|
|
176
176
|
|
|
177
177
|
if metadata is None:
|
|
178
178
|
metadata = sa.MetaData()
|
|
@@ -186,12 +186,12 @@ class DataTable:
|
|
|
186
186
|
# Grab it from metadata instead.
|
|
187
187
|
table = self.metadata.tables[self.name]
|
|
188
188
|
|
|
189
|
+
column_types = self.column_types | {c.name: c.type for c in self.sys_columns()}
|
|
189
190
|
# adjusting types for custom columns to be instances of SQLType if possible
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
c.type = t() if inspect.isclass(t) else t
|
|
191
|
+
for c in table.columns:
|
|
192
|
+
if c.name in column_types:
|
|
193
|
+
t = column_types[c.name]
|
|
194
|
+
c.type = t() if inspect.isclass(t) else t
|
|
195
195
|
return table
|
|
196
196
|
|
|
197
197
|
@property
|
|
@@ -230,24 +230,13 @@ class DataTable:
|
|
|
230
230
|
def delete(self):
|
|
231
231
|
return self.apply_conditions(self.table.delete())
|
|
232
232
|
|
|
233
|
-
@
|
|
234
|
-
def
|
|
233
|
+
@staticmethod
|
|
234
|
+
def sys_columns():
|
|
235
235
|
return [
|
|
236
236
|
sa.Column("id", Int, primary_key=True),
|
|
237
|
-
sa.Column(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
sa.Column("parent", String, index=True),
|
|
241
|
-
sa.Column("name", String, nullable=False, index=True),
|
|
242
|
-
sa.Column("etag", String),
|
|
243
|
-
sa.Column("version", String),
|
|
244
|
-
sa.Column("is_latest", Boolean),
|
|
245
|
-
sa.Column("last_modified", DateTime(timezone=True)),
|
|
246
|
-
sa.Column("size", Int64, nullable=False, index=True),
|
|
247
|
-
sa.Column("owner_name", String),
|
|
248
|
-
sa.Column("owner_id", String),
|
|
249
|
-
sa.Column("location", JSON),
|
|
250
|
-
sa.Column("source", String, nullable=False),
|
|
237
|
+
sa.Column(
|
|
238
|
+
"random", UInt64, nullable=False, server_default=f.abs(f.random())
|
|
239
|
+
),
|
|
251
240
|
]
|
|
252
241
|
|
|
253
242
|
def dir_expansion(self):
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -33,6 +33,7 @@ from datachain.data_storage.schema import (
|
|
|
33
33
|
from datachain.dataset import DatasetRecord
|
|
34
34
|
from datachain.error import DataChainError
|
|
35
35
|
from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_dialect
|
|
36
|
+
from datachain.sql.sqlite.base import load_usearch_extension
|
|
36
37
|
from datachain.sql.types import SQLType
|
|
37
38
|
from datachain.storage import StorageURI
|
|
38
39
|
from datachain.utils import DataChainDir
|
|
@@ -114,6 +115,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
114
115
|
if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
|
|
115
116
|
db.set_trace_callback(print)
|
|
116
117
|
|
|
118
|
+
load_usearch_extension(db)
|
|
119
|
+
|
|
117
120
|
return cls(engine, MetaData(), db, db_file)
|
|
118
121
|
except RuntimeError:
|
|
119
122
|
raise DataChainError("Can't connect to SQLite DB") from None
|
|
@@ -4,7 +4,6 @@ import logging
|
|
|
4
4
|
import posixpath
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
7
|
-
from random import getrandbits
|
|
8
7
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
9
8
|
from urllib.parse import urlparse
|
|
10
9
|
|
|
@@ -41,8 +40,6 @@ except ImportError:
|
|
|
41
40
|
|
|
42
41
|
logger = logging.getLogger("datachain")
|
|
43
42
|
|
|
44
|
-
RANDOM_BITS = 63 # size of the random integer field
|
|
45
|
-
|
|
46
43
|
SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
|
|
47
44
|
|
|
48
45
|
|
|
@@ -408,10 +405,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
408
405
|
|
|
409
406
|
def _prepare_entry(entry: Entry):
|
|
410
407
|
assert entry.dir_type is not None
|
|
411
|
-
return attrs.asdict(entry) | {
|
|
412
|
-
"source": uri,
|
|
413
|
-
"random": getrandbits(RANDOM_BITS),
|
|
414
|
-
}
|
|
408
|
+
return attrs.asdict(entry) | {"source": uri}
|
|
415
409
|
|
|
416
410
|
return [_prepare_entry(e) for e in entries]
|
|
417
411
|
|
datachain/lib/arrow.py
CHANGED
|
@@ -3,21 +3,14 @@ from typing import TYPE_CHECKING, Optional
|
|
|
3
3
|
|
|
4
4
|
from pyarrow.dataset import dataset
|
|
5
5
|
|
|
6
|
-
from datachain.lib.
|
|
7
|
-
from datachain.lib.
|
|
6
|
+
from datachain.lib.file import File, IndexedFile
|
|
7
|
+
from datachain.lib.udf import Generator
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
import pyarrow as pa
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class
|
|
14
|
-
"""File source info for tables."""
|
|
15
|
-
|
|
16
|
-
file: File
|
|
17
|
-
index: int
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class ArrowGenerator:
|
|
13
|
+
class ArrowGenerator(Generator):
|
|
21
14
|
def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
|
|
22
15
|
"""
|
|
23
16
|
Generator for getting rows from tabular files.
|
|
@@ -27,16 +20,17 @@ class ArrowGenerator:
|
|
|
27
20
|
schema : Optional pyarrow schema for validation.
|
|
28
21
|
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
29
22
|
"""
|
|
23
|
+
super().__init__()
|
|
30
24
|
self.schema = schema
|
|
31
25
|
self.kwargs = kwargs
|
|
32
26
|
|
|
33
|
-
def
|
|
27
|
+
def process(self, file: File):
|
|
34
28
|
path = file.get_path()
|
|
35
29
|
ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
|
|
36
30
|
index = 0
|
|
37
31
|
for record_batch in ds.to_batches():
|
|
38
32
|
for record in record_batch.to_pylist():
|
|
39
|
-
source =
|
|
33
|
+
source = IndexedFile(file=file, index=index)
|
|
40
34
|
yield [source, *record.values()]
|
|
41
35
|
index += 1
|
|
42
36
|
|
|
@@ -44,7 +38,7 @@ class ArrowGenerator:
|
|
|
44
38
|
def schema_to_output(schema: "pa.Schema"):
|
|
45
39
|
"""Generate UDF output schema from pyarrow schema."""
|
|
46
40
|
default_column = 0
|
|
47
|
-
output = {"source":
|
|
41
|
+
output = {"source": IndexedFile}
|
|
48
42
|
for field in schema:
|
|
49
43
|
column = field.name.lower()
|
|
50
44
|
column = re.sub("[^0-9a-z_]+", "", column)
|
datachain/lib/cached_stream.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import shutil
|
|
3
|
-
import tempfile
|
|
4
1
|
from abc import ABC
|
|
5
2
|
from contextlib import AbstractContextManager
|
|
6
3
|
|
|
@@ -8,9 +5,7 @@ from datachain.cache import UniqueId
|
|
|
8
5
|
|
|
9
6
|
|
|
10
7
|
class AbstractCachedStream(AbstractContextManager, ABC):
|
|
11
|
-
def __init__(self,
|
|
12
|
-
self.stream = stream
|
|
13
|
-
self.size = size
|
|
8
|
+
def __init__(self, catalog, uid: UniqueId):
|
|
14
9
|
self.catalog = catalog
|
|
15
10
|
self.uid = uid
|
|
16
11
|
self.mode = "rb"
|
|
@@ -19,86 +14,9 @@ class AbstractCachedStream(AbstractContextManager, ABC):
|
|
|
19
14
|
self.mode = mode
|
|
20
15
|
|
|
21
16
|
|
|
22
|
-
class ProgressiveCacheStream(AbstractCachedStream):
|
|
23
|
-
BUF_SIZE = 4096
|
|
24
|
-
|
|
25
|
-
def __init__(self, stream, size, catalog, uid: UniqueId):
|
|
26
|
-
super().__init__(stream, size, catalog, uid)
|
|
27
|
-
|
|
28
|
-
self.target_path = self.catalog.cache.path_from_checksum(self.uid.get_hash())
|
|
29
|
-
self.cached_file = None
|
|
30
|
-
|
|
31
|
-
self.temp_file = None
|
|
32
|
-
self.temp_file_pos = 0
|
|
33
|
-
|
|
34
|
-
def __enter__(self):
|
|
35
|
-
if os.path.exists(self.target_path):
|
|
36
|
-
self.cached_file = open(self.target_path, mode=self.mode)
|
|
37
|
-
return self.cached_file
|
|
38
|
-
|
|
39
|
-
tmp_dir = self.catalog.cache.tmp_dir
|
|
40
|
-
if not os.path.exists(tmp_dir):
|
|
41
|
-
os.makedirs(tmp_dir)
|
|
42
|
-
self.temp_file = tempfile.NamedTemporaryFile(
|
|
43
|
-
prefix=str(self.uid.get_hash()), dir=tmp_dir, delete=False
|
|
44
|
-
)
|
|
45
|
-
return self
|
|
46
|
-
|
|
47
|
-
def __exit__(self, *args):
|
|
48
|
-
self.close()
|
|
49
|
-
|
|
50
|
-
def read(self, size=-1):
|
|
51
|
-
buf = self.stream.read(size)
|
|
52
|
-
pos = self.stream.tell()
|
|
53
|
-
|
|
54
|
-
if pos >= self.temp_file_pos:
|
|
55
|
-
self._cache_catch_up(pos, buf)
|
|
56
|
-
|
|
57
|
-
return buf
|
|
58
|
-
|
|
59
|
-
def close(self):
|
|
60
|
-
if self.cached_file:
|
|
61
|
-
self.cached_file.close()
|
|
62
|
-
|
|
63
|
-
if self.temp_file:
|
|
64
|
-
if self.temp_file_pos < self.size:
|
|
65
|
-
self._cache_catch_up(self.size)
|
|
66
|
-
|
|
67
|
-
self.temp_file.close()
|
|
68
|
-
if not os.path.exists(self.target_path):
|
|
69
|
-
os.makedirs(os.path.dirname(self.target_path), exist_ok=True)
|
|
70
|
-
shutil.move(self.temp_file.name, self.target_path)
|
|
71
|
-
|
|
72
|
-
self.stream.close()
|
|
73
|
-
|
|
74
|
-
def _cache_catch_up(self, pos_target, latest_buf=None):
|
|
75
|
-
pos_to_restore = self.stream.tell()
|
|
76
|
-
try:
|
|
77
|
-
remainder = pos_target - self.temp_file_pos
|
|
78
|
-
self.stream.seek(self.temp_file_pos)
|
|
79
|
-
while remainder > 0:
|
|
80
|
-
chunk_size = min(self.BUF_SIZE, remainder)
|
|
81
|
-
buf = self.stream.read(chunk_size)
|
|
82
|
-
self._cache_update(buf)
|
|
83
|
-
remainder -= len(buf)
|
|
84
|
-
finally:
|
|
85
|
-
self.stream.seek(pos_to_restore)
|
|
86
|
-
|
|
87
|
-
def _cache_update(self, buf):
|
|
88
|
-
length = len(buf)
|
|
89
|
-
self.temp_file.write(buf)
|
|
90
|
-
self.temp_file_pos += length
|
|
91
|
-
|
|
92
|
-
def seek(self, offset, whence=0):
|
|
93
|
-
return self.stream.seek(offset, whence)
|
|
94
|
-
|
|
95
|
-
def tell(self):
|
|
96
|
-
return self.stream.tell()
|
|
97
|
-
|
|
98
|
-
|
|
99
17
|
class PreCachedStream(AbstractCachedStream):
|
|
100
|
-
def __init__(self,
|
|
101
|
-
super().__init__(
|
|
18
|
+
def __init__(self, catalog, uid: UniqueId):
|
|
19
|
+
super().__init__(catalog, uid)
|
|
102
20
|
self.client = self.catalog.get_client(self.uid.storage)
|
|
103
21
|
self.cached_file = None
|
|
104
22
|
|
datachain/lib/clip.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from typing import Any, Callable, Literal, Union
|
|
3
|
+
|
|
4
|
+
from datachain.lib.image import convert_images
|
|
5
|
+
from datachain.lib.text import convert_text
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import torch
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
11
|
+
except ImportError as exc:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"Missing dependencies for computer vision:\n"
|
|
14
|
+
"To install run:\n\n"
|
|
15
|
+
" pip install 'datachain[cv]'\n"
|
|
16
|
+
) from exc
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
|
|
20
|
+
# Check for transformers CLIPModel
|
|
21
|
+
method_name = f"get_{type}_features"
|
|
22
|
+
if isinstance(model, PreTrainedModel) and (
|
|
23
|
+
hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name))
|
|
24
|
+
):
|
|
25
|
+
method = getattr(model, method_name)
|
|
26
|
+
return lambda x: method(torch.tensor(x))
|
|
27
|
+
|
|
28
|
+
# Check for model from clip or open_clip library
|
|
29
|
+
method_name = f"encode_{type}"
|
|
30
|
+
if hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name)):
|
|
31
|
+
return getattr(model, method_name)
|
|
32
|
+
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Error encoding {type}: "
|
|
35
|
+
"'model' must be a CLIP model from clip, open_clip, or transformers library."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def similarity_scores(
|
|
40
|
+
images: Union[None, Image.Image, list[Image.Image]],
|
|
41
|
+
text: Union[None, str, list[str]],
|
|
42
|
+
model: Any,
|
|
43
|
+
preprocess: Callable,
|
|
44
|
+
tokenizer: Callable,
|
|
45
|
+
prob: bool = False,
|
|
46
|
+
image_to_text: bool = True,
|
|
47
|
+
) -> list[list[float]]:
|
|
48
|
+
"""
|
|
49
|
+
Calculate CLIP similarity scores between one or more images and/or text.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
images: Images to use as inputs.
|
|
53
|
+
text: Text to use as inputs.
|
|
54
|
+
model: Model from clip or open_clip packages.
|
|
55
|
+
preprocess: Image preprocessor to apply.
|
|
56
|
+
tokenizer: Text tokenizer.
|
|
57
|
+
prob: Compute softmax probabilities.
|
|
58
|
+
image_to_text: Whether to compute for image-to-text or text-to-image. Ignored if
|
|
59
|
+
only one of images or text provided.
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
Examples
|
|
63
|
+
--------
|
|
64
|
+
|
|
65
|
+
using https://github.com/openai/CLIP
|
|
66
|
+
>>> import clip
|
|
67
|
+
>>> model, preprocess = clip.load("ViT-B/32")
|
|
68
|
+
>>> similarity_scores(img, "cat", model, preprocess, clip.tokenize)
|
|
69
|
+
[[21.813]]
|
|
70
|
+
|
|
71
|
+
using https://github.com/mlfoundations/open_clip
|
|
72
|
+
>>> import open_clip
|
|
73
|
+
>>> model, _, preprocess = open_clip.create_model_and_transforms(
|
|
74
|
+
... "ViT-B-32", pretrained="laion2b_s34b_b79k"
|
|
75
|
+
... )
|
|
76
|
+
>>> tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
|
77
|
+
>>> similarity_scores(img, "cat", model, preprocess, tokenizer)
|
|
78
|
+
[[21.813]]
|
|
79
|
+
|
|
80
|
+
using https://huggingface.co/docs/transformers/en/model_doc/clip
|
|
81
|
+
>>> from transformers import CLIPProcessor, CLIPModel
|
|
82
|
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
83
|
+
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
84
|
+
>>> scores = similarity_scores(
|
|
85
|
+
... img, "cat", model, processor.image_processor, processor.tokenizer
|
|
86
|
+
... )
|
|
87
|
+
[[21.813]]
|
|
88
|
+
|
|
89
|
+
image -> list of text
|
|
90
|
+
>>> similarity_scores(img, ["cat", "dog"], model, preprocess, tokenizer)
|
|
91
|
+
[[21.813, 35.313]]
|
|
92
|
+
|
|
93
|
+
list of images -> text
|
|
94
|
+
>>> similarity_scores([img1, img2], "cat", model, preprocess, tokenizer)
|
|
95
|
+
[[21.813], [83.123]]
|
|
96
|
+
|
|
97
|
+
list of images -> list of text
|
|
98
|
+
>>> similarity_scores([img1, img2], ["cat", "dog"], model, preprocess, tokenizer)
|
|
99
|
+
[[21.813, 35.313], [83.123, 34.843]]
|
|
100
|
+
|
|
101
|
+
list of images -> list of images
|
|
102
|
+
>>> similarity_scores([img1, img2], None, model, preprocess, tokenizer)
|
|
103
|
+
[[94.189, 37.092]]
|
|
104
|
+
|
|
105
|
+
list of text -> list of text
|
|
106
|
+
>>> similarity_scores(None, ["cat", "dog"], model, preprocess, tokenizer)
|
|
107
|
+
[[67.334, 23.588]]
|
|
108
|
+
|
|
109
|
+
text -> list of images
|
|
110
|
+
>>> similarity_scores([img1, img2], "cat", ..., image_to_text=False)
|
|
111
|
+
[[19.708, 19.842]]
|
|
112
|
+
|
|
113
|
+
show scores as softmax probabilities
|
|
114
|
+
>>> similarity_scores(img, ["cat", "dog"], ..., prob=True)
|
|
115
|
+
[[0.423, 0.577]]
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
with torch.no_grad():
|
|
119
|
+
if images is not None:
|
|
120
|
+
encoder = _get_encoder(model, "image")
|
|
121
|
+
image_features = convert_images(
|
|
122
|
+
images, transform=preprocess, encoder=encoder
|
|
123
|
+
)
|
|
124
|
+
image_features /= image_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
|
|
125
|
+
|
|
126
|
+
if text is not None:
|
|
127
|
+
encoder = _get_encoder(model, "text")
|
|
128
|
+
text_features = convert_text(text, tokenizer, encoder=encoder)
|
|
129
|
+
text_features /= text_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
|
|
130
|
+
|
|
131
|
+
if images is not None and text is not None:
|
|
132
|
+
if image_to_text:
|
|
133
|
+
logits = 100.0 * image_features @ text_features.T # type: ignore[operator,union-attr]
|
|
134
|
+
else:
|
|
135
|
+
logits = 100.0 * text_features @ image_features.T # type: ignore[operator,union-attr]
|
|
136
|
+
elif images is not None:
|
|
137
|
+
logits = 100.0 * image_features @ image_features.T # type: ignore[operator,union-attr]
|
|
138
|
+
elif text is not None:
|
|
139
|
+
logits = 100.0 * text_features @ text_features.T # type: ignore[operator,union-attr]
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Error calculating CLIP similarity - "
|
|
143
|
+
"provide at least one of images or text"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if prob:
|
|
147
|
+
scores = logits.softmax(dim=1)
|
|
148
|
+
else:
|
|
149
|
+
scores = logits
|
|
150
|
+
|
|
151
|
+
return scores.tolist()
|