datachain 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +3 -4
- datachain/client/gcs.py +9 -0
- datachain/data_storage/warehouse.py +0 -1
- datachain/lib/arrow.py +82 -58
- datachain/lib/dc.py +12 -57
- datachain/lib/file.py +3 -1
- datachain/lib/listing.py +44 -0
- datachain/lib/udf.py +0 -1
- datachain/query/batch.py +32 -6
- datachain/query/dataset.py +17 -17
- datachain/query/dispatch.py +125 -125
- datachain/query/session.py +8 -5
- datachain/query/udf.py +20 -0
- datachain/query/utils.py +42 -0
- datachain/utils.py +1 -1
- {datachain-0.8.0.dist-info → datachain-0.8.1.dist-info}/METADATA +3 -3
- {datachain-0.8.0.dist-info → datachain-0.8.1.dist-info}/RECORD +21 -19
- {datachain-0.8.0.dist-info → datachain-0.8.1.dist-info}/LICENSE +0 -0
- {datachain-0.8.0.dist-info → datachain-0.8.1.dist-info}/WHEEL +0 -0
- {datachain-0.8.0.dist-info → datachain-0.8.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.0.dist-info → datachain-0.8.1.dist-info}/top_level.txt +0 -0
datachain/query/dataset.py
CHANGED
|
@@ -43,8 +43,9 @@ from datachain.data_storage.schema import (
|
|
|
43
43
|
from datachain.dataset import DatasetStatus, RowDict
|
|
44
44
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
45
45
|
from datachain.func.base import Function
|
|
46
|
-
from datachain.lib.udf import UDFAdapter
|
|
47
46
|
from datachain.progress import CombinedDownloadCallback
|
|
47
|
+
from datachain.query.schema import C, UDFParamSpec, normalize_param
|
|
48
|
+
from datachain.query.session import Session
|
|
48
49
|
from datachain.sql.functions.random import rand
|
|
49
50
|
from datachain.utils import (
|
|
50
51
|
batched,
|
|
@@ -53,9 +54,6 @@ from datachain.utils import (
|
|
|
53
54
|
get_datachain_executable,
|
|
54
55
|
)
|
|
55
56
|
|
|
56
|
-
from .schema import C, UDFParamSpec, normalize_param
|
|
57
|
-
from .session import Session
|
|
58
|
-
|
|
59
57
|
if TYPE_CHECKING:
|
|
60
58
|
from sqlalchemy.sql.elements import ClauseElement
|
|
61
59
|
from sqlalchemy.sql.schema import Table
|
|
@@ -65,7 +63,8 @@ if TYPE_CHECKING:
|
|
|
65
63
|
from datachain.catalog import Catalog
|
|
66
64
|
from datachain.data_storage import AbstractWarehouse
|
|
67
65
|
from datachain.dataset import DatasetRecord
|
|
68
|
-
from datachain.lib.udf import UDFResult
|
|
66
|
+
from datachain.lib.udf import UDFAdapter, UDFResult
|
|
67
|
+
from datachain.query.udf import UdfInfo
|
|
69
68
|
|
|
70
69
|
P = ParamSpec("P")
|
|
71
70
|
|
|
@@ -301,7 +300,7 @@ def adjust_outputs(
|
|
|
301
300
|
return row
|
|
302
301
|
|
|
303
302
|
|
|
304
|
-
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
|
|
303
|
+
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]:
|
|
305
304
|
"""Optimization: Precompute UDF column types so these don't have to be computed
|
|
306
305
|
in the convert_type function for each row in a loop."""
|
|
307
306
|
dialect = warehouse.db.dialect
|
|
@@ -322,7 +321,7 @@ def process_udf_outputs(
|
|
|
322
321
|
warehouse: "AbstractWarehouse",
|
|
323
322
|
udf_table: "Table",
|
|
324
323
|
udf_results: Iterator[Iterable["UDFResult"]],
|
|
325
|
-
udf: UDFAdapter,
|
|
324
|
+
udf: "UDFAdapter",
|
|
326
325
|
batch_size: int = INSERT_BATCH_SIZE,
|
|
327
326
|
cb: Callback = DEFAULT_CALLBACK,
|
|
328
327
|
) -> None:
|
|
@@ -347,6 +346,8 @@ def process_udf_outputs(
|
|
|
347
346
|
for row_chunk in batched(rows, batch_size):
|
|
348
347
|
warehouse.insert_rows(udf_table, row_chunk)
|
|
349
348
|
|
|
349
|
+
warehouse.insert_rows_done(udf_table)
|
|
350
|
+
|
|
350
351
|
|
|
351
352
|
def get_download_callback() -> Callback:
|
|
352
353
|
return CombinedDownloadCallback(
|
|
@@ -366,7 +367,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
366
367
|
|
|
367
368
|
@frozen
|
|
368
369
|
class UDFStep(Step, ABC):
|
|
369
|
-
udf: UDFAdapter
|
|
370
|
+
udf: "UDFAdapter"
|
|
370
371
|
catalog: "Catalog"
|
|
371
372
|
partition_by: Optional[PartitionByType] = None
|
|
372
373
|
parallel: Optional[int] = None
|
|
@@ -440,7 +441,7 @@ class UDFStep(Step, ABC):
|
|
|
440
441
|
raise RuntimeError(
|
|
441
442
|
"In-memory databases cannot be used with parallel processing."
|
|
442
443
|
)
|
|
443
|
-
udf_info = {
|
|
444
|
+
udf_info: UdfInfo = {
|
|
444
445
|
"udf_data": filtered_cloudpickle_dumps(self.udf),
|
|
445
446
|
"catalog_init": self.catalog.get_init_params(),
|
|
446
447
|
"metastore_clone_params": self.catalog.metastore.clone_params(),
|
|
@@ -464,8 +465,8 @@ class UDFStep(Step, ABC):
|
|
|
464
465
|
|
|
465
466
|
with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
|
|
466
467
|
process.communicate(process_data)
|
|
467
|
-
if process.poll():
|
|
468
|
-
raise RuntimeError("UDF Execution Failed!")
|
|
468
|
+
if retval := process.poll():
|
|
469
|
+
raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}")
|
|
469
470
|
else:
|
|
470
471
|
# Otherwise process single-threaded (faster for smaller UDFs)
|
|
471
472
|
warehouse = self.catalog.warehouse
|
|
@@ -479,7 +480,6 @@ class UDFStep(Step, ABC):
|
|
|
479
480
|
udf_fields,
|
|
480
481
|
udf_inputs,
|
|
481
482
|
self.catalog,
|
|
482
|
-
self.is_generator,
|
|
483
483
|
self.cache,
|
|
484
484
|
download_cb,
|
|
485
485
|
processed_cb,
|
|
@@ -496,8 +496,6 @@ class UDFStep(Step, ABC):
|
|
|
496
496
|
processed_cb.close()
|
|
497
497
|
generated_cb.close()
|
|
498
498
|
|
|
499
|
-
warehouse.insert_rows_done(udf_table)
|
|
500
|
-
|
|
501
499
|
except QueryScriptCancelError:
|
|
502
500
|
self.catalog.warehouse.close()
|
|
503
501
|
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
|
|
@@ -1491,7 +1489,7 @@ class DatasetQuery:
|
|
|
1491
1489
|
@detach
|
|
1492
1490
|
def add_signals(
|
|
1493
1491
|
self,
|
|
1494
|
-
udf: UDFAdapter,
|
|
1492
|
+
udf: "UDFAdapter",
|
|
1495
1493
|
parallel: Optional[int] = None,
|
|
1496
1494
|
workers: Union[bool, int] = False,
|
|
1497
1495
|
min_task_size: Optional[int] = None,
|
|
@@ -1535,7 +1533,7 @@ class DatasetQuery:
|
|
|
1535
1533
|
@detach
|
|
1536
1534
|
def generate(
|
|
1537
1535
|
self,
|
|
1538
|
-
udf: UDFAdapter,
|
|
1536
|
+
udf: "UDFAdapter",
|
|
1539
1537
|
parallel: Optional[int] = None,
|
|
1540
1538
|
workers: Union[bool, int] = False,
|
|
1541
1539
|
min_task_size: Optional[int] = None,
|
|
@@ -1617,7 +1615,9 @@ class DatasetQuery:
|
|
|
1617
1615
|
)
|
|
1618
1616
|
version = version or dataset.latest_version
|
|
1619
1617
|
|
|
1620
|
-
self.session.add_dataset_version(
|
|
1618
|
+
self.session.add_dataset_version(
|
|
1619
|
+
dataset=dataset, version=version, listing=kwargs.get("listing", False)
|
|
1620
|
+
)
|
|
1621
1621
|
|
|
1622
1622
|
dr = self.catalog.warehouse.dataset_rows(dataset)
|
|
1623
1623
|
|
datachain/query/dispatch.py
CHANGED
|
@@ -1,34 +1,37 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from collections.abc import
|
|
2
|
+
from collections.abc import Iterable, Sequence
|
|
3
3
|
from itertools import chain
|
|
4
4
|
from multiprocessing import cpu_count
|
|
5
5
|
from sys import stdin
|
|
6
|
-
from
|
|
6
|
+
from threading import Timer
|
|
7
|
+
from typing import TYPE_CHECKING, Optional
|
|
7
8
|
|
|
8
9
|
import attrs
|
|
9
10
|
import multiprocess
|
|
10
11
|
from cloudpickle import load, loads
|
|
11
12
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
12
13
|
from multiprocess import get_context
|
|
14
|
+
from sqlalchemy.sql import func
|
|
13
15
|
|
|
14
16
|
from datachain.catalog import Catalog
|
|
15
17
|
from datachain.catalog.loader import get_distributed_class
|
|
16
|
-
from datachain.
|
|
18
|
+
from datachain.query.batch import RowsOutput, RowsOutputBatch
|
|
17
19
|
from datachain.query.dataset import (
|
|
18
20
|
get_download_callback,
|
|
19
21
|
get_generated_callback,
|
|
20
22
|
get_processed_callback,
|
|
21
23
|
process_udf_outputs,
|
|
22
24
|
)
|
|
23
|
-
from datachain.query.queue import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
from datachain.
|
|
25
|
+
from datachain.query.queue import get_from_queue, put_into_queue
|
|
26
|
+
from datachain.query.udf import UdfInfo
|
|
27
|
+
from datachain.query.utils import get_query_id_column
|
|
28
|
+
from datachain.utils import batched, flatten
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from sqlalchemy import Select, Table
|
|
32
|
+
|
|
33
|
+
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
34
|
+
from datachain.lib.udf import UDFAdapter
|
|
32
35
|
|
|
33
36
|
DEFAULT_BATCH_SIZE = 10000
|
|
34
37
|
STOP_SIGNAL = "STOP"
|
|
@@ -38,10 +41,6 @@ FAILED_STATUS = "FAILED"
|
|
|
38
41
|
NOTIFY_STATUS = "NOTIFY"
|
|
39
42
|
|
|
40
43
|
|
|
41
|
-
def full_module_type_path(typ: type) -> str:
|
|
42
|
-
return f"{typ.__module__}.{typ.__qualname__}"
|
|
43
|
-
|
|
44
|
-
|
|
45
44
|
def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
46
45
|
if not n_workers:
|
|
47
46
|
return cpu_count()
|
|
@@ -52,55 +51,42 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
|
52
51
|
|
|
53
52
|
def udf_entrypoint() -> int:
|
|
54
53
|
# Load UDF info from stdin
|
|
55
|
-
udf_info = load(stdin.buffer)
|
|
56
|
-
|
|
57
|
-
(
|
|
58
|
-
warehouse_class,
|
|
59
|
-
warehouse_args,
|
|
60
|
-
warehouse_kwargs,
|
|
61
|
-
) = udf_info["warehouse_clone_params"]
|
|
62
|
-
warehouse = warehouse_class(*warehouse_args, **warehouse_kwargs)
|
|
54
|
+
udf_info: UdfInfo = load(stdin.buffer)
|
|
63
55
|
|
|
64
56
|
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
65
|
-
dispatch = UDFDispatcher(
|
|
66
|
-
udf_info["udf_data"],
|
|
67
|
-
udf_info["catalog_init"],
|
|
68
|
-
udf_info["metastore_clone_params"],
|
|
69
|
-
udf_info["warehouse_clone_params"],
|
|
70
|
-
udf_fields=udf_info["udf_fields"],
|
|
71
|
-
cache=udf_info["cache"],
|
|
72
|
-
is_generator=udf_info.get("is_generator", False),
|
|
73
|
-
)
|
|
57
|
+
dispatch = UDFDispatcher(udf_info)
|
|
74
58
|
|
|
75
59
|
query = udf_info["query"]
|
|
76
60
|
batching = udf_info["batching"]
|
|
77
|
-
table = udf_info["table"]
|
|
78
61
|
n_workers = udf_info["processes"]
|
|
79
|
-
udf = loads(udf_info["udf_data"])
|
|
80
62
|
if n_workers is True:
|
|
81
|
-
# Use default number of CPUs (cores)
|
|
82
|
-
|
|
63
|
+
n_workers = None # Use default number of CPUs (cores)
|
|
64
|
+
|
|
65
|
+
wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
|
|
66
|
+
warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
|
|
67
|
+
|
|
68
|
+
total_rows = next(
|
|
69
|
+
warehouse.db.execute(
|
|
70
|
+
query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
|
|
71
|
+
)
|
|
72
|
+
)[0]
|
|
83
73
|
|
|
84
74
|
with contextlib.closing(
|
|
85
|
-
batching(warehouse.dataset_select_paginated, query)
|
|
75
|
+
batching(warehouse.dataset_select_paginated, query, ids_only=True)
|
|
86
76
|
) as udf_inputs:
|
|
87
77
|
download_cb = get_download_callback()
|
|
88
78
|
processed_cb = get_processed_callback()
|
|
89
|
-
generated_cb = get_generated_callback(dispatch.is_generator)
|
|
90
79
|
try:
|
|
91
|
-
|
|
92
|
-
|
|
80
|
+
dispatch.run_udf_parallel(
|
|
81
|
+
udf_inputs,
|
|
82
|
+
total_rows=total_rows,
|
|
93
83
|
n_workers=n_workers,
|
|
94
84
|
processed_cb=processed_cb,
|
|
95
85
|
download_cb=download_cb,
|
|
96
86
|
)
|
|
97
|
-
process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb)
|
|
98
87
|
finally:
|
|
99
88
|
download_cb.close()
|
|
100
89
|
processed_cb.close()
|
|
101
|
-
generated_cb.close()
|
|
102
|
-
|
|
103
|
-
warehouse.insert_rows_done(table)
|
|
104
90
|
|
|
105
91
|
return 0
|
|
106
92
|
|
|
@@ -114,32 +100,17 @@ class UDFDispatcher:
|
|
|
114
100
|
task_queue: Optional[multiprocess.Queue] = None
|
|
115
101
|
done_queue: Optional[multiprocess.Queue] = None
|
|
116
102
|
|
|
117
|
-
def __init__(
|
|
118
|
-
self
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
self.udf_data = udf_data
|
|
129
|
-
self.catalog_init_params = catalog_init_params
|
|
130
|
-
(
|
|
131
|
-
self.metastore_class,
|
|
132
|
-
self.metastore_args,
|
|
133
|
-
self.metastore_kwargs,
|
|
134
|
-
) = metastore_clone_params
|
|
135
|
-
(
|
|
136
|
-
self.warehouse_class,
|
|
137
|
-
self.warehouse_args,
|
|
138
|
-
self.warehouse_kwargs,
|
|
139
|
-
) = warehouse_clone_params
|
|
140
|
-
self.udf_fields = udf_fields
|
|
141
|
-
self.cache = cache
|
|
142
|
-
self.is_generator = is_generator
|
|
103
|
+
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
|
|
104
|
+
self.udf_data = udf_info["udf_data"]
|
|
105
|
+
self.catalog_init_params = udf_info["catalog_init"]
|
|
106
|
+
self.metastore_clone_params = udf_info["metastore_clone_params"]
|
|
107
|
+
self.warehouse_clone_params = udf_info["warehouse_clone_params"]
|
|
108
|
+
self.query = udf_info["query"]
|
|
109
|
+
self.table = udf_info["table"]
|
|
110
|
+
self.udf_fields = udf_info["udf_fields"]
|
|
111
|
+
self.cache = udf_info["cache"]
|
|
112
|
+
self.is_generator = udf_info["is_generator"]
|
|
113
|
+
self.is_batching = udf_info["batching"].is_batching
|
|
143
114
|
self.buffer_size = buffer_size
|
|
144
115
|
self.catalog = None
|
|
145
116
|
self.task_queue = None
|
|
@@ -148,12 +119,10 @@ class UDFDispatcher:
|
|
|
148
119
|
|
|
149
120
|
def _create_worker(self) -> "UDFWorker":
|
|
150
121
|
if not self.catalog:
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
warehouse =
|
|
155
|
-
*self.warehouse_args, **self.warehouse_kwargs
|
|
156
|
-
)
|
|
122
|
+
ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
|
|
123
|
+
metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
|
|
124
|
+
ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
|
|
125
|
+
warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
|
|
157
126
|
self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
|
|
158
127
|
self.udf = loads(self.udf_data)
|
|
159
128
|
return UDFWorker(
|
|
@@ -161,7 +130,10 @@ class UDFDispatcher:
|
|
|
161
130
|
self.udf,
|
|
162
131
|
self.task_queue,
|
|
163
132
|
self.done_queue,
|
|
133
|
+
self.query,
|
|
134
|
+
self.table,
|
|
164
135
|
self.is_generator,
|
|
136
|
+
self.is_batching,
|
|
165
137
|
self.cache,
|
|
166
138
|
self.udf_fields,
|
|
167
139
|
)
|
|
@@ -189,26 +161,27 @@ class UDFDispatcher:
|
|
|
189
161
|
|
|
190
162
|
def run_udf_parallel( # noqa: C901, PLR0912
|
|
191
163
|
self,
|
|
192
|
-
input_rows,
|
|
164
|
+
input_rows: Iterable[RowsOutput],
|
|
165
|
+
total_rows: int,
|
|
193
166
|
n_workers: Optional[int] = None,
|
|
194
|
-
input_queue=None,
|
|
195
167
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
196
168
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
197
|
-
) ->
|
|
169
|
+
) -> None:
|
|
198
170
|
n_workers = get_n_workers_from_arg(n_workers)
|
|
199
171
|
|
|
172
|
+
input_batch_size = total_rows // n_workers
|
|
173
|
+
if input_batch_size == 0:
|
|
174
|
+
input_batch_size = 1
|
|
175
|
+
elif input_batch_size > DEFAULT_BATCH_SIZE:
|
|
176
|
+
input_batch_size = DEFAULT_BATCH_SIZE
|
|
177
|
+
|
|
200
178
|
if self.buffer_size < n_workers:
|
|
201
179
|
raise RuntimeError(
|
|
202
180
|
"Parallel run error: buffer size is smaller than "
|
|
203
181
|
f"number of workers: {self.buffer_size} < {n_workers}"
|
|
204
182
|
)
|
|
205
183
|
|
|
206
|
-
|
|
207
|
-
streaming_mode = True
|
|
208
|
-
self.task_queue = input_queue
|
|
209
|
-
else:
|
|
210
|
-
streaming_mode = False
|
|
211
|
-
self.task_queue = self.ctx.Queue()
|
|
184
|
+
self.task_queue = self.ctx.Queue()
|
|
212
185
|
self.done_queue = self.ctx.Queue()
|
|
213
186
|
pool = [
|
|
214
187
|
self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
|
|
@@ -223,41 +196,41 @@ class UDFDispatcher:
|
|
|
223
196
|
# Will be set to True when the input is exhausted
|
|
224
197
|
input_finished = False
|
|
225
198
|
|
|
226
|
-
if not
|
|
227
|
-
|
|
228
|
-
input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
|
|
199
|
+
if not self.is_batching:
|
|
200
|
+
input_rows = batched(flatten(input_rows), input_batch_size)
|
|
229
201
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
202
|
+
# Stop all workers after the input rows have finished processing
|
|
203
|
+
input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
|
|
204
|
+
|
|
205
|
+
# Add initial buffer of tasks
|
|
206
|
+
for _ in range(self.buffer_size):
|
|
207
|
+
try:
|
|
208
|
+
put_into_queue(self.task_queue, next(input_data))
|
|
209
|
+
except StopIteration:
|
|
210
|
+
input_finished = True
|
|
211
|
+
break
|
|
237
212
|
|
|
238
213
|
# Process all tasks
|
|
239
214
|
while n_workers > 0:
|
|
240
215
|
result = get_from_queue(self.done_queue)
|
|
216
|
+
|
|
217
|
+
if downloaded := result.get("downloaded"):
|
|
218
|
+
download_cb.relative_update(downloaded)
|
|
219
|
+
if processed := result.get("processed"):
|
|
220
|
+
processed_cb.relative_update(processed)
|
|
221
|
+
|
|
241
222
|
status = result["status"]
|
|
242
|
-
if status
|
|
243
|
-
|
|
244
|
-
download_cb.relative_update(downloaded)
|
|
245
|
-
if processed := result.get("processed"):
|
|
246
|
-
processed_cb.relative_update(processed)
|
|
223
|
+
if status in (OK_STATUS, NOTIFY_STATUS):
|
|
224
|
+
pass # Do nothing here
|
|
247
225
|
elif status == FINISHED_STATUS:
|
|
248
|
-
# Worker finished
|
|
249
|
-
n_workers -= 1
|
|
250
|
-
elif status == OK_STATUS:
|
|
251
|
-
if processed := result.get("processed"):
|
|
252
|
-
processed_cb.relative_update(processed)
|
|
253
|
-
yield msgpack_unpack(result["result"])
|
|
226
|
+
n_workers -= 1 # Worker finished
|
|
254
227
|
else: # Failed / error
|
|
255
228
|
n_workers -= 1
|
|
256
229
|
if exc := result.get("exception"):
|
|
257
230
|
raise exc
|
|
258
231
|
raise RuntimeError("Internal error: Parallel UDF execution failed")
|
|
259
232
|
|
|
260
|
-
if status == OK_STATUS and not
|
|
233
|
+
if status == OK_STATUS and not input_finished:
|
|
261
234
|
try:
|
|
262
235
|
put_into_queue(self.task_queue, next(input_data))
|
|
263
236
|
except StopIteration:
|
|
@@ -311,11 +284,14 @@ class ProcessedCallback(Callback):
|
|
|
311
284
|
|
|
312
285
|
@attrs.define
|
|
313
286
|
class UDFWorker:
|
|
314
|
-
catalog: Catalog
|
|
315
|
-
udf: UDFAdapter
|
|
287
|
+
catalog: "Catalog"
|
|
288
|
+
udf: "UDFAdapter"
|
|
316
289
|
task_queue: "multiprocess.Queue"
|
|
317
290
|
done_queue: "multiprocess.Queue"
|
|
291
|
+
query: "Select"
|
|
292
|
+
table: "Table"
|
|
318
293
|
is_generator: bool
|
|
294
|
+
is_batching: bool
|
|
319
295
|
cache: bool
|
|
320
296
|
udf_fields: Sequence[str]
|
|
321
297
|
cb: Callback = attrs.field()
|
|
@@ -326,30 +302,54 @@ class UDFWorker:
|
|
|
326
302
|
|
|
327
303
|
def run(self) -> None:
|
|
328
304
|
processed_cb = ProcessedCallback()
|
|
305
|
+
generated_cb = get_generated_callback(self.is_generator)
|
|
306
|
+
|
|
329
307
|
udf_results = self.udf.run(
|
|
330
308
|
self.udf_fields,
|
|
331
|
-
|
|
309
|
+
self.get_inputs(),
|
|
332
310
|
self.catalog,
|
|
333
|
-
self.is_generator,
|
|
334
311
|
self.cache,
|
|
335
312
|
download_cb=self.cb,
|
|
336
313
|
processed_cb=processed_cb,
|
|
337
314
|
)
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
315
|
+
process_udf_outputs(
|
|
316
|
+
self.catalog.warehouse,
|
|
317
|
+
self.table,
|
|
318
|
+
self.notify_and_process(udf_results, processed_cb),
|
|
319
|
+
self.udf,
|
|
320
|
+
cb=generated_cb,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
put_into_queue(
|
|
324
|
+
self.done_queue,
|
|
325
|
+
{"status": FINISHED_STATUS, "processed": processed_cb.processed_rows},
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def notify_and_process(self, udf_results, processed_cb):
|
|
329
|
+
for row in udf_results:
|
|
347
330
|
put_into_queue(
|
|
348
331
|
self.done_queue,
|
|
349
|
-
{"status":
|
|
332
|
+
{"status": OK_STATUS, "processed": processed_cb.processed_rows},
|
|
350
333
|
)
|
|
351
|
-
|
|
334
|
+
yield row
|
|
352
335
|
|
|
353
336
|
def get_inputs(self):
|
|
354
|
-
|
|
355
|
-
|
|
337
|
+
warehouse = self.catalog.warehouse.clone()
|
|
338
|
+
col_id = get_query_id_column(self.query)
|
|
339
|
+
|
|
340
|
+
if self.is_batching:
|
|
341
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
342
|
+
ids = [row[0] for row in batch.rows]
|
|
343
|
+
rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
|
|
344
|
+
yield RowsOutputBatch(list(rows))
|
|
345
|
+
else:
|
|
346
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
347
|
+
yield from warehouse.dataset_rows_select(
|
|
348
|
+
self.query.where(col_id.in_(batch))
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class RepeatTimer(Timer):
|
|
353
|
+
def run(self):
|
|
354
|
+
while not self.finished.wait(self.interval):
|
|
355
|
+
self.function(*self.args, **self.kwargs)
|
datachain/query/session.py
CHANGED
|
@@ -69,7 +69,7 @@ class Session:
|
|
|
69
69
|
self.catalog = catalog or get_catalog(
|
|
70
70
|
client_config=client_config, in_memory=in_memory
|
|
71
71
|
)
|
|
72
|
-
self.dataset_versions: list[tuple[DatasetRecord, int]] = []
|
|
72
|
+
self.dataset_versions: list[tuple[DatasetRecord, int, bool]] = []
|
|
73
73
|
|
|
74
74
|
def __enter__(self):
|
|
75
75
|
# Push the current context onto the stack
|
|
@@ -89,8 +89,10 @@ class Session:
|
|
|
89
89
|
if Session.SESSION_CONTEXTS:
|
|
90
90
|
Session.SESSION_CONTEXTS.pop()
|
|
91
91
|
|
|
92
|
-
def add_dataset_version(
|
|
93
|
-
self
|
|
92
|
+
def add_dataset_version(
|
|
93
|
+
self, dataset: "DatasetRecord", version: int, listing: bool = False
|
|
94
|
+
) -> None:
|
|
95
|
+
self.dataset_versions.append((dataset, version, listing))
|
|
94
96
|
|
|
95
97
|
def generate_temp_dataset_name(self) -> str:
|
|
96
98
|
return self.get_temp_prefix() + uuid4().hex[: self.TEMP_TABLE_UUID_LEN]
|
|
@@ -111,8 +113,9 @@ class Session:
|
|
|
111
113
|
if not self.dataset_versions:
|
|
112
114
|
return
|
|
113
115
|
|
|
114
|
-
for dataset, version in self.dataset_versions:
|
|
115
|
-
|
|
116
|
+
for dataset, version, listing in self.dataset_versions:
|
|
117
|
+
if not listing:
|
|
118
|
+
self.catalog.remove_dataset_version(dataset, version)
|
|
116
119
|
|
|
117
120
|
self.dataset_versions.clear()
|
|
118
121
|
|
datachain/query/udf.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from sqlalchemy import Select, Table
|
|
5
|
+
|
|
6
|
+
from datachain.query.batch import BatchingStrategy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class UdfInfo(TypedDict):
|
|
10
|
+
udf_data: bytes
|
|
11
|
+
catalog_init: dict[str, Any]
|
|
12
|
+
metastore_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
|
|
13
|
+
warehouse_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]]
|
|
14
|
+
table: "Table"
|
|
15
|
+
query: "Select"
|
|
16
|
+
udf_fields: list[str]
|
|
17
|
+
batching: "BatchingStrategy"
|
|
18
|
+
processes: Optional[int]
|
|
19
|
+
is_generator: bool
|
|
20
|
+
cache: bool
|
datachain/query/utils.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from sqlalchemy import ColumnElement, Select, TextClause
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
ColT = Union[Column, "ColumnElement", "TextClause"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def column_name(col: ColT) -> str:
|
|
13
|
+
"""Returns column name from column element."""
|
|
14
|
+
return col.name if isinstance(col, Column) else str(col)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_query_column(query: "Select", name: str) -> Optional[ColT]:
|
|
18
|
+
"""Returns column element from query by name or None if column not found."""
|
|
19
|
+
return next((col for col in query.inner_columns if column_name(col) == name), None)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_query_id_column(query: "Select") -> ColT:
|
|
23
|
+
"""Returns ID column element from query or None if column not found."""
|
|
24
|
+
col = get_query_column(query, "sys__id")
|
|
25
|
+
if col is None:
|
|
26
|
+
raise RuntimeError("sys__id column not found in query")
|
|
27
|
+
return col
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def select_only_columns(query: "Select", *names: str) -> "Select":
|
|
31
|
+
"""Returns query selecting defined columns only."""
|
|
32
|
+
if not names:
|
|
33
|
+
return query
|
|
34
|
+
|
|
35
|
+
cols: list[ColT] = []
|
|
36
|
+
for name in names:
|
|
37
|
+
col = get_query_column(query, name)
|
|
38
|
+
if col is None:
|
|
39
|
+
raise ValueError(f"Column '{name}' not found in query")
|
|
40
|
+
cols.append(col)
|
|
41
|
+
|
|
42
|
+
return query.with_only_columns(*cols)
|
datachain/utils.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: datachain
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.1
|
|
4
4
|
Summary: Wrangle unstructured AI data at scale
|
|
5
5
|
Author-email: Dmitry Petrov <support@dvc.org>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -84,7 +84,7 @@ Requires-Dist: requests-mock; extra == "tests"
|
|
|
84
84
|
Requires-Dist: scipy; extra == "tests"
|
|
85
85
|
Provides-Extra: dev
|
|
86
86
|
Requires-Dist: datachain[docs,tests]; extra == "dev"
|
|
87
|
-
Requires-Dist: mypy==1.
|
|
87
|
+
Requires-Dist: mypy==1.14.0; extra == "dev"
|
|
88
88
|
Requires-Dist: types-python-dateutil; extra == "dev"
|
|
89
89
|
Requires-Dist: types-pytz; extra == "dev"
|
|
90
90
|
Requires-Dist: types-PyYAML; extra == "dev"
|
|
@@ -99,7 +99,7 @@ Requires-Dist: unstructured[pdf]; extra == "examples"
|
|
|
99
99
|
Requires-Dist: pdfplumber==0.11.4; extra == "examples"
|
|
100
100
|
Requires-Dist: huggingface_hub[hf_transfer]; extra == "examples"
|
|
101
101
|
Requires-Dist: onnx==1.16.1; extra == "examples"
|
|
102
|
-
Requires-Dist: ultralytics==8.3.
|
|
102
|
+
Requires-Dist: ultralytics==8.3.53; extra == "examples"
|
|
103
103
|
|
|
104
104
|
================
|
|
105
105
|
|logo| DataChain
|