datachain 0.7.9__py3-none-any.whl → 0.7.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1,4 +1,3 @@
1
1
  from .fsspec import Client
2
- from .s3 import ClientS3
3
2
 
4
- __all__ = ["Client", "ClientS3"]
3
+ __all__ = ["Client"]
@@ -172,7 +172,7 @@ class Client(ABC):
172
172
  return url == cls.PREFIX
173
173
 
174
174
  @classmethod
175
- def get_uri(cls, name) -> "StorageURI":
175
+ def get_uri(cls, name: str) -> "StorageURI":
176
176
  from datachain.dataset import StorageURI
177
177
 
178
178
  return StorageURI(f"{cls.PREFIX}{name}")
@@ -278,7 +278,9 @@ class Client(ABC):
278
278
  ) -> None:
279
279
  await self._fetch_nested(start_prefix, result_queue)
280
280
 
281
- async def _fetch_dir(self, prefix, pbar, result_queue: ResultQueue) -> set[str]:
281
+ async def _fetch_dir(
282
+ self, prefix: str, pbar, result_queue: ResultQueue
283
+ ) -> set[str]:
282
284
  path = f"{self.name}/{prefix}"
283
285
  infos = await self.ls_dir(path)
284
286
  files = []
datachain/client/local.py CHANGED
@@ -12,6 +12,7 @@ from datachain.lib.file import File
12
12
  from .fsspec import Client
13
13
 
14
14
  if TYPE_CHECKING:
15
+ from datachain.cache import DataChainCache
15
16
  from datachain.dataset import StorageURI
16
17
 
17
18
 
@@ -21,7 +22,11 @@ class FileClient(Client):
21
22
  protocol = "file"
22
23
 
23
24
  def __init__(
24
- self, name: str, fs_kwargs: dict[str, Any], cache, use_symlinks: bool = False
25
+ self,
26
+ name: str,
27
+ fs_kwargs: dict[str, Any],
28
+ cache: "DataChainCache",
29
+ use_symlinks: bool = False,
25
30
  ) -> None:
26
31
  super().__init__(name, fs_kwargs, cache)
27
32
  self.use_symlinks = use_symlinks
@@ -30,7 +35,7 @@ class FileClient(Client):
30
35
  raise TypeError("Signed urls are not implemented for local file system")
31
36
 
32
37
  @classmethod
33
- def get_uri(cls, name) -> "StorageURI":
38
+ def get_uri(cls, name: str) -> "StorageURI":
34
39
  from datachain.dataset import StorageURI
35
40
 
36
41
  return StorageURI(f'{cls.PREFIX}/{name.removeprefix("/")}')
@@ -77,7 +82,7 @@ class FileClient(Client):
77
82
  return bucket, path
78
83
 
79
84
  @classmethod
80
- def from_name(cls, name: str, cache, kwargs) -> "FileClient":
85
+ def from_name(cls, name: str, cache: "DataChainCache", kwargs) -> "FileClient":
81
86
  use_symlinks = kwargs.pop("use_symlinks", False)
82
87
  return cls(name, kwargs, cache, use_symlinks=use_symlinks)
83
88
 
@@ -85,7 +90,7 @@ class FileClient(Client):
85
90
  def from_source(
86
91
  cls,
87
92
  uri: str,
88
- cache,
93
+ cache: "DataChainCache",
89
94
  use_symlinks: bool = False,
90
95
  **kwargs,
91
96
  ) -> "FileClient":
@@ -17,8 +17,9 @@ from .aggregate import (
17
17
  )
18
18
  from .array import cosine_distance, euclidean_distance, length, sip_hash_64
19
19
  from .conditional import greatest, least
20
- from .numeric import bit_and, bit_or, bit_xor, int_hash_64
20
+ from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
21
21
  from .random import rand
22
+ from .string import byte_hamming_distance
22
23
  from .window import window
23
24
 
24
25
  __all__ = [
@@ -26,8 +27,10 @@ __all__ = [
26
27
  "array",
27
28
  "avg",
28
29
  "bit_and",
30
+ "bit_hamming_distance",
29
31
  "bit_or",
30
32
  "bit_xor",
33
+ "byte_hamming_distance",
31
34
  "case",
32
35
  "collect",
33
36
  "concat",
datachain/func/numeric.py CHANGED
@@ -160,3 +160,49 @@ def int_hash_64(col: Union[ColT, int]) -> Func:
160
160
  return Func(
161
161
  "int_hash_64", inner=numeric.int_hash_64, cols=cols, args=args, result_type=int
162
162
  )
163
+
164
+
165
+ def bit_hamming_distance(*args: Union[ColT, int]) -> Func:
166
+ """
167
+ Computes the Hamming distance between the bit representations of two integer values.
168
+
169
+ The Hamming distance is the number of positions at which the corresponding bits
170
+ are different. This function returns the dissimilarity between the integers,
171
+ where 0 indicates identical integers and values closer to the number of bits
172
+ in the integer indicate higher dissimilarity.
173
+
174
+ Args:
175
+ args (str | int): Two integers to compute the Hamming distance between.
176
+ If a str is provided, it is assumed to be the name of the column.
177
+ If an int is provided, it is assumed to be an integer literal.
178
+
179
+ Returns:
180
+ Func: A Func object that represents the Hamming distance function.
181
+
182
+ Example:
183
+ ```py
184
+ dc.mutate(
185
+ ham_dist=func.bit_hamming_distance("embed1", 123456),
186
+ )
187
+ ```
188
+
189
+ Notes:
190
+ - Result column will always be of type int.
191
+ """
192
+ cols, func_args = [], []
193
+ for arg in args:
194
+ if isinstance(arg, int):
195
+ func_args.append(arg)
196
+ else:
197
+ cols.append(arg)
198
+
199
+ if len(cols) + len(func_args) != 2:
200
+ raise ValueError("bit_hamming_distance() requires exactly two arguments")
201
+
202
+ return Func(
203
+ "bit_hamming_distance",
204
+ inner=numeric.bit_hamming_distance,
205
+ cols=cols,
206
+ args=func_args,
207
+ result_type=int,
208
+ )
datachain/func/string.py CHANGED
@@ -152,3 +152,49 @@ def regexp_replace(col: Union[str, Func], regex: str, replacement: str) -> Func:
152
152
  args = None
153
153
 
154
154
  return Func("regexp_replace", inner=inner, cols=cols, args=args, result_type=str)
155
+
156
+
157
+ def byte_hamming_distance(*args: Union[str, Func]) -> Func:
158
+ """
159
+ Computes the Hamming distance between two strings.
160
+
161
+ The Hamming distance is the number of positions at which the corresponding
162
+ characters are different. This function returns the dissimilarity between
163
+ the strings, where 0 indicates identical strings and values closer to the length
164
+ of the strings indicate higher dissimilarity.
165
+
166
+ Args:
167
+ args (str | literal): Two strings to compute the Hamming distance between.
168
+ If a str is provided, it is assumed to be the name of the column.
169
+ If a Literal is provided, it is assumed to be a string literal.
170
+
171
+ Returns:
172
+ Func: A Func object that represents the Hamming distance function.
173
+
174
+ Example:
175
+ ```py
176
+ dc.mutate(
177
+ ham_dist=func.byte_hamming_distance("file.phash", literal("hello")),
178
+ )
179
+ ```
180
+
181
+ Notes:
182
+ - Result column will always be of type int.
183
+ """
184
+ cols, func_args = [], []
185
+ for arg in args:
186
+ if get_origin(arg) is literal:
187
+ func_args.append(arg)
188
+ else:
189
+ cols.append(arg)
190
+
191
+ if len(cols) + len(func_args) != 2:
192
+ raise ValueError("byte_hamming_distance() requires exactly two arguments")
193
+
194
+ return Func(
195
+ "byte_hamming_distance",
196
+ inner=string.byte_hamming_distance,
197
+ cols=cols,
198
+ args=func_args,
199
+ result_type=int,
200
+ )
@@ -1,19 +1,21 @@
1
+ from collections.abc import Generator
2
+
1
3
  from pydantic import BaseModel
2
4
 
3
5
  from datachain.lib.model_store import ModelStore
4
6
 
5
7
 
6
- def flatten(obj: BaseModel):
8
+ def flatten(obj: BaseModel) -> tuple:
7
9
  return tuple(_flatten_fields_values(obj.model_fields, obj))
8
10
 
9
11
 
10
- def flatten_list(obj_list):
12
+ def flatten_list(obj_list: list[BaseModel]) -> tuple:
11
13
  return tuple(
12
14
  val for obj in obj_list for val in _flatten_fields_values(obj.model_fields, obj)
13
15
  )
14
16
 
15
17
 
16
- def _flatten_list_field(value: list):
18
+ def _flatten_list_field(value: list) -> list:
17
19
  assert isinstance(value, list)
18
20
  if value and ModelStore.is_pydantic(type(value[0])):
19
21
  return [val.model_dump() for val in value]
@@ -22,7 +24,7 @@ def _flatten_list_field(value: list):
22
24
  return value
23
25
 
24
26
 
25
- def _flatten_fields_values(fields, obj: BaseModel):
27
+ def _flatten_fields_values(fields: dict, obj: BaseModel) -> Generator:
26
28
  for name, f_info in fields.items():
27
29
  anno = f_info.annotation
28
30
  # Optimization: Access attributes directly to skip the model_dump() call.
@@ -40,5 +42,5 @@ def _flatten_fields_values(fields, obj: BaseModel):
40
42
  yield value
41
43
 
42
44
 
43
- def _flatten(obj):
45
+ def _flatten(obj: BaseModel) -> tuple:
44
46
  return tuple(_flatten_fields_values(obj.model_fields, obj))
@@ -9,12 +9,12 @@ from pydantic import BaseModel
9
9
  from datachain.query.schema import DEFAULT_DELIMITER
10
10
 
11
11
 
12
- def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos=0) -> dict:
12
+ def unflatten_to_json(model: type[BaseModel], row: Sequence[Any], pos: int = 0) -> dict:
13
13
  return unflatten_to_json_pos(model, row, pos)[0]
14
14
 
15
15
 
16
16
  def unflatten_to_json_pos(
17
- model: type[BaseModel], row: Sequence[Any], pos=0
17
+ model: type[BaseModel], row: Sequence[Any], pos: int = 0
18
18
  ) -> tuple[dict, int]:
19
19
  res = {}
20
20
  for name, f_info in model.model_fields.items():
@@ -11,7 +11,7 @@ from datachain.lib.utils import DataChainParamsError
11
11
 
12
12
 
13
13
  class ValuesToTupleError(DataChainParamsError):
14
- def __init__(self, ds_name, msg):
14
+ def __init__(self, ds_name: str, msg: str):
15
15
  if ds_name:
16
16
  ds_name = f"' {ds_name}'"
17
17
  super().__init__(f"Cannot convert signals for dataset{ds_name}: {msg}")
datachain/lib/dc.py CHANGED
@@ -19,7 +19,6 @@ from typing import (
19
19
  )
20
20
 
21
21
  import orjson
22
- import pandas as pd
23
22
  import sqlalchemy
24
23
  from pydantic import BaseModel
25
24
  from sqlalchemy.sql.functions import GenericFunction
@@ -57,6 +56,7 @@ from datachain.telemetry import telemetry
57
56
  from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
58
57
 
59
58
  if TYPE_CHECKING:
59
+ import pandas as pd
60
60
  from pyarrow import DataType as ArrowDataType
61
61
  from typing_extensions import Concatenate, ParamSpec, Self
62
62
 
@@ -1701,6 +1701,8 @@ class DataChain:
1701
1701
  Parameters:
1702
1702
  flatten : Whether to use a multiindex or flatten column names.
1703
1703
  """
1704
+ import pandas as pd
1705
+
1704
1706
  headers, max_length = self._effective_signals_schema.get_headers_with_length()
1705
1707
  if flatten or max_length < 2:
1706
1708
  columns = [".".join(filter(None, header)) for header in headers]
@@ -1724,6 +1726,8 @@ class DataChain:
1724
1726
  transpose : Whether to transpose rows and columns.
1725
1727
  truncate : Whether or not to truncate the contents of columns.
1726
1728
  """
1729
+ import pandas as pd
1730
+
1727
1731
  dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
1728
1732
  df = dc.to_pandas(flatten)
1729
1733
 
datachain/lib/file.py CHANGED
@@ -17,7 +17,6 @@ from urllib.request import url2pathname
17
17
 
18
18
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
19
19
  from PIL import Image
20
- from pyarrow.dataset import dataset
21
20
  from pydantic import Field, field_validator
22
21
 
23
22
  from datachain.client.fileslice import FileSlice
@@ -452,6 +451,8 @@ class ArrowRow(DataModel):
452
451
  @contextmanager
453
452
  def open(self):
454
453
  """Stream row contents from indexed file."""
454
+ from pyarrow.dataset import dataset
455
+
455
456
  if self.file._caching_enabled:
456
457
  self.file.ensure_cached()
457
458
  path = self.file.get_local_path()
@@ -6,7 +6,6 @@ from collections.abc import Iterator
6
6
  from pathlib import Path
7
7
  from typing import Callable
8
8
 
9
- import datamodel_code_generator
10
9
  import jmespath as jsp
11
10
  from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401
12
11
 
@@ -67,6 +66,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
67
66
  data_type = "json" # treat json line as plain JSON in auto-schema
68
67
  data_string = json.dumps(json_object)
69
68
 
69
+ import datamodel_code_generator
70
+
70
71
  input_file_types = {i.value: i for i in datamodel_code_generator.InputFileType}
71
72
  input_file_type = input_file_types[data_type]
72
73
  with tempfile.TemporaryDirectory() as tmpdir:
datachain/lib/pytorch.py CHANGED
@@ -7,7 +7,6 @@ from torch import float32
7
7
  from torch.distributed import get_rank, get_world_size
8
8
  from torch.utils.data import IterableDataset, get_worker_info
9
9
  from torchvision.transforms import v2
10
- from tqdm import tqdm
11
10
 
12
11
  from datachain import Session
13
12
  from datachain.asyn import AsyncMapper
@@ -112,10 +111,7 @@ class PytorchDataset(IterableDataset):
112
111
  from datachain.lib.udf import _prefetch_input
113
112
 
114
113
  rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
115
-
116
- desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
117
- with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it:
118
- yield from map(self._process_row, rows_it)
114
+ yield from map(self._process_row, rows)
119
115
 
120
116
  def _process_row(self, row_features):
121
117
  row = []
@@ -402,9 +402,20 @@ class SignalSchema:
402
402
  if ModelStore.is_pydantic(finfo.annotation):
403
403
  SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)
404
404
 
405
- def get_column_type(self, col_name: str) -> DataType:
405
+ def get_column_type(self, col_name: str, with_subtree: bool = False) -> DataType:
406
+ """
407
+ Returns column type by column name.
408
+
409
+ If `with_subtree` is True, then it will return the type of the column
410
+ even if it has a subtree (e.g. model with nested fields), otherwise it will
411
+ return the type of the column (standard type field, not the model).
412
+
413
+ If column is not found, raises `SignalResolvingError`.
414
+ """
406
415
  for path, _type, has_subtree, _ in self.get_flat_tree():
407
- if not has_subtree and DEFAULT_DELIMITER.join(path) == col_name:
416
+ if (with_subtree or not has_subtree) and DEFAULT_DELIMITER.join(
417
+ path
418
+ ) == col_name:
408
419
  return _type
409
420
  raise SignalResolvingError([col_name], "is not found")
410
421
 
@@ -492,14 +503,25 @@ class SignalSchema:
492
503
  # renaming existing signal
493
504
  del new_values[value.name]
494
505
  new_values[name] = self.values[value.name]
495
- elif isinstance(value, Func):
506
+ continue
507
+ if isinstance(value, Column):
508
+ # adding new signal from existing signal field
509
+ try:
510
+ new_values[name] = self.get_column_type(
511
+ value.name, with_subtree=True
512
+ )
513
+ continue
514
+ except SignalResolvingError:
515
+ pass
516
+ if isinstance(value, Func):
496
517
  # adding new signal with function
497
518
  new_values[name] = value.get_result_type(self)
498
- elif isinstance(value, ColumnElement):
519
+ continue
520
+ if isinstance(value, ColumnElement):
499
521
  # adding new signal
500
522
  new_values[name] = sql_to_python(value)
501
- else:
502
- new_values[name] = value
523
+ continue
524
+ new_values[name] = value
503
525
 
504
526
  return SignalSchema(new_values)
505
527
 
datachain/lib/utils.py CHANGED
@@ -28,7 +28,7 @@ class DataChainParamsError(DataChainError):
28
28
 
29
29
 
30
30
  class DataChainColumnError(DataChainParamsError):
31
- def __init__(self, col_name, msg):
31
+ def __init__(self, col_name: str, msg: str):
32
32
  super().__init__(f"Error for column {col_name}: {msg}")
33
33
 
34
34
 
@@ -35,7 +35,6 @@ from sqlalchemy.sql.schema import TableClause
35
35
  from sqlalchemy.sql.selectable import Select
36
36
 
37
37
  from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
38
- from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
39
38
  from datachain.data_storage.schema import (
40
39
  PARTITION_COLUMN_ID,
41
40
  partition_col_names,
@@ -215,7 +214,7 @@ class DatasetDiffOperation(Step):
215
214
  Should return select query that calculates desired diff between dataset queries
216
215
  """
217
216
 
218
- def apply(self, query_generator, temp_tables: list[str]):
217
+ def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
219
218
  source_query = query_generator.exclude(("sys__id",))
220
219
  target_query = self.dq.apply_steps().select()
221
220
  temp_tables.extend(self.dq.temp_table_names)
@@ -394,6 +393,8 @@ class UDFStep(Step, ABC):
394
393
  """
395
394
 
396
395
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
396
+ from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
397
+
397
398
  use_partitioning = self.partition_by is not None
398
399
  batching = self.udf.get_batching(use_partitioning)
399
400
  workers = self.workers
@@ -1087,6 +1088,8 @@ class DatasetQuery:
1087
1088
  def delete(
1088
1089
  name: str, version: Optional[int] = None, catalog: Optional["Catalog"] = None
1089
1090
  ) -> None:
1091
+ from datachain.catalog import get_catalog
1092
+
1090
1093
  catalog = catalog or get_catalog()
1091
1094
  version = version or catalog.get_dataset(name).latest_version
1092
1095
  catalog.remove_dataset(name, version)
@@ -35,9 +35,21 @@ class int_hash_64(GenericFunction): # noqa: N801
35
35
  inherit_cache = True
36
36
 
37
37
 
38
+ class bit_hamming_distance(GenericFunction): # noqa: N801
39
+ """
40
+ Returns the Hamming distance between two integers.
41
+ """
42
+
43
+ type = Int64()
44
+ package = "numeric"
45
+ name = "hamming_distance"
46
+ inherit_cache = True
47
+
48
+
38
49
  compiler_not_implemented(bit_and)
39
50
  compiler_not_implemented(bit_or)
40
51
  compiler_not_implemented(bit_xor)
41
52
  compiler_not_implemented(bit_rshift)
42
53
  compiler_not_implemented(bit_lshift)
43
54
  compiler_not_implemented(int_hash_64)
55
+ compiler_not_implemented(bit_hamming_distance)
@@ -48,7 +48,19 @@ class replace(GenericFunction): # noqa: N801
48
48
  inherit_cache = True
49
49
 
50
50
 
51
+ class byte_hamming_distance(GenericFunction): # noqa: N801
52
+ """
53
+ Returns the Hamming distance between two strings.
54
+ """
55
+
56
+ type = Int64()
57
+ package = "string"
58
+ name = "hamming_distance"
59
+ inherit_cache = True
60
+
61
+
51
62
  compiler_not_implemented(length)
52
63
  compiler_not_implemented(split)
53
64
  compiler_not_implemented(regexp_replace)
54
65
  compiler_not_implemented(replace)
66
+ compiler_not_implemented(byte_hamming_distance)
@@ -90,6 +90,7 @@ def setup():
90
90
  compiles(string.split, "sqlite")(compile_string_split)
91
91
  compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
92
92
  compiles(string.replace, "sqlite")(compile_string_replace)
93
+ compiles(string.byte_hamming_distance, "sqlite")(compile_byte_hamming_distance)
93
94
  compiles(conditional.greatest, "sqlite")(compile_greatest)
94
95
  compiles(conditional.least, "sqlite")(compile_least)
95
96
  compiles(Values, "sqlite")(compile_values)
@@ -104,6 +105,7 @@ def setup():
104
105
  compiles(numeric.bit_rshift, "sqlite")(compile_bitwise_rshift)
105
106
  compiles(numeric.bit_lshift, "sqlite")(compile_bitwise_lshift)
106
107
  compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64)
108
+ compiles(numeric.bit_hamming_distance, "sqlite")(compile_bit_hamming_distance)
107
109
 
108
110
  if load_usearch_extension(sqlite3.connect(":memory:")):
109
111
  compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
@@ -191,6 +193,26 @@ def sqlite_int_hash_64(x: int) -> int:
191
193
  return x if x < 1 << 63 else (x & MAX_INT64) - (1 << 64)
192
194
 
193
195
 
196
+ def sqlite_bit_hamming_distance(a: int, b: int) -> int:
197
+ """Calculate the Hamming distance between two integers."""
198
+ diff = (a & MAX_INT64) ^ (b & MAX_INT64)
199
+ if hasattr(diff, "bit_count"):
200
+ return diff.bit_count()
201
+ return bin(diff).count("1")
202
+
203
+
204
+ def sqlite_byte_hamming_distance(a: str, b: str) -> int:
205
+ """Calculate the Hamming distance between two strings."""
206
+ diff = 0
207
+ if len(a) < len(b):
208
+ diff = len(b) - len(a)
209
+ b = b[: len(a)]
210
+ elif len(b) < len(a):
211
+ diff = len(a) - len(b)
212
+ a = a[: len(b)]
213
+ return diff + sum(c1 != c2 for c1, c2 in zip(a, b))
214
+
215
+
194
216
  def register_user_defined_sql_functions() -> None:
195
217
  # Register optional functions if we have the necessary dependencies
196
218
  # and otherwise register functions that will raise an exception with
@@ -225,6 +247,9 @@ def register_user_defined_sql_functions() -> None:
225
247
  "bitwise_lshift", 2, lambda a, b: a << b, deterministic=True
226
248
  )
227
249
  conn.create_function("int_hash_64", 1, sqlite_int_hash_64, deterministic=True)
250
+ conn.create_function(
251
+ "bit_hamming_distance", 2, sqlite_bit_hamming_distance, deterministic=True
252
+ )
228
253
 
229
254
  _registered_function_creators["numeric_functions"] = create_numeric_functions
230
255
 
@@ -237,6 +262,9 @@ def register_user_defined_sql_functions() -> None:
237
262
  conn.create_function(
238
263
  "regexp_replace", 3, sqlite_regexp_replace, deterministic=True
239
264
  )
265
+ conn.create_function(
266
+ "byte_hamming_distance", 2, sqlite_byte_hamming_distance, deterministic=True
267
+ )
240
268
 
241
269
  _registered_function_creators["string_functions"] = create_string_functions
242
270
 
@@ -383,6 +411,18 @@ def compile_int_hash_64(element, compiler, **kwargs):
383
411
  return compiler.process(func.int_hash_64(*element.clauses.clauses), **kwargs)
384
412
 
385
413
 
414
+ def compile_bit_hamming_distance(element, compiler, **kwargs):
415
+ return compiler.process(
416
+ func.bit_hamming_distance(*element.clauses.clauses), **kwargs
417
+ )
418
+
419
+
420
+ def compile_byte_hamming_distance(element, compiler, **kwargs):
421
+ return compiler.process(
422
+ func.byte_hamming_distance(*element.clauses.clauses), **kwargs
423
+ )
424
+
425
+
386
426
  def py_json_array_length(arr):
387
427
  return len(orjson.loads(arr))
388
428
 
@@ -1,7 +1,16 @@
1
+ import random
2
+ from typing import Optional
3
+
1
4
  from datachain import C, DataChain
2
5
 
6
+ RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
7
+
3
8
 
4
- def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
9
+ def train_test_split(
10
+ dc: DataChain,
11
+ weights: list[float],
12
+ seed: Optional[int] = None,
13
+ ) -> list[DataChain]:
5
14
  """
6
15
  Splits a DataChain into multiple subsets based on the provided weights.
7
16
 
@@ -18,6 +27,8 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
18
27
  For example:
19
28
  - `[0.7, 0.3]` corresponds to a 70/30 split;
20
29
  - `[2, 1, 1]` corresponds to a 50/25/25 split.
30
+ seed (int, optional):
31
+ The seed for the random number generator. Defaults to None.
21
32
 
22
33
  Returns:
23
34
  list[DataChain]:
@@ -58,14 +69,16 @@ def train_test_split(dc: DataChain, weights: list[float]) -> list[DataChain]:
58
69
 
59
70
  weights_normalized = [weight / sum(weights) for weight in weights]
60
71
 
61
- resolution = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
72
+ rand_col = C("sys.rand")
73
+ if seed is not None:
74
+ uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
75
+ rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
76
+ rand_col = rand_col % RESOLUTION # type: ignore[assignment]
62
77
 
63
78
  return [
64
79
  dc.filter(
65
- C("sys__rand") % resolution
66
- >= round(sum(weights_normalized[:index]) * resolution),
67
- C("sys__rand") % resolution
68
- < round(sum(weights_normalized[: index + 1]) * resolution),
80
+ rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)),
81
+ rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)),
69
82
  )
70
83
  for index, _ in enumerate(weights_normalized)
71
84
  ]