datachain 0.7.8__py3-none-any.whl → 0.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/cli.py +9 -3
- datachain/client/fsspec.py +4 -2
- datachain/client/local.py +9 -4
- datachain/data_storage/metastore.py +3 -2
- datachain/func/__init__.py +4 -1
- datachain/func/numeric.py +46 -0
- datachain/func/string.py +46 -0
- datachain/lib/convert/flatten.py +7 -5
- datachain/lib/convert/unflatten.py +2 -2
- datachain/lib/convert/values_to_tuples.py +1 -1
- datachain/lib/dc.py +1 -0
- datachain/lib/pytorch.py +54 -37
- datachain/lib/utils.py +1 -1
- datachain/query/dataset.py +1 -1
- datachain/remote/studio.py +44 -25
- datachain/sql/functions/numeric.py +12 -0
- datachain/sql/functions/string.py +12 -0
- datachain/sql/sqlite/base.py +40 -0
- datachain/studio.py +2 -2
- datachain-0.7.10.dist-info/METADATA +207 -0
- {datachain-0.7.8.dist-info → datachain-0.7.10.dist-info}/RECORD +25 -25
- datachain-0.7.8.dist-info/METADATA +0 -488
- {datachain-0.7.8.dist-info → datachain-0.7.10.dist-info}/LICENSE +0 -0
- {datachain-0.7.8.dist-info → datachain-0.7.10.dist-info}/WHEEL +0 -0
- {datachain-0.7.8.dist-info → datachain-0.7.10.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.8.dist-info → datachain-0.7.10.dist-info}/top_level.txt +0 -0
datachain/cli.py
CHANGED
|
@@ -16,7 +16,7 @@ from tabulate import tabulate
|
|
|
16
16
|
from datachain import Session, utils
|
|
17
17
|
from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyValueArgs
|
|
18
18
|
from datachain.config import Config
|
|
19
|
-
from datachain.error import DataChainError
|
|
19
|
+
from datachain.error import DataChainError, DatasetNotFoundError
|
|
20
20
|
from datachain.lib.dc import DataChain
|
|
21
21
|
from datachain.studio import (
|
|
22
22
|
edit_studio_dataset,
|
|
@@ -1056,7 +1056,10 @@ def rm_dataset(
|
|
|
1056
1056
|
all, local, studio = _determine_flavors(studio, local, all, token)
|
|
1057
1057
|
|
|
1058
1058
|
if all or local:
|
|
1059
|
-
|
|
1059
|
+
try:
|
|
1060
|
+
catalog.remove_dataset(name, version=version, force=force)
|
|
1061
|
+
except DatasetNotFoundError:
|
|
1062
|
+
print("Dataset not found in local", file=sys.stderr)
|
|
1060
1063
|
|
|
1061
1064
|
if (all or studio) and token:
|
|
1062
1065
|
remove_studio_dataset(team, name, version, force)
|
|
@@ -1077,7 +1080,10 @@ def edit_dataset(
|
|
|
1077
1080
|
all, local, studio = _determine_flavors(studio, local, all, token)
|
|
1078
1081
|
|
|
1079
1082
|
if all or local:
|
|
1080
|
-
|
|
1083
|
+
try:
|
|
1084
|
+
catalog.edit_dataset(name, new_name, description, labels)
|
|
1085
|
+
except DatasetNotFoundError:
|
|
1086
|
+
print("Dataset not found in local", file=sys.stderr)
|
|
1081
1087
|
|
|
1082
1088
|
if (all or studio) and token:
|
|
1083
1089
|
edit_studio_dataset(team, name, new_name, description, labels)
|
datachain/client/fsspec.py
CHANGED
|
@@ -172,7 +172,7 @@ class Client(ABC):
|
|
|
172
172
|
return url == cls.PREFIX
|
|
173
173
|
|
|
174
174
|
@classmethod
|
|
175
|
-
def get_uri(cls, name) -> "StorageURI":
|
|
175
|
+
def get_uri(cls, name: str) -> "StorageURI":
|
|
176
176
|
from datachain.dataset import StorageURI
|
|
177
177
|
|
|
178
178
|
return StorageURI(f"{cls.PREFIX}{name}")
|
|
@@ -278,7 +278,9 @@ class Client(ABC):
|
|
|
278
278
|
) -> None:
|
|
279
279
|
await self._fetch_nested(start_prefix, result_queue)
|
|
280
280
|
|
|
281
|
-
async def _fetch_dir(
|
|
281
|
+
async def _fetch_dir(
|
|
282
|
+
self, prefix: str, pbar, result_queue: ResultQueue
|
|
283
|
+
) -> set[str]:
|
|
282
284
|
path = f"{self.name}/{prefix}"
|
|
283
285
|
infos = await self.ls_dir(path)
|
|
284
286
|
files = []
|
datachain/client/local.py
CHANGED
|
@@ -12,6 +12,7 @@ from datachain.lib.file import File
|
|
|
12
12
|
from .fsspec import Client
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
+
from datachain.cache import DataChainCache
|
|
15
16
|
from datachain.dataset import StorageURI
|
|
16
17
|
|
|
17
18
|
|
|
@@ -21,7 +22,11 @@ class FileClient(Client):
|
|
|
21
22
|
protocol = "file"
|
|
22
23
|
|
|
23
24
|
def __init__(
|
|
24
|
-
self,
|
|
25
|
+
self,
|
|
26
|
+
name: str,
|
|
27
|
+
fs_kwargs: dict[str, Any],
|
|
28
|
+
cache: "DataChainCache",
|
|
29
|
+
use_symlinks: bool = False,
|
|
25
30
|
) -> None:
|
|
26
31
|
super().__init__(name, fs_kwargs, cache)
|
|
27
32
|
self.use_symlinks = use_symlinks
|
|
@@ -30,7 +35,7 @@ class FileClient(Client):
|
|
|
30
35
|
raise TypeError("Signed urls are not implemented for local file system")
|
|
31
36
|
|
|
32
37
|
@classmethod
|
|
33
|
-
def get_uri(cls, name) -> "StorageURI":
|
|
38
|
+
def get_uri(cls, name: str) -> "StorageURI":
|
|
34
39
|
from datachain.dataset import StorageURI
|
|
35
40
|
|
|
36
41
|
return StorageURI(f'{cls.PREFIX}/{name.removeprefix("/")}')
|
|
@@ -77,7 +82,7 @@ class FileClient(Client):
|
|
|
77
82
|
return bucket, path
|
|
78
83
|
|
|
79
84
|
@classmethod
|
|
80
|
-
def from_name(cls, name: str, cache, kwargs) -> "FileClient":
|
|
85
|
+
def from_name(cls, name: str, cache: "DataChainCache", kwargs) -> "FileClient":
|
|
81
86
|
use_symlinks = kwargs.pop("use_symlinks", False)
|
|
82
87
|
return cls(name, kwargs, cache, use_symlinks=use_symlinks)
|
|
83
88
|
|
|
@@ -85,7 +90,7 @@ class FileClient(Client):
|
|
|
85
90
|
def from_source(
|
|
86
91
|
cls,
|
|
87
92
|
uri: str,
|
|
88
|
-
cache,
|
|
93
|
+
cache: "DataChainCache",
|
|
89
94
|
use_symlinks: bool = False,
|
|
90
95
|
**kwargs,
|
|
91
96
|
) -> "FileClient":
|
|
@@ -725,9 +725,10 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
725
725
|
|
|
726
726
|
def list_datasets(self) -> Iterator["DatasetListRecord"]:
|
|
727
727
|
"""Lists all datasets."""
|
|
728
|
-
|
|
729
|
-
self.
|
|
728
|
+
query = self._base_list_datasets_query().order_by(
|
|
729
|
+
self._datasets.c.name, self._datasets_versions.c.version
|
|
730
730
|
)
|
|
731
|
+
yield from self._parse_dataset_list(self.db.execute(query))
|
|
731
732
|
|
|
732
733
|
def list_datasets_by_prefix(
|
|
733
734
|
self, prefix: str, conn=None
|
datachain/func/__init__.py
CHANGED
|
@@ -17,8 +17,9 @@ from .aggregate import (
|
|
|
17
17
|
)
|
|
18
18
|
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
19
|
from .conditional import greatest, least
|
|
20
|
-
from .numeric import bit_and, bit_or, bit_xor, int_hash_64
|
|
20
|
+
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
21
|
from .random import rand
|
|
22
|
+
from .string import byte_hamming_distance
|
|
22
23
|
from .window import window
|
|
23
24
|
|
|
24
25
|
__all__ = [
|
|
@@ -26,8 +27,10 @@ __all__ = [
|
|
|
26
27
|
"array",
|
|
27
28
|
"avg",
|
|
28
29
|
"bit_and",
|
|
30
|
+
"bit_hamming_distance",
|
|
29
31
|
"bit_or",
|
|
30
32
|
"bit_xor",
|
|
33
|
+
"byte_hamming_distance",
|
|
31
34
|
"case",
|
|
32
35
|
"collect",
|
|
33
36
|
"concat",
|
datachain/func/numeric.py
CHANGED
|
@@ -160,3 +160,49 @@ def int_hash_64(col: Union[ColT, int]) -> Func:
|
|
|
160
160
|
return Func(
|
|
161
161
|
"int_hash_64", inner=numeric.int_hash_64, cols=cols, args=args, result_type=int
|
|
162
162
|
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def bit_hamming_distance(*args: Union[ColT, int]) -> Func:
|
|
166
|
+
"""
|
|
167
|
+
Computes the Hamming distance between the bit representations of two integer values.
|
|
168
|
+
|
|
169
|
+
The Hamming distance is the number of positions at which the corresponding bits
|
|
170
|
+
are different. This function returns the dissimilarity between the integers,
|
|
171
|
+
where 0 indicates identical integers and values closer to the number of bits
|
|
172
|
+
in the integer indicate higher dissimilarity.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
args (str | int): Two integers to compute the Hamming distance between.
|
|
176
|
+
If a str is provided, it is assumed to be the name of the column.
|
|
177
|
+
If an int is provided, it is assumed to be an integer literal.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Func: A Func object that represents the Hamming distance function.
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
```py
|
|
184
|
+
dc.mutate(
|
|
185
|
+
ham_dist=func.bit_hamming_distance("embed1", 123456),
|
|
186
|
+
)
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
Notes:
|
|
190
|
+
- Result column will always be of type int.
|
|
191
|
+
"""
|
|
192
|
+
cols, func_args = [], []
|
|
193
|
+
for arg in args:
|
|
194
|
+
if isinstance(arg, int):
|
|
195
|
+
func_args.append(arg)
|
|
196
|
+
else:
|
|
197
|
+
cols.append(arg)
|
|
198
|
+
|
|
199
|
+
if len(cols) + len(func_args) != 2:
|
|
200
|
+
raise ValueError("bit_hamming_distance() requires exactly two arguments")
|
|
201
|
+
|
|
202
|
+
return Func(
|
|
203
|
+
"bit_hamming_distance",
|
|
204
|
+
inner=numeric.bit_hamming_distance,
|
|
205
|
+
cols=cols,
|
|
206
|
+
args=func_args,
|
|
207
|
+
result_type=int,
|
|
208
|
+
)
|
datachain/func/string.py
CHANGED
|
@@ -152,3 +152,49 @@ def regexp_replace(col: Union[str, Func], regex: str, replacement: str) -> Func:
|
|
|
152
152
|
args = None
|
|
153
153
|
|
|
154
154
|
return Func("regexp_replace", inner=inner, cols=cols, args=args, result_type=str)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def byte_hamming_distance(*args: Union[str, Func]) -> Func:
|
|
158
|
+
"""
|
|
159
|
+
Computes the Hamming distance between two strings.
|
|
160
|
+
|
|
161
|
+
The Hamming distance is the number of positions at which the corresponding
|
|
162
|
+
characters are different. This function returns the dissimilarity between
|
|
163
|
+
the strings, where 0 indicates identical strings and values closer to the length
|
|
164
|
+
of the strings indicate higher dissimilarity.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
args (str | literal): Two strings to compute the Hamming distance between.
|
|
168
|
+
If a str is provided, it is assumed to be the name of the column.
|
|
169
|
+
If a Literal is provided, it is assumed to be a string literal.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Func: A Func object that represents the Hamming distance function.
|
|
173
|
+
|
|
174
|
+
Example:
|
|
175
|
+
```py
|
|
176
|
+
dc.mutate(
|
|
177
|
+
ham_dist=func.byte_hamming_distance("file.phash", literal("hello")),
|
|
178
|
+
)
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
Notes:
|
|
182
|
+
- Result column will always be of type int.
|
|
183
|
+
"""
|
|
184
|
+
cols, func_args = [], []
|
|
185
|
+
for arg in args:
|
|
186
|
+
if get_origin(arg) is literal:
|
|
187
|
+
func_args.append(arg)
|
|
188
|
+
else:
|
|
189
|
+
cols.append(arg)
|
|
190
|
+
|
|
191
|
+
if len(cols) + len(func_args) != 2:
|
|
192
|
+
raise ValueError("byte_hamming_distance() requires exactly two arguments")
|
|
193
|
+
|
|
194
|
+
return Func(
|
|
195
|
+
"byte_hamming_distance",
|
|
196
|
+
inner=string.byte_hamming_distance,
|
|
197
|
+
cols=cols,
|
|
198
|
+
args=func_args,
|
|
199
|
+
result_type=int,
|
|
200
|
+
)
|
datachain/lib/convert/flatten.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel
|
|
2
4
|
|
|
3
5
|
from datachain.lib.model_store import ModelStore
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
def flatten(obj: BaseModel):
|
|
8
|
+
def flatten(obj: BaseModel) -> tuple:
|
|
7
9
|
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
8
10
|
|
|
9
11
|
|
|
10
|
-
def flatten_list(obj_list):
|
|
12
|
+
def flatten_list(obj_list: list[BaseModel]) -> tuple:
|
|
11
13
|
return tuple(
|
|
12
14
|
val for obj in obj_list for val in _flatten_fields_values(obj.model_fields, obj)
|
|
13
15
|
)
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
def _flatten_list_field(value: list):
|
|
18
|
+
def _flatten_list_field(value: list) -> list:
|
|
17
19
|
assert isinstance(value, list)
|
|
18
20
|
if value and ModelStore.is_pydantic(type(value[0])):
|
|
19
21
|
return [val.model_dump() for val in value]
|
|
@@ -22,7 +24,7 @@ def _flatten_list_field(value: list):
|
|
|
22
24
|
return value
|
|
23
25
|
|
|
24
26
|
|
|
25
|
-
def _flatten_fields_values(fields, obj: BaseModel):
|
|
27
|
+
def _flatten_fields_values(fields: dict, obj: BaseModel) -> Generator:
|
|
26
28
|
for name, f_info in fields.items():
|
|
27
29
|
anno = f_info.annotation
|
|
28
30
|
# Optimization: Access attributes directly to skip the model_dump() call.
|
|
@@ -40,5 +42,5 @@ def _flatten_fields_values(fields, obj: BaseModel):
|
|
|
40
42
|
yield value
|
|
41
43
|
|
|
42
44
|
|
|
43
|
-
def _flatten(obj):
|
|
45
|
+
def _flatten(obj: BaseModel) -> tuple:
|
|
44
46
|
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
@@ -9,12 +9,12 @@ from pydantic import BaseModel
|
|
|
9
9
|
from datachain.query.schema import DEFAULT_DELIMITER
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos=0) -> dict:
|
|
12
|
+
def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos: int = 0) -> dict:
|
|
13
13
|
return unflatten_to_json_pos(model, row, pos)[0]
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def unflatten_to_json_pos(
|
|
17
|
-
model: type[BaseModel], row: Sequence[Any], pos=0
|
|
17
|
+
model: type[BaseModel], row: Sequence[Any], pos: int = 0
|
|
18
18
|
) -> tuple[dict, int]:
|
|
19
19
|
res = {}
|
|
20
20
|
for name, f_info in model.model_fields.items():
|
|
@@ -11,7 +11,7 @@ from datachain.lib.utils import DataChainParamsError
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ValuesToTupleError(DataChainParamsError):
|
|
14
|
-
def __init__(self, ds_name, msg):
|
|
14
|
+
def __init__(self, ds_name: str, msg: str):
|
|
15
15
|
if ds_name:
|
|
16
16
|
ds_name = f"' {ds_name}'"
|
|
17
17
|
super().__init__(f"Cannot convert signals for dataset{ds_name}: {msg}")
|
datachain/lib/dc.py
CHANGED
datachain/lib/pytorch.py
CHANGED
|
@@ -10,8 +10,10 @@ from torchvision.transforms import v2
|
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
|
|
12
12
|
from datachain import Session
|
|
13
|
+
from datachain.asyn import AsyncMapper
|
|
13
14
|
from datachain.catalog import Catalog, get_catalog
|
|
14
15
|
from datachain.lib.dc import DataChain
|
|
16
|
+
from datachain.lib.settings import Settings
|
|
15
17
|
from datachain.lib.text import convert_text
|
|
16
18
|
|
|
17
19
|
if TYPE_CHECKING:
|
|
@@ -30,6 +32,8 @@ def label_to_int(value: str, classes: list) -> int:
|
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
class PytorchDataset(IterableDataset):
|
|
35
|
+
prefetch: int = 2
|
|
36
|
+
|
|
33
37
|
def __init__(
|
|
34
38
|
self,
|
|
35
39
|
name: str,
|
|
@@ -39,6 +43,7 @@ class PytorchDataset(IterableDataset):
|
|
|
39
43
|
tokenizer: Optional[Callable] = None,
|
|
40
44
|
tokenizer_kwargs: Optional[dict[str, Any]] = None,
|
|
41
45
|
num_samples: int = 0,
|
|
46
|
+
dc_settings: Optional[Settings] = None,
|
|
42
47
|
):
|
|
43
48
|
"""
|
|
44
49
|
Pytorch IterableDataset that streams DataChain datasets.
|
|
@@ -66,6 +71,11 @@ class PytorchDataset(IterableDataset):
|
|
|
66
71
|
catalog = get_catalog()
|
|
67
72
|
self._init_catalog(catalog)
|
|
68
73
|
|
|
74
|
+
dc_settings = dc_settings or Settings()
|
|
75
|
+
self.cache = dc_settings.cache
|
|
76
|
+
if (prefetch := dc_settings.prefetch) is not None:
|
|
77
|
+
self.prefetch = prefetch
|
|
78
|
+
|
|
69
79
|
def _init_catalog(self, catalog: "Catalog"):
|
|
70
80
|
# For compatibility with multiprocessing,
|
|
71
81
|
# we can only store params in __init__(), as Catalog isn't picklable
|
|
@@ -82,51 +92,58 @@ class PytorchDataset(IterableDataset):
|
|
|
82
92
|
wh = wh_cls(*wh_args, **wh_kwargs)
|
|
83
93
|
return Catalog(ms, wh, **self._catalog_params)
|
|
84
94
|
|
|
85
|
-
def
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
session = Session.get(catalog=self.catalog)
|
|
89
|
-
total_rank, total_workers = self.get_rank_and_workers()
|
|
95
|
+
def _rows_iter(self, total_rank: int, total_workers: int):
|
|
96
|
+
catalog = self._get_catalog()
|
|
97
|
+
session = Session("PyTorch", catalog=catalog)
|
|
90
98
|
ds = DataChain.from_dataset(
|
|
91
99
|
name=self.name, version=self.version, session=session
|
|
92
|
-
)
|
|
100
|
+
).settings(cache=self.cache, prefetch=self.prefetch)
|
|
93
101
|
ds = ds.remove_file_signals()
|
|
94
102
|
|
|
95
103
|
if self.num_samples > 0:
|
|
96
104
|
ds = ds.sample(self.num_samples)
|
|
97
105
|
ds = ds.chunk(total_rank, total_workers)
|
|
106
|
+
yield from ds.collect()
|
|
107
|
+
|
|
108
|
+
def __iter__(self) -> Iterator[Any]:
|
|
109
|
+
total_rank, total_workers = self.get_rank_and_workers()
|
|
110
|
+
rows = self._rows_iter(total_rank, total_workers)
|
|
111
|
+
if self.prefetch > 0:
|
|
112
|
+
from datachain.lib.udf import _prefetch_input
|
|
113
|
+
|
|
114
|
+
rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
|
|
115
|
+
|
|
98
116
|
desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
|
|
99
|
-
with tqdm(desc=desc, unit=" rows") as
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
pbar.update(1)
|
|
117
|
+
with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it:
|
|
118
|
+
yield from map(self._process_row, rows_it)
|
|
119
|
+
|
|
120
|
+
def _process_row(self, row_features):
|
|
121
|
+
row = []
|
|
122
|
+
for fr in row_features:
|
|
123
|
+
if hasattr(fr, "read"):
|
|
124
|
+
row.append(fr.read()) # type: ignore[unreachable]
|
|
125
|
+
else:
|
|
126
|
+
row.append(fr)
|
|
127
|
+
# Apply transforms
|
|
128
|
+
if self.transform:
|
|
129
|
+
try:
|
|
130
|
+
if isinstance(self.transform, v2.Transform):
|
|
131
|
+
row = self.transform(row)
|
|
132
|
+
for i, val in enumerate(row):
|
|
133
|
+
if isinstance(val, Image.Image):
|
|
134
|
+
row[i] = self.transform(val)
|
|
135
|
+
except ValueError:
|
|
136
|
+
logger.warning("Skipping transform due to unsupported data types.")
|
|
137
|
+
self.transform = None
|
|
138
|
+
if self.tokenizer:
|
|
139
|
+
for i, val in enumerate(row):
|
|
140
|
+
if isinstance(val, str) or (
|
|
141
|
+
isinstance(val, list) and isinstance(val[0], str)
|
|
142
|
+
):
|
|
143
|
+
row[i] = convert_text(
|
|
144
|
+
val, self.tokenizer, self.tokenizer_kwargs
|
|
145
|
+
).squeeze(0) # type: ignore[union-attr]
|
|
146
|
+
return row
|
|
130
147
|
|
|
131
148
|
@staticmethod
|
|
132
149
|
def get_rank_and_workers() -> tuple[int, int]:
|
datachain/lib/utils.py
CHANGED
datachain/query/dataset.py
CHANGED
|
@@ -215,7 +215,7 @@ class DatasetDiffOperation(Step):
|
|
|
215
215
|
Should return select query that calculates desired diff between dataset queries
|
|
216
216
|
"""
|
|
217
217
|
|
|
218
|
-
def apply(self, query_generator, temp_tables: list[str]):
|
|
218
|
+
def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
|
|
219
219
|
source_query = query_generator.exclude(("sys__id",))
|
|
220
220
|
target_query = self.dq.apply_steps().select()
|
|
221
221
|
temp_tables.extend(self.dq.temp_table_names)
|
datachain/remote/studio.py
CHANGED
|
@@ -119,18 +119,27 @@ class StudioClient:
|
|
|
119
119
|
"\tpip install 'datachain[remote]'"
|
|
120
120
|
) from None
|
|
121
121
|
|
|
122
|
-
def _send_request_msgpack(
|
|
122
|
+
def _send_request_msgpack(
|
|
123
|
+
self, route: str, data: dict[str, Any], method: Optional[str] = "POST"
|
|
124
|
+
) -> Response[Any]:
|
|
123
125
|
import msgpack
|
|
124
126
|
import requests
|
|
125
127
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
128
|
+
kwargs = (
|
|
129
|
+
{"params": {**data, "team_name": self.team}}
|
|
130
|
+
if method == "GET"
|
|
131
|
+
else {"json": {**data, "team_name": self.team}}
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
response = requests.request(
|
|
135
|
+
method=method, # type: ignore[arg-type]
|
|
136
|
+
url=f"{self.url}/{route}",
|
|
129
137
|
headers={
|
|
130
138
|
"Content-Type": "application/json",
|
|
131
139
|
"Authorization": f"token {self.token}",
|
|
132
140
|
},
|
|
133
141
|
timeout=self.timeout,
|
|
142
|
+
**kwargs, # type: ignore[arg-type]
|
|
134
143
|
)
|
|
135
144
|
ok = response.ok
|
|
136
145
|
if not ok:
|
|
@@ -148,7 +157,9 @@ class StudioClient:
|
|
|
148
157
|
return Response(response_data, ok, message)
|
|
149
158
|
|
|
150
159
|
@retry_with_backoff(retries=5)
|
|
151
|
-
def _send_request(
|
|
160
|
+
def _send_request(
|
|
161
|
+
self, route: str, data: dict[str, Any], method: Optional[str] = "POST"
|
|
162
|
+
) -> Response[Any]:
|
|
152
163
|
"""
|
|
153
164
|
Function that communicate Studio API.
|
|
154
165
|
It will raise an exception, and try to retry, if 5xx status code is
|
|
@@ -157,14 +168,21 @@ class StudioClient:
|
|
|
157
168
|
"""
|
|
158
169
|
import requests
|
|
159
170
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
171
|
+
kwargs = (
|
|
172
|
+
{"params": {**data, "team_name": self.team}}
|
|
173
|
+
if method == "GET"
|
|
174
|
+
else {"json": {**data, "team_name": self.team}}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
response = requests.request(
|
|
178
|
+
method=method, # type: ignore[arg-type]
|
|
179
|
+
url=f"{self.url}/{route}",
|
|
163
180
|
headers={
|
|
164
181
|
"Content-Type": "application/json",
|
|
165
182
|
"Authorization": f"token {self.token}",
|
|
166
183
|
},
|
|
167
184
|
timeout=self.timeout,
|
|
185
|
+
**kwargs, # type: ignore[arg-type]
|
|
168
186
|
)
|
|
169
187
|
try:
|
|
170
188
|
response.raise_for_status()
|
|
@@ -222,7 +240,7 @@ class StudioClient:
|
|
|
222
240
|
yield path, response
|
|
223
241
|
|
|
224
242
|
def ls_datasets(self) -> Response[LsData]:
|
|
225
|
-
return self._send_request("datachain/
|
|
243
|
+
return self._send_request("datachain/datasets", {}, method="GET")
|
|
226
244
|
|
|
227
245
|
def edit_dataset(
|
|
228
246
|
self,
|
|
@@ -232,20 +250,14 @@ class StudioClient:
|
|
|
232
250
|
labels: Optional[list[str]] = None,
|
|
233
251
|
) -> Response[DatasetInfoData]:
|
|
234
252
|
body = {
|
|
253
|
+
"new_name": new_name,
|
|
235
254
|
"dataset_name": name,
|
|
255
|
+
"description": description,
|
|
256
|
+
"labels": labels,
|
|
236
257
|
}
|
|
237
258
|
|
|
238
|
-
if new_name is not None:
|
|
239
|
-
body["new_name"] = new_name
|
|
240
|
-
|
|
241
|
-
if description is not None:
|
|
242
|
-
body["description"] = description
|
|
243
|
-
|
|
244
|
-
if labels is not None:
|
|
245
|
-
body["labels"] = labels # type: ignore[assignment]
|
|
246
|
-
|
|
247
259
|
return self._send_request(
|
|
248
|
-
"datachain/
|
|
260
|
+
"datachain/datasets",
|
|
249
261
|
body,
|
|
250
262
|
)
|
|
251
263
|
|
|
@@ -256,12 +268,13 @@ class StudioClient:
|
|
|
256
268
|
force: Optional[bool] = False,
|
|
257
269
|
) -> Response[DatasetInfoData]:
|
|
258
270
|
return self._send_request(
|
|
259
|
-
"datachain/
|
|
271
|
+
"datachain/datasets",
|
|
260
272
|
{
|
|
261
273
|
"dataset_name": name,
|
|
262
274
|
"version": version,
|
|
263
275
|
"force": force,
|
|
264
276
|
},
|
|
277
|
+
method="DELETE",
|
|
265
278
|
)
|
|
266
279
|
|
|
267
280
|
def dataset_info(self, name: str) -> Response[DatasetInfoData]:
|
|
@@ -272,7 +285,9 @@ class StudioClient:
|
|
|
272
285
|
|
|
273
286
|
return dataset_info
|
|
274
287
|
|
|
275
|
-
response = self._send_request(
|
|
288
|
+
response = self._send_request(
|
|
289
|
+
"datachain/datasets/info", {"dataset_name": name}, method="GET"
|
|
290
|
+
)
|
|
276
291
|
if response.ok:
|
|
277
292
|
response.data = _parse_dataset_info(response.data)
|
|
278
293
|
return response
|
|
@@ -282,14 +297,16 @@ class StudioClient:
|
|
|
282
297
|
) -> Response[DatasetRowsData]:
|
|
283
298
|
req_data = {"dataset_name": name, "dataset_version": version}
|
|
284
299
|
return self._send_request_msgpack(
|
|
285
|
-
"datachain/
|
|
300
|
+
"datachain/datasets/rows",
|
|
286
301
|
{**req_data, "offset": offset, "limit": DATASET_ROWS_CHUNK_SIZE},
|
|
302
|
+
method="GET",
|
|
287
303
|
)
|
|
288
304
|
|
|
289
305
|
def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
|
|
290
306
|
response = self._send_request(
|
|
291
|
-
"datachain/
|
|
307
|
+
"datachain/datasets/stats",
|
|
292
308
|
{"dataset_name": name, "dataset_version": version},
|
|
309
|
+
method="GET",
|
|
293
310
|
)
|
|
294
311
|
if response.ok:
|
|
295
312
|
response.data = DatasetStats(**response.data)
|
|
@@ -299,16 +316,18 @@ class StudioClient:
|
|
|
299
316
|
self, name: str, version: int
|
|
300
317
|
) -> Response[DatasetExportSignedUrls]:
|
|
301
318
|
return self._send_request(
|
|
302
|
-
"datachain/
|
|
319
|
+
"datachain/datasets/export",
|
|
303
320
|
{"dataset_name": name, "dataset_version": version},
|
|
321
|
+
method="GET",
|
|
304
322
|
)
|
|
305
323
|
|
|
306
324
|
def dataset_export_status(
|
|
307
325
|
self, name: str, version: int
|
|
308
326
|
) -> Response[DatasetExportStatus]:
|
|
309
327
|
return self._send_request(
|
|
310
|
-
"datachain/
|
|
328
|
+
"datachain/datasets/export-status",
|
|
311
329
|
{"dataset_name": name, "dataset_version": version},
|
|
330
|
+
method="GET",
|
|
312
331
|
)
|
|
313
332
|
|
|
314
333
|
def upload_file(self, file_name: str, content: bytes) -> Response[FileUploadData]:
|
|
@@ -35,9 +35,21 @@ class int_hash_64(GenericFunction): # noqa: N801
|
|
|
35
35
|
inherit_cache = True
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
class bit_hamming_distance(GenericFunction): # noqa: N801
|
|
39
|
+
"""
|
|
40
|
+
Returns the Hamming distance between two integers.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
type = Int64()
|
|
44
|
+
package = "numeric"
|
|
45
|
+
name = "hamming_distance"
|
|
46
|
+
inherit_cache = True
|
|
47
|
+
|
|
48
|
+
|
|
38
49
|
compiler_not_implemented(bit_and)
|
|
39
50
|
compiler_not_implemented(bit_or)
|
|
40
51
|
compiler_not_implemented(bit_xor)
|
|
41
52
|
compiler_not_implemented(bit_rshift)
|
|
42
53
|
compiler_not_implemented(bit_lshift)
|
|
43
54
|
compiler_not_implemented(int_hash_64)
|
|
55
|
+
compiler_not_implemented(bit_hamming_distance)
|