datachain 0.16.4__py3-none-any.whl → 0.16.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +5 -1
- datachain/cli/__init__.py +11 -9
- datachain/cli/commands/query.py +1 -0
- datachain/cli/parser/__init__.py +9 -1
- datachain/data_storage/job.py +1 -0
- datachain/data_storage/metastore.py +82 -71
- datachain/data_storage/warehouse.py +46 -34
- datachain/lib/arrow.py +23 -1
- datachain/lib/dc/csv.py +1 -0
- datachain/lib/dc/datachain.py +30 -13
- datachain/lib/listing.py +2 -0
- datachain/lib/udf.py +17 -5
- datachain/query/batch.py +40 -39
- datachain/query/dataset.py +33 -32
- datachain/query/dispatch.py +137 -75
- datachain/query/metrics.py +1 -2
- datachain/query/queue.py +1 -11
- datachain/query/udf.py +1 -1
- datachain/query/utils.py +8 -14
- datachain/utils.py +3 -0
- {datachain-0.16.4.dist-info → datachain-0.16.5.dist-info}/METADATA +1 -1
- {datachain-0.16.4.dist-info → datachain-0.16.5.dist-info}/RECORD +26 -26
- {datachain-0.16.4.dist-info → datachain-0.16.5.dist-info}/WHEEL +1 -1
- {datachain-0.16.4.dist-info → datachain-0.16.5.dist-info}/entry_points.txt +0 -0
- {datachain-0.16.4.dist-info → datachain-0.16.5.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.16.4.dist-info → datachain-0.16.5.dist-info}/top_level.txt +0 -0
datachain/lib/dc/datachain.py
CHANGED
|
@@ -1636,18 +1636,27 @@ class DataChain:
|
|
|
1636
1636
|
"""
|
|
1637
1637
|
from pyarrow.dataset import CsvFileFormat, JsonFileFormat
|
|
1638
1638
|
|
|
1639
|
-
from datachain.lib.arrow import
|
|
1639
|
+
from datachain.lib.arrow import (
|
|
1640
|
+
ArrowGenerator,
|
|
1641
|
+
fix_pyarrow_format,
|
|
1642
|
+
infer_schema,
|
|
1643
|
+
schema_to_output,
|
|
1644
|
+
)
|
|
1640
1645
|
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1646
|
+
parse_options = kwargs.pop("parse_options", None)
|
|
1647
|
+
if format := kwargs.get("format"):
|
|
1648
|
+
kwargs["format"] = fix_pyarrow_format(format, parse_options)
|
|
1649
|
+
|
|
1650
|
+
if (
|
|
1651
|
+
nrows
|
|
1652
|
+
and format not in ["csv", "json"]
|
|
1653
|
+
and not isinstance(format, (CsvFileFormat, JsonFileFormat))
|
|
1654
|
+
):
|
|
1655
|
+
raise DatasetPrepareError(
|
|
1656
|
+
self.name,
|
|
1657
|
+
"error in `parse_tabular` - "
|
|
1658
|
+
"`nrows` only supported for csv and json formats.",
|
|
1659
|
+
)
|
|
1651
1660
|
|
|
1652
1661
|
if "file" not in self.schema or not self.count():
|
|
1653
1662
|
raise DatasetPrepareError(self.name, "no files to parse.")
|
|
@@ -1656,7 +1665,7 @@ class DataChain:
|
|
|
1656
1665
|
col_names = output if isinstance(output, Sequence) else None
|
|
1657
1666
|
if col_names or not output:
|
|
1658
1667
|
try:
|
|
1659
|
-
schema = infer_schema(self, **kwargs)
|
|
1668
|
+
schema = infer_schema(self, **kwargs, parse_options=parse_options)
|
|
1660
1669
|
output, _ = schema_to_output(schema, col_names)
|
|
1661
1670
|
except ValueError as e:
|
|
1662
1671
|
raise DatasetPrepareError(self.name, e) from e
|
|
@@ -1682,7 +1691,15 @@ class DataChain:
|
|
|
1682
1691
|
# disable prefetch if nrows is set
|
|
1683
1692
|
settings = {"prefetch": 0} if nrows else {}
|
|
1684
1693
|
return self.settings(**settings).gen( # type: ignore[arg-type]
|
|
1685
|
-
ArrowGenerator(
|
|
1694
|
+
ArrowGenerator(
|
|
1695
|
+
schema,
|
|
1696
|
+
model,
|
|
1697
|
+
source,
|
|
1698
|
+
nrows,
|
|
1699
|
+
parse_options=parse_options,
|
|
1700
|
+
**kwargs,
|
|
1701
|
+
),
|
|
1702
|
+
output=output,
|
|
1686
1703
|
)
|
|
1687
1704
|
|
|
1688
1705
|
@classmethod
|
datachain/lib/listing.py
CHANGED
datachain/lib/udf.py
CHANGED
|
@@ -218,6 +218,18 @@ class UDFBase(AbstractUDF):
|
|
|
218
218
|
def name(self):
|
|
219
219
|
return self.__class__.__name__
|
|
220
220
|
|
|
221
|
+
@property
|
|
222
|
+
def verbose_name(self):
|
|
223
|
+
"""Returns the name of the function or class that implements the UDF."""
|
|
224
|
+
if self._func and callable(self._func):
|
|
225
|
+
if hasattr(self._func, "__name__"):
|
|
226
|
+
return self._func.__name__
|
|
227
|
+
if hasattr(self._func, "__class__") and hasattr(
|
|
228
|
+
self._func.__class__, "__name__"
|
|
229
|
+
):
|
|
230
|
+
return self._func.__class__.__name__
|
|
231
|
+
return "<unknown>"
|
|
232
|
+
|
|
221
233
|
@property
|
|
222
234
|
def signal_names(self) -> Iterable[str]:
|
|
223
235
|
return self.output.to_udf_spec().keys()
|
|
@@ -411,13 +423,13 @@ class BatchMapper(UDFBase):
|
|
|
411
423
|
self.setup()
|
|
412
424
|
|
|
413
425
|
for batch in udf_inputs:
|
|
414
|
-
n_rows = len(batch
|
|
426
|
+
n_rows = len(batch)
|
|
415
427
|
row_ids, *udf_args = zip(
|
|
416
428
|
*[
|
|
417
429
|
self._prepare_row_and_id(
|
|
418
430
|
row, udf_fields, catalog, cache, download_cb
|
|
419
431
|
)
|
|
420
|
-
for row in batch
|
|
432
|
+
for row in batch
|
|
421
433
|
]
|
|
422
434
|
)
|
|
423
435
|
result_objs = list(self.process_safe(udf_args))
|
|
@@ -489,7 +501,7 @@ class Aggregator(UDFBase):
|
|
|
489
501
|
|
|
490
502
|
def run(
|
|
491
503
|
self,
|
|
492
|
-
udf_fields:
|
|
504
|
+
udf_fields: Sequence[str],
|
|
493
505
|
udf_inputs: Iterable[RowsOutputBatch],
|
|
494
506
|
catalog: "Catalog",
|
|
495
507
|
cache: bool,
|
|
@@ -502,13 +514,13 @@ class Aggregator(UDFBase):
|
|
|
502
514
|
udf_args = zip(
|
|
503
515
|
*[
|
|
504
516
|
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
|
|
505
|
-
for row in batch
|
|
517
|
+
for row in batch
|
|
506
518
|
]
|
|
507
519
|
)
|
|
508
520
|
result_objs = self.process_safe(udf_args)
|
|
509
521
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
510
522
|
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
|
|
511
|
-
processed_cb.relative_update(len(batch
|
|
523
|
+
processed_cb.relative_update(len(batch))
|
|
512
524
|
yield output
|
|
513
525
|
|
|
514
526
|
self.teardown()
|
datachain/query/batch.py
CHANGED
|
@@ -2,22 +2,14 @@ import contextlib
|
|
|
2
2
|
import math
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from collections.abc import Generator, Sequence
|
|
5
|
-
from
|
|
6
|
-
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
7
|
-
|
|
8
|
-
from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
9
|
-
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
10
|
-
from datachain.query.utils import get_query_column, get_query_id_column
|
|
11
|
-
|
|
12
|
-
if TYPE_CHECKING:
|
|
13
|
-
from sqlalchemy import Select
|
|
5
|
+
from typing import Callable, Optional, Union
|
|
14
6
|
|
|
7
|
+
import sqlalchemy as sa
|
|
15
8
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
rows: Sequence[Sequence]
|
|
19
|
-
|
|
9
|
+
from datachain.data_storage.schema import PARTITION_COLUMN_ID
|
|
10
|
+
from datachain.query.utils import get_query_column
|
|
20
11
|
|
|
12
|
+
RowsOutputBatch = Sequence[Sequence]
|
|
21
13
|
RowsOutput = Union[Sequence, RowsOutputBatch]
|
|
22
14
|
|
|
23
15
|
|
|
@@ -30,8 +22,8 @@ class BatchingStrategy(ABC):
|
|
|
30
22
|
def __call__(
|
|
31
23
|
self,
|
|
32
24
|
execute: Callable,
|
|
33
|
-
query:
|
|
34
|
-
|
|
25
|
+
query: sa.Select,
|
|
26
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
35
27
|
) -> Generator[RowsOutput, None, None]:
|
|
36
28
|
"""Apply the provided parameters to the UDF."""
|
|
37
29
|
|
|
@@ -47,12 +39,16 @@ class NoBatching(BatchingStrategy):
|
|
|
47
39
|
def __call__(
|
|
48
40
|
self,
|
|
49
41
|
execute: Callable,
|
|
50
|
-
query:
|
|
51
|
-
|
|
42
|
+
query: sa.Select,
|
|
43
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
52
44
|
) -> Generator[Sequence, None, None]:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
45
|
+
ids_only = False
|
|
46
|
+
if id_col is not None:
|
|
47
|
+
query = query.with_only_columns(id_col)
|
|
48
|
+
ids_only = True
|
|
49
|
+
|
|
50
|
+
rows = execute(query)
|
|
51
|
+
yield from (r[0] for r in rows) if ids_only else rows
|
|
56
52
|
|
|
57
53
|
|
|
58
54
|
class Batch(BatchingStrategy):
|
|
@@ -69,27 +65,31 @@ class Batch(BatchingStrategy):
|
|
|
69
65
|
def __call__(
|
|
70
66
|
self,
|
|
71
67
|
execute: Callable,
|
|
72
|
-
query:
|
|
73
|
-
|
|
74
|
-
) -> Generator[
|
|
75
|
-
|
|
76
|
-
|
|
68
|
+
query: sa.Select,
|
|
69
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
70
|
+
) -> Generator[RowsOutput, None, None]:
|
|
71
|
+
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
|
|
72
|
+
|
|
73
|
+
ids_only = False
|
|
74
|
+
if id_col is not None:
|
|
75
|
+
query = query.with_only_columns(id_col)
|
|
76
|
+
ids_only = True
|
|
77
77
|
|
|
78
78
|
# choose page size that is a multiple of the batch size
|
|
79
79
|
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count
|
|
80
80
|
|
|
81
81
|
# select rows in batches
|
|
82
|
-
results
|
|
82
|
+
results = []
|
|
83
83
|
|
|
84
|
-
with contextlib.closing(execute(query, page_size=page_size)) as
|
|
85
|
-
for row in
|
|
84
|
+
with contextlib.closing(execute(query, page_size=page_size)) as batch_rows:
|
|
85
|
+
for row in batch_rows:
|
|
86
86
|
results.append(row)
|
|
87
87
|
if len(results) >= self.count:
|
|
88
88
|
batch, results = results[: self.count], results[self.count :]
|
|
89
|
-
yield
|
|
89
|
+
yield [r[0] for r in batch] if ids_only else batch
|
|
90
90
|
|
|
91
91
|
if len(results) > 0:
|
|
92
|
-
yield
|
|
92
|
+
yield [r[0] for r in results] if ids_only else results
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
class Partition(BatchingStrategy):
|
|
@@ -104,18 +104,19 @@ class Partition(BatchingStrategy):
|
|
|
104
104
|
def __call__(
|
|
105
105
|
self,
|
|
106
106
|
execute: Callable,
|
|
107
|
-
query:
|
|
108
|
-
|
|
109
|
-
) -> Generator[
|
|
110
|
-
id_col = get_query_id_column(query)
|
|
107
|
+
query: sa.Select,
|
|
108
|
+
id_col: Optional[sa.ColumnElement] = None,
|
|
109
|
+
) -> Generator[RowsOutput, None, None]:
|
|
111
110
|
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
|
|
112
111
|
raise RuntimeError("partition column not found in query")
|
|
113
112
|
|
|
114
|
-
|
|
113
|
+
ids_only = False
|
|
114
|
+
if id_col is not None:
|
|
115
115
|
query = query.with_only_columns(id_col, partition_col)
|
|
116
|
+
ids_only = True
|
|
116
117
|
|
|
117
118
|
current_partition: Optional[int] = None
|
|
118
|
-
batch: list
|
|
119
|
+
batch: list = []
|
|
119
120
|
|
|
120
121
|
query_fields = [str(c.name) for c in query.selected_columns]
|
|
121
122
|
id_column_idx = query_fields.index("sys__id")
|
|
@@ -132,9 +133,9 @@ class Partition(BatchingStrategy):
|
|
|
132
133
|
if current_partition != partition:
|
|
133
134
|
current_partition = partition
|
|
134
135
|
if len(batch) > 0:
|
|
135
|
-
yield
|
|
136
|
+
yield batch
|
|
136
137
|
batch = []
|
|
137
|
-
batch.append(
|
|
138
|
+
batch.append(row[id_column_idx] if ids_only else row)
|
|
138
139
|
|
|
139
140
|
if len(batch) > 0:
|
|
140
|
-
yield
|
|
141
|
+
yield batch
|
datachain/query/dataset.py
CHANGED
|
@@ -42,15 +42,9 @@ from datachain.data_storage.schema import (
|
|
|
42
42
|
partition_columns,
|
|
43
43
|
)
|
|
44
44
|
from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
|
|
45
|
-
from datachain.error import
|
|
46
|
-
DatasetNotFoundError,
|
|
47
|
-
QueryScriptCancelError,
|
|
48
|
-
)
|
|
45
|
+
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
49
46
|
from datachain.func.base import Function
|
|
50
|
-
from datachain.lib.listing import
|
|
51
|
-
is_listing_dataset,
|
|
52
|
-
listing_dataset_expired,
|
|
53
|
-
)
|
|
47
|
+
from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
|
|
54
48
|
from datachain.lib.udf import UDFAdapter, _get_cache
|
|
55
49
|
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
|
|
56
50
|
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
@@ -420,41 +414,30 @@ class UDFStep(Step, ABC):
|
|
|
420
414
|
"""
|
|
421
415
|
|
|
422
416
|
def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
rows_total = self.catalog.warehouse.query_count(query)
|
|
426
|
-
if rows_total == 0:
|
|
417
|
+
if (rows_total := self.catalog.warehouse.query_count(query)) == 0:
|
|
427
418
|
return
|
|
428
419
|
|
|
420
|
+
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE
|
|
421
|
+
from datachain.catalog.loader import (
|
|
422
|
+
DISTRIBUTED_IMPORT_PATH,
|
|
423
|
+
get_udf_distributor_class,
|
|
424
|
+
)
|
|
425
|
+
|
|
429
426
|
workers = determine_workers(self.workers, rows_total=rows_total)
|
|
430
427
|
processes = determine_processes(self.parallel, rows_total=rows_total)
|
|
431
428
|
|
|
432
429
|
use_partitioning = self.partition_by is not None
|
|
433
430
|
batching = self.udf.get_batching(use_partitioning)
|
|
434
431
|
udf_fields = [str(c.name) for c in query.selected_columns]
|
|
432
|
+
udf_distributor_class = get_udf_distributor_class()
|
|
435
433
|
|
|
436
434
|
prefetch = self.udf.prefetch
|
|
437
435
|
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
438
436
|
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
439
|
-
try:
|
|
440
|
-
if workers:
|
|
441
|
-
if catalog.in_memory:
|
|
442
|
-
raise RuntimeError(
|
|
443
|
-
"In-memory databases cannot be used with "
|
|
444
|
-
"distributed processing."
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
from datachain.catalog.loader import (
|
|
448
|
-
DISTRIBUTED_IMPORT_PATH,
|
|
449
|
-
get_udf_distributor_class,
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
if not (udf_distributor_class := get_udf_distributor_class()):
|
|
453
|
-
raise RuntimeError(
|
|
454
|
-
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
455
|
-
"for distributed UDF processing."
|
|
456
|
-
)
|
|
457
437
|
|
|
438
|
+
try:
|
|
439
|
+
if udf_distributor_class and not catalog.in_memory:
|
|
440
|
+
# Use the UDF distributor if available (running in SaaS)
|
|
458
441
|
udf_distributor = udf_distributor_class(
|
|
459
442
|
catalog=catalog,
|
|
460
443
|
table=udf_table,
|
|
@@ -470,7 +453,20 @@ class UDFStep(Step, ABC):
|
|
|
470
453
|
min_task_size=self.min_task_size,
|
|
471
454
|
)
|
|
472
455
|
udf_distributor()
|
|
473
|
-
|
|
456
|
+
return
|
|
457
|
+
|
|
458
|
+
if workers:
|
|
459
|
+
if catalog.in_memory:
|
|
460
|
+
raise RuntimeError(
|
|
461
|
+
"In-memory databases cannot be used with "
|
|
462
|
+
"distributed processing."
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
raise RuntimeError(
|
|
466
|
+
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
467
|
+
"for distributed UDF processing."
|
|
468
|
+
)
|
|
469
|
+
if processes:
|
|
474
470
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
475
471
|
if catalog.in_memory:
|
|
476
472
|
raise RuntimeError(
|
|
@@ -504,7 +500,12 @@ class UDFStep(Step, ABC):
|
|
|
504
500
|
with subprocess.Popen( # noqa: S603
|
|
505
501
|
cmd, env=envs, stdin=subprocess.PIPE
|
|
506
502
|
) as process:
|
|
507
|
-
|
|
503
|
+
try:
|
|
504
|
+
process.communicate(process_data)
|
|
505
|
+
except KeyboardInterrupt:
|
|
506
|
+
raise QueryScriptCancelError(
|
|
507
|
+
"UDF execution was canceled by the user."
|
|
508
|
+
) from None
|
|
508
509
|
if retval := process.poll():
|
|
509
510
|
raise RuntimeError(
|
|
510
511
|
f"UDF Execution Failed! Exit code: {retval}"
|