datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. datachain/__init__.py +4 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +5 -5
  4. datachain/catalog/__init__.py +0 -2
  5. datachain/catalog/catalog.py +276 -354
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +8 -3
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +10 -17
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +42 -27
  12. datachain/cli/commands/ls.py +15 -15
  13. datachain/cli/commands/show.py +2 -2
  14. datachain/cli/parser/__init__.py +3 -43
  15. datachain/cli/parser/job.py +1 -1
  16. datachain/cli/parser/utils.py +1 -2
  17. datachain/cli/utils.py +2 -15
  18. datachain/client/azure.py +2 -2
  19. datachain/client/fsspec.py +34 -23
  20. datachain/client/gcs.py +3 -3
  21. datachain/client/http.py +157 -0
  22. datachain/client/local.py +11 -7
  23. datachain/client/s3.py +3 -3
  24. datachain/config.py +4 -8
  25. datachain/data_storage/db_engine.py +12 -6
  26. datachain/data_storage/job.py +2 -0
  27. datachain/data_storage/metastore.py +716 -137
  28. datachain/data_storage/schema.py +20 -27
  29. datachain/data_storage/serializer.py +105 -15
  30. datachain/data_storage/sqlite.py +114 -114
  31. datachain/data_storage/warehouse.py +140 -48
  32. datachain/dataset.py +109 -89
  33. datachain/delta.py +117 -42
  34. datachain/diff/__init__.py +25 -33
  35. datachain/error.py +24 -0
  36. datachain/func/aggregate.py +9 -11
  37. datachain/func/array.py +12 -12
  38. datachain/func/base.py +7 -4
  39. datachain/func/conditional.py +9 -13
  40. datachain/func/func.py +63 -45
  41. datachain/func/numeric.py +5 -7
  42. datachain/func/string.py +2 -2
  43. datachain/hash_utils.py +123 -0
  44. datachain/job.py +11 -7
  45. datachain/json.py +138 -0
  46. datachain/lib/arrow.py +18 -15
  47. datachain/lib/audio.py +60 -59
  48. datachain/lib/clip.py +14 -13
  49. datachain/lib/convert/python_to_sql.py +6 -10
  50. datachain/lib/convert/values_to_tuples.py +151 -53
  51. datachain/lib/data_model.py +23 -19
  52. datachain/lib/dataset_info.py +7 -7
  53. datachain/lib/dc/__init__.py +2 -1
  54. datachain/lib/dc/csv.py +22 -26
  55. datachain/lib/dc/database.py +37 -34
  56. datachain/lib/dc/datachain.py +518 -324
  57. datachain/lib/dc/datasets.py +38 -30
  58. datachain/lib/dc/hf.py +16 -20
  59. datachain/lib/dc/json.py +17 -18
  60. datachain/lib/dc/listings.py +5 -8
  61. datachain/lib/dc/pandas.py +3 -6
  62. datachain/lib/dc/parquet.py +33 -21
  63. datachain/lib/dc/records.py +9 -13
  64. datachain/lib/dc/storage.py +103 -65
  65. datachain/lib/dc/storage_pattern.py +251 -0
  66. datachain/lib/dc/utils.py +17 -14
  67. datachain/lib/dc/values.py +3 -6
  68. datachain/lib/file.py +187 -50
  69. datachain/lib/hf.py +7 -5
  70. datachain/lib/image.py +13 -13
  71. datachain/lib/listing.py +5 -5
  72. datachain/lib/listing_info.py +1 -2
  73. datachain/lib/meta_formats.py +2 -3
  74. datachain/lib/model_store.py +20 -8
  75. datachain/lib/namespaces.py +59 -7
  76. datachain/lib/projects.py +51 -9
  77. datachain/lib/pytorch.py +31 -23
  78. datachain/lib/settings.py +188 -85
  79. datachain/lib/signal_schema.py +302 -64
  80. datachain/lib/text.py +8 -7
  81. datachain/lib/udf.py +103 -63
  82. datachain/lib/udf_signature.py +59 -34
  83. datachain/lib/utils.py +20 -0
  84. datachain/lib/video.py +3 -4
  85. datachain/lib/webdataset.py +31 -36
  86. datachain/lib/webdataset_laion.py +15 -16
  87. datachain/listing.py +12 -5
  88. datachain/model/bbox.py +3 -1
  89. datachain/namespace.py +22 -3
  90. datachain/node.py +6 -6
  91. datachain/nodes_thread_pool.py +0 -1
  92. datachain/plugins.py +24 -0
  93. datachain/project.py +4 -4
  94. datachain/query/batch.py +10 -12
  95. datachain/query/dataset.py +376 -194
  96. datachain/query/dispatch.py +112 -84
  97. datachain/query/metrics.py +3 -4
  98. datachain/query/params.py +2 -3
  99. datachain/query/queue.py +2 -1
  100. datachain/query/schema.py +7 -6
  101. datachain/query/session.py +190 -33
  102. datachain/query/udf.py +9 -6
  103. datachain/remote/studio.py +90 -53
  104. datachain/script_meta.py +12 -12
  105. datachain/sql/sqlite/base.py +37 -25
  106. datachain/sql/sqlite/types.py +1 -1
  107. datachain/sql/types.py +36 -5
  108. datachain/studio.py +49 -40
  109. datachain/toolkit/split.py +31 -10
  110. datachain/utils.py +39 -48
  111. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
  112. datachain-0.39.0.dist-info/RECORD +173 -0
  113. datachain/cli/commands/query.py +0 -54
  114. datachain/query/utils.py +0 -36
  115. datachain-0.30.5.dist-info/RECORD +0 -168
  116. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
  117. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  118. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  119. {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,18 @@
1
1
  import contextlib
2
+ import hashlib
2
3
  import inspect
3
4
  import logging
4
5
  import os
5
- import random
6
+ import secrets
6
7
  import string
7
8
  import subprocess
8
9
  import sys
9
10
  from abc import ABC, abstractmethod
10
- from collections.abc import Generator, Iterable, Iterator, Sequence
11
+ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
11
12
  from copy import copy
12
13
  from functools import wraps
13
14
  from types import GeneratorType
14
- from typing import (
15
- TYPE_CHECKING,
16
- Any,
17
- Callable,
18
- Optional,
19
- Protocol,
20
- TypeVar,
21
- Union,
22
- )
15
+ from typing import TYPE_CHECKING, Any, Protocol, TypeVar
23
16
 
24
17
  import attrs
25
18
  import sqlalchemy
@@ -44,20 +37,21 @@ from datachain.data_storage.schema import (
44
37
  from datachain.dataset import DatasetDependency, DatasetStatus, RowDict
45
38
  from datachain.error import DatasetNotFoundError, QueryScriptCancelError
46
39
  from datachain.func.base import Function
40
+ from datachain.hash_utils import hash_column_elements
47
41
  from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
48
- from datachain.lib.signal_schema import SignalSchema
42
+ from datachain.lib.signal_schema import SignalSchema, generate_merge_root_mapping
49
43
  from datachain.lib.udf import UDFAdapter, _get_cache
50
44
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
51
45
  from datachain.project import Project
52
- from datachain.query.schema import C, UDFParamSpec, normalize_param
46
+ from datachain.query.schema import DEFAULT_DELIMITER, C, UDFParamSpec, normalize_param
53
47
  from datachain.query.session import Session
54
48
  from datachain.query.udf import UdfInfo
55
49
  from datachain.sql.functions.random import rand
56
50
  from datachain.sql.types import SQLType
57
51
  from datachain.utils import (
58
- batched,
59
52
  determine_processes,
60
53
  determine_workers,
54
+ ensure_sequence,
61
55
  filtered_cloudpickle_dumps,
62
56
  get_datachain_executable,
63
57
  safe_closing,
@@ -65,11 +59,12 @@ from datachain.utils import (
65
59
 
66
60
  if TYPE_CHECKING:
67
61
  from collections.abc import Mapping
62
+ from typing import Concatenate
68
63
 
69
- from sqlalchemy.sql.elements import ClauseElement
64
+ from sqlalchemy.sql.elements import ClauseElement, KeyedColumnElement
70
65
  from sqlalchemy.sql.schema import Table
71
66
  from sqlalchemy.sql.selectable import GenerativeSelect
72
- from typing_extensions import Concatenate, ParamSpec, Self
67
+ from typing_extensions import ParamSpec, Self
73
68
 
74
69
  from datachain.catalog import Catalog
75
70
  from datachain.data_storage import AbstractWarehouse
@@ -81,13 +76,10 @@ if TYPE_CHECKING:
81
76
 
82
77
  INSERT_BATCH_SIZE = 10000
83
78
 
84
- PartitionByType = Union[
85
- str,
86
- Function,
87
- ColumnElement,
88
- Sequence[Union[str, Function, ColumnElement]],
89
- ]
90
- JoinPredicateType = Union[str, ColumnClause, ColumnElement]
79
+ PartitionByType = (
80
+ str | Function | ColumnElement | Sequence[str | Function | ColumnElement]
81
+ )
82
+ JoinPredicateType = str | ColumnClause | ColumnElement
91
83
  DatasetDependencyType = tuple["DatasetRecord", str]
92
84
 
93
85
  logger = logging.getLogger("datachain")
@@ -168,6 +160,18 @@ class Step(ABC):
168
160
  ) -> "StepResult":
169
161
  """Apply the processing step."""
170
162
 
163
+ @abstractmethod
164
+ def hash_inputs(self) -> str:
165
+ """Calculates hash of step inputs"""
166
+
167
+ def hash(self) -> str:
168
+ """
169
+ Calculates hash for step which includes step name and hash of it's inputs
170
+ """
171
+ return hashlib.sha256(
172
+ f"{self.__class__.__name__}|{self.hash_inputs()}".encode()
173
+ ).hexdigest()
174
+
171
175
 
172
176
  @frozen
173
177
  class QueryStep:
@@ -187,6 +191,11 @@ class QueryStep:
187
191
  q, dr.columns, dependencies=[(self.dataset, self.dataset_version)]
188
192
  )
189
193
 
194
+ def hash(self) -> str:
195
+ return hashlib.sha256(
196
+ self.dataset.uri(self.dataset_version).encode()
197
+ ).hexdigest()
198
+
190
199
 
191
200
  def generator_then_call(generator, func: Callable):
192
201
  """
@@ -222,8 +231,9 @@ class DatasetDiffOperation(Step):
222
231
 
223
232
  def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
224
233
  source_query = query_generator.exclude(("sys__id",))
234
+ right_before = len(self.dq.temp_table_names)
225
235
  target_query = self.dq.apply_steps().select()
226
- temp_tables.extend(self.dq.temp_table_names)
236
+ temp_tables.extend(self.dq.temp_table_names[right_before:])
227
237
 
228
238
  # creating temp table that will hold subtract results
229
239
  temp_table_name = self.catalog.warehouse.temp_table_name()
@@ -257,6 +267,13 @@ class DatasetDiffOperation(Step):
257
267
  class Subtract(DatasetDiffOperation):
258
268
  on: Sequence[tuple[str, str]]
259
269
 
270
+ def hash_inputs(self) -> str:
271
+ on_bytes = b"".join(
272
+ f"{a}:{b}".encode() for a, b in sorted(self.on, key=lambda t: (t[0], t[1]))
273
+ )
274
+
275
+ return hashlib.sha256(bytes.fromhex(self.dq.hash()) + on_bytes).hexdigest()
276
+
260
277
  def query(self, source_query: Select, target_query: Select) -> sa.Selectable:
261
278
  sq = source_query.alias("source_query")
262
279
  tq = target_query.alias("target_query")
@@ -334,10 +351,10 @@ def process_udf_outputs(
334
351
  udf_results: Iterator[Iterable["UDFResult"]],
335
352
  udf: "UDFAdapter",
336
353
  cb: Callback = DEFAULT_CALLBACK,
354
+ batch_size: int = INSERT_BATCH_SIZE,
337
355
  ) -> None:
338
356
  # Optimization: Compute row types once, rather than for every row.
339
357
  udf_col_types = get_col_types(warehouse, udf.output)
340
- batch_rows = udf.batch_rows or INSERT_BATCH_SIZE
341
358
 
342
359
  def _insert_rows():
343
360
  for udf_output in udf_results:
@@ -349,9 +366,7 @@ def process_udf_outputs(
349
366
  cb.relative_update()
350
367
  yield adjust_outputs(warehouse, row, udf_col_types)
351
368
 
352
- for row_chunk in batched(_insert_rows(), batch_rows):
353
- warehouse.insert_rows(udf_table, row_chunk)
354
-
369
+ warehouse.insert_rows(udf_table, _insert_rows(), batch_size=batch_size)
355
370
  warehouse.insert_rows_done(udf_table)
356
371
 
357
372
 
@@ -387,21 +402,34 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
387
402
  class UDFStep(Step, ABC):
388
403
  udf: "UDFAdapter"
389
404
  catalog: "Catalog"
390
- partition_by: Optional[PartitionByType] = None
391
- parallel: Optional[int] = None
392
- workers: Union[bool, int] = False
393
- min_task_size: Optional[int] = None
405
+ partition_by: PartitionByType | None = None
394
406
  is_generator = False
407
+ # Parameters from Settings
395
408
  cache: bool = False
396
- batch_rows: Optional[int] = None
409
+ parallel: int | None = None
410
+ workers: bool | int = False
411
+ min_task_size: int | None = None
412
+ batch_size: int | None = None
413
+
414
+ def hash_inputs(self) -> str:
415
+ partition_by = ensure_sequence(self.partition_by or [])
416
+ parts = [
417
+ bytes.fromhex(self.udf.hash()),
418
+ bytes.fromhex(hash_column_elements(partition_by)),
419
+ str(self.is_generator).encode(),
420
+ ]
421
+
422
+ return hashlib.sha256(b"".join(parts)).hexdigest()
397
423
 
398
424
  @abstractmethod
399
425
  def create_udf_table(self, query: Select) -> "Table":
400
426
  """Method that creates a table where temp udf results will be saved"""
401
427
 
402
428
  def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
403
- """Apply any necessary processing to the input query"""
404
- return query, []
429
+ """Materialize inputs, ensure sys columns are available, needed for checkpoints,
430
+ needed for map to work (merge results)"""
431
+ table = self.catalog.warehouse.create_pre_udf_table(query)
432
+ return sqlalchemy.select(*table.c), [table]
405
433
 
406
434
  @abstractmethod
407
435
  def create_result_query(
@@ -450,6 +478,7 @@ class UDFStep(Step, ABC):
450
478
  use_cache=self.cache,
451
479
  is_generator=self.is_generator,
452
480
  min_task_size=self.min_task_size,
481
+ batch_size=self.batch_size,
453
482
  )
454
483
  udf_distributor()
455
484
  return
@@ -486,6 +515,7 @@ class UDFStep(Step, ABC):
486
515
  is_generator=self.is_generator,
487
516
  cache=self.cache,
488
517
  rows_total=rows_total,
518
+ batch_size=self.batch_size or INSERT_BATCH_SIZE,
489
519
  )
490
520
 
491
521
  # Run the UDFDispatcher in another process to avoid needing
@@ -534,6 +564,7 @@ class UDFStep(Step, ABC):
534
564
  udf_results,
535
565
  self.udf,
536
566
  cb=generated_cb,
567
+ batch_size=self.batch_size or INSERT_BATCH_SIZE,
537
568
  )
538
569
  finally:
539
570
  download_cb.close()
@@ -552,13 +583,10 @@ class UDFStep(Step, ABC):
552
583
  """
553
584
  Create temporary table with group by partitions.
554
585
  """
555
- # Check if partition_by is set, we need it to create partitions.
556
- assert self.partition_by is not None
557
- # Check if sys__id is in the query, we need it to be able to join
558
- # the partition table with the udf table later.
559
- assert any(c.name == "sys__id" for c in query.selected_columns), (
560
- "Query must have sys__id column to use partitioning."
561
- )
586
+ if self.partition_by is None:
587
+ raise RuntimeError("Query must have partition_by set to use partitioning")
588
+ if (id_col := query.selected_columns.get("sys__id")) is None:
589
+ raise RuntimeError("Query must have sys__id column to use partitioning")
562
590
 
563
591
  if isinstance(self.partition_by, (list, tuple, GeneratorType)):
564
592
  list_partition_by = list(self.partition_by)
@@ -574,7 +602,7 @@ class UDFStep(Step, ABC):
574
602
 
575
603
  # fill table with partitions
576
604
  cols = [
577
- query.selected_columns.sys__id,
605
+ id_col,
578
606
  f.dense_rank().over(order_by=partition_by).label(PARTITION_COLUMN_ID),
579
607
  ]
580
608
  self.catalog.warehouse.db.execute(
@@ -586,7 +614,7 @@ class UDFStep(Step, ABC):
586
614
 
587
615
  return tbl
588
616
 
589
- def clone(self, partition_by: Optional[PartitionByType] = None) -> "Self":
617
+ def clone(self, partition_by: PartitionByType | None = None) -> "Self":
590
618
  if partition_by is not None:
591
619
  return self.__class__(
592
620
  self.udf,
@@ -595,41 +623,25 @@ class UDFStep(Step, ABC):
595
623
  parallel=self.parallel,
596
624
  workers=self.workers,
597
625
  min_task_size=self.min_task_size,
598
- batch_rows=self.batch_rows,
626
+ batch_size=self.batch_size,
599
627
  )
600
628
  return self.__class__(self.udf, self.catalog)
601
629
 
602
630
  def apply(
603
631
  self, query_generator: QueryGenerator, temp_tables: list[str]
604
632
  ) -> "StepResult":
605
- _query = query = query_generator.select()
633
+ query, tables = self.process_input_query(query_generator.select())
634
+ _query = query
606
635
 
607
636
  # Apply partitioning if needed.
608
637
  if self.partition_by is not None:
609
- if not any(c.name == "sys__id" for c in query.selected_columns):
610
- # If sys__id is not in the query, we need to create a temp table
611
- # to hold the query results, so we can join it with the
612
- # partition table later.
613
- columns = [
614
- c if isinstance(c, Column) else Column(c.name, c.type)
615
- for c in query.subquery().columns
616
- ]
617
- temp_table = self.catalog.warehouse.create_dataset_rows_table(
618
- self.catalog.warehouse.temp_table_name(),
619
- columns=columns,
620
- )
621
- temp_tables.append(temp_table.name)
622
- self.catalog.warehouse.copy_table(temp_table, query)
623
- _query = query = temp_table.select()
624
-
625
638
  partition_tbl = self.create_partitions_table(query)
626
- temp_tables.append(partition_tbl.name)
627
639
  query = query.outerjoin(
628
640
  partition_tbl,
629
641
  partition_tbl.c.sys__id == query.selected_columns.sys__id,
630
642
  ).add_columns(*partition_columns())
643
+ tables = [*tables, partition_tbl]
631
644
 
632
- query, tables = self.process_input_query(query)
633
645
  temp_tables.extend(t.name for t in tables)
634
646
  udf_table = self.create_udf_table(_query)
635
647
  temp_tables.append(udf_table.name)
@@ -641,7 +653,16 @@ class UDFStep(Step, ABC):
641
653
 
642
654
  @frozen
643
655
  class UDFSignal(UDFStep):
656
+ udf: "UDFAdapter"
657
+ catalog: "Catalog"
658
+ partition_by: PartitionByType | None = None
644
659
  is_generator = False
660
+ # Parameters from Settings
661
+ cache: bool = False
662
+ parallel: int | None = None
663
+ workers: bool | int = False
664
+ min_task_size: int | None = None
665
+ batch_size: int | None = None
645
666
 
646
667
  def create_udf_table(self, query: Select) -> "Table":
647
668
  udf_output_columns: list[sqlalchemy.Column[Any]] = [
@@ -651,13 +672,6 @@ class UDFSignal(UDFStep):
651
672
 
652
673
  return self.catalog.warehouse.create_udf_table(udf_output_columns)
653
674
 
654
- def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
655
- if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
656
- return query, []
657
- table = self.catalog.warehouse.create_pre_udf_table(query)
658
- q: Select = sqlalchemy.select(*table.c)
659
- return q, [table]
660
-
661
675
  def create_result_query(
662
676
  self, udf_table, query
663
677
  ) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
@@ -669,11 +683,26 @@ class UDFSignal(UDFStep):
669
683
  signal_name_cols = {c.name: c for c in signal_cols}
670
684
  cols = signal_cols
671
685
 
672
- overlap = {c.name for c in original_cols} & {c.name for c in cols}
686
+ original_names = {c.name for c in original_cols}
687
+ new_names = {c.name for c in cols}
688
+
689
+ overlap = original_names & new_names
673
690
  if overlap:
674
691
  raise ValueError(
675
692
  "Column already exists or added in the previous steps: "
676
- + ", ".join(overlap)
693
+ + ", ".join(sorted(overlap))
694
+ )
695
+
696
+ def _root(name: str) -> str:
697
+ return name.split(DEFAULT_DELIMITER, 1)[0]
698
+
699
+ existing_roots = {_root(name) for name in original_names}
700
+ new_roots = {_root(name) for name in new_names}
701
+ root_conflicts = existing_roots & new_roots
702
+ if root_conflicts:
703
+ raise ValueError(
704
+ "Signals already exist in the previous steps: "
705
+ + ", ".join(sorted(root_conflicts))
677
706
  )
678
707
 
679
708
  def q(*columns):
@@ -711,7 +740,16 @@ class UDFSignal(UDFStep):
711
740
  class RowGenerator(UDFStep):
712
741
  """Extend dataset with new rows."""
713
742
 
743
+ udf: "UDFAdapter"
744
+ catalog: "Catalog"
745
+ partition_by: PartitionByType | None = None
714
746
  is_generator = True
747
+ # Parameters from Settings
748
+ cache: bool = False
749
+ parallel: int | None = None
750
+ workers: bool | int = False
751
+ min_task_size: int | None = None
752
+ batch_size: int | None = None
715
753
 
716
754
  def create_udf_table(self, query: Select) -> "Table":
717
755
  warehouse = self.catalog.warehouse
@@ -758,18 +796,42 @@ class SQLClause(Step, ABC):
758
796
 
759
797
  def parse_cols(
760
798
  self,
761
- cols: Sequence[Union[Function, ColumnElement]],
799
+ cols: Sequence[Function | ColumnElement],
762
800
  ) -> tuple[ColumnElement, ...]:
763
801
  return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
764
802
 
765
803
  @abstractmethod
766
- def apply_sql_clause(self, query):
804
+ def apply_sql_clause(self, query: Any) -> Any:
767
805
  pass
768
806
 
769
807
 
808
+ @frozen
809
+ class RegenerateSystemColumns(Step):
810
+ catalog: "Catalog"
811
+
812
+ def hash_inputs(self) -> str:
813
+ return hashlib.sha256(b"regenerate_system_columns").hexdigest()
814
+
815
+ def apply(
816
+ self, query_generator: QueryGenerator, temp_tables: list[str]
817
+ ) -> StepResult:
818
+ query = query_generator.select()
819
+ new_query = self.catalog.warehouse._regenerate_system_columns(
820
+ query, keep_existing_columns=True
821
+ )
822
+
823
+ def q(*columns):
824
+ return new_query.with_only_columns(*columns)
825
+
826
+ return step_result(q, new_query.selected_columns)
827
+
828
+
770
829
  @frozen
771
830
  class SQLSelect(SQLClause):
772
- args: tuple[Union[Function, ColumnElement], ...]
831
+ args: tuple[Function | ColumnElement, ...]
832
+
833
+ def hash_inputs(self) -> str:
834
+ return hash_column_elements(self.args)
773
835
 
774
836
  def apply_sql_clause(self, query) -> Select:
775
837
  subquery = query.subquery()
@@ -785,7 +847,10 @@ class SQLSelect(SQLClause):
785
847
 
786
848
  @frozen
787
849
  class SQLSelectExcept(SQLClause):
788
- args: tuple[Union[Function, ColumnElement], ...]
850
+ args: tuple[Function | ColumnElement, ...]
851
+
852
+ def hash_inputs(self) -> str:
853
+ return hash_column_elements(self.args)
789
854
 
790
855
  def apply_sql_clause(self, query: Select) -> Select:
791
856
  subquery = query.subquery()
@@ -798,6 +863,9 @@ class SQLMutate(SQLClause):
798
863
  args: tuple[Label, ...]
799
864
  new_schema: SignalSchema
800
865
 
866
+ def hash_inputs(self) -> str:
867
+ return hash_column_elements(self.args)
868
+
801
869
  def apply_sql_clause(self, query: Select) -> Select:
802
870
  original_subquery = query.subquery()
803
871
  to_mutate = {c.name for c in self.args}
@@ -825,7 +893,10 @@ class SQLMutate(SQLClause):
825
893
 
826
894
  @frozen
827
895
  class SQLFilter(SQLClause):
828
- expressions: tuple[Union[Function, ColumnElement], ...]
896
+ expressions: tuple[Function | ColumnElement, ...]
897
+
898
+ def hash_inputs(self) -> str:
899
+ return hash_column_elements(self.expressions)
829
900
 
830
901
  def __and__(self, other):
831
902
  expressions = self.parse_cols(self.expressions)
@@ -838,7 +909,10 @@ class SQLFilter(SQLClause):
838
909
 
839
910
  @frozen
840
911
  class SQLOrderBy(SQLClause):
841
- args: tuple[Union[Function, ColumnElement], ...]
912
+ args: tuple[Function | ColumnElement, ...]
913
+
914
+ def hash_inputs(self) -> str:
915
+ return hash_column_elements(self.args)
842
916
 
843
917
  def apply_sql_clause(self, query: Select) -> Select:
844
918
  args = self.parse_cols(self.args)
@@ -849,6 +923,9 @@ class SQLOrderBy(SQLClause):
849
923
  class SQLLimit(SQLClause):
850
924
  n: int
851
925
 
926
+ def hash_inputs(self) -> str:
927
+ return hashlib.sha256(str(self.n).encode()).hexdigest()
928
+
852
929
  def apply_sql_clause(self, query: Select) -> Select:
853
930
  return query.limit(self.n)
854
931
 
@@ -857,12 +934,18 @@ class SQLLimit(SQLClause):
857
934
  class SQLOffset(SQLClause):
858
935
  offset: int
859
936
 
937
+ def hash_inputs(self) -> str:
938
+ return hashlib.sha256(str(self.offset).encode()).hexdigest()
939
+
860
940
  def apply_sql_clause(self, query: "GenerativeSelect"):
861
941
  return query.offset(self.offset)
862
942
 
863
943
 
864
944
  @frozen
865
945
  class SQLCount(SQLClause):
946
+ def hash_inputs(self) -> str:
947
+ return ""
948
+
866
949
  def apply_sql_clause(self, query):
867
950
  return sqlalchemy.select(f.count(1)).select_from(query.subquery())
868
951
 
@@ -872,6 +955,9 @@ class SQLDistinct(SQLClause):
872
955
  args: tuple[ColumnElement, ...]
873
956
  dialect: str
874
957
 
958
+ def hash_inputs(self) -> str:
959
+ return hash_column_elements(self.args)
960
+
875
961
  def apply_sql_clause(self, query):
876
962
  if self.dialect == "sqlite":
877
963
  return query.group_by(*self.args)
@@ -884,24 +970,34 @@ class SQLUnion(Step):
884
970
  query1: "DatasetQuery"
885
971
  query2: "DatasetQuery"
886
972
 
973
+ def hash_inputs(self) -> str:
974
+ return hashlib.sha256(
975
+ bytes.fromhex(self.query1.hash()) + bytes.fromhex(self.query2.hash())
976
+ ).hexdigest()
977
+
887
978
  def apply(
888
979
  self, query_generator: QueryGenerator, temp_tables: list[str]
889
980
  ) -> StepResult:
981
+ left_before = len(self.query1.temp_table_names)
890
982
  q1 = self.query1.apply_steps().select().subquery()
891
- temp_tables.extend(self.query1.temp_table_names)
983
+ temp_tables.extend(self.query1.temp_table_names[left_before:])
984
+ right_before = len(self.query2.temp_table_names)
892
985
  q2 = self.query2.apply_steps().select().subquery()
893
- temp_tables.extend(self.query2.temp_table_names)
986
+ temp_tables.extend(self.query2.temp_table_names[right_before:])
894
987
 
895
- columns1, columns2 = _order_columns(q1.columns, q2.columns)
988
+ columns1 = _drop_system_columns(q1.columns)
989
+ columns2 = _drop_system_columns(q2.columns)
990
+ columns1, columns2 = _order_columns(columns1, columns2)
896
991
 
897
992
  def q(*columns):
898
- names = {c.name for c in columns}
899
- col1 = [c for c in columns1 if c.name in names]
900
- col2 = [c for c in columns2 if c.name in names]
901
- res = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
993
+ selected_names = [c.name for c in columns]
994
+ col1 = [c for c in columns1 if c.name in selected_names]
995
+ col2 = [c for c in columns2 if c.name in selected_names]
996
+ union_query = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
902
997
 
903
- subquery = res.subquery()
904
- return sqlalchemy.select(*subquery.c).select_from(subquery)
998
+ union_cte = union_query.cte()
999
+ select_cols = [union_cte.c[name] for name in selected_names]
1000
+ return sqlalchemy.select(*select_cols)
905
1001
 
906
1002
  return step_result(
907
1003
  q,
@@ -915,14 +1011,42 @@ class SQLJoin(Step):
915
1011
  catalog: "Catalog"
916
1012
  query1: "DatasetQuery"
917
1013
  query2: "DatasetQuery"
918
- predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
1014
+ predicates: JoinPredicateType | tuple[JoinPredicateType, ...]
919
1015
  inner: bool
920
1016
  full: bool
921
1017
  rname: str
922
1018
 
1019
+ @staticmethod
1020
+ def _split_db_name(name: str) -> tuple[str, str]:
1021
+ if DEFAULT_DELIMITER in name:
1022
+ head, tail = name.split(DEFAULT_DELIMITER, 1)
1023
+ return head, tail
1024
+ return name, ""
1025
+
1026
+ @classmethod
1027
+ def _root_name(cls, name: str) -> str:
1028
+ return cls._split_db_name(name)[0]
1029
+
1030
+ def hash_inputs(self) -> str:
1031
+ predicates = (
1032
+ ensure_sequence(self.predicates) if self.predicates is not None else []
1033
+ )
1034
+
1035
+ parts = [
1036
+ bytes.fromhex(self.query1.hash()),
1037
+ bytes.fromhex(self.query2.hash()),
1038
+ bytes.fromhex(hash_column_elements(predicates)),
1039
+ str(self.inner).encode(),
1040
+ str(self.full).encode(),
1041
+ self.rname.encode("utf-8"),
1042
+ ]
1043
+
1044
+ return hashlib.sha256(b"".join(parts)).hexdigest()
1045
+
923
1046
  def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
1047
+ temp_tables_before = len(dq.temp_table_names)
924
1048
  query = dq.apply_steps().select()
925
- temp_tables.extend(dq.temp_table_names)
1049
+ temp_tables.extend(dq.temp_table_names[temp_tables_before:])
926
1050
 
927
1051
  if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in dq.steps):
928
1052
  return query.subquery(dq.table.name)
@@ -978,22 +1102,39 @@ class SQLJoin(Step):
978
1102
  q1 = self.get_query(self.query1, temp_tables)
979
1103
  q2 = self.get_query(self.query2, temp_tables)
980
1104
 
981
- q1_columns = list(q1.c)
982
- q1_column_names = {c.name for c in q1_columns}
983
-
984
- q2_columns = []
985
- for c in q2.c:
986
- if c.name.startswith("sys__"):
1105
+ q1_columns = _drop_system_columns(q1.c)
1106
+ existing_column_names = {c.name for c in q1_columns}
1107
+ right_columns: list[KeyedColumnElement[Any]] = []
1108
+ right_column_names: list[str] = []
1109
+ for column in q2.c:
1110
+ if column.name.startswith("sys__"):
987
1111
  continue
1112
+ right_columns.append(column)
1113
+ right_column_names.append(column.name)
1114
+
1115
+ root_mapping = generate_merge_root_mapping(
1116
+ existing_column_names,
1117
+ right_column_names,
1118
+ extract_root=self._root_name,
1119
+ prefix=self.rname,
1120
+ )
1121
+
1122
+ q2_columns: list[KeyedColumnElement[Any]] = []
1123
+ for column in right_columns:
1124
+ original_name = column.name
1125
+ column_root, column_tail = self._split_db_name(original_name)
1126
+ mapped_root = root_mapping[column_root]
1127
+
1128
+ new_name = (
1129
+ mapped_root
1130
+ if not column_tail
1131
+ else DEFAULT_DELIMITER.join([mapped_root, column_tail])
1132
+ )
1133
+
1134
+ if new_name != original_name:
1135
+ column = column.label(new_name)
988
1136
 
989
- if c.name in q1_column_names:
990
- new_name = self.rname.format(name=c.name)
991
- new_name_idx = 0
992
- while new_name in q1_column_names:
993
- new_name_idx += 1
994
- new_name = self.rname.format(name=f"{c.name}_{new_name_idx}")
995
- c = c.label(new_name)
996
- q2_columns.append(c)
1137
+ q2_columns.append(column)
997
1138
 
998
1139
  res_columns = q1_columns + q2_columns
999
1140
  predicates = (
@@ -1038,8 +1179,15 @@ class SQLJoin(Step):
1038
1179
 
1039
1180
  @frozen
1040
1181
  class SQLGroupBy(SQLClause):
1041
- cols: Sequence[Union[str, Function, ColumnElement]]
1042
- group_by: Sequence[Union[str, Function, ColumnElement]]
1182
+ cols: Sequence[str | Function | ColumnElement]
1183
+ group_by: Sequence[str | Function | ColumnElement]
1184
+
1185
+ def hash_inputs(self) -> str:
1186
+ return hashlib.sha256(
1187
+ bytes.fromhex(
1188
+ hash_column_elements(self.cols) + hash_column_elements(self.group_by)
1189
+ )
1190
+ ).hexdigest()
1043
1191
 
1044
1192
  def apply_sql_clause(self, query) -> Select:
1045
1193
  if not self.cols:
@@ -1069,46 +1217,52 @@ class SQLGroupBy(SQLClause):
1069
1217
  return sqlalchemy.select(*unique_cols).select_from(subquery).group_by(*group_by)
1070
1218
 
1071
1219
 
1072
- def _validate_columns(
1073
- left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
1074
- ) -> set[str]:
1075
- left_names = {c.name for c in left_columns}
1076
- right_names = {c.name for c in right_columns}
1077
-
1078
- if left_names == right_names:
1079
- return left_names
1080
-
1081
- missing_right = left_names - right_names
1082
- missing_left = right_names - left_names
1083
-
1084
- def _prepare_msg_part(missing_columns: set[str], side: str) -> str:
1085
- return f"{', '.join(sorted(missing_columns))} only present in {side}"
1086
-
1087
- msg_parts = [
1088
- _prepare_msg_part(missing_columns, found_side)
1089
- for missing_columns, found_side in zip(
1090
- [
1091
- missing_right,
1092
- missing_left,
1093
- ],
1094
- ["left", "right"],
1095
- )
1096
- if missing_columns
1097
- ]
1098
- msg = f"Cannot perform union. {'. '.join(msg_parts)}"
1220
+ class UnionSchemaMismatchError(ValueError):
1221
+ """Union input columns mismatch."""
1099
1222
 
1100
- raise ValueError(msg)
1223
+ @classmethod
1224
+ def from_column_sets(
1225
+ cls,
1226
+ missing_left: set[str],
1227
+ missing_right: set[str],
1228
+ ) -> "UnionSchemaMismatchError":
1229
+ def _describe(cols: set[str], side: str) -> str:
1230
+ return f"{', '.join(sorted(cols))} only present in {side}"
1231
+
1232
+ parts = []
1233
+ if missing_left:
1234
+ parts.append(_describe(missing_left, "left"))
1235
+ if missing_right:
1236
+ parts.append(_describe(missing_right, "right"))
1237
+
1238
+ return cls(f"Cannot perform union. {'. '.join(parts)}")
1101
1239
 
1102
1240
 
1103
1241
  def _order_columns(
1104
1242
  left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
1105
1243
  ) -> list[list[ColumnElement]]:
1106
- column_order = _validate_columns(left_columns, right_columns)
1244
+ left_names = [c.name for c in left_columns]
1245
+ right_names = [c.name for c in right_columns]
1246
+
1247
+ # validate
1248
+ if sorted(left_names) != sorted(right_names):
1249
+ left_names_set = set(left_names)
1250
+ right_names_set = set(right_names)
1251
+ raise UnionSchemaMismatchError.from_column_sets(
1252
+ left_names_set - right_names_set,
1253
+ right_names_set - left_names_set,
1254
+ )
1255
+
1256
+ # Order columns to match left_names order
1107
1257
  column_dicts = [
1108
1258
  {c.name: c for c in columns} for columns in [left_columns, right_columns]
1109
1259
  ]
1110
1260
 
1111
- return [[d[n] for n in column_order] for d in column_dicts]
1261
+ return [[d[n] for n in left_names] for d in column_dicts]
1262
+
1263
+
1264
+ def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
1265
+ return [c for c in columns if not c.name.startswith("sys__")]
1112
1266
 
1113
1267
 
1114
1268
  @attrs.define
@@ -1124,40 +1278,42 @@ class DatasetQuery:
1124
1278
  def __init__(
1125
1279
  self,
1126
1280
  name: str,
1127
- version: Optional[str] = None,
1128
- project_name: Optional[str] = None,
1129
- namespace_name: Optional[str] = None,
1130
- catalog: Optional["Catalog"] = None,
1131
- session: Optional[Session] = None,
1281
+ version: str | None = None,
1282
+ project_name: str | None = None,
1283
+ namespace_name: str | None = None,
1284
+ catalog: "Catalog | None" = None,
1285
+ session: Session | None = None,
1132
1286
  in_memory: bool = False,
1133
1287
  update: bool = False,
1134
1288
  ) -> None:
1135
1289
  self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
1136
1290
  self.catalog = catalog or self.session.catalog
1137
1291
  self.steps: list[Step] = []
1138
- self._chunk_index: Optional[int] = None
1139
- self._chunk_total: Optional[int] = None
1292
+ self._chunk_index: int | None = None
1293
+ self._chunk_total: int | None = None
1140
1294
  self.temp_table_names: list[str] = []
1141
1295
  self.dependencies: set[DatasetDependencyType] = set()
1142
1296
  self.table = self.get_table()
1143
- self.starting_step: Optional[QueryStep] = None
1144
- self.name: Optional[str] = None
1145
- self.version: Optional[str] = None
1146
- self.feature_schema: Optional[dict] = None
1147
- self.column_types: Optional[dict[str, Any]] = None
1297
+ self.starting_step: QueryStep | None = None
1298
+ self.name: str | None = None
1299
+ self.version: str | None = None
1300
+ self.feature_schema: dict | None = None
1301
+ self.column_types: dict[str, Any] | None = None
1148
1302
  self.before_steps: list[Callable] = []
1149
- self.listing_fn: Optional[Callable] = None
1303
+ self.listing_fn: Callable | None = None
1150
1304
  self.update = update
1151
1305
 
1152
- self.list_ds_name: Optional[str] = None
1306
+ self.list_ds_name: str | None = None
1153
1307
 
1154
1308
  self.name = name
1155
1309
  self.dialect = self.catalog.warehouse.db.dialect
1156
1310
  if version:
1157
1311
  self.version = version
1158
1312
 
1159
- namespace_name = namespace_name or self.catalog.metastore.default_namespace_name
1160
- project_name = project_name or self.catalog.metastore.default_project_name
1313
+ if namespace_name is None:
1314
+ namespace_name = self.catalog.metastore.default_namespace_name
1315
+ if project_name is None:
1316
+ project_name = self.catalog.metastore.default_project_name
1161
1317
 
1162
1318
  if is_listing_dataset(name) and not version:
1163
1319
  # not setting query step yet as listing dataset might not exist at
@@ -1194,12 +1350,26 @@ class DatasetQuery:
1194
1350
  def __or__(self, other):
1195
1351
  return self.union(other)
1196
1352
 
1353
+ def hash(self) -> str:
1354
+ """
1355
+ Calculates hash of this class taking into account hash of starting step
1356
+ and hashes of each following steps. Ordering is important.
1357
+ """
1358
+ hasher = hashlib.sha256()
1359
+ if self.starting_step:
1360
+ hasher.update(self.starting_step.hash().encode("utf-8"))
1361
+ else:
1362
+ assert self.list_ds_name
1363
+ hasher.update(self.list_ds_name.encode("utf-8"))
1364
+
1365
+ for step in self.steps:
1366
+ hasher.update(step.hash().encode("utf-8"))
1367
+
1368
+ return hasher.hexdigest()
1369
+
1197
1370
  @staticmethod
1198
1371
  def get_table() -> "TableClause":
1199
- table_name = "".join(
1200
- random.choice(string.ascii_letters) # noqa: S311
1201
- for _ in range(16)
1202
- )
1372
+ table_name = "".join(secrets.choice(string.ascii_letters) for _ in range(16))
1203
1373
  return sqlalchemy.table(table_name)
1204
1374
 
1205
1375
  @property
@@ -1216,7 +1386,7 @@ class DatasetQuery:
1216
1386
  """
1217
1387
  return self.name is not None and self.version is not None
1218
1388
 
1219
- def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
1389
+ def c(self, column: C | str) -> "ColumnClause[Any]":
1220
1390
  col: sqlalchemy.ColumnClause = (
1221
1391
  sqlalchemy.column(column)
1222
1392
  if isinstance(column, str)
@@ -1311,6 +1481,7 @@ class DatasetQuery:
1311
1481
  # This is needed to always use a new connection with all metastore and warehouse
1312
1482
  # implementations, as errors may close or render unusable the existing
1313
1483
  # connections.
1484
+ assert len(self.temp_table_names) == len(set(self.temp_table_names))
1314
1485
  with self.catalog.metastore.clone(use_new_connection=True) as metastore:
1315
1486
  metastore.cleanup_tables(self.temp_table_names)
1316
1487
  with self.catalog.warehouse.clone(use_new_connection=True) as warehouse:
@@ -1325,7 +1496,7 @@ class DatasetQuery:
1325
1496
  return list(result)
1326
1497
 
1327
1498
  def to_db_records(self) -> list[dict[str, Any]]:
1328
- return self.db_results(lambda cols, row: dict(zip(cols, row)))
1499
+ return self.db_results(lambda cols, row: dict(zip(cols, row, strict=False)))
1329
1500
 
1330
1501
  @contextlib.contextmanager
1331
1502
  def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
@@ -1364,7 +1535,7 @@ class DatasetQuery:
1364
1535
  yield from rows
1365
1536
 
1366
1537
  async def get_params(row: Sequence) -> tuple:
1367
- row_dict = RowDict(zip(query_fields, row))
1538
+ row_dict = RowDict(zip(query_fields, row, strict=False))
1368
1539
  return tuple( # noqa: C409
1369
1540
  [
1370
1541
  await p.get_value_async(
@@ -1381,10 +1552,6 @@ class DatasetQuery:
1381
1552
  finally:
1382
1553
  self.cleanup()
1383
1554
 
1384
- def shuffle(self) -> "Self":
1385
- # ToDo: implement shaffle based on seed and/or generating random column
1386
- return self.order_by(C.sys__rand)
1387
-
1388
1555
  def sample(self, n) -> "Self":
1389
1556
  """
1390
1557
  Return a random sample from the dataset.
@@ -1404,6 +1571,7 @@ class DatasetQuery:
1404
1571
  obj.steps = obj.steps.copy()
1405
1572
  if new_table:
1406
1573
  obj.table = self.get_table()
1574
+ obj.temp_table_names = []
1407
1575
  return obj
1408
1576
 
1409
1577
  @detach
@@ -1584,10 +1752,10 @@ class DatasetQuery:
1584
1752
  def join(
1585
1753
  self,
1586
1754
  dataset_query: "DatasetQuery",
1587
- predicates: Union[JoinPredicateType, Sequence[JoinPredicateType]],
1755
+ predicates: JoinPredicateType | Sequence[JoinPredicateType],
1588
1756
  inner=False,
1589
1757
  full=False,
1590
- rname="{name}_right",
1758
+ rname="right_",
1591
1759
  ) -> "Self":
1592
1760
  left = self.clone(new_table=False)
1593
1761
  if self.table.name == dataset_query.table.name:
@@ -1626,12 +1794,17 @@ class DatasetQuery:
1626
1794
  def add_signals(
1627
1795
  self,
1628
1796
  udf: "UDFAdapter",
1629
- parallel: Optional[int] = None,
1630
- workers: Union[bool, int] = False,
1631
- min_task_size: Optional[int] = None,
1632
- partition_by: Optional[PartitionByType] = None,
1797
+ partition_by: PartitionByType | None = None,
1798
+ # Parameters from Settings
1633
1799
  cache: bool = False,
1634
- batch_rows: Optional[int] = None,
1800
+ parallel: int | None = None,
1801
+ workers: bool | int = False,
1802
+ min_task_size: int | None = None,
1803
+ batch_size: int | None = None,
1804
+ # Parameters are unused, kept only to match the signature of Settings.to_dict
1805
+ prefetch: int | None = None,
1806
+ namespace: str | None = None,
1807
+ project: str | None = None,
1635
1808
  ) -> "Self":
1636
1809
  """
1637
1810
  Adds one or more signals based on the results from the provided UDF.
@@ -1657,7 +1830,7 @@ class DatasetQuery:
1657
1830
  workers=workers,
1658
1831
  min_task_size=min_task_size,
1659
1832
  cache=cache,
1660
- batch_rows=batch_rows,
1833
+ batch_size=batch_size,
1661
1834
  )
1662
1835
  )
1663
1836
  return query
@@ -1672,14 +1845,17 @@ class DatasetQuery:
1672
1845
  def generate(
1673
1846
  self,
1674
1847
  udf: "UDFAdapter",
1675
- parallel: Optional[int] = None,
1676
- workers: Union[bool, int] = False,
1677
- min_task_size: Optional[int] = None,
1678
- partition_by: Optional[PartitionByType] = None,
1679
- namespace: Optional[str] = None,
1680
- project: Optional[str] = None,
1848
+ partition_by: PartitionByType | None = None,
1849
+ # Parameters from Settings
1681
1850
  cache: bool = False,
1682
- batch_rows: Optional[int] = None,
1851
+ parallel: int | None = None,
1852
+ workers: bool | int = False,
1853
+ min_task_size: int | None = None,
1854
+ batch_size: int | None = None,
1855
+ # Parameters are unused, kept only to match the signature of Settings.to_dict:
1856
+ prefetch: int | None = None,
1857
+ namespace: str | None = None,
1858
+ project: str | None = None,
1683
1859
  ) -> "Self":
1684
1860
  query = self.clone()
1685
1861
  steps = query.steps
@@ -1692,7 +1868,7 @@ class DatasetQuery:
1692
1868
  workers=workers,
1693
1869
  min_task_size=min_task_size,
1694
1870
  cache=cache,
1695
- batch_rows=batch_rows,
1871
+ batch_size=batch_size,
1696
1872
  )
1697
1873
  )
1698
1874
  return query
@@ -1735,26 +1911,30 @@ class DatasetQuery:
1735
1911
 
1736
1912
  def exec(self) -> "Self":
1737
1913
  """Execute the query."""
1914
+ query = self.clone()
1738
1915
  try:
1739
- query = self.clone()
1740
1916
  query.apply_steps()
1741
1917
  finally:
1742
- self.cleanup()
1918
+ query.cleanup()
1743
1919
  return query
1744
1920
 
1745
1921
  def save(
1746
1922
  self,
1747
- name: Optional[str] = None,
1748
- version: Optional[str] = None,
1749
- project: Optional[Project] = None,
1750
- feature_schema: Optional[dict] = None,
1751
- dependencies: Optional[list[DatasetDependency]] = None,
1752
- description: Optional[str] = None,
1753
- attrs: Optional[list[str]] = None,
1754
- update_version: Optional[str] = "patch",
1923
+ name: str | None = None,
1924
+ version: str | None = None,
1925
+ project: Project | None = None,
1926
+ feature_schema: dict | None = None,
1927
+ dependencies: list[DatasetDependency] | None = None,
1928
+ description: str | None = None,
1929
+ attrs: list[str] | None = None,
1930
+ update_version: str | None = "patch",
1755
1931
  **kwargs,
1756
1932
  ) -> "Self":
1757
1933
  """Save the query as a dataset."""
1934
+ # Get job from session to link dataset version to job
1935
+ job = self.session.get_or_create_job()
1936
+ job_id = job.id
1937
+
1758
1938
  project = project or self.catalog.metastore.default_project
1759
1939
  try:
1760
1940
  if (
@@ -1797,14 +1977,11 @@ class DatasetQuery:
1797
1977
  description=description,
1798
1978
  attrs=attrs,
1799
1979
  update_version=update_version,
1980
+ job_id=job_id,
1800
1981
  **kwargs,
1801
1982
  )
1802
1983
  version = version or dataset.latest_version
1803
1984
 
1804
- self.session.add_dataset_version(
1805
- dataset=dataset, version=version, listing=kwargs.get("listing", False)
1806
- )
1807
-
1808
1985
  dr = self.catalog.warehouse.dataset_rows(dataset)
1809
1986
 
1810
1987
  self.catalog.warehouse.copy_table(dr.get_table(), query.select())
@@ -1814,6 +1991,11 @@ class DatasetQuery:
1814
1991
  )
1815
1992
  self.catalog.update_dataset_version_with_warehouse_info(dataset, version)
1816
1993
 
1994
+ # Link this dataset version to the job that created it
1995
+ self.catalog.metastore.link_dataset_version_to_job(
1996
+ dataset.get_version(version).id, job_id, is_creator=True
1997
+ )
1998
+
1817
1999
  if dependencies:
1818
2000
  # overriding dependencies
1819
2001
  self.dependencies = set()
@@ -1845,5 +2027,5 @@ class DatasetQuery:
1845
2027
  return isinstance(self.last_step, SQLOrderBy)
1846
2028
 
1847
2029
  @property
1848
- def last_step(self) -> Optional[Step]:
2030
+ def last_step(self) -> Step | None:
1849
2031
  return self.steps[-1] if self.steps else None