datachain 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +11 -2
- datachain/client/fsspec.py +1 -4
- datachain/client/local.py +2 -7
- datachain/data_storage/schema.py +22 -8
- datachain/data_storage/sqlite.py +5 -0
- datachain/data_storage/warehouse.py +8 -14
- datachain/lib/dc.py +28 -14
- datachain/lib/meta_formats.py +8 -2
- datachain/lib/udf.py +21 -14
- datachain/node.py +1 -1
- datachain/query/batch.py +45 -41
- datachain/query/dataset.py +13 -6
- datachain/query/dispatch.py +53 -68
- datachain/query/queue.py +120 -0
- datachain/query/schema.py +4 -0
- datachain/query/udf.py +23 -8
- datachain/sql/default/base.py +3 -0
- datachain/sql/sqlite/base.py +3 -0
- datachain/sql/types.py +120 -11
- datachain/utils.py +17 -2
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/METADATA +74 -86
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/RECORD +26 -25
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/WHEEL +1 -1
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/LICENSE +0 -0
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/entry_points.txt +0 -0
- {datachain-0.3.0.dist-info → datachain-0.3.2.dist-info}/top_level.txt +0 -0
datachain/query/dispatch.py
CHANGED
|
@@ -2,11 +2,8 @@ import contextlib
|
|
|
2
2
|
from collections.abc import Iterator, Sequence
|
|
3
3
|
from itertools import chain
|
|
4
4
|
from multiprocessing import cpu_count
|
|
5
|
-
from queue import Empty, Full, Queue
|
|
6
5
|
from sys import stdin
|
|
7
|
-
from
|
|
8
|
-
from types import GeneratorType
|
|
9
|
-
from typing import Any, Optional
|
|
6
|
+
from typing import Optional
|
|
10
7
|
|
|
11
8
|
import attrs
|
|
12
9
|
import multiprocess
|
|
@@ -22,7 +19,16 @@ from datachain.query.dataset import (
|
|
|
22
19
|
get_processed_callback,
|
|
23
20
|
process_udf_outputs,
|
|
24
21
|
)
|
|
22
|
+
from datachain.query.queue import (
|
|
23
|
+
get_from_queue,
|
|
24
|
+
marshal,
|
|
25
|
+
msgpack_pack,
|
|
26
|
+
msgpack_unpack,
|
|
27
|
+
put_into_queue,
|
|
28
|
+
unmarshal,
|
|
29
|
+
)
|
|
25
30
|
from datachain.query.udf import UDFBase, UDFFactory, UDFResult
|
|
31
|
+
from datachain.utils import batched_it
|
|
26
32
|
|
|
27
33
|
DEFAULT_BATCH_SIZE = 10000
|
|
28
34
|
STOP_SIGNAL = "STOP"
|
|
@@ -44,44 +50,6 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int:
|
|
|
44
50
|
return n_workers
|
|
45
51
|
|
|
46
52
|
|
|
47
|
-
# For more context on the get_from_queue and put_into_queue functions, see the
|
|
48
|
-
# discussion here:
|
|
49
|
-
# https://github.com/iterative/dvcx/pull/1297#issuecomment-2026308773
|
|
50
|
-
# This problem is not exactly described by, but is also related to these Python issues:
|
|
51
|
-
# https://github.com/python/cpython/issues/66587
|
|
52
|
-
# https://github.com/python/cpython/issues/88628
|
|
53
|
-
# https://github.com/python/cpython/issues/108645
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def get_from_queue(queue: Queue) -> Any:
|
|
57
|
-
"""
|
|
58
|
-
Gets an item from a queue.
|
|
59
|
-
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
60
|
-
while waiting for items to be available, although only on certain installations.
|
|
61
|
-
(See the above comment for more context.)
|
|
62
|
-
"""
|
|
63
|
-
while True:
|
|
64
|
-
try:
|
|
65
|
-
return queue.get_nowait()
|
|
66
|
-
except Empty:
|
|
67
|
-
sleep(0.01)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def put_into_queue(queue: Queue, item: Any) -> None:
|
|
71
|
-
"""
|
|
72
|
-
Puts an item into a queue.
|
|
73
|
-
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
74
|
-
while waiting for items to be queued, although only on certain installations.
|
|
75
|
-
(See the above comment for more context.)
|
|
76
|
-
"""
|
|
77
|
-
while True:
|
|
78
|
-
try:
|
|
79
|
-
queue.put_nowait(item)
|
|
80
|
-
return
|
|
81
|
-
except Full:
|
|
82
|
-
sleep(0.01)
|
|
83
|
-
|
|
84
|
-
|
|
85
53
|
def udf_entrypoint() -> int:
|
|
86
54
|
# Load UDF info from stdin
|
|
87
55
|
udf_info = load(stdin.buffer)
|
|
@@ -100,8 +68,9 @@ def udf_entrypoint() -> int:
|
|
|
100
68
|
udf_info["id_generator_clone_params"],
|
|
101
69
|
udf_info["metastore_clone_params"],
|
|
102
70
|
udf_info["warehouse_clone_params"],
|
|
103
|
-
|
|
71
|
+
udf_fields=udf_info["udf_fields"],
|
|
104
72
|
cache=udf_info["cache"],
|
|
73
|
+
is_generator=udf_info.get("is_generator", False),
|
|
105
74
|
)
|
|
106
75
|
|
|
107
76
|
query = udf_info["query"]
|
|
@@ -121,7 +90,7 @@ def udf_entrypoint() -> int:
|
|
|
121
90
|
generated_cb = get_generated_callback(dispatch.is_generator)
|
|
122
91
|
try:
|
|
123
92
|
udf_results = dispatch.run_udf_parallel(
|
|
124
|
-
udf_inputs,
|
|
93
|
+
marshal(udf_inputs),
|
|
125
94
|
n_workers=n_workers,
|
|
126
95
|
processed_cb=processed_cb,
|
|
127
96
|
download_cb=download_cb,
|
|
@@ -142,6 +111,9 @@ def udf_worker_entrypoint() -> int:
|
|
|
142
111
|
|
|
143
112
|
|
|
144
113
|
class UDFDispatcher:
|
|
114
|
+
catalog: Optional[Catalog] = None
|
|
115
|
+
task_queue: Optional[multiprocess.Queue] = None
|
|
116
|
+
done_queue: Optional[multiprocess.Queue] = None
|
|
145
117
|
_batch_size: Optional[int] = None
|
|
146
118
|
|
|
147
119
|
def __init__(
|
|
@@ -151,9 +123,10 @@ class UDFDispatcher:
|
|
|
151
123
|
id_generator_clone_params,
|
|
152
124
|
metastore_clone_params,
|
|
153
125
|
warehouse_clone_params,
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
126
|
+
udf_fields: "Sequence[str]",
|
|
127
|
+
cache: bool,
|
|
128
|
+
is_generator: bool = False,
|
|
129
|
+
buffer_size: int = DEFAULT_BATCH_SIZE,
|
|
157
130
|
):
|
|
158
131
|
self.udf_data = udf_data
|
|
159
132
|
self.catalog_init_params = catalog_init_params
|
|
@@ -172,12 +145,13 @@ class UDFDispatcher:
|
|
|
172
145
|
self.warehouse_args,
|
|
173
146
|
self.warehouse_kwargs,
|
|
174
147
|
) = warehouse_clone_params
|
|
175
|
-
self.
|
|
148
|
+
self.udf_fields = udf_fields
|
|
176
149
|
self.cache = cache
|
|
150
|
+
self.is_generator = is_generator
|
|
151
|
+
self.buffer_size = buffer_size
|
|
177
152
|
self.catalog = None
|
|
178
153
|
self.task_queue = None
|
|
179
154
|
self.done_queue = None
|
|
180
|
-
self.buffer_size = buffer_size
|
|
181
155
|
self.ctx = get_context("spawn")
|
|
182
156
|
|
|
183
157
|
@property
|
|
@@ -226,6 +200,7 @@ class UDFDispatcher:
|
|
|
226
200
|
self.done_queue,
|
|
227
201
|
self.is_generator,
|
|
228
202
|
self.cache,
|
|
203
|
+
self.udf_fields,
|
|
229
204
|
)
|
|
230
205
|
|
|
231
206
|
def _run_worker(self) -> None:
|
|
@@ -233,7 +208,11 @@ class UDFDispatcher:
|
|
|
233
208
|
worker = self._create_worker()
|
|
234
209
|
worker.run()
|
|
235
210
|
except (Exception, KeyboardInterrupt) as e:
|
|
236
|
-
|
|
211
|
+
if self.done_queue:
|
|
212
|
+
put_into_queue(
|
|
213
|
+
self.done_queue,
|
|
214
|
+
{"status": FAILED_STATUS, "exception": e},
|
|
215
|
+
)
|
|
237
216
|
raise
|
|
238
217
|
|
|
239
218
|
@staticmethod
|
|
@@ -249,7 +228,6 @@ class UDFDispatcher:
|
|
|
249
228
|
self,
|
|
250
229
|
input_rows,
|
|
251
230
|
n_workers: Optional[int] = None,
|
|
252
|
-
cache: bool = False,
|
|
253
231
|
input_queue=None,
|
|
254
232
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
255
233
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
@@ -299,21 +277,24 @@ class UDFDispatcher:
|
|
|
299
277
|
result = get_from_queue(self.done_queue)
|
|
300
278
|
status = result["status"]
|
|
301
279
|
if status == NOTIFY_STATUS:
|
|
302
|
-
|
|
280
|
+
if downloaded := result.get("downloaded"):
|
|
281
|
+
download_cb.relative_update(downloaded)
|
|
282
|
+
if processed := result.get("processed"):
|
|
283
|
+
processed_cb.relative_update(processed)
|
|
303
284
|
elif status == FINISHED_STATUS:
|
|
304
285
|
# Worker finished
|
|
305
286
|
n_workers -= 1
|
|
306
287
|
elif status == OK_STATUS:
|
|
307
|
-
|
|
308
|
-
|
|
288
|
+
if processed := result.get("processed"):
|
|
289
|
+
processed_cb.relative_update(processed)
|
|
290
|
+
yield msgpack_unpack(result["result"])
|
|
309
291
|
else: # Failed / error
|
|
310
292
|
n_workers -= 1
|
|
311
|
-
exc
|
|
312
|
-
if exc:
|
|
293
|
+
if exc := result.get("exception"):
|
|
313
294
|
raise exc
|
|
314
295
|
raise RuntimeError("Internal error: Parallel UDF execution failed")
|
|
315
296
|
|
|
316
|
-
if not streaming_mode and not input_finished:
|
|
297
|
+
if status == OK_STATUS and not streaming_mode and not input_finished:
|
|
317
298
|
try:
|
|
318
299
|
put_into_queue(self.task_queue, next(input_data))
|
|
319
300
|
except StopIteration:
|
|
@@ -348,7 +329,7 @@ class UDFDispatcher:
|
|
|
348
329
|
|
|
349
330
|
|
|
350
331
|
class WorkerCallback(Callback):
|
|
351
|
-
def __init__(self, queue: multiprocess.Queue):
|
|
332
|
+
def __init__(self, queue: "multiprocess.Queue"):
|
|
352
333
|
self.queue = queue
|
|
353
334
|
super().__init__()
|
|
354
335
|
|
|
@@ -369,10 +350,11 @@ class ProcessedCallback(Callback):
|
|
|
369
350
|
class UDFWorker:
|
|
370
351
|
catalog: Catalog
|
|
371
352
|
udf: UDFBase
|
|
372
|
-
task_queue: multiprocess.Queue
|
|
373
|
-
done_queue: multiprocess.Queue
|
|
353
|
+
task_queue: "multiprocess.Queue"
|
|
354
|
+
done_queue: "multiprocess.Queue"
|
|
374
355
|
is_generator: bool
|
|
375
356
|
cache: bool
|
|
357
|
+
udf_fields: Sequence[str]
|
|
376
358
|
cb: Callback = attrs.field()
|
|
377
359
|
|
|
378
360
|
@cb.default
|
|
@@ -382,7 +364,8 @@ class UDFWorker:
|
|
|
382
364
|
def run(self) -> None:
|
|
383
365
|
processed_cb = ProcessedCallback()
|
|
384
366
|
udf_results = self.udf.run(
|
|
385
|
-
self.
|
|
367
|
+
self.udf_fields,
|
|
368
|
+
unmarshal(self.get_inputs()),
|
|
386
369
|
self.catalog,
|
|
387
370
|
self.is_generator,
|
|
388
371
|
self.cache,
|
|
@@ -390,15 +373,17 @@ class UDFWorker:
|
|
|
390
373
|
processed_cb=processed_cb,
|
|
391
374
|
)
|
|
392
375
|
for udf_output in udf_results:
|
|
393
|
-
|
|
394
|
-
|
|
376
|
+
for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE):
|
|
377
|
+
put_into_queue(
|
|
378
|
+
self.done_queue,
|
|
379
|
+
{
|
|
380
|
+
"status": OK_STATUS,
|
|
381
|
+
"result": msgpack_pack(list(batch)),
|
|
382
|
+
},
|
|
383
|
+
)
|
|
395
384
|
put_into_queue(
|
|
396
385
|
self.done_queue,
|
|
397
|
-
{
|
|
398
|
-
"status": OK_STATUS,
|
|
399
|
-
"result": udf_output,
|
|
400
|
-
"processed": processed_cb.processed_rows,
|
|
401
|
-
},
|
|
386
|
+
{"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows},
|
|
402
387
|
)
|
|
403
388
|
put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
|
|
404
389
|
|
datachain/query/queue.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from collections.abc import Iterable, Iterator
|
|
3
|
+
from queue import Empty, Full, Queue
|
|
4
|
+
from struct import pack, unpack
|
|
5
|
+
from time import sleep
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import msgpack
|
|
9
|
+
|
|
10
|
+
from datachain.query.batch import RowsOutput, RowsOutputBatch
|
|
11
|
+
|
|
12
|
+
DEFAULT_BATCH_SIZE = 10000
|
|
13
|
+
STOP_SIGNAL = "STOP"
|
|
14
|
+
OK_STATUS = "OK"
|
|
15
|
+
FINISHED_STATUS = "FINISHED"
|
|
16
|
+
FAILED_STATUS = "FAILED"
|
|
17
|
+
NOTIFY_STATUS = "NOTIFY"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# For more context on the get_from_queue and put_into_queue functions, see the
|
|
21
|
+
# discussion here:
|
|
22
|
+
# https://github.com/iterative/dvcx/pull/1297#issuecomment-2026308773
|
|
23
|
+
# This problem is not exactly described by, but is also related to these Python issues:
|
|
24
|
+
# https://github.com/python/cpython/issues/66587
|
|
25
|
+
# https://github.com/python/cpython/issues/88628
|
|
26
|
+
# https://github.com/python/cpython/issues/108645
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_from_queue(queue: Queue) -> Any:
|
|
30
|
+
"""
|
|
31
|
+
Gets an item from a queue.
|
|
32
|
+
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
33
|
+
while waiting for items to be available, although only on certain installations.
|
|
34
|
+
(See the above comment for more context.)
|
|
35
|
+
"""
|
|
36
|
+
while True:
|
|
37
|
+
try:
|
|
38
|
+
return queue.get_nowait()
|
|
39
|
+
except Empty:
|
|
40
|
+
sleep(0.01)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def put_into_queue(queue: Queue, item: Any) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Puts an item into a queue.
|
|
46
|
+
This is required to handle signals, such as KeyboardInterrupt exceptions
|
|
47
|
+
while waiting for items to be queued, although only on certain installations.
|
|
48
|
+
(See the above comment for more context.)
|
|
49
|
+
"""
|
|
50
|
+
while True:
|
|
51
|
+
try:
|
|
52
|
+
queue.put_nowait(item)
|
|
53
|
+
return
|
|
54
|
+
except Full:
|
|
55
|
+
sleep(0.01)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
MSGPACK_EXT_TYPE_DATETIME = 42
|
|
59
|
+
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH = 43
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _msgpack_pack_extended_types(obj: Any) -> msgpack.ExtType:
|
|
63
|
+
if isinstance(obj, datetime.datetime):
|
|
64
|
+
# packing date object as 1 or 2 variables, depending if timezone info is present
|
|
65
|
+
# - timestamp
|
|
66
|
+
# - [OPTIONAL] timezone offset from utc in seconds if timezone info exists
|
|
67
|
+
if obj.tzinfo:
|
|
68
|
+
data = (obj.timestamp(), int(obj.utcoffset().total_seconds())) # type: ignore # noqa: PGH003
|
|
69
|
+
return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!dl", *data))
|
|
70
|
+
data = (obj.timestamp(),) # type: ignore # noqa: PGH003
|
|
71
|
+
return msgpack.ExtType(MSGPACK_EXT_TYPE_DATETIME, pack("!d", *data))
|
|
72
|
+
|
|
73
|
+
if isinstance(obj, RowsOutputBatch):
|
|
74
|
+
return msgpack.ExtType(
|
|
75
|
+
MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH,
|
|
76
|
+
msgpack_pack(obj.rows),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
raise TypeError(f"Unknown type: {obj}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def msgpack_pack(obj: Any) -> bytes:
|
|
83
|
+
return msgpack.packb(obj, default=_msgpack_pack_extended_types)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _msgpack_unpack_extended_types(code: int, data: bytes) -> Any:
|
|
87
|
+
if code == MSGPACK_EXT_TYPE_DATETIME:
|
|
88
|
+
has_timezone = False
|
|
89
|
+
if len(data) == 8:
|
|
90
|
+
# we send only timestamp without timezone if data is 8 bytes
|
|
91
|
+
values = unpack("!d", data)
|
|
92
|
+
else:
|
|
93
|
+
has_timezone = True
|
|
94
|
+
values = unpack("!dl", data)
|
|
95
|
+
|
|
96
|
+
timestamp = values[0]
|
|
97
|
+
tz_info = None
|
|
98
|
+
if has_timezone:
|
|
99
|
+
timezone_offset = values[1]
|
|
100
|
+
tz_info = datetime.timezone(datetime.timedelta(seconds=timezone_offset))
|
|
101
|
+
return datetime.datetime.fromtimestamp(timestamp, tz=tz_info)
|
|
102
|
+
|
|
103
|
+
if code == MSGPACK_EXT_TYPE_ROWS_INPUT_BATCH:
|
|
104
|
+
return RowsOutputBatch(msgpack_unpack(data))
|
|
105
|
+
|
|
106
|
+
return msgpack.ExtType(code, data)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def msgpack_unpack(data: bytes) -> Any:
|
|
110
|
+
return msgpack.unpackb(data, ext_hook=_msgpack_unpack_extended_types)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def marshal(obj: Iterator[RowsOutput]) -> Iterable[bytes]:
|
|
114
|
+
for row in obj:
|
|
115
|
+
yield msgpack_pack(row)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def unmarshal(obj: Iterator[bytes]) -> Iterable[RowsOutput]:
|
|
119
|
+
for row in obj:
|
|
120
|
+
yield msgpack_unpack(row)
|
datachain/query/schema.py
CHANGED
|
@@ -45,6 +45,10 @@ class Column(sa.ColumnClause, metaclass=ColumnMeta):
|
|
|
45
45
|
"""Search for matches using glob pattern matching."""
|
|
46
46
|
return self.op("GLOB")(glob_str)
|
|
47
47
|
|
|
48
|
+
def regexp(self, regexp_str):
|
|
49
|
+
"""Search for matches using regexp pattern matching."""
|
|
50
|
+
return self.op("REGEXP")(regexp_str)
|
|
51
|
+
|
|
48
52
|
|
|
49
53
|
class UDFParameter(ABC):
|
|
50
54
|
@abstractmethod
|
datachain/query/udf.py
CHANGED
|
@@ -15,7 +15,14 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
|
15
15
|
|
|
16
16
|
from datachain.dataset import RowDict
|
|
17
17
|
|
|
18
|
-
from .batch import
|
|
18
|
+
from .batch import (
|
|
19
|
+
Batch,
|
|
20
|
+
BatchingStrategy,
|
|
21
|
+
NoBatching,
|
|
22
|
+
Partition,
|
|
23
|
+
RowsOutputBatch,
|
|
24
|
+
UDFInputBatch,
|
|
25
|
+
)
|
|
19
26
|
from .schema import (
|
|
20
27
|
UDFParameter,
|
|
21
28
|
UDFParamSpec,
|
|
@@ -25,7 +32,7 @@ from .schema import (
|
|
|
25
32
|
if TYPE_CHECKING:
|
|
26
33
|
from datachain.catalog import Catalog
|
|
27
34
|
|
|
28
|
-
from .batch import
|
|
35
|
+
from .batch import RowsOutput, UDFInput
|
|
29
36
|
|
|
30
37
|
ColumnType = Any
|
|
31
38
|
|
|
@@ -107,7 +114,8 @@ class UDFBase:
|
|
|
107
114
|
|
|
108
115
|
def run(
|
|
109
116
|
self,
|
|
110
|
-
|
|
117
|
+
udf_fields: "Sequence[str]",
|
|
118
|
+
udf_inputs: "Iterable[RowsOutput]",
|
|
111
119
|
catalog: "Catalog",
|
|
112
120
|
is_generator: bool,
|
|
113
121
|
cache: bool,
|
|
@@ -115,15 +123,22 @@ class UDFBase:
|
|
|
115
123
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
116
124
|
) -> Iterator[Iterable["UDFResult"]]:
|
|
117
125
|
for batch in udf_inputs:
|
|
118
|
-
|
|
119
|
-
|
|
126
|
+
if isinstance(batch, RowsOutputBatch):
|
|
127
|
+
n_rows = len(batch.rows)
|
|
128
|
+
inputs: UDFInput = UDFInputBatch(
|
|
129
|
+
[RowDict(zip(udf_fields, row)) for row in batch.rows]
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
n_rows = 1
|
|
133
|
+
inputs = RowDict(zip(udf_fields, batch))
|
|
134
|
+
output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
|
|
120
135
|
processed_cb.relative_update(n_rows)
|
|
121
136
|
yield output
|
|
122
137
|
|
|
123
138
|
def run_once(
|
|
124
139
|
self,
|
|
125
140
|
catalog: "Catalog",
|
|
126
|
-
arg: "
|
|
141
|
+
arg: "UDFInput",
|
|
127
142
|
is_generator: bool = False,
|
|
128
143
|
cache: bool = False,
|
|
129
144
|
cb: Callback = DEFAULT_CALLBACK,
|
|
@@ -199,12 +214,12 @@ class UDFWrapper(UDFBase):
|
|
|
199
214
|
def run_once(
|
|
200
215
|
self,
|
|
201
216
|
catalog: "Catalog",
|
|
202
|
-
arg: "
|
|
217
|
+
arg: "UDFInput",
|
|
203
218
|
is_generator: bool = False,
|
|
204
219
|
cache: bool = False,
|
|
205
220
|
cb: Callback = DEFAULT_CALLBACK,
|
|
206
221
|
) -> Iterable[UDFResult]:
|
|
207
|
-
if isinstance(arg,
|
|
222
|
+
if isinstance(arg, UDFInputBatch):
|
|
208
223
|
udf_inputs = [
|
|
209
224
|
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
210
225
|
for row in arg.rows
|
datachain/sql/default/base.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from datachain.sql.types import (
|
|
2
|
+
DBDefaults,
|
|
2
3
|
TypeConverter,
|
|
3
4
|
TypeDefaults,
|
|
4
5
|
TypeReadConverter,
|
|
5
6
|
register_backend_types,
|
|
7
|
+
register_db_defaults,
|
|
6
8
|
register_type_defaults,
|
|
7
9
|
register_type_read_converters,
|
|
8
10
|
)
|
|
@@ -18,5 +20,6 @@ def setup() -> None:
|
|
|
18
20
|
register_backend_types("default", TypeConverter())
|
|
19
21
|
register_type_read_converters("default", TypeReadConverter())
|
|
20
22
|
register_type_defaults("default", TypeDefaults())
|
|
23
|
+
register_db_defaults("default", DBDefaults())
|
|
21
24
|
|
|
22
25
|
setup_is_complete = True
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -22,8 +22,10 @@ from datachain.sql.sqlite.types import (
|
|
|
22
22
|
register_type_converters,
|
|
23
23
|
)
|
|
24
24
|
from datachain.sql.types import (
|
|
25
|
+
DBDefaults,
|
|
25
26
|
TypeDefaults,
|
|
26
27
|
register_backend_types,
|
|
28
|
+
register_db_defaults,
|
|
27
29
|
register_type_defaults,
|
|
28
30
|
register_type_read_converters,
|
|
29
31
|
)
|
|
@@ -66,6 +68,7 @@ def setup():
|
|
|
66
68
|
register_backend_types("sqlite", SQLiteTypeConverter())
|
|
67
69
|
register_type_read_converters("sqlite", SQLiteTypeReadConverter())
|
|
68
70
|
register_type_defaults("sqlite", TypeDefaults())
|
|
71
|
+
register_db_defaults("sqlite", DBDefaults())
|
|
69
72
|
|
|
70
73
|
compiles(sql_path.parent, "sqlite")(compile_path_parent)
|
|
71
74
|
compiles(sql_path.name, "sqlite")(compile_path_name)
|