datachain 0.8.13__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -4,11 +4,9 @@ from collections.abc import Sequence
4
4
  from enum import Enum
5
5
  from typing import TYPE_CHECKING, Optional, Union
6
6
 
7
- import sqlalchemy as sa
8
-
7
+ from datachain.func import case, ifelse, isnone, or_
9
8
  from datachain.lib.signal_schema import SignalSchema
10
9
  from datachain.query.schema import Column
11
- from datachain.sql.types import String
12
10
 
13
11
  if TYPE_CHECKING:
14
12
  from datachain.lib.dc import DataChain
@@ -32,7 +30,7 @@ class CompareStatus(str, Enum):
32
30
  SAME = "S"
33
31
 
34
32
 
35
- def _compare( # noqa: PLR0912, PLR0915, C901
33
+ def _compare( # noqa: C901
36
34
  left: "DataChain",
37
35
  right: "DataChain",
38
36
  on: Union[str, Sequence[str]],
@@ -47,63 +45,46 @@ def _compare( # noqa: PLR0912, PLR0915, C901
47
45
  ) -> "DataChain":
48
46
  """Comparing two chains by identifying rows that are added, deleted, modified
49
47
  or same"""
50
- dialect = left._query.dialect
51
-
52
48
  rname = "right_"
49
+ schema = left.signals_schema # final chain must have schema from left chain
53
50
 
54
- def _rprefix(c: str, rc: str) -> str:
55
- """Returns prefix of right of two companion left - right columns
56
- from merge. If companion columns have the same name then prefix will
57
- be present in right column name, otherwise it won't.
58
- """
59
- return rname if c == rc else ""
60
-
61
- def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
51
+ def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]:
52
+ if obj is None:
53
+ return None
62
54
  return [obj] if isinstance(obj, str) else list(obj)
63
55
 
64
- if on is None:
65
- raise ValueError("'on' must be specified")
66
-
67
- on = _to_list(on)
68
- if right_on:
69
- right_on = _to_list(right_on)
70
- if len(on) != len(right_on):
71
- raise ValueError("'on' and 'right_on' must be have the same length")
72
-
73
- if compare:
74
- compare = _to_list(compare)
75
-
76
- if right_compare:
77
- if not compare:
78
- raise ValueError("'compare' must be defined if 'right_compare' is defined")
79
-
80
- right_compare = _to_list(right_compare)
81
- if len(compare) != len(right_compare):
82
- raise ValueError(
83
- "'compare' and 'right_compare' must be have the same length"
84
- )
56
+ on = _to_list(on) # type: ignore[assignment]
57
+ right_on = _to_list(right_on)
58
+ compare = _to_list(compare)
59
+ right_compare = _to_list(right_compare)
85
60
 
86
61
  if not any([added, deleted, modified, same]):
87
62
  raise ValueError(
88
63
  "At least one of added, deleted, modified, same flags must be set"
89
64
  )
90
-
91
- need_status_col = bool(status_col)
92
- # we still need status column for internal implementation even if not
93
- # needed in the output
94
- status_col = status_col or get_status_col_name()
95
-
96
- # calculate on and compare column names
97
- right_on = right_on or on
65
+ if on is None:
66
+ raise ValueError("'on' must be specified")
67
+ if right_on and len(on) != len(right_on):
68
+ raise ValueError("'on' and 'right_on' must be have the same length")
69
+ if right_compare and not compare:
70
+ raise ValueError("'compare' must be defined if 'right_compare' is defined")
71
+ if compare and right_compare and len(compare) != len(right_compare):
72
+ raise ValueError("'compare' and 'right_compare' must have the same length")
73
+
74
+ # all left and right columns
98
75
  cols = left.signals_schema.clone_without_sys_signals().db_signals()
99
76
  right_cols = right.signals_schema.clone_without_sys_signals().db_signals()
100
77
 
78
+ # getting correct on and right_on column names
101
79
  on = left.signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
102
- right_on = right.signals_schema.resolve(*right_on).db_signals() # type: ignore[assignment]
80
+ right_on = right.signals_schema.resolve(*(right_on or on)).db_signals() # type: ignore[assignment]
81
+
82
+ # getting correct compare and right_compare column names if they are defined
103
83
  if compare:
104
- right_compare = right_compare or compare
105
84
  compare = left.signals_schema.resolve(*compare).db_signals() # type: ignore[assignment]
106
- right_compare = right.signals_schema.resolve(*right_compare).db_signals() # type: ignore[assignment]
85
+ right_compare = right.signals_schema.resolve(
86
+ *(right_compare or compare)
87
+ ).db_signals() # type: ignore[assignment]
107
88
  elif not compare and len(cols) != len(right_cols):
108
89
  # here we will mark all rows that are not added or deleted as modified since
109
90
  # there was no explicit list of compare columns provided (meaning we need
@@ -113,103 +94,72 @@ def _compare( # noqa: PLR0912, PLR0915, C901
113
94
  compare = None
114
95
  right_compare = None
115
96
  else:
116
- compare = [c for c in cols if c in right_cols] # type: ignore[misc, assignment]
117
- right_compare = compare
97
+ # we are checking all columns as explicit compare is not defined
98
+ compare = right_compare = [c for c in cols if c in right_cols and c not in on] # type: ignore[misc]
118
99
 
119
- diff_cond = []
100
+ # get diff column names
101
+ diff_col = status_col or get_status_col_name()
102
+ ldiff_col = get_status_col_name()
103
+ rdiff_col = get_status_col_name()
120
104
 
121
- if added:
122
- added_cond = sa.and_(
123
- *[
124
- C(c) == None # noqa: E711
125
- for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
126
- ]
127
- )
128
- diff_cond.append((added_cond, CompareStatus.ADDED))
129
- if modified and compare:
130
- modified_cond = sa.or_(
131
- *[
132
- C(c) != C(f"{_rprefix(c, rc)}{rc}")
133
- for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
134
- ]
135
- )
136
- diff_cond.append((modified_cond, CompareStatus.MODIFIED))
137
- if same and compare:
138
- same_cond = sa.and_(
105
+ # adding helper diff columns, which will be removed after
106
+ left = left.mutate(**{ldiff_col: 1})
107
+ right = right.mutate(**{rdiff_col: 1})
108
+
109
+ if not compare:
110
+ modified_cond = True
111
+ else:
112
+ modified_cond = or_( # type: ignore[assignment]
139
113
  *[
140
- C(c) == C(f"{_rprefix(c, rc)}{rc}")
114
+ C(c) != (C(f"{rname}{rc}") if c == rc else C(rc))
141
115
  for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
142
116
  ]
143
117
  )
144
- diff_cond.append((same_cond, CompareStatus.SAME))
145
-
146
- diff = sa.case(*diff_cond, else_=None if compare else CompareStatus.MODIFIED).label(
147
- status_col
148
- )
149
- diff.type = String()
150
-
151
- left_right_merge = left.merge(
152
- right, on=on, right_on=right_on, inner=False, rname=rname
153
- )
154
- left_right_merge_select = left_right_merge._query.select(
155
- *(
156
- [C(c) for c in left_right_merge.signals_schema.db_signals("sys")]
157
- + [C(c) for c in on]
158
- + [C(c) for c in cols if c not in on]
159
- + [diff]
160
- )
161
- )
162
-
163
- diff_col = sa.literal(CompareStatus.DELETED).label(status_col)
164
- diff_col.type = String()
165
118
 
166
- right_left_merge = right.merge(
167
- left, on=right_on, right_on=on, inner=False, rname=rname
168
- ).filter(
169
- sa.and_(
170
- *[C(f"{_rprefix(c, rc)}{c}") == None for c, rc in zip(on, right_on)] # noqa: E711
119
+ dc_diff = (
120
+ left.merge(right, on=on, right_on=right_on, rname=rname, full=True)
121
+ .mutate(
122
+ **{
123
+ diff_col: case(
124
+ (isnone(ldiff_col), CompareStatus.DELETED),
125
+ (isnone(rdiff_col), CompareStatus.ADDED),
126
+ (modified_cond, CompareStatus.MODIFIED),
127
+ else_=CompareStatus.SAME,
128
+ )
129
+ }
171
130
  )
172
- )
173
-
174
- def _default_val(chain: "DataChain", col: str):
175
- col_type = chain._query.column_types[col] # type: ignore[index]
176
- val = sa.literal(col_type.default_value(dialect)).label(col)
177
- val.type = col_type()
178
- return val
179
-
180
- right_left_merge_select = right_left_merge._query.select(
181
- *(
182
- [C(c) for c in right_left_merge.signals_schema.db_signals("sys")]
183
- + [
184
- C(c) if c == rc else _default_val(left, c)
185
- for c, rc in zip(on, right_on)
186
- ]
187
- + [
188
- C(c) if c in right_cols else _default_val(left, c) # type: ignore[arg-type]
189
- for c in cols
190
- if c not in on
191
- ]
192
- + [diff_col]
131
+ # when the row is deleted, we need to take column values from the right chain
132
+ .mutate(
133
+ **{
134
+ f"{c}": ifelse(
135
+ C(diff_col) == CompareStatus.DELETED, C(f"{rname}{c}"), C(c)
136
+ )
137
+ for c in [c for c in cols if c in right_cols]
138
+ }
193
139
  )
140
+ .select_except(ldiff_col, rdiff_col)
194
141
  )
195
142
 
143
+ if not added:
144
+ dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.ADDED)
145
+ if not modified:
146
+ dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.MODIFIED)
147
+ if not same:
148
+ dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.SAME)
196
149
  if not deleted:
197
- res = left_right_merge_select
198
- elif deleted and not any([added, modified, same]):
199
- res = right_left_merge_select
200
- else:
201
- res = left_right_merge_select.union(right_left_merge_select)
150
+ dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.DELETED)
202
151
 
203
- res = res.filter(C(status_col) != None) # noqa: E711
152
+ if status_col:
153
+ cols.append(diff_col) # type: ignore[arg-type]
204
154
 
205
- schema = left.signals_schema
206
- if need_status_col:
207
- res = res.select()
208
- schema = SignalSchema({status_col: str}) | schema
209
- else:
210
- res = res.select_except(C(status_col))
155
+ dc_diff = dc_diff.select(*cols)
156
+
157
+ # final schema is schema from the left chain with status column added if needed
158
+ dc_diff.signals_schema = (
159
+ schema if not status_col else SignalSchema({status_col: str}) | schema
160
+ )
211
161
 
212
- return left._evolve(query=res, signal_schema=schema)
162
+ return dc_diff
213
163
 
214
164
 
215
165
  def compare_and_split(
@@ -0,0 +1,21 @@
1
+ import fsspec
2
+ from packaging.version import Version, parse
3
+
4
+ # fsspec==2025.2.0 added support for a proper `open()` in `ReferenceFileSystem`.
5
+ # Remove this module when `fsspec` minimum version requirement can be bumped.
6
+ if parse(fsspec.__version__) < Version("2025.2.0"):
7
+ from fsspec.core import split_protocol
8
+ from fsspec.implementations import reference
9
+
10
+ class ReferenceFileSystem(reference.ReferenceFileSystem):
11
+ def _open(self, path, mode="rb", *args, **kwargs):
12
+ # overriding because `fsspec`'s `ReferenceFileSystem._open`
13
+ # reads the whole file in-memory.
14
+ (uri,) = self.references[path]
15
+ protocol, _ = split_protocol(uri)
16
+ return self.fss[protocol].open(uri, mode, *args, **kwargs)
17
+ else:
18
+ from fsspec.implementations.reference import ReferenceFileSystem # type: ignore[no-redef] # noqa: I001
19
+
20
+
21
+ __all__ = ["ReferenceFileSystem"]
@@ -16,13 +16,14 @@ from .aggregate import (
16
16
  sum,
17
17
  )
18
18
  from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
19
- from .conditional import case, greatest, ifelse, isnone, least
19
+ from .conditional import and_, case, greatest, ifelse, isnone, least, or_
20
20
  from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
21
21
  from .random import rand
22
22
  from .string import byte_hamming_distance
23
23
  from .window import window
24
24
 
25
25
  __all__ = [
26
+ "and_",
26
27
  "any_value",
27
28
  "array",
28
29
  "avg",
@@ -49,6 +50,7 @@ __all__ = [
49
50
  "literal",
50
51
  "max",
51
52
  "min",
53
+ "or_",
52
54
  "path",
53
55
  "rand",
54
56
  "random",
@@ -1,7 +1,9 @@
1
1
  from typing import Optional, Union
2
2
 
3
3
  from sqlalchemy import ColumnElement
4
+ from sqlalchemy import and_ as sql_and
4
5
  from sqlalchemy import case as sql_case
6
+ from sqlalchemy import or_ as sql_or
5
7
 
6
8
  from datachain.lib.utils import DataChainParamsError
7
9
  from datachain.query.schema import Column
@@ -89,7 +91,7 @@ def least(*args: Union[ColT, float]) -> Func:
89
91
 
90
92
 
91
93
  def case(
92
- *args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None
94
+ *args: tuple[Union[ColumnElement, Func, bool], CaseT], else_: Optional[CaseT] = None
93
95
  ) -> Func:
94
96
  """
95
97
  Returns the case function that produces case expression which has a list of
@@ -99,7 +101,7 @@ def case(
99
101
  Result type is inferred from condition results.
100
102
 
101
103
  Args:
102
- args tuple((ColumnElement | Func),(str | int | float | complex | bool, Func, ColumnElement)):
104
+ args tuple((ColumnElement | Func | bool),(str | int | float | complex | bool, Func, ColumnElement)):
103
105
  Tuple of condition and values pair.
104
106
  else_ (str | int | float | complex | bool, Func): optional else value in case
105
107
  expression. If omitted, and no case conditions are satisfied, the result
@@ -118,12 +120,16 @@ def case(
118
120
  supported_types = [int, float, complex, str, bool]
119
121
 
120
122
  def _get_type(val):
123
+ from enum import Enum
124
+
121
125
  if isinstance(val, Func):
122
126
  # nested functions
123
127
  return val.result_type
124
128
  if isinstance(val, Column):
125
129
  # at this point we cannot know what is the type of a column
126
130
  return None
131
+ if isinstance(val, Enum):
132
+ return type(val.value)
127
133
  return type(val)
128
134
 
129
135
  if not args:
@@ -204,3 +210,61 @@ def isnone(col: Union[str, Column]) -> Func:
204
210
  col = C(col)
205
211
 
206
212
  return case((col.is_(None) if col is not None else True, True), else_=False)
213
+
214
+
215
+ def or_(*args: Union[ColumnElement, Func]) -> Func:
216
+ """
217
+ Returns the function that produces conjunction of expressions joined by OR
218
+ logical operator.
219
+
220
+ Args:
221
+ args (ColumnElement | Func): The expressions for OR statement.
222
+
223
+ Returns:
224
+ Func: A Func object that represents the or function.
225
+
226
+ Example:
227
+ ```py
228
+ dc.mutate(
229
+ test=ifelse(or_(isnone("name"), C("name") == ''), "Empty", "Not Empty")
230
+ )
231
+ ```
232
+ """
233
+ cols, func_args = [], []
234
+
235
+ for arg in args:
236
+ if isinstance(arg, (str, Func)):
237
+ cols.append(arg)
238
+ else:
239
+ func_args.append(arg)
240
+
241
+ return Func("or", inner=sql_or, cols=cols, args=func_args, result_type=bool)
242
+
243
+
244
+ def and_(*args: Union[ColumnElement, Func]) -> Func:
245
+ """
246
+ Returns the function that produces conjunction of expressions joined by AND
247
+ logical operator.
248
+
249
+ Args:
250
+ args (ColumnElement | Func): The expressions for AND statement.
251
+
252
+ Returns:
253
+ Func: A Func object that represents the and function.
254
+
255
+ Example:
256
+ ```py
257
+ dc.mutate(
258
+ test=ifelse(and_(isnone("name"), isnone("surname")), "Empty", "Not Empty")
259
+ )
260
+ ```
261
+ """
262
+ cols, func_args = [], []
263
+
264
+ for arg in args:
265
+ if isinstance(arg, (str, Func)):
266
+ cols.append(arg)
267
+ else:
268
+ func_args.append(arg)
269
+
270
+ return Func("and", inner=sql_and, cols=cols, args=func_args, result_type=bool)
datachain/job.py CHANGED
@@ -25,7 +25,7 @@ class Job:
25
25
 
26
26
  @classmethod
27
27
  def parse(
28
- cls: type[J],
28
+ cls,
29
29
  id: Union[str, uuid.UUID],
30
30
  name: str,
31
31
  status: int,
datachain/lib/arrow.py CHANGED
@@ -2,13 +2,12 @@ from collections.abc import Sequence
2
2
  from itertools import islice
3
3
  from typing import TYPE_CHECKING, Any, Optional
4
4
 
5
- import fsspec.implementations.reference
6
5
  import orjson
7
6
  import pyarrow as pa
8
- from fsspec.core import split_protocol
9
7
  from pyarrow.dataset import CsvFileFormat, dataset
10
8
  from tqdm.auto import tqdm
11
9
 
10
+ from datachain.fs.reference import ReferenceFileSystem
12
11
  from datachain.lib.data_model import dict_to_data_model
13
12
  from datachain.lib.file import ArrowRow, File
14
13
  from datachain.lib.model_store import ModelStore
@@ -27,15 +26,6 @@ if TYPE_CHECKING:
27
26
  DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"
28
27
 
29
28
 
30
- class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
31
- def _open(self, path, mode="rb", *args, **kwargs):
32
- # overriding because `fsspec`'s `ReferenceFileSystem._open`
33
- # reads the whole file in-memory.
34
- (uri,) = self.references[path]
35
- protocol, _ = split_protocol(uri)
36
- return self.fss[protocol].open(uri, mode, *args, **kwargs)
37
-
38
-
39
29
  class ArrowGenerator(Generator):
40
30
  DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow._dataset._DEFAULT_BATCH_SIZE`
41
31
 
datachain/lib/dc.py CHANGED
@@ -481,6 +481,7 @@ class DataChain:
481
481
  version: Optional[int] = None,
482
482
  session: Optional[Session] = None,
483
483
  settings: Optional[dict] = None,
484
+ fallback_to_remote: bool = True,
484
485
  ) -> "Self":
485
486
  """Get data from a saved Dataset. It returns the chain itself.
486
487
 
@@ -498,6 +499,7 @@ class DataChain:
498
499
  version=version,
499
500
  session=session,
500
501
  indexing_column_types=File._datachain_column_types,
502
+ fallback_to_remote=fallback_to_remote,
501
503
  )
502
504
  telemetry.send_event_once("class", "datachain_init", name=name, version=version)
503
505
  if settings: