datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,51 +1,69 @@
1
1
  import copy
2
+ import hashlib
3
+ import logging
2
4
  import os
3
5
  import os.path
4
6
  import sys
5
7
  import warnings
6
- from collections.abc import Iterator, Sequence
8
+ from collections.abc import Callable, Iterator, Sequence
7
9
  from typing import (
8
10
  IO,
9
11
  TYPE_CHECKING,
10
12
  Any,
11
13
  BinaryIO,
12
- Callable,
13
14
  ClassVar,
14
15
  Literal,
15
- Optional,
16
16
  TypeVar,
17
- Union,
17
+ cast,
18
18
  overload,
19
19
  )
20
20
 
21
- import orjson
22
21
  import sqlalchemy
23
22
  from pydantic import BaseModel
23
+ from sqlalchemy.sql.elements import ColumnElement
24
24
  from tqdm import tqdm
25
25
 
26
+ from datachain import json, semver
26
27
  from datachain.dataset import DatasetRecord
28
+ from datachain.delta import delta_disabled
29
+ from datachain.error import (
30
+ JobAncestryDepthExceededError,
31
+ ProjectCreateNotAllowedError,
32
+ ProjectNotFoundError,
33
+ )
27
34
  from datachain.func import literal
28
35
  from datachain.func.base import Function
29
36
  from datachain.func.func import Func
37
+ from datachain.job import Job
30
38
  from datachain.lib.convert.python_to_sql import python_to_sql
31
- from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_data_model
32
- from datachain.lib.file import (
33
- EXPORT_FILES_MAX_THREADS,
34
- ArrowRow,
35
- FileExporter,
39
+ from datachain.lib.data_model import (
40
+ DataModel,
41
+ DataType,
42
+ DataValue,
43
+ StandardType,
44
+ dict_to_data_model,
36
45
  )
46
+ from datachain.lib.file import EXPORT_FILES_MAX_THREADS, ArrowRow, File, FileExporter
37
47
  from datachain.lib.file import ExportPlacement as FileExportPlacement
48
+ from datachain.lib.model_store import ModelStore
38
49
  from datachain.lib.settings import Settings
39
- from datachain.lib.signal_schema import SignalSchema
50
+ from datachain.lib.signal_schema import SignalResolvingError, SignalSchema
40
51
  from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
41
52
  from datachain.lib.udf_signature import UdfSignature
42
53
  from datachain.lib.utils import DataChainColumnError, DataChainParamsError
54
+ from datachain.project import Project
43
55
  from datachain.query import Session
44
- from datachain.query.dataset import DatasetQuery, PartitionByType
45
- from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
56
+ from datachain.query.dataset import (
57
+ DatasetQuery,
58
+ PartitionByType,
59
+ RegenerateSystemColumns,
60
+ UnionSchemaMismatchError,
61
+ )
62
+ from datachain.query.schema import DEFAULT_DELIMITER, Column
46
63
  from datachain.sql.functions import path as pathfunc
47
- from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
64
+ from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict
48
65
 
66
+ from .database import DEFAULT_DATABASE_BATCH_SIZE
49
67
  from .utils import (
50
68
  DatasetMergeError,
51
69
  DatasetPrepareError,
@@ -54,9 +72,12 @@ from .utils import (
54
72
  Sys,
55
73
  _get_merge_error_str,
56
74
  _validate_merge_on,
75
+ is_studio,
57
76
  resolve_columns,
58
77
  )
59
78
 
79
+ logger = logging.getLogger("datachain")
80
+
60
81
  C = Column
61
82
 
62
83
  _T = TypeVar("_T")
@@ -65,11 +86,27 @@ UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
65
86
  DEFAULT_PARQUET_CHUNK_SIZE = 100_000
66
87
 
67
88
  if TYPE_CHECKING:
89
+ import sqlite3
90
+
68
91
  import pandas as pd
92
+ from sqlalchemy.orm import Session as OrmSession
69
93
  from typing_extensions import ParamSpec, Self
70
94
 
71
95
  P = ParamSpec("P")
72
96
 
97
+ ConnectionType = (
98
+ str
99
+ | sqlalchemy.engine.URL
100
+ | sqlalchemy.engine.interfaces.Connectable
101
+ | sqlalchemy.engine.Engine
102
+ | sqlalchemy.engine.Connection
103
+ | OrmSession
104
+ | sqlite3.Connection
105
+ )
106
+
107
+
108
+ T = TypeVar("T", bound="DataChain")
109
+
73
110
 
74
111
  class DataChain:
75
112
  """DataChain - a data structure for batch data processing and evaluation.
@@ -133,7 +170,7 @@ class DataChain:
133
170
  .choices[0]
134
171
  .message.content,
135
172
  )
136
- .save()
173
+ .persist()
137
174
  )
138
175
 
139
176
  try:
@@ -154,7 +191,7 @@ class DataChain:
154
191
  query: DatasetQuery,
155
192
  settings: Settings,
156
193
  signal_schema: SignalSchema,
157
- setup: Optional[dict] = None,
194
+ setup: dict | None = None,
158
195
  _sys: bool = False,
159
196
  ) -> None:
160
197
  """Don't instantiate this directly, use one of the from_XXX constructors."""
@@ -163,6 +200,12 @@ class DataChain:
163
200
  self.signals_schema = signal_schema
164
201
  self._setup: dict = setup or {}
165
202
  self._sys = _sys
203
+ self._delta = False
204
+ self._delta_unsafe = False
205
+ self._delta_on: str | Sequence[str] | None = None
206
+ self._delta_result_on: str | Sequence[str] | None = None
207
+ self._delta_compare: str | Sequence[str] | None = None
208
+ self._delta_retry: bool | str | None = None
166
209
 
167
210
  def __repr__(self) -> str:
168
211
  """Return a string representation of the chain."""
@@ -176,6 +219,48 @@ class DataChain:
176
219
  self.print_schema(file=file)
177
220
  return file.getvalue()
178
221
 
222
+ def hash(self) -> str:
223
+ """
224
+ Calculates SHA hash of this chain. Hash calculation is fast and consistent.
225
+ It takes into account all the steps added to the chain and their inputs.
226
+ Order of the steps is important.
227
+ """
228
+ return self._query.hash()
229
+
230
+ def _as_delta(
231
+ self,
232
+ on: str | Sequence[str] | None = None,
233
+ right_on: str | Sequence[str] | None = None,
234
+ compare: str | Sequence[str] | None = None,
235
+ delta_retry: bool | str | None = None,
236
+ delta_unsafe: bool = False,
237
+ ) -> "Self":
238
+ """Marks this chain as delta, which means special delta process will be
239
+ called on saving dataset for optimization"""
240
+ if on is None:
241
+ raise ValueError("'delta on' fields must be defined")
242
+ self._delta = True
243
+ self._delta_on = on
244
+ self._delta_result_on = right_on
245
+ self._delta_compare = compare
246
+ self._delta_retry = delta_retry
247
+ self._delta_unsafe = delta_unsafe
248
+ return self
249
+
250
+ @property
251
+ def empty(self) -> bool:
252
+ """Returns True if chain has zero number of rows"""
253
+ return not bool(self.count())
254
+
255
+ @property
256
+ def delta(self) -> bool:
257
+ """Returns True if this chain is ran in "delta" update mode"""
258
+ return self._delta
259
+
260
+ @property
261
+ def delta_unsafe(self) -> bool:
262
+ return self._delta_unsafe
263
+
179
264
  @property
180
265
  def schema(self) -> dict[str, DataType]:
181
266
  """Get schema of the chain."""
@@ -197,7 +282,7 @@ class DataChain:
197
282
 
198
283
  raise ValueError(f"Column with name {name} not found in the schema")
199
284
 
200
- def c(self, column: Union[str, Column]) -> Column:
285
+ def c(self, column: str | Column) -> Column:
201
286
  """Returns Column instance attached to the current chain."""
202
287
  c = self.column(column) if isinstance(column, str) else self.column(column.name)
203
288
  c.table = self._query.table
@@ -209,27 +294,31 @@ class DataChain:
209
294
  return self._query.session
210
295
 
211
296
  @property
212
- def name(self) -> Optional[str]:
297
+ def name(self) -> str | None:
213
298
  """Name of the underlying dataset, if there is one."""
214
299
  return self._query.name
215
300
 
216
301
  @property
217
- def version(self) -> Optional[int]:
302
+ def version(self) -> str | None:
218
303
  """Version of the underlying dataset, if there is one."""
219
304
  return self._query.version
220
305
 
221
306
  @property
222
- def dataset(self) -> Optional[DatasetRecord]:
307
+ def dataset(self) -> DatasetRecord | None:
223
308
  """Underlying dataset, if there is one."""
224
309
  if not self.name:
225
310
  return None
226
- return self.session.catalog.get_dataset(self.name)
311
+ return self.session.catalog.get_dataset(
312
+ self.name,
313
+ namespace_name=self._query.project.namespace.name,
314
+ project_name=self._query.project.name,
315
+ )
227
316
 
228
317
  def __or__(self, other: "Self") -> "Self":
229
318
  """Return `self.union(other)`."""
230
319
  return self.union(other)
231
320
 
232
- def print_schema(self, file: Optional[IO] = None) -> None:
321
+ def print_schema(self, file: IO | None = None) -> None:
233
322
  """Print schema of the chain."""
234
323
  self._effective_signals_schema.print_tree(file=file)
235
324
 
@@ -240,8 +329,8 @@ class DataChain:
240
329
  def _evolve(
241
330
  self,
242
331
  *,
243
- query: Optional[DatasetQuery] = None,
244
- settings: Optional[Settings] = None,
332
+ query: DatasetQuery | None = None,
333
+ settings: Settings | None = None,
245
334
  signal_schema=None,
246
335
  _sys=None,
247
336
  ) -> "Self":
@@ -253,39 +342,60 @@ class DataChain:
253
342
  signal_schema = copy.deepcopy(self.signals_schema)
254
343
  if _sys is None:
255
344
  _sys = self._sys
256
- return type(self)(
345
+ chain = type(self)(
257
346
  query, settings, signal_schema=signal_schema, setup=self._setup, _sys=_sys
258
347
  )
348
+ if self.delta:
349
+ chain = chain._as_delta(
350
+ on=self._delta_on,
351
+ right_on=self._delta_result_on,
352
+ compare=self._delta_compare,
353
+ delta_retry=self._delta_retry,
354
+ delta_unsafe=self._delta_unsafe,
355
+ )
356
+
357
+ return chain
259
358
 
260
359
  def settings(
261
360
  self,
262
- cache=None,
263
- parallel=None,
264
- workers=None,
265
- min_task_size=None,
266
- prefetch: Optional[int] = None,
267
- sys: Optional[bool] = None,
361
+ cache: bool | None = None,
362
+ prefetch: bool | int | None = None,
363
+ parallel: bool | int | None = None,
364
+ workers: int | None = None,
365
+ namespace: str | None = None,
366
+ project: str | None = None,
367
+ min_task_size: int | None = None,
368
+ batch_size: int | None = None,
369
+ sys: bool | None = None,
268
370
  ) -> "Self":
269
- """Change settings for chain.
270
-
271
- This function changes specified settings without changing not specified ones.
272
- It returns chain, so, it can be chained later with next operation.
371
+ """
372
+ Set chain execution parameters. Returns the chain itself, allowing method
373
+ chaining for subsequent operations. To restore all settings to their default
374
+ values, use `reset_settings()`.
273
375
 
274
376
  Parameters:
275
- cache : data caching (default=False)
276
- parallel : number of thread for processors. True is a special value to
277
- enable all available CPUs (default=1)
278
- workers : number of distributed workers. Only for Studio mode. (default=1)
279
- min_task_size : minimum number of tasks (default=1)
280
- prefetch: number of workers to use for downloading files in advance.
281
- This is enabled by default and uses 2 workers.
282
- To disable prefetching, set it to 0.
377
+ cache: Enable files caching to speed up subsequent accesses to the same
378
+ files from the same or different chains. Defaults to False.
379
+ prefetch: Enable prefetching of files. This will download files in
380
+ advance in parallel. If an integer is provided, it specifies the number
381
+ of files to prefetch concurrently for each process on each worker.
382
+ Defaults to 2. Set to 0 or False to disable prefetching.
383
+ parallel: Number of processes to use for processing user-defined functions
384
+ (UDFs) in parallel. If an integer is provided, it specifies the number
385
+ of CPUs to use. If True, all available CPUs are used. Defaults to 1.
386
+ namespace: Namespace to use for the chain by default.
387
+ project: Project to use for the chain by default.
388
+ min_task_size: Minimum number of rows per worker/process for parallel
389
+ processing by UDFs. Defaults to 1.
390
+ batch_size: Number of rows per insert by UDF to fine tune and balance speed
391
+ and memory usage. This might be useful when processing large rows
392
+ or when running into memory issues. Defaults to 2000.
283
393
 
284
394
  Example:
285
395
  ```py
286
396
  chain = (
287
397
  chain
288
- .settings(cache=True, parallel=8)
398
+ .settings(cache=True, parallel=8, batch_size=300)
289
399
  .map(laion=process_webdataset(spec=WDSLaion), params="file")
290
400
  )
291
401
  ```
@@ -293,22 +403,25 @@ class DataChain:
293
403
  if sys is None:
294
404
  sys = self._sys
295
405
  settings = copy.copy(self._settings)
296
- settings.add(Settings(cache, parallel, workers, min_task_size, prefetch))
406
+ settings.add(
407
+ Settings(
408
+ cache=cache,
409
+ prefetch=prefetch,
410
+ parallel=parallel,
411
+ workers=workers,
412
+ namespace=namespace,
413
+ project=project,
414
+ min_task_size=min_task_size,
415
+ batch_size=batch_size,
416
+ )
417
+ )
297
418
  return self._evolve(settings=settings, _sys=sys)
298
419
 
299
- def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
300
- """Reset all settings to default values."""
420
+ def reset_settings(self, settings: Settings | None = None) -> "Self":
421
+ """Reset all chain settings to default values."""
301
422
  self._settings = settings if settings else Settings()
302
423
  return self
303
424
 
304
- def reset_schema(self, signals_schema: SignalSchema) -> "Self":
305
- self.signals_schema = signals_schema
306
- return self
307
-
308
- def add_schema(self, signals_schema: SignalSchema) -> "Self":
309
- self.signals_schema |= signals_schema
310
- return self
311
-
312
425
  @classmethod
313
426
  def from_storage(
314
427
  cls,
@@ -356,8 +469,8 @@ class DataChain:
356
469
  def explode(
357
470
  self,
358
471
  col: str,
359
- model_name: Optional[str] = None,
360
- object_name: Optional[str] = None,
472
+ model_name: str | None = None,
473
+ column: str | None = None,
361
474
  schema_sample_size: int = 1,
362
475
  ) -> "DataChain":
363
476
  """Explodes a column containing JSON objects (dict or str DataChain type) into
@@ -368,7 +481,7 @@ class DataChain:
368
481
  col: the name of the column containing JSON to be exploded.
369
482
  model_name: optional generated model name. By default generates the name
370
483
  automatically.
371
- object_name: optional generated object column name. By default generates the
484
+ column: optional generated column name. By default generates the
372
485
  name automatically.
373
486
  schema_sample_size: the number of rows to use for inferring the schema of
374
487
  the JSON (in case some fields are optional and it's not enough to
@@ -377,16 +490,14 @@ class DataChain:
377
490
  Returns:
378
491
  DataChain: A new DataChain instance with the new set of columns.
379
492
  """
380
- import json
381
-
382
493
  import pyarrow as pa
383
494
 
384
495
  from datachain.lib.arrow import schema_to_output
385
496
 
386
- json_values = list(self.limit(schema_sample_size).collect(col))
497
+ json_values = self.limit(schema_sample_size).to_list(col)
387
498
  json_dicts = [
388
499
  json.loads(json_value) if isinstance(json_value, str) else json_value
389
- for json_value in json_values
500
+ for (json_value,) in json_values
390
501
  ]
391
502
 
392
503
  if any(not isinstance(json_dict, dict) for json_dict in json_dicts):
@@ -400,16 +511,16 @@ class DataChain:
400
511
 
401
512
  model = dict_to_data_model(model_name, output, original_names)
402
513
 
403
- def json_to_model(json_value: Union[str, dict]):
514
+ def json_to_model(json_value: str | dict):
404
515
  json_dict = (
405
516
  json.loads(json_value) if isinstance(json_value, str) else json_value
406
517
  )
407
518
  return model.model_validate(json_dict)
408
519
 
409
- if not object_name:
410
- object_name = f"{col}_expl"
520
+ if not column:
521
+ column = f"{col}_expl"
411
522
 
412
- return self.map(json_to_model, params=col, output={object_name: model})
523
+ return self.map(json_to_model, params=col, output={column: model})
413
524
 
414
525
  @classmethod
415
526
  def datasets(
@@ -443,35 +554,290 @@ class DataChain:
443
554
  )
444
555
  return listings(*args, **kwargs)
445
556
 
557
+ @property
558
+ def namespace_name(self) -> str:
559
+ """Current namespace name in which the chain is running"""
560
+ return (
561
+ self._settings.namespace
562
+ or self.session.catalog.metastore.default_namespace_name
563
+ )
564
+
565
+ @property
566
+ def project_name(self) -> str:
567
+ """Current project name in which the chain is running"""
568
+ return (
569
+ self._settings.project
570
+ or self.session.catalog.metastore.default_project_name
571
+ )
572
+
573
+ def persist(self) -> "Self":
574
+ """Saves temporary chain that will be removed after the process ends.
575
+ Temporary datasets are useful for optimization, for example when we have
576
+ multiple chains starting with identical sub-chain. We can then persist that
577
+ common chain and use it to calculate other chains, to avoid re-calculation
578
+ every time.
579
+ It returns the chain itself.
580
+ """
581
+ schema = self.signals_schema.clone_without_sys_signals().serialize()
582
+ project = self.session.catalog.metastore.get_project(
583
+ self.project_name,
584
+ self.namespace_name,
585
+ create=True,
586
+ )
587
+ return self._evolve(
588
+ query=self._query.save(project=project, feature_schema=schema),
589
+ signal_schema=self.signals_schema | SignalSchema({"sys": Sys}),
590
+ )
591
+
592
+ def _calculate_job_hash(self, job_id: str) -> str:
593
+ """
594
+ Calculates hash of the job at the place of this chain's save method.
595
+ Hash is calculated using previous job checkpoint hash (if exists) and
596
+ adding hash of this chain to produce new hash.
597
+ """
598
+ last_checkpoint = self.session.catalog.metastore.get_last_checkpoint(job_id)
599
+
600
+ return hashlib.sha256(
601
+ (bytes.fromhex(last_checkpoint.hash) if last_checkpoint else b"")
602
+ + bytes.fromhex(self.hash())
603
+ ).hexdigest()
604
+
446
605
  def save( # type: ignore[override]
447
606
  self,
448
- name: Optional[str] = None,
449
- version: Optional[int] = None,
450
- description: Optional[str] = None,
451
- labels: Optional[list[str]] = None,
607
+ name: str,
608
+ version: str | None = None,
609
+ description: str | None = None,
610
+ attrs: list[str] | None = None,
611
+ update_version: str | None = "patch",
452
612
  **kwargs,
453
- ) -> "Self":
613
+ ) -> "DataChain":
454
614
  """Save to a Dataset. It returns the chain itself.
455
615
 
456
616
  Parameters:
457
- name : dataset name. Empty name saves to a temporary dataset that will be
458
- removed after process ends. Temp dataset are useful for optimization.
459
- version : version of a dataset. Default - the last version that exist.
460
- description : description of a dataset.
461
- labels : labels of a dataset.
617
+ name: dataset name. This can be either a fully qualified name, including
618
+ the namespace and project, or just a regular dataset name. In the latter
619
+ case, the namespace and project will be taken from the settings
620
+ (if specified) or from the default values otherwise.
621
+ version: version of a dataset. If version is not specified and dataset
622
+ already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
623
+ description: description of a dataset.
624
+ attrs: attributes of a dataset. They can be without value, e.g "NLP",
625
+ or with a value, e.g "location=US".
626
+ update_version: which part of the dataset version to automatically increase.
627
+ Available values: `major`, `minor` or `patch`. Default is `patch`.
462
628
  """
629
+
630
+ catalog = self.session.catalog
631
+
632
+ result = None # result chain that will be returned at the end
633
+
634
+ # Version validation
635
+ self._validate_version(version)
636
+ self._validate_update_version(update_version)
637
+
638
+ # get existing job if running in SaaS, or creating new one if running locally
639
+ job = self.session.get_or_create_job()
640
+
641
+ namespace_name, project_name, name = catalog.get_full_dataset_name(
642
+ name,
643
+ namespace_name=self._settings.namespace,
644
+ project_name=self._settings.project,
645
+ )
646
+ project = self._get_or_create_project(namespace_name, project_name)
647
+
648
+ # Checkpoint handling
649
+ _hash, result = self._resolve_checkpoint(name, project, job, kwargs)
650
+ if bool(result):
651
+ # Checkpoint was found and reused
652
+ print(f"Checkpoint found for dataset '{name}', skipping creation")
653
+
654
+ # Schema preparation
463
655
  schema = self.signals_schema.clone_without_sys_signals().serialize()
464
- return self._evolve(
465
- query=self._query.save(
466
- name=name,
467
- version=version,
468
- description=description,
469
- labels=labels,
470
- feature_schema=schema,
656
+
657
+ # Handle retry and delta functionality
658
+ if not result:
659
+ result = self._handle_delta(name, version, project, schema, kwargs)
660
+
661
+ if not result:
662
+ # calculate chain if we already don't have result from checkpoint or delta
663
+ result = self._evolve(
664
+ query=self._query.save(
665
+ name=name,
666
+ version=version,
667
+ project=project,
668
+ description=description,
669
+ attrs=attrs,
670
+ feature_schema=schema,
671
+ update_version=update_version,
672
+ **kwargs,
673
+ )
674
+ )
675
+
676
+ catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
677
+ return result
678
+
679
+ def _validate_version(self, version: str | None) -> None:
680
+ """Validate dataset version if provided."""
681
+ if version is not None:
682
+ semver.validate(version)
683
+
684
+ def _validate_update_version(self, update_version: str | None) -> None:
685
+ """Ensure update_version is one of: major, minor, patch."""
686
+ allowed = ["major", "minor", "patch"]
687
+ if update_version not in allowed:
688
+ raise ValueError(f"update_version must be one of {allowed}")
689
+
690
+ def _get_or_create_project(self, namespace: str, project_name: str) -> Project:
691
+ """Get project or raise if creation not allowed."""
692
+ try:
693
+ return self.session.catalog.metastore.get_project(
694
+ project_name,
695
+ namespace,
696
+ create=is_studio(),
697
+ )
698
+ except ProjectNotFoundError as e:
699
+ raise ProjectCreateNotAllowedError("Creating project is not allowed") from e
700
+
701
+ def _resolve_checkpoint(
702
+ self,
703
+ name: str,
704
+ project: Project,
705
+ job: Job,
706
+ kwargs: dict,
707
+ ) -> tuple[str, "DataChain | None"]:
708
+ """Check if checkpoint exists and return cached dataset if possible."""
709
+ from .datasets import read_dataset
710
+
711
+ metastore = self.session.catalog.metastore
712
+ checkpoints_reset = env2bool("DATACHAIN_CHECKPOINTS_RESET", undefined=True)
713
+
714
+ _hash = self._calculate_job_hash(job.id)
715
+
716
+ if (
717
+ job.parent_job_id
718
+ and not checkpoints_reset
719
+ and metastore.find_checkpoint(job.parent_job_id, _hash)
720
+ ):
721
+ # checkpoint found → find which dataset version to reuse
722
+
723
+ # Find dataset version that was created by any ancestor job
724
+ try:
725
+ dataset_version = metastore.get_dataset_version_for_job_ancestry(
726
+ name,
727
+ project.namespace.name,
728
+ project.name,
729
+ job.id,
730
+ )
731
+ except JobAncestryDepthExceededError:
732
+ raise JobAncestryDepthExceededError(
733
+ "Job continuation chain is too deep. "
734
+ "Please run the job from scratch without continuing from a "
735
+ "parent job."
736
+ ) from None
737
+
738
+ if not dataset_version:
739
+ logger.debug(
740
+ "Checkpoint found but no dataset version for '%s' "
741
+ "in job ancestry (job_id=%s). Creating new version.",
742
+ name,
743
+ job.id,
744
+ )
745
+ # Dataset version not found (e.g deleted by user) - skip
746
+ # checkpoint and recreate
747
+ return _hash, None
748
+
749
+ logger.debug(
750
+ "Reusing dataset version '%s' v%s from job ancestry "
751
+ "(job_id=%s, dataset_version_id=%s)",
752
+ name,
753
+ dataset_version.version,
754
+ job.id,
755
+ dataset_version.id,
756
+ )
757
+
758
+ # Read the specific version from ancestry
759
+ chain = read_dataset(
760
+ name,
761
+ namespace=project.namespace.name,
762
+ project=project.name,
763
+ version=dataset_version.version,
471
764
  **kwargs,
472
765
  )
766
+
767
+ # Link current job to this dataset version (not creator).
768
+ # This also updates dataset_version.job_id.
769
+ metastore.link_dataset_version_to_job(
770
+ dataset_version.id,
771
+ job.id,
772
+ is_creator=False,
773
+ )
774
+
775
+ return _hash, chain
776
+
777
+ return _hash, None
778
+
779
+ def _handle_delta(
780
+ self,
781
+ name: str,
782
+ version: str | None,
783
+ project: Project,
784
+ schema: dict,
785
+ kwargs: dict,
786
+ ) -> "DataChain | None":
787
+ """Try to save as a delta dataset.
788
+ Returns:
789
+ A DataChain if delta logic could handle it, otherwise None to fall back
790
+ to the regular save path (e.g., on first dataset creation).
791
+ """
792
+ from datachain.delta import delta_retry_update
793
+
794
+ from .datasets import read_dataset
795
+
796
+ if not self.delta or not name:
797
+ return None
798
+
799
+ assert self._delta_on is not None, "Delta chain must have delta_on defined"
800
+
801
+ result_ds, dependencies, has_changes = delta_retry_update(
802
+ self,
803
+ project.namespace.name,
804
+ project.name,
805
+ name,
806
+ on=self._delta_on,
807
+ right_on=self._delta_result_on,
808
+ compare=self._delta_compare,
809
+ delta_retry=self._delta_retry,
473
810
  )
474
811
 
812
+ # Case 1: delta produced a new dataset
813
+ if result_ds:
814
+ return self._evolve(
815
+ query=result_ds._query.save(
816
+ name=name,
817
+ version=version,
818
+ project=project,
819
+ feature_schema=schema,
820
+ dependencies=dependencies,
821
+ **kwargs,
822
+ )
823
+ )
824
+
825
+ # Case 2: no changes → reuse last version
826
+ if not has_changes:
827
+ # sources have not been changed so new version of resulting dataset
828
+ # would be the same as previous one. To avoid duplicating exact
829
+ # datasets, we won't create new version of it and we will return
830
+ # current latest version instead.
831
+ return read_dataset(
832
+ name,
833
+ namespace=project.namespace.name,
834
+ project=project.name,
835
+ **kwargs,
836
+ )
837
+
838
+ # Case 3: first creation of dataset
839
+ return None
840
+
475
841
  def apply(self, func, *args, **kwargs):
476
842
  """Apply any function to the chain.
477
843
 
@@ -497,10 +863,10 @@ class DataChain:
497
863
 
498
864
  def map(
499
865
  self,
500
- func: Optional[Callable] = None,
501
- params: Union[None, str, Sequence[str]] = None,
866
+ func: Callable | None = None,
867
+ params: str | Sequence[str] | None = None,
502
868
  output: OutputType = None,
503
- **signal_map,
869
+ **signal_map: Any,
504
870
  ) -> "Self":
505
871
  """Apply a function to each row to create new signals. The function should
506
872
  return a new object for each row. It returns a chain itself with new signals.
@@ -508,17 +874,17 @@ class DataChain:
508
874
  Input-output relationship: 1:1
509
875
 
510
876
  Parameters:
511
- func : Function applied to each row.
512
- params : List of column names used as input for the function. Default
877
+ func: Function applied to each row.
878
+ params: List of column names used as input for the function. Default
513
879
  is taken from function signature.
514
- output : Dictionary defining new signals and their corresponding types.
880
+ output: Dictionary defining new signals and their corresponding types.
515
881
  Default type is taken from function signature. Default can be also
516
882
  taken from kwargs - **signal_map (see below).
517
883
  If signal name is defined using signal_map (see below) only a single
518
884
  type value can be used.
519
- **signal_map : kwargs can be used to define `func` together with it's return
885
+ **signal_map: kwargs can be used to define `func` together with its return
520
886
  signal name in format of `map(my_sign=my_func)`. This helps define
521
- signal names and function in a nicer way.
887
+ signal names and functions in a nicer way.
522
888
 
523
889
  Example:
524
890
  Using signal_map and single type in output:
@@ -539,18 +905,19 @@ class DataChain:
539
905
  if (prefetch := self._settings.prefetch) is not None:
540
906
  udf_obj.prefetch = prefetch
541
907
 
908
+ sys_schema = SignalSchema({"sys": Sys})
542
909
  return self._evolve(
543
910
  query=self._query.add_signals(
544
- udf_obj.to_udf_wrapper(),
911
+ udf_obj.to_udf_wrapper(self._settings.batch_size),
545
912
  **self._settings.to_dict(),
546
913
  ),
547
- signal_schema=self.signals_schema | udf_obj.output,
914
+ signal_schema=sys_schema | self.signals_schema | udf_obj.output,
548
915
  )
549
916
 
550
917
  def gen(
551
918
  self,
552
- func: Optional[Union[Callable, Generator]] = None,
553
- params: Union[None, str, Sequence[str]] = None,
919
+ func: Callable | Generator | None = None,
920
+ params: str | Sequence[str] | None = None,
554
921
  output: OutputType = None,
555
922
  **signal_map,
556
923
  ) -> "Self":
@@ -579,19 +946,21 @@ class DataChain:
579
946
  udf_obj.prefetch = prefetch
580
947
  return self._evolve(
581
948
  query=self._query.generate(
582
- udf_obj.to_udf_wrapper(),
949
+ udf_obj.to_udf_wrapper(self._settings.batch_size),
583
950
  **self._settings.to_dict(),
584
951
  ),
585
- signal_schema=udf_obj.output,
952
+ signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
586
953
  )
587
954
 
955
+ @delta_disabled
588
956
  def agg(
589
957
  self,
590
- func: Optional[Callable] = None,
591
- partition_by: Optional[PartitionByType] = None,
592
- params: Union[None, str, Sequence[str]] = None,
958
+ /,
959
+ func: Callable | None = None,
960
+ partition_by: PartitionByType | None = None,
961
+ params: str | Sequence[str] | None = None,
593
962
  output: OutputType = None,
594
- **signal_map,
963
+ **signal_map: Callable,
595
964
  ) -> "Self":
596
965
  """Aggregate rows using `partition_by` statement and apply a function to the
597
966
  groups of aggregated rows. The function needs to return new objects for each
@@ -601,12 +970,28 @@ class DataChain:
601
970
 
602
971
  This method bears similarity to `gen()` and `map()`, employing a comparable set
603
972
  of parameters, yet differs in two crucial aspects:
973
+
604
974
  1. The `partition_by` parameter: This specifies the column name or a list of
605
975
  column names that determine the grouping criteria for aggregation.
606
976
  2. Group-based UDF function input: Instead of individual rows, the function
607
- receives a list all rows within each group defined by `partition_by`.
977
+ receives a list of all rows within each group defined by `partition_by`.
978
+
979
+ If `partition_by` is not set or is an empty list, all rows will be placed
980
+ into a single group.
981
+
982
+ Parameters:
983
+ func: Function applied to each group of rows.
984
+ partition_by: Column name(s) to group by. If None, all rows go
985
+ into one group.
986
+ params: List of column names used as input for the function. Default is
987
+ taken from function signature.
988
+ output: Dictionary defining new signals and their corresponding types.
989
+ Default type is taken from function signature.
990
+ **signal_map: kwargs can be used to define `func` together with its return
991
+ signal name in format of `agg(result_column=my_func)`.
608
992
 
609
993
  Examples:
994
+ Basic aggregation with lambda function:
610
995
  ```py
611
996
  chain = chain.agg(
612
997
  total=lambda category, amount: [sum(amount)],
@@ -617,7 +1002,6 @@ class DataChain:
617
1002
  ```
618
1003
 
619
1004
  An alternative syntax, when you need to specify a more complex function:
620
-
621
1005
  ```py
622
1006
  # It automatically resolves which columns to pass to the function
623
1007
  # by looking at the function signature.
@@ -635,21 +1019,80 @@ class DataChain:
635
1019
  )
636
1020
  chain.save("new_dataset")
637
1021
  ```
1022
+
1023
+ Using complex signals for partitioning (`File` or any Pydantic `BaseModel`):
1024
+ ```py
1025
+ def my_agg(files: list[File]) -> Iterator[tuple[File, int]]:
1026
+ yield files[0], sum(f.size for f in files)
1027
+
1028
+ chain = chain.agg(
1029
+ my_agg,
1030
+ params=("file",),
1031
+ output={"file": File, "total": int},
1032
+ partition_by="file", # Column referring to all sub-columns of File
1033
+ )
1034
+ chain.save("new_dataset")
1035
+ ```
1036
+
1037
+ Aggregating all rows into a single group (when `partition_by` is not set):
1038
+ ```py
1039
+ chain = chain.agg(
1040
+ total_size=lambda file, size: [sum(size)],
1041
+ output=int,
1042
+ # No partition_by specified - all rows go into one group
1043
+ )
1044
+ chain.save("new_dataset")
1045
+ ```
1046
+
1047
+ Multiple partition columns:
1048
+ ```py
1049
+ chain = chain.agg(
1050
+ total=lambda category, subcategory, amount: [sum(amount)],
1051
+ output=float,
1052
+ partition_by=["category", "subcategory"],
1053
+ )
1054
+ chain.save("new_dataset")
1055
+ ```
638
1056
  """
1057
+ if partition_by is not None:
1058
+ # Convert string partition_by parameters to Column objects
1059
+ if isinstance(partition_by, (str, Function, ColumnElement)):
1060
+ list_partition_by = [partition_by]
1061
+ else:
1062
+ list_partition_by = list(partition_by)
1063
+
1064
+ processed_partition_columns: list[ColumnElement] = []
1065
+ for col in list_partition_by:
1066
+ if isinstance(col, str):
1067
+ columns = self.signals_schema.db_signals(name=col, as_columns=True)
1068
+ if not columns:
1069
+ raise SignalResolvingError([col], "is not found")
1070
+ processed_partition_columns.extend(cast("list[Column]", columns))
1071
+ elif isinstance(col, Function):
1072
+ column = col.get_column(self.signals_schema)
1073
+ processed_partition_columns.append(column)
1074
+ else:
1075
+ # Assume it's already a ColumnElement
1076
+ processed_partition_columns.append(col)
1077
+
1078
+ processed_partition_by = processed_partition_columns
1079
+ else:
1080
+ processed_partition_by = []
1081
+
639
1082
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
640
1083
  return self._evolve(
641
1084
  query=self._query.generate(
642
- udf_obj.to_udf_wrapper(),
643
- partition_by=partition_by,
1085
+ udf_obj.to_udf_wrapper(self._settings.batch_size),
1086
+ partition_by=processed_partition_by,
644
1087
  **self._settings.to_dict(),
645
1088
  ),
646
- signal_schema=udf_obj.output,
1089
+ signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
647
1090
  )
648
1091
 
649
1092
  def batch_map(
650
1093
  self,
651
- func: Optional[Callable] = None,
652
- params: Union[None, str, Sequence[str]] = None,
1094
+ func: Callable | None = None,
1095
+ params: str | Sequence[str] | None = None,
653
1096
  output: OutputType = None,
654
1097
  batch: int = 1000,
655
1098
  **signal_map,
@@ -661,7 +1104,7 @@ class DataChain:
661
1104
  It accepts the same parameters plus an
662
1105
  additional parameter:
663
1106
 
664
- batch : Size of each batch passed to `func`. Defaults to 1000.
1107
+ batch: Size of each batch passed to `func`. Defaults to 1000.
665
1108
 
666
1109
  Example:
667
1110
  ```py
@@ -671,11 +1114,24 @@ class DataChain:
671
1114
  )
672
1115
  chain.save("new_dataset")
673
1116
  ```
1117
+
1118
+ .. deprecated:: 0.29.0
1119
+ This method is deprecated and will be removed in a future version.
1120
+ Use `agg()` instead, which provides the similar functionality.
674
1121
  """
1122
+ import warnings
1123
+
1124
+ warnings.warn(
1125
+ "batch_map() is deprecated and will be removed in a future version. "
1126
+ "Use agg() instead, which provides the similar functionality.",
1127
+ DeprecationWarning,
1128
+ stacklevel=2,
1129
+ )
675
1130
  udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
1131
+
676
1132
  return self._evolve(
677
1133
  query=self._query.add_signals(
678
- udf_obj.to_udf_wrapper(batch),
1134
+ udf_obj.to_udf_wrapper(self._settings.batch_size, batch=batch),
679
1135
  **self._settings.to_dict(),
680
1136
  ),
681
1137
  signal_schema=self.signals_schema | udf_obj.output,
@@ -684,8 +1140,8 @@ class DataChain:
684
1140
  def _udf_to_obj(
685
1141
  self,
686
1142
  target_class: type[UDFObjT],
687
- func: Optional[Union[Callable, UDFObjT]],
688
- params: Union[None, str, Sequence[str]],
1143
+ func: Callable | UDFObjT | None,
1144
+ params: str | Sequence[str] | None,
689
1145
  output: OutputType,
690
1146
  signal_map: dict[str, Callable],
691
1147
  ) -> UDFObjT:
@@ -696,11 +1152,7 @@ class DataChain:
696
1152
  sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
697
1153
  DataModel.register(list(sign.output_schema.values.values()))
698
1154
 
699
- signals_schema = self.signals_schema
700
- if self._sys:
701
- signals_schema = SignalSchema({"sys": Sys}) | signals_schema
702
-
703
- params_schema = signals_schema.slice(
1155
+ params_schema = self.signals_schema.slice(
704
1156
  sign.params, self._setup, is_batch=is_batch
705
1157
  )
706
1158
 
@@ -710,7 +1162,7 @@ class DataChain:
710
1162
  query_func = getattr(self._query, method_name)
711
1163
 
712
1164
  new_schema = self.signals_schema.resolve(*args)
713
- columns = [C(col) for col in new_schema.db_signals()]
1165
+ columns = new_schema.db_signals(as_columns=True)
714
1166
  return query_func(*columns, **kwargs)
715
1167
 
716
1168
  @resolve_columns
@@ -729,15 +1181,17 @@ class DataChain:
729
1181
  Order is not guaranteed when steps are added after an `order_by` statement.
730
1182
  I.e. when using `read_dataset` an `order_by` statement should be used if
731
1183
  the order of the records in the chain is important.
732
- Using `order_by` directly before `limit`, `collect` and `collect_flatten`
1184
+ Using `order_by` directly before `limit`, `to_list` and similar methods
733
1185
  will give expected results.
734
- See https://github.com/iterative/datachain/issues/477 for further details.
1186
+ See https://github.com/datachain-ai/datachain/issues/477
1187
+ for further details.
735
1188
  """
736
1189
  if descending:
737
1190
  args = tuple(sqlalchemy.desc(a) for a in args)
738
1191
 
739
1192
  return self._evolve(query=self._query.order_by(*args))
740
1193
 
1194
+ @delta_disabled
741
1195
  def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
742
1196
  """Removes duplicate rows based on uniqueness of some input column(s)
743
1197
  i.e if rows are found with the same value of input column(s), only one
@@ -745,7 +1199,7 @@ class DataChain:
745
1199
 
746
1200
  Example:
747
1201
  ```py
748
- dc.distinct("file.parent", "file.name")
1202
+ dc.distinct("file.path")
749
1203
  ```
750
1204
  """
751
1205
  return self._evolve(
@@ -754,11 +1208,9 @@ class DataChain:
754
1208
  )
755
1209
  )
756
1210
 
757
- def select(self, *args: str, _sys: bool = True) -> "Self":
1211
+ def select(self, *args: str) -> "Self":
758
1212
  """Select only a specified set of signals."""
759
1213
  new_schema = self.signals_schema.resolve(*args)
760
- if self._sys and _sys:
761
- new_schema = SignalSchema({"sys": Sys}) | new_schema
762
1214
  columns = new_schema.db_signals()
763
1215
  return self._evolve(
764
1216
  query=self._query.select(*columns), signal_schema=new_schema
@@ -772,10 +1224,11 @@ class DataChain:
772
1224
  query=self._query.select(*columns), signal_schema=new_schema
773
1225
  )
774
1226
 
775
- def group_by(
1227
+ @delta_disabled # type: ignore[arg-type]
1228
+ def group_by( # noqa: C901, PLR0912
776
1229
  self,
777
1230
  *,
778
- partition_by: Optional[Union[str, Func, Sequence[Union[str, Func]]]] = None,
1231
+ partition_by: str | Func | Sequence[str | Func] | None = None,
779
1232
  **kwargs: Func,
780
1233
  ) -> "Self":
781
1234
  """Group rows by specified set of signals and return new signals
@@ -791,6 +1244,15 @@ class DataChain:
791
1244
  partition_by=("file_source", "file_ext"),
792
1245
  )
793
1246
  ```
1247
+
1248
+ Using complex signals:
1249
+ ```py
1250
+ chain = chain.group_by(
1251
+ total_size=func.sum("file.size"),
1252
+ count=func.count(),
1253
+ partition_by="file", # Uses column name, expands to File's unique keys
1254
+ )
1255
+ ```
794
1256
  """
795
1257
  if partition_by is None:
796
1258
  partition_by = []
@@ -801,20 +1263,61 @@ class DataChain:
801
1263
  signal_columns: list[Column] = []
802
1264
  schema_fields: dict[str, DataType] = {}
803
1265
  keep_columns: list[str] = []
1266
+ partial_fields: list[str] = [] # Track specific fields for partial creation
1267
+ schema_partition_by: list[str] = []
804
1268
 
805
- # validate partition_by columns and add them to the schema
806
1269
  for col in partition_by:
807
1270
  if isinstance(col, str):
808
- col_db_name = ColumnMeta.to_db_name(col)
809
- col_type = self.signals_schema.get_column_type(col_db_name)
810
- column = Column(col_db_name, python_to_sql(col_type))
811
- if col not in keep_columns:
812
- keep_columns.append(col)
1271
+ columns = self.signals_schema.db_signals(name=col, as_columns=True)
1272
+ if not columns:
1273
+ raise SignalResolvingError([col], "is not found")
1274
+ partition_by_columns.extend(cast("list[Column]", columns))
1275
+
1276
+ # For nested field references (e.g., "nested.level1.name"),
1277
+ # we need to distinguish between:
1278
+ # 1. References to fields within a complex signal (create partials)
1279
+ # 2. Deep nested references that should be flattened
1280
+ if "." in col:
1281
+ # Split the column reference to analyze it
1282
+ parts = col.split(".")
1283
+ parent_signal = parts[0]
1284
+ parent_type = self.signals_schema.values.get(parent_signal)
1285
+
1286
+ if ModelStore.is_partial(parent_type):
1287
+ if parent_signal not in keep_columns:
1288
+ keep_columns.append(parent_signal)
1289
+ partial_fields.append(col)
1290
+ schema_partition_by.append(col)
1291
+ else:
1292
+ # BaseModel or other - add flattened columns directly
1293
+ for column in cast("list[Column]", columns):
1294
+ col_type = self.signals_schema.get_column_type(column.name)
1295
+ schema_fields[column.name] = col_type
1296
+ schema_partition_by.append(col)
1297
+ else:
1298
+ # simple signal - but we need to check if it's a complex signal
1299
+ # complex signal - only include the columns used for partitioning
1300
+ col_type = self.signals_schema.get_column_type(
1301
+ col, with_subtree=True
1302
+ )
1303
+ if isinstance(col_type, type) and issubclass(col_type, BaseModel):
1304
+ # Complex signal - add only the partitioning columns
1305
+ for column in cast("list[Column]", columns):
1306
+ col_type = self.signals_schema.get_column_type(column.name)
1307
+ schema_fields[column.name] = col_type
1308
+ schema_partition_by.append(col)
1309
+ # Simple signal - keep the entire signal
1310
+ else:
1311
+ if col not in keep_columns:
1312
+ keep_columns.append(col)
1313
+ schema_partition_by.append(col)
813
1314
  elif isinstance(col, Function):
814
1315
  column = col.get_column(self.signals_schema)
815
1316
  col_db_name = column.name
816
1317
  col_type = column.type.python_type
817
1318
  schema_fields[col_db_name] = col_type
1319
+ partition_by_columns.append(column)
1320
+ signal_columns.append(column)
818
1321
  else:
819
1322
  raise DataChainColumnError(
820
1323
  col,
@@ -823,9 +1326,7 @@ class DataChain:
823
1326
  " but expected str or Function"
824
1327
  ),
825
1328
  )
826
- partition_by_columns.append(column)
827
1329
 
828
- # validate signal columns and add them to the schema
829
1330
  if not kwargs:
830
1331
  raise ValueError("At least one column should be provided for group_by")
831
1332
  for col_name, func in kwargs.items():
@@ -838,9 +1339,9 @@ class DataChain:
838
1339
  signal_columns.append(column)
839
1340
  schema_fields[col_name] = func.get_result_type(self.signals_schema)
840
1341
 
841
- signal_schema = SignalSchema(schema_fields)
842
- if keep_columns:
843
- signal_schema |= self.signals_schema.to_partial(*keep_columns)
1342
+ signal_schema = self.signals_schema.group_by(
1343
+ schema_partition_by, signal_columns
1344
+ )
844
1345
 
845
1346
  return self._evolve(
846
1347
  query=self._query.group_by(signal_columns, partition_by_columns),
@@ -848,17 +1349,13 @@ class DataChain:
848
1349
  )
849
1350
 
850
1351
  def mutate(self, **kwargs) -> "Self":
851
- """Create new signals based on existing signals.
852
-
853
- This method cannot modify existing columns. If you need to modify an
854
- existing column, use a different name for the new column and then use
855
- `select()` to choose which columns to keep.
1352
+ """Create or modify signals based on existing signals.
856
1353
 
857
1354
  This method is vectorized and more efficient compared to map(), and it does not
858
1355
  extract or download any data from the internal database. However, it can only
859
1356
  utilize predefined built-in functions and their combinations.
860
1357
 
861
- The supported functions:
1358
+ Supported functions:
862
1359
  Numerical: +, -, *, /, rand(), avg(), count(), func(),
863
1360
  greatest(), least(), max(), min(), sum()
864
1361
  String: length(), split(), replace(), regexp_replace()
@@ -871,7 +1368,7 @@ class DataChain:
871
1368
  ```py
872
1369
  dc.mutate(
873
1370
  area=Column("image.height") * Column("image.width"),
874
- extension=file_ext(Column("file.name")),
1371
+ extension=file_ext(Column("file.path")),
875
1372
  dist=cosine_distance(embedding_text, embedding_image)
876
1373
  )
877
1374
  ```
@@ -885,13 +1382,20 @@ class DataChain:
885
1382
  ```
886
1383
 
887
1384
  This method can be also used to rename signals. If the Column("name") provided
888
- as value for the new signal - the old column will be dropped. Otherwise a new
889
- column is created.
1385
+ as value for the new signal - the old signal will be dropped. Otherwise a new
1386
+ signal is created. Exception, if the old signal is nested one (e.g.
1387
+ `C("file.path")`), it will be kept to keep the object intact.
890
1388
 
891
1389
  Example:
892
1390
  ```py
893
1391
  dc.mutate(
894
- newkey=Column("oldkey")
1392
+ newkey=Column("oldkey") # drops oldkey
1393
+ )
1394
+ ```
1395
+
1396
+ ```py
1397
+ dc.mutate(
1398
+ size=Column("file.size") # keeps `file.size`
895
1399
  )
896
1400
  ```
897
1401
  """
@@ -926,49 +1430,52 @@ class DataChain:
926
1430
  # adding new signal
927
1431
  mutated[name] = value
928
1432
 
1433
+ new_schema = schema.mutate(kwargs)
929
1434
  return self._evolve(
930
- query=self._query.mutate(**mutated), signal_schema=schema.mutate(kwargs)
1435
+ query=self._query.mutate(new_schema=new_schema, **mutated),
1436
+ signal_schema=new_schema,
931
1437
  )
932
1438
 
933
1439
  @property
934
1440
  def _effective_signals_schema(self) -> "SignalSchema":
935
- """Effective schema used for user-facing API like collect, to_pandas, etc."""
1441
+ """Effective schema used for user-facing API like to_list, to_pandas, etc."""
936
1442
  signals_schema = self.signals_schema
937
1443
  if not self._sys:
938
1444
  return signals_schema.clone_without_sys_signals()
939
1445
  return signals_schema
940
1446
 
941
1447
  @overload
942
- def collect_flatten(self) -> Iterator[tuple[Any, ...]]: ...
1448
+ def _leaf_values(self) -> Iterator[tuple[Any, ...]]: ...
943
1449
 
944
1450
  @overload
945
- def collect_flatten(self, *, include_hidden: bool) -> Iterator[tuple[Any, ...]]: ...
1451
+ def _leaf_values(self, *, include_hidden: bool) -> Iterator[tuple[Any, ...]]: ...
946
1452
 
947
1453
  @overload
948
- def collect_flatten(
1454
+ def _leaf_values(
949
1455
  self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
950
1456
  ) -> Iterator[_T]: ...
951
1457
 
952
1458
  @overload
953
- def collect_flatten(
1459
+ def _leaf_values(
954
1460
  self,
955
1461
  *,
956
1462
  row_factory: Callable[[list[str], tuple[Any, ...]], _T],
957
1463
  include_hidden: bool,
958
1464
  ) -> Iterator[_T]: ...
959
1465
 
960
- def collect_flatten(self, *, row_factory=None, include_hidden: bool = True):
1466
+ def _leaf_values(self, *, row_factory=None, include_hidden: bool = True):
961
1467
  """Yields flattened rows of values as a tuple.
962
1468
 
963
1469
  Args:
964
- row_factory : A callable to convert row to a custom format.
965
- It should accept two arguments: a list of column names and
966
- a tuple of row values.
1470
+ row_factory: A callable to convert row to a custom format.
1471
+ It should accept two arguments: a list of column names and
1472
+ a tuple of row values.
967
1473
  include_hidden: Whether to include hidden signals from the schema.
968
1474
  """
969
1475
  db_signals = self._effective_signals_schema.db_signals(
970
1476
  include_hidden=include_hidden
971
1477
  )
1478
+
972
1479
  with self._query.ordered_select(*db_signals).as_iterable() as rows:
973
1480
  if row_factory:
974
1481
  rows = (row_factory(db_signals, r) for r in rows) # type: ignore[assignment]
@@ -985,7 +1492,7 @@ class DataChain:
985
1492
  headers, _ = self._effective_signals_schema.get_headers_with_length()
986
1493
  column_names = [".".join(filter(None, header)) for header in headers]
987
1494
 
988
- results_iter = self.collect_flatten()
1495
+ results_iter = self._leaf_values()
989
1496
 
990
1497
  def column_chunks() -> Iterator[list[list[Any]]]:
991
1498
  for chunk_iter in batched_it(results_iter, chunk_size):
@@ -1018,55 +1525,51 @@ class DataChain:
1018
1525
 
1019
1526
  def results(self, *, row_factory=None, include_hidden=True):
1020
1527
  if row_factory is None:
1021
- return list(self.collect_flatten(include_hidden=include_hidden))
1528
+ return list(self._leaf_values(include_hidden=include_hidden))
1022
1529
  return list(
1023
- self.collect_flatten(row_factory=row_factory, include_hidden=include_hidden)
1530
+ self._leaf_values(row_factory=row_factory, include_hidden=include_hidden)
1024
1531
  )
1025
1532
 
1026
1533
  def to_records(self) -> list[dict[str, Any]]:
1027
1534
  """Convert every row to a dictionary."""
1028
1535
 
1029
1536
  def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
1030
- return dict(zip(cols, row))
1537
+ return dict(zip(cols, row, strict=False))
1031
1538
 
1032
1539
  return self.results(row_factory=to_dict)
1033
1540
 
1034
- @overload
1035
- def collect(self) -> Iterator[tuple[DataValue, ...]]: ...
1036
-
1037
- @overload
1038
- def collect(self, col: str) -> Iterator[DataValue]: ...
1039
-
1040
- @overload
1041
- def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...
1042
-
1043
- def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...]]]: # type: ignore[overload-overlap,misc]
1541
+ def to_iter(self, *cols: str) -> Iterator[tuple[DataValue, ...]]:
1044
1542
  """Yields rows of values, optionally limited to the specified columns.
1045
1543
 
1046
1544
  Args:
1047
1545
  *cols: Limit to the specified columns. By default, all columns are selected.
1048
1546
 
1049
1547
  Yields:
1050
- (DataType): Yields a single item if a column is selected.
1051
- (tuple[DataType, ...]): Yields a tuple of items if multiple columns are
1052
- selected.
1548
+ (tuple[DataType, ...]): Yields a tuple of items for each row.
1053
1549
 
1054
1550
  Example:
1055
1551
  Iterating over all rows:
1056
1552
  ```py
1057
- for row in dc.collect():
1553
+ for row in ds.to_iter():
1554
+ print(row)
1555
+ ```
1556
+
1557
+ DataChain is iterable and can be used in a for loop directly which is
1558
+ equivalent to `ds.to_iter()`:
1559
+ ```py
1560
+ for row in ds:
1058
1561
  print(row)
1059
1562
  ```
1060
1563
 
1061
1564
  Iterating over all rows with selected columns:
1062
1565
  ```py
1063
- for name, size in dc.collect("file.name", "file.size"):
1566
+ for name, size in ds.to_iter("file.path", "file.size"):
1064
1567
  print(name, size)
1065
1568
  ```
1066
1569
 
1067
1570
  Iterating over a single column:
1068
1571
  ```py
1069
- for file in dc.collect("file.name"):
1572
+ for (file,) in ds.to_iter("file.path"):
1070
1573
  print(file)
1071
1574
  ```
1072
1575
  """
@@ -1078,7 +1581,31 @@ class DataChain:
1078
1581
  ret = signals_schema.row_to_features(
1079
1582
  row, catalog=chain.session.catalog, cache=chain._settings.cache
1080
1583
  )
1081
- yield ret[0] if len(cols) == 1 else tuple(ret)
1584
+ yield tuple(ret)
1585
+
1586
+ @overload
1587
+ def collect(self) -> Iterator[tuple[DataValue, ...]]: ...
1588
+
1589
+ @overload
1590
+ def collect(self, col: str) -> Iterator[DataValue]: ...
1591
+
1592
+ @overload
1593
+ def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...
1594
+
1595
+ def collect(self, *cols: str) -> Iterator[DataValue | tuple[DataValue, ...]]: # type: ignore[overload-overlap,misc]
1596
+ """
1597
+ Deprecated. Use `to_iter` method instead.
1598
+ """
1599
+ warnings.warn(
1600
+ "Method `collect` is deprecated. Use `to_iter` method instead.",
1601
+ DeprecationWarning,
1602
+ stacklevel=2,
1603
+ )
1604
+
1605
+ if len(cols) == 1:
1606
+ yield from [item[0] for item in self.to_iter(*cols)]
1607
+ else:
1608
+ yield from self.to_iter(*cols)
1082
1609
 
1083
1610
  def to_pytorch(
1084
1611
  self,
@@ -1112,7 +1639,7 @@ class DataChain:
1112
1639
  if self._query.attached:
1113
1640
  chain = self
1114
1641
  else:
1115
- chain = self.save()
1642
+ chain = self.persist()
1116
1643
  assert chain.name is not None # for mypy
1117
1644
  return PytorchDataset(
1118
1645
  chain.name,
@@ -1126,15 +1653,12 @@ class DataChain:
1126
1653
  remove_prefetched=remove_prefetched,
1127
1654
  )
1128
1655
 
1129
- def remove_file_signals(self) -> "Self":
1130
- schema = self.signals_schema.clone_without_file_signals()
1131
- return self.select(*schema.values.keys())
1132
-
1656
+ @delta_disabled
1133
1657
  def merge(
1134
1658
  self,
1135
1659
  right_ds: "DataChain",
1136
- on: Union[MergeColType, Sequence[MergeColType]],
1137
- right_on: Optional[Union[MergeColType, Sequence[MergeColType]]] = None,
1660
+ on: MergeColType | Sequence[MergeColType],
1661
+ right_on: MergeColType | Sequence[MergeColType] | None = None,
1138
1662
  inner=False,
1139
1663
  full=False,
1140
1664
  rname="right_",
@@ -1202,8 +1726,8 @@ class DataChain:
1202
1726
 
1203
1727
  def _resolve(
1204
1728
  ds: DataChain,
1205
- col: Union[str, Function, sqlalchemy.ColumnElement],
1206
- side: Union[str, None],
1729
+ col: str | Function | sqlalchemy.ColumnElement,
1730
+ side: str | None,
1207
1731
  ):
1208
1732
  try:
1209
1733
  if isinstance(col, Function):
@@ -1216,7 +1740,7 @@ class DataChain:
1216
1740
  ops = [
1217
1741
  _resolve(self, left, "left")
1218
1742
  == _resolve(right_ds, right, "right" if right_on else None)
1219
- for left, right in zip(on, right_on or on)
1743
+ for left, right in zip(on, right_on or on, strict=False)
1220
1744
  ]
1221
1745
 
1222
1746
  if errors:
@@ -1225,32 +1749,44 @@ class DataChain:
1225
1749
  )
1226
1750
 
1227
1751
  query = self._query.join(
1228
- right_ds._query, sqlalchemy.and_(*ops), inner, full, rname + "{name}"
1752
+ right_ds._query, sqlalchemy.and_(*ops), inner, full, rname
1229
1753
  )
1230
1754
  query.feature_schema = None
1231
1755
  ds = self._evolve(query=query)
1232
1756
 
1757
+ # Note: merge drops sys signals from both sides, make sure to not include it
1758
+ # in the resulting schema
1233
1759
  signals_schema = self.signals_schema.clone_without_sys_signals()
1234
1760
  right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
1235
- ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
1236
- right_signals_schema, rname
1237
- )
1761
+
1762
+ ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
1238
1763
 
1239
1764
  return ds
1240
1765
 
1766
+ @delta_disabled
1241
1767
  def union(self, other: "Self") -> "Self":
1242
1768
  """Return the set union of the two datasets.
1243
1769
 
1244
1770
  Parameters:
1245
1771
  other: chain whose rows will be added to `self`.
1246
1772
  """
1773
+ self_schema = self.signals_schema
1774
+ other_schema = other.signals_schema
1775
+ missing_left, missing_right = self_schema.compare_signals(other_schema)
1776
+ if missing_left or missing_right:
1777
+ raise UnionSchemaMismatchError.from_column_sets(
1778
+ missing_left,
1779
+ missing_right,
1780
+ )
1781
+
1782
+ self.signals_schema = self_schema.clone_without_sys_signals()
1247
1783
  return self._evolve(query=self._query.union(other._query))
1248
1784
 
1249
1785
  def subtract( # type: ignore[override]
1250
1786
  self,
1251
1787
  other: "DataChain",
1252
- on: Optional[Union[str, Sequence[str]]] = None,
1253
- right_on: Optional[Union[str, Sequence[str]]] = None,
1788
+ on: str | Sequence[str] | None = None,
1789
+ right_on: str | Sequence[str] | None = None,
1254
1790
  ) -> "Self":
1255
1791
  """Remove rows that appear in another chain.
1256
1792
 
@@ -1307,58 +1843,51 @@ class DataChain:
1307
1843
  zip(
1308
1844
  self.signals_schema.resolve(*on).db_signals(),
1309
1845
  other.signals_schema.resolve(*right_on).db_signals(),
1846
+ strict=False,
1310
1847
  ) # type: ignore[arg-type]
1311
1848
  )
1312
1849
  return self._evolve(query=self._query.subtract(other._query, signals)) # type: ignore[arg-type]
1313
1850
 
1314
- def compare(
1851
+ def diff(
1315
1852
  self,
1316
1853
  other: "DataChain",
1317
- on: Union[str, Sequence[str]],
1318
- right_on: Optional[Union[str, Sequence[str]]] = None,
1319
- compare: Optional[Union[str, Sequence[str]]] = None,
1320
- right_compare: Optional[Union[str, Sequence[str]]] = None,
1854
+ on: str | Sequence[str],
1855
+ right_on: str | Sequence[str] | None = None,
1856
+ compare: str | Sequence[str] | None = None,
1857
+ right_compare: str | Sequence[str] | None = None,
1321
1858
  added: bool = True,
1322
1859
  deleted: bool = True,
1323
1860
  modified: bool = True,
1324
1861
  same: bool = False,
1325
- status_col: Optional[str] = None,
1862
+ status_col: str | None = None,
1326
1863
  ) -> "DataChain":
1327
- """Comparing two chains by identifying rows that are added, deleted, modified
1328
- or same. Result is the new chain that has additional column with possible
1329
- values: `A`, `D`, `M`, `U` representing added, deleted, modified and same
1330
- rows respectively. Note that if only one "status" is asked, by setting proper
1331
- flags, this additional column is not created as it would have only one value
1332
- for all rows. Beside additional diff column, new chain has schema of the chain
1333
- on which method was called.
1864
+ """Calculate differences between two chains.
1865
+
1866
+ This method identifies records that are added, deleted, modified, or unchanged
1867
+ between two chains. It adds a status column with values: A=added, D=deleted,
1868
+ M=modified, S=same.
1334
1869
 
1335
1870
  Parameters:
1336
- other: Chain to calculate diff from.
1337
- on: Column or list of columns to match on. If both chains have the
1338
- same columns then this column is enough for the match. Otherwise,
1339
- `right_on` parameter has to specify the columns for the other chain.
1340
- This value is used to find corresponding row in other dataset. If not
1341
- found there, row is considered as added (or removed if vice versa), and
1342
- if found then row can be either modified or same.
1343
- right_on: Optional column or list of columns
1344
- for the `other` to match.
1345
- compare: Column or list of columns to compare on. If both chains have
1346
- the same columns then this column is enough for the compare. Otherwise,
1347
- `right_compare` parameter has to specify the columns for the other
1348
- chain. This value is used to see if row is modified or same. If
1349
- not set, all columns will be used for comparison
1350
- right_compare: Optional column or list of columns
1351
- for the `other` to compare to.
1352
- added (bool): Whether to return added rows in resulting chain.
1353
- deleted (bool): Whether to return deleted rows in resulting chain.
1354
- modified (bool): Whether to return modified rows in resulting chain.
1355
- same (bool): Whether to return unchanged rows in resulting chain.
1356
- status_col (str): Name of the new column that is created in resulting chain
1357
- representing diff status.
1871
+ other: Chain to compare against.
1872
+ on: Column(s) to match records between chains.
1873
+ right_on: Column(s) in the other chain to match against. Defaults to `on`.
1874
+ compare: Column(s) to check for changes.
1875
+ If not specified,all columns are used.
1876
+ right_compare: Column(s) in the other chain to compare against.
1877
+ Defaults to values of `compare`.
1878
+ added (bool): Include records that exist in this chain but not in the other.
1879
+ deleted (bool): Include records that exist only in the other chain.
1880
+ modified (bool): Include records that exist in both
1881
+ but have different values.
1882
+ same (bool): Include records that are identical in both chains.
1883
+ status_col (str): Name for the status column showing differences.
1884
+
1885
+ Default behavior: By default, shows added, deleted, and modified records,
1886
+ but excludes unchanged records (same=False). Status column is not created.
1358
1887
 
1359
1888
  Example:
1360
1889
  ```py
1361
- res = persons.compare(
1890
+ res = persons.diff(
1362
1891
  new_persons,
1363
1892
  on=["id"],
1364
1893
  right_on=["other_id"],
@@ -1387,42 +1916,40 @@ class DataChain:
1387
1916
  status_col=status_col,
1388
1917
  )
1389
1918
 
1390
- def diff(
1919
+ def file_diff(
1391
1920
  self,
1392
1921
  other: "DataChain",
1393
1922
  on: str = "file",
1394
- right_on: Optional[str] = None,
1923
+ right_on: str | None = None,
1395
1924
  added: bool = True,
1396
1925
  modified: bool = True,
1397
1926
  deleted: bool = False,
1398
1927
  same: bool = False,
1399
- status_col: Optional[str] = None,
1928
+ status_col: str | None = None,
1400
1929
  ) -> "DataChain":
1401
- """Similar to `.compare()`, which is more generic method to calculate difference
1402
- between two chains. Unlike `.compare()`, this method works only on those chains
1403
- that have `File` object, or it's derivatives, in it. File `source` and `path`
1404
- are used for matching, and file `version` and `etag` for comparing, while in
1405
- `.compare()` user needs to provide arbitrary columns for matching and comparing.
1930
+ """Calculate differences between two chains containing files.
1931
+
1932
+ This method is specifically designed for file chains. It uses file `source`
1933
+ and `path` to match files, and file `version` and `etag` to detect changes.
1406
1934
 
1407
1935
  Parameters:
1408
- other: Chain to calculate diff from.
1409
- on: File signal to match on. If both chains have the
1410
- same file signal then this column is enough for the match. Otherwise,
1411
- `right_on` parameter has to specify the file signal for the other chain.
1412
- This value is used to find corresponding row in other dataset. If not
1413
- found there, row is considered as added (or removed if vice versa), and
1414
- if found then row can be either modified or same.
1415
- right_on: Optional file signal for the `other` to match.
1416
- added (bool): Whether to return added rows in resulting chain.
1417
- deleted (bool): Whether to return deleted rows in resulting chain.
1418
- modified (bool): Whether to return modified rows in resulting chain.
1419
- same (bool): Whether to return unchanged rows in resulting chain.
1420
- status_col (str): Optional name of the new column that is created in
1421
- resulting chain representing diff status.
1936
+ other: Chain to compare against.
1937
+ on: File column name in this chain. Default is "file".
1938
+ right_on: File column name in the other chain. Defaults to `on`.
1939
+ added (bool): Include files that exist in this chain but not in the other.
1940
+ deleted (bool): Include files that exist only in the other chain.
1941
+ modified (bool): Include files that exist in both but have different
1942
+ versions/etags.
1943
+ same (bool): Include files that are identical in both chains.
1944
+ status_col (str): Name for the status column showing differences
1945
+ (A=added, D=deleted, M=modified, S=same).
1946
+
1947
+ Default behavior: By default, includes only new files (added=True and
1948
+ modified=True). This is useful for incremental processing.
1422
1949
 
1423
1950
  Example:
1424
1951
  ```py
1425
- diff = images.diff(
1952
+ diff = images.file_diff(
1426
1953
  new_images,
1427
1954
  on="file",
1428
1955
  right_on="other_file",
@@ -1447,7 +1974,7 @@ class DataChain:
1447
1974
  compare_cols = get_file_signals(on, compare_file_signals)
1448
1975
  right_compare_cols = get_file_signals(right_on, compare_file_signals)
1449
1976
 
1450
- return self.compare(
1977
+ return self.diff(
1451
1978
  other,
1452
1979
  on_cols,
1453
1980
  right_on=right_on_cols,
@@ -1492,47 +2019,67 @@ class DataChain:
1492
2019
  )
1493
2020
  return read_pandas(*args, **kwargs)
1494
2021
 
1495
- def to_pandas(self, flatten=False, include_hidden=True) -> "pd.DataFrame":
2022
+ def to_pandas(
2023
+ self,
2024
+ flatten: bool = False,
2025
+ include_hidden: bool = True,
2026
+ as_object: bool = False,
2027
+ ) -> "pd.DataFrame":
1496
2028
  """Return a pandas DataFrame from the chain.
1497
2029
 
1498
2030
  Parameters:
1499
- flatten : Whether to use a multiindex or flatten column names.
1500
- include_hidden : Whether to include hidden columns.
2031
+ flatten: Whether to use a multiindex or flatten column names.
2032
+ include_hidden: Whether to include hidden columns.
2033
+ as_object: Whether to emit a dataframe backed by Python objects
2034
+ rather than pandas-inferred dtypes.
2035
+
2036
+ Returns:
2037
+ pd.DataFrame: A pandas DataFrame representation of the chain.
1501
2038
  """
1502
2039
  import pandas as pd
1503
2040
 
1504
2041
  headers, max_length = self._effective_signals_schema.get_headers_with_length(
1505
2042
  include_hidden=include_hidden
1506
2043
  )
2044
+
2045
+ columns: list[str] | pd.MultiIndex
1507
2046
  if flatten or max_length < 2:
1508
2047
  columns = [".".join(filter(None, header)) for header in headers]
1509
2048
  else:
1510
2049
  columns = pd.MultiIndex.from_tuples(map(tuple, headers))
1511
2050
 
1512
2051
  results = self.results(include_hidden=include_hidden)
2052
+ if as_object:
2053
+ df = pd.DataFrame(results, columns=columns, dtype=object)
2054
+ df.where(pd.notna(df), None, inplace=True)
2055
+ return df
1513
2056
  return pd.DataFrame.from_records(results, columns=columns)
1514
2057
 
1515
2058
  def show(
1516
2059
  self,
1517
2060
  limit: int = 20,
1518
- flatten=False,
1519
- transpose=False,
1520
- truncate=True,
1521
- include_hidden=False,
2061
+ flatten: bool = False,
2062
+ transpose: bool = False,
2063
+ truncate: bool = True,
2064
+ include_hidden: bool = False,
1522
2065
  ) -> None:
1523
2066
  """Show a preview of the chain results.
1524
2067
 
1525
2068
  Parameters:
1526
- limit : How many rows to show.
1527
- flatten : Whether to use a multiindex or flatten column names.
1528
- transpose : Whether to transpose rows and columns.
1529
- truncate : Whether or not to truncate the contents of columns.
1530
- include_hidden : Whether to include hidden columns.
2069
+ limit: How many rows to show.
2070
+ flatten: Whether to use a multiindex or flatten column names.
2071
+ transpose: Whether to transpose rows and columns.
2072
+ truncate: Whether or not to truncate the contents of columns.
2073
+ include_hidden: Whether to include hidden columns.
1531
2074
  """
1532
2075
  import pandas as pd
1533
2076
 
1534
2077
  dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
1535
- df = dc.to_pandas(flatten, include_hidden=include_hidden)
2078
+ df = dc.to_pandas(
2079
+ flatten,
2080
+ include_hidden=include_hidden,
2081
+ as_object=True,
2082
+ )
1536
2083
 
1537
2084
  if df.empty:
1538
2085
  print("Empty result")
@@ -1588,23 +2135,23 @@ class DataChain:
1588
2135
  def parse_tabular(
1589
2136
  self,
1590
2137
  output: OutputType = None,
1591
- object_name: str = "",
2138
+ column: str = "",
1592
2139
  model_name: str = "",
1593
2140
  source: bool = True,
1594
- nrows: Optional[int] = None,
1595
- **kwargs,
2141
+ nrows: int | None = None,
2142
+ **kwargs: Any,
1596
2143
  ) -> "Self":
1597
2144
  """Generate chain from list of tabular files.
1598
2145
 
1599
2146
  Parameters:
1600
- output : Dictionary or feature class defining column names and their
2147
+ output: Dictionary or feature class defining column names and their
1601
2148
  corresponding types. List of column names is also accepted, in which
1602
2149
  case types will be inferred.
1603
- object_name : Generated object column name.
1604
- model_name : Generated model name.
1605
- source : Whether to include info about the source file.
1606
- nrows : Optional row limit.
1607
- kwargs : Parameters to pass to pyarrow.dataset.dataset.
2150
+ column: Generated column name.
2151
+ model_name: Generated model name.
2152
+ source: Whether to include info about the source file.
2153
+ nrows: Optional row limit.
2154
+ kwargs: Parameters to pass to pyarrow.dataset.dataset.
1608
2155
 
1609
2156
  Example:
1610
2157
  Reading a json lines file:
@@ -1619,24 +2166,33 @@ class DataChain:
1619
2166
  import datachain as dc
1620
2167
 
1621
2168
  chain = dc.read_storage("s3://mybucket")
1622
- chain = chain.filter(dc.C("file.name").glob("*.jsonl"))
2169
+ chain = chain.filter(dc.C("file.path").glob("*.jsonl"))
1623
2170
  chain = chain.parse_tabular(format="json")
1624
2171
  ```
1625
2172
  """
1626
2173
  from pyarrow.dataset import CsvFileFormat, JsonFileFormat
1627
2174
 
1628
- from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
2175
+ from datachain.lib.arrow import (
2176
+ ArrowGenerator,
2177
+ fix_pyarrow_format,
2178
+ infer_schema,
2179
+ schema_to_output,
2180
+ )
1629
2181
 
1630
- if nrows:
1631
- format = kwargs.get("format")
1632
- if format not in ["csv", "json"] and not isinstance(
1633
- format, (CsvFileFormat, JsonFileFormat)
1634
- ):
1635
- raise DatasetPrepareError(
1636
- self.name,
1637
- "error in `parse_tabular` - "
1638
- "`nrows` only supported for csv and json formats.",
1639
- )
2182
+ parse_options = kwargs.pop("parse_options", None)
2183
+ if format := kwargs.get("format"):
2184
+ kwargs["format"] = fix_pyarrow_format(format, parse_options)
2185
+
2186
+ if (
2187
+ nrows
2188
+ and format not in ["csv", "json"]
2189
+ and not isinstance(format, (CsvFileFormat, JsonFileFormat))
2190
+ ):
2191
+ raise DatasetPrepareError(
2192
+ self.name,
2193
+ "error in `parse_tabular` - "
2194
+ "`nrows` only supported for csv and json formats.",
2195
+ )
1640
2196
 
1641
2197
  if "file" not in self.schema or not self.count():
1642
2198
  raise DatasetPrepareError(self.name, "no files to parse.")
@@ -1645,20 +2201,20 @@ class DataChain:
1645
2201
  col_names = output if isinstance(output, Sequence) else None
1646
2202
  if col_names or not output:
1647
2203
  try:
1648
- schema = infer_schema(self, **kwargs)
2204
+ schema = infer_schema(self, **kwargs, parse_options=parse_options)
1649
2205
  output, _ = schema_to_output(schema, col_names)
1650
2206
  except ValueError as e:
1651
2207
  raise DatasetPrepareError(self.name, e) from e
1652
2208
 
1653
2209
  if isinstance(output, dict):
1654
- model_name = model_name or object_name or ""
2210
+ model_name = model_name or column or ""
1655
2211
  model = dict_to_data_model(model_name, output)
1656
2212
  output = model
1657
2213
  else:
1658
2214
  model = output # type: ignore[assignment]
1659
2215
 
1660
- if object_name:
1661
- output = {object_name: model} # type: ignore[dict-item]
2216
+ if column:
2217
+ output = {column: model} # type: ignore[dict-item]
1662
2218
  elif isinstance(output, type(BaseModel)):
1663
2219
  output = {
1664
2220
  name: info.annotation # type: ignore[misc]
@@ -1671,7 +2227,15 @@ class DataChain:
1671
2227
  # disable prefetch if nrows is set
1672
2228
  settings = {"prefetch": 0} if nrows else {}
1673
2229
  return self.settings(**settings).gen( # type: ignore[arg-type]
1674
- ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
2230
+ ArrowGenerator(
2231
+ schema,
2232
+ model,
2233
+ source,
2234
+ nrows,
2235
+ parse_options=parse_options,
2236
+ **kwargs,
2237
+ ),
2238
+ output=output,
1675
2239
  )
1676
2240
 
1677
2241
  @classmethod
@@ -1708,23 +2272,23 @@ class DataChain:
1708
2272
 
1709
2273
  def to_parquet(
1710
2274
  self,
1711
- path: Union[str, os.PathLike[str], BinaryIO],
1712
- partition_cols: Optional[Sequence[str]] = None,
2275
+ path: str | os.PathLike[str] | BinaryIO,
2276
+ partition_cols: Sequence[str] | None = None,
1713
2277
  chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
1714
- fs_kwargs: Optional[dict[str, Any]] = None,
2278
+ fs_kwargs: dict[str, Any] | None = None,
1715
2279
  **kwargs,
1716
2280
  ) -> None:
1717
2281
  """Save chain to parquet file with SignalSchema metadata.
1718
2282
 
1719
2283
  Parameters:
1720
- path : Path or a file-like binary object to save the file. This supports
2284
+ path: Path or a file-like binary object to save the file. This supports
1721
2285
  local paths as well as remote paths, such as s3:// or hf:// with fsspec.
1722
- partition_cols : Column names by which to partition the dataset.
1723
- chunk_size : The chunk size of results to read and convert to columnar
2286
+ partition_cols: Column names by which to partition the dataset.
2287
+ chunk_size: The chunk size of results to read and convert to columnar
1724
2288
  data, to avoid running out of memory.
1725
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
1726
- write, for fsspec-type URLs, such as s3:// or hf:// when
1727
- provided as the destination path.
2289
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2290
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2291
+ are supported.
1728
2292
  """
1729
2293
  import pyarrow as pa
1730
2294
  import pyarrow.parquet as pq
@@ -1754,9 +2318,9 @@ class DataChain:
1754
2318
  fsspec_fs = client.create_fs(**fs_kwargs)
1755
2319
 
1756
2320
  _partition_cols = list(partition_cols) if partition_cols else None
1757
- signal_schema_metadata = orjson.dumps(
1758
- self._effective_signals_schema.serialize()
1759
- )
2321
+ signal_schema_metadata = json.dumps(
2322
+ self._effective_signals_schema.serialize(), ensure_ascii=False
2323
+ ).encode("utf-8")
1760
2324
 
1761
2325
  column_names, column_chunks = self.to_columnar_data_with_names(chunk_size)
1762
2326
 
@@ -1768,7 +2332,7 @@ class DataChain:
1768
2332
  # pyarrow infers the best parquet schema from the python types of
1769
2333
  # the input data.
1770
2334
  table = pa.Table.from_pydict(
1771
- dict(zip(column_names, chunk)),
2335
+ dict(zip(column_names, chunk, strict=False)),
1772
2336
  schema=parquet_schema,
1773
2337
  )
1774
2338
 
@@ -1806,123 +2370,220 @@ class DataChain:
1806
2370
 
1807
2371
  def to_csv(
1808
2372
  self,
1809
- path: Union[str, os.PathLike[str]],
2373
+ path: str | os.PathLike[str],
1810
2374
  delimiter: str = ",",
1811
- fs_kwargs: Optional[dict[str, Any]] = None,
2375
+ fs_kwargs: dict[str, Any] | None = None,
1812
2376
  **kwargs,
1813
- ) -> None:
1814
- """Save chain to a csv (comma-separated values) file.
2377
+ ) -> File:
2378
+ """Save chain to a csv (comma-separated values) file and return the stored
2379
+ `File`.
1815
2380
 
1816
2381
  Parameters:
1817
- path : Path to save the file. This supports local paths as well as
2382
+ path: Path to save the file. This supports local paths as well as
1818
2383
  remote paths, such as s3:// or hf:// with fsspec.
1819
- delimiter : Delimiter to use for the resulting file.
1820
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
1821
- write, for fsspec-type URLs, such as s3:// or hf:// when
1822
- provided as the destination path.
2384
+ delimiter: Delimiter to use for the resulting file.
2385
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2386
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2387
+ are supported.
2388
+ Returns:
2389
+ File: The stored file with refreshed metadata (version, etag, size).
1823
2390
  """
1824
2391
  import csv
1825
2392
 
1826
- opener = open
1827
-
1828
- if isinstance(path, str) and "://" in path:
1829
- from datachain.client.fsspec import Client
1830
-
1831
- fs_kwargs = {
1832
- **self._query.catalog.client_config,
1833
- **(fs_kwargs or {}),
1834
- }
1835
-
1836
- client = Client.get_implementation(path)
1837
-
1838
- fsspec_fs = client.create_fs(**fs_kwargs)
1839
-
1840
- opener = fsspec_fs.open
2393
+ target = File.at(path, session=self.session)
1841
2394
 
1842
2395
  headers, _ = self._effective_signals_schema.get_headers_with_length()
1843
2396
  column_names = [".".join(filter(None, header)) for header in headers]
1844
2397
 
1845
- results_iter = self.collect_flatten()
1846
-
1847
- with opener(path, "w", newline="") as f:
2398
+ with target.open("w", newline="", client_config=fs_kwargs) as f:
1848
2399
  writer = csv.writer(f, delimiter=delimiter, **kwargs)
1849
2400
  writer.writerow(column_names)
1850
-
1851
- for row in results_iter:
2401
+ for row in self._leaf_values():
1852
2402
  writer.writerow(row)
1853
2403
 
2404
+ return target
2405
+
1854
2406
  def to_json(
1855
2407
  self,
1856
- path: Union[str, os.PathLike[str]],
1857
- fs_kwargs: Optional[dict[str, Any]] = None,
2408
+ path: str | os.PathLike[str],
2409
+ fs_kwargs: dict[str, Any] | None = None,
1858
2410
  include_outer_list: bool = True,
1859
- ) -> None:
1860
- """Save chain to a JSON file.
2411
+ ) -> File:
2412
+ """Save chain to a JSON file and return the stored `File`.
1861
2413
 
1862
2414
  Parameters:
1863
- path : Path to save the file. This supports local paths as well as
2415
+ path: Path to save the file. This supports local paths as well as
1864
2416
  remote paths, such as s3:// or hf:// with fsspec.
1865
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
1866
- write, for fsspec-type URLs, such as s3:// or hf:// when
1867
- provided as the destination path.
1868
- include_outer_list : Sets whether to include an outer list for all rows.
2417
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2418
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2419
+ are supported.
2420
+ include_outer_list: Sets whether to include an outer list for all rows.
1869
2421
  Setting this to True makes the file valid JSON, while False instead
1870
2422
  writes in the JSON lines format.
2423
+ Returns:
2424
+ File: The stored file with refreshed metadata (version, etag, size).
1871
2425
  """
1872
- opener = open
1873
-
1874
- if isinstance(path, str) and "://" in path:
1875
- from datachain.client.fsspec import Client
1876
-
1877
- fs_kwargs = {
1878
- **self._query.catalog.client_config,
1879
- **(fs_kwargs or {}),
1880
- }
1881
-
1882
- client = Client.get_implementation(path)
1883
-
1884
- fsspec_fs = client.create_fs(**fs_kwargs)
1885
-
1886
- opener = fsspec_fs.open
1887
-
2426
+ target = File.at(path, session=self.session)
1888
2427
  headers, _ = self._effective_signals_schema.get_headers_with_length()
1889
- headers = [list(filter(None, header)) for header in headers]
2428
+ headers = [list(filter(None, h)) for h in headers]
2429
+ with target.open("wb", client_config=fs_kwargs) as f:
2430
+ self._write_json_stream(f, headers, include_outer_list)
2431
+ return target
1890
2432
 
2433
+ def _write_json_stream(
2434
+ self,
2435
+ f: IO[bytes],
2436
+ headers: list[list[str]],
2437
+ include_outer_list: bool,
2438
+ ) -> None:
1891
2439
  is_first = True
1892
-
1893
- with opener(path, "wb") as f:
1894
- if include_outer_list:
1895
- # This makes the file JSON instead of JSON lines.
1896
- f.write(b"[\n")
1897
- for row in self.collect_flatten():
1898
- if not is_first:
1899
- if include_outer_list:
1900
- # This makes the file JSON instead of JSON lines.
1901
- f.write(b",\n")
1902
- else:
1903
- f.write(b"\n")
1904
- else:
1905
- is_first = False
1906
- f.write(orjson.dumps(row_to_nested_dict(headers, row)))
1907
- if include_outer_list:
1908
- # This makes the file JSON instead of JSON lines.
1909
- f.write(b"\n]\n")
2440
+ if include_outer_list:
2441
+ f.write(b"[\n")
2442
+ for row in self._leaf_values():
2443
+ if not is_first:
2444
+ f.write(b",\n" if include_outer_list else b"\n")
2445
+ else:
2446
+ is_first = False
2447
+ f.write(
2448
+ json.dumps(
2449
+ row_to_nested_dict(headers, row),
2450
+ ensure_ascii=False,
2451
+ ).encode("utf-8")
2452
+ )
2453
+ if include_outer_list:
2454
+ f.write(b"\n]\n")
1910
2455
 
1911
2456
  def to_jsonl(
1912
2457
  self,
1913
- path: Union[str, os.PathLike[str]],
1914
- fs_kwargs: Optional[dict[str, Any]] = None,
1915
- ) -> None:
2458
+ path: str | os.PathLike[str],
2459
+ fs_kwargs: dict[str, Any] | None = None,
2460
+ ) -> File:
1916
2461
  """Save chain to a JSON lines file.
1917
2462
 
1918
2463
  Parameters:
1919
- path : Path to save the file. This supports local paths as well as
2464
+ path: Path to save the file. This supports local paths as well as
1920
2465
  remote paths, such as s3:// or hf:// with fsspec.
1921
- fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
1922
- write, for fsspec-type URLs, such as s3:// or hf:// when
1923
- provided as the destination path.
2466
+ fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
2467
+ when writing (e.g., s3://, gs://, hf://), fsspec-specific options
2468
+ are supported.
2469
+ Returns:
2470
+ File: The stored file with refreshed metadata (version, etag, size).
1924
2471
  """
1925
- self.to_json(path, fs_kwargs, include_outer_list=False)
2472
+ return self.to_json(path, fs_kwargs, include_outer_list=False)
2473
+
2474
+ def to_database(
2475
+ self,
2476
+ table_name: str,
2477
+ connection: "ConnectionType",
2478
+ *,
2479
+ batch_size: int = DEFAULT_DATABASE_BATCH_SIZE,
2480
+ on_conflict: str | None = None,
2481
+ conflict_columns: list[str] | None = None,
2482
+ column_mapping: dict[str, str | None] | None = None,
2483
+ ) -> int:
2484
+ """Save chain to a database table using a given database connection.
2485
+
2486
+ This method exports all DataChain records to a database table, creating the
2487
+ table if it doesn't exist and appending data if it does. The table schema
2488
+ is automatically inferred from the DataChain's signal schema.
2489
+
2490
+ For PostgreSQL, tables are created in the schema specified by the connection's
2491
+ search_path (defaults to 'public'). Use URL parameters to target specific
2492
+ schemas.
2493
+
2494
+ Parameters:
2495
+ table_name: Name of the database table to create/write to.
2496
+ connection: SQLAlchemy connectable, str, or a sqlite3 connection
2497
+ Using SQLAlchemy makes it possible to use any DB supported by that
2498
+ library. If a DBAPI2 object, only sqlite3 is supported. The user is
2499
+ responsible for engine disposal and connection closure for the
2500
+ SQLAlchemy connectable; str connections are closed automatically.
2501
+ batch_size: Number of rows to insert per batch for optimal performance.
2502
+ Larger batches are faster but use more memory. Default: 10,000.
2503
+ on_conflict: Strategy for handling duplicate rows (requires table
2504
+ constraints):
2505
+ - None: Raise error (`sqlalchemy.exc.IntegrityError`) on conflict
2506
+ (default)
2507
+ - "ignore": Skip duplicate rows silently
2508
+ - "update": Update existing rows with new values
2509
+ conflict_columns: List of column names that form a unique constraint
2510
+ for conflict resolution. Required when on_conflict='update' and
2511
+ using PostgreSQL.
2512
+ column_mapping: Optional mapping to rename or skip columns:
2513
+ - Dict mapping DataChain column names to database column names
2514
+ - Set values to None to skip columns entirely, or use `defaultdict` to
2515
+ skip all columns except those specified.
2516
+
2517
+ Returns:
2518
+ int: Number of rows affected (inserted/updated). -1 if DB driver doesn't
2519
+ support telemetry.
2520
+
2521
+ Examples:
2522
+ Basic usage with PostgreSQL:
2523
+ ```py
2524
+ import datachain as dc
2525
+
2526
+ rows_affected = (dc
2527
+ .read_storage("s3://my-bucket/")
2528
+ .to_database("files_table", "postgresql://user:pass@localhost/mydb")
2529
+ )
2530
+ print(f"Inserted/updated {rows_affected} rows")
2531
+ ```
2532
+
2533
+ Using SQLite with connection string:
2534
+ ```py
2535
+ rows_affected = chain.to_database("my_table", "sqlite:///data.db")
2536
+ print(f"Affected {rows_affected} rows")
2537
+ ```
2538
+
2539
+ Column mapping and renaming:
2540
+ ```py
2541
+ mapping = {
2542
+ "user.id": "id",
2543
+ "user.name": "name",
2544
+ "user.password": None # Skip this column
2545
+ }
2546
+ chain.to_database("users", engine, column_mapping=mapping)
2547
+ ```
2548
+
2549
+ Handling conflicts (requires PRIMARY KEY or UNIQUE constraints):
2550
+ ```py
2551
+ # Skip duplicates
2552
+ chain.to_database("my_table", engine, on_conflict="ignore")
2553
+
2554
+ # Update existing records
2555
+ chain.to_database(
2556
+ "my_table", engine, on_conflict="update", conflict_columns=["id"]
2557
+ )
2558
+ ```
2559
+
2560
+ Working with different databases:
2561
+ ```py
2562
+ # MySQL
2563
+ mysql_engine = sa.create_engine("mysql+pymysql://user:pass@host/db")
2564
+ chain.to_database("mysql_table", mysql_engine)
2565
+
2566
+ # SQLite in-memory
2567
+ chain.to_database("temp_table", "sqlite:///:memory:")
2568
+ ```
2569
+
2570
+ PostgreSQL with schema support:
2571
+ ```py
2572
+ pg_url = "postgresql://user:pass@host/db?options=-c search_path=analytics"
2573
+ chain.to_database("processed_data", pg_url)
2574
+ ```
2575
+ """
2576
+ from .database import to_database
2577
+
2578
+ return to_database(
2579
+ self,
2580
+ table_name,
2581
+ connection,
2582
+ batch_size=batch_size,
2583
+ on_conflict=on_conflict,
2584
+ conflict_columns=conflict_columns,
2585
+ column_mapping=column_mapping,
2586
+ )
1926
2587
 
1927
2588
  @classmethod
1928
2589
  def from_records(
@@ -1940,28 +2601,85 @@ class DataChain:
1940
2601
  )
1941
2602
  return read_records(*args, **kwargs)
1942
2603
 
1943
- def sum(self, fr: DataType): # type: ignore[override]
1944
- """Compute the sum of a column."""
1945
- return self._extend_to_data_model("sum", fr)
2604
+ def sum(self, col: str) -> StandardType: # type: ignore[override]
2605
+ """Compute the sum of a column.
2606
+
2607
+ Parameters:
2608
+ col: The column to compute the sum for.
2609
+
2610
+ Returns:
2611
+ The sum of the column values.
2612
+
2613
+ Example:
2614
+ ```py
2615
+ total_size = chain.sum("file.size")
2616
+ print(f"Total size: {total_size}")
2617
+ ```
2618
+ """
2619
+ return self._extend_to_data_model("sum", col)
2620
+
2621
+ def avg(self, col: str) -> StandardType: # type: ignore[override]
2622
+ """Compute the average of a column.
2623
+
2624
+ Parameters:
2625
+ col: The column to compute the average for.
2626
+
2627
+ Returns:
2628
+ The average of the column values.
2629
+
2630
+ Example:
2631
+ ```py
2632
+ average_size = chain.avg("file.size")
2633
+ print(f"Average size: {average_size}")
2634
+ ```
2635
+ """
2636
+ return self._extend_to_data_model("avg", col)
2637
+
2638
+ def min(self, col: str) -> StandardType: # type: ignore[override]
2639
+ """Compute the minimum of a column.
1946
2640
 
1947
- def avg(self, fr: DataType): # type: ignore[override]
1948
- """Compute the average of a column."""
1949
- return self._extend_to_data_model("avg", fr)
2641
+ Parameters:
2642
+ col: The column to compute the minimum for.
1950
2643
 
1951
- def min(self, fr: DataType): # type: ignore[override]
1952
- """Compute the minimum of a column."""
1953
- return self._extend_to_data_model("min", fr)
2644
+ Returns:
2645
+ The minimum value in the column.
1954
2646
 
1955
- def max(self, fr: DataType): # type: ignore[override]
1956
- """Compute the maximum of a column."""
1957
- return self._extend_to_data_model("max", fr)
2647
+ Example:
2648
+ ```py
2649
+ min_size = chain.min("file.size")
2650
+ print(f"Minimum size: {min_size}")
2651
+ ```
2652
+ """
2653
+ return self._extend_to_data_model("min", col)
2654
+
2655
+ def max(self, col: str) -> StandardType: # type: ignore[override]
2656
+ """Compute the maximum of a column.
2657
+
2658
+ Parameters:
2659
+ col: The column to compute the maximum for.
2660
+
2661
+ Returns:
2662
+ The maximum value in the column.
2663
+
2664
+ Example:
2665
+ ```py
2666
+ max_size = chain.max("file.size")
2667
+ print(f"Maximum size: {max_size}")
2668
+ ```
2669
+ """
2670
+ return self._extend_to_data_model("max", col)
1958
2671
 
1959
2672
  def setup(self, **kwargs) -> "Self":
1960
2673
  """Setup variables to pass to UDF functions.
1961
2674
 
1962
- Use before running map/gen/agg/batch_map to save an object and pass it as an
2675
+ Use before running map/gen/agg to save an object and pass it as an
1963
2676
  argument to the UDF.
1964
2677
 
2678
+ The value must be a callable (a `lambda: <value>` syntax can be used to quickly
2679
+ create one) that returns the object to be passed to the UDF. It is evaluated
2680
+ lazily when UDF is running, in case of multiple machines the callable is run on
2681
+ a worker machine.
2682
+
1965
2683
  Example:
1966
2684
  ```py
1967
2685
  import anthropic
@@ -1971,7 +2689,11 @@ class DataChain:
1971
2689
  (
1972
2690
  dc.read_storage(DATA, type="text")
1973
2691
  .settings(parallel=4, cache=True)
2692
+
2693
+ # Setup Anthropic client and pass it to the UDF below automatically
2694
+ # The value is callable (see the note above)
1974
2695
  .setup(client=lambda: anthropic.Anthropic(api_key=API_KEY))
2696
+
1975
2697
  .map(
1976
2698
  claude=lambda client, file: client.messages.create(
1977
2699
  model=MODEL,
@@ -1993,13 +2715,13 @@ class DataChain:
1993
2715
 
1994
2716
  def to_storage(
1995
2717
  self,
1996
- output: Union[str, os.PathLike[str]],
2718
+ output: str | os.PathLike[str],
1997
2719
  signal: str = "file",
1998
2720
  placement: FileExportPlacement = "fullpath",
1999
2721
  link_type: Literal["copy", "symlink"] = "copy",
2000
- num_threads: Optional[int] = EXPORT_FILES_MAX_THREADS,
2001
- anon: bool = False,
2002
- client_config: Optional[dict] = None,
2722
+ num_threads: int | None = EXPORT_FILES_MAX_THREADS,
2723
+ anon: bool | None = None,
2724
+ client_config: dict | None = None,
2003
2725
  ) -> None:
2004
2726
  """Export files from a specified signal to a directory. Files can be
2005
2727
  exported to a local or cloud directory.
@@ -2008,12 +2730,28 @@ class DataChain:
2008
2730
  output: Path to the target directory for exporting files.
2009
2731
  signal: Name of the signal to export files from.
2010
2732
  placement: The method to use for naming exported files.
2011
- The possible values are: "filename", "etag", "fullpath", and "checksum".
2733
+ The possible values are: "filename", "etag", "fullpath",
2734
+ "filepath", and "checksum".
2735
+ Example path translations for an object located at
2736
+ ``s3://bucket/data/img.jpg`` and exported to ``./out``:
2737
+
2738
+ - "filename" -> ``./out/img.jpg`` (no directories)
2739
+ - "filepath" -> ``./out/data/img.jpg`` (relative path kept)
2740
+ - "fullpath" -> ``./out/bucket/data/img.jpg`` (remote host kept)
2741
+ - "etag" -> ``./out/<etag>.jpg`` (unique name via object digest)
2742
+
2743
+ Local sources behave like "filepath" for "fullpath" placement.
2744
+ Relative destinations such as "." or ".." and absolute paths
2745
+ are supported for every strategy.
2012
2746
  link_type: Method to use for exporting files.
2013
2747
  Falls back to `'copy'` if symlinking fails.
2014
- num_threads : number of threads to use for exporting files.
2015
- By default it uses 5 threads.
2016
- anon: If true, we will treat cloud bucket as public one
2748
+ num_threads: number of threads to use for exporting files.
2749
+ By default, it uses 5 threads.
2750
+ anon: If True, we will treat cloud bucket as a public one. Default behavior
2751
+ depends on the previous session configuration (e.g. happens in the
2752
+ initial `read_storage`) and particular cloud storage client
2753
+ implementation (e.g. S3 fallbacks to anonymous access if no credentials
2754
+ were found).
2017
2755
  client_config: Optional configuration for the destination storage client
2018
2756
 
2019
2757
  Example:
@@ -2025,21 +2763,23 @@ class DataChain:
2025
2763
  ds.to_storage("gs://mybucket", placement="filename")
2026
2764
  ```
2027
2765
  """
2766
+ chain = self.persist()
2767
+ count = chain.count()
2768
+
2028
2769
  if placement == "filename" and (
2029
- self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
2030
- != self._query.count()
2770
+ chain._query.distinct(pathfunc.name(C(f"{signal}__path"))).count() != count
2031
2771
  ):
2032
2772
  raise ValueError("Files with the same name found")
2033
2773
 
2034
- if anon:
2035
- client_config = (client_config or {}) | {"anon": True}
2774
+ if anon is not None:
2775
+ client_config = (client_config or {}) | {"anon": anon}
2036
2776
 
2037
2777
  progress_bar = tqdm(
2038
2778
  desc=f"Exporting files to {output}: ",
2039
2779
  unit=" files",
2040
2780
  unit_scale=True,
2041
2781
  unit_divisor=10,
2042
- total=self.count(),
2782
+ total=count,
2043
2783
  leave=False,
2044
2784
  )
2045
2785
  file_exporter = FileExporter(
@@ -2050,20 +2790,36 @@ class DataChain:
2050
2790
  max_threads=num_threads or 1,
2051
2791
  client_config=client_config,
2052
2792
  )
2053
- file_exporter.run(self.collect(signal), progress_bar)
2793
+ file_exporter.run(
2794
+ (rows[0] for rows in chain.to_iter(signal)),
2795
+ progress_bar,
2796
+ )
2054
2797
 
2055
2798
  def shuffle(self) -> "Self":
2056
- """Shuffle the rows of the chain deterministically."""
2057
- return self.order_by("sys.rand")
2799
+ """Shuffle rows with a best-effort deterministic ordering.
2800
+
2801
+ This produces repeatable shuffles. Merge and union operations can
2802
+ lead to non-deterministic results. Use order by or save a dataset
2803
+ afterward to guarantee the same result.
2804
+ """
2805
+ query = self._query.clone(new_table=False)
2806
+ query.steps.append(RegenerateSystemColumns(self._query.catalog))
2058
2807
 
2059
- def sample(self, n) -> "Self":
2808
+ chain = self._evolve(
2809
+ query=query,
2810
+ signal_schema=SignalSchema({"sys": Sys}) | self.signals_schema,
2811
+ )
2812
+ return chain.order_by("sys.rand")
2813
+
2814
+ def sample(self, n: int) -> "Self":
2060
2815
  """Return a random sample from the chain.
2061
2816
 
2062
2817
  Parameters:
2063
- n (int): Number of samples to draw.
2818
+ n: Number of samples to draw.
2064
2819
 
2065
- NOTE: Samples are not deterministic, and streamed/paginated queries or
2066
- multiple workers will draw samples with replacement.
2820
+ Note:
2821
+ Samples are not deterministic, and streamed/paginated queries or
2822
+ multiple workers will draw samples with replacement.
2067
2823
  """
2068
2824
  return self._evolve(query=self._query.sample(n))
2069
2825
 
@@ -2078,27 +2834,62 @@ class DataChain:
2078
2834
 
2079
2835
  Using glob to match patterns
2080
2836
  ```py
2081
- dc.filter(C("file.name").glob("*.jpg"))
2837
+ dc.filter(C("file.path").glob("*.jpg"))
2838
+ ```
2839
+
2840
+ Using in to match lists
2841
+ ```py
2842
+ ids = [1,2,3]
2843
+ dc.filter(C("experiment_id").in_(ids))
2082
2844
  ```
2083
2845
 
2084
2846
  Using `datachain.func`
2085
2847
  ```py
2086
2848
  from datachain.func import string
2087
- dc.filter(string.length(C("file.name")) > 5)
2849
+ dc.filter(string.length(C("file.path")) > 5)
2088
2850
  ```
2089
2851
 
2090
2852
  Combining filters with "or"
2091
2853
  ```py
2092
- dc.filter(C("file.name").glob("cat*") | C("file.name").glob("dog*))
2854
+ dc.filter(
2855
+ C("file.path").glob("cat*") |
2856
+ C("file.path").glob("dog*")
2857
+ )
2858
+ ```
2859
+
2860
+ ```py
2861
+ dc.filter(dc.func.or_(
2862
+ C("file.path").glob("cat*"),
2863
+ C("file.path").glob("dog*")
2864
+ ))
2093
2865
  ```
2094
2866
 
2095
2867
  Combining filters with "and"
2096
2868
  ```py
2097
2869
  dc.filter(
2098
- C("file.name").glob("*.jpg) &
2099
- (string.length(C("file.name")) > 5)
2870
+ C("file.path").glob("*.jpg"),
2871
+ string.length(C("file.path")) > 5
2872
+ )
2873
+ ```
2874
+
2875
+ ```py
2876
+ dc.filter(
2877
+ C("file.path").glob("*.jpg") &
2878
+ (string.length(C("file.path")) > 5)
2100
2879
  )
2101
2880
  ```
2881
+
2882
+ ```py
2883
+ dc.filter(dc.func.and_(
2884
+ C("file.path").glob("*.jpg"),
2885
+ string.length(C("file.path")) > 5
2886
+ ))
2887
+ ```
2888
+
2889
+ Combining filters with "not"
2890
+ ```py
2891
+ dc.filter(~(C("file.path").glob("*.jpg")))
2892
+ ```
2102
2893
  """
2103
2894
  return self._evolve(query=self._query.filter(*args))
2104
2895
 
@@ -2135,6 +2926,10 @@ class DataChain:
2135
2926
  def chunk(self, index: int, total: int) -> "Self":
2136
2927
  """Split a chain into smaller chunks for e.g. parallelization.
2137
2928
 
2929
+ Parameters:
2930
+ index: The index of the chunk (0-indexed).
2931
+ total: The total number of chunks.
2932
+
2138
2933
  Example:
2139
2934
  ```py
2140
2935
  import datachain as dc
@@ -2149,3 +2944,72 @@ class DataChain:
2149
2944
  Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
2150
2945
  """
2151
2946
  return self._evolve(query=self._query.chunk(index, total))
2947
+
2948
+ def to_list(self, *cols: str) -> list[tuple[DataValue, ...]]:
2949
+ """Returns a list of rows of values, optionally limited to the specified
2950
+ columns.
2951
+
2952
+ Parameters:
2953
+ *cols: Limit to the specified columns. By default, all columns are selected.
2954
+
2955
+ Returns:
2956
+ list[tuple[DataType, ...]]: Returns a list of tuples of items for each row.
2957
+
2958
+ Example:
2959
+ Getting all rows as a list:
2960
+ ```py
2961
+ rows = dc.to_list()
2962
+ print(rows)
2963
+ ```
2964
+
2965
+ Getting all rows with selected columns as a list:
2966
+ ```py
2967
+ name_size_pairs = dc.to_list("file.path", "file.size")
2968
+ print(name_size_pairs)
2969
+ ```
2970
+
2971
+ Getting a single column as a list:
2972
+ ```py
2973
+ files = dc.to_list("file.path")
2974
+ print(files) # Returns list of 1-tuples
2975
+ ```
2976
+ """
2977
+ return list(self.to_iter(*cols))
2978
+
2979
+ def to_values(self, col: str) -> list[DataValue]:
2980
+ """Returns a flat list of values from a single column.
2981
+
2982
+ Parameters:
2983
+ col: The name of the column to extract values from.
2984
+
2985
+ Returns:
2986
+ list[DataValue]: Returns a flat list of values from the specified column.
2987
+
2988
+ Example:
2989
+ Getting all values from a single column:
2990
+ ```py
2991
+ file_paths = dc.to_values("file.path")
2992
+ print(file_paths) # Returns list of strings
2993
+ ```
2994
+
2995
+ Getting all file sizes:
2996
+ ```py
2997
+ sizes = dc.to_values("file.size")
2998
+ print(sizes) # Returns list of integers
2999
+ ```
3000
+ """
3001
+ return [row[0] for row in self.to_list(col)]
3002
+
3003
+ def __iter__(self) -> Iterator[tuple[DataValue, ...]]:
3004
+ """Make DataChain objects iterable.
3005
+
3006
+ Yields:
3007
+ (tuple[DataValue, ...]): Yields tuples of all column values for each row.
3008
+
3009
+ Example:
3010
+ ```py
3011
+ for row in chain:
3012
+ print(row)
3013
+ ```
3014
+ """
3015
+ return self.to_iter()