datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/udf.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
import
|
|
2
|
-
import traceback
|
|
1
|
+
import hashlib
|
|
3
2
|
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
4
3
|
from contextlib import closing, nullcontext
|
|
5
4
|
from functools import partial
|
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
7
6
|
|
|
8
7
|
import attrs
|
|
9
8
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
@@ -12,9 +11,10 @@ from pydantic import BaseModel
|
|
|
12
11
|
from datachain.asyn import AsyncMapper
|
|
13
12
|
from datachain.cache import temporary_cache
|
|
14
13
|
from datachain.dataset import RowDict
|
|
14
|
+
from datachain.hash_utils import hash_callable
|
|
15
15
|
from datachain.lib.convert.flatten import flatten
|
|
16
16
|
from datachain.lib.file import DataModel, File
|
|
17
|
-
from datachain.lib.utils import AbstractUDF,
|
|
17
|
+
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
18
18
|
from datachain.query.batch import (
|
|
19
19
|
Batch,
|
|
20
20
|
BatchingStrategy,
|
|
@@ -40,8 +40,44 @@ T = TypeVar("T", bound=Sequence[Any])
|
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class UdfError(DataChainParamsError):
|
|
43
|
-
|
|
44
|
-
|
|
43
|
+
"""Exception raised for UDF-related errors."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, message: str) -> None:
|
|
46
|
+
self.message = message
|
|
47
|
+
super().__init__(message)
|
|
48
|
+
|
|
49
|
+
def __str__(self) -> str:
|
|
50
|
+
return f"{self.__class__.__name__!s}: {self.message!s}"
|
|
51
|
+
|
|
52
|
+
def __reduce__(self):
|
|
53
|
+
"""Custom reduce method for pickling."""
|
|
54
|
+
return self.__class__, (self.message,)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class UdfRunError(Exception):
|
|
58
|
+
"""Exception raised when UDF execution fails."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
error: Exception | str,
|
|
63
|
+
stacktrace: str | None = None,
|
|
64
|
+
udf_name: str | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
self.error = error
|
|
67
|
+
self.stacktrace = stacktrace
|
|
68
|
+
self.udf_name = udf_name
|
|
69
|
+
super().__init__(str(error))
|
|
70
|
+
|
|
71
|
+
def __str__(self) -> str:
|
|
72
|
+
if isinstance(self.error, UdfRunError):
|
|
73
|
+
return str(self.error)
|
|
74
|
+
if isinstance(self.error, Exception):
|
|
75
|
+
return f"{self.error.__class__.__name__!s}: {self.error!s}"
|
|
76
|
+
return f"{self.__class__.__name__!s}: {self.error!s}"
|
|
77
|
+
|
|
78
|
+
def __reduce__(self):
|
|
79
|
+
"""Custom reduce method for pickling."""
|
|
80
|
+
return self.__class__, (self.error, self.stacktrace, self.udf_name)
|
|
45
81
|
|
|
46
82
|
|
|
47
83
|
ColumnType = Any
|
|
@@ -54,25 +90,16 @@ UDFOutputSpec = Mapping[str, ColumnType]
|
|
|
54
90
|
UDFResult = dict[str, Any]
|
|
55
91
|
|
|
56
92
|
|
|
57
|
-
@attrs.define
|
|
58
|
-
class UDFProperties:
|
|
59
|
-
udf: "UDFAdapter"
|
|
60
|
-
|
|
61
|
-
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
62
|
-
return self.udf.get_batching(use_partitioning)
|
|
63
|
-
|
|
64
|
-
@property
|
|
65
|
-
def batch_rows(self):
|
|
66
|
-
return self.udf.batch_rows
|
|
67
|
-
|
|
68
|
-
|
|
69
93
|
@attrs.define(slots=False)
|
|
70
94
|
class UDFAdapter:
|
|
71
95
|
inner: "UDFBase"
|
|
72
96
|
output: UDFOutputSpec
|
|
73
|
-
|
|
97
|
+
batch_size: int | None = None
|
|
74
98
|
batch: int = 1
|
|
75
99
|
|
|
100
|
+
def hash(self) -> str:
|
|
101
|
+
return self.inner.hash()
|
|
102
|
+
|
|
76
103
|
def get_batching(self, use_partitioning: bool = False) -> BatchingStrategy:
|
|
77
104
|
if use_partitioning:
|
|
78
105
|
return Partition()
|
|
@@ -83,11 +110,6 @@ class UDFAdapter:
|
|
|
83
110
|
return Batch(self.batch)
|
|
84
111
|
raise ValueError(f"invalid batch size {self.batch}")
|
|
85
112
|
|
|
86
|
-
@property
|
|
87
|
-
def properties(self):
|
|
88
|
-
# For backwards compatibility.
|
|
89
|
-
return UDFProperties(self)
|
|
90
|
-
|
|
91
113
|
def run(
|
|
92
114
|
self,
|
|
93
115
|
udf_fields: "Sequence[str]",
|
|
@@ -164,10 +186,31 @@ class UDFBase(AbstractUDF):
|
|
|
164
186
|
prefetch: int = 0
|
|
165
187
|
|
|
166
188
|
def __init__(self):
|
|
167
|
-
self.params:
|
|
189
|
+
self.params: SignalSchema | None = None
|
|
168
190
|
self.output = None
|
|
169
191
|
self._func = None
|
|
170
192
|
|
|
193
|
+
def hash(self) -> str:
|
|
194
|
+
"""
|
|
195
|
+
Creates SHA hash of this UDF function. It takes into account function,
|
|
196
|
+
inputs and outputs.
|
|
197
|
+
|
|
198
|
+
For function-based UDFs, hashes self._func.
|
|
199
|
+
For class-based UDFs, hashes the process method.
|
|
200
|
+
"""
|
|
201
|
+
# Hash user code: either _func (function-based) or process method (class-based)
|
|
202
|
+
func_to_hash = self._func if self._func else self.process
|
|
203
|
+
|
|
204
|
+
parts = [
|
|
205
|
+
hash_callable(func_to_hash),
|
|
206
|
+
self.params.hash() if self.params else "",
|
|
207
|
+
self.output.hash(),
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
return hashlib.sha256(
|
|
211
|
+
b"".join([bytes.fromhex(part) for part in parts])
|
|
212
|
+
).hexdigest()
|
|
213
|
+
|
|
171
214
|
def process(self, *args, **kwargs):
|
|
172
215
|
"""Processing function that needs to be defined by user"""
|
|
173
216
|
if not self._func:
|
|
@@ -188,7 +231,7 @@ class UDFBase(AbstractUDF):
|
|
|
188
231
|
self,
|
|
189
232
|
sign: "UdfSignature",
|
|
190
233
|
params: "SignalSchema",
|
|
191
|
-
func:
|
|
234
|
+
func: Callable | None,
|
|
192
235
|
):
|
|
193
236
|
self.params = params
|
|
194
237
|
self.output = sign.output_schema
|
|
@@ -237,13 +280,13 @@ class UDFBase(AbstractUDF):
|
|
|
237
280
|
|
|
238
281
|
def to_udf_wrapper(
|
|
239
282
|
self,
|
|
240
|
-
|
|
283
|
+
batch_size: int | None = None,
|
|
241
284
|
batch: int = 1,
|
|
242
285
|
) -> UDFAdapter:
|
|
243
286
|
return UDFAdapter(
|
|
244
287
|
self,
|
|
245
288
|
self.output.to_udf_spec(),
|
|
246
|
-
|
|
289
|
+
batch_size,
|
|
247
290
|
batch,
|
|
248
291
|
)
|
|
249
292
|
|
|
@@ -295,28 +338,14 @@ class UDFBase(AbstractUDF):
|
|
|
295
338
|
self._set_stream_recursive(field_value, catalog, cache, download_cb)
|
|
296
339
|
|
|
297
340
|
def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
|
|
298
|
-
row_dict = RowDict(zip(udf_fields, row))
|
|
341
|
+
row_dict = RowDict(zip(udf_fields, row, strict=False))
|
|
299
342
|
return self._parse_row(row_dict, catalog, cache, download_cb)
|
|
300
343
|
|
|
301
344
|
def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
|
|
302
|
-
row_dict = RowDict(zip(udf_fields, row))
|
|
345
|
+
row_dict = RowDict(zip(udf_fields, row, strict=False))
|
|
303
346
|
udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
|
|
304
347
|
return row_dict["sys__id"], *udf_input
|
|
305
348
|
|
|
306
|
-
def process_safe(self, obj_rows):
|
|
307
|
-
try:
|
|
308
|
-
result_objs = self.process(*obj_rows)
|
|
309
|
-
except Exception as e: # noqa: BLE001
|
|
310
|
-
msg = f"============== Error in user code: '{self.name}' =============="
|
|
311
|
-
print(msg)
|
|
312
|
-
exc_type, exc_value, exc_traceback = sys.exc_info()
|
|
313
|
-
traceback.print_exception(exc_type, exc_value, exc_traceback.tb_next)
|
|
314
|
-
print("=" * len(msg))
|
|
315
|
-
raise DataChainError(
|
|
316
|
-
f"Error in user code in class '{self.name}': {e!s}"
|
|
317
|
-
) from None
|
|
318
|
-
return result_objs
|
|
319
|
-
|
|
320
349
|
|
|
321
350
|
def noop(*args, **kwargs):
|
|
322
351
|
pass
|
|
@@ -324,7 +353,7 @@ def noop(*args, **kwargs):
|
|
|
324
353
|
|
|
325
354
|
async def _prefetch_input(
|
|
326
355
|
row: T,
|
|
327
|
-
download_cb:
|
|
356
|
+
download_cb: Callback | None = None,
|
|
328
357
|
after_prefetch: "Callable[[], None]" = noop,
|
|
329
358
|
) -> T:
|
|
330
359
|
for obj in row:
|
|
@@ -347,8 +376,8 @@ def _remove_prefetched(row: T) -> None:
|
|
|
347
376
|
def _prefetch_inputs(
|
|
348
377
|
prepared_inputs: "Iterable[T]",
|
|
349
378
|
prefetch: int = 0,
|
|
350
|
-
download_cb:
|
|
351
|
-
after_prefetch:
|
|
379
|
+
download_cb: Callback | None = None,
|
|
380
|
+
after_prefetch: Callable[[], None] | None = None,
|
|
352
381
|
remove_prefetched: bool = False,
|
|
353
382
|
) -> "abc.Generator[T, None, None]":
|
|
354
383
|
if not prefetch:
|
|
@@ -415,9 +444,12 @@ class Mapper(UDFBase):
|
|
|
415
444
|
|
|
416
445
|
with closing(prepared_inputs):
|
|
417
446
|
for id_, *udf_args in prepared_inputs:
|
|
418
|
-
result_objs = self.
|
|
447
|
+
result_objs = self.process(*udf_args)
|
|
419
448
|
udf_output = self._flatten_row(result_objs)
|
|
420
|
-
output = [
|
|
449
|
+
output = [
|
|
450
|
+
{"sys__id": id_}
|
|
451
|
+
| dict(zip(self.signal_names, udf_output, strict=False))
|
|
452
|
+
]
|
|
421
453
|
processed_cb.relative_update(1)
|
|
422
454
|
yield output
|
|
423
455
|
|
|
@@ -465,17 +497,19 @@ class BatchMapper(UDFBase):
|
|
|
465
497
|
row, udf_fields, catalog, cache, download_cb
|
|
466
498
|
)
|
|
467
499
|
for row in batch
|
|
468
|
-
]
|
|
500
|
+
],
|
|
501
|
+
strict=False,
|
|
469
502
|
)
|
|
470
|
-
result_objs = list(self.
|
|
503
|
+
result_objs = list(self.process(*udf_args))
|
|
471
504
|
n_objs = len(result_objs)
|
|
472
505
|
assert n_objs == n_rows, (
|
|
473
506
|
f"{self.name} returns {n_objs} rows, but {n_rows} were expected"
|
|
474
507
|
)
|
|
475
508
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
476
509
|
output = [
|
|
477
|
-
{"sys__id": row_id}
|
|
478
|
-
|
|
510
|
+
{"sys__id": row_id}
|
|
511
|
+
| dict(zip(self.signal_names, signals, strict=False))
|
|
512
|
+
for row_id, signals in zip(row_ids, udf_outputs, strict=False)
|
|
479
513
|
]
|
|
480
514
|
processed_cb.relative_update(n_rows)
|
|
481
515
|
yield output
|
|
@@ -508,10 +542,10 @@ class Generator(UDFBase):
|
|
|
508
542
|
)
|
|
509
543
|
|
|
510
544
|
def _process_row(row):
|
|
511
|
-
with safe_closing(self.
|
|
545
|
+
with safe_closing(self.process(*row)) as result_objs:
|
|
512
546
|
for result_obj in result_objs:
|
|
513
547
|
udf_output = self._flatten_row(result_obj)
|
|
514
|
-
yield dict(zip(self.signal_names, udf_output))
|
|
548
|
+
yield dict(zip(self.signal_names, udf_output, strict=False))
|
|
515
549
|
|
|
516
550
|
prepared_inputs = _prepare_rows(udf_inputs)
|
|
517
551
|
prepared_inputs = _prefetch_inputs(
|
|
@@ -546,15 +580,21 @@ class Aggregator(UDFBase):
|
|
|
546
580
|
self.setup()
|
|
547
581
|
|
|
548
582
|
for batch in udf_inputs:
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
583
|
+
prepared_rows = [
|
|
584
|
+
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
|
|
585
|
+
for row in batch
|
|
586
|
+
]
|
|
587
|
+
batched_args = zip(*prepared_rows, strict=False)
|
|
588
|
+
# Convert aggregated column values to lists. This keeps behavior
|
|
589
|
+
# consistent with the type hints promoted in the public API.
|
|
590
|
+
udf_args = [
|
|
591
|
+
list(arg) if isinstance(arg, tuple) else arg for arg in batched_args
|
|
592
|
+
]
|
|
593
|
+
result_objs = self.process(*udf_args)
|
|
556
594
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
557
|
-
output = (
|
|
595
|
+
output = (
|
|
596
|
+
dict(zip(self.signal_names, row, strict=False)) for row in udf_outputs
|
|
597
|
+
)
|
|
558
598
|
processed_cb.relative_update(len(batch))
|
|
559
599
|
yield output
|
|
560
600
|
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from collections.abc import Generator, Iterator, Sequence
|
|
2
|
+
from collections.abc import Callable, Generator, Iterator, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, get_args, get_origin
|
|
5
5
|
|
|
6
6
|
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
8
|
from datachain.lib.udf import UDFBase
|
|
9
|
-
from datachain.lib.utils import AbstractUDF, DataChainParamsError
|
|
9
|
+
from datachain.lib.utils import AbstractUDF, DataChainParamsError, callable_name
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class UdfSignatureError(DataChainParamsError):
|
|
@@ -17,8 +17,8 @@ class UdfSignatureError(DataChainParamsError):
|
|
|
17
17
|
|
|
18
18
|
@dataclass
|
|
19
19
|
class UdfSignature: # noqa: PLW1641
|
|
20
|
-
func:
|
|
21
|
-
params: dict[str,
|
|
20
|
+
func: Callable | UDFBase
|
|
21
|
+
params: dict[str, DataType | Any]
|
|
22
22
|
output_schema: SignalSchema
|
|
23
23
|
|
|
24
24
|
DEFAULT_RETURN_TYPE = str
|
|
@@ -28,24 +28,29 @@ class UdfSignature: # noqa: PLW1641
|
|
|
28
28
|
cls,
|
|
29
29
|
chain: str,
|
|
30
30
|
signal_map: dict[str, Callable],
|
|
31
|
-
func:
|
|
32
|
-
params:
|
|
33
|
-
output:
|
|
31
|
+
func: UDFBase | Callable | None = None,
|
|
32
|
+
params: str | Sequence[str] | None = None,
|
|
33
|
+
output: DataType | Sequence[str] | dict[str, DataType] | None = None,
|
|
34
34
|
is_generator: bool = True,
|
|
35
35
|
) -> "UdfSignature":
|
|
36
36
|
keys = ", ".join(signal_map.keys())
|
|
37
37
|
if len(signal_map) > 1:
|
|
38
38
|
raise UdfSignatureError(
|
|
39
39
|
chain,
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
(
|
|
41
|
+
f"multiple signals '{keys}' are not supported in processors."
|
|
42
|
+
" Chain multiple processors instead.",
|
|
43
|
+
),
|
|
42
44
|
)
|
|
43
|
-
udf_func:
|
|
45
|
+
udf_func: UDFBase | Callable
|
|
44
46
|
if len(signal_map) == 1:
|
|
45
47
|
if func is not None:
|
|
46
48
|
raise UdfSignatureError(
|
|
47
49
|
chain,
|
|
48
|
-
|
|
50
|
+
(
|
|
51
|
+
"processor can't have signal "
|
|
52
|
+
f"'{keys}' with function '{callable_name(func)}'"
|
|
53
|
+
),
|
|
49
54
|
)
|
|
50
55
|
signal_name, udf_func = next(iter(signal_map.items()))
|
|
51
56
|
else:
|
|
@@ -56,13 +61,16 @@ class UdfSignature: # noqa: PLW1641
|
|
|
56
61
|
signal_name = None
|
|
57
62
|
|
|
58
63
|
if not isinstance(udf_func, UDFBase) and not callable(udf_func):
|
|
59
|
-
raise UdfSignatureError(
|
|
64
|
+
raise UdfSignatureError(
|
|
65
|
+
chain,
|
|
66
|
+
f"UDF '{callable_name(udf_func)}' is not callable",
|
|
67
|
+
)
|
|
60
68
|
|
|
61
69
|
func_params_map_sign, func_outs_sign, is_iterator = cls._func_signature(
|
|
62
70
|
chain, udf_func
|
|
63
71
|
)
|
|
64
72
|
|
|
65
|
-
udf_params: dict[str,
|
|
73
|
+
udf_params: dict[str, DataType | Any] = {}
|
|
66
74
|
if params:
|
|
67
75
|
udf_params = (
|
|
68
76
|
{params: Any} if isinstance(params, str) else dict.fromkeys(params, Any)
|
|
@@ -76,14 +84,15 @@ class UdfSignature: # noqa: PLW1641
|
|
|
76
84
|
}
|
|
77
85
|
|
|
78
86
|
if output:
|
|
87
|
+
# Use the actual resolved function (udf_func) for clearer error messages
|
|
79
88
|
udf_output_map = UdfSignature._validate_output(
|
|
80
|
-
chain, signal_name,
|
|
89
|
+
chain, signal_name, udf_func, func_outs_sign, output
|
|
81
90
|
)
|
|
82
91
|
else:
|
|
83
92
|
if not func_outs_sign:
|
|
84
93
|
raise UdfSignatureError(
|
|
85
94
|
chain,
|
|
86
|
-
f"outputs are not defined in function '{udf_func}'"
|
|
95
|
+
f"outputs are not defined in function '{callable_name(udf_func)}'"
|
|
87
96
|
" hints or 'output'",
|
|
88
97
|
)
|
|
89
98
|
|
|
@@ -97,9 +106,12 @@ class UdfSignature: # noqa: PLW1641
|
|
|
97
106
|
if is_generator and not is_iterator:
|
|
98
107
|
raise UdfSignatureError(
|
|
99
108
|
chain,
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
109
|
+
(
|
|
110
|
+
f"function '{callable_name(udf_func)}' cannot be used in "
|
|
111
|
+
"generator/aggregator because it returns a type that is "
|
|
112
|
+
"not Iterator/Generator. "
|
|
113
|
+
f"Instead, it returns '{func_outs_sign}'"
|
|
114
|
+
),
|
|
103
115
|
)
|
|
104
116
|
|
|
105
117
|
if isinstance(func_outs_sign, tuple):
|
|
@@ -124,11 +136,14 @@ class UdfSignature: # noqa: PLW1641
|
|
|
124
136
|
if len(func_outs_sign) != len(output):
|
|
125
137
|
raise UdfSignatureError(
|
|
126
138
|
chain,
|
|
127
|
-
|
|
128
|
-
|
|
139
|
+
(
|
|
140
|
+
f"length of outputs names ({len(output)}) and function "
|
|
141
|
+
f"'{callable_name(func)}' return type length "
|
|
142
|
+
f"({len(func_outs_sign)}) does not match"
|
|
143
|
+
),
|
|
129
144
|
)
|
|
130
145
|
|
|
131
|
-
udf_output_map = dict(zip(output, func_outs_sign))
|
|
146
|
+
udf_output_map = dict(zip(output, func_outs_sign, strict=False))
|
|
132
147
|
elif isinstance(output, dict):
|
|
133
148
|
for key, value in output.items():
|
|
134
149
|
if not isinstance(key, str):
|
|
@@ -164,7 +179,7 @@ class UdfSignature: # noqa: PLW1641
|
|
|
164
179
|
|
|
165
180
|
@staticmethod
|
|
166
181
|
def _func_signature(
|
|
167
|
-
chain: str, udf_func:
|
|
182
|
+
chain: str, udf_func: Callable | UDFBase
|
|
168
183
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
169
184
|
if isinstance(udf_func, AbstractUDF):
|
|
170
185
|
func = udf_func.process # type: ignore[unreachable]
|
|
@@ -183,17 +198,27 @@ class UdfSignature: # noqa: PLW1641
|
|
|
183
198
|
orig = get_origin(anno)
|
|
184
199
|
if inspect.isclass(orig) and issubclass(orig, Iterator):
|
|
185
200
|
args = get_args(anno)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
201
|
+
# For typing.Iterator without type args, default to DEFAULT_RETURN_TYPE
|
|
202
|
+
if len(args) == 0:
|
|
203
|
+
is_iterator = True
|
|
204
|
+
anno = UdfSignature.DEFAULT_RETURN_TYPE
|
|
205
|
+
orig = get_origin(anno)
|
|
206
|
+
else:
|
|
207
|
+
# typing.Generator[T, S, R] has 3 args; allow that shape
|
|
208
|
+
if len(args) > 1 and not (
|
|
209
|
+
issubclass(orig, Generator) and len(args) == 3
|
|
210
|
+
):
|
|
211
|
+
raise UdfSignatureError(
|
|
212
|
+
chain,
|
|
213
|
+
(
|
|
214
|
+
f"function '{callable_name(func)}' should return "
|
|
215
|
+
"iterator with a single value while "
|
|
216
|
+
f"'{args}' are specified"
|
|
217
|
+
),
|
|
218
|
+
)
|
|
219
|
+
is_iterator = True
|
|
220
|
+
anno = args[0]
|
|
221
|
+
orig = get_origin(anno)
|
|
197
222
|
|
|
198
223
|
if orig and orig is tuple:
|
|
199
224
|
output_types = tuple(get_args(anno)) # type: ignore[assignment]
|
datachain/lib/utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import re
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import Sequence
|
|
@@ -32,6 +33,25 @@ class DataChainColumnError(DataChainParamsError):
|
|
|
32
33
|
super().__init__(f"Error for column {col_name}: {msg}")
|
|
33
34
|
|
|
34
35
|
|
|
36
|
+
def callable_name(obj: object) -> str:
|
|
37
|
+
"""Return a friendly name for a callable or UDF-like instance."""
|
|
38
|
+
# UDF classes in DataChain inherit from AbstractUDF; prefer class name
|
|
39
|
+
if isinstance(obj, AbstractUDF):
|
|
40
|
+
return obj.__class__.__name__
|
|
41
|
+
|
|
42
|
+
# Plain functions and bound/unbound methods
|
|
43
|
+
if inspect.ismethod(obj) or inspect.isfunction(obj):
|
|
44
|
+
# __name__ exists for functions/methods; includes "<lambda>" for lambdas
|
|
45
|
+
return obj.__name__ # type: ignore[attr-defined]
|
|
46
|
+
|
|
47
|
+
# Generic callable object
|
|
48
|
+
if callable(obj):
|
|
49
|
+
return obj.__class__.__name__
|
|
50
|
+
|
|
51
|
+
# Fallback for non-callables
|
|
52
|
+
return str(obj)
|
|
53
|
+
|
|
54
|
+
|
|
35
55
|
def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
|
|
36
56
|
"""Returns normalized_name -> original_name dict."""
|
|
37
57
|
gen_col_counter = 0
|
datachain/lib/video.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
import shutil
|
|
3
3
|
import tempfile
|
|
4
|
-
from typing import Optional, Union
|
|
5
4
|
|
|
6
5
|
from numpy import ndarray
|
|
7
6
|
|
|
@@ -18,7 +17,7 @@ except ImportError as exc:
|
|
|
18
17
|
) from exc
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
def video_info(file:
|
|
20
|
+
def video_info(file: File | VideoFile) -> Video:
|
|
22
21
|
"""
|
|
23
22
|
Returns video file information.
|
|
24
23
|
|
|
@@ -108,7 +107,7 @@ def video_frame_np(video: VideoFile, frame: int) -> ndarray:
|
|
|
108
107
|
def validate_frame_range(
|
|
109
108
|
video: VideoFile,
|
|
110
109
|
start: int = 0,
|
|
111
|
-
end:
|
|
110
|
+
end: int | None = None,
|
|
112
111
|
step: int = 1,
|
|
113
112
|
) -> tuple[int, int, int]:
|
|
114
113
|
"""
|
|
@@ -186,7 +185,7 @@ def save_video_fragment(
|
|
|
186
185
|
start: float,
|
|
187
186
|
end: float,
|
|
188
187
|
output: str,
|
|
189
|
-
format:
|
|
188
|
+
format: str | None = None,
|
|
190
189
|
) -> VideoFile:
|
|
191
190
|
"""
|
|
192
191
|
Saves video interval as a new video file. If output is a remote path,
|
datachain/lib/webdataset.py
CHANGED
|
@@ -1,20 +1,13 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import tarfile
|
|
2
|
+
import types
|
|
3
3
|
import warnings
|
|
4
|
-
from collections.abc import Iterator, Sequence
|
|
4
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
-
Any,
|
|
8
|
-
Callable,
|
|
9
|
-
ClassVar,
|
|
10
|
-
Optional,
|
|
11
|
-
Union,
|
|
12
|
-
get_args,
|
|
13
|
-
get_origin,
|
|
14
|
-
)
|
|
6
|
+
from typing import Any, ClassVar, Union, get_args, get_origin
|
|
15
7
|
|
|
16
8
|
from pydantic import Field
|
|
17
9
|
|
|
10
|
+
from datachain import json
|
|
18
11
|
from datachain.lib.data_model import DataModel
|
|
19
12
|
from datachain.lib.file import File
|
|
20
13
|
from datachain.lib.tar import build_tar_member
|
|
@@ -64,28 +57,28 @@ class WDSBasic(DataModel):
|
|
|
64
57
|
|
|
65
58
|
|
|
66
59
|
class WDSAllFile(WDSBasic):
|
|
67
|
-
txt:
|
|
68
|
-
text:
|
|
69
|
-
cap:
|
|
70
|
-
transcript:
|
|
71
|
-
cls:
|
|
72
|
-
cls2:
|
|
73
|
-
index:
|
|
74
|
-
inx:
|
|
75
|
-
id:
|
|
76
|
-
json:
|
|
77
|
-
jsn:
|
|
78
|
-
|
|
79
|
-
pyd:
|
|
80
|
-
pickle:
|
|
81
|
-
pth:
|
|
82
|
-
ten:
|
|
83
|
-
tb:
|
|
84
|
-
mp:
|
|
85
|
-
msg:
|
|
86
|
-
npy:
|
|
87
|
-
npz:
|
|
88
|
-
cbor:
|
|
60
|
+
txt: str | None = Field(default=None)
|
|
61
|
+
text: str | None = Field(default=None)
|
|
62
|
+
cap: str | None = Field(default=None)
|
|
63
|
+
transcript: str | None = Field(default=None)
|
|
64
|
+
cls: int | None = Field(default=None)
|
|
65
|
+
cls2: int | None = Field(default=None)
|
|
66
|
+
index: int | None = Field(default=None)
|
|
67
|
+
inx: int | None = Field(default=None)
|
|
68
|
+
id: int | None = Field(default=None)
|
|
69
|
+
json: dict | None = Field(default=None) # type: ignore[assignment]
|
|
70
|
+
jsn: dict | None = Field(default=None)
|
|
71
|
+
|
|
72
|
+
pyd: bytes | None = Field(default=None)
|
|
73
|
+
pickle: bytes | None = Field(default=None)
|
|
74
|
+
pth: bytes | None = Field(default=None)
|
|
75
|
+
ten: bytes | None = Field(default=None)
|
|
76
|
+
tb: bytes | None = Field(default=None)
|
|
77
|
+
mp: bytes | None = Field(default=None)
|
|
78
|
+
msg: bytes | None = Field(default=None)
|
|
79
|
+
npy: bytes | None = Field(default=None)
|
|
80
|
+
npz: bytes | None = Field(default=None)
|
|
81
|
+
cbor: bytes | None = Field(default=None)
|
|
89
82
|
|
|
90
83
|
|
|
91
84
|
class WDSReadableSubclass(DataModel):
|
|
@@ -189,9 +182,11 @@ class Builder:
|
|
|
189
182
|
return
|
|
190
183
|
|
|
191
184
|
anno = field.annotation
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
185
|
+
anno_origin = get_origin(anno)
|
|
186
|
+
if anno_origin in (Union, types.UnionType):
|
|
187
|
+
anno_args = get_args(anno)
|
|
188
|
+
if len(anno_args) == 2 and type(None) in anno_args:
|
|
189
|
+
return anno_args[0] if anno_args[1] is type(None) else anno_args[1]
|
|
195
190
|
|
|
196
191
|
return anno
|
|
197
192
|
|