datachain 0.14.1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +18 -18
- datachain/catalog/catalog.py +5 -5
- datachain/catalog/loader.py +4 -9
- datachain/cli/commands/show.py +2 -2
- datachain/data_storage/warehouse.py +9 -0
- datachain/lib/dc/__init__.py +18 -18
- datachain/lib/dc/csv.py +5 -5
- datachain/lib/dc/datachain.py +42 -42
- datachain/lib/dc/datasets.py +7 -7
- datachain/lib/dc/hf.py +5 -5
- datachain/lib/dc/json.py +5 -5
- datachain/lib/dc/listings.py +2 -2
- datachain/lib/dc/pandas.py +4 -4
- datachain/lib/dc/parquet.py +5 -5
- datachain/lib/dc/records.py +4 -4
- datachain/lib/dc/storage.py +13 -12
- datachain/lib/dc/values.py +4 -4
- datachain/lib/listing.py +11 -0
- datachain/lib/meta_formats.py +2 -2
- datachain/lib/pytorch.py +2 -2
- datachain/lib/udf.py +1 -1
- datachain/query/dataset.py +62 -50
- datachain/query/dispatch.py +6 -12
- datachain/query/udf.py +30 -1
- datachain/toolkit/split.py +1 -1
- datachain/utils.py +30 -4
- {datachain-0.14.1.dist-info → datachain-0.14.3.dist-info}/METADATA +5 -5
- {datachain-0.14.1.dist-info → datachain-0.14.3.dist-info}/RECORD +32 -32
- {datachain-0.14.1.dist-info → datachain-0.14.3.dist-info}/WHEEL +0 -0
- {datachain-0.14.1.dist-info → datachain-0.14.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.1.dist-info → datachain-0.14.3.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.1.dist-info → datachain-0.14.3.dist-info}/top_level.txt +0 -0
datachain/lib/dc/pandas.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import (
|
|
|
5
5
|
|
|
6
6
|
from datachain.query import Session
|
|
7
7
|
|
|
8
|
-
from .values import
|
|
8
|
+
from .values import read_values
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
import pandas as pd
|
|
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|
|
16
16
|
P = ParamSpec("P")
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def read_pandas( # type: ignore[override]
|
|
20
20
|
df: "pd.DataFrame",
|
|
21
21
|
name: str = "",
|
|
22
22
|
session: Optional[Session] = None,
|
|
@@ -32,7 +32,7 @@ def from_pandas( # type: ignore[override]
|
|
|
32
32
|
import datachain as dc
|
|
33
33
|
|
|
34
34
|
df = pd.DataFrame({"fib": [1, 2, 3, 5, 8]})
|
|
35
|
-
dc.
|
|
35
|
+
dc.read_pandas(df)
|
|
36
36
|
```
|
|
37
37
|
"""
|
|
38
38
|
from .utils import DatasetPrepareError
|
|
@@ -46,7 +46,7 @@ def from_pandas( # type: ignore[override]
|
|
|
46
46
|
f"import from pandas error - '{column}' cannot be a column name",
|
|
47
47
|
)
|
|
48
48
|
|
|
49
|
-
return
|
|
49
|
+
return read_values(
|
|
50
50
|
name,
|
|
51
51
|
session,
|
|
52
52
|
settings=settings,
|
datachain/lib/dc/parquet.py
CHANGED
|
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|
|
15
15
|
P = ParamSpec("P")
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def
|
|
18
|
+
def read_parquet(
|
|
19
19
|
path,
|
|
20
20
|
partitioning: Any = "hive",
|
|
21
21
|
output: Optional[dict[str, DataType]] = None,
|
|
@@ -43,18 +43,18 @@ def from_parquet(
|
|
|
43
43
|
Reading a single file:
|
|
44
44
|
```py
|
|
45
45
|
import datachain as dc
|
|
46
|
-
dc.
|
|
46
|
+
dc.read_parquet("s3://mybucket/file.parquet")
|
|
47
47
|
```
|
|
48
48
|
|
|
49
49
|
Reading a partitioned dataset from a directory:
|
|
50
50
|
```py
|
|
51
51
|
import datachain as dc
|
|
52
|
-
dc.
|
|
52
|
+
dc.read_parquet("s3://mybucket/dir")
|
|
53
53
|
```
|
|
54
54
|
"""
|
|
55
|
-
from .storage import
|
|
55
|
+
from .storage import read_storage
|
|
56
56
|
|
|
57
|
-
chain =
|
|
57
|
+
chain = read_storage(path, session=session, settings=settings, **kwargs)
|
|
58
58
|
return chain.parse_tabular(
|
|
59
59
|
output=output,
|
|
60
60
|
object_name=object_name,
|
datachain/lib/dc/records.py
CHANGED
|
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|
|
21
21
|
P = ParamSpec("P")
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def
|
|
24
|
+
def read_records(
|
|
25
25
|
to_insert: Optional[Union[dict, list[dict]]],
|
|
26
26
|
session: Optional[Session] = None,
|
|
27
27
|
settings: Optional[dict] = None,
|
|
@@ -40,10 +40,10 @@ def from_records(
|
|
|
40
40
|
Example:
|
|
41
41
|
```py
|
|
42
42
|
import datachain as dc
|
|
43
|
-
single_record = dc.
|
|
43
|
+
single_record = dc.read_records(dc.DEFAULT_FILE_RECORD)
|
|
44
44
|
```
|
|
45
45
|
"""
|
|
46
|
-
from .datasets import
|
|
46
|
+
from .datasets import read_dataset
|
|
47
47
|
|
|
48
48
|
session = Session.get(session, in_memory=in_memory)
|
|
49
49
|
catalog = session.catalog
|
|
@@ -87,4 +87,4 @@ def from_records(
|
|
|
87
87
|
insert_q = dr.get_table().insert()
|
|
88
88
|
for record in to_insert:
|
|
89
89
|
db.execute(insert_q.values(**record))
|
|
90
|
-
return
|
|
90
|
+
return read_dataset(name=dsr.name, session=session, settings=settings)
|
datachain/lib/dc/storage.py
CHANGED
|
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|
|
21
21
|
from .datachain import DataChain
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def
|
|
24
|
+
def read_storage(
|
|
25
25
|
uri: Union[str, os.PathLike[str], list[str], list[os.PathLike[str]]],
|
|
26
26
|
*,
|
|
27
27
|
type: FileType = "binary",
|
|
@@ -55,12 +55,12 @@ def from_storage(
|
|
|
55
55
|
Simple call from s3:
|
|
56
56
|
```python
|
|
57
57
|
import datachain as dc
|
|
58
|
-
chain = dc.
|
|
58
|
+
chain = dc.read_storage("s3://my-bucket/my-dir")
|
|
59
59
|
```
|
|
60
60
|
|
|
61
61
|
Multiple URIs:
|
|
62
62
|
```python
|
|
63
|
-
chain = dc.
|
|
63
|
+
chain = dc.read_storage([
|
|
64
64
|
"s3://bucket1/dir1",
|
|
65
65
|
"s3://bucket2/dir2"
|
|
66
66
|
])
|
|
@@ -68,7 +68,7 @@ def from_storage(
|
|
|
68
68
|
|
|
69
69
|
With AWS S3-compatible storage:
|
|
70
70
|
```python
|
|
71
|
-
chain = dc.
|
|
71
|
+
chain = dc.read_storage(
|
|
72
72
|
"s3://my-bucket/my-dir",
|
|
73
73
|
client_config = {"aws_endpoint_url": "<minio-endpoint-url>"}
|
|
74
74
|
)
|
|
@@ -77,7 +77,7 @@ def from_storage(
|
|
|
77
77
|
Pass existing session
|
|
78
78
|
```py
|
|
79
79
|
session = Session.get()
|
|
80
|
-
chain = dc.
|
|
80
|
+
chain = dc.read_storage([
|
|
81
81
|
"path/to/dir1",
|
|
82
82
|
"path/to/dir2"
|
|
83
83
|
], session=session, recursive=True)
|
|
@@ -88,9 +88,9 @@ def from_storage(
|
|
|
88
88
|
avoiding redundant updates for URIs pointing to the same storage location.
|
|
89
89
|
"""
|
|
90
90
|
from .datachain import DataChain
|
|
91
|
-
from .datasets import
|
|
92
|
-
from .records import
|
|
93
|
-
from .values import
|
|
91
|
+
from .datasets import read_dataset
|
|
92
|
+
from .records import read_records
|
|
93
|
+
from .values import read_values
|
|
94
94
|
|
|
95
95
|
file_type = get_file_type(type)
|
|
96
96
|
|
|
@@ -122,7 +122,8 @@ def from_storage(
|
|
|
122
122
|
)
|
|
123
123
|
continue
|
|
124
124
|
|
|
125
|
-
dc =
|
|
125
|
+
dc = read_dataset(list_ds_name, session=session, settings=settings)
|
|
126
|
+
dc._query.update = update
|
|
126
127
|
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
|
|
127
128
|
|
|
128
129
|
if update or not list_ds_exists:
|
|
@@ -130,7 +131,7 @@ def from_storage(
|
|
|
130
131
|
def lst_fn(ds_name, lst_uri):
|
|
131
132
|
# disable prefetch for listing, as it pre-downloads all files
|
|
132
133
|
(
|
|
133
|
-
|
|
134
|
+
read_records(
|
|
134
135
|
DataChain.DEFAULT_FILE_RECORD,
|
|
135
136
|
session=session,
|
|
136
137
|
settings=settings,
|
|
@@ -144,7 +145,7 @@ def from_storage(
|
|
|
144
145
|
.save(ds_name, listing=True)
|
|
145
146
|
)
|
|
146
147
|
|
|
147
|
-
dc._query.
|
|
148
|
+
dc._query.set_listing_fn(
|
|
148
149
|
lambda ds_name=list_ds_name, lst_uri=list_uri: lst_fn(ds_name, lst_uri)
|
|
149
150
|
)
|
|
150
151
|
|
|
@@ -154,7 +155,7 @@ def from_storage(
|
|
|
154
155
|
listed_ds_name.add(list_ds_name)
|
|
155
156
|
|
|
156
157
|
if file_values:
|
|
157
|
-
file_chain =
|
|
158
|
+
file_chain = read_values(
|
|
158
159
|
session=session,
|
|
159
160
|
settings=settings,
|
|
160
161
|
in_memory=in_memory,
|
datachain/lib/dc/values.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import (
|
|
|
6
6
|
|
|
7
7
|
from datachain.lib.convert.values_to_tuples import values_to_tuples
|
|
8
8
|
from datachain.lib.data_model import dict_to_data_model
|
|
9
|
-
from datachain.lib.dc.records import
|
|
9
|
+
from datachain.lib.dc.records import read_records
|
|
10
10
|
from datachain.lib.dc.utils import OutputType
|
|
11
11
|
from datachain.query import Session
|
|
12
12
|
|
|
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
|
|
18
18
|
P = ParamSpec("P")
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def
|
|
21
|
+
def read_values(
|
|
22
22
|
ds_name: str = "",
|
|
23
23
|
session: Optional[Session] = None,
|
|
24
24
|
settings: Optional[dict] = None,
|
|
@@ -32,7 +32,7 @@ def from_values(
|
|
|
32
32
|
Example:
|
|
33
33
|
```py
|
|
34
34
|
import datachain as dc
|
|
35
|
-
dc.
|
|
35
|
+
dc.read_values(fib=[1, 2, 3, 5, 8])
|
|
36
36
|
```
|
|
37
37
|
"""
|
|
38
38
|
from .datachain import DataChain
|
|
@@ -42,7 +42,7 @@ def from_values(
|
|
|
42
42
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
43
43
|
yield from tuples
|
|
44
44
|
|
|
45
|
-
chain =
|
|
45
|
+
chain = read_records(
|
|
46
46
|
DataChain.DEFAULT_FILE_RECORD,
|
|
47
47
|
session=session,
|
|
48
48
|
settings=settings,
|
datachain/lib/listing.py
CHANGED
|
@@ -4,6 +4,7 @@ import os
|
|
|
4
4
|
import posixpath
|
|
5
5
|
from collections.abc import Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
+
from datetime import datetime, timedelta, timezone
|
|
7
8
|
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
|
8
9
|
|
|
9
10
|
from fsspec.asyn import get_loop
|
|
@@ -32,6 +33,16 @@ logging.getLogger("aiobotocore.credentials").setLevel(logging.CRITICAL)
|
|
|
32
33
|
logging.getLogger("gcsfs").setLevel(logging.CRITICAL)
|
|
33
34
|
|
|
34
35
|
|
|
36
|
+
def listing_dataset_expired(lst_ds) -> bool:
|
|
37
|
+
"""Function that checks if listing dataset is expired or not"""
|
|
38
|
+
lst_version = lst_ds.versions[-1]
|
|
39
|
+
if not lst_version.finished_at:
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
expires = lst_version.finished_at + timedelta(seconds=LISTING_TTL)
|
|
43
|
+
return datetime.now(timezone.utc) > expires
|
|
44
|
+
|
|
45
|
+
|
|
35
46
|
def list_bucket(uri: str, cache, client_config=None) -> Callable:
|
|
36
47
|
"""
|
|
37
48
|
Function that returns another generator function that yields File objects
|
datachain/lib/meta_formats.py
CHANGED
|
@@ -103,10 +103,10 @@ def read_meta( # noqa: C901
|
|
|
103
103
|
model_name=None,
|
|
104
104
|
nrows=None,
|
|
105
105
|
) -> Callable:
|
|
106
|
-
from datachain import
|
|
106
|
+
from datachain import read_storage
|
|
107
107
|
|
|
108
108
|
if schema_from:
|
|
109
|
-
file = next(
|
|
109
|
+
file = next(read_storage(schema_from, type="text").limit(1).collect("file"))
|
|
110
110
|
model_code = gen_datamodel_code(
|
|
111
111
|
file, format=format, jmespath=jmespath, model_name=model_name
|
|
112
112
|
)
|
datachain/lib/pytorch.py
CHANGED
|
@@ -14,7 +14,7 @@ from torchvision.transforms import v2
|
|
|
14
14
|
from datachain import Session
|
|
15
15
|
from datachain.cache import get_temp_cache
|
|
16
16
|
from datachain.catalog import Catalog, get_catalog
|
|
17
|
-
from datachain.lib.dc.datasets import
|
|
17
|
+
from datachain.lib.dc.datasets import read_dataset
|
|
18
18
|
from datachain.lib.settings import Settings
|
|
19
19
|
from datachain.lib.text import convert_text
|
|
20
20
|
from datachain.progress import CombinedDownloadCallback
|
|
@@ -122,7 +122,7 @@ class PytorchDataset(IterableDataset):
|
|
|
122
122
|
) -> Generator[tuple[Any, ...], None, None]:
|
|
123
123
|
catalog = self._get_catalog()
|
|
124
124
|
session = Session("PyTorch", catalog=catalog)
|
|
125
|
-
ds =
|
|
125
|
+
ds = read_dataset(
|
|
126
126
|
name=self.name, version=self.version, session=session
|
|
127
127
|
).settings(cache=self.cache, prefetch=self.prefetch)
|
|
128
128
|
ds = ds.remove_file_signals()
|
datachain/lib/udf.py
CHANGED
datachain/query/dataset.py
CHANGED
|
@@ -47,15 +47,20 @@ from datachain.error import (
|
|
|
47
47
|
QueryScriptCancelError,
|
|
48
48
|
)
|
|
49
49
|
from datachain.func.base import Function
|
|
50
|
-
from datachain.lib.listing import
|
|
50
|
+
from datachain.lib.listing import (
|
|
51
|
+
is_listing_dataset,
|
|
52
|
+
listing_dataset_expired,
|
|
53
|
+
)
|
|
51
54
|
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
52
55
|
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
53
56
|
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
54
57
|
from datachain.query.session import Session
|
|
58
|
+
from datachain.query.udf import UdfInfo
|
|
55
59
|
from datachain.sql.functions.random import rand
|
|
56
60
|
from datachain.utils import (
|
|
57
61
|
batched,
|
|
58
62
|
determine_processes,
|
|
63
|
+
determine_workers,
|
|
59
64
|
filtered_cloudpickle_dumps,
|
|
60
65
|
get_datachain_executable,
|
|
61
66
|
safe_closing,
|
|
@@ -71,7 +76,6 @@ if TYPE_CHECKING:
|
|
|
71
76
|
from datachain.data_storage import AbstractWarehouse
|
|
72
77
|
from datachain.dataset import DatasetRecord
|
|
73
78
|
from datachain.lib.udf import UDFAdapter, UDFResult
|
|
74
|
-
from datachain.query.udf import UdfInfo
|
|
75
79
|
|
|
76
80
|
P = ParamSpec("P")
|
|
77
81
|
|
|
@@ -411,20 +415,15 @@ class UDFStep(Step, ABC):
|
|
|
411
415
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
412
416
|
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
|
|
413
417
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
if (
|
|
418
|
-
not workers
|
|
419
|
-
and os.environ.get("DATACHAIN_DISTRIBUTED")
|
|
420
|
-
and os.environ.get("DATACHAIN_SETTINGS_WORKERS")
|
|
421
|
-
):
|
|
422
|
-
# Enable distributed processing by default if the module is available,
|
|
423
|
-
# and a default number of workers is provided.
|
|
424
|
-
workers = True
|
|
418
|
+
rows_total = self.catalog.warehouse.query_count(query)
|
|
419
|
+
if rows_total == 0:
|
|
420
|
+
return
|
|
425
421
|
|
|
426
|
-
|
|
422
|
+
workers = determine_workers(self.workers, rows_total=rows_total)
|
|
423
|
+
processes = determine_processes(self.parallel, rows_total=rows_total)
|
|
427
424
|
|
|
425
|
+
use_partitioning = self.partition_by is not None
|
|
426
|
+
batching = self.udf.get_batching(use_partitioning)
|
|
428
427
|
udf_fields = [str(c.name) for c in query.selected_columns]
|
|
429
428
|
|
|
430
429
|
prefetch = self.udf.prefetch
|
|
@@ -438,23 +437,24 @@ class UDFStep(Step, ABC):
|
|
|
438
437
|
"distributed processing."
|
|
439
438
|
)
|
|
440
439
|
|
|
441
|
-
from datachain.catalog.loader import
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
processes,
|
|
440
|
+
from datachain.catalog.loader import get_udf_distributor_class
|
|
441
|
+
|
|
442
|
+
udf_distributor_class = get_udf_distributor_class()
|
|
443
|
+
udf_distributor = udf_distributor_class(
|
|
444
|
+
catalog=catalog,
|
|
445
|
+
table=udf_table,
|
|
446
|
+
query=query,
|
|
447
|
+
udf_data=filtered_cloudpickle_dumps(self.udf),
|
|
448
|
+
batching=batching,
|
|
449
|
+
workers=workers,
|
|
450
|
+
processes=processes,
|
|
453
451
|
udf_fields=udf_fields,
|
|
452
|
+
rows_total=rows_total,
|
|
453
|
+
use_cache=self.cache,
|
|
454
454
|
is_generator=self.is_generator,
|
|
455
|
-
|
|
456
|
-
cache=self.cache,
|
|
455
|
+
min_task_size=self.min_task_size,
|
|
457
456
|
)
|
|
457
|
+
udf_distributor()
|
|
458
458
|
elif processes:
|
|
459
459
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
460
460
|
if catalog.in_memory:
|
|
@@ -462,19 +462,21 @@ class UDFStep(Step, ABC):
|
|
|
462
462
|
"In-memory databases cannot be used "
|
|
463
463
|
"with parallel processing."
|
|
464
464
|
)
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
465
|
+
|
|
466
|
+
udf_info = UdfInfo(
|
|
467
|
+
udf_data=filtered_cloudpickle_dumps(self.udf),
|
|
468
|
+
catalog_init=catalog.get_init_params(),
|
|
469
|
+
metastore_clone_params=catalog.metastore.clone_params(),
|
|
470
|
+
warehouse_clone_params=catalog.warehouse.clone_params(),
|
|
471
|
+
table=udf_table,
|
|
472
|
+
query=query,
|
|
473
|
+
udf_fields=udf_fields,
|
|
474
|
+
batching=batching,
|
|
475
|
+
processes=processes,
|
|
476
|
+
is_generator=self.is_generator,
|
|
477
|
+
cache=self.cache,
|
|
478
|
+
rows_total=rows_total,
|
|
479
|
+
)
|
|
478
480
|
|
|
479
481
|
# Run the UDFDispatcher in another process to avoid needing
|
|
480
482
|
# if __name__ == '__main__': in user scripts
|
|
@@ -1080,6 +1082,7 @@ class DatasetQuery:
|
|
|
1080
1082
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1081
1083
|
in_memory: bool = False,
|
|
1082
1084
|
fallback_to_studio: bool = True,
|
|
1085
|
+
update: bool = False,
|
|
1083
1086
|
) -> None:
|
|
1084
1087
|
from datachain.remote.studio import is_token_set
|
|
1085
1088
|
|
|
@@ -1097,6 +1100,8 @@ class DatasetQuery:
|
|
|
1097
1100
|
self.feature_schema: Optional[dict] = None
|
|
1098
1101
|
self.column_types: Optional[dict[str, Any]] = None
|
|
1099
1102
|
self.before_steps: list[Callable] = []
|
|
1103
|
+
self.listing_fn: Optional[Callable] = None
|
|
1104
|
+
self.update = update
|
|
1100
1105
|
|
|
1101
1106
|
self.list_ds_name: Optional[str] = None
|
|
1102
1107
|
|
|
@@ -1190,23 +1195,30 @@ class DatasetQuery:
|
|
|
1190
1195
|
col.table = self.table
|
|
1191
1196
|
return col
|
|
1192
1197
|
|
|
1193
|
-
def
|
|
1194
|
-
"""
|
|
1195
|
-
|
|
1196
|
-
"""
|
|
1197
|
-
self.before_steps.append(fn)
|
|
1198
|
+
def set_listing_fn(self, fn: Callable) -> None:
|
|
1199
|
+
"""Setting listing function to be run if needed"""
|
|
1200
|
+
self.listing_fn = fn
|
|
1198
1201
|
|
|
1199
1202
|
def apply_steps(self) -> QueryGenerator:
|
|
1200
1203
|
"""
|
|
1201
1204
|
Apply the steps in the query and return the resulting
|
|
1202
1205
|
sqlalchemy.SelectBase.
|
|
1203
1206
|
"""
|
|
1204
|
-
|
|
1205
|
-
|
|
1207
|
+
if self.list_ds_name and not self.starting_step:
|
|
1208
|
+
listing_ds = None
|
|
1209
|
+
try:
|
|
1210
|
+
listing_ds = self.catalog.get_dataset(self.list_ds_name)
|
|
1211
|
+
except DatasetNotFoundError:
|
|
1212
|
+
pass
|
|
1213
|
+
|
|
1214
|
+
if not listing_ds or self.update or listing_dataset_expired(listing_ds):
|
|
1215
|
+
assert self.listing_fn
|
|
1216
|
+
self.listing_fn()
|
|
1217
|
+
listing_ds = self.catalog.get_dataset(self.list_ds_name)
|
|
1206
1218
|
|
|
1207
|
-
if self.list_ds_name:
|
|
1208
1219
|
# at this point we know what is our starting listing dataset name
|
|
1209
|
-
self._set_starting_step(
|
|
1220
|
+
self._set_starting_step(listing_ds) # type: ignore [arg-type]
|
|
1221
|
+
|
|
1210
1222
|
query = self.clone()
|
|
1211
1223
|
|
|
1212
1224
|
index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index)
|
datachain/query/dispatch.py
CHANGED
|
@@ -11,11 +11,10 @@ import multiprocess
|
|
|
11
11
|
from cloudpickle import load, loads
|
|
12
12
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
13
13
|
from multiprocess import get_context
|
|
14
|
-
from sqlalchemy.sql import func
|
|
15
14
|
|
|
16
15
|
from datachain.catalog import Catalog
|
|
17
16
|
from datachain.catalog.catalog import clone_catalog_with_cache
|
|
18
|
-
from datachain.catalog.loader import
|
|
17
|
+
from datachain.catalog.loader import get_udf_distributor_class
|
|
19
18
|
from datachain.lib.udf import _get_cache
|
|
20
19
|
from datachain.query.batch import RowsOutput, RowsOutputBatch
|
|
21
20
|
from datachain.query.dataset import (
|
|
@@ -59,6 +58,7 @@ def udf_entrypoint() -> int:
|
|
|
59
58
|
dispatch = UDFDispatcher(udf_info)
|
|
60
59
|
|
|
61
60
|
query = udf_info["query"]
|
|
61
|
+
rows_total = udf_info["rows_total"]
|
|
62
62
|
batching = udf_info["batching"]
|
|
63
63
|
n_workers = udf_info["processes"]
|
|
64
64
|
if n_workers is True:
|
|
@@ -67,12 +67,6 @@ def udf_entrypoint() -> int:
|
|
|
67
67
|
wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
|
|
68
68
|
warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
|
|
69
69
|
|
|
70
|
-
total_rows = next(
|
|
71
|
-
warehouse.db.execute(
|
|
72
|
-
query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
|
|
73
|
-
)
|
|
74
|
-
)[0]
|
|
75
|
-
|
|
76
70
|
with contextlib.closing(
|
|
77
71
|
batching(warehouse.dataset_select_paginated, query, ids_only=True)
|
|
78
72
|
) as udf_inputs:
|
|
@@ -81,7 +75,7 @@ def udf_entrypoint() -> int:
|
|
|
81
75
|
try:
|
|
82
76
|
dispatch.run_udf_parallel(
|
|
83
77
|
udf_inputs,
|
|
84
|
-
|
|
78
|
+
rows_total=rows_total,
|
|
85
79
|
n_workers=n_workers,
|
|
86
80
|
processed_cb=processed_cb,
|
|
87
81
|
download_cb=download_cb,
|
|
@@ -94,7 +88,7 @@ def udf_entrypoint() -> int:
|
|
|
94
88
|
|
|
95
89
|
|
|
96
90
|
def udf_worker_entrypoint() -> int:
|
|
97
|
-
return
|
|
91
|
+
return get_udf_distributor_class().run_worker()
|
|
98
92
|
|
|
99
93
|
|
|
100
94
|
class UDFDispatcher:
|
|
@@ -164,14 +158,14 @@ class UDFDispatcher:
|
|
|
164
158
|
def run_udf_parallel( # noqa: C901, PLR0912
|
|
165
159
|
self,
|
|
166
160
|
input_rows: Iterable[RowsOutput],
|
|
167
|
-
|
|
161
|
+
rows_total: int,
|
|
168
162
|
n_workers: Optional[int] = None,
|
|
169
163
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
170
164
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
171
165
|
) -> None:
|
|
172
166
|
n_workers = get_n_workers_from_arg(n_workers)
|
|
173
167
|
|
|
174
|
-
input_batch_size =
|
|
168
|
+
input_batch_size = rows_total // n_workers
|
|
175
169
|
if input_batch_size == 0:
|
|
176
170
|
input_batch_size = 1
|
|
177
171
|
elif input_batch_size > DEFAULT_BATCH_SIZE:
|
datachain/query/udf.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
from
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union
|
|
2
3
|
|
|
3
4
|
if TYPE_CHECKING:
|
|
4
5
|
from sqlalchemy import Select, Table
|
|
5
6
|
|
|
7
|
+
from datachain.catalog import Catalog
|
|
6
8
|
from datachain.query.batch import BatchingStrategy
|
|
7
9
|
|
|
8
10
|
|
|
@@ -18,3 +20,30 @@ class UdfInfo(TypedDict):
|
|
|
18
20
|
processes: Optional[int]
|
|
19
21
|
is_generator: bool
|
|
20
22
|
cache: bool
|
|
23
|
+
rows_total: int
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AbstractUDFDistributor(ABC):
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
catalog: "Catalog",
|
|
31
|
+
table: "Table",
|
|
32
|
+
query: "Select",
|
|
33
|
+
udf_data: bytes,
|
|
34
|
+
batching: "BatchingStrategy",
|
|
35
|
+
workers: Union[bool, int],
|
|
36
|
+
processes: Union[bool, int],
|
|
37
|
+
udf_fields: list[str],
|
|
38
|
+
rows_total: int,
|
|
39
|
+
use_cache: bool,
|
|
40
|
+
is_generator: bool = False,
|
|
41
|
+
min_task_size: Optional[Union[str, int]] = None,
|
|
42
|
+
) -> None: ...
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def __call__(self) -> None: ...
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def run_worker() -> int: ...
|
datachain/toolkit/split.py
CHANGED
|
@@ -41,7 +41,7 @@ def train_test_split(
|
|
|
41
41
|
from datachain.toolkit import train_test_split
|
|
42
42
|
|
|
43
43
|
# Load a DataChain from a storage source (e.g., S3 bucket)
|
|
44
|
-
dc = dc.
|
|
44
|
+
dc = dc.read_storage("s3://bucket/dir/")
|
|
45
45
|
|
|
46
46
|
# Perform a 70/30 train-test split
|
|
47
47
|
train, test = train_test_split(dc, [0.7, 0.3])
|
datachain/utils.py
CHANGED
|
@@ -286,15 +286,41 @@ def retry_with_backoff(retries=5, backoff_sec=1, errors=(Exception,)):
|
|
|
286
286
|
return retry
|
|
287
287
|
|
|
288
288
|
|
|
289
|
-
def
|
|
289
|
+
def determine_workers(
|
|
290
|
+
workers: Union[bool, int],
|
|
291
|
+
rows_total: Optional[int] = None,
|
|
292
|
+
) -> Union[bool, int]:
|
|
293
|
+
"""Determine the number of workers to use for distributed processing."""
|
|
294
|
+
if rows_total is not None and rows_total <= 1:
|
|
295
|
+
# Disable distributed processing if there is no rows or only one row.
|
|
296
|
+
return False
|
|
297
|
+
if (
|
|
298
|
+
workers is False
|
|
299
|
+
and os.environ.get("DATACHAIN_DISTRIBUTED")
|
|
300
|
+
and os.environ.get("DATACHAIN_SETTINGS_WORKERS")
|
|
301
|
+
):
|
|
302
|
+
# Enable distributed processing by default if the module is available,
|
|
303
|
+
# and a default number of workers is provided.
|
|
304
|
+
workers = int(os.environ["DATACHAIN_SETTINGS_WORKERS"])
|
|
305
|
+
if not workers or workers <= 0:
|
|
306
|
+
return False
|
|
307
|
+
return workers
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def determine_processes(
|
|
311
|
+
parallel: Optional[Union[bool, int]] = None,
|
|
312
|
+
rows_total: Optional[int] = None,
|
|
313
|
+
) -> Union[bool, int]:
|
|
314
|
+
"""Determine the number of processes to use for parallel processing."""
|
|
315
|
+
if rows_total is not None and rows_total <= 1:
|
|
316
|
+
# Disable parallel processing if there is no rows or only one row.
|
|
317
|
+
return False
|
|
290
318
|
if parallel is None and os.environ.get("DATACHAIN_SETTINGS_PARALLEL") is not None:
|
|
291
319
|
parallel = int(os.environ["DATACHAIN_SETTINGS_PARALLEL"])
|
|
292
|
-
if parallel is None or parallel is False:
|
|
320
|
+
if parallel is None or parallel is False or parallel == 0:
|
|
293
321
|
return False
|
|
294
322
|
if parallel is True:
|
|
295
323
|
return True
|
|
296
|
-
if parallel == 0:
|
|
297
|
-
return False
|
|
298
324
|
if parallel < 0:
|
|
299
325
|
return True
|
|
300
326
|
return parallel
|