datachain 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/image.py CHANGED
@@ -1,6 +1,5 @@
1
- import inspect
2
1
  from io import BytesIO
3
- from typing import Any, Callable, Optional
2
+ from typing import Callable, Optional, Union
4
3
 
5
4
  from datachain.lib.file import File
6
5
 
@@ -14,8 +13,6 @@ except ImportError as exc:
14
13
  " pip install 'datachain[cv]'\n"
15
14
  ) from exc
16
15
 
17
- from datachain.lib.reader import FeatureReader
18
-
19
16
 
20
17
  class ImageFile(File):
21
18
  def get_value(self):
@@ -28,8 +25,8 @@ def convert_image(
28
25
  mode: str = "RGB",
29
26
  size: Optional[tuple[int, int]] = None,
30
27
  transform: Optional[Callable] = None,
31
- open_clip_model: Optional[Any] = None,
32
- ):
28
+ encoder: Optional[Callable] = None,
29
+ ) -> Union[Image.Image, torch.Tensor]:
33
30
  """
34
31
  Resize, transform, and otherwise convert an image.
35
32
 
@@ -37,8 +34,8 @@ def convert_image(
37
34
  img (Image): PIL.Image object.
38
35
  mode (str): PIL.Image mode.
39
36
  size (tuple[int, int]): Size in (width, height) pixels for resizing.
40
- transform (Callable): Torchvision v1 or other transform to apply.
41
- open_clip_model (Any): Encode image using model from open_clip library.
37
+ transform (Callable): Torchvision transform or huggingface processor to apply.
38
+ encoder (Callable): Encode image using model.
42
39
  """
43
40
  if mode:
44
41
  img = img.convert(mode)
@@ -46,86 +43,47 @@ def convert_image(
46
43
  img = img.resize(size)
47
44
  if transform:
48
45
  img = transform(img)
49
- if open_clip_model:
46
+
47
+ try:
48
+ from transformers.image_processing_utils import BaseImageProcessor
49
+
50
+ if isinstance(transform, BaseImageProcessor):
51
+ img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
52
+ except ImportError:
53
+ pass
54
+ if encoder:
50
55
  img = img.unsqueeze(0) # type: ignore[attr-defined]
51
- if open_clip_model:
52
- method_name = "encode_image"
53
- if not (
54
- hasattr(open_clip_model, method_name)
55
- and inspect.ismethod(getattr(open_clip_model, method_name))
56
- ):
57
- raise ValueError(
58
- "Unable to render Image: 'open_clip_model' doesn't support"
59
- f" '{method_name}()'"
60
- )
61
- img = open_clip_model.encode_image(img)
56
+ if encoder:
57
+ img = encoder(img)
62
58
  return img
63
59
 
64
60
 
65
- class ImageReader(FeatureReader):
66
- def __init__(
67
- self,
68
- mode: str = "RGB",
69
- size: Optional[tuple[int, int]] = None,
70
- transform: Optional[Callable] = None,
71
- open_clip_model: Any = None,
72
- ):
73
- """
74
- Read and optionally transform an image.
75
-
76
- All kwargs are passed to `convert_image()`.
77
- """
78
- self.mode = mode
79
- self.size = size
80
- self.transform = transform
81
- self.open_clip_model = open_clip_model
82
- super().__init__(ImageFile)
83
-
84
- def __call__(self, img: Image.Image):
85
- return convert_image(
86
- img,
87
- mode=self.mode,
88
- size=self.size,
89
- transform=self.transform,
90
- open_clip_model=self.open_clip_model,
91
- )
92
-
93
-
94
- def similarity_scores(
95
- model: Any,
96
- preprocess: Callable,
97
- tokenizer: Callable,
98
- image: Image.Image,
99
- text: str,
100
- prob: bool = False,
101
- ) -> list[float]:
61
+ def convert_images(
62
+ images: Union[Image.Image, list[Image.Image]],
63
+ mode: str = "RGB",
64
+ size: Optional[tuple[int, int]] = None,
65
+ transform: Optional[Callable] = None,
66
+ encoder: Optional[Callable] = None,
67
+ ) -> Union[list[Image.Image], torch.Tensor]:
102
68
  """
103
- Calculate CLIP similarity scores for one or more texts given an image.
69
+ Resize, transform, and otherwise convert one or more images.
104
70
 
105
71
  Args:
106
- model: Model from clip or open_clip packages.
107
- preprocess: Image preprocessing transforms.
108
- tokenizer: Text tokenizer.
109
- image: Image.
110
- text: Text.
111
- prob: Compute softmax probabilities across texts.
72
+ img (Image, list[Image]): PIL.Image object or list of objects.
73
+ mode (str): PIL.Image mode.
74
+ size (tuple[int, int]): Size in (width, height) pixels for resizing.
75
+ transform (Callable): Torchvision transform or huggingface processor to apply.
76
+ encoder (Callable): Encode image using model.
112
77
  """
78
+ if isinstance(images, Image.Image):
79
+ images = [images]
113
80
 
114
- with torch.no_grad():
115
- image = preprocess(image).unsqueeze(0)
116
- text = tokenizer(text)
117
-
118
- image_features = model.encode_image(image)
119
- text_features = model.encode_text(text)
120
-
121
- image_features /= image_features.norm(dim=-1, keepdim=True)
122
- text_features /= text_features.norm(dim=-1, keepdim=True)
81
+ converted = [convert_image(img, mode, size, transform) for img in images]
123
82
 
124
- logits_per_text = 100.0 * image_features @ text_features.T
83
+ if isinstance(converted[0], torch.Tensor):
84
+ converted = torch.stack(converted) # type: ignore[assignment,arg-type]
125
85
 
126
- if prob:
127
- scores = logits_per_text.softmax(dim=1)
128
- else:
129
- scores = logits_per_text
86
+ if encoder:
87
+ converted = encoder(converted)
130
88
 
131
- return scores[0].tolist()
89
+ return converted # type: ignore[return-value]
datachain/lib/pytorch.py CHANGED
@@ -116,10 +116,12 @@ class PytorchDataset(IterableDataset):
116
116
  self.transform = None
117
117
  if self.tokenizer:
118
118
  for i, val in enumerate(row):
119
- if isinstance(val, str):
119
+ if isinstance(val, str) or (
120
+ isinstance(val, list) and isinstance(val[0], str)
121
+ ):
120
122
  row[i] = convert_text(
121
123
  val, self.tokenizer, self.tokenizer_kwargs
122
- )
124
+ ).squeeze(0) # type: ignore[union-attr]
123
125
  yield row
124
126
 
125
127
  @staticmethod
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin
5
5
 
6
6
  from pydantic import create_model
7
7
 
8
- from datachain.lib.arrow import Source
9
8
  from datachain.lib.feature import (
10
9
  DATACHAIN_TO_TYPE,
11
10
  DEFAULT_DELIMITER,
@@ -14,17 +13,13 @@ from datachain.lib.feature import (
14
13
  convert_type_to_datachain,
15
14
  )
16
15
  from datachain.lib.feature_registry import Registry
17
- from datachain.lib.file import File, TextFile
18
- from datachain.lib.image import ImageFile
16
+ from datachain.lib.file import File
19
17
  from datachain.lib.utils import DataChainParamsError
20
- from datachain.lib.webdataset import TarStream, WDSAllFile, WDSBasic
21
- from datachain.lib.webdataset_laion import Laion, WDSLaion
22
18
 
23
19
  if TYPE_CHECKING:
24
20
  from datachain.catalog import Catalog
25
21
 
26
22
 
27
- # TODO fix hardcoded Feature class names with://github.com/iterative/dvcx/issues/1625
28
23
  NAMES_TO_TYPES = {
29
24
  "int": int,
30
25
  "str": str,
@@ -34,15 +29,6 @@ NAMES_TO_TYPES = {
34
29
  "dict": dict,
35
30
  "bytes": bytes,
36
31
  "datetime": datetime,
37
- "WDSLaion": WDSLaion,
38
- "Laion": Laion,
39
- "Source": Source,
40
- "File": File,
41
- "ImageFile": ImageFile,
42
- "TextFile": TextFile,
43
- "TarStream": TarStream,
44
- "WDSBasic": WDSBasic,
45
- "WDSAllFile": WDSAllFile,
46
32
  }
47
33
 
48
34
 
@@ -150,7 +136,7 @@ class SignalSchema:
150
136
  )
151
137
 
152
138
  def slice(self, keys: Sequence[str]) -> "SignalSchema":
153
- return SignalSchema({k: v for k, v in self.values.items() if k in keys})
139
+ return SignalSchema({k: self.values[k] for k in keys if k in self.values})
154
140
 
155
141
  def row_to_features(self, row: Sequence, catalog: "Catalog") -> list[FeatureType]:
156
142
  res = []
@@ -240,37 +226,6 @@ class SignalSchema:
240
226
  if has_subtree and issubclass(type_, File):
241
227
  yield ".".join(path)
242
228
 
243
- def get_file_signals_values(self, row: dict[str, Any]) -> dict[str, Any]:
244
- """
245
- Method that returns values with clean field names (without prefix) for
246
- all file signals found in this schema for some row
247
- Output example:
248
- {
249
- laion.file: {
250
- "source": "s3://ldb-public",
251
- "name": "dog.jpg",
252
- ...
253
- },
254
- meta.file: {
255
- "source": "s3://datacomp",
256
- "name": "cat.jpg",
257
- ...
258
- }
259
- }
260
- """
261
- res = {}
262
-
263
- for file_signals in self.get_file_signals():
264
- prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
265
- res[file_signals] = {
266
- c_name.removeprefix(prefix): c_value
267
- for c_name, c_value in row.items()
268
- if c_name.startswith(prefix)
269
- and DEFAULT_DELIMITER not in c_name.removeprefix(prefix)
270
- }
271
-
272
- return res
273
-
274
229
  def create_model(self, name: str) -> type[Feature]:
275
230
  fields = {key: (value, None) for key, value in self.values.items()}
276
231
 
datachain/lib/text.py CHANGED
@@ -1,19 +1,15 @@
1
- import inspect
2
1
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union
3
2
 
4
- from datachain.lib.file import TextFile
5
- from datachain.lib.reader import FeatureReader
6
-
7
3
  if TYPE_CHECKING:
8
- from datachain.lib.feature_utils import FeatureLike
4
+ import torch
9
5
 
10
6
 
11
7
  def convert_text(
12
8
  text: Union[str, list[str]],
13
9
  tokenizer: Optional[Callable] = None,
14
10
  tokenizer_kwargs: Optional[dict[str, Any]] = None,
15
- open_clip_model: Optional[Any] = None,
16
- ):
11
+ encoder: Optional[Callable] = None,
12
+ ) -> Union[str, list[str], "torch.Tensor"]:
17
13
  """
18
14
  Tokenize and otherwise transform text.
19
15
 
@@ -21,18 +17,8 @@ def convert_text(
21
17
  text (str): Text to convert.
22
18
  tokenizer (Callable): Tokenizer to use to tokenize objects.
23
19
  tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
24
- open_clip_model (Any): Encode text using model from open_clip library.
20
+ encoder (Callable): Encode text using model.
25
21
  """
26
- if open_clip_model:
27
- method_name = "encode_text"
28
- if not (
29
- hasattr(open_clip_model, method_name)
30
- and inspect.ismethod(getattr(open_clip_model, method_name))
31
- ):
32
- raise ValueError(
33
- f"TextColumn error: 'model' doesn't support '{method_name}()'"
34
- )
35
-
36
22
  if not tokenizer:
37
23
  return text
38
24
 
@@ -43,38 +29,21 @@ def convert_text(
43
29
  res = tokenizer(text, **tokenizer_kwargs)
44
30
  else:
45
31
  res = tokenizer(text)
46
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
47
-
48
- tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
49
-
50
- if not open_clip_model:
51
- return tokens.squeeze(0)
52
-
53
- return open_clip_model.encode_text(tokens).squeeze(0)
32
+ try:
33
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
54
34
 
35
+ tokens = (
36
+ res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
37
+ )
38
+ except ImportError:
39
+ tokens = res
55
40
 
56
- class TextReader(FeatureReader):
57
- def __init__(
58
- self,
59
- fr_class: "FeatureLike" = TextFile,
60
- tokenizer: Optional[Callable] = None,
61
- tokenizer_kwargs: Optional[dict[str, Any]] = None,
62
- open_clip_model: Optional[Any] = None,
63
- ):
64
- """
65
- Read and optionally transform a text column.
41
+ if not encoder:
42
+ return tokens
66
43
 
67
- All kwargs are passed to `convert_text()`.
68
- """
69
- self.tokenizer = tokenizer
70
- self.tokenizer_kwargs = tokenizer_kwargs
71
- self.open_clip_model = open_clip_model
72
- super().__init__(fr_class)
44
+ try:
45
+ import torch
46
+ except ImportError:
47
+ "Missing dependency 'torch' needed to encode text."
73
48
 
74
- def __call__(self, value: Union[str, list[str]]):
75
- return convert_text(
76
- value,
77
- tokenizer=self.tokenizer,
78
- tokenizer_kwargs=self.tokenizer_kwargs,
79
- open_clip_model=self.open_clip_model,
80
- )
49
+ return encoder(torch.tensor(tokens))
datachain/lib/udf.py CHANGED
@@ -1,11 +1,12 @@
1
1
  import inspect
2
2
  import sys
3
3
  import traceback
4
- from typing import TYPE_CHECKING, Callable, Optional
4
+ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  from datachain.lib.feature import Feature
7
7
  from datachain.lib.signal_schema import SignalSchema
8
- from datachain.lib.utils import DataChainError, DataChainParamsError
8
+ from datachain.lib.udf_signature import UdfSignature
9
+ from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
9
10
  from datachain.query import udf
10
11
 
11
12
  if TYPE_CHECKING:
@@ -17,26 +18,68 @@ class UdfError(DataChainParamsError):
17
18
  super().__init__(f"UDF error: {msg}")
18
19
 
19
20
 
20
- class UDFBase:
21
+ class UDFBase(AbstractUDF):
21
22
  is_input_batched = False
22
23
  is_output_batched = False
23
24
  is_input_grouped = False
24
25
 
25
- def __init__(
26
- self,
27
- params: SignalSchema,
28
- output: SignalSchema,
29
- func: Optional[Callable] = None,
30
- ):
26
+ def __init__(self):
27
+ self.params = None
28
+ self.output = None
29
+ self.params_spec = None
30
+ self.output_spec = None
31
+ self._contains_stream = None
32
+ self._catalog = None
33
+ self._func = None
34
+
35
+ def process(self, *args, **kwargs):
36
+ """Processing function that needs to be defined by user"""
37
+ if not self._func:
38
+ raise NotImplementedError("UDF processing is not implemented")
39
+ return self._func(*args, **kwargs)
40
+
41
+ def setup(self):
42
+ """Initialization process executed on each worker before processing begins.
43
+ This is needed for tasks like pre-loading ML models prior to scoring.
44
+ """
45
+
46
+ def teardown(self):
47
+ """Teardown process executed on each process/worker after processing ends.
48
+ This is needed for tasks like closing connections to end-points.
49
+ """
50
+
51
+ def _init(self, sign: UdfSignature, params: SignalSchema, func: Callable):
31
52
  self.params = params
32
- self.output = output
33
- self._func = func
53
+ self.output = sign.output_schema
34
54
 
35
- params_spec = params.to_udf_spec()
55
+ params_spec = self.params.to_udf_spec()
36
56
  self.params_spec = list(params_spec.keys())
37
- self.output_spec = output.to_udf_spec()
57
+ self.output_spec = self.output.to_udf_spec()
38
58
 
39
- self._catalog = None
59
+ self._func = func
60
+
61
+ @classmethod
62
+ def _create(
63
+ cls,
64
+ target_class: type["UDFBase"],
65
+ sign: UdfSignature,
66
+ params: SignalSchema,
67
+ catalog,
68
+ ) -> "UDFBase":
69
+ if isinstance(sign.func, AbstractUDF):
70
+ if not isinstance(sign.func, target_class): # type: ignore[unreachable]
71
+ raise UdfError(
72
+ f"cannot create UDF: provided UDF '{sign.func.__name__}'"
73
+ f" must be a child of target class '{target_class.__name__}'",
74
+ )
75
+ result = sign.func
76
+ func = None
77
+ else:
78
+ result = target_class()
79
+ func = sign.func
80
+
81
+ result._init(sign, params, func)
82
+ return result
40
83
 
41
84
  @property
42
85
  def name(self):
@@ -53,25 +96,10 @@ class UDFBase:
53
96
  udf_wrapper = udf(self.params_spec, self.output_spec, batch=batch)
54
97
  return udf_wrapper(self)
55
98
 
56
- def bootstrap(self):
57
- """Initialization process executed on each worker before processing begins.
58
- This is needed for tasks like pre-loading ML models prior to scoring.
59
- """
60
-
61
- def teardown(self):
62
- """Teardown process executed on each process/worker after processing ends.
63
- This is needed for tasks like closing connections to end-points.
64
- """
65
-
66
- def process(self, *args, **kwargs):
67
- if not self._func:
68
- raise NotImplementedError("UDF processing is not implemented")
69
- return self._func(*args, **kwargs)
70
-
71
99
  def validate_results(self, results, *args, **kwargs):
72
100
  return results
73
101
 
74
- def __call__(self, *rows, **kwargs):
102
+ def __call__(self, *rows):
75
103
  if self.is_input_grouped:
76
104
  objs = self._parse_grouped_rows(rows)
77
105
  else:
@@ -5,7 +5,7 @@ from typing import Callable, Optional, Union, get_args, get_origin
5
5
 
6
6
  from datachain.lib.feature import Feature, FeatureType, FeatureTypeNames
7
7
  from datachain.lib.signal_schema import SignalSchema
8
- from datachain.lib.utils import DataChainParamsError
8
+ from datachain.lib.utils import AbstractUDF, DataChainParamsError
9
9
 
10
10
 
11
11
  class UdfSignatureError(DataChainParamsError):
@@ -49,10 +49,13 @@ class UdfSignature:
49
49
  else:
50
50
  if func is None:
51
51
  raise UdfSignatureError(chain, "user function is not defined")
52
+
52
53
  udf_func = func
53
54
  signal_name = None
55
+
54
56
  if not callable(udf_func):
55
- raise UdfSignatureError(chain, f"function '{func}' is not callable")
57
+ raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
58
+
56
59
  func_params_map_sign, func_outs_sign, is_iterator = (
57
60
  UdfSignature._func_signature(chain, udf_func)
58
61
  )
@@ -108,13 +111,6 @@ class UdfSignature:
108
111
  if isinstance(output, str):
109
112
  output = [output]
110
113
  if isinstance(output, Sequence):
111
- if not func_outs_sign:
112
- raise UdfSignatureError(
113
- chain,
114
- "output types are not specified. Specify types in 'output' as"
115
- " a dict or as function return value hint.",
116
- )
117
-
118
114
  if len(func_outs_sign) != len(output):
119
115
  raise UdfSignatureError(
120
116
  chain,
@@ -158,8 +154,13 @@ class UdfSignature:
158
154
 
159
155
  @staticmethod
160
156
  def _func_signature(
161
- chain: str, func: Callable
157
+ chain: str, udf_func: Callable
162
158
  ) -> tuple[dict[str, type], Sequence[type], bool]:
159
+ if isinstance(udf_func, AbstractUDF):
160
+ func = udf_func.process # type: ignore[unreachable]
161
+ else:
162
+ func = udf_func
163
+
163
164
  sign = inspect.signature(func)
164
165
 
165
166
  input_map = {prm.name: prm.annotation for prm in sign.parameters.values()}
datachain/lib/utils.py CHANGED
@@ -1,3 +1,20 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class AbstractUDF(ABC):
5
+ @abstractmethod
6
+ def process(self, *args, **kwargs):
7
+ pass
8
+
9
+ @abstractmethod
10
+ def setup(self):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def teardown(self):
15
+ pass
16
+
17
+
1
18
  class DataChainError(Exception):
2
19
  def __init__(self, message):
3
20
  super().__init__(message)
@@ -2,6 +2,7 @@ import hashlib
2
2
  import json
3
3
  import tarfile
4
4
  from collections.abc import Iterator, Sequence
5
+ from pathlib import Path
5
6
  from typing import (
6
7
  Any,
7
8
  Callable,
@@ -240,10 +241,9 @@ class TarStream(File):
240
241
  def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
241
242
  builder = Builder(stream, core_extensions, spec, tar, encoding)
242
243
 
243
- for item in tar.getmembers():
244
+ for item in sorted(tar.getmembers(), key=lambda m: Path(m.name).stem):
244
245
  if not item.isfile():
245
246
  continue
246
-
247
247
  try:
248
248
  builder.add(item)
249
249
  except StopIteration:
datachain/listing.py CHANGED
@@ -20,9 +20,6 @@ if TYPE_CHECKING:
20
20
  from datachain.storage import Storage
21
21
 
22
22
 
23
- RANDOM_BITS = 63 # size of the random integer field
24
-
25
-
26
23
  class Listing:
27
24
  def __init__(
28
25
  self,