datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -2,20 +2,20 @@ import logging
2
2
  import re
3
3
  import sqlite3
4
4
  import warnings
5
- from collections.abc import Iterable
5
+ from collections.abc import Callable, Iterable
6
+ from contextlib import closing
6
7
  from datetime import MAXYEAR, MINYEAR, datetime, timezone
7
8
  from functools import cache
8
9
  from types import MappingProxyType
9
- from typing import Callable, Optional
10
10
 
11
11
  import sqlalchemy as sa
12
- import ujson as json
13
12
  from sqlalchemy.dialects import sqlite
14
13
  from sqlalchemy.ext.compiler import compiles
15
14
  from sqlalchemy.sql.elements import literal
16
15
  from sqlalchemy.sql.expression import case
17
16
  from sqlalchemy.sql.functions import func
18
17
 
18
+ from datachain import json
19
19
  from datachain.sql.functions import (
20
20
  aggregate,
21
21
  array,
@@ -112,7 +112,10 @@ def setup():
112
112
  compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64)
113
113
  compiles(numeric.bit_hamming_distance, "sqlite")(compile_bit_hamming_distance)
114
114
 
115
- if load_usearch_extension(sqlite3.connect(":memory:")):
115
+ with closing(sqlite3.connect(":memory:")) as _usearch_conn:
116
+ usearch_available = load_usearch_extension(_usearch_conn)
117
+
118
+ if usearch_available:
116
119
  compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
117
120
  compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance_ext)
118
121
  else:
@@ -132,7 +135,7 @@ def run_compiler_hook(name):
132
135
 
133
136
 
134
137
  def functions_exist(
135
- names: Iterable[str], connection: Optional[sqlite3.Connection] = None
138
+ names: Iterable[str], connection: sqlite3.Connection | None = None
136
139
  ) -> bool:
137
140
  """
138
141
  Returns True if all function names are defined for the given connection.
@@ -146,23 +149,34 @@ def functions_exist(
146
149
  f"Found value of type {type(n).__name__}: {n!r}"
147
150
  )
148
151
 
152
+ close_connection = False
149
153
  if connection is None:
150
154
  connection = sqlite3.connect(":memory:")
155
+ close_connection = True
151
156
 
152
- if not names:
153
- return True
154
- column1 = sa.column("column1", sa.String)
155
- func_name_query = column1.not_in(
156
- sa.select(sa.column("name", sa.String)).select_from(func.pragma_function_list())
157
- )
158
- query = (
159
- sa.select(func.count() == 0)
160
- .select_from(sa.values(column1).data([(n,) for n in names]))
161
- .where(func_name_query)
162
- )
163
- comp = query.compile(dialect=sqlite_dialect)
164
- args = (comp.string, comp.params) if comp.params else (comp.string,)
165
- return bool(connection.execute(*args).fetchone()[0])
157
+ try:
158
+ if not names:
159
+ return True
160
+ column1 = sa.column("column1", sa.String)
161
+ func_name_query = column1.not_in(
162
+ sa.select(sa.column("name", sa.String)).select_from(
163
+ func.pragma_function_list()
164
+ )
165
+ )
166
+ query = (
167
+ sa.select(func.count() == 0)
168
+ .select_from(sa.values(column1).data([(n,) for n in names]))
169
+ .where(func_name_query)
170
+ )
171
+ comp = query.compile(dialect=sqlite_dialect)
172
+ if comp.params:
173
+ result = connection.execute(comp.string, comp.params)
174
+ else:
175
+ result = connection.execute(comp.string)
176
+ return bool(result.fetchone()[0])
177
+ finally:
178
+ if close_connection:
179
+ connection.close()
166
180
 
167
181
 
168
182
  def create_user_defined_sql_functions(connection):
@@ -201,9 +215,7 @@ def sqlite_int_hash_64(x: int) -> int:
201
215
  def sqlite_bit_hamming_distance(a: int, b: int) -> int:
202
216
  """Calculate the Hamming distance between two integers."""
203
217
  diff = (a & MAX_INT64) ^ (b & MAX_INT64)
204
- if hasattr(diff, "bit_count"):
205
- return diff.bit_count()
206
- return bin(diff).count("1")
218
+ return diff.bit_count()
207
219
 
208
220
 
209
221
  def sqlite_byte_hamming_distance(a: str, b: str) -> int:
@@ -215,7 +227,7 @@ def sqlite_byte_hamming_distance(a: str, b: str) -> int:
215
227
  elif len(b) < len(a):
216
228
  diff = len(a) - len(b)
217
229
  a = a[: len(b)]
218
- return diff + sum(c1 != c2 for c1, c2 in zip(a, b))
230
+ return diff + sum(c1 != c2 for c1, c2 in zip(a, b, strict=False))
219
231
 
220
232
 
221
233
  def register_user_defined_sql_functions() -> None:
@@ -470,7 +482,7 @@ def py_json_array_get_element(val, idx):
470
482
  return None
471
483
 
472
484
 
473
- def py_json_array_slice(val, offset: int, length: Optional[int] = None):
485
+ def py_json_array_slice(val, offset: int, length: int | None = None):
474
486
  arr = json.loads(val)
475
487
  try:
476
488
  return json.dumps(
@@ -605,7 +617,7 @@ def compile_collect(element, compiler, **kwargs):
605
617
 
606
618
 
607
619
  @cache
608
- def usearch_sqlite_path() -> Optional[str]:
620
+ def usearch_sqlite_path() -> str | None:
609
621
  try:
610
622
  import usearch
611
623
  except ImportError:
@@ -1,8 +1,8 @@
1
1
  import sqlite3
2
2
 
3
- import ujson as json
4
3
  from sqlalchemy import types
5
4
 
5
+ from datachain import json
6
6
  from datachain.sql.types import TypeConverter, TypeReadConverter
7
7
 
8
8
  try:
datachain/sql/types.py CHANGED
@@ -12,14 +12,15 @@ for sqlite we can use `sqlite.register_converter`
12
12
  ( https://docs.python.org/3/library/sqlite3.html#sqlite3.register_converter )
13
13
  """
14
14
 
15
+ import numbers
15
16
  from datetime import datetime
16
17
  from types import MappingProxyType
17
18
  from typing import Any, Union
18
19
 
19
20
  import sqlalchemy as sa
20
- import ujson as jsonlib
21
21
  from sqlalchemy import TypeDecorator, types
22
22
 
23
+ from datachain import json as jsonlib
23
24
  from datachain.lib.data_model import StandardType
24
25
 
25
26
  _registry: dict[str, "TypeConverter"] = {}
@@ -336,10 +337,28 @@ class Array(SQLType):
336
337
 
337
338
  @classmethod
338
339
  def from_dict(cls, d: dict[str, Any]) -> Union[type["SQLType"], "SQLType"]:
339
- sub_t = NAME_TYPES_MAPPING[d["item_type"]["type"]].from_dict( # type: ignore [attr-defined]
340
- d["item_type"]
341
- )
342
- return cls(sub_t)
340
+ try:
341
+ array_item = d["item_type"]
342
+ except KeyError as e:
343
+ raise ValueError("Array type must have 'item_type' field") from e
344
+
345
+ if not isinstance(array_item, dict):
346
+ raise TypeError("Array 'item_type' field must be a dictionary")
347
+
348
+ try:
349
+ item_type = array_item["type"]
350
+ except KeyError as e:
351
+ raise ValueError("Array 'item_type' must have 'type' field") from e
352
+
353
+ try:
354
+ sub_t = NAME_TYPES_MAPPING[item_type]
355
+ except KeyError as e:
356
+ raise ValueError(f"Array item type '{item_type}' is not supported") from e
357
+
358
+ try:
359
+ return cls(sub_t.from_dict(d["item_type"])) # type: ignore [attr-defined]
360
+ except KeyError as e:
361
+ raise ValueError(f"Array item type '{item_type}' is not supported") from e
343
362
 
344
363
  @staticmethod
345
364
  def default_value(dialect):
@@ -427,6 +446,18 @@ class TypeReadConverter:
427
446
  return value
428
447
 
429
448
  def boolean(self, value):
449
+ if value is None or isinstance(value, bool):
450
+ return value
451
+
452
+ if isinstance(value, numbers.Integral):
453
+ return bool(value)
454
+ if isinstance(value, str):
455
+ normalized = value.strip().lower()
456
+ if normalized in {"true", "t", "yes", "y", "1"}:
457
+ return True
458
+ if normalized in {"false", "f", "no", "n", "0"}:
459
+ return False
460
+
430
461
  return value
431
462
 
432
463
  def int(self, value):
datachain/studio.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import asyncio
2
2
  import os
3
3
  import sys
4
+ import warnings
4
5
  from datetime import datetime, timezone
5
- from typing import TYPE_CHECKING, Optional
6
+ from typing import TYPE_CHECKING
6
7
 
7
8
  import dateparser
8
9
  import tabulate
@@ -175,7 +176,7 @@ def token():
175
176
  print(token)
176
177
 
177
178
 
178
- def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
179
+ def list_datasets(team: str | None = None, name: str | None = None):
179
180
  def ds_full_name(ds: dict) -> str:
180
181
  return (
181
182
  f"{ds['project']['namespace']['name']}.{ds['project']['name']}.{ds['name']}"
@@ -206,7 +207,7 @@ def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
206
207
  yield (full_name, version)
207
208
 
208
209
 
209
- def list_dataset_versions(team: Optional[str] = None, name: str = ""):
210
+ def list_dataset_versions(team: str | None = None, name: str = ""):
210
211
  client = StudioClient(team=team)
211
212
 
212
213
  namespace_name, project_name, name = parse_dataset_name(name)
@@ -226,13 +227,13 @@ def list_dataset_versions(team: Optional[str] = None, name: str = ""):
226
227
 
227
228
 
228
229
  def edit_studio_dataset(
229
- team_name: Optional[str],
230
+ team_name: str | None,
230
231
  name: str,
231
232
  namespace: str,
232
233
  project: str,
233
- new_name: Optional[str] = None,
234
- description: Optional[str] = None,
235
- attrs: Optional[list[str]] = None,
234
+ new_name: str | None = None,
235
+ description: str | None = None,
236
+ attrs: list[str] | None = None,
236
237
  ):
237
238
  client = StudioClient(team=team_name)
238
239
  response = client.edit_dataset(
@@ -245,12 +246,12 @@ def edit_studio_dataset(
245
246
 
246
247
 
247
248
  def remove_studio_dataset(
248
- team_name: Optional[str],
249
+ team_name: str | None,
249
250
  name: str,
250
251
  namespace: str,
251
252
  project: str,
252
- version: Optional[str] = None,
253
- force: Optional[bool] = False,
253
+ version: str | None = None,
254
+ force: bool | None = False,
254
255
  ):
255
256
  client = StudioClient(team=team_name)
256
257
  response = client.rm_dataset(name, namespace, project, version, force)
@@ -271,12 +272,21 @@ def save_config(hostname, token, level=ConfigLevel.GLOBAL):
271
272
  return config.config_file()
272
273
 
273
274
 
274
- def parse_start_time(start_time_str: Optional[str]) -> Optional[str]:
275
+ def parse_start_time(start_time_str: str | None) -> str | None:
275
276
  if not start_time_str:
276
277
  return None
277
278
 
278
- # Parse the datetime string using dateparser
279
- parsed_datetime = dateparser.parse(start_time_str)
279
+ # dateparser#1246: it explores strptime patterns lacking a year, which
280
+ # triggers a CPython 3.13 DeprecationWarning. Suppress that noise until a
281
+ # new dateparser release includes the upstream fix.
282
+ # https://github.com/scrapinghub/dateparser/issues/1246
283
+ with warnings.catch_warnings():
284
+ warnings.filterwarnings(
285
+ "ignore",
286
+ category=DeprecationWarning,
287
+ module="dateparser\\.utils\\.strptime",
288
+ )
289
+ parsed_datetime = dateparser.parse(start_time_str)
280
290
 
281
291
  if parsed_datetime is None:
282
292
  raise DataChainError(
@@ -343,21 +353,21 @@ def show_logs_from_client(client, job_id):
343
353
 
344
354
  def create_job(
345
355
  query_file: str,
346
- team_name: Optional[str],
347
- env_file: Optional[str] = None,
348
- env: Optional[list[str]] = None,
349
- workers: Optional[int] = None,
350
- files: Optional[list[str]] = None,
351
- python_version: Optional[str] = None,
352
- repository: Optional[str] = None,
353
- req: Optional[list[str]] = None,
354
- req_file: Optional[str] = None,
355
- priority: Optional[int] = None,
356
- cluster: Optional[str] = None,
357
- start_time: Optional[str] = None,
358
- cron: Optional[str] = None,
359
- no_wait: Optional[bool] = False,
360
- credentials_name: Optional[str] = None,
356
+ team_name: str | None,
357
+ env_file: str | None = None,
358
+ env: list[str] | None = None,
359
+ workers: int | None = None,
360
+ files: list[str] | None = None,
361
+ python_version: str | None = None,
362
+ repository: str | None = None,
363
+ req: list[str] | None = None,
364
+ req_file: str | None = None,
365
+ priority: int | None = None,
366
+ cluster: str | None = None,
367
+ start_time: str | None = None,
368
+ cron: str | None = None,
369
+ no_wait: bool | None = False,
370
+ credentials_name: str | None = None,
361
371
  ):
362
372
  query_type = "PYTHON" if query_file.endswith(".py") else "SHELL"
363
373
  with open(query_file) as f:
@@ -403,14 +413,14 @@ def create_job(
403
413
  if not response.data:
404
414
  raise DataChainError("Failed to create job")
405
415
 
406
- job_id = response.data.get("job", {}).get("id")
416
+ job_id = response.data.get("id")
407
417
 
408
418
  if parsed_start_time or cron:
409
419
  print(f"Job {job_id} is scheduled as a task in Studio.")
410
420
  return 0
411
421
 
412
422
  print(f"Job {job_id} created")
413
- print("Open the job in Studio at", response.data.get("job", {}).get("url"))
423
+ print("Open the job in Studio at", response.data.get("url"))
414
424
  print("=" * 40)
415
425
 
416
426
  return 0 if no_wait else show_logs_from_client(client, job_id)
@@ -421,21 +431,19 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
421
431
  for file in files:
422
432
  file_name = os.path.basename(file)
423
433
  with open(file, "rb") as f:
424
- file_content = f.read()
425
- response = client.upload_file(file_content, file_name)
434
+ response = client.upload_file(f, file_name)
426
435
  if not response.ok:
427
436
  raise DataChainError(response.message)
428
437
 
429
438
  if not response.data:
430
439
  raise DataChainError(f"Failed to upload file {file_name}")
431
440
 
432
- file_id = response.data.get("blob", {}).get("id")
433
- if file_id:
441
+ if file_id := response.data.get("id"):
434
442
  file_ids.append(str(file_id))
435
443
  return file_ids
436
444
 
437
445
 
438
- def cancel_job(job_id: str, team_name: Optional[str]):
446
+ def cancel_job(job_id: str, team_name: str | None):
439
447
  token = Config().read().get("studio", {}).get("token")
440
448
  if not token:
441
449
  raise DataChainError(
@@ -450,13 +458,13 @@ def cancel_job(job_id: str, team_name: Optional[str]):
450
458
  print(f"Job {job_id} canceled")
451
459
 
452
460
 
453
- def list_jobs(status: Optional[str], team_name: Optional[str], limit: int):
461
+ def list_jobs(status: str | None, team_name: str | None, limit: int):
454
462
  client = StudioClient(team=team_name)
455
463
  response = client.get_jobs(status, limit)
456
464
  if not response.ok:
457
465
  raise DataChainError(response.message)
458
466
 
459
- jobs = response.data.get("jobs", [])
467
+ jobs = response.data or []
460
468
  if not jobs:
461
469
  print("No jobs found")
462
470
  return
@@ -475,7 +483,7 @@ def list_jobs(status: Optional[str], team_name: Optional[str], limit: int):
475
483
  print(tabulate.tabulate(rows, headers="keys", tablefmt="grid"))
476
484
 
477
485
 
478
- def show_job_logs(job_id: str, team_name: Optional[str]):
486
+ def show_job_logs(job_id: str, team_name: str | None):
479
487
  token = Config().read().get("studio", {}).get("token")
480
488
  if not token:
481
489
  raise DataChainError(
@@ -486,13 +494,13 @@ def show_job_logs(job_id: str, team_name: Optional[str]):
486
494
  return show_logs_from_client(client, job_id)
487
495
 
488
496
 
489
- def list_clusters(team_name: Optional[str]):
497
+ def list_clusters(team_name: str | None):
490
498
  client = StudioClient(team=team_name)
491
499
  response = client.get_clusters()
492
500
  if not response.ok:
493
501
  raise DataChainError(response.message)
494
502
 
495
- clusters = response.data.get("clusters", [])
503
+ clusters = response.data or []
496
504
  if not clusters:
497
505
  print("No clusters found")
498
506
  return
@@ -505,6 +513,7 @@ def list_clusters(team_name: Optional[str]):
505
513
  "Cloud Provider": cluster.get("cloud_provider"),
506
514
  "Cloud Credentials": cluster.get("cloud_credentials"),
507
515
  "Is Active": cluster.get("is_active"),
516
+ "Is Default": cluster.get("default"),
508
517
  "Max Workers": cluster.get("max_workers"),
509
518
  }
510
519
  for cluster in clusters
@@ -1,7 +1,7 @@
1
1
  import random
2
- from typing import Optional
3
2
 
4
3
  from datachain import C, DataChain
4
+ from datachain.lib.signal_schema import SignalResolvingError
5
5
 
6
6
  RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
7
7
 
@@ -9,7 +9,7 @@ RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
9
9
  def train_test_split(
10
10
  dc: DataChain,
11
11
  weights: list[float],
12
- seed: Optional[int] = None,
12
+ seed: int | None = None,
13
13
  ) -> list[DataChain]:
14
14
  """
15
15
  Splits a DataChain into multiple subsets based on the provided weights.
@@ -60,7 +60,10 @@ def train_test_split(
60
60
  ```
61
61
 
62
62
  Note:
63
- The splits are random but deterministic, based on Dataset `sys__rand` field.
63
+ Splits reuse the same best-effort shuffle used by `DataChain.shuffle`. Results
64
+ are typically repeatable, but earlier operations such as `merge`, `union`, or
65
+ custom SQL that reshuffle rows can change the outcome between runs. Add order by
66
+ stable keys first when you need strict reproducibility.
64
67
  """
65
68
  if len(weights) < 2:
66
69
  raise ValueError("Weights should have at least two elements")
@@ -69,16 +72,34 @@ def train_test_split(
69
72
 
70
73
  weights_normalized = [weight / sum(weights) for weight in weights]
71
74
 
75
+ try:
76
+ dc.signals_schema.resolve("sys.rand")
77
+ except SignalResolvingError:
78
+ dc = dc.persist()
79
+
72
80
  rand_col = C("sys.rand")
73
81
  if seed is not None:
74
82
  uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
75
83
  rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
76
84
  rand_col = rand_col % RESOLUTION # type: ignore[assignment]
77
85
 
78
- return [
79
- dc.filter(
80
- rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)),
81
- rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)),
82
- )
83
- for index, _ in enumerate(weights_normalized)
84
- ]
86
+ boundaries: list[int] = [0]
87
+ cumulative = 0.0
88
+ for weight in weights_normalized[:-1]:
89
+ cumulative += weight
90
+ boundary = round(cumulative * RESOLUTION)
91
+ boundaries.append(min(boundary, RESOLUTION))
92
+ boundaries.append(RESOLUTION)
93
+
94
+ splits: list[DataChain] = []
95
+ last_index = len(weights_normalized) - 1
96
+ for index in range(len(weights_normalized)):
97
+ lower = boundaries[index]
98
+ if index == last_index:
99
+ condition = rand_col >= lower
100
+ else:
101
+ upper = boundaries[index + 1]
102
+ condition = (rand_col >= lower) & (rand_col < upper)
103
+ splits.append(dc.filter(condition))
104
+
105
+ return splits