datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
from datachain.dataset import DatasetDependency
|
|
7
|
+
|
|
8
|
+
DDN = TypeVar("DDN", bound="DatasetDependencyNode")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class DatasetDependencyNode:
|
|
13
|
+
namespace: str
|
|
14
|
+
project: str
|
|
15
|
+
id: int
|
|
16
|
+
dataset_id: int | None
|
|
17
|
+
dataset_version_id: int | None
|
|
18
|
+
dataset_name: str | None
|
|
19
|
+
dataset_version: str | None
|
|
20
|
+
created_at: datetime
|
|
21
|
+
source_dataset_id: int
|
|
22
|
+
source_dataset_version_id: int | None
|
|
23
|
+
depth: int
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def parse(
|
|
27
|
+
cls: builtins.type[DDN],
|
|
28
|
+
namespace: str,
|
|
29
|
+
project: str,
|
|
30
|
+
id: int,
|
|
31
|
+
dataset_id: int | None,
|
|
32
|
+
dataset_version_id: int | None,
|
|
33
|
+
dataset_name: str | None,
|
|
34
|
+
dataset_version: str | None,
|
|
35
|
+
created_at: datetime,
|
|
36
|
+
source_dataset_id: int,
|
|
37
|
+
source_dataset_version_id: int | None,
|
|
38
|
+
depth: int,
|
|
39
|
+
) -> "DatasetDependencyNode | None":
|
|
40
|
+
return cls(
|
|
41
|
+
namespace,
|
|
42
|
+
project,
|
|
43
|
+
id,
|
|
44
|
+
dataset_id,
|
|
45
|
+
dataset_version_id,
|
|
46
|
+
dataset_name,
|
|
47
|
+
dataset_version,
|
|
48
|
+
created_at,
|
|
49
|
+
source_dataset_id,
|
|
50
|
+
source_dataset_version_id,
|
|
51
|
+
depth,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def to_dependency(self) -> "DatasetDependency | None":
|
|
55
|
+
return DatasetDependency.parse(
|
|
56
|
+
namespace_name=self.namespace,
|
|
57
|
+
project_name=self.project,
|
|
58
|
+
id=self.id,
|
|
59
|
+
dataset_id=self.dataset_id,
|
|
60
|
+
dataset_version_id=self.dataset_version_id,
|
|
61
|
+
dataset_name=self.dataset_name,
|
|
62
|
+
dataset_version=self.dataset_version,
|
|
63
|
+
dataset_version_created_at=self.created_at,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def build_dependency_hierarchy(
|
|
68
|
+
dependency_nodes: list[DatasetDependencyNode | None],
|
|
69
|
+
) -> tuple[
|
|
70
|
+
dict[int, DatasetDependency | None], dict[tuple[int, int | None], list[int]]
|
|
71
|
+
]:
|
|
72
|
+
"""
|
|
73
|
+
Build dependency hierarchy from dependency nodes.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
dependency_nodes: List of DatasetDependencyNode objects from the database
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Tuple of (dependency_map, children_map) where:
|
|
80
|
+
- dependency_map: Maps dependency_id -> DatasetDependency
|
|
81
|
+
- children_map: Maps (source_dataset_id, source_version_id) ->
|
|
82
|
+
list of dependency_ids
|
|
83
|
+
"""
|
|
84
|
+
dependency_map: dict[int, DatasetDependency | None] = {}
|
|
85
|
+
children_map: dict[tuple[int, int | None], list[int]] = {}
|
|
86
|
+
|
|
87
|
+
for node in dependency_nodes:
|
|
88
|
+
if node is None:
|
|
89
|
+
continue
|
|
90
|
+
dependency = node.to_dependency()
|
|
91
|
+
parent_key = (node.source_dataset_id, node.source_dataset_version_id)
|
|
92
|
+
|
|
93
|
+
if dependency is not None:
|
|
94
|
+
dependency_map[dependency.id] = dependency
|
|
95
|
+
children_map.setdefault(parent_key, []).append(dependency.id)
|
|
96
|
+
else:
|
|
97
|
+
# Handle case where dependency creation failed (e.g., deleted dependency)
|
|
98
|
+
dependency_map[node.id] = None
|
|
99
|
+
children_map.setdefault(parent_key, []).append(node.id)
|
|
100
|
+
|
|
101
|
+
return dependency_map, children_map
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def populate_nested_dependencies(
|
|
105
|
+
dependency: DatasetDependency,
|
|
106
|
+
dependency_nodes: list[DatasetDependencyNode | None],
|
|
107
|
+
dependency_map: dict[int, DatasetDependency | None],
|
|
108
|
+
children_map: dict[tuple[int, int | None], list[int]],
|
|
109
|
+
) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Recursively populate nested dependencies for a given dependency.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
dependency: The dependency to populate nested dependencies for
|
|
115
|
+
dependency_nodes: All dependency nodes from the database
|
|
116
|
+
dependency_map: Maps dependency_id -> DatasetDependency
|
|
117
|
+
children_map: Maps (source_dataset_id, source_version_id) ->
|
|
118
|
+
list of dependency_ids
|
|
119
|
+
"""
|
|
120
|
+
# Find the target dataset and version for this dependency
|
|
121
|
+
target_dataset_id, target_version_id = find_target_dataset_version(
|
|
122
|
+
dependency, dependency_nodes
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if target_dataset_id is None or target_version_id is None:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
# Get children for this target
|
|
129
|
+
target_key = (target_dataset_id, target_version_id)
|
|
130
|
+
if target_key not in children_map:
|
|
131
|
+
dependency.dependencies = []
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
child_dependency_ids = children_map[target_key]
|
|
135
|
+
child_dependencies = [dependency_map[child_id] for child_id in child_dependency_ids]
|
|
136
|
+
|
|
137
|
+
dependency.dependencies = child_dependencies
|
|
138
|
+
|
|
139
|
+
# Recursively populate children
|
|
140
|
+
for child_dependency in child_dependencies:
|
|
141
|
+
if child_dependency is not None:
|
|
142
|
+
populate_nested_dependencies(
|
|
143
|
+
child_dependency, dependency_nodes, dependency_map, children_map
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def find_target_dataset_version(
|
|
148
|
+
dependency: DatasetDependency,
|
|
149
|
+
dependency_nodes: list[DatasetDependencyNode | None],
|
|
150
|
+
) -> tuple[int | None, int | None]:
|
|
151
|
+
"""
|
|
152
|
+
Find the target dataset ID and version ID for a given dependency.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
dependency: The dependency to find target for
|
|
156
|
+
dependency_nodes: All dependency nodes from the database
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Tuple of (target_dataset_id, target_version_id) or (None, None) if not found
|
|
160
|
+
"""
|
|
161
|
+
for node in dependency_nodes:
|
|
162
|
+
if node is not None and node.id == dependency.id:
|
|
163
|
+
return node.dataset_id, node.dataset_version_id
|
|
164
|
+
return None, None
|
datachain/catalog/loader.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import sys
|
|
2
3
|
from importlib import import_module
|
|
3
|
-
from typing import TYPE_CHECKING, Any
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
4
5
|
|
|
6
|
+
from datachain.plugins import ensure_plugins_loaded
|
|
5
7
|
from datachain.utils import get_envs_by_prefix
|
|
6
8
|
|
|
7
9
|
if TYPE_CHECKING:
|
|
8
10
|
from datachain.catalog import Catalog
|
|
9
11
|
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
12
|
+
from datachain.query.udf import AbstractUDFDistributor
|
|
10
13
|
|
|
11
14
|
METASTORE_SERIALIZED = "DATACHAIN__METASTORE"
|
|
12
15
|
METASTORE_IMPORT_PATH = "DATACHAIN_METASTORE"
|
|
@@ -14,13 +17,16 @@ METASTORE_ARG_PREFIX = "DATACHAIN_METASTORE_ARG_"
|
|
|
14
17
|
WAREHOUSE_SERIALIZED = "DATACHAIN__WAREHOUSE"
|
|
15
18
|
WAREHOUSE_IMPORT_PATH = "DATACHAIN_WAREHOUSE"
|
|
16
19
|
WAREHOUSE_ARG_PREFIX = "DATACHAIN_WAREHOUSE_ARG_"
|
|
20
|
+
DISTRIBUTED_IMPORT_PYTHONPATH = "DATACHAIN_DISTRIBUTED_PYTHONPATH"
|
|
17
21
|
DISTRIBUTED_IMPORT_PATH = "DATACHAIN_DISTRIBUTED"
|
|
18
|
-
|
|
22
|
+
DISTRIBUTED_DISABLED = "DATACHAIN_DISTRIBUTED_DISABLED"
|
|
19
23
|
|
|
20
24
|
IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
|
|
28
|
+
ensure_plugins_loaded()
|
|
29
|
+
|
|
24
30
|
from datachain.data_storage import AbstractMetastore
|
|
25
31
|
from datachain.data_storage.serializer import deserialize
|
|
26
32
|
|
|
@@ -61,6 +67,8 @@ def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
|
|
|
61
67
|
|
|
62
68
|
|
|
63
69
|
def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
|
|
70
|
+
ensure_plugins_loaded()
|
|
71
|
+
|
|
64
72
|
from datachain.data_storage import AbstractWarehouse
|
|
65
73
|
from datachain.data_storage.serializer import deserialize
|
|
66
74
|
|
|
@@ -100,31 +108,32 @@ def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
|
|
|
100
108
|
return warehouse_class(**warehouse_args)
|
|
101
109
|
|
|
102
110
|
|
|
103
|
-
def
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
# Convert env variable names to keyword argument names by lowercasing them
|
|
107
|
-
distributed_args = {k.lower(): v for k, v in distributed_arg_envs.items()}
|
|
111
|
+
def get_udf_distributor_class() -> type["AbstractUDFDistributor"] | None:
|
|
112
|
+
if os.environ.get(DISTRIBUTED_DISABLED) == "True":
|
|
113
|
+
return None
|
|
108
114
|
|
|
109
|
-
if not distributed_import_path:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
)
|
|
114
|
-
# Distributed class paths are specified as (for example):
|
|
115
|
-
# module.classname
|
|
115
|
+
if not (distributed_import_path := os.environ.get(DISTRIBUTED_IMPORT_PATH)):
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
# Distributed class paths are specified as (for example): module.classname
|
|
116
119
|
if "." not in distributed_import_path:
|
|
117
120
|
raise RuntimeError(
|
|
118
121
|
f"Invalid {DISTRIBUTED_IMPORT_PATH} import path: {distributed_import_path}"
|
|
119
122
|
)
|
|
123
|
+
|
|
124
|
+
# Optional: set the Python path to look for the module
|
|
125
|
+
distributed_import_pythonpath = os.environ.get(DISTRIBUTED_IMPORT_PYTHONPATH)
|
|
126
|
+
if distributed_import_pythonpath and distributed_import_pythonpath not in sys.path:
|
|
127
|
+
sys.path.insert(0, distributed_import_pythonpath)
|
|
128
|
+
|
|
120
129
|
module_name, _, class_name = distributed_import_path.rpartition(".")
|
|
121
130
|
distributed = import_module(module_name)
|
|
122
|
-
|
|
123
|
-
return distributed_class(**distributed_args | kwargs)
|
|
131
|
+
return getattr(distributed, class_name)
|
|
124
132
|
|
|
125
133
|
|
|
126
134
|
def get_catalog(
|
|
127
|
-
client_config:
|
|
135
|
+
client_config: dict[str, Any] | None = None,
|
|
136
|
+
in_memory: bool = False,
|
|
128
137
|
) -> "Catalog":
|
|
129
138
|
"""
|
|
130
139
|
Function that creates Catalog instance with appropriate metastore
|
|
@@ -139,8 +148,9 @@ def get_catalog(
|
|
|
139
148
|
"""
|
|
140
149
|
from datachain.catalog import Catalog
|
|
141
150
|
|
|
151
|
+
metastore = get_metastore(in_memory=in_memory)
|
|
142
152
|
return Catalog(
|
|
143
|
-
metastore=
|
|
153
|
+
metastore=metastore,
|
|
144
154
|
warehouse=get_warehouse(in_memory=in_memory),
|
|
145
155
|
client_config=client_config,
|
|
146
156
|
in_memory=in_memory,
|
datachain/checkpoint.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class Checkpoint:
|
|
8
|
+
"""
|
|
9
|
+
Represents a checkpoint within a job run.
|
|
10
|
+
|
|
11
|
+
A checkpoint marks a successfully completed stage of execution. In the event
|
|
12
|
+
of a failure, the job can resume from the most recent checkpoint rather than
|
|
13
|
+
starting over from the beginning.
|
|
14
|
+
|
|
15
|
+
Checkpoints can also be created in a "partial" mode, which indicates that the
|
|
16
|
+
work at this stage was only partially completed. For example, if a failure
|
|
17
|
+
occurs halfway through running a UDF, already computed results can still be
|
|
18
|
+
saved, allowing the job to resume from that partially completed state on
|
|
19
|
+
restart.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
id: str
|
|
23
|
+
job_id: str
|
|
24
|
+
hash: str
|
|
25
|
+
partial: bool
|
|
26
|
+
created_at: datetime
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def parse(
|
|
30
|
+
cls,
|
|
31
|
+
id: str | uuid.UUID,
|
|
32
|
+
job_id: str,
|
|
33
|
+
_hash: str,
|
|
34
|
+
partial: bool,
|
|
35
|
+
created_at: datetime,
|
|
36
|
+
) -> "Checkpoint":
|
|
37
|
+
return cls(
|
|
38
|
+
str(id),
|
|
39
|
+
job_id,
|
|
40
|
+
_hash,
|
|
41
|
+
bool(partial),
|
|
42
|
+
created_at,
|
|
43
|
+
)
|
datachain/cli/__init__.py
CHANGED
|
@@ -3,7 +3,6 @@ import os
|
|
|
3
3
|
import sys
|
|
4
4
|
import traceback
|
|
5
5
|
from multiprocessing import freeze_support
|
|
6
|
-
from typing import Optional
|
|
7
6
|
|
|
8
7
|
from datachain.cli.utils import get_logging_level
|
|
9
8
|
|
|
@@ -16,7 +15,6 @@ from .commands import (
|
|
|
16
15
|
index,
|
|
17
16
|
list_datasets,
|
|
18
17
|
ls,
|
|
19
|
-
query,
|
|
20
18
|
rm_dataset,
|
|
21
19
|
show,
|
|
22
20
|
)
|
|
@@ -25,7 +23,7 @@ from .parser import get_parser
|
|
|
25
23
|
logger = logging.getLogger("datachain")
|
|
26
24
|
|
|
27
25
|
|
|
28
|
-
def main(argv:
|
|
26
|
+
def main(argv: list[str] | None = None) -> int:
|
|
29
27
|
from datachain.catalog import get_catalog
|
|
30
28
|
|
|
31
29
|
# Required for Windows multiprocessing support
|
|
@@ -34,8 +32,10 @@ def main(argv: Optional[list[str]] = None) -> int:
|
|
|
34
32
|
datachain_parser = get_parser()
|
|
35
33
|
args = datachain_parser.parse_args(argv)
|
|
36
34
|
|
|
37
|
-
if args.command
|
|
38
|
-
return handle_udf(
|
|
35
|
+
if args.command == "internal-run-udf":
|
|
36
|
+
return handle_udf()
|
|
37
|
+
if args.command == "internal-run-udf-worker":
|
|
38
|
+
return handle_udf_runner()
|
|
39
39
|
|
|
40
40
|
if args.command is None:
|
|
41
41
|
datachain_parser.print_help(sys.stderr)
|
|
@@ -59,16 +59,22 @@ def main(argv: Optional[list[str]] = None) -> int:
|
|
|
59
59
|
|
|
60
60
|
error = None
|
|
61
61
|
|
|
62
|
+
catalog = None
|
|
62
63
|
try:
|
|
63
64
|
catalog = get_catalog(client_config=client_config)
|
|
64
65
|
return handle_command(args, catalog, client_config)
|
|
65
66
|
except BrokenPipeError as exc:
|
|
66
67
|
error, return_code = handle_broken_pipe_error(exc)
|
|
67
68
|
return return_code
|
|
68
|
-
except (KeyboardInterrupt, Exception) as exc:
|
|
69
|
+
except (KeyboardInterrupt, Exception) as exc: # noqa: BLE001
|
|
69
70
|
error, return_code = handle_general_exception(exc, args, logging_level)
|
|
70
71
|
return return_code
|
|
71
72
|
finally:
|
|
73
|
+
if catalog is not None:
|
|
74
|
+
try:
|
|
75
|
+
catalog.close()
|
|
76
|
+
except Exception:
|
|
77
|
+
logger.exception("Failed to close catalog")
|
|
72
78
|
from datachain.telemetry import telemetry
|
|
73
79
|
|
|
74
80
|
telemetry.send_cli_call(args.command, error=error)
|
|
@@ -89,7 +95,6 @@ def handle_command(args, catalog, client_config) -> int:
|
|
|
89
95
|
"find": lambda: handle_find_command(args, catalog),
|
|
90
96
|
"index": lambda: handle_index_command(args, catalog),
|
|
91
97
|
"completion": lambda: handle_completion_command(args),
|
|
92
|
-
"query": lambda: handle_query_command(args, catalog),
|
|
93
98
|
"clear-cache": lambda: clear_cache(catalog),
|
|
94
99
|
"gc": lambda: garbage_collect(catalog),
|
|
95
100
|
"auth": lambda: process_auth_cli_args(args),
|
|
@@ -98,8 +103,10 @@ def handle_command(args, catalog, client_config) -> int:
|
|
|
98
103
|
|
|
99
104
|
handler = command_handlers.get(args.command)
|
|
100
105
|
if handler:
|
|
101
|
-
handler()
|
|
102
|
-
|
|
106
|
+
return_code = handler()
|
|
107
|
+
if return_code is None:
|
|
108
|
+
return 0
|
|
109
|
+
return return_code
|
|
103
110
|
print(f"invalid command: {args.command}", file=sys.stderr)
|
|
104
111
|
return 1
|
|
105
112
|
|
|
@@ -149,10 +156,7 @@ def handle_dataset_command(args, catalog):
|
|
|
149
156
|
args.name,
|
|
150
157
|
new_name=args.new_name,
|
|
151
158
|
description=args.description,
|
|
152
|
-
|
|
153
|
-
studio=args.studio,
|
|
154
|
-
local=args.local,
|
|
155
|
-
all=args.all,
|
|
159
|
+
attrs=args.attrs,
|
|
156
160
|
team=args.team,
|
|
157
161
|
),
|
|
158
162
|
"ls": lambda: list_datasets(
|
|
@@ -170,8 +174,6 @@ def handle_dataset_command(args, catalog):
|
|
|
170
174
|
version=args.version,
|
|
171
175
|
force=args.force,
|
|
172
176
|
studio=args.studio,
|
|
173
|
-
local=args.local,
|
|
174
|
-
all=args.all,
|
|
175
177
|
team=args.team,
|
|
176
178
|
),
|
|
177
179
|
"remove": lambda: rm_dataset(
|
|
@@ -180,8 +182,6 @@ def handle_dataset_command(args, catalog):
|
|
|
180
182
|
version=args.version,
|
|
181
183
|
force=args.force,
|
|
182
184
|
studio=args.studio,
|
|
183
|
-
local=args.local,
|
|
184
|
-
all=args.all,
|
|
185
185
|
team=args.team,
|
|
186
186
|
),
|
|
187
187
|
}
|
|
@@ -263,15 +263,6 @@ def handle_completion_command(args):
|
|
|
263
263
|
print(completion(args.shell))
|
|
264
264
|
|
|
265
265
|
|
|
266
|
-
def handle_query_command(args, catalog):
|
|
267
|
-
query(
|
|
268
|
-
catalog,
|
|
269
|
-
args.script,
|
|
270
|
-
parallel=args.parallel,
|
|
271
|
-
params=args.param,
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
|
|
275
266
|
def handle_broken_pipe_error(exc):
|
|
276
267
|
# Python flushes standard streams on exit; redirect remaining output
|
|
277
268
|
# to devnull to avoid another BrokenPipeError at shutdown
|
|
@@ -303,13 +294,13 @@ def handle_general_exception(exc, args, logging_level):
|
|
|
303
294
|
return error, 1
|
|
304
295
|
|
|
305
296
|
|
|
306
|
-
def handle_udf(
|
|
307
|
-
|
|
308
|
-
|
|
297
|
+
def handle_udf() -> int:
|
|
298
|
+
from datachain.query.dispatch import udf_entrypoint
|
|
299
|
+
|
|
300
|
+
return udf_entrypoint()
|
|
309
301
|
|
|
310
|
-
return udf_entrypoint()
|
|
311
302
|
|
|
312
|
-
|
|
313
|
-
|
|
303
|
+
def handle_udf_runner() -> int:
|
|
304
|
+
from datachain.query.dispatch import udf_worker_entrypoint
|
|
314
305
|
|
|
315
|
-
|
|
306
|
+
return udf_worker_entrypoint()
|
|
@@ -1,14 +1,8 @@
|
|
|
1
|
-
from .datasets import
|
|
2
|
-
edit_dataset,
|
|
3
|
-
list_datasets,
|
|
4
|
-
list_datasets_local,
|
|
5
|
-
rm_dataset,
|
|
6
|
-
)
|
|
1
|
+
from .datasets import edit_dataset, list_datasets, list_datasets_local, rm_dataset
|
|
7
2
|
from .du import du
|
|
8
3
|
from .index import index
|
|
9
4
|
from .ls import ls
|
|
10
5
|
from .misc import clear_cache, completion, garbage_collect
|
|
11
|
-
from .query import query
|
|
12
6
|
from .show import show
|
|
13
7
|
|
|
14
8
|
__all__ = [
|
|
@@ -21,7 +15,6 @@ __all__ = [
|
|
|
21
15
|
"list_datasets",
|
|
22
16
|
"list_datasets_local",
|
|
23
17
|
"ls",
|
|
24
|
-
"query",
|
|
25
18
|
"rm_dataset",
|
|
26
19
|
"show",
|
|
27
20
|
]
|
|
@@ -1,29 +1,41 @@
|
|
|
1
1
|
import sys
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Iterable, Iterator
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
3
4
|
|
|
4
5
|
from tabulate import tabulate
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
7
|
+
from datachain import semver
|
|
8
|
+
from datachain.catalog import is_namespace_local
|
|
9
9
|
from datachain.cli.utils import determine_flavors
|
|
10
10
|
from datachain.config import Config
|
|
11
|
-
from datachain.error import DatasetNotFoundError
|
|
11
|
+
from datachain.error import DataChainError, DatasetNotFoundError
|
|
12
12
|
from datachain.studio import list_datasets as list_datasets_studio
|
|
13
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from datachain.catalog import Catalog
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def group_dataset_versions(
|
|
19
|
+
datasets: Iterable[tuple[str, str]], latest_only=True
|
|
20
|
+
) -> dict[str, str | list[str]]:
|
|
21
|
+
grouped: dict[str, list[tuple[int, int, int]]] = {}
|
|
14
22
|
|
|
15
|
-
def group_dataset_versions(datasets, latest_only=True):
|
|
16
|
-
grouped = {}
|
|
17
23
|
# Sort to ensure groupby works as expected
|
|
18
24
|
# (groupby expects consecutive items with the same key)
|
|
19
25
|
for name, version in sorted(datasets):
|
|
20
|
-
grouped.setdefault(name, []).append(version)
|
|
26
|
+
grouped.setdefault(name, []).append(semver.parse(version))
|
|
21
27
|
|
|
22
28
|
if latest_only:
|
|
23
29
|
# For each dataset name, pick the highest version.
|
|
24
|
-
return {
|
|
30
|
+
return {
|
|
31
|
+
name: semver.create(*(max(versions))) for name, versions in grouped.items()
|
|
32
|
+
}
|
|
33
|
+
|
|
25
34
|
# For each dataset name, return a sorted list of unique versions.
|
|
26
|
-
return {
|
|
35
|
+
return {
|
|
36
|
+
name: [semver.create(*v) for v in sorted(set(versions))]
|
|
37
|
+
for name, versions in grouped.items()
|
|
38
|
+
}
|
|
27
39
|
|
|
28
40
|
|
|
29
41
|
def list_datasets(
|
|
@@ -31,10 +43,10 @@ def list_datasets(
|
|
|
31
43
|
studio: bool = False,
|
|
32
44
|
local: bool = False,
|
|
33
45
|
all: bool = True,
|
|
34
|
-
team:
|
|
46
|
+
team: str | None = None,
|
|
35
47
|
latest_only: bool = True,
|
|
36
|
-
name:
|
|
37
|
-
):
|
|
48
|
+
name: str | None = None,
|
|
49
|
+
) -> None:
|
|
38
50
|
token = Config().read().get("studio", {}).get("token")
|
|
39
51
|
all, local, studio = determine_flavors(studio, local, all, token)
|
|
40
52
|
if name:
|
|
@@ -94,23 +106,31 @@ def list_datasets(
|
|
|
94
106
|
print(tabulate(rows, headers="keys"))
|
|
95
107
|
|
|
96
108
|
|
|
97
|
-
def list_datasets_local(
|
|
109
|
+
def list_datasets_local(
|
|
110
|
+
catalog: "Catalog", name: str | None = None
|
|
111
|
+
) -> Iterator[tuple[str, str]]:
|
|
98
112
|
if name:
|
|
99
113
|
yield from list_datasets_local_versions(catalog, name)
|
|
100
114
|
return
|
|
101
115
|
|
|
102
116
|
for d in catalog.ls_datasets():
|
|
103
117
|
for v in d.versions:
|
|
104
|
-
yield
|
|
118
|
+
yield d.full_name, v.version
|
|
105
119
|
|
|
106
120
|
|
|
107
|
-
def list_datasets_local_versions(
|
|
108
|
-
|
|
121
|
+
def list_datasets_local_versions(
|
|
122
|
+
catalog: "Catalog", name: str
|
|
123
|
+
) -> Iterator[tuple[str, str]]:
|
|
124
|
+
namespace_name, project_name, name = catalog.get_full_dataset_name(name)
|
|
125
|
+
|
|
126
|
+
ds = catalog.get_dataset(
|
|
127
|
+
name, namespace_name=namespace_name, project_name=project_name
|
|
128
|
+
)
|
|
109
129
|
for v in ds.versions:
|
|
110
|
-
yield
|
|
130
|
+
yield name, v.version
|
|
111
131
|
|
|
112
132
|
|
|
113
|
-
def _datasets_tabulate_row(name, both, local_version, studio_version):
|
|
133
|
+
def _datasets_tabulate_row(name, both, local_version, studio_version) -> dict[str, str]:
|
|
114
134
|
row = {
|
|
115
135
|
"Name": name,
|
|
116
136
|
}
|
|
@@ -127,49 +147,60 @@ def _datasets_tabulate_row(name, both, local_version, studio_version):
|
|
|
127
147
|
def rm_dataset(
|
|
128
148
|
catalog: "Catalog",
|
|
129
149
|
name: str,
|
|
130
|
-
version:
|
|
131
|
-
force:
|
|
132
|
-
studio: bool = False,
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
150
|
+
version: str | None = None,
|
|
151
|
+
force: bool | None = False,
|
|
152
|
+
studio: bool | None = False,
|
|
153
|
+
team: str | None = None,
|
|
154
|
+
) -> None:
|
|
155
|
+
namespace_name, project_name, name = catalog.get_full_dataset_name(name)
|
|
156
|
+
|
|
157
|
+
if studio:
|
|
158
|
+
# removing Studio dataset from CLI
|
|
159
|
+
from datachain.studio import remove_studio_dataset
|
|
160
|
+
|
|
161
|
+
if Config().read().get("studio", {}).get("token"):
|
|
162
|
+
remove_studio_dataset(
|
|
163
|
+
team, name, namespace_name, project_name, version, force
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
raise DataChainError(
|
|
167
|
+
"Not logged in to Studio. Log in with 'datachain auth login'."
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
143
170
|
try:
|
|
144
|
-
catalog.
|
|
171
|
+
project = catalog.metastore.get_project(project_name, namespace_name)
|
|
172
|
+
catalog.remove_dataset(name, project, version=version, force=force)
|
|
145
173
|
except DatasetNotFoundError:
|
|
146
174
|
print("Dataset not found in local", file=sys.stderr)
|
|
147
175
|
|
|
148
|
-
if (all or studio) and token:
|
|
149
|
-
remove_studio_dataset(team, name, version, force)
|
|
150
|
-
|
|
151
176
|
|
|
152
177
|
def edit_dataset(
|
|
153
178
|
catalog: "Catalog",
|
|
154
179
|
name: str,
|
|
155
|
-
new_name:
|
|
156
|
-
description:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
team: Optional[str] = None,
|
|
162
|
-
):
|
|
163
|
-
from datachain.studio import edit_studio_dataset
|
|
180
|
+
new_name: str | None = None,
|
|
181
|
+
description: str | None = None,
|
|
182
|
+
attrs: list[str] | None = None,
|
|
183
|
+
team: str | None = None,
|
|
184
|
+
) -> None:
|
|
185
|
+
from datachain.lib.dc.utils import is_studio
|
|
164
186
|
|
|
165
|
-
|
|
166
|
-
all, local, studio = determine_flavors(studio, local, all, token)
|
|
187
|
+
namespace_name, project_name, name = catalog.get_full_dataset_name(name)
|
|
167
188
|
|
|
168
|
-
if
|
|
189
|
+
if is_studio() or is_namespace_local(namespace_name):
|
|
169
190
|
try:
|
|
170
|
-
catalog.edit_dataset(
|
|
191
|
+
catalog.edit_dataset(
|
|
192
|
+
name, catalog.metastore.default_project, new_name, description, attrs
|
|
193
|
+
)
|
|
171
194
|
except DatasetNotFoundError:
|
|
172
195
|
print("Dataset not found in local", file=sys.stderr)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
196
|
+
else:
|
|
197
|
+
from datachain.studio import edit_studio_dataset
|
|
198
|
+
|
|
199
|
+
if Config().read().get("studio", {}).get("token"):
|
|
200
|
+
edit_studio_dataset(
|
|
201
|
+
team, name, namespace_name, project_name, new_name, description, attrs
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
raise DataChainError(
|
|
205
|
+
"Not logged in to Studio. Log in with 'datachain auth login'."
|
|
206
|
+
)
|