datachain 0.34.6__py3-none-any.whl → 0.34.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/catalog.py +75 -83
- datachain/catalog/loader.py +3 -3
- datachain/checkpoint.py +1 -2
- datachain/cli/__init__.py +2 -4
- datachain/cli/commands/datasets.py +13 -13
- datachain/cli/commands/ls.py +4 -4
- datachain/cli/commands/query.py +3 -3
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +1 -2
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +11 -21
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +4 -4
- datachain/client/local.py +4 -4
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +5 -5
- datachain/data_storage/metastore.py +107 -107
- datachain/data_storage/schema.py +18 -24
- datachain/data_storage/sqlite.py +21 -28
- datachain/data_storage/warehouse.py +13 -13
- datachain/dataset.py +64 -70
- datachain/delta.py +21 -18
- datachain/diff/__init__.py +13 -13
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +45 -42
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +54 -81
- datachain/job.py +8 -8
- datachain/lib/arrow.py +17 -14
- datachain/lib/audio.py +6 -6
- datachain/lib/clip.py +5 -4
- datachain/lib/convert/python_to_sql.py +4 -22
- datachain/lib/convert/values_to_tuples.py +4 -9
- datachain/lib/data_model.py +20 -19
- datachain/lib/dataset_info.py +6 -6
- datachain/lib/dc/csv.py +10 -10
- datachain/lib/dc/database.py +28 -29
- datachain/lib/dc/datachain.py +98 -97
- datachain/lib/dc/datasets.py +22 -22
- datachain/lib/dc/hf.py +4 -4
- datachain/lib/dc/json.py +9 -10
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +5 -5
- datachain/lib/dc/records.py +5 -5
- datachain/lib/dc/storage.py +12 -12
- datachain/lib/dc/storage_pattern.py +2 -2
- datachain/lib/dc/utils.py +11 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +26 -26
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +1 -2
- datachain/lib/model_store.py +3 -3
- datachain/lib/namespaces.py +4 -6
- datachain/lib/projects.py +5 -9
- datachain/lib/pytorch.py +10 -10
- datachain/lib/settings.py +23 -23
- datachain/lib/signal_schema.py +52 -44
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +25 -17
- datachain/lib/udf_signature.py +11 -11
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +30 -35
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +4 -4
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +4 -4
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +1 -7
- datachain/project.py +4 -4
- datachain/query/batch.py +7 -8
- datachain/query/dataset.py +80 -87
- datachain/query/dispatch.py +7 -7
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/schema.py +7 -6
- datachain/query/session.py +7 -7
- datachain/query/udf.py +8 -7
- datachain/query/utils.py +3 -5
- datachain/remote/studio.py +33 -39
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +6 -9
- datachain/studio.py +30 -30
- datachain/toolkit/split.py +1 -2
- datachain/utils.py +21 -21
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/METADATA +2 -3
- datachain-0.34.7.dist-info/RECORD +173 -0
- datachain-0.34.6.dist-info/RECORD +0 -173
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/WHEEL +0 -0
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/entry_points.txt +0 -0
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/top_level.txt +0 -0
datachain/lib/projects.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
from datachain.error import ProjectCreateNotAllowedError, ProjectDeleteNotAllowedError
|
|
4
2
|
from datachain.project import Project
|
|
5
3
|
from datachain.query import Session
|
|
@@ -8,8 +6,8 @@ from datachain.query import Session
|
|
|
8
6
|
def create(
|
|
9
7
|
namespace: str,
|
|
10
8
|
name: str,
|
|
11
|
-
descr:
|
|
12
|
-
session:
|
|
9
|
+
descr: str | None = None,
|
|
10
|
+
session: Session | None = None,
|
|
13
11
|
) -> Project:
|
|
14
12
|
"""
|
|
15
13
|
Creates a new project under a specified namespace.
|
|
@@ -42,7 +40,7 @@ def create(
|
|
|
42
40
|
return session.catalog.metastore.create_project(namespace, name, descr)
|
|
43
41
|
|
|
44
42
|
|
|
45
|
-
def get(name: str, namespace: str, session:
|
|
43
|
+
def get(name: str, namespace: str, session: Session | None) -> Project:
|
|
46
44
|
"""
|
|
47
45
|
Gets a project by name in some namespace.
|
|
48
46
|
If the project is not found, a `ProjectNotFoundError` is raised.
|
|
@@ -62,9 +60,7 @@ def get(name: str, namespace: str, session: Optional[Session]) -> Project:
|
|
|
62
60
|
return Session.get(session).catalog.metastore.get_project(name, namespace)
|
|
63
61
|
|
|
64
62
|
|
|
65
|
-
def ls(
|
|
66
|
-
namespace: Optional[str] = None, session: Optional[Session] = None
|
|
67
|
-
) -> list[Project]:
|
|
63
|
+
def ls(namespace: str | None = None, session: Session | None = None) -> list[Project]:
|
|
68
64
|
"""
|
|
69
65
|
Gets a list of projects in a specific namespace or from all namespaces.
|
|
70
66
|
|
|
@@ -88,7 +84,7 @@ def ls(
|
|
|
88
84
|
return session.catalog.metastore.list_projects(namespace_id)
|
|
89
85
|
|
|
90
86
|
|
|
91
|
-
def delete(name: str, namespace: str, session:
|
|
87
|
+
def delete(name: str, namespace: str, session: Session | None = None) -> None:
|
|
92
88
|
"""
|
|
93
89
|
Removes a project by name within a namespace.
|
|
94
90
|
|
datachain/lib/pytorch.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
import weakref
|
|
4
|
-
from collections.abc import Generator, Iterable, Iterator
|
|
4
|
+
from collections.abc import Callable, Generator, Iterable, Iterator
|
|
5
5
|
from contextlib import closing
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from PIL import Image
|
|
9
9
|
from torch import float32
|
|
@@ -43,13 +43,13 @@ class PytorchDataset(IterableDataset):
|
|
|
43
43
|
def __init__(
|
|
44
44
|
self,
|
|
45
45
|
name: str,
|
|
46
|
-
version:
|
|
47
|
-
catalog:
|
|
48
|
-
transform:
|
|
49
|
-
tokenizer:
|
|
50
|
-
tokenizer_kwargs:
|
|
46
|
+
version: str | None = None,
|
|
47
|
+
catalog: Catalog | None = None,
|
|
48
|
+
transform: "Transform | None" = None,
|
|
49
|
+
tokenizer: Callable | None = None,
|
|
50
|
+
tokenizer_kwargs: dict[str, Any] | None = None,
|
|
51
51
|
num_samples: int = 0,
|
|
52
|
-
dc_settings:
|
|
52
|
+
dc_settings: Settings | None = None,
|
|
53
53
|
remove_prefetched: bool = False,
|
|
54
54
|
):
|
|
55
55
|
"""
|
|
@@ -84,7 +84,7 @@ class PytorchDataset(IterableDataset):
|
|
|
84
84
|
self.prefetch = prefetch
|
|
85
85
|
|
|
86
86
|
self._cache = catalog.cache
|
|
87
|
-
self._prefetch_cache:
|
|
87
|
+
self._prefetch_cache: Cache | None = None
|
|
88
88
|
self._remove_prefetched = remove_prefetched
|
|
89
89
|
if prefetch and not self.cache:
|
|
90
90
|
tmp_dir = catalog.cache.tmp_dir
|
|
@@ -104,7 +104,7 @@ class PytorchDataset(IterableDataset):
|
|
|
104
104
|
self._ms_params = catalog.metastore.clone_params()
|
|
105
105
|
self._wh_params = catalog.warehouse.clone_params()
|
|
106
106
|
self._catalog_params = catalog.get_init_params()
|
|
107
|
-
self.catalog:
|
|
107
|
+
self.catalog: Catalog | None = None
|
|
108
108
|
|
|
109
109
|
def _get_catalog(self) -> "Catalog":
|
|
110
110
|
ms_cls, ms_args, ms_kwargs = self._ms_params
|
datachain/lib/settings.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
from datachain.lib.utils import DataChainParamsError
|
|
4
4
|
|
|
@@ -15,25 +15,25 @@ class SettingsError(DataChainParamsError):
|
|
|
15
15
|
class Settings:
|
|
16
16
|
"""Settings for datachain."""
|
|
17
17
|
|
|
18
|
-
_cache:
|
|
19
|
-
_prefetch:
|
|
20
|
-
_parallel:
|
|
21
|
-
_workers:
|
|
22
|
-
_namespace:
|
|
23
|
-
_project:
|
|
24
|
-
_min_task_size:
|
|
25
|
-
_batch_size:
|
|
18
|
+
_cache: bool | None
|
|
19
|
+
_prefetch: int | None
|
|
20
|
+
_parallel: bool | int | None
|
|
21
|
+
_workers: int | None
|
|
22
|
+
_namespace: str | None
|
|
23
|
+
_project: str | None
|
|
24
|
+
_min_task_size: int | None
|
|
25
|
+
_batch_size: int | None
|
|
26
26
|
|
|
27
27
|
def __init__( # noqa: C901, PLR0912
|
|
28
28
|
self,
|
|
29
|
-
cache:
|
|
30
|
-
prefetch:
|
|
31
|
-
parallel:
|
|
32
|
-
workers:
|
|
33
|
-
namespace:
|
|
34
|
-
project:
|
|
35
|
-
min_task_size:
|
|
36
|
-
batch_size:
|
|
29
|
+
cache: bool | None = None,
|
|
30
|
+
prefetch: bool | int | None = None,
|
|
31
|
+
parallel: bool | int | None = None,
|
|
32
|
+
workers: int | None = None,
|
|
33
|
+
namespace: str | None = None,
|
|
34
|
+
project: str | None = None,
|
|
35
|
+
min_task_size: int | None = None,
|
|
36
|
+
batch_size: int | None = None,
|
|
37
37
|
) -> None:
|
|
38
38
|
if cache is None:
|
|
39
39
|
self._cache = None
|
|
@@ -148,27 +148,27 @@ class Settings:
|
|
|
148
148
|
return self._cache if self._cache is not None else DEFAULT_CACHE
|
|
149
149
|
|
|
150
150
|
@property
|
|
151
|
-
def prefetch(self) ->
|
|
151
|
+
def prefetch(self) -> int | None:
|
|
152
152
|
return self._prefetch if self._prefetch is not None else DEFAULT_PREFETCH
|
|
153
153
|
|
|
154
154
|
@property
|
|
155
|
-
def parallel(self) ->
|
|
155
|
+
def parallel(self) -> bool | int | None:
|
|
156
156
|
return self._parallel if self._parallel is not None else None
|
|
157
157
|
|
|
158
158
|
@property
|
|
159
|
-
def workers(self) ->
|
|
159
|
+
def workers(self) -> int | None:
|
|
160
160
|
return self._workers if self._workers is not None else None
|
|
161
161
|
|
|
162
162
|
@property
|
|
163
|
-
def namespace(self) ->
|
|
163
|
+
def namespace(self) -> str | None:
|
|
164
164
|
return self._namespace if self._namespace is not None else None
|
|
165
165
|
|
|
166
166
|
@property
|
|
167
|
-
def project(self) ->
|
|
167
|
+
def project(self) -> str | None:
|
|
168
168
|
return self._project if self._project is not None else None
|
|
169
169
|
|
|
170
170
|
@property
|
|
171
|
-
def min_task_size(self) ->
|
|
171
|
+
def min_task_size(self) -> int | None:
|
|
172
172
|
return self._min_task_size if self._min_task_size is not None else None
|
|
173
173
|
|
|
174
174
|
@property
|
datachain/lib/signal_schema.py
CHANGED
|
@@ -3,22 +3,21 @@ import hashlib
|
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
5
|
import math
|
|
6
|
+
import types
|
|
6
7
|
import warnings
|
|
7
|
-
from collections.abc import Iterator, Sequence
|
|
8
|
+
from collections.abc import Callable, Iterator, Mapping, Sequence
|
|
8
9
|
from dataclasses import dataclass
|
|
9
10
|
from datetime import datetime
|
|
10
11
|
from inspect import isclass
|
|
11
|
-
from typing import (
|
|
12
|
+
from typing import (
|
|
12
13
|
IO,
|
|
13
14
|
TYPE_CHECKING,
|
|
14
15
|
Annotated,
|
|
15
16
|
Any,
|
|
16
|
-
|
|
17
|
-
Dict,
|
|
17
|
+
Dict, # type: ignore[UP035]
|
|
18
18
|
Final,
|
|
19
|
-
List,
|
|
19
|
+
List, # type: ignore[UP035]
|
|
20
20
|
Literal,
|
|
21
|
-
Mapping,
|
|
22
21
|
Optional,
|
|
23
22
|
Union,
|
|
24
23
|
get_args,
|
|
@@ -75,7 +74,7 @@ class SignalSchemaWarning(RuntimeWarning):
|
|
|
75
74
|
|
|
76
75
|
|
|
77
76
|
class SignalResolvingError(SignalSchemaError):
|
|
78
|
-
def __init__(self, path:
|
|
77
|
+
def __init__(self, path: list[str] | None, msg: str):
|
|
79
78
|
name = " '" + ".".join(path) + "'" if path else ""
|
|
80
79
|
super().__init__(f"cannot resolve signal name{name}: {msg}")
|
|
81
80
|
|
|
@@ -95,7 +94,7 @@ class SignalResolvingTypeError(SignalResolvingError):
|
|
|
95
94
|
|
|
96
95
|
|
|
97
96
|
class SignalRemoveError(SignalSchemaError):
|
|
98
|
-
def __init__(self, path:
|
|
97
|
+
def __init__(self, path: list[str] | None, msg: str):
|
|
99
98
|
name = " '" + ".".join(path) + "'" if path else ""
|
|
100
99
|
super().__init__(f"cannot remove signal name{name}: {msg}")
|
|
101
100
|
|
|
@@ -104,8 +103,8 @@ class CustomType(BaseModel):
|
|
|
104
103
|
schema_version: int = Field(ge=1, le=2, strict=True)
|
|
105
104
|
name: str
|
|
106
105
|
fields: dict[str, str]
|
|
107
|
-
bases: list[tuple[str, str,
|
|
108
|
-
hidden_fields:
|
|
106
|
+
bases: list[tuple[str, str, str | None]]
|
|
107
|
+
hidden_fields: list[str] | None = None
|
|
109
108
|
|
|
110
109
|
@classmethod
|
|
111
110
|
def deserialize(cls, data: dict[str, Any], type_name: str) -> "CustomType":
|
|
@@ -125,8 +124,8 @@ class CustomType(BaseModel):
|
|
|
125
124
|
|
|
126
125
|
def create_feature_model(
|
|
127
126
|
name: str,
|
|
128
|
-
fields: Mapping[str,
|
|
129
|
-
base:
|
|
127
|
+
fields: Mapping[str, type | tuple[type, Any] | None],
|
|
128
|
+
base: type | None = None,
|
|
130
129
|
) -> type[BaseModel]:
|
|
131
130
|
"""
|
|
132
131
|
This gets or returns a dynamic feature model for use in restoring a model
|
|
@@ -152,12 +151,12 @@ class SignalSchema:
|
|
|
152
151
|
values: dict[str, DataType]
|
|
153
152
|
tree: dict[str, Any]
|
|
154
153
|
setup_func: dict[str, Callable]
|
|
155
|
-
setup_values:
|
|
154
|
+
setup_values: dict[str, Any] | None
|
|
156
155
|
|
|
157
156
|
def __init__(
|
|
158
157
|
self,
|
|
159
158
|
values: dict[str, DataType],
|
|
160
|
-
setup:
|
|
159
|
+
setup: dict[str, Callable] | None = None,
|
|
161
160
|
):
|
|
162
161
|
self.values = values
|
|
163
162
|
self.tree = self._build_tree(values)
|
|
@@ -196,8 +195,8 @@ class SignalSchema:
|
|
|
196
195
|
return SignalSchema(signals)
|
|
197
196
|
|
|
198
197
|
@staticmethod
|
|
199
|
-
def _get_bases(fr: type) -> list[tuple[str, str,
|
|
200
|
-
bases: list[tuple[str, str,
|
|
198
|
+
def _get_bases(fr: type) -> list[tuple[str, str, str | None]]:
|
|
199
|
+
bases: list[tuple[str, str, str | None]] = []
|
|
201
200
|
for base in fr.__mro__:
|
|
202
201
|
model_store_name = (
|
|
203
202
|
ModelStore.get_name(base) if issubclass(base, DataModel) else None
|
|
@@ -294,7 +293,7 @@ class SignalSchema:
|
|
|
294
293
|
@staticmethod
|
|
295
294
|
def _deserialize_custom_type(
|
|
296
295
|
type_name: str, custom_types: dict[str, Any]
|
|
297
|
-
) ->
|
|
296
|
+
) -> type | None:
|
|
298
297
|
"""Given a type name like MyType@v1 gets a type from ModelStore or recreates
|
|
299
298
|
it based on the information from the custom types dict that includes fields and
|
|
300
299
|
bases."""
|
|
@@ -327,7 +326,7 @@ class SignalSchema:
|
|
|
327
326
|
return None
|
|
328
327
|
|
|
329
328
|
@staticmethod
|
|
330
|
-
def _resolve_type(type_name: str, custom_types: dict[str, Any]) ->
|
|
329
|
+
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> type | None:
|
|
331
330
|
"""Convert a string-based type back into a python type."""
|
|
332
331
|
type_name = type_name.strip()
|
|
333
332
|
if not type_name:
|
|
@@ -336,7 +335,7 @@ class SignalSchema:
|
|
|
336
335
|
return None
|
|
337
336
|
|
|
338
337
|
bracket_idx = type_name.find("[")
|
|
339
|
-
subtypes:
|
|
338
|
+
subtypes: tuple[type | None, ...] | None = None
|
|
340
339
|
if bracket_idx > -1:
|
|
341
340
|
if bracket_idx == 0:
|
|
342
341
|
raise ValueError("Type cannot start with '['")
|
|
@@ -493,7 +492,7 @@ class SignalSchema:
|
|
|
493
492
|
return math.isnan(value) or value is None
|
|
494
493
|
return value is None
|
|
495
494
|
|
|
496
|
-
def get_file_signal(self) ->
|
|
495
|
+
def get_file_signal(self) -> str | None:
|
|
497
496
|
for signal_name, signal_type in self.values.items():
|
|
498
497
|
if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass(
|
|
499
498
|
fr, File
|
|
@@ -503,8 +502,8 @@ class SignalSchema:
|
|
|
503
502
|
|
|
504
503
|
def slice(
|
|
505
504
|
self,
|
|
506
|
-
params: dict[str,
|
|
507
|
-
setup:
|
|
505
|
+
params: dict[str, DataType | Any],
|
|
506
|
+
setup: dict[str, Callable] | None = None,
|
|
508
507
|
is_batch: bool = False,
|
|
509
508
|
) -> "SignalSchema":
|
|
510
509
|
"""
|
|
@@ -528,9 +527,13 @@ class SignalSchema:
|
|
|
528
527
|
schema_origin = get_origin(schema_type)
|
|
529
528
|
param_origin = get_origin(param_type)
|
|
530
529
|
|
|
531
|
-
if schema_origin
|
|
530
|
+
if schema_origin in (Union, types.UnionType) and type(None) in get_args(
|
|
531
|
+
schema_type
|
|
532
|
+
):
|
|
532
533
|
schema_type = get_args(schema_type)[0]
|
|
533
|
-
if param_origin
|
|
534
|
+
if param_origin in (Union, types.UnionType) and type(None) in get_args(
|
|
535
|
+
param_type
|
|
536
|
+
):
|
|
534
537
|
param_type = get_args(param_type)[0]
|
|
535
538
|
|
|
536
539
|
if is_batch:
|
|
@@ -610,8 +613,8 @@ class SignalSchema:
|
|
|
610
613
|
raise SignalResolvingError([col_name], "is not found")
|
|
611
614
|
|
|
612
615
|
def db_signals(
|
|
613
|
-
self, name:
|
|
614
|
-
) ->
|
|
616
|
+
self, name: str | None = None, as_columns=False, include_hidden: bool = True
|
|
617
|
+
) -> list[str] | list[Column]:
|
|
615
618
|
"""
|
|
616
619
|
Returns DB columns as strings or Column objects with proper types
|
|
617
620
|
Optionally, it can filter results by specific object, returning only his signals
|
|
@@ -802,7 +805,7 @@ class SignalSchema:
|
|
|
802
805
|
@staticmethod
|
|
803
806
|
def _build_tree(
|
|
804
807
|
values: dict[str, DataType],
|
|
805
|
-
) -> dict[str, tuple[DataType,
|
|
808
|
+
) -> dict[str, tuple[DataType, dict | None]]:
|
|
806
809
|
return {
|
|
807
810
|
name: (val, SignalSchema._build_tree_for_type(val))
|
|
808
811
|
for name, val in values.items()
|
|
@@ -834,7 +837,7 @@ class SignalSchema:
|
|
|
834
837
|
substree, new_prefix, depth + 1, include_hidden
|
|
835
838
|
)
|
|
836
839
|
|
|
837
|
-
def print_tree(self, indent: int = 2, start_at: int = 0, file:
|
|
840
|
+
def print_tree(self, indent: int = 2, start_at: int = 0, file: IO | None = None):
|
|
838
841
|
for path, type_, _, depth in self.get_flat_tree():
|
|
839
842
|
total_indent = start_at + depth * indent
|
|
840
843
|
col_name = " " * total_indent + path[-1]
|
|
@@ -873,15 +876,20 @@ class SignalSchema:
|
|
|
873
876
|
return self.values.pop(name)
|
|
874
877
|
|
|
875
878
|
@staticmethod
|
|
876
|
-
def _type_to_str(type_:
|
|
879
|
+
def _type_to_str(type_: type | None, subtypes: list | None = None) -> str: # noqa: C901, PLR0911
|
|
877
880
|
"""Convert a type to a string-based representation."""
|
|
878
881
|
if type_ is None:
|
|
879
882
|
return "NoneType"
|
|
880
883
|
|
|
881
884
|
origin = get_origin(type_)
|
|
882
885
|
|
|
883
|
-
if origin
|
|
886
|
+
if origin in (Union, types.UnionType):
|
|
884
887
|
args = get_args(type_)
|
|
888
|
+
if len(args) == 2 and type(None) in args:
|
|
889
|
+
# This is an Optional type.
|
|
890
|
+
non_none_type = args[0] if args[1] is type(None) else args[1]
|
|
891
|
+
type_str = SignalSchema._type_to_str(non_none_type, subtypes)
|
|
892
|
+
return f"Optional[{type_str}]"
|
|
885
893
|
formatted_types = ", ".join(
|
|
886
894
|
SignalSchema._type_to_str(arg, subtypes) for arg in args
|
|
887
895
|
)
|
|
@@ -892,19 +900,19 @@ class SignalSchema:
|
|
|
892
900
|
return f"Optional[{type_str}]"
|
|
893
901
|
if origin in (list, List): # noqa: UP006
|
|
894
902
|
args = get_args(type_)
|
|
903
|
+
if len(args) == 0:
|
|
904
|
+
return "list"
|
|
895
905
|
type_str = SignalSchema._type_to_str(args[0], subtypes)
|
|
896
906
|
return f"list[{type_str}]"
|
|
897
907
|
if origin in (dict, Dict): # noqa: UP006
|
|
898
908
|
args = get_args(type_)
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
)
|
|
902
|
-
|
|
903
|
-
f",
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
)
|
|
907
|
-
return f"dict[{type_str}{vals}]"
|
|
909
|
+
if len(args) == 0:
|
|
910
|
+
return "dict"
|
|
911
|
+
key_type = SignalSchema._type_to_str(args[0], subtypes)
|
|
912
|
+
if len(args) == 1:
|
|
913
|
+
return f"dict[{key_type}, Any]"
|
|
914
|
+
val_type = SignalSchema._type_to_str(args[1], subtypes)
|
|
915
|
+
return f"dict[{key_type}, {val_type}]"
|
|
908
916
|
if origin == Annotated:
|
|
909
917
|
args = get_args(type_)
|
|
910
918
|
return SignalSchema._type_to_str(args[0], subtypes)
|
|
@@ -918,7 +926,7 @@ class SignalSchema:
|
|
|
918
926
|
# Include this type in the list of all subtypes, if requested.
|
|
919
927
|
subtypes.append(type_)
|
|
920
928
|
if not hasattr(type_, "__name__"):
|
|
921
|
-
# This can happen for some third-party or custom types
|
|
929
|
+
# This can happen for some third-party or custom types
|
|
922
930
|
warnings.warn(
|
|
923
931
|
f"Unable to determine name of type '{type_}'.",
|
|
924
932
|
SignalSchemaWarning,
|
|
@@ -933,7 +941,7 @@ class SignalSchema:
|
|
|
933
941
|
@staticmethod
|
|
934
942
|
def _build_tree_for_type(
|
|
935
943
|
model: DataType,
|
|
936
|
-
) ->
|
|
944
|
+
) -> dict[str, tuple[DataType, dict | None]] | None:
|
|
937
945
|
if (fr := ModelStore.to_pydantic(model)) is not None:
|
|
938
946
|
return SignalSchema._build_tree_for_model(fr)
|
|
939
947
|
return None
|
|
@@ -941,8 +949,8 @@ class SignalSchema:
|
|
|
941
949
|
@staticmethod
|
|
942
950
|
def _build_tree_for_model(
|
|
943
951
|
model: type[BaseModel],
|
|
944
|
-
) ->
|
|
945
|
-
res: dict[str, tuple[DataType,
|
|
952
|
+
) -> dict[str, tuple[DataType, dict | None]] | None:
|
|
953
|
+
res: dict[str, tuple[DataType, dict | None]] = {}
|
|
946
954
|
|
|
947
955
|
for name, f_info in model.model_fields.items():
|
|
948
956
|
anno = f_info.annotation
|
|
@@ -991,7 +999,7 @@ class SignalSchema:
|
|
|
991
999
|
schema: dict[str, Any] = {}
|
|
992
1000
|
schema_custom_types: dict[str, CustomType] = {}
|
|
993
1001
|
|
|
994
|
-
data_model_bases:
|
|
1002
|
+
data_model_bases: list[tuple[str, str, str | None]] | None = None
|
|
995
1003
|
|
|
996
1004
|
signal_partials: dict[str, str] = {}
|
|
997
1005
|
partial_versions: dict[str, int] = {}
|
datachain/lib/text.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def convert_text(
|
|
8
|
-
text:
|
|
9
|
-
tokenizer:
|
|
10
|
-
tokenizer_kwargs:
|
|
11
|
-
encoder:
|
|
12
|
-
device:
|
|
13
|
-
) ->
|
|
9
|
+
text: str | list[str],
|
|
10
|
+
tokenizer: Callable | None = None,
|
|
11
|
+
tokenizer_kwargs: dict[str, Any] | None = None,
|
|
12
|
+
encoder: Callable | None = None,
|
|
13
|
+
device: str | torch.device | None = None,
|
|
14
|
+
) -> str | list[str] | torch.Tensor:
|
|
14
15
|
"""
|
|
15
16
|
Tokenize and otherwise transform text.
|
|
16
17
|
|
datachain/lib/udf.py
CHANGED
|
@@ -4,7 +4,7 @@ import traceback
|
|
|
4
4
|
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
|
5
5
|
from contextlib import closing, nullcontext
|
|
6
6
|
from functools import partial
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
7
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
8
8
|
|
|
9
9
|
import attrs
|
|
10
10
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
@@ -60,7 +60,7 @@ UDFResult = dict[str, Any]
|
|
|
60
60
|
class UDFAdapter:
|
|
61
61
|
inner: "UDFBase"
|
|
62
62
|
output: UDFOutputSpec
|
|
63
|
-
batch_size:
|
|
63
|
+
batch_size: int | None = None
|
|
64
64
|
batch: int = 1
|
|
65
65
|
|
|
66
66
|
def hash(self) -> str:
|
|
@@ -152,7 +152,7 @@ class UDFBase(AbstractUDF):
|
|
|
152
152
|
prefetch: int = 0
|
|
153
153
|
|
|
154
154
|
def __init__(self):
|
|
155
|
-
self.params:
|
|
155
|
+
self.params: SignalSchema | None = None
|
|
156
156
|
self.output = None
|
|
157
157
|
self._func = None
|
|
158
158
|
|
|
@@ -197,7 +197,7 @@ class UDFBase(AbstractUDF):
|
|
|
197
197
|
self,
|
|
198
198
|
sign: "UdfSignature",
|
|
199
199
|
params: "SignalSchema",
|
|
200
|
-
func:
|
|
200
|
+
func: Callable | None,
|
|
201
201
|
):
|
|
202
202
|
self.params = params
|
|
203
203
|
self.output = sign.output_schema
|
|
@@ -246,7 +246,7 @@ class UDFBase(AbstractUDF):
|
|
|
246
246
|
|
|
247
247
|
def to_udf_wrapper(
|
|
248
248
|
self,
|
|
249
|
-
batch_size:
|
|
249
|
+
batch_size: int | None = None,
|
|
250
250
|
batch: int = 1,
|
|
251
251
|
) -> UDFAdapter:
|
|
252
252
|
return UDFAdapter(
|
|
@@ -304,11 +304,11 @@ class UDFBase(AbstractUDF):
|
|
|
304
304
|
self._set_stream_recursive(field_value, catalog, cache, download_cb)
|
|
305
305
|
|
|
306
306
|
def _prepare_row(self, row, udf_fields, catalog, cache, download_cb):
|
|
307
|
-
row_dict = RowDict(zip(udf_fields, row))
|
|
307
|
+
row_dict = RowDict(zip(udf_fields, row, strict=False))
|
|
308
308
|
return self._parse_row(row_dict, catalog, cache, download_cb)
|
|
309
309
|
|
|
310
310
|
def _prepare_row_and_id(self, row, udf_fields, catalog, cache, download_cb):
|
|
311
|
-
row_dict = RowDict(zip(udf_fields, row))
|
|
311
|
+
row_dict = RowDict(zip(udf_fields, row, strict=False))
|
|
312
312
|
udf_input = self._parse_row(row_dict, catalog, cache, download_cb)
|
|
313
313
|
return row_dict["sys__id"], *udf_input
|
|
314
314
|
|
|
@@ -333,7 +333,7 @@ def noop(*args, **kwargs):
|
|
|
333
333
|
|
|
334
334
|
async def _prefetch_input(
|
|
335
335
|
row: T,
|
|
336
|
-
download_cb:
|
|
336
|
+
download_cb: Callback | None = None,
|
|
337
337
|
after_prefetch: "Callable[[], None]" = noop,
|
|
338
338
|
) -> T:
|
|
339
339
|
for obj in row:
|
|
@@ -356,8 +356,8 @@ def _remove_prefetched(row: T) -> None:
|
|
|
356
356
|
def _prefetch_inputs(
|
|
357
357
|
prepared_inputs: "Iterable[T]",
|
|
358
358
|
prefetch: int = 0,
|
|
359
|
-
download_cb:
|
|
360
|
-
after_prefetch:
|
|
359
|
+
download_cb: Callback | None = None,
|
|
360
|
+
after_prefetch: Callable[[], None] | None = None,
|
|
361
361
|
remove_prefetched: bool = False,
|
|
362
362
|
) -> "abc.Generator[T, None, None]":
|
|
363
363
|
if not prefetch:
|
|
@@ -426,7 +426,10 @@ class Mapper(UDFBase):
|
|
|
426
426
|
for id_, *udf_args in prepared_inputs:
|
|
427
427
|
result_objs = self.process_safe(udf_args)
|
|
428
428
|
udf_output = self._flatten_row(result_objs)
|
|
429
|
-
output = [
|
|
429
|
+
output = [
|
|
430
|
+
{"sys__id": id_}
|
|
431
|
+
| dict(zip(self.signal_names, udf_output, strict=False))
|
|
432
|
+
]
|
|
430
433
|
processed_cb.relative_update(1)
|
|
431
434
|
yield output
|
|
432
435
|
|
|
@@ -474,7 +477,8 @@ class BatchMapper(UDFBase):
|
|
|
474
477
|
row, udf_fields, catalog, cache, download_cb
|
|
475
478
|
)
|
|
476
479
|
for row in batch
|
|
477
|
-
]
|
|
480
|
+
],
|
|
481
|
+
strict=False,
|
|
478
482
|
)
|
|
479
483
|
result_objs = list(self.process_safe(udf_args))
|
|
480
484
|
n_objs = len(result_objs)
|
|
@@ -483,8 +487,9 @@ class BatchMapper(UDFBase):
|
|
|
483
487
|
)
|
|
484
488
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
485
489
|
output = [
|
|
486
|
-
{"sys__id": row_id}
|
|
487
|
-
|
|
490
|
+
{"sys__id": row_id}
|
|
491
|
+
| dict(zip(self.signal_names, signals, strict=False))
|
|
492
|
+
for row_id, signals in zip(row_ids, udf_outputs, strict=False)
|
|
488
493
|
]
|
|
489
494
|
processed_cb.relative_update(n_rows)
|
|
490
495
|
yield output
|
|
@@ -520,7 +525,7 @@ class Generator(UDFBase):
|
|
|
520
525
|
with safe_closing(self.process_safe(row)) as result_objs:
|
|
521
526
|
for result_obj in result_objs:
|
|
522
527
|
udf_output = self._flatten_row(result_obj)
|
|
523
|
-
yield dict(zip(self.signal_names, udf_output))
|
|
528
|
+
yield dict(zip(self.signal_names, udf_output, strict=False))
|
|
524
529
|
|
|
525
530
|
prepared_inputs = _prepare_rows(udf_inputs)
|
|
526
531
|
prepared_inputs = _prefetch_inputs(
|
|
@@ -559,11 +564,14 @@ class Aggregator(UDFBase):
|
|
|
559
564
|
*[
|
|
560
565
|
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
|
|
561
566
|
for row in batch
|
|
562
|
-
]
|
|
567
|
+
],
|
|
568
|
+
strict=False,
|
|
563
569
|
)
|
|
564
570
|
result_objs = self.process_safe(udf_args)
|
|
565
571
|
udf_outputs = (self._flatten_row(row) for row in result_objs)
|
|
566
|
-
output = (
|
|
572
|
+
output = (
|
|
573
|
+
dict(zip(self.signal_names, row, strict=False)) for row in udf_outputs
|
|
574
|
+
)
|
|
567
575
|
processed_cb.relative_update(len(batch))
|
|
568
576
|
yield output
|
|
569
577
|
|
datachain/lib/udf_signature.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from collections.abc import Generator, Iterator, Sequence
|
|
2
|
+
from collections.abc import Callable, Generator, Iterator, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, get_args, get_origin
|
|
5
5
|
|
|
6
6
|
from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
|
|
7
7
|
from datachain.lib.signal_schema import SignalSchema
|
|
@@ -17,8 +17,8 @@ class UdfSignatureError(DataChainParamsError):
|
|
|
17
17
|
|
|
18
18
|
@dataclass
|
|
19
19
|
class UdfSignature: # noqa: PLW1641
|
|
20
|
-
func:
|
|
21
|
-
params: dict[str,
|
|
20
|
+
func: Callable | UDFBase
|
|
21
|
+
params: dict[str, DataType | Any]
|
|
22
22
|
output_schema: SignalSchema
|
|
23
23
|
|
|
24
24
|
DEFAULT_RETURN_TYPE = str
|
|
@@ -28,9 +28,9 @@ class UdfSignature: # noqa: PLW1641
|
|
|
28
28
|
cls,
|
|
29
29
|
chain: str,
|
|
30
30
|
signal_map: dict[str, Callable],
|
|
31
|
-
func:
|
|
32
|
-
params:
|
|
33
|
-
output:
|
|
31
|
+
func: UDFBase | Callable | None = None,
|
|
32
|
+
params: str | Sequence[str] | None = None,
|
|
33
|
+
output: DataType | Sequence[str] | dict[str, DataType] | None = None,
|
|
34
34
|
is_generator: bool = True,
|
|
35
35
|
) -> "UdfSignature":
|
|
36
36
|
keys = ", ".join(signal_map.keys())
|
|
@@ -40,7 +40,7 @@ class UdfSignature: # noqa: PLW1641
|
|
|
40
40
|
f"multiple signals '{keys}' are not supported in processors."
|
|
41
41
|
" Chain multiple processors instead.",
|
|
42
42
|
)
|
|
43
|
-
udf_func:
|
|
43
|
+
udf_func: UDFBase | Callable
|
|
44
44
|
if len(signal_map) == 1:
|
|
45
45
|
if func is not None:
|
|
46
46
|
raise UdfSignatureError(
|
|
@@ -62,7 +62,7 @@ class UdfSignature: # noqa: PLW1641
|
|
|
62
62
|
chain, udf_func
|
|
63
63
|
)
|
|
64
64
|
|
|
65
|
-
udf_params: dict[str,
|
|
65
|
+
udf_params: dict[str, DataType | Any] = {}
|
|
66
66
|
if params:
|
|
67
67
|
udf_params = (
|
|
68
68
|
{params: Any} if isinstance(params, str) else dict.fromkeys(params, Any)
|
|
@@ -128,7 +128,7 @@ class UdfSignature: # noqa: PLW1641
|
|
|
128
128
|
f" return type length ({len(func_outs_sign)}) does not match",
|
|
129
129
|
)
|
|
130
130
|
|
|
131
|
-
udf_output_map = dict(zip(output, func_outs_sign))
|
|
131
|
+
udf_output_map = dict(zip(output, func_outs_sign, strict=False))
|
|
132
132
|
elif isinstance(output, dict):
|
|
133
133
|
for key, value in output.items():
|
|
134
134
|
if not isinstance(key, str):
|
|
@@ -164,7 +164,7 @@ class UdfSignature: # noqa: PLW1641
|
|
|
164
164
|
|
|
165
165
|
@staticmethod
|
|
166
166
|
def _func_signature(
|
|
167
|
-
chain: str, udf_func:
|
|
167
|
+
chain: str, udf_func: Callable | UDFBase
|
|
168
168
|
) -> tuple[dict[str, type], Sequence[type], bool]:
|
|
169
169
|
if isinstance(udf_func, AbstractUDF):
|
|
170
170
|
func = udf_func.process # type: ignore[unreachable]
|