datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/tar.py
CHANGED
|
@@ -6,12 +6,11 @@ from datachain.lib.file import File, TarVFile
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def build_tar_member(parent: File, info: tarfile.TarInfo) -> File:
|
|
9
|
-
new_parent = parent.get_full_name()
|
|
10
9
|
etag_string = "-".join([parent.etag, info.name, str(info.mtime)])
|
|
11
10
|
etag = hashlib.md5(etag_string.encode(), usedforsecurity=False).hexdigest()
|
|
12
11
|
return File(
|
|
13
12
|
source=parent.source,
|
|
14
|
-
path=f"{
|
|
13
|
+
path=f"{parent.path}/{info.name}",
|
|
15
14
|
version=parent.version,
|
|
16
15
|
size=info.size,
|
|
17
16
|
etag=etag,
|
datachain/lib/text.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def convert_text(
|
|
8
|
-
text:
|
|
9
|
-
tokenizer:
|
|
10
|
-
tokenizer_kwargs:
|
|
11
|
-
encoder:
|
|
12
|
-
device:
|
|
13
|
-
) ->
|
|
9
|
+
text: str | list[str],
|
|
10
|
+
tokenizer: Callable | None = None,
|
|
11
|
+
tokenizer_kwargs: dict[str, Any] | None = None,
|
|
12
|
+
encoder: Callable | None = None,
|
|
13
|
+
device: str | torch.device | None = None,
|
|
14
|
+
) -> str | list[str] | torch.Tensor:
|
|
14
15
|
"""
|
|
15
16
|
Tokenize and otherwise transform text.
|
|
16
17
|
|
datachain/lib/udf.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
import
|
|
2
|
-
import traceback
|
|
1
|
+
import hashlib
|
|
3
2
|
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
4
3
|
from contextlib import closing, nullcontext
|
|
5
4
|
from functools import partial
|
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
7
6
|
|
|
8
7
|
import attrs
|
|
9
8
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
@@ -12,11 +11,10 @@ from pydantic import BaseModel
|
|
|
12
11
|
from datachain.asyn import AsyncMapper
|
|
13
12
|
from datachain.cache import temporary_cache
|
|
14
13
|
from datachain.dataset import RowDict
|
|
14
|
+
from datachain.hash_utils import hash_callable
|
|
15
15
|
from datachain.lib.convert.flatten import flatten
|
|
16
|
-
from datachain.lib.
|
|
17
|
-
from datachain.lib.
|
|
18
|
-
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
|
|
19
|
-
from datachain.progress import CombinedDownloadCallback
|
|
16
|
+
from datachain.lib.file import DataModel, File
|
|
17
|
+
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
20
18
|
from datachain.query.batch import (
|
|
21
19
|
Batch,
|
|
22
20
|
BatchingStrategy,
|
|
@@ -42,8 +40,44 @@ T = TypeVar("T", bound=Sequence[Any])
|
|
|
42
40
|
|
|
43
41
|
|
|
44
42
|
class UdfError(DataChainParamsError):
|
|
45
|
-
|
|
46
|
-
|
|
43
|
+
"""Exception raised for UDF-related errors."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, message: str) -> None:
|
|
46
|
+
self.message = message
|
|
47
|
+
super().__init__(message)
|
|
48
|
+
|
|
49
|
+
def __str__(self) -> str:
|
|
50
|
+
return f"{self.__class__.__name__!s}: {self.message!s}"
|
|
51
|
+
|
|
52
|
+
def __reduce__(self):
|
|
53
|
+
"""Custom reduce method for pickling."""
|
|
54
|
+
return self.__class__, (self.message,)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class UdfRunError(Exception):
|
|
58
|
+
"""Exception raised when UDF execution fails."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
error: Exception | str,
|
|
63
|
+
stacktrace: str | None = None,
|
|
64
|
+
udf_name: str | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
self.error = error
|
|
67
|
+
self.stacktrace = stacktrace
|
|
68
|
+
self.udf_name = udf_name
|
|
69
|
+
super().__init__(str(error))
|
|
70
|
+
|
|
71
|
+
def __str__(self) -> str:
|
|
72
|
+
if isinstance(self.error, UdfRunError):
|
|
73
|
+
return str(self.error)
|
|
74
|
+
if isinstance(self.error, Exception):
|
|
75
|
+
return f"{self.error.__class__.__name__!s}: {self.error!s}"
|
|
76
|
+
return f"{self.__class__.__name__!s}: {self.error!s}"
|
|
77
|
+
|
|
78
|
+
def __reduce__(self):
|
|
79
|
+
"""Custom reduce method for pickling."""
|
|
80
|
+
return self.__class__, (self.error, self.stacktrace, self.udf_name)
|
|
47
81
|
|
|
48
82
|
|
|
49
83
|
ColumnType = Any
|
|
@@ -56,38 +90,26 @@ UDFOutputSpec = Mapping[str, ColumnType]
|
|
|
56
90
|
UDFResult = dict[str, Any]
|
|
57
91
|
|
|
58
92
|
|
|
59
|
-
@attrs.define
|
|
60
|
-
class UDFProperties:
|
|
61
|
-
udf: "UDFAdapter"
|
|
62
|
-
|
|
63
|
-
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
64
|
-
return self.udf.get_batching(use_partitioning)
|
|
65
|
-
|
|
66
|
-
@property
|
|
67
|
-
def batch(self):
|
|
68
|
-
return self.udf.batch
|
|
69
|
-
|
|
70
|
-
|
|
71
93
|
@attrs.define(slots=False)
|
|
72
94
|
class UDFAdapter:
|
|
73
95
|
inner: "UDFBase"
|
|
74
96
|
output: UDFOutputSpec
|
|
97
|
+
batch_size: int | None = None
|
|
75
98
|
batch: int = 1
|
|
76
99
|
|
|
100
|
+
def hash(self) -> str:
|
|
101
|
+
return self.inner.hash()
|
|
102
|
+
|
|
77
103
|
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
78
104
|
if use_partitioning:
|
|
79
105
|
return Partition()
|
|
106
|
+
|
|
80
107
|
if self.batch == 1:
|
|
81
108
|
return NoBatching()
|
|
82
109
|
if self.batch > 1:
|
|
83
110
|
return Batch(self.batch)
|
|
84
111
|
raise ValueError(f"invalid batch size {self.batch}")
|
|
85
112
|
|
|
86
|
-
@property
|
|
87
|
-
def properties(self):
|
|
88
|
-
# For backwards compatibility.
|
|
89
|
-
return UDFProperties(self)
|
|
90
|
-
|
|
91
113
|
def run(
|
|
92
114
|
self,
|
|
93
115
|
udf_fields: "Sequence[str]",
|
|
@@ -164,10 +186,31 @@ class UDFBase(AbstractUDF):
|
|
|
164
186
|
prefetch: int = 0
|
|
165
187
|
|
|
166
188
|
def __init__(self):
|
|
167
|
-
self.params:
|
|
189
|
+
self.params: SignalSchema | None = None
|
|
168
190
|
self.output = None
|
|
169
191
|
self._func = None
|
|
170
192
|
|
|
193
|
+
def hash(self) -> str:
|
|
194
|
+
"""
|
|
195
|
+
Creates SHA hash of this UDF function. It takes into account function,
|
|
196
|
+
inputs and outputs.
|
|
197
|
+
|
|
198
|
+
For function-based UDFs, hashes self._func.
|
|
199
|
+
For class-based UDFs, hashes the process method.
|
|
200
|
+
"""
|
|
201
|
+
# Hash user code: either _func (function-based) or process method (class-based)
|
|
202
|
+
func_to_hash = self._func if self._func else self.process
|
|
203
|
+
|
|
204
|
+
parts = [
|
|
205
|
+
hash_callable(func_to_hash),
|
|
206
|
+
self.params.hash() if self.params else "",
|
|
207
|
+
self.output.hash(),
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
return hashlib.sha256(
|
|
211
|
+
b"".join([bytes.fromhex(part) for part in parts])
|
|
212
|
+
).hexdigest()
|
|
213
|
+
|
|
171
214
|
def process(self, *args, **kwargs):
|
|
172
215
|
"""Processing function that needs to be defined by user"""
|
|
173
216
|
if not self._func:
|
|
@@ -188,7 +231,7 @@ class UDFBase(AbstractUDF):
|
|
|
188
231
|
self,
|
|
189
232
|
sign: "UdfSignature",
|
|
190
233
|
params: "SignalSchema",
|
|
191
|
-
func:
|
|
234
|
+
func: Callable | None,
|
|
192
235
|
):
|
|
193
236
|
self.params = params
|
|
194
237
|
self.output = sign.output_schema
|
|
@@ -219,14 +262,31 @@ class UDFBase(AbstractUDF):
|
|
|
219
262
|
def name(self):
|
|
220
263
|
return self.__class__.__name__
|
|
221
264
|
|
|
265
|
+
@property
|
|
266
|
+
def verbose_name(self):
|
|
267
|
+
"""Returns the name of the function or class that implements the UDF."""
|
|
268
|
+
if self._func and callable(self._func):
|
|
269
|
+
if hasattr(self._func, "__name__"):
|
|
270
|
+
return self._func.__name__
|
|
271
|
+
if hasattr(self._func, "__class__") and hasattr(
|
|
272
|
+
self._func.__class__, "__name__"
|
|
273
|
+
):
|
|
274
|
+
return self._func.__class__.__name__
|
|
275
|
+
return "<unknown>"
|
|
276
|
+
|
|
222
277
|
@property
|
|
223
278
|
def signal_names(self) -> Iterable[str]:
|
|
224
279
|
return self.output.to_udf_spec().keys()
|
|
225
280
|
|
|
226
|
-
def to_udf_wrapper(
|
|
281
|
+
def to_udf_wrapper(
|
|
282
|
+
self,
|
|
283
|
+
batch_size: int | None = None,
|
|
284
|
+
batch: int = 1,
|
|
285
|
+
) -> UDFAdapter:
|
|
227
286
|
return UDFAdapter(
|
|
228
287
|
self,
|
|
229
288
|
self.output.to_udf_spec(),
|
|
289
|
+
batch_size,
|
|
230
290
|
batch,
|
|
231
291
|
)
|
|
232
292
|
|
|
@@ -255,38 +315,37 @@ class UDFBase(AbstractUDF):
|
|
|
255
315
|
|
|
256
316
|
def _parse_row(
|
|
257
317
|
self, row_dict: RowDict, catalog: "Catalog", cache: bool, download_cb: Callback
|
|
258
|
-
) -> list[
|
|
318
|
+
) -> list[Any]:
|
|
259
319
|
assert self.params
|
|
260
320
|
row = [row_dict[p] for p in self.params.to_udf_spec()]
|
|
261
321
|
obj_row = self.params.row_to_objs(row)
|
|
262
322
|
for obj in obj_row:
|
|
263
|
-
|
|
264
|
-
obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
|
|
323
|
+
self._set_stream_recursive(obj, catalog, cache, download_cb)
|
|
265
324
|
return obj_row
|
|
266
325
|
|
|
326
|
+
def _set_stream_recursive(
|
|
327
|
+
self, obj: Any, catalog: "Catalog", cache: bool, download_cb: Callback
|
|
328
|
+
) -> None:
|
|
329
|
+
"""Recursively set the catalog stream on all File objects within an object."""
|
|
330
|
+
if isinstance(obj, File):
|
|
331
|
+
obj._set_stream(catalog, caching_enabled=cache, download_cb=download_cb)
|
|
332
|
+
|
|
333
|
+
# Check all fields for nested File objects, but only for DataModel objects
|
|
334
|
+
if isinstance(obj, DataModel):
|
|
335
|
+
for field_name in type(obj).model_fields:
|
|
336
|
+
field_value = getattr(obj, field_name, None)
|
|
337
|
+
if isinstance(field_value, DataModel):
|
|
338
|
+
self._set_stream_recursive(field_value, catalog, cache, download_cb)
|
|
339
|
+
|
|
267
340
|
def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
|
|
268
|
-
row_dict = RowDict(zip(udf_fields, row))
|
|
341
|
+
row_dict = RowDict(zip(udf_fields, row, strict=False))
|
|
269
342
|
return self._parse_row(row_dict, catalog, cache, download_cb)
|
|
270
343
|
|
|
271
344
|
def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
|
|
272
|
-
row_dict = RowDict(zip(udf_fields, row))
|
|
345
|
+
row_dict = RowDict(zip(udf_fields, row, strict=False))
|
|
273
346
|
udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
|
|
274
347
|
return row_dict["sys__id"], *udf_input
|
|
275
348
|
|
|
276
|
-
def process_safe(self, obj_rows):
|
|
277
|
-
try:
|
|
278
|
-
result_objs = self.process(*obj_rows)
|
|
279
|
-
except Exception as e: # noqa: BLE001
|
|
280
|
-
msg = f"============== Error in user code: '{self.name}' =============="
|
|
281
|
-
print(msg)
|
|
282
|
-
exc_type, exc_value, exc_traceback = sys.exc_info()
|
|
283
|
-
traceback.print_exception(exc_type, exc_value, exc_traceback.tb_next)
|
|
284
|
-
print("=" * len(msg))
|
|
285
|
-
raise DataChainError(
|
|
286
|
-
f"Error in user code in class '{self.name}': {e!s}"
|
|
287
|
-
) from None
|
|
288
|
-
return result_objs
|
|
289
|
-
|
|
290
349
|
|
|
291
350
|
def noop(*args, **kwargs):
|
|
292
351
|
pass
|
|
@@ -294,11 +353,11 @@ def noop(*args, **kwargs):
|
|
|
294
353
|
|
|
295
354
|
async def _prefetch_input(
|
|
296
355
|
row: T,
|
|
297
|
-
download_cb:
|
|
356
|
+
download_cb: Callback | None = None,
|
|
298
357
|
after_prefetch: "Callable[[], None]" = noop,
|
|
299
358
|
) -> T:
|
|
300
359
|
for obj in row:
|
|
301
|
-
if isinstance(obj, File) and await obj._prefetch(download_cb):
|
|
360
|
+
if isinstance(obj, File) and obj.path and await obj._prefetch(download_cb):
|
|
302
361
|
after_prefetch()
|
|
303
362
|
return row
|
|
304
363
|
|
|
@@ -317,8 +376,8 @@ def _remove_prefetched(row: T) -> None:
|
|
|
317
376
|
def _prefetch_inputs(
|
|
318
377
|
prepared_inputs: "Iterable[T]",
|
|
319
378
|
prefetch: int = 0,
|
|
320
|
-
download_cb:
|
|
321
|
-
after_prefetch:
|
|
379
|
+
download_cb: Callback | None = None,
|
|
380
|
+
after_prefetch: Callable[[], None] | None = None,
|
|
322
381
|
remove_prefetched: bool = False,
|
|
323
382
|
) -> "abc.Generator[T, None, None]":
|
|
324
383
|
if not prefetch:
|
|
@@ -327,8 +386,9 @@ def _prefetch_inputs(
|
|
|
327
386
|
|
|
328
387
|
if after_prefetch is None:
|
|
329
388
|
after_prefetch = noop
|
|
330
|
-
if
|
|
331
|
-
|
|
389
|
+
if download_cb and hasattr(download_cb, "increment_file_count"):
|
|
390
|
+
increment_file_count: Callable[[], None] = download_cb.increment_file_count
|
|
391
|
+
after_prefetch = increment_file_count
|
|
332
392
|
|
|
333
393
|
f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
|
|
334
394
|
mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
|
|
@@ -384,9 +444,12 @@ class Mapper(UDFBase):
|
|
|
384
444
|
|
|
385
445
|
with closing(prepared_inputs):
|
|
386
446
|
for id_, *udf_args in prepared_inputs:
|
|
387
|
-
result_objs = self.
|
|
447
|
+
result_objs = self.process(*udf_args)
|
|
388
448
|
udf_output = self._flatten_row(result_objs)
|
|
389
|
-
output = [
|
|
449
|
+
output = [
|
|
450
|
+
{"sys__id": id_}
|
|
451
|
+
| dict(zip(self.signal_names, udf_output, strict=False))
|
|
452
|
+
]
|
|
390
453
|
processed_cb.relative_update(1)
|
|
391
454
|
yield output
|
|
392
455
|
|
|
@@ -394,11 +457,27 @@ class Mapper(UDFBase):
|
|
|
394
457
|
|
|
395
458
|
|
|
396
459
|
class BatchMapper(UDFBase):
|
|
397
|
-
"""Inherit from this class to pass to `DataChain.batch_map()`.
|
|
460
|
+
"""Inherit from this class to pass to `DataChain.batch_map()`.
|
|
461
|
+
|
|
462
|
+
.. deprecated:: 0.29.0
|
|
463
|
+
This class is deprecated and will be removed in a future version.
|
|
464
|
+
Use `Aggregator` instead, which provides the similar functionality.
|
|
465
|
+
"""
|
|
398
466
|
|
|
399
467
|
is_input_batched = True
|
|
400
468
|
is_output_batched = True
|
|
401
469
|
|
|
470
|
+
def __init__(self):
|
|
471
|
+
import warnings
|
|
472
|
+
|
|
473
|
+
warnings.warn(
|
|
474
|
+
"BatchMapper is deprecated and will be removed in a future version. "
|
|
475
|
+
"Use Aggregator instead, which provides the similar functionality.",
|
|
476
|
+
DeprecationWarning,
|
|
477
|
+
stacklevel=2,
|
|
478
|
+
)
|
|
479
|
+
super().__init__()
|
|
480
|
+
|
|
402
481
|
def run(
|
|
403
482
|
self,
|
|
404
483
|
udf_fields: Sequence[str],
|
|
@@ -411,24 +490,26 @@ class BatchMapper(UDFBase):
|
|
|
411
490
|
self.setup()
|
|
412
491
|
|
|
413
492
|
for batch in udf_inputs:
|
|
414
|
-
n_rows = len(batch
|
|
493
|
+
n_rows = len(batch)
|
|
415
494
|
row_ids, *udf_args = zip(
|
|
416
495
|
*[
|
|
417
496
|
self._prepare_row_and_id(
|
|
418
497
|
row, udf_fields, catalog, cache, download_cb
|
|
419
498
|
)
|
|
420
|
-
for row in batch
|
|
421
|
-
]
|
|
499
|
+
for row in batch
|
|
500
|
+
],
|
|
501
|
+
strict=False,
|
|
422
502
|
)
|
|
423
|
-
result_objs = list(self.
|
|
503
|
+
result_objs = list(self.process(*udf_args))
|
|
424
504
|
n_objs = len(result_objs)
|
|
425
505
|
assert n_objs == n_rows, (
|
|
426
506
|
f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
|
|
427
507
|
)
|
|
428
508
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
429
509
|
output = [
|
|
430
|
-
{"sys__id": row_id}
|
|
431
|
-
|
|
510
|
+
{"sys__id": row_id}
|
|
511
|
+
| dict(zip(self.signal_names, signals, strict=False))
|
|
512
|
+
for row_id, signals in zip(row_ids, udf_outputs, strict=False)
|
|
432
513
|
]
|
|
433
514
|
processed_cb.relative_update(n_rows)
|
|
434
515
|
yield output
|
|
@@ -461,10 +542,10 @@ class Generator(UDFBase):
|
|
|
461
542
|
)
|
|
462
543
|
|
|
463
544
|
def _process_row(row):
|
|
464
|
-
with safe_closing(self.
|
|
545
|
+
with safe_closing(self.process(*row)) as result_objs:
|
|
465
546
|
for result_obj in result_objs:
|
|
466
547
|
udf_output = self._flatten_row(result_obj)
|
|
467
|
-
yield dict(zip(self.signal_names, udf_output))
|
|
548
|
+
yield dict(zip(self.signal_names, udf_output, strict=False))
|
|
468
549
|
|
|
469
550
|
prepared_inputs = _prepare_rows(udf_inputs)
|
|
470
551
|
prepared_inputs = _prefetch_inputs(
|
|
@@ -474,8 +555,9 @@ class Generator(UDFBase):
|
|
|
474
555
|
remove_prefetched=bool(self.prefetch) and not cache,
|
|
475
556
|
)
|
|
476
557
|
with closing(prepared_inputs):
|
|
477
|
-
for row in
|
|
558
|
+
for row in prepared_inputs:
|
|
478
559
|
yield _process_row(row)
|
|
560
|
+
processed_cb.relative_update(1)
|
|
479
561
|
|
|
480
562
|
self.teardown()
|
|
481
563
|
|
|
@@ -488,7 +570,7 @@ class Aggregator(UDFBase):
|
|
|
488
570
|
|
|
489
571
|
def run(
|
|
490
572
|
self,
|
|
491
|
-
udf_fields:
|
|
573
|
+
udf_fields: Sequence[str],
|
|
492
574
|
udf_inputs: Iterable[RowsOutputBatch],
|
|
493
575
|
catalog: "Catalog",
|
|
494
576
|
cache: bool,
|
|
@@ -498,16 +580,22 @@ class Aggregator(UDFBase):
|
|
|
498
580
|
self.setup()
|
|
499
581
|
|
|
500
582
|
for batch in udf_inputs:
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
583
|
+
prepared_rows = [
|
|
584
|
+
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
|
|
585
|
+
for row in batch
|
|
586
|
+
]
|
|
587
|
+
batched_args = zip(*prepared_rows, strict=False)
|
|
588
|
+
# Convert aggregated column values to lists. This keeps behavior
|
|
589
|
+
# consistent with the type hints promoted in the public API.
|
|
590
|
+
udf_args = [
|
|
591
|
+
list(arg) if isinstance(arg, tuple) else arg for arg in batched_args
|
|
592
|
+
]
|
|
593
|
+
result_objs = self.process(*udf_args)
|
|
508
594
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
509
|
-
output = (
|
|
510
|
-
|
|
595
|
+
output = (
|
|
596
|
+
dict(zip(self.signal_names, row, strict=False)) for row in udf_outputs
|
|
597
|
+
)
|
|
598
|
+
processed_cb.relative_update(len(batch))
|
|
511
599
|
yield output
|
|
512
600
|
|
|
513
601
|
self.teardown()
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from collections.abc import Generator, Iterator, Sequence
|
|
2
|
+
from collections.abc import Callable, Generator, Iterator, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, get_args, get_origin
|
|
5
5
|
|
|
6
6
|
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
8
|
from datachain.lib.udf import UDFBase
|
|
9
|
-
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
|
+
from datachain.lib.utils import AbstractUDF, DataChainParamsError, callable_name
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class UdfSignatureError(DataChainParamsError):
|
|
@@ -16,9 +16,9 @@ class UdfSignatureError(DataChainParamsError):
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@dataclass
|
|
19
|
-
class UdfSignature:
|
|
20
|
-
func:
|
|
21
|
-
params: dict[str,
|
|
19
|
+
class UdfSignature: # noqa: PLW1641
|
|
20
|
+
func: Callable | UDFBase
|
|
21
|
+
params: dict[str, DataType | Any]
|
|
22
22
|
output_schema: SignalSchema
|
|
23
23
|
|
|
24
24
|
DEFAULT_RETURN_TYPE = str
|
|
@@ -28,24 +28,29 @@ class UdfSignature:
|
|
|
28
28
|
cls,
|
|
29
29
|
chain: str,
|
|
30
30
|
signal_map: dict[str, Callable],
|
|
31
|
-
func:
|
|
32
|
-
params:
|
|
33
|
-
output:
|
|
31
|
+
func: UDFBase | Callable | None = None,
|
|
32
|
+
params: str | Sequence[str] | None = None,
|
|
33
|
+
output: DataType | Sequence[str] | dict[str, DataType] | None = None,
|
|
34
34
|
is_generator: bool = True,
|
|
35
35
|
) -> "UdfSignature":
|
|
36
36
|
keys = ", ".join(signal_map.keys())
|
|
37
37
|
if len(signal_map) > 1:
|
|
38
38
|
raise UdfSignatureError(
|
|
39
39
|
chain,
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
(
|
|
41
|
+
f"multiple signals '{keys}' are not supported in processors."
|
|
42
|
+
" Chain multiple processors instead.",
|
|
43
|
+
),
|
|
42
44
|
)
|
|
43
|
-
udf_func:
|
|
45
|
+
udf_func: UDFBase | Callable
|
|
44
46
|
if len(signal_map) == 1:
|
|
45
47
|
if func is not None:
|
|
46
48
|
raise UdfSignatureError(
|
|
47
49
|
chain,
|
|
48
|
-
|
|
50
|
+
(
|
|
51
|
+
"processor can't have signal "
|
|
52
|
+
f"'{keys}' with function '{callable_name(func)}'"
|
|
53
|
+
),
|
|
49
54
|
)
|
|
50
55
|
signal_name, udf_func = next(iter(signal_map.items()))
|
|
51
56
|
else:
|
|
@@ -56,13 +61,16 @@ class UdfSignature:
|
|
|
56
61
|
signal_name = None
|
|
57
62
|
|
|
58
63
|
if not isinstance(udf_func, UDFBase) and not callable(udf_func):
|
|
59
|
-
raise UdfSignatureError(
|
|
64
|
+
raise UdfSignatureError(
|
|
65
|
+
chain,
|
|
66
|
+
f"UDF '{callable_name(udf_func)}' is not callable",
|
|
67
|
+
)
|
|
60
68
|
|
|
61
69
|
func_params_map_sign, func_outs_sign, is_iterator = cls._func_signature(
|
|
62
70
|
chain, udf_func
|
|
63
71
|
)
|
|
64
72
|
|
|
65
|
-
udf_params: dict[str,
|
|
73
|
+
udf_params: dict[str, DataType | Any] = {}
|
|
66
74
|
if params:
|
|
67
75
|
udf_params = (
|
|
68
76
|
{params: Any} if isinstance(params, str) else dict.fromkeys(params, Any)
|
|
@@ -76,14 +84,15 @@ class UdfSignature:
|
|
|
76
84
|
}
|
|
77
85
|
|
|
78
86
|
if output:
|
|
87
|
+
# Use the actual resolved function (udf_func) for clearer error messages
|
|
79
88
|
udf_output_map = UdfSignature._validate_output(
|
|
80
|
-
chain, signal_name,
|
|
89
|
+
chain, signal_name, udf_func, func_outs_sign, output
|
|
81
90
|
)
|
|
82
91
|
else:
|
|
83
92
|
if not func_outs_sign:
|
|
84
93
|
raise UdfSignatureError(
|
|
85
94
|
chain,
|
|
86
|
-
f"outputs are not defined in function '{udf_func}'"
|
|
95
|
+
f"outputs are not defined in function '{callable_name(udf_func)}'"
|
|
87
96
|
" hints or 'output'",
|
|
88
97
|
)
|
|
89
98
|
|
|
@@ -97,9 +106,12 @@ class UdfSignature:
|
|
|
97
106
|
if is_generator and not is_iterator:
|
|
98
107
|
raise UdfSignatureError(
|
|
99
108
|
chain,
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
109
|
+
(
|
|
110
|
+
f"function '{callable_name(udf_func)}' cannot be used in "
|
|
111
|
+
"generator/aggregator because it returns a type that is "
|
|
112
|
+
"not Iterator/Generator. "
|
|
113
|
+
f"Instead, it returns '{func_outs_sign}'"
|
|
114
|
+
),
|
|
103
115
|
)
|
|
104
116
|
|
|
105
117
|
if isinstance(func_outs_sign, tuple):
|
|
@@ -124,11 +136,14 @@ class UdfSignature:
|
|
|
124
136
|
if len(func_outs_sign) != len(output):
|
|
125
137
|
raise UdfSignatureError(
|
|
126
138
|
chain,
|
|
127
|
-
|
|
128
|
-
|
|
139
|
+
(
|
|
140
|
+
f"length of outputs names ({len(output)}) and function "
|
|
141
|
+
f"'{callable_name(func)}' return type length "
|
|
142
|
+
f"({len(func_outs_sign)}) does not match"
|
|
143
|
+
),
|
|
129
144
|
)
|
|
130
145
|
|
|
131
|
-
udf_output_map = dict(zip(output, func_outs_sign))
|
|
146
|
+
udf_output_map = dict(zip(output, func_outs_sign, strict=False))
|
|
132
147
|
elif isinstance(output, dict):
|
|
133
148
|
for key, value in output.items():
|
|
134
149
|
if not isinstance(key, str):
|
|
@@ -164,7 +179,7 @@ class UdfSignature:
|
|
|
164
179
|
|
|
165
180
|
@staticmethod
|
|
166
181
|
def _func_signature(
|
|
167
|
-
chain: str, udf_func:
|
|
182
|
+
chain: str, udf_func: Callable | UDFBase
|
|
168
183
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
169
184
|
if isinstance(udf_func, AbstractUDF):
|
|
170
185
|
func = udf_func.process # type: ignore[unreachable]
|
|
@@ -183,17 +198,27 @@ class UdfSignature:
|
|
|
183
198
|
orig = get_origin(anno)
|
|
184
199
|
if inspect.isclass(orig) and issubclass(orig, Iterator):
|
|
185
200
|
args = get_args(anno)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
201
|
+
# For typing.Iterator without type args, default to DEFAULT_RETURN_TYPE
|
|
202
|
+
if len(args) == 0:
|
|
203
|
+
is_iterator = True
|
|
204
|
+
anno = UdfSignature.DEFAULT_RETURN_TYPE
|
|
205
|
+
orig = get_origin(anno)
|
|
206
|
+
else:
|
|
207
|
+
# typing.Generator[T, S, R] has 3 args; allow that shape
|
|
208
|
+
if len(args) > 1 and not (
|
|
209
|
+
issubclass(orig, Generator) and len(args) == 3
|
|
210
|
+
):
|
|
211
|
+
raise UdfSignatureError(
|
|
212
|
+
chain,
|
|
213
|
+
(
|
|
214
|
+
f"function '{callable_name(func)}' should return "
|
|
215
|
+
"iterator with a single value while "
|
|
216
|
+
f"'{args}' are specified"
|
|
217
|
+
),
|
|
218
|
+
)
|
|
219
|
+
is_iterator = True
|
|
220
|
+
anno = args[0]
|
|
221
|
+
orig = get_origin(anno)
|
|
197
222
|
|
|
198
223
|
if orig and orig is tuple:
|
|
199
224
|
output_types = tuple(get_args(anno)) # type: ignore[assignment]
|