datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/studio.py CHANGED
@@ -1,10 +1,16 @@
1
1
  import asyncio
2
2
  import os
3
3
  import sys
4
- from typing import TYPE_CHECKING, Optional
4
+ import warnings
5
+ from datetime import datetime, timezone
6
+ from typing import TYPE_CHECKING
7
+
8
+ import dateparser
9
+ import tabulate
5
10
 
6
11
  from datachain.config import Config, ConfigLevel
7
- from datachain.dataset import QUERY_DATASET_PREFIX
12
+ from datachain.data_storage.job import JobStatus
13
+ from datachain.dataset import QUERY_DATASET_PREFIX, parse_dataset_name
8
14
  from datachain.error import DataChainError
9
15
  from datachain.remote.studio import StudioClient
10
16
  from datachain.utils import STUDIO_URL
@@ -16,6 +22,8 @@ POST_LOGIN_MESSAGE = (
16
22
  "Once you've logged in, return here "
17
23
  "and you'll be ready to start using DataChain with Studio."
18
24
  )
25
+ RETRY_MAX_TIMES = 10
26
+ RETRY_SLEEP_SEC = 1
19
27
 
20
28
 
21
29
  def process_jobs_args(args: "Namespace"):
@@ -35,14 +43,28 @@ def process_jobs_args(args: "Namespace"):
35
43
  args.workers,
36
44
  args.files,
37
45
  args.python_version,
46
+ args.repository,
38
47
  args.req,
39
48
  args.req_file,
49
+ args.priority,
50
+ args.cluster,
51
+ args.start_time,
52
+ args.cron,
53
+ args.no_wait,
54
+ args.credentials_name,
40
55
  )
41
56
 
42
57
  if args.cmd == "cancel":
43
58
  return cancel_job(args.id, args.team)
44
59
  if args.cmd == "logs":
45
60
  return show_job_logs(args.id, args.team)
61
+
62
+ if args.cmd == "ls":
63
+ return list_jobs(args.status, args.team, args.limit)
64
+
65
+ if args.cmd == "clusters":
66
+ return list_clusters(args.team)
67
+
46
68
  raise DataChainError(f"Unknown command '{args.cmd}'.")
47
69
 
48
70
 
@@ -60,14 +82,24 @@ def process_auth_cli_args(args: "Namespace"):
60
82
  return logout(args.local)
61
83
  if args.cmd == "token":
62
84
  return token()
63
-
64
85
  if args.cmd == "team":
65
86
  return set_team(args)
66
87
  raise DataChainError(f"Unknown command '{args.cmd}'.")
67
88
 
68
89
 
69
90
  def set_team(args: "Namespace"):
70
- level = ConfigLevel.GLOBAL if args.__dict__.get("global") else ConfigLevel.LOCAL
91
+ if args.team_name is None:
92
+ config = Config().read().get("studio", {})
93
+ team = config.get("team")
94
+ if team:
95
+ print(f"Default team is '{team}'")
96
+ return 0
97
+
98
+ raise DataChainError(
99
+ "No default team set. Use `datachain auth team <team_name>` to set one."
100
+ )
101
+
102
+ level = ConfigLevel.LOCAL if args.local else ConfigLevel.GLOBAL
71
103
  config = Config(level)
72
104
  with config.edit() as conf:
73
105
  studio_conf = conf.get("studio", {})
@@ -80,11 +112,13 @@ def set_team(args: "Namespace"):
80
112
  def login(args: "Namespace"):
81
113
  from dvc_studio_client.auth import StudioAuthError, get_access_token
82
114
 
115
+ from datachain.remote.studio import get_studio_env_variable
116
+
83
117
  config = Config().read().get("studio", {})
84
118
  name = args.name
85
119
  hostname = (
86
120
  args.hostname
87
- or os.environ.get("DVC_STUDIO_URL")
121
+ or get_studio_env_variable("URL")
88
122
  or config.get("url")
89
123
  or STUDIO_URL
90
124
  )
@@ -113,6 +147,7 @@ def login(args: "Namespace"):
113
147
  level = ConfigLevel.LOCAL if args.local else ConfigLevel.GLOBAL
114
148
  config_path = save_config(hostname, access_token, level=level)
115
149
  print(f"Authentication complete. Saved token to {config_path}.")
150
+ print("You can now use 'datachain auth team' to set the default team.")
116
151
  return 0
117
152
 
118
153
 
@@ -141,7 +176,12 @@ def token():
141
176
  print(token)
142
177
 
143
178
 
144
- def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
179
+ def list_datasets(team: str | None = None, name: str | None = None):
180
+ def ds_full_name(ds: dict) -> str:
181
+ return (
182
+ f"{ds['project']['namespace']['name']}.{ds['project']['name']}.{ds['name']}"
183
+ )
184
+
145
185
  if name:
146
186
  yield from list_dataset_versions(team, name)
147
187
  return
@@ -158,18 +198,22 @@ def list_datasets(team: Optional[str] = None, name: Optional[str] = None):
158
198
 
159
199
  for d in response.data:
160
200
  name = d.get("name")
201
+ full_name = ds_full_name(d)
161
202
  if name and name.startswith(QUERY_DATASET_PREFIX):
162
203
  continue
163
204
 
164
205
  for v in d.get("versions", []):
165
206
  version = v.get("version")
166
- yield (name, version)
207
+ yield (full_name, version)
167
208
 
168
209
 
169
- def list_dataset_versions(team: Optional[str] = None, name: str = ""):
210
+ def list_dataset_versions(team: str | None = None, name: str = ""):
170
211
  client = StudioClient(team=team)
171
212
 
172
- response = client.dataset_info(name)
213
+ namespace_name, project_name, name = parse_dataset_name(name)
214
+ if not namespace_name or not project_name:
215
+ raise DataChainError(f"Missing namespace or project form dataset name {name}")
216
+ response = client.dataset_info(namespace_name, project_name, name)
173
217
 
174
218
  if not response.ok:
175
219
  raise DataChainError(response.message)
@@ -183,14 +227,18 @@ def list_dataset_versions(team: Optional[str] = None, name: str = ""):
183
227
 
184
228
 
185
229
  def edit_studio_dataset(
186
- team_name: Optional[str],
230
+ team_name: str | None,
187
231
  name: str,
188
- new_name: Optional[str] = None,
189
- description: Optional[str] = None,
190
- labels: Optional[list[str]] = None,
232
+ namespace: str,
233
+ project: str,
234
+ new_name: str | None = None,
235
+ description: str | None = None,
236
+ attrs: list[str] | None = None,
191
237
  ):
192
238
  client = StudioClient(team=team_name)
193
- response = client.edit_dataset(name, new_name, description, labels)
239
+ response = client.edit_dataset(
240
+ name, namespace, project, new_name, description, attrs
241
+ )
194
242
  if not response.ok:
195
243
  raise DataChainError(response.message)
196
244
 
@@ -198,13 +246,15 @@ def edit_studio_dataset(
198
246
 
199
247
 
200
248
  def remove_studio_dataset(
201
- team_name: Optional[str],
249
+ team_name: str | None,
202
250
  name: str,
203
- version: Optional[int] = None,
204
- force: Optional[bool] = False,
251
+ namespace: str,
252
+ project: str,
253
+ version: str | None = None,
254
+ force: bool | None = False,
205
255
  ):
206
256
  client = StudioClient(team=team_name)
207
- response = client.rm_dataset(name, version, force)
257
+ response = client.rm_dataset(name, namespace, project, version, force)
208
258
  if not response.ok:
209
259
  raise DataChainError(response.message)
210
260
 
@@ -222,42 +272,102 @@ def save_config(hostname, token, level=ConfigLevel.GLOBAL):
222
272
  return config.config_file()
223
273
 
224
274
 
275
+ def parse_start_time(start_time_str: str | None) -> str | None:
276
+ if not start_time_str:
277
+ return None
278
+
279
+ # dateparser#1246: it explores strptime patterns lacking a year, which
280
+ # triggers a CPython 3.13 DeprecationWarning. Suppress that noise until a
281
+ # new dateparser release includes the upstream fix.
282
+ # https://github.com/scrapinghub/dateparser/issues/1246
283
+ with warnings.catch_warnings():
284
+ warnings.filterwarnings(
285
+ "ignore",
286
+ category=DeprecationWarning,
287
+ module="dateparser\\.utils\\.strptime",
288
+ )
289
+ parsed_datetime = dateparser.parse(start_time_str)
290
+
291
+ if parsed_datetime is None:
292
+ raise DataChainError(
293
+ f"Could not parse datetime string: '{start_time_str}'. "
294
+ f"Supported formats include: '2024-01-15 14:30:00', 'tomorrow 3pm', "
295
+ f"'monday 9am', '2024-01-15T14:30:00Z', 'in 2 hours', etc."
296
+ )
297
+
298
+ # Convert to ISO format string
299
+ return parsed_datetime.isoformat()
300
+
301
+
225
302
  def show_logs_from_client(client, job_id):
226
303
  # Sync usage
227
304
  async def _run():
228
- async for message in client.tail_job_logs(job_id):
229
- if "logs" in message:
230
- for log in message["logs"]:
231
- print(log["message"], end="")
232
- elif "job" in message:
233
- print(f"\n>>>> Job is now in {message['job']['status']} status.")
234
-
235
- asyncio.run(_run())
305
+ retry_count = 0
306
+ latest_status = None
307
+ processed_statuses = set()
308
+ while True:
309
+ async for message in client.tail_job_logs(job_id):
310
+ if "logs" in message:
311
+ for log in message["logs"]:
312
+ print(log["message"], end="")
313
+ elif "job" in message:
314
+ latest_status = message["job"]["status"]
315
+ if latest_status in processed_statuses:
316
+ continue
317
+ processed_statuses.add(latest_status)
318
+ print(f"\n>>>> Job is now in {latest_status} status.")
319
+
320
+ try:
321
+ if retry_count > RETRY_MAX_TIMES or (
322
+ latest_status and JobStatus[latest_status].finished()
323
+ ):
324
+ break
325
+ await asyncio.sleep(RETRY_SLEEP_SEC)
326
+ retry_count += 1
327
+ except KeyError:
328
+ pass
329
+
330
+ return latest_status
331
+
332
+ final_status = asyncio.run(_run())
236
333
 
237
334
  response = client.dataset_job_versions(job_id)
238
335
  if not response.ok:
239
336
  raise DataChainError(response.message)
240
337
 
241
338
  response_data = response.data
242
- if response_data:
339
+ if response_data and response_data.get("dataset_versions"):
243
340
  dataset_versions = response_data.get("dataset_versions", [])
244
341
  print("\n\n>>>> Dataset versions created during the job:")
245
342
  for version in dataset_versions:
246
343
  print(f" - {version.get('dataset_name')}@v{version.get('version')}")
247
344
  else:
248
- print("No dataset versions created during the job.")
345
+ print("\n\nNo dataset versions created during the job.")
346
+
347
+ exit_code_by_status = {
348
+ "FAILED": 1,
349
+ "CANCELED": 2,
350
+ }
351
+ return exit_code_by_status.get(final_status.upper(), 0) if final_status else 0
249
352
 
250
353
 
251
354
  def create_job(
252
355
  query_file: str,
253
- team_name: Optional[str],
254
- env_file: Optional[str] = None,
255
- env: Optional[list[str]] = None,
256
- workers: Optional[int] = None,
257
- files: Optional[list[str]] = None,
258
- python_version: Optional[str] = None,
259
- req: Optional[list[str]] = None,
260
- req_file: Optional[str] = None,
356
+ team_name: str | None,
357
+ env_file: str | None = None,
358
+ env: list[str] | None = None,
359
+ workers: int | None = None,
360
+ files: list[str] | None = None,
361
+ python_version: str | None = None,
362
+ repository: str | None = None,
363
+ req: list[str] | None = None,
364
+ req_file: str | None = None,
365
+ priority: int | None = None,
366
+ cluster: str | None = None,
367
+ start_time: str | None = None,
368
+ cron: str | None = None,
369
+ no_wait: bool | None = False,
370
+ credentials_name: str | None = None,
261
371
  ):
262
372
  query_type = "PYTHON" if query_file.endswith(".py") else "SHELL"
263
373
  with open(query_file) as f:
@@ -276,6 +386,11 @@ def create_job(
276
386
  client = StudioClient(team=team_name)
277
387
  file_ids = upload_files(client, files) if files else []
278
388
 
389
+ # Parse start_time if provided
390
+ parsed_start_time = parse_start_time(start_time)
391
+ if cron and parsed_start_time is None:
392
+ parsed_start_time = datetime.now(timezone.utc).isoformat()
393
+
279
394
  response = client.create_job(
280
395
  query=query,
281
396
  query_type=query_type,
@@ -284,7 +399,13 @@ def create_job(
284
399
  query_name=os.path.basename(query_file),
285
400
  files=file_ids,
286
401
  python_version=python_version,
402
+ repository=repository,
287
403
  requirements=requirements,
404
+ priority=priority,
405
+ cluster=cluster,
406
+ start_time=parsed_start_time,
407
+ cron=cron,
408
+ credentials_name=credentials_name,
288
409
  )
289
410
  if not response.ok:
290
411
  raise DataChainError(response.message)
@@ -292,12 +413,17 @@ def create_job(
292
413
  if not response.data:
293
414
  raise DataChainError("Failed to create job")
294
415
 
295
- job_id = response.data.get("job", {}).get("id")
416
+ job_id = response.data.get("id")
417
+
418
+ if parsed_start_time or cron:
419
+ print(f"Job {job_id} is scheduled as a task in Studio.")
420
+ return 0
421
+
296
422
  print(f"Job {job_id} created")
297
- print("Open the job in Studio at", response.data.get("job", {}).get("url"))
423
+ print("Open the job in Studio at", response.data.get("url"))
298
424
  print("=" * 40)
299
425
 
300
- show_logs_from_client(client, job_id)
426
+ return 0 if no_wait else show_logs_from_client(client, job_id)
301
427
 
302
428
 
303
429
  def upload_files(client: StudioClient, files: list[str]) -> list[str]:
@@ -305,21 +431,19 @@ def upload_files(client: StudioClient, files: list[str]) -> list[str]:
305
431
  for file in files:
306
432
  file_name = os.path.basename(file)
307
433
  with open(file, "rb") as f:
308
- file_content = f.read()
309
- response = client.upload_file(file_content, file_name)
434
+ response = client.upload_file(f, file_name)
310
435
  if not response.ok:
311
436
  raise DataChainError(response.message)
312
437
 
313
438
  if not response.data:
314
439
  raise DataChainError(f"Failed to upload file {file_name}")
315
440
 
316
- file_id = response.data.get("blob", {}).get("id")
317
- if file_id:
441
+ if file_id := response.data.get("id"):
318
442
  file_ids.append(str(file_id))
319
443
  return file_ids
320
444
 
321
445
 
322
- def cancel_job(job_id: str, team_name: Optional[str]):
446
+ def cancel_job(job_id: str, team_name: str | None):
323
447
  token = Config().read().get("studio", {}).get("token")
324
448
  if not token:
325
449
  raise DataChainError(
@@ -334,7 +458,32 @@ def cancel_job(job_id: str, team_name: Optional[str]):
334
458
  print(f"Job {job_id} canceled")
335
459
 
336
460
 
337
- def show_job_logs(job_id: str, team_name: Optional[str]):
461
+ def list_jobs(status: str | None, team_name: str | None, limit: int):
462
+ client = StudioClient(team=team_name)
463
+ response = client.get_jobs(status, limit)
464
+ if not response.ok:
465
+ raise DataChainError(response.message)
466
+
467
+ jobs = response.data or []
468
+ if not jobs:
469
+ print("No jobs found")
470
+ return
471
+
472
+ rows = [
473
+ {
474
+ "ID": job.get("id"),
475
+ "Name": job.get("name"),
476
+ "Status": job.get("status"),
477
+ "Created at": job.get("created_at"),
478
+ "Created by": job.get("created_by"),
479
+ }
480
+ for job in jobs
481
+ ]
482
+
483
+ print(tabulate.tabulate(rows, headers="keys", tablefmt="grid"))
484
+
485
+
486
+ def show_job_logs(job_id: str, team_name: str | None):
338
487
  token = Config().read().get("studio", {}).get("token")
339
488
  if not token:
340
489
  raise DataChainError(
@@ -342,4 +491,32 @@ def show_job_logs(job_id: str, team_name: Optional[str]):
342
491
  )
343
492
 
344
493
  client = StudioClient(team=team_name)
345
- show_logs_from_client(client, job_id)
494
+ return show_logs_from_client(client, job_id)
495
+
496
+
497
+ def list_clusters(team_name: str | None):
498
+ client = StudioClient(team=team_name)
499
+ response = client.get_clusters()
500
+ if not response.ok:
501
+ raise DataChainError(response.message)
502
+
503
+ clusters = response.data or []
504
+ if not clusters:
505
+ print("No clusters found")
506
+ return
507
+
508
+ rows = [
509
+ {
510
+ "ID": cluster.get("id"),
511
+ "Name": cluster.get("name"),
512
+ "Status": cluster.get("status"),
513
+ "Cloud Provider": cluster.get("cloud_provider"),
514
+ "Cloud Credentials": cluster.get("cloud_credentials"),
515
+ "Is Active": cluster.get("is_active"),
516
+ "Is Default": cluster.get("default"),
517
+ "Max Workers": cluster.get("max_workers"),
518
+ }
519
+ for cluster in clusters
520
+ ]
521
+
522
+ print(tabulate.tabulate(rows, headers="keys", tablefmt="grid"))
@@ -1,7 +1,7 @@
1
1
  import random
2
- from typing import Optional
3
2
 
4
3
  from datachain import C, DataChain
4
+ from datachain.lib.signal_schema import SignalResolvingError
5
5
 
6
6
  RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
7
7
 
@@ -9,7 +9,7 @@ RESOLUTION = 2**31 - 1 # Maximum positive value for a 32-bit signed integer.
9
9
  def train_test_split(
10
10
  dc: DataChain,
11
11
  weights: list[float],
12
- seed: Optional[int] = None,
12
+ seed: int | None = None,
13
13
  ) -> list[DataChain]:
14
14
  """
15
15
  Splits a DataChain into multiple subsets based on the provided weights.
@@ -60,7 +60,10 @@ def train_test_split(
60
60
  ```
61
61
 
62
62
  Note:
63
- The splits are random but deterministic, based on Dataset `sys__rand` field.
63
+ Splits reuse the same best-effort shuffle used by `DataChain.shuffle`. Results
64
+ are typically repeatable, but earlier operations such as `merge`, `union`, or
65
+ custom SQL that reshuffle rows can change the outcome between runs. Add order by
66
+ stable keys first when you need strict reproducibility.
64
67
  """
65
68
  if len(weights) < 2:
66
69
  raise ValueError("Weights should have at least two elements")
@@ -69,16 +72,34 @@ def train_test_split(
69
72
 
70
73
  weights_normalized = [weight / sum(weights) for weight in weights]
71
74
 
75
+ try:
76
+ dc.signals_schema.resolve("sys.rand")
77
+ except SignalResolvingError:
78
+ dc = dc.persist()
79
+
72
80
  rand_col = C("sys.rand")
73
81
  if seed is not None:
74
82
  uniform_seed = random.Random(seed).randrange(1, RESOLUTION) # noqa: S311
75
83
  rand_col = (rand_col % RESOLUTION) * uniform_seed # type: ignore[assignment]
76
84
  rand_col = rand_col % RESOLUTION # type: ignore[assignment]
77
85
 
78
- return [
79
- dc.filter(
80
- rand_col >= round(sum(weights_normalized[:index]) * (RESOLUTION - 1)),
81
- rand_col < round(sum(weights_normalized[: index + 1]) * (RESOLUTION - 1)),
82
- )
83
- for index, _ in enumerate(weights_normalized)
84
- ]
86
+ boundaries: list[int] = [0]
87
+ cumulative = 0.0
88
+ for weight in weights_normalized[:-1]:
89
+ cumulative += weight
90
+ boundary = round(cumulative * RESOLUTION)
91
+ boundaries.append(min(boundary, RESOLUTION))
92
+ boundaries.append(RESOLUTION)
93
+
94
+ splits: list[DataChain] = []
95
+ last_index = len(weights_normalized) - 1
96
+ for index in range(len(weights_normalized)):
97
+ lower = boundaries[index]
98
+ if index == last_index:
99
+ condition = rand_col >= lower
100
+ else:
101
+ upper = boundaries[index + 1]
102
+ condition = (rand_col >= lower) & (rand_col < upper)
103
+ splits.append(dc.filter(condition))
104
+
105
+ return splits