datachain 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -65,7 +65,7 @@ from datachain.listing import Listing
65
65
  from datachain.node import DirType, Node, NodeWithPath
66
66
  from datachain.nodes_thread_pool import NodesThreadPool
67
67
  from datachain.remote.studio import StudioClient
68
- from datachain.sql.types import DateTime, SQLType, String
68
+ from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
69
69
  from datachain.storage import Storage, StorageStatus, StorageURI
70
70
  from datachain.utils import (
71
71
  DataChainDir,
@@ -714,7 +714,22 @@ class Catalog:
714
714
  source_metastore = self.metastore.clone(client.uri)
715
715
  source_warehouse = self.warehouse.clone()
716
716
 
717
- columns = self.warehouse.schema.dataset_row_cls.file_columns()
717
+ columns = [
718
+ Column("vtype", String),
719
+ Column("dir_type", Int),
720
+ Column("parent", String),
721
+ Column("name", String),
722
+ Column("etag", String),
723
+ Column("version", String),
724
+ Column("is_latest", Boolean),
725
+ Column("last_modified", DateTime(timezone=True)),
726
+ Column("size", Int64),
727
+ Column("owner_name", String),
728
+ Column("owner_id", String),
729
+ Column("location", JSON),
730
+ Column("source", String),
731
+ ]
732
+
718
733
  if skip_indexing:
719
734
  source_metastore.create_storage_if_not_registered(client.uri)
720
735
  storage = source_metastore.get_storage(client.uri)
@@ -20,8 +20,6 @@ if TYPE_CHECKING:
20
20
 
21
21
  logger = logging.getLogger("datachain")
22
22
 
23
- RANDOM_BITS = 63 # size of the random integer field
24
-
25
23
  SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
26
24
 
27
25
 
@@ -14,7 +14,7 @@ from sqlalchemy.sql.expression import null, true
14
14
 
15
15
  from datachain.node import DirType
16
16
  from datachain.sql.functions import path
17
- from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
17
+ from datachain.sql.types import Int, SQLType, UInt64
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from sqlalchemy import Engine
@@ -137,7 +137,7 @@ class DataTable:
137
137
  self.name: str = name
138
138
  self.engine = engine
139
139
  self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
140
- self.column_types = column_types
140
+ self.column_types: dict[str, SQLType] = column_types or {}
141
141
 
142
142
  @staticmethod
143
143
  def copy_column(column: sa.Column):
@@ -186,12 +186,12 @@ class DataTable:
186
186
  # Grab it from metadata instead.
187
187
  table = self.metadata.tables[self.name]
188
188
 
189
+ column_types = self.column_types | {c.name: c.type for c in self.sys_columns()}
189
190
  # adjusting types for custom columns to be instances of SQLType if possible
190
- if self.column_types:
191
- for c in table.columns:
192
- if c.name in self.column_types:
193
- t = self.column_types[c.name]
194
- c.type = t() if inspect.isclass(t) else t
191
+ for c in table.columns:
192
+ if c.name in column_types:
193
+ t = column_types[c.name]
194
+ c.type = t() if inspect.isclass(t) else t
195
195
  return table
196
196
 
197
197
  @property
@@ -234,26 +234,9 @@ class DataTable:
234
234
  def sys_columns():
235
235
  return [
236
236
  sa.Column("id", Int, primary_key=True),
237
- sa.Column("random", Int64, nullable=False, default=f.random()),
238
- ]
239
-
240
- @classmethod
241
- def file_columns(cls) -> list[sa.Column]:
242
- return [
243
- *cls.sys_columns(),
244
- sa.Column("vtype", String, nullable=False, index=True),
245
- sa.Column("dir_type", Int, index=True),
246
- sa.Column("parent", String, index=True),
247
- sa.Column("name", String, nullable=False, index=True),
248
- sa.Column("etag", String),
249
- sa.Column("version", String),
250
- sa.Column("is_latest", Boolean),
251
- sa.Column("last_modified", DateTime(timezone=True)),
252
- sa.Column("size", Int64, nullable=False, index=True),
253
- sa.Column("owner_name", String),
254
- sa.Column("owner_id", String),
255
- sa.Column("location", JSON),
256
- sa.Column("source", String, nullable=False),
237
+ sa.Column(
238
+ "random", UInt64, nullable=False, server_default=f.abs(f.random())
239
+ ),
257
240
  ]
258
241
 
259
242
  def dir_expansion(self):
@@ -4,7 +4,6 @@ import logging
4
4
  import posixpath
5
5
  from abc import ABC, abstractmethod
6
6
  from collections.abc import Generator, Iterable, Iterator, Sequence
7
- from random import getrandbits
8
7
  from typing import TYPE_CHECKING, Any, Optional, Union
9
8
  from urllib.parse import urlparse
10
9
 
@@ -41,8 +40,6 @@ except ImportError:
41
40
 
42
41
  logger = logging.getLogger("datachain")
43
42
 
44
- RANDOM_BITS = 63 # size of the random integer field
45
-
46
43
  SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
47
44
 
48
45
 
@@ -408,10 +405,7 @@ class AbstractWarehouse(ABC, Serializable):
408
405
 
409
406
  def _prepare_entry(entry: Entry):
410
407
  assert entry.dir_type is not None
411
- return attrs.asdict(entry) | {
412
- "source": uri,
413
- "random": getrandbits(RANDOM_BITS),
414
- }
408
+ return attrs.asdict(entry) | {"source": uri}
415
409
 
416
410
  return [_prepare_entry(e) for e in entries]
417
411
 
datachain/lib/arrow.py CHANGED
@@ -3,21 +3,14 @@ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  from pyarrow.dataset import dataset
5
5
 
6
- from datachain.lib.feature import Feature
7
- from datachain.lib.file import File
6
+ from datachain.lib.file import File, IndexedFile
7
+ from datachain.lib.udf import Generator
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  import pyarrow as pa
11
11
 
12
12
 
13
- class Source(Feature):
14
- """File source info for tables."""
15
-
16
- file: File
17
- index: int
18
-
19
-
20
- class ArrowGenerator:
13
+ class ArrowGenerator(Generator):
21
14
  def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
22
15
  """
23
16
  Generator for getting rows from tabular files.
@@ -27,16 +20,17 @@ class ArrowGenerator:
27
20
  schema : Optional pyarrow schema for validation.
28
21
  kwargs: Parameters to pass to pyarrow.dataset.dataset.
29
22
  """
23
+ super().__init__()
30
24
  self.schema = schema
31
25
  self.kwargs = kwargs
32
26
 
33
- def __call__(self, file: File):
27
+ def process(self, file: File):
34
28
  path = file.get_path()
35
29
  ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
36
30
  index = 0
37
31
  for record_batch in ds.to_batches():
38
32
  for record in record_batch.to_pylist():
39
- source = Source(file=file, index=index)
33
+ source = IndexedFile(file=file, index=index)
40
34
  yield [source, *record.values()]
41
35
  index += 1
42
36
 
@@ -44,7 +38,7 @@ class ArrowGenerator:
44
38
  def schema_to_output(schema: "pa.Schema"):
45
39
  """Generate UDF output schema from pyarrow schema."""
46
40
  default_column = 0
47
- output = {"source": Source}
41
+ output = {"source": IndexedFile}
48
42
  for field in schema:
49
43
  column = field.name.lower()
50
44
  column = re.sub("[^0-9a-z_]+", "", column)
datachain/lib/clip.py ADDED
@@ -0,0 +1,151 @@
1
+ import inspect
2
+ from typing import Any, Callable, Literal, Union
3
+
4
+ from datachain.lib.image import convert_images
5
+ from datachain.lib.text import convert_text
6
+
7
+ try:
8
+ import torch
9
+ from PIL import Image
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ except ImportError as exc:
12
+ raise ImportError(
13
+ "Missing dependencies for computer vision:\n"
14
+ "To install run:\n\n"
15
+ " pip install 'datachain[cv]'\n"
16
+ ) from exc
17
+
18
+
19
+ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
20
+ # Check for transformers CLIPModel
21
+ method_name = f"get_{type}_features"
22
+ if isinstance(model, PreTrainedModel) and (
23
+ hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name))
24
+ ):
25
+ method = getattr(model, method_name)
26
+ return lambda x: method(torch.tensor(x))
27
+
28
+ # Check for model from clip or open_clip library
29
+ method_name = f"encode_{type}"
30
+ if hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name)):
31
+ return getattr(model, method_name)
32
+
33
+ raise ValueError(
34
+ f"Error encoding {type}: "
35
+ "'model' must be a CLIP model from clip, open_clip, or transformers library."
36
+ )
37
+
38
+
39
+ def similarity_scores(
40
+ images: Union[None, Image.Image, list[Image.Image]],
41
+ text: Union[None, str, list[str]],
42
+ model: Any,
43
+ preprocess: Callable,
44
+ tokenizer: Callable,
45
+ prob: bool = False,
46
+ image_to_text: bool = True,
47
+ ) -> list[list[float]]:
48
+ """
49
+ Calculate CLIP similarity scores between one or more images and/or text.
50
+
51
+ Args:
52
+ images: Images to use as inputs.
53
+ text: Text to use as inputs.
54
+ model: Model from clip or open_clip packages.
55
+ preprocess: Image preprocessor to apply.
56
+ tokenizer: Text tokenizer.
57
+ prob: Compute softmax probabilities.
58
+ image_to_text: Whether to compute for image-to-text or text-to-image. Ignored if
59
+ only one of images or text provided.
60
+
61
+
62
+ Examples
63
+ --------
64
+
65
+ using https://github.com/openai/CLIP
66
+ >>> import clip
67
+ >>> model, preprocess = clip.load("ViT-B/32")
68
+ >>> similarity_scores(img, "cat", model, preprocess, clip.tokenize)
69
+ [[21.813]]
70
+
71
+ using https://github.com/mlfoundations/open_clip
72
+ >>> import open_clip
73
+ >>> model, _, preprocess = open_clip.create_model_and_transforms(
74
+ ... "ViT-B-32", pretrained="laion2b_s34b_b79k"
75
+ ... )
76
+ >>> tokenizer = open_clip.get_tokenizer("ViT-B-32")
77
+ >>> similarity_scores(img, "cat", model, preprocess, tokenizer)
78
+ [[21.813]]
79
+
80
+ using https://huggingface.co/docs/transformers/en/model_doc/clip
81
+ >>> from transformers import CLIPProcessor, CLIPModel
82
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
83
+ >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
84
+ >>> scores = similarity_scores(
85
+ ... img, "cat", model, processor.image_processor, processor.tokenizer
86
+ ... )
87
+ [[21.813]]
88
+
89
+ image -> list of text
90
+ >>> similarity_scores(img, ["cat", "dog"], model, preprocess, tokenizer)
91
+ [[21.813, 35.313]]
92
+
93
+ list of images -> text
94
+ >>> similarity_scores([img1, img2], "cat", model, preprocess, tokenizer)
95
+ [[21.813], [83.123]]
96
+
97
+ list of images -> list of text
98
+ >>> similarity_scores([img1, img2], ["cat", "dog"], model, preprocess, tokenizer)
99
+ [[21.813, 35.313], [83.123, 34.843]]
100
+
101
+ list of images -> list of images
102
+ >>> similarity_scores([img1, img2], None, model, preprocess, tokenizer)
103
+ [[94.189, 37.092]]
104
+
105
+ list of text -> list of text
106
+ >>> similarity_scores(None, ["cat", "dog"], model, preprocess, tokenizer)
107
+ [[67.334, 23.588]]
108
+
109
+ text -> list of images
110
+ >>> similarity_scores([img1, img2], "cat", ..., image_to_text=False)
111
+ [[19.708, 19.842]]
112
+
113
+ show scores as softmax probabilities
114
+ >>> similarity_scores(img, ["cat", "dog"], ..., prob=True)
115
+ [[0.423, 0.577]]
116
+ """
117
+
118
+ with torch.no_grad():
119
+ if images is not None:
120
+ encoder = _get_encoder(model, "image")
121
+ image_features = convert_images(
122
+ images, transform=preprocess, encoder=encoder
123
+ )
124
+ image_features /= image_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
125
+
126
+ if text is not None:
127
+ encoder = _get_encoder(model, "text")
128
+ text_features = convert_text(text, tokenizer, encoder=encoder)
129
+ text_features /= text_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
130
+
131
+ if images is not None and text is not None:
132
+ if image_to_text:
133
+ logits = 100.0 * image_features @ text_features.T # type: ignore[operator,union-attr]
134
+ else:
135
+ logits = 100.0 * text_features @ image_features.T # type: ignore[operator,union-attr]
136
+ elif images is not None:
137
+ logits = 100.0 * image_features @ image_features.T # type: ignore[operator,union-attr]
138
+ elif text is not None:
139
+ logits = 100.0 * text_features @ text_features.T # type: ignore[operator,union-attr]
140
+ else:
141
+ raise ValueError(
142
+ "Error calculating CLIP similarity - "
143
+ "provide at least one of images or text"
144
+ )
145
+
146
+ if prob:
147
+ scores = logits.softmax(dim=1)
148
+ else:
149
+ scores = logits
150
+
151
+ return scores.tolist()
datachain/lib/dc.py CHANGED
@@ -14,7 +14,7 @@ import sqlalchemy
14
14
 
15
15
  from datachain.lib.feature import Feature, FeatureType
16
16
  from datachain.lib.feature_utils import features_to_tuples
17
- from datachain.lib.file import File, get_file
17
+ from datachain.lib.file import File, IndexedFile, get_file
18
18
  from datachain.lib.meta_formats import read_meta, read_schema
19
19
  from datachain.lib.settings import Settings
20
20
  from datachain.lib.signal_schema import SignalSchema
@@ -437,8 +437,7 @@ class DataChain(DatasetQuery):
437
437
 
438
438
  udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
439
439
 
440
- chain = DatasetQuery.add_signals(
441
- self,
440
+ chain = self.add_signals(
442
441
  udf_obj.to_udf_wrapper(self._settings.batch),
443
442
  **self._settings.to_dict(),
444
443
  )
@@ -534,23 +533,23 @@ class DataChain(DatasetQuery):
534
533
  signal_map,
535
534
  ) -> UDFBase:
536
535
  is_generator = target_class.is_output_batched
537
- name = self.name or "Unknown"
536
+ name = self.name or ""
537
+
538
538
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
539
+ params_schema = self.signals_schema.slice(sign.params)
539
540
 
540
- params_feature = self.signals_schema.slice(sign.params)
541
- udf = target_class(params_feature, sign.output_schema, func=sign.func)
542
- udf.set_catalog(self.catalog)
543
- return udf
541
+ return UDFBase._create(target_class, sign, params_schema, self.catalog)
544
542
 
545
543
  def _extend_features(self, method_name, *args, **kwargs):
546
544
  super_func = getattr(super(), method_name)
547
545
 
548
546
  new_schema = self.signals_schema.resolve(*args)
549
- columns = new_schema.db_signals()
550
- chain = super_func(*columns, **kwargs)
551
- chain.signals_schema = new_schema
547
+ columns = [C(col) for col in new_schema.db_signals()]
548
+ res = super_func(*columns, **kwargs)
549
+ if isinstance(res, DataChain):
550
+ res.signals_schema = new_schema
552
551
 
553
- return chain
552
+ return res
554
553
 
555
554
  @detach
556
555
  def select(self, *args: str) -> "Self":
@@ -703,6 +702,9 @@ class DataChain(DatasetQuery):
703
702
  right_on = on
704
703
  right_on_columns = on_columns
705
704
 
705
+ if self == right_ds:
706
+ right_ds = right_ds.clone(new_table=True)
707
+
706
708
  ops = [
707
709
  self.c(left) == right_ds.c(right)
708
710
  for left, right in zip(on_columns, right_on_columns)
@@ -778,11 +780,11 @@ class DataChain(DatasetQuery):
778
780
  from pyarrow import unify_schemas
779
781
  from pyarrow.dataset import dataset
780
782
 
781
- from datachain.lib.arrow import ArrowGenerator, Source, schema_to_output
783
+ from datachain.lib.arrow import ArrowGenerator, schema_to_output
782
784
 
783
785
  schema = None
784
786
  if output:
785
- output = {"source": Source} | output
787
+ output = {"source": IndexedFile} | output
786
788
  else:
787
789
  schemas = []
788
790
  for row in self.select("file").iterate():
@@ -795,7 +797,6 @@ class DataChain(DatasetQuery):
795
797
  schema = unify_schemas(schemas)
796
798
  try:
797
799
  output = schema_to_output(schema)
798
- print(f"Inferred tabular data schema: {output}")
799
800
  except ValueError as e:
800
801
  raise DatasetPrepareError(self.name, e) from e
801
802
 
@@ -897,15 +898,26 @@ class DataChain(DatasetQuery):
897
898
  >>> single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
898
899
  """
899
900
  session = Session.get(session)
900
- dsr = cls.create_empty_record(session=session)
901
-
902
- if to_insert is not None:
903
- if not isinstance(to_insert, list):
904
- to_insert = [to_insert]
905
-
906
- for record in to_insert:
907
- cls.insert_record(dsr, record, session=session)
901
+ catalog = session.catalog
908
902
 
903
+ name = session.generate_temp_dataset_name()
904
+ columns: tuple[sqlalchemy.Column[Any], ...] = tuple(
905
+ sqlalchemy.Column(name, typ)
906
+ for name, typ in File._datachain_column_types.items()
907
+ )
908
+ dsr = catalog.create_dataset(name, columns=columns)
909
+
910
+ if isinstance(to_insert, dict):
911
+ to_insert = [to_insert]
912
+ elif not to_insert:
913
+ to_insert = []
914
+
915
+ warehouse = catalog.warehouse
916
+ dr = warehouse.dataset_rows(dsr)
917
+ db = warehouse.db
918
+ insert_q = dr.get_table().insert()
919
+ for record in to_insert:
920
+ db.execute(insert_q.values(**record))
909
921
  return DataChain(name=dsr.name)
910
922
 
911
923
  def sum(self, fr: FeatureType): # type: ignore[override]
@@ -919,37 +931,3 @@ class DataChain(DatasetQuery):
919
931
 
920
932
  def max(self, fr: FeatureType): # type: ignore[override]
921
933
  return self._extend_features("max", fr)
922
-
923
- @detach
924
- def gen_random(self) -> "DataChain":
925
- from random import getrandbits
926
-
927
- from datachain.data_storage.warehouse import RANDOM_BITS
928
-
929
- if "random" not in self.signals_schema.values:
930
- chain = self.map(random=lambda: getrandbits(RANDOM_BITS), output=int).save()
931
- return chain.select_except("random")
932
-
933
- return self
934
-
935
- @detach
936
- def shuffle(self) -> "DataChain":
937
- """Return results in deterministic random order."""
938
- chain = self.gen_random()
939
- return DatasetQuery.shuffle(chain)
940
-
941
- @detach
942
- def chunk(self, index: int, total: int) -> "DataChain":
943
- """Split a query into smaller chunks for e.g. parallelization.
944
-
945
- Examples:
946
- >>> dc = DataChain(...)
947
- >>> chunk_1 = dc._chunk(0, 2)
948
- >>> chunk_2 = dc._chunk(1, 2)
949
-
950
- Note:
951
- Bear in mind that `index` is 0-indexed but `total` isn't.
952
- Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
953
- """
954
- chain = self.gen_random()
955
- return DatasetQuery.chunk(chain, index, total)
@@ -11,11 +11,10 @@ from datachain.lib.feature import (
11
11
  FeatureTypeNames,
12
12
  convert_type_to_datachain,
13
13
  )
14
- from datachain.lib.reader import FeatureReader
15
14
  from datachain.lib.utils import DataChainParamsError
16
15
  from datachain.query.schema import Column
17
16
 
18
- FeatureLike = Union[type["Feature"], FeatureReader, Column, str]
17
+ FeatureLike = Union[type["Feature"], Column, str]
19
18
 
20
19
  AUTO_FEATURE_PREFIX = "_auto_fr"
21
20
  SUFFIX_SYMBOLS = string.digits + string.ascii_lowercase
datachain/lib/file.py CHANGED
@@ -282,3 +282,10 @@ def get_file(type: Literal["binary", "text", "image"] = "binary"):
282
282
  )
283
283
 
284
284
  return get_file_type
285
+
286
+
287
+ class IndexedFile(Feature):
288
+ """File source info for tables."""
289
+
290
+ file: File
291
+ index: int
datachain/lib/image.py CHANGED
@@ -1,6 +1,5 @@
1
- import inspect
2
1
  from io import BytesIO
3
- from typing import Any, Callable, Optional
2
+ from typing import Callable, Optional, Union
4
3
 
5
4
  from datachain.lib.file import File
6
5
 
@@ -14,8 +13,6 @@ except ImportError as exc:
14
13
  " pip install 'datachain[cv]'\n"
15
14
  ) from exc
16
15
 
17
- from datachain.lib.reader import FeatureReader
18
-
19
16
 
20
17
  class ImageFile(File):
21
18
  def get_value(self):
@@ -28,8 +25,8 @@ def convert_image(
28
25
  mode: str = "RGB",
29
26
  size: Optional[tuple[int, int]] = None,
30
27
  transform: Optional[Callable] = None,
31
- open_clip_model: Optional[Any] = None,
32
- ):
28
+ encoder: Optional[Callable] = None,
29
+ ) -> Union[Image.Image, torch.Tensor]:
33
30
  """
34
31
  Resize, transform, and otherwise convert an image.
35
32
 
@@ -37,8 +34,8 @@ def convert_image(
37
34
  img (Image): PIL.Image object.
38
35
  mode (str): PIL.Image mode.
39
36
  size (tuple[int, int]): Size in (width, height) pixels for resizing.
40
- transform (Callable): Torchvision v1 or other transform to apply.
41
- open_clip_model (Any): Encode image using model from open_clip library.
37
+ transform (Callable): Torchvision transform or huggingface processor to apply.
38
+ encoder (Callable): Encode image using model.
42
39
  """
43
40
  if mode:
44
41
  img = img.convert(mode)
@@ -46,86 +43,47 @@ def convert_image(
46
43
  img = img.resize(size)
47
44
  if transform:
48
45
  img = transform(img)
49
- if open_clip_model:
46
+
47
+ try:
48
+ from transformers.image_processing_utils import BaseImageProcessor
49
+
50
+ if isinstance(transform, BaseImageProcessor):
51
+ img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
52
+ except ImportError:
53
+ pass
54
+ if encoder:
50
55
  img = img.unsqueeze(0) # type: ignore[attr-defined]
51
- if open_clip_model:
52
- method_name = "encode_image"
53
- if not (
54
- hasattr(open_clip_model, method_name)
55
- and inspect.ismethod(getattr(open_clip_model, method_name))
56
- ):
57
- raise ValueError(
58
- "Unable to render Image: 'open_clip_model' doesn't support"
59
- f" '{method_name}()'"
60
- )
61
- img = open_clip_model.encode_image(img)
56
+ if encoder:
57
+ img = encoder(img)
62
58
  return img
63
59
 
64
60
 
65
- class ImageReader(FeatureReader):
66
- def __init__(
67
- self,
68
- mode: str = "RGB",
69
- size: Optional[tuple[int, int]] = None,
70
- transform: Optional[Callable] = None,
71
- open_clip_model: Any = None,
72
- ):
73
- """
74
- Read and optionally transform an image.
75
-
76
- All kwargs are passed to `convert_image()`.
77
- """
78
- self.mode = mode
79
- self.size = size
80
- self.transform = transform
81
- self.open_clip_model = open_clip_model
82
- super().__init__(ImageFile)
83
-
84
- def __call__(self, img: Image.Image):
85
- return convert_image(
86
- img,
87
- mode=self.mode,
88
- size=self.size,
89
- transform=self.transform,
90
- open_clip_model=self.open_clip_model,
91
- )
92
-
93
-
94
- def similarity_scores(
95
- model: Any,
96
- preprocess: Callable,
97
- tokenizer: Callable,
98
- image: Image.Image,
99
- text: str,
100
- prob: bool = False,
101
- ) -> list[float]:
61
+ def convert_images(
62
+ images: Union[Image.Image, list[Image.Image]],
63
+ mode: str = "RGB",
64
+ size: Optional[tuple[int, int]] = None,
65
+ transform: Optional[Callable] = None,
66
+ encoder: Optional[Callable] = None,
67
+ ) -> Union[list[Image.Image], torch.Tensor]:
102
68
  """
103
- Calculate CLIP similarity scores for one or more texts given an image.
69
+ Resize, transform, and otherwise convert one or more images.
104
70
 
105
71
  Args:
106
- model: Model from clip or open_clip packages.
107
- preprocess: Image preprocessing transforms.
108
- tokenizer: Text tokenizer.
109
- image: Image.
110
- text: Text.
111
- prob: Compute softmax probabilities across texts.
72
+ img (Image, list[Image]): PIL.Image object or list of objects.
73
+ mode (str): PIL.Image mode.
74
+ size (tuple[int, int]): Size in (width, height) pixels for resizing.
75
+ transform (Callable): Torchvision transform or huggingface processor to apply.
76
+ encoder (Callable): Encode image using model.
112
77
  """
78
+ if isinstance(images, Image.Image):
79
+ images = [images]
113
80
 
114
- with torch.no_grad():
115
- image = preprocess(image).unsqueeze(0)
116
- text = tokenizer(text)
117
-
118
- image_features = model.encode_image(image)
119
- text_features = model.encode_text(text)
120
-
121
- image_features /= image_features.norm(dim=-1, keepdim=True)
122
- text_features /= text_features.norm(dim=-1, keepdim=True)
81
+ converted = [convert_image(img, mode, size, transform) for img in images]
123
82
 
124
- logits_per_text = 100.0 * image_features @ text_features.T
83
+ if isinstance(converted[0], torch.Tensor):
84
+ converted = torch.stack(converted) # type: ignore[assignment,arg-type]
125
85
 
126
- if prob:
127
- scores = logits_per_text.softmax(dim=1)
128
- else:
129
- scores = logits_per_text
86
+ if encoder:
87
+ converted = encoder(converted)
130
88
 
131
- return scores[0].tolist()
89
+ return converted # type: ignore[return-value]