datachain 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1,4 +1,4 @@
1
- from sqlalchemy import literal
1
+ from sqlalchemy import case, literal
2
2
 
3
3
  from . import array, path, random, string
4
4
  from .aggregate import (
@@ -17,6 +17,7 @@ from .aggregate import (
17
17
  )
18
18
  from .array import cosine_distance, euclidean_distance, length, sip_hash_64
19
19
  from .conditional import greatest, least
20
+ from .numeric import bit_and, bit_or, bit_xor, int_hash_64
20
21
  from .random import rand
21
22
  from .window import window
22
23
 
@@ -24,6 +25,10 @@ __all__ = [
24
25
  "any_value",
25
26
  "array",
26
27
  "avg",
28
+ "bit_and",
29
+ "bit_or",
30
+ "bit_xor",
31
+ "case",
27
32
  "collect",
28
33
  "concat",
29
34
  "cosine_distance",
@@ -32,6 +37,7 @@ __all__ = [
32
37
  "euclidean_distance",
33
38
  "first",
34
39
  "greatest",
40
+ "int_hash_64",
35
41
  "least",
36
42
  "length",
37
43
  "literal",
datachain/func/func.py CHANGED
@@ -2,11 +2,15 @@ import inspect
2
2
  from collections.abc import Sequence
3
3
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union
4
4
 
5
- from sqlalchemy import BindParameter, ColumnElement, desc
5
+ from sqlalchemy import BindParameter, Case, ColumnElement, Integer, cast, desc
6
+ from sqlalchemy.ext.hybrid import Comparator
7
+ from sqlalchemy.sql import func as sa_func
6
8
 
7
9
  from datachain.lib.convert.python_to_sql import python_to_sql
10
+ from datachain.lib.convert.sql_to_python import sql_to_python
8
11
  from datachain.lib.utils import DataChainColumnError, DataChainParamsError
9
12
  from datachain.query.schema import Column, ColumnMeta
13
+ from datachain.sql.functions import numeric
10
14
 
11
15
  from .base import Function
12
16
 
@@ -71,7 +75,7 @@ class Func(Function):
71
75
  return (
72
76
  [
73
77
  col
74
- if isinstance(col, (Func, BindParameter))
78
+ if isinstance(col, (Func, BindParameter, Case, Comparator))
75
79
  else ColumnMeta.to_db_name(
76
80
  col.name if isinstance(col, ColumnElement) else col
77
81
  )
@@ -96,94 +100,232 @@ class Func(Function):
96
100
  return list[col_type] if self.is_array else col_type # type: ignore[valid-type]
97
101
 
98
102
  def __add__(self, other: Union[ColT, float]) -> "Func":
99
- return math_add(self, other)
103
+ if isinstance(other, (int, float)):
104
+ return Func("add", lambda a: a + other, [self])
105
+ return Func("add", lambda a1, a2: a1 + a2, [self, other])
100
106
 
101
107
  def __radd__(self, other: Union[ColT, float]) -> "Func":
102
- return math_add(other, self)
108
+ if isinstance(other, (int, float)):
109
+ return Func("add", lambda a: other + a, [self])
110
+ return Func("add", lambda a1, a2: a1 + a2, [other, self])
103
111
 
104
112
  def __sub__(self, other: Union[ColT, float]) -> "Func":
105
- return math_sub(self, other)
113
+ if isinstance(other, (int, float)):
114
+ return Func("sub", lambda a: a - other, [self])
115
+ return Func("sub", lambda a1, a2: a1 - a2, [self, other])
106
116
 
107
117
  def __rsub__(self, other: Union[ColT, float]) -> "Func":
108
- return math_sub(other, self)
118
+ if isinstance(other, (int, float)):
119
+ return Func("sub", lambda a: other - a, [self])
120
+ return Func("sub", lambda a1, a2: a1 - a2, [other, self])
109
121
 
110
122
  def __mul__(self, other: Union[ColT, float]) -> "Func":
111
- return math_mul(self, other)
123
+ if isinstance(other, (int, float)):
124
+ return Func("mul", lambda a: a * other, [self])
125
+ return Func("mul", lambda a1, a2: a1 * a2, [self, other])
112
126
 
113
127
  def __rmul__(self, other: Union[ColT, float]) -> "Func":
114
- return math_mul(other, self)
128
+ if isinstance(other, (int, float)):
129
+ return Func("mul", lambda a: other * a, [self])
130
+ return Func("mul", lambda a1, a2: a1 * a2, [other, self])
115
131
 
116
132
  def __truediv__(self, other: Union[ColT, float]) -> "Func":
117
- return math_truediv(self, other)
133
+ if isinstance(other, (int, float)):
134
+ return Func("div", lambda a: _truediv(a, other), [self], result_type=float)
135
+ return Func(
136
+ "div", lambda a1, a2: _truediv(a1, a2), [self, other], result_type=float
137
+ )
118
138
 
119
139
  def __rtruediv__(self, other: Union[ColT, float]) -> "Func":
120
- return math_truediv(other, self)
140
+ if isinstance(other, (int, float)):
141
+ return Func("div", lambda a: _truediv(other, a), [self], result_type=float)
142
+ return Func(
143
+ "div", lambda a1, a2: _truediv(a1, a2), [other, self], result_type=float
144
+ )
121
145
 
122
146
  def __floordiv__(self, other: Union[ColT, float]) -> "Func":
123
- return math_floordiv(self, other)
147
+ if isinstance(other, (int, float)):
148
+ return Func(
149
+ "floordiv", lambda a: _floordiv(a, other), [self], result_type=int
150
+ )
151
+ return Func(
152
+ "floordiv", lambda a1, a2: _floordiv(a1, a2), [self, other], result_type=int
153
+ )
124
154
 
125
155
  def __rfloordiv__(self, other: Union[ColT, float]) -> "Func":
126
- return math_floordiv(other, self)
156
+ if isinstance(other, (int, float)):
157
+ return Func(
158
+ "floordiv", lambda a: _floordiv(other, a), [self], result_type=int
159
+ )
160
+ return Func(
161
+ "floordiv", lambda a1, a2: _floordiv(a1, a2), [other, self], result_type=int
162
+ )
127
163
 
128
164
  def __mod__(self, other: Union[ColT, float]) -> "Func":
129
- return math_mod(self, other)
165
+ if isinstance(other, (int, float)):
166
+ return Func("mod", lambda a: a % other, [self], result_type=int)
167
+ return Func("mod", lambda a1, a2: a1 % a2, [self, other], result_type=int)
130
168
 
131
169
  def __rmod__(self, other: Union[ColT, float]) -> "Func":
132
- return math_mod(other, self)
133
-
134
- def __pow__(self, other: Union[ColT, float]) -> "Func":
135
- return math_pow(self, other)
136
-
137
- def __rpow__(self, other: Union[ColT, float]) -> "Func":
138
- return math_pow(other, self)
139
-
140
- def __lshift__(self, other: Union[ColT, float]) -> "Func":
141
- return math_lshift(self, other)
142
-
143
- def __rlshift__(self, other: Union[ColT, float]) -> "Func":
144
- return math_lshift(other, self)
145
-
146
- def __rshift__(self, other: Union[ColT, float]) -> "Func":
147
- return math_rshift(self, other)
148
-
149
- def __rrshift__(self, other: Union[ColT, float]) -> "Func":
150
- return math_rshift(other, self)
170
+ if isinstance(other, (int, float)):
171
+ return Func("mod", lambda a: other % a, [self], result_type=int)
172
+ return Func("mod", lambda a1, a2: a1 % a2, [other, self], result_type=int)
151
173
 
152
174
  def __and__(self, other: Union[ColT, float]) -> "Func":
153
- return math_and(self, other)
175
+ if isinstance(other, (int, float)):
176
+ return Func(
177
+ "and", lambda a: numeric.bit_and(a, other), [self], result_type=int
178
+ )
179
+ return Func(
180
+ "and",
181
+ lambda a1, a2: numeric.bit_and(a1, a2),
182
+ [self, other],
183
+ result_type=int,
184
+ )
154
185
 
155
186
  def __rand__(self, other: Union[ColT, float]) -> "Func":
156
- return math_and(other, self)
187
+ if isinstance(other, (int, float)):
188
+ return Func(
189
+ "and", lambda a: numeric.bit_and(other, a), [self], result_type=int
190
+ )
191
+ return Func(
192
+ "and",
193
+ lambda a1, a2: numeric.bit_and(a1, a2),
194
+ [other, self],
195
+ result_type=int,
196
+ )
157
197
 
158
198
  def __or__(self, other: Union[ColT, float]) -> "Func":
159
- return math_or(self, other)
199
+ if isinstance(other, (int, float)):
200
+ return Func(
201
+ "or", lambda a: numeric.bit_or(a, other), [self], result_type=int
202
+ )
203
+ return Func(
204
+ "or", lambda a1, a2: numeric.bit_or(a1, a2), [self, other], result_type=int
205
+ )
160
206
 
161
207
  def __ror__(self, other: Union[ColT, float]) -> "Func":
162
- return math_or(other, self)
208
+ if isinstance(other, (int, float)):
209
+ return Func(
210
+ "or", lambda a: numeric.bit_or(other, a), [self], result_type=int
211
+ )
212
+ return Func(
213
+ "or", lambda a1, a2: numeric.bit_or(a1, a2), [other, self], result_type=int
214
+ )
163
215
 
164
216
  def __xor__(self, other: Union[ColT, float]) -> "Func":
165
- return math_xor(self, other)
217
+ if isinstance(other, (int, float)):
218
+ return Func(
219
+ "xor", lambda a: numeric.bit_xor(a, other), [self], result_type=int
220
+ )
221
+ return Func(
222
+ "xor",
223
+ lambda a1, a2: numeric.bit_xor(a1, a2),
224
+ [self, other],
225
+ result_type=int,
226
+ )
166
227
 
167
228
  def __rxor__(self, other: Union[ColT, float]) -> "Func":
168
- return math_xor(other, self)
229
+ if isinstance(other, (int, float)):
230
+ return Func(
231
+ "xor", lambda a: numeric.bit_xor(other, a), [self], result_type=int
232
+ )
233
+ return Func(
234
+ "xor",
235
+ lambda a1, a2: numeric.bit_xor(a1, a2),
236
+ [other, self],
237
+ result_type=int,
238
+ )
239
+
240
+ def __rshift__(self, other: Union[ColT, float]) -> "Func":
241
+ if isinstance(other, (int, float)):
242
+ return Func(
243
+ "rshift",
244
+ lambda a: numeric.bit_rshift(a, other),
245
+ [self],
246
+ result_type=int,
247
+ )
248
+ return Func(
249
+ "rshift",
250
+ lambda a1, a2: numeric.bit_rshift(a1, a2),
251
+ [self, other],
252
+ result_type=int,
253
+ )
254
+
255
+ def __rrshift__(self, other: Union[ColT, float]) -> "Func":
256
+ if isinstance(other, (int, float)):
257
+ return Func(
258
+ "rshift",
259
+ lambda a: numeric.bit_rshift(other, a),
260
+ [self],
261
+ result_type=int,
262
+ )
263
+ return Func(
264
+ "rshift",
265
+ lambda a1, a2: numeric.bit_rshift(a1, a2),
266
+ [other, self],
267
+ result_type=int,
268
+ )
269
+
270
+ def __lshift__(self, other: Union[ColT, float]) -> "Func":
271
+ if isinstance(other, (int, float)):
272
+ return Func(
273
+ "lshift",
274
+ lambda a: numeric.bit_lshift(a, other),
275
+ [self],
276
+ result_type=int,
277
+ )
278
+ return Func(
279
+ "lshift",
280
+ lambda a1, a2: numeric.bit_lshift(a1, a2),
281
+ [self, other],
282
+ result_type=int,
283
+ )
284
+
285
+ def __rlshift__(self, other: Union[ColT, float]) -> "Func":
286
+ if isinstance(other, (int, float)):
287
+ return Func(
288
+ "lshift",
289
+ lambda a: numeric.bit_lshift(other, a),
290
+ [self],
291
+ result_type=int,
292
+ )
293
+ return Func(
294
+ "lshift",
295
+ lambda a1, a2: numeric.bit_lshift(a1, a2),
296
+ [other, self],
297
+ result_type=int,
298
+ )
169
299
 
170
300
  def __lt__(self, other: Union[ColT, float]) -> "Func":
171
- return math_lt(self, other)
301
+ if isinstance(other, (int, float)):
302
+ return Func("lt", lambda a: a < other, [self], result_type=bool)
303
+ return Func("lt", lambda a1, a2: a1 < a2, [self, other], result_type=bool)
172
304
 
173
305
  def __le__(self, other: Union[ColT, float]) -> "Func":
174
- return math_le(self, other)
306
+ if isinstance(other, (int, float)):
307
+ return Func("le", lambda a: a <= other, [self], result_type=bool)
308
+ return Func("le", lambda a1, a2: a1 <= a2, [self, other], result_type=bool)
175
309
 
176
310
  def __eq__(self, other):
177
- return math_eq(self, other)
311
+ if isinstance(other, (int, float)):
312
+ return Func("eq", lambda a: a == other, [self], result_type=bool)
313
+ return Func("eq", lambda a1, a2: a1 == a2, [self, other], result_type=bool)
178
314
 
179
315
  def __ne__(self, other):
180
- return math_ne(self, other)
316
+ if isinstance(other, (int, float)):
317
+ return Func("ne", lambda a: a != other, [self], result_type=bool)
318
+ return Func("ne", lambda a1, a2: a1 != a2, [self, other], result_type=bool)
181
319
 
182
320
  def __gt__(self, other: Union[ColT, float]) -> "Func":
183
- return math_gt(self, other)
321
+ if isinstance(other, (int, float)):
322
+ return Func("gt", lambda a: a > other, [self], result_type=bool)
323
+ return Func("gt", lambda a1, a2: a1 > a2, [self, other], result_type=bool)
184
324
 
185
325
  def __ge__(self, other: Union[ColT, float]) -> "Func":
186
- return math_ge(self, other)
326
+ if isinstance(other, (int, float)):
327
+ return Func("ge", lambda a: a >= other, [self], result_type=bool)
328
+ return Func("ge", lambda a1, a2: a1 >= a2, [self, other], result_type=bool)
187
329
 
188
330
  def label(self, label: str) -> "Func":
189
331
  return Func(
@@ -273,112 +415,20 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
273
415
  if isinstance(col, Func):
274
416
  return col.get_result_type(signals_schema)
275
417
 
418
+ if isinstance(col, ColumnElement) and not hasattr(col, "name"):
419
+ return sql_to_python(col)
420
+
276
421
  return signals_schema.get_column_type(
277
422
  col.name if isinstance(col, ColumnElement) else col
278
423
  )
279
424
 
280
425
 
281
- def math_func(
282
- name: str,
283
- inner: Callable,
284
- params: Sequence[Union[ColT, float]],
285
- result_type: Optional["DataType"] = None,
286
- ) -> Func:
287
- """Returns math function from the columns."""
288
- cols, args = [], []
289
- for arg in params:
290
- if isinstance(arg, (int, float)):
291
- args.append(arg)
292
- else:
293
- cols.append(arg)
294
- return Func(name, inner, cols=cols, args=args, result_type=result_type)
295
-
296
-
297
- def math_add(*args: Union[ColT, float]) -> Func:
298
- """Computes the sum of the column."""
299
- return math_func("add", lambda a1, a2: a1 + a2, args)
300
-
301
-
302
- def math_sub(*args: Union[ColT, float]) -> Func:
303
- """Computes the diff of the column."""
304
- return math_func("sub", lambda a1, a2: a1 - a2, args)
305
-
306
-
307
- def math_mul(*args: Union[ColT, float]) -> Func:
308
- """Computes the product of the column."""
309
- return math_func("mul", lambda a1, a2: a1 * a2, args)
310
-
311
-
312
- def math_truediv(*args: Union[ColT, float]) -> Func:
313
- """Computes the division of the column."""
314
- return math_func("div", lambda a1, a2: a1 / a2, args, result_type=float)
315
-
316
-
317
- def math_floordiv(*args: Union[ColT, float]) -> Func:
318
- """Computes the floor division of the column."""
319
- return math_func("floordiv", lambda a1, a2: a1 // a2, args, result_type=float)
320
-
321
-
322
- def math_mod(*args: Union[ColT, float]) -> Func:
323
- """Computes the modulo of the column."""
324
- return math_func("mod", lambda a1, a2: a1 % a2, args, result_type=float)
325
-
326
-
327
- def math_pow(*args: Union[ColT, float]) -> Func:
328
- """Computes the power of the column."""
329
- return math_func("pow", lambda a1, a2: a1**a2, args, result_type=float)
330
-
331
-
332
- def math_lshift(*args: Union[ColT, float]) -> Func:
333
- """Computes the left shift of the column."""
334
- return math_func("lshift", lambda a1, a2: a1 << a2, args, result_type=int)
335
-
336
-
337
- def math_rshift(*args: Union[ColT, float]) -> Func:
338
- """Computes the right shift of the column."""
339
- return math_func("rshift", lambda a1, a2: a1 >> a2, args, result_type=int)
340
-
341
-
342
- def math_and(*args: Union[ColT, float]) -> Func:
343
- """Computes the logical AND of the column."""
344
- return math_func("and", lambda a1, a2: a1 & a2, args, result_type=bool)
345
-
346
-
347
- def math_or(*args: Union[ColT, float]) -> Func:
348
- """Computes the logical OR of the column."""
349
- return math_func("or", lambda a1, a2: a1 | a2, args, result_type=bool)
350
-
351
-
352
- def math_xor(*args: Union[ColT, float]) -> Func:
353
- """Computes the logical XOR of the column."""
354
- return math_func("xor", lambda a1, a2: a1 ^ a2, args, result_type=bool)
355
-
356
-
357
- def math_lt(*args: Union[ColT, float]) -> Func:
358
- """Computes the less than comparison of the column."""
359
- return math_func("lt", lambda a1, a2: a1 < a2, args, result_type=bool)
360
-
361
-
362
- def math_le(*args: Union[ColT, float]) -> Func:
363
- """Computes the less than or equal comparison of the column."""
364
- return math_func("le", lambda a1, a2: a1 <= a2, args, result_type=bool)
365
-
366
-
367
- def math_eq(*args: Union[ColT, float]) -> Func:
368
- """Computes the equality comparison of the column."""
369
- return math_func("eq", lambda a1, a2: a1 == a2, args, result_type=bool)
370
-
371
-
372
- def math_ne(*args: Union[ColT, float]) -> Func:
373
- """Computes the inequality comparison of the column."""
374
- return math_func("ne", lambda a1, a2: a1 != a2, args, result_type=bool)
375
-
376
-
377
- def math_gt(*args: Union[ColT, float]) -> Func:
378
- """Computes the greater than comparison of the column."""
379
- return math_func("gt", lambda a1, a2: a1 > a2, args, result_type=bool)
426
+ def _truediv(a, b):
427
+ # Using sqlalchemy.sql.func.divide here instead of / operator
428
+ # because of a bug in ClickHouse SQLAlchemy dialect
429
+ # See https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/335
430
+ return sa_func.divide(a, b)
380
431
 
381
432
 
382
- def math_ge(*args: Union[ColT, float]) -> Func:
383
- """Computes the greater than or equal comparison of the column."""
384
- return math_func("ge", lambda a1, a2: a1 >= a2, args, result_type=bool)
433
+ def _floordiv(a, b):
434
+ return cast(_truediv(a, b), Integer)
@@ -0,0 +1,162 @@
1
+ from typing import Union
2
+
3
+ from datachain.sql.functions import numeric
4
+
5
+ from .func import ColT, Func
6
+
7
+
8
+ def bit_and(*args: Union[ColT, int]) -> Func:
9
+ """
10
+ Computes the bitwise AND operation between two values.
11
+
12
+ Args:
13
+ args (str | int): Two values to compute the bitwise AND operation between.
14
+ If a string is provided, it is assumed to be the name of the column vector.
15
+ If an integer is provided, it is assumed to be a constant value.
16
+
17
+ Returns:
18
+ Func: A Func object that represents the bitwise AND function.
19
+
20
+ Example:
21
+ ```py
22
+ dc.mutate(
23
+ xor1=func.bit_and("signal.values", 0x0F),
24
+ )
25
+ ```
26
+
27
+ Notes:
28
+ - Result column will always be of type int.
29
+ """
30
+ cols, func_args = [], []
31
+ for arg in args:
32
+ if isinstance(arg, int):
33
+ func_args.append(arg)
34
+ else:
35
+ cols.append(arg)
36
+
37
+ if len(cols) + len(func_args) != 2:
38
+ raise ValueError("bit_and() requires exactly two arguments")
39
+
40
+ return Func(
41
+ "bit_and",
42
+ inner=numeric.bit_and,
43
+ cols=cols,
44
+ args=func_args,
45
+ result_type=int,
46
+ )
47
+
48
+
49
+ def bit_or(*args: Union[ColT, int]) -> Func:
50
+ """
51
+ Computes the bitwise AND operation between two values.
52
+
53
+ Args:
54
+ args (str | int): Two values to compute the bitwise OR operation between.
55
+ If a string is provided, it is assumed to be the name of the column vector.
56
+ If an integer is provided, it is assumed to be a constant value.
57
+
58
+ Returns:
59
+ Func: A Func object that represents the bitwise OR function.
60
+
61
+ Example:
62
+ ```py
63
+ dc.mutate(
64
+ xor1=func.bit_or("signal.values", 0x0F),
65
+ )
66
+ ```
67
+
68
+ Notes:
69
+ - Result column will always be of type int.
70
+ """
71
+ cols, func_args = [], []
72
+ for arg in args:
73
+ if isinstance(arg, int):
74
+ func_args.append(arg)
75
+ else:
76
+ cols.append(arg)
77
+
78
+ if len(cols) + len(func_args) != 2:
79
+ raise ValueError("bit_or() requires exactly two arguments")
80
+
81
+ return Func(
82
+ "bit_or",
83
+ inner=numeric.bit_or,
84
+ cols=cols,
85
+ args=func_args,
86
+ result_type=int,
87
+ )
88
+
89
+
90
+ def bit_xor(*args: Union[ColT, int]) -> Func:
91
+ """
92
+ Computes the bitwise XOR operation between two values.
93
+
94
+ Args:
95
+ args (str | int): Two values to compute the bitwise XOR operation between.
96
+ If a string is provided, it is assumed to be the name of the column vector.
97
+ If an integer is provided, it is assumed to be a constant value.
98
+
99
+ Returns:
100
+ Func: A Func object that represents the bitwise XOR function.
101
+
102
+ Example:
103
+ ```py
104
+ dc.mutate(
105
+ xor1=func.bit_xor("signal.values", 0x0F),
106
+ )
107
+ ```
108
+
109
+ Notes:
110
+ - Result column will always be of type int.
111
+ """
112
+ cols, func_args = [], []
113
+ for arg in args:
114
+ if isinstance(arg, int):
115
+ func_args.append(arg)
116
+ else:
117
+ cols.append(arg)
118
+
119
+ if len(cols) + len(func_args) != 2:
120
+ raise ValueError("bit_xor() requires exactly two arguments")
121
+
122
+ return Func(
123
+ "bit_xor",
124
+ inner=numeric.bit_xor,
125
+ cols=cols,
126
+ args=func_args,
127
+ result_type=int,
128
+ )
129
+
130
+
131
+ def int_hash_64(col: Union[ColT, int]) -> Func:
132
+ """
133
+ Returns the 64-bit hash of an integer.
134
+
135
+ Args:
136
+ col (str | int): String to compute the hash of.
137
+ If a string is provided, it is assumed to be the name of the column.
138
+ If a int is provided, it is assumed to be an int literal.
139
+ If a Func is provided, it is assumed to be a function returning an int.
140
+
141
+ Returns:
142
+ Func: A Func object that represents the 64-bit hash function.
143
+
144
+ Example:
145
+ ```py
146
+ dc.mutate(
147
+ val_hash=func.int_hash_64("val"),
148
+ )
149
+ ```
150
+
151
+ Note:
152
+ - Result column will always be of type int.
153
+ """
154
+ cols, args = [], []
155
+ if isinstance(col, int):
156
+ args.append(col)
157
+ else:
158
+ cols.append(col)
159
+
160
+ return Func(
161
+ "int_hash_64", inner=numeric.int_hash_64, cols=cols, args=args, result_type=int
162
+ )
datachain/lib/dc.py CHANGED
@@ -1150,7 +1150,7 @@ class DataChain:
1150
1150
  def group_by(
1151
1151
  self,
1152
1152
  *,
1153
- partition_by: Union[str, Func, Sequence[Union[str, Func]]],
1153
+ partition_by: Optional[Union[str, Func, Sequence[Union[str, Func]]]] = None,
1154
1154
  **kwargs: Func,
1155
1155
  ) -> "Self":
1156
1156
  """Group rows by specified set of signals and return new signals
@@ -1167,10 +1167,10 @@ class DataChain:
1167
1167
  )
1168
1168
  ```
1169
1169
  """
1170
- if isinstance(partition_by, (str, Func)):
1170
+ if partition_by is None:
1171
+ partition_by = []
1172
+ elif isinstance(partition_by, (str, Func)):
1171
1173
  partition_by = [partition_by]
1172
- if not partition_by:
1173
- raise ValueError("At least one column should be provided for partition_by")
1174
1174
 
1175
1175
  partition_by_columns: list[Column] = []
1176
1176
  signal_columns: list[Column] = []
@@ -966,8 +966,6 @@ class SQLGroupBy(SQLClause):
966
966
  def apply_sql_clause(self, query) -> Select:
967
967
  if not self.cols:
968
968
  raise ValueError("No columns to select")
969
- if not self.group_by:
970
- raise ValueError("No columns to group by")
971
969
 
972
970
  subquery = query.subquery()
973
971
 
@@ -38,6 +38,10 @@ class length(GenericFunction): # noqa: N801
38
38
 
39
39
 
40
40
  class sip_hash_64(GenericFunction): # noqa: N801
41
+ """
42
+ Computes the SipHash-64 hash of the array.
43
+ """
44
+
41
45
  type = Int64()
42
46
  package = "hash"
43
47
  name = "sip_hash_64"
@@ -0,0 +1,43 @@
1
+ from sqlalchemy.sql.functions import GenericFunction, ReturnTypeFromArgs
2
+
3
+ from datachain.sql.types import Int64
4
+ from datachain.sql.utils import compiler_not_implemented
5
+
6
+
7
+ class bit_and(ReturnTypeFromArgs): # noqa: N801
8
+ inherit_cache = True
9
+
10
+
11
+ class bit_or(ReturnTypeFromArgs): # noqa: N801
12
+ inherit_cache = True
13
+
14
+
15
+ class bit_xor(ReturnTypeFromArgs): # noqa: N801
16
+ inherit_cache = True
17
+
18
+
19
+ class bit_rshift(ReturnTypeFromArgs): # noqa: N801
20
+ inherit_cache = True
21
+
22
+
23
+ class bit_lshift(ReturnTypeFromArgs): # noqa: N801
24
+ inherit_cache = True
25
+
26
+
27
+ class int_hash_64(GenericFunction): # noqa: N801
28
+ """
29
+ Computes the 64-bit hash of an integer.
30
+ """
31
+
32
+ type = Int64()
33
+ package = "hash"
34
+ name = "int_hash_64"
35
+ inherit_cache = True
36
+
37
+
38
+ compiler_not_implemented(bit_and)
39
+ compiler_not_implemented(bit_or)
40
+ compiler_not_implemented(bit_xor)
41
+ compiler_not_implemented(bit_rshift)
42
+ compiler_not_implemented(bit_lshift)
43
+ compiler_not_implemented(int_hash_64)
@@ -15,7 +15,14 @@ from sqlalchemy.sql.elements import literal
15
15
  from sqlalchemy.sql.expression import case
16
16
  from sqlalchemy.sql.functions import func
17
17
 
18
- from datachain.sql.functions import aggregate, array, conditional, random, string
18
+ from datachain.sql.functions import (
19
+ aggregate,
20
+ array,
21
+ conditional,
22
+ numeric,
23
+ random,
24
+ string,
25
+ )
19
26
  from datachain.sql.functions import path as sql_path
20
27
  from datachain.sql.selectable import Values, base_values_compiler
21
28
  from datachain.sql.sqlite.types import (
@@ -47,6 +54,8 @@ slash = literal("/")
47
54
  empty_str = literal("")
48
55
  dot = literal(".")
49
56
 
57
+ MAX_INT64 = 2**64 - 1
58
+
50
59
 
51
60
  def setup():
52
61
  global setup_is_complete # noqa: PLW0603
@@ -89,6 +98,12 @@ def setup():
89
98
  compiles(aggregate.group_concat, "sqlite")(compile_group_concat)
90
99
  compiles(aggregate.any_value, "sqlite")(compile_any_value)
91
100
  compiles(aggregate.collect, "sqlite")(compile_collect)
101
+ compiles(numeric.bit_and, "sqlite")(compile_bitwise_and)
102
+ compiles(numeric.bit_or, "sqlite")(compile_bitwise_or)
103
+ compiles(numeric.bit_xor, "sqlite")(compile_bitwise_xor)
104
+ compiles(numeric.bit_rshift, "sqlite")(compile_bitwise_rshift)
105
+ compiles(numeric.bit_lshift, "sqlite")(compile_bitwise_lshift)
106
+ compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64)
92
107
 
93
108
  if load_usearch_extension(sqlite3.connect(":memory:")):
94
109
  compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
@@ -163,6 +178,19 @@ def sqlite_string_split(string: str, sep: str, maxsplit: int = -1) -> str:
163
178
  return orjson.dumps(string.split(sep, maxsplit)).decode("utf-8")
164
179
 
165
180
 
181
+ def sqlite_int_hash_64(x: int) -> int:
182
+ """IntHash64 implementation from ClickHouse."""
183
+ x ^= 0x4CF2D2BAAE6DA887
184
+ x ^= x >> 33
185
+ x = (x * 0xFF51AFD7ED558CCD) & MAX_INT64
186
+ x ^= x >> 33
187
+ x = (x * 0xC4CEB9FE1A85EC53) & MAX_INT64
188
+ x ^= x >> 33
189
+ # SQLite does not support unsigned 64-bit integers,
190
+ # so we need to convert to signed 64-bit
191
+ return x if x < 1 << 63 else (x & MAX_INT64) - (1 << 64)
192
+
193
+
166
194
  def register_user_defined_sql_functions() -> None:
167
195
  # Register optional functions if we have the necessary dependencies
168
196
  # and otherwise register functions that will raise an exception with
@@ -185,6 +213,21 @@ def register_user_defined_sql_functions() -> None:
185
213
 
186
214
  _registered_function_creators["vector_functions"] = create_vector_functions
187
215
 
216
+ def create_numeric_functions(conn):
217
+ conn.create_function("divide", 2, lambda a, b: a / b, deterministic=True)
218
+ conn.create_function("bitwise_and", 2, lambda a, b: a & b, deterministic=True)
219
+ conn.create_function("bitwise_or", 2, lambda a, b: a | b, deterministic=True)
220
+ conn.create_function("bitwise_xor", 2, lambda a, b: a ^ b, deterministic=True)
221
+ conn.create_function(
222
+ "bitwise_rshift", 2, lambda a, b: a >> b, deterministic=True
223
+ )
224
+ conn.create_function(
225
+ "bitwise_lshift", 2, lambda a, b: a << b, deterministic=True
226
+ )
227
+ conn.create_function("int_hash_64", 1, sqlite_int_hash_64, deterministic=True)
228
+
229
+ _registered_function_creators["numeric_functions"] = create_numeric_functions
230
+
188
231
  def sqlite_regexp_replace(string: str, pattern: str, replacement: str) -> str:
189
232
  return re.sub(pattern, replacement, string)
190
233
 
@@ -316,6 +359,30 @@ def compile_euclidean_distance(element, compiler, **kwargs):
316
359
  return f"euclidean_distance({compiler.process(element.clauses, **kwargs)})"
317
360
 
318
361
 
362
+ def compile_bitwise_and(element, compiler, **kwargs):
363
+ return compiler.process(func.bitwise_and(*element.clauses.clauses), **kwargs)
364
+
365
+
366
+ def compile_bitwise_or(element, compiler, **kwargs):
367
+ return compiler.process(func.bitwise_or(*element.clauses.clauses), **kwargs)
368
+
369
+
370
+ def compile_bitwise_xor(element, compiler, **kwargs):
371
+ return compiler.process(func.bitwise_xor(*element.clauses.clauses), **kwargs)
372
+
373
+
374
+ def compile_bitwise_rshift(element, compiler, **kwargs):
375
+ return compiler.process(func.bitwise_rshift(*element.clauses.clauses), **kwargs)
376
+
377
+
378
+ def compile_bitwise_lshift(element, compiler, **kwargs):
379
+ return compiler.process(func.bitwise_lshift(*element.clauses.clauses), **kwargs)
380
+
381
+
382
+ def compile_int_hash_64(element, compiler, **kwargs):
383
+ return compiler.process(func.int_hash_64(*element.clauses.clauses), **kwargs)
384
+
385
+
319
386
  def py_json_array_length(arr):
320
387
  return len(orjson.loads(arr))
321
388
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: datachain
3
- Version: 0.7.6
3
+ Version: 0.7.8
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -37,12 +37,13 @@ datachain/data_storage/schema.py,sha256=-QVlRvD0dfu-ZFUxylEoSnLJLnleMEjVlcAb2OGu
37
37
  datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2krunWUum5o,927
38
38
  datachain/data_storage/sqlite.py,sha256=D_ZQ0PHmZzHO2dinv4naVJocUDIZUwV4WAz692C1cyk,22521
39
39
  datachain/data_storage/warehouse.py,sha256=tjIkU-5JywBR0apCyqTcwSyaRtGxhu2L7IVjrz-55uc,30802
40
- datachain/func/__init__.py,sha256=4VUt5BaLdBAl_BnAku0Jb8plqd7kDOiYrQTMG3pN0c4,794
40
+ datachain/func/__init__.py,sha256=oz-GbCcp5jnN82u6cghWTGzmU9IQvtvllOof73wE52g,934
41
41
  datachain/func/aggregate.py,sha256=7_IPrIwb2XSs3zG4iOr1eTvzn6kNVe2mkzvNzjusDHk,10942
42
42
  datachain/func/array.py,sha256=zHDNWuWLA7HVa9FEvQeHhVi00_xqenyleTqcLwkXWBI,5477
43
43
  datachain/func/base.py,sha256=wA0sBQAVyN9LPxoo7Ox83peS0zUVnyuKxukwAcjGLfY,534
44
44
  datachain/func/conditional.py,sha256=mQroxsoExpBW84Zm5dAYP4OpBblWmzfnF2qJq9rba54,2223
45
- datachain/func/func.py,sha256=9wqdxxisoDL0w8qKGQmL6sNdgJeIOzotEUPlxu9t2IQ,12326
45
+ datachain/func/func.py,sha256=mJ_rOXMpoqnK4-d5eF9boSMx5hWzgKoMLPGpZQqLAfw,15222
46
+ datachain/func/numeric.py,sha256=GcUX6ijZvzfac8CZrHE0gRc9WCPiutcMLKqNXtbn-Yo,4186
46
47
  datachain/func/path.py,sha256=mqN_mfkwv44z2II7DMTp_fGGw95hmTCNls_TOFNpr4k,3155
47
48
  datachain/func/random.py,sha256=pENOLj9rSmWfGCnOsUIaCsVC5486zQb66qfQvXaz9Z4,452
48
49
  datachain/func/string.py,sha256=NQzaXXYu7yb72HPADy4WrFlcgvTS77L9x7-qvCKJtnk,4522
@@ -52,7 +53,7 @@ datachain/lib/arrow.py,sha256=b5efxAUaNNYVwtXVJqj07D3zf5KC-BPlLCxKEZbEG6w,9429
52
53
  datachain/lib/clip.py,sha256=lm5CzVi4Cj1jVLEKvERKArb-egb9j1Ls-fwTItT6vlI,6150
53
54
  datachain/lib/data_model.py,sha256=zS4lmXHVBXc9ntcyea2a1CRLXGSAN_0glXcF88CohgY,2685
54
55
  datachain/lib/dataset_info.py,sha256=IjdF1E0TQNOq9YyynfWiCFTeZpbyGfyJvxgJY4YN810,2493
55
- datachain/lib/dc.py,sha256=J7liATKQBJCkeHanVLr0s3d1t5wxiiiSJuSbuxKBbLg,89527
56
+ datachain/lib/dc.py,sha256=t5y5tsYyU7uuk3gEPPhhcDSZ1tL1aHkKG2W54eHiUq8,89492
56
57
  datachain/lib/file.py,sha256=-XMkL6ED1sE7TMhWoMRTEuOXswZJw8X6AEmJDONFP74,15019
57
58
  datachain/lib/hf.py,sha256=a-zFpDmZIR4r8dlNNTjfpAKSnuJ9xyRXlgcdENiXt3E,5864
58
59
  datachain/lib/image.py,sha256=AMXYwQsmarZjRbPCZY3M1jDsM2WAB_b3cTY4uOIuXNU,2675
@@ -87,7 +88,7 @@ datachain/model/ultralytics/pose.py,sha256=71KBTcoST2wcEtsyGXqLVpvUtqbp9gwZGA15p
87
88
  datachain/model/ultralytics/segment.py,sha256=Z1ab0tZRJubSYNH4KkFlzhYeGNTfAyC71KmkQcToHDQ,2760
88
89
  datachain/query/__init__.py,sha256=7DhEIjAA8uZJfejruAVMZVcGFmvUpffuZJwgRqNwe-c,263
89
90
  datachain/query/batch.py,sha256=5fEhORFe7li12SdYddaSK3LyqksMfCHhwN1_A6TfsA4,3485
90
- datachain/query/dataset.py,sha256=o9Ssa47t1IM78qcaoCeTL-rp4fZCpYfR7XFjw2hGWeY,54632
91
+ datachain/query/dataset.py,sha256=J6SbCLnFlZgCxRchc3tVk5tcC7xo1Hp616JGlEZXCDo,54547
91
92
  datachain/query/dispatch.py,sha256=fZ0TgGFRcsrYh1iXQoZVjkUl4Xetom9PSHoeDes3IRs,11606
92
93
  datachain/query/metrics.py,sha256=r5b0ygYhokbXp8Mg3kCH8iFSRw0jxzyeBe-C-J_bKFc,938
93
94
  datachain/query/params.py,sha256=O_j89mjYRLOwWNhYZl-z7mi-rkdP7WyFmaDufsdTryE,863
@@ -104,21 +105,22 @@ datachain/sql/default/__init__.py,sha256=XQ2cEZpzWiABqjV-6yYHUBGI9vN_UHxbxZENESm
104
105
  datachain/sql/default/base.py,sha256=QD-31C6JnyOXzogyDx90sUhm7QvgXIYpeHEASH84igU,628
105
106
  datachain/sql/functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
106
107
  datachain/sql/functions/aggregate.py,sha256=3AQdA8YHPFdtCEfwZKQXTT8SlQWdG9gD5PBtGN3Odqs,944
107
- datachain/sql/functions/array.py,sha256=rvH27SWN9gdh_mFnp0GIiXuCrNW6n8ZbY4I_JUS-_e0,1140
108
+ datachain/sql/functions/array.py,sha256=Zq59CaMHf_hFapU4kxvy2mwteH344k5Wksxja4MfBks,1204
108
109
  datachain/sql/functions/conditional.py,sha256=q7YUKfunXeEldXaxgT-p5pUTcOEVU_tcQ2BJlquTRPs,207
110
+ datachain/sql/functions/numeric.py,sha256=DFTTEWsvBBXwbaaC4zdxhAoqUYwI6nbymG-nzbzdPv8,972
109
111
  datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0mg,1294
110
112
  datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
111
113
  datachain/sql/functions/string.py,sha256=DYgiw8XSk7ge7GXvyRI1zbaMruIizNeI-puOjriQGZQ,1148
112
114
  datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
113
- datachain/sql/sqlite/base.py,sha256=X4iEynOAqqvqz8lmgUKvURleKO6aguULgG8RoufKrSk,14772
115
+ datachain/sql/sqlite/base.py,sha256=eQv2U32jChG9tnYSFE4SS2Mvfb7-W3Ok3Ffhew9qkKI,17254
114
116
  datachain/sql/sqlite/types.py,sha256=lPXS1XbkmUtlkkiRxy_A_UzsgpPv2VSkXYOD4zIHM4w,1734
115
117
  datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
116
118
  datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
117
119
  datachain/toolkit/split.py,sha256=ZgDcrNiKiPXZmKD591_1z9qRIXitu5zwAsoVPB7ykiU,2508
118
120
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
119
- datachain-0.7.6.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
120
- datachain-0.7.6.dist-info/METADATA,sha256=KMChqSG7d_lMaF9BYNIgmijvnxZbDm5gCEg980gUGOA,18006
121
- datachain-0.7.6.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
122
- datachain-0.7.6.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
123
- datachain-0.7.6.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
124
- datachain-0.7.6.dist-info/RECORD,,
121
+ datachain-0.7.8.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
122
+ datachain-0.7.8.dist-info/METADATA,sha256=r8znUWHdmY5y6hk8N9NFdlrKaHKkteeji7NXJTb2Ges,18006
123
+ datachain-0.7.8.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
124
+ datachain-0.7.8.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
125
+ datachain-0.7.8.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
126
+ datachain-0.7.8.dist-info/RECORD,,