datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,40 @@
1
1
  import copy
2
+ import hashlib
3
+ import logging
2
4
  import os
3
5
  import os.path
4
6
  import sys
5
7
  import warnings
6
- from collections.abc import Iterator, Sequence
8
+ from collections.abc import Callable, Iterator, Sequence
7
9
  from typing import (
8
10
  IO,
9
11
  TYPE_CHECKING,
10
12
  Any,
11
13
  BinaryIO,
12
- Callable,
13
14
  ClassVar,
14
15
  Literal,
15
- Optional,
16
16
  TypeVar,
17
- Union,
18
17
  cast,
19
18
  overload,
20
19
  )
21
20
 
22
21
  import sqlalchemy
23
- import ujson as json
24
22
  from pydantic import BaseModel
25
23
  from sqlalchemy.sql.elements import ColumnElement
26
24
  from tqdm import tqdm
27
25
 
28
- from datachain import semver
26
+ from datachain import json, semver
29
27
  from datachain.dataset import DatasetRecord
30
28
  from datachain.delta import delta_disabled
31
- from datachain.error import ProjectCreateNotAllowedError, ProjectNotFoundError
29
+ from datachain.error import (
30
+ JobAncestryDepthExceededError,
31
+ ProjectCreateNotAllowedError,
32
+ ProjectNotFoundError,
33
+ )
32
34
  from datachain.func import literal
33
35
  from datachain.func.base import Function
34
36
  from datachain.func.func import Func
37
+ from datachain.job import Job
35
38
  from datachain.lib.convert.python_to_sql import python_to_sql
36
39
  from datachain.lib.data_model import (
37
40
  DataModel,
@@ -40,11 +43,7 @@ from datachain.lib.data_model import (
40
43
  StandardType,
41
44
  dict_to_data_model,
42
45
  )
43
- from datachain.lib.file import (
44
- EXPORT_FILES_MAX_THREADS,
45
- ArrowRow,
46
- FileExporter,
47
- )
46
+ from datachain.lib.file import EXPORT_FILES_MAX_THREADS, ArrowRow, File, FileExporter
48
47
  from datachain.lib.file import ExportPlacement as FileExportPlacement
49
48
  from datachain.lib.model_store import ModelStore
50
49
  from datachain.lib.settings import Settings
@@ -52,11 +51,17 @@ from datachain.lib.signal_schema import SignalResolvingError, SignalSchema
52
51
  from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
53
52
  from datachain.lib.udf_signature import UdfSignature
54
53
  from datachain.lib.utils import DataChainColumnError, DataChainParamsError
54
+ from datachain.project import Project
55
55
  from datachain.query import Session
56
- from datachain.query.dataset import DatasetQuery, PartitionByType
56
+ from datachain.query.dataset import (
57
+ DatasetQuery,
58
+ PartitionByType,
59
+ RegenerateSystemColumns,
60
+ UnionSchemaMismatchError,
61
+ )
57
62
  from datachain.query.schema import DEFAULT_DELIMITER, Column
58
63
  from datachain.sql.functions import path as pathfunc
59
- from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
64
+ from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict
60
65
 
61
66
  from .database import DEFAULT_DATABASE_BATCH_SIZE
62
67
  from .utils import (
@@ -71,6 +76,8 @@ from .utils import (
71
76
  resolve_columns,
72
77
  )
73
78
 
79
+ logger = logging.getLogger("datachain")
80
+
74
81
  C = Column
75
82
 
76
83
  _T = TypeVar("_T")
@@ -82,19 +89,20 @@ if TYPE_CHECKING:
82
89
  import sqlite3
83
90
 
84
91
  import pandas as pd
92
+ from sqlalchemy.orm import Session as OrmSession
85
93
  from typing_extensions import ParamSpec, Self
86
94
 
87
95
  P = ParamSpec("P")
88
96
 
89
- ConnectionType = Union[
90
- str,
91
- sqlalchemy.engine.URL,
92
- sqlalchemy.engine.interfaces.Connectable,
93
- sqlalchemy.engine.Engine,
94
- sqlalchemy.engine.Connection,
95
- "sqlalchemy.orm.Session",
96
- sqlite3.Connection,
97
- ]
97
+ ConnectionType = (
98
+ str
99
+ | sqlalchemy.engine.URL
100
+ | sqlalchemy.engine.interfaces.Connectable
101
+ | sqlalchemy.engine.Engine
102
+ | sqlalchemy.engine.Connection
103
+ | OrmSession
104
+ | sqlite3.Connection
105
+ )
98
106
 
99
107
 
100
108
  T = TypeVar("T", bound="DataChain")
@@ -183,7 +191,7 @@ class DataChain:
183
191
  query: DatasetQuery,
184
192
  settings: Settings,
185
193
  signal_schema: SignalSchema,
186
- setup: Optional[dict] = None,
194
+ setup: dict | None = None,
187
195
  _sys: bool = False,
188
196
  ) -> None:
189
197
  """Don't instantiate this directly, use one of the from_XXX constructors."""
@@ -193,10 +201,11 @@ class DataChain:
193
201
  self._setup: dict = setup or {}
194
202
  self._sys = _sys
195
203
  self._delta = False
196
- self._delta_on: Optional[Union[str, Sequence[str]]] = None
197
- self._delta_result_on: Optional[Union[str, Sequence[str]]] = None
198
- self._delta_compare: Optional[Union[str, Sequence[str]]] = None
199
- self._delta_retry: Optional[Union[bool, str]] = None
204
+ self._delta_unsafe = False
205
+ self._delta_on: str | Sequence[str] | None = None
206
+ self._delta_result_on: str | Sequence[str] | None = None
207
+ self._delta_compare: str | Sequence[str] | None = None
208
+ self._delta_retry: bool | str | None = None
200
209
 
201
210
  def __repr__(self) -> str:
202
211
  """Return a string representation of the chain."""
@@ -210,12 +219,21 @@ class DataChain:
210
219
  self.print_schema(file=file)
211
220
  return file.getvalue()
212
221
 
222
+ def hash(self) -> str:
223
+ """
224
+ Calculates SHA hash of this chain. Hash calculation is fast and consistent.
225
+ It takes into account all the steps added to the chain and their inputs.
226
+ Order of the steps is important.
227
+ """
228
+ return self._query.hash()
229
+
213
230
  def _as_delta(
214
231
  self,
215
- on: Optional[Union[str, Sequence[str]]] = None,
216
- right_on: Optional[Union[str, Sequence[str]]] = None,
217
- compare: Optional[Union[str, Sequence[str]]] = None,
218
- delta_retry: Optional[Union[bool, str]] = None,
232
+ on: str | Sequence[str] | None = None,
233
+ right_on: str | Sequence[str] | None = None,
234
+ compare: str | Sequence[str] | None = None,
235
+ delta_retry: bool | str | None = None,
236
+ delta_unsafe: bool = False,
219
237
  ) -> "Self":
220
238
  """Marks this chain as delta, which means special delta process will be
221
239
  called on saving dataset for optimization"""
@@ -226,6 +244,7 @@ class DataChain:
226
244
  self._delta_result_on = right_on
227
245
  self._delta_compare = compare
228
246
  self._delta_retry = delta_retry
247
+ self._delta_unsafe = delta_unsafe
229
248
  return self
230
249
 
231
250
  @property
@@ -238,6 +257,10 @@ class DataChain:
238
257
  """Returns True if this chain is ran in "delta" update mode"""
239
258
  return self._delta
240
259
 
260
+ @property
261
+ def delta_unsafe(self) -> bool:
262
+ return self._delta_unsafe
263
+
241
264
  @property
242
265
  def schema(self) -> dict[str, DataType]:
243
266
  """Get schema of the chain."""
@@ -259,7 +282,7 @@ class DataChain:
259
282
 
260
283
  raise ValueError(f"Column with name {name} not found in the schema")
261
284
 
262
- def c(self, column: Union[str, Column]) -> Column:
285
+ def c(self, column: str | Column) -> Column:
263
286
  """Returns Column instance attached to the current chain."""
264
287
  c = self.column(column) if isinstance(column, str) else self.column(column.name)
265
288
  c.table = self._query.table
@@ -271,17 +294,17 @@ class DataChain:
271
294
  return self._query.session
272
295
 
273
296
  @property
274
- def name(self) -> Optional[str]:
297
+ def name(self) -> str | None:
275
298
  """Name of the underlying dataset, if there is one."""
276
299
  return self._query.name
277
300
 
278
301
  @property
279
- def version(self) -> Optional[str]:
302
+ def version(self) -> str | None:
280
303
  """Version of the underlying dataset, if there is one."""
281
304
  return self._query.version
282
305
 
283
306
  @property
284
- def dataset(self) -> Optional[DatasetRecord]:
307
+ def dataset(self) -> DatasetRecord | None:
285
308
  """Underlying dataset, if there is one."""
286
309
  if not self.name:
287
310
  return None
@@ -295,7 +318,7 @@ class DataChain:
295
318
  """Return `self.union(other)`."""
296
319
  return self.union(other)
297
320
 
298
- def print_schema(self, file: Optional[IO] = None) -> None:
321
+ def print_schema(self, file: IO | None = None) -> None:
299
322
  """Print schema of the chain."""
300
323
  self._effective_signals_schema.print_tree(file=file)
301
324
 
@@ -306,8 +329,8 @@ class DataChain:
306
329
  def _evolve(
307
330
  self,
308
331
  *,
309
- query: Optional[DatasetQuery] = None,
310
- settings: Optional[Settings] = None,
332
+ query: DatasetQuery | None = None,
333
+ settings: Settings | None = None,
311
334
  signal_schema=None,
312
335
  _sys=None,
313
336
  ) -> "Self":
@@ -328,46 +351,51 @@ class DataChain:
328
351
  right_on=self._delta_result_on,
329
352
  compare=self._delta_compare,
330
353
  delta_retry=self._delta_retry,
354
+ delta_unsafe=self._delta_unsafe,
331
355
  )
332
356
 
333
357
  return chain
334
358
 
335
359
  def settings(
336
360
  self,
337
- cache=None,
338
- parallel=None,
339
- workers=None,
340
- min_task_size=None,
341
- prefetch: Optional[int] = None,
342
- sys: Optional[bool] = None,
343
- namespace: Optional[str] = None,
344
- project: Optional[str] = None,
345
- batch_rows: Optional[int] = None,
361
+ cache: bool | None = None,
362
+ prefetch: bool | int | None = None,
363
+ parallel: bool | int | None = None,
364
+ workers: int | None = None,
365
+ namespace: str | None = None,
366
+ project: str | None = None,
367
+ min_task_size: int | None = None,
368
+ batch_size: int | None = None,
369
+ sys: bool | None = None,
346
370
  ) -> "Self":
347
- """Change settings for chain.
348
-
349
- This function changes specified settings without changing not specified ones.
350
- It returns chain, so, it can be chained later with next operation.
371
+ """
372
+ Set chain execution parameters. Returns the chain itself, allowing method
373
+ chaining for subsequent operations. To restore all settings to their default
374
+ values, use `reset_settings()`.
351
375
 
352
376
  Parameters:
353
- cache : data caching. (default=False)
354
- parallel : number of thread for processors. True is a special value to
355
- enable all available CPUs. (default=1)
356
- workers : number of distributed workers. Only for Studio mode. (default=1)
357
- min_task_size : minimum number of tasks. (default=1)
358
- prefetch : number of workers to use for downloading files in advance.
359
- This is enabled by default and uses 2 workers.
360
- To disable prefetching, set it to 0.
361
- namespace : namespace name.
362
- project : project name.
363
- batch_rows : row limit per insert to balance speed and memory usage.
364
- (default=2000)
377
+ cache: Enable files caching to speed up subsequent accesses to the same
378
+ files from the same or different chains. Defaults to False.
379
+ prefetch: Enable prefetching of files. This will download files in
380
+ advance in parallel. If an integer is provided, it specifies the number
381
+ of files to prefetch concurrently for each process on each worker.
382
+ Defaults to 2. Set to 0 or False to disable prefetching.
383
+ parallel: Number of processes to use for processing user-defined functions
384
+ (UDFs) in parallel. If an integer is provided, it specifies the number
385
+ of CPUs to use. If True, all available CPUs are used. Defaults to 1.
386
+ namespace: Namespace to use for the chain by default.
387
+ project: Project to use for the chain by default.
388
+ min_task_size: Minimum number of rows per worker/process for parallel
389
+ processing by UDFs. Defaults to 1.
390
+ batch_size: Number of rows per insert by UDF to fine tune and balance speed
391
+ and memory usage. This might be useful when processing large rows
392
+ or when running into memory issues. Defaults to 2000.
365
393
 
366
394
  Example:
367
395
  ```py
368
396
  chain = (
369
397
  chain
370
- .settings(cache=True, parallel=8, batch_rows=300)
398
+ .settings(cache=True, parallel=8, batch_size=300)
371
399
  .map(laion=process_webdataset(spec=WDSLaion), params="file")
372
400
  )
373
401
  ```
@@ -377,20 +405,20 @@ class DataChain:
377
405
  settings = copy.copy(self._settings)
378
406
  settings.add(
379
407
  Settings(
380
- cache,
381
- parallel,
382
- workers,
383
- min_task_size,
384
- prefetch,
385
- namespace,
386
- project,
387
- batch_rows,
408
+ cache=cache,
409
+ prefetch=prefetch,
410
+ parallel=parallel,
411
+ workers=workers,
412
+ namespace=namespace,
413
+ project=project,
414
+ min_task_size=min_task_size,
415
+ batch_size=batch_size,
388
416
  )
389
417
  )
390
418
  return self._evolve(settings=settings, _sys=sys)
391
419
 
392
- def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
393
- """Reset all settings to default values."""
420
+ def reset_settings(self, settings: Settings | None = None) -> "Self":
421
+ """Reset all chain settings to default values."""
394
422
  self._settings = settings if settings else Settings()
395
423
  return self
396
424
 
@@ -441,8 +469,8 @@ class DataChain:
441
469
  def explode(
442
470
  self,
443
471
  col: str,
444
- model_name: Optional[str] = None,
445
- column: Optional[str] = None,
472
+ model_name: str | None = None,
473
+ column: str | None = None,
446
474
  schema_sample_size: int = 1,
447
475
  ) -> "DataChain":
448
476
  """Explodes a column containing JSON objects (dict or str DataChain type) into
@@ -483,7 +511,7 @@ class DataChain:
483
511
 
484
512
  model = dict_to_data_model(model_name, output, original_names)
485
513
 
486
- def json_to_model(json_value: Union[str, dict]):
514
+ def json_to_model(json_value: str | dict):
487
515
  json_dict = (
488
516
  json.loads(json_value) if isinstance(json_value, str) else json_value
489
517
  )
@@ -557,116 +585,258 @@ class DataChain:
557
585
  create=True,
558
586
  )
559
587
  return self._evolve(
560
- query=self._query.save(project=project, feature_schema=schema)
588
+ query=self._query.save(project=project, feature_schema=schema),
589
+ signal_schema=self.signals_schema | SignalSchema({"sys": Sys}),
561
590
  )
562
591
 
592
+ def _calculate_job_hash(self, job_id: str) -> str:
593
+ """
594
+ Calculates hash of the job at the place of this chain's save method.
595
+ Hash is calculated using previous job checkpoint hash (if exists) and
596
+ adding hash of this chain to produce new hash.
597
+ """
598
+ last_checkpoint = self.session.catalog.metastore.get_last_checkpoint(job_id)
599
+
600
+ return hashlib.sha256(
601
+ (bytes.fromhex(last_checkpoint.hash) if last_checkpoint else b"")
602
+ + bytes.fromhex(self.hash())
603
+ ).hexdigest()
604
+
563
605
  def save( # type: ignore[override]
564
606
  self,
565
607
  name: str,
566
- version: Optional[str] = None,
567
- description: Optional[str] = None,
568
- attrs: Optional[list[str]] = None,
569
- update_version: Optional[str] = "patch",
608
+ version: str | None = None,
609
+ description: str | None = None,
610
+ attrs: list[str] | None = None,
611
+ update_version: str | None = "patch",
570
612
  **kwargs,
571
613
  ) -> "DataChain":
572
614
  """Save to a Dataset. It returns the chain itself.
573
615
 
574
616
  Parameters:
575
- name : dataset name. It can be full name consisting of namespace and
576
- project, but it can also be just a regular dataset name in which
577
- case we are taking namespace and project from settings, if they
578
- are defined there, or default ones instead.
579
- version : version of a dataset. If version is not specified and dataset
617
+ name: dataset name. This can be either a fully qualified name, including
618
+ the namespace and project, or just a regular dataset name. In the latter
619
+ case, the namespace and project will be taken from the settings
620
+ (if specified) or from the default values otherwise.
621
+ version: version of a dataset. If version is not specified and dataset
580
622
  already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
581
- description : description of a dataset.
582
- attrs : attributes of a dataset. They can be without value, e.g "NLP",
623
+ description: description of a dataset.
624
+ attrs: attributes of a dataset. They can be without value, e.g "NLP",
583
625
  or with a value, e.g "location=US".
584
626
  update_version: which part of the dataset version to automatically increase.
585
627
  Available values: `major`, `minor` or `patch`. Default is `patch`.
586
628
  """
629
+
587
630
  catalog = self.session.catalog
588
- if version is not None:
589
- semver.validate(version)
590
631
 
591
- if update_version is not None and update_version not in [
592
- "patch",
593
- "major",
594
- "minor",
595
- ]:
596
- raise ValueError(
597
- "update_version can have one of the following values: major, minor or"
598
- " patch"
599
- )
632
+ result = None # result chain that will be returned at the end
633
+
634
+ # Version validation
635
+ self._validate_version(version)
636
+ self._validate_update_version(update_version)
637
+
638
+ # get existing job if running in SaaS, or creating new one if running locally
639
+ job = self.session.get_or_create_job()
600
640
 
601
641
  namespace_name, project_name, name = catalog.get_full_dataset_name(
602
642
  name,
603
643
  namespace_name=self._settings.namespace,
604
644
  project_name=self._settings.project,
605
645
  )
646
+ project = self._get_or_create_project(namespace_name, project_name)
647
+
648
+ # Checkpoint handling
649
+ _hash, result = self._resolve_checkpoint(name, project, job, kwargs)
650
+ if bool(result):
651
+ # Checkpoint was found and reused
652
+ print(f"Checkpoint found for dataset '{name}', skipping creation")
653
+
654
+ # Schema preparation
655
+ schema = self.signals_schema.clone_without_sys_signals().serialize()
656
+
657
+ # Handle retry and delta functionality
658
+ if not result:
659
+ result = self._handle_delta(name, version, project, schema, kwargs)
660
+
661
+ if not result:
662
+ # calculate chain if we already don't have result from checkpoint or delta
663
+ result = self._evolve(
664
+ query=self._query.save(
665
+ name=name,
666
+ version=version,
667
+ project=project,
668
+ description=description,
669
+ attrs=attrs,
670
+ feature_schema=schema,
671
+ update_version=update_version,
672
+ **kwargs,
673
+ )
674
+ )
606
675
 
676
+ catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
677
+ return result
678
+
679
+ def _validate_version(self, version: str | None) -> None:
680
+ """Validate dataset version if provided."""
681
+ if version is not None:
682
+ semver.validate(version)
683
+
684
+ def _validate_update_version(self, update_version: str | None) -> None:
685
+ """Ensure update_version is one of: major, minor, patch."""
686
+ allowed = ["major", "minor", "patch"]
687
+ if update_version not in allowed:
688
+ raise ValueError(f"update_version must be one of {allowed}")
689
+
690
+ def _get_or_create_project(self, namespace: str, project_name: str) -> Project:
691
+ """Get project or raise if creation not allowed."""
607
692
  try:
608
- project = self.session.catalog.metastore.get_project(
693
+ return self.session.catalog.metastore.get_project(
609
694
  project_name,
610
- namespace_name,
695
+ namespace,
611
696
  create=is_studio(),
612
697
  )
613
698
  except ProjectNotFoundError as e:
614
- # not being able to create it as creation is not allowed
615
699
  raise ProjectCreateNotAllowedError("Creating project is not allowed") from e
616
700
 
617
- schema = self.signals_schema.clone_without_sys_signals().serialize()
701
+ def _resolve_checkpoint(
702
+ self,
703
+ name: str,
704
+ project: Project,
705
+ job: Job,
706
+ kwargs: dict,
707
+ ) -> tuple[str, "DataChain | None"]:
708
+ """Check if checkpoint exists and return cached dataset if possible."""
709
+ from .datasets import read_dataset
618
710
 
619
- # Handle retry and delta functionality
620
- if self.delta and name:
621
- from datachain.delta import delta_retry_update
711
+ metastore = self.session.catalog.metastore
712
+ checkpoints_reset = env2bool("DATACHAIN_CHECKPOINTS_RESET", undefined=True)
622
713
 
623
- # Delta chains must have delta_on defined (ensured by _as_delta method)
624
- assert self._delta_on is not None, "Delta chain must have delta_on defined"
714
+ _hash = self._calculate_job_hash(job.id)
625
715
 
626
- result_ds, dependencies, has_changes = delta_retry_update(
627
- self,
628
- namespace_name,
629
- project_name,
716
+ if (
717
+ job.parent_job_id
718
+ and not checkpoints_reset
719
+ and metastore.find_checkpoint(job.parent_job_id, _hash)
720
+ ):
721
+ # checkpoint found → find which dataset version to reuse
722
+
723
+ # Find dataset version that was created by any ancestor job
724
+ try:
725
+ dataset_version = metastore.get_dataset_version_for_job_ancestry(
726
+ name,
727
+ project.namespace.name,
728
+ project.name,
729
+ job.id,
730
+ )
731
+ except JobAncestryDepthExceededError:
732
+ raise JobAncestryDepthExceededError(
733
+ "Job continuation chain is too deep. "
734
+ "Please run the job from scratch without continuing from a "
735
+ "parent job."
736
+ ) from None
737
+
738
+ if not dataset_version:
739
+ logger.debug(
740
+ "Checkpoint found but no dataset version for '%s' "
741
+ "in job ancestry (job_id=%s). Creating new version.",
742
+ name,
743
+ job.id,
744
+ )
745
+ # Dataset version not found (e.g deleted by user) - skip
746
+ # checkpoint and recreate
747
+ return _hash, None
748
+
749
+ logger.debug(
750
+ "Reusing dataset version '%s' v%s from job ancestry "
751
+ "(job_id=%s, dataset_version_id=%s)",
630
752
  name,
631
- on=self._delta_on,
632
- right_on=self._delta_result_on,
633
- compare=self._delta_compare,
634
- delta_retry=self._delta_retry,
753
+ dataset_version.version,
754
+ job.id,
755
+ dataset_version.id,
635
756
  )
636
757
 
637
- if result_ds:
638
- return self._evolve(
639
- query=result_ds._query.save(
640
- name=name,
641
- version=version,
642
- project=project,
643
- feature_schema=schema,
644
- dependencies=dependencies,
645
- **kwargs,
646
- )
647
- )
758
+ # Read the specific version from ancestry
759
+ chain = read_dataset(
760
+ name,
761
+ namespace=project.namespace.name,
762
+ project=project.name,
763
+ version=dataset_version.version,
764
+ **kwargs,
765
+ )
648
766
 
649
- if not has_changes:
650
- # sources have not been changed so new version of resulting dataset
651
- # would be the same as previous one. To avoid duplicating exact
652
- # datasets, we won't create new version of it and we will return
653
- # current latest version instead.
654
- from .datasets import read_dataset
767
+ # Link current job to this dataset version (not creator).
768
+ # This also updates dataset_version.job_id.
769
+ metastore.link_dataset_version_to_job(
770
+ dataset_version.id,
771
+ job.id,
772
+ is_creator=False,
773
+ )
655
774
 
656
- return read_dataset(name, **kwargs)
775
+ return _hash, chain
657
776
 
658
- return self._evolve(
659
- query=self._query.save(
660
- name=name,
661
- version=version,
662
- project=project,
663
- description=description,
664
- attrs=attrs,
665
- feature_schema=schema,
666
- update_version=update_version,
777
+ return _hash, None
778
+
779
+ def _handle_delta(
780
+ self,
781
+ name: str,
782
+ version: str | None,
783
+ project: Project,
784
+ schema: dict,
785
+ kwargs: dict,
786
+ ) -> "DataChain | None":
787
+ """Try to save as a delta dataset.
788
+ Returns:
789
+ A DataChain if delta logic could handle it, otherwise None to fall back
790
+ to the regular save path (e.g., on first dataset creation).
791
+ """
792
+ from datachain.delta import delta_retry_update
793
+
794
+ from .datasets import read_dataset
795
+
796
+ if not self.delta or not name:
797
+ return None
798
+
799
+ assert self._delta_on is not None, "Delta chain must have delta_on defined"
800
+
801
+ result_ds, dependencies, has_changes = delta_retry_update(
802
+ self,
803
+ project.namespace.name,
804
+ project.name,
805
+ name,
806
+ on=self._delta_on,
807
+ right_on=self._delta_result_on,
808
+ compare=self._delta_compare,
809
+ delta_retry=self._delta_retry,
810
+ )
811
+
812
+ # Case 1: delta produced a new dataset
813
+ if result_ds:
814
+ return self._evolve(
815
+ query=result_ds._query.save(
816
+ name=name,
817
+ version=version,
818
+ project=project,
819
+ feature_schema=schema,
820
+ dependencies=dependencies,
821
+ **kwargs,
822
+ )
823
+ )
824
+
825
+ # Case 2: no changes → reuse last version
826
+ if not has_changes:
827
+ # sources have not been changed so new version of resulting dataset
828
+ # would be the same as previous one. To avoid duplicating exact
829
+ # datasets, we won't create new version of it and we will return
830
+ # current latest version instead.
831
+ return read_dataset(
832
+ name,
833
+ namespace=project.namespace.name,
834
+ project=project.name,
667
835
  **kwargs,
668
836
  )
669
- )
837
+
838
+ # Case 3: first creation of dataset
839
+ return None
670
840
 
671
841
  def apply(self, func, *args, **kwargs):
672
842
  """Apply any function to the chain.
@@ -693,10 +863,10 @@ class DataChain:
693
863
 
694
864
  def map(
695
865
  self,
696
- func: Optional[Callable] = None,
697
- params: Union[None, str, Sequence[str]] = None,
866
+ func: Callable | None = None,
867
+ params: str | Sequence[str] | None = None,
698
868
  output: OutputType = None,
699
- **signal_map,
869
+ **signal_map: Any,
700
870
  ) -> "Self":
701
871
  """Apply a function to each row to create new signals. The function should
702
872
  return a new object for each row. It returns a chain itself with new signals.
@@ -704,17 +874,17 @@ class DataChain:
704
874
  Input-output relationship: 1:1
705
875
 
706
876
  Parameters:
707
- func : Function applied to each row.
708
- params : List of column names used as input for the function. Default
877
+ func: Function applied to each row.
878
+ params: List of column names used as input for the function. Default
709
879
  is taken from function signature.
710
- output : Dictionary defining new signals and their corresponding types.
880
+ output: Dictionary defining new signals and their corresponding types.
711
881
  Default type is taken from function signature. Default can be also
712
882
  taken from kwargs - **signal_map (see below).
713
883
  If signal name is defined using signal_map (see below) only a single
714
884
  type value can be used.
715
- **signal_map : kwargs can be used to define `func` together with it's return
885
+ **signal_map: kwargs can be used to define `func` together with its return
716
886
  signal name in format of `map(my_sign=my_func)`. This helps define
717
- signal names and function in a nicer way.
887
+ signal names and functions in a nicer way.
718
888
 
719
889
  Example:
720
890
  Using signal_map and single type in output:
@@ -735,18 +905,19 @@ class DataChain:
735
905
  if (prefetch := self._settings.prefetch) is not None:
736
906
  udf_obj.prefetch = prefetch
737
907
 
908
+ sys_schema = SignalSchema({"sys": Sys})
738
909
  return self._evolve(
739
910
  query=self._query.add_signals(
740
- udf_obj.to_udf_wrapper(self._settings.batch_rows),
911
+ udf_obj.to_udf_wrapper(self._settings.batch_size),
741
912
  **self._settings.to_dict(),
742
913
  ),
743
- signal_schema=self.signals_schema | udf_obj.output,
914
+ signal_schema=sys_schema | self.signals_schema | udf_obj.output,
744
915
  )
745
916
 
746
917
  def gen(
747
918
  self,
748
- func: Optional[Union[Callable, Generator]] = None,
749
- params: Union[None, str, Sequence[str]] = None,
919
+ func: Callable | Generator | None = None,
920
+ params: str | Sequence[str] | None = None,
750
921
  output: OutputType = None,
751
922
  **signal_map,
752
923
  ) -> "Self":
@@ -775,19 +946,19 @@ class DataChain:
775
946
  udf_obj.prefetch = prefetch
776
947
  return self._evolve(
777
948
  query=self._query.generate(
778
- udf_obj.to_udf_wrapper(self._settings.batch_rows),
949
+ udf_obj.to_udf_wrapper(self._settings.batch_size),
779
950
  **self._settings.to_dict(),
780
951
  ),
781
- signal_schema=udf_obj.output,
952
+ signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
782
953
  )
783
954
 
784
955
  @delta_disabled
785
956
  def agg(
786
957
  self,
787
958
  /,
788
- func: Optional[Callable] = None,
789
- partition_by: Optional[PartitionByType] = None,
790
- params: Union[None, str, Sequence[str]] = None,
959
+ func: Callable | None = None,
960
+ partition_by: PartitionByType | None = None,
961
+ params: str | Sequence[str] | None = None,
791
962
  output: OutputType = None,
792
963
  **signal_map: Callable,
793
964
  ) -> "Self":
@@ -911,17 +1082,17 @@ class DataChain:
911
1082
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
912
1083
  return self._evolve(
913
1084
  query=self._query.generate(
914
- udf_obj.to_udf_wrapper(self._settings.batch_rows),
1085
+ udf_obj.to_udf_wrapper(self._settings.batch_size),
915
1086
  partition_by=processed_partition_by,
916
1087
  **self._settings.to_dict(),
917
1088
  ),
918
- signal_schema=udf_obj.output,
1089
+ signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
919
1090
  )
920
1091
 
921
1092
  def batch_map(
922
1093
  self,
923
- func: Optional[Callable] = None,
924
- params: Union[None, str, Sequence[str]] = None,
1094
+ func: Callable | None = None,
1095
+ params: str | Sequence[str] | None = None,
925
1096
  output: OutputType = None,
926
1097
  batch: int = 1000,
927
1098
  **signal_map,
@@ -933,7 +1104,7 @@ class DataChain:
933
1104
  It accepts the same parameters plus an
934
1105
  additional parameter:
935
1106
 
936
- batch : Size of each batch passed to `func`. Defaults to 1000.
1107
+ batch: Size of each batch passed to `func`. Defaults to 1000.
937
1108
 
938
1109
  Example:
939
1110
  ```py
@@ -960,7 +1131,7 @@ class DataChain:
960
1131
 
961
1132
  return self._evolve(
962
1133
  query=self._query.add_signals(
963
- udf_obj.to_udf_wrapper(self._settings.batch_rows, batch=batch),
1134
+ udf_obj.to_udf_wrapper(self._settings.batch_size, batch=batch),
964
1135
  **self._settings.to_dict(),
965
1136
  ),
966
1137
  signal_schema=self.signals_schema | udf_obj.output,
@@ -969,8 +1140,8 @@ class DataChain:
969
1140
  def _udf_to_obj(
970
1141
  self,
971
1142
  target_class: type[UDFObjT],
972
- func: Optional[Union[Callable, UDFObjT]],
973
- params: Union[None, str, Sequence[str]],
1143
+ func: Callable | UDFObjT | None,
1144
+ params: str | Sequence[str] | None,
974
1145
  output: OutputType,
975
1146
  signal_map: dict[str, Callable],
976
1147
  ) -> UDFObjT:
@@ -981,11 +1152,7 @@ class DataChain:
981
1152
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
982
1153
  DataModel.register(list(sign.output_schema.values.values()))
983
1154
 
984
- signals_schema = self.signals_schema
985
- if self._sys:
986
- signals_schema = SignalSchema({"sys": Sys}) | signals_schema
987
-
988
- params_schema = signals_schema.slice(
1155
+ params_schema = self.signals_schema.slice(
989
1156
  sign.params, self._setup, is_batch=is_batch
990
1157
  )
991
1158
 
@@ -1016,7 +1183,8 @@ class DataChain:
1016
1183
  the order of the records in the chain is important.
1017
1184
  Using `order_by` directly before `limit`, `to_list` and similar methods
1018
1185
  will give expected results.
1019
- See https://github.com/iterative/datachain/issues/477 for further details.
1186
+ See https://github.com/datachain-ai/datachain/issues/477
1187
+ for further details.
1020
1188
  """
1021
1189
  if descending:
1022
1190
  args = tuple(sqlalchemy.desc(a) for a in args)
@@ -1040,11 +1208,9 @@ class DataChain:
1040
1208
  )
1041
1209
  )
1042
1210
 
1043
- def select(self, *args: str, _sys: bool = True) -> "Self":
1211
+ def select(self, *args: str) -> "Self":
1044
1212
  """Select only a specified set of signals."""
1045
1213
  new_schema = self.signals_schema.resolve(*args)
1046
- if self._sys and _sys:
1047
- new_schema = SignalSchema({"sys": Sys}) | new_schema
1048
1214
  columns = new_schema.db_signals()
1049
1215
  return self._evolve(
1050
1216
  query=self._query.select(*columns), signal_schema=new_schema
@@ -1062,7 +1228,7 @@ class DataChain:
1062
1228
  def group_by( # noqa: C901, PLR0912
1063
1229
  self,
1064
1230
  *,
1065
- partition_by: Optional[Union[str, Func, Sequence[Union[str, Func]]]] = None,
1231
+ partition_by: str | Func | Sequence[str | Func] | None = None,
1066
1232
  **kwargs: Func,
1067
1233
  ) -> "Self":
1068
1234
  """Group rows by specified set of signals and return new signals
@@ -1301,9 +1467,9 @@ class DataChain:
1301
1467
  """Yields flattened rows of values as a tuple.
1302
1468
 
1303
1469
  Args:
1304
- row_factory : A callable to convert row to a custom format.
1305
- It should accept two arguments: a list of column names and
1306
- a tuple of row values.
1470
+ row_factory: A callable to convert row to a custom format.
1471
+ It should accept two arguments: a list of column names and
1472
+ a tuple of row values.
1307
1473
  include_hidden: Whether to include hidden signals from the schema.
1308
1474
  """
1309
1475
  db_signals = self._effective_signals_schema.db_signals(
@@ -1368,7 +1534,7 @@ class DataChain:
1368
1534
  """Convert every row to a dictionary."""
1369
1535
 
1370
1536
  def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
1371
- return dict(zip(cols, row))
1537
+ return dict(zip(cols, row, strict=False))
1372
1538
 
1373
1539
  return self.results(row_factory=to_dict)
1374
1540
 
@@ -1426,7 +1592,7 @@ class DataChain:
1426
1592
  @overload
1427
1593
  def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...
1428
1594
 
1429
- def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...]]]: # type: ignore[overload-overlap,misc]
1595
+ def collect(self, *cols: str) -> Iterator[DataValue | tuple[DataValue, ...]]: # type: ignore[overload-overlap,misc]
1430
1596
  """
1431
1597
  Deprecated. Use `to_iter` method instead.
1432
1598
  """
@@ -1491,8 +1657,8 @@ class DataChain:
1491
1657
  def merge(
1492
1658
  self,
1493
1659
  right_ds: "DataChain",
1494
- on: Union[MergeColType, Sequence[MergeColType]],
1495
- right_on: Optional[Union[MergeColType, Sequence[MergeColType]]] = None,
1660
+ on: MergeColType | Sequence[MergeColType],
1661
+ right_on: MergeColType | Sequence[MergeColType] | None = None,
1496
1662
  inner=False,
1497
1663
  full=False,
1498
1664
  rname="right_",
@@ -1560,8 +1726,8 @@ class DataChain:
1560
1726
 
1561
1727
  def _resolve(
1562
1728
  ds: DataChain,
1563
- col: Union[str, Function, sqlalchemy.ColumnElement],
1564
- side: Union[str, None],
1729
+ col: str | Function | sqlalchemy.ColumnElement,
1730
+ side: str | None,
1565
1731
  ):
1566
1732
  try:
1567
1733
  if isinstance(col, Function):
@@ -1574,7 +1740,7 @@ class DataChain:
1574
1740
  ops = [
1575
1741
  _resolve(self, left, "left")
1576
1742
  == _resolve(right_ds, right, "right" if right_on else None)
1577
- for left, right in zip(on, right_on or on)
1743
+ for left, right in zip(on, right_on or on, strict=False)
1578
1744
  ]
1579
1745
 
1580
1746
  if errors:
@@ -1583,16 +1749,17 @@ class DataChain:
1583
1749
  )
1584
1750
 
1585
1751
  query = self._query.join(
1586
- right_ds._query, sqlalchemy.and_(*ops), inner, full, rname + "{name}"
1752
+ right_ds._query, sqlalchemy.and_(*ops), inner, full, rname
1587
1753
  )
1588
1754
  query.feature_schema = None
1589
1755
  ds = self._evolve(query=query)
1590
1756
 
1757
+ # Note: merge drops sys signals from both sides, make sure to not include it
1758
+ # in the resulting schema
1591
1759
  signals_schema = self.signals_schema.clone_without_sys_signals()
1592
1760
  right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
1593
- ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
1594
- right_signals_schema, rname
1595
- )
1761
+
1762
+ ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
1596
1763
 
1597
1764
  return ds
1598
1765
 
@@ -1603,13 +1770,23 @@ class DataChain:
1603
1770
  Parameters:
1604
1771
  other: chain whose rows will be added to `self`.
1605
1772
  """
1773
+ self_schema = self.signals_schema
1774
+ other_schema = other.signals_schema
1775
+ missing_left, missing_right = self_schema.compare_signals(other_schema)
1776
+ if missing_left or missing_right:
1777
+ raise UnionSchemaMismatchError.from_column_sets(
1778
+ missing_left,
1779
+ missing_right,
1780
+ )
1781
+
1782
+ self.signals_schema = self_schema.clone_without_sys_signals()
1606
1783
  return self._evolve(query=self._query.union(other._query))
1607
1784
 
1608
1785
  def subtract( # type: ignore[override]
1609
1786
  self,
1610
1787
  other: "DataChain",
1611
- on: Optional[Union[str, Sequence[str]]] = None,
1612
- right_on: Optional[Union[str, Sequence[str]]] = None,
1788
+ on: str | Sequence[str] | None = None,
1789
+ right_on: str | Sequence[str] | None = None,
1613
1790
  ) -> "Self":
1614
1791
  """Remove rows that appear in another chain.
1615
1792
 
@@ -1666,6 +1843,7 @@ class DataChain:
1666
1843
  zip(
1667
1844
  self.signals_schema.resolve(*on).db_signals(),
1668
1845
  other.signals_schema.resolve(*right_on).db_signals(),
1846
+ strict=False,
1669
1847
  ) # type: ignore[arg-type]
1670
1848
  )
1671
1849
  return self._evolve(query=self._query.subtract(other._query, signals)) # type: ignore[arg-type]
@@ -1673,15 +1851,15 @@ class DataChain:
1673
1851
  def diff(
1674
1852
  self,
1675
1853
  other: "DataChain",
1676
- on: Union[str, Sequence[str]],
1677
- right_on: Optional[Union[str, Sequence[str]]] = None,
1678
- compare: Optional[Union[str, Sequence[str]]] = None,
1679
- right_compare: Optional[Union[str, Sequence[str]]] = None,
1854
+ on: str | Sequence[str],
1855
+ right_on: str | Sequence[str] | None = None,
1856
+ compare: str | Sequence[str] | None = None,
1857
+ right_compare: str | Sequence[str] | None = None,
1680
1858
  added: bool = True,
1681
1859
  deleted: bool = True,
1682
1860
  modified: bool = True,
1683
1861
  same: bool = False,
1684
- status_col: Optional[str] = None,
1862
+ status_col: str | None = None,
1685
1863
  ) -> "DataChain":
1686
1864
  """Calculate differences between two chains.
1687
1865
 
@@ -1742,12 +1920,12 @@ class DataChain:
1742
1920
  self,
1743
1921
  other: "DataChain",
1744
1922
  on: str = "file",
1745
- right_on: Optional[str] = None,
1923
+ right_on: str | None = None,
1746
1924
  added: bool = True,
1747
1925
  modified: bool = True,
1748
1926
  deleted: bool = False,
1749
1927
  same: bool = False,
1750
- status_col: Optional[str] = None,
1928
+ status_col: str | None = None,
1751
1929
  ) -> "DataChain":
1752
1930
  """Calculate differences between two chains containing files.
1753
1931
 
@@ -1845,12 +2023,15 @@ class DataChain:
1845
2023
  self,
1846
2024
  flatten: bool = False,
1847
2025
  include_hidden: bool = True,
2026
+ as_object: bool = False,
1848
2027
  ) -> "pd.DataFrame":
1849
2028
  """Return a pandas DataFrame from the chain.
1850
2029
 
1851
2030
  Parameters:
1852
2031
  flatten: Whether to use a multiindex or flatten column names.
1853
2032
  include_hidden: Whether to include hidden columns.
2033
+ as_object: Whether to emit a dataframe backed by Python objects
2034
+ rather than pandas-inferred dtypes.
1854
2035
 
1855
2036
  Returns:
1856
2037
  pd.DataFrame: A pandas DataFrame representation of the chain.
@@ -1860,12 +2041,18 @@ class DataChain:
1860
2041
  headers, max_length = self._effective_signals_schema.get_headers_with_length(
1861
2042
  include_hidden=include_hidden
1862
2043
  )
2044
+
2045
+ columns: list[str] | pd.MultiIndex
1863
2046
  if flatten or max_length < 2:
1864
2047
  columns = [".".join(filter(None, header)) for header in headers]
1865
2048
  else:
1866
2049
  columns = pd.MultiIndex.from_tuples(map(tuple, headers))
1867
2050
 
1868
2051
  results = self.results(include_hidden=include_hidden)
2052
+ if as_object:
2053
+ df = pd.DataFrame(results, columns=columns, dtype=object)
2054
+ df.where(pd.notna(df), None, inplace=True)
2055
+ return df
1869
2056
  return pd.DataFrame.from_records(results, columns=columns)
1870
2057
 
1871
2058
  def show(
@@ -1888,7 +2075,11 @@ class DataChain:
1888
2075
  import pandas as pd
1889
2076
 
1890
2077
  dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
1891
- df = dc.to_pandas(flatten, include_hidden=include_hidden)
2078
+ df = dc.to_pandas(
2079
+ flatten,
2080
+ include_hidden=include_hidden,
2081
+ as_object=True,
2082
+ )
1892
2083
 
1893
2084
  if df.empty:
1894
2085
  print("Empty result")
@@ -1947,20 +2138,20 @@ class DataChain:
1947
2138
  column: str = "",
1948
2139
  model_name: str = "",
1949
2140
  source: bool = True,
1950
- nrows: Optional[int] = None,
1951
- **kwargs,
2141
+ nrows: int | None = None,
2142
+ **kwargs: Any,
1952
2143
  ) -> "Self":
1953
2144
  """Generate chain from list of tabular files.
1954
2145
 
1955
2146
  Parameters:
1956
- output : Dictionary or feature class defining column names and their
2147
+ output: Dictionary or feature class defining column names and their
1957
2148
  corresponding types. List of column names is also accepted, in which
1958
2149
  case types will be inferred.
1959
- column : Generated column name.
1960
- model_name : Generated model name.
1961
- source : Whether to include info about the source file.
1962
- nrows : Optional row limit.
1963
- kwargs : Parameters to pass to pyarrow.dataset.dataset.
2150
+ column: Generated column name.
2151
+ model_name: Generated model name.
2152
+ source: Whether to include info about the source file.
2153
+ nrows: Optional row limit.
2154
+ kwargs: Parameters to pass to pyarrow.dataset.dataset.
1964
2155
 
1965
2156
  Example:
1966
2157
  Reading a json lines file:
@@ -2081,23 +2272,23 @@ class DataChain:
2081
2272
 
2082
2273
  def to_parquet(
2083
2274
  self,
2084
- path: Union[str, os.PathLike[str], BinaryIO],
2085
- partition_cols: Optional[Sequence[str]] = None,
2275
+ path: str | os.PathLike[str] | BinaryIO,
2276
+ partition_cols: Sequence[str] | None = None,
2086
2277
  chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
2087
- fs_kwargs: Optional[dict[str, Any]] = None,
2278
+ fs_kwargs: dict[str, Any] | None = None,
2088
2279
  **kwargs,
2089
2280
  ) -> None:
2090
2281
  """Save chain to parquet file with SignalSchema metadata.
2091
2282
 
2092
2283
  Parameters:
2093
- path : Path or a file-like binary object to save the file. This supports
2284
+ path: Path or a file-like binary object to save the file. This supports
2094
2285
  local paths as well as remote paths, such as s3:// or hf:// with fsspec.
2095
- partition_cols : Column names by which to partition the dataset.
2096
- chunk_size : The chunk size of results to read and convert to columnar
2286
+ partition_cols: Column names by which to partition the dataset.
2287
+ chunk_size: The chunk size of results to read and convert to columnar
2097
2288
  data, to avoid running out of memory.
2098
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
2099
- write, for fsspec-type URLs, such as s3:// or hf:// when
2100
- provided as the destination path.
2289
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2290
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2291
+ are supported.
2101
2292
  """
2102
2293
  import pyarrow as pa
2103
2294
  import pyarrow.parquet as pq
@@ -2141,7 +2332,7 @@ class DataChain:
2141
2332
  # pyarrow infers the best parquet schema from the python types of
2142
2333
  # the input data.
2143
2334
  table = pa.Table.from_pydict(
2144
- dict(zip(column_names, chunk)),
2335
+ dict(zip(column_names, chunk, strict=False)),
2145
2336
  schema=parquet_schema,
2146
2337
  )
2147
2338
 
@@ -2179,137 +2370,116 @@ class DataChain:
2179
2370
 
2180
2371
  def to_csv(
2181
2372
  self,
2182
- path: Union[str, os.PathLike[str]],
2373
+ path: str | os.PathLike[str],
2183
2374
  delimiter: str = ",",
2184
- fs_kwargs: Optional[dict[str, Any]] = None,
2375
+ fs_kwargs: dict[str, Any] | None = None,
2185
2376
  **kwargs,
2186
- ) -> None:
2187
- """Save chain to a csv (comma-separated values) file.
2377
+ ) -> File:
2378
+ """Save chain to a csv (comma-separated values) file and return the stored
2379
+ `File`.
2188
2380
 
2189
2381
  Parameters:
2190
- path : Path to save the file. This supports local paths as well as
2382
+ path: Path to save the file. This supports local paths as well as
2191
2383
  remote paths, such as s3:// or hf:// with fsspec.
2192
- delimiter : Delimiter to use for the resulting file.
2193
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
2194
- write, for fsspec-type URLs, such as s3:// or hf:// when
2195
- provided as the destination path.
2384
+ delimiter: Delimiter to use for the resulting file.
2385
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2386
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2387
+ are supported.
2388
+ Returns:
2389
+ File: The stored file with refreshed metadata (version, etag, size).
2196
2390
  """
2197
2391
  import csv
2198
2392
 
2199
- opener = open
2200
-
2201
- if isinstance(path, str) and "://" in path:
2202
- from datachain.client.fsspec import Client
2203
-
2204
- fs_kwargs = {
2205
- **self._query.catalog.client_config,
2206
- **(fs_kwargs or {}),
2207
- }
2208
-
2209
- client = Client.get_implementation(path)
2210
-
2211
- fsspec_fs = client.create_fs(**fs_kwargs)
2212
-
2213
- opener = fsspec_fs.open
2393
+ target = File.at(path, session=self.session)
2214
2394
 
2215
2395
  headers, _ = self._effective_signals_schema.get_headers_with_length()
2216
2396
  column_names = [".".join(filter(None, header)) for header in headers]
2217
2397
 
2218
- results_iter = self._leaf_values()
2219
-
2220
- with opener(path, "w", newline="") as f:
2398
+ with target.open("w", newline="", client_config=fs_kwargs) as f:
2221
2399
  writer = csv.writer(f, delimiter=delimiter, **kwargs)
2222
2400
  writer.writerow(column_names)
2223
-
2224
- for row in results_iter:
2401
+ for row in self._leaf_values():
2225
2402
  writer.writerow(row)
2226
2403
 
2404
+ return target
2405
+
2227
2406
  def to_json(
2228
2407
  self,
2229
- path: Union[str, os.PathLike[str]],
2230
- fs_kwargs: Optional[dict[str, Any]] = None,
2408
+ path: str | os.PathLike[str],
2409
+ fs_kwargs: dict[str, Any] | None = None,
2231
2410
  include_outer_list: bool = True,
2232
- ) -> None:
2233
- """Save chain to a JSON file.
2411
+ ) -> File:
2412
+ """Save chain to a JSON file and return the stored `File`.
2234
2413
 
2235
2414
  Parameters:
2236
- path : Path to save the file. This supports local paths as well as
2415
+ path: Path to save the file. This supports local paths as well as
2237
2416
  remote paths, such as s3:// or hf:// with fsspec.
2238
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
2239
- write, for fsspec-type URLs, such as s3:// or hf:// when
2240
- provided as the destination path.
2241
- include_outer_list : Sets whether to include an outer list for all rows.
2417
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2418
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2419
+ are supported.
2420
+ include_outer_list: Sets whether to include an outer list for all rows.
2242
2421
  Setting this to True makes the file valid JSON, while False instead
2243
2422
  writes in the JSON lines format.
2423
+ Returns:
2424
+ File: The stored file with refreshed metadata (version, etag, size).
2244
2425
  """
2245
- opener = open
2246
-
2247
- if isinstance(path, str) and "://" in path:
2248
- from datachain.client.fsspec import Client
2249
-
2250
- fs_kwargs = {
2251
- **self._query.catalog.client_config,
2252
- **(fs_kwargs or {}),
2253
- }
2254
-
2255
- client = Client.get_implementation(path)
2256
-
2257
- fsspec_fs = client.create_fs(**fs_kwargs)
2258
-
2259
- opener = fsspec_fs.open
2260
-
2426
+ target = File.at(path, session=self.session)
2261
2427
  headers, _ = self._effective_signals_schema.get_headers_with_length()
2262
- headers = [list(filter(None, header)) for header in headers]
2428
+ headers = [list(filter(None, h)) for h in headers]
2429
+ with target.open("wb", client_config=fs_kwargs) as f:
2430
+ self._write_json_stream(f, headers, include_outer_list)
2431
+ return target
2263
2432
 
2433
+ def _write_json_stream(
2434
+ self,
2435
+ f: IO[bytes],
2436
+ headers: list[list[str]],
2437
+ include_outer_list: bool,
2438
+ ) -> None:
2264
2439
  is_first = True
2265
-
2266
- with opener(path, "wb") as f:
2267
- if include_outer_list:
2268
- # This makes the file JSON instead of JSON lines.
2269
- f.write(b"[\n")
2270
- for row in self._leaf_values():
2271
- if not is_first:
2272
- if include_outer_list:
2273
- # This makes the file JSON instead of JSON lines.
2274
- f.write(b",\n")
2275
- else:
2276
- f.write(b"\n")
2277
- else:
2278
- is_first = False
2279
- f.write(
2280
- json.dumps(
2281
- row_to_nested_dict(headers, row), ensure_ascii=False
2282
- ).encode("utf-8")
2283
- )
2284
- if include_outer_list:
2285
- # This makes the file JSON instead of JSON lines.
2286
- f.write(b"\n]\n")
2440
+ if include_outer_list:
2441
+ f.write(b"[\n")
2442
+ for row in self._leaf_values():
2443
+ if not is_first:
2444
+ f.write(b",\n" if include_outer_list else b"\n")
2445
+ else:
2446
+ is_first = False
2447
+ f.write(
2448
+ json.dumps(
2449
+ row_to_nested_dict(headers, row),
2450
+ ensure_ascii=False,
2451
+ ).encode("utf-8")
2452
+ )
2453
+ if include_outer_list:
2454
+ f.write(b"\n]\n")
2287
2455
 
2288
2456
  def to_jsonl(
2289
2457
  self,
2290
- path: Union[str, os.PathLike[str]],
2291
- fs_kwargs: Optional[dict[str, Any]] = None,
2292
- ) -> None:
2458
+ path: str | os.PathLike[str],
2459
+ fs_kwargs: dict[str, Any] | None = None,
2460
+ ) -> File:
2293
2461
  """Save chain to a JSON lines file.
2294
2462
 
2295
2463
  Parameters:
2296
- path : Path to save the file. This supports local paths as well as
2464
+ path: Path to save the file. This supports local paths as well as
2297
2465
  remote paths, such as s3:// or hf:// with fsspec.
2298
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
2299
- write, for fsspec-type URLs, such as s3:// or hf:// when
2300
- provided as the destination path.
2466
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2467
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2468
+ are supported.
2469
+ Returns:
2470
+ File: The stored file with refreshed metadata (version, etag, size).
2301
2471
  """
2302
- self.to_json(path, fs_kwargs, include_outer_list=False)
2472
+ return self.to_json(path, fs_kwargs, include_outer_list=False)
2303
2473
 
2304
2474
  def to_database(
2305
2475
  self,
2306
2476
  table_name: str,
2307
2477
  connection: "ConnectionType",
2308
2478
  *,
2309
- batch_rows: int = DEFAULT_DATABASE_BATCH_SIZE,
2310
- on_conflict: Optional[str] = None,
2311
- conflict_columns: Optional[list[str]] = None,
2312
- column_mapping: Optional[dict[str, Optional[str]]] = None,
2479
+ batch_size: int = DEFAULT_DATABASE_BATCH_SIZE,
2480
+ on_conflict: str | None = None,
2481
+ conflict_columns: list[str] | None = None,
2482
+ column_mapping: dict[str, str | None] | None = None,
2313
2483
  ) -> int:
2314
2484
  """Save chain to a database table using a given database connection.
2315
2485
 
@@ -2328,7 +2498,7 @@ class DataChain:
2328
2498
  library. If a DBAPI2 object, only sqlite3 is supported. The user is
2329
2499
  responsible for engine disposal and connection closure for the
2330
2500
  SQLAlchemy connectable; str connections are closed automatically.
2331
- batch_rows: Number of rows to insert per batch for optimal performance.
2501
+ batch_size: Number of rows to insert per batch for optimal performance.
2332
2502
  Larger batches are faster but use more memory. Default: 10,000.
2333
2503
  on_conflict: Strategy for handling duplicate rows (requires table
2334
2504
  constraints):
@@ -2409,7 +2579,7 @@ class DataChain:
2409
2579
  self,
2410
2580
  table_name,
2411
2581
  connection,
2412
- batch_rows=batch_rows,
2582
+ batch_size=batch_size,
2413
2583
  on_conflict=on_conflict,
2414
2584
  conflict_columns=conflict_columns,
2415
2585
  column_mapping=column_mapping,
@@ -2545,13 +2715,13 @@ class DataChain:
2545
2715
 
2546
2716
  def to_storage(
2547
2717
  self,
2548
- output: Union[str, os.PathLike[str]],
2718
+ output: str | os.PathLike[str],
2549
2719
  signal: str = "file",
2550
2720
  placement: FileExportPlacement = "fullpath",
2551
2721
  link_type: Literal["copy", "symlink"] = "copy",
2552
- num_threads: Optional[int] = EXPORT_FILES_MAX_THREADS,
2553
- anon: Optional[bool] = None,
2554
- client_config: Optional[dict] = None,
2722
+ num_threads: int | None = EXPORT_FILES_MAX_THREADS,
2723
+ anon: bool | None = None,
2724
+ client_config: dict | None = None,
2555
2725
  ) -> None:
2556
2726
  """Export files from a specified signal to a directory. Files can be
2557
2727
  exported to a local or cloud directory.
@@ -2560,12 +2730,24 @@ class DataChain:
2560
2730
  output: Path to the target directory for exporting files.
2561
2731
  signal: Name of the signal to export files from.
2562
2732
  placement: The method to use for naming exported files.
2563
- The possible values are: "filename", "etag", "fullpath", and "checksum".
2733
+ The possible values are: "filename", "etag", "fullpath",
2734
+ "filepath", and "checksum".
2735
+ Example path translations for an object located at
2736
+ ``s3://bucket/data/img.jpg`` and exported to ``./out``:
2737
+
2738
+ - "filename" -> ``./out/img.jpg`` (no directories)
2739
+ - "filepath" -> ``./out/data/img.jpg`` (relative path kept)
2740
+ - "fullpath" -> ``./out/bucket/data/img.jpg`` (remote host kept)
2741
+ - "etag" -> ``./out/<etag>.jpg`` (unique name via object digest)
2742
+
2743
+ Local sources behave like "filepath" for "fullpath" placement.
2744
+ Relative destinations such as "." or ".." and absolute paths
2745
+ are supported for every strategy.
2564
2746
  link_type: Method to use for exporting files.
2565
2747
  Falls back to `'copy'` if symlinking fails.
2566
- num_threads : number of threads to use for exporting files.
2567
- By default it uses 5 threads.
2568
- anon: If True, we will treat cloud bucket as public one. Default behavior
2748
+ num_threads: number of threads to use for exporting files.
2749
+ By default, it uses 5 threads.
2750
+ anon: If True, we will treat cloud bucket as a public one. Default behavior
2569
2751
  depends on the previous session configuration (e.g. happens in the
2570
2752
  initial `read_storage`) and particular cloud storage client
2571
2753
  implementation (e.g. S3 fallbacks to anonymous access if no credentials
@@ -2614,8 +2796,20 @@ class DataChain:
2614
2796
  )
2615
2797
 
2616
2798
  def shuffle(self) -> "Self":
2617
- """Shuffle the rows of the chain deterministically."""
2618
- return self.order_by("sys.rand")
2799
+ """Shuffle rows with a best-effort deterministic ordering.
2800
+
2801
+ This produces repeatable shuffles. Merge and union operations can
2802
+ lead to non-deterministic results. Use order by or save a dataset
2803
+ afterward to guarantee the same result.
2804
+ """
2805
+ query = self._query.clone(new_table=False)
2806
+ query.steps.append(RegenerateSystemColumns(self._query.catalog))
2807
+
2808
+ chain = self._evolve(
2809
+ query=query,
2810
+ signal_schema=SignalSchema({"sys": Sys}) | self.signals_schema,
2811
+ )
2812
+ return chain.order_by("sys.rand")
2619
2813
 
2620
2814
  def sample(self, n: int) -> "Self":
2621
2815
  """Return a random sample from the chain.