datachain 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (51) hide show
  1. datachain/__init__.py +17 -8
  2. datachain/catalog/catalog.py +5 -5
  3. datachain/cli.py +0 -2
  4. datachain/data_storage/schema.py +5 -5
  5. datachain/data_storage/sqlite.py +1 -1
  6. datachain/data_storage/warehouse.py +7 -7
  7. datachain/lib/arrow.py +25 -8
  8. datachain/lib/clip.py +6 -11
  9. datachain/lib/convert/__init__.py +0 -0
  10. datachain/lib/convert/flatten.py +67 -0
  11. datachain/lib/convert/type_converter.py +96 -0
  12. datachain/lib/convert/unflatten.py +69 -0
  13. datachain/lib/convert/values_to_tuples.py +85 -0
  14. datachain/lib/data_model.py +74 -0
  15. datachain/lib/dc.py +225 -168
  16. datachain/lib/file.py +41 -41
  17. datachain/lib/gpt4_vision.py +1 -9
  18. datachain/lib/hf_image_to_text.py +9 -17
  19. datachain/lib/hf_pipeline.py +4 -12
  20. datachain/lib/image.py +2 -18
  21. datachain/lib/image_transform.py +0 -1
  22. datachain/lib/iptc_exif_xmp.py +8 -15
  23. datachain/lib/meta_formats.py +1 -5
  24. datachain/lib/model_store.py +77 -0
  25. datachain/lib/pytorch.py +9 -21
  26. datachain/lib/signal_schema.py +139 -60
  27. datachain/lib/text.py +5 -16
  28. datachain/lib/udf.py +114 -30
  29. datachain/lib/udf_signature.py +5 -5
  30. datachain/lib/webdataset.py +3 -3
  31. datachain/lib/webdataset_laion.py +2 -3
  32. datachain/node.py +4 -4
  33. datachain/query/batch.py +1 -1
  34. datachain/query/dataset.py +51 -178
  35. datachain/query/dispatch.py +43 -30
  36. datachain/query/udf.py +46 -26
  37. datachain/remote/studio.py +1 -9
  38. datachain/torch/__init__.py +21 -0
  39. datachain/utils.py +39 -0
  40. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/METADATA +14 -12
  41. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/RECORD +45 -43
  42. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/WHEEL +1 -1
  43. datachain/image/__init__.py +0 -3
  44. datachain/lib/cached_stream.py +0 -38
  45. datachain/lib/claude.py +0 -69
  46. datachain/lib/feature.py +0 -412
  47. datachain/lib/feature_registry.py +0 -51
  48. datachain/lib/feature_utils.py +0 -154
  49. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/LICENSE +0 -0
  50. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/entry_points.txt +0 -0
  51. {datachain-0.2.9.dist-info → datachain-0.2.11.dist-info}/top_level.txt +0 -0
datachain/lib/file.py CHANGED
@@ -1,18 +1,22 @@
1
+ import io
1
2
  import json
2
3
  from abc import ABC, abstractmethod
4
+ from contextlib import contextmanager
3
5
  from datetime import datetime
6
+ from io import BytesIO
4
7
  from pathlib import Path
5
8
  from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
6
9
  from urllib.parse import unquote, urlparse
7
10
  from urllib.request import url2pathname
8
11
 
12
+ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
9
13
  from fsspec.implementations.local import LocalFileSystem
14
+ from PIL import Image
10
15
  from pydantic import Field, field_validator
11
16
 
12
17
  from datachain.cache import UniqueId
13
18
  from datachain.client.fileslice import FileSlice
14
- from datachain.lib.cached_stream import PreCachedStream, PreDownloadStream
15
- from datachain.lib.feature import Feature
19
+ from datachain.lib.data_model import DataModel, FileBasic
16
20
  from datachain.lib.utils import DataChainError
17
21
  from datachain.sql.types import JSON, Int, String
18
22
  from datachain.utils import TIME_ZERO
@@ -21,20 +25,6 @@ if TYPE_CHECKING:
21
25
  from datachain.catalog import Catalog
22
26
 
23
27
 
24
- class FileFeature(Feature):
25
- _is_file = True
26
-
27
- def open(self):
28
- raise NotImplementedError
29
-
30
- def read(self):
31
- with self.open() as stream:
32
- return stream.read()
33
-
34
- def get_value(self):
35
- return self.read()
36
-
37
-
38
28
  class VFileError(DataChainError):
39
29
  def __init__(self, file: "File", message: str, vtype: str = ""):
40
30
  type_ = f" of vtype '{vtype}'" if vtype else ""
@@ -110,7 +100,7 @@ class VFileRegistry:
110
100
  return reader.open(file, location)
111
101
 
112
102
 
113
- class File(FileFeature):
103
+ class File(FileBasic):
114
104
  source: str = Field(default="")
115
105
  parent: str = Field(default="")
116
106
  name: str
@@ -178,24 +168,33 @@ class File(FileFeature):
178
168
 
179
169
  def __init__(self, **kwargs):
180
170
  super().__init__(**kwargs)
181
- self._stream = None
182
171
  self._catalog = None
183
172
  self._caching_enabled = False
184
173
 
174
+ @contextmanager
185
175
  def open(self):
186
- if self._stream is None:
187
- raise FileError(self, "stream is not set")
188
-
189
176
  if self.location:
190
- return VFileRegistry.resolve(self, self.location)
191
-
192
- return self._stream
193
-
194
- def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
177
+ with VFileRegistry.resolve(self, self.location) as f:
178
+ yield f
179
+
180
+ uid = self.get_uid()
181
+ client = self._catalog.get_client(self.source)
182
+ if self._caching_enabled:
183
+ client.download(uid, callback=self._download_cb)
184
+ with client.open_object(
185
+ uid, use_cache=self._caching_enabled, cb=self._download_cb
186
+ ) as f:
187
+ yield f
188
+
189
+ def _set_stream(
190
+ self,
191
+ catalog: "Catalog",
192
+ caching_enabled: bool = False,
193
+ download_cb: Callback = DEFAULT_CALLBACK,
194
+ ) -> None:
195
195
  self._catalog = catalog
196
- stream_class = PreCachedStream if caching_enabled else PreDownloadStream
197
- self._stream = stream_class(self._catalog, self.get_uid())
198
196
  self._caching_enabled = caching_enabled
197
+ self._download_cb = download_cb
199
198
 
200
199
  def get_uid(self) -> UniqueId:
201
200
  dump = self.model_dump()
@@ -239,22 +238,23 @@ class File(FileFeature):
239
238
 
240
239
 
241
240
  class TextFile(File):
242
- def __init__(self, **kwargs):
243
- super().__init__(**kwargs)
244
- self._stream = None
241
+ @contextmanager
242
+ def open(self):
243
+ with super().open() as binary:
244
+ yield io.TextIOWrapper(binary)
245
+
245
246
 
246
- def _set_stream(self, catalog: "Catalog", caching_enabled: bool = False) -> None:
247
- super()._set_stream(catalog, caching_enabled)
248
- self._stream.set_mode("r")
247
+ class ImageFile(File):
248
+ def get_value(self):
249
+ value = super().get_value()
250
+ return Image.open(BytesIO(value))
249
251
 
250
252
 
251
- def get_file(type: Literal["binary", "text", "image"] = "binary"):
252
- file = File
253
- if type == "text":
253
+ def get_file(type_: Literal["binary", "text", "image"] = "binary"):
254
+ file: type[File] = File
255
+ if type_ == "text":
254
256
  file = TextFile
255
- elif type == "image":
256
- from datachain.lib.image import ImageFile
257
-
257
+ elif type_ == "image":
258
258
  file = ImageFile # type: ignore[assignment]
259
259
 
260
260
  def get_file_type(
@@ -281,7 +281,7 @@ def get_file(type: Literal["binary", "text", "image"] = "binary"):
281
281
  return get_file_type
282
282
 
283
283
 
284
- class IndexedFile(Feature):
284
+ class IndexedFile(DataModel):
285
285
  """File source info for tables."""
286
286
 
287
287
  file: File
@@ -3,15 +3,7 @@ import io
3
3
  import os
4
4
 
5
5
  import requests
6
-
7
- try:
8
- from PIL import Image, ImageOps, UnidentifiedImageError
9
- except ImportError as exc:
10
- raise ImportError(
11
- "Missing dependency Pillow for computer vision:\n"
12
- "To install run:\n\n"
13
- " pip install 'datachain[cv]'\n"
14
- ) from exc
6
+ from PIL import Image, ImageOps, UnidentifiedImageError
15
7
 
16
8
  from datachain.query import Object, udf
17
9
  from datachain.sql.types import String
@@ -1,20 +1,12 @@
1
- try:
2
- import numpy as np
3
- import torch
4
- from PIL import Image, ImageOps, UnidentifiedImageError
5
- from transformers import (
6
- AutoProcessor,
7
- Blip2ForConditionalGeneration,
8
- Blip2Processor,
9
- LlavaForConditionalGeneration,
10
- )
11
- except ImportError as exc:
12
- raise ImportError(
13
- "Missing dependencies for computer vision:\n"
14
- "To install run:\n\n"
15
- " pip install 'datachain[cv]'\n"
16
- ) from exc
17
-
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image, ImageOps, UnidentifiedImageError
4
+ from transformers import (
5
+ AutoProcessor,
6
+ Blip2ForConditionalGeneration,
7
+ Blip2Processor,
8
+ LlavaForConditionalGeneration,
9
+ )
18
10
 
19
11
  from datachain.query import Object, udf
20
12
  from datachain.sql.types import String
@@ -1,22 +1,14 @@
1
1
  import json
2
2
 
3
+ from PIL import (
4
+ Image,
5
+ UnidentifiedImageError,
6
+ )
3
7
  from transformers import pipeline
4
8
 
5
9
  from datachain.query import Object, udf
6
10
  from datachain.sql.types import JSON, String
7
11
 
8
- try:
9
- from PIL import (
10
- Image,
11
- UnidentifiedImageError,
12
- )
13
- except ImportError as exc:
14
- raise ImportError(
15
- "Missing dependency Pillow for computer vision:\n"
16
- "To install run:\n\n"
17
- " pip install 'datachain[cv]'\n"
18
- ) from exc
19
-
20
12
 
21
13
  def read_image(raw):
22
14
  try:
datachain/lib/image.py CHANGED
@@ -1,23 +1,7 @@
1
- from io import BytesIO
2
1
  from typing import Callable, Optional, Union
3
2
 
4
- from datachain.lib.file import File
5
-
6
- try:
7
- import torch
8
- from PIL import Image
9
- except ImportError as exc:
10
- raise ImportError(
11
- "Missing dependencies for computer vision:\n"
12
- "To install run:\n\n"
13
- " pip install 'datachain[cv]'\n"
14
- ) from exc
15
-
16
-
17
- class ImageFile(File):
18
- def get_value(self):
19
- value = super().get_value()
20
- return Image.open(BytesIO(value))
3
+ import torch
4
+ from PIL import Image
21
5
 
22
6
 
23
7
  def convert_image(
@@ -66,7 +66,6 @@ class ImageTransform:
66
66
  ):
67
67
  # Build a dict from row contents
68
68
  record = dict(zip(DatasetRow.schema.keys(), args))
69
- del record["random"] # random will be populated automatically
70
69
  record["is_latest"] = record["is_latest"] > 0 # needs to be a bool
71
70
 
72
71
  # yield same row back
@@ -1,23 +1,16 @@
1
1
  import json
2
2
 
3
+ from PIL import (
4
+ ExifTags,
5
+ Image,
6
+ IptcImagePlugin,
7
+ TiffImagePlugin,
8
+ UnidentifiedImageError,
9
+ )
10
+
3
11
  from datachain.query import Object, udf
4
12
  from datachain.sql.types import JSON, String
5
13
 
6
- try:
7
- from PIL import (
8
- ExifTags,
9
- Image,
10
- IptcImagePlugin,
11
- TiffImagePlugin,
12
- UnidentifiedImageError,
13
- )
14
- except ImportError as exc:
15
- raise ImportError(
16
- "Missing dependency Pillow for computer vision:\n"
17
- "To install run:\n\n"
18
- " pip install 'datachain[cv]'\n"
19
- ) from exc
20
-
21
14
 
22
15
  def encode_image(raw):
23
16
  try:
@@ -13,11 +13,8 @@ from typing import Any, Callable
13
13
  import jmespath as jsp
14
14
  from pydantic import ValidationError
15
15
 
16
- from datachain.lib.feature_utils import pydantic_to_feature # noqa: F401
17
16
  from datachain.lib.file import File
18
17
 
19
- # from datachain.lib.dc import C, DataChain
20
-
21
18
 
22
19
  def generate_uuid():
23
20
  return uuid.uuid4() # Generates a random UUID.
@@ -89,7 +86,6 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
89
86
  except subprocess.CalledProcessError as e:
90
87
  model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
91
88
  print(f"{model_output}")
92
- print("\n" + f"spec=pydantic_to_feature({model_name})" + "\n")
93
89
  return model_output
94
90
 
95
91
 
@@ -131,7 +127,7 @@ def read_meta( # noqa: C901
131
127
 
132
128
  if show_schema:
133
129
  print(f"{model_output}")
134
- # Below 'spec' should be a dynamically converted Feature from Pydantic datamodel
130
+ # Below 'spec' should be a dynamically converted DataModel from Pydantic
135
131
  if not spec:
136
132
  local_vars: dict[str, Any] = {}
137
133
  exec(model_output, globals(), local_vars) # noqa: S102
@@ -0,0 +1,77 @@
1
+ import logging
2
+ from typing import ClassVar, Optional
3
+
4
+ from pydantic import BaseModel
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class ModelStore:
10
+ store: ClassVar[dict[str, dict[int, type[BaseModel]]]] = {}
11
+
12
+ @classmethod
13
+ def get_version(cls, model: type[BaseModel]) -> int:
14
+ if not hasattr(model, "_version"):
15
+ return 0
16
+ return model._version
17
+
18
+ @classmethod
19
+ def get_name(cls, model) -> str:
20
+ if (version := cls.get_version(model)) > 0:
21
+ return f"{model.__name__}@v{version}"
22
+ return model.__name__
23
+
24
+ @classmethod
25
+ def add(cls, fr: type):
26
+ if (model := ModelStore.to_pydantic(fr)) is None:
27
+ return
28
+
29
+ name = model.__name__
30
+ if name not in cls.store:
31
+ cls.store[name] = {}
32
+ version = ModelStore.get_version(model)
33
+ cls.store[name][version] = model
34
+
35
+ for f_info in model.model_fields.values():
36
+ if (anno := ModelStore.to_pydantic(f_info.annotation)) is not None:
37
+ cls.add(anno)
38
+
39
+ @classmethod
40
+ def get(cls, name: str, version: Optional[int] = None) -> Optional[type]:
41
+ class_dict = cls.store.get(name, None)
42
+ if class_dict is None:
43
+ return None
44
+ if version is None:
45
+ max_ver = max(class_dict.keys(), default=None)
46
+ if max_ver is None:
47
+ return None
48
+ return class_dict[max_ver]
49
+ return class_dict.get(version, None)
50
+
51
+ @classmethod
52
+ def parse_name_version(cls, fullname: str) -> tuple[str, int]:
53
+ name = fullname
54
+ version = 0
55
+
56
+ if "@" in fullname:
57
+ name, version_str = fullname.split("@")
58
+ if version_str.strip() != "":
59
+ version = int(version_str[1:])
60
+
61
+ return name, version
62
+
63
+ @classmethod
64
+ def remove(cls, fr: type) -> None:
65
+ version = fr._version # type: ignore[attr-defined]
66
+ if fr.__name__ in cls.store and version in cls.store[fr.__name__]:
67
+ del cls.store[fr.__name__][version]
68
+
69
+ @staticmethod
70
+ def is_pydantic(val):
71
+ return not hasattr(val, "__origin__") and issubclass(val, BaseModel)
72
+
73
+ @staticmethod
74
+ def to_pydantic(val) -> Optional[type[BaseModel]]:
75
+ if val is None or not ModelStore.is_pydantic(val):
76
+ return None
77
+ return val
datachain/lib/pytorch.py CHANGED
@@ -2,13 +2,15 @@ import logging
2
2
  from collections.abc import Iterator
3
3
  from typing import TYPE_CHECKING, Any, Callable, Optional
4
4
 
5
+ from PIL import Image
6
+ from pydantic import BaseModel
5
7
  from torch import float32
6
8
  from torch.distributed import get_rank, get_world_size
7
9
  from torch.utils.data import IterableDataset, get_worker_info
10
+ from torchvision.transforms import v2
8
11
 
9
12
  from datachain.catalog import Catalog, get_catalog
10
13
  from datachain.lib.dc import DataChain
11
- from datachain.lib.feature import Feature
12
14
  from datachain.lib.text import convert_text
13
15
 
14
16
  if TYPE_CHECKING:
@@ -18,20 +20,7 @@ if TYPE_CHECKING:
18
20
  logger = logging.getLogger("datachain")
19
21
 
20
22
 
21
- try:
22
- from PIL import Image
23
- from torchvision.transforms import v2
24
-
25
- DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])
26
- except ImportError:
27
- logger.warning(
28
- "Missing dependencies for computer vision:\n"
29
- "To install run:\n\n"
30
- " pip install 'datachain[cv]'\n"
31
- )
32
- Image = None # type: ignore[assignment]
33
- v2 = None
34
- DEFAULT_TRANSFORM = None
23
+ DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])
35
24
 
36
25
 
37
26
  def label_to_int(value: str, classes: list) -> int:
@@ -105,19 +94,18 @@ class PytorchDataset(IterableDataset):
105
94
  for row_features in stream:
106
95
  row = []
107
96
  for fr in row_features:
108
- if isinstance(fr, Feature):
97
+ if isinstance(fr, BaseModel):
109
98
  row.append(fr.get_value()) # type: ignore[unreachable]
110
99
  else:
111
100
  row.append(fr)
112
101
  # Apply transforms
113
102
  if self.transform:
114
103
  try:
115
- if v2 and isinstance(self.transform, v2.Transform):
104
+ if isinstance(self.transform, v2.Transform):
116
105
  row = self.transform(row)
117
- elif Image:
118
- for i, val in enumerate(row):
119
- if isinstance(val, Image.Image):
120
- row[i] = self.transform(val)
106
+ for i, val in enumerate(row):
107
+ if isinstance(val, Image.Image):
108
+ row[i] = self.transform(val)
121
109
  except ValueError:
122
110
  logger.warning("Skipping transform due to unsupported data types.")
123
111
  self.transform = None