datachain 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/cache.py +4 -2
- datachain/catalog/catalog.py +100 -54
- datachain/catalog/datasource.py +4 -6
- datachain/cli/__init__.py +311 -0
- datachain/cli/commands/__init__.py +29 -0
- datachain/cli/commands/datasets.py +129 -0
- datachain/cli/commands/du.py +14 -0
- datachain/cli/commands/index.py +12 -0
- datachain/cli/commands/ls.py +169 -0
- datachain/cli/commands/misc.py +28 -0
- datachain/cli/commands/query.py +53 -0
- datachain/cli/commands/show.py +38 -0
- datachain/cli/parser/__init__.py +547 -0
- datachain/cli/parser/job.py +120 -0
- datachain/cli/parser/studio.py +126 -0
- datachain/cli/parser/utils.py +63 -0
- datachain/{cli_utils.py → cli/utils.py} +27 -1
- datachain/client/azure.py +21 -1
- datachain/client/fsspec.py +45 -13
- datachain/client/gcs.py +10 -2
- datachain/client/local.py +4 -4
- datachain/client/s3.py +10 -0
- datachain/dataset.py +1 -0
- datachain/func/__init__.py +2 -2
- datachain/func/conditional.py +52 -0
- datachain/func/func.py +5 -1
- datachain/lib/arrow.py +4 -0
- datachain/lib/dc.py +18 -3
- datachain/lib/file.py +1 -1
- datachain/lib/listing.py +36 -3
- datachain/lib/signal_schema.py +89 -27
- datachain/listing.py +1 -5
- datachain/node.py +27 -1
- datachain/progress.py +2 -2
- datachain/query/session.py +1 -1
- datachain/studio.py +58 -38
- datachain/utils.py +1 -1
- {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/METADATA +6 -6
- {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/RECORD +43 -31
- {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/WHEEL +1 -1
- datachain/cli.py +0 -1475
- {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/LICENSE +0 -0
- {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.2.dist-info → datachain-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
def add_studio_parser(subparsers, parent_parser) -> None:
|
|
2
|
+
studio_help = "Commands to authenticate DataChain with Iterative Studio"
|
|
3
|
+
studio_description = (
|
|
4
|
+
"Authenticate DataChain with Studio and set the token. "
|
|
5
|
+
"Once this token has been properly configured,\n"
|
|
6
|
+
"DataChain will utilize it for seamlessly sharing datasets\n"
|
|
7
|
+
"and using Studio features from CLI"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
studio_parser = subparsers.add_parser(
|
|
11
|
+
"studio",
|
|
12
|
+
parents=[parent_parser],
|
|
13
|
+
description=studio_description,
|
|
14
|
+
help=studio_help,
|
|
15
|
+
)
|
|
16
|
+
studio_subparser = studio_parser.add_subparsers(
|
|
17
|
+
dest="cmd",
|
|
18
|
+
help="Use `DataChain studio CMD --help` to display command-specific help.",
|
|
19
|
+
required=True,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
studio_login_help = "Authenticate DataChain with Studio host"
|
|
23
|
+
studio_login_description = (
|
|
24
|
+
"By default, this command authenticates the DataChain with Studio\n"
|
|
25
|
+
"using default scopes and assigns a random name as the token name."
|
|
26
|
+
)
|
|
27
|
+
login_parser = studio_subparser.add_parser(
|
|
28
|
+
"login",
|
|
29
|
+
parents=[parent_parser],
|
|
30
|
+
description=studio_login_description,
|
|
31
|
+
help=studio_login_help,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
login_parser.add_argument(
|
|
35
|
+
"-H",
|
|
36
|
+
"--hostname",
|
|
37
|
+
action="store",
|
|
38
|
+
default=None,
|
|
39
|
+
help="The hostname of the Studio instance to authenticate with.",
|
|
40
|
+
)
|
|
41
|
+
login_parser.add_argument(
|
|
42
|
+
"-s",
|
|
43
|
+
"--scopes",
|
|
44
|
+
action="store",
|
|
45
|
+
default=None,
|
|
46
|
+
help="The scopes for the authentication token. ",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
login_parser.add_argument(
|
|
50
|
+
"-n",
|
|
51
|
+
"--name",
|
|
52
|
+
action="store",
|
|
53
|
+
default=None,
|
|
54
|
+
help="The name of the authentication token. It will be used to\n"
|
|
55
|
+
"identify token shown in Studio profile.",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
login_parser.add_argument(
|
|
59
|
+
"--no-open",
|
|
60
|
+
action="store_true",
|
|
61
|
+
default=False,
|
|
62
|
+
help="Use authentication flow based on user code.\n"
|
|
63
|
+
"You will be presented with user code to enter in browser.\n"
|
|
64
|
+
"DataChain will also use this if it cannot launch browser on your behalf.",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
studio_logout_help = "Logout user from Studio"
|
|
68
|
+
studio_logout_description = "This removes the studio token from your global config."
|
|
69
|
+
|
|
70
|
+
studio_subparser.add_parser(
|
|
71
|
+
"logout",
|
|
72
|
+
parents=[parent_parser],
|
|
73
|
+
description=studio_logout_description,
|
|
74
|
+
help=studio_logout_help,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
studio_team_help = "Set the default team for DataChain"
|
|
78
|
+
studio_team_description = (
|
|
79
|
+
"Set the default team for DataChain to use when interacting with Studio."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
team_parser = studio_subparser.add_parser(
|
|
83
|
+
"team",
|
|
84
|
+
parents=[parent_parser],
|
|
85
|
+
description=studio_team_description,
|
|
86
|
+
help=studio_team_help,
|
|
87
|
+
)
|
|
88
|
+
team_parser.add_argument(
|
|
89
|
+
"team_name",
|
|
90
|
+
action="store",
|
|
91
|
+
help="The name of the team to set as the default.",
|
|
92
|
+
)
|
|
93
|
+
team_parser.add_argument(
|
|
94
|
+
"--global",
|
|
95
|
+
action="store_true",
|
|
96
|
+
default=False,
|
|
97
|
+
help="Set the team globally for all DataChain projects.",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
studio_token_help = "View the token datachain uses to contact Studio" # noqa: S105 # nosec B105
|
|
101
|
+
|
|
102
|
+
studio_subparser.add_parser(
|
|
103
|
+
"token",
|
|
104
|
+
parents=[parent_parser],
|
|
105
|
+
description=studio_token_help,
|
|
106
|
+
help=studio_token_help,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
studio_ls_dataset_help = "List the available datasets from Studio"
|
|
110
|
+
studio_ls_dataset_description = (
|
|
111
|
+
"This command lists all the datasets available in Studio.\n"
|
|
112
|
+
"It will show the dataset name and the number of versions available."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
ls_dataset_parser = studio_subparser.add_parser(
|
|
116
|
+
"dataset",
|
|
117
|
+
parents=[parent_parser],
|
|
118
|
+
description=studio_ls_dataset_description,
|
|
119
|
+
help=studio_ls_dataset_help,
|
|
120
|
+
)
|
|
121
|
+
ls_dataset_parser.add_argument(
|
|
122
|
+
"--team",
|
|
123
|
+
action="store",
|
|
124
|
+
default=None,
|
|
125
|
+
help="The team to list datasets for. By default, it will use team from config.",
|
|
126
|
+
)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from argparse import Action, ArgumentParser, ArgumentTypeError
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from datachain.cli.utils import CommaSeparatedArgs
|
|
5
|
+
|
|
6
|
+
FIND_COLUMNS = ["du", "name", "path", "size", "type"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def find_columns_type(
|
|
10
|
+
columns_str: str,
|
|
11
|
+
default_colums_str: str = "path",
|
|
12
|
+
) -> list[str]:
|
|
13
|
+
if not columns_str:
|
|
14
|
+
columns_str = default_colums_str
|
|
15
|
+
|
|
16
|
+
return [parse_find_column(c) for c in columns_str.split(",")]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def parse_find_column(column: str) -> str:
|
|
20
|
+
column_lower = column.strip().lower()
|
|
21
|
+
if column_lower in FIND_COLUMNS:
|
|
22
|
+
return column_lower
|
|
23
|
+
raise ArgumentTypeError(
|
|
24
|
+
f"Invalid column for find: '{column}' Options are: {','.join(FIND_COLUMNS)}"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Action:
|
|
29
|
+
return parser.add_argument(
|
|
30
|
+
"sources",
|
|
31
|
+
type=str,
|
|
32
|
+
nargs=nargs,
|
|
33
|
+
help="Data sources - paths to cloud storage dirs",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def add_show_args(parser: ArgumentParser) -> None:
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"--limit",
|
|
40
|
+
action="store",
|
|
41
|
+
default=10,
|
|
42
|
+
type=int,
|
|
43
|
+
help="Number of rows to show",
|
|
44
|
+
)
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"--offset",
|
|
47
|
+
action="store",
|
|
48
|
+
default=0,
|
|
49
|
+
type=int,
|
|
50
|
+
help="Number of rows to offset",
|
|
51
|
+
)
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--columns",
|
|
54
|
+
default=[],
|
|
55
|
+
action=CommaSeparatedArgs,
|
|
56
|
+
help="Columns to show",
|
|
57
|
+
)
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
"--no-collapse",
|
|
60
|
+
action="store_true",
|
|
61
|
+
default=False,
|
|
62
|
+
help="Do not collapse the columns",
|
|
63
|
+
)
|
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from argparse import SUPPRESS, Action, ArgumentError, Namespace, _AppendAction
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from datachain.error import DataChainError
|
|
2
6
|
|
|
3
7
|
|
|
4
8
|
class BooleanOptionalAction(Action):
|
|
@@ -70,3 +74,25 @@ class KeyValueArgs(_AppendAction): # pylint: disable=protected-access
|
|
|
70
74
|
items[key.strip()] = value
|
|
71
75
|
|
|
72
76
|
setattr(namespace, self.dest, items)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_logging_level(args: Namespace) -> int:
|
|
80
|
+
if args.quiet:
|
|
81
|
+
return logging.CRITICAL
|
|
82
|
+
if args.verbose:
|
|
83
|
+
return logging.DEBUG
|
|
84
|
+
return logging.INFO
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]):
|
|
88
|
+
if studio and not token:
|
|
89
|
+
raise DataChainError(
|
|
90
|
+
"Not logged in to Studio. Log in with 'datachain studio login'."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if local or studio:
|
|
94
|
+
all = False
|
|
95
|
+
|
|
96
|
+
all = all and not (local or studio)
|
|
97
|
+
|
|
98
|
+
return all, local, studio
|
datachain/client/azure.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
from urllib.parse import parse_qs, urlsplit, urlunsplit
|
|
2
3
|
|
|
3
4
|
from adlfs import AzureBlobFileSystem
|
|
4
5
|
from tqdm import tqdm
|
|
@@ -25,6 +26,16 @@ class AzureClient(Client):
|
|
|
25
26
|
size=v.get("size", ""),
|
|
26
27
|
)
|
|
27
28
|
|
|
29
|
+
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Generate a signed URL for the given path.
|
|
32
|
+
"""
|
|
33
|
+
version_id = kwargs.pop("version_id", None)
|
|
34
|
+
result = self.fs.sign(
|
|
35
|
+
self.get_full_path(path, version_id), expiration=expires, **kwargs
|
|
36
|
+
)
|
|
37
|
+
return result + (f"&versionid={version_id}" if version_id else "")
|
|
38
|
+
|
|
28
39
|
async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
|
|
29
40
|
prefix = start_prefix
|
|
30
41
|
if prefix:
|
|
@@ -57,4 +68,13 @@ class AzureClient(Client):
|
|
|
57
68
|
finally:
|
|
58
69
|
result_queue.put_nowait(None)
|
|
59
70
|
|
|
71
|
+
@classmethod
|
|
72
|
+
def version_path(cls, path: str, version_id: Optional[str]) -> str:
|
|
73
|
+
parts = list(urlsplit(path))
|
|
74
|
+
query = parse_qs(parts[3])
|
|
75
|
+
if "versionid" in query:
|
|
76
|
+
raise ValueError("path already includes a version query")
|
|
77
|
+
parts[3] = f"versionid={version_id}" if version_id else ""
|
|
78
|
+
return urlunsplit(parts)
|
|
79
|
+
|
|
60
80
|
_fetch_default = _fetch_flat
|
datachain/client/fsspec.py
CHANGED
|
@@ -137,6 +137,10 @@ class Client(ABC):
|
|
|
137
137
|
fs.invalidate_cache()
|
|
138
138
|
return fs
|
|
139
139
|
|
|
140
|
+
@classmethod
|
|
141
|
+
def version_path(cls, path: str, version_id: Optional[str]) -> str:
|
|
142
|
+
return path
|
|
143
|
+
|
|
140
144
|
@classmethod
|
|
141
145
|
def from_name(
|
|
142
146
|
cls,
|
|
@@ -198,17 +202,37 @@ class Client(ABC):
|
|
|
198
202
|
return self._fs
|
|
199
203
|
|
|
200
204
|
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
|
|
201
|
-
return self.fs.sign(
|
|
205
|
+
return self.fs.sign(
|
|
206
|
+
self.get_full_path(path, kwargs.pop("version_id", None)),
|
|
207
|
+
expiration=expires,
|
|
208
|
+
**kwargs,
|
|
209
|
+
)
|
|
202
210
|
|
|
203
211
|
async def get_current_etag(self, file: "File") -> str:
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
+
kwargs = {}
|
|
213
|
+
if self.fs.version_aware:
|
|
214
|
+
kwargs["version_id"] = file.version
|
|
215
|
+
info = await self.fs._info(
|
|
216
|
+
self.get_full_path(file.path, file.version), **kwargs
|
|
217
|
+
)
|
|
218
|
+
return self.info_to_file(info, file.path).etag
|
|
219
|
+
|
|
220
|
+
def get_file_info(self, path: str, version_id: Optional[str] = None) -> "File":
|
|
221
|
+
info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id)
|
|
222
|
+
return self.info_to_file(info, path)
|
|
223
|
+
|
|
224
|
+
async def get_size(self, path: str, version_id: Optional[str] = None) -> int:
|
|
225
|
+
return await self.fs._size(
|
|
226
|
+
self.version_path(path, version_id), version_id=version_id
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
async def get_file(self, lpath, rpath, callback, version_id: Optional[str] = None):
|
|
230
|
+
return await self.fs._get_file(
|
|
231
|
+
self.version_path(lpath, version_id),
|
|
232
|
+
rpath,
|
|
233
|
+
callback=callback,
|
|
234
|
+
version_id=version_id,
|
|
235
|
+
)
|
|
212
236
|
|
|
213
237
|
async def scandir(
|
|
214
238
|
self, start_prefix: str, method: str = "default"
|
|
@@ -315,11 +339,11 @@ class Client(ABC):
|
|
|
315
339
|
def rel_path(self, path: str) -> str:
|
|
316
340
|
return self.fs.split_path(path)[1]
|
|
317
341
|
|
|
318
|
-
def get_full_path(self, rel_path: str) -> str:
|
|
319
|
-
return f"{self.PREFIX}{self.name}/{rel_path}"
|
|
342
|
+
def get_full_path(self, rel_path: str, version_id: Optional[str] = None) -> str:
|
|
343
|
+
return self.version_path(f"{self.PREFIX}{self.name}/{rel_path}", version_id)
|
|
320
344
|
|
|
321
345
|
@abstractmethod
|
|
322
|
-
def info_to_file(self, v: dict[str, Any],
|
|
346
|
+
def info_to_file(self, v: dict[str, Any], path: str) -> "File": ...
|
|
323
347
|
|
|
324
348
|
def fetch_nodes(
|
|
325
349
|
self,
|
|
@@ -362,7 +386,15 @@ class Client(ABC):
|
|
|
362
386
|
if use_cache and (cache_path := self.cache.get_path(file)):
|
|
363
387
|
return open(cache_path, mode="rb")
|
|
364
388
|
assert not file.location
|
|
365
|
-
return FileWrapper(
|
|
389
|
+
return FileWrapper(
|
|
390
|
+
self.fs.open(self.get_full_path(file.path, file.version)), cb
|
|
391
|
+
) # type: ignore[return-value]
|
|
392
|
+
|
|
393
|
+
def upload(self, path: str, data: bytes) -> "File":
|
|
394
|
+
full_path = self.get_full_path(path)
|
|
395
|
+
self.fs.pipe_file(full_path, data)
|
|
396
|
+
file_info = self.fs.info(full_path)
|
|
397
|
+
return self.info_to_file(file_info, path)
|
|
366
398
|
|
|
367
399
|
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
|
|
368
400
|
sync(get_loop(), functools.partial(self._download, file, callback=callback))
|
datachain/client/gcs.py
CHANGED
|
@@ -38,9 +38,13 @@ class GCSClient(Client):
|
|
|
38
38
|
If the client is anonymous, a public URL is returned instead
|
|
39
39
|
(see https://cloud.google.com/storage/docs/access-public-data#api-link).
|
|
40
40
|
"""
|
|
41
|
+
version_id = kwargs.pop("version_id", None)
|
|
41
42
|
if self.fs.storage_options.get("token") == "anon":
|
|
42
|
-
|
|
43
|
-
|
|
43
|
+
query = f"?generation={version_id}" if version_id else ""
|
|
44
|
+
return f"https://storage.googleapis.com/{self.name}/{path}{query}"
|
|
45
|
+
return self.fs.sign(
|
|
46
|
+
self.get_full_path(path, version_id), expiration=expires, **kwargs
|
|
47
|
+
)
|
|
44
48
|
|
|
45
49
|
@staticmethod
|
|
46
50
|
def parse_timestamp(timestamp: str) -> datetime:
|
|
@@ -131,3 +135,7 @@ class GCSClient(Client):
|
|
|
131
135
|
last_modified=self.parse_timestamp(v["updated"]),
|
|
132
136
|
size=v.get("size", ""),
|
|
133
137
|
)
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def version_path(cls, path: str, version_id: Optional[str]) -> str:
|
|
141
|
+
return f"{path}#{version_id}" if version_id else path
|
datachain/client/local.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
import posixpath
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import TYPE_CHECKING, Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
6
6
|
from urllib.parse import urlparse
|
|
7
7
|
|
|
8
8
|
from fsspec.implementations.local import LocalFileSystem
|
|
@@ -105,10 +105,10 @@ class FileClient(Client):
|
|
|
105
105
|
info = self.fs.info(self.get_full_path(file.path))
|
|
106
106
|
return self.info_to_file(info, "").etag
|
|
107
107
|
|
|
108
|
-
async def get_size(self, path: str) -> int:
|
|
108
|
+
async def get_size(self, path: str, version_id: Optional[str] = None) -> int:
|
|
109
109
|
return self.fs.size(path)
|
|
110
110
|
|
|
111
|
-
async def get_file(self, lpath, rpath, callback):
|
|
111
|
+
async def get_file(self, lpath, rpath, callback, version_id: Optional[str] = None):
|
|
112
112
|
return self.fs.get_file(lpath, rpath, callback=callback)
|
|
113
113
|
|
|
114
114
|
async def ls_dir(self, path):
|
|
@@ -117,7 +117,7 @@ class FileClient(Client):
|
|
|
117
117
|
def rel_path(self, path):
|
|
118
118
|
return posixpath.relpath(path, self.name)
|
|
119
119
|
|
|
120
|
-
def get_full_path(self, rel_path):
|
|
120
|
+
def get_full_path(self, rel_path, version_id: Optional[str] = None):
|
|
121
121
|
full_path = Path(self.name, rel_path).as_posix()
|
|
122
122
|
if rel_path.endswith("/") or not rel_path:
|
|
123
123
|
full_path += "/"
|
datachain/client/s3.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, Optional, cast
|
|
3
|
+
from urllib.parse import parse_qs, urlsplit, urlunsplit
|
|
3
4
|
|
|
4
5
|
from botocore.exceptions import NoCredentialsError
|
|
5
6
|
from s3fs import S3FileSystem
|
|
@@ -121,6 +122,15 @@ class ClientS3(Client):
|
|
|
121
122
|
size=v["Size"],
|
|
122
123
|
)
|
|
123
124
|
|
|
125
|
+
@classmethod
|
|
126
|
+
def version_path(cls, path: str, version_id: Optional[str]) -> str:
|
|
127
|
+
parts = list(urlsplit(path))
|
|
128
|
+
query = parse_qs(parts[3])
|
|
129
|
+
if "versionId" in query:
|
|
130
|
+
raise ValueError("path already includes a version query")
|
|
131
|
+
parts[3] = f"versionId={version_id}" if version_id else ""
|
|
132
|
+
return urlunsplit(parts)
|
|
133
|
+
|
|
124
134
|
async def _fetch_dir(
|
|
125
135
|
self,
|
|
126
136
|
prefix,
|
datachain/dataset.py
CHANGED
datachain/func/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from sqlalchemy import
|
|
1
|
+
from sqlalchemy import literal
|
|
2
2
|
|
|
3
3
|
from . import array, path, random, string
|
|
4
4
|
from .aggregate import (
|
|
@@ -16,7 +16,7 @@ from .aggregate import (
|
|
|
16
16
|
sum,
|
|
17
17
|
)
|
|
18
18
|
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
|
-
from .conditional import greatest, least
|
|
19
|
+
from .conditional import case, greatest, least
|
|
20
20
|
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
21
|
from .random import rand
|
|
22
22
|
from .string import byte_hamming_distance
|
datachain/func/conditional.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
from typing import Union
|
|
2
2
|
|
|
3
|
+
from sqlalchemy import case as sql_case
|
|
4
|
+
from sqlalchemy.sql.elements import BinaryExpression
|
|
5
|
+
|
|
6
|
+
from datachain.lib.utils import DataChainParamsError
|
|
3
7
|
from datachain.sql.functions import conditional
|
|
4
8
|
|
|
5
9
|
from .func import ColT, Func
|
|
@@ -79,3 +83,51 @@ def least(*args: Union[ColT, float]) -> Func:
|
|
|
79
83
|
return Func(
|
|
80
84
|
"least", inner=conditional.least, cols=cols, args=func_args, result_type=int
|
|
81
85
|
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def case(
|
|
89
|
+
*args: tuple[BinaryExpression, Union[int, float, complex, bool, str]], else_=None
|
|
90
|
+
) -> Func:
|
|
91
|
+
"""
|
|
92
|
+
Returns the case function that produces case expression which has a list of
|
|
93
|
+
conditions and corresponding results. Results can only be python primitives
|
|
94
|
+
like string, numbes or booleans. Result type is inferred from condition results.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
args (tuple(BinaryExpression, value(str | int | float | complex | bool):
|
|
98
|
+
- Tuple of binary expression and values pair which corresponds to one
|
|
99
|
+
case condition - value
|
|
100
|
+
else_ (str | int | float | complex | bool): else value in case expression
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Func: A Func object that represents the case function.
|
|
104
|
+
|
|
105
|
+
Example:
|
|
106
|
+
```py
|
|
107
|
+
dc.mutate(
|
|
108
|
+
res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
|
|
109
|
+
)
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
Note:
|
|
113
|
+
- Result column will always be of the same type as the input columns.
|
|
114
|
+
"""
|
|
115
|
+
supported_types = [int, float, complex, str, bool]
|
|
116
|
+
|
|
117
|
+
type_ = type(else_) if else_ else None
|
|
118
|
+
|
|
119
|
+
if not args:
|
|
120
|
+
raise DataChainParamsError("Missing case statements")
|
|
121
|
+
|
|
122
|
+
for arg in args:
|
|
123
|
+
if type_ and not isinstance(arg[1], type_):
|
|
124
|
+
raise DataChainParamsError("Case statement values must be of the same type")
|
|
125
|
+
type_ = type(arg[1])
|
|
126
|
+
|
|
127
|
+
if type_ not in supported_types:
|
|
128
|
+
raise DataChainParamsError(
|
|
129
|
+
f"Case supports only python literals ({supported_types}) for values"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
kwargs = {"else_": else_}
|
|
133
|
+
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
|
datachain/func/func.py
CHANGED
|
@@ -35,6 +35,7 @@ class Func(Function):
|
|
|
35
35
|
inner: Callable,
|
|
36
36
|
cols: Optional[Sequence[ColT]] = None,
|
|
37
37
|
args: Optional[Sequence[Any]] = None,
|
|
38
|
+
kwargs: Optional[dict[str, Any]] = None,
|
|
38
39
|
result_type: Optional["DataType"] = None,
|
|
39
40
|
is_array: bool = False,
|
|
40
41
|
is_window: bool = False,
|
|
@@ -45,6 +46,7 @@ class Func(Function):
|
|
|
45
46
|
self.inner = inner
|
|
46
47
|
self.cols = cols or []
|
|
47
48
|
self.args = args or []
|
|
49
|
+
self.kwargs = kwargs or {}
|
|
48
50
|
self.result_type = result_type
|
|
49
51
|
self.is_array = is_array
|
|
50
52
|
self.is_window = is_window
|
|
@@ -63,6 +65,7 @@ class Func(Function):
|
|
|
63
65
|
self.inner,
|
|
64
66
|
self.cols,
|
|
65
67
|
self.args,
|
|
68
|
+
self.kwargs,
|
|
66
69
|
self.result_type,
|
|
67
70
|
self.is_array,
|
|
68
71
|
self.is_window,
|
|
@@ -333,6 +336,7 @@ class Func(Function):
|
|
|
333
336
|
self.inner,
|
|
334
337
|
self.cols,
|
|
335
338
|
self.args,
|
|
339
|
+
self.kwargs,
|
|
336
340
|
self.result_type,
|
|
337
341
|
self.is_array,
|
|
338
342
|
self.is_window,
|
|
@@ -387,7 +391,7 @@ class Func(Function):
|
|
|
387
391
|
return col
|
|
388
392
|
|
|
389
393
|
cols = [get_col(col) for col in self._db_cols]
|
|
390
|
-
func_col = self.inner(*cols, *self.args)
|
|
394
|
+
func_col = self.inner(*cols, *self.args, **self.kwargs)
|
|
391
395
|
|
|
392
396
|
if self.is_window:
|
|
393
397
|
if not self.window:
|
datachain/lib/arrow.py
CHANGED
|
@@ -149,6 +149,10 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
|
|
|
149
149
|
for file in chain.collect("file"):
|
|
150
150
|
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
|
|
151
151
|
schemas.append(ds.schema)
|
|
152
|
+
if not schemas:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"Cannot infer schema (no files to process or can't access them)"
|
|
155
|
+
)
|
|
152
156
|
return pa.unify_schemas(schemas)
|
|
153
157
|
|
|
154
158
|
|
datachain/lib/dc.py
CHANGED
|
@@ -32,7 +32,7 @@ from datachain.lib.data_model import DataModel, DataType, DataValue, dict_to_dat
|
|
|
32
32
|
from datachain.lib.dataset_info import DatasetInfo
|
|
33
33
|
from datachain.lib.file import ArrowRow, File, FileType, get_file_type
|
|
34
34
|
from datachain.lib.file import ExportPlacement as FileExportPlacement
|
|
35
|
-
from datachain.lib.listing import get_listing, list_bucket, ls
|
|
35
|
+
from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls
|
|
36
36
|
from datachain.lib.listing_info import ListingInfo
|
|
37
37
|
from datachain.lib.meta_formats import read_meta
|
|
38
38
|
from datachain.lib.model_store import ModelStore
|
|
@@ -438,6 +438,18 @@ class DataChain:
|
|
|
438
438
|
uri, session, update=update
|
|
439
439
|
)
|
|
440
440
|
|
|
441
|
+
# ds_name is None if object is a file, we don't want to use cache
|
|
442
|
+
# or do listing in that case - just read that single object
|
|
443
|
+
if not list_ds_name:
|
|
444
|
+
dc = cls.from_values(
|
|
445
|
+
session=session,
|
|
446
|
+
settings=settings,
|
|
447
|
+
in_memory=in_memory,
|
|
448
|
+
file=[get_file_info(list_uri, cache, client_config=client_config)],
|
|
449
|
+
)
|
|
450
|
+
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})
|
|
451
|
+
return dc
|
|
452
|
+
|
|
441
453
|
if update or not list_ds_exists:
|
|
442
454
|
(
|
|
443
455
|
cls.from_records(
|
|
@@ -1634,7 +1646,7 @@ class DataChain:
|
|
|
1634
1646
|
output: OutputType = None,
|
|
1635
1647
|
object_name: str = "",
|
|
1636
1648
|
**fr_map,
|
|
1637
|
-
) -> "
|
|
1649
|
+
) -> "Self":
|
|
1638
1650
|
"""Generate chain from list of values.
|
|
1639
1651
|
|
|
1640
1652
|
Example:
|
|
@@ -1647,7 +1659,7 @@ class DataChain:
|
|
|
1647
1659
|
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
|
|
1648
1660
|
yield from tuples
|
|
1649
1661
|
|
|
1650
|
-
chain =
|
|
1662
|
+
chain = cls.from_records(
|
|
1651
1663
|
DataChain.DEFAULT_FILE_RECORD,
|
|
1652
1664
|
session=session,
|
|
1653
1665
|
settings=settings,
|
|
@@ -1870,6 +1882,9 @@ class DataChain:
|
|
|
1870
1882
|
"`nrows` only supported for csv and json formats.",
|
|
1871
1883
|
)
|
|
1872
1884
|
|
|
1885
|
+
if "file" not in self.schema or not self.count():
|
|
1886
|
+
raise DatasetPrepareError(self.name, "no files to parse.")
|
|
1887
|
+
|
|
1873
1888
|
schema = None
|
|
1874
1889
|
col_names = output if isinstance(output, Sequence) else None
|
|
1875
1890
|
if col_names or not output:
|
datachain/lib/file.py
CHANGED
|
@@ -364,7 +364,7 @@ class File(DataModel):
|
|
|
364
364
|
|
|
365
365
|
try:
|
|
366
366
|
info = client.fs.info(client.get_full_path(self.path))
|
|
367
|
-
converted_info = client.info_to_file(info, self.
|
|
367
|
+
converted_info = client.info_to_file(info, self.path)
|
|
368
368
|
return type(self)(
|
|
369
369
|
path=self.path,
|
|
370
370
|
source=self.source,
|