datachain 0.8.9__py3-none-any.whl → 0.8.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/cache.py +4 -4
- datachain/catalog/__init__.py +0 -2
- datachain/catalog/catalog.py +102 -138
- datachain/cli/__init__.py +9 -9
- datachain/cli/parser/__init__.py +36 -20
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/studio.py +35 -34
- datachain/cli/parser/utils.py +19 -1
- datachain/cli/utils.py +1 -1
- datachain/client/fsspec.py +11 -8
- datachain/client/local.py +4 -4
- datachain/data_storage/schema.py +1 -1
- datachain/data_storage/sqlite.py +38 -7
- datachain/data_storage/warehouse.py +2 -2
- datachain/dataset.py +1 -1
- datachain/error.py +12 -0
- datachain/func/__init__.py +2 -1
- datachain/func/conditional.py +67 -23
- datachain/func/func.py +17 -5
- datachain/lib/convert/python_to_sql.py +15 -3
- datachain/lib/dc.py +27 -5
- datachain/lib/file.py +16 -0
- datachain/lib/listing.py +30 -12
- datachain/lib/pytorch.py +1 -1
- datachain/lib/udf.py +1 -1
- datachain/listing.py +1 -13
- datachain/node.py +0 -15
- datachain/nodes_fetcher.py +2 -2
- datachain/query/dataset.py +8 -4
- datachain/remote/studio.py +3 -3
- datachain/sql/sqlite/base.py +35 -14
- datachain/studio.py +8 -8
- {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/METADATA +3 -7
- {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/RECORD +38 -38
- {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/LICENSE +0 -0
- {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/WHEEL +0 -0
- {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/top_level.txt +0 -0
datachain/cli/parser/studio.py
CHANGED
|
@@ -1,31 +1,32 @@
|
|
|
1
|
-
def
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
)
|
|
1
|
+
def add_auth_parser(subparsers, parent_parser) -> None:
|
|
2
|
+
from dvc_studio_client.auth import AVAILABLE_SCOPES
|
|
3
|
+
|
|
4
|
+
auth_help = "Manage Studio authentication"
|
|
5
|
+
auth_description = "Manage authentication and settings for Studio. "
|
|
7
6
|
|
|
8
|
-
|
|
9
|
-
"
|
|
7
|
+
auth_parser = subparsers.add_parser(
|
|
8
|
+
"auth",
|
|
10
9
|
parents=[parent_parser],
|
|
11
|
-
description=
|
|
12
|
-
help=
|
|
10
|
+
description=auth_description,
|
|
11
|
+
help=auth_help,
|
|
13
12
|
)
|
|
14
|
-
|
|
13
|
+
auth_subparser = auth_parser.add_subparsers(
|
|
15
14
|
dest="cmd",
|
|
16
|
-
help="Use `datachain
|
|
15
|
+
help="Use `datachain auth CMD --help` to display command-specific help",
|
|
17
16
|
)
|
|
18
17
|
|
|
19
|
-
|
|
20
|
-
|
|
18
|
+
auth_login_help = "Authenticate with Studio"
|
|
19
|
+
auth_login_description = (
|
|
21
20
|
"Authenticate with Studio using default scopes. "
|
|
22
|
-
"A random name will be assigned
|
|
21
|
+
"A random name will be assigned if the token name is not specified."
|
|
23
22
|
)
|
|
24
|
-
|
|
23
|
+
|
|
24
|
+
allowed_scopes = ", ".join(AVAILABLE_SCOPES)
|
|
25
|
+
login_parser = auth_subparser.add_parser(
|
|
25
26
|
"login",
|
|
26
27
|
parents=[parent_parser],
|
|
27
|
-
description=
|
|
28
|
-
help=
|
|
28
|
+
description=auth_login_description,
|
|
29
|
+
help=auth_login_help,
|
|
29
30
|
)
|
|
30
31
|
|
|
31
32
|
login_parser.add_argument(
|
|
@@ -40,7 +41,7 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
40
41
|
"--scopes",
|
|
41
42
|
action="store",
|
|
42
43
|
default=None,
|
|
43
|
-
help="Authentication token scopes",
|
|
44
|
+
help=f"Authentication token scopes. Allowed scopes: {allowed_scopes}",
|
|
44
45
|
)
|
|
45
46
|
|
|
46
47
|
login_parser.add_argument(
|
|
@@ -58,26 +59,26 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
58
59
|
help="Use code-based authentication without browser",
|
|
59
60
|
)
|
|
60
61
|
|
|
61
|
-
|
|
62
|
-
|
|
62
|
+
auth_logout_help = "Log out from Studio"
|
|
63
|
+
auth_logout_description = (
|
|
63
64
|
"Remove the Studio authentication token from global config."
|
|
64
65
|
)
|
|
65
66
|
|
|
66
|
-
|
|
67
|
+
auth_subparser.add_parser(
|
|
67
68
|
"logout",
|
|
68
69
|
parents=[parent_parser],
|
|
69
|
-
description=
|
|
70
|
-
help=
|
|
70
|
+
description=auth_logout_description,
|
|
71
|
+
help=auth_logout_help,
|
|
71
72
|
)
|
|
72
73
|
|
|
73
|
-
|
|
74
|
-
|
|
74
|
+
auth_team_help = "Set default team for Studio operations"
|
|
75
|
+
auth_team_description = "Set the default team for Studio operations."
|
|
75
76
|
|
|
76
|
-
team_parser =
|
|
77
|
+
team_parser = auth_subparser.add_parser(
|
|
77
78
|
"team",
|
|
78
79
|
parents=[parent_parser],
|
|
79
|
-
description=
|
|
80
|
-
help=
|
|
80
|
+
description=auth_team_description,
|
|
81
|
+
help=auth_team_help,
|
|
81
82
|
)
|
|
82
83
|
team_parser.add_argument(
|
|
83
84
|
"team_name",
|
|
@@ -91,12 +92,12 @@ def add_studio_parser(subparsers, parent_parser) -> None:
|
|
|
91
92
|
help="Set team globally for all projects",
|
|
92
93
|
)
|
|
93
94
|
|
|
94
|
-
|
|
95
|
-
|
|
95
|
+
auth_token_help = "View Studio authentication token" # noqa: S105
|
|
96
|
+
auth_token_description = "Display the current authentication token for Studio." # noqa: S105
|
|
96
97
|
|
|
97
|
-
|
|
98
|
+
auth_subparser.add_parser(
|
|
98
99
|
"token",
|
|
99
100
|
parents=[parent_parser],
|
|
100
|
-
description=
|
|
101
|
-
help=
|
|
101
|
+
description=auth_token_description,
|
|
102
|
+
help=auth_token_help,
|
|
102
103
|
)
|
datachain/cli/parser/utils.py
CHANGED
|
@@ -30,7 +30,25 @@ def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Act
|
|
|
30
30
|
"sources",
|
|
31
31
|
type=str,
|
|
32
32
|
nargs=nargs,
|
|
33
|
-
help="Data sources - paths to
|
|
33
|
+
help="Data sources - paths to source storage directories or files",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def add_anon_arg(parser: ArgumentParser) -> None:
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"--anon",
|
|
40
|
+
action="store_true",
|
|
41
|
+
help="Use anonymous access to storage",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def add_update_arg(parser: ArgumentParser) -> None:
|
|
46
|
+
parser.add_argument(
|
|
47
|
+
"-u",
|
|
48
|
+
"--update",
|
|
49
|
+
action="count",
|
|
50
|
+
default=0,
|
|
51
|
+
help="Update cached list of files for the sources",
|
|
34
52
|
)
|
|
35
53
|
|
|
36
54
|
|
datachain/cli/utils.py
CHANGED
|
@@ -87,7 +87,7 @@ def get_logging_level(args: Namespace) -> int:
|
|
|
87
87
|
def determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]):
|
|
88
88
|
if studio and not token:
|
|
89
89
|
raise DataChainError(
|
|
90
|
-
"Not logged in to Studio. Log in with 'datachain
|
|
90
|
+
"Not logged in to Studio. Log in with 'datachain auth login'."
|
|
91
91
|
)
|
|
92
92
|
|
|
93
93
|
if local or studio:
|
datachain/client/fsspec.py
CHANGED
|
@@ -3,6 +3,7 @@ import functools
|
|
|
3
3
|
import logging
|
|
4
4
|
import multiprocessing
|
|
5
5
|
import os
|
|
6
|
+
import posixpath
|
|
6
7
|
import re
|
|
7
8
|
import sys
|
|
8
9
|
from abc import ABC, abstractmethod
|
|
@@ -25,7 +26,7 @@ from fsspec.asyn import get_loop, sync
|
|
|
25
26
|
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
|
|
26
27
|
from tqdm.auto import tqdm
|
|
27
28
|
|
|
28
|
-
from datachain.cache import
|
|
29
|
+
from datachain.cache import Cache
|
|
29
30
|
from datachain.client.fileslice import FileWrapper
|
|
30
31
|
from datachain.error import ClientError as DataChainClientError
|
|
31
32
|
from datachain.nodes_fetcher import NodesFetcher
|
|
@@ -74,9 +75,7 @@ class Client(ABC):
|
|
|
74
75
|
PREFIX: ClassVar[str]
|
|
75
76
|
protocol: ClassVar[str]
|
|
76
77
|
|
|
77
|
-
def __init__(
|
|
78
|
-
self, name: str, fs_kwargs: dict[str, Any], cache: DataChainCache
|
|
79
|
-
) -> None:
|
|
78
|
+
def __init__(self, name: str, fs_kwargs: dict[str, Any], cache: Cache) -> None:
|
|
80
79
|
self.name = name
|
|
81
80
|
self.fs_kwargs = fs_kwargs
|
|
82
81
|
self._fs: Optional[AbstractFileSystem] = None
|
|
@@ -122,7 +121,7 @@ class Client(ABC):
|
|
|
122
121
|
return cls.get_uri(storage_name), rel_path
|
|
123
122
|
|
|
124
123
|
@staticmethod
|
|
125
|
-
def get_client(source: str, cache:
|
|
124
|
+
def get_client(source: str, cache: Cache, **kwargs) -> "Client":
|
|
126
125
|
cls = Client.get_implementation(source)
|
|
127
126
|
storage_url, _ = cls.split_url(source)
|
|
128
127
|
if os.name == "nt":
|
|
@@ -145,7 +144,7 @@ class Client(ABC):
|
|
|
145
144
|
def from_name(
|
|
146
145
|
cls,
|
|
147
146
|
name: str,
|
|
148
|
-
cache:
|
|
147
|
+
cache: Cache,
|
|
149
148
|
kwargs: dict[str, Any],
|
|
150
149
|
) -> "Client":
|
|
151
150
|
return cls(name, kwargs, cache)
|
|
@@ -154,7 +153,7 @@ class Client(ABC):
|
|
|
154
153
|
def from_source(
|
|
155
154
|
cls,
|
|
156
155
|
uri: "StorageURI",
|
|
157
|
-
cache:
|
|
156
|
+
cache: Cache,
|
|
158
157
|
**kwargs,
|
|
159
158
|
) -> "Client":
|
|
160
159
|
return cls(cls.FS_CLASS._strip_protocol(uri), kwargs, cache)
|
|
@@ -390,8 +389,12 @@ class Client(ABC):
|
|
|
390
389
|
self.fs.open(self.get_full_path(file.path, file.version)), cb
|
|
391
390
|
) # type: ignore[return-value]
|
|
392
391
|
|
|
393
|
-
def upload(self,
|
|
392
|
+
def upload(self, data: bytes, path: str) -> "File":
|
|
394
393
|
full_path = self.get_full_path(path)
|
|
394
|
+
|
|
395
|
+
parent = posixpath.dirname(full_path)
|
|
396
|
+
self.fs.makedirs(parent, exist_ok=True)
|
|
397
|
+
|
|
395
398
|
self.fs.pipe_file(full_path, data)
|
|
396
399
|
file_info = self.fs.info(full_path)
|
|
397
400
|
return self.info_to_file(file_info, path)
|
datachain/client/local.py
CHANGED
|
@@ -12,7 +12,7 @@ from datachain.lib.file import File
|
|
|
12
12
|
from .fsspec import Client
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
-
from datachain.cache import
|
|
15
|
+
from datachain.cache import Cache
|
|
16
16
|
from datachain.dataset import StorageURI
|
|
17
17
|
|
|
18
18
|
|
|
@@ -25,7 +25,7 @@ class FileClient(Client):
|
|
|
25
25
|
self,
|
|
26
26
|
name: str,
|
|
27
27
|
fs_kwargs: dict[str, Any],
|
|
28
|
-
cache: "
|
|
28
|
+
cache: "Cache",
|
|
29
29
|
use_symlinks: bool = False,
|
|
30
30
|
) -> None:
|
|
31
31
|
super().__init__(name, fs_kwargs, cache)
|
|
@@ -82,7 +82,7 @@ class FileClient(Client):
|
|
|
82
82
|
return bucket, path
|
|
83
83
|
|
|
84
84
|
@classmethod
|
|
85
|
-
def from_name(cls, name: str, cache: "
|
|
85
|
+
def from_name(cls, name: str, cache: "Cache", kwargs) -> "FileClient":
|
|
86
86
|
use_symlinks = kwargs.pop("use_symlinks", False)
|
|
87
87
|
return cls(name, kwargs, cache, use_symlinks=use_symlinks)
|
|
88
88
|
|
|
@@ -90,7 +90,7 @@ class FileClient(Client):
|
|
|
90
90
|
def from_source(
|
|
91
91
|
cls,
|
|
92
92
|
uri: str,
|
|
93
|
-
cache: "
|
|
93
|
+
cache: "Cache",
|
|
94
94
|
use_symlinks: bool = False,
|
|
95
95
|
**kwargs,
|
|
96
96
|
) -> "FileClient":
|
datachain/data_storage/schema.py
CHANGED
|
@@ -200,7 +200,7 @@ class DataTable:
|
|
|
200
200
|
columns: Sequence["sa.Column"] = (),
|
|
201
201
|
metadata: Optional["sa.MetaData"] = None,
|
|
202
202
|
):
|
|
203
|
-
# copy columns, since
|
|
203
|
+
# copy columns, since reusing the same objects from another table
|
|
204
204
|
# may raise an error
|
|
205
205
|
columns = cls.sys_columns() + [cls.copy_column(c) for c in columns]
|
|
206
206
|
columns = dedup_columns(columns)
|
datachain/data_storage/sqlite.py
CHANGED
|
@@ -19,6 +19,7 @@ from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
|
|
|
19
19
|
from sqlalchemy.dialects import sqlite
|
|
20
20
|
from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
|
|
21
21
|
from sqlalchemy.sql import func
|
|
22
|
+
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
|
|
22
23
|
from sqlalchemy.sql.expression import bindparam, cast
|
|
23
24
|
from sqlalchemy.sql.selectable import Select
|
|
24
25
|
from tqdm.auto import tqdm
|
|
@@ -40,7 +41,6 @@ if TYPE_CHECKING:
|
|
|
40
41
|
from sqlalchemy.schema import SchemaItem
|
|
41
42
|
from sqlalchemy.sql._typing import _FromClauseArgument, _OnClauseArgument
|
|
42
43
|
from sqlalchemy.sql.elements import ColumnElement
|
|
43
|
-
from sqlalchemy.sql.selectable import Join
|
|
44
44
|
from sqlalchemy.types import TypeEngine
|
|
45
45
|
|
|
46
46
|
from datachain.lib.file import File
|
|
@@ -654,16 +654,47 @@ class SQLiteWarehouse(AbstractWarehouse):
|
|
|
654
654
|
right: "_FromClauseArgument",
|
|
655
655
|
onclause: "_OnClauseArgument",
|
|
656
656
|
inner: bool = True,
|
|
657
|
-
|
|
657
|
+
full: bool = False,
|
|
658
|
+
columns=None,
|
|
659
|
+
) -> "Select":
|
|
658
660
|
"""
|
|
659
661
|
Join two tables together.
|
|
660
662
|
"""
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
663
|
+
if not full:
|
|
664
|
+
join_query = sqlalchemy.join(
|
|
665
|
+
left,
|
|
666
|
+
right,
|
|
667
|
+
onclause,
|
|
668
|
+
isouter=not inner,
|
|
669
|
+
)
|
|
670
|
+
return sqlalchemy.select(*columns).select_from(join_query)
|
|
671
|
+
|
|
672
|
+
left_right_join = sqlalchemy.select(*columns).select_from(
|
|
673
|
+
sqlalchemy.join(left, right, onclause, isouter=True)
|
|
666
674
|
)
|
|
675
|
+
right_left_join = sqlalchemy.select(*columns).select_from(
|
|
676
|
+
sqlalchemy.join(right, left, onclause, isouter=True)
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
def add_left_rows_filter(exp: BinaryExpression):
|
|
680
|
+
"""
|
|
681
|
+
Adds filter to right_left_join to remove unmatched left table rows by
|
|
682
|
+
getting column names that need to be NULL from BinaryExpressions in onclause
|
|
683
|
+
"""
|
|
684
|
+
return right_left_join.where(
|
|
685
|
+
getattr(left.c, exp.left.name) == None # type: ignore[union-attr] # noqa: E711
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
if isinstance(onclause, BinaryExpression):
|
|
689
|
+
right_left_join = add_left_rows_filter(onclause)
|
|
690
|
+
|
|
691
|
+
if isinstance(onclause, BooleanClauseList):
|
|
692
|
+
for c in onclause.get_children():
|
|
693
|
+
if isinstance(c, BinaryExpression):
|
|
694
|
+
right_left_join = add_left_rows_filter(c)
|
|
695
|
+
|
|
696
|
+
union = sqlalchemy.union(left_right_join, right_left_join).subquery()
|
|
697
|
+
return sqlalchemy.select(*union.c).select_from(union)
|
|
667
698
|
|
|
668
699
|
def create_pre_udf_table(self, query: "Select") -> "Table":
|
|
669
700
|
"""
|
|
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
|
|
31
31
|
_FromClauseArgument,
|
|
32
32
|
_OnClauseArgument,
|
|
33
33
|
)
|
|
34
|
-
from sqlalchemy.sql.selectable import
|
|
34
|
+
from sqlalchemy.sql.selectable import Select
|
|
35
35
|
from sqlalchemy.types import TypeEngine
|
|
36
36
|
|
|
37
37
|
from datachain.data_storage import schema
|
|
@@ -873,7 +873,7 @@ class AbstractWarehouse(ABC, Serializable):
|
|
|
873
873
|
right: "_FromClauseArgument",
|
|
874
874
|
onclause: "_OnClauseArgument",
|
|
875
875
|
inner: bool = True,
|
|
876
|
-
) -> "
|
|
876
|
+
) -> "Select":
|
|
877
877
|
"""
|
|
878
878
|
Join two tables together.
|
|
879
879
|
"""
|
datachain/dataset.py
CHANGED
|
@@ -91,7 +91,7 @@ class DatasetDependency:
|
|
|
91
91
|
if self.type == DatasetDependencyType.DATASET:
|
|
92
92
|
return self.name
|
|
93
93
|
|
|
94
|
-
list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"),
|
|
94
|
+
list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), {})
|
|
95
95
|
assert list_dataset_name
|
|
96
96
|
return list_dataset_name
|
|
97
97
|
|
datachain/error.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
1
|
+
import botocore.errorfactory
|
|
2
|
+
import botocore.exceptions
|
|
3
|
+
import gcsfs.retry
|
|
4
|
+
|
|
5
|
+
REMOTE_ERRORS = (
|
|
6
|
+
gcsfs.retry.HttpError, # GCS
|
|
7
|
+
OSError, # GCS
|
|
8
|
+
botocore.exceptions.BotoCoreError, # S3
|
|
9
|
+
ValueError, # Azure
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
1
13
|
class DataChainError(RuntimeError):
|
|
2
14
|
pass
|
|
3
15
|
|
datachain/func/__init__.py
CHANGED
|
@@ -16,7 +16,7 @@ from .aggregate import (
|
|
|
16
16
|
sum,
|
|
17
17
|
)
|
|
18
18
|
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
|
-
from .conditional import case, greatest, ifelse, least
|
|
19
|
+
from .conditional import case, greatest, ifelse, isnone, least
|
|
20
20
|
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
21
|
from .random import rand
|
|
22
22
|
from .string import byte_hamming_distance
|
|
@@ -42,6 +42,7 @@ __all__ = [
|
|
|
42
42
|
"greatest",
|
|
43
43
|
"ifelse",
|
|
44
44
|
"int_hash_64",
|
|
45
|
+
"isnone",
|
|
45
46
|
"least",
|
|
46
47
|
"length",
|
|
47
48
|
"literal",
|
datachain/func/conditional.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
from typing import Union
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
|
+
from sqlalchemy import ColumnElement
|
|
3
4
|
from sqlalchemy import case as sql_case
|
|
4
|
-
from sqlalchemy.sql.elements import BinaryExpression
|
|
5
5
|
|
|
6
6
|
from datachain.lib.utils import DataChainParamsError
|
|
7
|
+
from datachain.query.schema import Column
|
|
7
8
|
from datachain.sql.functions import conditional
|
|
8
9
|
|
|
9
10
|
from .func import ColT, Func
|
|
10
11
|
|
|
11
|
-
CaseT = Union[int, float, complex, bool, str]
|
|
12
|
+
CaseT = Union[int, float, complex, bool, str, Func]
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def greatest(*args: Union[ColT, float]) -> Func:
|
|
@@ -87,17 +88,21 @@ def least(*args: Union[ColT, float]) -> Func:
|
|
|
87
88
|
)
|
|
88
89
|
|
|
89
90
|
|
|
90
|
-
def case(
|
|
91
|
+
def case(
|
|
92
|
+
*args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None
|
|
93
|
+
) -> Func:
|
|
91
94
|
"""
|
|
92
95
|
Returns the case function that produces case expression which has a list of
|
|
93
|
-
conditions and corresponding results. Results can
|
|
94
|
-
|
|
96
|
+
conditions and corresponding results. Results can be python primitives like string,
|
|
97
|
+
numbers or booleans but can also be other nested function (including case function).
|
|
98
|
+
Result type is inferred from condition results.
|
|
95
99
|
|
|
96
100
|
Args:
|
|
97
|
-
args (tuple(
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
+
args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))):
|
|
102
|
+
Tuple of condition and values pair.
|
|
103
|
+
else_ (str | int | float | complex | bool, Func): optional else value in case
|
|
104
|
+
expression. If omitted, and no case conditions are satisfied, the result
|
|
105
|
+
will be None (NULL in DB).
|
|
101
106
|
|
|
102
107
|
Returns:
|
|
103
108
|
Func: A Func object that represents the case function.
|
|
@@ -111,15 +116,24 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
|
|
|
111
116
|
"""
|
|
112
117
|
supported_types = [int, float, complex, str, bool]
|
|
113
118
|
|
|
114
|
-
|
|
119
|
+
def _get_type(val):
|
|
120
|
+
if isinstance(val, Func):
|
|
121
|
+
# nested functions
|
|
122
|
+
return val.result_type
|
|
123
|
+
return type(val)
|
|
115
124
|
|
|
116
125
|
if not args:
|
|
117
126
|
raise DataChainParamsError("Missing statements")
|
|
118
127
|
|
|
128
|
+
type_ = _get_type(else_) if else_ is not None else None
|
|
129
|
+
|
|
119
130
|
for arg in args:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
131
|
+
arg_type = _get_type(arg[1])
|
|
132
|
+
if type_ and arg_type != type_:
|
|
133
|
+
raise DataChainParamsError(
|
|
134
|
+
f"Statement values must be of the same type, got {type_} and {arg_type}"
|
|
135
|
+
)
|
|
136
|
+
type_ = arg_type
|
|
123
137
|
|
|
124
138
|
if type_ not in supported_types:
|
|
125
139
|
raise DataChainParamsError(
|
|
@@ -127,20 +141,25 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
|
|
|
127
141
|
)
|
|
128
142
|
|
|
129
143
|
kwargs = {"else_": else_}
|
|
130
|
-
|
|
144
|
+
|
|
145
|
+
return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_)
|
|
131
146
|
|
|
132
147
|
|
|
133
|
-
def ifelse(
|
|
148
|
+
def ifelse(
|
|
149
|
+
condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT
|
|
150
|
+
) -> Func:
|
|
134
151
|
"""
|
|
135
152
|
Returns the ifelse function that produces if expression which has a condition
|
|
136
|
-
and values for true and false outcome. Results can
|
|
137
|
-
like string,
|
|
153
|
+
and values for true and false outcome. Results can be one of python primitives
|
|
154
|
+
like string, numbers or booleans, but can also be nested functions.
|
|
155
|
+
Result type is inferred from the values.
|
|
138
156
|
|
|
139
157
|
Args:
|
|
140
|
-
condition
|
|
141
|
-
if_val
|
|
142
|
-
|
|
143
|
-
|
|
158
|
+
condition (ColumnElement, Func): Condition which is evaluated.
|
|
159
|
+
if_val (str | int | float | complex | bool, Func): Value for true
|
|
160
|
+
condition outcome.
|
|
161
|
+
else_val (str | int | float | complex | bool, Func): Value for false condition
|
|
162
|
+
outcome.
|
|
144
163
|
|
|
145
164
|
Returns:
|
|
146
165
|
Func: A Func object that represents the ifelse function.
|
|
@@ -148,8 +167,33 @@ def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
|
|
|
148
167
|
Example:
|
|
149
168
|
```py
|
|
150
169
|
dc.mutate(
|
|
151
|
-
res=func.ifelse(
|
|
170
|
+
res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY")
|
|
152
171
|
)
|
|
153
172
|
```
|
|
154
173
|
"""
|
|
155
174
|
return case((condition, if_val), else_=else_val)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def isnone(col: Union[str, Column]) -> Func:
|
|
178
|
+
"""
|
|
179
|
+
Returns True if column value is None, otherwise False.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
col (str | Column): Column to check if it's None or not.
|
|
183
|
+
If a string is provided, it is assumed to be the name of the column.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Func: A Func object that represents the conditional to check if column is None.
|
|
187
|
+
|
|
188
|
+
Example:
|
|
189
|
+
```py
|
|
190
|
+
dc.mutate(test=ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"))
|
|
191
|
+
```
|
|
192
|
+
"""
|
|
193
|
+
from datachain import C
|
|
194
|
+
|
|
195
|
+
if isinstance(col, str):
|
|
196
|
+
# if string, it is assumed to be the name of the column
|
|
197
|
+
col = C(col)
|
|
198
|
+
|
|
199
|
+
return case((col.is_(None) if col is not None else True, True), else_=False)
|
datachain/func/func.py
CHANGED
|
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|
|
23
23
|
from .window import Window
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
ColT = Union[str, ColumnElement, "Func"]
|
|
26
|
+
ColT = Union[str, ColumnElement, "Func", tuple]
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Func(Function):
|
|
@@ -78,7 +78,7 @@ class Func(Function):
|
|
|
78
78
|
return (
|
|
79
79
|
[
|
|
80
80
|
col
|
|
81
|
-
if isinstance(col, (Func, BindParameter, Case, Comparator))
|
|
81
|
+
if isinstance(col, (Func, BindParameter, Case, Comparator, tuple))
|
|
82
82
|
else ColumnMeta.to_db_name(
|
|
83
83
|
col.name if isinstance(col, ColumnElement) else col
|
|
84
84
|
)
|
|
@@ -381,17 +381,24 @@ class Func(Function):
|
|
|
381
381
|
col_type = self.get_result_type(signals_schema)
|
|
382
382
|
sql_type = python_to_sql(col_type)
|
|
383
383
|
|
|
384
|
-
def get_col(col: ColT) -> ColT:
|
|
384
|
+
def get_col(col: ColT, string_as_literal=False) -> ColT:
|
|
385
|
+
# string_as_literal is used only for conditionals like `case()` where
|
|
386
|
+
# literals are nested inside ColT as we have tuples of condition - values
|
|
387
|
+
# and if user wants to set some case value as column, explicit `C("col")`
|
|
388
|
+
# syntax must be used to distinguish from literals
|
|
389
|
+
if isinstance(col, tuple):
|
|
390
|
+
return tuple(get_col(x, string_as_literal=True) for x in col)
|
|
385
391
|
if isinstance(col, Func):
|
|
386
392
|
return col.get_column(signals_schema, table=table)
|
|
387
|
-
if isinstance(col, str):
|
|
393
|
+
if isinstance(col, str) and not string_as_literal:
|
|
388
394
|
column = Column(col, sql_type)
|
|
389
395
|
column.table = table
|
|
390
396
|
return column
|
|
391
397
|
return col
|
|
392
398
|
|
|
393
399
|
cols = [get_col(col) for col in self._db_cols]
|
|
394
|
-
|
|
400
|
+
kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
|
|
401
|
+
func_col = self.inner(*cols, *self.args, **kwargs)
|
|
395
402
|
|
|
396
403
|
if self.is_window:
|
|
397
404
|
if not self.window:
|
|
@@ -416,6 +423,11 @@ class Func(Function):
|
|
|
416
423
|
|
|
417
424
|
|
|
418
425
|
def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
|
|
426
|
+
if isinstance(col, tuple):
|
|
427
|
+
raise DataChainParamsError(
|
|
428
|
+
"Cannot get type from tuple, please provide type hint to the function"
|
|
429
|
+
)
|
|
430
|
+
|
|
419
431
|
if isinstance(col, Func):
|
|
420
432
|
return col.get_result_type(signals_schema)
|
|
421
433
|
|
|
@@ -52,15 +52,15 @@ def python_to_sql(typ): # noqa: PLR0911
|
|
|
52
52
|
|
|
53
53
|
args = get_args(typ)
|
|
54
54
|
if inspect.isclass(orig) and (issubclass(list, orig) or issubclass(tuple, orig)):
|
|
55
|
-
if args is None
|
|
55
|
+
if args is None:
|
|
56
56
|
raise TypeError(f"Cannot resolve type '{typ}' for flattening features")
|
|
57
57
|
|
|
58
58
|
args0 = args[0]
|
|
59
59
|
if ModelStore.is_pydantic(args0):
|
|
60
60
|
return Array(JSON())
|
|
61
61
|
|
|
62
|
-
|
|
63
|
-
return Array(
|
|
62
|
+
list_type = list_of_args_to_type(args)
|
|
63
|
+
return Array(list_type)
|
|
64
64
|
|
|
65
65
|
if orig is Annotated:
|
|
66
66
|
# Ignoring annotations
|
|
@@ -82,6 +82,18 @@ def python_to_sql(typ): # noqa: PLR0911
|
|
|
82
82
|
raise TypeError(f"Cannot recognize type {typ}")
|
|
83
83
|
|
|
84
84
|
|
|
85
|
+
def list_of_args_to_type(args) -> SQLType:
|
|
86
|
+
first_type = python_to_sql(args[0])
|
|
87
|
+
for next_arg in args[1:]:
|
|
88
|
+
try:
|
|
89
|
+
next_type = python_to_sql(next_arg)
|
|
90
|
+
if next_type != first_type:
|
|
91
|
+
return JSON()
|
|
92
|
+
except TypeError:
|
|
93
|
+
return JSON()
|
|
94
|
+
return first_type
|
|
95
|
+
|
|
96
|
+
|
|
85
97
|
def _is_json_inside_union(orig, args) -> bool:
|
|
86
98
|
if orig == Union and len(args) >= 2:
|
|
87
99
|
# List in JSON: Union[dict, list[dict]]
|