datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/dc/datachain.py
CHANGED
|
@@ -1,51 +1,69 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import hashlib
|
|
3
|
+
import logging
|
|
2
4
|
import os
|
|
3
5
|
import os.path
|
|
4
6
|
import sys
|
|
5
7
|
import warnings
|
|
6
|
-
from collections.abc import Iterator, Sequence
|
|
8
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
7
9
|
from typing import (
|
|
8
10
|
IO,
|
|
9
11
|
TYPE_CHECKING,
|
|
10
12
|
Any,
|
|
11
13
|
BinaryIO,
|
|
12
|
-
Callable,
|
|
13
14
|
ClassVar,
|
|
14
15
|
Literal,
|
|
15
|
-
Optional,
|
|
16
16
|
TypeVar,
|
|
17
|
-
|
|
17
|
+
cast,
|
|
18
18
|
overload,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
|
-
import orjson
|
|
22
21
|
import sqlalchemy
|
|
23
22
|
from pydantic import BaseModel
|
|
23
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
24
24
|
from tqdm import tqdm
|
|
25
25
|
|
|
26
|
+
from datachain import json, semver
|
|
26
27
|
from datachain.dataset import DatasetRecord
|
|
28
|
+
from datachain.delta import delta_disabled
|
|
29
|
+
from datachain.error import (
|
|
30
|
+
JobAncestryDepthExceededError,
|
|
31
|
+
ProjectCreateNotAllowedError,
|
|
32
|
+
ProjectNotFoundError,
|
|
33
|
+
)
|
|
27
34
|
from datachain.func import literal
|
|
28
35
|
from datachain.func.base import Function
|
|
29
36
|
from datachain.func.func import Func
|
|
37
|
+
from datachain.job import Job
|
|
30
38
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
31
|
-
from datachain.lib.data_model import
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
39
|
+
from datachain.lib.data_model import (
|
|
40
|
+
DataModel,
|
|
41
|
+
DataType,
|
|
42
|
+
DataValue,
|
|
43
|
+
StandardType,
|
|
44
|
+
dict_to_data_model,
|
|
36
45
|
)
|
|
46
|
+
from datachain.lib.file import EXPORT_FILES_MAX_THREADS, ArrowRow, File, FileExporter
|
|
37
47
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
48
|
+
from datachain.lib.model_store import ModelStore
|
|
38
49
|
from datachain.lib.settings import Settings
|
|
39
|
-
from datachain.lib.signal_schema import SignalSchema
|
|
50
|
+
from datachain.lib.signal_schema import SignalResolvingError, SignalSchema
|
|
40
51
|
from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
|
|
41
52
|
from datachain.lib.udf_signature import UdfSignature
|
|
42
53
|
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
54
|
+
from datachain.project import Project
|
|
43
55
|
from datachain.query import Session
|
|
44
|
-
from datachain.query.dataset import
|
|
45
|
-
|
|
56
|
+
from datachain.query.dataset import (
|
|
57
|
+
DatasetQuery,
|
|
58
|
+
PartitionByType,
|
|
59
|
+
RegenerateSystemColumns,
|
|
60
|
+
UnionSchemaMismatchError,
|
|
61
|
+
)
|
|
62
|
+
from datachain.query.schema import DEFAULT_DELIMITER, Column
|
|
46
63
|
from datachain.sql.functions import path as pathfunc
|
|
47
|
-
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
|
|
64
|
+
from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict
|
|
48
65
|
|
|
66
|
+
from .database import DEFAULT_DATABASE_BATCH_SIZE
|
|
49
67
|
from .utils import (
|
|
50
68
|
DatasetMergeError,
|
|
51
69
|
DatasetPrepareError,
|
|
@@ -54,9 +72,12 @@ from .utils import (
|
|
|
54
72
|
Sys,
|
|
55
73
|
_get_merge_error_str,
|
|
56
74
|
_validate_merge_on,
|
|
75
|
+
is_studio,
|
|
57
76
|
resolve_columns,
|
|
58
77
|
)
|
|
59
78
|
|
|
79
|
+
logger = logging.getLogger("datachain")
|
|
80
|
+
|
|
60
81
|
C = Column
|
|
61
82
|
|
|
62
83
|
_T = TypeVar("_T")
|
|
@@ -65,11 +86,27 @@ UDFObjT = TypeVar("UDFObjT", bound=UDFBase)
|
|
|
65
86
|
DEFAULT_PARQUET_CHUNK_SIZE = 100_000
|
|
66
87
|
|
|
67
88
|
if TYPE_CHECKING:
|
|
89
|
+
import sqlite3
|
|
90
|
+
|
|
68
91
|
import pandas as pd
|
|
92
|
+
from sqlalchemy.orm import Session as OrmSession
|
|
69
93
|
from typing_extensions import ParamSpec, Self
|
|
70
94
|
|
|
71
95
|
P = ParamSpec("P")
|
|
72
96
|
|
|
97
|
+
ConnectionType = (
|
|
98
|
+
str
|
|
99
|
+
| sqlalchemy.engine.URL
|
|
100
|
+
| sqlalchemy.engine.interfaces.Connectable
|
|
101
|
+
| sqlalchemy.engine.Engine
|
|
102
|
+
| sqlalchemy.engine.Connection
|
|
103
|
+
| OrmSession
|
|
104
|
+
| sqlite3.Connection
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
T = TypeVar("T", bound="DataChain")
|
|
109
|
+
|
|
73
110
|
|
|
74
111
|
class DataChain:
|
|
75
112
|
"""DataChain - a data structure for batch data processing and evaluation.
|
|
@@ -133,7 +170,7 @@ class DataChain:
|
|
|
133
170
|
.choices[0]
|
|
134
171
|
.message.content,
|
|
135
172
|
)
|
|
136
|
-
.
|
|
173
|
+
.persist()
|
|
137
174
|
)
|
|
138
175
|
|
|
139
176
|
try:
|
|
@@ -154,7 +191,7 @@ class DataChain:
|
|
|
154
191
|
query: DatasetQuery,
|
|
155
192
|
settings: Settings,
|
|
156
193
|
signal_schema: SignalSchema,
|
|
157
|
-
setup:
|
|
194
|
+
setup: dict | None = None,
|
|
158
195
|
_sys: bool = False,
|
|
159
196
|
) -> None:
|
|
160
197
|
"""Don't instantiate this directly, use one of the from_XXX constructors."""
|
|
@@ -163,6 +200,12 @@ class DataChain:
|
|
|
163
200
|
self.signals_schema = signal_schema
|
|
164
201
|
self._setup: dict = setup or {}
|
|
165
202
|
self._sys = _sys
|
|
203
|
+
self._delta = False
|
|
204
|
+
self._delta_unsafe = False
|
|
205
|
+
self._delta_on: str | Sequence[str] | None = None
|
|
206
|
+
self._delta_result_on: str | Sequence[str] | None = None
|
|
207
|
+
self._delta_compare: str | Sequence[str] | None = None
|
|
208
|
+
self._delta_retry: bool | str | None = None
|
|
166
209
|
|
|
167
210
|
def __repr__(self) -> str:
|
|
168
211
|
"""Return a string representation of the chain."""
|
|
@@ -176,6 +219,48 @@ class DataChain:
|
|
|
176
219
|
self.print_schema(file=file)
|
|
177
220
|
return file.getvalue()
|
|
178
221
|
|
|
222
|
+
def hash(self) -> str:
|
|
223
|
+
"""
|
|
224
|
+
Calculates SHA hash of this chain. Hash calculation is fast and consistent.
|
|
225
|
+
It takes into account all the steps added to the chain and their inputs.
|
|
226
|
+
Order of the steps is important.
|
|
227
|
+
"""
|
|
228
|
+
return self._query.hash()
|
|
229
|
+
|
|
230
|
+
def _as_delta(
|
|
231
|
+
self,
|
|
232
|
+
on: str | Sequence[str] | None = None,
|
|
233
|
+
right_on: str | Sequence[str] | None = None,
|
|
234
|
+
compare: str | Sequence[str] | None = None,
|
|
235
|
+
delta_retry: bool | str | None = None,
|
|
236
|
+
delta_unsafe: bool = False,
|
|
237
|
+
) -> "Self":
|
|
238
|
+
"""Marks this chain as delta, which means special delta process will be
|
|
239
|
+
called on saving dataset for optimization"""
|
|
240
|
+
if on is None:
|
|
241
|
+
raise ValueError("'delta on' fields must be defined")
|
|
242
|
+
self._delta = True
|
|
243
|
+
self._delta_on = on
|
|
244
|
+
self._delta_result_on = right_on
|
|
245
|
+
self._delta_compare = compare
|
|
246
|
+
self._delta_retry = delta_retry
|
|
247
|
+
self._delta_unsafe = delta_unsafe
|
|
248
|
+
return self
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def empty(self) -> bool:
|
|
252
|
+
"""Returns True if chain has zero number of rows"""
|
|
253
|
+
return not bool(self.count())
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def delta(self) -> bool:
|
|
257
|
+
"""Returns True if this chain is ran in "delta" update mode"""
|
|
258
|
+
return self._delta
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def delta_unsafe(self) -> bool:
|
|
262
|
+
return self._delta_unsafe
|
|
263
|
+
|
|
179
264
|
@property
|
|
180
265
|
def schema(self) -> dict[str, DataType]:
|
|
181
266
|
"""Get schema of the chain."""
|
|
@@ -197,7 +282,7 @@ class DataChain:
|
|
|
197
282
|
|
|
198
283
|
raise ValueError(f"Column with name {name} not found in the schema")
|
|
199
284
|
|
|
200
|
-
def c(self, column:
|
|
285
|
+
def c(self, column: str | Column) -> Column:
|
|
201
286
|
"""Returns Column instance attached to the current chain."""
|
|
202
287
|
c = self.column(column) if isinstance(column, str) else self.column(column.name)
|
|
203
288
|
c.table = self._query.table
|
|
@@ -209,27 +294,31 @@ class DataChain:
|
|
|
209
294
|
return self._query.session
|
|
210
295
|
|
|
211
296
|
@property
|
|
212
|
-
def name(self) ->
|
|
297
|
+
def name(self) -> str | None:
|
|
213
298
|
"""Name of the underlying dataset, if there is one."""
|
|
214
299
|
return self._query.name
|
|
215
300
|
|
|
216
301
|
@property
|
|
217
|
-
def version(self) ->
|
|
302
|
+
def version(self) -> str | None:
|
|
218
303
|
"""Version of the underlying dataset, if there is one."""
|
|
219
304
|
return self._query.version
|
|
220
305
|
|
|
221
306
|
@property
|
|
222
|
-
def dataset(self) ->
|
|
307
|
+
def dataset(self) -> DatasetRecord | None:
|
|
223
308
|
"""Underlying dataset, if there is one."""
|
|
224
309
|
if not self.name:
|
|
225
310
|
return None
|
|
226
|
-
return self.session.catalog.get_dataset(
|
|
311
|
+
return self.session.catalog.get_dataset(
|
|
312
|
+
self.name,
|
|
313
|
+
namespace_name=self._query.project.namespace.name,
|
|
314
|
+
project_name=self._query.project.name,
|
|
315
|
+
)
|
|
227
316
|
|
|
228
317
|
def __or__(self, other: "Self") -> "Self":
|
|
229
318
|
"""Return `self.union(other)`."""
|
|
230
319
|
return self.union(other)
|
|
231
320
|
|
|
232
|
-
def print_schema(self, file:
|
|
321
|
+
def print_schema(self, file: IO | None = None) -> None:
|
|
233
322
|
"""Print schema of the chain."""
|
|
234
323
|
self._effective_signals_schema.print_tree(file=file)
|
|
235
324
|
|
|
@@ -240,8 +329,8 @@ class DataChain:
|
|
|
240
329
|
def _evolve(
|
|
241
330
|
self,
|
|
242
331
|
*,
|
|
243
|
-
query:
|
|
244
|
-
settings:
|
|
332
|
+
query: DatasetQuery | None = None,
|
|
333
|
+
settings: Settings | None = None,
|
|
245
334
|
signal_schema=None,
|
|
246
335
|
_sys=None,
|
|
247
336
|
) -> "Self":
|
|
@@ -253,39 +342,60 @@ class DataChain:
|
|
|
253
342
|
signal_schema = copy.deepcopy(self.signals_schema)
|
|
254
343
|
if _sys is None:
|
|
255
344
|
_sys = self._sys
|
|
256
|
-
|
|
345
|
+
chain = type(self)(
|
|
257
346
|
query, settings, signal_schema=signal_schema, setup=self._setup, _sys=_sys
|
|
258
347
|
)
|
|
348
|
+
if self.delta:
|
|
349
|
+
chain = chain._as_delta(
|
|
350
|
+
on=self._delta_on,
|
|
351
|
+
right_on=self._delta_result_on,
|
|
352
|
+
compare=self._delta_compare,
|
|
353
|
+
delta_retry=self._delta_retry,
|
|
354
|
+
delta_unsafe=self._delta_unsafe,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
return chain
|
|
259
358
|
|
|
260
359
|
def settings(
|
|
261
360
|
self,
|
|
262
|
-
cache=None,
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
361
|
+
cache: bool | None = None,
|
|
362
|
+
prefetch: bool | int | None = None,
|
|
363
|
+
parallel: bool | int | None = None,
|
|
364
|
+
workers: int | None = None,
|
|
365
|
+
namespace: str | None = None,
|
|
366
|
+
project: str | None = None,
|
|
367
|
+
min_task_size: int | None = None,
|
|
368
|
+
batch_size: int | None = None,
|
|
369
|
+
sys: bool | None = None,
|
|
268
370
|
) -> "Self":
|
|
269
|
-
"""
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
371
|
+
"""
|
|
372
|
+
Set chain execution parameters. Returns the chain itself, allowing method
|
|
373
|
+
chaining for subsequent operations. To restore all settings to their default
|
|
374
|
+
values, use `reset_settings()`.
|
|
273
375
|
|
|
274
376
|
Parameters:
|
|
275
|
-
cache
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
377
|
+
cache: Enable files caching to speed up subsequent accesses to the same
|
|
378
|
+
files from the same or different chains. Defaults to False.
|
|
379
|
+
prefetch: Enable prefetching of files. This will download files in
|
|
380
|
+
advance in parallel. If an integer is provided, it specifies the number
|
|
381
|
+
of files to prefetch concurrently for each process on each worker.
|
|
382
|
+
Defaults to 2. Set to 0 or False to disable prefetching.
|
|
383
|
+
parallel: Number of processes to use for processing user-defined functions
|
|
384
|
+
(UDFs) in parallel. If an integer is provided, it specifies the number
|
|
385
|
+
of CPUs to use. If True, all available CPUs are used. Defaults to 1.
|
|
386
|
+
namespace: Namespace to use for the chain by default.
|
|
387
|
+
project: Project to use for the chain by default.
|
|
388
|
+
min_task_size: Minimum number of rows per worker/process for parallel
|
|
389
|
+
processing by UDFs. Defaults to 1.
|
|
390
|
+
batch_size: Number of rows per insert by UDF to fine tune and balance speed
|
|
391
|
+
and memory usage. This might be useful when processing large rows
|
|
392
|
+
or when running into memory issues. Defaults to 2000.
|
|
283
393
|
|
|
284
394
|
Example:
|
|
285
395
|
```py
|
|
286
396
|
chain = (
|
|
287
397
|
chain
|
|
288
|
-
.settings(cache=True, parallel=8)
|
|
398
|
+
.settings(cache=True, parallel=8, batch_size=300)
|
|
289
399
|
.map(laion=process_webdataset(spec=WDSLaion), params="file")
|
|
290
400
|
)
|
|
291
401
|
```
|
|
@@ -293,22 +403,25 @@ class DataChain:
|
|
|
293
403
|
if sys is None:
|
|
294
404
|
sys = self._sys
|
|
295
405
|
settings = copy.copy(self._settings)
|
|
296
|
-
settings.add(
|
|
406
|
+
settings.add(
|
|
407
|
+
Settings(
|
|
408
|
+
cache=cache,
|
|
409
|
+
prefetch=prefetch,
|
|
410
|
+
parallel=parallel,
|
|
411
|
+
workers=workers,
|
|
412
|
+
namespace=namespace,
|
|
413
|
+
project=project,
|
|
414
|
+
min_task_size=min_task_size,
|
|
415
|
+
batch_size=batch_size,
|
|
416
|
+
)
|
|
417
|
+
)
|
|
297
418
|
return self._evolve(settings=settings, _sys=sys)
|
|
298
419
|
|
|
299
|
-
def reset_settings(self, settings:
|
|
300
|
-
"""Reset all settings to default values."""
|
|
420
|
+
def reset_settings(self, settings: Settings | None = None) -> "Self":
|
|
421
|
+
"""Reset all chain settings to default values."""
|
|
301
422
|
self._settings = settings if settings else Settings()
|
|
302
423
|
return self
|
|
303
424
|
|
|
304
|
-
def reset_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
305
|
-
self.signals_schema = signals_schema
|
|
306
|
-
return self
|
|
307
|
-
|
|
308
|
-
def add_schema(self, signals_schema: SignalSchema) -> "Self":
|
|
309
|
-
self.signals_schema |= signals_schema
|
|
310
|
-
return self
|
|
311
|
-
|
|
312
425
|
@classmethod
|
|
313
426
|
def from_storage(
|
|
314
427
|
cls,
|
|
@@ -356,8 +469,8 @@ class DataChain:
|
|
|
356
469
|
def explode(
|
|
357
470
|
self,
|
|
358
471
|
col: str,
|
|
359
|
-
model_name:
|
|
360
|
-
|
|
472
|
+
model_name: str | None = None,
|
|
473
|
+
column: str | None = None,
|
|
361
474
|
schema_sample_size: int = 1,
|
|
362
475
|
) -> "DataChain":
|
|
363
476
|
"""Explodes a column containing JSON objects (dict or str DataChain type) into
|
|
@@ -368,7 +481,7 @@ class DataChain:
|
|
|
368
481
|
col: the name of the column containing JSON to be exploded.
|
|
369
482
|
model_name: optional generated model name. By default generates the name
|
|
370
483
|
automatically.
|
|
371
|
-
|
|
484
|
+
column: optional generated column name. By default generates the
|
|
372
485
|
name automatically.
|
|
373
486
|
schema_sample_size: the number of rows to use for inferring the schema of
|
|
374
487
|
the JSON (in case some fields are optional and it's not enough to
|
|
@@ -377,16 +490,14 @@ class DataChain:
|
|
|
377
490
|
Returns:
|
|
378
491
|
DataChain: A new DataChain instance with the new set of columns.
|
|
379
492
|
"""
|
|
380
|
-
import json
|
|
381
|
-
|
|
382
493
|
import pyarrow as pa
|
|
383
494
|
|
|
384
495
|
from datachain.lib.arrow import schema_to_output
|
|
385
496
|
|
|
386
|
-
json_values =
|
|
497
|
+
json_values = self.limit(schema_sample_size).to_list(col)
|
|
387
498
|
json_dicts = [
|
|
388
499
|
json.loads(json_value) if isinstance(json_value, str) else json_value
|
|
389
|
-
for json_value in json_values
|
|
500
|
+
for (json_value,) in json_values
|
|
390
501
|
]
|
|
391
502
|
|
|
392
503
|
if any(not isinstance(json_dict, dict) for json_dict in json_dicts):
|
|
@@ -400,16 +511,16 @@ class DataChain:
|
|
|
400
511
|
|
|
401
512
|
model = dict_to_data_model(model_name, output, original_names)
|
|
402
513
|
|
|
403
|
-
def json_to_model(json_value:
|
|
514
|
+
def json_to_model(json_value: str | dict):
|
|
404
515
|
json_dict = (
|
|
405
516
|
json.loads(json_value) if isinstance(json_value, str) else json_value
|
|
406
517
|
)
|
|
407
518
|
return model.model_validate(json_dict)
|
|
408
519
|
|
|
409
|
-
if not
|
|
410
|
-
|
|
520
|
+
if not column:
|
|
521
|
+
column = f"{col}_expl"
|
|
411
522
|
|
|
412
|
-
return self.map(json_to_model, params=col, output={
|
|
523
|
+
return self.map(json_to_model, params=col, output={column: model})
|
|
413
524
|
|
|
414
525
|
@classmethod
|
|
415
526
|
def datasets(
|
|
@@ -443,35 +554,290 @@ class DataChain:
|
|
|
443
554
|
)
|
|
444
555
|
return listings(*args, **kwargs)
|
|
445
556
|
|
|
557
|
+
@property
|
|
558
|
+
def namespace_name(self) -> str:
|
|
559
|
+
"""Current namespace name in which the chain is running"""
|
|
560
|
+
return (
|
|
561
|
+
self._settings.namespace
|
|
562
|
+
or self.session.catalog.metastore.default_namespace_name
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def project_name(self) -> str:
|
|
567
|
+
"""Current project name in which the chain is running"""
|
|
568
|
+
return (
|
|
569
|
+
self._settings.project
|
|
570
|
+
or self.session.catalog.metastore.default_project_name
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
def persist(self) -> "Self":
|
|
574
|
+
"""Saves temporary chain that will be removed after the process ends.
|
|
575
|
+
Temporary datasets are useful for optimization, for example when we have
|
|
576
|
+
multiple chains starting with identical sub-chain. We can then persist that
|
|
577
|
+
common chain and use it to calculate other chains, to avoid re-calculation
|
|
578
|
+
every time.
|
|
579
|
+
It returns the chain itself.
|
|
580
|
+
"""
|
|
581
|
+
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
582
|
+
project = self.session.catalog.metastore.get_project(
|
|
583
|
+
self.project_name,
|
|
584
|
+
self.namespace_name,
|
|
585
|
+
create=True,
|
|
586
|
+
)
|
|
587
|
+
return self._evolve(
|
|
588
|
+
query=self._query.save(project=project, feature_schema=schema),
|
|
589
|
+
signal_schema=self.signals_schema | SignalSchema({"sys": Sys}),
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
def _calculate_job_hash(self, job_id: str) -> str:
|
|
593
|
+
"""
|
|
594
|
+
Calculates hash of the job at the place of this chain's save method.
|
|
595
|
+
Hash is calculated using previous job checkpoint hash (if exists) and
|
|
596
|
+
adding hash of this chain to produce new hash.
|
|
597
|
+
"""
|
|
598
|
+
last_checkpoint = self.session.catalog.metastore.get_last_checkpoint(job_id)
|
|
599
|
+
|
|
600
|
+
return hashlib.sha256(
|
|
601
|
+
(bytes.fromhex(last_checkpoint.hash) if last_checkpoint else b"")
|
|
602
|
+
+ bytes.fromhex(self.hash())
|
|
603
|
+
).hexdigest()
|
|
604
|
+
|
|
446
605
|
def save( # type: ignore[override]
|
|
447
606
|
self,
|
|
448
|
-
name:
|
|
449
|
-
version:
|
|
450
|
-
description:
|
|
451
|
-
|
|
607
|
+
name: str,
|
|
608
|
+
version: str | None = None,
|
|
609
|
+
description: str | None = None,
|
|
610
|
+
attrs: list[str] | None = None,
|
|
611
|
+
update_version: str | None = "patch",
|
|
452
612
|
**kwargs,
|
|
453
|
-
) -> "
|
|
613
|
+
) -> "DataChain":
|
|
454
614
|
"""Save to a Dataset. It returns the chain itself.
|
|
455
615
|
|
|
456
616
|
Parameters:
|
|
457
|
-
name
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
617
|
+
name: dataset name. This can be either a fully qualified name, including
|
|
618
|
+
the namespace and project, or just a regular dataset name. In the latter
|
|
619
|
+
case, the namespace and project will be taken from the settings
|
|
620
|
+
(if specified) or from the default values otherwise.
|
|
621
|
+
version: version of a dataset. If version is not specified and dataset
|
|
622
|
+
already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
|
|
623
|
+
description: description of a dataset.
|
|
624
|
+
attrs: attributes of a dataset. They can be without value, e.g "NLP",
|
|
625
|
+
or with a value, e.g "location=US".
|
|
626
|
+
update_version: which part of the dataset version to automatically increase.
|
|
627
|
+
Available values: `major`, `minor` or `patch`. Default is `patch`.
|
|
462
628
|
"""
|
|
629
|
+
|
|
630
|
+
catalog = self.session.catalog
|
|
631
|
+
|
|
632
|
+
result = None # result chain that will be returned at the end
|
|
633
|
+
|
|
634
|
+
# Version validation
|
|
635
|
+
self._validate_version(version)
|
|
636
|
+
self._validate_update_version(update_version)
|
|
637
|
+
|
|
638
|
+
# get existing job if running in SaaS, or creating new one if running locally
|
|
639
|
+
job = self.session.get_or_create_job()
|
|
640
|
+
|
|
641
|
+
namespace_name, project_name, name = catalog.get_full_dataset_name(
|
|
642
|
+
name,
|
|
643
|
+
namespace_name=self._settings.namespace,
|
|
644
|
+
project_name=self._settings.project,
|
|
645
|
+
)
|
|
646
|
+
project = self._get_or_create_project(namespace_name, project_name)
|
|
647
|
+
|
|
648
|
+
# Checkpoint handling
|
|
649
|
+
_hash, result = self._resolve_checkpoint(name, project, job, kwargs)
|
|
650
|
+
if bool(result):
|
|
651
|
+
# Checkpoint was found and reused
|
|
652
|
+
print(f"Checkpoint found for dataset '{name}', skipping creation")
|
|
653
|
+
|
|
654
|
+
# Schema preparation
|
|
463
655
|
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
656
|
+
|
|
657
|
+
# Handle retry and delta functionality
|
|
658
|
+
if not result:
|
|
659
|
+
result = self._handle_delta(name, version, project, schema, kwargs)
|
|
660
|
+
|
|
661
|
+
if not result:
|
|
662
|
+
# calculate chain if we already don't have result from checkpoint or delta
|
|
663
|
+
result = self._evolve(
|
|
664
|
+
query=self._query.save(
|
|
665
|
+
name=name,
|
|
666
|
+
version=version,
|
|
667
|
+
project=project,
|
|
668
|
+
description=description,
|
|
669
|
+
attrs=attrs,
|
|
670
|
+
feature_schema=schema,
|
|
671
|
+
update_version=update_version,
|
|
672
|
+
**kwargs,
|
|
673
|
+
)
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
|
|
677
|
+
return result
|
|
678
|
+
|
|
679
|
+
def _validate_version(self, version: str | None) -> None:
|
|
680
|
+
"""Validate dataset version if provided."""
|
|
681
|
+
if version is not None:
|
|
682
|
+
semver.validate(version)
|
|
683
|
+
|
|
684
|
+
def _validate_update_version(self, update_version: str | None) -> None:
|
|
685
|
+
"""Ensure update_version is one of: major, minor, patch."""
|
|
686
|
+
allowed = ["major", "minor", "patch"]
|
|
687
|
+
if update_version not in allowed:
|
|
688
|
+
raise ValueError(f"update_version must be one of {allowed}")
|
|
689
|
+
|
|
690
|
+
def _get_or_create_project(self, namespace: str, project_name: str) -> Project:
|
|
691
|
+
"""Get project or raise if creation not allowed."""
|
|
692
|
+
try:
|
|
693
|
+
return self.session.catalog.metastore.get_project(
|
|
694
|
+
project_name,
|
|
695
|
+
namespace,
|
|
696
|
+
create=is_studio(),
|
|
697
|
+
)
|
|
698
|
+
except ProjectNotFoundError as e:
|
|
699
|
+
raise ProjectCreateNotAllowedError("Creating project is not allowed") from e
|
|
700
|
+
|
|
701
|
+
def _resolve_checkpoint(
|
|
702
|
+
self,
|
|
703
|
+
name: str,
|
|
704
|
+
project: Project,
|
|
705
|
+
job: Job,
|
|
706
|
+
kwargs: dict,
|
|
707
|
+
) -> tuple[str, "DataChain | None"]:
|
|
708
|
+
"""Check if checkpoint exists and return cached dataset if possible."""
|
|
709
|
+
from .datasets import read_dataset
|
|
710
|
+
|
|
711
|
+
metastore = self.session.catalog.metastore
|
|
712
|
+
checkpoints_reset = env2bool("DATACHAIN_CHECKPOINTS_RESET", undefined=True)
|
|
713
|
+
|
|
714
|
+
_hash = self._calculate_job_hash(job.id)
|
|
715
|
+
|
|
716
|
+
if (
|
|
717
|
+
job.parent_job_id
|
|
718
|
+
and not checkpoints_reset
|
|
719
|
+
and metastore.find_checkpoint(job.parent_job_id, _hash)
|
|
720
|
+
):
|
|
721
|
+
# checkpoint found → find which dataset version to reuse
|
|
722
|
+
|
|
723
|
+
# Find dataset version that was created by any ancestor job
|
|
724
|
+
try:
|
|
725
|
+
dataset_version = metastore.get_dataset_version_for_job_ancestry(
|
|
726
|
+
name,
|
|
727
|
+
project.namespace.name,
|
|
728
|
+
project.name,
|
|
729
|
+
job.id,
|
|
730
|
+
)
|
|
731
|
+
except JobAncestryDepthExceededError:
|
|
732
|
+
raise JobAncestryDepthExceededError(
|
|
733
|
+
"Job continuation chain is too deep. "
|
|
734
|
+
"Please run the job from scratch without continuing from a "
|
|
735
|
+
"parent job."
|
|
736
|
+
) from None
|
|
737
|
+
|
|
738
|
+
if not dataset_version:
|
|
739
|
+
logger.debug(
|
|
740
|
+
"Checkpoint found but no dataset version for '%s' "
|
|
741
|
+
"in job ancestry (job_id=%s). Creating new version.",
|
|
742
|
+
name,
|
|
743
|
+
job.id,
|
|
744
|
+
)
|
|
745
|
+
# Dataset version not found (e.g deleted by user) - skip
|
|
746
|
+
# checkpoint and recreate
|
|
747
|
+
return _hash, None
|
|
748
|
+
|
|
749
|
+
logger.debug(
|
|
750
|
+
"Reusing dataset version '%s' v%s from job ancestry "
|
|
751
|
+
"(job_id=%s, dataset_version_id=%s)",
|
|
752
|
+
name,
|
|
753
|
+
dataset_version.version,
|
|
754
|
+
job.id,
|
|
755
|
+
dataset_version.id,
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
# Read the specific version from ancestry
|
|
759
|
+
chain = read_dataset(
|
|
760
|
+
name,
|
|
761
|
+
namespace=project.namespace.name,
|
|
762
|
+
project=project.name,
|
|
763
|
+
version=dataset_version.version,
|
|
471
764
|
**kwargs,
|
|
472
765
|
)
|
|
766
|
+
|
|
767
|
+
# Link current job to this dataset version (not creator).
|
|
768
|
+
# This also updates dataset_version.job_id.
|
|
769
|
+
metastore.link_dataset_version_to_job(
|
|
770
|
+
dataset_version.id,
|
|
771
|
+
job.id,
|
|
772
|
+
is_creator=False,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
return _hash, chain
|
|
776
|
+
|
|
777
|
+
return _hash, None
|
|
778
|
+
|
|
779
|
+
def _handle_delta(
|
|
780
|
+
self,
|
|
781
|
+
name: str,
|
|
782
|
+
version: str | None,
|
|
783
|
+
project: Project,
|
|
784
|
+
schema: dict,
|
|
785
|
+
kwargs: dict,
|
|
786
|
+
) -> "DataChain | None":
|
|
787
|
+
"""Try to save as a delta dataset.
|
|
788
|
+
Returns:
|
|
789
|
+
A DataChain if delta logic could handle it, otherwise None to fall back
|
|
790
|
+
to the regular save path (e.g., on first dataset creation).
|
|
791
|
+
"""
|
|
792
|
+
from datachain.delta import delta_retry_update
|
|
793
|
+
|
|
794
|
+
from .datasets import read_dataset
|
|
795
|
+
|
|
796
|
+
if not self.delta or not name:
|
|
797
|
+
return None
|
|
798
|
+
|
|
799
|
+
assert self._delta_on is not None, "Delta chain must have delta_on defined"
|
|
800
|
+
|
|
801
|
+
result_ds, dependencies, has_changes = delta_retry_update(
|
|
802
|
+
self,
|
|
803
|
+
project.namespace.name,
|
|
804
|
+
project.name,
|
|
805
|
+
name,
|
|
806
|
+
on=self._delta_on,
|
|
807
|
+
right_on=self._delta_result_on,
|
|
808
|
+
compare=self._delta_compare,
|
|
809
|
+
delta_retry=self._delta_retry,
|
|
473
810
|
)
|
|
474
811
|
|
|
812
|
+
# Case 1: delta produced a new dataset
|
|
813
|
+
if result_ds:
|
|
814
|
+
return self._evolve(
|
|
815
|
+
query=result_ds._query.save(
|
|
816
|
+
name=name,
|
|
817
|
+
version=version,
|
|
818
|
+
project=project,
|
|
819
|
+
feature_schema=schema,
|
|
820
|
+
dependencies=dependencies,
|
|
821
|
+
**kwargs,
|
|
822
|
+
)
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
# Case 2: no changes → reuse last version
|
|
826
|
+
if not has_changes:
|
|
827
|
+
# sources have not been changed so new version of resulting dataset
|
|
828
|
+
# would be the same as previous one. To avoid duplicating exact
|
|
829
|
+
# datasets, we won't create new version of it and we will return
|
|
830
|
+
# current latest version instead.
|
|
831
|
+
return read_dataset(
|
|
832
|
+
name,
|
|
833
|
+
namespace=project.namespace.name,
|
|
834
|
+
project=project.name,
|
|
835
|
+
**kwargs,
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Case 3: first creation of dataset
|
|
839
|
+
return None
|
|
840
|
+
|
|
475
841
|
def apply(self, func, *args, **kwargs):
|
|
476
842
|
"""Apply any function to the chain.
|
|
477
843
|
|
|
@@ -497,10 +863,10 @@ class DataChain:
|
|
|
497
863
|
|
|
498
864
|
def map(
|
|
499
865
|
self,
|
|
500
|
-
func:
|
|
501
|
-
params:
|
|
866
|
+
func: Callable | None = None,
|
|
867
|
+
params: str | Sequence[str] | None = None,
|
|
502
868
|
output: OutputType = None,
|
|
503
|
-
**signal_map,
|
|
869
|
+
**signal_map: Any,
|
|
504
870
|
) -> "Self":
|
|
505
871
|
"""Apply a function to each row to create new signals. The function should
|
|
506
872
|
return a new object for each row. It returns a chain itself with new signals.
|
|
@@ -508,17 +874,17 @@ class DataChain:
|
|
|
508
874
|
Input-output relationship: 1:1
|
|
509
875
|
|
|
510
876
|
Parameters:
|
|
511
|
-
func
|
|
512
|
-
params
|
|
877
|
+
func: Function applied to each row.
|
|
878
|
+
params: List of column names used as input for the function. Default
|
|
513
879
|
is taken from function signature.
|
|
514
|
-
output
|
|
880
|
+
output: Dictionary defining new signals and their corresponding types.
|
|
515
881
|
Default type is taken from function signature. Default can be also
|
|
516
882
|
taken from kwargs - **signal_map (see below).
|
|
517
883
|
If signal name is defined using signal_map (see below) only a single
|
|
518
884
|
type value can be used.
|
|
519
|
-
**signal_map
|
|
885
|
+
**signal_map: kwargs can be used to define `func` together with its return
|
|
520
886
|
signal name in format of `map(my_sign=my_func)`. This helps define
|
|
521
|
-
signal names and
|
|
887
|
+
signal names and functions in a nicer way.
|
|
522
888
|
|
|
523
889
|
Example:
|
|
524
890
|
Using signal_map and single type in output:
|
|
@@ -539,18 +905,19 @@ class DataChain:
|
|
|
539
905
|
if (prefetch := self._settings.prefetch) is not None:
|
|
540
906
|
udf_obj.prefetch = prefetch
|
|
541
907
|
|
|
908
|
+
sys_schema = SignalSchema({"sys": Sys})
|
|
542
909
|
return self._evolve(
|
|
543
910
|
query=self._query.add_signals(
|
|
544
|
-
udf_obj.to_udf_wrapper(),
|
|
911
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
545
912
|
**self._settings.to_dict(),
|
|
546
913
|
),
|
|
547
|
-
signal_schema=self.signals_schema | udf_obj.output,
|
|
914
|
+
signal_schema=sys_schema | self.signals_schema | udf_obj.output,
|
|
548
915
|
)
|
|
549
916
|
|
|
550
917
|
def gen(
|
|
551
918
|
self,
|
|
552
|
-
func:
|
|
553
|
-
params:
|
|
919
|
+
func: Callable | Generator | None = None,
|
|
920
|
+
params: str | Sequence[str] | None = None,
|
|
554
921
|
output: OutputType = None,
|
|
555
922
|
**signal_map,
|
|
556
923
|
) -> "Self":
|
|
@@ -579,19 +946,21 @@ class DataChain:
|
|
|
579
946
|
udf_obj.prefetch = prefetch
|
|
580
947
|
return self._evolve(
|
|
581
948
|
query=self._query.generate(
|
|
582
|
-
udf_obj.to_udf_wrapper(),
|
|
949
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
583
950
|
**self._settings.to_dict(),
|
|
584
951
|
),
|
|
585
|
-
signal_schema=udf_obj.output,
|
|
952
|
+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
|
|
586
953
|
)
|
|
587
954
|
|
|
955
|
+
@delta_disabled
|
|
588
956
|
def agg(
|
|
589
957
|
self,
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
958
|
+
/,
|
|
959
|
+
func: Callable | None = None,
|
|
960
|
+
partition_by: PartitionByType | None = None,
|
|
961
|
+
params: str | Sequence[str] | None = None,
|
|
593
962
|
output: OutputType = None,
|
|
594
|
-
**signal_map,
|
|
963
|
+
**signal_map: Callable,
|
|
595
964
|
) -> "Self":
|
|
596
965
|
"""Aggregate rows using `partition_by` statement and apply a function to the
|
|
597
966
|
groups of aggregated rows. The function needs to return new objects for each
|
|
@@ -601,12 +970,28 @@ class DataChain:
|
|
|
601
970
|
|
|
602
971
|
This method bears similarity to `gen()` and `map()`, employing a comparable set
|
|
603
972
|
of parameters, yet differs in two crucial aspects:
|
|
973
|
+
|
|
604
974
|
1. The `partition_by` parameter: This specifies the column name or a list of
|
|
605
975
|
column names that determine the grouping criteria for aggregation.
|
|
606
976
|
2. Group-based UDF function input: Instead of individual rows, the function
|
|
607
|
-
receives a list all rows within each group defined by `partition_by`.
|
|
977
|
+
receives a list of all rows within each group defined by `partition_by`.
|
|
978
|
+
|
|
979
|
+
If `partition_by` is not set or is an empty list, all rows will be placed
|
|
980
|
+
into a single group.
|
|
981
|
+
|
|
982
|
+
Parameters:
|
|
983
|
+
func: Function applied to each group of rows.
|
|
984
|
+
partition_by: Column name(s) to group by. If None, all rows go
|
|
985
|
+
into one group.
|
|
986
|
+
params: List of column names used as input for the function. Default is
|
|
987
|
+
taken from function signature.
|
|
988
|
+
output: Dictionary defining new signals and their corresponding types.
|
|
989
|
+
Default type is taken from function signature.
|
|
990
|
+
**signal_map: kwargs can be used to define `func` together with its return
|
|
991
|
+
signal name in format of `agg(result_column=my_func)`.
|
|
608
992
|
|
|
609
993
|
Examples:
|
|
994
|
+
Basic aggregation with lambda function:
|
|
610
995
|
```py
|
|
611
996
|
chain = chain.agg(
|
|
612
997
|
total=lambda category, amount: [sum(amount)],
|
|
@@ -617,7 +1002,6 @@ class DataChain:
|
|
|
617
1002
|
```
|
|
618
1003
|
|
|
619
1004
|
An alternative syntax, when you need to specify a more complex function:
|
|
620
|
-
|
|
621
1005
|
```py
|
|
622
1006
|
# It automatically resolves which columns to pass to the function
|
|
623
1007
|
# by looking at the function signature.
|
|
@@ -635,21 +1019,80 @@ class DataChain:
|
|
|
635
1019
|
)
|
|
636
1020
|
chain.save("new_dataset")
|
|
637
1021
|
```
|
|
1022
|
+
|
|
1023
|
+
Using complex signals for partitioning (`File` or any Pydantic `BaseModel`):
|
|
1024
|
+
```py
|
|
1025
|
+
def my_agg(files: list[File]) -> Iterator[tuple[File, int]]:
|
|
1026
|
+
yield files[0], sum(f.size for f in files)
|
|
1027
|
+
|
|
1028
|
+
chain = chain.agg(
|
|
1029
|
+
my_agg,
|
|
1030
|
+
params=("file",),
|
|
1031
|
+
output={"file": File, "total": int},
|
|
1032
|
+
partition_by="file", # Column referring to all sub-columns of File
|
|
1033
|
+
)
|
|
1034
|
+
chain.save("new_dataset")
|
|
1035
|
+
```
|
|
1036
|
+
|
|
1037
|
+
Aggregating all rows into a single group (when `partition_by` is not set):
|
|
1038
|
+
```py
|
|
1039
|
+
chain = chain.agg(
|
|
1040
|
+
total_size=lambda file, size: [sum(size)],
|
|
1041
|
+
output=int,
|
|
1042
|
+
# No partition_by specified - all rows go into one group
|
|
1043
|
+
)
|
|
1044
|
+
chain.save("new_dataset")
|
|
1045
|
+
```
|
|
1046
|
+
|
|
1047
|
+
Multiple partition columns:
|
|
1048
|
+
```py
|
|
1049
|
+
chain = chain.agg(
|
|
1050
|
+
total=lambda category, subcategory, amount: [sum(amount)],
|
|
1051
|
+
output=float,
|
|
1052
|
+
partition_by=["category", "subcategory"],
|
|
1053
|
+
)
|
|
1054
|
+
chain.save("new_dataset")
|
|
1055
|
+
```
|
|
638
1056
|
"""
|
|
1057
|
+
if partition_by is not None:
|
|
1058
|
+
# Convert string partition_by parameters to Column objects
|
|
1059
|
+
if isinstance(partition_by, (str, Function, ColumnElement)):
|
|
1060
|
+
list_partition_by = [partition_by]
|
|
1061
|
+
else:
|
|
1062
|
+
list_partition_by = list(partition_by)
|
|
1063
|
+
|
|
1064
|
+
processed_partition_columns: list[ColumnElement] = []
|
|
1065
|
+
for col in list_partition_by:
|
|
1066
|
+
if isinstance(col, str):
|
|
1067
|
+
columns = self.signals_schema.db_signals(name=col, as_columns=True)
|
|
1068
|
+
if not columns:
|
|
1069
|
+
raise SignalResolvingError([col], "is not found")
|
|
1070
|
+
processed_partition_columns.extend(cast("list[Column]", columns))
|
|
1071
|
+
elif isinstance(col, Function):
|
|
1072
|
+
column = col.get_column(self.signals_schema)
|
|
1073
|
+
processed_partition_columns.append(column)
|
|
1074
|
+
else:
|
|
1075
|
+
# Assume it's already a ColumnElement
|
|
1076
|
+
processed_partition_columns.append(col)
|
|
1077
|
+
|
|
1078
|
+
processed_partition_by = processed_partition_columns
|
|
1079
|
+
else:
|
|
1080
|
+
processed_partition_by = []
|
|
1081
|
+
|
|
639
1082
|
udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
|
|
640
1083
|
return self._evolve(
|
|
641
1084
|
query=self._query.generate(
|
|
642
|
-
udf_obj.to_udf_wrapper(),
|
|
643
|
-
partition_by=
|
|
1085
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
1086
|
+
partition_by=processed_partition_by,
|
|
644
1087
|
**self._settings.to_dict(),
|
|
645
1088
|
),
|
|
646
|
-
signal_schema=udf_obj.output,
|
|
1089
|
+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
|
|
647
1090
|
)
|
|
648
1091
|
|
|
649
1092
|
def batch_map(
|
|
650
1093
|
self,
|
|
651
|
-
func:
|
|
652
|
-
params:
|
|
1094
|
+
func: Callable | None = None,
|
|
1095
|
+
params: str | Sequence[str] | None = None,
|
|
653
1096
|
output: OutputType = None,
|
|
654
1097
|
batch: int = 1000,
|
|
655
1098
|
**signal_map,
|
|
@@ -661,7 +1104,7 @@ class DataChain:
|
|
|
661
1104
|
It accepts the same parameters plus an
|
|
662
1105
|
additional parameter:
|
|
663
1106
|
|
|
664
|
-
batch
|
|
1107
|
+
batch: Size of each batch passed to `func`. Defaults to 1000.
|
|
665
1108
|
|
|
666
1109
|
Example:
|
|
667
1110
|
```py
|
|
@@ -671,11 +1114,24 @@ class DataChain:
|
|
|
671
1114
|
)
|
|
672
1115
|
chain.save("new_dataset")
|
|
673
1116
|
```
|
|
1117
|
+
|
|
1118
|
+
.. deprecated:: 0.29.0
|
|
1119
|
+
This method is deprecated and will be removed in a future version.
|
|
1120
|
+
Use `agg()` instead, which provides the similar functionality.
|
|
674
1121
|
"""
|
|
1122
|
+
import warnings
|
|
1123
|
+
|
|
1124
|
+
warnings.warn(
|
|
1125
|
+
"batch_map() is deprecated and will be removed in a future version. "
|
|
1126
|
+
"Use agg() instead, which provides the similar functionality.",
|
|
1127
|
+
DeprecationWarning,
|
|
1128
|
+
stacklevel=2,
|
|
1129
|
+
)
|
|
675
1130
|
udf_obj = self._udf_to_obj(BatchMapper, func, params, output, signal_map)
|
|
1131
|
+
|
|
676
1132
|
return self._evolve(
|
|
677
1133
|
query=self._query.add_signals(
|
|
678
|
-
udf_obj.to_udf_wrapper(batch),
|
|
1134
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size, batch=batch),
|
|
679
1135
|
**self._settings.to_dict(),
|
|
680
1136
|
),
|
|
681
1137
|
signal_schema=self.signals_schema | udf_obj.output,
|
|
@@ -684,8 +1140,8 @@ class DataChain:
|
|
|
684
1140
|
def _udf_to_obj(
|
|
685
1141
|
self,
|
|
686
1142
|
target_class: type[UDFObjT],
|
|
687
|
-
func:
|
|
688
|
-
params:
|
|
1143
|
+
func: Callable | UDFObjT | None,
|
|
1144
|
+
params: str | Sequence[str] | None,
|
|
689
1145
|
output: OutputType,
|
|
690
1146
|
signal_map: dict[str, Callable],
|
|
691
1147
|
) -> UDFObjT:
|
|
@@ -696,11 +1152,7 @@ class DataChain:
|
|
|
696
1152
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
697
1153
|
DataModel.register(list(sign.output_schema.values.values()))
|
|
698
1154
|
|
|
699
|
-
|
|
700
|
-
if self._sys:
|
|
701
|
-
signals_schema = SignalSchema({"sys": Sys}) | signals_schema
|
|
702
|
-
|
|
703
|
-
params_schema = signals_schema.slice(
|
|
1155
|
+
params_schema = self.signals_schema.slice(
|
|
704
1156
|
sign.params, self._setup, is_batch=is_batch
|
|
705
1157
|
)
|
|
706
1158
|
|
|
@@ -710,7 +1162,7 @@ class DataChain:
|
|
|
710
1162
|
query_func = getattr(self._query, method_name)
|
|
711
1163
|
|
|
712
1164
|
new_schema = self.signals_schema.resolve(*args)
|
|
713
|
-
columns =
|
|
1165
|
+
columns = new_schema.db_signals(as_columns=True)
|
|
714
1166
|
return query_func(*columns, **kwargs)
|
|
715
1167
|
|
|
716
1168
|
@resolve_columns
|
|
@@ -729,15 +1181,17 @@ class DataChain:
|
|
|
729
1181
|
Order is not guaranteed when steps are added after an `order_by` statement.
|
|
730
1182
|
I.e. when using `read_dataset` an `order_by` statement should be used if
|
|
731
1183
|
the order of the records in the chain is important.
|
|
732
|
-
Using `order_by` directly before `limit`, `
|
|
1184
|
+
Using `order_by` directly before `limit`, `to_list` and similar methods
|
|
733
1185
|
will give expected results.
|
|
734
|
-
See https://github.com/
|
|
1186
|
+
See https://github.com/datachain-ai/datachain/issues/477
|
|
1187
|
+
for further details.
|
|
735
1188
|
"""
|
|
736
1189
|
if descending:
|
|
737
1190
|
args = tuple(sqlalchemy.desc(a) for a in args)
|
|
738
1191
|
|
|
739
1192
|
return self._evolve(query=self._query.order_by(*args))
|
|
740
1193
|
|
|
1194
|
+
@delta_disabled
|
|
741
1195
|
def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
|
|
742
1196
|
"""Removes duplicate rows based on uniqueness of some input column(s)
|
|
743
1197
|
i.e if rows are found with the same value of input column(s), only one
|
|
@@ -745,7 +1199,7 @@ class DataChain:
|
|
|
745
1199
|
|
|
746
1200
|
Example:
|
|
747
1201
|
```py
|
|
748
|
-
dc.distinct("file.
|
|
1202
|
+
dc.distinct("file.path")
|
|
749
1203
|
```
|
|
750
1204
|
"""
|
|
751
1205
|
return self._evolve(
|
|
@@ -754,11 +1208,9 @@ class DataChain:
|
|
|
754
1208
|
)
|
|
755
1209
|
)
|
|
756
1210
|
|
|
757
|
-
def select(self, *args: str
|
|
1211
|
+
def select(self, *args: str) -> "Self":
|
|
758
1212
|
"""Select only a specified set of signals."""
|
|
759
1213
|
new_schema = self.signals_schema.resolve(*args)
|
|
760
|
-
if self._sys and _sys:
|
|
761
|
-
new_schema = SignalSchema({"sys": Sys}) | new_schema
|
|
762
1214
|
columns = new_schema.db_signals()
|
|
763
1215
|
return self._evolve(
|
|
764
1216
|
query=self._query.select(*columns), signal_schema=new_schema
|
|
@@ -772,10 +1224,11 @@ class DataChain:
|
|
|
772
1224
|
query=self._query.select(*columns), signal_schema=new_schema
|
|
773
1225
|
)
|
|
774
1226
|
|
|
775
|
-
|
|
1227
|
+
@delta_disabled # type: ignore[arg-type]
|
|
1228
|
+
def group_by( # noqa: C901, PLR0912
|
|
776
1229
|
self,
|
|
777
1230
|
*,
|
|
778
|
-
partition_by:
|
|
1231
|
+
partition_by: str | Func | Sequence[str | Func] | None = None,
|
|
779
1232
|
**kwargs: Func,
|
|
780
1233
|
) -> "Self":
|
|
781
1234
|
"""Group rows by specified set of signals and return new signals
|
|
@@ -791,6 +1244,15 @@ class DataChain:
|
|
|
791
1244
|
partition_by=("file_source", "file_ext"),
|
|
792
1245
|
)
|
|
793
1246
|
```
|
|
1247
|
+
|
|
1248
|
+
Using complex signals:
|
|
1249
|
+
```py
|
|
1250
|
+
chain = chain.group_by(
|
|
1251
|
+
total_size=func.sum("file.size"),
|
|
1252
|
+
count=func.count(),
|
|
1253
|
+
partition_by="file", # Uses column name, expands to File's unique keys
|
|
1254
|
+
)
|
|
1255
|
+
```
|
|
794
1256
|
"""
|
|
795
1257
|
if partition_by is None:
|
|
796
1258
|
partition_by = []
|
|
@@ -801,20 +1263,61 @@ class DataChain:
|
|
|
801
1263
|
signal_columns: list[Column] = []
|
|
802
1264
|
schema_fields: dict[str, DataType] = {}
|
|
803
1265
|
keep_columns: list[str] = []
|
|
1266
|
+
partial_fields: list[str] = [] # Track specific fields for partial creation
|
|
1267
|
+
schema_partition_by: list[str] = []
|
|
804
1268
|
|
|
805
|
-
# validate partition_by columns and add them to the schema
|
|
806
1269
|
for col in partition_by:
|
|
807
1270
|
if isinstance(col, str):
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
1271
|
+
columns = self.signals_schema.db_signals(name=col, as_columns=True)
|
|
1272
|
+
if not columns:
|
|
1273
|
+
raise SignalResolvingError([col], "is not found")
|
|
1274
|
+
partition_by_columns.extend(cast("list[Column]", columns))
|
|
1275
|
+
|
|
1276
|
+
# For nested field references (e.g., "nested.level1.name"),
|
|
1277
|
+
# we need to distinguish between:
|
|
1278
|
+
# 1. References to fields within a complex signal (create partials)
|
|
1279
|
+
# 2. Deep nested references that should be flattened
|
|
1280
|
+
if "." in col:
|
|
1281
|
+
# Split the column reference to analyze it
|
|
1282
|
+
parts = col.split(".")
|
|
1283
|
+
parent_signal = parts[0]
|
|
1284
|
+
parent_type = self.signals_schema.values.get(parent_signal)
|
|
1285
|
+
|
|
1286
|
+
if ModelStore.is_partial(parent_type):
|
|
1287
|
+
if parent_signal not in keep_columns:
|
|
1288
|
+
keep_columns.append(parent_signal)
|
|
1289
|
+
partial_fields.append(col)
|
|
1290
|
+
schema_partition_by.append(col)
|
|
1291
|
+
else:
|
|
1292
|
+
# BaseModel or other - add flattened columns directly
|
|
1293
|
+
for column in cast("list[Column]", columns):
|
|
1294
|
+
col_type = self.signals_schema.get_column_type(column.name)
|
|
1295
|
+
schema_fields[column.name] = col_type
|
|
1296
|
+
schema_partition_by.append(col)
|
|
1297
|
+
else:
|
|
1298
|
+
# simple signal - but we need to check if it's a complex signal
|
|
1299
|
+
# complex signal - only include the columns used for partitioning
|
|
1300
|
+
col_type = self.signals_schema.get_column_type(
|
|
1301
|
+
col, with_subtree=True
|
|
1302
|
+
)
|
|
1303
|
+
if isinstance(col_type, type) and issubclass(col_type, BaseModel):
|
|
1304
|
+
# Complex signal - add only the partitioning columns
|
|
1305
|
+
for column in cast("list[Column]", columns):
|
|
1306
|
+
col_type = self.signals_schema.get_column_type(column.name)
|
|
1307
|
+
schema_fields[column.name] = col_type
|
|
1308
|
+
schema_partition_by.append(col)
|
|
1309
|
+
# Simple signal - keep the entire signal
|
|
1310
|
+
else:
|
|
1311
|
+
if col not in keep_columns:
|
|
1312
|
+
keep_columns.append(col)
|
|
1313
|
+
schema_partition_by.append(col)
|
|
813
1314
|
elif isinstance(col, Function):
|
|
814
1315
|
column = col.get_column(self.signals_schema)
|
|
815
1316
|
col_db_name = column.name
|
|
816
1317
|
col_type = column.type.python_type
|
|
817
1318
|
schema_fields[col_db_name] = col_type
|
|
1319
|
+
partition_by_columns.append(column)
|
|
1320
|
+
signal_columns.append(column)
|
|
818
1321
|
else:
|
|
819
1322
|
raise DataChainColumnError(
|
|
820
1323
|
col,
|
|
@@ -823,9 +1326,7 @@ class DataChain:
|
|
|
823
1326
|
" but expected str or Function"
|
|
824
1327
|
),
|
|
825
1328
|
)
|
|
826
|
-
partition_by_columns.append(column)
|
|
827
1329
|
|
|
828
|
-
# validate signal columns and add them to the schema
|
|
829
1330
|
if not kwargs:
|
|
830
1331
|
raise ValueError("At least one column should be provided for group_by")
|
|
831
1332
|
for col_name, func in kwargs.items():
|
|
@@ -838,9 +1339,9 @@ class DataChain:
|
|
|
838
1339
|
signal_columns.append(column)
|
|
839
1340
|
schema_fields[col_name] = func.get_result_type(self.signals_schema)
|
|
840
1341
|
|
|
841
|
-
signal_schema =
|
|
842
|
-
|
|
843
|
-
|
|
1342
|
+
signal_schema = self.signals_schema.group_by(
|
|
1343
|
+
schema_partition_by, signal_columns
|
|
1344
|
+
)
|
|
844
1345
|
|
|
845
1346
|
return self._evolve(
|
|
846
1347
|
query=self._query.group_by(signal_columns, partition_by_columns),
|
|
@@ -848,17 +1349,13 @@ class DataChain:
|
|
|
848
1349
|
)
|
|
849
1350
|
|
|
850
1351
|
def mutate(self, **kwargs) -> "Self":
|
|
851
|
-
"""Create
|
|
852
|
-
|
|
853
|
-
This method cannot modify existing columns. If you need to modify an
|
|
854
|
-
existing column, use a different name for the new column and then use
|
|
855
|
-
`select()` to choose which columns to keep.
|
|
1352
|
+
"""Create or modify signals based on existing signals.
|
|
856
1353
|
|
|
857
1354
|
This method is vectorized and more efficient compared to map(), and it does not
|
|
858
1355
|
extract or download any data from the internal database. However, it can only
|
|
859
1356
|
utilize predefined built-in functions and their combinations.
|
|
860
1357
|
|
|
861
|
-
|
|
1358
|
+
Supported functions:
|
|
862
1359
|
Numerical: +, -, *, /, rand(), avg(), count(), func(),
|
|
863
1360
|
greatest(), least(), max(), min(), sum()
|
|
864
1361
|
String: length(), split(), replace(), regexp_replace()
|
|
@@ -871,7 +1368,7 @@ class DataChain:
|
|
|
871
1368
|
```py
|
|
872
1369
|
dc.mutate(
|
|
873
1370
|
area=Column("image.height") * Column("image.width"),
|
|
874
|
-
extension=file_ext(Column("file.
|
|
1371
|
+
extension=file_ext(Column("file.path")),
|
|
875
1372
|
dist=cosine_distance(embedding_text, embedding_image)
|
|
876
1373
|
)
|
|
877
1374
|
```
|
|
@@ -885,13 +1382,20 @@ class DataChain:
|
|
|
885
1382
|
```
|
|
886
1383
|
|
|
887
1384
|
This method can be also used to rename signals. If the Column("name") provided
|
|
888
|
-
as value for the new signal - the old
|
|
889
|
-
|
|
1385
|
+
as value for the new signal - the old signal will be dropped. Otherwise a new
|
|
1386
|
+
signal is created. Exception, if the old signal is nested one (e.g.
|
|
1387
|
+
`C("file.path")`), it will be kept to keep the object intact.
|
|
890
1388
|
|
|
891
1389
|
Example:
|
|
892
1390
|
```py
|
|
893
1391
|
dc.mutate(
|
|
894
|
-
newkey=Column("oldkey")
|
|
1392
|
+
newkey=Column("oldkey") # drops oldkey
|
|
1393
|
+
)
|
|
1394
|
+
```
|
|
1395
|
+
|
|
1396
|
+
```py
|
|
1397
|
+
dc.mutate(
|
|
1398
|
+
size=Column("file.size") # keeps `file.size`
|
|
895
1399
|
)
|
|
896
1400
|
```
|
|
897
1401
|
"""
|
|
@@ -926,49 +1430,52 @@ class DataChain:
|
|
|
926
1430
|
# adding new signal
|
|
927
1431
|
mutated[name] = value
|
|
928
1432
|
|
|
1433
|
+
new_schema = schema.mutate(kwargs)
|
|
929
1434
|
return self._evolve(
|
|
930
|
-
query=self._query.mutate(**mutated),
|
|
1435
|
+
query=self._query.mutate(new_schema=new_schema, **mutated),
|
|
1436
|
+
signal_schema=new_schema,
|
|
931
1437
|
)
|
|
932
1438
|
|
|
933
1439
|
@property
|
|
934
1440
|
def _effective_signals_schema(self) -> "SignalSchema":
|
|
935
|
-
"""Effective schema used for user-facing API like
|
|
1441
|
+
"""Effective schema used for user-facing API like to_list, to_pandas, etc."""
|
|
936
1442
|
signals_schema = self.signals_schema
|
|
937
1443
|
if not self._sys:
|
|
938
1444
|
return signals_schema.clone_without_sys_signals()
|
|
939
1445
|
return signals_schema
|
|
940
1446
|
|
|
941
1447
|
@overload
|
|
942
|
-
def
|
|
1448
|
+
def _leaf_values(self) -> Iterator[tuple[Any, ...]]: ...
|
|
943
1449
|
|
|
944
1450
|
@overload
|
|
945
|
-
def
|
|
1451
|
+
def _leaf_values(self, *, include_hidden: bool) -> Iterator[tuple[Any, ...]]: ...
|
|
946
1452
|
|
|
947
1453
|
@overload
|
|
948
|
-
def
|
|
1454
|
+
def _leaf_values(
|
|
949
1455
|
self, *, row_factory: Callable[[list[str], tuple[Any, ...]], _T]
|
|
950
1456
|
) -> Iterator[_T]: ...
|
|
951
1457
|
|
|
952
1458
|
@overload
|
|
953
|
-
def
|
|
1459
|
+
def _leaf_values(
|
|
954
1460
|
self,
|
|
955
1461
|
*,
|
|
956
1462
|
row_factory: Callable[[list[str], tuple[Any, ...]], _T],
|
|
957
1463
|
include_hidden: bool,
|
|
958
1464
|
) -> Iterator[_T]: ...
|
|
959
1465
|
|
|
960
|
-
def
|
|
1466
|
+
def _leaf_values(self, *, row_factory=None, include_hidden: bool = True):
|
|
961
1467
|
"""Yields flattened rows of values as a tuple.
|
|
962
1468
|
|
|
963
1469
|
Args:
|
|
964
|
-
row_factory
|
|
965
|
-
|
|
966
|
-
|
|
1470
|
+
row_factory: A callable to convert row to a custom format.
|
|
1471
|
+
It should accept two arguments: a list of column names and
|
|
1472
|
+
a tuple of row values.
|
|
967
1473
|
include_hidden: Whether to include hidden signals from the schema.
|
|
968
1474
|
"""
|
|
969
1475
|
db_signals = self._effective_signals_schema.db_signals(
|
|
970
1476
|
include_hidden=include_hidden
|
|
971
1477
|
)
|
|
1478
|
+
|
|
972
1479
|
with self._query.ordered_select(*db_signals).as_iterable() as rows:
|
|
973
1480
|
if row_factory:
|
|
974
1481
|
rows = (row_factory(db_signals, r) for r in rows) # type: ignore[assignment]
|
|
@@ -985,7 +1492,7 @@ class DataChain:
|
|
|
985
1492
|
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
986
1493
|
column_names = [".".join(filter(None, header)) for header in headers]
|
|
987
1494
|
|
|
988
|
-
results_iter = self.
|
|
1495
|
+
results_iter = self._leaf_values()
|
|
989
1496
|
|
|
990
1497
|
def column_chunks() -> Iterator[list[list[Any]]]:
|
|
991
1498
|
for chunk_iter in batched_it(results_iter, chunk_size):
|
|
@@ -1018,55 +1525,51 @@ class DataChain:
|
|
|
1018
1525
|
|
|
1019
1526
|
def results(self, *, row_factory=None, include_hidden=True):
|
|
1020
1527
|
if row_factory is None:
|
|
1021
|
-
return list(self.
|
|
1528
|
+
return list(self._leaf_values(include_hidden=include_hidden))
|
|
1022
1529
|
return list(
|
|
1023
|
-
self.
|
|
1530
|
+
self._leaf_values(row_factory=row_factory, include_hidden=include_hidden)
|
|
1024
1531
|
)
|
|
1025
1532
|
|
|
1026
1533
|
def to_records(self) -> list[dict[str, Any]]:
|
|
1027
1534
|
"""Convert every row to a dictionary."""
|
|
1028
1535
|
|
|
1029
1536
|
def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
|
|
1030
|
-
return dict(zip(cols, row))
|
|
1537
|
+
return dict(zip(cols, row, strict=False))
|
|
1031
1538
|
|
|
1032
1539
|
return self.results(row_factory=to_dict)
|
|
1033
1540
|
|
|
1034
|
-
|
|
1035
|
-
def collect(self) -> Iterator[tuple[DataValue, ...]]: ...
|
|
1036
|
-
|
|
1037
|
-
@overload
|
|
1038
|
-
def collect(self, col: str) -> Iterator[DataValue]: ...
|
|
1039
|
-
|
|
1040
|
-
@overload
|
|
1041
|
-
def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...
|
|
1042
|
-
|
|
1043
|
-
def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...]]]: # type: ignore[overload-overlap,misc]
|
|
1541
|
+
def to_iter(self, *cols: str) -> Iterator[tuple[DataValue, ...]]:
|
|
1044
1542
|
"""Yields rows of values, optionally limited to the specified columns.
|
|
1045
1543
|
|
|
1046
1544
|
Args:
|
|
1047
1545
|
*cols: Limit to the specified columns. By default, all columns are selected.
|
|
1048
1546
|
|
|
1049
1547
|
Yields:
|
|
1050
|
-
(DataType): Yields a
|
|
1051
|
-
(tuple[DataType, ...]): Yields a tuple of items if multiple columns are
|
|
1052
|
-
selected.
|
|
1548
|
+
(tuple[DataType, ...]): Yields a tuple of items for each row.
|
|
1053
1549
|
|
|
1054
1550
|
Example:
|
|
1055
1551
|
Iterating over all rows:
|
|
1056
1552
|
```py
|
|
1057
|
-
for row in
|
|
1553
|
+
for row in ds.to_iter():
|
|
1554
|
+
print(row)
|
|
1555
|
+
```
|
|
1556
|
+
|
|
1557
|
+
DataChain is iterable and can be used in a for loop directly which is
|
|
1558
|
+
equivalent to `ds.to_iter()`:
|
|
1559
|
+
```py
|
|
1560
|
+
for row in ds:
|
|
1058
1561
|
print(row)
|
|
1059
1562
|
```
|
|
1060
1563
|
|
|
1061
1564
|
Iterating over all rows with selected columns:
|
|
1062
1565
|
```py
|
|
1063
|
-
for name, size in
|
|
1566
|
+
for name, size in ds.to_iter("file.path", "file.size"):
|
|
1064
1567
|
print(name, size)
|
|
1065
1568
|
```
|
|
1066
1569
|
|
|
1067
1570
|
Iterating over a single column:
|
|
1068
1571
|
```py
|
|
1069
|
-
for file in
|
|
1572
|
+
for (file,) in ds.to_iter("file.path"):
|
|
1070
1573
|
print(file)
|
|
1071
1574
|
```
|
|
1072
1575
|
"""
|
|
@@ -1078,7 +1581,31 @@ class DataChain:
|
|
|
1078
1581
|
ret = signals_schema.row_to_features(
|
|
1079
1582
|
row, catalog=chain.session.catalog, cache=chain._settings.cache
|
|
1080
1583
|
)
|
|
1081
|
-
yield
|
|
1584
|
+
yield tuple(ret)
|
|
1585
|
+
|
|
1586
|
+
@overload
|
|
1587
|
+
def collect(self) -> Iterator[tuple[DataValue, ...]]: ...
|
|
1588
|
+
|
|
1589
|
+
@overload
|
|
1590
|
+
def collect(self, col: str) -> Iterator[DataValue]: ...
|
|
1591
|
+
|
|
1592
|
+
@overload
|
|
1593
|
+
def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...
|
|
1594
|
+
|
|
1595
|
+
def collect(self, *cols: str) -> Iterator[DataValue | tuple[DataValue, ...]]: # type: ignore[overload-overlap,misc]
|
|
1596
|
+
"""
|
|
1597
|
+
Deprecated. Use `to_iter` method instead.
|
|
1598
|
+
"""
|
|
1599
|
+
warnings.warn(
|
|
1600
|
+
"Method `collect` is deprecated. Use `to_iter` method instead.",
|
|
1601
|
+
DeprecationWarning,
|
|
1602
|
+
stacklevel=2,
|
|
1603
|
+
)
|
|
1604
|
+
|
|
1605
|
+
if len(cols) == 1:
|
|
1606
|
+
yield from [item[0] for item in self.to_iter(*cols)]
|
|
1607
|
+
else:
|
|
1608
|
+
yield from self.to_iter(*cols)
|
|
1082
1609
|
|
|
1083
1610
|
def to_pytorch(
|
|
1084
1611
|
self,
|
|
@@ -1112,7 +1639,7 @@ class DataChain:
|
|
|
1112
1639
|
if self._query.attached:
|
|
1113
1640
|
chain = self
|
|
1114
1641
|
else:
|
|
1115
|
-
chain = self.
|
|
1642
|
+
chain = self.persist()
|
|
1116
1643
|
assert chain.name is not None # for mypy
|
|
1117
1644
|
return PytorchDataset(
|
|
1118
1645
|
chain.name,
|
|
@@ -1126,15 +1653,12 @@ class DataChain:
|
|
|
1126
1653
|
remove_prefetched=remove_prefetched,
|
|
1127
1654
|
)
|
|
1128
1655
|
|
|
1129
|
-
|
|
1130
|
-
schema = self.signals_schema.clone_without_file_signals()
|
|
1131
|
-
return self.select(*schema.values.keys())
|
|
1132
|
-
|
|
1656
|
+
@delta_disabled
|
|
1133
1657
|
def merge(
|
|
1134
1658
|
self,
|
|
1135
1659
|
right_ds: "DataChain",
|
|
1136
|
-
on:
|
|
1137
|
-
right_on:
|
|
1660
|
+
on: MergeColType | Sequence[MergeColType],
|
|
1661
|
+
right_on: MergeColType | Sequence[MergeColType] | None = None,
|
|
1138
1662
|
inner=False,
|
|
1139
1663
|
full=False,
|
|
1140
1664
|
rname="right_",
|
|
@@ -1202,8 +1726,8 @@ class DataChain:
|
|
|
1202
1726
|
|
|
1203
1727
|
def _resolve(
|
|
1204
1728
|
ds: DataChain,
|
|
1205
|
-
col:
|
|
1206
|
-
side:
|
|
1729
|
+
col: str | Function | sqlalchemy.ColumnElement,
|
|
1730
|
+
side: str | None,
|
|
1207
1731
|
):
|
|
1208
1732
|
try:
|
|
1209
1733
|
if isinstance(col, Function):
|
|
@@ -1216,7 +1740,7 @@ class DataChain:
|
|
|
1216
1740
|
ops = [
|
|
1217
1741
|
_resolve(self, left, "left")
|
|
1218
1742
|
== _resolve(right_ds, right, "right" if right_on else None)
|
|
1219
|
-
for left, right in zip(on, right_on or on)
|
|
1743
|
+
for left, right in zip(on, right_on or on, strict=False)
|
|
1220
1744
|
]
|
|
1221
1745
|
|
|
1222
1746
|
if errors:
|
|
@@ -1225,32 +1749,44 @@ class DataChain:
|
|
|
1225
1749
|
)
|
|
1226
1750
|
|
|
1227
1751
|
query = self._query.join(
|
|
1228
|
-
right_ds._query, sqlalchemy.and_(*ops), inner, full, rname
|
|
1752
|
+
right_ds._query, sqlalchemy.and_(*ops), inner, full, rname
|
|
1229
1753
|
)
|
|
1230
1754
|
query.feature_schema = None
|
|
1231
1755
|
ds = self._evolve(query=query)
|
|
1232
1756
|
|
|
1757
|
+
# Note: merge drops sys signals from both sides, make sure to not include it
|
|
1758
|
+
# in the resulting schema
|
|
1233
1759
|
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1234
1760
|
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
)
|
|
1761
|
+
|
|
1762
|
+
ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
|
|
1238
1763
|
|
|
1239
1764
|
return ds
|
|
1240
1765
|
|
|
1766
|
+
@delta_disabled
|
|
1241
1767
|
def union(self, other: "Self") -> "Self":
|
|
1242
1768
|
"""Return the set union of the two datasets.
|
|
1243
1769
|
|
|
1244
1770
|
Parameters:
|
|
1245
1771
|
other: chain whose rows will be added to `self`.
|
|
1246
1772
|
"""
|
|
1773
|
+
self_schema = self.signals_schema
|
|
1774
|
+
other_schema = other.signals_schema
|
|
1775
|
+
missing_left, missing_right = self_schema.compare_signals(other_schema)
|
|
1776
|
+
if missing_left or missing_right:
|
|
1777
|
+
raise UnionSchemaMismatchError.from_column_sets(
|
|
1778
|
+
missing_left,
|
|
1779
|
+
missing_right,
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
self.signals_schema = self_schema.clone_without_sys_signals()
|
|
1247
1783
|
return self._evolve(query=self._query.union(other._query))
|
|
1248
1784
|
|
|
1249
1785
|
def subtract( # type: ignore[override]
|
|
1250
1786
|
self,
|
|
1251
1787
|
other: "DataChain",
|
|
1252
|
-
on:
|
|
1253
|
-
right_on:
|
|
1788
|
+
on: str | Sequence[str] | None = None,
|
|
1789
|
+
right_on: str | Sequence[str] | None = None,
|
|
1254
1790
|
) -> "Self":
|
|
1255
1791
|
"""Remove rows that appear in another chain.
|
|
1256
1792
|
|
|
@@ -1307,58 +1843,51 @@ class DataChain:
|
|
|
1307
1843
|
zip(
|
|
1308
1844
|
self.signals_schema.resolve(*on).db_signals(),
|
|
1309
1845
|
other.signals_schema.resolve(*right_on).db_signals(),
|
|
1846
|
+
strict=False,
|
|
1310
1847
|
) # type: ignore[arg-type]
|
|
1311
1848
|
)
|
|
1312
1849
|
return self._evolve(query=self._query.subtract(other._query, signals)) # type: ignore[arg-type]
|
|
1313
1850
|
|
|
1314
|
-
def
|
|
1851
|
+
def diff(
|
|
1315
1852
|
self,
|
|
1316
1853
|
other: "DataChain",
|
|
1317
|
-
on:
|
|
1318
|
-
right_on:
|
|
1319
|
-
compare:
|
|
1320
|
-
right_compare:
|
|
1854
|
+
on: str | Sequence[str],
|
|
1855
|
+
right_on: str | Sequence[str] | None = None,
|
|
1856
|
+
compare: str | Sequence[str] | None = None,
|
|
1857
|
+
right_compare: str | Sequence[str] | None = None,
|
|
1321
1858
|
added: bool = True,
|
|
1322
1859
|
deleted: bool = True,
|
|
1323
1860
|
modified: bool = True,
|
|
1324
1861
|
same: bool = False,
|
|
1325
|
-
status_col:
|
|
1862
|
+
status_col: str | None = None,
|
|
1326
1863
|
) -> "DataChain":
|
|
1327
|
-
"""
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
for all rows. Beside additional diff column, new chain has schema of the chain
|
|
1333
|
-
on which method was called.
|
|
1864
|
+
"""Calculate differences between two chains.
|
|
1865
|
+
|
|
1866
|
+
This method identifies records that are added, deleted, modified, or unchanged
|
|
1867
|
+
between two chains. It adds a status column with values: A=added, D=deleted,
|
|
1868
|
+
M=modified, S=same.
|
|
1334
1869
|
|
|
1335
1870
|
Parameters:
|
|
1336
|
-
other: Chain to
|
|
1337
|
-
on: Column
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
added (bool): Whether to return added rows in resulting chain.
|
|
1353
|
-
deleted (bool): Whether to return deleted rows in resulting chain.
|
|
1354
|
-
modified (bool): Whether to return modified rows in resulting chain.
|
|
1355
|
-
same (bool): Whether to return unchanged rows in resulting chain.
|
|
1356
|
-
status_col (str): Name of the new column that is created in resulting chain
|
|
1357
|
-
representing diff status.
|
|
1871
|
+
other: Chain to compare against.
|
|
1872
|
+
on: Column(s) to match records between chains.
|
|
1873
|
+
right_on: Column(s) in the other chain to match against. Defaults to `on`.
|
|
1874
|
+
compare: Column(s) to check for changes.
|
|
1875
|
+
If not specified,all columns are used.
|
|
1876
|
+
right_compare: Column(s) in the other chain to compare against.
|
|
1877
|
+
Defaults to values of `compare`.
|
|
1878
|
+
added (bool): Include records that exist in this chain but not in the other.
|
|
1879
|
+
deleted (bool): Include records that exist only in the other chain.
|
|
1880
|
+
modified (bool): Include records that exist in both
|
|
1881
|
+
but have different values.
|
|
1882
|
+
same (bool): Include records that are identical in both chains.
|
|
1883
|
+
status_col (str): Name for the status column showing differences.
|
|
1884
|
+
|
|
1885
|
+
Default behavior: By default, shows added, deleted, and modified records,
|
|
1886
|
+
but excludes unchanged records (same=False). Status column is not created.
|
|
1358
1887
|
|
|
1359
1888
|
Example:
|
|
1360
1889
|
```py
|
|
1361
|
-
res = persons.
|
|
1890
|
+
res = persons.diff(
|
|
1362
1891
|
new_persons,
|
|
1363
1892
|
on=["id"],
|
|
1364
1893
|
right_on=["other_id"],
|
|
@@ -1387,42 +1916,40 @@ class DataChain:
|
|
|
1387
1916
|
status_col=status_col,
|
|
1388
1917
|
)
|
|
1389
1918
|
|
|
1390
|
-
def
|
|
1919
|
+
def file_diff(
|
|
1391
1920
|
self,
|
|
1392
1921
|
other: "DataChain",
|
|
1393
1922
|
on: str = "file",
|
|
1394
|
-
right_on:
|
|
1923
|
+
right_on: str | None = None,
|
|
1395
1924
|
added: bool = True,
|
|
1396
1925
|
modified: bool = True,
|
|
1397
1926
|
deleted: bool = False,
|
|
1398
1927
|
same: bool = False,
|
|
1399
|
-
status_col:
|
|
1928
|
+
status_col: str | None = None,
|
|
1400
1929
|
) -> "DataChain":
|
|
1401
|
-
"""
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
`.compare()` user needs to provide arbitrary columns for matching and comparing.
|
|
1930
|
+
"""Calculate differences between two chains containing files.
|
|
1931
|
+
|
|
1932
|
+
This method is specifically designed for file chains. It uses file `source`
|
|
1933
|
+
and `path` to match files, and file `version` and `etag` to detect changes.
|
|
1406
1934
|
|
|
1407
1935
|
Parameters:
|
|
1408
|
-
other: Chain to
|
|
1409
|
-
on: File
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
resulting chain representing diff status.
|
|
1936
|
+
other: Chain to compare against.
|
|
1937
|
+
on: File column name in this chain. Default is "file".
|
|
1938
|
+
right_on: File column name in the other chain. Defaults to `on`.
|
|
1939
|
+
added (bool): Include files that exist in this chain but not in the other.
|
|
1940
|
+
deleted (bool): Include files that exist only in the other chain.
|
|
1941
|
+
modified (bool): Include files that exist in both but have different
|
|
1942
|
+
versions/etags.
|
|
1943
|
+
same (bool): Include files that are identical in both chains.
|
|
1944
|
+
status_col (str): Name for the status column showing differences
|
|
1945
|
+
(A=added, D=deleted, M=modified, S=same).
|
|
1946
|
+
|
|
1947
|
+
Default behavior: By default, includes only new files (added=True and
|
|
1948
|
+
modified=True). This is useful for incremental processing.
|
|
1422
1949
|
|
|
1423
1950
|
Example:
|
|
1424
1951
|
```py
|
|
1425
|
-
diff = images.
|
|
1952
|
+
diff = images.file_diff(
|
|
1426
1953
|
new_images,
|
|
1427
1954
|
on="file",
|
|
1428
1955
|
right_on="other_file",
|
|
@@ -1447,7 +1974,7 @@ class DataChain:
|
|
|
1447
1974
|
compare_cols = get_file_signals(on, compare_file_signals)
|
|
1448
1975
|
right_compare_cols = get_file_signals(right_on, compare_file_signals)
|
|
1449
1976
|
|
|
1450
|
-
return self.
|
|
1977
|
+
return self.diff(
|
|
1451
1978
|
other,
|
|
1452
1979
|
on_cols,
|
|
1453
1980
|
right_on=right_on_cols,
|
|
@@ -1492,47 +2019,67 @@ class DataChain:
|
|
|
1492
2019
|
)
|
|
1493
2020
|
return read_pandas(*args, **kwargs)
|
|
1494
2021
|
|
|
1495
|
-
def to_pandas(
|
|
2022
|
+
def to_pandas(
|
|
2023
|
+
self,
|
|
2024
|
+
flatten: bool = False,
|
|
2025
|
+
include_hidden: bool = True,
|
|
2026
|
+
as_object: bool = False,
|
|
2027
|
+
) -> "pd.DataFrame":
|
|
1496
2028
|
"""Return a pandas DataFrame from the chain.
|
|
1497
2029
|
|
|
1498
2030
|
Parameters:
|
|
1499
|
-
flatten
|
|
1500
|
-
include_hidden
|
|
2031
|
+
flatten: Whether to use a multiindex or flatten column names.
|
|
2032
|
+
include_hidden: Whether to include hidden columns.
|
|
2033
|
+
as_object: Whether to emit a dataframe backed by Python objects
|
|
2034
|
+
rather than pandas-inferred dtypes.
|
|
2035
|
+
|
|
2036
|
+
Returns:
|
|
2037
|
+
pd.DataFrame: A pandas DataFrame representation of the chain.
|
|
1501
2038
|
"""
|
|
1502
2039
|
import pandas as pd
|
|
1503
2040
|
|
|
1504
2041
|
headers, max_length = self._effective_signals_schema.get_headers_with_length(
|
|
1505
2042
|
include_hidden=include_hidden
|
|
1506
2043
|
)
|
|
2044
|
+
|
|
2045
|
+
columns: list[str] | pd.MultiIndex
|
|
1507
2046
|
if flatten or max_length < 2:
|
|
1508
2047
|
columns = [".".join(filter(None, header)) for header in headers]
|
|
1509
2048
|
else:
|
|
1510
2049
|
columns = pd.MultiIndex.from_tuples(map(tuple, headers))
|
|
1511
2050
|
|
|
1512
2051
|
results = self.results(include_hidden=include_hidden)
|
|
2052
|
+
if as_object:
|
|
2053
|
+
df = pd.DataFrame(results, columns=columns, dtype=object)
|
|
2054
|
+
df.where(pd.notna(df), None, inplace=True)
|
|
2055
|
+
return df
|
|
1513
2056
|
return pd.DataFrame.from_records(results, columns=columns)
|
|
1514
2057
|
|
|
1515
2058
|
def show(
|
|
1516
2059
|
self,
|
|
1517
2060
|
limit: int = 20,
|
|
1518
|
-
flatten=False,
|
|
1519
|
-
transpose=False,
|
|
1520
|
-
truncate=True,
|
|
1521
|
-
include_hidden=False,
|
|
2061
|
+
flatten: bool = False,
|
|
2062
|
+
transpose: bool = False,
|
|
2063
|
+
truncate: bool = True,
|
|
2064
|
+
include_hidden: bool = False,
|
|
1522
2065
|
) -> None:
|
|
1523
2066
|
"""Show a preview of the chain results.
|
|
1524
2067
|
|
|
1525
2068
|
Parameters:
|
|
1526
|
-
limit
|
|
1527
|
-
flatten
|
|
1528
|
-
transpose
|
|
1529
|
-
truncate
|
|
1530
|
-
include_hidden
|
|
2069
|
+
limit: How many rows to show.
|
|
2070
|
+
flatten: Whether to use a multiindex or flatten column names.
|
|
2071
|
+
transpose: Whether to transpose rows and columns.
|
|
2072
|
+
truncate: Whether or not to truncate the contents of columns.
|
|
2073
|
+
include_hidden: Whether to include hidden columns.
|
|
1531
2074
|
"""
|
|
1532
2075
|
import pandas as pd
|
|
1533
2076
|
|
|
1534
2077
|
dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
|
|
1535
|
-
df = dc.to_pandas(
|
|
2078
|
+
df = dc.to_pandas(
|
|
2079
|
+
flatten,
|
|
2080
|
+
include_hidden=include_hidden,
|
|
2081
|
+
as_object=True,
|
|
2082
|
+
)
|
|
1536
2083
|
|
|
1537
2084
|
if df.empty:
|
|
1538
2085
|
print("Empty result")
|
|
@@ -1588,23 +2135,23 @@ class DataChain:
|
|
|
1588
2135
|
def parse_tabular(
|
|
1589
2136
|
self,
|
|
1590
2137
|
output: OutputType = None,
|
|
1591
|
-
|
|
2138
|
+
column: str = "",
|
|
1592
2139
|
model_name: str = "",
|
|
1593
2140
|
source: bool = True,
|
|
1594
|
-
nrows:
|
|
1595
|
-
**kwargs,
|
|
2141
|
+
nrows: int | None = None,
|
|
2142
|
+
**kwargs: Any,
|
|
1596
2143
|
) -> "Self":
|
|
1597
2144
|
"""Generate chain from list of tabular files.
|
|
1598
2145
|
|
|
1599
2146
|
Parameters:
|
|
1600
|
-
output
|
|
2147
|
+
output: Dictionary or feature class defining column names and their
|
|
1601
2148
|
corresponding types. List of column names is also accepted, in which
|
|
1602
2149
|
case types will be inferred.
|
|
1603
|
-
|
|
1604
|
-
model_name
|
|
1605
|
-
source
|
|
1606
|
-
nrows
|
|
1607
|
-
kwargs
|
|
2150
|
+
column: Generated column name.
|
|
2151
|
+
model_name: Generated model name.
|
|
2152
|
+
source: Whether to include info about the source file.
|
|
2153
|
+
nrows: Optional row limit.
|
|
2154
|
+
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
1608
2155
|
|
|
1609
2156
|
Example:
|
|
1610
2157
|
Reading a json lines file:
|
|
@@ -1619,24 +2166,33 @@ class DataChain:
|
|
|
1619
2166
|
import datachain as dc
|
|
1620
2167
|
|
|
1621
2168
|
chain = dc.read_storage("s3://mybucket")
|
|
1622
|
-
chain = chain.filter(dc.C("file.
|
|
2169
|
+
chain = chain.filter(dc.C("file.path").glob("*.jsonl"))
|
|
1623
2170
|
chain = chain.parse_tabular(format="json")
|
|
1624
2171
|
```
|
|
1625
2172
|
"""
|
|
1626
2173
|
from pyarrow.dataset import CsvFileFormat, JsonFileFormat
|
|
1627
2174
|
|
|
1628
|
-
from datachain.lib.arrow import
|
|
2175
|
+
from datachain.lib.arrow import (
|
|
2176
|
+
ArrowGenerator,
|
|
2177
|
+
fix_pyarrow_format,
|
|
2178
|
+
infer_schema,
|
|
2179
|
+
schema_to_output,
|
|
2180
|
+
)
|
|
1629
2181
|
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
2182
|
+
parse_options = kwargs.pop("parse_options", None)
|
|
2183
|
+
if format := kwargs.get("format"):
|
|
2184
|
+
kwargs["format"] = fix_pyarrow_format(format, parse_options)
|
|
2185
|
+
|
|
2186
|
+
if (
|
|
2187
|
+
nrows
|
|
2188
|
+
and format not in ["csv", "json"]
|
|
2189
|
+
and not isinstance(format, (CsvFileFormat, JsonFileFormat))
|
|
2190
|
+
):
|
|
2191
|
+
raise DatasetPrepareError(
|
|
2192
|
+
self.name,
|
|
2193
|
+
"error in `parse_tabular` - "
|
|
2194
|
+
"`nrows` only supported for csv and json formats.",
|
|
2195
|
+
)
|
|
1640
2196
|
|
|
1641
2197
|
if "file" not in self.schema or not self.count():
|
|
1642
2198
|
raise DatasetPrepareError(self.name, "no files to parse.")
|
|
@@ -1645,20 +2201,20 @@ class DataChain:
|
|
|
1645
2201
|
col_names = output if isinstance(output, Sequence) else None
|
|
1646
2202
|
if col_names or not output:
|
|
1647
2203
|
try:
|
|
1648
|
-
schema = infer_schema(self, **kwargs)
|
|
2204
|
+
schema = infer_schema(self, **kwargs, parse_options=parse_options)
|
|
1649
2205
|
output, _ = schema_to_output(schema, col_names)
|
|
1650
2206
|
except ValueError as e:
|
|
1651
2207
|
raise DatasetPrepareError(self.name, e) from e
|
|
1652
2208
|
|
|
1653
2209
|
if isinstance(output, dict):
|
|
1654
|
-
model_name = model_name or
|
|
2210
|
+
model_name = model_name or column or ""
|
|
1655
2211
|
model = dict_to_data_model(model_name, output)
|
|
1656
2212
|
output = model
|
|
1657
2213
|
else:
|
|
1658
2214
|
model = output # type: ignore[assignment]
|
|
1659
2215
|
|
|
1660
|
-
if
|
|
1661
|
-
output = {
|
|
2216
|
+
if column:
|
|
2217
|
+
output = {column: model} # type: ignore[dict-item]
|
|
1662
2218
|
elif isinstance(output, type(BaseModel)):
|
|
1663
2219
|
output = {
|
|
1664
2220
|
name: info.annotation # type: ignore[misc]
|
|
@@ -1671,7 +2227,15 @@ class DataChain:
|
|
|
1671
2227
|
# disable prefetch if nrows is set
|
|
1672
2228
|
settings = {"prefetch": 0} if nrows else {}
|
|
1673
2229
|
return self.settings(**settings).gen( # type: ignore[arg-type]
|
|
1674
|
-
ArrowGenerator(
|
|
2230
|
+
ArrowGenerator(
|
|
2231
|
+
schema,
|
|
2232
|
+
model,
|
|
2233
|
+
source,
|
|
2234
|
+
nrows,
|
|
2235
|
+
parse_options=parse_options,
|
|
2236
|
+
**kwargs,
|
|
2237
|
+
),
|
|
2238
|
+
output=output,
|
|
1675
2239
|
)
|
|
1676
2240
|
|
|
1677
2241
|
@classmethod
|
|
@@ -1708,23 +2272,23 @@ class DataChain:
|
|
|
1708
2272
|
|
|
1709
2273
|
def to_parquet(
|
|
1710
2274
|
self,
|
|
1711
|
-
path:
|
|
1712
|
-
partition_cols:
|
|
2275
|
+
path: str | os.PathLike[str] | BinaryIO,
|
|
2276
|
+
partition_cols: Sequence[str] | None = None,
|
|
1713
2277
|
chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
|
|
1714
|
-
fs_kwargs:
|
|
2278
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
1715
2279
|
**kwargs,
|
|
1716
2280
|
) -> None:
|
|
1717
2281
|
"""Save chain to parquet file with SignalSchema metadata.
|
|
1718
2282
|
|
|
1719
2283
|
Parameters:
|
|
1720
|
-
path
|
|
2284
|
+
path: Path or a file-like binary object to save the file. This supports
|
|
1721
2285
|
local paths as well as remote paths, such as s3:// or hf:// with fsspec.
|
|
1722
|
-
partition_cols
|
|
1723
|
-
chunk_size
|
|
2286
|
+
partition_cols: Column names by which to partition the dataset.
|
|
2287
|
+
chunk_size: The chunk size of results to read and convert to columnar
|
|
1724
2288
|
data, to avoid running out of memory.
|
|
1725
|
-
fs_kwargs
|
|
1726
|
-
|
|
1727
|
-
|
|
2289
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2290
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2291
|
+
are supported.
|
|
1728
2292
|
"""
|
|
1729
2293
|
import pyarrow as pa
|
|
1730
2294
|
import pyarrow.parquet as pq
|
|
@@ -1754,9 +2318,9 @@ class DataChain:
|
|
|
1754
2318
|
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
1755
2319
|
|
|
1756
2320
|
_partition_cols = list(partition_cols) if partition_cols else None
|
|
1757
|
-
signal_schema_metadata =
|
|
1758
|
-
self._effective_signals_schema.serialize()
|
|
1759
|
-
)
|
|
2321
|
+
signal_schema_metadata = json.dumps(
|
|
2322
|
+
self._effective_signals_schema.serialize(), ensure_ascii=False
|
|
2323
|
+
).encode("utf-8")
|
|
1760
2324
|
|
|
1761
2325
|
column_names, column_chunks = self.to_columnar_data_with_names(chunk_size)
|
|
1762
2326
|
|
|
@@ -1768,7 +2332,7 @@ class DataChain:
|
|
|
1768
2332
|
# pyarrow infers the best parquet schema from the python types of
|
|
1769
2333
|
# the input data.
|
|
1770
2334
|
table = pa.Table.from_pydict(
|
|
1771
|
-
dict(zip(column_names, chunk)),
|
|
2335
|
+
dict(zip(column_names, chunk, strict=False)),
|
|
1772
2336
|
schema=parquet_schema,
|
|
1773
2337
|
)
|
|
1774
2338
|
|
|
@@ -1806,123 +2370,220 @@ class DataChain:
|
|
|
1806
2370
|
|
|
1807
2371
|
def to_csv(
|
|
1808
2372
|
self,
|
|
1809
|
-
path:
|
|
2373
|
+
path: str | os.PathLike[str],
|
|
1810
2374
|
delimiter: str = ",",
|
|
1811
|
-
fs_kwargs:
|
|
2375
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
1812
2376
|
**kwargs,
|
|
1813
|
-
) ->
|
|
1814
|
-
"""Save chain to a csv (comma-separated values) file
|
|
2377
|
+
) -> File:
|
|
2378
|
+
"""Save chain to a csv (comma-separated values) file and return the stored
|
|
2379
|
+
`File`.
|
|
1815
2380
|
|
|
1816
2381
|
Parameters:
|
|
1817
|
-
path
|
|
2382
|
+
path: Path to save the file. This supports local paths as well as
|
|
1818
2383
|
remote paths, such as s3:// or hf:// with fsspec.
|
|
1819
|
-
delimiter
|
|
1820
|
-
fs_kwargs
|
|
1821
|
-
|
|
1822
|
-
|
|
2384
|
+
delimiter: Delimiter to use for the resulting file.
|
|
2385
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2386
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2387
|
+
are supported.
|
|
2388
|
+
Returns:
|
|
2389
|
+
File: The stored file with refreshed metadata (version, etag, size).
|
|
1823
2390
|
"""
|
|
1824
2391
|
import csv
|
|
1825
2392
|
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
if isinstance(path, str) and "://" in path:
|
|
1829
|
-
from datachain.client.fsspec import Client
|
|
1830
|
-
|
|
1831
|
-
fs_kwargs = {
|
|
1832
|
-
**self._query.catalog.client_config,
|
|
1833
|
-
**(fs_kwargs or {}),
|
|
1834
|
-
}
|
|
1835
|
-
|
|
1836
|
-
client = Client.get_implementation(path)
|
|
1837
|
-
|
|
1838
|
-
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
1839
|
-
|
|
1840
|
-
opener = fsspec_fs.open
|
|
2393
|
+
target = File.at(path, session=self.session)
|
|
1841
2394
|
|
|
1842
2395
|
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1843
2396
|
column_names = [".".join(filter(None, header)) for header in headers]
|
|
1844
2397
|
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
with opener(path, "w", newline="") as f:
|
|
2398
|
+
with target.open("w", newline="", client_config=fs_kwargs) as f:
|
|
1848
2399
|
writer = csv.writer(f, delimiter=delimiter, **kwargs)
|
|
1849
2400
|
writer.writerow(column_names)
|
|
1850
|
-
|
|
1851
|
-
for row in results_iter:
|
|
2401
|
+
for row in self._leaf_values():
|
|
1852
2402
|
writer.writerow(row)
|
|
1853
2403
|
|
|
2404
|
+
return target
|
|
2405
|
+
|
|
1854
2406
|
def to_json(
|
|
1855
2407
|
self,
|
|
1856
|
-
path:
|
|
1857
|
-
fs_kwargs:
|
|
2408
|
+
path: str | os.PathLike[str],
|
|
2409
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
1858
2410
|
include_outer_list: bool = True,
|
|
1859
|
-
) ->
|
|
1860
|
-
"""Save chain to a JSON file
|
|
2411
|
+
) -> File:
|
|
2412
|
+
"""Save chain to a JSON file and return the stored `File`.
|
|
1861
2413
|
|
|
1862
2414
|
Parameters:
|
|
1863
|
-
path
|
|
2415
|
+
path: Path to save the file. This supports local paths as well as
|
|
1864
2416
|
remote paths, such as s3:// or hf:// with fsspec.
|
|
1865
|
-
fs_kwargs
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
include_outer_list
|
|
2417
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2418
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2419
|
+
are supported.
|
|
2420
|
+
include_outer_list: Sets whether to include an outer list for all rows.
|
|
1869
2421
|
Setting this to True makes the file valid JSON, while False instead
|
|
1870
2422
|
writes in the JSON lines format.
|
|
2423
|
+
Returns:
|
|
2424
|
+
File: The stored file with refreshed metadata (version, etag, size).
|
|
1871
2425
|
"""
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
if isinstance(path, str) and "://" in path:
|
|
1875
|
-
from datachain.client.fsspec import Client
|
|
1876
|
-
|
|
1877
|
-
fs_kwargs = {
|
|
1878
|
-
**self._query.catalog.client_config,
|
|
1879
|
-
**(fs_kwargs or {}),
|
|
1880
|
-
}
|
|
1881
|
-
|
|
1882
|
-
client = Client.get_implementation(path)
|
|
1883
|
-
|
|
1884
|
-
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
1885
|
-
|
|
1886
|
-
opener = fsspec_fs.open
|
|
1887
|
-
|
|
2426
|
+
target = File.at(path, session=self.session)
|
|
1888
2427
|
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
1889
|
-
headers = [list(filter(None,
|
|
2428
|
+
headers = [list(filter(None, h)) for h in headers]
|
|
2429
|
+
with target.open("wb", client_config=fs_kwargs) as f:
|
|
2430
|
+
self._write_json_stream(f, headers, include_outer_list)
|
|
2431
|
+
return target
|
|
1890
2432
|
|
|
2433
|
+
def _write_json_stream(
|
|
2434
|
+
self,
|
|
2435
|
+
f: IO[bytes],
|
|
2436
|
+
headers: list[list[str]],
|
|
2437
|
+
include_outer_list: bool,
|
|
2438
|
+
) -> None:
|
|
1891
2439
|
is_first = True
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
f.write(b"
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
if include_outer_list:
|
|
1908
|
-
# This makes the file JSON instead of JSON lines.
|
|
1909
|
-
f.write(b"\n]\n")
|
|
2440
|
+
if include_outer_list:
|
|
2441
|
+
f.write(b"[\n")
|
|
2442
|
+
for row in self._leaf_values():
|
|
2443
|
+
if not is_first:
|
|
2444
|
+
f.write(b",\n" if include_outer_list else b"\n")
|
|
2445
|
+
else:
|
|
2446
|
+
is_first = False
|
|
2447
|
+
f.write(
|
|
2448
|
+
json.dumps(
|
|
2449
|
+
row_to_nested_dict(headers, row),
|
|
2450
|
+
ensure_ascii=False,
|
|
2451
|
+
).encode("utf-8")
|
|
2452
|
+
)
|
|
2453
|
+
if include_outer_list:
|
|
2454
|
+
f.write(b"\n]\n")
|
|
1910
2455
|
|
|
1911
2456
|
def to_jsonl(
|
|
1912
2457
|
self,
|
|
1913
|
-
path:
|
|
1914
|
-
fs_kwargs:
|
|
1915
|
-
) ->
|
|
2458
|
+
path: str | os.PathLike[str],
|
|
2459
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
2460
|
+
) -> File:
|
|
1916
2461
|
"""Save chain to a JSON lines file.
|
|
1917
2462
|
|
|
1918
2463
|
Parameters:
|
|
1919
|
-
path
|
|
2464
|
+
path: Path to save the file. This supports local paths as well as
|
|
1920
2465
|
remote paths, such as s3:// or hf:// with fsspec.
|
|
1921
|
-
fs_kwargs
|
|
1922
|
-
|
|
1923
|
-
|
|
2466
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2467
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2468
|
+
are supported.
|
|
2469
|
+
Returns:
|
|
2470
|
+
File: The stored file with refreshed metadata (version, etag, size).
|
|
1924
2471
|
"""
|
|
1925
|
-
self.to_json(path, fs_kwargs, include_outer_list=False)
|
|
2472
|
+
return self.to_json(path, fs_kwargs, include_outer_list=False)
|
|
2473
|
+
|
|
2474
|
+
def to_database(
|
|
2475
|
+
self,
|
|
2476
|
+
table_name: str,
|
|
2477
|
+
connection: "ConnectionType",
|
|
2478
|
+
*,
|
|
2479
|
+
batch_size: int = DEFAULT_DATABASE_BATCH_SIZE,
|
|
2480
|
+
on_conflict: str | None = None,
|
|
2481
|
+
conflict_columns: list[str] | None = None,
|
|
2482
|
+
column_mapping: dict[str, str | None] | None = None,
|
|
2483
|
+
) -> int:
|
|
2484
|
+
"""Save chain to a database table using a given database connection.
|
|
2485
|
+
|
|
2486
|
+
This method exports all DataChain records to a database table, creating the
|
|
2487
|
+
table if it doesn't exist and appending data if it does. The table schema
|
|
2488
|
+
is automatically inferred from the DataChain's signal schema.
|
|
2489
|
+
|
|
2490
|
+
For PostgreSQL, tables are created in the schema specified by the connection's
|
|
2491
|
+
search_path (defaults to 'public'). Use URL parameters to target specific
|
|
2492
|
+
schemas.
|
|
2493
|
+
|
|
2494
|
+
Parameters:
|
|
2495
|
+
table_name: Name of the database table to create/write to.
|
|
2496
|
+
connection: SQLAlchemy connectable, str, or a sqlite3 connection
|
|
2497
|
+
Using SQLAlchemy makes it possible to use any DB supported by that
|
|
2498
|
+
library. If a DBAPI2 object, only sqlite3 is supported. The user is
|
|
2499
|
+
responsible for engine disposal and connection closure for the
|
|
2500
|
+
SQLAlchemy connectable; str connections are closed automatically.
|
|
2501
|
+
batch_size: Number of rows to insert per batch for optimal performance.
|
|
2502
|
+
Larger batches are faster but use more memory. Default: 10,000.
|
|
2503
|
+
on_conflict: Strategy for handling duplicate rows (requires table
|
|
2504
|
+
constraints):
|
|
2505
|
+
- None: Raise error (`sqlalchemy.exc.IntegrityError`) on conflict
|
|
2506
|
+
(default)
|
|
2507
|
+
- "ignore": Skip duplicate rows silently
|
|
2508
|
+
- "update": Update existing rows with new values
|
|
2509
|
+
conflict_columns: List of column names that form a unique constraint
|
|
2510
|
+
for conflict resolution. Required when on_conflict='update' and
|
|
2511
|
+
using PostgreSQL.
|
|
2512
|
+
column_mapping: Optional mapping to rename or skip columns:
|
|
2513
|
+
- Dict mapping DataChain column names to database column names
|
|
2514
|
+
- Set values to None to skip columns entirely, or use `defaultdict` to
|
|
2515
|
+
skip all columns except those specified.
|
|
2516
|
+
|
|
2517
|
+
Returns:
|
|
2518
|
+
int: Number of rows affected (inserted/updated). -1 if DB driver doesn't
|
|
2519
|
+
support telemetry.
|
|
2520
|
+
|
|
2521
|
+
Examples:
|
|
2522
|
+
Basic usage with PostgreSQL:
|
|
2523
|
+
```py
|
|
2524
|
+
import datachain as dc
|
|
2525
|
+
|
|
2526
|
+
rows_affected = (dc
|
|
2527
|
+
.read_storage("s3://my-bucket/")
|
|
2528
|
+
.to_database("files_table", "postgresql://user:pass@localhost/mydb")
|
|
2529
|
+
)
|
|
2530
|
+
print(f"Inserted/updated {rows_affected} rows")
|
|
2531
|
+
```
|
|
2532
|
+
|
|
2533
|
+
Using SQLite with connection string:
|
|
2534
|
+
```py
|
|
2535
|
+
rows_affected = chain.to_database("my_table", "sqlite:///data.db")
|
|
2536
|
+
print(f"Affected {rows_affected} rows")
|
|
2537
|
+
```
|
|
2538
|
+
|
|
2539
|
+
Column mapping and renaming:
|
|
2540
|
+
```py
|
|
2541
|
+
mapping = {
|
|
2542
|
+
"user.id": "id",
|
|
2543
|
+
"user.name": "name",
|
|
2544
|
+
"user.password": None # Skip this column
|
|
2545
|
+
}
|
|
2546
|
+
chain.to_database("users", engine, column_mapping=mapping)
|
|
2547
|
+
```
|
|
2548
|
+
|
|
2549
|
+
Handling conflicts (requires PRIMARY KEY or UNIQUE constraints):
|
|
2550
|
+
```py
|
|
2551
|
+
# Skip duplicates
|
|
2552
|
+
chain.to_database("my_table", engine, on_conflict="ignore")
|
|
2553
|
+
|
|
2554
|
+
# Update existing records
|
|
2555
|
+
chain.to_database(
|
|
2556
|
+
"my_table", engine, on_conflict="update", conflict_columns=["id"]
|
|
2557
|
+
)
|
|
2558
|
+
```
|
|
2559
|
+
|
|
2560
|
+
Working with different databases:
|
|
2561
|
+
```py
|
|
2562
|
+
# MySQL
|
|
2563
|
+
mysql_engine = sa.create_engine("mysql+pymysql://user:pass@host/db")
|
|
2564
|
+
chain.to_database("mysql_table", mysql_engine)
|
|
2565
|
+
|
|
2566
|
+
# SQLite in-memory
|
|
2567
|
+
chain.to_database("temp_table", "sqlite:///:memory:")
|
|
2568
|
+
```
|
|
2569
|
+
|
|
2570
|
+
PostgreSQL with schema support:
|
|
2571
|
+
```py
|
|
2572
|
+
pg_url = "postgresql://user:pass@host/db?options=-c search_path=analytics"
|
|
2573
|
+
chain.to_database("processed_data", pg_url)
|
|
2574
|
+
```
|
|
2575
|
+
"""
|
|
2576
|
+
from .database import to_database
|
|
2577
|
+
|
|
2578
|
+
return to_database(
|
|
2579
|
+
self,
|
|
2580
|
+
table_name,
|
|
2581
|
+
connection,
|
|
2582
|
+
batch_size=batch_size,
|
|
2583
|
+
on_conflict=on_conflict,
|
|
2584
|
+
conflict_columns=conflict_columns,
|
|
2585
|
+
column_mapping=column_mapping,
|
|
2586
|
+
)
|
|
1926
2587
|
|
|
1927
2588
|
@classmethod
|
|
1928
2589
|
def from_records(
|
|
@@ -1940,28 +2601,85 @@ class DataChain:
|
|
|
1940
2601
|
)
|
|
1941
2602
|
return read_records(*args, **kwargs)
|
|
1942
2603
|
|
|
1943
|
-
def sum(self,
|
|
1944
|
-
"""Compute the sum of a column.
|
|
1945
|
-
|
|
2604
|
+
def sum(self, col: str) -> StandardType: # type: ignore[override]
|
|
2605
|
+
"""Compute the sum of a column.
|
|
2606
|
+
|
|
2607
|
+
Parameters:
|
|
2608
|
+
col: The column to compute the sum for.
|
|
2609
|
+
|
|
2610
|
+
Returns:
|
|
2611
|
+
The sum of the column values.
|
|
2612
|
+
|
|
2613
|
+
Example:
|
|
2614
|
+
```py
|
|
2615
|
+
total_size = chain.sum("file.size")
|
|
2616
|
+
print(f"Total size: {total_size}")
|
|
2617
|
+
```
|
|
2618
|
+
"""
|
|
2619
|
+
return self._extend_to_data_model("sum", col)
|
|
2620
|
+
|
|
2621
|
+
def avg(self, col: str) -> StandardType: # type: ignore[override]
|
|
2622
|
+
"""Compute the average of a column.
|
|
2623
|
+
|
|
2624
|
+
Parameters:
|
|
2625
|
+
col: The column to compute the average for.
|
|
2626
|
+
|
|
2627
|
+
Returns:
|
|
2628
|
+
The average of the column values.
|
|
2629
|
+
|
|
2630
|
+
Example:
|
|
2631
|
+
```py
|
|
2632
|
+
average_size = chain.avg("file.size")
|
|
2633
|
+
print(f"Average size: {average_size}")
|
|
2634
|
+
```
|
|
2635
|
+
"""
|
|
2636
|
+
return self._extend_to_data_model("avg", col)
|
|
2637
|
+
|
|
2638
|
+
def min(self, col: str) -> StandardType: # type: ignore[override]
|
|
2639
|
+
"""Compute the minimum of a column.
|
|
1946
2640
|
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
return self._extend_to_data_model("avg", fr)
|
|
2641
|
+
Parameters:
|
|
2642
|
+
col: The column to compute the minimum for.
|
|
1950
2643
|
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
return self._extend_to_data_model("min", fr)
|
|
2644
|
+
Returns:
|
|
2645
|
+
The minimum value in the column.
|
|
1954
2646
|
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
2647
|
+
Example:
|
|
2648
|
+
```py
|
|
2649
|
+
min_size = chain.min("file.size")
|
|
2650
|
+
print(f"Minimum size: {min_size}")
|
|
2651
|
+
```
|
|
2652
|
+
"""
|
|
2653
|
+
return self._extend_to_data_model("min", col)
|
|
2654
|
+
|
|
2655
|
+
def max(self, col: str) -> StandardType: # type: ignore[override]
|
|
2656
|
+
"""Compute the maximum of a column.
|
|
2657
|
+
|
|
2658
|
+
Parameters:
|
|
2659
|
+
col: The column to compute the maximum for.
|
|
2660
|
+
|
|
2661
|
+
Returns:
|
|
2662
|
+
The maximum value in the column.
|
|
2663
|
+
|
|
2664
|
+
Example:
|
|
2665
|
+
```py
|
|
2666
|
+
max_size = chain.max("file.size")
|
|
2667
|
+
print(f"Maximum size: {max_size}")
|
|
2668
|
+
```
|
|
2669
|
+
"""
|
|
2670
|
+
return self._extend_to_data_model("max", col)
|
|
1958
2671
|
|
|
1959
2672
|
def setup(self, **kwargs) -> "Self":
|
|
1960
2673
|
"""Setup variables to pass to UDF functions.
|
|
1961
2674
|
|
|
1962
|
-
Use before running map/gen/agg
|
|
2675
|
+
Use before running map/gen/agg to save an object and pass it as an
|
|
1963
2676
|
argument to the UDF.
|
|
1964
2677
|
|
|
2678
|
+
The value must be a callable (a `lambda: <value>` syntax can be used to quickly
|
|
2679
|
+
create one) that returns the object to be passed to the UDF. It is evaluated
|
|
2680
|
+
lazily when UDF is running, in case of multiple machines the callable is run on
|
|
2681
|
+
a worker machine.
|
|
2682
|
+
|
|
1965
2683
|
Example:
|
|
1966
2684
|
```py
|
|
1967
2685
|
import anthropic
|
|
@@ -1971,7 +2689,11 @@ class DataChain:
|
|
|
1971
2689
|
(
|
|
1972
2690
|
dc.read_storage(DATA, type="text")
|
|
1973
2691
|
.settings(parallel=4, cache=True)
|
|
2692
|
+
|
|
2693
|
+
# Setup Anthropic client and pass it to the UDF below automatically
|
|
2694
|
+
# The value is callable (see the note above)
|
|
1974
2695
|
.setup(client=lambda: anthropic.Anthropic(api_key=API_KEY))
|
|
2696
|
+
|
|
1975
2697
|
.map(
|
|
1976
2698
|
claude=lambda client, file: client.messages.create(
|
|
1977
2699
|
model=MODEL,
|
|
@@ -1993,13 +2715,13 @@ class DataChain:
|
|
|
1993
2715
|
|
|
1994
2716
|
def to_storage(
|
|
1995
2717
|
self,
|
|
1996
|
-
output:
|
|
2718
|
+
output: str | os.PathLike[str],
|
|
1997
2719
|
signal: str = "file",
|
|
1998
2720
|
placement: FileExportPlacement = "fullpath",
|
|
1999
2721
|
link_type: Literal["copy", "symlink"] = "copy",
|
|
2000
|
-
num_threads:
|
|
2001
|
-
anon: bool =
|
|
2002
|
-
client_config:
|
|
2722
|
+
num_threads: int | None = EXPORT_FILES_MAX_THREADS,
|
|
2723
|
+
anon: bool | None = None,
|
|
2724
|
+
client_config: dict | None = None,
|
|
2003
2725
|
) -> None:
|
|
2004
2726
|
"""Export files from a specified signal to a directory. Files can be
|
|
2005
2727
|
exported to a local or cloud directory.
|
|
@@ -2008,12 +2730,28 @@ class DataChain:
|
|
|
2008
2730
|
output: Path to the target directory for exporting files.
|
|
2009
2731
|
signal: Name of the signal to export files from.
|
|
2010
2732
|
placement: The method to use for naming exported files.
|
|
2011
|
-
The possible values are: "filename", "etag", "fullpath",
|
|
2733
|
+
The possible values are: "filename", "etag", "fullpath",
|
|
2734
|
+
"filepath", and "checksum".
|
|
2735
|
+
Example path translations for an object located at
|
|
2736
|
+
``s3://bucket/data/img.jpg`` and exported to ``./out``:
|
|
2737
|
+
|
|
2738
|
+
- "filename" -> ``./out/img.jpg`` (no directories)
|
|
2739
|
+
- "filepath" -> ``./out/data/img.jpg`` (relative path kept)
|
|
2740
|
+
- "fullpath" -> ``./out/bucket/data/img.jpg`` (remote host kept)
|
|
2741
|
+
- "etag" -> ``./out/<etag>.jpg`` (unique name via object digest)
|
|
2742
|
+
|
|
2743
|
+
Local sources behave like "filepath" for "fullpath" placement.
|
|
2744
|
+
Relative destinations such as "." or ".." and absolute paths
|
|
2745
|
+
are supported for every strategy.
|
|
2012
2746
|
link_type: Method to use for exporting files.
|
|
2013
2747
|
Falls back to `'copy'` if symlinking fails.
|
|
2014
|
-
num_threads
|
|
2015
|
-
By default it uses 5 threads.
|
|
2016
|
-
anon: If
|
|
2748
|
+
num_threads: number of threads to use for exporting files.
|
|
2749
|
+
By default, it uses 5 threads.
|
|
2750
|
+
anon: If True, we will treat cloud bucket as a public one. Default behavior
|
|
2751
|
+
depends on the previous session configuration (e.g. happens in the
|
|
2752
|
+
initial `read_storage`) and particular cloud storage client
|
|
2753
|
+
implementation (e.g. S3 fallbacks to anonymous access if no credentials
|
|
2754
|
+
were found).
|
|
2017
2755
|
client_config: Optional configuration for the destination storage client
|
|
2018
2756
|
|
|
2019
2757
|
Example:
|
|
@@ -2025,21 +2763,23 @@ class DataChain:
|
|
|
2025
2763
|
ds.to_storage("gs://mybucket", placement="filename")
|
|
2026
2764
|
```
|
|
2027
2765
|
"""
|
|
2766
|
+
chain = self.persist()
|
|
2767
|
+
count = chain.count()
|
|
2768
|
+
|
|
2028
2769
|
if placement == "filename" and (
|
|
2029
|
-
|
|
2030
|
-
!= self._query.count()
|
|
2770
|
+
chain._query.distinct(pathfunc.name(C(f"{signal}__path"))).count() != count
|
|
2031
2771
|
):
|
|
2032
2772
|
raise ValueError("Files with the same name found")
|
|
2033
2773
|
|
|
2034
|
-
if anon:
|
|
2035
|
-
client_config = (client_config or {}) | {"anon":
|
|
2774
|
+
if anon is not None:
|
|
2775
|
+
client_config = (client_config or {}) | {"anon": anon}
|
|
2036
2776
|
|
|
2037
2777
|
progress_bar = tqdm(
|
|
2038
2778
|
desc=f"Exporting files to {output}: ",
|
|
2039
2779
|
unit=" files",
|
|
2040
2780
|
unit_scale=True,
|
|
2041
2781
|
unit_divisor=10,
|
|
2042
|
-
total=
|
|
2782
|
+
total=count,
|
|
2043
2783
|
leave=False,
|
|
2044
2784
|
)
|
|
2045
2785
|
file_exporter = FileExporter(
|
|
@@ -2050,20 +2790,36 @@ class DataChain:
|
|
|
2050
2790
|
max_threads=num_threads or 1,
|
|
2051
2791
|
client_config=client_config,
|
|
2052
2792
|
)
|
|
2053
|
-
file_exporter.run(
|
|
2793
|
+
file_exporter.run(
|
|
2794
|
+
(rows[0] for rows in chain.to_iter(signal)),
|
|
2795
|
+
progress_bar,
|
|
2796
|
+
)
|
|
2054
2797
|
|
|
2055
2798
|
def shuffle(self) -> "Self":
|
|
2056
|
-
"""Shuffle
|
|
2057
|
-
|
|
2799
|
+
"""Shuffle rows with a best-effort deterministic ordering.
|
|
2800
|
+
|
|
2801
|
+
This produces repeatable shuffles. Merge and union operations can
|
|
2802
|
+
lead to non-deterministic results. Use order by or save a dataset
|
|
2803
|
+
afterward to guarantee the same result.
|
|
2804
|
+
"""
|
|
2805
|
+
query = self._query.clone(new_table=False)
|
|
2806
|
+
query.steps.append(RegenerateSystemColumns(self._query.catalog))
|
|
2058
2807
|
|
|
2059
|
-
|
|
2808
|
+
chain = self._evolve(
|
|
2809
|
+
query=query,
|
|
2810
|
+
signal_schema=SignalSchema({"sys": Sys}) | self.signals_schema,
|
|
2811
|
+
)
|
|
2812
|
+
return chain.order_by("sys.rand")
|
|
2813
|
+
|
|
2814
|
+
def sample(self, n: int) -> "Self":
|
|
2060
2815
|
"""Return a random sample from the chain.
|
|
2061
2816
|
|
|
2062
2817
|
Parameters:
|
|
2063
|
-
n
|
|
2818
|
+
n: Number of samples to draw.
|
|
2064
2819
|
|
|
2065
|
-
|
|
2066
|
-
|
|
2820
|
+
Note:
|
|
2821
|
+
Samples are not deterministic, and streamed/paginated queries or
|
|
2822
|
+
multiple workers will draw samples with replacement.
|
|
2067
2823
|
"""
|
|
2068
2824
|
return self._evolve(query=self._query.sample(n))
|
|
2069
2825
|
|
|
@@ -2078,27 +2834,62 @@ class DataChain:
|
|
|
2078
2834
|
|
|
2079
2835
|
Using glob to match patterns
|
|
2080
2836
|
```py
|
|
2081
|
-
dc.filter(C("file.
|
|
2837
|
+
dc.filter(C("file.path").glob("*.jpg"))
|
|
2838
|
+
```
|
|
2839
|
+
|
|
2840
|
+
Using in to match lists
|
|
2841
|
+
```py
|
|
2842
|
+
ids = [1,2,3]
|
|
2843
|
+
dc.filter(C("experiment_id").in_(ids))
|
|
2082
2844
|
```
|
|
2083
2845
|
|
|
2084
2846
|
Using `datachain.func`
|
|
2085
2847
|
```py
|
|
2086
2848
|
from datachain.func import string
|
|
2087
|
-
dc.filter(string.length(C("file.
|
|
2849
|
+
dc.filter(string.length(C("file.path")) > 5)
|
|
2088
2850
|
```
|
|
2089
2851
|
|
|
2090
2852
|
Combining filters with "or"
|
|
2091
2853
|
```py
|
|
2092
|
-
dc.filter(
|
|
2854
|
+
dc.filter(
|
|
2855
|
+
C("file.path").glob("cat*") |
|
|
2856
|
+
C("file.path").glob("dog*")
|
|
2857
|
+
)
|
|
2858
|
+
```
|
|
2859
|
+
|
|
2860
|
+
```py
|
|
2861
|
+
dc.filter(dc.func.or_(
|
|
2862
|
+
C("file.path").glob("cat*"),
|
|
2863
|
+
C("file.path").glob("dog*")
|
|
2864
|
+
))
|
|
2093
2865
|
```
|
|
2094
2866
|
|
|
2095
2867
|
Combining filters with "and"
|
|
2096
2868
|
```py
|
|
2097
2869
|
dc.filter(
|
|
2098
|
-
C("file.
|
|
2099
|
-
|
|
2870
|
+
C("file.path").glob("*.jpg"),
|
|
2871
|
+
string.length(C("file.path")) > 5
|
|
2872
|
+
)
|
|
2873
|
+
```
|
|
2874
|
+
|
|
2875
|
+
```py
|
|
2876
|
+
dc.filter(
|
|
2877
|
+
C("file.path").glob("*.jpg") &
|
|
2878
|
+
(string.length(C("file.path")) > 5)
|
|
2100
2879
|
)
|
|
2101
2880
|
```
|
|
2881
|
+
|
|
2882
|
+
```py
|
|
2883
|
+
dc.filter(dc.func.and_(
|
|
2884
|
+
C("file.path").glob("*.jpg"),
|
|
2885
|
+
string.length(C("file.path")) > 5
|
|
2886
|
+
))
|
|
2887
|
+
```
|
|
2888
|
+
|
|
2889
|
+
Combining filters with "not"
|
|
2890
|
+
```py
|
|
2891
|
+
dc.filter(~(C("file.path").glob("*.jpg")))
|
|
2892
|
+
```
|
|
2102
2893
|
"""
|
|
2103
2894
|
return self._evolve(query=self._query.filter(*args))
|
|
2104
2895
|
|
|
@@ -2135,6 +2926,10 @@ class DataChain:
|
|
|
2135
2926
|
def chunk(self, index: int, total: int) -> "Self":
|
|
2136
2927
|
"""Split a chain into smaller chunks for e.g. parallelization.
|
|
2137
2928
|
|
|
2929
|
+
Parameters:
|
|
2930
|
+
index: The index of the chunk (0-indexed).
|
|
2931
|
+
total: The total number of chunks.
|
|
2932
|
+
|
|
2138
2933
|
Example:
|
|
2139
2934
|
```py
|
|
2140
2935
|
import datachain as dc
|
|
@@ -2149,3 +2944,72 @@ class DataChain:
|
|
|
2149
2944
|
Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
|
|
2150
2945
|
"""
|
|
2151
2946
|
return self._evolve(query=self._query.chunk(index, total))
|
|
2947
|
+
|
|
2948
|
+
def to_list(self, *cols: str) -> list[tuple[DataValue, ...]]:
|
|
2949
|
+
"""Returns a list of rows of values, optionally limited to the specified
|
|
2950
|
+
columns.
|
|
2951
|
+
|
|
2952
|
+
Parameters:
|
|
2953
|
+
*cols: Limit to the specified columns. By default, all columns are selected.
|
|
2954
|
+
|
|
2955
|
+
Returns:
|
|
2956
|
+
list[tuple[DataType, ...]]: Returns a list of tuples of items for each row.
|
|
2957
|
+
|
|
2958
|
+
Example:
|
|
2959
|
+
Getting all rows as a list:
|
|
2960
|
+
```py
|
|
2961
|
+
rows = dc.to_list()
|
|
2962
|
+
print(rows)
|
|
2963
|
+
```
|
|
2964
|
+
|
|
2965
|
+
Getting all rows with selected columns as a list:
|
|
2966
|
+
```py
|
|
2967
|
+
name_size_pairs = dc.to_list("file.path", "file.size")
|
|
2968
|
+
print(name_size_pairs)
|
|
2969
|
+
```
|
|
2970
|
+
|
|
2971
|
+
Getting a single column as a list:
|
|
2972
|
+
```py
|
|
2973
|
+
files = dc.to_list("file.path")
|
|
2974
|
+
print(files) # Returns list of 1-tuples
|
|
2975
|
+
```
|
|
2976
|
+
"""
|
|
2977
|
+
return list(self.to_iter(*cols))
|
|
2978
|
+
|
|
2979
|
+
def to_values(self, col: str) -> list[DataValue]:
|
|
2980
|
+
"""Returns a flat list of values from a single column.
|
|
2981
|
+
|
|
2982
|
+
Parameters:
|
|
2983
|
+
col: The name of the column to extract values from.
|
|
2984
|
+
|
|
2985
|
+
Returns:
|
|
2986
|
+
list[DataValue]: Returns a flat list of values from the specified column.
|
|
2987
|
+
|
|
2988
|
+
Example:
|
|
2989
|
+
Getting all values from a single column:
|
|
2990
|
+
```py
|
|
2991
|
+
file_paths = dc.to_values("file.path")
|
|
2992
|
+
print(file_paths) # Returns list of strings
|
|
2993
|
+
```
|
|
2994
|
+
|
|
2995
|
+
Getting all file sizes:
|
|
2996
|
+
```py
|
|
2997
|
+
sizes = dc.to_values("file.size")
|
|
2998
|
+
print(sizes) # Returns list of integers
|
|
2999
|
+
```
|
|
3000
|
+
"""
|
|
3001
|
+
return [row[0] for row in self.to_list(col)]
|
|
3002
|
+
|
|
3003
|
+
def __iter__(self) -> Iterator[tuple[DataValue, ...]]:
|
|
3004
|
+
"""Make DataChain objects iterable.
|
|
3005
|
+
|
|
3006
|
+
Yields:
|
|
3007
|
+
(tuple[DataValue, ...]): Yields tuples of all column values for each row.
|
|
3008
|
+
|
|
3009
|
+
Example:
|
|
3010
|
+
```py
|
|
3011
|
+
for row in chain:
|
|
3012
|
+
print(row)
|
|
3013
|
+
```
|
|
3014
|
+
"""
|
|
3015
|
+
return self.to_iter()
|