datachain 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (39) hide show
  1. datachain/__init__.py +0 -4
  2. datachain/catalog/catalog.py +17 -2
  3. datachain/cli.py +8 -1
  4. datachain/data_storage/db_engine.py +0 -2
  5. datachain/data_storage/schema.py +15 -26
  6. datachain/data_storage/sqlite.py +3 -0
  7. datachain/data_storage/warehouse.py +1 -7
  8. datachain/lib/arrow.py +7 -13
  9. datachain/lib/cached_stream.py +3 -85
  10. datachain/lib/clip.py +151 -0
  11. datachain/lib/dc.py +41 -59
  12. datachain/lib/feature.py +5 -1
  13. datachain/lib/feature_registry.py +3 -2
  14. datachain/lib/feature_utils.py +1 -2
  15. datachain/lib/file.py +17 -24
  16. datachain/lib/image.py +37 -79
  17. datachain/lib/pytorch.py +4 -2
  18. datachain/lib/signal_schema.py +3 -4
  19. datachain/lib/text.py +18 -49
  20. datachain/lib/udf.py +64 -55
  21. datachain/lib/udf_signature.py +11 -10
  22. datachain/lib/utils.py +17 -0
  23. datachain/lib/webdataset.py +2 -2
  24. datachain/listing.py +0 -3
  25. datachain/query/dataset.py +66 -46
  26. datachain/query/dispatch.py +2 -2
  27. datachain/query/schema.py +1 -8
  28. datachain/query/udf.py +16 -18
  29. datachain/sql/sqlite/base.py +34 -2
  30. datachain/sql/sqlite/vector.py +13 -5
  31. datachain/utils.py +28 -0
  32. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/METADATA +3 -2
  33. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/RECORD +37 -38
  34. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/WHEEL +1 -1
  35. datachain/_version.py +0 -16
  36. datachain/lib/reader.py +0 -49
  37. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/LICENSE +0 -0
  38. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/entry_points.txt +0 -0
  39. {datachain-0.2.0.dist-info → datachain-0.2.2.dist-info}/top_level.txt +0 -0
datachain/__init__.py CHANGED
@@ -1,4 +0,0 @@
1
- try:
2
- from ._version import version as __version__
3
- except ImportError:
4
- __version__ = "UNKNOWN"
@@ -65,7 +65,7 @@ from datachain.listing import Listing
65
65
  from datachain.node import DirType, Node, NodeWithPath
66
66
  from datachain.nodes_thread_pool import NodesThreadPool
67
67
  from datachain.remote.studio import StudioClient
68
- from datachain.sql.types import DateTime, SQLType, String
68
+ from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
69
69
  from datachain.storage import Storage, StorageStatus, StorageURI
70
70
  from datachain.utils import (
71
71
  DataChainDir,
@@ -714,7 +714,22 @@ class Catalog:
714
714
  source_metastore = self.metastore.clone(client.uri)
715
715
  source_warehouse = self.warehouse.clone()
716
716
 
717
- columns = self.warehouse.schema.dataset_row_cls.file_columns()
717
+ columns = [
718
+ Column("vtype", String),
719
+ Column("dir_type", Int),
720
+ Column("parent", String),
721
+ Column("name", String),
722
+ Column("etag", String),
723
+ Column("version", String),
724
+ Column("is_latest", Boolean),
725
+ Column("last_modified", DateTime(timezone=True)),
726
+ Column("size", Int64),
727
+ Column("owner_name", String),
728
+ Column("owner_id", String),
729
+ Column("location", JSON),
730
+ Column("source", String),
731
+ ]
732
+
718
733
  if skip_indexing:
719
734
  source_metastore.create_storage_if_not_registered(client.uri)
720
735
  storage = source_metastore.get_storage(client.uri)
datachain/cli.py CHANGED
@@ -5,13 +5,14 @@ import sys
5
5
  import traceback
6
6
  from argparse import SUPPRESS, Action, ArgumentParser, ArgumentTypeError, Namespace
7
7
  from collections.abc import Iterable, Iterator, Mapping, Sequence
8
+ from importlib.metadata import PackageNotFoundError, version
8
9
  from itertools import chain
9
10
  from multiprocessing import freeze_support
10
11
  from typing import TYPE_CHECKING, Optional, Union
11
12
 
12
13
  import shtab
13
14
 
14
- from datachain import __version__, utils
15
+ from datachain import utils
15
16
  from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
16
17
  from datachain.utils import DataChainDir
17
18
 
@@ -96,6 +97,12 @@ def add_show_args(parser: ArgumentParser) -> None:
96
97
 
97
98
 
98
99
  def get_parser() -> ArgumentParser: # noqa: PLR0915
100
+ try:
101
+ __version__ = version("datachain")
102
+ except PackageNotFoundError:
103
+ # package is not installed
104
+ __version__ = "unknown"
105
+
99
106
  parser = ArgumentParser(
100
107
  description="DataChain: Wrangle unstructured AI data at scale", prog="datachain"
101
108
  )
@@ -20,8 +20,6 @@ if TYPE_CHECKING:
20
20
 
21
21
  logger = logging.getLogger("datachain")
22
22
 
23
- RANDOM_BITS = 63 # size of the random integer field
24
-
25
23
  SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
26
24
 
27
25
 
@@ -14,7 +14,7 @@ from sqlalchemy.sql.expression import null, true
14
14
 
15
15
  from datachain.node import DirType
16
16
  from datachain.sql.functions import path
17
- from datachain.sql.types import JSON, Boolean, DateTime, Int, Int64, SQLType, String
17
+ from datachain.sql.types import Int, SQLType, UInt64
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from sqlalchemy import Engine
@@ -31,7 +31,7 @@ def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
31
31
  """
32
32
  c_set: dict[str, sa.Column] = {}
33
33
  for c in columns:
34
- if ec := c_set.get(c.name, None):
34
+ if (ec := c_set.get(c.name, None)) is not None:
35
35
  if str(ec.type) != str(c.type):
36
36
  raise ValueError(
37
37
  f"conflicting types for column {c.name}:{c.type!s} and {ec.type!s}"
@@ -137,7 +137,7 @@ class DataTable:
137
137
  self.name: str = name
138
138
  self.engine = engine
139
139
  self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
140
- self.column_types = column_types
140
+ self.column_types: dict[str, SQLType] = column_types or {}
141
141
 
142
142
  @staticmethod
143
143
  def copy_column(column: sa.Column):
@@ -171,8 +171,8 @@ class DataTable:
171
171
  ):
172
172
  # copy columns, since re-using the same objects from another table
173
173
  # may raise an error
174
- columns = [cls.copy_column(c) for c in columns if c.name != "id"]
175
- columns = [sa.Column("id", Int, primary_key=True), *columns]
174
+ columns = cls.sys_columns() + [cls.copy_column(c) for c in columns]
175
+ columns = dedup_columns(columns)
176
176
 
177
177
  if metadata is None:
178
178
  metadata = sa.MetaData()
@@ -186,12 +186,12 @@ class DataTable:
186
186
  # Grab it from metadata instead.
187
187
  table = self.metadata.tables[self.name]
188
188
 
189
+ column_types = self.column_types | {c.name: c.type for c in self.sys_columns()}
189
190
  # adjusting types for custom columns to be instances of SQLType if possible
190
- if self.column_types:
191
- for c in table.columns:
192
- if c.name in self.column_types:
193
- t = self.column_types[c.name]
194
- c.type = t() if inspect.isclass(t) else t
191
+ for c in table.columns:
192
+ if c.name in column_types:
193
+ t = column_types[c.name]
194
+ c.type = t() if inspect.isclass(t) else t
195
195
  return table
196
196
 
197
197
  @property
@@ -230,24 +230,13 @@ class DataTable:
230
230
  def delete(self):
231
231
  return self.apply_conditions(self.table.delete())
232
232
 
233
- @classmethod
234
- def file_columns(cls) -> list[sa.Column]:
233
+ @staticmethod
234
+ def sys_columns():
235
235
  return [
236
236
  sa.Column("id", Int, primary_key=True),
237
- sa.Column("random", Int64, nullable=False),
238
- sa.Column("vtype", String, nullable=False, index=True),
239
- sa.Column("dir_type", Int, index=True),
240
- sa.Column("parent", String, index=True),
241
- sa.Column("name", String, nullable=False, index=True),
242
- sa.Column("etag", String),
243
- sa.Column("version", String),
244
- sa.Column("is_latest", Boolean),
245
- sa.Column("last_modified", DateTime(timezone=True)),
246
- sa.Column("size", Int64, nullable=False, index=True),
247
- sa.Column("owner_name", String),
248
- sa.Column("owner_id", String),
249
- sa.Column("location", JSON),
250
- sa.Column("source", String, nullable=False),
237
+ sa.Column(
238
+ "random", UInt64, nullable=False, server_default=f.abs(f.random())
239
+ ),
251
240
  ]
252
241
 
253
242
  def dir_expansion(self):
@@ -33,6 +33,7 @@ from datachain.data_storage.schema import (
33
33
  from datachain.dataset import DatasetRecord
34
34
  from datachain.error import DataChainError
35
35
  from datachain.sql.sqlite import create_user_defined_sql_functions, sqlite_dialect
36
+ from datachain.sql.sqlite.base import load_usearch_extension
36
37
  from datachain.sql.types import SQLType
37
38
  from datachain.storage import StorageURI
38
39
  from datachain.utils import DataChainDir
@@ -114,6 +115,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
114
115
  if os.environ.get("DEBUG_SHOW_SQL_QUERIES"):
115
116
  db.set_trace_callback(print)
116
117
 
118
+ load_usearch_extension(db)
119
+
117
120
  return cls(engine, MetaData(), db, db_file)
118
121
  except RuntimeError:
119
122
  raise DataChainError("Can't connect to SQLite DB") from None
@@ -4,7 +4,6 @@ import logging
4
4
  import posixpath
5
5
  from abc import ABC, abstractmethod
6
6
  from collections.abc import Generator, Iterable, Iterator, Sequence
7
- from random import getrandbits
8
7
  from typing import TYPE_CHECKING, Any, Optional, Union
9
8
  from urllib.parse import urlparse
10
9
 
@@ -41,8 +40,6 @@ except ImportError:
41
40
 
42
41
  logger = logging.getLogger("datachain")
43
42
 
44
- RANDOM_BITS = 63 # size of the random integer field
45
-
46
43
  SELECT_BATCH_SIZE = 100_000 # number of rows to fetch at a time
47
44
 
48
45
 
@@ -408,10 +405,7 @@ class AbstractWarehouse(ABC, Serializable):
408
405
 
409
406
  def _prepare_entry(entry: Entry):
410
407
  assert entry.dir_type is not None
411
- return attrs.asdict(entry) | {
412
- "source": uri,
413
- "random": getrandbits(RANDOM_BITS),
414
- }
408
+ return attrs.asdict(entry) | {"source": uri}
415
409
 
416
410
  return [_prepare_entry(e) for e in entries]
417
411
 
datachain/lib/arrow.py CHANGED
@@ -3,21 +3,14 @@ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  from pyarrow.dataset import dataset
5
5
 
6
- from datachain.lib.feature import Feature
7
- from datachain.lib.file import File
6
+ from datachain.lib.file import File, IndexedFile
7
+ from datachain.lib.udf import Generator
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  import pyarrow as pa
11
11
 
12
12
 
13
- class Source(Feature):
14
- """File source info for tables."""
15
-
16
- file: File
17
- index: int
18
-
19
-
20
- class ArrowGenerator:
13
+ class ArrowGenerator(Generator):
21
14
  def __init__(self, schema: Optional["pa.Schema"] = None, **kwargs):
22
15
  """
23
16
  Generator for getting rows from tabular files.
@@ -27,16 +20,17 @@ class ArrowGenerator:
27
20
  schema : Optional pyarrow schema for validation.
28
21
  kwargs: Parameters to pass to pyarrow.dataset.dataset.
29
22
  """
23
+ super().__init__()
30
24
  self.schema = schema
31
25
  self.kwargs = kwargs
32
26
 
33
- def __call__(self, file: File):
27
+ def process(self, file: File):
34
28
  path = file.get_path()
35
29
  ds = dataset(path, filesystem=file.get_fs(), schema=self.schema, **self.kwargs)
36
30
  index = 0
37
31
  for record_batch in ds.to_batches():
38
32
  for record in record_batch.to_pylist():
39
- source = Source(file=file, index=index)
33
+ source = IndexedFile(file=file, index=index)
40
34
  yield [source, *record.values()]
41
35
  index += 1
42
36
 
@@ -44,7 +38,7 @@ class ArrowGenerator:
44
38
  def schema_to_output(schema: "pa.Schema"):
45
39
  """Generate UDF output schema from pyarrow schema."""
46
40
  default_column = 0
47
- output = {"source": Source}
41
+ output = {"source": IndexedFile}
48
42
  for field in schema:
49
43
  column = field.name.lower()
50
44
  column = re.sub("[^0-9a-z_]+", "", column)
@@ -1,6 +1,3 @@
1
- import os
2
- import shutil
3
- import tempfile
4
1
  from abc import ABC
5
2
  from contextlib import AbstractContextManager
6
3
 
@@ -8,9 +5,7 @@ from datachain.cache import UniqueId
8
5
 
9
6
 
10
7
  class AbstractCachedStream(AbstractContextManager, ABC):
11
- def __init__(self, stream, size, catalog, uid: UniqueId):
12
- self.stream = stream
13
- self.size = size
8
+ def __init__(self, catalog, uid: UniqueId):
14
9
  self.catalog = catalog
15
10
  self.uid = uid
16
11
  self.mode = "rb"
@@ -19,86 +14,9 @@ class AbstractCachedStream(AbstractContextManager, ABC):
19
14
  self.mode = mode
20
15
 
21
16
 
22
- class ProgressiveCacheStream(AbstractCachedStream):
23
- BUF_SIZE = 4096
24
-
25
- def __init__(self, stream, size, catalog, uid: UniqueId):
26
- super().__init__(stream, size, catalog, uid)
27
-
28
- self.target_path = self.catalog.cache.path_from_checksum(self.uid.get_hash())
29
- self.cached_file = None
30
-
31
- self.temp_file = None
32
- self.temp_file_pos = 0
33
-
34
- def __enter__(self):
35
- if os.path.exists(self.target_path):
36
- self.cached_file = open(self.target_path, mode=self.mode)
37
- return self.cached_file
38
-
39
- tmp_dir = self.catalog.cache.tmp_dir
40
- if not os.path.exists(tmp_dir):
41
- os.makedirs(tmp_dir)
42
- self.temp_file = tempfile.NamedTemporaryFile(
43
- prefix=str(self.uid.get_hash()), dir=tmp_dir, delete=False
44
- )
45
- return self
46
-
47
- def __exit__(self, *args):
48
- self.close()
49
-
50
- def read(self, size=-1):
51
- buf = self.stream.read(size)
52
- pos = self.stream.tell()
53
-
54
- if pos >= self.temp_file_pos:
55
- self._cache_catch_up(pos, buf)
56
-
57
- return buf
58
-
59
- def close(self):
60
- if self.cached_file:
61
- self.cached_file.close()
62
-
63
- if self.temp_file:
64
- if self.temp_file_pos < self.size:
65
- self._cache_catch_up(self.size)
66
-
67
- self.temp_file.close()
68
- if not os.path.exists(self.target_path):
69
- os.makedirs(os.path.dirname(self.target_path), exist_ok=True)
70
- shutil.move(self.temp_file.name, self.target_path)
71
-
72
- self.stream.close()
73
-
74
- def _cache_catch_up(self, pos_target, latest_buf=None):
75
- pos_to_restore = self.stream.tell()
76
- try:
77
- remainder = pos_target - self.temp_file_pos
78
- self.stream.seek(self.temp_file_pos)
79
- while remainder > 0:
80
- chunk_size = min(self.BUF_SIZE, remainder)
81
- buf = self.stream.read(chunk_size)
82
- self._cache_update(buf)
83
- remainder -= len(buf)
84
- finally:
85
- self.stream.seek(pos_to_restore)
86
-
87
- def _cache_update(self, buf):
88
- length = len(buf)
89
- self.temp_file.write(buf)
90
- self.temp_file_pos += length
91
-
92
- def seek(self, offset, whence=0):
93
- return self.stream.seek(offset, whence)
94
-
95
- def tell(self):
96
- return self.stream.tell()
97
-
98
-
99
17
  class PreCachedStream(AbstractCachedStream):
100
- def __init__(self, stream, size, catalog, uid: UniqueId):
101
- super().__init__(stream, size, catalog, uid)
18
+ def __init__(self, catalog, uid: UniqueId):
19
+ super().__init__(catalog, uid)
102
20
  self.client = self.catalog.get_client(self.uid.storage)
103
21
  self.cached_file = None
104
22
 
datachain/lib/clip.py ADDED
@@ -0,0 +1,151 @@
1
+ import inspect
2
+ from typing import Any, Callable, Literal, Union
3
+
4
+ from datachain.lib.image import convert_images
5
+ from datachain.lib.text import convert_text
6
+
7
+ try:
8
+ import torch
9
+ from PIL import Image
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ except ImportError as exc:
12
+ raise ImportError(
13
+ "Missing dependencies for computer vision:\n"
14
+ "To install run:\n\n"
15
+ " pip install 'datachain[cv]'\n"
16
+ ) from exc
17
+
18
+
19
+ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
20
+ # Check for transformers CLIPModel
21
+ method_name = f"get_{type}_features"
22
+ if isinstance(model, PreTrainedModel) and (
23
+ hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name))
24
+ ):
25
+ method = getattr(model, method_name)
26
+ return lambda x: method(torch.tensor(x))
27
+
28
+ # Check for model from clip or open_clip library
29
+ method_name = f"encode_{type}"
30
+ if hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name)):
31
+ return getattr(model, method_name)
32
+
33
+ raise ValueError(
34
+ f"Error encoding {type}: "
35
+ "'model' must be a CLIP model from clip, open_clip, or transformers library."
36
+ )
37
+
38
+
39
+ def similarity_scores(
40
+ images: Union[None, Image.Image, list[Image.Image]],
41
+ text: Union[None, str, list[str]],
42
+ model: Any,
43
+ preprocess: Callable,
44
+ tokenizer: Callable,
45
+ prob: bool = False,
46
+ image_to_text: bool = True,
47
+ ) -> list[list[float]]:
48
+ """
49
+ Calculate CLIP similarity scores between one or more images and/or text.
50
+
51
+ Args:
52
+ images: Images to use as inputs.
53
+ text: Text to use as inputs.
54
+ model: Model from clip or open_clip packages.
55
+ preprocess: Image preprocessor to apply.
56
+ tokenizer: Text tokenizer.
57
+ prob: Compute softmax probabilities.
58
+ image_to_text: Whether to compute for image-to-text or text-to-image. Ignored if
59
+ only one of images or text provided.
60
+
61
+
62
+ Examples
63
+ --------
64
+
65
+ using https://github.com/openai/CLIP
66
+ >>> import clip
67
+ >>> model, preprocess = clip.load("ViT-B/32")
68
+ >>> similarity_scores(img, "cat", model, preprocess, clip.tokenize)
69
+ [[21.813]]
70
+
71
+ using https://github.com/mlfoundations/open_clip
72
+ >>> import open_clip
73
+ >>> model, _, preprocess = open_clip.create_model_and_transforms(
74
+ ... "ViT-B-32", pretrained="laion2b_s34b_b79k"
75
+ ... )
76
+ >>> tokenizer = open_clip.get_tokenizer("ViT-B-32")
77
+ >>> similarity_scores(img, "cat", model, preprocess, tokenizer)
78
+ [[21.813]]
79
+
80
+ using https://huggingface.co/docs/transformers/en/model_doc/clip
81
+ >>> from transformers import CLIPProcessor, CLIPModel
82
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
83
+ >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
84
+ >>> scores = similarity_scores(
85
+ ... img, "cat", model, processor.image_processor, processor.tokenizer
86
+ ... )
87
+ [[21.813]]
88
+
89
+ image -> list of text
90
+ >>> similarity_scores(img, ["cat", "dog"], model, preprocess, tokenizer)
91
+ [[21.813, 35.313]]
92
+
93
+ list of images -> text
94
+ >>> similarity_scores([img1, img2], "cat", model, preprocess, tokenizer)
95
+ [[21.813], [83.123]]
96
+
97
+ list of images -> list of text
98
+ >>> similarity_scores([img1, img2], ["cat", "dog"], model, preprocess, tokenizer)
99
+ [[21.813, 35.313], [83.123, 34.843]]
100
+
101
+ list of images -> list of images
102
+ >>> similarity_scores([img1, img2], None, model, preprocess, tokenizer)
103
+ [[94.189, 37.092]]
104
+
105
+ list of text -> list of text
106
+ >>> similarity_scores(None, ["cat", "dog"], model, preprocess, tokenizer)
107
+ [[67.334, 23.588]]
108
+
109
+ text -> list of images
110
+ >>> similarity_scores([img1, img2], "cat", ..., image_to_text=False)
111
+ [[19.708, 19.842]]
112
+
113
+ show scores as softmax probabilities
114
+ >>> similarity_scores(img, ["cat", "dog"], ..., prob=True)
115
+ [[0.423, 0.577]]
116
+ """
117
+
118
+ with torch.no_grad():
119
+ if images is not None:
120
+ encoder = _get_encoder(model, "image")
121
+ image_features = convert_images(
122
+ images, transform=preprocess, encoder=encoder
123
+ )
124
+ image_features /= image_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
125
+
126
+ if text is not None:
127
+ encoder = _get_encoder(model, "text")
128
+ text_features = convert_text(text, tokenizer, encoder=encoder)
129
+ text_features /= text_features.norm(dim=-1, keepdim=True) # type: ignore[union-attr]
130
+
131
+ if images is not None and text is not None:
132
+ if image_to_text:
133
+ logits = 100.0 * image_features @ text_features.T # type: ignore[operator,union-attr]
134
+ else:
135
+ logits = 100.0 * text_features @ image_features.T # type: ignore[operator,union-attr]
136
+ elif images is not None:
137
+ logits = 100.0 * image_features @ image_features.T # type: ignore[operator,union-attr]
138
+ elif text is not None:
139
+ logits = 100.0 * text_features @ text_features.T # type: ignore[operator,union-attr]
140
+ else:
141
+ raise ValueError(
142
+ "Error calculating CLIP similarity - "
143
+ "provide at least one of images or text"
144
+ )
145
+
146
+ if prob:
147
+ scores = logits.softmax(dim=1)
148
+ else:
149
+ scores = logits
150
+
151
+ return scores.tolist()