datachain 0.7.9__py3-none-any.whl → 0.7.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/client/__init__.py +1 -2
- datachain/client/fsspec.py +4 -2
- datachain/client/local.py +9 -4
- datachain/func/__init__.py +4 -1
- datachain/func/numeric.py +46 -0
- datachain/func/string.py +46 -0
- datachain/lib/convert/flatten.py +7 -5
- datachain/lib/convert/unflatten.py +2 -2
- datachain/lib/convert/values_to_tuples.py +1 -1
- datachain/lib/dc.py +5 -1
- datachain/lib/file.py +2 -1
- datachain/lib/meta_formats.py +2 -1
- datachain/lib/pytorch.py +1 -5
- datachain/lib/signal_schema.py +28 -6
- datachain/lib/utils.py +1 -1
- datachain/query/dataset.py +5 -2
- datachain/sql/functions/numeric.py +12 -0
- datachain/sql/functions/string.py +12 -0
- datachain/sql/sqlite/base.py +40 -0
- datachain/toolkit/split.py +19 -6
- datachain-0.7.11.dist-info/METADATA +206 -0
- {datachain-0.7.9.dist-info → datachain-0.7.11.dist-info}/RECORD +26 -26
- datachain-0.7.9.dist-info/METADATA +0 -488
- {datachain-0.7.9.dist-info → datachain-0.7.11.dist-info}/LICENSE +0 -0
- {datachain-0.7.9.dist-info → datachain-0.7.11.dist-info}/WHEEL +0 -0
- {datachain-0.7.9.dist-info → datachain-0.7.11.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.9.dist-info → datachain-0.7.11.dist-info}/top_level.txt +0 -0
datachain/client/__init__.py
CHANGED
datachain/client/fsspec.py
CHANGED
|
@@ -172,7 +172,7 @@ class Client(ABC):
|
|
|
172
172
|
return url == cls.PREFIX
|
|
173
173
|
|
|
174
174
|
@classmethod
|
|
175
|
-
def get_uri(cls, name) -> "StorageURI":
|
|
175
|
+
def get_uri(cls, name: str) -> "StorageURI":
|
|
176
176
|
from datachain.dataset import StorageURI
|
|
177
177
|
|
|
178
178
|
return StorageURI(f"{cls.PREFIX}{name}")
|
|
@@ -278,7 +278,9 @@ class Client(ABC):
|
|
|
278
278
|
) -> None:
|
|
279
279
|
await self._fetch_nested(start_prefix, result_queue)
|
|
280
280
|
|
|
281
|
-
async def _fetch_dir(
|
|
281
|
+
async def _fetch_dir(
|
|
282
|
+
self, prefix: str, pbar, result_queue: ResultQueue
|
|
283
|
+
) -> set[str]:
|
|
282
284
|
path = f"{self.name}/{prefix}"
|
|
283
285
|
infos = await self.ls_dir(path)
|
|
284
286
|
files = []
|
datachain/client/local.py
CHANGED
|
@@ -12,6 +12,7 @@ from datachain.lib.file import File
|
|
|
12
12
|
from .fsspec import Client
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
+
from datachain.cache import DataChainCache
|
|
15
16
|
from datachain.dataset import StorageURI
|
|
16
17
|
|
|
17
18
|
|
|
@@ -21,7 +22,11 @@ class FileClient(Client):
|
|
|
21
22
|
protocol = "file"
|
|
22
23
|
|
|
23
24
|
def __init__(
|
|
24
|
-
self,
|
|
25
|
+
self,
|
|
26
|
+
name: str,
|
|
27
|
+
fs_kwargs: dict[str, Any],
|
|
28
|
+
cache: "DataChainCache",
|
|
29
|
+
use_symlinks: bool = False,
|
|
25
30
|
) -> None:
|
|
26
31
|
super().__init__(name, fs_kwargs, cache)
|
|
27
32
|
self.use_symlinks = use_symlinks
|
|
@@ -30,7 +35,7 @@ class FileClient(Client):
|
|
|
30
35
|
raise TypeError("Signed urls are not implemented for local file system")
|
|
31
36
|
|
|
32
37
|
@classmethod
|
|
33
|
-
def get_uri(cls, name) -> "StorageURI":
|
|
38
|
+
def get_uri(cls, name: str) -> "StorageURI":
|
|
34
39
|
from datachain.dataset import StorageURI
|
|
35
40
|
|
|
36
41
|
return StorageURI(f'{cls.PREFIX}/{name.removeprefix("/")}')
|
|
@@ -77,7 +82,7 @@ class FileClient(Client):
|
|
|
77
82
|
return bucket, path
|
|
78
83
|
|
|
79
84
|
@classmethod
|
|
80
|
-
def from_name(cls, name: str, cache, kwargs) -> "FileClient":
|
|
85
|
+
def from_name(cls, name: str, cache: "DataChainCache", kwargs) -> "FileClient":
|
|
81
86
|
use_symlinks = kwargs.pop("use_symlinks", False)
|
|
82
87
|
return cls(name, kwargs, cache, use_symlinks=use_symlinks)
|
|
83
88
|
|
|
@@ -85,7 +90,7 @@ class FileClient(Client):
|
|
|
85
90
|
def from_source(
|
|
86
91
|
cls,
|
|
87
92
|
uri: str,
|
|
88
|
-
cache,
|
|
93
|
+
cache: "DataChainCache",
|
|
89
94
|
use_symlinks: bool = False,
|
|
90
95
|
**kwargs,
|
|
91
96
|
) -> "FileClient":
|
datachain/func/__init__.py
CHANGED
|
@@ -17,8 +17,9 @@ from .aggregate import (
|
|
|
17
17
|
)
|
|
18
18
|
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
19
|
from .conditional import greatest, least
|
|
20
|
-
from .numeric import bit_and, bit_or, bit_xor, int_hash_64
|
|
20
|
+
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
21
|
from .random import rand
|
|
22
|
+
from .string import byte_hamming_distance
|
|
22
23
|
from .window import window
|
|
23
24
|
|
|
24
25
|
__all__ = [
|
|
@@ -26,8 +27,10 @@ __all__ = [
|
|
|
26
27
|
"array",
|
|
27
28
|
"avg",
|
|
28
29
|
"bit_and",
|
|
30
|
+
"bit_hamming_distance",
|
|
29
31
|
"bit_or",
|
|
30
32
|
"bit_xor",
|
|
33
|
+
"byte_hamming_distance",
|
|
31
34
|
"case",
|
|
32
35
|
"collect",
|
|
33
36
|
"concat",
|
datachain/func/numeric.py
CHANGED
|
@@ -160,3 +160,49 @@ def int_hash_64(col: Union[ColT, int]) -> Func:
|
|
|
160
160
|
return Func(
|
|
161
161
|
"int_hash_64", inner=numeric.int_hash_64, cols=cols, args=args, result_type=int
|
|
162
162
|
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def bit_hamming_distance(*args: Union[ColT, int]) -> Func:
|
|
166
|
+
"""
|
|
167
|
+
Computes the Hamming distance between the bit representations of two integer values.
|
|
168
|
+
|
|
169
|
+
The Hamming distance is the number of positions at which the corresponding bits
|
|
170
|
+
are different. This function returns the dissimilarity between the integers,
|
|
171
|
+
where 0 indicates identical integers and values closer to the number of bits
|
|
172
|
+
in the integer indicate higher dissimilarity.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
args (str | int): Two integers to compute the Hamming distance between.
|
|
176
|
+
If a str is provided, it is assumed to be the name of the column.
|
|
177
|
+
If an int is provided, it is assumed to be an integer literal.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Func: A Func object that represents the Hamming distance function.
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
```py
|
|
184
|
+
dc.mutate(
|
|
185
|
+
ham_dist=func.bit_hamming_distance("embed1", 123456),
|
|
186
|
+
)
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
Notes:
|
|
190
|
+
- Result column will always be of type int.
|
|
191
|
+
"""
|
|
192
|
+
cols, func_args = [], []
|
|
193
|
+
for arg in args:
|
|
194
|
+
if isinstance(arg, int):
|
|
195
|
+
func_args.append(arg)
|
|
196
|
+
else:
|
|
197
|
+
cols.append(arg)
|
|
198
|
+
|
|
199
|
+
if len(cols) + len(func_args) != 2:
|
|
200
|
+
raise ValueError("bit_hamming_distance() requires exactly two arguments")
|
|
201
|
+
|
|
202
|
+
return Func(
|
|
203
|
+
"bit_hamming_distance",
|
|
204
|
+
inner=numeric.bit_hamming_distance,
|
|
205
|
+
cols=cols,
|
|
206
|
+
args=func_args,
|
|
207
|
+
result_type=int,
|
|
208
|
+
)
|
datachain/func/string.py
CHANGED
|
@@ -152,3 +152,49 @@ def regexp_replace(col: Union[str, Func], regex: str, replacement: str) -> Func:
|
|
|
152
152
|
args = None
|
|
153
153
|
|
|
154
154
|
return Func("regexp_replace", inner=inner, cols=cols, args=args, result_type=str)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def byte_hamming_distance(*args: Union[str, Func]) -> Func:
|
|
158
|
+
"""
|
|
159
|
+
Computes the Hamming distance between two strings.
|
|
160
|
+
|
|
161
|
+
The Hamming distance is the number of positions at which the corresponding
|
|
162
|
+
characters are different. This function returns the dissimilarity between
|
|
163
|
+
the strings, where 0 indicates identical strings and values closer to the length
|
|
164
|
+
of the strings indicate higher dissimilarity.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
args (str | literal): Two strings to compute the Hamming distance between.
|
|
168
|
+
If a str is provided, it is assumed to be the name of the column.
|
|
169
|
+
If a Literal is provided, it is assumed to be a string literal.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Func: A Func object that represents the Hamming distance function.
|
|
173
|
+
|
|
174
|
+
Example:
|
|
175
|
+
```py
|
|
176
|
+
dc.mutate(
|
|
177
|
+
ham_dist=func.byte_hamming_distance("file.phash", literal("hello")),
|
|
178
|
+
)
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
Notes:
|
|
182
|
+
- Result column will always be of type int.
|
|
183
|
+
"""
|
|
184
|
+
cols, func_args = [], []
|
|
185
|
+
for arg in args:
|
|
186
|
+
if get_origin(arg) is literal:
|
|
187
|
+
func_args.append(arg)
|
|
188
|
+
else:
|
|
189
|
+
cols.append(arg)
|
|
190
|
+
|
|
191
|
+
if len(cols) + len(func_args) != 2:
|
|
192
|
+
raise ValueError("byte_hamming_distance() requires exactly two arguments")
|
|
193
|
+
|
|
194
|
+
return Func(
|
|
195
|
+
"byte_hamming_distance",
|
|
196
|
+
inner=string.byte_hamming_distance,
|
|
197
|
+
cols=cols,
|
|
198
|
+
args=func_args,
|
|
199
|
+
result_type=int,
|
|
200
|
+
)
|
datachain/lib/convert/flatten.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel
|
|
2
4
|
|
|
3
5
|
from datachain.lib.model_store import ModelStore
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
def flatten(obj: BaseModel):
|
|
8
|
+
def flatten(obj: BaseModel) -> tuple:
|
|
7
9
|
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
8
10
|
|
|
9
11
|
|
|
10
|
-
def flatten_list(obj_list):
|
|
12
|
+
def flatten_list(obj_list: list[BaseModel]) -> tuple:
|
|
11
13
|
return tuple(
|
|
12
14
|
val for obj in obj_list for val in _flatten_fields_values(obj.model_fields, obj)
|
|
13
15
|
)
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
def _flatten_list_field(value: list):
|
|
18
|
+
def _flatten_list_field(value: list) -> list:
|
|
17
19
|
assert isinstance(value, list)
|
|
18
20
|
if value and ModelStore.is_pydantic(type(value[0])):
|
|
19
21
|
return [val.model_dump() for val in value]
|
|
@@ -22,7 +24,7 @@ def _flatten_list_field(value: list):
|
|
|
22
24
|
return value
|
|
23
25
|
|
|
24
26
|
|
|
25
|
-
def _flatten_fields_values(fields, obj: BaseModel):
|
|
27
|
+
def _flatten_fields_values(fields: dict, obj: BaseModel) -> Generator:
|
|
26
28
|
for name, f_info in fields.items():
|
|
27
29
|
anno = f_info.annotation
|
|
28
30
|
# Optimization: Access attributes directly to skip the model_dump() call.
|
|
@@ -40,5 +42,5 @@ def _flatten_fields_values(fields, obj: BaseModel):
|
|
|
40
42
|
yield value
|
|
41
43
|
|
|
42
44
|
|
|
43
|
-
def _flatten(obj):
|
|
45
|
+
def _flatten(obj: BaseModel) -> tuple:
|
|
44
46
|
return tuple(_flatten_fields_values(obj.model_fields, obj))
|
|
@@ -9,12 +9,12 @@ from pydantic import BaseModel
|
|
|
9
9
|
from datachain.query.schema import DEFAULT_DELIMITER
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos=0) -> dict:
|
|
12
|
+
def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos: int = 0) -> dict:
|
|
13
13
|
return unflatten_to_json_pos(model, row, pos)[0]
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def unflatten_to_json_pos(
|
|
17
|
-
model: type[BaseModel], row: Sequence[Any], pos=0
|
|
17
|
+
model: type[BaseModel], row: Sequence[Any], pos: int = 0
|
|
18
18
|
) -> tuple[dict, int]:
|
|
19
19
|
res = {}
|
|
20
20
|
for name, f_info in model.model_fields.items():
|
|
@@ -11,7 +11,7 @@ from datachain.lib.utils import DataChainParamsError
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ValuesToTupleError(DataChainParamsError):
|
|
14
|
-
def __init__(self, ds_name, msg):
|
|
14
|
+
def __init__(self, ds_name: str, msg: str):
|
|
15
15
|
if ds_name:
|
|
16
16
|
ds_name = f"' {ds_name}'"
|
|
17
17
|
super().__init__(f"Cannot convert signals for dataset{ds_name}: {msg}")
|
datachain/lib/dc.py
CHANGED
|
@@ -19,7 +19,6 @@ from typing import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
import orjson
|
|
22
|
-
import pandas as pd
|
|
23
22
|
import sqlalchemy
|
|
24
23
|
from pydantic import BaseModel
|
|
25
24
|
from sqlalchemy.sql.functions import GenericFunction
|
|
@@ -57,6 +56,7 @@ from datachain.telemetry import telemetry
|
|
|
57
56
|
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
|
|
58
57
|
|
|
59
58
|
if TYPE_CHECKING:
|
|
59
|
+
import pandas as pd
|
|
60
60
|
from pyarrow import DataType as ArrowDataType
|
|
61
61
|
from typing_extensions import Concatenate, ParamSpec, Self
|
|
62
62
|
|
|
@@ -1701,6 +1701,8 @@ class DataChain:
|
|
|
1701
1701
|
Parameters:
|
|
1702
1702
|
flatten : Whether to use a multiindex or flatten column names.
|
|
1703
1703
|
"""
|
|
1704
|
+
import pandas as pd
|
|
1705
|
+
|
|
1704
1706
|
headers, max_length = self._effective_signals_schema.get_headers_with_length()
|
|
1705
1707
|
if flatten or max_length < 2:
|
|
1706
1708
|
columns = [".".join(filter(None, header)) for header in headers]
|
|
@@ -1724,6 +1726,8 @@ class DataChain:
|
|
|
1724
1726
|
transpose : Whether to transpose rows and columns.
|
|
1725
1727
|
truncate : Whether or not to truncate the contents of columns.
|
|
1726
1728
|
"""
|
|
1729
|
+
import pandas as pd
|
|
1730
|
+
|
|
1727
1731
|
dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
|
|
1728
1732
|
df = dc.to_pandas(flatten)
|
|
1729
1733
|
|
datachain/lib/file.py
CHANGED
|
@@ -17,7 +17,6 @@ from urllib.request import url2pathname
|
|
|
17
17
|
|
|
18
18
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
19
19
|
from PIL import Image
|
|
20
|
-
from pyarrow.dataset import dataset
|
|
21
20
|
from pydantic import Field, field_validator
|
|
22
21
|
|
|
23
22
|
from datachain.client.fileslice import FileSlice
|
|
@@ -452,6 +451,8 @@ class ArrowRow(DataModel):
|
|
|
452
451
|
@contextmanager
|
|
453
452
|
def open(self):
|
|
454
453
|
"""Stream row contents from indexed file."""
|
|
454
|
+
from pyarrow.dataset import dataset
|
|
455
|
+
|
|
455
456
|
if self.file._caching_enabled:
|
|
456
457
|
self.file.ensure_cached()
|
|
457
458
|
path = self.file.get_local_path()
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -6,7 +6,6 @@ from collections.abc import Iterator
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Callable
|
|
8
8
|
|
|
9
|
-
import datamodel_code_generator
|
|
10
9
|
import jmespath as jsp
|
|
11
10
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
|
|
12
11
|
|
|
@@ -67,6 +66,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
|
|
|
67
66
|
data_type = "json" # treat json line as plain JSON in auto-schema
|
|
68
67
|
data_string = json.dumps(json_object)
|
|
69
68
|
|
|
69
|
+
import datamodel_code_generator
|
|
70
|
+
|
|
70
71
|
input_file_types = {i.value: i for i in datamodel_code_generator.InputFileType}
|
|
71
72
|
input_file_type = input_file_types[data_type]
|
|
72
73
|
with tempfile.TemporaryDirectory() as tmpdir:
|
datachain/lib/pytorch.py
CHANGED
|
@@ -7,7 +7,6 @@ from torch import float32
|
|
|
7
7
|
from torch.distributed import get_rank, get_world_size
|
|
8
8
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
9
|
from torchvision.transforms import v2
|
|
10
|
-
from tqdm import tqdm
|
|
11
10
|
|
|
12
11
|
from datachain import Session
|
|
13
12
|
from datachain.asyn import AsyncMapper
|
|
@@ -112,10 +111,7 @@ class PytorchDataset(IterableDataset):
|
|
|
112
111
|
from datachain.lib.udf import _prefetch_input
|
|
113
112
|
|
|
114
113
|
rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
|
|
115
|
-
|
|
116
|
-
desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
|
|
117
|
-
with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it:
|
|
118
|
-
yield from map(self._process_row, rows_it)
|
|
114
|
+
yield from map(self._process_row, rows)
|
|
119
115
|
|
|
120
116
|
def _process_row(self, row_features):
|
|
121
117
|
row = []
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -402,9 +402,20 @@ class SignalSchema:
|
|
|
402
402
|
if ModelStore.is_pydantic(finfo.annotation):
|
|
403
403
|
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
|
|
404
404
|
|
|
405
|
-
def get_column_type(self, col_name: str) -> DataType:
|
|
405
|
+
def get_column_type(self, col_name: str, with_subtree: bool = False) -> DataType:
|
|
406
|
+
"""
|
|
407
|
+
Returns column type by column name.
|
|
408
|
+
|
|
409
|
+
If `with_subtree` is True, then it will return the type of the column
|
|
410
|
+
even if it has a subtree (e.g. model with nested fields), otherwise it will
|
|
411
|
+
return the type of the column (standard type field, not the model).
|
|
412
|
+
|
|
413
|
+
If column is not found, raises `SignalResolvingError`.
|
|
414
|
+
"""
|
|
406
415
|
for path, _type, has_subtree, _ in self.get_flat_tree():
|
|
407
|
-
if not has_subtree and DEFAULT_DELIMITER.join(
|
|
416
|
+
if (with_subtree or not has_subtree) and DEFAULT_DELIMITER.join(
|
|
417
|
+
path
|
|
418
|
+
) == col_name:
|
|
408
419
|
return _type
|
|
409
420
|
raise SignalResolvingError([col_name], "is not found")
|
|
410
421
|
|
|
@@ -492,14 +503,25 @@ class SignalSchema:
|
|
|
492
503
|
# renaming existing signal
|
|
493
504
|
del new_values[value.name]
|
|
494
505
|
new_values[name] = self.values[value.name]
|
|
495
|
-
|
|
506
|
+
continue
|
|
507
|
+
if isinstance(value, Column):
|
|
508
|
+
# adding new signal from existing signal field
|
|
509
|
+
try:
|
|
510
|
+
new_values[name] = self.get_column_type(
|
|
511
|
+
value.name, with_subtree=True
|
|
512
|
+
)
|
|
513
|
+
continue
|
|
514
|
+
except SignalResolvingError:
|
|
515
|
+
pass
|
|
516
|
+
if isinstance(value, Func):
|
|
496
517
|
# adding new signal with function
|
|
497
518
|
new_values[name] = value.get_result_type(self)
|
|
498
|
-
|
|
519
|
+
continue
|
|
520
|
+
if isinstance(value, ColumnElement):
|
|
499
521
|
# adding new signal
|
|
500
522
|
new_values[name] = sql_to_python(value)
|
|
501
|
-
|
|
502
|
-
|
|
523
|
+
continue
|
|
524
|
+
new_values[name] = value
|
|
503
525
|
|
|
504
526
|
return SignalSchema(new_values)
|
|
505
527
|
|
datachain/lib/utils.py
CHANGED
datachain/query/dataset.py
CHANGED
|
@@ -35,7 +35,6 @@ from sqlalchemy.sql.schema import TableClause
|
|
|
35
35
|
from sqlalchemy.sql.selectable import Select
|
|
36
36
|
|
|
37
37
|
from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
|
|
38
|
-
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
|
|
39
38
|
from datachain.data_storage.schema import (
|
|
40
39
|
PARTITION_COLUMN_ID,
|
|
41
40
|
partition_col_names,
|
|
@@ -215,7 +214,7 @@ class DatasetDiffOperation(Step):
|
|
|
215
214
|
Should return select query that calculates desired diff between dataset queries
|
|
216
215
|
"""
|
|
217
216
|
|
|
218
|
-
def apply(self, query_generator, temp_tables: list[str]):
|
|
217
|
+
def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
|
|
219
218
|
source_query = query_generator.exclude(("sys__id",))
|
|
220
219
|
target_query = self.dq.apply_steps().select()
|
|
221
220
|
temp_tables.extend(self.dq.temp_table_names)
|
|
@@ -394,6 +393,8 @@ class UDFStep(Step, ABC):
|
|
|
394
393
|
"""
|
|
395
394
|
|
|
396
395
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
396
|
+
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
|
|
397
|
+
|
|
397
398
|
use_partitioning = self.partition_by is not None
|
|
398
399
|
batching = self.udf.get_batching(use_partitioning)
|
|
399
400
|
workers = self.workers
|
|
@@ -1087,6 +1088,8 @@ class DatasetQuery:
|
|
|
1087
1088
|
def delete(
|
|
1088
1089
|
name: str, version: Optional[int] = None, catalog: Optional["Catalog"] = None
|
|
1089
1090
|
) -> None:
|
|
1091
|
+
from datachain.catalog import get_catalog
|
|
1092
|
+
|
|
1090
1093
|
catalog = catalog or get_catalog()
|
|
1091
1094
|
version = version or catalog.get_dataset(name).latest_version
|
|
1092
1095
|
catalog.remove_dataset(name, version)
|
|
@@ -35,9 +35,21 @@ class int_hash_64(GenericFunction): # noqa: N801
|
|
|
35
35
|
inherit_cache = True
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
class bit_hamming_distance(GenericFunction): # noqa: N801
|
|
39
|
+
"""
|
|
40
|
+
Returns the Hamming distance between two integers.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
type = Int64()
|
|
44
|
+
package = "numeric"
|
|
45
|
+
name = "hamming_distance"
|
|
46
|
+
inherit_cache = True
|
|
47
|
+
|
|
48
|
+
|
|
38
49
|
compiler_not_implemented(bit_and)
|
|
39
50
|
compiler_not_implemented(bit_or)
|
|
40
51
|
compiler_not_implemented(bit_xor)
|
|
41
52
|
compiler_not_implemented(bit_rshift)
|
|
42
53
|
compiler_not_implemented(bit_lshift)
|
|
43
54
|
compiler_not_implemented(int_hash_64)
|
|
55
|
+
compiler_not_implemented(bit_hamming_distance)
|
|
@@ -48,7 +48,19 @@ class replace(GenericFunction): # noqa: N801
|
|
|
48
48
|
inherit_cache = True
|
|
49
49
|
|
|
50
50
|
|
|
51
|
+
class byte_hamming_distance(GenericFunction): # noqa: N801
|
|
52
|
+
"""
|
|
53
|
+
Returns the Hamming distance between two strings.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
type = Int64()
|
|
57
|
+
package = "string"
|
|
58
|
+
name = "hamming_distance"
|
|
59
|
+
inherit_cache = True
|
|
60
|
+
|
|
61
|
+
|
|
51
62
|
compiler_not_implemented(length)
|
|
52
63
|
compiler_not_implemented(split)
|
|
53
64
|
compiler_not_implemented(regexp_replace)
|
|
54
65
|
compiler_not_implemented(replace)
|
|
66
|
+
compiler_not_implemented(byte_hamming_distance)
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -90,6 +90,7 @@ def setup():
|
|
|
90
90
|
compiles(string.split, "sqlite")(compile_string_split)
|
|
91
91
|
compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
|
|
92
92
|
compiles(string.replace, "sqlite")(compile_string_replace)
|
|
93
|
+
compiles(string.byte_hamming_distance, "sqlite")(compile_byte_hamming_distance)
|
|
93
94
|
compiles(conditional.greatest, "sqlite")(compile_greatest)
|
|
94
95
|
compiles(conditional.least, "sqlite")(compile_least)
|
|
95
96
|
compiles(Values, "sqlite")(compile_values)
|
|
@@ -104,6 +105,7 @@ def setup():
|
|
|
104
105
|
compiles(numeric.bit_rshift, "sqlite")(compile_bitwise_rshift)
|
|
105
106
|
compiles(numeric.bit_lshift, "sqlite")(compile_bitwise_lshift)
|
|
106
107
|
compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64)
|
|
108
|
+
compiles(numeric.bit_hamming_distance, "sqlite")(compile_bit_hamming_distance)
|
|
107
109
|
|
|
108
110
|
if load_usearch_extension(sqlite3.connect(":memory:")):
|
|
109
111
|
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
|
|
@@ -191,6 +193,26 @@ def sqlite_int_hash_64(x: int) -> int:
|
|
|
191
193
|
return x if x < 1 << 63 else (x & MAX_INT64) - (1 << 64)
|
|
192
194
|
|
|
193
195
|
|
|
196
|
+
def sqlite_bit_hamming_distance(a: int, b: int) -> int:
|
|
197
|
+
"""Calculate the Hamming distance between two integers."""
|
|
198
|
+
diff = (a & MAX_INT64) ^ (b & MAX_INT64)
|
|
199
|
+
if hasattr(diff, "bit_count"):
|
|
200
|
+
return diff.bit_count()
|
|
201
|
+
return bin(diff).count("1")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def sqlite_byte_hamming_distance(a: str, b: str) -> int:
|
|
205
|
+
"""Calculate the Hamming distance between two strings."""
|
|
206
|
+
diff = 0
|
|
207
|
+
if len(a) < len(b):
|
|
208
|
+
diff = len(b) - len(a)
|
|
209
|
+
b = b[: len(a)]
|
|
210
|
+
elif len(b) < len(a):
|
|
211
|
+
diff = len(a) - len(b)
|
|
212
|
+
a = a[: len(b)]
|
|
213
|
+
return diff + sum(c1 != c2 for c1, c2 in zip(a, b))
|
|
214
|
+
|
|
215
|
+
|
|
194
216
|
def register_user_defined_sql_functions() -> None:
|
|
195
217
|
# Register optional functions if we have the necessary dependencies
|
|
196
218
|
# and otherwise register functions that will raise an exception with
|
|
@@ -225,6 +247,9 @@ def register_user_defined_sql_functions() -> None:
|
|
|
225
247
|
"bitwise_lshift", 2, lambda a, b: a << b, deterministic=True
|
|
226
248
|
)
|
|
227
249
|
conn.create_function("int_hash_64", 1, sqlite_int_hash_64, deterministic=True)
|
|
250
|
+
conn.create_function(
|
|
251
|
+
"bit_hamming_distance", 2, sqlite_bit_hamming_distance, deterministic=True
|
|
252
|
+
)
|
|
228
253
|
|
|
229
254
|
_registered_function_creators["numeric_functions"] = create_numeric_functions
|
|
230
255
|
|
|
@@ -237,6 +262,9 @@ def register_user_defined_sql_functions() -> None:
|
|
|
237
262
|
conn.create_function(
|
|
238
263
|
"regexp_replace", 3, sqlite_regexp_replace, deterministic=True
|
|
239
264
|
)
|
|
265
|
+
conn.create_function(
|
|
266
|
+
"byte_hamming_distance", 2, sqlite_byte_hamming_distance, deterministic=True
|
|
267
|
+
)
|
|
240
268
|
|
|
241
269
|
_registered_function_creators["string_functions"] = create_string_functions
|
|
242
270
|
|
|
@@ -383,6 +411,18 @@ def compile_int_hash_64(element, compiler, **kwargs):
|
|
|
383
411
|
return compiler.process(func.int_hash_64(*element.clauses.clauses), **kwargs)
|
|
384
412
|
|
|
385
413
|
|
|
414
|
+
def compile_bit_hamming_distance(element, compiler, **kwargs):
|
|
415
|
+
return compiler.process(
|
|
416
|
+
func.bit_hamming_distance(*element.clauses.clauses), **kwargs
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def compile_byte_hamming_distance(element, compiler, **kwargs):
|
|
421
|
+
return compiler.process(
|
|
422
|
+
func.byte_hamming_distance(*element.clauses.clauses), **kwargs
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
|
|
386
426
|
def py_json_array_length(arr):
|
|
387
427
|
return len(orjson.loads(arr))
|
|
388
428
|
|
datachain/toolkit/split.py
CHANGED
|
@@ -1,7 +1,16 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
1
4
|
from datachain import C, DataChain
|
|
2
5
|
|
|
6
|
+
RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
|
|
7
|
+
|
|
3
8
|
|
|
4
|
-
def train_test_split(
|
|
9
|
+
def train_test_split(
|
|
10
|
+
dc: DataChain,
|
|
11
|
+
weights: list[float],
|
|
12
|
+
seed: Optional[int] = None,
|
|
13
|
+
) -> list[DataChain]:
|
|
5
14
|
"""
|
|
6
15
|
Splits a DataChain into multiple subsets based on the provided weights.
|
|
7
16
|
|
|
@@ -18,6 +27,8 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
|
|
|
18
27
|
For example:
|
|
19
28
|
- `[0.7, 0.3]` corresponds to a 70/30 split;
|
|
20
29
|
- `[2, 1, 1]` corresponds to a 50/25/25 split.
|
|
30
|
+
seed (int, optional):
|
|
31
|
+
The seed for the random number generator. Defaults to None.
|
|
21
32
|
|
|
22
33
|
Returns:
|
|
23
34
|
list[DataChain]:
|
|
@@ -58,14 +69,16 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
|
|
|
58
69
|
|
|
59
70
|
weights_normalized = [weight / sum(weights) for weight in weights]
|
|
60
71
|
|
|
61
|
-
|
|
72
|
+
rand_col = C("sys.rand")
|
|
73
|
+
if seed is not None:
|
|
74
|
+
uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
|
|
75
|
+
rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
|
|
76
|
+
rand_col = rand_col % RESOLUTION # type: ignore[assignment]
|
|
62
77
|
|
|
63
78
|
return [
|
|
64
79
|
dc.filter(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
C("sys__rand") % resolution
|
|
68
|
-
< round(sum(weights_normalized[: index + 1]) * resolution),
|
|
80
|
+
rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)),
|
|
81
|
+
rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)),
|
|
69
82
|
)
|
|
70
83
|
for index, _ in enumerate(weights_normalized)
|
|
71
84
|
]
|