datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/delta.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from copy import copy
|
|
3
3
|
from functools import wraps
|
|
4
|
-
from typing import TYPE_CHECKING,
|
|
4
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
5
5
|
|
|
6
6
|
import datachain
|
|
7
|
-
from datachain.dataset import DatasetDependency
|
|
8
|
-
from datachain.error import DatasetNotFoundError
|
|
7
|
+
from datachain.dataset import DatasetDependency, DatasetRecord
|
|
8
|
+
from datachain.error import DatasetNotFoundError, SchemaDriftError
|
|
9
9
|
from datachain.project import Project
|
|
10
|
+
from datachain.query.dataset import UnionSchemaMismatchError
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
|
-
from
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from typing import Concatenate
|
|
15
|
+
|
|
16
|
+
from typing_extensions import ParamSpec
|
|
13
17
|
|
|
14
18
|
from datachain.lib.dc import DataChain
|
|
19
|
+
from datachain.lib.signal_schema import SignalSchema
|
|
15
20
|
|
|
16
21
|
P = ParamSpec("P")
|
|
17
22
|
|
|
@@ -30,9 +35,10 @@ def delta_disabled(
|
|
|
30
35
|
|
|
31
36
|
@wraps(method)
|
|
32
37
|
def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T:
|
|
33
|
-
if self.delta:
|
|
38
|
+
if self.delta and not self._delta_unsafe:
|
|
34
39
|
raise NotImplementedError(
|
|
35
|
-
f"
|
|
40
|
+
f"Cannot use {method.__name__} with delta datasets - may cause"
|
|
41
|
+
" inconsistency. Use delta_unsafe flag to allow this operation."
|
|
36
42
|
)
|
|
37
43
|
return method(self, *args, **kwargs)
|
|
38
44
|
|
|
@@ -49,13 +55,55 @@ def _append_steps(dc: "DataChain", other: "DataChain"):
|
|
|
49
55
|
return dc
|
|
50
56
|
|
|
51
57
|
|
|
58
|
+
def _format_schema_drift_message(
|
|
59
|
+
context: str,
|
|
60
|
+
existing_schema: "SignalSchema",
|
|
61
|
+
updated_schema: "SignalSchema",
|
|
62
|
+
) -> tuple[str, bool]:
|
|
63
|
+
missing_cols, new_cols = existing_schema.compare_signals(updated_schema)
|
|
64
|
+
|
|
65
|
+
if not new_cols and not missing_cols:
|
|
66
|
+
return "", False
|
|
67
|
+
|
|
68
|
+
parts: list[str] = []
|
|
69
|
+
if new_cols:
|
|
70
|
+
parts.append("new columns detected: " + ", ".join(sorted(new_cols)))
|
|
71
|
+
if missing_cols:
|
|
72
|
+
parts.append(
|
|
73
|
+
"columns missing in updated data: " + ", ".join(sorted(missing_cols))
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
details = "; ".join(parts)
|
|
77
|
+
message = f"Delta update failed: schema drift detected while {context}: {details}."
|
|
78
|
+
|
|
79
|
+
return message, True
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _safe_union(
|
|
83
|
+
left: "DataChain",
|
|
84
|
+
right: "DataChain",
|
|
85
|
+
context: str,
|
|
86
|
+
) -> "DataChain":
|
|
87
|
+
try:
|
|
88
|
+
return left.union(right)
|
|
89
|
+
except UnionSchemaMismatchError as exc:
|
|
90
|
+
message, has_drift = _format_schema_drift_message(
|
|
91
|
+
context,
|
|
92
|
+
left.signals_schema,
|
|
93
|
+
right.signals_schema,
|
|
94
|
+
)
|
|
95
|
+
if has_drift:
|
|
96
|
+
raise SchemaDriftError(message) from exc
|
|
97
|
+
raise
|
|
98
|
+
|
|
99
|
+
|
|
52
100
|
def _get_delta_chain(
|
|
53
101
|
source_ds_name: str,
|
|
54
102
|
source_ds_project: Project,
|
|
55
103
|
source_ds_version: str,
|
|
56
104
|
source_ds_latest_version: str,
|
|
57
|
-
on:
|
|
58
|
-
compare:
|
|
105
|
+
on: str | Sequence[str],
|
|
106
|
+
compare: str | Sequence[str] | None = None,
|
|
59
107
|
) -> "DataChain":
|
|
60
108
|
"""Get delta chain for processing changes between versions."""
|
|
61
109
|
source_dc = datachain.read_dataset(
|
|
@@ -83,11 +131,11 @@ def _get_retry_chain(
|
|
|
83
131
|
source_ds_name: str,
|
|
84
132
|
source_ds_project: Project,
|
|
85
133
|
source_ds_version: str,
|
|
86
|
-
on:
|
|
87
|
-
right_on:
|
|
88
|
-
delta_retry:
|
|
134
|
+
on: str | Sequence[str],
|
|
135
|
+
right_on: str | Sequence[str] | None,
|
|
136
|
+
delta_retry: bool | str | None,
|
|
89
137
|
diff_chain: "DataChain",
|
|
90
|
-
) ->
|
|
138
|
+
) -> "DataChain | None":
|
|
91
139
|
"""Get retry chain for processing error records and missing records."""
|
|
92
140
|
# Import here to avoid circular import
|
|
93
141
|
from datachain.lib.dc import C
|
|
@@ -113,7 +161,9 @@ def _get_retry_chain(
|
|
|
113
161
|
error_records = result_dataset.filter(C(delta_retry) != "")
|
|
114
162
|
error_source_records = source_dc.merge(
|
|
115
163
|
error_records, on=on, right_on=right_on, inner=True
|
|
116
|
-
).select(
|
|
164
|
+
).select(
|
|
165
|
+
*list(source_dc.signals_schema.clone_without_sys_signals().values.keys())
|
|
166
|
+
)
|
|
117
167
|
retry_chain = error_source_records
|
|
118
168
|
|
|
119
169
|
# Handle missing records if delta_retry is True
|
|
@@ -124,21 +174,30 @@ def _get_retry_chain(
|
|
|
124
174
|
# Subtract also diff chain since some items might be picked
|
|
125
175
|
# up by `delta=True` itself (e.g. records got modified AND are missing in the
|
|
126
176
|
# result dataset atm)
|
|
127
|
-
|
|
177
|
+
on = [on] if isinstance(on, str) else on
|
|
178
|
+
|
|
179
|
+
return (
|
|
180
|
+
retry_chain.diff(
|
|
181
|
+
diff_chain, on=on, added=True, same=True, modified=False, deleted=False
|
|
182
|
+
).distinct(*on)
|
|
183
|
+
if retry_chain
|
|
184
|
+
else None
|
|
185
|
+
)
|
|
128
186
|
|
|
129
187
|
|
|
130
188
|
def _get_source_info(
|
|
189
|
+
source_ds: DatasetRecord,
|
|
131
190
|
name: str,
|
|
132
191
|
namespace_name: str,
|
|
133
192
|
project_name: str,
|
|
134
193
|
latest_version: str,
|
|
135
194
|
catalog,
|
|
136
195
|
) -> tuple[
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
196
|
+
str | None,
|
|
197
|
+
Project | None,
|
|
198
|
+
str | None,
|
|
199
|
+
str | None,
|
|
200
|
+
list[DatasetDependency] | None,
|
|
142
201
|
]:
|
|
143
202
|
"""Get source dataset information and dependencies.
|
|
144
203
|
|
|
@@ -154,25 +213,25 @@ def _get_source_info(
|
|
|
154
213
|
indirect=False,
|
|
155
214
|
)
|
|
156
215
|
|
|
157
|
-
|
|
158
|
-
|
|
216
|
+
source_ds_dep = next(
|
|
217
|
+
(d for d in dependencies if d and d.name == source_ds.name), None
|
|
218
|
+
)
|
|
219
|
+
if not source_ds_dep:
|
|
159
220
|
# Starting dataset was removed, back off to normal dataset creation
|
|
160
221
|
return None, None, None, None, None
|
|
161
222
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
project_name=source_ds_project.name,
|
|
169
|
-
).latest_version
|
|
223
|
+
# Refresh starting dataset to have new versions if they are created
|
|
224
|
+
source_ds = catalog.get_dataset(
|
|
225
|
+
source_ds.name,
|
|
226
|
+
namespace_name=source_ds.project.namespace.name,
|
|
227
|
+
project_name=source_ds.project.name,
|
|
228
|
+
)
|
|
170
229
|
|
|
171
230
|
return (
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
231
|
+
source_ds.name,
|
|
232
|
+
source_ds.project,
|
|
233
|
+
source_ds_dep.version,
|
|
234
|
+
source_ds.latest_version,
|
|
176
235
|
dependencies,
|
|
177
236
|
)
|
|
178
237
|
|
|
@@ -182,11 +241,11 @@ def delta_retry_update(
|
|
|
182
241
|
namespace_name: str,
|
|
183
242
|
project_name: str,
|
|
184
243
|
name: str,
|
|
185
|
-
on:
|
|
186
|
-
right_on:
|
|
187
|
-
compare:
|
|
188
|
-
delta_retry:
|
|
189
|
-
) -> tuple[
|
|
244
|
+
on: str | Sequence[str],
|
|
245
|
+
right_on: str | Sequence[str] | None = None,
|
|
246
|
+
compare: str | Sequence[str] | None = None,
|
|
247
|
+
delta_retry: bool | str | None = None,
|
|
248
|
+
) -> tuple["DataChain | None", list[DatasetDependency] | None, bool]:
|
|
190
249
|
"""
|
|
191
250
|
Creates new chain that consists of the last version of current delta dataset
|
|
192
251
|
plus diff from the source with all needed modifications.
|
|
@@ -244,7 +303,14 @@ def delta_retry_update(
|
|
|
244
303
|
source_ds_version,
|
|
245
304
|
source_ds_latest_version,
|
|
246
305
|
dependencies,
|
|
247
|
-
) = _get_source_info(
|
|
306
|
+
) = _get_source_info(
|
|
307
|
+
dc._query.starting_step.dataset, # type: ignore[union-attr]
|
|
308
|
+
name,
|
|
309
|
+
namespace_name,
|
|
310
|
+
project_name,
|
|
311
|
+
latest_version,
|
|
312
|
+
catalog,
|
|
313
|
+
)
|
|
248
314
|
|
|
249
315
|
# If source_ds_name is None, starting dataset was removed
|
|
250
316
|
if source_ds_name is None:
|
|
@@ -267,8 +333,9 @@ def delta_retry_update(
|
|
|
267
333
|
if dependencies:
|
|
268
334
|
dependencies = copy(dependencies)
|
|
269
335
|
dependencies = [d for d in dependencies if d is not None]
|
|
336
|
+
source_ds_dep = next(d for d in dependencies if d.name == source_ds_name)
|
|
270
337
|
# Update to latest version
|
|
271
|
-
|
|
338
|
+
source_ds_dep.version = source_ds_latest_version # type: ignore[union-attr]
|
|
272
339
|
|
|
273
340
|
# Handle retry functionality if enabled
|
|
274
341
|
if delta_retry:
|
|
@@ -288,7 +355,11 @@ def delta_retry_update(
|
|
|
288
355
|
|
|
289
356
|
# Combine delta and retry chains
|
|
290
357
|
if retry_chain is not None:
|
|
291
|
-
processing_chain =
|
|
358
|
+
processing_chain = _safe_union(
|
|
359
|
+
diff_chain,
|
|
360
|
+
retry_chain,
|
|
361
|
+
context="combining retry records with delta changes",
|
|
362
|
+
)
|
|
292
363
|
else:
|
|
293
364
|
processing_chain = diff_chain
|
|
294
365
|
|
|
@@ -312,5 +383,9 @@ def delta_retry_update(
|
|
|
312
383
|
modified=False,
|
|
313
384
|
deleted=False,
|
|
314
385
|
)
|
|
315
|
-
result_chain =
|
|
386
|
+
result_chain = _safe_union(
|
|
387
|
+
compared_chain,
|
|
388
|
+
processing_chain,
|
|
389
|
+
context="merging the delta output with the existing dataset version",
|
|
390
|
+
)
|
|
316
391
|
return result_chain, dependencies, True
|
datachain/diff/__init__.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
|
-
import random
|
|
2
|
-
import string
|
|
3
1
|
from collections.abc import Sequence
|
|
4
2
|
from enum import Enum
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
6
4
|
|
|
7
5
|
from datachain.func import case, ifelse, isnone, or_
|
|
8
6
|
from datachain.lib.signal_schema import SignalSchema
|
|
@@ -11,16 +9,12 @@ from datachain.query.schema import Column
|
|
|
11
9
|
if TYPE_CHECKING:
|
|
12
10
|
from datachain.lib.dc import DataChain
|
|
13
11
|
|
|
14
|
-
|
|
15
12
|
C = Column
|
|
16
13
|
|
|
17
14
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
random.choice(string.ascii_letters) # noqa: S311
|
|
22
|
-
for _ in range(10)
|
|
23
|
-
)
|
|
15
|
+
STATUS_COL_NAME = "diff_7aeed3aa17ba4d50b8d1c368c76e16a6"
|
|
16
|
+
LEFT_DIFF_COL_NAME = "diff_95f95344064a4b819c8625cd1a5cfc2b"
|
|
17
|
+
RIGHT_DIFF_COL_NAME = "diff_5808838a49b54849aa461d7387376d34"
|
|
24
18
|
|
|
25
19
|
|
|
26
20
|
class CompareStatus(str, Enum):
|
|
@@ -30,25 +24,25 @@ class CompareStatus(str, Enum):
|
|
|
30
24
|
SAME = "S"
|
|
31
25
|
|
|
32
26
|
|
|
33
|
-
def _compare( # noqa: C901
|
|
27
|
+
def _compare( # noqa: C901
|
|
34
28
|
left: "DataChain",
|
|
35
29
|
right: "DataChain",
|
|
36
|
-
on:
|
|
37
|
-
right_on:
|
|
38
|
-
compare:
|
|
39
|
-
right_compare:
|
|
30
|
+
on: str | Sequence[str],
|
|
31
|
+
right_on: str | Sequence[str] | None = None,
|
|
32
|
+
compare: str | Sequence[str] | None = None,
|
|
33
|
+
right_compare: str | Sequence[str] | None = None,
|
|
40
34
|
added: bool = True,
|
|
41
35
|
deleted: bool = True,
|
|
42
36
|
modified: bool = True,
|
|
43
37
|
same: bool = True,
|
|
44
|
-
status_col:
|
|
38
|
+
status_col: str | None = None,
|
|
45
39
|
) -> "DataChain":
|
|
46
40
|
"""Comparing two chains by identifying rows that are added, deleted, modified
|
|
47
41
|
or same"""
|
|
48
42
|
rname = "right_"
|
|
49
43
|
schema = left.signals_schema # final chain must have schema from left chain
|
|
50
44
|
|
|
51
|
-
def _to_list(obj:
|
|
45
|
+
def _to_list(obj: str | Sequence[str] | None) -> list[str] | None:
|
|
52
46
|
if obj is None:
|
|
53
47
|
return None
|
|
54
48
|
return [obj] if isinstance(obj, str) else list(obj)
|
|
@@ -101,21 +95,23 @@ def _compare( # noqa: C901, PLR0912
|
|
|
101
95
|
compare = right_compare = [c for c in cols if c in right_cols and c not in on] # type: ignore[misc]
|
|
102
96
|
|
|
103
97
|
# get diff column names
|
|
104
|
-
diff_col = status_col or
|
|
105
|
-
ldiff_col =
|
|
106
|
-
rdiff_col =
|
|
98
|
+
diff_col = status_col or STATUS_COL_NAME
|
|
99
|
+
ldiff_col = LEFT_DIFF_COL_NAME
|
|
100
|
+
rdiff_col = RIGHT_DIFF_COL_NAME
|
|
107
101
|
|
|
108
102
|
# adding helper diff columns, which will be removed after
|
|
109
103
|
left = left.mutate(**{ldiff_col: 1})
|
|
110
104
|
right = right.mutate(**{rdiff_col: 1})
|
|
111
105
|
|
|
112
|
-
if
|
|
106
|
+
if compare is None:
|
|
113
107
|
modified_cond = True
|
|
108
|
+
elif len(compare) == 0:
|
|
109
|
+
modified_cond = False
|
|
114
110
|
else:
|
|
115
111
|
modified_cond = or_( # type: ignore[assignment]
|
|
116
112
|
*[
|
|
117
113
|
C(c) != (C(f"{rname}{rc}") if c == rc else C(rc))
|
|
118
|
-
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
|
|
114
|
+
for c, rc in zip(compare, right_compare, strict=False) # type: ignore[arg-type]
|
|
119
115
|
]
|
|
120
116
|
)
|
|
121
117
|
|
|
@@ -139,7 +135,7 @@ def _compare( # noqa: C901, PLR0912
|
|
|
139
135
|
C(f"{rname + l_on if on == right_on else r_on}"),
|
|
140
136
|
C(l_on),
|
|
141
137
|
)
|
|
142
|
-
for l_on, r_on in zip(on, right_on) # type: ignore[arg-type]
|
|
138
|
+
for l_on, r_on in zip(on, right_on, strict=False) # type: ignore[arg-type]
|
|
143
139
|
}
|
|
144
140
|
)
|
|
145
141
|
.select_except(ldiff_col, rdiff_col)
|
|
@@ -157,11 +153,7 @@ def _compare( # noqa: C901, PLR0912
|
|
|
157
153
|
if status_col:
|
|
158
154
|
cols_select.append(diff_col)
|
|
159
155
|
|
|
160
|
-
|
|
161
|
-
# TODO workaround when sys signal is not available in diff
|
|
162
|
-
dc_diff = dc_diff.settings(sys=True).select(*cols_select).settings(sys=False)
|
|
163
|
-
else:
|
|
164
|
-
dc_diff = dc_diff.select(*cols_select)
|
|
156
|
+
dc_diff = dc_diff.select(*cols_select)
|
|
165
157
|
|
|
166
158
|
# final schema is schema from the left chain with status column added if needed
|
|
167
159
|
dc_diff.signals_schema = (
|
|
@@ -174,10 +166,10 @@ def _compare( # noqa: C901, PLR0912
|
|
|
174
166
|
def compare_and_split(
|
|
175
167
|
left: "DataChain",
|
|
176
168
|
right: "DataChain",
|
|
177
|
-
on:
|
|
178
|
-
right_on:
|
|
179
|
-
compare:
|
|
180
|
-
right_compare:
|
|
169
|
+
on: str | Sequence[str],
|
|
170
|
+
right_on: str | Sequence[str] | None = None,
|
|
171
|
+
compare: str | Sequence[str] | None = None,
|
|
172
|
+
right_compare: str | Sequence[str] | None = None,
|
|
181
173
|
added: bool = True,
|
|
182
174
|
deleted: bool = True,
|
|
183
175
|
modified: bool = True,
|
|
@@ -227,7 +219,7 @@ def compare_and_split(
|
|
|
227
219
|
)
|
|
228
220
|
```
|
|
229
221
|
"""
|
|
230
|
-
status_col =
|
|
222
|
+
status_col = STATUS_COL_NAME
|
|
231
223
|
|
|
232
224
|
res = _compare(
|
|
233
225
|
left,
|
datachain/error.py
CHANGED
|
@@ -2,6 +2,10 @@ class DataChainError(RuntimeError):
|
|
|
2
2
|
pass
|
|
3
3
|
|
|
4
4
|
|
|
5
|
+
class SchemaDriftError(DataChainError):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
|
|
5
9
|
class InvalidDatasetNameError(RuntimeError):
|
|
6
10
|
pass
|
|
7
11
|
|
|
@@ -34,6 +38,14 @@ class ProjectCreateNotAllowedError(NotAllowedError):
|
|
|
34
38
|
pass
|
|
35
39
|
|
|
36
40
|
|
|
41
|
+
class ProjectDeleteNotAllowedError(NotAllowedError):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class NamespaceDeleteNotAllowedError(NotAllowedError):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
37
49
|
class ProjectNotFoundError(NotFoundError):
|
|
38
50
|
pass
|
|
39
51
|
|
|
@@ -89,3 +101,15 @@ class TableMissingError(DataChainError):
|
|
|
89
101
|
|
|
90
102
|
class OutdatedDatabaseSchemaError(DataChainError):
|
|
91
103
|
pass
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class CheckpointNotFoundError(NotFoundError):
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class JobNotFoundError(NotFoundError):
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class JobAncestryDepthExceededError(DataChainError):
|
|
115
|
+
pass
|
datachain/func/aggregate.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
|
-
|
|
3
1
|
from sqlalchemy import func as sa_func
|
|
4
2
|
|
|
5
3
|
from datachain.query.schema import Column
|
|
@@ -8,7 +6,7 @@ from datachain.sql.functions import aggregate
|
|
|
8
6
|
from .func import Func
|
|
9
7
|
|
|
10
8
|
|
|
11
|
-
def count(col:
|
|
9
|
+
def count(col: str | Column | None = None) -> Func:
|
|
12
10
|
"""
|
|
13
11
|
Returns a COUNT aggregate SQL function for the specified column.
|
|
14
12
|
|
|
@@ -44,7 +42,7 @@ def count(col: Optional[Union[str, Column]] = None) -> Func:
|
|
|
44
42
|
)
|
|
45
43
|
|
|
46
44
|
|
|
47
|
-
def sum(col:
|
|
45
|
+
def sum(col: str | Column) -> Func:
|
|
48
46
|
"""
|
|
49
47
|
Returns the SUM aggregate SQL function for the specified column.
|
|
50
48
|
|
|
@@ -74,7 +72,7 @@ def sum(col: Union[str, Column]) -> Func:
|
|
|
74
72
|
return Func("sum", inner=sa_func.sum, cols=[col])
|
|
75
73
|
|
|
76
74
|
|
|
77
|
-
def avg(col:
|
|
75
|
+
def avg(col: str | Column) -> Func:
|
|
78
76
|
"""
|
|
79
77
|
Returns the AVG aggregate SQL function for the specified column.
|
|
80
78
|
|
|
@@ -104,7 +102,7 @@ def avg(col: Union[str, Column]) -> Func:
|
|
|
104
102
|
return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
|
|
105
103
|
|
|
106
104
|
|
|
107
|
-
def min(col:
|
|
105
|
+
def min(col: str | Column) -> Func:
|
|
108
106
|
"""
|
|
109
107
|
Returns the MIN aggregate SQL function for the specified column.
|
|
110
108
|
|
|
@@ -134,7 +132,7 @@ def min(col: Union[str, Column]) -> Func:
|
|
|
134
132
|
return Func("min", inner=sa_func.min, cols=[col])
|
|
135
133
|
|
|
136
134
|
|
|
137
|
-
def max(col:
|
|
135
|
+
def max(col: str | Column) -> Func:
|
|
138
136
|
"""
|
|
139
137
|
Returns the MAX aggregate SQL function for the given column name.
|
|
140
138
|
|
|
@@ -164,7 +162,7 @@ def max(col: Union[str, Column]) -> Func:
|
|
|
164
162
|
return Func("max", inner=sa_func.max, cols=[col])
|
|
165
163
|
|
|
166
164
|
|
|
167
|
-
def any_value(col:
|
|
165
|
+
def any_value(col: str | Column) -> Func:
|
|
168
166
|
"""
|
|
169
167
|
Returns the ANY_VALUE aggregate SQL function for the given column name.
|
|
170
168
|
|
|
@@ -198,7 +196,7 @@ def any_value(col: Union[str, Column]) -> Func:
|
|
|
198
196
|
return Func("any_value", inner=aggregate.any_value, cols=[col])
|
|
199
197
|
|
|
200
198
|
|
|
201
|
-
def collect(col:
|
|
199
|
+
def collect(col: str | Column) -> Func:
|
|
202
200
|
"""
|
|
203
201
|
Returns the COLLECT aggregate SQL function for the given column name.
|
|
204
202
|
|
|
@@ -229,7 +227,7 @@ def collect(col: Union[str, Column]) -> Func:
|
|
|
229
227
|
return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
|
|
230
228
|
|
|
231
229
|
|
|
232
|
-
def concat(col:
|
|
230
|
+
def concat(col: str | Column, separator="") -> Func:
|
|
233
231
|
"""
|
|
234
232
|
Returns the CONCAT aggregate SQL function for the given column name.
|
|
235
233
|
|
|
@@ -348,7 +346,7 @@ def dense_rank() -> Func:
|
|
|
348
346
|
return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True)
|
|
349
347
|
|
|
350
348
|
|
|
351
|
-
def first(col:
|
|
349
|
+
def first(col: str | Column) -> Func:
|
|
352
350
|
"""
|
|
353
351
|
Returns the FIRST_VALUE window function for SQL queries.
|
|
354
352
|
|
datachain/func/array.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
from datachain.query.schema import Column
|
|
5
5
|
from datachain.sql.functions import array
|
|
@@ -7,7 +7,7 @@ from datachain.sql.functions import array
|
|
|
7
7
|
from .func import Func
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def cosine_distance(*args:
|
|
10
|
+
def cosine_distance(*args: str | Column | Func | Sequence) -> Func:
|
|
11
11
|
"""
|
|
12
12
|
Returns the cosine distance between two vectors.
|
|
13
13
|
|
|
@@ -62,7 +62,7 @@ def cosine_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
|
|
|
62
62
|
)
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def euclidean_distance(*args:
|
|
65
|
+
def euclidean_distance(*args: str | Column | Func | Sequence) -> Func:
|
|
66
66
|
"""
|
|
67
67
|
Returns the Euclidean distance between two vectors.
|
|
68
68
|
|
|
@@ -115,7 +115,7 @@ def euclidean_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
|
|
|
115
115
|
)
|
|
116
116
|
|
|
117
117
|
|
|
118
|
-
def length(arg:
|
|
118
|
+
def length(arg: str | Column | Func | Sequence) -> Func:
|
|
119
119
|
"""
|
|
120
120
|
Returns the length of the array.
|
|
121
121
|
|
|
@@ -151,7 +151,7 @@ def length(arg: Union[str, Column, Func, Sequence]) -> Func:
|
|
|
151
151
|
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
|
|
152
152
|
|
|
153
153
|
|
|
154
|
-
def contains(arr:
|
|
154
|
+
def contains(arr: str | Column | Func | Sequence, elem: Any) -> Func:
|
|
155
155
|
"""
|
|
156
156
|
Checks whether the array contains the specified element.
|
|
157
157
|
|
|
@@ -196,9 +196,9 @@ def contains(arr: Union[str, Column, Func, Sequence], elem: Any) -> Func:
|
|
|
196
196
|
|
|
197
197
|
|
|
198
198
|
def slice(
|
|
199
|
-
arr:
|
|
199
|
+
arr: str | Column | Func | Sequence,
|
|
200
200
|
offset: int,
|
|
201
|
-
length:
|
|
201
|
+
length: int | None = None,
|
|
202
202
|
) -> Func:
|
|
203
203
|
"""
|
|
204
204
|
Returns a slice of the array starting from the specified offset.
|
|
@@ -272,7 +272,7 @@ def slice(
|
|
|
272
272
|
|
|
273
273
|
|
|
274
274
|
def join(
|
|
275
|
-
arr:
|
|
275
|
+
arr: str | Column | Func | Sequence,
|
|
276
276
|
sep: str = "",
|
|
277
277
|
) -> Func:
|
|
278
278
|
"""
|
|
@@ -322,7 +322,7 @@ def join(
|
|
|
322
322
|
)
|
|
323
323
|
|
|
324
324
|
|
|
325
|
-
def get_element(arg:
|
|
325
|
+
def get_element(arg: str | Column | Func | Sequence, index: int) -> Func:
|
|
326
326
|
"""
|
|
327
327
|
Returns the element at the given index from the array.
|
|
328
328
|
If the index is out of bounds, it returns None or columns default value.
|
|
@@ -359,8 +359,8 @@ def get_element(arg: Union[str, Column, Func, Sequence], index: int) -> Func:
|
|
|
359
359
|
return str # if the array is empty, return str as default type
|
|
360
360
|
return None
|
|
361
361
|
|
|
362
|
-
cols:
|
|
363
|
-
args:
|
|
362
|
+
cols: str | Column | Func | Sequence | None
|
|
363
|
+
args: str | Column | Func | Sequence | int
|
|
364
364
|
|
|
365
365
|
if isinstance(arg, (str, Column, Func)):
|
|
366
366
|
cols = [arg]
|
|
@@ -379,7 +379,7 @@ def get_element(arg: Union[str, Column, Func, Sequence], index: int) -> Func:
|
|
|
379
379
|
)
|
|
380
380
|
|
|
381
381
|
|
|
382
|
-
def sip_hash_64(arg:
|
|
382
|
+
def sip_hash_64(arg: str | Column | Func | Sequence) -> Func:
|
|
383
383
|
"""
|
|
384
384
|
Returns the SipHash-64 hash of the array.
|
|
385
385
|
|
datachain/func/base.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import ABCMeta, abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
3
4
|
|
|
4
5
|
if TYPE_CHECKING:
|
|
5
6
|
from sqlalchemy import TableClause
|
|
@@ -12,12 +13,14 @@ class Function:
|
|
|
12
13
|
__metaclass__ = ABCMeta
|
|
13
14
|
|
|
14
15
|
name: str
|
|
16
|
+
cols: Sequence
|
|
17
|
+
args: Sequence
|
|
15
18
|
|
|
16
19
|
@abstractmethod
|
|
17
20
|
def get_column(
|
|
18
21
|
self,
|
|
19
|
-
signals_schema:
|
|
20
|
-
label:
|
|
21
|
-
table:
|
|
22
|
+
signals_schema: "SignalSchema | None" = None,
|
|
23
|
+
label: str | None = None,
|
|
24
|
+
table: "TableClause | None" = None,
|
|
22
25
|
) -> "Column":
|
|
23
26
|
pass
|