datachain 0.8.12__py3-none-any.whl → 0.8.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/func/__init__.py +2 -1
- datachain/func/array.py +39 -1
- datachain/sql/functions/array.py +13 -1
- datachain/sql/sqlite/base.py +17 -1
- datachain/sql/sqlite/types.py +5 -0
- {datachain-0.8.12.dist-info → datachain-0.8.13.dist-info}/METADATA +1 -1
- {datachain-0.8.12.dist-info → datachain-0.8.13.dist-info}/RECORD +11 -11
- {datachain-0.8.12.dist-info → datachain-0.8.13.dist-info}/LICENSE +0 -0
- {datachain-0.8.12.dist-info → datachain-0.8.13.dist-info}/WHEEL +0 -0
- {datachain-0.8.12.dist-info → datachain-0.8.13.dist-info}/entry_points.txt +0 -0
- {datachain-0.8.12.dist-info → datachain-0.8.13.dist-info}/top_level.txt +0 -0
datachain/func/__init__.py
CHANGED
|
@@ -15,7 +15,7 @@ from .aggregate import (
|
|
|
15
15
|
row_number,
|
|
16
16
|
sum,
|
|
17
17
|
)
|
|
18
|
-
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
|
|
18
|
+
from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
|
|
19
19
|
from .conditional import case, greatest, ifelse, isnone, least
|
|
20
20
|
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
|
|
21
21
|
from .random import rand
|
|
@@ -34,6 +34,7 @@ __all__ = [
|
|
|
34
34
|
"case",
|
|
35
35
|
"collect",
|
|
36
36
|
"concat",
|
|
37
|
+
"contains",
|
|
37
38
|
"cosine_distance",
|
|
38
39
|
"count",
|
|
39
40
|
"dense_rank",
|
datachain/func/array.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from typing import Union
|
|
2
|
+
from typing import Any, Union
|
|
3
3
|
|
|
4
4
|
from datachain.sql.functions import array
|
|
5
5
|
|
|
@@ -140,6 +140,44 @@ def length(arg: Union[str, Sequence, Func]) -> Func:
|
|
|
140
140
|
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
|
|
141
141
|
|
|
142
142
|
|
|
143
|
+
def contains(arr: Union[str, Sequence, Func], elem: Any) -> Func:
|
|
144
|
+
"""
|
|
145
|
+
Checks whether the `arr` array has the `elem` element.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
arr (str | Sequence | Func): Array to check for the element.
|
|
149
|
+
If a string is provided, it is assumed to be the name of the array column.
|
|
150
|
+
If a sequence is provided, it is assumed to be an array of values.
|
|
151
|
+
If a Func is provided, it is assumed to be a function returning an array.
|
|
152
|
+
elem (Any): Element to check for in the array.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Func: A Func object that represents the contains function. Result of the
|
|
156
|
+
function will be 1 if the element is present in the array, and 0 otherwise.
|
|
157
|
+
|
|
158
|
+
Example:
|
|
159
|
+
```py
|
|
160
|
+
dc.mutate(
|
|
161
|
+
contains1=func.array.contains("signal.values", 3),
|
|
162
|
+
contains2=func.array.contains([1, 2, 3, 4, 5], 7),
|
|
163
|
+
)
|
|
164
|
+
```
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
def inner(arg):
|
|
168
|
+
is_json = type(elem) in [list, dict]
|
|
169
|
+
return array.contains(arg, elem, is_json)
|
|
170
|
+
|
|
171
|
+
if isinstance(arr, (str, Func)):
|
|
172
|
+
cols = [arr]
|
|
173
|
+
args = None
|
|
174
|
+
else:
|
|
175
|
+
cols = None
|
|
176
|
+
args = [arr]
|
|
177
|
+
|
|
178
|
+
return Func("contains", inner=inner, cols=cols, args=args, result_type=int)
|
|
179
|
+
|
|
180
|
+
|
|
143
181
|
def sip_hash_64(arg: Union[str, Sequence]) -> Func:
|
|
144
182
|
"""
|
|
145
183
|
Computes the SipHash-64 hash of the array.
|
datachain/sql/functions/array.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from sqlalchemy.sql.functions import GenericFunction
|
|
2
2
|
|
|
3
|
-
from datachain.sql.types import Float, Int64
|
|
3
|
+
from datachain.sql.types import Boolean, Float, Int64
|
|
4
4
|
from datachain.sql.utils import compiler_not_implemented
|
|
5
5
|
|
|
6
6
|
|
|
@@ -37,6 +37,17 @@ class length(GenericFunction): # noqa: N801
|
|
|
37
37
|
inherit_cache = True
|
|
38
38
|
|
|
39
39
|
|
|
40
|
+
class contains(GenericFunction): # noqa: N801
|
|
41
|
+
"""
|
|
42
|
+
Checks if element is in the array.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
type = Boolean()
|
|
46
|
+
package = "array"
|
|
47
|
+
name = "contains"
|
|
48
|
+
inherit_cache = True
|
|
49
|
+
|
|
50
|
+
|
|
40
51
|
class sip_hash_64(GenericFunction): # noqa: N801
|
|
41
52
|
"""
|
|
42
53
|
Computes the SipHash-64 hash of the array.
|
|
@@ -51,4 +62,5 @@ class sip_hash_64(GenericFunction): # noqa: N801
|
|
|
51
62
|
compiler_not_implemented(cosine_distance)
|
|
52
63
|
compiler_not_implemented(euclidean_distance)
|
|
53
64
|
compiler_not_implemented(length)
|
|
65
|
+
compiler_not_implemented(contains)
|
|
54
66
|
compiler_not_implemented(sip_hash_64)
|
datachain/sql/sqlite/base.py
CHANGED
|
@@ -87,6 +87,7 @@ def setup():
|
|
|
87
87
|
compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
|
|
88
88
|
compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
|
|
89
89
|
compiles(array.length, "sqlite")(compile_array_length)
|
|
90
|
+
compiles(array.contains, "sqlite")(compile_array_contains)
|
|
90
91
|
compiles(string.length, "sqlite")(compile_string_length)
|
|
91
92
|
compiles(string.split, "sqlite")(compile_string_split)
|
|
92
93
|
compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
|
|
@@ -269,13 +270,16 @@ def register_user_defined_sql_functions() -> None:
|
|
|
269
270
|
|
|
270
271
|
_registered_function_creators["string_functions"] = create_string_functions
|
|
271
272
|
|
|
272
|
-
has_json_extension = functions_exist(["json_array_length"])
|
|
273
|
+
has_json_extension = functions_exist(["json_array_length", "json_array_contains"])
|
|
273
274
|
if not has_json_extension:
|
|
274
275
|
|
|
275
276
|
def create_json_functions(conn):
|
|
276
277
|
conn.create_function(
|
|
277
278
|
"json_array_length", 1, py_json_array_length, deterministic=True
|
|
278
279
|
)
|
|
280
|
+
conn.create_function(
|
|
281
|
+
"json_array_contains", 3, py_json_array_contains, deterministic=True
|
|
282
|
+
)
|
|
279
283
|
|
|
280
284
|
_registered_function_creators["json_functions"] = create_json_functions
|
|
281
285
|
|
|
@@ -428,10 +432,22 @@ def py_json_array_length(arr):
|
|
|
428
432
|
return len(orjson.loads(arr))
|
|
429
433
|
|
|
430
434
|
|
|
435
|
+
def py_json_array_contains(arr, value, is_json):
|
|
436
|
+
if is_json:
|
|
437
|
+
value = orjson.loads(value)
|
|
438
|
+
return value in orjson.loads(arr)
|
|
439
|
+
|
|
440
|
+
|
|
431
441
|
def compile_array_length(element, compiler, **kwargs):
|
|
432
442
|
return compiler.process(func.json_array_length(*element.clauses.clauses), **kwargs)
|
|
433
443
|
|
|
434
444
|
|
|
445
|
+
def compile_array_contains(element, compiler, **kwargs):
|
|
446
|
+
return compiler.process(
|
|
447
|
+
func.json_array_contains(*element.clauses.clauses), **kwargs
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
435
451
|
def compile_string_length(element, compiler, **kwargs):
|
|
436
452
|
return compiler.process(func.length(*element.clauses.clauses), **kwargs)
|
|
437
453
|
|
datachain/sql/sqlite/types.py
CHANGED
|
@@ -31,6 +31,10 @@ def adapt_array(arr):
|
|
|
31
31
|
return orjson.dumps(arr).decode("utf-8")
|
|
32
32
|
|
|
33
33
|
|
|
34
|
+
def adapt_dict(dct):
|
|
35
|
+
return orjson.dumps(dct).decode("utf-8")
|
|
36
|
+
|
|
37
|
+
|
|
34
38
|
def convert_array(arr):
|
|
35
39
|
return orjson.loads(arr)
|
|
36
40
|
|
|
@@ -52,6 +56,7 @@ def adapt_np_generic(val):
|
|
|
52
56
|
|
|
53
57
|
def register_type_converters():
|
|
54
58
|
sqlite3.register_adapter(list, adapt_array)
|
|
59
|
+
sqlite3.register_adapter(dict, adapt_dict)
|
|
55
60
|
sqlite3.register_converter("ARRAY", convert_array)
|
|
56
61
|
if numpy_imported:
|
|
57
62
|
sqlite3.register_adapter(np.ndarray, adapt_np_array)
|
|
@@ -50,9 +50,9 @@ datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2kru
|
|
|
50
50
|
datachain/data_storage/sqlite.py,sha256=KJ8hI0Hrwv9eAA-nLUlw2AYCQxiAAZ12a-ftUBtroNQ,24545
|
|
51
51
|
datachain/data_storage/warehouse.py,sha256=ovdH9LmOWLfCrvf0UvXnrNC-CrdAjns3EmXEgFdz4KM,30824
|
|
52
52
|
datachain/diff/__init__.py,sha256=OapNRBsyGDOQHelefUEoXoFHRWCJuBnhvD0ibebKvBc,10486
|
|
53
|
-
datachain/func/__init__.py,sha256=
|
|
53
|
+
datachain/func/__init__.py,sha256=DDpkbK6Kg53ONVGZo8szaNEFJrcepCAG9ev27DC6_7E,1125
|
|
54
54
|
datachain/func/aggregate.py,sha256=7_IPrIwb2XSs3zG4iOr1eTvzn6kNVe2mkzvNzjusDHk,10942
|
|
55
|
-
datachain/func/array.py,sha256=
|
|
55
|
+
datachain/func/array.py,sha256=O784_uwmaP5CjZX4VSF4RmS8cmpaForQc8zASxHJB6A,6717
|
|
56
56
|
datachain/func/base.py,sha256=wA0sBQAVyN9LPxoo7Ox83peS0zUVnyuKxukwAcjGLfY,534
|
|
57
57
|
datachain/func/conditional.py,sha256=g46zwW-i87uA45zWJnPHtHaqr6qOXSg6xLb4p9W3Gtk,6400
|
|
58
58
|
datachain/func/func.py,sha256=PnwTRAiEJUus3e4NYdQ-hldqLzKS9hY0FjiyBMZhsSo,16183
|
|
@@ -120,22 +120,22 @@ datachain/sql/default/__init__.py,sha256=XQ2cEZpzWiABqjV-6yYHUBGI9vN_UHxbxZENESm
|
|
|
120
120
|
datachain/sql/default/base.py,sha256=QD-31C6JnyOXzogyDx90sUhm7QvgXIYpeHEASH84igU,628
|
|
121
121
|
datachain/sql/functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
122
122
|
datachain/sql/functions/aggregate.py,sha256=3AQdA8YHPFdtCEfwZKQXTT8SlQWdG9gD5PBtGN3Odqs,944
|
|
123
|
-
datachain/sql/functions/array.py,sha256=
|
|
123
|
+
datachain/sql/functions/array.py,sha256=ZXPHNQxG-t5tl7Gr0nubu4oc-9fBZmX5q03z9ljWg-I,1443
|
|
124
124
|
datachain/sql/functions/conditional.py,sha256=q7YUKfunXeEldXaxgT-p5pUTcOEVU_tcQ2BJlquTRPs,207
|
|
125
125
|
datachain/sql/functions/numeric.py,sha256=BK2KCiPSgM2IveCq-9M_PG3CtPBlztaS9TTn1LGzyLs,1250
|
|
126
126
|
datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0mg,1294
|
|
127
127
|
datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
|
|
128
128
|
datachain/sql/functions/string.py,sha256=E-T9OIzUR-GKaLgjZsEtg5CJrY_sLf1lt1awTvY7w2w,1426
|
|
129
129
|
datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
|
|
130
|
-
datachain/sql/sqlite/base.py,sha256=
|
|
131
|
-
datachain/sql/sqlite/types.py,sha256=
|
|
130
|
+
datachain/sql/sqlite/base.py,sha256=Rfemu8pj7V0aWhWwryDghhnbiMFfQS5X9FCihGuplb8,19593
|
|
131
|
+
datachain/sql/sqlite/types.py,sha256=cH6oge2E_YWFy22wY-txPJH8gxoQFSpCthtZR8PZjpo,1849
|
|
132
132
|
datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
|
|
133
133
|
datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
|
|
134
134
|
datachain/toolkit/split.py,sha256=z3zRJNzjWrpPuRw-zgFbCOBKInyYxJew8ygrYQRQLNc,2930
|
|
135
135
|
datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
|
|
136
|
-
datachain-0.8.
|
|
137
|
-
datachain-0.8.
|
|
138
|
-
datachain-0.8.
|
|
139
|
-
datachain-0.8.
|
|
140
|
-
datachain-0.8.
|
|
141
|
-
datachain-0.8.
|
|
136
|
+
datachain-0.8.13.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
|
|
137
|
+
datachain-0.8.13.dist-info/METADATA,sha256=ugNjNgfRl-dnR4tKM3AqiRHEmPGFXSYjA-7Bl1BRgOA,10880
|
|
138
|
+
datachain-0.8.13.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
139
|
+
datachain-0.8.13.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
|
|
140
|
+
datachain-0.8.13.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
|
|
141
|
+
datachain-0.8.13.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|