datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/query/dataset.py
CHANGED
|
@@ -1,25 +1,18 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import hashlib
|
|
2
3
|
import inspect
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
5
|
-
import
|
|
6
|
+
import secrets
|
|
6
7
|
import string
|
|
7
8
|
import subprocess
|
|
8
9
|
import sys
|
|
9
10
|
from abc import ABC, abstractmethod
|
|
10
|
-
from collections.abc import Generator, Iterable, Iterator, Sequence
|
|
11
|
+
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
11
12
|
from copy import copy
|
|
12
13
|
from functools import wraps
|
|
13
14
|
from types import GeneratorType
|
|
14
|
-
from typing import
|
|
15
|
-
TYPE_CHECKING,
|
|
16
|
-
Any,
|
|
17
|
-
Callable,
|
|
18
|
-
Optional,
|
|
19
|
-
Protocol,
|
|
20
|
-
TypeVar,
|
|
21
|
-
Union,
|
|
22
|
-
)
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
|
23
16
|
|
|
24
17
|
import attrs
|
|
25
18
|
import sqlalchemy
|
|
@@ -44,20 +37,21 @@ from datachain.data_storage.schema import (
|
|
|
44
37
|
from datachain.dataset import DatasetDependency, DatasetStatus, RowDict
|
|
45
38
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
46
39
|
from datachain.func.base import Function
|
|
40
|
+
from datachain.hash_utils import hash_column_elements
|
|
47
41
|
from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
|
|
48
|
-
from datachain.lib.signal_schema import SignalSchema
|
|
42
|
+
from datachain.lib.signal_schema import SignalSchema, generate_merge_root_mapping
|
|
49
43
|
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
50
44
|
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
51
45
|
from datachain.project import Project
|
|
52
|
-
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
46
|
+
from datachain.query.schema import DEFAULT_DELIMITER, C, UDFParamSpec, normalize_param
|
|
53
47
|
from datachain.query.session import Session
|
|
54
48
|
from datachain.query.udf import UdfInfo
|
|
55
49
|
from datachain.sql.functions.random import rand
|
|
56
50
|
from datachain.sql.types import SQLType
|
|
57
51
|
from datachain.utils import (
|
|
58
|
-
batched,
|
|
59
52
|
determine_processes,
|
|
60
53
|
determine_workers,
|
|
54
|
+
ensure_sequence,
|
|
61
55
|
filtered_cloudpickle_dumps,
|
|
62
56
|
get_datachain_executable,
|
|
63
57
|
safe_closing,
|
|
@@ -65,11 +59,12 @@ from datachain.utils import (
|
|
|
65
59
|
|
|
66
60
|
if TYPE_CHECKING:
|
|
67
61
|
from collections.abc import Mapping
|
|
62
|
+
from typing import Concatenate
|
|
68
63
|
|
|
69
|
-
from sqlalchemy.sql.elements import ClauseElement
|
|
64
|
+
from sqlalchemy.sql.elements import ClauseElement, KeyedColumnElement
|
|
70
65
|
from sqlalchemy.sql.schema import Table
|
|
71
66
|
from sqlalchemy.sql.selectable import GenerativeSelect
|
|
72
|
-
from typing_extensions import
|
|
67
|
+
from typing_extensions import ParamSpec, Self
|
|
73
68
|
|
|
74
69
|
from datachain.catalog import Catalog
|
|
75
70
|
from datachain.data_storage import AbstractWarehouse
|
|
@@ -81,13 +76,10 @@ if TYPE_CHECKING:
|
|
|
81
76
|
|
|
82
77
|
INSERT_BATCH_SIZE = 10000
|
|
83
78
|
|
|
84
|
-
PartitionByType =
|
|
85
|
-
str
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
Sequence[Union[str, Function, ColumnElement]],
|
|
89
|
-
]
|
|
90
|
-
JoinPredicateType = Union[str, ColumnClause, ColumnElement]
|
|
79
|
+
PartitionByType = (
|
|
80
|
+
str | Function | ColumnElement | Sequence[str | Function | ColumnElement]
|
|
81
|
+
)
|
|
82
|
+
JoinPredicateType = str | ColumnClause | ColumnElement
|
|
91
83
|
DatasetDependencyType = tuple["DatasetRecord", str]
|
|
92
84
|
|
|
93
85
|
logger = logging.getLogger("datachain")
|
|
@@ -168,6 +160,18 @@ class Step(ABC):
|
|
|
168
160
|
) -> "StepResult":
|
|
169
161
|
"""Apply the processing step."""
|
|
170
162
|
|
|
163
|
+
@abstractmethod
|
|
164
|
+
def hash_inputs(self) -> str:
|
|
165
|
+
"""Calculates hash of step inputs"""
|
|
166
|
+
|
|
167
|
+
def hash(self) -> str:
|
|
168
|
+
"""
|
|
169
|
+
Calculates hash for step which includes step name and hash of it's inputs
|
|
170
|
+
"""
|
|
171
|
+
return hashlib.sha256(
|
|
172
|
+
f"{self.__class__.__name__}|{self.hash_inputs()}".encode()
|
|
173
|
+
).hexdigest()
|
|
174
|
+
|
|
171
175
|
|
|
172
176
|
@frozen
|
|
173
177
|
class QueryStep:
|
|
@@ -187,6 +191,11 @@ class QueryStep:
|
|
|
187
191
|
q, dr.columns, dependencies=[(self.dataset, self.dataset_version)]
|
|
188
192
|
)
|
|
189
193
|
|
|
194
|
+
def hash(self) -> str:
|
|
195
|
+
return hashlib.sha256(
|
|
196
|
+
self.dataset.uri(self.dataset_version).encode()
|
|
197
|
+
).hexdigest()
|
|
198
|
+
|
|
190
199
|
|
|
191
200
|
def generator_then_call(generator, func: Callable):
|
|
192
201
|
"""
|
|
@@ -222,8 +231,9 @@ class DatasetDiffOperation(Step):
|
|
|
222
231
|
|
|
223
232
|
def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
|
|
224
233
|
source_query = query_generator.exclude(("sys__id",))
|
|
234
|
+
right_before = len(self.dq.temp_table_names)
|
|
225
235
|
target_query = self.dq.apply_steps().select()
|
|
226
|
-
temp_tables.extend(self.dq.temp_table_names)
|
|
236
|
+
temp_tables.extend(self.dq.temp_table_names[right_before:])
|
|
227
237
|
|
|
228
238
|
# creating temp table that will hold subtract results
|
|
229
239
|
temp_table_name = self.catalog.warehouse.temp_table_name()
|
|
@@ -257,6 +267,13 @@ class DatasetDiffOperation(Step):
|
|
|
257
267
|
class Subtract(DatasetDiffOperation):
|
|
258
268
|
on: Sequence[tuple[str, str]]
|
|
259
269
|
|
|
270
|
+
def hash_inputs(self) -> str:
|
|
271
|
+
on_bytes = b"".join(
|
|
272
|
+
f"{a}:{b}".encode() for a, b in sorted(self.on, key=lambda t: (t[0], t[1]))
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return hashlib.sha256(bytes.fromhex(self.dq.hash()) + on_bytes).hexdigest()
|
|
276
|
+
|
|
260
277
|
def query(self, source_query: Select, target_query: Select) -> sa.Selectable:
|
|
261
278
|
sq = source_query.alias("source_query")
|
|
262
279
|
tq = target_query.alias("target_query")
|
|
@@ -334,10 +351,10 @@ def process_udf_outputs(
|
|
|
334
351
|
udf_results: Iterator[Iterable["UDFResult"]],
|
|
335
352
|
udf: "UDFAdapter",
|
|
336
353
|
cb: Callback = DEFAULT_CALLBACK,
|
|
354
|
+
batch_size: int = INSERT_BATCH_SIZE,
|
|
337
355
|
) -> None:
|
|
338
356
|
# Optimization: Compute row types once, rather than for every row.
|
|
339
357
|
udf_col_types = get_col_types(warehouse, udf.output)
|
|
340
|
-
batch_rows = udf.batch_rows or INSERT_BATCH_SIZE
|
|
341
358
|
|
|
342
359
|
def _insert_rows():
|
|
343
360
|
for udf_output in udf_results:
|
|
@@ -349,9 +366,7 @@ def process_udf_outputs(
|
|
|
349
366
|
cb.relative_update()
|
|
350
367
|
yield adjust_outputs(warehouse, row, udf_col_types)
|
|
351
368
|
|
|
352
|
-
|
|
353
|
-
warehouse.insert_rows(udf_table, row_chunk)
|
|
354
|
-
|
|
369
|
+
warehouse.insert_rows(udf_table, _insert_rows(), batch_size=batch_size)
|
|
355
370
|
warehouse.insert_rows_done(udf_table)
|
|
356
371
|
|
|
357
372
|
|
|
@@ -387,21 +402,34 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
387
402
|
class UDFStep(Step, ABC):
|
|
388
403
|
udf: "UDFAdapter"
|
|
389
404
|
catalog: "Catalog"
|
|
390
|
-
partition_by:
|
|
391
|
-
parallel: Optional[int] = None
|
|
392
|
-
workers: Union[bool, int] = False
|
|
393
|
-
min_task_size: Optional[int] = None
|
|
405
|
+
partition_by: PartitionByType | None = None
|
|
394
406
|
is_generator = False
|
|
407
|
+
# Parameters from Settings
|
|
395
408
|
cache: bool = False
|
|
396
|
-
|
|
409
|
+
parallel: int | None = None
|
|
410
|
+
workers: bool | int = False
|
|
411
|
+
min_task_size: int | None = None
|
|
412
|
+
batch_size: int | None = None
|
|
413
|
+
|
|
414
|
+
def hash_inputs(self) -> str:
|
|
415
|
+
partition_by = ensure_sequence(self.partition_by or [])
|
|
416
|
+
parts = [
|
|
417
|
+
bytes.fromhex(self.udf.hash()),
|
|
418
|
+
bytes.fromhex(hash_column_elements(partition_by)),
|
|
419
|
+
str(self.is_generator).encode(),
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
return hashlib.sha256(b"".join(parts)).hexdigest()
|
|
397
423
|
|
|
398
424
|
@abstractmethod
|
|
399
425
|
def create_udf_table(self, query: Select) -> "Table":
|
|
400
426
|
"""Method that creates a table where temp udf results will be saved"""
|
|
401
427
|
|
|
402
428
|
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
|
|
403
|
-
"""
|
|
404
|
-
|
|
429
|
+
"""Materialize inputs, ensure sys columns are available, needed for checkpoints,
|
|
430
|
+
needed for map to work (merge results)"""
|
|
431
|
+
table = self.catalog.warehouse.create_pre_udf_table(query)
|
|
432
|
+
return sqlalchemy.select(*table.c), [table]
|
|
405
433
|
|
|
406
434
|
@abstractmethod
|
|
407
435
|
def create_result_query(
|
|
@@ -450,6 +478,7 @@ class UDFStep(Step, ABC):
|
|
|
450
478
|
use_cache=self.cache,
|
|
451
479
|
is_generator=self.is_generator,
|
|
452
480
|
min_task_size=self.min_task_size,
|
|
481
|
+
batch_size=self.batch_size,
|
|
453
482
|
)
|
|
454
483
|
udf_distributor()
|
|
455
484
|
return
|
|
@@ -486,6 +515,7 @@ class UDFStep(Step, ABC):
|
|
|
486
515
|
is_generator=self.is_generator,
|
|
487
516
|
cache=self.cache,
|
|
488
517
|
rows_total=rows_total,
|
|
518
|
+
batch_size=self.batch_size or INSERT_BATCH_SIZE,
|
|
489
519
|
)
|
|
490
520
|
|
|
491
521
|
# Run the UDFDispatcher in another process to avoid needing
|
|
@@ -534,6 +564,7 @@ class UDFStep(Step, ABC):
|
|
|
534
564
|
udf_results,
|
|
535
565
|
self.udf,
|
|
536
566
|
cb=generated_cb,
|
|
567
|
+
batch_size=self.batch_size or INSERT_BATCH_SIZE,
|
|
537
568
|
)
|
|
538
569
|
finally:
|
|
539
570
|
download_cb.close()
|
|
@@ -552,13 +583,10 @@ class UDFStep(Step, ABC):
|
|
|
552
583
|
"""
|
|
553
584
|
Create temporary table with group by partitions.
|
|
554
585
|
"""
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
assert any(c.name == "sys__id" for c in query.selected_columns), (
|
|
560
|
-
"Query must have sys__id column to use partitioning."
|
|
561
|
-
)
|
|
586
|
+
if self.partition_by is None:
|
|
587
|
+
raise RuntimeError("Query must have partition_by set to use partitioning")
|
|
588
|
+
if (id_col := query.selected_columns.get("sys__id")) is None:
|
|
589
|
+
raise RuntimeError("Query must have sys__id column to use partitioning")
|
|
562
590
|
|
|
563
591
|
if isinstance(self.partition_by, (list, tuple, GeneratorType)):
|
|
564
592
|
list_partition_by = list(self.partition_by)
|
|
@@ -574,7 +602,7 @@ class UDFStep(Step, ABC):
|
|
|
574
602
|
|
|
575
603
|
# fill table with partitions
|
|
576
604
|
cols = [
|
|
577
|
-
|
|
605
|
+
id_col,
|
|
578
606
|
f.dense_rank().over(order_by=partition_by).label(PARTITION_COLUMN_ID),
|
|
579
607
|
]
|
|
580
608
|
self.catalog.warehouse.db.execute(
|
|
@@ -586,7 +614,7 @@ class UDFStep(Step, ABC):
|
|
|
586
614
|
|
|
587
615
|
return tbl
|
|
588
616
|
|
|
589
|
-
def clone(self, partition_by:
|
|
617
|
+
def clone(self, partition_by: PartitionByType | None = None) -> "Self":
|
|
590
618
|
if partition_by is not None:
|
|
591
619
|
return self.__class__(
|
|
592
620
|
self.udf,
|
|
@@ -595,41 +623,25 @@ class UDFStep(Step, ABC):
|
|
|
595
623
|
parallel=self.parallel,
|
|
596
624
|
workers=self.workers,
|
|
597
625
|
min_task_size=self.min_task_size,
|
|
598
|
-
|
|
626
|
+
batch_size=self.batch_size,
|
|
599
627
|
)
|
|
600
628
|
return self.__class__(self.udf, self.catalog)
|
|
601
629
|
|
|
602
630
|
def apply(
|
|
603
631
|
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
604
632
|
) -> "StepResult":
|
|
605
|
-
|
|
633
|
+
query, tables = self.process_input_query(query_generator.select())
|
|
634
|
+
_query = query
|
|
606
635
|
|
|
607
636
|
# Apply partitioning if needed.
|
|
608
637
|
if self.partition_by is not None:
|
|
609
|
-
if not any(c.name == "sys__id" for c in query.selected_columns):
|
|
610
|
-
# If sys__id is not in the query, we need to create a temp table
|
|
611
|
-
# to hold the query results, so we can join it with the
|
|
612
|
-
# partition table later.
|
|
613
|
-
columns = [
|
|
614
|
-
c if isinstance(c, Column) else Column(c.name, c.type)
|
|
615
|
-
for c in query.subquery().columns
|
|
616
|
-
]
|
|
617
|
-
temp_table = self.catalog.warehouse.create_dataset_rows_table(
|
|
618
|
-
self.catalog.warehouse.temp_table_name(),
|
|
619
|
-
columns=columns,
|
|
620
|
-
)
|
|
621
|
-
temp_tables.append(temp_table.name)
|
|
622
|
-
self.catalog.warehouse.copy_table(temp_table, query)
|
|
623
|
-
_query = query = temp_table.select()
|
|
624
|
-
|
|
625
638
|
partition_tbl = self.create_partitions_table(query)
|
|
626
|
-
temp_tables.append(partition_tbl.name)
|
|
627
639
|
query = query.outerjoin(
|
|
628
640
|
partition_tbl,
|
|
629
641
|
partition_tbl.c.sys__id == query.selected_columns.sys__id,
|
|
630
642
|
).add_columns(*partition_columns())
|
|
643
|
+
tables = [*tables, partition_tbl]
|
|
631
644
|
|
|
632
|
-
query, tables = self.process_input_query(query)
|
|
633
645
|
temp_tables.extend(t.name for t in tables)
|
|
634
646
|
udf_table = self.create_udf_table(_query)
|
|
635
647
|
temp_tables.append(udf_table.name)
|
|
@@ -641,7 +653,16 @@ class UDFStep(Step, ABC):
|
|
|
641
653
|
|
|
642
654
|
@frozen
|
|
643
655
|
class UDFSignal(UDFStep):
|
|
656
|
+
udf: "UDFAdapter"
|
|
657
|
+
catalog: "Catalog"
|
|
658
|
+
partition_by: PartitionByType | None = None
|
|
644
659
|
is_generator = False
|
|
660
|
+
# Parameters from Settings
|
|
661
|
+
cache: bool = False
|
|
662
|
+
parallel: int | None = None
|
|
663
|
+
workers: bool | int = False
|
|
664
|
+
min_task_size: int | None = None
|
|
665
|
+
batch_size: int | None = None
|
|
645
666
|
|
|
646
667
|
def create_udf_table(self, query: Select) -> "Table":
|
|
647
668
|
udf_output_columns: list[sqlalchemy.Column[Any]] = [
|
|
@@ -651,13 +672,6 @@ class UDFSignal(UDFStep):
|
|
|
651
672
|
|
|
652
673
|
return self.catalog.warehouse.create_udf_table(udf_output_columns)
|
|
653
674
|
|
|
654
|
-
def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
|
|
655
|
-
if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
|
|
656
|
-
return query, []
|
|
657
|
-
table = self.catalog.warehouse.create_pre_udf_table(query)
|
|
658
|
-
q: Select = sqlalchemy.select(*table.c)
|
|
659
|
-
return q, [table]
|
|
660
|
-
|
|
661
675
|
def create_result_query(
|
|
662
676
|
self, udf_table, query
|
|
663
677
|
) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
|
|
@@ -669,11 +683,26 @@ class UDFSignal(UDFStep):
|
|
|
669
683
|
signal_name_cols = {c.name: c for c in signal_cols}
|
|
670
684
|
cols = signal_cols
|
|
671
685
|
|
|
672
|
-
|
|
686
|
+
original_names = {c.name for c in original_cols}
|
|
687
|
+
new_names = {c.name for c in cols}
|
|
688
|
+
|
|
689
|
+
overlap = original_names & new_names
|
|
673
690
|
if overlap:
|
|
674
691
|
raise ValueError(
|
|
675
692
|
"Column already exists or added in the previous steps: "
|
|
676
|
-
+ ", ".join(overlap)
|
|
693
|
+
+ ", ".join(sorted(overlap))
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
def _root(name: str) -> str:
|
|
697
|
+
return name.split(DEFAULT_DELIMITER, 1)[0]
|
|
698
|
+
|
|
699
|
+
existing_roots = {_root(name) for name in original_names}
|
|
700
|
+
new_roots = {_root(name) for name in new_names}
|
|
701
|
+
root_conflicts = existing_roots & new_roots
|
|
702
|
+
if root_conflicts:
|
|
703
|
+
raise ValueError(
|
|
704
|
+
"Signals already exist in the previous steps: "
|
|
705
|
+
+ ", ".join(sorted(root_conflicts))
|
|
677
706
|
)
|
|
678
707
|
|
|
679
708
|
def q(*columns):
|
|
@@ -711,7 +740,16 @@ class UDFSignal(UDFStep):
|
|
|
711
740
|
class RowGenerator(UDFStep):
|
|
712
741
|
"""Extend dataset with new rows."""
|
|
713
742
|
|
|
743
|
+
udf: "UDFAdapter"
|
|
744
|
+
catalog: "Catalog"
|
|
745
|
+
partition_by: PartitionByType | None = None
|
|
714
746
|
is_generator = True
|
|
747
|
+
# Parameters from Settings
|
|
748
|
+
cache: bool = False
|
|
749
|
+
parallel: int | None = None
|
|
750
|
+
workers: bool | int = False
|
|
751
|
+
min_task_size: int | None = None
|
|
752
|
+
batch_size: int | None = None
|
|
715
753
|
|
|
716
754
|
def create_udf_table(self, query: Select) -> "Table":
|
|
717
755
|
warehouse = self.catalog.warehouse
|
|
@@ -758,18 +796,42 @@ class SQLClause(Step, ABC):
|
|
|
758
796
|
|
|
759
797
|
def parse_cols(
|
|
760
798
|
self,
|
|
761
|
-
cols: Sequence[
|
|
799
|
+
cols: Sequence[Function | ColumnElement],
|
|
762
800
|
) -> tuple[ColumnElement, ...]:
|
|
763
801
|
return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
|
|
764
802
|
|
|
765
803
|
@abstractmethod
|
|
766
|
-
def apply_sql_clause(self, query):
|
|
804
|
+
def apply_sql_clause(self, query: Any) -> Any:
|
|
767
805
|
pass
|
|
768
806
|
|
|
769
807
|
|
|
808
|
+
@frozen
|
|
809
|
+
class RegenerateSystemColumns(Step):
|
|
810
|
+
catalog: "Catalog"
|
|
811
|
+
|
|
812
|
+
def hash_inputs(self) -> str:
|
|
813
|
+
return hashlib.sha256(b"regenerate_system_columns").hexdigest()
|
|
814
|
+
|
|
815
|
+
def apply(
|
|
816
|
+
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
817
|
+
) -> StepResult:
|
|
818
|
+
query = query_generator.select()
|
|
819
|
+
new_query = self.catalog.warehouse._regenerate_system_columns(
|
|
820
|
+
query, keep_existing_columns=True
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
def q(*columns):
|
|
824
|
+
return new_query.with_only_columns(*columns)
|
|
825
|
+
|
|
826
|
+
return step_result(q, new_query.selected_columns)
|
|
827
|
+
|
|
828
|
+
|
|
770
829
|
@frozen
|
|
771
830
|
class SQLSelect(SQLClause):
|
|
772
|
-
args: tuple[
|
|
831
|
+
args: tuple[Function | ColumnElement, ...]
|
|
832
|
+
|
|
833
|
+
def hash_inputs(self) -> str:
|
|
834
|
+
return hash_column_elements(self.args)
|
|
773
835
|
|
|
774
836
|
def apply_sql_clause(self, query) -> Select:
|
|
775
837
|
subquery = query.subquery()
|
|
@@ -785,7 +847,10 @@ class SQLSelect(SQLClause):
|
|
|
785
847
|
|
|
786
848
|
@frozen
|
|
787
849
|
class SQLSelectExcept(SQLClause):
|
|
788
|
-
args: tuple[
|
|
850
|
+
args: tuple[Function | ColumnElement, ...]
|
|
851
|
+
|
|
852
|
+
def hash_inputs(self) -> str:
|
|
853
|
+
return hash_column_elements(self.args)
|
|
789
854
|
|
|
790
855
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
791
856
|
subquery = query.subquery()
|
|
@@ -798,6 +863,9 @@ class SQLMutate(SQLClause):
|
|
|
798
863
|
args: tuple[Label, ...]
|
|
799
864
|
new_schema: SignalSchema
|
|
800
865
|
|
|
866
|
+
def hash_inputs(self) -> str:
|
|
867
|
+
return hash_column_elements(self.args)
|
|
868
|
+
|
|
801
869
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
802
870
|
original_subquery = query.subquery()
|
|
803
871
|
to_mutate = {c.name for c in self.args}
|
|
@@ -825,7 +893,10 @@ class SQLMutate(SQLClause):
|
|
|
825
893
|
|
|
826
894
|
@frozen
|
|
827
895
|
class SQLFilter(SQLClause):
|
|
828
|
-
expressions: tuple[
|
|
896
|
+
expressions: tuple[Function | ColumnElement, ...]
|
|
897
|
+
|
|
898
|
+
def hash_inputs(self) -> str:
|
|
899
|
+
return hash_column_elements(self.expressions)
|
|
829
900
|
|
|
830
901
|
def __and__(self, other):
|
|
831
902
|
expressions = self.parse_cols(self.expressions)
|
|
@@ -838,7 +909,10 @@ class SQLFilter(SQLClause):
|
|
|
838
909
|
|
|
839
910
|
@frozen
|
|
840
911
|
class SQLOrderBy(SQLClause):
|
|
841
|
-
args: tuple[
|
|
912
|
+
args: tuple[Function | ColumnElement, ...]
|
|
913
|
+
|
|
914
|
+
def hash_inputs(self) -> str:
|
|
915
|
+
return hash_column_elements(self.args)
|
|
842
916
|
|
|
843
917
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
844
918
|
args = self.parse_cols(self.args)
|
|
@@ -849,6 +923,9 @@ class SQLOrderBy(SQLClause):
|
|
|
849
923
|
class SQLLimit(SQLClause):
|
|
850
924
|
n: int
|
|
851
925
|
|
|
926
|
+
def hash_inputs(self) -> str:
|
|
927
|
+
return hashlib.sha256(str(self.n).encode()).hexdigest()
|
|
928
|
+
|
|
852
929
|
def apply_sql_clause(self, query: Select) -> Select:
|
|
853
930
|
return query.limit(self.n)
|
|
854
931
|
|
|
@@ -857,12 +934,18 @@ class SQLLimit(SQLClause):
|
|
|
857
934
|
class SQLOffset(SQLClause):
|
|
858
935
|
offset: int
|
|
859
936
|
|
|
937
|
+
def hash_inputs(self) -> str:
|
|
938
|
+
return hashlib.sha256(str(self.offset).encode()).hexdigest()
|
|
939
|
+
|
|
860
940
|
def apply_sql_clause(self, query: "GenerativeSelect"):
|
|
861
941
|
return query.offset(self.offset)
|
|
862
942
|
|
|
863
943
|
|
|
864
944
|
@frozen
|
|
865
945
|
class SQLCount(SQLClause):
|
|
946
|
+
def hash_inputs(self) -> str:
|
|
947
|
+
return ""
|
|
948
|
+
|
|
866
949
|
def apply_sql_clause(self, query):
|
|
867
950
|
return sqlalchemy.select(f.count(1)).select_from(query.subquery())
|
|
868
951
|
|
|
@@ -872,6 +955,9 @@ class SQLDistinct(SQLClause):
|
|
|
872
955
|
args: tuple[ColumnElement, ...]
|
|
873
956
|
dialect: str
|
|
874
957
|
|
|
958
|
+
def hash_inputs(self) -> str:
|
|
959
|
+
return hash_column_elements(self.args)
|
|
960
|
+
|
|
875
961
|
def apply_sql_clause(self, query):
|
|
876
962
|
if self.dialect == "sqlite":
|
|
877
963
|
return query.group_by(*self.args)
|
|
@@ -884,24 +970,34 @@ class SQLUnion(Step):
|
|
|
884
970
|
query1: "DatasetQuery"
|
|
885
971
|
query2: "DatasetQuery"
|
|
886
972
|
|
|
973
|
+
def hash_inputs(self) -> str:
|
|
974
|
+
return hashlib.sha256(
|
|
975
|
+
bytes.fromhex(self.query1.hash()) + bytes.fromhex(self.query2.hash())
|
|
976
|
+
).hexdigest()
|
|
977
|
+
|
|
887
978
|
def apply(
|
|
888
979
|
self, query_generator: QueryGenerator, temp_tables: list[str]
|
|
889
980
|
) -> StepResult:
|
|
981
|
+
left_before = len(self.query1.temp_table_names)
|
|
890
982
|
q1 = self.query1.apply_steps().select().subquery()
|
|
891
|
-
temp_tables.extend(self.query1.temp_table_names)
|
|
983
|
+
temp_tables.extend(self.query1.temp_table_names[left_before:])
|
|
984
|
+
right_before = len(self.query2.temp_table_names)
|
|
892
985
|
q2 = self.query2.apply_steps().select().subquery()
|
|
893
|
-
temp_tables.extend(self.query2.temp_table_names)
|
|
986
|
+
temp_tables.extend(self.query2.temp_table_names[right_before:])
|
|
894
987
|
|
|
895
|
-
columns1
|
|
988
|
+
columns1 = _drop_system_columns(q1.columns)
|
|
989
|
+
columns2 = _drop_system_columns(q2.columns)
|
|
990
|
+
columns1, columns2 = _order_columns(columns1, columns2)
|
|
896
991
|
|
|
897
992
|
def q(*columns):
|
|
898
|
-
|
|
899
|
-
col1 = [c for c in columns1 if c.name in
|
|
900
|
-
col2 = [c for c in columns2 if c.name in
|
|
901
|
-
|
|
993
|
+
selected_names = [c.name for c in columns]
|
|
994
|
+
col1 = [c for c in columns1 if c.name in selected_names]
|
|
995
|
+
col2 = [c for c in columns2 if c.name in selected_names]
|
|
996
|
+
union_query = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
|
|
902
997
|
|
|
903
|
-
|
|
904
|
-
|
|
998
|
+
union_cte = union_query.cte()
|
|
999
|
+
select_cols = [union_cte.c[name] for name in selected_names]
|
|
1000
|
+
return sqlalchemy.select(*select_cols)
|
|
905
1001
|
|
|
906
1002
|
return step_result(
|
|
907
1003
|
q,
|
|
@@ -915,14 +1011,42 @@ class SQLJoin(Step):
|
|
|
915
1011
|
catalog: "Catalog"
|
|
916
1012
|
query1: "DatasetQuery"
|
|
917
1013
|
query2: "DatasetQuery"
|
|
918
|
-
predicates:
|
|
1014
|
+
predicates: JoinPredicateType | tuple[JoinPredicateType, ...]
|
|
919
1015
|
inner: bool
|
|
920
1016
|
full: bool
|
|
921
1017
|
rname: str
|
|
922
1018
|
|
|
1019
|
+
@staticmethod
|
|
1020
|
+
def _split_db_name(name: str) -> tuple[str, str]:
|
|
1021
|
+
if DEFAULT_DELIMITER in name:
|
|
1022
|
+
head, tail = name.split(DEFAULT_DELIMITER, 1)
|
|
1023
|
+
return head, tail
|
|
1024
|
+
return name, ""
|
|
1025
|
+
|
|
1026
|
+
@classmethod
|
|
1027
|
+
def _root_name(cls, name: str) -> str:
|
|
1028
|
+
return cls._split_db_name(name)[0]
|
|
1029
|
+
|
|
1030
|
+
def hash_inputs(self) -> str:
|
|
1031
|
+
predicates = (
|
|
1032
|
+
ensure_sequence(self.predicates) if self.predicates is not None else []
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
parts = [
|
|
1036
|
+
bytes.fromhex(self.query1.hash()),
|
|
1037
|
+
bytes.fromhex(self.query2.hash()),
|
|
1038
|
+
bytes.fromhex(hash_column_elements(predicates)),
|
|
1039
|
+
str(self.inner).encode(),
|
|
1040
|
+
str(self.full).encode(),
|
|
1041
|
+
self.rname.encode("utf-8"),
|
|
1042
|
+
]
|
|
1043
|
+
|
|
1044
|
+
return hashlib.sha256(b"".join(parts)).hexdigest()
|
|
1045
|
+
|
|
923
1046
|
def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
|
|
1047
|
+
temp_tables_before = len(dq.temp_table_names)
|
|
924
1048
|
query = dq.apply_steps().select()
|
|
925
|
-
temp_tables.extend(dq.temp_table_names)
|
|
1049
|
+
temp_tables.extend(dq.temp_table_names[temp_tables_before:])
|
|
926
1050
|
|
|
927
1051
|
if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in dq.steps):
|
|
928
1052
|
return query.subquery(dq.table.name)
|
|
@@ -978,22 +1102,39 @@ class SQLJoin(Step):
|
|
|
978
1102
|
q1 = self.get_query(self.query1, temp_tables)
|
|
979
1103
|
q2 = self.get_query(self.query2, temp_tables)
|
|
980
1104
|
|
|
981
|
-
q1_columns =
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
for
|
|
986
|
-
if
|
|
1105
|
+
q1_columns = _drop_system_columns(q1.c)
|
|
1106
|
+
existing_column_names = {c.name for c in q1_columns}
|
|
1107
|
+
right_columns: list[KeyedColumnElement[Any]] = []
|
|
1108
|
+
right_column_names: list[str] = []
|
|
1109
|
+
for column in q2.c:
|
|
1110
|
+
if column.name.startswith("sys__"):
|
|
987
1111
|
continue
|
|
1112
|
+
right_columns.append(column)
|
|
1113
|
+
right_column_names.append(column.name)
|
|
1114
|
+
|
|
1115
|
+
root_mapping = generate_merge_root_mapping(
|
|
1116
|
+
existing_column_names,
|
|
1117
|
+
right_column_names,
|
|
1118
|
+
extract_root=self._root_name,
|
|
1119
|
+
prefix=self.rname,
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
q2_columns: list[KeyedColumnElement[Any]] = []
|
|
1123
|
+
for column in right_columns:
|
|
1124
|
+
original_name = column.name
|
|
1125
|
+
column_root, column_tail = self._split_db_name(original_name)
|
|
1126
|
+
mapped_root = root_mapping[column_root]
|
|
1127
|
+
|
|
1128
|
+
new_name = (
|
|
1129
|
+
mapped_root
|
|
1130
|
+
if not column_tail
|
|
1131
|
+
else DEFAULT_DELIMITER.join([mapped_root, column_tail])
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
if new_name != original_name:
|
|
1135
|
+
column = column.label(new_name)
|
|
988
1136
|
|
|
989
|
-
|
|
990
|
-
new_name = self.rname.format(name=c.name)
|
|
991
|
-
new_name_idx = 0
|
|
992
|
-
while new_name in q1_column_names:
|
|
993
|
-
new_name_idx += 1
|
|
994
|
-
new_name = self.rname.format(name=f"{c.name}_{new_name_idx}")
|
|
995
|
-
c = c.label(new_name)
|
|
996
|
-
q2_columns.append(c)
|
|
1137
|
+
q2_columns.append(column)
|
|
997
1138
|
|
|
998
1139
|
res_columns = q1_columns + q2_columns
|
|
999
1140
|
predicates = (
|
|
@@ -1038,8 +1179,15 @@ class SQLJoin(Step):
|
|
|
1038
1179
|
|
|
1039
1180
|
@frozen
|
|
1040
1181
|
class SQLGroupBy(SQLClause):
|
|
1041
|
-
cols: Sequence[
|
|
1042
|
-
group_by: Sequence[
|
|
1182
|
+
cols: Sequence[str | Function | ColumnElement]
|
|
1183
|
+
group_by: Sequence[str | Function | ColumnElement]
|
|
1184
|
+
|
|
1185
|
+
def hash_inputs(self) -> str:
|
|
1186
|
+
return hashlib.sha256(
|
|
1187
|
+
bytes.fromhex(
|
|
1188
|
+
hash_column_elements(self.cols) + hash_column_elements(self.group_by)
|
|
1189
|
+
)
|
|
1190
|
+
).hexdigest()
|
|
1043
1191
|
|
|
1044
1192
|
def apply_sql_clause(self, query) -> Select:
|
|
1045
1193
|
if not self.cols:
|
|
@@ -1069,46 +1217,52 @@ class SQLGroupBy(SQLClause):
|
|
|
1069
1217
|
return sqlalchemy.select(*unique_cols).select_from(subquery).group_by(*group_by)
|
|
1070
1218
|
|
|
1071
1219
|
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
) -> set[str]:
|
|
1075
|
-
left_names = {c.name for c in left_columns}
|
|
1076
|
-
right_names = {c.name for c in right_columns}
|
|
1077
|
-
|
|
1078
|
-
if left_names == right_names:
|
|
1079
|
-
return left_names
|
|
1080
|
-
|
|
1081
|
-
missing_right = left_names - right_names
|
|
1082
|
-
missing_left = right_names - left_names
|
|
1083
|
-
|
|
1084
|
-
def _prepare_msg_part(missing_columns: set[str], side: str) -> str:
|
|
1085
|
-
return f"{', '.join(sorted(missing_columns))} only present in {side}"
|
|
1086
|
-
|
|
1087
|
-
msg_parts = [
|
|
1088
|
-
_prepare_msg_part(missing_columns, found_side)
|
|
1089
|
-
for missing_columns, found_side in zip(
|
|
1090
|
-
[
|
|
1091
|
-
missing_right,
|
|
1092
|
-
missing_left,
|
|
1093
|
-
],
|
|
1094
|
-
["left", "right"],
|
|
1095
|
-
)
|
|
1096
|
-
if missing_columns
|
|
1097
|
-
]
|
|
1098
|
-
msg = f"Cannot perform union. {'. '.join(msg_parts)}"
|
|
1220
|
+
class UnionSchemaMismatchError(ValueError):
|
|
1221
|
+
"""Union input columns mismatch."""
|
|
1099
1222
|
|
|
1100
|
-
|
|
1223
|
+
@classmethod
|
|
1224
|
+
def from_column_sets(
|
|
1225
|
+
cls,
|
|
1226
|
+
missing_left: set[str],
|
|
1227
|
+
missing_right: set[str],
|
|
1228
|
+
) -> "UnionSchemaMismatchError":
|
|
1229
|
+
def _describe(cols: set[str], side: str) -> str:
|
|
1230
|
+
return f"{', '.join(sorted(cols))} only present in {side}"
|
|
1231
|
+
|
|
1232
|
+
parts = []
|
|
1233
|
+
if missing_left:
|
|
1234
|
+
parts.append(_describe(missing_left, "left"))
|
|
1235
|
+
if missing_right:
|
|
1236
|
+
parts.append(_describe(missing_right, "right"))
|
|
1237
|
+
|
|
1238
|
+
return cls(f"Cannot perform union. {'. '.join(parts)}")
|
|
1101
1239
|
|
|
1102
1240
|
|
|
1103
1241
|
def _order_columns(
|
|
1104
1242
|
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
|
|
1105
1243
|
) -> list[list[ColumnElement]]:
|
|
1106
|
-
|
|
1244
|
+
left_names = [c.name for c in left_columns]
|
|
1245
|
+
right_names = [c.name for c in right_columns]
|
|
1246
|
+
|
|
1247
|
+
# validate
|
|
1248
|
+
if sorted(left_names) != sorted(right_names):
|
|
1249
|
+
left_names_set = set(left_names)
|
|
1250
|
+
right_names_set = set(right_names)
|
|
1251
|
+
raise UnionSchemaMismatchError.from_column_sets(
|
|
1252
|
+
left_names_set - right_names_set,
|
|
1253
|
+
right_names_set - left_names_set,
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
# Order columns to match left_names order
|
|
1107
1257
|
column_dicts = [
|
|
1108
1258
|
{c.name: c for c in columns} for columns in [left_columns, right_columns]
|
|
1109
1259
|
]
|
|
1110
1260
|
|
|
1111
|
-
return [[d[n] for n in
|
|
1261
|
+
return [[d[n] for n in left_names] for d in column_dicts]
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
|
|
1265
|
+
return [c for c in columns if not c.name.startswith("sys__")]
|
|
1112
1266
|
|
|
1113
1267
|
|
|
1114
1268
|
@attrs.define
|
|
@@ -1124,40 +1278,42 @@ class DatasetQuery:
|
|
|
1124
1278
|
def __init__(
|
|
1125
1279
|
self,
|
|
1126
1280
|
name: str,
|
|
1127
|
-
version:
|
|
1128
|
-
project_name:
|
|
1129
|
-
namespace_name:
|
|
1130
|
-
catalog:
|
|
1131
|
-
session:
|
|
1281
|
+
version: str | None = None,
|
|
1282
|
+
project_name: str | None = None,
|
|
1283
|
+
namespace_name: str | None = None,
|
|
1284
|
+
catalog: "Catalog | None" = None,
|
|
1285
|
+
session: Session | None = None,
|
|
1132
1286
|
in_memory: bool = False,
|
|
1133
1287
|
update: bool = False,
|
|
1134
1288
|
) -> None:
|
|
1135
1289
|
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
|
|
1136
1290
|
self.catalog = catalog or self.session.catalog
|
|
1137
1291
|
self.steps: list[Step] = []
|
|
1138
|
-
self._chunk_index:
|
|
1139
|
-
self._chunk_total:
|
|
1292
|
+
self._chunk_index: int | None = None
|
|
1293
|
+
self._chunk_total: int | None = None
|
|
1140
1294
|
self.temp_table_names: list[str] = []
|
|
1141
1295
|
self.dependencies: set[DatasetDependencyType] = set()
|
|
1142
1296
|
self.table = self.get_table()
|
|
1143
|
-
self.starting_step:
|
|
1144
|
-
self.name:
|
|
1145
|
-
self.version:
|
|
1146
|
-
self.feature_schema:
|
|
1147
|
-
self.column_types:
|
|
1297
|
+
self.starting_step: QueryStep | None = None
|
|
1298
|
+
self.name: str | None = None
|
|
1299
|
+
self.version: str | None = None
|
|
1300
|
+
self.feature_schema: dict | None = None
|
|
1301
|
+
self.column_types: dict[str, Any] | None = None
|
|
1148
1302
|
self.before_steps: list[Callable] = []
|
|
1149
|
-
self.listing_fn:
|
|
1303
|
+
self.listing_fn: Callable | None = None
|
|
1150
1304
|
self.update = update
|
|
1151
1305
|
|
|
1152
|
-
self.list_ds_name:
|
|
1306
|
+
self.list_ds_name: str | None = None
|
|
1153
1307
|
|
|
1154
1308
|
self.name = name
|
|
1155
1309
|
self.dialect = self.catalog.warehouse.db.dialect
|
|
1156
1310
|
if version:
|
|
1157
1311
|
self.version = version
|
|
1158
1312
|
|
|
1159
|
-
|
|
1160
|
-
|
|
1313
|
+
if namespace_name is None:
|
|
1314
|
+
namespace_name = self.catalog.metastore.default_namespace_name
|
|
1315
|
+
if project_name is None:
|
|
1316
|
+
project_name = self.catalog.metastore.default_project_name
|
|
1161
1317
|
|
|
1162
1318
|
if is_listing_dataset(name) and not version:
|
|
1163
1319
|
# not setting query step yet as listing dataset might not exist at
|
|
@@ -1194,12 +1350,26 @@ class DatasetQuery:
|
|
|
1194
1350
|
def __or__(self, other):
|
|
1195
1351
|
return self.union(other)
|
|
1196
1352
|
|
|
1353
|
+
def hash(self) -> str:
|
|
1354
|
+
"""
|
|
1355
|
+
Calculates hash of this class taking into account hash of starting step
|
|
1356
|
+
and hashes of each following steps. Ordering is important.
|
|
1357
|
+
"""
|
|
1358
|
+
hasher = hashlib.sha256()
|
|
1359
|
+
if self.starting_step:
|
|
1360
|
+
hasher.update(self.starting_step.hash().encode("utf-8"))
|
|
1361
|
+
else:
|
|
1362
|
+
assert self.list_ds_name
|
|
1363
|
+
hasher.update(self.list_ds_name.encode("utf-8"))
|
|
1364
|
+
|
|
1365
|
+
for step in self.steps:
|
|
1366
|
+
hasher.update(step.hash().encode("utf-8"))
|
|
1367
|
+
|
|
1368
|
+
return hasher.hexdigest()
|
|
1369
|
+
|
|
1197
1370
|
@staticmethod
|
|
1198
1371
|
def get_table() -> "TableClause":
|
|
1199
|
-
table_name = "".join(
|
|
1200
|
-
random.choice(string.ascii_letters) # noqa: S311
|
|
1201
|
-
for _ in range(16)
|
|
1202
|
-
)
|
|
1372
|
+
table_name = "".join(secrets.choice(string.ascii_letters) for _ in range(16))
|
|
1203
1373
|
return sqlalchemy.table(table_name)
|
|
1204
1374
|
|
|
1205
1375
|
@property
|
|
@@ -1216,7 +1386,7 @@ class DatasetQuery:
|
|
|
1216
1386
|
"""
|
|
1217
1387
|
return self.name is not None and self.version is not None
|
|
1218
1388
|
|
|
1219
|
-
def c(self, column:
|
|
1389
|
+
def c(self, column: C | str) -> "ColumnClause[Any]":
|
|
1220
1390
|
col: sqlalchemy.ColumnClause = (
|
|
1221
1391
|
sqlalchemy.column(column)
|
|
1222
1392
|
if isinstance(column, str)
|
|
@@ -1311,6 +1481,7 @@ class DatasetQuery:
|
|
|
1311
1481
|
# This is needed to always use a new connection with all metastore and warehouse
|
|
1312
1482
|
# implementations, as errors may close or render unusable the existing
|
|
1313
1483
|
# connections.
|
|
1484
|
+
assert len(self.temp_table_names) == len(set(self.temp_table_names))
|
|
1314
1485
|
with self.catalog.metastore.clone(use_new_connection=True) as metastore:
|
|
1315
1486
|
metastore.cleanup_tables(self.temp_table_names)
|
|
1316
1487
|
with self.catalog.warehouse.clone(use_new_connection=True) as warehouse:
|
|
@@ -1325,7 +1496,7 @@ class DatasetQuery:
|
|
|
1325
1496
|
return list(result)
|
|
1326
1497
|
|
|
1327
1498
|
def to_db_records(self) -> list[dict[str, Any]]:
|
|
1328
|
-
return self.db_results(lambda cols, row: dict(zip(cols, row)))
|
|
1499
|
+
return self.db_results(lambda cols, row: dict(zip(cols, row, strict=False)))
|
|
1329
1500
|
|
|
1330
1501
|
@contextlib.contextmanager
|
|
1331
1502
|
def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
|
|
@@ -1364,7 +1535,7 @@ class DatasetQuery:
|
|
|
1364
1535
|
yield from rows
|
|
1365
1536
|
|
|
1366
1537
|
async def get_params(row: Sequence) -> tuple:
|
|
1367
|
-
row_dict = RowDict(zip(query_fields, row))
|
|
1538
|
+
row_dict = RowDict(zip(query_fields, row, strict=False))
|
|
1368
1539
|
return tuple( # noqa: C409
|
|
1369
1540
|
[
|
|
1370
1541
|
await p.get_value_async(
|
|
@@ -1381,10 +1552,6 @@ class DatasetQuery:
|
|
|
1381
1552
|
finally:
|
|
1382
1553
|
self.cleanup()
|
|
1383
1554
|
|
|
1384
|
-
def shuffle(self) -> "Self":
|
|
1385
|
-
# ToDo: implement shaffle based on seed and/or generating random column
|
|
1386
|
-
return self.order_by(C.sys__rand)
|
|
1387
|
-
|
|
1388
1555
|
def sample(self, n) -> "Self":
|
|
1389
1556
|
"""
|
|
1390
1557
|
Return a random sample from the dataset.
|
|
@@ -1404,6 +1571,7 @@ class DatasetQuery:
|
|
|
1404
1571
|
obj.steps = obj.steps.copy()
|
|
1405
1572
|
if new_table:
|
|
1406
1573
|
obj.table = self.get_table()
|
|
1574
|
+
obj.temp_table_names = []
|
|
1407
1575
|
return obj
|
|
1408
1576
|
|
|
1409
1577
|
@detach
|
|
@@ -1584,10 +1752,10 @@ class DatasetQuery:
|
|
|
1584
1752
|
def join(
|
|
1585
1753
|
self,
|
|
1586
1754
|
dataset_query: "DatasetQuery",
|
|
1587
|
-
predicates:
|
|
1755
|
+
predicates: JoinPredicateType | Sequence[JoinPredicateType],
|
|
1588
1756
|
inner=False,
|
|
1589
1757
|
full=False,
|
|
1590
|
-
rname="
|
|
1758
|
+
rname="right_",
|
|
1591
1759
|
) -> "Self":
|
|
1592
1760
|
left = self.clone(new_table=False)
|
|
1593
1761
|
if self.table.name == dataset_query.table.name:
|
|
@@ -1626,12 +1794,17 @@ class DatasetQuery:
|
|
|
1626
1794
|
def add_signals(
|
|
1627
1795
|
self,
|
|
1628
1796
|
udf: "UDFAdapter",
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
min_task_size: Optional[int] = None,
|
|
1632
|
-
partition_by: Optional[PartitionByType] = None,
|
|
1797
|
+
partition_by: PartitionByType | None = None,
|
|
1798
|
+
# Parameters from Settings
|
|
1633
1799
|
cache: bool = False,
|
|
1634
|
-
|
|
1800
|
+
parallel: int | None = None,
|
|
1801
|
+
workers: bool | int = False,
|
|
1802
|
+
min_task_size: int | None = None,
|
|
1803
|
+
batch_size: int | None = None,
|
|
1804
|
+
# Parameters are unused, kept only to match the signature of Settings.to_dict
|
|
1805
|
+
prefetch: int | None = None,
|
|
1806
|
+
namespace: str | None = None,
|
|
1807
|
+
project: str | None = None,
|
|
1635
1808
|
) -> "Self":
|
|
1636
1809
|
"""
|
|
1637
1810
|
Adds one or more signals based on the results from the provided UDF.
|
|
@@ -1657,7 +1830,7 @@ class DatasetQuery:
|
|
|
1657
1830
|
workers=workers,
|
|
1658
1831
|
min_task_size=min_task_size,
|
|
1659
1832
|
cache=cache,
|
|
1660
|
-
|
|
1833
|
+
batch_size=batch_size,
|
|
1661
1834
|
)
|
|
1662
1835
|
)
|
|
1663
1836
|
return query
|
|
@@ -1672,14 +1845,17 @@ class DatasetQuery:
|
|
|
1672
1845
|
def generate(
|
|
1673
1846
|
self,
|
|
1674
1847
|
udf: "UDFAdapter",
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
min_task_size: Optional[int] = None,
|
|
1678
|
-
partition_by: Optional[PartitionByType] = None,
|
|
1679
|
-
namespace: Optional[str] = None,
|
|
1680
|
-
project: Optional[str] = None,
|
|
1848
|
+
partition_by: PartitionByType | None = None,
|
|
1849
|
+
# Parameters from Settings
|
|
1681
1850
|
cache: bool = False,
|
|
1682
|
-
|
|
1851
|
+
parallel: int | None = None,
|
|
1852
|
+
workers: bool | int = False,
|
|
1853
|
+
min_task_size: int | None = None,
|
|
1854
|
+
batch_size: int | None = None,
|
|
1855
|
+
# Parameters are unused, kept only to match the signature of Settings.to_dict:
|
|
1856
|
+
prefetch: int | None = None,
|
|
1857
|
+
namespace: str | None = None,
|
|
1858
|
+
project: str | None = None,
|
|
1683
1859
|
) -> "Self":
|
|
1684
1860
|
query = self.clone()
|
|
1685
1861
|
steps = query.steps
|
|
@@ -1692,7 +1868,7 @@ class DatasetQuery:
|
|
|
1692
1868
|
workers=workers,
|
|
1693
1869
|
min_task_size=min_task_size,
|
|
1694
1870
|
cache=cache,
|
|
1695
|
-
|
|
1871
|
+
batch_size=batch_size,
|
|
1696
1872
|
)
|
|
1697
1873
|
)
|
|
1698
1874
|
return query
|
|
@@ -1735,26 +1911,30 @@ class DatasetQuery:
|
|
|
1735
1911
|
|
|
1736
1912
|
def exec(self) -> "Self":
|
|
1737
1913
|
"""Execute the query."""
|
|
1914
|
+
query = self.clone()
|
|
1738
1915
|
try:
|
|
1739
|
-
query = self.clone()
|
|
1740
1916
|
query.apply_steps()
|
|
1741
1917
|
finally:
|
|
1742
|
-
|
|
1918
|
+
query.cleanup()
|
|
1743
1919
|
return query
|
|
1744
1920
|
|
|
1745
1921
|
def save(
|
|
1746
1922
|
self,
|
|
1747
|
-
name:
|
|
1748
|
-
version:
|
|
1749
|
-
project:
|
|
1750
|
-
feature_schema:
|
|
1751
|
-
dependencies:
|
|
1752
|
-
description:
|
|
1753
|
-
attrs:
|
|
1754
|
-
update_version:
|
|
1923
|
+
name: str | None = None,
|
|
1924
|
+
version: str | None = None,
|
|
1925
|
+
project: Project | None = None,
|
|
1926
|
+
feature_schema: dict | None = None,
|
|
1927
|
+
dependencies: list[DatasetDependency] | None = None,
|
|
1928
|
+
description: str | None = None,
|
|
1929
|
+
attrs: list[str] | None = None,
|
|
1930
|
+
update_version: str | None = "patch",
|
|
1755
1931
|
**kwargs,
|
|
1756
1932
|
) -> "Self":
|
|
1757
1933
|
"""Save the query as a dataset."""
|
|
1934
|
+
# Get job from session to link dataset version to job
|
|
1935
|
+
job = self.session.get_or_create_job()
|
|
1936
|
+
job_id = job.id
|
|
1937
|
+
|
|
1758
1938
|
project = project or self.catalog.metastore.default_project
|
|
1759
1939
|
try:
|
|
1760
1940
|
if (
|
|
@@ -1797,14 +1977,11 @@ class DatasetQuery:
|
|
|
1797
1977
|
description=description,
|
|
1798
1978
|
attrs=attrs,
|
|
1799
1979
|
update_version=update_version,
|
|
1980
|
+
job_id=job_id,
|
|
1800
1981
|
**kwargs,
|
|
1801
1982
|
)
|
|
1802
1983
|
version = version or dataset.latest_version
|
|
1803
1984
|
|
|
1804
|
-
self.session.add_dataset_version(
|
|
1805
|
-
dataset=dataset, version=version, listing=kwargs.get("listing", False)
|
|
1806
|
-
)
|
|
1807
|
-
|
|
1808
1985
|
dr = self.catalog.warehouse.dataset_rows(dataset)
|
|
1809
1986
|
|
|
1810
1987
|
self.catalog.warehouse.copy_table(dr.get_table(), query.select())
|
|
@@ -1814,6 +1991,11 @@ class DatasetQuery:
|
|
|
1814
1991
|
)
|
|
1815
1992
|
self.catalog.update_dataset_version_with_warehouse_info(dataset, version)
|
|
1816
1993
|
|
|
1994
|
+
# Link this dataset version to the job that created it
|
|
1995
|
+
self.catalog.metastore.link_dataset_version_to_job(
|
|
1996
|
+
dataset.get_version(version).id, job_id, is_creator=True
|
|
1997
|
+
)
|
|
1998
|
+
|
|
1817
1999
|
if dependencies:
|
|
1818
2000
|
# overriding dependencies
|
|
1819
2001
|
self.dependencies = set()
|
|
@@ -1845,5 +2027,5 @@ class DatasetQuery:
|
|
|
1845
2027
|
return isinstance(self.last_step, SQLOrderBy)
|
|
1846
2028
|
|
|
1847
2029
|
@property
|
|
1848
|
-
def last_step(self) ->
|
|
2030
|
+
def last_step(self) -> Step | None:
|
|
1849
2031
|
return self.steps[-1] if self.steps else None
|