datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/query/dispatch.py
CHANGED
|
@@ -1,23 +1,24 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import traceback
|
|
2
3
|
from collections.abc import Iterable, Sequence
|
|
3
4
|
from itertools import chain
|
|
4
5
|
from multiprocessing import cpu_count
|
|
6
|
+
from queue import Empty
|
|
5
7
|
from sys import stdin
|
|
6
|
-
from
|
|
7
|
-
from typing import TYPE_CHECKING,
|
|
8
|
+
from time import monotonic, sleep
|
|
9
|
+
from typing import TYPE_CHECKING, Literal
|
|
8
10
|
|
|
9
|
-
import attrs
|
|
10
11
|
import multiprocess
|
|
11
12
|
from cloudpickle import load, loads
|
|
12
13
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
13
|
-
from multiprocess import
|
|
14
|
-
from
|
|
14
|
+
from multiprocess.context import Process
|
|
15
|
+
from multiprocess.queues import Queue as MultiprocessQueue
|
|
15
16
|
|
|
16
17
|
from datachain.catalog import Catalog
|
|
17
18
|
from datachain.catalog.catalog import clone_catalog_with_cache
|
|
18
|
-
from datachain.catalog.loader import
|
|
19
|
-
from datachain.lib.
|
|
20
|
-
from datachain.
|
|
19
|
+
from datachain.catalog.loader import DISTRIBUTED_IMPORT_PATH, get_udf_distributor_class
|
|
20
|
+
from datachain.lib.model_store import ModelStore
|
|
21
|
+
from datachain.lib.udf import UdfRunError, _get_cache
|
|
21
22
|
from datachain.query.dataset import (
|
|
22
23
|
get_download_callback,
|
|
23
24
|
get_generated_callback,
|
|
@@ -26,7 +27,6 @@ from datachain.query.dataset import (
|
|
|
26
27
|
)
|
|
27
28
|
from datachain.query.queue import get_from_queue, put_into_queue
|
|
28
29
|
from datachain.query.udf import UdfInfo
|
|
29
|
-
from datachain.query.utils import get_query_id_column
|
|
30
30
|
from datachain.utils import batched, flatten, safe_closing
|
|
31
31
|
|
|
32
32
|
if TYPE_CHECKING:
|
|
@@ -34,6 +34,7 @@ if TYPE_CHECKING:
|
|
|
34
34
|
|
|
35
35
|
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
36
36
|
from datachain.lib.udf import UDFAdapter
|
|
37
|
+
from datachain.query.batch import RowsOutput
|
|
37
38
|
|
|
38
39
|
DEFAULT_BATCH_SIZE = 10000
|
|
39
40
|
STOP_SIGNAL = "STOP"
|
|
@@ -43,7 +44,7 @@ FAILED_STATUS = "FAILED"
|
|
|
43
44
|
NOTIFY_STATUS = "NOTIFY"
|
|
44
45
|
|
|
45
46
|
|
|
46
|
-
def get_n_workers_from_arg(n_workers:
|
|
47
|
+
def get_n_workers_from_arg(n_workers: int | None = None) -> int:
|
|
47
48
|
if not n_workers:
|
|
48
49
|
return cpu_count()
|
|
49
50
|
if n_workers < 1:
|
|
@@ -52,55 +53,60 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
|
52
53
|
|
|
53
54
|
|
|
54
55
|
def udf_entrypoint() -> int:
|
|
56
|
+
"""Parallel processing (faster for more CPU-heavy UDFs)."""
|
|
55
57
|
# Load UDF info from stdin
|
|
56
58
|
udf_info: UdfInfo = load(stdin.buffer)
|
|
57
59
|
|
|
58
|
-
# Parallel processing (faster for more CPU-heavy UDFs)
|
|
59
|
-
dispatch = UDFDispatcher(udf_info)
|
|
60
|
-
|
|
61
60
|
query = udf_info["query"]
|
|
61
|
+
if "sys__id" not in query.selected_columns:
|
|
62
|
+
raise RuntimeError("sys__id column is required in UDF query")
|
|
63
|
+
|
|
62
64
|
batching = udf_info["batching"]
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
is_generator = udf_info["is_generator"]
|
|
66
|
+
|
|
67
|
+
download_cb = get_download_callback()
|
|
68
|
+
processed_cb = get_processed_callback()
|
|
69
|
+
generated_cb = get_generated_callback(is_generator)
|
|
66
70
|
|
|
67
71
|
wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
|
|
68
72
|
warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
|
|
69
73
|
|
|
70
|
-
total_rows = next(
|
|
71
|
-
warehouse.db.execute(
|
|
72
|
-
query.with_only_columns(func.count(query.c.sys__id)).order_by(None)
|
|
73
|
-
)
|
|
74
|
-
)[0]
|
|
75
|
-
|
|
76
74
|
with contextlib.closing(
|
|
77
|
-
batching(
|
|
75
|
+
batching(
|
|
76
|
+
warehouse.dataset_select_paginated,
|
|
77
|
+
query,
|
|
78
|
+
id_col=query.selected_columns.sys__id,
|
|
79
|
+
)
|
|
78
80
|
) as udf_inputs:
|
|
79
|
-
download_cb = get_download_callback()
|
|
80
|
-
processed_cb = get_processed_callback()
|
|
81
81
|
try:
|
|
82
|
-
|
|
82
|
+
UDFDispatcher(udf_info).run_udf(
|
|
83
83
|
udf_inputs,
|
|
84
|
-
total_rows=total_rows,
|
|
85
|
-
n_workers=n_workers,
|
|
86
|
-
processed_cb=processed_cb,
|
|
87
84
|
download_cb=download_cb,
|
|
85
|
+
processed_cb=processed_cb,
|
|
86
|
+
generated_cb=generated_cb,
|
|
88
87
|
)
|
|
89
88
|
finally:
|
|
90
89
|
download_cb.close()
|
|
91
90
|
processed_cb.close()
|
|
91
|
+
generated_cb.close()
|
|
92
92
|
|
|
93
93
|
return 0
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
def udf_worker_entrypoint() -> int:
|
|
97
|
-
|
|
97
|
+
if not (udf_distributor_class := get_udf_distributor_class()):
|
|
98
|
+
raise RuntimeError(
|
|
99
|
+
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
100
|
+
"for distributed UDF processing."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return udf_distributor_class.run_udf()
|
|
98
104
|
|
|
99
105
|
|
|
100
106
|
class UDFDispatcher:
|
|
101
|
-
|
|
102
|
-
task_queue:
|
|
103
|
-
done_queue:
|
|
107
|
+
_catalog: Catalog | None = None
|
|
108
|
+
task_queue: MultiprocessQueue | None = None
|
|
109
|
+
done_queue: MultiprocessQueue | None = None
|
|
104
110
|
|
|
105
111
|
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
|
|
106
112
|
self.udf_data = udf_info["udf_data"]
|
|
@@ -113,30 +119,38 @@ class UDFDispatcher:
|
|
|
113
119
|
self.cache = udf_info["cache"]
|
|
114
120
|
self.is_generator = udf_info["is_generator"]
|
|
115
121
|
self.is_batching = udf_info["batching"].is_batching
|
|
122
|
+
self.processes = udf_info["processes"]
|
|
123
|
+
self.rows_total = udf_info["rows_total"]
|
|
124
|
+
self.batch_size = udf_info["batch_size"]
|
|
116
125
|
self.buffer_size = buffer_size
|
|
117
|
-
self.catalog = None
|
|
118
126
|
self.task_queue = None
|
|
119
127
|
self.done_queue = None
|
|
120
|
-
self.ctx = get_context("spawn")
|
|
128
|
+
self.ctx = multiprocess.get_context("spawn")
|
|
121
129
|
|
|
122
|
-
|
|
123
|
-
|
|
130
|
+
@property
|
|
131
|
+
def catalog(self) -> "Catalog":
|
|
132
|
+
if not self._catalog:
|
|
124
133
|
ms_cls, ms_args, ms_kwargs = self.metastore_clone_params
|
|
125
134
|
metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs)
|
|
126
135
|
ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params
|
|
127
136
|
warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs)
|
|
128
|
-
self.
|
|
129
|
-
|
|
137
|
+
self._catalog = Catalog(metastore, warehouse, **self.catalog_init_params)
|
|
138
|
+
return self._catalog
|
|
139
|
+
|
|
140
|
+
def _create_worker(self) -> "UDFWorker":
|
|
141
|
+
udf: UDFAdapter = loads(self.udf_data)
|
|
142
|
+
# Ensure all registered DataModels have rebuilt schemas in worker processes.
|
|
143
|
+
ModelStore.rebuild_all()
|
|
130
144
|
return UDFWorker(
|
|
131
145
|
self.catalog,
|
|
132
|
-
|
|
146
|
+
udf,
|
|
133
147
|
self.task_queue,
|
|
134
148
|
self.done_queue,
|
|
135
149
|
self.query,
|
|
136
150
|
self.table,
|
|
137
|
-
self.is_generator,
|
|
138
|
-
self.is_batching,
|
|
139
151
|
self.cache,
|
|
152
|
+
self.is_batching,
|
|
153
|
+
self.batch_size,
|
|
140
154
|
self.udf_fields,
|
|
141
155
|
)
|
|
142
156
|
|
|
@@ -146,45 +160,109 @@ class UDFDispatcher:
|
|
|
146
160
|
worker.run()
|
|
147
161
|
except (Exception, KeyboardInterrupt) as e:
|
|
148
162
|
if self.done_queue:
|
|
163
|
+
# We put the exception into the done queue so the main process
|
|
164
|
+
# can handle it appropriately. We include the stacktrace to propagate
|
|
165
|
+
# it to the main process and show it to the user.
|
|
149
166
|
put_into_queue(
|
|
150
167
|
self.done_queue,
|
|
151
|
-
{
|
|
168
|
+
{
|
|
169
|
+
"status": FAILED_STATUS,
|
|
170
|
+
"exception": e,
|
|
171
|
+
"stacktrace": traceback.format_exc(),
|
|
172
|
+
},
|
|
152
173
|
)
|
|
174
|
+
if isinstance(e, KeyboardInterrupt):
|
|
175
|
+
return
|
|
153
176
|
raise
|
|
154
177
|
|
|
155
|
-
|
|
156
|
-
|
|
178
|
+
def run_udf(
|
|
179
|
+
self,
|
|
180
|
+
input_rows: Iterable["RowsOutput"],
|
|
181
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
182
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
183
|
+
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
184
|
+
) -> None:
|
|
185
|
+
n_workers = self.processes
|
|
186
|
+
if n_workers is True:
|
|
187
|
+
n_workers = None # Use default number of CPUs (cores)
|
|
188
|
+
elif not n_workers or n_workers < 1:
|
|
189
|
+
n_workers = 1 # Single-threaded (on this worker)
|
|
157
190
|
n_workers = get_n_workers_from_arg(n_workers)
|
|
158
|
-
for _ in range(n_workers):
|
|
159
|
-
put_into_queue(task_queue, STOP_SIGNAL)
|
|
160
191
|
|
|
161
|
-
|
|
162
|
-
|
|
192
|
+
if n_workers == 1:
|
|
193
|
+
# no need to spawn worker processes if we are running in a single process
|
|
194
|
+
self.run_udf_single(input_rows, download_cb, processed_cb, generated_cb)
|
|
195
|
+
else:
|
|
196
|
+
if self.buffer_size < n_workers:
|
|
197
|
+
raise RuntimeError(
|
|
198
|
+
"Parallel run error: buffer size is smaller than "
|
|
199
|
+
f"number of workers: {self.buffer_size} < {n_workers}"
|
|
200
|
+
)
|
|
163
201
|
|
|
164
|
-
|
|
202
|
+
self.run_udf_parallel(
|
|
203
|
+
n_workers, input_rows, download_cb, processed_cb, generated_cb
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def run_udf_single(
|
|
165
207
|
self,
|
|
166
|
-
input_rows: Iterable[RowsOutput],
|
|
167
|
-
total_rows: int,
|
|
168
|
-
n_workers: Optional[int] = None,
|
|
169
|
-
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
208
|
+
input_rows: Iterable["RowsOutput"],
|
|
170
209
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
210
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
211
|
+
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
171
212
|
) -> None:
|
|
172
|
-
|
|
213
|
+
udf: UDFAdapter = loads(self.udf_data)
|
|
214
|
+
# Rebuild schemas in single process too for consistency (cheap, idempotent).
|
|
215
|
+
ModelStore.rebuild_all()
|
|
216
|
+
|
|
217
|
+
if not self.is_batching:
|
|
218
|
+
input_rows = flatten(input_rows)
|
|
219
|
+
|
|
220
|
+
def get_inputs() -> Iterable["RowsOutput"]:
|
|
221
|
+
warehouse = self.catalog.warehouse.clone()
|
|
222
|
+
for ids in batched(input_rows, DEFAULT_BATCH_SIZE):
|
|
223
|
+
yield from warehouse.dataset_rows_select_from_ids(
|
|
224
|
+
self.query, ids, self.is_batching
|
|
225
|
+
)
|
|
173
226
|
|
|
174
|
-
|
|
227
|
+
prefetch = udf.prefetch
|
|
228
|
+
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
229
|
+
udf_results = udf.run(
|
|
230
|
+
self.udf_fields,
|
|
231
|
+
get_inputs(),
|
|
232
|
+
self.catalog,
|
|
233
|
+
self.cache,
|
|
234
|
+
download_cb=download_cb,
|
|
235
|
+
processed_cb=processed_cb,
|
|
236
|
+
)
|
|
237
|
+
with safe_closing(udf_results):
|
|
238
|
+
process_udf_outputs(
|
|
239
|
+
self.catalog.warehouse.clone(),
|
|
240
|
+
self.table,
|
|
241
|
+
udf_results,
|
|
242
|
+
udf,
|
|
243
|
+
cb=generated_cb,
|
|
244
|
+
batch_size=self.batch_size,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def input_batch_size(self, n_workers: int) -> int:
|
|
248
|
+
input_batch_size = self.rows_total // n_workers
|
|
175
249
|
if input_batch_size == 0:
|
|
176
250
|
input_batch_size = 1
|
|
177
251
|
elif input_batch_size > DEFAULT_BATCH_SIZE:
|
|
178
252
|
input_batch_size = DEFAULT_BATCH_SIZE
|
|
253
|
+
return input_batch_size
|
|
179
254
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
255
|
+
def run_udf_parallel( # noqa: C901, PLR0912
|
|
256
|
+
self,
|
|
257
|
+
n_workers: int,
|
|
258
|
+
input_rows: Iterable["RowsOutput"],
|
|
259
|
+
download_cb: Callback = DEFAULT_CALLBACK,
|
|
260
|
+
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
261
|
+
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
262
|
+
) -> None:
|
|
186
263
|
self.task_queue = self.ctx.Queue()
|
|
187
264
|
self.done_queue = self.ctx.Queue()
|
|
265
|
+
|
|
188
266
|
pool = [
|
|
189
267
|
self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
|
|
190
268
|
for i in range(n_workers)
|
|
@@ -192,14 +270,14 @@ class UDFDispatcher:
|
|
|
192
270
|
for p in pool:
|
|
193
271
|
p.start()
|
|
194
272
|
|
|
195
|
-
# Will be set to True if all tasks complete normally
|
|
196
|
-
normal_completion = False
|
|
197
273
|
try:
|
|
198
274
|
# Will be set to True when the input is exhausted
|
|
199
275
|
input_finished = False
|
|
200
276
|
|
|
201
|
-
|
|
202
|
-
input_rows
|
|
277
|
+
input_rows = batched(
|
|
278
|
+
input_rows if self.is_batching else flatten(input_rows),
|
|
279
|
+
self.input_batch_size(n_workers),
|
|
280
|
+
)
|
|
203
281
|
|
|
204
282
|
# Stop all workers after the input rows have finished processing
|
|
205
283
|
input_data = chain(input_rows, [STOP_SIGNAL] * n_workers)
|
|
@@ -214,12 +292,29 @@ class UDFDispatcher:
|
|
|
214
292
|
|
|
215
293
|
# Process all tasks
|
|
216
294
|
while n_workers > 0:
|
|
217
|
-
|
|
218
|
-
|
|
295
|
+
while True:
|
|
296
|
+
try:
|
|
297
|
+
result = self.done_queue.get_nowait()
|
|
298
|
+
break
|
|
299
|
+
except Empty:
|
|
300
|
+
for p in pool:
|
|
301
|
+
exitcode = p.exitcode
|
|
302
|
+
if exitcode not in (None, 0):
|
|
303
|
+
message = (
|
|
304
|
+
f"Worker {p.name} exited unexpectedly with "
|
|
305
|
+
f"code {exitcode}"
|
|
306
|
+
)
|
|
307
|
+
raise RuntimeError(message) from None
|
|
308
|
+
sleep(0.01)
|
|
309
|
+
|
|
310
|
+
if bytes_downloaded := result.get("bytes_downloaded"):
|
|
311
|
+
download_cb.relative_update(bytes_downloaded)
|
|
219
312
|
if downloaded := result.get("downloaded"):
|
|
220
|
-
download_cb.
|
|
313
|
+
download_cb.increment_file_count(downloaded)
|
|
221
314
|
if processed := result.get("processed"):
|
|
222
315
|
processed_cb.relative_update(processed)
|
|
316
|
+
if generated := result.get("generated"):
|
|
317
|
+
generated_cb.relative_update(generated)
|
|
223
318
|
|
|
224
319
|
status = result["status"]
|
|
225
320
|
if status in (OK_STATUS, NOTIFY_STATUS):
|
|
@@ -229,7 +324,9 @@ class UDFDispatcher:
|
|
|
229
324
|
else: # Failed / error
|
|
230
325
|
n_workers -= 1
|
|
231
326
|
if exc := result.get("exception"):
|
|
232
|
-
|
|
327
|
+
if isinstance(exc, KeyboardInterrupt):
|
|
328
|
+
raise exc
|
|
329
|
+
raise UdfRunError(exc, stacktrace=result.get("stacktrace"))
|
|
233
330
|
raise RuntimeError("Internal error: Parallel UDF execution failed")
|
|
234
331
|
|
|
235
332
|
if status == OK_STATUS and not input_finished:
|
|
@@ -237,75 +334,104 @@ class UDFDispatcher:
|
|
|
237
334
|
put_into_queue(self.task_queue, next(input_data))
|
|
238
335
|
except StopIteration:
|
|
239
336
|
input_finished = True
|
|
240
|
-
|
|
241
|
-
# Finished with all tasks normally
|
|
242
|
-
normal_completion = True
|
|
243
337
|
finally:
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
338
|
+
self._shutdown_workers(pool)
|
|
339
|
+
|
|
340
|
+
def _shutdown_workers(self, pool: list[Process]) -> None:
|
|
341
|
+
self._terminate_pool(pool)
|
|
342
|
+
self._drain_queue(self.done_queue)
|
|
343
|
+
self._drain_queue(self.task_queue)
|
|
344
|
+
self._close_queue(self.done_queue)
|
|
345
|
+
self._close_queue(self.task_queue)
|
|
346
|
+
|
|
347
|
+
def _terminate_pool(self, pool: list[Process]) -> None:
|
|
348
|
+
for proc in pool:
|
|
349
|
+
if proc.is_alive():
|
|
350
|
+
proc.terminate()
|
|
351
|
+
|
|
352
|
+
deadline = monotonic() + 1.0
|
|
353
|
+
for proc in pool:
|
|
354
|
+
if not proc.is_alive():
|
|
355
|
+
continue
|
|
356
|
+
remaining = deadline - monotonic()
|
|
357
|
+
if remaining > 0:
|
|
358
|
+
proc.join(remaining)
|
|
359
|
+
if proc.is_alive():
|
|
360
|
+
proc.kill()
|
|
361
|
+
proc.join(timeout=0.2)
|
|
362
|
+
|
|
363
|
+
def _drain_queue(self, queue: MultiprocessQueue) -> None:
|
|
364
|
+
while True:
|
|
365
|
+
try:
|
|
366
|
+
queue.get_nowait()
|
|
367
|
+
except Empty:
|
|
368
|
+
return
|
|
369
|
+
except (OSError, ValueError):
|
|
370
|
+
return
|
|
371
|
+
|
|
372
|
+
def _close_queue(self, queue: MultiprocessQueue) -> None:
|
|
373
|
+
with contextlib.suppress(OSError, ValueError):
|
|
374
|
+
queue.close()
|
|
375
|
+
with contextlib.suppress(RuntimeError, AssertionError, ValueError):
|
|
376
|
+
queue.join_thread()
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
class DownloadCallback(Callback):
|
|
380
|
+
def __init__(self, queue: MultiprocessQueue) -> None:
|
|
271
381
|
self.queue = queue
|
|
272
382
|
super().__init__()
|
|
273
383
|
|
|
274
384
|
def relative_update(self, inc: int = 1) -> None:
|
|
385
|
+
put_into_queue(self.queue, {"status": NOTIFY_STATUS, "bytes_downloaded": inc})
|
|
386
|
+
|
|
387
|
+
def increment_file_count(self, inc: int = 1) -> None:
|
|
275
388
|
put_into_queue(self.queue, {"status": NOTIFY_STATUS, "downloaded": inc})
|
|
276
389
|
|
|
277
390
|
|
|
278
391
|
class ProcessedCallback(Callback):
|
|
279
|
-
def __init__(
|
|
280
|
-
self
|
|
392
|
+
def __init__(
|
|
393
|
+
self,
|
|
394
|
+
name: Literal["processed", "generated"],
|
|
395
|
+
queue: MultiprocessQueue,
|
|
396
|
+
) -> None:
|
|
397
|
+
self.name = name
|
|
398
|
+
self.queue = queue
|
|
281
399
|
super().__init__()
|
|
282
400
|
|
|
283
401
|
def relative_update(self, inc: int = 1) -> None:
|
|
284
|
-
self.
|
|
402
|
+
put_into_queue(self.queue, {"status": NOTIFY_STATUS, self.name: inc})
|
|
285
403
|
|
|
286
404
|
|
|
287
|
-
@attrs.define
|
|
288
405
|
class UDFWorker:
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
406
|
+
def __init__(
|
|
407
|
+
self,
|
|
408
|
+
catalog: "Catalog",
|
|
409
|
+
udf: "UDFAdapter",
|
|
410
|
+
task_queue: MultiprocessQueue,
|
|
411
|
+
done_queue: MultiprocessQueue,
|
|
412
|
+
query: "Select",
|
|
413
|
+
table: "Table",
|
|
414
|
+
cache: bool,
|
|
415
|
+
is_batching: bool,
|
|
416
|
+
batch_size: int,
|
|
417
|
+
udf_fields: Sequence[str],
|
|
418
|
+
) -> None:
|
|
419
|
+
self.catalog = catalog
|
|
420
|
+
self.udf = udf
|
|
421
|
+
self.task_queue = task_queue
|
|
422
|
+
self.done_queue = done_queue
|
|
423
|
+
self.query = query
|
|
424
|
+
self.table = table
|
|
425
|
+
self.cache = cache
|
|
426
|
+
self.is_batching = is_batching
|
|
427
|
+
self.batch_size = batch_size
|
|
428
|
+
self.udf_fields = udf_fields
|
|
429
|
+
|
|
430
|
+
self.download_cb = DownloadCallback(self.done_queue)
|
|
431
|
+
self.processed_cb = ProcessedCallback("processed", self.done_queue)
|
|
432
|
+
self.generated_cb = ProcessedCallback("generated", self.done_queue)
|
|
304
433
|
|
|
305
434
|
def run(self) -> None:
|
|
306
|
-
processed_cb = ProcessedCallback()
|
|
307
|
-
generated_cb = get_generated_callback(self.is_generator)
|
|
308
|
-
|
|
309
435
|
prefetch = self.udf.prefetch
|
|
310
436
|
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
311
437
|
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
@@ -314,48 +440,29 @@ class UDFWorker:
|
|
|
314
440
|
self.get_inputs(),
|
|
315
441
|
catalog,
|
|
316
442
|
self.cache,
|
|
317
|
-
download_cb=self.
|
|
318
|
-
processed_cb=processed_cb,
|
|
443
|
+
download_cb=self.download_cb,
|
|
444
|
+
processed_cb=self.processed_cb,
|
|
319
445
|
)
|
|
320
446
|
with safe_closing(udf_results):
|
|
321
447
|
process_udf_outputs(
|
|
322
448
|
catalog.warehouse,
|
|
323
449
|
self.table,
|
|
324
|
-
self.notify_and_process(udf_results
|
|
450
|
+
self.notify_and_process(udf_results),
|
|
325
451
|
self.udf,
|
|
326
|
-
cb=generated_cb,
|
|
452
|
+
cb=self.generated_cb,
|
|
453
|
+
batch_size=self.batch_size,
|
|
327
454
|
)
|
|
455
|
+
put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
|
|
328
456
|
|
|
329
|
-
|
|
330
|
-
self.done_queue,
|
|
331
|
-
{"status": FINISHED_STATUS, "processed": processed_cb.processed_rows},
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
def notify_and_process(self, udf_results, processed_cb):
|
|
457
|
+
def notify_and_process(self, udf_results):
|
|
335
458
|
for row in udf_results:
|
|
336
|
-
put_into_queue(
|
|
337
|
-
self.done_queue,
|
|
338
|
-
{"status": OK_STATUS, "processed": processed_cb.processed_rows},
|
|
339
|
-
)
|
|
459
|
+
put_into_queue(self.done_queue, {"status": OK_STATUS})
|
|
340
460
|
yield row
|
|
341
461
|
|
|
342
|
-
def get_inputs(self):
|
|
462
|
+
def get_inputs(self) -> Iterable["RowsOutput"]:
|
|
343
463
|
warehouse = self.catalog.warehouse.clone()
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
ids = [row[0] for row in batch.rows]
|
|
349
|
-
rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
|
|
350
|
-
yield RowsOutputBatch(list(rows))
|
|
351
|
-
else:
|
|
352
|
-
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
353
|
-
yield from warehouse.dataset_rows_select(
|
|
354
|
-
self.query.where(col_id.in_(batch))
|
|
464
|
+
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
465
|
+
for ids in batched(batch, DEFAULT_BATCH_SIZE):
|
|
466
|
+
yield from warehouse.dataset_rows_select_from_ids(
|
|
467
|
+
self.query, ids, self.is_batching
|
|
355
468
|
)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
class RepeatTimer(Timer):
|
|
359
|
-
def run(self):
|
|
360
|
-
while not self.finished.wait(self.interval):
|
|
361
|
-
self.function(*self.args, **self.kwargs)
|
datachain/query/metrics.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Optional, Union
|
|
3
2
|
|
|
4
|
-
metrics: dict[str,
|
|
3
|
+
metrics: dict[str, str | int | float | bool | None] = {}
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def set(key: str, value:
|
|
6
|
+
def set(key: str, value: str | int | float | bool | None) -> None: # noqa: PYI041
|
|
8
7
|
"""Set a metric value."""
|
|
9
8
|
if not isinstance(key, str):
|
|
10
9
|
raise TypeError("Key must be a string")
|
|
@@ -15,13 +14,12 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
|
|
|
15
14
|
metrics[key] = value
|
|
16
15
|
|
|
17
16
|
if job_id := os.getenv("DATACHAIN_JOB_ID"):
|
|
18
|
-
from datachain.data_storage.job import JobStatus
|
|
19
17
|
from datachain.query.session import Session
|
|
20
18
|
|
|
21
19
|
metastore = Session.get().catalog.metastore
|
|
22
|
-
metastore.
|
|
20
|
+
metastore.update_job(job_id, metrics=metrics)
|
|
23
21
|
|
|
24
22
|
|
|
25
|
-
def get(key: str) ->
|
|
23
|
+
def get(key: str) -> str | int | float | bool | None:
|
|
26
24
|
"""Get a metric value."""
|
|
27
25
|
return metrics[key]
|
datachain/query/params.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
|
-
params_cache:
|
|
4
|
+
params_cache: dict[str, str] | None = None
|
|
6
5
|
|
|
7
6
|
|
|
8
|
-
def param(key: str, default:
|
|
7
|
+
def param(key: str, default: str | None = None) -> str | None:
|
|
9
8
|
"""Get query parameter."""
|
|
10
9
|
if not isinstance(key, str):
|
|
11
10
|
raise TypeError("Param key must be a string")
|