datachain 0.11.11__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/catalog/catalog.py +39 -7
- datachain/catalog/loader.py +19 -13
- datachain/cli/__init__.py +2 -1
- datachain/cli/commands/ls.py +8 -6
- datachain/cli/commands/show.py +7 -0
- datachain/cli/parser/studio.py +13 -1
- datachain/client/fsspec.py +12 -16
- datachain/client/gcs.py +1 -1
- datachain/client/hf.py +36 -14
- datachain/client/local.py +1 -4
- datachain/client/s3.py +1 -1
- datachain/data_storage/metastore.py +6 -0
- datachain/data_storage/warehouse.py +3 -8
- datachain/dataset.py +8 -0
- datachain/error.py +0 -12
- datachain/fs/utils.py +30 -0
- datachain/func/__init__.py +5 -0
- datachain/func/func.py +2 -1
- datachain/lib/dc.py +59 -15
- datachain/lib/file.py +63 -18
- datachain/lib/image.py +30 -6
- datachain/lib/listing.py +21 -39
- datachain/lib/meta_formats.py +2 -2
- datachain/lib/signal_schema.py +65 -18
- datachain/lib/udf.py +3 -0
- datachain/lib/udf_signature.py +17 -9
- datachain/lib/video.py +7 -5
- datachain/model/bbox.py +209 -58
- datachain/model/pose.py +49 -37
- datachain/model/segment.py +22 -18
- datachain/model/ultralytics/bbox.py +9 -9
- datachain/model/ultralytics/pose.py +7 -7
- datachain/model/ultralytics/segment.py +7 -7
- datachain/model/utils.py +191 -0
- datachain/query/dataset.py +8 -2
- datachain/sql/sqlite/base.py +2 -2
- datachain/studio.py +8 -6
- datachain/utils.py +0 -16
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/METADATA +4 -2
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/RECORD +44 -42
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/WHEEL +1 -1
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/LICENSE +0 -0
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/entry_points.txt +0 -0
- {datachain-0.11.11.dist-info → datachain-0.13.0.dist-info}/top_level.txt +0 -0
datachain/catalog/catalog.py
CHANGED
|
@@ -25,7 +25,6 @@ from typing import (
|
|
|
25
25
|
)
|
|
26
26
|
from uuid import uuid4
|
|
27
27
|
|
|
28
|
-
import requests
|
|
29
28
|
import sqlalchemy as sa
|
|
30
29
|
from sqlalchemy import Column
|
|
31
30
|
from tqdm.auto import tqdm
|
|
@@ -54,7 +53,6 @@ from datachain.error import (
|
|
|
54
53
|
from datachain.lib.listing import get_listing
|
|
55
54
|
from datachain.node import DirType, Node, NodeWithPath
|
|
56
55
|
from datachain.nodes_thread_pool import NodesThreadPool
|
|
57
|
-
from datachain.remote.studio import StudioClient
|
|
58
56
|
from datachain.sql.types import DateTime, SQLType
|
|
59
57
|
from datachain.utils import DataChainDir
|
|
60
58
|
|
|
@@ -162,6 +160,8 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
162
160
|
max_threads: int = PULL_DATASET_MAX_THREADS,
|
|
163
161
|
progress_bar=None,
|
|
164
162
|
):
|
|
163
|
+
from datachain.remote.studio import StudioClient
|
|
164
|
+
|
|
165
165
|
super().__init__(max_threads)
|
|
166
166
|
self._check_dependencies()
|
|
167
167
|
self.metastore = metastore
|
|
@@ -234,6 +234,8 @@ class DatasetRowsFetcher(NodesThreadPool):
|
|
|
234
234
|
return df.drop("sys__id", axis=1)
|
|
235
235
|
|
|
236
236
|
def get_parquet_content(self, url: str):
|
|
237
|
+
import requests
|
|
238
|
+
|
|
237
239
|
while True:
|
|
238
240
|
if self.should_check_for_status():
|
|
239
241
|
self.check_for_status()
|
|
@@ -775,6 +777,8 @@ class Catalog:
|
|
|
775
777
|
validate_version: Optional[bool] = True,
|
|
776
778
|
listing: Optional[bool] = False,
|
|
777
779
|
uuid: Optional[str] = None,
|
|
780
|
+
description: Optional[str] = None,
|
|
781
|
+
labels: Optional[list[str]] = None,
|
|
778
782
|
) -> "DatasetRecord":
|
|
779
783
|
"""
|
|
780
784
|
Creates new dataset of a specific version.
|
|
@@ -801,6 +805,8 @@ class Catalog:
|
|
|
801
805
|
query_script=query_script,
|
|
802
806
|
schema=schema,
|
|
803
807
|
ignore_if_exists=True,
|
|
808
|
+
description=description,
|
|
809
|
+
labels=labels,
|
|
804
810
|
)
|
|
805
811
|
|
|
806
812
|
version = version or default_version
|
|
@@ -1130,6 +1136,8 @@ class Catalog:
|
|
|
1130
1136
|
raise DatasetNotFoundError(f"Dataset with version uuid {uuid} not found.")
|
|
1131
1137
|
|
|
1132
1138
|
def get_remote_dataset(self, name: str) -> DatasetRecord:
|
|
1139
|
+
from datachain.remote.studio import StudioClient
|
|
1140
|
+
|
|
1133
1141
|
studio_client = StudioClient()
|
|
1134
1142
|
|
|
1135
1143
|
info_response = studio_client.dataset_info(name)
|
|
@@ -1164,8 +1172,27 @@ class Catalog:
|
|
|
1164
1172
|
|
|
1165
1173
|
return direct_dependencies
|
|
1166
1174
|
|
|
1167
|
-
def ls_datasets(
|
|
1168
|
-
|
|
1175
|
+
def ls_datasets(
|
|
1176
|
+
self, include_listing: bool = False, studio: bool = False
|
|
1177
|
+
) -> Iterator[DatasetListRecord]:
|
|
1178
|
+
from datachain.remote.studio import StudioClient
|
|
1179
|
+
|
|
1180
|
+
if studio:
|
|
1181
|
+
client = StudioClient()
|
|
1182
|
+
response = client.ls_datasets()
|
|
1183
|
+
if not response.ok:
|
|
1184
|
+
raise DataChainError(response.message)
|
|
1185
|
+
if not response.data:
|
|
1186
|
+
return
|
|
1187
|
+
|
|
1188
|
+
datasets: Iterator[DatasetListRecord] = (
|
|
1189
|
+
DatasetListRecord.from_dict(d)
|
|
1190
|
+
for d in response.data
|
|
1191
|
+
if not d.get("name", "").startswith(QUERY_DATASET_PREFIX)
|
|
1192
|
+
)
|
|
1193
|
+
else:
|
|
1194
|
+
datasets = self.metastore.list_datasets()
|
|
1195
|
+
|
|
1169
1196
|
for d in datasets:
|
|
1170
1197
|
if not d.is_bucket_listing or include_listing:
|
|
1171
1198
|
yield d
|
|
@@ -1173,9 +1200,12 @@ class Catalog:
|
|
|
1173
1200
|
def list_datasets_versions(
|
|
1174
1201
|
self,
|
|
1175
1202
|
include_listing: bool = False,
|
|
1203
|
+
studio: bool = False,
|
|
1176
1204
|
) -> Iterator[tuple[DatasetListRecord, "DatasetListVersion", Optional["Job"]]]:
|
|
1177
1205
|
"""Iterate over all dataset versions with related jobs."""
|
|
1178
|
-
datasets = list(
|
|
1206
|
+
datasets = list(
|
|
1207
|
+
self.ls_datasets(include_listing=include_listing, studio=studio)
|
|
1208
|
+
)
|
|
1179
1209
|
|
|
1180
1210
|
# preselect dataset versions jobs from db to avoid multiple queries
|
|
1181
1211
|
jobs_ids: set[str] = {
|
|
@@ -1345,6 +1375,8 @@ class Catalog:
|
|
|
1345
1375
|
if cp and not output:
|
|
1346
1376
|
raise ValueError("Please provide output directory for instantiation")
|
|
1347
1377
|
|
|
1378
|
+
from datachain.remote.studio import StudioClient
|
|
1379
|
+
|
|
1348
1380
|
studio_client = StudioClient()
|
|
1349
1381
|
|
|
1350
1382
|
try:
|
|
@@ -1580,7 +1612,7 @@ class Catalog:
|
|
|
1580
1612
|
except TerminationSignal as exc:
|
|
1581
1613
|
signal.signal(signal.SIGTERM, orig_sigterm_handler)
|
|
1582
1614
|
signal.signal(signal.SIGINT, orig_sigint_handler)
|
|
1583
|
-
|
|
1615
|
+
logger.info("Shutting down process %s, received %r", proc.pid, exc)
|
|
1584
1616
|
# Rather than forwarding the signal to the child, we try to shut it down
|
|
1585
1617
|
# gracefully. This is because we consider the script to be interactive
|
|
1586
1618
|
# and special, so we give it time to cleanup before exiting.
|
|
@@ -1595,7 +1627,7 @@ class Catalog:
|
|
|
1595
1627
|
if thread:
|
|
1596
1628
|
thread.join() # wait for the reader thread
|
|
1597
1629
|
|
|
1598
|
-
|
|
1630
|
+
logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
|
|
1599
1631
|
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
|
|
1600
1632
|
raise QueryScriptCancelError(
|
|
1601
1633
|
"Query script was canceled by user",
|
datachain/catalog/loader.py
CHANGED
|
@@ -1,19 +1,13 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from importlib import import_module
|
|
3
|
-
from typing import Any, Optional
|
|
4
|
-
|
|
5
|
-
from datachain.catalog import Catalog
|
|
6
|
-
from datachain.data_storage import (
|
|
7
|
-
AbstractMetastore,
|
|
8
|
-
AbstractWarehouse,
|
|
9
|
-
)
|
|
10
|
-
from datachain.data_storage.serializer import deserialize
|
|
11
|
-
from datachain.data_storage.sqlite import (
|
|
12
|
-
SQLiteMetastore,
|
|
13
|
-
SQLiteWarehouse,
|
|
14
|
-
)
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
4
|
+
|
|
15
5
|
from datachain.utils import get_envs_by_prefix
|
|
16
6
|
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from datachain.catalog import Catalog
|
|
9
|
+
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
|
|
10
|
+
|
|
17
11
|
METASTORE_SERIALIZED = "DATACHAIN__METASTORE"
|
|
18
12
|
METASTORE_IMPORT_PATH = "DATACHAIN_METASTORE"
|
|
19
13
|
METASTORE_ARG_PREFIX = "DATACHAIN_METASTORE_ARG_"
|
|
@@ -27,6 +21,9 @@ IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"
|
|
|
27
21
|
|
|
28
22
|
|
|
29
23
|
def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
|
|
24
|
+
from datachain.data_storage import AbstractMetastore
|
|
25
|
+
from datachain.data_storage.serializer import deserialize
|
|
26
|
+
|
|
30
27
|
metastore_serialized = os.environ.get(METASTORE_SERIALIZED)
|
|
31
28
|
if metastore_serialized:
|
|
32
29
|
metastore_obj = deserialize(metastore_serialized)
|
|
@@ -45,6 +42,8 @@ def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
|
|
|
45
42
|
}
|
|
46
43
|
|
|
47
44
|
if not metastore_import_path:
|
|
45
|
+
from datachain.data_storage.sqlite import SQLiteMetastore
|
|
46
|
+
|
|
48
47
|
metastore_args["in_memory"] = in_memory
|
|
49
48
|
return SQLiteMetastore(**metastore_args)
|
|
50
49
|
if in_memory:
|
|
@@ -62,6 +61,9 @@ def get_metastore(in_memory: bool = False) -> "AbstractMetastore":
|
|
|
62
61
|
|
|
63
62
|
|
|
64
63
|
def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
|
|
64
|
+
from datachain.data_storage import AbstractWarehouse
|
|
65
|
+
from datachain.data_storage.serializer import deserialize
|
|
66
|
+
|
|
65
67
|
warehouse_serialized = os.environ.get(WAREHOUSE_SERIALIZED)
|
|
66
68
|
if warehouse_serialized:
|
|
67
69
|
warehouse_obj = deserialize(warehouse_serialized)
|
|
@@ -80,6 +82,8 @@ def get_warehouse(in_memory: bool = False) -> "AbstractWarehouse":
|
|
|
80
82
|
}
|
|
81
83
|
|
|
82
84
|
if not warehouse_import_path:
|
|
85
|
+
from datachain.data_storage.sqlite import SQLiteWarehouse
|
|
86
|
+
|
|
83
87
|
warehouse_args["in_memory"] = in_memory
|
|
84
88
|
return SQLiteWarehouse(**warehouse_args)
|
|
85
89
|
if in_memory:
|
|
@@ -121,7 +125,7 @@ def get_distributed_class(**kwargs):
|
|
|
121
125
|
|
|
122
126
|
def get_catalog(
|
|
123
127
|
client_config: Optional[dict[str, Any]] = None, in_memory: bool = False
|
|
124
|
-
) -> Catalog:
|
|
128
|
+
) -> "Catalog":
|
|
125
129
|
"""
|
|
126
130
|
Function that creates Catalog instance with appropriate metastore
|
|
127
131
|
and warehouse classes. Metastore class can be provided with env variable
|
|
@@ -133,6 +137,8 @@ def get_catalog(
|
|
|
133
137
|
and name of variable after, e.g. if it accepts team_id as kwargs
|
|
134
138
|
we can provide DATACHAIN_METASTORE_ARG_TEAM_ID=12345 env variable.
|
|
135
139
|
"""
|
|
140
|
+
from datachain.catalog import Catalog
|
|
141
|
+
|
|
136
142
|
return Catalog(
|
|
137
143
|
metastore=get_metastore(in_memory=in_memory),
|
|
138
144
|
warehouse=get_warehouse(in_memory=in_memory),
|
datachain/cli/__init__.py
CHANGED
|
@@ -6,7 +6,6 @@ from multiprocessing import freeze_support
|
|
|
6
6
|
from typing import Optional
|
|
7
7
|
|
|
8
8
|
from datachain.cli.utils import get_logging_level
|
|
9
|
-
from datachain.telemetry import telemetry
|
|
10
9
|
|
|
11
10
|
from .commands import (
|
|
12
11
|
clear_cache,
|
|
@@ -70,6 +69,8 @@ def main(argv: Optional[list[str]] = None) -> int:
|
|
|
70
69
|
error, return_code = handle_general_exception(exc, args, logging_level)
|
|
71
70
|
return return_code
|
|
72
71
|
finally:
|
|
72
|
+
from datachain.telemetry import telemetry
|
|
73
|
+
|
|
73
74
|
telemetry.send_cli_call(args.command, error=error)
|
|
74
75
|
|
|
75
76
|
|
datachain/cli/commands/ls.py
CHANGED
|
@@ -38,11 +38,12 @@ def ls_local(
|
|
|
38
38
|
):
|
|
39
39
|
from datachain import DataChain
|
|
40
40
|
|
|
41
|
-
if catalog is None:
|
|
42
|
-
from datachain.catalog import get_catalog
|
|
43
|
-
|
|
44
|
-
catalog = get_catalog(client_config=client_config)
|
|
45
41
|
if sources:
|
|
42
|
+
if catalog is None:
|
|
43
|
+
from datachain.catalog import get_catalog
|
|
44
|
+
|
|
45
|
+
catalog = get_catalog(client_config=client_config)
|
|
46
|
+
|
|
46
47
|
actual_sources = list(ls_urls(sources, catalog=catalog, long=long, **kwargs))
|
|
47
48
|
if len(actual_sources) == 1:
|
|
48
49
|
for _, entries in actual_sources:
|
|
@@ -61,8 +62,9 @@ def ls_local(
|
|
|
61
62
|
for entry in entries:
|
|
62
63
|
print(format_ls_entry(entry))
|
|
63
64
|
else:
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
# Collect results in a list here to prevent interference from `tqdm` and `print`
|
|
66
|
+
listing = list(DataChain.listings().collect("listing"))
|
|
67
|
+
for ls in listing:
|
|
66
68
|
print(format_ls_entry(f"{ls.uri}@v{ls.version}")) # type: ignore[union-attr]
|
|
67
69
|
|
|
68
70
|
|
datachain/cli/commands/show.py
CHANGED
|
@@ -40,6 +40,13 @@ def show(
|
|
|
40
40
|
.offset(offset)
|
|
41
41
|
)
|
|
42
42
|
records = query.to_db_records()
|
|
43
|
+
print("Name: ", name)
|
|
44
|
+
if dataset.description:
|
|
45
|
+
print("Description: ", dataset.description)
|
|
46
|
+
if dataset.labels:
|
|
47
|
+
print("Labels: ", ",".join(dataset.labels))
|
|
48
|
+
print("\n")
|
|
49
|
+
|
|
43
50
|
show_records(records, collapse_columns=not no_collapse, hidden_fields=hidden_fields)
|
|
44
51
|
|
|
45
52
|
if schema and dataset_version.feature_schema:
|
datachain/cli/parser/studio.py
CHANGED
|
@@ -63,19 +63,31 @@ def add_auth_parser(subparsers, parent_parser) -> None:
|
|
|
63
63
|
default=False,
|
|
64
64
|
help="Use code-based authentication without browser",
|
|
65
65
|
)
|
|
66
|
+
login_parser.add_argument(
|
|
67
|
+
"--local",
|
|
68
|
+
action="store_true",
|
|
69
|
+
default=False,
|
|
70
|
+
help="Save the token in the local project config",
|
|
71
|
+
)
|
|
66
72
|
|
|
67
73
|
auth_logout_help = "Log out from Studio"
|
|
68
74
|
auth_logout_description = (
|
|
69
75
|
"Remove the Studio authentication token from global config."
|
|
70
76
|
)
|
|
71
77
|
|
|
72
|
-
auth_subparser.add_parser(
|
|
78
|
+
logout_parser = auth_subparser.add_parser(
|
|
73
79
|
"logout",
|
|
74
80
|
parents=[parent_parser],
|
|
75
81
|
description=auth_logout_description,
|
|
76
82
|
help=auth_logout_help,
|
|
77
83
|
formatter_class=CustomHelpFormatter,
|
|
78
84
|
)
|
|
85
|
+
logout_parser.add_argument(
|
|
86
|
+
"--local",
|
|
87
|
+
action="store_true",
|
|
88
|
+
default=False,
|
|
89
|
+
help="Remove the token from the local project config",
|
|
90
|
+
)
|
|
79
91
|
|
|
80
92
|
auth_team_help = "Set default team for Studio operations"
|
|
81
93
|
auth_team_description = "Set the default team for Studio operations."
|
datachain/client/fsspec.py
CHANGED
|
@@ -17,10 +17,10 @@ from typing import (
|
|
|
17
17
|
ClassVar,
|
|
18
18
|
NamedTuple,
|
|
19
19
|
Optional,
|
|
20
|
+
Union,
|
|
20
21
|
)
|
|
21
22
|
from urllib.parse import urlparse
|
|
22
23
|
|
|
23
|
-
from botocore.exceptions import ClientError
|
|
24
24
|
from dvc_objects.fs.system import reflink
|
|
25
25
|
from fsspec.asyn import get_loop, sync
|
|
26
26
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
@@ -28,7 +28,6 @@ from tqdm.auto import tqdm
|
|
|
28
28
|
|
|
29
29
|
from datachain.cache import Cache
|
|
30
30
|
from datachain.client.fileslice import FileWrapper
|
|
31
|
-
from datachain.error import ClientError as DataChainClientError
|
|
32
31
|
from datachain.nodes_fetcher import NodesFetcher
|
|
33
32
|
from datachain.nodes_thread_pool import NodeChunk
|
|
34
33
|
|
|
@@ -83,19 +82,17 @@ class Client(ABC):
|
|
|
83
82
|
self.uri = self.get_uri(self.name)
|
|
84
83
|
|
|
85
84
|
@staticmethod
|
|
86
|
-
def get_implementation(url: str) -> type["Client"]:
|
|
85
|
+
def get_implementation(url: Union[str, os.PathLike[str]]) -> type["Client"]:
|
|
87
86
|
from .azure import AzureClient
|
|
88
87
|
from .gcs import GCSClient
|
|
89
88
|
from .hf import HfClient
|
|
90
89
|
from .local import FileClient
|
|
91
90
|
from .s3 import ClientS3
|
|
92
91
|
|
|
93
|
-
protocol = urlparse(url).scheme
|
|
92
|
+
protocol = urlparse(str(url)).scheme
|
|
94
93
|
|
|
95
|
-
if not protocol or _is_win_local_path(url):
|
|
94
|
+
if not protocol or _is_win_local_path(str(url)):
|
|
96
95
|
return FileClient
|
|
97
|
-
|
|
98
|
-
protocol = protocol.lower()
|
|
99
96
|
if protocol == ClientS3.protocol:
|
|
100
97
|
return ClientS3
|
|
101
98
|
if protocol == GCSClient.protocol:
|
|
@@ -121,9 +118,11 @@ class Client(ABC):
|
|
|
121
118
|
return cls.get_uri(storage_name), rel_path
|
|
122
119
|
|
|
123
120
|
@staticmethod
|
|
124
|
-
def get_client(
|
|
121
|
+
def get_client(
|
|
122
|
+
source: Union[str, os.PathLike[str]], cache: Cache, **kwargs
|
|
123
|
+
) -> "Client":
|
|
125
124
|
cls = Client.get_implementation(source)
|
|
126
|
-
storage_url, _ = cls.split_url(source)
|
|
125
|
+
storage_url, _ = cls.split_url(str(source))
|
|
127
126
|
if os.name == "nt":
|
|
128
127
|
storage_url = storage_url.removeprefix("/")
|
|
129
128
|
|
|
@@ -209,7 +208,7 @@ class Client(ABC):
|
|
|
209
208
|
|
|
210
209
|
async def get_current_etag(self, file: "File") -> str:
|
|
211
210
|
kwargs = {}
|
|
212
|
-
if self.fs
|
|
211
|
+
if getattr(self.fs, "version_aware", False):
|
|
213
212
|
kwargs["version_id"] = file.version
|
|
214
213
|
info = await self.fs._info(
|
|
215
214
|
self.get_full_path(file.path, file.version), **kwargs
|
|
@@ -286,11 +285,6 @@ class Client(ABC):
|
|
|
286
285
|
worker.cancel()
|
|
287
286
|
if excs:
|
|
288
287
|
raise excs[0]
|
|
289
|
-
except ClientError as exc:
|
|
290
|
-
raise DataChainClientError(
|
|
291
|
-
exc.response.get("Error", {}).get("Message") or exc,
|
|
292
|
-
exc.response.get("Error", {}).get("Code"),
|
|
293
|
-
) from exc
|
|
294
288
|
finally:
|
|
295
289
|
# This ensures the progress bar is closed before any exceptions are raised
|
|
296
290
|
progress_bar.close()
|
|
@@ -333,7 +327,9 @@ class Client(ABC):
|
|
|
333
327
|
return not (key.startswith("/") or key.endswith("/") or "//" in key)
|
|
334
328
|
|
|
335
329
|
async def ls_dir(self, path):
|
|
336
|
-
|
|
330
|
+
if getattr(self.fs, "version_aware", False):
|
|
331
|
+
kwargs = {"versions": True}
|
|
332
|
+
return await self.fs._ls(path, detail=True, **kwargs)
|
|
337
333
|
|
|
338
334
|
def rel_path(self, path: str) -> str:
|
|
339
335
|
return self.fs.split_path(path)[1]
|
datachain/client/gcs.py
CHANGED
|
@@ -30,7 +30,7 @@ class GCSClient(Client):
|
|
|
30
30
|
if kwargs.pop("anon", False):
|
|
31
31
|
kwargs["token"] = "anon" # noqa: S105
|
|
32
32
|
|
|
33
|
-
return cast(GCSFileSystem, super().create_fs(**kwargs))
|
|
33
|
+
return cast("GCSFileSystem", super().create_fs(**kwargs))
|
|
34
34
|
|
|
35
35
|
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
|
|
36
36
|
"""
|
datachain/client/hf.py
CHANGED
|
@@ -1,25 +1,50 @@
|
|
|
1
|
-
import
|
|
1
|
+
import functools
|
|
2
2
|
import posixpath
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
from huggingface_hub import HfFileSystem
|
|
3
|
+
from typing import Any
|
|
6
4
|
|
|
7
5
|
from datachain.lib.file import File
|
|
8
6
|
|
|
9
7
|
from .fsspec import Client
|
|
10
8
|
|
|
11
9
|
|
|
10
|
+
class classproperty: # noqa: N801
|
|
11
|
+
def __init__(self, func):
|
|
12
|
+
self.fget = func
|
|
13
|
+
|
|
14
|
+
def __get__(self, instance, owner):
|
|
15
|
+
return self.fget(owner)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@functools.cache
|
|
19
|
+
def get_hf_filesystem_cls():
|
|
20
|
+
import fsspec
|
|
21
|
+
from packaging.version import Version, parse
|
|
22
|
+
|
|
23
|
+
fsspec_version = parse(fsspec.__version__)
|
|
24
|
+
minver = Version("2024.12.0")
|
|
25
|
+
|
|
26
|
+
if fsspec_version < minver:
|
|
27
|
+
raise ImportError(
|
|
28
|
+
f"datachain requires 'fsspec>={minver}' but version "
|
|
29
|
+
f"{fsspec_version} is installed."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
|
|
33
|
+
from huggingface_hub import HfFileSystem
|
|
34
|
+
|
|
35
|
+
fs_cls = AsyncFileSystemWrapper.wrap_class(HfFileSystem)
|
|
36
|
+
# AsyncFileSystemWrapper does not set class properties, so we need to set them back.
|
|
37
|
+
fs_cls.protocol = HfFileSystem.protocol
|
|
38
|
+
return fs_cls
|
|
39
|
+
|
|
40
|
+
|
|
12
41
|
class HfClient(Client):
|
|
13
|
-
FS_CLASS = HfFileSystem
|
|
14
42
|
PREFIX = "hf://"
|
|
15
43
|
protocol = "hf"
|
|
16
44
|
|
|
17
|
-
@
|
|
18
|
-
def
|
|
19
|
-
|
|
20
|
-
kwargs["token"] = os.environ["HF_TOKEN"]
|
|
21
|
-
|
|
22
|
-
return cast(HfFileSystem, super().create_fs(**kwargs))
|
|
45
|
+
@classproperty
|
|
46
|
+
def FS_CLASS(cls): # noqa: N802, N805
|
|
47
|
+
return get_hf_filesystem_cls()
|
|
23
48
|
|
|
24
49
|
def info_to_file(self, v: dict[str, Any], path: str) -> File:
|
|
25
50
|
return File(
|
|
@@ -31,8 +56,5 @@ class HfClient(Client):
|
|
|
31
56
|
last_modified=v["last_commit"].date,
|
|
32
57
|
)
|
|
33
58
|
|
|
34
|
-
async def ls_dir(self, path):
|
|
35
|
-
return self.fs.ls(path, detail=True)
|
|
36
|
-
|
|
37
59
|
def rel_path(self, path):
|
|
38
60
|
return posixpath.relpath(path, self.name)
|
datachain/client/local.py
CHANGED
|
@@ -67,10 +67,7 @@ class FileClient(Client):
|
|
|
67
67
|
@classmethod
|
|
68
68
|
def split_url(cls, url: str) -> tuple[str, str]:
|
|
69
69
|
parsed = urlparse(url)
|
|
70
|
-
if parsed.scheme
|
|
71
|
-
scheme, rest = url.split(":", 1)
|
|
72
|
-
url = f"{scheme.lower()}:{rest}"
|
|
73
|
-
else:
|
|
70
|
+
if parsed.scheme != "file":
|
|
74
71
|
url = cls.path_to_uri(url)
|
|
75
72
|
|
|
76
73
|
fill_path = url[len(cls.PREFIX) :]
|
datachain/client/s3.py
CHANGED
|
@@ -55,7 +55,7 @@ class ClientS3(Client):
|
|
|
55
55
|
except NotImplementedError:
|
|
56
56
|
pass
|
|
57
57
|
|
|
58
|
-
return cast(S3FileSystem, super().create_fs(**kwargs))
|
|
58
|
+
return cast("S3FileSystem", super().create_fs(**kwargs))
|
|
59
59
|
|
|
60
60
|
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
|
|
61
61
|
"""
|
|
@@ -119,6 +119,8 @@ class AbstractMetastore(ABC, Serializable):
|
|
|
119
119
|
query_script: str = "",
|
|
120
120
|
schema: Optional[dict[str, Any]] = None,
|
|
121
121
|
ignore_if_exists: bool = False,
|
|
122
|
+
description: Optional[str] = None,
|
|
123
|
+
labels: Optional[list[str]] = None,
|
|
122
124
|
) -> DatasetRecord:
|
|
123
125
|
"""Creates new dataset."""
|
|
124
126
|
|
|
@@ -518,6 +520,8 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
518
520
|
query_script: str = "",
|
|
519
521
|
schema: Optional[dict[str, Any]] = None,
|
|
520
522
|
ignore_if_exists: bool = False,
|
|
523
|
+
description: Optional[str] = None,
|
|
524
|
+
labels: Optional[list[str]] = None,
|
|
521
525
|
**kwargs, # TODO registered = True / False
|
|
522
526
|
) -> DatasetRecord:
|
|
523
527
|
"""Creates new dataset."""
|
|
@@ -533,6 +537,8 @@ class AbstractDBMetastore(AbstractMetastore):
|
|
|
533
537
|
sources="\n".join(sources) if sources else "",
|
|
534
538
|
query_script=query_script,
|
|
535
539
|
schema=json.dumps(schema or {}),
|
|
540
|
+
description=description,
|
|
541
|
+
labels=json.dumps(labels or []),
|
|
536
542
|
)
|
|
537
543
|
if ignore_if_exists and hasattr(query, "on_conflict_do_nothing"):
|
|
538
544
|
# SQLite and PostgreSQL both support 'on_conflict_do_nothing',
|
|
@@ -39,13 +39,6 @@ if TYPE_CHECKING:
|
|
|
39
39
|
from datachain.data_storage.schema import DataTable
|
|
40
40
|
from datachain.lib.file import File
|
|
41
41
|
|
|
42
|
-
try:
|
|
43
|
-
import numpy as np
|
|
44
|
-
|
|
45
|
-
numpy_imported = True
|
|
46
|
-
except ImportError:
|
|
47
|
-
numpy_imported = False
|
|
48
|
-
|
|
49
42
|
|
|
50
43
|
logger = logging.getLogger("datachain")
|
|
51
44
|
|
|
@@ -96,7 +89,9 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
96
89
|
If value is a list or some other iterable, it tries to convert sub elements
|
|
97
90
|
as well
|
|
98
91
|
"""
|
|
99
|
-
|
|
92
|
+
import numpy as np
|
|
93
|
+
|
|
94
|
+
if isinstance(val, (np.ndarray, np.generic)):
|
|
100
95
|
val = val.tolist()
|
|
101
96
|
|
|
102
97
|
# Optimization: Precompute all the column type variables.
|
datachain/dataset.py
CHANGED
|
@@ -302,6 +302,7 @@ class DatasetListVersion:
|
|
|
302
302
|
size: Optional[int],
|
|
303
303
|
query_script: str = "",
|
|
304
304
|
job_id: Optional[str] = None,
|
|
305
|
+
**kwargs,
|
|
305
306
|
):
|
|
306
307
|
return cls(
|
|
307
308
|
id,
|
|
@@ -648,6 +649,13 @@ class DatasetListRecord:
|
|
|
648
649
|
def has_version_with_uuid(self, uuid: str) -> bool:
|
|
649
650
|
return any(v.uuid == uuid for v in self.versions)
|
|
650
651
|
|
|
652
|
+
@classmethod
|
|
653
|
+
def from_dict(cls, d: dict[str, Any]) -> "DatasetListRecord":
|
|
654
|
+
versions = [DatasetListVersion.parse(**v) for v in d.get("versions", [])]
|
|
655
|
+
kwargs = {f.name: d[f.name] for f in fields(cls) if f.name in d}
|
|
656
|
+
kwargs["versions"] = versions
|
|
657
|
+
return cls(**kwargs)
|
|
658
|
+
|
|
651
659
|
|
|
652
660
|
class RowDict(dict):
|
|
653
661
|
pass
|
datachain/error.py
CHANGED
|
@@ -1,15 +1,3 @@
|
|
|
1
|
-
import botocore.errorfactory
|
|
2
|
-
import botocore.exceptions
|
|
3
|
-
import gcsfs.retry
|
|
4
|
-
|
|
5
|
-
REMOTE_ERRORS = (
|
|
6
|
-
gcsfs.retry.HttpError, # GCS
|
|
7
|
-
OSError, # GCS
|
|
8
|
-
botocore.exceptions.BotoCoreError, # S3
|
|
9
|
-
ValueError, # Azure
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
|
|
13
1
|
class DataChainError(RuntimeError):
|
|
14
2
|
pass
|
|
15
3
|
|
datachain/fs/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from fsspec.implementations.local import LocalFileSystem
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from fsspec import AbstractFileSystem
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _isdir(fs: "AbstractFileSystem", path: str) -> bool:
|
|
10
|
+
info = fs.info(path)
|
|
11
|
+
return info["type"] == "directory" or (
|
|
12
|
+
info["size"] == 0 and info["type"] == "file" and info["name"].endswith("/")
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def isfile(fs: "AbstractFileSystem", path: str) -> bool:
|
|
17
|
+
"""
|
|
18
|
+
Returns True if uri points to a file.
|
|
19
|
+
|
|
20
|
+
Supports special directories on object storages, e.g.:
|
|
21
|
+
Google creates a zero byte file with the same name as the directory with a trailing
|
|
22
|
+
slash at the end.
|
|
23
|
+
"""
|
|
24
|
+
if isinstance(fs, LocalFileSystem):
|
|
25
|
+
return fs.isfile(path)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
return not _isdir(fs, path)
|
|
29
|
+
except FileNotFoundError:
|
|
30
|
+
return False
|
datachain/func/__init__.py
CHANGED
|
@@ -18,6 +18,7 @@ from .aggregate import (
|
|
|
18
18
|
from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
19
|
from .conditional import and_, case, greatest, ifelse, isnone, least, or_
|
|
20
20
|
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
|
+
from .path import file_ext, file_stem, name, parent
|
|
21
22
|
from .random import rand
|
|
22
23
|
from .string import byte_hamming_distance
|
|
23
24
|
from .window import window
|
|
@@ -40,6 +41,8 @@ __all__ = [
|
|
|
40
41
|
"count",
|
|
41
42
|
"dense_rank",
|
|
42
43
|
"euclidean_distance",
|
|
44
|
+
"file_ext",
|
|
45
|
+
"file_stem",
|
|
43
46
|
"first",
|
|
44
47
|
"greatest",
|
|
45
48
|
"ifelse",
|
|
@@ -50,7 +53,9 @@ __all__ = [
|
|
|
50
53
|
"literal",
|
|
51
54
|
"max",
|
|
52
55
|
"min",
|
|
56
|
+
"name",
|
|
53
57
|
"or_",
|
|
58
|
+
"parent",
|
|
54
59
|
"path",
|
|
55
60
|
"rand",
|
|
56
61
|
"random",
|
datachain/func/func.py
CHANGED
|
@@ -3,7 +3,6 @@ from collections.abc import Sequence
|
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
|
4
4
|
|
|
5
5
|
from sqlalchemy import BindParameter, Case, ColumnElement, Integer, cast, desc
|
|
6
|
-
from sqlalchemy.ext.hybrid import Comparator
|
|
7
6
|
from sqlalchemy.sql import func as sa_func
|
|
8
7
|
|
|
9
8
|
from datachain.lib.convert.python_to_sql import python_to_sql
|
|
@@ -75,6 +74,8 @@ class Func(Function):
|
|
|
75
74
|
|
|
76
75
|
@property
|
|
77
76
|
def _db_cols(self) -> Sequence[ColT]:
|
|
77
|
+
from sqlalchemy.ext.hybrid import Comparator
|
|
78
|
+
|
|
78
79
|
return (
|
|
79
80
|
[
|
|
80
81
|
col
|