datachain 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +17 -2
- datachain/data_storage/db_engine.py +0 -2
- datachain/data_storage/schema.py +10 -27
- datachain/data_storage/warehouse.py +1 -7
- datachain/lib/arrow.py +7 -13
- datachain/lib/clip.py +151 -0
- datachain/lib/dc.py +35 -57
- datachain/lib/feature_utils.py +1 -2
- datachain/lib/file.py +7 -0
- datachain/lib/image.py +37 -79
- datachain/lib/pytorch.py +4 -2
- datachain/lib/signal_schema.py +3 -4
- datachain/lib/text.py +18 -49
- datachain/lib/udf.py +58 -30
- datachain/lib/udf_signature.py +11 -10
- datachain/lib/utils.py +17 -0
- datachain/lib/webdataset.py +2 -2
- datachain/listing.py +0 -3
- datachain/query/dataset.py +63 -37
- datachain/query/dispatch.py +2 -2
- datachain/query/schema.py +1 -8
- datachain/query/udf.py +16 -18
- datachain/utils.py +28 -0
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/METADATA +2 -1
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/RECORD +29 -29
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
- datachain/lib/reader.py +0 -49
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.1.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -65,7 +65,7 @@ from datachain.listing import Listing
|
|
|
65
65
|
from datachain.node import DirType, Node, NodeWithPath
|
|
66
66
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
67
67
|
from datachain.remote.studio import StudioClient
|
|
68
|
-
from datachain.sql.types import DateTime, SQLType, String
|
|
68
|
+
from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
|
|
69
69
|
from datachain.storage import Storage, StorageStatus, StorageURI
|
|
70
70
|
from datachain.utils import (
|
|
71
71
|
DataChainDir,
|
|
@@ -714,7 +714,22 @@ class Catalog:
|
|
|
714
714
|
source_metastore = self.metastore.clone(client.uri)
|
|
715
715
|
source_warehouse = self.warehouse.clone()
|
|
716
716
|
|
|
717
|
-
columns =
|
|
717
|
+
columns = [
|
|
718
|
+
Column("vtype", String),
|
|
719
|
+
Column("dir_type", Int),
|
|
720
|
+
Column("parent", String),
|
|
721
|
+
Column("name", String),
|
|
722
|
+
Column("etag", String),
|
|
723
|
+
Column("version", String),
|
|
724
|
+
Column("is_latest", Boolean),
|
|
725
|
+
Column("last_modified", DateTime(timezone=True)),
|
|
726
|
+
Column("size", Int64),
|
|
727
|
+
Column("owner_name", String),
|
|
728
|
+
Column("owner_id", String),
|
|
729
|
+
Column("location", JSON),
|
|
730
|
+
Column("source", String),
|
|
731
|
+
]
|
|
732
|
+
|
|
718
733
|
if skip_indexing:
|
|
719
734
|
source_metastore.create_storage_if_not_registered(client.uri)
|
|
720
735
|
storage = source_metastore.get_storage(client.uri)
|
datachain/data_storage/schema.py
CHANGED
|
@@ -14,7 +14,7 @@ from sqlalchemy.sql.expression import null, true
|
|
|
14
14
|
|
|
15
15
|
from datachain.node import DirType
|
|
16
16
|
from datachain.sql.functions import path
|
|
17
|
-
from datachain.sql.types import
|
|
17
|
+
from datachain.sql.types import Int, SQLType, UInt64
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from sqlalchemy import Engine
|
|
@@ -137,7 +137,7 @@ class DataTable:
|
|
|
137
137
|
self.name: str = name
|
|
138
138
|
self.engine = engine
|
|
139
139
|
self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
|
|
140
|
-
self.column_types = column_types
|
|
140
|
+
self.column_types: dict[str, SQLType] = column_types or {}
|
|
141
141
|
|
|
142
142
|
@staticmethod
|
|
143
143
|
def copy_column(column: sa.Column):
|
|
@@ -186,12 +186,12 @@ class DataTable:
|
|
|
186
186
|
# Grab it from metadata instead.
|
|
187
187
|
table = self.metadata.tables[self.name]
|
|
188
188
|
|
|
189
|
+
column_types = self.column_types | {c.name: c.type for c in self.sys_columns()}
|
|
189
190
|
# adjusting types for custom columns to be instances of SQLType if possible
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
c.type = t() if inspect.isclass(t) else t
|
|
191
|
+
for c in table.columns:
|
|
192
|
+
if c.name in column_types:
|
|
193
|
+
t = column_types[c.name]
|
|
194
|
+
c.type = t() if inspect.isclass(t) else t
|
|
195
195
|
return table
|
|
196
196
|
|
|
197
197
|
@property
|
|
@@ -234,26 +234,9 @@ class DataTable:
|
|
|
234
234
|
def sys_columns():
|
|
235
235
|
return [
|
|
236
236
|
sa.Column("id", Int, primary_key=True),
|
|
237
|
-
sa.Column(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
@classmethod
|
|
241
|
-
def file_columns(cls) -> list[sa.Column]:
|
|
242
|
-
return [
|
|
243
|
-
*cls.sys_columns(),
|
|
244
|
-
sa.Column("vtype", String, nullable=False, index=True),
|
|
245
|
-
sa.Column("dir_type", Int, index=True),
|
|
246
|
-
sa.Column("parent", String, index=True),
|
|
247
|
-
sa.Column("name", String, nullable=False, index=True),
|
|
248
|
-
sa.Column("etag", String),
|
|
249
|
-
sa.Column("version", String),
|
|
250
|
-
sa.Column("is_latest", Boolean),
|
|
251
|
-
sa.Column("last_modified", DateTime(timezone=True)),
|
|
252
|
-
sa.Column("size", Int64, nullable=False, index=True),
|
|
253
|
-
sa.Column("owner_name", String),
|
|
254
|
-
sa.Column("owner_id", String),
|
|
255
|
-
sa.Column("location", JSON),
|
|
256
|
-
sa.Column("source", String, nullable=False),
|
|
237
|
+
sa.Column(
|
|
238
|
+
"random", UInt64, nullable=False, server_default=f.abs(f.random())
|
|
239
|
+
),
|
|
257
240
|
]
|
|
258
241
|
|
|
259
242
|
def dir_expansion(self):
|
|
@@ -4,7 +4,6 @@ import logging
|
|
|
4
4
|
import posixpath
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
7
|
-
from random import getrandbits
|
|
8
7
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
9
8
|
from urllib.parse import urlparse
|
|
10
9
|
|
|
@@ -41,8 +40,6 @@ except ImportError:
|
|
|
41
40
|
|
|
42
41
|
logger = logging.getLogger("datachain")
|
|
43
42
|
|
|
44
|
-
RANDOM_BITS = 63 # size of the random integer field
|
|
45
|
-
|
|
46
43
|
SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
|
|
47
44
|
|
|
48
45
|
|
|
@@ -408,10 +405,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
408
405
|
|
|
409
406
|
def _prepare_entry(entry: Entry):
|
|
410
407
|
assert entry.dir_type is not None
|
|
411
|
-
return attrs.asdict(entry) | {
|
|
412
|
-
"source": uri,
|
|
413
|
-
"random": getrandbits(RANDOM_BITS),
|
|
414
|
-
}
|
|
408
|
+
return attrs.asdict(entry) | {"source": uri}
|
|
415
409
|
|
|
416
410
|
return [_prepare_entry(e) for e in entries]
|
|
417
411
|
|
datachain/lib/arrow.py
CHANGED
|
@@ -3,21 +3,14 @@ from typing import TYPE_CHECKING, Optional
|
|
|
3
3
|
|
|
4
4
|
from pyarrow.dataset import dataset
|
|
5
5
|
|
|
6
|
-
from datachain.lib.
|
|
7
|
-
from datachain.lib.
|
|
6
|
+
from datachain.lib.file import File, IndexedFile
|
|
7
|
+
from datachain.lib.udf import Generator
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
import pyarrow as pa
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class
|
|
14
|
-
"""File source info for tables."""
|
|
15
|
-
|
|
16
|
-
file: File
|
|
17
|
-
index: int
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class ArrowGenerator:
|
|
13
|
+
class ArrowGenerator(Generator):
|
|
21
14
|
def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
|
|
22
15
|
"""
|
|
23
16
|
Generator for getting rows from tabular files.
|
|
@@ -27,16 +20,17 @@ class ArrowGenerator:
|
|
|
27
20
|
schema : Optional pyarrow schema for validation.
|
|
28
21
|
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
29
22
|
"""
|
|
23
|
+
super().__init__()
|
|
30
24
|
self.schema = schema
|
|
31
25
|
self.kwargs = kwargs
|
|
32
26
|
|
|
33
|
-
def
|
|
27
|
+
def process(self, file: File):
|
|
34
28
|
path = file.get_path()
|
|
35
29
|
ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
|
|
36
30
|
index = 0
|
|
37
31
|
for record_batch in ds.to_batches():
|
|
38
32
|
for record in record_batch.to_pylist():
|
|
39
|
-
source =
|
|
33
|
+
source = IndexedFile(file=file, index=index)
|
|
40
34
|
yield [source, *record.values()]
|
|
41
35
|
index += 1
|
|
42
36
|
|
|
@@ -44,7 +38,7 @@ class ArrowGenerator:
|
|
|
44
38
|
def schema_to_output(schema: "pa.Schema"):
|
|
45
39
|
"""Generate UDF output schema from pyarrow schema."""
|
|
46
40
|
default_column = 0
|
|
47
|
-
output = {"source":
|
|
41
|
+
output = {"source": IndexedFile}
|
|
48
42
|
for field in schema:
|
|
49
43
|
column = field.name.lower()
|
|
50
44
|
column = re.sub("[^0-9a-z_]+", "", column)
|
datachain/lib/clip.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from typing import Any, Callable, Literal, Union
|
|
3
|
+
|
|
4
|
+
from datachain.lib.image import convert_images
|
|
5
|
+
from datachain.lib.text import convert_text
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import torch
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
11
|
+
except ImportError as exc:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"Missing dependencies for computer vision:\n"
|
|
14
|
+
"To install run:\n\n"
|
|
15
|
+
" pip install 'datachain[cv]'\n"
|
|
16
|
+
) from exc
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
|
|
20
|
+
# Check for transformers CLIPModel
|
|
21
|
+
method_name = f"get_{type}_features"
|
|
22
|
+
if isinstance(model, PreTrainedModel) and (
|
|
23
|
+
hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name))
|
|
24
|
+
):
|
|
25
|
+
method = getattr(model, method_name)
|
|
26
|
+
return lambda x: method(torch.tensor(x))
|
|
27
|
+
|
|
28
|
+
# Check for model from clip or open_clip library
|
|
29
|
+
method_name = f"encode_{type}"
|
|
30
|
+
if hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name)):
|
|
31
|
+
return getattr(model, method_name)
|
|
32
|
+
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Error encoding {type}: "
|
|
35
|
+
"'model' must be a CLIP model from clip, open_clip, or transformers library."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def similarity_scores(
|
|
40
|
+
images: Union[None, Image.Image, list[Image.Image]],
|
|
41
|
+
text: Union[None, str, list[str]],
|
|
42
|
+
model: Any,
|
|
43
|
+
preprocess: Callable,
|
|
44
|
+
tokenizer: Callable,
|
|
45
|
+
prob: bool = False,
|
|
46
|
+
image_to_text: bool = True,
|
|
47
|
+
) -> list[list[float]]:
|
|
48
|
+
"""
|
|
49
|
+
Calculate CLIP similarity scores between one or more images and/or text.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
images: Images to use as inputs.
|
|
53
|
+
text: Text to use as inputs.
|
|
54
|
+
model: Model from clip or open_clip packages.
|
|
55
|
+
preprocess: Image preprocessor to apply.
|
|
56
|
+
tokenizer: Text tokenizer.
|
|
57
|
+
prob: Compute softmax probabilities.
|
|
58
|
+
image_to_text: Whether to compute for image-to-text or text-to-image. Ignored if
|
|
59
|
+
only one of images or text provided.
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
Examples
|
|
63
|
+
--------
|
|
64
|
+
|
|
65
|
+
using https://github.com/openai/CLIP
|
|
66
|
+
>>> import clip
|
|
67
|
+
>>> model, preprocess = clip.load("ViT-B/32")
|
|
68
|
+
>>> similarity_scores(img, "cat", model, preprocess, clip.tokenize)
|
|
69
|
+
[[21.813]]
|
|
70
|
+
|
|
71
|
+
using https://github.com/mlfoundations/open_clip
|
|
72
|
+
>>> import open_clip
|
|
73
|
+
>>> model, _, preprocess = open_clip.create_model_and_transforms(
|
|
74
|
+
... "ViT-B-32", pretrained="laion2b_s34b_b79k"
|
|
75
|
+
... )
|
|
76
|
+
>>> tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
|
77
|
+
>>> similarity_scores(img, "cat", model, preprocess, tokenizer)
|
|
78
|
+
[[21.813]]
|
|
79
|
+
|
|
80
|
+
using https://huggingface.co/docs/transformers/en/model_doc/clip
|
|
81
|
+
>>> from transformers import CLIPProcessor, CLIPModel
|
|
82
|
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
83
|
+
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
84
|
+
>>> scores = similarity_scores(
|
|
85
|
+
... img, "cat", model, processor.image_processor, processor.tokenizer
|
|
86
|
+
... )
|
|
87
|
+
[[21.813]]
|
|
88
|
+
|
|
89
|
+
image -> list of text
|
|
90
|
+
>>> similarity_scores(img, ["cat", "dog"], model, preprocess, tokenizer)
|
|
91
|
+
[[21.813, 35.313]]
|
|
92
|
+
|
|
93
|
+
list of images -> text
|
|
94
|
+
>>> similarity_scores([img1, img2], "cat", model, preprocess, tokenizer)
|
|
95
|
+
[[21.813], [83.123]]
|
|
96
|
+
|
|
97
|
+
list of images -> list of text
|
|
98
|
+
>>> similarity_scores([img1, img2], ["cat", "dog"], model, preprocess, tokenizer)
|
|
99
|
+
[[21.813, 35.313], [83.123, 34.843]]
|
|
100
|
+
|
|
101
|
+
list of images -> list of images
|
|
102
|
+
>>> similarity_scores([img1, img2], None, model, preprocess, tokenizer)
|
|
103
|
+
[[94.189, 37.092]]
|
|
104
|
+
|
|
105
|
+
list of text -> list of text
|
|
106
|
+
>>> similarity_scores(None, ["cat", "dog"], model, preprocess, tokenizer)
|
|
107
|
+
[[67.334, 23.588]]
|
|
108
|
+
|
|
109
|
+
text -> list of images
|
|
110
|
+
>>> similarity_scores([img1, img2], "cat", ..., image_to_text=False)
|
|
111
|
+
[[19.708, 19.842]]
|
|
112
|
+
|
|
113
|
+
show scores as softmax probabilities
|
|
114
|
+
>>> similarity_scores(img, ["cat", "dog"], ..., prob=True)
|
|
115
|
+
[[0.423, 0.577]]
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
with torch.no_grad():
|
|
119
|
+
if images is not None:
|
|
120
|
+
encoder = _get_encoder(model, "image")
|
|
121
|
+
image_features = convert_images(
|
|
122
|
+
images, transform=preprocess, encoder=encoder
|
|
123
|
+
)
|
|
124
|
+
image_features /= image_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
|
|
125
|
+
|
|
126
|
+
if text is not None:
|
|
127
|
+
encoder = _get_encoder(model, "text")
|
|
128
|
+
text_features = convert_text(text, tokenizer, encoder=encoder)
|
|
129
|
+
text_features /= text_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
|
|
130
|
+
|
|
131
|
+
if images is not None and text is not None:
|
|
132
|
+
if image_to_text:
|
|
133
|
+
logits = 100.0 * image_features @ text_features.T # type: ignore[operator,union-attr]
|
|
134
|
+
else:
|
|
135
|
+
logits = 100.0 * text_features @ image_features.T # type: ignore[operator,union-attr]
|
|
136
|
+
elif images is not None:
|
|
137
|
+
logits = 100.0 * image_features @ image_features.T # type: ignore[operator,union-attr]
|
|
138
|
+
elif text is not None:
|
|
139
|
+
logits = 100.0 * text_features @ text_features.T # type: ignore[operator,union-attr]
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Error calculating CLIP similarity - "
|
|
143
|
+
"provide at least one of images or text"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if prob:
|
|
147
|
+
scores = logits.softmax(dim=1)
|
|
148
|
+
else:
|
|
149
|
+
scores = logits
|
|
150
|
+
|
|
151
|
+
return scores.tolist()
|
datachain/lib/dc.py
CHANGED
|
@@ -14,7 +14,7 @@ import sqlalchemy
|
|
|
14
14
|
|
|
15
15
|
from datachain.lib.feature import Feature, FeatureType
|
|
16
16
|
from datachain.lib.feature_utils import features_to_tuples
|
|
17
|
-
from datachain.lib.file import File, get_file
|
|
17
|
+
from datachain.lib.file import File, IndexedFile, get_file
|
|
18
18
|
from datachain.lib.meta_formats import read_meta, read_schema
|
|
19
19
|
from datachain.lib.settings import Settings
|
|
20
20
|
from datachain.lib.signal_schema import SignalSchema
|
|
@@ -437,8 +437,7 @@ class DataChain(DatasetQuery):
|
|
|
437
437
|
|
|
438
438
|
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
|
|
439
439
|
|
|
440
|
-
chain =
|
|
441
|
-
self,
|
|
440
|
+
chain = self.add_signals(
|
|
442
441
|
udf_obj.to_udf_wrapper(self._settings.batch),
|
|
443
442
|
**self._settings.to_dict(),
|
|
444
443
|
)
|
|
@@ -534,23 +533,23 @@ class DataChain(DatasetQuery):
|
|
|
534
533
|
signal_map,
|
|
535
534
|
) -> UDFBase:
|
|
536
535
|
is_generator = target_class.is_output_batched
|
|
537
|
-
name = self.name or "
|
|
536
|
+
name = self.name or ""
|
|
537
|
+
|
|
538
538
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
539
|
+
params_schema = self.signals_schema.slice(sign.params)
|
|
539
540
|
|
|
540
|
-
|
|
541
|
-
udf = target_class(params_feature, sign.output_schema, func=sign.func)
|
|
542
|
-
udf.set_catalog(self.catalog)
|
|
543
|
-
return udf
|
|
541
|
+
return UDFBase._create(target_class, sign, params_schema, self.catalog)
|
|
544
542
|
|
|
545
543
|
def _extend_features(self, method_name, *args, **kwargs):
|
|
546
544
|
super_func = getattr(super(), method_name)
|
|
547
545
|
|
|
548
546
|
new_schema = self.signals_schema.resolve(*args)
|
|
549
|
-
columns = new_schema.db_signals()
|
|
550
|
-
|
|
551
|
-
|
|
547
|
+
columns = [C(col) for col in new_schema.db_signals()]
|
|
548
|
+
res = super_func(*columns, **kwargs)
|
|
549
|
+
if isinstance(res, DataChain):
|
|
550
|
+
res.signals_schema = new_schema
|
|
552
551
|
|
|
553
|
-
return
|
|
552
|
+
return res
|
|
554
553
|
|
|
555
554
|
@detach
|
|
556
555
|
def select(self, *args: str) -> "Self":
|
|
@@ -703,6 +702,9 @@ class DataChain(DatasetQuery):
|
|
|
703
702
|
right_on = on
|
|
704
703
|
right_on_columns = on_columns
|
|
705
704
|
|
|
705
|
+
if self == right_ds:
|
|
706
|
+
right_ds = right_ds.clone(new_table=True)
|
|
707
|
+
|
|
706
708
|
ops = [
|
|
707
709
|
self.c(left) == right_ds.c(right)
|
|
708
710
|
for left, right in zip(on_columns, right_on_columns)
|
|
@@ -778,11 +780,11 @@ class DataChain(DatasetQuery):
|
|
|
778
780
|
from pyarrow import unify_schemas
|
|
779
781
|
from pyarrow.dataset import dataset
|
|
780
782
|
|
|
781
|
-
from datachain.lib.arrow import ArrowGenerator,
|
|
783
|
+
from datachain.lib.arrow import ArrowGenerator, schema_to_output
|
|
782
784
|
|
|
783
785
|
schema = None
|
|
784
786
|
if output:
|
|
785
|
-
output = {"source":
|
|
787
|
+
output = {"source": IndexedFile} | output
|
|
786
788
|
else:
|
|
787
789
|
schemas = []
|
|
788
790
|
for row in self.select("file").iterate():
|
|
@@ -795,7 +797,6 @@ class DataChain(DatasetQuery):
|
|
|
795
797
|
schema = unify_schemas(schemas)
|
|
796
798
|
try:
|
|
797
799
|
output = schema_to_output(schema)
|
|
798
|
-
print(f"Inferred tabular data schema: {output}")
|
|
799
800
|
except ValueError as e:
|
|
800
801
|
raise DatasetPrepareError(self.name, e) from e
|
|
801
802
|
|
|
@@ -897,15 +898,26 @@ class DataChain(DatasetQuery):
|
|
|
897
898
|
>>> single_record = DataChain.create_empty(DataChain.DEFAULT_FILE_RECORD)
|
|
898
899
|
"""
|
|
899
900
|
session = Session.get(session)
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
if to_insert is not None:
|
|
903
|
-
if not isinstance(to_insert, list):
|
|
904
|
-
to_insert = [to_insert]
|
|
905
|
-
|
|
906
|
-
for record in to_insert:
|
|
907
|
-
cls.insert_record(dsr, record, session=session)
|
|
901
|
+
catalog = session.catalog
|
|
908
902
|
|
|
903
|
+
name = session.generate_temp_dataset_name()
|
|
904
|
+
columns: tuple[sqlalchemy.Column[Any], ...] = tuple(
|
|
905
|
+
sqlalchemy.Column(name, typ)
|
|
906
|
+
for name, typ in File._datachain_column_types.items()
|
|
907
|
+
)
|
|
908
|
+
dsr = catalog.create_dataset(name, columns=columns)
|
|
909
|
+
|
|
910
|
+
if isinstance(to_insert, dict):
|
|
911
|
+
to_insert = [to_insert]
|
|
912
|
+
elif not to_insert:
|
|
913
|
+
to_insert = []
|
|
914
|
+
|
|
915
|
+
warehouse = catalog.warehouse
|
|
916
|
+
dr = warehouse.dataset_rows(dsr)
|
|
917
|
+
db = warehouse.db
|
|
918
|
+
insert_q = dr.get_table().insert()
|
|
919
|
+
for record in to_insert:
|
|
920
|
+
db.execute(insert_q.values(**record))
|
|
909
921
|
return DataChain(name=dsr.name)
|
|
910
922
|
|
|
911
923
|
def sum(self, fr: FeatureType): # type: ignore[override]
|
|
@@ -919,37 +931,3 @@ class DataChain(DatasetQuery):
|
|
|
919
931
|
|
|
920
932
|
def max(self, fr: FeatureType): # type: ignore[override]
|
|
921
933
|
return self._extend_features("max", fr)
|
|
922
|
-
|
|
923
|
-
@detach
|
|
924
|
-
def gen_random(self) -> "DataChain":
|
|
925
|
-
from random import getrandbits
|
|
926
|
-
|
|
927
|
-
from datachain.data_storage.warehouse import RANDOM_BITS
|
|
928
|
-
|
|
929
|
-
if "random" not in self.signals_schema.values:
|
|
930
|
-
chain = self.map(random=lambda: getrandbits(RANDOM_BITS), output=int).save()
|
|
931
|
-
return chain.select_except("random")
|
|
932
|
-
|
|
933
|
-
return self
|
|
934
|
-
|
|
935
|
-
@detach
|
|
936
|
-
def shuffle(self) -> "DataChain":
|
|
937
|
-
"""Return results in deterministic random order."""
|
|
938
|
-
chain = self.gen_random()
|
|
939
|
-
return DatasetQuery.shuffle(chain)
|
|
940
|
-
|
|
941
|
-
@detach
|
|
942
|
-
def chunk(self, index: int, total: int) -> "DataChain":
|
|
943
|
-
"""Split a query into smaller chunks for e.g. parallelization.
|
|
944
|
-
|
|
945
|
-
Examples:
|
|
946
|
-
>>> dc = DataChain(...)
|
|
947
|
-
>>> chunk_1 = dc._chunk(0, 2)
|
|
948
|
-
>>> chunk_2 = dc._chunk(1, 2)
|
|
949
|
-
|
|
950
|
-
Note:
|
|
951
|
-
Bear in mind that `index` is 0-indexed but `total` isn't.
|
|
952
|
-
Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
|
|
953
|
-
"""
|
|
954
|
-
chain = self.gen_random()
|
|
955
|
-
return DatasetQuery.chunk(chain, index, total)
|
datachain/lib/feature_utils.py
CHANGED
|
@@ -11,11 +11,10 @@ from datachain.lib.feature import (
|
|
|
11
11
|
FeatureTypeNames,
|
|
12
12
|
convert_type_to_datachain,
|
|
13
13
|
)
|
|
14
|
-
from datachain.lib.reader import FeatureReader
|
|
15
14
|
from datachain.lib.utils import DataChainParamsError
|
|
16
15
|
from datachain.query.schema import Column
|
|
17
16
|
|
|
18
|
-
FeatureLike = Union[type["Feature"],
|
|
17
|
+
FeatureLike = Union[type["Feature"], Column, str]
|
|
19
18
|
|
|
20
19
|
AUTO_FEATURE_PREFIX = "_auto_fr"
|
|
21
20
|
SUFFIX_SYMBOLS = string.digits + string.ascii_lowercase
|
datachain/lib/file.py
CHANGED
datachain/lib/image.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
import inspect
|
|
2
1
|
from io import BytesIO
|
|
3
|
-
from typing import
|
|
2
|
+
from typing import Callable, Optional, Union
|
|
4
3
|
|
|
5
4
|
from datachain.lib.file import File
|
|
6
5
|
|
|
@@ -14,8 +13,6 @@ except ImportError as exc:
|
|
|
14
13
|
" pip install 'datachain[cv]'\n"
|
|
15
14
|
) from exc
|
|
16
15
|
|
|
17
|
-
from datachain.lib.reader import FeatureReader
|
|
18
|
-
|
|
19
16
|
|
|
20
17
|
class ImageFile(File):
|
|
21
18
|
def get_value(self):
|
|
@@ -28,8 +25,8 @@ def convert_image(
|
|
|
28
25
|
mode: str = "RGB",
|
|
29
26
|
size: Optional[tuple[int, int]] = None,
|
|
30
27
|
transform: Optional[Callable] = None,
|
|
31
|
-
|
|
32
|
-
):
|
|
28
|
+
encoder: Optional[Callable] = None,
|
|
29
|
+
) -> Union[Image.Image, torch.Tensor]:
|
|
33
30
|
"""
|
|
34
31
|
Resize, transform, and otherwise convert an image.
|
|
35
32
|
|
|
@@ -37,8 +34,8 @@ def convert_image(
|
|
|
37
34
|
img (Image): PIL.Image object.
|
|
38
35
|
mode (str): PIL.Image mode.
|
|
39
36
|
size (tuple[int, int]): Size in (width, height) pixels for resizing.
|
|
40
|
-
transform (Callable): Torchvision
|
|
41
|
-
|
|
37
|
+
transform (Callable): Torchvision transform or huggingface processor to apply.
|
|
38
|
+
encoder (Callable): Encode image using model.
|
|
42
39
|
"""
|
|
43
40
|
if mode:
|
|
44
41
|
img = img.convert(mode)
|
|
@@ -46,86 +43,47 @@ def convert_image(
|
|
|
46
43
|
img = img.resize(size)
|
|
47
44
|
if transform:
|
|
48
45
|
img = transform(img)
|
|
49
|
-
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
from transformers.image_processing_utils import BaseImageProcessor
|
|
49
|
+
|
|
50
|
+
if isinstance(transform, BaseImageProcessor):
|
|
51
|
+
img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
|
|
52
|
+
except ImportError:
|
|
53
|
+
pass
|
|
54
|
+
if encoder:
|
|
50
55
|
img = img.unsqueeze(0) # type: ignore[attr-defined]
|
|
51
|
-
if
|
|
52
|
-
|
|
53
|
-
if not (
|
|
54
|
-
hasattr(open_clip_model, method_name)
|
|
55
|
-
and inspect.ismethod(getattr(open_clip_model, method_name))
|
|
56
|
-
):
|
|
57
|
-
raise ValueError(
|
|
58
|
-
"Unable to render Image: 'open_clip_model' doesn't support"
|
|
59
|
-
f" '{method_name}()'"
|
|
60
|
-
)
|
|
61
|
-
img = open_clip_model.encode_image(img)
|
|
56
|
+
if encoder:
|
|
57
|
+
img = encoder(img)
|
|
62
58
|
return img
|
|
63
59
|
|
|
64
60
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
):
|
|
73
|
-
"""
|
|
74
|
-
Read and optionally transform an image.
|
|
75
|
-
|
|
76
|
-
All kwargs are passed to `convert_image()`.
|
|
77
|
-
"""
|
|
78
|
-
self.mode = mode
|
|
79
|
-
self.size = size
|
|
80
|
-
self.transform = transform
|
|
81
|
-
self.open_clip_model = open_clip_model
|
|
82
|
-
super().__init__(ImageFile)
|
|
83
|
-
|
|
84
|
-
def __call__(self, img: Image.Image):
|
|
85
|
-
return convert_image(
|
|
86
|
-
img,
|
|
87
|
-
mode=self.mode,
|
|
88
|
-
size=self.size,
|
|
89
|
-
transform=self.transform,
|
|
90
|
-
open_clip_model=self.open_clip_model,
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def similarity_scores(
|
|
95
|
-
model: Any,
|
|
96
|
-
preprocess: Callable,
|
|
97
|
-
tokenizer: Callable,
|
|
98
|
-
image: Image.Image,
|
|
99
|
-
text: str,
|
|
100
|
-
prob: bool = False,
|
|
101
|
-
) -> list[float]:
|
|
61
|
+
def convert_images(
|
|
62
|
+
images: Union[Image.Image, list[Image.Image]],
|
|
63
|
+
mode: str = "RGB",
|
|
64
|
+
size: Optional[tuple[int, int]] = None,
|
|
65
|
+
transform: Optional[Callable] = None,
|
|
66
|
+
encoder: Optional[Callable] = None,
|
|
67
|
+
) -> Union[list[Image.Image], torch.Tensor]:
|
|
102
68
|
"""
|
|
103
|
-
|
|
69
|
+
Resize, transform, and otherwise convert one or more images.
|
|
104
70
|
|
|
105
71
|
Args:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
prob: Compute softmax probabilities across texts.
|
|
72
|
+
img (Image, list[Image]): PIL.Image object or list of objects.
|
|
73
|
+
mode (str): PIL.Image mode.
|
|
74
|
+
size (tuple[int, int]): Size in (width, height) pixels for resizing.
|
|
75
|
+
transform (Callable): Torchvision transform or huggingface processor to apply.
|
|
76
|
+
encoder (Callable): Encode image using model.
|
|
112
77
|
"""
|
|
78
|
+
if isinstance(images, Image.Image):
|
|
79
|
+
images = [images]
|
|
113
80
|
|
|
114
|
-
|
|
115
|
-
image = preprocess(image).unsqueeze(0)
|
|
116
|
-
text = tokenizer(text)
|
|
117
|
-
|
|
118
|
-
image_features = model.encode_image(image)
|
|
119
|
-
text_features = model.encode_text(text)
|
|
120
|
-
|
|
121
|
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
122
|
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
81
|
+
converted = [convert_image(img, mode, size, transform) for img in images]
|
|
123
82
|
|
|
124
|
-
|
|
83
|
+
if isinstance(converted[0], torch.Tensor):
|
|
84
|
+
converted = torch.stack(converted) # type: ignore[assignment,arg-type]
|
|
125
85
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
else:
|
|
129
|
-
scores = logits_per_text
|
|
86
|
+
if encoder:
|
|
87
|
+
converted = encoder(converted)
|
|
130
88
|
|
|
131
|
-
|
|
89
|
+
return converted # type: ignore[return-value]
|