datachain 0.16.4__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (39) hide show
  1. datachain/catalog/catalog.py +25 -92
  2. datachain/cli/__init__.py +11 -9
  3. datachain/cli/commands/datasets.py +1 -1
  4. datachain/cli/commands/query.py +1 -0
  5. datachain/cli/commands/show.py +1 -1
  6. datachain/cli/parser/__init__.py +11 -3
  7. datachain/data_storage/job.py +1 -0
  8. datachain/data_storage/metastore.py +105 -94
  9. datachain/data_storage/sqlite.py +8 -7
  10. datachain/data_storage/warehouse.py +58 -46
  11. datachain/dataset.py +88 -45
  12. datachain/lib/arrow.py +23 -1
  13. datachain/lib/dataset_info.py +2 -1
  14. datachain/lib/dc/csv.py +1 -0
  15. datachain/lib/dc/datachain.py +38 -16
  16. datachain/lib/dc/datasets.py +28 -7
  17. datachain/lib/dc/storage.py +10 -2
  18. datachain/lib/listing.py +2 -0
  19. datachain/lib/pytorch.py +2 -2
  20. datachain/lib/udf.py +17 -5
  21. datachain/listing.py +1 -1
  22. datachain/query/batch.py +40 -39
  23. datachain/query/dataset.py +42 -41
  24. datachain/query/dispatch.py +137 -75
  25. datachain/query/metrics.py +1 -2
  26. datachain/query/queue.py +1 -11
  27. datachain/query/session.py +2 -2
  28. datachain/query/udf.py +1 -1
  29. datachain/query/utils.py +8 -14
  30. datachain/remote/studio.py +4 -4
  31. datachain/semver.py +58 -0
  32. datachain/studio.py +1 -1
  33. datachain/utils.py +3 -0
  34. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/METADATA +1 -1
  35. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/RECORD +39 -38
  36. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/WHEEL +1 -1
  37. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/entry_points.txt +0 -0
  38. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/licenses/LICENSE +0 -0
  39. {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from uuid import uuid4
6
6
  from pydantic import Field, field_validator
7
7
 
8
8
  from datachain.dataset import (
9
+ DEFAULT_DATASET_VERSION,
9
10
  DatasetListRecord,
10
11
  DatasetListVersion,
11
12
  DatasetStatus,
@@ -22,7 +23,7 @@ if TYPE_CHECKING:
22
23
  class DatasetInfo(DataModel):
23
24
  name: str
24
25
  uuid: str = Field(default=str(uuid4()))
25
- version: int = Field(default=1)
26
+ version: str = Field(default=DEFAULT_DATASET_VERSION)
26
27
  status: int = Field(default=DatasetStatus.CREATED)
27
28
  created_at: datetime = Field(default=TIME_ZERO)
28
29
  finished_at: Optional[datetime] = Field(default=None)
datachain/lib/dc/csv.py CHANGED
@@ -124,4 +124,5 @@ def read_csv(
124
124
  source=source,
125
125
  nrows=nrows,
126
126
  format=format,
127
+ parse_options=parse_options,
127
128
  )
@@ -23,6 +23,7 @@ import sqlalchemy
23
23
  from pydantic import BaseModel
24
24
  from tqdm import tqdm
25
25
 
26
+ from datachain import semver
26
27
  from datachain.dataset import DatasetRecord
27
28
  from datachain.func import literal
28
29
  from datachain.func.base import Function
@@ -214,7 +215,7 @@ class DataChain:
214
215
  return self._query.name
215
216
 
216
217
  @property
217
- def version(self) -> Optional[int]:
218
+ def version(self) -> Optional[str]:
218
219
  """Version of the underlying dataset, if there is one."""
219
220
  return self._query.version
220
221
 
@@ -457,7 +458,7 @@ class DataChain:
457
458
  def save( # type: ignore[override]
458
459
  self,
459
460
  name: str,
460
- version: Optional[int] = None,
461
+ version: Optional[str] = None,
461
462
  description: Optional[str] = None,
462
463
  attrs: Optional[list[str]] = None,
463
464
  **kwargs,
@@ -466,11 +467,15 @@ class DataChain:
466
467
 
467
468
  Parameters:
468
469
  name : dataset name.
469
- version : version of a dataset. Default - the last version that exist.
470
+ version : version of a dataset. If version is not specified and dataset
471
+ already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
470
472
  description : description of a dataset.
471
473
  attrs : attributes of a dataset. They can be without value, e.g "NLP",
472
474
  or with a value, e.g "location=US".
473
475
  """
476
+ if version is not None:
477
+ semver.validate(version)
478
+
474
479
  schema = self.signals_schema.clone_without_sys_signals().serialize()
475
480
  return self._evolve(
476
481
  query=self._query.save(
@@ -1636,18 +1641,27 @@ class DataChain:
1636
1641
  """
1637
1642
  from pyarrow.dataset import CsvFileFormat, JsonFileFormat
1638
1643
 
1639
- from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output
1644
+ from datachain.lib.arrow import (
1645
+ ArrowGenerator,
1646
+ fix_pyarrow_format,
1647
+ infer_schema,
1648
+ schema_to_output,
1649
+ )
1640
1650
 
1641
- if nrows:
1642
- format = kwargs.get("format")
1643
- if format not in ["csv", "json"] and not isinstance(
1644
- format, (CsvFileFormat, JsonFileFormat)
1645
- ):
1646
- raise DatasetPrepareError(
1647
- self.name,
1648
- "error in `parse_tabular` - "
1649
- "`nrows` only supported for csv and json formats.",
1650
- )
1651
+ parse_options = kwargs.pop("parse_options", None)
1652
+ if format := kwargs.get("format"):
1653
+ kwargs["format"] = fix_pyarrow_format(format, parse_options)
1654
+
1655
+ if (
1656
+ nrows
1657
+ and format not in ["csv", "json"]
1658
+ and not isinstance(format, (CsvFileFormat, JsonFileFormat))
1659
+ ):
1660
+ raise DatasetPrepareError(
1661
+ self.name,
1662
+ "error in `parse_tabular` - "
1663
+ "`nrows` only supported for csv and json formats.",
1664
+ )
1651
1665
 
1652
1666
  if "file" not in self.schema or not self.count():
1653
1667
  raise DatasetPrepareError(self.name, "no files to parse.")
@@ -1656,7 +1670,7 @@ class DataChain:
1656
1670
  col_names = output if isinstance(output, Sequence) else None
1657
1671
  if col_names or not output:
1658
1672
  try:
1659
- schema = infer_schema(self, **kwargs)
1673
+ schema = infer_schema(self, **kwargs, parse_options=parse_options)
1660
1674
  output, _ = schema_to_output(schema, col_names)
1661
1675
  except ValueError as e:
1662
1676
  raise DatasetPrepareError(self.name, e) from e
@@ -1682,7 +1696,15 @@ class DataChain:
1682
1696
  # disable prefetch if nrows is set
1683
1697
  settings = {"prefetch": 0} if nrows else {}
1684
1698
  return self.settings(**settings).gen( # type: ignore[arg-type]
1685
- ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
1699
+ ArrowGenerator(
1700
+ schema,
1701
+ model,
1702
+ source,
1703
+ nrows,
1704
+ parse_options=parse_options,
1705
+ **kwargs,
1706
+ ),
1707
+ output=output,
1686
1708
  )
1687
1709
 
1688
1710
  @classmethod
@@ -1,5 +1,6 @@
1
- from typing import TYPE_CHECKING, Optional, get_origin, get_type_hints
1
+ from typing import TYPE_CHECKING, Optional, Union, get_origin, get_type_hints
2
2
 
3
+ from datachain.error import DatasetVersionNotFoundError
3
4
  from datachain.lib.dataset_info import DatasetInfo
4
5
  from datachain.lib.file import (
5
6
  File,
@@ -22,7 +23,7 @@ if TYPE_CHECKING:
22
23
 
23
24
  def read_dataset(
24
25
  name: str,
25
- version: Optional[int] = None,
26
+ version: Optional[Union[str, int]] = None,
26
27
  session: Optional[Session] = None,
27
28
  settings: Optional[dict] = None,
28
29
  fallback_to_studio: bool = True,
@@ -49,7 +50,7 @@ def read_dataset(
49
50
  ```
50
51
 
51
52
  ```py
52
- chain = dc.read_dataset("my_cats", version=1)
53
+ chain = dc.read_dataset("my_cats", version="1.0.0")
53
54
  ```
54
55
 
55
56
  ```py
@@ -63,7 +64,7 @@ def read_dataset(
63
64
  }
64
65
  chain = dc.read_dataset(
65
66
  name="my_cats",
66
- version=1,
67
+ version="1.0.0",
67
68
  session=session,
68
69
  settings=settings,
69
70
  fallback_to_studio=True,
@@ -74,9 +75,29 @@ def read_dataset(
74
75
 
75
76
  from .datachain import DataChain
76
77
 
78
+ if version is not None:
79
+ try:
80
+ # for backward compatibility we still allow users to put version as integer
81
+ # in which case we are trying to find latest version where major part is
82
+ # equal to that input version. For example if user sets version=2, we could
83
+ # continue with something like 2.4.3 (assuming 2.4.3 is the biggest among
84
+ # all 2.* dataset versions). If dataset doesn't have any versions where
85
+ # major part is equal to that input, exception is thrown.
86
+ major = int(version)
87
+ dataset = Session.get(session).catalog.get_dataset(name)
88
+ latest_major = dataset.latest_major_version(major)
89
+ if not latest_major:
90
+ raise DatasetVersionNotFoundError(
91
+ f"Dataset {name} does not have version {version}"
92
+ )
93
+ version = latest_major
94
+ except ValueError:
95
+ # version is in new semver string format, continuing as normal
96
+ pass
97
+
77
98
  query = DatasetQuery(
78
99
  name=name,
79
- version=version,
100
+ version=version, # type: ignore[arg-type]
80
101
  session=session,
81
102
  indexing_column_types=File._datachain_column_types,
82
103
  fallback_to_studio=fallback_to_studio,
@@ -179,7 +200,7 @@ def datasets(
179
200
 
180
201
  def delete_dataset(
181
202
  name: str,
182
- version: Optional[int] = None,
203
+ version: Optional[str] = None,
183
204
  force: Optional[bool] = False,
184
205
  studio: Optional[bool] = False,
185
206
  session: Optional[Session] = None,
@@ -207,7 +228,7 @@ def delete_dataset(
207
228
 
208
229
  ```py
209
230
  import datachain as dc
210
- dc.delete_dataset("cats", version=1)
231
+ dc.delete_dataset("cats", version="1.0.0")
211
232
  ```
212
233
  """
213
234
 
@@ -5,6 +5,7 @@ from typing import (
5
5
  Union,
6
6
  )
7
7
 
8
+ from datachain.error import DatasetNotFoundError
8
9
  from datachain.lib.file import (
9
10
  FileType,
10
11
  get_file_type,
@@ -97,7 +98,8 @@ def read_storage(
97
98
  if anon:
98
99
  client_config = (client_config or {}) | {"anon": True}
99
100
  session = Session.get(session, client_config=client_config, in_memory=in_memory)
100
- cache = session.catalog.cache
101
+ catalog = session.catalog
102
+ cache = catalog.cache
101
103
  client_config = session.catalog.client_config
102
104
 
103
105
  uris = uri if isinstance(uri, (list, tuple)) else [uri]
@@ -130,6 +132,11 @@ def read_storage(
130
132
 
131
133
  def lst_fn(ds_name, lst_uri):
132
134
  # disable prefetch for listing, as it pre-downloads all files
135
+ try:
136
+ version = catalog.get_dataset(ds_name).next_version_major
137
+ except DatasetNotFoundError:
138
+ version = None
139
+
133
140
  (
134
141
  read_records(
135
142
  DataChain.DEFAULT_FILE_RECORD,
@@ -142,7 +149,8 @@ def read_storage(
142
149
  list_bucket(lst_uri, cache, client_config=client_config),
143
150
  output={f"{column}": file_type},
144
151
  )
145
- .save(ds_name, listing=True)
152
+ # for internal listing datasets, we always bump major version
153
+ .save(ds_name, listing=True, version=version)
146
154
  )
147
155
 
148
156
  dc._query.set_listing_fn(
datachain/lib/listing.py CHANGED
@@ -56,6 +56,8 @@ def list_bucket(uri: str, cache, client_config=None) -> Callable:
56
56
  for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
57
57
  yield from entries
58
58
 
59
+ list_func.__name__ = "read_storage"
60
+
59
61
  return list_func
60
62
 
61
63
 
datachain/lib/pytorch.py CHANGED
@@ -43,7 +43,7 @@ class PytorchDataset(IterableDataset):
43
43
  def __init__(
44
44
  self,
45
45
  name: str,
46
- version: Optional[int] = None,
46
+ version: Optional[str] = None,
47
47
  catalog: Optional["Catalog"] = None,
48
48
  transform: Optional["Transform"] = None,
49
49
  tokenizer: Optional[Callable] = None,
@@ -60,7 +60,7 @@ class PytorchDataset(IterableDataset):
60
60
 
61
61
  Args:
62
62
  name (str): Name of DataChain dataset to stream.
63
- version (int): Version of DataChain dataset to stream.
63
+ version (str): Version of DataChain dataset to stream.
64
64
  catalog (Catalog): DataChain catalog to which dataset belongs.
65
65
  transform (Transform): Torchvision transforms to apply to the dataset.
66
66
  tokenizer (Callable): Tokenizer to use to tokenize text values.
datachain/lib/udf.py CHANGED
@@ -218,6 +218,18 @@ class UDFBase(AbstractUDF):
218
218
  def name(self):
219
219
  return self.__class__.__name__
220
220
 
221
+ @property
222
+ def verbose_name(self):
223
+ """Returns the name of the function or class that implements the UDF."""
224
+ if self._func and callable(self._func):
225
+ if hasattr(self._func, "__name__"):
226
+ return self._func.__name__
227
+ if hasattr(self._func, "__class__") and hasattr(
228
+ self._func.__class__, "__name__"
229
+ ):
230
+ return self._func.__class__.__name__
231
+ return "<unknown>"
232
+
221
233
  @property
222
234
  def signal_names(self) -> Iterable[str]:
223
235
  return self.output.to_udf_spec().keys()
@@ -411,13 +423,13 @@ class BatchMapper(UDFBase):
411
423
  self.setup()
412
424
 
413
425
  for batch in udf_inputs:
414
- n_rows = len(batch.rows)
426
+ n_rows = len(batch)
415
427
  row_ids, *udf_args = zip(
416
428
  *[
417
429
  self._prepare_row_and_id(
418
430
  row, udf_fields, catalog, cache, download_cb
419
431
  )
420
- for row in batch.rows
432
+ for row in batch
421
433
  ]
422
434
  )
423
435
  result_objs = list(self.process_safe(udf_args))
@@ -489,7 +501,7 @@ class Aggregator(UDFBase):
489
501
 
490
502
  def run(
491
503
  self,
492
- udf_fields: "Sequence[str]",
504
+ udf_fields: Sequence[str],
493
505
  udf_inputs: Iterable[RowsOutputBatch],
494
506
  catalog: "Catalog",
495
507
  cache: bool,
@@ -502,13 +514,13 @@ class Aggregator(UDFBase):
502
514
  udf_args = zip(
503
515
  *[
504
516
  self._prepare_row(row, udf_fields, catalog, cache, download_cb)
505
- for row in batch.rows
517
+ for row in batch
506
518
  ]
507
519
  )
508
520
  result_objs = self.process_safe(udf_args)
509
521
  udf_outputs = (self._flatten_row(row) for row in result_objs)
510
522
  output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
511
- processed_cb.relative_update(len(batch.rows))
523
+ processed_cb.relative_update(len(batch))
512
524
  yield output
513
525
 
514
526
  self.teardown()
datachain/listing.py CHANGED
@@ -26,7 +26,7 @@ class Listing:
26
26
  warehouse: "AbstractWarehouse",
27
27
  client: "Client",
28
28
  dataset_name: Optional["str"] = None,
29
- dataset_version: Optional[int] = None,
29
+ dataset_version: Optional[str] = None,
30
30
  column: str = "file",
31
31
  ):
32
32
  self.metastore = metastore
datachain/query/batch.py CHANGED
@@ -2,22 +2,14 @@ import contextlib
2
2
  import math
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Generator, Sequence
5
- from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Callable, Optional, Union
7
-
8
- from datachain.data_storage.schema import PARTITION_COLUMN_ID
9
- from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
10
- from datachain.query.utils import get_query_column, get_query_id_column
11
-
12
- if TYPE_CHECKING:
13
- from sqlalchemy import Select
5
+ from typing import Callable, Optional, Union
14
6
 
7
+ import sqlalchemy as sa
15
8
 
16
- @dataclass
17
- class RowsOutputBatch:
18
- rows: Sequence[Sequence]
19
-
9
+ from datachain.data_storage.schema import PARTITION_COLUMN_ID
10
+ from datachain.query.utils import get_query_column
20
11
 
12
+ RowsOutputBatch = Sequence[Sequence]
21
13
  RowsOutput = Union[Sequence, RowsOutputBatch]
22
14
 
23
15
 
@@ -30,8 +22,8 @@ class BatchingStrategy(ABC):
30
22
  def __call__(
31
23
  self,
32
24
  execute: Callable,
33
- query: "Select",
34
- ids_only: bool = False,
25
+ query: sa.Select,
26
+ id_col: Optional[sa.ColumnElement] = None,
35
27
  ) -> Generator[RowsOutput, None, None]:
36
28
  """Apply the provided parameters to the UDF."""
37
29
 
@@ -47,12 +39,16 @@ class NoBatching(BatchingStrategy):
47
39
  def __call__(
48
40
  self,
49
41
  execute: Callable,
50
- query: "Select",
51
- ids_only: bool = False,
42
+ query: sa.Select,
43
+ id_col: Optional[sa.ColumnElement] = None,
52
44
  ) -> Generator[Sequence, None, None]:
53
- if ids_only:
54
- query = query.with_only_columns(get_query_id_column(query))
55
- return execute(query)
45
+ ids_only = False
46
+ if id_col is not None:
47
+ query = query.with_only_columns(id_col)
48
+ ids_only = True
49
+
50
+ rows = execute(query)
51
+ yield from (r[0] for r in rows) if ids_only else rows
56
52
 
57
53
 
58
54
  class Batch(BatchingStrategy):
@@ -69,27 +65,31 @@ class Batch(BatchingStrategy):
69
65
  def __call__(
70
66
  self,
71
67
  execute: Callable,
72
- query: "Select",
73
- ids_only: bool = False,
74
- ) -> Generator[RowsOutputBatch, None, None]:
75
- if ids_only:
76
- query = query.with_only_columns(get_query_id_column(query))
68
+ query: sa.Select,
69
+ id_col: Optional[sa.ColumnElement] = None,
70
+ ) -> Generator[RowsOutput, None, None]:
71
+ from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
72
+
73
+ ids_only = False
74
+ if id_col is not None:
75
+ query = query.with_only_columns(id_col)
76
+ ids_only = True
77
77
 
78
78
  # choose page size that is a multiple of the batch size
79
79
  page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
80
80
 
81
81
  # select rows in batches
82
- results: list[Sequence] = []
82
+ results = []
83
83
 
84
- with contextlib.closing(execute(query, page_size=page_size)) as rows:
85
- for row in rows:
84
+ with contextlib.closing(execute(query, page_size=page_size)) as batch_rows:
85
+ for row in batch_rows:
86
86
  results.append(row)
87
87
  if len(results) >= self.count:
88
88
  batch, results = results[: self.count], results[self.count :]
89
- yield RowsOutputBatch(batch)
89
+ yield [r[0] for r in batch] if ids_only else batch
90
90
 
91
91
  if len(results) > 0:
92
- yield RowsOutputBatch(results)
92
+ yield [r[0] for r in results] if ids_only else results
93
93
 
94
94
 
95
95
  class Partition(BatchingStrategy):
@@ -104,18 +104,19 @@ class Partition(BatchingStrategy):
104
104
  def __call__(
105
105
  self,
106
106
  execute: Callable,
107
- query: "Select",
108
- ids_only: bool = False,
109
- ) -> Generator[RowsOutputBatch, None, None]:
110
- id_col = get_query_id_column(query)
107
+ query: sa.Select,
108
+ id_col: Optional[sa.ColumnElement] = None,
109
+ ) -> Generator[RowsOutput, None, None]:
111
110
  if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
112
111
  raise RuntimeError("partition column not found in query")
113
112
 
114
- if ids_only:
113
+ ids_only = False
114
+ if id_col is not None:
115
115
  query = query.with_only_columns(id_col, partition_col)
116
+ ids_only = True
116
117
 
117
118
  current_partition: Optional[int] = None
118
- batch: list[Sequence] = []
119
+ batch: list = []
119
120
 
120
121
  query_fields = [str(c.name) for c in query.selected_columns]
121
122
  id_column_idx = query_fields.index("sys__id")
@@ -132,9 +133,9 @@ class Partition(BatchingStrategy):
132
133
  if current_partition != partition:
133
134
  current_partition = partition
134
135
  if len(batch) > 0:
135
- yield RowsOutputBatch(batch)
136
+ yield batch
136
137
  batch = []
137
- batch.append([row[id_column_idx]] if ids_only else row)
138
+ batch.append(row[id_column_idx] if ids_only else row)
138
139
 
139
140
  if len(batch) > 0:
140
- yield RowsOutputBatch(batch)
141
+ yield batch
@@ -42,15 +42,9 @@ from datachain.data_storage.schema import (
42
42
  partition_columns,
43
43
  )
44
44
  from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
45
- from datachain.error import (
46
- DatasetNotFoundError,
47
- QueryScriptCancelError,
48
- )
45
+ from datachain.error import DatasetNotFoundError, QueryScriptCancelError
49
46
  from datachain.func.base import Function
50
- from datachain.lib.listing import (
51
- is_listing_dataset,
52
- listing_dataset_expired,
53
- )
47
+ from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
54
48
  from datachain.lib.udf import UDFAdapter, _get_cache
55
49
  from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
56
50
  from datachain.query.schema import C, UDFParamSpec, normalize_param
@@ -89,7 +83,7 @@ PartitionByType = Union[
89
83
  Function, ColumnElement, Sequence[Union[Function, ColumnElement]]
90
84
  ]
91
85
  JoinPredicateType = Union[str, ColumnClause, ColumnElement]
92
- DatasetDependencyType = tuple[str, int]
86
+ DatasetDependencyType = tuple[str, str]
93
87
 
94
88
  logger = logging.getLogger("datachain")
95
89
 
@@ -174,7 +168,7 @@ class Step(ABC):
174
168
  class QueryStep:
175
169
  catalog: "Catalog"
176
170
  dataset_name: str
177
- dataset_version: int
171
+ dataset_version: str
178
172
 
179
173
  def apply(self):
180
174
  def q(*columns):
@@ -420,41 +414,30 @@ class UDFStep(Step, ABC):
420
414
  """
421
415
 
422
416
  def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
423
- from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
424
-
425
- rows_total = self.catalog.warehouse.query_count(query)
426
- if rows_total == 0:
417
+ if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
427
418
  return
428
419
 
420
+ from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
421
+ from datachain.catalog.loader import (
422
+ DISTRIBUTED_IMPORT_PATH,
423
+ get_udf_distributor_class,
424
+ )
425
+
429
426
  workers = determine_workers(self.workers, rows_total=rows_total)
430
427
  processes = determine_processes(self.parallel, rows_total=rows_total)
431
428
 
432
429
  use_partitioning = self.partition_by is not None
433
430
  batching = self.udf.get_batching(use_partitioning)
434
431
  udf_fields = [str(c.name) for c in query.selected_columns]
432
+ udf_distributor_class = get_udf_distributor_class()
435
433
 
436
434
  prefetch = self.udf.prefetch
437
435
  with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
438
436
  catalog = clone_catalog_with_cache(self.catalog, _cache)
439
- try:
440
- if workers:
441
- if catalog.in_memory:
442
- raise RuntimeError(
443
- "In-memory databases cannot be used with "
444
- "distributed processing."
445
- )
446
-
447
- from datachain.catalog.loader import (
448
- DISTRIBUTED_IMPORT_PATH,
449
- get_udf_distributor_class,
450
- )
451
-
452
- if not (udf_distributor_class := get_udf_distributor_class()):
453
- raise RuntimeError(
454
- f"{DISTRIBUTED_IMPORT_PATH} import path is required "
455
- "for distributed UDF processing."
456
- )
457
437
 
438
+ try:
439
+ if udf_distributor_class and not catalog.in_memory:
440
+ # Use the UDF distributor if available (running in SaaS)
458
441
  udf_distributor = udf_distributor_class(
459
442
  catalog=catalog,
460
443
  table=udf_table,
@@ -470,7 +453,20 @@ class UDFStep(Step, ABC):
470
453
  min_task_size=self.min_task_size,
471
454
  )
472
455
  udf_distributor()
473
- elif processes:
456
+ return
457
+
458
+ if workers:
459
+ if catalog.in_memory:
460
+ raise RuntimeError(
461
+ "In-memory databases cannot be used with "
462
+ "distributed processing."
463
+ )
464
+
465
+ raise RuntimeError(
466
+ f"{DISTRIBUTED_IMPORT_PATH} import path is required "
467
+ "for distributed UDF processing."
468
+ )
469
+ if processes:
474
470
  # Parallel processing (faster for more CPU-heavy UDFs)
475
471
  if catalog.in_memory:
476
472
  raise RuntimeError(
@@ -504,7 +500,12 @@ class UDFStep(Step, ABC):
504
500
  with subprocess.Popen( # noqa: S603
505
501
  cmd, env=envs, stdin=subprocess.PIPE
506
502
  ) as process:
507
- process.communicate(process_data)
503
+ try:
504
+ process.communicate(process_data)
505
+ except KeyboardInterrupt:
506
+ raise QueryScriptCancelError(
507
+ "UDF execution was canceled by the user."
508
+ ) from None
508
509
  if retval := process.poll():
509
510
  raise RuntimeError(
510
511
  f"UDF Execution Failed! Exit code: {retval}"
@@ -1091,7 +1092,7 @@ class DatasetQuery:
1091
1092
  def __init__(
1092
1093
  self,
1093
1094
  name: str,
1094
- version: Optional[int] = None,
1095
+ version: Optional[str] = None,
1095
1096
  catalog: Optional["Catalog"] = None,
1096
1097
  session: Optional[Session] = None,
1097
1098
  indexing_column_types: Optional[dict[str, Any]] = None,
@@ -1111,7 +1112,7 @@ class DatasetQuery:
1111
1112
  self.table = self.get_table()
1112
1113
  self.starting_step: Optional[QueryStep] = None
1113
1114
  self.name: Optional[str] = None
1114
- self.version: Optional[int] = None
1115
+ self.version: Optional[str] = None
1115
1116
  self.feature_schema: Optional[dict] = None
1116
1117
  self.column_types: Optional[dict[str, Any]] = None
1117
1118
  self.before_steps: list[Callable] = []
@@ -1154,7 +1155,7 @@ class DatasetQuery:
1154
1155
  def __or__(self, other):
1155
1156
  return self.union(other)
1156
1157
 
1157
- def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
1158
+ def pull_dataset(self, name: str, version: Optional[str] = None) -> "DatasetRecord":
1158
1159
  print("Dataset not found in local catalog, trying to get from studio")
1159
1160
 
1160
1161
  remote_ds_uri = f"{DATASET_PREFIX}{name}"
@@ -1184,8 +1185,8 @@ class DatasetQuery:
1184
1185
  it completely. If this is the case, name and version of underlying dataset
1185
1186
  will be defined.
1186
1187
  DatasetQuery instance can become attached in two scenarios:
1187
- 1. ds = DatasetQuery(name="dogs", version=1) -> ds is attached to dogs
1188
- 2. ds = ds.save("dogs", version=1) -> ds is attached to dogs dataset
1188
+ 1. ds = DatasetQuery(name="dogs", version="1.0.0") -> ds is attached to dogs
1189
+ 2. ds = ds.save("dogs", version="1.0.0") -> ds is attached to dogs dataset
1189
1190
  It can move to detached state if filter or similar methods are called on it,
1190
1191
  as then it no longer 100% represents underlying datasets.
1191
1192
  """
@@ -1662,7 +1663,7 @@ class DatasetQuery:
1662
1663
  )
1663
1664
  return query
1664
1665
 
1665
- def _add_dependencies(self, dataset: "DatasetRecord", version: int):
1666
+ def _add_dependencies(self, dataset: "DatasetRecord", version: str):
1666
1667
  for dependency in self.dependencies:
1667
1668
  ds_dependency_name, ds_dependency_version = dependency
1668
1669
  self.catalog.metastore.add_dataset_dependency(
@@ -1684,7 +1685,7 @@ class DatasetQuery:
1684
1685
  def save(
1685
1686
  self,
1686
1687
  name: Optional[str] = None,
1687
- version: Optional[int] = None,
1688
+ version: Optional[str] = None,
1688
1689
  feature_schema: Optional[dict] = None,
1689
1690
  description: Optional[str] = None,
1690
1691
  attrs: Optional[list[str]] = None,