datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/dc/datachain.py
CHANGED
|
@@ -1,37 +1,40 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import hashlib
|
|
3
|
+
import logging
|
|
2
4
|
import os
|
|
3
5
|
import os.path
|
|
4
6
|
import sys
|
|
5
7
|
import warnings
|
|
6
|
-
from collections.abc import Iterator, Sequence
|
|
8
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
7
9
|
from typing import (
|
|
8
10
|
IO,
|
|
9
11
|
TYPE_CHECKING,
|
|
10
12
|
Any,
|
|
11
13
|
BinaryIO,
|
|
12
|
-
Callable,
|
|
13
14
|
ClassVar,
|
|
14
15
|
Literal,
|
|
15
|
-
Optional,
|
|
16
16
|
TypeVar,
|
|
17
|
-
Union,
|
|
18
17
|
cast,
|
|
19
18
|
overload,
|
|
20
19
|
)
|
|
21
20
|
|
|
22
21
|
import sqlalchemy
|
|
23
|
-
import ujson as json
|
|
24
22
|
from pydantic import BaseModel
|
|
25
23
|
from sqlalchemy.sql.elements import ColumnElement
|
|
26
24
|
from tqdm import tqdm
|
|
27
25
|
|
|
28
|
-
from datachain import semver
|
|
26
|
+
from datachain import json, semver
|
|
29
27
|
from datachain.dataset import DatasetRecord
|
|
30
28
|
from datachain.delta import delta_disabled
|
|
31
|
-
from datachain.error import
|
|
29
|
+
from datachain.error import (
|
|
30
|
+
JobAncestryDepthExceededError,
|
|
31
|
+
ProjectCreateNotAllowedError,
|
|
32
|
+
ProjectNotFoundError,
|
|
33
|
+
)
|
|
32
34
|
from datachain.func import literal
|
|
33
35
|
from datachain.func.base import Function
|
|
34
36
|
from datachain.func.func import Func
|
|
37
|
+
from datachain.job import Job
|
|
35
38
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
36
39
|
from datachain.lib.data_model import (
|
|
37
40
|
DataModel,
|
|
@@ -40,11 +43,7 @@ from datachain.lib.data_model import (
|
|
|
40
43
|
StandardType,
|
|
41
44
|
dict_to_data_model,
|
|
42
45
|
)
|
|
43
|
-
from datachain.lib.file import
|
|
44
|
-
EXPORT_FILES_MAX_THREADS,
|
|
45
|
-
ArrowRow,
|
|
46
|
-
FileExporter,
|
|
47
|
-
)
|
|
46
|
+
from datachain.lib.file import EXPORT_FILES_MAX_THREADS, ArrowRow, File, FileExporter
|
|
48
47
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
49
48
|
from datachain.lib.model_store import ModelStore
|
|
50
49
|
from datachain.lib.settings import Settings
|
|
@@ -52,11 +51,17 @@ from datachain.lib.signal_schema import SignalResolvingError, SignalSchema
|
|
|
52
51
|
from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
|
|
53
52
|
from datachain.lib.udf_signature import UdfSignature
|
|
54
53
|
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
|
|
54
|
+
from datachain.project import Project
|
|
55
55
|
from datachain.query import Session
|
|
56
|
-
from datachain.query.dataset import
|
|
56
|
+
from datachain.query.dataset import (
|
|
57
|
+
DatasetQuery,
|
|
58
|
+
PartitionByType,
|
|
59
|
+
RegenerateSystemColumns,
|
|
60
|
+
UnionSchemaMismatchError,
|
|
61
|
+
)
|
|
57
62
|
from datachain.query.schema import DEFAULT_DELIMITER, Column
|
|
58
63
|
from datachain.sql.functions import path as pathfunc
|
|
59
|
-
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict
|
|
64
|
+
from datachain.utils import batched_it, env2bool, inside_notebook, row_to_nested_dict
|
|
60
65
|
|
|
61
66
|
from .database import DEFAULT_DATABASE_BATCH_SIZE
|
|
62
67
|
from .utils import (
|
|
@@ -71,6 +76,8 @@ from .utils import (
|
|
|
71
76
|
resolve_columns,
|
|
72
77
|
)
|
|
73
78
|
|
|
79
|
+
logger = logging.getLogger("datachain")
|
|
80
|
+
|
|
74
81
|
C = Column
|
|
75
82
|
|
|
76
83
|
_T = TypeVar("_T")
|
|
@@ -82,19 +89,20 @@ if TYPE_CHECKING:
|
|
|
82
89
|
import sqlite3
|
|
83
90
|
|
|
84
91
|
import pandas as pd
|
|
92
|
+
from sqlalchemy.orm import Session as OrmSession
|
|
85
93
|
from typing_extensions import ParamSpec, Self
|
|
86
94
|
|
|
87
95
|
P = ParamSpec("P")
|
|
88
96
|
|
|
89
|
-
ConnectionType =
|
|
90
|
-
str
|
|
91
|
-
sqlalchemy.engine.URL
|
|
92
|
-
sqlalchemy.engine.interfaces.Connectable
|
|
93
|
-
sqlalchemy.engine.Engine
|
|
94
|
-
sqlalchemy.engine.Connection
|
|
95
|
-
|
|
96
|
-
sqlite3.Connection
|
|
97
|
-
|
|
97
|
+
ConnectionType = (
|
|
98
|
+
str
|
|
99
|
+
| sqlalchemy.engine.URL
|
|
100
|
+
| sqlalchemy.engine.interfaces.Connectable
|
|
101
|
+
| sqlalchemy.engine.Engine
|
|
102
|
+
| sqlalchemy.engine.Connection
|
|
103
|
+
| OrmSession
|
|
104
|
+
| sqlite3.Connection
|
|
105
|
+
)
|
|
98
106
|
|
|
99
107
|
|
|
100
108
|
T = TypeVar("T", bound="DataChain")
|
|
@@ -183,7 +191,7 @@ class DataChain:
|
|
|
183
191
|
query: DatasetQuery,
|
|
184
192
|
settings: Settings,
|
|
185
193
|
signal_schema: SignalSchema,
|
|
186
|
-
setup:
|
|
194
|
+
setup: dict | None = None,
|
|
187
195
|
_sys: bool = False,
|
|
188
196
|
) -> None:
|
|
189
197
|
"""Don't instantiate this directly, use one of the from_XXX constructors."""
|
|
@@ -193,10 +201,11 @@ class DataChain:
|
|
|
193
201
|
self._setup: dict = setup or {}
|
|
194
202
|
self._sys = _sys
|
|
195
203
|
self._delta = False
|
|
196
|
-
self.
|
|
197
|
-
self.
|
|
198
|
-
self.
|
|
199
|
-
self.
|
|
204
|
+
self._delta_unsafe = False
|
|
205
|
+
self._delta_on: str | Sequence[str] | None = None
|
|
206
|
+
self._delta_result_on: str | Sequence[str] | None = None
|
|
207
|
+
self._delta_compare: str | Sequence[str] | None = None
|
|
208
|
+
self._delta_retry: bool | str | None = None
|
|
200
209
|
|
|
201
210
|
def __repr__(self) -> str:
|
|
202
211
|
"""Return a string representation of the chain."""
|
|
@@ -210,12 +219,21 @@ class DataChain:
|
|
|
210
219
|
self.print_schema(file=file)
|
|
211
220
|
return file.getvalue()
|
|
212
221
|
|
|
222
|
+
def hash(self) -> str:
|
|
223
|
+
"""
|
|
224
|
+
Calculates SHA hash of this chain. Hash calculation is fast and consistent.
|
|
225
|
+
It takes into account all the steps added to the chain and their inputs.
|
|
226
|
+
Order of the steps is important.
|
|
227
|
+
"""
|
|
228
|
+
return self._query.hash()
|
|
229
|
+
|
|
213
230
|
def _as_delta(
|
|
214
231
|
self,
|
|
215
|
-
on:
|
|
216
|
-
right_on:
|
|
217
|
-
compare:
|
|
218
|
-
delta_retry:
|
|
232
|
+
on: str | Sequence[str] | None = None,
|
|
233
|
+
right_on: str | Sequence[str] | None = None,
|
|
234
|
+
compare: str | Sequence[str] | None = None,
|
|
235
|
+
delta_retry: bool | str | None = None,
|
|
236
|
+
delta_unsafe: bool = False,
|
|
219
237
|
) -> "Self":
|
|
220
238
|
"""Marks this chain as delta, which means special delta process will be
|
|
221
239
|
called on saving dataset for optimization"""
|
|
@@ -226,6 +244,7 @@ class DataChain:
|
|
|
226
244
|
self._delta_result_on = right_on
|
|
227
245
|
self._delta_compare = compare
|
|
228
246
|
self._delta_retry = delta_retry
|
|
247
|
+
self._delta_unsafe = delta_unsafe
|
|
229
248
|
return self
|
|
230
249
|
|
|
231
250
|
@property
|
|
@@ -238,6 +257,10 @@ class DataChain:
|
|
|
238
257
|
"""Returns True if this chain is ran in "delta" update mode"""
|
|
239
258
|
return self._delta
|
|
240
259
|
|
|
260
|
+
@property
|
|
261
|
+
def delta_unsafe(self) -> bool:
|
|
262
|
+
return self._delta_unsafe
|
|
263
|
+
|
|
241
264
|
@property
|
|
242
265
|
def schema(self) -> dict[str, DataType]:
|
|
243
266
|
"""Get schema of the chain."""
|
|
@@ -259,7 +282,7 @@ class DataChain:
|
|
|
259
282
|
|
|
260
283
|
raise ValueError(f"Column with name {name} not found in the schema")
|
|
261
284
|
|
|
262
|
-
def c(self, column:
|
|
285
|
+
def c(self, column: str | Column) -> Column:
|
|
263
286
|
"""Returns Column instance attached to the current chain."""
|
|
264
287
|
c = self.column(column) if isinstance(column, str) else self.column(column.name)
|
|
265
288
|
c.table = self._query.table
|
|
@@ -271,17 +294,17 @@ class DataChain:
|
|
|
271
294
|
return self._query.session
|
|
272
295
|
|
|
273
296
|
@property
|
|
274
|
-
def name(self) ->
|
|
297
|
+
def name(self) -> str | None:
|
|
275
298
|
"""Name of the underlying dataset, if there is one."""
|
|
276
299
|
return self._query.name
|
|
277
300
|
|
|
278
301
|
@property
|
|
279
|
-
def version(self) ->
|
|
302
|
+
def version(self) -> str | None:
|
|
280
303
|
"""Version of the underlying dataset, if there is one."""
|
|
281
304
|
return self._query.version
|
|
282
305
|
|
|
283
306
|
@property
|
|
284
|
-
def dataset(self) ->
|
|
307
|
+
def dataset(self) -> DatasetRecord | None:
|
|
285
308
|
"""Underlying dataset, if there is one."""
|
|
286
309
|
if not self.name:
|
|
287
310
|
return None
|
|
@@ -295,7 +318,7 @@ class DataChain:
|
|
|
295
318
|
"""Return `self.union(other)`."""
|
|
296
319
|
return self.union(other)
|
|
297
320
|
|
|
298
|
-
def print_schema(self, file:
|
|
321
|
+
def print_schema(self, file: IO | None = None) -> None:
|
|
299
322
|
"""Print schema of the chain."""
|
|
300
323
|
self._effective_signals_schema.print_tree(file=file)
|
|
301
324
|
|
|
@@ -306,8 +329,8 @@ class DataChain:
|
|
|
306
329
|
def _evolve(
|
|
307
330
|
self,
|
|
308
331
|
*,
|
|
309
|
-
query:
|
|
310
|
-
settings:
|
|
332
|
+
query: DatasetQuery | None = None,
|
|
333
|
+
settings: Settings | None = None,
|
|
311
334
|
signal_schema=None,
|
|
312
335
|
_sys=None,
|
|
313
336
|
) -> "Self":
|
|
@@ -328,46 +351,51 @@ class DataChain:
|
|
|
328
351
|
right_on=self._delta_result_on,
|
|
329
352
|
compare=self._delta_compare,
|
|
330
353
|
delta_retry=self._delta_retry,
|
|
354
|
+
delta_unsafe=self._delta_unsafe,
|
|
331
355
|
)
|
|
332
356
|
|
|
333
357
|
return chain
|
|
334
358
|
|
|
335
359
|
def settings(
|
|
336
360
|
self,
|
|
337
|
-
cache=None,
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
361
|
+
cache: bool | None = None,
|
|
362
|
+
prefetch: bool | int | None = None,
|
|
363
|
+
parallel: bool | int | None = None,
|
|
364
|
+
workers: int | None = None,
|
|
365
|
+
namespace: str | None = None,
|
|
366
|
+
project: str | None = None,
|
|
367
|
+
min_task_size: int | None = None,
|
|
368
|
+
batch_size: int | None = None,
|
|
369
|
+
sys: bool | None = None,
|
|
346
370
|
) -> "Self":
|
|
347
|
-
"""
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
371
|
+
"""
|
|
372
|
+
Set chain execution parameters. Returns the chain itself, allowing method
|
|
373
|
+
chaining for subsequent operations. To restore all settings to their default
|
|
374
|
+
values, use `reset_settings()`.
|
|
351
375
|
|
|
352
376
|
Parameters:
|
|
353
|
-
cache
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
377
|
+
cache: Enable files caching to speed up subsequent accesses to the same
|
|
378
|
+
files from the same or different chains. Defaults to False.
|
|
379
|
+
prefetch: Enable prefetching of files. This will download files in
|
|
380
|
+
advance in parallel. If an integer is provided, it specifies the number
|
|
381
|
+
of files to prefetch concurrently for each process on each worker.
|
|
382
|
+
Defaults to 2. Set to 0 or False to disable prefetching.
|
|
383
|
+
parallel: Number of processes to use for processing user-defined functions
|
|
384
|
+
(UDFs) in parallel. If an integer is provided, it specifies the number
|
|
385
|
+
of CPUs to use. If True, all available CPUs are used. Defaults to 1.
|
|
386
|
+
namespace: Namespace to use for the chain by default.
|
|
387
|
+
project: Project to use for the chain by default.
|
|
388
|
+
min_task_size: Minimum number of rows per worker/process for parallel
|
|
389
|
+
processing by UDFs. Defaults to 1.
|
|
390
|
+
batch_size: Number of rows per insert by UDF to fine tune and balance speed
|
|
391
|
+
and memory usage. This might be useful when processing large rows
|
|
392
|
+
or when running into memory issues. Defaults to 2000.
|
|
365
393
|
|
|
366
394
|
Example:
|
|
367
395
|
```py
|
|
368
396
|
chain = (
|
|
369
397
|
chain
|
|
370
|
-
.settings(cache=True, parallel=8,
|
|
398
|
+
.settings(cache=True, parallel=8, batch_size=300)
|
|
371
399
|
.map(laion=process_webdataset(spec=WDSLaion), params="file")
|
|
372
400
|
)
|
|
373
401
|
```
|
|
@@ -377,20 +405,20 @@ class DataChain:
|
|
|
377
405
|
settings = copy.copy(self._settings)
|
|
378
406
|
settings.add(
|
|
379
407
|
Settings(
|
|
380
|
-
cache,
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
408
|
+
cache=cache,
|
|
409
|
+
prefetch=prefetch,
|
|
410
|
+
parallel=parallel,
|
|
411
|
+
workers=workers,
|
|
412
|
+
namespace=namespace,
|
|
413
|
+
project=project,
|
|
414
|
+
min_task_size=min_task_size,
|
|
415
|
+
batch_size=batch_size,
|
|
388
416
|
)
|
|
389
417
|
)
|
|
390
418
|
return self._evolve(settings=settings, _sys=sys)
|
|
391
419
|
|
|
392
|
-
def reset_settings(self, settings:
|
|
393
|
-
"""Reset all settings to default values."""
|
|
420
|
+
def reset_settings(self, settings: Settings | None = None) -> "Self":
|
|
421
|
+
"""Reset all chain settings to default values."""
|
|
394
422
|
self._settings = settings if settings else Settings()
|
|
395
423
|
return self
|
|
396
424
|
|
|
@@ -441,8 +469,8 @@ class DataChain:
|
|
|
441
469
|
def explode(
|
|
442
470
|
self,
|
|
443
471
|
col: str,
|
|
444
|
-
model_name:
|
|
445
|
-
column:
|
|
472
|
+
model_name: str | None = None,
|
|
473
|
+
column: str | None = None,
|
|
446
474
|
schema_sample_size: int = 1,
|
|
447
475
|
) -> "DataChain":
|
|
448
476
|
"""Explodes a column containing JSON objects (dict or str DataChain type) into
|
|
@@ -483,7 +511,7 @@ class DataChain:
|
|
|
483
511
|
|
|
484
512
|
model = dict_to_data_model(model_name, output, original_names)
|
|
485
513
|
|
|
486
|
-
def json_to_model(json_value:
|
|
514
|
+
def json_to_model(json_value: str | dict):
|
|
487
515
|
json_dict = (
|
|
488
516
|
json.loads(json_value) if isinstance(json_value, str) else json_value
|
|
489
517
|
)
|
|
@@ -557,116 +585,258 @@ class DataChain:
|
|
|
557
585
|
create=True,
|
|
558
586
|
)
|
|
559
587
|
return self._evolve(
|
|
560
|
-
query=self._query.save(project=project, feature_schema=schema)
|
|
588
|
+
query=self._query.save(project=project, feature_schema=schema),
|
|
589
|
+
signal_schema=self.signals_schema | SignalSchema({"sys": Sys}),
|
|
561
590
|
)
|
|
562
591
|
|
|
592
|
+
def _calculate_job_hash(self, job_id: str) -> str:
|
|
593
|
+
"""
|
|
594
|
+
Calculates hash of the job at the place of this chain's save method.
|
|
595
|
+
Hash is calculated using previous job checkpoint hash (if exists) and
|
|
596
|
+
adding hash of this chain to produce new hash.
|
|
597
|
+
"""
|
|
598
|
+
last_checkpoint = self.session.catalog.metastore.get_last_checkpoint(job_id)
|
|
599
|
+
|
|
600
|
+
return hashlib.sha256(
|
|
601
|
+
(bytes.fromhex(last_checkpoint.hash) if last_checkpoint else b"")
|
|
602
|
+
+ bytes.fromhex(self.hash())
|
|
603
|
+
).hexdigest()
|
|
604
|
+
|
|
563
605
|
def save( # type: ignore[override]
|
|
564
606
|
self,
|
|
565
607
|
name: str,
|
|
566
|
-
version:
|
|
567
|
-
description:
|
|
568
|
-
attrs:
|
|
569
|
-
update_version:
|
|
608
|
+
version: str | None = None,
|
|
609
|
+
description: str | None = None,
|
|
610
|
+
attrs: list[str] | None = None,
|
|
611
|
+
update_version: str | None = "patch",
|
|
570
612
|
**kwargs,
|
|
571
613
|
) -> "DataChain":
|
|
572
614
|
"""Save to a Dataset. It returns the chain itself.
|
|
573
615
|
|
|
574
616
|
Parameters:
|
|
575
|
-
name
|
|
576
|
-
|
|
577
|
-
case
|
|
578
|
-
|
|
579
|
-
version
|
|
617
|
+
name: dataset name. This can be either a fully qualified name, including
|
|
618
|
+
the namespace and project, or just a regular dataset name. In the latter
|
|
619
|
+
case, the namespace and project will be taken from the settings
|
|
620
|
+
(if specified) or from the default values otherwise.
|
|
621
|
+
version: version of a dataset. If version is not specified and dataset
|
|
580
622
|
already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
|
|
581
|
-
description
|
|
582
|
-
attrs
|
|
623
|
+
description: description of a dataset.
|
|
624
|
+
attrs: attributes of a dataset. They can be without value, e.g "NLP",
|
|
583
625
|
or with a value, e.g "location=US".
|
|
584
626
|
update_version: which part of the dataset version to automatically increase.
|
|
585
627
|
Available values: `major`, `minor` or `patch`. Default is `patch`.
|
|
586
628
|
"""
|
|
629
|
+
|
|
587
630
|
catalog = self.session.catalog
|
|
588
|
-
if version is not None:
|
|
589
|
-
semver.validate(version)
|
|
590
631
|
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
)
|
|
632
|
+
result = None # result chain that will be returned at the end
|
|
633
|
+
|
|
634
|
+
# Version validation
|
|
635
|
+
self._validate_version(version)
|
|
636
|
+
self._validate_update_version(update_version)
|
|
637
|
+
|
|
638
|
+
# get existing job if running in SaaS, or creating new one if running locally
|
|
639
|
+
job = self.session.get_or_create_job()
|
|
600
640
|
|
|
601
641
|
namespace_name, project_name, name = catalog.get_full_dataset_name(
|
|
602
642
|
name,
|
|
603
643
|
namespace_name=self._settings.namespace,
|
|
604
644
|
project_name=self._settings.project,
|
|
605
645
|
)
|
|
646
|
+
project = self._get_or_create_project(namespace_name, project_name)
|
|
647
|
+
|
|
648
|
+
# Checkpoint handling
|
|
649
|
+
_hash, result = self._resolve_checkpoint(name, project, job, kwargs)
|
|
650
|
+
if bool(result):
|
|
651
|
+
# Checkpoint was found and reused
|
|
652
|
+
print(f"Checkpoint found for dataset '{name}', skipping creation")
|
|
653
|
+
|
|
654
|
+
# Schema preparation
|
|
655
|
+
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
656
|
+
|
|
657
|
+
# Handle retry and delta functionality
|
|
658
|
+
if not result:
|
|
659
|
+
result = self._handle_delta(name, version, project, schema, kwargs)
|
|
660
|
+
|
|
661
|
+
if not result:
|
|
662
|
+
# calculate chain if we already don't have result from checkpoint or delta
|
|
663
|
+
result = self._evolve(
|
|
664
|
+
query=self._query.save(
|
|
665
|
+
name=name,
|
|
666
|
+
version=version,
|
|
667
|
+
project=project,
|
|
668
|
+
description=description,
|
|
669
|
+
attrs=attrs,
|
|
670
|
+
feature_schema=schema,
|
|
671
|
+
update_version=update_version,
|
|
672
|
+
**kwargs,
|
|
673
|
+
)
|
|
674
|
+
)
|
|
606
675
|
|
|
676
|
+
catalog.metastore.create_checkpoint(job.id, _hash) # type: ignore[arg-type]
|
|
677
|
+
return result
|
|
678
|
+
|
|
679
|
+
def _validate_version(self, version: str | None) -> None:
|
|
680
|
+
"""Validate dataset version if provided."""
|
|
681
|
+
if version is not None:
|
|
682
|
+
semver.validate(version)
|
|
683
|
+
|
|
684
|
+
def _validate_update_version(self, update_version: str | None) -> None:
|
|
685
|
+
"""Ensure update_version is one of: major, minor, patch."""
|
|
686
|
+
allowed = ["major", "minor", "patch"]
|
|
687
|
+
if update_version not in allowed:
|
|
688
|
+
raise ValueError(f"update_version must be one of {allowed}")
|
|
689
|
+
|
|
690
|
+
def _get_or_create_project(self, namespace: str, project_name: str) -> Project:
|
|
691
|
+
"""Get project or raise if creation not allowed."""
|
|
607
692
|
try:
|
|
608
|
-
|
|
693
|
+
return self.session.catalog.metastore.get_project(
|
|
609
694
|
project_name,
|
|
610
|
-
|
|
695
|
+
namespace,
|
|
611
696
|
create=is_studio(),
|
|
612
697
|
)
|
|
613
698
|
except ProjectNotFoundError as e:
|
|
614
|
-
# not being able to create it as creation is not allowed
|
|
615
699
|
raise ProjectCreateNotAllowedError("Creating project is not allowed") from e
|
|
616
700
|
|
|
617
|
-
|
|
701
|
+
def _resolve_checkpoint(
|
|
702
|
+
self,
|
|
703
|
+
name: str,
|
|
704
|
+
project: Project,
|
|
705
|
+
job: Job,
|
|
706
|
+
kwargs: dict,
|
|
707
|
+
) -> tuple[str, "DataChain | None"]:
|
|
708
|
+
"""Check if checkpoint exists and return cached dataset if possible."""
|
|
709
|
+
from .datasets import read_dataset
|
|
618
710
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
from datachain.delta import delta_retry_update
|
|
711
|
+
metastore = self.session.catalog.metastore
|
|
712
|
+
checkpoints_reset = env2bool("DATACHAIN_CHECKPOINTS_RESET", undefined=True)
|
|
622
713
|
|
|
623
|
-
|
|
624
|
-
assert self._delta_on is not None, "Delta chain must have delta_on defined"
|
|
714
|
+
_hash = self._calculate_job_hash(job.id)
|
|
625
715
|
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
716
|
+
if (
|
|
717
|
+
job.parent_job_id
|
|
718
|
+
and not checkpoints_reset
|
|
719
|
+
and metastore.find_checkpoint(job.parent_job_id, _hash)
|
|
720
|
+
):
|
|
721
|
+
# checkpoint found → find which dataset version to reuse
|
|
722
|
+
|
|
723
|
+
# Find dataset version that was created by any ancestor job
|
|
724
|
+
try:
|
|
725
|
+
dataset_version = metastore.get_dataset_version_for_job_ancestry(
|
|
726
|
+
name,
|
|
727
|
+
project.namespace.name,
|
|
728
|
+
project.name,
|
|
729
|
+
job.id,
|
|
730
|
+
)
|
|
731
|
+
except JobAncestryDepthExceededError:
|
|
732
|
+
raise JobAncestryDepthExceededError(
|
|
733
|
+
"Job continuation chain is too deep. "
|
|
734
|
+
"Please run the job from scratch without continuing from a "
|
|
735
|
+
"parent job."
|
|
736
|
+
) from None
|
|
737
|
+
|
|
738
|
+
if not dataset_version:
|
|
739
|
+
logger.debug(
|
|
740
|
+
"Checkpoint found but no dataset version for '%s' "
|
|
741
|
+
"in job ancestry (job_id=%s). Creating new version.",
|
|
742
|
+
name,
|
|
743
|
+
job.id,
|
|
744
|
+
)
|
|
745
|
+
# Dataset version not found (e.g deleted by user) - skip
|
|
746
|
+
# checkpoint and recreate
|
|
747
|
+
return _hash, None
|
|
748
|
+
|
|
749
|
+
logger.debug(
|
|
750
|
+
"Reusing dataset version '%s' v%s from job ancestry "
|
|
751
|
+
"(job_id=%s, dataset_version_id=%s)",
|
|
630
752
|
name,
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
delta_retry=self._delta_retry,
|
|
753
|
+
dataset_version.version,
|
|
754
|
+
job.id,
|
|
755
|
+
dataset_version.id,
|
|
635
756
|
)
|
|
636
757
|
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
**kwargs,
|
|
646
|
-
)
|
|
647
|
-
)
|
|
758
|
+
# Read the specific version from ancestry
|
|
759
|
+
chain = read_dataset(
|
|
760
|
+
name,
|
|
761
|
+
namespace=project.namespace.name,
|
|
762
|
+
project=project.name,
|
|
763
|
+
version=dataset_version.version,
|
|
764
|
+
**kwargs,
|
|
765
|
+
)
|
|
648
766
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
767
|
+
# Link current job to this dataset version (not creator).
|
|
768
|
+
# This also updates dataset_version.job_id.
|
|
769
|
+
metastore.link_dataset_version_to_job(
|
|
770
|
+
dataset_version.id,
|
|
771
|
+
job.id,
|
|
772
|
+
is_creator=False,
|
|
773
|
+
)
|
|
655
774
|
|
|
656
|
-
|
|
775
|
+
return _hash, chain
|
|
657
776
|
|
|
658
|
-
return
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
777
|
+
return _hash, None
|
|
778
|
+
|
|
779
|
+
def _handle_delta(
|
|
780
|
+
self,
|
|
781
|
+
name: str,
|
|
782
|
+
version: str | None,
|
|
783
|
+
project: Project,
|
|
784
|
+
schema: dict,
|
|
785
|
+
kwargs: dict,
|
|
786
|
+
) -> "DataChain | None":
|
|
787
|
+
"""Try to save as a delta dataset.
|
|
788
|
+
Returns:
|
|
789
|
+
A DataChain if delta logic could handle it, otherwise None to fall back
|
|
790
|
+
to the regular save path (e.g., on first dataset creation).
|
|
791
|
+
"""
|
|
792
|
+
from datachain.delta import delta_retry_update
|
|
793
|
+
|
|
794
|
+
from .datasets import read_dataset
|
|
795
|
+
|
|
796
|
+
if not self.delta or not name:
|
|
797
|
+
return None
|
|
798
|
+
|
|
799
|
+
assert self._delta_on is not None, "Delta chain must have delta_on defined"
|
|
800
|
+
|
|
801
|
+
result_ds, dependencies, has_changes = delta_retry_update(
|
|
802
|
+
self,
|
|
803
|
+
project.namespace.name,
|
|
804
|
+
project.name,
|
|
805
|
+
name,
|
|
806
|
+
on=self._delta_on,
|
|
807
|
+
right_on=self._delta_result_on,
|
|
808
|
+
compare=self._delta_compare,
|
|
809
|
+
delta_retry=self._delta_retry,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Case 1: delta produced a new dataset
|
|
813
|
+
if result_ds:
|
|
814
|
+
return self._evolve(
|
|
815
|
+
query=result_ds._query.save(
|
|
816
|
+
name=name,
|
|
817
|
+
version=version,
|
|
818
|
+
project=project,
|
|
819
|
+
feature_schema=schema,
|
|
820
|
+
dependencies=dependencies,
|
|
821
|
+
**kwargs,
|
|
822
|
+
)
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
# Case 2: no changes → reuse last version
|
|
826
|
+
if not has_changes:
|
|
827
|
+
# sources have not been changed so new version of resulting dataset
|
|
828
|
+
# would be the same as previous one. To avoid duplicating exact
|
|
829
|
+
# datasets, we won't create new version of it and we will return
|
|
830
|
+
# current latest version instead.
|
|
831
|
+
return read_dataset(
|
|
832
|
+
name,
|
|
833
|
+
namespace=project.namespace.name,
|
|
834
|
+
project=project.name,
|
|
667
835
|
**kwargs,
|
|
668
836
|
)
|
|
669
|
-
|
|
837
|
+
|
|
838
|
+
# Case 3: first creation of dataset
|
|
839
|
+
return None
|
|
670
840
|
|
|
671
841
|
def apply(self, func, *args, **kwargs):
|
|
672
842
|
"""Apply any function to the chain.
|
|
@@ -693,10 +863,10 @@ class DataChain:
|
|
|
693
863
|
|
|
694
864
|
def map(
|
|
695
865
|
self,
|
|
696
|
-
func:
|
|
697
|
-
params:
|
|
866
|
+
func: Callable | None = None,
|
|
867
|
+
params: str | Sequence[str] | None = None,
|
|
698
868
|
output: OutputType = None,
|
|
699
|
-
**signal_map,
|
|
869
|
+
**signal_map: Any,
|
|
700
870
|
) -> "Self":
|
|
701
871
|
"""Apply a function to each row to create new signals. The function should
|
|
702
872
|
return a new object for each row. It returns a chain itself with new signals.
|
|
@@ -704,17 +874,17 @@ class DataChain:
|
|
|
704
874
|
Input-output relationship: 1:1
|
|
705
875
|
|
|
706
876
|
Parameters:
|
|
707
|
-
func
|
|
708
|
-
params
|
|
877
|
+
func: Function applied to each row.
|
|
878
|
+
params: List of column names used as input for the function. Default
|
|
709
879
|
is taken from function signature.
|
|
710
|
-
output
|
|
880
|
+
output: Dictionary defining new signals and their corresponding types.
|
|
711
881
|
Default type is taken from function signature. Default can be also
|
|
712
882
|
taken from kwargs - **signal_map (see below).
|
|
713
883
|
If signal name is defined using signal_map (see below) only a single
|
|
714
884
|
type value can be used.
|
|
715
|
-
**signal_map
|
|
885
|
+
**signal_map: kwargs can be used to define `func` together with its return
|
|
716
886
|
signal name in format of `map(my_sign=my_func)`. This helps define
|
|
717
|
-
signal names and
|
|
887
|
+
signal names and functions in a nicer way.
|
|
718
888
|
|
|
719
889
|
Example:
|
|
720
890
|
Using signal_map and single type in output:
|
|
@@ -735,18 +905,19 @@ class DataChain:
|
|
|
735
905
|
if (prefetch := self._settings.prefetch) is not None:
|
|
736
906
|
udf_obj.prefetch = prefetch
|
|
737
907
|
|
|
908
|
+
sys_schema = SignalSchema({"sys": Sys})
|
|
738
909
|
return self._evolve(
|
|
739
910
|
query=self._query.add_signals(
|
|
740
|
-
udf_obj.to_udf_wrapper(self._settings.
|
|
911
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
741
912
|
**self._settings.to_dict(),
|
|
742
913
|
),
|
|
743
|
-
signal_schema=self.signals_schema | udf_obj.output,
|
|
914
|
+
signal_schema=sys_schema | self.signals_schema | udf_obj.output,
|
|
744
915
|
)
|
|
745
916
|
|
|
746
917
|
def gen(
|
|
747
918
|
self,
|
|
748
|
-
func:
|
|
749
|
-
params:
|
|
919
|
+
func: Callable | Generator | None = None,
|
|
920
|
+
params: str | Sequence[str] | None = None,
|
|
750
921
|
output: OutputType = None,
|
|
751
922
|
**signal_map,
|
|
752
923
|
) -> "Self":
|
|
@@ -775,19 +946,19 @@ class DataChain:
|
|
|
775
946
|
udf_obj.prefetch = prefetch
|
|
776
947
|
return self._evolve(
|
|
777
948
|
query=self._query.generate(
|
|
778
|
-
udf_obj.to_udf_wrapper(self._settings.
|
|
949
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
779
950
|
**self._settings.to_dict(),
|
|
780
951
|
),
|
|
781
|
-
signal_schema=udf_obj.output,
|
|
952
|
+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
|
|
782
953
|
)
|
|
783
954
|
|
|
784
955
|
@delta_disabled
|
|
785
956
|
def agg(
|
|
786
957
|
self,
|
|
787
958
|
/,
|
|
788
|
-
func:
|
|
789
|
-
partition_by:
|
|
790
|
-
params:
|
|
959
|
+
func: Callable | None = None,
|
|
960
|
+
partition_by: PartitionByType | None = None,
|
|
961
|
+
params: str | Sequence[str] | None = None,
|
|
791
962
|
output: OutputType = None,
|
|
792
963
|
**signal_map: Callable,
|
|
793
964
|
) -> "Self":
|
|
@@ -911,17 +1082,17 @@ class DataChain:
|
|
|
911
1082
|
udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map)
|
|
912
1083
|
return self._evolve(
|
|
913
1084
|
query=self._query.generate(
|
|
914
|
-
udf_obj.to_udf_wrapper(self._settings.
|
|
1085
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size),
|
|
915
1086
|
partition_by=processed_partition_by,
|
|
916
1087
|
**self._settings.to_dict(),
|
|
917
1088
|
),
|
|
918
|
-
signal_schema=udf_obj.output,
|
|
1089
|
+
signal_schema=SignalSchema({"sys": Sys}) | udf_obj.output,
|
|
919
1090
|
)
|
|
920
1091
|
|
|
921
1092
|
def batch_map(
|
|
922
1093
|
self,
|
|
923
|
-
func:
|
|
924
|
-
params:
|
|
1094
|
+
func: Callable | None = None,
|
|
1095
|
+
params: str | Sequence[str] | None = None,
|
|
925
1096
|
output: OutputType = None,
|
|
926
1097
|
batch: int = 1000,
|
|
927
1098
|
**signal_map,
|
|
@@ -933,7 +1104,7 @@ class DataChain:
|
|
|
933
1104
|
It accepts the same parameters plus an
|
|
934
1105
|
additional parameter:
|
|
935
1106
|
|
|
936
|
-
batch
|
|
1107
|
+
batch: Size of each batch passed to `func`. Defaults to 1000.
|
|
937
1108
|
|
|
938
1109
|
Example:
|
|
939
1110
|
```py
|
|
@@ -960,7 +1131,7 @@ class DataChain:
|
|
|
960
1131
|
|
|
961
1132
|
return self._evolve(
|
|
962
1133
|
query=self._query.add_signals(
|
|
963
|
-
udf_obj.to_udf_wrapper(self._settings.
|
|
1134
|
+
udf_obj.to_udf_wrapper(self._settings.batch_size, batch=batch),
|
|
964
1135
|
**self._settings.to_dict(),
|
|
965
1136
|
),
|
|
966
1137
|
signal_schema=self.signals_schema | udf_obj.output,
|
|
@@ -969,8 +1140,8 @@ class DataChain:
|
|
|
969
1140
|
def _udf_to_obj(
|
|
970
1141
|
self,
|
|
971
1142
|
target_class: type[UDFObjT],
|
|
972
|
-
func:
|
|
973
|
-
params:
|
|
1143
|
+
func: Callable | UDFObjT | None,
|
|
1144
|
+
params: str | Sequence[str] | None,
|
|
974
1145
|
output: OutputType,
|
|
975
1146
|
signal_map: dict[str, Callable],
|
|
976
1147
|
) -> UDFObjT:
|
|
@@ -981,11 +1152,7 @@ class DataChain:
|
|
|
981
1152
|
sign = UdfSignature.parse(name, signal_map, func, params, output, is_generator)
|
|
982
1153
|
DataModel.register(list(sign.output_schema.values.values()))
|
|
983
1154
|
|
|
984
|
-
|
|
985
|
-
if self._sys:
|
|
986
|
-
signals_schema = SignalSchema({"sys": Sys}) | signals_schema
|
|
987
|
-
|
|
988
|
-
params_schema = signals_schema.slice(
|
|
1155
|
+
params_schema = self.signals_schema.slice(
|
|
989
1156
|
sign.params, self._setup, is_batch=is_batch
|
|
990
1157
|
)
|
|
991
1158
|
|
|
@@ -1016,7 +1183,8 @@ class DataChain:
|
|
|
1016
1183
|
the order of the records in the chain is important.
|
|
1017
1184
|
Using `order_by` directly before `limit`, `to_list` and similar methods
|
|
1018
1185
|
will give expected results.
|
|
1019
|
-
See https://github.com/
|
|
1186
|
+
See https://github.com/datachain-ai/datachain/issues/477
|
|
1187
|
+
for further details.
|
|
1020
1188
|
"""
|
|
1021
1189
|
if descending:
|
|
1022
1190
|
args = tuple(sqlalchemy.desc(a) for a in args)
|
|
@@ -1040,11 +1208,9 @@ class DataChain:
|
|
|
1040
1208
|
)
|
|
1041
1209
|
)
|
|
1042
1210
|
|
|
1043
|
-
def select(self, *args: str
|
|
1211
|
+
def select(self, *args: str) -> "Self":
|
|
1044
1212
|
"""Select only a specified set of signals."""
|
|
1045
1213
|
new_schema = self.signals_schema.resolve(*args)
|
|
1046
|
-
if self._sys and _sys:
|
|
1047
|
-
new_schema = SignalSchema({"sys": Sys}) | new_schema
|
|
1048
1214
|
columns = new_schema.db_signals()
|
|
1049
1215
|
return self._evolve(
|
|
1050
1216
|
query=self._query.select(*columns), signal_schema=new_schema
|
|
@@ -1062,7 +1228,7 @@ class DataChain:
|
|
|
1062
1228
|
def group_by( # noqa: C901, PLR0912
|
|
1063
1229
|
self,
|
|
1064
1230
|
*,
|
|
1065
|
-
partition_by:
|
|
1231
|
+
partition_by: str | Func | Sequence[str | Func] | None = None,
|
|
1066
1232
|
**kwargs: Func,
|
|
1067
1233
|
) -> "Self":
|
|
1068
1234
|
"""Group rows by specified set of signals and return new signals
|
|
@@ -1301,9 +1467,9 @@ class DataChain:
|
|
|
1301
1467
|
"""Yields flattened rows of values as a tuple.
|
|
1302
1468
|
|
|
1303
1469
|
Args:
|
|
1304
|
-
row_factory
|
|
1305
|
-
|
|
1306
|
-
|
|
1470
|
+
row_factory: A callable to convert row to a custom format.
|
|
1471
|
+
It should accept two arguments: a list of column names and
|
|
1472
|
+
a tuple of row values.
|
|
1307
1473
|
include_hidden: Whether to include hidden signals from the schema.
|
|
1308
1474
|
"""
|
|
1309
1475
|
db_signals = self._effective_signals_schema.db_signals(
|
|
@@ -1368,7 +1534,7 @@ class DataChain:
|
|
|
1368
1534
|
"""Convert every row to a dictionary."""
|
|
1369
1535
|
|
|
1370
1536
|
def to_dict(cols: list[str], row: tuple[Any, ...]) -> dict[str, Any]:
|
|
1371
|
-
return dict(zip(cols, row))
|
|
1537
|
+
return dict(zip(cols, row, strict=False))
|
|
1372
1538
|
|
|
1373
1539
|
return self.results(row_factory=to_dict)
|
|
1374
1540
|
|
|
@@ -1426,7 +1592,7 @@ class DataChain:
|
|
|
1426
1592
|
@overload
|
|
1427
1593
|
def collect(self, *cols: str) -> Iterator[tuple[DataValue, ...]]: ...
|
|
1428
1594
|
|
|
1429
|
-
def collect(self, *cols: str) -> Iterator[
|
|
1595
|
+
def collect(self, *cols: str) -> Iterator[DataValue | tuple[DataValue, ...]]: # type: ignore[overload-overlap,misc]
|
|
1430
1596
|
"""
|
|
1431
1597
|
Deprecated. Use `to_iter` method instead.
|
|
1432
1598
|
"""
|
|
@@ -1491,8 +1657,8 @@ class DataChain:
|
|
|
1491
1657
|
def merge(
|
|
1492
1658
|
self,
|
|
1493
1659
|
right_ds: "DataChain",
|
|
1494
|
-
on:
|
|
1495
|
-
right_on:
|
|
1660
|
+
on: MergeColType | Sequence[MergeColType],
|
|
1661
|
+
right_on: MergeColType | Sequence[MergeColType] | None = None,
|
|
1496
1662
|
inner=False,
|
|
1497
1663
|
full=False,
|
|
1498
1664
|
rname="right_",
|
|
@@ -1560,8 +1726,8 @@ class DataChain:
|
|
|
1560
1726
|
|
|
1561
1727
|
def _resolve(
|
|
1562
1728
|
ds: DataChain,
|
|
1563
|
-
col:
|
|
1564
|
-
side:
|
|
1729
|
+
col: str | Function | sqlalchemy.ColumnElement,
|
|
1730
|
+
side: str | None,
|
|
1565
1731
|
):
|
|
1566
1732
|
try:
|
|
1567
1733
|
if isinstance(col, Function):
|
|
@@ -1574,7 +1740,7 @@ class DataChain:
|
|
|
1574
1740
|
ops = [
|
|
1575
1741
|
_resolve(self, left, "left")
|
|
1576
1742
|
== _resolve(right_ds, right, "right" if right_on else None)
|
|
1577
|
-
for left, right in zip(on, right_on or on)
|
|
1743
|
+
for left, right in zip(on, right_on or on, strict=False)
|
|
1578
1744
|
]
|
|
1579
1745
|
|
|
1580
1746
|
if errors:
|
|
@@ -1583,16 +1749,17 @@ class DataChain:
|
|
|
1583
1749
|
)
|
|
1584
1750
|
|
|
1585
1751
|
query = self._query.join(
|
|
1586
|
-
right_ds._query, sqlalchemy.and_(*ops), inner, full, rname
|
|
1752
|
+
right_ds._query, sqlalchemy.and_(*ops), inner, full, rname
|
|
1587
1753
|
)
|
|
1588
1754
|
query.feature_schema = None
|
|
1589
1755
|
ds = self._evolve(query=query)
|
|
1590
1756
|
|
|
1757
|
+
# Note: merge drops sys signals from both sides, make sure to not include it
|
|
1758
|
+
# in the resulting schema
|
|
1591
1759
|
signals_schema = self.signals_schema.clone_without_sys_signals()
|
|
1592
1760
|
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
)
|
|
1761
|
+
|
|
1762
|
+
ds.signals_schema = signals_schema.merge(right_signals_schema, rname)
|
|
1596
1763
|
|
|
1597
1764
|
return ds
|
|
1598
1765
|
|
|
@@ -1603,13 +1770,23 @@ class DataChain:
|
|
|
1603
1770
|
Parameters:
|
|
1604
1771
|
other: chain whose rows will be added to `self`.
|
|
1605
1772
|
"""
|
|
1773
|
+
self_schema = self.signals_schema
|
|
1774
|
+
other_schema = other.signals_schema
|
|
1775
|
+
missing_left, missing_right = self_schema.compare_signals(other_schema)
|
|
1776
|
+
if missing_left or missing_right:
|
|
1777
|
+
raise UnionSchemaMismatchError.from_column_sets(
|
|
1778
|
+
missing_left,
|
|
1779
|
+
missing_right,
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
self.signals_schema = self_schema.clone_without_sys_signals()
|
|
1606
1783
|
return self._evolve(query=self._query.union(other._query))
|
|
1607
1784
|
|
|
1608
1785
|
def subtract( # type: ignore[override]
|
|
1609
1786
|
self,
|
|
1610
1787
|
other: "DataChain",
|
|
1611
|
-
on:
|
|
1612
|
-
right_on:
|
|
1788
|
+
on: str | Sequence[str] | None = None,
|
|
1789
|
+
right_on: str | Sequence[str] | None = None,
|
|
1613
1790
|
) -> "Self":
|
|
1614
1791
|
"""Remove rows that appear in another chain.
|
|
1615
1792
|
|
|
@@ -1666,6 +1843,7 @@ class DataChain:
|
|
|
1666
1843
|
zip(
|
|
1667
1844
|
self.signals_schema.resolve(*on).db_signals(),
|
|
1668
1845
|
other.signals_schema.resolve(*right_on).db_signals(),
|
|
1846
|
+
strict=False,
|
|
1669
1847
|
) # type: ignore[arg-type]
|
|
1670
1848
|
)
|
|
1671
1849
|
return self._evolve(query=self._query.subtract(other._query, signals)) # type: ignore[arg-type]
|
|
@@ -1673,15 +1851,15 @@ class DataChain:
|
|
|
1673
1851
|
def diff(
|
|
1674
1852
|
self,
|
|
1675
1853
|
other: "DataChain",
|
|
1676
|
-
on:
|
|
1677
|
-
right_on:
|
|
1678
|
-
compare:
|
|
1679
|
-
right_compare:
|
|
1854
|
+
on: str | Sequence[str],
|
|
1855
|
+
right_on: str | Sequence[str] | None = None,
|
|
1856
|
+
compare: str | Sequence[str] | None = None,
|
|
1857
|
+
right_compare: str | Sequence[str] | None = None,
|
|
1680
1858
|
added: bool = True,
|
|
1681
1859
|
deleted: bool = True,
|
|
1682
1860
|
modified: bool = True,
|
|
1683
1861
|
same: bool = False,
|
|
1684
|
-
status_col:
|
|
1862
|
+
status_col: str | None = None,
|
|
1685
1863
|
) -> "DataChain":
|
|
1686
1864
|
"""Calculate differences between two chains.
|
|
1687
1865
|
|
|
@@ -1742,12 +1920,12 @@ class DataChain:
|
|
|
1742
1920
|
self,
|
|
1743
1921
|
other: "DataChain",
|
|
1744
1922
|
on: str = "file",
|
|
1745
|
-
right_on:
|
|
1923
|
+
right_on: str | None = None,
|
|
1746
1924
|
added: bool = True,
|
|
1747
1925
|
modified: bool = True,
|
|
1748
1926
|
deleted: bool = False,
|
|
1749
1927
|
same: bool = False,
|
|
1750
|
-
status_col:
|
|
1928
|
+
status_col: str | None = None,
|
|
1751
1929
|
) -> "DataChain":
|
|
1752
1930
|
"""Calculate differences between two chains containing files.
|
|
1753
1931
|
|
|
@@ -1845,12 +2023,15 @@ class DataChain:
|
|
|
1845
2023
|
self,
|
|
1846
2024
|
flatten: bool = False,
|
|
1847
2025
|
include_hidden: bool = True,
|
|
2026
|
+
as_object: bool = False,
|
|
1848
2027
|
) -> "pd.DataFrame":
|
|
1849
2028
|
"""Return a pandas DataFrame from the chain.
|
|
1850
2029
|
|
|
1851
2030
|
Parameters:
|
|
1852
2031
|
flatten: Whether to use a multiindex or flatten column names.
|
|
1853
2032
|
include_hidden: Whether to include hidden columns.
|
|
2033
|
+
as_object: Whether to emit a dataframe backed by Python objects
|
|
2034
|
+
rather than pandas-inferred dtypes.
|
|
1854
2035
|
|
|
1855
2036
|
Returns:
|
|
1856
2037
|
pd.DataFrame: A pandas DataFrame representation of the chain.
|
|
@@ -1860,12 +2041,18 @@ class DataChain:
|
|
|
1860
2041
|
headers, max_length = self._effective_signals_schema.get_headers_with_length(
|
|
1861
2042
|
include_hidden=include_hidden
|
|
1862
2043
|
)
|
|
2044
|
+
|
|
2045
|
+
columns: list[str] | pd.MultiIndex
|
|
1863
2046
|
if flatten or max_length < 2:
|
|
1864
2047
|
columns = [".".join(filter(None, header)) for header in headers]
|
|
1865
2048
|
else:
|
|
1866
2049
|
columns = pd.MultiIndex.from_tuples(map(tuple, headers))
|
|
1867
2050
|
|
|
1868
2051
|
results = self.results(include_hidden=include_hidden)
|
|
2052
|
+
if as_object:
|
|
2053
|
+
df = pd.DataFrame(results, columns=columns, dtype=object)
|
|
2054
|
+
df.where(pd.notna(df), None, inplace=True)
|
|
2055
|
+
return df
|
|
1869
2056
|
return pd.DataFrame.from_records(results, columns=columns)
|
|
1870
2057
|
|
|
1871
2058
|
def show(
|
|
@@ -1888,7 +2075,11 @@ class DataChain:
|
|
|
1888
2075
|
import pandas as pd
|
|
1889
2076
|
|
|
1890
2077
|
dc = self.limit(limit) if limit > 0 else self # type: ignore[misc]
|
|
1891
|
-
df = dc.to_pandas(
|
|
2078
|
+
df = dc.to_pandas(
|
|
2079
|
+
flatten,
|
|
2080
|
+
include_hidden=include_hidden,
|
|
2081
|
+
as_object=True,
|
|
2082
|
+
)
|
|
1892
2083
|
|
|
1893
2084
|
if df.empty:
|
|
1894
2085
|
print("Empty result")
|
|
@@ -1947,20 +2138,20 @@ class DataChain:
|
|
|
1947
2138
|
column: str = "",
|
|
1948
2139
|
model_name: str = "",
|
|
1949
2140
|
source: bool = True,
|
|
1950
|
-
nrows:
|
|
1951
|
-
**kwargs,
|
|
2141
|
+
nrows: int | None = None,
|
|
2142
|
+
**kwargs: Any,
|
|
1952
2143
|
) -> "Self":
|
|
1953
2144
|
"""Generate chain from list of tabular files.
|
|
1954
2145
|
|
|
1955
2146
|
Parameters:
|
|
1956
|
-
output
|
|
2147
|
+
output: Dictionary or feature class defining column names and their
|
|
1957
2148
|
corresponding types. List of column names is also accepted, in which
|
|
1958
2149
|
case types will be inferred.
|
|
1959
|
-
column
|
|
1960
|
-
model_name
|
|
1961
|
-
source
|
|
1962
|
-
nrows
|
|
1963
|
-
kwargs
|
|
2150
|
+
column: Generated column name.
|
|
2151
|
+
model_name: Generated model name.
|
|
2152
|
+
source: Whether to include info about the source file.
|
|
2153
|
+
nrows: Optional row limit.
|
|
2154
|
+
kwargs: Parameters to pass to pyarrow.dataset.dataset.
|
|
1964
2155
|
|
|
1965
2156
|
Example:
|
|
1966
2157
|
Reading a json lines file:
|
|
@@ -2081,23 +2272,23 @@ class DataChain:
|
|
|
2081
2272
|
|
|
2082
2273
|
def to_parquet(
|
|
2083
2274
|
self,
|
|
2084
|
-
path:
|
|
2085
|
-
partition_cols:
|
|
2275
|
+
path: str | os.PathLike[str] | BinaryIO,
|
|
2276
|
+
partition_cols: Sequence[str] | None = None,
|
|
2086
2277
|
chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
|
|
2087
|
-
fs_kwargs:
|
|
2278
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
2088
2279
|
**kwargs,
|
|
2089
2280
|
) -> None:
|
|
2090
2281
|
"""Save chain to parquet file with SignalSchema metadata.
|
|
2091
2282
|
|
|
2092
2283
|
Parameters:
|
|
2093
|
-
path
|
|
2284
|
+
path: Path or a file-like binary object to save the file. This supports
|
|
2094
2285
|
local paths as well as remote paths, such as s3:// or hf:// with fsspec.
|
|
2095
|
-
partition_cols
|
|
2096
|
-
chunk_size
|
|
2286
|
+
partition_cols: Column names by which to partition the dataset.
|
|
2287
|
+
chunk_size: The chunk size of results to read and convert to columnar
|
|
2097
2288
|
data, to avoid running out of memory.
|
|
2098
|
-
fs_kwargs
|
|
2099
|
-
|
|
2100
|
-
|
|
2289
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2290
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2291
|
+
are supported.
|
|
2101
2292
|
"""
|
|
2102
2293
|
import pyarrow as pa
|
|
2103
2294
|
import pyarrow.parquet as pq
|
|
@@ -2141,7 +2332,7 @@ class DataChain:
|
|
|
2141
2332
|
# pyarrow infers the best parquet schema from the python types of
|
|
2142
2333
|
# the input data.
|
|
2143
2334
|
table = pa.Table.from_pydict(
|
|
2144
|
-
dict(zip(column_names, chunk)),
|
|
2335
|
+
dict(zip(column_names, chunk, strict=False)),
|
|
2145
2336
|
schema=parquet_schema,
|
|
2146
2337
|
)
|
|
2147
2338
|
|
|
@@ -2179,137 +2370,116 @@ class DataChain:
|
|
|
2179
2370
|
|
|
2180
2371
|
def to_csv(
|
|
2181
2372
|
self,
|
|
2182
|
-
path:
|
|
2373
|
+
path: str | os.PathLike[str],
|
|
2183
2374
|
delimiter: str = ",",
|
|
2184
|
-
fs_kwargs:
|
|
2375
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
2185
2376
|
**kwargs,
|
|
2186
|
-
) ->
|
|
2187
|
-
"""Save chain to a csv (comma-separated values) file
|
|
2377
|
+
) -> File:
|
|
2378
|
+
"""Save chain to a csv (comma-separated values) file and return the stored
|
|
2379
|
+
`File`.
|
|
2188
2380
|
|
|
2189
2381
|
Parameters:
|
|
2190
|
-
path
|
|
2382
|
+
path: Path to save the file. This supports local paths as well as
|
|
2191
2383
|
remote paths, such as s3:// or hf:// with fsspec.
|
|
2192
|
-
delimiter
|
|
2193
|
-
fs_kwargs
|
|
2194
|
-
|
|
2195
|
-
|
|
2384
|
+
delimiter: Delimiter to use for the resulting file.
|
|
2385
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2386
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2387
|
+
are supported.
|
|
2388
|
+
Returns:
|
|
2389
|
+
File: The stored file with refreshed metadata (version, etag, size).
|
|
2196
2390
|
"""
|
|
2197
2391
|
import csv
|
|
2198
2392
|
|
|
2199
|
-
|
|
2200
|
-
|
|
2201
|
-
if isinstance(path, str) and "://" in path:
|
|
2202
|
-
from datachain.client.fsspec import Client
|
|
2203
|
-
|
|
2204
|
-
fs_kwargs = {
|
|
2205
|
-
**self._query.catalog.client_config,
|
|
2206
|
-
**(fs_kwargs or {}),
|
|
2207
|
-
}
|
|
2208
|
-
|
|
2209
|
-
client = Client.get_implementation(path)
|
|
2210
|
-
|
|
2211
|
-
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
2212
|
-
|
|
2213
|
-
opener = fsspec_fs.open
|
|
2393
|
+
target = File.at(path, session=self.session)
|
|
2214
2394
|
|
|
2215
2395
|
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
2216
2396
|
column_names = [".".join(filter(None, header)) for header in headers]
|
|
2217
2397
|
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
with opener(path, "w", newline="") as f:
|
|
2398
|
+
with target.open("w", newline="", client_config=fs_kwargs) as f:
|
|
2221
2399
|
writer = csv.writer(f, delimiter=delimiter, **kwargs)
|
|
2222
2400
|
writer.writerow(column_names)
|
|
2223
|
-
|
|
2224
|
-
for row in results_iter:
|
|
2401
|
+
for row in self._leaf_values():
|
|
2225
2402
|
writer.writerow(row)
|
|
2226
2403
|
|
|
2404
|
+
return target
|
|
2405
|
+
|
|
2227
2406
|
def to_json(
|
|
2228
2407
|
self,
|
|
2229
|
-
path:
|
|
2230
|
-
fs_kwargs:
|
|
2408
|
+
path: str | os.PathLike[str],
|
|
2409
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
2231
2410
|
include_outer_list: bool = True,
|
|
2232
|
-
) ->
|
|
2233
|
-
"""Save chain to a JSON file
|
|
2411
|
+
) -> File:
|
|
2412
|
+
"""Save chain to a JSON file and return the stored `File`.
|
|
2234
2413
|
|
|
2235
2414
|
Parameters:
|
|
2236
|
-
path
|
|
2415
|
+
path: Path to save the file. This supports local paths as well as
|
|
2237
2416
|
remote paths, such as s3:// or hf:// with fsspec.
|
|
2238
|
-
fs_kwargs
|
|
2239
|
-
|
|
2240
|
-
|
|
2241
|
-
include_outer_list
|
|
2417
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2418
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2419
|
+
are supported.
|
|
2420
|
+
include_outer_list: Sets whether to include an outer list for all rows.
|
|
2242
2421
|
Setting this to True makes the file valid JSON, while False instead
|
|
2243
2422
|
writes in the JSON lines format.
|
|
2423
|
+
Returns:
|
|
2424
|
+
File: The stored file with refreshed metadata (version, etag, size).
|
|
2244
2425
|
"""
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
if isinstance(path, str) and "://" in path:
|
|
2248
|
-
from datachain.client.fsspec import Client
|
|
2249
|
-
|
|
2250
|
-
fs_kwargs = {
|
|
2251
|
-
**self._query.catalog.client_config,
|
|
2252
|
-
**(fs_kwargs or {}),
|
|
2253
|
-
}
|
|
2254
|
-
|
|
2255
|
-
client = Client.get_implementation(path)
|
|
2256
|
-
|
|
2257
|
-
fsspec_fs = client.create_fs(**fs_kwargs)
|
|
2258
|
-
|
|
2259
|
-
opener = fsspec_fs.open
|
|
2260
|
-
|
|
2426
|
+
target = File.at(path, session=self.session)
|
|
2261
2427
|
headers, _ = self._effective_signals_schema.get_headers_with_length()
|
|
2262
|
-
headers = [list(filter(None,
|
|
2428
|
+
headers = [list(filter(None, h)) for h in headers]
|
|
2429
|
+
with target.open("wb", client_config=fs_kwargs) as f:
|
|
2430
|
+
self._write_json_stream(f, headers, include_outer_list)
|
|
2431
|
+
return target
|
|
2263
2432
|
|
|
2433
|
+
def _write_json_stream(
|
|
2434
|
+
self,
|
|
2435
|
+
f: IO[bytes],
|
|
2436
|
+
headers: list[list[str]],
|
|
2437
|
+
include_outer_list: bool,
|
|
2438
|
+
) -> None:
|
|
2264
2439
|
is_first = True
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
f.write(b"
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
json.dumps(
|
|
2281
|
-
row_to_nested_dict(headers, row), ensure_ascii=False
|
|
2282
|
-
).encode("utf-8")
|
|
2283
|
-
)
|
|
2284
|
-
if include_outer_list:
|
|
2285
|
-
# This makes the file JSON instead of JSON lines.
|
|
2286
|
-
f.write(b"\n]\n")
|
|
2440
|
+
if include_outer_list:
|
|
2441
|
+
f.write(b"[\n")
|
|
2442
|
+
for row in self._leaf_values():
|
|
2443
|
+
if not is_first:
|
|
2444
|
+
f.write(b",\n" if include_outer_list else b"\n")
|
|
2445
|
+
else:
|
|
2446
|
+
is_first = False
|
|
2447
|
+
f.write(
|
|
2448
|
+
json.dumps(
|
|
2449
|
+
row_to_nested_dict(headers, row),
|
|
2450
|
+
ensure_ascii=False,
|
|
2451
|
+
).encode("utf-8")
|
|
2452
|
+
)
|
|
2453
|
+
if include_outer_list:
|
|
2454
|
+
f.write(b"\n]\n")
|
|
2287
2455
|
|
|
2288
2456
|
def to_jsonl(
|
|
2289
2457
|
self,
|
|
2290
|
-
path:
|
|
2291
|
-
fs_kwargs:
|
|
2292
|
-
) ->
|
|
2458
|
+
path: str | os.PathLike[str],
|
|
2459
|
+
fs_kwargs: dict[str, Any] | None = None,
|
|
2460
|
+
) -> File:
|
|
2293
2461
|
"""Save chain to a JSON lines file.
|
|
2294
2462
|
|
|
2295
2463
|
Parameters:
|
|
2296
|
-
path
|
|
2464
|
+
path: Path to save the file. This supports local paths as well as
|
|
2297
2465
|
remote paths, such as s3:// or hf:// with fsspec.
|
|
2298
|
-
fs_kwargs
|
|
2299
|
-
|
|
2300
|
-
|
|
2466
|
+
fs_kwargs: Optional kwargs forwarded to the underlying fsspec filesystem
|
|
2467
|
+
when writing (e.g., s3://, gs://, hf://), fsspec-specific options
|
|
2468
|
+
are supported.
|
|
2469
|
+
Returns:
|
|
2470
|
+
File: The stored file with refreshed metadata (version, etag, size).
|
|
2301
2471
|
"""
|
|
2302
|
-
self.to_json(path, fs_kwargs, include_outer_list=False)
|
|
2472
|
+
return self.to_json(path, fs_kwargs, include_outer_list=False)
|
|
2303
2473
|
|
|
2304
2474
|
def to_database(
|
|
2305
2475
|
self,
|
|
2306
2476
|
table_name: str,
|
|
2307
2477
|
connection: "ConnectionType",
|
|
2308
2478
|
*,
|
|
2309
|
-
|
|
2310
|
-
on_conflict:
|
|
2311
|
-
conflict_columns:
|
|
2312
|
-
column_mapping:
|
|
2479
|
+
batch_size: int = DEFAULT_DATABASE_BATCH_SIZE,
|
|
2480
|
+
on_conflict: str | None = None,
|
|
2481
|
+
conflict_columns: list[str] | None = None,
|
|
2482
|
+
column_mapping: dict[str, str | None] | None = None,
|
|
2313
2483
|
) -> int:
|
|
2314
2484
|
"""Save chain to a database table using a given database connection.
|
|
2315
2485
|
|
|
@@ -2328,7 +2498,7 @@ class DataChain:
|
|
|
2328
2498
|
library. If a DBAPI2 object, only sqlite3 is supported. The user is
|
|
2329
2499
|
responsible for engine disposal and connection closure for the
|
|
2330
2500
|
SQLAlchemy connectable; str connections are closed automatically.
|
|
2331
|
-
|
|
2501
|
+
batch_size: Number of rows to insert per batch for optimal performance.
|
|
2332
2502
|
Larger batches are faster but use more memory. Default: 10,000.
|
|
2333
2503
|
on_conflict: Strategy for handling duplicate rows (requires table
|
|
2334
2504
|
constraints):
|
|
@@ -2409,7 +2579,7 @@ class DataChain:
|
|
|
2409
2579
|
self,
|
|
2410
2580
|
table_name,
|
|
2411
2581
|
connection,
|
|
2412
|
-
|
|
2582
|
+
batch_size=batch_size,
|
|
2413
2583
|
on_conflict=on_conflict,
|
|
2414
2584
|
conflict_columns=conflict_columns,
|
|
2415
2585
|
column_mapping=column_mapping,
|
|
@@ -2545,13 +2715,13 @@ class DataChain:
|
|
|
2545
2715
|
|
|
2546
2716
|
def to_storage(
|
|
2547
2717
|
self,
|
|
2548
|
-
output:
|
|
2718
|
+
output: str | os.PathLike[str],
|
|
2549
2719
|
signal: str = "file",
|
|
2550
2720
|
placement: FileExportPlacement = "fullpath",
|
|
2551
2721
|
link_type: Literal["copy", "symlink"] = "copy",
|
|
2552
|
-
num_threads:
|
|
2553
|
-
anon:
|
|
2554
|
-
client_config:
|
|
2722
|
+
num_threads: int | None = EXPORT_FILES_MAX_THREADS,
|
|
2723
|
+
anon: bool | None = None,
|
|
2724
|
+
client_config: dict | None = None,
|
|
2555
2725
|
) -> None:
|
|
2556
2726
|
"""Export files from a specified signal to a directory. Files can be
|
|
2557
2727
|
exported to a local or cloud directory.
|
|
@@ -2560,12 +2730,24 @@ class DataChain:
|
|
|
2560
2730
|
output: Path to the target directory for exporting files.
|
|
2561
2731
|
signal: Name of the signal to export files from.
|
|
2562
2732
|
placement: The method to use for naming exported files.
|
|
2563
|
-
The possible values are: "filename", "etag", "fullpath",
|
|
2733
|
+
The possible values are: "filename", "etag", "fullpath",
|
|
2734
|
+
"filepath", and "checksum".
|
|
2735
|
+
Example path translations for an object located at
|
|
2736
|
+
``s3://bucket/data/img.jpg`` and exported to ``./out``:
|
|
2737
|
+
|
|
2738
|
+
- "filename" -> ``./out/img.jpg`` (no directories)
|
|
2739
|
+
- "filepath" -> ``./out/data/img.jpg`` (relative path kept)
|
|
2740
|
+
- "fullpath" -> ``./out/bucket/data/img.jpg`` (remote host kept)
|
|
2741
|
+
- "etag" -> ``./out/<etag>.jpg`` (unique name via object digest)
|
|
2742
|
+
|
|
2743
|
+
Local sources behave like "filepath" for "fullpath" placement.
|
|
2744
|
+
Relative destinations such as "." or ".." and absolute paths
|
|
2745
|
+
are supported for every strategy.
|
|
2564
2746
|
link_type: Method to use for exporting files.
|
|
2565
2747
|
Falls back to `'copy'` if symlinking fails.
|
|
2566
|
-
num_threads
|
|
2567
|
-
By default it uses 5 threads.
|
|
2568
|
-
anon: If True, we will treat cloud bucket as public one. Default behavior
|
|
2748
|
+
num_threads: number of threads to use for exporting files.
|
|
2749
|
+
By default, it uses 5 threads.
|
|
2750
|
+
anon: If True, we will treat cloud bucket as a public one. Default behavior
|
|
2569
2751
|
depends on the previous session configuration (e.g. happens in the
|
|
2570
2752
|
initial `read_storage`) and particular cloud storage client
|
|
2571
2753
|
implementation (e.g. S3 fallbacks to anonymous access if no credentials
|
|
@@ -2614,8 +2796,20 @@ class DataChain:
|
|
|
2614
2796
|
)
|
|
2615
2797
|
|
|
2616
2798
|
def shuffle(self) -> "Self":
|
|
2617
|
-
"""Shuffle
|
|
2618
|
-
|
|
2799
|
+
"""Shuffle rows with a best-effort deterministic ordering.
|
|
2800
|
+
|
|
2801
|
+
This produces repeatable shuffles. Merge and union operations can
|
|
2802
|
+
lead to non-deterministic results. Use order by or save a dataset
|
|
2803
|
+
afterward to guarantee the same result.
|
|
2804
|
+
"""
|
|
2805
|
+
query = self._query.clone(new_table=False)
|
|
2806
|
+
query.steps.append(RegenerateSystemColumns(self._query.catalog))
|
|
2807
|
+
|
|
2808
|
+
chain = self._evolve(
|
|
2809
|
+
query=query,
|
|
2810
|
+
signal_schema=SignalSchema({"sys": Sys}) | self.signals_schema,
|
|
2811
|
+
)
|
|
2812
|
+
return chain.order_by("sys.rand")
|
|
2619
2813
|
|
|
2620
2814
|
def sample(self, n: int) -> "Self":
|
|
2621
2815
|
"""Return a random sample from the chain.
|