datachain 0.8.9__py3-none-any.whl → 0.8.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

Files changed (38) hide show
  1. datachain/cache.py +4 -4
  2. datachain/catalog/__init__.py +0 -2
  3. datachain/catalog/catalog.py +102 -138
  4. datachain/cli/__init__.py +9 -9
  5. datachain/cli/parser/__init__.py +36 -20
  6. datachain/cli/parser/job.py +1 -1
  7. datachain/cli/parser/studio.py +35 -34
  8. datachain/cli/parser/utils.py +19 -1
  9. datachain/cli/utils.py +1 -1
  10. datachain/client/fsspec.py +11 -8
  11. datachain/client/local.py +4 -4
  12. datachain/data_storage/schema.py +1 -1
  13. datachain/data_storage/sqlite.py +38 -7
  14. datachain/data_storage/warehouse.py +2 -2
  15. datachain/dataset.py +1 -1
  16. datachain/error.py +12 -0
  17. datachain/func/__init__.py +2 -1
  18. datachain/func/conditional.py +67 -23
  19. datachain/func/func.py +17 -5
  20. datachain/lib/convert/python_to_sql.py +15 -3
  21. datachain/lib/dc.py +27 -5
  22. datachain/lib/file.py +16 -0
  23. datachain/lib/listing.py +30 -12
  24. datachain/lib/pytorch.py +1 -1
  25. datachain/lib/udf.py +1 -1
  26. datachain/listing.py +1 -13
  27. datachain/node.py +0 -15
  28. datachain/nodes_fetcher.py +2 -2
  29. datachain/query/dataset.py +8 -4
  30. datachain/remote/studio.py +3 -3
  31. datachain/sql/sqlite/base.py +35 -14
  32. datachain/studio.py +8 -8
  33. {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/METADATA +3 -7
  34. {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/RECORD +38 -38
  35. {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/LICENSE +0 -0
  36. {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/WHEEL +0 -0
  37. {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/entry_points.txt +0 -0
  38. {datachain-0.8.9.dist-info → datachain-0.8.11.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,32 @@
1
- def add_studio_parser(subparsers, parent_parser) -> None:
2
- studio_help = "Manage Studio authentication"
3
- studio_description = (
4
- "Manage authentication and settings for Studio. "
5
- "Configure tokens for sharing datasets and using Studio features."
6
- )
1
+ def add_auth_parser(subparsers, parent_parser) -> None:
2
+ from dvc_studio_client.auth import AVAILABLE_SCOPES
3
+
4
+ auth_help = "Manage Studio authentication"
5
+ auth_description = "Manage authentication and settings for Studio. "
7
6
 
8
- studio_parser = subparsers.add_parser(
9
- "studio",
7
+ auth_parser = subparsers.add_parser(
8
+ "auth",
10
9
  parents=[parent_parser],
11
- description=studio_description,
12
- help=studio_help,
10
+ description=auth_description,
11
+ help=auth_help,
13
12
  )
14
- studio_subparser = studio_parser.add_subparsers(
13
+ auth_subparser = auth_parser.add_subparsers(
15
14
  dest="cmd",
16
- help="Use `datachain studio CMD --help` to display command-specific help",
15
+ help="Use `datachain auth CMD --help` to display command-specific help",
17
16
  )
18
17
 
19
- studio_login_help = "Authenticate with Studio"
20
- studio_login_description = (
18
+ auth_login_help = "Authenticate with Studio"
19
+ auth_login_description = (
21
20
  "Authenticate with Studio using default scopes. "
22
- "A random name will be assigned as the token name if not specified."
21
+ "A random name will be assigned if the token name is not specified."
23
22
  )
24
- login_parser = studio_subparser.add_parser(
23
+
24
+ allowed_scopes = ", ".join(AVAILABLE_SCOPES)
25
+ login_parser = auth_subparser.add_parser(
25
26
  "login",
26
27
  parents=[parent_parser],
27
- description=studio_login_description,
28
- help=studio_login_help,
28
+ description=auth_login_description,
29
+ help=auth_login_help,
29
30
  )
30
31
 
31
32
  login_parser.add_argument(
@@ -40,7 +41,7 @@ def add_studio_parser(subparsers, parent_parser) -> None:
40
41
  "--scopes",
41
42
  action="store",
42
43
  default=None,
43
- help="Authentication token scopes",
44
+ help=f"Authentication token scopes. Allowed scopes: {allowed_scopes}",
44
45
  )
45
46
 
46
47
  login_parser.add_argument(
@@ -58,26 +59,26 @@ def add_studio_parser(subparsers, parent_parser) -> None:
58
59
  help="Use code-based authentication without browser",
59
60
  )
60
61
 
61
- studio_logout_help = "Log out from Studio"
62
- studio_logout_description = (
62
+ auth_logout_help = "Log out from Studio"
63
+ auth_logout_description = (
63
64
  "Remove the Studio authentication token from global config."
64
65
  )
65
66
 
66
- studio_subparser.add_parser(
67
+ auth_subparser.add_parser(
67
68
  "logout",
68
69
  parents=[parent_parser],
69
- description=studio_logout_description,
70
- help=studio_logout_help,
70
+ description=auth_logout_description,
71
+ help=auth_logout_help,
71
72
  )
72
73
 
73
- studio_team_help = "Set default team for Studio operations"
74
- studio_team_description = "Set the default team for Studio operations."
74
+ auth_team_help = "Set default team for Studio operations"
75
+ auth_team_description = "Set the default team for Studio operations."
75
76
 
76
- team_parser = studio_subparser.add_parser(
77
+ team_parser = auth_subparser.add_parser(
77
78
  "team",
78
79
  parents=[parent_parser],
79
- description=studio_team_description,
80
- help=studio_team_help,
80
+ description=auth_team_description,
81
+ help=auth_team_help,
81
82
  )
82
83
  team_parser.add_argument(
83
84
  "team_name",
@@ -91,12 +92,12 @@ def add_studio_parser(subparsers, parent_parser) -> None:
91
92
  help="Set team globally for all projects",
92
93
  )
93
94
 
94
- studio_token_help = "View Studio authentication token" # noqa: S105
95
- studio_token_description = "Display the current authentication token for Studio." # noqa: S105
95
+ auth_token_help = "View Studio authentication token" # noqa: S105
96
+ auth_token_description = "Display the current authentication token for Studio." # noqa: S105
96
97
 
97
- studio_subparser.add_parser(
98
+ auth_subparser.add_parser(
98
99
  "token",
99
100
  parents=[parent_parser],
100
- description=studio_token_description,
101
- help=studio_token_help,
101
+ description=auth_token_description,
102
+ help=auth_token_help,
102
103
  )
@@ -30,7 +30,25 @@ def add_sources_arg(parser: ArgumentParser, nargs: Union[str, int] = "+") -> Act
30
30
  "sources",
31
31
  type=str,
32
32
  nargs=nargs,
33
- help="Data sources - paths to cloud storage directories",
33
+ help="Data sources - paths to source storage directories or files",
34
+ )
35
+
36
+
37
+ def add_anon_arg(parser: ArgumentParser) -> None:
38
+ parser.add_argument(
39
+ "--anon",
40
+ action="store_true",
41
+ help="Use anonymous access to storage",
42
+ )
43
+
44
+
45
+ def add_update_arg(parser: ArgumentParser) -> None:
46
+ parser.add_argument(
47
+ "-u",
48
+ "--update",
49
+ action="count",
50
+ default=0,
51
+ help="Update cached list of files for the sources",
34
52
  )
35
53
 
36
54
 
datachain/cli/utils.py CHANGED
@@ -87,7 +87,7 @@ def get_logging_level(args: Namespace) -> int:
87
87
  def determine_flavors(studio: bool, local: bool, all: bool, token: Optional[str]):
88
88
  if studio and not token:
89
89
  raise DataChainError(
90
- "Not logged in to Studio. Log in with 'datachain studio login'."
90
+ "Not logged in to Studio. Log in with 'datachain auth login'."
91
91
  )
92
92
 
93
93
  if local or studio:
@@ -3,6 +3,7 @@ import functools
3
3
  import logging
4
4
  import multiprocessing
5
5
  import os
6
+ import posixpath
6
7
  import re
7
8
  import sys
8
9
  from abc import ABC, abstractmethod
@@ -25,7 +26,7 @@ from fsspec.asyn import get_loop, sync
25
26
  from fsspec.callbacks import DEFAULT_CALLBACK, Callback
26
27
  from tqdm.auto import tqdm
27
28
 
28
- from datachain.cache import DataChainCache
29
+ from datachain.cache import Cache
29
30
  from datachain.client.fileslice import FileWrapper
30
31
  from datachain.error import ClientError as DataChainClientError
31
32
  from datachain.nodes_fetcher import NodesFetcher
@@ -74,9 +75,7 @@ class Client(ABC):
74
75
  PREFIX: ClassVar[str]
75
76
  protocol: ClassVar[str]
76
77
 
77
- def __init__(
78
- self, name: str, fs_kwargs: dict[str, Any], cache: DataChainCache
79
- ) -> None:
78
+ def __init__(self, name: str, fs_kwargs: dict[str, Any], cache: Cache) -> None:
80
79
  self.name = name
81
80
  self.fs_kwargs = fs_kwargs
82
81
  self._fs: Optional[AbstractFileSystem] = None
@@ -122,7 +121,7 @@ class Client(ABC):
122
121
  return cls.get_uri(storage_name), rel_path
123
122
 
124
123
  @staticmethod
125
- def get_client(source: str, cache: DataChainCache, **kwargs) -> "Client":
124
+ def get_client(source: str, cache: Cache, **kwargs) -> "Client":
126
125
  cls = Client.get_implementation(source)
127
126
  storage_url, _ = cls.split_url(source)
128
127
  if os.name == "nt":
@@ -145,7 +144,7 @@ class Client(ABC):
145
144
  def from_name(
146
145
  cls,
147
146
  name: str,
148
- cache: DataChainCache,
147
+ cache: Cache,
149
148
  kwargs: dict[str, Any],
150
149
  ) -> "Client":
151
150
  return cls(name, kwargs, cache)
@@ -154,7 +153,7 @@ class Client(ABC):
154
153
  def from_source(
155
154
  cls,
156
155
  uri: "StorageURI",
157
- cache: DataChainCache,
156
+ cache: Cache,
158
157
  **kwargs,
159
158
  ) -> "Client":
160
159
  return cls(cls.FS_CLASS._strip_protocol(uri), kwargs, cache)
@@ -390,8 +389,12 @@ class Client(ABC):
390
389
  self.fs.open(self.get_full_path(file.path, file.version)), cb
391
390
  ) # type: ignore[return-value]
392
391
 
393
- def upload(self, path: str, data: bytes) -> "File":
392
+ def upload(self, data: bytes, path: str) -> "File":
394
393
  full_path = self.get_full_path(path)
394
+
395
+ parent = posixpath.dirname(full_path)
396
+ self.fs.makedirs(parent, exist_ok=True)
397
+
395
398
  self.fs.pipe_file(full_path, data)
396
399
  file_info = self.fs.info(full_path)
397
400
  return self.info_to_file(file_info, path)
datachain/client/local.py CHANGED
@@ -12,7 +12,7 @@ from datachain.lib.file import File
12
12
  from .fsspec import Client
13
13
 
14
14
  if TYPE_CHECKING:
15
- from datachain.cache import DataChainCache
15
+ from datachain.cache import Cache
16
16
  from datachain.dataset import StorageURI
17
17
 
18
18
 
@@ -25,7 +25,7 @@ class FileClient(Client):
25
25
  self,
26
26
  name: str,
27
27
  fs_kwargs: dict[str, Any],
28
- cache: "DataChainCache",
28
+ cache: "Cache",
29
29
  use_symlinks: bool = False,
30
30
  ) -> None:
31
31
  super().__init__(name, fs_kwargs, cache)
@@ -82,7 +82,7 @@ class FileClient(Client):
82
82
  return bucket, path
83
83
 
84
84
  @classmethod
85
- def from_name(cls, name: str, cache: "DataChainCache", kwargs) -> "FileClient":
85
+ def from_name(cls, name: str, cache: "Cache", kwargs) -> "FileClient":
86
86
  use_symlinks = kwargs.pop("use_symlinks", False)
87
87
  return cls(name, kwargs, cache, use_symlinks=use_symlinks)
88
88
 
@@ -90,7 +90,7 @@ class FileClient(Client):
90
90
  def from_source(
91
91
  cls,
92
92
  uri: str,
93
- cache: "DataChainCache",
93
+ cache: "Cache",
94
94
  use_symlinks: bool = False,
95
95
  **kwargs,
96
96
  ) -> "FileClient":
@@ -200,7 +200,7 @@ class DataTable:
200
200
  columns: Sequence["sa.Column"] = (),
201
201
  metadata: Optional["sa.MetaData"] = None,
202
202
  ):
203
- # copy columns, since re-using the same objects from another table
203
+ # copy columns, since reusing the same objects from another table
204
204
  # may raise an error
205
205
  columns = cls.sys_columns() + [cls.copy_column(c) for c in columns]
206
206
  columns = dedup_columns(columns)
@@ -19,6 +19,7 @@ from sqlalchemy import MetaData, Table, UniqueConstraint, exists, select
19
19
  from sqlalchemy.dialects import sqlite
20
20
  from sqlalchemy.schema import CreateIndex, CreateTable, DropTable
21
21
  from sqlalchemy.sql import func
22
+ from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
22
23
  from sqlalchemy.sql.expression import bindparam, cast
23
24
  from sqlalchemy.sql.selectable import Select
24
25
  from tqdm.auto import tqdm
@@ -40,7 +41,6 @@ if TYPE_CHECKING:
40
41
  from sqlalchemy.schema import SchemaItem
41
42
  from sqlalchemy.sql._typing import _FromClauseArgument, _OnClauseArgument
42
43
  from sqlalchemy.sql.elements import ColumnElement
43
- from sqlalchemy.sql.selectable import Join
44
44
  from sqlalchemy.types import TypeEngine
45
45
 
46
46
  from datachain.lib.file import File
@@ -654,16 +654,47 @@ class SQLiteWarehouse(AbstractWarehouse):
654
654
  right: "_FromClauseArgument",
655
655
  onclause: "_OnClauseArgument",
656
656
  inner: bool = True,
657
- ) -> "Join":
657
+ full: bool = False,
658
+ columns=None,
659
+ ) -> "Select":
658
660
  """
659
661
  Join two tables together.
660
662
  """
661
- return sqlalchemy.join(
662
- left,
663
- right,
664
- onclause,
665
- isouter=not inner,
663
+ if not full:
664
+ join_query = sqlalchemy.join(
665
+ left,
666
+ right,
667
+ onclause,
668
+ isouter=not inner,
669
+ )
670
+ return sqlalchemy.select(*columns).select_from(join_query)
671
+
672
+ left_right_join = sqlalchemy.select(*columns).select_from(
673
+ sqlalchemy.join(left, right, onclause, isouter=True)
666
674
  )
675
+ right_left_join = sqlalchemy.select(*columns).select_from(
676
+ sqlalchemy.join(right, left, onclause, isouter=True)
677
+ )
678
+
679
+ def add_left_rows_filter(exp: BinaryExpression):
680
+ """
681
+ Adds filter to right_left_join to remove unmatched left table rows by
682
+ getting column names that need to be NULL from BinaryExpressions in onclause
683
+ """
684
+ return right_left_join.where(
685
+ getattr(left.c, exp.left.name) == None # type: ignore[union-attr] # noqa: E711
686
+ )
687
+
688
+ if isinstance(onclause, BinaryExpression):
689
+ right_left_join = add_left_rows_filter(onclause)
690
+
691
+ if isinstance(onclause, BooleanClauseList):
692
+ for c in onclause.get_children():
693
+ if isinstance(c, BinaryExpression):
694
+ right_left_join = add_left_rows_filter(c)
695
+
696
+ union = sqlalchemy.union(left_right_join, right_left_join).subquery()
697
+ return sqlalchemy.select(*union.c).select_from(union)
667
698
 
668
699
  def create_pre_udf_table(self, query: "Select") -> "Table":
669
700
  """
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
31
31
  _FromClauseArgument,
32
32
  _OnClauseArgument,
33
33
  )
34
- from sqlalchemy.sql.selectable import Join, Select
34
+ from sqlalchemy.sql.selectable import Select
35
35
  from sqlalchemy.types import TypeEngine
36
36
 
37
37
  from datachain.data_storage import schema
@@ -873,7 +873,7 @@ class AbstractWarehouse(ABC, Serializable):
873
873
  right: "_FromClauseArgument",
874
874
  onclause: "_OnClauseArgument",
875
875
  inner: bool = True,
876
- ) -> "Join":
876
+ ) -> "Select":
877
877
  """
878
878
  Join two tables together.
879
879
  """
datachain/dataset.py CHANGED
@@ -91,7 +91,7 @@ class DatasetDependency:
91
91
  if self.type == DatasetDependencyType.DATASET:
92
92
  return self.name
93
93
 
94
- list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), None, {})
94
+ list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), {})
95
95
  assert list_dataset_name
96
96
  return list_dataset_name
97
97
 
datachain/error.py CHANGED
@@ -1,3 +1,15 @@
1
+ import botocore.errorfactory
2
+ import botocore.exceptions
3
+ import gcsfs.retry
4
+
5
+ REMOTE_ERRORS = (
6
+ gcsfs.retry.HttpError, # GCS
7
+ OSError, # GCS
8
+ botocore.exceptions.BotoCoreError, # S3
9
+ ValueError, # Azure
10
+ )
11
+
12
+
1
13
  class DataChainError(RuntimeError):
2
14
  pass
3
15
 
@@ -16,7 +16,7 @@ from .aggregate import (
16
16
  sum,
17
17
  )
18
18
  from .array import cosine_distance, euclidean_distance, length, sip_hash_64
19
- from .conditional import case, greatest, ifelse, least
19
+ from .conditional import case, greatest, ifelse, isnone, least
20
20
  from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
21
21
  from .random import rand
22
22
  from .string import byte_hamming_distance
@@ -42,6 +42,7 @@ __all__ = [
42
42
  "greatest",
43
43
  "ifelse",
44
44
  "int_hash_64",
45
+ "isnone",
45
46
  "least",
46
47
  "length",
47
48
  "literal",
@@ -1,14 +1,15 @@
1
- from typing import Union
1
+ from typing import Optional, Union
2
2
 
3
+ from sqlalchemy import ColumnElement
3
4
  from sqlalchemy import case as sql_case
4
- from sqlalchemy.sql.elements import BinaryExpression
5
5
 
6
6
  from datachain.lib.utils import DataChainParamsError
7
+ from datachain.query.schema import Column
7
8
  from datachain.sql.functions import conditional
8
9
 
9
10
  from .func import ColT, Func
10
11
 
11
- CaseT = Union[int, float, complex, bool, str]
12
+ CaseT = Union[int, float, complex, bool, str, Func]
12
13
 
13
14
 
14
15
  def greatest(*args: Union[ColT, float]) -> Func:
@@ -87,17 +88,21 @@ def least(*args: Union[ColT, float]) -> Func:
87
88
  )
88
89
 
89
90
 
90
- def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
91
+ def case(
92
+ *args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None
93
+ ) -> Func:
91
94
  """
92
95
  Returns the case function that produces case expression which has a list of
93
- conditions and corresponding results. Results can only be python primitives
94
- like string, numbes or booleans. Result type is inferred from condition results.
96
+ conditions and corresponding results. Results can be python primitives like string,
97
+ numbers or booleans but can also be other nested function (including case function).
98
+ Result type is inferred from condition results.
95
99
 
96
100
  Args:
97
- args (tuple(BinaryExpression, value(str | int | float | complex | bool):
98
- - Tuple of binary expression and values pair which corresponds to one
99
- case condition - value
100
- else_ (str | int | float | complex | bool): else value in case expression
101
+ args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))):
102
+ Tuple of condition and values pair.
103
+ else_ (str | int | float | complex | bool, Func): optional else value in case
104
+ expression. If omitted, and no case conditions are satisfied, the result
105
+ will be None (NULL in DB).
101
106
 
102
107
  Returns:
103
108
  Func: A Func object that represents the case function.
@@ -111,15 +116,24 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
111
116
  """
112
117
  supported_types = [int, float, complex, str, bool]
113
118
 
114
- type_ = type(else_) if else_ else None
119
+ def _get_type(val):
120
+ if isinstance(val, Func):
121
+ # nested functions
122
+ return val.result_type
123
+ return type(val)
115
124
 
116
125
  if not args:
117
126
  raise DataChainParamsError("Missing statements")
118
127
 
128
+ type_ = _get_type(else_) if else_ is not None else None
129
+
119
130
  for arg in args:
120
- if type_ and not isinstance(arg[1], type_):
121
- raise DataChainParamsError("Statement values must be of the same type")
122
- type_ = type(arg[1])
131
+ arg_type = _get_type(arg[1])
132
+ if type_ and arg_type != type_:
133
+ raise DataChainParamsError(
134
+ f"Statement values must be of the same type, got {type_} and {arg_type}"
135
+ )
136
+ type_ = arg_type
123
137
 
124
138
  if type_ not in supported_types:
125
139
  raise DataChainParamsError(
@@ -127,20 +141,25 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
127
141
  )
128
142
 
129
143
  kwargs = {"else_": else_}
130
- return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)
144
+
145
+ return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_)
131
146
 
132
147
 
133
- def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
148
+ def ifelse(
149
+ condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT
150
+ ) -> Func:
134
151
  """
135
152
  Returns the ifelse function that produces if expression which has a condition
136
- and values for true and false outcome. Results can only be python primitives
137
- like string, numbes or booleans. Result type is inferred from the values.
153
+ and values for true and false outcome. Results can be one of python primitives
154
+ like string, numbers or booleans, but can also be nested functions.
155
+ Result type is inferred from the values.
138
156
 
139
157
  Args:
140
- condition: BinaryExpression - condition which is evaluated
141
- if_val: (str | int | float | complex | bool): value for true condition outcome
142
- else_val: (str | int | float | complex | bool): value for false condition
143
- outcome
158
+ condition (ColumnElement, Func): Condition which is evaluated.
159
+ if_val (str | int | float | complex | bool, Func): Value for true
160
+ condition outcome.
161
+ else_val (str | int | float | complex | bool, Func): Value for false condition
162
+ outcome.
144
163
 
145
164
  Returns:
146
165
  Func: A Func object that represents the ifelse function.
@@ -148,8 +167,33 @@ def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
148
167
  Example:
149
168
  ```py
150
169
  dc.mutate(
151
- res=func.ifelse(C("num") > 0, "P", "N"),
170
+ res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY")
152
171
  )
153
172
  ```
154
173
  """
155
174
  return case((condition, if_val), else_=else_val)
175
+
176
+
177
+ def isnone(col: Union[str, Column]) -> Func:
178
+ """
179
+ Returns True if column value is None, otherwise False.
180
+
181
+ Args:
182
+ col (str | Column): Column to check if it's None or not.
183
+ If a string is provided, it is assumed to be the name of the column.
184
+
185
+ Returns:
186
+ Func: A Func object that represents the conditional to check if column is None.
187
+
188
+ Example:
189
+ ```py
190
+ dc.mutate(test=ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"))
191
+ ```
192
+ """
193
+ from datachain import C
194
+
195
+ if isinstance(col, str):
196
+ # if string, it is assumed to be the name of the column
197
+ col = C(col)
198
+
199
+ return case((col.is_(None) if col is not None else True, True), else_=False)
datachain/func/func.py CHANGED
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
23
23
  from .window import Window
24
24
 
25
25
 
26
- ColT = Union[str, ColumnElement, "Func"]
26
+ ColT = Union[str, ColumnElement, "Func", tuple]
27
27
 
28
28
 
29
29
  class Func(Function):
@@ -78,7 +78,7 @@ class Func(Function):
78
78
  return (
79
79
  [
80
80
  col
81
- if isinstance(col, (Func, BindParameter, Case, Comparator))
81
+ if isinstance(col, (Func, BindParameter, Case, Comparator, tuple))
82
82
  else ColumnMeta.to_db_name(
83
83
  col.name if isinstance(col, ColumnElement) else col
84
84
  )
@@ -381,17 +381,24 @@ class Func(Function):
381
381
  col_type = self.get_result_type(signals_schema)
382
382
  sql_type = python_to_sql(col_type)
383
383
 
384
- def get_col(col: ColT) -> ColT:
384
+ def get_col(col: ColT, string_as_literal=False) -> ColT:
385
+ # string_as_literal is used only for conditionals like `case()` where
386
+ # literals are nested inside ColT as we have tuples of condition - values
387
+ # and if user wants to set some case value as column, explicit `C("col")`
388
+ # syntax must be used to distinguish from literals
389
+ if isinstance(col, tuple):
390
+ return tuple(get_col(x, string_as_literal=True) for x in col)
385
391
  if isinstance(col, Func):
386
392
  return col.get_column(signals_schema, table=table)
387
- if isinstance(col, str):
393
+ if isinstance(col, str) and not string_as_literal:
388
394
  column = Column(col, sql_type)
389
395
  column.table = table
390
396
  return column
391
397
  return col
392
398
 
393
399
  cols = [get_col(col) for col in self._db_cols]
394
- func_col = self.inner(*cols, *self.args, **self.kwargs)
400
+ kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
401
+ func_col = self.inner(*cols, *self.args, **kwargs)
395
402
 
396
403
  if self.is_window:
397
404
  if not self.window:
@@ -416,6 +423,11 @@ class Func(Function):
416
423
 
417
424
 
418
425
  def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
426
+ if isinstance(col, tuple):
427
+ raise DataChainParamsError(
428
+ "Cannot get type from tuple, please provide type hint to the function"
429
+ )
430
+
419
431
  if isinstance(col, Func):
420
432
  return col.get_result_type(signals_schema)
421
433
 
@@ -52,15 +52,15 @@ def python_to_sql(typ): # noqa: PLR0911
52
52
 
53
53
  args = get_args(typ)
54
54
  if inspect.isclass(orig) and (issubclass(list, orig) or issubclass(tuple, orig)):
55
- if args is None or len(args) != 1:
55
+ if args is None:
56
56
  raise TypeError(f"Cannot resolve type '{typ}' for flattening features")
57
57
 
58
58
  args0 = args[0]
59
59
  if ModelStore.is_pydantic(args0):
60
60
  return Array(JSON())
61
61
 
62
- next_type = python_to_sql(args0)
63
- return Array(next_type)
62
+ list_type = list_of_args_to_type(args)
63
+ return Array(list_type)
64
64
 
65
65
  if orig is Annotated:
66
66
  # Ignoring annotations
@@ -82,6 +82,18 @@ def python_to_sql(typ): # noqa: PLR0911
82
82
  raise TypeError(f"Cannot recognize type {typ}")
83
83
 
84
84
 
85
+ def list_of_args_to_type(args) -> SQLType:
86
+ first_type = python_to_sql(args[0])
87
+ for next_arg in args[1:]:
88
+ try:
89
+ next_type = python_to_sql(next_arg)
90
+ if next_type != first_type:
91
+ return JSON()
92
+ except TypeError:
93
+ return JSON()
94
+ return first_type
95
+
96
+
85
97
  def _is_json_inside_union(orig, args) -> bool:
86
98
  if orig == Union and len(args) >= 2:
87
99
  # List in JSON: Union[dict, list[dict]]