datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/query/dispatch.py
CHANGED
|
@@ -1,20 +1,24 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import traceback
|
|
2
3
|
from collections.abc import Iterable, Sequence
|
|
3
4
|
from itertools import chain
|
|
4
5
|
from multiprocessing import cpu_count
|
|
6
|
+
from queue import Empty
|
|
5
7
|
from sys import stdin
|
|
6
|
-
from
|
|
8
|
+
from time import monotonic, sleep
|
|
9
|
+
from typing import TYPE_CHECKING, Literal
|
|
7
10
|
|
|
8
11
|
import multiprocess
|
|
9
12
|
from cloudpickle import load, loads
|
|
10
13
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
11
|
-
from multiprocess import
|
|
14
|
+
from multiprocess.context import Process
|
|
15
|
+
from multiprocess.queues import Queue as MultiprocessQueue
|
|
12
16
|
|
|
13
17
|
from datachain.catalog import Catalog
|
|
14
18
|
from datachain.catalog.catalog import clone_catalog_with_cache
|
|
15
19
|
from datachain.catalog.loader import DISTRIBUTED_IMPORT_PATH, get_udf_distributor_class
|
|
16
20
|
from datachain.lib.model_store import ModelStore
|
|
17
|
-
from datachain.lib.udf import _get_cache
|
|
21
|
+
from datachain.lib.udf import UdfRunError, _get_cache
|
|
18
22
|
from datachain.query.dataset import (
|
|
19
23
|
get_download_callback,
|
|
20
24
|
get_generated_callback,
|
|
@@ -23,7 +27,6 @@ from datachain.query.dataset import (
|
|
|
23
27
|
)
|
|
24
28
|
from datachain.query.queue import get_from_queue, put_into_queue
|
|
25
29
|
from datachain.query.udf import UdfInfo
|
|
26
|
-
from datachain.query.utils import get_query_id_column
|
|
27
30
|
from datachain.utils import batched, flatten, safe_closing
|
|
28
31
|
|
|
29
32
|
if TYPE_CHECKING:
|
|
@@ -41,7 +44,7 @@ FAILED_STATUS = "FAILED"
|
|
|
41
44
|
NOTIFY_STATUS = "NOTIFY"
|
|
42
45
|
|
|
43
46
|
|
|
44
|
-
def get_n_workers_from_arg(n_workers:
|
|
47
|
+
def get_n_workers_from_arg(n_workers: int | None = None) -> int:
|
|
45
48
|
if not n_workers:
|
|
46
49
|
return cpu_count()
|
|
47
50
|
if n_workers < 1:
|
|
@@ -55,6 +58,9 @@ def udf_entrypoint() -> int:
|
|
|
55
58
|
udf_info: UdfInfo = load(stdin.buffer)
|
|
56
59
|
|
|
57
60
|
query = udf_info["query"]
|
|
61
|
+
if "sys__id" not in query.selected_columns:
|
|
62
|
+
raise RuntimeError("sys__id column is required in UDF query")
|
|
63
|
+
|
|
58
64
|
batching = udf_info["batching"]
|
|
59
65
|
is_generator = udf_info["is_generator"]
|
|
60
66
|
|
|
@@ -65,15 +71,16 @@ def udf_entrypoint() -> int:
|
|
|
65
71
|
wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"]
|
|
66
72
|
warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs)
|
|
67
73
|
|
|
68
|
-
id_col = get_query_id_column(query)
|
|
69
|
-
|
|
70
74
|
with contextlib.closing(
|
|
71
|
-
batching(
|
|
75
|
+
batching(
|
|
76
|
+
warehouse.dataset_select_paginated,
|
|
77
|
+
query,
|
|
78
|
+
id_col=query.selected_columns.sys__id,
|
|
79
|
+
)
|
|
72
80
|
) as udf_inputs:
|
|
73
81
|
try:
|
|
74
82
|
UDFDispatcher(udf_info).run_udf(
|
|
75
83
|
udf_inputs,
|
|
76
|
-
ids_only=id_col is not None,
|
|
77
84
|
download_cb=download_cb,
|
|
78
85
|
processed_cb=processed_cb,
|
|
79
86
|
generated_cb=generated_cb,
|
|
@@ -86,20 +93,20 @@ def udf_entrypoint() -> int:
|
|
|
86
93
|
return 0
|
|
87
94
|
|
|
88
95
|
|
|
89
|
-
def udf_worker_entrypoint(
|
|
96
|
+
def udf_worker_entrypoint() -> int:
|
|
90
97
|
if not (udf_distributor_class := get_udf_distributor_class()):
|
|
91
98
|
raise RuntimeError(
|
|
92
99
|
f"{DISTRIBUTED_IMPORT_PATH} import path is required "
|
|
93
100
|
"for distributed UDF processing."
|
|
94
101
|
)
|
|
95
102
|
|
|
96
|
-
return udf_distributor_class.run_udf(
|
|
103
|
+
return udf_distributor_class.run_udf()
|
|
97
104
|
|
|
98
105
|
|
|
99
106
|
class UDFDispatcher:
|
|
100
|
-
_catalog:
|
|
101
|
-
task_queue:
|
|
102
|
-
done_queue:
|
|
107
|
+
_catalog: Catalog | None = None
|
|
108
|
+
task_queue: MultiprocessQueue | None = None
|
|
109
|
+
done_queue: MultiprocessQueue | None = None
|
|
103
110
|
|
|
104
111
|
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
|
|
105
112
|
self.udf_data = udf_info["udf_data"]
|
|
@@ -114,10 +121,11 @@ class UDFDispatcher:
|
|
|
114
121
|
self.is_batching = udf_info["batching"].is_batching
|
|
115
122
|
self.processes = udf_info["processes"]
|
|
116
123
|
self.rows_total = udf_info["rows_total"]
|
|
124
|
+
self.batch_size = udf_info["batch_size"]
|
|
117
125
|
self.buffer_size = buffer_size
|
|
118
126
|
self.task_queue = None
|
|
119
127
|
self.done_queue = None
|
|
120
|
-
self.ctx = get_context("spawn")
|
|
128
|
+
self.ctx = multiprocess.get_context("spawn")
|
|
121
129
|
|
|
122
130
|
@property
|
|
123
131
|
def catalog(self) -> "Catalog":
|
|
@@ -142,18 +150,26 @@ class UDFDispatcher:
|
|
|
142
150
|
self.table,
|
|
143
151
|
self.cache,
|
|
144
152
|
self.is_batching,
|
|
153
|
+
self.batch_size,
|
|
145
154
|
self.udf_fields,
|
|
146
155
|
)
|
|
147
156
|
|
|
148
|
-
def _run_worker(self
|
|
157
|
+
def _run_worker(self) -> None:
|
|
149
158
|
try:
|
|
150
159
|
worker = self._create_worker()
|
|
151
|
-
worker.run(
|
|
160
|
+
worker.run()
|
|
152
161
|
except (Exception, KeyboardInterrupt) as e:
|
|
153
162
|
if self.done_queue:
|
|
163
|
+
# We put the exception into the done queue so the main process
|
|
164
|
+
# can handle it appropriately. We include the stacktrace to propagate
|
|
165
|
+
# it to the main process and show it to the user.
|
|
154
166
|
put_into_queue(
|
|
155
167
|
self.done_queue,
|
|
156
|
-
{
|
|
168
|
+
{
|
|
169
|
+
"status": FAILED_STATUS,
|
|
170
|
+
"exception": e,
|
|
171
|
+
"stacktrace": traceback.format_exc(),
|
|
172
|
+
},
|
|
157
173
|
)
|
|
158
174
|
if isinstance(e, KeyboardInterrupt):
|
|
159
175
|
return
|
|
@@ -162,7 +178,6 @@ class UDFDispatcher:
|
|
|
162
178
|
def run_udf(
|
|
163
179
|
self,
|
|
164
180
|
input_rows: Iterable["RowsOutput"],
|
|
165
|
-
ids_only: bool,
|
|
166
181
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
167
182
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
168
183
|
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
@@ -176,9 +191,7 @@ class UDFDispatcher:
|
|
|
176
191
|
|
|
177
192
|
if n_workers == 1:
|
|
178
193
|
# no need to spawn worker processes if we are running in a single process
|
|
179
|
-
self.run_udf_single(
|
|
180
|
-
input_rows, ids_only, download_cb, processed_cb, generated_cb
|
|
181
|
-
)
|
|
194
|
+
self.run_udf_single(input_rows, download_cb, processed_cb, generated_cb)
|
|
182
195
|
else:
|
|
183
196
|
if self.buffer_size < n_workers:
|
|
184
197
|
raise RuntimeError(
|
|
@@ -187,13 +200,12 @@ class UDFDispatcher:
|
|
|
187
200
|
)
|
|
188
201
|
|
|
189
202
|
self.run_udf_parallel(
|
|
190
|
-
n_workers, input_rows,
|
|
203
|
+
n_workers, input_rows, download_cb, processed_cb, generated_cb
|
|
191
204
|
)
|
|
192
205
|
|
|
193
206
|
def run_udf_single(
|
|
194
207
|
self,
|
|
195
208
|
input_rows: Iterable["RowsOutput"],
|
|
196
|
-
ids_only: bool,
|
|
197
209
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
198
210
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
199
211
|
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
@@ -202,18 +214,15 @@ class UDFDispatcher:
|
|
|
202
214
|
# Rebuild schemas in single process too for consistency (cheap, idempotent).
|
|
203
215
|
ModelStore.rebuild_all()
|
|
204
216
|
|
|
205
|
-
if
|
|
217
|
+
if not self.is_batching:
|
|
206
218
|
input_rows = flatten(input_rows)
|
|
207
219
|
|
|
208
220
|
def get_inputs() -> Iterable["RowsOutput"]:
|
|
209
221
|
warehouse = self.catalog.warehouse.clone()
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
)
|
|
215
|
-
else:
|
|
216
|
-
yield from input_rows
|
|
222
|
+
for ids in batched(input_rows, DEFAULT_BATCH_SIZE):
|
|
223
|
+
yield from warehouse.dataset_rows_select_from_ids(
|
|
224
|
+
self.query, ids, self.is_batching
|
|
225
|
+
)
|
|
217
226
|
|
|
218
227
|
prefetch = udf.prefetch
|
|
219
228
|
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
@@ -232,6 +241,7 @@ class UDFDispatcher:
|
|
|
232
241
|
udf_results,
|
|
233
242
|
udf,
|
|
234
243
|
cb=generated_cb,
|
|
244
|
+
batch_size=self.batch_size,
|
|
235
245
|
)
|
|
236
246
|
|
|
237
247
|
def input_batch_size(self, n_workers: int) -> int:
|
|
@@ -246,7 +256,6 @@ class UDFDispatcher:
|
|
|
246
256
|
self,
|
|
247
257
|
n_workers: int,
|
|
248
258
|
input_rows: Iterable["RowsOutput"],
|
|
249
|
-
ids_only: bool,
|
|
250
259
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
251
260
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
252
261
|
generated_cb: Callback = DEFAULT_CALLBACK,
|
|
@@ -255,16 +264,12 @@ class UDFDispatcher:
|
|
|
255
264
|
self.done_queue = self.ctx.Queue()
|
|
256
265
|
|
|
257
266
|
pool = [
|
|
258
|
-
self.ctx.Process(
|
|
259
|
-
name=f"Worker-UDF-{i}", target=self._run_worker, args=[ids_only]
|
|
260
|
-
)
|
|
267
|
+
self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker)
|
|
261
268
|
for i in range(n_workers)
|
|
262
269
|
]
|
|
263
270
|
for p in pool:
|
|
264
271
|
p.start()
|
|
265
272
|
|
|
266
|
-
# Will be set to True if all tasks complete normally
|
|
267
|
-
normal_completion = False
|
|
268
273
|
try:
|
|
269
274
|
# Will be set to True when the input is exhausted
|
|
270
275
|
input_finished = False
|
|
@@ -287,10 +292,20 @@ class UDFDispatcher:
|
|
|
287
292
|
|
|
288
293
|
# Process all tasks
|
|
289
294
|
while n_workers > 0:
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
295
|
+
while True:
|
|
296
|
+
try:
|
|
297
|
+
result = self.done_queue.get_nowait()
|
|
298
|
+
break
|
|
299
|
+
except Empty:
|
|
300
|
+
for p in pool:
|
|
301
|
+
exitcode = p.exitcode
|
|
302
|
+
if exitcode not in (None, 0):
|
|
303
|
+
message = (
|
|
304
|
+
f"Worker {p.name} exited unexpectedly with "
|
|
305
|
+
f"code {exitcode}"
|
|
306
|
+
)
|
|
307
|
+
raise RuntimeError(message) from None
|
|
308
|
+
sleep(0.01)
|
|
294
309
|
|
|
295
310
|
if bytes_downloaded := result.get("bytes_downloaded"):
|
|
296
311
|
download_cb.relative_update(bytes_downloaded)
|
|
@@ -309,7 +324,9 @@ class UDFDispatcher:
|
|
|
309
324
|
else: # Failed / error
|
|
310
325
|
n_workers -= 1
|
|
311
326
|
if exc := result.get("exception"):
|
|
312
|
-
|
|
327
|
+
if isinstance(exc, KeyboardInterrupt):
|
|
328
|
+
raise exc
|
|
329
|
+
raise UdfRunError(exc, stacktrace=result.get("stacktrace"))
|
|
313
330
|
raise RuntimeError("Internal error: Parallel UDF execution failed")
|
|
314
331
|
|
|
315
332
|
if status == OK_STATUS and not input_finished:
|
|
@@ -317,39 +334,50 @@ class UDFDispatcher:
|
|
|
317
334
|
put_into_queue(self.task_queue, next(input_data))
|
|
318
335
|
except StopIteration:
|
|
319
336
|
input_finished = True
|
|
320
|
-
|
|
321
|
-
# Finished with all tasks normally
|
|
322
|
-
normal_completion = True
|
|
323
337
|
finally:
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
338
|
+
self._shutdown_workers(pool)
|
|
339
|
+
|
|
340
|
+
def _shutdown_workers(self, pool: list[Process]) -> None:
|
|
341
|
+
self._terminate_pool(pool)
|
|
342
|
+
self._drain_queue(self.done_queue)
|
|
343
|
+
self._drain_queue(self.task_queue)
|
|
344
|
+
self._close_queue(self.done_queue)
|
|
345
|
+
self._close_queue(self.task_queue)
|
|
346
|
+
|
|
347
|
+
def _terminate_pool(self, pool: list[Process]) -> None:
|
|
348
|
+
for proc in pool:
|
|
349
|
+
if proc.is_alive():
|
|
350
|
+
proc.terminate()
|
|
351
|
+
|
|
352
|
+
deadline = monotonic() + 1.0
|
|
353
|
+
for proc in pool:
|
|
354
|
+
if not proc.is_alive():
|
|
355
|
+
continue
|
|
356
|
+
remaining = deadline - monotonic()
|
|
357
|
+
if remaining > 0:
|
|
358
|
+
proc.join(remaining)
|
|
359
|
+
if proc.is_alive():
|
|
360
|
+
proc.kill()
|
|
361
|
+
proc.join(timeout=0.2)
|
|
362
|
+
|
|
363
|
+
def _drain_queue(self, queue: MultiprocessQueue) -> None:
|
|
364
|
+
while True:
|
|
365
|
+
try:
|
|
366
|
+
queue.get_nowait()
|
|
367
|
+
except Empty:
|
|
368
|
+
return
|
|
369
|
+
except (OSError, ValueError):
|
|
370
|
+
return
|
|
345
371
|
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
372
|
+
def _close_queue(self, queue: MultiprocessQueue) -> None:
|
|
373
|
+
with contextlib.suppress(OSError, ValueError):
|
|
374
|
+
queue.close()
|
|
375
|
+
with contextlib.suppress(RuntimeError, AssertionError, ValueError):
|
|
376
|
+
queue.join_thread()
|
|
349
377
|
|
|
350
378
|
|
|
351
379
|
class DownloadCallback(Callback):
|
|
352
|
-
def __init__(self, queue:
|
|
380
|
+
def __init__(self, queue: MultiprocessQueue) -> None:
|
|
353
381
|
self.queue = queue
|
|
354
382
|
super().__init__()
|
|
355
383
|
|
|
@@ -364,7 +392,7 @@ class ProcessedCallback(Callback):
|
|
|
364
392
|
def __init__(
|
|
365
393
|
self,
|
|
366
394
|
name: Literal["processed", "generated"],
|
|
367
|
-
queue:
|
|
395
|
+
queue: MultiprocessQueue,
|
|
368
396
|
) -> None:
|
|
369
397
|
self.name = name
|
|
370
398
|
self.queue = queue
|
|
@@ -379,12 +407,13 @@ class UDFWorker:
|
|
|
379
407
|
self,
|
|
380
408
|
catalog: "Catalog",
|
|
381
409
|
udf: "UDFAdapter",
|
|
382
|
-
task_queue:
|
|
383
|
-
done_queue:
|
|
410
|
+
task_queue: MultiprocessQueue,
|
|
411
|
+
done_queue: MultiprocessQueue,
|
|
384
412
|
query: "Select",
|
|
385
413
|
table: "Table",
|
|
386
414
|
cache: bool,
|
|
387
415
|
is_batching: bool,
|
|
416
|
+
batch_size: int,
|
|
388
417
|
udf_fields: Sequence[str],
|
|
389
418
|
) -> None:
|
|
390
419
|
self.catalog = catalog
|
|
@@ -395,19 +424,20 @@ class UDFWorker:
|
|
|
395
424
|
self.table = table
|
|
396
425
|
self.cache = cache
|
|
397
426
|
self.is_batching = is_batching
|
|
427
|
+
self.batch_size = batch_size
|
|
398
428
|
self.udf_fields = udf_fields
|
|
399
429
|
|
|
400
430
|
self.download_cb = DownloadCallback(self.done_queue)
|
|
401
431
|
self.processed_cb = ProcessedCallback("processed", self.done_queue)
|
|
402
432
|
self.generated_cb = ProcessedCallback("generated", self.done_queue)
|
|
403
433
|
|
|
404
|
-
def run(self
|
|
434
|
+
def run(self) -> None:
|
|
405
435
|
prefetch = self.udf.prefetch
|
|
406
436
|
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
|
|
407
437
|
catalog = clone_catalog_with_cache(self.catalog, _cache)
|
|
408
438
|
udf_results = self.udf.run(
|
|
409
439
|
self.udf_fields,
|
|
410
|
-
self.get_inputs(
|
|
440
|
+
self.get_inputs(),
|
|
411
441
|
catalog,
|
|
412
442
|
self.cache,
|
|
413
443
|
download_cb=self.download_cb,
|
|
@@ -420,6 +450,7 @@ class UDFWorker:
|
|
|
420
450
|
self.notify_and_process(udf_results),
|
|
421
451
|
self.udf,
|
|
422
452
|
cb=self.generated_cb,
|
|
453
|
+
batch_size=self.batch_size,
|
|
423
454
|
)
|
|
424
455
|
put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
|
|
425
456
|
|
|
@@ -428,13 +459,10 @@ class UDFWorker:
|
|
|
428
459
|
put_into_queue(self.done_queue, {"status": OK_STATUS})
|
|
429
460
|
yield row
|
|
430
461
|
|
|
431
|
-
def get_inputs(self
|
|
462
|
+
def get_inputs(self) -> Iterable["RowsOutput"]:
|
|
432
463
|
warehouse = self.catalog.warehouse.clone()
|
|
433
464
|
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
)
|
|
439
|
-
else:
|
|
440
|
-
yield from batch
|
|
465
|
+
for ids in batched(batch, DEFAULT_BATCH_SIZE):
|
|
466
|
+
yield from warehouse.dataset_rows_select_from_ids(
|
|
467
|
+
self.query, ids, self.is_batching
|
|
468
|
+
)
|
datachain/query/metrics.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Optional, Union
|
|
3
2
|
|
|
4
|
-
metrics: dict[str,
|
|
3
|
+
metrics: dict[str, str | int | float | bool | None] = {}
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def set(key: str, value:
|
|
6
|
+
def set(key: str, value: str | int | float | bool | None) -> None: # noqa: PYI041
|
|
8
7
|
"""Set a metric value."""
|
|
9
8
|
if not isinstance(key, str):
|
|
10
9
|
raise TypeError("Key must be a string")
|
|
@@ -21,6 +20,6 @@ def set(key: str, value: Union[str, int, float, bool, None]) -> None: # noqa: P
|
|
|
21
20
|
metastore.update_job(job_id, metrics=metrics)
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
def get(key: str) ->
|
|
23
|
+
def get(key: str) -> str | int | float | bool | None:
|
|
25
24
|
"""Get a metric value."""
|
|
26
25
|
return metrics[key]
|
datachain/query/params.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
|
-
params_cache:
|
|
4
|
+
params_cache: dict[str, str] | None = None
|
|
6
5
|
|
|
7
6
|
|
|
8
|
-
def param(key: str, default:
|
|
7
|
+
def param(key: str, default: str | None = None) -> str | None:
|
|
9
8
|
"""Get query parameter."""
|
|
10
9
|
if not isinstance(key, str):
|
|
11
10
|
raise TypeError("Param key must be a string")
|
datachain/query/queue.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
from collections.abc import Iterable, Iterator
|
|
3
|
-
from queue import Empty, Full
|
|
3
|
+
from queue import Empty, Full
|
|
4
4
|
from struct import pack, unpack
|
|
5
5
|
from time import sleep
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import msgpack
|
|
9
|
+
from multiprocess.queues import Queue
|
|
9
10
|
|
|
10
11
|
from datachain.query.batch import RowsOutput
|
|
11
12
|
|
datachain/query/schema.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from fnmatch import fnmatch
|
|
4
|
-
from typing import TYPE_CHECKING, Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
5
6
|
|
|
6
7
|
import attrs
|
|
7
8
|
import sqlalchemy as sa
|
|
@@ -42,7 +43,7 @@ class ColumnMeta(type):
|
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
45
|
-
inherit_cache:
|
|
46
|
+
inherit_cache: bool | None = True
|
|
46
47
|
|
|
47
48
|
def __init__(self, text, type_=None, is_literal=False, _selectable=None):
|
|
48
49
|
"""Dataset column."""
|
|
@@ -177,7 +178,7 @@ class LocalFilename(UDFParameter):
|
|
|
177
178
|
otherwise None will be returned.
|
|
178
179
|
"""
|
|
179
180
|
|
|
180
|
-
glob:
|
|
181
|
+
glob: str | None = None
|
|
181
182
|
|
|
182
183
|
def get_value(
|
|
183
184
|
self,
|
|
@@ -186,7 +187,7 @@ class LocalFilename(UDFParameter):
|
|
|
186
187
|
*,
|
|
187
188
|
cb: Callback = DEFAULT_CALLBACK,
|
|
188
189
|
**kwargs,
|
|
189
|
-
) ->
|
|
190
|
+
) -> str | None:
|
|
190
191
|
if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
|
|
191
192
|
# If the glob pattern is specified and the row filename
|
|
192
193
|
# does not match it, then return None
|
|
@@ -205,7 +206,7 @@ class LocalFilename(UDFParameter):
|
|
|
205
206
|
cache: bool = False,
|
|
206
207
|
cb: Callback = DEFAULT_CALLBACK,
|
|
207
208
|
**kwargs,
|
|
208
|
-
) ->
|
|
209
|
+
) -> str | None:
|
|
209
210
|
if self.glob and not fnmatch(row["name"], self.glob): # type: ignore[type-var]
|
|
210
211
|
# If the glob pattern is specified and the row filename
|
|
211
212
|
# does not match it, then return None
|
|
@@ -216,7 +217,7 @@ class LocalFilename(UDFParameter):
|
|
|
216
217
|
return client.cache.get_path(file)
|
|
217
218
|
|
|
218
219
|
|
|
219
|
-
UDFParamSpec =
|
|
220
|
+
UDFParamSpec = str | Column | UDFParameter
|
|
220
221
|
|
|
221
222
|
|
|
222
223
|
def normalize_param(param: UDFParamSpec) -> UDFParameter:
|