datachain 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/__init__.py +0 -2
- datachain/catalog/catalog.py +12 -9
- datachain/cli.py +109 -9
- datachain/client/fsspec.py +9 -9
- datachain/data_storage/metastore.py +63 -11
- datachain/data_storage/schema.py +2 -2
- datachain/data_storage/sqlite.py +5 -4
- datachain/data_storage/warehouse.py +18 -18
- datachain/dataset.py +142 -14
- datachain/func/__init__.py +49 -0
- datachain/{lib/func → func}/aggregate.py +13 -11
- datachain/func/array.py +176 -0
- datachain/func/base.py +23 -0
- datachain/func/conditional.py +81 -0
- datachain/func/func.py +384 -0
- datachain/func/path.py +110 -0
- datachain/func/random.py +23 -0
- datachain/func/string.py +154 -0
- datachain/func/window.py +49 -0
- datachain/lib/arrow.py +24 -12
- datachain/lib/data_model.py +25 -9
- datachain/lib/dataset_info.py +9 -5
- datachain/lib/dc.py +94 -56
- datachain/lib/hf.py +1 -1
- datachain/lib/signal_schema.py +1 -1
- datachain/lib/utils.py +1 -0
- datachain/lib/webdataset_laion.py +5 -5
- datachain/model/bbox.py +2 -2
- datachain/model/pose.py +5 -5
- datachain/model/segment.py +2 -2
- datachain/nodes_fetcher.py +2 -2
- datachain/query/dataset.py +57 -34
- datachain/remote/studio.py +40 -8
- datachain/sql/__init__.py +0 -2
- datachain/sql/functions/__init__.py +0 -26
- datachain/sql/selectable.py +11 -5
- datachain/sql/sqlite/base.py +11 -2
- datachain/studio.py +29 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/METADATA +2 -2
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/RECORD +44 -37
- datachain/lib/func/__init__.py +0 -32
- datachain/lib/func/func.py +0 -152
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/LICENSE +0 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/WHEEL +0 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/entry_points.txt +0 -0
- {datachain-0.7.1.dist-info → datachain-0.7.3.dist-info}/top_level.txt +0 -0
datachain/__init__.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from datachain.lib import func
|
|
2
1
|
from datachain.lib.data_model import DataModel, DataType, is_chain_type
|
|
3
2
|
from datachain.lib.dc import C, Column, DataChain, Sys
|
|
4
3
|
from datachain.lib.file import (
|
|
@@ -35,7 +34,6 @@ __all__ = [
|
|
|
35
34
|
"Sys",
|
|
36
35
|
"TarVFile",
|
|
37
36
|
"TextFile",
|
|
38
|
-
"func",
|
|
39
37
|
"is_chain_type",
|
|
40
38
|
"metrics",
|
|
41
39
|
"param",
|
datachain/catalog/catalog.py
CHANGED
|
@@ -38,6 +38,7 @@ from datachain.dataset import (
|
|
|
38
38
|
DATASET_PREFIX,
|
|
39
39
|
QUERY_DATASET_PREFIX,
|
|
40
40
|
DatasetDependency,
|
|
41
|
+
DatasetListRecord,
|
|
41
42
|
DatasetRecord,
|
|
42
43
|
DatasetStats,
|
|
43
44
|
DatasetStatus,
|
|
@@ -54,7 +55,6 @@ from datachain.error import (
|
|
|
54
55
|
QueryScriptCancelError,
|
|
55
56
|
QueryScriptRunError,
|
|
56
57
|
)
|
|
57
|
-
from datachain.listing import Listing
|
|
58
58
|
from datachain.node import DirType, Node, NodeWithPath
|
|
59
59
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
60
60
|
from datachain.remote.studio import StudioClient
|
|
@@ -73,9 +73,10 @@ if TYPE_CHECKING:
|
|
|
73
73
|
AbstractMetastore,
|
|
74
74
|
AbstractWarehouse,
|
|
75
75
|
)
|
|
76
|
-
from datachain.dataset import
|
|
76
|
+
from datachain.dataset import DatasetListVersion
|
|
77
77
|
from datachain.job import Job
|
|
78
78
|
from datachain.lib.file import File
|
|
79
|
+
from datachain.listing import Listing
|
|
79
80
|
|
|
80
81
|
logger = logging.getLogger("datachain")
|
|
81
82
|
|
|
@@ -236,7 +237,7 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
236
237
|
class NodeGroup:
|
|
237
238
|
"""Class for a group of nodes from the same source"""
|
|
238
239
|
|
|
239
|
-
listing: Listing
|
|
240
|
+
listing: "Listing"
|
|
240
241
|
sources: list[DataSource]
|
|
241
242
|
|
|
242
243
|
# The source path within the bucket
|
|
@@ -591,8 +592,9 @@ class Catalog:
|
|
|
591
592
|
client_config=None,
|
|
592
593
|
object_name="file",
|
|
593
594
|
skip_indexing=False,
|
|
594
|
-
) -> tuple[Listing, str]:
|
|
595
|
+
) -> tuple["Listing", str]:
|
|
595
596
|
from datachain.lib.dc import DataChain
|
|
597
|
+
from datachain.listing import Listing
|
|
596
598
|
|
|
597
599
|
DataChain.from_storage(
|
|
598
600
|
source, session=self.session, update=update, object_name=object_name
|
|
@@ -660,7 +662,8 @@ class Catalog:
|
|
|
660
662
|
no_glob: bool = False,
|
|
661
663
|
client_config=None,
|
|
662
664
|
) -> list[NodeGroup]:
|
|
663
|
-
from datachain.
|
|
665
|
+
from datachain.listing import Listing
|
|
666
|
+
from datachain.query.dataset import DatasetQuery
|
|
664
667
|
|
|
665
668
|
def _row_to_node(d: dict[str, Any]) -> Node:
|
|
666
669
|
del d["file__source"]
|
|
@@ -876,7 +879,7 @@ class Catalog:
|
|
|
876
879
|
def update_dataset_version_with_warehouse_info(
|
|
877
880
|
self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
|
|
878
881
|
) -> None:
|
|
879
|
-
from datachain.query import DatasetQuery
|
|
882
|
+
from datachain.query.dataset import DatasetQuery
|
|
880
883
|
|
|
881
884
|
dataset_version = dataset.get_version(version)
|
|
882
885
|
|
|
@@ -1133,7 +1136,7 @@ class Catalog:
|
|
|
1133
1136
|
|
|
1134
1137
|
return direct_dependencies
|
|
1135
1138
|
|
|
1136
|
-
def ls_datasets(self, include_listing: bool = False) -> Iterator[
|
|
1139
|
+
def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetListRecord]:
|
|
1137
1140
|
datasets = self.metastore.list_datasets()
|
|
1138
1141
|
for d in datasets:
|
|
1139
1142
|
if not d.is_bucket_listing or include_listing:
|
|
@@ -1142,7 +1145,7 @@ class Catalog:
|
|
|
1142
1145
|
def list_datasets_versions(
|
|
1143
1146
|
self,
|
|
1144
1147
|
include_listing: bool = False,
|
|
1145
|
-
) -> Iterator[tuple[
|
|
1148
|
+
) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
|
|
1146
1149
|
"""Iterate over all dataset versions with related jobs."""
|
|
1147
1150
|
datasets = list(self.ls_datasets(include_listing=include_listing))
|
|
1148
1151
|
|
|
@@ -1177,7 +1180,7 @@ class Catalog:
|
|
|
1177
1180
|
def ls_dataset_rows(
|
|
1178
1181
|
self, name: str, version: int, offset=None, limit=None
|
|
1179
1182
|
) -> list[dict]:
|
|
1180
|
-
from datachain.query import DatasetQuery
|
|
1183
|
+
from datachain.query.dataset import DatasetQuery
|
|
1181
1184
|
|
|
1182
1185
|
dataset = self.get_dataset(name)
|
|
1183
1186
|
|
datachain/cli.py
CHANGED
|
@@ -18,7 +18,12 @@ from datachain.cli_utils import BooleanOptionalAction, CommaSeparatedArgs, KeyVa
|
|
|
18
18
|
from datachain.config import Config
|
|
19
19
|
from datachain.error import DataChainError
|
|
20
20
|
from datachain.lib.dc import DataChain
|
|
21
|
-
from datachain.studio import
|
|
21
|
+
from datachain.studio import (
|
|
22
|
+
edit_studio_dataset,
|
|
23
|
+
list_datasets,
|
|
24
|
+
process_studio_cli_args,
|
|
25
|
+
remove_studio_dataset,
|
|
26
|
+
)
|
|
22
27
|
from datachain.telemetry import telemetry
|
|
23
28
|
|
|
24
29
|
if TYPE_CHECKING:
|
|
@@ -403,21 +408,44 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
403
408
|
parse_edit_dataset.add_argument(
|
|
404
409
|
"--new-name",
|
|
405
410
|
action="store",
|
|
406
|
-
default="",
|
|
407
411
|
help="Dataset new name",
|
|
408
412
|
)
|
|
409
413
|
parse_edit_dataset.add_argument(
|
|
410
414
|
"--description",
|
|
411
415
|
action="store",
|
|
412
|
-
default="",
|
|
413
416
|
help="Dataset description",
|
|
414
417
|
)
|
|
415
418
|
parse_edit_dataset.add_argument(
|
|
416
419
|
"--labels",
|
|
417
|
-
default=[],
|
|
418
420
|
nargs="+",
|
|
419
421
|
help="Dataset labels",
|
|
420
422
|
)
|
|
423
|
+
parse_edit_dataset.add_argument(
|
|
424
|
+
"--studio",
|
|
425
|
+
action="store_true",
|
|
426
|
+
default=False,
|
|
427
|
+
help="Edit dataset from Studio",
|
|
428
|
+
)
|
|
429
|
+
parse_edit_dataset.add_argument(
|
|
430
|
+
"-L",
|
|
431
|
+
"--local",
|
|
432
|
+
action="store_true",
|
|
433
|
+
default=False,
|
|
434
|
+
help="Edit local dataset only",
|
|
435
|
+
)
|
|
436
|
+
parse_edit_dataset.add_argument(
|
|
437
|
+
"-a",
|
|
438
|
+
"--all",
|
|
439
|
+
action="store_true",
|
|
440
|
+
default=True,
|
|
441
|
+
help="Edit both datasets from studio and local",
|
|
442
|
+
)
|
|
443
|
+
parse_edit_dataset.add_argument(
|
|
444
|
+
"--team",
|
|
445
|
+
action="store",
|
|
446
|
+
default=None,
|
|
447
|
+
help="The team to edit a dataset. By default, it will use team from config.",
|
|
448
|
+
)
|
|
421
449
|
|
|
422
450
|
datasets_parser = subp.add_parser(
|
|
423
451
|
"datasets", parents=[parent_parser], description="List datasets"
|
|
@@ -466,6 +494,32 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
466
494
|
action=BooleanOptionalAction,
|
|
467
495
|
help="Force delete registered dataset with all of it's versions",
|
|
468
496
|
)
|
|
497
|
+
rm_dataset_parser.add_argument(
|
|
498
|
+
"--studio",
|
|
499
|
+
action="store_true",
|
|
500
|
+
default=False,
|
|
501
|
+
help="Remove dataset from Studio",
|
|
502
|
+
)
|
|
503
|
+
rm_dataset_parser.add_argument(
|
|
504
|
+
"-L",
|
|
505
|
+
"--local",
|
|
506
|
+
action="store_true",
|
|
507
|
+
default=False,
|
|
508
|
+
help="Remove local datasets only",
|
|
509
|
+
)
|
|
510
|
+
rm_dataset_parser.add_argument(
|
|
511
|
+
"-a",
|
|
512
|
+
"--all",
|
|
513
|
+
action="store_true",
|
|
514
|
+
default=True,
|
|
515
|
+
help="Remove both local and studio",
|
|
516
|
+
)
|
|
517
|
+
rm_dataset_parser.add_argument(
|
|
518
|
+
"--team",
|
|
519
|
+
action="store",
|
|
520
|
+
default=None,
|
|
521
|
+
help="The team to delete a dataset. By default, it will use team from config.",
|
|
522
|
+
)
|
|
469
523
|
|
|
470
524
|
dataset_stats_parser = subp.add_parser(
|
|
471
525
|
"dataset-stats",
|
|
@@ -909,8 +963,40 @@ def rm_dataset(
|
|
|
909
963
|
name: str,
|
|
910
964
|
version: Optional[int] = None,
|
|
911
965
|
force: Optional[bool] = False,
|
|
966
|
+
studio: bool = False,
|
|
967
|
+
local: bool = False,
|
|
968
|
+
all: bool = True,
|
|
969
|
+
team: Optional[str] = None,
|
|
970
|
+
):
|
|
971
|
+
token = Config().read().get("studio", {}).get("token")
|
|
972
|
+
all, local, studio = _determine_flavors(studio, local, all, token)
|
|
973
|
+
|
|
974
|
+
if all or local:
|
|
975
|
+
catalog.remove_dataset(name, version=version, force=force)
|
|
976
|
+
|
|
977
|
+
if (all or studio) and token:
|
|
978
|
+
remove_studio_dataset(team, name, version, force)
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
def edit_dataset(
|
|
982
|
+
catalog: "Catalog",
|
|
983
|
+
name: str,
|
|
984
|
+
new_name: Optional[str] = None,
|
|
985
|
+
description: Optional[str] = None,
|
|
986
|
+
labels: Optional[list[str]] = None,
|
|
987
|
+
studio: bool = False,
|
|
988
|
+
local: bool = False,
|
|
989
|
+
all: bool = True,
|
|
990
|
+
team: Optional[str] = None,
|
|
912
991
|
):
|
|
913
|
-
|
|
992
|
+
token = Config().read().get("studio", {}).get("token")
|
|
993
|
+
all, local, studio = _determine_flavors(studio, local, all, token)
|
|
994
|
+
|
|
995
|
+
if all or local:
|
|
996
|
+
catalog.edit_dataset(name, new_name, description, labels)
|
|
997
|
+
|
|
998
|
+
if (all or studio) and token:
|
|
999
|
+
edit_studio_dataset(team, name, new_name, description, labels)
|
|
914
1000
|
|
|
915
1001
|
|
|
916
1002
|
def dataset_stats(
|
|
@@ -957,7 +1043,7 @@ def show(
|
|
|
957
1043
|
schema: bool = False,
|
|
958
1044
|
) -> None:
|
|
959
1045
|
from datachain.lib.dc import DataChain
|
|
960
|
-
from datachain.query import DatasetQuery
|
|
1046
|
+
from datachain.query.dataset import DatasetQuery
|
|
961
1047
|
from datachain.utils import show_records
|
|
962
1048
|
|
|
963
1049
|
dataset = catalog.get_dataset(name)
|
|
@@ -1127,11 +1213,16 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1127
1213
|
edatachain_file=args.edatachain_file,
|
|
1128
1214
|
)
|
|
1129
1215
|
elif args.command == "edit-dataset":
|
|
1130
|
-
|
|
1216
|
+
edit_dataset(
|
|
1217
|
+
catalog,
|
|
1131
1218
|
args.name,
|
|
1132
|
-
description=args.description,
|
|
1133
1219
|
new_name=args.new_name,
|
|
1220
|
+
description=args.description,
|
|
1134
1221
|
labels=args.labels,
|
|
1222
|
+
studio=args.studio,
|
|
1223
|
+
local=args.local,
|
|
1224
|
+
all=args.all,
|
|
1225
|
+
team=args.team,
|
|
1135
1226
|
)
|
|
1136
1227
|
elif args.command == "ls":
|
|
1137
1228
|
ls(
|
|
@@ -1164,7 +1255,16 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
|
|
|
1164
1255
|
schema=args.schema,
|
|
1165
1256
|
)
|
|
1166
1257
|
elif args.command == "rm-dataset":
|
|
1167
|
-
rm_dataset(
|
|
1258
|
+
rm_dataset(
|
|
1259
|
+
catalog,
|
|
1260
|
+
args.name,
|
|
1261
|
+
version=args.version,
|
|
1262
|
+
force=args.force,
|
|
1263
|
+
studio=args.studio,
|
|
1264
|
+
local=args.local,
|
|
1265
|
+
all=args.all,
|
|
1266
|
+
team=args.team,
|
|
1267
|
+
)
|
|
1168
1268
|
elif args.command == "dataset-stats":
|
|
1169
1269
|
dataset_stats(
|
|
1170
1270
|
catalog,
|
datachain/client/fsspec.py
CHANGED
|
@@ -28,7 +28,6 @@ from tqdm import tqdm
|
|
|
28
28
|
from datachain.cache import DataChainCache
|
|
29
29
|
from datachain.client.fileslice import FileWrapper
|
|
30
30
|
from datachain.error import ClientError as DataChainClientError
|
|
31
|
-
from datachain.lib.file import File
|
|
32
31
|
from datachain.nodes_fetcher import NodesFetcher
|
|
33
32
|
from datachain.nodes_thread_pool import NodeChunk
|
|
34
33
|
|
|
@@ -36,6 +35,7 @@ if TYPE_CHECKING:
|
|
|
36
35
|
from fsspec.spec import AbstractFileSystem
|
|
37
36
|
|
|
38
37
|
from datachain.dataset import StorageURI
|
|
38
|
+
from datachain.lib.file import File
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
logger = logging.getLogger("datachain")
|
|
@@ -45,7 +45,7 @@ DELIMITER = "/" # Path delimiter.
|
|
|
45
45
|
|
|
46
46
|
DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")
|
|
47
47
|
|
|
48
|
-
ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
|
|
48
|
+
ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
def _is_win_local_path(uri: str) -> bool:
|
|
@@ -212,7 +212,7 @@ class Client(ABC):
|
|
|
212
212
|
|
|
213
213
|
async def scandir(
|
|
214
214
|
self, start_prefix: str, method: str = "default"
|
|
215
|
-
) -> AsyncIterator[Sequence[File]]:
|
|
215
|
+
) -> AsyncIterator[Sequence["File"]]:
|
|
216
216
|
try:
|
|
217
217
|
impl = getattr(self, f"_fetch_{method}")
|
|
218
218
|
except AttributeError:
|
|
@@ -317,7 +317,7 @@ class Client(ABC):
|
|
|
317
317
|
return f"{self.PREFIX}{self.name}/{rel_path}"
|
|
318
318
|
|
|
319
319
|
@abstractmethod
|
|
320
|
-
def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...
|
|
320
|
+
def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...
|
|
321
321
|
|
|
322
322
|
def fetch_nodes(
|
|
323
323
|
self,
|
|
@@ -354,7 +354,7 @@ class Client(ABC):
|
|
|
354
354
|
copy2(src, dst)
|
|
355
355
|
|
|
356
356
|
def open_object(
|
|
357
|
-
self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
|
|
357
|
+
self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
|
|
358
358
|
) -> BinaryIO:
|
|
359
359
|
"""Open a file, including files in tar archives."""
|
|
360
360
|
if use_cache and (cache_path := self.cache.get_path(file)):
|
|
@@ -362,19 +362,19 @@ class Client(ABC):
|
|
|
362
362
|
assert not file.location
|
|
363
363
|
return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]
|
|
364
364
|
|
|
365
|
-
def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
365
|
+
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
366
366
|
sync(get_loop(), functools.partial(self._download, file, callback=callback))
|
|
367
367
|
|
|
368
|
-
async def _download(self, file: File, *, callback: "Callback" = None) -> None:
|
|
368
|
+
async def _download(self, file: "File", *, callback: "Callback" = None) -> None:
|
|
369
369
|
if self.cache.contains(file):
|
|
370
370
|
# Already in cache, so there's nothing to do.
|
|
371
371
|
return
|
|
372
372
|
await self._put_in_cache(file, callback=callback)
|
|
373
373
|
|
|
374
|
-
def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
|
|
374
|
+
def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
|
|
375
375
|
sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))
|
|
376
376
|
|
|
377
|
-
async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
|
|
377
|
+
async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
|
|
378
378
|
assert not file.location
|
|
379
379
|
if file.etag:
|
|
380
380
|
etag = await self.get_current_etag(file)
|
|
@@ -27,6 +27,8 @@ from datachain.data_storage import JobQueryType, JobStatus
|
|
|
27
27
|
from datachain.data_storage.serializer import Serializable
|
|
28
28
|
from datachain.dataset import (
|
|
29
29
|
DatasetDependency,
|
|
30
|
+
DatasetListRecord,
|
|
31
|
+
DatasetListVersion,
|
|
30
32
|
DatasetRecord,
|
|
31
33
|
DatasetStatus,
|
|
32
34
|
DatasetVersion,
|
|
@@ -59,6 +61,8 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
59
61
|
|
|
60
62
|
schema: "schema.Schema"
|
|
61
63
|
dataset_class: type[DatasetRecord] = DatasetRecord
|
|
64
|
+
dataset_list_class: type[DatasetListRecord] = DatasetListRecord
|
|
65
|
+
dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
|
|
62
66
|
dependency_class: type[DatasetDependency] = DatasetDependency
|
|
63
67
|
job_class: type[Job] = Job
|
|
64
68
|
|
|
@@ -166,11 +170,11 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
166
170
|
"""
|
|
167
171
|
|
|
168
172
|
@abstractmethod
|
|
169
|
-
def list_datasets(self) -> Iterator[
|
|
173
|
+
def list_datasets(self) -> Iterator[DatasetListRecord]:
|
|
170
174
|
"""Lists all datasets."""
|
|
171
175
|
|
|
172
176
|
@abstractmethod
|
|
173
|
-
def list_datasets_by_prefix(self, prefix: str) -> Iterator["
|
|
177
|
+
def list_datasets_by_prefix(self, prefix: str) -> Iterator["DatasetListRecord"]:
|
|
174
178
|
"""Lists all datasets which names start with prefix."""
|
|
175
179
|
|
|
176
180
|
@abstractmethod
|
|
@@ -348,6 +352,14 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
348
352
|
if c.name # type: ignore [attr-defined]
|
|
349
353
|
]
|
|
350
354
|
|
|
355
|
+
@cached_property
|
|
356
|
+
def _dataset_list_fields(self) -> list[str]:
|
|
357
|
+
return [
|
|
358
|
+
c.name # type: ignore [attr-defined]
|
|
359
|
+
for c in self._datasets_columns()
|
|
360
|
+
if c.name in self.dataset_list_class.__dataclass_fields__ # type: ignore [attr-defined]
|
|
361
|
+
]
|
|
362
|
+
|
|
351
363
|
@classmethod
|
|
352
364
|
def _datasets_versions_columns(cls) -> list["SchemaItem"]:
|
|
353
365
|
"""Datasets versions table columns."""
|
|
@@ -390,6 +402,15 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
390
402
|
if c.name # type: ignore [attr-defined]
|
|
391
403
|
]
|
|
392
404
|
|
|
405
|
+
@cached_property
|
|
406
|
+
def _dataset_list_version_fields(self) -> list[str]:
|
|
407
|
+
return [
|
|
408
|
+
c.name # type: ignore [attr-defined]
|
|
409
|
+
for c in self._datasets_versions_columns()
|
|
410
|
+
if c.name # type: ignore [attr-defined]
|
|
411
|
+
in self.dataset_list_version_class.__dataclass_fields__
|
|
412
|
+
]
|
|
413
|
+
|
|
393
414
|
@classmethod
|
|
394
415
|
def _datasets_dependencies_columns(cls) -> list["SchemaItem"]:
|
|
395
416
|
"""Datasets dependencies table columns."""
|
|
@@ -671,7 +692,25 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
671
692
|
if dataset:
|
|
672
693
|
yield dataset
|
|
673
694
|
|
|
674
|
-
def
|
|
695
|
+
def _parse_list_dataset(self, rows) -> Optional[DatasetListRecord]:
|
|
696
|
+
versions = [self.dataset_list_class.parse(*r) for r in rows]
|
|
697
|
+
if not versions:
|
|
698
|
+
return None
|
|
699
|
+
return reduce(lambda ds, version: ds.merge_versions(version), versions)
|
|
700
|
+
|
|
701
|
+
def _parse_dataset_list(self, rows) -> Iterator["DatasetListRecord"]:
|
|
702
|
+
# grouping rows by dataset id
|
|
703
|
+
for _, g in groupby(rows, lambda r: r[0]):
|
|
704
|
+
dataset = self._parse_list_dataset(list(g))
|
|
705
|
+
if dataset:
|
|
706
|
+
yield dataset
|
|
707
|
+
|
|
708
|
+
def _get_dataset_query(
|
|
709
|
+
self,
|
|
710
|
+
dataset_fields: list[str],
|
|
711
|
+
dataset_version_fields: list[str],
|
|
712
|
+
isouter: bool = True,
|
|
713
|
+
):
|
|
675
714
|
if not (
|
|
676
715
|
self.db.has_table(self._datasets.name)
|
|
677
716
|
and self.db.has_table(self._datasets_versions.name)
|
|
@@ -680,23 +719,36 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
680
719
|
|
|
681
720
|
d = self._datasets
|
|
682
721
|
dv = self._datasets_versions
|
|
722
|
+
|
|
683
723
|
query = self._datasets_select(
|
|
684
|
-
*(getattr(d.c, f) for f in
|
|
685
|
-
*(getattr(dv.c, f) for f in
|
|
724
|
+
*(getattr(d.c, f) for f in dataset_fields),
|
|
725
|
+
*(getattr(dv.c, f) for f in dataset_version_fields),
|
|
686
726
|
)
|
|
687
|
-
j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=
|
|
727
|
+
j = d.join(dv, d.c.id == dv.c.dataset_id, isouter=isouter)
|
|
688
728
|
return query.select_from(j)
|
|
689
729
|
|
|
690
|
-
def
|
|
730
|
+
def _base_dataset_query(self):
|
|
731
|
+
return self._get_dataset_query(
|
|
732
|
+
self._dataset_fields, self._dataset_version_fields
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
def _base_list_datasets_query(self):
|
|
736
|
+
return self._get_dataset_query(
|
|
737
|
+
self._dataset_list_fields, self._dataset_list_version_fields, isouter=False
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
def list_datasets(self) -> Iterator["DatasetListRecord"]:
|
|
691
741
|
"""Lists all datasets."""
|
|
692
|
-
yield from self.
|
|
742
|
+
yield from self._parse_dataset_list(
|
|
743
|
+
self.db.execute(self._base_list_datasets_query())
|
|
744
|
+
)
|
|
693
745
|
|
|
694
746
|
def list_datasets_by_prefix(
|
|
695
747
|
self, prefix: str, conn=None
|
|
696
|
-
) -> Iterator["
|
|
697
|
-
query = self.
|
|
748
|
+
) -> Iterator["DatasetListRecord"]:
|
|
749
|
+
query = self._base_list_datasets_query()
|
|
698
750
|
query = query.where(self._datasets.c.name.startswith(prefix))
|
|
699
|
-
yield from self.
|
|
751
|
+
yield from self._parse_dataset_list(self.db.execute(query))
|
|
700
752
|
|
|
701
753
|
def get_dataset(self, name: str, conn=None) -> DatasetRecord:
|
|
702
754
|
"""Gets a single dataset by name"""
|
datachain/data_storage/schema.py
CHANGED
|
@@ -12,7 +12,7 @@ import sqlalchemy as sa
|
|
|
12
12
|
from sqlalchemy.sql import func as f
|
|
13
13
|
from sqlalchemy.sql.expression import false, null, true
|
|
14
14
|
|
|
15
|
-
from datachain.sql.functions import path
|
|
15
|
+
from datachain.sql.functions import path as pathfunc
|
|
16
16
|
from datachain.sql.types import Int, SQLType, UInt64
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
@@ -130,7 +130,7 @@ class DirExpansion:
|
|
|
130
130
|
|
|
131
131
|
def query(self, q):
|
|
132
132
|
q = self.base_select(q).cte(recursive=True)
|
|
133
|
-
parent =
|
|
133
|
+
parent = pathfunc.parent(self.c(q, "path"))
|
|
134
134
|
q = q.union_all(
|
|
135
135
|
sa.select(
|
|
136
136
|
sa.literal(-1).label("sys__id"),
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -122,7 +122,9 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
122
122
|
return cls(*cls._connect(db_file=db_file))
|
|
123
123
|
|
|
124
124
|
@staticmethod
|
|
125
|
-
def _connect(
|
|
125
|
+
def _connect(
|
|
126
|
+
db_file: Optional[str] = None,
|
|
127
|
+
) -> tuple["Engine", "MetaData", sqlite3.Connection, str]:
|
|
126
128
|
try:
|
|
127
129
|
if db_file == ":memory:":
|
|
128
130
|
# Enable multithreaded usage of the same in-memory db
|
|
@@ -130,9 +132,8 @@ class SQLiteDatabaseEngine(DatabaseEngine):
|
|
|
130
132
|
_get_in_memory_uri(), uri=True, detect_types=DETECT_TYPES
|
|
131
133
|
)
|
|
132
134
|
else:
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
)
|
|
135
|
+
db_file = db_file or DataChainDir.find().db
|
|
136
|
+
db = sqlite3.connect(db_file, detect_types=DETECT_TYPES)
|
|
136
137
|
create_user_defined_sql_functions(db)
|
|
137
138
|
engine = sqlalchemy.create_engine(
|
|
138
139
|
"sqlite+pysqlite:///", creator=lambda: db, future=True
|
|
@@ -224,28 +224,28 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
224
224
|
offset = 0
|
|
225
225
|
num_yielded = 0
|
|
226
226
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
if limit
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
227
|
+
# Ensure we're using a thread-local connection
|
|
228
|
+
with self.clone() as wh:
|
|
229
|
+
while True:
|
|
230
|
+
if limit is not None:
|
|
231
|
+
limit -= num_yielded
|
|
232
|
+
if limit == 0:
|
|
233
|
+
break
|
|
234
|
+
if limit < page_size:
|
|
235
|
+
paginated_query = paginated_query.limit(None).limit(limit)
|
|
236
|
+
|
|
237
237
|
# Cursor results are not thread-safe, so we convert them to a list
|
|
238
238
|
results = list(wh.dataset_rows_select(paginated_query.offset(offset)))
|
|
239
239
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
240
|
+
processed = False
|
|
241
|
+
for row in results:
|
|
242
|
+
processed = True
|
|
243
|
+
yield row
|
|
244
|
+
num_yielded += 1
|
|
245
245
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
246
|
+
if not processed:
|
|
247
|
+
break # no more results
|
|
248
|
+
offset += page_size
|
|
249
249
|
|
|
250
250
|
#
|
|
251
251
|
# Table Name Internal Functions
|