datachain 0.2.18__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/cache.py +5 -10
- datachain/catalog/catalog.py +10 -20
- datachain/client/azure.py +5 -12
- datachain/client/fsspec.py +6 -10
- datachain/client/gcs.py +4 -14
- datachain/client/local.py +4 -11
- datachain/client/s3.py +4 -8
- datachain/data_storage/schema.py +7 -15
- datachain/data_storage/warehouse.py +34 -45
- datachain/lib/dc.py +8 -6
- datachain/lib/file.py +19 -18
- datachain/lib/udf.py +21 -14
- datachain/lib/webdataset.py +2 -3
- datachain/listing.py +14 -20
- datachain/node.py +32 -21
- datachain/query/batch.py +45 -41
- datachain/query/builtins.py +5 -12
- datachain/query/dataset.py +15 -8
- datachain/query/dispatch.py +53 -68
- datachain/query/queue.py +120 -0
- datachain/query/schema.py +3 -7
- datachain/query/udf.py +23 -8
- datachain/utils.py +17 -2
- {datachain-0.2.18.dist-info → datachain-0.3.1.dist-info}/METADATA +1 -1
- {datachain-0.2.18.dist-info → datachain-0.3.1.dist-info}/RECORD +29 -28
- {datachain-0.2.18.dist-info → datachain-0.3.1.dist-info}/LICENSE +0 -0
- {datachain-0.2.18.dist-info → datachain-0.3.1.dist-info}/WHEEL +0 -0
- {datachain-0.2.18.dist-info → datachain-0.3.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.2.18.dist-info → datachain-0.3.1.dist-info}/top_level.txt +0 -0
datachain/query/dispatch.py
CHANGED
|
@@ -2,11 +2,8 @@ import contextlib
|
|
|
2
2
|
from collections.abc import Iterator, Sequence
|
|
3
3
|
from itertools import chain
|
|
4
4
|
from multiprocessing import cpu_count
|
|
5
|
-
from queue import Empty, Full, Queue
|
|
6
5
|
from sys import stdin
|
|
7
|
-
from
|
|
8
|
-
from types import GeneratorType
|
|
9
|
-
from typing import Any, Optional
|
|
6
|
+
from typing import Optional
|
|
10
7
|
|
|
11
8
|
import attrs
|
|
12
9
|
import multiprocess
|
|
@@ -22,7 +19,16 @@ from datachain.query.dataset import (
|
|
|
22
19
|
get_processed_callback,
|
|
23
20
|
process_udf_outputs,
|
|
24
21
|
)
|
|
22
|
+
from datachain.query.queue import (
|
|
23
|
+
get_from_queue,
|
|
24
|
+
marshal,
|
|
25
|
+
msgpack_pack,
|
|
26
|
+
msgpack_unpack,
|
|
27
|
+
put_into_queue,
|
|
28
|
+
unmarshal,
|
|
29
|
+
)
|
|
25
30
|
from datachain.query.udf import UDFBase, UDFFactory, UDFResult
|
|
31
|
+
from datachain.utils import batched_it
|
|
26
32
|
|
|
27
33
|
DEFAULT_BATCH_SIZE = 10000
|
|
28
34
|
STOP_SIGNAL = "STOP"
|
|
@@ -44,44 +50,6 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
|
44
50
|
return n_workers
|
|
45
51
|
|
|
46
52
|
|
|
47
|
-
# For more context on the get_from_queue and put_into_queue functions, see the
|
|
48
|
-
# discussion here:
|
|
49
|
-
# https://github.com/iterative/dvcx/pull/1297#issuecomment-2026308773
|
|
50
|
-
# This problem is not exactly described by, but is also related to these Python issues:
|
|
51
|
-
# https://github.com/python/cpython/issues/66587
|
|
52
|
-
# https://github.com/python/cpython/issues/88628
|
|
53
|
-
# https://github.com/python/cpython/issues/108645
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def get_from_queue(queue: Queue) -> Any:
|
|
57
|
-
"""
|
|
58
|
-
Gets an item from a queue.
|
|
59
|
-
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
60
|
-
while waiting for items to be available, although only on certain installations.
|
|
61
|
-
(See the above comment for more context.)
|
|
62
|
-
"""
|
|
63
|
-
while True:
|
|
64
|
-
try:
|
|
65
|
-
return queue.get_nowait()
|
|
66
|
-
except Empty:
|
|
67
|
-
sleep(0.01)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def put_into_queue(queue: Queue, item: Any) -> None:
|
|
71
|
-
"""
|
|
72
|
-
Puts an item into a queue.
|
|
73
|
-
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
74
|
-
while waiting for items to be queued, although only on certain installations.
|
|
75
|
-
(See the above comment for more context.)
|
|
76
|
-
"""
|
|
77
|
-
while True:
|
|
78
|
-
try:
|
|
79
|
-
queue.put_nowait(item)
|
|
80
|
-
return
|
|
81
|
-
except Full:
|
|
82
|
-
sleep(0.01)
|
|
83
|
-
|
|
84
|
-
|
|
85
53
|
def udf_entrypoint() -> int:
|
|
86
54
|
# Load UDF info from stdin
|
|
87
55
|
udf_info = load(stdin.buffer)
|
|
@@ -100,8 +68,9 @@ def udf_entrypoint() -> int:
|
|
|
100
68
|
udf_info["id_generator_clone_params"],
|
|
101
69
|
udf_info["metastore_clone_params"],
|
|
102
70
|
udf_info["warehouse_clone_params"],
|
|
103
|
-
|
|
71
|
+
udf_fields=udf_info["udf_fields"],
|
|
104
72
|
cache=udf_info["cache"],
|
|
73
|
+
is_generator=udf_info.get("is_generator", False),
|
|
105
74
|
)
|
|
106
75
|
|
|
107
76
|
query = udf_info["query"]
|
|
@@ -121,7 +90,7 @@ def udf_entrypoint() -> int:
|
|
|
121
90
|
generated_cb = get_generated_callback(dispatch.is_generator)
|
|
122
91
|
try:
|
|
123
92
|
udf_results = dispatch.run_udf_parallel(
|
|
124
|
-
udf_inputs,
|
|
93
|
+
marshal(udf_inputs),
|
|
125
94
|
n_workers=n_workers,
|
|
126
95
|
processed_cb=processed_cb,
|
|
127
96
|
download_cb=download_cb,
|
|
@@ -142,6 +111,9 @@ def udf_worker_entrypoint() -> int:
|
|
|
142
111
|
|
|
143
112
|
|
|
144
113
|
class UDFDispatcher:
|
|
114
|
+
catalog: Optional[Catalog] = None
|
|
115
|
+
task_queue: Optional[multiprocess.Queue] = None
|
|
116
|
+
done_queue: Optional[multiprocess.Queue] = None
|
|
145
117
|
_batch_size: Optional[int] = None
|
|
146
118
|
|
|
147
119
|
def __init__(
|
|
@@ -151,9 +123,10 @@ class UDFDispatcher:
|
|
|
151
123
|
id_generator_clone_params,
|
|
152
124
|
metastore_clone_params,
|
|
153
125
|
warehouse_clone_params,
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
126
|
+
udf_fields: "Sequence[str]",
|
|
127
|
+
cache: bool,
|
|
128
|
+
is_generator: bool = False,
|
|
129
|
+
buffer_size: int = DEFAULT_BATCH_SIZE,
|
|
157
130
|
):
|
|
158
131
|
self.udf_data = udf_data
|
|
159
132
|
self.catalog_init_params = catalog_init_params
|
|
@@ -172,12 +145,13 @@ class UDFDispatcher:
|
|
|
172
145
|
self.warehouse_args,
|
|
173
146
|
self.warehouse_kwargs,
|
|
174
147
|
) = warehouse_clone_params
|
|
175
|
-
self.
|
|
148
|
+
self.udf_fields = udf_fields
|
|
176
149
|
self.cache = cache
|
|
150
|
+
self.is_generator = is_generator
|
|
151
|
+
self.buffer_size = buffer_size
|
|
177
152
|
self.catalog = None
|
|
178
153
|
self.task_queue = None
|
|
179
154
|
self.done_queue = None
|
|
180
|
-
self.buffer_size = buffer_size
|
|
181
155
|
self.ctx = get_context("spawn")
|
|
182
156
|
|
|
183
157
|
@property
|
|
@@ -226,6 +200,7 @@ class UDFDispatcher:
|
|
|
226
200
|
self.done_queue,
|
|
227
201
|
self.is_generator,
|
|
228
202
|
self.cache,
|
|
203
|
+
self.udf_fields,
|
|
229
204
|
)
|
|
230
205
|
|
|
231
206
|
def _run_worker(self) -> None:
|
|
@@ -233,7 +208,11 @@ class UDFDispatcher:
|
|
|
233
208
|
worker = self._create_worker()
|
|
234
209
|
worker.run()
|
|
235
210
|
except (Exception, KeyboardInterrupt) as e:
|
|
236
|
-
|
|
211
|
+
if self.done_queue:
|
|
212
|
+
put_into_queue(
|
|
213
|
+
self.done_queue,
|
|
214
|
+
{"status": FAILED_STATUS, "exception": e},
|
|
215
|
+
)
|
|
237
216
|
raise
|
|
238
217
|
|
|
239
218
|
@staticmethod
|
|
@@ -249,7 +228,6 @@ class UDFDispatcher:
|
|
|
249
228
|
self,
|
|
250
229
|
input_rows,
|
|
251
230
|
n_workers: Optional[int] = None,
|
|
252
|
-
cache: bool = False,
|
|
253
231
|
input_queue=None,
|
|
254
232
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
255
233
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
@@ -299,21 +277,24 @@ class UDFDispatcher:
|
|
|
299
277
|
result = get_from_queue(self.done_queue)
|
|
300
278
|
status = result["status"]
|
|
301
279
|
if status == NOTIFY_STATUS:
|
|
302
|
-
|
|
280
|
+
if downloaded := result.get("downloaded"):
|
|
281
|
+
download_cb.relative_update(downloaded)
|
|
282
|
+
if processed := result.get("processed"):
|
|
283
|
+
processed_cb.relative_update(processed)
|
|
303
284
|
elif status == FINISHED_STATUS:
|
|
304
285
|
# Worker finished
|
|
305
286
|
n_workers -= 1
|
|
306
287
|
elif status == OK_STATUS:
|
|
307
|
-
|
|
308
|
-
|
|
288
|
+
if processed := result.get("processed"):
|
|
289
|
+
processed_cb.relative_update(processed)
|
|
290
|
+
yield msgpack_unpack(result["result"])
|
|
309
291
|
else: # Failed / error
|
|
310
292
|
n_workers -= 1
|
|
311
|
-
exc
|
|
312
|
-
if exc:
|
|
293
|
+
if exc := result.get("exception"):
|
|
313
294
|
raise exc
|
|
314
295
|
raise RuntimeError("Internal error: Parallel UDF execution failed")
|
|
315
296
|
|
|
316
|
-
if not streaming_mode and not input_finished:
|
|
297
|
+
if status == OK_STATUS and not streaming_mode and not input_finished:
|
|
317
298
|
try:
|
|
318
299
|
put_into_queue(self.task_queue, next(input_data))
|
|
319
300
|
except StopIteration:
|
|
@@ -348,7 +329,7 @@ class UDFDispatcher:
|
|
|
348
329
|
|
|
349
330
|
|
|
350
331
|
class WorkerCallback(Callback):
|
|
351
|
-
def __init__(self, queue: multiprocess.Queue):
|
|
332
|
+
def __init__(self, queue: "multiprocess.Queue"):
|
|
352
333
|
self.queue = queue
|
|
353
334
|
super().__init__()
|
|
354
335
|
|
|
@@ -369,10 +350,11 @@ class ProcessedCallback(Callback):
|
|
|
369
350
|
class UDFWorker:
|
|
370
351
|
catalog: Catalog
|
|
371
352
|
udf: UDFBase
|
|
372
|
-
task_queue: multiprocess.Queue
|
|
373
|
-
done_queue: multiprocess.Queue
|
|
353
|
+
task_queue: "multiprocess.Queue"
|
|
354
|
+
done_queue: "multiprocess.Queue"
|
|
374
355
|
is_generator: bool
|
|
375
356
|
cache: bool
|
|
357
|
+
udf_fields: Sequence[str]
|
|
376
358
|
cb: Callback = attrs.field()
|
|
377
359
|
|
|
378
360
|
@cb.default
|
|
@@ -382,7 +364,8 @@ class UDFWorker:
|
|
|
382
364
|
def run(self) -> None:
|
|
383
365
|
processed_cb = ProcessedCallback()
|
|
384
366
|
udf_results = self.udf.run(
|
|
385
|
-
self.
|
|
367
|
+
self.udf_fields,
|
|
368
|
+
unmarshal(self.get_inputs()),
|
|
386
369
|
self.catalog,
|
|
387
370
|
self.is_generator,
|
|
388
371
|
self.cache,
|
|
@@ -390,15 +373,17 @@ class UDFWorker:
|
|
|
390
373
|
processed_cb=processed_cb,
|
|
391
374
|
)
|
|
392
375
|
for udf_output in udf_results:
|
|
393
|
-
|
|
394
|
-
|
|
376
|
+
for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE):
|
|
377
|
+
put_into_queue(
|
|
378
|
+
self.done_queue,
|
|
379
|
+
{
|
|
380
|
+
"status": OK_STATUS,
|
|
381
|
+
"result": msgpack_pack(list(batch)),
|
|
382
|
+
},
|
|
383
|
+
)
|
|
395
384
|
put_into_queue(
|
|
396
385
|
self.done_queue,
|
|
397
|
-
{
|
|
398
|
-
"status": OK_STATUS,
|
|
399
|
-
"result": udf_output,
|
|
400
|
-
"processed": processed_cb.processed_rows,
|
|
401
|
-
},
|
|
386
|
+
{"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows},
|
|
402
387
|
)
|
|
403
388
|
put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
|
|
404
389
|
|
datachain/query/queue.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from collections.abc import Iterable, Iterator
|
|
3
|
+
from queue import Empty, Full, Queue
|
|
4
|
+
from struct import pack, unpack
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import msgpack
|
|
9
|
+
|
|
10
|
+
from datachain.query.batch import RowsOutput, RowsOutputBatch
|
|
11
|
+
|
|
12
|
+
DEFAULT_BATCH_SIZE = 10000
|
|
13
|
+
STOP_SIGNAL = "STOP"
|
|
14
|
+
OK_STATUS = "OK"
|
|
15
|
+
FINISHED_STATUS = "FINISHED"
|
|
16
|
+
FAILED_STATUS = "FAILED"
|
|
17
|
+
NOTIFY_STATUS = "NOTIFY"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# For more context on the get_from_queue and put_into_queue functions, see the
|
|
21
|
+
# discussion here:
|
|
22
|
+
# https://github.com/iterative/dvcx/pull/1297#issuecomment-2026308773
|
|
23
|
+
# This problem is not exactly described by, but is also related to these Python issues:
|
|
24
|
+
# https://github.com/python/cpython/issues/66587
|
|
25
|
+
# https://github.com/python/cpython/issues/88628
|
|
26
|
+
# https://github.com/python/cpython/issues/108645
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_from_queue(queue: Queue) -> Any:
|
|
30
|
+
"""
|
|
31
|
+
Gets an item from a queue.
|
|
32
|
+
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
33
|
+
while waiting for items to be available, although only on certain installations.
|
|
34
|
+
(See the above comment for more context.)
|
|
35
|
+
"""
|
|
36
|
+
while True:
|
|
37
|
+
try:
|
|
38
|
+
return queue.get_nowait()
|
|
39
|
+
except Empty:
|
|
40
|
+
sleep(0.01)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def put_into_queue(queue: Queue, item: Any) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Puts an item into a queue.
|
|
46
|
+
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
47
|
+
while waiting for items to be queued, although only on certain installations.
|
|
48
|
+
(See the above comment for more context.)
|
|
49
|
+
"""
|
|
50
|
+
while True:
|
|
51
|
+
try:
|
|
52
|
+
queue.put_nowait(item)
|
|
53
|
+
return
|
|
54
|
+
except Full:
|
|
55
|
+
sleep(0.01)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
MSGPACK_EXT_TYPE_DATETIME = 42
|
|
59
|
+
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH = 43
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
|
|
63
|
+
if isinstance(obj, datetime.datetime):
|
|
64
|
+
# packing date object as 1 or 2 variables, depending if timezone info is present
|
|
65
|
+
# - timestamp
|
|
66
|
+
# - [OPTIONAL] timezone offset from utc in seconds if timezone info exists
|
|
67
|
+
if obj.tzinfo:
|
|
68
|
+
data = (obj.timestamp(), int(obj.utcoffset().total_seconds())) # type: ignore # noqa: PGH003
|
|
69
|
+
return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!dl", *data))
|
|
70
|
+
data = (obj.timestamp(),) # type: ignore # noqa: PGH003
|
|
71
|
+
return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!d", *data))
|
|
72
|
+
|
|
73
|
+
if isinstance(obj, RowsOutputBatch):
|
|
74
|
+
return msgpack.ExtType(
|
|
75
|
+
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH,
|
|
76
|
+
msgpack_pack(obj.rows),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
raise TypeError(f"Unknown type: {obj}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def msgpack_pack(obj: Any) -> bytes:
|
|
83
|
+
return msgpack.packb(obj, default=_msgpack_pack_extended_types)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _msgpack_unpack_extended_types(code: int, data: bytes) -> Any:
|
|
87
|
+
if code == MSGPACK_EXT_TYPE_DATETIME:
|
|
88
|
+
has_timezone = False
|
|
89
|
+
if len(data) == 8:
|
|
90
|
+
# we send only timestamp without timezone if data is 8 bytes
|
|
91
|
+
values = unpack("!d", data)
|
|
92
|
+
else:
|
|
93
|
+
has_timezone = True
|
|
94
|
+
values = unpack("!dl", data)
|
|
95
|
+
|
|
96
|
+
timestamp = values[0]
|
|
97
|
+
tz_info = None
|
|
98
|
+
if has_timezone:
|
|
99
|
+
timezone_offset = values[1]
|
|
100
|
+
tz_info = datetime.timezone(datetime.timedelta(seconds=timezone_offset))
|
|
101
|
+
return datetime.datetime.fromtimestamp(timestamp, tz=tz_info)
|
|
102
|
+
|
|
103
|
+
if code == MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH:
|
|
104
|
+
return RowsOutputBatch(msgpack_unpack(data))
|
|
105
|
+
|
|
106
|
+
return msgpack.ExtType(code, data)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def msgpack_unpack(data: bytes) -> Any:
|
|
110
|
+
return msgpack.unpackb(data, ext_hook=_msgpack_unpack_extended_types)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def marshal(obj: Iterator[RowsOutput]) -> Iterable[bytes]:
|
|
114
|
+
for row in obj:
|
|
115
|
+
yield msgpack_pack(row)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def unmarshal(obj: Iterator[bytes]) -> Iterable[RowsOutput]:
|
|
119
|
+
for row in obj:
|
|
120
|
+
yield msgpack_unpack(row)
|
datachain/query/schema.py
CHANGED
|
@@ -215,8 +215,7 @@ def normalize_param(param: UDFParamSpec) -> UDFParameter:
|
|
|
215
215
|
class DatasetRow:
|
|
216
216
|
schema: ClassVar[dict[str, type[SQLType]]] = {
|
|
217
217
|
"source": String,
|
|
218
|
-
"
|
|
219
|
-
"name": String,
|
|
218
|
+
"path": String,
|
|
220
219
|
"size": Int64,
|
|
221
220
|
"location": JSON,
|
|
222
221
|
"vtype": String,
|
|
@@ -231,9 +230,8 @@ class DatasetRow:
|
|
|
231
230
|
|
|
232
231
|
@staticmethod
|
|
233
232
|
def create(
|
|
234
|
-
|
|
233
|
+
path: str,
|
|
235
234
|
source: str = "",
|
|
236
|
-
parent: str = "",
|
|
237
235
|
size: int = 0,
|
|
238
236
|
location: Optional[dict[str, Any]] = None,
|
|
239
237
|
vtype: str = "",
|
|
@@ -245,7 +243,6 @@ class DatasetRow:
|
|
|
245
243
|
version: str = "",
|
|
246
244
|
etag: str = "",
|
|
247
245
|
) -> tuple[
|
|
248
|
-
str,
|
|
249
246
|
str,
|
|
250
247
|
str,
|
|
251
248
|
int,
|
|
@@ -267,8 +264,7 @@ class DatasetRow:
|
|
|
267
264
|
|
|
268
265
|
return ( # type: ignore [return-value]
|
|
269
266
|
source,
|
|
270
|
-
|
|
271
|
-
name,
|
|
267
|
+
path,
|
|
272
268
|
size,
|
|
273
269
|
location,
|
|
274
270
|
vtype,
|
datachain/query/udf.py
CHANGED
|
@@ -15,7 +15,14 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
|
15
15
|
|
|
16
16
|
from datachain.dataset import RowDict
|
|
17
17
|
|
|
18
|
-
from .batch import
|
|
18
|
+
from .batch import (
|
|
19
|
+
Batch,
|
|
20
|
+
BatchingStrategy,
|
|
21
|
+
NoBatching,
|
|
22
|
+
Partition,
|
|
23
|
+
RowsOutputBatch,
|
|
24
|
+
UDFInputBatch,
|
|
25
|
+
)
|
|
19
26
|
from .schema import (
|
|
20
27
|
UDFParameter,
|
|
21
28
|
UDFParamSpec,
|
|
@@ -25,7 +32,7 @@ from .schema import (
|
|
|
25
32
|
if TYPE_CHECKING:
|
|
26
33
|
from datachain.catalog import Catalog
|
|
27
34
|
|
|
28
|
-
from .batch import
|
|
35
|
+
from .batch import RowsOutput, UDFInput
|
|
29
36
|
|
|
30
37
|
ColumnType = Any
|
|
31
38
|
|
|
@@ -107,7 +114,8 @@ class UDFBase:
|
|
|
107
114
|
|
|
108
115
|
def run(
|
|
109
116
|
self,
|
|
110
|
-
|
|
117
|
+
udf_fields: "Sequence[str]",
|
|
118
|
+
udf_inputs: "Iterable[RowsOutput]",
|
|
111
119
|
catalog: "Catalog",
|
|
112
120
|
is_generator: bool,
|
|
113
121
|
cache: bool,
|
|
@@ -115,15 +123,22 @@ class UDFBase:
|
|
|
115
123
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
116
124
|
) -> Iterator[Iterable["UDFResult"]]:
|
|
117
125
|
for batch in udf_inputs:
|
|
118
|
-
|
|
119
|
-
|
|
126
|
+
if isinstance(batch, RowsOutputBatch):
|
|
127
|
+
n_rows = len(batch.rows)
|
|
128
|
+
inputs: UDFInput = UDFInputBatch(
|
|
129
|
+
[RowDict(zip(udf_fields, row)) for row in batch.rows]
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
n_rows = 1
|
|
133
|
+
inputs = RowDict(zip(udf_fields, batch))
|
|
134
|
+
output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
|
|
120
135
|
processed_cb.relative_update(n_rows)
|
|
121
136
|
yield output
|
|
122
137
|
|
|
123
138
|
def run_once(
|
|
124
139
|
self,
|
|
125
140
|
catalog: "Catalog",
|
|
126
|
-
arg: "
|
|
141
|
+
arg: "UDFInput",
|
|
127
142
|
is_generator: bool = False,
|
|
128
143
|
cache: bool = False,
|
|
129
144
|
cb: Callback = DEFAULT_CALLBACK,
|
|
@@ -199,12 +214,12 @@ class UDFWrapper(UDFBase):
|
|
|
199
214
|
def run_once(
|
|
200
215
|
self,
|
|
201
216
|
catalog: "Catalog",
|
|
202
|
-
arg: "
|
|
217
|
+
arg: "UDFInput",
|
|
203
218
|
is_generator: bool = False,
|
|
204
219
|
cache: bool = False,
|
|
205
220
|
cb: Callback = DEFAULT_CALLBACK,
|
|
206
221
|
) -> Iterable[UDFResult]:
|
|
207
|
-
if isinstance(arg,
|
|
222
|
+
if isinstance(arg, UDFInputBatch):
|
|
208
223
|
udf_inputs = [
|
|
209
224
|
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
210
225
|
for row in arg.rows
|
datachain/utils.py
CHANGED
|
@@ -10,7 +10,7 @@ import sys
|
|
|
10
10
|
import time
|
|
11
11
|
from collections.abc import Iterable, Iterator, Sequence
|
|
12
12
|
from datetime import date, datetime, timezone
|
|
13
|
-
from itertools import islice
|
|
13
|
+
from itertools import chain, islice
|
|
14
14
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
|
15
15
|
from uuid import UUID
|
|
16
16
|
|
|
@@ -241,7 +241,7 @@ _T_co = TypeVar("_T_co", covariant=True)
|
|
|
241
241
|
|
|
242
242
|
|
|
243
243
|
def batched(iterable: Iterable[_T_co], n: int) -> Iterator[tuple[_T_co, ...]]:
|
|
244
|
-
"Batch data into tuples of length n. The last batch may be shorter."
|
|
244
|
+
"""Batch data into tuples of length n. The last batch may be shorter."""
|
|
245
245
|
# Based on: https://docs.python.org/3/library/itertools.html#itertools-recipes
|
|
246
246
|
# batched('ABCDEFG', 3) --> ABC DEF G
|
|
247
247
|
if n < 1:
|
|
@@ -251,6 +251,21 @@ def batched(iterable: Iterable[_T_co], n: int) -> Iterator[tuple[_T_co, ...]]:
|
|
|
251
251
|
yield batch
|
|
252
252
|
|
|
253
253
|
|
|
254
|
+
def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]:
|
|
255
|
+
"""Batch data into iterators of length n. The last batch may be shorter."""
|
|
256
|
+
# batched('ABCDEFG', 3) --> ABC DEF G
|
|
257
|
+
if n < 1:
|
|
258
|
+
raise ValueError("Batch size must be at least one")
|
|
259
|
+
it = iter(iterable)
|
|
260
|
+
while True:
|
|
261
|
+
chunk_it = islice(it, n)
|
|
262
|
+
try:
|
|
263
|
+
first_el = next(chunk_it)
|
|
264
|
+
except StopIteration:
|
|
265
|
+
return
|
|
266
|
+
yield chain((first_el,), chunk_it)
|
|
267
|
+
|
|
268
|
+
|
|
254
269
|
def flatten(items):
|
|
255
270
|
for item in items:
|
|
256
271
|
if isinstance(item, list):
|
|
@@ -1,49 +1,49 @@
|
|
|
1
1
|
datachain/__init__.py,sha256=GeyhE-5LgfJav2OKYGaieP2lBvf2Gm-ihj7thnK9zjI,800
|
|
2
2
|
datachain/__main__.py,sha256=hG3Y4ARGEqe1AWwNMd259rBlqtphx1Wk39YbueQ0yV8,91
|
|
3
3
|
datachain/asyn.py,sha256=CKCFQJ0CbB3r04S7mUTXxriKzPnOvdUaVPXjM8vCtJw,7644
|
|
4
|
-
datachain/cache.py,sha256=
|
|
4
|
+
datachain/cache.py,sha256=wznC2pge6RhlPTaJfBVGjmBc6bxWCPThu4aTFMltvFU,4076
|
|
5
5
|
datachain/cli.py,sha256=DbmI1sXs7-KCQz6RdLE_JAp3XO3yrTSRJ71LdUzx-XE,33099
|
|
6
6
|
datachain/cli_utils.py,sha256=jrn9ejGXjybeO1ur3fjdSiAyCHZrX0qsLLbJzN9ErPM,2418
|
|
7
7
|
datachain/config.py,sha256=PfC7W5yO6HFO6-iMB4YB-0RR88LPiGmD6sS_SfVbGso,1979
|
|
8
8
|
datachain/dataset.py,sha256=MZezyuJWNj_3PEtzr0epPMNyWAOTrhTSPI5FmemV6L4,14470
|
|
9
9
|
datachain/error.py,sha256=GY9KYTmb7GHXn2gGHV9X-PBhgwLj3i7VpK7tGHtAoGM,1279
|
|
10
10
|
datachain/job.py,sha256=bk25bIqClhgRPzlXAhxpTtDeewibQe5l3S8Cf7db0gM,1229
|
|
11
|
-
datachain/listing.py,sha256=
|
|
12
|
-
datachain/node.py,sha256=
|
|
11
|
+
datachain/listing.py,sha256=keLkvPfumDA3gijeIiinH5yGWe71qCxgF5HqqP5AeH4,8299
|
|
12
|
+
datachain/node.py,sha256=frxZWoEvqUvk9pyXmVaeiNCs3W-xjC_sENmUD11V06Q,6006
|
|
13
13
|
datachain/nodes_fetcher.py,sha256=kca19yvu11JxoVY1t4_ydp1FmchiV88GnNicNBQ9NIA,831
|
|
14
14
|
datachain/nodes_thread_pool.py,sha256=ZyzBvUImIPmi4WlKC2SW2msA0UhtembbTdcs2nx29A0,3191
|
|
15
15
|
datachain/progress.py,sha256=7_8FtJs770ITK9sMq-Lt4k4k18QmYl4yIG_kCoWID3o,4559
|
|
16
16
|
datachain/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
datachain/storage.py,sha256=RiSJLYdHUjnrEWkLBKPcETHpAxld_B2WxLg711t0aZI,3733
|
|
18
|
-
datachain/utils.py,sha256=
|
|
18
|
+
datachain/utils.py,sha256=ROVCLwb37VmFRzgTlSGUDw4eJNgYGiQ4yMX581HfUX8,12988
|
|
19
19
|
datachain/catalog/__init__.py,sha256=g2iAAFx_gEIrqshXlhSEbrc8qDaEH11cjU40n3CHDz4,409
|
|
20
|
-
datachain/catalog/catalog.py,sha256=
|
|
20
|
+
datachain/catalog/catalog.py,sha256=9-7SnMjh5ruH9sdKDo8P5EklX9oC2EHH6bnku6ZqLko,80275
|
|
21
21
|
datachain/catalog/datasource.py,sha256=D-VWIVDCM10A8sQavLhRXdYSCG7F4o4ifswEF80_NAQ,1412
|
|
22
22
|
datachain/catalog/loader.py,sha256=GJ8zhEYkC7TuaPzCsjJQ4LtTdECu-wwYzC12MikPOMQ,7307
|
|
23
23
|
datachain/catalog/subclass.py,sha256=B5R0qxeTYEyVAAPM1RutBPSoXZc8L5mVVZeSGXki9Sw,2096
|
|
24
24
|
datachain/client/__init__.py,sha256=T4wiYL9KIM0ZZ_UqIyzV8_ufzYlewmizlV4iymHNluE,86
|
|
25
|
-
datachain/client/azure.py,sha256=
|
|
25
|
+
datachain/client/azure.py,sha256=3RfDTAI_TszDy9WazHQd3bI3sS2wDFrNXfNqCDewZgE,2214
|
|
26
26
|
datachain/client/fileslice.py,sha256=bT7TYco1Qe3bqoc8aUkUZcPdPofJDHlryL5BsTn9xsY,3021
|
|
27
|
-
datachain/client/fsspec.py,sha256=
|
|
28
|
-
datachain/client/gcs.py,sha256=
|
|
29
|
-
datachain/client/local.py,sha256=
|
|
30
|
-
datachain/client/s3.py,sha256=
|
|
27
|
+
datachain/client/fsspec.py,sha256=G4QTm3KPhlaV74T3gLXJ86345_ak8CH38ezn2ET-oLc,13230
|
|
28
|
+
datachain/client/gcs.py,sha256=Mt77W_l8_fK61gLm4mmxNmENuOM0ETwxdiFp4S8d-_w,4105
|
|
29
|
+
datachain/client/local.py,sha256=SyGnqcrbtSvDK6IJsQa6NxxHwbWaWIP1GLZsQBXg_IA,4939
|
|
30
|
+
datachain/client/s3.py,sha256=GfRZZzNPQPRsYjoef8bbsLbanJPUlCbyGTTK8ojzp8A,6136
|
|
31
31
|
datachain/data_storage/__init__.py,sha256=cEOJpyu1JDZtfUupYucCDNFI6e5Wmp_Oyzq6rZv32Y8,398
|
|
32
32
|
datachain/data_storage/db_engine.py,sha256=81Ol1of9TTTzD97ORajCnP366Xz2mEJt6C-kTUCaru4,3406
|
|
33
33
|
datachain/data_storage/id_generator.py,sha256=lCEoU0BM37Ai2aRpSbwo5oQT0GqZnSpYwwvizathRMQ,4292
|
|
34
34
|
datachain/data_storage/job.py,sha256=w-7spowjkOa1P5fUVtJou3OltT0L48P0RYWZ9rSJ9-s,383
|
|
35
35
|
datachain/data_storage/metastore.py,sha256=nxcY6nwyEmQWMAo33sNGO-FgUFQs2amBGGnZz2ftEz0,55362
|
|
36
|
-
datachain/data_storage/schema.py,sha256=
|
|
36
|
+
datachain/data_storage/schema.py,sha256=Idi-29fckvZozzvkyz3nTR2FOIajPlSuPdIEO7SMvXM,7863
|
|
37
37
|
datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
|
|
38
38
|
datachain/data_storage/sqlite.py,sha256=0r6L_a2hdGRoR_gl06v1qWhEFOS_Q31aldHyk07Yx-M,26857
|
|
39
|
-
datachain/data_storage/warehouse.py,sha256=
|
|
39
|
+
datachain/data_storage/warehouse.py,sha256=MXYkUG69UK2wbIFsZFvT7rKzXlnSitDMp3Vzj_IIsnA,33089
|
|
40
40
|
datachain/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
41
|
datachain/lib/arrow.py,sha256=R8wDUDEa-5hYjI3HW9cqvOYYJpeeah5lbhFIL3gkmcE,4915
|
|
42
42
|
datachain/lib/clip.py,sha256=16u4b_y2Y15nUS2UN_8ximMo6r_-_4IQpmct2ol-e-g,5730
|
|
43
43
|
datachain/lib/data_model.py,sha256=qfTtQNncS5pt9SvXdMEa5kClniaT6XBGBfO7onEz2TI,1632
|
|
44
44
|
datachain/lib/dataset_info.py,sha256=lONGr71ozo1DS4CQEhnpKORaU4qFb6Ketv8Xm8CVm2U,2188
|
|
45
|
-
datachain/lib/dc.py,sha256=
|
|
46
|
-
datachain/lib/file.py,sha256=
|
|
45
|
+
datachain/lib/dc.py,sha256=e24ecfIcypVkmVBqvr-p06zpwrw7GD20gy1gBJQPT-I,58012
|
|
46
|
+
datachain/lib/file.py,sha256=ZHpdilDPYCob8uqtwUPtBvBNxVvQRq4AC_0IGg5m-G4,12003
|
|
47
47
|
datachain/lib/image.py,sha256=TgYhRhzd4nkytfFMeykQkPyzqb5Le_-tU81unVMPn4Q,2328
|
|
48
48
|
datachain/lib/meta_formats.py,sha256=jlSYWRUeDMjun_YCsQ2JxyaDJpEpokzHDPmKUAoCXnU,7034
|
|
49
49
|
datachain/lib/model_store.py,sha256=c4USXsBBjrGH8VOh4seIgOiav-qHOwdoixtxfLgU63c,2409
|
|
@@ -51,11 +51,11 @@ datachain/lib/pytorch.py,sha256=9PsypKseyKfIimTmTQOgb-pbNXgeeAHLdlWx0qRPULY,5660
|
|
|
51
51
|
datachain/lib/settings.py,sha256=39thOpYJw-zPirzeNO6pmRC2vPrQvt4eBsw1xLWDFsw,2344
|
|
52
52
|
datachain/lib/signal_schema.py,sha256=VL9TR0CJ3eRzjIDr-8e-e7cZKuMBbPUZtY2lGAsucc0,15734
|
|
53
53
|
datachain/lib/text.py,sha256=dVe2Ilc_gW2EV0kun0UwegiCkapWcd20cef7CgINWHU,1083
|
|
54
|
-
datachain/lib/udf.py,sha256=
|
|
54
|
+
datachain/lib/udf.py,sha256=n3x6No-7l5LAciPJPWwZbA8WtTnGUU7d0wRL6CyfZh8,11847
|
|
55
55
|
datachain/lib/udf_signature.py,sha256=gMStcEeYJka5M6cg50Z9orC6y6HzCAJ3MkFqqn1fjZg,7137
|
|
56
56
|
datachain/lib/utils.py,sha256=5-kJlAZE0D9nXXweAjo7-SP_AWGo28feaDByONYaooQ,463
|
|
57
57
|
datachain/lib/vfile.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
58
|
-
datachain/lib/webdataset.py,sha256=
|
|
58
|
+
datachain/lib/webdataset.py,sha256=SsjCKLSKEkHRRfeTHQhjoGqNPqIWw_SCWQcUwgUWWP0,8282
|
|
59
59
|
datachain/lib/webdataset_laion.py,sha256=PQP6tQmUP7Xu9fPuAGK1JDBYA6T5UufYMUTGaxgspJA,2118
|
|
60
60
|
datachain/lib/convert/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
61
61
|
datachain/lib/convert/flatten.py,sha256=YMoC00BqEy3zSpvCp6Q0DfxihuPmgjUJj1g2cesWGPs,1790
|
|
@@ -64,15 +64,16 @@ datachain/lib/convert/sql_to_python.py,sha256=lGnKzSF_tz9Y_5SSKkrIU95QEjpcDzvOxI
|
|
|
64
64
|
datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xdq56Tw,2012
|
|
65
65
|
datachain/lib/convert/values_to_tuples.py,sha256=aVoHWMOUGLAiS6_BBwKJqVIne91VffOW6-dWyNE7oHg,3715
|
|
66
66
|
datachain/query/__init__.py,sha256=tv-spkjUCYamMN9ys_90scYrZ8kJ7C7d1MTYVmxGtk4,325
|
|
67
|
-
datachain/query/batch.py,sha256
|
|
68
|
-
datachain/query/builtins.py,sha256=
|
|
69
|
-
datachain/query/dataset.py,sha256
|
|
70
|
-
datachain/query/dispatch.py,sha256=
|
|
67
|
+
datachain/query/batch.py,sha256=-vlpINJiertlnaoUVv1C95RatU0F6zuhpIYRufJRo1M,3660
|
|
68
|
+
datachain/query/builtins.py,sha256=EmKPYsoQ46zwdyOn54MuCzvYFmfsBn5F8zyF7UBUfrc,2550
|
|
69
|
+
datachain/query/dataset.py,sha256=sRKY2it_znlzTNOt_OCRe008rHu0TXMnFwvGsnthSO0,60209
|
|
70
|
+
datachain/query/dispatch.py,sha256=GBh3EZHDp5AaXxrjOpfrpfsuy7Umnqxu-MAXcK9X3gc,12945
|
|
71
71
|
datachain/query/metrics.py,sha256=vsECqbZfoSDBnvC3GQlziKXmISVYDLgHP1fMPEOtKyo,640
|
|
72
72
|
datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
|
|
73
|
-
datachain/query/
|
|
73
|
+
datachain/query/queue.py,sha256=waqM_KzavU8C-G95-4211Nd4GXna_u2747Chgwtgz2w,3839
|
|
74
|
+
datachain/query/schema.py,sha256=O3mTM5DRjvRAJCI7O9mR8wOdFJbgI1jIjvtfl5YvjI4,7755
|
|
74
75
|
datachain/query/session.py,sha256=qTzkXgwMJdJhal3rVt3hdv3x1EXT1IHuXcwkC-Ex0As,4111
|
|
75
|
-
datachain/query/udf.py,sha256=
|
|
76
|
+
datachain/query/udf.py,sha256=j3NhmKK5rYG5TclcM2Sr0LhS1tmYLMjzMugx9G9iFLM,8100
|
|
76
77
|
datachain/remote/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
78
|
datachain/remote/studio.py,sha256=f5s6qSZ9uB4URGUoU_8_W1KZRRQQVSm6cgEBkBUEfuE,7226
|
|
78
79
|
datachain/sql/__init__.py,sha256=A2djrbQwSMUZZEIKGnm-mnRA-NDSbiDJNpAmmwGNyIo,303
|
|
@@ -92,9 +93,9 @@ datachain/sql/sqlite/base.py,sha256=LBYmXqXsVF30fbcnR55evCZHbPDCzMdGk_ogPLps63s,
|
|
|
92
93
|
datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
|
|
93
94
|
datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
|
|
94
95
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
95
|
-
datachain-0.
|
|
96
|
-
datachain-0.
|
|
97
|
-
datachain-0.
|
|
98
|
-
datachain-0.
|
|
99
|
-
datachain-0.
|
|
100
|
-
datachain-0.
|
|
96
|
+
datachain-0.3.1.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
97
|
+
datachain-0.3.1.dist-info/METADATA,sha256=qR3OMpGUkx0cKelnl51d9uksn5H-Wn4LvTJbUnTMDuQ,17268
|
|
98
|
+
datachain-0.3.1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
99
|
+
datachain-0.3.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
100
|
+
datachain-0.3.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
101
|
+
datachain-0.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|