datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. datachain/__init__.py +20 -0
  2. datachain/asyn.py +11 -12
  3. datachain/cache.py +7 -7
  4. datachain/catalog/__init__.py +2 -2
  5. datachain/catalog/catalog.py +621 -507
  6. datachain/catalog/dependency.py +164 -0
  7. datachain/catalog/loader.py +28 -18
  8. datachain/checkpoint.py +43 -0
  9. datachain/cli/__init__.py +24 -33
  10. datachain/cli/commands/__init__.py +1 -8
  11. datachain/cli/commands/datasets.py +83 -52
  12. datachain/cli/commands/ls.py +17 -17
  13. datachain/cli/commands/show.py +4 -4
  14. datachain/cli/parser/__init__.py +8 -74
  15. datachain/cli/parser/job.py +95 -3
  16. datachain/cli/parser/studio.py +11 -4
  17. datachain/cli/parser/utils.py +1 -2
  18. datachain/cli/utils.py +2 -15
  19. datachain/client/azure.py +4 -4
  20. datachain/client/fsspec.py +45 -28
  21. datachain/client/gcs.py +6 -6
  22. datachain/client/hf.py +29 -2
  23. datachain/client/http.py +157 -0
  24. datachain/client/local.py +15 -11
  25. datachain/client/s3.py +17 -9
  26. datachain/config.py +4 -8
  27. datachain/data_storage/db_engine.py +12 -6
  28. datachain/data_storage/job.py +5 -1
  29. datachain/data_storage/metastore.py +1252 -186
  30. datachain/data_storage/schema.py +58 -45
  31. datachain/data_storage/serializer.py +105 -15
  32. datachain/data_storage/sqlite.py +286 -127
  33. datachain/data_storage/warehouse.py +250 -113
  34. datachain/dataset.py +353 -148
  35. datachain/delta.py +391 -0
  36. datachain/diff/__init__.py +27 -29
  37. datachain/error.py +60 -0
  38. datachain/func/__init__.py +2 -1
  39. datachain/func/aggregate.py +66 -42
  40. datachain/func/array.py +242 -38
  41. datachain/func/base.py +7 -4
  42. datachain/func/conditional.py +110 -60
  43. datachain/func/func.py +96 -45
  44. datachain/func/numeric.py +55 -38
  45. datachain/func/path.py +32 -20
  46. datachain/func/random.py +2 -2
  47. datachain/func/string.py +67 -37
  48. datachain/func/window.py +7 -8
  49. datachain/hash_utils.py +123 -0
  50. datachain/job.py +11 -7
  51. datachain/json.py +138 -0
  52. datachain/lib/arrow.py +58 -22
  53. datachain/lib/audio.py +245 -0
  54. datachain/lib/clip.py +14 -13
  55. datachain/lib/convert/flatten.py +5 -3
  56. datachain/lib/convert/python_to_sql.py +6 -10
  57. datachain/lib/convert/sql_to_python.py +8 -0
  58. datachain/lib/convert/values_to_tuples.py +156 -51
  59. datachain/lib/data_model.py +42 -20
  60. datachain/lib/dataset_info.py +36 -8
  61. datachain/lib/dc/__init__.py +8 -2
  62. datachain/lib/dc/csv.py +25 -28
  63. datachain/lib/dc/database.py +398 -0
  64. datachain/lib/dc/datachain.py +1289 -425
  65. datachain/lib/dc/datasets.py +320 -38
  66. datachain/lib/dc/hf.py +38 -24
  67. datachain/lib/dc/json.py +29 -32
  68. datachain/lib/dc/listings.py +112 -8
  69. datachain/lib/dc/pandas.py +16 -12
  70. datachain/lib/dc/parquet.py +35 -23
  71. datachain/lib/dc/records.py +31 -23
  72. datachain/lib/dc/storage.py +154 -64
  73. datachain/lib/dc/storage_pattern.py +251 -0
  74. datachain/lib/dc/utils.py +24 -16
  75. datachain/lib/dc/values.py +8 -9
  76. datachain/lib/file.py +622 -89
  77. datachain/lib/hf.py +69 -39
  78. datachain/lib/image.py +14 -14
  79. datachain/lib/listing.py +14 -11
  80. datachain/lib/listing_info.py +1 -2
  81. datachain/lib/meta_formats.py +3 -4
  82. datachain/lib/model_store.py +39 -7
  83. datachain/lib/namespaces.py +125 -0
  84. datachain/lib/projects.py +130 -0
  85. datachain/lib/pytorch.py +32 -21
  86. datachain/lib/settings.py +192 -56
  87. datachain/lib/signal_schema.py +427 -104
  88. datachain/lib/tar.py +1 -2
  89. datachain/lib/text.py +8 -7
  90. datachain/lib/udf.py +164 -76
  91. datachain/lib/udf_signature.py +60 -35
  92. datachain/lib/utils.py +118 -4
  93. datachain/lib/video.py +17 -9
  94. datachain/lib/webdataset.py +61 -56
  95. datachain/lib/webdataset_laion.py +15 -16
  96. datachain/listing.py +22 -10
  97. datachain/model/bbox.py +3 -1
  98. datachain/model/ultralytics/bbox.py +16 -12
  99. datachain/model/ultralytics/pose.py +16 -12
  100. datachain/model/ultralytics/segment.py +16 -12
  101. datachain/namespace.py +84 -0
  102. datachain/node.py +6 -6
  103. datachain/nodes_thread_pool.py +0 -1
  104. datachain/plugins.py +24 -0
  105. datachain/project.py +78 -0
  106. datachain/query/batch.py +40 -41
  107. datachain/query/dataset.py +604 -322
  108. datachain/query/dispatch.py +261 -154
  109. datachain/query/metrics.py +4 -6
  110. datachain/query/params.py +2 -3
  111. datachain/query/queue.py +3 -12
  112. datachain/query/schema.py +11 -6
  113. datachain/query/session.py +200 -33
  114. datachain/query/udf.py +34 -2
  115. datachain/remote/studio.py +171 -69
  116. datachain/script_meta.py +12 -12
  117. datachain/semver.py +68 -0
  118. datachain/sql/__init__.py +2 -0
  119. datachain/sql/functions/array.py +33 -1
  120. datachain/sql/postgresql_dialect.py +9 -0
  121. datachain/sql/postgresql_types.py +21 -0
  122. datachain/sql/sqlite/__init__.py +5 -1
  123. datachain/sql/sqlite/base.py +102 -29
  124. datachain/sql/sqlite/types.py +8 -13
  125. datachain/sql/types.py +70 -15
  126. datachain/studio.py +223 -46
  127. datachain/toolkit/split.py +31 -10
  128. datachain/utils.py +101 -59
  129. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
  130. datachain-0.39.0.dist-info/RECORD +173 -0
  131. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
  132. datachain/cli/commands/query.py +0 -53
  133. datachain/query/utils.py +0 -42
  134. datachain-0.14.2.dist-info/RECORD +0 -158
  135. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
  136. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
  137. {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,18 @@
1
1
  import contextlib
2
+ import hashlib
2
3
  import inspect
3
4
  import logging
4
5
  import os
5
- import random
6
+ import secrets
6
7
  import string
7
8
  import subprocess
8
9
  import sys
9
10
  from abc import ABC, abstractmethod
10
- from collections.abc import Generator, Iterable, Iterator, Sequence
11
+ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
11
12
  from copy import copy
12
13
  from functools import wraps
13
- from secrets import token_hex
14
- from typing import (
15
- TYPE_CHECKING,
16
- Any,
17
- Callable,
18
- Optional,
19
- Protocol,
20
- TypeVar,
21
- Union,
22
- )
14
+ from types import GeneratorType
15
+ from typing import TYPE_CHECKING, Any, Protocol, TypeVar
23
16
 
24
17
  import attrs
25
18
  import sqlalchemy
@@ -28,7 +21,7 @@ from attrs import frozen
28
21
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
29
22
  from sqlalchemy import Column
30
23
  from sqlalchemy.sql import func as f
31
- from sqlalchemy.sql.elements import ColumnClause, ColumnElement
24
+ from sqlalchemy.sql.elements import ColumnClause, ColumnElement, Label
32
25
  from sqlalchemy.sql.expression import label
33
26
  from sqlalchemy.sql.schema import TableClause
34
27
  from sqlalchemy.sql.selectable import Select
@@ -41,51 +34,53 @@ from datachain.data_storage.schema import (
41
34
  partition_col_names,
42
35
  partition_columns,
43
36
  )
44
- from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
45
- from datachain.error import (
46
- DatasetNotFoundError,
47
- QueryScriptCancelError,
48
- )
37
+ from datachain.dataset import DatasetDependency, DatasetStatus, RowDict
38
+ from datachain.error import DatasetNotFoundError, QueryScriptCancelError
49
39
  from datachain.func.base import Function
50
- from datachain.lib.listing import (
51
- is_listing_dataset,
52
- listing_dataset_expired,
53
- )
40
+ from datachain.hash_utils import hash_column_elements
41
+ from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
42
+ from datachain.lib.signal_schema import SignalSchema, generate_merge_root_mapping
54
43
  from datachain.lib.udf import UDFAdapter, _get_cache
55
44
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
56
- from datachain.query.schema import C, UDFParamSpec, normalize_param
45
+ from datachain.project import Project
46
+ from datachain.query.schema import DEFAULT_DELIMITER, C, UDFParamSpec, normalize_param
57
47
  from datachain.query.session import Session
48
+ from datachain.query.udf import UdfInfo
58
49
  from datachain.sql.functions.random import rand
50
+ from datachain.sql.types import SQLType
59
51
  from datachain.utils import (
60
- batched,
61
52
  determine_processes,
53
+ determine_workers,
54
+ ensure_sequence,
62
55
  filtered_cloudpickle_dumps,
63
56
  get_datachain_executable,
64
57
  safe_closing,
65
58
  )
66
59
 
67
60
  if TYPE_CHECKING:
68
- from sqlalchemy.sql.elements import ClauseElement
61
+ from collections.abc import Mapping
62
+ from typing import Concatenate
63
+
64
+ from sqlalchemy.sql.elements import ClauseElement, KeyedColumnElement
69
65
  from sqlalchemy.sql.schema import Table
70
66
  from sqlalchemy.sql.selectable import GenerativeSelect
71
- from typing_extensions import Concatenate, ParamSpec, Self
67
+ from typing_extensions import ParamSpec, Self
72
68
 
73
69
  from datachain.catalog import Catalog
74
70
  from datachain.data_storage import AbstractWarehouse
75
71
  from datachain.dataset import DatasetRecord
76
72
  from datachain.lib.udf import UDFAdapter, UDFResult
77
- from datachain.query.udf import UdfInfo
78
73
 
79
74
  P = ParamSpec("P")
80
75
 
81
76
 
82
77
  INSERT_BATCH_SIZE = 10000
83
78
 
84
- PartitionByType = Union[
85
- Function, ColumnElement, Sequence[Union[Function, ColumnElement]]
86
- ]
87
- JoinPredicateType = Union[str, ColumnClause, ColumnElement]
88
- DatasetDependencyType = tuple[str, int]
79
+ PartitionByType = (
80
+ str | Function | ColumnElement | Sequence[str | Function | ColumnElement]
81
+ )
82
+ JoinPredicateType = str | ColumnClause | ColumnElement
83
+ DatasetDependencyType = tuple["DatasetRecord", str]
89
84
 
90
85
  logger = logging.getLogger("datachain")
91
86
 
@@ -165,24 +160,42 @@ class Step(ABC):
165
160
  ) -> "StepResult":
166
161
  """Apply the processing step."""
167
162
 
163
+ @abstractmethod
164
+ def hash_inputs(self) -> str:
165
+ """Calculates hash of step inputs"""
166
+
167
+ def hash(self) -> str:
168
+ """
169
+ Calculates hash for step which includes step name and hash of it's inputs
170
+ """
171
+ return hashlib.sha256(
172
+ f"{self.__class__.__name__}|{self.hash_inputs()}".encode()
173
+ ).hexdigest()
174
+
168
175
 
169
176
  @frozen
170
177
  class QueryStep:
178
+ """A query that returns all rows from specific dataset version"""
179
+
171
180
  catalog: "Catalog"
172
- dataset_name: str
173
- dataset_version: int
181
+ dataset: "DatasetRecord"
182
+ dataset_version: str
174
183
 
175
- def apply(self):
184
+ def apply(self) -> "StepResult":
176
185
  def q(*columns):
177
186
  return sqlalchemy.select(*columns)
178
187
 
179
- dataset = self.catalog.get_dataset(self.dataset_name)
180
- dr = self.catalog.warehouse.dataset_rows(dataset, self.dataset_version)
188
+ dr = self.catalog.warehouse.dataset_rows(self.dataset, self.dataset_version)
181
189
 
182
190
  return step_result(
183
- q, dr.columns, dependencies=[(self.dataset_name, self.dataset_version)]
191
+ q, dr.columns, dependencies=[(self.dataset, self.dataset_version)]
184
192
  )
185
193
 
194
+ def hash(self) -> str:
195
+ return hashlib.sha256(
196
+ self.dataset.uri(self.dataset_version).encode()
197
+ ).hexdigest()
198
+
186
199
 
187
200
  def generator_then_call(generator, func: Callable):
188
201
  """
@@ -218,8 +231,9 @@ class DatasetDiffOperation(Step):
218
231
 
219
232
  def apply(self, query_generator, temp_tables: list[str]) -> "StepResult":
220
233
  source_query = query_generator.exclude(("sys__id",))
234
+ right_before = len(self.dq.temp_table_names)
221
235
  target_query = self.dq.apply_steps().select()
222
- temp_tables.extend(self.dq.temp_table_names)
236
+ temp_tables.extend(self.dq.temp_table_names[right_before:])
223
237
 
224
238
  # creating temp table that will hold subtract results
225
239
  temp_table_name = self.catalog.warehouse.temp_table_name()
@@ -253,6 +267,13 @@ class DatasetDiffOperation(Step):
253
267
  class Subtract(DatasetDiffOperation):
254
268
  on: Sequence[tuple[str, str]]
255
269
 
270
+ def hash_inputs(self) -> str:
271
+ on_bytes = b"".join(
272
+ f"{a}:{b}".encode() for a, b in sorted(self.on, key=lambda t: (t[0], t[1]))
273
+ )
274
+
275
+ return hashlib.sha256(bytes.fromhex(self.dq.hash()) + on_bytes).hexdigest()
276
+
256
277
  def query(self, source_query: Select, target_query: Select) -> sa.Selectable:
257
278
  sq = source_query.alias("source_query")
258
279
  tq = target_query.alias("target_query")
@@ -272,7 +293,9 @@ class Subtract(DatasetDiffOperation):
272
293
 
273
294
 
274
295
  def adjust_outputs(
275
- warehouse: "AbstractWarehouse", row: dict[str, Any], udf_col_types: list[tuple]
296
+ warehouse: "AbstractWarehouse",
297
+ row: dict[str, Any],
298
+ col_types: list[tuple[str, SQLType, type, str, Any]],
276
299
  ) -> dict[str, Any]:
277
300
  """
278
301
  This function does a couple of things to prepare a row for inserting into the db:
@@ -288,7 +311,7 @@ def adjust_outputs(
288
311
  col_python_type,
289
312
  col_type_name,
290
313
  default_value,
291
- ) in udf_col_types:
314
+ ) in col_types:
292
315
  row_val = row.get(col_name)
293
316
 
294
317
  # Fill None or missing values with defaults (get returns None if not in the row)
@@ -303,8 +326,10 @@ def adjust_outputs(
303
326
  return row
304
327
 
305
328
 
306
- def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]:
307
- """Optimization: Precompute UDF column types so these don't have to be computed
329
+ def get_col_types(
330
+ warehouse: "AbstractWarehouse", output: "Mapping[str, Any]"
331
+ ) -> list[tuple]:
332
+ """Optimization: Precompute column types so these don't have to be computed
308
333
  in the convert_type function for each row in a loop."""
309
334
  dialect = warehouse.db.dialect
310
335
  return [
@@ -316,7 +341,7 @@ def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list
316
341
  type(col_type_inst).__name__,
317
342
  col_type.default_value(dialect),
318
343
  )
319
- for col_name, col_type in udf.output.items()
344
+ for col_name, col_type in output.items()
320
345
  ]
321
346
 
322
347
 
@@ -325,33 +350,23 @@ def process_udf_outputs(
325
350
  udf_table: "Table",
326
351
  udf_results: Iterator[Iterable["UDFResult"]],
327
352
  udf: "UDFAdapter",
328
- batch_size: int = INSERT_BATCH_SIZE,
329
353
  cb: Callback = DEFAULT_CALLBACK,
354
+ batch_size: int = INSERT_BATCH_SIZE,
330
355
  ) -> None:
331
- import psutil
332
-
333
- rows: list[UDFResult] = []
334
356
  # Optimization: Compute row types once, rather than for every row.
335
- udf_col_types = get_udf_col_types(warehouse, udf)
357
+ udf_col_types = get_col_types(warehouse, udf.output)
336
358
 
337
- for udf_output in udf_results:
338
- if not udf_output:
339
- continue
340
- with safe_closing(udf_output):
341
- for row in udf_output:
342
- cb.relative_update()
343
- rows.append(adjust_outputs(warehouse, row, udf_col_types))
344
- if len(rows) >= batch_size or (
345
- len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
346
- ):
347
- for row_chunk in batched(rows, batch_size):
348
- warehouse.insert_rows(udf_table, row_chunk)
349
- rows.clear()
359
+ def _insert_rows():
360
+ for udf_output in udf_results:
361
+ if not udf_output:
362
+ continue
350
363
 
351
- if rows:
352
- for row_chunk in batched(rows, batch_size):
353
- warehouse.insert_rows(udf_table, row_chunk)
364
+ with safe_closing(udf_output):
365
+ for row in udf_output:
366
+ cb.relative_update()
367
+ yield adjust_outputs(warehouse, row, udf_col_types)
354
368
 
369
+ warehouse.insert_rows(udf_table, _insert_rows(), batch_size=batch_size)
355
370
  warehouse.insert_rows_done(udf_table)
356
371
 
357
372
 
@@ -387,20 +402,34 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
387
402
  class UDFStep(Step, ABC):
388
403
  udf: "UDFAdapter"
389
404
  catalog: "Catalog"
390
- partition_by: Optional[PartitionByType] = None
391
- parallel: Optional[int] = None
392
- workers: Union[bool, int] = False
393
- min_task_size: Optional[int] = None
405
+ partition_by: PartitionByType | None = None
394
406
  is_generator = False
407
+ # Parameters from Settings
395
408
  cache: bool = False
409
+ parallel: int | None = None
410
+ workers: bool | int = False
411
+ min_task_size: int | None = None
412
+ batch_size: int | None = None
413
+
414
+ def hash_inputs(self) -> str:
415
+ partition_by = ensure_sequence(self.partition_by or [])
416
+ parts = [
417
+ bytes.fromhex(self.udf.hash()),
418
+ bytes.fromhex(hash_column_elements(partition_by)),
419
+ str(self.is_generator).encode(),
420
+ ]
421
+
422
+ return hashlib.sha256(b"".join(parts)).hexdigest()
396
423
 
397
424
  @abstractmethod
398
425
  def create_udf_table(self, query: Select) -> "Table":
399
426
  """Method that creates a table where temp udf results will be saved"""
400
427
 
401
428
  def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
402
- """Apply any necessary processing to the input query"""
403
- return query, []
429
+ """Materialize inputs, ensure sys columns are available, needed for checkpoints,
430
+ needed for map to work (merge results)"""
431
+ table = self.catalog.warehouse.create_pre_udf_table(query)
432
+ return sqlalchemy.select(*table.c), [table]
404
433
 
405
434
  @abstractmethod
406
435
  def create_result_query(
@@ -412,28 +441,48 @@ class UDFStep(Step, ABC):
412
441
  """
413
442
 
414
443
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
444
+ if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
445
+ return
446
+
415
447
  from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
448
+ from datachain.catalog.loader import (
449
+ DISTRIBUTED_IMPORT_PATH,
450
+ get_udf_distributor_class,
451
+ )
452
+
453
+ workers = determine_workers(self.workers, rows_total=rows_total)
454
+ processes = determine_processes(self.parallel, rows_total=rows_total)
416
455
 
417
456
  use_partitioning = self.partition_by is not None
418
457
  batching = self.udf.get_batching(use_partitioning)
419
- workers = self.workers
420
- if (
421
- not workers
422
- and os.environ.get("DATACHAIN_DISTRIBUTED")
423
- and os.environ.get("DATACHAIN_SETTINGS_WORKERS")
424
- ):
425
- # Enable distributed processing by default if the module is available,
426
- # and a default number of workers is provided.
427
- workers = True
428
-
429
- processes = determine_processes(self.parallel)
430
-
431
458
  udf_fields = [str(c.name) for c in query.selected_columns]
459
+ udf_distributor_class = get_udf_distributor_class()
432
460
 
433
461
  prefetch = self.udf.prefetch
434
462
  with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
435
463
  catalog = clone_catalog_with_cache(self.catalog, _cache)
464
+
436
465
  try:
466
+ if udf_distributor_class and not catalog.in_memory:
467
+ # Use the UDF distributor if available (running in SaaS)
468
+ udf_distributor = udf_distributor_class(
469
+ catalog=catalog,
470
+ table=udf_table,
471
+ query=query,
472
+ udf_data=filtered_cloudpickle_dumps(self.udf),
473
+ batching=batching,
474
+ workers=workers,
475
+ processes=processes,
476
+ udf_fields=udf_fields,
477
+ rows_total=rows_total,
478
+ use_cache=self.cache,
479
+ is_generator=self.is_generator,
480
+ min_task_size=self.min_task_size,
481
+ batch_size=self.batch_size,
482
+ )
483
+ udf_distributor()
484
+ return
485
+
437
486
  if workers:
438
487
  if catalog.in_memory:
439
488
  raise RuntimeError(
@@ -441,43 +490,33 @@ class UDFStep(Step, ABC):
441
490
  "distributed processing."
442
491
  )
443
492
 
444
- from datachain.catalog.loader import get_distributed_class
445
-
446
- distributor = get_distributed_class(
447
- min_task_size=self.min_task_size
493
+ raise RuntimeError(
494
+ f"{DISTRIBUTED_IMPORT_PATH} import path is required "
495
+ "for distributed UDF processing."
448
496
  )
449
- distributor(
450
- self.udf,
451
- catalog,
452
- udf_table,
453
- query,
454
- workers,
455
- processes,
456
- udf_fields=udf_fields,
457
- is_generator=self.is_generator,
458
- use_partitioning=use_partitioning,
459
- cache=self.cache,
460
- )
461
- elif processes:
497
+ if processes:
462
498
  # Parallel processing (faster for more CPU-heavy UDFs)
463
499
  if catalog.in_memory:
464
500
  raise RuntimeError(
465
501
  "In-memory databases cannot be used "
466
502
  "with parallel processing."
467
503
  )
468
- udf_info: UdfInfo = {
469
- "udf_data": filtered_cloudpickle_dumps(self.udf),
470
- "catalog_init": catalog.get_init_params(),
471
- "metastore_clone_params": catalog.metastore.clone_params(),
472
- "warehouse_clone_params": catalog.warehouse.clone_params(),
473
- "table": udf_table,
474
- "query": query,
475
- "udf_fields": udf_fields,
476
- "batching": batching,
477
- "processes": processes,
478
- "is_generator": self.is_generator,
479
- "cache": self.cache,
480
- }
504
+
505
+ udf_info = UdfInfo(
506
+ udf_data=filtered_cloudpickle_dumps(self.udf),
507
+ catalog_init=catalog.get_init_params(),
508
+ metastore_clone_params=catalog.metastore.clone_params(),
509
+ warehouse_clone_params=catalog.warehouse.clone_params(),
510
+ table=udf_table,
511
+ query=query,
512
+ udf_fields=udf_fields,
513
+ batching=batching,
514
+ processes=processes,
515
+ is_generator=self.is_generator,
516
+ cache=self.cache,
517
+ rows_total=rows_total,
518
+ batch_size=self.batch_size or INSERT_BATCH_SIZE,
519
+ )
481
520
 
482
521
  # Run the UDFDispatcher in another process to avoid needing
483
522
  # if __name__ == '__main__': in user scripts
@@ -490,7 +529,12 @@ class UDFStep(Step, ABC):
490
529
  with subprocess.Popen( # noqa: S603
491
530
  cmd, env=envs, stdin=subprocess.PIPE
492
531
  ) as process:
493
- process.communicate(process_data)
532
+ try:
533
+ process.communicate(process_data)
534
+ except KeyboardInterrupt:
535
+ raise QueryScriptCancelError(
536
+ "UDF execution was canceled by the user."
537
+ ) from None
494
538
  if retval := process.poll():
495
539
  raise RuntimeError(
496
540
  f"UDF Execution Failed! Exit code: {retval}"
@@ -520,6 +564,7 @@ class UDFStep(Step, ABC):
520
564
  udf_results,
521
565
  self.udf,
522
566
  cb=generated_cb,
567
+ batch_size=self.batch_size or INSERT_BATCH_SIZE,
523
568
  )
524
569
  finally:
525
570
  download_cb.close()
@@ -538,10 +583,13 @@ class UDFStep(Step, ABC):
538
583
  """
539
584
  Create temporary table with group by partitions.
540
585
  """
541
- assert self.partition_by is not None
586
+ if self.partition_by is None:
587
+ raise RuntimeError("Query must have partition_by set to use partitioning")
588
+ if (id_col := query.selected_columns.get("sys__id")) is None:
589
+ raise RuntimeError("Query must have sys__id column to use partitioning")
542
590
 
543
- if isinstance(self.partition_by, Sequence):
544
- list_partition_by = self.partition_by
591
+ if isinstance(self.partition_by, (list, tuple, GeneratorType)):
592
+ list_partition_by = list(self.partition_by)
545
593
  else:
546
594
  list_partition_by = [self.partition_by]
547
595
 
@@ -554,16 +602,19 @@ class UDFStep(Step, ABC):
554
602
 
555
603
  # fill table with partitions
556
604
  cols = [
557
- query.selected_columns.sys__id,
605
+ id_col,
558
606
  f.dense_rank().over(order_by=partition_by).label(PARTITION_COLUMN_ID),
559
607
  ]
560
608
  self.catalog.warehouse.db.execute(
561
- tbl.insert().from_select(cols, query.with_only_columns(*cols))
609
+ tbl.insert().from_select(
610
+ cols,
611
+ query.offset(None).limit(None).with_only_columns(*cols),
612
+ )
562
613
  )
563
614
 
564
615
  return tbl
565
616
 
566
- def clone(self, partition_by: Optional[PartitionByType] = None) -> "Self":
617
+ def clone(self, partition_by: PartitionByType | None = None) -> "Self":
567
618
  if partition_by is not None:
568
619
  return self.__class__(
569
620
  self.udf,
@@ -572,27 +623,25 @@ class UDFStep(Step, ABC):
572
623
  parallel=self.parallel,
573
624
  workers=self.workers,
574
625
  min_task_size=self.min_task_size,
626
+ batch_size=self.batch_size,
575
627
  )
576
628
  return self.__class__(self.udf, self.catalog)
577
629
 
578
630
  def apply(
579
631
  self, query_generator: QueryGenerator, temp_tables: list[str]
580
632
  ) -> "StepResult":
581
- _query = query = query_generator.select()
633
+ query, tables = self.process_input_query(query_generator.select())
634
+ _query = query
582
635
 
583
636
  # Apply partitioning if needed.
584
637
  if self.partition_by is not None:
585
638
  partition_tbl = self.create_partitions_table(query)
586
- temp_tables.append(partition_tbl.name)
639
+ query = query.outerjoin(
640
+ partition_tbl,
641
+ partition_tbl.c.sys__id == query.selected_columns.sys__id,
642
+ ).add_columns(*partition_columns())
643
+ tables = [*tables, partition_tbl]
587
644
 
588
- subq = query.subquery()
589
- query = (
590
- sqlalchemy.select(*subq.c)
591
- .outerjoin(partition_tbl, partition_tbl.c.sys__id == subq.c.sys__id)
592
- .add_columns(*partition_columns())
593
- )
594
-
595
- query, tables = self.process_input_query(query)
596
645
  temp_tables.extend(t.name for t in tables)
597
646
  udf_table = self.create_udf_table(_query)
598
647
  temp_tables.append(udf_table.name)
@@ -604,7 +653,16 @@ class UDFStep(Step, ABC):
604
653
 
605
654
  @frozen
606
655
  class UDFSignal(UDFStep):
656
+ udf: "UDFAdapter"
657
+ catalog: "Catalog"
658
+ partition_by: PartitionByType | None = None
607
659
  is_generator = False
660
+ # Parameters from Settings
661
+ cache: bool = False
662
+ parallel: int | None = None
663
+ workers: bool | int = False
664
+ min_task_size: int | None = None
665
+ batch_size: int | None = None
608
666
 
609
667
  def create_udf_table(self, query: Select) -> "Table":
610
668
  udf_output_columns: list[sqlalchemy.Column[Any]] = [
@@ -614,13 +672,6 @@ class UDFSignal(UDFStep):
614
672
 
615
673
  return self.catalog.warehouse.create_udf_table(udf_output_columns)
616
674
 
617
- def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
618
- if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
619
- return query, []
620
- table = self.catalog.warehouse.create_pre_udf_table(query)
621
- q: Select = sqlalchemy.select(*table.c)
622
- return q, [table]
623
-
624
675
  def create_result_query(
625
676
  self, udf_table, query
626
677
  ) -> tuple[QueryGeneratorFunc, list["sqlalchemy.Column"]]:
@@ -628,15 +679,30 @@ class UDFSignal(UDFStep):
628
679
  original_cols = [c for c in subq.c if c.name not in partition_col_names]
629
680
 
630
681
  # new signal columns that are added to udf_table
631
- signal_cols = [c for c in udf_table.c if c.name != "sys__id"]
682
+ signal_cols = [c for c in udf_table.c if not c.name.startswith("sys__")]
632
683
  signal_name_cols = {c.name: c for c in signal_cols}
633
684
  cols = signal_cols
634
685
 
635
- overlap = {c.name for c in original_cols} & {c.name for c in cols}
686
+ original_names = {c.name for c in original_cols}
687
+ new_names = {c.name for c in cols}
688
+
689
+ overlap = original_names & new_names
636
690
  if overlap:
637
691
  raise ValueError(
638
692
  "Column already exists or added in the previous steps: "
639
- + ", ".join(overlap)
693
+ + ", ".join(sorted(overlap))
694
+ )
695
+
696
+ def _root(name: str) -> str:
697
+ return name.split(DEFAULT_DELIMITER, 1)[0]
698
+
699
+ existing_roots = {_root(name) for name in original_names}
700
+ new_roots = {_root(name) for name in new_names}
701
+ root_conflicts = existing_roots & new_roots
702
+ if root_conflicts:
703
+ raise ValueError(
704
+ "Signals already exist in the previous steps: "
705
+ + ", ".join(sorted(root_conflicts))
640
706
  )
641
707
 
642
708
  def q(*columns):
@@ -674,7 +740,16 @@ class UDFSignal(UDFStep):
674
740
  class RowGenerator(UDFStep):
675
741
  """Extend dataset with new rows."""
676
742
 
743
+ udf: "UDFAdapter"
744
+ catalog: "Catalog"
745
+ partition_by: PartitionByType | None = None
677
746
  is_generator = True
747
+ # Parameters from Settings
748
+ cache: bool = False
749
+ parallel: int | None = None
750
+ workers: bool | int = False
751
+ min_task_size: int | None = None
752
+ batch_size: int | None = None
678
753
 
679
754
  def create_udf_table(self, query: Select) -> "Table":
680
755
  warehouse = self.catalog.warehouse
@@ -721,18 +796,42 @@ class SQLClause(Step, ABC):
721
796
 
722
797
  def parse_cols(
723
798
  self,
724
- cols: Sequence[Union[Function, ColumnElement]],
799
+ cols: Sequence[Function | ColumnElement],
725
800
  ) -> tuple[ColumnElement, ...]:
726
801
  return tuple(c.get_column() if isinstance(c, Function) else c for c in cols)
727
802
 
728
803
  @abstractmethod
729
- def apply_sql_clause(self, query):
804
+ def apply_sql_clause(self, query: Any) -> Any:
730
805
  pass
731
806
 
732
807
 
808
+ @frozen
809
+ class RegenerateSystemColumns(Step):
810
+ catalog: "Catalog"
811
+
812
+ def hash_inputs(self) -> str:
813
+ return hashlib.sha256(b"regenerate_system_columns").hexdigest()
814
+
815
+ def apply(
816
+ self, query_generator: QueryGenerator, temp_tables: list[str]
817
+ ) -> StepResult:
818
+ query = query_generator.select()
819
+ new_query = self.catalog.warehouse._regenerate_system_columns(
820
+ query, keep_existing_columns=True
821
+ )
822
+
823
+ def q(*columns):
824
+ return new_query.with_only_columns(*columns)
825
+
826
+ return step_result(q, new_query.selected_columns)
827
+
828
+
733
829
  @frozen
734
830
  class SQLSelect(SQLClause):
735
- args: tuple[Union[Function, ColumnElement], ...]
831
+ args: tuple[Function | ColumnElement, ...]
832
+
833
+ def hash_inputs(self) -> str:
834
+ return hash_column_elements(self.args)
736
835
 
737
836
  def apply_sql_clause(self, query) -> Select:
738
837
  subquery = query.subquery()
@@ -748,7 +847,10 @@ class SQLSelect(SQLClause):
748
847
 
749
848
  @frozen
750
849
  class SQLSelectExcept(SQLClause):
751
- args: tuple[Union[Function, ColumnElement], ...]
850
+ args: tuple[Function | ColumnElement, ...]
851
+
852
+ def hash_inputs(self) -> str:
853
+ return hash_column_elements(self.args)
752
854
 
753
855
  def apply_sql_clause(self, query: Select) -> Select:
754
856
  subquery = query.subquery()
@@ -758,33 +860,43 @@ class SQLSelectExcept(SQLClause):
758
860
 
759
861
  @frozen
760
862
  class SQLMutate(SQLClause):
761
- args: tuple[Union[Function, ColumnElement], ...]
863
+ args: tuple[Label, ...]
864
+ new_schema: SignalSchema
865
+
866
+ def hash_inputs(self) -> str:
867
+ return hash_column_elements(self.args)
762
868
 
763
869
  def apply_sql_clause(self, query: Select) -> Select:
764
870
  original_subquery = query.subquery()
765
- args = [
766
- original_subquery.c[str(c)] if isinstance(c, (str, C)) else c
767
- for c in self.parse_cols(self.args)
768
- ]
769
- to_mutate = {c.name for c in args}
871
+ to_mutate = {c.name for c in self.args}
770
872
 
771
- prefix = f"mutate{token_hex(8)}_"
772
- cols = [
773
- c.label(prefix + c.name) if c.name in to_mutate else c
873
+ # Drop the original versions to avoid name collisions, exclude renamed
874
+ # columns. Always keep system columns (sys__*) if they exist in original query
875
+ new_schema_columns = set(self.new_schema.db_signals())
876
+ base_cols = [
877
+ c
774
878
  for c in original_subquery.c
879
+ if c.name not in to_mutate
880
+ and (c.name in new_schema_columns or c.name.startswith("sys__"))
775
881
  ]
776
- # this is needed for new column to be used in clauses
777
- # like ORDER BY, otherwise new column is not recognized
778
- subquery = (
779
- sqlalchemy.select(*cols, *args).select_from(original_subquery).subquery()
882
+
883
+ # Create intermediate subquery to properly handle window functions
884
+ intermediate_query = sqlalchemy.select(*base_cols, *self.args).select_from(
885
+ original_subquery
780
886
  )
887
+ intermediate_subquery = intermediate_query.subquery()
781
888
 
782
- return sqlalchemy.select(*subquery.c).select_from(subquery)
889
+ return sqlalchemy.select(*intermediate_subquery.c).select_from(
890
+ intermediate_subquery
891
+ )
783
892
 
784
893
 
785
894
  @frozen
786
895
  class SQLFilter(SQLClause):
787
- expressions: tuple[Union[Function, ColumnElement], ...]
896
+ expressions: tuple[Function | ColumnElement, ...]
897
+
898
+ def hash_inputs(self) -> str:
899
+ return hash_column_elements(self.expressions)
788
900
 
789
901
  def __and__(self, other):
790
902
  expressions = self.parse_cols(self.expressions)
@@ -797,7 +909,10 @@ class SQLFilter(SQLClause):
797
909
 
798
910
  @frozen
799
911
  class SQLOrderBy(SQLClause):
800
- args: tuple[Union[Function, ColumnElement], ...]
912
+ args: tuple[Function | ColumnElement, ...]
913
+
914
+ def hash_inputs(self) -> str:
915
+ return hash_column_elements(self.args)
801
916
 
802
917
  def apply_sql_clause(self, query: Select) -> Select:
803
918
  args = self.parse_cols(self.args)
@@ -808,6 +923,9 @@ class SQLOrderBy(SQLClause):
808
923
  class SQLLimit(SQLClause):
809
924
  n: int
810
925
 
926
+ def hash_inputs(self) -> str:
927
+ return hashlib.sha256(str(self.n).encode()).hexdigest()
928
+
811
929
  def apply_sql_clause(self, query: Select) -> Select:
812
930
  return query.limit(self.n)
813
931
 
@@ -816,12 +934,18 @@ class SQLLimit(SQLClause):
816
934
  class SQLOffset(SQLClause):
817
935
  offset: int
818
936
 
937
+ def hash_inputs(self) -> str:
938
+ return hashlib.sha256(str(self.offset).encode()).hexdigest()
939
+
819
940
  def apply_sql_clause(self, query: "GenerativeSelect"):
820
941
  return query.offset(self.offset)
821
942
 
822
943
 
823
944
  @frozen
824
945
  class SQLCount(SQLClause):
946
+ def hash_inputs(self) -> str:
947
+ return ""
948
+
825
949
  def apply_sql_clause(self, query):
826
950
  return sqlalchemy.select(f.count(1)).select_from(query.subquery())
827
951
 
@@ -831,6 +955,9 @@ class SQLDistinct(SQLClause):
831
955
  args: tuple[ColumnElement, ...]
832
956
  dialect: str
833
957
 
958
+ def hash_inputs(self) -> str:
959
+ return hash_column_elements(self.args)
960
+
834
961
  def apply_sql_clause(self, query):
835
962
  if self.dialect == "sqlite":
836
963
  return query.group_by(*self.args)
@@ -843,24 +970,34 @@ class SQLUnion(Step):
843
970
  query1: "DatasetQuery"
844
971
  query2: "DatasetQuery"
845
972
 
973
+ def hash_inputs(self) -> str:
974
+ return hashlib.sha256(
975
+ bytes.fromhex(self.query1.hash()) + bytes.fromhex(self.query2.hash())
976
+ ).hexdigest()
977
+
846
978
  def apply(
847
979
  self, query_generator: QueryGenerator, temp_tables: list[str]
848
980
  ) -> StepResult:
981
+ left_before = len(self.query1.temp_table_names)
849
982
  q1 = self.query1.apply_steps().select().subquery()
850
- temp_tables.extend(self.query1.temp_table_names)
983
+ temp_tables.extend(self.query1.temp_table_names[left_before:])
984
+ right_before = len(self.query2.temp_table_names)
851
985
  q2 = self.query2.apply_steps().select().subquery()
852
- temp_tables.extend(self.query2.temp_table_names)
986
+ temp_tables.extend(self.query2.temp_table_names[right_before:])
853
987
 
854
- columns1, columns2 = _order_columns(q1.columns, q2.columns)
988
+ columns1 = _drop_system_columns(q1.columns)
989
+ columns2 = _drop_system_columns(q2.columns)
990
+ columns1, columns2 = _order_columns(columns1, columns2)
855
991
 
856
992
  def q(*columns):
857
- names = {c.name for c in columns}
858
- col1 = [c for c in columns1 if c.name in names]
859
- col2 = [c for c in columns2 if c.name in names]
860
- res = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
993
+ selected_names = [c.name for c in columns]
994
+ col1 = [c for c in columns1 if c.name in selected_names]
995
+ col2 = [c for c in columns2 if c.name in selected_names]
996
+ union_query = sqlalchemy.select(*col1).union_all(sqlalchemy.select(*col2))
861
997
 
862
- subquery = res.subquery()
863
- return sqlalchemy.select(*subquery.c).select_from(subquery)
998
+ union_cte = union_query.cte()
999
+ select_cols = [union_cte.c[name] for name in selected_names]
1000
+ return sqlalchemy.select(*select_cols)
864
1001
 
865
1002
  return step_result(
866
1003
  q,
@@ -874,14 +1011,42 @@ class SQLJoin(Step):
874
1011
  catalog: "Catalog"
875
1012
  query1: "DatasetQuery"
876
1013
  query2: "DatasetQuery"
877
- predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
1014
+ predicates: JoinPredicateType | tuple[JoinPredicateType, ...]
878
1015
  inner: bool
879
1016
  full: bool
880
1017
  rname: str
881
1018
 
1019
+ @staticmethod
1020
+ def _split_db_name(name: str) -> tuple[str, str]:
1021
+ if DEFAULT_DELIMITER in name:
1022
+ head, tail = name.split(DEFAULT_DELIMITER, 1)
1023
+ return head, tail
1024
+ return name, ""
1025
+
1026
+ @classmethod
1027
+ def _root_name(cls, name: str) -> str:
1028
+ return cls._split_db_name(name)[0]
1029
+
1030
+ def hash_inputs(self) -> str:
1031
+ predicates = (
1032
+ ensure_sequence(self.predicates) if self.predicates is not None else []
1033
+ )
1034
+
1035
+ parts = [
1036
+ bytes.fromhex(self.query1.hash()),
1037
+ bytes.fromhex(self.query2.hash()),
1038
+ bytes.fromhex(hash_column_elements(predicates)),
1039
+ str(self.inner).encode(),
1040
+ str(self.full).encode(),
1041
+ self.rname.encode("utf-8"),
1042
+ ]
1043
+
1044
+ return hashlib.sha256(b"".join(parts)).hexdigest()
1045
+
882
1046
  def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
1047
+ temp_tables_before = len(dq.temp_table_names)
883
1048
  query = dq.apply_steps().select()
884
- temp_tables.extend(dq.temp_table_names)
1049
+ temp_tables.extend(dq.temp_table_names[temp_tables_before:])
885
1050
 
886
1051
  if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in dq.steps):
887
1052
  return query.subquery(dq.table.name)
@@ -937,22 +1102,39 @@ class SQLJoin(Step):
937
1102
  q1 = self.get_query(self.query1, temp_tables)
938
1103
  q2 = self.get_query(self.query2, temp_tables)
939
1104
 
940
- q1_columns = list(q1.c)
941
- q1_column_names = {c.name for c in q1_columns}
942
-
943
- q2_columns = []
944
- for c in q2.c:
945
- if c.name.startswith("sys__"):
1105
+ q1_columns = _drop_system_columns(q1.c)
1106
+ existing_column_names = {c.name for c in q1_columns}
1107
+ right_columns: list[KeyedColumnElement[Any]] = []
1108
+ right_column_names: list[str] = []
1109
+ for column in q2.c:
1110
+ if column.name.startswith("sys__"):
946
1111
  continue
1112
+ right_columns.append(column)
1113
+ right_column_names.append(column.name)
1114
+
1115
+ root_mapping = generate_merge_root_mapping(
1116
+ existing_column_names,
1117
+ right_column_names,
1118
+ extract_root=self._root_name,
1119
+ prefix=self.rname,
1120
+ )
1121
+
1122
+ q2_columns: list[KeyedColumnElement[Any]] = []
1123
+ for column in right_columns:
1124
+ original_name = column.name
1125
+ column_root, column_tail = self._split_db_name(original_name)
1126
+ mapped_root = root_mapping[column_root]
1127
+
1128
+ new_name = (
1129
+ mapped_root
1130
+ if not column_tail
1131
+ else DEFAULT_DELIMITER.join([mapped_root, column_tail])
1132
+ )
1133
+
1134
+ if new_name != original_name:
1135
+ column = column.label(new_name)
947
1136
 
948
- if c.name in q1_column_names:
949
- new_name = self.rname.format(name=c.name)
950
- new_name_idx = 0
951
- while new_name in q1_column_names:
952
- new_name_idx += 1
953
- new_name = self.rname.format(name=f"{c.name}_{new_name_idx}")
954
- c = c.label(new_name)
955
- q2_columns.append(c)
1137
+ q2_columns.append(column)
956
1138
 
957
1139
  res_columns = q1_columns + q2_columns
958
1140
  predicates = (
@@ -997,8 +1179,15 @@ class SQLJoin(Step):
997
1179
 
998
1180
  @frozen
999
1181
  class SQLGroupBy(SQLClause):
1000
- cols: Sequence[Union[str, Function, ColumnElement]]
1001
- group_by: Sequence[Union[str, Function, ColumnElement]]
1182
+ cols: Sequence[str | Function | ColumnElement]
1183
+ group_by: Sequence[str | Function | ColumnElement]
1184
+
1185
+ def hash_inputs(self) -> str:
1186
+ return hashlib.sha256(
1187
+ bytes.fromhex(
1188
+ hash_column_elements(self.cols) + hash_column_elements(self.group_by)
1189
+ )
1190
+ ).hexdigest()
1002
1191
 
1003
1192
  def apply_sql_clause(self, query) -> Select:
1004
1193
  if not self.cols:
@@ -1010,58 +1199,70 @@ class SQLGroupBy(SQLClause):
1010
1199
  c.get_column() if isinstance(c, Function) else c for c in self.group_by
1011
1200
  ]
1012
1201
 
1013
- cols = [
1014
- c.get_column()
1015
- if isinstance(c, Function)
1016
- else subquery.c[str(c)]
1017
- if isinstance(c, (str, C))
1018
- else c
1019
- for c in (*group_by, *self.cols)
1020
- ]
1202
+ cols_dict: dict[str, Any] = {}
1203
+ for c in (*group_by, *self.cols):
1204
+ if isinstance(c, Function):
1205
+ key = c.name
1206
+ value = c.get_column()
1207
+ elif isinstance(c, (str, C)):
1208
+ key = str(c)
1209
+ value = subquery.c[str(c)]
1210
+ else:
1211
+ key = c.name
1212
+ value = c # type: ignore[assignment]
1213
+ cols_dict[key] = value
1021
1214
 
1022
- return sqlalchemy.select(*cols).select_from(subquery).group_by(*group_by)
1215
+ unique_cols = cols_dict.values()
1023
1216
 
1217
+ return sqlalchemy.select(*unique_cols).select_from(subquery).group_by(*group_by)
1024
1218
 
1025
- def _validate_columns(
1026
- left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
1027
- ) -> set[str]:
1028
- left_names = {c.name for c in left_columns}
1029
- right_names = {c.name for c in right_columns}
1030
-
1031
- if left_names == right_names:
1032
- return left_names
1033
-
1034
- missing_right = left_names - right_names
1035
- missing_left = right_names - left_names
1036
-
1037
- def _prepare_msg_part(missing_columns: set[str], side: str) -> str:
1038
- return f"{', '.join(sorted(missing_columns))} only present in {side}"
1039
-
1040
- msg_parts = [
1041
- _prepare_msg_part(missing_columns, found_side)
1042
- for missing_columns, found_side in zip(
1043
- [
1044
- missing_right,
1045
- missing_left,
1046
- ],
1047
- ["left", "right"],
1048
- )
1049
- if missing_columns
1050
- ]
1051
- msg = f"Cannot perform union. {'. '.join(msg_parts)}"
1052
1219
 
1053
- raise ValueError(msg)
1220
+ class UnionSchemaMismatchError(ValueError):
1221
+ """Union input columns mismatch."""
1222
+
1223
+ @classmethod
1224
+ def from_column_sets(
1225
+ cls,
1226
+ missing_left: set[str],
1227
+ missing_right: set[str],
1228
+ ) -> "UnionSchemaMismatchError":
1229
+ def _describe(cols: set[str], side: str) -> str:
1230
+ return f"{', '.join(sorted(cols))} only present in {side}"
1231
+
1232
+ parts = []
1233
+ if missing_left:
1234
+ parts.append(_describe(missing_left, "left"))
1235
+ if missing_right:
1236
+ parts.append(_describe(missing_right, "right"))
1237
+
1238
+ return cls(f"Cannot perform union. {'. '.join(parts)}")
1054
1239
 
1055
1240
 
1056
1241
  def _order_columns(
1057
1242
  left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
1058
1243
  ) -> list[list[ColumnElement]]:
1059
- column_order = _validate_columns(left_columns, right_columns)
1244
+ left_names = [c.name for c in left_columns]
1245
+ right_names = [c.name for c in right_columns]
1246
+
1247
+ # validate
1248
+ if sorted(left_names) != sorted(right_names):
1249
+ left_names_set = set(left_names)
1250
+ right_names_set = set(right_names)
1251
+ raise UnionSchemaMismatchError.from_column_sets(
1252
+ left_names_set - right_names_set,
1253
+ right_names_set - left_names_set,
1254
+ )
1255
+
1256
+ # Order columns to match left_names order
1060
1257
  column_dicts = [
1061
1258
  {c.name: c for c in columns} for columns in [left_columns, right_columns]
1062
1259
  ]
1063
1260
 
1064
- return [[d[n] for n in column_order] for d in column_dicts]
1261
+ return [[d[n] for n in left_names] for d in column_dicts]
1262
+
1263
+
1264
+ def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
1265
+ return [c for c in columns if not c.name.startswith("sys__")]
1065
1266
 
1066
1267
 
1067
1268
  @attrs.define
@@ -1077,62 +1278,71 @@ class DatasetQuery:
1077
1278
  def __init__(
1078
1279
  self,
1079
1280
  name: str,
1080
- version: Optional[int] = None,
1081
- catalog: Optional["Catalog"] = None,
1082
- session: Optional[Session] = None,
1083
- indexing_column_types: Optional[dict[str, Any]] = None,
1281
+ version: str | None = None,
1282
+ project_name: str | None = None,
1283
+ namespace_name: str | None = None,
1284
+ catalog: "Catalog | None" = None,
1285
+ session: Session | None = None,
1084
1286
  in_memory: bool = False,
1085
- fallback_to_studio: bool = True,
1086
1287
  update: bool = False,
1087
1288
  ) -> None:
1088
- from datachain.remote.studio import is_token_set
1089
-
1090
1289
  self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
1091
1290
  self.catalog = catalog or self.session.catalog
1092
1291
  self.steps: list[Step] = []
1093
- self._chunk_index: Optional[int] = None
1094
- self._chunk_total: Optional[int] = None
1292
+ self._chunk_index: int | None = None
1293
+ self._chunk_total: int | None = None
1095
1294
  self.temp_table_names: list[str] = []
1096
1295
  self.dependencies: set[DatasetDependencyType] = set()
1097
1296
  self.table = self.get_table()
1098
- self.starting_step: Optional[QueryStep] = None
1099
- self.name: Optional[str] = None
1100
- self.version: Optional[int] = None
1101
- self.feature_schema: Optional[dict] = None
1102
- self.column_types: Optional[dict[str, Any]] = None
1297
+ self.starting_step: QueryStep | None = None
1298
+ self.name: str | None = None
1299
+ self.version: str | None = None
1300
+ self.feature_schema: dict | None = None
1301
+ self.column_types: dict[str, Any] | None = None
1103
1302
  self.before_steps: list[Callable] = []
1104
- self.listing_fn: Optional[Callable] = None
1303
+ self.listing_fn: Callable | None = None
1105
1304
  self.update = update
1106
1305
 
1107
- self.list_ds_name: Optional[str] = None
1306
+ self.list_ds_name: str | None = None
1108
1307
 
1109
1308
  self.name = name
1110
1309
  self.dialect = self.catalog.warehouse.db.dialect
1111
1310
  if version:
1112
1311
  self.version = version
1113
1312
 
1114
- if is_listing_dataset(name):
1313
+ if namespace_name is None:
1314
+ namespace_name = self.catalog.metastore.default_namespace_name
1315
+ if project_name is None:
1316
+ project_name = self.catalog.metastore.default_project_name
1317
+
1318
+ if is_listing_dataset(name) and not version:
1115
1319
  # not setting query step yet as listing dataset might not exist at
1116
1320
  # this point
1117
1321
  self.list_ds_name = name
1118
- elif fallback_to_studio and is_token_set():
1322
+ else:
1119
1323
  self._set_starting_step(
1120
- self.catalog.get_dataset_with_remote_fallback(name, version)
1324
+ self.catalog.get_dataset_with_remote_fallback(
1325
+ name,
1326
+ namespace_name=namespace_name,
1327
+ project_name=project_name,
1328
+ version=version,
1329
+ pull_dataset=True,
1330
+ update=update,
1331
+ )
1121
1332
  )
1122
- else:
1123
- self._set_starting_step(self.catalog.get_dataset(name))
1124
1333
 
1125
1334
  def _set_starting_step(self, ds: "DatasetRecord") -> None:
1126
1335
  if not self.version:
1127
1336
  self.version = ds.latest_version
1128
1337
 
1129
- self.starting_step = QueryStep(self.catalog, ds.name, self.version)
1338
+ self.starting_step = QueryStep(self.catalog, ds, self.version)
1130
1339
 
1131
1340
  # at this point we know our starting dataset so setting up schemas
1132
1341
  self.feature_schema = ds.get_version(self.version).feature_schema
1133
1342
  self.column_types = copy(ds.schema)
1134
1343
  if "sys__id" in self.column_types:
1135
1344
  self.column_types.pop("sys__id")
1345
+ self.project = ds.project
1136
1346
 
1137
1347
  def __iter__(self):
1138
1348
  return iter(self.db_results())
@@ -1140,39 +1350,28 @@ class DatasetQuery:
1140
1350
  def __or__(self, other):
1141
1351
  return self.union(other)
1142
1352
 
1143
- def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
1144
- print("Dataset not found in local catalog, trying to get from studio")
1145
-
1146
- remote_ds_uri = f"{DATASET_PREFIX}{name}"
1147
- if version:
1148
- remote_ds_uri += f"@v{version}"
1353
+ def hash(self) -> str:
1354
+ """
1355
+ Calculates hash of this class taking into account hash of starting step
1356
+ and hashes of each following steps. Ordering is important.
1357
+ """
1358
+ hasher = hashlib.sha256()
1359
+ if self.starting_step:
1360
+ hasher.update(self.starting_step.hash().encode("utf-8"))
1361
+ else:
1362
+ assert self.list_ds_name
1363
+ hasher.update(self.list_ds_name.encode("utf-8"))
1149
1364
 
1150
- self.catalog.pull_dataset(
1151
- remote_ds_uri=remote_ds_uri,
1152
- local_ds_name=name,
1153
- local_ds_version=version,
1154
- )
1365
+ for step in self.steps:
1366
+ hasher.update(step.hash().encode("utf-8"))
1155
1367
 
1156
- return self.catalog.get_dataset(name)
1368
+ return hasher.hexdigest()
1157
1369
 
1158
1370
  @staticmethod
1159
1371
  def get_table() -> "TableClause":
1160
- table_name = "".join(
1161
- random.choice(string.ascii_letters) # noqa: S311
1162
- for _ in range(16)
1163
- )
1372
+ table_name = "".join(secrets.choice(string.ascii_letters) for _ in range(16))
1164
1373
  return sqlalchemy.table(table_name)
1165
1374
 
1166
- @staticmethod
1167
- def delete(
1168
- name: str, version: Optional[int] = None, catalog: Optional["Catalog"] = None
1169
- ) -> None:
1170
- from datachain.catalog import get_catalog
1171
-
1172
- catalog = catalog or get_catalog()
1173
- version = version or catalog.get_dataset(name).latest_version
1174
- catalog.remove_dataset(name, version)
1175
-
1176
1375
  @property
1177
1376
  def attached(self) -> bool:
1178
1377
  """
@@ -1180,14 +1379,14 @@ class DatasetQuery:
1180
1379
  it completely. If this is the case, name and version of underlying dataset
1181
1380
  will be defined.
1182
1381
  DatasetQuery instance can become attached in two scenarios:
1183
- 1. ds = DatasetQuery(name="dogs", version=1) -> ds is attached to dogs
1184
- 2. ds = ds.save("dogs", version=1) -> ds is attached to dogs dataset
1382
+ 1. ds = DatasetQuery(name="dogs", version="1.0.0") -> ds is attached to dogs
1383
+ 2. ds = ds.save("dogs", version="1.0.0") -> ds is attached to dogs dataset
1185
1384
  It can move to detached state if filter or similar methods are called on it,
1186
1385
  as then it no longer 100% represents underlying datasets.
1187
1386
  """
1188
1387
  return self.name is not None and self.version is not None
1189
1388
 
1190
- def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
1389
+ def c(self, column: C | str) -> "ColumnClause[Any]":
1191
1390
  col: sqlalchemy.ColumnClause = (
1192
1391
  sqlalchemy.column(column)
1193
1392
  if isinstance(column, str)
@@ -1200,11 +1399,8 @@ class DatasetQuery:
1200
1399
  """Setting listing function to be run if needed"""
1201
1400
  self.listing_fn = fn
1202
1401
 
1203
- def apply_steps(self) -> QueryGenerator:
1204
- """
1205
- Apply the steps in the query and return the resulting
1206
- sqlalchemy.SelectBase.
1207
- """
1402
+ def apply_listing_pre_step(self) -> None:
1403
+ """Runs listing pre-step if needed"""
1208
1404
  if self.list_ds_name and not self.starting_step:
1209
1405
  listing_ds = None
1210
1406
  try:
@@ -1220,6 +1416,13 @@ class DatasetQuery:
1220
1416
  # at this point we know what is our starting listing dataset name
1221
1417
  self._set_starting_step(listing_ds) # type: ignore [arg-type]
1222
1418
 
1419
+ def apply_steps(self) -> QueryGenerator:
1420
+ """
1421
+ Apply the steps in the query and return the resulting
1422
+ sqlalchemy.SelectBase.
1423
+ """
1424
+ self.apply_listing_pre_step()
1425
+
1223
1426
  query = self.clone()
1224
1427
 
1225
1428
  index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
@@ -1278,6 +1481,7 @@ class DatasetQuery:
1278
1481
  # This is needed to always use a new connection with all metastore and warehouse
1279
1482
  # implementations, as errors may close or render unusable the existing
1280
1483
  # connections.
1484
+ assert len(self.temp_table_names) == len(set(self.temp_table_names))
1281
1485
  with self.catalog.metastore.clone(use_new_connection=True) as metastore:
1282
1486
  metastore.cleanup_tables(self.temp_table_names)
1283
1487
  with self.catalog.warehouse.clone(use_new_connection=True) as warehouse:
@@ -1292,7 +1496,7 @@ class DatasetQuery:
1292
1496
  return list(result)
1293
1497
 
1294
1498
  def to_db_records(self) -> list[dict[str, Any]]:
1295
- return self.db_results(lambda cols, row: dict(zip(cols, row)))
1499
+ return self.db_results(lambda cols, row: dict(zip(cols, row, strict=False)))
1296
1500
 
1297
1501
  @contextlib.contextmanager
1298
1502
  def as_iterable(self, **kwargs) -> Iterator[ResultIter]:
@@ -1331,8 +1535,8 @@ class DatasetQuery:
1331
1535
  yield from rows
1332
1536
 
1333
1537
  async def get_params(row: Sequence) -> tuple:
1334
- row_dict = RowDict(zip(query_fields, row))
1335
- return tuple(
1538
+ row_dict = RowDict(zip(query_fields, row, strict=False))
1539
+ return tuple( # noqa: C409
1336
1540
  [
1337
1541
  await p.get_value_async(
1338
1542
  self.catalog, row_dict, mapper, **kwargs
@@ -1348,10 +1552,6 @@ class DatasetQuery:
1348
1552
  finally:
1349
1553
  self.cleanup()
1350
1554
 
1351
- def shuffle(self) -> "Self":
1352
- # ToDo: implement shaffle based on seed and/or generating random column
1353
- return self.order_by(C.sys__rand)
1354
-
1355
1555
  def sample(self, n) -> "Self":
1356
1556
  """
1357
1557
  Return a random sample from the dataset.
@@ -1371,6 +1571,7 @@ class DatasetQuery:
1371
1571
  obj.steps = obj.steps.copy()
1372
1572
  if new_table:
1373
1573
  obj.table = self.get_table()
1574
+ obj.temp_table_names = []
1374
1575
  return obj
1375
1576
 
1376
1577
  @detach
@@ -1441,7 +1642,7 @@ class DatasetQuery:
1441
1642
  return query
1442
1643
 
1443
1644
  @detach
1444
- def mutate(self, *args, **kwargs) -> "Self":
1645
+ def mutate(self, *args, new_schema, **kwargs) -> "Self":
1445
1646
  """
1446
1647
  Add new columns to this query.
1447
1648
 
@@ -1453,7 +1654,7 @@ class DatasetQuery:
1453
1654
  """
1454
1655
  query_args = [v.label(k) for k, v in dict(args, **kwargs).items()]
1455
1656
  query = self.clone()
1456
- query.steps.append(SQLMutate((*query_args,)))
1657
+ query.steps.append(SQLMutate((*query_args,), new_schema))
1457
1658
  return query
1458
1659
 
1459
1660
  @detach
@@ -1551,10 +1752,10 @@ class DatasetQuery:
1551
1752
  def join(
1552
1753
  self,
1553
1754
  dataset_query: "DatasetQuery",
1554
- predicates: Union[JoinPredicateType, Sequence[JoinPredicateType]],
1755
+ predicates: JoinPredicateType | Sequence[JoinPredicateType],
1555
1756
  inner=False,
1556
1757
  full=False,
1557
- rname="{name}_right",
1758
+ rname="right_",
1558
1759
  ) -> "Self":
1559
1760
  left = self.clone(new_table=False)
1560
1761
  if self.table.name == dataset_query.table.name:
@@ -1593,11 +1794,17 @@ class DatasetQuery:
1593
1794
  def add_signals(
1594
1795
  self,
1595
1796
  udf: "UDFAdapter",
1596
- parallel: Optional[int] = None,
1597
- workers: Union[bool, int] = False,
1598
- min_task_size: Optional[int] = None,
1599
- partition_by: Optional[PartitionByType] = None,
1797
+ partition_by: PartitionByType | None = None,
1798
+ # Parameters from Settings
1600
1799
  cache: bool = False,
1800
+ parallel: int | None = None,
1801
+ workers: bool | int = False,
1802
+ min_task_size: int | None = None,
1803
+ batch_size: int | None = None,
1804
+ # Parameters are unused, kept only to match the signature of Settings.to_dict
1805
+ prefetch: int | None = None,
1806
+ namespace: str | None = None,
1807
+ project: str | None = None,
1601
1808
  ) -> "Self":
1602
1809
  """
1603
1810
  Adds one or more signals based on the results from the provided UDF.
@@ -1623,6 +1830,7 @@ class DatasetQuery:
1623
1830
  workers=workers,
1624
1831
  min_task_size=min_task_size,
1625
1832
  cache=cache,
1833
+ batch_size=batch_size,
1626
1834
  )
1627
1835
  )
1628
1836
  return query
@@ -1637,11 +1845,17 @@ class DatasetQuery:
1637
1845
  def generate(
1638
1846
  self,
1639
1847
  udf: "UDFAdapter",
1640
- parallel: Optional[int] = None,
1641
- workers: Union[bool, int] = False,
1642
- min_task_size: Optional[int] = None,
1643
- partition_by: Optional[PartitionByType] = None,
1848
+ partition_by: PartitionByType | None = None,
1849
+ # Parameters from Settings
1644
1850
  cache: bool = False,
1851
+ parallel: int | None = None,
1852
+ workers: bool | int = False,
1853
+ min_task_size: int | None = None,
1854
+ batch_size: int | None = None,
1855
+ # Parameters are unused, kept only to match the signature of Settings.to_dict:
1856
+ prefetch: int | None = None,
1857
+ namespace: str | None = None,
1858
+ project: str | None = None,
1645
1859
  ) -> "Self":
1646
1860
  query = self.clone()
1647
1861
  steps = query.steps
@@ -1654,41 +1868,84 @@ class DatasetQuery:
1654
1868
  workers=workers,
1655
1869
  min_task_size=min_task_size,
1656
1870
  cache=cache,
1871
+ batch_size=batch_size,
1657
1872
  )
1658
1873
  )
1659
1874
  return query
1660
1875
 
1661
- def _add_dependencies(self, dataset: "DatasetRecord", version: int):
1662
- for dependency in self.dependencies:
1663
- ds_dependency_name, ds_dependency_version = dependency
1876
+ def _add_dependencies(self, dataset: "DatasetRecord", version: str):
1877
+ dependencies: set[DatasetDependencyType] = set()
1878
+ for dep_dataset, dep_dataset_version in self.dependencies:
1879
+ if Session.is_temp_dataset(dep_dataset.name):
1880
+ # temp dataset are created for optimization and they will be removed
1881
+ # afterwards. Therefore, we should not put them as dependencies, but
1882
+ # their own direct dependencies
1883
+ for dep in self.catalog.get_dataset_dependencies(
1884
+ dep_dataset.name,
1885
+ dep_dataset_version,
1886
+ namespace_name=dep_dataset.project.namespace.name,
1887
+ project_name=dep_dataset.project.name,
1888
+ indirect=False,
1889
+ ):
1890
+ if dep:
1891
+ dependencies.add(
1892
+ (
1893
+ self.catalog.get_dataset(
1894
+ dep.name,
1895
+ namespace_name=dep.namespace,
1896
+ project_name=dep.project,
1897
+ ),
1898
+ dep.version,
1899
+ )
1900
+ )
1901
+ else:
1902
+ dependencies.add((dep_dataset, dep_dataset_version))
1903
+
1904
+ for dep_dataset, dep_dataset_version in dependencies:
1664
1905
  self.catalog.metastore.add_dataset_dependency(
1665
- dataset.name,
1906
+ dataset,
1666
1907
  version,
1667
- ds_dependency_name,
1668
- ds_dependency_version,
1908
+ dep_dataset,
1909
+ dep_dataset_version,
1669
1910
  )
1670
1911
 
1671
1912
  def exec(self) -> "Self":
1672
1913
  """Execute the query."""
1914
+ query = self.clone()
1673
1915
  try:
1674
- query = self.clone()
1675
1916
  query.apply_steps()
1676
1917
  finally:
1677
- self.cleanup()
1918
+ query.cleanup()
1678
1919
  return query
1679
1920
 
1680
1921
  def save(
1681
1922
  self,
1682
- name: Optional[str] = None,
1683
- version: Optional[int] = None,
1684
- feature_schema: Optional[dict] = None,
1685
- description: Optional[str] = None,
1686
- labels: Optional[list[str]] = None,
1923
+ name: str | None = None,
1924
+ version: str | None = None,
1925
+ project: Project | None = None,
1926
+ feature_schema: dict | None = None,
1927
+ dependencies: list[DatasetDependency] | None = None,
1928
+ description: str | None = None,
1929
+ attrs: list[str] | None = None,
1930
+ update_version: str | None = "patch",
1687
1931
  **kwargs,
1688
1932
  ) -> "Self":
1689
1933
  """Save the query as a dataset."""
1934
+ # Get job from session to link dataset version to job
1935
+ job = self.session.get_or_create_job()
1936
+ job_id = job.id
1937
+
1938
+ project = project or self.catalog.metastore.default_project
1690
1939
  try:
1691
- if name and version and self.catalog.get_dataset(name).has_version(version):
1940
+ if (
1941
+ name
1942
+ and version
1943
+ and self.catalog.get_dataset(
1944
+ name,
1945
+ namespace_name=project.namespace.name,
1946
+ project_name=project.name,
1947
+ ).has_version(version)
1948
+ ):
1692
1949
  raise RuntimeError(f"Dataset {name} already has version {version}")
1693
1950
  except DatasetNotFoundError:
1694
1951
  pass
@@ -1713,19 +1970,18 @@ class DatasetQuery:
1713
1970
 
1714
1971
  dataset = self.catalog.create_dataset(
1715
1972
  name,
1973
+ project,
1716
1974
  version=version,
1717
1975
  feature_schema=feature_schema,
1718
1976
  columns=columns,
1719
1977
  description=description,
1720
- labels=labels,
1978
+ attrs=attrs,
1979
+ update_version=update_version,
1980
+ job_id=job_id,
1721
1981
  **kwargs,
1722
1982
  )
1723
1983
  version = version or dataset.latest_version
1724
1984
 
1725
- self.session.add_dataset_version(
1726
- dataset=dataset, version=version, listing=kwargs.get("listing", False)
1727
- )
1728
-
1729
1985
  dr = self.catalog.warehouse.dataset_rows(dataset)
1730
1986
 
1731
1987
  self.catalog.warehouse.copy_table(dr.get_table(), query.select())
@@ -1735,15 +1991,41 @@ class DatasetQuery:
1735
1991
  )
1736
1992
  self.catalog.update_dataset_version_with_warehouse_info(dataset, version)
1737
1993
 
1994
+ # Link this dataset version to the job that created it
1995
+ self.catalog.metastore.link_dataset_version_to_job(
1996
+ dataset.get_version(version).id, job_id, is_creator=True
1997
+ )
1998
+
1999
+ if dependencies:
2000
+ # overriding dependencies
2001
+ self.dependencies = set()
2002
+ for dep in dependencies:
2003
+ self.dependencies.add(
2004
+ (
2005
+ self.catalog.get_dataset(
2006
+ dep.name,
2007
+ namespace_name=dep.namespace,
2008
+ project_name=dep.project,
2009
+ ),
2010
+ dep.version,
2011
+ )
2012
+ )
2013
+
1738
2014
  self._add_dependencies(dataset, version) # type: ignore [arg-type]
1739
2015
  finally:
1740
2016
  self.cleanup()
1741
- return self.__class__(name=name, version=version, catalog=self.catalog)
2017
+ return self.__class__(
2018
+ name=name,
2019
+ namespace_name=project.namespace.name,
2020
+ project_name=project.name,
2021
+ version=version,
2022
+ catalog=self.catalog,
2023
+ )
1742
2024
 
1743
2025
  @property
1744
2026
  def is_ordered(self) -> bool:
1745
2027
  return isinstance(self.last_step, SQLOrderBy)
1746
2028
 
1747
2029
  @property
1748
- def last_step(self) -> Optional[Step]:
2030
+ def last_step(self) -> Step | None:
1749
2031
  return self.steps[-1] if self.steps else None