datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/query/dataset.py
CHANGED
|
@@ -1,25 +1,18 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import hashlib
|
|
2
3
|
import inspect
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
5
|
-
import
|
|
6
|
+
import secrets
|
|
6
7
|
import string
|
|
7
8
|
import subprocess
|
|
8
9
|
import sys
|
|
9
10
|
from abc import ABC, abstractmethod
|
|
10
|
-
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
11
|
+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
11
12
|
from copy import copy
|
|
12
13
|
from functools import wraps
|
|
13
|
-
from
|
|
14
|
-
from typing import
|
|
15
|
-
TYPE_CHECKING,
|
|
16
|
-
Any,
|
|
17
|
-
Callable,
|
|
18
|
-
Optional,
|
|
19
|
-
Protocol,
|
|
20
|
-
TypeVar,
|
|
21
|
-
Union,
|
|
22
|
-
)
|
|
14
|
+
from types import GeneratorType
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
|
23
16
|
|
|
24
17
|
import attrs
|
|
25
18
|
import sqlalchemy
|
|
@@ -28,7 +21,7 @@ from attrs import frozen
|
|
|
28
21
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
|
|
29
22
|
from sqlalchemy import Column
|
|
30
23
|
from sqlalchemy.sql import func as f
|
|
31
|
-
from sqlalchemy.sql.elements import ColumnClause, ColumnElement
|
|
24
|
+
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, Label
|
|
32
25
|
from sqlalchemy.sql.expression import label
|
|
33
26
|
from sqlalchemy.sql.schema import TableClause
|
|
34
27
|
from sqlalchemy.sql.selectable import Select
|
|
@@ -41,51 +34,53 @@ from datachain.data_storage.schema import (
|
|
|
41
34
|
partition_col_names,
|
|
42
35
|
partition_columns,
|
|
43
36
|
)
|
|
44
|
-
from datachain.dataset import
|
|
45
|
-
from datachain.error import
|
|
46
|
-
DatasetNotFoundError,
|
|
47
|
-
QueryScriptCancelError,
|
|
48
|
-
)
|
|
37
|
+
from datachain.dataset import DatasetDependency, DatasetStatus, RowDict
|
|
38
|
+
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
49
39
|
from datachain.func.base import Function
|
|
50
|
-
from datachain.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
)
|
|
40
|
+
from datachain.hash_utils import hash_column_elements
|
|
41
|
+
from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
|
|
42
|
+
from datachain.lib.signal_schema import SignalSchema, generate_merge_root_mapping
|
|
54
43
|
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
55
44
|
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
56
|
-
from datachain.
|
|
45
|
+
from datachain.project import Project
|
|
46
|
+
from datachain.query.schema import DEFAULT_DELIMITER, C, UDFParamSpec, normalize_param
|
|
57
47
|
from datachain.query.session import Session
|
|
48
|
+
from datachain.query.udf import UdfInfo
|
|
58
49
|
from datachain.sql.functions.random import rand
|
|
50
|
+
from datachain.sql.types import SQLType
|
|
59
51
|
from datachain.utils import (
|
|
60
|
-
batched,
|
|
61
52
|
determine_processes,
|
|
53
|
+
determine_workers,
|
|
54
|
+
ensure_sequence,
|
|
62
55
|
filtered_cloudpickle_dumps,
|
|
63
56
|
get_datachain_executable,
|
|
64
57
|
safe_closing,
|
|
65
58
|
)
|
|
66
59
|
|
|
67
60
|
if TYPE_CHECKING:
|
|
68
|
-
from
|
|
61
|
+
from collections.abc import Mapping
|
|
62
|
+
from typing import Concatenate
|
|
63
|
+
|
|
64
|
+
from sqlalchemy.sql.elements import ClauseElement, KeyedColumnElement
|
|
69
65
|
from sqlalchemy.sql.schema import Table
|
|
70
66
|
from sqlalchemy.sql.selectable import GenerativeSelect
|
|
71
|
-
from typing_extensions import
|
|
67
|
+
from typing_extensions import ParamSpec, Self
|
|
72
68
|
|
|
73
69
|
from datachain.catalog import Catalog
|
|
74
70
|
from datachain.data_storage import AbstractWarehouse
|
|
75
71
|
from datachain.dataset import DatasetRecord
|
|
76
72
|
from datachain.lib.udf import UDFAdapter, UDFResult
|
|
77
|
-
from datachain.query.udf import UdfInfo
|
|
78
73
|
|
|
79
74
|
P = ParamSpec("P")
|
|
80
75
|
|
|
81
76
|
|
|
82
77
|
INSERT_BATCH_SIZE = 10000
|
|
83
78
|
|
|
84
|
-
PartitionByType =
|
|
85
|
-
Function
|
|
86
|
-
|
|
87
|
-
JoinPredicateType =
|
|
88
|
-
DatasetDependencyType = tuple[
|
|
79
|
+
PartitionByType = (
|
|
80
|
+
str | Function | ColumnElement | Sequence[str | Function | ColumnElement]
|
|
81
|
+
)
|
|
82
|
+
JoinPredicateType = str | ColumnClause | ColumnElement
|
|
83
|
+
DatasetDependencyType = tuple["DatasetRecord", str]
|
|
89
84
|
|
|
90
85
|
logger = logging.getLogger("datachain")
|
|
91
86
|
|
|
@@ -165,24 +160,42 @@ class Step(ABC):
|
|
|
165
160
|
) -> "StepResult":
|
|
166
161
|
"""Apply the processing step."""
|
|
167
162
|
|
|
163
|
+
@abstractmethod
|
|
164
|
+
def hash_inputs(self) -> str:
|
|
165
|
+
"""Calculates hash of step inputs"""
|
|
166
|
+
|
|
167
|
+
def hash(self) -> str:
|
|
168
|
+
"""
|
|
169
|
+
Calculates hash for step which includes step name and hash of it's inputs
|
|
170
|
+
"""
|
|
171
|
+
return hashlib.sha256(
|
|
172
|
+
f"{self.__class__.__name__}|{self.hash_inputs()}".encode()
|
|
173
|
+
).hexdigest()
|
|
174
|
+
|
|
168
175
|
|
|
169
176
|
@frozen
|
|
170
177
|
class QueryStep:
|
|
178
|
+
"""A query that returns all rows from specific dataset version"""
|
|
179
|
+
|
|
171
180
|
catalog: "Catalog"
|
|
172
|
-
|
|
173
|
-
dataset_version:
|
|
181
|
+
dataset: "DatasetRecord"
|
|
182
|
+
dataset_version: str
|
|
174
183
|
|
|
175
|
-
def apply(self):
|
|
184
|
+
def apply(self) -> "StepResult":
|
|
176
185
|
def q(*columns):
|
|
177
186
|
return sqlalchemy.select(*columns)
|
|
178
187
|
|
|
179
|
-
|
|
180
|
-
dr = self.catalog.warehouse.dataset_rows(dataset, self.dataset_version)
|
|
188
|
+
dr = self.catalog.warehouse.dataset_rows(self.dataset, self.dataset_version)
|
|
181
189
|
|
|
182
190
|
return step_result(
|
|
183
|
-
q, dr.columns, dependencies=[(self.
|
|
191
|
+
q, dr.columns, dependencies=[(self.dataset, self.dataset_version)]
|
|
184
192
|
)
|
|
185
193
|
|
|
194
|
+
def hash(self) -> str:
|
|
195
|
+
return hashlib.sha256(
|
|
196
|
+
self.dataset.uri(self.dataset_version).encode()
|
|
197
|
+
).hexdigest()
|
|
198
|
+
|
|
186
199
|
|
|
187
200
|
def generator_then_call(generator, func: Callable):
|
|
188
201
|
"""
|
|
@@ -218,8 +231,9 @@ class DatasetDiffOperation(Step):
|
|
|
218
231
|
|
|
219
232
|
def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
|
|
220
233
|
source_query = query_generator.exclude(("sys__id",))
|
|
234
|
+
right_before = len(self.dq.temp_table_names)
|
|
221
235
|
target_query = self.dq.apply_steps().select()
|
|
222
|
-
temp_tables.extend(self.dq.temp_table_names)
|
|
236
|
+
temp_tables.extend(self.dq.temp_table_names[right_before:])
|
|
223
237
|
|
|
224
238
|
# creating temp table that will hold subtract results
|
|
225
239
|
temp_table_name = self.catalog.warehouse.temp_table_name()
|
|
@@ -253,6 +267,13 @@ class DatasetDiffOperation(Step):
|
|
|
253
267
|
class Subtract(DatasetDiffOperation):
|
|
254
268
|
on: Sequence[tuple[str, str]]
|
|
255
269
|
|
|
270
|
+
def hash_inputs(self) -> str:
|
|
271
|
+
on_bytes = b"".join(
|
|
272
|
+
f"{a}:{b}".encode() for a, b in sorted(self.on, key=lambda t: (t[0], t[1]))
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return hashlib.sha256(bytes.fromhex(self.dq.hash()) + on_bytes).hexdigest()
|
|
276
|
+
|
|
256
277
|
def query(self, source_query: Select, target_query: Select) -> sa.Selectable:
|
|
257
278
|
sq = source_query.alias("source_query")
|
|
258
279
|
tq = target_query.alias("target_query")
|
|
@@ -272,7 +293,9 @@ class Subtract(DatasetDiffOperation):
|
|
|
272
293
|
|
|
273
294
|
|
|
274
295
|
def adjust_outputs(
|
|
275
|
-
warehouse: "AbstractWarehouse",
|
|
296
|
+
warehouse: "AbstractWarehouse",
|
|
297
|
+
row: dict[str, Any],
|
|
298
|
+
col_types: list[tuple[str, SQLType, type, str, Any]],
|
|
276
299
|
) -> dict[str, Any]:
|
|
277
300
|
"""
|
|
278
301
|
This function does a couple of things to prepare a row for inserting into the db:
|
|
@@ -288,7 +311,7 @@ def adjust_outputs(
|
|
|
288
311
|
col_python_type,
|
|
289
312
|
col_type_name,
|
|
290
313
|
default_value,
|
|
291
|
-
) in
|
|
314
|
+
) in col_types:
|
|
292
315
|
row_val = row.get(col_name)
|
|
293
316
|
|
|
294
317
|
# Fill None or missing values with defaults (get returns None if not in the row)
|
|
@@ -303,8 +326,10 @@ def adjust_outputs(
|
|
|
303
326
|
return row
|
|
304
327
|
|
|
305
328
|
|
|
306
|
-
def
|
|
307
|
-
""
|
|
329
|
+
def get_col_types(
|
|
330
|
+
warehouse: "AbstractWarehouse", output: "Mapping[str, Any]"
|
|
331
|
+
) -> list[tuple]:
|
|
332
|
+
"""Optimization: Precompute column types so these don't have to be computed
|
|
308
333
|
in the convert_type function for each row in a loop."""
|
|
309
334
|
dialect = warehouse.db.dialect
|
|
310
335
|
return [
|
|
@@ -316,7 +341,7 @@ def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list
|
|
|
316
341
|
type(col_type_inst).__name__,
|
|
317
342
|
col_type.default_value(dialect),
|
|
318
343
|
)
|
|
319
|
-
for col_name, col_type in
|
|
344
|
+
for col_name, col_type in output.items()
|
|
320
345
|
]
|
|
321
346
|
|
|
322
347
|
|
|
@@ -325,33 +350,23 @@ def process_udf_outputs(
|
|
|
325
350
|
udf_table: "Table",
|
|
326
351
|
udf_results: Iterator[Iterable["UDFResult"]],
|
|
327
352
|
udf: "UDFAdapter",
|
|
328
|
-
batch_size: int = INSERT_BATCH_SIZE,
|
|
329
353
|
cb: Callback = DEFAULT_CALLBACK,
|
|
354
|
+
batch_size: int = INSERT_BATCH_SIZE,
|
|
330
355
|
) -> None:
|
|
331
|
-
import psutil
|
|
332
|
-
|
|
333
|
-
rows: list[UDFResult] = []
|
|
334
356
|
# Optimization: Compute row types once, rather than for every row.
|
|
335
|
-
udf_col_types =
|
|
357
|
+
udf_col_types = get_col_types(warehouse, udf.output)
|
|
336
358
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
for row in udf_output:
|
|
342
|
-
cb.relative_update()
|
|
343
|
-
rows.append(adjust_outputs(warehouse, row, udf_col_types))
|
|
344
|
-
if len(rows) >= batch_size or (
|
|
345
|
-
len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
|
|
346
|
-
):
|
|
347
|
-
for row_chunk in batched(rows, batch_size):
|
|
348
|
-
warehouse.insert_rows(udf_table, row_chunk)
|
|
349
|
-
rows.clear()
|
|
359
|
+
def _insert_rows():
|
|
360
|
+
for udf_output in udf_results:
|
|
361
|
+
if not udf_output:
|
|
362
|
+
continue
|
|
350
363
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
364
|
+
with safe_closing(udf_output):
|
|
365
|
+
for row in udf_output:
|
|
366
|
+
cb.relative_update()
|
|
367
|
+
yield adjust_outputs(warehouse, row, udf_col_types)
|
|
354
368
|
|
|
369
|
+
warehouse.insert_rows(udf_table, _insert_rows(), batch_size=batch_size)
|
|
355
370
|
warehouse.insert_rows_done(udf_table)
|
|
356
371
|
|
|
357
372
|
|
|
@@ -387,20 +402,34 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
387
402
|
class UDFStep(Step, ABC):
|
|
388
403
|
udf: "UDFAdapter"
|
|
389
404
|
catalog: "Catalog"
|
|
390
|
-
partition_by:
|
|
391
|
-
parallel: Optional[int] = None
|
|
392
|
-
workers: Union[bool, int] = False
|
|
393
|
-
min_task_size: Optional[int] = None
|
|
405
|
+
partition_by: PartitionByType | None = None
|
|
394
406
|
is_generator = False
|
|
407
|
+
# Parameters from Settings
|
|
395
408
|
cache: bool = False
|
|
409
|
+
parallel: int | None = None
|
|
410
|
+
workers: bool | int = False
|
|
411
|
+
min_task_size: int | None = None
|
|
412
|
+
batch_size: int | None = None
|
|
413
|
+
|
|
414
|
+
def hash_inputs(self) -> str:
|
|
415
|
+
partition_by = ensure_sequence(self.partition_by or [])
|
|
416
|
+
parts = [
|
|
417
|
+
bytes.fromhex(self.udf.hash()),
|
|
418
|
+
bytes.fromhex(hash_column_elements(partition_by)),
|
|
419
|
+
str(self.is_generator).encode(),
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
return hashlib.sha256(b"".join(parts)).hexdigest()
|
|
396
423
|
|
|
397
424
|
@abstractmethod
|
|
398
425
|
def create_udf_table(self, query: Select) -> "Table":
|
|
399
426
|
"""Method that creates a table where temp udf results will be saved"""
|
|
400
427
|
|
|
401
428
|
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
|
|
402
|
-
"""
|
|
403
|
-
|
|
429
|
+
"""Materialize inputs, ensure sys columns are available, needed for checkpoints,
|
|
430
|
+
needed for map to work (merge results)"""
|
|
431
|
+
table = self.catalog.warehouse.create_pre_udf_table(query)
|
|
432
|
+
return sqlalchemy.select(*table.c), [table]
|
|
404
433
|
|
|
405
434
|
@abstractmethod
|
|
406
435
|
def create_result_query(
|
|
@@ -412,28 +441,48 @@ class UDFStep(Step, ABC):
|
|
|
412
441
|
"""
|
|
413
442
|
|
|
414
443
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
444
|
+
if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
|
|
445
|
+
return
|
|
446
|
+
|
|
415
447
|
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
|
|
448
|
+
from datachain.catalog.loader import (
|
|
449
|
+
DISTRIBUTED_IMPORT_PATH,
|
|
450
|
+
get_udf_distributor_class,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
workers = determine_workers(self.workers, rows_total=rows_total)
|
|
454
|
+
processes = determine_processes(self.parallel, rows_total=rows_total)
|
|
416
455
|
|
|
417
456
|
use_partitioning = self.partition_by is not None
|
|
418
457
|
batching = self.udf.get_batching(use_partitioning)
|
|
419
|
-
workers = self.workers
|
|
420
|
-
if (
|
|
421
|
-
not workers
|
|
422
|
-
and os.environ.get("DATACHAIN_DISTRIBUTED")
|
|
423
|
-
and os.environ.get("DATACHAIN_SETTINGS_WORKERS")
|
|
424
|
-
):
|
|
425
|
-
# Enable distributed processing by default if the module is available,
|
|
426
|
-
# and a default number of workers is provided.
|
|
427
|
-
workers = True
|
|
428
|
-
|
|
429
|
-
processes = determine_processes(self.parallel)
|
|
430
|
-
|
|
431
458
|
udf_fields = [str(c.name) for c in query.selected_columns]
|
|
459
|
+
udf_distributor_class = get_udf_distributor_class()
|
|
432
460
|
|
|
433
461
|
prefetch = self.udf.prefetch
|
|
434
462
|
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
435
463
|
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
464
|
+
|
|
436
465
|
try:
|
|
466
|
+
if udf_distributor_class and not catalog.in_memory:
|
|
467
|
+
# Use the UDF distributor if available (running in SaaS)
|
|
468
|
+
udf_distributor = udf_distributor_class(
|
|
469
|
+
catalog=catalog,
|
|
470
|
+
table=udf_table,
|
|
471
|
+
query=query,
|
|
472
|
+
udf_data=filtered_cloudpickle_dumps(self.udf),
|
|
473
|
+
batching=batching,
|
|
474
|
+
workers=workers,
|
|
475
|
+
processes=processes,
|
|
476
|
+
udf_fields=udf_fields,
|
|
477
|
+
rows_total=rows_total,
|
|
478
|
+
use_cache=self.cache,
|
|
479
|
+
is_generator=self.is_generator,
|
|
480
|
+
min_task_size=self.min_task_size,
|
|
481
|
+
batch_size=self.batch_size,
|
|
482
|
+
)
|
|
483
|
+
udf_distributor()
|
|
484
|
+
return
|
|
485
|
+
|
|
437
486
|
if workers:
|
|
438
487
|
if catalog.in_memory:
|
|
439
488
|
raise RuntimeError(
|
|
@@ -441,43 +490,33 @@ class UDFStep(Step, ABC):
|
|
|
441
490
|
"distributed processing."
|
|
442
491
|
)
|
|
443
492
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
min_task_size=self.min_task_size
|
|
493
|
+
raise RuntimeError(
|
|
494
|
+
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
495
|
+
"for distributed UDF processing."
|
|
448
496
|
)
|
|
449
|
-
|
|
450
|
-
self.udf,
|
|
451
|
-
catalog,
|
|
452
|
-
udf_table,
|
|
453
|
-
query,
|
|
454
|
-
workers,
|
|
455
|
-
processes,
|
|
456
|
-
udf_fields=udf_fields,
|
|
457
|
-
is_generator=self.is_generator,
|
|
458
|
-
use_partitioning=use_partitioning,
|
|
459
|
-
cache=self.cache,
|
|
460
|
-
)
|
|
461
|
-
elif processes:
|
|
497
|
+
if processes:
|
|
462
498
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
463
499
|
if catalog.in_memory:
|
|
464
500
|
raise RuntimeError(
|
|
465
501
|
"In-memory databases cannot be used "
|
|
466
502
|
"with parallel processing."
|
|
467
503
|
)
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
504
|
+
|
|
505
|
+
udf_info = UdfInfo(
|
|
506
|
+
udf_data=filtered_cloudpickle_dumps(self.udf),
|
|
507
|
+
catalog_init=catalog.get_init_params(),
|
|
508
|
+
metastore_clone_params=catalog.metastore.clone_params(),
|
|
509
|
+
warehouse_clone_params=catalog.warehouse.clone_params(),
|
|
510
|
+
table=udf_table,
|
|
511
|
+
query=query,
|
|
512
|
+
udf_fields=udf_fields,
|
|
513
|
+
batching=batching,
|
|
514
|
+
processes=processes,
|
|
515
|
+
is_generator=self.is_generator,
|
|
516
|
+
cache=self.cache,
|
|
517
|
+
rows_total=rows_total,
|
|
518
|
+
batch_size=self.batch_size or INSERT_BATCH_SIZE,
|
|
519
|
+
)
|
|
481
520
|
|
|
482
521
|
# Run the UDFDispatcher in another process to avoid needing
|
|
483
522
|
# if __name__ == '__main__': in user scripts
|
|
@@ -490,7 +529,12 @@ class UDFStep(Step, ABC):
|
|
|
490
529
|
with subprocess.Popen( # noqa: S603
|
|
491
530
|
cmd, env=envs, stdin=subprocess.PIPE
|
|
492
531
|
) as process:
|
|
493
|
-
|
|
532
|
+
try:
|
|
533
|
+
process.communicate(process_data)
|
|
534
|
+
except KeyboardInterrupt:
|
|
535
|
+
raise QueryScriptCancelError(
|
|
536
|
+
"UDF execution was canceled by the user."
|
|
537
|
+
) from None
|
|
494
538
|
if retval := process.poll():
|
|
495
539
|
raise RuntimeError(
|
|
496
540
|
f"UDF Execution Failed! Exit code: {retval}"
|
|
@@ -520,6 +564,7 @@ class UDFStep(Step, ABC):
|
|
|
520
564
|
udf_results,
|
|
521
565
|
self.udf,
|
|
522
566
|
cb=generated_cb,
|
|
567
|
+
batch_size=self.batch_size or INSERT_BATCH_SIZE,
|
|
523
568
|
)
|
|
524
569
|
finally:
|
|
525
570
|
download_cb.close()
|
|
@@ -538,10 +583,13 @@ class UDFStep(Step, ABC):
|
|
|
538
583
|
"""
|
|
539
584
|
Create temporary table with group by partitions.
|
|
540
585
|
"""
|
|
541
|
-
|
|
586
|
+
if self.partition_by is None:
|
|
587
|
+
raise RuntimeError("Query must have partition_by set to use partitioning")
|
|
588
|
+
if (id_col := query.selected_columns.get("sys__id")) is None:
|
|
589
|
+
raise RuntimeError("Query must have sys__id column to use partitioning")
|
|
542
590
|
|
|
543
|
-
if isinstance(self.partition_by,
|
|
544
|
-
list_partition_by = self.partition_by
|
|
591
|
+
if isinstance(self.partition_by, (list, tuple, GeneratorType)):
|
|
592
|
+
list_partition_by = list(self.partition_by)
|
|
545
593
|
else:
|
|
546
594
|
list_partition_by = [self.partition_by]
|
|
547
595
|
|
|
@@ -554,16 +602,19 @@ class UDFStep(Step, ABC):
|
|
|
554
602
|
|
|
555
603
|
# fill table with partitions
|
|
556
604
|
cols = [
|
|
557
|
-
|
|
605
|
+
id_col,
|
|
558
606
|
f.dense_rank().over(order_by=partition_by).label(PARTITION_COLUMN_ID),
|
|
559
607
|
]
|
|
560
608
|
self.catalog.warehouse.db.execute(
|
|
561
|
-
tbl.insert().from_select(
|
|
609
|
+
tbl.insert().from_select(
|
|
610
|
+
cols,
|
|
611
|
+
query.offset(None).limit(None).with_only_columns(*cols),
|
|
612
|
+
)
|
|
562
613
|
)
|
|
563
614
|
|
|
564
615
|
return tbl
|
|
565
616
|
|
|
566
|
-
def clone(self, partition_by:
|
|
617
|
+
def clone(self, partition_by: PartitionByType | None = None) -> "Self":
|
|
567
618
|
if partition_by is not None:
|
|
568
619
|
return self.__class__(
|
|
569
620
|
self.udf,
|
|
@@ -572,27 +623,25 @@ class UDFStep(Step, ABC):
|
|
|
572
623
|
parallel=self.parallel,
|
|
573
624
|
workers=self.workers,
|
|
574
625
|
min_task_size=self.min_task_size,
|
|
626
|
+
batch_size=self.batch_size,
|
|
575
627
|
)
|
|
576
628
|
return self.__class__(self.udf, self.catalog)
|
|
577
629
|
|
|
578
630
|
def apply(
|
|
579
631
|
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
580
632
|
) -> "StepResult":
|
|
581
|
-
|
|
633
|
+
query, tables = self.process_input_query(query_generator.select())
|
|
634
|
+
_query = query
|
|
582
635
|
|
|
583
636
|
# Apply partitioning if needed.
|
|
584
637
|
if self.partition_by is not None:
|
|
585
638
|
partition_tbl = self.create_partitions_table(query)
|
|
586
|
-
|
|
639
|
+
query = query.outerjoin(
|
|
640
|
+
partition_tbl,
|
|
641
|
+
partition_tbl.c.sys__id == query.selected_columns.sys__id,
|
|
642
|
+
).add_columns(*partition_columns())
|
|
643
|
+
tables = [*tables, partition_tbl]
|
|
587
644
|
|
|
588
|
-
subq = query.subquery()
|
|
589
|
-
query = (
|
|
590
|
-
sqlalchemy.select(*subq.c)
|
|
591
|
-
.outerjoin(partition_tbl, partition_tbl.c.sys__id == subq.c.sys__id)
|
|
592
|
-
.add_columns(*partition_columns())
|
|
593
|
-
)
|
|
594
|
-
|
|
595
|
-
query, tables = self.process_input_query(query)
|
|
596
645
|
temp_tables.extend(t.name for t in tables)
|
|
597
646
|
udf_table = self.create_udf_table(_query)
|
|
598
647
|
temp_tables.append(udf_table.name)
|
|
@@ -604,7 +653,16 @@ class UDFStep(Step, ABC):
|
|
|
604
653
|
|
|
605
654
|
@frozen
|
|
606
655
|
class UDFSignal(UDFStep):
|
|
656
|
+
udf: "UDFAdapter"
|
|
657
|
+
catalog: "Catalog"
|
|
658
|
+
partition_by: PartitionByType | None = None
|
|
607
659
|
is_generator = False
|
|
660
|
+
# Parameters from Settings
|
|
661
|
+
cache: bool = False
|
|
662
|
+
parallel: int | None = None
|
|
663
|
+
workers: bool | int = False
|
|
664
|
+
min_task_size: int | None = None
|
|
665
|
+
batch_size: int | None = None
|
|
608
666
|
|
|
609
667
|
def create_udf_table(self, query: Select) -> "Table":
|
|
610
668
|
udf_output_columns: list[sqlalchemy.Column[Any]] = [
|
|
@@ -614,13 +672,6 @@ class UDFSignal(UDFStep):
|
|
|
614
672
|
|
|
615
673
|
return self.catalog.warehouse.create_udf_table(udf_output_columns)
|
|
616
674
|
|
|
617
|
-
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
|
|
618
|
-
if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
|
|
619
|
-
return query, []
|
|
620
|
-
table = self.catalog.warehouse.create_pre_udf_table(query)
|
|
621
|
-
q: Select = sqlalchemy.select(*table.c)
|
|
622
|
-
return q, [table]
|
|
623
|
-
|
|
624
675
|
def create_result_query(
|
|
625
676
|
self, udf_table, query
|
|
626
677
|
) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
|
|
@@ -628,15 +679,30 @@ class UDFSignal(UDFStep):
|
|
|
628
679
|
original_cols = [c for c in subq.c if c.name not in partition_col_names]
|
|
629
680
|
|
|
630
681
|
# new signal columns that are added to udf_table
|
|
631
|
-
signal_cols = [c for c in udf_table.c if c.name
|
|
682
|
+
signal_cols = [c for c in udf_table.c if not c.name.startswith("sys__")]
|
|
632
683
|
signal_name_cols = {c.name: c for c in signal_cols}
|
|
633
684
|
cols = signal_cols
|
|
634
685
|
|
|
635
|
-
|
|
686
|
+
original_names = {c.name for c in original_cols}
|
|
687
|
+
new_names = {c.name for c in cols}
|
|
688
|
+
|
|
689
|
+
overlap = original_names & new_names
|
|
636
690
|
if overlap:
|
|
637
691
|
raise ValueError(
|
|
638
692
|
"Column already exists or added in the previous steps: "
|
|
639
|
-
+ ", ".join(overlap)
|
|
693
|
+
+ ", ".join(sorted(overlap))
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
def _root(name: str) -> str:
|
|
697
|
+
return name.split(DEFAULT_DELIMITER, 1)[0]
|
|
698
|
+
|
|
699
|
+
existing_roots = {_root(name) for name in original_names}
|
|
700
|
+
new_roots = {_root(name) for name in new_names}
|
|
701
|
+
root_conflicts = existing_roots & new_roots
|
|
702
|
+
if root_conflicts:
|
|
703
|
+
raise ValueError(
|
|
704
|
+
"Signals already exist in the previous steps: "
|
|
705
|
+
+ ", ".join(sorted(root_conflicts))
|
|
640
706
|
)
|
|
641
707
|
|
|
642
708
|
def q(*columns):
|
|
@@ -674,7 +740,16 @@ class UDFSignal(UDFStep):
|
|
|
674
740
|
class RowGenerator(UDFStep):
|
|
675
741
|
"""Extend dataset with new rows."""
|
|
676
742
|
|
|
743
|
+
udf: "UDFAdapter"
|
|
744
|
+
catalog: "Catalog"
|
|
745
|
+
partition_by: PartitionByType | None = None
|
|
677
746
|
is_generator = True
|
|
747
|
+
# Parameters from Settings
|
|
748
|
+
cache: bool = False
|
|
749
|
+
parallel: int | None = None
|
|
750
|
+
workers: bool | int = False
|
|
751
|
+
min_task_size: int | None = None
|
|
752
|
+
batch_size: int | None = None
|
|
678
753
|
|
|
679
754
|
def create_udf_table(self, query: Select) -> "Table":
|
|
680
755
|
warehouse = self.catalog.warehouse
|
|
@@ -721,18 +796,42 @@ class SQLClause(Step, ABC):
|
|
|
721
796
|
|
|
722
797
|
def parse_cols(
|
|
723
798
|
self,
|
|
724
|
-
cols: Sequence[
|
|
799
|
+
cols: Sequence[Function | ColumnElement],
|
|
725
800
|
) -> tuple[ColumnElement, ...]:
|
|
726
801
|
return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
|
|
727
802
|
|
|
728
803
|
@abstractmethod
|
|
729
|
-
def apply_sql_clause(self, query):
|
|
804
|
+
def apply_sql_clause(self, query: Any) -> Any:
|
|
730
805
|
pass
|
|
731
806
|
|
|
732
807
|
|
|
808
|
+
@frozen
|
|
809
|
+
class RegenerateSystemColumns(Step):
|
|
810
|
+
catalog: "Catalog"
|
|
811
|
+
|
|
812
|
+
def hash_inputs(self) -> str:
|
|
813
|
+
return hashlib.sha256(b"regenerate_system_columns").hexdigest()
|
|
814
|
+
|
|
815
|
+
def apply(
|
|
816
|
+
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
817
|
+
) -> StepResult:
|
|
818
|
+
query = query_generator.select()
|
|
819
|
+
new_query = self.catalog.warehouse._regenerate_system_columns(
|
|
820
|
+
query, keep_existing_columns=True
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
def q(*columns):
|
|
824
|
+
return new_query.with_only_columns(*columns)
|
|
825
|
+
|
|
826
|
+
return step_result(q, new_query.selected_columns)
|
|
827
|
+
|
|
828
|
+
|
|
733
829
|
@frozen
|
|
734
830
|
class SQLSelect(SQLClause):
|
|
735
|
-
args: tuple[
|
|
831
|
+
args: tuple[Function | ColumnElement, ...]
|
|
832
|
+
|
|
833
|
+
def hash_inputs(self) -> str:
|
|
834
|
+
return hash_column_elements(self.args)
|
|
736
835
|
|
|
737
836
|
def apply_sql_clause(self, query) -> Select:
|
|
738
837
|
subquery = query.subquery()
|
|
@@ -748,7 +847,10 @@ class SQLSelect(SQLClause):
|
|
|
748
847
|
|
|
749
848
|
@frozen
|
|
750
849
|
class SQLSelectExcept(SQLClause):
|
|
751
|
-
args: tuple[
|
|
850
|
+
args: tuple[Function | ColumnElement, ...]
|
|
851
|
+
|
|
852
|
+
def hash_inputs(self) -> str:
|
|
853
|
+
return hash_column_elements(self.args)
|
|
752
854
|
|
|
753
855
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
754
856
|
subquery = query.subquery()
|
|
@@ -758,33 +860,43 @@ class SQLSelectExcept(SQLClause):
|
|
|
758
860
|
|
|
759
861
|
@frozen
|
|
760
862
|
class SQLMutate(SQLClause):
|
|
761
|
-
args: tuple[
|
|
863
|
+
args: tuple[Label, ...]
|
|
864
|
+
new_schema: SignalSchema
|
|
865
|
+
|
|
866
|
+
def hash_inputs(self) -> str:
|
|
867
|
+
return hash_column_elements(self.args)
|
|
762
868
|
|
|
763
869
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
764
870
|
original_subquery = query.subquery()
|
|
765
|
-
|
|
766
|
-
original_subquery.c[str(c)] if isinstance(c, (str, C)) else c
|
|
767
|
-
for c in self.parse_cols(self.args)
|
|
768
|
-
]
|
|
769
|
-
to_mutate = {c.name for c in args}
|
|
871
|
+
to_mutate = {c.name for c in self.args}
|
|
770
872
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
873
|
+
# Drop the original versions to avoid name collisions, exclude renamed
|
|
874
|
+
# columns. Always keep system columns (sys__*) if they exist in original query
|
|
875
|
+
new_schema_columns = set(self.new_schema.db_signals())
|
|
876
|
+
base_cols = [
|
|
877
|
+
c
|
|
774
878
|
for c in original_subquery.c
|
|
879
|
+
if c.name not in to_mutate
|
|
880
|
+
and (c.name in new_schema_columns or c.name.startswith("sys__"))
|
|
775
881
|
]
|
|
776
|
-
|
|
777
|
-
#
|
|
778
|
-
|
|
779
|
-
|
|
882
|
+
|
|
883
|
+
# Create intermediate subquery to properly handle window functions
|
|
884
|
+
intermediate_query = sqlalchemy.select(*base_cols, *self.args).select_from(
|
|
885
|
+
original_subquery
|
|
780
886
|
)
|
|
887
|
+
intermediate_subquery = intermediate_query.subquery()
|
|
781
888
|
|
|
782
|
-
return sqlalchemy.select(*
|
|
889
|
+
return sqlalchemy.select(*intermediate_subquery.c).select_from(
|
|
890
|
+
intermediate_subquery
|
|
891
|
+
)
|
|
783
892
|
|
|
784
893
|
|
|
785
894
|
@frozen
|
|
786
895
|
class SQLFilter(SQLClause):
|
|
787
|
-
expressions: tuple[
|
|
896
|
+
expressions: tuple[Function | ColumnElement, ...]
|
|
897
|
+
|
|
898
|
+
def hash_inputs(self) -> str:
|
|
899
|
+
return hash_column_elements(self.expressions)
|
|
788
900
|
|
|
789
901
|
def __and__(self, other):
|
|
790
902
|
expressions = self.parse_cols(self.expressions)
|
|
@@ -797,7 +909,10 @@ class SQLFilter(SQLClause):
|
|
|
797
909
|
|
|
798
910
|
@frozen
|
|
799
911
|
class SQLOrderBy(SQLClause):
|
|
800
|
-
args: tuple[
|
|
912
|
+
args: tuple[Function | ColumnElement, ...]
|
|
913
|
+
|
|
914
|
+
def hash_inputs(self) -> str:
|
|
915
|
+
return hash_column_elements(self.args)
|
|
801
916
|
|
|
802
917
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
803
918
|
args = self.parse_cols(self.args)
|
|
@@ -808,6 +923,9 @@ class SQLOrderBy(SQLClause):
|
|
|
808
923
|
class SQLLimit(SQLClause):
|
|
809
924
|
n: int
|
|
810
925
|
|
|
926
|
+
def hash_inputs(self) -> str:
|
|
927
|
+
return hashlib.sha256(str(self.n).encode()).hexdigest()
|
|
928
|
+
|
|
811
929
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
812
930
|
return query.limit(self.n)
|
|
813
931
|
|
|
@@ -816,12 +934,18 @@ class SQLLimit(SQLClause):
|
|
|
816
934
|
class SQLOffset(SQLClause):
|
|
817
935
|
offset: int
|
|
818
936
|
|
|
937
|
+
def hash_inputs(self) -> str:
|
|
938
|
+
return hashlib.sha256(str(self.offset).encode()).hexdigest()
|
|
939
|
+
|
|
819
940
|
def apply_sql_clause(self, query: "GenerativeSelect"):
|
|
820
941
|
return query.offset(self.offset)
|
|
821
942
|
|
|
822
943
|
|
|
823
944
|
@frozen
|
|
824
945
|
class SQLCount(SQLClause):
|
|
946
|
+
def hash_inputs(self) -> str:
|
|
947
|
+
return ""
|
|
948
|
+
|
|
825
949
|
def apply_sql_clause(self, query):
|
|
826
950
|
return sqlalchemy.select(f.count(1)).select_from(query.subquery())
|
|
827
951
|
|
|
@@ -831,6 +955,9 @@ class SQLDistinct(SQLClause):
|
|
|
831
955
|
args: tuple[ColumnElement, ...]
|
|
832
956
|
dialect: str
|
|
833
957
|
|
|
958
|
+
def hash_inputs(self) -> str:
|
|
959
|
+
return hash_column_elements(self.args)
|
|
960
|
+
|
|
834
961
|
def apply_sql_clause(self, query):
|
|
835
962
|
if self.dialect == "sqlite":
|
|
836
963
|
return query.group_by(*self.args)
|
|
@@ -843,24 +970,34 @@ class SQLUnion(Step):
|
|
|
843
970
|
query1: "DatasetQuery"
|
|
844
971
|
query2: "DatasetQuery"
|
|
845
972
|
|
|
973
|
+
def hash_inputs(self) -> str:
|
|
974
|
+
return hashlib.sha256(
|
|
975
|
+
bytes.fromhex(self.query1.hash()) + bytes.fromhex(self.query2.hash())
|
|
976
|
+
).hexdigest()
|
|
977
|
+
|
|
846
978
|
def apply(
|
|
847
979
|
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
848
980
|
) -> StepResult:
|
|
981
|
+
left_before = len(self.query1.temp_table_names)
|
|
849
982
|
q1 = self.query1.apply_steps().select().subquery()
|
|
850
|
-
temp_tables.extend(self.query1.temp_table_names)
|
|
983
|
+
temp_tables.extend(self.query1.temp_table_names[left_before:])
|
|
984
|
+
right_before = len(self.query2.temp_table_names)
|
|
851
985
|
q2 = self.query2.apply_steps().select().subquery()
|
|
852
|
-
temp_tables.extend(self.query2.temp_table_names)
|
|
986
|
+
temp_tables.extend(self.query2.temp_table_names[right_before:])
|
|
853
987
|
|
|
854
|
-
columns1
|
|
988
|
+
columns1 = _drop_system_columns(q1.columns)
|
|
989
|
+
columns2 = _drop_system_columns(q2.columns)
|
|
990
|
+
columns1, columns2 = _order_columns(columns1, columns2)
|
|
855
991
|
|
|
856
992
|
def q(*columns):
|
|
857
|
-
|
|
858
|
-
col1 = [c for c in columns1 if c.name in
|
|
859
|
-
col2 = [c for c in columns2 if c.name in
|
|
860
|
-
|
|
993
|
+
selected_names = [c.name for c in columns]
|
|
994
|
+
col1 = [c for c in columns1 if c.name in selected_names]
|
|
995
|
+
col2 = [c for c in columns2 if c.name in selected_names]
|
|
996
|
+
union_query = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
|
|
861
997
|
|
|
862
|
-
|
|
863
|
-
|
|
998
|
+
union_cte = union_query.cte()
|
|
999
|
+
select_cols = [union_cte.c[name] for name in selected_names]
|
|
1000
|
+
return sqlalchemy.select(*select_cols)
|
|
864
1001
|
|
|
865
1002
|
return step_result(
|
|
866
1003
|
q,
|
|
@@ -874,14 +1011,42 @@ class SQLJoin(Step):
|
|
|
874
1011
|
catalog: "Catalog"
|
|
875
1012
|
query1: "DatasetQuery"
|
|
876
1013
|
query2: "DatasetQuery"
|
|
877
|
-
predicates:
|
|
1014
|
+
predicates: JoinPredicateType | tuple[JoinPredicateType, ...]
|
|
878
1015
|
inner: bool
|
|
879
1016
|
full: bool
|
|
880
1017
|
rname: str
|
|
881
1018
|
|
|
1019
|
+
@staticmethod
|
|
1020
|
+
def _split_db_name(name: str) -> tuple[str, str]:
|
|
1021
|
+
if DEFAULT_DELIMITER in name:
|
|
1022
|
+
head, tail = name.split(DEFAULT_DELIMITER, 1)
|
|
1023
|
+
return head, tail
|
|
1024
|
+
return name, ""
|
|
1025
|
+
|
|
1026
|
+
@classmethod
|
|
1027
|
+
def _root_name(cls, name: str) -> str:
|
|
1028
|
+
return cls._split_db_name(name)[0]
|
|
1029
|
+
|
|
1030
|
+
def hash_inputs(self) -> str:
|
|
1031
|
+
predicates = (
|
|
1032
|
+
ensure_sequence(self.predicates) if self.predicates is not None else []
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
parts = [
|
|
1036
|
+
bytes.fromhex(self.query1.hash()),
|
|
1037
|
+
bytes.fromhex(self.query2.hash()),
|
|
1038
|
+
bytes.fromhex(hash_column_elements(predicates)),
|
|
1039
|
+
str(self.inner).encode(),
|
|
1040
|
+
str(self.full).encode(),
|
|
1041
|
+
self.rname.encode("utf-8"),
|
|
1042
|
+
]
|
|
1043
|
+
|
|
1044
|
+
return hashlib.sha256(b"".join(parts)).hexdigest()
|
|
1045
|
+
|
|
882
1046
|
def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
|
|
1047
|
+
temp_tables_before = len(dq.temp_table_names)
|
|
883
1048
|
query = dq.apply_steps().select()
|
|
884
|
-
temp_tables.extend(dq.temp_table_names)
|
|
1049
|
+
temp_tables.extend(dq.temp_table_names[temp_tables_before:])
|
|
885
1050
|
|
|
886
1051
|
if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in dq.steps):
|
|
887
1052
|
return query.subquery(dq.table.name)
|
|
@@ -937,22 +1102,39 @@ class SQLJoin(Step):
|
|
|
937
1102
|
q1 = self.get_query(self.query1, temp_tables)
|
|
938
1103
|
q2 = self.get_query(self.query2, temp_tables)
|
|
939
1104
|
|
|
940
|
-
q1_columns =
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
for
|
|
945
|
-
if
|
|
1105
|
+
q1_columns = _drop_system_columns(q1.c)
|
|
1106
|
+
existing_column_names = {c.name for c in q1_columns}
|
|
1107
|
+
right_columns: list[KeyedColumnElement[Any]] = []
|
|
1108
|
+
right_column_names: list[str] = []
|
|
1109
|
+
for column in q2.c:
|
|
1110
|
+
if column.name.startswith("sys__"):
|
|
946
1111
|
continue
|
|
1112
|
+
right_columns.append(column)
|
|
1113
|
+
right_column_names.append(column.name)
|
|
1114
|
+
|
|
1115
|
+
root_mapping = generate_merge_root_mapping(
|
|
1116
|
+
existing_column_names,
|
|
1117
|
+
right_column_names,
|
|
1118
|
+
extract_root=self._root_name,
|
|
1119
|
+
prefix=self.rname,
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
q2_columns: list[KeyedColumnElement[Any]] = []
|
|
1123
|
+
for column in right_columns:
|
|
1124
|
+
original_name = column.name
|
|
1125
|
+
column_root, column_tail = self._split_db_name(original_name)
|
|
1126
|
+
mapped_root = root_mapping[column_root]
|
|
1127
|
+
|
|
1128
|
+
new_name = (
|
|
1129
|
+
mapped_root
|
|
1130
|
+
if not column_tail
|
|
1131
|
+
else DEFAULT_DELIMITER.join([mapped_root, column_tail])
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
if new_name != original_name:
|
|
1135
|
+
column = column.label(new_name)
|
|
947
1136
|
|
|
948
|
-
|
|
949
|
-
new_name = self.rname.format(name=c.name)
|
|
950
|
-
new_name_idx = 0
|
|
951
|
-
while new_name in q1_column_names:
|
|
952
|
-
new_name_idx += 1
|
|
953
|
-
new_name = self.rname.format(name=f"{c.name}_{new_name_idx}")
|
|
954
|
-
c = c.label(new_name)
|
|
955
|
-
q2_columns.append(c)
|
|
1137
|
+
q2_columns.append(column)
|
|
956
1138
|
|
|
957
1139
|
res_columns = q1_columns + q2_columns
|
|
958
1140
|
predicates = (
|
|
@@ -997,8 +1179,15 @@ class SQLJoin(Step):
|
|
|
997
1179
|
|
|
998
1180
|
@frozen
|
|
999
1181
|
class SQLGroupBy(SQLClause):
|
|
1000
|
-
cols: Sequence[
|
|
1001
|
-
group_by: Sequence[
|
|
1182
|
+
cols: Sequence[str | Function | ColumnElement]
|
|
1183
|
+
group_by: Sequence[str | Function | ColumnElement]
|
|
1184
|
+
|
|
1185
|
+
def hash_inputs(self) -> str:
|
|
1186
|
+
return hashlib.sha256(
|
|
1187
|
+
bytes.fromhex(
|
|
1188
|
+
hash_column_elements(self.cols) + hash_column_elements(self.group_by)
|
|
1189
|
+
)
|
|
1190
|
+
).hexdigest()
|
|
1002
1191
|
|
|
1003
1192
|
def apply_sql_clause(self, query) -> Select:
|
|
1004
1193
|
if not self.cols:
|
|
@@ -1010,58 +1199,70 @@ class SQLGroupBy(SQLClause):
|
|
|
1010
1199
|
c.get_column() if isinstance(c, Function) else c for c in self.group_by
|
|
1011
1200
|
]
|
|
1012
1201
|
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
if isinstance(c, Function)
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1202
|
+
cols_dict: dict[str, Any] = {}
|
|
1203
|
+
for c in (*group_by, *self.cols):
|
|
1204
|
+
if isinstance(c, Function):
|
|
1205
|
+
key = c.name
|
|
1206
|
+
value = c.get_column()
|
|
1207
|
+
elif isinstance(c, (str, C)):
|
|
1208
|
+
key = str(c)
|
|
1209
|
+
value = subquery.c[str(c)]
|
|
1210
|
+
else:
|
|
1211
|
+
key = c.name
|
|
1212
|
+
value = c # type: ignore[assignment]
|
|
1213
|
+
cols_dict[key] = value
|
|
1021
1214
|
|
|
1022
|
-
|
|
1215
|
+
unique_cols = cols_dict.values()
|
|
1023
1216
|
|
|
1217
|
+
return sqlalchemy.select(*unique_cols).select_from(subquery).group_by(*group_by)
|
|
1024
1218
|
|
|
1025
|
-
def _validate_columns(
|
|
1026
|
-
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
|
|
1027
|
-
) -> set[str]:
|
|
1028
|
-
left_names = {c.name for c in left_columns}
|
|
1029
|
-
right_names = {c.name for c in right_columns}
|
|
1030
|
-
|
|
1031
|
-
if left_names == right_names:
|
|
1032
|
-
return left_names
|
|
1033
|
-
|
|
1034
|
-
missing_right = left_names - right_names
|
|
1035
|
-
missing_left = right_names - left_names
|
|
1036
|
-
|
|
1037
|
-
def _prepare_msg_part(missing_columns: set[str], side: str) -> str:
|
|
1038
|
-
return f"{', '.join(sorted(missing_columns))} only present in {side}"
|
|
1039
|
-
|
|
1040
|
-
msg_parts = [
|
|
1041
|
-
_prepare_msg_part(missing_columns, found_side)
|
|
1042
|
-
for missing_columns, found_side in zip(
|
|
1043
|
-
[
|
|
1044
|
-
missing_right,
|
|
1045
|
-
missing_left,
|
|
1046
|
-
],
|
|
1047
|
-
["left", "right"],
|
|
1048
|
-
)
|
|
1049
|
-
if missing_columns
|
|
1050
|
-
]
|
|
1051
|
-
msg = f"Cannot perform union. {'. '.join(msg_parts)}"
|
|
1052
1219
|
|
|
1053
|
-
|
|
1220
|
+
class UnionSchemaMismatchError(ValueError):
|
|
1221
|
+
"""Union input columns mismatch."""
|
|
1222
|
+
|
|
1223
|
+
@classmethod
|
|
1224
|
+
def from_column_sets(
|
|
1225
|
+
cls,
|
|
1226
|
+
missing_left: set[str],
|
|
1227
|
+
missing_right: set[str],
|
|
1228
|
+
) -> "UnionSchemaMismatchError":
|
|
1229
|
+
def _describe(cols: set[str], side: str) -> str:
|
|
1230
|
+
return f"{', '.join(sorted(cols))} only present in {side}"
|
|
1231
|
+
|
|
1232
|
+
parts = []
|
|
1233
|
+
if missing_left:
|
|
1234
|
+
parts.append(_describe(missing_left, "left"))
|
|
1235
|
+
if missing_right:
|
|
1236
|
+
parts.append(_describe(missing_right, "right"))
|
|
1237
|
+
|
|
1238
|
+
return cls(f"Cannot perform union. {'. '.join(parts)}")
|
|
1054
1239
|
|
|
1055
1240
|
|
|
1056
1241
|
def _order_columns(
|
|
1057
1242
|
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
|
|
1058
1243
|
) -> list[list[ColumnElement]]:
|
|
1059
|
-
|
|
1244
|
+
left_names = [c.name for c in left_columns]
|
|
1245
|
+
right_names = [c.name for c in right_columns]
|
|
1246
|
+
|
|
1247
|
+
# validate
|
|
1248
|
+
if sorted(left_names) != sorted(right_names):
|
|
1249
|
+
left_names_set = set(left_names)
|
|
1250
|
+
right_names_set = set(right_names)
|
|
1251
|
+
raise UnionSchemaMismatchError.from_column_sets(
|
|
1252
|
+
left_names_set - right_names_set,
|
|
1253
|
+
right_names_set - left_names_set,
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
# Order columns to match left_names order
|
|
1060
1257
|
column_dicts = [
|
|
1061
1258
|
{c.name: c for c in columns} for columns in [left_columns, right_columns]
|
|
1062
1259
|
]
|
|
1063
1260
|
|
|
1064
|
-
return [[d[n] for n in
|
|
1261
|
+
return [[d[n] for n in left_names] for d in column_dicts]
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
|
|
1265
|
+
return [c for c in columns if not c.name.startswith("sys__")]
|
|
1065
1266
|
|
|
1066
1267
|
|
|
1067
1268
|
@attrs.define
|
|
@@ -1077,62 +1278,71 @@ class DatasetQuery:
|
|
|
1077
1278
|
def __init__(
|
|
1078
1279
|
self,
|
|
1079
1280
|
name: str,
|
|
1080
|
-
version:
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1281
|
+
version: str | None = None,
|
|
1282
|
+
project_name: str | None = None,
|
|
1283
|
+
namespace_name: str | None = None,
|
|
1284
|
+
catalog: "Catalog | None" = None,
|
|
1285
|
+
session: Session | None = None,
|
|
1084
1286
|
in_memory: bool = False,
|
|
1085
|
-
fallback_to_studio: bool = True,
|
|
1086
1287
|
update: bool = False,
|
|
1087
1288
|
) -> None:
|
|
1088
|
-
from datachain.remote.studio import is_token_set
|
|
1089
|
-
|
|
1090
1289
|
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
|
|
1091
1290
|
self.catalog = catalog or self.session.catalog
|
|
1092
1291
|
self.steps: list[Step] = []
|
|
1093
|
-
self._chunk_index:
|
|
1094
|
-
self._chunk_total:
|
|
1292
|
+
self._chunk_index: int | None = None
|
|
1293
|
+
self._chunk_total: int | None = None
|
|
1095
1294
|
self.temp_table_names: list[str] = []
|
|
1096
1295
|
self.dependencies: set[DatasetDependencyType] = set()
|
|
1097
1296
|
self.table = self.get_table()
|
|
1098
|
-
self.starting_step:
|
|
1099
|
-
self.name:
|
|
1100
|
-
self.version:
|
|
1101
|
-
self.feature_schema:
|
|
1102
|
-
self.column_types:
|
|
1297
|
+
self.starting_step: QueryStep | None = None
|
|
1298
|
+
self.name: str | None = None
|
|
1299
|
+
self.version: str | None = None
|
|
1300
|
+
self.feature_schema: dict | None = None
|
|
1301
|
+
self.column_types: dict[str, Any] | None = None
|
|
1103
1302
|
self.before_steps: list[Callable] = []
|
|
1104
|
-
self.listing_fn:
|
|
1303
|
+
self.listing_fn: Callable | None = None
|
|
1105
1304
|
self.update = update
|
|
1106
1305
|
|
|
1107
|
-
self.list_ds_name:
|
|
1306
|
+
self.list_ds_name: str | None = None
|
|
1108
1307
|
|
|
1109
1308
|
self.name = name
|
|
1110
1309
|
self.dialect = self.catalog.warehouse.db.dialect
|
|
1111
1310
|
if version:
|
|
1112
1311
|
self.version = version
|
|
1113
1312
|
|
|
1114
|
-
if
|
|
1313
|
+
if namespace_name is None:
|
|
1314
|
+
namespace_name = self.catalog.metastore.default_namespace_name
|
|
1315
|
+
if project_name is None:
|
|
1316
|
+
project_name = self.catalog.metastore.default_project_name
|
|
1317
|
+
|
|
1318
|
+
if is_listing_dataset(name) and not version:
|
|
1115
1319
|
# not setting query step yet as listing dataset might not exist at
|
|
1116
1320
|
# this point
|
|
1117
1321
|
self.list_ds_name = name
|
|
1118
|
-
|
|
1322
|
+
else:
|
|
1119
1323
|
self._set_starting_step(
|
|
1120
|
-
self.catalog.get_dataset_with_remote_fallback(
|
|
1324
|
+
self.catalog.get_dataset_with_remote_fallback(
|
|
1325
|
+
name,
|
|
1326
|
+
namespace_name=namespace_name,
|
|
1327
|
+
project_name=project_name,
|
|
1328
|
+
version=version,
|
|
1329
|
+
pull_dataset=True,
|
|
1330
|
+
update=update,
|
|
1331
|
+
)
|
|
1121
1332
|
)
|
|
1122
|
-
else:
|
|
1123
|
-
self._set_starting_step(self.catalog.get_dataset(name))
|
|
1124
1333
|
|
|
1125
1334
|
def _set_starting_step(self, ds: "DatasetRecord") -> None:
|
|
1126
1335
|
if not self.version:
|
|
1127
1336
|
self.version = ds.latest_version
|
|
1128
1337
|
|
|
1129
|
-
self.starting_step = QueryStep(self.catalog, ds
|
|
1338
|
+
self.starting_step = QueryStep(self.catalog, ds, self.version)
|
|
1130
1339
|
|
|
1131
1340
|
# at this point we know our starting dataset so setting up schemas
|
|
1132
1341
|
self.feature_schema = ds.get_version(self.version).feature_schema
|
|
1133
1342
|
self.column_types = copy(ds.schema)
|
|
1134
1343
|
if "sys__id" in self.column_types:
|
|
1135
1344
|
self.column_types.pop("sys__id")
|
|
1345
|
+
self.project = ds.project
|
|
1136
1346
|
|
|
1137
1347
|
def __iter__(self):
|
|
1138
1348
|
return iter(self.db_results())
|
|
@@ -1140,39 +1350,28 @@ class DatasetQuery:
|
|
|
1140
1350
|
def __or__(self, other):
|
|
1141
1351
|
return self.union(other)
|
|
1142
1352
|
|
|
1143
|
-
def
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1353
|
+
def hash(self) -> str:
|
|
1354
|
+
"""
|
|
1355
|
+
Calculates hash of this class taking into account hash of starting step
|
|
1356
|
+
and hashes of each following steps. Ordering is important.
|
|
1357
|
+
"""
|
|
1358
|
+
hasher = hashlib.sha256()
|
|
1359
|
+
if self.starting_step:
|
|
1360
|
+
hasher.update(self.starting_step.hash().encode("utf-8"))
|
|
1361
|
+
else:
|
|
1362
|
+
assert self.list_ds_name
|
|
1363
|
+
hasher.update(self.list_ds_name.encode("utf-8"))
|
|
1149
1364
|
|
|
1150
|
-
self.
|
|
1151
|
-
|
|
1152
|
-
local_ds_name=name,
|
|
1153
|
-
local_ds_version=version,
|
|
1154
|
-
)
|
|
1365
|
+
for step in self.steps:
|
|
1366
|
+
hasher.update(step.hash().encode("utf-8"))
|
|
1155
1367
|
|
|
1156
|
-
return
|
|
1368
|
+
return hasher.hexdigest()
|
|
1157
1369
|
|
|
1158
1370
|
@staticmethod
|
|
1159
1371
|
def get_table() -> "TableClause":
|
|
1160
|
-
table_name = "".join(
|
|
1161
|
-
random.choice(string.ascii_letters) # noqa: S311
|
|
1162
|
-
for _ in range(16)
|
|
1163
|
-
)
|
|
1372
|
+
table_name = "".join(secrets.choice(string.ascii_letters) for _ in range(16))
|
|
1164
1373
|
return sqlalchemy.table(table_name)
|
|
1165
1374
|
|
|
1166
|
-
@staticmethod
|
|
1167
|
-
def delete(
|
|
1168
|
-
name: str, version: Optional[int] = None, catalog: Optional["Catalog"] = None
|
|
1169
|
-
) -> None:
|
|
1170
|
-
from datachain.catalog import get_catalog
|
|
1171
|
-
|
|
1172
|
-
catalog = catalog or get_catalog()
|
|
1173
|
-
version = version or catalog.get_dataset(name).latest_version
|
|
1174
|
-
catalog.remove_dataset(name, version)
|
|
1175
|
-
|
|
1176
1375
|
@property
|
|
1177
1376
|
def attached(self) -> bool:
|
|
1178
1377
|
"""
|
|
@@ -1180,14 +1379,14 @@ class DatasetQuery:
|
|
|
1180
1379
|
it completely. If this is the case, name and version of underlying dataset
|
|
1181
1380
|
will be defined.
|
|
1182
1381
|
DatasetQuery instance can become attached in two scenarios:
|
|
1183
|
-
1. ds = DatasetQuery(name="dogs", version=1) -> ds is attached to dogs
|
|
1184
|
-
2. ds = ds.save("dogs", version=1) -> ds is attached to dogs dataset
|
|
1382
|
+
1. ds = DatasetQuery(name="dogs", version="1.0.0") -> ds is attached to dogs
|
|
1383
|
+
2. ds = ds.save("dogs", version="1.0.0") -> ds is attached to dogs dataset
|
|
1185
1384
|
It can move to detached state if filter or similar methods are called on it,
|
|
1186
1385
|
as then it no longer 100% represents underlying datasets.
|
|
1187
1386
|
"""
|
|
1188
1387
|
return self.name is not None and self.version is not None
|
|
1189
1388
|
|
|
1190
|
-
def c(self, column:
|
|
1389
|
+
def c(self, column: C | str) -> "ColumnClause[Any]":
|
|
1191
1390
|
col: sqlalchemy.ColumnClause = (
|
|
1192
1391
|
sqlalchemy.column(column)
|
|
1193
1392
|
if isinstance(column, str)
|
|
@@ -1200,11 +1399,8 @@ class DatasetQuery:
|
|
|
1200
1399
|
"""Setting listing function to be run if needed"""
|
|
1201
1400
|
self.listing_fn = fn
|
|
1202
1401
|
|
|
1203
|
-
def
|
|
1204
|
-
"""
|
|
1205
|
-
Apply the steps in the query and return the resulting
|
|
1206
|
-
sqlalchemy.SelectBase.
|
|
1207
|
-
"""
|
|
1402
|
+
def apply_listing_pre_step(self) -> None:
|
|
1403
|
+
"""Runs listing pre-step if needed"""
|
|
1208
1404
|
if self.list_ds_name and not self.starting_step:
|
|
1209
1405
|
listing_ds = None
|
|
1210
1406
|
try:
|
|
@@ -1220,6 +1416,13 @@ class DatasetQuery:
|
|
|
1220
1416
|
# at this point we know what is our starting listing dataset name
|
|
1221
1417
|
self._set_starting_step(listing_ds) # type: ignore [arg-type]
|
|
1222
1418
|
|
|
1419
|
+
def apply_steps(self) -> QueryGenerator:
|
|
1420
|
+
"""
|
|
1421
|
+
Apply the steps in the query and return the resulting
|
|
1422
|
+
sqlalchemy.SelectBase.
|
|
1423
|
+
"""
|
|
1424
|
+
self.apply_listing_pre_step()
|
|
1425
|
+
|
|
1223
1426
|
query = self.clone()
|
|
1224
1427
|
|
|
1225
1428
|
index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
|
|
@@ -1278,6 +1481,7 @@ class DatasetQuery:
|
|
|
1278
1481
|
# This is needed to always use a new connection with all metastore and warehouse
|
|
1279
1482
|
# implementations, as errors may close or render unusable the existing
|
|
1280
1483
|
# connections.
|
|
1484
|
+
assert len(self.temp_table_names) == len(set(self.temp_table_names))
|
|
1281
1485
|
with self.catalog.metastore.clone(use_new_connection=True) as metastore:
|
|
1282
1486
|
metastore.cleanup_tables(self.temp_table_names)
|
|
1283
1487
|
with self.catalog.warehouse.clone(use_new_connection=True) as warehouse:
|
|
@@ -1292,7 +1496,7 @@ class DatasetQuery:
|
|
|
1292
1496
|
return list(result)
|
|
1293
1497
|
|
|
1294
1498
|
def to_db_records(self) -> list[dict[str, Any]]:
|
|
1295
|
-
return self.db_results(lambda cols, row: dict(zip(cols, row)))
|
|
1499
|
+
return self.db_results(lambda cols, row: dict(zip(cols, row, strict=False)))
|
|
1296
1500
|
|
|
1297
1501
|
@contextlib.contextmanager
|
|
1298
1502
|
def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
|
|
@@ -1331,8 +1535,8 @@ class DatasetQuery:
|
|
|
1331
1535
|
yield from rows
|
|
1332
1536
|
|
|
1333
1537
|
async def get_params(row: Sequence) -> tuple:
|
|
1334
|
-
row_dict = RowDict(zip(query_fields, row))
|
|
1335
|
-
return tuple(
|
|
1538
|
+
row_dict = RowDict(zip(query_fields, row, strict=False))
|
|
1539
|
+
return tuple( # noqa: C409
|
|
1336
1540
|
[
|
|
1337
1541
|
await p.get_value_async(
|
|
1338
1542
|
self.catalog, row_dict, mapper, **kwargs
|
|
@@ -1348,10 +1552,6 @@ class DatasetQuery:
|
|
|
1348
1552
|
finally:
|
|
1349
1553
|
self.cleanup()
|
|
1350
1554
|
|
|
1351
|
-
def shuffle(self) -> "Self":
|
|
1352
|
-
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1353
|
-
return self.order_by(C.sys__rand)
|
|
1354
|
-
|
|
1355
1555
|
def sample(self, n) -> "Self":
|
|
1356
1556
|
"""
|
|
1357
1557
|
Return a random sample from the dataset.
|
|
@@ -1371,6 +1571,7 @@ class DatasetQuery:
|
|
|
1371
1571
|
obj.steps = obj.steps.copy()
|
|
1372
1572
|
if new_table:
|
|
1373
1573
|
obj.table = self.get_table()
|
|
1574
|
+
obj.temp_table_names = []
|
|
1374
1575
|
return obj
|
|
1375
1576
|
|
|
1376
1577
|
@detach
|
|
@@ -1441,7 +1642,7 @@ class DatasetQuery:
|
|
|
1441
1642
|
return query
|
|
1442
1643
|
|
|
1443
1644
|
@detach
|
|
1444
|
-
def mutate(self, *args, **kwargs) -> "Self":
|
|
1645
|
+
def mutate(self, *args, new_schema, **kwargs) -> "Self":
|
|
1445
1646
|
"""
|
|
1446
1647
|
Add new columns to this query.
|
|
1447
1648
|
|
|
@@ -1453,7 +1654,7 @@ class DatasetQuery:
|
|
|
1453
1654
|
"""
|
|
1454
1655
|
query_args = [v.label(k) for k, v in dict(args, **kwargs).items()]
|
|
1455
1656
|
query = self.clone()
|
|
1456
|
-
query.steps.append(SQLMutate((*query_args,)))
|
|
1657
|
+
query.steps.append(SQLMutate((*query_args,), new_schema))
|
|
1457
1658
|
return query
|
|
1458
1659
|
|
|
1459
1660
|
@detach
|
|
@@ -1551,10 +1752,10 @@ class DatasetQuery:
|
|
|
1551
1752
|
def join(
|
|
1552
1753
|
self,
|
|
1553
1754
|
dataset_query: "DatasetQuery",
|
|
1554
|
-
predicates:
|
|
1755
|
+
predicates: JoinPredicateType | Sequence[JoinPredicateType],
|
|
1555
1756
|
inner=False,
|
|
1556
1757
|
full=False,
|
|
1557
|
-
rname="
|
|
1758
|
+
rname="right_",
|
|
1558
1759
|
) -> "Self":
|
|
1559
1760
|
left = self.clone(new_table=False)
|
|
1560
1761
|
if self.table.name == dataset_query.table.name:
|
|
@@ -1593,11 +1794,17 @@ class DatasetQuery:
|
|
|
1593
1794
|
def add_signals(
|
|
1594
1795
|
self,
|
|
1595
1796
|
udf: "UDFAdapter",
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
min_task_size: Optional[int] = None,
|
|
1599
|
-
partition_by: Optional[PartitionByType] = None,
|
|
1797
|
+
partition_by: PartitionByType | None = None,
|
|
1798
|
+
# Parameters from Settings
|
|
1600
1799
|
cache: bool = False,
|
|
1800
|
+
parallel: int | None = None,
|
|
1801
|
+
workers: bool | int = False,
|
|
1802
|
+
min_task_size: int | None = None,
|
|
1803
|
+
batch_size: int | None = None,
|
|
1804
|
+
# Parameters are unused, kept only to match the signature of Settings.to_dict
|
|
1805
|
+
prefetch: int | None = None,
|
|
1806
|
+
namespace: str | None = None,
|
|
1807
|
+
project: str | None = None,
|
|
1601
1808
|
) -> "Self":
|
|
1602
1809
|
"""
|
|
1603
1810
|
Adds one or more signals based on the results from the provided UDF.
|
|
@@ -1623,6 +1830,7 @@ class DatasetQuery:
|
|
|
1623
1830
|
workers=workers,
|
|
1624
1831
|
min_task_size=min_task_size,
|
|
1625
1832
|
cache=cache,
|
|
1833
|
+
batch_size=batch_size,
|
|
1626
1834
|
)
|
|
1627
1835
|
)
|
|
1628
1836
|
return query
|
|
@@ -1637,11 +1845,17 @@ class DatasetQuery:
|
|
|
1637
1845
|
def generate(
|
|
1638
1846
|
self,
|
|
1639
1847
|
udf: "UDFAdapter",
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
min_task_size: Optional[int] = None,
|
|
1643
|
-
partition_by: Optional[PartitionByType] = None,
|
|
1848
|
+
partition_by: PartitionByType | None = None,
|
|
1849
|
+
# Parameters from Settings
|
|
1644
1850
|
cache: bool = False,
|
|
1851
|
+
parallel: int | None = None,
|
|
1852
|
+
workers: bool | int = False,
|
|
1853
|
+
min_task_size: int | None = None,
|
|
1854
|
+
batch_size: int | None = None,
|
|
1855
|
+
# Parameters are unused, kept only to match the signature of Settings.to_dict:
|
|
1856
|
+
prefetch: int | None = None,
|
|
1857
|
+
namespace: str | None = None,
|
|
1858
|
+
project: str | None = None,
|
|
1645
1859
|
) -> "Self":
|
|
1646
1860
|
query = self.clone()
|
|
1647
1861
|
steps = query.steps
|
|
@@ -1654,41 +1868,84 @@ class DatasetQuery:
|
|
|
1654
1868
|
workers=workers,
|
|
1655
1869
|
min_task_size=min_task_size,
|
|
1656
1870
|
cache=cache,
|
|
1871
|
+
batch_size=batch_size,
|
|
1657
1872
|
)
|
|
1658
1873
|
)
|
|
1659
1874
|
return query
|
|
1660
1875
|
|
|
1661
|
-
def _add_dependencies(self, dataset: "DatasetRecord", version:
|
|
1662
|
-
|
|
1663
|
-
|
|
1876
|
+
def _add_dependencies(self, dataset: "DatasetRecord", version: str):
|
|
1877
|
+
dependencies: set[DatasetDependencyType] = set()
|
|
1878
|
+
for dep_dataset, dep_dataset_version in self.dependencies:
|
|
1879
|
+
if Session.is_temp_dataset(dep_dataset.name):
|
|
1880
|
+
# temp dataset are created for optimization and they will be removed
|
|
1881
|
+
# afterwards. Therefore, we should not put them as dependencies, but
|
|
1882
|
+
# their own direct dependencies
|
|
1883
|
+
for dep in self.catalog.get_dataset_dependencies(
|
|
1884
|
+
dep_dataset.name,
|
|
1885
|
+
dep_dataset_version,
|
|
1886
|
+
namespace_name=dep_dataset.project.namespace.name,
|
|
1887
|
+
project_name=dep_dataset.project.name,
|
|
1888
|
+
indirect=False,
|
|
1889
|
+
):
|
|
1890
|
+
if dep:
|
|
1891
|
+
dependencies.add(
|
|
1892
|
+
(
|
|
1893
|
+
self.catalog.get_dataset(
|
|
1894
|
+
dep.name,
|
|
1895
|
+
namespace_name=dep.namespace,
|
|
1896
|
+
project_name=dep.project,
|
|
1897
|
+
),
|
|
1898
|
+
dep.version,
|
|
1899
|
+
)
|
|
1900
|
+
)
|
|
1901
|
+
else:
|
|
1902
|
+
dependencies.add((dep_dataset, dep_dataset_version))
|
|
1903
|
+
|
|
1904
|
+
for dep_dataset, dep_dataset_version in dependencies:
|
|
1664
1905
|
self.catalog.metastore.add_dataset_dependency(
|
|
1665
|
-
dataset
|
|
1906
|
+
dataset,
|
|
1666
1907
|
version,
|
|
1667
|
-
|
|
1668
|
-
|
|
1908
|
+
dep_dataset,
|
|
1909
|
+
dep_dataset_version,
|
|
1669
1910
|
)
|
|
1670
1911
|
|
|
1671
1912
|
def exec(self) -> "Self":
|
|
1672
1913
|
"""Execute the query."""
|
|
1914
|
+
query = self.clone()
|
|
1673
1915
|
try:
|
|
1674
|
-
query = self.clone()
|
|
1675
1916
|
query.apply_steps()
|
|
1676
1917
|
finally:
|
|
1677
|
-
|
|
1918
|
+
query.cleanup()
|
|
1678
1919
|
return query
|
|
1679
1920
|
|
|
1680
1921
|
def save(
|
|
1681
1922
|
self,
|
|
1682
|
-
name:
|
|
1683
|
-
version:
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1923
|
+
name: str | None = None,
|
|
1924
|
+
version: str | None = None,
|
|
1925
|
+
project: Project | None = None,
|
|
1926
|
+
feature_schema: dict | None = None,
|
|
1927
|
+
dependencies: list[DatasetDependency] | None = None,
|
|
1928
|
+
description: str | None = None,
|
|
1929
|
+
attrs: list[str] | None = None,
|
|
1930
|
+
update_version: str | None = "patch",
|
|
1687
1931
|
**kwargs,
|
|
1688
1932
|
) -> "Self":
|
|
1689
1933
|
"""Save the query as a dataset."""
|
|
1934
|
+
# Get job from session to link dataset version to job
|
|
1935
|
+
job = self.session.get_or_create_job()
|
|
1936
|
+
job_id = job.id
|
|
1937
|
+
|
|
1938
|
+
project = project or self.catalog.metastore.default_project
|
|
1690
1939
|
try:
|
|
1691
|
-
if
|
|
1940
|
+
if (
|
|
1941
|
+
name
|
|
1942
|
+
and version
|
|
1943
|
+
and self.catalog.get_dataset(
|
|
1944
|
+
name,
|
|
1945
|
+
namespace_name=project.namespace.name,
|
|
1946
|
+
project_name=project.name,
|
|
1947
|
+
).has_version(version)
|
|
1948
|
+
):
|
|
1692
1949
|
raise RuntimeError(f"Dataset {name} already has version {version}")
|
|
1693
1950
|
except DatasetNotFoundError:
|
|
1694
1951
|
pass
|
|
@@ -1713,19 +1970,18 @@ class DatasetQuery:
|
|
|
1713
1970
|
|
|
1714
1971
|
dataset = self.catalog.create_dataset(
|
|
1715
1972
|
name,
|
|
1973
|
+
project,
|
|
1716
1974
|
version=version,
|
|
1717
1975
|
feature_schema=feature_schema,
|
|
1718
1976
|
columns=columns,
|
|
1719
1977
|
description=description,
|
|
1720
|
-
|
|
1978
|
+
attrs=attrs,
|
|
1979
|
+
update_version=update_version,
|
|
1980
|
+
job_id=job_id,
|
|
1721
1981
|
**kwargs,
|
|
1722
1982
|
)
|
|
1723
1983
|
version = version or dataset.latest_version
|
|
1724
1984
|
|
|
1725
|
-
self.session.add_dataset_version(
|
|
1726
|
-
dataset=dataset, version=version, listing=kwargs.get("listing", False)
|
|
1727
|
-
)
|
|
1728
|
-
|
|
1729
1985
|
dr = self.catalog.warehouse.dataset_rows(dataset)
|
|
1730
1986
|
|
|
1731
1987
|
self.catalog.warehouse.copy_table(dr.get_table(), query.select())
|
|
@@ -1735,15 +1991,41 @@ class DatasetQuery:
|
|
|
1735
1991
|
)
|
|
1736
1992
|
self.catalog.update_dataset_version_with_warehouse_info(dataset, version)
|
|
1737
1993
|
|
|
1994
|
+
# Link this dataset version to the job that created it
|
|
1995
|
+
self.catalog.metastore.link_dataset_version_to_job(
|
|
1996
|
+
dataset.get_version(version).id, job_id, is_creator=True
|
|
1997
|
+
)
|
|
1998
|
+
|
|
1999
|
+
if dependencies:
|
|
2000
|
+
# overriding dependencies
|
|
2001
|
+
self.dependencies = set()
|
|
2002
|
+
for dep in dependencies:
|
|
2003
|
+
self.dependencies.add(
|
|
2004
|
+
(
|
|
2005
|
+
self.catalog.get_dataset(
|
|
2006
|
+
dep.name,
|
|
2007
|
+
namespace_name=dep.namespace,
|
|
2008
|
+
project_name=dep.project,
|
|
2009
|
+
),
|
|
2010
|
+
dep.version,
|
|
2011
|
+
)
|
|
2012
|
+
)
|
|
2013
|
+
|
|
1738
2014
|
self._add_dependencies(dataset, version) # type: ignore [arg-type]
|
|
1739
2015
|
finally:
|
|
1740
2016
|
self.cleanup()
|
|
1741
|
-
return self.__class__(
|
|
2017
|
+
return self.__class__(
|
|
2018
|
+
name=name,
|
|
2019
|
+
namespace_name=project.namespace.name,
|
|
2020
|
+
project_name=project.name,
|
|
2021
|
+
version=version,
|
|
2022
|
+
catalog=self.catalog,
|
|
2023
|
+
)
|
|
1742
2024
|
|
|
1743
2025
|
@property
|
|
1744
2026
|
def is_ordered(self) -> bool:
|
|
1745
2027
|
return isinstance(self.last_step, SQLOrderBy)
|
|
1746
2028
|
|
|
1747
2029
|
@property
|
|
1748
|
-
def last_step(self) ->
|
|
2030
|
+
def last_step(self) -> Step | None:
|
|
1749
2031
|
return self.steps[-1] if self.steps else None
|