datachain 0.30.5__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +4 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +276 -354
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +8 -3
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +10 -17
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +42 -27
- datachain/cli/commands/ls.py +15 -15
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/__init__.py +3 -43
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +34 -23
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +157 -0
- datachain/client/local.py +11 -7
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +2 -0
- datachain/data_storage/metastore.py +716 -137
- datachain/data_storage/schema.py +20 -27
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +114 -114
- datachain/data_storage/warehouse.py +140 -48
- datachain/dataset.py +109 -89
- datachain/delta.py +117 -42
- datachain/diff/__init__.py +25 -33
- datachain/error.py +24 -0
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +63 -45
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +18 -15
- datachain/lib/audio.py +60 -59
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/values_to_tuples.py +151 -53
- datachain/lib/data_model.py +23 -19
- datachain/lib/dataset_info.py +7 -7
- datachain/lib/dc/__init__.py +2 -1
- datachain/lib/dc/csv.py +22 -26
- datachain/lib/dc/database.py +37 -34
- datachain/lib/dc/datachain.py +518 -324
- datachain/lib/dc/datasets.py +38 -30
- datachain/lib/dc/hf.py +16 -20
- datachain/lib/dc/json.py +17 -18
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +33 -21
- datachain/lib/dc/records.py +9 -13
- datachain/lib/dc/storage.py +103 -65
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +17 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +187 -50
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +2 -3
- datachain/lib/model_store.py +20 -8
- datachain/lib/namespaces.py +59 -7
- datachain/lib/projects.py +51 -9
- datachain/lib/pytorch.py +31 -23
- datachain/lib/settings.py +188 -85
- datachain/lib/signal_schema.py +302 -64
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +103 -63
- datachain/lib/udf_signature.py +59 -34
- datachain/lib/utils.py +20 -0
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +31 -36
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +12 -5
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +22 -3
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +4 -4
- datachain/query/batch.py +10 -12
- datachain/query/dataset.py +376 -194
- datachain/query/dispatch.py +112 -84
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/queue.py +2 -1
- datachain/query/schema.py +7 -6
- datachain/query/session.py +190 -33
- datachain/query/udf.py +9 -6
- datachain/remote/studio.py +90 -53
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +37 -25
- datachain/sql/sqlite/types.py +1 -1
- datachain/sql/types.py +36 -5
- datachain/studio.py +49 -40
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +39 -48
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/METADATA +26 -38
- datachain-0.39.0.dist-info/RECORD +173 -0
- datachain/cli/commands/query.py +0 -54
- datachain/query/utils.py +0 -36
- datachain-0.30.5.dist-info/RECORD +0 -168
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/WHEEL +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.30.5.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/lib/model_store.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
import
|
|
3
|
-
from typing import Any, ClassVar, Optional
|
|
2
|
+
from typing import Any, ClassVar
|
|
4
3
|
|
|
5
4
|
from pydantic import BaseModel
|
|
6
5
|
|
|
7
|
-
logger = logging.getLogger(__name__)
|
|
8
|
-
|
|
9
6
|
|
|
10
7
|
class ModelStore:
|
|
11
8
|
store: ClassVar[dict[str, dict[int, type[BaseModel]]]] = {}
|
|
@@ -14,7 +11,7 @@ class ModelStore:
|
|
|
14
11
|
def get_version(cls, model: type[BaseModel]) -> int:
|
|
15
12
|
if not hasattr(model, "_version"):
|
|
16
13
|
return 0
|
|
17
|
-
return model._version
|
|
14
|
+
return model._version # type: ignore[attr-defined]
|
|
18
15
|
|
|
19
16
|
@classmethod
|
|
20
17
|
def get_name(cls, model) -> str:
|
|
@@ -39,7 +36,7 @@ class ModelStore:
|
|
|
39
36
|
cls.register(anno)
|
|
40
37
|
|
|
41
38
|
@classmethod
|
|
42
|
-
def get(cls, name: str, version:
|
|
39
|
+
def get(cls, name: str, version: int | None = None) -> type | None:
|
|
43
40
|
class_dict = cls.store.get(name, None)
|
|
44
41
|
if class_dict is None:
|
|
45
42
|
return None
|
|
@@ -77,7 +74,7 @@ class ModelStore:
|
|
|
77
74
|
)
|
|
78
75
|
|
|
79
76
|
@staticmethod
|
|
80
|
-
def to_pydantic(val) ->
|
|
77
|
+
def to_pydantic(val) -> type[BaseModel] | None:
|
|
81
78
|
if val is None or not ModelStore.is_pydantic(val):
|
|
82
79
|
return None
|
|
83
80
|
return val
|
|
@@ -98,6 +95,21 @@ class ModelStore:
|
|
|
98
95
|
(e.g. from by-value cloudpickle in workers) reports built state but
|
|
99
96
|
nested model field schemas aren't fully resolved yet.
|
|
100
97
|
"""
|
|
98
|
+
visited: set[type[BaseModel]] = set()
|
|
99
|
+
visiting: set[type[BaseModel]] = set()
|
|
100
|
+
|
|
101
|
+
def visit(model: type[BaseModel]) -> None:
|
|
102
|
+
if model in visited or model in visiting:
|
|
103
|
+
return
|
|
104
|
+
visiting.add(model)
|
|
105
|
+
for field in model.model_fields.values():
|
|
106
|
+
child = cls.to_pydantic(field.annotation)
|
|
107
|
+
if child is not None:
|
|
108
|
+
visit(child)
|
|
109
|
+
visiting.remove(model)
|
|
110
|
+
model.model_rebuild(force=True)
|
|
111
|
+
visited.add(model)
|
|
112
|
+
|
|
101
113
|
for versions in cls.store.values():
|
|
102
114
|
for model in versions.values():
|
|
103
|
-
model
|
|
115
|
+
visit(model)
|
datachain/lib/namespaces.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
from datachain.error import (
|
|
2
|
+
NamespaceCreateNotAllowedError,
|
|
3
|
+
NamespaceDeleteNotAllowedError,
|
|
4
|
+
)
|
|
5
|
+
from datachain.lib.projects import delete as delete_project
|
|
6
|
+
from datachain.namespace import Namespace, parse_name
|
|
5
7
|
from datachain.query import Session
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
def create(
|
|
9
|
-
name: str, descr:
|
|
11
|
+
name: str, descr: str | None = None, session: Session | None = None
|
|
10
12
|
) -> Namespace:
|
|
11
13
|
"""
|
|
12
14
|
Creates a new namespace.
|
|
@@ -38,7 +40,7 @@ def create(
|
|
|
38
40
|
return session.catalog.metastore.create_namespace(name, descr)
|
|
39
41
|
|
|
40
42
|
|
|
41
|
-
def get(name: str, session:
|
|
43
|
+
def get(name: str, session: Session | None = None) -> Namespace:
|
|
42
44
|
"""
|
|
43
45
|
Gets a namespace by name.
|
|
44
46
|
If the namespace is not found, a `NamespaceNotFoundError` is raised.
|
|
@@ -57,7 +59,7 @@ def get(name: str, session: Optional[Session] = None) -> Namespace:
|
|
|
57
59
|
return session.catalog.metastore.get_namespace(name)
|
|
58
60
|
|
|
59
61
|
|
|
60
|
-
def ls(session:
|
|
62
|
+
def ls(session: Session | None = None) -> list[Namespace]:
|
|
61
63
|
"""
|
|
62
64
|
Gets a list of all namespaces.
|
|
63
65
|
|
|
@@ -71,3 +73,53 @@ def ls(session: Optional[Session] = None) -> list[Namespace]:
|
|
|
71
73
|
```
|
|
72
74
|
"""
|
|
73
75
|
return Session.get(session).catalog.metastore.list_namespaces()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def delete_namespace(name: str, session: Session | None = None) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Removes a namespace by name.
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
NamespaceNotFoundError: If the namespace does not exist.
|
|
84
|
+
NamespaceDeleteNotAllowedError: If the namespace is non-empty,
|
|
85
|
+
is the default namespace, or is a system namespace,
|
|
86
|
+
as these cannot be removed.
|
|
87
|
+
|
|
88
|
+
Parameters:
|
|
89
|
+
name: The name of the namespace.
|
|
90
|
+
session: Session to use for getting project.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
```py
|
|
94
|
+
import datachain as dc
|
|
95
|
+
dc.delete_namespace("dev")
|
|
96
|
+
```
|
|
97
|
+
"""
|
|
98
|
+
session = Session.get(session)
|
|
99
|
+
metastore = session.catalog.metastore
|
|
100
|
+
|
|
101
|
+
namespace_name, project_name = parse_name(name)
|
|
102
|
+
|
|
103
|
+
if project_name:
|
|
104
|
+
return delete_project(project_name, namespace_name, session)
|
|
105
|
+
|
|
106
|
+
namespace = metastore.get_namespace(name)
|
|
107
|
+
|
|
108
|
+
if name == metastore.system_namespace_name:
|
|
109
|
+
raise NamespaceDeleteNotAllowedError(
|
|
110
|
+
f"Namespace {metastore.system_namespace_name} cannot be removed"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if name == metastore.default_namespace_name:
|
|
114
|
+
raise NamespaceDeleteNotAllowedError(
|
|
115
|
+
f"Namespace {metastore.default_namespace_name} cannot be removed"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
num_projects = metastore.count_projects(namespace.id)
|
|
119
|
+
if num_projects > 0:
|
|
120
|
+
raise NamespaceDeleteNotAllowedError(
|
|
121
|
+
f"Namespace cannot be removed. It contains {num_projects} project(s). "
|
|
122
|
+
"Please remove the project(s) first."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
metastore.remove_namespace(namespace.id)
|
datachain/lib/projects.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from datachain.error import ProjectCreateNotAllowedError
|
|
1
|
+
from datachain.error import ProjectCreateNotAllowedError, ProjectDeleteNotAllowedError
|
|
4
2
|
from datachain.project import Project
|
|
5
3
|
from datachain.query import Session
|
|
6
4
|
|
|
@@ -8,8 +6,8 @@ from datachain.query import Session
|
|
|
8
6
|
def create(
|
|
9
7
|
namespace: str,
|
|
10
8
|
name: str,
|
|
11
|
-
descr:
|
|
12
|
-
session:
|
|
9
|
+
descr: str | None = None,
|
|
10
|
+
session: Session | None = None,
|
|
13
11
|
) -> Project:
|
|
14
12
|
"""
|
|
15
13
|
Creates a new project under a specified namespace.
|
|
@@ -42,7 +40,7 @@ def create(
|
|
|
42
40
|
return session.catalog.metastore.create_project(namespace, name, descr)
|
|
43
41
|
|
|
44
42
|
|
|
45
|
-
def get(name: str, namespace: str, session:
|
|
43
|
+
def get(name: str, namespace: str, session: Session | None) -> Project:
|
|
46
44
|
"""
|
|
47
45
|
Gets a project by name in some namespace.
|
|
48
46
|
If the project is not found, a `ProjectNotFoundError` is raised.
|
|
@@ -62,9 +60,7 @@ def get(name: str, namespace: str, session: Optional[Session]) -> Project:
|
|
|
62
60
|
return Session.get(session).catalog.metastore.get_project(name, namespace)
|
|
63
61
|
|
|
64
62
|
|
|
65
|
-
def ls(
|
|
66
|
-
namespace: Optional[str] = None, session: Optional[Session] = None
|
|
67
|
-
) -> list[Project]:
|
|
63
|
+
def ls(namespace: str | None = None, session: Session | None = None) -> list[Project]:
|
|
68
64
|
"""
|
|
69
65
|
Gets a list of projects in a specific namespace or from all namespaces.
|
|
70
66
|
|
|
@@ -86,3 +82,49 @@ def ls(
|
|
|
86
82
|
namespace_id = session.catalog.metastore.get_namespace(namespace).id
|
|
87
83
|
|
|
88
84
|
return session.catalog.metastore.list_projects(namespace_id)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def delete(name: str, namespace: str, session: Session | None = None) -> None:
|
|
88
|
+
"""
|
|
89
|
+
Removes a project by name within a namespace.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ProjectNotFoundError: If the project does not exist.
|
|
93
|
+
ProjectDeleteNotAllowedError: If the project is non-empty,
|
|
94
|
+
is the default project, or is a listing project,
|
|
95
|
+
as these cannot be removed.
|
|
96
|
+
|
|
97
|
+
Parameters:
|
|
98
|
+
name : The name of the project.
|
|
99
|
+
namespace : The name of the namespace.
|
|
100
|
+
session : Session to use for getting project.
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
```py
|
|
104
|
+
import datachain as dc
|
|
105
|
+
dc.delete_project("my-project", "local")
|
|
106
|
+
```
|
|
107
|
+
"""
|
|
108
|
+
session = Session.get(session)
|
|
109
|
+
metastore = session.catalog.metastore
|
|
110
|
+
|
|
111
|
+
project = metastore.get_project(name, namespace)
|
|
112
|
+
|
|
113
|
+
if metastore.is_listing_project(name, namespace):
|
|
114
|
+
raise ProjectDeleteNotAllowedError(
|
|
115
|
+
f"Project {metastore.listing_project_name} cannot be removed"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if metastore.is_default_project(name, namespace):
|
|
119
|
+
raise ProjectDeleteNotAllowedError(
|
|
120
|
+
f"Project {metastore.default_project_name} cannot be removed"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
num_datasets = metastore.count_datasets(project.id)
|
|
124
|
+
if num_datasets > 0:
|
|
125
|
+
raise ProjectDeleteNotAllowedError(
|
|
126
|
+
f"Project cannot be removed. It contains {num_datasets} dataset(s). "
|
|
127
|
+
"Please remove the dataset(s) first."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
metastore.remove_project(project.id)
|
datachain/lib/pytorch.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
import weakref
|
|
4
|
-
from collections.abc import Generator, Iterable, Iterator
|
|
4
|
+
from collections.abc import Callable, Generator, Iterable, Iterator
|
|
5
5
|
from contextlib import closing
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from PIL import Image
|
|
9
9
|
from torch import float32
|
|
@@ -43,13 +43,13 @@ class PytorchDataset(IterableDataset):
|
|
|
43
43
|
def __init__(
|
|
44
44
|
self,
|
|
45
45
|
name: str,
|
|
46
|
-
version:
|
|
47
|
-
catalog:
|
|
48
|
-
transform:
|
|
49
|
-
tokenizer:
|
|
50
|
-
tokenizer_kwargs:
|
|
46
|
+
version: str | None = None,
|
|
47
|
+
catalog: Catalog | None = None,
|
|
48
|
+
transform: "Transform | None" = None,
|
|
49
|
+
tokenizer: Callable | None = None,
|
|
50
|
+
tokenizer_kwargs: dict[str, Any] | None = None,
|
|
51
51
|
num_samples: int = 0,
|
|
52
|
-
dc_settings:
|
|
52
|
+
dc_settings: Settings | None = None,
|
|
53
53
|
remove_prefetched: bool = False,
|
|
54
54
|
):
|
|
55
55
|
"""
|
|
@@ -74,6 +74,7 @@ class PytorchDataset(IterableDataset):
|
|
|
74
74
|
self.tokenizer = tokenizer
|
|
75
75
|
self.tokenizer_kwargs = tokenizer_kwargs or {}
|
|
76
76
|
self.num_samples = num_samples
|
|
77
|
+
owns_catalog = catalog is None
|
|
77
78
|
if catalog is None:
|
|
78
79
|
catalog = get_catalog()
|
|
79
80
|
self._init_catalog(catalog)
|
|
@@ -84,7 +85,7 @@ class PytorchDataset(IterableDataset):
|
|
|
84
85
|
self.prefetch = prefetch
|
|
85
86
|
|
|
86
87
|
self._cache = catalog.cache
|
|
87
|
-
self._prefetch_cache:
|
|
88
|
+
self._prefetch_cache: Cache | None = None
|
|
88
89
|
self._remove_prefetched = remove_prefetched
|
|
89
90
|
if prefetch and not self.cache:
|
|
90
91
|
tmp_dir = catalog.cache.tmp_dir
|
|
@@ -93,6 +94,10 @@ class PytorchDataset(IterableDataset):
|
|
|
93
94
|
self._cache = self._prefetch_cache
|
|
94
95
|
weakref.finalize(self, self._prefetch_cache.destroy)
|
|
95
96
|
|
|
97
|
+
# Close the catalog if we created it - we only needed it for clone params
|
|
98
|
+
if owns_catalog:
|
|
99
|
+
catalog.close()
|
|
100
|
+
|
|
96
101
|
def close(self) -> None:
|
|
97
102
|
if self._prefetch_cache:
|
|
98
103
|
self._prefetch_cache.destroy()
|
|
@@ -104,7 +109,7 @@ class PytorchDataset(IterableDataset):
|
|
|
104
109
|
self._ms_params = catalog.metastore.clone_params()
|
|
105
110
|
self._wh_params = catalog.warehouse.clone_params()
|
|
106
111
|
self._catalog_params = catalog.get_init_params()
|
|
107
|
-
self.catalog:
|
|
112
|
+
self.catalog: Catalog | None = None
|
|
108
113
|
|
|
109
114
|
def _get_catalog(self) -> "Catalog":
|
|
110
115
|
ms_cls, ms_args, ms_kwargs = self._ms_params
|
|
@@ -121,19 +126,22 @@ class PytorchDataset(IterableDataset):
|
|
|
121
126
|
total_workers: int,
|
|
122
127
|
) -> Generator[tuple[Any, ...], None, None]:
|
|
123
128
|
catalog = self._get_catalog()
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
129
|
+
try:
|
|
130
|
+
session = Session("PyTorch", catalog=catalog)
|
|
131
|
+
ds = read_dataset(
|
|
132
|
+
name=self.name, version=self.version, session=session
|
|
133
|
+
).settings(cache=self.cache, prefetch=self.prefetch)
|
|
134
|
+
|
|
135
|
+
# remove file signals from dataset
|
|
136
|
+
schema = ds.signals_schema.clone_without_file_signals()
|
|
137
|
+
ds = ds.select(*schema.values.keys())
|
|
138
|
+
|
|
139
|
+
if self.num_samples > 0:
|
|
140
|
+
ds = ds.sample(self.num_samples)
|
|
141
|
+
ds = ds.chunk(total_rank, total_workers)
|
|
142
|
+
yield from ds.to_iter()
|
|
143
|
+
finally:
|
|
144
|
+
catalog.close()
|
|
137
145
|
|
|
138
146
|
def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
|
|
139
147
|
from datachain.lib.udf import _prefetch_inputs
|
datachain/lib/settings.py
CHANGED
|
@@ -1,111 +1,214 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
from datachain.lib.utils import DataChainParamsError
|
|
2
|
-
|
|
4
|
+
|
|
5
|
+
DEFAULT_CACHE = False
|
|
6
|
+
DEFAULT_PREFETCH = 2
|
|
7
|
+
DEFAULT_BATCH_SIZE = 2_000
|
|
3
8
|
|
|
4
9
|
|
|
5
10
|
class SettingsError(DataChainParamsError):
|
|
6
|
-
def __init__(self, msg):
|
|
11
|
+
def __init__(self, msg: str) -> None:
|
|
7
12
|
super().__init__(f"Dataset settings error: {msg}")
|
|
8
13
|
|
|
9
14
|
|
|
10
15
|
class Settings:
|
|
11
|
-
|
|
16
|
+
"""Settings for datachain."""
|
|
17
|
+
|
|
18
|
+
_cache: bool | None
|
|
19
|
+
_prefetch: int | None
|
|
20
|
+
_parallel: bool | int | None
|
|
21
|
+
_workers: int | None
|
|
22
|
+
_namespace: str | None
|
|
23
|
+
_project: str | None
|
|
24
|
+
_min_task_size: int | None
|
|
25
|
+
_batch_size: int | None
|
|
26
|
+
|
|
27
|
+
def __init__( # noqa: C901, PLR0912
|
|
12
28
|
self,
|
|
13
|
-
cache=None,
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
):
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
)
|
|
29
|
+
cache: bool | None = None,
|
|
30
|
+
prefetch: bool | int | None = None,
|
|
31
|
+
parallel: bool | int | None = None,
|
|
32
|
+
workers: int | None = None,
|
|
33
|
+
namespace: str | None = None,
|
|
34
|
+
project: str | None = None,
|
|
35
|
+
min_task_size: int | None = None,
|
|
36
|
+
batch_size: int | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
if cache is None:
|
|
39
|
+
self._cache = None
|
|
40
|
+
else:
|
|
41
|
+
if not isinstance(cache, bool):
|
|
42
|
+
raise SettingsError(
|
|
43
|
+
"'cache' argument must be bool"
|
|
44
|
+
f" while {cache.__class__.__name__} was given"
|
|
45
|
+
)
|
|
46
|
+
self._cache = cache
|
|
47
|
+
|
|
48
|
+
if prefetch is None or prefetch is True:
|
|
49
|
+
self._prefetch = None
|
|
50
|
+
elif prefetch is False:
|
|
51
|
+
self._prefetch = 0 # disable prefetch (False == 0)
|
|
52
|
+
else:
|
|
53
|
+
if not isinstance(prefetch, int):
|
|
54
|
+
raise SettingsError(
|
|
55
|
+
"'prefetch' argument must be int or bool"
|
|
56
|
+
f" while {prefetch.__class__.__name__} was given"
|
|
57
|
+
)
|
|
58
|
+
if prefetch < 0:
|
|
59
|
+
raise SettingsError(
|
|
60
|
+
"'prefetch' argument must be non-negative integer"
|
|
61
|
+
f", {prefetch} was given"
|
|
62
|
+
)
|
|
63
|
+
self._prefetch = prefetch
|
|
64
|
+
|
|
65
|
+
if parallel is None or parallel is False:
|
|
66
|
+
self._parallel = None
|
|
67
|
+
elif parallel is True:
|
|
68
|
+
self._parallel = True
|
|
69
|
+
else:
|
|
70
|
+
if not isinstance(parallel, int):
|
|
71
|
+
raise SettingsError(
|
|
72
|
+
"'parallel' argument must be int or bool"
|
|
73
|
+
f" while {parallel.__class__.__name__} was given"
|
|
74
|
+
)
|
|
75
|
+
if parallel <= 0:
|
|
76
|
+
raise SettingsError(
|
|
77
|
+
"'parallel' argument must be positive integer"
|
|
78
|
+
f", {parallel} was given"
|
|
79
|
+
)
|
|
80
|
+
self._parallel = parallel
|
|
81
|
+
|
|
82
|
+
if workers is None:
|
|
83
|
+
self._workers = None
|
|
84
|
+
else:
|
|
85
|
+
if not isinstance(workers, int) or isinstance(workers, bool):
|
|
86
|
+
raise SettingsError(
|
|
87
|
+
"'workers' argument must be int"
|
|
88
|
+
f" while {workers.__class__.__name__} was given"
|
|
89
|
+
)
|
|
90
|
+
if workers <= 0:
|
|
91
|
+
raise SettingsError(
|
|
92
|
+
f"'workers' argument must be positive integer, {workers} was given"
|
|
93
|
+
)
|
|
94
|
+
self._workers = workers
|
|
95
|
+
|
|
96
|
+
if namespace is None:
|
|
97
|
+
self._namespace = None
|
|
98
|
+
else:
|
|
99
|
+
if not isinstance(namespace, str):
|
|
100
|
+
raise SettingsError(
|
|
101
|
+
"'namespace' argument must be str"
|
|
102
|
+
f", {namespace.__class__.__name__} was given"
|
|
103
|
+
)
|
|
104
|
+
self._namespace = namespace
|
|
105
|
+
|
|
106
|
+
if project is None:
|
|
107
|
+
self._project = None
|
|
108
|
+
else:
|
|
109
|
+
if not isinstance(project, str):
|
|
110
|
+
raise SettingsError(
|
|
111
|
+
"'project' argument must be str"
|
|
112
|
+
f", {project.__class__.__name__} was given"
|
|
113
|
+
)
|
|
114
|
+
self._project = project
|
|
115
|
+
|
|
116
|
+
if min_task_size is None:
|
|
117
|
+
self._min_task_size = None
|
|
118
|
+
else:
|
|
119
|
+
if not isinstance(min_task_size, int) or isinstance(min_task_size, bool):
|
|
120
|
+
raise SettingsError(
|
|
121
|
+
"'min_task_size' argument must be int"
|
|
122
|
+
f", {min_task_size.__class__.__name__} was given"
|
|
123
|
+
)
|
|
124
|
+
if min_task_size <= 0:
|
|
125
|
+
raise SettingsError(
|
|
126
|
+
"'min_task_size' argument must be positive integer"
|
|
127
|
+
f", {min_task_size} was given"
|
|
128
|
+
)
|
|
129
|
+
self._min_task_size = min_task_size
|
|
130
|
+
|
|
131
|
+
if batch_size is None:
|
|
132
|
+
self._batch_size = None
|
|
133
|
+
else:
|
|
134
|
+
if not isinstance(batch_size, int) or isinstance(batch_size, bool):
|
|
135
|
+
raise SettingsError(
|
|
136
|
+
"'batch_size' argument must be int"
|
|
137
|
+
f", {batch_size.__class__.__name__} was given"
|
|
138
|
+
)
|
|
139
|
+
if batch_size <= 0:
|
|
140
|
+
raise SettingsError(
|
|
141
|
+
"'batch_size' argument must be positive integer"
|
|
142
|
+
f", {batch_size} was given"
|
|
143
|
+
)
|
|
144
|
+
self._batch_size = batch_size
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def cache(self) -> bool:
|
|
148
|
+
return self._cache if self._cache is not None else DEFAULT_CACHE
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def prefetch(self) -> int | None:
|
|
152
|
+
return self._prefetch if self._prefetch is not None else DEFAULT_PREFETCH
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def parallel(self) -> bool | int | None:
|
|
156
|
+
return self._parallel if self._parallel is not None else None
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def workers(self) -> int | None:
|
|
160
|
+
return self._workers if self._workers is not None else None
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def namespace(self) -> str | None:
|
|
164
|
+
return self._namespace if self._namespace is not None else None
|
|
70
165
|
|
|
71
166
|
@property
|
|
72
|
-
def
|
|
73
|
-
return self.
|
|
167
|
+
def project(self) -> str | None:
|
|
168
|
+
return self._project if self._project is not None else None
|
|
74
169
|
|
|
75
170
|
@property
|
|
76
|
-
def
|
|
77
|
-
return self.
|
|
171
|
+
def min_task_size(self) -> int | None:
|
|
172
|
+
return self._min_task_size if self._min_task_size is not None else None
|
|
78
173
|
|
|
79
174
|
@property
|
|
80
|
-
def
|
|
81
|
-
return self.
|
|
175
|
+
def batch_size(self) -> int:
|
|
176
|
+
return self._batch_size if self._batch_size is not None else DEFAULT_BATCH_SIZE
|
|
82
177
|
|
|
83
|
-
def to_dict(self):
|
|
84
|
-
res = {}
|
|
178
|
+
def to_dict(self) -> dict[str, Any]:
|
|
179
|
+
res: dict[str, Any] = {}
|
|
85
180
|
if self._cache is not None:
|
|
86
181
|
res["cache"] = self.cache
|
|
87
|
-
if self.
|
|
182
|
+
if self._prefetch is not None:
|
|
183
|
+
res["prefetch"] = self.prefetch
|
|
184
|
+
if self._parallel is not None:
|
|
88
185
|
res["parallel"] = self.parallel
|
|
89
186
|
if self._workers is not None:
|
|
90
187
|
res["workers"] = self.workers
|
|
91
|
-
if self.
|
|
188
|
+
if self._min_task_size is not None:
|
|
92
189
|
res["min_task_size"] = self.min_task_size
|
|
93
|
-
if self.
|
|
190
|
+
if self._namespace is not None:
|
|
94
191
|
res["namespace"] = self.namespace
|
|
95
|
-
if self.
|
|
192
|
+
if self._project is not None:
|
|
96
193
|
res["project"] = self.project
|
|
97
|
-
if self.
|
|
98
|
-
res["
|
|
194
|
+
if self._batch_size is not None:
|
|
195
|
+
res["batch_size"] = self.batch_size
|
|
99
196
|
return res
|
|
100
197
|
|
|
101
|
-
def add(self, settings: "Settings"):
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
if settings.
|
|
109
|
-
self.
|
|
110
|
-
if settings.
|
|
111
|
-
self.
|
|
198
|
+
def add(self, settings: "Settings") -> None:
|
|
199
|
+
if settings._cache is not None:
|
|
200
|
+
self._cache = settings._cache
|
|
201
|
+
if settings._prefetch is not None:
|
|
202
|
+
self._prefetch = settings._prefetch
|
|
203
|
+
if settings._parallel is not None:
|
|
204
|
+
self._parallel = settings._parallel
|
|
205
|
+
if settings._workers is not None:
|
|
206
|
+
self._workers = settings._workers
|
|
207
|
+
if settings._namespace is not None:
|
|
208
|
+
self._namespace = settings._namespace
|
|
209
|
+
if settings._project is not None:
|
|
210
|
+
self._project = settings._project
|
|
211
|
+
if settings._min_task_size is not None:
|
|
212
|
+
self._min_task_size = settings._min_task_size
|
|
213
|
+
if settings._batch_size is not None:
|
|
214
|
+
self._batch_size = settings._batch_size
|