datachain 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/lib/dc.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
  overload,
17
17
  )
18
18
 
19
+ import orjson
19
20
  import pandas as pd
20
21
  import sqlalchemy
21
22
  from pydantic import BaseModel
@@ -54,12 +55,11 @@ from datachain.query import Session
54
55
  from datachain.query.dataset import (
55
56
  DatasetQuery,
56
57
  PartitionByType,
57
- detach,
58
58
  )
59
59
  from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
60
60
  from datachain.sql.functions import path as pathfunc
61
61
  from datachain.telemetry import telemetry
62
- from datachain.utils import inside_notebook
62
+ from datachain.utils import batched_it, inside_notebook
63
63
 
64
64
  if TYPE_CHECKING:
65
65
  from typing_extensions import Concatenate, ParamSpec, Self
@@ -72,6 +72,10 @@ C = Column
72
72
 
73
73
  _T = TypeVar("_T")
74
74
  D = TypeVar("D", bound="DataChain")
75
+ UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
76
+
77
+
78
+ DEFAULT_PARQUET_CHUNK_SIZE = 100_000
75
79
 
76
80
 
77
81
  def resolve_columns(
@@ -159,7 +163,7 @@ class Sys(DataModel):
159
163
  rand: int
160
164
 
161
165
 
162
- class DataChain(DatasetQuery):
166
+ class DataChain:
163
167
  """DataChain - a data structure for batch data processing and evaluation.
164
168
 
165
169
  It represents a sequence of data manipulation steps such as reading data from
@@ -238,33 +242,20 @@ class DataChain(DatasetQuery):
238
242
  "size": 0,
239
243
  }
240
244
 
241
- def __init__(self, *args, settings: Optional[dict] = None, **kwargs):
242
- """This method needs to be redefined as a part of Dataset and DataChain
243
- decoupling.
244
- """
245
- super().__init__( # type: ignore[misc]
246
- *args,
247
- **kwargs,
248
- indexing_column_types=File._datachain_column_types,
249
- )
250
-
251
- telemetry.send_event_once("class", "datachain_init", **kwargs)
252
-
253
- if settings:
254
- self._settings = Settings(**settings)
255
- else:
256
- self._settings = Settings()
257
- self._setup: dict = {}
258
-
259
- self.signals_schema = SignalSchema({"sys": Sys})
260
- if self.feature_schema:
261
- self.signals_schema |= SignalSchema.deserialize(self.feature_schema)
262
- else:
263
- self.signals_schema |= SignalSchema.from_column_types(
264
- self.column_types or {}
265
- )
266
-
267
- self._sys = False
245
+ def __init__(
246
+ self,
247
+ query: DatasetQuery,
248
+ settings: Settings,
249
+ signal_schema: SignalSchema,
250
+ setup: Optional[dict] = None,
251
+ _sys: bool = False,
252
+ ) -> None:
253
+ """Don't instantiate this directly, use one of the from_XXX constructors."""
254
+ self._query = query
255
+ self._settings = settings
256
+ self.signals_schema = signal_schema
257
+ self._setup: dict = setup or {}
258
+ self._sys = _sys
268
259
 
269
260
  @property
270
261
  def schema(self) -> dict[str, DataType]:
@@ -290,18 +281,55 @@ class DataChain(DatasetQuery):
290
281
  def c(self, column: Union[str, Column]) -> Column:
291
282
  """Returns Column instance attached to the current chain."""
292
283
  c = self.column(column) if isinstance(column, str) else self.column(column.name)
293
- c.table = self.table
284
+ c.table = self._query.table
294
285
  return c
295
286
 
287
+ @property
288
+ def session(self) -> Session:
289
+ """Session of the chain."""
290
+ return self._query.session
291
+
292
+ @property
293
+ def name(self) -> Optional[str]:
294
+ """Name of the underlying dataset, if there is one."""
295
+ return self._query.name
296
+
297
+ @property
298
+ def version(self) -> Optional[int]:
299
+ """Version of the underlying dataset, if there is one."""
300
+ return self._query.version
301
+
302
+ def __or__(self, other: "Self") -> "Self":
303
+ """Return `self.union(other)`."""
304
+ return self.union(other)
305
+
296
306
  def print_schema(self) -> None:
297
307
  """Print schema of the chain."""
298
308
  self._effective_signals_schema.print_tree()
299
309
 
300
- def clone(self, new_table: bool = True) -> "Self":
310
+ def clone(self) -> "Self":
301
311
  """Make a copy of the chain in a new table."""
302
- obj = super().clone(new_table=new_table)
303
- obj.signals_schema = copy.deepcopy(self.signals_schema)
304
- return obj
312
+ return self._evolve(query=self._query.clone(new_table=True))
313
+
314
+ def _evolve(
315
+ self,
316
+ *,
317
+ query: Optional[DatasetQuery] = None,
318
+ settings: Optional[Settings] = None,
319
+ signal_schema=None,
320
+ _sys=None,
321
+ ) -> "Self":
322
+ if query is None:
323
+ query = self._query.clone(new_table=False)
324
+ if settings is None:
325
+ settings = self._settings
326
+ if signal_schema is None:
327
+ signal_schema = copy.deepcopy(self.signals_schema)
328
+ if _sys is None:
329
+ _sys = self._sys
330
+ return type(self)(
331
+ query, settings, signal_schema=signal_schema, setup=self._setup, _sys=_sys
332
+ )
305
333
 
306
334
  def settings(
307
335
  self,
@@ -332,11 +360,11 @@ class DataChain(DatasetQuery):
332
360
  )
333
361
  ```
334
362
  """
335
- chain = self.clone()
336
- if sys is not None:
337
- chain._sys = sys
338
- chain._settings.add(Settings(cache, parallel, workers, min_task_size))
339
- return chain
363
+ if sys is None:
364
+ sys = self._sys
365
+ settings = copy.copy(self._settings)
366
+ settings.add(Settings(cache, parallel, workers, min_task_size))
367
+ return self._evolve(settings=settings, _sys=sys)
340
368
 
341
369
  def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
342
370
  """Reset all settings to default values."""
@@ -434,7 +462,7 @@ class DataChain(DatasetQuery):
434
462
  version: Optional[int] = None,
435
463
  session: Optional[Session] = None,
436
464
  settings: Optional[dict] = None,
437
- ) -> "DataChain":
465
+ ) -> "Self":
438
466
  """Get data from a saved Dataset. It returns the chain itself.
439
467
 
440
468
  Parameters:
@@ -446,7 +474,24 @@ class DataChain(DatasetQuery):
446
474
  chain = DataChain.from_dataset("my_cats")
447
475
  ```
448
476
  """
449
- return DataChain(name=name, version=version, session=session, settings=settings)
477
+ query = DatasetQuery(
478
+ name=name,
479
+ version=version,
480
+ session=session,
481
+ indexing_column_types=File._datachain_column_types,
482
+ )
483
+ telemetry.send_event_once("class", "datachain_init", name=name, version=version)
484
+ if settings:
485
+ _settings = Settings(**settings)
486
+ else:
487
+ _settings = Settings()
488
+
489
+ signals_schema = SignalSchema({"sys": Sys})
490
+ if query.feature_schema:
491
+ signals_schema |= SignalSchema.deserialize(query.feature_schema)
492
+ else:
493
+ signals_schema |= SignalSchema.from_column_types(query.column_types or {})
494
+ return cls(query, _settings, signals_schema)
450
495
 
451
496
  @classmethod
452
497
  def from_json(
@@ -699,7 +744,11 @@ class DataChain(DatasetQuery):
699
744
  version : version of a dataset. Default - the last version that exist.
700
745
  """
701
746
  schema = self.signals_schema.clone_without_sys_signals().serialize()
702
- return super().save(name=name, version=version, feature_schema=schema, **kwargs)
747
+ return self._evolve(
748
+ query=self._query.save(
749
+ name=name, version=version, feature_schema=schema, **kwargs
750
+ )
751
+ )
703
752
 
704
753
  def apply(self, func, *args, **kwargs):
705
754
  """Apply any function to the chain.
@@ -765,16 +814,17 @@ class DataChain(DatasetQuery):
765
814
  """
766
815
  udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
767
816
 
768
- chain = self.add_signals(
769
- udf_obj.to_udf_wrapper(),
770
- **self._settings.to_dict(),
817
+ return self._evolve(
818
+ query=self._query.add_signals(
819
+ udf_obj.to_udf_wrapper(),
820
+ **self._settings.to_dict(),
821
+ ),
822
+ signal_schema=self.signals_schema | udf_obj.output,
771
823
  )
772
824
 
773
- return chain.add_schema(udf_obj.output).reset_settings(self._settings)
774
-
775
825
  def gen(
776
826
  self,
777
- func: Optional[Callable] = None,
827
+ func: Optional[Union[Callable, Generator]] = None,
778
828
  params: Union[None, str, Sequence[str]] = None,
779
829
  output: OutputType = None,
780
830
  **signal_map,
@@ -800,14 +850,14 @@ class DataChain(DatasetQuery):
800
850
  ```
801
851
  """
802
852
  udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
803
- chain = DatasetQuery.generate(
804
- self,
805
- udf_obj.to_udf_wrapper(),
806
- **self._settings.to_dict(),
853
+ return self._evolve(
854
+ query=self._query.generate(
855
+ udf_obj.to_udf_wrapper(),
856
+ **self._settings.to_dict(),
857
+ ),
858
+ signal_schema=udf_obj.output,
807
859
  )
808
860
 
809
- return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
810
-
811
861
  def agg(
812
862
  self,
813
863
  func: Optional[Callable] = None,
@@ -840,15 +890,15 @@ class DataChain(DatasetQuery):
840
890
  ```
841
891
  """
842
892
  udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
843
- chain = DatasetQuery.generate(
844
- self,
845
- udf_obj.to_udf_wrapper(),
846
- partition_by=partition_by,
847
- **self._settings.to_dict(),
893
+ return self._evolve(
894
+ query=self._query.generate(
895
+ udf_obj.to_udf_wrapper(),
896
+ partition_by=partition_by,
897
+ **self._settings.to_dict(),
898
+ ),
899
+ signal_schema=udf_obj.output,
848
900
  )
849
901
 
850
- return chain.reset_schema(udf_obj.output).reset_settings(self._settings)
851
-
852
902
  def batch_map(
853
903
  self,
854
904
  func: Optional[Callable] = None,
@@ -876,22 +926,22 @@ class DataChain(DatasetQuery):
876
926
  ```
877
927
  """
878
928
  udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
879
- chain = DatasetQuery.add_signals(
880
- self,
881
- udf_obj.to_udf_wrapper(batch),
882
- **self._settings.to_dict(),
929
+ return self._evolve(
930
+ query=self._query.add_signals(
931
+ udf_obj.to_udf_wrapper(batch),
932
+ **self._settings.to_dict(),
933
+ ),
934
+ signal_schema=self.signals_schema | udf_obj.output,
883
935
  )
884
936
 
885
- return chain.add_schema(udf_obj.output).reset_settings(self._settings)
886
-
887
937
  def _udf_to_obj(
888
938
  self,
889
- target_class: type[UDFBase],
890
- func: Optional[Callable],
939
+ target_class: type[UDFObjT],
940
+ func: Optional[Union[Callable, UDFObjT]],
891
941
  params: Union[None, str, Sequence[str]],
892
942
  output: OutputType,
893
943
  signal_map,
894
- ) -> UDFBase:
944
+ ) -> UDFObjT:
895
945
  is_generator = target_class.is_output_batched
896
946
  name = self.name or ""
897
947
 
@@ -907,17 +957,12 @@ class DataChain(DatasetQuery):
907
957
  return target_class._create(sign, params_schema)
908
958
 
909
959
  def _extend_to_data_model(self, method_name, *args, **kwargs):
910
- super_func = getattr(super(), method_name)
960
+ query_func = getattr(self._query, method_name)
911
961
 
912
962
  new_schema = self.signals_schema.resolve(*args)
913
963
  columns = [C(col) for col in new_schema.db_signals()]
914
- res = super_func(*columns, **kwargs)
915
- if isinstance(res, DataChain):
916
- res.signals_schema = new_schema
917
-
918
- return res
964
+ return query_func(*columns, **kwargs)
919
965
 
920
- @detach
921
966
  @resolve_columns
922
967
  def order_by(self, *args, descending: bool = False) -> "Self":
923
968
  """Orders by specified set of signals.
@@ -928,9 +973,8 @@ class DataChain(DatasetQuery):
928
973
  if descending:
929
974
  args = tuple(sqlalchemy.desc(a) for a in args)
930
975
 
931
- return super().order_by(*args)
976
+ return self._evolve(query=self._query.order_by(*args))
932
977
 
933
- @detach
934
978
  def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
935
979
  """Removes duplicate rows based on uniqueness of some input column(s)
936
980
  i.e if rows are found with the same value of input column(s), only one
@@ -942,29 +986,30 @@ class DataChain(DatasetQuery):
942
986
  )
943
987
  ```
944
988
  """
945
- return super().distinct(*self.signals_schema.resolve(arg, *args).db_signals())
989
+ return self._evolve(
990
+ query=self._query.distinct(
991
+ *self.signals_schema.resolve(arg, *args).db_signals()
992
+ )
993
+ )
946
994
 
947
- @detach
948
995
  def select(self, *args: str, _sys: bool = True) -> "Self":
949
996
  """Select only a specified set of signals."""
950
997
  new_schema = self.signals_schema.resolve(*args)
951
998
  if _sys:
952
999
  new_schema = SignalSchema({"sys": Sys}) | new_schema
953
1000
  columns = new_schema.db_signals()
954
- chain = super().select(*columns)
955
- chain.signals_schema = new_schema
956
- return chain
1001
+ return self._evolve(
1002
+ query=self._query.select(*columns), signal_schema=new_schema
1003
+ )
957
1004
 
958
- @detach
959
1005
  def select_except(self, *args: str) -> "Self":
960
1006
  """Select all the signals expect the specified signals."""
961
1007
  new_schema = self.signals_schema.select_except_signals(*args)
962
1008
  columns = new_schema.db_signals()
963
- chain = super().select(*columns)
964
- chain.signals_schema = new_schema
965
- return chain
1009
+ return self._evolve(
1010
+ query=self._query.select(*columns), signal_schema=new_schema
1011
+ )
966
1012
 
967
- @detach
968
1013
  def mutate(self, **kwargs) -> "Self":
969
1014
  """Create new signals based on existing signals.
970
1015
 
@@ -1029,9 +1074,9 @@ class DataChain(DatasetQuery):
1029
1074
  # adding new signal
1030
1075
  mutated[name] = value
1031
1076
 
1032
- chain = super().mutate(**mutated)
1033
- chain.signals_schema = schema.mutate(kwargs)
1034
- return chain
1077
+ return self._evolve(
1078
+ query=self._query.mutate(**mutated), signal_schema=schema.mutate(kwargs)
1079
+ )
1035
1080
 
1036
1081
  @property
1037
1082
  def _effective_signals_schema(self) -> "SignalSchema":
@@ -1058,11 +1103,34 @@ class DataChain(DatasetQuery):
1058
1103
  a tuple of row values.
1059
1104
  """
1060
1105
  db_signals = self._effective_signals_schema.db_signals()
1061
- with super().select(*db_signals).as_iterable() as rows:
1106
+ with self._query.select(*db_signals).as_iterable() as rows:
1062
1107
  if row_factory:
1063
1108
  rows = (row_factory(db_signals, r) for r in rows)
1064
1109
  yield from rows
1065
1110
 
1111
+ def to_columnar_data_with_names(
1112
+ self, chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE
1113
+ ) -> tuple[list[str], Iterator[list[list[Any]]]]:
1114
+ """Returns column names and the results as an iterator that provides chunks,
1115
+ with each chunk containing a list of columns, where each column contains a
1116
+ list of the row values for that column in that chunk. Useful for columnar data
1117
+ formats, such as parquet or other OLAP databases.
1118
+ """
1119
+ headers, _ = self._effective_signals_schema.get_headers_with_length()
1120
+ column_names = [".".join(filter(None, header)) for header in headers]
1121
+
1122
+ results_iter = self.collect_flatten()
1123
+
1124
+ def column_chunks() -> Iterator[list[list[Any]]]:
1125
+ for chunk_iter in batched_it(results_iter, chunk_size):
1126
+ columns: list[list[Any]] = [[] for _ in column_names]
1127
+ for row in chunk_iter:
1128
+ for i, col in enumerate(columns):
1129
+ col.append(row[i])
1130
+ yield columns
1131
+
1132
+ return column_names, column_chunks()
1133
+
1066
1134
  @overload
1067
1135
  def results(self) -> list[tuple[Any, ...]]: ...
1068
1136
 
@@ -1126,7 +1194,7 @@ class DataChain(DatasetQuery):
1126
1194
  chain = self.select(*cols) if cols else self
1127
1195
  signals_schema = chain._effective_signals_schema
1128
1196
  db_signals = signals_schema.db_signals()
1129
- with super().select(*db_signals).as_iterable() as rows:
1197
+ with self._query.select(*db_signals).as_iterable() as rows:
1130
1198
  for row in rows:
1131
1199
  ret = signals_schema.row_to_features(
1132
1200
  row, catalog=chain.session.catalog, cache=chain._settings.cache
@@ -1156,7 +1224,7 @@ class DataChain(DatasetQuery):
1156
1224
  """
1157
1225
  from datachain.torch import PytorchDataset
1158
1226
 
1159
- if self.attached:
1227
+ if self._query.attached:
1160
1228
  chain = self
1161
1229
  else:
1162
1230
  chain = self.save()
@@ -1164,7 +1232,7 @@ class DataChain(DatasetQuery):
1164
1232
  return PytorchDataset(
1165
1233
  chain.name,
1166
1234
  chain.version,
1167
- catalog=self.catalog,
1235
+ catalog=self.session.catalog,
1168
1236
  transform=transform,
1169
1237
  tokenizer=tokenizer,
1170
1238
  tokenizer_kwargs=tokenizer_kwargs,
@@ -1175,7 +1243,6 @@ class DataChain(DatasetQuery):
1175
1243
  schema = self.signals_schema.clone_without_file_signals()
1176
1244
  return self.select(*schema.values.keys())
1177
1245
 
1178
- @detach
1179
1246
  def merge(
1180
1247
  self,
1181
1248
  right_ds: "DataChain",
@@ -1240,7 +1307,7 @@ class DataChain(DatasetQuery):
1240
1307
  )
1241
1308
 
1242
1309
  if self == right_ds:
1243
- right_ds = right_ds.clone(new_table=True)
1310
+ right_ds = right_ds.clone()
1244
1311
 
1245
1312
  errors = []
1246
1313
 
@@ -1266,9 +1333,11 @@ class DataChain(DatasetQuery):
1266
1333
  on, right_on, f"Could not resolve {', '.join(errors)}"
1267
1334
  )
1268
1335
 
1269
- ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")
1270
-
1271
- ds.feature_schema = None
1336
+ query = self._query.join(
1337
+ right_ds._query, sqlalchemy.and_(*ops), inner, rname + "{name}"
1338
+ )
1339
+ query.feature_schema = None
1340
+ ds = self._evolve(query=query)
1272
1341
 
1273
1342
  signals_schema = self.signals_schema.clone_without_sys_signals()
1274
1343
  right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
@@ -1278,6 +1347,14 @@ class DataChain(DatasetQuery):
1278
1347
 
1279
1348
  return ds
1280
1349
 
1350
+ def union(self, other: "Self") -> "Self":
1351
+ """Return the set union of the two datasets.
1352
+
1353
+ Parameters:
1354
+ other: chain whose rows will be added to `self`.
1355
+ """
1356
+ return self._evolve(query=self._query.union(other._query))
1357
+
1281
1358
  def subtract( # type: ignore[override]
1282
1359
  self,
1283
1360
  other: "DataChain",
@@ -1341,7 +1418,7 @@ class DataChain(DatasetQuery):
1341
1418
  other.signals_schema.resolve(*right_on).db_signals(),
1342
1419
  ) # type: ignore[arg-type]
1343
1420
  )
1344
- return super().subtract(other, signals) # type: ignore[arg-type]
1421
+ return self._evolve(query=self._query.subtract(other._query, signals)) # type: ignore[arg-type]
1345
1422
 
1346
1423
  @classmethod
1347
1424
  def from_values(
@@ -1449,7 +1526,7 @@ class DataChain(DatasetQuery):
1449
1526
  transpose : Whether to transpose rows and columns.
1450
1527
  truncate : Whether or not to truncate the contents of columns.
1451
1528
  """
1452
- dc = self.limit(limit) if limit > 0 else self
1529
+ dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
1453
1530
  df = dc.to_pandas(flatten)
1454
1531
 
1455
1532
  if df.empty:
@@ -1759,21 +1836,96 @@ class DataChain(DatasetQuery):
1759
1836
  self,
1760
1837
  path: Union[str, os.PathLike[str], BinaryIO],
1761
1838
  partition_cols: Optional[Sequence[str]] = None,
1839
+ chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
1762
1840
  **kwargs,
1763
1841
  ) -> None:
1764
- """Save chain to parquet file.
1842
+ """Save chain to parquet file with SignalSchema metadata.
1765
1843
 
1766
1844
  Parameters:
1767
1845
  path : Path or a file-like binary object to save the file.
1768
1846
  partition_cols : Column names by which to partition the dataset.
1847
+ chunk_size : The chunk size of results to read and convert to columnar
1848
+ data, to avoid running out of memory.
1769
1849
  """
1850
+ import pyarrow as pa
1851
+ import pyarrow.parquet as pq
1852
+
1853
+ from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY
1854
+
1770
1855
  _partition_cols = list(partition_cols) if partition_cols else None
1771
- return self.to_pandas().to_parquet(
1772
- path,
1773
- partition_cols=_partition_cols,
1774
- **kwargs,
1856
+ signal_schema_metadata = orjson.dumps(
1857
+ self._effective_signals_schema.serialize()
1775
1858
  )
1776
1859
 
1860
+ column_names, column_chunks = self.to_columnar_data_with_names(chunk_size)
1861
+
1862
+ parquet_schema = None
1863
+ parquet_writer = None
1864
+ first_chunk = True
1865
+
1866
+ for chunk in column_chunks:
1867
+ # pyarrow infers the best parquet schema from the python types of
1868
+ # the input data.
1869
+ table = pa.Table.from_pydict(
1870
+ dict(zip(column_names, chunk)),
1871
+ schema=parquet_schema,
1872
+ )
1873
+
1874
+ # Preserve any existing metadata, and add the DataChain SignalSchema.
1875
+ existing_metadata = table.schema.metadata or {}
1876
+ merged_metadata = {
1877
+ **existing_metadata,
1878
+ DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY: signal_schema_metadata,
1879
+ }
1880
+ table = table.replace_schema_metadata(merged_metadata)
1881
+ parquet_schema = table.schema
1882
+
1883
+ if _partition_cols:
1884
+ # Write to a partitioned parquet dataset.
1885
+ pq.write_to_dataset(
1886
+ table,
1887
+ root_path=path,
1888
+ partition_cols=_partition_cols,
1889
+ **kwargs,
1890
+ )
1891
+ else:
1892
+ if first_chunk:
1893
+ # Write to a single parquet file.
1894
+ parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs)
1895
+ first_chunk = False
1896
+
1897
+ assert parquet_writer
1898
+ parquet_writer.write_table(table)
1899
+
1900
+ if parquet_writer:
1901
+ parquet_writer.close()
1902
+
1903
+ def to_csv(
1904
+ self,
1905
+ path: Union[str, os.PathLike[str]],
1906
+ delimiter: str = ",",
1907
+ **kwargs,
1908
+ ) -> None:
1909
+ """Save chain to a csv (comma-separated values) file.
1910
+
1911
+ Parameters:
1912
+ path : Path to save the file.
1913
+ delimiter : Delimiter to use for the resulting file.
1914
+ """
1915
+ import csv
1916
+
1917
+ headers, _ = self._effective_signals_schema.get_headers_with_length()
1918
+ column_names = [".".join(filter(None, header)) for header in headers]
1919
+
1920
+ results_iter = self.collect_flatten()
1921
+
1922
+ with open(path, "w", newline="") as f:
1923
+ writer = csv.writer(f, delimiter=delimiter, **kwargs)
1924
+ writer.writerow(column_names)
1925
+
1926
+ for row in results_iter:
1927
+ writer.writerow(row)
1928
+
1777
1929
  @classmethod
1778
1930
  def from_records(
1779
1931
  cls,
@@ -1782,7 +1934,7 @@ class DataChain(DatasetQuery):
1782
1934
  settings: Optional[dict] = None,
1783
1935
  in_memory: bool = False,
1784
1936
  schema: Optional[dict[str, DataType]] = None,
1785
- ) -> "DataChain":
1937
+ ) -> "Self":
1786
1938
  """Create a DataChain from the provided records. This method can be used for
1787
1939
  programmatically generating a chain in contrast of reading data from storages
1788
1940
  or other sources.
@@ -1837,7 +1989,7 @@ class DataChain(DatasetQuery):
1837
1989
  insert_q = dr.get_table().insert()
1838
1990
  for record in to_insert:
1839
1991
  db.execute(insert_q.values(**record))
1840
- return DataChain(name=dsr.name, settings=settings)
1992
+ return cls.from_dataset(name=dsr.name, session=session, settings=settings)
1841
1993
 
1842
1994
  def sum(self, fr: DataType): # type: ignore[override]
1843
1995
  """Compute the sum of a column."""
@@ -1898,8 +2050,8 @@ class DataChain(DatasetQuery):
1898
2050
  ) -> None:
1899
2051
  """Method that exports all files from chain to some folder."""
1900
2052
  if placement == "filename" and (
1901
- super().distinct(pathfunc.name(C(f"{signal}__path"))).count()
1902
- != self.count()
2053
+ self._query.distinct(pathfunc.name(C(f"{signal}__path"))).count()
2054
+ != self._query.count()
1903
2055
  ):
1904
2056
  raise ValueError("Files with the same name found")
1905
2057
 
@@ -1919,10 +2071,9 @@ class DataChain(DatasetQuery):
1919
2071
  NOTE: Samples are not deterministic, and streamed/paginated queries or
1920
2072
  multiple workers will draw samples with replacement.
1921
2073
  """
1922
- return super().sample(n)
2074
+ return self._evolve(query=self._query.sample(n))
1923
2075
 
1924
- @detach
1925
- def filter(self, *args) -> "Self":
2076
+ def filter(self, *args: Any) -> "Self":
1926
2077
  """Filter the chain according to conditions.
1927
2078
 
1928
2079
  Example:
@@ -1955,14 +2106,50 @@ class DataChain(DatasetQuery):
1955
2106
  )
1956
2107
  ```
1957
2108
  """
1958
- return super().filter(*args)
2109
+ return self._evolve(query=self._query.filter(*args))
1959
2110
 
1960
- @detach
1961
2111
  def limit(self, n: int) -> "Self":
1962
- """Return the first n rows of the chain."""
1963
- return super().limit(n)
2112
+ """Return the first `n` rows of the chain.
2113
+
2114
+ If the chain is unordered, which rows are returned is undefined.
2115
+ If the chain has less than `n` rows, the whole chain is returned.
2116
+
2117
+ Parameters:
2118
+ n (int): Number of rows to return.
2119
+ """
2120
+ return self._evolve(query=self._query.limit(n))
1964
2121
 
1965
- @detach
1966
2122
  def offset(self, offset: int) -> "Self":
1967
- """Return the results starting with the offset row."""
1968
- return super().offset(offset)
2123
+ """Return the results starting with the offset row.
2124
+
2125
+ If the chain is unordered, which rows are skipped in undefined.
2126
+ If the chain has less than `offset` rows, the result is an empty chain.
2127
+
2128
+ Parameters:
2129
+ offset (int): Number of rows to skip.
2130
+ """
2131
+ return self._evolve(query=self._query.offset(offset))
2132
+
2133
+ def count(self) -> int:
2134
+ """Return the number of rows in the chain."""
2135
+ return self._query.count()
2136
+
2137
+ def exec(self) -> "Self":
2138
+ """Execute the chain."""
2139
+ return self._evolve(query=self._query.exec())
2140
+
2141
+ def chunk(self, index: int, total: int) -> "Self":
2142
+ """Split a chain into smaller chunks for e.g. parallelization.
2143
+
2144
+ Example:
2145
+ ```py
2146
+ chain = DataChain.from_storage(...)
2147
+ chunk_1 = query._chunk(0, 2)
2148
+ chunk_2 = query._chunk(1, 2)
2149
+ ```
2150
+
2151
+ Note:
2152
+ Bear in mind that `index` is 0-indexed but `total` isn't.
2153
+ Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
2154
+ """
2155
+ return self._evolve(query=self._query.chunk(index, total))