datachain 0.14.2__py3-none-any.whl → 0.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datachain/__init__.py +20 -0
- datachain/asyn.py +11 -12
- datachain/cache.py +7 -7
- datachain/catalog/__init__.py +2 -2
- datachain/catalog/catalog.py +621 -507
- datachain/catalog/dependency.py +164 -0
- datachain/catalog/loader.py +28 -18
- datachain/checkpoint.py +43 -0
- datachain/cli/__init__.py +24 -33
- datachain/cli/commands/__init__.py +1 -8
- datachain/cli/commands/datasets.py +83 -52
- datachain/cli/commands/ls.py +17 -17
- datachain/cli/commands/show.py +4 -4
- datachain/cli/parser/__init__.py +8 -74
- datachain/cli/parser/job.py +95 -3
- datachain/cli/parser/studio.py +11 -4
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +2 -15
- datachain/client/azure.py +4 -4
- datachain/client/fsspec.py +45 -28
- datachain/client/gcs.py +6 -6
- datachain/client/hf.py +29 -2
- datachain/client/http.py +157 -0
- datachain/client/local.py +15 -11
- datachain/client/s3.py +17 -9
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +12 -6
- datachain/data_storage/job.py +5 -1
- datachain/data_storage/metastore.py +1252 -186
- datachain/data_storage/schema.py +58 -45
- datachain/data_storage/serializer.py +105 -15
- datachain/data_storage/sqlite.py +286 -127
- datachain/data_storage/warehouse.py +250 -113
- datachain/dataset.py +353 -148
- datachain/delta.py +391 -0
- datachain/diff/__init__.py +27 -29
- datachain/error.py +60 -0
- datachain/func/__init__.py +2 -1
- datachain/func/aggregate.py +66 -42
- datachain/func/array.py +242 -38
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +110 -60
- datachain/func/func.py +96 -45
- datachain/func/numeric.py +55 -38
- datachain/func/path.py +32 -20
- datachain/func/random.py +2 -2
- datachain/func/string.py +67 -37
- datachain/func/window.py +7 -8
- datachain/hash_utils.py +123 -0
- datachain/job.py +11 -7
- datachain/json.py +138 -0
- datachain/lib/arrow.py +58 -22
- datachain/lib/audio.py +245 -0
- datachain/lib/clip.py +14 -13
- datachain/lib/convert/flatten.py +5 -3
- datachain/lib/convert/python_to_sql.py +6 -10
- datachain/lib/convert/sql_to_python.py +8 -0
- datachain/lib/convert/values_to_tuples.py +156 -51
- datachain/lib/data_model.py +42 -20
- datachain/lib/dataset_info.py +36 -8
- datachain/lib/dc/__init__.py +8 -2
- datachain/lib/dc/csv.py +25 -28
- datachain/lib/dc/database.py +398 -0
- datachain/lib/dc/datachain.py +1289 -425
- datachain/lib/dc/datasets.py +320 -38
- datachain/lib/dc/hf.py +38 -24
- datachain/lib/dc/json.py +29 -32
- datachain/lib/dc/listings.py +112 -8
- datachain/lib/dc/pandas.py +16 -12
- datachain/lib/dc/parquet.py +35 -23
- datachain/lib/dc/records.py +31 -23
- datachain/lib/dc/storage.py +154 -64
- datachain/lib/dc/storage_pattern.py +251 -0
- datachain/lib/dc/utils.py +24 -16
- datachain/lib/dc/values.py +8 -9
- datachain/lib/file.py +622 -89
- datachain/lib/hf.py +69 -39
- datachain/lib/image.py +14 -14
- datachain/lib/listing.py +14 -11
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +3 -4
- datachain/lib/model_store.py +39 -7
- datachain/lib/namespaces.py +125 -0
- datachain/lib/projects.py +130 -0
- datachain/lib/pytorch.py +32 -21
- datachain/lib/settings.py +192 -56
- datachain/lib/signal_schema.py +427 -104
- datachain/lib/tar.py +1 -2
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +164 -76
- datachain/lib/udf_signature.py +60 -35
- datachain/lib/utils.py +118 -4
- datachain/lib/video.py +17 -9
- datachain/lib/webdataset.py +61 -56
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +22 -10
- datachain/model/bbox.py +3 -1
- datachain/model/ultralytics/bbox.py +16 -12
- datachain/model/ultralytics/pose.py +16 -12
- datachain/model/ultralytics/segment.py +16 -12
- datachain/namespace.py +84 -0
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +24 -0
- datachain/project.py +78 -0
- datachain/query/batch.py +40 -41
- datachain/query/dataset.py +604 -322
- datachain/query/dispatch.py +261 -154
- datachain/query/metrics.py +4 -6
- datachain/query/params.py +2 -3
- datachain/query/queue.py +3 -12
- datachain/query/schema.py +11 -6
- datachain/query/session.py +200 -33
- datachain/query/udf.py +34 -2
- datachain/remote/studio.py +171 -69
- datachain/script_meta.py +12 -12
- datachain/semver.py +68 -0
- datachain/sql/__init__.py +2 -0
- datachain/sql/functions/array.py +33 -1
- datachain/sql/postgresql_dialect.py +9 -0
- datachain/sql/postgresql_types.py +21 -0
- datachain/sql/sqlite/__init__.py +5 -1
- datachain/sql/sqlite/base.py +102 -29
- datachain/sql/sqlite/types.py +8 -13
- datachain/sql/types.py +70 -15
- datachain/studio.py +223 -46
- datachain/toolkit/split.py +31 -10
- datachain/utils.py +101 -59
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/METADATA +77 -22
- datachain-0.39.0.dist-info/RECORD +173 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/WHEEL +1 -1
- datachain/cli/commands/query.py +0 -53
- datachain/query/utils.py +0 -42
- datachain-0.14.2.dist-info/RECORD +0 -158
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.14.2.dist-info → datachain-0.39.0.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -1,28 +1,16 @@
|
|
|
1
1
|
import io
|
|
2
|
-
import json
|
|
3
2
|
import logging
|
|
4
3
|
import os
|
|
5
4
|
import os.path
|
|
6
5
|
import posixpath
|
|
7
|
-
import signal
|
|
8
|
-
import subprocess
|
|
9
|
-
import sys
|
|
10
6
|
import time
|
|
11
7
|
import traceback
|
|
12
|
-
from collections.abc import Iterable, Iterator,
|
|
8
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
9
|
+
from contextlib import contextmanager, suppress
|
|
13
10
|
from copy import copy
|
|
14
11
|
from dataclasses import dataclass
|
|
15
12
|
from functools import cached_property, reduce
|
|
16
|
-
from
|
|
17
|
-
from typing import (
|
|
18
|
-
IO,
|
|
19
|
-
TYPE_CHECKING,
|
|
20
|
-
Any,
|
|
21
|
-
Callable,
|
|
22
|
-
NoReturn,
|
|
23
|
-
Optional,
|
|
24
|
-
Union,
|
|
25
|
-
)
|
|
13
|
+
from typing import TYPE_CHECKING, Any
|
|
26
14
|
from uuid import uuid4
|
|
27
15
|
|
|
28
16
|
import sqlalchemy as sa
|
|
@@ -33,6 +21,7 @@ from datachain.cache import Cache
|
|
|
33
21
|
from datachain.client import Client
|
|
34
22
|
from datachain.dataset import (
|
|
35
23
|
DATASET_PREFIX,
|
|
24
|
+
DEFAULT_DATASET_VERSION,
|
|
36
25
|
QUERY_DATASET_PREFIX,
|
|
37
26
|
DatasetDependency,
|
|
38
27
|
DatasetListRecord,
|
|
@@ -40,31 +29,33 @@ from datachain.dataset import (
|
|
|
40
29
|
DatasetStatus,
|
|
41
30
|
StorageURI,
|
|
42
31
|
create_dataset_uri,
|
|
32
|
+
parse_dataset_name,
|
|
43
33
|
parse_dataset_uri,
|
|
34
|
+
parse_schema,
|
|
44
35
|
)
|
|
45
36
|
from datachain.error import (
|
|
46
37
|
DataChainError,
|
|
47
38
|
DatasetInvalidVersionError,
|
|
48
39
|
DatasetNotFoundError,
|
|
49
40
|
DatasetVersionNotFoundError,
|
|
50
|
-
|
|
51
|
-
|
|
41
|
+
NamespaceNotFoundError,
|
|
42
|
+
ProjectNotFoundError,
|
|
52
43
|
)
|
|
53
44
|
from datachain.lib.listing import get_listing
|
|
54
45
|
from datachain.node import DirType, Node, NodeWithPath
|
|
55
46
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
47
|
+
from datachain.project import Project
|
|
56
48
|
from datachain.sql.types import DateTime, SQLType
|
|
57
49
|
from datachain.utils import DataChainDir
|
|
58
50
|
|
|
59
51
|
from .datasource import DataSource
|
|
52
|
+
from .dependency import build_dependency_hierarchy, populate_nested_dependencies
|
|
60
53
|
|
|
61
54
|
if TYPE_CHECKING:
|
|
62
|
-
from datachain.data_storage import
|
|
63
|
-
AbstractMetastore,
|
|
64
|
-
AbstractWarehouse,
|
|
65
|
-
)
|
|
55
|
+
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
66
56
|
from datachain.dataset import DatasetListVersion
|
|
67
57
|
from datachain.job import Job
|
|
58
|
+
from datachain.lib.listing_info import ListingInfo
|
|
68
59
|
from datachain.listing import Listing
|
|
69
60
|
|
|
70
61
|
logger = logging.getLogger("datachain")
|
|
@@ -75,10 +66,9 @@ TTL_INT = 4 * 60 * 60
|
|
|
75
66
|
|
|
76
67
|
INDEX_INTERNAL_ERROR_MESSAGE = "Internal error on indexing"
|
|
77
68
|
DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
|
|
78
|
-
# exit code we use if last statement in query script is not instance of DatasetQuery
|
|
79
|
-
QUERY_SCRIPT_INVALID_LAST_STATEMENT_EXIT_CODE = 10
|
|
80
69
|
# exit code we use if query script was canceled
|
|
81
70
|
QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
|
|
71
|
+
QUERY_SCRIPT_SIGTERM_EXIT_CODE = -15 # if query script was terminated by SIGTERM
|
|
82
72
|
|
|
83
73
|
# dataset pull
|
|
84
74
|
PULL_DATASET_MAX_THREADS = 5
|
|
@@ -87,64 +77,9 @@ PULL_DATASET_SLEEP_INTERVAL = 0.1 # sleep time while waiting for chunk to be av
|
|
|
87
77
|
PULL_DATASET_CHECK_STATUS_INTERVAL = 20 # interval to check export status in Studio
|
|
88
78
|
|
|
89
79
|
|
|
90
|
-
def
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
class TerminationSignal(RuntimeError): # noqa: N818
|
|
95
|
-
def __init__(self, signal):
|
|
96
|
-
self.signal = signal
|
|
97
|
-
super().__init__("Received termination signal", signal)
|
|
98
|
-
|
|
99
|
-
def __repr__(self):
|
|
100
|
-
return f"{self.__class__.__name__}({self.signal})"
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
if sys.platform == "win32":
|
|
104
|
-
SIGINT = signal.CTRL_C_EVENT
|
|
105
|
-
else:
|
|
106
|
-
SIGINT = signal.SIGINT
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def shutdown_process(
|
|
110
|
-
proc: subprocess.Popen,
|
|
111
|
-
interrupt_timeout: Optional[int] = None,
|
|
112
|
-
terminate_timeout: Optional[int] = None,
|
|
113
|
-
) -> int:
|
|
114
|
-
"""Shut down the process gracefully with SIGINT -> SIGTERM -> SIGKILL."""
|
|
115
|
-
|
|
116
|
-
logger.info("sending interrupt signal to the process %s", proc.pid)
|
|
117
|
-
proc.send_signal(SIGINT)
|
|
118
|
-
|
|
119
|
-
logger.info("waiting for the process %s to finish", proc.pid)
|
|
120
|
-
try:
|
|
121
|
-
return proc.wait(interrupt_timeout)
|
|
122
|
-
except subprocess.TimeoutExpired:
|
|
123
|
-
logger.info(
|
|
124
|
-
"timed out waiting, sending terminate signal to the process %s", proc.pid
|
|
125
|
-
)
|
|
126
|
-
proc.terminate()
|
|
127
|
-
try:
|
|
128
|
-
return proc.wait(terminate_timeout)
|
|
129
|
-
except subprocess.TimeoutExpired:
|
|
130
|
-
logger.info("timed out waiting, killing the process %s", proc.pid)
|
|
131
|
-
proc.kill()
|
|
132
|
-
return proc.wait()
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
|
|
136
|
-
buffer = b""
|
|
137
|
-
while byt := stream.read(1): # Read one byte at a time
|
|
138
|
-
buffer += byt
|
|
139
|
-
|
|
140
|
-
if byt in (b"\n", b"\r"): # Check for newline or carriage return
|
|
141
|
-
line = buffer.decode("utf-8")
|
|
142
|
-
callback(line)
|
|
143
|
-
buffer = b"" # Clear buffer for next line
|
|
144
|
-
|
|
145
|
-
if buffer: # Handle any remaining data in the buffer
|
|
146
|
-
line = buffer.decode("utf-8")
|
|
147
|
-
callback(line)
|
|
80
|
+
def is_namespace_local(namespace_name) -> bool:
|
|
81
|
+
"""Checks if namespace is from local environment, i.e. is `local`"""
|
|
82
|
+
return namespace_name == "local"
|
|
148
83
|
|
|
149
84
|
|
|
150
85
|
class DatasetRowsFetcher(NodesThreadPool):
|
|
@@ -152,11 +87,11 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
152
87
|
self,
|
|
153
88
|
metastore: "AbstractMetastore",
|
|
154
89
|
warehouse: "AbstractWarehouse",
|
|
155
|
-
|
|
156
|
-
remote_ds_version:
|
|
157
|
-
|
|
158
|
-
local_ds_version:
|
|
159
|
-
schema: dict[str,
|
|
90
|
+
remote_ds: DatasetRecord,
|
|
91
|
+
remote_ds_version: str,
|
|
92
|
+
local_ds: DatasetRecord,
|
|
93
|
+
local_ds_version: str,
|
|
94
|
+
schema: dict[str, SQLType | type[SQLType]],
|
|
160
95
|
max_threads: int = PULL_DATASET_MAX_THREADS,
|
|
161
96
|
progress_bar=None,
|
|
162
97
|
):
|
|
@@ -166,12 +101,12 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
166
101
|
self._check_dependencies()
|
|
167
102
|
self.metastore = metastore
|
|
168
103
|
self.warehouse = warehouse
|
|
169
|
-
self.
|
|
104
|
+
self.remote_ds = remote_ds
|
|
170
105
|
self.remote_ds_version = remote_ds_version
|
|
171
|
-
self.
|
|
106
|
+
self.local_ds = local_ds
|
|
172
107
|
self.local_ds_version = local_ds_version
|
|
173
108
|
self.schema = schema
|
|
174
|
-
self.last_status_check:
|
|
109
|
+
self.last_status_check: float | None = None
|
|
175
110
|
self.studio_client = StudioClient()
|
|
176
111
|
self.progress_bar = progress_bar
|
|
177
112
|
|
|
@@ -204,7 +139,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
204
139
|
Checks are done every PULL_DATASET_CHECK_STATUS_INTERVAL seconds
|
|
205
140
|
"""
|
|
206
141
|
export_status_response = self.studio_client.dataset_export_status(
|
|
207
|
-
self.
|
|
142
|
+
self.remote_ds, self.remote_ds_version
|
|
208
143
|
)
|
|
209
144
|
if not export_status_response.ok:
|
|
210
145
|
raise DataChainError(export_status_response.message)
|
|
@@ -251,9 +186,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
251
186
|
import pandas as pd
|
|
252
187
|
|
|
253
188
|
# metastore and warehouse are not thread safe
|
|
254
|
-
with self.
|
|
255
|
-
local_ds = metastore.get_dataset(self.local_ds_name)
|
|
256
|
-
|
|
189
|
+
with self.warehouse.clone() as warehouse:
|
|
257
190
|
urls = list(urls)
|
|
258
191
|
|
|
259
192
|
for url in urls:
|
|
@@ -266,7 +199,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
266
199
|
df = self.fix_columns(df)
|
|
267
200
|
|
|
268
201
|
inserted = warehouse.insert_dataset_rows(
|
|
269
|
-
df, local_ds, self.local_ds_version
|
|
202
|
+
df, self.local_ds, self.local_ds_version
|
|
270
203
|
)
|
|
271
204
|
self.increase_counter(inserted) # type: ignore [arg-type]
|
|
272
205
|
# sometimes progress bar doesn't get updated so manually updating it
|
|
@@ -277,16 +210,16 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
277
210
|
class NodeGroup:
|
|
278
211
|
"""Class for a group of nodes from the same source"""
|
|
279
212
|
|
|
280
|
-
listing:
|
|
281
|
-
client:
|
|
213
|
+
listing: "Listing | None"
|
|
214
|
+
client: Client
|
|
282
215
|
sources: list[DataSource]
|
|
283
216
|
|
|
284
217
|
# The source path within the bucket
|
|
285
218
|
# (not including the bucket name or s3:// prefix)
|
|
286
219
|
source_path: str = ""
|
|
287
|
-
dataset_name:
|
|
288
|
-
dataset_version:
|
|
289
|
-
instantiated_nodes:
|
|
220
|
+
dataset_name: str | None = None
|
|
221
|
+
dataset_version: str | None = None
|
|
222
|
+
instantiated_nodes: list[NodeWithPath] | None = None
|
|
290
223
|
|
|
291
224
|
@property
|
|
292
225
|
def is_dataset(self) -> bool:
|
|
@@ -307,13 +240,23 @@ class NodeGroup:
|
|
|
307
240
|
if self.sources:
|
|
308
241
|
self.client.fetch_nodes(self.iternodes(recursive), shared_progress_bar=pbar)
|
|
309
242
|
|
|
243
|
+
def close(self) -> None:
|
|
244
|
+
if self.listing:
|
|
245
|
+
self.listing.close()
|
|
246
|
+
|
|
247
|
+
def __enter__(self) -> "NodeGroup":
|
|
248
|
+
return self
|
|
249
|
+
|
|
250
|
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
251
|
+
self.close()
|
|
252
|
+
|
|
310
253
|
|
|
311
254
|
def prepare_output_for_cp(
|
|
312
255
|
node_groups: list[NodeGroup],
|
|
313
256
|
output: str,
|
|
314
257
|
force: bool = False,
|
|
315
258
|
no_cp: bool = False,
|
|
316
|
-
) -> tuple[bool,
|
|
259
|
+
) -> tuple[bool, str | None]:
|
|
317
260
|
total_node_count = 0
|
|
318
261
|
for node_group in node_groups:
|
|
319
262
|
if not node_group.sources:
|
|
@@ -362,7 +305,7 @@ def collect_nodes_for_cp(
|
|
|
362
305
|
|
|
363
306
|
# Collect all sources to process
|
|
364
307
|
for node_group in node_groups:
|
|
365
|
-
listing:
|
|
308
|
+
listing: Listing | None = node_group.listing
|
|
366
309
|
valid_sources: list[DataSource] = []
|
|
367
310
|
for dsrc in node_group.sources:
|
|
368
311
|
if dsrc.is_single_object():
|
|
@@ -406,7 +349,7 @@ def instantiate_node_groups(
|
|
|
406
349
|
recursive: bool = False,
|
|
407
350
|
virtual_only: bool = False,
|
|
408
351
|
always_copy_dir_contents: bool = False,
|
|
409
|
-
copy_to_filename:
|
|
352
|
+
copy_to_filename: str | None = None,
|
|
410
353
|
) -> None:
|
|
411
354
|
instantiate_progress_bar = (
|
|
412
355
|
None
|
|
@@ -434,7 +377,7 @@ def instantiate_node_groups(
|
|
|
434
377
|
for node_group in node_groups:
|
|
435
378
|
if not node_group.sources:
|
|
436
379
|
continue
|
|
437
|
-
listing:
|
|
380
|
+
listing: Listing | None = node_group.listing
|
|
438
381
|
source_path: str = node_group.source_path
|
|
439
382
|
|
|
440
383
|
copy_dir_contents = always_copy_dir_contents or source_path.endswith("/")
|
|
@@ -517,10 +460,8 @@ class Catalog:
|
|
|
517
460
|
warehouse: "AbstractWarehouse",
|
|
518
461
|
cache_dir=None,
|
|
519
462
|
tmp_dir=None,
|
|
520
|
-
client_config:
|
|
521
|
-
warehouse_ready_callback:
|
|
522
|
-
Callable[["AbstractWarehouse"], None]
|
|
523
|
-
] = None,
|
|
463
|
+
client_config: dict[str, Any] | None = None,
|
|
464
|
+
warehouse_ready_callback: Callable[["AbstractWarehouse"], None] | None = None,
|
|
524
465
|
in_memory: bool = False,
|
|
525
466
|
):
|
|
526
467
|
datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
|
|
@@ -535,6 +476,7 @@ class Catalog:
|
|
|
535
476
|
}
|
|
536
477
|
self._warehouse_ready_callback = warehouse_ready_callback
|
|
537
478
|
self.in_memory = in_memory
|
|
479
|
+
self._owns_connections = True # False for copies, prevents double-close
|
|
538
480
|
|
|
539
481
|
@cached_property
|
|
540
482
|
def warehouse(self) -> "AbstractWarehouse":
|
|
@@ -556,13 +498,36 @@ class Catalog:
|
|
|
556
498
|
}
|
|
557
499
|
|
|
558
500
|
def copy(self, cache=True, db=True):
|
|
501
|
+
"""
|
|
502
|
+
Create a shallow copy of this catalog.
|
|
503
|
+
|
|
504
|
+
The copy shares metastore and warehouse with the original but will not
|
|
505
|
+
close them - only the original catalog owns the connections.
|
|
506
|
+
"""
|
|
559
507
|
result = copy(self)
|
|
508
|
+
result._owns_connections = False
|
|
560
509
|
if not db:
|
|
561
510
|
result.metastore = None
|
|
562
511
|
result._warehouse = None
|
|
563
512
|
result.warehouse = None
|
|
564
513
|
return result
|
|
565
514
|
|
|
515
|
+
def close(self) -> None:
|
|
516
|
+
if not self._owns_connections:
|
|
517
|
+
return
|
|
518
|
+
if self.metastore is not None:
|
|
519
|
+
with suppress(Exception):
|
|
520
|
+
self.metastore.close_on_exit()
|
|
521
|
+
if self._warehouse is not None:
|
|
522
|
+
with suppress(Exception):
|
|
523
|
+
self._warehouse.close_on_exit()
|
|
524
|
+
|
|
525
|
+
def __enter__(self) -> "Catalog":
|
|
526
|
+
return self
|
|
527
|
+
|
|
528
|
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
529
|
+
self.close()
|
|
530
|
+
|
|
566
531
|
@classmethod
|
|
567
532
|
def generate_query_dataset_name(cls) -> str:
|
|
568
533
|
return f"{QUERY_DATASET_PREFIX}_{uuid4().hex}"
|
|
@@ -580,15 +545,13 @@ class Catalog:
|
|
|
580
545
|
source: str,
|
|
581
546
|
update=False,
|
|
582
547
|
client_config=None,
|
|
583
|
-
|
|
548
|
+
column="file",
|
|
584
549
|
skip_indexing=False,
|
|
585
|
-
) -> tuple[
|
|
550
|
+
) -> tuple["Listing | None", Client, str]:
|
|
586
551
|
from datachain import read_storage
|
|
587
552
|
from datachain.listing import Listing
|
|
588
553
|
|
|
589
|
-
read_storage(
|
|
590
|
-
source, session=self.session, update=update, object_name=object_name
|
|
591
|
-
).exec()
|
|
554
|
+
read_storage(source, session=self.session, update=update, column=column).exec()
|
|
592
555
|
|
|
593
556
|
list_ds_name, list_uri, list_path, _ = get_listing(
|
|
594
557
|
source, self.session, update=update
|
|
@@ -602,13 +565,13 @@ class Catalog:
|
|
|
602
565
|
self.warehouse.clone(),
|
|
603
566
|
client,
|
|
604
567
|
dataset_name=list_ds_name,
|
|
605
|
-
|
|
568
|
+
column=column,
|
|
606
569
|
)
|
|
607
570
|
|
|
608
571
|
return lst, client, list_path
|
|
609
572
|
|
|
610
573
|
def _remove_dataset_rows_and_warehouse_info(
|
|
611
|
-
self, dataset: DatasetRecord, version:
|
|
574
|
+
self, dataset: DatasetRecord, version: str, **kwargs
|
|
612
575
|
):
|
|
613
576
|
self.warehouse.drop_dataset_rows_table(dataset, version)
|
|
614
577
|
self.update_dataset_version_with_warehouse_info(
|
|
@@ -618,6 +581,7 @@ class Catalog:
|
|
|
618
581
|
**kwargs,
|
|
619
582
|
)
|
|
620
583
|
|
|
584
|
+
@contextmanager
|
|
621
585
|
def enlist_sources(
|
|
622
586
|
self,
|
|
623
587
|
sources: list[str],
|
|
@@ -625,34 +589,41 @@ class Catalog:
|
|
|
625
589
|
skip_indexing=False,
|
|
626
590
|
client_config=None,
|
|
627
591
|
only_index=False,
|
|
628
|
-
) ->
|
|
629
|
-
enlisted_sources = []
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
592
|
+
) -> Iterator[list["DataSource"] | None]:
|
|
593
|
+
enlisted_sources: list[tuple[Listing | None, Client, str]] = []
|
|
594
|
+
try:
|
|
595
|
+
for src in sources: # Opt: parallel
|
|
596
|
+
listing, client, file_path = self.enlist_source(
|
|
597
|
+
src,
|
|
598
|
+
update,
|
|
599
|
+
client_config=client_config or self.client_config,
|
|
600
|
+
skip_indexing=skip_indexing,
|
|
601
|
+
)
|
|
602
|
+
enlisted_sources.append((listing, client, file_path))
|
|
603
|
+
|
|
604
|
+
if only_index:
|
|
605
|
+
# sometimes we don't really need listing result (e.g. on indexing
|
|
606
|
+
# process) so this is to improve performance
|
|
607
|
+
yield None
|
|
608
|
+
return
|
|
609
|
+
|
|
610
|
+
dsrc_all: list[DataSource] = []
|
|
611
|
+
for listing, client, file_path in enlisted_sources:
|
|
612
|
+
if not listing:
|
|
613
|
+
nodes = [Node.from_file(client.get_file_info(file_path))]
|
|
614
|
+
dir_only = False
|
|
615
|
+
else:
|
|
616
|
+
nodes = listing.expand_path(file_path)
|
|
617
|
+
dir_only = file_path.endswith("/")
|
|
618
|
+
dsrc_all.extend(
|
|
619
|
+
DataSource(listing, client, node, dir_only) for node in nodes
|
|
620
|
+
)
|
|
621
|
+
yield dsrc_all
|
|
622
|
+
finally:
|
|
623
|
+
for listing, _, _ in enlisted_sources:
|
|
624
|
+
if listing:
|
|
625
|
+
with suppress(Exception):
|
|
626
|
+
listing.close()
|
|
656
627
|
|
|
657
628
|
def enlist_sources_grouped(
|
|
658
629
|
self,
|
|
@@ -671,10 +642,15 @@ class Catalog:
|
|
|
671
642
|
enlisted_sources: list[tuple[bool, bool, Any]] = []
|
|
672
643
|
client_config = client_config or self.client_config
|
|
673
644
|
for src in sources: # Opt: parallel
|
|
674
|
-
listing:
|
|
645
|
+
listing: Listing | None
|
|
675
646
|
if src.startswith("ds://"):
|
|
676
647
|
ds_name, ds_version = parse_dataset_uri(src)
|
|
677
|
-
|
|
648
|
+
ds_namespace, ds_project, ds_name = parse_dataset_name(ds_name)
|
|
649
|
+
assert ds_namespace
|
|
650
|
+
assert ds_project
|
|
651
|
+
dataset = self.get_dataset(
|
|
652
|
+
ds_name, namespace_name=ds_namespace, project_name=ds_project
|
|
653
|
+
)
|
|
678
654
|
if not ds_version:
|
|
679
655
|
ds_version = dataset.latest_version
|
|
680
656
|
dataset_sources = self.warehouse.get_dataset_sources(
|
|
@@ -694,7 +670,11 @@ class Catalog:
|
|
|
694
670
|
dataset_name=dataset_name,
|
|
695
671
|
)
|
|
696
672
|
rows = DatasetQuery(
|
|
697
|
-
name=dataset.name,
|
|
673
|
+
name=dataset.name,
|
|
674
|
+
namespace_name=dataset.project.namespace.name,
|
|
675
|
+
project_name=dataset.project.name,
|
|
676
|
+
version=ds_version,
|
|
677
|
+
catalog=self,
|
|
698
678
|
).to_db_records()
|
|
699
679
|
indexed_sources.append(
|
|
700
680
|
(
|
|
@@ -768,44 +748,56 @@ class Catalog:
|
|
|
768
748
|
def create_dataset(
|
|
769
749
|
self,
|
|
770
750
|
name: str,
|
|
771
|
-
|
|
751
|
+
project: Project | None = None,
|
|
752
|
+
version: str | None = None,
|
|
772
753
|
*,
|
|
773
754
|
columns: Sequence[Column],
|
|
774
|
-
feature_schema:
|
|
755
|
+
feature_schema: dict | None = None,
|
|
775
756
|
query_script: str = "",
|
|
776
|
-
create_rows:
|
|
777
|
-
validate_version:
|
|
778
|
-
listing:
|
|
779
|
-
uuid:
|
|
780
|
-
description:
|
|
781
|
-
|
|
757
|
+
create_rows: bool | None = True,
|
|
758
|
+
validate_version: bool | None = True,
|
|
759
|
+
listing: bool | None = False,
|
|
760
|
+
uuid: str | None = None,
|
|
761
|
+
description: str | None = None,
|
|
762
|
+
attrs: list[str] | None = None,
|
|
763
|
+
update_version: str | None = "patch",
|
|
764
|
+
job_id: str | None = None,
|
|
782
765
|
) -> "DatasetRecord":
|
|
783
766
|
"""
|
|
784
767
|
Creates new dataset of a specific version.
|
|
785
768
|
If dataset is not yet created, it will create it with version 1
|
|
786
769
|
If version is None, then next unused version is created.
|
|
787
|
-
If version is given, then it must be an unused version
|
|
770
|
+
If version is given, then it must be an unused version.
|
|
788
771
|
"""
|
|
772
|
+
DatasetRecord.validate_name(name)
|
|
789
773
|
assert [c.name for c in columns if c.name != "sys__id"], f"got {columns=}"
|
|
790
774
|
if not listing and Client.is_data_source_uri(name):
|
|
791
775
|
raise RuntimeError(
|
|
792
776
|
"Cannot create dataset that starts with source prefix, e.g s3://"
|
|
793
777
|
)
|
|
794
|
-
default_version =
|
|
778
|
+
default_version = DEFAULT_DATASET_VERSION
|
|
795
779
|
try:
|
|
796
|
-
dataset = self.get_dataset(
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
780
|
+
dataset = self.get_dataset(
|
|
781
|
+
name,
|
|
782
|
+
namespace_name=project.namespace.name if project else None,
|
|
783
|
+
project_name=project.name if project else None,
|
|
784
|
+
)
|
|
785
|
+
default_version = dataset.next_version_patch
|
|
786
|
+
if update_version == "major":
|
|
787
|
+
default_version = dataset.next_version_major
|
|
788
|
+
if update_version == "minor":
|
|
789
|
+
default_version = dataset.next_version_minor
|
|
790
|
+
|
|
791
|
+
if (description or attrs) and (
|
|
792
|
+
dataset.description != description or dataset.attrs != attrs
|
|
801
793
|
):
|
|
802
794
|
description = description or dataset.description
|
|
803
|
-
|
|
795
|
+
attrs = attrs or dataset.attrs
|
|
804
796
|
|
|
805
797
|
self.update_dataset(
|
|
806
798
|
dataset,
|
|
807
799
|
description=description,
|
|
808
|
-
|
|
800
|
+
attrs=attrs,
|
|
809
801
|
)
|
|
810
802
|
|
|
811
803
|
except DatasetNotFoundError:
|
|
@@ -814,12 +806,13 @@ class Catalog:
|
|
|
814
806
|
}
|
|
815
807
|
dataset = self.metastore.create_dataset(
|
|
816
808
|
name,
|
|
809
|
+
project.id if project else None,
|
|
817
810
|
feature_schema=feature_schema,
|
|
818
811
|
query_script=query_script,
|
|
819
812
|
schema=schema,
|
|
820
813
|
ignore_if_exists=True,
|
|
821
814
|
description=description,
|
|
822
|
-
|
|
815
|
+
attrs=attrs,
|
|
823
816
|
)
|
|
824
817
|
|
|
825
818
|
version = version or default_version
|
|
@@ -834,7 +827,7 @@ class Catalog:
|
|
|
834
827
|
f"Version {version} must be higher than the current latest one"
|
|
835
828
|
)
|
|
836
829
|
|
|
837
|
-
return self.
|
|
830
|
+
return self.create_dataset_version(
|
|
838
831
|
dataset,
|
|
839
832
|
version,
|
|
840
833
|
feature_schema=feature_schema,
|
|
@@ -842,12 +835,13 @@ class Catalog:
|
|
|
842
835
|
create_rows_table=create_rows,
|
|
843
836
|
columns=columns,
|
|
844
837
|
uuid=uuid,
|
|
838
|
+
job_id=job_id,
|
|
845
839
|
)
|
|
846
840
|
|
|
847
|
-
def
|
|
841
|
+
def create_dataset_version(
|
|
848
842
|
self,
|
|
849
843
|
dataset: DatasetRecord,
|
|
850
|
-
version:
|
|
844
|
+
version: str,
|
|
851
845
|
*,
|
|
852
846
|
columns: Sequence[Column],
|
|
853
847
|
sources="",
|
|
@@ -857,8 +851,8 @@ class Catalog:
|
|
|
857
851
|
error_stack="",
|
|
858
852
|
script_output="",
|
|
859
853
|
create_rows_table=True,
|
|
860
|
-
job_id:
|
|
861
|
-
uuid:
|
|
854
|
+
job_id: str | None = None,
|
|
855
|
+
uuid: str | None = None,
|
|
862
856
|
) -> DatasetRecord:
|
|
863
857
|
"""
|
|
864
858
|
Creates dataset version if it doesn't exist.
|
|
@@ -872,7 +866,7 @@ class Catalog:
|
|
|
872
866
|
dataset = self.metastore.create_dataset_version(
|
|
873
867
|
dataset,
|
|
874
868
|
version,
|
|
875
|
-
status=DatasetStatus.
|
|
869
|
+
status=DatasetStatus.CREATED,
|
|
876
870
|
sources=sources,
|
|
877
871
|
feature_schema=feature_schema,
|
|
878
872
|
query_script=query_script,
|
|
@@ -886,14 +880,14 @@ class Catalog:
|
|
|
886
880
|
)
|
|
887
881
|
|
|
888
882
|
if create_rows_table:
|
|
889
|
-
table_name = self.warehouse.dataset_table_name(dataset
|
|
883
|
+
table_name = self.warehouse.dataset_table_name(dataset, version)
|
|
890
884
|
self.warehouse.create_dataset_rows_table(table_name, columns=columns)
|
|
891
885
|
self.update_dataset_version_with_warehouse_info(dataset, version)
|
|
892
886
|
|
|
893
887
|
return dataset
|
|
894
888
|
|
|
895
889
|
def update_dataset_version_with_warehouse_info(
|
|
896
|
-
self, dataset: DatasetRecord, version:
|
|
890
|
+
self, dataset: DatasetRecord, version: str, rows_dropped=False, **kwargs
|
|
897
891
|
) -> None:
|
|
898
892
|
from datachain.query.dataset import DatasetQuery
|
|
899
893
|
|
|
@@ -905,11 +899,7 @@ class Catalog:
|
|
|
905
899
|
values["num_objects"] = None
|
|
906
900
|
values["size"] = None
|
|
907
901
|
values["preview"] = None
|
|
908
|
-
self.metastore.update_dataset_version(
|
|
909
|
-
dataset,
|
|
910
|
-
version,
|
|
911
|
-
**values,
|
|
912
|
-
)
|
|
902
|
+
self.metastore.update_dataset_version(dataset, version, **values)
|
|
913
903
|
return
|
|
914
904
|
|
|
915
905
|
if not dataset_version.num_objects:
|
|
@@ -921,7 +911,13 @@ class Catalog:
|
|
|
921
911
|
|
|
922
912
|
if not dataset_version.preview:
|
|
923
913
|
values["preview"] = (
|
|
924
|
-
DatasetQuery(
|
|
914
|
+
DatasetQuery(
|
|
915
|
+
name=dataset.name,
|
|
916
|
+
namespace_name=dataset.project.namespace.name,
|
|
917
|
+
project_name=dataset.project.name,
|
|
918
|
+
version=version,
|
|
919
|
+
catalog=self,
|
|
920
|
+
)
|
|
925
921
|
.limit(20)
|
|
926
922
|
.to_db_records()
|
|
927
923
|
)
|
|
@@ -929,38 +925,18 @@ class Catalog:
|
|
|
929
925
|
if not values:
|
|
930
926
|
return
|
|
931
927
|
|
|
932
|
-
self.metastore.update_dataset_version(
|
|
933
|
-
dataset,
|
|
934
|
-
version,
|
|
935
|
-
**values,
|
|
936
|
-
)
|
|
928
|
+
self.metastore.update_dataset_version(dataset, version, **values)
|
|
937
929
|
|
|
938
930
|
def update_dataset(
|
|
939
931
|
self, dataset: DatasetRecord, conn=None, **kwargs
|
|
940
932
|
) -> DatasetRecord:
|
|
941
933
|
"""Updates dataset fields."""
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
old_name = dataset.name
|
|
946
|
-
new_name = kwargs["name"]
|
|
947
|
-
|
|
948
|
-
dataset = self.metastore.update_dataset(dataset, conn=conn, **kwargs)
|
|
949
|
-
|
|
950
|
-
if old_name and new_name:
|
|
951
|
-
# updating name must result in updating dataset table names as well
|
|
952
|
-
for version in [v.version for v in dataset.versions]:
|
|
953
|
-
self.warehouse.rename_dataset_table(
|
|
954
|
-
old_name,
|
|
955
|
-
new_name,
|
|
956
|
-
old_version=version,
|
|
957
|
-
new_version=version,
|
|
958
|
-
)
|
|
959
|
-
|
|
960
|
-
return dataset
|
|
934
|
+
dataset_updated = self.metastore.update_dataset(dataset, conn=conn, **kwargs)
|
|
935
|
+
self.warehouse.rename_dataset_tables(dataset, dataset_updated)
|
|
936
|
+
return dataset_updated
|
|
961
937
|
|
|
962
938
|
def remove_dataset_version(
|
|
963
|
-
self, dataset: DatasetRecord, version:
|
|
939
|
+
self, dataset: DatasetRecord, version: str, drop_rows: bool | None = True
|
|
964
940
|
) -> None:
|
|
965
941
|
"""
|
|
966
942
|
Deletes one single dataset version.
|
|
@@ -988,6 +964,7 @@ class Catalog:
|
|
|
988
964
|
self,
|
|
989
965
|
name: str,
|
|
990
966
|
sources: list[str],
|
|
967
|
+
project: Project | None = None,
|
|
991
968
|
client_config=None,
|
|
992
969
|
recursive=False,
|
|
993
970
|
) -> DatasetRecord:
|
|
@@ -996,6 +973,8 @@ class Catalog:
|
|
|
996
973
|
|
|
997
974
|
from datachain import read_dataset, read_storage
|
|
998
975
|
|
|
976
|
+
project = project or self.metastore.default_project
|
|
977
|
+
|
|
999
978
|
chains = []
|
|
1000
979
|
for source in sources:
|
|
1001
980
|
if source.startswith(DATASET_PREFIX):
|
|
@@ -1008,10 +987,15 @@ class Catalog:
|
|
|
1008
987
|
# create union of all dataset queries created from sources
|
|
1009
988
|
dc = reduce(lambda dc1, dc2: dc1.union(dc2), chains)
|
|
1010
989
|
try:
|
|
990
|
+
dc = dc.settings(project=project.name, namespace=project.namespace.name)
|
|
1011
991
|
dc.save(name)
|
|
1012
992
|
except Exception as e: # noqa: BLE001
|
|
1013
993
|
try:
|
|
1014
|
-
ds = self.get_dataset(
|
|
994
|
+
ds = self.get_dataset(
|
|
995
|
+
name,
|
|
996
|
+
namespace_name=project.namespace.name,
|
|
997
|
+
project_name=project.name,
|
|
998
|
+
)
|
|
1015
999
|
self.metastore.update_dataset_status(
|
|
1016
1000
|
ds,
|
|
1017
1001
|
DatasetStatus.FAILED,
|
|
@@ -1028,7 +1012,11 @@ class Catalog:
|
|
|
1028
1012
|
except DatasetNotFoundError:
|
|
1029
1013
|
raise e from None
|
|
1030
1014
|
|
|
1031
|
-
ds = self.get_dataset(
|
|
1015
|
+
ds = self.get_dataset(
|
|
1016
|
+
name,
|
|
1017
|
+
namespace_name=project.namespace.name,
|
|
1018
|
+
project_name=project.name,
|
|
1019
|
+
)
|
|
1032
1020
|
|
|
1033
1021
|
self.update_dataset_version_with_warehouse_info(
|
|
1034
1022
|
ds,
|
|
@@ -1036,159 +1024,231 @@ class Catalog:
|
|
|
1036
1024
|
sources="\n".join(sources),
|
|
1037
1025
|
)
|
|
1038
1026
|
|
|
1039
|
-
return self.get_dataset(
|
|
1027
|
+
return self.get_dataset(
|
|
1028
|
+
name,
|
|
1029
|
+
namespace_name=project.namespace.name,
|
|
1030
|
+
project_name=project.name,
|
|
1031
|
+
)
|
|
1040
1032
|
|
|
1041
|
-
def
|
|
1033
|
+
def get_full_dataset_name(
|
|
1042
1034
|
self,
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
) -> DatasetRecord:
|
|
1035
|
+
name: str,
|
|
1036
|
+
project_name: str | None = None,
|
|
1037
|
+
namespace_name: str | None = None,
|
|
1038
|
+
) -> tuple[str, str, str]:
|
|
1048
1039
|
"""
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
It also removes original dataset version
|
|
1040
|
+
Returns dataset name together with separated namespace and project name.
|
|
1041
|
+
It takes into account all the ways namespace and project can be added.
|
|
1052
1042
|
"""
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
if not dataset_version.is_final_status():
|
|
1067
|
-
raise ValueError("Cannot register dataset version in non final status")
|
|
1068
|
-
|
|
1069
|
-
# copy dataset version
|
|
1070
|
-
target_dataset = self.metastore.create_dataset_version(
|
|
1071
|
-
target_dataset,
|
|
1072
|
-
target_version,
|
|
1073
|
-
sources=dataset_version.sources,
|
|
1074
|
-
status=dataset_version.status,
|
|
1075
|
-
query_script=dataset_version.query_script,
|
|
1076
|
-
error_message=dataset_version.error_message,
|
|
1077
|
-
error_stack=dataset_version.error_stack,
|
|
1078
|
-
script_output=dataset_version.script_output,
|
|
1079
|
-
created_at=dataset_version.created_at,
|
|
1080
|
-
finished_at=dataset_version.finished_at,
|
|
1081
|
-
schema=dataset_version.serialized_schema,
|
|
1082
|
-
num_objects=dataset_version.num_objects,
|
|
1083
|
-
size=dataset_version.size,
|
|
1084
|
-
preview=dataset_version.preview,
|
|
1085
|
-
job_id=dataset_version.job_id,
|
|
1086
|
-
)
|
|
1087
|
-
|
|
1088
|
-
# to avoid re-creating rows table, we are just renaming it for a new version
|
|
1089
|
-
# of target dataset
|
|
1090
|
-
self.warehouse.rename_dataset_table(
|
|
1091
|
-
dataset.name,
|
|
1092
|
-
target_dataset.name,
|
|
1093
|
-
old_version=version,
|
|
1094
|
-
new_version=target_version,
|
|
1043
|
+
parsed_namespace_name, parsed_project_name, name = parse_dataset_name(name)
|
|
1044
|
+
|
|
1045
|
+
namespace_env = os.environ.get("DATACHAIN_NAMESPACE")
|
|
1046
|
+
project_env = os.environ.get("DATACHAIN_PROJECT")
|
|
1047
|
+
if project_env and len(project_env.split(".")) == 2:
|
|
1048
|
+
# we allow setting both namespace and project in DATACHAIN_PROJECT
|
|
1049
|
+
namespace_env, project_env = project_env.split(".")
|
|
1050
|
+
|
|
1051
|
+
namespace_name = (
|
|
1052
|
+
parsed_namespace_name
|
|
1053
|
+
or namespace_name
|
|
1054
|
+
or namespace_env
|
|
1055
|
+
or self.metastore.default_namespace_name
|
|
1095
1056
|
)
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1057
|
+
project_name = (
|
|
1058
|
+
parsed_project_name
|
|
1059
|
+
or project_name
|
|
1060
|
+
or project_env
|
|
1061
|
+
or self.metastore.default_project_name
|
|
1101
1062
|
)
|
|
1102
1063
|
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1064
|
+
return namespace_name, project_name, name
|
|
1065
|
+
|
|
1066
|
+
def get_dataset(
|
|
1067
|
+
self,
|
|
1068
|
+
name: str,
|
|
1069
|
+
namespace_name: str | None = None,
|
|
1070
|
+
project_name: str | None = None,
|
|
1071
|
+
) -> DatasetRecord:
|
|
1072
|
+
from datachain.lib.listing import is_listing_dataset
|
|
1107
1073
|
|
|
1108
|
-
self.
|
|
1074
|
+
namespace_name = namespace_name or self.metastore.default_namespace_name
|
|
1075
|
+
project_name = project_name or self.metastore.default_project_name
|
|
1109
1076
|
|
|
1110
|
-
|
|
1077
|
+
if is_listing_dataset(name):
|
|
1078
|
+
namespace_name = self.metastore.system_namespace_name
|
|
1079
|
+
project_name = self.metastore.listing_project_name
|
|
1111
1080
|
|
|
1112
|
-
|
|
1113
|
-
|
|
1081
|
+
return self.metastore.get_dataset(
|
|
1082
|
+
name, namespace_name=namespace_name, project_name=project_name
|
|
1083
|
+
)
|
|
1114
1084
|
|
|
1115
1085
|
def get_dataset_with_remote_fallback(
|
|
1116
|
-
self,
|
|
1086
|
+
self,
|
|
1087
|
+
name: str,
|
|
1088
|
+
namespace_name: str,
|
|
1089
|
+
project_name: str,
|
|
1090
|
+
version: str | None = None,
|
|
1091
|
+
pull_dataset: bool = False,
|
|
1092
|
+
update: bool = False,
|
|
1117
1093
|
) -> DatasetRecord:
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1094
|
+
from datachain.lib.dc.utils import is_studio
|
|
1095
|
+
|
|
1096
|
+
# Intentionally ignore update flag is version is provided. Here only exact
|
|
1097
|
+
# version can be provided and update then doesn't make sense.
|
|
1098
|
+
# It corresponds to a query like this for example:
|
|
1099
|
+
#
|
|
1100
|
+
# dc.read_dataset("some.remote.dataset", version="1.0.0", update=True)
|
|
1101
|
+
if version:
|
|
1102
|
+
update = False
|
|
1103
|
+
|
|
1104
|
+
# we don't do Studio fallback is script is already ran in Studio, or if we try
|
|
1105
|
+
# to fetch dataset with local namespace as that one cannot
|
|
1106
|
+
# exist in Studio in the first place
|
|
1107
|
+
no_fallback = is_studio() or is_namespace_local(namespace_name)
|
|
1108
|
+
|
|
1109
|
+
if no_fallback or not update:
|
|
1110
|
+
try:
|
|
1111
|
+
ds = self.get_dataset(
|
|
1112
|
+
name,
|
|
1113
|
+
namespace_name=namespace_name,
|
|
1114
|
+
project_name=project_name,
|
|
1123
1115
|
)
|
|
1124
|
-
|
|
1116
|
+
if not version or ds.has_version(version):
|
|
1117
|
+
return ds
|
|
1118
|
+
except (NamespaceNotFoundError, ProjectNotFoundError, DatasetNotFoundError):
|
|
1119
|
+
pass
|
|
1120
|
+
|
|
1121
|
+
if no_fallback:
|
|
1122
|
+
raise DatasetNotFoundError(
|
|
1123
|
+
f"Dataset {name}"
|
|
1124
|
+
+ (f" version {version} " if version else " ")
|
|
1125
|
+
+ f"not found in namespace {namespace_name} and project {project_name}"
|
|
1126
|
+
)
|
|
1125
1127
|
|
|
1126
|
-
|
|
1128
|
+
if pull_dataset:
|
|
1127
1129
|
print("Dataset not found in local catalog, trying to get from studio")
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
remote_ds_uri += f"@v{version}"
|
|
1130
|
+
remote_ds_uri = create_dataset_uri(
|
|
1131
|
+
name, namespace_name, project_name, version
|
|
1132
|
+
)
|
|
1132
1133
|
|
|
1133
1134
|
self.pull_dataset(
|
|
1134
1135
|
remote_ds_uri=remote_ds_uri,
|
|
1135
1136
|
local_ds_name=name,
|
|
1136
1137
|
local_ds_version=version,
|
|
1137
1138
|
)
|
|
1138
|
-
return self.get_dataset(
|
|
1139
|
+
return self.get_dataset(
|
|
1140
|
+
name,
|
|
1141
|
+
namespace_name=namespace_name,
|
|
1142
|
+
project_name=project_name,
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
return self.get_remote_dataset(namespace_name, project_name, name)
|
|
1139
1146
|
|
|
1140
1147
|
def get_dataset_with_version_uuid(self, uuid: str) -> DatasetRecord:
|
|
1141
1148
|
"""Returns dataset that contains version with specific uuid"""
|
|
1142
1149
|
for dataset in self.ls_datasets():
|
|
1143
1150
|
if dataset.has_version_with_uuid(uuid):
|
|
1144
|
-
return self.get_dataset(
|
|
1151
|
+
return self.get_dataset(
|
|
1152
|
+
dataset.name,
|
|
1153
|
+
namespace_name=dataset.project.namespace.name,
|
|
1154
|
+
project_name=dataset.project.name,
|
|
1155
|
+
)
|
|
1145
1156
|
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")
|
|
1146
1157
|
|
|
1147
|
-
def get_remote_dataset(
|
|
1158
|
+
def get_remote_dataset(
|
|
1159
|
+
self, namespace: str, project: str, name: str
|
|
1160
|
+
) -> DatasetRecord:
|
|
1148
1161
|
from datachain.remote.studio import StudioClient
|
|
1149
1162
|
|
|
1150
1163
|
studio_client = StudioClient()
|
|
1151
1164
|
|
|
1152
|
-
info_response = studio_client.dataset_info(name)
|
|
1165
|
+
info_response = studio_client.dataset_info(namespace, project, name)
|
|
1153
1166
|
if not info_response.ok:
|
|
1167
|
+
if info_response.status == 404:
|
|
1168
|
+
raise DatasetNotFoundError(
|
|
1169
|
+
f"Dataset {namespace}.{project}.{name} not found"
|
|
1170
|
+
)
|
|
1154
1171
|
raise DataChainError(info_response.message)
|
|
1155
1172
|
|
|
1156
1173
|
dataset_info = info_response.data
|
|
1157
1174
|
assert isinstance(dataset_info, dict)
|
|
1158
1175
|
return DatasetRecord.from_dict(dataset_info)
|
|
1159
1176
|
|
|
1160
|
-
def
|
|
1161
|
-
self,
|
|
1162
|
-
|
|
1163
|
-
|
|
1177
|
+
def get_dataset_dependencies_by_ids(
|
|
1178
|
+
self,
|
|
1179
|
+
dataset_id: int,
|
|
1180
|
+
version_id: int,
|
|
1181
|
+
indirect: bool = True,
|
|
1182
|
+
) -> list[DatasetDependency | None]:
|
|
1183
|
+
dependency_nodes = self.metastore.get_dataset_dependency_nodes(
|
|
1184
|
+
dataset_id=dataset_id,
|
|
1185
|
+
version_id=version_id,
|
|
1186
|
+
)
|
|
1187
|
+
|
|
1188
|
+
if not dependency_nodes:
|
|
1189
|
+
return []
|
|
1190
|
+
|
|
1191
|
+
dependency_map, children_map = build_dependency_hierarchy(dependency_nodes)
|
|
1164
1192
|
|
|
1165
|
-
|
|
1166
|
-
|
|
1193
|
+
root_key = (dataset_id, version_id)
|
|
1194
|
+
if root_key not in children_map:
|
|
1195
|
+
return []
|
|
1196
|
+
|
|
1197
|
+
root_dependency_ids = children_map[root_key]
|
|
1198
|
+
root_dependencies = [dependency_map[dep_id] for dep_id in root_dependency_ids]
|
|
1199
|
+
|
|
1200
|
+
if indirect:
|
|
1201
|
+
for dependency in root_dependencies:
|
|
1202
|
+
if dependency is not None:
|
|
1203
|
+
populate_nested_dependencies(
|
|
1204
|
+
dependency, dependency_nodes, dependency_map, children_map
|
|
1205
|
+
)
|
|
1206
|
+
|
|
1207
|
+
return root_dependencies
|
|
1208
|
+
|
|
1209
|
+
def get_dataset_dependencies(
|
|
1210
|
+
self,
|
|
1211
|
+
name: str,
|
|
1212
|
+
version: str,
|
|
1213
|
+
namespace_name: str | None = None,
|
|
1214
|
+
project_name: str | None = None,
|
|
1215
|
+
indirect=False,
|
|
1216
|
+
) -> list[DatasetDependency | None]:
|
|
1217
|
+
dataset = self.get_dataset(
|
|
1218
|
+
name,
|
|
1219
|
+
namespace_name=namespace_name,
|
|
1220
|
+
project_name=project_name,
|
|
1167
1221
|
)
|
|
1222
|
+
dataset_version = dataset.get_version(version)
|
|
1223
|
+
dataset_id = dataset.id
|
|
1224
|
+
dataset_version_id = dataset_version.id
|
|
1168
1225
|
|
|
1169
1226
|
if not indirect:
|
|
1170
|
-
return
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
# dependency has been removed
|
|
1175
|
-
continue
|
|
1176
|
-
if d.is_dataset:
|
|
1177
|
-
# only datasets can have dependencies
|
|
1178
|
-
d.dependencies = self.get_dataset_dependencies(
|
|
1179
|
-
d.name, int(d.version), indirect=indirect
|
|
1180
|
-
)
|
|
1227
|
+
return self.metastore.get_direct_dataset_dependencies(
|
|
1228
|
+
dataset,
|
|
1229
|
+
version,
|
|
1230
|
+
)
|
|
1181
1231
|
|
|
1182
|
-
return
|
|
1232
|
+
return self.get_dataset_dependencies_by_ids(
|
|
1233
|
+
dataset_id,
|
|
1234
|
+
dataset_version_id,
|
|
1235
|
+
indirect,
|
|
1236
|
+
)
|
|
1183
1237
|
|
|
1184
1238
|
def ls_datasets(
|
|
1185
|
-
self,
|
|
1239
|
+
self,
|
|
1240
|
+
prefix: str | None = None,
|
|
1241
|
+
include_listing: bool = False,
|
|
1242
|
+
studio: bool = False,
|
|
1243
|
+
project: Project | None = None,
|
|
1186
1244
|
) -> Iterator[DatasetListRecord]:
|
|
1187
1245
|
from datachain.remote.studio import StudioClient
|
|
1188
1246
|
|
|
1247
|
+
project_id = project.id if project else None
|
|
1248
|
+
|
|
1189
1249
|
if studio:
|
|
1190
1250
|
client = StudioClient()
|
|
1191
|
-
response = client.ls_datasets()
|
|
1251
|
+
response = client.ls_datasets(prefix=prefix)
|
|
1192
1252
|
if not response.ok:
|
|
1193
1253
|
raise DataChainError(response.message)
|
|
1194
1254
|
if not response.data:
|
|
@@ -1199,8 +1259,12 @@ class Catalog:
|
|
|
1199
1259
|
for d in response.data
|
|
1200
1260
|
if not d.get("name", "").startswith(QUERY_DATASET_PREFIX)
|
|
1201
1261
|
)
|
|
1262
|
+
elif prefix:
|
|
1263
|
+
datasets = self.metastore.list_datasets_by_prefix(
|
|
1264
|
+
prefix, project_id=project_id
|
|
1265
|
+
)
|
|
1202
1266
|
else:
|
|
1203
|
-
datasets = self.metastore.list_datasets()
|
|
1267
|
+
datasets = self.metastore.list_datasets(project_id=project_id)
|
|
1204
1268
|
|
|
1205
1269
|
for d in datasets:
|
|
1206
1270
|
if not d.is_bucket_listing or include_listing:
|
|
@@ -1208,50 +1272,79 @@ class Catalog:
|
|
|
1208
1272
|
|
|
1209
1273
|
def list_datasets_versions(
|
|
1210
1274
|
self,
|
|
1275
|
+
prefix: str | None = None,
|
|
1211
1276
|
include_listing: bool = False,
|
|
1277
|
+
with_job: bool = True,
|
|
1212
1278
|
studio: bool = False,
|
|
1213
|
-
|
|
1279
|
+
project: Project | None = None,
|
|
1280
|
+
) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", "Job | None"]]:
|
|
1214
1281
|
"""Iterate over all dataset versions with related jobs."""
|
|
1215
1282
|
datasets = list(
|
|
1216
|
-
self.ls_datasets(
|
|
1283
|
+
self.ls_datasets(
|
|
1284
|
+
prefix=prefix,
|
|
1285
|
+
include_listing=include_listing,
|
|
1286
|
+
studio=studio,
|
|
1287
|
+
project=project,
|
|
1288
|
+
)
|
|
1217
1289
|
)
|
|
1218
1290
|
|
|
1219
1291
|
# preselect dataset versions jobs from db to avoid multiple queries
|
|
1220
|
-
jobs_ids: set[str] = {
|
|
1221
|
-
v.job_id for ds in datasets for v in ds.versions if v.job_id
|
|
1222
|
-
}
|
|
1223
1292
|
jobs: dict[str, Job] = {}
|
|
1224
|
-
if
|
|
1225
|
-
|
|
1293
|
+
if with_job:
|
|
1294
|
+
jobs_ids: set[str] = {
|
|
1295
|
+
v.job_id for ds in datasets for v in ds.versions if v.job_id
|
|
1296
|
+
}
|
|
1297
|
+
if jobs_ids:
|
|
1298
|
+
jobs = {
|
|
1299
|
+
j.id: j for j in self.metastore.list_jobs_by_ids(list(jobs_ids))
|
|
1300
|
+
}
|
|
1226
1301
|
|
|
1227
1302
|
for d in datasets:
|
|
1228
1303
|
yield from (
|
|
1229
|
-
(d, v, jobs.get(str(v.job_id)) if v.job_id else None)
|
|
1304
|
+
(d, v, jobs.get(str(v.job_id)) if with_job and v.job_id else None)
|
|
1230
1305
|
for v in d.versions
|
|
1231
1306
|
)
|
|
1232
1307
|
|
|
1233
|
-
def listings(self):
|
|
1308
|
+
def listings(self, prefix: str | None = None) -> list["ListingInfo"]:
|
|
1234
1309
|
"""
|
|
1235
1310
|
Returns list of ListingInfo objects which are representing specific
|
|
1236
1311
|
storage listing datasets
|
|
1237
1312
|
"""
|
|
1238
|
-
from datachain.lib.listing import is_listing_dataset
|
|
1313
|
+
from datachain.lib.listing import LISTING_PREFIX, is_listing_dataset
|
|
1239
1314
|
from datachain.lib.listing_info import ListingInfo
|
|
1240
1315
|
|
|
1316
|
+
if prefix and not prefix.startswith(LISTING_PREFIX):
|
|
1317
|
+
prefix = LISTING_PREFIX + prefix
|
|
1318
|
+
|
|
1319
|
+
listing_datasets_versions = self.list_datasets_versions(
|
|
1320
|
+
prefix=prefix,
|
|
1321
|
+
include_listing=True,
|
|
1322
|
+
with_job=False,
|
|
1323
|
+
project=self.metastore.listing_project,
|
|
1324
|
+
)
|
|
1325
|
+
|
|
1241
1326
|
return [
|
|
1242
1327
|
ListingInfo.from_models(d, v, j)
|
|
1243
|
-
for d, v, j in
|
|
1328
|
+
for d, v, j in listing_datasets_versions
|
|
1244
1329
|
if is_listing_dataset(d.name)
|
|
1245
1330
|
]
|
|
1246
1331
|
|
|
1247
1332
|
def ls_dataset_rows(
|
|
1248
|
-
self,
|
|
1333
|
+
self,
|
|
1334
|
+
dataset: DatasetRecord,
|
|
1335
|
+
version: str,
|
|
1336
|
+
offset=None,
|
|
1337
|
+
limit=None,
|
|
1249
1338
|
) -> list[dict]:
|
|
1250
1339
|
from datachain.query.dataset import DatasetQuery
|
|
1251
1340
|
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1341
|
+
q = DatasetQuery(
|
|
1342
|
+
name=dataset.name,
|
|
1343
|
+
namespace_name=dataset.project.namespace.name,
|
|
1344
|
+
project_name=dataset.project.name,
|
|
1345
|
+
version=version,
|
|
1346
|
+
catalog=self,
|
|
1347
|
+
)
|
|
1255
1348
|
if limit:
|
|
1256
1349
|
q = q.limit(limit)
|
|
1257
1350
|
if offset:
|
|
@@ -1263,9 +1356,9 @@ class Catalog:
|
|
|
1263
1356
|
self,
|
|
1264
1357
|
source: str,
|
|
1265
1358
|
path: str,
|
|
1266
|
-
version_id:
|
|
1359
|
+
version_id: str | None = None,
|
|
1267
1360
|
client_config=None,
|
|
1268
|
-
content_disposition:
|
|
1361
|
+
content_disposition: str | None = None,
|
|
1269
1362
|
**kwargs,
|
|
1270
1363
|
) -> str:
|
|
1271
1364
|
client_config = client_config or self.client_config
|
|
@@ -1283,26 +1376,42 @@ class Catalog:
|
|
|
1283
1376
|
self,
|
|
1284
1377
|
bucket_uri: str,
|
|
1285
1378
|
name: str,
|
|
1286
|
-
version:
|
|
1379
|
+
version: str,
|
|
1380
|
+
project: Project | None = None,
|
|
1287
1381
|
client_config=None,
|
|
1288
1382
|
) -> list[str]:
|
|
1289
|
-
dataset = self.get_dataset(
|
|
1383
|
+
dataset = self.get_dataset(
|
|
1384
|
+
name,
|
|
1385
|
+
namespace_name=project.namespace.name if project else None,
|
|
1386
|
+
project_name=project.name if project else None,
|
|
1387
|
+
)
|
|
1290
1388
|
|
|
1291
1389
|
return self.warehouse.export_dataset_table(
|
|
1292
1390
|
bucket_uri, dataset, version, client_config
|
|
1293
1391
|
)
|
|
1294
1392
|
|
|
1295
|
-
def dataset_table_export_file_names(
|
|
1296
|
-
|
|
1393
|
+
def dataset_table_export_file_names(
|
|
1394
|
+
self, name: str, version: str, project: Project | None = None
|
|
1395
|
+
) -> list[str]:
|
|
1396
|
+
dataset = self.get_dataset(
|
|
1397
|
+
name,
|
|
1398
|
+
namespace_name=project.namespace.name if project else None,
|
|
1399
|
+
project_name=project.name if project else None,
|
|
1400
|
+
)
|
|
1297
1401
|
return self.warehouse.dataset_table_export_file_names(dataset, version)
|
|
1298
1402
|
|
|
1299
1403
|
def remove_dataset(
|
|
1300
1404
|
self,
|
|
1301
1405
|
name: str,
|
|
1302
|
-
|
|
1303
|
-
|
|
1406
|
+
project: Project | None = None,
|
|
1407
|
+
version: str | None = None,
|
|
1408
|
+
force: bool | None = False,
|
|
1304
1409
|
):
|
|
1305
|
-
dataset = self.get_dataset(
|
|
1410
|
+
dataset = self.get_dataset(
|
|
1411
|
+
name,
|
|
1412
|
+
namespace_name=project.namespace.name if project else None,
|
|
1413
|
+
project_name=project.name if project else None,
|
|
1414
|
+
)
|
|
1306
1415
|
if not version and not force:
|
|
1307
1416
|
raise ValueError(f"Missing dataset version from input for dataset {name}")
|
|
1308
1417
|
if version and not dataset.has_version(version):
|
|
@@ -1324,19 +1433,25 @@ class Catalog:
|
|
|
1324
1433
|
def edit_dataset(
|
|
1325
1434
|
self,
|
|
1326
1435
|
name: str,
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1436
|
+
project: Project | None = None,
|
|
1437
|
+
new_name: str | None = None,
|
|
1438
|
+
description: str | None = None,
|
|
1439
|
+
attrs: list[str] | None = None,
|
|
1330
1440
|
) -> DatasetRecord:
|
|
1331
1441
|
update_data = {}
|
|
1332
1442
|
if new_name:
|
|
1443
|
+
DatasetRecord.validate_name(new_name)
|
|
1333
1444
|
update_data["name"] = new_name
|
|
1334
1445
|
if description is not None:
|
|
1335
1446
|
update_data["description"] = description
|
|
1336
|
-
if
|
|
1337
|
-
update_data["
|
|
1447
|
+
if attrs is not None:
|
|
1448
|
+
update_data["attrs"] = attrs # type: ignore[assignment]
|
|
1338
1449
|
|
|
1339
|
-
dataset = self.get_dataset(
|
|
1450
|
+
dataset = self.get_dataset(
|
|
1451
|
+
name,
|
|
1452
|
+
namespace_name=project.namespace.name if project else None,
|
|
1453
|
+
project_name=project.name if project else None,
|
|
1454
|
+
)
|
|
1340
1455
|
return self.update_dataset(dataset, **update_data)
|
|
1341
1456
|
|
|
1342
1457
|
def ls(
|
|
@@ -1348,22 +1463,24 @@ class Catalog:
|
|
|
1348
1463
|
*,
|
|
1349
1464
|
client_config=None,
|
|
1350
1465
|
) -> Iterator[tuple[DataSource, Iterable[tuple]]]:
|
|
1351
|
-
|
|
1466
|
+
with self.enlist_sources(
|
|
1352
1467
|
sources,
|
|
1353
1468
|
update,
|
|
1354
1469
|
skip_indexing=skip_indexing,
|
|
1355
1470
|
client_config=client_config or self.client_config,
|
|
1356
|
-
)
|
|
1471
|
+
) as data_sources:
|
|
1472
|
+
if data_sources is None:
|
|
1473
|
+
return
|
|
1357
1474
|
|
|
1358
|
-
|
|
1359
|
-
|
|
1475
|
+
for source in data_sources:
|
|
1476
|
+
yield source, source.ls(fields)
|
|
1360
1477
|
|
|
1361
1478
|
def pull_dataset( # noqa: C901, PLR0915
|
|
1362
1479
|
self,
|
|
1363
1480
|
remote_ds_uri: str,
|
|
1364
|
-
output:
|
|
1365
|
-
local_ds_name:
|
|
1366
|
-
local_ds_version:
|
|
1481
|
+
output: str | None = None,
|
|
1482
|
+
local_ds_name: str | None = None,
|
|
1483
|
+
local_ds_version: str | None = None,
|
|
1367
1484
|
cp: bool = False,
|
|
1368
1485
|
force: bool = False,
|
|
1369
1486
|
*,
|
|
@@ -1393,7 +1510,29 @@ class Catalog:
|
|
|
1393
1510
|
except Exception as e:
|
|
1394
1511
|
raise DataChainError("Error when parsing dataset uri") from e
|
|
1395
1512
|
|
|
1396
|
-
|
|
1513
|
+
remote_namespace, remote_project, remote_ds_name = parse_dataset_name(
|
|
1514
|
+
remote_ds_name
|
|
1515
|
+
)
|
|
1516
|
+
if not remote_namespace or not remote_project:
|
|
1517
|
+
raise DataChainError(
|
|
1518
|
+
f"Invalid fully qualified dataset name {remote_ds_name}, namespace"
|
|
1519
|
+
f" or project missing"
|
|
1520
|
+
)
|
|
1521
|
+
|
|
1522
|
+
if local_ds_name:
|
|
1523
|
+
local_namespace, local_project, local_ds_name = parse_dataset_name(
|
|
1524
|
+
local_ds_name
|
|
1525
|
+
)
|
|
1526
|
+
if local_namespace and local_namespace != remote_namespace:
|
|
1527
|
+
raise DataChainError(
|
|
1528
|
+
"Local namespace must be the same to remote namespace"
|
|
1529
|
+
)
|
|
1530
|
+
if local_project and local_project != remote_project:
|
|
1531
|
+
raise DataChainError("Local project must be the same to remote project")
|
|
1532
|
+
|
|
1533
|
+
remote_ds = self.get_remote_dataset(
|
|
1534
|
+
remote_namespace, remote_project, remote_ds_name
|
|
1535
|
+
)
|
|
1397
1536
|
|
|
1398
1537
|
try:
|
|
1399
1538
|
# if version is not specified in uri, take the latest one
|
|
@@ -1401,7 +1540,12 @@ class Catalog:
|
|
|
1401
1540
|
version = remote_ds.latest_version
|
|
1402
1541
|
print(f"Version not specified, pulling the latest one (v{version})")
|
|
1403
1542
|
# updating dataset uri with latest version
|
|
1404
|
-
remote_ds_uri = create_dataset_uri(
|
|
1543
|
+
remote_ds_uri = create_dataset_uri(
|
|
1544
|
+
remote_ds.name,
|
|
1545
|
+
remote_ds.project.namespace.name,
|
|
1546
|
+
remote_ds.project.name,
|
|
1547
|
+
version,
|
|
1548
|
+
)
|
|
1405
1549
|
remote_ds_version = remote_ds.get_version(version)
|
|
1406
1550
|
except (DatasetVersionNotFoundError, StopIteration) as exc:
|
|
1407
1551
|
raise DataChainError(
|
|
@@ -1410,7 +1554,13 @@ class Catalog:
|
|
|
1410
1554
|
|
|
1411
1555
|
local_ds_name = local_ds_name or remote_ds.name
|
|
1412
1556
|
local_ds_version = local_ds_version or remote_ds_version.version
|
|
1413
|
-
|
|
1557
|
+
|
|
1558
|
+
local_ds_uri = create_dataset_uri(
|
|
1559
|
+
local_ds_name,
|
|
1560
|
+
remote_ds.project.namespace.name,
|
|
1561
|
+
remote_ds.project.name,
|
|
1562
|
+
local_ds_version,
|
|
1563
|
+
)
|
|
1414
1564
|
|
|
1415
1565
|
try:
|
|
1416
1566
|
# try to find existing dataset with the same uuid to avoid pulling again
|
|
@@ -1419,7 +1569,10 @@ class Catalog:
|
|
|
1419
1569
|
remote_ds_version.uuid
|
|
1420
1570
|
)
|
|
1421
1571
|
existing_ds_uri = create_dataset_uri(
|
|
1422
|
-
existing_ds.name,
|
|
1572
|
+
existing_ds.name,
|
|
1573
|
+
existing_ds.project.namespace.name,
|
|
1574
|
+
existing_ds.project.name,
|
|
1575
|
+
existing_ds_version.version,
|
|
1423
1576
|
)
|
|
1424
1577
|
if existing_ds_uri == remote_ds_uri:
|
|
1425
1578
|
print(f"Local copy of dataset {remote_ds_uri} already present")
|
|
@@ -1433,8 +1586,30 @@ class Catalog:
|
|
|
1433
1586
|
except DatasetNotFoundError:
|
|
1434
1587
|
pass
|
|
1435
1588
|
|
|
1589
|
+
# Create namespace and project if doesn't exist
|
|
1590
|
+
print(
|
|
1591
|
+
f"Creating namespace {remote_ds.project.namespace.name} and project"
|
|
1592
|
+
f" {remote_ds.project.name}"
|
|
1593
|
+
)
|
|
1594
|
+
|
|
1595
|
+
namespace = self.metastore.create_namespace(
|
|
1596
|
+
remote_ds.project.namespace.name,
|
|
1597
|
+
description=remote_ds.project.namespace.descr,
|
|
1598
|
+
uuid=remote_ds.project.namespace.uuid,
|
|
1599
|
+
validate=False,
|
|
1600
|
+
)
|
|
1601
|
+
project = self.metastore.create_project(
|
|
1602
|
+
namespace.name,
|
|
1603
|
+
remote_ds.project.name,
|
|
1604
|
+
description=remote_ds.project.descr,
|
|
1605
|
+
uuid=remote_ds.project.uuid,
|
|
1606
|
+
validate=False,
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1436
1609
|
try:
|
|
1437
|
-
local_dataset = self.get_dataset(
|
|
1610
|
+
local_dataset = self.get_dataset(
|
|
1611
|
+
local_ds_name, namespace_name=namespace.name, project_name=project.name
|
|
1612
|
+
)
|
|
1438
1613
|
if local_dataset and local_dataset.has_version(local_ds_version):
|
|
1439
1614
|
raise DataChainError(
|
|
1440
1615
|
f"Local dataset {local_ds_uri} already exists with different uuid,"
|
|
@@ -1452,10 +1627,11 @@ class Catalog:
|
|
|
1452
1627
|
leave=False,
|
|
1453
1628
|
)
|
|
1454
1629
|
|
|
1455
|
-
schema =
|
|
1630
|
+
schema = parse_schema(remote_ds_version.schema)
|
|
1456
1631
|
|
|
1457
1632
|
local_ds = self.create_dataset(
|
|
1458
1633
|
local_ds_name,
|
|
1634
|
+
project,
|
|
1459
1635
|
local_ds_version,
|
|
1460
1636
|
query_script=remote_ds_version.query_script,
|
|
1461
1637
|
create_rows=True,
|
|
@@ -1468,7 +1644,7 @@ class Catalog:
|
|
|
1468
1644
|
# asking remote to export dataset rows table to s3 and to return signed
|
|
1469
1645
|
# urls of exported parts, which are in parquet format
|
|
1470
1646
|
export_response = studio_client.export_dataset_table(
|
|
1471
|
-
|
|
1647
|
+
remote_ds, remote_ds_version.version
|
|
1472
1648
|
)
|
|
1473
1649
|
if not export_response.ok:
|
|
1474
1650
|
raise DataChainError(export_response.message)
|
|
@@ -1499,9 +1675,9 @@ class Catalog:
|
|
|
1499
1675
|
rows_fetcher = DatasetRowsFetcher(
|
|
1500
1676
|
metastore,
|
|
1501
1677
|
warehouse,
|
|
1502
|
-
|
|
1678
|
+
remote_ds,
|
|
1503
1679
|
remote_ds_version.version,
|
|
1504
|
-
|
|
1680
|
+
local_ds,
|
|
1505
1681
|
local_ds_version,
|
|
1506
1682
|
schema,
|
|
1507
1683
|
progress_bar=dataset_save_progress_bar,
|
|
@@ -1511,7 +1687,7 @@ class Catalog:
|
|
|
1511
1687
|
iter(batch(signed_urls)), dataset_save_progress_bar
|
|
1512
1688
|
)
|
|
1513
1689
|
except:
|
|
1514
|
-
self.remove_dataset(local_ds_name, local_ds_version)
|
|
1690
|
+
self.remove_dataset(local_ds_name, project, local_ds_version)
|
|
1515
1691
|
raise
|
|
1516
1692
|
|
|
1517
1693
|
local_ds = self.metastore.update_dataset_status(
|
|
@@ -1561,92 +1737,20 @@ class Catalog:
|
|
|
1561
1737
|
else:
|
|
1562
1738
|
# since we don't call cp command, which does listing implicitly,
|
|
1563
1739
|
# it needs to be done here
|
|
1564
|
-
self.enlist_sources(
|
|
1740
|
+
with self.enlist_sources(
|
|
1565
1741
|
sources,
|
|
1566
1742
|
update,
|
|
1567
1743
|
client_config=client_config or self.client_config,
|
|
1568
|
-
)
|
|
1744
|
+
):
|
|
1745
|
+
pass
|
|
1569
1746
|
|
|
1570
1747
|
self.create_dataset_from_sources(
|
|
1571
|
-
output,
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
query_script: str,
|
|
1577
|
-
env: Optional[Mapping[str, str]] = None,
|
|
1578
|
-
python_executable: str = sys.executable,
|
|
1579
|
-
capture_output: bool = False,
|
|
1580
|
-
output_hook: Callable[[str], None] = noop,
|
|
1581
|
-
params: Optional[dict[str, str]] = None,
|
|
1582
|
-
job_id: Optional[str] = None,
|
|
1583
|
-
interrupt_timeout: Optional[int] = None,
|
|
1584
|
-
terminate_timeout: Optional[int] = None,
|
|
1585
|
-
) -> None:
|
|
1586
|
-
cmd = [python_executable, "-c", query_script]
|
|
1587
|
-
env = dict(env or os.environ)
|
|
1588
|
-
env.update(
|
|
1589
|
-
{
|
|
1590
|
-
"DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
|
|
1591
|
-
"DATACHAIN_JOB_ID": job_id or "",
|
|
1592
|
-
},
|
|
1748
|
+
output,
|
|
1749
|
+
sources,
|
|
1750
|
+
self.metastore.default_project,
|
|
1751
|
+
client_config=client_config,
|
|
1752
|
+
recursive=recursive,
|
|
1593
1753
|
)
|
|
1594
|
-
popen_kwargs: dict[str, Any] = {}
|
|
1595
|
-
if capture_output:
|
|
1596
|
-
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
|
|
1597
|
-
|
|
1598
|
-
def raise_termination_signal(sig: int, _: Any) -> NoReturn:
|
|
1599
|
-
raise TerminationSignal(sig)
|
|
1600
|
-
|
|
1601
|
-
thread: Optional[Thread] = None
|
|
1602
|
-
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
|
|
1603
|
-
logger.info("Starting process %s", proc.pid)
|
|
1604
|
-
|
|
1605
|
-
orig_sigint_handler = signal.getsignal(signal.SIGINT)
|
|
1606
|
-
# ignore SIGINT in the main process.
|
|
1607
|
-
# In the terminal, SIGINTs are received by all the processes in
|
|
1608
|
-
# the foreground process group, so the script will receive the signal too.
|
|
1609
|
-
# (If we forward the signal to the child, it will receive it twice.)
|
|
1610
|
-
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
1611
|
-
|
|
1612
|
-
orig_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|
1613
|
-
signal.signal(signal.SIGTERM, raise_termination_signal)
|
|
1614
|
-
try:
|
|
1615
|
-
if capture_output:
|
|
1616
|
-
args = (proc.stdout, output_hook)
|
|
1617
|
-
thread = Thread(target=_process_stream, args=args, daemon=True)
|
|
1618
|
-
thread.start()
|
|
1619
|
-
|
|
1620
|
-
proc.wait()
|
|
1621
|
-
except TerminationSignal as exc:
|
|
1622
|
-
signal.signal(signal.SIGTERM, orig_sigterm_handler)
|
|
1623
|
-
signal.signal(signal.SIGINT, orig_sigint_handler)
|
|
1624
|
-
logger.info("Shutting down process %s, received %r", proc.pid, exc)
|
|
1625
|
-
# Rather than forwarding the signal to the child, we try to shut it down
|
|
1626
|
-
# gracefully. This is because we consider the script to be interactive
|
|
1627
|
-
# and special, so we give it time to cleanup before exiting.
|
|
1628
|
-
shutdown_process(proc, interrupt_timeout, terminate_timeout)
|
|
1629
|
-
if proc.returncode:
|
|
1630
|
-
raise QueryScriptCancelError(
|
|
1631
|
-
"Query script was canceled by user", return_code=proc.returncode
|
|
1632
|
-
) from exc
|
|
1633
|
-
finally:
|
|
1634
|
-
signal.signal(signal.SIGTERM, orig_sigterm_handler)
|
|
1635
|
-
signal.signal(signal.SIGINT, orig_sigint_handler)
|
|
1636
|
-
if thread:
|
|
1637
|
-
thread.join() # wait for the reader thread
|
|
1638
|
-
|
|
1639
|
-
logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
|
|
1640
|
-
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
|
|
1641
|
-
raise QueryScriptCancelError(
|
|
1642
|
-
"Query script was canceled by user",
|
|
1643
|
-
return_code=proc.returncode,
|
|
1644
|
-
)
|
|
1645
|
-
if proc.returncode:
|
|
1646
|
-
raise QueryScriptRunError(
|
|
1647
|
-
f"Query script exited with error code {proc.returncode}",
|
|
1648
|
-
return_code=proc.returncode,
|
|
1649
|
-
)
|
|
1650
1754
|
|
|
1651
1755
|
def cp(
|
|
1652
1756
|
self,
|
|
@@ -1658,7 +1762,7 @@ class Catalog:
|
|
|
1658
1762
|
no_cp: bool = False,
|
|
1659
1763
|
no_glob: bool = False,
|
|
1660
1764
|
*,
|
|
1661
|
-
client_config:
|
|
1765
|
+
client_config: dict | None = None,
|
|
1662
1766
|
) -> None:
|
|
1663
1767
|
"""
|
|
1664
1768
|
This function copies files from cloud sources to local destination directory
|
|
@@ -1671,38 +1775,42 @@ class Catalog:
|
|
|
1671
1775
|
no_glob,
|
|
1672
1776
|
client_config=client_config,
|
|
1673
1777
|
)
|
|
1778
|
+
try:
|
|
1779
|
+
always_copy_dir_contents, copy_to_filename = prepare_output_for_cp(
|
|
1780
|
+
node_groups, output, force, no_cp
|
|
1781
|
+
)
|
|
1782
|
+
total_size, total_files = collect_nodes_for_cp(node_groups, recursive)
|
|
1783
|
+
if not total_files:
|
|
1784
|
+
return
|
|
1674
1785
|
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
desc_max_len = max(len(output) + 16, 19)
|
|
1683
|
-
bar_format = (
|
|
1684
|
-
"{desc:<"
|
|
1685
|
-
f"{desc_max_len}"
|
|
1686
|
-
"}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
|
|
1687
|
-
"[{elapsed}<{remaining}, {rate_fmt:>8}]"
|
|
1688
|
-
)
|
|
1786
|
+
desc_max_len = max(len(output) + 16, 19)
|
|
1787
|
+
bar_format = (
|
|
1788
|
+
"{desc:<"
|
|
1789
|
+
f"{desc_max_len}"
|
|
1790
|
+
"}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
|
|
1791
|
+
"[{elapsed}<{remaining}, {rate_fmt:>8}]"
|
|
1792
|
+
)
|
|
1689
1793
|
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1794
|
+
if not no_cp:
|
|
1795
|
+
with get_download_bar(bar_format, total_size) as pbar:
|
|
1796
|
+
for node_group in node_groups:
|
|
1797
|
+
node_group.download(recursive=recursive, pbar=pbar)
|
|
1694
1798
|
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1799
|
+
instantiate_node_groups(
|
|
1800
|
+
node_groups,
|
|
1801
|
+
output,
|
|
1802
|
+
bar_format,
|
|
1803
|
+
total_files,
|
|
1804
|
+
force,
|
|
1805
|
+
recursive,
|
|
1806
|
+
no_cp,
|
|
1807
|
+
always_copy_dir_contents,
|
|
1808
|
+
copy_to_filename,
|
|
1809
|
+
)
|
|
1810
|
+
finally:
|
|
1811
|
+
for node_group in node_groups:
|
|
1812
|
+
with suppress(Exception):
|
|
1813
|
+
node_group.close()
|
|
1706
1814
|
|
|
1707
1815
|
def du(
|
|
1708
1816
|
self,
|
|
@@ -1712,24 +1820,26 @@ class Catalog:
|
|
|
1712
1820
|
*,
|
|
1713
1821
|
client_config=None,
|
|
1714
1822
|
) -> Iterable[tuple[str, float]]:
|
|
1715
|
-
|
|
1823
|
+
with self.enlist_sources(
|
|
1716
1824
|
sources,
|
|
1717
1825
|
update,
|
|
1718
1826
|
client_config=client_config or self.client_config,
|
|
1719
|
-
)
|
|
1827
|
+
) as matched_sources:
|
|
1828
|
+
if matched_sources is None:
|
|
1829
|
+
return
|
|
1720
1830
|
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1831
|
+
def du_dirs(src, node, subdepth):
|
|
1832
|
+
if subdepth > 0:
|
|
1833
|
+
subdirs = src.listing.get_dirs_by_parent_path(node.path)
|
|
1834
|
+
for sd in subdirs:
|
|
1835
|
+
yield from du_dirs(src, sd, subdepth - 1)
|
|
1836
|
+
yield (
|
|
1837
|
+
src.get_node_full_path(node),
|
|
1838
|
+
src.listing.du(node)[0],
|
|
1839
|
+
)
|
|
1730
1840
|
|
|
1731
|
-
|
|
1732
|
-
|
|
1841
|
+
for src in matched_sources:
|
|
1842
|
+
yield from du_dirs(src, src.node, depth)
|
|
1733
1843
|
|
|
1734
1844
|
def find(
|
|
1735
1845
|
self,
|
|
@@ -1745,39 +1855,42 @@ class Catalog:
|
|
|
1745
1855
|
*,
|
|
1746
1856
|
client_config=None,
|
|
1747
1857
|
) -> Iterator[str]:
|
|
1748
|
-
|
|
1858
|
+
with self.enlist_sources(
|
|
1749
1859
|
sources,
|
|
1750
1860
|
update,
|
|
1751
1861
|
client_config=client_config or self.client_config,
|
|
1752
|
-
)
|
|
1753
|
-
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
find_column_to_str(row, field_lookup, src, column)
|
|
1779
|
-
for column in columns
|
|
1862
|
+
) as matched_sources:
|
|
1863
|
+
if matched_sources is None:
|
|
1864
|
+
return
|
|
1865
|
+
|
|
1866
|
+
if not columns:
|
|
1867
|
+
columns = ["path"]
|
|
1868
|
+
field_set = set()
|
|
1869
|
+
for column in columns:
|
|
1870
|
+
if column == "du":
|
|
1871
|
+
field_set.add("dir_type")
|
|
1872
|
+
field_set.add("size")
|
|
1873
|
+
field_set.add("path")
|
|
1874
|
+
elif column == "name":
|
|
1875
|
+
field_set.add("path")
|
|
1876
|
+
elif column == "path":
|
|
1877
|
+
field_set.add("dir_type")
|
|
1878
|
+
field_set.add("path")
|
|
1879
|
+
elif column == "size":
|
|
1880
|
+
field_set.add("size")
|
|
1881
|
+
elif column == "type":
|
|
1882
|
+
field_set.add("dir_type")
|
|
1883
|
+
fields = list(field_set)
|
|
1884
|
+
field_lookup = {f: i for i, f in enumerate(fields)}
|
|
1885
|
+
for src in matched_sources:
|
|
1886
|
+
results = src.listing.find(
|
|
1887
|
+
src.node, fields, names, inames, paths, ipaths, size, typ
|
|
1780
1888
|
)
|
|
1889
|
+
for row in results:
|
|
1890
|
+
yield "\t".join(
|
|
1891
|
+
find_column_to_str(row, field_lookup, src, column)
|
|
1892
|
+
for column in columns
|
|
1893
|
+
)
|
|
1781
1894
|
|
|
1782
1895
|
def index(
|
|
1783
1896
|
self,
|
|
@@ -1786,9 +1899,10 @@ class Catalog:
|
|
|
1786
1899
|
*,
|
|
1787
1900
|
client_config=None,
|
|
1788
1901
|
) -> None:
|
|
1789
|
-
self.enlist_sources(
|
|
1902
|
+
with self.enlist_sources(
|
|
1790
1903
|
sources,
|
|
1791
1904
|
update,
|
|
1792
1905
|
client_config=client_config or self.client_config,
|
|
1793
1906
|
only_index=True,
|
|
1794
|
-
)
|
|
1907
|
+
):
|
|
1908
|
+
pass
|