datachain 0.16.4__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +25 -92
- datachain/cli/__init__.py +11 -9
- datachain/cli/commands/datasets.py +1 -1
- datachain/cli/commands/query.py +1 -0
- datachain/cli/commands/show.py +1 -1
- datachain/cli/parser/__init__.py +11 -3
- datachain/data_storage/job.py +1 -0
- datachain/data_storage/metastore.py +105 -94
- datachain/data_storage/sqlite.py +8 -7
- datachain/data_storage/warehouse.py +58 -46
- datachain/dataset.py +88 -45
- datachain/lib/arrow.py +23 -1
- datachain/lib/dataset_info.py +2 -1
- datachain/lib/dc/csv.py +1 -0
- datachain/lib/dc/datachain.py +38 -16
- datachain/lib/dc/datasets.py +28 -7
- datachain/lib/dc/storage.py +10 -2
- datachain/lib/listing.py +2 -0
- datachain/lib/pytorch.py +2 -2
- datachain/lib/udf.py +17 -5
- datachain/listing.py +1 -1
- datachain/query/batch.py +40 -39
- datachain/query/dataset.py +42 -41
- datachain/query/dispatch.py +137 -75
- datachain/query/metrics.py +1 -2
- datachain/query/queue.py +1 -11
- datachain/query/session.py +2 -2
- datachain/query/udf.py +1 -1
- datachain/query/utils.py +8 -14
- datachain/remote/studio.py +4 -4
- datachain/semver.py +58 -0
- datachain/studio.py +1 -1
- datachain/utils.py +3 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/METADATA +1 -1
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/RECORD +39 -38
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/WHEEL +1 -1
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.16.4.dist-info → datachain-0.17.0.dist-info}/top_level.txt +0 -0
datachain/lib/dataset_info.py
CHANGED
|
@@ -6,6 +6,7 @@ from uuid import uuid4
|
|
|
6
6
|
from pydantic import Field, field_validator
|
|
7
7
|
|
|
8
8
|
from datachain.dataset import (
|
|
9
|
+
DEFAULT_DATASET_VERSION,
|
|
9
10
|
DatasetListRecord,
|
|
10
11
|
DatasetListVersion,
|
|
11
12
|
DatasetStatus,
|
|
@@ -22,7 +23,7 @@ if TYPE_CHECKING:
|
|
|
22
23
|
class DatasetInfo(DataModel):
|
|
23
24
|
name: str
|
|
24
25
|
uuid: str = Field(default=str(uuid4()))
|
|
25
|
-
version:
|
|
26
|
+
version: str = Field(default=DEFAULT_DATASET_VERSION)
|
|
26
27
|
status: int = Field(default=DatasetStatus.CREATED)
|
|
27
28
|
created_at: datetime = Field(default=TIME_ZERO)
|
|
28
29
|
finished_at: Optional[datetime] = Field(default=None)
|
datachain/lib/dc/csv.py
CHANGED
datachain/lib/dc/datachain.py
CHANGED
|
@@ -23,6 +23,7 @@ import sqlalchemy
|
|
|
23
23
|
from pydantic import BaseModel
|
|
24
24
|
from tqdm import tqdm
|
|
25
25
|
|
|
26
|
+
from datachain import semver
|
|
26
27
|
from datachain.dataset import DatasetRecord
|
|
27
28
|
from datachain.func import literal
|
|
28
29
|
from datachain.func.base import Function
|
|
@@ -214,7 +215,7 @@ class DataChain:
|
|
|
214
215
|
return self._query.name
|
|
215
216
|
|
|
216
217
|
@property
|
|
217
|
-
def version(self) -> Optional[
|
|
218
|
+
def version(self) -> Optional[str]:
|
|
218
219
|
"""Version of the underlying dataset, if there is one."""
|
|
219
220
|
return self._query.version
|
|
220
221
|
|
|
@@ -457,7 +458,7 @@ class DataChain:
|
|
|
457
458
|
def save( # type: ignore[override]
|
|
458
459
|
self,
|
|
459
460
|
name: str,
|
|
460
|
-
version: Optional[
|
|
461
|
+
version: Optional[str] = None,
|
|
461
462
|
description: Optional[str] = None,
|
|
462
463
|
attrs: Optional[list[str]] = None,
|
|
463
464
|
**kwargs,
|
|
@@ -466,11 +467,15 @@ class DataChain:
|
|
|
466
467
|
|
|
467
468
|
Parameters:
|
|
468
469
|
name : dataset name.
|
|
469
|
-
version : version of a dataset.
|
|
470
|
+
version : version of a dataset. If version is not specified and dataset
|
|
471
|
+
already exists, version patch increment will happen e.g 1.2.1 -> 1.2.2.
|
|
470
472
|
description : description of a dataset.
|
|
471
473
|
attrs : attributes of a dataset. They can be without value, e.g "NLP",
|
|
472
474
|
or with a value, e.g "location=US".
|
|
473
475
|
"""
|
|
476
|
+
if version is not None:
|
|
477
|
+
semver.validate(version)
|
|
478
|
+
|
|
474
479
|
schema = self.signals_schema.clone_without_sys_signals().serialize()
|
|
475
480
|
return self._evolve(
|
|
476
481
|
query=self._query.save(
|
|
@@ -1636,18 +1641,27 @@ class DataChain:
|
|
|
1636
1641
|
"""
|
|
1637
1642
|
from pyarrow.dataset import CsvFileFormat, JsonFileFormat
|
|
1638
1643
|
|
|
1639
|
-
from datachain.lib.arrow import
|
|
1644
|
+
from datachain.lib.arrow import (
|
|
1645
|
+
ArrowGenerator,
|
|
1646
|
+
fix_pyarrow_format,
|
|
1647
|
+
infer_schema,
|
|
1648
|
+
schema_to_output,
|
|
1649
|
+
)
|
|
1640
1650
|
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
+
parse_options = kwargs.pop("parse_options", None)
|
|
1652
|
+
if format := kwargs.get("format"):
|
|
1653
|
+
kwargs["format"] = fix_pyarrow_format(format, parse_options)
|
|
1654
|
+
|
|
1655
|
+
if (
|
|
1656
|
+
nrows
|
|
1657
|
+
and format not in ["csv", "json"]
|
|
1658
|
+
and not isinstance(format, (CsvFileFormat, JsonFileFormat))
|
|
1659
|
+
):
|
|
1660
|
+
raise DatasetPrepareError(
|
|
1661
|
+
self.name,
|
|
1662
|
+
"error in `parse_tabular` - "
|
|
1663
|
+
"`nrows` only supported for csv and json formats.",
|
|
1664
|
+
)
|
|
1651
1665
|
|
|
1652
1666
|
if "file" not in self.schema or not self.count():
|
|
1653
1667
|
raise DatasetPrepareError(self.name, "no files to parse.")
|
|
@@ -1656,7 +1670,7 @@ class DataChain:
|
|
|
1656
1670
|
col_names = output if isinstance(output, Sequence) else None
|
|
1657
1671
|
if col_names or not output:
|
|
1658
1672
|
try:
|
|
1659
|
-
schema = infer_schema(self, **kwargs)
|
|
1673
|
+
schema = infer_schema(self, **kwargs, parse_options=parse_options)
|
|
1660
1674
|
output, _ = schema_to_output(schema, col_names)
|
|
1661
1675
|
except ValueError as e:
|
|
1662
1676
|
raise DatasetPrepareError(self.name, e) from e
|
|
@@ -1682,7 +1696,15 @@ class DataChain:
|
|
|
1682
1696
|
# disable prefetch if nrows is set
|
|
1683
1697
|
settings = {"prefetch": 0} if nrows else {}
|
|
1684
1698
|
return self.settings(**settings).gen( # type: ignore[arg-type]
|
|
1685
|
-
ArrowGenerator(
|
|
1699
|
+
ArrowGenerator(
|
|
1700
|
+
schema,
|
|
1701
|
+
model,
|
|
1702
|
+
source,
|
|
1703
|
+
nrows,
|
|
1704
|
+
parse_options=parse_options,
|
|
1705
|
+
**kwargs,
|
|
1706
|
+
),
|
|
1707
|
+
output=output,
|
|
1686
1708
|
)
|
|
1687
1709
|
|
|
1688
1710
|
@classmethod
|
datachain/lib/dc/datasets.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING, Optional, get_origin, get_type_hints
|
|
1
|
+
from typing import TYPE_CHECKING, Optional, Union, get_origin, get_type_hints
|
|
2
2
|
|
|
3
|
+
from datachain.error import DatasetVersionNotFoundError
|
|
3
4
|
from datachain.lib.dataset_info import DatasetInfo
|
|
4
5
|
from datachain.lib.file import (
|
|
5
6
|
File,
|
|
@@ -22,7 +23,7 @@ if TYPE_CHECKING:
|
|
|
22
23
|
|
|
23
24
|
def read_dataset(
|
|
24
25
|
name: str,
|
|
25
|
-
version: Optional[int] = None,
|
|
26
|
+
version: Optional[Union[str, int]] = None,
|
|
26
27
|
session: Optional[Session] = None,
|
|
27
28
|
settings: Optional[dict] = None,
|
|
28
29
|
fallback_to_studio: bool = True,
|
|
@@ -49,7 +50,7 @@ def read_dataset(
|
|
|
49
50
|
```
|
|
50
51
|
|
|
51
52
|
```py
|
|
52
|
-
chain = dc.read_dataset("my_cats", version=1)
|
|
53
|
+
chain = dc.read_dataset("my_cats", version="1.0.0")
|
|
53
54
|
```
|
|
54
55
|
|
|
55
56
|
```py
|
|
@@ -63,7 +64,7 @@ def read_dataset(
|
|
|
63
64
|
}
|
|
64
65
|
chain = dc.read_dataset(
|
|
65
66
|
name="my_cats",
|
|
66
|
-
version=1,
|
|
67
|
+
version="1.0.0",
|
|
67
68
|
session=session,
|
|
68
69
|
settings=settings,
|
|
69
70
|
fallback_to_studio=True,
|
|
@@ -74,9 +75,29 @@ def read_dataset(
|
|
|
74
75
|
|
|
75
76
|
from .datachain import DataChain
|
|
76
77
|
|
|
78
|
+
if version is not None:
|
|
79
|
+
try:
|
|
80
|
+
# for backward compatibility we still allow users to put version as integer
|
|
81
|
+
# in which case we are trying to find latest version where major part is
|
|
82
|
+
# equal to that input version. For example if user sets version=2, we could
|
|
83
|
+
# continue with something like 2.4.3 (assuming 2.4.3 is the biggest among
|
|
84
|
+
# all 2.* dataset versions). If dataset doesn't have any versions where
|
|
85
|
+
# major part is equal to that input, exception is thrown.
|
|
86
|
+
major = int(version)
|
|
87
|
+
dataset = Session.get(session).catalog.get_dataset(name)
|
|
88
|
+
latest_major = dataset.latest_major_version(major)
|
|
89
|
+
if not latest_major:
|
|
90
|
+
raise DatasetVersionNotFoundError(
|
|
91
|
+
f"Dataset {name} does not have version {version}"
|
|
92
|
+
)
|
|
93
|
+
version = latest_major
|
|
94
|
+
except ValueError:
|
|
95
|
+
# version is in new semver string format, continuing as normal
|
|
96
|
+
pass
|
|
97
|
+
|
|
77
98
|
query = DatasetQuery(
|
|
78
99
|
name=name,
|
|
79
|
-
version=version,
|
|
100
|
+
version=version, # type: ignore[arg-type]
|
|
80
101
|
session=session,
|
|
81
102
|
indexing_column_types=File._datachain_column_types,
|
|
82
103
|
fallback_to_studio=fallback_to_studio,
|
|
@@ -179,7 +200,7 @@ def datasets(
|
|
|
179
200
|
|
|
180
201
|
def delete_dataset(
|
|
181
202
|
name: str,
|
|
182
|
-
version: Optional[
|
|
203
|
+
version: Optional[str] = None,
|
|
183
204
|
force: Optional[bool] = False,
|
|
184
205
|
studio: Optional[bool] = False,
|
|
185
206
|
session: Optional[Session] = None,
|
|
@@ -207,7 +228,7 @@ def delete_dataset(
|
|
|
207
228
|
|
|
208
229
|
```py
|
|
209
230
|
import datachain as dc
|
|
210
|
-
dc.delete_dataset("cats", version=1)
|
|
231
|
+
dc.delete_dataset("cats", version="1.0.0")
|
|
211
232
|
```
|
|
212
233
|
"""
|
|
213
234
|
|
datachain/lib/dc/storage.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import (
|
|
|
5
5
|
Union,
|
|
6
6
|
)
|
|
7
7
|
|
|
8
|
+
from datachain.error import DatasetNotFoundError
|
|
8
9
|
from datachain.lib.file import (
|
|
9
10
|
FileType,
|
|
10
11
|
get_file_type,
|
|
@@ -97,7 +98,8 @@ def read_storage(
|
|
|
97
98
|
if anon:
|
|
98
99
|
client_config = (client_config or {}) | {"anon": True}
|
|
99
100
|
session = Session.get(session, client_config=client_config, in_memory=in_memory)
|
|
100
|
-
|
|
101
|
+
catalog = session.catalog
|
|
102
|
+
cache = catalog.cache
|
|
101
103
|
client_config = session.catalog.client_config
|
|
102
104
|
|
|
103
105
|
uris = uri if isinstance(uri, (list, tuple)) else [uri]
|
|
@@ -130,6 +132,11 @@ def read_storage(
|
|
|
130
132
|
|
|
131
133
|
def lst_fn(ds_name, lst_uri):
|
|
132
134
|
# disable prefetch for listing, as it pre-downloads all files
|
|
135
|
+
try:
|
|
136
|
+
version = catalog.get_dataset(ds_name).next_version_major
|
|
137
|
+
except DatasetNotFoundError:
|
|
138
|
+
version = None
|
|
139
|
+
|
|
133
140
|
(
|
|
134
141
|
read_records(
|
|
135
142
|
DataChain.DEFAULT_FILE_RECORD,
|
|
@@ -142,7 +149,8 @@ def read_storage(
|
|
|
142
149
|
list_bucket(lst_uri, cache, client_config=client_config),
|
|
143
150
|
output={f"{column}": file_type},
|
|
144
151
|
)
|
|
145
|
-
|
|
152
|
+
# for internal listing datasets, we always bump major version
|
|
153
|
+
.save(ds_name, listing=True, version=version)
|
|
146
154
|
)
|
|
147
155
|
|
|
148
156
|
dc._query.set_listing_fn(
|
datachain/lib/listing.py
CHANGED
datachain/lib/pytorch.py
CHANGED
|
@@ -43,7 +43,7 @@ class PytorchDataset(IterableDataset):
|
|
|
43
43
|
def __init__(
|
|
44
44
|
self,
|
|
45
45
|
name: str,
|
|
46
|
-
version: Optional[
|
|
46
|
+
version: Optional[str] = None,
|
|
47
47
|
catalog: Optional["Catalog"] = None,
|
|
48
48
|
transform: Optional["Transform"] = None,
|
|
49
49
|
tokenizer: Optional[Callable] = None,
|
|
@@ -60,7 +60,7 @@ class PytorchDataset(IterableDataset):
|
|
|
60
60
|
|
|
61
61
|
Args:
|
|
62
62
|
name (str): Name of DataChain dataset to stream.
|
|
63
|
-
version (
|
|
63
|
+
version (str): Version of DataChain dataset to stream.
|
|
64
64
|
catalog (Catalog): DataChain catalog to which dataset belongs.
|
|
65
65
|
transform (Transform): Torchvision transforms to apply to the dataset.
|
|
66
66
|
tokenizer (Callable): Tokenizer to use to tokenize text values.
|
datachain/lib/udf.py
CHANGED
|
@@ -218,6 +218,18 @@ class UDFBase(AbstractUDF):
|
|
|
218
218
|
def name(self):
|
|
219
219
|
return self.__class__.__name__
|
|
220
220
|
|
|
221
|
+
@property
|
|
222
|
+
def verbose_name(self):
|
|
223
|
+
"""Returns the name of the function or class that implements the UDF."""
|
|
224
|
+
if self._func and callable(self._func):
|
|
225
|
+
if hasattr(self._func, "__name__"):
|
|
226
|
+
return self._func.__name__
|
|
227
|
+
if hasattr(self._func, "__class__") and hasattr(
|
|
228
|
+
self._func.__class__, "__name__"
|
|
229
|
+
):
|
|
230
|
+
return self._func.__class__.__name__
|
|
231
|
+
return "<unknown>"
|
|
232
|
+
|
|
221
233
|
@property
|
|
222
234
|
def signal_names(self) -> Iterable[str]:
|
|
223
235
|
return self.output.to_udf_spec().keys()
|
|
@@ -411,13 +423,13 @@ class BatchMapper(UDFBase):
|
|
|
411
423
|
self.setup()
|
|
412
424
|
|
|
413
425
|
for batch in udf_inputs:
|
|
414
|
-
n_rows = len(batch
|
|
426
|
+
n_rows = len(batch)
|
|
415
427
|
row_ids, *udf_args = zip(
|
|
416
428
|
*[
|
|
417
429
|
self._prepare_row_and_id(
|
|
418
430
|
row, udf_fields, catalog, cache, download_cb
|
|
419
431
|
)
|
|
420
|
-
for row in batch
|
|
432
|
+
for row in batch
|
|
421
433
|
]
|
|
422
434
|
)
|
|
423
435
|
result_objs = list(self.process_safe(udf_args))
|
|
@@ -489,7 +501,7 @@ class Aggregator(UDFBase):
|
|
|
489
501
|
|
|
490
502
|
def run(
|
|
491
503
|
self,
|
|
492
|
-
udf_fields:
|
|
504
|
+
udf_fields: Sequence[str],
|
|
493
505
|
udf_inputs: Iterable[RowsOutputBatch],
|
|
494
506
|
catalog: "Catalog",
|
|
495
507
|
cache: bool,
|
|
@@ -502,13 +514,13 @@ class Aggregator(UDFBase):
|
|
|
502
514
|
udf_args = zip(
|
|
503
515
|
*[
|
|
504
516
|
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
|
|
505
|
-
for row in batch
|
|
517
|
+
for row in batch
|
|
506
518
|
]
|
|
507
519
|
)
|
|
508
520
|
result_objs = self.process_safe(udf_args)
|
|
509
521
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
510
522
|
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
511
|
-
processed_cb.relative_update(len(batch
|
|
523
|
+
processed_cb.relative_update(len(batch))
|
|
512
524
|
yield output
|
|
513
525
|
|
|
514
526
|
self.teardown()
|
datachain/listing.py
CHANGED
datachain/query/batch.py
CHANGED
|
@@ -2,22 +2,14 @@ import contextlib
|
|
|
2
2
|
import math
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import Generator, Sequence
|
|
5
|
-
from
|
|
6
|
-
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
7
|
-
|
|
8
|
-
from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
9
|
-
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
10
|
-
from datachain.query.utils import get_query_column, get_query_id_column
|
|
11
|
-
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from sqlalchemy import Select
|
|
5
|
+
from typing import Callable, Optional, Union
|
|
14
6
|
|
|
7
|
+
import sqlalchemy as sa
|
|
15
8
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
rows: Sequence[Sequence]
|
|
19
|
-
|
|
9
|
+
from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
10
|
+
from datachain.query.utils import get_query_column
|
|
20
11
|
|
|
12
|
+
RowsOutputBatch = Sequence[Sequence]
|
|
21
13
|
RowsOutput = Union[Sequence, RowsOutputBatch]
|
|
22
14
|
|
|
23
15
|
|
|
@@ -30,8 +22,8 @@ class BatchingStrategy(ABC):
|
|
|
30
22
|
def __call__(
|
|
31
23
|
self,
|
|
32
24
|
execute: Callable,
|
|
33
|
-
query:
|
|
34
|
-
|
|
25
|
+
query: sa.Select,
|
|
26
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
35
27
|
) -> Generator[RowsOutput, None, None]:
|
|
36
28
|
"""Apply the provided parameters to the UDF."""
|
|
37
29
|
|
|
@@ -47,12 +39,16 @@ class NoBatching(BatchingStrategy):
|
|
|
47
39
|
def __call__(
|
|
48
40
|
self,
|
|
49
41
|
execute: Callable,
|
|
50
|
-
query:
|
|
51
|
-
|
|
42
|
+
query: sa.Select,
|
|
43
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
52
44
|
) -> Generator[Sequence, None, None]:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
45
|
+
ids_only = False
|
|
46
|
+
if id_col is not None:
|
|
47
|
+
query = query.with_only_columns(id_col)
|
|
48
|
+
ids_only = True
|
|
49
|
+
|
|
50
|
+
rows = execute(query)
|
|
51
|
+
yield from (r[0] for r in rows) if ids_only else rows
|
|
56
52
|
|
|
57
53
|
|
|
58
54
|
class Batch(BatchingStrategy):
|
|
@@ -69,27 +65,31 @@ class Batch(BatchingStrategy):
|
|
|
69
65
|
def __call__(
|
|
70
66
|
self,
|
|
71
67
|
execute: Callable,
|
|
72
|
-
query:
|
|
73
|
-
|
|
74
|
-
) -> Generator[
|
|
75
|
-
|
|
76
|
-
|
|
68
|
+
query: sa.Select,
|
|
69
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
70
|
+
) -> Generator[RowsOutput, None, None]:
|
|
71
|
+
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
72
|
+
|
|
73
|
+
ids_only = False
|
|
74
|
+
if id_col is not None:
|
|
75
|
+
query = query.with_only_columns(id_col)
|
|
76
|
+
ids_only = True
|
|
77
77
|
|
|
78
78
|
# choose page size that is a multiple of the batch size
|
|
79
79
|
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
|
|
80
80
|
|
|
81
81
|
# select rows in batches
|
|
82
|
-
results
|
|
82
|
+
results = []
|
|
83
83
|
|
|
84
|
-
with contextlib.closing(execute(query, page_size=page_size)) as
|
|
85
|
-
for row in
|
|
84
|
+
with contextlib.closing(execute(query, page_size=page_size)) as batch_rows:
|
|
85
|
+
for row in batch_rows:
|
|
86
86
|
results.append(row)
|
|
87
87
|
if len(results) >= self.count:
|
|
88
88
|
batch, results = results[: self.count], results[self.count :]
|
|
89
|
-
yield
|
|
89
|
+
yield [r[0] for r in batch] if ids_only else batch
|
|
90
90
|
|
|
91
91
|
if len(results) > 0:
|
|
92
|
-
yield
|
|
92
|
+
yield [r[0] for r in results] if ids_only else results
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
class Partition(BatchingStrategy):
|
|
@@ -104,18 +104,19 @@ class Partition(BatchingStrategy):
|
|
|
104
104
|
def __call__(
|
|
105
105
|
self,
|
|
106
106
|
execute: Callable,
|
|
107
|
-
query:
|
|
108
|
-
|
|
109
|
-
) -> Generator[
|
|
110
|
-
id_col = get_query_id_column(query)
|
|
107
|
+
query: sa.Select,
|
|
108
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
109
|
+
) -> Generator[RowsOutput, None, None]:
|
|
111
110
|
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
|
|
112
111
|
raise RuntimeError("partition column not found in query")
|
|
113
112
|
|
|
114
|
-
|
|
113
|
+
ids_only = False
|
|
114
|
+
if id_col is not None:
|
|
115
115
|
query = query.with_only_columns(id_col, partition_col)
|
|
116
|
+
ids_only = True
|
|
116
117
|
|
|
117
118
|
current_partition: Optional[int] = None
|
|
118
|
-
batch: list
|
|
119
|
+
batch: list = []
|
|
119
120
|
|
|
120
121
|
query_fields = [str(c.name) for c in query.selected_columns]
|
|
121
122
|
id_column_idx = query_fields.index("sys__id")
|
|
@@ -132,9 +133,9 @@ class Partition(BatchingStrategy):
|
|
|
132
133
|
if current_partition != partition:
|
|
133
134
|
current_partition = partition
|
|
134
135
|
if len(batch) > 0:
|
|
135
|
-
yield
|
|
136
|
+
yield batch
|
|
136
137
|
batch = []
|
|
137
|
-
batch.append(
|
|
138
|
+
batch.append(row[id_column_idx] if ids_only else row)
|
|
138
139
|
|
|
139
140
|
if len(batch) > 0:
|
|
140
|
-
yield
|
|
141
|
+
yield batch
|
datachain/query/dataset.py
CHANGED
|
@@ -42,15 +42,9 @@ from datachain.data_storage.schema import (
|
|
|
42
42
|
partition_columns,
|
|
43
43
|
)
|
|
44
44
|
from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
|
|
45
|
-
from datachain.error import
|
|
46
|
-
DatasetNotFoundError,
|
|
47
|
-
QueryScriptCancelError,
|
|
48
|
-
)
|
|
45
|
+
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
49
46
|
from datachain.func.base import Function
|
|
50
|
-
from datachain.lib.listing import
|
|
51
|
-
is_listing_dataset,
|
|
52
|
-
listing_dataset_expired,
|
|
53
|
-
)
|
|
47
|
+
from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
|
|
54
48
|
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
55
49
|
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
56
50
|
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
@@ -89,7 +83,7 @@ PartitionByType = Union[
|
|
|
89
83
|
Function, ColumnElement, Sequence[Union[Function, ColumnElement]]
|
|
90
84
|
]
|
|
91
85
|
JoinPredicateType = Union[str, ColumnClause, ColumnElement]
|
|
92
|
-
DatasetDependencyType = tuple[str,
|
|
86
|
+
DatasetDependencyType = tuple[str, str]
|
|
93
87
|
|
|
94
88
|
logger = logging.getLogger("datachain")
|
|
95
89
|
|
|
@@ -174,7 +168,7 @@ class Step(ABC):
|
|
|
174
168
|
class QueryStep:
|
|
175
169
|
catalog: "Catalog"
|
|
176
170
|
dataset_name: str
|
|
177
|
-
dataset_version:
|
|
171
|
+
dataset_version: str
|
|
178
172
|
|
|
179
173
|
def apply(self):
|
|
180
174
|
def q(*columns):
|
|
@@ -420,41 +414,30 @@ class UDFStep(Step, ABC):
|
|
|
420
414
|
"""
|
|
421
415
|
|
|
422
416
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
rows_total = self.catalog.warehouse.query_count(query)
|
|
426
|
-
if rows_total == 0:
|
|
417
|
+
if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
|
|
427
418
|
return
|
|
428
419
|
|
|
420
|
+
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
|
|
421
|
+
from datachain.catalog.loader import (
|
|
422
|
+
DISTRIBUTED_IMPORT_PATH,
|
|
423
|
+
get_udf_distributor_class,
|
|
424
|
+
)
|
|
425
|
+
|
|
429
426
|
workers = determine_workers(self.workers, rows_total=rows_total)
|
|
430
427
|
processes = determine_processes(self.parallel, rows_total=rows_total)
|
|
431
428
|
|
|
432
429
|
use_partitioning = self.partition_by is not None
|
|
433
430
|
batching = self.udf.get_batching(use_partitioning)
|
|
434
431
|
udf_fields = [str(c.name) for c in query.selected_columns]
|
|
432
|
+
udf_distributor_class = get_udf_distributor_class()
|
|
435
433
|
|
|
436
434
|
prefetch = self.udf.prefetch
|
|
437
435
|
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
438
436
|
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
439
|
-
try:
|
|
440
|
-
if workers:
|
|
441
|
-
if catalog.in_memory:
|
|
442
|
-
raise RuntimeError(
|
|
443
|
-
"In-memory databases cannot be used with "
|
|
444
|
-
"distributed processing."
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
from datachain.catalog.loader import (
|
|
448
|
-
DISTRIBUTED_IMPORT_PATH,
|
|
449
|
-
get_udf_distributor_class,
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
if not (udf_distributor_class := get_udf_distributor_class()):
|
|
453
|
-
raise RuntimeError(
|
|
454
|
-
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
455
|
-
"for distributed UDF processing."
|
|
456
|
-
)
|
|
457
437
|
|
|
438
|
+
try:
|
|
439
|
+
if udf_distributor_class and not catalog.in_memory:
|
|
440
|
+
# Use the UDF distributor if available (running in SaaS)
|
|
458
441
|
udf_distributor = udf_distributor_class(
|
|
459
442
|
catalog=catalog,
|
|
460
443
|
table=udf_table,
|
|
@@ -470,7 +453,20 @@ class UDFStep(Step, ABC):
|
|
|
470
453
|
min_task_size=self.min_task_size,
|
|
471
454
|
)
|
|
472
455
|
udf_distributor()
|
|
473
|
-
|
|
456
|
+
return
|
|
457
|
+
|
|
458
|
+
if workers:
|
|
459
|
+
if catalog.in_memory:
|
|
460
|
+
raise RuntimeError(
|
|
461
|
+
"In-memory databases cannot be used with "
|
|
462
|
+
"distributed processing."
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
raise RuntimeError(
|
|
466
|
+
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
467
|
+
"for distributed UDF processing."
|
|
468
|
+
)
|
|
469
|
+
if processes:
|
|
474
470
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
475
471
|
if catalog.in_memory:
|
|
476
472
|
raise RuntimeError(
|
|
@@ -504,7 +500,12 @@ class UDFStep(Step, ABC):
|
|
|
504
500
|
with subprocess.Popen( # noqa: S603
|
|
505
501
|
cmd, env=envs, stdin=subprocess.PIPE
|
|
506
502
|
) as process:
|
|
507
|
-
|
|
503
|
+
try:
|
|
504
|
+
process.communicate(process_data)
|
|
505
|
+
except KeyboardInterrupt:
|
|
506
|
+
raise QueryScriptCancelError(
|
|
507
|
+
"UDF execution was canceled by the user."
|
|
508
|
+
) from None
|
|
508
509
|
if retval := process.poll():
|
|
509
510
|
raise RuntimeError(
|
|
510
511
|
f"UDF Execution Failed! Exit code: {retval}"
|
|
@@ -1091,7 +1092,7 @@ class DatasetQuery:
|
|
|
1091
1092
|
def __init__(
|
|
1092
1093
|
self,
|
|
1093
1094
|
name: str,
|
|
1094
|
-
version: Optional[
|
|
1095
|
+
version: Optional[str] = None,
|
|
1095
1096
|
catalog: Optional["Catalog"] = None,
|
|
1096
1097
|
session: Optional[Session] = None,
|
|
1097
1098
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
@@ -1111,7 +1112,7 @@ class DatasetQuery:
|
|
|
1111
1112
|
self.table = self.get_table()
|
|
1112
1113
|
self.starting_step: Optional[QueryStep] = None
|
|
1113
1114
|
self.name: Optional[str] = None
|
|
1114
|
-
self.version: Optional[
|
|
1115
|
+
self.version: Optional[str] = None
|
|
1115
1116
|
self.feature_schema: Optional[dict] = None
|
|
1116
1117
|
self.column_types: Optional[dict[str, Any]] = None
|
|
1117
1118
|
self.before_steps: list[Callable] = []
|
|
@@ -1154,7 +1155,7 @@ class DatasetQuery:
|
|
|
1154
1155
|
def __or__(self, other):
|
|
1155
1156
|
return self.union(other)
|
|
1156
1157
|
|
|
1157
|
-
def pull_dataset(self, name: str, version: Optional[
|
|
1158
|
+
def pull_dataset(self, name: str, version: Optional[str] = None) -> "DatasetRecord":
|
|
1158
1159
|
print("Dataset not found in local catalog, trying to get from studio")
|
|
1159
1160
|
|
|
1160
1161
|
remote_ds_uri = f"{DATASET_PREFIX}{name}"
|
|
@@ -1184,8 +1185,8 @@ class DatasetQuery:
|
|
|
1184
1185
|
it completely. If this is the case, name and version of underlying dataset
|
|
1185
1186
|
will be defined.
|
|
1186
1187
|
DatasetQuery instance can become attached in two scenarios:
|
|
1187
|
-
1. ds = DatasetQuery(name="dogs", version=1) -> ds is attached to dogs
|
|
1188
|
-
2. ds = ds.save("dogs", version=1) -> ds is attached to dogs dataset
|
|
1188
|
+
1. ds = DatasetQuery(name="dogs", version="1.0.0") -> ds is attached to dogs
|
|
1189
|
+
2. ds = ds.save("dogs", version="1.0.0") -> ds is attached to dogs dataset
|
|
1189
1190
|
It can move to detached state if filter or similar methods are called on it,
|
|
1190
1191
|
as then it no longer 100% represents underlying datasets.
|
|
1191
1192
|
"""
|
|
@@ -1662,7 +1663,7 @@ class DatasetQuery:
|
|
|
1662
1663
|
)
|
|
1663
1664
|
return query
|
|
1664
1665
|
|
|
1665
|
-
def _add_dependencies(self, dataset: "DatasetRecord", version:
|
|
1666
|
+
def _add_dependencies(self, dataset: "DatasetRecord", version: str):
|
|
1666
1667
|
for dependency in self.dependencies:
|
|
1667
1668
|
ds_dependency_name, ds_dependency_version = dependency
|
|
1668
1669
|
self.catalog.metastore.add_dataset_dependency(
|
|
@@ -1684,7 +1685,7 @@ class DatasetQuery:
|
|
|
1684
1685
|
def save(
|
|
1685
1686
|
self,
|
|
1686
1687
|
name: Optional[str] = None,
|
|
1687
|
-
version: Optional[
|
|
1688
|
+
version: Optional[str] = None,
|
|
1688
1689
|
feature_schema: Optional[dict] = None,
|
|
1689
1690
|
description: Optional[str] = None,
|
|
1690
1691
|
attrs: Optional[list[str]] = None,
|