datachain 0.8.10__py3-none-any.whl → 0.8.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/cache.py +4 -4
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +103 -158
- datachain/cli/__init__.py +7 -14
- datachain/cli/commands/__init__.py +0 -2
- datachain/cli/commands/datasets.py +0 -19
- datachain/cli/parser/__init__.py +27 -41
- datachain/cli/parser/studio.py +7 -6
- datachain/cli/parser/utils.py +18 -0
- datachain/client/fsspec.py +11 -8
- datachain/client/local.py +4 -4
- datachain/data_storage/schema.py +1 -1
- datachain/dataset.py +1 -7
- datachain/error.py +12 -0
- datachain/func/__init__.py +2 -1
- datachain/func/conditional.py +77 -26
- datachain/func/func.py +17 -6
- datachain/lib/dc.py +24 -4
- datachain/lib/file.py +16 -0
- datachain/lib/listing.py +30 -12
- datachain/lib/pytorch.py +1 -1
- datachain/lib/udf.py +1 -1
- datachain/listing.py +1 -13
- datachain/node.py +0 -15
- datachain/nodes_fetcher.py +2 -2
- datachain/remote/studio.py +2 -14
- datachain/studio.py +1 -1
- {datachain-0.8.10.dist-info → datachain-0.8.12.dist-info}/METADATA +3 -7
- {datachain-0.8.10.dist-info → datachain-0.8.12.dist-info}/RECORD +33 -33
- {datachain-0.8.10.dist-info → datachain-0.8.12.dist-info}/LICENSE +0 -0
- {datachain-0.8.10.dist-info → datachain-0.8.12.dist-info}/WHEEL +0 -0
- {datachain-0.8.10.dist-info → datachain-0.8.12.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.10.dist-info → datachain-0.8.12.dist-info}/top_level.txt +0 -0
datachain/cli/parser/__init__.py
CHANGED
|
@@ -8,7 +8,14 @@ from datachain.cli.utils import BooleanOptionalAction, KeyValueArgs
|
|
|
8
8
|
|
|
9
9
|
from .job import add_jobs_parser
|
|
10
10
|
from .studio import add_auth_parser
|
|
11
|
-
from .utils import
|
|
11
|
+
from .utils import (
|
|
12
|
+
FIND_COLUMNS,
|
|
13
|
+
add_anon_arg,
|
|
14
|
+
add_show_args,
|
|
15
|
+
add_sources_arg,
|
|
16
|
+
add_update_arg,
|
|
17
|
+
find_columns_type,
|
|
18
|
+
)
|
|
12
19
|
|
|
13
20
|
|
|
14
21
|
def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
@@ -32,19 +39,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
32
39
|
"-q", "--quiet", action="count", default=0, help="Be quiet"
|
|
33
40
|
)
|
|
34
41
|
|
|
35
|
-
parent_parser.add_argument(
|
|
36
|
-
"--anon",
|
|
37
|
-
action="store_true",
|
|
38
|
-
help="Use anonymous access to storage",
|
|
39
|
-
)
|
|
40
|
-
parent_parser.add_argument(
|
|
41
|
-
"-u",
|
|
42
|
-
"--update",
|
|
43
|
-
action="count",
|
|
44
|
-
default=0,
|
|
45
|
-
help="Update cached list of files for the sources",
|
|
46
|
-
)
|
|
47
|
-
|
|
48
42
|
parent_parser.add_argument(
|
|
49
43
|
"--debug-sql",
|
|
50
44
|
action="store_true",
|
|
@@ -92,6 +86,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
92
86
|
action="store_true",
|
|
93
87
|
help="Do not expand globs (such as * or ?)",
|
|
94
88
|
)
|
|
89
|
+
add_anon_arg(parse_cp)
|
|
90
|
+
add_update_arg(parse_cp)
|
|
95
91
|
|
|
96
92
|
parse_clone = subp.add_parser(
|
|
97
93
|
"clone", parents=[parent_parser], description="Copy data files from the cloud."
|
|
@@ -127,6 +123,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
127
123
|
action="store_true",
|
|
128
124
|
help="Do not copy files, just create a dataset",
|
|
129
125
|
)
|
|
126
|
+
add_anon_arg(parse_clone)
|
|
127
|
+
add_update_arg(parse_clone)
|
|
130
128
|
|
|
131
129
|
add_auth_parser(subp, parent_parser)
|
|
132
130
|
add_jobs_parser(subp, parent_parser)
|
|
@@ -137,6 +135,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
137
135
|
parents=[parent_parser],
|
|
138
136
|
description="Commands for managing datasets.",
|
|
139
137
|
)
|
|
138
|
+
add_anon_arg(datasets_parser)
|
|
140
139
|
datasets_subparser = datasets_parser.add_subparsers(
|
|
141
140
|
dest="datasets_cmd",
|
|
142
141
|
help="Use `datachain dataset CMD --help` to display command-specific help",
|
|
@@ -308,34 +307,11 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
308
307
|
help="The team to delete a dataset. By default, it will use team from config",
|
|
309
308
|
)
|
|
310
309
|
|
|
311
|
-
dataset_stats_parser = datasets_subparser.add_parser(
|
|
312
|
-
"stats", parents=[parent_parser], description="Show basic dataset statistics."
|
|
313
|
-
)
|
|
314
|
-
dataset_stats_parser.add_argument("name", type=str, help="Dataset name")
|
|
315
|
-
dataset_stats_parser.add_argument(
|
|
316
|
-
"--version",
|
|
317
|
-
action="store",
|
|
318
|
-
default=None,
|
|
319
|
-
type=int,
|
|
320
|
-
help="Dataset version",
|
|
321
|
-
)
|
|
322
|
-
dataset_stats_parser.add_argument(
|
|
323
|
-
"-b",
|
|
324
|
-
"--bytes",
|
|
325
|
-
default=False,
|
|
326
|
-
action="store_true",
|
|
327
|
-
help="Display size in bytes instead of human-readable size",
|
|
328
|
-
)
|
|
329
|
-
dataset_stats_parser.add_argument(
|
|
330
|
-
"--si",
|
|
331
|
-
default=False,
|
|
332
|
-
action="store_true",
|
|
333
|
-
help="Display size using powers of 1000 not 1024",
|
|
334
|
-
)
|
|
335
|
-
|
|
336
310
|
parse_ls = subp.add_parser(
|
|
337
311
|
"ls", parents=[parent_parser], description="List storage contents."
|
|
338
312
|
)
|
|
313
|
+
add_anon_arg(parse_ls)
|
|
314
|
+
add_update_arg(parse_ls)
|
|
339
315
|
add_sources_arg(parse_ls, nargs="*")
|
|
340
316
|
parse_ls.add_argument(
|
|
341
317
|
"-l",
|
|
@@ -375,6 +351,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
375
351
|
"du", parents=[parent_parser], description="Display space usage."
|
|
376
352
|
)
|
|
377
353
|
add_sources_arg(parse_du)
|
|
354
|
+
add_anon_arg(parse_du)
|
|
355
|
+
add_update_arg(parse_du)
|
|
378
356
|
parse_du.add_argument(
|
|
379
357
|
"-b",
|
|
380
358
|
"--bytes",
|
|
@@ -404,6 +382,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
404
382
|
parse_find = subp.add_parser(
|
|
405
383
|
"find", parents=[parent_parser], description="Search in a directory hierarchy."
|
|
406
384
|
)
|
|
385
|
+
add_anon_arg(parse_find)
|
|
386
|
+
add_update_arg(parse_find)
|
|
407
387
|
add_sources_arg(parse_find)
|
|
408
388
|
parse_find.add_argument(
|
|
409
389
|
"--name",
|
|
@@ -457,6 +437,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
457
437
|
parse_index = subp.add_parser(
|
|
458
438
|
"index", parents=[parent_parser], description="Index storage location."
|
|
459
439
|
)
|
|
440
|
+
add_anon_arg(parse_index)
|
|
441
|
+
add_update_arg(parse_index)
|
|
460
442
|
add_sources_arg(parse_index)
|
|
461
443
|
|
|
462
444
|
show_parser = subp.add_parser(
|
|
@@ -480,6 +462,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
480
462
|
parents=[parent_parser],
|
|
481
463
|
description="Create a new dataset with a query script.",
|
|
482
464
|
)
|
|
465
|
+
add_anon_arg(query_parser)
|
|
483
466
|
query_parser.add_argument(
|
|
484
467
|
"script", metavar="<script.py>", type=str, help="Filepath for script"
|
|
485
468
|
)
|
|
@@ -504,14 +487,17 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
|
|
|
504
487
|
help="Query parameters",
|
|
505
488
|
)
|
|
506
489
|
|
|
507
|
-
subp.add_parser(
|
|
490
|
+
parse_clear_cache = subp.add_parser(
|
|
508
491
|
"clear-cache",
|
|
509
492
|
parents=[parent_parser],
|
|
510
493
|
description="Clear the local file cache.",
|
|
511
494
|
)
|
|
512
|
-
|
|
495
|
+
add_anon_arg(parse_clear_cache)
|
|
496
|
+
|
|
497
|
+
parse_gc = subp.add_parser(
|
|
513
498
|
"gc", parents=[parent_parser], description="Garbage collect temporary tables."
|
|
514
499
|
)
|
|
500
|
+
add_anon_arg(parse_gc)
|
|
515
501
|
|
|
516
502
|
subp.add_parser("internal-run-udf", parents=[parent_parser])
|
|
517
503
|
subp.add_parser("internal-run-udf-worker", parents=[parent_parser])
|
datachain/cli/parser/studio.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
def add_auth_parser(subparsers, parent_parser) -> None:
|
|
2
|
+
from dvc_studio_client.auth import AVAILABLE_SCOPES
|
|
3
|
+
|
|
2
4
|
auth_help = "Manage Studio authentication"
|
|
3
|
-
auth_description =
|
|
4
|
-
"Manage authentication and settings for Studio. "
|
|
5
|
-
"Configure tokens for sharing datasets and using Studio features."
|
|
6
|
-
)
|
|
5
|
+
auth_description = "Manage authentication and settings for Studio. "
|
|
7
6
|
|
|
8
7
|
auth_parser = subparsers.add_parser(
|
|
9
8
|
"auth",
|
|
@@ -19,8 +18,10 @@ def add_auth_parser(subparsers, parent_parser) -> None:
|
|
|
19
18
|
auth_login_help = "Authenticate with Studio"
|
|
20
19
|
auth_login_description = (
|
|
21
20
|
"Authenticate with Studio using default scopes. "
|
|
22
|
-
"A random name will be assigned
|
|
21
|
+
"A random name will be assigned if the token name is not specified."
|
|
23
22
|
)
|
|
23
|
+
|
|
24
|
+
allowed_scopes = ", ".join(AVAILABLE_SCOPES)
|
|
24
25
|
login_parser = auth_subparser.add_parser(
|
|
25
26
|
"login",
|
|
26
27
|
parents=[parent_parser],
|
|
@@ -40,7 +41,7 @@ def add_auth_parser(subparsers, parent_parser) -> None:
|
|
|
40
41
|
"--scopes",
|
|
41
42
|
action="store",
|
|
42
43
|
default=None,
|
|
43
|
-
help="Authentication token scopes",
|
|
44
|
+
help=f"Authentication token scopes. Allowed scopes: {allowed_scopes}",
|
|
44
45
|
)
|
|
45
46
|
|
|
46
47
|
login_parser.add_argument(
|
datachain/cli/parser/utils.py
CHANGED
|
@@ -34,6 +34,24 @@ def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Act
|
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
def add_anon_arg(parser: ArgumentParser) -> None:
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"--anon",
|
|
40
|
+
action="store_true",
|
|
41
|
+
help="Use anonymous access to storage",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def add_update_arg(parser: ArgumentParser) -> None:
|
|
46
|
+
parser.add_argument(
|
|
47
|
+
"-u",
|
|
48
|
+
"--update",
|
|
49
|
+
action="count",
|
|
50
|
+
default=0,
|
|
51
|
+
help="Update cached list of files for the sources",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
37
55
|
def add_show_args(parser: ArgumentParser) -> None:
|
|
38
56
|
parser.add_argument(
|
|
39
57
|
"--limit",
|
datachain/client/fsspec.py
CHANGED
|
@@ -3,6 +3,7 @@ import functools
|
|
|
3
3
|
import logging
|
|
4
4
|
import multiprocessing
|
|
5
5
|
import os
|
|
6
|
+
import posixpath
|
|
6
7
|
import re
|
|
7
8
|
import sys
|
|
8
9
|
from abc import ABC, abstractmethod
|
|
@@ -25,7 +26,7 @@ from fsspec.asyn import get_loop, sync
|
|
|
25
26
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
26
27
|
from tqdm.auto import tqdm
|
|
27
28
|
|
|
28
|
-
from datachain.cache import
|
|
29
|
+
from datachain.cache import Cache
|
|
29
30
|
from datachain.client.fileslice import FileWrapper
|
|
30
31
|
from datachain.error import ClientError as DataChainClientError
|
|
31
32
|
from datachain.nodes_fetcher import NodesFetcher
|
|
@@ -74,9 +75,7 @@ class Client(ABC):
|
|
|
74
75
|
PREFIX: ClassVar[str]
|
|
75
76
|
protocol: ClassVar[str]
|
|
76
77
|
|
|
77
|
-
def __init__(
|
|
78
|
-
self, name: str, fs_kwargs: dict[str, Any], cache: DataChainCache
|
|
79
|
-
) -> None:
|
|
78
|
+
def __init__(self, name: str, fs_kwargs: dict[str, Any], cache: Cache) -> None:
|
|
80
79
|
self.name = name
|
|
81
80
|
self.fs_kwargs = fs_kwargs
|
|
82
81
|
self._fs: Optional[AbstractFileSystem] = None
|
|
@@ -122,7 +121,7 @@ class Client(ABC):
|
|
|
122
121
|
return cls.get_uri(storage_name), rel_path
|
|
123
122
|
|
|
124
123
|
@staticmethod
|
|
125
|
-
def get_client(source: str, cache:
|
|
124
|
+
def get_client(source: str, cache: Cache, **kwargs) -> "Client":
|
|
126
125
|
cls = Client.get_implementation(source)
|
|
127
126
|
storage_url, _ = cls.split_url(source)
|
|
128
127
|
if os.name == "nt":
|
|
@@ -145,7 +144,7 @@ class Client(ABC):
|
|
|
145
144
|
def from_name(
|
|
146
145
|
cls,
|
|
147
146
|
name: str,
|
|
148
|
-
cache:
|
|
147
|
+
cache: Cache,
|
|
149
148
|
kwargs: dict[str, Any],
|
|
150
149
|
) -> "Client":
|
|
151
150
|
return cls(name, kwargs, cache)
|
|
@@ -154,7 +153,7 @@ class Client(ABC):
|
|
|
154
153
|
def from_source(
|
|
155
154
|
cls,
|
|
156
155
|
uri: "StorageURI",
|
|
157
|
-
cache:
|
|
156
|
+
cache: Cache,
|
|
158
157
|
**kwargs,
|
|
159
158
|
) -> "Client":
|
|
160
159
|
return cls(cls.FS_CLASS._strip_protocol(uri), kwargs, cache)
|
|
@@ -390,8 +389,12 @@ class Client(ABC):
|
|
|
390
389
|
self.fs.open(self.get_full_path(file.path, file.version)), cb
|
|
391
390
|
) # type: ignore[return-value]
|
|
392
391
|
|
|
393
|
-
def upload(self,
|
|
392
|
+
def upload(self, data: bytes, path: str) -> "File":
|
|
394
393
|
full_path = self.get_full_path(path)
|
|
394
|
+
|
|
395
|
+
parent = posixpath.dirname(full_path)
|
|
396
|
+
self.fs.makedirs(parent, exist_ok=True)
|
|
397
|
+
|
|
395
398
|
self.fs.pipe_file(full_path, data)
|
|
396
399
|
file_info = self.fs.info(full_path)
|
|
397
400
|
return self.info_to_file(file_info, path)
|
datachain/client/local.py
CHANGED
|
@@ -12,7 +12,7 @@ from datachain.lib.file import File
|
|
|
12
12
|
from .fsspec import Client
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
-
from datachain.cache import
|
|
15
|
+
from datachain.cache import Cache
|
|
16
16
|
from datachain.dataset import StorageURI
|
|
17
17
|
|
|
18
18
|
|
|
@@ -25,7 +25,7 @@ class FileClient(Client):
|
|
|
25
25
|
self,
|
|
26
26
|
name: str,
|
|
27
27
|
fs_kwargs: dict[str, Any],
|
|
28
|
-
cache: "
|
|
28
|
+
cache: "Cache",
|
|
29
29
|
use_symlinks: bool = False,
|
|
30
30
|
) -> None:
|
|
31
31
|
super().__init__(name, fs_kwargs, cache)
|
|
@@ -82,7 +82,7 @@ class FileClient(Client):
|
|
|
82
82
|
return bucket, path
|
|
83
83
|
|
|
84
84
|
@classmethod
|
|
85
|
-
def from_name(cls, name: str, cache: "
|
|
85
|
+
def from_name(cls, name: str, cache: "Cache", kwargs) -> "FileClient":
|
|
86
86
|
use_symlinks = kwargs.pop("use_symlinks", False)
|
|
87
87
|
return cls(name, kwargs, cache, use_symlinks=use_symlinks)
|
|
88
88
|
|
|
@@ -90,7 +90,7 @@ class FileClient(Client):
|
|
|
90
90
|
def from_source(
|
|
91
91
|
cls,
|
|
92
92
|
uri: str,
|
|
93
|
-
cache: "
|
|
93
|
+
cache: "Cache",
|
|
94
94
|
use_symlinks: bool = False,
|
|
95
95
|
**kwargs,
|
|
96
96
|
) -> "FileClient":
|
datachain/data_storage/schema.py
CHANGED
|
@@ -200,7 +200,7 @@ class DataTable:
|
|
|
200
200
|
columns: Sequence["sa.Column"] = (),
|
|
201
201
|
metadata: Optional["sa.MetaData"] = None,
|
|
202
202
|
):
|
|
203
|
-
# copy columns, since
|
|
203
|
+
# copy columns, since reusing the same objects from another table
|
|
204
204
|
# may raise an error
|
|
205
205
|
columns = cls.sys_columns() + [cls.copy_column(c) for c in columns]
|
|
206
206
|
columns = dedup_columns(columns)
|
datachain/dataset.py
CHANGED
|
@@ -91,7 +91,7 @@ class DatasetDependency:
|
|
|
91
91
|
if self.type == DatasetDependencyType.DATASET:
|
|
92
92
|
return self.name
|
|
93
93
|
|
|
94
|
-
list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"),
|
|
94
|
+
list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), {})
|
|
95
95
|
assert list_dataset_name
|
|
96
96
|
return list_dataset_name
|
|
97
97
|
|
|
@@ -150,12 +150,6 @@ class DatasetDependency:
|
|
|
150
150
|
return hash(f"{self.type}_{self.name}_{self.version}")
|
|
151
151
|
|
|
152
152
|
|
|
153
|
-
@dataclass
|
|
154
|
-
class DatasetStats:
|
|
155
|
-
num_objects: Optional[int] # None if table is missing
|
|
156
|
-
size: Optional[int] # in bytes None if table is missing or empty
|
|
157
|
-
|
|
158
|
-
|
|
159
153
|
class DatasetStatus:
|
|
160
154
|
CREATED = 1
|
|
161
155
|
PENDING = 2
|
datachain/error.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
1
|
+
import botocore.errorfactory
|
|
2
|
+
import botocore.exceptions
|
|
3
|
+
import gcsfs.retry
|
|
4
|
+
|
|
5
|
+
REMOTE_ERRORS = (
|
|
6
|
+
gcsfs.retry.HttpError, # GCS
|
|
7
|
+
OSError, # GCS
|
|
8
|
+
botocore.exceptions.BotoCoreError, # S3
|
|
9
|
+
ValueError, # Azure
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
1
13
|
class DataChainError(RuntimeError):
|
|
2
14
|
pass
|
|
3
15
|
|
datachain/func/__init__.py
CHANGED
|
@@ -16,7 +16,7 @@ from .aggregate import (
|
|
|
16
16
|
sum,
|
|
17
17
|
)
|
|
18
18
|
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
|
-
from .conditional import case, greatest, ifelse, least
|
|
19
|
+
from .conditional import case, greatest, ifelse, isnone, least
|
|
20
20
|
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
21
|
from .random import rand
|
|
22
22
|
from .string import byte_hamming_distance
|
|
@@ -42,6 +42,7 @@ __all__ = [
|
|
|
42
42
|
"greatest",
|
|
43
43
|
"ifelse",
|
|
44
44
|
"int_hash_64",
|
|
45
|
+
"isnone",
|
|
45
46
|
"least",
|
|
46
47
|
"length",
|
|
47
48
|
"literal",
|
datachain/func/conditional.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
from typing import Union
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
|
+
from sqlalchemy import ColumnElement
|
|
3
4
|
from sqlalchemy import case as sql_case
|
|
4
|
-
from sqlalchemy.sql.elements import BinaryExpression
|
|
5
5
|
|
|
6
6
|
from datachain.lib.utils import DataChainParamsError
|
|
7
|
+
from datachain.query.schema import Column
|
|
7
8
|
from datachain.sql.functions import conditional
|
|
8
9
|
|
|
9
10
|
from .func import ColT, Func
|
|
10
11
|
|
|
11
|
-
CaseT = Union[int, float, complex, bool, str]
|
|
12
|
+
CaseT = Union[int, float, complex, bool, str, Func, ColumnElement]
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def greatest(*args: Union[ColT, float]) -> Func:
|
|
@@ -87,17 +88,22 @@ def least(*args: Union[ColT, float]) -> Func:
|
|
|
87
88
|
)
|
|
88
89
|
|
|
89
90
|
|
|
90
|
-
def case(
|
|
91
|
+
def case(
|
|
92
|
+
*args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None
|
|
93
|
+
) -> Func:
|
|
91
94
|
"""
|
|
92
95
|
Returns the case function that produces case expression which has a list of
|
|
93
|
-
conditions and corresponding results. Results can
|
|
94
|
-
|
|
96
|
+
conditions and corresponding results. Results can be python primitives like string,
|
|
97
|
+
numbers or booleans but can also be other nested functions (including case function)
|
|
98
|
+
or columns.
|
|
99
|
+
Result type is inferred from condition results.
|
|
95
100
|
|
|
96
101
|
Args:
|
|
97
|
-
args
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
args tuple((ColumnElement | Func),(str | int | float | complex | bool, Func, ColumnElement)):
|
|
103
|
+
Tuple of condition and values pair.
|
|
104
|
+
else_ (str | int | float | complex | bool, Func): optional else value in case
|
|
105
|
+
expression. If omitted, and no case conditions are satisfied, the result
|
|
106
|
+
will be None (NULL in DB).
|
|
101
107
|
|
|
102
108
|
Returns:
|
|
103
109
|
Func: A Func object that represents the case function.
|
|
@@ -108,39 +114,59 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
|
|
|
108
114
|
res=func.case((C("num") > 0, "P"), (C("num") < 0, "N"), else_="Z"),
|
|
109
115
|
)
|
|
110
116
|
```
|
|
111
|
-
"""
|
|
117
|
+
""" # noqa: E501
|
|
112
118
|
supported_types = [int, float, complex, str, bool]
|
|
113
119
|
|
|
114
|
-
|
|
120
|
+
def _get_type(val):
|
|
121
|
+
if isinstance(val, Func):
|
|
122
|
+
# nested functions
|
|
123
|
+
return val.result_type
|
|
124
|
+
if isinstance(val, Column):
|
|
125
|
+
# at this point we cannot know what is the type of a column
|
|
126
|
+
return None
|
|
127
|
+
return type(val)
|
|
115
128
|
|
|
116
129
|
if not args:
|
|
117
130
|
raise DataChainParamsError("Missing statements")
|
|
118
131
|
|
|
119
|
-
|
|
120
|
-
if type_ and not isinstance(arg[1], type_):
|
|
121
|
-
raise DataChainParamsError("Statement values must be of the same type")
|
|
122
|
-
type_ = type(arg[1])
|
|
132
|
+
type_ = _get_type(else_) if else_ is not None else None
|
|
123
133
|
|
|
124
|
-
|
|
134
|
+
for arg in args:
|
|
135
|
+
arg_type = _get_type(arg[1])
|
|
136
|
+
if arg_type is None:
|
|
137
|
+
# we couldn't figure out the type of case value
|
|
138
|
+
continue
|
|
139
|
+
if type_ and arg_type != type_:
|
|
140
|
+
raise DataChainParamsError(
|
|
141
|
+
f"Statement values must be of the same type, got {type_} and {arg_type}"
|
|
142
|
+
)
|
|
143
|
+
type_ = arg_type
|
|
144
|
+
|
|
145
|
+
if type_ is not None and type_ not in supported_types:
|
|
125
146
|
raise DataChainParamsError(
|
|
126
147
|
f"Only python literals ({supported_types}) are supported for values"
|
|
127
148
|
)
|
|
128
149
|
|
|
129
150
|
kwargs = {"else_": else_}
|
|
130
|
-
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
|
|
131
151
|
|
|
152
|
+
return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_)
|
|
132
153
|
|
|
133
|
-
|
|
154
|
+
|
|
155
|
+
def ifelse(
|
|
156
|
+
condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT
|
|
157
|
+
) -> Func:
|
|
134
158
|
"""
|
|
135
159
|
Returns the ifelse function that produces if expression which has a condition
|
|
136
|
-
and values for true and false outcome. Results can
|
|
137
|
-
like string,
|
|
160
|
+
and values for true and false outcome. Results can be one of python primitives
|
|
161
|
+
like string, numbers or booleans, but can also be nested functions or columns.
|
|
162
|
+
Result type is inferred from the values.
|
|
138
163
|
|
|
139
164
|
Args:
|
|
140
|
-
condition
|
|
141
|
-
if_val
|
|
142
|
-
|
|
143
|
-
|
|
165
|
+
condition (ColumnElement, Func): Condition which is evaluated.
|
|
166
|
+
if_val (str | int | float | complex | bool, Func, ColumnElement): Value for true
|
|
167
|
+
condition outcome.
|
|
168
|
+
else_val (str | int | float | complex | bool, Func, ColumnElement): Value for
|
|
169
|
+
false condition outcome.
|
|
144
170
|
|
|
145
171
|
Returns:
|
|
146
172
|
Func: A Func object that represents the ifelse function.
|
|
@@ -148,8 +174,33 @@ def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
|
|
|
148
174
|
Example:
|
|
149
175
|
```py
|
|
150
176
|
dc.mutate(
|
|
151
|
-
res=func.ifelse(
|
|
177
|
+
res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY")
|
|
152
178
|
)
|
|
153
179
|
```
|
|
154
180
|
"""
|
|
155
181
|
return case((condition, if_val), else_=else_val)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def isnone(col: Union[str, Column]) -> Func:
|
|
185
|
+
"""
|
|
186
|
+
Returns True if column value is None, otherwise False.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
col (str | Column): Column to check if it's None or not.
|
|
190
|
+
If a string is provided, it is assumed to be the name of the column.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Func: A Func object that represents the conditional to check if column is None.
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
```py
|
|
197
|
+
dc.mutate(test=ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"))
|
|
198
|
+
```
|
|
199
|
+
"""
|
|
200
|
+
from datachain import C
|
|
201
|
+
|
|
202
|
+
if isinstance(col, str):
|
|
203
|
+
# if string, it is assumed to be the name of the column
|
|
204
|
+
col = C(col)
|
|
205
|
+
|
|
206
|
+
return case((col.is_(None) if col is not None else True, True), else_=False)
|
datachain/func/func.py
CHANGED
|
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|
|
23
23
|
from .window import Window
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
ColT = Union[str, ColumnElement, "Func"]
|
|
26
|
+
ColT = Union[str, ColumnElement, "Func", tuple]
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Func(Function):
|
|
@@ -78,7 +78,7 @@ class Func(Function):
|
|
|
78
78
|
return (
|
|
79
79
|
[
|
|
80
80
|
col
|
|
81
|
-
if isinstance(col, (Func, BindParameter, Case, Comparator))
|
|
81
|
+
if isinstance(col, (Func, BindParameter, Case, Comparator, tuple))
|
|
82
82
|
else ColumnMeta.to_db_name(
|
|
83
83
|
col.name if isinstance(col, ColumnElement) else col
|
|
84
84
|
)
|
|
@@ -381,17 +381,24 @@ class Func(Function):
|
|
|
381
381
|
col_type = self.get_result_type(signals_schema)
|
|
382
382
|
sql_type = python_to_sql(col_type)
|
|
383
383
|
|
|
384
|
-
def get_col(col: ColT) -> ColT:
|
|
384
|
+
def get_col(col: ColT, string_as_literal=False) -> ColT:
|
|
385
|
+
# string_as_literal is used only for conditionals like `case()` where
|
|
386
|
+
# literals are nested inside ColT as we have tuples of condition - values
|
|
387
|
+
# and if user wants to set some case value as column, explicit `C("col")`
|
|
388
|
+
# syntax must be used to distinguish from literals
|
|
389
|
+
if isinstance(col, tuple):
|
|
390
|
+
return tuple(get_col(x, string_as_literal=True) for x in col)
|
|
385
391
|
if isinstance(col, Func):
|
|
386
392
|
return col.get_column(signals_schema, table=table)
|
|
387
|
-
if isinstance(col, str):
|
|
393
|
+
if isinstance(col, str) and not string_as_literal:
|
|
388
394
|
column = Column(col, sql_type)
|
|
389
395
|
column.table = table
|
|
390
396
|
return column
|
|
391
397
|
return col
|
|
392
398
|
|
|
393
399
|
cols = [get_col(col) for col in self._db_cols]
|
|
394
|
-
|
|
400
|
+
kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
|
|
401
|
+
func_col = self.inner(*cols, *self.args, **kwargs)
|
|
395
402
|
|
|
396
403
|
if self.is_window:
|
|
397
404
|
if not self.window:
|
|
@@ -416,6 +423,10 @@ class Func(Function):
|
|
|
416
423
|
|
|
417
424
|
|
|
418
425
|
def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
|
|
426
|
+
if isinstance(col, tuple):
|
|
427
|
+
# we can only get tuple from case statement where the first tuple item
|
|
428
|
+
# is condition, and second one is value which type is important
|
|
429
|
+
col = col[1]
|
|
419
430
|
if isinstance(col, Func):
|
|
420
431
|
return col.get_result_type(signals_schema)
|
|
421
432
|
|
|
@@ -423,7 +434,7 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
|
|
|
423
434
|
return sql_to_python(col)
|
|
424
435
|
|
|
425
436
|
return signals_schema.get_column_type(
|
|
426
|
-
col.name if isinstance(col, ColumnElement) else col
|
|
437
|
+
col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type]
|
|
427
438
|
)
|
|
428
439
|
|
|
429
440
|
|