datachain 0.1.13__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -4
- datachain/asyn.py +3 -3
- datachain/catalog/__init__.py +3 -3
- datachain/catalog/catalog.py +6 -6
- datachain/catalog/loader.py +3 -3
- datachain/cli.py +10 -2
- datachain/client/azure.py +37 -1
- datachain/client/fsspec.py +1 -1
- datachain/client/local.py +1 -1
- datachain/data_storage/__init__.py +1 -1
- datachain/data_storage/metastore.py +11 -3
- datachain/data_storage/schema.py +12 -7
- datachain/data_storage/sqlite.py +3 -0
- datachain/data_storage/warehouse.py +31 -30
- datachain/dataset.py +1 -3
- datachain/lib/arrow.py +85 -0
- datachain/lib/cached_stream.py +3 -85
- datachain/lib/dc.py +382 -179
- datachain/lib/feature.py +46 -91
- datachain/lib/feature_registry.py +4 -1
- datachain/lib/feature_utils.py +2 -2
- datachain/lib/file.py +30 -44
- datachain/lib/image.py +9 -2
- datachain/lib/meta_formats.py +66 -34
- datachain/lib/settings.py +5 -5
- datachain/lib/signal_schema.py +103 -105
- datachain/lib/udf.py +10 -38
- datachain/lib/udf_signature.py +11 -6
- datachain/lib/webdataset_laion.py +5 -22
- datachain/listing.py +8 -8
- datachain/node.py +1 -1
- datachain/progress.py +1 -1
- datachain/query/builtins.py +1 -1
- datachain/query/dataset.py +42 -119
- datachain/query/dispatch.py +1 -1
- datachain/query/metrics.py +19 -0
- datachain/query/schema.py +13 -3
- datachain/sql/__init__.py +1 -1
- datachain/sql/sqlite/base.py +34 -2
- datachain/sql/sqlite/vector.py +13 -5
- datachain/utils.py +1 -122
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/METADATA +11 -4
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/RECORD +47 -47
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/WHEEL +1 -1
- datachain/_version.py +0 -16
- datachain/lib/parquet.py +0 -32
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/LICENSE +0 -0
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/entry_points.txt +0 -0
- {datachain-0.1.13.dist-info → datachain-0.2.1.dist-info}/top_level.txt +0 -0
datachain/lib/settings.py
CHANGED
|
@@ -18,19 +18,19 @@ class Settings:
|
|
|
18
18
|
|
|
19
19
|
if not isinstance(cache, bool) and cache is not None:
|
|
20
20
|
raise SettingsError(
|
|
21
|
-
|
|
21
|
+
"'cache' argument must be bool"
|
|
22
22
|
f" while {cache.__class__.__name__} was given"
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
if not isinstance(batch, int) and batch is not None:
|
|
26
26
|
raise SettingsError(
|
|
27
|
-
|
|
27
|
+
"'batch' argument must be int or None"
|
|
28
28
|
f" while {batch.__class__.__name__} was given"
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
if not isinstance(parallel, int) and parallel is not None:
|
|
32
32
|
raise SettingsError(
|
|
33
|
-
|
|
33
|
+
"'parallel' argument must be int or None"
|
|
34
34
|
f" while {parallel.__class__.__name__} was given"
|
|
35
35
|
)
|
|
36
36
|
|
|
@@ -40,13 +40,13 @@ class Settings:
|
|
|
40
40
|
and workers is not None
|
|
41
41
|
):
|
|
42
42
|
raise SettingsError(
|
|
43
|
-
|
|
43
|
+
"'workers' argument must be int or bool"
|
|
44
44
|
f" while {workers.__class__.__name__} was given"
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
if min_task_size is not None and not isinstance(min_task_size, int):
|
|
48
48
|
raise SettingsError(
|
|
49
|
-
|
|
49
|
+
"'min_task_size' argument must be int or None"
|
|
50
50
|
f", {min_task_size.__class__.__name__} was given"
|
|
51
51
|
)
|
|
52
52
|
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
import copy
|
|
2
|
-
from collections.abc import Sequence
|
|
2
|
+
from collections.abc import Iterator, Sequence
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin
|
|
5
5
|
|
|
6
|
+
from pydantic import create_model
|
|
7
|
+
|
|
8
|
+
from datachain.lib.arrow import Source
|
|
6
9
|
from datachain.lib.feature import (
|
|
7
10
|
DATACHAIN_TO_TYPE,
|
|
8
11
|
DEFAULT_DELIMITER,
|
|
@@ -11,10 +14,11 @@ from datachain.lib.feature import (
|
|
|
11
14
|
convert_type_to_datachain,
|
|
12
15
|
)
|
|
13
16
|
from datachain.lib.feature_registry import Registry
|
|
14
|
-
from datachain.lib.file import File,
|
|
17
|
+
from datachain.lib.file import File, TextFile
|
|
18
|
+
from datachain.lib.image import ImageFile
|
|
15
19
|
from datachain.lib.utils import DataChainParamsError
|
|
16
20
|
from datachain.lib.webdataset import TarStream, WDSAllFile, WDSBasic
|
|
17
|
-
from datachain.lib.webdataset_laion import Laion,
|
|
21
|
+
from datachain.lib.webdataset_laion import Laion, WDSLaion
|
|
18
22
|
|
|
19
23
|
if TYPE_CHECKING:
|
|
20
24
|
from datachain.catalog import Catalog
|
|
@@ -32,7 +36,7 @@ NAMES_TO_TYPES = {
|
|
|
32
36
|
"datetime": datetime,
|
|
33
37
|
"WDSLaion": WDSLaion,
|
|
34
38
|
"Laion": Laion,
|
|
35
|
-
"
|
|
39
|
+
"Source": Source,
|
|
36
40
|
"File": File,
|
|
37
41
|
"ImageFile": ImageFile,
|
|
38
42
|
"TextFile": TextFile,
|
|
@@ -64,6 +68,7 @@ class SignalResolvingTypeError(SignalResolvingError):
|
|
|
64
68
|
class SignalSchema:
|
|
65
69
|
def __init__(self, values: dict[str, FeatureType]):
|
|
66
70
|
self.values = values
|
|
71
|
+
self.tree = self._build_tree()
|
|
67
72
|
|
|
68
73
|
@staticmethod
|
|
69
74
|
def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
|
|
@@ -119,26 +124,10 @@ class SignalSchema:
|
|
|
119
124
|
|
|
120
125
|
def to_udf_spec(self) -> dict[str, Any]:
|
|
121
126
|
res = {}
|
|
122
|
-
for
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
delimiter = fr_type._delimiter # type: ignore[union-attr]
|
|
127
|
-
if fr_type._is_shallow: # type: ignore[union-attr]
|
|
128
|
-
signal_path = []
|
|
129
|
-
spec = fr_type._to_udf_spec() # type: ignore[union-attr]
|
|
130
|
-
for attr, value in spec:
|
|
131
|
-
name_path = [*signal_path, attr]
|
|
132
|
-
res[delimiter.join(name_path)] = value
|
|
133
|
-
else:
|
|
134
|
-
delimiter = DEFAULT_DELIMITER
|
|
135
|
-
try:
|
|
136
|
-
type_ = convert_type_to_datachain(fr_type)
|
|
137
|
-
except TypeError as err:
|
|
138
|
-
raise SignalSchemaError(
|
|
139
|
-
f"unsupported type '{fr_type}' for signal '{signal}'"
|
|
140
|
-
) from err
|
|
141
|
-
res[delimiter.join(signal_path)] = type_
|
|
127
|
+
for path, type_, has_subtree, _ in self.get_flat_tree():
|
|
128
|
+
if not has_subtree:
|
|
129
|
+
db_name = DEFAULT_DELIMITER.join(path)
|
|
130
|
+
res[db_name] = convert_type_to_datachain(type_)
|
|
142
131
|
return res
|
|
143
132
|
|
|
144
133
|
def row_to_objs(self, row: Sequence[Any]) -> list[FeatureType]:
|
|
@@ -179,35 +168,37 @@ class SignalSchema:
|
|
|
179
168
|
return res
|
|
180
169
|
|
|
181
170
|
def db_signals(self) -> list[str]:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
res.append(DEFAULT_DELIMITER.join(prefixes))
|
|
188
|
-
else:
|
|
189
|
-
if fr_cls._is_shallow: # type: ignore[union-attr]
|
|
190
|
-
prefixes = []
|
|
191
|
-
spec = fr_cls._to_udf_spec() # type: ignore[union-attr]
|
|
192
|
-
new_db_signals = [
|
|
193
|
-
DEFAULT_DELIMITER.join([*prefixes, name]) for name, type_ in spec
|
|
194
|
-
]
|
|
195
|
-
res.extend(new_db_signals)
|
|
196
|
-
return res
|
|
171
|
+
return [
|
|
172
|
+
DEFAULT_DELIMITER.join(path)
|
|
173
|
+
for path, _, has_subtree, _ in self.get_flat_tree()
|
|
174
|
+
if not has_subtree
|
|
175
|
+
]
|
|
197
176
|
|
|
198
177
|
def resolve(self, *names: str) -> "SignalSchema":
|
|
199
178
|
schema = {}
|
|
200
|
-
tree = self._get_prefix_tree()
|
|
201
179
|
for field in names:
|
|
202
180
|
if not isinstance(field, str):
|
|
203
181
|
raise SignalResolvingTypeError("select()", field)
|
|
204
|
-
|
|
205
|
-
path = field.split(".")
|
|
206
|
-
cls, position = self._find_feature_in_prefix_tree(tree, path)
|
|
207
|
-
schema[field] = self._find_in_feature(cls, path, position)
|
|
182
|
+
schema[field] = self._find_in_tree(field.split("."))
|
|
208
183
|
|
|
209
184
|
return SignalSchema(schema)
|
|
210
185
|
|
|
186
|
+
def _find_in_tree(self, path: list[str]) -> FeatureType:
|
|
187
|
+
curr_tree = self.tree
|
|
188
|
+
curr_type = None
|
|
189
|
+
i = 0
|
|
190
|
+
while curr_tree is not None and i < len(path):
|
|
191
|
+
if val := curr_tree.get(path[i], None):
|
|
192
|
+
curr_type, curr_tree = val
|
|
193
|
+
else:
|
|
194
|
+
curr_type = None
|
|
195
|
+
i += 1
|
|
196
|
+
|
|
197
|
+
if curr_type is None:
|
|
198
|
+
raise SignalResolvingError(path, "is not found")
|
|
199
|
+
|
|
200
|
+
return curr_type
|
|
201
|
+
|
|
211
202
|
def select_except_signals(self, *args: str) -> "SignalSchema":
|
|
212
203
|
schema = copy.deepcopy(self.values)
|
|
213
204
|
for field in args:
|
|
@@ -224,59 +215,6 @@ class SignalSchema:
|
|
|
224
215
|
|
|
225
216
|
return SignalSchema(schema)
|
|
226
217
|
|
|
227
|
-
def _get_prefix_tree(self) -> dict[str, Any]:
|
|
228
|
-
tree: dict[str, Any] = {}
|
|
229
|
-
for name, fr_cls in self.values.items():
|
|
230
|
-
prefixes = name.split(".")
|
|
231
|
-
|
|
232
|
-
curr_tree = {}
|
|
233
|
-
curr_prefix = ""
|
|
234
|
-
for prefix in prefixes:
|
|
235
|
-
if not curr_prefix:
|
|
236
|
-
curr_prefix = prefix
|
|
237
|
-
curr_tree = tree
|
|
238
|
-
else:
|
|
239
|
-
new_tree = curr_tree.get(curr_prefix, {}) #
|
|
240
|
-
curr_tree[curr_prefix] = new_tree
|
|
241
|
-
curr_tree = new_tree
|
|
242
|
-
curr_prefix = prefix
|
|
243
|
-
|
|
244
|
-
curr_tree[curr_prefix] = fr_cls
|
|
245
|
-
return tree
|
|
246
|
-
|
|
247
|
-
def _find_feature_in_prefix_tree(
|
|
248
|
-
self, tree: dict, path: list[str]
|
|
249
|
-
) -> tuple[FeatureType, int]:
|
|
250
|
-
for i in range(len(path)):
|
|
251
|
-
prefix = path[i]
|
|
252
|
-
if prefix not in tree:
|
|
253
|
-
raise SignalResolvingError(path, f"'{prefix}' is not found")
|
|
254
|
-
val = tree[prefix]
|
|
255
|
-
if not isinstance(val, dict):
|
|
256
|
-
return val, i + 1
|
|
257
|
-
tree = val
|
|
258
|
-
|
|
259
|
-
next_keys = ", ".join(tree.keys())
|
|
260
|
-
raise SignalResolvingError(
|
|
261
|
-
path,
|
|
262
|
-
f"it's not a terminal value or feature, next item might be '{next_keys}'",
|
|
263
|
-
)
|
|
264
|
-
|
|
265
|
-
def _find_in_feature(
|
|
266
|
-
self, cls: FeatureType, path: list[str], position: int
|
|
267
|
-
) -> FeatureType:
|
|
268
|
-
if position == len(path):
|
|
269
|
-
return cls
|
|
270
|
-
|
|
271
|
-
name = path[position]
|
|
272
|
-
field_info = cls.model_fields.get(name, None) # type: ignore[union-attr]
|
|
273
|
-
if field_info is None:
|
|
274
|
-
raise SignalResolvingError(
|
|
275
|
-
path, f"field '{name}' is not found in Feature '{cls.__name__}'"
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
return self._find_in_feature(field_info.annotation, path, position + 1) # type: ignore[arg-type]
|
|
279
|
-
|
|
280
218
|
def clone_without_file_signals(self) -> "SignalSchema":
|
|
281
219
|
schema = copy.deepcopy(self.values)
|
|
282
220
|
|
|
@@ -297,14 +235,10 @@ class SignalSchema:
|
|
|
297
235
|
|
|
298
236
|
return SignalSchema(self.values | schema_right)
|
|
299
237
|
|
|
300
|
-
def get_file_signals(self) ->
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
signals = fr.get_file_signals([name]) # type: ignore[union-attr]
|
|
305
|
-
for signal in signals:
|
|
306
|
-
res.append(".".join(signal))
|
|
307
|
-
return res
|
|
238
|
+
def get_file_signals(self) -> Iterator[str]:
|
|
239
|
+
for path, type_, has_subtree, _ in self.get_flat_tree():
|
|
240
|
+
if has_subtree and issubclass(type_, File):
|
|
241
|
+
yield ".".join(path)
|
|
308
242
|
|
|
309
243
|
def get_file_signals_values(self, row: dict[str, Any]) -> dict[str, Any]:
|
|
310
244
|
"""
|
|
@@ -336,3 +270,67 @@ class SignalSchema:
|
|
|
336
270
|
}
|
|
337
271
|
|
|
338
272
|
return res
|
|
273
|
+
|
|
274
|
+
def create_model(self, name: str) -> type[Feature]:
|
|
275
|
+
fields = {key: (value, None) for key, value in self.values.items()}
|
|
276
|
+
|
|
277
|
+
return create_model(
|
|
278
|
+
name,
|
|
279
|
+
__base__=(Feature,), # type: ignore[call-overload]
|
|
280
|
+
**fields,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _build_tree(self) -> dict[str, Any]:
|
|
284
|
+
res = {}
|
|
285
|
+
|
|
286
|
+
for name, val in self.values.items():
|
|
287
|
+
subtree = val.build_tree() if Feature.is_feature(val) else None # type: ignore[union-attr]
|
|
288
|
+
res[name] = (val, subtree)
|
|
289
|
+
|
|
290
|
+
return res
|
|
291
|
+
|
|
292
|
+
def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]:
|
|
293
|
+
yield from self._get_flat_tree(self.tree, [], 0)
|
|
294
|
+
|
|
295
|
+
def _get_flat_tree(
|
|
296
|
+
self, tree: dict, prefix: list[str], depth: int
|
|
297
|
+
) -> Iterator[tuple[list[str], type, bool, int]]:
|
|
298
|
+
for name, (type_, substree) in tree.items():
|
|
299
|
+
suffix = name.split(".")
|
|
300
|
+
new_prefix = prefix + suffix
|
|
301
|
+
has_subtree = substree is not None
|
|
302
|
+
yield new_prefix, type_, has_subtree, depth
|
|
303
|
+
if substree is not None:
|
|
304
|
+
yield from self._get_flat_tree(substree, new_prefix, depth + 1)
|
|
305
|
+
|
|
306
|
+
def print_tree(self, indent: int = 4, start_at: int = 0):
|
|
307
|
+
for path, type_, _, depth in self.get_flat_tree():
|
|
308
|
+
total_indent = start_at + depth * indent
|
|
309
|
+
print(" " * total_indent, f"{path[-1]}:", SignalSchema._type_to_str(type_))
|
|
310
|
+
|
|
311
|
+
if get_origin(type_) is list:
|
|
312
|
+
args = get_args(type_)
|
|
313
|
+
if len(args) > 0 and Feature.is_feature(args[0]):
|
|
314
|
+
sub_schema = SignalSchema({"* list of": args[0]})
|
|
315
|
+
sub_schema.print_tree(indent=indent, start_at=total_indent + indent)
|
|
316
|
+
|
|
317
|
+
@staticmethod
|
|
318
|
+
def _type_to_str(type_):
|
|
319
|
+
if get_origin(type_) == Union:
|
|
320
|
+
args = get_args(type_)
|
|
321
|
+
formatted_types = ", ".join(SignalSchema._type_to_str(arg) for arg in args)
|
|
322
|
+
return f"Union[{formatted_types}]"
|
|
323
|
+
if get_origin(type_) == Optional:
|
|
324
|
+
args = get_args(type_)
|
|
325
|
+
type_str = SignalSchema._type_to_str(args[0])
|
|
326
|
+
return f"Optional[{type_str}]"
|
|
327
|
+
if get_origin(type_) is list:
|
|
328
|
+
args = get_args(type_)
|
|
329
|
+
type_str = SignalSchema._type_to_str(args[0])
|
|
330
|
+
return f"list[{type_str}]"
|
|
331
|
+
if get_origin(type_) is dict:
|
|
332
|
+
args = get_args(type_)
|
|
333
|
+
type_str = SignalSchema._type_to_str(args[0])
|
|
334
|
+
vals = f", {SignalSchema._type_to_str(args[1])}" if len(args) > 1 else ""
|
|
335
|
+
return f"dict[{type_str}{vals}]"
|
|
336
|
+
return type_.__name__
|
datachain/lib/udf.py
CHANGED
|
@@ -6,10 +6,10 @@ from typing import TYPE_CHECKING, Callable, Optional
|
|
|
6
6
|
from datachain.lib.feature import Feature
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
8
8
|
from datachain.lib.utils import DataChainError, DataChainParamsError
|
|
9
|
-
from datachain.query import
|
|
9
|
+
from datachain.query import udf
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
|
-
from
|
|
12
|
+
from datachain.query.udf import UDFWrapper
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class UdfError(DataChainParamsError):
|
|
@@ -21,7 +21,6 @@ class UDFBase:
|
|
|
21
21
|
is_input_batched = False
|
|
22
22
|
is_output_batched = False
|
|
23
23
|
is_input_grouped = False
|
|
24
|
-
is_output_single = False
|
|
25
24
|
|
|
26
25
|
def __init__(
|
|
27
26
|
self,
|
|
@@ -35,11 +34,6 @@ class UDFBase:
|
|
|
35
34
|
|
|
36
35
|
params_spec = params.to_udf_spec()
|
|
37
36
|
self.params_spec = list(params_spec.keys())
|
|
38
|
-
self._contains_stream = False
|
|
39
|
-
if params.contains_file():
|
|
40
|
-
self.params_spec.insert(0, Stream()) # type: ignore[arg-type]
|
|
41
|
-
self._contains_stream = True
|
|
42
|
-
|
|
43
37
|
self.output_spec = output.to_udf_spec()
|
|
44
38
|
|
|
45
39
|
self._catalog = None
|
|
@@ -91,9 +85,6 @@ class UDFBase:
|
|
|
91
85
|
if not self.is_output_batched:
|
|
92
86
|
result_objs = [result_objs]
|
|
93
87
|
|
|
94
|
-
if self.is_output_single:
|
|
95
|
-
result_objs = [[x] for x in result_objs]
|
|
96
|
-
|
|
97
88
|
if len(self.output.values) > 1:
|
|
98
89
|
res = []
|
|
99
90
|
for tuple_ in result_objs:
|
|
@@ -107,7 +98,7 @@ class UDFBase:
|
|
|
107
98
|
else:
|
|
108
99
|
# Generator expression is required, otherwise the value will be materialized
|
|
109
100
|
res = (
|
|
110
|
-
|
|
101
|
+
obj._flatten() if isinstance(obj, Feature) else (obj,)
|
|
111
102
|
for obj in result_objs
|
|
112
103
|
)
|
|
113
104
|
|
|
@@ -126,18 +117,10 @@ class UDFBase:
|
|
|
126
117
|
rows = [rows]
|
|
127
118
|
objs = []
|
|
128
119
|
for row in rows:
|
|
129
|
-
if self._contains_stream:
|
|
130
|
-
stream, *row = row
|
|
131
|
-
else:
|
|
132
|
-
stream = None
|
|
133
|
-
|
|
134
120
|
obj_row = self.params.row_to_objs(row)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
if isinstance(obj, Feature):
|
|
139
|
-
obj._set_stream(self._catalog, stream, True)
|
|
140
|
-
|
|
121
|
+
for obj in obj_row:
|
|
122
|
+
if isinstance(obj, Feature):
|
|
123
|
+
obj._set_stream(self._catalog, caching_enabled=True)
|
|
141
124
|
objs.append(obj_row)
|
|
142
125
|
return objs
|
|
143
126
|
|
|
@@ -145,22 +128,16 @@ class UDFBase:
|
|
|
145
128
|
group = rows[0]
|
|
146
129
|
spec_map = {}
|
|
147
130
|
output_map = {}
|
|
148
|
-
for name, anno in self.params.
|
|
131
|
+
for name, (anno, subtree) in self.params.tree.items():
|
|
149
132
|
if inspect.isclass(anno) and issubclass(anno, Feature):
|
|
150
|
-
length =
|
|
133
|
+
length = sum(1 for _ in self.params._get_flat_tree(subtree, [], 0))
|
|
151
134
|
else:
|
|
152
135
|
length = 1
|
|
153
136
|
spec_map[name] = anno, length
|
|
154
137
|
output_map[name] = []
|
|
155
138
|
|
|
156
139
|
for flat_obj in group:
|
|
157
|
-
|
|
158
|
-
position = 1
|
|
159
|
-
stream = flat_obj[0]
|
|
160
|
-
else:
|
|
161
|
-
position = 0
|
|
162
|
-
stream = None
|
|
163
|
-
|
|
140
|
+
position = 0
|
|
164
141
|
for signal, (cls, length) in spec_map.items():
|
|
165
142
|
slice = flat_obj[position : position + length]
|
|
166
143
|
position += length
|
|
@@ -171,7 +148,7 @@ class UDFBase:
|
|
|
171
148
|
obj = slice[0]
|
|
172
149
|
|
|
173
150
|
if isinstance(obj, Feature):
|
|
174
|
-
obj._set_stream(self._catalog
|
|
151
|
+
obj._set_stream(self._catalog)
|
|
175
152
|
output_map[signal].append(obj)
|
|
176
153
|
|
|
177
154
|
return list(output_map.values())
|
|
@@ -208,8 +185,3 @@ class Aggregator(UDFBase):
|
|
|
208
185
|
is_input_batched = True
|
|
209
186
|
is_output_batched = True
|
|
210
187
|
is_input_grouped = True
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
class GroupMapper(UDFBase):
|
|
214
|
-
is_input_batched = True
|
|
215
|
-
is_output_batched = True
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -20,6 +20,8 @@ class UdfSignature:
|
|
|
20
20
|
params: Sequence[str]
|
|
21
21
|
output_schema: SignalSchema
|
|
22
22
|
|
|
23
|
+
DEFAULT_RETURN_TYPE = str
|
|
24
|
+
|
|
23
25
|
@classmethod
|
|
24
26
|
def parse(
|
|
25
27
|
cls,
|
|
@@ -35,7 +37,7 @@ class UdfSignature:
|
|
|
35
37
|
raise UdfSignatureError(
|
|
36
38
|
chain,
|
|
37
39
|
f"multiple signals '{keys}' are not supported in processors."
|
|
38
|
-
|
|
40
|
+
" Chain multiple processors instead.",
|
|
39
41
|
)
|
|
40
42
|
if len(signal_map) == 1:
|
|
41
43
|
if func is not None:
|
|
@@ -69,7 +71,7 @@ class UdfSignature:
|
|
|
69
71
|
raise UdfSignatureError(
|
|
70
72
|
chain,
|
|
71
73
|
f"outputs are not defined in function '{udf_func.__name__}'"
|
|
72
|
-
|
|
74
|
+
" hints or 'output'",
|
|
73
75
|
)
|
|
74
76
|
|
|
75
77
|
if not signal_name:
|
|
@@ -83,7 +85,7 @@ class UdfSignature:
|
|
|
83
85
|
raise UdfSignatureError(
|
|
84
86
|
chain,
|
|
85
87
|
f"function '{func}' cannot be used in generator/aggregator"
|
|
86
|
-
|
|
88
|
+
" because it returns a type that is not Iterator/Generator."
|
|
87
89
|
f" Instead, it returns '{func_outs_sign}'",
|
|
88
90
|
)
|
|
89
91
|
|
|
@@ -127,7 +129,7 @@ class UdfSignature:
|
|
|
127
129
|
raise UdfSignatureError(
|
|
128
130
|
chain,
|
|
129
131
|
f"output signal '{key}' has type '{type(key)}'"
|
|
130
|
-
|
|
132
|
+
" while 'str' is expected",
|
|
131
133
|
)
|
|
132
134
|
if not Feature.is_feature_type(value):
|
|
133
135
|
raise UdfSignatureError(
|
|
@@ -143,7 +145,7 @@ class UdfSignature:
|
|
|
143
145
|
raise UdfSignatureError(
|
|
144
146
|
chain,
|
|
145
147
|
f"unknown output type: {output}. List of signals or dict of signals"
|
|
146
|
-
|
|
148
|
+
" to function are expected.",
|
|
147
149
|
)
|
|
148
150
|
return udf_output_map
|
|
149
151
|
|
|
@@ -182,9 +184,12 @@ class UdfSignature:
|
|
|
182
184
|
anno = args[0]
|
|
183
185
|
orig = get_origin(anno)
|
|
184
186
|
|
|
185
|
-
if orig and orig
|
|
187
|
+
if orig and orig is tuple:
|
|
186
188
|
output_types = tuple(get_args(anno)) # type: ignore[assignment]
|
|
187
189
|
else:
|
|
188
190
|
output_types = [anno]
|
|
189
191
|
|
|
192
|
+
if not output_types:
|
|
193
|
+
output_types = [UdfSignature.DEFAULT_RETURN_TYPE]
|
|
194
|
+
|
|
190
195
|
return input_map, output_types, is_iterator
|
|
@@ -4,8 +4,8 @@ from typing import Optional
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from pydantic import Field
|
|
6
6
|
|
|
7
|
+
from datachain.lib.feature import Feature
|
|
7
8
|
from datachain.lib.file import File
|
|
8
|
-
from datachain.lib.parquet import BasicParquet
|
|
9
9
|
from datachain.lib.webdataset import WDSBasic, WDSReadableSubclass
|
|
10
10
|
|
|
11
11
|
|
|
@@ -34,19 +34,9 @@ class WDSLaion(WDSBasic):
|
|
|
34
34
|
json: Laion # type: ignore[assignment]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
text: str = Field(default="")
|
|
41
|
-
original_width: int = Field(default=-1)
|
|
42
|
-
original_height: int = Field(default=-1)
|
|
43
|
-
clip_b32_similarity_score: float = Field(default=0.0)
|
|
44
|
-
clip_l14_similarity_score: float = Field(default=0.0)
|
|
45
|
-
face_bboxes: Optional[list[list[float]]] = Field(default=None)
|
|
46
|
-
sha256: str = Field(default="")
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class LaionMeta(BasicParquet):
|
|
37
|
+
class LaionMeta(Feature):
|
|
38
|
+
file: File
|
|
39
|
+
index: Optional[int] = Field(default=None)
|
|
50
40
|
b32_img: list[float] = Field(default=None)
|
|
51
41
|
b32_txt: list[float] = Field(default=None)
|
|
52
42
|
l14_img: list[float] = Field(default=None)
|
|
@@ -65,14 +55,7 @@ def process_laion_meta(file: File) -> Iterator[LaionMeta]:
|
|
|
65
55
|
|
|
66
56
|
for index in range(len(b32_img)):
|
|
67
57
|
yield LaionMeta(
|
|
68
|
-
file=
|
|
69
|
-
name=str(index),
|
|
70
|
-
source=file.source,
|
|
71
|
-
parent=f"{file.get_full_name()}",
|
|
72
|
-
version=file.version,
|
|
73
|
-
etag=f"{file.etag}_{index}",
|
|
74
|
-
location={"vtype": LaionMeta.__name__},
|
|
75
|
-
),
|
|
58
|
+
file=file,
|
|
76
59
|
index=index,
|
|
77
60
|
b32_img=b32_img[index],
|
|
78
61
|
b32_txt=b32_txt[index],
|
datachain/listing.py
CHANGED
|
@@ -192,25 +192,25 @@ class Listing:
|
|
|
192
192
|
dr = self.dataset_rows
|
|
193
193
|
conds = []
|
|
194
194
|
if names:
|
|
195
|
-
|
|
196
|
-
|
|
195
|
+
f = Column("name").op("GLOB")
|
|
196
|
+
conds.extend(f(name) for name in names)
|
|
197
197
|
if inames:
|
|
198
|
-
|
|
199
|
-
|
|
198
|
+
f = func.lower(Column("name")).op("GLOB")
|
|
199
|
+
conds.extend(f(iname.lower()) for iname in inames)
|
|
200
200
|
if paths:
|
|
201
201
|
node_path = case(
|
|
202
202
|
(Column("parent") == "", Column("name")),
|
|
203
203
|
else_=Column("parent") + "/" + Column("name"),
|
|
204
204
|
)
|
|
205
|
-
|
|
206
|
-
|
|
205
|
+
f = node_path.op("GLOB")
|
|
206
|
+
conds.extend(f(path) for path in paths)
|
|
207
207
|
if ipaths:
|
|
208
208
|
node_path = case(
|
|
209
209
|
(Column("parent") == "", Column("name")),
|
|
210
210
|
else_=Column("parent") + "/" + Column("name"),
|
|
211
211
|
)
|
|
212
|
-
|
|
213
|
-
|
|
212
|
+
f = func.lower(node_path).op("GLOB")
|
|
213
|
+
conds.extend(f(ipath.lower()) for ipath in ipaths)
|
|
214
214
|
|
|
215
215
|
if size is not None:
|
|
216
216
|
size_limit = suffix_to_number(size)
|
datachain/node.py
CHANGED
|
@@ -47,6 +47,7 @@ class DirTypeGroup:
|
|
|
47
47
|
@attrs.define
|
|
48
48
|
class Node:
|
|
49
49
|
id: int = 0
|
|
50
|
+
random: int = -1
|
|
50
51
|
vtype: str = ""
|
|
51
52
|
dir_type: Optional[int] = None
|
|
52
53
|
parent: str = ""
|
|
@@ -58,7 +59,6 @@ class Node:
|
|
|
58
59
|
size: int = 0
|
|
59
60
|
owner_name: str = ""
|
|
60
61
|
owner_id: str = ""
|
|
61
|
-
random: int = -1
|
|
62
62
|
location: Optional[str] = None
|
|
63
63
|
source: StorageURI = StorageURI("")
|
|
64
64
|
|
datachain/progress.py
CHANGED