datachain 0.8.12__py3-none-any.whl → 0.8.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -15,7 +15,7 @@ from .aggregate import (
15
15
  row_number,
16
16
  sum,
17
17
  )
18
- from .array import cosine_distance, euclidean_distance, length, sip_hash_64
18
+ from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
19
19
  from .conditional import case, greatest, ifelse, isnone, least
20
20
  from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
21
21
  from .random import rand
@@ -34,6 +34,7 @@ __all__ = [
34
34
  "case",
35
35
  "collect",
36
36
  "concat",
37
+ "contains",
37
38
  "cosine_distance",
38
39
  "count",
39
40
  "dense_rank",
datachain/func/array.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Sequence
2
- from typing import Union
2
+ from typing import Any, Union
3
3
 
4
4
  from datachain.sql.functions import array
5
5
 
@@ -140,6 +140,44 @@ def length(arg: Union[str, Sequence, Func]) -> Func:
140
140
  return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
141
141
 
142
142
 
143
+ def contains(arr: Union[str, Sequence, Func], elem: Any) -> Func:
144
+ """
145
+ Checks whether the `arr` array has the `elem` element.
146
+
147
+ Args:
148
+ arr (str | Sequence | Func): Array to check for the element.
149
+ If a string is provided, it is assumed to be the name of the array column.
150
+ If a sequence is provided, it is assumed to be an array of values.
151
+ If a Func is provided, it is assumed to be a function returning an array.
152
+ elem (Any): Element to check for in the array.
153
+
154
+ Returns:
155
+ Func: A Func object that represents the contains function. Result of the
156
+ function will be 1 if the element is present in the array, and 0 otherwise.
157
+
158
+ Example:
159
+ ```py
160
+ dc.mutate(
161
+ contains1=func.array.contains("signal.values", 3),
162
+ contains2=func.array.contains([1, 2, 3, 4, 5], 7),
163
+ )
164
+ ```
165
+ """
166
+
167
+ def inner(arg):
168
+ is_json = type(elem) in [list, dict]
169
+ return array.contains(arg, elem, is_json)
170
+
171
+ if isinstance(arr, (str, Func)):
172
+ cols = [arr]
173
+ args = None
174
+ else:
175
+ cols = None
176
+ args = [arr]
177
+
178
+ return Func("contains", inner=inner, cols=cols, args=args, result_type=int)
179
+
180
+
143
181
  def sip_hash_64(arg: Union[str, Sequence]) -> Func:
144
182
  """
145
183
  Computes the SipHash-64 hash of the array.
@@ -1,6 +1,6 @@
1
1
  from sqlalchemy.sql.functions import GenericFunction
2
2
 
3
- from datachain.sql.types import Float, Int64
3
+ from datachain.sql.types import Boolean, Float, Int64
4
4
  from datachain.sql.utils import compiler_not_implemented
5
5
 
6
6
 
@@ -37,6 +37,17 @@ class length(GenericFunction): # noqa: N801
37
37
  inherit_cache = True
38
38
 
39
39
 
40
+ class contains(GenericFunction): # noqa: N801
41
+ """
42
+ Checks if element is in the array.
43
+ """
44
+
45
+ type = Boolean()
46
+ package = "array"
47
+ name = "contains"
48
+ inherit_cache = True
49
+
50
+
40
51
  class sip_hash_64(GenericFunction): # noqa: N801
41
52
  """
42
53
  Computes the SipHash-64 hash of the array.
@@ -51,4 +62,5 @@ class sip_hash_64(GenericFunction): # noqa: N801
51
62
  compiler_not_implemented(cosine_distance)
52
63
  compiler_not_implemented(euclidean_distance)
53
64
  compiler_not_implemented(length)
65
+ compiler_not_implemented(contains)
54
66
  compiler_not_implemented(sip_hash_64)
@@ -87,6 +87,7 @@ def setup():
87
87
  compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
88
88
  compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
89
89
  compiles(array.length, "sqlite")(compile_array_length)
90
+ compiles(array.contains, "sqlite")(compile_array_contains)
90
91
  compiles(string.length, "sqlite")(compile_string_length)
91
92
  compiles(string.split, "sqlite")(compile_string_split)
92
93
  compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
@@ -269,13 +270,16 @@ def register_user_defined_sql_functions() -> None:
269
270
 
270
271
  _registered_function_creators["string_functions"] = create_string_functions
271
272
 
272
- has_json_extension = functions_exist(["json_array_length"])
273
+ has_json_extension = functions_exist(["json_array_length", "json_array_contains"])
273
274
  if not has_json_extension:
274
275
 
275
276
  def create_json_functions(conn):
276
277
  conn.create_function(
277
278
  "json_array_length", 1, py_json_array_length, deterministic=True
278
279
  )
280
+ conn.create_function(
281
+ "json_array_contains", 3, py_json_array_contains, deterministic=True
282
+ )
279
283
 
280
284
  _registered_function_creators["json_functions"] = create_json_functions
281
285
 
@@ -428,10 +432,22 @@ def py_json_array_length(arr):
428
432
  return len(orjson.loads(arr))
429
433
 
430
434
 
435
+ def py_json_array_contains(arr, value, is_json):
436
+ if is_json:
437
+ value = orjson.loads(value)
438
+ return value in orjson.loads(arr)
439
+
440
+
431
441
  def compile_array_length(element, compiler, **kwargs):
432
442
  return compiler.process(func.json_array_length(*element.clauses.clauses), **kwargs)
433
443
 
434
444
 
445
+ def compile_array_contains(element, compiler, **kwargs):
446
+ return compiler.process(
447
+ func.json_array_contains(*element.clauses.clauses), **kwargs
448
+ )
449
+
450
+
435
451
  def compile_string_length(element, compiler, **kwargs):
436
452
  return compiler.process(func.length(*element.clauses.clauses), **kwargs)
437
453
 
@@ -31,6 +31,10 @@ def adapt_array(arr):
31
31
  return orjson.dumps(arr).decode("utf-8")
32
32
 
33
33
 
34
+ def adapt_dict(dct):
35
+ return orjson.dumps(dct).decode("utf-8")
36
+
37
+
34
38
  def convert_array(arr):
35
39
  return orjson.loads(arr)
36
40
 
@@ -52,6 +56,7 @@ def adapt_np_generic(val):
52
56
 
53
57
  def register_type_converters():
54
58
  sqlite3.register_adapter(list, adapt_array)
59
+ sqlite3.register_adapter(dict, adapt_dict)
55
60
  sqlite3.register_converter("ARRAY", convert_array)
56
61
  if numpy_imported:
57
62
  sqlite3.register_adapter(np.ndarray, adapt_np_array)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: datachain
3
- Version: 0.8.12
3
+ Version: 0.8.13
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -50,9 +50,9 @@ datachain/data_storage/serializer.py,sha256=6G2YtOFqqDzJf1KbvZraKGXl2XHZyVml2kru
50
50
  datachain/data_storage/sqlite.py,sha256=KJ8hI0Hrwv9eAA-nLUlw2AYCQxiAAZ12a-ftUBtroNQ,24545
51
51
  datachain/data_storage/warehouse.py,sha256=ovdH9LmOWLfCrvf0UvXnrNC-CrdAjns3EmXEgFdz4KM,30824
52
52
  datachain/diff/__init__.py,sha256=OapNRBsyGDOQHelefUEoXoFHRWCJuBnhvD0ibebKvBc,10486
53
- datachain/func/__init__.py,sha256=qaSjakSaTsRtnU7Hcb4lJk71tbwk7M0oWmjRqXExCLA,1099
53
+ datachain/func/__init__.py,sha256=DDpkbK6Kg53ONVGZo8szaNEFJrcepCAG9ev27DC6_7E,1125
54
54
  datachain/func/aggregate.py,sha256=7_IPrIwb2XSs3zG4iOr1eTvzn6kNVe2mkzvNzjusDHk,10942
55
- datachain/func/array.py,sha256=zHDNWuWLA7HVa9FEvQeHhVi00_xqenyleTqcLwkXWBI,5477
55
+ datachain/func/array.py,sha256=O784_uwmaP5CjZX4VSF4RmS8cmpaForQc8zASxHJB6A,6717
56
56
  datachain/func/base.py,sha256=wA0sBQAVyN9LPxoo7Ox83peS0zUVnyuKxukwAcjGLfY,534
57
57
  datachain/func/conditional.py,sha256=g46zwW-i87uA45zWJnPHtHaqr6qOXSg6xLb4p9W3Gtk,6400
58
58
  datachain/func/func.py,sha256=PnwTRAiEJUus3e4NYdQ-hldqLzKS9hY0FjiyBMZhsSo,16183
@@ -120,22 +120,22 @@ datachain/sql/default/__init__.py,sha256=XQ2cEZpzWiABqjV-6yYHUBGI9vN_UHxbxZENESm
120
120
  datachain/sql/default/base.py,sha256=QD-31C6JnyOXzogyDx90sUhm7QvgXIYpeHEASH84igU,628
121
121
  datachain/sql/functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
122
122
  datachain/sql/functions/aggregate.py,sha256=3AQdA8YHPFdtCEfwZKQXTT8SlQWdG9gD5PBtGN3Odqs,944
123
- datachain/sql/functions/array.py,sha256=Zq59CaMHf_hFapU4kxvy2mwteH344k5Wksxja4MfBks,1204
123
+ datachain/sql/functions/array.py,sha256=ZXPHNQxG-t5tl7Gr0nubu4oc-9fBZmX5q03z9ljWg-I,1443
124
124
  datachain/sql/functions/conditional.py,sha256=q7YUKfunXeEldXaxgT-p5pUTcOEVU_tcQ2BJlquTRPs,207
125
125
  datachain/sql/functions/numeric.py,sha256=BK2KCiPSgM2IveCq-9M_PG3CtPBlztaS9TTn1LGzyLs,1250
126
126
  datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0mg,1294
127
127
  datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
128
128
  datachain/sql/functions/string.py,sha256=E-T9OIzUR-GKaLgjZsEtg5CJrY_sLf1lt1awTvY7w2w,1426
129
129
  datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
130
- datachain/sql/sqlite/base.py,sha256=bPrYfj2ZF9hFZFs0chgH7J5l_tdXI4VMZMgkuBjf7Ng,19070
131
- datachain/sql/sqlite/types.py,sha256=lPXS1XbkmUtlkkiRxy_A_UzsgpPv2VSkXYOD4zIHM4w,1734
130
+ datachain/sql/sqlite/base.py,sha256=Rfemu8pj7V0aWhWwryDghhnbiMFfQS5X9FCihGuplb8,19593
131
+ datachain/sql/sqlite/types.py,sha256=cH6oge2E_YWFy22wY-txPJH8gxoQFSpCthtZR8PZjpo,1849
132
132
  datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
133
133
  datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
134
134
  datachain/toolkit/split.py,sha256=z3zRJNzjWrpPuRw-zgFbCOBKInyYxJew8ygrYQRQLNc,2930
135
135
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
136
- datachain-0.8.12.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
137
- datachain-0.8.12.dist-info/METADATA,sha256=C1vaFTVw44GIVe32CcfLthfCi5nbbqTgS7HL61iSFGg,10880
138
- datachain-0.8.12.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
139
- datachain-0.8.12.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
140
- datachain-0.8.12.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
141
- datachain-0.8.12.dist-info/RECORD,,
136
+ datachain-0.8.13.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
137
+ datachain-0.8.13.dist-info/METADATA,sha256=ugNjNgfRl-dnR4tKM3AqiRHEmPGFXSYjA-7Bl1BRgOA,10880
138
+ datachain-0.8.13.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
139
+ datachain-0.8.13.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
140
+ datachain-0.8.13.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
141
+ datachain-0.8.13.dist-info/RECORD,,