datachain 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +8 -0
- datachain/cli.py +3 -2
- datachain/data_storage/metastore.py +28 -9
- datachain/data_storage/sqlite.py +24 -32
- datachain/data_storage/warehouse.py +1 -3
- datachain/dataset.py +0 -3
- datachain/lib/arrow.py +64 -19
- datachain/lib/dc.py +310 -123
- datachain/lib/listing.py +5 -3
- datachain/lib/pytorch.py +5 -1
- datachain/lib/udf.py +100 -78
- datachain/lib/udf_signature.py +8 -6
- datachain/query/dataset.py +7 -7
- datachain/query/dispatch.py +2 -2
- datachain/query/session.py +42 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/METADATA +1 -1
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/RECORD +21 -22
- datachain/query/udf.py +0 -126
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/LICENSE +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/WHEEL +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.4.0.dist-info → datachain-0.5.1.dist-info}/top_level.txt +0 -0
datachain/lib/listing.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
from collections.abc import Iterator
|
|
3
3
|
from datetime import datetime, timedelta, timezone
|
|
4
|
-
from typing import TYPE_CHECKING, Callable, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar
|
|
5
5
|
|
|
6
6
|
from fsspec.asyn import get_loop
|
|
7
7
|
from sqlalchemy.sql.expression import true
|
|
@@ -20,6 +20,8 @@ if TYPE_CHECKING:
|
|
|
20
20
|
LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours
|
|
21
21
|
LISTING_PREFIX = "lst__" # listing datasets start with this name
|
|
22
22
|
|
|
23
|
+
D = TypeVar("D", bound="DataChain")
|
|
24
|
+
|
|
23
25
|
|
|
24
26
|
def list_bucket(uri: str, cache, client_config=None) -> Callable:
|
|
25
27
|
"""
|
|
@@ -38,11 +40,11 @@ def list_bucket(uri: str, cache, client_config=None) -> Callable:
|
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
def ls(
|
|
41
|
-
dc:
|
|
43
|
+
dc: D,
|
|
42
44
|
path: str,
|
|
43
45
|
recursive: Optional[bool] = True,
|
|
44
46
|
object_name="file",
|
|
45
|
-
):
|
|
47
|
+
) -> D:
|
|
46
48
|
"""
|
|
47
49
|
Return files by some path from DataChain instance which contains bucket listing.
|
|
48
50
|
Path can have globs.
|
datachain/lib/pytorch.py
CHANGED
|
@@ -9,6 +9,7 @@ from torch.utils.data import IterableDataset, get_worker_info
|
|
|
9
9
|
from torchvision.transforms import v2
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
|
|
12
|
+
from datachain import Session
|
|
12
13
|
from datachain.catalog import Catalog, get_catalog
|
|
13
14
|
from datachain.lib.dc import DataChain
|
|
14
15
|
from datachain.lib.text import convert_text
|
|
@@ -87,8 +88,11 @@ class PytorchDataset(IterableDataset):
|
|
|
87
88
|
def __iter__(self) -> Iterator[Any]:
|
|
88
89
|
if self.catalog is None:
|
|
89
90
|
self.catalog = self._get_catalog()
|
|
91
|
+
session = Session.get(catalog=self.catalog)
|
|
90
92
|
total_rank, total_workers = self.get_rank_and_workers()
|
|
91
|
-
ds = DataChain(
|
|
93
|
+
ds = DataChain.from_dataset(
|
|
94
|
+
name=self.name, version=self.version, session=session
|
|
95
|
+
)
|
|
92
96
|
ds = ds.remove_file_signals()
|
|
93
97
|
|
|
94
98
|
if self.num_samples > 0:
|
datachain/lib/udf.py
CHANGED
|
@@ -1,31 +1,33 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import traceback
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
4
6
|
|
|
5
7
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
6
8
|
from pydantic import BaseModel
|
|
7
9
|
|
|
8
10
|
from datachain.dataset import RowDict
|
|
9
11
|
from datachain.lib.convert.flatten import flatten
|
|
10
|
-
from datachain.lib.convert.unflatten import unflatten_to_json
|
|
11
12
|
from datachain.lib.file import File
|
|
12
|
-
from datachain.lib.model_store import ModelStore
|
|
13
13
|
from datachain.lib.signal_schema import SignalSchema
|
|
14
|
-
from datachain.lib.udf_signature import UdfSignature
|
|
15
14
|
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
16
|
-
from datachain.query.batch import
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
15
|
+
from datachain.query.batch import (
|
|
16
|
+
Batch,
|
|
17
|
+
BatchingStrategy,
|
|
18
|
+
NoBatching,
|
|
19
|
+
Partition,
|
|
20
|
+
RowsOutputBatch,
|
|
21
|
+
UDFInputBatch,
|
|
22
|
+
)
|
|
23
|
+
from datachain.query.schema import ColumnParameter, UDFParameter
|
|
20
24
|
|
|
21
25
|
if TYPE_CHECKING:
|
|
22
|
-
from collections.abc import Iterable, Iterator, Sequence
|
|
23
|
-
|
|
24
26
|
from typing_extensions import Self
|
|
25
27
|
|
|
26
28
|
from datachain.catalog import Catalog
|
|
29
|
+
from datachain.lib.udf_signature import UdfSignature
|
|
27
30
|
from datachain.query.batch import RowsOutput, UDFInput
|
|
28
|
-
from datachain.query.udf import UDFResult
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class UdfError(DataChainParamsError):
|
|
@@ -33,14 +35,47 @@ class UdfError(DataChainParamsError):
|
|
|
33
35
|
super().__init__(f"UDF error: {msg}")
|
|
34
36
|
|
|
35
37
|
|
|
36
|
-
|
|
38
|
+
ColumnType = Any
|
|
39
|
+
|
|
40
|
+
# Specification for the output of a UDF
|
|
41
|
+
UDFOutputSpec = Mapping[str, ColumnType]
|
|
42
|
+
|
|
43
|
+
# Result type when calling the UDF wrapper around the actual
|
|
44
|
+
# Python function / class implementing it.
|
|
45
|
+
UDFResult = dict[str, Any]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class UDFProperties:
|
|
50
|
+
"""Container for basic UDF properties."""
|
|
51
|
+
|
|
52
|
+
params: list[UDFParameter]
|
|
53
|
+
output: UDFOutputSpec
|
|
54
|
+
batch: int = 1
|
|
55
|
+
|
|
56
|
+
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
57
|
+
if use_partitioning:
|
|
58
|
+
return Partition()
|
|
59
|
+
if self.batch == 1:
|
|
60
|
+
return NoBatching()
|
|
61
|
+
if self.batch > 1:
|
|
62
|
+
return Batch(self.batch)
|
|
63
|
+
raise ValueError(f"invalid batch size {self.batch}")
|
|
64
|
+
|
|
65
|
+
def signal_names(self) -> Iterable[str]:
|
|
66
|
+
return self.output.keys()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class UDFAdapter:
|
|
37
70
|
def __init__(
|
|
38
71
|
self,
|
|
39
72
|
inner: "UDFBase",
|
|
40
73
|
properties: UDFProperties,
|
|
41
74
|
):
|
|
42
75
|
self.inner = inner
|
|
43
|
-
|
|
76
|
+
self.properties = properties
|
|
77
|
+
self.signal_names = properties.signal_names()
|
|
78
|
+
self.output = properties.output
|
|
44
79
|
|
|
45
80
|
def run(
|
|
46
81
|
self,
|
|
@@ -51,20 +86,23 @@ class UDFAdapter(_UDFBase):
|
|
|
51
86
|
cache: bool,
|
|
52
87
|
download_cb: Callback = DEFAULT_CALLBACK,
|
|
53
88
|
processed_cb: Callback = DEFAULT_CALLBACK,
|
|
54
|
-
) ->
|
|
55
|
-
self.inner.
|
|
89
|
+
) -> Iterator[Iterable[UDFResult]]:
|
|
90
|
+
self.inner.catalog = catalog
|
|
56
91
|
if hasattr(self.inner, "setup") and callable(self.inner.setup):
|
|
57
92
|
self.inner.setup()
|
|
58
93
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
94
|
+
for batch in udf_inputs:
|
|
95
|
+
if isinstance(batch, RowsOutputBatch):
|
|
96
|
+
n_rows = len(batch.rows)
|
|
97
|
+
inputs: UDFInput = UDFInputBatch(
|
|
98
|
+
[RowDict(zip(udf_fields, row)) for row in batch.rows]
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
n_rows = 1
|
|
102
|
+
inputs = RowDict(zip(udf_fields, batch))
|
|
103
|
+
output = self.run_once(catalog, inputs, is_generator, cache, cb=download_cb)
|
|
104
|
+
processed_cb.relative_update(n_rows)
|
|
105
|
+
yield output
|
|
68
106
|
|
|
69
107
|
if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
|
|
70
108
|
self.inner.teardown()
|
|
@@ -76,23 +114,46 @@ class UDFAdapter(_UDFBase):
|
|
|
76
114
|
is_generator: bool = False,
|
|
77
115
|
cache: bool = False,
|
|
78
116
|
cb: Callback = DEFAULT_CALLBACK,
|
|
79
|
-
) ->
|
|
117
|
+
) -> Iterable[UDFResult]:
|
|
80
118
|
if isinstance(arg, UDFInputBatch):
|
|
81
119
|
udf_inputs = [
|
|
82
120
|
self.bind_parameters(catalog, row, cache=cache, cb=cb)
|
|
83
121
|
for row in arg.rows
|
|
84
122
|
]
|
|
85
|
-
udf_outputs = self.inner(udf_inputs, cache=cache, download_cb=cb)
|
|
123
|
+
udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
|
|
86
124
|
return self._process_results(arg.rows, udf_outputs, is_generator)
|
|
87
125
|
if isinstance(arg, RowDict):
|
|
88
126
|
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
|
|
89
|
-
udf_outputs = self.inner(
|
|
127
|
+
udf_outputs = self.inner.run_once(udf_inputs, cache=cache, download_cb=cb)
|
|
90
128
|
if not is_generator:
|
|
91
129
|
# udf_outputs is generator already if is_generator=True
|
|
92
130
|
udf_outputs = [udf_outputs]
|
|
93
131
|
return self._process_results([arg], udf_outputs, is_generator)
|
|
94
132
|
raise ValueError(f"Unexpected UDF argument: {arg}")
|
|
95
133
|
|
|
134
|
+
def bind_parameters(self, catalog: "Catalog", row: "RowDict", **kwargs) -> list:
|
|
135
|
+
return [p.get_value(catalog, row, **kwargs) for p in self.properties.params]
|
|
136
|
+
|
|
137
|
+
def _process_results(
|
|
138
|
+
self,
|
|
139
|
+
rows: Sequence["RowDict"],
|
|
140
|
+
results: Sequence[Sequence[Any]],
|
|
141
|
+
is_generator=False,
|
|
142
|
+
) -> Iterable[UDFResult]:
|
|
143
|
+
"""Create a list of dictionaries representing UDF results."""
|
|
144
|
+
|
|
145
|
+
# outputting rows
|
|
146
|
+
if is_generator:
|
|
147
|
+
# each row in results is a tuple of column values
|
|
148
|
+
return (dict(zip(self.signal_names, row)) for row in results)
|
|
149
|
+
|
|
150
|
+
# outputting signals
|
|
151
|
+
row_ids = [row["sys__id"] for row in rows]
|
|
152
|
+
return [
|
|
153
|
+
{"sys__id": row_id} | dict(zip(self.signal_names, signals))
|
|
154
|
+
for row_id, signals in zip(row_ids, results)
|
|
155
|
+
]
|
|
156
|
+
|
|
96
157
|
|
|
97
158
|
class UDFBase(AbstractUDF):
|
|
98
159
|
"""Base class for stateful user-defined functions.
|
|
@@ -146,14 +207,14 @@ class UDFBase(AbstractUDF):
|
|
|
146
207
|
is_output_batched = False
|
|
147
208
|
is_input_grouped = False
|
|
148
209
|
params_spec: Optional[list[str]]
|
|
210
|
+
catalog: "Optional[Catalog]"
|
|
149
211
|
|
|
150
212
|
def __init__(self):
|
|
151
213
|
self.params = None
|
|
152
214
|
self.output = None
|
|
153
215
|
self.params_spec = None
|
|
154
216
|
self.output_spec = None
|
|
155
|
-
self.
|
|
156
|
-
self._catalog = None
|
|
217
|
+
self.catalog = None
|
|
157
218
|
self._func = None
|
|
158
219
|
|
|
159
220
|
def process(self, *args, **kwargs):
|
|
@@ -174,9 +235,9 @@ class UDFBase(AbstractUDF):
|
|
|
174
235
|
|
|
175
236
|
def _init(
|
|
176
237
|
self,
|
|
177
|
-
sign: UdfSignature,
|
|
238
|
+
sign: "UdfSignature",
|
|
178
239
|
params: SignalSchema,
|
|
179
|
-
func: Callable,
|
|
240
|
+
func: Optional[Callable],
|
|
180
241
|
):
|
|
181
242
|
self.params = params
|
|
182
243
|
self.output = sign.output_schema
|
|
@@ -190,13 +251,13 @@ class UDFBase(AbstractUDF):
|
|
|
190
251
|
@classmethod
|
|
191
252
|
def _create(
|
|
192
253
|
cls,
|
|
193
|
-
sign: UdfSignature,
|
|
254
|
+
sign: "UdfSignature",
|
|
194
255
|
params: SignalSchema,
|
|
195
256
|
) -> "Self":
|
|
196
257
|
if isinstance(sign.func, AbstractUDF):
|
|
197
258
|
if not isinstance(sign.func, cls): # type: ignore[unreachable]
|
|
198
259
|
raise UdfError(
|
|
199
|
-
f"cannot create UDF: provided UDF '{sign.func.__name__}'"
|
|
260
|
+
f"cannot create UDF: provided UDF '{type(sign.func).__name__}'"
|
|
200
261
|
f" must be a child of target class '{cls.__name__}'",
|
|
201
262
|
)
|
|
202
263
|
result = sign.func
|
|
@@ -212,13 +273,6 @@ class UDFBase(AbstractUDF):
|
|
|
212
273
|
def name(self):
|
|
213
274
|
return self.__class__.__name__
|
|
214
275
|
|
|
215
|
-
def set_catalog(self, catalog):
|
|
216
|
-
self._catalog = catalog.copy(db=False)
|
|
217
|
-
|
|
218
|
-
@property
|
|
219
|
-
def catalog(self):
|
|
220
|
-
return self._catalog
|
|
221
|
-
|
|
222
276
|
def to_udf_wrapper(self, batch: int = 1) -> UDFAdapter:
|
|
223
277
|
assert self.params_spec is not None
|
|
224
278
|
properties = UDFProperties(
|
|
@@ -229,11 +283,9 @@ class UDFBase(AbstractUDF):
|
|
|
229
283
|
def validate_results(self, results, *args, **kwargs):
|
|
230
284
|
return results
|
|
231
285
|
|
|
232
|
-
def
|
|
233
|
-
if self.
|
|
234
|
-
objs = self.
|
|
235
|
-
elif self.is_input_batched:
|
|
236
|
-
objs = zip(*self._parse_rows(rows[0], cache, download_cb))
|
|
286
|
+
def run_once(self, rows, cache, download_cb):
|
|
287
|
+
if self.is_input_batched:
|
|
288
|
+
objs = zip(*self._parse_rows(rows, cache, download_cb))
|
|
237
289
|
else:
|
|
238
290
|
objs = self._parse_rows([rows], cache, download_cb)[0]
|
|
239
291
|
|
|
@@ -259,8 +311,8 @@ class UDFBase(AbstractUDF):
|
|
|
259
311
|
):
|
|
260
312
|
res = list(res)
|
|
261
313
|
assert len(res) == len(
|
|
262
|
-
rows
|
|
263
|
-
), f"{self.name} returns {len(res)} rows while len(rows
|
|
314
|
+
rows
|
|
315
|
+
), f"{self.name} returns {len(res)} rows while {len(rows)} expected"
|
|
264
316
|
|
|
265
317
|
return res
|
|
266
318
|
|
|
@@ -283,41 +335,11 @@ class UDFBase(AbstractUDF):
|
|
|
283
335
|
for obj in obj_row:
|
|
284
336
|
if isinstance(obj, File):
|
|
285
337
|
obj._set_stream(
|
|
286
|
-
self.
|
|
338
|
+
self.catalog, caching_enabled=cache, download_cb=download_cb
|
|
287
339
|
)
|
|
288
340
|
objs.append(obj_row)
|
|
289
341
|
return objs
|
|
290
342
|
|
|
291
|
-
def _parse_grouped_rows(self, group, cache, download_cb):
|
|
292
|
-
spec_map = {}
|
|
293
|
-
output_map = {}
|
|
294
|
-
for name, (anno, subtree) in self.params.tree.items():
|
|
295
|
-
if ModelStore.is_pydantic(anno):
|
|
296
|
-
length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
|
|
297
|
-
else:
|
|
298
|
-
length = 1
|
|
299
|
-
spec_map[name] = anno, length
|
|
300
|
-
output_map[name] = []
|
|
301
|
-
|
|
302
|
-
for flat_obj in group:
|
|
303
|
-
position = 0
|
|
304
|
-
for signal, (cls, length) in spec_map.items():
|
|
305
|
-
slice = flat_obj[position : position + length]
|
|
306
|
-
position += length
|
|
307
|
-
|
|
308
|
-
if ModelStore.is_pydantic(cls):
|
|
309
|
-
obj = cls(**unflatten_to_json(cls, slice))
|
|
310
|
-
else:
|
|
311
|
-
obj = slice[0]
|
|
312
|
-
|
|
313
|
-
if isinstance(obj, File):
|
|
314
|
-
obj._set_stream(
|
|
315
|
-
self._catalog, caching_enabled=cache, download_cb=download_cb
|
|
316
|
-
)
|
|
317
|
-
output_map[signal].append(obj)
|
|
318
|
-
|
|
319
|
-
return list(output_map.values())
|
|
320
|
-
|
|
321
343
|
def process_safe(self, obj_rows):
|
|
322
344
|
try:
|
|
323
345
|
result_objs = self.process(*obj_rows)
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from collections.abc import Generator, Iterator, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Callable,
|
|
4
|
+
from typing import Callable, Union, get_args, get_origin
|
|
5
5
|
|
|
6
6
|
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
|
+
from datachain.lib.udf import UDFBase
|
|
8
9
|
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
10
|
|
|
10
11
|
|
|
@@ -16,7 +17,7 @@ class UdfSignatureError(DataChainParamsError):
|
|
|
16
17
|
|
|
17
18
|
@dataclass
|
|
18
19
|
class UdfSignature:
|
|
19
|
-
func: Callable
|
|
20
|
+
func: Union[Callable, UDFBase]
|
|
20
21
|
params: Sequence[str]
|
|
21
22
|
output_schema: SignalSchema
|
|
22
23
|
|
|
@@ -27,7 +28,7 @@ class UdfSignature:
|
|
|
27
28
|
cls,
|
|
28
29
|
chain: str,
|
|
29
30
|
signal_map: dict[str, Callable],
|
|
30
|
-
func:
|
|
31
|
+
func: Union[None, UDFBase, Callable] = None,
|
|
31
32
|
params: Union[None, str, Sequence[str]] = None,
|
|
32
33
|
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
|
|
33
34
|
is_generator: bool = True,
|
|
@@ -39,6 +40,7 @@ class UdfSignature:
|
|
|
39
40
|
f"multiple signals '{keys}' are not supported in processors."
|
|
40
41
|
" Chain multiple processors instead.",
|
|
41
42
|
)
|
|
43
|
+
udf_func: Union[UDFBase, Callable]
|
|
42
44
|
if len(signal_map) == 1:
|
|
43
45
|
if func is not None:
|
|
44
46
|
raise UdfSignatureError(
|
|
@@ -53,7 +55,7 @@ class UdfSignature:
|
|
|
53
55
|
udf_func = func
|
|
54
56
|
signal_name = None
|
|
55
57
|
|
|
56
|
-
if not callable(udf_func):
|
|
58
|
+
if not isinstance(udf_func, UDFBase) and not callable(udf_func):
|
|
57
59
|
raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")
|
|
58
60
|
|
|
59
61
|
func_params_map_sign, func_outs_sign, is_iterator = (
|
|
@@ -73,7 +75,7 @@ class UdfSignature:
|
|
|
73
75
|
if not func_outs_sign:
|
|
74
76
|
raise UdfSignatureError(
|
|
75
77
|
chain,
|
|
76
|
-
f"outputs are not defined in function '{udf_func
|
|
78
|
+
f"outputs are not defined in function '{udf_func}'"
|
|
77
79
|
" hints or 'output'",
|
|
78
80
|
)
|
|
79
81
|
|
|
@@ -154,7 +156,7 @@ class UdfSignature:
|
|
|
154
156
|
|
|
155
157
|
@staticmethod
|
|
156
158
|
def _func_signature(
|
|
157
|
-
chain: str, udf_func: Callable
|
|
159
|
+
chain: str, udf_func: Union[Callable, UDFBase]
|
|
158
160
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
159
161
|
if isinstance(udf_func, AbstractUDF):
|
|
160
162
|
func = udf_func.process # type: ignore[unreachable]
|
datachain/query/dataset.py
CHANGED
|
@@ -42,6 +42,7 @@ from datachain.data_storage.schema import (
|
|
|
42
42
|
)
|
|
43
43
|
from datachain.dataset import DatasetStatus, RowDict
|
|
44
44
|
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
|
|
45
|
+
from datachain.lib.udf import UDFAdapter
|
|
45
46
|
from datachain.progress import CombinedDownloadCallback
|
|
46
47
|
from datachain.sql.functions import rand
|
|
47
48
|
from datachain.utils import (
|
|
@@ -53,7 +54,6 @@ from datachain.utils import (
|
|
|
53
54
|
|
|
54
55
|
from .schema import C, UDFParamSpec, normalize_param
|
|
55
56
|
from .session import Session
|
|
56
|
-
from .udf import UDFBase
|
|
57
57
|
|
|
58
58
|
if TYPE_CHECKING:
|
|
59
59
|
from sqlalchemy.sql.elements import ClauseElement
|
|
@@ -299,7 +299,7 @@ def adjust_outputs(
|
|
|
299
299
|
return row
|
|
300
300
|
|
|
301
301
|
|
|
302
|
-
def get_udf_col_types(warehouse: "AbstractWarehouse", udf:
|
|
302
|
+
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
|
|
303
303
|
"""Optimization: Precompute UDF column types so these don't have to be computed
|
|
304
304
|
in the convert_type function for each row in a loop."""
|
|
305
305
|
dialect = warehouse.db.dialect
|
|
@@ -320,7 +320,7 @@ def process_udf_outputs(
|
|
|
320
320
|
warehouse: "AbstractWarehouse",
|
|
321
321
|
udf_table: "Table",
|
|
322
322
|
udf_results: Iterator[Iterable["UDFResult"]],
|
|
323
|
-
udf:
|
|
323
|
+
udf: UDFAdapter,
|
|
324
324
|
batch_size: int = INSERT_BATCH_SIZE,
|
|
325
325
|
cb: Callback = DEFAULT_CALLBACK,
|
|
326
326
|
) -> None:
|
|
@@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:
|
|
|
364
364
|
|
|
365
365
|
@frozen
|
|
366
366
|
class UDFStep(Step, ABC):
|
|
367
|
-
udf:
|
|
367
|
+
udf: UDFAdapter
|
|
368
368
|
catalog: "Catalog"
|
|
369
369
|
partition_by: Optional[PartitionByType] = None
|
|
370
370
|
parallel: Optional[int] = None
|
|
@@ -1037,7 +1037,7 @@ class DatasetQuery:
|
|
|
1037
1037
|
session: Optional[Session] = None,
|
|
1038
1038
|
indexing_column_types: Optional[dict[str, Any]] = None,
|
|
1039
1039
|
in_memory: bool = False,
|
|
1040
|
-
):
|
|
1040
|
+
) -> None:
|
|
1041
1041
|
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
|
|
1042
1042
|
self.catalog = catalog or self.session.catalog
|
|
1043
1043
|
self.steps: list[Step] = []
|
|
@@ -1465,7 +1465,7 @@ class DatasetQuery:
|
|
|
1465
1465
|
@detach
|
|
1466
1466
|
def add_signals(
|
|
1467
1467
|
self,
|
|
1468
|
-
udf:
|
|
1468
|
+
udf: UDFAdapter,
|
|
1469
1469
|
parallel: Optional[int] = None,
|
|
1470
1470
|
workers: Union[bool, int] = False,
|
|
1471
1471
|
min_task_size: Optional[int] = None,
|
|
@@ -1509,7 +1509,7 @@ class DatasetQuery:
|
|
|
1509
1509
|
@detach
|
|
1510
1510
|
def generate(
|
|
1511
1511
|
self,
|
|
1512
|
-
udf:
|
|
1512
|
+
udf: UDFAdapter,
|
|
1513
1513
|
parallel: Optional[int] = None,
|
|
1514
1514
|
workers: Union[bool, int] = False,
|
|
1515
1515
|
min_task_size: Optional[int] = None,
|
datachain/query/dispatch.py
CHANGED
|
@@ -13,6 +13,7 @@ from multiprocess import get_context
|
|
|
13
13
|
|
|
14
14
|
from datachain.catalog import Catalog
|
|
15
15
|
from datachain.catalog.loader import get_distributed_class
|
|
16
|
+
from datachain.lib.udf import UDFAdapter, UDFResult
|
|
16
17
|
from datachain.query.dataset import (
|
|
17
18
|
get_download_callback,
|
|
18
19
|
get_generated_callback,
|
|
@@ -27,7 +28,6 @@ from datachain.query.queue import (
|
|
|
27
28
|
put_into_queue,
|
|
28
29
|
unmarshal,
|
|
29
30
|
)
|
|
30
|
-
from datachain.query.udf import UDFBase, UDFResult
|
|
31
31
|
from datachain.utils import batched_it
|
|
32
32
|
|
|
33
33
|
DEFAULT_BATCH_SIZE = 10000
|
|
@@ -336,7 +336,7 @@ class ProcessedCallback(Callback):
|
|
|
336
336
|
@attrs.define
|
|
337
337
|
class UDFWorker:
|
|
338
338
|
catalog: Catalog
|
|
339
|
-
udf:
|
|
339
|
+
udf: UDFAdapter
|
|
340
340
|
task_queue: "multiprocess.Queue"
|
|
341
341
|
done_queue: "multiprocess.Queue"
|
|
342
342
|
is_generator: bool
|
datachain/query/session.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
2
4
|
import re
|
|
5
|
+
import sys
|
|
3
6
|
from typing import TYPE_CHECKING, Optional
|
|
4
7
|
from uuid import uuid4
|
|
5
8
|
|
|
@@ -9,6 +12,8 @@ from datachain.error import TableMissingError
|
|
|
9
12
|
if TYPE_CHECKING:
|
|
10
13
|
from datachain.catalog import Catalog
|
|
11
14
|
|
|
15
|
+
logger = logging.getLogger("datachain")
|
|
16
|
+
|
|
12
17
|
|
|
13
18
|
class Session:
|
|
14
19
|
"""
|
|
@@ -35,6 +40,7 @@ class Session:
|
|
|
35
40
|
|
|
36
41
|
GLOBAL_SESSION_CTX: Optional["Session"] = None
|
|
37
42
|
GLOBAL_SESSION: Optional["Session"] = None
|
|
43
|
+
ORIGINAL_EXCEPT_HOOK = None
|
|
38
44
|
|
|
39
45
|
DATASET_PREFIX = "session_"
|
|
40
46
|
GLOBAL_SESSION_NAME = "global"
|
|
@@ -58,6 +64,7 @@ class Session:
|
|
|
58
64
|
|
|
59
65
|
session_uuid = uuid4().hex[: self.SESSION_UUID_LEN]
|
|
60
66
|
self.name = f"{name}_{session_uuid}"
|
|
67
|
+
self.job_id = os.getenv("DATACHAIN_JOB_ID") or str(uuid4())
|
|
61
68
|
self.is_new_catalog = not catalog
|
|
62
69
|
self.catalog = catalog or get_catalog(
|
|
63
70
|
client_config=client_config, in_memory=in_memory
|
|
@@ -67,6 +74,9 @@ class Session:
|
|
|
67
74
|
return self
|
|
68
75
|
|
|
69
76
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
77
|
+
if exc_type:
|
|
78
|
+
self._cleanup_created_versions(self.name)
|
|
79
|
+
|
|
70
80
|
self._cleanup_temp_datasets()
|
|
71
81
|
if self.is_new_catalog:
|
|
72
82
|
self.catalog.metastore.close_on_exit()
|
|
@@ -88,6 +98,21 @@ class Session:
|
|
|
88
98
|
except TableMissingError:
|
|
89
99
|
pass
|
|
90
100
|
|
|
101
|
+
def _cleanup_created_versions(self, job_id: str) -> None:
|
|
102
|
+
versions = self.catalog.metastore.get_job_dataset_versions(job_id)
|
|
103
|
+
if not versions:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
datasets = {}
|
|
107
|
+
for dataset_name, version in versions:
|
|
108
|
+
if dataset_name not in datasets:
|
|
109
|
+
datasets[dataset_name] = self.catalog.get_dataset(dataset_name)
|
|
110
|
+
dataset = datasets[dataset_name]
|
|
111
|
+
logger.info(
|
|
112
|
+
"Removing dataset version %s@%s due to exception", dataset_name, version
|
|
113
|
+
)
|
|
114
|
+
self.catalog.remove_dataset_version(dataset, version)
|
|
115
|
+
|
|
91
116
|
@classmethod
|
|
92
117
|
def get(
|
|
93
118
|
cls,
|
|
@@ -114,9 +139,23 @@ class Session:
|
|
|
114
139
|
in_memory=in_memory,
|
|
115
140
|
)
|
|
116
141
|
cls.GLOBAL_SESSION = cls.GLOBAL_SESSION_CTX.__enter__()
|
|
142
|
+
|
|
117
143
|
atexit.register(cls._global_cleanup)
|
|
144
|
+
cls.ORIGINAL_EXCEPT_HOOK = sys.excepthook
|
|
145
|
+
sys.excepthook = cls.except_hook
|
|
146
|
+
|
|
118
147
|
return cls.GLOBAL_SESSION
|
|
119
148
|
|
|
149
|
+
@staticmethod
|
|
150
|
+
def except_hook(exc_type, exc_value, exc_traceback):
|
|
151
|
+
Session._global_cleanup()
|
|
152
|
+
if Session.GLOBAL_SESSION_CTX is not None:
|
|
153
|
+
job_id = Session.GLOBAL_SESSION_CTX.job_id
|
|
154
|
+
Session.GLOBAL_SESSION_CTX._cleanup_created_versions(job_id)
|
|
155
|
+
|
|
156
|
+
if Session.ORIGINAL_EXCEPT_HOOK:
|
|
157
|
+
Session.ORIGINAL_EXCEPT_HOOK(exc_type, exc_value, exc_traceback)
|
|
158
|
+
|
|
120
159
|
@classmethod
|
|
121
160
|
def cleanup_for_tests(cls):
|
|
122
161
|
if cls.GLOBAL_SESSION_CTX is not None:
|
|
@@ -125,6 +164,9 @@ class Session:
|
|
|
125
164
|
cls.GLOBAL_SESSION_CTX = None
|
|
126
165
|
atexit.unregister(cls._global_cleanup)
|
|
127
166
|
|
|
167
|
+
if cls.ORIGINAL_EXCEPT_HOOK:
|
|
168
|
+
sys.excepthook = cls.ORIGINAL_EXCEPT_HOOK
|
|
169
|
+
|
|
128
170
|
@staticmethod
|
|
129
171
|
def _global_cleanup():
|
|
130
172
|
if Session.GLOBAL_SESSION_CTX is not None:
|